brainstate 0.1.1__py2.py3-none-any.whl → 0.1.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +12 -9
  3. brainstate/_state.py +1 -1
  4. brainstate/augment/_autograd_test.py +132 -133
  5. brainstate/augment/_eval_shape_test.py +7 -9
  6. brainstate/augment/_mapping_test.py +75 -76
  7. brainstate/compile/_ad_checkpoint_test.py +6 -8
  8. brainstate/compile/_conditions_test.py +35 -36
  9. brainstate/compile/_error_if_test.py +10 -13
  10. brainstate/compile/_loop_collect_return_test.py +7 -9
  11. brainstate/compile/_loop_no_collection_test.py +7 -8
  12. brainstate/compile/_make_jaxpr.py +29 -14
  13. brainstate/compile/_make_jaxpr_test.py +20 -20
  14. brainstate/functional/_activations_test.py +61 -61
  15. brainstate/graph/_graph_node_test.py +16 -18
  16. brainstate/graph/_graph_operation_test.py +154 -156
  17. brainstate/init/_random_inits_test.py +20 -21
  18. brainstate/init/_regular_inits_test.py +4 -5
  19. brainstate/mixin.py +1 -14
  20. brainstate/nn/__init__.py +81 -17
  21. brainstate/nn/_collective_ops_test.py +8 -8
  22. brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
  23. brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +31 -33
  24. brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +6 -2
  25. brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +1 -1
  26. brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +9 -9
  27. brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +139 -18
  28. brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +14 -15
  29. brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
  30. brainstate/nn/_elementwise_test.py +169 -0
  31. brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
  32. brainstate/nn/_exp_euler_test.py +5 -6
  33. brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob_mv.py} +1 -1
  34. brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_mv_test.py} +0 -1
  35. brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
  36. brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
  37. brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +1 -1
  38. brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -1
  39. brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +15 -17
  40. brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
  41. brainstate/nn/_module_test.py +34 -37
  42. brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
  43. brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +18 -19
  44. brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
  45. brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +10 -12
  46. brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
  47. brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +20 -22
  48. brainstate/nn/{_dynamics/_projection_base.py → _projection.py} +35 -3
  49. brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
  50. brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +6 -7
  51. brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
  52. brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +9 -10
  53. brainstate/nn/_stp.py +236 -0
  54. brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +17 -206
  55. brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +9 -10
  56. brainstate/nn/_synaptic_projection.py +133 -0
  57. brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
  58. brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +4 -5
  59. brainstate/optim/_lr_scheduler_test.py +3 -3
  60. brainstate/optim/_optax_optimizer_test.py +8 -9
  61. brainstate/random/_rand_funs_test.py +183 -184
  62. brainstate/random/_rand_seed_test.py +10 -12
  63. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/METADATA +1 -1
  64. brainstate-0.1.3.dist-info/RECORD +131 -0
  65. brainstate/nn/_dyn_impl/__init__.py +0 -42
  66. brainstate/nn/_dynamics/__init__.py +0 -37
  67. brainstate/nn/_elementwise/__init__.py +0 -22
  68. brainstate/nn/_elementwise/_elementwise_test.py +0 -171
  69. brainstate/nn/_interaction/__init__.py +0 -41
  70. brainstate-0.1.1.dist-info/RECORD +0 -133
  71. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/LICENSE +0 -0
  72. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/WHEEL +0 -0
  73. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/top_level.txt +0 -0
@@ -14,30 +14,29 @@
14
14
  # ==============================================================================
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
- from __future__ import annotations
18
17
 
19
18
  import unittest
20
19
 
21
- import brainstate as bst
20
+ import brainstate
22
21
 
23
22
 
24
23
  class TestNormalInit(unittest.TestCase):
25
24
 
26
25
  def test_normal_init1(self):
27
- init = bst.init.Normal()
26
+ init = brainstate.init.Normal()
28
27
  for size in [(100,), (10, 20), (10, 20, 30)]:
29
28
  weights = init(size)
30
29
  assert weights.shape == size
31
30
 
32
31
  def test_normal_init2(self):
