brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +10 -3
- brainstate/_state.py +178 -178
- brainstate/_utils.py +0 -1
- brainstate/augment/_autograd.py +0 -2
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape.py +0 -2
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping.py +2 -3
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/augment/_random.py +0 -2
- brainstate/compile/_ad_checkpoint.py +0 -2
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions.py +0 -2
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if.py +0 -2
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_jit.py +9 -8
- brainstate/compile/_loop_collect_return.py +0 -2
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection.py +0 -2
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +30 -17
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/compile/_progress_bar.py +0 -1
- brainstate/compile/_unvmap.py +0 -1
- brainstate/compile/_util.py +0 -2
- brainstate/environ.py +0 -2
- brainstate/functional/_activations.py +0 -2
- brainstate/functional/_activations_test.py +61 -61
- brainstate/functional/_normalization.py +0 -2
- brainstate/functional/_others.py +0 -2
- brainstate/functional/_spikes.py +0 -1
- brainstate/graph/_graph_node.py +1 -3
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation.py +4 -2
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_base.py +0 -2
- brainstate/init/_generic.py +0 -1
- brainstate/init/_random_inits.py +0 -1
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits.py +0 -2
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/mixin.py +0 -2
- brainstate/nn/_collective_ops.py +0 -3
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/_common.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
- brainstate/nn/_dyn_impl/_inputs.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
- brainstate/nn/_dyn_impl/_readout.py +0 -1
- brainstate/nn/_dyn_impl/_readout_test.py +9 -10
- brainstate/nn/_dynamics/_dynamics_base.py +0 -1
- brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
- brainstate/nn/_dynamics/_projection_base.py +0 -1
- brainstate/nn/_dynamics/_state_delay.py +0 -2
- brainstate/nn/_dynamics/_synouts.py +0 -2
- brainstate/nn/_dynamics/_synouts_test.py +4 -5
- brainstate/nn/_elementwise/_dropout.py +0 -2
- brainstate/nn/_elementwise/_dropout_test.py +9 -9
- brainstate/nn/_elementwise/_elementwise.py +0 -2
- brainstate/nn/_elementwise/_elementwise_test.py +57 -59
- brainstate/nn/_event/_fixedprob_mv.py +0 -1
- brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
- brainstate/nn/_event/_linear_mv.py +0 -2
- brainstate/nn/_event/_linear_mv_test.py +0 -1
- brainstate/nn/_exp_euler.py +0 -2
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/_interaction/_conv.py +0 -2
- brainstate/nn/_interaction/_conv_test.py +31 -33
- brainstate/nn/_interaction/_embedding.py +0 -1
- brainstate/nn/_interaction/_linear.py +0 -2
- brainstate/nn/_interaction/_linear_test.py +15 -17
- brainstate/nn/_interaction/_normalizations.py +0 -2
- brainstate/nn/_interaction/_normalizations_test.py +10 -12
- brainstate/nn/_interaction/_poolings.py +0 -2
- brainstate/nn/_interaction/_poolings_test.py +19 -21
- brainstate/nn/_module.py +0 -1
- brainstate/nn/_module_test.py +34 -37
- brainstate/nn/metrics.py +0 -2
- brainstate/optim/_base.py +0 -2
- brainstate/optim/_lr_scheduler.py +0 -1
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer.py +0 -2
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/optim/_sgd_optimizer.py +0 -1
- brainstate/random/_rand_funs.py +0 -1
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed.py +0 -1
- brainstate/random/_rand_seed_test.py +10 -12
- brainstate/random/_rand_state.py +0 -1
- brainstate/surrogate.py +0 -1
- brainstate/typing.py +0 -2
- brainstate/util/_caller.py +4 -6
- brainstate/util/_others.py +0 -2
- brainstate/util/_pretty_pytree.py +201 -150
- brainstate/util/_pretty_repr.py +0 -2
- brainstate/util/_pretty_table.py +57 -3
- brainstate/util/_scaling.py +0 -2
- brainstate/util/_struct.py +0 -2
- brainstate/util/filter.py +0 -2
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
- brainstate-0.1.2.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -15,47 +15,46 @@
|
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
|
-
from __future__ import annotations
|
19
18
|
|
20
19
|
import unittest
|
21
20
|
|
22
21
|
import numpy as np
|
23
22
|
|
24
|
-
import brainstate
|
23
|
+
import brainstate
|
25
24
|
|
26
25
|
|
27
26
|
class TestModuleGroup(unittest.TestCase):
|
28
27
|
def test_initialization(self):
|
29
|
-
group =
|
30
|
-
self.assertIsInstance(group,
|
28
|
+
group = brainstate.nn.DynamicsGroup()
|
29
|
+
self.assertIsInstance(group, brainstate.nn.DynamicsGroup)
|
31
30
|
|
32
31
|
|
33
32
|
class TestProjection(unittest.TestCase):
|
34
33
|
def test_initialization(self):
|
35
|
-
proj =
|
36
|
-
self.assertIsInstance(proj,
|
34
|
+
proj = brainstate.nn.Projection()
|
35
|
+
self.assertIsInstance(proj, brainstate.nn.Projection)
|
37
36
|
|
38
37
|
def test_update_not_implemented(self):
|
39
|
-
proj =
|
38
|
+
proj = brainstate.nn.Projection()
|
40
39
|
with self.assertRaises(ValueError):
|
41
40
|
proj.update()
|
42
41
|
|
43
42
|
|
44
43
|
class TestDynamics(unittest.TestCase):
|
45
44
|
def test_initialization(self):
|
46
|
-
dyn =
|
47
|
-
self.assertIsInstance(dyn,
|
45
|
+
dyn = brainstate.nn.Dynamics(in_size=10)
|
46
|
+
self.assertIsInstance(dyn, brainstate.nn.Dynamics)
|
48
47
|
self.assertEqual(dyn.in_size, (10,))
|
49
48
|
self.assertEqual(dyn.out_size, (10,))
|
50
49
|
|
51
50
|
def test_size_validation(self):
|
52
51
|
with self.assertRaises(ValueError):
|
53
|
-
|
52
|
+
brainstate.nn.Dynamics(in_size=[])
|
54
53
|
with self.assertRaises(ValueError):
|
55
|
-
|
54
|
+
brainstate.nn.Dynamics(in_size="invalid")
|
56
55
|
|
57
56
|
def test_input_handling(self):
|
58
|
-
dyn =
|
57
|
+
dyn = brainstate.nn.Dynamics(in_size=10)
|
59
58
|
dyn.add_current_input("test_current", lambda: np.random.rand(10))
|
60
59
|
dyn.add_delta_input("test_delta", lambda: np.random.rand(10))
|
61
60
|
|
@@ -63,15 +62,15 @@ class TestDynamics(unittest.TestCase):
|
|
63
62
|
self.assertIn("test_delta", dyn.delta_inputs)
|
64
63
|
|
65
64
|
def test_duplicate_input_key(self):
|
66
|
-
dyn =
|
65
|
+
dyn = brainstate.nn.Dynamics(in_size=10)
|
67
66
|
dyn.add_current_input("test", lambda: np.random.rand(10))
|
68
67
|
with self.assertRaises(ValueError):
|
69
68
|
dyn.add_current_input("test", lambda: np.random.rand(10))
|
70
69
|
|
71
70
|
def test_varshape(self):
|
72
|
-
dyn =
|
71
|
+
dyn = brainstate.nn.Dynamics(in_size=(2, 3))
|
73
72
|
self.assertEqual(dyn.varshape, (2, 3))
|
74
|
-
dyn =
|
73
|
+
dyn = brainstate.nn.Dynamics(in_size=(2, 3))
|
75
74
|
self.assertEqual(dyn.varshape, (2, 3))
|
76
75
|
|
77
76
|
|
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from __future__ import annotations
|
16
15
|
|
17
16
|
from typing import Union, Callable, Optional
|
18
17
|
|
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
16
|
|
18
17
|
import unittest
|
19
18
|
|
@@ -21,7 +20,7 @@ import brainunit as u
|
|
21
20
|
import jax.numpy as jnp
|
22
21
|
import numpy as np
|
23
22
|
|
24
|
-
import brainstate
|
23
|
+
import brainstate
|
25
24
|
|
26
25
|
|
27
26
|
class TestSynOutModels(unittest.TestCase):
|
@@ -35,19 +34,19 @@ class TestSynOutModels(unittest.TestCase):
|
|
35
34
|
self.V_offset = jnp.array([0.0])
|
36
35
|
|
37
36
|
def test_COBA(self):
|
38
|
-
model =
|
37
|
+
model = brainstate.nn.COBA(E=self.E)
|
39
38
|
output = model.update(self.conductance, self.potential)
|
40
39
|
expected_output = self.conductance * (self.E - self.potential)
|
41
40
|
np.testing.assert_array_almost_equal(output, expected_output)
|
42
41
|
|
43
42
|
def test_CUBA(self):
|
44
|
-
model =
|
43
|
+
model = brainstate.nn.CUBA()
|
45
44
|
output = model.update(self.conductance)
|
46
45
|
expected_output = self.conductance * model.scale
|
47
46
|
self.assertTrue(u.math.allclose(output, expected_output))
|
48
47
|
|
49
48
|
def test_MgBlock(self):
|
50
|
-
model =
|
49
|
+
model = brainstate.nn.MgBlock(E=self.E, cc_Mg=self.cc_Mg, alpha=self.alpha, beta=self.beta, V_offset=self.V_offset)
|
51
50
|
output = model.update(self.conductance, self.potential)
|
52
51
|
norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - self.potential)))
|
53
52
|
expected_output = self.conductance * (self.E - self.potential) / norm
|
@@ -18,19 +18,19 @@ import unittest
|
|
18
18
|
|
19
19
|
import numpy as np
|
20
20
|
|
21
|
-
import brainstate
|
21
|
+
import brainstate
|
22
22
|
|
23
23
|
|
24
24
|
class TestDropout(unittest.TestCase):
|
25
25
|
|
26
26
|
def test_dropout(self):
|
27
27
|
# Create a Dropout layer with a dropout rate of 0.5
|
28
|
-
dropout_layer =
|
28
|
+
dropout_layer = brainstate.nn.Dropout(0.5)
|
29
29
|
|
30
30
|
# Input data
|
31
31
|
input_data = np.arange(20)
|
32
32
|
|
33
|
-
with
|
33
|
+
with brainstate.environ.context(fit=True):
|
34
34
|
# Apply dropout
|
35
35
|
output_data = dropout_layer(input_data)
|
36
36
|
|
@@ -47,10 +47,10 @@ class TestDropout(unittest.TestCase):
|
|
47
47
|
np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements)
|
48
48
|
|
49
49
|
def test_DropoutFixed(self):
|
50
|
-
dropout_layer =
|
50
|
+
dropout_layer = brainstate.nn.DropoutFixed(in_size=(2, 3), prob=0.5)
|
51
51
|
dropout_layer.init_state(batch_size=2)
|
52
52
|
input_data = np.random.randn(2, 2, 3)
|
53
|
-
with
|
53
|
+
with brainstate.environ.context(fit=True):
|
54
54
|
output_data = dropout_layer.update(input_data)
|
55
55
|
self.assertEqual(input_data.shape, output_data.shape)
|
56
56
|
self.assertTrue(np.any(output_data == 0))
|
@@ -72,9 +72,9 @@ class TestDropout(unittest.TestCase):
|
|
72
72
|
# np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
|
73
73
|
|
74
74
|
def test_Dropout2d(self):
|
75
|
-
dropout_layer =
|
75
|
+
dropout_layer = brainstate.nn.Dropout2d(prob=0.5)
|
76
76
|
input_data = np.random.randn(2, 3, 4, 5)
|
77
|
-
with
|
77
|
+
with brainstate.environ.context(fit=True):
|
78
78
|
output_data = dropout_layer(input_data)
|
79
79
|
self.assertEqual(input_data.shape, output_data.shape)
|
80
80
|
self.assertTrue(np.any(output_data == 0))
|
@@ -84,9 +84,9 @@ class TestDropout(unittest.TestCase):
|
|
84
84
|
np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
|
85
85
|
|
86
86
|
def test_Dropout3d(self):
|
87
|
-
dropout_layer =
|
87
|
+
dropout_layer = brainstate.nn.Dropout3d(prob=0.5)
|
88
88
|
input_data = np.random.randn(2, 3, 4, 5, 6)
|
89
|
-
with
|
89
|
+
with brainstate.environ.context(fit=True):
|
90
90
|
output_data = dropout_layer(input_data)
|
91
91
|
self.assertEqual(input_data.shape, output_data.shape)
|
92
92
|
self.assertTrue(np.any(output_data == 0))
|
@@ -13,157 +13,155 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
from absl.testing import absltest
|
19
17
|
from absl.testing import parameterized
|
20
18
|
|
21
|
-
import brainstate
|
19
|
+
import brainstate
|
22
20
|
|
23
21
|
|
24
22
|
class Test_Activation(parameterized.TestCase):
|
25
23
|
|
26
24
|
def test_Threshold(self):
|
27
|
-
threshold_layer =
|
28
|
-
input =
|
25
|
+
threshold_layer = brainstate.nn.Threshold(5, 20)
|
26
|
+
input = brainstate.random.randn(2)
|
29
27
|
output = threshold_layer(input)
|
30
28
|
|
31
29
|
def test_ReLU(self):
|
32
|
-
ReLU_layer =
|
33
|
-
input =
|
30
|
+
ReLU_layer = brainstate.nn.ReLU()
|
31
|
+
input = brainstate.random.randn(2)
|
34
32
|
output = ReLU_layer(input)
|
35
33
|
|
36
34
|
def test_RReLU(self):
|
37
|
-
RReLU_layer =
|
38
|
-
input =
|
35
|
+
RReLU_layer = brainstate.nn.RReLU(lower=0, upper=1)
|
36
|
+
input = brainstate.random.randn(2)
|
39
37
|
output = RReLU_layer(input)
|
40
38
|
|
41
39
|
def test_Hardtanh(self):
|
42
|
-
Hardtanh_layer =
|
43
|
-
input =
|
40
|
+
Hardtanh_layer = brainstate.nn.Hardtanh(min_val=0, max_val=1, )
|
41
|
+
input = brainstate.random.randn(2)
|
44
42
|
output = Hardtanh_layer(input)
|
45
43
|
|
46
44
|
def test_ReLU6(self):
|
47
|
-
ReLU6_layer =
|
48
|
-
input =
|
45
|
+
ReLU6_layer = brainstate.nn.ReLU6()
|
46
|
+
input = brainstate.random.randn(2)
|
49
47
|
output = ReLU6_layer(input)
|
50
48
|
|
51
49
|
def test_Sigmoid(self):
|
52
|
-
Sigmoid_layer =
|
53
|
-
input =
|
50
|
+
Sigmoid_layer = brainstate.nn.Sigmoid()
|
51
|
+
input = brainstate.random.randn(2)
|
54
52
|
output = Sigmoid_layer(input)
|
55
53
|
|
56
54
|
def test_Hardsigmoid(self):
|
57
|
-
Hardsigmoid_layer =
|
58
|
-
input =
|
55
|
+
Hardsigmoid_layer = brainstate.nn.Hardsigmoid()
|
56
|
+
input = brainstate.random.randn(2)
|
59
57
|
output = Hardsigmoid_layer(input)
|
60
58
|
|
61
59
|
def test_Tanh(self):
|
62
|
-
Tanh_layer =
|
63
|
-
input =
|
60
|
+
Tanh_layer = brainstate.nn.Tanh()
|
61
|
+
input = brainstate.random.randn(2)
|
64
62
|
output = Tanh_layer(input)
|
65
63
|
|
66
64
|
def test_SiLU(self):
|
67
|
-
SiLU_layer =
|
68
|
-
input =
|
65
|
+
SiLU_layer = brainstate.nn.SiLU()
|
66
|
+
input = brainstate.random.randn(2)
|
69
67
|
output = SiLU_layer(input)
|
70
68
|
|
71
69
|
def test_Mish(self):
|
72
|
-
Mish_layer =
|
73
|
-
input =
|
70
|
+
Mish_layer = brainstate.nn.Mish()
|
71
|
+
input = brainstate.random.randn(2)
|
74
72
|
output = Mish_layer(input)
|
75
73
|
|
76
74
|
def test_Hardswish(self):
|
77
|
-
Hardswish_layer =
|
78
|
-
input =
|
75
|
+
Hardswish_layer = brainstate.nn.Hardswish()
|
76
|
+
input = brainstate.random.randn(2)
|
79
77
|
output = Hardswish_layer(input)
|
80
78
|
|
81
79
|
def test_ELU(self):
|
82
|
-
ELU_layer =
|
83
|
-
input =
|
80
|
+
ELU_layer = brainstate.nn.ELU(alpha=0.5, )
|
81
|
+
input = brainstate.random.randn(2)
|
84
82
|
output = ELU_layer(input)
|
85
83
|
|
86
84
|
def test_CELU(self):
|
87
|
-
CELU_layer =
|
88
|
-
input =
|
85
|
+
CELU_layer = brainstate.nn.CELU(alpha=0.5, )
|
86
|
+
input = brainstate.random.randn(2)
|
89
87
|
output = CELU_layer(input)
|
90
88
|
|
91
89
|
def test_SELU(self):
|
92
|
-
SELU_layer =
|
93
|
-
input =
|
90
|
+
SELU_layer = brainstate.nn.SELU()
|
91
|
+
input = brainstate.random.randn(2)
|
94
92
|
output = SELU_layer(input)
|
95
93
|
|
96
94
|
def test_GLU(self):
|
97
|
-
GLU_layer =
|
98
|
-
input =
|
95
|
+
GLU_layer = brainstate.nn.GLU()
|
96
|
+
input = brainstate.random.randn(4, 2)
|
99
97
|
output = GLU_layer(input)
|
100
98
|
|
101
99
|
@parameterized.product(
|
102
100
|
approximate=['tanh', 'none']
|
103
101
|
)
|
104
102
|
def test_GELU(self, approximate):
|
105
|
-
GELU_layer =
|
106
|
-
input =
|
103
|
+
GELU_layer = brainstate.nn.GELU()
|
104
|
+
input = brainstate.random.randn(2)
|
107
105
|
output = GELU_layer(input)
|
108
106
|
|
109
107
|
def test_Hardshrink(self):
|
110
|
-
Hardshrink_layer =
|
111
|
-
input =
|
108
|
+
Hardshrink_layer = brainstate.nn.Hardshrink(lambd=1)
|
109
|
+
input = brainstate.random.randn(2)
|
112
110
|
output = Hardshrink_layer(input)
|
113
111
|
|
114
112
|
def test_LeakyReLU(self):
|
115
|
-
LeakyReLU_layer =
|
116
|
-
input =
|
113
|
+
LeakyReLU_layer = brainstate.nn.LeakyReLU()
|
114
|
+
input = brainstate.random.randn(2)
|
117
115
|
output = LeakyReLU_layer(input)
|
118
116
|
|
119
117
|
def test_LogSigmoid(self):
|
120
|
-
LogSigmoid_layer =
|
121
|
-
input =
|
118
|
+
LogSigmoid_layer = brainstate.nn.LogSigmoid()
|
119
|
+
input = brainstate.random.randn(2)
|
122
120
|
output = LogSigmoid_layer(input)
|
123
121
|
|
124
122
|
def test_Softplus(self):
|
125
|
-
Softplus_layer =
|
126
|
-
input =
|
123
|
+
Softplus_layer = brainstate.nn.Softplus()
|
124
|
+
input = brainstate.random.randn(2)
|
127
125
|
output = Softplus_layer(input)
|
128
126
|
|
129
127
|
def test_Softshrink(self):
|
130
|
-
Softshrink_layer =
|
131
|
-
input =
|
128
|
+
Softshrink_layer = brainstate.nn.Softshrink(lambd=1)
|
129
|
+
input = brainstate.random.randn(2)
|
132
130
|
output = Softshrink_layer(input)
|
133
131
|
|
134
132
|
def test_PReLU(self):
|
135
|
-
PReLU_layer =
|
136
|
-
input =
|
133
|
+
PReLU_layer = brainstate.nn.PReLU(num_parameters=2, init=0.5)
|
134
|
+
input = brainstate.random.randn(2)
|
137
135
|
output = PReLU_layer(input)
|
138
136
|
|
139
137
|
def test_Softsign(self):
|
140
|
-
Softsign_layer =
|
141
|
-
input =
|
138
|
+
Softsign_layer = brainstate.nn.Softsign()
|
139
|
+
input = brainstate.random.randn(2)
|
142
140
|
output = Softsign_layer(input)
|
143
141
|
|
144
142
|
def test_Tanhshrink(self):
|
145
|
-
Tanhshrink_layer =
|
146
|
-
input =
|
143
|
+
Tanhshrink_layer = brainstate.nn.Tanhshrink()
|
144
|
+
input = brainstate.random.randn(2)
|
147
145
|
output = Tanhshrink_layer(input)
|
148
146
|
|
149
147
|
def test_Softmin(self):
|
150
|
-
Softmin_layer =
|
151
|
-
input =
|
148
|
+
Softmin_layer = brainstate.nn.Softmin(dim=2)
|
149
|
+
input = brainstate.random.randn(2, 3, 4)
|
152
150
|
output = Softmin_layer(input)
|
153
151
|
|
154
152
|
def test_Softmax(self):
|
155
|
-
Softmax_layer =
|
156
|
-
input =
|
153
|
+
Softmax_layer = brainstate.nn.Softmax(dim=2)
|
154
|
+
input = brainstate.random.randn(2, 3, 4)
|
157
155
|
output = Softmax_layer(input)
|
158
156
|
|
159
157
|
def test_Softmax2d(self):
|
160
|
-
Softmax2d_layer =
|
161
|
-
input =
|
158
|
+
Softmax2d_layer = brainstate.nn.Softmax2d()
|
159
|
+
input = brainstate.random.randn(2, 3, 12, 13)
|
162
160
|
output = Softmax2d_layer(input)
|
163
161
|
|
164
162
|
def test_LogSoftmax(self):
|
165
|
-
LogSoftmax_layer =
|
166
|
-
input =
|
163
|
+
LogSoftmax_layer = brainstate.nn.LogSoftmax(dim=2)
|
164
|
+
input = brainstate.random.randn(2, 3, 4)
|
167
165
|
output = LogSoftmax_layer(input)
|
168
166
|
|
169
167
|
|
brainstate/nn/_exp_euler.py
CHANGED
brainstate/nn/_exp_euler_test.py
CHANGED
@@ -13,13 +13,12 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
16
|
|
18
17
|
import unittest
|
19
18
|
|
20
19
|
import brainunit as u
|
21
20
|
|
22
|
-
import brainstate
|
21
|
+
import brainstate
|
23
22
|
|
24
23
|
|
25
24
|
class TestExpEuler(unittest.TestCase):
|
@@ -27,10 +26,10 @@ class TestExpEuler(unittest.TestCase):
|
|
27
26
|
def fun(x, tau):
|
28
27
|
return -x / tau
|
29
28
|
|
30
|
-
with
|
29
|
+
with brainstate.environ.context(dt=0.1):
|
31
30
|
with self.assertRaises(AssertionError):
|
32
|
-
r =
|
31
|
+
r = brainstate.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
|
33
32
|
|
34
|
-
with
|
35
|
-
r =
|
33
|
+
with brainstate.environ.context(dt=1. * u.ms):
|
34
|
+
r = brainstate.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
|
36
35
|
print(r)
|