brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/nn/_activations.py
CHANGED
@@ -1,1100 +1,1100 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
"""
|
18
|
-
Shared neural network activations and other functions.
|
19
|
-
"""
|
20
|
-
|
21
|
-
from typing import Any, Union, Sequence
|
22
|
-
|
23
|
-
import brainunit as u
|
24
|
-
import jax
|
25
|
-
from jax.scipy.special import logsumexp
|
26
|
-
|
27
|
-
from brainstate import random
|
28
|
-
from brainstate.typing import ArrayLike
|
29
|
-
|
30
|
-
__all__ = [
|
31
|
-
"tanh",
|
32
|
-
"relu",
|
33
|
-
"squareplus",
|
34
|
-
"softplus",
|
35
|
-
"soft_sign",
|
36
|
-
"sigmoid",
|
37
|
-
"silu",
|
38
|
-
"swish",
|
39
|
-
"log_sigmoid",
|
40
|
-
"elu",
|
41
|
-
"leaky_relu",
|
42
|
-
"hard_tanh",
|
43
|
-
"celu",
|
44
|
-
"selu",
|
45
|
-
"gelu",
|
46
|
-
"glu",
|
47
|
-
"logsumexp",
|
48
|
-
"log_softmax",
|
49
|
-
"softmax",
|
50
|
-
"standardize",
|
51
|
-
"one_hot",
|
52
|
-
"relu6",
|
53
|
-
"hard_sigmoid",
|
54
|
-
"hard_silu",
|
55
|
-
"hard_swish",
|
56
|
-
'hard_shrink',
|
57
|
-
'rrelu',
|
58
|
-
'mish',
|
59
|
-
'soft_shrink',
|
60
|
-
'prelu',
|
61
|
-
'tanh_shrink',
|
62
|
-
'softmin',
|
63
|
-
'sparse_plus',
|
64
|
-
'sparse_sigmoid',
|
65
|
-
]
|
66
|
-
|
67
|
-
|
68
|
-
def tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
69
|
-
r"""
|
70
|
-
Hyperbolic tangent activation function.
|
71
|
-
|
72
|
-
Computes the element-wise function:
|
73
|
-
|
74
|
-
.. math::
|
75
|
-
\mathrm{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
|
76
|
-
|
77
|
-
Parameters
|
78
|
-
----------
|
79
|
-
x : ArrayLike
|
80
|
-
Input array.
|
81
|
-
|
82
|
-
Returns
|
83
|
-
-------
|
84
|
-
jax.Array or Quantity
|
85
|
-
An array with the same shape as the input.
|
86
|
-
"""
|
87
|
-
return u.math.tanh(x)
|
88
|
-
|
89
|
-
|
90
|
-
def softmin(x, axis=-1):
|
91
|
-
r"""
|
92
|
-
Softmin activation function.
|
93
|
-
|
94
|
-
Applies the Softmin function to an n-dimensional input tensor, rescaling elements
|
95
|
-
so that they lie in the range [0, 1] and sum to 1 along the specified axis.
|
96
|
-
|
97
|
-
.. math::
|
98
|
-
\text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
|
99
|
-
|
100
|
-
Parameters
|
101
|
-
----------
|
102
|
-
x : ArrayLike
|
103
|
-
Input array of any shape.
|
104
|
-
axis : int, optional
|
105
|
-
The axis along which Softmin will be computed. Every slice along this
|
106
|
-
dimension will sum to 1. Default is -1.
|
107
|
-
|
108
|
-
Returns
|
109
|
-
-------
|
110
|
-
jax.Array or Quantity
|
111
|
-
Output array with the same shape as the input.
|
112
|
-
"""
|
113
|
-
unnormalized = u.math.exp(-x)
|
114
|
-
return unnormalized / unnormalized.sum(axis, keepdims=True)
|
115
|
-
|
116
|
-
|
117
|
-
def tanh_shrink(x):
|
118
|
-
r"""
|
119
|
-
Tanh shrink activation function.
|
120
|
-
|
121
|
-
Applies the element-wise function:
|
122
|
-
|
123
|
-
.. math::
|
124
|
-
\text{Tanhshrink}(x) = x - \tanh(x)
|
125
|
-
|
126
|
-
Parameters
|
127
|
-
----------
|
128
|
-
x : ArrayLike
|
129
|
-
Input array.
|
130
|
-
|
131
|
-
Returns
|
132
|
-
-------
|
133
|
-
jax.Array or Quantity
|
134
|
-
Output array with the same shape as the input.
|
135
|
-
"""
|
136
|
-
return x - u.math.tanh(x)
|
137
|
-
|
138
|
-
|
139
|
-
def prelu(x, a=0.25):
|
140
|
-
r"""
|
141
|
-
Parametric Rectified Linear Unit activation function.
|
142
|
-
|
143
|
-
Applies the element-wise function:
|
144
|
-
|
145
|
-
.. math::
|
146
|
-
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
147
|
-
|
148
|
-
or equivalently:
|
149
|
-
|
150
|
-
.. math::
|
151
|
-
\text{PReLU}(x) =
|
152
|
-
\begin{cases}
|
153
|
-
x, & \text{ if } x \geq 0 \\
|
154
|
-
ax, & \text{ otherwise }
|
155
|
-
\end{cases}
|
156
|
-
|
157
|
-
Parameters
|
158
|
-
----------
|
159
|
-
x : ArrayLike
|
160
|
-
Input array.
|
161
|
-
a : float or ArrayLike, optional
|
162
|
-
The negative slope coefficient. Can be a learnable parameter.
|
163
|
-
Default is 0.25.
|
164
|
-
|
165
|
-
Returns
|
166
|
-
-------
|
167
|
-
jax.Array or Quantity
|
168
|
-
Output array with the same shape as the input.
|
169
|
-
|
170
|
-
Notes
|
171
|
-
-----
|
172
|
-
When used in neural network layers, :math:`a` can be a learnable parameter
|
173
|
-
that is optimized during training.
|
174
|
-
"""
|
175
|
-
return u.math.where(x >= 0., x, a * x)
|
176
|
-
|
177
|
-
|
178
|
-
def soft_shrink(x, lambd=0.5):
|
179
|
-
r"""
|
180
|
-
Soft shrinkage activation function.
|
181
|
-
|
182
|
-
Applies the soft shrinkage function element-wise:
|
183
|
-
|
184
|
-
.. math::
|
185
|
-
\text{SoftShrinkage}(x) =
|
186
|
-
\begin{cases}
|
187
|
-
x - \lambda, & \text{ if } x > \lambda \\
|
188
|
-
x + \lambda, & \text{ if } x < -\lambda \\
|
189
|
-
0, & \text{ otherwise }
|
190
|
-
\end{cases}
|
191
|
-
|
192
|
-
Parameters
|
193
|
-
----------
|
194
|
-
x : ArrayLike
|
195
|
-
Input array of any shape.
|
196
|
-
lambd : float, optional
|
197
|
-
The :math:`\lambda` value for the soft shrinkage formulation.
|
198
|
-
Must be non-negative. Default is 0.5.
|
199
|
-
|
200
|
-
Returns
|
201
|
-
-------
|
202
|
-
jax.Array or Quantity
|
203
|
-
Output array with the same shape as the input.
|
204
|
-
"""
|
205
|
-
return u.math.where(
|
206
|
-
x > lambd,
|
207
|
-
x - lambd,
|
208
|
-
u.math.where(
|
209
|
-
x < -lambd,
|
210
|
-
x + lambd,
|
211
|
-
u.Quantity(0., unit=u.get_unit(lambd))
|
212
|
-
)
|
213
|
-
)
|
214
|
-
|
215
|
-
|
216
|
-
def mish(x):
|
217
|
-
r"""
|
218
|
-
Mish activation function.
|
219
|
-
|
220
|
-
Mish is a self-regularized non-monotonic activation function.
|
221
|
-
|
222
|
-
.. math::
|
223
|
-
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
|
224
|
-
|
225
|
-
Parameters
|
226
|
-
----------
|
227
|
-
x : ArrayLike
|
228
|
-
Input array of any shape.
|
229
|
-
|
230
|
-
Returns
|
231
|
-
-------
|
232
|
-
jax.Array or Quantity
|
233
|
-
Output array with the same shape as the input.
|
234
|
-
|
235
|
-
References
|
236
|
-
----------
|
237
|
-
.. [1] Misra, D. (2019). "Mish: A Self Regularized Non-Monotonic Activation Function."
|
238
|
-
arXiv:1908.08681
|
239
|
-
"""
|
240
|
-
return x * u.math.tanh(softplus(x))
|
241
|
-
|
242
|
-
|
243
|
-
def rrelu(x, lower=0.125, upper=0.3333333333333333):
|
244
|
-
r"""
|
245
|
-
Randomized Leaky Rectified Linear Unit activation function.
|
246
|
-
|
247
|
-
The function is defined as:
|
248
|
-
|
249
|
-
.. math::
|
250
|
-
\text{RReLU}(x) =
|
251
|
-
\begin{cases}
|
252
|
-
x & \text{if } x \geq 0 \\
|
253
|
-
ax & \text{ otherwise }
|
254
|
-
\end{cases}
|
255
|
-
|
256
|
-
where :math:`a` is randomly sampled from uniform distribution
|
257
|
-
:math:`\mathcal{U}(\text{lower}, \text{upper})`.
|
258
|
-
|
259
|
-
Parameters
|
260
|
-
----------
|
261
|
-
x : ArrayLike
|
262
|
-
Input array of any shape.
|
263
|
-
lower : float, optional
|
264
|
-
Lower bound of the uniform distribution for sampling the negative slope.
|
265
|
-
Default is 1/8.
|
266
|
-
upper : float, optional
|
267
|
-
Upper bound of the uniform distribution for sampling the negative slope.
|
268
|
-
Default is 1/3.
|
269
|
-
|
270
|
-
Returns
|
271
|
-
-------
|
272
|
-
jax.Array or Quantity
|
273
|
-
Output array with the same shape as the input.
|
274
|
-
|
275
|
-
References
|
276
|
-
----------
|
277
|
-
.. [1] Xu, B., et al. (2015). "Empirical Evaluation of Rectified Activations
|
278
|
-
in Convolutional Network." arXiv:1505.00853
|
279
|
-
"""
|
280
|
-
a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
|
281
|
-
return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
|
282
|
-
|
283
|
-
|
284
|
-
def hard_shrink(x, lambd=0.5):
|
285
|
-
r"""
|
286
|
-
Hard shrinkage activation function.
|
287
|
-
|
288
|
-
Applies the hard shrinkage function element-wise:
|
289
|
-
|
290
|
-
.. math::
|
291
|
-
\text{HardShrink}(x) =
|
292
|
-
\begin{cases}
|
293
|
-
x, & \text{ if } x > \lambda \\
|
294
|
-
x, & \text{ if } x < -\lambda \\
|
295
|
-
0, & \text{ otherwise }
|
296
|
-
\end{cases}
|
297
|
-
|
298
|
-
Parameters
|
299
|
-
----------
|
300
|
-
x : ArrayLike
|
301
|
-
Input array of any shape.
|
302
|
-
lambd : float, optional
|
303
|
-
The :math:`\lambda` threshold value for the hard shrinkage formulation.
|
304
|
-
Default is 0.5.
|
305
|
-
|
306
|
-
Returns
|
307
|
-
-------
|
308
|
-
jax.Array or Quantity
|
309
|
-
Output array with the same shape as the input.
|
310
|
-
"""
|
311
|
-
return u.math.where(
|
312
|
-
x > lambd,
|
313
|
-
x,
|
314
|
-
u.math.where(
|
315
|
-
x < -lambd,
|
316
|
-
x,
|
317
|
-
u.Quantity(0., unit=u.get_unit(x))
|
318
|
-
)
|
319
|
-
)
|
320
|
-
|
321
|
-
|
322
|
-
def relu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
323
|
-
r"""
|
324
|
-
Rectified Linear Unit activation function.
|
325
|
-
|
326
|
-
Computes the element-wise function:
|
327
|
-
|
328
|
-
.. math::
|
329
|
-
\mathrm{relu}(x) = \max(x, 0)
|
330
|
-
|
331
|
-
Under differentiation, we take:
|
332
|
-
|
333
|
-
.. math::
|
334
|
-
\nabla \mathrm{relu}(0) = 0
|
335
|
-
|
336
|
-
Parameters
|
337
|
-
----------
|
338
|
-
x : ArrayLike
|
339
|
-
Input array.
|
340
|
-
|
341
|
-
Returns
|
342
|
-
-------
|
343
|
-
jax.Array or Quantity
|
344
|
-
An array with the same shape as the input.
|
345
|
-
|
346
|
-
Examples
|
347
|
-
--------
|
348
|
-
.. code-block:: python
|
349
|
-
|
350
|
-
>>> import jax.numpy as jnp
|
351
|
-
>>> import brainstate
|
352
|
-
>>> brainstate.nn.relu(jnp.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
|
353
|
-
Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
|
354
|
-
|
355
|
-
See Also
|
356
|
-
--------
|
357
|
-
relu6 : ReLU6 activation function.
|
358
|
-
leaky_relu : Leaky ReLU activation function.
|
359
|
-
|
360
|
-
References
|
361
|
-
----------
|
362
|
-
.. [1] For more information see "Numerical influence of ReLU'(0) on backpropagation"
|
363
|
-
https://openreview.net/forum?id=urrcVI-_jRm
|
364
|
-
"""
|
365
|
-
return u.math.relu(x)
|
366
|
-
|
367
|
-
|
368
|
-
def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Union[jax.Array, u.Quantity]:
|
369
|
-
r"""
|
370
|
-
Squareplus activation function.
|
371
|
-
|
372
|
-
Computes the element-wise function:
|
373
|
-
|
374
|
-
.. math::
|
375
|
-
\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
|
376
|
-
|
377
|
-
Parameters
|
378
|
-
----------
|
379
|
-
x : ArrayLike
|
380
|
-
Input array.
|
381
|
-
b : ArrayLike, optional
|
382
|
-
Smoothness parameter. Default is 4.
|
383
|
-
|
384
|
-
Returns
|
385
|
-
-------
|
386
|
-
jax.Array or Quantity
|
387
|
-
An array with the same shape as the input.
|
388
|
-
|
389
|
-
References
|
390
|
-
----------
|
391
|
-
.. [1] So, D., et al. (2021). "Primer: Searching for Efficient Transformers
|
392
|
-
for Language Modeling." arXiv:2112.11687
|
393
|
-
"""
|
394
|
-
return u.math.squareplus(x, b=b)
|
395
|
-
|
396
|
-
|
397
|
-
def softplus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
398
|
-
r"""
|
399
|
-
Softplus activation function.
|
400
|
-
|
401
|
-
Computes the element-wise function:
|
402
|
-
|
403
|
-
.. math::
|
404
|
-
\mathrm{softplus}(x) = \log(1 + e^x)
|
405
|
-
|
406
|
-
Parameters
|
407
|
-
----------
|
408
|
-
x : ArrayLike
|
409
|
-
Input array.
|
410
|
-
|
411
|
-
Returns
|
412
|
-
-------
|
413
|
-
jax.Array or Quantity
|
414
|
-
An array with the same shape as the input.
|
415
|
-
"""
|
416
|
-
return u.math.softplus(x)
|
417
|
-
|
418
|
-
|
419
|
-
def soft_sign(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
420
|
-
r"""
|
421
|
-
Soft-sign activation function.
|
422
|
-
|
423
|
-
Computes the element-wise function:
|
424
|
-
|
425
|
-
.. math::
|
426
|
-
\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
|
427
|
-
|
428
|
-
Parameters
|
429
|
-
----------
|
430
|
-
x : ArrayLike
|
431
|
-
Input array.
|
432
|
-
|
433
|
-
Returns
|
434
|
-
-------
|
435
|
-
jax.Array or Quantity
|
436
|
-
An array with the same shape as the input.
|
437
|
-
"""
|
438
|
-
return u.math.soft_sign(x)
|
439
|
-
|
440
|
-
|
441
|
-
def sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
442
|
-
r"""
|
443
|
-
Sigmoid activation function.
|
444
|
-
|
445
|
-
Computes the element-wise function:
|
446
|
-
|
447
|
-
.. math::
|
448
|
-
\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
|
449
|
-
|
450
|
-
Parameters
|
451
|
-
----------
|
452
|
-
x : ArrayLike
|
453
|
-
Input array.
|
454
|
-
|
455
|
-
Returns
|
456
|
-
-------
|
457
|
-
jax.Array or Quantity
|
458
|
-
An array with the same shape as the input.
|
459
|
-
|
460
|
-
See Also
|
461
|
-
--------
|
462
|
-
log_sigmoid : Logarithm of the sigmoid function.
|
463
|
-
"""
|
464
|
-
return u.math.sigmoid(x)
|
465
|
-
|
466
|
-
|
467
|
-
def silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
468
|
-
r"""
|
469
|
-
SiLU (Sigmoid Linear Unit) activation function.
|
470
|
-
|
471
|
-
Computes the element-wise function:
|
472
|
-
|
473
|
-
.. math::
|
474
|
-
\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
|
475
|
-
|
476
|
-
Parameters
|
477
|
-
----------
|
478
|
-
x : ArrayLike
|
479
|
-
Input array.
|
480
|
-
|
481
|
-
Returns
|
482
|
-
-------
|
483
|
-
jax.Array or Quantity
|
484
|
-
An array with the same shape as the input.
|
485
|
-
|
486
|
-
See Also
|
487
|
-
--------
|
488
|
-
sigmoid : The sigmoid function.
|
489
|
-
swish : Alias for silu.
|
490
|
-
|
491
|
-
Notes
|
492
|
-
-----
|
493
|
-
`swish` and `silu` are both aliases for the same function.
|
494
|
-
"""
|
495
|
-
return u.math.silu(x)
|
496
|
-
|
497
|
-
|
498
|
-
swish = silu
|
499
|
-
|
500
|
-
|
501
|
-
def log_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
502
|
-
r"""
|
503
|
-
Log-sigmoid activation function.
|
504
|
-
|
505
|
-
Computes the element-wise function:
|
506
|
-
|
507
|
-
.. math::
|
508
|
-
\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
|
509
|
-
|
510
|
-
Parameters
|
511
|
-
----------
|
512
|
-
x : ArrayLike
|
513
|
-
Input array.
|
514
|
-
|
515
|
-
Returns
|
516
|
-
-------
|
517
|
-
jax.Array or Quantity
|
518
|
-
An array with the same shape as the input.
|
519
|
-
|
520
|
-
See Also
|
521
|
-
--------
|
522
|
-
sigmoid : The sigmoid function.
|
523
|
-
"""
|
524
|
-
return u.math.log_sigmoid(x)
|
525
|
-
|
526
|
-
|
527
|
-
def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
|
528
|
-
r"""
|
529
|
-
Exponential Linear Unit activation function.
|
530
|
-
|
531
|
-
Computes the element-wise function:
|
532
|
-
|
533
|
-
.. math::
|
534
|
-
\mathrm{elu}(x) = \begin{cases}
|
535
|
-
x, & x > 0\\
|
536
|
-
\alpha \left(\exp(x) - 1\right), & x \le 0
|
537
|
-
\end{cases}
|
538
|
-
|
539
|
-
Parameters
|
540
|
-
----------
|
541
|
-
x : ArrayLike
|
542
|
-
Input array.
|
543
|
-
alpha : ArrayLike, optional
|
544
|
-
Scalar or array of alpha values. Default is 1.0.
|
545
|
-
|
546
|
-
Returns
|
547
|
-
-------
|
548
|
-
jax.Array or Quantity
|
549
|
-
An array with the same shape as the input.
|
550
|
-
|
551
|
-
See Also
|
552
|
-
--------
|
553
|
-
selu : Scaled ELU activation function.
|
554
|
-
celu : Continuously-differentiable ELU activation function.
|
555
|
-
"""
|
556
|
-
return u.math.elu(x, alpha=alpha)
|
557
|
-
|
558
|
-
|
559
|
-
def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Union[jax.Array, u.Quantity]:
|
560
|
-
r"""
|
561
|
-
Leaky Rectified Linear Unit activation function.
|
562
|
-
|
563
|
-
Computes the element-wise function:
|
564
|
-
|
565
|
-
.. math::
|
566
|
-
\mathrm{leaky\_relu}(x) = \begin{cases}
|
567
|
-
x, & x \ge 0\\
|
568
|
-
\alpha x, & x < 0
|
569
|
-
\end{cases}
|
570
|
-
|
571
|
-
where :math:`\alpha` = :code:`negative_slope`.
|
572
|
-
|
573
|
-
Parameters
|
574
|
-
----------
|
575
|
-
x : ArrayLike
|
576
|
-
Input array.
|
577
|
-
negative_slope : ArrayLike, optional
|
578
|
-
Array or scalar specifying the negative slope. Default is 0.01.
|
579
|
-
|
580
|
-
Returns
|
581
|
-
-------
|
582
|
-
jax.Array or Quantity
|
583
|
-
An array with the same shape as the input.
|
584
|
-
|
585
|
-
See Also
|
586
|
-
--------
|
587
|
-
relu : Standard ReLU activation function.
|
588
|
-
prelu : Parametric ReLU with learnable slope.
|
589
|
-
"""
|
590
|
-
return u.math.leaky_relu(x, negative_slope=negative_slope)
|
591
|
-
|
592
|
-
|
593
|
-
def _hard_tanh(x, min_val=- 1.0, max_val=1.0):
|
594
|
-
return jax.numpy.where(x > max_val, max_val, jax.numpy.where(x < min_val, min_val, x))
|
595
|
-
|
596
|
-
|
597
|
-
def hard_tanh(
|
598
|
-
x: ArrayLike,
|
599
|
-
min_val: float = - 1.0,
|
600
|
-
max_val: float = 1.0
|
601
|
-
) -> Union[jax.Array, u.Quantity]:
|
602
|
-
r"""
|
603
|
-
Hard hyperbolic tangent activation function.
|
604
|
-
|
605
|
-
Computes the element-wise function:
|
606
|
-
|
607
|
-
.. math::
|
608
|
-
\mathrm{hard\_tanh}(x) = \begin{cases}
|
609
|
-
-1, & x < -1\\
|
610
|
-
x, & -1 \le x \le 1\\
|
611
|
-
1, & 1 < x
|
612
|
-
\end{cases}
|
613
|
-
|
614
|
-
Parameters
|
615
|
-
----------
|
616
|
-
x : ArrayLike
|
617
|
-
Input array.
|
618
|
-
min_val : float, optional
|
619
|
-
Minimum value of the linear region range. Default is -1.
|
620
|
-
max_val : float, optional
|
621
|
-
Maximum value of the linear region range. Default is 1.
|
622
|
-
|
623
|
-
Returns
|
624
|
-
-------
|
625
|
-
jax.Array or Quantity
|
626
|
-
An array with the same shape as the input.
|
627
|
-
"""
|
628
|
-
x = u.Quantity(x)
|
629
|
-
min_val = u.Quantity(min_val).to(x.unit).mantissa
|
630
|
-
max_val = u.Quantity(max_val).to(x.unit).mantissa
|
631
|
-
return u.maybe_decimal(_hard_tanh(x.mantissa, min_val=min_val, max_val=max_val) * x.unit)
|
632
|
-
|
633
|
-
|
634
|
-
def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
|
635
|
-
r"""
|
636
|
-
Continuously-differentiable Exponential Linear Unit activation.
|
637
|
-
|
638
|
-
Computes the element-wise function:
|
639
|
-
|
640
|
-
.. math::
|
641
|
-
\mathrm{celu}(x) = \begin{cases}
|
642
|
-
x, & x > 0\\
|
643
|
-
\alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
|
644
|
-
\end{cases}
|
645
|
-
|
646
|
-
Parameters
|
647
|
-
----------
|
648
|
-
x : ArrayLike
|
649
|
-
Input array.
|
650
|
-
alpha : ArrayLike, optional
|
651
|
-
Scalar or array value controlling the smoothness. Default is 1.0.
|
652
|
-
|
653
|
-
Returns
|
654
|
-
-------
|
655
|
-
jax.Array or Quantity
|
656
|
-
An array with the same shape as the input.
|
657
|
-
|
658
|
-
References
|
659
|
-
----------
|
660
|
-
.. [1] Barron, J. T. (2017). "Continuously Differentiable Exponential Linear Units."
|
661
|
-
arXiv:1704.07483
|
662
|
-
"""
|
663
|
-
return u.math.celu(x, alpha=alpha)
|
664
|
-
|
665
|
-
|
666
|
-
def selu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
667
|
-
r"""
|
668
|
-
Scaled Exponential Linear Unit activation.
|
669
|
-
|
670
|
-
Computes the element-wise function:
|
671
|
-
|
672
|
-
.. math::
|
673
|
-
\mathrm{selu}(x) = \lambda \begin{cases}
|
674
|
-
x, & x > 0\\
|
675
|
-
\alpha e^x - \alpha, & x \le 0
|
676
|
-
\end{cases}
|
677
|
-
|
678
|
-
where :math:`\lambda = 1.0507009873554804934193349852946` and
|
679
|
-
:math:`\alpha = 1.6732632423543772848170429916717`.
|
680
|
-
|
681
|
-
Parameters
|
682
|
-
----------
|
683
|
-
x : ArrayLike
|
684
|
-
Input array.
|
685
|
-
|
686
|
-
Returns
|
687
|
-
-------
|
688
|
-
jax.Array or Quantity
|
689
|
-
An array with the same shape as the input.
|
690
|
-
|
691
|
-
See Also
|
692
|
-
--------
|
693
|
-
elu : Exponential Linear Unit activation function.
|
694
|
-
|
695
|
-
References
|
696
|
-
----------
|
697
|
-
.. [1] Klambauer, G., et al. (2017). "Self-Normalizing Neural Networks."
|
698
|
-
NeurIPS 2017.
|
699
|
-
"""
|
700
|
-
return u.math.selu(x)
|
701
|
-
|
702
|
-
|
703
|
-
def gelu(x: ArrayLike, approximate: bool = True) -> Union[jax.Array, u.Quantity]:
|
704
|
-
r"""
|
705
|
-
Gaussian Error Linear Unit activation function.
|
706
|
-
|
707
|
-
If ``approximate=False``, computes the element-wise function:
|
708
|
-
|
709
|
-
.. math::
|
710
|
-
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
|
711
|
-
\frac{x}{\sqrt{2}} \right) \right)
|
712
|
-
|
713
|
-
If ``approximate=True``, uses the approximate formulation of GELU:
|
714
|
-
|
715
|
-
.. math::
|
716
|
-
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
|
717
|
-
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
|
718
|
-
|
719
|
-
Parameters
|
720
|
-
----------
|
721
|
-
x : ArrayLike
|
722
|
-
Input array.
|
723
|
-
approximate : bool, optional
|
724
|
-
Whether to use the approximate (True) or exact (False) formulation.
|
725
|
-
Default is True.
|
726
|
-
|
727
|
-
Returns
|
728
|
-
-------
|
729
|
-
jax.Array or Quantity
|
730
|
-
An array with the same shape as the input.
|
731
|
-
|
732
|
-
References
|
733
|
-
----------
|
734
|
-
.. [1] Hendrycks, D., & Gimpel, K. (2016). "Gaussian Error Linear Units (GELUs)."
|
735
|
-
arXiv:1606.08415
|
736
|
-
"""
|
737
|
-
return u.math.gelu(x, approximate=approximate)
|
738
|
-
|
739
|
-
|
740
|
-
def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
|
741
|
-
r"""
|
742
|
-
Gated Linear Unit activation function.
|
743
|
-
|
744
|
-
Computes the function:
|
745
|
-
|
746
|
-
.. math::
|
747
|
-
\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot
|
748
|
-
\mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right]
|
749
|
-
\right)
|
750
|
-
|
751
|
-
where the array is split into two along ``axis``. The size of the ``axis``
|
752
|
-
dimension must be divisible by two.
|
753
|
-
|
754
|
-
Parameters
|
755
|
-
----------
|
756
|
-
x : ArrayLike
|
757
|
-
Input array. The dimension specified by ``axis`` must be divisible by 2.
|
758
|
-
axis : int, optional
|
759
|
-
The axis along which the split should be computed. Default is -1.
|
760
|
-
|
761
|
-
Returns
|
762
|
-
-------
|
763
|
-
jax.Array or Quantity
|
764
|
-
An array with the same shape as input except the ``axis`` dimension
|
765
|
-
is halved.
|
766
|
-
|
767
|
-
See Also
|
768
|
-
--------
|
769
|
-
sigmoid : The sigmoid activation function.
|
770
|
-
"""
|
771
|
-
return u.math.glu(x, axis=axis)
|
772
|
-
|
773
|
-
|
774
|
-
def log_softmax(x: ArrayLike,
|
775
|
-
axis: int | tuple[int, ...] | None = -1,
|
776
|
-
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
777
|
-
r"""
|
778
|
-
Log-Softmax function.
|
779
|
-
|
780
|
-
Computes the logarithm of the softmax function, which rescales
|
781
|
-
elements to the range :math:`[-\infty, 0)`.
|
782
|
-
|
783
|
-
.. math ::
|
784
|
-
\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
785
|
-
\right)
|
786
|
-
|
787
|
-
Parameters
|
788
|
-
----------
|
789
|
-
x : ArrayLike
|
790
|
-
Input array.
|
791
|
-
axis : int or tuple of int, optional
|
792
|
-
The axis or axes along which the log-softmax should be computed.
|
793
|
-
Either an integer or a tuple of integers. Default is -1.
|
794
|
-
where : ArrayLike, optional
|
795
|
-
Elements to include in the log-softmax computation.
|
796
|
-
|
797
|
-
Returns
|
798
|
-
-------
|
799
|
-
jax.Array or Quantity
|
800
|
-
An array with the same shape as the input.
|
801
|
-
|
802
|
-
See Also
|
803
|
-
--------
|
804
|
-
softmax : The softmax function.
|
805
|
-
"""
|
806
|
-
return jax.nn.log_softmax(x, axis=axis, where=where)
|
807
|
-
|
808
|
-
|
809
|
-
def softmax(x: ArrayLike,
|
810
|
-
axis: int | tuple[int, ...] | None = -1,
|
811
|
-
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
812
|
-
r"""
|
813
|
-
Softmax activation function.
|
814
|
-
|
815
|
-
Computes the function which rescales elements to the range :math:`[0, 1]`
|
816
|
-
such that the elements along :code:`axis` sum to :math:`1`.
|
817
|
-
|
818
|
-
.. math ::
|
819
|
-
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
820
|
-
|
821
|
-
Parameters
|
822
|
-
----------
|
823
|
-
x : ArrayLike
|
824
|
-
Input array.
|
825
|
-
axis : int or tuple of int, optional
|
826
|
-
The axis or axes along which the softmax should be computed. The
|
827
|
-
softmax output summed across these dimensions should sum to :math:`1`.
|
828
|
-
Either an integer or a tuple of integers. Default is -1.
|
829
|
-
where : ArrayLike, optional
|
830
|
-
Elements to include in the softmax computation.
|
831
|
-
|
832
|
-
Returns
|
833
|
-
-------
|
834
|
-
jax.Array or Quantity
|
835
|
-
An array with the same shape as the input.
|
836
|
-
|
837
|
-
See Also
|
838
|
-
--------
|
839
|
-
log_softmax : Logarithm of the softmax function.
|
840
|
-
softmin : Softmin activation function.
|
841
|
-
"""
|
842
|
-
return jax.nn.softmax(x, axis=axis, where=where)
|
843
|
-
|
844
|
-
|
845
|
-
def standardize(x: ArrayLike,
|
846
|
-
axis: int | tuple[int, ...] | None = -1,
|
847
|
-
variance: ArrayLike | None = None,
|
848
|
-
epsilon: ArrayLike = 1e-5,
|
849
|
-
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
850
|
-
r"""
|
851
|
-
Standardize (normalize) an array.
|
852
|
-
|
853
|
-
Normalizes an array by subtracting the mean and dividing by the standard
|
854
|
-
deviation :math:`\sqrt{\mathrm{variance}}`.
|
855
|
-
|
856
|
-
Parameters
|
857
|
-
----------
|
858
|
-
x : ArrayLike
|
859
|
-
Input array.
|
860
|
-
axis : int or tuple of int, optional
|
861
|
-
The axis or axes along which to compute the mean and variance.
|
862
|
-
Default is -1.
|
863
|
-
variance : ArrayLike, optional
|
864
|
-
Pre-computed variance. If None, variance is computed from ``x``.
|
865
|
-
epsilon : ArrayLike, optional
|
866
|
-
A small constant added to the variance to avoid division by zero.
|
867
|
-
Default is 1e-5.
|
868
|
-
where : ArrayLike, optional
|
869
|
-
Elements to include in the computation.
|
870
|
-
|
871
|
-
Returns
|
872
|
-
-------
|
873
|
-
jax.Array or Quantity
|
874
|
-
Standardized array with the same shape as the input.
|
875
|
-
"""
|
876
|
-
return jax.nn.standardize(x, axis=axis, where=where, variance=variance, epsilon=epsilon)
|
877
|
-
|
878
|
-
|
879
|
-
def one_hot(x: Any,
|
880
|
-
num_classes: int, *,
|
881
|
-
dtype: Any = jax.numpy.float_,
|
882
|
-
axis: Union[int, Sequence[int]] = -1) -> Union[jax.Array, u.Quantity]:
|
883
|
-
"""
|
884
|
-
One-hot encode the given indices.
|
885
|
-
|
886
|
-
Each index in the input ``x`` is encoded as a vector of zeros of length
|
887
|
-
``num_classes`` with the element at ``index`` set to one.
|
888
|
-
|
889
|
-
Indices outside the range [0, num_classes) will be encoded as zeros.
|
890
|
-
|
891
|
-
Parameters
|
892
|
-
----------
|
893
|
-
x : ArrayLike
|
894
|
-
A tensor of indices.
|
895
|
-
num_classes : int
|
896
|
-
Number of classes in the one-hot dimension.
|
897
|
-
dtype : dtype, optional
|
898
|
-
The dtype for the returned values. Default is ``jnp.float_``.
|
899
|
-
axis : int or Sequence of int, optional
|
900
|
-
The axis or axes along which the function should be computed.
|
901
|
-
Default is -1.
|
902
|
-
|
903
|
-
Returns
|
904
|
-
-------
|
905
|
-
jax.Array or Quantity
|
906
|
-
One-hot encoded array.
|
907
|
-
|
908
|
-
Examples
|
909
|
-
--------
|
910
|
-
.. code-block:: python
|
911
|
-
|
912
|
-
>>> import jax.numpy as jnp
|
913
|
-
>>> import brainstate
|
914
|
-
>>> brainstate.nn.one_hot(jnp.array([0, 1, 2]), 3)
|
915
|
-
Array([[1., 0., 0.],
|
916
|
-
[0., 1., 0.],
|
917
|
-
[0., 0., 1.]], dtype=float32)
|
918
|
-
|
919
|
-
>>> # Indices outside the range are encoded as zeros
|
920
|
-
>>> brainstate.nn.one_hot(jnp.array([-1, 3]), 3)
|
921
|
-
Array([[0., 0., 0.],
|
922
|
-
[0., 0., 0.]], dtype=float32)
|
923
|
-
"""
|
924
|
-
return jax.nn.one_hot(x, axis=axis, num_classes=num_classes, dtype=dtype)
|
925
|
-
|
926
|
-
|
927
|
-
def relu6(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
928
|
-
r"""
|
929
|
-
Rectified Linear Unit 6 activation function.
|
930
|
-
|
931
|
-
Computes the element-wise function:
|
932
|
-
|
933
|
-
.. math::
|
934
|
-
\mathrm{relu6}(x) = \min(\max(x, 0), 6)
|
935
|
-
|
936
|
-
Under differentiation, we take:
|
937
|
-
|
938
|
-
.. math::
|
939
|
-
\nabla \mathrm{relu}(0) = 0
|
940
|
-
|
941
|
-
and
|
942
|
-
|
943
|
-
.. math::
|
944
|
-
\nabla \mathrm{relu}(6) = 0
|
945
|
-
|
946
|
-
Parameters
|
947
|
-
----------
|
948
|
-
x : ArrayLike
|
949
|
-
Input array.
|
950
|
-
|
951
|
-
Returns
|
952
|
-
-------
|
953
|
-
jax.Array or Quantity
|
954
|
-
An array with the same shape as the input.
|
955
|
-
|
956
|
-
See Also
|
957
|
-
--------
|
958
|
-
relu : Standard ReLU activation function.
|
959
|
-
"""
|
960
|
-
return u.math.relu6(x)
|
961
|
-
|
962
|
-
|
963
|
-
def hard_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
964
|
-
r"""
|
965
|
-
Hard Sigmoid activation function.
|
966
|
-
|
967
|
-
Computes the element-wise function:
|
968
|
-
|
969
|
-
.. math::
|
970
|
-
\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
|
971
|
-
|
972
|
-
Parameters
|
973
|
-
----------
|
974
|
-
x : ArrayLike
|
975
|
-
Input array.
|
976
|
-
|
977
|
-
Returns
|
978
|
-
-------
|
979
|
-
jax.Array or Quantity
|
980
|
-
An array with the same shape as the input.
|
981
|
-
|
982
|
-
See Also
|
983
|
-
--------
|
984
|
-
relu6 : ReLU6 activation function.
|
985
|
-
sigmoid : Standard sigmoid function.
|
986
|
-
"""
|
987
|
-
return u.math.hard_sigmoid(x)
|
988
|
-
|
989
|
-
|
990
|
-
def hard_silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
991
|
-
r"""
|
992
|
-
Hard SiLU (Swish) activation function.
|
993
|
-
|
994
|
-
Computes the element-wise function:
|
995
|
-
|
996
|
-
.. math::
|
997
|
-
\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
|
998
|
-
|
999
|
-
Parameters
|
1000
|
-
----------
|
1001
|
-
x : ArrayLike
|
1002
|
-
Input array.
|
1003
|
-
|
1004
|
-
Returns
|
1005
|
-
-------
|
1006
|
-
jax.Array or Quantity
|
1007
|
-
An array with the same shape as the input.
|
1008
|
-
|
1009
|
-
See Also
|
1010
|
-
--------
|
1011
|
-
hard_sigmoid : Hard sigmoid activation function.
|
1012
|
-
silu : Standard SiLU activation function.
|
1013
|
-
hard_swish : Alias for hard_silu.
|
1014
|
-
|
1015
|
-
Notes
|
1016
|
-
-----
|
1017
|
-
Both `hard_silu` and `hard_swish` are aliases for the same function.
|
1018
|
-
"""
|
1019
|
-
return u.math.hard_silu(x)
|
1020
|
-
|
1021
|
-
|
1022
|
-
hard_swish = hard_silu
|
1023
|
-
|
1024
|
-
|
1025
|
-
def sparse_plus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
1026
|
-
r"""
|
1027
|
-
Sparse plus activation function.
|
1028
|
-
|
1029
|
-
Computes the function:
|
1030
|
-
|
1031
|
-
.. math::
|
1032
|
-
|
1033
|
-
\mathrm{sparse\_plus}(x) = \begin{cases}
|
1034
|
-
0, & x \leq -1\\
|
1035
|
-
\frac{1}{4}(x+1)^2, & -1 < x < 1 \\
|
1036
|
-
x, & 1 \leq x
|
1037
|
-
\end{cases}
|
1038
|
-
|
1039
|
-
This is the twin function of the softplus activation, ensuring a zero output
|
1040
|
-
for inputs less than -1 and a linear output for inputs greater than 1,
|
1041
|
-
while remaining smooth, convex, and monotonic between -1 and 1.
|
1042
|
-
|
1043
|
-
Parameters
|
1044
|
-
----------
|
1045
|
-
x : ArrayLike
|
1046
|
-
Input array.
|
1047
|
-
|
1048
|
-
Returns
|
1049
|
-
-------
|
1050
|
-
jax.Array or Quantity
|
1051
|
-
An array with the same shape as the input.
|
1052
|
-
|
1053
|
-
See Also
|
1054
|
-
--------
|
1055
|
-
sparse_sigmoid : Derivative of sparse_plus.
|
1056
|
-
softplus : Standard softplus activation function.
|
1057
|
-
"""
|
1058
|
-
return u.math.sparse_plus(x)
|
1059
|
-
|
1060
|
-
|
1061
|
-
def sparse_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
1062
|
-
r"""
|
1063
|
-
Sparse sigmoid activation function.
|
1064
|
-
|
1065
|
-
Computes the function:
|
1066
|
-
|
1067
|
-
.. math::
|
1068
|
-
|
1069
|
-
\mathrm{sparse\_sigmoid}(x) = \begin{cases}
|
1070
|
-
0, & x \leq -1\\
|
1071
|
-
\frac{1}{2}(x+1), & -1 < x < 1 \\
|
1072
|
-
1, & 1 \leq x
|
1073
|
-
\end{cases}
|
1074
|
-
|
1075
|
-
This is the twin function of the standard sigmoid activation, ensuring a zero
|
1076
|
-
output for inputs less than -1, a 1 output for inputs greater than 1, and a
|
1077
|
-
linear output for inputs between -1 and 1. It is the derivative of `sparse_plus`.
|
1078
|
-
|
1079
|
-
Parameters
|
1080
|
-
----------
|
1081
|
-
x : ArrayLike
|
1082
|
-
Input array.
|
1083
|
-
|
1084
|
-
Returns
|
1085
|
-
-------
|
1086
|
-
jax.Array or Quantity
|
1087
|
-
An array with the same shape as the input.
|
1088
|
-
|
1089
|
-
See Also
|
1090
|
-
--------
|
1091
|
-
sigmoid : Standard sigmoid activation function.
|
1092
|
-
sparse_plus : Sparse plus activation function.
|
1093
|
-
|
1094
|
-
References
|
1095
|
-
----------
|
1096
|
-
.. [1] Martins, A. F. T., & Astudillo, R. F. (2016). "From Softmax to Sparsemax:
|
1097
|
-
A Sparse Model of Attention and Multi-Label Classification."
|
1098
|
-
In ICML. See also "Learning with Fenchel-Young Losses", arXiv:1901.02324
|
1099
|
-
"""
|
1100
|
-
return u.math.sparse_sigmoid(x)
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
"""
|
18
|
+
Shared neural network activations and other functions.
|
19
|
+
"""
|
20
|
+
|
21
|
+
from typing import Any, Union, Sequence
|
22
|
+
|
23
|
+
import brainunit as u
|
24
|
+
import jax
|
25
|
+
from jax.scipy.special import logsumexp
|
26
|
+
|
27
|
+
from brainstate import random
|
28
|
+
from brainstate.typing import ArrayLike
|
29
|
+
|
30
|
+
__all__ = [
|
31
|
+
"tanh",
|
32
|
+
"relu",
|
33
|
+
"squareplus",
|
34
|
+
"softplus",
|
35
|
+
"soft_sign",
|
36
|
+
"sigmoid",
|
37
|
+
"silu",
|
38
|
+
"swish",
|
39
|
+
"log_sigmoid",
|
40
|
+
"elu",
|
41
|
+
"leaky_relu",
|
42
|
+
"hard_tanh",
|
43
|
+
"celu",
|
44
|
+
"selu",
|
45
|
+
"gelu",
|
46
|
+
"glu",
|
47
|
+
"logsumexp",
|
48
|
+
"log_softmax",
|
49
|
+
"softmax",
|
50
|
+
"standardize",
|
51
|
+
"one_hot",
|
52
|
+
"relu6",
|
53
|
+
"hard_sigmoid",
|
54
|
+
"hard_silu",
|
55
|
+
"hard_swish",
|
56
|
+
'hard_shrink',
|
57
|
+
'rrelu',
|
58
|
+
'mish',
|
59
|
+
'soft_shrink',
|
60
|
+
'prelu',
|
61
|
+
'tanh_shrink',
|
62
|
+
'softmin',
|
63
|
+
'sparse_plus',
|
64
|
+
'sparse_sigmoid',
|
65
|
+
]
|
66
|
+
|
67
|
+
|
68
|
+
def tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
69
|
+
r"""
|
70
|
+
Hyperbolic tangent activation function.
|
71
|
+
|
72
|
+
Computes the element-wise function:
|
73
|
+
|
74
|
+
.. math::
|
75
|
+
\mathrm{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
|
76
|
+
|
77
|
+
Parameters
|
78
|
+
----------
|
79
|
+
x : ArrayLike
|
80
|
+
Input array.
|
81
|
+
|
82
|
+
Returns
|
83
|
+
-------
|
84
|
+
jax.Array or Quantity
|
85
|
+
An array with the same shape as the input.
|
86
|
+
"""
|
87
|
+
return u.math.tanh(x)
|
88
|
+
|
89
|
+
|
90
|
+
def softmin(x, axis=-1):
|
91
|
+
r"""
|
92
|
+
Softmin activation function.
|
93
|
+
|
94
|
+
Applies the Softmin function to an n-dimensional input tensor, rescaling elements
|
95
|
+
so that they lie in the range [0, 1] and sum to 1 along the specified axis.
|
96
|
+
|
97
|
+
.. math::
|
98
|
+
\text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
|
99
|
+
|
100
|
+
Parameters
|
101
|
+
----------
|
102
|
+
x : ArrayLike
|
103
|
+
Input array of any shape.
|
104
|
+
axis : int, optional
|
105
|
+
The axis along which Softmin will be computed. Every slice along this
|
106
|
+
dimension will sum to 1. Default is -1.
|
107
|
+
|
108
|
+
Returns
|
109
|
+
-------
|
110
|
+
jax.Array or Quantity
|
111
|
+
Output array with the same shape as the input.
|
112
|
+
"""
|
113
|
+
unnormalized = u.math.exp(-x)
|
114
|
+
return unnormalized / unnormalized.sum(axis, keepdims=True)
|
115
|
+
|
116
|
+
|
117
|
+
def tanh_shrink(x):
|
118
|
+
r"""
|
119
|
+
Tanh shrink activation function.
|
120
|
+
|
121
|
+
Applies the element-wise function:
|
122
|
+
|
123
|
+
.. math::
|
124
|
+
\text{Tanhshrink}(x) = x - \tanh(x)
|
125
|
+
|
126
|
+
Parameters
|
127
|
+
----------
|
128
|
+
x : ArrayLike
|
129
|
+
Input array.
|
130
|
+
|
131
|
+
Returns
|
132
|
+
-------
|
133
|
+
jax.Array or Quantity
|
134
|
+
Output array with the same shape as the input.
|
135
|
+
"""
|
136
|
+
return x - u.math.tanh(x)
|
137
|
+
|
138
|
+
|
139
|
+
def prelu(x, a=0.25):
|
140
|
+
r"""
|
141
|
+
Parametric Rectified Linear Unit activation function.
|
142
|
+
|
143
|
+
Applies the element-wise function:
|
144
|
+
|
145
|
+
.. math::
|
146
|
+
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
147
|
+
|
148
|
+
or equivalently:
|
149
|
+
|
150
|
+
.. math::
|
151
|
+
\text{PReLU}(x) =
|
152
|
+
\begin{cases}
|
153
|
+
x, & \text{ if } x \geq 0 \\
|
154
|
+
ax, & \text{ otherwise }
|
155
|
+
\end{cases}
|
156
|
+
|
157
|
+
Parameters
|
158
|
+
----------
|
159
|
+
x : ArrayLike
|
160
|
+
Input array.
|
161
|
+
a : float or ArrayLike, optional
|
162
|
+
The negative slope coefficient. Can be a learnable parameter.
|
163
|
+
Default is 0.25.
|
164
|
+
|
165
|
+
Returns
|
166
|
+
-------
|
167
|
+
jax.Array or Quantity
|
168
|
+
Output array with the same shape as the input.
|
169
|
+
|
170
|
+
Notes
|
171
|
+
-----
|
172
|
+
When used in neural network layers, :math:`a` can be a learnable parameter
|
173
|
+
that is optimized during training.
|
174
|
+
"""
|
175
|
+
return u.math.where(x >= 0., x, a * x)
|
176
|
+
|
177
|
+
|
178
|
+
def soft_shrink(x, lambd=0.5):
|
179
|
+
r"""
|
180
|
+
Soft shrinkage activation function.
|
181
|
+
|
182
|
+
Applies the soft shrinkage function element-wise:
|
183
|
+
|
184
|
+
.. math::
|
185
|
+
\text{SoftShrinkage}(x) =
|
186
|
+
\begin{cases}
|
187
|
+
x - \lambda, & \text{ if } x > \lambda \\
|
188
|
+
x + \lambda, & \text{ if } x < -\lambda \\
|
189
|
+
0, & \text{ otherwise }
|
190
|
+
\end{cases}
|
191
|
+
|
192
|
+
Parameters
|
193
|
+
----------
|
194
|
+
x : ArrayLike
|
195
|
+
Input array of any shape.
|
196
|
+
lambd : float, optional
|
197
|
+
The :math:`\lambda` value for the soft shrinkage formulation.
|
198
|
+
Must be non-negative. Default is 0.5.
|
199
|
+
|
200
|
+
Returns
|
201
|
+
-------
|
202
|
+
jax.Array or Quantity
|
203
|
+
Output array with the same shape as the input.
|
204
|
+
"""
|
205
|
+
return u.math.where(
|
206
|
+
x > lambd,
|
207
|
+
x - lambd,
|
208
|
+
u.math.where(
|
209
|
+
x < -lambd,
|
210
|
+
x + lambd,
|
211
|
+
u.Quantity(0., unit=u.get_unit(lambd))
|
212
|
+
)
|
213
|
+
)
|
214
|
+
|
215
|
+
|
216
|
+
def mish(x):
|
217
|
+
r"""
|
218
|
+
Mish activation function.
|
219
|
+
|
220
|
+
Mish is a self-regularized non-monotonic activation function.
|
221
|
+
|
222
|
+
.. math::
|
223
|
+
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
|
224
|
+
|
225
|
+
Parameters
|
226
|
+
----------
|
227
|
+
x : ArrayLike
|
228
|
+
Input array of any shape.
|
229
|
+
|
230
|
+
Returns
|
231
|
+
-------
|
232
|
+
jax.Array or Quantity
|
233
|
+
Output array with the same shape as the input.
|
234
|
+
|
235
|
+
References
|
236
|
+
----------
|
237
|
+
.. [1] Misra, D. (2019). "Mish: A Self Regularized Non-Monotonic Activation Function."
|
238
|
+
arXiv:1908.08681
|
239
|
+
"""
|
240
|
+
return x * u.math.tanh(softplus(x))
|
241
|
+
|
242
|
+
|
243
|
+
def rrelu(x, lower=0.125, upper=0.3333333333333333):
|
244
|
+
r"""
|
245
|
+
Randomized Leaky Rectified Linear Unit activation function.
|
246
|
+
|
247
|
+
The function is defined as:
|
248
|
+
|
249
|
+
.. math::
|
250
|
+
\text{RReLU}(x) =
|
251
|
+
\begin{cases}
|
252
|
+
x & \text{if } x \geq 0 \\
|
253
|
+
ax & \text{ otherwise }
|
254
|
+
\end{cases}
|
255
|
+
|
256
|
+
where :math:`a` is randomly sampled from uniform distribution
|
257
|
+
:math:`\mathcal{U}(\text{lower}, \text{upper})`.
|
258
|
+
|
259
|
+
Parameters
|
260
|
+
----------
|
261
|
+
x : ArrayLike
|
262
|
+
Input array of any shape.
|
263
|
+
lower : float, optional
|
264
|
+
Lower bound of the uniform distribution for sampling the negative slope.
|
265
|
+
Default is 1/8.
|
266
|
+
upper : float, optional
|
267
|
+
Upper bound of the uniform distribution for sampling the negative slope.
|
268
|
+
Default is 1/3.
|
269
|
+
|
270
|
+
Returns
|
271
|
+
-------
|
272
|
+
jax.Array or Quantity
|
273
|
+
Output array with the same shape as the input.
|
274
|
+
|
275
|
+
References
|
276
|
+
----------
|
277
|
+
.. [1] Xu, B., et al. (2015). "Empirical Evaluation of Rectified Activations
|
278
|
+
in Convolutional Network." arXiv:1505.00853
|
279
|
+
"""
|
280
|
+
a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
|
281
|
+
return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
|
282
|
+
|
283
|
+
|
284
|
+
def hard_shrink(x, lambd=0.5):
|
285
|
+
r"""
|
286
|
+
Hard shrinkage activation function.
|
287
|
+
|
288
|
+
Applies the hard shrinkage function element-wise:
|
289
|
+
|
290
|
+
.. math::
|
291
|
+
\text{HardShrink}(x) =
|
292
|
+
\begin{cases}
|
293
|
+
x, & \text{ if } x > \lambda \\
|
294
|
+
x, & \text{ if } x < -\lambda \\
|
295
|
+
0, & \text{ otherwise }
|
296
|
+
\end{cases}
|
297
|
+
|
298
|
+
Parameters
|
299
|
+
----------
|
300
|
+
x : ArrayLike
|
301
|
+
Input array of any shape.
|
302
|
+
lambd : float, optional
|
303
|
+
The :math:`\lambda` threshold value for the hard shrinkage formulation.
|
304
|
+
Default is 0.5.
|
305
|
+
|
306
|
+
Returns
|
307
|
+
-------
|
308
|
+
jax.Array or Quantity
|
309
|
+
Output array with the same shape as the input.
|
310
|
+
"""
|
311
|
+
return u.math.where(
|
312
|
+
x > lambd,
|
313
|
+
x,
|
314
|
+
u.math.where(
|
315
|
+
x < -lambd,
|
316
|
+
x,
|
317
|
+
u.Quantity(0., unit=u.get_unit(x))
|
318
|
+
)
|
319
|
+
)
|
320
|
+
|
321
|
+
|
322
|
+
def relu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
323
|
+
r"""
|
324
|
+
Rectified Linear Unit activation function.
|
325
|
+
|
326
|
+
Computes the element-wise function:
|
327
|
+
|
328
|
+
.. math::
|
329
|
+
\mathrm{relu}(x) = \max(x, 0)
|
330
|
+
|
331
|
+
Under differentiation, we take:
|
332
|
+
|
333
|
+
.. math::
|
334
|
+
\nabla \mathrm{relu}(0) = 0
|
335
|
+
|
336
|
+
Parameters
|
337
|
+
----------
|
338
|
+
x : ArrayLike
|
339
|
+
Input array.
|
340
|
+
|
341
|
+
Returns
|
342
|
+
-------
|
343
|
+
jax.Array or Quantity
|
344
|
+
An array with the same shape as the input.
|
345
|
+
|
346
|
+
Examples
|
347
|
+
--------
|
348
|
+
.. code-block:: python
|
349
|
+
|
350
|
+
>>> import jax.numpy as jnp
|
351
|
+
>>> import brainstate
|
352
|
+
>>> brainstate.nn.relu(jnp.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
|
353
|
+
Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
|
354
|
+
|
355
|
+
See Also
|
356
|
+
--------
|
357
|
+
relu6 : ReLU6 activation function.
|
358
|
+
leaky_relu : Leaky ReLU activation function.
|
359
|
+
|
360
|
+
References
|
361
|
+
----------
|
362
|
+
.. [1] For more information see "Numerical influence of ReLU'(0) on backpropagation"
|
363
|
+
https://openreview.net/forum?id=urrcVI-_jRm
|
364
|
+
"""
|
365
|
+
return u.math.relu(x)
|
366
|
+
|
367
|
+
|
368
|
+
def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Union[jax.Array, u.Quantity]:
|
369
|
+
r"""
|
370
|
+
Squareplus activation function.
|
371
|
+
|
372
|
+
Computes the element-wise function:
|
373
|
+
|
374
|
+
.. math::
|
375
|
+
\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
|
376
|
+
|
377
|
+
Parameters
|
378
|
+
----------
|
379
|
+
x : ArrayLike
|
380
|
+
Input array.
|
381
|
+
b : ArrayLike, optional
|
382
|
+
Smoothness parameter. Default is 4.
|
383
|
+
|
384
|
+
Returns
|
385
|
+
-------
|
386
|
+
jax.Array or Quantity
|
387
|
+
An array with the same shape as the input.
|
388
|
+
|
389
|
+
References
|
390
|
+
----------
|
391
|
+
.. [1] So, D., et al. (2021). "Primer: Searching for Efficient Transformers
|
392
|
+
for Language Modeling." arXiv:2112.11687
|
393
|
+
"""
|
394
|
+
return u.math.squareplus(x, b=b)
|
395
|
+
|
396
|
+
|
397
|
+
def softplus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
398
|
+
r"""
|
399
|
+
Softplus activation function.
|
400
|
+
|
401
|
+
Computes the element-wise function:
|
402
|
+
|
403
|
+
.. math::
|
404
|
+
\mathrm{softplus}(x) = \log(1 + e^x)
|
405
|
+
|
406
|
+
Parameters
|
407
|
+
----------
|
408
|
+
x : ArrayLike
|
409
|
+
Input array.
|
410
|
+
|
411
|
+
Returns
|
412
|
+
-------
|
413
|
+
jax.Array or Quantity
|
414
|
+
An array with the same shape as the input.
|
415
|
+
"""
|
416
|
+
return u.math.softplus(x)
|
417
|
+
|
418
|
+
|
419
|
+
def soft_sign(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
420
|
+
r"""
|
421
|
+
Soft-sign activation function.
|
422
|
+
|
423
|
+
Computes the element-wise function:
|
424
|
+
|
425
|
+
.. math::
|
426
|
+
\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
|
427
|
+
|
428
|
+
Parameters
|
429
|
+
----------
|
430
|
+
x : ArrayLike
|
431
|
+
Input array.
|
432
|
+
|
433
|
+
Returns
|
434
|
+
-------
|
435
|
+
jax.Array or Quantity
|
436
|
+
An array with the same shape as the input.
|
437
|
+
"""
|
438
|
+
return u.math.soft_sign(x)
|
439
|
+
|
440
|
+
|
441
|
+
def sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
442
|
+
r"""
|
443
|
+
Sigmoid activation function.
|
444
|
+
|
445
|
+
Computes the element-wise function:
|
446
|
+
|
447
|
+
.. math::
|
448
|
+
\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
|
449
|
+
|
450
|
+
Parameters
|
451
|
+
----------
|
452
|
+
x : ArrayLike
|
453
|
+
Input array.
|
454
|
+
|
455
|
+
Returns
|
456
|
+
-------
|
457
|
+
jax.Array or Quantity
|
458
|
+
An array with the same shape as the input.
|
459
|
+
|
460
|
+
See Also
|
461
|
+
--------
|
462
|
+
log_sigmoid : Logarithm of the sigmoid function.
|
463
|
+
"""
|
464
|
+
return u.math.sigmoid(x)
|
465
|
+
|
466
|
+
|
467
|
+
def silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
468
|
+
r"""
|
469
|
+
SiLU (Sigmoid Linear Unit) activation function.
|
470
|
+
|
471
|
+
Computes the element-wise function:
|
472
|
+
|
473
|
+
.. math::
|
474
|
+
\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
|
475
|
+
|
476
|
+
Parameters
|
477
|
+
----------
|
478
|
+
x : ArrayLike
|
479
|
+
Input array.
|
480
|
+
|
481
|
+
Returns
|
482
|
+
-------
|
483
|
+
jax.Array or Quantity
|
484
|
+
An array with the same shape as the input.
|
485
|
+
|
486
|
+
See Also
|
487
|
+
--------
|
488
|
+
sigmoid : The sigmoid function.
|
489
|
+
swish : Alias for silu.
|
490
|
+
|
491
|
+
Notes
|
492
|
+
-----
|
493
|
+
`swish` and `silu` are both aliases for the same function.
|
494
|
+
"""
|
495
|
+
return u.math.silu(x)
|
496
|
+
|
497
|
+
|
498
|
+
swish = silu
|
499
|
+
|
500
|
+
|
501
|
+
def log_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
502
|
+
r"""
|
503
|
+
Log-sigmoid activation function.
|
504
|
+
|
505
|
+
Computes the element-wise function:
|
506
|
+
|
507
|
+
.. math::
|
508
|
+
\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
|
509
|
+
|
510
|
+
Parameters
|
511
|
+
----------
|
512
|
+
x : ArrayLike
|
513
|
+
Input array.
|
514
|
+
|
515
|
+
Returns
|
516
|
+
-------
|
517
|
+
jax.Array or Quantity
|
518
|
+
An array with the same shape as the input.
|
519
|
+
|
520
|
+
See Also
|
521
|
+
--------
|
522
|
+
sigmoid : The sigmoid function.
|
523
|
+
"""
|
524
|
+
return u.math.log_sigmoid(x)
|
525
|
+
|
526
|
+
|
527
|
+
def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
|
528
|
+
r"""
|
529
|
+
Exponential Linear Unit activation function.
|
530
|
+
|
531
|
+
Computes the element-wise function:
|
532
|
+
|
533
|
+
.. math::
|
534
|
+
\mathrm{elu}(x) = \begin{cases}
|
535
|
+
x, & x > 0\\
|
536
|
+
\alpha \left(\exp(x) - 1\right), & x \le 0
|
537
|
+
\end{cases}
|
538
|
+
|
539
|
+
Parameters
|
540
|
+
----------
|
541
|
+
x : ArrayLike
|
542
|
+
Input array.
|
543
|
+
alpha : ArrayLike, optional
|
544
|
+
Scalar or array of alpha values. Default is 1.0.
|
545
|
+
|
546
|
+
Returns
|
547
|
+
-------
|
548
|
+
jax.Array or Quantity
|
549
|
+
An array with the same shape as the input.
|
550
|
+
|
551
|
+
See Also
|
552
|
+
--------
|
553
|
+
selu : Scaled ELU activation function.
|
554
|
+
celu : Continuously-differentiable ELU activation function.
|
555
|
+
"""
|
556
|
+
return u.math.elu(x, alpha=alpha)
|
557
|
+
|
558
|
+
|
559
|
+
def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Union[jax.Array, u.Quantity]:
|
560
|
+
r"""
|
561
|
+
Leaky Rectified Linear Unit activation function.
|
562
|
+
|
563
|
+
Computes the element-wise function:
|
564
|
+
|
565
|
+
.. math::
|
566
|
+
\mathrm{leaky\_relu}(x) = \begin{cases}
|
567
|
+
x, & x \ge 0\\
|
568
|
+
\alpha x, & x < 0
|
569
|
+
\end{cases}
|
570
|
+
|
571
|
+
where :math:`\alpha` = :code:`negative_slope`.
|
572
|
+
|
573
|
+
Parameters
|
574
|
+
----------
|
575
|
+
x : ArrayLike
|
576
|
+
Input array.
|
577
|
+
negative_slope : ArrayLike, optional
|
578
|
+
Array or scalar specifying the negative slope. Default is 0.01.
|
579
|
+
|
580
|
+
Returns
|
581
|
+
-------
|
582
|
+
jax.Array or Quantity
|
583
|
+
An array with the same shape as the input.
|
584
|
+
|
585
|
+
See Also
|
586
|
+
--------
|
587
|
+
relu : Standard ReLU activation function.
|
588
|
+
prelu : Parametric ReLU with learnable slope.
|
589
|
+
"""
|
590
|
+
return u.math.leaky_relu(x, negative_slope=negative_slope)
|
591
|
+
|
592
|
+
|
593
|
+
def _hard_tanh(x, min_val=- 1.0, max_val=1.0):
|
594
|
+
return jax.numpy.where(x > max_val, max_val, jax.numpy.where(x < min_val, min_val, x))
|
595
|
+
|
596
|
+
|
597
|
+
def hard_tanh(
|
598
|
+
x: ArrayLike,
|
599
|
+
min_val: float = - 1.0,
|
600
|
+
max_val: float = 1.0
|
601
|
+
) -> Union[jax.Array, u.Quantity]:
|
602
|
+
r"""
|
603
|
+
Hard hyperbolic tangent activation function.
|
604
|
+
|
605
|
+
Computes the element-wise function:
|
606
|
+
|
607
|
+
.. math::
|
608
|
+
\mathrm{hard\_tanh}(x) = \begin{cases}
|
609
|
+
-1, & x < -1\\
|
610
|
+
x, & -1 \le x \le 1\\
|
611
|
+
1, & 1 < x
|
612
|
+
\end{cases}
|
613
|
+
|
614
|
+
Parameters
|
615
|
+
----------
|
616
|
+
x : ArrayLike
|
617
|
+
Input array.
|
618
|
+
min_val : float, optional
|
619
|
+
Minimum value of the linear region range. Default is -1.
|
620
|
+
max_val : float, optional
|
621
|
+
Maximum value of the linear region range. Default is 1.
|
622
|
+
|
623
|
+
Returns
|
624
|
+
-------
|
625
|
+
jax.Array or Quantity
|
626
|
+
An array with the same shape as the input.
|
627
|
+
"""
|
628
|
+
x = u.Quantity(x)
|
629
|
+
min_val = u.Quantity(min_val).to(x.unit).mantissa
|
630
|
+
max_val = u.Quantity(max_val).to(x.unit).mantissa
|
631
|
+
return u.maybe_decimal(_hard_tanh(x.mantissa, min_val=min_val, max_val=max_val) * x.unit)
|
632
|
+
|
633
|
+
|
634
|
+
def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
|
635
|
+
r"""
|
636
|
+
Continuously-differentiable Exponential Linear Unit activation.
|
637
|
+
|
638
|
+
Computes the element-wise function:
|
639
|
+
|
640
|
+
.. math::
|
641
|
+
\mathrm{celu}(x) = \begin{cases}
|
642
|
+
x, & x > 0\\
|
643
|
+
\alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
|
644
|
+
\end{cases}
|
645
|
+
|
646
|
+
Parameters
|
647
|
+
----------
|
648
|
+
x : ArrayLike
|
649
|
+
Input array.
|
650
|
+
alpha : ArrayLike, optional
|
651
|
+
Scalar or array value controlling the smoothness. Default is 1.0.
|
652
|
+
|
653
|
+
Returns
|
654
|
+
-------
|
655
|
+
jax.Array or Quantity
|
656
|
+
An array with the same shape as the input.
|
657
|
+
|
658
|
+
References
|
659
|
+
----------
|
660
|
+
.. [1] Barron, J. T. (2017). "Continuously Differentiable Exponential Linear Units."
|
661
|
+
arXiv:1704.07483
|
662
|
+
"""
|
663
|
+
return u.math.celu(x, alpha=alpha)
|
664
|
+
|
665
|
+
|
666
|
+
def selu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
667
|
+
r"""
|
668
|
+
Scaled Exponential Linear Unit activation.
|
669
|
+
|
670
|
+
Computes the element-wise function:
|
671
|
+
|
672
|
+
.. math::
|
673
|
+
\mathrm{selu}(x) = \lambda \begin{cases}
|
674
|
+
x, & x > 0\\
|
675
|
+
\alpha e^x - \alpha, & x \le 0
|
676
|
+
\end{cases}
|
677
|
+
|
678
|
+
where :math:`\lambda = 1.0507009873554804934193349852946` and
|
679
|
+
:math:`\alpha = 1.6732632423543772848170429916717`.
|
680
|
+
|
681
|
+
Parameters
|
682
|
+
----------
|
683
|
+
x : ArrayLike
|
684
|
+
Input array.
|
685
|
+
|
686
|
+
Returns
|
687
|
+
-------
|
688
|
+
jax.Array or Quantity
|
689
|
+
An array with the same shape as the input.
|
690
|
+
|
691
|
+
See Also
|
692
|
+
--------
|
693
|
+
elu : Exponential Linear Unit activation function.
|
694
|
+
|
695
|
+
References
|
696
|
+
----------
|
697
|
+
.. [1] Klambauer, G., et al. (2017). "Self-Normalizing Neural Networks."
|
698
|
+
NeurIPS 2017.
|
699
|
+
"""
|
700
|
+
return u.math.selu(x)
|
701
|
+
|
702
|
+
|
703
|
+
def gelu(x: ArrayLike, approximate: bool = True) -> Union[jax.Array, u.Quantity]:
|
704
|
+
r"""
|
705
|
+
Gaussian Error Linear Unit activation function.
|
706
|
+
|
707
|
+
If ``approximate=False``, computes the element-wise function:
|
708
|
+
|
709
|
+
.. math::
|
710
|
+
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
|
711
|
+
\frac{x}{\sqrt{2}} \right) \right)
|
712
|
+
|
713
|
+
If ``approximate=True``, uses the approximate formulation of GELU:
|
714
|
+
|
715
|
+
.. math::
|
716
|
+
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
|
717
|
+
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
|
718
|
+
|
719
|
+
Parameters
|
720
|
+
----------
|
721
|
+
x : ArrayLike
|
722
|
+
Input array.
|
723
|
+
approximate : bool, optional
|
724
|
+
Whether to use the approximate (True) or exact (False) formulation.
|
725
|
+
Default is True.
|
726
|
+
|
727
|
+
Returns
|
728
|
+
-------
|
729
|
+
jax.Array or Quantity
|
730
|
+
An array with the same shape as the input.
|
731
|
+
|
732
|
+
References
|
733
|
+
----------
|
734
|
+
.. [1] Hendrycks, D., & Gimpel, K. (2016). "Gaussian Error Linear Units (GELUs)."
|
735
|
+
arXiv:1606.08415
|
736
|
+
"""
|
737
|
+
return u.math.gelu(x, approximate=approximate)
|
738
|
+
|
739
|
+
|
740
|
+
def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
|
741
|
+
r"""
|
742
|
+
Gated Linear Unit activation function.
|
743
|
+
|
744
|
+
Computes the function:
|
745
|
+
|
746
|
+
.. math::
|
747
|
+
\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot
|
748
|
+
\mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right]
|
749
|
+
\right)
|
750
|
+
|
751
|
+
where the array is split into two along ``axis``. The size of the ``axis``
|
752
|
+
dimension must be divisible by two.
|
753
|
+
|
754
|
+
Parameters
|
755
|
+
----------
|
756
|
+
x : ArrayLike
|
757
|
+
Input array. The dimension specified by ``axis`` must be divisible by 2.
|
758
|
+
axis : int, optional
|
759
|
+
The axis along which the split should be computed. Default is -1.
|
760
|
+
|
761
|
+
Returns
|
762
|
+
-------
|
763
|
+
jax.Array or Quantity
|
764
|
+
An array with the same shape as input except the ``axis`` dimension
|
765
|
+
is halved.
|
766
|
+
|
767
|
+
See Also
|
768
|
+
--------
|
769
|
+
sigmoid : The sigmoid activation function.
|
770
|
+
"""
|
771
|
+
return u.math.glu(x, axis=axis)
|
772
|
+
|
773
|
+
|
774
|
+
def log_softmax(x: ArrayLike,
|
775
|
+
axis: int | tuple[int, ...] | None = -1,
|
776
|
+
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
777
|
+
r"""
|
778
|
+
Log-Softmax function.
|
779
|
+
|
780
|
+
Computes the logarithm of the softmax function, which rescales
|
781
|
+
elements to the range :math:`[-\infty, 0)`.
|
782
|
+
|
783
|
+
.. math ::
|
784
|
+
\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
785
|
+
\right)
|
786
|
+
|
787
|
+
Parameters
|
788
|
+
----------
|
789
|
+
x : ArrayLike
|
790
|
+
Input array.
|
791
|
+
axis : int or tuple of int, optional
|
792
|
+
The axis or axes along which the log-softmax should be computed.
|
793
|
+
Either an integer or a tuple of integers. Default is -1.
|
794
|
+
where : ArrayLike, optional
|
795
|
+
Elements to include in the log-softmax computation.
|
796
|
+
|
797
|
+
Returns
|
798
|
+
-------
|
799
|
+
jax.Array or Quantity
|
800
|
+
An array with the same shape as the input.
|
801
|
+
|
802
|
+
See Also
|
803
|
+
--------
|
804
|
+
softmax : The softmax function.
|
805
|
+
"""
|
806
|
+
return jax.nn.log_softmax(x, axis=axis, where=where)
|
807
|
+
|
808
|
+
|
809
|
+
def softmax(x: ArrayLike,
|
810
|
+
axis: int | tuple[int, ...] | None = -1,
|
811
|
+
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
812
|
+
r"""
|
813
|
+
Softmax activation function.
|
814
|
+
|
815
|
+
Computes the function which rescales elements to the range :math:`[0, 1]`
|
816
|
+
such that the elements along :code:`axis` sum to :math:`1`.
|
817
|
+
|
818
|
+
.. math ::
|
819
|
+
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
820
|
+
|
821
|
+
Parameters
|
822
|
+
----------
|
823
|
+
x : ArrayLike
|
824
|
+
Input array.
|
825
|
+
axis : int or tuple of int, optional
|
826
|
+
The axis or axes along which the softmax should be computed. The
|
827
|
+
softmax output summed across these dimensions should sum to :math:`1`.
|
828
|
+
Either an integer or a tuple of integers. Default is -1.
|
829
|
+
where : ArrayLike, optional
|
830
|
+
Elements to include in the softmax computation.
|
831
|
+
|
832
|
+
Returns
|
833
|
+
-------
|
834
|
+
jax.Array or Quantity
|
835
|
+
An array with the same shape as the input.
|
836
|
+
|
837
|
+
See Also
|
838
|
+
--------
|
839
|
+
log_softmax : Logarithm of the softmax function.
|
840
|
+
softmin : Softmin activation function.
|
841
|
+
"""
|
842
|
+
return jax.nn.softmax(x, axis=axis, where=where)
|
843
|
+
|
844
|
+
|
845
|
+
def standardize(x: ArrayLike,
|
846
|
+
axis: int | tuple[int, ...] | None = -1,
|
847
|
+
variance: ArrayLike | None = None,
|
848
|
+
epsilon: ArrayLike = 1e-5,
|
849
|
+
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
850
|
+
r"""
|
851
|
+
Standardize (normalize) an array.
|
852
|
+
|
853
|
+
Normalizes an array by subtracting the mean and dividing by the standard
|
854
|
+
deviation :math:`\sqrt{\mathrm{variance}}`.
|
855
|
+
|
856
|
+
Parameters
|
857
|
+
----------
|
858
|
+
x : ArrayLike
|
859
|
+
Input array.
|
860
|
+
axis : int or tuple of int, optional
|
861
|
+
The axis or axes along which to compute the mean and variance.
|
862
|
+
Default is -1.
|
863
|
+
variance : ArrayLike, optional
|
864
|
+
Pre-computed variance. If None, variance is computed from ``x``.
|
865
|
+
epsilon : ArrayLike, optional
|
866
|
+
A small constant added to the variance to avoid division by zero.
|
867
|
+
Default is 1e-5.
|
868
|
+
where : ArrayLike, optional
|
869
|
+
Elements to include in the computation.
|
870
|
+
|
871
|
+
Returns
|
872
|
+
-------
|
873
|
+
jax.Array or Quantity
|
874
|
+
Standardized array with the same shape as the input.
|
875
|
+
"""
|
876
|
+
return jax.nn.standardize(x, axis=axis, where=where, variance=variance, epsilon=epsilon)
|
877
|
+
|
878
|
+
|
879
|
+
def one_hot(x: Any,
|
880
|
+
num_classes: int, *,
|
881
|
+
dtype: Any = jax.numpy.float_,
|
882
|
+
axis: Union[int, Sequence[int]] = -1) -> Union[jax.Array, u.Quantity]:
|
883
|
+
"""
|
884
|
+
One-hot encode the given indices.
|
885
|
+
|
886
|
+
Each index in the input ``x`` is encoded as a vector of zeros of length
|
887
|
+
``num_classes`` with the element at ``index`` set to one.
|
888
|
+
|
889
|
+
Indices outside the range [0, num_classes) will be encoded as zeros.
|
890
|
+
|
891
|
+
Parameters
|
892
|
+
----------
|
893
|
+
x : ArrayLike
|
894
|
+
A tensor of indices.
|
895
|
+
num_classes : int
|
896
|
+
Number of classes in the one-hot dimension.
|
897
|
+
dtype : dtype, optional
|
898
|
+
The dtype for the returned values. Default is ``jnp.float_``.
|
899
|
+
axis : int or Sequence of int, optional
|
900
|
+
The axis or axes along which the function should be computed.
|
901
|
+
Default is -1.
|
902
|
+
|
903
|
+
Returns
|
904
|
+
-------
|
905
|
+
jax.Array or Quantity
|
906
|
+
One-hot encoded array.
|
907
|
+
|
908
|
+
Examples
|
909
|
+
--------
|
910
|
+
.. code-block:: python
|
911
|
+
|
912
|
+
>>> import jax.numpy as jnp
|
913
|
+
>>> import brainstate
|
914
|
+
>>> brainstate.nn.one_hot(jnp.array([0, 1, 2]), 3)
|
915
|
+
Array([[1., 0., 0.],
|
916
|
+
[0., 1., 0.],
|
917
|
+
[0., 0., 1.]], dtype=float32)
|
918
|
+
|
919
|
+
>>> # Indices outside the range are encoded as zeros
|
920
|
+
>>> brainstate.nn.one_hot(jnp.array([-1, 3]), 3)
|
921
|
+
Array([[0., 0., 0.],
|
922
|
+
[0., 0., 0.]], dtype=float32)
|
923
|
+
"""
|
924
|
+
return jax.nn.one_hot(x, axis=axis, num_classes=num_classes, dtype=dtype)
|
925
|
+
|
926
|
+
|
927
|
+
def relu6(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
928
|
+
r"""
|
929
|
+
Rectified Linear Unit 6 activation function.
|
930
|
+
|
931
|
+
Computes the element-wise function:
|
932
|
+
|
933
|
+
.. math::
|
934
|
+
\mathrm{relu6}(x) = \min(\max(x, 0), 6)
|
935
|
+
|
936
|
+
Under differentiation, we take:
|
937
|
+
|
938
|
+
.. math::
|
939
|
+
\nabla \mathrm{relu}(0) = 0
|
940
|
+
|
941
|
+
and
|
942
|
+
|
943
|
+
.. math::
|
944
|
+
\nabla \mathrm{relu}(6) = 0
|
945
|
+
|
946
|
+
Parameters
|
947
|
+
----------
|
948
|
+
x : ArrayLike
|
949
|
+
Input array.
|
950
|
+
|
951
|
+
Returns
|
952
|
+
-------
|
953
|
+
jax.Array or Quantity
|
954
|
+
An array with the same shape as the input.
|
955
|
+
|
956
|
+
See Also
|
957
|
+
--------
|
958
|
+
relu : Standard ReLU activation function.
|
959
|
+
"""
|
960
|
+
return u.math.relu6(x)
|
961
|
+
|
962
|
+
|
963
|
+
def hard_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
964
|
+
r"""
|
965
|
+
Hard Sigmoid activation function.
|
966
|
+
|
967
|
+
Computes the element-wise function:
|
968
|
+
|
969
|
+
.. math::
|
970
|
+
\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
|
971
|
+
|
972
|
+
Parameters
|
973
|
+
----------
|
974
|
+
x : ArrayLike
|
975
|
+
Input array.
|
976
|
+
|
977
|
+
Returns
|
978
|
+
-------
|
979
|
+
jax.Array or Quantity
|
980
|
+
An array with the same shape as the input.
|
981
|
+
|
982
|
+
See Also
|
983
|
+
--------
|
984
|
+
relu6 : ReLU6 activation function.
|
985
|
+
sigmoid : Standard sigmoid function.
|
986
|
+
"""
|
987
|
+
return u.math.hard_sigmoid(x)
|
988
|
+
|
989
|
+
|
990
|
+
def hard_silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
991
|
+
r"""
|
992
|
+
Hard SiLU (Swish) activation function.
|
993
|
+
|
994
|
+
Computes the element-wise function:
|
995
|
+
|
996
|
+
.. math::
|
997
|
+
\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
|
998
|
+
|
999
|
+
Parameters
|
1000
|
+
----------
|
1001
|
+
x : ArrayLike
|
1002
|
+
Input array.
|
1003
|
+
|
1004
|
+
Returns
|
1005
|
+
-------
|
1006
|
+
jax.Array or Quantity
|
1007
|
+
An array with the same shape as the input.
|
1008
|
+
|
1009
|
+
See Also
|
1010
|
+
--------
|
1011
|
+
hard_sigmoid : Hard sigmoid activation function.
|
1012
|
+
silu : Standard SiLU activation function.
|
1013
|
+
hard_swish : Alias for hard_silu.
|
1014
|
+
|
1015
|
+
Notes
|
1016
|
+
-----
|
1017
|
+
Both `hard_silu` and `hard_swish` are aliases for the same function.
|
1018
|
+
"""
|
1019
|
+
return u.math.hard_silu(x)
|
1020
|
+
|
1021
|
+
|
1022
|
+
hard_swish = hard_silu
|
1023
|
+
|
1024
|
+
|
1025
|
+
def sparse_plus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
1026
|
+
r"""
|
1027
|
+
Sparse plus activation function.
|
1028
|
+
|
1029
|
+
Computes the function:
|
1030
|
+
|
1031
|
+
.. math::
|
1032
|
+
|
1033
|
+
\mathrm{sparse\_plus}(x) = \begin{cases}
|
1034
|
+
0, & x \leq -1\\
|
1035
|
+
\frac{1}{4}(x+1)^2, & -1 < x < 1 \\
|
1036
|
+
x, & 1 \leq x
|
1037
|
+
\end{cases}
|
1038
|
+
|
1039
|
+
This is the twin function of the softplus activation, ensuring a zero output
|
1040
|
+
for inputs less than -1 and a linear output for inputs greater than 1,
|
1041
|
+
while remaining smooth, convex, and monotonic between -1 and 1.
|
1042
|
+
|
1043
|
+
Parameters
|
1044
|
+
----------
|
1045
|
+
x : ArrayLike
|
1046
|
+
Input array.
|
1047
|
+
|
1048
|
+
Returns
|
1049
|
+
-------
|
1050
|
+
jax.Array or Quantity
|
1051
|
+
An array with the same shape as the input.
|
1052
|
+
|
1053
|
+
See Also
|
1054
|
+
--------
|
1055
|
+
sparse_sigmoid : Derivative of sparse_plus.
|
1056
|
+
softplus : Standard softplus activation function.
|
1057
|
+
"""
|
1058
|
+
return u.math.sparse_plus(x)
|
1059
|
+
|
1060
|
+
|
1061
|
+
def sparse_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
1062
|
+
r"""
|
1063
|
+
Sparse sigmoid activation function.
|
1064
|
+
|
1065
|
+
Computes the function:
|
1066
|
+
|
1067
|
+
.. math::
|
1068
|
+
|
1069
|
+
\mathrm{sparse\_sigmoid}(x) = \begin{cases}
|
1070
|
+
0, & x \leq -1\\
|
1071
|
+
\frac{1}{2}(x+1), & -1 < x < 1 \\
|
1072
|
+
1, & 1 \leq x
|
1073
|
+
\end{cases}
|
1074
|
+
|
1075
|
+
This is the twin function of the standard sigmoid activation, ensuring a zero
|
1076
|
+
output for inputs less than -1, a 1 output for inputs greater than 1, and a
|
1077
|
+
linear output for inputs between -1 and 1. It is the derivative of `sparse_plus`.
|
1078
|
+
|
1079
|
+
Parameters
|
1080
|
+
----------
|
1081
|
+
x : ArrayLike
|
1082
|
+
Input array.
|
1083
|
+
|
1084
|
+
Returns
|
1085
|
+
-------
|
1086
|
+
jax.Array or Quantity
|
1087
|
+
An array with the same shape as the input.
|
1088
|
+
|
1089
|
+
See Also
|
1090
|
+
--------
|
1091
|
+
sigmoid : Standard sigmoid activation function.
|
1092
|
+
sparse_plus : Sparse plus activation function.
|
1093
|
+
|
1094
|
+
References
|
1095
|
+
----------
|
1096
|
+
.. [1] Martins, A. F. T., & Astudillo, R. F. (2016). "From Softmax to Sparsemax:
|
1097
|
+
A Sparse Model of Attention and Multi-Label Classification."
|
1098
|
+
In ICML. See also "Learning with Fenchel-Young Losses", arXiv:1901.02324
|
1099
|
+
"""
|
1100
|
+
return u.math.sparse_sigmoid(x)
|