brainstate 0.1.8__py2.py3-none-any.whl → 0.1.10__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -156
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +588 -502
- brainstate/nn/_delay_test.py +238 -184
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1343
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1119
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/METADATA +91 -99
- brainstate-0.1.10.dist-info/RECORD +130 -0
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/WHEEL +1 -1
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.8.dist-info/RECORD +0 -132
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/top_level.txt +0 -0
@@ -1,169 +1,169 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from absl.testing import absltest
|
17
|
-
from absl.testing import parameterized
|
18
|
-
|
19
|
-
import brainstate
|
20
|
-
|
21
|
-
|
22
|
-
class Test_Activation(parameterized.TestCase):
|
23
|
-
|
24
|
-
def test_Threshold(self):
|
25
|
-
threshold_layer = brainstate.nn.Threshold(5, 20)
|
26
|
-
input = brainstate.random.randn(2)
|
27
|
-
output = threshold_layer(input)
|
28
|
-
|
29
|
-
def test_ReLU(self):
|
30
|
-
ReLU_layer = brainstate.nn.ReLU()
|
31
|
-
input = brainstate.random.randn(2)
|
32
|
-
output = ReLU_layer(input)
|
33
|
-
|
34
|
-
def test_RReLU(self):
|
35
|
-
RReLU_layer = brainstate.nn.RReLU(lower=0, upper=1)
|
36
|
-
input = brainstate.random.randn(2)
|
37
|
-
output = RReLU_layer(input)
|
38
|
-
|
39
|
-
def test_Hardtanh(self):
|
40
|
-
Hardtanh_layer = brainstate.nn.Hardtanh(min_val=0, max_val=1, )
|
41
|
-
input = brainstate.random.randn(2)
|
42
|
-
output = Hardtanh_layer(input)
|
43
|
-
|
44
|
-
def test_ReLU6(self):
|
45
|
-
ReLU6_layer = brainstate.nn.ReLU6()
|
46
|
-
input = brainstate.random.randn(2)
|
47
|
-
output = ReLU6_layer(input)
|
48
|
-
|
49
|
-
def test_Sigmoid(self):
|
50
|
-
Sigmoid_layer = brainstate.nn.Sigmoid()
|
51
|
-
input = brainstate.random.randn(2)
|
52
|
-
output = Sigmoid_layer(input)
|
53
|
-
|
54
|
-
def test_Hardsigmoid(self):
|
55
|
-
Hardsigmoid_layer = brainstate.nn.Hardsigmoid()
|
56
|
-
input = brainstate.random.randn(2)
|
57
|
-
output = Hardsigmoid_layer(input)
|
58
|
-
|
59
|
-
def test_Tanh(self):
|
60
|
-
Tanh_layer = brainstate.nn.Tanh()
|
61
|
-
input = brainstate.random.randn(2)
|
62
|
-
output = Tanh_layer(input)
|
63
|
-
|
64
|
-
def test_SiLU(self):
|
65
|
-
SiLU_layer = brainstate.nn.SiLU()
|
66
|
-
input = brainstate.random.randn(2)
|
67
|
-
output = SiLU_layer(input)
|
68
|
-
|
69
|
-
def test_Mish(self):
|
70
|
-
Mish_layer = brainstate.nn.Mish()
|
71
|
-
input = brainstate.random.randn(2)
|
72
|
-
output = Mish_layer(input)
|
73
|
-
|
74
|
-
def test_Hardswish(self):
|
75
|
-
Hardswish_layer = brainstate.nn.Hardswish()
|
76
|
-
input = brainstate.random.randn(2)
|
77
|
-
output = Hardswish_layer(input)
|
78
|
-
|
79
|
-
def test_ELU(self):
|
80
|
-
ELU_layer = brainstate.nn.ELU(alpha=0.5, )
|
81
|
-
input = brainstate.random.randn(2)
|
82
|
-
output = ELU_layer(input)
|
83
|
-
|
84
|
-
def test_CELU(self):
|
85
|
-
CELU_layer = brainstate.nn.CELU(alpha=0.5, )
|
86
|
-
input = brainstate.random.randn(2)
|
87
|
-
output = CELU_layer(input)
|
88
|
-
|
89
|
-
def test_SELU(self):
|
90
|
-
SELU_layer = brainstate.nn.SELU()
|
91
|
-
input = brainstate.random.randn(2)
|
92
|
-
output = SELU_layer(input)
|
93
|
-
|
94
|
-
def test_GLU(self):
|
95
|
-
GLU_layer = brainstate.nn.GLU()
|
96
|
-
input = brainstate.random.randn(4, 2)
|
97
|
-
output = GLU_layer(input)
|
98
|
-
|
99
|
-
@parameterized.product(
|
100
|
-
approximate=['tanh', 'none']
|
101
|
-
)
|
102
|
-
def test_GELU(self, approximate):
|
103
|
-
GELU_layer = brainstate.nn.GELU()
|
104
|
-
input = brainstate.random.randn(2)
|
105
|
-
output = GELU_layer(input)
|
106
|
-
|
107
|
-
def test_Hardshrink(self):
|
108
|
-
Hardshrink_layer = brainstate.nn.Hardshrink(lambd=1)
|
109
|
-
input = brainstate.random.randn(2)
|
110
|
-
output = Hardshrink_layer(input)
|
111
|
-
|
112
|
-
def test_LeakyReLU(self):
|
113
|
-
LeakyReLU_layer = brainstate.nn.LeakyReLU()
|
114
|
-
input = brainstate.random.randn(2)
|
115
|
-
output = LeakyReLU_layer(input)
|
116
|
-
|
117
|
-
def test_LogSigmoid(self):
|
118
|
-
LogSigmoid_layer = brainstate.nn.LogSigmoid()
|
119
|
-
input = brainstate.random.randn(2)
|
120
|
-
output = LogSigmoid_layer(input)
|
121
|
-
|
122
|
-
def test_Softplus(self):
|
123
|
-
Softplus_layer = brainstate.nn.Softplus()
|
124
|
-
input = brainstate.random.randn(2)
|
125
|
-
output = Softplus_layer(input)
|
126
|
-
|
127
|
-
def test_Softshrink(self):
|
128
|
-
Softshrink_layer = brainstate.nn.Softshrink(lambd=1)
|
129
|
-
input = brainstate.random.randn(2)
|
130
|
-
output = Softshrink_layer(input)
|
131
|
-
|
132
|
-
def test_PReLU(self):
|
133
|
-
PReLU_layer = brainstate.nn.PReLU(num_parameters=2, init=0.5)
|
134
|
-
input = brainstate.random.randn(2)
|
135
|
-
output = PReLU_layer(input)
|
136
|
-
|
137
|
-
def test_Softsign(self):
|
138
|
-
Softsign_layer = brainstate.nn.Softsign()
|
139
|
-
input = brainstate.random.randn(2)
|
140
|
-
output = Softsign_layer(input)
|
141
|
-
|
142
|
-
def test_Tanhshrink(self):
|
143
|
-
Tanhshrink_layer = brainstate.nn.Tanhshrink()
|
144
|
-
input = brainstate.random.randn(2)
|
145
|
-
output = Tanhshrink_layer(input)
|
146
|
-
|
147
|
-
def test_Softmin(self):
|
148
|
-
Softmin_layer = brainstate.nn.Softmin(dim=2)
|
149
|
-
input = brainstate.random.randn(2, 3, 4)
|
150
|
-
output = Softmin_layer(input)
|
151
|
-
|
152
|
-
def test_Softmax(self):
|
153
|
-
Softmax_layer = brainstate.nn.Softmax(dim=2)
|
154
|
-
input = brainstate.random.randn(2, 3, 4)
|
155
|
-
output = Softmax_layer(input)
|
156
|
-
|
157
|
-
def test_Softmax2d(self):
|
158
|
-
Softmax2d_layer = brainstate.nn.Softmax2d()
|
159
|
-
input = brainstate.random.randn(2, 3, 12, 13)
|
160
|
-
output = Softmax2d_layer(input)
|
161
|
-
|
162
|
-
def test_LogSoftmax(self):
|
163
|
-
LogSoftmax_layer = brainstate.nn.LogSoftmax(dim=2)
|
164
|
-
input = brainstate.random.randn(2, 3, 4)
|
165
|
-
output = LogSoftmax_layer(input)
|
166
|
-
|
167
|
-
|
168
|
-
if __name__ == '__main__':
|
169
|
-
absltest.main()
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from absl.testing import absltest
|
17
|
+
from absl.testing import parameterized
|
18
|
+
|
19
|
+
import brainstate
|
20
|
+
|
21
|
+
|
22
|
+
class Test_Activation(parameterized.TestCase):
|
23
|
+
|
24
|
+
def test_Threshold(self):
|
25
|
+
threshold_layer = brainstate.nn.Threshold(5, 20)
|
26
|
+
input = brainstate.random.randn(2)
|
27
|
+
output = threshold_layer(input)
|
28
|
+
|
29
|
+
def test_ReLU(self):
|
30
|
+
ReLU_layer = brainstate.nn.ReLU()
|
31
|
+
input = brainstate.random.randn(2)
|
32
|
+
output = ReLU_layer(input)
|
33
|
+
|
34
|
+
def test_RReLU(self):
|
35
|
+
RReLU_layer = brainstate.nn.RReLU(lower=0, upper=1)
|
36
|
+
input = brainstate.random.randn(2)
|
37
|
+
output = RReLU_layer(input)
|
38
|
+
|
39
|
+
def test_Hardtanh(self):
|
40
|
+
Hardtanh_layer = brainstate.nn.Hardtanh(min_val=0, max_val=1, )
|
41
|
+
input = brainstate.random.randn(2)
|
42
|
+
output = Hardtanh_layer(input)
|
43
|
+
|
44
|
+
def test_ReLU6(self):
|
45
|
+
ReLU6_layer = brainstate.nn.ReLU6()
|
46
|
+
input = brainstate.random.randn(2)
|
47
|
+
output = ReLU6_layer(input)
|
48
|
+
|
49
|
+
def test_Sigmoid(self):
|
50
|
+
Sigmoid_layer = brainstate.nn.Sigmoid()
|
51
|
+
input = brainstate.random.randn(2)
|
52
|
+
output = Sigmoid_layer(input)
|
53
|
+
|
54
|
+
def test_Hardsigmoid(self):
|
55
|
+
Hardsigmoid_layer = brainstate.nn.Hardsigmoid()
|
56
|
+
input = brainstate.random.randn(2)
|
57
|
+
output = Hardsigmoid_layer(input)
|
58
|
+
|
59
|
+
def test_Tanh(self):
|
60
|
+
Tanh_layer = brainstate.nn.Tanh()
|
61
|
+
input = brainstate.random.randn(2)
|
62
|
+
output = Tanh_layer(input)
|
63
|
+
|
64
|
+
def test_SiLU(self):
|
65
|
+
SiLU_layer = brainstate.nn.SiLU()
|
66
|
+
input = brainstate.random.randn(2)
|
67
|
+
output = SiLU_layer(input)
|
68
|
+
|
69
|
+
def test_Mish(self):
|
70
|
+
Mish_layer = brainstate.nn.Mish()
|
71
|
+
input = brainstate.random.randn(2)
|
72
|
+
output = Mish_layer(input)
|
73
|
+
|
74
|
+
def test_Hardswish(self):
|
75
|
+
Hardswish_layer = brainstate.nn.Hardswish()
|
76
|
+
input = brainstate.random.randn(2)
|
77
|
+
output = Hardswish_layer(input)
|
78
|
+
|
79
|
+
def test_ELU(self):
|
80
|
+
ELU_layer = brainstate.nn.ELU(alpha=0.5, )
|
81
|
+
input = brainstate.random.randn(2)
|
82
|
+
output = ELU_layer(input)
|
83
|
+
|
84
|
+
def test_CELU(self):
|
85
|
+
CELU_layer = brainstate.nn.CELU(alpha=0.5, )
|
86
|
+
input = brainstate.random.randn(2)
|
87
|
+
output = CELU_layer(input)
|
88
|
+
|
89
|
+
def test_SELU(self):
|
90
|
+
SELU_layer = brainstate.nn.SELU()
|
91
|
+
input = brainstate.random.randn(2)
|
92
|
+
output = SELU_layer(input)
|
93
|
+
|
94
|
+
def test_GLU(self):
|
95
|
+
GLU_layer = brainstate.nn.GLU()
|
96
|
+
input = brainstate.random.randn(4, 2)
|
97
|
+
output = GLU_layer(input)
|
98
|
+
|
99
|
+
@parameterized.product(
|
100
|
+
approximate=['tanh', 'none']
|
101
|
+
)
|
102
|
+
def test_GELU(self, approximate):
|
103
|
+
GELU_layer = brainstate.nn.GELU()
|
104
|
+
input = brainstate.random.randn(2)
|
105
|
+
output = GELU_layer(input)
|
106
|
+
|
107
|
+
def test_Hardshrink(self):
|
108
|
+
Hardshrink_layer = brainstate.nn.Hardshrink(lambd=1)
|
109
|
+
input = brainstate.random.randn(2)
|
110
|
+
output = Hardshrink_layer(input)
|
111
|
+
|
112
|
+
def test_LeakyReLU(self):
|
113
|
+
LeakyReLU_layer = brainstate.nn.LeakyReLU()
|
114
|
+
input = brainstate.random.randn(2)
|
115
|
+
output = LeakyReLU_layer(input)
|
116
|
+
|
117
|
+
def test_LogSigmoid(self):
|
118
|
+
LogSigmoid_layer = brainstate.nn.LogSigmoid()
|
119
|
+
input = brainstate.random.randn(2)
|
120
|
+
output = LogSigmoid_layer(input)
|
121
|
+
|
122
|
+
def test_Softplus(self):
|
123
|
+
Softplus_layer = brainstate.nn.Softplus()
|
124
|
+
input = brainstate.random.randn(2)
|
125
|
+
output = Softplus_layer(input)
|
126
|
+
|
127
|
+
def test_Softshrink(self):
|
128
|
+
Softshrink_layer = brainstate.nn.Softshrink(lambd=1)
|
129
|
+
input = brainstate.random.randn(2)
|
130
|
+
output = Softshrink_layer(input)
|
131
|
+
|
132
|
+
def test_PReLU(self):
|
133
|
+
PReLU_layer = brainstate.nn.PReLU(num_parameters=2, init=0.5)
|
134
|
+
input = brainstate.random.randn(2)
|
135
|
+
output = PReLU_layer(input)
|
136
|
+
|
137
|
+
def test_Softsign(self):
|
138
|
+
Softsign_layer = brainstate.nn.Softsign()
|
139
|
+
input = brainstate.random.randn(2)
|
140
|
+
output = Softsign_layer(input)
|
141
|
+
|
142
|
+
def test_Tanhshrink(self):
|
143
|
+
Tanhshrink_layer = brainstate.nn.Tanhshrink()
|
144
|
+
input = brainstate.random.randn(2)
|
145
|
+
output = Tanhshrink_layer(input)
|
146
|
+
|
147
|
+
def test_Softmin(self):
|
148
|
+
Softmin_layer = brainstate.nn.Softmin(dim=2)
|
149
|
+
input = brainstate.random.randn(2, 3, 4)
|
150
|
+
output = Softmin_layer(input)
|
151
|
+
|
152
|
+
def test_Softmax(self):
|
153
|
+
Softmax_layer = brainstate.nn.Softmax(dim=2)
|
154
|
+
input = brainstate.random.randn(2, 3, 4)
|
155
|
+
output = Softmax_layer(input)
|
156
|
+
|
157
|
+
def test_Softmax2d(self):
|
158
|
+
Softmax2d_layer = brainstate.nn.Softmax2d()
|
159
|
+
input = brainstate.random.randn(2, 3, 12, 13)
|
160
|
+
output = Softmax2d_layer(input)
|
161
|
+
|
162
|
+
def test_LogSoftmax(self):
|
163
|
+
LogSoftmax_layer = brainstate.nn.LogSoftmax(dim=2)
|
164
|
+
input = brainstate.random.randn(2, 3, 4)
|
165
|
+
output = LogSoftmax_layer(input)
|
166
|
+
|
167
|
+
|
168
|
+
if __name__ == '__main__':
|
169
|
+
absltest.main()
|
brainstate/nn/_embedding.py
CHANGED
@@ -1,58 +1,58 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from typing import Optional, Callable, Union
|
17
|
-
|
18
|
-
from brainstate import init
|
19
|
-
from brainstate._state import ParamState
|
20
|
-
from brainstate.typing import ArrayLike
|
21
|
-
from ._module import Module
|
22
|
-
|
23
|
-
__all__ = [
|
24
|
-
'Embedding',
|
25
|
-
]
|
26
|
-
|
27
|
-
|
28
|
-
class Embedding(Module):
|
29
|
-
r"""
|
30
|
-
A simple lookup table that stores embeddings of a fixed size.
|
31
|
-
|
32
|
-
Args:
|
33
|
-
num_embeddings: Size of embedding dictionary. Must be non-negative.
|
34
|
-
embedding_size: Size of each embedding vector. Must be non-negative.
|
35
|
-
embedding_init: The initializer for the embedding lookup table, of shape `(num_embeddings, embedding_size)`.
|
36
|
-
"""
|
37
|
-
|
38
|
-
def __init__(
|
39
|
-
self,
|
40
|
-
num_embeddings: int,
|
41
|
-
embedding_size: int,
|
42
|
-
embedding_init: Union[Callable, ArrayLike] = init.LecunUniform(),
|
43
|
-
name: Optional[str] = None,
|
44
|
-
):
|
45
|
-
super().__init__(name=name)
|
46
|
-
if num_embeddings < 0:
|
47
|
-
raise ValueError("num_embeddings must not be negative.")
|
48
|
-
if embedding_size < 0:
|
49
|
-
raise ValueError("embedding_size must not be negative.")
|
50
|
-
self.num_embeddings = num_embeddings
|
51
|
-
self.embedding_size = embedding_size
|
52
|
-
self.out_size = (embedding_size,)
|
53
|
-
|
54
|
-
weight = init.param(embedding_init, (self.num_embeddings, self.embedding_size))
|
55
|
-
self.weight = ParamState(weight)
|
56
|
-
|
57
|
-
def update(self, indices: ArrayLike):
|
58
|
-
return self.weight.value[indices]
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Optional, Callable, Union
|
17
|
+
|
18
|
+
from brainstate import init
|
19
|
+
from brainstate._state import ParamState
|
20
|
+
from brainstate.typing import ArrayLike
|
21
|
+
from ._module import Module
|
22
|
+
|
23
|
+
__all__ = [
|
24
|
+
'Embedding',
|
25
|
+
]
|
26
|
+
|
27
|
+
|
28
|
+
class Embedding(Module):
|
29
|
+
r"""
|
30
|
+
A simple lookup table that stores embeddings of a fixed size.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
num_embeddings: Size of embedding dictionary. Must be non-negative.
|
34
|
+
embedding_size: Size of each embedding vector. Must be non-negative.
|
35
|
+
embedding_init: The initializer for the embedding lookup table, of shape `(num_embeddings, embedding_size)`.
|
36
|
+
"""
|
37
|
+
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
num_embeddings: int,
|
41
|
+
embedding_size: int,
|
42
|
+
embedding_init: Union[Callable, ArrayLike] = init.LecunUniform(),
|
43
|
+
name: Optional[str] = None,
|
44
|
+
):
|
45
|
+
super().__init__(name=name)
|
46
|
+
if num_embeddings < 0:
|
47
|
+
raise ValueError("num_embeddings must not be negative.")
|
48
|
+
if embedding_size < 0:
|
49
|
+
raise ValueError("embedding_size must not be negative.")
|
50
|
+
self.num_embeddings = num_embeddings
|
51
|
+
self.embedding_size = embedding_size
|
52
|
+
self.out_size = (embedding_size,)
|
53
|
+
|
54
|
+
weight = init.param(embedding_init, (self.num_embeddings, self.embedding_size))
|
55
|
+
self.weight = ParamState(weight)
|
56
|
+
|
57
|
+
def update(self, indices: ArrayLike):
|
58
|
+
return self.weight.value[indices]
|
brainstate/nn/_exp_euler.py
CHANGED
@@ -1,92 +1,92 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
from typing import Callable
|
18
|
-
|
19
|
-
import brainunit as u
|
20
|
-
import jax.numpy as jnp
|
21
|
-
|
22
|
-
from brainstate import environ, random
|
23
|
-
from brainstate.augment import vector_grad
|
24
|
-
|
25
|
-
__all__ = [
|
26
|
-
'exp_euler_step',
|
27
|
-
]
|
28
|
-
|
29
|
-
|
30
|
-
def exp_euler_step(
|
31
|
-
fn: Callable, *args, **kwargs
|
32
|
-
):
|
33
|
-
r"""
|
34
|
-
One-step Exponential Euler method for solving ODEs.
|
35
|
-
|
36
|
-
Examples
|
37
|
-
--------
|
38
|
-
|
39
|
-
>>> def fun(x, t):
|
40
|
-
... return -x
|
41
|
-
>>> x = 1.0
|
42
|
-
>>> exp_euler_step(fun, x, None)
|
43
|
-
|
44
|
-
If the variable ( $x$ ) has units of ( $[X]$ ), then the drift term ( $\text{drift_fn}(x)$ ) should
|
45
|
-
have units of ( $[X]/[T]$ ), where ( $[T]$ ) is the unit of time.
|
46
|
-
|
47
|
-
If the variable ( x ) has units of ( [X] ), then the diffusion term ( \text{diffusion_fn}(x) )
|
48
|
-
should have units of ( [X]/\sqrt{[T]} ).
|
49
|
-
|
50
|
-
Args:
|
51
|
-
fun: Callable. The function to be solved.
|
52
|
-
diffusion: Callable. The diffusion function.
|
53
|
-
*args: The input arguments.
|
54
|
-
drift: Callable. The drift function.
|
55
|
-
|
56
|
-
Returns:
|
57
|
-
The one-step solution of the ODE.
|
58
|
-
"""
|
59
|
-
assert callable(fn), 'The input function should be callable.'
|
60
|
-
assert len(args) > 0, 'The input arguments should not be empty.'
|
61
|
-
if callable(args[0]):
|
62
|
-
diffusion = args[0]
|
63
|
-
args = args[1:]
|
64
|
-
else:
|
65
|
-
diffusion = None
|
66
|
-
assert len(args) > 0, 'The input arguments should not be empty.'
|
67
|
-
if u.math.get_dtype(args[0]) not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]:
|
68
|
-
raise ValueError(
|
69
|
-
f'The input data type should be float64, float32, float16, or bfloat16 '
|
70
|
-
f'when using Exponential Euler method. But we got {args[0].dtype}.'
|
71
|
-
)
|
72
|
-
|
73
|
-
# drift
|
74
|
-
dt = environ.get('dt')
|
75
|
-
linear, derivative = vector_grad(fn, argnums=0, return_value=True)(*args, **kwargs)
|
76
|
-
linear = u.Quantity(u.get_mantissa(linear), u.get_unit(derivative) / u.get_unit(linear))
|
77
|
-
phi = u.math.exprel(dt * linear)
|
78
|
-
x_next = args[0] + dt * phi * derivative
|
79
|
-
|
80
|
-
# diffusion
|
81
|
-
if diffusion is not None:
|
82
|
-
diffusion_part = diffusion(*args, **kwargs) * u.math.sqrt(dt) * random.randn_like(args[0])
|
83
|
-
if u.get_dim(x_next) != u.get_dim(diffusion_part):
|
84
|
-
drift_unit = u.get_unit(x_next)
|
85
|
-
time_unit = u.get_unit(dt)
|
86
|
-
raise ValueError(
|
87
|
-
f"Drift unit is {drift_unit}, "
|
88
|
-
f"expected diffusion unit is {drift_unit / time_unit ** 0.5}, "
|
89
|
-
f"but we got {u.get_unit(diffusion_part)}."
|
90
|
-
)
|
91
|
-
x_next += diffusion_part
|
92
|
-
return x_next
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
from typing import Callable
|
18
|
+
|
19
|
+
import brainunit as u
|
20
|
+
import jax.numpy as jnp
|
21
|
+
|
22
|
+
from brainstate import environ, random
|
23
|
+
from brainstate.augment import vector_grad
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
'exp_euler_step',
|
27
|
+
]
|
28
|
+
|
29
|
+
|
30
|
+
def exp_euler_step(
|
31
|
+
fn: Callable, *args, **kwargs
|
32
|
+
):
|
33
|
+
r"""
|
34
|
+
One-step Exponential Euler method for solving ODEs.
|
35
|
+
|
36
|
+
Examples
|
37
|
+
--------
|
38
|
+
|
39
|
+
>>> def fun(x, t):
|
40
|
+
... return -x
|
41
|
+
>>> x = 1.0
|
42
|
+
>>> exp_euler_step(fun, x, None)
|
43
|
+
|
44
|
+
If the variable ( $x$ ) has units of ( $[X]$ ), then the drift term ( $\text{drift_fn}(x)$ ) should
|
45
|
+
have units of ( $[X]/[T]$ ), where ( $[T]$ ) is the unit of time.
|
46
|
+
|
47
|
+
If the variable ( x ) has units of ( [X] ), then the diffusion term ( \text{diffusion_fn}(x) )
|
48
|
+
should have units of ( [X]/\sqrt{[T]} ).
|
49
|
+
|
50
|
+
Args:
|
51
|
+
fun: Callable. The function to be solved.
|
52
|
+
diffusion: Callable. The diffusion function.
|
53
|
+
*args: The input arguments.
|
54
|
+
drift: Callable. The drift function.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
The one-step solution of the ODE.
|
58
|
+
"""
|
59
|
+
assert callable(fn), 'The input function should be callable.'
|
60
|
+
assert len(args) > 0, 'The input arguments should not be empty.'
|
61
|
+
if callable(args[0]):
|
62
|
+
diffusion = args[0]
|
63
|
+
args = args[1:]
|
64
|
+
else:
|
65
|
+
diffusion = None
|
66
|
+
assert len(args) > 0, 'The input arguments should not be empty.'
|
67
|
+
if u.math.get_dtype(args[0]) not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]:
|
68
|
+
raise ValueError(
|
69
|
+
f'The input data type should be float64, float32, float16, or bfloat16 '
|
70
|
+
f'when using Exponential Euler method. But we got {args[0].dtype}.'
|
71
|
+
)
|
72
|
+
|
73
|
+
# drift
|
74
|
+
dt = environ.get('dt')
|
75
|
+
linear, derivative = vector_grad(fn, argnums=0, return_value=True)(*args, **kwargs)
|
76
|
+
linear = u.Quantity(u.get_mantissa(linear), u.get_unit(derivative) / u.get_unit(linear))
|
77
|
+
phi = u.math.exprel(dt * linear)
|
78
|
+
x_next = args[0] + dt * phi * derivative
|
79
|
+
|
80
|
+
# diffusion
|
81
|
+
if diffusion is not None:
|
82
|
+
diffusion_part = diffusion(*args, **kwargs) * u.math.sqrt(dt) * random.randn_like(args[0])
|
83
|
+
if u.get_dim(x_next) != u.get_dim(diffusion_part):
|
84
|
+
drift_unit = u.get_unit(x_next)
|
85
|
+
time_unit = u.get_unit(dt)
|
86
|
+
raise ValueError(
|
87
|
+
f"Drift unit is {drift_unit}, "
|
88
|
+
f"expected diffusion unit is {drift_unit / time_unit ** 0.5}, "
|
89
|
+
f"but we got {u.get_unit(diffusion_part)}."
|
90
|
+
)
|
91
|
+
x_next += diffusion_part
|
92
|
+
return x_next
|