brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -30,614 +30,624 @@ from brainstate import random
30
30
  from brainstate.typing import ArrayLike
31
31
 
32
32
  __all__ = [
33
- "tanh",
34
- "relu",
35
- "squareplus",
36
- "softplus",
37
- "soft_sign",
38
- "sigmoid",
39
- "silu",
40
- "swish",
41
- "log_sigmoid",
42
- "elu",
43
- "leaky_relu",
44
- "hard_tanh",
45
- "celu",
46
- "selu",
47
- "gelu",
48
- "glu",
49
- "logsumexp",
50
- "log_softmax",
51
- "softmax",
52
- "standardize",
53
- "one_hot",
54
- "relu6",
55
- "hard_sigmoid",
56
- "hard_silu",
57
- "hard_swish",
58
- 'hard_shrink',
59
- 'rrelu',
60
- 'mish',
61
- 'soft_shrink',
62
- 'prelu',
63
- 'tanh_shrink',
64
- 'softmin',
65
- 'sparse_plus',
66
- 'sparse_sigmoid',
33
+ "tanh",
34
+ "relu",
35
+ "squareplus",
36
+ "softplus",
37
+ "soft_sign",
38
+ "sigmoid",
39
+ "silu",
40
+ "swish",
41
+ "log_sigmoid",
42
+ "elu",
43
+ "leaky_relu",
44
+ "hard_tanh",
45
+ "celu",
46
+ "selu",
47
+ "gelu",
48
+ "glu",
49
+ "logsumexp",
50
+ "log_softmax",
51
+ "softmax",
52
+ "standardize",
53
+ "one_hot",
54
+ "relu6",
55
+ "hard_sigmoid",
56
+ "hard_silu",
57
+ "hard_swish",
58
+ 'hard_shrink',
59
+ 'rrelu',
60
+ 'mish',
61
+ 'soft_shrink',
62
+ 'prelu',
63
+ 'tanh_shrink',
64
+ 'softmin',
65
+ 'sparse_plus',
66
+ 'sparse_sigmoid',
67
67
  ]
68
68
 
69
69
 
70
70
  def tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
71
- r"""Hyperbolic tangent activation function.
71
+ r"""Hyperbolic tangent activation function.
72
72
 
73
- Computes the element-wise function:
73
+ Computes the element-wise function:
74
74
 
