brainstate 0.0.2.post20240913__py2.py3-none-any.whl → 0.0.2.post20241009__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. brainstate/__init__.py +4 -2
  2. brainstate/_module.py +102 -67
  3. brainstate/_state.py +2 -2
  4. brainstate/_visualization.py +47 -0
  5. brainstate/environ.py +116 -9
  6. brainstate/environ_test.py +56 -0
  7. brainstate/functional/_activations.py +134 -56
  8. brainstate/functional/_activations_test.py +331 -0
  9. brainstate/functional/_normalization.py +21 -10
  10. brainstate/init/_generic.py +4 -2
  11. brainstate/mixin.py +1 -1
  12. brainstate/nn/__init__.py +7 -2
  13. brainstate/nn/_base.py +2 -2
  14. brainstate/nn/_connections.py +4 -4
  15. brainstate/nn/_dynamics.py +5 -5
  16. brainstate/nn/_elementwise.py +9 -9
  17. brainstate/nn/_embedding.py +3 -3
  18. brainstate/nn/_normalizations.py +3 -3
  19. brainstate/nn/_others.py +2 -2
  20. brainstate/nn/_poolings.py +6 -6
  21. brainstate/nn/_rate_rnns.py +1 -1
  22. brainstate/nn/_readout.py +1 -1
  23. brainstate/nn/_synouts.py +1 -1
  24. brainstate/nn/event/__init__.py +25 -0
  25. brainstate/nn/event/_misc.py +34 -0
  26. brainstate/nn/event/csr.py +312 -0
  27. brainstate/nn/event/csr_test.py +118 -0
  28. brainstate/nn/event/fixed_probability.py +276 -0
  29. brainstate/nn/event/fixed_probability_test.py +127 -0
  30. brainstate/nn/event/linear.py +220 -0
  31. brainstate/nn/event/linear_test.py +111 -0
  32. brainstate/nn/metrics.py +390 -0
  33. brainstate/optim/__init__.py +5 -1
  34. brainstate/optim/_optax_optimizer.py +208 -0
  35. brainstate/optim/_optax_optimizer_test.py +14 -0
  36. brainstate/random/__init__.py +24 -0
  37. brainstate/{random.py → random/_rand_funs.py} +7 -1596
  38. brainstate/random/_rand_seed.py +169 -0
  39. brainstate/random/_rand_state.py +1491 -0
  40. brainstate/{_random_for_unit.py → random/_random_for_unit.py} +1 -1
  41. brainstate/{random_test.py → random/random_test.py} +208 -191
  42. brainstate/transform/_jit.py +1 -1
  43. brainstate/transform/_jit_test.py +19 -0
  44. brainstate/transform/_make_jaxpr.py +1 -1
  45. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/METADATA +1 -1
  46. brainstate-0.0.2.post20241009.dist-info/RECORD +87 -0
  47. brainstate-0.0.2.post20240913.dist-info/RECORD +0 -70
  48. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/LICENSE +0 -0
  49. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/WHEEL +0 -0
  50. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/top_level.txt +0 -0
@@ -22,12 +22,12 @@ from __future__ import annotations
22
22
 
23
23
  from typing import Any, Union, Sequence
24
24
 
25
+ import brainunit as u
25
26
  import jax
26
- import jax.numpy as jnp
27
27
  from jax.scipy.special import logsumexp
28
28
 
29
+ from brainstate import random
29
30
  from brainstate.typing import ArrayLike
30
- from .. import random
31
31
 