33
- init = bst.init.Normal(scale=0.5)
32
+ init = brainstate.init.Normal(scale=0.5)
34
33
  for size in [(100,), (10, 20)]:
35
34
  weights = init(size)
36
35
  assert weights.shape == size
37
36
 
38
37
  def test_normal_init3(self):
39
- init1 = bst.init.Normal(scale=0.5, seed=10)
40
- init2 = bst.init.Normal(scale=0.5, seed=10)
38
+ init1 = brainstate.init.Normal(scale=0.5, seed=10)
39
+ init2 = brainstate.init.Normal(scale=0.5, seed=10)
41
40
  size = (10,)
42
41
  weights1 = init1(size)
43
42
  weights2 = init2(size)
@@ -47,13 +46,13 @@ class TestNormalInit(unittest.TestCase):
47
46
 
48
47
  class TestUniformInit(unittest.TestCase):
49
48
  def test_uniform_init1(self):
50
- init = bst.init.Normal()
49
+ init = brainstate.init.Normal()
51
50
  for size in [(100,), (10, 20), (10, 20, 30)]:
52
51
  weights = init(size)
53
52
  assert weights.shape == size
54
53
 
55
54
  def test_uniform_init2(self):
56
- init = bst.init.Uniform(min_val=10, max_val=20)
55
+ init = brainstate.init.Uniform(min_val=10, max_val=20)
57
56
  for size in [(100,), (10, 20)]:
58
57
  weights = init(size)
59
58
  assert weights.shape == size
@@ -61,20 +60,20 @@ class TestUniformInit(unittest.TestCase):
61
60
 
62
61
  class TestVarianceScaling(unittest.TestCase):
63
62
  def test_var_scaling1(self):
64
- init = bst.init.VarianceScaling(scale=1., mode='fan_in', distribution='truncated_normal')
63
+ init = brainstate.init.VarianceScaling(scale=1., mode='fan_in', distribution='truncated_normal')
65
64
  for size in [(10, 20), (10, 20, 30)]:
66
65
  weights = init(size)
67
66
  assert weights.shape == size
68
67
 
69
68
  def test_var_scaling2(self):
70
- init = bst.init.VarianceScaling(scale=2, mode='fan_out', distribution='normal')
69
+ init = brainstate.init.VarianceScaling(scale=2, mode='fan_out', distribution='normal')
71
70
  for size in [(10, 20), (10, 20, 30)]:
72
71
  weights = init(size)
73
72
  assert weights.shape == size
74
73
 
75
74
  def test_var_scaling3(self):
76
- init = bst.init.VarianceScaling(scale=2 / 4, mode='fan_avg', in_axis=0, out_axis=1,
77
- distribution='uniform')
75
+ init = brainstate.init.VarianceScaling(scale=2 / 4, mode='fan_avg', in_axis=0, out_axis=1,
76
+ distribution='uniform')
78
77
  for size in [(10, 20), (10, 20, 30)]:
79
78
  weights = init(size)
80
79
  assert weights.shape == size
@@ -82,7 +81,7 @@ class TestVarianceScaling(unittest.TestCase):
82
81
 
83
82
  class TestKaimingUniformUnit(unittest.TestCase):
84
83
  def test_kaiming_uniform_init(self):
85
- init = bst.init.KaimingUniform()
84
+ init = brainstate.init.KaimingUniform()
86
85
  for size in [(10, 20), (10, 20, 30)]:
87
86
  weights = init(size)
88
87
  assert weights.shape == size
@@ -90,7 +89,7 @@ class TestKaimingUniformUnit(unittest.TestCase):
90
89
 
91
90
  class TestKaimingNormalUnit(unittest.TestCase):
92
91
  def test_kaiming_normal_init(self):
93
- init = bst.init.KaimingNormal()
92
+ init = brainstate.init.KaimingNormal()
94
93
  for size in [(10, 20), (10, 20, 30)]:
95
94
  weights = init(size)
96
95
  assert weights.shape == size
@@ -98,7 +97,7 @@ class TestKaimingNormalUnit(unittest.TestCase):
98
97
 
99
98
  class TestXavierUniformUnit(unittest.TestCase):
