brainstate 0.1.1__py2.py3-none-any.whl → 0.1.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +12 -9
- brainstate/_state.py +1 -1
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +29 -14
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/functional/_activations_test.py +61 -61
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/mixin.py +1 -14
- brainstate/nn/__init__.py +81 -17
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
- brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +31 -33
- brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +6 -2
- brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +1 -1
- brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +9 -9
- brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +139 -18
- brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +14 -15
- brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
- brainstate/nn/_elementwise_test.py +169 -0
- brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob_mv.py} +1 -1
- brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_mv_test.py} +0 -1
- brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
- brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
- brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +1 -1
- brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -1
- brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +15 -17
- brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
- brainstate/nn/_module_test.py +34 -37
- brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
- brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +18 -19
- brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
- brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +10 -12
- brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
- brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +20 -22
- brainstate/nn/{_dynamics/_projection_base.py → _projection.py} +35 -3
- brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
- brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +6 -7
- brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
- brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +9 -10
- brainstate/nn/_stp.py +236 -0
- brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +17 -206
- brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +9 -10
- brainstate/nn/_synaptic_projection.py +133 -0
- brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
- brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +4 -5
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed_test.py +10 -12
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/METADATA +1 -1
- brainstate-0.1.3.dist-info/RECORD +131 -0
- brainstate/nn/_dyn_impl/__init__.py +0 -42
- brainstate/nn/_dynamics/__init__.py +0 -37
- brainstate/nn/_elementwise/__init__.py +0 -22
- brainstate/nn/_elementwise/_elementwise_test.py +0 -171
- brainstate/nn/_interaction/__init__.py +0 -41
- brainstate-0.1.1.dist-info/RECORD +0 -133
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/LICENSE +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/WHEEL +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/top_level.txt +0 -0
@@ -14,30 +14,29 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
|
-
from __future__ import annotations
|
18
17
|
|
19
18
|
import unittest
|
20
19
|
|
21
|
-
import brainstate
|
20
|
+
import brainstate
|
22
21
|
|
23
22
|
|
24
23
|
class TestNormalInit(unittest.TestCase):
|
25
24
|
|
26
25
|
def test_normal_init1(self):
|
27
|
-
init =
|
26
|
+
init = brainstate.init.Normal()
|
28
27
|
for size in [(100,), (10, 20), (10, 20, 30)]:
|
29
28
|
weights = init(size)
|
30
29
|
assert weights.shape == size
|
31
30
|
|
32
31
|
def test_normal_init2(self):
|
33
|
-
init =
|
32
|
+
init = brainstate.init.Normal(scale=0.5)
|
34
33
|
for size in [(100,), (10, 20)]:
|
35
34
|
weights = init(size)
|
36
35
|
assert weights.shape == size
|
37
36
|
|
38
37
|
def test_normal_init3(self):
|
39
|
-
init1 =
|
40
|
-
init2 =
|
38
|
+
init1 = brainstate.init.Normal(scale=0.5, seed=10)
|
39
|
+
init2 = brainstate.init.Normal(scale=0.5, seed=10)
|
41
40
|
size = (10,)
|
42
41
|
weights1 = init1(size)
|
43
42
|
weights2 = init2(size)
|
@@ -47,13 +46,13 @@ class TestNormalInit(unittest.TestCase):
|
|
47
46
|
|
48
47
|
class TestUniformInit(unittest.TestCase):
|
49
48
|
def test_uniform_init1(self):
|
50
|
-
init =
|
49
|
+
init = brainstate.init.Normal()
|
51
50
|
for size in [(100,), (10, 20), (10, 20, 30)]:
|
52
51
|
weights = init(size)
|
53
52
|
assert weights.shape == size
|
54
53
|
|
55
54
|
def test_uniform_init2(self):
|
56
|
-
init =
|
55
|
+
init = brainstate.init.Uniform(min_val=10, max_val=20)
|
57
56
|
for size in [(100,), (10, 20)]:
|
58
57
|
weights = init(size)
|
59
58
|
assert weights.shape == size
|
@@ -61,20 +60,20 @@ class TestUniformInit(unittest.TestCase):
|
|
61
60
|
|
62
61
|
class TestVarianceScaling(unittest.TestCase):
|
63
62
|
def test_var_scaling1(self):
|
64
|
-
init =
|
63
|
+
init = brainstate.init.VarianceScaling(scale=1., mode='fan_in', distribution='truncated_normal')
|
65
64
|
for size in [(10, 20), (10, 20, 30)]:
|
66
65
|
weights = init(size)
|
67
66
|
assert weights.shape == size
|
68
67
|
|
69
68
|
def test_var_scaling2(self):
|
70
|
-
init =
|
69
|
+
init = brainstate.init.VarianceScaling(scale=2, mode='fan_out', distribution='normal')
|
71
70
|
for size in [(10, 20), (10, 20, 30)]:
|
72
71
|
weights = init(size)
|
73
72
|
assert weights.shape == size
|
74
73
|
|
75
74
|
def test_var_scaling3(self):
|
76
|
-
init =
|
77
|
-
|
75
|
+
init = brainstate.init.VarianceScaling(scale=2 / 4, mode='fan_avg', in_axis=0, out_axis=1,
|
76
|
+
distribution='uniform')
|
78
77
|
for size in [(10, 20), (10, 20, 30)]:
|
79
78
|
weights = init(size)
|
80
79
|
assert weights.shape == size
|
@@ -82,7 +81,7 @@ class TestVarianceScaling(unittest.TestCase):
|
|
82
81
|
|
83
82
|
class TestKaimingUniformUnit(unittest.TestCase):
|
84
83
|
def test_kaiming_uniform_init(self):
|
85
|
-
init =
|
84
|
+
init = brainstate.init.KaimingUniform()
|
86
85
|
for size in [(10, 20), (10, 20, 30)]:
|
87
86
|
weights = init(size)
|
88
87
|
assert weights.shape == size
|
@@ -90,7 +89,7 @@ class TestKaimingUniformUnit(unittest.TestCase):
|
|
90
89
|
|
91
90
|
class TestKaimingNormalUnit(unittest.TestCase):
|
92
91
|
def test_kaiming_normal_init(self):
|
93
|
-
init =
|
92
|
+
init = brainstate.init.KaimingNormal()
|
94
93
|
for size in [(10, 20), (10, 20, 30)]:
|
95
94
|
weights = init(size)
|
96
95
|
assert weights.shape == size
|
@@ -98,7 +97,7 @@ class TestKaimingNormalUnit(unittest.TestCase):
|
|
98
97
|
|
99
98
|
class TestXavierUniformUnit(unittest.TestCase):
|
100
99
|
def test_xavier_uniform_init(self):
|
101
|
-
init =
|
100
|
+
init = brainstate.init.XavierUniform()
|
102
101
|
for size in [(10, 20), (10, 20, 30)]:
|
103
102
|
weights = init(size)
|
104
103
|
assert weights.shape == size
|
@@ -106,7 +105,7 @@ class TestXavierUniformUnit(unittest.TestCase):
|
|
106
105
|
|
107
106
|
class TestXavierNormalUnit(unittest.TestCase):
|
108
107
|
def test_xavier_normal_init(self):
|
109
|
-
init =
|
108
|
+
init = brainstate.init.XavierNormal()
|
110
109
|
for size in [(10, 20), (10, 20, 30)]:
|
111
110
|
weights = init(size)
|
112
111
|
assert weights.shape == size
|
@@ -114,7 +113,7 @@ class TestXavierNormalUnit(unittest.TestCase):
|
|
114
113
|
|
115
114
|
class TestLecunUniformUnit(unittest.TestCase):
|
116
115
|
def test_lecun_uniform_init(self):
|
117
|
-
init =
|
116
|
+
init = brainstate.init.LecunUniform()
|
118
117
|
for size in [(10, 20), (10, 20, 30)]:
|
119
118
|
weights = init(size)
|
120
119
|
assert weights.shape == size
|
@@ -122,7 +121,7 @@ class TestLecunUniformUnit(unittest.TestCase):
|
|
122
121
|
|
123
122
|
class TestLecunNormalUnit(unittest.TestCase):
|
124
123
|
def test_lecun_normal_init(self):
|
125
|
-
init =
|
124
|
+
init = brainstate.init.LecunNormal()
|
126
125
|
for size in [(10, 20), (10, 20, 30)]:
|
127
126
|
weights = init(size)
|
128
127
|
assert weights.shape == size
|
@@ -130,13 +129,13 @@ class TestLecunNormalUnit(unittest.TestCase):
|
|
130
129
|
|
131
130
|
class TestOrthogonalUnit(unittest.TestCase):
|
132
131
|
def test_orthogonal_init1(self):
|
133
|
-
init =
|
132
|
+
init = brainstate.init.Orthogonal()
|
134
133
|
for size in [(20, 20), (10, 20, 30)]:
|
135
134
|
weights = init(size)
|
136
135
|
assert weights.shape == size
|
137
136
|
|
138
137
|
def test_orthogonal_init2(self):
|
139
|
-
init =
|
138
|
+
init = brainstate.init.Orthogonal(scale=2., axis=0)
|
140
139
|
for size in [(10, 20), (10, 20, 30)]:
|
141
140
|
weights = init(size)
|
142
141
|
assert weights.shape == size
|
@@ -144,7 +143,7 @@ class TestOrthogonalUnit(unittest.TestCase):
|
|
144
143
|
|
145
144
|
class TestDeltaOrthogonalUnit(unittest.TestCase):
|
146
145
|
def test_delta_orthogonal_init1(self):
|
147
|
-
init =
|
146
|
+
init = brainstate.init.DeltaOrthogonal()
|
148
147
|
for size in [(20, 20, 20), (10, 20, 30, 40), (50, 40, 30, 20, 20)]:
|
149
148
|
weights = init(size)
|
150
149
|
assert weights.shape == size
|
@@ -14,16 +14,15 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
|
-
from __future__ import annotations
|
18
17
|
|
19
18
|
import unittest
|
20
19
|
|
21
|
-
import brainstate
|
20
|
+
import brainstate
|
22
21
|
|
23
22
|
|
24
23
|
class TestZeroInit(unittest.TestCase):
|
25
24
|
def test_zero_init(self):
|
26
|
-
init =
|
25
|
+
init = brainstate.init.ZeroInit()
|
27
26
|
for size in [(100,), (10, 20), (10, 20, 30)]:
|
28
27
|
weights = init(size)
|
29
28
|
assert weights.shape == size
|
@@ -33,7 +32,7 @@ class TestOneInit(unittest.TestCase):
|
|
33
32
|
def test_one_init(self):
|
34
33
|
for size in [(100,), (10, 20), (10, 20, 30)]:
|
35
34
|
for value in [0., 1., -1.]:
|
36
|
-
init =
|
35
|
+
init = brainstate.init.Constant(value=value)
|
37
36
|
weights = init(size)
|
38
37
|
assert weights.shape == size
|
39
38
|
assert (weights == value).all()
|
@@ -43,7 +42,7 @@ class TestIdentityInit(unittest.TestCase):
|
|
43
42
|
def test_identity_init(self):
|
44
43
|
for size in [(100,), (10, 20)]:
|
45
44
|
for value in [0., 1., -1.]:
|
46
|
-
init =
|
45
|
+
init = brainstate.init.Identity(value=value)
|
47
46
|
weights = init(size)
|
48
47
|
if len(size) == 1:
|
49
48
|
assert weights.shape == (size[0], size[0])
|
brainstate/mixin.py
CHANGED
@@ -132,6 +132,7 @@ class AlignPost(Mixin):
|
|
132
132
|
raise NotImplementedError
|
133
133
|
|
134
134
|
|
135
|
+
|
135
136
|
class BindCondData(Mixin):
|
136
137
|
"""Bind temporary conductance data.
|
137
138
|
|
@@ -147,7 +148,6 @@ class BindCondData(Mixin):
|
|
147
148
|
|
148
149
|
|
149
150
|
class UpdateReturn(Mixin):
|
150
|
-
|
151
151
|
def update_return(self) -> PyTree:
|
152
152
|
"""
|
153
153
|
The update function return of the model.
|
@@ -157,19 +157,6 @@ class UpdateReturn(Mixin):
|
|
157
157
|
"""
|
158
158
|
raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
|
159
159
|
|
160
|
-
def update_return_info(self) -> PyTree:
|
161
|
-
"""
|
162
|
-
The update return information of the model.
|
163
|
-
|
164
|
-
It should be a pytree, with each element as a ``jax.Array``.
|
165
|
-
|
166
|
-
.. note::
|
167
|
-
Should not include the batch axis and batch in_size.
|
168
|
-
These information will be inferred from the ``mode`` attribute.
|
169
|
-
|
170
|
-
"""
|
171
|
-
raise NotImplementedError(f'Must implement the "{self.update_return_info.__name__}()" function.')
|
172
|
-
|
173
160
|
|
174
161
|
class _MetaUnionType(type):
|
175
162
|
def __new__(cls, name, bases, dct):
|
brainstate/nn/__init__.py
CHANGED
@@ -19,46 +19,110 @@ from ._collective_ops import *
|
|
19
19
|
from ._collective_ops import __all__ as collective_ops_all
|
20
20
|
from ._common import *
|
21
21
|
from ._common import __all__ as common_all
|
22
|
-
from .
|
23
|
-
from .
|
22
|
+
from ._conv import *
|
23
|
+
from ._conv import __all__ as conv_all
|
24
|
+
from ._delay import *
|
25
|
+
from ._delay import __all__ as state_delay_all
|
26
|
+
from ._dropout import *
|
27
|
+
from ._dropout import __all__ as dropout_all
|
24
28
|
from ._dynamics import *
|
25
|
-
from ._dynamics import __all__ as
|
29
|
+
from ._dynamics import __all__ as dyn_all
|
26
30
|
from ._elementwise import *
|
27
31
|
from ._elementwise import __all__ as elementwise_all
|
28
|
-
from .
|
29
|
-
from .
|
32
|
+
from ._embedding import *
|
33
|
+
from ._embedding import __all__ as embed_all
|
30
34
|
from ._exp_euler import *
|
31
35
|
from ._exp_euler import __all__ as exp_euler_all
|
32
|
-
from .
|
33
|
-
from .
|
36
|
+
from ._fixedprob_mv import EventFixedProb, EventFixedNumConn
|
37
|
+
from ._inputs import *
|
38
|
+
from ._inputs import __all__ as inputs_all
|
39
|
+
from ._linear import *
|
40
|
+
from ._linear import __all__ as linear_all
|
41
|
+
from ._linear_mv import EventLinear
|
42
|
+
from ._ltp import *
|
43
|
+
from ._ltp import __all__ as ltp_all
|
34
44
|
from ._module import *
|
35
45
|
from ._module import __all__ as module_all
|
46
|
+
from ._neuron import *
|
47
|
+
from ._neuron import __all__ as dyn_neuron_all
|
48
|
+
from ._normalizations import *
|
49
|
+
from ._normalizations import __all__ as normalizations_all
|
50
|
+
from ._poolings import *
|
51
|
+
from ._poolings import __all__ as poolings_all
|
52
|
+
from ._projection import *
|
53
|
+
from ._projection import __all__ as projection_all
|
54
|
+
from ._rate_rnns import *
|
55
|
+
from ._rate_rnns import __all__ as rate_rnns
|
56
|
+
from ._readout import *
|
57
|
+
from ._readout import __all__ as readout_all
|
58
|
+
from ._stp import *
|
59
|
+
from ._stp import __all__ as stp_all
|
60
|
+
from ._synapse import *
|
61
|
+
from ._synapse import __all__ as dyn_synapse_all
|
62
|
+
from ._synaptic_projection import *
|
63
|
+
from ._synaptic_projection import __all__ as _syn_proj_all
|
64
|
+
from ._synouts import *
|
65
|
+
from ._synouts import __all__ as synouts_all
|
36
66
|
from ._utils import *
|
37
67
|
from ._utils import __all__ as utils_all
|
38
68
|
|
39
69
|
__all__ = (
|
40
|
-
[
|
70
|
+
[
|
71
|
+
'metrics',
|
72
|
+
'EventLinear',
|
73
|
+
'EventFixedProb',
|
74
|
+
'EventFixedNumConn',
|
75
|
+
]
|
41
76
|
+ collective_ops_all
|
42
77
|
+ common_all
|
43
|
-
+ dyn_impl_all
|
44
|
-
+ dynamics_all
|
45
78
|
+ elementwise_all
|
46
79
|
+ module_all
|
47
80
|
+ exp_euler_all
|
48
|
-
+ interaction_all
|
49
81
|
+ utils_all
|
50
|
-
+
|
82
|
+
+ dyn_all
|
83
|
+
+ projection_all
|
84
|
+
+ state_delay_all
|
85
|
+
+ synouts_all
|
86
|
+
+ conv_all
|
87
|
+
+ linear_all
|
88
|
+
+ normalizations_all
|
89
|
+
+ poolings_all
|
90
|
+
+ embed_all
|
91
|
+
+ dropout_all
|
92
|
+
+ elementwise_all
|
93
|
+
+ dyn_neuron_all
|
94
|
+
+ dyn_synapse_all
|
95
|
+
+ inputs_all
|
96
|
+
+ rate_rnns
|
97
|
+
+ readout_all
|
98
|
+
+ stp_all
|
99
|
+
+ ltp_all
|
100
|
+
+ _syn_proj_all
|
51
101
|
)
|
52
102
|
|
53
103
|
del (
|
54
104
|
collective_ops_all,
|
55
105
|
common_all,
|
56
|
-
dyn_impl_all,
|
57
|
-
dynamics_all,
|
58
|
-
elementwise_all,
|
59
106
|
module_all,
|
60
107
|
exp_euler_all,
|
61
|
-
interaction_all,
|
62
108
|
utils_all,
|
63
|
-
|
109
|
+
dyn_all,
|
110
|
+
projection_all,
|
111
|
+
state_delay_all,
|
112
|
+
synouts_all,
|
113
|
+
conv_all,
|
114
|
+
linear_all,
|
115
|
+
normalizations_all,
|
116
|
+
poolings_all,
|
117
|
+
embed_all,
|
118
|
+
dropout_all,
|
119
|
+
elementwise_all,
|
120
|
+
dyn_neuron_all,
|
121
|
+
dyn_synapse_all,
|
122
|
+
inputs_all,
|
123
|
+
readout_all,
|
124
|
+
rate_rnns,
|
125
|
+
stp_all,
|
126
|
+
ltp_all,
|
127
|
+
_syn_proj_all,
|
64
128
|
)
|
@@ -16,21 +16,21 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
18
|
|
19
|
-
import brainstate
|
19
|
+
import brainstate
|
20
20
|
|
21
21
|
|
22
22
|
class Test_vmap_init_all_states:
|
23
23
|
|
24
24
|
def test_vmap_init_all_states(self):
|
25
|
-
gru =
|
26
|
-
|
25
|
+
gru = brainstate.nn.GRUCell(1, 2)
|
26
|
+
brainstate.nn.vmap_init_all_states(gru, axis_size=10)
|
27
27
|
print(gru)
|
28
28
|
|
29
29
|
def test_vmap_init_all_states_v2(self):
|
30
|
-
@
|
30
|
+
@brainstate.compile.jit
|
31
31
|
def init():
|
32
|
-
gru =
|
33
|
-
|
32
|
+
gru = brainstate.nn.GRUCell(1, 2)
|
33
|
+
brainstate.nn.vmap_init_all_states(gru, axis_size=10)
|
34
34
|
print(gru)
|
35
35
|
|
36
36
|
init()
|
@@ -38,6 +38,6 @@ class Test_vmap_init_all_states:
|
|
38
38
|
|
39
39
|
class Test_init_all_states:
|
40
40
|
def test_init_all_states(self):
|
41
|
-
gru =
|
42
|
-
|
41
|
+
gru = brainstate.nn.GRUCell(1, 2)
|
42
|
+
brainstate.nn.init_all_states(gru, batch_size=10)
|
43
43
|
print(gru)
|
@@ -23,8 +23,8 @@ import jax.numpy as jnp
|
|
23
23
|
|
24
24
|
from brainstate import init, functional
|
25
25
|
from brainstate._state import ParamState
|
26
|
-
from brainstate.nn._module import Module
|
27
26
|
from brainstate.typing import ArrayLike
|
27
|
+
from ._module import Module
|
28
28
|
|
29
29
|
T = TypeVar('T')
|
30
30
|
|
@@ -1,13 +1,11 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
|
3
|
-
from __future__ import annotations
|
4
|
-
|
5
3
|
import jax.numpy as jnp
|
6
4
|
import pytest
|
7
5
|
from absl.testing import absltest
|
8
6
|
from absl.testing import parameterized
|
9
7
|
|
10
|
-
import brainstate
|
8
|
+
import brainstate
|
11
9
|
|
12
10
|
|
13
11
|
class TestConv(parameterized.TestCase):
|
@@ -19,8 +17,8 @@ class TestConv(parameterized.TestCase):
|
|
19
17
|
img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
|
20
18
|
img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)
|
21
19
|
|
22
|
-
net =
|
23
|
-
|
20
|
+
net = brainstate.nn.Conv2d((200, 198, 4), out_channels=32, kernel_size=(3, 3),
|
21
|
+
stride=(2, 1), padding='VALID', groups=4)
|
24
22
|
out = net(img)
|
25
23
|
print("out shape: ", out.shape)
|
26
24
|
self.assertEqual(out.shape, (2, 99, 196, 32))
|
@@ -30,7 +28,7 @@ class TestConv(parameterized.TestCase):
|
|
30
28
|
# plt.show()
|
31
29
|
|
32
30
|
def test_conv1D(self):
|
33
|
-
model =
|
31
|
+
model = brainstate.nn.Conv1d((5, 3), out_channels=32, kernel_size=(3,))
|
34
32
|
input = jnp.ones((2, 5, 3))
|
35
33
|
out = model(input)
|
36
34
|
print("out shape: ", out.shape)
|
@@ -41,7 +39,7 @@ class TestConv(parameterized.TestCase):
|
|
41
39
|
# plt.show()
|
42
40
|
|
43
41
|
def test_conv2D(self):
|
44
|
-
model =
|
42
|
+
model = brainstate.nn.Conv2d((5, 5, 3), out_channels=32, kernel_size=(3, 3))
|
45
43
|
input = jnp.ones((2, 5, 5, 3))
|
46
44
|
|
47
45
|
out = model(input)
|
@@ -49,7 +47,7 @@ class TestConv(parameterized.TestCase):
|
|
49
47
|
self.assertEqual(out.shape, (2, 5, 5, 32))
|
50
48
|
|
51
49
|
def test_conv3D(self):
|
52
|
-
model =
|
50
|
+
model = brainstate.nn.Conv3d((5, 5, 5, 3), out_channels=32, kernel_size=(3, 3, 3))
|
53
51
|
input = jnp.ones((2, 5, 5, 5, 3))
|
54
52
|
out = model(input)
|
55
53
|
print("out shape: ", out.shape)
|
@@ -62,13 +60,13 @@ class TestConvTranspose1d(parameterized.TestCase):
|
|
62
60
|
|
63
61
|
x = jnp.ones((1, 8, 3))
|
64
62
|
for use_bias in [True, False]:
|
65
|
-
conv_transpose_module =
|
63
|
+
conv_transpose_module = brainstate.nn.ConvTranspose1d(
|
66
64
|
in_channels=3,
|
67
65
|
out_channels=4,
|
68
66
|
kernel_size=(3,),
|
69
67
|
padding='VALID',
|
70
|
-
w_initializer=
|
71
|
-
b_initializer=
|
68
|
+
w_initializer=brainstate.init.Constant(1.),
|
69
|
+
b_initializer=brainstate.init.Constant(1.) if use_bias else None,
|
72
70
|
)
|
73
71
|
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
|
74
72
|
y = conv_transpose_module(x)
|
@@ -91,14 +89,14 @@ class TestConvTranspose1d(parameterized.TestCase):
|
|
91
89
|
|
92
90
|
x = jnp.ones((1, 8, 3))
|
93
91
|
m = jnp.tril(jnp.ones((3, 3, 4)))
|
94
|
-
conv_transpose_module =
|
92
|
+
conv_transpose_module = brainstate.nn.ConvTranspose1d(
|
95
93
|
in_channels=3,
|
96
94
|
out_channels=4,
|
97
95
|
kernel_size=(3,),
|
98
96
|
padding='VALID',
|
99
97
|
mask=m,
|
100
|
-
w_initializer=
|
101
|
-
b_initializer=
|
98
|
+
w_initializer=brainstate.init.Constant(),
|
99
|
+
b_initializer=brainstate.init.Constant(),
|
102
100
|
)
|
103
101
|
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
|
104
102
|
y = conv_transpose_module(x)
|
@@ -119,14 +117,14 @@ class TestConvTranspose1d(parameterized.TestCase):
|
|
119
117
|
|
120
118
|
data = jnp.ones([1, 3, 1])
|
121
119
|
for use_bias in [True, False]:
|
122
|
-
net =
|
120
|
+
net = brainstate.nn.ConvTranspose1d(
|
123
121
|
in_channels=1,
|
124
122
|
out_channels=1,
|
125
123
|
kernel_size=3,
|
126
124
|
stride=1,
|
127
125
|
padding="SAME",
|
128
|
-
w_initializer=
|
129
|
-
b_initializer=
|
126
|
+
w_initializer=brainstate.init.Constant(),
|
127
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
130
128
|
)
|
131
129
|
out = net(data)
|
132
130
|
self.assertEqual(out.shape, (1, 3, 1))
|
@@ -143,13 +141,13 @@ class TestConvTranspose2d(parameterized.TestCase):
|
|
143
141
|
|
144
142
|
x = jnp.ones((1, 8, 8, 3))
|
145
143
|
for use_bias in [True, False]:
|
146
|
-
conv_transpose_module =
|
144
|
+
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
147
145
|
in_channels=3,
|
148
146
|
out_channels=4,
|
149
147
|
kernel_size=(3, 3),
|
150
148
|
padding='VALID',
|
151
|
-
w_initializer=
|
152
|
-
b_initializer=
|
149
|
+
w_initializer=brainstate.init.Constant(),
|
150
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
153
151
|
)
|
154
152
|
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4))
|
155
153
|
y = conv_transpose_module(x)
|
@@ -159,13 +157,13 @@ class TestConvTranspose2d(parameterized.TestCase):
|
|
159
157
|
|
160
158
|
x = jnp.ones((1, 8, 8, 3))
|
161
159
|
m = jnp.tril(jnp.ones((3, 3, 3, 4)))
|
162
|
-
conv_transpose_module =
|
160
|
+
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
163
161
|
in_channels=3,
|
164
162
|
out_channels=4,
|
165
163
|
kernel_size=(3, 3),
|
166
164
|
padding='VALID',
|
167
165
|
mask=m,
|
168
|
-
w_initializer=
|
166
|
+
w_initializer=brainstate.init.Constant(),
|
169
167
|
)
|
170
168
|
y = conv_transpose_module(x)
|
171
169
|
print(y.shape)
|
@@ -174,14 +172,14 @@ class TestConvTranspose2d(parameterized.TestCase):
|
|
174
172
|
|
175
173
|
x = jnp.ones((1, 8, 8, 3))
|
176
174
|
for use_bias in [True, False]:
|
177
|
-
conv_transpose_module =
|
175
|
+
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
178
176
|
in_channels=3,
|
179
177
|
out_channels=4,
|
180
178
|
kernel_size=(3, 3),
|
181
179
|
stride=1,
|
182
180
|
padding='SAME',
|
183
|
-
w_initializer=
|
184
|
-
b_initializer=
|
181
|
+
w_initializer=brainstate.init.Constant(),
|
182
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
185
183
|
)
|
186
184
|
y = conv_transpose_module(x)
|
187
185
|
print(y.shape)
|
@@ -193,13 +191,13 @@ class TestConvTranspose3d(parameterized.TestCase):
|
|
193
191
|
|
194
192
|
x = jnp.ones((1, 8, 8, 8, 3))
|
195
193
|
for use_bias in [True, False]:
|
196
|
-
conv_transpose_module =
|
194
|
+
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
197
195
|
in_channels=3,
|
198
196
|
out_channels=4,
|
199
197
|
kernel_size=(3, 3, 3),
|
200
198
|
padding='VALID',
|
201
|
-
w_initializer=
|
202
|
-
b_initializer=
|
199
|
+
w_initializer=brainstate.init.Constant(),
|
200
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
203
201
|
)
|
204
202
|
y = conv_transpose_module(x)
|
205
203
|
print(y.shape)
|
@@ -208,13 +206,13 @@ class TestConvTranspose3d(parameterized.TestCase):
|
|
208
206
|
|
209
207
|
x = jnp.ones((1, 8, 8, 8, 3))
|
210
208
|
m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
|
211
|
-
conv_transpose_module =
|
209
|
+
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
212
210
|
in_channels=3,
|
213
211
|
out_channels=4,
|
214
212
|
kernel_size=(3, 3, 3),
|
215
213
|
padding='VALID',
|
216
214
|
mask=m,
|
217
|
-
w_initializer=
|
215
|
+
w_initializer=brainstate.init.Constant(),
|
218
216
|
)
|
219
217
|
y = conv_transpose_module(x)
|
220
218
|
print(y.shape)
|
@@ -223,14 +221,14 @@ class TestConvTranspose3d(parameterized.TestCase):
|
|
223
221
|
|
224
222
|
x = jnp.ones((1, 8, 8, 8, 3))
|
225
223
|
for use_bias in [True, False]:
|
226
|
-
conv_transpose_module =
|
224
|
+
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
227
225
|
in_channels=3,
|
228
226
|
out_channels=4,
|
229
227
|
kernel_size=(3, 3, 3),
|
230
228
|
stride=1,
|
231
229
|
padding='SAME',
|
232
|
-
w_initializer=
|
233
|
-
b_initializer=
|
230
|
+
w_initializer=brainstate.init.Constant(),
|
231
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
234
232
|
)
|
235
233
|
y = conv_transpose_module(x)
|
236
234
|
print(y.shape)
|
@@ -27,9 +27,9 @@ from brainstate import environ
|
|
27
27
|
from brainstate._state import ShortTermState, State
|
28
28
|
from brainstate.compile import jit_error_if
|
29
29
|
from brainstate.graph import Node
|
30
|
-
from brainstate.nn._collective_ops import call_order
|
31
|
-
from brainstate.nn._module import Module
|
32
30
|
from brainstate.typing import ArrayLike, PyTree
|
31
|
+
from ._collective_ops import call_order
|
32
|
+
from ._module import Module
|
33
33
|
|
34
34
|
__all__ = [
|
35
35
|
'Delay', 'DelayAccess', 'StateWithDelay',
|
@@ -135,6 +135,7 @@ class Delay(Module):
|
|
135
135
|
entries: Optional[Dict] = None, # delay access entry
|
136
136
|
delay_method: Optional[str] = _DELAY_ROTATE, # delay method
|
137
137
|
interp_method: str = _INTERP_LINEAR, # interpolation method
|
138
|
+
take_aware_unit: bool = False
|
138
139
|
):
|
139
140
|
# target information
|
140
141
|
self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
|
@@ -170,6 +171,9 @@ class Delay(Module):
|
|
170
171
|
for entry, delay_time in entries.items():
|
171
172
|
self.register_entry(entry, delay_time)
|
172
173
|
|
174
|
+
self.take_aware_unit = take_aware_unit
|
175
|
+
self._unit = None
|
176
|
+
|
173
177
|
@property
|
174
178
|
def history(self):
|
175
179
|
return self._history
|
@@ -22,8 +22,8 @@ import jax.numpy as jnp
|
|
22
22
|
|
23
23
|
from brainstate import random, environ, init
|
24
24
|
from brainstate._state import ShortTermState
|
25
|
-
from brainstate.nn._module import ElementWiseBlock
|
26
25
|
from brainstate.typing import Size
|
26
|
+
from ._module import ElementWiseBlock
|
27
27
|
|
28
28
|
__all__ = [
|
29
29
|
'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',
|