brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
brainstate/nn/_elementwise.py
DELETED
@@ -1,1438 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
from __future__ import annotations
|
19
|
-
|
20
|
-
from typing import Optional
|
21
|
-
|
22
|
-
import brainunit as u
|
23
|
-
import jax.numpy as jnp
|
24
|
-
import jax.typing
|
25
|
-
|
26
|
-
from ._base import ElementWiseBlock
|
27
|
-
from brainstate import environ, random, functional as F
|
28
|
-
from brainstate._module import Module
|
29
|
-
from brainstate._state import ParamState
|
30
|
-
from brainstate.mixin import Mode
|
31
|
-
from brainstate.typing import ArrayLike
|
32
|
-
|
33
|
-
__all__ = [
|
34
|
-
# activation functions
|
35
|
-
'Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid',
|
36
|
-
'Tanh', 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU',
|
37
|
-
'Hardshrink', 'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'PReLU',
|
38
|
-
'Softsign', 'Tanhshrink', 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax',
|
39
|
-
|
40
|
-
# dropout
|
41
|
-
'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
|
42
|
-
|
43
|
-
# others
|
44
|
-
'Identity', 'SpikeBitwise',
|
45
|
-
]
|
46
|
-
|
47
|
-
|
48
|
-
class Threshold(Module, ElementWiseBlock):
|
49
|
-
r"""Thresholds each element of the input Tensor.
|
50
|
-
|
51
|
-
Threshold is defined as:
|
52
|
-
|
53
|
-
.. math::
|
54
|
-
y =
|
55
|
-
\begin{cases}
|
56
|
-
x, &\text{ if } x > \text{threshold} \\
|
57
|
-
\text{value}, &\text{ otherwise }
|
58
|
-
\end{cases}
|
59
|
-
|
60
|
-
Args:
|
61
|
-
threshold: The value to threshold at
|
62
|
-
value: The value to replace with
|
63
|
-
|
64
|
-
Shape:
|
65
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
66
|
-
- Output: :math:`(*)`, same shape as the input.
|
67
|
-
|
68
|
-
Examples::
|
69
|
-
|
70
|
-
>>> import brainstate.nn as nn
|
71
|
-
>>> import brainstate as bst
|
72
|
-
>>> m = nn.Threshold(0.1, 20)
|
73
|
-
>>> x = random.randn(2)
|
74
|
-
>>> output = m(x)
|
75
|
-
"""
|
76
|
-
__module__ = 'brainstate.nn'
|
77
|
-
threshold: float
|
78
|
-
value: float
|
79
|
-
|
80
|
-
def __init__(self, threshold: float, value: float) -> None:
|
81
|
-
super().__init__()
|
82
|
-
self.threshold = threshold
|
83
|
-
self.value = value
|
84
|
-
|
85
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
86
|
-
dtype = u.math.get_dtype(x)
|
87
|
-
return jnp.where(x > jnp.asarray(self.threshold, dtype=dtype),
|
88
|
-
x,
|
89
|
-
jnp.asarray(self.value, dtype=dtype))
|
90
|
-
|
91
|
-
def __repr__(self):
|
92
|
-
return f'{self.__class__.__name__}(threshold={self.threshold}, value={self.value})'
|
93
|
-
|
94
|
-
|
95
|
-
class ReLU(Module, ElementWiseBlock):
|
96
|
-
r"""Applies the rectified linear unit function element-wise:
|
97
|
-
|
98
|
-
:math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
|
99
|
-
|
100
|
-
Shape:
|
101
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
102
|
-
- Output: :math:`(*)`, same shape as the input.
|
103
|
-
|
104
|
-
Examples::
|
105
|
-
|
106
|
-
>>> import brainstate.nn as nn
|
107
|
-
>>> import brainstate as bst
|
108
|
-
>>> m = nn.ReLU()
|
109
|
-
>>> x = random.randn(2)
|
110
|
-
>>> output = m(x)
|
111
|
-
|
112
|
-
|
113
|
-
An implementation of CReLU - https://arxiv.org/abs/1603.05201
|
114
|
-
|
115
|
-
>>> import brainstate.nn as nn
|
116
|
-
>>> import brainstate as bst
|
117
|
-
>>> m = nn.ReLU()
|
118
|
-
>>> x = random.randn(2).unsqueeze(0)
|
119
|
-
>>> output = jax.numpy.concat((m(x), m(-x)))
|
120
|
-
"""
|
121
|
-
__module__ = 'brainstate.nn'
|
122
|
-
|
123
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
124
|
-
return F.relu(x)
|
125
|
-
|
126
|
-
def __repr__(self):
|
127
|
-
return f'{self.__class__.__name__}()'
|
128
|
-
|
129
|
-
|
130
|
-
class RReLU(Module, ElementWiseBlock):
|
131
|
-
r"""Applies the randomized leaky rectified liner unit function, element-wise,
|
132
|
-
as described in the paper:
|
133
|
-
|
134
|
-
`Empirical Evaluation of Rectified Activations in Convolutional Network`_.
|
135
|
-
|
136
|
-
The function is defined as:
|
137
|
-
|
138
|
-
.. math::
|
139
|
-
\text{RReLU}(x) =
|
140
|
-
\begin{cases}
|
141
|
-
x & \text{if } x \geq 0 \\
|
142
|
-
ax & \text{ otherwise }
|
143
|
-
\end{cases}
|
144
|
-
|
145
|
-
where :math:`a` is randomly sampled from uniform distribution
|
146
|
-
:math:`\mathcal{U}(\text{lower}, \text{upper})`.
|
147
|
-
|
148
|
-
See: https://arxiv.org/pdf/1505.00853.pdf
|
149
|
-
|
150
|
-
Args:
|
151
|
-
lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
|
152
|
-
upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
|
153
|
-
|
154
|
-
Shape:
|
155
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
156
|
-
- Output: :math:`(*)`, same shape as the input.
|
157
|
-
|
158
|
-
Examples::
|
159
|
-
|
160
|
-
>>> import brainstate.nn as nn
|
161
|
-
>>> import brainstate as bst
|
162
|
-
>>> m = nn.RReLU(0.1, 0.3)
|
163
|
-
>>> x = random.randn(2)
|
164
|
-
>>> output = m(x)
|
165
|
-
|
166
|
-
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
|
167
|
-
https://arxiv.org/abs/1505.00853
|
168
|
-
"""
|
169
|
-
__module__ = 'brainstate.nn'
|
170
|
-
lower: float
|
171
|
-
upper: float
|
172
|
-
|
173
|
-
def __init__(
|
174
|
-
self,
|
175
|
-
lower: float = 1. / 8,
|
176
|
-
upper: float = 1. / 3,
|
177
|
-
):
|
178
|
-
super().__init__()
|
179
|
-
self.lower = lower
|
180
|
-
self.upper = upper
|
181
|
-
|
182
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
183
|
-
return F.rrelu(x, self.lower, self.upper)
|
184
|
-
|
185
|
-
def extra_repr(self):
|
186
|
-
return f'{self.__class__.__name__}(lower={self.lower}, upper={self.upper})'
|
187
|
-
|
188
|
-
|
189
|
-
class Hardtanh(Module, ElementWiseBlock):
|
190
|
-
r"""Applies the HardTanh function element-wise.
|
191
|
-
|
192
|
-
HardTanh is defined as:
|
193
|
-
|
194
|
-
.. math::
|
195
|
-
\text{HardTanh}(x) = \begin{cases}
|
196
|
-
\text{max\_val} & \text{ if } x > \text{ max\_val } \\
|
197
|
-
\text{min\_val} & \text{ if } x < \text{ min\_val } \\
|
198
|
-
x & \text{ otherwise } \\
|
199
|
-
\end{cases}
|
200
|
-
|
201
|
-
Args:
|
202
|
-
min_val: minimum value of the linear region range. Default: -1
|
203
|
-
max_val: maximum value of the linear region range. Default: 1
|
204
|
-
|
205
|
-
Keyword arguments :attr:`min_value` and :attr:`max_value`
|
206
|
-
have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
|
207
|
-
|
208
|
-
Shape:
|
209
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
210
|
-
- Output: :math:`(*)`, same shape as the input.
|
211
|
-
|
212
|
-
Examples::
|
213
|
-
|
214
|
-
>>> import brainstate.nn as nn
|
215
|
-
>>> import brainstate as bst
|
216
|
-
>>> m = nn.Hardtanh(-2, 2)
|
217
|
-
>>> x = random.randn(2)
|
218
|
-
>>> output = m(x)
|
219
|
-
"""
|
220
|
-
__module__ = 'brainstate.nn'
|
221
|
-
min_val: float
|
222
|
-
max_val: float
|
223
|
-
|
224
|
-
def __init__(
|
225
|
-
self,
|
226
|
-
min_val: float = -1.,
|
227
|
-
max_val: float = 1.,
|
228
|
-
) -> None:
|
229
|
-
super().__init__()
|
230
|
-
self.min_val = min_val
|
231
|
-
self.max_val = max_val
|
232
|
-
assert self.max_val > self.min_val
|
233
|
-
|
234
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
235
|
-
return F.hard_tanh(x, self.min_val, self.max_val)
|
236
|
-
|
237
|
-
def extra_repr(self) -> str:
|
238
|
-
return f'{self.__class__.__name__}(min_val={self.min_val}, max_val={self.max_val})'
|
239
|
-
|
240
|
-
|
241
|
-
class ReLU6(Hardtanh, ElementWiseBlock):
|
242
|
-
r"""Applies the element-wise function:
|
243
|
-
|
244
|
-
.. math::
|
245
|
-
\text{ReLU6}(x) = \min(\max(0,x), 6)
|
246
|
-
|
247
|
-
Shape:
|
248
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
249
|
-
- Output: :math:`(*)`, same shape as the input.
|
250
|
-
|
251
|
-
Examples::
|
252
|
-
|
253
|
-
>>> import brainstate.nn as nn
|
254
|
-
>>> import brainstate as bst
|
255
|
-
>>> m = nn.ReLU6()
|
256
|
-
>>> x = random.randn(2)
|
257
|
-
>>> output = m(x)
|
258
|
-
"""
|
259
|
-
__module__ = 'brainstate.nn'
|
260
|
-
|
261
|
-
def __init__(self):
|
262
|
-
super().__init__(0., 6.)
|
263
|
-
|
264
|
-
|
265
|
-
class Sigmoid(Module, ElementWiseBlock):
|
266
|
-
r"""Applies the element-wise function:
|
267
|
-
|
268
|
-
.. math::
|
269
|
-
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
|
270
|
-
|
271
|
-
|
272
|
-
Shape:
|
273
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
274
|
-
- Output: :math:`(*)`, same shape as the input.
|
275
|
-
|
276
|
-
Examples::
|
277
|
-
|
278
|
-
>>> import brainstate.nn as nn
|
279
|
-
>>> import brainstate as bst
|
280
|
-
>>> m = nn.Sigmoid()
|
281
|
-
>>> x = random.randn(2)
|
282
|
-
>>> output = m(x)
|
283
|
-
"""
|
284
|
-
__module__ = 'brainstate.nn'
|
285
|
-
|
286
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
287
|
-
return F.sigmoid(x)
|
288
|
-
|
289
|
-
|
290
|
-
class Hardsigmoid(Module, ElementWiseBlock):
|
291
|
-
r"""Applies the Hardsigmoid function element-wise.
|
292
|
-
|
293
|
-
Hardsigmoid is defined as:
|
294
|
-
|
295
|
-
.. math::
|
296
|
-
\text{Hardsigmoid}(x) = \begin{cases}
|
297
|
-
0 & \text{if~} x \le -3, \\
|
298
|
-
1 & \text{if~} x \ge +3, \\
|
299
|
-
x / 6 + 1 / 2 & \text{otherwise}
|
300
|
-
\end{cases}
|
301
|
-
|
302
|
-
Shape:
|
303
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
304
|
-
- Output: :math:`(*)`, same shape as the input.
|
305
|
-
|
306
|
-
Examples::
|
307
|
-
|
308
|
-
>>> import brainstate.nn as nn
|
309
|
-
>>> import brainstate as bst
|
310
|
-
>>> m = nn.Hardsigmoid()
|
311
|
-
>>> x = random.randn(2)
|
312
|
-
>>> output = m(x)
|
313
|
-
"""
|
314
|
-
__module__ = 'brainstate.nn'
|
315
|
-
|
316
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
317
|
-
return F.hard_sigmoid(x)
|
318
|
-
|
319
|
-
|
320
|
-
class Tanh(Module, ElementWiseBlock):
|
321
|
-
r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
|
322
|
-
|
323
|
-
Tanh is defined as:
|
324
|
-
|
325
|
-
.. math::
|
326
|
-
\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
|
327
|
-
|
328
|
-
Shape:
|
329
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
330
|
-
- Output: :math:`(*)`, same shape as the input.
|
331
|
-
|
332
|
-
Examples::
|
333
|
-
|
334
|
-
>>> import brainstate.nn as nn
|
335
|
-
>>> import brainstate as bst
|
336
|
-
>>> m = nn.Tanh()
|
337
|
-
>>> x = random.randn(2)
|
338
|
-
>>> output = m(x)
|
339
|
-
"""
|
340
|
-
__module__ = 'brainstate.nn'
|
341
|
-
|
342
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
343
|
-
return F.tanh(x)
|
344
|
-
|
345
|
-
|
346
|
-
class SiLU(Module, ElementWiseBlock):
|
347
|
-
r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
|
348
|
-
The SiLU function is also known as the swish function.
|
349
|
-
|
350
|
-
.. math::
|
351
|
-
\text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
|
352
|
-
|
353
|
-
.. note::
|
354
|
-
See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
|
355
|
-
where the SiLU (Sigmoid Linear Unit) was originally coined, and see
|
356
|
-
`Sigmoid-Weighted Linear Units for Neural Network Function Approximation
|
357
|
-
in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
|
358
|
-
a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
|
359
|
-
where the SiLU was experimented with later.
|
360
|
-
|
361
|
-
Shape:
|
362
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
363
|
-
- Output: :math:`(*)`, same shape as the input.
|
364
|
-
|
365
|
-
Examples::
|
366
|
-
|
367
|
-
>>> import brainstate.nn as nn
|
368
|
-
>>> m = nn.SiLU()
|
369
|
-
>>> x = random.randn(2)
|
370
|
-
>>> output = m(x)
|
371
|
-
"""
|
372
|
-
__module__ = 'brainstate.nn'
|
373
|
-
|
374
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
375
|
-
return F.silu(x)
|
376
|
-
|
377
|
-
|
378
|
-
class Mish(Module, ElementWiseBlock):
|
379
|
-
r"""Applies the Mish function, element-wise.
|
380
|
-
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
381
|
-
|
382
|
-
.. math::
|
383
|
-
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
|
384
|
-
|
385
|
-
.. note::
|
386
|
-
See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
|
387
|
-
|
388
|
-
|
389
|
-
Shape:
|
390
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
391
|
-
- Output: :math:`(*)`, same shape as the input.
|
392
|
-
|
393
|
-
Examples::
|
394
|
-
|
395
|
-
>>> import brainstate.nn as nn
|
396
|
-
>>> import brainstate as bst
|
397
|
-
>>> m = nn.Mish()
|
398
|
-
>>> x = random.randn(2)
|
399
|
-
>>> output = m(x)
|
400
|
-
"""
|
401
|
-
__module__ = 'brainstate.nn'
|
402
|
-
|
403
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
404
|
-
return F.mish(x)
|
405
|
-
|
406
|
-
|
407
|
-
class Hardswish(Module, ElementWiseBlock):
|
408
|
-
r"""Applies the Hardswish function, element-wise, as described in the paper:
|
409
|
-
`Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
|
410
|
-
|
411
|
-
Hardswish is defined as:
|
412
|
-
|
413
|
-
.. math::
|
414
|
-
\text{Hardswish}(x) = \begin{cases}
|
415
|
-
0 & \text{if~} x \le -3, \\
|
416
|
-
x & \text{if~} x \ge +3, \\
|
417
|
-
x \cdot (x + 3) /6 & \text{otherwise}
|
418
|
-
\end{cases}
|
419
|
-
|
420
|
-
|
421
|
-
Shape:
|
422
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
423
|
-
- Output: :math:`(*)`, same shape as the input.
|
424
|
-
|
425
|
-
Examples::
|
426
|
-
|
427
|
-
>>> import brainstate.nn as nn
|
428
|
-
>>> import brainstate as bst
|
429
|
-
>>> m = nn.Hardswish()
|
430
|
-
>>> x = random.randn(2)
|
431
|
-
>>> output = m(x)
|
432
|
-
"""
|
433
|
-
__module__ = 'brainstate.nn'
|
434
|
-
|
435
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
436
|
-
return F.hard_swish(x)
|
437
|
-
|
438
|
-
|
439
|
-
class ELU(Module, ElementWiseBlock):
|
440
|
-
r"""Applies the Exponential Linear Unit (ELU) function, element-wise, as described
|
441
|
-
in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
|
442
|
-
Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
|
443
|
-
|
444
|
-
ELU is defined as:
|
445
|
-
|
446
|
-
.. math::
|
447
|
-
\text{ELU}(x) = \begin{cases}
|
448
|
-
x, & \text{ if } x > 0\\
|
449
|
-
\alpha * (\exp(x) - 1), & \text{ if } x \leq 0
|
450
|
-
\end{cases}
|
451
|
-
|
452
|
-
Args:
|
453
|
-
alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
|
454
|
-
|
455
|
-
Shape:
|
456
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
457
|
-
- Output: :math:`(*)`, same shape as the input.
|
458
|
-
|
459
|
-
Examples::
|
460
|
-
|
461
|
-
>>> import brainstate.nn as nn
|
462
|
-
>>> import brainstate as bst
|
463
|
-
>>> m = nn.ELU()
|
464
|
-
>>> x = random.randn(2)
|
465
|
-
>>> output = m(x)
|
466
|
-
"""
|
467
|
-
__module__ = 'brainstate.nn'
|
468
|
-
alpha: float
|
469
|
-
|
470
|
-
def __init__(self, alpha: float = 1.) -> None:
|
471
|
-
super().__init__()
|
472
|
-
self.alpha = alpha
|
473
|
-
|
474
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
475
|
-
return F.elu(x, self.alpha)
|
476
|
-
|
477
|
-
def extra_repr(self) -> str:
|
478
|
-
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
479
|
-
|
480
|
-
|
481
|
-
class CELU(Module, ElementWiseBlock):
|
482
|
-
r"""Applies the element-wise function:
|
483
|
-
|
484
|
-
.. math::
|
485
|
-
\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
|
486
|
-
|
487
|
-
More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
|
488
|
-
|
489
|
-
Args:
|
490
|
-
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
|
491
|
-
|
492
|
-
Shape:
|
493
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
494
|
-
- Output: :math:`(*)`, same shape as the input.
|
495
|
-
|
496
|
-
Examples::
|
497
|
-
|
498
|
-
>>> import brainstate.nn as nn
|
499
|
-
>>> import brainstate as bst
|
500
|
-
>>> m = nn.CELU()
|
501
|
-
>>> x = random.randn(2)
|
502
|
-
>>> output = m(x)
|
503
|
-
|
504
|
-
.. _`Continuously Differentiable Exponential Linear Units`:
|
505
|
-
https://arxiv.org/abs/1704.07483
|
506
|
-
"""
|
507
|
-
__module__ = 'brainstate.nn'
|
508
|
-
alpha: float
|
509
|
-
|
510
|
-
def __init__(self, alpha: float = 1.) -> None:
|
511
|
-
super().__init__()
|
512
|
-
self.alpha = alpha
|
513
|
-
|
514
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
515
|
-
return F.celu(x, self.alpha)
|
516
|
-
|
517
|
-
def extra_repr(self) -> str:
|
518
|
-
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
519
|
-
|
520
|
-
|
521
|
-
class SELU(Module, ElementWiseBlock):
|
522
|
-
r"""Applied element-wise, as:
|
523
|
-
|
524
|
-
.. math::
|
525
|
-
\text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
|
526
|
-
|
527
|
-
with :math:`\alpha = 1.6732632423543772848170429916717` and
|
528
|
-
:math:`\text{scale} = 1.0507009873554804934193349852946`.
|
529
|
-
|
530
|
-
More details can be found in the paper `Self-Normalizing Neural Networks`_ .
|
531
|
-
|
532
|
-
|
533
|
-
Shape:
|
534
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
535
|
-
- Output: :math:`(*)`, same shape as the input.
|
536
|
-
|
537
|
-
Examples::
|
538
|
-
|
539
|
-
>>> import brainstate.nn as nn
|
540
|
-
>>> import brainstate as bst
|
541
|
-
>>> m = nn.SELU()
|
542
|
-
>>> x = random.randn(2)
|
543
|
-
>>> output = m(x)
|
544
|
-
|
545
|
-
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
546
|
-
"""
|
547
|
-
__module__ = 'brainstate.nn'
|
548
|
-
|
549
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
550
|
-
return F.selu(x)
|
551
|
-
|
552
|
-
|
553
|
-
class GLU(Module, ElementWiseBlock):
|
554
|
-
r"""Applies the gated linear unit function
|
555
|
-
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
|
556
|
-
of the input matrices and :math:`b` is the second half.
|
557
|
-
|
558
|
-
Args:
|
559
|
-
dim (int): the dimension on which to split the input. Default: -1
|
560
|
-
|
561
|
-
Shape:
|
562
|
-
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
|
563
|
-
dimensions
|
564
|
-
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
|
565
|
-
|
566
|
-
Examples::
|
567
|
-
|
568
|
-
>>> import brainstate.nn as nn
|
569
|
-
>>> import brainstate as bst
|
570
|
-
>>> m = nn.GLU()
|
571
|
-
>>> x = random.randn(4, 2)
|
572
|
-
>>> output = m(x)
|
573
|
-
"""
|
574
|
-
__module__ = 'brainstate.nn'
|
575
|
-
dim: int
|
576
|
-
|
577
|
-
def __init__(self, dim: int = -1) -> None:
|
578
|
-
super().__init__()
|
579
|
-
self.dim = dim
|
580
|
-
|
581
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
582
|
-
return F.glu(x, self.dim)
|
583
|
-
|
584
|
-
def __repr__(self):
|
585
|
-
return f'{self.__class__.__name__}(dim={self.dim})'
|
586
|
-
|
587
|
-
|
588
|
-
class GELU(Module, ElementWiseBlock):
|
589
|
-
r"""Applies the Gaussian Error Linear Units function:
|
590
|
-
|
591
|
-
.. math:: \text{GELU}(x) = x * \Phi(x)
|
592
|
-
|
593
|
-
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
|
594
|
-
|
595
|
-
When the approximate argument is 'tanh', Gelu is estimated with:
|
596
|
-
|
597
|
-
.. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
|
598
|
-
|
599
|
-
Args:
|
600
|
-
approximate (str, optional): the gelu approximation algorithm to use:
|
601
|
-
``'none'`` | ``'tanh'``. Default: ``'none'``
|
602
|
-
|
603
|
-
Shape:
|
604
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
605
|
-
- Output: :math:`(*)`, same shape as the input.
|
606
|
-
|
607
|
-
Examples::
|
608
|
-
|
609
|
-
>>> import brainstate.nn as nn
|
610
|
-
>>> import brainstate as bst
|
611
|
-
>>> m = nn.GELU()
|
612
|
-
>>> x = random.randn(2)
|
613
|
-
>>> output = m(x)
|
614
|
-
"""
|
615
|
-
__module__ = 'brainstate.nn'
|
616
|
-
approximate: bool
|
617
|
-
|
618
|
-
def __init__(self, approximate: bool = False) -> None:
|
619
|
-
super().__init__()
|
620
|
-
self.approximate = approximate
|
621
|
-
|
622
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
623
|
-
return F.gelu(x, approximate=self.approximate)
|
624
|
-
|
625
|
-
def __repr__(self):
|
626
|
-
return f'{self.__class__.__name__}(approximate={self.approximate})'
|
627
|
-
|
628
|
-
|
629
|
-
class Hardshrink(Module, ElementWiseBlock):
|
630
|
-
r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
|
631
|
-
|
632
|
-
Hardshrink is defined as:
|
633
|
-
|
634
|
-
.. math::
|
635
|
-
\text{HardShrink}(x) =
|
636
|
-
\begin{cases}
|
637
|
-
x, & \text{ if } x > \lambda \\
|
638
|
-
x, & \text{ if } x < -\lambda \\
|
639
|
-
0, & \text{ otherwise }
|
640
|
-
\end{cases}
|
641
|
-
|
642
|
-
Args:
|
643
|
-
lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
|
644
|
-
|
645
|
-
Shape:
|
646
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
647
|
-
- Output: :math:`(*)`, same shape as the input.
|
648
|
-
|
649
|
-
Examples::
|
650
|
-
|
651
|
-
>>> import brainstate.nn as nn
|
652
|
-
>>> import brainstate as bst
|
653
|
-
>>> m = nn.Hardshrink()
|
654
|
-
>>> x = random.randn(2)
|
655
|
-
>>> output = m(x)
|
656
|
-
"""
|
657
|
-
__module__ = 'brainstate.nn'
|
658
|
-
lambd: float
|
659
|
-
|
660
|
-
def __init__(self, lambd: float = 0.5) -> None:
|
661
|
-
super().__init__()
|
662
|
-
self.lambd = lambd
|
663
|
-
|
664
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
665
|
-
return F.hard_shrink(x, self.lambd)
|
666
|
-
|
667
|
-
def __repr__(self):
|
668
|
-
return f'{self.__class__.__name__}(lambd={self.lambd})'
|
669
|
-
|
670
|
-
|
671
|
-
class LeakyReLU(Module, ElementWiseBlock):
|
672
|
-
r"""Applies the element-wise function:
|
673
|
-
|
674
|
-
.. math::
|
675
|
-
\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
|
676
|
-
|
677
|
-
|
678
|
-
or
|
679
|
-
|
680
|
-
.. math::
|
681
|
-
\text{LeakyReLU}(x) =
|
682
|
-
\begin{cases}
|
683
|
-
x, & \text{ if } x \geq 0 \\
|
684
|
-
\text{negative\_slope} \times x, & \text{ otherwise }
|
685
|
-
\end{cases}
|
686
|
-
|
687
|
-
Args:
|
688
|
-
negative_slope: Controls the angle of the negative slope (which is used for
|
689
|
-
negative input values). Default: 1e-2
|
690
|
-
|
691
|
-
Shape:
|
692
|
-
- Input: :math:`(*)` where `*` means, any number of additional
|
693
|
-
dimensions
|
694
|
-
- Output: :math:`(*)`, same shape as the input
|
695
|
-
|
696
|
-
Examples::
|
697
|
-
|
698
|
-
>>> import brainstate.nn as nn
|
699
|
-
>>> import brainstate as bst
|
700
|
-
>>> m = nn.LeakyReLU(0.1)
|
701
|
-
>>> x = random.randn(2)
|
702
|
-
>>> output = m(x)
|
703
|
-
"""
|
704
|
-
__module__ = 'brainstate.nn'
|
705
|
-
negative_slope: float
|
706
|
-
|
707
|
-
def __init__(self, negative_slope: float = 1e-2) -> None:
|
708
|
-
super().__init__()
|
709
|
-
self.negative_slope = negative_slope
|
710
|
-
|
711
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
712
|
-
return F.leaky_relu(x, self.negative_slope)
|
713
|
-
|
714
|
-
def __repr__(self):
|
715
|
-
return f'{self.__class__.__name__}(negative_slope={self.negative_slope})'
|
716
|
-
|
717
|
-
|
718
|
-
class LogSigmoid(Module, ElementWiseBlock):
|
719
|
-
r"""Applies the element-wise function:
|
720
|
-
|
721
|
-
.. math::
|
722
|
-
\text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
|
723
|
-
|
724
|
-
Shape:
|
725
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
726
|
-
- Output: :math:`(*)`, same shape as the input.
|
727
|
-
|
728
|
-
Examples::
|
729
|
-
|
730
|
-
>>> import brainstate.nn as nn
|
731
|
-
>>> import brainstate as bst
|
732
|
-
>>> m = nn.LogSigmoid()
|
733
|
-
>>> x = random.randn(2)
|
734
|
-
>>> output = m(x)
|
735
|
-
"""
|
736
|
-
__module__ = 'brainstate.nn'
|
737
|
-
|
738
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
739
|
-
return F.log_sigmoid(x)
|
740
|
-
|
741
|
-
|
742
|
-
class Softplus(Module, ElementWiseBlock):
|
743
|
-
r"""Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} *
|
744
|
-
\log(1 + \exp(\beta * x))` element-wise.
|
745
|
-
|
746
|
-
SoftPlus is a smooth approximation to the ReLU function and can be used
|
747
|
-
to constrain the output of a machine to always be positive.
|
748
|
-
|
749
|
-
For numerical stability the implementation reverts to the linear function
|
750
|
-
when :math:`input \times \beta > threshold`.
|
751
|
-
|
752
|
-
Args:
|
753
|
-
beta: the :math:`\beta` value for the Softplus formulation. Default: 1
|
754
|
-
threshold: values above this revert to a linear function. Default: 20
|
755
|
-
|
756
|
-
Shape:
|
757
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
758
|
-
- Output: :math:`(*)`, same shape as the input.
|
759
|
-
|
760
|
-
Examples::
|
761
|
-
|
762
|
-
>>> import brainstate.nn as nn
|
763
|
-
>>> import brainstate as bst
|
764
|
-
>>> m = nn.Softplus()
|
765
|
-
>>> x = random.randn(2)
|
766
|
-
>>> output = m(x)
|
767
|
-
"""
|
768
|
-
__module__ = 'brainstate.nn'
|
769
|
-
beta: float
|
770
|
-
threshold: float
|
771
|
-
|
772
|
-
def __init__(self, beta: float = 1, threshold: float = 20.) -> None:
|
773
|
-
super().__init__()
|
774
|
-
self.beta = beta
|
775
|
-
self.threshold = threshold
|
776
|
-
|
777
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
778
|
-
return F.softplus(x, self.beta, self.threshold)
|
779
|
-
|
780
|
-
def __repr__(self):
|
781
|
-
return f'{self.__class__.__name__}(beta={self.beta}, threshold={self.threshold})'
|
782
|
-
|
783
|
-
|
784
|
-
class Softshrink(Module, ElementWiseBlock):
|
785
|
-
r"""Applies the soft shrinkage function elementwise:
|
786
|
-
|
787
|
-
.. math::
|
788
|
-
\text{SoftShrinkage}(x) =
|
789
|
-
\begin{cases}
|
790
|
-
x - \lambda, & \text{ if } x > \lambda \\
|
791
|
-
x + \lambda, & \text{ if } x < -\lambda \\
|
792
|
-
0, & \text{ otherwise }
|
793
|
-
\end{cases}
|
794
|
-
|
795
|
-
Args:
|
796
|
-
lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
|
797
|
-
|
798
|
-
Shape:
|
799
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
800
|
-
- Output: :math:`(*)`, same shape as the input.
|
801
|
-
|
802
|
-
Examples::
|
803
|
-
|
804
|
-
>>> import brainstate.nn as nn
|
805
|
-
>>> import brainstate as bst
|
806
|
-
>>> m = nn.Softshrink()
|
807
|
-
>>> x = random.randn(2)
|
808
|
-
>>> output = m(x)
|
809
|
-
"""
|
810
|
-
__module__ = 'brainstate.nn'
|
811
|
-
lambd: float
|
812
|
-
|
813
|
-
def __init__(self, lambd: float = 0.5) -> None:
|
814
|
-
super().__init__()
|
815
|
-
self.lambd = lambd
|
816
|
-
|
817
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
818
|
-
return F.soft_shrink(x, self.lambd)
|
819
|
-
|
820
|
-
def __repr__(self):
|
821
|
-
return f'{self.__class__.__name__}(lambd={self.lambd})'
|
822
|
-
|
823
|
-
|
824
|
-
class PReLU(Module, ElementWiseBlock):
|
825
|
-
r"""Applies the element-wise function:
|
826
|
-
|
827
|
-
.. math::
|
828
|
-
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
829
|
-
|
830
|
-
or
|
831
|
-
|
832
|
-
.. math::
|
833
|
-
\text{PReLU}(x) =
|
834
|
-
\begin{cases}
|
835
|
-
x, & \text{ if } x \geq 0 \\
|
836
|
-
ax, & \text{ otherwise }
|
837
|
-
\end{cases}
|
838
|
-
|
839
|
-
Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
|
840
|
-
parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
|
841
|
-
a separate :math:`a` is used for each input channel.
|
842
|
-
|
843
|
-
|
844
|
-
.. note::
|
845
|
-
weight decay should not be used when learning :math:`a` for good performance.
|
846
|
-
|
847
|
-
.. note::
|
848
|
-
Channel dim is the 2nd dim of input. When input has dims < 2, then there is
|
849
|
-
no channel dim and the number of channels = 1.
|
850
|
-
|
851
|
-
Args:
|
852
|
-
num_parameters (int): number of :math:`a` to learn.
|
853
|
-
Although it takes an int as input, there is only two values are legitimate:
|
854
|
-
1, or the number of channels at input. Default: 1
|
855
|
-
init (float): the initial value of :math:`a`. Default: 0.25
|
856
|
-
|
857
|
-
Shape:
|
858
|
-
- Input: :math:`( *)` where `*` means, any number of additional
|
859
|
-
dimensions.
|
860
|
-
- Output: :math:`(*)`, same shape as the input.
|
861
|
-
|
862
|
-
Attributes:
|
863
|
-
weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
|
864
|
-
|
865
|
-
Examples::
|
866
|
-
|
867
|
-
>>> import brainstate as bst
|
868
|
-
>>> m = bst.nn.PReLU()
|
869
|
-
>>> x = bst.random.randn(2)
|
870
|
-
>>> output = m(x)
|
871
|
-
"""
|
872
|
-
__module__ = 'brainstate.nn'
|
873
|
-
num_parameters: int
|
874
|
-
|
875
|
-
def __init__(self, num_parameters: int = 1, init: float = 0.25, dtype=None) -> None:
|
876
|
-
super().__init__()
|
877
|
-
self.num_parameters = num_parameters
|
878
|
-
self.weight = ParamState(jnp.ones(num_parameters, dtype=dtype) * init)
|
879
|
-
|
880
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
881
|
-
return F.prelu(x, self.weight.value)
|
882
|
-
|
883
|
-
def __repr__(self):
|
884
|
-
return f'{self.__class__.__name__}(num_parameters={self.num_parameters})'
|
885
|
-
|
886
|
-
|
887
|
-
class Softsign(Module, ElementWiseBlock):
|
888
|
-
r"""Applies the element-wise function:
|
889
|
-
|
890
|
-
.. math::
|
891
|
-
\text{SoftSign}(x) = \frac{x}{ 1 + |x|}
|
892
|
-
|
893
|
-
Shape:
|
894
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
895
|
-
- Output: :math:`(*)`, same shape as the input.
|
896
|
-
|
897
|
-
Examples::
|
898
|
-
|
899
|
-
>>> import brainstate.nn as nn
|
900
|
-
>>> import brainstate as bst
|
901
|
-
>>> m = nn.Softsign()
|
902
|
-
>>> x = random.randn(2)
|
903
|
-
>>> output = m(x)
|
904
|
-
"""
|
905
|
-
__module__ = 'brainstate.nn'
|
906
|
-
|
907
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
908
|
-
return F.soft_sign(x)
|
909
|
-
|
910
|
-
|
911
|
-
class Tanhshrink(Module, ElementWiseBlock):
|
912
|
-
r"""Applies the element-wise function:
|
913
|
-
|
914
|
-
.. math::
|
915
|
-
\text{Tanhshrink}(x) = x - \tanh(x)
|
916
|
-
|
917
|
-
Shape:
|
918
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
919
|
-
- Output: :math:`(*)`, same shape as the input.
|
920
|
-
|
921
|
-
Examples::
|
922
|
-
|
923
|
-
>>> import brainstate.nn as nn
|
924
|
-
>>> import brainstate as bst
|
925
|
-
>>> m = nn.Tanhshrink()
|
926
|
-
>>> x = random.randn(2)
|
927
|
-
>>> output = m(x)
|
928
|
-
"""
|
929
|
-
__module__ = 'brainstate.nn'
|
930
|
-
|
931
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
932
|
-
return F.tanh_shrink(x)
|
933
|
-
|
934
|
-
|
935
|
-
class Softmin(Module, ElementWiseBlock):
|
936
|
-
r"""Applies the Softmin function to an n-dimensional input Tensor
|
937
|
-
rescaling them so that the elements of the n-dimensional output Tensor
|
938
|
-
lie in the range `[0, 1]` and sum to 1.
|
939
|
-
|
940
|
-
Softmin is defined as:
|
941
|
-
|
942
|
-
.. math::
|
943
|
-
\text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
|
944
|
-
|
945
|
-
Shape:
|
946
|
-
- Input: :math:`(*)` where `*` means, any number of additional
|
947
|
-
dimensions
|
948
|
-
- Output: :math:`(*)`, same shape as the input
|
949
|
-
|
950
|
-
Args:
|
951
|
-
dim (int): A dimension along which Softmin will be computed (so every slice
|
952
|
-
along dim will sum to 1).
|
953
|
-
|
954
|
-
Returns:
|
955
|
-
a Tensor of the same dimension and shape as the input, with
|
956
|
-
values in the range [0, 1]
|
957
|
-
|
958
|
-
Examples::
|
959
|
-
|
960
|
-
>>> import brainstate.nn as nn
|
961
|
-
>>> import brainstate as bst
|
962
|
-
>>> m = nn.Softmin(dim=1)
|
963
|
-
>>> x = random.randn(2, 3)
|
964
|
-
>>> output = m(x)
|
965
|
-
"""
|
966
|
-
__module__ = 'brainstate.nn'
|
967
|
-
dim: Optional[int]
|
968
|
-
|
969
|
-
def __init__(self, dim: Optional[int] = None) -> None:
|
970
|
-
super().__init__()
|
971
|
-
self.dim = dim
|
972
|
-
|
973
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
974
|
-
return F.softmin(x, self.dim)
|
975
|
-
|
976
|
-
def __repr__(self):
|
977
|
-
return f'{self.__class__.__name__}(dim={self.dim})'
|
978
|
-
|
979
|
-
|
980
|
-
class Softmax(Module, ElementWiseBlock):
|
981
|
-
r"""Applies the Softmax function to an n-dimensional input Tensor
|
982
|
-
rescaling them so that the elements of the n-dimensional output Tensor
|
983
|
-
lie in the range [0,1] and sum to 1.
|
984
|
-
|
985
|
-
Softmax is defined as:
|
986
|
-
|
987
|
-
.. math::
|
988
|
-
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
989
|
-
|
990
|
-
When the input Tensor is a sparse tensor then the unspecified
|
991
|
-
values are treated as ``-inf``.
|
992
|
-
|
993
|
-
Shape:
|
994
|
-
- Input: :math:`(*)` where `*` means, any number of additional
|
995
|
-
dimensions
|
996
|
-
- Output: :math:`(*)`, same shape as the input
|
997
|
-
|
998
|
-
Returns:
|
999
|
-
a Tensor of the same dimension and shape as the input with
|
1000
|
-
values in the range [0, 1]
|
1001
|
-
|
1002
|
-
Args:
|
1003
|
-
dim (int): A dimension along which Softmax will be computed (so every slice
|
1004
|
-
along dim will sum to 1).
|
1005
|
-
|
1006
|
-
.. note::
|
1007
|
-
This module doesn't work directly with NLLLoss,
|
1008
|
-
which expects the Log to be computed between the Softmax and itself.
|
1009
|
-
Use `LogSoftmax` instead (it's faster and has better numerical properties).
|
1010
|
-
|
1011
|
-
Examples::
|
1012
|
-
|
1013
|
-
>>> import brainstate.nn as nn
|
1014
|
-
>>> import brainstate as bst
|
1015
|
-
>>> m = nn.Softmax(dim=1)
|
1016
|
-
>>> x = random.randn(2, 3)
|
1017
|
-
>>> output = m(x)
|
1018
|
-
|
1019
|
-
"""
|
1020
|
-
__module__ = 'brainstate.nn'
|
1021
|
-
dim: Optional[int]
|
1022
|
-
|
1023
|
-
def __init__(self, dim: Optional[int] = None) -> None:
|
1024
|
-
super().__init__()
|
1025
|
-
self.dim = dim
|
1026
|
-
|
1027
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
1028
|
-
return F.softmax(x, self.dim)
|
1029
|
-
|
1030
|
-
def __repr__(self):
|
1031
|
-
return f'{self.__class__.__name__}(dim={self.dim})'
|
1032
|
-
|
1033
|
-
|
1034
|
-
class Softmax2d(Module, ElementWiseBlock):
|
1035
|
-
r"""Applies SoftMax over features to each spatial location.
|
1036
|
-
|
1037
|
-
When given an image of ``Channels x Height x Width``, it will
|
1038
|
-
apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
|
1039
|
-
|
1040
|
-
Shape:
|
1041
|
-
- Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
|
1042
|
-
- Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
|
1043
|
-
|
1044
|
-
Returns:
|
1045
|
-
a Tensor of the same dimension and shape as the input with
|
1046
|
-
values in the range [0, 1]
|
1047
|
-
|
1048
|
-
Examples::
|
1049
|
-
|
1050
|
-
>>> import brainstate.nn as nn
|
1051
|
-
>>> import brainstate as bst
|
1052
|
-
>>> m = nn.Softmax2d()
|
1053
|
-
>>> # you softmax over the 2nd dimension
|
1054
|
-
>>> x = random.randn(2, 3, 12, 13)
|
1055
|
-
>>> output = m(x)
|
1056
|
-
"""
|
1057
|
-
__module__ = 'brainstate.nn'
|
1058
|
-
|
1059
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
1060
|
-
assert x.ndim == 4 or x.ndim == 3, 'Softmax2d requires a 3D or 4D tensor as input'
|
1061
|
-
return F.softmax(x, -3)
|
1062
|
-
|
1063
|
-
|
1064
|
-
class LogSoftmax(Module, ElementWiseBlock):
|
1065
|
-
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
|
1066
|
-
input Tensor. The LogSoftmax formulation can be simplified as:
|
1067
|
-
|
1068
|
-
.. math::
|
1069
|
-
\text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
|
1070
|
-
|
1071
|
-
Shape:
|
1072
|
-
- Input: :math:`(*)` where `*` means, any number of additional
|
1073
|
-
dimensions
|
1074
|
-
- Output: :math:`(*)`, same shape as the input
|
1075
|
-
|
1076
|
-
Args:
|
1077
|
-
dim (int): A dimension along which LogSoftmax will be computed.
|
1078
|
-
|
1079
|
-
Returns:
|
1080
|
-
a Tensor of the same dimension and shape as the input with
|
1081
|
-
values in the range [-inf, 0)
|
1082
|
-
|
1083
|
-
Examples::
|
1084
|
-
|
1085
|
-
>>> import brainstate.nn as nn
|
1086
|
-
>>> import brainstate as bst
|
1087
|
-
>>> m = nn.LogSoftmax(dim=1)
|
1088
|
-
>>> x = random.randn(2, 3)
|
1089
|
-
>>> output = m(x)
|
1090
|
-
"""
|
1091
|
-
__module__ = 'brainstate.nn'
|
1092
|
-
dim: Optional[int]
|
1093
|
-
|
1094
|
-
def __init__(self, dim: Optional[int] = None) -> None:
|
1095
|
-
super().__init__()
|
1096
|
-
self.dim = dim
|
1097
|
-
|
1098
|
-
def __call__(self, x: ArrayLike) -> ArrayLike:
|
1099
|
-
return F.log_softmax(x, self.dim)
|
1100
|
-
|
1101
|
-
def __repr__(self):
|
1102
|
-
return f'{self.__class__.__name__}(dim={self.dim})'
|
1103
|
-
|
1104
|
-
|
1105
|
-
class Identity(Module, ElementWiseBlock):
|
1106
|
-
r"""A placeholder identity operator that is argument-insensitive.
|
1107
|
-
"""
|
1108
|
-
__module__ = 'brainstate.nn'
|
1109
|
-
|
1110
|
-
def __call__(self, x):
|
1111
|
-
return x
|
1112
|
-
|
1113
|
-
|
1114
|
-
class Dropout(Module, ElementWiseBlock):
|
1115
|
-
"""A layer that stochastically ignores a subset of inputs each training step.
|
1116
|
-
|
1117
|
-
In training, to compensate for the fraction of input values dropped (`rate`),
|
1118
|
-
all surviving values are multiplied by `1 / (1 - rate)`.
|
1119
|
-
|
1120
|
-
This layer is active only during training (``mode=brainstate.mixin.Training``). In other
|
1121
|
-
circumstances it is a no-op.
|
1122
|
-
|
1123
|
-
.. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
|
1124
|
-
neural networks from overfitting." The journal of machine learning
|
1125
|
-
research 15.1 (2014): 1929-1958.
|
1126
|
-
|
1127
|
-
Args:
|
1128
|
-
prob: Probability to keep element of the tensor.
|
1129
|
-
mode: Mode. The computation mode of the object.
|
1130
|
-
name: str. The name of the dynamic system.
|
1131
|
-
|
1132
|
-
"""
|
1133
|
-
__module__ = 'brainstate.nn'
|
1134
|
-
|
1135
|
-
def __init__(
|
1136
|
-
self,
|
1137
|
-
prob: float = 0.5,
|
1138
|
-
mode: Optional[Mode] = None,
|
1139
|
-
name: Optional[str] = None
|
1140
|
-
) -> None:
|
1141
|
-
super().__init__(mode=mode, name=name)
|
1142
|
-
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
1143
|
-
self.prob = prob
|
1144
|
-
|
1145
|
-
def __call__(self, x):
|
1146
|
-
dtype = u.math.get_dtype(x)
|
1147
|
-
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
1148
|
-
if fit_phase and self.prob < 1.:
|
1149
|
-
keep_mask = random.bernoulli(self.prob, x.shape)
|
1150
|
-
return jnp.where(keep_mask,
|
1151
|
-
jnp.asarray(x / self.prob, dtype=dtype),
|
1152
|
-
jnp.asarray(0., dtype=dtype))
|
1153
|
-
else:
|
1154
|
-
return x
|
1155
|
-
|
1156
|
-
|
1157
|
-
class _DropoutNd(Module, ElementWiseBlock):
|
1158
|
-
__module__ = 'brainstate.nn'
|
1159
|
-
prob: float
|
1160
|
-
channel_axis: int
|
1161
|
-
minimal_dim: int
|
1162
|
-
|
1163
|
-
def __init__(
|
1164
|
-
self,
|
1165
|
-
prob: float = 0.5,
|
1166
|
-
channel_axis: int = -1,
|
1167
|
-
mode: Optional[Mode] = None,
|
1168
|
-
name: Optional[str] = None
|
1169
|
-
) -> None:
|
1170
|
-
super().__init__(mode=mode, name=name)
|
1171
|
-
assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}."
|
1172
|
-
self.prob = prob
|
1173
|
-
self.channel_axis = channel_axis
|
1174
|
-
|
1175
|
-
def __call__(self, x):
|
1176
|
-
dtype = u.math.get_dtype(x)
|
1177
|
-
# get fit phase
|
1178
|
-
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
1179
|
-
|
1180
|
-
# check input shape
|
1181
|
-
if self.mode.is_nonbatch_mode():
|
1182
|
-
assert x.ndim == self.minimal_dim, f"Input tensor must be {self.minimal_dim}D. But got {x.ndim}D."
|
1183
|
-
channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
|
1184
|
-
mask_shape = [(dim if i == channel_axis else 1) for i, dim in enumerate(x.shape)]
|
1185
|
-
else:
|
1186
|
-
channel_axis = (self.channel_axis + 1) if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
|
1187
|
-
assert channel_axis != 0, f"Channel axis must not be 0. But got {self.channel_axis}."
|
1188
|
-
mask_shape = [(dim if i in (channel_axis, 0) else 1) for i, dim in enumerate(x.shape)]
|
1189
|
-
|
1190
|
-
# generate mask
|
1191
|
-
if fit_phase:
|
1192
|
-
keep_mask = jnp.broadcast_to(random.bernoulli(self.prob, mask_shape), x.shape)
|
1193
|
-
return jnp.where(keep_mask,
|
1194
|
-
jnp.asarray(x / self.prob, dtype=dtype),
|
1195
|
-
jnp.asarray(0., dtype=dtype))
|
1196
|
-
else:
|
1197
|
-
return x
|
1198
|
-
|
1199
|
-
def __repr__(self) -> str:
|
1200
|
-
return f'{self.__class__.__name__}(prob={self.prob}, channel_axis={self.channel_axis})'
|
1201
|
-
|
1202
|
-
|
1203
|
-
class Dropout1d(_DropoutNd):
|
1204
|
-
r"""Randomly zero out entire channels (a channel is a 1D feature map,
|
1205
|
-
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
1206
|
-
batched input is a 1D tensor :math:`\text{input}[i, j]`).
|
1207
|
-
Each channel will be zeroed out independently on every forward call with
|
1208
|
-
probability :attr:`p` using samples from a Bernoulli distribution.
|
1209
|
-
|
1210
|
-
Usually the input comes from :class:`nn.Conv1d` modules.
|
1211
|
-
|
1212
|
-
As described in the paper
|
1213
|
-
`Efficient Object Localization Using Convolutional Networks`_ ,
|
1214
|
-
if adjacent pixels within feature maps are strongly correlated
|
1215
|
-
(as is normally the case in early convolution layers) then i.i.d. dropout
|
1216
|
-
will not regularize the activations and will otherwise just result
|
1217
|
-
in an effective learning rate decrease.
|
1218
|
-
|
1219
|
-
In this case, :func:`nn.Dropout1d` will help promote independence between
|
1220
|
-
feature maps and should be used instead.
|
1221
|
-
|
1222
|
-
Args:
|
1223
|
-
prob: float. probability of an element to be zero-ed.
|
1224
|
-
|
1225
|
-
Shape:
|
1226
|
-
- Input: :math:`(N, C, L)` or :math:`(C, L)`.
|
1227
|
-
- Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
|
1228
|
-
|
1229
|
-
Examples::
|
1230
|
-
|
1231
|
-
>>> m = Dropout1d(p=0.2)
|
1232
|
-
>>> x = random.randn(20, 32, 16)
|
1233
|
-
>>> output = m(x)
|
1234
|
-
>>> output.shape
|
1235
|
-
(20, 32, 16)
|
1236
|
-
|
1237
|
-
.. _Efficient Object Localization Using Convolutional Networks:
|
1238
|
-
https://arxiv.org/abs/1411.4280
|
1239
|
-
"""
|
1240
|
-
__module__ = 'brainstate.nn'
|
1241
|
-
minimal_dim: int = 2
|
1242
|
-
|
1243
|
-
|
1244
|
-
class Dropout2d(_DropoutNd):
|
1245
|
-
r"""Randomly zero out entire channels (a channel is a 2D feature map,
|
1246
|
-
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
1247
|
-
batched input is a 2D tensor :math:`\text{input}[i, j]`).
|
1248
|
-
Each channel will be zeroed out independently on every forward call with
|
1249
|
-
probability :attr:`p` using samples from a Bernoulli distribution.
|
1250
|
-
|
1251
|
-
Usually the input comes from :class:`nn.Conv2d` modules.
|
1252
|
-
|
1253
|
-
As described in the paper
|
1254
|
-
`Efficient Object Localization Using Convolutional Networks`_ ,
|
1255
|
-
if adjacent pixels within feature maps are strongly correlated
|
1256
|
-
(as is normally the case in early convolution layers) then i.i.d. dropout
|
1257
|
-
will not regularize the activations and will otherwise just result
|
1258
|
-
in an effective learning rate decrease.
|
1259
|
-
|
1260
|
-
In this case, :func:`nn.Dropout2d` will help promote independence between
|
1261
|
-
feature maps and should be used instead.
|
1262
|
-
|
1263
|
-
Args:
|
1264
|
-
prob: float. probability of an element to be kept.
|
1265
|
-
|
1266
|
-
Shape:
|
1267
|
-
- Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`.
|
1268
|
-
- Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input).
|
1269
|
-
|
1270
|
-
Examples::
|
1271
|
-
|
1272
|
-
>>> m = Dropout2d(p=0.2)
|
1273
|
-
>>> x = random.randn(20, 32, 32, 16)
|
1274
|
-
>>> output = m(x)
|
1275
|
-
|
1276
|
-
.. _Efficient Object Localization Using Convolutional Networks:
|
1277
|
-
https://arxiv.org/abs/1411.4280
|
1278
|
-
"""
|
1279
|
-
__module__ = 'brainstate.nn'
|
1280
|
-
minimal_dim: int = 3
|
1281
|
-
|
1282
|
-
|
1283
|
-
class Dropout3d(_DropoutNd):
|
1284
|
-
r"""Randomly zero out entire channels (a channel is a 3D feature map,
|
1285
|
-
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
1286
|
-
batched input is a 3D tensor :math:`\text{input}[i, j]`).
|
1287
|
-
Each channel will be zeroed out independently on every forward call with
|
1288
|
-
probability :attr:`p` using samples from a Bernoulli distribution.
|
1289
|
-
|
1290
|
-
Usually the input comes from :class:`nn.Conv3d` modules.
|
1291
|
-
|
1292
|
-
As described in the paper
|
1293
|
-
`Efficient Object Localization Using Convolutional Networks`_ ,
|
1294
|
-
if adjacent pixels within feature maps are strongly correlated
|
1295
|
-
(as is normally the case in early convolution layers) then i.i.d. dropout
|
1296
|
-
will not regularize the activations and will otherwise just result
|
1297
|
-
in an effective learning rate decrease.
|
1298
|
-
|
1299
|
-
In this case, :func:`nn.Dropout3d` will help promote independence between
|
1300
|
-
feature maps and should be used instead.
|
1301
|
-
|
1302
|
-
Args:
|
1303
|
-
prob: float. probability of an element to be kept.
|
1304
|
-
|
1305
|
-
Shape:
|
1306
|
-
- Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
|
1307
|
-
- Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
|
1308
|
-
|
1309
|
-
Examples::
|
1310
|
-
|
1311
|
-
>>> m = Dropout3d(p=0.2)
|
1312
|
-
>>> x = random.randn(20, 16, 4, 32, 32)
|
1313
|
-
>>> output = m(x)
|
1314
|
-
|
1315
|
-
.. _Efficient Object Localization Using Convolutional Networks:
|
1316
|
-
https://arxiv.org/abs/1411.4280
|
1317
|
-
"""
|
1318
|
-
__module__ = 'brainstate.nn'
|
1319
|
-
minimal_dim: int = 4
|
1320
|
-
|
1321
|
-
|
1322
|
-
class AlphaDropout(_DropoutNd):
|
1323
|
-
r"""Applies Alpha Dropout over the input.
|
1324
|
-
|
1325
|
-
Alpha Dropout is a type of Dropout that maintains the self-normalizing
|
1326
|
-
property.
|
1327
|
-
For an input with zero mean and unit standard deviation, the output of
|
1328
|
-
Alpha Dropout maintains the original mean and standard deviation of the
|
1329
|
-
input.
|
1330
|
-
Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
|
1331
|
-
that the outputs have zero mean and unit standard deviation.
|
1332
|
-
|
1333
|
-
During training, it randomly masks some of the elements of the input
|
1334
|
-
tensor with probability *p* using samples from a bernoulli distribution.
|
1335
|
-
The elements to masked are randomized on every forward call, and scaled
|
1336
|
-
and shifted to maintain zero mean and unit standard deviation.
|
1337
|
-
|
1338
|
-
During evaluation the module simply computes an identity function.
|
1339
|
-
|
1340
|
-
More details can be found in the paper `Self-Normalizing Neural Networks`_ .
|
1341
|
-
|
1342
|
-
Args:
|
1343
|
-
prob: float. probability of an element to be kept.
|
1344
|
-
|
1345
|
-
Shape:
|
1346
|
-
- Input: :math:`(*)`. Input can be of any shape
|
1347
|
-
- Output: :math:`(*)`. Output is of the same shape as input
|
1348
|
-
|
1349
|
-
Examples::
|
1350
|
-
|
1351
|
-
>>> m = AlphaDropout(p=0.2)
|
1352
|
-
>>> x = random.randn(20, 16)
|
1353
|
-
>>> output = m(x)
|
1354
|
-
|
1355
|
-
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
1356
|
-
"""
|
1357
|
-
__module__ = 'brainstate.nn'
|
1358
|
-
|
1359
|
-
def forward(self, x):
|
1360
|
-
return F.alpha_dropout(x, self.p, self.training)
|
1361
|
-
|
1362
|
-
|
1363
|
-
class FeatureAlphaDropout(_DropoutNd):
|
1364
|
-
r"""Randomly masks out entire channels (a channel is a feature map,
|
1365
|
-
e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
|
1366
|
-
is a tensor :math:`\text{input}[i, j]`) of the input tensor). Instead of
|
1367
|
-
setting activations to zero, as in regular Dropout, the activations are set
|
1368
|
-
to the negative saturation value of the SELU activation function. More details
|
1369
|
-
can be found in the paper `Self-Normalizing Neural Networks`_ .
|
1370
|
-
|
1371
|
-
Each element will be masked independently for each sample on every forward
|
1372
|
-
call with probability :attr:`p` using samples from a Bernoulli distribution.
|
1373
|
-
The elements to be masked are randomized on every forward call, and scaled
|
1374
|
-
and shifted to maintain zero mean and unit variance.
|
1375
|
-
|
1376
|
-
Usually the input comes from :class:`nn.AlphaDropout` modules.
|
1377
|
-
|
1378
|
-
As described in the paper
|
1379
|
-
`Efficient Object Localization Using Convolutional Networks`_ ,
|
1380
|
-
if adjacent pixels within feature maps are strongly correlated
|
1381
|
-
(as is normally the case in early convolution layers) then i.i.d. dropout
|
1382
|
-
will not regularize the activations and will otherwise just result
|
1383
|
-
in an effective learning rate decrease.
|
1384
|
-
|
1385
|
-
In this case, :func:`nn.AlphaDropout` will help promote independence between
|
1386
|
-
feature maps and should be used instead.
|
1387
|
-
|
1388
|
-
Args:
|
1389
|
-
prob: float. probability of an element to be kept.
|
1390
|
-
|
1391
|
-
Shape:
|
1392
|
-
- Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
|
1393
|
-
- Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
|
1394
|
-
|
1395
|
-
Examples::
|
1396
|
-
|
1397
|
-
>>> m = FeatureAlphaDropout(p=0.2)
|
1398
|
-
>>> x = random.randn(20, 16, 4, 32, 32)
|
1399
|
-
>>> output = m(x)
|
1400
|
-
|
1401
|
-
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
1402
|
-
.. _Efficient Object Localization Using Convolutional Networks:
|
1403
|
-
https://arxiv.org/abs/1411.4280
|
1404
|
-
"""
|
1405
|
-
__module__ = 'brainstate.nn'
|
1406
|
-
|
1407
|
-
def forward(self, x):
|
1408
|
-
return F.feature_alpha_dropout(x, self.p, self.training)
|
1409
|
-
|
1410
|
-
|
1411
|
-
class SpikeBitwise(Module, ElementWiseBlock):
|
1412
|
-
r"""Bitwise addition for the spiking inputs.
|
1413
|
-
|
1414
|
-
.. math::
|
1415
|
-
|
1416
|
-
\begin{array}{ccc}
|
1417
|
-
\hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
|
1418
|
-
\hline \text { ADD } & x+y & x+y \\
|
1419
|
-
\text { AND } & x \cap y & x \cdot y \\
|
1420
|
-
\text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
|
1421
|
-
\text { OR } & x \cup y & (x+y)-(x \cdot y) \\
|
1422
|
-
\hline
|
1423
|
-
\end{array}
|
1424
|
-
|
1425
|
-
Args:
|
1426
|
-
op: str. The bitwise operation.
|
1427
|
-
name: str. The name of the dynamic system.
|
1428
|
-
"""
|
1429
|
-
__module__ = 'brainstate.nn'
|
1430
|
-
|
1431
|
-
def __init__(self,
|
1432
|
-
op: str = 'add',
|
1433
|
-
name: Optional[str] = None) -> None:
|
1434
|
-
super().__init__(name=name)
|
1435
|
-
self.op = op
|
1436
|
-
|
1437
|
-
def __call__(self, x, y):
|
1438
|
-
return F.spike_bitwise(x, y, self.op)
|