brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +10 -3
  3. brainstate/_state.py +178 -178
  4. brainstate/_utils.py +0 -1
  5. brainstate/augment/_autograd.py +0 -2
  6. brainstate/augment/_autograd_test.py +132 -133
  7. brainstate/augment/_eval_shape.py +0 -2
  8. brainstate/augment/_eval_shape_test.py +7 -9
  9. brainstate/augment/_mapping.py +2 -3
  10. brainstate/augment/_mapping_test.py +75 -76
  11. brainstate/augment/_random.py +0 -2
  12. brainstate/compile/_ad_checkpoint.py +0 -2
  13. brainstate/compile/_ad_checkpoint_test.py +6 -8
  14. brainstate/compile/_conditions.py +0 -2
  15. brainstate/compile/_conditions_test.py +35 -36
  16. brainstate/compile/_error_if.py +0 -2
  17. brainstate/compile/_error_if_test.py +10 -13
  18. brainstate/compile/_jit.py +9 -8
  19. brainstate/compile/_loop_collect_return.py +0 -2
  20. brainstate/compile/_loop_collect_return_test.py +7 -9
  21. brainstate/compile/_loop_no_collection.py +0 -2
  22. brainstate/compile/_loop_no_collection_test.py +7 -8
  23. brainstate/compile/_make_jaxpr.py +30 -17
  24. brainstate/compile/_make_jaxpr_test.py +20 -20
  25. brainstate/compile/_progress_bar.py +0 -1
  26. brainstate/compile/_unvmap.py +0 -1
  27. brainstate/compile/_util.py +0 -2
  28. brainstate/environ.py +0 -2
  29. brainstate/functional/_activations.py +0 -2
  30. brainstate/functional/_activations_test.py +61 -61
  31. brainstate/functional/_normalization.py +0 -2
  32. brainstate/functional/_others.py +0 -2
  33. brainstate/functional/_spikes.py +0 -1
  34. brainstate/graph/_graph_node.py +1 -3
  35. brainstate/graph/_graph_node_test.py +16 -18
  36. brainstate/graph/_graph_operation.py +4 -2
  37. brainstate/graph/_graph_operation_test.py +154 -156
  38. brainstate/init/_base.py +0 -2
  39. brainstate/init/_generic.py +0 -1
  40. brainstate/init/_random_inits.py +0 -1
  41. brainstate/init/_random_inits_test.py +20 -21
  42. brainstate/init/_regular_inits.py +0 -2
  43. brainstate/init/_regular_inits_test.py +4 -5
  44. brainstate/mixin.py +0 -2
  45. brainstate/nn/_collective_ops.py +0 -3
  46. brainstate/nn/_collective_ops_test.py +8 -8
  47. brainstate/nn/_common.py +0 -2
  48. brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
  49. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
  50. brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
  51. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
  52. brainstate/nn/_dyn_impl/_inputs.py +0 -1
  53. brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
  54. brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
  55. brainstate/nn/_dyn_impl/_readout.py +0 -1
  56. brainstate/nn/_dyn_impl/_readout_test.py +9 -10
  57. brainstate/nn/_dynamics/_dynamics_base.py +0 -1
  58. brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
  59. brainstate/nn/_dynamics/_projection_base.py +0 -1
  60. brainstate/nn/_dynamics/_state_delay.py +0 -2
  61. brainstate/nn/_dynamics/_synouts.py +0 -2
  62. brainstate/nn/_dynamics/_synouts_test.py +4 -5
  63. brainstate/nn/_elementwise/_dropout.py +0 -2
  64. brainstate/nn/_elementwise/_dropout_test.py +9 -9
  65. brainstate/nn/_elementwise/_elementwise.py +0 -2
  66. brainstate/nn/_elementwise/_elementwise_test.py +57 -59
  67. brainstate/nn/_event/_fixedprob_mv.py +0 -1
  68. brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
  69. brainstate/nn/_event/_linear_mv.py +0 -2
  70. brainstate/nn/_event/_linear_mv_test.py +0 -1
  71. brainstate/nn/_exp_euler.py +0 -2
  72. brainstate/nn/_exp_euler_test.py +5 -6
  73. brainstate/nn/_interaction/_conv.py +0 -2
  74. brainstate/nn/_interaction/_conv_test.py +31 -33
  75. brainstate/nn/_interaction/_embedding.py +0 -1
  76. brainstate/nn/_interaction/_linear.py +0 -2
  77. brainstate/nn/_interaction/_linear_test.py +15 -17
  78. brainstate/nn/_interaction/_normalizations.py +0 -2
  79. brainstate/nn/_interaction/_normalizations_test.py +10 -12
  80. brainstate/nn/_interaction/_poolings.py +0 -2
  81. brainstate/nn/_interaction/_poolings_test.py +19 -21
  82. brainstate/nn/_module.py +0 -1
  83. brainstate/nn/_module_test.py +34 -37
  84. brainstate/nn/metrics.py +0 -2
  85. brainstate/optim/_base.py +0 -2
  86. brainstate/optim/_lr_scheduler.py +0 -1
  87. brainstate/optim/_lr_scheduler_test.py +3 -3
  88. brainstate/optim/_optax_optimizer.py +0 -2
  89. brainstate/optim/_optax_optimizer_test.py +8 -9
  90. brainstate/optim/_sgd_optimizer.py +0 -1
  91. brainstate/random/_rand_funs.py +0 -1
  92. brainstate/random/_rand_funs_test.py +183 -184
  93. brainstate/random/_rand_seed.py +0 -1
  94. brainstate/random/_rand_seed_test.py +10 -12
  95. brainstate/random/_rand_state.py +0 -1
  96. brainstate/surrogate.py +0 -1
  97. brainstate/typing.py +0 -2
  98. brainstate/util/_caller.py +4 -6
  99. brainstate/util/_others.py +0 -2
  100. brainstate/util/_pretty_pytree.py +201 -150
  101. brainstate/util/_pretty_repr.py +0 -2
  102. brainstate/util/_pretty_table.py +57 -3
  103. brainstate/util/_scaling.py +0 -2
  104. brainstate/util/_struct.py +0 -2
  105. brainstate/util/filter.py +0 -2
  106. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
  107. brainstate-0.1.2.dist-info/RECORD +133 -0
  108. brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
  109. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
  110. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
  111. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -15,47 +15,46 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from __future__ import annotations
