brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -30,614 +30,624 @@ from brainstate import random
|
|
30
30
|
from brainstate.typing import ArrayLike
|
31
31
|
|
32
32
|
__all__ = [
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
33
|
+
"tanh",
|
34
|
+
"relu",
|
35
|
+
"squareplus",
|
36
|
+
"softplus",
|
37
|
+
"soft_sign",
|
38
|
+
"sigmoid",
|
39
|
+
"silu",
|
40
|
+
"swish",
|
41
|
+
"log_sigmoid",
|
42
|
+
"elu",
|
43
|
+
"leaky_relu",
|
44
|
+
"hard_tanh",
|
45
|
+
"celu",
|
46
|
+
"selu",
|
47
|
+
"gelu",
|
48
|
+
"glu",
|
49
|
+
"logsumexp",
|
50
|
+
"log_softmax",
|
51
|
+
"softmax",
|
52
|
+
"standardize",
|
53
|
+
"one_hot",
|
54
|
+
"relu6",
|
55
|
+
"hard_sigmoid",
|
56
|
+
"hard_silu",
|
57
|
+
"hard_swish",
|
58
|
+
'hard_shrink',
|
59
|
+
'rrelu',
|
60
|
+
'mish',
|
61
|
+
'soft_shrink',
|
62
|
+
'prelu',
|
63
|
+
'tanh_shrink',
|
64
|
+
'softmin',
|
65
|
+
'sparse_plus',
|
66
|
+
'sparse_sigmoid',
|
67
67
|
]
|
68
68
|
|
69
69
|
|
70
70
|
def tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
71
|
-
|
71
|
+
r"""Hyperbolic tangent activation function.
|
72
72
|
|
73
|
-
|
73
|
+
Computes the element-wise function:
|
74
74
|
|
75
|
-
|
76
|
-
|
75
|
+
.. math::
|
76
|
+
\mathrm{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
|
77
77
|
|
78
|
-
|
79
|
-
|
78
|
+
Args:
|
79
|
+
x : input array
|
80
80
|
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
81
|
+
Returns:
|
82
|
+
An array.
|
83
|
+
"""
|
84
|
+
return u.math.tanh(x)
|
85
85
|
|
86
86
|
|
87
87
|
def softmin(x, axis=-1):
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
88
|
+
r"""
|
89
|
+
Applies the Softmin function to an n-dimensional input Tensor
|
90
|
+
rescaling them so that the elements of the n-dimensional output Tensor
|
91
|
+
lie in the range `[0, 1]` and sum to 1.
|
92
92
|
|
93
|
-
|
93
|
+
Softmin is defined as:
|
94
94
|
|
95
|
-
|
96
|
-
|
95
|
+
.. math::
|
96
|
+
\text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
|
97
97
|
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
98
|
+
Shape:
|
99
|
+
- Input: :math:`(*)` where `*` means, any number of additional
|
100
|
+
dimensions
|
101
|
+
- Output: :math:`(*)`, same shape as the input
|
102
102
|
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
103
|
+
Args:
|
104
|
+
axis (int): A dimension along which Softmin will be computed (so every slice
|
105
|
+
along dim will sum to 1).
|
106
|
+
"""
|
107
|
+
unnormalized = u.math.exp(-x)
|
108
|
+
return unnormalized / unnormalized.sum(axis, keepdims=True)
|
109
109
|
|
110
110
|
|
111
111
|
def tanh_shrink(x):
|
112
|
-
|
113
|
-
|
112
|
+
r"""
|
113
|
+
Applies the element-wise function:
|
114
114
|
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
115
|
+
.. math::
|
116
|
+
\text{Tanhshrink}(x) = x - \tanh(x)
|
117
|
+
"""
|
118
|
+
return x - u.math.tanh(x)
|
119
119
|
|
120
120
|
|
121
121
|
def prelu(x, a=0.25):
|
122
|
-
|
123
|
-
|
122
|
+
r"""
|
123
|
+
Applies the element-wise function:
|
124
124
|
|
125
|
-
|
126
|
-
|
125
|
+
.. math::
|
126
|
+
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
127
127
|
|
128
|
-
|
128
|
+
or
|
129
129
|
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
130
|
+
.. math::
|
131
|
+
\text{PReLU}(x) =
|
132
|
+
\begin{cases}
|
133
|
+
x, & \text{ if } x \geq 0 \\
|
134
|
+
ax, & \text{ otherwise }
|
135
|
+
\end{cases}
|
136
136
|
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
137
|
+
Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
|
138
|
+
parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
|
139
|
+
a separate :math:`a` is used for each input channel.
|
140
|
+
"""
|
141
|
+
return u.math.where(x >= 0., x, a * x)
|
142
142
|
|
143
143
|
|
144
144
|
def soft_shrink(x, lambd=0.5):
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
145
|
+
r"""
|
146
|
+
Applies the soft shrinkage function elementwise:
|
147
|
+
|
148
|
+
.. math::
|
149
|
+
\text{SoftShrinkage}(x) =
|
150
|
+
\begin{cases}
|
151
|
+
x - \lambda, & \text{ if } x > \lambda \\
|
152
|
+
x + \lambda, & \text{ if } x < -\lambda \\
|
153
|
+
0, & \text{ otherwise }
|
154
|
+
\end{cases}
|
155
|
+
|
156
|
+
Args:
|
157
|
+
lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
|
158
|
+
|
159
|
+
Shape:
|
160
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
161
|
+
- Output: :math:`(*)`, same shape as the input.
|
162
|
+
"""
|
163
|
+
return u.math.where(x > lambd,
|
164
|
+
x - lambd,
|
165
|
+
u.math.where(x < -lambd,
|
166
|
+
x + lambd,
|
167
|
+
u.Quantity(0., unit=u.get_unit(lambd))))
|
168
168
|
|
169
169
|
|
170
170
|
def mish(x):
|
171
|
-
|
171
|
+
r"""Applies the Mish function, element-wise.
|
172
172
|
|
173
|
-
|
173
|
+
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
174
174
|
|
175
|
-
|
176
|
-
|
175
|
+
.. math::
|
176
|
+
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
|
177
177
|
|
178
|
-
|
179
|
-
|
178
|
+
.. note::
|
179
|
+
See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
|
180
180
|
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
181
|
+
Shape:
|
182
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
183
|
+
- Output: :math:`(*)`, same shape as the input.
|
184
|
+
"""
|
185
|
+
return x * u.math.tanh(softplus(x))
|
186
186
|
|
187
187
|
|
188
188
|
def rrelu(x, lower=0.125, upper=0.3333333333333333):
|
189
|
-
|
190
|
-
|
189
|
+
r"""Applies the randomized leaky rectified liner unit function, element-wise,
|
190
|
+
as described in the paper:
|
191
191
|
|
192
|
-
|
192
|
+
`Empirical Evaluation of Rectified Activations in Convolutional Network`_.
|
193
193
|
|
194
|
-
|
194
|
+
The function is defined as:
|
195
195
|
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
196
|
+
.. math::
|
197
|
+
\text{RReLU}(x) =
|
198
|
+
\begin{cases}
|
199
|
+
x & \text{if } x \geq 0 \\
|
200
|
+
ax & \text{ otherwise }
|
201
|
+
\end{cases}
|
202
202
|
|
203
|
-
|
204
|
-
|
203
|
+
where :math:`a` is randomly sampled from uniform distribution
|
204
|
+
:math:`\mathcal{U}(\text{lower}, \text{upper})`.
|
205
205
|
|
206
|
-
|
206
|
+
See: https://arxiv.org/pdf/1505.00853.pdf
|
207
207
|
|
208
|
-
|
209
|
-
|
210
|
-
|
208
|
+
Args:
|
209
|
+
lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
|
210
|
+
upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
|
211
211
|
|
212
|
-
|
213
|
-
|
214
|
-
|
212
|
+
Shape:
|
213
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
214
|
+
- Output: :math:`(*)`, same shape as the input.
|
215
215
|
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
216
|
+
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
|
217
|
+
https://arxiv.org/abs/1505.00853
|
218
|
+
"""
|
219
|
+
a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
|
220
|
+
return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
|
221
221
|
|
222
222
|
|
223
223
|
def hard_shrink(x, lambd=0.5):
|
224
|
-
|
224
|
+
r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
|
225
225
|
|
226
|
-
|
226
|
+
Hardshrink is defined as:
|
227
227
|
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
228
|
+
.. math::
|
229
|
+
\text{HardShrink}(x) =
|
230
|
+
\begin{cases}
|
231
|
+
x, & \text{ if } x > \lambda \\
|
232
|
+
x, & \text{ if } x < -\lambda \\
|
233
|
+
0, & \text{ otherwise }
|
234
|
+
\end{cases}
|
235
235
|
|
236
|
-
|
237
|
-
|
236
|
+
Args:
|
237
|
+
lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
|
238
238
|
|
239
|
-
|
240
|
-
|
241
|
-
|
239
|
+
Shape:
|
240
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
241
|
+
- Output: :math:`(*)`, same shape as the input.
|
242
242
|
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
243
|
+
"""
|
244
|
+
return u.math.where(x > lambd,
|
245
|
+
x,
|
246
|
+
u.math.where(x < -lambd,
|
247
|
+
x,
|
248
|
+
u.Quantity(0., unit=u.get_unit(x))))
|
249
249
|
|
250
250
|
|
251
251
|
def _keep_unit(fun, x, **kwargs):
|
252
|
-
|
253
|
-
|
254
|
-
|
252
|
+
unit = u.get_unit(x)
|
253
|
+
x = fun(u.get_mantissa(x), **kwargs)
|
254
|
+
return x if unit.is_unitless else u.Quantity(x, unit=unit)
|
255
255
|
|
256
256
|
|
257
257
|
def relu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
258
|
-
|
258
|
+
r"""Rectified linear unit activation function.
|
259
259
|
|
260
|
-
|
260
|
+
Computes the element-wise function:
|
261
261
|
|
262
|
-
|
263
|
-
|
262
|
+
.. math::
|
263
|
+
\mathrm{relu}(x) = \max(x, 0)
|
264
264
|
|
265
|
-
|
265
|
+
except under differentiation, we take:
|
266
266
|
|
267
|
-
|
268
|
-
|
267
|
+
.. math::
|
268
|
+
\nabla \mathrm{relu}(0) = 0
|
269
269
|
|
270
|
-
|
271
|
-
|
272
|
-
|
270
|
+
For more information see
|
271
|
+
`Numerical influence of ReLU’(0) on backpropagation
|
272
|
+
<https://openreview.net/forum?id=urrcVI-_jRm>`_.
|
273
273
|
|
274
|
-
|
275
|
-
|
274
|
+
Args:
|
275
|
+
x : input array
|
276
276
|
|
277
|
-
|
278
|
-
|
277
|
+
Returns:
|
278
|
+
An array.
|
279
279
|
|
280
|
-
|
281
|
-
|
282
|
-
|
280
|
+
Example:
|
281
|
+
>>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
|
282
|
+
Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
|
283
283
|
|
284
|
-
|
285
|
-
|
284
|
+
See also:
|
285
|
+
:func:`relu6`
|
286
286
|
|
287
|
-
|
288
|
-
|
287
|
+
"""
|
288
|
+
return _keep_unit(jax.nn.relu, x)
|
289
289
|
|
290
290
|
|
291
291
|
def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Union[jax.Array, u.Quantity]:
|
292
|
-
|
292
|
+
r"""Squareplus activation function.
|
293
293
|
|
294
|
-
|
294
|
+
Computes the element-wise function
|
295
295
|
|
296
|
-
|
297
|
-
|
296
|
+
.. math::
|
297
|
+
\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
|
298
298
|
|
299
|
-
|
299
|
+
as described in https://arxiv.org/abs/2112.11687.
|
300
300
|
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
301
|
+
Args:
|
302
|
+
x : input array
|
303
|
+
b : smoothness parameter
|
304
|
+
"""
|
305
|
+
return _keep_unit(jax.nn.squareplus, x, b=b)
|
306
306
|
|
307
307
|
|
308
308
|
def softplus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
309
|
-
|
309
|
+
r"""Softplus activation function.
|
310
310
|
|
311
|
-
|
311
|
+
Computes the element-wise function
|
312
312
|
|
313
|
-
|
314
|
-
|
313
|
+
.. math::
|
314
|
+
\mathrm{softplus}(x) = \log(1 + e^x)
|
315
315
|
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
316
|
+
Args:
|
317
|
+
x : input array
|
318
|
+
"""
|
319
|
+
return _keep_unit(jax.nn.softplus, x)
|
320
320
|
|
321
321
|
|
322
322
|
def soft_sign(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
323
|
-
|
323
|
+
r"""Soft-sign activation function.
|
324
324
|
|
325
|
-
|
325
|
+
Computes the element-wise function
|
326
326
|
|
327
|
-
|
328
|
-
|
327
|
+
.. math::
|
328
|
+
\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
|
329
329
|
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
330
|
+
Args:
|
331
|
+
x : input array
|
332
|
+
"""
|
333
|
+
return _keep_unit(jax.nn.soft_sign, x)
|
334
334
|
|
335
335
|
|
336
336
|
def sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
337
|
-
|
337
|
+
r"""Sigmoid activation function.
|
338
338
|
|
339
|
-
|
339
|
+
Computes the element-wise function:
|
340
340
|
|
341
|
-
|
342
|
-
|
341
|
+
.. math::
|
342
|
+
\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
|
343
343
|
|
344
|
-
|
345
|
-
|
344
|
+
Args:
|
345
|
+
x : input array
|
346
346
|
|
347
|
-
|
348
|
-
|
347
|
+
Returns:
|
348
|
+
An array.
|
349
349
|
|
350
|
-
|
351
|
-
|
350
|
+
See also:
|
351
|
+
:func:`log_sigmoid`
|
352
352
|
|
353
|
-
|
354
|
-
|
353
|
+
"""
|
354
|
+
return _keep_unit(jax.nn.sigmoid, x)
|
355
355
|
|
356
356
|
|
357
357
|
def silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
358
|
-
|
358
|
+
r"""SiLU (a.k.a. swish) activation function.
|
359
359
|
|
360
|
-
|
360
|
+
Computes the element-wise function:
|
361
361
|
|
362
|
-
|
363
|
-
|
362
|
+
.. math::
|
363
|
+
\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
|
364
364
|
|
365
|
-
|
365
|
+
:func:`swish` and :func:`silu` are both aliases for the same function.
|
366
366
|
|
367
|
-
|
368
|
-
|
367
|
+
Args:
|
368
|
+
x : input array
|
369
369
|
|
370
|
-
|
371
|
-
|
370
|
+
Returns:
|
371
|
+
An array.
|
372
372
|
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
373
|
+
See also:
|
374
|
+
:func:`sigmoid`
|
375
|
+
"""
|
376
|
+
return _keep_unit(jax.nn.silu, x)
|
377
377
|
|
378
378
|
|
379
379
|
swish = silu
|
380
380
|
|
381
381
|
|
382
382
|
def log_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
383
|
-
|
383
|
+
r"""Log-sigmoid activation function.
|
384
384
|
|
385
|
-
|
385
|
+
Computes the element-wise function:
|
386
386
|
|
387
|
-
|
388
|
-
|
387
|
+
.. math::
|
388
|
+
\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
|
389
389
|
|
390
|
-
|
391
|
-
|
390
|
+
Args:
|
391
|
+
x : input array
|
392
392
|
|
393
|
-
|
394
|
-
|
393
|
+
Returns:
|
394
|
+
An array.
|
395
395
|
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
396
|
+
See also:
|
397
|
+
:func:`sigmoid`
|
398
|
+
"""
|
399
|
+
return _keep_unit(jax.nn.log_sigmoid, x)
|
400
400
|
|
401
401
|
|
402
402
|
def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
|
403
|
-
|
403
|
+
r"""Exponential linear unit activation function.
|
404
404
|
|
405
|
-
|
405
|
+
Computes the element-wise function:
|
406
406
|
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
407
|
+
.. math::
|
408
|
+
\mathrm{elu}(x) = \begin{cases}
|
409
|
+
x, & x > 0\\
|
410
|
+
\alpha \left(\exp(x) - 1\right), & x \le 0
|
411
|
+
\end{cases}
|
412
412
|
|
413
|
-
|
414
|
-
|
415
|
-
|
413
|
+
Args:
|
414
|
+
x : input array
|
415
|
+
alpha : scalar or array of alpha values (default: 1.0)
|
416
416
|
|
417
|
-
|
418
|
-
|
417
|
+
Returns:
|
418
|
+
An array.
|
419
419
|
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
420
|
+
See also:
|
421
|
+
:func:`selu`
|
422
|
+
"""
|
423
|
+
return _keep_unit(jax.nn.elu, x)
|
424
424
|
|
425
425
|
|
426
426
|
def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Union[jax.Array, u.Quantity]:
|
427
|
-
|
427
|
+
r"""Leaky rectified linear unit activation function.
|
428
428
|
|
429
|
-
|
429
|
+
Computes the element-wise function:
|
430
430
|
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
431
|
+
.. math::
|
432
|
+
\mathrm{leaky\_relu}(x) = \begin{cases}
|
433
|
+
x, & x \ge 0\\
|
434
|
+
\alpha x, & x < 0
|
435
|
+
\end{cases}
|
436
436
|
|
437
|
-
|
437
|
+
where :math:`\alpha` = :code:`negative_slope`.
|
438
438
|
|
439
|
-
|
440
|
-
|
441
|
-
|
439
|
+
Args:
|
440
|
+
x : input array
|
441
|
+
negative_slope : array or scalar specifying the negative slope (default: 0.01)
|
442
442
|
|
443
|
-
|
444
|
-
|
443
|
+
Returns:
|
444
|
+
An array.
|
445
445
|
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
446
|
+
See also:
|
447
|
+
:func:`relu`
|
448
|
+
"""
|
449
|
+
return _keep_unit(jax.nn.leaky_relu, x, negative_slope=negative_slope)
|
450
450
|
|
451
451
|
|
452
|
-
def
|
453
|
-
|
452
|
+
def _hard_tanh(x, min_val=- 1.0, max_val=1.0):
|
453
|
+
return jax.numpy.where(x > max_val, max_val, jax.numpy.where(x < min_val, min_val, x))
|
454
454
|
|
455
|
-
Computes the element-wise function:
|
456
455
|
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
456
|
+
def hard_tanh(
|
457
|
+
x: ArrayLike,
|
458
|
+
min_val: float = - 1.0,
|
459
|
+
max_val: float = 1.0
|
460
|
+
) -> Union[jax.Array, u.Quantity]:
|
461
|
+
r"""Hard :math:`\mathrm{tanh}` activation function.
|
463
462
|
|
464
|
-
|
465
|
-
x : input array
|
463
|
+
Computes the element-wise function:
|
466
464
|
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
465
|
+
.. math::
|
466
|
+
\mathrm{hard\_tanh}(x) = \begin{cases}
|
467
|
+
-1, & x < -1\\
|
468
|
+
x, & -1 \le x \le 1\\
|
469
|
+
1, & 1 < x
|
470
|
+
\end{cases}
|
471
|
+
|
472
|
+
Args:
|
473
|
+
x : input array
|
474
|
+
min_val: float. minimum value of the linear region range. Default: -1
|
475
|
+
max_val: float. maximum value of the linear region range. Default: 1
|
476
|
+
|
477
|
+
Returns:
|
478
|
+
An array.
|
479
|
+
"""
|
480
|
+
return _keep_unit(_hard_tanh, x, min_val=min_val, max_val=max_val)
|
471
481
|
|
472
482
|
|
473
483
|
def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
|
474
|
-
|
484
|
+
r"""Continuously-differentiable exponential linear unit activation.
|
475
485
|
|
476
|
-
|
486
|
+
Computes the element-wise function:
|
477
487
|
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
488
|
+
.. math::
|
489
|
+
\mathrm{celu}(x) = \begin{cases}
|
490
|
+
x, & x > 0\\
|
491
|
+
\alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
|
492
|
+
\end{cases}
|
483
493
|
|
484
|
-
|
485
|
-
|
486
|
-
|
494
|
+
For more information, see
|
495
|
+
`Continuously Differentiable Exponential Linear Units
|
496
|
+
<https://arxiv.org/pdf/1704.07483.pdf>`_.
|
487
497
|
|
488
|
-
|
489
|
-
|
490
|
-
|
498
|
+
Args:
|
499
|
+
x : input array
|
500
|
+
alpha : array or scalar (default: 1.0)
|
491
501
|
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
502
|
+
Returns:
|
503
|
+
An array.
|
504
|
+
"""
|
505
|
+
return _keep_unit(jax.nn.celu, x, alpha=alpha)
|
496
506
|
|
497
507
|
|
498
508
|
def selu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
499
|
-
|
509
|
+
r"""Scaled exponential linear unit activation.
|
500
510
|
|
501
|
-
|
511
|
+
Computes the element-wise function:
|
502
512
|
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
513
|
+
.. math::
|
514
|
+
\mathrm{selu}(x) = \lambda \begin{cases}
|
515
|
+
x, & x > 0\\
|
516
|
+
\alpha e^x - \alpha, & x \le 0
|
517
|
+
\end{cases}
|
508
518
|
|
509
|
-
|
510
|
-
|
519
|
+
where :math:`\lambda = 1.0507009873554804934193349852946` and
|
520
|
+
:math:`\alpha = 1.6732632423543772848170429916717`.
|
511
521
|
|
512
|
-
|
513
|
-
|
514
|
-
|
522
|
+
For more information, see
|
523
|
+
`Self-Normalizing Neural Networks
|
524
|
+
<https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf>`_.
|
515
525
|
|
516
|
-
|
517
|
-
|
526
|
+
Args:
|
527
|
+
x : input array
|
518
528
|
|
519
|
-
|
520
|
-
|
529
|
+
Returns:
|
530
|
+
An array.
|
521
531
|
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
532
|
+
See also:
|
533
|
+
:func:`elu`
|
534
|
+
"""
|
535
|
+
return _keep_unit(jax.nn.selu, x)
|
526
536
|
|
527
537
|
|
528
538
|
def gelu(x: ArrayLike, approximate: bool = True) -> Union[jax.Array, u.Quantity]:
|
529
|
-
|
539
|
+
r"""Gaussian error linear unit activation function.
|
530
540
|
|
531
|
-
|
541
|
+
If ``approximate=False``, computes the element-wise function:
|
532
542
|
|
533
|
-
|
534
|
-
|
535
|
-
|
543
|
+
.. math::
|
544
|
+
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
|
545
|
+
\frac{x}{\sqrt{2}} \right) \right)
|
536
546
|
|
537
|
-
|
547
|
+
If ``approximate=True``, uses the approximate formulation of GELU:
|
538
548
|
|
539
|
-
|
540
|
-
|
541
|
-
|
549
|
+
.. math::
|
550
|
+
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
|
551
|
+
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
|
542
552
|
|
543
|
-
|
544
|
-
|
553
|
+
For more information, see `Gaussian Error Linear Units (GELUs)
|
554
|
+
<https://arxiv.org/abs/1606.08415>`_, section 2.
|
545
555
|
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
556
|
+
Args:
|
557
|
+
x : input array
|
558
|
+
approximate: whether to use the approximate or exact formulation.
|
559
|
+
"""
|
560
|
+
return _keep_unit(jax.nn.gelu, x, approximate=approximate)
|
551
561
|
|
552
562
|
|
553
563
|
def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
|
554
|
-
|
564
|
+
r"""Gated linear unit activation function.
|
555
565
|
|
556
|
-
|
566
|
+
Computes the function:
|
557
567
|
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
568
|
+
.. math::
|
569
|
+
\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot
|
570
|
+
\mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right]
|
571
|
+
\right)
|
562
572
|
|
563
|
-
|
564
|
-
|
573
|
+
where the array is split into two along ``axis``. The size of the ``axis``
|
574
|
+
dimension must be divisible by two.
|
565
575
|
|
566
|
-
|
567
|
-
|
568
|
-
|
576
|
+
Args:
|
577
|
+
x : input array
|
578
|
+
axis: the axis along which the split should be computed (default: -1)
|
569
579
|
|
570
|
-
|
571
|
-
|
580
|
+
Returns:
|
581
|
+
An array.
|
572
582
|
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
583
|
+
See also:
|
584
|
+
:func:`sigmoid`
|
585
|
+
"""
|
586
|
+
return _keep_unit(jax.nn.glu, x, axis=axis)
|
577
587
|
|
578
588
|
|
579
589
|
def log_softmax(x: ArrayLike,
|
580
590
|
axis: int | tuple[int, ...] | None = -1,
|
581
591
|
where: ArrayLike | None = None,
|
582
592
|
initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
583
|
-
|
593
|
+
r"""Log-Softmax function.
|
584
594
|
|
585
|
-
|
586
|
-
|
595
|
+
Computes the logarithm of the :code:`softmax` function, which rescales
|
596
|
+
elements to the range :math:`[-\infty, 0)`.
|
587
597
|
|
588
|
-
|
589
|
-
|
590
|
-
|
598
|
+
.. math ::
|
599
|
+
\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
600
|
+
\right)
|
591
601
|
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
602
|
+
Args:
|
603
|
+
x : input array
|
604
|
+
axis: the axis or axes along which the :code:`log_softmax` should be
|
605
|
+
computed. Either an integer or a tuple of integers.
|
606
|
+
where: Elements to include in the :code:`log_softmax`.
|
607
|
+
initial: The minimum value used to shift the input array. Must be present
|
608
|
+
when :code:`where` is not None.
|
599
609
|
|
600
|
-
|
601
|
-
|
610
|
+
Returns:
|
611
|
+
An array.
|
602
612
|
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
613
|
+
See also:
|
614
|
+
:func:`softmax`
|
615
|
+
"""
|
616
|
+
if initial is not None:
|
617
|
+
initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
|
618
|
+
return _keep_unit(jax.nn.log_softmax, x, axis=axis, where=where, initial=initial)
|
609
619
|
|
610
620
|
|
611
621
|
def softmax(x: ArrayLike,
|
612
622
|
axis: int | tuple[int, ...] | None = -1,
|
613
623
|
where: ArrayLike | None = None,
|
614
624
|
initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
615
|
-
|
625
|
+
r"""Softmax function.
|
616
626
|
|
617
|
-
|
618
|
-
|
627
|
+
Computes the function which rescales elements to the range :math:`[0, 1]`
|
628
|
+
such that the elements along :code:`axis` sum to :math:`1`.
|
619
629
|
|
620
|
-
|
621
|
-
|
630
|
+
.. math ::
|
631
|
+
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
622
632
|
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
633
|
+
Args:
|
634
|
+
x : input array
|
635
|
+
axis: the axis or axes along which the softmax should be computed. The
|
636
|
+
softmax output summed across these dimensions should sum to :math:`1`.
|
637
|
+
Either an integer or a tuple of integers.
|
638
|
+
where: Elements to include in the :code:`softmax`.
|
639
|
+
initial: The minimum value used to shift the input array. Must be present
|
640
|
+
when :code:`where` is not None.
|
631
641
|
|
632
|
-
|
633
|
-
|
642
|
+
Returns:
|
643
|
+
An array.
|
634
644
|
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
645
|
+
See also:
|
646
|
+
:func:`log_softmax`
|
647
|
+
"""
|
648
|
+
if initial is not None:
|
649
|
+
initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
|
650
|
+
return _keep_unit(jax.nn.softmax, x, axis=axis, where=where, initial=initial)
|
641
651
|
|
642
652
|
|
643
653
|
def standardize(x: ArrayLike,
|
@@ -645,169 +655,169 @@ def standardize(x: ArrayLike,
|
|
645
655
|
variance: ArrayLike | None = None,
|
646
656
|
epsilon: ArrayLike = 1e-5,
|
647
657
|
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
648
|
-
|
649
|
-
|
658
|
+
r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
|
659
|
+
return _keep_unit(jax.nn.standardize, x, axis=axis, where=where, variance=variance, epsilon=epsilon)
|
650
660
|
|
651
661
|
|
652
662
|
def one_hot(x: Any,
|
653
663
|
num_classes: int, *,
|
654
664
|
dtype: Any = jax.numpy.float_,
|
655
665
|
axis: Union[int, Sequence[int]] = -1) -> Union[jax.Array, u.Quantity]:
|
656
|
-
|
666
|
+
"""One-hot encodes the given indices.
|
657
667
|
|
658
|
-
|
659
|
-
|
668
|
+
Each index in the input ``x`` is encoded as a vector of zeros of length
|
669
|
+
``num_classes`` with the element at ``index`` set to one::
|
660
670
|
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
671
|
+
>>> one_hot(jnp.array([0, 1, 2]), 3)
|
672
|
+
Array([[1., 0., 0.],
|
673
|
+
[0., 1., 0.],
|
674
|
+
[0., 0., 1.]], dtype=float32)
|
665
675
|
|
666
|
-
|
676
|
+
Indices outside the range [0, num_classes) will be encoded as zeros::
|
667
677
|
|
668
|
-
|
669
|
-
|
670
|
-
|
678
|
+
>>> one_hot(jnp.array([-1, 3]), 3)
|
679
|
+
Array([[0., 0., 0.],
|
680
|
+
[0., 0., 0.]], dtype=float32)
|
671
681
|
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
682
|
+
Args:
|
683
|
+
x: A tensor of indices.
|
684
|
+
num_classes: Number of classes in the one-hot dimension.
|
685
|
+
dtype: optional, a float dtype for the returned values (default :obj:`jnp.float_`).
|
686
|
+
axis: the axis or axes along which the function should be
|
687
|
+
computed.
|
688
|
+
"""
|
689
|
+
return _keep_unit(jax.nn.one_hot, x, axis=axis, num_classes=num_classes, dtype=dtype)
|
680
690
|
|
681
691
|
|
682
692
|
def relu6(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
683
|
-
|
693
|
+
r"""Rectified Linear Unit 6 activation function.
|
684
694
|
|
685
|
-
|
695
|
+
Computes the element-wise function
|
686
696
|
|
687
|
-
|
688
|
-
|
697
|
+
.. math::
|
698
|
+
\mathrm{relu6}(x) = \min(\max(x, 0), 6)
|
689
699
|
|
690
|
-
|
700
|
+
except under differentiation, we take:
|
691
701
|
|
692
|
-
|
693
|
-
|
702
|
+
.. math::
|
703
|
+
\nabla \mathrm{relu}(0) = 0
|
694
704
|
|
695
|
-
|
705
|
+
and
|
696
706
|
|
697
|
-
|
698
|
-
|
707
|
+
.. math::
|
708
|
+
\nabla \mathrm{relu}(6) = 0
|
699
709
|
|
700
|
-
|
701
|
-
|
710
|
+
Args:
|
711
|
+
x : input array
|
702
712
|
|
703
|
-
|
704
|
-
|
713
|
+
Returns:
|
714
|
+
An array.
|
705
715
|
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
716
|
+
See also:
|
717
|
+
:func:`relu`
|
718
|
+
"""
|
719
|
+
return _keep_unit(jax.nn.relu6, x)
|
710
720
|
|
711
721
|
|
712
722
|
def hard_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
713
|
-
|
723
|
+
r"""Hard Sigmoid activation function.
|
714
724
|
|
715
|
-
|
725
|
+
Computes the element-wise function
|
716
726
|
|
717
|
-
|
718
|
-
|
727
|
+
.. math::
|
728
|
+
\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
|
719
729
|
|
720
|
-
|
721
|
-
|
730
|
+
Args:
|
731
|
+
x : input array
|
722
732
|
|
723
|
-
|
724
|
-
|
733
|
+
Returns:
|
734
|
+
An array.
|
725
735
|
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
736
|
+
See also:
|
737
|
+
:func:`relu6`
|
738
|
+
"""
|
739
|
+
return _keep_unit(jax.nn.hard_sigmoid, x)
|
730
740
|
|
731
741
|
|
732
742
|
def hard_silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
733
|
-
|
743
|
+
r"""Hard SiLU (swish) activation function
|
734
744
|
|
735
|
-
|
745
|
+
Computes the element-wise function
|
736
746
|
|
737
|
-
|
738
|
-
|
747
|
+
.. math::
|
748
|
+
\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
|
739
749
|
|
740
|
-
|
741
|
-
|
750
|
+
Both :func:`hard_silu` and :func:`hard_swish` are aliases for the same
|
751
|
+
function.
|
742
752
|
|
743
|
-
|
744
|
-
|
753
|
+
Args:
|
754
|
+
x : input array
|
745
755
|
|
746
|
-
|
747
|
-
|
756
|
+
Returns:
|
757
|
+
An array.
|
748
758
|
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
759
|
+
See also:
|
760
|
+
:func:`hard_sigmoid`
|
761
|
+
"""
|
762
|
+
return _keep_unit(jax.nn.hard_silu, x)
|
753
763
|
|
754
|
-
|
764
|
+
return jax.nn.hard_silu(x)
|
755
765
|
|
756
766
|
|
757
767
|
hard_swish = hard_silu
|
758
768
|
|
759
769
|
|
760
770
|
def sparse_plus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
761
|
-
|
771
|
+
r"""Sparse plus function.
|
762
772
|
|
763
|
-
|
773
|
+
Computes the function:
|
764
774
|
|
765
|
-
|
775
|
+
.. math::
|
766
776
|
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
777
|
+
\mathrm{sparse\_plus}(x) = \begin{cases}
|
778
|
+
0, & x \leq -1\\
|
779
|
+
\frac{1}{4}(x+1)^2, & -1 < x < 1 \\
|
780
|
+
x, & 1 \leq x
|
781
|
+
\end{cases}
|
772
782
|
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
783
|
+
This is the twin function of the softplus activation ensuring a zero output
|
784
|
+
for inputs less than -1 and a linear output for inputs greater than 1,
|
785
|
+
while remaining smooth, convex, monotonic by an adequate definition between
|
786
|
+
-1 and 1.
|
777
787
|
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
788
|
+
Args:
|
789
|
+
x: input (float)
|
790
|
+
"""
|
791
|
+
return _keep_unit(jax.nn.sparse_plus, x)
|
782
792
|
|
783
793
|
|
784
794
|
def sparse_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
|
785
|
-
|
795
|
+
r"""Sparse sigmoid activation function.
|
786
796
|
|
787
|
-
|
797
|
+
Computes the function:
|
788
798
|
|
789
|
-
|
799
|
+
.. math::
|
790
800
|
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
801
|
+
\mathrm{sparse\_sigmoid}(x) = \begin{cases}
|
802
|
+
0, & x \leq -1\\
|
803
|
+
\frac{1}{2}(x+1), & -1 < x < 1 \\
|
804
|
+
1, & 1 \leq x
|
805
|
+
\end{cases}
|
796
806
|
|
797
|
-
|
798
|
-
|
799
|
-
|
807
|
+
This is the twin function of the ``sigmoid`` activation ensuring a zero output
|
808
|
+
for inputs less than -1, a 1 output for inputs greater than 1, and a linear
|
809
|
+
output for inputs between -1 and 1. It is the derivative of ``sparse_plus``.
|
800
810
|
|
801
|
-
|
802
|
-
|
811
|
+
For more information, see `Learning with Fenchel-Young Losses (section 6.2)
|
812
|
+
<https://arxiv.org/abs/1901.02324>`_.
|
803
813
|
|
804
|
-
|
805
|
-
|
814
|
+
Args:
|
815
|
+
x : input array
|
806
816
|
|
807
|
-
|
808
|
-
|
817
|
+
Returns:
|
818
|
+
An array.
|
809
819
|
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
820
|
+
See also:
|
821
|
+
:func:`sigmoid`
|
822
|
+
"""
|
823
|
+
return _keep_unit(jax.nn.sparse_sigmoid, x)
|