32
32
  __all__ = [
33
33
  "tanh",
@@ -62,10 +62,12 @@ __all__ = [
62
62
  'prelu',
63
63
  'tanh_shrink',
64
64
  'softmin',
65
+ 'sparse_plus',
66
+ 'sparse_sigmoid',
65
67
  ]
66
68
 
67
69
 
68
- def tanh(x: ArrayLike) -> jax.Array:
70
+ def tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
69
71
  r"""Hyperbolic tangent activation function.
70
72
 
71
73
  Computes the element-wise function:
@@ -79,7 +81,7 @@ def tanh(x: ArrayLike) -> jax.Array:
79
81
  Returns:
80
82
  An array.
81
83
  """
82
- return jnp.tanh(x)
84
+ return u.math.tanh(x)
83
85
 
84
86
 
85
87
  def softmin(x, axis=-1):
@@ -102,7 +104,7 @@ def softmin(x, axis=-1):
102
104
  axis (int): A dimension along which Softmin will be computed (so every slice
103
105
  along dim will sum to 1).
104
106
  """
105
- unnormalized = jnp.exp(-x)
107
+ unnormalized = u.math.exp(-x)
106
108
  return unnormalized / unnormalized.sum(axis, keepdims=True)
107
109
 
108
110
 
@@ -113,7 +115,7 @@ def tanh_shrink(x):
113
115
  .. math::
114
116
  \text{Tanhshrink}(x) = x - \tanh(x)
115
117
  """
116
- return x - jnp.tanh(x)
118
+ return x - u.math.tanh(x)
117
119
 
118
120
 
119
121
  def prelu(x, a=0.25):
@@ -136,7 +138,7 @@ def prelu(x, a=0.25):
136
138
  parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
137
139
  a separate :math:`a` is used for each input channel.
138
140
  """
139
- return jnp.where(x >= 0., x, a * x)
141
+ return u.math.where(x >= 0., x, a * x)
140
142
 
141
143
 
142
144
  def soft_shrink(x, lambd=0.5):
@@ -158,7 +160,11 @@ def soft_shrink(x, lambd=0.5):
158
160
  - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
159
161
  - Output: :math:`(*)`, same shape as the input.
160
162
  """
161
- return jnp.where(x > lambd, x - lambd, jnp.where(x < -lambd, x + lambd, 0.))
163
+ return u.math.where(x > lambd,
164
+ x - lambd,
165
+ u.math.where(x < -lambd,
166
+ x + lambd,
167
+ u.Quantity(0., unit=u.get_unit(lambd))))
162
168
 
163
169
 
164
170
  def mish(x):
@@ -176,7 +182,7 @@ def mish(x):
176
182
  - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
177
183
  - Output: :math:`(*)`, same shape as the input.
178
184
  """
179
- return x * jnp.tanh(softplus(x))
185
+ return x * u.math.tanh(softplus(x))
180
186
 
181
187
 
182
188
  def rrelu(x, lower=0.125, upper=0.3333333333333333):
@@ -210,8 +216,8 @@ def rrelu(x, lower=0.125, upper=0.3333333333333333):
210
216
  .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
211
217
  https://arxiv.org/abs/1505.00853
212
218
  """
213
- a = random.uniform(lower, upper, size=jnp.shape(x), dtype=x.dtype)
214
- return jnp.where(x >= 0., x, a * x)
219
+ a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
220
+ return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
215
221
 
216
222
 
217
223
  def hard_shrink(x, lambd=0.5):
@@ -235,10 +241,20 @@ def hard_shrink(x, lambd=0.5):
235
241
  - Output: :math:`(*)`, same shape as the input.
236
242
 
237
243
  """
238
- return jnp.where(x > lambd, x, jnp.where(x < -lambd, x, 0.))
244
+ return u.math.where(x > lambd,
245
+ x,
246
+ u.math.where(x < -lambd,
247
+ x,
248
+ u.Quantity(0., unit=u.get_unit(x))))
239
249
 
240
250
 
241
- def relu(x: ArrayLike) -> jax.Array:
251
+ def _keep_unit(fun, x, **kwargs):
252
+ unit = u.get_unit(x)
253
+ x = fun(u.get_mantissa(x), **kwargs)
254
+ return x if unit.is_unitless else u.Quantity(x, unit=unit)
255
+
256
+
257
+ def relu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
242
258
  r"""Rectified linear unit activation function.
243
259
 
244
260
  Computes the element-wise function:
@@ -269,10 +285,10 @@ def relu(x: ArrayLike) -> jax.Array:
269
285
  :func:`relu6`
270
286
 
271
287
  """
272
- return jax.nn.relu(x)
288
+ return _keep_unit(jax.nn.relu, x)
273
289
 
274
290
 
275
- def squareplus(x: ArrayLike, b: ArrayLike = 4) -> jax.Array:
291
+ def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Union[jax.Array, u.Quantity]:
276
292
  r"""Squareplus activation function.
277
293
 
278
294
  Computes the element-wise function
@@ -286,10 +302,10 @@ def squareplus(x: ArrayLike, b: ArrayLike = 4) -> jax.Array:
286
302
  x : input array
287
303
  b : smoothness parameter
288
304
  """
289
- return jax.nn.squareplus(x, b)
305
+ return _keep_unit(jax.nn.squareplus, x, b=b)
290
306
 
291
307
 
292
- def softplus(x: ArrayLike) -> jax.Array:
308
+ def softplus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
293
309
  r"""Softplus activation function.
294
310
 
295
311
  Computes the element-wise function
@@ -300,10 +316,10 @@ def softplus(x: ArrayLike) -> jax.Array:
300
316
  Args:
301
317
  x : input array
302
318
  """
303
- return jax.nn.softplus(x)
319
+ return _keep_unit(jax.nn.softplus, x)
304
320
 
305
321
 
306
- def soft_sign(x: ArrayLike) -> jax.Array:
322
+ def soft_sign(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
307
323
  r"""Soft-sign activation function.
308
324
 
309
325
  Computes the element-wise function
@@ -314,10 +330,10 @@ def soft_sign(x: ArrayLike) -> jax.Array:
314
330
  Args:
315
331
  x : input array
316
332
  """
317
- return jax.nn.soft_sign(x)
333
+ return _keep_unit(jax.nn.soft_sign, x)
318
334
 
319
335
 
320
- def sigmoid(x: ArrayLike) -> jax.Array:
336
+ def sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
321
337
  r"""Sigmoid activation function.
322
338
 
323
339
  Computes the element-wise function:
@@ -335,10 +351,10 @@ def sigmoid(x: ArrayLike) -> jax.Array:
335
351
  :func:`log_sigmoid`
336
352
 
337
353
  """
338
- return jax.nn.sigmoid(x)
354
+ return _keep_unit(jax.nn.sigmoid, x)
339
355
 
340
356
 
341
- def silu(x: ArrayLike) -> jax.Array:
357
+ def silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
342
358
  r"""SiLU (a.k.a. swish) activation function.
343
359
 
344
360
  Computes the element-wise function:
@@ -357,13 +373,13 @@ def silu(x: ArrayLike) -> jax.Array:
357
373
  See also:
358
374
  :func:`sigmoid`
359
375
  """
360
- return jax.nn.silu(x)
376
+ return _keep_unit(jax.nn.silu, x)
361
377
 
362
378
 
363
379
  swish = silu
364
380
 
365
381
 
366
- def log_sigmoid(x: ArrayLike) -> jax.Array:
382
+ def log_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
367
383
  r"""Log-sigmoid activation function.
368
384
 
369
385
  Computes the element-wise function:
@@ -380,10 +396,10 @@ def log_sigmoid(x: ArrayLike) -> jax.Array:
380
396
  See also:
381
397
  :func:`sigmoid`
382
398
  """
383
- return jax.nn.log_sigmoid(x)
399
+ return _keep_unit(jax.nn.log_sigmoid, x)
384
400
 
385
401
 
386
- def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
402
+ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
387
403
  r"""Exponential linear unit activation function.
388
404
 
389
405
  Computes the element-wise function:
@@ -404,10 +420,10 @@ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
404
420
  See also:
405
421
  :func:`selu`
406
422
  """
407
- return jax.nn.elu(x, alpha)
423
+ return _keep_unit(jax.nn.elu, x)
408
424
 
409
425
 
410
- def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> jax.Array:
426
+ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Union[jax.Array, u.Quantity]:
411
427
  r"""Leaky rectified linear unit activation function.
412
428
 
413
429
  Computes the element-wise function:
@@ -430,10 +446,10 @@ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> jax.Array:
430
446
  See also:
431
447
  :func:`relu`
432
448
  """
433
- return jax.nn.leaky_relu(x, negative_slope=negative_slope)
449
+ return _keep_unit(jax.nn.leaky_relu, x, negative_slope=negative_slope)
434
450
 
435
451
 
436
- def hard_tanh(x: ArrayLike) -> jax.Array:
452
+ def hard_tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
437
453
  r"""Hard :math:`\mathrm{tanh}` activation function.
438
454
 
439
455
  Computes the element-wise function:
@@ -451,10 +467,10 @@ def hard_tanh(x: ArrayLike) -> jax.Array:
451
467
  Returns:
452
468
  An array.
453
469
  """
454
- return jax.nn.hard_tanh(x)
470
+ return _keep_unit(jax.nn.hard_tanh, x)
455
471
 
456
472
 
457
- def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
473
+ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
458
474
  r"""Continuously-differentiable exponential linear unit activation.
459
475
 
460
476
  Computes the element-wise function:
@@ -476,10 +492,10 @@ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
476
492
  Returns:
477
493
  An array.
478
494
  """
479
- return jax.nn.celu(x, alpha)
495
+ return _keep_unit(jax.nn.celu, x, alpha=alpha)
480
496
 
481
497
 
482
- def selu(x: ArrayLike) -> jax.Array:
498
+ def selu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
483
499
  r"""Scaled exponential linear unit activation.
484
500
 
485
501
  Computes the element-wise function:
@@ -506,10 +522,10 @@ def selu(x: ArrayLike) -> jax.Array:
506
522
  See also:
507
523
  :func:`elu`
508
524
  """
509
- return jax.nn.selu(x)
525
+ return _keep_unit(jax.nn.selu, x)
510
526
 
511
527
 
512
- def gelu(x: ArrayLike, approximate: bool = True) -> jax.Array:
528
+ def gelu(x: ArrayLike, approximate: bool = True) -> Union[jax.Array, u.Quantity]:
513
529
  r"""Gaussian error linear unit activation function.
514
530
 
515
531
  If ``approximate=False``, computes the element-wise function:
@@ -531,10 +547,10 @@ def gelu(x: ArrayLike, approximate: bool = True) -> jax.Array:
531
547
  x : input array
532
548
  approximate: whether to use the approximate or exact formulation.
533
549
  """
534
- return jax.nn.gelu(x, approximate=approximate)
550
+ return _keep_unit(jax.nn.gelu, x, approximate=approximate)
535
551
 
536
552
 
537
- def glu(x: ArrayLike, axis: int = -1) -> jax.Array:
553
+ def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
538
554
  r"""Gated linear unit activation function.
539
555
 
540
556
  Computes the function:
@@ -557,13 +573,13 @@ def glu(x: ArrayLike, axis: int = -1) -> jax.Array:
557
573
  See also:
558
574
  :func:`sigmoid`
559
575
  """
560
- return jax.nn.glu(x, axis=axis)
576
+ return _keep_unit(jax.nn.glu, x, axis=axis)
561
577
 
562
578
 
563
579
  def log_softmax(x: ArrayLike,
564
580
  axis: int | tuple[int, ...] | None = -1,
565
581
  where: ArrayLike | None = None,
566
- initial: ArrayLike | None = None) -> jax.Array:
582
+ initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
567
583
  r"""Log-Softmax function.
568
584
 
569
585
  Computes the logarithm of the :code:`softmax` function, which rescales
@@ -587,13 +603,15 @@ def log_softmax(x: ArrayLike,
587
603
  See also:
588
604
  :func:`softmax`
589
605
  """
590
- return jax.nn.log_softmax(x, axis, where, initial)
606
+ if initial is not None:
607
+ initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
608
+ return _keep_unit(jax.nn.log_softmax, x, axis=axis, where=where, initial=initial)
591
609
 
592
610
 
593
611
  def softmax(x: ArrayLike,
594
612
  axis: int | tuple[int, ...] | None = -1,
595
613
  where: ArrayLike | None = None,
596
- initial: ArrayLike | None = None) -> jax.Array:
614
+ initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
597
615
  r"""Softmax function.
598
616
 
599
617
  Computes the function which rescales elements to the range :math:`[0, 1]`
@@ -617,35 +635,37 @@ def softmax(x: ArrayLike,
617
635
  See also:
618
636
  :func:`log_softmax`
619
637
  """
620
- return jax.nn.softmax(x, axis, where, initial)
638
+ if initial is not None:
639
+ initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
640
+ return _keep_unit(jax.nn.softmax, x, axis=axis, where=where, initial=initial)
621
641
 
622
642
 
623
643
  def standardize(x: ArrayLike,
624
644
  axis: int | tuple[int, ...] | None = -1,
625
645
  variance: ArrayLike | None = None,
626
646
  epsilon: ArrayLike = 1e-5,
627
- where: ArrayLike | None = None) -> jax.Array:
647
+ where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
628
648
  r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
629
- return jax.nn.standardize(x, axis, variance, epsilon, where)
649
+ return _keep_unit(jax.nn.standardize, x, axis=axis, where=where, variance=variance, epsilon=epsilon)
630
650
 
631
651
 
632
652
  def one_hot(x: Any,
633
653
  num_classes: int, *,
634
- dtype: Any = jnp.float_,
635
- axis: Union[int, Sequence[int]] = -1) -> jax.Array:
654
+ dtype: Any = jax.numpy.float_,
655
+ axis: Union[int, Sequence[int]] = -1) -> Union[jax.Array, u.Quantity]:
636
656
  """One-hot encodes the given indices.
637
657
 
638
658
  Each index in the input ``x`` is encoded as a vector of zeros of length
639
659
  ``num_classes`` with the element at ``index`` set to one::
640
660
 
641
- >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
661
+ >>> one_hot(jnp.array([0, 1, 2]), 3)
642
662
  Array([[1., 0., 0.],
643
663
  [0., 1., 0.],
644
664
  [0., 0., 1.]], dtype=float32)
645
665
 
646
666
  Indices outside the range [0, num_classes) will be encoded as zeros::
647
667
 
648
- >>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
668
+ >>> one_hot(jnp.array([-1, 3]), 3)
649
669
  Array([[0., 0., 0.],
650
670
  [0., 0., 0.]], dtype=float32)
651
671
 
@@ -656,10 +676,10 @@ def one_hot(x: Any,
656
676
  axis: the axis or axes along which the function should be
657
677
  computed.
658
678
  """
659
- return jax.nn.one_hot(x, num_classes, dtype=dtype, axis=axis)
679
+ return _keep_unit(jax.nn.one_hot, x, axis=axis, num_classes=num_classes, dtype=dtype)
660
680
 
661
681
 
662
- def relu6(x: ArrayLike) -> jax.Array:
682
+ def relu6(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
663
683
  r"""Rectified Linear Unit 6 activation function.
664
684
 
665
685
  Computes the element-wise function
@@ -686,10 +706,10 @@ def relu6(x: ArrayLike) -> jax.Array:
686
706
  See also:
687
707
  :func:`relu`
688
708
  """
689
- return jax.nn.relu6(x)
709
+ return _keep_unit(jax.nn.relu6, x)
690
710
 
691
711
 
692
- def hard_sigmoid(x: ArrayLike) -> jax.Array:
712
+ def hard_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
693
713
  r"""Hard Sigmoid activation function.
694
714
 
695
715
  Computes the element-wise function
@@ -706,10 +726,10 @@ def hard_sigmoid(x: ArrayLike) -> jax.Array:
706
726
  See also:
707
727
  :func:`relu6`
708
728
  """
709
- return jax.nn.hard_sigmoid(x)
729
+ return _keep_unit(jax.nn.hard_sigmoid, x)
710
730
 
711
731
 
712
- def hard_silu(x: ArrayLike) -> jax.Array:
732
+ def hard_silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
713
733
  r"""Hard SiLU (swish) activation function
714
734
 
715
735
  Computes the element-wise function
@@ -729,7 +749,65 @@ def hard_silu(x: ArrayLike) -> jax.Array:
729
749
  See also:
730
750
  :func:`hard_sigmoid`
731
751
  """
752
+ return _keep_unit(jax.nn.hard_silu, x)
753
+
732
754
  return jax.nn.hard_silu(x)
733
755
 
734
756
 
735
757
  hard_swish = hard_silu
758
+
759
+
760
+ def sparse_plus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
761
+ r"""Sparse plus function.
762
+
763
+ Computes the function:
764
+
765
+ .. math::
766
+
767
+ \mathrm{sparse\_plus}(x) = \begin{cases}
768
+ 0, & x \leq -1\\
769
+ \frac{1}{4}(x+1)^2, & -1 < x < 1 \\
770
+ x, & 1 \leq x
771
+ \end{cases}
772
+
773
+ This is the twin function of the softplus activation ensuring a zero output
774
+ for inputs less than -1 and a linear output for inputs greater than 1,
775
+ while remaining smooth, convex, monotonic by an adequate definition between
776
+ -1 and 1.
777
+
778
+ Args:
779
+ x: input (float)
780
+ """
781
+ return _keep_unit(jax.nn.sparse_plus, x)
782
+
783
+
784
+ def sparse_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
785
+ r"""Sparse sigmoid activation function.
786
+
787
+ Computes the function:
788
+
789
+ .. math::
790
+
791
+ \mathrm{sparse\_sigmoid}(x) = \begin{cases}
792
+ 0, & x \leq -1\\
793
+ \frac{1}{2}(x+1), & -1 < x < 1 \\
794
+ 1, & 1 \leq x
795
+ \end{cases}
796
+
797
+ This is the twin function of the ``sigmoid`` activation ensuring a zero output
798
+ for inputs less than -1, a 1 output for inputs greater than 1, and a linear
799
+ output for inputs between -1 and 1. It is the derivative of ``sparse_plus``.
800
+
801
+ For more information, see `Learning with Fenchel-Young Losses (section 6.2)
802
+ <https://arxiv.org/abs/1901.02324>`_.
803
+
804
+ Args:
805
+ x : input array
806
+
807
+ Returns:
808
+ An array.
809
+
810
+ See also:
811
+ :func:`sigmoid`
812
+ """
813
+ return _keep_unit(jax.nn.sparse_sigmoid, x)