19
18
 
20
19
  import unittest
21
20
 
22
21
  import numpy as np
23
22
 
24
- import brainstate as bst
23
+ import brainstate
25
24
 
26
25
 
27
26
  class TestModuleGroup(unittest.TestCase):
28
27
  def test_initialization(self):
29
- group = bst.nn.DynamicsGroup()
30
- self.assertIsInstance(group, bst.nn.DynamicsGroup)
28
+ group = brainstate.nn.DynamicsGroup()
29
+ self.assertIsInstance(group, brainstate.nn.DynamicsGroup)
31
30
 
32
31
 
33
32
  class TestProjection(unittest.TestCase):
34
33
  def test_initialization(self):
35
- proj = bst.nn.Projection()
36
- self.assertIsInstance(proj, bst.nn.Projection)
34
+ proj = brainstate.nn.Projection()
35
+ self.assertIsInstance(proj, brainstate.nn.Projection)
37
36
 
38
37
  def test_update_not_implemented(self):
39
- proj = bst.nn.Projection()
38
+ proj = brainstate.nn.Projection()
40
39
  with self.assertRaises(ValueError):
41
40
  proj.update()
42
41
 
43
42
 
44
43
  class TestDynamics(unittest.TestCase):
45
44
  def test_initialization(self):
46
- dyn = bst.nn.Dynamics(in_size=10)
47
- self.assertIsInstance(dyn, bst.nn.Dynamics)
45
+ dyn = brainstate.nn.Dynamics(in_size=10)
46
+ self.assertIsInstance(dyn, brainstate.nn.Dynamics)
48
47
  self.assertEqual(dyn.in_size, (10,))
49
48
  self.assertEqual(dyn.out_size, (10,))
50
49
 
51
50
  def test_size_validation(self):
52
51
  with self.assertRaises(ValueError):
53
- bst.nn.Dynamics(in_size=[])
52
+ brainstate.nn.Dynamics(in_size=[])
54
53
  with self.assertRaises(ValueError):
55
- bst.nn.Dynamics(in_size="invalid")
54
+ brainstate.nn.Dynamics(in_size="invalid")
56
55
 
57
56
  def test_input_handling(self):
58
- dyn = bst.nn.Dynamics(in_size=10)
57
+ dyn = brainstate.nn.Dynamics(in_size=10)
59
58
  dyn.add_current_input("test_current", lambda: np.random.rand(10))
60
59
  dyn.add_delta_input("test_delta", lambda: np.random.rand(10))
61
60
 
