brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_activations.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -66,41 +66,49 @@ __all__ = [
|
|
66
66
|
|
67
67
|
|
68
68
|
def tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
69
|
-
r"""
|
69
|
+
r"""
|
70
|
+
Hyperbolic tangent activation function.
|
70
71
|
|
71
72
|
Computes the element-wise function:
|
72
73
|
|
73
74
|
.. math::
|
74
75
|
\mathrm{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
|
75
76
|
|
76
|
-
|
77
|
-
|
77
|
+
Parameters
|
78
|
+
----------
|
79
|
+
x : ArrayLike
|
80
|
+
Input array.
|
78
81
|
|
79
|
-
Returns
|
80
|
-
|
82
|
+
Returns
|
83
|
+
-------
|
84
|
+
jax.Array or Quantity
|
85
|
+
An array with the same shape as the input.
|
81
86
|
"""
|
82
87
|
return u.math.tanh(x)
|
83
88
|
|
84
89
|
|
85
90
|
def softmin(x, axis=-1):
|
86
91
|
r"""
|
87
|
-
|
88
|
-
rescaling them so that the elements of the n-dimensional output Tensor
|
89
|
-
lie in the range `[0, 1]` and sum to 1.
|
92
|
+
Softmin activation function.
|
90
93
|
|
91
|
-
Softmin
|
94
|
+
Applies the Softmin function to an n-dimensional input tensor, rescaling elements
|
95
|
+
so that they lie in the range [0, 1] and sum to 1 along the specified axis.
|
92
96
|
|
93
97
|
.. math::
|
94
98
|
\text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
|
95
99
|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
100
|
+
Parameters
|
101
|
+
----------
|
102
|
+
x : ArrayLike
|
103
|
+
Input array of any shape.
|
104
|
+
axis : int, optional
|
105
|
+
The axis along which Softmin will be computed. Every slice along this
|
106
|
+
dimension will sum to 1. Default is -1.
|
107
|
+
|
108
|
+
Returns
|
109
|
+
-------
|
110
|
+
jax.Array or Quantity
|
111
|
+
Output array with the same shape as the input.
|
104
112
|
"""
|
105
113
|
unnormalized = u.math.exp(-x)
|
106
114
|
return unnormalized / unnormalized.sum(axis, keepdims=True)
|
@@ -108,22 +116,36 @@ def softmin(x, axis=-1):
|
|
108
116
|
|
109
117
|
def tanh_shrink(x):
|
110
118
|
r"""
|
119
|
+
Tanh shrink activation function.
|
120
|
+
|
111
121
|
Applies the element-wise function:
|
112
122
|
|
113
123
|
.. math::
|
114
124
|
\text{Tanhshrink}(x) = x - \tanh(x)
|
125
|
+
|
126
|
+
Parameters
|
127
|
+
----------
|
128
|
+
x : ArrayLike
|
129
|
+
Input array.
|
130
|
+
|
131
|
+
Returns
|
132
|
+
-------
|
133
|
+
jax.Array or Quantity
|
134
|
+
Output array with the same shape as the input.
|
115
135
|
"""
|
116
136
|
return x - u.math.tanh(x)
|
117
137
|
|
118
138
|
|
119
139
|
def prelu(x, a=0.25):
|
120
140
|
r"""
|
141
|
+
Parametric Rectified Linear Unit activation function.
|
142
|
+
|
121
143
|
Applies the element-wise function:
|
122
144
|
|
123
145
|
.. math::
|
124
146
|
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
125
147
|
|
126
|
-
or
|
148
|
+
or equivalently:
|
127
149
|
|
128
150
|
.. math::
|
129
151
|
\text{PReLU}(x) =
|
@@ -132,16 +154,32 @@ def prelu(x, a=0.25):
|
|
132
154
|
ax, & \text{ otherwise }
|
133
155
|
\end{cases}
|
134
156
|
|
135
|
-
|
136
|
-
|
137
|
-
|
157
|
+
Parameters
|
158
|
+
----------
|
159
|
+
x : ArrayLike
|
160
|
+
Input array.
|
161
|
+
a : float or ArrayLike, optional
|
162
|
+
The negative slope coefficient. Can be a learnable parameter.
|
163
|
+
Default is 0.25.
|
164
|
+
|
165
|
+
Returns
|
166
|
+
-------
|
167
|
+
jax.Array or Quantity
|
168
|
+
Output array with the same shape as the input.
|
169
|
+
|
170
|
+
Notes
|
171
|
+
-----
|
172
|
+
When used in neural network layers, :math:`a` can be a learnable parameter
|
173
|
+
that is optimized during training.
|
138
174
|
"""
|
139
175
|
return u.math.where(x >= 0., x, a * x)
|
140
176
|
|
141
177
|
|
142
178
|
def soft_shrink(x, lambd=0.5):
|
143
179
|
r"""
|
144
|
-
|
180
|
+
Soft shrinkage activation function.
|
181
|
+
|
182
|
+
Applies the soft shrinkage function element-wise:
|
145
183
|
|
146
184
|
.. math::
|
147
185
|
\text{SoftShrinkage}(x) =
|
@@ -151,43 +189,60 @@ def soft_shrink(x, lambd=0.5):
|
|
151
189
|
0, & \text{ otherwise }
|
152
190
|
\end{cases}
|
153
191
|
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
192
|
+
Parameters
|
193
|
+
----------
|
194
|
+
x : ArrayLike
|
195
|
+
Input array of any shape.
|
196
|
+
lambd : float, optional
|
197
|
+
The :math:`\lambda` value for the soft shrinkage formulation.
|
198
|
+
Must be non-negative. Default is 0.5.
|
199
|
+
|
200
|
+
Returns
|
201
|
+
-------
|
202
|
+
jax.Array or Quantity
|
203
|
+
Output array with the same shape as the input.
|
160
204
|
"""
|
161
|
-
return u.math.where(
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
205
|
+
return u.math.where(
|
206
|
+
x > lambd,
|
207
|
+
x - lambd,
|
208
|
+
u.math.where(
|
209
|
+
x < -lambd,
|
210
|
+
x + lambd,
|
211
|
+
u.Quantity(0., unit=u.get_unit(lambd))
|
212
|
+
)
|
213
|
+
)
|
166
214
|
|
167
215
|
|
168
216
|
def mish(x):
|
169
|
-
r"""
|
217
|
+
r"""
|
218
|
+
Mish activation function.
|
170
219
|
|
171
|
-
Mish
|
220
|
+
Mish is a self-regularized non-monotonic activation function.
|
172
221
|
|
173
222
|
.. math::
|
174
223
|
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
|
175
224
|
|
176
|
-
|
177
|
-
|
225
|
+
Parameters
|
226
|
+
----------
|
227
|
+
x : ArrayLike
|
228
|
+
Input array of any shape.
|
178
229
|
|
179
|
-
|
180
|
-
|
181
|
-
|
230
|
+
Returns
|
231
|
+
-------
|
232
|
+
jax.Array or Quantity
|
233
|
+
Output array with the same shape as the input.
|
234
|
+
|
235
|
+
References
|
236
|
+
----------
|
237
|
+
.. [1] Misra, D. (2019). "Mish: A Self Regularized Non-Monotonic Activation Function."
|
238
|
+
arXiv:1908.08681
|
182
239
|
"""
|
183
240
|
return x * u.math.tanh(softplus(x))
|
184
241
|
|
185
242
|
|
186
243
|
def rrelu(x, lower=0.125, upper=0.3333333333333333):
|
187
|
-
r"""
|
188
|
-
|
189
|
-
|
190
|
-
`Empirical Evaluation of Rectified Activations in Convolutional Network`_.
|
244
|
+
r"""
|
245
|
+
Randomized Leaky Rectified Linear Unit activation function.
|
191
246
|
|
192
247
|
The function is defined as:
|
193
248
|
|
@@ -201,27 +256,36 @@ def rrelu(x, lower=0.125, upper=0.3333333333333333):
|
|
201
256
|
where :math:`a` is randomly sampled from uniform distribution
|
202
257
|
:math:`\mathcal{U}(\text{lower}, \text{upper})`.
|
203
258
|
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
259
|
+
Parameters
|
260
|
+
----------
|
261
|
+
x : ArrayLike
|
262
|
+
Input array of any shape.
|
263
|
+
lower : float, optional
|
264
|
+
Lower bound of the uniform distribution for sampling the negative slope.
|
265
|
+
Default is 1/8.
|
266
|
+
upper : float, optional
|
267
|
+
Upper bound of the uniform distribution for sampling the negative slope.
|
268
|
+
Default is 1/3.
|
269
|
+
|
270
|
+
Returns
|
271
|
+
-------
|
272
|
+
jax.Array or Quantity
|
273
|
+
Output array with the same shape as the input.
|
274
|
+
|
275
|
+
References
|
276
|
+
----------
|
277
|
+
.. [1] Xu, B., et al. (2015). "Empirical Evaluation of Rectified Activations
|
278
|
+
in Convolutional Network." arXiv:1505.00853
|
216
279
|
"""
|
217
280
|
a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
|
218
281
|
return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
|
219
282
|
|
220
283
|
|
221
284
|
def hard_shrink(x, lambd=0.5):
|
222
|
-
r"""
|
285
|
+
r"""
|
286
|
+
Hard shrinkage activation function.
|
223
287
|
|
224
|
-
|
288
|
+
Applies the hard shrinkage function element-wise:
|
225
289
|
|
226
290
|
.. math::
|
227
291
|
\text{HardShrink}(x) =
|
@@ -231,139 +295,202 @@ def hard_shrink(x, lambd=0.5):
|
|
231
295
|
0, & \text{ otherwise }
|
232
296
|
\end{cases}
|
233
297
|
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
298
|
+
Parameters
|
299
|
+
----------
|
300
|
+
x : ArrayLike
|
301
|
+
Input array of any shape.
|
302
|
+
lambd : float, optional
|
303
|
+
The :math:`\lambda` threshold value for the hard shrinkage formulation.
|
304
|
+
Default is 0.5.
|
305
|
+
|
306
|
+
Returns
|
307
|
+
-------
|
308
|
+
jax.Array or Quantity
|
309
|
+
Output array with the same shape as the input.
|
241
310
|
"""
|
242
|
-
return u.math.where(
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
311
|
+
return u.math.where(
|
312
|
+
x > lambd,
|
313
|
+
x,
|
314
|
+
u.math.where(
|
315
|
+
x < -lambd,
|
316
|
+
x,
|
317
|
+
u.Quantity(0., unit=u.get_unit(x))
|
318
|
+
)
|
319
|
+
)
|
247
320
|
|
248
321
|
|
249
322
|
def relu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
250
|
-
r"""
|
323
|
+
r"""
|
324
|
+
Rectified Linear Unit activation function.
|
251
325
|
|
252
326
|
Computes the element-wise function:
|
253
327
|
|
254
328
|
.. math::
|
255
329
|
\mathrm{relu}(x) = \max(x, 0)
|
256
330
|
|
257
|
-
|
331
|
+
Under differentiation, we take:
|
258
332
|
|
259
333
|
.. math::
|
260
334
|
\nabla \mathrm{relu}(0) = 0
|
261
335
|
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
336
|
+
Parameters
|
337
|
+
----------
|
338
|
+
x : ArrayLike
|
339
|
+
Input array.
|
340
|
+
|
341
|
+
Returns
|
342
|
+
-------
|
343
|
+
jax.Array or Quantity
|
344
|
+
An array with the same shape as the input.
|
345
|
+
|
346
|
+
Examples
|
347
|
+
--------
|
348
|
+
.. code-block:: python
|
349
|
+
|
350
|
+
>>> import jax.numpy as jnp
|
351
|
+
>>> import brainstate
|
352
|
+
>>> brainstate.nn.relu(jnp.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
|
353
|
+
Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
|
354
|
+
|
355
|
+
See Also
|
356
|
+
--------
|
357
|
+
relu6 : ReLU6 activation function.
|
358
|
+
leaky_relu : Leaky ReLU activation function.
|
359
|
+
|
360
|
+
References
|
361
|
+
----------
|
362
|
+
.. [1] For more information see "Numerical influence of ReLU'(0) on backpropagation"
|
363
|
+
https://openreview.net/forum?id=urrcVI-_jRm
|
279
364
|
"""
|
280
365
|
return u.math.relu(x)
|
281
366
|
|
282
367
|
|
283
368
|
def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Union[jax.Array, u.Quantity]:
|
284
|
-
r"""
|
369
|
+
r"""
|
370
|
+
Squareplus activation function.
|
285
371
|
|
286
|
-
Computes the element-wise function
|
372
|
+
Computes the element-wise function:
|
287
373
|
|
288
374
|
.. math::
|
289
375
|
\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
|
290
376
|
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
377
|
+
Parameters
|
378
|
+
----------
|
379
|
+
x : ArrayLike
|
380
|
+
Input array.
|
381
|
+
b : ArrayLike, optional
|
382
|
+
Smoothness parameter. Default is 4.
|
383
|
+
|
384
|
+
Returns
|
385
|
+
-------
|
386
|
+
jax.Array or Quantity
|
387
|
+
An array with the same shape as the input.
|
388
|
+
|
389
|
+
References
|
390
|
+
----------
|
391
|
+
.. [1] So, D., et al. (2021). "Primer: Searching for Efficient Transformers
|
392
|
+
for Language Modeling." arXiv:2112.11687
|
296
393
|
"""
|
297
394
|
return u.math.squareplus(x, b=b)
|
298
395
|
|
299
396
|
|
300
397
|
def softplus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
301
|
-
r"""
|
398
|
+
r"""
|
399
|
+
Softplus activation function.
|
302
400
|
|
303
|
-
Computes the element-wise function
|
401
|
+
Computes the element-wise function:
|
304
402
|
|
305
403
|
.. math::
|
306
404
|
\mathrm{softplus}(x) = \log(1 + e^x)
|
307
405
|
|
308
|
-
|
309
|
-
|
406
|
+
Parameters
|
407
|
+
----------
|
408
|
+
x : ArrayLike
|
409
|
+
Input array.
|
410
|
+
|
411
|
+
Returns
|
412
|
+
-------
|
413
|
+
jax.Array or Quantity
|
414
|
+
An array with the same shape as the input.
|
310
415
|
"""
|
311
416
|
return u.math.softplus(x)
|
312
417
|
|
313
418
|
|
314
419
|
def soft_sign(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
315
|
-
r"""
|
420
|
+
r"""
|
421
|
+
Soft-sign activation function.
|
316
422
|
|
317
|
-
Computes the element-wise function
|
423
|
+
Computes the element-wise function:
|
318
424
|
|
319
425
|
.. math::
|
320
426
|
\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
|
321
427
|
|
322
|
-
|
323
|
-
|
428
|
+
Parameters
|
429
|
+
----------
|
430
|
+
x : ArrayLike
|
431
|
+
Input array.
|
432
|
+
|
433
|
+
Returns
|
434
|
+
-------
|
435
|
+
jax.Array or Quantity
|
436
|
+
An array with the same shape as the input.
|
324
437
|
"""
|
325
438
|
return u.math.soft_sign(x)
|
326
439
|
|
327
440
|
|
328
441
|
def sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
329
|
-
r"""
|
442
|
+
r"""
|
443
|
+
Sigmoid activation function.
|
330
444
|
|
331
445
|
Computes the element-wise function:
|
332
446
|
|
333
447
|
.. math::
|
334
448
|
\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
|
335
449
|
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
An array.
|
450
|
+
Parameters
|
451
|
+
----------
|
452
|
+
x : ArrayLike
|
453
|
+
Input array.
|
341
454
|
|
342
|
-
|
343
|
-
|
455
|
+
Returns
|
456
|
+
-------
|
457
|
+
jax.Array or Quantity
|
458
|
+
An array with the same shape as the input.
|
344
459
|
|
460
|
+
See Also
|
461
|
+
--------
|
462
|
+
log_sigmoid : Logarithm of the sigmoid function.
|
345
463
|
"""
|
346
464
|
return u.math.sigmoid(x)
|
347
465
|
|
348
466
|
|
349
467
|
def silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
350
|
-
r"""
|
468
|
+
r"""
|
469
|
+
SiLU (Sigmoid Linear Unit) activation function.
|
351
470
|
|
352
471
|
Computes the element-wise function:
|
353
472
|
|
354
473
|
.. math::
|
355
474
|
\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
|
356
475
|
|
357
|
-
|
476
|
+
Parameters
|
477
|
+
----------
|
478
|
+
x : ArrayLike
|
479
|
+
Input array.
|
358
480
|
|
359
|
-
|
360
|
-
|
481
|
+
Returns
|
482
|
+
-------
|
483
|
+
jax.Array or Quantity
|
484
|
+
An array with the same shape as the input.
|
361
485
|
|
362
|
-
|
363
|
-
|
486
|
+
See Also
|
487
|
+
--------
|
488
|
+
sigmoid : The sigmoid function.
|
489
|
+
swish : Alias for silu.
|
364
490
|
|
365
|
-
|
366
|
-
|
491
|
+
Notes
|
492
|
+
-----
|
493
|
+
`swish` and `silu` are both aliases for the same function.
|
367
494
|
"""
|
368
495
|
return u.math.silu(x)
|
369
496
|
|
@@ -372,27 +499,34 @@ swish = silu
|
|
372
499
|
|
373
500
|
|
374
501
|
def log_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
375
|
-
r"""
|
502
|
+
r"""
|
503
|
+
Log-sigmoid activation function.
|
376
504
|
|
377
505
|
Computes the element-wise function:
|
378
506
|
|
379
507
|
.. math::
|
380
508
|
\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
|
381
509
|
|
382
|
-
|
383
|
-
|
510
|
+
Parameters
|
511
|
+
----------
|
512
|
+
x : ArrayLike
|
513
|
+
Input array.
|
384
514
|
|
385
|
-
Returns
|
386
|
-
|
515
|
+
Returns
|
516
|
+
-------
|
517
|
+
jax.Array or Quantity
|
518
|
+
An array with the same shape as the input.
|
387
519
|
|
388
|
-
See
|
389
|
-
|
520
|
+
See Also
|
521
|
+
--------
|
522
|
+
sigmoid : The sigmoid function.
|
390
523
|
"""
|
391
524
|
return u.math.log_sigmoid(x)
|
392
525
|
|
393
526
|
|
394
527
|
def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
|
395
|
-
r"""
|
528
|
+
r"""
|
529
|
+
Exponential Linear Unit activation function.
|
396
530
|
|
397
531
|
Computes the element-wise function:
|
398
532
|
|
@@ -402,21 +536,29 @@ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
|
|
402
536
|
\alpha \left(\exp(x) - 1\right), & x \le 0
|
403
537
|
\end{cases}
|
404
538
|
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
539
|
+
Parameters
|
540
|
+
----------
|
541
|
+
x : ArrayLike
|
542
|
+
Input array.
|
543
|
+
alpha : ArrayLike, optional
|
544
|
+
Scalar or array of alpha values. Default is 1.0.
|
545
|
+
|
546
|
+
Returns
|
547
|
+
-------
|
548
|
+
jax.Array or Quantity
|
549
|
+
An array with the same shape as the input.
|
550
|
+
|
551
|
+
See Also
|
552
|
+
--------
|
553
|
+
selu : Scaled ELU activation function.
|
554
|
+
celu : Continuously-differentiable ELU activation function.
|
414
555
|
"""
|
415
556
|
return u.math.elu(x, alpha=alpha)
|
416
557
|
|
417
558
|
|
418
559
|
def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Union[jax.Array, u.Quantity]:
|
419
|
-
r"""
|
560
|
+
r"""
|
561
|
+
Leaky Rectified Linear Unit activation function.
|
420
562
|
|
421
563
|
Computes the element-wise function:
|
422
564
|
|
@@ -428,15 +570,22 @@ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Union[jax.Arra
|
|
428
570
|
|
429
571
|
where :math:`\alpha` = :code:`negative_slope`.
|
430
572
|
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
573
|
+
Parameters
|
574
|
+
----------
|
575
|
+
x : ArrayLike
|
576
|
+
Input array.
|
577
|
+
negative_slope : ArrayLike, optional
|
578
|
+
Array or scalar specifying the negative slope. Default is 0.01.
|
579
|
+
|
580
|
+
Returns
|
581
|
+
-------
|
582
|
+
jax.Array or Quantity
|
583
|
+
An array with the same shape as the input.
|
584
|
+
|
585
|
+
See Also
|
586
|
+
--------
|
587
|
+
relu : Standard ReLU activation function.
|
588
|
+
prelu : Parametric ReLU with learnable slope.
|
440
589
|
"""
|
441
590
|
return u.math.leaky_relu(x, negative_slope=negative_slope)
|
442
591
|
|
@@ -450,7 +599,8 @@ def hard_tanh(
|
|
450
599
|
min_val: float = - 1.0,
|
451
600
|
max_val: float = 1.0
|
452
601
|
) -> Union[jax.Array, u.Quantity]:
|
453
|
-
r"""
|
602
|
+
r"""
|
603
|
+
Hard hyperbolic tangent activation function.
|
454
604
|
|
455
605
|
Computes the element-wise function:
|
456
606
|
|
@@ -461,13 +611,19 @@ def hard_tanh(
|
|
461
611
|
1, & 1 < x
|
462
612
|
\end{cases}
|
463
613
|
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
614
|
+
Parameters
|
615
|
+
----------
|
616
|
+
x : ArrayLike
|
617
|
+
Input array.
|
618
|
+
min_val : float, optional
|
619
|
+
Minimum value of the linear region range. Default is -1.
|
620
|
+
max_val : float, optional
|
621
|
+
Maximum value of the linear region range. Default is 1.
|
622
|
+
|
623
|
+
Returns
|
624
|
+
-------
|
625
|
+
jax.Array or Quantity
|
626
|
+
An array with the same shape as the input.
|
471
627
|
"""
|
472
628
|
x = u.Quantity(x)
|
473
629
|
min_val = u.Quantity(min_val).to(x.unit).mantissa
|
@@ -476,7 +632,8 @@ def hard_tanh(
|
|
476
632
|
|
477
633
|
|
478
634
|
def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
|
479
|
-
r"""
|
635
|
+
r"""
|
636
|
+
Continuously-differentiable Exponential Linear Unit activation.
|
480
637
|
|
481
638
|
Computes the element-wise function:
|
482
639
|
|
@@ -486,22 +643,29 @@ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
|
|
486
643
|
\alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
|
487
644
|
\end{cases}
|
488
645
|
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
646
|
+
Parameters
|
647
|
+
----------
|
648
|
+
x : ArrayLike
|
649
|
+
Input array.
|
650
|
+
alpha : ArrayLike, optional
|
651
|
+
Scalar or array value controlling the smoothness. Default is 1.0.
|
652
|
+
|
653
|
+
Returns
|
654
|
+
-------
|
655
|
+
jax.Array or Quantity
|
656
|
+
An array with the same shape as the input.
|
657
|
+
|
658
|
+
References
|
659
|
+
----------
|
660
|
+
.. [1] Barron, J. T. (2017). "Continuously Differentiable Exponential Linear Units."
|
661
|
+
arXiv:1704.07483
|
499
662
|
"""
|
500
663
|
return u.math.celu(x, alpha=alpha)
|
501
664
|
|
502
665
|
|
503
666
|
def selu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
504
|
-
r"""
|
667
|
+
r"""
|
668
|
+
Scaled Exponential Linear Unit activation.
|
505
669
|
|
506
670
|
Computes the element-wise function:
|
507
671
|
|
@@ -514,24 +678,31 @@ def selu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
|
514
678
|
where :math:`\lambda = 1.0507009873554804934193349852946` and
|
515
679
|
:math:`\alpha = 1.6732632423543772848170429916717`.
|
516
680
|
|
517
|
-
|
518
|
-
|
519
|
-
|
681
|
+
Parameters
|
682
|
+
----------
|
683
|
+
x : ArrayLike
|
684
|
+
Input array.
|
520
685
|
|
521
|
-
|
522
|
-
|
686
|
+
Returns
|
687
|
+
-------
|
688
|
+
jax.Array or Quantity
|
689
|
+
An array with the same shape as the input.
|
523
690
|
|
524
|
-
|
525
|
-
|
691
|
+
See Also
|
692
|
+
--------
|
693
|
+
elu : Exponential Linear Unit activation function.
|
526
694
|
|
527
|
-
|
528
|
-
|
695
|
+
References
|
696
|
+
----------
|
697
|
+
.. [1] Klambauer, G., et al. (2017). "Self-Normalizing Neural Networks."
|
698
|
+
NeurIPS 2017.
|
529
699
|
"""
|
530
700
|
return u.math.selu(x)
|
531
701
|
|
532
702
|
|
533
703
|
def gelu(x: ArrayLike, approximate: bool = True) -> Union[jax.Array, u.Quantity]:
|
534
|
-
r"""
|
704
|
+
r"""
|
705
|
+
Gaussian Error Linear Unit activation function.
|
535
706
|
|
536
707
|
If ``approximate=False``, computes the element-wise function:
|
537
708
|
|
@@ -545,18 +716,30 @@ def gelu(x: ArrayLike, approximate: bool = True) -> Union[jax.Array, u.Quantity]
|
|
545
716
|
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
|
546
717
|
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
|
547
718
|
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
719
|
+
Parameters
|
720
|
+
----------
|
721
|
+
x : ArrayLike
|
722
|
+
Input array.
|
723
|
+
approximate : bool, optional
|
724
|
+
Whether to use the approximate (True) or exact (False) formulation.
|
725
|
+
Default is True.
|
726
|
+
|
727
|
+
Returns
|
728
|
+
-------
|
729
|
+
jax.Array or Quantity
|
730
|
+
An array with the same shape as the input.
|
731
|
+
|
732
|
+
References
|
733
|
+
----------
|
734
|
+
.. [1] Hendrycks, D., & Gimpel, K. (2016). "Gaussian Error Linear Units (GELUs)."
|
735
|
+
arXiv:1606.08415
|
554
736
|
"""
|
555
737
|
return u.math.gelu(x, approximate=approximate)
|
556
738
|
|
557
739
|
|
558
740
|
def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
|
559
|
-
r"""
|
741
|
+
r"""
|
742
|
+
Gated Linear Unit activation function.
|
560
743
|
|
561
744
|
Computes the function:
|
562
745
|
|
@@ -568,15 +751,22 @@ def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
|
|
568
751
|
where the array is split into two along ``axis``. The size of the ``axis``
|
569
752
|
dimension must be divisible by two.
|
570
753
|
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
754
|
+
Parameters
|
755
|
+
----------
|
756
|
+
x : ArrayLike
|
757
|
+
Input array. The dimension specified by ``axis`` must be divisible by 2.
|
758
|
+
axis : int, optional
|
759
|
+
The axis along which the split should be computed. Default is -1.
|
760
|
+
|
761
|
+
Returns
|
762
|
+
-------
|
763
|
+
jax.Array or Quantity
|
764
|
+
An array with the same shape as input except the ``axis`` dimension
|
765
|
+
is halved.
|
766
|
+
|
767
|
+
See Also
|
768
|
+
--------
|
769
|
+
sigmoid : The sigmoid activation function.
|
580
770
|
"""
|
581
771
|
return u.math.glu(x, axis=axis)
|
582
772
|
|
@@ -584,26 +774,34 @@ def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
|
|
584
774
|
def log_softmax(x: ArrayLike,
|
585
775
|
axis: int | tuple[int, ...] | None = -1,
|
586
776
|
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
587
|
-
r"""
|
777
|
+
r"""
|
778
|
+
Log-Softmax function.
|
588
779
|
|
589
|
-
Computes the logarithm of the
|
780
|
+
Computes the logarithm of the softmax function, which rescales
|
590
781
|
elements to the range :math:`[-\infty, 0)`.
|
591
782
|
|
592
783
|
.. math ::
|
593
784
|
\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
594
785
|
\right)
|
595
786
|
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
787
|
+
Parameters
|
788
|
+
----------
|
789
|
+
x : ArrayLike
|
790
|
+
Input array.
|
791
|
+
axis : int or tuple of int, optional
|
792
|
+
The axis or axes along which the log-softmax should be computed.
|
793
|
+
Either an integer or a tuple of integers. Default is -1.
|
794
|
+
where : ArrayLike, optional
|
795
|
+
Elements to include in the log-softmax computation.
|
796
|
+
|
797
|
+
Returns
|
798
|
+
-------
|
799
|
+
jax.Array or Quantity
|
800
|
+
An array with the same shape as the input.
|
801
|
+
|
802
|
+
See Also
|
803
|
+
--------
|
804
|
+
softmax : The softmax function.
|
607
805
|
"""
|
608
806
|
return jax.nn.log_softmax(x, axis=axis, where=where)
|
609
807
|
|
@@ -611,7 +809,8 @@ def log_softmax(x: ArrayLike,
|
|
611
809
|
def softmax(x: ArrayLike,
|
612
810
|
axis: int | tuple[int, ...] | None = -1,
|
613
811
|
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
614
|
-
r"""
|
812
|
+
r"""
|
813
|
+
Softmax activation function.
|
615
814
|
|
616
815
|
Computes the function which rescales elements to the range :math:`[0, 1]`
|
617
816
|
such that the elements along :code:`axis` sum to :math:`1`.
|
@@ -619,20 +818,26 @@ def softmax(x: ArrayLike,
|
|
619
818
|
.. math ::
|
620
819
|
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
621
820
|
|
622
|
-
|
623
|
-
|
624
|
-
|
821
|
+
Parameters
|
822
|
+
----------
|
823
|
+
x : ArrayLike
|
824
|
+
Input array.
|
825
|
+
axis : int or tuple of int, optional
|
826
|
+
The axis or axes along which the softmax should be computed. The
|
625
827
|
softmax output summed across these dimensions should sum to :math:`1`.
|
626
|
-
Either an integer or a tuple of integers.
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
828
|
+
Either an integer or a tuple of integers. Default is -1.
|
829
|
+
where : ArrayLike, optional
|
830
|
+
Elements to include in the softmax computation.
|
831
|
+
|
832
|
+
Returns
|
833
|
+
-------
|
834
|
+
jax.Array or Quantity
|
835
|
+
An array with the same shape as the input.
|
836
|
+
|
837
|
+
See Also
|
838
|
+
--------
|
839
|
+
log_softmax : Logarithm of the softmax function.
|
840
|
+
softmin : Softmin activation function.
|
636
841
|
"""
|
637
842
|
return jax.nn.softmax(x, axis=axis, where=where)
|
638
843
|
|
@@ -642,7 +847,32 @@ def standardize(x: ArrayLike,
|
|
642
847
|
variance: ArrayLike | None = None,
|
643
848
|
epsilon: ArrayLike = 1e-5,
|
644
849
|
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
645
|
-
r"""
|
850
|
+
r"""
|
851
|
+
Standardize (normalize) an array.
|
852
|
+
|
853
|
+
Normalizes an array by subtracting the mean and dividing by the standard
|
854
|
+
deviation :math:`\sqrt{\mathrm{variance}}`.
|
855
|
+
|
856
|
+
Parameters
|
857
|
+
----------
|
858
|
+
x : ArrayLike
|
859
|
+
Input array.
|
860
|
+
axis : int or tuple of int, optional
|
861
|
+
The axis or axes along which to compute the mean and variance.
|
862
|
+
Default is -1.
|
863
|
+
variance : ArrayLike, optional
|
864
|
+
Pre-computed variance. If None, variance is computed from ``x``.
|
865
|
+
epsilon : ArrayLike, optional
|
866
|
+
A small constant added to the variance to avoid division by zero.
|
867
|
+
Default is 1e-5.
|
868
|
+
where : ArrayLike, optional
|
869
|
+
Elements to include in the computation.
|
870
|
+
|
871
|
+
Returns
|
872
|
+
-------
|
873
|
+
jax.Array or Quantity
|
874
|
+
Standardized array with the same shape as the input.
|
875
|
+
"""
|
646
876
|
return jax.nn.standardize(x, axis=axis, where=where, variance=variance, epsilon=epsilon)
|
647
877
|
|
648
878
|
|
@@ -650,41 +880,60 @@ def one_hot(x: Any,
|
|
650
880
|
num_classes: int, *,
|
651
881
|
dtype: Any = jax.numpy.float_,
|
652
882
|
axis: Union[int, Sequence[int]] = -1) -> Union[jax.Array, u.Quantity]:
|
653
|
-
"""
|
883
|
+
"""
|
884
|
+
One-hot encode the given indices.
|
654
885
|
|
655
886
|
Each index in the input ``x`` is encoded as a vector of zeros of length
|
656
|
-
``num_classes`` with the element at ``index`` set to one
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
887
|
+
``num_classes`` with the element at ``index`` set to one.
|
888
|
+
|
889
|
+
Indices outside the range [0, num_classes) will be encoded as zeros.
|
890
|
+
|
891
|
+
Parameters
|
892
|
+
----------
|
893
|
+
x : ArrayLike
|
894
|
+
A tensor of indices.
|
895
|
+
num_classes : int
|
896
|
+
Number of classes in the one-hot dimension.
|
897
|
+
dtype : dtype, optional
|
898
|
+
The dtype for the returned values. Default is ``jnp.float_``.
|
899
|
+
axis : int or Sequence of int, optional
|
900
|
+
The axis or axes along which the function should be computed.
|
901
|
+
Default is -1.
|
902
|
+
|
903
|
+
Returns
|
904
|
+
-------
|
905
|
+
jax.Array or Quantity
|
906
|
+
One-hot encoded array.
|
907
|
+
|
908
|
+
Examples
|
909
|
+
--------
|
910
|
+
.. code-block:: python
|
911
|
+
|
912
|
+
>>> import jax.numpy as jnp
|
913
|
+
>>> import brainstate
|
914
|
+
>>> brainstate.nn.one_hot(jnp.array([0, 1, 2]), 3)
|
915
|
+
Array([[1., 0., 0.],
|
916
|
+
[0., 1., 0.],
|
917
|
+
[0., 0., 1.]], dtype=float32)
|
918
|
+
|
919
|
+
>>> # Indices outside the range are encoded as zeros
|
920
|
+
>>> brainstate.nn.one_hot(jnp.array([-1, 3]), 3)
|
921
|
+
Array([[0., 0., 0.],
|
922
|
+
[0., 0., 0.]], dtype=float32)
|
675
923
|
"""
|
676
924
|
return jax.nn.one_hot(x, axis=axis, num_classes=num_classes, dtype=dtype)
|
677
925
|
|
678
926
|
|
679
927
|
def relu6(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
680
|
-
r"""
|
928
|
+
r"""
|
929
|
+
Rectified Linear Unit 6 activation function.
|
681
930
|
|
682
|
-
Computes the element-wise function
|
931
|
+
Computes the element-wise function:
|
683
932
|
|
684
933
|
.. math::
|
685
934
|
\mathrm{relu6}(x) = \min(\max(x, 0), 6)
|
686
935
|
|
687
|
-
|
936
|
+
Under differentiation, we take:
|
688
937
|
|
689
938
|
.. math::
|
690
939
|
\nabla \mathrm{relu}(0) = 0
|
@@ -694,57 +943,78 @@ def relu6(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
|
694
943
|
.. math::
|
695
944
|
\nabla \mathrm{relu}(6) = 0
|
696
945
|
|
697
|
-
|
698
|
-
|
946
|
+
Parameters
|
947
|
+
----------
|
948
|
+
x : ArrayLike
|
949
|
+
Input array.
|
699
950
|
|
700
|
-
Returns
|
701
|
-
|
951
|
+
Returns
|
952
|
+
-------
|
953
|
+
jax.Array or Quantity
|
954
|
+
An array with the same shape as the input.
|
702
955
|
|
703
|
-
See
|
704
|
-
|
956
|
+
See Also
|
957
|
+
--------
|
958
|
+
relu : Standard ReLU activation function.
|
705
959
|
"""
|
706
960
|
return u.math.relu6(x)
|
707
961
|
|
708
962
|
|
709
963
|
def hard_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
710
|
-
r"""
|
964
|
+
r"""
|
965
|
+
Hard Sigmoid activation function.
|
711
966
|
|
712
|
-
Computes the element-wise function
|
967
|
+
Computes the element-wise function:
|
713
968
|
|
714
969
|
.. math::
|
715
970
|
\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
|
716
971
|
|
717
|
-
|
718
|
-
|
972
|
+
Parameters
|
973
|
+
----------
|
974
|
+
x : ArrayLike
|
975
|
+
Input array.
|
719
976
|
|
720
|
-
Returns
|
721
|
-
|
977
|
+
Returns
|
978
|
+
-------
|
979
|
+
jax.Array or Quantity
|
980
|
+
An array with the same shape as the input.
|
722
981
|
|
723
|
-
See
|
724
|
-
|
982
|
+
See Also
|
983
|
+
--------
|
984
|
+
relu6 : ReLU6 activation function.
|
985
|
+
sigmoid : Standard sigmoid function.
|
725
986
|
"""
|
726
987
|
return u.math.hard_sigmoid(x)
|
727
988
|
|
728
989
|
|
729
990
|
def hard_silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
730
|
-
r"""
|
991
|
+
r"""
|
992
|
+
Hard SiLU (Swish) activation function.
|
731
993
|
|
732
|
-
Computes the element-wise function
|
994
|
+
Computes the element-wise function:
|
733
995
|
|
734
996
|
.. math::
|
735
997
|
\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
|
736
998
|
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
999
|
+
Parameters
|
1000
|
+
----------
|
1001
|
+
x : ArrayLike
|
1002
|
+
Input array.
|
1003
|
+
|
1004
|
+
Returns
|
1005
|
+
-------
|
1006
|
+
jax.Array or Quantity
|
1007
|
+
An array with the same shape as the input.
|
1008
|
+
|
1009
|
+
See Also
|
1010
|
+
--------
|
1011
|
+
hard_sigmoid : Hard sigmoid activation function.
|
1012
|
+
silu : Standard SiLU activation function.
|
1013
|
+
hard_swish : Alias for hard_silu.
|
1014
|
+
|
1015
|
+
Notes
|
1016
|
+
-----
|
1017
|
+
Both `hard_silu` and `hard_swish` are aliases for the same function.
|
748
1018
|
"""
|
749
1019
|
return u.math.hard_silu(x)
|
750
1020
|
|
@@ -753,7 +1023,8 @@ hard_swish = hard_silu
|
|
753
1023
|
|
754
1024
|
|
755
1025
|
def sparse_plus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
756
|
-
r"""
|
1026
|
+
r"""
|
1027
|
+
Sparse plus activation function.
|
757
1028
|
|
758
1029
|
Computes the function:
|
759
1030
|
|
@@ -765,19 +1036,31 @@ def sparse_plus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
|
765
1036
|
x, & 1 \leq x
|
766
1037
|
\end{cases}
|
767
1038
|
|
768
|
-
This is the twin function of the softplus activation ensuring a zero output
|
1039
|
+
This is the twin function of the softplus activation, ensuring a zero output
|
769
1040
|
for inputs less than -1 and a linear output for inputs greater than 1,
|
770
|
-
while remaining smooth, convex, monotonic
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
1041
|
+
while remaining smooth, convex, and monotonic between -1 and 1.
|
1042
|
+
|
1043
|
+
Parameters
|
1044
|
+
----------
|
1045
|
+
x : ArrayLike
|
1046
|
+
Input array.
|
1047
|
+
|
1048
|
+
Returns
|
1049
|
+
-------
|
1050
|
+
jax.Array or Quantity
|
1051
|
+
An array with the same shape as the input.
|
1052
|
+
|
1053
|
+
See Also
|
1054
|
+
--------
|
1055
|
+
sparse_sigmoid : Derivative of sparse_plus.
|
1056
|
+
softplus : Standard softplus activation function.
|
775
1057
|
"""
|
776
1058
|
return u.math.sparse_plus(x)
|
777
1059
|
|
778
1060
|
|
779
1061
|
def sparse_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
780
|
-
r"""
|
1062
|
+
r"""
|
1063
|
+
Sparse sigmoid activation function.
|
781
1064
|
|
782
1065
|
Computes the function:
|
783
1066
|
|
@@ -789,20 +1072,29 @@ def sparse_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
|
789
1072
|
1, & 1 \leq x
|
790
1073
|
\end{cases}
|
791
1074
|
|
792
|
-
This is the twin function of the
|
793
|
-
for inputs less than -1, a 1 output for inputs greater than 1, and a
|
794
|
-
output for inputs between -1 and 1. It is the derivative of
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
1075
|
+
This is the twin function of the standard sigmoid activation, ensuring a zero
|
1076
|
+
output for inputs less than -1, a 1 output for inputs greater than 1, and a
|
1077
|
+
linear output for inputs between -1 and 1. It is the derivative of `sparse_plus`.
|
1078
|
+
|
1079
|
+
Parameters
|
1080
|
+
----------
|
1081
|
+
x : ArrayLike
|
1082
|
+
Input array.
|
1083
|
+
|
1084
|
+
Returns
|
1085
|
+
-------
|
1086
|
+
jax.Array or Quantity
|
1087
|
+
An array with the same shape as the input.
|
1088
|
+
|
1089
|
+
See Also
|
1090
|
+
--------
|
1091
|
+
sigmoid : Standard sigmoid activation function.
|
1092
|
+
sparse_plus : Sparse plus activation function.
|
1093
|
+
|
1094
|
+
References
|
1095
|
+
----------
|
1096
|
+
.. [1] Martins, A. F. T., & Astudillo, R. F. (2016). "From Softmax to Sparsemax:
|
1097
|
+
A Sparse Model of Attention and Multi-Label Classification."
|
1098
|
+
In ICML. See also "Learning with Fenchel-Young Losses", arXiv:1901.02324
|
807
1099
|
"""
|
808
1100
|
return u.math.sparse_sigmoid(x)
|