brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +10 -3
- brainstate/_state.py +178 -178
- brainstate/_utils.py +0 -1
- brainstate/augment/_autograd.py +0 -2
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape.py +0 -2
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping.py +2 -3
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/augment/_random.py +0 -2
- brainstate/compile/_ad_checkpoint.py +0 -2
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions.py +0 -2
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if.py +0 -2
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_jit.py +9 -8
- brainstate/compile/_loop_collect_return.py +0 -2
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection.py +0 -2
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +30 -17
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/compile/_progress_bar.py +0 -1
- brainstate/compile/_unvmap.py +0 -1
- brainstate/compile/_util.py +0 -2
- brainstate/environ.py +0 -2
- brainstate/functional/_activations.py +0 -2
- brainstate/functional/_activations_test.py +61 -61
- brainstate/functional/_normalization.py +0 -2
- brainstate/functional/_others.py +0 -2
- brainstate/functional/_spikes.py +0 -1
- brainstate/graph/_graph_node.py +1 -3
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation.py +4 -2
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_base.py +0 -2
- brainstate/init/_generic.py +0 -1
- brainstate/init/_random_inits.py +0 -1
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits.py +0 -2
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/mixin.py +0 -2
- brainstate/nn/_collective_ops.py +0 -3
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/_common.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
- brainstate/nn/_dyn_impl/_inputs.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
- brainstate/nn/_dyn_impl/_readout.py +0 -1
- brainstate/nn/_dyn_impl/_readout_test.py +9 -10
- brainstate/nn/_dynamics/_dynamics_base.py +0 -1
- brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
- brainstate/nn/_dynamics/_projection_base.py +0 -1
- brainstate/nn/_dynamics/_state_delay.py +0 -2
- brainstate/nn/_dynamics/_synouts.py +0 -2
- brainstate/nn/_dynamics/_synouts_test.py +4 -5
- brainstate/nn/_elementwise/_dropout.py +0 -2
- brainstate/nn/_elementwise/_dropout_test.py +9 -9
- brainstate/nn/_elementwise/_elementwise.py +0 -2
- brainstate/nn/_elementwise/_elementwise_test.py +57 -59
- brainstate/nn/_event/_fixedprob_mv.py +0 -1
- brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
- brainstate/nn/_event/_linear_mv.py +0 -2
- brainstate/nn/_event/_linear_mv_test.py +0 -1
- brainstate/nn/_exp_euler.py +0 -2
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/_interaction/_conv.py +0 -2
- brainstate/nn/_interaction/_conv_test.py +31 -33
- brainstate/nn/_interaction/_embedding.py +0 -1
- brainstate/nn/_interaction/_linear.py +0 -2
- brainstate/nn/_interaction/_linear_test.py +15 -17
- brainstate/nn/_interaction/_normalizations.py +0 -2
- brainstate/nn/_interaction/_normalizations_test.py +10 -12
- brainstate/nn/_interaction/_poolings.py +0 -2
- brainstate/nn/_interaction/_poolings_test.py +19 -21
- brainstate/nn/_module.py +0 -1
- brainstate/nn/_module_test.py +34 -37
- brainstate/nn/metrics.py +0 -2
- brainstate/optim/_base.py +0 -2
- brainstate/optim/_lr_scheduler.py +0 -1
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer.py +0 -2
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/optim/_sgd_optimizer.py +0 -1
- brainstate/random/_rand_funs.py +0 -1
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed.py +0 -1
- brainstate/random/_rand_seed_test.py +10 -12
- brainstate/random/_rand_state.py +0 -1
- brainstate/surrogate.py +0 -1
- brainstate/typing.py +0 -2
- brainstate/util/_caller.py +4 -6
- brainstate/util/_others.py +0 -2
- brainstate/util/_pretty_pytree.py +201 -150
- brainstate/util/_pretty_repr.py +0 -2
- brainstate/util/_pretty_table.py +57 -3
- brainstate/util/_scaling.py +0 -2
- brainstate/util/_struct.py +0 -2
- brainstate/util/filter.py +0 -2
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
- brainstate-0.1.2.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
brainstate/init/_base.py
CHANGED
brainstate/init/_generic.py
CHANGED
brainstate/init/_random_inits.py
CHANGED
@@ -14,30 +14,29 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
|
-
from __future__ import annotations
|
18
17
|
|
19
18
|
import unittest
|
20
19
|
|
21
|
-
import brainstate
|
20
|
+
import brainstate
|
22
21
|
|
23
22
|
|
24
23
|
class TestNormalInit(unittest.TestCase):
|
25
24
|
|
26
25
|
def test_normal_init1(self):
|
27
|
-
init =
|
26
|
+
init = brainstate.init.Normal()
|
28
27
|
for size in [(100,), (10, 20), (10, 20, 30)]:
|
29
28
|
weights = init(size)
|
30
29
|
assert weights.shape == size
|
31
30
|
|
32
31
|
def test_normal_init2(self):
|
33
|
-
init =
|
32
|
+
init = brainstate.init.Normal(scale=0.5)
|
34
33
|
for size in [(100,), (10, 20)]:
|
35
34
|
weights = init(size)
|
36
35
|
assert weights.shape == size
|
37
36
|
|
38
37
|
def test_normal_init3(self):
|
39
|
-
init1 =
|
40
|
-
init2 =
|
38
|
+
init1 = brainstate.init.Normal(scale=0.5, seed=10)
|
39
|
+
init2 = brainstate.init.Normal(scale=0.5, seed=10)
|
41
40
|
size = (10,)
|
42
41
|
weights1 = init1(size)
|
43
42
|
weights2 = init2(size)
|
@@ -47,13 +46,13 @@ class TestNormalInit(unittest.TestCase):
|
|
47
46
|
|
48
47
|
class TestUniformInit(unittest.TestCase):
|
49
48
|
def test_uniform_init1(self):
|
50
|
-
init =
|
49
|
+
init = brainstate.init.Normal()
|
51
50
|
for size in [(100,), (10, 20), (10, 20, 30)]:
|
52
51
|
weights = init(size)
|
53
52
|
assert weights.shape == size
|
54
53
|
|
55
54
|
def test_uniform_init2(self):
|
56
|
-
init =
|
55
|
+
init = brainstate.init.Uniform(min_val=10, max_val=20)
|
57
56
|
for size in [(100,), (10, 20)]:
|
58
57
|
weights = init(size)
|
59
58
|
assert weights.shape == size
|
@@ -61,20 +60,20 @@ class TestUniformInit(unittest.TestCase):
|
|
61
60
|
|
62
61
|
class TestVarianceScaling(unittest.TestCase):
|
63
62
|
def test_var_scaling1(self):
|
64
|
-
init =
|
63
|
+
init = brainstate.init.VarianceScaling(scale=1., mode='fan_in', distribution='truncated_normal')
|
65
64
|
for size in [(10, 20), (10, 20, 30)]:
|
66
65
|
weights = init(size)
|
67
66
|
assert weights.shape == size
|
68
67
|
|
69
68
|
def test_var_scaling2(self):
|
70
|
-
init =
|
69
|
+
init = brainstate.init.VarianceScaling(scale=2, mode='fan_out', distribution='normal')
|
71
70
|
for size in [(10, 20), (10, 20, 30)]:
|
72
71
|
weights = init(size)
|
73
72
|
assert weights.shape == size
|
74
73
|
|
75
74
|
def test_var_scaling3(self):
|
76
|
-
init =
|
77
|
-
|
75
|
+
init = brainstate.init.VarianceScaling(scale=2 / 4, mode='fan_avg', in_axis=0, out_axis=1,
|
76
|
+
distribution='uniform')
|
78
77
|
for size in [(10, 20), (10, 20, 30)]:
|
79
78
|
weights = init(size)
|
80
79
|
assert weights.shape == size
|
@@ -82,7 +81,7 @@ class TestVarianceScaling(unittest.TestCase):
|
|
82
81
|
|
83
82
|
class TestKaimingUniformUnit(unittest.TestCase):
|
84
83
|
def test_kaiming_uniform_init(self):
|
85
|
-
init =
|
84
|
+
init = brainstate.init.KaimingUniform()
|
86
85
|
for size in [(10, 20), (10, 20, 30)]:
|
87
86
|
weights = init(size)
|
88
87
|
assert weights.shape == size
|
@@ -90,7 +89,7 @@ class TestKaimingUniformUnit(unittest.TestCase):
|
|
90
89
|
|
91
90
|
class TestKaimingNormalUnit(unittest.TestCase):
|
92
91
|
def test_kaiming_normal_init(self):
|
93
|
-
init =
|
92
|
+
init = brainstate.init.KaimingNormal()
|
94
93
|
for size in [(10, 20), (10, 20, 30)]:
|
95
94
|
weights = init(size)
|
96
95
|
assert weights.shape == size
|
@@ -98,7 +97,7 @@ class TestKaimingNormalUnit(unittest.TestCase):
|
|
98
97
|
|
99
98
|
class TestXavierUniformUnit(unittest.TestCase):
|
100
99
|
def test_xavier_uniform_init(self):
|
101
|
-
init =
|
100
|
+
init = brainstate.init.XavierUniform()
|
102
101
|
for size in [(10, 20), (10, 20, 30)]:
|
103
102
|
weights = init(size)
|
104
103
|
assert weights.shape == size
|
@@ -106,7 +105,7 @@ class TestXavierUniformUnit(unittest.TestCase):
|
|
106
105
|
|
107
106
|
class TestXavierNormalUnit(unittest.TestCase):
|
108
107
|
def test_xavier_normal_init(self):
|
109
|
-
init =
|
108
|
+
init = brainstate.init.XavierNormal()
|
110
109
|
for size in [(10, 20), (10, 20, 30)]:
|
111
110
|
weights = init(size)
|
112
111
|
assert weights.shape == size
|
@@ -114,7 +113,7 @@ class TestXavierNormalUnit(unittest.TestCase):
|
|
114
113
|
|
115
114
|
class TestLecunUniformUnit(unittest.TestCase):
|
116
115
|
def test_lecun_uniform_init(self):
|
117
|
-
init =
|
116
|
+
init = brainstate.init.LecunUniform()
|
118
117
|
for size in [(10, 20), (10, 20, 30)]:
|
119
118
|
weights = init(size)
|
120
119
|
assert weights.shape == size
|
@@ -122,7 +121,7 @@ class TestLecunUniformUnit(unittest.TestCase):
|
|
122
121
|
|
123
122
|
class TestLecunNormalUnit(unittest.TestCase):
|
124
123
|
def test_lecun_normal_init(self):
|
125
|
-
init =
|
124
|
+
init = brainstate.init.LecunNormal()
|
126
125
|
for size in [(10, 20), (10, 20, 30)]:
|
127
126
|
weights = init(size)
|
128
127
|
assert weights.shape == size
|
@@ -130,13 +129,13 @@ class TestLecunNormalUnit(unittest.TestCase):
|
|
130
129
|
|
131
130
|
class TestOrthogonalUnit(unittest.TestCase):
|
132
131
|
def test_orthogonal_init1(self):
|
133
|
-
init =
|
132
|
+
init = brainstate.init.Orthogonal()
|
134
133
|
for size in [(20, 20), (10, 20, 30)]:
|
135
134
|
weights = init(size)
|
136
135
|
assert weights.shape == size
|
137
136
|
|
138
137
|
def test_orthogonal_init2(self):
|
139
|
-
init =
|
138
|
+
init = brainstate.init.Orthogonal(scale=2., axis=0)
|
140
139
|
for size in [(10, 20), (10, 20, 30)]:
|
141
140
|
weights = init(size)
|
142
141
|
assert weights.shape == size
|
@@ -144,7 +143,7 @@ class TestOrthogonalUnit(unittest.TestCase):
|
|
144
143
|
|
145
144
|
class TestDeltaOrthogonalUnit(unittest.TestCase):
|
146
145
|
def test_delta_orthogonal_init1(self):
|
147
|
-
init =
|
146
|
+
init = brainstate.init.DeltaOrthogonal()
|
148
147
|
for size in [(20, 20, 20), (10, 20, 30, 40), (50, 40, 30, 20, 20)]:
|
149
148
|
weights = init(size)
|
150
149
|
assert weights.shape == size
|
@@ -14,16 +14,15 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
|
-
from __future__ import annotations
|
18
17
|
|
19
18
|
import unittest
|
20
19
|
|
21
|
-
import brainstate
|
20
|
+
import brainstate
|
22
21
|
|
23
22
|
|
24
23
|
class TestZeroInit(unittest.TestCase):
|
25
24
|
def test_zero_init(self):
|
26
|
-
init =
|
25
|
+
init = brainstate.init.ZeroInit()
|
27
26
|
for size in [(100,), (10, 20), (10, 20, 30)]:
|
28
27
|
weights = init(size)
|
29
28
|
assert weights.shape == size
|
@@ -33,7 +32,7 @@ class TestOneInit(unittest.TestCase):
|
|
33
32
|
def test_one_init(self):
|
34
33
|
for size in [(100,), (10, 20), (10, 20, 30)]:
|
35
34
|
for value in [0., 1., -1.]:
|
36
|
-
init =
|
35
|
+
init = brainstate.init.Constant(value=value)
|
37
36
|
weights = init(size)
|
38
37
|
assert weights.shape == size
|
39
38
|
assert (weights == value).all()
|
@@ -43,7 +42,7 @@ class TestIdentityInit(unittest.TestCase):
|
|
43
42
|
def test_identity_init(self):
|
44
43
|
for size in [(100,), (10, 20)]:
|
45
44
|
for value in [0., 1., -1.]:
|
46
|
-
init =
|
45
|
+
init = brainstate.init.Identity(value=value)
|
47
46
|
weights = init(size)
|
48
47
|
if len(size) == 1:
|
49
48
|
assert weights.shape == (size[0], size[0])
|
brainstate/mixin.py
CHANGED
brainstate/nn/_collective_ops.py
CHANGED
@@ -13,13 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
from collections import namedtuple
|
19
17
|
from typing import Callable, TypeVar, Tuple, Any, Dict
|
20
18
|
|
21
19
|
import jax
|
22
|
-
from typing import Callable, TypeVar, Tuple, Any, Dict
|
23
20
|
|
24
21
|
from brainstate._state import catch_new_states
|
25
22
|
from brainstate._utils import set_module_as
|
@@ -16,21 +16,21 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
18
|
|
19
|
-
import brainstate
|
19
|
+
import brainstate
|
20
20
|
|
21
21
|
|
22
22
|
class Test_vmap_init_all_states:
|
23
23
|
|
24
24
|
def test_vmap_init_all_states(self):
|
25
|
-
gru =
|
26
|
-
|
25
|
+
gru = brainstate.nn.GRUCell(1, 2)
|
26
|
+
brainstate.nn.vmap_init_all_states(gru, axis_size=10)
|
27
27
|
print(gru)
|
28
28
|
|
29
29
|
def test_vmap_init_all_states_v2(self):
|
30
|
-
@
|
30
|
+
@brainstate.compile.jit
|
31
31
|
def init():
|
32
|
-
gru =
|
33
|
-
|
32
|
+
gru = brainstate.nn.GRUCell(1, 2)
|
33
|
+
brainstate.nn.vmap_init_all_states(gru, axis_size=10)
|
34
34
|
print(gru)
|
35
35
|
|
36
36
|
init()
|
@@ -38,6 +38,6 @@ class Test_vmap_init_all_states:
|
|
38
38
|
|
39
39
|
class Test_init_all_states:
|
40
40
|
def test_init_all_states(self):
|
41
|
-
gru =
|
42
|
-
|
41
|
+
gru = brainstate.nn.GRUCell(1, 2)
|
42
|
+
brainstate.nn.init_all_states(gru, batch_size=10)
|
43
43
|
print(gru)
|
brainstate/nn/_common.py
CHANGED
@@ -15,7 +15,6 @@
|
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
|
-
from __future__ import annotations
|
19
18
|
|
20
19
|
import unittest
|
21
20
|
|
@@ -23,7 +22,7 @@ import brainunit as u
|
|
23
22
|
import jax
|
24
23
|
import jax.numpy as jnp
|
25
24
|
|
26
|
-
import brainstate
|
25
|
+
import brainstate
|
27
26
|
from brainstate.nn import IF, LIF, ALIF
|
28
27
|
|
29
28
|
|
@@ -35,13 +34,13 @@ class TestNeuron(unittest.TestCase):
|
|
35
34
|
|
36
35
|
def test_neuron_base_class(self):
|
37
36
|
with self.assertRaises(NotImplementedError):
|
38
|
-
|
37
|
+
brainstate.nn.Neuron(self.in_size).get_spike() # Neuron is an abstract base class
|
39
38
|
|
40
39
|
def generate_input(self):
|
41
|
-
return
|
40
|
+
return brainstate.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mA
|
42
41
|
|
43
42
|
def test_if_neuron(self):
|
44
|
-
with
|
43
|
+
with brainstate.environ.context(dt=0.1 * u.ms):
|
45
44
|
neuron = IF(self.in_size)
|
46
45
|
inputs = self.generate_input()
|
47
46
|
|
@@ -62,7 +61,7 @@ class TestNeuron(unittest.TestCase):
|
|
62
61
|
self.assertTrue(jnp.all((spikes >= 0) & (spikes <= 1)))
|
63
62
|
|
64
63
|
def test_lif_neuron(self):
|
65
|
-
with
|
64
|
+
with brainstate.environ.context(dt=0.1 * u.ms):
|
66
65
|
tau = 20.0 * u.ms
|
67
66
|
neuron = LIF(self.in_size, tau=tau)
|
68
67
|
inputs = self.generate_input()
|
@@ -74,7 +73,7 @@ class TestNeuron(unittest.TestCase):
|
|
74
73
|
|
75
74
|
# Test forward pass
|
76
75
|
state = neuron.init_state(self.batch_size)
|
77
|
-
call =
|
76
|
+
call = brainstate.compile.jit(neuron)
|
78
77
|
|
79
78
|
for t in range(self.time_steps):
|
80
79
|
out = call(inputs[t])
|
@@ -94,8 +93,8 @@ class TestNeuron(unittest.TestCase):
|
|
94
93
|
|
95
94
|
# Test forward pass
|
96
95
|
neuron.init_state(self.batch_size)
|
97
|
-
call =
|
98
|
-
with
|
96
|
+
call = brainstate.compile.jit(neuron)
|
97
|
+
with brainstate.environ.context(dt=0.1 * u.ms):
|
99
98
|
for t in range(self.time_steps):
|
100
99
|
out = call(inputs[t])
|
101
100
|
self.assertEqual(out.shape, (self.batch_size, self.in_size))
|
@@ -113,8 +112,8 @@ class TestNeuron(unittest.TestCase):
|
|
113
112
|
neuron = NeuronClass(self.in_size, spk_reset='soft')
|
114
113
|
inputs = self.generate_input()
|
115
114
|
state = neuron.init_state(self.batch_size)
|
116
|
-
call =
|
117
|
-
with
|
115
|
+
call = brainstate.compile.jit(neuron)
|
116
|
+
with brainstate.environ.context(dt=0.1 * u.ms):
|
118
117
|
for t in range(self.time_steps):
|
119
118
|
out = call(inputs[t])
|
120
119
|
self.assertTrue(jnp.all(neuron.V.value <= neuron.V_th))
|
@@ -124,8 +123,8 @@ class TestNeuron(unittest.TestCase):
|
|
124
123
|
neuron = NeuronClass(self.in_size, spk_reset='hard')
|
125
124
|
inputs = self.generate_input()
|
126
125
|
state = neuron.init_state(self.batch_size)
|
127
|
-
call =
|
128
|
-
with
|
126
|
+
call = brainstate.compile.jit(neuron)
|
127
|
+
with brainstate.environ.context(dt=0.1 * u.ms):
|
129
128
|
for t in range(self.time_steps):
|
130
129
|
out = call(inputs[t])
|
131
130
|
self.assertTrue(jnp.all((neuron.V.value < neuron.V_th) | (neuron.V.value == 0. * u.mV)))
|
@@ -135,8 +134,8 @@ class TestNeuron(unittest.TestCase):
|
|
135
134
|
neuron = NeuronClass(self.in_size)
|
136
135
|
inputs = self.generate_input()
|
137
136
|
state = neuron.init_state(self.batch_size)
|
138
|
-
call =
|
139
|
-
with
|
137
|
+
call = brainstate.compile.jit(neuron)
|
138
|
+
with brainstate.environ.context(dt=0.1 * u.ms):
|
140
139
|
for t in range(self.time_steps):
|
141
140
|
out = call(inputs[t])
|
142
141
|
self.assertFalse(jax.tree_util.tree_leaves(out)[0].aval.weak_type)
|
@@ -148,15 +147,15 @@ class TestNeuron(unittest.TestCase):
|
|
148
147
|
self.assertEqual(neuron.in_size, in_size)
|
149
148
|
self.assertEqual(neuron.out_size, in_size)
|
150
149
|
|
151
|
-
inputs =
|
150
|
+
inputs = brainstate.random.randn(self.time_steps, self.batch_size, *in_size) * u.mA
|
152
151
|
state = neuron.init_state(self.batch_size)
|
153
|
-
call =
|
154
|
-
with
|
152
|
+
call = brainstate.compile.jit(neuron)
|
153
|
+
with brainstate.environ.context(dt=0.1 * u.ms):
|
155
154
|
for t in range(self.time_steps):
|
156
155
|
out = call(inputs[t])
|
157
156
|
self.assertEqual(out.shape, (self.batch_size, *in_size))
|
158
157
|
|
159
158
|
|
160
159
|
if __name__ == '__main__':
|
161
|
-
with
|
160
|
+
with brainstate.environ.context(dt=0.1):
|
162
161
|
unittest.main()
|
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
16
|
|
18
17
|
import unittest
|
19
18
|
|
@@ -21,7 +20,7 @@ import brainunit as u
|
|
21
20
|
import jax.numpy as jnp
|
22
21
|
import pytest
|
23
22
|
|
24
|
-
import brainstate
|
23
|
+
import brainstate
|
25
24
|
from brainstate.nn import Expon, STP, STD
|
26
25
|
|
27
26
|
|
@@ -32,7 +31,7 @@ class TestSynapse(unittest.TestCase):
|
|
32
31
|
self.time_steps = 100
|
33
32
|
|
34
33
|
def generate_input(self):
|
35
|
-
return
|
34
|
+
return brainstate.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mS
|
36
35
|
|
37
36
|
def test_expon_synapse(self):
|
38
37
|
tau = 20.0 * u.ms
|
@@ -46,8 +45,8 @@ class TestSynapse(unittest.TestCase):
|
|
46
45
|
|
47
46
|
# Test forward pass
|
48
47
|
state = synapse.init_state(self.batch_size)
|
49
|
-
call =
|
50
|
-
with
|
48
|
+
call = brainstate.compile.jit(synapse)
|
49
|
+
with brainstate.environ.context(dt=0.1 * u.ms):
|
51
50
|
for t in range(self.time_steps):
|
52
51
|
out = call(inputs[t])
|
53
52
|
self.assertEqual(out.shape, (self.batch_size, self.in_size))
|
@@ -75,7 +74,7 @@ class TestSynapse(unittest.TestCase):
|
|
75
74
|
|
76
75
|
# Test forward pass
|
77
76
|
state = synapse.init_state(self.batch_size)
|
78
|
-
call =
|
77
|
+
call = brainstate.compile.jit(synapse)
|
79
78
|
for t in range(self.time_steps):
|
80
79
|
out = call(inputs[t])
|
81
80
|
self.assertEqual(out.shape, (self.batch_size, self.in_size))
|
@@ -118,15 +117,15 @@ class TestSynapse(unittest.TestCase):
|
|
118
117
|
self.assertEqual(synapse.in_size, in_size)
|
119
118
|
self.assertEqual(synapse.out_size, in_size)
|
120
119
|
|
121
|
-
inputs =
|
120
|
+
inputs = brainstate.random.randn(self.time_steps, self.batch_size, *in_size) * u.mS
|
122
121
|
state = synapse.init_state(self.batch_size)
|
123
|
-
call =
|
124
|
-
with
|
122
|
+
call = brainstate.compile.jit(synapse)
|
123
|
+
with brainstate.environ.context(dt=0.1 * u.ms):
|
125
124
|
for t in range(self.time_steps):
|
126
125
|
out = call(inputs[t])
|
127
126
|
self.assertEqual(out.shape, (self.batch_size, *in_size))
|
128
127
|
|
129
128
|
|
130
129
|
if __name__ == '__main__':
|
131
|
-
with
|
130
|
+
with brainstate.environ.context(dt=0.1):
|
132
131
|
unittest.main()
|
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from __future__ import annotations
|
16
15
|
|
17
16
|
from typing import Union, Optional, Sequence, Callable
|
18
17
|
|
@@ -13,13 +13,12 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
16
|
|
18
17
|
import unittest
|
19
18
|
|
20
19
|
import jax.numpy as jnp
|
21
20
|
|
22
|
-
import brainstate
|
21
|
+
import brainstate
|
23
22
|
|
24
23
|
|
25
24
|
class TestRateRNNModels(unittest.TestCase):
|
@@ -30,31 +29,31 @@ class TestRateRNNModels(unittest.TestCase):
|
|
30
29
|
self.x = jnp.ones((self.batch_size, self.num_in))
|
31
30
|
|
32
31
|
def test_ValinaRNNCell(self):
|
33
|
-
model =
|
32
|
+
model = brainstate.nn.ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
|
34
33
|
model.init_state(batch_size=self.batch_size)
|
35
34
|
output = model.update(self.x)
|
36
35
|
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
37
36
|
|
38
37
|
def test_GRUCell(self):
|
39
|
-
model =
|
38
|
+
model = brainstate.nn.GRUCell(num_in=self.num_in, num_out=self.num_out)
|
40
39
|
model.init_state(batch_size=self.batch_size)
|
41
40
|
output = model.update(self.x)
|
42
41
|
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
43
42
|
|
44
43
|
def test_MGUCell(self):
|
45
|
-
model =
|
44
|
+
model = brainstate.nn.MGUCell(num_in=self.num_in, num_out=self.num_out)
|
46
45
|
model.init_state(batch_size=self.batch_size)
|
47
46
|
output = model.update(self.x)
|
48
47
|
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
49
48
|
|
50
49
|
def test_LSTMCell(self):
|
51
|
-
model =
|
50
|
+
model = brainstate.nn.LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
52
51
|
model.init_state(batch_size=self.batch_size)
|
53
52
|
output = model.update(self.x)
|
54
53
|
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
55
54
|
|
56
55
|
def test_URLSTMCell(self):
|
57
|
-
model =
|
56
|
+
model = brainstate.nn.URLSTMCell(num_in=self.num_in, num_out=self.num_out)
|
58
57
|
model.init_state(batch_size=self.batch_size)
|
59
58
|
output = model.update(self.x)
|
60
59
|
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
@@ -13,13 +13,12 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
16
|
|
18
17
|
import unittest
|
19
18
|
|
20
19
|
import jax.numpy as jnp
|
21
20
|
|
22
|
-
import brainstate
|
21
|
+
import brainstate
|
23
22
|
|
24
23
|
|
25
24
|
class TestReadoutModels(unittest.TestCase):
|
@@ -32,23 +31,23 @@ class TestReadoutModels(unittest.TestCase):
|
|
32
31
|
self.x = jnp.ones((self.batch_size, self.in_size))
|
33
32
|
|
34
33
|
def test_LeakyRateReadout(self):
|
35
|
-
with
|
36
|
-
model =
|
34
|
+
with brainstate.environ.context(dt=0.1):
|
35
|
+
model = brainstate.nn.LeakyRateReadout(in_size=self.in_size, out_size=self.out_size, tau=self.tau)
|
37
36
|
model.init_state(batch_size=self.batch_size)
|
38
37
|
output = model.update(self.x)
|
39
38
|
self.assertEqual(output.shape, (self.batch_size, self.out_size))
|
40
39
|
|
41
40
|
def test_LeakySpikeReadout(self):
|
42
|
-
with
|
43
|
-
model =
|
44
|
-
|
45
|
-
|
41
|
+
with brainstate.environ.context(dt=0.1):
|
42
|
+
model = brainstate.nn.LeakySpikeReadout(in_size=self.in_size, tau=self.tau, V_th=self.V_th,
|
43
|
+
V_initializer=brainstate.init.ZeroInit(),
|
44
|
+
w_init=brainstate.init.KaimingNormal())
|
46
45
|
model.init_state(batch_size=self.batch_size)
|
47
|
-
with
|
46
|
+
with brainstate.environ.context(t=0.):
|
48
47
|
output = model.update(self.x)
|
49
48
|
self.assertEqual(output.shape, (self.batch_size, self.out_size))
|
50
49
|
|
51
50
|
|
52
51
|
if __name__ == '__main__':
|
53
|
-
with
|
52
|
+
with brainstate.environ.context(dt=0.1):
|
54
53
|
unittest.main()
|