brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,808 +1,1100 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- """
18
- Shared neural network activations and other functions.
19
- """
20
-
21
- from typing import Any, Union, Sequence
22
-
23
- import brainunit as u
24
- import jax
25
- from jax.scipy.special import logsumexp
26
-
27
- from brainstate import random
28
- from brainstate.typing import ArrayLike
29
-
30
- __all__ = [
31
- "tanh",
32
- "relu",
33
- "squareplus",
34
- "softplus",
35
- "soft_sign",
36
- "sigmoid",
37
- "silu",
38
- "swish",
39
- "log_sigmoid",
40
- "elu",
41
- "leaky_relu",
42
- "hard_tanh",
43
- "celu",
44
- "selu",
45
- "gelu",
46
- "glu",
47
- "logsumexp",
48
- "log_softmax",
49
- "softmax",
50
- "standardize",
51
- "one_hot",
52
- "relu6",
53
- "hard_sigmoid",
54
- "hard_silu",
55
- "hard_swish",
56
- 'hard_shrink',
57
- 'rrelu',
58
- 'mish',
59
- 'soft_shrink',
60
- 'prelu',
61
- 'tanh_shrink',
62
- 'softmin',
63
- 'sparse_plus',
64
- 'sparse_sigmoid',
65
- ]
66
-
67
-
68
- def tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
69
- r"""Hyperbolic tangent activation function.
70
-
71
- Computes the element-wise function:
72
-
73
- .. math::
74
- \mathrm{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
75
-
76
- Args:
77
- x : input array
78
-
79
- Returns:
80
- An array.
81
- """
82
- return u.math.tanh(x)
83
-
84
-
85
- def softmin(x, axis=-1):
86
- r"""
87
- Applies the Softmin function to an n-dimensional input Tensor
88
- rescaling them so that the elements of the n-dimensional output Tensor
89
- lie in the range `[0, 1]` and sum to 1.
90
-
91
- Softmin is defined as:
92
-
93
- .. math::
94
- \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
95
-
96
- Shape:
97
- - Input: :math:`(*)` where `*` means, any number of additional
98
- dimensions
99
- - Output: :math:`(*)`, same shape as the input
100
-
101
- Args:
102
- axis (int): A dimension along which Softmin will be computed (so every slice
103
- along dim will sum to 1).
104
- """
105
- unnormalized = u.math.exp(-x)
106
- return unnormalized / unnormalized.sum(axis, keepdims=True)
107
-
108
-
109
- def tanh_shrink(x):
110
- r"""
111
- Applies the element-wise function:
112
-
113
- .. math::
114
- \text{Tanhshrink}(x) = x - \tanh(x)
115
- """
116
- return x - u.math.tanh(x)
117
-
118
-
119
- def prelu(x, a=0.25):
120
- r"""
121
- Applies the element-wise function:
122
-
123
- .. math::
124
- \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
125
-
126
- or
127
-
128
- .. math::
129
- \text{PReLU}(x) =
130
- \begin{cases}
131
- x, & \text{ if } x \geq 0 \\
132
- ax, & \text{ otherwise }
133
- \end{cases}
134
-
135
- Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
136
- parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
137
- a separate :math:`a` is used for each input channel.
138
- """
139
- return u.math.where(x >= 0., x, a * x)
140
-
141
-
142
- def soft_shrink(x, lambd=0.5):
143
- r"""
144
- Applies the soft shrinkage function elementwise:
145
-
146
- .. math::
147
- \text{SoftShrinkage}(x) =
148
- \begin{cases}
149
- x - \lambda, & \text{ if } x > \lambda \\
150
- x + \lambda, & \text{ if } x < -\lambda \\
151
- 0, & \text{ otherwise }
152
- \end{cases}
153
-
154
- Args:
155
- lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
156
-
157
- Shape:
158
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
159
- - Output: :math:`(*)`, same shape as the input.
160
- """
161
- return u.math.where(x > lambd,
162
- x - lambd,
163
- u.math.where(x < -lambd,
164
- x + lambd,
165
- u.Quantity(0., unit=u.get_unit(lambd))))
166
-
167
-
168
- def mish(x):
169
- r"""Applies the Mish function, element-wise.
170
-
171
- Mish: A Self Regularized Non-Monotonic Neural Activation Function.
172
-
173
- .. math::
174
- \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
175
-
176
- .. note::
177
- See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
178
-
179
- Shape:
180
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
181
- - Output: :math:`(*)`, same shape as the input.
182
- """
183
- return x * u.math.tanh(softplus(x))
184
-
185
-
186
- def rrelu(x, lower=0.125, upper=0.3333333333333333):
187
- r"""Applies the randomized leaky rectified liner unit function, element-wise,
188
- as described in the paper:
189
-
190
- `Empirical Evaluation of Rectified Activations in Convolutional Network`_.
191
-
192
- The function is defined as:
193
-
194
- .. math::
195
- \text{RReLU}(x) =
196
- \begin{cases}
197
- x & \text{if } x \geq 0 \\
198
- ax & \text{ otherwise }
199
- \end{cases}
200
-
201
- where :math:`a` is randomly sampled from uniform distribution
202
- :math:`\mathcal{U}(\text{lower}, \text{upper})`.
203
-
204
- See: https://arxiv.org/pdf/1505.00853.pdf
205
-
206
- Args:
207
- lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
208
- upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
209
-
210
- Shape:
211
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
212
- - Output: :math:`(*)`, same shape as the input.
213
-
214
- .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
215
- https://arxiv.org/abs/1505.00853
216
- """
217
- a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
218
- return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
219
-
220
-
221
- def hard_shrink(x, lambd=0.5):
222
- r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
223
-
224
- Hardshrink is defined as:
225
-
226
- .. math::
227
- \text{HardShrink}(x) =
228
- \begin{cases}
229
- x, & \text{ if } x > \lambda \\
230
- x, & \text{ if } x < -\lambda \\
231
- 0, & \text{ otherwise }
232
- \end{cases}
233
-
234
- Args:
235
- lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
236
-
237
- Shape:
238
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
239
- - Output: :math:`(*)`, same shape as the input.
240
-
241
- """
242
- return u.math.where(x > lambd,
243
- x,
244
- u.math.where(x < -lambd,
245
- x,
246
- u.Quantity(0., unit=u.get_unit(x))))
247
-
248
-
249
- def relu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
250
- r"""Rectified linear unit activation function.
251
-
252
- Computes the element-wise function:
253
-
254
- .. math::
255
- \mathrm{relu}(x) = \max(x, 0)
256
-
257
- except under differentiation, we take:
258
-
259
- .. math::
260
- \nabla \mathrm{relu}(0) = 0
261
-
262
- For more information see
263
- `Numerical influence of ReLU’(0) on backpropagation
264
- <https://openreview.net/forum?id=urrcVI-_jRm>`_.
265
-
266
- Args:
267
- x : input array
268
-
269
- Returns:
270
- An array.
271
-
272
- Example:
273
- >>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
274
- Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
275
-
276
- See also:
277
- :func:`relu6`
278
-
279
- """
280
- return u.math.relu(x)
281
-
282
-
283
- def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Union[jax.Array, u.Quantity]:
284
- r"""Squareplus activation function.
285
-
286
- Computes the element-wise function
287
-
288
- .. math::
289
- \mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
290
-
291
- as described in https://arxiv.org/abs/2112.11687.
292
-
293
- Args:
294
- x : input array
295
- b : smoothness parameter
296
- """
297
- return u.math.squareplus(x, b=b)
298
-
299
-
300
- def softplus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
301
- r"""Softplus activation function.
302
-
303
- Computes the element-wise function
304
-
305
- .. math::
306
- \mathrm{softplus}(x) = \log(1 + e^x)
307
-
308
- Args:
309
- x : input array
310
- """
311
- return u.math.softplus(x)
312
-
313
-
314
- def soft_sign(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
315
- r"""Soft-sign activation function.
316
-
317
- Computes the element-wise function
318
-
319
- .. math::
320
- \mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
321
-
322
- Args:
323
- x : input array
324
- """
325
- return u.math.soft_sign(x)
326
-
327
-
328
- def sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
329
- r"""Sigmoid activation function.
330
-
331
- Computes the element-wise function:
332
-
333
- .. math::
334
- \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
335
-
336
- Args:
337
- x : input array
338
-
339
- Returns:
340
- An array.
341
-
342
- See also:
343
- :func:`log_sigmoid`
344
-
345
- """
346
- return u.math.sigmoid(x)
347
-
348
-
349
- def silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
350
- r"""SiLU (a.k.a. swish) activation function.
351
-
352
- Computes the element-wise function:
353
-
354
- .. math::
355
- \mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
356
-
357
- :func:`swish` and :func:`silu` are both aliases for the same function.
358
-
359
- Args:
360
- x : input array
361
-
362
- Returns:
363
- An array.
364
-
365
- See also:
366
- :func:`sigmoid`
367
- """
368
- return u.math.silu(x)
369
-
370
-
371
- swish = silu
372
-
373
-
374
- def log_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
375
- r"""Log-sigmoid activation function.
376
-
377
- Computes the element-wise function:
378
-
379
- .. math::
380
- \mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
381
-
382
- Args:
383
- x : input array
384
-
385
- Returns:
386
- An array.
387
-
388
- See also:
389
- :func:`sigmoid`
390
- """
391
- return u.math.log_sigmoid(x)
392
-
393
-
394
- def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
395
- r"""Exponential linear unit activation function.
396
-
397
- Computes the element-wise function:
398
-
399
- .. math::
400
- \mathrm{elu}(x) = \begin{cases}
401
- x, & x > 0\\
402
- \alpha \left(\exp(x) - 1\right), & x \le 0
403
- \end{cases}
404
-
405
- Args:
406
- x : input array
407
- alpha : scalar or array of alpha values (default: 1.0)
408
-
409
- Returns:
410
- An array.
411
-
412
- See also:
413
- :func:`selu`
414
- """
415
- return u.math.elu(x, alpha=alpha)
416
-
417
-
418
- def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Union[jax.Array, u.Quantity]:
419
- r"""Leaky rectified linear unit activation function.
420
-
421
- Computes the element-wise function:
422
-
423
- .. math::
424
- \mathrm{leaky\_relu}(x) = \begin{cases}
425
- x, & x \ge 0\\
426
- \alpha x, & x < 0
427
- \end{cases}
428
-
429
- where :math:`\alpha` = :code:`negative_slope`.
430
-
431
- Args:
432
- x : input array
433
- negative_slope : array or scalar specifying the negative slope (default: 0.01)
434
-
435
- Returns:
436
- An array.
437
-
438
- See also:
439
- :func:`relu`
440
- """
441
- return u.math.leaky_relu(x, negative_slope=negative_slope)
442
-
443
-
444
- def _hard_tanh(x, min_val=- 1.0, max_val=1.0):
445
- return jax.numpy.where(x > max_val, max_val, jax.numpy.where(x < min_val, min_val, x))
446
-
447
-
448
- def hard_tanh(
449
- x: ArrayLike,
450
- min_val: float = - 1.0,
451
- max_val: float = 1.0
452
- ) -> Union[jax.Array, u.Quantity]:
453
- r"""Hard :math:`\mathrm{tanh}` activation function.
454
-
455
- Computes the element-wise function:
456
-
457
- .. math::
458
- \mathrm{hard\_tanh}(x) = \begin{cases}
459
- -1, & x < -1\\
460
- x, & -1 \le x \le 1\\
461
- 1, & 1 < x
462
- \end{cases}
463
-
464
- Args:
465
- x : input array
466
- min_val: float. minimum value of the linear region range. Default: -1
467
- max_val: float. maximum value of the linear region range. Default: 1
468
-
469
- Returns:
470
- An array.
471
- """
472
- x = u.Quantity(x)
473
- min_val = u.Quantity(min_val).to(x.unit).mantissa
474
- max_val = u.Quantity(max_val).to(x.unit).mantissa
475
- return u.maybe_decimal(_hard_tanh(x.mantissa, min_val=min_val, max_val=max_val) * x.unit)
476
-
477
-
478
- def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
479
- r"""Continuously-differentiable exponential linear unit activation.
480
-
481
- Computes the element-wise function:
482
-
483
- .. math::
484
- \mathrm{celu}(x) = \begin{cases}
485
- x, & x > 0\\
486
- \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
487
- \end{cases}
488
-
489
- For more information, see
490
- `Continuously Differentiable Exponential Linear Units
491
- <https://arxiv.org/pdf/1704.07483.pdf>`_.
492
-
493
- Args:
494
- x : input array
495
- alpha : array or scalar (default: 1.0)
496
-
497
- Returns:
498
- An array.
499
- """
500
- return u.math.celu(x, alpha=alpha)
501
-
502
-
503
- def selu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
504
- r"""Scaled exponential linear unit activation.
505
-
506
- Computes the element-wise function:
507
-
508
- .. math::
509
- \mathrm{selu}(x) = \lambda \begin{cases}
510
- x, & x > 0\\
511
- \alpha e^x - \alpha, & x \le 0
512
- \end{cases}
513
-
514
- where :math:`\lambda = 1.0507009873554804934193349852946` and
515
- :math:`\alpha = 1.6732632423543772848170429916717`.
516
-
517
- For more information, see
518
- `Self-Normalizing Neural Networks
519
- <https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf>`_.
520
-
521
- Args:
522
- x : input array
523
-
524
- Returns:
525
- An array.
526
-
527
- See also:
528
- :func:`elu`
529
- """
530
- return u.math.selu(x)
531
-
532
-
533
- def gelu(x: ArrayLike, approximate: bool = True) -> Union[jax.Array, u.Quantity]:
534
- r"""Gaussian error linear unit activation function.
535
-
536
- If ``approximate=False``, computes the element-wise function:
537
-
538
- .. math::
539
- \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
540
- \frac{x}{\sqrt{2}} \right) \right)
541
-
542
- If ``approximate=True``, uses the approximate formulation of GELU:
543
-
544
- .. math::
545
- \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
546
- \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
547
-
548
- For more information, see `Gaussian Error Linear Units (GELUs)
549
- <https://arxiv.org/abs/1606.08415>`_, section 2.
550
-
551
- Args:
552
- x : input array
553
- approximate: whether to use the approximate or exact formulation.
554
- """
555
- return u.math.gelu(x, approximate=approximate)
556
-
557
-
558
- def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
559
- r"""Gated linear unit activation function.
560
-
561
- Computes the function:
562
-
563
- .. math::
564
- \mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot
565
- \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right]
566
- \right)
567
-
568
- where the array is split into two along ``axis``. The size of the ``axis``
569
- dimension must be divisible by two.
570
-
571
- Args:
572
- x : input array
573
- axis: the axis along which the split should be computed (default: -1)
574
-
575
- Returns:
576
- An array.
577
-
578
- See also:
579
- :func:`sigmoid`
580
- """
581
- return u.math.glu(x, axis=axis)
582
-
583
-
584
- def log_softmax(x: ArrayLike,
585
- axis: int | tuple[int, ...] | None = -1,
586
- where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
587
- r"""Log-Softmax function.
588
-
589
- Computes the logarithm of the :code:`softmax` function, which rescales
590
- elements to the range :math:`[-\infty, 0)`.
591
-
592
- .. math ::
593
- \mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
594
- \right)
595
-
596
- Args:
597
- x : input array
598
- axis: the axis or axes along which the :code:`log_softmax` should be
599
- computed. Either an integer or a tuple of integers.
600
- where: Elements to include in the :code:`log_softmax`.
601
-
602
- Returns:
603
- An array.
604
-
605
- See also:
606
- :func:`softmax`
607
- """
608
- return jax.nn.log_softmax(x, axis=axis, where=where)
609
-
610
-
611
- def softmax(x: ArrayLike,
612
- axis: int | tuple[int, ...] | None = -1,
613
- where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
614
- r"""Softmax function.
615
-
616
- Computes the function which rescales elements to the range :math:`[0, 1]`
617
- such that the elements along :code:`axis` sum to :math:`1`.
618
-
619
- .. math ::
620
- \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
621
-
622
- Args:
623
- x : input array
624
- axis: the axis or axes along which the softmax should be computed. The
625
- softmax output summed across these dimensions should sum to :math:`1`.
626
- Either an integer or a tuple of integers.
627
- where: Elements to include in the :code:`softmax`.
628
- initial: The minimum value used to shift the input array. Must be present
629
- when :code:`where` is not None.
630
-
631
- Returns:
632
- An array.
633
-
634
- See also:
635
- :func:`log_softmax`
636
- """
637
- return jax.nn.softmax(x, axis=axis, where=where)
638
-
639
-
640
- def standardize(x: ArrayLike,
641
- axis: int | tuple[int, ...] | None = -1,
642
- variance: ArrayLike | None = None,
643
- epsilon: ArrayLike = 1e-5,
644
- where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
645
- r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
646
- return jax.nn.standardize(x, axis=axis, where=where, variance=variance, epsilon=epsilon)
647
-
648
-
649
- def one_hot(x: Any,
650
- num_classes: int, *,
651
- dtype: Any = jax.numpy.float_,
652
- axis: Union[int, Sequence[int]] = -1) -> Union[jax.Array, u.Quantity]:
653
- """One-hot encodes the given indices.
654
-
655
- Each index in the input ``x`` is encoded as a vector of zeros of length
656
- ``num_classes`` with the element at ``index`` set to one::
657
-
658
- >>> one_hot(jnp.array([0, 1, 2]), 3)
659
- Array([[1., 0., 0.],
660
- [0., 1., 0.],
661
- [0., 0., 1.]], dtype=float32)
662
-
663
- Indices outside the range [0, num_classes) will be encoded as zeros::
664
-
665
- >>> one_hot(jnp.array([-1, 3]), 3)
666
- Array([[0., 0., 0.],
667
- [0., 0., 0.]], dtype=float32)
668
-
669
- Args:
670
- x: A tensor of indices.
671
- num_classes: Number of classes in the one-hot dimension.
672
- dtype: optional, a float dtype for the returned values (default :obj:`jnp.float_`).
673
- axis: the axis or axes along which the function should be
674
- computed.
675
- """
676
- return jax.nn.one_hot(x, axis=axis, num_classes=num_classes, dtype=dtype)
677
-
678
-
679
- def relu6(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
680
- r"""Rectified Linear Unit 6 activation function.
681
-
682
- Computes the element-wise function
683
-
684
- .. math::
685
- \mathrm{relu6}(x) = \min(\max(x, 0), 6)
686
-
687
- except under differentiation, we take:
688
-
689
- .. math::
690
- \nabla \mathrm{relu}(0) = 0
691
-
692
- and
693
-
694
- .. math::
695
- \nabla \mathrm{relu}(6) = 0
696
-
697
- Args:
698
- x : input array
699
-
700
- Returns:
701
- An array.
702
-
703
- See also:
704
- :func:`relu`
705
- """
706
- return u.math.relu6(x)
707
-
708
-
709
- def hard_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
710
- r"""Hard Sigmoid activation function.
711
-
712
- Computes the element-wise function
713
-
714
- .. math::
715
- \mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
716
-
717
- Args:
718
- x : input array
719
-
720
- Returns:
721
- An array.
722
-
723
- See also:
724
- :func:`relu6`
725
- """
726
- return u.math.hard_sigmoid(x)
727
-
728
-
729
- def hard_silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
730
- r"""Hard SiLU (swish) activation function
731
-
732
- Computes the element-wise function
733
-
734
- .. math::
735
- \mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
736
-
737
- Both :func:`hard_silu` and :func:`hard_swish` are aliases for the same
738
- function.
739
-
740
- Args:
741
- x : input array
742
-
743
- Returns:
744
- An array.
745
-
746
- See also:
747
- :func:`hard_sigmoid`
748
- """
749
- return u.math.hard_silu(x)
750
-
751
-
752
- hard_swish = hard_silu
753
-
754
-
755
- def sparse_plus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
756
- r"""Sparse plus function.
757
-
758
- Computes the function:
759
-
760
- .. math::
761
-
762
- \mathrm{sparse\_plus}(x) = \begin{cases}
763
- 0, & x \leq -1\\
764
- \frac{1}{4}(x+1)^2, & -1 < x < 1 \\
765
- x, & 1 \leq x
766
- \end{cases}
767
-
768
- This is the twin function of the softplus activation ensuring a zero output
769
- for inputs less than -1 and a linear output for inputs greater than 1,
770
- while remaining smooth, convex, monotonic by an adequate definition between
771
- -1 and 1.
772
-
773
- Args:
774
- x: input (float)
775
- """
776
- return u.math.sparse_plus(x)
777
-
778
-
779
- def sparse_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
780
- r"""Sparse sigmoid activation function.
781
-
782
- Computes the function:
783
-
784
- .. math::
785
-
786
- \mathrm{sparse\_sigmoid}(x) = \begin{cases}
787
- 0, & x \leq -1\\
788
- \frac{1}{2}(x+1), & -1 < x < 1 \\
789
- 1, & 1 \leq x
790
- \end{cases}
791
-
792
- This is the twin function of the ``sigmoid`` activation ensuring a zero output
793
- for inputs less than -1, a 1 output for inputs greater than 1, and a linear
794
- output for inputs between -1 and 1. It is the derivative of ``sparse_plus``.
795
-
796
- For more information, see `Learning with Fenchel-Young Losses (section 6.2)
797
- <https://arxiv.org/abs/1901.02324>`_.
798
-
799
- Args:
800
- x : input array
801
-
802
- Returns:
803
- An array.
804
-
805
- See also:
806
- :func:`sigmoid`
807
- """
808
- return u.math.sparse_sigmoid(x)
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ """
18
+ Shared neural network activations and other functions.
19
+ """
20
+
21
+ from typing import Any, Union, Sequence
22
+
23
+ import brainunit as u
24
+ import jax
25
+ from jax.scipy.special import logsumexp
26
+
27
+ from brainstate import random
28
+ from brainstate.typing import ArrayLike
29
+
30
+ __all__ = [
31
+ "tanh",
32
+ "relu",
33
+ "squareplus",
34
+ "softplus",
35
+ "soft_sign",
36
+ "sigmoid",
37
+ "silu",
38
+ "swish",
39
+ "log_sigmoid",
40
+ "elu",
41
+ "leaky_relu",
42
+ "hard_tanh",
43
+ "celu",
44
+ "selu",
45
+ "gelu",
46
+ "glu",
47
+ "logsumexp",
48
+ "log_softmax",
49
+ "softmax",
50
+ "standardize",
51
+ "one_hot",
52
+ "relu6",
53
+ "hard_sigmoid",
54
+ "hard_silu",
55
+ "hard_swish",
56
+ 'hard_shrink',
57
+ 'rrelu',
58
+ 'mish',
59
+ 'soft_shrink',
60
+ 'prelu',
61
+ 'tanh_shrink',
62
+ 'softmin',
63
+ 'sparse_plus',
64
+ 'sparse_sigmoid',
65
+ ]
66
+
67
+
68
+ def tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
69
+ r"""
70
+ Hyperbolic tangent activation function.
71
+
72
+ Computes the element-wise function:
73
+
74
+ .. math::
75
+ \mathrm{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
76
+
77
+ Parameters
78
+ ----------
79
+ x : ArrayLike
80
+ Input array.
81
+
82
+ Returns
83
+ -------
84
+ jax.Array or Quantity
85
+ An array with the same shape as the input.
86
+ """
87
+ return u.math.tanh(x)
88
+
89
+
90
+ def softmin(x, axis=-1):
91
+ r"""
92
+ Softmin activation function.
93
+
94
+ Applies the Softmin function to an n-dimensional input tensor, rescaling elements
95
+ so that they lie in the range [0, 1] and sum to 1 along the specified axis.
96
+
97
+ .. math::
98
+ \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
99
+
100
+ Parameters
101
+ ----------
102
+ x : ArrayLike
103
+ Input array of any shape.
104
+ axis : int, optional
105
+ The axis along which Softmin will be computed. Every slice along this
106
+ dimension will sum to 1. Default is -1.
107
+
108
+ Returns
109
+ -------
110
+ jax.Array or Quantity
111
+ Output array with the same shape as the input.
112
+ """
113
+ unnormalized = u.math.exp(-x)
114
+ return unnormalized / unnormalized.sum(axis, keepdims=True)
115
+
116
+
117
+ def tanh_shrink(x):
118
+ r"""
119
+ Tanh shrink activation function.
120
+
121
+ Applies the element-wise function:
122
+
123
+ .. math::
124
+ \text{Tanhshrink}(x) = x - \tanh(x)
125
+
126
+ Parameters
127
+ ----------
128
+ x : ArrayLike
129
+ Input array.
130
+
131
+ Returns
132
+ -------
133
+ jax.Array or Quantity
134
+ Output array with the same shape as the input.
135
+ """
136
+ return x - u.math.tanh(x)
137
+
138
+
139
+ def prelu(x, a=0.25):
140
+ r"""
141
+ Parametric Rectified Linear Unit activation function.
142
+
143
+ Applies the element-wise function:
144
+
145
+ .. math::
146
+ \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
147
+
148
+ or equivalently:
149
+
150
+ .. math::
151
+ \text{PReLU}(x) =
152
+ \begin{cases}
153
+ x, & \text{ if } x \geq 0 \\
154
+ ax, & \text{ otherwise }
155
+ \end{cases}
156
+
157
+ Parameters
158
+ ----------
159
+ x : ArrayLike
160
+ Input array.
161
+ a : float or ArrayLike, optional
162
+ The negative slope coefficient. Can be a learnable parameter.
163
+ Default is 0.25.
164
+
165
+ Returns
166
+ -------
167
+ jax.Array or Quantity
168
+ Output array with the same shape as the input.
169
+
170
+ Notes
171
+ -----
172
+ When used in neural network layers, :math:`a` can be a learnable parameter
173
+ that is optimized during training.
174
+ """
175
+ return u.math.where(x >= 0., x, a * x)
176
+
177
+
178
+ def soft_shrink(x, lambd=0.5):
179
+ r"""
180
+ Soft shrinkage activation function.
181
+
182
+ Applies the soft shrinkage function element-wise:
183
+
184
+ .. math::
185
+ \text{SoftShrinkage}(x) =
186
+ \begin{cases}
187
+ x - \lambda, & \text{ if } x > \lambda \\
188
+ x + \lambda, & \text{ if } x < -\lambda \\
189
+ 0, & \text{ otherwise }
190
+ \end{cases}
191
+
192
+ Parameters
193
+ ----------
194
+ x : ArrayLike
195
+ Input array of any shape.
196
+ lambd : float, optional
197
+ The :math:`\lambda` value for the soft shrinkage formulation.
198
+ Must be non-negative. Default is 0.5.
199
+
200
+ Returns
201
+ -------
202
+ jax.Array or Quantity
203
+ Output array with the same shape as the input.
204
+ """
205
+ return u.math.where(
206
+ x > lambd,
207
+ x - lambd,
208
+ u.math.where(
209
+ x < -lambd,
210
+ x + lambd,
211
+ u.Quantity(0., unit=u.get_unit(lambd))
212
+ )
213
+ )
214
+
215
+
216
+ def mish(x):
217
+ r"""
218
+ Mish activation function.
219
+
220
+ Mish is a self-regularized non-monotonic activation function.
221
+
222
+ .. math::
223
+ \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
224
+
225
+ Parameters
226
+ ----------
227
+ x : ArrayLike
228
+ Input array of any shape.
229
+
230
+ Returns
231
+ -------
232
+ jax.Array or Quantity
233
+ Output array with the same shape as the input.
234
+
235
+ References
236
+ ----------
237
+ .. [1] Misra, D. (2019). "Mish: A Self Regularized Non-Monotonic Activation Function."
238
+ arXiv:1908.08681
239
+ """
240
+ return x * u.math.tanh(softplus(x))
241
+
242
+
243
+ def rrelu(x, lower=0.125, upper=0.3333333333333333):
244
+ r"""
245
+ Randomized Leaky Rectified Linear Unit activation function.
246
+
247
+ The function is defined as:
248
+
249
+ .. math::
250
+ \text{RReLU}(x) =
251
+ \begin{cases}
252
+ x & \text{if } x \geq 0 \\
253
+ ax & \text{ otherwise }
254
+ \end{cases}
255
+
256
+ where :math:`a` is randomly sampled from uniform distribution
257
+ :math:`\mathcal{U}(\text{lower}, \text{upper})`.
258
+
259
+ Parameters
260
+ ----------
261
+ x : ArrayLike
262
+ Input array of any shape.
263
+ lower : float, optional
264
+ Lower bound of the uniform distribution for sampling the negative slope.
265
+ Default is 1/8.
266
+ upper : float, optional
267
+ Upper bound of the uniform distribution for sampling the negative slope.
268
+ Default is 1/3.
269
+
270
+ Returns
271
+ -------
272
+ jax.Array or Quantity
273
+ Output array with the same shape as the input.
274
+
275
+ References
276
+ ----------
277
+ .. [1] Xu, B., et al. (2015). "Empirical Evaluation of Rectified Activations
278
+ in Convolutional Network." arXiv:1505.00853
279
+ """
280
+ a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
281
+ return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
282
+
283
+
284
+ def hard_shrink(x, lambd=0.5):
285
+ r"""
286
+ Hard shrinkage activation function.
287
+
288
+ Applies the hard shrinkage function element-wise:
289
+
290
+ .. math::
291
+ \text{HardShrink}(x) =
292
+ \begin{cases}
293
+ x, & \text{ if } x > \lambda \\
294
+ x, & \text{ if } x < -\lambda \\
295
+ 0, & \text{ otherwise }
296
+ \end{cases}
297
+
298
+ Parameters
299
+ ----------
300
+ x : ArrayLike
301
+ Input array of any shape.
302
+ lambd : float, optional
303
+ The :math:`\lambda` threshold value for the hard shrinkage formulation.
304
+ Default is 0.5.
305
+
306
+ Returns
307
+ -------
308
+ jax.Array or Quantity
309
+ Output array with the same shape as the input.
310
+ """
311
+ return u.math.where(
312
+ x > lambd,
313
+ x,
314
+ u.math.where(
315
+ x < -lambd,
316
+ x,
317
+ u.Quantity(0., unit=u.get_unit(x))
318
+ )
319
+ )
320
+
321
+
322
+ def relu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
323
+ r"""
324
+ Rectified Linear Unit activation function.
325
+
326
+ Computes the element-wise function:
327
+
328
+ .. math::
329
+ \mathrm{relu}(x) = \max(x, 0)
330
+
331
+ Under differentiation, we take:
332
+
333
+ .. math::
334
+ \nabla \mathrm{relu}(0) = 0
335
+
336
+ Parameters
337
+ ----------
338
+ x : ArrayLike
339
+ Input array.
340
+
341
+ Returns
342
+ -------
343
+ jax.Array or Quantity
344
+ An array with the same shape as the input.
345
+
346
+ Examples
347
+ --------
348
+ .. code-block:: python
349
+
350
+ >>> import jax.numpy as jnp
351
+ >>> import brainstate
352
+ >>> brainstate.nn.relu(jnp.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
353
+ Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
354
+
355
+ See Also
356
+ --------
357
+ relu6 : ReLU6 activation function.
358
+ leaky_relu : Leaky ReLU activation function.
359
+
360
+ References
361
+ ----------
362
+ .. [1] For more information see "Numerical influence of ReLU'(0) on backpropagation"
363
+ https://openreview.net/forum?id=urrcVI-_jRm
364
+ """
365
+ return u.math.relu(x)
366
+
367
+
368
+ def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Union[jax.Array, u.Quantity]:
369
+ r"""
370
+ Squareplus activation function.
371
+
372
+ Computes the element-wise function:
373
+
374
+ .. math::
375
+ \mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
376
+
377
+ Parameters
378
+ ----------
379
+ x : ArrayLike
380
+ Input array.
381
+ b : ArrayLike, optional
382
+ Smoothness parameter. Default is 4.
383
+
384
+ Returns
385
+ -------
386
+ jax.Array or Quantity
387
+ An array with the same shape as the input.
388
+
389
+ References
390
+ ----------
391
+ .. [1] So, D., et al. (2021). "Primer: Searching for Efficient Transformers
392
+ for Language Modeling." arXiv:2112.11687
393
+ """
394
+ return u.math.squareplus(x, b=b)
395
+
396
+
397
+ def softplus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
398
+ r"""
399
+ Softplus activation function.
400
+
401
+ Computes the element-wise function:
402
+
403
+ .. math::
404
+ \mathrm{softplus}(x) = \log(1 + e^x)
405
+
406
+ Parameters
407
+ ----------
408
+ x : ArrayLike
409
+ Input array.
410
+
411
+ Returns
412
+ -------
413
+ jax.Array or Quantity
414
+ An array with the same shape as the input.
415
+ """
416
+ return u.math.softplus(x)
417
+
418
+
419
+ def soft_sign(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
420
+ r"""
421
+ Soft-sign activation function.
422
+
423
+ Computes the element-wise function:
424
+
425
+ .. math::
426
+ \mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
427
+
428
+ Parameters
429
+ ----------
430
+ x : ArrayLike
431
+ Input array.
432
+
433
+ Returns
434
+ -------
435
+ jax.Array or Quantity
436
+ An array with the same shape as the input.
437
+ """
438
+ return u.math.soft_sign(x)
439
+
440
+
441
+ def sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
442
+ r"""
443
+ Sigmoid activation function.
444
+
445
+ Computes the element-wise function:
446
+
447
+ .. math::
448
+ \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
449
+
450
+ Parameters
451
+ ----------
452
+ x : ArrayLike
453
+ Input array.
454
+
455
+ Returns
456
+ -------
457
+ jax.Array or Quantity
458
+ An array with the same shape as the input.
459
+
460
+ See Also
461
+ --------
462
+ log_sigmoid : Logarithm of the sigmoid function.
463
+ """
464
+ return u.math.sigmoid(x)
465
+
466
+
467
+ def silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
468
+ r"""
469
+ SiLU (Sigmoid Linear Unit) activation function.
470
+
471
+ Computes the element-wise function:
472
+
473
+ .. math::
474
+ \mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
475
+
476
+ Parameters
477
+ ----------
478
+ x : ArrayLike
479
+ Input array.
480
+
481
+ Returns
482
+ -------
483
+ jax.Array or Quantity
484
+ An array with the same shape as the input.
485
+
486
+ See Also
487
+ --------
488
+ sigmoid : The sigmoid function.
489
+ swish : Alias for silu.
490
+
491
+ Notes
492
+ -----
493
+ `swish` and `silu` are both aliases for the same function.
494
+ """
495
+ return u.math.silu(x)
496
+
497
+
498
+ swish = silu
499
+
500
+
501
+ def log_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
502
+ r"""
503
+ Log-sigmoid activation function.
504
+
505
+ Computes the element-wise function:
506
+
507
+ .. math::
508
+ \mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
509
+
510
+ Parameters
511
+ ----------
512
+ x : ArrayLike
513
+ Input array.
514
+
515
+ Returns
516
+ -------
517
+ jax.Array or Quantity
518
+ An array with the same shape as the input.
519
+
520
+ See Also
521
+ --------
522
+ sigmoid : The sigmoid function.
523
+ """
524
+ return u.math.log_sigmoid(x)
525
+
526
+
527
+ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
528
+ r"""
529
+ Exponential Linear Unit activation function.
530
+
531
+ Computes the element-wise function:
532
+
533
+ .. math::
534
+ \mathrm{elu}(x) = \begin{cases}
535
+ x, & x > 0\\
536
+ \alpha \left(\exp(x) - 1\right), & x \le 0
537
+ \end{cases}
538
+
539
+ Parameters
540
+ ----------
541
+ x : ArrayLike
542
+ Input array.
543
+ alpha : ArrayLike, optional
544
+ Scalar or array of alpha values. Default is 1.0.
545
+
546
+ Returns
547
+ -------
548
+ jax.Array or Quantity
549
+ An array with the same shape as the input.
550
+
551
+ See Also
552
+ --------
553
+ selu : Scaled ELU activation function.
554
+ celu : Continuously-differentiable ELU activation function.
555
+ """
556
+ return u.math.elu(x, alpha=alpha)
557
+
558
+
559
+ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Union[jax.Array, u.Quantity]:
560
+ r"""
561
+ Leaky Rectified Linear Unit activation function.
562
+
563
+ Computes the element-wise function:
564
+
565
+ .. math::
566
+ \mathrm{leaky\_relu}(x) = \begin{cases}
567
+ x, & x \ge 0\\
568
+ \alpha x, & x < 0
569
+ \end{cases}
570
+
571
+ where :math:`\alpha` = :code:`negative_slope`.
572
+
573
+ Parameters
574
+ ----------
575
+ x : ArrayLike
576
+ Input array.
577
+ negative_slope : ArrayLike, optional
578
+ Array or scalar specifying the negative slope. Default is 0.01.
579
+
580
+ Returns
581
+ -------
582
+ jax.Array or Quantity
583
+ An array with the same shape as the input.
584
+
585
+ See Also
586
+ --------
587
+ relu : Standard ReLU activation function.
588
+ prelu : Parametric ReLU with learnable slope.
589
+ """
590
+ return u.math.leaky_relu(x, negative_slope=negative_slope)
591
+
592
+
593
+ def _hard_tanh(x, min_val=- 1.0, max_val=1.0):
594
+ return jax.numpy.where(x > max_val, max_val, jax.numpy.where(x < min_val, min_val, x))
595
+
596
+
597
+ def hard_tanh(
598
+ x: ArrayLike,
599
+ min_val: float = - 1.0,
600
+ max_val: float = 1.0
601
+ ) -> Union[jax.Array, u.Quantity]:
602
+ r"""
603
+ Hard hyperbolic tangent activation function.
604
+
605
+ Computes the element-wise function:
606
+
607
+ .. math::
608
+ \mathrm{hard\_tanh}(x) = \begin{cases}
609
+ -1, & x < -1\\
610
+ x, & -1 \le x \le 1\\
611
+ 1, & 1 < x
612
+ \end{cases}
613
+
614
+ Parameters
615
+ ----------
616
+ x : ArrayLike
617
+ Input array.
618
+ min_val : float, optional
619
+ Minimum value of the linear region range. Default is -1.
620
+ max_val : float, optional
621
+ Maximum value of the linear region range. Default is 1.
622
+
623
+ Returns
624
+ -------
625
+ jax.Array or Quantity
626
+ An array with the same shape as the input.
627
+ """
628
+ x = u.Quantity(x)
629
+ min_val = u.Quantity(min_val).to(x.unit).mantissa
630
+ max_val = u.Quantity(max_val).to(x.unit).mantissa
631
+ return u.maybe_decimal(_hard_tanh(x.mantissa, min_val=min_val, max_val=max_val) * x.unit)
632
+
633
+
634
+ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
635
+ r"""
636
+ Continuously-differentiable Exponential Linear Unit activation.
637
+
638
+ Computes the element-wise function:
639
+
640
+ .. math::
641
+ \mathrm{celu}(x) = \begin{cases}
642
+ x, & x > 0\\
643
+ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
644
+ \end{cases}
645
+
646
+ Parameters
647
+ ----------
648
+ x : ArrayLike
649
+ Input array.
650
+ alpha : ArrayLike, optional
651
+ Scalar or array value controlling the smoothness. Default is 1.0.
652
+
653
+ Returns
654
+ -------
655
+ jax.Array or Quantity
656
+ An array with the same shape as the input.
657
+
658
+ References
659
+ ----------
660
+ .. [1] Barron, J. T. (2017). "Continuously Differentiable Exponential Linear Units."
661
+ arXiv:1704.07483
662
+ """
663
+ return u.math.celu(x, alpha=alpha)
664
+
665
+
666
+ def selu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
667
+ r"""
668
+ Scaled Exponential Linear Unit activation.
669
+
670
+ Computes the element-wise function:
671
+
672
+ .. math::
673
+ \mathrm{selu}(x) = \lambda \begin{cases}
674
+ x, & x > 0\\
675
+ \alpha e^x - \alpha, & x \le 0
676
+ \end{cases}
677
+
678
+ where :math:`\lambda = 1.0507009873554804934193349852946` and
679
+ :math:`\alpha = 1.6732632423543772848170429916717`.
680
+
681
+ Parameters
682
+ ----------
683
+ x : ArrayLike
684
+ Input array.
685
+
686
+ Returns
687
+ -------
688
+ jax.Array or Quantity
689
+ An array with the same shape as the input.
690
+
691
+ See Also
692
+ --------
693
+ elu : Exponential Linear Unit activation function.
694
+
695
+ References
696
+ ----------
697
+ .. [1] Klambauer, G., et al. (2017). "Self-Normalizing Neural Networks."
698
+ NeurIPS 2017.
699
+ """
700
+ return u.math.selu(x)
701
+
702
+
703
+ def gelu(x: ArrayLike, approximate: bool = True) -> Union[jax.Array, u.Quantity]:
704
+ r"""
705
+ Gaussian Error Linear Unit activation function.
706
+
707
+ If ``approximate=False``, computes the element-wise function:
708
+
709
+ .. math::
710
+ \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
711
+ \frac{x}{\sqrt{2}} \right) \right)
712
+
713
+ If ``approximate=True``, uses the approximate formulation of GELU:
714
+
715
+ .. math::
716
+ \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
717
+ \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
718
+
719
+ Parameters
720
+ ----------
721
+ x : ArrayLike
722
+ Input array.
723
+ approximate : bool, optional
724
+ Whether to use the approximate (True) or exact (False) formulation.
725
+ Default is True.
726
+
727
+ Returns
728
+ -------
729
+ jax.Array or Quantity
730
+ An array with the same shape as the input.
731
+
732
+ References
733
+ ----------
734
+ .. [1] Hendrycks, D., & Gimpel, K. (2016). "Gaussian Error Linear Units (GELUs)."
735
+ arXiv:1606.08415
736
+ """
737
+ return u.math.gelu(x, approximate=approximate)
738
+
739
+
740
+ def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
741
+ r"""
742
+ Gated Linear Unit activation function.
743
+
744
+ Computes the function:
745
+
746
+ .. math::
747
+ \mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot
748
+ \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right]
749
+ \right)
750
+
751
+ where the array is split into two along ``axis``. The size of the ``axis``
752
+ dimension must be divisible by two.
753
+
754
+ Parameters
755
+ ----------
756
+ x : ArrayLike
757
+ Input array. The dimension specified by ``axis`` must be divisible by 2.
758
+ axis : int, optional
759
+ The axis along which the split should be computed. Default is -1.
760
+
761
+ Returns
762
+ -------
763
+ jax.Array or Quantity
764
+ An array with the same shape as input except the ``axis`` dimension
765
+ is halved.
766
+
767
+ See Also
768
+ --------
769
+ sigmoid : The sigmoid activation function.
770
+ """
771
+ return u.math.glu(x, axis=axis)
772
+
773
+
774
+ def log_softmax(x: ArrayLike,
775
+ axis: int | tuple[int, ...] | None = -1,
776
+ where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
777
+ r"""
778
+ Log-Softmax function.
779
+
780
+ Computes the logarithm of the softmax function, which rescales
781
+ elements to the range :math:`[-\infty, 0)`.
782
+
783
+ .. math ::
784
+ \mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
785
+ \right)
786
+
787
+ Parameters
788
+ ----------
789
+ x : ArrayLike
790
+ Input array.
791
+ axis : int or tuple of int, optional
792
+ The axis or axes along which the log-softmax should be computed.
793
+ Either an integer or a tuple of integers. Default is -1.
794
+ where : ArrayLike, optional
795
+ Elements to include in the log-softmax computation.
796
+
797
+ Returns
798
+ -------
799
+ jax.Array or Quantity
800
+ An array with the same shape as the input.
801
+
802
+ See Also
803
+ --------
804
+ softmax : The softmax function.
805
+ """
806
+ return jax.nn.log_softmax(x, axis=axis, where=where)
807
+
808
+
809
+ def softmax(x: ArrayLike,
810
+ axis: int | tuple[int, ...] | None = -1,
811
+ where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
812
+ r"""
813
+ Softmax activation function.
814
+
815
+ Computes the function which rescales elements to the range :math:`[0, 1]`
816
+ such that the elements along :code:`axis` sum to :math:`1`.
817
+
818
+ .. math ::
819
+ \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
820
+
821
+ Parameters
822
+ ----------
823
+ x : ArrayLike
824
+ Input array.
825
+ axis : int or tuple of int, optional
826
+ The axis or axes along which the softmax should be computed. The
827
+ softmax output summed across these dimensions should sum to :math:`1`.
828
+ Either an integer or a tuple of integers. Default is -1.
829
+ where : ArrayLike, optional
830
+ Elements to include in the softmax computation.
831
+
832
+ Returns
833
+ -------
834
+ jax.Array or Quantity
835
+ An array with the same shape as the input.
836
+
837
+ See Also
838
+ --------
839
+ log_softmax : Logarithm of the softmax function.
840
+ softmin : Softmin activation function.
841
+ """
842
+ return jax.nn.softmax(x, axis=axis, where=where)
843
+
844
+
845
+ def standardize(x: ArrayLike,
846
+ axis: int | tuple[int, ...] | None = -1,
847
+ variance: ArrayLike | None = None,
848
+ epsilon: ArrayLike = 1e-5,
849
+ where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
850
+ r"""
851
+ Standardize (normalize) an array.
852
+
853
+ Normalizes an array by subtracting the mean and dividing by the standard
854
+ deviation :math:`\sqrt{\mathrm{variance}}`.
855
+
856
+ Parameters
857
+ ----------
858
+ x : ArrayLike
859
+ Input array.
860
+ axis : int or tuple of int, optional
861
+ The axis or axes along which to compute the mean and variance.
862
+ Default is -1.
863
+ variance : ArrayLike, optional
864
+ Pre-computed variance. If None, variance is computed from ``x``.
865
+ epsilon : ArrayLike, optional
866
+ A small constant added to the variance to avoid division by zero.
867
+ Default is 1e-5.
868
+ where : ArrayLike, optional
869
+ Elements to include in the computation.
870
+
871
+ Returns
872
+ -------
873
+ jax.Array or Quantity
874
+ Standardized array with the same shape as the input.
875
+ """
876
+ return jax.nn.standardize(x, axis=axis, where=where, variance=variance, epsilon=epsilon)
877
+
878
+
879
+ def one_hot(x: Any,
880
+ num_classes: int, *,
881
+ dtype: Any = jax.numpy.float_,
882
+ axis: Union[int, Sequence[int]] = -1) -> Union[jax.Array, u.Quantity]:
883
+ """
884
+ One-hot encode the given indices.
885
+
886
+ Each index in the input ``x`` is encoded as a vector of zeros of length
887
+ ``num_classes`` with the element at ``index`` set to one.
888
+
889
+ Indices outside the range [0, num_classes) will be encoded as zeros.
890
+
891
+ Parameters
892
+ ----------
893
+ x : ArrayLike
894
+ A tensor of indices.
895
+ num_classes : int
896
+ Number of classes in the one-hot dimension.
897
+ dtype : dtype, optional
898
+ The dtype for the returned values. Default is ``jnp.float_``.
899
+ axis : int or Sequence of int, optional
900
+ The axis or axes along which the function should be computed.
901
+ Default is -1.
902
+
903
+ Returns
904
+ -------
905
+ jax.Array or Quantity
906
+ One-hot encoded array.
907
+
908
+ Examples
909
+ --------
910
+ .. code-block:: python
911
+
912
+ >>> import jax.numpy as jnp
913
+ >>> import brainstate
914
+ >>> brainstate.nn.one_hot(jnp.array([0, 1, 2]), 3)
915
+ Array([[1., 0., 0.],
916
+ [0., 1., 0.],
917
+ [0., 0., 1.]], dtype=float32)
918
+
919
+ >>> # Indices outside the range are encoded as zeros
920
+ >>> brainstate.nn.one_hot(jnp.array([-1, 3]), 3)
921
+ Array([[0., 0., 0.],
922
+ [0., 0., 0.]], dtype=float32)
923
+ """
924
+ return jax.nn.one_hot(x, axis=axis, num_classes=num_classes, dtype=dtype)
925
+
926
+
927
+ def relu6(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
928
+ r"""
929
+ Rectified Linear Unit 6 activation function.
930
+
931
+ Computes the element-wise function:
932
+
933
+ .. math::
934
+ \mathrm{relu6}(x) = \min(\max(x, 0), 6)
935
+
936
+ Under differentiation, we take:
937
+
938
+ .. math::
939
+ \nabla \mathrm{relu}(0) = 0
940
+
941
+ and
942
+
943
+ .. math::
944
+ \nabla \mathrm{relu}(6) = 0
945
+
946
+ Parameters
947
+ ----------
948
+ x : ArrayLike
949
+ Input array.
950
+
951
+ Returns
952
+ -------
953
+ jax.Array or Quantity
954
+ An array with the same shape as the input.
955
+
956
+ See Also
957
+ --------
958
+ relu : Standard ReLU activation function.
959
+ """
960
+ return u.math.relu6(x)
961
+
962
+
963
+ def hard_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
964
+ r"""
965
+ Hard Sigmoid activation function.
966
+
967
+ Computes the element-wise function:
968
+
969
+ .. math::
970
+ \mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
971
+
972
+ Parameters
973
+ ----------
974
+ x : ArrayLike
975
+ Input array.
976
+
977
+ Returns
978
+ -------
979
+ jax.Array or Quantity
980
+ An array with the same shape as the input.
981
+
982
+ See Also
983
+ --------
984
+ relu6 : ReLU6 activation function.
985
+ sigmoid : Standard sigmoid function.
986
+ """
987
+ return u.math.hard_sigmoid(x)
988
+
989
+
990
+ def hard_silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
991
+ r"""
992
+ Hard SiLU (Swish) activation function.
993
+
994
+ Computes the element-wise function:
995
+
996
+ .. math::
997
+ \mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
998
+
999
+ Parameters
1000
+ ----------
1001
+ x : ArrayLike
1002
+ Input array.
1003
+
1004
+ Returns
1005
+ -------
1006
+ jax.Array or Quantity
1007
+ An array with the same shape as the input.
1008
+
1009
+ See Also
1010
+ --------
1011
+ hard_sigmoid : Hard sigmoid activation function.
1012
+ silu : Standard SiLU activation function.
1013
+ hard_swish : Alias for hard_silu.
1014
+
1015
+ Notes
1016
+ -----
1017
+ Both `hard_silu` and `hard_swish` are aliases for the same function.
1018
+ """
1019
+ return u.math.hard_silu(x)
1020
+
1021
+
1022
+ hard_swish = hard_silu
1023
+
1024
+
1025
+ def sparse_plus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
1026
+ r"""
1027
+ Sparse plus activation function.
1028
+
1029
+ Computes the function:
1030
+
1031
+ .. math::
1032
+
1033
+ \mathrm{sparse\_plus}(x) = \begin{cases}
1034
+ 0, & x \leq -1\\
1035
+ \frac{1}{4}(x+1)^2, & -1 < x < 1 \\
1036
+ x, & 1 \leq x
1037
+ \end{cases}
1038
+
1039
+ This is the twin function of the softplus activation, ensuring a zero output
1040
+ for inputs less than -1 and a linear output for inputs greater than 1,
1041
+ while remaining smooth, convex, and monotonic between -1 and 1.
1042
+
1043
+ Parameters
1044
+ ----------
1045
+ x : ArrayLike
1046
+ Input array.
1047
+
1048
+ Returns
1049
+ -------
1050
+ jax.Array or Quantity
1051
+ An array with the same shape as the input.
1052
+
1053
+ See Also
1054
+ --------
1055
+ sparse_sigmoid : Derivative of sparse_plus.
1056
+ softplus : Standard softplus activation function.
1057
+ """
1058
+ return u.math.sparse_plus(x)
1059
+
1060
+
1061
+ def sparse_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
1062
+ r"""
1063
+ Sparse sigmoid activation function.
1064
+
1065
+ Computes the function:
1066
+
1067
+ .. math::
1068
+
1069
+ \mathrm{sparse\_sigmoid}(x) = \begin{cases}
1070
+ 0, & x \leq -1\\
1071
+ \frac{1}{2}(x+1), & -1 < x < 1 \\
1072
+ 1, & 1 \leq x
1073
+ \end{cases}
1074
+
1075
+ This is the twin function of the standard sigmoid activation, ensuring a zero
1076
+ output for inputs less than -1, a 1 output for inputs greater than 1, and a
1077
+ linear output for inputs between -1 and 1. It is the derivative of `sparse_plus`.
1078
+
1079
+ Parameters
1080
+ ----------
1081
+ x : ArrayLike
1082
+ Input array.
1083
+
1084
+ Returns
1085
+ -------
1086
+ jax.Array or Quantity
1087
+ An array with the same shape as the input.
1088
+
1089
+ See Also
1090
+ --------
1091
+ sigmoid : Standard sigmoid activation function.
1092
+ sparse_plus : Sparse plus activation function.
1093
+
1094
+ References
1095
+ ----------
1096
+ .. [1] Martins, A. F. T., & Astudillo, R. F. (2016). "From Softmax to Sparsemax:
1097
+ A Sparse Model of Attention and Multi-Label Classification."
1098
+ In ICML. See also "Learning with Fenchel-Young Losses", arXiv:1901.02324
1099
+ """
1100
+ return u.math.sparse_sigmoid(x)