@@ -63,15 +62,15 @@ class TestDynamics(unittest.TestCase):
63
62
  self.assertIn("test_delta", dyn.delta_inputs)
64
63
 
65
64
  def test_duplicate_input_key(self):
66
- dyn = bst.nn.Dynamics(in_size=10)
65
+ dyn = brainstate.nn.Dynamics(in_size=10)
67
66
  dyn.add_current_input("test", lambda: np.random.rand(10))
68
67
  with self.assertRaises(ValueError):
69
68
  dyn.add_current_input("test", lambda: np.random.rand(10))
70
69
 
71
70
  def test_varshape(self):
72
- dyn = bst.nn.Dynamics(in_size=(2, 3))
71
+ dyn = brainstate.nn.Dynamics(in_size=(2, 3))
73
72
  self.assertEqual(dyn.varshape, (2, 3))
74
- dyn = bst.nn.Dynamics(in_size=(2, 3))
73
+ dyn = brainstate.nn.Dynamics(in_size=(2, 3))
75
74
  self.assertEqual(dyn.varshape, (2, 3))
76
75
 
77
76
 
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from __future__ import annotations
16
15
 
17
16
  from typing import Union, Callable, Optional
18
17
 
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import math
19
17
  import numbers
20
18
  from functools import partial
@@ -15,8 +15,6 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from __future__ import annotations
19
-
20
18
  import brainunit as u
21
19
  import jax.numpy as jnp
22
20
 
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
@@ -21,7 +20,7 @@ import brainunit as u
21
20
  import jax.numpy as jnp
22
21
  import numpy as np
23
22
 
24
- import brainstate as bst
23
+ import brainstate
25
24
 
26
25
 
27
26
  class TestSynOutModels(unittest.TestCase):
@@ -35,19 +34,19 @@ class TestSynOutModels(unittest.TestCase):
35
34
  self.V_offset = jnp.array([0.0])
36
35
 
37
36
  def test_COBA(self):
38
- model = bst.nn.COBA(E=self.E)
37
+ model = brainstate.nn.COBA(E=self.E)
39
38
  output = model.update(self.conductance, self.potential)
40
39
  expected_output = self.conductance * (self.E - self.potential)
41
40
  np.testing.assert_array_almost_equal(output, expected_output)
42
41
 
43
42
  def test_CUBA(self):
44
- model = bst.nn.CUBA()
43
+ model = brainstate.nn.CUBA()
45
44
  output = model.update(self.conductance)
46
45
  expected_output = self.conductance * model.scale
47
46
  self.assertTrue(u.math.allclose(output, expected_output))
48
47
 
49
48
  def test_MgBlock(self):
50
- model = bst.nn.MgBlock(E=self.E, cc_Mg=self.cc_Mg, alpha=self.alpha, beta=self.beta, V_offset=self.V_offset)
49
+ model = brainstate.nn.MgBlock(E=self.E, cc_Mg=self.cc_Mg, alpha=self.alpha, beta=self.beta, V_offset=self.V_offset)
51
50
  output = model.update(self.conductance, self.potential)
52
51
  norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - self.potential)))
53
52
  expected_output = self.conductance * (self.E - self.potential) / norm
@@ -14,8 +14,6 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from __future__ import annotations
18
-
19
17
  from functools import partial
20
18
  from typing import Optional, Sequence
21
19
 
@@ -18,19 +18,19 @@ import unittest
18
18
 
19
19
  import numpy as np
20
20
 
21
- import brainstate as bst
21
+ import brainstate
22
22
 
23
23
 
24
24
  class TestDropout(unittest.TestCase):
25
25
 
26
26
  def test_dropout(self):
27
27
  # Create a Dropout layer with a dropout rate of 0.5
28
- dropout_layer = bst.nn.Dropout(0.5)
28
+ dropout_layer = brainstate.nn.Dropout(0.5)
29
29
 
30
30
  # Input data
31
31
  input_data = np.arange(20)
32
32
 
33
- with bst.environ.context(fit=True):
33
+ with brainstate.environ.context(fit=True):
34
34
  # Apply dropout
35
35
  output_data = dropout_layer(input_data)
36
36
 
@@ -47,10 +47,10 @@ class TestDropout(unittest.TestCase):
47
47
  np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements)
48
48
 
49
49
  def test_DropoutFixed(self):
50
- dropout_layer = bst.nn.DropoutFixed(in_size=(2, 3), prob=0.5)
50
+ dropout_layer = brainstate.nn.DropoutFixed(in_size=(2, 3), prob=0.5)
51
51
  dropout_layer.init_state(batch_size=2)
