brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +10 -3
- brainstate/_state.py +178 -178
- brainstate/_utils.py +0 -1
- brainstate/augment/_autograd.py +0 -2
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape.py +0 -2
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping.py +2 -3
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/augment/_random.py +0 -2
- brainstate/compile/_ad_checkpoint.py +0 -2
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions.py +0 -2
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if.py +0 -2
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_jit.py +9 -8
- brainstate/compile/_loop_collect_return.py +0 -2
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection.py +0 -2
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +30 -17
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/compile/_progress_bar.py +0 -1
- brainstate/compile/_unvmap.py +0 -1
- brainstate/compile/_util.py +0 -2
- brainstate/environ.py +0 -2
- brainstate/functional/_activations.py +0 -2
- brainstate/functional/_activations_test.py +61 -61
- brainstate/functional/_normalization.py +0 -2
- brainstate/functional/_others.py +0 -2
- brainstate/functional/_spikes.py +0 -1
- brainstate/graph/_graph_node.py +1 -3
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation.py +4 -2
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_base.py +0 -2
- brainstate/init/_generic.py +0 -1
- brainstate/init/_random_inits.py +0 -1
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits.py +0 -2
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/mixin.py +0 -2
- brainstate/nn/_collective_ops.py +0 -3
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/_common.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
- brainstate/nn/_dyn_impl/_inputs.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
- brainstate/nn/_dyn_impl/_readout.py +0 -1
- brainstate/nn/_dyn_impl/_readout_test.py +9 -10
- brainstate/nn/_dynamics/_dynamics_base.py +0 -1
- brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
- brainstate/nn/_dynamics/_projection_base.py +0 -1
- brainstate/nn/_dynamics/_state_delay.py +0 -2
- brainstate/nn/_dynamics/_synouts.py +0 -2
- brainstate/nn/_dynamics/_synouts_test.py +4 -5
- brainstate/nn/_elementwise/_dropout.py +0 -2
- brainstate/nn/_elementwise/_dropout_test.py +9 -9
- brainstate/nn/_elementwise/_elementwise.py +0 -2
- brainstate/nn/_elementwise/_elementwise_test.py +57 -59
- brainstate/nn/_event/_fixedprob_mv.py +0 -1
- brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
- brainstate/nn/_event/_linear_mv.py +0 -2
- brainstate/nn/_event/_linear_mv_test.py +0 -1
- brainstate/nn/_exp_euler.py +0 -2
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/_interaction/_conv.py +0 -2
- brainstate/nn/_interaction/_conv_test.py +31 -33
- brainstate/nn/_interaction/_embedding.py +0 -1
- brainstate/nn/_interaction/_linear.py +0 -2
- brainstate/nn/_interaction/_linear_test.py +15 -17
- brainstate/nn/_interaction/_normalizations.py +0 -2
- brainstate/nn/_interaction/_normalizations_test.py +10 -12
- brainstate/nn/_interaction/_poolings.py +0 -2
- brainstate/nn/_interaction/_poolings_test.py +19 -21
- brainstate/nn/_module.py +0 -1
- brainstate/nn/_module_test.py +34 -37
- brainstate/nn/metrics.py +0 -2
- brainstate/optim/_base.py +0 -2
- brainstate/optim/_lr_scheduler.py +0 -1
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer.py +0 -2
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/optim/_sgd_optimizer.py +0 -1
- brainstate/random/_rand_funs.py +0 -1
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed.py +0 -1
- brainstate/random/_rand_seed_test.py +10 -12
- brainstate/random/_rand_state.py +0 -1
- brainstate/surrogate.py +0 -1
- brainstate/typing.py +0 -2
- brainstate/util/_caller.py +4 -6
- brainstate/util/_others.py +0 -2
- brainstate/util/_pretty_pytree.py +201 -150
- brainstate/util/_pretty_repr.py +0 -2
- brainstate/util/_pretty_table.py +57 -3
- brainstate/util/_scaling.py +0 -2
- brainstate/util/_struct.py +0 -2
- brainstate/util/filter.py +0 -2
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
- brainstate-0.1.2.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
brainstate/nn/_module_test.py
CHANGED
@@ -13,20 +13,17 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
import unittest
|
19
17
|
|
20
18
|
import jax.numpy as jnp
|
21
|
-
import jaxlib.xla_extension
|
22
19
|
|
23
|
-
import brainstate
|
20
|
+
import brainstate
|
24
21
|
|
25
22
|
|
26
23
|
class TestDelay(unittest.TestCase):
|
27
24
|
def test_delay1(self):
|
28
|
-
a =
|
29
|
-
delay =
|
25
|
+
a = brainstate.State(brainstate.random.random(10, 20))
|
26
|
+
delay = brainstate.nn.Delay(a.value)
|
30
27
|
delay.register_entry('a', 1.)
|
31
28
|
delay.register_entry('b', 2.)
|
32
29
|
delay.register_entry('c', None)
|
@@ -36,7 +33,7 @@ class TestDelay(unittest.TestCase):
|
|
36
33
|
delay.register_entry('c', 10.)
|
37
34
|
|
38
35
|
def test_rotation_delay(self):
|
39
|
-
rotation_delay =
|
36
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
|
40
37
|
t0 = 0.
|
41
38
|
t1, n1 = 1., 10
|
42
39
|
t2, n2 = 2., 20
|
@@ -53,7 +50,7 @@ class TestDelay(unittest.TestCase):
|
|
53
50
|
# print(rotation_delay.max_length)
|
54
51
|
|
55
52
|
for i in range(100):
|
56
|
-
|
53
|
+
brainstate.environ.set(i=i)
|
57
54
|
rotation_delay.update(jnp.ones((1,)) * i)
|
58
55
|
# print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
|
59
56
|
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
@@ -61,7 +58,7 @@ class TestDelay(unittest.TestCase):
|
|
61
58
|
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
62
59
|
|
63
60
|
def test_concat_delay(self):
|
64
|
-
rotation_delay =
|
61
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
|
65
62
|
t0 = 0.
|
66
63
|
t1, n1 = 1., 10
|
67
64
|
t2, n2 = 2., 20
|
@@ -74,7 +71,7 @@ class TestDelay(unittest.TestCase):
|
|
74
71
|
|
75
72
|
print()
|
76
73
|
for i in range(100):
|
77
|
-
|
74
|
+
brainstate.environ.set(i=i)
|
78
75
|
rotation_delay.update(jnp.ones((1,)) * i)
|
79
76
|
print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
|
80
77
|
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
@@ -83,40 +80,40 @@ class TestDelay(unittest.TestCase):
|
|
83
80
|
# bst.util.clear_buffer_memory()
|
84
81
|
|
85
82
|
def test_jit_erro(self):
|
86
|
-
rotation_delay =
|
83
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
|
87
84
|
rotation_delay.init_state()
|
88
85
|
|
89
|
-
with
|
86
|
+
with brainstate.environ.context(i=0, t=0, jit_error_check=True):
|
90
87
|
rotation_delay.retrieve_at_time(-2.0)
|
91
|
-
with self.assertRaises(
|
88
|
+
with self.assertRaises(Exception):
|
92
89
|
rotation_delay.retrieve_at_time(-2.1)
|
93
90
|
rotation_delay.retrieve_at_time(-2.01)
|
94
|
-
with self.assertRaises(
|
91
|
+
with self.assertRaises(Exception):
|
95
92
|
rotation_delay.retrieve_at_time(-2.09)
|
96
|
-
with self.assertRaises(
|
93
|
+
with self.assertRaises(Exception):
|
97
94
|
rotation_delay.retrieve_at_time(0.1)
|
98
|
-
with self.assertRaises(
|
95
|
+
with self.assertRaises(Exception):
|
99
96
|
rotation_delay.retrieve_at_time(0.01)
|
100
97
|
|
101
98
|
def test_round_interp(self):
|
102
99
|
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
103
100
|
for delay_method in ['rotation', 'concat']:
|
104
|
-
rotation_delay =
|
105
|
-
|
101
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
|
102
|
+
interp_method='round')
|
106
103
|
t0, n1 = 0.01, 0
|
107
104
|
t1, n1 = 1.04, 10
|
108
105
|
t2, n2 = 1.06, 11
|
109
106
|
rotation_delay.init_state()
|
110
107
|
|
111
|
-
@
|
108
|
+
@brainstate.compile.jit
|
112
109
|
def retrieve(td, i):
|
113
|
-
with
|
110
|
+
with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
|
114
111
|
return rotation_delay.retrieve_at_time(td)
|
115
112
|
|
116
113
|
print()
|
117
114
|
for i in range(100):
|
118
|
-
t = i *
|
119
|
-
with
|
115
|
+
t = i * brainstate.environ.get_dt()
|
116
|
+
with brainstate.environ.context(i=i, t=t):
|
120
117
|
rotation_delay.update(jnp.ones(shape) * i)
|
121
118
|
print(i,
|
122
119
|
retrieve(t - t0, i),
|
@@ -131,22 +128,22 @@ class TestDelay(unittest.TestCase):
|
|
131
128
|
for delay_method in ['rotation', 'concat']:
|
132
129
|
print(shape, delay_method)
|
133
130
|
|
134
|
-
rotation_delay =
|
135
|
-
|
131
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
|
132
|
+
interp_method='linear_interp')
|
136
133
|
t0, n0 = 0.01, 0.1
|
137
134
|
t1, n1 = 1.04, 10.4
|
138
135
|
t2, n2 = 1.06, 10.6
|
139
136
|
rotation_delay.init_state()
|
140
137
|
|
141
|
-
@
|
138
|
+
@brainstate.compile.jit
|
142
139
|
def retrieve(td, i):
|
143
|
-
with
|
140
|
+
with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
|
144
141
|
return rotation_delay.retrieve_at_time(td)
|
145
142
|
|
146
143
|
print()
|
147
144
|
for i in range(100):
|
148
|
-
t = i *
|
149
|
-
with
|
145
|
+
t = i * brainstate.environ.get_dt()
|
146
|
+
with brainstate.environ.context(i=i, t=t):
|
150
147
|
rotation_delay.update(jnp.ones(shape) * i)
|
151
148
|
print(i,
|
152
149
|
retrieve(t - t0, i),
|
@@ -157,8 +154,8 @@ class TestDelay(unittest.TestCase):
|
|
157
154
|
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
158
155
|
|
159
156
|
def test_rotation_and_concat_delay(self):
|
160
|
-
rotation_delay =
|
161
|
-
concat_delay =
|
157
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
|
158
|
+
concat_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
|
162
159
|
t0 = 0.
|
163
160
|
t1, n1 = 1., 10
|
164
161
|
t2, n2 = 2., 20
|
@@ -175,7 +172,7 @@ class TestDelay(unittest.TestCase):
|
|
175
172
|
|
176
173
|
print()
|
177
174
|
for i in range(100):
|
178
|
-
|
175
|
+
brainstate.environ.set(i=i)
|
179
176
|
new = jnp.ones((1,)) * i
|
180
177
|
rotation_delay.update(new)
|
181
178
|
concat_delay.update(new)
|
@@ -186,17 +183,17 @@ class TestDelay(unittest.TestCase):
|
|
186
183
|
|
187
184
|
class TestModule(unittest.TestCase):
|
188
185
|
def test_states(self):
|
189
|
-
class A(
|
186
|
+
class A(brainstate.nn.Module):
|
190
187
|
def __init__(self):
|
191
188
|
super().__init__()
|
192
|
-
self.a =
|
193
|
-
self.b =
|
189
|
+
self.a = brainstate.State(brainstate.random.random(10, 20))
|
190
|
+
self.b = brainstate.State(brainstate.random.random(10, 20))
|
194
191
|
|
195
|
-
class B(
|
192
|
+
class B(brainstate.nn.Module):
|
196
193
|
def __init__(self):
|
197
194
|
super().__init__()
|
198
195
|
self.a = A()
|
199
|
-
self.b =
|
196
|
+
self.b = brainstate.State(brainstate.random.random(10, 20))
|
200
197
|
|
201
198
|
b = B()
|
202
199
|
print()
|
@@ -207,5 +204,5 @@ class TestModule(unittest.TestCase):
|
|
207
204
|
|
208
205
|
|
209
206
|
if __name__ == '__main__':
|
210
|
-
with
|
207
|
+
with brainstate.environ.context(dt=0.1):
|
211
208
|
unittest.main()
|
brainstate/nn/metrics.py
CHANGED
brainstate/optim/_base.py
CHANGED
@@ -19,12 +19,12 @@ import unittest
|
|
19
19
|
|
20
20
|
import jax.numpy as jnp
|
21
21
|
|
22
|
-
import brainstate
|
22
|
+
import brainstate
|
23
23
|
|
24
24
|
|
25
25
|
class TestMultiStepLR(unittest.TestCase):
|
26
26
|
def test1(self):
|
27
|
-
lr =
|
27
|
+
lr = brainstate.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)
|
28
28
|
for i in range(40):
|
29
29
|
r = lr(i)
|
30
30
|
if i < 10:
|
@@ -37,7 +37,7 @@ class TestMultiStepLR(unittest.TestCase):
|
|
37
37
|
self.assertTrue(jnp.allclose(r, 0.0001))
|
38
38
|
|
39
39
|
def test2(self):
|
40
|
-
lr =
|
40
|
+
lr = brainstate.compile.jit(brainstate.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
|
41
41
|
for i in range(40):
|
42
42
|
r = lr(i)
|
43
43
|
if i < 10:
|
@@ -13,39 +13,38 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
16
|
|
18
17
|
import unittest
|
19
18
|
|
20
19
|
import jax
|
21
20
|
import optax
|
22
21
|
|
23
|
-
import brainstate
|
22
|
+
import brainstate
|
24
23
|
|
25
24
|
|
26
25
|
class TestOptaxOptimizer(unittest.TestCase):
|
27
26
|
def test1(self):
|
28
|
-
class Model(
|
27
|
+
class Model(brainstate.nn.Module):
|
29
28
|
def __init__(self):
|
30
29
|
super().__init__()
|
31
|
-
self.linear1 =
|
32
|
-
self.linear2 =
|
30
|
+
self.linear1 = brainstate.nn.Linear(2, 3)
|
31
|
+
self.linear2 = brainstate.nn.Linear(3, 4)
|
33
32
|
|
34
33
|
def __call__(self, x):
|
35
34
|
return self.linear2(self.linear1(x))
|
36
35
|
|
37
|
-
x =
|
36
|
+
x = brainstate.random.randn(1, 2)
|
38
37
|
y = jax.numpy.ones((1, 4))
|
39
38
|
|
40
39
|
model = Model()
|
41
40
|
tx = optax.adam(1e-3)
|
42
|
-
optimizer =
|
43
|
-
optimizer.register_trainable_weights(model.states(
|
41
|
+
optimizer = brainstate.optim.OptaxOptimizer(tx)
|
42
|
+
optimizer.register_trainable_weights(model.states(brainstate.ParamState))
|
44
43
|
|
45
44
|
loss_fn = lambda: ((model(x) - y) ** 2).mean()
|
46
45
|
prev_loss = loss_fn()
|
47
46
|
|
48
|
-
grads =
|
47
|
+
grads = brainstate.augment.grad(loss_fn, model.states(brainstate.ParamState))()
|
49
48
|
optimizer.update(grads)
|
50
49
|
|
51
50
|
new_loss = loss_fn()
|