brainstate 0.1.1__py2.py3-none-any.whl → 0.1.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +12 -9
- brainstate/_state.py +1 -1
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +29 -14
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/functional/_activations_test.py +61 -61
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/mixin.py +1 -14
- brainstate/nn/__init__.py +81 -17
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
- brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +31 -33
- brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +6 -2
- brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +1 -1
- brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +9 -9
- brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +139 -18
- brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +14 -15
- brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
- brainstate/nn/_elementwise_test.py +169 -0
- brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob_mv.py} +1 -1
- brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_mv_test.py} +0 -1
- brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
- brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
- brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +1 -1
- brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -1
- brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +15 -17
- brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
- brainstate/nn/_module_test.py +34 -37
- brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
- brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +18 -19
- brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
- brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +10 -12
- brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
- brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +20 -22
- brainstate/nn/{_dynamics/_projection_base.py → _projection.py} +35 -3
- brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
- brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +6 -7
- brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
- brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +9 -10
- brainstate/nn/_stp.py +236 -0
- brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +17 -206
- brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +9 -10
- brainstate/nn/_synaptic_projection.py +133 -0
- brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
- brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +4 -5
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed_test.py +10 -12
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/METADATA +1 -1
- brainstate-0.1.3.dist-info/RECORD +131 -0
- brainstate/nn/_dyn_impl/__init__.py +0 -42
- brainstate/nn/_dynamics/__init__.py +0 -37
- brainstate/nn/_elementwise/__init__.py +0 -22
- brainstate/nn/_elementwise/_elementwise_test.py +0 -171
- brainstate/nn/_interaction/__init__.py +0 -41
- brainstate-0.1.1.dist-info/RECORD +0 -133
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/LICENSE +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/WHEEL +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/top_level.txt +0 -0
@@ -14,14 +14,12 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
|
17
|
-
from __future__ import annotations
|
18
|
-
|
19
17
|
import unittest
|
20
18
|
|
21
19
|
import brainunit as u
|
22
20
|
from absl.testing import parameterized
|
23
21
|
|
24
|
-
import brainstate
|
22
|
+
import brainstate
|
25
23
|
|
26
24
|
|
27
25
|
class TestDense(parameterized.TestCase):
|
@@ -32,19 +30,19 @@ class TestDense(parameterized.TestCase):
|
|
32
30
|
num_out=[20, ]
|
33
31
|
)
|
34
32
|
def test_Dense1(self, size, num_out):
|
35
|
-
f =
|
36
|
-
x =
|
33
|
+
f = brainstate.nn.Linear(10, num_out)
|
34
|
+
x = brainstate.random.random(size)
|
37
35
|
y = f(x)
|
38
36
|
self.assertTrue(y.shape == size[:-1] + (num_out,))
|
39
37
|
|
40
38
|
|
41
39
|
class TestSparseMatrix(unittest.TestCase):
|
42
40
|
def test_csr(self):
|
43
|
-
data =
|
41
|
+
data = brainstate.random.rand(10, 20)
|
44
42
|
data = data * (data > 0.9)
|
45
|
-
f =
|
43
|
+
f = brainstate.nn.SparseLinear(u.sparse.CSR.fromdense(data))
|
46
44
|
|
47
|
-
x =
|
45
|
+
x = brainstate.random.rand(10)
|
48
46
|
y = f(x)
|
49
47
|
self.assertTrue(
|
50
48
|
u.math.allclose(
|
@@ -53,7 +51,7 @@ class TestSparseMatrix(unittest.TestCase):
|
|
53
51
|
)
|
54
52
|
)
|
55
53
|
|
56
|
-
x =
|
54
|
+
x = brainstate.random.rand(5, 10)
|
57
55
|
y = f(x)
|
58
56
|
self.assertTrue(
|
59
57
|
u.math.allclose(
|
@@ -63,11 +61,11 @@ class TestSparseMatrix(unittest.TestCase):
|
|
63
61
|
)
|
64
62
|
|
65
63
|
def test_csc(self):
|
66
|
-
data =
|
64
|
+
data = brainstate.random.rand(10, 20)
|
67
65
|
data = data * (data > 0.9)
|
68
|
-
f =
|
66
|
+
f = brainstate.nn.SparseLinear(u.sparse.CSC.fromdense(data))
|
69
67
|
|
70
|
-
x =
|
68
|
+
x = brainstate.random.rand(10)
|
71
69
|
y = f(x)
|
72
70
|
self.assertTrue(
|
73
71
|
u.math.allclose(
|
@@ -76,7 +74,7 @@ class TestSparseMatrix(unittest.TestCase):
|
|
76
74
|
)
|
77
75
|
)
|
78
76
|
|
79
|
-
x =
|
77
|
+
x = brainstate.random.rand(5, 10)
|
80
78
|
y = f(x)
|
81
79
|
self.assertTrue(
|
82
80
|
u.math.allclose(
|
@@ -86,11 +84,11 @@ class TestSparseMatrix(unittest.TestCase):
|
|
86
84
|
)
|
87
85
|
|
88
86
|
def test_coo(self):
|
89
|
-
data =
|
87
|
+
data = brainstate.random.rand(10, 20)
|
90
88
|
data = data * (data > 0.9)
|
91
|
-
f =
|
89
|
+
f = brainstate.nn.SparseLinear(u.sparse.COO.fromdense(data))
|
92
90
|
|
93
|
-
x =
|
91
|
+
x = brainstate.random.rand(10)
|
94
92
|
y = f(x)
|
95
93
|
self.assertTrue(
|
96
94
|
u.math.allclose(
|
@@ -99,7 +97,7 @@ class TestSparseMatrix(unittest.TestCase):
|
|
99
97
|
)
|
100
98
|
)
|
101
99
|
|
102
|
-
x =
|
100
|
+
x = brainstate.random.rand(5, 10)
|
103
101
|
y = f(x)
|
104
102
|
self.assertTrue(
|
105
103
|
u.math.allclose(
|
@@ -16,11 +16,13 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
18
|
|
19
|
-
from .
|
20
|
-
from ._linear_mv import EventLinear
|
19
|
+
from ._synapse import Synapse
|
21
20
|
|
22
21
|
__all__ = [
|
23
|
-
'
|
24
|
-
'EventFixedProb',
|
25
|
-
'EventFixedNumConn',
|
22
|
+
'LongTermPlasticity',
|
26
23
|
]
|
24
|
+
|
25
|
+
|
26
|
+
class LongTermPlasticity(Synapse):
|
27
|
+
pass
|
28
|
+
|
brainstate/nn/_module_test.py
CHANGED
@@ -13,20 +13,17 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
import unittest
|
19
17
|
|
20
18
|
import jax.numpy as jnp
|
21
|
-
import jaxlib.xla_extension
|
22
19
|
|
23
|
-
import brainstate
|
20
|
+
import brainstate
|
24
21
|
|
25
22
|
|
26
23
|
class TestDelay(unittest.TestCase):
|
27
24
|
def test_delay1(self):
|
28
|
-
a =
|
29
|
-
delay =
|
25
|
+
a = brainstate.State(brainstate.random.random(10, 20))
|
26
|
+
delay = brainstate.nn.Delay(a.value)
|
30
27
|
delay.register_entry('a', 1.)
|
31
28
|
delay.register_entry('b', 2.)
|
32
29
|
delay.register_entry('c', None)
|
@@ -36,7 +33,7 @@ class TestDelay(unittest.TestCase):
|
|
36
33
|
delay.register_entry('c', 10.)
|
37
34
|
|
38
35
|
def test_rotation_delay(self):
|
39
|
-
rotation_delay =
|
36
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
|
40
37
|
t0 = 0.
|
41
38
|
t1, n1 = 1., 10
|
42
39
|
t2, n2 = 2., 20
|
@@ -53,7 +50,7 @@ class TestDelay(unittest.TestCase):
|
|
53
50
|
# print(rotation_delay.max_length)
|
54
51
|
|
55
52
|
for i in range(100):
|
56
|
-
|
53
|
+
brainstate.environ.set(i=i)
|
57
54
|
rotation_delay.update(jnp.ones((1,)) * i)
|
58
55
|
# print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
|
59
56
|
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
@@ -61,7 +58,7 @@ class TestDelay(unittest.TestCase):
|
|
61
58
|
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
62
59
|
|
63
60
|
def test_concat_delay(self):
|
64
|
-
rotation_delay =
|
61
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
|
65
62
|
t0 = 0.
|
66
63
|
t1, n1 = 1., 10
|
67
64
|
t2, n2 = 2., 20
|
@@ -74,7 +71,7 @@ class TestDelay(unittest.TestCase):
|
|
74
71
|
|
75
72
|
print()
|
76
73
|
for i in range(100):
|
77
|
-
|
74
|
+
brainstate.environ.set(i=i)
|
78
75
|
rotation_delay.update(jnp.ones((1,)) * i)
|
79
76
|
print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
|
80
77
|
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
@@ -83,40 +80,40 @@ class TestDelay(unittest.TestCase):
|
|
83
80
|
# bst.util.clear_buffer_memory()
|
84
81
|
|
85
82
|
def test_jit_erro(self):
|
86
|
-
rotation_delay =
|
83
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
|
87
84
|
rotation_delay.init_state()
|
88
85
|
|
89
|
-
with
|
86
|
+
with brainstate.environ.context(i=0, t=0, jit_error_check=True):
|
90
87
|
rotation_delay.retrieve_at_time(-2.0)
|
91
|
-
with self.assertRaises(
|
88
|
+
with self.assertRaises(Exception):
|
92
89
|
rotation_delay.retrieve_at_time(-2.1)
|
93
90
|
rotation_delay.retrieve_at_time(-2.01)
|
94
|
-
with self.assertRaises(
|
91
|
+
with self.assertRaises(Exception):
|
95
92
|
rotation_delay.retrieve_at_time(-2.09)
|
96
|
-
with self.assertRaises(
|
93
|
+
with self.assertRaises(Exception):
|
97
94
|
rotation_delay.retrieve_at_time(0.1)
|
98
|
-
with self.assertRaises(
|
95
|
+
with self.assertRaises(Exception):
|
99
96
|
rotation_delay.retrieve_at_time(0.01)
|
100
97
|
|
101
98
|
def test_round_interp(self):
|
102
99
|
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
103
100
|
for delay_method in ['rotation', 'concat']:
|
104
|
-
rotation_delay =
|
105
|
-
|
101
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
|
102
|
+
interp_method='round')
|
106
103
|
t0, n1 = 0.01, 0
|
107
104
|
t1, n1 = 1.04, 10
|
108
105
|
t2, n2 = 1.06, 11
|
109
106
|
rotation_delay.init_state()
|
110
107
|
|
111
|
-
@
|
108
|
+
@brainstate.compile.jit
|
112
109
|
def retrieve(td, i):
|
113
|
-
with
|
110
|
+
with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
|
114
111
|
return rotation_delay.retrieve_at_time(td)
|
115
112
|
|
116
113
|
print()
|
117
114
|
for i in range(100):
|
118
|
-
t = i *
|
119
|
-
with
|
115
|
+
t = i * brainstate.environ.get_dt()
|
116
|
+
with brainstate.environ.context(i=i, t=t):
|
120
117
|
rotation_delay.update(jnp.ones(shape) * i)
|
121
118
|
print(i,
|
122
119
|
retrieve(t - t0, i),
|
@@ -131,22 +128,22 @@ class TestDelay(unittest.TestCase):
|
|
131
128
|
for delay_method in ['rotation', 'concat']:
|
132
129
|
print(shape, delay_method)
|
133
130
|
|
134
|
-
rotation_delay =
|
135
|
-
|
131
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
|
132
|
+
interp_method='linear_interp')
|
136
133
|
t0, n0 = 0.01, 0.1
|
137
134
|
t1, n1 = 1.04, 10.4
|
138
135
|
t2, n2 = 1.06, 10.6
|
139
136
|
rotation_delay.init_state()
|
140
137
|
|
141
|
-
@
|
138
|
+
@brainstate.compile.jit
|
142
139
|
def retrieve(td, i):
|
143
|
-
with
|
140
|
+
with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
|
144
141
|
return rotation_delay.retrieve_at_time(td)
|
145
142
|
|
146
143
|
print()
|
147
144
|
for i in range(100):
|
148
|
-
t = i *
|
149
|
-
with
|
145
|
+
t = i * brainstate.environ.get_dt()
|
146
|
+
with brainstate.environ.context(i=i, t=t):
|
150
147
|
rotation_delay.update(jnp.ones(shape) * i)
|
151
148
|
print(i,
|
152
149
|
retrieve(t - t0, i),
|
@@ -157,8 +154,8 @@ class TestDelay(unittest.TestCase):
|
|
157
154
|
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
158
155
|
|
159
156
|
def test_rotation_and_concat_delay(self):
|
160
|
-
rotation_delay =
|
161
|
-
concat_delay =
|
157
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
|
158
|
+
concat_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
|
162
159
|
t0 = 0.
|
163
160
|
t1, n1 = 1., 10
|
164
161
|
t2, n2 = 2., 20
|
@@ -175,7 +172,7 @@ class TestDelay(unittest.TestCase):
|
|
175
172
|
|
176
173
|
print()
|
177
174
|
for i in range(100):
|
178
|
-
|
175
|
+
brainstate.environ.set(i=i)
|
179
176
|
new = jnp.ones((1,)) * i
|
180
177
|
rotation_delay.update(new)
|
181
178
|
concat_delay.update(new)
|
@@ -186,17 +183,17 @@ class TestDelay(unittest.TestCase):
|
|
186
183
|
|
187
184
|
class TestModule(unittest.TestCase):
|
188
185
|
def test_states(self):
|
189
|
-
class A(
|
186
|
+
class A(brainstate.nn.Module):
|
190
187
|
def __init__(self):
|
191
188
|
super().__init__()
|
192
|
-
self.a =
|
193
|
-
self.b =
|
189
|
+
self.a = brainstate.State(brainstate.random.random(10, 20))
|
190
|
+
self.b = brainstate.State(brainstate.random.random(10, 20))
|
194
191
|
|
195
|
-
class B(
|
192
|
+
class B(brainstate.nn.Module):
|
196
193
|
def __init__(self):
|
197
194
|
super().__init__()
|
198
195
|
self.a = A()
|
199
|
-
self.b =
|
196
|
+
self.b = brainstate.State(brainstate.random.random(10, 20))
|
200
197
|
|
201
198
|
b = B()
|
202
199
|
print()
|
@@ -207,5 +204,5 @@ class TestModule(unittest.TestCase):
|
|
207
204
|
|
208
205
|
|
209
206
|
if __name__ == '__main__':
|
210
|
-
with
|
207
|
+
with brainstate.environ.context(dt=0.1):
|
211
208
|
unittest.main()
|
@@ -22,9 +22,9 @@ import jax
|
|
22
22
|
|
23
23
|
from brainstate import init, surrogate, environ
|
24
24
|
from brainstate._state import HiddenState, ShortTermState
|
25
|
-
from brainstate.nn._dynamics._dynamics_base import Dynamics
|
26
|
-
from brainstate.nn._exp_euler import exp_euler_step
|
27
25
|
from brainstate.typing import ArrayLike, Size
|
26
|
+
from ._dynamics import Dynamics
|
27
|
+
from ._exp_euler import exp_euler_step
|
28
28
|
|
29
29
|
__all__ = [
|
30
30
|
'Neuron', 'IF', 'LIF', 'LIFRef', 'ALIF',
|
@@ -15,7 +15,6 @@
|
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
|
-
from __future__ import annotations
|
19
18
|
|
20
19
|
import unittest
|
21
20
|
|
@@ -23,7 +22,7 @@ import brainunit as u
|
|
23
22
|
import jax
|
24
23
|
import jax.numpy as jnp
|
25
24
|
|
26
|
-
import brainstate
|
25
|
+
import brainstate
|
27
26
|
from brainstate.nn import IF, LIF, ALIF
|
28
27
|
|
29
28
|
|
@@ -35,13 +34,13 @@ class TestNeuron(unittest.TestCase):
|
|
35
34
|
|
36
35
|
def test_neuron_base_class(self):
|
37
36
|
with self.assertRaises(NotImplementedError):
|
38
|
-
|
37
|
+
brainstate.nn.Neuron(self.in_size).get_spike() # Neuron is an abstract base class
|
39
38
|
|
40
39
|
def generate_input(self):
|
41
|
-
return
|
40
|
+
return brainstate.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mA
|
42
41
|
|
43
42
|
def test_if_neuron(self):
|
44
|
-
with
|
43
|
+
with brainstate.environ.context(dt=0.1 * u.ms):
|
45
44
|
neuron = IF(self.in_size)
|
46
45
|
inputs = self.generate_input()
|
47
46
|
|
@@ -62,7 +61,7 @@ class TestNeuron(unittest.TestCase):
|
|
62
61
|
self.assertTrue(jnp.all((spikes >= 0) & (spikes <= 1)))
|
63
62
|
|
64
63
|
def test_lif_neuron(self):
|
65
|
-
with
|
64
|
+
with brainstate.environ.context(dt=0.1 * u.ms):
|
66
65
|
tau = 20.0 * u.ms
|
67
66
|
neuron = LIF(self.in_size, tau=tau)
|
68
67
|
inputs = self.generate_input()
|
@@ -74,7 +73,7 @@ class TestNeuron(unittest.TestCase):
|
|
74
73
|
|
75
74
|
# Test forward pass
|
76
75
|
state = neuron.init_state(self.batch_size)
|
77
|
-
call =
|
76
|
+
call = brainstate.compile.jit(neuron)
|
78
77
|
|
79
78
|
for t in range(self.time_steps):
|
80
79
|
out = call(inputs[t])
|
@@ -94,8 +93,8 @@ class TestNeuron(unittest.TestCase):
|
|
94
93
|
|
95
94
|
# Test forward pass
|
96
95
|
neuron.init_state(self.batch_size)
|
97
|
-
call =
|
98
|
-
with
|
96
|
+
call = brainstate.compile.jit(neuron)
|
97
|
+
with brainstate.environ.context(dt=0.1 * u.ms):
|
99
98
|
for t in range(self.time_steps):
|
100
99
|
out = call(inputs[t])
|
101
100
|
self.assertEqual(out.shape, (self.batch_size, self.in_size))
|
@@ -113,8 +112,8 @@ class TestNeuron(unittest.TestCase):
|
|
113
112
|
neuron = NeuronClass(self.in_size, spk_reset='soft')
|
114
113
|
inputs = self.generate_input()
|
115
114
|
state = neuron.init_state(self.batch_size)
|
116
|
-
call =
|
117
|
-
with
|
115
|
+
call = brainstate.compile.jit(neuron)
|
116
|
+
with brainstate.environ.context(dt=0.1 * u.ms):
|
118
117
|
for t in range(self.time_steps):
|
119
118
|
out = call(inputs[t])
|
120
119
|
self.assertTrue(jnp.all(neuron.V.value <= neuron.V_th))
|
@@ -124,8 +123,8 @@ class TestNeuron(unittest.TestCase):
|
|
124
123
|
neuron = NeuronClass(self.in_size, spk_reset='hard')
|
125
124
|
inputs = self.generate_input()
|
126
125
|
state = neuron.init_state(self.batch_size)
|
127
|
-
call =
|
128
|
-
with
|
126
|
+
call = brainstate.compile.jit(neuron)
|
127
|
+
with brainstate.environ.context(dt=0.1 * u.ms):
|
129
128
|
for t in range(self.time_steps):
|
130
129
|
out = call(inputs[t])
|
131
130
|
self.assertTrue(jnp.all((neuron.V.value < neuron.V_th) | (neuron.V.value == 0. * u.mV)))
|
@@ -135,8 +134,8 @@ class TestNeuron(unittest.TestCase):
|
|
135
134
|
neuron = NeuronClass(self.in_size)
|
136
135
|
inputs = self.generate_input()
|
137
136
|
state = neuron.init_state(self.batch_size)
|
138
|
-
call =
|
139
|
-
with
|
137
|
+
call = brainstate.compile.jit(neuron)
|
138
|
+
with brainstate.environ.context(dt=0.1 * u.ms):
|
140
139
|
for t in range(self.time_steps):
|
141
140
|
out = call(inputs[t])
|
142
141
|
self.assertFalse(jax.tree_util.tree_leaves(out)[0].aval.weak_type)
|
@@ -148,15 +147,15 @@ class TestNeuron(unittest.TestCase):
|
|
148
147
|
self.assertEqual(neuron.in_size, in_size)
|
149
148
|
self.assertEqual(neuron.out_size, in_size)
|
150
149
|
|
151
|
-
inputs =
|
150
|
+
inputs = brainstate.random.randn(self.time_steps, self.batch_size, *in_size) * u.mA
|
152
151
|
state = neuron.init_state(self.batch_size)
|
153
|
-
call =
|
154
|
-
with
|
152
|
+
call = brainstate.compile.jit(neuron)
|
153
|
+
with brainstate.environ.context(dt=0.1 * u.ms):
|
155
154
|
for t in range(self.time_steps):
|
156
155
|
out = call(inputs[t])
|
157
156
|
self.assertEqual(out.shape, (self.batch_size, *in_size))
|
158
157
|
|
159
158
|
|
160
159
|
if __name__ == '__main__':
|
161
|
-
with
|
160
|
+
with brainstate.environ.context(dt=0.1):
|
162
161
|
unittest.main()
|
@@ -22,8 +22,8 @@ import jax.numpy as jnp
|
|
22
22
|
|
23
23
|
from brainstate import environ, init
|
24
24
|
from brainstate._state import ParamState, BatchState
|
25
|
-
from brainstate.nn._module import Module
|
26
25
|
from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
|
26
|
+
from ._module import Module
|
27
27
|
|
28
28
|
__all__ = [
|
29
29
|
'BatchNorm0d',
|
@@ -13,12 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
from absl.testing import absltest
|
19
17
|
from absl.testing import parameterized
|
20
18
|
|
21
|
-
import brainstate
|
19
|
+
import brainstate
|
22
20
|
|
23
21
|
|
24
22
|
class Test_Normalization(parameterized.TestCase):
|
@@ -26,27 +24,27 @@ class Test_Normalization(parameterized.TestCase):
|
|
26
24
|
fit=[True, False],
|
27
25
|
)
|
28
26
|
def test_BatchNorm1d(self, fit):
|
29
|
-
net =
|
30
|
-
|
31
|
-
input =
|
27
|
+
net = brainstate.nn.BatchNorm1d((3, 10))
|
28
|
+
brainstate.environ.set(fit=fit)
|
29
|
+
input = brainstate.random.randn(1, 3, 10)
|
32
30
|
output = net(input)
|
33
31
|
|
34
32
|
@parameterized.product(
|
35
33
|
fit=[True, False]
|
36
34
|
)
|
37
35
|
def test_BatchNorm2d(self, fit):
|
38
|
-
net =
|
39
|
-
|
40
|
-
input =
|
36
|
+
net = brainstate.nn.BatchNorm2d([3, 4, 10])
|
37
|
+
brainstate.environ.set(fit=fit)
|
38
|
+
input = brainstate.random.randn(1, 3, 4, 10)
|
41
39
|
output = net(input)
|
42
40
|
|
43
41
|
@parameterized.product(
|
44
42
|
fit=[True, False]
|
45
43
|
)
|
46
44
|
def test_BatchNorm3d(self, fit):
|
47
|
-
net =
|
48
|
-
|
49
|
-
input =
|
45
|
+
net = brainstate.nn.BatchNorm3d([3, 4, 5, 10])
|
46
|
+
brainstate.environ.set(fit=fit)
|
47
|
+
input = brainstate.random.randn(1, 3, 4, 5, 10)
|
50
48
|
output = net(input)
|
51
49
|
|
52
50
|
# @parameterized.product(
|
@@ -1,13 +1,11 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
|
3
|
-
from __future__ import annotations
|
4
|
-
|
5
3
|
import jax
|
6
4
|
import numpy as np
|
7
5
|
from absl.testing import absltest
|
8
6
|
from absl.testing import parameterized
|
9
7
|
|
10
|
-
import brainstate
|
8
|
+
import brainstate
|
11
9
|
import brainstate.nn as nn
|
12
10
|
|
13
11
|
|
@@ -18,7 +16,7 @@ class TestFlatten(parameterized.TestCase):
|
|
18
16
|
(32, 8),
|
19
17
|
(10, 20, 30),
|
20
18
|
]:
|
21
|
-
arr =
|
19
|
+
arr = brainstate.random.rand(*size)
|
22
20
|
f = nn.Flatten(start_axis=0)
|
23
21
|
out = f(arr)
|
24
22
|
self.assertTrue(out.shape == (np.prod(size),))
|
@@ -29,21 +27,21 @@ class TestFlatten(parameterized.TestCase):
|
|
29
27
|
(32, 8),
|
30
28
|
(10, 20, 30),
|
31
29
|
]:
|
32
|
-
arr =
|
30
|
+
arr = brainstate.random.rand(*size)
|
33
31
|
f = nn.Flatten(start_axis=1)
|
34
32
|
out = f(arr)
|
35
33
|
self.assertTrue(out.shape == (size[0], np.prod(size[1:])))
|
36
34
|
|
37
35
|
def test_flatten3(self):
|
38
36
|
size = (16, 32, 32, 8)
|
39
|
-
arr =
|
37
|
+
arr = brainstate.random.rand(*size)
|
40
38
|
f = nn.Flatten(start_axis=0, in_size=(32, 8))
|
41
39
|
out = f(arr)
|
42
40
|
self.assertTrue(out.shape == (16, 32, 32 * 8))
|
43
41
|
|
44
42
|
def test_flatten4(self):
|
45
43
|
size = (16, 32, 32, 8)
|
46
|
-
arr =
|
44
|
+
arr = brainstate.random.rand(*size)
|
47
45
|
f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
|
48
46
|
out = f(arr)
|
49
47
|
self.assertTrue(out.shape == (16, 32, 32 * 8))
|
@@ -58,7 +56,7 @@ class TestPool(parameterized.TestCase):
|
|
58
56
|
super().__init__(*args, **kwargs)
|
59
57
|
|
60
58
|
def test_MaxPool2d_v1(self):
|
61
|
-
arr =
|
59
|
+
arr = brainstate.random.rand(16, 32, 32, 8)
|
62
60
|
|
63
61
|
out = nn.MaxPool2d(2, 2, channel_axis=-1)(arr)
|
64
62
|
self.assertTrue(out.shape == (16, 16, 16, 8))
|
@@ -79,7 +77,7 @@ class TestPool(parameterized.TestCase):
|
|
79
77
|
self.assertTrue(out.shape == (16, 17, 32, 5))
|
80
78
|
|
81
79
|
def test_AvgPool2d_v1(self):
|
82
|
-
arr =
|
80
|
+
arr = brainstate.random.rand(16, 32, 32, 8)
|
83
81
|
|
84
82
|
out = nn.AvgPool2d(2, 2, channel_axis=-1)(arr)
|
85
83
|
self.assertTrue(out.shape == (16, 16, 16, 8))
|
@@ -105,9 +103,9 @@ class TestPool(parameterized.TestCase):
|
|
105
103
|
for target_size in [10, 9, 8, 7, 6]
|
106
104
|
)
|
107
105
|
def test_adaptive_pool1d(self, target_size):
|
108
|
-
from brainstate.nn.
|
106
|
+
from brainstate.nn._poolings import _adaptive_pool1d
|
109
107
|
|
110
|
-
arr =
|
108
|
+
arr = brainstate.random.rand(100)
|
111
109
|
op = jax.numpy.mean
|
112
110
|
|
113
111
|
out = _adaptive_pool1d(arr, target_size, op)
|
@@ -119,7 +117,7 @@ class TestPool(parameterized.TestCase):
|
|
119
117
|
self.assertTrue(out.shape == (target_size,))
|
120
118
|
|
121
119
|
def test_AdaptiveAvgPool2d_v1(self):
|
122
|
-
input =
|
120
|
+
input = brainstate.random.randn(64, 8, 9)
|
123
121
|
|
124
122
|
output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
|
125
123
|
self.assertTrue(output.shape == (64, 5, 7))
|
@@ -137,8 +135,8 @@ class TestPool(parameterized.TestCase):
|
|
137
135
|
self.assertTrue(output.shape == (64, 2, 3))
|
138
136
|
|
139
137
|
def test_AdaptiveAvgPool2d_v2(self):
|
140
|
-
|
141
|
-
input =
|
138
|
+
brainstate.random.seed()
|
139
|
+
input = brainstate.random.randn(128, 64, 32, 16)
|
142
140
|
|
143
141
|
output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
|
144
142
|
self.assertTrue(output.shape == (128, 64, 5, 7))
|
@@ -154,13 +152,13 @@ class TestPool(parameterized.TestCase):
|
|
154
152
|
print()
|
155
153
|
|
156
154
|
def test_AdaptiveAvgPool3d_v1(self):
|
157
|
-
input =
|
155
|
+
input = brainstate.random.randn(10, 128, 64, 32)
|
158
156
|
net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3], channel_axis=0)
|
159
157
|
output = net(input)
|
160
158
|
self.assertTrue(output.shape == (10, 6, 5, 3))
|
161
159
|
|
162
160
|
def test_AdaptiveAvgPool3d_v2(self):
|
163
|
-
input =
|
161
|
+
input = brainstate.random.randn(10, 20, 128, 64, 32)
|
164
162
|
net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3])
|
165
163
|
output = net(input)
|
166
164
|
self.assertTrue(output.shape == (10, 6, 5, 3, 32))
|
@@ -169,7 +167,7 @@ class TestPool(parameterized.TestCase):
|
|
169
167
|
axis=(-1, 0, 1)
|
170
168
|
)
|
171
169
|
def test_AdaptiveMaxPool1d_v1(self, axis):
|
172
|
-
input =
|
170
|
+
input = brainstate.random.randn(32, 16)
|
173
171
|
net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
|
174
172
|
output = net(input)
|
175
173
|
|
@@ -177,7 +175,7 @@ class TestPool(parameterized.TestCase):
|
|
177
175
|
axis=(-1, 0, 1, 2)
|
178
176
|
)
|
179
177
|
def test_AdaptiveMaxPool1d_v2(self, axis):
|
180
|
-
input =
|
178
|
+
input = brainstate.random.randn(2, 32, 16)
|
181
179
|
net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
|
182
180
|
output = net(input)
|
183
181
|
|
@@ -185,7 +183,7 @@ class TestPool(parameterized.TestCase):
|
|
185
183
|
axis=(-1, 0, 1, 2)
|
186
184
|
)
|
187
185
|
def test_AdaptiveMaxPool2d_v1(self, axis):
|
188
|
-
input =
|
186
|
+
input = brainstate.random.randn(32, 16, 12)
|
189
187
|
net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
|
190
188
|
output = net(input)
|
191
189
|
|
@@ -193,7 +191,7 @@ class TestPool(parameterized.TestCase):
|
|
193
191
|
axis=(-1, 0, 1, 2, 3)
|
194
192
|
)
|
195
193
|
def test_AdaptiveMaxPool2d_v2(self, axis):
|
196
|
-
input =
|
194
|
+
input = brainstate.random.randn(2, 32, 16, 12)
|
197
195
|
net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
|
198
196
|
output = net(input)
|
199
197
|
|
@@ -201,7 +199,7 @@ class TestPool(parameterized.TestCase):
|
|
201
199
|
axis=(-1, 0, 1, 2, 3)
|
202
200
|
)
|
203
201
|
def test_AdaptiveMaxPool3d_v1(self, axis):
|
204
|
-
input =
|
202
|
+
input = brainstate.random.randn(2, 128, 64, 32)
|
205
203
|
net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
|
206
204
|
output = net(input)
|
207
205
|
print()
|
@@ -210,7 +208,7 @@ class TestPool(parameterized.TestCase):
|
|
210
208
|
axis=(-1, 0, 1, 2, 3, 4)
|
211
209
|
)
|
212
210
|
def test_AdaptiveMaxPool3d_v1(self, axis):
|
213
|
-
input =
|
211
|
+
input = brainstate.random.randn(2, 128, 64, 32, 16)
|
214
212
|
net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
|
215
213
|
output = net(input)
|
216
214
|
|