75
- .. math::
76
- \mathrm{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
75
+ .. math::
76
+ \mathrm{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
77
77
 
78
- Args:
79
- x : input array
78
+ Args:
79
+ x : input array
80
80
 
81
- Returns:
82
- An array.
83
- """
84
- return u.math.tanh(x)
81
+ Returns:
82
+ An array.
83
+ """
84
+ return u.math.tanh(x)
85
85
 
86
86
 
87
87
  def softmin(x, axis=-1):
88
- r"""
89
- Applies the Softmin function to an n-dimensional input Tensor
90
- rescaling them so that the elements of the n-dimensional output Tensor
91
- lie in the range `[0, 1]` and sum to 1.
88
+ r"""
89
+ Applies the Softmin function to an n-dimensional input Tensor
90
+ rescaling them so that the elements of the n-dimensional output Tensor
91
+ lie in the range `[0, 1]` and sum to 1.
92
92
 
93
- Softmin is defined as:
93
+ Softmin is defined as:
94
94
 
95
- .. math::
96
- \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
95
+ .. math::
96
+ \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
97
97
 
98
- Shape:
99
- - Input: :math:`(*)` where `*` means, any number of additional
100
- dimensions
101
- - Output: :math:`(*)`, same shape as the input
98
+ Shape:
99
+ - Input: :math:`(*)` where `*` means, any number of additional
100
+ dimensions
101
+ - Output: :math:`(*)`, same shape as the input
102
102
 
103
- Args:
104
- axis (int): A dimension along which Softmin will be computed (so every slice
105
- along dim will sum to 1).
106
- """
107
- unnormalized = u.math.exp(-x)
108
- return unnormalized / unnormalized.sum(axis, keepdims=True)
103
+ Args:
104
+ axis (int): A dimension along which Softmin will be computed (so every slice
105
+ along dim will sum to 1).
106
+ """
107
+ unnormalized = u.math.exp(-x)
108
+ return unnormalized / unnormalized.sum(axis, keepdims=True)
109
109
 
110
110
 
111
111
  def tanh_shrink(x):
112
- r"""
113
- Applies the element-wise function:
112
+ r"""
113
+ Applies the element-wise function:
114
114
 
115
- .. math::
116
- \text{Tanhshrink}(x) = x - \tanh(x)
117
- """
118
- return x - u.math.tanh(x)
115
+ .. math::
116
+ \text{Tanhshrink}(x) = x - \tanh(x)
117
+ """
118
+ return x - u.math.tanh(x)
119
119
 
120
120
 
121
121
  def prelu(x, a=0.25):
122
- r"""
123
- Applies the element-wise function:
122
+ r"""
123
+ Applies the element-wise function:
124
124
 
125
- .. math::
126
- \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
125
+ .. math::
126
+ \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
127
127
 
128
- or
128
+ or
129
129
 
130
- .. math::
131
- \text{PReLU}(x) =
132
- \begin{cases}
133
- x, & \text{ if } x \geq 0 \\
134
- ax, & \text{ otherwise }
135
- \end{cases}
130
+ .. math::
131
+ \text{PReLU}(x) =
132
+ \begin{cases}
133
+ x, & \text{ if } x \geq 0 \\
134
+ ax, & \text{ otherwise }
135
+ \end{cases}
136
136
 
137
- Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
138
- parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
139
- a separate :math:`a` is used for each input channel.
140
- """
141
- return u.math.where(x >= 0., x, a * x)
137
+ Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
138
+ parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
139
+ a separate :math:`a` is used for each input channel.
140
+ """
141
+ return u.math.where(x >= 0., x, a * x)
142
142
 
143
143
 
144
144
  def soft_shrink(x, lambd=0.5):
145
- r"""
146
- Applies the soft shrinkage function elementwise:
147
-
148
- .. math::
149
- \text{SoftShrinkage}(x) =
150
- \begin{cases}
151
- x - \lambda, & \text{ if } x > \lambda \\
152
- x + \lambda, & \text{ if } x < -\lambda \\
153
- 0, & \text{ otherwise }
154
- \end{cases}
155
-
156
- Args:
157
- lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
158
-
159
- Shape:
160
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
161
- - Output: :math:`(*)`, same shape as the input.
162
- """
163
- return u.math.where(x > lambd,
164
- x - lambd,
165
- u.math.where(x < -lambd,
166
- x + lambd,
167
- u.Quantity(0., unit=u.get_unit(lambd))))
145
+ r"""
146
+ Applies the soft shrinkage function elementwise:
147
+
148
+ .. math::
149
+ \text{SoftShrinkage}(x) =
150
+ \begin{cases}
151
+ x - \lambda, & \text{ if } x > \lambda \\
152
+ x + \lambda, & \text{ if } x < -\lambda \\
153
+ 0, & \text{ otherwise }
154
+ \end{cases}
155
+
156
+ Args:
157
+ lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
158
+
159
+ Shape:
160
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
161
+ - Output: :math:`(*)`, same shape as the input.
162
+ """
163
+ return u.math.where(x > lambd,
164
+ x - lambd,
165
+ u.math.where(x < -lambd,
166
+ x + lambd,
167
+ u.Quantity(0., unit=u.get_unit(lambd))))
168
168
 
169
169
 
170
170
  def mish(x):
171
- r"""Applies the Mish function, element-wise.
171
+ r"""Applies the Mish function, element-wise.
172
172
 
173
- Mish: A Self Regularized Non-Monotonic Neural Activation Function.
173
+ Mish: A Self Regularized Non-Monotonic Neural Activation Function.
174
174
 
175
- .. math::
176
- \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
175
+ .. math::
176
+ \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
177
177
 
178
- .. note::
179
- See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
178
+ .. note::
179
+ See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
180
180
 
181
- Shape:
182
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
183
- - Output: :math:`(*)`, same shape as the input.
184
- """
185
- return x * u.math.tanh(softplus(x))
181
+ Shape:
182
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
183
+ - Output: :math:`(*)`, same shape as the input.
184
+ """
185
+ return x * u.math.tanh(softplus(x))
186
186
 
187
187
 
188
188
  def rrelu(x, lower=0.125, upper=0.3333333333333333):
189
- r"""Applies the randomized leaky rectified liner unit function, element-wise,
190
- as described in the paper:
189
+ r"""Applies the randomized leaky rectified liner unit function, element-wise,
190
+ as described in the paper:
191
191
 
192
- `Empirical Evaluation of Rectified Activations in Convolutional Network`_.
192
+ `Empirical Evaluation of Rectified Activations in Convolutional Network`_.
193
193
 
194
- The function is defined as:
194
+ The function is defined as:
195
195
 
196
- .. math::
197
- \text{RReLU}(x) =
198
- \begin{cases}
199
- x & \text{if } x \geq 0 \\
200
- ax & \text{ otherwise }
201
- \end{cases}
196
+ .. math::
197
+ \text{RReLU}(x) =
198
+ \begin{cases}
199
+ x & \text{if } x \geq 0 \\
200
+ ax & \text{ otherwise }
201
+ \end{cases}
202
202
 
203
- where :math:`a` is randomly sampled from uniform distribution
204
- :math:`\mathcal{U}(\text{lower}, \text{upper})`.
203
+ where :math:`a` is randomly sampled from uniform distribution
204
+ :math:`\mathcal{U}(\text{lower}, \text{upper})`.
205
205
 
206
- See: https://arxiv.org/pdf/1505.00853.pdf
206
+ See: https://arxiv.org/pdf/1505.00853.pdf
207
207
 
208
- Args:
209
- lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
210
- upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
208
+ Args:
209
+ lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
210
+ upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
211
211
 
212
- Shape:
213
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
214
- - Output: :math:`(*)`, same shape as the input.
212
+ Shape:
213
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
214
+ - Output: :math:`(*)`, same shape as the input.
215
215
 
216
- .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
217
- https://arxiv.org/abs/1505.00853
218
- """
219
- a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
220
- return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
216
+ .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
217
+ https://arxiv.org/abs/1505.00853
218
+ """
219
+ a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
220
+ return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
221
221
 
222
222
 
223
223
  def hard_shrink(x, lambd=0.5):
224
- r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
224
+ r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
225
225
 
226
- Hardshrink is defined as:
226
+ Hardshrink is defined as:
227
227
 
228
- .. math::
229
- \text{HardShrink}(x) =
230
- \begin{cases}
231
- x, & \text{ if } x > \lambda \\
232
- x, & \text{ if } x < -\lambda \\
233
- 0, & \text{ otherwise }
234
- \end{cases}
228
+ .. math::
229
+ \text{HardShrink}(x) =
230
+ \begin{cases}
231
+ x, & \text{ if } x > \lambda \\
232
+ x, & \text{ if } x < -\lambda \\
233
+ 0, & \text{ otherwise }
234
+ \end{cases}
235
235
 
236
- Args:
237
- lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
236
+ Args:
237
+ lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
238
238
 
239
- Shape:
240
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
241
- - Output: :math:`(*)`, same shape as the input.
239
+ Shape:
240
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
241
+ - Output: :math:`(*)`, same shape as the input.
242
242
 
243
- """
244
- return u.math.where(x > lambd,
245
- x,
246
- u.math.where(x < -lambd,
247
- x,
248
- u.Quantity(0., unit=u.get_unit(x))))
243
+ """
244
+ return u.math.where(x > lambd,
245
+ x,
246
+ u.math.where(x < -lambd,
247
+ x,
248
+ u.Quantity(0., unit=u.get_unit(x))))
249
249
 
250
250
 
251
251
  def _keep_unit(fun, x, **kwargs):
252
- unit = u.get_unit(x)
253
- x = fun(u.get_mantissa(x), **kwargs)
254
- return x if unit.is_unitless else u.Quantity(x, unit=unit)
252
+ unit = u.get_unit(x)
253
+ x = fun(u.get_mantissa(x), **kwargs)
254
+ return x if unit.is_unitless else u.Quantity(x, unit=unit)
255
255
 
256
256
 
257
257
  def relu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
258
- r"""Rectified linear unit activation function.
258
+ r"""Rectified linear unit activation function.
259
259
 
260
- Computes the element-wise function:
260
+ Computes the element-wise function:
261
261
 
262
- .. math::
263
- \mathrm{relu}(x) = \max(x, 0)
262
+ .. math::
263
+ \mathrm{relu}(x) = \max(x, 0)
264
264
 
265
- except under differentiation, we take:
265
+ except under differentiation, we take:
266
266
 
267
- .. math::
268
- \nabla \mathrm{relu}(0) = 0
267
+ .. math::
268
+ \nabla \mathrm{relu}(0) = 0
269
269
 
270
- For more information see
271
- `Numerical influence of ReLU’(0) on backpropagation
272
- <https://openreview.net/forum?id=urrcVI-_jRm>`_.
270
+ For more information see
271
+ `Numerical influence of ReLU’(0) on backpropagation
272
+ <https://openreview.net/forum?id=urrcVI-_jRm>`_.
273
273
 
274
- Args:
275
- x : input array
274
+ Args:
275
+ x : input array
276
276
 
277
- Returns:
278
- An array.
277
+ Returns:
278
+ An array.
279
279
 
280
- Example:
281
- >>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
282
- Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
280
+ Example:
281
+ >>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
282
+ Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
283
283
 
284
- See also:
285
- :func:`relu6`
284
+ See also:
285
+ :func:`relu6`
286
286
 
287
- """
288
- return _keep_unit(jax.nn.relu, x)
287
+ """
288
+ return _keep_unit(jax.nn.relu, x)
289
289
 
290
290
 
291
291
  def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Union[jax.Array, u.Quantity]:
292
- r"""Squareplus activation function.
292
+ r"""Squareplus activation function.
293
293
 
294
- Computes the element-wise function
294
+ Computes the element-wise function
295
295
 
296
- .. math::
297
- \mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
296
+ .. math::
297
+ \mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
298
298
 
299
- as described in https://arxiv.org/abs/2112.11687.
299
+ as described in https://arxiv.org/abs/2112.11687.
300
300
 
301
- Args:
302
- x : input array
303
- b : smoothness parameter
304
- """
305
- return _keep_unit(jax.nn.squareplus, x, b=b)
301
+ Args:
302
+ x : input array
303
+ b : smoothness parameter
304
+ """
305
+ return _keep_unit(jax.nn.squareplus, x, b=b)
306
306
 
307
307
 
308
308
  def softplus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
309
- r"""Softplus activation function.
309
+ r"""Softplus activation function.
310
310
 
311
- Computes the element-wise function
311
+ Computes the element-wise function
312
312
 
313
- .. math::
314
- \mathrm{softplus}(x) = \log(1 + e^x)
313
+ .. math::
314
+ \mathrm{softplus}(x) = \log(1 + e^x)
315
315
 
316
- Args:
317
- x : input array
318
- """
319
- return _keep_unit(jax.nn.softplus, x)
316
+ Args:
317
+ x : input array
318
+ """
319
+ return _keep_unit(jax.nn.softplus, x)
320
320
 
321
321
 
322
322
  def soft_sign(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
323
- r"""Soft-sign activation function.
323
+ r"""Soft-sign activation function.
324
324
 
325
- Computes the element-wise function
325
+ Computes the element-wise function
326
326
 
327
- .. math::
328
- \mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
327
+ .. math::
328
+ \mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
329
329
 
330
- Args:
331
- x : input array
332
- """
333
- return _keep_unit(jax.nn.soft_sign, x)
330
+ Args:
331
+ x : input array
332
+ """
333
+ return _keep_unit(jax.nn.soft_sign, x)
334
334
 
335
335
 
336
336
  def sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
337
- r"""Sigmoid activation function.
337
+ r"""Sigmoid activation function.
338
338
 
339
- Computes the element-wise function:
339
+ Computes the element-wise function:
340
340
 
341
- .. math::
342
- \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
341
+ .. math::
342
+ \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
343
343
 
344
- Args:
345
- x : input array
344
+ Args:
345
+ x : input array
346
346
 
347
- Returns:
348
- An array.
347
+ Returns:
348
+ An array.
349
349
 
350
- See also:
351
- :func:`log_sigmoid`
350
+ See also:
351
+ :func:`log_sigmoid`
352
352
 
353
- """
354
- return _keep_unit(jax.nn.sigmoid, x)
353
+ """
354
+ return _keep_unit(jax.nn.sigmoid, x)
355
355
 
356
356
 
357
357
  def silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
358
- r"""SiLU (a.k.a. swish) activation function.
358
+ r"""SiLU (a.k.a. swish) activation function.
359
359
 
360
- Computes the element-wise function:
360
+ Computes the element-wise function:
361
361
 
362
- .. math::
363
- \mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
362
+ .. math::
363
+ \mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
364
364
 
365
- :func:`swish` and :func:`silu` are both aliases for the same function.
365
+ :func:`swish` and :func:`silu` are both aliases for the same function.
366
366
 
367
- Args:
368
- x : input array
367
+ Args:
368
+ x : input array
369
369
 
370
- Returns:
371
- An array.
370
+ Returns:
371
+ An array.
372
372
 
373
- See also:
374
- :func:`sigmoid`
375
- """
376
- return _keep_unit(jax.nn.silu, x)
373
+ See also:
374
+ :func:`sigmoid`
375
+ """
376
+ return _keep_unit(jax.nn.silu, x)
377
377
 
378
378
 
379
379
  swish = silu
380
380
 
381
381
 
382
382
  def log_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
383
- r"""Log-sigmoid activation function.
383
+ r"""Log-sigmoid activation function.
384
384
 
385
- Computes the element-wise function:
385
+ Computes the element-wise function:
386
386
 
387
- .. math::
388
- \mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
387
+ .. math::
388
+ \mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
389
389
 
390
- Args:
391
- x : input array
390
+ Args:
391
+ x : input array
392
392
 
393
- Returns:
394
- An array.
393
+ Returns:
394
+ An array.
395
395
 
396
- See also:
397
- :func:`sigmoid`
398
- """
399
- return _keep_unit(jax.nn.log_sigmoid, x)
396
+ See also:
397
+ :func:`sigmoid`
398
+ """
399
+ return _keep_unit(jax.nn.log_sigmoid, x)
400
400
 
401
401
 
402
402
  def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
403
- r"""Exponential linear unit activation function.
403
+ r"""Exponential linear unit activation function.
404
404
 
405
- Computes the element-wise function:
405
+ Computes the element-wise function:
406
406
 
407
- .. math::
408
- \mathrm{elu}(x) = \begin{cases}
409
- x, & x > 0\\
410
- \alpha \left(\exp(x) - 1\right), & x \le 0
411
- \end{cases}
407
+ .. math::
408
+ \mathrm{elu}(x) = \begin{cases}
409
+ x, & x > 0\\
410
+ \alpha \left(\exp(x) - 1\right), & x \le 0
411
+ \end{cases}
412
412
 
413
- Args:
414
- x : input array
415
- alpha : scalar or array of alpha values (default: 1.0)
413
+ Args:
414
+ x : input array
415
+ alpha : scalar or array of alpha values (default: 1.0)
416
416
 
417
- Returns:
418
- An array.
417
+ Returns:
418
+ An array.
419
419
 
420
- See also:
421
- :func:`selu`
422
- """
423
- return _keep_unit(jax.nn.elu, x)
420
+ See also:
421
+ :func:`selu`
422
+ """
423
+ return _keep_unit(jax.nn.elu, x)
424
424
 
425
425
 
426
426
  def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Union[jax.Array, u.Quantity]:
427
- r"""Leaky rectified linear unit activation function.
427
+ r"""Leaky rectified linear unit activation function.
428
428
 
429
- Computes the element-wise function:
429
+ Computes the element-wise function:
430
430
 
431
- .. math::
432
- \mathrm{leaky\_relu}(x) = \begin{cases}
433
- x, & x \ge 0\\
434
- \alpha x, & x < 0
435
- \end{cases}
431
+ .. math::
432
+ \mathrm{leaky\_relu}(x) = \begin{cases}
433
+ x, & x \ge 0\\
434
+ \alpha x, & x < 0
435
+ \end{cases}
436
436
 
437
- where :math:`\alpha` = :code:`negative_slope`.
437
+ where :math:`\alpha` = :code:`negative_slope`.
438
438
 
439
- Args:
440
- x : input array
441
- negative_slope : array or scalar specifying the negative slope (default: 0.01)
439
+ Args:
440
+ x : input array
441
+ negative_slope : array or scalar specifying the negative slope (default: 0.01)
442
442
 
443
- Returns:
444
- An array.
443
+ Returns:
444
+ An array.
445
445
 
446
- See also:
447
- :func:`relu`
448
- """
449
- return _keep_unit(jax.nn.leaky_relu, x, negative_slope=negative_slope)
446
+ See also:
447
+ :func:`relu`
448
+ """
449
+ return _keep_unit(jax.nn.leaky_relu, x, negative_slope=negative_slope)
450
450
 
451
451
 
452
- def hard_tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
453
- r"""Hard :math:`\mathrm{tanh}` activation function.
452
+ def _hard_tanh(x, min_val=- 1.0, max_val=1.0):
453
+ return jax.numpy.where(x > max_val, max_val, jax.numpy.where(x < min_val, min_val, x))
454
454
 
455
- Computes the element-wise function:
456
455
 
457
- .. math::
458
- \mathrm{hard\_tanh}(x) = \begin{cases}
459
- -1, & x < -1\\
460
- x, & -1 \le x \le 1\\
461
- 1, & 1 < x
462
- \end{cases}
456
+ def hard_tanh(
457
+ x: ArrayLike,
458
+ min_val: float = - 1.0,
459
+ max_val: float = 1.0
460
+ ) -> Union[jax.Array, u.Quantity]:
461
+ r"""Hard :math:`\mathrm{tanh}` activation function.
463
462
 
464
- Args:
465
- x : input array
463
+ Computes the element-wise function:
466
464
 
467
- Returns:
468
- An array.
469
- """
470
- return _keep_unit(jax.nn.hard_tanh, x)
465
+ .. math::
466
+ \mathrm{hard\_tanh}(x) = \begin{cases}
467
+ -1, & x < -1\\
468
+ x, & -1 \le x \le 1\\
469
+ 1, & 1 < x
470
+ \end{cases}
471
+
472
+ Args:
473
+ x : input array
474
+ min_val: float. minimum value of the linear region range. Default: -1
475
+ max_val: float. maximum value of the linear region range. Default: 1
476
+
477
+ Returns:
478
+ An array.
479
+ """
480
+ return _keep_unit(_hard_tanh, x, min_val=min_val, max_val=max_val)
471
481
 
472
482
 
473
483
  def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
474
- r"""Continuously-differentiable exponential linear unit activation.
484
+ r"""Continuously-differentiable exponential linear unit activation.
475
485
 
476
- Computes the element-wise function:
486
+ Computes the element-wise function:
477
487
 
478
- .. math::
479
- \mathrm{celu}(x) = \begin{cases}
480
- x, & x > 0\\
481
- \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
482
- \end{cases}
488
+ .. math::
489
+ \mathrm{celu}(x) = \begin{cases}
490
+ x, & x > 0\\
491
+ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
492
+ \end{cases}
483
493
 
484
- For more information, see
485
- `Continuously Differentiable Exponential Linear Units
486
- <https://arxiv.org/pdf/1704.07483.pdf>`_.
494
+ For more information, see
495
+ `Continuously Differentiable Exponential Linear Units
496
+ <https://arxiv.org/pdf/1704.07483.pdf>`_.
487
497
 
488
- Args:
489
- x : input array
490
- alpha : array or scalar (default: 1.0)
498
+ Args:
499
+ x : input array
500
+ alpha : array or scalar (default: 1.0)
491
501
 
492
- Returns:
493
- An array.
494
- """
495
- return _keep_unit(jax.nn.celu, x, alpha=alpha)
502
+ Returns:
503
+ An array.
504
+ """
505
+ return _keep_unit(jax.nn.celu, x, alpha=alpha)
496
506
 
497
507
 
498
508
  def selu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
499
- r"""Scaled exponential linear unit activation.
509
+ r"""Scaled exponential linear unit activation.
500
510
 
501
- Computes the element-wise function:
511
+ Computes the element-wise function:
502
512
 
503
- .. math::
504
- \mathrm{selu}(x) = \lambda \begin{cases}
505
- x, & x > 0\\
506
- \alpha e^x - \alpha, & x \le 0
507
- \end{cases}
513
+ .. math::
514
+ \mathrm{selu}(x) = \lambda \begin{cases}
515
+ x, & x > 0\\
516
+ \alpha e^x - \alpha, & x \le 0
517
+ \end{cases}
508
518
 
509
- where :math:`\lambda = 1.0507009873554804934193349852946` and
510
- :math:`\alpha = 1.6732632423543772848170429916717`.
519
+ where :math:`\lambda = 1.0507009873554804934193349852946` and
520
+ :math:`\alpha = 1.6732632423543772848170429916717`.
511
521
 
512
- For more information, see
513
- `Self-Normalizing Neural Networks
514
- <https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf>`_.
522
+ For more information, see
523
+ `Self-Normalizing Neural Networks
524
+ <https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf>`_.
515
525
 
516
- Args:
517
- x : input array
526
+ Args:
527
+ x : input array
518
528
 
519
- Returns:
520
- An array.
529
+ Returns:
530
+ An array.
521
531
 
522
- See also:
523
- :func:`elu`
524
- """
525
- return _keep_unit(jax.nn.selu, x)
532
+ See also:
533
+ :func:`elu`
534
+ """
535
+ return _keep_unit(jax.nn.selu, x)
526
536
 
527
537
 
528
538
  def gelu(x: ArrayLike, approximate: bool = True) -> Union[jax.Array, u.Quantity]:
529
- r"""Gaussian error linear unit activation function.
539
+ r"""Gaussian error linear unit activation function.
530
540
 
531
- If ``approximate=False``, computes the element-wise function:
541
+ If ``approximate=False``, computes the element-wise function:
532
542
 
533
- .. math::
534
- \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
535
- \frac{x}{\sqrt{2}} \right) \right)
543
+ .. math::
544
+ \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
545
+ \frac{x}{\sqrt{2}} \right) \right)
536
546
 
