brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,724 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import unittest
|
19
|
+
from collections.abc import Callable
|
20
|
+
from functools import partial
|
21
|
+
from threading import Thread
|
22
|
+
from typing import Any
|
23
|
+
|
24
|
+
import jax
|
25
|
+
import jax.numpy as jnp
|
26
|
+
import pytest
|
27
|
+
from absl.testing import absltest, parameterized
|
28
|
+
|
29
|
+
import brainstate as bst
|
30
|
+
|
31
|
+
|
32
|
+
class TestIter(unittest.TestCase):
|
33
|
+
def test1(self):
|
34
|
+
class Model(bst.nn.Module):
|
35
|
+
def __init__(self):
|
36
|
+
super().__init__()
|
37
|
+
self.a = bst.nn.Linear(1, 2)
|
38
|
+
self.b = bst.nn.Linear(2, 3)
|
39
|
+
self.c = [bst.nn.Linear(3, 4), bst.nn.Linear(4, 5)]
|
40
|
+
self.d = {'x': bst.nn.Linear(5, 6), 'y': bst.nn.Linear(6, 7)}
|
41
|
+
self.b.a = bst.nn.LIF(2)
|
42
|
+
|
43
|
+
for path, node in bst.graph.iter_leaf(Model()):
|
44
|
+
print(path, node)
|
45
|
+
for path, node in bst.graph.iter_node(Model()):
|
46
|
+
print(path, node)
|
47
|
+
for path, node in bst.graph.iter_node(Model(), allowed_hierarchy=(1, 1)):
|
48
|
+
print(path, node)
|
49
|
+
for path, node in bst.graph.iter_node(Model(), allowed_hierarchy=(2, 2)):
|
50
|
+
print(path, node)
|
51
|
+
|
52
|
+
def test_iter_leaf_v1(self):
|
53
|
+
class Linear(bst.nn.Module):
|
54
|
+
def __init__(self, din, dout):
|
55
|
+
super().__init__()
|
56
|
+
self.weight = bst.ParamState(bst.random.randn(din, dout))
|
57
|
+
self.bias = bst.ParamState(bst.random.randn(dout))
|
58
|
+
self.a = 1
|
59
|
+
|
60
|
+
module = Linear(3, 4)
|
61
|
+
graph = [module, module]
|
62
|
+
|
63
|
+
num = 0
|
64
|
+
for path, value in bst.graph.iter_leaf(graph):
|
65
|
+
print(path, type(value).__name__)
|
66
|
+
num += 1
|
67
|
+
|
68
|
+
assert num == 3
|
69
|
+
|
70
|
+
def test_iter_node_v1(self):
|
71
|
+
class Model(bst.nn.Module):
|
72
|
+
def __init__(self):
|
73
|
+
super().__init__()
|
74
|
+
self.a = bst.nn.Linear(1, 2)
|
75
|
+
self.b = bst.nn.Linear(2, 3)
|
76
|
+
self.c = [bst.nn.Linear(3, 4), bst.nn.Linear(4, 5)]
|
77
|
+
self.d = {'x': bst.nn.Linear(5, 6), 'y': bst.nn.Linear(6, 7)}
|
78
|
+
self.b.a = bst.nn.LIF(2)
|
79
|
+
|
80
|
+
model = Model()
|
81
|
+
|
82
|
+
num = 0
|
83
|
+
for path, node in bst.graph.iter_node([model, model]):
|
84
|
+
print(path, node.__class__.__name__)
|
85
|
+
num += 1
|
86
|
+
assert num == 8
|
87
|
+
|
88
|
+
|
89
|
+
class List(bst.nn.Module):
|
90
|
+
def __init__(self, items):
|
91
|
+
super().__init__()
|
92
|
+
self.items = list(items)
|
93
|
+
|
94
|
+
def __getitem__(self, idx):
|
95
|
+
return self.items[idx]
|
96
|
+
|
97
|
+
def __setitem__(self, idx, value):
|
98
|
+
self.items[idx] = value
|
99
|
+
|
100
|
+
|
101
|
+
class Dict(bst.nn.Module):
|
102
|
+
def __init__(self, *args, **kwargs):
|
103
|
+
super().__init__()
|
104
|
+
self.items = dict(*args, **kwargs)
|
105
|
+
|
106
|
+
def __getitem__(self, key):
|
107
|
+
return self.items[key]
|
108
|
+
|
109
|
+
def __setitem__(self, key, value):
|
110
|
+
self.items[key] = value
|
111
|
+
|
112
|
+
|
113
|
+
class StatefulLinear(bst.nn.Module):
|
114
|
+
def __init__(self, din, dout):
|
115
|
+
super().__init__()
|
116
|
+
self.w = bst.ParamState(bst.random.rand(din, dout))
|
117
|
+
self.b = bst.ParamState(jnp.zeros((dout,)))
|
118
|
+
self.count = bst.State(jnp.array(0, dtype=jnp.uint32))
|
119
|
+
|
120
|
+
def increment(self):
|
121
|
+
self.count.value += 1
|
122
|
+
|
123
|
+
def __call__(self, x):
|
124
|
+
self.count.value += 1
|
125
|
+
return x @ self.w.value + self.b.value
|
126
|
+
|
127
|
+
|
128
|
+
class TestGraphUtils(absltest.TestCase):
|
129
|
+
def test_flatten_treey_state(self):
|
130
|
+
a = {'a': 1, 'b': bst.ParamState(2)}
|
131
|
+
g = [a, 3, a, bst.ParamState(4)]
|
132
|
+
|
133
|
+
refmap = bst.graph.RefMap()
|
134
|
+
graphdef, states = bst.graph.flatten(g, ref_index=refmap, treefy_state=True)
|
135
|
+
|
136
|
+
states[0]['b'].value = 2
|
137
|
+
states[3].value = 4
|
138
|
+
|
139
|
+
assert isinstance(states[0]['b'], bst.TreefyState)
|
140
|
+
assert isinstance(states[3], bst.TreefyState)
|
141
|
+
assert isinstance(states, bst.util.NestedDict)
|
142
|
+
assert len(refmap) == 2
|
143
|
+
assert a['b'] in refmap
|
144
|
+
assert g[3] in refmap
|
145
|
+
|
146
|
+
def test_flatten(self):
|
147
|
+
a = {'a': 1, 'b': bst.ParamState(2)}
|
148
|
+
g = [a, 3, a, bst.ParamState(4)]
|
149
|
+
|
150
|
+
refmap = bst.graph.RefMap()
|
151
|
+
graphdef, states = bst.graph.flatten(g, ref_index=refmap, treefy_state=False)
|
152
|
+
|
153
|
+
states[0]['b'].value = 2
|
154
|
+
states[3].value = 4
|
155
|
+
|
156
|
+
assert isinstance(states[0]['b'], bst.State)
|
157
|
+
assert isinstance(states[3], bst.State)
|
158
|
+
assert len(refmap) == 2
|
159
|
+
assert a['b'] in refmap
|
160
|
+
assert g[3] in refmap
|
161
|
+
|
162
|
+
def test_unflatten_treey_state(self):
|
163
|
+
a = bst.graph.Dict(a=1, b=bst.ParamState(2))
|
164
|
+
g1 = bst.graph.List([a, 3, a, bst.ParamState(4)])
|
165
|
+
|
166
|
+
graphdef, references = bst.graph.flatten(g1, treefy_state=True)
|
167
|
+
g = bst.graph.unflatten(graphdef, references)
|
168
|
+
|
169
|
+
print(graphdef)
|
170
|
+
print(references)
|
171
|
+
assert g[0] is g[2]
|
172
|
+
assert g1[3] is not g[3]
|
173
|
+
assert g1[0]['b'] is not g[0]['b']
|
174
|
+
|
175
|
+
def test_unflatten(self):
|
176
|
+
a = bst.graph.Dict(a=1, b=bst.ParamState(2))
|
177
|
+
g1 = bst.graph.List([a, 3, a, bst.ParamState(4)])
|
178
|
+
|
179
|
+
graphdef, references = bst.graph.flatten(g1, treefy_state=False)
|
180
|
+
g = bst.graph.unflatten(graphdef, references)
|
181
|
+
|
182
|
+
print(graphdef)
|
183
|
+
print(references)
|
184
|
+
assert g[0] is g[2]
|
185
|
+
assert g1[3] is g[3]
|
186
|
+
assert g1[0]['b'] is g[0]['b']
|
187
|
+
|
188
|
+
def test_unflatten_pytree(self):
|
189
|
+
a = {'a': 1, 'b': bst.ParamState(2)}
|
190
|
+
g = [a, 3, a, bst.ParamState(4)]
|
191
|
+
|
192
|
+
graphdef, references = bst.graph.treefy_split(g)
|
193
|
+
g = bst.graph.treefy_merge(graphdef, references)
|
194
|
+
|
195
|
+
assert g[0] is not g[2]
|
196
|
+
|
197
|
+
def test_unflatten_empty(self):
|
198
|
+
a = Dict({'a': 1, 'b': bst.ParamState(2)})
|
199
|
+
g = List([a, 3, a, bst.ParamState(4)])
|
200
|
+
|
201
|
+
graphdef, references = bst.graph.treefy_split(g)
|
202
|
+
|
203
|
+
with self.assertRaisesRegex(ValueError, 'Expected key'):
|
204
|
+
bst.graph.unflatten(graphdef, bst.util.NestedDict({}))
|
205
|
+
|
206
|
+
def test_module_list(self):
|
207
|
+
ls = [
|
208
|
+
bst.nn.Linear(2, 2),
|
209
|
+
bst.nn.BatchNorm1d([10, 2]),
|
210
|
+
]
|
211
|
+
graphdef, statetree = bst.graph.treefy_split(ls)
|
212
|
+
|
213
|
+
assert statetree[0]['weight'].value['weight'].shape == (2, 2)
|
214
|
+
assert statetree[0]['weight'].value['bias'].shape == (2,)
|
215
|
+
assert statetree[1]['weight'].value['scale'].shape == (1, 2,)
|
216
|
+
assert statetree[1]['weight'].value['bias'].shape == (1, 2,)
|
217
|
+
assert statetree[1]['running_mean'].value.shape == (1, 2,)
|
218
|
+
assert statetree[1]['running_var'].value.shape == (1, 2)
|
219
|
+
|
220
|
+
def test_shared_variables(self):
|
221
|
+
v = bst.ParamState(1)
|
222
|
+
g = [v, v]
|
223
|
+
|
224
|
+
graphdef, statetree = bst.graph.treefy_split(g)
|
225
|
+
assert len(statetree.to_flat()) == 1
|
226
|
+
|
227
|
+
g2 = bst.graph.treefy_merge(graphdef, statetree)
|
228
|
+
assert g2[0] is g2[1]
|
229
|
+
|
230
|
+
def test_tied_weights(self):
|
231
|
+
class Foo(bst.nn.Module):
|
232
|
+
def __init__(self) -> None:
|
233
|
+
super().__init__()
|
234
|
+
self.bar = bst.nn.Linear(2, 2)
|
235
|
+
self.baz = bst.nn.Linear(2, 2)
|
236
|
+
|
237
|
+
# tie the weights
|
238
|
+
self.baz.weight = self.bar.weight
|
239
|
+
|
240
|
+
node = Foo()
|
241
|
+
graphdef, state = bst.graph.treefy_split(node)
|
242
|
+
|
243
|
+
assert len(state.to_flat()) == 1
|
244
|
+
|
245
|
+
node2 = bst.graph.treefy_merge(graphdef, state)
|
246
|
+
|
247
|
+
assert node2.bar.weight is node2.baz.weight
|
248
|
+
|
249
|
+
def test_tied_weights_example(self):
|
250
|
+
class LinearTranspose(bst.nn.Module):
|
251
|
+
def __init__(self, dout: int, din: int, ) -> None:
|
252
|
+
super().__init__()
|
253
|
+
self.kernel = bst.ParamState(bst.init.LecunNormal()((dout, din)))
|
254
|
+
|
255
|
+
def __call__(self, x):
|
256
|
+
return x @ self.kernel.value.T
|
257
|
+
|
258
|
+
class Encoder(bst.nn.Module):
|
259
|
+
def __init__(self, ) -> None:
|
260
|
+
super().__init__()
|
261
|
+
self.embed = bst.nn.Embedding(10, 2)
|
262
|
+
self.linear_out = LinearTranspose(10, 2)
|
263
|
+
|
264
|
+
# tie the weights
|
265
|
+
self.linear_out.kernel = self.embed.weight
|
266
|
+
|
267
|
+
def __call__(self, x):
|
268
|
+
x = self.embed(x)
|
269
|
+
return self.linear_out(x)
|
270
|
+
|
271
|
+
model = Encoder()
|
272
|
+
graphdef, state = bst.graph.treefy_split(model)
|
273
|
+
|
274
|
+
assert len(state.to_flat()) == 1
|
275
|
+
|
276
|
+
x = jax.random.randint(jax.random.key(0), (2,), 0, 10)
|
277
|
+
y = model(x)
|
278
|
+
|
279
|
+
assert y.shape == (2, 10)
|
280
|
+
|
281
|
+
def test_state_variables_not_shared_with_graph(self):
|
282
|
+
class Foo(bst.graph.Node):
|
283
|
+
def __init__(self):
|
284
|
+
self.a = bst.ParamState(1)
|
285
|
+
|
286
|
+
m = Foo()
|
287
|
+
graphdef, statetree = bst.graph.treefy_split(m)
|
288
|
+
|
289
|
+
assert isinstance(m.a, bst.ParamState)
|
290
|
+
assert issubclass(statetree.a.type, bst.ParamState)
|
291
|
+
assert m.a is not statetree.a
|
292
|
+
assert m.a.value == statetree.a.value
|
293
|
+
|
294
|
+
m2 = bst.graph.treefy_merge(graphdef, statetree)
|
295
|
+
|
296
|
+
assert isinstance(m2.a, bst.ParamState)
|
297
|
+
assert issubclass(statetree.a.type, bst.ParamState)
|
298
|
+
assert m2.a is not statetree.a
|
299
|
+
assert m2.a.value == statetree.a.value
|
300
|
+
|
301
|
+
def test_shared_state_variables_not_shared_with_graph(self):
|
302
|
+
class Foo(bst.graph.Node):
|
303
|
+
def __init__(self):
|
304
|
+
p = bst.ParamState(1)
|
305
|
+
self.a = p
|
306
|
+
self.b = p
|
307
|
+
|
308
|
+
m = Foo()
|
309
|
+
graphdef, state = bst.graph.treefy_split(m)
|
310
|
+
|
311
|
+
assert isinstance(m.a, bst.ParamState)
|
312
|
+
assert isinstance(m.b, bst.ParamState)
|
313
|
+
assert issubclass(state.a.type, bst.ParamState)
|
314
|
+
assert 'b' not in state
|
315
|
+
assert m.a is not state.a
|
316
|
+
assert m.b is not state.a
|
317
|
+
assert m.a.value == state.a.value
|
318
|
+
assert m.b.value == state.a.value
|
319
|
+
|
320
|
+
m2 = bst.graph.treefy_merge(graphdef, state)
|
321
|
+
|
322
|
+
assert isinstance(m2.a, bst.ParamState)
|
323
|
+
assert isinstance(m2.b, bst.ParamState)
|
324
|
+
assert issubclass(state.a.type, bst.ParamState)
|
325
|
+
assert m2.a is not state.a
|
326
|
+
assert m2.b is not state.a
|
327
|
+
assert m2.a.value == state.a.value
|
328
|
+
assert m2.b.value == state.a.value
|
329
|
+
assert m2.a is m2.b
|
330
|
+
|
331
|
+
def test_pytree_node(self):
|
332
|
+
@bst.util.dataclass
|
333
|
+
class Tree:
|
334
|
+
a: bst.ParamState
|
335
|
+
b: str = bst.util.field(pytree_node=False)
|
336
|
+
|
337
|
+
class Foo(bst.graph.Node):
|
338
|
+
def __init__(self):
|
339
|
+
self.tree = Tree(bst.ParamState(1), 'a')
|
340
|
+
|
341
|
+
m = Foo()
|
342
|
+
|
343
|
+
graphdef, state = bst.graph.treefy_split(m)
|
344
|
+
|
345
|
+
assert 'tree' in state
|
346
|
+
assert 'a' in state.tree
|
347
|
+
assert graphdef.subgraphs['tree'].type.__name__ == 'PytreeType'
|
348
|
+
|
349
|
+
m2 = bst.graph.treefy_merge(graphdef, state)
|
350
|
+
|
351
|
+
assert isinstance(m2.tree, Tree)
|
352
|
+
assert m2.tree.a.value == 1
|
353
|
+
assert m2.tree.b == 'a'
|
354
|
+
assert m2.tree.a is not m.tree.a
|
355
|
+
assert m2.tree is not m.tree
|
356
|
+
|
357
|
+
@pytest.mark.skip(reason='Not implemented')
|
358
|
+
def test_cached_unflatten(self):
|
359
|
+
class Foo(bst.graph.Node):
|
360
|
+
def __init__(self, ):
|
361
|
+
self.a = bst.nn.Linear(2, 2)
|
362
|
+
self.b = bst.nn.BatchNorm1d([10, 2])
|
363
|
+
|
364
|
+
def f(m: Foo):
|
365
|
+
m.a, m.b = m.b, m.a # type: ignore
|
366
|
+
|
367
|
+
m = Foo()
|
368
|
+
a = m.a
|
369
|
+
b = m.b
|
370
|
+
|
371
|
+
ref_out_idx_out = bst.graph.RefMap()
|
372
|
+
graphdef: bst.graph.GraphDef[Foo]
|
373
|
+
graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
|
374
|
+
|
375
|
+
@partial(jax.jit, static_argnums=(0,))
|
376
|
+
def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
|
377
|
+
idx_out_ref_in: dict[int, Any] = {}
|
378
|
+
m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
|
379
|
+
f(m)
|
380
|
+
ref_in_idx_in = bst.graph.RefMap[Any, int]()
|
381
|
+
graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
|
382
|
+
idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
|
383
|
+
static_out = bst.graph.Static((graphdef, idx_out_idx_in))
|
384
|
+
return state, static_out
|
385
|
+
|
386
|
+
static_out: bst.graph.Static
|
387
|
+
state, static_out = f_pure(graphdef, state)
|
388
|
+
idx_out_idx_in: dict[int, int]
|
389
|
+
graphdef, idx_out_idx_in = static_out.value
|
390
|
+
idx_in_ref_out = bst.graph.compose_mapping_reversed(
|
391
|
+
ref_out_idx_out, idx_out_idx_in
|
392
|
+
)
|
393
|
+
m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
|
394
|
+
assert m2 is m
|
395
|
+
assert m2.a is b
|
396
|
+
assert m2.b is a
|
397
|
+
|
398
|
+
@pytest.mark.skip(reason='Not implemented')
|
399
|
+
def test_cached_unflatten_swap_variables(self):
|
400
|
+
class Foo(bst.graph.Node):
|
401
|
+
def __init__(self):
|
402
|
+
self.a = bst.ParamState(1)
|
403
|
+
self.b = bst.ParamState(2)
|
404
|
+
|
405
|
+
def f(m: Foo):
|
406
|
+
m.a, m.b = m.b, m.a
|
407
|
+
|
408
|
+
m = Foo()
|
409
|
+
a = m.a
|
410
|
+
b = m.b
|
411
|
+
|
412
|
+
ref_out_idx_out = bst.graph.RefMap[Any, int]()
|
413
|
+
graphdef: bst.graph.GraphDef[Foo]
|
414
|
+
graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
|
415
|
+
|
416
|
+
@partial(jax.jit, static_argnums=(0,))
|
417
|
+
def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
|
418
|
+
idx_out_ref_in: dict[int, Any] = {}
|
419
|
+
m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
|
420
|
+
f(m)
|
421
|
+
ref_in_idx_in = bst.graph.RefMap[Any, int]()
|
422
|
+
graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
|
423
|
+
idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
|
424
|
+
static_out = bst.graph.Static((graphdef, idx_out_idx_in))
|
425
|
+
return state, static_out
|
426
|
+
|
427
|
+
static_out: bst.graph.Static
|
428
|
+
state, static_out = f_pure(graphdef, state)
|
429
|
+
idx_out_idx_in: dict[int, int]
|
430
|
+
graphdef, idx_out_idx_in = static_out.value
|
431
|
+
idx_in_ref_out = bst.graph.compose_mapping_reversed(
|
432
|
+
ref_out_idx_out, idx_out_idx_in
|
433
|
+
)
|
434
|
+
m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
|
435
|
+
assert m2 is m
|
436
|
+
assert m2.a is b
|
437
|
+
assert m2.b is a
|
438
|
+
|
439
|
+
@pytest.mark.skip(reason='Not implemented')
|
440
|
+
def test_cached_unflatten_add_self_reference(self):
|
441
|
+
class Foo(bst.graph.Node):
|
442
|
+
def __init__(self):
|
443
|
+
self.ref = None
|
444
|
+
|
445
|
+
def f(m: Foo):
|
446
|
+
m.ref = m
|
447
|
+
|
448
|
+
m = Foo()
|
449
|
+
|
450
|
+
ref_out_idx_out = bst.graph.RefMap()
|
451
|
+
graphdef: bst.graph.GraphDef[Foo]
|
452
|
+
graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
|
453
|
+
|
454
|
+
@partial(jax.jit, static_argnums=(0,))
|
455
|
+
def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
|
456
|
+
idx_out_ref_in: dict[int, Any] = {}
|
457
|
+
m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
|
458
|
+
f(m)
|
459
|
+
ref_in_idx_in = bst.graph.RefMap[Any, int]()
|
460
|
+
graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
|
461
|
+
idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
|
462
|
+
static_out = bst.graph.Static((graphdef, idx_out_idx_in))
|
463
|
+
return state, static_out
|
464
|
+
|
465
|
+
static_out: bst.graph.Static
|
466
|
+
state, static_out = f_pure(graphdef, state)
|
467
|
+
idx_out_idx_in: dict[int, int]
|
468
|
+
graphdef, idx_out_idx_in = static_out.value
|
469
|
+
idx_in_ref_out = bst.graph.compose_mapping_reversed(
|
470
|
+
ref_out_idx_out, idx_out_idx_in
|
471
|
+
)
|
472
|
+
m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
|
473
|
+
assert m2 is m
|
474
|
+
assert m2.ref is m2
|
475
|
+
|
476
|
+
def test_call_jit_update(self):
|
477
|
+
class Counter(bst.graph.Node):
|
478
|
+
def __init__(self):
|
479
|
+
self.count = bst.ParamState(jnp.zeros(()))
|
480
|
+
|
481
|
+
def inc(self):
|
482
|
+
self.count.value += 1
|
483
|
+
return 1
|
484
|
+
|
485
|
+
graph_state = bst.graph.treefy_split(Counter())
|
486
|
+
|
487
|
+
@jax.jit
|
488
|
+
def update(graph_state):
|
489
|
+
out, graph_state = bst.graph.call(graph_state).inc()
|
490
|
+
self.assertEqual(out, 1)
|
491
|
+
return graph_state
|
492
|
+
|
493
|
+
graph_state = update(graph_state)
|
494
|
+
graph_state = update(graph_state)
|
495
|
+
|
496
|
+
counter = bst.graph.treefy_merge(*graph_state)
|
497
|
+
|
498
|
+
self.assertEqual(counter.count.value, 2)
|
499
|
+
|
500
|
+
def test_stateful_linear(self):
|
501
|
+
linear = StatefulLinear(3, 2)
|
502
|
+
linear_state = bst.graph.treefy_split(linear)
|
503
|
+
|
504
|
+
@jax.jit
|
505
|
+
def forward(x, pure_linear):
|
506
|
+
y, pure_linear = bst.graph.call(pure_linear)(x)
|
507
|
+
return y, pure_linear
|
508
|
+
|
509
|
+
x = jnp.ones((1, 3))
|
510
|
+
y, linear_state = forward(x, linear_state)
|
511
|
+
y, linear_state = forward(x, linear_state)
|
512
|
+
|
513
|
+
self.assertEqual(linear.count.value, 0)
|
514
|
+
new_linear = bst.graph.treefy_merge(*linear_state)
|
515
|
+
self.assertEqual(new_linear.count.value, 2)
|
516
|
+
|
517
|
+
def test_getitem(self):
|
518
|
+
nodes = dict(
|
519
|
+
a=StatefulLinear(3, 2),
|
520
|
+
b=StatefulLinear(2, 1),
|
521
|
+
)
|
522
|
+
node_state = bst.graph.treefy_split(nodes)
|
523
|
+
_, node_state = bst.graph.call(node_state)['b'].increment()
|
524
|
+
|
525
|
+
nodes = bst.graph.treefy_merge(*node_state)
|
526
|
+
|
527
|
+
self.assertEqual(nodes['a'].count.value, 0)
|
528
|
+
self.assertEqual(nodes['b'].count.value, 1)
|
529
|
+
|
530
|
+
def test_to_tree_simple(self):
|
531
|
+
m = bst.nn.Linear(2, 3, )
|
532
|
+
impure_tree = (m, 1, {'b': m})
|
533
|
+
|
534
|
+
pure_tree = bst.graph.graph_to_tree(impure_tree)
|
535
|
+
|
536
|
+
t1 = pure_tree[0]
|
537
|
+
t2 = pure_tree[2]['b']
|
538
|
+
|
539
|
+
self.assertEqual(pure_tree[1], 1)
|
540
|
+
self.assertIsInstance(t1, bst.graph.NodeStates)
|
541
|
+
assert isinstance(t1, bst.graph.NodeStates)
|
542
|
+
self.assertIsInstance(t2, bst.graph.NodeStates)
|
543
|
+
assert isinstance(t2, bst.graph.NodeStates)
|
544
|
+
self.assertIsInstance(t1.graphdef, bst.graph.NodeDef)
|
545
|
+
self.assertIsInstance(t2.graphdef, bst.graph.NodeRef)
|
546
|
+
self.assertLen(t1.states[0].to_flat(), 1)
|
547
|
+
self.assertLen(t2.states[0].to_flat(), 0)
|
548
|
+
|
549
|
+
impure_tree2 = bst.graph.tree_to_graph(pure_tree)
|
550
|
+
|
551
|
+
m1_out = impure_tree2[0]
|
552
|
+
m2_out = impure_tree2[2]['b']
|
553
|
+
|
554
|
+
self.assertIs(m1_out, m2_out)
|
555
|
+
self.assertEqual(impure_tree2[1], 1)
|
556
|
+
|
557
|
+
def test_to_tree_consistent_prefix(self):
|
558
|
+
m = bst.nn.Linear(2, 3, )
|
559
|
+
impure_tree = (m, 1, {'b': m})
|
560
|
+
prefix = (0, None, 0)
|
561
|
+
pure_tree = bst.graph.graph_to_tree(impure_tree, prefix=prefix)
|
562
|
+
|
563
|
+
prefix = (0, None, 1)
|
564
|
+
with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'):
|
565
|
+
bst.graph.graph_to_tree(impure_tree, prefix=prefix)
|
566
|
+
|
567
|
+
|
568
|
+
class SimpleModule(bst.nn.Module):
|
569
|
+
pass
|
570
|
+
|
571
|
+
|
572
|
+
class SimplePyTreeModule(bst.nn.Module):
|
573
|
+
pass
|
574
|
+
|
575
|
+
|
576
|
+
class TestThreading(parameterized.TestCase):
|
577
|
+
|
578
|
+
@parameterized.parameters(
|
579
|
+
(SimpleModule,),
|
580
|
+
(SimplePyTreeModule,),
|
581
|
+
)
|
582
|
+
def test_threading(self, module_fn: Callable[[], bst.nn.Module]):
|
583
|
+
x = module_fn()
|
584
|
+
|
585
|
+
class MyThread(Thread):
|
586
|
+
|
587
|
+
def run(self) -> None:
|
588
|
+
bst.graph.treefy_split(x)
|
589
|
+
|
590
|
+
thread = MyThread()
|
591
|
+
thread.start()
|
592
|
+
thread.join()
|
593
|
+
|
594
|
+
|
595
|
+
class TestGraphOperation(unittest.TestCase):
|
596
|
+
def test1(self):
|
597
|
+
class MyNode(bst.graph.Node):
|
598
|
+
def __init__(self):
|
599
|
+
self.a = bst.nn.Linear(2, 3)
|
600
|
+
self.b = bst.nn.Linear(3, 2)
|
601
|
+
self.c = [bst.nn.Linear(1, 2), bst.nn.Linear(1, 3)]
|
602
|
+
self.d = {'x': bst.nn.Linear(1, 3), 'y': bst.nn.Linear(1, 4)}
|
603
|
+
|
604
|
+
graphdef, statetree = bst.graph.flatten(MyNode())
|
605
|
+
# print(graphdef)
|
606
|
+
print(statetree)
|
607
|
+
# print(bst.graph.unflatten(graphdef, statetree))
|
608
|
+
|
609
|
+
def test_split(self):
|
610
|
+
class Foo(bst.graph.Node):
|
611
|
+
def __init__(self):
|
612
|
+
self.a = bst.nn.Linear(2, 2)
|
613
|
+
self.b = bst.nn.BatchNorm1d([10, 2])
|
614
|
+
|
615
|
+
node = Foo()
|
616
|
+
graphdef, params, others = bst.graph.treefy_split(node, bst.ParamState, ...)
|
617
|
+
|
618
|
+
print(params)
|
619
|
+
print(jax.tree.map(jnp.shape, params))
|
620
|
+
|
621
|
+
print(jax.tree.map(jnp.shape, others))
|
622
|
+
|
623
|
+
def test_merge(self):
|
624
|
+
class Foo(bst.graph.Node):
|
625
|
+
def __init__(self):
|
626
|
+
self.a = bst.nn.Linear(2, 2)
|
627
|
+
self.b = bst.nn.BatchNorm1d([10, 2])
|
628
|
+
|
629
|
+
node = Foo()
|
630
|
+
graphdef, params, others = bst.graph.treefy_split(node, bst.ParamState, ...)
|
631
|
+
|
632
|
+
new_node = bst.graph.treefy_merge(graphdef, params, others)
|
633
|
+
|
634
|
+
assert isinstance(new_node, Foo)
|
635
|
+
assert isinstance(new_node.b, bst.nn.BatchNorm1d)
|
636
|
+
assert isinstance(new_node.a, bst.nn.Linear)
|
637
|
+
|
638
|
+
def test_update_states(self):
|
639
|
+
x = jnp.ones((1, 2))
|
640
|
+
y = jnp.ones((1, 3))
|
641
|
+
model = bst.nn.Linear(2, 3)
|
642
|
+
|
643
|
+
def loss_fn(x, y):
|
644
|
+
return jnp.mean((y - model(x)) ** 2)
|
645
|
+
|
646
|
+
def sgd(ps, gs):
|
647
|
+
updates = jax.tree.map(lambda p, g: p - 0.1 * g, ps.value, gs)
|
648
|
+
ps.value = updates
|
649
|
+
|
650
|
+
prev_loss = loss_fn(x, y)
|
651
|
+
weights = model.states()
|
652
|
+
grads = bst.augment.grad(loss_fn, weights)(x, y)
|
653
|
+
for key, val in grads.items():
|
654
|
+
sgd(weights[key], val)
|
655
|
+
assert loss_fn(x, y) < prev_loss
|
656
|
+
|
657
|
+
def test_pop_states(self):
|
658
|
+
class Model(bst.nn.Module):
|
659
|
+
def __init__(self):
|
660
|
+
super().__init__()
|
661
|
+
self.a = bst.nn.Linear(2, 3)
|
662
|
+
self.b = bst.nn.LIF([10, 2])
|
663
|
+
|
664
|
+
model = Model()
|
665
|
+
with bst.catch_new_states('new'):
|
666
|
+
bst.nn.init_all_states(model)
|
667
|
+
# print(model.states())
|
668
|
+
self.assertTrue(len(model.states()) == 2)
|
669
|
+
model_states = bst.graph.pop_states(model, 'new')
|
670
|
+
print(model_states)
|
671
|
+
self.assertTrue(len(model.states()) == 1)
|
672
|
+
assert not hasattr(model.b, 'V')
|
673
|
+
# print(model.states())
|
674
|
+
|
675
|
+
def test_treefy_split(self):
|
676
|
+
class MLP(bst.graph.Node):
|
677
|
+
def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
|
678
|
+
self.input = bst.nn.Linear(din, dmid)
|
679
|
+
self.layers = [bst.nn.Linear(dmid, dmid) for _ in range(n_layer)]
|
680
|
+
self.output = bst.nn.Linear(dmid, dout)
|
681
|
+
|
682
|
+
def __call__(self, x):
|
683
|
+
x = bst.functional.relu(self.input(x))
|
684
|
+
for layer in self.layers:
|
685
|
+
x = bst.functional.relu(layer(x))
|
686
|
+
return self.output(x)
|
687
|
+
|
688
|
+
model = MLP(2, 1, 3)
|
689
|
+
graph_def, treefy_states = bst.graph.treefy_split(model)
|
690
|
+
|
691
|
+
print(graph_def)
|
692
|
+
print(treefy_states)
|
693
|
+
|
694
|
+
# states = bst.graph.states(model)
|
695
|
+
# print(states)
|
696
|
+
# nest_states = states.to_nest()
|
697
|
+
# print(nest_states)
|
698
|
+
|
699
|
+
def test_states(self):
|
700
|
+
class MLP(bst.graph.Node):
|
701
|
+
def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
|
702
|
+
self.input = bst.nn.Linear(din, dmid)
|
703
|
+
self.layers = [bst.nn.Linear(dmid, dmid) for _ in range(n_layer)]
|
704
|
+
self.output = bst.nn.LIF(dout)
|
705
|
+
|
706
|
+
def __call__(self, x):
|
707
|
+
x = bst.functional.relu(self.input(x))
|
708
|
+
for layer in self.layers:
|
709
|
+
x = bst.functional.relu(layer(x))
|
710
|
+
return self.output(x)
|
711
|
+
|
712
|
+
model = bst.nn.init_all_states(MLP(2, 1, 3))
|
713
|
+
states = bst.graph.states(model)
|
714
|
+
print(states)
|
715
|
+
nest_states = states.to_nest()
|
716
|
+
print(nest_states)
|
717
|
+
|
718
|
+
params, others = bst.graph.states(model, bst.ParamState, bst.ShortTermState)
|
719
|
+
print(params)
|
720
|
+
print(others)
|
721
|
+
|
722
|
+
|
723
|
+
if __name__ == '__main__':
|
724
|
+
absltest.main()
|