brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,563 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import unittest
|
17
|
-
from collections.abc import Callable
|
18
|
-
from threading import Thread
|
19
|
-
|
20
|
-
import jax
|
21
|
-
import jax.numpy as jnp
|
22
|
-
from absl.testing import absltest, parameterized
|
23
|
-
|
24
|
-
import brainstate
|
25
|
-
|
26
|
-
|
27
|
-
class TestIter(unittest.TestCase):
|
28
|
-
def test1(self):
|
29
|
-
class Model(brainstate.nn.Module):
|
30
|
-
def __init__(self):
|
31
|
-
super().__init__()
|
32
|
-
self.a = brainstate.nn.Linear(1, 2)
|
33
|
-
self.b = brainstate.nn.Linear(2, 3)
|
34
|
-
self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
|
35
|
-
self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
|
36
|
-
self.b.a = brainstate.nn.LIF(2)
|
37
|
-
|
38
|
-
for path, node in brainstate.graph.iter_leaf(Model()):
|
39
|
-
print(path, node)
|
40
|
-
for path, node in brainstate.graph.iter_node(Model()):
|
41
|
-
print(path, node)
|
42
|
-
for path, node in brainstate.graph.iter_node(Model(), allowed_hierarchy=(1, 1)):
|
43
|
-
print(path, node)
|
44
|
-
for path, node in brainstate.graph.iter_node(Model(), allowed_hierarchy=(2, 2)):
|
45
|
-
print(path, node)
|
46
|
-
|
47
|
-
def test_iter_leaf_v1(self):
|
48
|
-
class Linear(brainstate.nn.Module):
|
49
|
-
def __init__(self, din, dout):
|
50
|
-
super().__init__()
|
51
|
-
self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
|
52
|
-
self.bias = brainstate.ParamState(brainstate.random.randn(dout))
|
53
|
-
self.a = 1
|
54
|
-
|
55
|
-
module = Linear(3, 4)
|
56
|
-
graph = [module, module]
|
57
|
-
|
58
|
-
num = 0
|
59
|
-
for path, value in brainstate.graph.iter_leaf(graph):
|
60
|
-
print(path, type(value).__name__)
|
61
|
-
num += 1
|
62
|
-
|
63
|
-
assert num == 3
|
64
|
-
|
65
|
-
def test_iter_node_v1(self):
|
66
|
-
class Model(brainstate.nn.Module):
|
67
|
-
def __init__(self):
|
68
|
-
super().__init__()
|
69
|
-
self.a = brainstate.nn.Linear(1, 2)
|
70
|
-
self.b = brainstate.nn.Linear(2, 3)
|
71
|
-
self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
|
72
|
-
self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
|
73
|
-
self.b.a = brainstate.nn.LIF(2)
|
74
|
-
|
75
|
-
model = Model()
|
76
|
-
|
77
|
-
num = 0
|
78
|
-
for path, node in brainstate.graph.iter_node([model, model]):
|
79
|
-
print(path, node.__class__.__name__)
|
80
|
-
num += 1
|
81
|
-
assert num == 8
|
82
|
-
|
83
|
-
|
84
|
-
class List(brainstate.nn.Module):
|
85
|
-
def __init__(self, items):
|
86
|
-
super().__init__()
|
87
|
-
self.items = list(items)
|
88
|
-
|
89
|
-
def __getitem__(self, idx):
|
90
|
-
return self.items[idx]
|
91
|
-
|
92
|
-
def __setitem__(self, idx, value):
|
93
|
-
self.items[idx] = value
|
94
|
-
|
95
|
-
|
96
|
-
class Dict(brainstate.nn.Module):
|
97
|
-
def __init__(self, *args, **kwargs):
|
98
|
-
super().__init__()
|
99
|
-
self.items = dict(*args, **kwargs)
|
100
|
-
|
101
|
-
def __getitem__(self, key):
|
102
|
-
return self.items[key]
|
103
|
-
|
104
|
-
def __setitem__(self, key, value):
|
105
|
-
self.items[key] = value
|
106
|
-
|
107
|
-
|
108
|
-
class StatefulLinear(brainstate.nn.Module):
|
109
|
-
def __init__(self, din, dout):
|
110
|
-
super().__init__()
|
111
|
-
self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
|
112
|
-
self.b = brainstate.ParamState(jnp.zeros((dout,)))
|
113
|
-
self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
|
114
|
-
|
115
|
-
def increment(self):
|
116
|
-
self.count.value += 1
|
117
|
-
|
118
|
-
def __call__(self, x):
|
119
|
-
self.count.value += 1
|
120
|
-
return x @ self.w.value + self.b.value
|
121
|
-
|
122
|
-
|
123
|
-
class TestGraphUtils(absltest.TestCase):
|
124
|
-
def test_flatten_treey_state(self):
|
125
|
-
a = {'a': 1, 'b': brainstate.ParamState(2)}
|
126
|
-
g = [a, 3, a, brainstate.ParamState(4)]
|
127
|
-
|
128
|
-
refmap = brainstate.graph.RefMap()
|
129
|
-
graphdef, states = brainstate.graph.flatten(g, ref_index=refmap, treefy_state=True)
|
130
|
-
|
131
|
-
states[0]['b'].value = 2
|
132
|
-
states[3].value = 4
|
133
|
-
|
134
|
-
assert isinstance(states[0]['b'], brainstate.TreefyState)
|
135
|
-
assert isinstance(states[3], brainstate.TreefyState)
|
136
|
-
assert isinstance(states, brainstate.util.NestedDict)
|
137
|
-
assert len(refmap) == 2
|
138
|
-
assert a['b'] in refmap
|
139
|
-
assert g[3] in refmap
|
140
|
-
|
141
|
-
def test_flatten(self):
|
142
|
-
a = {'a': 1, 'b': brainstate.ParamState(2)}
|
143
|
-
g = [a, 3, a, brainstate.ParamState(4)]
|
144
|
-
|
145
|
-
refmap = brainstate.graph.RefMap()
|
146
|
-
graphdef, states = brainstate.graph.flatten(g, ref_index=refmap, treefy_state=False)
|
147
|
-
|
148
|
-
states[0]['b'].value = 2
|
149
|
-
states[3].value = 4
|
150
|
-
|
151
|
-
assert isinstance(states[0]['b'], brainstate.State)
|
152
|
-
assert isinstance(states[3], brainstate.State)
|
153
|
-
assert len(refmap) == 2
|
154
|
-
assert a['b'] in refmap
|
155
|
-
assert g[3] in refmap
|
156
|
-
|
157
|
-
def test_unflatten_treey_state(self):
|
158
|
-
a = brainstate.graph.Dict(a=1, b=brainstate.ParamState(2))
|
159
|
-
g1 = brainstate.graph.List([a, 3, a, brainstate.ParamState(4)])
|
160
|
-
|
161
|
-
graphdef, references = brainstate.graph.flatten(g1, treefy_state=True)
|
162
|
-
g = brainstate.graph.unflatten(graphdef, references)
|
163
|
-
|
164
|
-
print(graphdef)
|
165
|
-
print(references)
|
166
|
-
assert g[0] is g[2]
|
167
|
-
assert g1[3] is not g[3]
|
168
|
-
assert g1[0]['b'] is not g[0]['b']
|
169
|
-
|
170
|
-
def test_unflatten(self):
|
171
|
-
a = brainstate.graph.Dict(a=1, b=brainstate.ParamState(2))
|
172
|
-
g1 = brainstate.graph.List([a, 3, a, brainstate.ParamState(4)])
|
173
|
-
|
174
|
-
graphdef, references = brainstate.graph.flatten(g1, treefy_state=False)
|
175
|
-
g = brainstate.graph.unflatten(graphdef, references)
|
176
|
-
|
177
|
-
print(graphdef)
|
178
|
-
print(references)
|
179
|
-
assert g[0] is g[2]
|
180
|
-
assert g1[3] is g[3]
|
181
|
-
assert g1[0]['b'] is g[0]['b']
|
182
|
-
|
183
|
-
def test_unflatten_pytree(self):
|
184
|
-
a = {'a': 1, 'b': brainstate.ParamState(2)}
|
185
|
-
g = [a, 3, a, brainstate.ParamState(4)]
|
186
|
-
|
187
|
-
graphdef, references = brainstate.graph.treefy_split(g)
|
188
|
-
g = brainstate.graph.treefy_merge(graphdef, references)
|
189
|
-
|
190
|
-
assert g[0] is not g[2]
|
191
|
-
|
192
|
-
def test_unflatten_empty(self):
|
193
|
-
a = Dict({'a': 1, 'b': brainstate.ParamState(2)})
|
194
|
-
g = List([a, 3, a, brainstate.ParamState(4)])
|
195
|
-
|
196
|
-
graphdef, references = brainstate.graph.treefy_split(g)
|
197
|
-
|
198
|
-
with self.assertRaisesRegex(ValueError, 'Expected key'):
|
199
|
-
brainstate.graph.unflatten(graphdef, brainstate.util.NestedDict({}))
|
200
|
-
|
201
|
-
def test_module_list(self):
|
202
|
-
ls = [
|
203
|
-
brainstate.nn.Linear(2, 2),
|
204
|
-
brainstate.nn.BatchNorm1d([10, 2]),
|
205
|
-
]
|
206
|
-
graphdef, statetree = brainstate.graph.treefy_split(ls)
|
207
|
-
|
208
|
-
assert statetree[0]['weight'].value['weight'].shape == (2, 2)
|
209
|
-
assert statetree[0]['weight'].value['bias'].shape == (2,)
|
210
|
-
assert statetree[1]['weight'].value['scale'].shape == (1, 2,)
|
211
|
-
assert statetree[1]['weight'].value['bias'].shape == (1, 2,)
|
212
|
-
assert statetree[1]['running_mean'].value.shape == (1, 2,)
|
213
|
-
assert statetree[1]['running_var'].value.shape == (1, 2)
|
214
|
-
|
215
|
-
def test_shared_variables(self):
|
216
|
-
v = brainstate.ParamState(1)
|
217
|
-
g = [v, v]
|
218
|
-
|
219
|
-
graphdef, statetree = brainstate.graph.treefy_split(g)
|
220
|
-
assert len(statetree.to_flat()) == 1
|
221
|
-
|
222
|
-
g2 = brainstate.graph.treefy_merge(graphdef, statetree)
|
223
|
-
assert g2[0] is g2[1]
|
224
|
-
|
225
|
-
def test_tied_weights(self):
|
226
|
-
class Foo(brainstate.nn.Module):
|
227
|
-
def __init__(self) -> None:
|
228
|
-
super().__init__()
|
229
|
-
self.bar = brainstate.nn.Linear(2, 2)
|
230
|
-
self.baz = brainstate.nn.Linear(2, 2)
|
231
|
-
|
232
|
-
# tie the weights
|
233
|
-
self.baz.weight = self.bar.weight
|
234
|
-
|
235
|
-
node = Foo()
|
236
|
-
graphdef, state = brainstate.graph.treefy_split(node)
|
237
|
-
|
238
|
-
assert len(state.to_flat()) == 1
|
239
|
-
|
240
|
-
node2 = brainstate.graph.treefy_merge(graphdef, state)
|
241
|
-
|
242
|
-
assert node2.bar.weight is node2.baz.weight
|
243
|
-
|
244
|
-
def test_tied_weights_example(self):
|
245
|
-
class LinearTranspose(brainstate.nn.Module):
|
246
|
-
def __init__(self, dout: int, din: int, ) -> None:
|
247
|
-
super().__init__()
|
248
|
-
self.kernel = brainstate.ParamState(brainstate.init.LecunNormal()((dout, din)))
|
249
|
-
|
250
|
-
def __call__(self, x):
|
251
|
-
return x @ self.kernel.value.T
|
252
|
-
|
253
|
-
class Encoder(brainstate.nn.Module):
|
254
|
-
def __init__(self, ) -> None:
|
255
|
-
super().__init__()
|
256
|
-
self.embed = brainstate.nn.Embedding(10, 2)
|
257
|
-
self.linear_out = LinearTranspose(10, 2)
|
258
|
-
|
259
|
-
# tie the weights
|
260
|
-
self.linear_out.kernel = self.embed.weight
|
261
|
-
|
262
|
-
def __call__(self, x):
|
263
|
-
x = self.embed(x)
|
264
|
-
return self.linear_out(x)
|
265
|
-
|
266
|
-
model = Encoder()
|
267
|
-
graphdef, state = brainstate.graph.treefy_split(model)
|
268
|
-
|
269
|
-
assert len(state.to_flat()) == 1
|
270
|
-
|
271
|
-
x = jax.random.randint(jax.random.key(0), (2,), 0, 10)
|
272
|
-
y = model(x)
|
273
|
-
|
274
|
-
assert y.shape == (2, 10)
|
275
|
-
|
276
|
-
def test_state_variables_not_shared_with_graph(self):
|
277
|
-
class Foo(brainstate.graph.Node):
|
278
|
-
def __init__(self):
|
279
|
-
self.a = brainstate.ParamState(1)
|
280
|
-
|
281
|
-
m = Foo()
|
282
|
-
graphdef, statetree = brainstate.graph.treefy_split(m)
|
283
|
-
|
284
|
-
assert isinstance(m.a, brainstate.ParamState)
|
285
|
-
assert issubclass(statetree.a.type, brainstate.ParamState)
|
286
|
-
assert m.a is not statetree.a
|
287
|
-
assert m.a.value == statetree.a.value
|
288
|
-
|
289
|
-
m2 = brainstate.graph.treefy_merge(graphdef, statetree)
|
290
|
-
|
291
|
-
assert isinstance(m2.a, brainstate.ParamState)
|
292
|
-
assert issubclass(statetree.a.type, brainstate.ParamState)
|
293
|
-
assert m2.a is not statetree.a
|
294
|
-
assert m2.a.value == statetree.a.value
|
295
|
-
|
296
|
-
def test_shared_state_variables_not_shared_with_graph(self):
|
297
|
-
class Foo(brainstate.graph.Node):
|
298
|
-
def __init__(self):
|
299
|
-
p = brainstate.ParamState(1)
|
300
|
-
self.a = p
|
301
|
-
self.b = p
|
302
|
-
|
303
|
-
m = Foo()
|
304
|
-
graphdef, state = brainstate.graph.treefy_split(m)
|
305
|
-
|
306
|
-
assert isinstance(m.a, brainstate.ParamState)
|
307
|
-
assert isinstance(m.b, brainstate.ParamState)
|
308
|
-
assert issubclass(state.a.type, brainstate.ParamState)
|
309
|
-
assert 'b' not in state
|
310
|
-
assert m.a is not state.a
|
311
|
-
assert m.b is not state.a
|
312
|
-
assert m.a.value == state.a.value
|
313
|
-
assert m.b.value == state.a.value
|
314
|
-
|
315
|
-
m2 = brainstate.graph.treefy_merge(graphdef, state)
|
316
|
-
|
317
|
-
assert isinstance(m2.a, brainstate.ParamState)
|
318
|
-
assert isinstance(m2.b, brainstate.ParamState)
|
319
|
-
assert issubclass(state.a.type, brainstate.ParamState)
|
320
|
-
assert m2.a is not state.a
|
321
|
-
assert m2.b is not state.a
|
322
|
-
assert m2.a.value == state.a.value
|
323
|
-
assert m2.b.value == state.a.value
|
324
|
-
assert m2.a is m2.b
|
325
|
-
|
326
|
-
def test_pytree_node(self):
|
327
|
-
@brainstate.util.dataclass
|
328
|
-
class Tree:
|
329
|
-
a: brainstate.ParamState
|
330
|
-
b: str = brainstate.util.field(pytree_node=False)
|
331
|
-
|
332
|
-
class Foo(brainstate.graph.Node):
|
333
|
-
def __init__(self):
|
334
|
-
self.tree = Tree(brainstate.ParamState(1), 'a')
|
335
|
-
|
336
|
-
m = Foo()
|
337
|
-
|
338
|
-
graphdef, state = brainstate.graph.treefy_split(m)
|
339
|
-
|
340
|
-
assert 'tree' in state
|
341
|
-
assert 'a' in state.tree
|
342
|
-
assert graphdef.subgraphs['tree'].type.__name__ == 'PytreeType'
|
343
|
-
|
344
|
-
m2 = brainstate.graph.treefy_merge(graphdef, state)
|
345
|
-
|
346
|
-
assert isinstance(m2.tree, Tree)
|
347
|
-
assert m2.tree.a.value == 1
|
348
|
-
assert m2.tree.b == 'a'
|
349
|
-
assert m2.tree.a is not m.tree.a
|
350
|
-
assert m2.tree is not m.tree
|
351
|
-
|
352
|
-
def test_call_jit_update(self):
|
353
|
-
class Counter(brainstate.graph.Node):
|
354
|
-
def __init__(self):
|
355
|
-
self.count = brainstate.ParamState(jnp.zeros(()))
|
356
|
-
|
357
|
-
def inc(self):
|
358
|
-
self.count.value += 1
|
359
|
-
return 1
|
360
|
-
|
361
|
-
graph_state = brainstate.graph.treefy_split(Counter())
|
362
|
-
|
363
|
-
@jax.jit
|
364
|
-
def update(graph_state):
|
365
|
-
out, graph_state = brainstate.graph.call(graph_state).inc()
|
366
|
-
self.assertEqual(out, 1)
|
367
|
-
return graph_state
|
368
|
-
|
369
|
-
graph_state = update(graph_state)
|
370
|
-
graph_state = update(graph_state)
|
371
|
-
|
372
|
-
counter = brainstate.graph.treefy_merge(*graph_state)
|
373
|
-
|
374
|
-
self.assertEqual(counter.count.value, 2)
|
375
|
-
|
376
|
-
def test_stateful_linear(self):
|
377
|
-
linear = StatefulLinear(3, 2)
|
378
|
-
linear_state = brainstate.graph.treefy_split(linear)
|
379
|
-
|
380
|
-
@jax.jit
|
381
|
-
def forward(x, pure_linear):
|
382
|
-
y, pure_linear = brainstate.graph.call(pure_linear)(x)
|
383
|
-
return y, pure_linear
|
384
|
-
|
385
|
-
x = jnp.ones((1, 3))
|
386
|
-
y, linear_state = forward(x, linear_state)
|
387
|
-
y, linear_state = forward(x, linear_state)
|
388
|
-
|
389
|
-
self.assertEqual(linear.count.value, 0)
|
390
|
-
new_linear = brainstate.graph.treefy_merge(*linear_state)
|
391
|
-
self.assertEqual(new_linear.count.value, 2)
|
392
|
-
|
393
|
-
def test_getitem(self):
|
394
|
-
nodes = dict(
|
395
|
-
a=StatefulLinear(3, 2),
|
396
|
-
b=StatefulLinear(2, 1),
|
397
|
-
)
|
398
|
-
node_state = brainstate.graph.treefy_split(nodes)
|
399
|
-
_, node_state = brainstate.graph.call(node_state)['b'].increment()
|
400
|
-
|
401
|
-
nodes = brainstate.graph.treefy_merge(*node_state)
|
402
|
-
|
403
|
-
self.assertEqual(nodes['a'].count.value, 0)
|
404
|
-
self.assertEqual(nodes['b'].count.value, 1)
|
405
|
-
|
406
|
-
|
407
|
-
class SimpleModule(brainstate.nn.Module):
|
408
|
-
pass
|
409
|
-
|
410
|
-
|
411
|
-
class SimplePyTreeModule(brainstate.nn.Module):
|
412
|
-
pass
|
413
|
-
|
414
|
-
|
415
|
-
class TestThreading(parameterized.TestCase):
|
416
|
-
|
417
|
-
@parameterized.parameters(
|
418
|
-
(SimpleModule,),
|
419
|
-
(SimplePyTreeModule,),
|
420
|
-
)
|
421
|
-
def test_threading(self, module_fn: Callable[[], brainstate.nn.Module]):
|
422
|
-
x = module_fn()
|
423
|
-
|
424
|
-
class MyThread(Thread):
|
425
|
-
|
426
|
-
def run(self) -> None:
|
427
|
-
brainstate.graph.treefy_split(x)
|
428
|
-
|
429
|
-
thread = MyThread()
|
430
|
-
thread.start()
|
431
|
-
thread.join()
|
432
|
-
|
433
|
-
|
434
|
-
class TestGraphOperation(unittest.TestCase):
|
435
|
-
def test1(self):
|
436
|
-
class MyNode(brainstate.graph.Node):
|
437
|
-
def __init__(self):
|
438
|
-
self.a = brainstate.nn.Linear(2, 3)
|
439
|
-
self.b = brainstate.nn.Linear(3, 2)
|
440
|
-
self.c = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(1, 3)]
|
441
|
-
self.d = {'x': brainstate.nn.Linear(1, 3), 'y': brainstate.nn.Linear(1, 4)}
|
442
|
-
|
443
|
-
graphdef, statetree = brainstate.graph.flatten(MyNode())
|
444
|
-
# print(graphdef)
|
445
|
-
print(statetree)
|
446
|
-
# print(brainstate.graph.unflatten(graphdef, statetree))
|
447
|
-
|
448
|
-
def test_split(self):
|
449
|
-
class Foo(brainstate.graph.Node):
|
450
|
-
def __init__(self):
|
451
|
-
self.a = brainstate.nn.Linear(2, 2)
|
452
|
-
self.b = brainstate.nn.BatchNorm1d([10, 2])
|
453
|
-
|
454
|
-
node = Foo()
|
455
|
-
graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
|
456
|
-
|
457
|
-
print(params)
|
458
|
-
print(jax.tree.map(jnp.shape, params))
|
459
|
-
|
460
|
-
print(jax.tree.map(jnp.shape, others))
|
461
|
-
|
462
|
-
def test_merge(self):
|
463
|
-
class Foo(brainstate.graph.Node):
|
464
|
-
def __init__(self):
|
465
|
-
self.a = brainstate.nn.Linear(2, 2)
|
466
|
-
self.b = brainstate.nn.BatchNorm1d([10, 2])
|
467
|
-
|
468
|
-
node = Foo()
|
469
|
-
graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
|
470
|
-
|
471
|
-
new_node = brainstate.graph.treefy_merge(graphdef, params, others)
|
472
|
-
|
473
|
-
assert isinstance(new_node, Foo)
|
474
|
-
assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
|
475
|
-
assert isinstance(new_node.a, brainstate.nn.Linear)
|
476
|
-
|
477
|
-
def test_update_states(self):
|
478
|
-
x = jnp.ones((1, 2))
|
479
|
-
y = jnp.ones((1, 3))
|
480
|
-
model = brainstate.nn.Linear(2, 3)
|
481
|
-
|
482
|
-
def loss_fn(x, y):
|
483
|
-
return jnp.mean((y - model(x)) ** 2)
|
484
|
-
|
485
|
-
def sgd(ps, gs):
|
486
|
-
updates = jax.tree.map(lambda p, g: p - 0.1 * g, ps.value, gs)
|
487
|
-
ps.value = updates
|
488
|
-
|
489
|
-
prev_loss = loss_fn(x, y)
|
490
|
-
weights = model.states()
|
491
|
-
grads = brainstate.augment.grad(loss_fn, weights)(x, y)
|
492
|
-
for key, val in grads.items():
|
493
|
-
sgd(weights[key], val)
|
494
|
-
assert loss_fn(x, y) < prev_loss
|
495
|
-
|
496
|
-
def test_pop_states(self):
|
497
|
-
class Model(brainstate.nn.Module):
|
498
|
-
def __init__(self):
|
499
|
-
super().__init__()
|
500
|
-
self.a = brainstate.nn.Linear(2, 3)
|
501
|
-
self.b = brainstate.nn.LIF([10, 2])
|
502
|
-
|
503
|
-
model = Model()
|
504
|
-
with brainstate.catch_new_states('new'):
|
505
|
-
brainstate.nn.init_all_states(model)
|
506
|
-
# print(model.states())
|
507
|
-
self.assertTrue(len(model.states()) == 2)
|
508
|
-
model_states = brainstate.graph.pop_states(model, 'new')
|
509
|
-
print(model_states)
|
510
|
-
self.assertTrue(len(model.states()) == 1)
|
511
|
-
assert not hasattr(model.b, 'V')
|
512
|
-
# print(model.states())
|
513
|
-
|
514
|
-
def test_treefy_split(self):
|
515
|
-
class MLP(brainstate.graph.Node):
|
516
|
-
def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
|
517
|
-
self.input = brainstate.nn.Linear(din, dmid)
|
518
|
-
self.layers = [brainstate.nn.Linear(dmid, dmid) for _ in range(n_layer)]
|
519
|
-
self.output = brainstate.nn.Linear(dmid, dout)
|
520
|
-
|
521
|
-
def __call__(self, x):
|
522
|
-
x = brainstate.functional.relu(self.input(x))
|
523
|
-
for layer in self.layers:
|
524
|
-
x = brainstate.functional.relu(layer(x))
|
525
|
-
return self.output(x)
|
526
|
-
|
527
|
-
model = MLP(2, 1, 3)
|
528
|
-
graph_def, treefy_states = brainstate.graph.treefy_split(model)
|
529
|
-
|
530
|
-
print(graph_def)
|
531
|
-
print(treefy_states)
|
532
|
-
|
533
|
-
# states = brainstate.graph.states(model)
|
534
|
-
# print(states)
|
535
|
-
# nest_states = states.to_nest()
|
536
|
-
# print(nest_states)
|
537
|
-
|
538
|
-
def test_states(self):
|
539
|
-
class MLP(brainstate.graph.Node):
|
540
|
-
def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
|
541
|
-
self.input = brainstate.nn.Linear(din, dmid)
|
542
|
-
self.layers = [brainstate.nn.Linear(dmid, dmid) for _ in range(n_layer)]
|
543
|
-
self.output = brainstate.nn.LIF(dout)
|
544
|
-
|
545
|
-
def __call__(self, x):
|
546
|
-
x = brainstate.functional.relu(self.input(x))
|
547
|
-
for layer in self.layers:
|
548
|
-
x = brainstate.functional.relu(layer(x))
|
549
|
-
return self.output(x)
|
550
|
-
|
551
|
-
model = brainstate.nn.init_all_states(MLP(2, 1, 3))
|
552
|
-
states = brainstate.graph.states(model)
|
553
|
-
print(states)
|
554
|
-
nest_states = states.to_nest()
|
555
|
-
print(nest_states)
|
556
|
-
|
557
|
-
params, others = brainstate.graph.states(model, brainstate.ParamState, brainstate.ShortTermState)
|
558
|
-
print(params)
|
559
|
-
print(others)
|
560
|
-
|
561
|
-
|
562
|
-
if __name__ == '__main__':
|
563
|
-
absltest.main()
|
brainstate/init/__init__.py
DELETED
@@ -1,26 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
from ._base import *
|
18
|
-
from ._base import __all__ as _base_all
|
19
|
-
from ._generic import *
|
20
|
-
from ._generic import __all__ as _generic_all
|
21
|
-
from ._random_inits import *
|
22
|
-
from ._random_inits import __all__ as _random_inits_all
|
23
|
-
from ._regular_inits import *
|
24
|
-
from ._regular_inits import __all__ as _regular_inits_all
|
25
|
-
|
26
|
-
__all__ = _generic_all + _base_all + _regular_inits_all + _random_inits_all
|
brainstate/init/_base.py
DELETED
@@ -1,52 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from typing import Optional, Tuple
|
17
|
-
|
18
|
-
import numpy as np
|
19
|
-
|
20
|
-
from brainstate.util import PrettyRepr, PrettyType, PrettyAttr
|
21
|
-
|
22
|
-
__all__ = ['Initializer', 'to_size']
|
23
|
-
|
24
|
-
|
25
|
-
class Initializer(PrettyRepr):
|
26
|
-
"""
|
27
|
-
Base class for initializers.
|
28
|
-
"""
|
29
|
-
__module__ = 'brainstate.init'
|
30
|
-
|
31
|
-
def __call__(self, *args, **kwargs):
|
32
|
-
raise NotImplementedError
|
33
|
-
|
34
|
-
def __pretty_repr__(self):
|
35
|
-
"""
|
36
|
-
Pretty repr for the object.
|
37
|
-
"""
|
38
|
-
yield PrettyType(type=type(self))
|
39
|
-
for name, value in vars(self).items():
|
40
|
-
if name.startswith('_'):
|
41
|
-
continue
|
42
|
-
yield PrettyAttr(name, repr(value))
|
43
|
-
|
44
|
-
|
45
|
-
def to_size(x) -> Optional[Tuple[int]]:
|
46
|
-
if isinstance(x, (tuple, list)):
|
47
|
-
return tuple(x)
|
48
|
-
if isinstance(x, (int, np.integer)):
|
49
|
-
return (x,)
|
50
|
-
if x is None:
|
51
|
-
return x
|
52
|
-
raise ValueError(f'Cannot make a size for {x}')
|