537
- If ``approximate=True``, uses the approximate formulation of GELU:
547
+ If ``approximate=True``, uses the approximate formulation of GELU:
538
548
 
539
- .. math::
540
- \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
541
- \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
549
+ .. math::
550
+ \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
551
+ \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
542
552
 
543
- For more information, see `Gaussian Error Linear Units (GELUs)
544
- <https://arxiv.org/abs/1606.08415>`_, section 2.
553
+ For more information, see `Gaussian Error Linear Units (GELUs)
554
+ <https://arxiv.org/abs/1606.08415>`_, section 2.
545
555
 
546
- Args:
547
- x : input array
548
- approximate: whether to use the approximate or exact formulation.
549
- """
550
- return _keep_unit(jax.nn.gelu, x, approximate=approximate)
556
+ Args:
557
+ x : input array
558
+ approximate: whether to use the approximate or exact formulation.
559
+ """
560
+ return _keep_unit(jax.nn.gelu, x, approximate=approximate)
551
561
 
552
562
 
553
563
  def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
554
- r"""Gated linear unit activation function.
564
+ r"""Gated linear unit activation function.
555
565
 
556
- Computes the function:
566
+ Computes the function:
557
567
 
558
- .. math::
559
- \mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot
560
- \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right]
561
- \right)
568
+ .. math::
569
+ \mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot
570
+ \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right]
571
+ \right)
562
572
 
