brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -29,303 +29,303 @@ import brainstate as bst
|
|
29
29
|
|
30
30
|
|
31
31
|
class NNFunctionsTest(jtu.JaxTestCase):
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
check_grads(bst.functional.squareplus, (1e-8,), order=4,
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
check_grads(bst.functional.squareplus, (0.,), order=1,
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
check_grads(bst.functional.squareplus, (-float('inf'),), order=1,
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
check_grads(bst.functional.squareplus, (float('nan'),), order=1,
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
self.assertEqual(dtype(1), bst.functional.squareplus(dtype(0), dtype(4)))
|
91
|
-
|
92
|
-
|
93
|
-
check_grads(bst.functional.mish, (1e-8,), order=4,
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
check_grads(bst.functional.mish, (0.,), order=1,
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
check_grads(bst.functional.mish, (-float('inf'),), order=1,
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
check_grads(bst.functional.mish, (float('nan'),), order=1,
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
val = bst.functional.squareplus(1e3)
|
141
|
-
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
32
|
+
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
33
|
+
def testSoftplusGrad(self):
|
34
|
+
check_grads(bst.functional.softplus, (1e-8,), order=4, )
|
35
|
+
|
36
|
+
def testSoftplusGradZero(self):
|
37
|
+
check_grads(bst.functional.softplus, (0.,), order=1)
|
38
|
+
|
39
|
+
def testSoftplusGradInf(self):
|
40
|
+
self.assertAllClose(1., jax.grad(bst.functional.softplus)(float('inf')))
|
41
|
+
|
42
|
+
def testSoftplusGradNegInf(self):
|
43
|
+
check_grads(bst.functional.softplus, (-float('inf'),), order=1)
|
44
|
+
|
45
|
+
def testSoftplusGradNan(self):
|
46
|
+
check_grads(bst.functional.softplus, (float('nan'),), order=1)
|
47
|
+
|
48
|
+
@parameterized.parameters([int, float] + jtu.dtypes.floating + jtu.dtypes.integer)
|
49
|
+
def testSoftplusZero(self, dtype):
|
50
|
+
self.assertEqual(jnp.log(dtype(2)), bst.functional.softplus(dtype(0)))
|
51
|
+
|
52
|
+
def testSparseplusGradZero(self):
|
53
|
+
check_grads(bst.functional.sparse_plus, (-2.,), order=1)
|
54
|
+
|
55
|
+
def testSparseplusGrad(self):
|
56
|
+
check_grads(bst.functional.sparse_plus, (0.,), order=1)
|
57
|
+
|
58
|
+
def testSparseplusAndSparseSigmoid(self):
|
59
|
+
self.assertAllClose(
|
60
|
+
jax.grad(bst.functional.sparse_plus)(0.),
|
61
|
+
bst.functional.sparse_sigmoid(0.),
|
62
|
+
check_dtypes=False)
|
63
|
+
self.assertAllClose(
|
64
|
+
jax.grad(bst.functional.sparse_plus)(2.),
|
65
|
+
bst.functional.sparse_sigmoid(2.),
|
66
|
+
check_dtypes=False)
|
67
|
+
self.assertAllClose(
|
68
|
+
jax.grad(bst.functional.sparse_plus)(-2.),
|
69
|
+
bst.functional.sparse_sigmoid(-2.),
|
70
|
+
check_dtypes=False)
|
71
|
+
|
72
|
+
# def testSquareplusGrad(self):
|
73
|
+
# check_grads(bst.functional.squareplus, (1e-8,), order=4,
|
74
|
+
# )
|
75
|
+
|
76
|
+
# def testSquareplusGradZero(self):
|
77
|
+
# check_grads(bst.functional.squareplus, (0.,), order=1,
|
78
|
+
# )
|
79
|
+
|
80
|
+
# def testSquareplusGradNegInf(self):
|
81
|
+
# check_grads(bst.functional.squareplus, (-float('inf'),), order=1,
|
82
|
+
# )
|
83
|
+
|
84
|
+
# def testSquareplusGradNan(self):
|
85
|
+
# check_grads(bst.functional.squareplus, (float('nan'),), order=1,
|
86
|
+
# )
|
87
|
+
|
88
|
+
# @parameterized.parameters([float] + jtu.dtypes.floating)
|
89
|
+
# def testSquareplusZero(self, dtype):
|
90
|
+
# self.assertEqual(dtype(1), bst.functional.squareplus(dtype(0), dtype(4)))
|
91
|
+
#
|
92
|
+
# def testMishGrad(self):
|
93
|
+
# check_grads(bst.functional.mish, (1e-8,), order=4,
|
94
|
+
# )
|
95
|
+
#
|
96
|
+
# def testMishGradZero(self):
|
97
|
+
# check_grads(bst.functional.mish, (0.,), order=1,
|
98
|
+
# )
|
99
|
+
#
|
100
|
+
# def testMishGradNegInf(self):
|
101
|
+
# check_grads(bst.functional.mish, (-float('inf'),), order=1,
|
102
|
+
# )
|
103
|
+
#
|
104
|
+
# def testMishGradNan(self):
|
105
|
+
# check_grads(bst.functional.mish, (float('nan'),), order=1,
|
106
|
+
# )
|
107
|
+
|
108
|
+
@parameterized.parameters([float] + jtu.dtypes.floating)
|
109
|
+
def testMishZero(self, dtype):
|
110
|
+
self.assertEqual(dtype(0), bst.functional.mish(dtype(0)))
|
111
|
+
|
112
|
+
def testReluGrad(self):
|
113
|
+
rtol = None
|
114
|
+
check_grads(bst.functional.relu, (1.,), order=3, rtol=rtol)
|
115
|
+
check_grads(bst.functional.relu, (-1.,), order=3, rtol=rtol)
|
116
|
+
jaxpr = jax.make_jaxpr(jax.grad(bst.functional.relu))(0.)
|
117
|
+
self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
|
118
|
+
|
119
|
+
def testRelu6Grad(self):
|
120
|
+
rtol = None
|
121
|
+
check_grads(bst.functional.relu6, (1.,), order=3, rtol=rtol)
|
122
|
+
check_grads(bst.functional.relu6, (-1.,), order=3, rtol=rtol)
|
123
|
+
self.assertAllClose(jax.grad(bst.functional.relu6)(0.), 0., check_dtypes=False)
|
124
|
+
self.assertAllClose(jax.grad(bst.functional.relu6)(6.), 0., check_dtypes=False)
|
125
|
+
|
126
|
+
def testSoftplusValue(self):
|
127
|
+
val = bst.functional.softplus(89.)
|
128
|
+
self.assertAllClose(val, 89., check_dtypes=False)
|
129
|
+
|
130
|
+
def testSparseplusValue(self):
|
131
|
+
val = bst.functional.sparse_plus(89.)
|
132
|
+
self.assertAllClose(val, 89., check_dtypes=False)
|
133
|
+
|
134
|
+
def testSparsesigmoidValue(self):
|
135
|
+
self.assertAllClose(bst.functional.sparse_sigmoid(-2.), 0., check_dtypes=False)
|
136
|
+
self.assertAllClose(bst.functional.sparse_sigmoid(2.), 1., check_dtypes=False)
|
137
|
+
self.assertAllClose(bst.functional.sparse_sigmoid(0.), .5, check_dtypes=False)
|
138
|
+
|
139
|
+
# def testSquareplusValue(self):
|
140
|
+
# val = bst.functional.squareplus(1e3)
|
141
|
+
# self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
142
|
+
|
143
|
+
def testMishValue(self):
|
144
|
+
val = bst.functional.mish(1e3)
|
145
|
+
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
146
|
+
|
147
|
+
def testEluValue(self):
|
148
|
+
val = bst.functional.elu(1e4)
|
149
|
+
self.assertAllClose(val, 1e4, check_dtypes=False)
|
150
|
+
|
151
|
+
def testGluValue(self):
|
152
|
+
val = bst.functional.glu(jnp.array([1.0, 0.0]), axis=0)
|
153
|
+
self.assertAllClose(val, jnp.array([0.5]))
|
154
|
+
|
155
|
+
@parameterized.parameters(False, True)
|
156
|
+
def testGeluIntType(self, approximate):
|
157
|
+
val_float = bst.functional.gelu(jnp.array(-1.0), approximate=approximate)
|
158
|
+
val_int = bst.functional.gelu(jnp.array(-1), approximate=approximate)
|
159
|
+
self.assertAllClose(val_float, val_int)
|
160
|
+
|
161
|
+
@parameterized.parameters(False, True)
|
162
|
+
def testGelu(self, approximate):
|
163
|
+
def gelu_reference(x):
|
164
|
+
return x * scipy.stats.norm.cdf(x)
|
165
|
+
|
166
|
+
rng = jtu.rand_default(self.rng())
|
167
|
+
args_maker = lambda: [rng((4, 5, 6), jnp.float32)]
|
168
|
+
self._CheckAgainstNumpy(
|
169
|
+
gelu_reference, partial(bst.functional.gelu, approximate=approximate), args_maker,
|
170
|
+
check_dtypes=False, tol=1e-3 if approximate else None)
|
171
|
+
|
172
|
+
@parameterized.parameters(*itertools.product(
|
173
|
+
(jnp.float32, jnp.bfloat16, jnp.float16),
|
174
|
+
(partial(bst.functional.gelu, approximate=False),
|
175
|
+
partial(bst.functional.gelu, approximate=True),
|
176
|
+
bst.functional.relu,
|
177
|
+
bst.functional.softplus,
|
178
|
+
bst.functional.sparse_plus,
|
179
|
+
bst.functional.sigmoid,
|
180
|
+
# bst.functional.squareplus,
|
181
|
+
bst.functional.mish)))
|
182
|
+
def testDtypeMatchesInput(self, dtype, fn):
|
183
|
+
x = jnp.zeros((), dtype=dtype)
|
184
|
+
out = fn(x)
|
185
|
+
self.assertEqual(out.dtype, dtype)
|
186
|
+
|
187
|
+
def testEluMemory(self):
|
188
|
+
# see https://github.com/google/jax/pull/1640
|
189
|
+
with jax.enable_checks(False): # With checks we materialize the array
|
190
|
+
jax.make_jaxpr(lambda: bst.functional.elu(jnp.ones((10 ** 12,)))) # don't oom
|
191
|
+
|
192
|
+
def testHardTanhMemory(self):
|
193
|
+
# see https://github.com/google/jax/pull/1640
|
194
|
+
with jax.enable_checks(False): # With checks we materialize the array
|
195
|
+
jax.make_jaxpr(lambda: bst.functional.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
|
196
|
+
|
197
|
+
@parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
|
198
|
+
def testSoftmaxEmptyArray(self, fn):
|
199
|
+
x = jnp.array([], dtype=float)
|
200
|
+
self.assertArraysEqual(fn(x), x)
|
201
|
+
|
202
|
+
@parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
|
203
|
+
def testSoftmaxEmptyMask(self, fn):
|
204
|
+
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
205
|
+
m = jnp.zeros_like(x, dtype=bool)
|
206
|
+
expected = jnp.full_like(x, 0.0 if fn is bst.functional.softmax else -jnp.inf)
|
207
|
+
self.assertArraysEqual(fn(x, where=m), expected)
|
208
|
+
|
209
|
+
@parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
|
210
|
+
def testSoftmaxWhereMask(self, fn):
|
211
|
+
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
212
|
+
m = jnp.array([True, False, True, True])
|
213
|
+
|
214
|
+
out = fn(x, where=m)
|
215
|
+
self.assertAllClose(out[m], fn(x[m]))
|
216
|
+
|
217
|
+
probs = out if fn is bst.functional.softmax else jnp.exp(out)
|
218
|
+
self.assertAllClose(probs.sum(), 1.0)
|
219
|
+
|
220
|
+
@parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
|
221
|
+
def testSoftmaxWhereGrad(self, fn):
|
222
|
+
# regression test for https://github.com/google/jax/issues/19490
|
223
|
+
x = jnp.array([36., 10000.])
|
224
|
+
mask = x < 1000
|
225
|
+
|
226
|
+
f = lambda x, mask: fn(x, where=mask)[0]
|
227
|
+
|
228
|
+
self.assertAllClose(jax.grad(f)(x, mask), jnp.zeros_like(x))
|
229
|
+
|
230
|
+
def testSoftmaxGrad(self):
|
231
|
+
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
232
|
+
jtu.check_grads(bst.functional.softmax, (x,), order=2, atol=5e-3)
|
233
|
+
|
234
|
+
def testStandardizeWhereMask(self):
|
235
|
+
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
236
|
+
m = jnp.array([True, False, True, True])
|
237
|
+
x_filtered = jnp.take(x, jnp.array([0, 2, 3]))
|
238
|
+
|
239
|
+
out_masked = jnp.take(bst.functional.standardize(x, where=m), jnp.array([0, 2, 3]))
|
240
|
+
out_filtered = bst.functional.standardize(x_filtered)
|
241
|
+
|
242
|
+
self.assertAllClose(out_masked, out_filtered)
|
243
|
+
|
244
|
+
def testOneHot(self):
|
245
|
+
actual = bst.functional.one_hot(jnp.array([0, 1, 2]), 3)
|
246
|
+
expected = jnp.array([[1., 0., 0.],
|
247
|
+
[0., 1., 0.],
|
248
|
+
[0., 0., 1.]])
|
249
|
+
self.assertAllClose(actual, expected, check_dtypes=False)
|
250
|
+
|
251
|
+
actual = bst.functional.one_hot(jnp.array([1, 2, 0]), 3)
|
252
|
+
expected = jnp.array([[0., 1., 0.],
|
253
|
+
[0., 0., 1.],
|
254
|
+
[1., 0., 0.]])
|
255
|
+
self.assertAllClose(actual, expected, check_dtypes=False)
|
256
|
+
|
257
|
+
def testOneHotOutOfBound(self):
|
258
|
+
actual = bst.functional.one_hot(jnp.array([-1, 3]), 3)
|
259
|
+
expected = jnp.array([[0., 0., 0.],
|
260
|
+
[0., 0., 0.]])
|
261
|
+
self.assertAllClose(actual, expected, check_dtypes=False)
|
262
|
+
|
263
|
+
def testOneHotNonArrayInput(self):
|
264
|
+
actual = bst.functional.one_hot([0, 1, 2], 3)
|
265
|
+
expected = jnp.array([[1., 0., 0.],
|
266
|
+
[0., 1., 0.],
|
267
|
+
[0., 0., 1.]])
|
268
|
+
self.assertAllClose(actual, expected, check_dtypes=False)
|
269
|
+
|
270
|
+
def testOneHotCustomDtype(self):
|
271
|
+
actual = bst.functional.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
|
272
|
+
expected = jnp.array([[True, False, False],
|
273
|
+
[False, True, False],
|
274
|
+
[False, False, True]])
|
275
|
+
self.assertAllClose(actual, expected)
|
276
|
+
|
277
|
+
def testOneHotAxis(self):
|
278
|
+
expected = jnp.array([[0., 1., 0.],
|
279
|
+
[0., 0., 1.],
|
280
|
+
[1., 0., 0.]]).T
|
281
|
+
|
282
|
+
actual = bst.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=0)
|
283
|
+
self.assertAllClose(actual, expected, check_dtypes=False)
|
284
|
+
|
285
|
+
actual = bst.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
|
286
|
+
self.assertAllClose(actual, expected, check_dtypes=False)
|
287
|
+
|
288
|
+
def testTanhExists(self):
|
289
|
+
print(bst.functional.tanh) # doesn't crash
|
290
|
+
|
291
|
+
def testCustomJVPLeak(self):
|
292
|
+
# https://github.com/google/jax/issues/8171
|
293
|
+
@jax.jit
|
294
|
+
def fwd():
|
295
|
+
a = jnp.array(1.)
|
296
|
+
|
297
|
+
def f(hx, _):
|
298
|
+
hx = bst.functional.sigmoid(hx + a)
|
299
|
+
return hx, None
|
300
|
+
|
301
|
+
hx = jnp.array(0.)
|
302
|
+
jax.lax.scan(f, hx, None, length=2)
|
303
|
+
|
304
|
+
with jax.checking_leaks():
|
305
|
+
fwd() # doesn't crash
|
306
|
+
|
307
|
+
def testCustomJVPLeak2(self):
|
308
|
+
# https://github.com/google/jax/issues/8171
|
309
|
+
# The above test uses jax.bst.functional.sigmoid, as in the original #8171, but that
|
310
|
+
# function no longer actually has a custom_jvp! So we inline the old def.
|
311
|
+
|
312
|
+
@jax.custom_jvp
|
313
|
+
def sigmoid(x):
|
314
|
+
one = jnp.float32(1)
|
315
|
+
return jax.lax.div(one, jax.lax.add(one, jax.lax.exp(jax.lax.neg(x))))
|
316
|
+
|
317
|
+
sigmoid.defjvps(lambda g, ans, x: g * ans * (jnp.float32(1) - ans))
|
318
|
+
|
319
|
+
@jax.jit
|
320
|
+
def fwd():
|
321
|
+
a = jnp.array(1., 'float32')
|
322
|
+
|
323
|
+
def f(hx, _):
|
324
|
+
hx = sigmoid(hx + a)
|
325
|
+
return hx, None
|
326
|
+
|
327
|
+
hx = jnp.array(0., 'float32')
|
328
|
+
jax.lax.scan(f, hx, None, length=2)
|
329
|
+
|
330
|
+
with jax.checking_leaks():
|
331
|
+
fwd() # doesn't crash
|