52
52
  input_data = np.random.randn(2, 2, 3)
53
- with bst.environ.context(fit=True):
53
+ with brainstate.environ.context(fit=True):
54
54
  output_data = dropout_layer.update(input_data)
55
55
  self.assertEqual(input_data.shape, output_data.shape)
56
56
  self.assertTrue(np.any(output_data == 0))
@@ -72,9 +72,9 @@ class TestDropout(unittest.TestCase):
72
72
  # np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
73
73
 
74
74
  def test_Dropout2d(self):
75
- dropout_layer = bst.nn.Dropout2d(prob=0.5)
75
+ dropout_layer = brainstate.nn.Dropout2d(prob=0.5)
76
76
  input_data = np.random.randn(2, 3, 4, 5)
77
- with bst.environ.context(fit=True):
77
+ with brainstate.environ.context(fit=True):
78
78
  output_data = dropout_layer(input_data)
79
79
  self.assertEqual(input_data.shape, output_data.shape)
80
80
  self.assertTrue(np.any(output_data == 0))
@@ -84,9 +84,9 @@ class TestDropout(unittest.TestCase):
84
84
  np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
85
85
 
86
86
  def test_Dropout3d(self):
87
- dropout_layer = bst.nn.Dropout3d(prob=0.5)
87
+ dropout_layer = brainstate.nn.Dropout3d(prob=0.5)
88
88
  input_data = np.random.randn(2, 3, 4, 5, 6)
89
- with bst.environ.context(fit=True):
89
+ with brainstate.environ.context(fit=True):
90
90
  output_data = dropout_layer(input_data)
91
91
  self.assertEqual(input_data.shape, output_data.shape)
92
92
  self.assertTrue(np.any(output_data == 0))
@@ -15,8 +15,6 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from __future__ import annotations
19
-
20
18
  from typing import Optional
21
19
 
22
20
  import brainunit as u
@@ -13,157 +13,155 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  from absl.testing import absltest
19
17
  from absl.testing import parameterized
20
18
 
21
- import brainstate as bst
19
+ import brainstate
22
20
 
23
21
 
24
22
  class Test_Activation(parameterized.TestCase):
25
23
 
26
24
  def test_Threshold(self):
27
- threshold_layer = bst.nn.Threshold(5, 20)
28
- input = bst.random.randn(2)
25
+ threshold_layer = brainstate.nn.Threshold(5, 20)
26
+ input = brainstate.random.randn(2)
29
27
  output = threshold_layer(input)
30
28
 
31
29
  def test_ReLU(self):
32
- ReLU_layer = bst.nn.ReLU()
33
- input = bst.random.randn(2)
30
+ ReLU_layer = brainstate.nn.ReLU()
31
+ input = brainstate.random.randn(2)
34
32
  output = ReLU_layer(input)
35
33
 
36
34
  def test_RReLU(self):
37
- RReLU_layer = bst.nn.RReLU(lower=0, upper=1)
38
- input = bst.random.randn(2)
35
+ RReLU_layer = brainstate.nn.RReLU(lower=0, upper=1)
36
+ input = brainstate.random.randn(2)
39
37
  output = RReLU_layer(input)
40
38
 
41
39
  def test_Hardtanh(self):
42
- Hardtanh_layer = bst.nn.Hardtanh(min_val=0, max_val=1, )
43
- input = bst.random.randn(2)
40
+ Hardtanh_layer = brainstate.nn.Hardtanh(min_val=0, max_val=1, )
41
+ input = brainstate.random.randn(2)
44
42
  output = Hardtanh_layer(input)
45
43
 
46
44
  def test_ReLU6(self):
47
- ReLU6_layer = bst.nn.ReLU6()
48
- input = bst.random.randn(2)
45
+ ReLU6_layer = brainstate.nn.ReLU6()
46
+ input = brainstate.random.randn(2)
49
47
  output = ReLU6_layer(input)
50
48
 
51
49
  def test_Sigmoid(self):
52
- Sigmoid_layer = bst.nn.Sigmoid()
53
- input = bst.random.randn(2)
50
+ Sigmoid_layer = brainstate.nn.Sigmoid()
51
+ input = brainstate.random.randn(2)
54
52
  output = Sigmoid_layer(input)
55
53
 
56
54
  def test_Hardsigmoid(self):