563
- where the array is split into two along ``axis``. The size of the ``axis``
564
- dimension must be divisible by two.
573
+ where the array is split into two along ``axis``. The size of the ``axis``
574
+ dimension must be divisible by two.
565
575
 
566
- Args:
567
- x : input array
568
- axis: the axis along which the split should be computed (default: -1)
576
+ Args:
577
+ x : input array
578
+ axis: the axis along which the split should be computed (default: -1)
569
579
 
570
- Returns:
571
- An array.
580
+ Returns:
581
+ An array.
572
582
 
573
- See also:
574
- :func:`sigmoid`
575
- """
576
- return _keep_unit(jax.nn.glu, x, axis=axis)
583
+ See also:
584
+ :func:`sigmoid`
585
+ """
586
+ return _keep_unit(jax.nn.glu, x, axis=axis)
577
587
 
578
588
 
579
589
  def log_softmax(x: ArrayLike,
580
590
  axis: int | tuple[int, ...] | None = -1,
581
591
  where: ArrayLike | None = None,
582
592
  initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
583
- r"""Log-Softmax function.
593
+ r"""Log-Softmax function.
584
594
 
585
- Computes the logarithm of the :code:`softmax` function, which rescales
586
- elements to the range :math:`[-\infty, 0)`.
595
+ Computes the logarithm of the :code:`softmax` function, which rescales
596
+ elements to the range :math:`[-\infty, 0)`.
587
597
 
588
- .. math ::
589
- \mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
590
- \right)
598
+ .. math ::
599
+ \mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
600
+ \right)
591
601
 
