brainstate 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. brainstate/__init__.py +45 -0
  2. brainstate/_module.py +1466 -0
  3. brainstate/_module_test.py +133 -0
  4. brainstate/_state.py +378 -0
  5. brainstate/_state_test.py +41 -0
  6. brainstate/_utils.py +21 -0
  7. brainstate/environ.py +375 -0
  8. brainstate/functional/__init__.py +25 -0
  9. brainstate/functional/_activations.py +754 -0
  10. brainstate/functional/_normalization.py +69 -0
  11. brainstate/functional/_spikes.py +90 -0
  12. brainstate/init/__init__.py +26 -0
  13. brainstate/init/_base.py +36 -0
  14. brainstate/init/_generic.py +175 -0
  15. brainstate/init/_random_inits.py +489 -0
  16. brainstate/init/_regular_inits.py +109 -0
  17. brainstate/math/__init__.py +21 -0
  18. brainstate/math/_einops.py +787 -0
  19. brainstate/math/_einops_parsing.py +169 -0
  20. brainstate/math/_einops_parsing_test.py +126 -0
  21. brainstate/math/_einops_test.py +346 -0
  22. brainstate/math/_misc.py +298 -0
  23. brainstate/math/_misc_test.py +58 -0
  24. brainstate/mixin.py +373 -0
  25. brainstate/mixin_test.py +73 -0
  26. brainstate/nn/__init__.py +68 -0
  27. brainstate/nn/_base.py +248 -0
  28. brainstate/nn/_connections.py +686 -0
  29. brainstate/nn/_dynamics.py +406 -0
  30. brainstate/nn/_elementwise.py +1437 -0
  31. brainstate/nn/_misc.py +132 -0
  32. brainstate/nn/_normalizations.py +389 -0
  33. brainstate/nn/_others.py +100 -0
  34. brainstate/nn/_poolings.py +1228 -0
  35. brainstate/nn/_poolings_test.py +231 -0
  36. brainstate/nn/_projection/__init__.py +32 -0
  37. brainstate/nn/_projection/_align_post.py +528 -0
  38. brainstate/nn/_projection/_align_pre.py +599 -0
  39. brainstate/nn/_projection/_delta.py +241 -0
  40. brainstate/nn/_projection/_utils.py +17 -0
  41. brainstate/nn/_projection/_vanilla.py +101 -0
  42. brainstate/nn/_rate_rnns.py +393 -0
  43. brainstate/nn/_readout.py +130 -0
  44. brainstate/nn/_synouts.py +166 -0
  45. brainstate/nn/functional/__init__.py +25 -0
  46. brainstate/nn/functional/_activations.py +754 -0
  47. brainstate/nn/functional/_normalization.py +69 -0
  48. brainstate/nn/functional/_spikes.py +90 -0
  49. brainstate/nn/init/__init__.py +26 -0
  50. brainstate/nn/init/_base.py +36 -0
  51. brainstate/nn/init/_generic.py +175 -0
  52. brainstate/nn/init/_random_inits.py +489 -0
  53. brainstate/nn/init/_regular_inits.py +109 -0
  54. brainstate/nn/surrogate.py +1740 -0
  55. brainstate/optim/__init__.py +23 -0
  56. brainstate/optim/_lr_scheduler.py +486 -0
  57. brainstate/optim/_lr_scheduler_test.py +36 -0
  58. brainstate/optim/_sgd_optimizer.py +1148 -0
  59. brainstate/random.py +5148 -0
  60. brainstate/random_test.py +576 -0
  61. brainstate/surrogate.py +1740 -0
  62. brainstate/transform/__init__.py +36 -0
  63. brainstate/transform/_autograd.py +585 -0
  64. brainstate/transform/_autograd_test.py +1183 -0
  65. brainstate/transform/_control.py +665 -0
  66. brainstate/transform/_controls_test.py +220 -0
  67. brainstate/transform/_jit.py +239 -0
  68. brainstate/transform/_jit_error.py +158 -0
  69. brainstate/transform/_jit_test.py +102 -0
  70. brainstate/transform/_make_jaxpr.py +573 -0
  71. brainstate/transform/_make_jaxpr_test.py +133 -0
  72. brainstate/transform/_progress_bar.py +113 -0
  73. brainstate/typing.py +69 -0
  74. brainstate/util.py +747 -0
  75. brainstate-0.0.1.dist-info/LICENSE +202 -0
  76. brainstate-0.0.1.dist-info/METADATA +101 -0
  77. brainstate-0.0.1.dist-info/RECORD +79 -0
  78. brainstate-0.0.1.dist-info/WHEEL +6 -0
  79. brainstate-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,754 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ """
18
+ Shared neural network activations and other functions.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from typing import Any, Union, Sequence
24
+
25
+ import jax
26
+ import jax.numpy as jnp
27
+ from jax.scipy.special import logsumexp
28
+ from jax.typing import ArrayLike
29
+
30
+ from brainstate import math, random
31
+
32
+ __all__ = [
33
+ "tanh",
34
+ "relu",
35
+ "squareplus",
36
+ "softplus",
37
+ "soft_sign",
38
+ "sigmoid",
39
+ "silu",
40
+ "swish",
41
+ "log_sigmoid",
42
+ "elu",
43
+ "leaky_relu",
44
+ "hard_tanh",
45
+ "celu",
46
+ "selu",
47
+ "gelu",
48
+ "glu",
49
+ "logsumexp",
50
+ "log_softmax",
51
+ "softmax",
52
+ "standardize",
53
+ "one_hot",
54
+ "relu6",
55
+ "hard_sigmoid",
56
+ "hard_silu",
57
+ "hard_swish",
58
+ 'hard_shrink',
59
+ 'rrelu',
60
+ 'mish',
61
+ 'soft_shrink',
62
+ 'prelu',
63
+ 'tanh_shrink',
64
+ 'softmin',
65
+ ]
66
+
67
+
68
+ def tanh(x: ArrayLike) -> jax.Array:
69
+ r"""Hyperbolic tangent activation function.
70
+
71
+ Computes the element-wise function:
72
+
73
+ .. math::
74
+ \mathrm{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
75
+
76
+ Args:
77
+ x : input array
78
+
79
+ Returns:
80
+ An array.
81
+ """
82
+ return jnp.tanh(x)
83
+
84
+
85
+ def softmin(x, axis=-1):
86
+ r"""
87
+ Applies the Softmin function to an n-dimensional input Tensor
88
+ rescaling them so that the elements of the n-dimensional output Tensor
89
+ lie in the range `[0, 1]` and sum to 1.
90
+
91
+ Softmin is defined as:
92
+
93
+ .. math::
94
+ \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
95
+
96
+ Shape:
97
+ - Input: :math:`(*)` where `*` means, any number of additional
98
+ dimensions
99
+ - Output: :math:`(*)`, same shape as the input
100
+
101
+ Args:
102
+ axis (int): A dimension along which Softmin will be computed (so every slice
103
+ along dim will sum to 1).
104
+ """
105
+ unnormalized = jnp.exp(-x)
106
+ return unnormalized / unnormalized.sum(axis, keepdims=True)
107
+
108
+
109
+ def tanh_shrink(x):
110
+ r"""
111
+ Applies the element-wise function:
112
+
113
+ .. math::
114
+ \text{Tanhshrink}(x) = x - \tanh(x)
115
+ """
116
+ return x - jnp.tanh(x)
117
+
118
+
119
+ def prelu(x, a=0.25):
120
+ r"""
121
+ Applies the element-wise function:
122
+
123
+ .. math::
124
+ \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
125
+
126
+ or
127
+
128
+ .. math::
129
+ \text{PReLU}(x) =
130
+ \begin{cases}
131
+ x, & \text{ if } x \geq 0 \\
132
+ ax, & \text{ otherwise }
133
+ \end{cases}
134
+
135
+ Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
136
+ parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
137
+ a separate :math:`a` is used for each input channel.
138
+ """
139
+ dtype = math.get_dtype(x)
140
+ return jnp.where(x >= jnp.asarray(0., dtype),
141
+ x,
142
+ jnp.asarray(a, dtype) * x)
143
+
144
+
145
+ def soft_shrink(x, lambd=0.5):
146
+ r"""
147
+ Applies the soft shrinkage function elementwise:
148
+
149
+ .. math::
150
+ \text{SoftShrinkage}(x) =
151
+ \begin{cases}
152
+ x - \lambda, & \text{ if } x > \lambda \\
153
+ x + \lambda, & \text{ if } x < -\lambda \\
154
+ 0, & \text{ otherwise }
155
+ \end{cases}
156
+
157
+ Args:
158
+ lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
159
+
160
+ Shape:
161
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
162
+ - Output: :math:`(*)`, same shape as the input.
163
+ """
164
+ dtype = math.get_dtype(x)
165
+ lambd = jnp.asarray(lambd, dtype)
166
+ return jnp.where(x > lambd,
167
+ x - lambd,
168
+ jnp.where(x < -lambd, x + lambd, jnp.asarray(0., dtype)))
169
+
170
+
171
+ def mish(x):
172
+ r"""Applies the Mish function, element-wise.
173
+
174
+ Mish: A Self Regularized Non-Monotonic Neural Activation Function.
175
+
176
+ .. math::
177
+ \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
178
+
179
+ .. note::
180
+ See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
181
+
182
+ Shape:
183
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
184
+ - Output: :math:`(*)`, same shape as the input.
185
+ """
186
+ return x * jnp.tanh(softplus(x))
187
+
188
+
189
+ def rrelu(x, lower=0.125, upper=0.3333333333333333):
190
+ r"""Applies the randomized leaky rectified liner unit function, element-wise,
191
+ as described in the paper:
192
+
193
+ `Empirical Evaluation of Rectified Activations in Convolutional Network`_.
194
+
195
+ The function is defined as:
196
+
197
+ .. math::
198
+ \text{RReLU}(x) =
199
+ \begin{cases}
200
+ x & \text{if } x \geq 0 \\
201
+ ax & \text{ otherwise }
202
+ \end{cases}
203
+
204
+ where :math:`a` is randomly sampled from uniform distribution
205
+ :math:`\mathcal{U}(\text{lower}, \text{upper})`.
206
+
207
+ See: https://arxiv.org/pdf/1505.00853.pdf
208
+
209
+ Args:
210
+ lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
211
+ upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
212
+
213
+ Shape:
214
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
215
+ - Output: :math:`(*)`, same shape as the input.
216
+
217
+ .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
218
+ https://arxiv.org/abs/1505.00853
219
+ """
220
+ dtype = math.get_dtype(x)
221
+ a = random.uniform(lower, upper, size=jnp.shape(x), dtype=dtype)
222
+ return jnp.where(x >= jnp.asarray(0., dtype), x, jnp.asarray(a, dtype) * x)
223
+
224
+
225
+ def hard_shrink(x, lambd=0.5):
226
+ r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
227
+
228
+ Hardshrink is defined as:
229
+
230
+ .. math::
231
+ \text{HardShrink}(x) =
232
+ \begin{cases}
233
+ x, & \text{ if } x > \lambda \\
234
+ x, & \text{ if } x < -\lambda \\
235
+ 0, & \text{ otherwise }
236
+ \end{cases}
237
+
238
+ Args:
239
+ lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
240
+
241
+ Shape:
242
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
243
+ - Output: :math:`(*)`, same shape as the input.
244
+
245
+ """
246
+ dtype = math.get_dtype(x)
247
+ lambd = jnp.asarray(lambd, dtype)
248
+ return jnp.where(x > lambd,
249
+ x,
250
+ jnp.where(x < -lambd, x, jnp.asarray(0., dtype)))
251
+
252
+
253
+ def relu(x: ArrayLike) -> jax.Array:
254
+ r"""Rectified linear unit activation function.
255
+
256
+ Computes the element-wise function:
257
+
258
+ .. math::
259
+ \mathrm{relu}(x) = \max(x, 0)
260
+
261
+ except under differentiation, we take:
262
+
263
+ .. math::
264
+ \nabla \mathrm{relu}(0) = 0
265
+
266
+ For more information see
267
+ `Numerical influence of ReLU’(0) on backpropagation
268
+ <https://openreview.net/forum?id=urrcVI-_jRm>`_.
269
+
270
+ Args:
271
+ x : input array
272
+
273
+ Returns:
274
+ An array.
275
+
276
+ Example:
277
+ >>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
278
+ Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
279
+
280
+ See also:
281
+ :func:`relu6`
282
+
283
+ """
284
+ return jax.nn.relu(x)
285
+
286
+
287
+ def squareplus(x: ArrayLike, b: ArrayLike = 4) -> jax.Array:
288
+ r"""Squareplus activation function.
289
+
290
+ Computes the element-wise function
291
+
292
+ .. math::
293
+ \mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
294
+
295
+ as described in https://arxiv.org/abs/2112.11687.
296
+
297
+ Args:
298
+ x : input array
299
+ b : smoothness parameter
300
+ """
301
+ dtype = math.get_dtype(x)
302
+ return jax.nn.squareplus(x, jnp.asarray(b, dtype))
303
+
304
+
305
+ def softplus(x: ArrayLike) -> jax.Array:
306
+ r"""Softplus activation function.
307
+
308
+ Computes the element-wise function
309
+
310
+ .. math::
311
+ \mathrm{softplus}(x) = \log(1 + e^x)
312
+
313
+ Args:
314
+ x : input array
315
+ """
316
+ return jax.nn.softplus(x)
317
+
318
+
319
+ def soft_sign(x: ArrayLike) -> jax.Array:
320
+ r"""Soft-sign activation function.
321
+
322
+ Computes the element-wise function
323
+
324
+ .. math::
325
+ \mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
326
+
327
+ Args:
328
+ x : input array
329
+ """
330
+ return jax.nn.soft_sign(x)
331
+
332
+
333
+ def sigmoid(x: ArrayLike) -> jax.Array:
334
+ r"""Sigmoid activation function.
335
+
336
+ Computes the element-wise function:
337
+
338
+ .. math::
339
+ \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
340
+
341
+ Args:
342
+ x : input array
343
+
344
+ Returns:
345
+ An array.
346
+
347
+ See also:
348
+ :func:`log_sigmoid`
349
+
350
+ """
351
+ return jax.nn.sigmoid(x)
352
+
353
+
354
+ def silu(x: ArrayLike) -> jax.Array:
355
+ r"""SiLU (a.k.a. swish) activation function.
356
+
357
+ Computes the element-wise function:
358
+
359
+ .. math::
360
+ \mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
361
+
362
+ :func:`swish` and :func:`silu` are both aliases for the same function.
363
+
364
+ Args:
365
+ x : input array
366
+
367
+ Returns:
368
+ An array.
369
+
370
+ See also:
371
+ :func:`sigmoid`
372
+ """
373
+ return jax.nn.silu(x)
374
+
375
+
376
+ swish = silu
377
+
378
+
379
+ def log_sigmoid(x: ArrayLike) -> jax.Array:
380
+ r"""Log-sigmoid activation function.
381
+
382
+ Computes the element-wise function:
383
+
384
+ .. math::
385
+ \mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
386
+
387
+ Args:
388
+ x : input array
389
+
390
+ Returns:
391
+ An array.
392
+
393
+ See also:
394
+ :func:`sigmoid`
395
+ """
396
+ return jax.nn.log_sigmoid(x)
397
+
398
+
399
+ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
400
+ r"""Exponential linear unit activation function.
401
+
402
+ Computes the element-wise function:
403
+
404
+ .. math::
405
+ \mathrm{elu}(x) = \begin{cases}
406
+ x, & x > 0\\
407
+ \alpha \left(\exp(x) - 1\right), & x \le 0
408
+ \end{cases}
409
+
410
+ Args:
411
+ x : input array
412
+ alpha : scalar or array of alpha values (default: 1.0)
413
+
414
+ Returns:
415
+ An array.
416
+
417
+ See also:
418
+ :func:`selu`
419
+ """
420
+ dtype = math.get_dtype(x)
421
+ alpha = jnp.asarray(alpha, dtype)
422
+ return jax.nn.elu(x, alpha)
423
+
424
+
425
+ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> jax.Array:
426
+ r"""Leaky rectified linear unit activation function.
427
+
428
+ Computes the element-wise function:
429
+
430
+ .. math::
431
+ \mathrm{leaky\_relu}(x) = \begin{cases}
432
+ x, & x \ge 0\\
433
+ \alpha x, & x < 0
434
+ \end{cases}
435
+
436
+ where :math:`\alpha` = :code:`negative_slope`.
437
+
438
+ Args:
439
+ x : input array
440
+ negative_slope : array or scalar specifying the negative slope (default: 0.01)
441
+
442
+ Returns:
443
+ An array.
444
+
445
+ See also:
446
+ :func:`relu`
447
+ """
448
+ dtype = math.get_dtype(x)
449
+ negative_slope = jnp.asarray(negative_slope, dtype)
450
+ return jax.nn.leaky_relu(x, negative_slope=negative_slope)
451
+
452
+
453
+ def hard_tanh(x: ArrayLike) -> jax.Array:
454
+ r"""Hard :math:`\mathrm{tanh}` activation function.
455
+
456
+ Computes the element-wise function:
457
+
458
+ .. math::
459
+ \mathrm{hard\_tanh}(x) = \begin{cases}
460
+ -1, & x < -1\\
461
+ x, & -1 \le x \le 1\\
462
+ 1, & 1 < x
463
+ \end{cases}
464
+
465
+ Args:
466
+ x : input array
467
+
468
+ Returns:
469
+ An array.
470
+ """
471
+ return jax.nn.hard_tanh(x)
472
+
473
+
474
+ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
475
+ r"""Continuously-differentiable exponential linear unit activation.
476
+
477
+ Computes the element-wise function:
478
+
479
+ .. math::
480
+ \mathrm{celu}(x) = \begin{cases}
481
+ x, & x > 0\\
482
+ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
483
+ \end{cases}
484
+
485
+ For more information, see
486
+ `Continuously Differentiable Exponential Linear Units
487
+ <https://arxiv.org/pdf/1704.07483.pdf>`_.
488
+
489
+ Args:
490
+ x : input array
491
+ alpha : array or scalar (default: 1.0)
492
+
493
+ Returns:
494
+ An array.
495
+ """
496
+ dtype = math.get_dtype(x)
497
+ alpha = jnp.asarray(alpha, dtype)
498
+ return jax.nn.celu(x, alpha)
499
+
500
+
501
+ def selu(x: ArrayLike) -> jax.Array:
502
+ r"""Scaled exponential linear unit activation.
503
+
504
+ Computes the element-wise function:
505
+
506
+ .. math::
507
+ \mathrm{selu}(x) = \lambda \begin{cases}
508
+ x, & x > 0\\
509
+ \alpha e^x - \alpha, & x \le 0
510
+ \end{cases}
511
+
512
+ where :math:`\lambda = 1.0507009873554804934193349852946` and
513
+ :math:`\alpha = 1.6732632423543772848170429916717`.
514
+
515
+ For more information, see
516
+ `Self-Normalizing Neural Networks
517
+ <https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf>`_.
518
+
519
+ Args:
520
+ x : input array
521
+
522
+ Returns:
523
+ An array.
524
+
525
+ See also:
526
+ :func:`elu`
527
+ """
528
+ return jax.nn.selu(x)
529
+
530
+
531
+ def gelu(x: ArrayLike, approximate: bool = True) -> jax.Array:
532
+ r"""Gaussian error linear unit activation function.
533
+
534
+ If ``approximate=False``, computes the element-wise function:
535
+
536
+ .. math::
537
+ \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
538
+ \frac{x}{\sqrt{2}} \right) \right)
539
+
540
+ If ``approximate=True``, uses the approximate formulation of GELU:
541
+
542
+ .. math::
543
+ \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
544
+ \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
545
+
546
+ For more information, see `Gaussian Error Linear Units (GELUs)
547
+ <https://arxiv.org/abs/1606.08415>`_, section 2.
548
+
549
+ Args:
550
+ x : input array
551
+ approximate: whether to use the approximate or exact formulation.
552
+ """
553
+ return jax.nn.gelu(x, approximate=approximate)
554
+
555
+
556
+ def glu(x: ArrayLike, axis: int = -1) -> jax.Array:
557
+ r"""Gated linear unit activation function.
558
+
559
+ Computes the function:
560
+
561
+ .. math::
562
+ \mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot
563
+ \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right]
564
+ \right)
565
+
566
+ where the array is split into two along ``axis``. The size of the ``axis``
567
+ dimension must be divisible by two.
568
+
569
+ Args:
570
+ x : input array
571
+ axis: the axis along which the split should be computed (default: -1)
572
+
573
+ Returns:
574
+ An array.
575
+
576
+ See also:
577
+ :func:`sigmoid`
578
+ """
579
+ return jax.nn.glu(x, axis=axis)
580
+
581
+
582
+ def log_softmax(x: ArrayLike,
583
+ axis: int | tuple[int, ...] | None = -1,
584
+ where: ArrayLike | None = None,
585
+ initial: ArrayLike | None = None) -> jax.Array:
586
+ r"""Log-Softmax function.
587
+
588
+ Computes the logarithm of the :code:`softmax` function, which rescales
589
+ elements to the range :math:`[-\infty, 0)`.
590
+
591
+ .. math ::
592
+ \mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
593
+ \right)
594
+
595
+ Args:
596
+ x : input array
597
+ axis: the axis or axes along which the :code:`log_softmax` should be
598
+ computed. Either an integer or a tuple of integers.
599
+ where: Elements to include in the :code:`log_softmax`.
600
+ initial: The minimum value used to shift the input array. Must be present
601
+ when :code:`where` is not None.
602
+
603
+ Returns:
604
+ An array.
605
+
606
+ See also:
607
+ :func:`softmax`
608
+ """
609
+ return jax.nn.log_softmax(x, axis, where, initial)
610
+
611
+
612
+ def softmax(x: ArrayLike,
613
+ axis: int | tuple[int, ...] | None = -1,
614
+ where: ArrayLike | None = None,
615
+ initial: ArrayLike | None = None) -> jax.Array:
616
+ r"""Softmax function.
617
+
618
+ Computes the function which rescales elements to the range :math:`[0, 1]`
619
+ such that the elements along :code:`axis` sum to :math:`1`.
620
+
621
+ .. math ::
622
+ \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
623
+
624
+ Args:
625
+ x : input array
626
+ axis: the axis or axes along which the softmax should be computed. The
627
+ softmax output summed across these dimensions should sum to :math:`1`.
628
+ Either an integer or a tuple of integers.
629
+ where: Elements to include in the :code:`softmax`.
630
+ initial: The minimum value used to shift the input array. Must be present
631
+ when :code:`where` is not None.
632
+
633
+ Returns:
634
+ An array.
635
+
636
+ See also:
637
+ :func:`log_softmax`
638
+ """
639
+ return jax.nn.softmax(x, axis, where, initial)
640
+
641
+
642
+ def standardize(x: ArrayLike,
643
+ axis: int | tuple[int, ...] | None = -1,
644
+ variance: ArrayLike | None = None,
645
+ epsilon: ArrayLike = 1e-5,
646
+ where: ArrayLike | None = None) -> jax.Array:
647
+ r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
648
+ return jax.nn.standardize(x, axis, variance, epsilon, where)
649
+
650
+
651
+ def one_hot(x: Any,
652
+ num_classes: int, *,
653
+ dtype: Any = jnp.float_,
654
+ axis: Union[int, Sequence[int]] = -1) -> jax.Array:
655
+ """One-hot encodes the given indices.
656
+
657
+ Each index in the input ``x`` is encoded as a vector of zeros of length
658
+ ``num_classes`` with the element at ``index`` set to one::
659
+
660
+ >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
661
+ Array([[1., 0., 0.],
662
+ [0., 1., 0.],
663
+ [0., 0., 1.]], dtype=float32)
664
+
665
+ Indices outside the range [0, num_classes) will be encoded as zeros::
666
+
667
+ >>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
668
+ Array([[0., 0., 0.],
669
+ [0., 0., 0.]], dtype=float32)
670
+
671
+ Args:
672
+ x: A tensor of indices.
673
+ num_classes: Number of classes in the one-hot dimension.
674
+ dtype: optional, a float dtype for the returned values (default :obj:`jnp.float_`).
675
+ axis: the axis or axes along which the function should be
676
+ computed.
677
+ """
678
+ return jax.nn.one_hot(x, num_classes, dtype=dtype, axis=axis)
679
+
680
+
681
+ def relu6(x: ArrayLike) -> jax.Array:
682
+ r"""Rectified Linear Unit 6 activation function.
683
+
684
+ Computes the element-wise function
685
+
686
+ .. math::
687
+ \mathrm{relu6}(x) = \min(\max(x, 0), 6)
688
+
689
+ except under differentiation, we take:
690
+
691
+ .. math::
692
+ \nabla \mathrm{relu}(0) = 0
693
+
694
+ and
695
+
696
+ .. math::
697
+ \nabla \mathrm{relu}(6) = 0
698
+
699
+ Args:
700
+ x : input array
701
+
702
+ Returns:
703
+ An array.
704
+
705
+ See also:
706
+ :func:`relu`
707
+ """
708
+ return jax.nn.relu6(x)
709
+
710
+
711
+ def hard_sigmoid(x: ArrayLike) -> jax.Array:
712
+ r"""Hard Sigmoid activation function.
713
+
714
+ Computes the element-wise function
715
+
716
+ .. math::
717
+ \mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
718
+
719
+ Args:
720
+ x : input array
721
+
722
+ Returns:
723
+ An array.
724
+
725
+ See also:
726
+ :func:`relu6`
727
+ """
728
+ return jax.nn.hard_sigmoid(x)
729
+
730
+
731
+ def hard_silu(x: ArrayLike) -> jax.Array:
732
+ r"""Hard SiLU (swish) activation function
733
+
734
+ Computes the element-wise function
735
+
736
+ .. math::
737
+ \mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
738
+
739
+ Both :func:`hard_silu` and :func:`hard_swish` are aliases for the same
740
+ function.
741
+
742
+ Args:
743
+ x : input array
744
+
745
+ Returns:
746
+ An array.
747
+
748
+ See also:
749
+ :func:`hard_sigmoid`
750
+ """
751
+ return jax.nn.hard_silu(x)
752
+
753
+
754
+ hard_swish = hard_silu