57
- Hardsigmoid_layer = bst.nn.Hardsigmoid()
58
- input = bst.random.randn(2)
55
+ Hardsigmoid_layer = brainstate.nn.Hardsigmoid()
56
+ input = brainstate.random.randn(2)
59
57
  output = Hardsigmoid_layer(input)
60
58
 
61
59
  def test_Tanh(self):
62
- Tanh_layer = bst.nn.Tanh()
63
- input = bst.random.randn(2)
60
+ Tanh_layer = brainstate.nn.Tanh()
61
+ input = brainstate.random.randn(2)
64
62
  output = Tanh_layer(input)
65
63
 
66
64
  def test_SiLU(self):
67
- SiLU_layer = bst.nn.SiLU()
68
- input = bst.random.randn(2)
65
+ SiLU_layer = brainstate.nn.SiLU()
66
+ input = brainstate.random.randn(2)
69
67
  output = SiLU_layer(input)
70
68
 
71
69
  def test_Mish(self):
72
- Mish_layer = bst.nn.Mish()
73
- input = bst.random.randn(2)
70
+ Mish_layer = brainstate.nn.Mish()
71
+ input = brainstate.random.randn(2)
74
72
  output = Mish_layer(input)
75
73
 
76
74
  def test_Hardswish(self):
77
- Hardswish_layer = bst.nn.Hardswish()
78
- input = bst.random.randn(2)
75
+ Hardswish_layer = brainstate.nn.Hardswish()
76
+ input = brainstate.random.randn(2)
79
77
  output = Hardswish_layer(input)
80
78
 
81
79
  def test_ELU(self):
82
- ELU_layer = bst.nn.ELU(alpha=0.5, )
83
- input = bst.random.randn(2)
80
+ ELU_layer = brainstate.nn.ELU(alpha=0.5, )
81
+ input = brainstate.random.randn(2)
84
82
  output = ELU_layer(input)
85
83
 
86
84
  def test_CELU(self):
87
- CELU_layer = bst.nn.CELU(alpha=0.5, )
88
- input = bst.random.randn(2)
85
+ CELU_layer = brainstate.nn.CELU(alpha=0.5, )
86
+ input = brainstate.random.randn(2)
89
87
  output = CELU_layer(input)
90
88
 
91
89
  def test_SELU(self):
92
- SELU_layer = bst.nn.SELU()
93
- input = bst.random.randn(2)
90
+ SELU_layer = brainstate.nn.SELU()
91
+ input = brainstate.random.randn(2)
94
92
  output = SELU_layer(input)
95
93
 
96
94
  def test_GLU(self):
97
- GLU_layer = bst.nn.GLU()
98
- input = bst.random.randn(4, 2)
95
+ GLU_layer = brainstate.nn.GLU()
96
+ input = brainstate.random.randn(4, 2)
99
97
  output = GLU_layer(input)
100
98
 
101
99
  @parameterized.product(
102
100
  approximate=['tanh', 'none']
103
101
  )
104
102
  def test_GELU(self, approximate):
105
- GELU_layer = bst.nn.GELU()
106
- input = bst.random.randn(2)
103
+ GELU_layer = brainstate.nn.GELU()
104
+ input = brainstate.random.randn(2)
107
105
  output = GELU_layer(input)
108
106
 
109
107
  def test_Hardshrink(self):
110
- Hardshrink_layer = bst.nn.Hardshrink(lambd=1)
111
- input = bst.random.randn(2)
108
+ Hardshrink_layer = brainstate.nn.Hardshrink(lambd=1)
109
+ input = brainstate.random.randn(2)
112
110
  output = Hardshrink_layer(input)
113
111
 
114
112
  def test_LeakyReLU(self):
115
- LeakyReLU_layer = bst.nn.LeakyReLU()
116
- input = bst.random.randn(2)
113
+ LeakyReLU_layer = brainstate.nn.LeakyReLU()
114
+ input = brainstate.random.randn(2)
117
115
  output = LeakyReLU_layer(input)
118
116
 
119
117
  def test_LogSigmoid(self):
120
- LogSigmoid_layer = bst.nn.LogSigmoid()
121
- input = bst.random.randn(2)
118
+ LogSigmoid_layer = brainstate.nn.LogSigmoid()
119
+ input = brainstate.random.randn(2)
122
120
  output = LogSigmoid_layer(input)
123
121
 
124
122
  def test_Softplus(self):
125
- Softplus_layer = bst.nn.Softplus()
126
- input = bst.random.randn(2)
123
+ Softplus_layer = brainstate.nn.Softplus()
124
+ input = brainstate.random.randn(2)
127
125
  output = Softplus_layer(input)