592
- Args:
593
- x : input array
594
- axis: the axis or axes along which the :code:`log_softmax` should be
595
- computed. Either an integer or a tuple of integers.
596
- where: Elements to include in the :code:`log_softmax`.
597
- initial: The minimum value used to shift the input array. Must be present
598
- when :code:`where` is not None.
602
+ Args:
603
+ x : input array
604
+ axis: the axis or axes along which the :code:`log_softmax` should be
605
+ computed. Either an integer or a tuple of integers.
606
+ where: Elements to include in the :code:`log_softmax`.
607
+ initial: The minimum value used to shift the input array. Must be present
608
+ when :code:`where` is not None.
599
609
 
600
- Returns:
601
- An array.
610
+ Returns:
611
+ An array.
602
612
 
603
- See also:
604
- :func:`softmax`
605
- """
606
- if initial is not None:
607
- initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
608
- return _keep_unit(jax.nn.log_softmax, x, axis=axis, where=where, initial=initial)
613
+ See also:
614
+ :func:`softmax`
615
+ """
616
+ if initial is not None:
617
+ initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
618
+ return _keep_unit(jax.nn.log_softmax, x, axis=axis, where=where, initial=initial)
609
619
 
610
620
 
611
621
  def softmax(x: ArrayLike,
612
622
  axis: int | tuple[int, ...] | None = -1,
613
623
  where: ArrayLike | None = None,
614
624
  initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
615
- r"""Softmax function.
625
+ r"""Softmax function.
616
626
 
617
- Computes the function which rescales elements to the range :math:`[0, 1]`
618
- such that the elements along :code:`axis` sum to :math:`1`.
627
+ Computes the function which rescales elements to the range :math:`[0, 1]`
628
+ such that the elements along :code:`axis` sum to :math:`1`.
619
629
 
620
- .. math ::
621
- \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
630
+ .. math ::
631
+ \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
622
632
 
623
- Args:
624
- x : input array
625
- axis: the axis or axes along which the softmax should be computed. The
626
- softmax output summed across these dimensions should sum to :math:`1`.
627
- Either an integer or a tuple of integers.
628
- where: Elements to include in the :code:`softmax`.
629
- initial: The minimum value used to shift the input array. Must be present
630
- when :code:`where` is not None.
633
+ Args:
634
+ x : input array
635
+ axis: the axis or axes along which the softmax should be computed. The
636
+ softmax output summed across these dimensions should sum to :math:`1`.
637
+ Either an integer or a tuple of integers.
638
+ where: Elements to include in the :code:`softmax`.
639
+ initial: The minimum value used to shift the input array. Must be present
640
+ when :code:`where` is not None.
631
641
 
632
- Returns:
633
- An array.
642
+ Returns:
643
+ An array.
634
644
 
635
- See also:
636
- :func:`log_softmax`
637
- """
638
- if initial is not None:
639
- initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
640
- return _keep_unit(jax.nn.softmax, x, axis=axis, where=where, initial=initial)
645
+ See also:
646
+ :func:`log_softmax`
647
+ """
648
+ if initial is not None:
649
+ initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
650
+ return _keep_unit(jax.nn.softmax, x, axis=axis, where=where, initial=initial)
641
651
 
642
652
 
643
653
  def standardize(x: ArrayLike,
@@ -645,169 +655,169 @@ def standardize(x: ArrayLike,
645
655
  variance: ArrayLike | None = None,
646
656
  epsilon: ArrayLike = 1e-5,
647
657
  where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
648
- r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
649
- return _keep_unit(jax.nn.standardize, x, axis=axis, where=where, variance=variance, epsilon=epsilon)
658
+ r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
659
+ return _keep_unit(jax.nn.standardize, x, axis=axis, where=where, variance=variance, epsilon=epsilon)
650
660
 
651
661
 
652
662
  def one_hot(x: Any,
653
663
  num_classes: int, *,
654
664
  dtype: Any = jax.numpy.float_,
655
665
  axis: Union[int, Sequence[int]] = -1) -> Union[jax.Array, u.Quantity]:
656
- """One-hot encodes the given indices.
666
+ """One-hot encodes the given indices.
657
667
 
658
- Each index in the input ``x`` is encoded as a vector of zeros of length
659
- ``num_classes`` with the element at ``index`` set to one::
668
+ Each index in the input ``x`` is encoded as a vector of zeros of length
669
+ ``num_classes`` with the element at ``index`` set to one::
660
670
 
661
- >>> one_hot(jnp.array([0, 1, 2]), 3)
662
- Array([[1., 0., 0.],
663
- [0., 1., 0.],
664
- [0., 0., 1.]], dtype=float32)
671
+ >>> one_hot(jnp.array([0, 1, 2]), 3)
672
+ Array([[1., 0., 0.],
673
+ [0., 1., 0.],
674
+ [0., 0., 1.]], dtype=float32)
665
675
 
666
- Indices outside the range [0, num_classes) will be encoded as zeros::
676
+ Indices outside the range [0, num_classes) will be encoded as zeros::
667
677
 
668
- >>> one_hot(jnp.array([-1, 3]), 3)
669
- Array([[0., 0., 0.],
670
- [0., 0., 0.]], dtype=float32)
678
+ >>> one_hot(jnp.array([-1, 3]), 3)
679
+ Array([[0., 0., 0.],
680
+ [0., 0., 0.]], dtype=float32)
671
681
 
672
- Args:
673
- x: A tensor of indices.
674
- num_classes: Number of classes in the one-hot dimension.
675
- dtype: optional, a float dtype for the returned values (default :obj:`jnp.float_`).
676
- axis: the axis or axes along which the function should be
677
- computed.
678
- """
679
- return _keep_unit(jax.nn.one_hot, x, axis=axis, num_classes=num_classes, dtype=dtype)
682
+ Args:
683
+ x: A tensor of indices.
684
+ num_classes: Number of classes in the one-hot dimension.
685
+ dtype: optional, a float dtype for the returned values (default :obj:`jnp.float_`).
686
+ axis: the axis or axes along which the function should be
687
+ computed.
688
+ """
689
+ return _keep_unit(jax.nn.one_hot, x, axis=axis, num_classes=num_classes, dtype=dtype)
680
690
 
681
691
 
682
692
  def relu6(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
683
- r"""Rectified Linear Unit 6 activation function.
693
+ r"""Rectified Linear Unit 6 activation function.
684
694
 
685
- Computes the element-wise function
695
+ Computes the element-wise function
686
696
 
687
- .. math::
688
- \mathrm{relu6}(x) = \min(\max(x, 0), 6)
697
+ .. math::
698
+ \mathrm{relu6}(x) = \min(\max(x, 0), 6)
689
699
 
690
- except under differentiation, we take:
700
+ except under differentiation, we take:
691
701
 
692
- .. math::
693
- \nabla \mathrm{relu}(0) = 0
702
+ .. math::
703
+ \nabla \mathrm{relu}(0) = 0
694
704
 
695
- and
705
+ and
696
706
 
697
- .. math::
698
- \nabla \mathrm{relu}(6) = 0
707
+ .. math::
708
+ \nabla \mathrm{relu}(6) = 0
699
709
 
700
- Args:
701
- x : input array
710
+ Args:
711
+ x : input array
702
712
 
703
- Returns:
704
- An array.
713
+ Returns:
714
+ An array.
705
715
 
706
- See also:
707
- :func:`relu`
708
- """
709
- return _keep_unit(jax.nn.relu6, x)
716
+ See also:
717
+ :func:`relu`
718
+ """
719
+ return _keep_unit(jax.nn.relu6, x)
710
720
 
711
721
 
712
722
  def hard_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
713
- r"""Hard Sigmoid activation function.
723
+ r"""Hard Sigmoid activation function.
714
724
 
715
- Computes the element-wise function
725
+ Computes the element-wise function
716
726
 
717
- .. math::
718
- \mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
727
+ .. math::
728
+ \mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
719
729
 
720
- Args:
721
- x : input array
730
+ Args:
731
+ x : input array
722
732
 
723
- Returns:
724
- An array.
733
+ Returns:
734
+ An array.
725
735
 
726
- See also:
727
- :func:`relu6`
728
- """
729
- return _keep_unit(jax.nn.hard_sigmoid, x)
736
+ See also:
737
+ :func:`relu6`
738
+ """
739
+ return _keep_unit(jax.nn.hard_sigmoid, x)
730
740
 
731
741
 
732
742
  def hard_silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
733
- r"""Hard SiLU (swish) activation function
743
+ r"""Hard SiLU (swish) activation function
734
744
 
735
- Computes the element-wise function
745
+ Computes the element-wise function
736
746
 
737
- .. math::
738
- \mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
747
+ .. math::
748
+ \mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
739
749
 
740
- Both :func:`hard_silu` and :func:`hard_swish` are aliases for the same
741
- function.
750
+ Both :func:`hard_silu` and :func:`hard_swish` are aliases for the same
751
+ function.
742
752
 
743
- Args:
744
- x : input array
753
+ Args:
754
+ x : input array
745
755
 
746
- Returns:
747
- An array.
756
+ Returns:
757
+ An array.
748
758
 
749
- See also:
750
- :func:`hard_sigmoid`
751
- """
752
- return _keep_unit(jax.nn.hard_silu, x)
759
+ See also:
760
+ :func:`hard_sigmoid`
761
+ """
762
+ return _keep_unit(jax.nn.hard_silu, x)
753
763
 
754
- return jax.nn.hard_silu(x)
764
+ return jax.nn.hard_silu(x)
755
765
 
756
766
 
757
767
  hard_swish = hard_silu
758
768
 
759
769
 
760
770
  def sparse_plus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
761
- r"""Sparse plus function.
771
+ r"""Sparse plus function.
762
772
 
763
- Computes the function:
773
+ Computes the function:
764
774
 
765
- .. math::
775
+ .. math::
766
776
 
767
- \mathrm{sparse\_plus}(x) = \begin{cases}
768
- 0, & x \leq -1\\
769
- \frac{1}{4}(x+1)^2, & -1 < x < 1 \\
770
- x, & 1 \leq x
771
- \end{cases}
777
+ \mathrm{sparse\_plus}(x) = \begin{cases}
778
+ 0, & x \leq -1\\
779
+ \frac{1}{4}(x+1)^2, & -1 < x < 1 \\
780
+ x, & 1 \leq x
781
+ \end{cases}
772
782
 
773
- This is the twin function of the softplus activation ensuring a zero output
774
- for inputs less than -1 and a linear output for inputs greater than 1,
775
- while remaining smooth, convex, monotonic by an adequate definition between
776
- -1 and 1.
783
+ This is the twin function of the softplus activation ensuring a zero output
784
+ for inputs less than -1 and a linear output for inputs greater than 1,
785
+ while remaining smooth, convex, monotonic by an adequate definition between
786
+ -1 and 1.
777
787
 
778
- Args:
779
- x: input (float)
780
- """
781
- return _keep_unit(jax.nn.sparse_plus, x)
788
+ Args:
789
+ x: input (float)
790
+ """
791
+ return _keep_unit(jax.nn.sparse_plus, x)
782
792
 
783
793
 
784
794
  def sparse_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
785
- r"""Sparse sigmoid activation function.
795
+ r"""Sparse sigmoid activation function.
786
796
 
787
- Computes the function:
797
+ Computes the function:
788
798
 
789
- .. math::
799
+ .. math::
790
800
 
791
- \mathrm{sparse\_sigmoid}(x) = \begin{cases}
792
- 0, & x \leq -1\\
793
- \frac{1}{2}(x+1), & -1 < x < 1 \\
794
- 1, & 1 \leq x
795
- \end{cases}
801
+ \mathrm{sparse\_sigmoid}(x) = \begin{cases}
802
+ 0, & x \leq -1\\
803
+ \frac{1}{2}(x+1), & -1 < x < 1 \\
804
+ 1, & 1 \leq x
805
+ \end{cases}
796
806
 
797
- This is the twin function of the ``sigmoid`` activation ensuring a zero output
798
- for inputs less than -1, a 1 output for inputs greater than 1, and a linear
799
- output for inputs between -1 and 1. It is the derivative of ``sparse_plus``.
807
+ This is the twin function of the ``sigmoid`` activation ensuring a zero output
808
+ for inputs less than -1, a 1 output for inputs greater than 1, and a linear
809
+ output for inputs between -1 and 1. It is the derivative of ``sparse_plus``.
800
810
 
801
- For more information, see `Learning with Fenchel-Young Losses (section 6.2)
802
- <https://arxiv.org/abs/1901.02324>`_.
811
+ For more information, see `Learning with Fenchel-Young Losses (section 6.2)
812
+ <https://arxiv.org/abs/1901.02324>`_.
803
813
 
804
- Args:
805
- x : input array
814
+ Args:
815
+ x : input array
806
816
 
807
- Returns:
808
- An array.
817
+ Returns:
818
+ An array.
809
819
 
810
- See also:
811
- :func:`sigmoid`
812
- """
813
- return _keep_unit(jax.nn.sparse_sigmoid, x)
820
+ See also:
821
+ :func:`sigmoid`
822
+ """
823
+ return _keep_unit(jax.nn.sparse_sigmoid, x)