100
99
  def test_xavier_uniform_init(self):
101
- init = bst.init.XavierUniform()
100
+ init = brainstate.init.XavierUniform()
102
101
  for size in [(10, 20), (10, 20, 30)]:
103
102
  weights = init(size)
104
103
  assert weights.shape == size
@@ -106,7 +105,7 @@ class TestXavierUniformUnit(unittest.TestCase):
106
105
 
107
106
  class TestXavierNormalUnit(unittest.TestCase):
108
107
  def test_xavier_normal_init(self):
109
- init = bst.init.XavierNormal()
108
+ init = brainstate.init.XavierNormal()
110
109
  for size in [(10, 20), (10, 20, 30)]:
111
110
  weights = init(size)
112
111
  assert weights.shape == size
@@ -114,7 +113,7 @@ class TestXavierNormalUnit(unittest.TestCase):
114
113
 
115
114
  class TestLecunUniformUnit(unittest.TestCase):
116
115
  def test_lecun_uniform_init(self):
117
- init = bst.init.LecunUniform()
116
+ init = brainstate.init.LecunUniform()
118
117
  for size in [(10, 20), (10, 20, 30)]:
119
118
  weights = init(size)
120
119
  assert weights.shape == size
@@ -122,7 +121,7 @@ class TestLecunUniformUnit(unittest.TestCase):
122
121
 
123
122
  class TestLecunNormalUnit(unittest.TestCase):
124
123
  def test_lecun_normal_init(self):
125
- init = bst.init.LecunNormal()
124
+ init = brainstate.init.LecunNormal()
126
125
  for size in [(10, 20), (10, 20, 30)]:
127
126
  weights = init(size)
128
127
  assert weights.shape == size
@@ -130,13 +129,13 @@ class TestLecunNormalUnit(unittest.TestCase):
130
129
 
131
130
  class TestOrthogonalUnit(unittest.TestCase):
132
131
  def test_orthogonal_init1(self):
133
- init = bst.init.Orthogonal()
132
+ init = brainstate.init.Orthogonal()
134
133
  for size in [(20, 20), (10, 20, 30)]:
135
134
  weights = init(size)
136
135
  assert weights.shape == size
137
136
 
138
137
  def test_orthogonal_init2(self):
139
- init = bst.init.Orthogonal(scale=2., axis=0)
138
+ init = brainstate.init.Orthogonal(scale=2., axis=0)
140
139
  for size in [(10, 20), (10, 20, 30)]:
141
140
  weights = init(size)
142
141
  assert weights.shape == size
@@ -144,7 +143,7 @@ class TestOrthogonalUnit(unittest.TestCase):
144
143
 
145
144
  class TestDeltaOrthogonalUnit(unittest.TestCase):
146
145
  def test_delta_orthogonal_init1(self):
147
- init = bst.init.DeltaOrthogonal()
146
+ init = brainstate.init.DeltaOrthogonal()
148
147
  for size in [(20, 20, 20), (10, 20, 30, 40), (50, 40, 30, 20, 20)]:
149
148
  weights = init(size)
150
149
  assert weights.shape == size
@@ -14,16 +14,15 @@
14
14
  # ==============================================================================
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
- from __future__ import annotations
18
17
 
19
18
  import unittest
20
19
 
21
- import brainstate as bst
20
+ import brainstate
22
21
 
23
22
 
24
23
  class TestZeroInit(unittest.TestCase):
25
24
  def test_zero_init(self):
26
- init = bst.init.ZeroInit()
25
+ init = brainstate.init.ZeroInit()
27
26
  for size in [(100,), (10, 20), (10, 20, 30)]:
28
27
  weights = init(size)
29
28
  assert weights.shape == size
@@ -33,7 +32,7 @@ class TestOneInit(unittest.TestCase):
33
32
  def test_one_init(self):
34
33
  for size in [(100,), (10, 20), (10, 20, 30)]:
35
34
  for value in [0., 1., -1.]:
36
- init = bst.init.Constant(value=value)
35
+ init = brainstate.init.Constant(value=value)
37
36
  weights = init(size)
38
37
  assert weights.shape == size