128
126
 
129
127
  def test_Softshrink(self):
130
- Softshrink_layer = bst.nn.Softshrink(lambd=1)
131
- input = bst.random.randn(2)
128
+ Softshrink_layer = brainstate.nn.Softshrink(lambd=1)
129
+ input = brainstate.random.randn(2)
132
130
  output = Softshrink_layer(input)
133
131
 
134
132
  def test_PReLU(self):
135
- PReLU_layer = bst.nn.PReLU(num_parameters=2, init=0.5)
136
- input = bst.random.randn(2)
133
+ PReLU_layer = brainstate.nn.PReLU(num_parameters=2, init=0.5)
134
+ input = brainstate.random.randn(2)
137
135
  output = PReLU_layer(input)
138
136
 
139
137
  def test_Softsign(self):
140
- Softsign_layer = bst.nn.Softsign()
141
- input = bst.random.randn(2)
138
+ Softsign_layer = brainstate.nn.Softsign()
139
+ input = brainstate.random.randn(2)
142
140
  output = Softsign_layer(input)
143
141
 
144
142
  def test_Tanhshrink(self):
145
- Tanhshrink_layer = bst.nn.Tanhshrink()
146
- input = bst.random.randn(2)
143
+ Tanhshrink_layer = brainstate.nn.Tanhshrink()
144
+ input = brainstate.random.randn(2)
147
145
  output = Tanhshrink_layer(input)
148
146
 
149
147
  def test_Softmin(self):
150
- Softmin_layer = bst.nn.Softmin(dim=2)
151
- input = bst.random.randn(2, 3, 4)
148
+ Softmin_layer = brainstate.nn.Softmin(dim=2)
149
+ input = brainstate.random.randn(2, 3, 4)
152
150
  output = Softmin_layer(input)
153
151
 
154
152
  def test_Softmax(self):
155
- Softmax_layer = bst.nn.Softmax(dim=2)
156
- input = bst.random.randn(2, 3, 4)
153
+ Softmax_layer = brainstate.nn.Softmax(dim=2)
154
+ input = brainstate.random.randn(2, 3, 4)
157
155
  output = Softmax_layer(input)
158
156
 
159
157
  def test_Softmax2d(self):
160
- Softmax2d_layer = bst.nn.Softmax2d()
161
- input = bst.random.randn(2, 3, 12, 13)
158
+ Softmax2d_layer = brainstate.nn.Softmax2d()
159
+ input = brainstate.random.randn(2, 3, 12, 13)
162
160
  output = Softmax2d_layer(input)
163
161
 
164
162
  def test_LogSoftmax(self):
165
- LogSoftmax_layer = bst.nn.LogSoftmax(dim=2)
166
- input = bst.random.randn(2, 3, 4)
163
+ LogSoftmax_layer = brainstate.nn.LogSoftmax(dim=2)
164
+ input = brainstate.random.randn(2, 3, 4)
167
165
  output = LogSoftmax_layer(input)
168
166
 
169
167
 
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  from typing import Union, Callable, Optional
19
18
 
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import jax.numpy
19
18
  import jax.numpy as jnp
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  from typing import Union, Callable, Optional
19
17
 
20
18
  import brainunit as u
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import jax
19
18
  import jax.numpy as jnp
@@ -14,8 +14,6 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from __future__ import annotations
18
-
19
17
  from typing import Callable
20
18
 
21
19
  import brainunit as u
@@ -13,13 +13,12 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
20
19
  import brainunit as u
21
20
 
22
- import brainstate as bst
21
+ import brainstate
23
22
 
24
23
 
25
24
  class TestExpEuler(unittest.TestCase):
@@ -27,10 +26,10 @@ class TestExpEuler(unittest.TestCase):
27
26
  def fun(x, tau):
28
27
  return -x / tau
29
28
 
30
- with bst.environ.context(dt=0.1):
29
+ with brainstate.environ.context(dt=0.1):
31
30
  with self.assertRaises(AssertionError):
32
- r = bst.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
31
+ r = brainstate.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
33
32
 
34
- with bst.environ.context(dt=1. * u.ms):
35
- r = bst.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
33
+ with brainstate.environ.context(dt=1. * u.ms):
34
+ r = brainstate.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
36
35
  print(r)
@@ -15,8 +15,6 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from __future__ import annotations
19
-
20
18
  import collections.abc
21
19
  from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
22
20