brainstate 0.0.2.post20240913__py2.py3-none-any.whl → 0.0.2.post20241009__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +4 -2
- brainstate/_module.py +102 -67
- brainstate/_state.py +2 -2
- brainstate/_visualization.py +47 -0
- brainstate/environ.py +116 -9
- brainstate/environ_test.py +56 -0
- brainstate/functional/_activations.py +134 -56
- brainstate/functional/_activations_test.py +331 -0
- brainstate/functional/_normalization.py +21 -10
- brainstate/init/_generic.py +4 -2
- brainstate/mixin.py +1 -1
- brainstate/nn/__init__.py +7 -2
- brainstate/nn/_base.py +2 -2
- brainstate/nn/_connections.py +4 -4
- brainstate/nn/_dynamics.py +5 -5
- brainstate/nn/_elementwise.py +9 -9
- brainstate/nn/_embedding.py +3 -3
- brainstate/nn/_normalizations.py +3 -3
- brainstate/nn/_others.py +2 -2
- brainstate/nn/_poolings.py +6 -6
- brainstate/nn/_rate_rnns.py +1 -1
- brainstate/nn/_readout.py +1 -1
- brainstate/nn/_synouts.py +1 -1
- brainstate/nn/event/__init__.py +25 -0
- brainstate/nn/event/_misc.py +34 -0
- brainstate/nn/event/csr.py +312 -0
- brainstate/nn/event/csr_test.py +118 -0
- brainstate/nn/event/fixed_probability.py +276 -0
- brainstate/nn/event/fixed_probability_test.py +127 -0
- brainstate/nn/event/linear.py +220 -0
- brainstate/nn/event/linear_test.py +111 -0
- brainstate/nn/metrics.py +390 -0
- brainstate/optim/__init__.py +5 -1
- brainstate/optim/_optax_optimizer.py +208 -0
- brainstate/optim/_optax_optimizer_test.py +14 -0
- brainstate/random/__init__.py +24 -0
- brainstate/{random.py → random/_rand_funs.py} +7 -1596
- brainstate/random/_rand_seed.py +169 -0
- brainstate/random/_rand_state.py +1491 -0
- brainstate/{_random_for_unit.py → random/_random_for_unit.py} +1 -1
- brainstate/{random_test.py → random/random_test.py} +208 -191
- brainstate/transform/_jit.py +1 -1
- brainstate/transform/_jit_test.py +19 -0
- brainstate/transform/_make_jaxpr.py +1 -1
- {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/METADATA +1 -1
- brainstate-0.0.2.post20241009.dist-info/RECORD +87 -0
- brainstate-0.0.2.post20240913.dist-info/RECORD +0 -70
- {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/top_level.txt +0 -0
@@ -22,12 +22,12 @@ from __future__ import annotations
|
|
22
22
|
|
23
23
|
from typing import Any, Union, Sequence
|
24
24
|
|
25
|
+
import brainunit as u
|
25
26
|
import jax
|
26
|
-
import jax.numpy as jnp
|
27
27
|
from jax.scipy.special import logsumexp
|
28
28
|
|
29
|
+
from brainstate import random
|
29
30
|
from brainstate.typing import ArrayLike
|
30
|
-
from .. import random
|
31
31
|
|
32
32
|
__all__ = [
|
33
33
|
"tanh",
|
@@ -62,10 +62,12 @@ __all__ = [
|
|
62
62
|
'prelu',
|
63
63
|
'tanh_shrink',
|
64
64
|
'softmin',
|
65
|
+
'sparse_plus',
|
66
|
+
'sparse_sigmoid',
|
65
67
|
]
|
66
68
|
|
67
69
|
|
68
|
-
def tanh(x: ArrayLike) -> jax.Array:
|
70
|
+
def tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
69
71
|
r"""Hyperbolic tangent activation function.
|
70
72
|
|
71
73
|
Computes the element-wise function:
|
@@ -79,7 +81,7 @@ def tanh(x: ArrayLike) -> jax.Array:
|
|
79
81
|
Returns:
|
80
82
|
An array.
|
81
83
|
"""
|
82
|
-
return
|
84
|
+
return u.math.tanh(x)
|
83
85
|
|
84
86
|
|
85
87
|
def softmin(x, axis=-1):
|
@@ -102,7 +104,7 @@ def softmin(x, axis=-1):
|
|
102
104
|
axis (int): A dimension along which Softmin will be computed (so every slice
|
103
105
|
along dim will sum to 1).
|
104
106
|
"""
|
105
|
-
unnormalized =
|
107
|
+
unnormalized = u.math.exp(-x)
|
106
108
|
return unnormalized / unnormalized.sum(axis, keepdims=True)
|
107
109
|
|
108
110
|
|
@@ -113,7 +115,7 @@ def tanh_shrink(x):
|
|
113
115
|
.. math::
|
114
116
|
\text{Tanhshrink}(x) = x - \tanh(x)
|
115
117
|
"""
|
116
|
-
return x -
|
118
|
+
return x - u.math.tanh(x)
|
117
119
|
|
118
120
|
|
119
121
|
def prelu(x, a=0.25):
|
@@ -136,7 +138,7 @@ def prelu(x, a=0.25):
|
|
136
138
|
parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
|
137
139
|
a separate :math:`a` is used for each input channel.
|
138
140
|
"""
|
139
|
-
return
|
141
|
+
return u.math.where(x >= 0., x, a * x)
|
140
142
|
|
141
143
|
|
142
144
|
def soft_shrink(x, lambd=0.5):
|
@@ -158,7 +160,11 @@ def soft_shrink(x, lambd=0.5):
|
|
158
160
|
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
159
161
|
- Output: :math:`(*)`, same shape as the input.
|
160
162
|
"""
|
161
|
-
return
|
163
|
+
return u.math.where(x > lambd,
|
164
|
+
x - lambd,
|
165
|
+
u.math.where(x < -lambd,
|
166
|
+
x + lambd,
|
167
|
+
u.Quantity(0., unit=u.get_unit(lambd))))
|
162
168
|
|
163
169
|
|
164
170
|
def mish(x):
|
@@ -176,7 +182,7 @@ def mish(x):
|
|
176
182
|
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
177
183
|
- Output: :math:`(*)`, same shape as the input.
|
178
184
|
"""
|
179
|
-
return x *
|
185
|
+
return x * u.math.tanh(softplus(x))
|
180
186
|
|
181
187
|
|
182
188
|
def rrelu(x, lower=0.125, upper=0.3333333333333333):
|
@@ -210,8 +216,8 @@ def rrelu(x, lower=0.125, upper=0.3333333333333333):
|
|
210
216
|
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
|
211
217
|
https://arxiv.org/abs/1505.00853
|
212
218
|
"""
|
213
|
-
a = random.uniform(lower, upper, size=
|
214
|
-
return
|
219
|
+
a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
|
220
|
+
return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
|
215
221
|
|
216
222
|
|
217
223
|
def hard_shrink(x, lambd=0.5):
|
@@ -235,10 +241,20 @@ def hard_shrink(x, lambd=0.5):
|
|
235
241
|
- Output: :math:`(*)`, same shape as the input.
|
236
242
|
|
237
243
|
"""
|
238
|
-
return
|
244
|
+
return u.math.where(x > lambd,
|
245
|
+
x,
|
246
|
+
u.math.where(x < -lambd,
|
247
|
+
x,
|
248
|
+
u.Quantity(0., unit=u.get_unit(x))))
|
239
249
|
|
240
250
|
|
241
|
-
def
|
251
|
+
def _keep_unit(fun, x, **kwargs):
|
252
|
+
unit = u.get_unit(x)
|
253
|
+
x = fun(u.get_mantissa(x), **kwargs)
|
254
|
+
return x if unit.is_unitless else u.Quantity(x, unit=unit)
|
255
|
+
|
256
|
+
|
257
|
+
def relu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
242
258
|
r"""Rectified linear unit activation function.
|
243
259
|
|
244
260
|
Computes the element-wise function:
|
@@ -269,10 +285,10 @@ def relu(x: ArrayLike) -> jax.Array:
|
|
269
285
|
:func:`relu6`
|
270
286
|
|
271
287
|
"""
|
272
|
-
return jax.nn.relu
|
288
|
+
return _keep_unit(jax.nn.relu, x)
|
273
289
|
|
274
290
|
|
275
|
-
def squareplus(x: ArrayLike, b: ArrayLike = 4) -> jax.Array:
|
291
|
+
def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Union[jax.Array, u.Quantity]:
|
276
292
|
r"""Squareplus activation function.
|
277
293
|
|
278
294
|
Computes the element-wise function
|
@@ -286,10 +302,10 @@ def squareplus(x: ArrayLike, b: ArrayLike = 4) -> jax.Array:
|
|
286
302
|
x : input array
|
287
303
|
b : smoothness parameter
|
288
304
|
"""
|
289
|
-
return jax.nn.squareplus
|
305
|
+
return _keep_unit(jax.nn.squareplus, x, b=b)
|
290
306
|
|
291
307
|
|
292
|
-
def softplus(x: ArrayLike) -> jax.Array:
|
308
|
+
def softplus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
293
309
|
r"""Softplus activation function.
|
294
310
|
|
295
311
|
Computes the element-wise function
|
@@ -300,10 +316,10 @@ def softplus(x: ArrayLike) -> jax.Array:
|
|
300
316
|
Args:
|
301
317
|
x : input array
|
302
318
|
"""
|
303
|
-
return jax.nn.softplus
|
319
|
+
return _keep_unit(jax.nn.softplus, x)
|
304
320
|
|
305
321
|
|
306
|
-
def soft_sign(x: ArrayLike) -> jax.Array:
|
322
|
+
def soft_sign(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
307
323
|
r"""Soft-sign activation function.
|
308
324
|
|
309
325
|
Computes the element-wise function
|
@@ -314,10 +330,10 @@ def soft_sign(x: ArrayLike) -> jax.Array:
|
|
314
330
|
Args:
|
315
331
|
x : input array
|
316
332
|
"""
|
317
|
-
return jax.nn.soft_sign
|
333
|
+
return _keep_unit(jax.nn.soft_sign, x)
|
318
334
|
|
319
335
|
|
320
|
-
def sigmoid(x: ArrayLike) -> jax.Array:
|
336
|
+
def sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
321
337
|
r"""Sigmoid activation function.
|
322
338
|
|
323
339
|
Computes the element-wise function:
|
@@ -335,10 +351,10 @@ def sigmoid(x: ArrayLike) -> jax.Array:
|
|
335
351
|
:func:`log_sigmoid`
|
336
352
|
|
337
353
|
"""
|
338
|
-
return jax.nn.sigmoid
|
354
|
+
return _keep_unit(jax.nn.sigmoid, x)
|
339
355
|
|
340
356
|
|
341
|
-
def silu(x: ArrayLike) -> jax.Array:
|
357
|
+
def silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
342
358
|
r"""SiLU (a.k.a. swish) activation function.
|
343
359
|
|
344
360
|
Computes the element-wise function:
|
@@ -357,13 +373,13 @@ def silu(x: ArrayLike) -> jax.Array:
|
|
357
373
|
See also:
|
358
374
|
:func:`sigmoid`
|
359
375
|
"""
|
360
|
-
return jax.nn.silu
|
376
|
+
return _keep_unit(jax.nn.silu, x)
|
361
377
|
|
362
378
|
|
363
379
|
swish = silu
|
364
380
|
|
365
381
|
|
366
|
-
def log_sigmoid(x: ArrayLike) -> jax.Array:
|
382
|
+
def log_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
367
383
|
r"""Log-sigmoid activation function.
|
368
384
|
|
369
385
|
Computes the element-wise function:
|
@@ -380,10 +396,10 @@ def log_sigmoid(x: ArrayLike) -> jax.Array:
|
|
380
396
|
See also:
|
381
397
|
:func:`sigmoid`
|
382
398
|
"""
|
383
|
-
return jax.nn.log_sigmoid
|
399
|
+
return _keep_unit(jax.nn.log_sigmoid, x)
|
384
400
|
|
385
401
|
|
386
|
-
def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
|
402
|
+
def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
|
387
403
|
r"""Exponential linear unit activation function.
|
388
404
|
|
389
405
|
Computes the element-wise function:
|
@@ -404,10 +420,10 @@ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
|
|
404
420
|
See also:
|
405
421
|
:func:`selu`
|
406
422
|
"""
|
407
|
-
return jax.nn.elu
|
423
|
+
return _keep_unit(jax.nn.elu, x)
|
408
424
|
|
409
425
|
|
410
|
-
def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> jax.Array:
|
426
|
+
def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Union[jax.Array, u.Quantity]:
|
411
427
|
r"""Leaky rectified linear unit activation function.
|
412
428
|
|
413
429
|
Computes the element-wise function:
|
@@ -430,10 +446,10 @@ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> jax.Array:
|
|
430
446
|
See also:
|
431
447
|
:func:`relu`
|
432
448
|
"""
|
433
|
-
return jax.nn.leaky_relu
|
449
|
+
return _keep_unit(jax.nn.leaky_relu, x, negative_slope=negative_slope)
|
434
450
|
|
435
451
|
|
436
|
-
def hard_tanh(x: ArrayLike) -> jax.Array:
|
452
|
+
def hard_tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
437
453
|
r"""Hard :math:`\mathrm{tanh}` activation function.
|
438
454
|
|
439
455
|
Computes the element-wise function:
|
@@ -451,10 +467,10 @@ def hard_tanh(x: ArrayLike) -> jax.Array:
|
|
451
467
|
Returns:
|
452
468
|
An array.
|
453
469
|
"""
|
454
|
-
return jax.nn.hard_tanh
|
470
|
+
return _keep_unit(jax.nn.hard_tanh, x)
|
455
471
|
|
456
472
|
|
457
|
-
def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
|
473
|
+
def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
|
458
474
|
r"""Continuously-differentiable exponential linear unit activation.
|
459
475
|
|
460
476
|
Computes the element-wise function:
|
@@ -476,10 +492,10 @@ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
|
|
476
492
|
Returns:
|
477
493
|
An array.
|
478
494
|
"""
|
479
|
-
return jax.nn.celu
|
495
|
+
return _keep_unit(jax.nn.celu, x, alpha=alpha)
|
480
496
|
|
481
497
|
|
482
|
-
def selu(x: ArrayLike) -> jax.Array:
|
498
|
+
def selu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
483
499
|
r"""Scaled exponential linear unit activation.
|
484
500
|
|
485
501
|
Computes the element-wise function:
|
@@ -506,10 +522,10 @@ def selu(x: ArrayLike) -> jax.Array:
|
|
506
522
|
See also:
|
507
523
|
:func:`elu`
|
508
524
|
"""
|
509
|
-
return jax.nn.selu
|
525
|
+
return _keep_unit(jax.nn.selu, x)
|
510
526
|
|
511
527
|
|
512
|
-
def gelu(x: ArrayLike, approximate: bool = True) -> jax.Array:
|
528
|
+
def gelu(x: ArrayLike, approximate: bool = True) -> Union[jax.Array, u.Quantity]:
|
513
529
|
r"""Gaussian error linear unit activation function.
|
514
530
|
|
515
531
|
If ``approximate=False``, computes the element-wise function:
|
@@ -531,10 +547,10 @@ def gelu(x: ArrayLike, approximate: bool = True) -> jax.Array:
|
|
531
547
|
x : input array
|
532
548
|
approximate: whether to use the approximate or exact formulation.
|
533
549
|
"""
|
534
|
-
return jax.nn.gelu
|
550
|
+
return _keep_unit(jax.nn.gelu, x, approximate=approximate)
|
535
551
|
|
536
552
|
|
537
|
-
def glu(x: ArrayLike, axis: int = -1) -> jax.Array:
|
553
|
+
def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
|
538
554
|
r"""Gated linear unit activation function.
|
539
555
|
|
540
556
|
Computes the function:
|
@@ -557,13 +573,13 @@ def glu(x: ArrayLike, axis: int = -1) -> jax.Array:
|
|
557
573
|
See also:
|
558
574
|
:func:`sigmoid`
|
559
575
|
"""
|
560
|
-
return jax.nn.glu
|
576
|
+
return _keep_unit(jax.nn.glu, x, axis=axis)
|
561
577
|
|
562
578
|
|
563
579
|
def log_softmax(x: ArrayLike,
|
564
580
|
axis: int | tuple[int, ...] | None = -1,
|
565
581
|
where: ArrayLike | None = None,
|
566
|
-
initial: ArrayLike | None = None) -> jax.Array:
|
582
|
+
initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
567
583
|
r"""Log-Softmax function.
|
568
584
|
|
569
585
|
Computes the logarithm of the :code:`softmax` function, which rescales
|
@@ -587,13 +603,15 @@ def log_softmax(x: ArrayLike,
|
|
587
603
|
See also:
|
588
604
|
:func:`softmax`
|
589
605
|
"""
|
590
|
-
|
606
|
+
if initial is not None:
|
607
|
+
initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
|
608
|
+
return _keep_unit(jax.nn.log_softmax, x, axis=axis, where=where, initial=initial)
|
591
609
|
|
592
610
|
|
593
611
|
def softmax(x: ArrayLike,
|
594
612
|
axis: int | tuple[int, ...] | None = -1,
|
595
613
|
where: ArrayLike | None = None,
|
596
|
-
initial: ArrayLike | None = None) -> jax.Array:
|
614
|
+
initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
597
615
|
r"""Softmax function.
|
598
616
|
|
599
617
|
Computes the function which rescales elements to the range :math:`[0, 1]`
|
@@ -617,35 +635,37 @@ def softmax(x: ArrayLike,
|
|
617
635
|
See also:
|
618
636
|
:func:`log_softmax`
|
619
637
|
"""
|
620
|
-
|
638
|
+
if initial is not None:
|
639
|
+
initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
|
640
|
+
return _keep_unit(jax.nn.softmax, x, axis=axis, where=where, initial=initial)
|
621
641
|
|
622
642
|
|
623
643
|
def standardize(x: ArrayLike,
|
624
644
|
axis: int | tuple[int, ...] | None = -1,
|
625
645
|
variance: ArrayLike | None = None,
|
626
646
|
epsilon: ArrayLike = 1e-5,
|
627
|
-
where: ArrayLike | None = None) -> jax.Array:
|
647
|
+
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
628
648
|
r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
|
629
|
-
return jax.nn.standardize
|
649
|
+
return _keep_unit(jax.nn.standardize, x, axis=axis, where=where, variance=variance, epsilon=epsilon)
|
630
650
|
|
631
651
|
|
632
652
|
def one_hot(x: Any,
|
633
653
|
num_classes: int, *,
|
634
|
-
dtype: Any =
|
635
|
-
axis: Union[int, Sequence[int]] = -1) -> jax.Array:
|
654
|
+
dtype: Any = jax.numpy.float_,
|
655
|
+
axis: Union[int, Sequence[int]] = -1) -> Union[jax.Array, u.Quantity]:
|
636
656
|
"""One-hot encodes the given indices.
|
637
657
|
|
638
658
|
Each index in the input ``x`` is encoded as a vector of zeros of length
|
639
659
|
``num_classes`` with the element at ``index`` set to one::
|
640
660
|
|
641
|
-
>>>
|
661
|
+
>>> one_hot(jnp.array([0, 1, 2]), 3)
|
642
662
|
Array([[1., 0., 0.],
|
643
663
|
[0., 1., 0.],
|
644
664
|
[0., 0., 1.]], dtype=float32)
|
645
665
|
|
646
666
|
Indices outside the range [0, num_classes) will be encoded as zeros::
|
647
667
|
|
648
|
-
>>>
|
668
|
+
>>> one_hot(jnp.array([-1, 3]), 3)
|
649
669
|
Array([[0., 0., 0.],
|
650
670
|
[0., 0., 0.]], dtype=float32)
|
651
671
|
|
@@ -656,10 +676,10 @@ def one_hot(x: Any,
|
|
656
676
|
axis: the axis or axes along which the function should be
|
657
677
|
computed.
|
658
678
|
"""
|
659
|
-
return jax.nn.one_hot
|
679
|
+
return _keep_unit(jax.nn.one_hot, x, axis=axis, num_classes=num_classes, dtype=dtype)
|
660
680
|
|
661
681
|
|
662
|
-
def relu6(x: ArrayLike) -> jax.Array:
|
682
|
+
def relu6(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
663
683
|
r"""Rectified Linear Unit 6 activation function.
|
664
684
|
|
665
685
|
Computes the element-wise function
|
@@ -686,10 +706,10 @@ def relu6(x: ArrayLike) -> jax.Array:
|
|
686
706
|
See also:
|
687
707
|
:func:`relu`
|
688
708
|
"""
|
689
|
-
return jax.nn.relu6
|
709
|
+
return _keep_unit(jax.nn.relu6, x)
|
690
710
|
|
691
711
|
|
692
|
-
def hard_sigmoid(x: ArrayLike) -> jax.Array:
|
712
|
+
def hard_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
693
713
|
r"""Hard Sigmoid activation function.
|
694
714
|
|
695
715
|
Computes the element-wise function
|
@@ -706,10 +726,10 @@ def hard_sigmoid(x: ArrayLike) -> jax.Array:
|
|
706
726
|
See also:
|
707
727
|
:func:`relu6`
|
708
728
|
"""
|
709
|
-
return jax.nn.hard_sigmoid
|
729
|
+
return _keep_unit(jax.nn.hard_sigmoid, x)
|
710
730
|
|
711
731
|
|
712
|
-
def hard_silu(x: ArrayLike) -> jax.Array:
|
732
|
+
def hard_silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
713
733
|
r"""Hard SiLU (swish) activation function
|
714
734
|
|
715
735
|
Computes the element-wise function
|
@@ -729,7 +749,65 @@ def hard_silu(x: ArrayLike) -> jax.Array:
|
|
729
749
|
See also:
|
730
750
|
:func:`hard_sigmoid`
|
731
751
|
"""
|
752
|
+
return _keep_unit(jax.nn.hard_silu, x)
|
753
|
+
|
732
754
|
return jax.nn.hard_silu(x)
|
733
755
|
|
734
756
|
|
735
757
|
hard_swish = hard_silu
|
758
|
+
|
759
|
+
|
760
|
+
def sparse_plus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
761
|
+
r"""Sparse plus function.
|
762
|
+
|
763
|
+
Computes the function:
|
764
|
+
|
765
|
+
.. math::
|
766
|
+
|
767
|
+
\mathrm{sparse\_plus}(x) = \begin{cases}
|
768
|
+
0, & x \leq -1\\
|
769
|
+
\frac{1}{4}(x+1)^2, & -1 < x < 1 \\
|
770
|
+
x, & 1 \leq x
|
771
|
+
\end{cases}
|
772
|
+
|
773
|
+
This is the twin function of the softplus activation ensuring a zero output
|
774
|
+
for inputs less than -1 and a linear output for inputs greater than 1,
|
775
|
+
while remaining smooth, convex, monotonic by an adequate definition between
|
776
|
+
-1 and 1.
|
777
|
+
|
778
|
+
Args:
|
779
|
+
x: input (float)
|
780
|
+
"""
|
781
|
+
return _keep_unit(jax.nn.sparse_plus, x)
|
782
|
+
|
783
|
+
|
784
|
+
def sparse_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
785
|
+
r"""Sparse sigmoid activation function.
|
786
|
+
|
787
|
+
Computes the function:
|
788
|
+
|
789
|
+
.. math::
|
790
|
+
|
791
|
+
\mathrm{sparse\_sigmoid}(x) = \begin{cases}
|
792
|
+
0, & x \leq -1\\
|
793
|
+
\frac{1}{2}(x+1), & -1 < x < 1 \\
|
794
|
+
1, & 1 \leq x
|
795
|
+
\end{cases}
|
796
|
+
|
797
|
+
This is the twin function of the ``sigmoid`` activation ensuring a zero output
|
798
|
+
for inputs less than -1, a 1 output for inputs greater than 1, and a linear
|
799
|
+
output for inputs between -1 and 1. It is the derivative of ``sparse_plus``.
|
800
|
+
|
801
|
+
For more information, see `Learning with Fenchel-Young Losses (section 6.2)
|
802
|
+
<https://arxiv.org/abs/1901.02324>`_.
|
803
|
+
|
804
|
+
Args:
|
805
|
+
x : input array
|
806
|
+
|
807
|
+
Returns:
|
808
|
+
An array.
|
809
|
+
|
810
|
+
See also:
|
811
|
+
:func:`sigmoid`
|
812
|
+
"""
|
813
|
+
return _keep_unit(jax.nn.sparse_sigmoid, x)
|