39
38
  assert (weights == value).all()
@@ -43,7 +42,7 @@ class TestIdentityInit(unittest.TestCase):
43
42
  def test_identity_init(self):
44
43
  for size in [(100,), (10, 20)]:
45
44
  for value in [0., 1., -1.]:
46
- init = bst.init.Identity(value=value)
45
+ init = brainstate.init.Identity(value=value)
47
46
  weights = init(size)
48
47
  if len(size) == 1:
49
48
  assert weights.shape == (size[0], size[0])
brainstate/mixin.py CHANGED
@@ -132,6 +132,7 @@ class AlignPost(Mixin):
132
132
  raise NotImplementedError
133
133
 
134
134
 
135
+
135
136
  class BindCondData(Mixin):
136
137
  """Bind temporary conductance data.
137
138
 
@@ -147,7 +148,6 @@ class BindCondData(Mixin):
147
148
 
148
149
 
149
150
  class UpdateReturn(Mixin):
150
-
151
151
  def update_return(self) -> PyTree:
152
152
  """
153
153
  The update function return of the model.
@@ -157,19 +157,6 @@ class UpdateReturn(Mixin):
157
157
  """
158
158
  raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
159
159
 
160
- def update_return_info(self) -> PyTree:
161
- """
162
- The update return information of the model.
163
-
164
- It should be a pytree, with each element as a ``jax.Array``.
165
-
166
- .. note::
167
- Should not include the batch axis and batch in_size.
168
- These information will be inferred from the ``mode`` attribute.
169
-
170
- """
171
- raise NotImplementedError(f'Must implement the "{self.update_return_info.__name__}()" function.')
172
-
173
160
 
174
161
  class _MetaUnionType(type):
175
162
  def __new__(cls, name, bases, dct):
brainstate/nn/__init__.py CHANGED
@@ -19,46 +19,110 @@ from ._collective_ops import *
19
19
  from ._collective_ops import __all__ as collective_ops_all
20
20
  from ._common import *
21
21
  from ._common import __all__ as common_all
22
- from ._dyn_impl import *
23
- from ._dyn_impl import __all__ as dyn_impl_all
22
+ from ._conv import *
23
+ from ._conv import __all__ as conv_all
24
+ from ._delay import *
25
+ from ._delay import __all__ as state_delay_all
26
+ from ._dropout import *
27
+ from ._dropout import __all__ as dropout_all
24
28
  from ._dynamics import *
25
- from ._dynamics import __all__ as dynamics_all
29
+ from ._dynamics import __all__ as dyn_all
26
30
  from ._elementwise import *
27
31
  from ._elementwise import __all__ as elementwise_all
28
- from ._event import *
29
- from ._event import __all__ as _event_all
32
+ from ._embedding import *
33
+ from ._embedding import __all__ as embed_all
30
34
  from ._exp_euler import *
31
35
  from ._exp_euler import __all__ as exp_euler_all
32
- from ._interaction import *
33
- from ._interaction import __all__ as interaction_all
36
+ from ._fixedprob_mv import EventFixedProb, EventFixedNumConn
37
+ from ._inputs import *
38
+ from ._inputs import __all__ as inputs_all
39
+ from ._linear import *
40
+ from ._linear import __all__ as linear_all
41
+ from ._linear_mv import EventLinear
42
+ from ._ltp import *
43
+ from ._ltp import __all__ as ltp_all
34
44
  from ._module import *
35
45
  from ._module import __all__ as module_all
46
+ from ._neuron import *
47
+ from ._neuron import __all__ as dyn_neuron_all
48
+ from ._normalizations import *
49
+ from ._normalizations import __all__ as normalizations_all
50
+ from ._poolings import *
51
+ from ._poolings import __all__ as poolings_all
52
+ from ._projection import *
53
+ from ._projection import __all__ as projection_all
54
+ from ._rate_rnns import *
55
+ from ._rate_rnns import __all__ as rate_rnns
56
+ from ._readout import *
57
+ from ._readout import __all__ as readout_all
58
+ from ._stp import *
59
+ from ._stp import __all__ as stp_all
60
+ from ._synapse import *
61
+ from ._synapse import __all__ as dyn_synapse_all
62
+ from ._synaptic_projection import *
63
+ from ._synaptic_projection import __all__ as _syn_proj_all
64
+ from ._synouts import *
65
+ from ._synouts import __all__ as synouts_all
36
66
  from ._utils import *
37
67
  from ._utils import __all__ as utils_all
38
68
 
39
69
  __all__ = (
40
- ['metrics']
70
+ [
71
+ 'metrics',
72
+ 'EventLinear',
73
+ 'EventFixedProb',
74
+ 'EventFixedNumConn',
75
+ ]
41
76
  + collective_ops_all
42
77
  + common_all
43
- + dyn_impl_all
44
- + dynamics_all
45
78
  + elementwise_all
46
79
  + module_all
47
80
  + exp_euler_all
48
- + interaction_all
49
81
  + utils_all
50
- + _event_all
82
+ + dyn_all
83
+ + projection_all
84
+ + state_delay_all
85
+ + synouts_all
86
+ + conv_all
87
+ + linear_all
88
+ + normalizations_all
89
+ + poolings_all
90
+ + embed_all
91
+ + dropout_all
92
+ + elementwise_all
93
+ + dyn_neuron_all
94
+ + dyn_synapse_all
95
+ + inputs_all
96
+ + rate_rnns
97
+ + readout_all
98
+ + stp_all
99
+ + ltp_all
100
+ + _syn_proj_all
51
101
  )
52
102
 
53
103
  del (
54
104
  collective_ops_all,
55
105
  common_all,
56
- dyn_impl_all,
57
- dynamics_all,
58
- elementwise_all,
59
106
  module_all,
60
107
  exp_euler_all,
61
- interaction_all,
62
108
  utils_all,
63
- _event_all,
109
+ dyn_all,
110
+ projection_all,
111
+ state_delay_all,
112
+ synouts_all,
113
+ conv_all,
114
+ linear_all,
115
+ normalizations_all,
116
+ poolings_all,
117
+ embed_all,
118
+ dropout_all,
119
+ elementwise_all,
120
+ dyn_neuron_all,
121
+ dyn_synapse_all,
122
+ inputs_all,
123
+ readout_all,
124
+ rate_rnns,
125
+ stp_all,
126
+ ltp_all,
127
+ _syn_proj_all,
64
128
  )
@@ -16,21 +16,21 @@
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
18
 
19
- import brainstate as bst
19
+ import brainstate
20
20
 
21
21
 
22
22
  class Test_vmap_init_all_states:
23
23
 
24
24
  def test_vmap_init_all_states(self):
25
- gru = bst.nn.GRUCell(1, 2)
26
- bst.nn.vmap_init_all_states(gru, axis_size=10)
25
+ gru = brainstate.nn.GRUCell(1, 2)
26
+ brainstate.nn.vmap_init_all_states(gru, axis_size=10)
27
27
  print(gru)
28
28
 
29
29
  def test_vmap_init_all_states_v2(self):
30
- @bst.compile.jit
30
+ @brainstate.compile.jit
31
31
  def init():
32
- gru = bst.nn.GRUCell(1, 2)
33
- bst.nn.vmap_init_all_states(gru, axis_size=10)
32
+ gru = brainstate.nn.GRUCell(1, 2)
33
+ brainstate.nn.vmap_init_all_states(gru, axis_size=10)
34
34
  print(gru)
35
35
 
36
36
  init()
@@ -38,6 +38,6 @@ class Test_vmap_init_all_states:
38
38
 
39
39
  class Test_init_all_states:
40
40
  def test_init_all_states(self):
41
- gru = bst.nn.GRUCell(1, 2)
42
- bst.nn.init_all_states(gru, batch_size=10)
41
+ gru = brainstate.nn.GRUCell(1, 2)
42
+ brainstate.nn.init_all_states(gru, batch_size=10)
43
43
  print(gru)
@@ -23,8 +23,8 @@ import jax.numpy as jnp
23
23
 
24
24
  from brainstate import init, functional
25
25
  from brainstate._state import ParamState
26
- from brainstate.nn._module import Module
27
26
  from brainstate.typing import ArrayLike
27
+ from ._module import Module
28
28
 
29
29
  T = TypeVar('T')
30
30
 
@@ -1,13 +1,11 @@
1
1
  # -*- coding: utf-8 -*-
2
2
 
3
- from __future__ import annotations
4
-
5
3
  import jax.numpy as jnp
6
4
  import pytest
7
5
  from absl.testing import absltest
8
6
  from absl.testing import parameterized
9
7
 
10
- import brainstate as bst
8
+ import brainstate
11
9
 
12
10
 
13
11
  class TestConv(parameterized.TestCase):
@@ -19,8 +17,8 @@ class TestConv(parameterized.TestCase):
19
17
  img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
20
18
  img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)
21
19
 
22
- net = bst.nn.Conv2d((200, 198, 4), out_channels=32, kernel_size=(3, 3),
23
- stride=(2, 1), padding='VALID', groups=4)
20
+ net = brainstate.nn.Conv2d((200, 198, 4), out_channels=32, kernel_size=(3, 3),
21
+ stride=(2, 1), padding='VALID', groups=4)
24
22
  out = net(img)
25
23
  print("out shape: ", out.shape)
26
24
  self.assertEqual(out.shape, (2, 99, 196, 32))
@@ -30,7 +28,7 @@ class TestConv(parameterized.TestCase):
30
28
  # plt.show()
31
29
 
32
30
  def test_conv1D(self):
33
- model = bst.nn.Conv1d((5, 3), out_channels=32, kernel_size=(3,))
31
+ model = brainstate.nn.Conv1d((5, 3), out_channels=32, kernel_size=(3,))
34
32
  input = jnp.ones((2, 5, 3))
35
33
  out = model(input)
36
34
  print("out shape: ", out.shape)
@@ -41,7 +39,7 @@ class TestConv(parameterized.TestCase):
41
39
  # plt.show()
42
40
 
43
41
  def test_conv2D(self):
44
- model = bst.nn.Conv2d((5, 5, 3), out_channels=32, kernel_size=(3, 3))
42
+ model = brainstate.nn.Conv2d((5, 5, 3), out_channels=32, kernel_size=(3, 3))
45
43
  input = jnp.ones((2, 5, 5, 3))
46
44
 
47
45
  out = model(input)
@@ -49,7 +47,7 @@ class TestConv(parameterized.TestCase):
49
47
  self.assertEqual(out.shape, (2, 5, 5, 32))
50
48
 
51
49
  def test_conv3D(self):
52
- model = bst.nn.Conv3d((5, 5, 5, 3), out_channels=32, kernel_size=(3, 3, 3))
50
+ model = brainstate.nn.Conv3d((5, 5, 5, 3), out_channels=32, kernel_size=(3, 3, 3))
53
51
  input = jnp.ones((2, 5, 5, 5, 3))
54
52
  out = model(input)
55
53
  print("out shape: ", out.shape)
@@ -62,13 +60,13 @@ class TestConvTranspose1d(parameterized.TestCase):
62
60
 
63
61
  x = jnp.ones((1, 8, 3))
64
62
  for use_bias in [True, False]:
65
- conv_transpose_module = bst.nn.ConvTranspose1d(
63
+ conv_transpose_module = brainstate.nn.ConvTranspose1d(
66
64
  in_channels=3,
67
65
  out_channels=4,
68
66
  kernel_size=(3,),
69
67
  padding='VALID',
70
- w_initializer=bst.init.Constant(1.),
71
- b_initializer=bst.init.Constant(1.) if use_bias else None,
68
+ w_initializer=brainstate.init.Constant(1.),
69
+ b_initializer=brainstate.init.Constant(1.) if use_bias else None,
72
70
  )
73
71
  self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
74
72
  y = conv_transpose_module(x)
@@ -91,14 +89,14 @@ class TestConvTranspose1d(parameterized.TestCase):
91
89
 
92
90
  x = jnp.ones((1, 8, 3))
93
91
  m = jnp.tril(jnp.ones((3, 3, 4)))
94
- conv_transpose_module = bst.nn.ConvTranspose1d(
92
+ conv_transpose_module = brainstate.nn.ConvTranspose1d(
95
93
  in_channels=3,
96
94
  out_channels=4,
97
95
  kernel_size=(3,),
98
96
  padding='VALID',
99
97
  mask=m,
100
- w_initializer=bst.init.Constant(),
101
- b_initializer=bst.init.Constant(),
98
+ w_initializer=brainstate.init.Constant(),
99
+ b_initializer=brainstate.init.Constant(),
102
100
  )
103
101
  self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
104
102
  y = conv_transpose_module(x)
@@ -119,14 +117,14 @@ class TestConvTranspose1d(parameterized.TestCase):
119
117
 
120
118
  data = jnp.ones([1, 3, 1])
121
119
  for use_bias in [True, False]:
122
- net = bst.nn.ConvTranspose1d(
120
+ net = brainstate.nn.ConvTranspose1d(
123
121
  in_channels=1,
124
122
  out_channels=1,
125
123
  kernel_size=3,
126
124
  stride=1,
127
125
  padding="SAME",
128
- w_initializer=bst.init.Constant(),
129
- b_initializer=bst.init.Constant() if use_bias else None,
126
+ w_initializer=brainstate.init.Constant(),
127
+ b_initializer=brainstate.init.Constant() if use_bias else None,
130
128
  )
131
129
  out = net(data)
132
130
  self.assertEqual(out.shape, (1, 3, 1))
@@ -143,13 +141,13 @@ class TestConvTranspose2d(parameterized.TestCase):
143
141
 
144
142
  x = jnp.ones((1, 8, 8, 3))
145
143
  for use_bias in [True, False]:
146
- conv_transpose_module = bst.nn.ConvTranspose2d(
144
+ conv_transpose_module = brainstate.nn.ConvTranspose2d(
147
145
  in_channels=3,
148
146
  out_channels=4,
149
147
  kernel_size=(3, 3),
150
148
  padding='VALID',
151
- w_initializer=bst.init.Constant(),
152
- b_initializer=bst.init.Constant() if use_bias else None,
149
+ w_initializer=brainstate.init.Constant(),
150
+ b_initializer=brainstate.init.Constant() if use_bias else None,
153
151
  )
154
152
  self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4))
155
153
  y = conv_transpose_module(x)
@@ -159,13 +157,13 @@ class TestConvTranspose2d(parameterized.TestCase):
159
157
 
160
158
  x = jnp.ones((1, 8, 8, 3))
161
159
  m = jnp.tril(jnp.ones((3, 3, 3, 4)))
162
- conv_transpose_module = bst.nn.ConvTranspose2d(
160
+ conv_transpose_module = brainstate.nn.ConvTranspose2d(
163
161
  in_channels=3,
164
162
  out_channels=4,
165
163
  kernel_size=(3, 3),
166
164
  padding='VALID',
167
165
  mask=m,
168
- w_initializer=bst.init.Constant(),
166
+ w_initializer=brainstate.init.Constant(),
169
167
  )
170
168
  y = conv_transpose_module(x)
171
169
  print(y.shape)
@@ -174,14 +172,14 @@ class TestConvTranspose2d(parameterized.TestCase):
174
172
 
175
173
  x = jnp.ones((1, 8, 8, 3))
176
174
  for use_bias in [True, False]:
177
- conv_transpose_module = bst.nn.ConvTranspose2d(
175
+ conv_transpose_module = brainstate.nn.ConvTranspose2d(
178
176
  in_channels=3,
179
177
  out_channels=4,
180
178
  kernel_size=(3, 3),
181
179
  stride=1,
182
180
  padding='SAME',
183
- w_initializer=bst.init.Constant(),
184
- b_initializer=bst.init.Constant() if use_bias else None,
181
+ w_initializer=brainstate.init.Constant(),
182
+ b_initializer=brainstate.init.Constant() if use_bias else None,
185
183
  )
186
184
  y = conv_transpose_module(x)
187
185
  print(y.shape)
@@ -193,13 +191,13 @@ class TestConvTranspose3d(parameterized.TestCase):
193
191
 
194
192
  x = jnp.ones((1, 8, 8, 8, 3))
195
193
  for use_bias in [True, False]:
196
- conv_transpose_module = bst.nn.ConvTranspose3d(
194
+ conv_transpose_module = brainstate.nn.ConvTranspose3d(
197
195
  in_channels=3,
198
196
  out_channels=4,
199
197
  kernel_size=(3, 3, 3),
200
198
  padding='VALID',
201
- w_initializer=bst.init.Constant(),
202
- b_initializer=bst.init.Constant() if use_bias else None,
199
+ w_initializer=brainstate.init.Constant(),
200
+ b_initializer=brainstate.init.Constant() if use_bias else None,
203
201
  )
204
202
  y = conv_transpose_module(x)
205
203
  print(y.shape)
@@ -208,13 +206,13 @@ class TestConvTranspose3d(parameterized.TestCase):
208
206
 
209
207
  x = jnp.ones((1, 8, 8, 8, 3))
210
208
  m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
211
- conv_transpose_module = bst.nn.ConvTranspose3d(
209
+ conv_transpose_module = brainstate.nn.ConvTranspose3d(
212
210
  in_channels=3,
213
211
  out_channels=4,
214
212
  kernel_size=(3, 3, 3),
215
213
  padding='VALID',
216
214
  mask=m,
217
- w_initializer=bst.init.Constant(),
215
+ w_initializer=brainstate.init.Constant(),
218
216
  )
219
217
  y = conv_transpose_module(x)
220
218
  print(y.shape)
@@ -223,14 +221,14 @@ class TestConvTranspose3d(parameterized.TestCase):
223
221
 
224
222
  x = jnp.ones((1, 8, 8, 8, 3))
225
223
  for use_bias in [True, False]:
226
- conv_transpose_module = bst.nn.ConvTranspose3d(
224
+ conv_transpose_module = brainstate.nn.ConvTranspose3d(
227
225
  in_channels=3,
228
226
  out_channels=4,
229
227
  kernel_size=(3, 3, 3),
230
228
  stride=1,
231
229
  padding='SAME',
232
- w_initializer=bst.init.Constant(),
233
- b_initializer=bst.init.Constant() if use_bias else None,
230
+ w_initializer=brainstate.init.Constant(),
231
+ b_initializer=brainstate.init.Constant() if use_bias else None,
234
232
  )
235
233
  y = conv_transpose_module(x)
236
234
  print(y.shape)
@@ -27,9 +27,9 @@ from brainstate import environ
27
27
  from brainstate._state import ShortTermState, State
28
28
  from brainstate.compile import jit_error_if
29
29
  from brainstate.graph import Node
30
- from brainstate.nn._collective_ops import call_order
31
- from brainstate.nn._module import Module
32
30
  from brainstate.typing import ArrayLike, PyTree
31
+ from ._collective_ops import call_order
32
+ from ._module import Module
33
33
 
34
34
  __all__ = [
35
35
  'Delay', 'DelayAccess', 'StateWithDelay',
@@ -135,6 +135,7 @@ class Delay(Module):
135
135
  entries: Optional[Dict] = None, # delay access entry
136
136
  delay_method: Optional[str] = _DELAY_ROTATE, # delay method
137
137
  interp_method: str = _INTERP_LINEAR, # interpolation method
138
+ take_aware_unit: bool = False
138
139
  ):
139
140
  # target information
140
141
  self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
@@ -170,6 +171,9 @@ class Delay(Module):
170
171
  for entry, delay_time in entries.items():
171
172
  self.register_entry(entry, delay_time)
172
173
 
174
+ self.take_aware_unit = take_aware_unit
175
+ self._unit = None
176
+
173
177
  @property
174
178
  def history(self):
175
179
  return self._history
@@ -22,8 +22,8 @@ import jax.numpy as jnp
22
22
 
23
23
  from brainstate import random, environ, init
24
24
  from brainstate._state import ShortTermState
25
- from brainstate.nn._module import ElementWiseBlock
26
25
  from brainstate.typing import Size
26
+ from ._module import ElementWiseBlock
27
27
 
28
28
  __all__ = [
29
29
  'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',