brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,597 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
import unittest
|
18
|
-
|
19
|
-
import jax
|
20
|
-
import jax.numpy as jnp
|
21
|
-
import numpy as np
|
22
|
-
from jax import vmap
|
23
|
-
from jax.lax import psum, pmean, pmax
|
24
|
-
|
25
|
-
import brainstate
|
26
|
-
import brainstate.augment
|
27
|
-
from brainstate.augment._mapping import BatchAxisError
|
28
|
-
from brainstate.augment._mapping import _remove_axis
|
29
|
-
|
30
|
-
|
31
|
-
class TestVmap(unittest.TestCase):
|
32
|
-
def test_vmap_1(self):
|
33
|
-
class Model(brainstate.nn.Module):
|
34
|
-
def __init__(self):
|
35
|
-
super().__init__()
|
36
|
-
|
37
|
-
self.a = brainstate.State(brainstate.random.randn(5))
|
38
|
-
self.b = brainstate.State(brainstate.random.randn(5))
|
39
|
-
|
40
|
-
def __call__(self, *args, **kwargs):
|
41
|
-
return self.a.value * self.b.value
|
42
|
-
|
43
|
-
model = Model()
|
44
|
-
r1 = model.a.value * model.b.value
|
45
|
-
r2 = brainstate.augment.vmap(model, in_states=model.states())()
|
46
|
-
self.assertTrue(jnp.allclose(r1, r2))
|
47
|
-
|
48
|
-
def test_vmap_2(self):
|
49
|
-
class Model(brainstate.nn.Module):
|
50
|
-
def __init__(self):
|
51
|
-
super().__init__()
|
52
|
-
|
53
|
-
self.a = brainstate.ShortTermState(brainstate.random.randn(5))
|
54
|
-
self.b = brainstate.ShortTermState(brainstate.random.randn(5))
|
55
|
-
self.c = brainstate.State(brainstate.random.randn(1))
|
56
|
-
|
57
|
-
def __call__(self, *args, **kwargs):
|
58
|
-
self.c.value = self.a.value * self.b.value
|
59
|
-
return self.c.value + 1.
|
60
|
-
|
61
|
-
model = Model()
|
62
|
-
with self.assertRaises(BatchAxisError):
|
63
|
-
r2 = brainstate.augment.vmap(model, in_states=model.states(brainstate.ShortTermState))()
|
64
|
-
|
65
|
-
model = Model()
|
66
|
-
r2 = brainstate.augment.vmap(model, in_states=model.states(brainstate.ShortTermState), out_states=model.c)()
|
67
|
-
|
68
|
-
def test_vmap_3(self):
|
69
|
-
class Model(brainstate.nn.Module):
|
70
|
-
def __init__(self):
|
71
|
-
super().__init__()
|
72
|
-
|
73
|
-
self.a = brainstate.State(brainstate.random.randn(5))
|
74
|
-
self.b = brainstate.State(brainstate.random.randn(5))
|
75
|
-
|
76
|
-
def __call__(self, *args, **kwargs):
|
77
|
-
return self.a.value * self.b.value
|
78
|
-
|
79
|
-
model = Model()
|
80
|
-
with self.assertRaises(BatchAxisError):
|
81
|
-
r2 = brainstate.augment.vmap(model, in_states=model.states(), out_states={1: model.states()})()
|
82
|
-
|
83
|
-
def test_vmap_with_random(self):
|
84
|
-
class Model(brainstate.nn.Module):
|
85
|
-
def __init__(self):
|
86
|
-
super().__init__()
|
87
|
-
|
88
|
-
self.a = brainstate.ShortTermState(brainstate.random.randn(5))
|
89
|
-
self.b = brainstate.ShortTermState(brainstate.random.randn(5))
|
90
|
-
self.c = brainstate.State(brainstate.random.randn(1))
|
91
|
-
|
92
|
-
def __call__(self, key):
|
93
|
-
brainstate.random.set_key(key)
|
94
|
-
self.c.value = self.a.value * self.b.value
|
95
|
-
return self.c.value + brainstate.random.randn(1)
|
96
|
-
|
97
|
-
model = Model()
|
98
|
-
r2 = brainstate.augment.vmap(
|
99
|
-
model,
|
100
|
-
in_states=model.states(brainstate.ShortTermState),
|
101
|
-
out_states=model.c
|
102
|
-
)(
|
103
|
-
brainstate.random.split_key(5)
|
104
|
-
)
|
105
|
-
print(brainstate.random.DEFAULT)
|
106
|
-
|
107
|
-
def test_vmap_with_random_v3(self):
|
108
|
-
class Model(brainstate.nn.Module):
|
109
|
-
def __init__(self):
|
110
|
-
super().__init__()
|
111
|
-
|
112
|
-
self.a = brainstate.ShortTermState(brainstate.random.randn(5))
|
113
|
-
self.b = brainstate.ShortTermState(brainstate.random.randn(5))
|
114
|
-
self.c = brainstate.State(brainstate.random.randn(1))
|
115
|
-
|
116
|
-
def __call__(self):
|
117
|
-
self.c.value = self.a.value * self.b.value
|
118
|
-
return self.c.value + brainstate.random.randn(1)
|
119
|
-
|
120
|
-
model = Model()
|
121
|
-
r2 = brainstate.augment.vmap(
|
122
|
-
model,
|
123
|
-
in_states=model.states(brainstate.ShortTermState),
|
124
|
-
out_states=model.c
|
125
|
-
)()
|
126
|
-
print(brainstate.random.DEFAULT)
|
127
|
-
|
128
|
-
def test_vmap_with_random_2(self):
|
129
|
-
class Model(brainstate.nn.Module):
|
130
|
-
def __init__(self):
|
131
|
-
super().__init__()
|
132
|
-
|
133
|
-
self.a = brainstate.ShortTermState(brainstate.random.randn(5))
|
134
|
-
self.b = brainstate.ShortTermState(brainstate.random.randn(5))
|
135
|
-
self.c = brainstate.State(brainstate.random.randn(1))
|
136
|
-
self.rng = brainstate.random.RandomState(1)
|
137
|
-
|
138
|
-
def __call__(self, key):
|
139
|
-
self.rng.set_key(key)
|
140
|
-
self.c.value = self.a.value * self.b.value
|
141
|
-
return self.c.value + brainstate.random.randn(1)
|
142
|
-
|
143
|
-
model = Model()
|
144
|
-
r2 = brainstate.augment.vmap(
|
145
|
-
model,
|
146
|
-
in_states=model.states(brainstate.ShortTermState),
|
147
|
-
out_states=model.c
|
148
|
-
)(
|
149
|
-
brainstate.random.split_key(5)
|
150
|
-
)
|
151
|
-
|
152
|
-
def test_vmap_input(self):
|
153
|
-
model = brainstate.nn.Linear(2, 3)
|
154
|
-
print(id(model), id(model.weight))
|
155
|
-
model_id = id(model)
|
156
|
-
weight_id = id(model.weight)
|
157
|
-
|
158
|
-
x = jnp.ones((5, 2))
|
159
|
-
|
160
|
-
@brainstate.augment.vmap
|
161
|
-
def forward(x):
|
162
|
-
self.assertTrue(id(model) == model_id)
|
163
|
-
self.assertTrue(id(model.weight) == weight_id)
|
164
|
-
return model(x)
|
165
|
-
|
166
|
-
y = forward(x)
|
167
|
-
self.assertTrue(y.shape == (5, 3))
|
168
|
-
print(y.shape)
|
169
|
-
print(model.weight.value_call(jnp.shape))
|
170
|
-
print(model.weight.value)
|
171
|
-
|
172
|
-
def test_vmap_states_and_input_1(self):
|
173
|
-
gru = brainstate.nn.GRUCell(2, 3)
|
174
|
-
gru.init_state(5)
|
175
|
-
|
176
|
-
@brainstate.augment.vmap(in_states=gru.states(brainstate.HiddenState))
|
177
|
-
def forward(x):
|
178
|
-
return gru(x)
|
179
|
-
|
180
|
-
xs = brainstate.random.randn(5, 2)
|
181
|
-
y = forward(xs)
|
182
|
-
self.assertTrue(y.shape == (5, 3))
|
183
|
-
|
184
|
-
def test_vmap_jit(self):
|
185
|
-
class Foo(brainstate.nn.Module):
|
186
|
-
def __init__(self):
|
187
|
-
super().__init__()
|
188
|
-
self.a = brainstate.ParamState(jnp.arange(4))
|
189
|
-
self.b = brainstate.ShortTermState(jnp.arange(4))
|
190
|
-
|
191
|
-
def __call__(self):
|
192
|
-
self.b.value = self.a.value * self.b.value
|
193
|
-
|
194
|
-
foo = Foo()
|
195
|
-
|
196
|
-
@brainstate.augment.vmap(in_states=foo.states())
|
197
|
-
def mul():
|
198
|
-
foo()
|
199
|
-
|
200
|
-
@brainstate.compile.jit
|
201
|
-
def mul_jit(inp):
|
202
|
-
mul()
|
203
|
-
foo.a.value += inp
|
204
|
-
|
205
|
-
with brainstate.StateTraceStack() as trace:
|
206
|
-
mul_jit(1.)
|
207
|
-
|
208
|
-
print(foo.a.value)
|
209
|
-
print(foo.b.value)
|
210
|
-
self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
|
211
|
-
self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
|
212
|
-
|
213
|
-
write_state_ids = [id(st) for st in trace.get_write_states()]
|
214
|
-
read_state_ids = [id(st) for st in trace.get_read_states()]
|
215
|
-
|
216
|
-
assert id(foo.a) in write_state_ids
|
217
|
-
assert id(foo.b) in write_state_ids
|
218
|
-
|
219
|
-
print(trace.get_write_states())
|
220
|
-
print(trace.get_read_states())
|
221
|
-
|
222
|
-
def test_vmap_jit_2(self):
|
223
|
-
class Foo(brainstate.nn.Module):
|
224
|
-
def __init__(self):
|
225
|
-
super().__init__()
|
226
|
-
self.a = brainstate.ParamState(jnp.arange(4))
|
227
|
-
self.b = brainstate.ShortTermState(jnp.arange(4))
|
228
|
-
|
229
|
-
def __call__(self):
|
230
|
-
self.b.value = self.a.value * self.b.value
|
231
|
-
|
232
|
-
foo = Foo()
|
233
|
-
|
234
|
-
@brainstate.augment.vmap(in_states=foo.states())
|
235
|
-
def mul():
|
236
|
-
foo()
|
237
|
-
|
238
|
-
@brainstate.compile.jit
|
239
|
-
def mul_jit(inp):
|
240
|
-
mul()
|
241
|
-
foo.b.value += inp
|
242
|
-
|
243
|
-
with brainstate.StateTraceStack() as trace:
|
244
|
-
mul_jit(1.)
|
245
|
-
|
246
|
-
print(foo.a.value)
|
247
|
-
print(foo.b.value)
|
248
|
-
self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4)))
|
249
|
-
self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4) + 1.))
|
250
|
-
|
251
|
-
write_state_ids = [id(st) for st in trace.get_write_states()]
|
252
|
-
read_state_ids = [id(st) for st in trace.get_read_states()]
|
253
|
-
|
254
|
-
assert id(foo.a) in read_state_ids
|
255
|
-
assert id(foo.b) in write_state_ids
|
256
|
-
|
257
|
-
print(trace.get_write_states())
|
258
|
-
print(trace.get_read_states())
|
259
|
-
|
260
|
-
def test_auto_rand_key_split(self):
|
261
|
-
def f():
|
262
|
-
return brainstate.random.rand(1)
|
263
|
-
|
264
|
-
res = brainstate.augment.vmap(f, axis_size=10)()
|
265
|
-
self.assertTrue(jnp.all(~(res[0] == res[1:])))
|
266
|
-
|
267
|
-
res2 = jax.vmap(f, axis_size=10)()
|
268
|
-
self.assertTrue(jnp.all((res2[0] == res2[1:])))
|
269
|
-
|
270
|
-
def test_axis(self):
|
271
|
-
def f(x):
|
272
|
-
return x - jax.lax.pmean(x, 'i')
|
273
|
-
|
274
|
-
r = jax.vmap(f, axis_name='i')(jnp.arange(10))
|
275
|
-
print(r)
|
276
|
-
|
277
|
-
r2 = brainstate.augment.vmap(f, axis_name='i')(jnp.arange(10))
|
278
|
-
print(r2)
|
279
|
-
self.assertTrue(jnp.allclose(r, r2))
|
280
|
-
|
281
|
-
def test_vmap_init(self):
|
282
|
-
class Foo(brainstate.nn.Module):
|
283
|
-
def __init__(self):
|
284
|
-
super().__init__()
|
285
|
-
self.a = brainstate.ParamState(jnp.arange(4))
|
286
|
-
self.b = brainstate.ShortTermState(jnp.arange(4))
|
287
|
-
|
288
|
-
def init_state_v1(self, *args, **kwargs):
|
289
|
-
self.c = brainstate.State(jnp.arange(4))
|
290
|
-
|
291
|
-
def init_state_v2(self):
|
292
|
-
self.d = brainstate.State(self.c.value * 2.)
|
293
|
-
|
294
|
-
foo = Foo()
|
295
|
-
|
296
|
-
@brainstate.augment.vmap_new_states(state_tag='new1', axis_size=5)
|
297
|
-
def init1():
|
298
|
-
foo.init_state_v1()
|
299
|
-
|
300
|
-
init1()
|
301
|
-
print(foo.c.value)
|
302
|
-
|
303
|
-
@brainstate.augment.vmap_new_states(state_tag='new2', axis_size=5, in_states=foo.states('new1'))
|
304
|
-
def init2():
|
305
|
-
foo.init_state_v2()
|
306
|
-
|
307
|
-
init2()
|
308
|
-
print(foo.c.value)
|
309
|
-
print(foo.d.value)
|
310
|
-
|
311
|
-
self.assertTrue(
|
312
|
-
jnp.allclose(
|
313
|
-
foo.d.value,
|
314
|
-
foo.c.value * 2.
|
315
|
-
)
|
316
|
-
)
|
317
|
-
|
318
|
-
|
319
|
-
class TestMap(unittest.TestCase):
|
320
|
-
def test_map(self):
|
321
|
-
for dim in [(10,), (10, 10), (10, 10, 10)]:
|
322
|
-
x = brainstate.random.rand(*dim)
|
323
|
-
r1 = brainstate.augment.map(lambda a: a + 1, x, batch_size=None)
|
324
|
-
r2 = brainstate.augment.map(lambda a: a + 1, x, batch_size=2)
|
325
|
-
r3 = brainstate.augment.map(lambda a: a + 1, x, batch_size=4)
|
326
|
-
r4 = brainstate.augment.map(lambda a: a + 1, x, batch_size=5)
|
327
|
-
true_r = x + 1
|
328
|
-
|
329
|
-
self.assertTrue(jnp.allclose(r1, true_r))
|
330
|
-
self.assertTrue(jnp.allclose(r2, true_r))
|
331
|
-
self.assertTrue(jnp.allclose(r3, true_r))
|
332
|
-
self.assertTrue(jnp.allclose(r4, true_r))
|
333
|
-
|
334
|
-
|
335
|
-
class TestRemoveAxis:
|
336
|
-
|
337
|
-
def test_remove_axis_2d_array_axis_0(self):
|
338
|
-
input_array = np.array([[1, 2, 3], [4, 5, 6]])
|
339
|
-
expected_output = np.array([1, 2, 3])
|
340
|
-
|
341
|
-
result = _remove_axis(input_array, axis=0)
|
342
|
-
|
343
|
-
np.testing.assert_array_equal(result, expected_output)
|
344
|
-
|
345
|
-
def test_remove_axis_3d_array(self):
|
346
|
-
# Create a 3D array
|
347
|
-
x = np.arange(24).reshape((2, 3, 4))
|
348
|
-
|
349
|
-
# Remove axis 1
|
350
|
-
result = _remove_axis(x, axis=1)
|
351
|
-
|
352
|
-
# Expected result: a 2D array with shape (2, 4)
|
353
|
-
expected = x[:, 0, :]
|
354
|
-
|
355
|
-
np.testing.assert_array_equal(result, expected)
|
356
|
-
assert result.shape == (2, 4)
|
357
|
-
|
358
|
-
def test_remove_axis_1d_array(self):
|
359
|
-
# Create a 1D array
|
360
|
-
x = np.array([1, 2, 3, 4, 5])
|
361
|
-
|
362
|
-
# Remove axis 0 (the only axis in a 1D array)
|
363
|
-
result = _remove_axis(x, axis=0)
|
364
|
-
|
365
|
-
# Check that the result is a scalar (0D array) and equal to the first element
|
366
|
-
assert np.isscalar(result), "Result should be a scalar"
|
367
|
-
assert result == 1, "Result should be equal to the first element of the input array"
|
368
|
-
|
369
|
-
def test_remove_axis_out_of_bounds(self):
|
370
|
-
x = jnp.array([[1, 2], [3, 4]])
|
371
|
-
with unittest.TestCase().assertRaises(IndexError):
|
372
|
-
_remove_axis(x, axis=2)
|
373
|
-
|
374
|
-
def test_remove_axis_negative(self):
|
375
|
-
x = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
376
|
-
result = _remove_axis(x, -1)
|
377
|
-
expected = jnp.array([[1, 3], [5, 7]])
|
378
|
-
np.testing.assert_array_equal(result, expected)
|
379
|
-
|
380
|
-
def test_remove_axis_with_nan_and_inf(self):
|
381
|
-
x = jnp.array([[1.0, jnp.nan, 3.0], [4.0, 5.0, jnp.inf]])
|
382
|
-
result = _remove_axis(x, axis=0)
|
383
|
-
expected = jnp.array([1.0, jnp.nan, 3.0])
|
384
|
-
np.testing.assert_array_equal(result, expected)
|
385
|
-
assert jnp.isnan(result[1])
|
386
|
-
|
387
|
-
def test_remove_axis_different_dtypes(self):
|
388
|
-
# Test with integer array
|
389
|
-
int_array = jnp.array([[1, 2, 3], [4, 5, 6]])
|
390
|
-
int_result = _remove_axis(int_array, 0)
|
391
|
-
assert jnp.array_equal(int_result, jnp.array([1, 2, 3]))
|
392
|
-
|
393
|
-
# Test with float array
|
394
|
-
float_array = jnp.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]])
|
395
|
-
float_result = _remove_axis(float_array, 1)
|
396
|
-
assert jnp.allclose(float_result, jnp.array([1.1, 4.4]))
|
397
|
-
|
398
|
-
# Test with complex array
|
399
|
-
complex_array = jnp.array([[1 + 1j, 2 + 2j], [3 + 3j, 4 + 4j]])
|
400
|
-
complex_result = _remove_axis(complex_array, 0)
|
401
|
-
assert jnp.allclose(complex_result, jnp.array([1 + 1j, 2 + 2j]))
|
402
|
-
|
403
|
-
|
404
|
-
class TestVMAPNewStatesEdgeCases(unittest.TestCase):
|
405
|
-
|
406
|
-
def test_axis_size_zero(self):
|
407
|
-
foo = brainstate.nn.LIF(3)
|
408
|
-
# Testing that axis_size of 0 raises an error.
|
409
|
-
with self.assertRaises(ValueError):
|
410
|
-
@brainstate.augment.vmap_new_states(state_tag='new1', axis_size=0)
|
411
|
-
def faulty_init():
|
412
|
-
foo.init_state()
|
413
|
-
|
414
|
-
# Call the decorated function to trigger validation
|
415
|
-
faulty_init()
|
416
|
-
|
417
|
-
def test_axis_size_negative(self):
|
418
|
-
foo = brainstate.nn.LIF(3)
|
419
|
-
# Testing that a negative axis_size raises an error.
|
420
|
-
with self.assertRaises(ValueError):
|
421
|
-
@brainstate.augment.vmap_new_states(state_tag='new1', axis_size=-3)
|
422
|
-
def faulty_init():
|
423
|
-
foo.init_state()
|
424
|
-
|
425
|
-
faulty_init()
|
426
|
-
|
427
|
-
def test_incompatible_shapes(self):
|
428
|
-
foo = brainstate.nn.LIF(3)
|
429
|
-
|
430
|
-
# Simulate an incompatible shapes scenario:
|
431
|
-
# We intentionally assign a state with a different shape than expected.
|
432
|
-
@brainstate.augment.vmap_new_states(state_tag='new1', axis_size=5)
|
433
|
-
def faulty_init():
|
434
|
-
# Modify state to produce an incompatible shape
|
435
|
-
foo.c = brainstate.State(jnp.arange(3)) # Original expected shape is (4,)
|
436
|
-
|
437
|
-
faulty_init()
|
438
|
-
|
439
|
-
|
440
|
-
class TestAxisName:
|
441
|
-
def test1(self):
|
442
|
-
def compute_stats_with_axis_name(x):
|
443
|
-
"""Compute statistics using named axis operations"""
|
444
|
-
# Sum across the named axis 'batch'
|
445
|
-
total_sum = psum(x, axis_name='batch')
|
446
|
-
|
447
|
-
# Mean across the named axis 'batch'
|
448
|
-
mean_val = pmean(x, axis_name='batch')
|
449
|
-
|
450
|
-
# Max across the named axis 'batch'
|
451
|
-
max_val = pmax(x, axis_name='batch')
|
452
|
-
|
453
|
-
return {
|
454
|
-
'sum': total_sum,
|
455
|
-
'mean': mean_val,
|
456
|
-
'max': max_val,
|
457
|
-
'original': x
|
458
|
-
}
|
459
|
-
|
460
|
-
batch_data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
461
|
-
print("Input batch data:", batch_data)
|
462
|
-
|
463
|
-
# vmap with axis name 'batch'
|
464
|
-
vectorized_stats_jax = jax.jit(vmap(compute_stats_with_axis_name, axis_name='batch'))
|
465
|
-
result_jax = vectorized_stats_jax(batch_data)
|
466
|
-
|
467
|
-
# vmap with axis name 'batch'
|
468
|
-
vectorized_stats = brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
469
|
-
result = vectorized_stats(batch_data)
|
470
|
-
|
471
|
-
# vmap with axis name 'batch'
|
472
|
-
vectorized_stats_v2 = brainstate.transform.jit(
|
473
|
-
brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
474
|
-
)
|
475
|
-
result_v2 = vectorized_stats_v2(batch_data)
|
476
|
-
|
477
|
-
for key in result_jax.keys():
|
478
|
-
print(f" {key}: {result_jax[key]}")
|
479
|
-
assert jnp.allclose(result_jax[key], result[key]), f"Mismatch in {key}"
|
480
|
-
assert jnp.allclose(result_jax[key], result_v2[key]), f"Mismatch in {key}"
|
481
|
-
|
482
|
-
def test_nested_vmap(self):
|
483
|
-
def nested_computation(x):
|
484
|
-
"""Computation with multiple named axes"""
|
485
|
-
# Sum over 'inner' axis, then mean over 'outer' axis
|
486
|
-
inner_sum = psum(x, axis_name='inner')
|
487
|
-
outer_mean = pmean(inner_sum, axis_name='outer')
|
488
|
-
return outer_mean
|
489
|
-
|
490
|
-
# Create 2D batch data
|
491
|
-
data_2d = jnp.arange(12.0).reshape(3, 4) # Shape: [outer_batch=3, inner_batch=4]
|
492
|
-
print("Input 2D data shape:", data_2d.shape)
|
493
|
-
print("Input 2D data:\n", data_2d)
|
494
|
-
|
495
|
-
# Nested vmap: first over inner dimension, then outer dimension
|
496
|
-
inner_vmap = vmap(nested_computation, axis_name='inner')
|
497
|
-
nested_vmap = vmap(inner_vmap, axis_name='outer')
|
498
|
-
|
499
|
-
result_2d = nested_vmap(data_2d)
|
500
|
-
print("Result after nested vmap:", result_2d)
|
501
|
-
|
502
|
-
inner_vmap_bst = brainstate.transform.vmap(nested_computation, axis_name='inner')
|
503
|
-
nested_vmap_bst = brainstate.transform.vmap(inner_vmap_bst, axis_name='outer')
|
504
|
-
result_2d_bst = nested_vmap_bst(data_2d)
|
505
|
-
print("Result after nested vmap:", result_2d_bst)
|
506
|
-
|
507
|
-
assert jnp.allclose(result_2d, result_2d_bst)
|
508
|
-
|
509
|
-
def _gradient_averaging_simulation_bst(self):
|
510
|
-
def loss_function(params, x, y):
|
511
|
-
"""Simple quadratic loss"""
|
512
|
-
pred = params * x
|
513
|
-
return (pred - y) ** 2
|
514
|
-
|
515
|
-
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
516
|
-
"""Compute gradients and average them across the batch"""
|
517
|
-
# Compute per-sample gradients
|
518
|
-
grad_fn = jax.grad(loss_function, argnums=0)
|
519
|
-
per_sample_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
520
|
-
|
521
|
-
# Average gradients across batch using named axis
|
522
|
-
def average_grads(grads):
|
523
|
-
return pmean(grads, axis_name='batch')
|
524
|
-
|
525
|
-
# Apply averaging with named axis
|
526
|
-
averaged_grads = vmap(average_grads, axis_name='batch')(per_sample_grads)
|
527
|
-
return averaged_grads
|
528
|
-
|
529
|
-
# Example data
|
530
|
-
params = 2.0
|
531
|
-
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
532
|
-
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
533
|
-
|
534
|
-
print("Parameters:", params)
|
535
|
-
print("Batch X:", batch_x)
|
536
|
-
print("Batch Y:", batch_y)
|
537
|
-
|
538
|
-
# Compute individual gradients first
|
539
|
-
grad_fn = jax.grad(loss_function, argnums=0)
|
540
|
-
individual_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
541
|
-
print("Individual gradients:", individual_grads)
|
542
|
-
|
543
|
-
# Now compute averaged gradients using axis names
|
544
|
-
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
545
|
-
print("Averaged gradients:", averaged_grads)
|
546
|
-
|
547
|
-
return individual_grads, averaged_grads
|
548
|
-
|
549
|
-
def _gradient_averaging_simulation_jax(self):
|
550
|
-
def loss_function(params, x, y):
|
551
|
-
"""Simple quadratic loss"""
|
552
|
-
pred = params * x
|
553
|
-
return (pred - y) ** 2
|
554
|
-
|
555
|
-
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
556
|
-
"""Compute gradients and average them across the batch"""
|
557
|
-
# Compute per-sample gradients
|
558
|
-
grad_fn = jax.grad(loss_function, argnums=0)
|
559
|
-
per_sample_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
560
|
-
|
561
|
-
# Average gradients across batch using named axis
|
562
|
-
def average_grads(grads):
|
563
|
-
return pmean(grads, axis_name='batch')
|
564
|
-
|
565
|
-
# Apply averaging with named axis
|
566
|
-
averaged_grads = brainstate.transform.vmap(average_grads, axis_name='batch')(per_sample_grads)
|
567
|
-
return averaged_grads
|
568
|
-
|
569
|
-
# Example data
|
570
|
-
params = 2.0
|
571
|
-
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
572
|
-
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
573
|
-
|
574
|
-
print("Parameters:", params)
|
575
|
-
print("Batch X:", batch_x)
|
576
|
-
print("Batch Y:", batch_y)
|
577
|
-
|
578
|
-
# Compute individual gradients first
|
579
|
-
grad_fn = jax.grad(loss_function, argnums=0)
|
580
|
-
individual_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
581
|
-
print("Individual gradients:", individual_grads)
|
582
|
-
|
583
|
-
# Now compute averaged gradients using axis names
|
584
|
-
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
585
|
-
print("Averaged gradients:", averaged_grads)
|
586
|
-
|
587
|
-
return individual_grads, averaged_grads
|
588
|
-
|
589
|
-
def test_gradient_averaging_simulation(self):
|
590
|
-
individual_grads, averaged_grads = self._gradient_averaging_simulation_bst()
|
591
|
-
individual_grads_jax, averaged_grads_jax = self._gradient_averaging_simulation_jax()
|
592
|
-
assert jnp.allclose(individual_grads, individual_grads_jax)
|
593
|
-
assert jnp.allclose(averaged_grads, averaged_grads_jax)
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
brainstate/compile/__init__.py
DELETED
@@ -1,38 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""
|
17
|
-
This module contains the functions for the compilation of JAX code.
|
18
|
-
"""
|
19
|
-
|
20
|
-
from ._ad_checkpoint import checkpoint, remat
|
21
|
-
from ._conditions import cond, switch, ifelse
|
22
|
-
from ._error_if import jit_error_if
|
23
|
-
from ._jit import jit
|
24
|
-
from ._loop_collect_return import scan, checkpointed_scan, for_loop, checkpointed_for_loop
|
25
|
-
from ._loop_no_collection import while_loop, bounded_while_loop
|
26
|
-
from ._make_jaxpr import StatefulFunction, make_jaxpr
|
27
|
-
from ._progress_bar import ProgressBar
|
28
|
-
|
29
|
-
__all__ = [
|
30
|
-
'checkpoint', 'remat',
|
31
|
-
'cond', 'switch', 'ifelse',
|
32
|
-
'jit_error_if',
|
33
|
-
'jit',
|
34
|
-
'scan', 'checkpointed_scan', 'for_loop', 'checkpointed_for_loop',
|
35
|
-
'while_loop', 'bounded_while_loop',
|
36
|
-
'StatefulFunction', 'make_jaxpr',
|
37
|
-
'ProgressBar',
|
38
|
-
]
|