brainstate 0.1.1__py2.py3-none-any.whl → 0.1.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +12 -9
- brainstate/_state.py +1 -1
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +29 -14
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/functional/_activations_test.py +61 -61
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/mixin.py +1 -14
- brainstate/nn/__init__.py +81 -17
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
- brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +31 -33
- brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +6 -2
- brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +1 -1
- brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +9 -9
- brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +139 -18
- brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +14 -15
- brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
- brainstate/nn/_elementwise_test.py +169 -0
- brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob_mv.py} +1 -1
- brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_mv_test.py} +0 -1
- brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
- brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
- brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +1 -1
- brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -1
- brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +15 -17
- brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
- brainstate/nn/_module_test.py +34 -37
- brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
- brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +18 -19
- brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
- brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +10 -12
- brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
- brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +20 -22
- brainstate/nn/{_dynamics/_projection_base.py → _projection.py} +35 -3
- brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
- brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +6 -7
- brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
- brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +9 -10
- brainstate/nn/_stp.py +236 -0
- brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +17 -206
- brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +9 -10
- brainstate/nn/_synaptic_projection.py +133 -0
- brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
- brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +4 -5
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed_test.py +10 -12
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/METADATA +1 -1
- brainstate-0.1.3.dist-info/RECORD +131 -0
- brainstate/nn/_dyn_impl/__init__.py +0 -42
- brainstate/nn/_dynamics/__init__.py +0 -37
- brainstate/nn/_elementwise/__init__.py +0 -22
- brainstate/nn/_elementwise/_elementwise_test.py +0 -171
- brainstate/nn/_interaction/__init__.py +0 -41
- brainstate-0.1.1.dist-info/RECORD +0 -133
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/LICENSE +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/WHEEL +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/top_level.txt +0 -0
@@ -14,27 +14,25 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
|
17
|
-
from __future__ import annotations
|
18
|
-
|
19
17
|
import unittest
|
20
18
|
|
21
|
-
import brainstate
|
19
|
+
import brainstate
|
22
20
|
|
23
21
|
|
24
22
|
class TestEvalShape(unittest.TestCase):
|
25
23
|
def test1(self):
|
26
|
-
class MLP(
|
24
|
+
class MLP(brainstate.nn.Module):
|
27
25
|
def __init__(self, n_in, n_mid, n_out):
|
28
26
|
super().__init__()
|
29
|
-
self.dense1 =
|
30
|
-
self.dense2 =
|
27
|
+
self.dense1 = brainstate.nn.Linear(n_in, n_mid)
|
28
|
+
self.dense2 = brainstate.nn.Linear(n_mid, n_out)
|
31
29
|
|
32
30
|
def __call__(self, x):
|
33
31
|
x = self.dense1(x)
|
34
|
-
x =
|
32
|
+
x = brainstate.functional.relu(x)
|
35
33
|
x = self.dense2(x)
|
36
34
|
return x
|
37
35
|
|
38
|
-
r =
|
36
|
+
r = brainstate.augment.abstract_init(lambda: MLP(1, 2, 3))
|
39
37
|
print(r)
|
40
|
-
print(
|
38
|
+
print(brainstate.random.DEFAULT)
|
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
16
|
|
18
17
|
import unittest
|
19
18
|
|
@@ -21,7 +20,7 @@ import jax
|
|
21
20
|
import jax.numpy as jnp
|
22
21
|
import numpy as np
|
23
22
|
|
24
|
-
import brainstate
|
23
|
+
import brainstate
|
25
24
|
import brainstate.augment
|
26
25
|
from brainstate.augment._mapping import BatchAxisError
|
27
26
|
from brainstate.augment._mapping import _remove_axis
|
@@ -29,29 +28,29 @@ from brainstate.augment._mapping import _remove_axis
|
|
29
28
|
|
30
29
|
class TestVmap(unittest.TestCase):
|
31
30
|
def test_vmap_1(self):
|
32
|
-
class Model(
|
31
|
+
class Model(brainstate.nn.Module):
|
33
32
|
def __init__(self):
|
34
33
|
super().__init__()
|
35
34
|
|
36
|
-
self.a =
|
37
|
-
self.b =
|
35
|
+
self.a = brainstate.State(brainstate.random.randn(5))
|
36
|
+
self.b = brainstate.State(brainstate.random.randn(5))
|
38
37
|
|
39
38
|
def __call__(self, *args, **kwargs):
|
40
39
|
return self.a.value * self.b.value
|
41
40
|
|
42
41
|
model = Model()
|
43
42
|
r1 = model.a.value * model.b.value
|
44
|
-
r2 =
|
43
|
+
r2 = brainstate.augment.vmap(model, in_states=model.states())()
|
45
44
|
self.assertTrue(jnp.allclose(r1, r2))
|
46
45
|
|
47
46
|
def test_vmap_2(self):
|
48
|
-
class Model(
|
47
|
+
class Model(brainstate.nn.Module):
|
49
48
|
def __init__(self):
|
50
49
|
super().__init__()
|
51
50
|
|
52
|
-
self.a =
|
53
|
-
self.b =
|
54
|
-
self.c =
|
51
|
+
self.a = brainstate.ShortTermState(brainstate.random.randn(5))
|
52
|
+
self.b = brainstate.ShortTermState(brainstate.random.randn(5))
|
53
|
+
self.c = brainstate.State(brainstate.random.randn(1))
|
55
54
|
|
56
55
|
def __call__(self, *args, **kwargs):
|
57
56
|
self.c.value = self.a.value * self.b.value
|
@@ -59,104 +58,104 @@ class TestVmap(unittest.TestCase):
|
|
59
58
|
|
60
59
|
model = Model()
|
61
60
|
with self.assertRaises(BatchAxisError):
|
62
|
-
r2 =
|
61
|
+
r2 = brainstate.augment.vmap(model, in_states=model.states(brainstate.ShortTermState))()
|
63
62
|
|
64
63
|
model = Model()
|
65
|
-
r2 =
|
64
|
+
r2 = brainstate.augment.vmap(model, in_states=model.states(brainstate.ShortTermState), out_states=model.c)()
|
66
65
|
|
67
66
|
def test_vmap_3(self):
|
68
|
-
class Model(
|
67
|
+
class Model(brainstate.nn.Module):
|
69
68
|
def __init__(self):
|
70
69
|
super().__init__()
|
71
70
|
|
72
|
-
self.a =
|
73
|
-
self.b =
|
71
|
+
self.a = brainstate.State(brainstate.random.randn(5))
|
72
|
+
self.b = brainstate.State(brainstate.random.randn(5))
|
74
73
|
|
75
74
|
def __call__(self, *args, **kwargs):
|
76
75
|
return self.a.value * self.b.value
|
77
76
|
|
78
77
|
model = Model()
|
79
78
|
with self.assertRaises(BatchAxisError):
|
80
|
-
r2 =
|
79
|
+
r2 = brainstate.augment.vmap(model, in_states=model.states(), out_states={1: model.states()})()
|
81
80
|
|
82
81
|
def test_vmap_with_random(self):
|
83
|
-
class Model(
|
82
|
+
class Model(brainstate.nn.Module):
|
84
83
|
def __init__(self):
|
85
84
|
super().__init__()
|
86
85
|
|
87
|
-
self.a =
|
88
|
-
self.b =
|
89
|
-
self.c =
|
86
|
+
self.a = brainstate.ShortTermState(brainstate.random.randn(5))
|
87
|
+
self.b = brainstate.ShortTermState(brainstate.random.randn(5))
|
88
|
+
self.c = brainstate.State(brainstate.random.randn(1))
|
90
89
|
|
91
90
|
def __call__(self, key):
|
92
|
-
|
91
|
+
brainstate.random.set_key(key)
|
93
92
|
self.c.value = self.a.value * self.b.value
|
94
|
-
return self.c.value +
|
93
|
+
return self.c.value + brainstate.random.randn(1)
|
95
94
|
|
96
95
|
model = Model()
|
97
|
-
r2 =
|
96
|
+
r2 = brainstate.augment.vmap(
|
98
97
|
model,
|
99
|
-
in_states=model.states(
|
98
|
+
in_states=model.states(brainstate.ShortTermState),
|
100
99
|
out_states=model.c
|
101
100
|
)(
|
102
|
-
|
101
|
+
brainstate.random.split_key(5)
|
103
102
|
)
|
104
|
-
print(
|
103
|
+
print(brainstate.random.DEFAULT)
|
105
104
|
|
106
105
|
def test_vmap_with_random_v3(self):
|
107
|
-
class Model(
|
106
|
+
class Model(brainstate.nn.Module):
|
108
107
|
def __init__(self):
|
109
108
|
super().__init__()
|
110
109
|
|
111
|
-
self.a =
|
112
|
-
self.b =
|
113
|
-
self.c =
|
110
|
+
self.a = brainstate.ShortTermState(brainstate.random.randn(5))
|
111
|
+
self.b = brainstate.ShortTermState(brainstate.random.randn(5))
|
112
|
+
self.c = brainstate.State(brainstate.random.randn(1))
|
114
113
|
|
115
114
|
def __call__(self):
|
116
115
|
self.c.value = self.a.value * self.b.value
|
117
|
-
return self.c.value +
|
116
|
+
return self.c.value + brainstate.random.randn(1)
|
118
117
|
|
119
118
|
model = Model()
|
120
|
-
r2 =
|
119
|
+
r2 = brainstate.augment.vmap(
|
121
120
|
model,
|
122
|
-
in_states=model.states(
|
121
|
+
in_states=model.states(brainstate.ShortTermState),
|
123
122
|
out_states=model.c
|
124
123
|
)()
|
125
|
-
print(
|
124
|
+
print(brainstate.random.DEFAULT)
|
126
125
|
|
127
126
|
def test_vmap_with_random_2(self):
|
128
|
-
class Model(
|
127
|
+
class Model(brainstate.nn.Module):
|
129
128
|
def __init__(self):
|
130
129
|
super().__init__()
|
131
130
|
|
132
|
-
self.a =
|
133
|
-
self.b =
|
134
|
-
self.c =
|
135
|
-
self.rng =
|
131
|
+
self.a = brainstate.ShortTermState(brainstate.random.randn(5))
|
132
|
+
self.b = brainstate.ShortTermState(brainstate.random.randn(5))
|
133
|
+
self.c = brainstate.State(brainstate.random.randn(1))
|
134
|
+
self.rng = brainstate.random.RandomState(1)
|
136
135
|
|
137
136
|
def __call__(self, key):
|
138
137
|
self.rng.set_key(key)
|
139
138
|
self.c.value = self.a.value * self.b.value
|
140
|
-
return self.c.value +
|
139
|
+
return self.c.value + brainstate.random.randn(1)
|
141
140
|
|
142
141
|
model = Model()
|
143
|
-
r2 =
|
142
|
+
r2 = brainstate.augment.vmap(
|
144
143
|
model,
|
145
|
-
in_states=model.states(
|
144
|
+
in_states=model.states(brainstate.ShortTermState),
|
146
145
|
out_states=model.c
|
147
146
|
)(
|
148
|
-
|
147
|
+
brainstate.random.split_key(5)
|
149
148
|
)
|
150
149
|
|
151
150
|
def test_vmap_input(self):
|
152
|
-
model =
|
151
|
+
model = brainstate.nn.Linear(2, 3)
|
153
152
|
print(id(model), id(model.weight))
|
154
153
|
model_id = id(model)
|
155
154
|
weight_id = id(model.weight)
|
156
155
|
|
157
156
|
x = jnp.ones((5, 2))
|
158
157
|
|
159
|
-
@
|
158
|
+
@brainstate.augment.vmap
|
160
159
|
def forward(x):
|
161
160
|
self.assertTrue(id(model) == model_id)
|
162
161
|
self.assertTrue(id(model.weight) == weight_id)
|
@@ -169,39 +168,39 @@ class TestVmap(unittest.TestCase):
|
|
169
168
|
print(model.weight.value)
|
170
169
|
|
171
170
|
def test_vmap_states_and_input_1(self):
|
172
|
-
gru =
|
171
|
+
gru = brainstate.nn.GRUCell(2, 3)
|
173
172
|
gru.init_state(5)
|
174
173
|
|
175
|
-
@
|
174
|
+
@brainstate.augment.vmap(in_states=gru.states(brainstate.HiddenState))
|
176
175
|
def forward(x):
|
177
176
|
return gru(x)
|
178
177
|
|
179
|
-
xs =
|
178
|
+
xs = brainstate.random.randn(5, 2)
|
180
179
|
y = forward(xs)
|
181
180
|
self.assertTrue(y.shape == (5, 3))
|
182
181
|
|
183
182
|
def test_vmap_jit(self):
|
184
|
-
class Foo(
|
183
|
+
class Foo(brainstate.nn.Module):
|
185
184
|
def __init__(self):
|
186
185
|
super().__init__()
|
187
|
-
self.a =
|
188
|
-
self.b =
|
186
|
+
self.a = brainstate.ParamState(jnp.arange(4))
|
187
|
+
self.b = brainstate.ShortTermState(jnp.arange(4))
|
189
188
|
|
190
189
|
def __call__(self):
|
191
190
|
self.b.value = self.a.value * self.b.value
|
192
191
|
|
193
192
|
foo = Foo()
|
194
193
|
|
195
|
-
@
|
194
|
+
@brainstate.augment.vmap(in_states=foo.states())
|
196
195
|
def mul():
|
197
196
|
foo()
|
198
197
|
|
199
|
-
@
|
198
|
+
@brainstate.compile.jit
|
200
199
|
def mul_jit(inp):
|
201
200
|
mul()
|
202
201
|
foo.a.value += inp
|
203
202
|
|
204
|
-
with
|
203
|
+
with brainstate.StateTraceStack() as trace:
|
205
204
|
mul_jit(1.)
|
206
205
|
|
207
206
|
print(foo.a.value)
|
@@ -219,27 +218,27 @@ class TestVmap(unittest.TestCase):
|
|
219
218
|
print(trace.get_read_states())
|
220
219
|
|
221
220
|
def test_vmap_jit_2(self):
|
222
|
-
class Foo(
|
221
|
+
class Foo(brainstate.nn.Module):
|
223
222
|
def __init__(self):
|
224
223
|
super().__init__()
|
225
|
-
self.a =
|
226
|
-
self.b =
|
224
|
+
self.a = brainstate.ParamState(jnp.arange(4))
|
225
|
+
self.b = brainstate.ShortTermState(jnp.arange(4))
|
227
226
|
|
228
227
|
def __call__(self):
|
229
228
|
self.b.value = self.a.value * self.b.value
|
230
229
|
|
231
230
|
foo = Foo()
|
232
231
|
|
233
|
-
@
|
232
|
+
@brainstate.augment.vmap(in_states=foo.states())
|
234
233
|
def mul():
|
235
234
|
foo()
|
236
235
|
|
237
|
-
@
|
236
|
+
@brainstate.compile.jit
|
238
237
|
def mul_jit(inp):
|
239
238
|
mul()
|
240
239
|
foo.b.value += inp
|
241
240
|
|
242
|
-
with
|
241
|
+
with brainstate.StateTraceStack() as trace:
|
243
242
|
mul_jit(1.)
|
244
243
|
|
245
244
|
print(foo.a.value)
|
@@ -258,9 +257,9 @@ class TestVmap(unittest.TestCase):
|
|
258
257
|
|
259
258
|
def test_auto_rand_key_split(self):
|
260
259
|
def f():
|
261
|
-
return
|
260
|
+
return brainstate.random.rand(1)
|
262
261
|
|
263
|
-
res =
|
262
|
+
res = brainstate.augment.vmap(f, axis_size=10)()
|
264
263
|
self.assertTrue(jnp.all(~(res[0] == res[1:])))
|
265
264
|
|
266
265
|
res2 = jax.vmap(f, axis_size=10)()
|
@@ -278,17 +277,17 @@ class TestVmap(unittest.TestCase):
|
|
278
277
|
self.assertTrue(jnp.allclose(r, r2))
|
279
278
|
|
280
279
|
def test_vmap_init(self):
|
281
|
-
class Foo(
|
280
|
+
class Foo(brainstate.nn.Module):
|
282
281
|
def __init__(self):
|
283
282
|
super().__init__()
|
284
|
-
self.a =
|
285
|
-
self.b =
|
283
|
+
self.a = brainstate.ParamState(jnp.arange(4))
|
284
|
+
self.b = brainstate.ShortTermState(jnp.arange(4))
|
286
285
|
|
287
286
|
def init_state_v1(self, *args, **kwargs):
|
288
|
-
self.c =
|
287
|
+
self.c = brainstate.State(jnp.arange(4))
|
289
288
|
|
290
289
|
def init_state_v2(self):
|
291
|
-
self.d =
|
290
|
+
self.d = brainstate.State(self.c.value * 2.)
|
292
291
|
|
293
292
|
foo = Foo()
|
294
293
|
|
@@ -318,11 +317,11 @@ class TestVmap(unittest.TestCase):
|
|
318
317
|
class TestMap(unittest.TestCase):
|
319
318
|
def test_map(self):
|
320
319
|
for dim in [(10,), (10, 10), (10, 10, 10)]:
|
321
|
-
x =
|
322
|
-
r1 =
|
323
|
-
r2 =
|
324
|
-
r3 =
|
325
|
-
r4 =
|
320
|
+
x = brainstate.random.rand(*dim)
|
321
|
+
r1 = brainstate.augment.map(lambda a: a + 1, x, batch_size=None)
|
322
|
+
r2 = brainstate.augment.map(lambda a: a + 1, x, batch_size=2)
|
323
|
+
r3 = brainstate.augment.map(lambda a: a + 1, x, batch_size=4)
|
324
|
+
r4 = brainstate.augment.map(lambda a: a + 1, x, batch_size=5)
|
326
325
|
true_r = x + 1
|
327
326
|
|
328
327
|
self.assertTrue(jnp.allclose(r1, true_r))
|
@@ -406,7 +405,7 @@ class TestVMAPNewStatesEdgeCases(unittest.TestCase):
|
|
406
405
|
foo = brainstate.nn.LIF(3)
|
407
406
|
# Testing that axis_size of 0 raises an error.
|
408
407
|
with self.assertRaises(ValueError):
|
409
|
-
@
|
408
|
+
@brainstate.augment.vmap_new_states(state_tag='new1', axis_size=0)
|
410
409
|
def faulty_init():
|
411
410
|
foo.init_state()
|
412
411
|
|
@@ -417,7 +416,7 @@ class TestVMAPNewStatesEdgeCases(unittest.TestCase):
|
|
417
416
|
foo = brainstate.nn.LIF(3)
|
418
417
|
# Testing that a negative axis_size raises an error.
|
419
418
|
with self.assertRaises(ValueError):
|
420
|
-
@
|
419
|
+
@brainstate.augment.vmap_new_states(state_tag='new1', axis_size=-3)
|
421
420
|
def faulty_init():
|
422
421
|
foo.init_state()
|
423
422
|
|
@@ -428,9 +427,9 @@ class TestVMAPNewStatesEdgeCases(unittest.TestCase):
|
|
428
427
|
|
429
428
|
# Simulate an incompatible shapes scenario:
|
430
429
|
# We intentionally assign a state with a different shape than expected.
|
431
|
-
@
|
430
|
+
@brainstate.augment.vmap_new_states(state_tag='new1', axis_size=5)
|
432
431
|
def faulty_init():
|
433
432
|
# Modify state to produce an incompatible shape
|
434
|
-
foo.c =
|
433
|
+
foo.c = brainstate.State(jnp.arange(3)) # Original expected shape is (4,)
|
435
434
|
|
436
435
|
faulty_init()
|
@@ -13,34 +13,32 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
import jax
|
19
17
|
import jax.numpy as jnp
|
20
18
|
from absl.testing import absltest
|
21
19
|
|
22
|
-
import brainstate
|
20
|
+
import brainstate
|
23
21
|
|
24
22
|
|
25
23
|
class TestRemat(absltest.TestCase):
|
26
24
|
def test_basic_remat(self):
|
27
|
-
module =
|
25
|
+
module = brainstate.compile.remat(brainstate.nn.Linear(2, 3))
|
28
26
|
y = module(jnp.ones((1, 2)))
|
29
27
|
assert y.shape == (1, 3)
|
30
28
|
|
31
29
|
def test_remat_with_scan(self):
|
32
|
-
class ScanLinear(
|
30
|
+
class ScanLinear(brainstate.nn.Module):
|
33
31
|
def __init__(self):
|
34
32
|
super().__init__()
|
35
|
-
self.linear =
|
33
|
+
self.linear = brainstate.nn.Linear(3, 3)
|
36
34
|
|
37
35
|
def __call__(self, x: jax.Array):
|
38
|
-
@
|
36
|
+
@brainstate.compile.remat
|
39
37
|
def fun(x: jax.Array, _):
|
40
38
|
x = self.linear(x)
|
41
39
|
return x, None
|
42
40
|
|
43
|
-
return
|
41
|
+
return brainstate.compile.scan(fun, x, None, length=10)[0]
|
44
42
|
|
45
43
|
m = ScanLinear()
|
46
44
|
|
@@ -12,27 +12,26 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from __future__ import annotations
|
16
15
|
|
17
16
|
import unittest
|
18
17
|
|
19
18
|
import jax
|
20
19
|
import jax.numpy as jnp
|
21
20
|
|
22
|
-
import brainstate
|
21
|
+
import brainstate
|
23
22
|
|
24
23
|
|
25
24
|
class TestCond(unittest.TestCase):
|
26
25
|
def test1(self):
|
27
|
-
|
28
|
-
|
29
|
-
|
26
|
+
brainstate.random.seed(1)
|
27
|
+
brainstate.compile.cond(True, lambda: brainstate.random.random(10), lambda: brainstate.random.random(10))
|
28
|
+
brainstate.compile.cond(False, lambda: brainstate.random.random(10), lambda: brainstate.random.random(10))
|
30
29
|
|
31
30
|
def test2(self):
|
32
|
-
st1 =
|
33
|
-
st2 =
|
34
|
-
st3 =
|
35
|
-
st4 =
|
31
|
+
st1 = brainstate.State(brainstate.random.rand(10))
|
32
|
+
st2 = brainstate.State(brainstate.random.rand(2))
|
33
|
+
st3 = brainstate.State(brainstate.random.rand(5))
|
34
|
+
st4 = brainstate.State(brainstate.random.rand(2, 10))
|
36
35
|
|
37
36
|
def true_fun(x):
|
38
37
|
st1.value = st2.value @ st4.value + x
|
@@ -40,7 +39,7 @@ class TestCond(unittest.TestCase):
|
|
40
39
|
def false_fun(x):
|
41
40
|
st3.value = (st3.value + 1.) * x
|
42
41
|
|
43
|
-
|
42
|
+
brainstate.compile.cond(True, true_fun, false_fun, 2.)
|
44
43
|
assert not isinstance(st1.value, jax.core.Tracer)
|
45
44
|
assert not isinstance(st2.value, jax.core.Tracer)
|
46
45
|
assert not isinstance(st3.value, jax.core.Tracer)
|
@@ -66,7 +65,7 @@ class TestSwitch(unittest.TestCase):
|
|
66
65
|
return branches[2](x)
|
67
66
|
|
68
67
|
def cfun(x):
|
69
|
-
return
|
68
|
+
return brainstate.compile.switch(x, branches, x)
|
70
69
|
|
71
70
|
self.assertEqual(fun(-1), cfun(-1))
|
72
71
|
self.assertEqual(fun(0), cfun(0))
|
@@ -90,7 +89,7 @@ class TestSwitch(unittest.TestCase):
|
|
90
89
|
return branches[i](x, x)
|
91
90
|
|
92
91
|
def cfun(x):
|
93
|
-
return
|
92
|
+
return brainstate.compile.switch(x, branches, x, x)
|
94
93
|
|
95
94
|
self.assertEqual(fun(-1), cfun(-1))
|
96
95
|
self.assertEqual(fun(0), cfun(0))
|
@@ -123,13 +122,13 @@ class TestSwitch(unittest.TestCase):
|
|
123
122
|
branches3 = branches2 + [lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
|
124
123
|
|
125
124
|
def fun1(x, i):
|
126
|
-
return
|
125
|
+
return brainstate.compile.switch(i + 1, branches1, x)
|
127
126
|
|
128
127
|
def fun2(x, i):
|
129
|
-
return
|
128
|
+
return brainstate.compile.switch(i + 1, branches2, x)
|
130
129
|
|
131
130
|
def fun3(x, i):
|
132
|
-
return
|
131
|
+
return brainstate.compile.switch(i + 1, branches3, x)
|
133
132
|
|
134
133
|
fwd1, bwd1 = get_conds(fun1)
|
135
134
|
fwd2, bwd2 = get_conds(fun2)
|
@@ -149,7 +148,7 @@ class TestSwitch(unittest.TestCase):
|
|
149
148
|
|
150
149
|
def testOneBranchSwitch(self):
|
151
150
|
branch = lambda x: -x
|
152
|
-
f = lambda i, x:
|
151
|
+
f = lambda i, x: brainstate.compile.switch(i, [branch], x)
|
153
152
|
x = 7.
|
154
153
|
self.assertEqual(f(-1, x), branch(x))
|
155
154
|
self.assertEqual(f(0, x), branch(x))
|
@@ -167,12 +166,12 @@ class TestSwitch(unittest.TestCase):
|
|
167
166
|
class TestIfElse(unittest.TestCase):
|
168
167
|
def test1(self):
|
169
168
|
def f(a):
|
170
|
-
return
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
169
|
+
return brainstate.compile.ifelse(conditions=[a < 0,
|
170
|
+
a >= 0 and a < 2,
|
171
|
+
a >= 2 and a < 5,
|
172
|
+
a >= 5 and a < 10,
|
173
|
+
a >= 10],
|
174
|
+
branches=[lambda: 1,
|
176
175
|
lambda: 2,
|
177
176
|
lambda: 3,
|
178
177
|
lambda: 4,
|
@@ -184,38 +183,38 @@ class TestIfElse(unittest.TestCase):
|
|
184
183
|
|
185
184
|
def test_vmap(self):
|
186
185
|
def f(operands):
|
187
|
-
f = lambda a:
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
186
|
+
f = lambda a: brainstate.compile.ifelse([a > 10,
|
187
|
+
jnp.logical_and(a <= 10, a > 5),
|
188
|
+
jnp.logical_and(a <= 5, a > 2),
|
189
|
+
jnp.logical_and(a <= 2, a > 0),
|
190
|
+
a <= 0],
|
191
|
+
[lambda _: 1,
|
193
192
|
lambda _: 2,
|
194
193
|
lambda _: 3,
|
195
194
|
lambda _: 4,
|
196
195
|
lambda _: 5, ],
|
197
|
-
|
196
|
+
a)
|
198
197
|
return jax.vmap(f)(operands)
|
199
198
|
|
200
|
-
r = f(
|
199
|
+
r = f(brainstate.random.randint(-20, 20, 200))
|
201
200
|
self.assertTrue(r.size == 200)
|
202
201
|
|
203
202
|
def test_grad1(self):
|
204
203
|
def F2(x):
|
205
|
-
return
|
206
|
-
|
207
|
-
|
204
|
+
return brainstate.compile.ifelse((x >= 10, x < 10),
|
205
|
+
[lambda x: x, lambda x: x ** 2, ],
|
206
|
+
x)
|
208
207
|
|
209
208
|
self.assertTrue(jax.grad(F2)(9.0) == 18.)
|
210
209
|
self.assertTrue(jax.grad(F2)(11.0) == 1.)
|
211
210
|
|
212
211
|
def test_grad2(self):
|
213
212
|
def F3(x):
|
214
|
-
return
|
215
|
-
|
213
|
+
return brainstate.compile.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
|
214
|
+
[lambda x: x,
|
216
215
|
lambda x: x ** 2,
|
217
216
|
lambda x: x ** 4, ],
|
218
|
-
|
217
|
+
x)
|
219
218
|
|
220
219
|
self.assertTrue(jax.grad(F3)(9.0) == 18.)
|
221
220
|
self.assertTrue(jax.grad(F3)(11.0) == 1.)
|
@@ -13,43 +13,40 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
import unittest
|
19
17
|
|
20
18
|
import jax
|
21
19
|
import jax.numpy as jnp
|
22
|
-
import jaxlib.xla_extension
|
23
20
|
|
24
|
-
import brainstate
|
21
|
+
import brainstate
|
25
22
|
|
26
23
|
|
27
24
|
class TestJitError(unittest.TestCase):
|
28
25
|
def test1(self):
|
29
|
-
with self.assertRaises(
|
30
|
-
|
26
|
+
with self.assertRaises(Exception):
|
27
|
+
brainstate.compile.jit_error_if(True, 'error')
|
31
28
|
|
32
29
|
def err_f(x):
|
33
30
|
raise ValueError(f'error: {x}')
|
34
31
|
|
35
|
-
|
36
|
-
with self.assertRaises(
|
37
|
-
|
32
|
+
brainstate.compile.jit_error_if(False, err_f, 1.)
|
33
|
+
with self.assertRaises(Exception):
|
34
|
+
brainstate.compile.jit_error_if(True, err_f, 1.)
|
38
35
|
|
39
36
|
def test_vmap(self):
|
40
37
|
def f(x):
|
41
|
-
|
38
|
+
brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
|
42
39
|
|
43
40
|
jax.vmap(f)(jnp.array([False, False, False]))
|
44
|
-
with self.assertRaises(
|
41
|
+
with self.assertRaises(Exception):
|
45
42
|
jax.vmap(f)(jnp.array([True, False, False]))
|
46
43
|
|
47
44
|
def test_vmap_vmap(self):
|
48
45
|
def f(x):
|
49
|
-
|
46
|
+
brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
|
50
47
|
|
51
48
|
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
52
49
|
[False, False, False]]))
|
53
|
-
with self.assertRaises(
|
50
|
+
with self.assertRaises(Exception):
|
54
51
|
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
55
52
|
[True, False, False]]))
|