brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,1100 +1,1100 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- """
18
- Shared neural network activations and other functions.
19
- """
20
-
21
- from typing import Any, Union, Sequence
22
-
23
- import brainunit as u
24
- import jax
25
- from jax.scipy.special import logsumexp
26
-
27
- from brainstate import random
28
- from brainstate.typing import ArrayLike
29
-
30
- __all__ = [
31
- "tanh",
32
- "relu",
33
- "squareplus",
34
- "softplus",
35
- "soft_sign",
36
- "sigmoid",
37
- "silu",
38
- "swish",
39
- "log_sigmoid",
40
- "elu",
41
- "leaky_relu",
42
- "hard_tanh",
43
- "celu",
44
- "selu",
45
- "gelu",
46
- "glu",
47
- "logsumexp",
48
- "log_softmax",
49
- "softmax",
50
- "standardize",
51
- "one_hot",
52
- "relu6",
53
- "hard_sigmoid",
54
- "hard_silu",
55
- "hard_swish",
56
- 'hard_shrink',
57
- 'rrelu',
58
- 'mish',
59
- 'soft_shrink',
60
- 'prelu',
61
- 'tanh_shrink',
62
- 'softmin',
63
- 'sparse_plus',
64
- 'sparse_sigmoid',
65
- ]
66
-
67
-
68
- def tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
69
- r"""
70
- Hyperbolic tangent activation function.
71
-
72
- Computes the element-wise function:
73
-
74
- .. math::
75
- \mathrm{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
76
-
77
- Parameters
78
- ----------
79
- x : ArrayLike
80
- Input array.
81
-
82
- Returns
83
- -------
84
- jax.Array or Quantity
85
- An array with the same shape as the input.
86
- """
87
- return u.math.tanh(x)
88
-
89
-
90
- def softmin(x, axis=-1):
91
- r"""
92
- Softmin activation function.
93
-
94
- Applies the Softmin function to an n-dimensional input tensor, rescaling elements
95
- so that they lie in the range [0, 1] and sum to 1 along the specified axis.
96
-
97
- .. math::
98
- \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
99
-
100
- Parameters
101
- ----------
102
- x : ArrayLike
103
- Input array of any shape.
104
- axis : int, optional
105
- The axis along which Softmin will be computed. Every slice along this
106
- dimension will sum to 1. Default is -1.
107
-
108
- Returns
109
- -------
110
- jax.Array or Quantity
111
- Output array with the same shape as the input.
112
- """
113
- unnormalized = u.math.exp(-x)
114
- return unnormalized / unnormalized.sum(axis, keepdims=True)
115
-
116
-
117
- def tanh_shrink(x):
118
- r"""
119
- Tanh shrink activation function.
120
-
121
- Applies the element-wise function:
122
-
123
- .. math::
124
- \text{Tanhshrink}(x) = x - \tanh(x)
125
-
126
- Parameters
127
- ----------
128
- x : ArrayLike
129
- Input array.
130
-
131
- Returns
132
- -------
133
- jax.Array or Quantity
134
- Output array with the same shape as the input.
135
- """
136
- return x - u.math.tanh(x)
137
-
138
-
139
- def prelu(x, a=0.25):
140
- r"""
141
- Parametric Rectified Linear Unit activation function.
142
-
143
- Applies the element-wise function:
144
-
145
- .. math::
146
- \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
147
-
148
- or equivalently:
149
-
150
- .. math::
151
- \text{PReLU}(x) =
152
- \begin{cases}
153
- x, & \text{ if } x \geq 0 \\
154
- ax, & \text{ otherwise }
155
- \end{cases}
156
-
157
- Parameters
158
- ----------
159
- x : ArrayLike
160
- Input array.
161
- a : float or ArrayLike, optional
162
- The negative slope coefficient. Can be a learnable parameter.
163
- Default is 0.25.
164
-
165
- Returns
166
- -------
167
- jax.Array or Quantity
168
- Output array with the same shape as the input.
169
-
170
- Notes
171
- -----
172
- When used in neural network layers, :math:`a` can be a learnable parameter
173
- that is optimized during training.
174
- """
175
- return u.math.where(x >= 0., x, a * x)
176
-
177
-
178
- def soft_shrink(x, lambd=0.5):
179
- r"""
180
- Soft shrinkage activation function.
181
-
182
- Applies the soft shrinkage function element-wise:
183
-
184
- .. math::
185
- \text{SoftShrinkage}(x) =
186
- \begin{cases}
187
- x - \lambda, & \text{ if } x > \lambda \\
188
- x + \lambda, & \text{ if } x < -\lambda \\
189
- 0, & \text{ otherwise }
190
- \end{cases}
191
-
192
- Parameters
193
- ----------
194
- x : ArrayLike
195
- Input array of any shape.
196
- lambd : float, optional
197
- The :math:`\lambda` value for the soft shrinkage formulation.
198
- Must be non-negative. Default is 0.5.
199
-
200
- Returns
201
- -------
202
- jax.Array or Quantity
203
- Output array with the same shape as the input.
204
- """
205
- return u.math.where(
206
- x > lambd,
207
- x - lambd,
208
- u.math.where(
209
- x < -lambd,
210
- x + lambd,
211
- u.Quantity(0., unit=u.get_unit(lambd))
212
- )
213
- )
214
-
215
-
216
- def mish(x):
217
- r"""
218
- Mish activation function.
219
-
220
- Mish is a self-regularized non-monotonic activation function.
221
-
222
- .. math::
223
- \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
224
-
225
- Parameters
226
- ----------
227
- x : ArrayLike
228
- Input array of any shape.
229
-
230
- Returns
231
- -------
232
- jax.Array or Quantity
233
- Output array with the same shape as the input.
234
-
235
- References
236
- ----------
237
- .. [1] Misra, D. (2019). "Mish: A Self Regularized Non-Monotonic Activation Function."
238
- arXiv:1908.08681
239
- """
240
- return x * u.math.tanh(softplus(x))
241
-
242
-
243
- def rrelu(x, lower=0.125, upper=0.3333333333333333):
244
- r"""
245
- Randomized Leaky Rectified Linear Unit activation function.
246
-
247
- The function is defined as:
248
-
249
- .. math::
250
- \text{RReLU}(x) =
251
- \begin{cases}
252
- x & \text{if } x \geq 0 \\
253
- ax & \text{ otherwise }
254
- \end{cases}
255
-
256
- where :math:`a` is randomly sampled from uniform distribution
257
- :math:`\mathcal{U}(\text{lower}, \text{upper})`.
258
-
259
- Parameters
260
- ----------
261
- x : ArrayLike
262
- Input array of any shape.
263
- lower : float, optional
264
- Lower bound of the uniform distribution for sampling the negative slope.
265
- Default is 1/8.
266
- upper : float, optional
267
- Upper bound of the uniform distribution for sampling the negative slope.
268
- Default is 1/3.
269
-
270
- Returns
271
- -------
272
- jax.Array or Quantity
273
- Output array with the same shape as the input.
274
-
275
- References
276
- ----------
277
- .. [1] Xu, B., et al. (2015). "Empirical Evaluation of Rectified Activations
278
- in Convolutional Network." arXiv:1505.00853
279
- """
280
- a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
281
- return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
282
-
283
-
284
- def hard_shrink(x, lambd=0.5):
285
- r"""
286
- Hard shrinkage activation function.
287
-
288
- Applies the hard shrinkage function element-wise:
289
-
290
- .. math::
291
- \text{HardShrink}(x) =
292
- \begin{cases}
293
- x, & \text{ if } x > \lambda \\
294
- x, & \text{ if } x < -\lambda \\
295
- 0, & \text{ otherwise }
296
- \end{cases}
297
-
298
- Parameters
299
- ----------
300
- x : ArrayLike
301
- Input array of any shape.
302
- lambd : float, optional
303
- The :math:`\lambda` threshold value for the hard shrinkage formulation.
304
- Default is 0.5.
305
-
306
- Returns
307
- -------
308
- jax.Array or Quantity
309
- Output array with the same shape as the input.
310
- """
311
- return u.math.where(
312
- x > lambd,
313
- x,
314
- u.math.where(
315
- x < -lambd,
316
- x,
317
- u.Quantity(0., unit=u.get_unit(x))
318
- )
319
- )
320
-
321
-
322
- def relu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
323
- r"""
324
- Rectified Linear Unit activation function.
325
-
326
- Computes the element-wise function:
327
-
328
- .. math::
329
- \mathrm{relu}(x) = \max(x, 0)
330
-
331
- Under differentiation, we take:
332
-
333
- .. math::
334
- \nabla \mathrm{relu}(0) = 0
335
-
336
- Parameters
337
- ----------
338
- x : ArrayLike
339
- Input array.
340
-
341
- Returns
342
- -------
343
- jax.Array or Quantity
344
- An array with the same shape as the input.
345
-
346
- Examples
347
- --------
348
- .. code-block:: python
349
-
350
- >>> import jax.numpy as jnp
351
- >>> import brainstate
352
- >>> brainstate.nn.relu(jnp.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
353
- Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
354
-
355
- See Also
356
- --------
357
- relu6 : ReLU6 activation function.
358
- leaky_relu : Leaky ReLU activation function.
359
-
360
- References
361
- ----------
362
- .. [1] For more information see "Numerical influence of ReLU'(0) on backpropagation"
363
- https://openreview.net/forum?id=urrcVI-_jRm
364
- """
365
- return u.math.relu(x)
366
-
367
-
368
- def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Union[jax.Array, u.Quantity]:
369
- r"""
370
- Squareplus activation function.
371
-
372
- Computes the element-wise function:
373
-
374
- .. math::
375
- \mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
376
-
377
- Parameters
378
- ----------
379
- x : ArrayLike
380
- Input array.
381
- b : ArrayLike, optional
382
- Smoothness parameter. Default is 4.
383
-
384
- Returns
385
- -------
386
- jax.Array or Quantity
387
- An array with the same shape as the input.
388
-
389
- References
390
- ----------
391
- .. [1] So, D., et al. (2021). "Primer: Searching for Efficient Transformers
392
- for Language Modeling." arXiv:2112.11687
393
- """
394
- return u.math.squareplus(x, b=b)
395
-
396
-
397
- def softplus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
398
- r"""
399
- Softplus activation function.
400
-
401
- Computes the element-wise function:
402
-
403
- .. math::
404
- \mathrm{softplus}(x) = \log(1 + e^x)
405
-
406
- Parameters
407
- ----------
408
- x : ArrayLike
409
- Input array.
410
-
411
- Returns
412
- -------
413
- jax.Array or Quantity
414
- An array with the same shape as the input.
415
- """
416
- return u.math.softplus(x)
417
-
418
-
419
- def soft_sign(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
420
- r"""
421
- Soft-sign activation function.
422
-
423
- Computes the element-wise function:
424
-
425
- .. math::
426
- \mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
427
-
428
- Parameters
429
- ----------
430
- x : ArrayLike
431
- Input array.
432
-
433
- Returns
434
- -------
435
- jax.Array or Quantity
436
- An array with the same shape as the input.
437
- """
438
- return u.math.soft_sign(x)
439
-
440
-
441
- def sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
442
- r"""
443
- Sigmoid activation function.
444
-
445
- Computes the element-wise function:
446
-
447
- .. math::
448
- \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
449
-
450
- Parameters
451
- ----------
452
- x : ArrayLike
453
- Input array.
454
-
455
- Returns
456
- -------
457
- jax.Array or Quantity
458
- An array with the same shape as the input.
459
-
460
- See Also
461
- --------
462
- log_sigmoid : Logarithm of the sigmoid function.
463
- """
464
- return u.math.sigmoid(x)
465
-
466
-
467
- def silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
468
- r"""
469
- SiLU (Sigmoid Linear Unit) activation function.
470
-
471
- Computes the element-wise function:
472
-
473
- .. math::
474
- \mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
475
-
476
- Parameters
477
- ----------
478
- x : ArrayLike
479
- Input array.
480
-
481
- Returns
482
- -------
483
- jax.Array or Quantity
484
- An array with the same shape as the input.
485
-
486
- See Also
487
- --------
488
- sigmoid : The sigmoid function.
489
- swish : Alias for silu.
490
-
491
- Notes
492
- -----
493
- `swish` and `silu` are both aliases for the same function.
494
- """
495
- return u.math.silu(x)
496
-
497
-
498
- swish = silu
499
-
500
-
501
- def log_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
502
- r"""
503
- Log-sigmoid activation function.
504
-
505
- Computes the element-wise function:
506
-
507
- .. math::
508
- \mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
509
-
510
- Parameters
511
- ----------
512
- x : ArrayLike
513
- Input array.
514
-
515
- Returns
516
- -------
517
- jax.Array or Quantity
518
- An array with the same shape as the input.
519
-
520
- See Also
521
- --------
522
- sigmoid : The sigmoid function.
523
- """
524
- return u.math.log_sigmoid(x)
525
-
526
-
527
- def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
528
- r"""
529
- Exponential Linear Unit activation function.
530
-
531
- Computes the element-wise function:
532
-
533
- .. math::
534
- \mathrm{elu}(x) = \begin{cases}
535
- x, & x > 0\\
536
- \alpha \left(\exp(x) - 1\right), & x \le 0
537
- \end{cases}
538
-
539
- Parameters
540
- ----------
541
- x : ArrayLike
542
- Input array.
543
- alpha : ArrayLike, optional
544
- Scalar or array of alpha values. Default is 1.0.
545
-
546
- Returns
547
- -------
548
- jax.Array or Quantity
549
- An array with the same shape as the input.
550
-
551
- See Also
552
- --------
553
- selu : Scaled ELU activation function.
554
- celu : Continuously-differentiable ELU activation function.
555
- """
556
- return u.math.elu(x, alpha=alpha)
557
-
558
-
559
- def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Union[jax.Array, u.Quantity]:
560
- r"""
561
- Leaky Rectified Linear Unit activation function.
562
-
563
- Computes the element-wise function:
564
-
565
- .. math::
566
- \mathrm{leaky\_relu}(x) = \begin{cases}
567
- x, & x \ge 0\\
568
- \alpha x, & x < 0
569
- \end{cases}
570
-
571
- where :math:`\alpha` = :code:`negative_slope`.
572
-
573
- Parameters
574
- ----------
575
- x : ArrayLike
576
- Input array.
577
- negative_slope : ArrayLike, optional
578
- Array or scalar specifying the negative slope. Default is 0.01.
579
-
580
- Returns
581
- -------
582
- jax.Array or Quantity
583
- An array with the same shape as the input.
584
-
585
- See Also
586
- --------
587
- relu : Standard ReLU activation function.
588
- prelu : Parametric ReLU with learnable slope.
589
- """
590
- return u.math.leaky_relu(x, negative_slope=negative_slope)
591
-
592
-
593
- def _hard_tanh(x, min_val=- 1.0, max_val=1.0):
594
- return jax.numpy.where(x > max_val, max_val, jax.numpy.where(x < min_val, min_val, x))
595
-
596
-
597
- def hard_tanh(
598
- x: ArrayLike,
599
- min_val: float = - 1.0,
600
- max_val: float = 1.0
601
- ) -> Union[jax.Array, u.Quantity]:
602
- r"""
603
- Hard hyperbolic tangent activation function.
604
-
605
- Computes the element-wise function:
606
-
607
- .. math::
608
- \mathrm{hard\_tanh}(x) = \begin{cases}
609
- -1, & x < -1\\
610
- x, & -1 \le x \le 1\\
611
- 1, & 1 < x
612
- \end{cases}
613
-
614
- Parameters
615
- ----------
616
- x : ArrayLike
617
- Input array.
618
- min_val : float, optional
619
- Minimum value of the linear region range. Default is -1.
620
- max_val : float, optional
621
- Maximum value of the linear region range. Default is 1.
622
-
623
- Returns
624
- -------
625
- jax.Array or Quantity
626
- An array with the same shape as the input.
627
- """
628
- x = u.Quantity(x)
629
- min_val = u.Quantity(min_val).to(x.unit).mantissa
630
- max_val = u.Quantity(max_val).to(x.unit).mantissa
631
- return u.maybe_decimal(_hard_tanh(x.mantissa, min_val=min_val, max_val=max_val) * x.unit)
632
-
633
-
634
- def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
635
- r"""
636
- Continuously-differentiable Exponential Linear Unit activation.
637
-
638
- Computes the element-wise function:
639
-
640
- .. math::
641
- \mathrm{celu}(x) = \begin{cases}
642
- x, & x > 0\\
643
- \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
644
- \end{cases}
645
-
646
- Parameters
647
- ----------
648
- x : ArrayLike
649
- Input array.
650
- alpha : ArrayLike, optional
651
- Scalar or array value controlling the smoothness. Default is 1.0.
652
-
653
- Returns
654
- -------
655
- jax.Array or Quantity
656
- An array with the same shape as the input.
657
-
658
- References
659
- ----------
660
- .. [1] Barron, J. T. (2017). "Continuously Differentiable Exponential Linear Units."
661
- arXiv:1704.07483
662
- """
663
- return u.math.celu(x, alpha=alpha)
664
-
665
-
666
- def selu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
667
- r"""
668
- Scaled Exponential Linear Unit activation.
669
-
670
- Computes the element-wise function:
671
-
672
- .. math::
673
- \mathrm{selu}(x) = \lambda \begin{cases}
674
- x, & x > 0\\
675
- \alpha e^x - \alpha, & x \le 0
676
- \end{cases}
677
-
678
- where :math:`\lambda = 1.0507009873554804934193349852946` and
679
- :math:`\alpha = 1.6732632423543772848170429916717`.
680
-
681
- Parameters
682
- ----------
683
- x : ArrayLike
684
- Input array.
685
-
686
- Returns
687
- -------
688
- jax.Array or Quantity
689
- An array with the same shape as the input.
690
-
691
- See Also
692
- --------
693
- elu : Exponential Linear Unit activation function.
694
-
695
- References
696
- ----------
697
- .. [1] Klambauer, G., et al. (2017). "Self-Normalizing Neural Networks."
698
- NeurIPS 2017.
699
- """
700
- return u.math.selu(x)
701
-
702
-
703
- def gelu(x: ArrayLike, approximate: bool = True) -> Union[jax.Array, u.Quantity]:
704
- r"""
705
- Gaussian Error Linear Unit activation function.
706
-
707
- If ``approximate=False``, computes the element-wise function:
708
-
709
- .. math::
710
- \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
711
- \frac{x}{\sqrt{2}} \right) \right)
712
-
713
- If ``approximate=True``, uses the approximate formulation of GELU:
714
-
715
- .. math::
716
- \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
717
- \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
718
-
719
- Parameters
720
- ----------
721
- x : ArrayLike
722
- Input array.
723
- approximate : bool, optional
724
- Whether to use the approximate (True) or exact (False) formulation.
725
- Default is True.
726
-
727
- Returns
728
- -------
729
- jax.Array or Quantity
730
- An array with the same shape as the input.
731
-
732
- References
733
- ----------
734
- .. [1] Hendrycks, D., & Gimpel, K. (2016). "Gaussian Error Linear Units (GELUs)."
735
- arXiv:1606.08415
736
- """
737
- return u.math.gelu(x, approximate=approximate)
738
-
739
-
740
- def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
741
- r"""
742
- Gated Linear Unit activation function.
743
-
744
- Computes the function:
745
-
746
- .. math::
747
- \mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot
748
- \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right]
749
- \right)
750
-
751
- where the array is split into two along ``axis``. The size of the ``axis``
752
- dimension must be divisible by two.
753
-
754
- Parameters
755
- ----------
756
- x : ArrayLike
757
- Input array. The dimension specified by ``axis`` must be divisible by 2.
758
- axis : int, optional
759
- The axis along which the split should be computed. Default is -1.
760
-
761
- Returns
762
- -------
763
- jax.Array or Quantity
764
- An array with the same shape as input except the ``axis`` dimension
765
- is halved.
766
-
767
- See Also
768
- --------
769
- sigmoid : The sigmoid activation function.
770
- """
771
- return u.math.glu(x, axis=axis)
772
-
773
-
774
- def log_softmax(x: ArrayLike,
775
- axis: int | tuple[int, ...] | None = -1,
776
- where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
777
- r"""
778
- Log-Softmax function.
779
-
780
- Computes the logarithm of the softmax function, which rescales
781
- elements to the range :math:`[-\infty, 0)`.
782
-
783
- .. math ::
784
- \mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
785
- \right)
786
-
787
- Parameters
788
- ----------
789
- x : ArrayLike
790
- Input array.
791
- axis : int or tuple of int, optional
792
- The axis or axes along which the log-softmax should be computed.
793
- Either an integer or a tuple of integers. Default is -1.
794
- where : ArrayLike, optional
795
- Elements to include in the log-softmax computation.
796
-
797
- Returns
798
- -------
799
- jax.Array or Quantity
800
- An array with the same shape as the input.
801
-
802
- See Also
803
- --------
804
- softmax : The softmax function.
805
- """
806
- return jax.nn.log_softmax(x, axis=axis, where=where)
807
-
808
-
809
- def softmax(x: ArrayLike,
810
- axis: int | tuple[int, ...] | None = -1,
811
- where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
812
- r"""
813
- Softmax activation function.
814
-
815
- Computes the function which rescales elements to the range :math:`[0, 1]`
816
- such that the elements along :code:`axis` sum to :math:`1`.
817
-
818
- .. math ::
819
- \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
820
-
821
- Parameters
822
- ----------
823
- x : ArrayLike
824
- Input array.
825
- axis : int or tuple of int, optional
826
- The axis or axes along which the softmax should be computed. The
827
- softmax output summed across these dimensions should sum to :math:`1`.
828
- Either an integer or a tuple of integers. Default is -1.
829
- where : ArrayLike, optional
830
- Elements to include in the softmax computation.
831
-
832
- Returns
833
- -------
834
- jax.Array or Quantity
835
- An array with the same shape as the input.
836
-
837
- See Also
838
- --------
839
- log_softmax : Logarithm of the softmax function.
840
- softmin : Softmin activation function.
841
- """
842
- return jax.nn.softmax(x, axis=axis, where=where)
843
-
844
-
845
- def standardize(x: ArrayLike,
846
- axis: int | tuple[int, ...] | None = -1,
847
- variance: ArrayLike | None = None,
848
- epsilon: ArrayLike = 1e-5,
849
- where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
850
- r"""
851
- Standardize (normalize) an array.
852
-
853
- Normalizes an array by subtracting the mean and dividing by the standard
854
- deviation :math:`\sqrt{\mathrm{variance}}`.
855
-
856
- Parameters
857
- ----------
858
- x : ArrayLike
859
- Input array.
860
- axis : int or tuple of int, optional
861
- The axis or axes along which to compute the mean and variance.
862
- Default is -1.
863
- variance : ArrayLike, optional
864
- Pre-computed variance. If None, variance is computed from ``x``.
865
- epsilon : ArrayLike, optional
866
- A small constant added to the variance to avoid division by zero.
867
- Default is 1e-5.
868
- where : ArrayLike, optional
869
- Elements to include in the computation.
870
-
871
- Returns
872
- -------
873
- jax.Array or Quantity
874
- Standardized array with the same shape as the input.
875
- """
876
- return jax.nn.standardize(x, axis=axis, where=where, variance=variance, epsilon=epsilon)
877
-
878
-
879
- def one_hot(x: Any,
880
- num_classes: int, *,
881
- dtype: Any = jax.numpy.float_,
882
- axis: Union[int, Sequence[int]] = -1) -> Union[jax.Array, u.Quantity]:
883
- """
884
- One-hot encode the given indices.
885
-
886
- Each index in the input ``x`` is encoded as a vector of zeros of length
887
- ``num_classes`` with the element at ``index`` set to one.
888
-
889
- Indices outside the range [0, num_classes) will be encoded as zeros.
890
-
891
- Parameters
892
- ----------
893
- x : ArrayLike
894
- A tensor of indices.
895
- num_classes : int
896
- Number of classes in the one-hot dimension.
897
- dtype : dtype, optional
898
- The dtype for the returned values. Default is ``jnp.float_``.
899
- axis : int or Sequence of int, optional
900
- The axis or axes along which the function should be computed.
901
- Default is -1.
902
-
903
- Returns
904
- -------
905
- jax.Array or Quantity
906
- One-hot encoded array.
907
-
908
- Examples
909
- --------
910
- .. code-block:: python
911
-
912
- >>> import jax.numpy as jnp
913
- >>> import brainstate
914
- >>> brainstate.nn.one_hot(jnp.array([0, 1, 2]), 3)
915
- Array([[1., 0., 0.],
916
- [0., 1., 0.],
917
- [0., 0., 1.]], dtype=float32)
918
-
919
- >>> # Indices outside the range are encoded as zeros
920
- >>> brainstate.nn.one_hot(jnp.array([-1, 3]), 3)
921
- Array([[0., 0., 0.],
922
- [0., 0., 0.]], dtype=float32)
923
- """
924
- return jax.nn.one_hot(x, axis=axis, num_classes=num_classes, dtype=dtype)
925
-
926
-
927
- def relu6(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
928
- r"""
929
- Rectified Linear Unit 6 activation function.
930
-
931
- Computes the element-wise function:
932
-
933
- .. math::
934
- \mathrm{relu6}(x) = \min(\max(x, 0), 6)
935
-
936
- Under differentiation, we take:
937
-
938
- .. math::
939
- \nabla \mathrm{relu}(0) = 0
940
-
941
- and
942
-
943
- .. math::
944
- \nabla \mathrm{relu}(6) = 0
945
-
946
- Parameters
947
- ----------
948
- x : ArrayLike
949
- Input array.
950
-
951
- Returns
952
- -------
953
- jax.Array or Quantity
954
- An array with the same shape as the input.
955
-
956
- See Also
957
- --------
958
- relu : Standard ReLU activation function.
959
- """
960
- return u.math.relu6(x)
961
-
962
-
963
- def hard_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
964
- r"""
965
- Hard Sigmoid activation function.
966
-
967
- Computes the element-wise function:
968
-
969
- .. math::
970
- \mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
971
-
972
- Parameters
973
- ----------
974
- x : ArrayLike
975
- Input array.
976
-
977
- Returns
978
- -------
979
- jax.Array or Quantity
980
- An array with the same shape as the input.
981
-
982
- See Also
983
- --------
984
- relu6 : ReLU6 activation function.
985
- sigmoid : Standard sigmoid function.
986
- """
987
- return u.math.hard_sigmoid(x)
988
-
989
-
990
- def hard_silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
991
- r"""
992
- Hard SiLU (Swish) activation function.
993
-
994
- Computes the element-wise function:
995
-
996
- .. math::
997
- \mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
998
-
999
- Parameters
1000
- ----------
1001
- x : ArrayLike
1002
- Input array.
1003
-
1004
- Returns
1005
- -------
1006
- jax.Array or Quantity
1007
- An array with the same shape as the input.
1008
-
1009
- See Also
1010
- --------
1011
- hard_sigmoid : Hard sigmoid activation function.
1012
- silu : Standard SiLU activation function.
1013
- hard_swish : Alias for hard_silu.
1014
-
1015
- Notes
1016
- -----
1017
- Both `hard_silu` and `hard_swish` are aliases for the same function.
1018
- """
1019
- return u.math.hard_silu(x)
1020
-
1021
-
1022
- hard_swish = hard_silu
1023
-
1024
-
1025
- def sparse_plus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
1026
- r"""
1027
- Sparse plus activation function.
1028
-
1029
- Computes the function:
1030
-
1031
- .. math::
1032
-
1033
- \mathrm{sparse\_plus}(x) = \begin{cases}
1034
- 0, & x \leq -1\\
1035
- \frac{1}{4}(x+1)^2, & -1 < x < 1 \\
1036
- x, & 1 \leq x
1037
- \end{cases}
1038
-
1039
- This is the twin function of the softplus activation, ensuring a zero output
1040
- for inputs less than -1 and a linear output for inputs greater than 1,
1041
- while remaining smooth, convex, and monotonic between -1 and 1.
1042
-
1043
- Parameters
1044
- ----------
1045
- x : ArrayLike
1046
- Input array.
1047
-
1048
- Returns
1049
- -------
1050
- jax.Array or Quantity
1051
- An array with the same shape as the input.
1052
-
1053
- See Also
1054
- --------
1055
- sparse_sigmoid : Derivative of sparse_plus.
1056
- softplus : Standard softplus activation function.
1057
- """
1058
- return u.math.sparse_plus(x)
1059
-
1060
-
1061
- def sparse_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
1062
- r"""
1063
- Sparse sigmoid activation function.
1064
-
1065
- Computes the function:
1066
-
1067
- .. math::
1068
-
1069
- \mathrm{sparse\_sigmoid}(x) = \begin{cases}
1070
- 0, & x \leq -1\\
1071
- \frac{1}{2}(x+1), & -1 < x < 1 \\
1072
- 1, & 1 \leq x
1073
- \end{cases}
1074
-
1075
- This is the twin function of the standard sigmoid activation, ensuring a zero
1076
- output for inputs less than -1, a 1 output for inputs greater than 1, and a
1077
- linear output for inputs between -1 and 1. It is the derivative of `sparse_plus`.
1078
-
1079
- Parameters
1080
- ----------
1081
- x : ArrayLike
1082
- Input array.
1083
-
1084
- Returns
1085
- -------
1086
- jax.Array or Quantity
1087
- An array with the same shape as the input.
1088
-
1089
- See Also
1090
- --------
1091
- sigmoid : Standard sigmoid activation function.
1092
- sparse_plus : Sparse plus activation function.
1093
-
1094
- References
1095
- ----------
1096
- .. [1] Martins, A. F. T., & Astudillo, R. F. (2016). "From Softmax to Sparsemax:
1097
- A Sparse Model of Attention and Multi-Label Classification."
1098
- In ICML. See also "Learning with Fenchel-Young Losses", arXiv:1901.02324
1099
- """
1100
- return u.math.sparse_sigmoid(x)
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ """
18
+ Shared neural network activations and other functions.
19
+ """
20
+
21
+ from typing import Any, Union, Sequence
22
+
23
+ import brainunit as u
24
+ import jax
25
+ from jax.scipy.special import logsumexp
26
+
27
+ from brainstate import random
28
+ from brainstate.typing import ArrayLike
29
+
30
+ __all__ = [
31
+ "tanh",
32
+ "relu",
33
+ "squareplus",
34
+ "softplus",
35
+ "soft_sign",
36
+ "sigmoid",
37
+ "silu",
38
+ "swish",
39
+ "log_sigmoid",
40
+ "elu",
41
+ "leaky_relu",
42
+ "hard_tanh",
43
+ "celu",
44
+ "selu",
45
+ "gelu",
46
+ "glu",
47
+ "logsumexp",
48
+ "log_softmax",
49
+ "softmax",
50
+ "standardize",
51
+ "one_hot",
52
+ "relu6",
53
+ "hard_sigmoid",
54
+ "hard_silu",
55
+ "hard_swish",
56
+ 'hard_shrink',
57
+ 'rrelu',
58
+ 'mish',
59
+ 'soft_shrink',
60
+ 'prelu',
61
+ 'tanh_shrink',
62
+ 'softmin',
63
+ 'sparse_plus',
64
+ 'sparse_sigmoid',
65
+ ]
66
+
67
+
68
+ def tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
69
+ r"""
70
+ Hyperbolic tangent activation function.
71
+
72
+ Computes the element-wise function:
73
+
74
+ .. math::
75
+ \mathrm{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
76
+
77
+ Parameters
78
+ ----------
79
+ x : ArrayLike
80
+ Input array.
81
+
82
+ Returns
83
+ -------
84
+ jax.Array or Quantity
85
+ An array with the same shape as the input.
86
+ """
87
+ return u.math.tanh(x)
88
+
89
+
90
+ def softmin(x, axis=-1):
91
+ r"""
92
+ Softmin activation function.
93
+
94
+ Applies the Softmin function to an n-dimensional input tensor, rescaling elements
95
+ so that they lie in the range [0, 1] and sum to 1 along the specified axis.
96
+
97
+ .. math::
98
+ \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
99
+
100
+ Parameters
101
+ ----------
102
+ x : ArrayLike
103
+ Input array of any shape.
104
+ axis : int, optional
105
+ The axis along which Softmin will be computed. Every slice along this
106
+ dimension will sum to 1. Default is -1.
107
+
108
+ Returns
109
+ -------
110
+ jax.Array or Quantity
111
+ Output array with the same shape as the input.
112
+ """
113
+ unnormalized = u.math.exp(-x)
114
+ return unnormalized / unnormalized.sum(axis, keepdims=True)
115
+
116
+
117
+ def tanh_shrink(x):
118
+ r"""
119
+ Tanh shrink activation function.
120
+
121
+ Applies the element-wise function:
122
+
123
+ .. math::
124
+ \text{Tanhshrink}(x) = x - \tanh(x)
125
+
126
+ Parameters
127
+ ----------
128
+ x : ArrayLike
129
+ Input array.
130
+
131
+ Returns
132
+ -------
133
+ jax.Array or Quantity
134
+ Output array with the same shape as the input.
135
+ """
136
+ return x - u.math.tanh(x)
137
+
138
+
139
+ def prelu(x, a=0.25):
140
+ r"""
141
+ Parametric Rectified Linear Unit activation function.
142
+
143
+ Applies the element-wise function:
144
+
145
+ .. math::
146
+ \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
147
+
148
+ or equivalently:
149
+
150
+ .. math::
151
+ \text{PReLU}(x) =
152
+ \begin{cases}
153
+ x, & \text{ if } x \geq 0 \\
154
+ ax, & \text{ otherwise }
155
+ \end{cases}
156
+
157
+ Parameters
158
+ ----------
159
+ x : ArrayLike
160
+ Input array.
161
+ a : float or ArrayLike, optional
162
+ The negative slope coefficient. Can be a learnable parameter.
163
+ Default is 0.25.
164
+
165
+ Returns
166
+ -------
167
+ jax.Array or Quantity
168
+ Output array with the same shape as the input.
169
+
170
+ Notes
171
+ -----
172
+ When used in neural network layers, :math:`a` can be a learnable parameter
173
+ that is optimized during training.
174
+ """
175
+ return u.math.where(x >= 0., x, a * x)
176
+
177
+
178
+ def soft_shrink(x, lambd=0.5):
179
+ r"""
180
+ Soft shrinkage activation function.
181
+
182
+ Applies the soft shrinkage function element-wise:
183
+
184
+ .. math::
185
+ \text{SoftShrinkage}(x) =
186
+ \begin{cases}
187
+ x - \lambda, & \text{ if } x > \lambda \\
188
+ x + \lambda, & \text{ if } x < -\lambda \\
189
+ 0, & \text{ otherwise }
190
+ \end{cases}
191
+
192
+ Parameters
193
+ ----------
194
+ x : ArrayLike
195
+ Input array of any shape.
196
+ lambd : float, optional
197
+ The :math:`\lambda` value for the soft shrinkage formulation.
198
+ Must be non-negative. Default is 0.5.
199
+
200
+ Returns
201
+ -------
202
+ jax.Array or Quantity
203
+ Output array with the same shape as the input.
204
+ """
205
+ return u.math.where(
206
+ x > lambd,
207
+ x - lambd,
208
+ u.math.where(
209
+ x < -lambd,
210
+ x + lambd,
211
+ u.Quantity(0., unit=u.get_unit(lambd))
212
+ )
213
+ )
214
+
215
+
216
+ def mish(x):
217
+ r"""
218
+ Mish activation function.
219
+
220
+ Mish is a self-regularized non-monotonic activation function.
221
+
222
+ .. math::
223
+ \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
224
+
225
+ Parameters
226
+ ----------
227
+ x : ArrayLike
228
+ Input array of any shape.
229
+
230
+ Returns
231
+ -------
232
+ jax.Array or Quantity
233
+ Output array with the same shape as the input.
234
+
235
+ References
236
+ ----------
237
+ .. [1] Misra, D. (2019). "Mish: A Self Regularized Non-Monotonic Activation Function."
238
+ arXiv:1908.08681
239
+ """
240
+ return x * u.math.tanh(softplus(x))
241
+
242
+
243
+ def rrelu(x, lower=0.125, upper=0.3333333333333333):
244
+ r"""
245
+ Randomized Leaky Rectified Linear Unit activation function.
246
+
247
+ The function is defined as:
248
+
249
+ .. math::
250
+ \text{RReLU}(x) =
251
+ \begin{cases}
252
+ x & \text{if } x \geq 0 \\
253
+ ax & \text{ otherwise }
254
+ \end{cases}
255
+
256
+ where :math:`a` is randomly sampled from uniform distribution
257
+ :math:`\mathcal{U}(\text{lower}, \text{upper})`.
258
+
259
+ Parameters
260
+ ----------
261
+ x : ArrayLike
262
+ Input array of any shape.
263
+ lower : float, optional
264
+ Lower bound of the uniform distribution for sampling the negative slope.
265
+ Default is 1/8.
266
+ upper : float, optional
267
+ Upper bound of the uniform distribution for sampling the negative slope.
268
+ Default is 1/3.
269
+
270
+ Returns
271
+ -------
272
+ jax.Array or Quantity
273
+ Output array with the same shape as the input.
274
+
275
+ References
276
+ ----------
277
+ .. [1] Xu, B., et al. (2015). "Empirical Evaluation of Rectified Activations
278
+ in Convolutional Network." arXiv:1505.00853
279
+ """
280
+ a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
281
+ return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
282
+
283
+
284
+ def hard_shrink(x, lambd=0.5):
285
+ r"""
286
+ Hard shrinkage activation function.
287
+
288
+ Applies the hard shrinkage function element-wise:
289
+
290
+ .. math::
291
+ \text{HardShrink}(x) =
292
+ \begin{cases}
293
+ x, & \text{ if } x > \lambda \\
294
+ x, & \text{ if } x < -\lambda \\
295
+ 0, & \text{ otherwise }
296
+ \end{cases}
297
+
298
+ Parameters
299
+ ----------
300
+ x : ArrayLike
301
+ Input array of any shape.
302
+ lambd : float, optional
303
+ The :math:`\lambda` threshold value for the hard shrinkage formulation.
304
+ Default is 0.5.
305
+
306
+ Returns
307
+ -------
308
+ jax.Array or Quantity
309
+ Output array with the same shape as the input.
310
+ """
311
+ return u.math.where(
312
+ x > lambd,
313
+ x,
314
+ u.math.where(
315
+ x < -lambd,
316
+ x,
317
+ u.Quantity(0., unit=u.get_unit(x))
318
+ )
319
+ )
320
+
321
+
322
+ def relu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
323
+ r"""
324
+ Rectified Linear Unit activation function.
325
+
326
+ Computes the element-wise function:
327
+
328
+ .. math::
329
+ \mathrm{relu}(x) = \max(x, 0)
330
+
331
+ Under differentiation, we take:
332
+
333
+ .. math::
334
+ \nabla \mathrm{relu}(0) = 0
335
+
336
+ Parameters
337
+ ----------
338
+ x : ArrayLike
339
+ Input array.
340
+
341
+ Returns
342
+ -------
343
+ jax.Array or Quantity
344
+ An array with the same shape as the input.
345
+
346
+ Examples
347
+ --------
348
+ .. code-block:: python
349
+
350
+ >>> import jax.numpy as jnp
351
+ >>> import brainstate
352
+ >>> brainstate.nn.relu(jnp.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
353
+ Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
354
+
355
+ See Also
356
+ --------
357
+ relu6 : ReLU6 activation function.
358
+ leaky_relu : Leaky ReLU activation function.
359
+
360
+ References
361
+ ----------
362
+ .. [1] For more information see "Numerical influence of ReLU'(0) on backpropagation"
363
+ https://openreview.net/forum?id=urrcVI-_jRm
364
+ """
365
+ return u.math.relu(x)
366
+
367
+
368
+ def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Union[jax.Array, u.Quantity]:
369
+ r"""
370
+ Squareplus activation function.
371
+
372
+ Computes the element-wise function:
373
+
374
+ .. math::
375
+ \mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
376
+
377
+ Parameters
378
+ ----------
379
+ x : ArrayLike
380
+ Input array.
381
+ b : ArrayLike, optional
382
+ Smoothness parameter. Default is 4.
383
+
384
+ Returns
385
+ -------
386
+ jax.Array or Quantity
387
+ An array with the same shape as the input.
388
+
389
+ References
390
+ ----------
391
+ .. [1] So, D., et al. (2021). "Primer: Searching for Efficient Transformers
392
+ for Language Modeling." arXiv:2112.11687
393
+ """
394
+ return u.math.squareplus(x, b=b)
395
+
396
+
397
+ def softplus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
398
+ r"""
399
+ Softplus activation function.
400
+
401
+ Computes the element-wise function:
402
+
403
+ .. math::
404
+ \mathrm{softplus}(x) = \log(1 + e^x)
405
+
406
+ Parameters
407
+ ----------
408
+ x : ArrayLike
409
+ Input array.
410
+
411
+ Returns
412
+ -------
413
+ jax.Array or Quantity
414
+ An array with the same shape as the input.
415
+ """
416
+ return u.math.softplus(x)
417
+
418
+
419
+ def soft_sign(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
420
+ r"""
421
+ Soft-sign activation function.
422
+
423
+ Computes the element-wise function:
424
+
425
+ .. math::
426
+ \mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
427
+
428
+ Parameters
429
+ ----------
430
+ x : ArrayLike
431
+ Input array.
432
+
433
+ Returns
434
+ -------
435
+ jax.Array or Quantity
436
+ An array with the same shape as the input.
437
+ """
438
+ return u.math.soft_sign(x)
439
+
440
+
441
+ def sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
442
+ r"""
443
+ Sigmoid activation function.
444
+
445
+ Computes the element-wise function:
446
+
447
+ .. math::
448
+ \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
449
+
450
+ Parameters
451
+ ----------
452
+ x : ArrayLike
453
+ Input array.
454
+
455
+ Returns
456
+ -------
457
+ jax.Array or Quantity
458
+ An array with the same shape as the input.
459
+
460
+ See Also
461
+ --------
462
+ log_sigmoid : Logarithm of the sigmoid function.
463
+ """
464
+ return u.math.sigmoid(x)
465
+
466
+
467
+ def silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
468
+ r"""
469
+ SiLU (Sigmoid Linear Unit) activation function.
470
+
471
+ Computes the element-wise function:
472
+
473
+ .. math::
474
+ \mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
475
+
476
+ Parameters
477
+ ----------
478
+ x : ArrayLike
479
+ Input array.
480
+
481
+ Returns
482
+ -------
483
+ jax.Array or Quantity
484
+ An array with the same shape as the input.
485
+
486
+ See Also
487
+ --------
488
+ sigmoid : The sigmoid function.
489
+ swish : Alias for silu.
490
+
491
+ Notes
492
+ -----
493
+ `swish` and `silu` are both aliases for the same function.
494
+ """
495
+ return u.math.silu(x)
496
+
497
+
498
+ swish = silu
499
+
500
+
501
+ def log_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
502
+ r"""
503
+ Log-sigmoid activation function.
504
+
505
+ Computes the element-wise function:
506
+
507
+ .. math::
508
+ \mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
509
+
510
+ Parameters
511
+ ----------
512
+ x : ArrayLike
513
+ Input array.
514
+
515
+ Returns
516
+ -------
517
+ jax.Array or Quantity
518
+ An array with the same shape as the input.
519
+
520
+ See Also
521
+ --------
522
+ sigmoid : The sigmoid function.
523
+ """
524
+ return u.math.log_sigmoid(x)
525
+
526
+
527
+ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
528
+ r"""
529
+ Exponential Linear Unit activation function.
530
+
531
+ Computes the element-wise function:
532
+
533
+ .. math::
534
+ \mathrm{elu}(x) = \begin{cases}
535
+ x, & x > 0\\
536
+ \alpha \left(\exp(x) - 1\right), & x \le 0
537
+ \end{cases}
538
+
539
+ Parameters
540
+ ----------
541
+ x : ArrayLike
542
+ Input array.
543
+ alpha : ArrayLike, optional
544
+ Scalar or array of alpha values. Default is 1.0.
545
+
546
+ Returns
547
+ -------
548
+ jax.Array or Quantity
549
+ An array with the same shape as the input.
550
+
551
+ See Also
552
+ --------
553
+ selu : Scaled ELU activation function.
554
+ celu : Continuously-differentiable ELU activation function.
555
+ """
556
+ return u.math.elu(x, alpha=alpha)
557
+
558
+
559
+ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Union[jax.Array, u.Quantity]:
560
+ r"""
561
+ Leaky Rectified Linear Unit activation function.
562
+
563
+ Computes the element-wise function:
564
+
565
+ .. math::
566
+ \mathrm{leaky\_relu}(x) = \begin{cases}
567
+ x, & x \ge 0\\
568
+ \alpha x, & x < 0
569
+ \end{cases}
570
+
571
+ where :math:`\alpha` = :code:`negative_slope`.
572
+
573
+ Parameters
574
+ ----------
575
+ x : ArrayLike
576
+ Input array.
577
+ negative_slope : ArrayLike, optional
578
+ Array or scalar specifying the negative slope. Default is 0.01.
579
+
580
+ Returns
581
+ -------
582
+ jax.Array or Quantity
583
+ An array with the same shape as the input.
584
+
585
+ See Also
586
+ --------
587
+ relu : Standard ReLU activation function.
588
+ prelu : Parametric ReLU with learnable slope.
589
+ """
590
+ return u.math.leaky_relu(x, negative_slope=negative_slope)
591
+
592
+
593
+ def _hard_tanh(x, min_val=- 1.0, max_val=1.0):
594
+ return jax.numpy.where(x > max_val, max_val, jax.numpy.where(x < min_val, min_val, x))
595
+
596
+
597
+ def hard_tanh(
598
+ x: ArrayLike,
599
+ min_val: float = - 1.0,
600
+ max_val: float = 1.0
601
+ ) -> Union[jax.Array, u.Quantity]:
602
+ r"""
603
+ Hard hyperbolic tangent activation function.
604
+
605
+ Computes the element-wise function:
606
+
607
+ .. math::
608
+ \mathrm{hard\_tanh}(x) = \begin{cases}
609
+ -1, & x < -1\\
610
+ x, & -1 \le x \le 1\\
611
+ 1, & 1 < x
612
+ \end{cases}
613
+
614
+ Parameters
615
+ ----------
616
+ x : ArrayLike
617
+ Input array.
618
+ min_val : float, optional
619
+ Minimum value of the linear region range. Default is -1.
620
+ max_val : float, optional
621
+ Maximum value of the linear region range. Default is 1.
622
+
623
+ Returns
624
+ -------
625
+ jax.Array or Quantity
626
+ An array with the same shape as the input.
627
+ """
628
+ x = u.Quantity(x)
629
+ min_val = u.Quantity(min_val).to(x.unit).mantissa
630
+ max_val = u.Quantity(max_val).to(x.unit).mantissa
631
+ return u.maybe_decimal(_hard_tanh(x.mantissa, min_val=min_val, max_val=max_val) * x.unit)
632
+
633
+
634
+ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
635
+ r"""
636
+ Continuously-differentiable Exponential Linear Unit activation.
637
+
638
+ Computes the element-wise function:
639
+
640
+ .. math::
641
+ \mathrm{celu}(x) = \begin{cases}
642
+ x, & x > 0\\
643
+ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
644
+ \end{cases}
645
+
646
+ Parameters
647
+ ----------
648
+ x : ArrayLike
649
+ Input array.
650
+ alpha : ArrayLike, optional
651
+ Scalar or array value controlling the smoothness. Default is 1.0.
652
+
653
+ Returns
654
+ -------
655
+ jax.Array or Quantity
656
+ An array with the same shape as the input.
657
+
658
+ References
659
+ ----------
660
+ .. [1] Barron, J. T. (2017). "Continuously Differentiable Exponential Linear Units."
661
+ arXiv:1704.07483
662
+ """
663
+ return u.math.celu(x, alpha=alpha)
664
+
665
+
666
+ def selu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
667
+ r"""
668
+ Scaled Exponential Linear Unit activation.
669
+
670
+ Computes the element-wise function:
671
+
672
+ .. math::
673
+ \mathrm{selu}(x) = \lambda \begin{cases}
674
+ x, & x > 0\\
675
+ \alpha e^x - \alpha, & x \le 0
676
+ \end{cases}
677
+
678
+ where :math:`\lambda = 1.0507009873554804934193349852946` and
679
+ :math:`\alpha = 1.6732632423543772848170429916717`.
680
+
681
+ Parameters
682
+ ----------
683
+ x : ArrayLike
684
+ Input array.
685
+
686
+ Returns
687
+ -------
688
+ jax.Array or Quantity
689
+ An array with the same shape as the input.
690
+
691
+ See Also
692
+ --------
693
+ elu : Exponential Linear Unit activation function.
694
+
695
+ References
696
+ ----------
697
+ .. [1] Klambauer, G., et al. (2017). "Self-Normalizing Neural Networks."
698
+ NeurIPS 2017.
699
+ """
700
+ return u.math.selu(x)
701
+
702
+
703
+ def gelu(x: ArrayLike, approximate: bool = True) -> Union[jax.Array, u.Quantity]:
704
+ r"""
705
+ Gaussian Error Linear Unit activation function.
706
+
707
+ If ``approximate=False``, computes the element-wise function:
708
+
709
+ .. math::
710
+ \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
711
+ \frac{x}{\sqrt{2}} \right) \right)
712
+
713
+ If ``approximate=True``, uses the approximate formulation of GELU:
714
+
715
+ .. math::
716
+ \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
717
+ \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
718
+
719
+ Parameters
720
+ ----------
721
+ x : ArrayLike
722
+ Input array.
723
+ approximate : bool, optional
724
+ Whether to use the approximate (True) or exact (False) formulation.
725
+ Default is True.
726
+
727
+ Returns
728
+ -------
729
+ jax.Array or Quantity
730
+ An array with the same shape as the input.
731
+
732
+ References
733
+ ----------
734
+ .. [1] Hendrycks, D., & Gimpel, K. (2016). "Gaussian Error Linear Units (GELUs)."
735
+ arXiv:1606.08415
736
+ """
737
+ return u.math.gelu(x, approximate=approximate)
738
+
739
+
740
+ def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
741
+ r"""
742
+ Gated Linear Unit activation function.
743
+
744
+ Computes the function:
745
+
746
+ .. math::
747
+ \mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot
748
+ \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right]
749
+ \right)
750
+
751
+ where the array is split into two along ``axis``. The size of the ``axis``
752
+ dimension must be divisible by two.
753
+
754
+ Parameters
755
+ ----------
756
+ x : ArrayLike
757
+ Input array. The dimension specified by ``axis`` must be divisible by 2.
758
+ axis : int, optional
759
+ The axis along which the split should be computed. Default is -1.
760
+
761
+ Returns
762
+ -------
763
+ jax.Array or Quantity
764
+ An array with the same shape as input except the ``axis`` dimension
765
+ is halved.
766
+
767
+ See Also
768
+ --------
769
+ sigmoid : The sigmoid activation function.
770
+ """
771
+ return u.math.glu(x, axis=axis)
772
+
773
+
774
+ def log_softmax(x: ArrayLike,
775
+ axis: int | tuple[int, ...] | None = -1,
776
+ where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
777
+ r"""
778
+ Log-Softmax function.
779
+
780
+ Computes the logarithm of the softmax function, which rescales
781
+ elements to the range :math:`[-\infty, 0)`.
782
+
783
+ .. math ::
784
+ \mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
785
+ \right)
786
+
787
+ Parameters
788
+ ----------
789
+ x : ArrayLike
790
+ Input array.
791
+ axis : int or tuple of int, optional
792
+ The axis or axes along which the log-softmax should be computed.
793
+ Either an integer or a tuple of integers. Default is -1.
794
+ where : ArrayLike, optional
795
+ Elements to include in the log-softmax computation.
796
+
797
+ Returns
798
+ -------
799
+ jax.Array or Quantity
800
+ An array with the same shape as the input.
801
+
802
+ See Also
803
+ --------
804
+ softmax : The softmax function.
805
+ """
806
+ return jax.nn.log_softmax(x, axis=axis, where=where)
807
+
808
+
809
+ def softmax(x: ArrayLike,
810
+ axis: int | tuple[int, ...] | None = -1,
811
+ where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
812
+ r"""
813
+ Softmax activation function.
814
+
815
+ Computes the function which rescales elements to the range :math:`[0, 1]`
816
+ such that the elements along :code:`axis` sum to :math:`1`.
817
+
818
+ .. math ::
819
+ \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
820
+
821
+ Parameters
822
+ ----------
823
+ x : ArrayLike
824
+ Input array.
825
+ axis : int or tuple of int, optional
826
+ The axis or axes along which the softmax should be computed. The
827
+ softmax output summed across these dimensions should sum to :math:`1`.
828
+ Either an integer or a tuple of integers. Default is -1.
829
+ where : ArrayLike, optional
830
+ Elements to include in the softmax computation.
831
+
832
+ Returns
833
+ -------
834
+ jax.Array or Quantity
835
+ An array with the same shape as the input.
836
+
837
+ See Also
838
+ --------
839
+ log_softmax : Logarithm of the softmax function.
840
+ softmin : Softmin activation function.
841
+ """
842
+ return jax.nn.softmax(x, axis=axis, where=where)
843
+
844
+
845
+ def standardize(x: ArrayLike,
846
+ axis: int | tuple[int, ...] | None = -1,
847
+ variance: ArrayLike | None = None,
848
+ epsilon: ArrayLike = 1e-5,
849
+ where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
850
+ r"""
851
+ Standardize (normalize) an array.
852
+
853
+ Normalizes an array by subtracting the mean and dividing by the standard
854
+ deviation :math:`\sqrt{\mathrm{variance}}`.
855
+
856
+ Parameters
857
+ ----------
858
+ x : ArrayLike
859
+ Input array.
860
+ axis : int or tuple of int, optional
861
+ The axis or axes along which to compute the mean and variance.
862
+ Default is -1.
863
+ variance : ArrayLike, optional
864
+ Pre-computed variance. If None, variance is computed from ``x``.
865
+ epsilon : ArrayLike, optional
866
+ A small constant added to the variance to avoid division by zero.
867
+ Default is 1e-5.
868
+ where : ArrayLike, optional
869
+ Elements to include in the computation.
870
+
871
+ Returns
872
+ -------
873
+ jax.Array or Quantity
874
+ Standardized array with the same shape as the input.
875
+ """
876
+ return jax.nn.standardize(x, axis=axis, where=where, variance=variance, epsilon=epsilon)
877
+
878
+
879
+ def one_hot(x: Any,
880
+ num_classes: int, *,
881
+ dtype: Any = jax.numpy.float_,
882
+ axis: Union[int, Sequence[int]] = -1) -> Union[jax.Array, u.Quantity]:
883
+ """
884
+ One-hot encode the given indices.
885
+
886
+ Each index in the input ``x`` is encoded as a vector of zeros of length
887
+ ``num_classes`` with the element at ``index`` set to one.
888
+
889
+ Indices outside the range [0, num_classes) will be encoded as zeros.
890
+
891
+ Parameters
892
+ ----------
893
+ x : ArrayLike
894
+ A tensor of indices.
895
+ num_classes : int
896
+ Number of classes in the one-hot dimension.
897
+ dtype : dtype, optional
898
+ The dtype for the returned values. Default is ``jnp.float_``.
899
+ axis : int or Sequence of int, optional
900
+ The axis or axes along which the function should be computed.
901
+ Default is -1.
902
+
903
+ Returns
904
+ -------
905
+ jax.Array or Quantity
906
+ One-hot encoded array.
907
+
908
+ Examples
909
+ --------
910
+ .. code-block:: python
911
+
912
+ >>> import jax.numpy as jnp
913
+ >>> import brainstate
914
+ >>> brainstate.nn.one_hot(jnp.array([0, 1, 2]), 3)
915
+ Array([[1., 0., 0.],
916
+ [0., 1., 0.],
917
+ [0., 0., 1.]], dtype=float32)
918
+
919
+ >>> # Indices outside the range are encoded as zeros
920
+ >>> brainstate.nn.one_hot(jnp.array([-1, 3]), 3)
921
+ Array([[0., 0., 0.],
922
+ [0., 0., 0.]], dtype=float32)
923
+ """
924
+ return jax.nn.one_hot(x, axis=axis, num_classes=num_classes, dtype=dtype)
925
+
926
+
927
+ def relu6(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
928
+ r"""
929
+ Rectified Linear Unit 6 activation function.
930
+
931
+ Computes the element-wise function:
932
+
933
+ .. math::
934
+ \mathrm{relu6}(x) = \min(\max(x, 0), 6)
935
+
936
+ Under differentiation, we take:
937
+
938
+ .. math::
939
+ \nabla \mathrm{relu}(0) = 0
940
+
941
+ and
942
+
943
+ .. math::
944
+ \nabla \mathrm{relu}(6) = 0
945
+
946
+ Parameters
947
+ ----------
948
+ x : ArrayLike
949
+ Input array.
950
+
951
+ Returns
952
+ -------
953
+ jax.Array or Quantity
954
+ An array with the same shape as the input.
955
+
956
+ See Also
957
+ --------
958
+ relu : Standard ReLU activation function.
959
+ """
960
+ return u.math.relu6(x)
961
+
962
+
963
+ def hard_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
964
+ r"""
965
+ Hard Sigmoid activation function.
966
+
967
+ Computes the element-wise function:
968
+
969
+ .. math::
970
+ \mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
971
+
972
+ Parameters
973
+ ----------
974
+ x : ArrayLike
975
+ Input array.
976
+
977
+ Returns
978
+ -------
979
+ jax.Array or Quantity
980
+ An array with the same shape as the input.
981
+
982
+ See Also
983
+ --------
984
+ relu6 : ReLU6 activation function.
985
+ sigmoid : Standard sigmoid function.
986
+ """
987
+ return u.math.hard_sigmoid(x)
988
+
989
+
990
+ def hard_silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
991
+ r"""
992
+ Hard SiLU (Swish) activation function.
993
+
994
+ Computes the element-wise function:
995
+
996
+ .. math::
997
+ \mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
998
+
999
+ Parameters
1000
+ ----------
1001
+ x : ArrayLike
1002
+ Input array.
1003
+
1004
+ Returns
1005
+ -------
1006
+ jax.Array or Quantity
1007
+ An array with the same shape as the input.
1008
+
1009
+ See Also
1010
+ --------
1011
+ hard_sigmoid : Hard sigmoid activation function.
1012
+ silu : Standard SiLU activation function.
1013
+ hard_swish : Alias for hard_silu.
1014
+
1015
+ Notes
1016
+ -----
1017
+ Both `hard_silu` and `hard_swish` are aliases for the same function.
1018
+ """
1019
+ return u.math.hard_silu(x)
1020
+
1021
+
1022
+ hard_swish = hard_silu
1023
+
1024
+
1025
+ def sparse_plus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
1026
+ r"""
1027
+ Sparse plus activation function.
1028
+
1029
+ Computes the function:
1030
+
1031
+ .. math::
1032
+
1033
+ \mathrm{sparse\_plus}(x) = \begin{cases}
1034
+ 0, & x \leq -1\\
1035
+ \frac{1}{4}(x+1)^2, & -1 < x < 1 \\
1036
+ x, & 1 \leq x
1037
+ \end{cases}
1038
+
1039
+ This is the twin function of the softplus activation, ensuring a zero output
1040
+ for inputs less than -1 and a linear output for inputs greater than 1,
1041
+ while remaining smooth, convex, and monotonic between -1 and 1.
1042
+
1043
+ Parameters
1044
+ ----------
1045
+ x : ArrayLike
1046
+ Input array.
1047
+
1048
+ Returns
1049
+ -------
1050
+ jax.Array or Quantity
1051
+ An array with the same shape as the input.
1052
+
1053
+ See Also
1054
+ --------
1055
+ sparse_sigmoid : Derivative of sparse_plus.
1056
+ softplus : Standard softplus activation function.
1057
+ """
1058
+ return u.math.sparse_plus(x)
1059
+
1060
+
1061
+ def sparse_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
1062
+ r"""
1063
+ Sparse sigmoid activation function.
1064
+
1065
+ Computes the function:
1066
+
1067
+ .. math::
1068
+
1069
+ \mathrm{sparse\_sigmoid}(x) = \begin{cases}
1070
+ 0, & x \leq -1\\
1071
+ \frac{1}{2}(x+1), & -1 < x < 1 \\
1072
+ 1, & 1 \leq x
1073
+ \end{cases}
1074
+
1075
+ This is the twin function of the standard sigmoid activation, ensuring a zero
1076
+ output for inputs less than -1, a 1 output for inputs greater than 1, and a
1077
+ linear output for inputs between -1 and 1. It is the derivative of `sparse_plus`.
1078
+
1079
+ Parameters
1080
+ ----------
1081
+ x : ArrayLike
1082
+ Input array.
1083
+
1084
+ Returns
1085
+ -------
1086
+ jax.Array or Quantity
1087
+ An array with the same shape as the input.
1088
+
1089
+ See Also
1090
+ --------
1091
+ sigmoid : Standard sigmoid activation function.
1092
+ sparse_plus : Sparse plus activation function.
1093
+
1094
+ References
1095
+ ----------
1096
+ .. [1] Martins, A. F. T., & Astudillo, R. F. (2016). "From Softmax to Sparsemax:
1097
+ A Sparse Model of Attention and Multi-Label Classification."
1098
+ In ICML. See also "Learning with Fenchel-Young Losses", arXiv:1901.02324
1099
+ """
1100
+ return u.math.sparse_sigmoid(x)