brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_elementwise.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -49,20 +49,26 @@ class Threshold(ElementWiseBlock):
|
|
49
49
|
\text{value}, &\text{ otherwise }
|
50
50
|
\end{cases}
|
51
51
|
|
52
|
-
|
53
|
-
|
54
|
-
|
52
|
+
Parameters
|
53
|
+
----------
|
54
|
+
threshold : float
|
55
|
+
The value to threshold at.
|
56
|
+
value : float
|
57
|
+
The value to replace with.
|
55
58
|
|
56
|
-
Shape
|
57
|
-
|
58
|
-
|
59
|
+
Shape
|
60
|
+
-----
|
61
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
62
|
+
- Output: :math:`(*)`, same shape as the input.
|
59
63
|
|
60
|
-
Examples
|
64
|
+
Examples
|
65
|
+
--------
|
66
|
+
.. code-block:: python
|
61
67
|
|
62
68
|
>>> import brainstate.nn as nn
|
63
69
|
>>> import brainstate
|
64
70
|
>>> m = nn.Threshold(0.1, 20)
|
65
|
-
>>> x = random.randn(2)
|
71
|
+
>>> x = brainstate.random.randn(2)
|
66
72
|
>>> output = m(x)
|
67
73
|
"""
|
68
74
|
__module__ = 'brainstate.nn'
|
@@ -85,30 +91,38 @@ class Threshold(ElementWiseBlock):
|
|
85
91
|
|
86
92
|
|
87
93
|
class ReLU(ElementWiseBlock):
|
88
|
-
r"""Applies the rectified linear unit function element-wise
|
94
|
+
r"""Applies the rectified linear unit function element-wise.
|
95
|
+
|
96
|
+
The ReLU function is defined as:
|
89
97
|
|
90
|
-
|
98
|
+
.. math::
|
99
|
+
\text{ReLU}(x) = (x)^+ = \max(0, x)
|
91
100
|
|
92
|
-
Shape
|
93
|
-
|
94
|
-
|
101
|
+
Shape
|
102
|
+
-----
|
103
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
104
|
+
- Output: :math:`(*)`, same shape as the input.
|
95
105
|
|
96
|
-
Examples
|
106
|
+
Examples
|
107
|
+
--------
|
108
|
+
.. code-block:: python
|
97
109
|
|
98
110
|
>>> import brainstate.nn as nn
|
99
|
-
>>> import brainstate
|
111
|
+
>>> import brainstate
|
100
112
|
>>> m = nn.ReLU()
|
101
|
-
>>> x = random.randn(2)
|
113
|
+
>>> x = brainstate.random.randn(2)
|
102
114
|
>>> output = m(x)
|
103
115
|
|
116
|
+
An implementation of CReLU - https://arxiv.org/abs/1603.05201
|
104
117
|
|
105
|
-
|
118
|
+
.. code-block:: python
|
106
119
|
|
107
120
|
>>> import brainstate.nn as nn
|
108
|
-
>>> import brainstate
|
121
|
+
>>> import brainstate
|
122
|
+
>>> import jax.numpy as jnp
|
109
123
|
>>> m = nn.ReLU()
|
110
|
-
>>> x = random.randn(2).unsqueeze(0)
|
111
|
-
>>> output =
|
124
|
+
>>> x = brainstate.random.randn(2).unsqueeze(0)
|
125
|
+
>>> output = jnp.concat((m(x), m(-x)))
|
112
126
|
"""
|
113
127
|
__module__ = 'brainstate.nn'
|
114
128
|
|
@@ -120,10 +134,10 @@ class ReLU(ElementWiseBlock):
|
|
120
134
|
|
121
135
|
|
122
136
|
class RReLU(ElementWiseBlock):
|
123
|
-
r"""Applies the randomized leaky rectified liner unit function, element-wise
|
124
|
-
as described in the paper:
|
137
|
+
r"""Applies the randomized leaky rectified liner unit function, element-wise.
|
125
138
|
|
126
|
-
`Empirical Evaluation of Rectified Activations in
|
139
|
+
As described in the paper `Empirical Evaluation of Rectified Activations in
|
140
|
+
Convolutional Network`_.
|
127
141
|
|
128
142
|
The function is defined as:
|
129
143
|
|
@@ -137,26 +151,32 @@ class RReLU(ElementWiseBlock):
|
|
137
151
|
where :math:`a` is randomly sampled from uniform distribution
|
138
152
|
:math:`\mathcal{U}(\text{lower}, \text{upper})`.
|
139
153
|
|
140
|
-
|
154
|
+
Parameters
|
155
|
+
----------
|
156
|
+
lower : float, optional
|
157
|
+
Lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
|
158
|
+
upper : float, optional
|
159
|
+
Upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
|
141
160
|
|
142
|
-
|
143
|
-
|
144
|
-
|
161
|
+
Shape
|
162
|
+
-----
|
163
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
164
|
+
- Output: :math:`(*)`, same shape as the input.
|
145
165
|
|
146
|
-
|
147
|
-
|
148
|
-
|
166
|
+
References
|
167
|
+
----------
|
168
|
+
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
|
169
|
+
https://arxiv.org/abs/1505.00853
|
149
170
|
|
150
|
-
Examples
|
171
|
+
Examples
|
172
|
+
--------
|
173
|
+
.. code-block:: python
|
151
174
|
|
152
175
|
>>> import brainstate.nn as nn
|
153
|
-
>>> import brainstate
|
176
|
+
>>> import brainstate
|
154
177
|
>>> m = nn.RReLU(0.1, 0.3)
|
155
|
-
>>> x = random.randn(2)
|
178
|
+
>>> x = brainstate.random.randn(2)
|
156
179
|
>>> output = m(x)
|
157
|
-
|
158
|
-
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
|
159
|
-
https://arxiv.org/abs/1505.00853
|
160
180
|
"""
|
161
181
|
__module__ = 'brainstate.nn'
|
162
182
|
lower: float
|
@@ -190,23 +210,31 @@ class Hardtanh(ElementWiseBlock):
|
|
190
210
|
x & \text{ otherwise } \\
|
191
211
|
\end{cases}
|
192
212
|
|
193
|
-
|
194
|
-
|
195
|
-
|
213
|
+
Parameters
|
214
|
+
----------
|
215
|
+
min_val : float, optional
|
216
|
+
Minimum value of the linear region range. Default: -1
|
217
|
+
max_val : float, optional
|
218
|
+
Maximum value of the linear region range. Default: 1
|
196
219
|
|
220
|
+
Notes
|
221
|
+
-----
|
197
222
|
Keyword arguments :attr:`min_value` and :attr:`max_value`
|
198
223
|
have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
|
199
224
|
|
200
|
-
Shape
|
201
|
-
|
202
|
-
|
225
|
+
Shape
|
226
|
+
-----
|
227
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
228
|
+
- Output: :math:`(*)`, same shape as the input.
|
203
229
|
|
204
|
-
Examples
|
230
|
+
Examples
|
231
|
+
--------
|
232
|
+
.. code-block:: python
|
205
233
|
|
206
234
|
>>> import brainstate.nn as nn
|
207
|
-
>>> import brainstate
|
235
|
+
>>> import brainstate
|
208
236
|
>>> m = nn.Hardtanh(-2, 2)
|
209
|
-
>>> x = random.randn(2)
|
237
|
+
>>> x = brainstate.random.randn(2)
|
210
238
|
>>> output = m(x)
|
211
239
|
"""
|
212
240
|
__module__ = 'brainstate.nn'
|
@@ -230,22 +258,27 @@ class Hardtanh(ElementWiseBlock):
|
|
230
258
|
return f'{self.__class__.__name__}(min_val={self.min_val}, max_val={self.max_val})'
|
231
259
|
|
232
260
|
|
233
|
-
class ReLU6(Hardtanh
|
234
|
-
r"""Applies the element-wise function
|
261
|
+
class ReLU6(Hardtanh):
|
262
|
+
r"""Applies the element-wise function.
|
263
|
+
|
264
|
+
ReLU6 is defined as:
|
235
265
|
|
236
266
|
.. math::
|
237
267
|
\text{ReLU6}(x) = \min(\max(0,x), 6)
|
238
268
|
|
239
|
-
Shape
|
240
|
-
|
241
|
-
|
269
|
+
Shape
|
270
|
+
-----
|
271
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
272
|
+
- Output: :math:`(*)`, same shape as the input.
|
242
273
|
|
243
|
-
Examples
|
274
|
+
Examples
|
275
|
+
--------
|
276
|
+
.. code-block:: python
|
244
277
|
|
245
278
|
>>> import brainstate.nn as nn
|
246
|
-
>>> import brainstate
|
279
|
+
>>> import brainstate
|
247
280
|
>>> m = nn.ReLU6()
|
248
|
-
>>> x = random.randn(2)
|
281
|
+
>>> x = brainstate.random.randn(2)
|
249
282
|
>>> output = m(x)
|
250
283
|
"""
|
251
284
|
__module__ = 'brainstate.nn'
|
@@ -255,22 +288,26 @@ class ReLU6(Hardtanh, ElementWiseBlock):
|
|
255
288
|
|
256
289
|
|
257
290
|
class Sigmoid(ElementWiseBlock):
|
258
|
-
r"""Applies the element-wise function
|
291
|
+
r"""Applies the element-wise function.
|
292
|
+
|
293
|
+
Sigmoid is defined as:
|
259
294
|
|
260
295
|
.. math::
|
261
296
|
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
|
262
297
|
|
298
|
+
Shape
|
299
|
+
-----
|
300
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
301
|
+
- Output: :math:`(*)`, same shape as the input.
|
263
302
|
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
Examples::
|
303
|
+
Examples
|
304
|
+
--------
|
305
|
+
.. code-block:: python
|
269
306
|
|
270
307
|
>>> import brainstate.nn as nn
|
271
|
-
>>> import brainstate
|
308
|
+
>>> import brainstate
|
272
309
|
>>> m = nn.Sigmoid()
|
273
|
-
>>> x = random.randn(2)
|
310
|
+
>>> x = brainstate.random.randn(2)
|
274
311
|
>>> output = m(x)
|
275
312
|
"""
|
276
313
|
__module__ = 'brainstate.nn'
|
@@ -291,16 +328,19 @@ class Hardsigmoid(ElementWiseBlock):
|
|
291
328
|
x / 6 + 1 / 2 & \text{otherwise}
|
292
329
|
\end{cases}
|
293
330
|
|
294
|
-
Shape
|
295
|
-
|
296
|
-
|
331
|
+
Shape
|
332
|
+
-----
|
333
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
334
|
+
- Output: :math:`(*)`, same shape as the input.
|
297
335
|
|
298
|
-
Examples
|
336
|
+
Examples
|
337
|
+
--------
|
338
|
+
.. code-block:: python
|
299
339
|
|
300
340
|
>>> import brainstate.nn as nn
|
301
|
-
>>> import brainstate
|
341
|
+
>>> import brainstate
|
302
342
|
>>> m = nn.Hardsigmoid()
|
303
|
-
>>> x = random.randn(2)
|
343
|
+
>>> x = brainstate.random.randn(2)
|
304
344
|
>>> output = m(x)
|
305
345
|
"""
|
306
346
|
__module__ = 'brainstate.nn'
|
@@ -317,16 +357,19 @@ class Tanh(ElementWiseBlock):
|
|
317
357
|
.. math::
|
318
358
|
\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
|
319
359
|
|
320
|
-
Shape
|
321
|
-
|
322
|
-
|
360
|
+
Shape
|
361
|
+
-----
|
362
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
363
|
+
- Output: :math:`(*)`, same shape as the input.
|
323
364
|
|
324
|
-
Examples
|
365
|
+
Examples
|
366
|
+
--------
|
367
|
+
.. code-block:: python
|
325
368
|
|
326
369
|
>>> import brainstate.nn as nn
|
327
|
-
>>> import brainstate
|
370
|
+
>>> import brainstate
|
328
371
|
>>> m = nn.Tanh()
|
329
|
-
>>> x = random.randn(2)
|
372
|
+
>>> x = brainstate.random.randn(2)
|
330
373
|
>>> output = m(x)
|
331
374
|
"""
|
332
375
|
__module__ = 'brainstate.nn'
|
@@ -337,28 +380,34 @@ class Tanh(ElementWiseBlock):
|
|
337
380
|
|
338
381
|
class SiLU(ElementWiseBlock):
|
339
382
|
r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
|
383
|
+
|
340
384
|
The SiLU function is also known as the swish function.
|
341
385
|
|
342
386
|
.. math::
|
343
387
|
\text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
|
344
388
|
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
389
|
+
Notes
|
390
|
+
-----
|
391
|
+
See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
|
392
|
+
where the SiLU (Sigmoid Linear Unit) was originally coined, and see
|
393
|
+
`Sigmoid-Weighted Linear Units for Neural Network Function Approximation
|
394
|
+
in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
|
395
|
+
a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
|
396
|
+
where the SiLU was experimented with later.
|
352
397
|
|
353
|
-
Shape
|
354
|
-
|
355
|
-
|
398
|
+
Shape
|
399
|
+
-----
|
400
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
401
|
+
- Output: :math:`(*)`, same shape as the input.
|
356
402
|
|
357
|
-
Examples
|
403
|
+
Examples
|
404
|
+
--------
|
405
|
+
.. code-block:: python
|
358
406
|
|
359
407
|
>>> import brainstate.nn as nn
|
408
|
+
>>> import brainstate
|
360
409
|
>>> m = nn.SiLU()
|
361
|
-
>>> x = random.randn(2)
|
410
|
+
>>> x = brainstate.random.randn(2)
|
362
411
|
>>> output = m(x)
|
363
412
|
"""
|
364
413
|
__module__ = 'brainstate.nn'
|
@@ -369,25 +418,30 @@ class SiLU(ElementWiseBlock):
|
|
369
418
|
|
370
419
|
class Mish(ElementWiseBlock):
|
371
420
|
r"""Applies the Mish function, element-wise.
|
421
|
+
|
372
422
|
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
373
423
|
|
374
424
|
.. math::
|
375
425
|
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
|
376
426
|
|
377
|
-
|
378
|
-
|
379
|
-
|
427
|
+
Notes
|
428
|
+
-----
|
429
|
+
See `Mish: A Self Regularized Non-Monotonic Neural Activation Function
|
430
|
+
<https://arxiv.org/abs/1908.08681>`_
|
380
431
|
|
381
|
-
Shape
|
382
|
-
|
383
|
-
|
432
|
+
Shape
|
433
|
+
-----
|
434
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
435
|
+
- Output: :math:`(*)`, same shape as the input.
|
384
436
|
|
385
|
-
Examples
|
437
|
+
Examples
|
438
|
+
--------
|
439
|
+
.. code-block:: python
|
386
440
|
|
387
441
|
>>> import brainstate.nn as nn
|
388
|
-
>>> import brainstate
|
442
|
+
>>> import brainstate
|
389
443
|
>>> m = nn.Mish()
|
390
|
-
>>> x = random.randn(2)
|
444
|
+
>>> x = brainstate.random.randn(2)
|
391
445
|
>>> output = m(x)
|
392
446
|
"""
|
393
447
|
__module__ = 'brainstate.nn'
|
@@ -397,8 +451,10 @@ class Mish(ElementWiseBlock):
|
|
397
451
|
|
398
452
|
|
399
453
|
class Hardswish(ElementWiseBlock):
|
400
|
-
r"""Applies the Hardswish function, element-wise
|
401
|
-
|
454
|
+
r"""Applies the Hardswish function, element-wise.
|
455
|
+
|
456
|
+
As described in the paper `Searching for MobileNetV3
|
457
|
+
<https://arxiv.org/abs/1905.02244>`_.
|
402
458
|
|
403
459
|
Hardswish is defined as:
|
404
460
|
|
@@ -409,17 +465,19 @@ class Hardswish(ElementWiseBlock):
|
|
409
465
|
x \cdot (x + 3) /6 & \text{otherwise}
|
410
466
|
\end{cases}
|
411
467
|
|
468
|
+
Shape
|
469
|
+
-----
|
470
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
471
|
+
- Output: :math:`(*)`, same shape as the input.
|
412
472
|
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
Examples::
|
473
|
+
Examples
|
474
|
+
--------
|
475
|
+
.. code-block:: python
|
418
476
|
|
419
477
|
>>> import brainstate.nn as nn
|
420
|
-
>>> import brainstate
|
478
|
+
>>> import brainstate
|
421
479
|
>>> m = nn.Hardswish()
|
422
|
-
>>> x = random.randn(2)
|
480
|
+
>>> x = brainstate.random.randn(2)
|
423
481
|
>>> output = m(x)
|
424
482
|
"""
|
425
483
|
__module__ = 'brainstate.nn'
|
@@ -429,9 +487,10 @@ class Hardswish(ElementWiseBlock):
|
|
429
487
|
|
430
488
|
|
431
489
|
class ELU(ElementWiseBlock):
|
432
|
-
r"""Applies the Exponential Linear Unit (ELU) function, element-wise
|
433
|
-
|
434
|
-
|
490
|
+
r"""Applies the Exponential Linear Unit (ELU) function, element-wise.
|
491
|
+
|
492
|
+
As described in the paper: `Fast and Accurate Deep Network Learning by
|
493
|
+
Exponential Linear Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
|
435
494
|
|
436
495
|
ELU is defined as:
|
437
496
|
|
@@ -441,19 +500,24 @@ class ELU(ElementWiseBlock):
|
|
441
500
|
\alpha * (\exp(x) - 1), & \text{ if } x \leq 0
|
442
501
|
\end{cases}
|
443
502
|
|
444
|
-
|
445
|
-
|
503
|
+
Parameters
|
504
|
+
----------
|
505
|
+
alpha : float, optional
|
506
|
+
The :math:`\alpha` value for the ELU formulation. Default: 1.0
|
446
507
|
|
447
|
-
Shape
|
448
|
-
|
449
|
-
|
508
|
+
Shape
|
509
|
+
-----
|
510
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
511
|
+
- Output: :math:`(*)`, same shape as the input.
|
450
512
|
|
451
|
-
Examples
|
513
|
+
Examples
|
514
|
+
--------
|
515
|
+
.. code-block:: python
|
452
516
|
|
453
517
|
>>> import brainstate.nn as nn
|
454
|
-
>>> import brainstate
|
518
|
+
>>> import brainstate
|
455
519
|
>>> m = nn.ELU()
|
456
|
-
>>> x = random.randn(2)
|
520
|
+
>>> x = brainstate.random.randn(2)
|
457
521
|
>>> output = m(x)
|
458
522
|
"""
|
459
523
|
__module__ = 'brainstate.nn'
|
@@ -471,30 +535,38 @@ class ELU(ElementWiseBlock):
|
|
471
535
|
|
472
536
|
|
473
537
|
class CELU(ElementWiseBlock):
|
474
|
-
r"""Applies the element-wise function
|
538
|
+
r"""Applies the element-wise function.
|
475
539
|
|
476
540
|
.. math::
|
477
541
|
\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
|
478
542
|
|
479
|
-
More details can be found in the paper `Continuously Differentiable Exponential
|
543
|
+
More details can be found in the paper `Continuously Differentiable Exponential
|
544
|
+
Linear Units`_ .
|
480
545
|
|
481
|
-
|
482
|
-
|
546
|
+
Parameters
|
547
|
+
----------
|
548
|
+
alpha : float, optional
|
549
|
+
The :math:`\alpha` value for the CELU formulation. Default: 1.0
|
483
550
|
|
484
|
-
Shape
|
485
|
-
|
486
|
-
|
551
|
+
Shape
|
552
|
+
-----
|
553
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
554
|
+
- Output: :math:`(*)`, same shape as the input.
|
487
555
|
|
488
|
-
|
556
|
+
References
|
557
|
+
----------
|
558
|
+
.. _`Continuously Differentiable Exponential Linear Units`:
|
559
|
+
https://arxiv.org/abs/1704.07483
|
560
|
+
|
561
|
+
Examples
|
562
|
+
--------
|
563
|
+
.. code-block:: python
|
489
564
|
|
490
565
|
>>> import brainstate.nn as nn
|
491
|
-
>>> import brainstate
|
566
|
+
>>> import brainstate
|
492
567
|
>>> m = nn.CELU()
|
493
|
-
>>> x = random.randn(2)
|
568
|
+
>>> x = brainstate.random.randn(2)
|
494
569
|
>>> output = m(x)
|
495
|
-
|
496
|
-
.. _`Continuously Differentiable Exponential Linear Units`:
|
497
|
-
https://arxiv.org/abs/1704.07483
|
498
570
|
"""
|
499
571
|
__module__ = 'brainstate.nn'
|
500
572
|
alpha: float
|
@@ -511,7 +583,7 @@ class CELU(ElementWiseBlock):
|
|
511
583
|
|
512
584
|
|
513
585
|
class SELU(ElementWiseBlock):
|
514
|
-
r"""Applied element-wise
|
586
|
+
r"""Applied element-wise.
|
515
587
|
|
516
588
|
.. math::
|
517
589
|
\text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
|
@@ -521,20 +593,24 @@ class SELU(ElementWiseBlock):
|
|
521
593
|
|
522
594
|
More details can be found in the paper `Self-Normalizing Neural Networks`_ .
|
523
595
|
|
596
|
+
Shape
|
597
|
+
-----
|
598
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
599
|
+
- Output: :math:`(*)`, same shape as the input.
|
524
600
|
|
525
|
-
|
526
|
-
|
527
|
-
|
601
|
+
References
|
602
|
+
----------
|
603
|
+
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
528
604
|
|
529
|
-
Examples
|
605
|
+
Examples
|
606
|
+
--------
|
607
|
+
.. code-block:: python
|
530
608
|
|
531
609
|
>>> import brainstate.nn as nn
|
532
|
-
>>> import brainstate
|
610
|
+
>>> import brainstate
|
533
611
|
>>> m = nn.SELU()
|
534
|
-
>>> x = random.randn(2)
|
612
|
+
>>> x = brainstate.random.randn(2)
|
535
613
|
>>> output = m(x)
|
536
|
-
|
537
|
-
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
538
614
|
"""
|
539
615
|
__module__ = 'brainstate.nn'
|
540
616
|
|
@@ -543,24 +619,33 @@ class SELU(ElementWiseBlock):
|
|
543
619
|
|
544
620
|
|
545
621
|
class GLU(ElementWiseBlock):
|
546
|
-
r"""Applies the gated linear unit function
|
547
|
-
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
|
548
|
-
of the input matrices and :math:`b` is the second half.
|
622
|
+
r"""Applies the gated linear unit function.
|
549
623
|
|
550
|
-
|
551
|
-
|
624
|
+
.. math::
|
625
|
+
{GLU}(a, b)= a \otimes \sigma(b)
|
626
|
+
|
627
|
+
where :math:`a` is the first half of the input matrices and :math:`b` is
|
628
|
+
the second half.
|
552
629
|
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
630
|
+
Parameters
|
631
|
+
----------
|
632
|
+
dim : int, optional
|
633
|
+
The dimension on which to split the input. Default: -1
|
557
634
|
|
558
|
-
|
635
|
+
Shape
|
636
|
+
-----
|
637
|
+
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
|
638
|
+
dimensions
|
639
|
+
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
|
640
|
+
|
641
|
+
Examples
|
642
|
+
--------
|
643
|
+
.. code-block:: python
|
559
644
|
|
560
645
|
>>> import brainstate.nn as nn
|
561
|
-
>>> import brainstate
|
646
|
+
>>> import brainstate
|
562
647
|
>>> m = nn.GLU()
|
563
|
-
>>> x = random.randn(4, 2)
|
648
|
+
>>> x = brainstate.random.randn(4, 2)
|
564
649
|
>>> output = m(x)
|
565
650
|
"""
|
566
651
|
__module__ = 'brainstate.nn'
|
@@ -578,30 +663,35 @@ class GLU(ElementWiseBlock):
|
|
578
663
|
|
579
664
|
|
580
665
|
class GELU(ElementWiseBlock):
|
581
|
-
r"""Applies the Gaussian Error Linear Units function
|
666
|
+
r"""Applies the Gaussian Error Linear Units function.
|
582
667
|
|
583
668
|
.. math:: \text{GELU}(x) = x * \Phi(x)
|
584
669
|
|
585
|
-
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian
|
670
|
+
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian
|
671
|
+
Distribution.
|
586
672
|
|
587
|
-
When the approximate argument is
|
673
|
+
When the approximate argument is True, Gelu is estimated with:
|
588
674
|
|
589
675
|
.. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
|
590
676
|
|
591
|
-
|
592
|
-
|
593
|
-
|
677
|
+
Parameters
|
678
|
+
----------
|
679
|
+
approximate : bool, optional
|
680
|
+
Whether to use the tanh approximation algorithm. Default: False
|
594
681
|
|
595
|
-
Shape
|
596
|
-
|
597
|
-
|
682
|
+
Shape
|
683
|
+
-----
|
684
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
685
|
+
- Output: :math:`(*)`, same shape as the input.
|
598
686
|
|
599
|
-
Examples
|
687
|
+
Examples
|
688
|
+
--------
|
689
|
+
.. code-block:: python
|
600
690
|
|
601
691
|
>>> import brainstate.nn as nn
|
602
|
-
>>> import brainstate
|
692
|
+
>>> import brainstate
|
603
693
|
>>> m = nn.GELU()
|
604
|
-
>>> x = random.randn(2)
|
694
|
+
>>> x = brainstate.random.randn(2)
|
605
695
|
>>> output = m(x)
|
606
696
|
"""
|
607
697
|
__module__ = 'brainstate.nn'
|
@@ -631,19 +721,24 @@ class Hardshrink(ElementWiseBlock):
|
|
631
721
|
0, & \text{ otherwise }
|
632
722
|
\end{cases}
|
633
723
|
|
634
|
-
|
635
|
-
|
724
|
+
Parameters
|
725
|
+
----------
|
726
|
+
lambd : float, optional
|
727
|
+
The :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
|
636
728
|
|
637
|
-
Shape
|
638
|
-
|
639
|
-
|
729
|
+
Shape
|
730
|
+
-----
|
731
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
732
|
+
- Output: :math:`(*)`, same shape as the input.
|
640
733
|
|
641
|
-
Examples
|
734
|
+
Examples
|
735
|
+
--------
|
736
|
+
.. code-block:: python
|
642
737
|
|
643
738
|
>>> import brainstate.nn as nn
|
644
|
-
>>> import brainstate
|
739
|
+
>>> import brainstate
|
645
740
|
>>> m = nn.Hardshrink()
|
646
|
-
>>> x = random.randn(2)
|
741
|
+
>>> x = brainstate.random.randn(2)
|
647
742
|
>>> output = m(x)
|
648
743
|
"""
|
649
744
|
__module__ = 'brainstate.nn'
|
@@ -661,12 +756,11 @@ class Hardshrink(ElementWiseBlock):
|
|
661
756
|
|
662
757
|
|
663
758
|
class LeakyReLU(ElementWiseBlock):
|
664
|
-
r"""Applies the element-wise function
|
759
|
+
r"""Applies the element-wise function.
|
665
760
|
|
666
761
|
.. math::
|
667
762
|
\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
|
668
763
|
|
669
|
-
|
670
764
|
or
|
671
765
|
|
672
766
|
.. math::
|
@@ -676,21 +770,26 @@ class LeakyReLU(ElementWiseBlock):
|
|
676
770
|
\text{negative\_slope} \times x, & \text{ otherwise }
|
677
771
|
\end{cases}
|
678
772
|
|
679
|
-
|
680
|
-
|
681
|
-
|
773
|
+
Parameters
|
774
|
+
----------
|
775
|
+
negative_slope : float, optional
|
776
|
+
Controls the angle of the negative slope (which is used for
|
777
|
+
negative input values). Default: 1e-2
|
682
778
|
|
683
|
-
Shape
|
684
|
-
|
685
|
-
|
686
|
-
|
779
|
+
Shape
|
780
|
+
-----
|
781
|
+
- Input: :math:`(*)` where `*` means, any number of additional
|
782
|
+
dimensions
|
783
|
+
- Output: :math:`(*)`, same shape as the input
|
687
784
|
|
688
|
-
Examples
|
785
|
+
Examples
|
786
|
+
--------
|
787
|
+
.. code-block:: python
|
689
788
|
|
690
789
|
>>> import brainstate.nn as nn
|
691
|
-
>>> import brainstate
|
790
|
+
>>> import brainstate
|
692
791
|
>>> m = nn.LeakyReLU(0.1)
|
693
|
-
>>> x = random.randn(2)
|
792
|
+
>>> x = brainstate.random.randn(2)
|
694
793
|
>>> output = m(x)
|
695
794
|
"""
|
696
795
|
__module__ = 'brainstate.nn'
|
@@ -708,21 +807,24 @@ class LeakyReLU(ElementWiseBlock):
|
|
708
807
|
|
709
808
|
|
710
809
|
class LogSigmoid(ElementWiseBlock):
|
711
|
-
r"""Applies the element-wise function
|
810
|
+
r"""Applies the element-wise function.
|
712
811
|
|
713
812
|
.. math::
|
714
813
|
\text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
|
715
814
|
|
716
|
-
Shape
|
717
|
-
|
718
|
-
|
815
|
+
Shape
|
816
|
+
-----
|
817
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
818
|
+
- Output: :math:`(*)`, same shape as the input.
|
719
819
|
|
720
|
-
Examples
|
820
|
+
Examples
|
821
|
+
--------
|
822
|
+
.. code-block:: python
|
721
823
|
|
722
824
|
>>> import brainstate.nn as nn
|
723
|
-
>>> import brainstate
|
825
|
+
>>> import brainstate
|
724
826
|
>>> m = nn.LogSigmoid()
|
725
|
-
>>> x = random.randn(2)
|
827
|
+
>>> x = brainstate.random.randn(2)
|
726
828
|
>>> output = m(x)
|
727
829
|
"""
|
728
830
|
__module__ = 'brainstate.nn'
|
@@ -732,8 +834,10 @@ class LogSigmoid(ElementWiseBlock):
|
|
732
834
|
|
733
835
|
|
734
836
|
class Softplus(ElementWiseBlock):
|
735
|
-
r"""Applies the Softplus function
|
736
|
-
|
837
|
+
r"""Applies the Softplus function element-wise.
|
838
|
+
|
839
|
+
.. math::
|
840
|
+
\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
|
737
841
|
|
738
842
|
SoftPlus is a smooth approximation to the ReLU function and can be used
|
739
843
|
to constrain the output of a machine to always be positive.
|
@@ -741,16 +845,19 @@ class Softplus(ElementWiseBlock):
|
|
741
845
|
For numerical stability the implementation reverts to the linear function
|
742
846
|
when :math:`input \times \beta > threshold`.
|
743
847
|
|
744
|
-
Shape
|
745
|
-
|
746
|
-
|
848
|
+
Shape
|
849
|
+
-----
|
850
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
851
|
+
- Output: :math:`(*)`, same shape as the input.
|
747
852
|
|
748
|
-
Examples
|
853
|
+
Examples
|
854
|
+
--------
|
855
|
+
.. code-block:: python
|
749
856
|
|
750
857
|
>>> import brainstate.nn as nn
|
751
|
-
>>> import brainstate
|
858
|
+
>>> import brainstate
|
752
859
|
>>> m = nn.Softplus()
|
753
|
-
>>> x = random.randn(2)
|
860
|
+
>>> x = brainstate.random.randn(2)
|
754
861
|
>>> output = m(x)
|
755
862
|
"""
|
756
863
|
__module__ = 'brainstate.nn'
|
@@ -760,7 +867,7 @@ class Softplus(ElementWiseBlock):
|
|
760
867
|
|
761
868
|
|
762
869
|
class Softshrink(ElementWiseBlock):
|
763
|
-
r"""Applies the soft shrinkage function elementwise
|
870
|
+
r"""Applies the soft shrinkage function elementwise.
|
764
871
|
|
765
872
|
.. math::
|
766
873
|
\text{SoftShrinkage}(x) =
|
@@ -770,19 +877,25 @@ class Softshrink(ElementWiseBlock):
|
|
770
877
|
0, & \text{ otherwise }
|
771
878
|
\end{cases}
|
772
879
|
|
773
|
-
|
774
|
-
|
880
|
+
Parameters
|
881
|
+
----------
|
882
|
+
lambd : float, optional
|
883
|
+
The :math:`\lambda` (must be no less than zero) value for the
|
884
|
+
Softshrink formulation. Default: 0.5
|
775
885
|
|
776
|
-
Shape
|
777
|
-
|
778
|
-
|
886
|
+
Shape
|
887
|
+
-----
|
888
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
889
|
+
- Output: :math:`(*)`, same shape as the input.
|
779
890
|
|
780
|
-
Examples
|
891
|
+
Examples
|
892
|
+
--------
|
893
|
+
.. code-block:: python
|
781
894
|
|
782
895
|
>>> import brainstate.nn as nn
|
783
|
-
>>> import brainstate
|
896
|
+
>>> import brainstate
|
784
897
|
>>> m = nn.Softshrink()
|
785
|
-
>>> x = random.randn(2)
|
898
|
+
>>> x = brainstate.random.randn(2)
|
786
899
|
>>> output = m(x)
|
787
900
|
"""
|
788
901
|
__module__ = 'brainstate.nn'
|
@@ -800,7 +913,7 @@ class Softshrink(ElementWiseBlock):
|
|
800
913
|
|
801
914
|
|
802
915
|
class PReLU(ElementWiseBlock):
|
803
|
-
r"""Applies the element-wise function
|
916
|
+
r"""Applies the element-wise function.
|
804
917
|
|
805
918
|
.. math::
|
806
919
|
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
@@ -814,35 +927,43 @@ class PReLU(ElementWiseBlock):
|
|
814
927
|
ax, & \text{ otherwise }
|
815
928
|
\end{cases}
|
816
929
|
|
817
|
-
Here :math:`a` is a learnable parameter. When called without arguments,
|
818
|
-
parameter :math:`a` across all input channels.
|
819
|
-
a separate :math:`a` is used for
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
930
|
+
Here :math:`a` is a learnable parameter. When called without arguments,
|
931
|
+
`nn.PReLU()` uses a single parameter :math:`a` across all input channels.
|
932
|
+
If called with `nn.PReLU(nChannels)`, a separate :math:`a` is used for
|
933
|
+
each input channel.
|
934
|
+
|
935
|
+
Parameters
|
936
|
+
----------
|
937
|
+
num_parameters : int, optional
|
938
|
+
Number of :math:`a` to learn. Although it takes an int as input,
|
939
|
+
there is only two values are legitimate: 1, or the number of channels
|
940
|
+
at input. Default: 1
|
941
|
+
init : float, optional
|
942
|
+
The initial value of :math:`a`. Default: 0.25
|
943
|
+
dtype : optional
|
944
|
+
The data type for the weight parameter.
|
945
|
+
|
946
|
+
Shape
|
947
|
+
-----
|
948
|
+
- Input: :math:`( *)` where `*` means, any number of additional dimensions.
|
949
|
+
- Output: :math:`(*)`, same shape as the input.
|
950
|
+
|
951
|
+
Attributes
|
952
|
+
----------
|
953
|
+
weight : Tensor
|
954
|
+
The learnable weights of shape (:attr:`num_parameters`).
|
955
|
+
|
956
|
+
Notes
|
957
|
+
-----
|
958
|
+
- Weight decay should not be used when learning :math:`a` for good performance.
|
959
|
+
- Channel dim is the 2nd dim of input. When input has dims < 2, then there is
|
960
|
+
no channel dim and the number of channels = 1.
|
961
|
+
|
962
|
+
Examples
|
963
|
+
--------
|
964
|
+
.. code-block:: python
|
824
965
|
|
825
|
-
|
826
|
-
Channel dim is the 2nd dim of input. When input has dims < 2, then there is
|
827
|
-
no channel dim and the number of channels = 1.
|
828
|
-
|
829
|
-
Args:
|
830
|
-
num_parameters (int): number of :math:`a` to learn.
|
831
|
-
Although it takes an int as input, there is only two values are legitimate:
|
832
|
-
1, or the number of channels at input. Default: 1
|
833
|
-
init (float): the initial value of :math:`a`. Default: 0.25
|
834
|
-
|
835
|
-
Shape:
|
836
|
-
- Input: :math:`( *)` where `*` means, any number of additional
|
837
|
-
dimensions.
|
838
|
-
- Output: :math:`(*)`, same shape as the input.
|
839
|
-
|
840
|
-
Attributes:
|
841
|
-
weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
|
842
|
-
|
843
|
-
Examples::
|
844
|
-
|
845
|
-
>>> import brainstate as brainstate
|
966
|
+
>>> import brainstate
|
846
967
|
>>> m = brainstate.nn.PReLU()
|
847
968
|
>>> x = brainstate.random.randn(2)
|
848
969
|
>>> output = m(x)
|
@@ -863,21 +984,24 @@ class PReLU(ElementWiseBlock):
|
|
863
984
|
|
864
985
|
|
865
986
|
class Softsign(ElementWiseBlock):
|
866
|
-
r"""Applies the element-wise function
|
987
|
+
r"""Applies the element-wise function.
|
867
988
|
|
868
989
|
.. math::
|
869
990
|
\text{SoftSign}(x) = \frac{x}{ 1 + |x|}
|
870
991
|
|
871
|
-
Shape
|
872
|
-
|
873
|
-
|
992
|
+
Shape
|
993
|
+
-----
|
994
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
995
|
+
- Output: :math:`(*)`, same shape as the input.
|
874
996
|
|
875
|
-
Examples
|
997
|
+
Examples
|
998
|
+
--------
|
999
|
+
.. code-block:: python
|
876
1000
|
|
877
1001
|
>>> import brainstate.nn as nn
|
878
|
-
>>> import brainstate
|
1002
|
+
>>> import brainstate
|
879
1003
|
>>> m = nn.Softsign()
|
880
|
-
>>> x = random.randn(2)
|
1004
|
+
>>> x = brainstate.random.randn(2)
|
881
1005
|
>>> output = m(x)
|
882
1006
|
"""
|
883
1007
|
__module__ = 'brainstate.nn'
|
@@ -887,21 +1011,24 @@ class Softsign(ElementWiseBlock):
|
|
887
1011
|
|
888
1012
|
|
889
1013
|
class Tanhshrink(ElementWiseBlock):
|
890
|
-
r"""Applies the element-wise function
|
1014
|
+
r"""Applies the element-wise function.
|
891
1015
|
|
892
1016
|
.. math::
|
893
1017
|
\text{Tanhshrink}(x) = x - \tanh(x)
|
894
1018
|
|
895
|
-
Shape
|
896
|
-
|
897
|
-
|
1019
|
+
Shape
|
1020
|
+
-----
|
1021
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
1022
|
+
- Output: :math:`(*)`, same shape as the input.
|
898
1023
|
|
899
|
-
Examples
|
1024
|
+
Examples
|
1025
|
+
--------
|
1026
|
+
.. code-block:: python
|
900
1027
|
|
901
1028
|
>>> import brainstate.nn as nn
|
902
|
-
>>> import brainstate
|
1029
|
+
>>> import brainstate
|
903
1030
|
>>> m = nn.Tanhshrink()
|
904
|
-
>>> x = random.randn(2)
|
1031
|
+
>>> x = brainstate.random.randn(2)
|
905
1032
|
>>> output = m(x)
|
906
1033
|
"""
|
907
1034
|
__module__ = 'brainstate.nn'
|
@@ -911,8 +1038,9 @@ class Tanhshrink(ElementWiseBlock):
|
|
911
1038
|
|
912
1039
|
|
913
1040
|
class Softmin(ElementWiseBlock):
|
914
|
-
r"""Applies the Softmin function to an n-dimensional input Tensor
|
915
|
-
|
1041
|
+
r"""Applies the Softmin function to an n-dimensional input Tensor.
|
1042
|
+
|
1043
|
+
Rescales the input so that the elements of the n-dimensional output Tensor
|
916
1044
|
lie in the range `[0, 1]` and sum to 1.
|
917
1045
|
|
918
1046
|
Softmin is defined as:
|
@@ -920,25 +1048,31 @@ class Softmin(ElementWiseBlock):
|
|
920
1048
|
.. math::
|
921
1049
|
\text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
|
922
1050
|
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
1051
|
+
Parameters
|
1052
|
+
----------
|
1053
|
+
dim : int, optional
|
1054
|
+
A dimension along which Softmin will be computed (so every slice
|
1055
|
+
along dim will sum to 1).
|
1056
|
+
|
1057
|
+
Shape
|
1058
|
+
-----
|
1059
|
+
- Input: :math:`(*)` where `*` means, any number of additional dimensions
|
1060
|
+
- Output: :math:`(*)`, same shape as the input
|
1061
|
+
|
1062
|
+
Returns
|
1063
|
+
-------
|
1064
|
+
Tensor
|
1065
|
+
A Tensor of the same dimension and shape as the input, with
|
934
1066
|
values in the range [0, 1]
|
935
1067
|
|
936
|
-
Examples
|
1068
|
+
Examples
|
1069
|
+
--------
|
1070
|
+
.. code-block:: python
|
937
1071
|
|
938
1072
|
>>> import brainstate.nn as nn
|
939
|
-
>>> import brainstate
|
1073
|
+
>>> import brainstate
|
940
1074
|
>>> m = nn.Softmin(dim=1)
|
941
|
-
>>> x = random.randn(2, 3)
|
1075
|
+
>>> x = brainstate.random.randn(2, 3)
|
942
1076
|
>>> output = m(x)
|
943
1077
|
"""
|
944
1078
|
__module__ = 'brainstate.nn'
|
@@ -956,8 +1090,9 @@ class Softmin(ElementWiseBlock):
|
|
956
1090
|
|
957
1091
|
|
958
1092
|
class Softmax(ElementWiseBlock):
|
959
|
-
r"""Applies the Softmax function to an n-dimensional input Tensor
|
960
|
-
|
1093
|
+
r"""Applies the Softmax function to an n-dimensional input Tensor.
|
1094
|
+
|
1095
|
+
Rescales the input so that the elements of the n-dimensional output Tensor
|
961
1096
|
lie in the range [0,1] and sum to 1.
|
962
1097
|
|
963
1098
|
Softmax is defined as:
|
@@ -968,32 +1103,38 @@ class Softmax(ElementWiseBlock):
|
|
968
1103
|
When the input Tensor is a sparse tensor then the unspecified
|
969
1104
|
values are treated as ``-inf``.
|
970
1105
|
|
971
|
-
|
972
|
-
|
973
|
-
|
974
|
-
|
975
|
-
|
976
|
-
|
977
|
-
|
1106
|
+
Parameters
|
1107
|
+
----------
|
1108
|
+
dim : int, optional
|
1109
|
+
A dimension along which Softmax will be computed (so every slice
|
1110
|
+
along dim will sum to 1).
|
1111
|
+
|
1112
|
+
Shape
|
1113
|
+
-----
|
1114
|
+
- Input: :math:`(*)` where `*` means, any number of additional dimensions
|
1115
|
+
- Output: :math:`(*)`, same shape as the input
|
1116
|
+
|
1117
|
+
Returns
|
1118
|
+
-------
|
1119
|
+
Tensor
|
1120
|
+
A Tensor of the same dimension and shape as the input with
|
978
1121
|
values in the range [0, 1]
|
979
1122
|
|
980
|
-
|
981
|
-
|
982
|
-
|
1123
|
+
Notes
|
1124
|
+
-----
|
1125
|
+
This module doesn't work directly with NLLLoss, which expects the Log to be
|
1126
|
+
computed between the Softmax and itself. Use `LogSoftmax` instead (it's
|
1127
|
+
faster and has better numerical properties).
|
983
1128
|
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
Use `LogSoftmax` instead (it's faster and has better numerical properties).
|
988
|
-
|
989
|
-
Examples::
|
1129
|
+
Examples
|
1130
|
+
--------
|
1131
|
+
.. code-block:: python
|
990
1132
|
|
991
1133
|
>>> import brainstate.nn as nn
|
992
|
-
>>> import brainstate
|
1134
|
+
>>> import brainstate
|
993
1135
|
>>> m = nn.Softmax(dim=1)
|
994
|
-
>>> x = random.randn(2, 3)
|
1136
|
+
>>> x = brainstate.random.randn(2, 3)
|
995
1137
|
>>> output = m(x)
|
996
|
-
|
997
1138
|
"""
|
998
1139
|
__module__ = 'brainstate.nn'
|
999
1140
|
dim: Optional[int]
|
@@ -1015,21 +1156,26 @@ class Softmax2d(ElementWiseBlock):
|
|
1015
1156
|
When given an image of ``Channels x Height x Width``, it will
|
1016
1157
|
apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
|
1017
1158
|
|
1018
|
-
Shape
|
1019
|
-
|
1020
|
-
|
1159
|
+
Shape
|
1160
|
+
-----
|
1161
|
+
- Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
|
1162
|
+
- Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
|
1021
1163
|
|
1022
|
-
Returns
|
1023
|
-
|
1164
|
+
Returns
|
1165
|
+
-------
|
1166
|
+
Tensor
|
1167
|
+
A Tensor of the same dimension and shape as the input with
|
1024
1168
|
values in the range [0, 1]
|
1025
1169
|
|
1026
|
-
Examples
|
1170
|
+
Examples
|
1171
|
+
--------
|
1172
|
+
.. code-block:: python
|
1027
1173
|
|
1028
1174
|
>>> import brainstate.nn as nn
|
1029
|
-
>>> import brainstate
|
1175
|
+
>>> import brainstate
|
1030
1176
|
>>> m = nn.Softmax2d()
|
1031
1177
|
>>> # you softmax over the 2nd dimension
|
1032
|
-
>>> x = random.randn(2, 3, 12, 13)
|
1178
|
+
>>> x = brainstate.random.randn(2, 3, 12, 13)
|
1033
1179
|
>>> output = m(x)
|
1034
1180
|
"""
|
1035
1181
|
__module__ = 'brainstate.nn'
|
@@ -1040,30 +1186,37 @@ class Softmax2d(ElementWiseBlock):
|
|
1040
1186
|
|
1041
1187
|
|
1042
1188
|
class LogSoftmax(ElementWiseBlock):
|
1043
|
-
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
|
1044
|
-
|
1189
|
+
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor.
|
1190
|
+
|
1191
|
+
The LogSoftmax formulation can be simplified as:
|
1045
1192
|
|
1046
1193
|
.. math::
|
1047
1194
|
\text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
|
1048
1195
|
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1196
|
+
Parameters
|
1197
|
+
----------
|
1198
|
+
dim : int, optional
|
1199
|
+
A dimension along which LogSoftmax will be computed.
|
1053
1200
|
|
1054
|
-
|
1055
|
-
|
1201
|
+
Shape
|
1202
|
+
-----
|
1203
|
+
- Input: :math:`(*)` where `*` means, any number of additional dimensions
|
1204
|
+
- Output: :math:`(*)`, same shape as the input
|
1056
1205
|
|
1057
|
-
Returns
|
1058
|
-
|
1206
|
+
Returns
|
1207
|
+
-------
|
1208
|
+
Tensor
|
1209
|
+
A Tensor of the same dimension and shape as the input with
|
1059
1210
|
values in the range [-inf, 0)
|
1060
1211
|
|
1061
|
-
Examples
|
1212
|
+
Examples
|
1213
|
+
--------
|
1214
|
+
.. code-block:: python
|
1062
1215
|
|
1063
1216
|
>>> import brainstate.nn as nn
|
1064
|
-
>>> import brainstate
|
1217
|
+
>>> import brainstate
|
1065
1218
|
>>> m = nn.LogSoftmax(dim=1)
|
1066
|
-
>>> x = random.randn(2, 3)
|
1219
|
+
>>> x = brainstate.random.randn(2, 3)
|
1067
1220
|
>>> output = m(x)
|
1068
1221
|
"""
|
1069
1222
|
__module__ = 'brainstate.nn'
|
@@ -1082,6 +1235,16 @@ class LogSoftmax(ElementWiseBlock):
|
|
1082
1235
|
|
1083
1236
|
class Identity(ElementWiseBlock):
|
1084
1237
|
r"""A placeholder identity operator that is argument-insensitive.
|
1238
|
+
|
1239
|
+
Examples
|
1240
|
+
--------
|
1241
|
+
.. code-block:: python
|
1242
|
+
|
1243
|
+
>>> import brainstate.nn as nn
|
1244
|
+
>>> m = nn.Identity()
|
1245
|
+
>>> x = brainstate.random.randn(2, 3)
|
1246
|
+
>>> output = m(x)
|
1247
|
+
>>> assert (output == x).all()
|
1085
1248
|
"""
|
1086
1249
|
__module__ = 'brainstate.nn'
|
1087
1250
|
|
@@ -1103,17 +1266,33 @@ class SpikeBitwise(ElementWiseBlock):
|
|
1103
1266
|
\hline
|
1104
1267
|
\end{array}
|
1105
1268
|
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1269
|
+
Parameters
|
1270
|
+
----------
|
1271
|
+
op : str, optional
|
1272
|
+
The bitwise operation. Default: 'add'
|
1273
|
+
name : str, optional
|
1274
|
+
The name of the dynamic system.
|
1275
|
+
|
1276
|
+
Examples
|
1277
|
+
--------
|
1278
|
+
.. code-block:: python
|
1279
|
+
|
1280
|
+
>>> import brainstate.nn as nn
|
1281
|
+
>>> m = nn.SpikeBitwise(op='and')
|
1282
|
+
>>> x = brainstate.random.randn(2, 3) > 0
|
1283
|
+
>>> y = brainstate.random.randn(2, 3) > 0
|
1284
|
+
>>> output = m(x, y)
|
1109
1285
|
"""
|
1110
1286
|
__module__ = 'brainstate.nn'
|
1111
1287
|
|
1112
|
-
def __init__(
|
1113
|
-
|
1114
|
-
|
1288
|
+
def __init__(
|
1289
|
+
self,
|
1290
|
+
op: str = 'add',
|
1291
|
+
name: Optional[str] = None
|
1292
|
+
) -> None:
|
1115
1293
|
super().__init__(name=name)
|
1116
1294
|
self.op = op
|
1117
1295
|
|
1118
1296
|
def __call__(self, x, y):
|
1119
|
-
|
1297
|
+
import braintools
|
1298
|
+
return braintools.spike_bitwise(x, y, self.op)
|