brainstate 0.0.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +45 -0
- brainstate/_module.py +1466 -0
- brainstate/_module_test.py +133 -0
- brainstate/_state.py +378 -0
- brainstate/_state_test.py +41 -0
- brainstate/_utils.py +21 -0
- brainstate/environ.py +375 -0
- brainstate/functional/__init__.py +25 -0
- brainstate/functional/_activations.py +754 -0
- brainstate/functional/_normalization.py +69 -0
- brainstate/functional/_spikes.py +90 -0
- brainstate/init/__init__.py +26 -0
- brainstate/init/_base.py +36 -0
- brainstate/init/_generic.py +175 -0
- brainstate/init/_random_inits.py +489 -0
- brainstate/init/_regular_inits.py +109 -0
- brainstate/math/__init__.py +21 -0
- brainstate/math/_einops.py +787 -0
- brainstate/math/_einops_parsing.py +169 -0
- brainstate/math/_einops_parsing_test.py +126 -0
- brainstate/math/_einops_test.py +346 -0
- brainstate/math/_misc.py +298 -0
- brainstate/math/_misc_test.py +58 -0
- brainstate/mixin.py +373 -0
- brainstate/mixin_test.py +73 -0
- brainstate/nn/__init__.py +68 -0
- brainstate/nn/_base.py +248 -0
- brainstate/nn/_connections.py +686 -0
- brainstate/nn/_dynamics.py +406 -0
- brainstate/nn/_elementwise.py +1437 -0
- brainstate/nn/_misc.py +132 -0
- brainstate/nn/_normalizations.py +389 -0
- brainstate/nn/_others.py +100 -0
- brainstate/nn/_poolings.py +1228 -0
- brainstate/nn/_poolings_test.py +231 -0
- brainstate/nn/_projection/__init__.py +32 -0
- brainstate/nn/_projection/_align_post.py +528 -0
- brainstate/nn/_projection/_align_pre.py +599 -0
- brainstate/nn/_projection/_delta.py +241 -0
- brainstate/nn/_projection/_utils.py +17 -0
- brainstate/nn/_projection/_vanilla.py +101 -0
- brainstate/nn/_rate_rnns.py +393 -0
- brainstate/nn/_readout.py +130 -0
- brainstate/nn/_synouts.py +166 -0
- brainstate/nn/functional/__init__.py +25 -0
- brainstate/nn/functional/_activations.py +754 -0
- brainstate/nn/functional/_normalization.py +69 -0
- brainstate/nn/functional/_spikes.py +90 -0
- brainstate/nn/init/__init__.py +26 -0
- brainstate/nn/init/_base.py +36 -0
- brainstate/nn/init/_generic.py +175 -0
- brainstate/nn/init/_random_inits.py +489 -0
- brainstate/nn/init/_regular_inits.py +109 -0
- brainstate/nn/surrogate.py +1740 -0
- brainstate/optim/__init__.py +23 -0
- brainstate/optim/_lr_scheduler.py +486 -0
- brainstate/optim/_lr_scheduler_test.py +36 -0
- brainstate/optim/_sgd_optimizer.py +1148 -0
- brainstate/random.py +5148 -0
- brainstate/random_test.py +576 -0
- brainstate/surrogate.py +1740 -0
- brainstate/transform/__init__.py +36 -0
- brainstate/transform/_autograd.py +585 -0
- brainstate/transform/_autograd_test.py +1183 -0
- brainstate/transform/_control.py +665 -0
- brainstate/transform/_controls_test.py +220 -0
- brainstate/transform/_jit.py +239 -0
- brainstate/transform/_jit_error.py +158 -0
- brainstate/transform/_jit_test.py +102 -0
- brainstate/transform/_make_jaxpr.py +573 -0
- brainstate/transform/_make_jaxpr_test.py +133 -0
- brainstate/transform/_progress_bar.py +113 -0
- brainstate/typing.py +69 -0
- brainstate/util.py +747 -0
- brainstate-0.0.1.dist-info/LICENSE +202 -0
- brainstate-0.0.1.dist-info/METADATA +101 -0
- brainstate-0.0.1.dist-info/RECORD +79 -0
- brainstate-0.0.1.dist-info/WHEEL +6 -0
- brainstate-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,133 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import jax.numpy as jnp
|
19
|
+
|
20
|
+
import brainstate as bc
|
21
|
+
|
22
|
+
|
23
|
+
class TestVarDelay(unittest.TestCase):
|
24
|
+
def test_delay1(self):
|
25
|
+
a = bc.State(bc.random.random(10, 20))
|
26
|
+
delay = bc.Delay(a.value)
|
27
|
+
delay.register_entry('a', 1.)
|
28
|
+
delay.register_entry('b', 2.)
|
29
|
+
delay.register_entry('c', None)
|
30
|
+
|
31
|
+
delay.init_state()
|
32
|
+
with self.assertRaises(KeyError):
|
33
|
+
delay.register_entry('c', 10.)
|
34
|
+
bc.util.clear_buffer_memory()
|
35
|
+
|
36
|
+
def test_rotation_delay(self):
|
37
|
+
rotation_delay = bc.Delay(jnp.ones((1,)))
|
38
|
+
t0 = 0.
|
39
|
+
t1, n1 = 1., 10
|
40
|
+
t2, n2 = 2., 20
|
41
|
+
|
42
|
+
rotation_delay.register_entry('a', t0)
|
43
|
+
rotation_delay.register_entry('b', t1)
|
44
|
+
rotation_delay.register_entry('c2', 1.9)
|
45
|
+
rotation_delay.register_entry('c', t2)
|
46
|
+
|
47
|
+
rotation_delay.init_state()
|
48
|
+
|
49
|
+
print()
|
50
|
+
# print(rotation_delay)
|
51
|
+
# print(rotation_delay.max_length)
|
52
|
+
|
53
|
+
for i in range(100):
|
54
|
+
bc.environ.set(i=i)
|
55
|
+
rotation_delay(jnp.ones((1,)) * i)
|
56
|
+
# print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
|
57
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
58
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
59
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
60
|
+
bc.util.clear_buffer_memory()
|
61
|
+
|
62
|
+
def test_concat_delay(self):
|
63
|
+
rotation_delay = bc.Delay(jnp.ones([1]), method='concat')
|
64
|
+
t0 = 0.
|
65
|
+
t1, n1 = 1., 10
|
66
|
+
t2, n2 = 2., 20
|
67
|
+
|
68
|
+
rotation_delay.register_entry('a', t0)
|
69
|
+
rotation_delay.register_entry('b', t1)
|
70
|
+
rotation_delay.register_entry('c', t2)
|
71
|
+
|
72
|
+
rotation_delay.init_state()
|
73
|
+
|
74
|
+
print()
|
75
|
+
for i in range(100):
|
76
|
+
bc.environ.set(i=i)
|
77
|
+
rotation_delay(jnp.ones((1,)) * i)
|
78
|
+
print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
|
79
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
80
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
81
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
82
|
+
bc.util.clear_buffer_memory()
|
83
|
+
|
84
|
+
def test_rotation_and_concat_delay(self):
|
85
|
+
rotation_delay = bc.Delay(jnp.ones((1,)))
|
86
|
+
concat_delay = bc.Delay(jnp.ones([1]), method='concat')
|
87
|
+
t0 = 0.
|
88
|
+
t1, n1 = 1., 10
|
89
|
+
t2, n2 = 2., 20
|
90
|
+
|
91
|
+
rotation_delay.register_entry('a', t0)
|
92
|
+
rotation_delay.register_entry('b', t1)
|
93
|
+
rotation_delay.register_entry('c', t2)
|
94
|
+
concat_delay.register_entry('a', t0)
|
95
|
+
concat_delay.register_entry('b', t1)
|
96
|
+
concat_delay.register_entry('c', t2)
|
97
|
+
|
98
|
+
rotation_delay.init_state()
|
99
|
+
concat_delay.init_state()
|
100
|
+
|
101
|
+
print()
|
102
|
+
for i in range(100):
|
103
|
+
bc.environ.set(i=i)
|
104
|
+
new = jnp.ones((1,)) * i
|
105
|
+
rotation_delay(new)
|
106
|
+
concat_delay(new)
|
107
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('a'), concat_delay.at('a'), ))
|
108
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('b'), concat_delay.at('b'), ))
|
109
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('c'), concat_delay.at('c'), ))
|
110
|
+
bc.util.clear_buffer_memory()
|
111
|
+
|
112
|
+
|
113
|
+
class TestModule(unittest.TestCase):
|
114
|
+
def test_states(self):
|
115
|
+
class A(bc.Module):
|
116
|
+
def __init__(self):
|
117
|
+
super().__init__()
|
118
|
+
self.a = bc.State(bc.random.random(10, 20))
|
119
|
+
self.b = bc.State(bc.random.random(10, 20))
|
120
|
+
|
121
|
+
class B(bc.Module):
|
122
|
+
def __init__(self):
|
123
|
+
super().__init__()
|
124
|
+
self.a = A()
|
125
|
+
self.b = bc.State(bc.random.random(10, 20))
|
126
|
+
|
127
|
+
b = B()
|
128
|
+
print()
|
129
|
+
print(b.states())
|
130
|
+
print(b.states())
|
131
|
+
print(b.states(level=0))
|
132
|
+
print(b.states(level=0))
|
133
|
+
|
brainstate/_state.py
ADDED
@@ -0,0 +1,378 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import contextlib
|
17
|
+
import threading
|
18
|
+
from typing import Any, Tuple, Dict, List, Callable
|
19
|
+
|
20
|
+
import jax
|
21
|
+
import numpy as np
|
22
|
+
from jax.api_util import shaped_abstractify
|
23
|
+
from jax.extend import source_info_util
|
24
|
+
|
25
|
+
from .util import DictManager
|
26
|
+
|
27
|
+
PyTree = Any
|
28
|
+
max_int = np.iinfo(np.int32)
|
29
|
+
|
30
|
+
__all__ = [
|
31
|
+
'State', 'ShortTermState', 'LongTermState', 'ParamState',
|
32
|
+
'StateDictManager', 'visible_state_dict',
|
33
|
+
'check_state_value_tree',
|
34
|
+
]
|
35
|
+
|
36
|
+
_pytree_registered_objects = set()
|
37
|
+
|
38
|
+
|
39
|
+
def _register_pytree_cls(cls):
|
40
|
+
if cls not in _pytree_registered_objects:
|
41
|
+
jax.tree_util.register_pytree_node_class(cls)
|
42
|
+
_pytree_registered_objects.add(cls)
|
43
|
+
|
44
|
+
|
45
|
+
# The global state of the state stack is accessed by a thread-local object.
|
46
|
+
# This allows concurrent tracing in separate threads; passing traced objects
|
47
|
+
# between threads is forbidden.
|
48
|
+
class ThreadLocalStack(threading.local):
|
49
|
+
def __init__(self):
|
50
|
+
self.stack: List[StateTrace] = []
|
51
|
+
|
52
|
+
|
53
|
+
thread_local_stack = ThreadLocalStack()
|
54
|
+
|
55
|
+
_global_context_to_check_state_tree = [False]
|
56
|
+
|
57
|
+
|
58
|
+
@contextlib.contextmanager
|
59
|
+
def check_state_value_tree() -> None:
|
60
|
+
"""
|
61
|
+
The contex manager to check weather the tree structure of the state value keeps consistently.
|
62
|
+
"""
|
63
|
+
try:
|
64
|
+
_global_context_to_check_state_tree.append(True)
|
65
|
+
yield
|
66
|
+
finally:
|
67
|
+
_global_context_to_check_state_tree.pop()
|
68
|
+
|
69
|
+
|
70
|
+
class State(object):
|
71
|
+
"""
|
72
|
+
The pointer to specify the dynamical data.
|
73
|
+
|
74
|
+
To implement a new subclass of :py:class:`~.State`, you only need to inherent this class:
|
75
|
+
|
76
|
+
Example::
|
77
|
+
|
78
|
+
class MyState(State):
|
79
|
+
pass
|
80
|
+
|
81
|
+
The typical examples of :py:class:`~.State` subclass are:
|
82
|
+
|
83
|
+
- :py:class:`~.ShortTermState`: The short-term state, which is used to store the short-term data in the program.
|
84
|
+
- :py:class:`~.LongTermState`: The long-term state, which is used to store the long-term data in the program.
|
85
|
+
- :py:class:`~.ParamState`: The parameter state, which is used to store the parameters in the program.
|
86
|
+
- :py:class:`~.RandomState`: The random generator state, which is used to store the random key in the program.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
value: PyTree. It can be anything as a pyTree.
|
90
|
+
"""
|
91
|
+
__module__ = 'brainstate'
|
92
|
+
__slots__ = ('_value', '_tree', '_level', '_source_info', '_check_tree')
|
93
|
+
|
94
|
+
def __init__(self, value: PyTree):
|
95
|
+
if isinstance(value, State):
|
96
|
+
value = value.value
|
97
|
+
self._value = value
|
98
|
+
self._tree = jax.tree.structure(value)
|
99
|
+
self._check_tree = False
|
100
|
+
self._level = len(thread_local_stack.stack)
|
101
|
+
self._source_info = source_info_util.current()
|
102
|
+
|
103
|
+
@property
|
104
|
+
def value(self) -> PyTree:
|
105
|
+
"""
|
106
|
+
The data and its value.
|
107
|
+
"""
|
108
|
+
self._check_if_deleted()
|
109
|
+
|
110
|
+
# read the value by the stack (>= level)
|
111
|
+
trace: StateTrace
|
112
|
+
for trace in thread_local_stack.stack[self._level:]:
|
113
|
+
trace.read_its_value(self)
|
114
|
+
# return the value
|
115
|
+
return self._value
|
116
|
+
|
117
|
+
@value.setter
|
118
|
+
def value(self, v) -> None:
|
119
|
+
"""
|
120
|
+
Set the value of the state.
|
121
|
+
|
122
|
+
Args:
|
123
|
+
v: The value.
|
124
|
+
"""
|
125
|
+
# value checking
|
126
|
+
v = v.value if isinstance(v, State) else v
|
127
|
+
self._check_value(v)
|
128
|
+
# write the value by the stack (>= level)
|
129
|
+
trace: StateTrace
|
130
|
+
for trace in thread_local_stack.stack[self._level:]:
|
131
|
+
trace.write_its_value(self)
|
132
|
+
# set the value
|
133
|
+
self._value = v
|
134
|
+
|
135
|
+
def _check_value(self, v):
|
136
|
+
if self._check_tree or _global_context_to_check_state_tree[-1]:
|
137
|
+
in_tree = jax.tree_util.tree_structure(v)
|
138
|
+
if in_tree != self._tree:
|
139
|
+
self._raise_error_with_source_info(
|
140
|
+
ValueError(f'The given value {in_tree} does not '
|
141
|
+
f'match with the origin tree structure '
|
142
|
+
f'{self._tree}.')
|
143
|
+
)
|
144
|
+
|
145
|
+
def _raise_error_with_source_info(self, error: Exception):
|
146
|
+
name_stack = source_info_util.current_name_stack() + self.source_info.name_stack
|
147
|
+
with source_info_util.user_context(self.source_info.traceback, name_stack=name_stack):
|
148
|
+
raise error
|
149
|
+
|
150
|
+
def _check_if_deleted(self):
|
151
|
+
pass
|
152
|
+
|
153
|
+
@property
|
154
|
+
def source_info(self) -> source_info_util.SourceInfo:
|
155
|
+
"""
|
156
|
+
The source information of the state, can be useful to identify
|
157
|
+
the source code where the definition of the state.
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
The source information.
|
161
|
+
"""
|
162
|
+
return self._source_info
|
163
|
+
|
164
|
+
def tree_flatten(self):
|
165
|
+
"""Flattens this variable.
|
166
|
+
|
167
|
+
Returns:
|
168
|
+
A pair where the first element is a list of leaf values
|
169
|
+
and the second element is a treedef representing the
|
170
|
+
structure of the flattened tree.
|
171
|
+
"""
|
172
|
+
return (self._value,), (self._level,)
|
173
|
+
|
174
|
+
@classmethod
|
175
|
+
def tree_unflatten(cls, aux_data, flat_contents):
|
176
|
+
"""Reconstructs a variable from the aux_data and the leaves.
|
177
|
+
|
178
|
+
Args:
|
179
|
+
aux_data:
|
180
|
+
flat_contents:
|
181
|
+
|
182
|
+
Returns:
|
183
|
+
The variable.
|
184
|
+
"""
|
185
|
+
(_level,) = aux_data
|
186
|
+
self = cls(flat_contents[0])
|
187
|
+
self._level = max_int
|
188
|
+
return self
|
189
|
+
|
190
|
+
def __repr__(self):
|
191
|
+
leaves, tree = jax.tree.flatten(self._value)
|
192
|
+
leaves_info = [ShapeDtype(leaf.shape, leaf.dtype) for leaf in leaves]
|
193
|
+
tree_info = jax.tree.unflatten(tree, leaves_info)
|
194
|
+
return f'{self.__class__.__name__}({tree_info})'
|
195
|
+
|
196
|
+
|
197
|
+
class ShapeDtype:
|
198
|
+
def __init__(self, shape, dtype):
|
199
|
+
self.shape = shape
|
200
|
+
self.dtype = dtype
|
201
|
+
|
202
|
+
def __repr__(self):
|
203
|
+
return f'{self.dtype}{list(self.shape)}'
|
204
|
+
|
205
|
+
|
206
|
+
class ShortTermState(State):
|
207
|
+
"""
|
208
|
+
The short-term state, which is used to store the short-term data in the program.
|
209
|
+
|
210
|
+
For example, in a training process, the gradients of the model are short-term states.
|
211
|
+
"""
|
212
|
+
|
213
|
+
__module__ = 'brainstate'
|
214
|
+
|
215
|
+
|
216
|
+
class LongTermState(State):
|
217
|
+
"""
|
218
|
+
The long-term state, which is used to store the long-term data in the program.
|
219
|
+
|
220
|
+
For example, in a training process, the weights of the model are long-term states.
|
221
|
+
|
222
|
+
"""
|
223
|
+
|
224
|
+
__module__ = 'brainstate'
|
225
|
+
|
226
|
+
|
227
|
+
class ParamState(LongTermState):
|
228
|
+
"""
|
229
|
+
The parameter state, which is used to store the trainable parameters in the model.
|
230
|
+
"""
|
231
|
+
__module__ = 'brainstate'
|
232
|
+
|
233
|
+
|
234
|
+
class StateDictManager(DictManager):
|
235
|
+
"""
|
236
|
+
State stack, for collecting all :py:class:`~.State` used in the program.
|
237
|
+
|
238
|
+
:py:class:`~.StateDictManager` supports all features of python dict.
|
239
|
+
"""
|
240
|
+
|
241
|
+
__module__ = 'brainstate'
|
242
|
+
|
243
|
+
def assign_values(self, *args: Dict) -> None:
|
244
|
+
"""
|
245
|
+
Assign the value for each element according to the given ``data``.
|
246
|
+
"""
|
247
|
+
for arg in args:
|
248
|
+
assert isinstance(arg, dict), 'Must be an instance of dict.'
|
249
|
+
for k, v in arg.items():
|
250
|
+
self._set_elem(k, v)
|
251
|
+
|
252
|
+
def split_values(self, *filters: type) -> Tuple[Dict, ...]:
|
253
|
+
"""
|
254
|
+
Split the values into several subsets of stack by the given types.
|
255
|
+
"""
|
256
|
+
results = tuple(DictManager() for _ in range(len(filters) + 1))
|
257
|
+
for k, v in self.items():
|
258
|
+
for i, filt in enumerate(filters):
|
259
|
+
if isinstance(v, filt):
|
260
|
+
results[i][k] = v.value
|
261
|
+
break
|
262
|
+
else:
|
263
|
+
results[-1][k] = v.value
|
264
|
+
return results
|
265
|
+
|
266
|
+
def collect_values(self) -> Dict:
|
267
|
+
"""
|
268
|
+
Collect the values by the given types.
|
269
|
+
"""
|
270
|
+
results = DictManager()
|
271
|
+
for k, v in self.items():
|
272
|
+
results[k] = v.value
|
273
|
+
return results
|
274
|
+
|
275
|
+
def split(self, first: type, *others: type) -> Tuple['StateDictManager', ...]:
|
276
|
+
return super().split(first, *others)
|
277
|
+
|
278
|
+
def to_dict_values(self) -> Dict:
|
279
|
+
"""
|
280
|
+
Convert the values into a dict.
|
281
|
+
"""
|
282
|
+
return {k: v.value for k, v in self.items()}
|
283
|
+
|
284
|
+
def _check_elem(self, elem):
|
285
|
+
assert isinstance(elem, State), f'must be instance of {State}'
|
286
|
+
|
287
|
+
def _set_elem(self, key: Any, value: Any) -> None:
|
288
|
+
self[key].value = value
|
289
|
+
|
290
|
+
|
291
|
+
class visible_state_dict(StateDictManager):
|
292
|
+
"""
|
293
|
+
The state dictionary whose elements are visible to ``.states()`` collection functions.
|
294
|
+
"""
|
295
|
+
pass
|
296
|
+
|
297
|
+
|
298
|
+
class StateTrace(object):
|
299
|
+
"""
|
300
|
+
The state trace, which is used to trace the states automatically.
|
301
|
+
"""
|
302
|
+
|
303
|
+
def __init__(self, new_arg: Callable = None):
|
304
|
+
self.states: List[State] = []
|
305
|
+
self.types: List[str] = []
|
306
|
+
self._id2index = dict()
|
307
|
+
self._org_values = []
|
308
|
+
self._jax_trace_new_arg = new_arg
|
309
|
+
self._written_ids = set()
|
310
|
+
|
311
|
+
def set_new_arg(self, new_arg: Callable) -> None:
|
312
|
+
self._jax_trace_new_arg = new_arg
|
313
|
+
|
314
|
+
def new_arg(self, state: State) -> None:
|
315
|
+
if self._jax_trace_new_arg is not None:
|
316
|
+
# internal use
|
317
|
+
state._value = jax.tree.map(lambda x: self._jax_trace_new_arg(shaped_abstractify(x)), state._value)
|
318
|
+
|
319
|
+
def __enter__(self) -> 'StateTrace':
|
320
|
+
thread_local_stack.stack.append(self)
|
321
|
+
return self
|
322
|
+
|
323
|
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
324
|
+
thread_local_stack.stack.pop()
|
325
|
+
|
326
|
+
def read_its_value(self, state: State) -> None:
|
327
|
+
"""
|
328
|
+
Read the value of the state.
|
329
|
+
|
330
|
+
Args:
|
331
|
+
state: The state.
|
332
|
+
"""
|
333
|
+
id_ = id(state)
|
334
|
+
if id_ not in self._id2index:
|
335
|
+
self._id2index[id_] = len(self.states)
|
336
|
+
self.states.append(state)
|
337
|
+
self.types.append('read')
|
338
|
+
self._org_values.append(state._value) # internal use
|
339
|
+
self.new_arg(state)
|
340
|
+
|
341
|
+
def write_its_value(self, state: State) -> None:
|
342
|
+
"""
|
343
|
+
Write the value of the state.
|
344
|
+
|
345
|
+
Args:
|
346
|
+
state: The state.
|
347
|
+
"""
|
348
|
+
id_ = id(state)
|
349
|
+
if id_ not in self._id2index:
|
350
|
+
self.read_its_value(state)
|
351
|
+
if id_ not in self._written_ids:
|
352
|
+
index = self._id2index[id_]
|
353
|
+
self.types[index] = 'write'
|
354
|
+
self._written_ids.add(id_)
|
355
|
+
|
356
|
+
def collect_values(self, *categories: str) -> Tuple:
|
357
|
+
"""
|
358
|
+
Collect the values by the given categories.
|
359
|
+
|
360
|
+
Args:
|
361
|
+
*categories: The categories.
|
362
|
+
|
363
|
+
Returns:
|
364
|
+
results: The values.
|
365
|
+
"""
|
366
|
+
results = []
|
367
|
+
for st, ty in zip(self.states, self.types):
|
368
|
+
if ty in categories:
|
369
|
+
results.append(st.value)
|
370
|
+
return tuple(results)
|
371
|
+
|
372
|
+
def recovery_original_values(self) -> None:
|
373
|
+
"""
|
374
|
+
Recovery the original values.
|
375
|
+
"""
|
376
|
+
for st, val in zip(self.states, self._org_values):
|
377
|
+
# internal use
|
378
|
+
st._value = val
|
@@ -0,0 +1,41 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import unittest
|
18
|
+
|
19
|
+
import brainstate as bc
|
20
|
+
|
21
|
+
|
22
|
+
class TestStateSourceInfo(unittest.TestCase):
|
23
|
+
|
24
|
+
def test_state_source_info(self):
|
25
|
+
state = bc.State(bc.random.randn(10))
|
26
|
+
print(state._source_info)
|
27
|
+
|
28
|
+
|
29
|
+
class TestStateRepr(unittest.TestCase):
|
30
|
+
|
31
|
+
def test_state_repr(self):
|
32
|
+
print()
|
33
|
+
|
34
|
+
state = bc.State(bc.random.randn(10))
|
35
|
+
print(state)
|
36
|
+
|
37
|
+
state2 = bc.State({'a': bc.random.randn(10), 'b': bc.random.randn(10)})
|
38
|
+
print(state2)
|
39
|
+
|
40
|
+
state3 = bc.State([bc.random.randn(10), bc.random.randn(10)])
|
41
|
+
print(state3)
|
brainstate/_utils.py
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
def set_module_as(module: str):
|
17
|
+
def wrapper(fun: callable):
|
18
|
+
fun.__module__ = module
|
19
|
+
return fun
|
20
|
+
|
21
|
+
return wrapper
|