brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -66,41 +66,49 @@ __all__ = [
66
66
 
67
67
 
68
68
  def tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
69
- r"""Hyperbolic tangent activation function.
69
+ r"""
70
+ Hyperbolic tangent activation function.
70
71
 
71
72
  Computes the element-wise function:
72
73
 
73
74
  .. math::
74
75
  \mathrm{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
75
76
 
76
- Args:
77
- x : input array
77
+ Parameters
78
+ ----------
79
+ x : ArrayLike
80
+ Input array.
78
81
 
79
- Returns:
80
- An array.
82
+ Returns
83
+ -------
84
+ jax.Array or Quantity
85
+ An array with the same shape as the input.
81
86
  """
82
87
  return u.math.tanh(x)
83
88
 
84
89
 
85
90
  def softmin(x, axis=-1):
86
91
  r"""
87
- Applies the Softmin function to an n-dimensional input Tensor
88
- rescaling them so that the elements of the n-dimensional output Tensor
89
- lie in the range `[0, 1]` and sum to 1.
92
+ Softmin activation function.
90
93
 
91
- Softmin is defined as:
94
+ Applies the Softmin function to an n-dimensional input tensor, rescaling elements
95
+ so that they lie in the range [0, 1] and sum to 1 along the specified axis.
92
96
 
93
97
  .. math::
94
98
  \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
95
99
 
96
- Shape:
97
- - Input: :math:`(*)` where `*` means, any number of additional
98
- dimensions
99
- - Output: :math:`(*)`, same shape as the input
100
-
101
- Args:
102
- axis (int): A dimension along which Softmin will be computed (so every slice
103
- along dim will sum to 1).
100
+ Parameters
101
+ ----------
102
+ x : ArrayLike
103
+ Input array of any shape.
104
+ axis : int, optional
105
+ The axis along which Softmin will be computed. Every slice along this
106
+ dimension will sum to 1. Default is -1.
107
+
108
+ Returns
109
+ -------
110
+ jax.Array or Quantity
111
+ Output array with the same shape as the input.
104
112
  """
105
113
  unnormalized = u.math.exp(-x)
106
114
  return unnormalized / unnormalized.sum(axis, keepdims=True)
@@ -108,22 +116,36 @@ def softmin(x, axis=-1):
108
116
 
109
117
  def tanh_shrink(x):
110
118
  r"""
119
+ Tanh shrink activation function.
120
+
111
121
  Applies the element-wise function:
112
122
 
113
123
  .. math::
114
124
  \text{Tanhshrink}(x) = x - \tanh(x)
125
+
126
+ Parameters
127
+ ----------
128
+ x : ArrayLike
129
+ Input array.
130
+
131
+ Returns
132
+ -------
133
+ jax.Array or Quantity
134
+ Output array with the same shape as the input.
115
135
  """
116
136
  return x - u.math.tanh(x)
117
137
 
118
138
 
119
139
  def prelu(x, a=0.25):
120
140
  r"""
141
+ Parametric Rectified Linear Unit activation function.
142
+
121
143
  Applies the element-wise function:
122
144
 
123
145
  .. math::
124
146
  \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
125
147
 
126
- or
148
+ or equivalently:
127
149
 
128
150
  .. math::
129
151
  \text{PReLU}(x) =
@@ -132,16 +154,32 @@ def prelu(x, a=0.25):
132
154
  ax, & \text{ otherwise }
133
155
  \end{cases}
134
156
 
135
- Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
136
- parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
137
- a separate :math:`a` is used for each input channel.
157
+ Parameters
158
+ ----------
159
+ x : ArrayLike
160
+ Input array.
161
+ a : float or ArrayLike, optional
162
+ The negative slope coefficient. Can be a learnable parameter.
163
+ Default is 0.25.
164
+
165
+ Returns
166
+ -------
167
+ jax.Array or Quantity
168
+ Output array with the same shape as the input.
169
+
170
+ Notes
171
+ -----
172
+ When used in neural network layers, :math:`a` can be a learnable parameter
173
+ that is optimized during training.
138
174
  """
139
175
  return u.math.where(x >= 0., x, a * x)
140
176
 
141
177
 
142
178
  def soft_shrink(x, lambd=0.5):
143
179
  r"""
144
- Applies the soft shrinkage function elementwise:
180
+ Soft shrinkage activation function.
181
+
182
+ Applies the soft shrinkage function element-wise:
145
183
 
146
184
  .. math::
147
185
  \text{SoftShrinkage}(x) =
@@ -151,43 +189,60 @@ def soft_shrink(x, lambd=0.5):
151
189
  0, & \text{ otherwise }
152
190
  \end{cases}
153
191
 
154
- Args:
155
- lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
156
-
157
- Shape:
158
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
159
- - Output: :math:`(*)`, same shape as the input.
192
+ Parameters
193
+ ----------
194
+ x : ArrayLike
195
+ Input array of any shape.
196
+ lambd : float, optional
197
+ The :math:`\lambda` value for the soft shrinkage formulation.
198
+ Must be non-negative. Default is 0.5.
199
+
200
+ Returns
201
+ -------
202
+ jax.Array or Quantity
203
+ Output array with the same shape as the input.
160
204
  """
161
- return u.math.where(x > lambd,
162
- x - lambd,
163
- u.math.where(x < -lambd,
164
- x + lambd,
165
- u.Quantity(0., unit=u.get_unit(lambd))))
205
+ return u.math.where(
206
+ x > lambd,
207
+ x - lambd,
208
+ u.math.where(
209
+ x < -lambd,
210
+ x + lambd,
211
+ u.Quantity(0., unit=u.get_unit(lambd))
212
+ )
213
+ )
166
214
 
167
215
 
168
216
  def mish(x):
169
- r"""Applies the Mish function, element-wise.
217
+ r"""
218
+ Mish activation function.
170
219
 
171
- Mish: A Self Regularized Non-Monotonic Neural Activation Function.
220
+ Mish is a self-regularized non-monotonic activation function.
172
221
 
173
222
  .. math::
174
223
  \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
175
224
 
176
- .. note::
177
- See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
225
+ Parameters
226
+ ----------
227
+ x : ArrayLike
228
+ Input array of any shape.
178
229
 
179
- Shape:
180
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
181
- - Output: :math:`(*)`, same shape as the input.
230
+ Returns
231
+ -------
232
+ jax.Array or Quantity
233
+ Output array with the same shape as the input.
234
+
235
+ References
236
+ ----------
237
+ .. [1] Misra, D. (2019). "Mish: A Self Regularized Non-Monotonic Activation Function."
238
+ arXiv:1908.08681
182
239
  """
183
240
  return x * u.math.tanh(softplus(x))
184
241
 
185
242
 
186
243
  def rrelu(x, lower=0.125, upper=0.3333333333333333):
187
- r"""Applies the randomized leaky rectified liner unit function, element-wise,
188
- as described in the paper:
189
-
190
- `Empirical Evaluation of Rectified Activations in Convolutional Network`_.
244
+ r"""
245
+ Randomized Leaky Rectified Linear Unit activation function.
191
246
 
192
247
  The function is defined as:
193
248
 
@@ -201,27 +256,36 @@ def rrelu(x, lower=0.125, upper=0.3333333333333333):
201
256
  where :math:`a` is randomly sampled from uniform distribution
202
257
  :math:`\mathcal{U}(\text{lower}, \text{upper})`.
203
258
 
204
- See: https://arxiv.org/pdf/1505.00853.pdf
205
-
206
- Args:
207
- lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
208
- upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
209
-
210
- Shape:
211
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
212
- - Output: :math:`(*)`, same shape as the input.
213
-
214
- .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
215
- https://arxiv.org/abs/1505.00853
259
+ Parameters
260
+ ----------
261
+ x : ArrayLike
262
+ Input array of any shape.
263
+ lower : float, optional
264
+ Lower bound of the uniform distribution for sampling the negative slope.
265
+ Default is 1/8.
266
+ upper : float, optional
267
+ Upper bound of the uniform distribution for sampling the negative slope.
268
+ Default is 1/3.
269
+
270
+ Returns
271
+ -------
272
+ jax.Array or Quantity
273
+ Output array with the same shape as the input.
274
+
275
+ References
276
+ ----------
277
+ .. [1] Xu, B., et al. (2015). "Empirical Evaluation of Rectified Activations
278
+ in Convolutional Network." arXiv:1505.00853
216
279
  """
217
280
  a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
218
281
  return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
219
282
 
220
283
 
221
284
  def hard_shrink(x, lambd=0.5):
222
- r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
285
+ r"""
286
+ Hard shrinkage activation function.
223
287
 
224
- Hardshrink is defined as:
288
+ Applies the hard shrinkage function element-wise:
225
289
 
226
290
  .. math::
227
291
  \text{HardShrink}(x) =
@@ -231,139 +295,202 @@ def hard_shrink(x, lambd=0.5):
231
295
  0, & \text{ otherwise }
232
296
  \end{cases}
233
297
 
234
- Args:
235
- lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
236
-
237
- Shape:
238
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
239
- - Output: :math:`(*)`, same shape as the input.
240
-
298
+ Parameters
299
+ ----------
300
+ x : ArrayLike
301
+ Input array of any shape.
302
+ lambd : float, optional
303
+ The :math:`\lambda` threshold value for the hard shrinkage formulation.
304
+ Default is 0.5.
305
+
306
+ Returns
307
+ -------
308
+ jax.Array or Quantity
309
+ Output array with the same shape as the input.
241
310
  """
242
- return u.math.where(x > lambd,
243
- x,
244
- u.math.where(x < -lambd,
245
- x,
246
- u.Quantity(0., unit=u.get_unit(x))))
311
+ return u.math.where(
312
+ x > lambd,
313
+ x,
314
+ u.math.where(
315
+ x < -lambd,
316
+ x,
317
+ u.Quantity(0., unit=u.get_unit(x))
318
+ )
319
+ )
247
320
 
248
321
 
249
322
  def relu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
250
- r"""Rectified linear unit activation function.
323
+ r"""
324
+ Rectified Linear Unit activation function.
251
325
 
252
326
  Computes the element-wise function:
253
327
 
254
328
  .. math::
255
329
  \mathrm{relu}(x) = \max(x, 0)
256
330
 
257
- except under differentiation, we take:
331
+ Under differentiation, we take:
258
332
 
259
333
  .. math::
260
334
  \nabla \mathrm{relu}(0) = 0
261
335
 
262
- For more information see
263
- `Numerical influence of ReLU’(0) on backpropagation
264
- <https://openreview.net/forum?id=urrcVI-_jRm>`_.
265
-
266
- Args:
267
- x : input array
268
-
269
- Returns:
270
- An array.
271
-
272
- Example:
273
- >>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
274
- Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
275
-
276
- See also:
277
- :func:`relu6`
278
-
336
+ Parameters
337
+ ----------
338
+ x : ArrayLike
339
+ Input array.
340
+
341
+ Returns
342
+ -------
343
+ jax.Array or Quantity
344
+ An array with the same shape as the input.
345
+
346
+ Examples
347
+ --------
348
+ .. code-block:: python
349
+
350
+ >>> import jax.numpy as jnp
351
+ >>> import brainstate
352
+ >>> brainstate.nn.relu(jnp.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
353
+ Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
354
+
355
+ See Also
356
+ --------
357
+ relu6 : ReLU6 activation function.
358
+ leaky_relu : Leaky ReLU activation function.
359
+
360
+ References
361
+ ----------
362
+ .. [1] For more information see "Numerical influence of ReLU'(0) on backpropagation"
363
+ https://openreview.net/forum?id=urrcVI-_jRm
279
364
  """
280
365
  return u.math.relu(x)
281
366
 
282
367
 
283
368
  def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Union[jax.Array, u.Quantity]:
284
- r"""Squareplus activation function.
369
+ r"""
370
+ Squareplus activation function.
285
371
 
286
- Computes the element-wise function
372
+ Computes the element-wise function:
287
373
 
288
374
  .. math::
289
375
  \mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
290
376
 
291
- as described in https://arxiv.org/abs/2112.11687.
292
-
293
- Args:
294
- x : input array
295
- b : smoothness parameter
377
+ Parameters
378
+ ----------
379
+ x : ArrayLike
380
+ Input array.
381
+ b : ArrayLike, optional
382
+ Smoothness parameter. Default is 4.
383
+
384
+ Returns
385
+ -------
386
+ jax.Array or Quantity
387
+ An array with the same shape as the input.
388
+
389
+ References
390
+ ----------
391
+ .. [1] So, D., et al. (2021). "Primer: Searching for Efficient Transformers
392
+ for Language Modeling." arXiv:2112.11687
296
393
  """
297
394
  return u.math.squareplus(x, b=b)
298
395
 
299
396
 
300
397
  def softplus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
301
- r"""Softplus activation function.
398
+ r"""
399
+ Softplus activation function.
302
400
 
303
- Computes the element-wise function
401
+ Computes the element-wise function:
304
402
 
305
403
  .. math::
306
404
  \mathrm{softplus}(x) = \log(1 + e^x)
307
405
 
308
- Args:
309
- x : input array
406
+ Parameters
407
+ ----------
408
+ x : ArrayLike
409
+ Input array.
410
+
411
+ Returns
412
+ -------
413
+ jax.Array or Quantity
414
+ An array with the same shape as the input.
310
415
  """
311
416
  return u.math.softplus(x)
312
417
 
313
418
 
314
419
  def soft_sign(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
315
- r"""Soft-sign activation function.
420
+ r"""
421
+ Soft-sign activation function.
316
422
 
317
- Computes the element-wise function
423
+ Computes the element-wise function:
318
424
 
319
425
  .. math::
320
426
  \mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
321
427
 
322
- Args:
323
- x : input array
428
+ Parameters
429
+ ----------
430
+ x : ArrayLike
431
+ Input array.
432
+
433
+ Returns
434
+ -------
435
+ jax.Array or Quantity
436
+ An array with the same shape as the input.
324
437
  """
325
438
  return u.math.soft_sign(x)
326
439
 
327
440
 
328
441
  def sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
329
- r"""Sigmoid activation function.
442
+ r"""
443
+ Sigmoid activation function.
330
444
 
331
445
  Computes the element-wise function:
332
446
 
333
447
  .. math::
334
448
  \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
335
449
 
336
- Args:
337
- x : input array
338
-
339
- Returns:
340
- An array.
450
+ Parameters
451
+ ----------
452
+ x : ArrayLike
453
+ Input array.
341
454
 
342
- See also:
343
- :func:`log_sigmoid`
455
+ Returns
456
+ -------
457
+ jax.Array or Quantity
458
+ An array with the same shape as the input.
344
459
 
460
+ See Also
461
+ --------
462
+ log_sigmoid : Logarithm of the sigmoid function.
345
463
  """
346
464
  return u.math.sigmoid(x)
347
465
 
348
466
 
349
467
  def silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
350
- r"""SiLU (a.k.a. swish) activation function.
468
+ r"""
469
+ SiLU (Sigmoid Linear Unit) activation function.
351
470
 
352
471
  Computes the element-wise function:
353
472
 
354
473
  .. math::
355
474
  \mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
356
475
 
357
- :func:`swish` and :func:`silu` are both aliases for the same function.
476
+ Parameters
477
+ ----------
478
+ x : ArrayLike
479
+ Input array.
358
480
 
359
- Args:
360
- x : input array
481
+ Returns
482
+ -------
483
+ jax.Array or Quantity
484
+ An array with the same shape as the input.
361
485
 
362
- Returns:
363
- An array.
486
+ See Also
487
+ --------
488
+ sigmoid : The sigmoid function.
489
+ swish : Alias for silu.
364
490
 
365
- See also:
366
- :func:`sigmoid`
491
+ Notes
492
+ -----
493
+ `swish` and `silu` are both aliases for the same function.
367
494
  """
368
495
  return u.math.silu(x)
369
496
 
@@ -372,27 +499,34 @@ swish = silu
372
499
 
373
500
 
374
501
  def log_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
375
- r"""Log-sigmoid activation function.
502
+ r"""
503
+ Log-sigmoid activation function.
376
504
 
377
505
  Computes the element-wise function:
378
506
 
379
507
  .. math::
380
508
  \mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
381
509
 
382
- Args:
383
- x : input array
510
+ Parameters
511
+ ----------
512
+ x : ArrayLike
513
+ Input array.
384
514
 
385
- Returns:
386
- An array.
515
+ Returns
516
+ -------
517
+ jax.Array or Quantity
518
+ An array with the same shape as the input.
387
519
 
388
- See also:
389
- :func:`sigmoid`
520
+ See Also
521
+ --------
522
+ sigmoid : The sigmoid function.
390
523
  """
391
524
  return u.math.log_sigmoid(x)
392
525
 
393
526
 
394
527
  def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
395
- r"""Exponential linear unit activation function.
528
+ r"""
529
+ Exponential Linear Unit activation function.
396
530
 
397
531
  Computes the element-wise function:
398
532
 
@@ -402,21 +536,29 @@ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
402
536
  \alpha \left(\exp(x) - 1\right), & x \le 0
403
537
  \end{cases}
404
538
 
405
- Args:
406
- x : input array
407
- alpha : scalar or array of alpha values (default: 1.0)
408
-
409
- Returns:
410
- An array.
411
-
412
- See also:
413
- :func:`selu`
539
+ Parameters
540
+ ----------
541
+ x : ArrayLike
542
+ Input array.
543
+ alpha : ArrayLike, optional
544
+ Scalar or array of alpha values. Default is 1.0.
545
+
546
+ Returns
547
+ -------
548
+ jax.Array or Quantity
549
+ An array with the same shape as the input.
550
+
551
+ See Also
552
+ --------
553
+ selu : Scaled ELU activation function.
554
+ celu : Continuously-differentiable ELU activation function.
414
555
  """
415
556
  return u.math.elu(x, alpha=alpha)
416
557
 
417
558
 
418
559
  def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Union[jax.Array, u.Quantity]:
419
- r"""Leaky rectified linear unit activation function.
560
+ r"""
561
+ Leaky Rectified Linear Unit activation function.
420
562
 
421
563
  Computes the element-wise function:
422
564
 
@@ -428,15 +570,22 @@ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Union[jax.Arra
428
570
 
429
571
  where :math:`\alpha` = :code:`negative_slope`.
430
572
 
431
- Args:
432
- x : input array
433
- negative_slope : array or scalar specifying the negative slope (default: 0.01)
434
-
435
- Returns:
436
- An array.
437
-
438
- See also:
439
- :func:`relu`
573
+ Parameters
574
+ ----------
575
+ x : ArrayLike
576
+ Input array.
577
+ negative_slope : ArrayLike, optional
578
+ Array or scalar specifying the negative slope. Default is 0.01.
579
+
580
+ Returns
581
+ -------
582
+ jax.Array or Quantity
583
+ An array with the same shape as the input.
584
+
585
+ See Also
586
+ --------
587
+ relu : Standard ReLU activation function.
588
+ prelu : Parametric ReLU with learnable slope.
440
589
  """
441
590
  return u.math.leaky_relu(x, negative_slope=negative_slope)
442
591
 
@@ -450,7 +599,8 @@ def hard_tanh(
450
599
  min_val: float = - 1.0,
451
600
  max_val: float = 1.0
452
601
  ) -> Union[jax.Array, u.Quantity]:
453
- r"""Hard :math:`\mathrm{tanh}` activation function.
602
+ r"""
603
+ Hard hyperbolic tangent activation function.
454
604
 
455
605
  Computes the element-wise function:
456
606
 
@@ -461,13 +611,19 @@ def hard_tanh(
461
611
  1, & 1 < x
462
612
  \end{cases}
463
613
 
464
- Args:
465
- x : input array
466
- min_val: float. minimum value of the linear region range. Default: -1
467
- max_val: float. maximum value of the linear region range. Default: 1
468
-
469
- Returns:
470
- An array.
614
+ Parameters
615
+ ----------
616
+ x : ArrayLike
617
+ Input array.
618
+ min_val : float, optional
619
+ Minimum value of the linear region range. Default is -1.
620
+ max_val : float, optional
621
+ Maximum value of the linear region range. Default is 1.
622
+
623
+ Returns
624
+ -------
625
+ jax.Array or Quantity
626
+ An array with the same shape as the input.
471
627
  """
472
628
  x = u.Quantity(x)
473
629
  min_val = u.Quantity(min_val).to(x.unit).mantissa
@@ -476,7 +632,8 @@ def hard_tanh(
476
632
 
477
633
 
478
634
  def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
479
- r"""Continuously-differentiable exponential linear unit activation.
635
+ r"""
636
+ Continuously-differentiable Exponential Linear Unit activation.
480
637
 
481
638
  Computes the element-wise function:
482
639
 
@@ -486,22 +643,29 @@ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
486
643
  \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
487
644
  \end{cases}
488
645
 
489
- For more information, see
490
- `Continuously Differentiable Exponential Linear Units
491
- <https://arxiv.org/pdf/1704.07483.pdf>`_.
492
-
493
- Args:
494
- x : input array
495
- alpha : array or scalar (default: 1.0)
496
-
497
- Returns:
498
- An array.
646
+ Parameters
647
+ ----------
648
+ x : ArrayLike
649
+ Input array.
650
+ alpha : ArrayLike, optional
651
+ Scalar or array value controlling the smoothness. Default is 1.0.
652
+
653
+ Returns
654
+ -------
655
+ jax.Array or Quantity
656
+ An array with the same shape as the input.
657
+
658
+ References
659
+ ----------
660
+ .. [1] Barron, J. T. (2017). "Continuously Differentiable Exponential Linear Units."
661
+ arXiv:1704.07483
499
662
  """
500
663
  return u.math.celu(x, alpha=alpha)
501
664
 
502
665
 
503
666
  def selu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
504
- r"""Scaled exponential linear unit activation.
667
+ r"""
668
+ Scaled Exponential Linear Unit activation.
505
669
 
506
670
  Computes the element-wise function:
507
671
 
@@ -514,24 +678,31 @@ def selu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
514
678
  where :math:`\lambda = 1.0507009873554804934193349852946` and
515
679
  :math:`\alpha = 1.6732632423543772848170429916717`.
516
680
 
517
- For more information, see
518
- `Self-Normalizing Neural Networks
519
- <https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf>`_.
681
+ Parameters
682
+ ----------
683
+ x : ArrayLike
684
+ Input array.
520
685
 
521
- Args:
522
- x : input array
686
+ Returns
687
+ -------
688
+ jax.Array or Quantity
689
+ An array with the same shape as the input.
523
690
 
524
- Returns:
525
- An array.
691
+ See Also
692
+ --------
693
+ elu : Exponential Linear Unit activation function.
526
694
 
527
- See also:
528
- :func:`elu`
695
+ References
696
+ ----------
697
+ .. [1] Klambauer, G., et al. (2017). "Self-Normalizing Neural Networks."
698
+ NeurIPS 2017.
529
699
  """
530
700
  return u.math.selu(x)
531
701
 
532
702
 
533
703
  def gelu(x: ArrayLike, approximate: bool = True) -> Union[jax.Array, u.Quantity]:
534
- r"""Gaussian error linear unit activation function.
704
+ r"""
705
+ Gaussian Error Linear Unit activation function.
535
706
 
536
707
  If ``approximate=False``, computes the element-wise function:
537
708
 
@@ -545,18 +716,30 @@ def gelu(x: ArrayLike, approximate: bool = True) -> Union[jax.Array, u.Quantity]
545
716
  \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
546
717
  \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
547
718
 
548
- For more information, see `Gaussian Error Linear Units (GELUs)
549
- <https://arxiv.org/abs/1606.08415>`_, section 2.
550
-
551
- Args:
552
- x : input array
553
- approximate: whether to use the approximate or exact formulation.
719
+ Parameters
720
+ ----------
721
+ x : ArrayLike
722
+ Input array.
723
+ approximate : bool, optional
724
+ Whether to use the approximate (True) or exact (False) formulation.
725
+ Default is True.
726
+
727
+ Returns
728
+ -------
729
+ jax.Array or Quantity
730
+ An array with the same shape as the input.
731
+
732
+ References
733
+ ----------
734
+ .. [1] Hendrycks, D., & Gimpel, K. (2016). "Gaussian Error Linear Units (GELUs)."
735
+ arXiv:1606.08415
554
736
  """
555
737
  return u.math.gelu(x, approximate=approximate)
556
738
 
557
739
 
558
740
  def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
559
- r"""Gated linear unit activation function.
741
+ r"""
742
+ Gated Linear Unit activation function.
560
743
 
561
744
  Computes the function:
562
745
 
@@ -568,15 +751,22 @@ def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
568
751
  where the array is split into two along ``axis``. The size of the ``axis``
569
752
  dimension must be divisible by two.
570
753
 
571
- Args:
572
- x : input array
573
- axis: the axis along which the split should be computed (default: -1)
574
-
575
- Returns:
576
- An array.
577
-
578
- See also:
579
- :func:`sigmoid`
754
+ Parameters
755
+ ----------
756
+ x : ArrayLike
757
+ Input array. The dimension specified by ``axis`` must be divisible by 2.
758
+ axis : int, optional
759
+ The axis along which the split should be computed. Default is -1.
760
+
761
+ Returns
762
+ -------
763
+ jax.Array or Quantity
764
+ An array with the same shape as input except the ``axis`` dimension
765
+ is halved.
766
+
767
+ See Also
768
+ --------
769
+ sigmoid : The sigmoid activation function.
580
770
  """
581
771
  return u.math.glu(x, axis=axis)
582
772
 
@@ -584,26 +774,34 @@ def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
584
774
  def log_softmax(x: ArrayLike,
585
775
  axis: int | tuple[int, ...] | None = -1,
586
776
  where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
587
- r"""Log-Softmax function.
777
+ r"""
778
+ Log-Softmax function.
588
779
 
589
- Computes the logarithm of the :code:`softmax` function, which rescales
780
+ Computes the logarithm of the softmax function, which rescales
590
781
  elements to the range :math:`[-\infty, 0)`.
591
782
 
592
783
  .. math ::
593
784
  \mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
594
785
  \right)
595
786
 
596
- Args:
597
- x : input array
598
- axis: the axis or axes along which the :code:`log_softmax` should be
599
- computed. Either an integer or a tuple of integers.
600
- where: Elements to include in the :code:`log_softmax`.
601
-
602
- Returns:
603
- An array.
604
-
605
- See also:
606
- :func:`softmax`
787
+ Parameters
788
+ ----------
789
+ x : ArrayLike
790
+ Input array.
791
+ axis : int or tuple of int, optional
792
+ The axis or axes along which the log-softmax should be computed.
793
+ Either an integer or a tuple of integers. Default is -1.
794
+ where : ArrayLike, optional
795
+ Elements to include in the log-softmax computation.
796
+
797
+ Returns
798
+ -------
799
+ jax.Array or Quantity
800
+ An array with the same shape as the input.
801
+
802
+ See Also
803
+ --------
804
+ softmax : The softmax function.
607
805
  """
608
806
  return jax.nn.log_softmax(x, axis=axis, where=where)
609
807
 
@@ -611,7 +809,8 @@ def log_softmax(x: ArrayLike,
611
809
  def softmax(x: ArrayLike,
612
810
  axis: int | tuple[int, ...] | None = -1,
613
811
  where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
614
- r"""Softmax function.
812
+ r"""
813
+ Softmax activation function.
615
814
 
616
815
  Computes the function which rescales elements to the range :math:`[0, 1]`
617
816
  such that the elements along :code:`axis` sum to :math:`1`.
@@ -619,20 +818,26 @@ def softmax(x: ArrayLike,
619
818
  .. math ::
620
819
  \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
621
820
 
622
- Args:
623
- x : input array
624
- axis: the axis or axes along which the softmax should be computed. The
821
+ Parameters
822
+ ----------
823
+ x : ArrayLike
824
+ Input array.
825
+ axis : int or tuple of int, optional
826
+ The axis or axes along which the softmax should be computed. The
625
827
  softmax output summed across these dimensions should sum to :math:`1`.
626
- Either an integer or a tuple of integers.
627
- where: Elements to include in the :code:`softmax`.
628
- initial: The minimum value used to shift the input array. Must be present
629
- when :code:`where` is not None.
630
-
631
- Returns:
632
- An array.
633
-
634
- See also:
635
- :func:`log_softmax`
828
+ Either an integer or a tuple of integers. Default is -1.
829
+ where : ArrayLike, optional
830
+ Elements to include in the softmax computation.
831
+
832
+ Returns
833
+ -------
834
+ jax.Array or Quantity
835
+ An array with the same shape as the input.
836
+
837
+ See Also
838
+ --------
839
+ log_softmax : Logarithm of the softmax function.
840
+ softmin : Softmin activation function.
636
841
  """
637
842
  return jax.nn.softmax(x, axis=axis, where=where)
638
843
 
@@ -642,7 +847,32 @@ def standardize(x: ArrayLike,
642
847
  variance: ArrayLike | None = None,
643
848
  epsilon: ArrayLike = 1e-5,
644
849
  where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
645
- r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
850
+ r"""
851
+ Standardize (normalize) an array.
852
+
853
+ Normalizes an array by subtracting the mean and dividing by the standard
854
+ deviation :math:`\sqrt{\mathrm{variance}}`.
855
+
856
+ Parameters
857
+ ----------
858
+ x : ArrayLike
859
+ Input array.
860
+ axis : int or tuple of int, optional
861
+ The axis or axes along which to compute the mean and variance.
862
+ Default is -1.
863
+ variance : ArrayLike, optional
864
+ Pre-computed variance. If None, variance is computed from ``x``.
865
+ epsilon : ArrayLike, optional
866
+ A small constant added to the variance to avoid division by zero.
867
+ Default is 1e-5.
868
+ where : ArrayLike, optional
869
+ Elements to include in the computation.
870
+
871
+ Returns
872
+ -------
873
+ jax.Array or Quantity
874
+ Standardized array with the same shape as the input.
875
+ """
646
876
  return jax.nn.standardize(x, axis=axis, where=where, variance=variance, epsilon=epsilon)
647
877
 
648
878
 
@@ -650,41 +880,60 @@ def one_hot(x: Any,
650
880
  num_classes: int, *,
651
881
  dtype: Any = jax.numpy.float_,
652
882
  axis: Union[int, Sequence[int]] = -1) -> Union[jax.Array, u.Quantity]:
653
- """One-hot encodes the given indices.
883
+ """
884
+ One-hot encode the given indices.
654
885
 
655
886
  Each index in the input ``x`` is encoded as a vector of zeros of length
656
- ``num_classes`` with the element at ``index`` set to one::
657
-
658
- >>> one_hot(jnp.array([0, 1, 2]), 3)
659
- Array([[1., 0., 0.],
660
- [0., 1., 0.],
661
- [0., 0., 1.]], dtype=float32)
662
-
663
- Indices outside the range [0, num_classes) will be encoded as zeros::
664
-
665
- >>> one_hot(jnp.array([-1, 3]), 3)
666
- Array([[0., 0., 0.],
667
- [0., 0., 0.]], dtype=float32)
668
-
669
- Args:
670
- x: A tensor of indices.
671
- num_classes: Number of classes in the one-hot dimension.
672
- dtype: optional, a float dtype for the returned values (default :obj:`jnp.float_`).
673
- axis: the axis or axes along which the function should be
674
- computed.
887
+ ``num_classes`` with the element at ``index`` set to one.
888
+
889
+ Indices outside the range [0, num_classes) will be encoded as zeros.
890
+
891
+ Parameters
892
+ ----------
893
+ x : ArrayLike
894
+ A tensor of indices.
895
+ num_classes : int
896
+ Number of classes in the one-hot dimension.
897
+ dtype : dtype, optional
898
+ The dtype for the returned values. Default is ``jnp.float_``.
899
+ axis : int or Sequence of int, optional
900
+ The axis or axes along which the function should be computed.
901
+ Default is -1.
902
+
903
+ Returns
904
+ -------
905
+ jax.Array or Quantity
906
+ One-hot encoded array.
907
+
908
+ Examples
909
+ --------
910
+ .. code-block:: python
911
+
912
+ >>> import jax.numpy as jnp
913
+ >>> import brainstate
914
+ >>> brainstate.nn.one_hot(jnp.array([0, 1, 2]), 3)
915
+ Array([[1., 0., 0.],
916
+ [0., 1., 0.],
917
+ [0., 0., 1.]], dtype=float32)
918
+
919
+ >>> # Indices outside the range are encoded as zeros
920
+ >>> brainstate.nn.one_hot(jnp.array([-1, 3]), 3)
921
+ Array([[0., 0., 0.],
922
+ [0., 0., 0.]], dtype=float32)
675
923
  """
676
924
  return jax.nn.one_hot(x, axis=axis, num_classes=num_classes, dtype=dtype)
677
925
 
678
926
 
679
927
  def relu6(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
680
- r"""Rectified Linear Unit 6 activation function.
928
+ r"""
929
+ Rectified Linear Unit 6 activation function.
681
930
 
682
- Computes the element-wise function
931
+ Computes the element-wise function:
683
932
 
684
933
  .. math::
685
934
  \mathrm{relu6}(x) = \min(\max(x, 0), 6)
686
935
 
687
- except under differentiation, we take:
936
+ Under differentiation, we take:
688
937
 
689
938
  .. math::
690
939
  \nabla \mathrm{relu}(0) = 0
@@ -694,57 +943,78 @@ def relu6(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
694
943
  .. math::
695
944
  \nabla \mathrm{relu}(6) = 0
696
945
 
697
- Args:
698
- x : input array
946
+ Parameters
947
+ ----------
948
+ x : ArrayLike
949
+ Input array.
699
950
 
700
- Returns:
701
- An array.
951
+ Returns
952
+ -------
953
+ jax.Array or Quantity
954
+ An array with the same shape as the input.
702
955
 
703
- See also:
704
- :func:`relu`
956
+ See Also
957
+ --------
958
+ relu : Standard ReLU activation function.
705
959
  """
706
960
  return u.math.relu6(x)
707
961
 
708
962
 
709
963
  def hard_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
710
- r"""Hard Sigmoid activation function.
964
+ r"""
965
+ Hard Sigmoid activation function.
711
966
 
712
- Computes the element-wise function
967
+ Computes the element-wise function:
713
968
 
714
969
  .. math::
715
970
  \mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
716
971
 
717
- Args:
718
- x : input array
972
+ Parameters
973
+ ----------
974
+ x : ArrayLike
975
+ Input array.
719
976
 
720
- Returns:
721
- An array.
977
+ Returns
978
+ -------
979
+ jax.Array or Quantity
980
+ An array with the same shape as the input.
722
981
 
723
- See also:
724
- :func:`relu6`
982
+ See Also
983
+ --------
984
+ relu6 : ReLU6 activation function.
985
+ sigmoid : Standard sigmoid function.
725
986
  """
726
987
  return u.math.hard_sigmoid(x)
727
988
 
728
989
 
729
990
  def hard_silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
730
- r"""Hard SiLU (swish) activation function
991
+ r"""
992
+ Hard SiLU (Swish) activation function.
731
993
 
732
- Computes the element-wise function
994
+ Computes the element-wise function:
733
995
 
734
996
  .. math::
735
997
  \mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
736
998
 
737
- Both :func:`hard_silu` and :func:`hard_swish` are aliases for the same
738
- function.
739
-
740
- Args:
741
- x : input array
742
-
743
- Returns:
744
- An array.
745
-
746
- See also:
747
- :func:`hard_sigmoid`
999
+ Parameters
1000
+ ----------
1001
+ x : ArrayLike
1002
+ Input array.
1003
+
1004
+ Returns
1005
+ -------
1006
+ jax.Array or Quantity
1007
+ An array with the same shape as the input.
1008
+
1009
+ See Also
1010
+ --------
1011
+ hard_sigmoid : Hard sigmoid activation function.
1012
+ silu : Standard SiLU activation function.
1013
+ hard_swish : Alias for hard_silu.
1014
+
1015
+ Notes
1016
+ -----
1017
+ Both `hard_silu` and `hard_swish` are aliases for the same function.
748
1018
  """
749
1019
  return u.math.hard_silu(x)
750
1020
 
@@ -753,7 +1023,8 @@ hard_swish = hard_silu
753
1023
 
754
1024
 
755
1025
  def sparse_plus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
756
- r"""Sparse plus function.
1026
+ r"""
1027
+ Sparse plus activation function.
757
1028
 
758
1029
  Computes the function:
759
1030
 
@@ -765,19 +1036,31 @@ def sparse_plus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
765
1036
  x, & 1 \leq x
766
1037
  \end{cases}
767
1038
 
768
- This is the twin function of the softplus activation ensuring a zero output
1039
+ This is the twin function of the softplus activation, ensuring a zero output
769
1040
  for inputs less than -1 and a linear output for inputs greater than 1,
770
- while remaining smooth, convex, monotonic by an adequate definition between
771
- -1 and 1.
772
-
773
- Args:
774
- x: input (float)
1041
+ while remaining smooth, convex, and monotonic between -1 and 1.
1042
+
1043
+ Parameters
1044
+ ----------
1045
+ x : ArrayLike
1046
+ Input array.
1047
+
1048
+ Returns
1049
+ -------
1050
+ jax.Array or Quantity
1051
+ An array with the same shape as the input.
1052
+
1053
+ See Also
1054
+ --------
1055
+ sparse_sigmoid : Derivative of sparse_plus.
1056
+ softplus : Standard softplus activation function.
775
1057
  """
776
1058
  return u.math.sparse_plus(x)
777
1059
 
778
1060
 
779
1061
  def sparse_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
780
- r"""Sparse sigmoid activation function.
1062
+ r"""
1063
+ Sparse sigmoid activation function.
781
1064
 
782
1065
  Computes the function:
783
1066
 
@@ -789,20 +1072,29 @@ def sparse_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
789
1072
  1, & 1 \leq x
790
1073
  \end{cases}
791
1074
 
792
- This is the twin function of the ``sigmoid`` activation ensuring a zero output
793
- for inputs less than -1, a 1 output for inputs greater than 1, and a linear
794
- output for inputs between -1 and 1. It is the derivative of ``sparse_plus``.
795
-
796
- For more information, see `Learning with Fenchel-Young Losses (section 6.2)
797
- <https://arxiv.org/abs/1901.02324>`_.
798
-
799
- Args:
800
- x : input array
801
-
802
- Returns:
803
- An array.
804
-
805
- See also:
806
- :func:`sigmoid`
1075
+ This is the twin function of the standard sigmoid activation, ensuring a zero
1076
+ output for inputs less than -1, a 1 output for inputs greater than 1, and a
1077
+ linear output for inputs between -1 and 1. It is the derivative of `sparse_plus`.
1078
+
1079
+ Parameters
1080
+ ----------
1081
+ x : ArrayLike
1082
+ Input array.
1083
+
1084
+ Returns
1085
+ -------
1086
+ jax.Array or Quantity
1087
+ An array with the same shape as the input.
1088
+
1089
+ See Also
1090
+ --------
1091
+ sigmoid : Standard sigmoid activation function.
1092
+ sparse_plus : Sparse plus activation function.
1093
+
1094
+ References
1095
+ ----------
1096
+ .. [1] Martins, A. F. T., & Astudillo, R. F. (2016). "From Softmax to Sparsemax:
1097
+ A Sparse Model of Attention and Multi-Label Classification."
1098
+ In ICML. See also "Learning with Fenchel-Young Losses", arXiv:1901.02324
807
1099
  """
808
1100
  return u.math.sparse_sigmoid(x)