brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -156
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -502
- brainstate/nn/_delay_test.py +238 -184
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1343
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1119
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.8.dist-info/RECORD +0 -132
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,813 +1,808 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
"""
|
18
|
-
Shared neural network activations and other functions.
|
19
|
-
"""
|
20
|
-
|
21
|
-
from typing import Any, Union, Sequence
|
22
|
-
|
23
|
-
import brainunit as u
|
24
|
-
import jax
|
25
|
-
from jax.scipy.special import logsumexp
|
26
|
-
|
27
|
-
from brainstate import random
|
28
|
-
from brainstate.typing import ArrayLike
|
29
|
-
|
30
|
-
__all__ = [
|
31
|
-
"tanh",
|
32
|
-
"relu",
|
33
|
-
"squareplus",
|
34
|
-
"softplus",
|
35
|
-
"soft_sign",
|
36
|
-
"sigmoid",
|
37
|
-
"silu",
|
38
|
-
"swish",
|
39
|
-
"log_sigmoid",
|
40
|
-
"elu",
|
41
|
-
"leaky_relu",
|
42
|
-
"hard_tanh",
|
43
|
-
"celu",
|
44
|
-
"selu",
|
45
|
-
"gelu",
|
46
|
-
"glu",
|
47
|
-
"logsumexp",
|
48
|
-
"log_softmax",
|
49
|
-
"softmax",
|
50
|
-
"standardize",
|
51
|
-
"one_hot",
|
52
|
-
"relu6",
|
53
|
-
"hard_sigmoid",
|
54
|
-
"hard_silu",
|
55
|
-
"hard_swish",
|
56
|
-
'hard_shrink',
|
57
|
-
'rrelu',
|
58
|
-
'mish',
|
59
|
-
'soft_shrink',
|
60
|
-
'prelu',
|
61
|
-
'tanh_shrink',
|
62
|
-
'softmin',
|
63
|
-
'sparse_plus',
|
64
|
-
'sparse_sigmoid',
|
65
|
-
]
|
66
|
-
|
67
|
-
|
68
|
-
def tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
69
|
-
r"""Hyperbolic tangent activation function.
|
70
|
-
|
71
|
-
Computes the element-wise function:
|
72
|
-
|
73
|
-
.. math::
|
74
|
-
\mathrm{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
|
75
|
-
|
76
|
-
Args:
|
77
|
-
x : input array
|
78
|
-
|
79
|
-
Returns:
|
80
|
-
An array.
|
81
|
-
"""
|
82
|
-
return u.math.tanh(x)
|
83
|
-
|
84
|
-
|
85
|
-
def softmin(x, axis=-1):
|
86
|
-
r"""
|
87
|
-
Applies the Softmin function to an n-dimensional input Tensor
|
88
|
-
rescaling them so that the elements of the n-dimensional output Tensor
|
89
|
-
lie in the range `[0, 1]` and sum to 1.
|
90
|
-
|
91
|
-
Softmin is defined as:
|
92
|
-
|
93
|
-
.. math::
|
94
|
-
\text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
|
95
|
-
|
96
|
-
Shape:
|
97
|
-
- Input: :math:`(*)` where `*` means, any number of additional
|
98
|
-
dimensions
|
99
|
-
- Output: :math:`(*)`, same shape as the input
|
100
|
-
|
101
|
-
Args:
|
102
|
-
axis (int): A dimension along which Softmin will be computed (so every slice
|
103
|
-
along dim will sum to 1).
|
104
|
-
"""
|
105
|
-
unnormalized = u.math.exp(-x)
|
106
|
-
return unnormalized / unnormalized.sum(axis, keepdims=True)
|
107
|
-
|
108
|
-
|
109
|
-
def tanh_shrink(x):
|
110
|
-
r"""
|
111
|
-
Applies the element-wise function:
|
112
|
-
|
113
|
-
.. math::
|
114
|
-
\text{Tanhshrink}(x) = x - \tanh(x)
|
115
|
-
"""
|
116
|
-
return x - u.math.tanh(x)
|
117
|
-
|
118
|
-
|
119
|
-
def prelu(x, a=0.25):
|
120
|
-
r"""
|
121
|
-
Applies the element-wise function:
|
122
|
-
|
123
|
-
.. math::
|
124
|
-
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
125
|
-
|
126
|
-
or
|
127
|
-
|
128
|
-
.. math::
|
129
|
-
\text{PReLU}(x) =
|
130
|
-
\begin{cases}
|
131
|
-
x, & \text{ if } x \geq 0 \\
|
132
|
-
ax, & \text{ otherwise }
|
133
|
-
\end{cases}
|
134
|
-
|
135
|
-
Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
|
136
|
-
parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
|
137
|
-
a separate :math:`a` is used for each input channel.
|
138
|
-
"""
|
139
|
-
return u.math.where(x >= 0., x, a * x)
|
140
|
-
|
141
|
-
|
142
|
-
def soft_shrink(x, lambd=0.5):
|
143
|
-
r"""
|
144
|
-
Applies the soft shrinkage function elementwise:
|
145
|
-
|
146
|
-
.. math::
|
147
|
-
\text{SoftShrinkage}(x) =
|
148
|
-
\begin{cases}
|
149
|
-
x - \lambda, & \text{ if } x > \lambda \\
|
150
|
-
x + \lambda, & \text{ if } x < -\lambda \\
|
151
|
-
0, & \text{ otherwise }
|
152
|
-
\end{cases}
|
153
|
-
|
154
|
-
Args:
|
155
|
-
lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
|
156
|
-
|
157
|
-
Shape:
|
158
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
159
|
-
- Output: :math:`(*)`, same shape as the input.
|
160
|
-
"""
|
161
|
-
return u.math.where(x > lambd,
|
162
|
-
x - lambd,
|
163
|
-
u.math.where(x < -lambd,
|
164
|
-
x + lambd,
|
165
|
-
u.Quantity(0., unit=u.get_unit(lambd))))
|
166
|
-
|
167
|
-
|
168
|
-
def mish(x):
|
169
|
-
r"""Applies the Mish function, element-wise.
|
170
|
-
|
171
|
-
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
172
|
-
|
173
|
-
.. math::
|
174
|
-
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
|
175
|
-
|
176
|
-
.. note::
|
177
|
-
See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
|
178
|
-
|
179
|
-
Shape:
|
180
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
181
|
-
- Output: :math:`(*)`, same shape as the input.
|
182
|
-
"""
|
183
|
-
return x * u.math.tanh(softplus(x))
|
184
|
-
|
185
|
-
|
186
|
-
def rrelu(x, lower=0.125, upper=0.3333333333333333):
|
187
|
-
r"""Applies the randomized leaky rectified liner unit function, element-wise,
|
188
|
-
as described in the paper:
|
189
|
-
|
190
|
-
`Empirical Evaluation of Rectified Activations in Convolutional Network`_.
|
191
|
-
|
192
|
-
The function is defined as:
|
193
|
-
|
194
|
-
.. math::
|
195
|
-
\text{RReLU}(x) =
|
196
|
-
\begin{cases}
|
197
|
-
x & \text{if } x \geq 0 \\
|
198
|
-
ax & \text{ otherwise }
|
199
|
-
\end{cases}
|
200
|
-
|
201
|
-
where :math:`a` is randomly sampled from uniform distribution
|
202
|
-
:math:`\mathcal{U}(\text{lower}, \text{upper})`.
|
203
|
-
|
204
|
-
See: https://arxiv.org/pdf/1505.00853.pdf
|
205
|
-
|
206
|
-
Args:
|
207
|
-
lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
|
208
|
-
upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
|
209
|
-
|
210
|
-
Shape:
|
211
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
212
|
-
- Output: :math:`(*)`, same shape as the input.
|
213
|
-
|
214
|
-
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
|
215
|
-
https://arxiv.org/abs/1505.00853
|
216
|
-
"""
|
217
|
-
a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
|
218
|
-
return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
|
219
|
-
|
220
|
-
|
221
|
-
def hard_shrink(x, lambd=0.5):
|
222
|
-
r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
|
223
|
-
|
224
|
-
Hardshrink is defined as:
|
225
|
-
|
226
|
-
.. math::
|
227
|
-
\text{HardShrink}(x) =
|
228
|
-
\begin{cases}
|
229
|
-
x, & \text{ if } x > \lambda \\
|
230
|
-
x, & \text{ if } x < -\lambda \\
|
231
|
-
0, & \text{ otherwise }
|
232
|
-
\end{cases}
|
233
|
-
|
234
|
-
Args:
|
235
|
-
lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
|
236
|
-
|
237
|
-
Shape:
|
238
|
-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
239
|
-
- Output: :math:`(*)`, same shape as the input.
|
240
|
-
|
241
|
-
"""
|
242
|
-
return u.math.where(x > lambd,
|
243
|
-
x,
|
244
|
-
u.math.where(x < -lambd,
|
245
|
-
x,
|
246
|
-
u.Quantity(0., unit=u.get_unit(x))))
|
247
|
-
|
248
|
-
|
249
|
-
def
|
250
|
-
unit
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
)
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
\
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
\
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
"""
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
See also:
|
811
|
-
:func:`sigmoid`
|
812
|
-
"""
|
813
|
-
return _keep_unit(jax.nn.sparse_sigmoid, x)
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
"""
|
18
|
+
Shared neural network activations and other functions.
|
19
|
+
"""
|
20
|
+
|
21
|
+
from typing import Any, Union, Sequence
|
22
|
+
|
23
|
+
import brainunit as u
|
24
|
+
import jax
|
25
|
+
from jax.scipy.special import logsumexp
|
26
|
+
|
27
|
+
from brainstate import random
|
28
|
+
from brainstate.typing import ArrayLike
|
29
|
+
|
30
|
+
__all__ = [
|
31
|
+
"tanh",
|
32
|
+
"relu",
|
33
|
+
"squareplus",
|
34
|
+
"softplus",
|
35
|
+
"soft_sign",
|
36
|
+
"sigmoid",
|
37
|
+
"silu",
|
38
|
+
"swish",
|
39
|
+
"log_sigmoid",
|
40
|
+
"elu",
|
41
|
+
"leaky_relu",
|
42
|
+
"hard_tanh",
|
43
|
+
"celu",
|
44
|
+
"selu",
|
45
|
+
"gelu",
|
46
|
+
"glu",
|
47
|
+
"logsumexp",
|
48
|
+
"log_softmax",
|
49
|
+
"softmax",
|
50
|
+
"standardize",
|
51
|
+
"one_hot",
|
52
|
+
"relu6",
|
53
|
+
"hard_sigmoid",
|
54
|
+
"hard_silu",
|
55
|
+
"hard_swish",
|
56
|
+
'hard_shrink',
|
57
|
+
'rrelu',
|
58
|
+
'mish',
|
59
|
+
'soft_shrink',
|
60
|
+
'prelu',
|
61
|
+
'tanh_shrink',
|
62
|
+
'softmin',
|
63
|
+
'sparse_plus',
|
64
|
+
'sparse_sigmoid',
|
65
|
+
]
|
66
|
+
|
67
|
+
|
68
|
+
def tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
69
|
+
r"""Hyperbolic tangent activation function.
|
70
|
+
|
71
|
+
Computes the element-wise function:
|
72
|
+
|
73
|
+
.. math::
|
74
|
+
\mathrm{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
|
75
|
+
|
76
|
+
Args:
|
77
|
+
x : input array
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
An array.
|
81
|
+
"""
|
82
|
+
return u.math.tanh(x)
|
83
|
+
|
84
|
+
|
85
|
+
def softmin(x, axis=-1):
|
86
|
+
r"""
|
87
|
+
Applies the Softmin function to an n-dimensional input Tensor
|
88
|
+
rescaling them so that the elements of the n-dimensional output Tensor
|
89
|
+
lie in the range `[0, 1]` and sum to 1.
|
90
|
+
|
91
|
+
Softmin is defined as:
|
92
|
+
|
93
|
+
.. math::
|
94
|
+
\text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
|
95
|
+
|
96
|
+
Shape:
|
97
|
+
- Input: :math:`(*)` where `*` means, any number of additional
|
98
|
+
dimensions
|
99
|
+
- Output: :math:`(*)`, same shape as the input
|
100
|
+
|
101
|
+
Args:
|
102
|
+
axis (int): A dimension along which Softmin will be computed (so every slice
|
103
|
+
along dim will sum to 1).
|
104
|
+
"""
|
105
|
+
unnormalized = u.math.exp(-x)
|
106
|
+
return unnormalized / unnormalized.sum(axis, keepdims=True)
|
107
|
+
|
108
|
+
|
109
|
+
def tanh_shrink(x):
|
110
|
+
r"""
|
111
|
+
Applies the element-wise function:
|
112
|
+
|
113
|
+
.. math::
|
114
|
+
\text{Tanhshrink}(x) = x - \tanh(x)
|
115
|
+
"""
|
116
|
+
return x - u.math.tanh(x)
|
117
|
+
|
118
|
+
|
119
|
+
def prelu(x, a=0.25):
|
120
|
+
r"""
|
121
|
+
Applies the element-wise function:
|
122
|
+
|
123
|
+
.. math::
|
124
|
+
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
125
|
+
|
126
|
+
or
|
127
|
+
|
128
|
+
.. math::
|
129
|
+
\text{PReLU}(x) =
|
130
|
+
\begin{cases}
|
131
|
+
x, & \text{ if } x \geq 0 \\
|
132
|
+
ax, & \text{ otherwise }
|
133
|
+
\end{cases}
|
134
|
+
|
135
|
+
Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
|
136
|
+
parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
|
137
|
+
a separate :math:`a` is used for each input channel.
|
138
|
+
"""
|
139
|
+
return u.math.where(x >= 0., x, a * x)
|
140
|
+
|
141
|
+
|
142
|
+
def soft_shrink(x, lambd=0.5):
|
143
|
+
r"""
|
144
|
+
Applies the soft shrinkage function elementwise:
|
145
|
+
|
146
|
+
.. math::
|
147
|
+
\text{SoftShrinkage}(x) =
|
148
|
+
\begin{cases}
|
149
|
+
x - \lambda, & \text{ if } x > \lambda \\
|
150
|
+
x + \lambda, & \text{ if } x < -\lambda \\
|
151
|
+
0, & \text{ otherwise }
|
152
|
+
\end{cases}
|
153
|
+
|
154
|
+
Args:
|
155
|
+
lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
|
156
|
+
|
157
|
+
Shape:
|
158
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
159
|
+
- Output: :math:`(*)`, same shape as the input.
|
160
|
+
"""
|
161
|
+
return u.math.where(x > lambd,
|
162
|
+
x - lambd,
|
163
|
+
u.math.where(x < -lambd,
|
164
|
+
x + lambd,
|
165
|
+
u.Quantity(0., unit=u.get_unit(lambd))))
|
166
|
+
|
167
|
+
|
168
|
+
def mish(x):
|
169
|
+
r"""Applies the Mish function, element-wise.
|
170
|
+
|
171
|
+
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
172
|
+
|
173
|
+
.. math::
|
174
|
+
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
|
175
|
+
|
176
|
+
.. note::
|
177
|
+
See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
|
178
|
+
|
179
|
+
Shape:
|
180
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
181
|
+
- Output: :math:`(*)`, same shape as the input.
|
182
|
+
"""
|
183
|
+
return x * u.math.tanh(softplus(x))
|
184
|
+
|
185
|
+
|
186
|
+
def rrelu(x, lower=0.125, upper=0.3333333333333333):
|
187
|
+
r"""Applies the randomized leaky rectified liner unit function, element-wise,
|
188
|
+
as described in the paper:
|
189
|
+
|
190
|
+
`Empirical Evaluation of Rectified Activations in Convolutional Network`_.
|
191
|
+
|
192
|
+
The function is defined as:
|
193
|
+
|
194
|
+
.. math::
|
195
|
+
\text{RReLU}(x) =
|
196
|
+
\begin{cases}
|
197
|
+
x & \text{if } x \geq 0 \\
|
198
|
+
ax & \text{ otherwise }
|
199
|
+
\end{cases}
|
200
|
+
|
201
|
+
where :math:`a` is randomly sampled from uniform distribution
|
202
|
+
:math:`\mathcal{U}(\text{lower}, \text{upper})`.
|
203
|
+
|
204
|
+
See: https://arxiv.org/pdf/1505.00853.pdf
|
205
|
+
|
206
|
+
Args:
|
207
|
+
lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
|
208
|
+
upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
|
209
|
+
|
210
|
+
Shape:
|
211
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
212
|
+
- Output: :math:`(*)`, same shape as the input.
|
213
|
+
|
214
|
+
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
|
215
|
+
https://arxiv.org/abs/1505.00853
|
216
|
+
"""
|
217
|
+
a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
|
218
|
+
return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
|
219
|
+
|
220
|
+
|
221
|
+
def hard_shrink(x, lambd=0.5):
|
222
|
+
r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
|
223
|
+
|
224
|
+
Hardshrink is defined as:
|
225
|
+
|
226
|
+
.. math::
|
227
|
+
\text{HardShrink}(x) =
|
228
|
+
\begin{cases}
|
229
|
+
x, & \text{ if } x > \lambda \\
|
230
|
+
x, & \text{ if } x < -\lambda \\
|
231
|
+
0, & \text{ otherwise }
|
232
|
+
\end{cases}
|
233
|
+
|
234
|
+
Args:
|
235
|
+
lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
|
236
|
+
|
237
|
+
Shape:
|
238
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
239
|
+
- Output: :math:`(*)`, same shape as the input.
|
240
|
+
|
241
|
+
"""
|
242
|
+
return u.math.where(x > lambd,
|
243
|
+
x,
|
244
|
+
u.math.where(x < -lambd,
|
245
|
+
x,
|
246
|
+
u.Quantity(0., unit=u.get_unit(x))))
|
247
|
+
|
248
|
+
|
249
|
+
def relu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
250
|
+
r"""Rectified linear unit activation function.
|
251
|
+
|
252
|
+
Computes the element-wise function:
|
253
|
+
|
254
|
+
.. math::
|
255
|
+
\mathrm{relu}(x) = \max(x, 0)
|
256
|
+
|
257
|
+
except under differentiation, we take:
|
258
|
+
|
259
|
+
.. math::
|
260
|
+
\nabla \mathrm{relu}(0) = 0
|
261
|
+
|
262
|
+
For more information see
|
263
|
+
`Numerical influence of ReLU’(0) on backpropagation
|
264
|
+
<https://openreview.net/forum?id=urrcVI-_jRm>`_.
|
265
|
+
|
266
|
+
Args:
|
267
|
+
x : input array
|
268
|
+
|
269
|
+
Returns:
|
270
|
+
An array.
|
271
|
+
|
272
|
+
Example:
|
273
|
+
>>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
|
274
|
+
Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
|
275
|
+
|
276
|
+
See also:
|
277
|
+
:func:`relu6`
|
278
|
+
|
279
|
+
"""
|
280
|
+
return u.math.relu(x)
|
281
|
+
|
282
|
+
|
283
|
+
def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Union[jax.Array, u.Quantity]:
|
284
|
+
r"""Squareplus activation function.
|
285
|
+
|
286
|
+
Computes the element-wise function
|
287
|
+
|
288
|
+
.. math::
|
289
|
+
\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
|
290
|
+
|
291
|
+
as described in https://arxiv.org/abs/2112.11687.
|
292
|
+
|
293
|
+
Args:
|
294
|
+
x : input array
|
295
|
+
b : smoothness parameter
|
296
|
+
"""
|
297
|
+
return u.math.squareplus(x, b=b)
|
298
|
+
|
299
|
+
|
300
|
+
def softplus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
301
|
+
r"""Softplus activation function.
|
302
|
+
|
303
|
+
Computes the element-wise function
|
304
|
+
|
305
|
+
.. math::
|
306
|
+
\mathrm{softplus}(x) = \log(1 + e^x)
|
307
|
+
|
308
|
+
Args:
|
309
|
+
x : input array
|
310
|
+
"""
|
311
|
+
return u.math.softplus(x)
|
312
|
+
|
313
|
+
|
314
|
+
def soft_sign(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
315
|
+
r"""Soft-sign activation function.
|
316
|
+
|
317
|
+
Computes the element-wise function
|
318
|
+
|
319
|
+
.. math::
|
320
|
+
\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
|
321
|
+
|
322
|
+
Args:
|
323
|
+
x : input array
|
324
|
+
"""
|
325
|
+
return u.math.soft_sign(x)
|
326
|
+
|
327
|
+
|
328
|
+
def sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
329
|
+
r"""Sigmoid activation function.
|
330
|
+
|
331
|
+
Computes the element-wise function:
|
332
|
+
|
333
|
+
.. math::
|
334
|
+
\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
|
335
|
+
|
336
|
+
Args:
|
337
|
+
x : input array
|
338
|
+
|
339
|
+
Returns:
|
340
|
+
An array.
|
341
|
+
|
342
|
+
See also:
|
343
|
+
:func:`log_sigmoid`
|
344
|
+
|
345
|
+
"""
|
346
|
+
return u.math.sigmoid(x)
|
347
|
+
|
348
|
+
|
349
|
+
def silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
350
|
+
r"""SiLU (a.k.a. swish) activation function.
|
351
|
+
|
352
|
+
Computes the element-wise function:
|
353
|
+
|
354
|
+
.. math::
|
355
|
+
\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
|
356
|
+
|
357
|
+
:func:`swish` and :func:`silu` are both aliases for the same function.
|
358
|
+
|
359
|
+
Args:
|
360
|
+
x : input array
|
361
|
+
|
362
|
+
Returns:
|
363
|
+
An array.
|
364
|
+
|
365
|
+
See also:
|
366
|
+
:func:`sigmoid`
|
367
|
+
"""
|
368
|
+
return u.math.silu(x)
|
369
|
+
|
370
|
+
|
371
|
+
swish = silu
|
372
|
+
|
373
|
+
|
374
|
+
def log_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
375
|
+
r"""Log-sigmoid activation function.
|
376
|
+
|
377
|
+
Computes the element-wise function:
|
378
|
+
|
379
|
+
.. math::
|
380
|
+
\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
|
381
|
+
|
382
|
+
Args:
|
383
|
+
x : input array
|
384
|
+
|
385
|
+
Returns:
|
386
|
+
An array.
|
387
|
+
|
388
|
+
See also:
|
389
|
+
:func:`sigmoid`
|
390
|
+
"""
|
391
|
+
return u.math.log_sigmoid(x)
|
392
|
+
|
393
|
+
|
394
|
+
def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
|
395
|
+
r"""Exponential linear unit activation function.
|
396
|
+
|
397
|
+
Computes the element-wise function:
|
398
|
+
|
399
|
+
.. math::
|
400
|
+
\mathrm{elu}(x) = \begin{cases}
|
401
|
+
x, & x > 0\\
|
402
|
+
\alpha \left(\exp(x) - 1\right), & x \le 0
|
403
|
+
\end{cases}
|
404
|
+
|
405
|
+
Args:
|
406
|
+
x : input array
|
407
|
+
alpha : scalar or array of alpha values (default: 1.0)
|
408
|
+
|
409
|
+
Returns:
|
410
|
+
An array.
|
411
|
+
|
412
|
+
See also:
|
413
|
+
:func:`selu`
|
414
|
+
"""
|
415
|
+
return u.math.elu(x, alpha=alpha)
|
416
|
+
|
417
|
+
|
418
|
+
def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Union[jax.Array, u.Quantity]:
|
419
|
+
r"""Leaky rectified linear unit activation function.
|
420
|
+
|
421
|
+
Computes the element-wise function:
|
422
|
+
|
423
|
+
.. math::
|
424
|
+
\mathrm{leaky\_relu}(x) = \begin{cases}
|
425
|
+
x, & x \ge 0\\
|
426
|
+
\alpha x, & x < 0
|
427
|
+
\end{cases}
|
428
|
+
|
429
|
+
where :math:`\alpha` = :code:`negative_slope`.
|
430
|
+
|
431
|
+
Args:
|
432
|
+
x : input array
|
433
|
+
negative_slope : array or scalar specifying the negative slope (default: 0.01)
|
434
|
+
|
435
|
+
Returns:
|
436
|
+
An array.
|
437
|
+
|
438
|
+
See also:
|
439
|
+
:func:`relu`
|
440
|
+
"""
|
441
|
+
return u.math.leaky_relu(x, negative_slope=negative_slope)
|
442
|
+
|
443
|
+
|
444
|
+
def _hard_tanh(x, min_val=- 1.0, max_val=1.0):
|
445
|
+
return jax.numpy.where(x > max_val, max_val, jax.numpy.where(x < min_val, min_val, x))
|
446
|
+
|
447
|
+
|
448
|
+
def hard_tanh(
|
449
|
+
x: ArrayLike,
|
450
|
+
min_val: float = - 1.0,
|
451
|
+
max_val: float = 1.0
|
452
|
+
) -> Union[jax.Array, u.Quantity]:
|
453
|
+
r"""Hard :math:`\mathrm{tanh}` activation function.
|
454
|
+
|
455
|
+
Computes the element-wise function:
|
456
|
+
|
457
|
+
.. math::
|
458
|
+
\mathrm{hard\_tanh}(x) = \begin{cases}
|
459
|
+
-1, & x < -1\\
|
460
|
+
x, & -1 \le x \le 1\\
|
461
|
+
1, & 1 < x
|
462
|
+
\end{cases}
|
463
|
+
|
464
|
+
Args:
|
465
|
+
x : input array
|
466
|
+
min_val: float. minimum value of the linear region range. Default: -1
|
467
|
+
max_val: float. maximum value of the linear region range. Default: 1
|
468
|
+
|
469
|
+
Returns:
|
470
|
+
An array.
|
471
|
+
"""
|
472
|
+
x = u.Quantity(x)
|
473
|
+
min_val = u.Quantity(min_val).to(x.unit).mantissa
|
474
|
+
max_val = u.Quantity(max_val).to(x.unit).mantissa
|
475
|
+
return u.maybe_decimal(_hard_tanh(x.mantissa, min_val=min_val, max_val=max_val) * x.unit)
|
476
|
+
|
477
|
+
|
478
|
+
def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
|
479
|
+
r"""Continuously-differentiable exponential linear unit activation.
|
480
|
+
|
481
|
+
Computes the element-wise function:
|
482
|
+
|
483
|
+
.. math::
|
484
|
+
\mathrm{celu}(x) = \begin{cases}
|
485
|
+
x, & x > 0\\
|
486
|
+
\alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
|
487
|
+
\end{cases}
|
488
|
+
|
489
|
+
For more information, see
|
490
|
+
`Continuously Differentiable Exponential Linear Units
|
491
|
+
<https://arxiv.org/pdf/1704.07483.pdf>`_.
|
492
|
+
|
493
|
+
Args:
|
494
|
+
x : input array
|
495
|
+
alpha : array or scalar (default: 1.0)
|
496
|
+
|
497
|
+
Returns:
|
498
|
+
An array.
|
499
|
+
"""
|
500
|
+
return u.math.celu(x, alpha=alpha)
|
501
|
+
|
502
|
+
|
503
|
+
def selu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
504
|
+
r"""Scaled exponential linear unit activation.
|
505
|
+
|
506
|
+
Computes the element-wise function:
|
507
|
+
|
508
|
+
.. math::
|
509
|
+
\mathrm{selu}(x) = \lambda \begin{cases}
|
510
|
+
x, & x > 0\\
|
511
|
+
\alpha e^x - \alpha, & x \le 0
|
512
|
+
\end{cases}
|
513
|
+
|
514
|
+
where :math:`\lambda = 1.0507009873554804934193349852946` and
|
515
|
+
:math:`\alpha = 1.6732632423543772848170429916717`.
|
516
|
+
|
517
|
+
For more information, see
|
518
|
+
`Self-Normalizing Neural Networks
|
519
|
+
<https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf>`_.
|
520
|
+
|
521
|
+
Args:
|
522
|
+
x : input array
|
523
|
+
|
524
|
+
Returns:
|
525
|
+
An array.
|
526
|
+
|
527
|
+
See also:
|
528
|
+
:func:`elu`
|
529
|
+
"""
|
530
|
+
return u.math.selu(x)
|
531
|
+
|
532
|
+
|
533
|
+
def gelu(x: ArrayLike, approximate: bool = True) -> Union[jax.Array, u.Quantity]:
|
534
|
+
r"""Gaussian error linear unit activation function.
|
535
|
+
|
536
|
+
If ``approximate=False``, computes the element-wise function:
|
537
|
+
|
538
|
+
.. math::
|
539
|
+
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
|
540
|
+
\frac{x}{\sqrt{2}} \right) \right)
|
541
|
+
|
542
|
+
If ``approximate=True``, uses the approximate formulation of GELU:
|
543
|
+
|
544
|
+
.. math::
|
545
|
+
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
|
546
|
+
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
|
547
|
+
|
548
|
+
For more information, see `Gaussian Error Linear Units (GELUs)
|
549
|
+
<https://arxiv.org/abs/1606.08415>`_, section 2.
|
550
|
+
|
551
|
+
Args:
|
552
|
+
x : input array
|
553
|
+
approximate: whether to use the approximate or exact formulation.
|
554
|
+
"""
|
555
|
+
return u.math.gelu(x, approximate=approximate)
|
556
|
+
|
557
|
+
|
558
|
+
def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
|
559
|
+
r"""Gated linear unit activation function.
|
560
|
+
|
561
|
+
Computes the function:
|
562
|
+
|
563
|
+
.. math::
|
564
|
+
\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot
|
565
|
+
\mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right]
|
566
|
+
\right)
|
567
|
+
|
568
|
+
where the array is split into two along ``axis``. The size of the ``axis``
|
569
|
+
dimension must be divisible by two.
|
570
|
+
|
571
|
+
Args:
|
572
|
+
x : input array
|
573
|
+
axis: the axis along which the split should be computed (default: -1)
|
574
|
+
|
575
|
+
Returns:
|
576
|
+
An array.
|
577
|
+
|
578
|
+
See also:
|
579
|
+
:func:`sigmoid`
|
580
|
+
"""
|
581
|
+
return u.math.glu(x, axis=axis)
|
582
|
+
|
583
|
+
|
584
|
+
def log_softmax(x: ArrayLike,
|
585
|
+
axis: int | tuple[int, ...] | None = -1,
|
586
|
+
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
587
|
+
r"""Log-Softmax function.
|
588
|
+
|
589
|
+
Computes the logarithm of the :code:`softmax` function, which rescales
|
590
|
+
elements to the range :math:`[-\infty, 0)`.
|
591
|
+
|
592
|
+
.. math ::
|
593
|
+
\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
594
|
+
\right)
|
595
|
+
|
596
|
+
Args:
|
597
|
+
x : input array
|
598
|
+
axis: the axis or axes along which the :code:`log_softmax` should be
|
599
|
+
computed. Either an integer or a tuple of integers.
|
600
|
+
where: Elements to include in the :code:`log_softmax`.
|
601
|
+
|
602
|
+
Returns:
|
603
|
+
An array.
|
604
|
+
|
605
|
+
See also:
|
606
|
+
:func:`softmax`
|
607
|
+
"""
|
608
|
+
return jax.nn.log_softmax(x, axis=axis, where=where)
|
609
|
+
|
610
|
+
|
611
|
+
def softmax(x: ArrayLike,
|
612
|
+
axis: int | tuple[int, ...] | None = -1,
|
613
|
+
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
614
|
+
r"""Softmax function.
|
615
|
+
|
616
|
+
Computes the function which rescales elements to the range :math:`[0, 1]`
|
617
|
+
such that the elements along :code:`axis` sum to :math:`1`.
|
618
|
+
|
619
|
+
.. math ::
|
620
|
+
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
621
|
+
|
622
|
+
Args:
|
623
|
+
x : input array
|
624
|
+
axis: the axis or axes along which the softmax should be computed. The
|
625
|
+
softmax output summed across these dimensions should sum to :math:`1`.
|
626
|
+
Either an integer or a tuple of integers.
|
627
|
+
where: Elements to include in the :code:`softmax`.
|
628
|
+
initial: The minimum value used to shift the input array. Must be present
|
629
|
+
when :code:`where` is not None.
|
630
|
+
|
631
|
+
Returns:
|
632
|
+
An array.
|
633
|
+
|
634
|
+
See also:
|
635
|
+
:func:`log_softmax`
|
636
|
+
"""
|
637
|
+
return jax.nn.softmax(x, axis=axis, where=where)
|
638
|
+
|
639
|
+
|
640
|
+
def standardize(x: ArrayLike,
|
641
|
+
axis: int | tuple[int, ...] | None = -1,
|
642
|
+
variance: ArrayLike | None = None,
|
643
|
+
epsilon: ArrayLike = 1e-5,
|
644
|
+
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
645
|
+
r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
|
646
|
+
return jax.nn.standardize(x, axis=axis, where=where, variance=variance, epsilon=epsilon)
|
647
|
+
|
648
|
+
|
649
|
+
def one_hot(x: Any,
|
650
|
+
num_classes: int, *,
|
651
|
+
dtype: Any = jax.numpy.float_,
|
652
|
+
axis: Union[int, Sequence[int]] = -1) -> Union[jax.Array, u.Quantity]:
|
653
|
+
"""One-hot encodes the given indices.
|
654
|
+
|
655
|
+
Each index in the input ``x`` is encoded as a vector of zeros of length
|
656
|
+
``num_classes`` with the element at ``index`` set to one::
|
657
|
+
|
658
|
+
>>> one_hot(jnp.array([0, 1, 2]), 3)
|
659
|
+
Array([[1., 0., 0.],
|
660
|
+
[0., 1., 0.],
|
661
|
+
[0., 0., 1.]], dtype=float32)
|
662
|
+
|
663
|
+
Indices outside the range [0, num_classes) will be encoded as zeros::
|
664
|
+
|
665
|
+
>>> one_hot(jnp.array([-1, 3]), 3)
|
666
|
+
Array([[0., 0., 0.],
|
667
|
+
[0., 0., 0.]], dtype=float32)
|
668
|
+
|
669
|
+
Args:
|
670
|
+
x: A tensor of indices.
|
671
|
+
num_classes: Number of classes in the one-hot dimension.
|
672
|
+
dtype: optional, a float dtype for the returned values (default :obj:`jnp.float_`).
|
673
|
+
axis: the axis or axes along which the function should be
|
674
|
+
computed.
|
675
|
+
"""
|
676
|
+
return jax.nn.one_hot(x, axis=axis, num_classes=num_classes, dtype=dtype)
|
677
|
+
|
678
|
+
|
679
|
+
def relu6(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
680
|
+
r"""Rectified Linear Unit 6 activation function.
|
681
|
+
|
682
|
+
Computes the element-wise function
|
683
|
+
|
684
|
+
.. math::
|
685
|
+
\mathrm{relu6}(x) = \min(\max(x, 0), 6)
|
686
|
+
|
687
|
+
except under differentiation, we take:
|
688
|
+
|
689
|
+
.. math::
|
690
|
+
\nabla \mathrm{relu}(0) = 0
|
691
|
+
|
692
|
+
and
|
693
|
+
|
694
|
+
.. math::
|
695
|
+
\nabla \mathrm{relu}(6) = 0
|
696
|
+
|
697
|
+
Args:
|
698
|
+
x : input array
|
699
|
+
|
700
|
+
Returns:
|
701
|
+
An array.
|
702
|
+
|
703
|
+
See also:
|
704
|
+
:func:`relu`
|
705
|
+
"""
|
706
|
+
return u.math.relu6(x)
|
707
|
+
|
708
|
+
|
709
|
+
def hard_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
710
|
+
r"""Hard Sigmoid activation function.
|
711
|
+
|
712
|
+
Computes the element-wise function
|
713
|
+
|
714
|
+
.. math::
|
715
|
+
\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
|
716
|
+
|
717
|
+
Args:
|
718
|
+
x : input array
|
719
|
+
|
720
|
+
Returns:
|
721
|
+
An array.
|
722
|
+
|
723
|
+
See also:
|
724
|
+
:func:`relu6`
|
725
|
+
"""
|
726
|
+
return u.math.hard_sigmoid(x)
|
727
|
+
|
728
|
+
|
729
|
+
def hard_silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
730
|
+
r"""Hard SiLU (swish) activation function
|
731
|
+
|
732
|
+
Computes the element-wise function
|
733
|
+
|
734
|
+
.. math::
|
735
|
+
\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
|
736
|
+
|
737
|
+
Both :func:`hard_silu` and :func:`hard_swish` are aliases for the same
|
738
|
+
function.
|
739
|
+
|
740
|
+
Args:
|
741
|
+
x : input array
|
742
|
+
|
743
|
+
Returns:
|
744
|
+
An array.
|
745
|
+
|
746
|
+
See also:
|
747
|
+
:func:`hard_sigmoid`
|
748
|
+
"""
|
749
|
+
return u.math.hard_silu(x)
|
750
|
+
|
751
|
+
|
752
|
+
hard_swish = hard_silu
|
753
|
+
|
754
|
+
|
755
|
+
def sparse_plus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
756
|
+
r"""Sparse plus function.
|
757
|
+
|
758
|
+
Computes the function:
|
759
|
+
|
760
|
+
.. math::
|
761
|
+
|
762
|
+
\mathrm{sparse\_plus}(x) = \begin{cases}
|
763
|
+
0, & x \leq -1\\
|
764
|
+
\frac{1}{4}(x+1)^2, & -1 < x < 1 \\
|
765
|
+
x, & 1 \leq x
|
766
|
+
\end{cases}
|
767
|
+
|
768
|
+
This is the twin function of the softplus activation ensuring a zero output
|
769
|
+
for inputs less than -1 and a linear output for inputs greater than 1,
|
770
|
+
while remaining smooth, convex, monotonic by an adequate definition between
|
771
|
+
-1 and 1.
|
772
|
+
|
773
|
+
Args:
|
774
|
+
x: input (float)
|
775
|
+
"""
|
776
|
+
return u.math.sparse_plus(x)
|
777
|
+
|
778
|
+
|
779
|
+
def sparse_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
780
|
+
r"""Sparse sigmoid activation function.
|
781
|
+
|
782
|
+
Computes the function:
|
783
|
+
|
784
|
+
.. math::
|
785
|
+
|
786
|
+
\mathrm{sparse\_sigmoid}(x) = \begin{cases}
|
787
|
+
0, & x \leq -1\\
|
788
|
+
\frac{1}{2}(x+1), & -1 < x < 1 \\
|
789
|
+
1, & 1 \leq x
|
790
|
+
\end{cases}
|
791
|
+
|
792
|
+
This is the twin function of the ``sigmoid`` activation ensuring a zero output
|
793
|
+
for inputs less than -1, a 1 output for inputs greater than 1, and a linear
|
794
|
+
output for inputs between -1 and 1. It is the derivative of ``sparse_plus``.
|
795
|
+
|
796
|
+
For more information, see `Learning with Fenchel-Young Losses (section 6.2)
|
797
|
+
<https://arxiv.org/abs/1901.02324>`_.
|
798
|
+
|
799
|
+
Args:
|
800
|
+
x : input array
|
801
|
+
|
802
|
+
Returns:
|
803
|
+
An array.
|
804
|
+
|
805
|
+
See also:
|
806
|
+
:func:`sigmoid`
|
807
|
+
"""
|
808
|
+
return u.math.sparse_sigmoid(x)
|