brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -20,142 +20,162 @@ from functools import partial
|
|
20
20
|
|
21
21
|
import jax
|
22
22
|
import jax.numpy as jnp
|
23
|
+
import numpy as np
|
23
24
|
import scipy.stats
|
24
|
-
from absl.testing import parameterized
|
25
|
-
from jax._src import test_util as jtu
|
25
|
+
from absl.testing import absltest, parameterized
|
26
26
|
from jax.test_util import check_grads
|
27
27
|
|
28
28
|
import brainstate
|
29
29
|
|
30
30
|
|
31
|
-
class NNFunctionsTest(
|
32
|
-
|
31
|
+
class NNFunctionsTest(parameterized.TestCase):
|
32
|
+
def setUp(self):
|
33
|
+
super().setUp()
|
34
|
+
self.rng_key = jax.random.PRNGKey(0)
|
35
|
+
|
36
|
+
def assertAllClose(self, a, b, check_dtypes=True, atol=None, rtol=None):
|
37
|
+
"""Helper method for backwards compatibility with JAX test utilities."""
|
38
|
+
a = np.asarray(a)
|
39
|
+
b = np.asarray(b)
|
40
|
+
kw = {}
|
41
|
+
if atol is not None:
|
42
|
+
kw['atol'] = atol
|
43
|
+
if rtol is not None:
|
44
|
+
kw['rtol'] = rtol
|
45
|
+
np.testing.assert_allclose(a, b, **kw)
|
46
|
+
if check_dtypes:
|
47
|
+
self.assertEqual(a.dtype, b.dtype)
|
48
|
+
|
49
|
+
def assertArraysEqual(self, a, b):
|
50
|
+
"""Helper method for backwards compatibility with JAX test utilities."""
|
51
|
+
np.testing.assert_array_equal(np.asarray(a), np.asarray(b))
|
52
|
+
|
33
53
|
def testSoftplusGrad(self):
|
34
|
-
check_grads(brainstate.
|
54
|
+
check_grads(brainstate.nn.softplus, (1e-8,), order=4, )
|
35
55
|
|
36
56
|
def testSoftplusGradZero(self):
|
37
|
-
check_grads(brainstate.
|
57
|
+
check_grads(brainstate.nn.softplus, (0.,), order=1)
|
38
58
|
|
39
59
|
def testSoftplusGradInf(self):
|
40
|
-
self.assertAllClose(1., jax.grad(brainstate.
|
60
|
+
self.assertAllClose(1., jax.grad(brainstate.nn.softplus)(float('inf')), check_dtypes=False)
|
41
61
|
|
42
62
|
def testSoftplusGradNegInf(self):
|
43
|
-
check_grads(brainstate.
|
63
|
+
check_grads(brainstate.nn.softplus, (-float('inf'),), order=1)
|
44
64
|
|
45
65
|
def testSoftplusGradNan(self):
|
46
|
-
check_grads(brainstate.
|
66
|
+
check_grads(brainstate.nn.softplus, (float('nan'),), order=1)
|
47
67
|
|
48
|
-
@parameterized.parameters([int, float
|
68
|
+
@parameterized.parameters([int, float, jnp.float32, jnp.float64, jnp.int32, jnp.int64])
|
49
69
|
def testSoftplusZero(self, dtype):
|
50
|
-
self.assertEqual(jnp.log(dtype(2)), brainstate.
|
70
|
+
self.assertEqual(jnp.log(dtype(2)), brainstate.nn.softplus(dtype(0)))
|
51
71
|
|
52
72
|
def testSparseplusGradZero(self):
|
53
|
-
check_grads(brainstate.
|
73
|
+
check_grads(brainstate.nn.sparse_plus, (-2.,), order=1)
|
54
74
|
|
55
75
|
def testSparseplusGrad(self):
|
56
|
-
check_grads(brainstate.
|
76
|
+
check_grads(brainstate.nn.sparse_plus, (0.,), order=1)
|
57
77
|
|
58
78
|
def testSparseplusAndSparseSigmoid(self):
|
59
79
|
self.assertAllClose(
|
60
|
-
jax.grad(brainstate.
|
61
|
-
brainstate.
|
80
|
+
jax.grad(brainstate.nn.sparse_plus)(0.),
|
81
|
+
brainstate.nn.sparse_sigmoid(0.),
|
62
82
|
check_dtypes=False)
|
63
83
|
self.assertAllClose(
|
64
|
-
jax.grad(brainstate.
|
65
|
-
brainstate.
|
84
|
+
jax.grad(brainstate.nn.sparse_plus)(2.),
|
85
|
+
brainstate.nn.sparse_sigmoid(2.),
|
66
86
|
check_dtypes=False)
|
67
87
|
self.assertAllClose(
|
68
|
-
jax.grad(brainstate.
|
69
|
-
brainstate.
|
88
|
+
jax.grad(brainstate.nn.sparse_plus)(-2.),
|
89
|
+
brainstate.nn.sparse_sigmoid(-2.),
|
70
90
|
check_dtypes=False)
|
71
91
|
|
72
92
|
# def testSquareplusGrad(self):
|
73
|
-
# check_grads(brainstate.
|
93
|
+
# check_grads(brainstate.nn.squareplus, (1e-8,), order=4,
|
74
94
|
# )
|
75
95
|
|
76
96
|
# def testSquareplusGradZero(self):
|
77
|
-
# check_grads(brainstate.
|
97
|
+
# check_grads(brainstate.nn.squareplus, (0.,), order=1,
|
78
98
|
# )
|
79
99
|
|
80
100
|
# def testSquareplusGradNegInf(self):
|
81
|
-
# check_grads(brainstate.
|
101
|
+
# check_grads(brainstate.nn.squareplus, (-float('inf'),), order=1,
|
82
102
|
# )
|
83
103
|
|
84
104
|
# def testSquareplusGradNan(self):
|
85
|
-
# check_grads(brainstate.
|
105
|
+
# check_grads(brainstate.nn.squareplus, (float('nan'),), order=1,
|
86
106
|
# )
|
87
107
|
|
88
|
-
# @parameterized.parameters([float
|
108
|
+
# @parameterized.parameters([float, jnp.float32, jnp.float64])
|
89
109
|
# def testSquareplusZero(self, dtype):
|
90
|
-
# self.assertEqual(dtype(1), brainstate.
|
110
|
+
# self.assertEqual(dtype(1), brainstate.nn.squareplus(dtype(0), dtype(4)))
|
91
111
|
#
|
92
112
|
# def testMishGrad(self):
|
93
|
-
# check_grads(brainstate.
|
113
|
+
# check_grads(brainstate.nn.mish, (1e-8,), order=4,
|
94
114
|
# )
|
95
115
|
#
|
96
116
|
# def testMishGradZero(self):
|
97
|
-
# check_grads(brainstate.
|
117
|
+
# check_grads(brainstate.nn.mish, (0.,), order=1,
|
98
118
|
# )
|
99
119
|
#
|
100
120
|
# def testMishGradNegInf(self):
|
101
|
-
# check_grads(brainstate.
|
121
|
+
# check_grads(brainstate.nn.mish, (-float('inf'),), order=1,
|
102
122
|
# )
|
103
123
|
#
|
104
124
|
# def testMishGradNan(self):
|
105
|
-
# check_grads(brainstate.
|
125
|
+
# check_grads(brainstate.nn.mish, (float('nan'),), order=1,
|
106
126
|
# )
|
107
127
|
|
108
|
-
@parameterized.parameters([float
|
128
|
+
@parameterized.parameters([float, jnp.float32, jnp.float64])
|
109
129
|
def testMishZero(self, dtype):
|
110
|
-
self.assertEqual(dtype(0), brainstate.
|
130
|
+
self.assertEqual(dtype(0), brainstate.nn.mish(dtype(0)))
|
111
131
|
|
112
132
|
def testReluGrad(self):
|
113
133
|
rtol = None
|
114
|
-
check_grads(brainstate.
|
115
|
-
check_grads(brainstate.
|
116
|
-
jaxpr = jax.make_jaxpr(jax.grad(brainstate.
|
134
|
+
check_grads(brainstate.nn.relu, (1.,), order=3, rtol=rtol)
|
135
|
+
check_grads(brainstate.nn.relu, (-1.,), order=3, rtol=rtol)
|
136
|
+
jaxpr = jax.make_jaxpr(jax.grad(brainstate.nn.relu))(0.)
|
117
137
|
self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
|
118
138
|
|
119
139
|
def testRelu6Grad(self):
|
120
140
|
rtol = None
|
121
|
-
check_grads(brainstate.
|
122
|
-
check_grads(brainstate.
|
123
|
-
self.assertAllClose(jax.grad(brainstate.
|
124
|
-
self.assertAllClose(jax.grad(brainstate.
|
141
|
+
check_grads(brainstate.nn.relu6, (1.,), order=3, rtol=rtol)
|
142
|
+
check_grads(brainstate.nn.relu6, (-1.,), order=3, rtol=rtol)
|
143
|
+
self.assertAllClose(jax.grad(brainstate.nn.relu6)(0.), 0., check_dtypes=False)
|
144
|
+
self.assertAllClose(jax.grad(brainstate.nn.relu6)(6.), 0., check_dtypes=False)
|
125
145
|
|
126
146
|
def testSoftplusValue(self):
|
127
|
-
val = brainstate.
|
147
|
+
val = brainstate.nn.softplus(89.)
|
128
148
|
self.assertAllClose(val, 89., check_dtypes=False)
|
129
149
|
|
130
150
|
def testSparseplusValue(self):
|
131
|
-
val = brainstate.
|
151
|
+
val = brainstate.nn.sparse_plus(89.)
|
132
152
|
self.assertAllClose(val, 89., check_dtypes=False)
|
133
153
|
|
134
154
|
def testSparsesigmoidValue(self):
|
135
|
-
self.assertAllClose(brainstate.
|
136
|
-
self.assertAllClose(brainstate.
|
137
|
-
self.assertAllClose(brainstate.
|
155
|
+
self.assertAllClose(brainstate.nn.sparse_sigmoid(-2.), 0., check_dtypes=False)
|
156
|
+
self.assertAllClose(brainstate.nn.sparse_sigmoid(2.), 1., check_dtypes=False)
|
157
|
+
self.assertAllClose(brainstate.nn.sparse_sigmoid(0.), .5, check_dtypes=False)
|
138
158
|
|
139
159
|
# def testSquareplusValue(self):
|
140
|
-
# val = brainstate.
|
160
|
+
# val = brainstate.nn.squareplus(1e3)
|
141
161
|
# self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
142
162
|
|
143
163
|
def testMishValue(self):
|
144
|
-
val = brainstate.
|
164
|
+
val = brainstate.nn.mish(1e3)
|
145
165
|
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
146
166
|
|
147
167
|
def testEluValue(self):
|
148
|
-
val = brainstate.
|
168
|
+
val = brainstate.nn.elu(1e4)
|
149
169
|
self.assertAllClose(val, 1e4, check_dtypes=False)
|
150
170
|
|
151
171
|
def testGluValue(self):
|
152
|
-
val = brainstate.
|
172
|
+
val = brainstate.nn.glu(jnp.array([1.0, 0.0]), axis=0)
|
153
173
|
self.assertAllClose(val, jnp.array([0.5]))
|
154
174
|
|
155
175
|
@parameterized.parameters(False, True)
|
156
176
|
def testGeluIntType(self, approximate):
|
157
|
-
val_float = brainstate.
|
158
|
-
val_int = brainstate.
|
177
|
+
val_float = brainstate.nn.gelu(jnp.array(-1.0), approximate=approximate)
|
178
|
+
val_int = brainstate.nn.gelu(jnp.array(-1), approximate=approximate)
|
159
179
|
self.assertAllClose(val_float, val_int)
|
160
180
|
|
161
181
|
@parameterized.parameters(False, True)
|
@@ -163,22 +183,21 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
163
183
|
def gelu_reference(x):
|
164
184
|
return x * scipy.stats.norm.cdf(x)
|
165
185
|
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
check_dtypes=False, tol=1e-3 if approximate else None)
|
186
|
+
x = jax.random.normal(self.rng_key, (4, 5, 6), dtype=jnp.float32)
|
187
|
+
expected = gelu_reference(x)
|
188
|
+
actual = brainstate.nn.gelu(x, approximate=approximate)
|
189
|
+
np.testing.assert_allclose(actual, expected, rtol=1e-2 if approximate else 1e-5, atol=1e-3 if approximate else 1e-5)
|
171
190
|
|
172
191
|
@parameterized.parameters(*itertools.product(
|
173
192
|
(jnp.float32, jnp.bfloat16, jnp.float16),
|
174
|
-
(partial(brainstate.
|
175
|
-
partial(brainstate.
|
176
|
-
brainstate.
|
177
|
-
brainstate.
|
178
|
-
brainstate.
|
179
|
-
brainstate.
|
180
|
-
# brainstate.
|
181
|
-
brainstate.
|
193
|
+
(partial(brainstate.nn.gelu, approximate=False),
|
194
|
+
partial(brainstate.nn.gelu, approximate=True),
|
195
|
+
brainstate.nn.relu,
|
196
|
+
brainstate.nn.softplus,
|
197
|
+
brainstate.nn.sparse_plus,
|
198
|
+
brainstate.nn.sigmoid,
|
199
|
+
# brainstate.nn.squareplus,
|
200
|
+
brainstate.nn.mish)))
|
182
201
|
def testDtypeMatchesInput(self, dtype, fn):
|
183
202
|
x = jnp.zeros((), dtype=dtype)
|
184
203
|
out = fn(x)
|
@@ -187,26 +206,26 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
187
206
|
def testEluMemory(self):
|
188
207
|
# see https://github.com/google/jax/pull/1640
|
189
208
|
with jax.enable_checks(False): # With checks we materialize the array
|
190
|
-
jax.make_jaxpr(lambda: brainstate.
|
209
|
+
jax.make_jaxpr(lambda: brainstate.nn.elu(jnp.ones((10 ** 12,)))) # don't oom
|
191
210
|
|
192
211
|
def testHardTanhMemory(self):
|
193
212
|
# see https://github.com/google/jax/pull/1640
|
194
213
|
with jax.enable_checks(False): # With checks we materialize the array
|
195
|
-
jax.make_jaxpr(lambda: brainstate.
|
214
|
+
jax.make_jaxpr(lambda: brainstate.nn.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
|
196
215
|
|
197
|
-
@parameterized.parameters([brainstate.
|
216
|
+
@parameterized.parameters([brainstate.nn.softmax, brainstate.nn.log_softmax])
|
198
217
|
def testSoftmaxEmptyArray(self, fn):
|
199
218
|
x = jnp.array([], dtype=float)
|
200
219
|
self.assertArraysEqual(fn(x), x)
|
201
220
|
|
202
|
-
@parameterized.parameters([brainstate.
|
221
|
+
@parameterized.parameters([brainstate.nn.softmax, brainstate.nn.log_softmax])
|
203
222
|
def testSoftmaxEmptyMask(self, fn):
|
204
223
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
205
224
|
m = jnp.zeros_like(x, dtype=bool)
|
206
|
-
expected = jnp.full_like(x, 0.0 if fn is brainstate.
|
225
|
+
expected = jnp.full_like(x, 0.0 if fn is brainstate.nn.softmax else -jnp.inf)
|
207
226
|
self.assertArraysEqual(fn(x, where=m), expected)
|
208
227
|
|
209
|
-
@parameterized.parameters([brainstate.
|
228
|
+
@parameterized.parameters([brainstate.nn.softmax, brainstate.nn.log_softmax])
|
210
229
|
def testSoftmaxWhereMask(self, fn):
|
211
230
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
212
231
|
m = jnp.array([True, False, True, True])
|
@@ -214,10 +233,10 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
214
233
|
out = fn(x, where=m)
|
215
234
|
self.assertAllClose(out[m], fn(x[m]))
|
216
235
|
|
217
|
-
probs = out if fn is brainstate.
|
218
|
-
self.assertAllClose(probs.sum(), 1.0)
|
236
|
+
probs = out if fn is brainstate.nn.softmax else jnp.exp(out)
|
237
|
+
self.assertAllClose(probs.sum(), 1.0, check_dtypes=False)
|
219
238
|
|
220
|
-
@parameterized.parameters([brainstate.
|
239
|
+
@parameterized.parameters([brainstate.nn.softmax, brainstate.nn.log_softmax])
|
221
240
|
def testSoftmaxWhereGrad(self, fn):
|
222
241
|
# regression test for https://github.com/google/jax/issues/19490
|
223
242
|
x = jnp.array([36., 10000.])
|
@@ -229,46 +248,46 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
229
248
|
|
230
249
|
def testSoftmaxGrad(self):
|
231
250
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
232
|
-
|
251
|
+
check_grads(brainstate.nn.softmax, (x,), order=2, atol=5e-3)
|
233
252
|
|
234
253
|
def testStandardizeWhereMask(self):
|
235
254
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
236
255
|
m = jnp.array([True, False, True, True])
|
237
256
|
x_filtered = jnp.take(x, jnp.array([0, 2, 3]))
|
238
257
|
|
239
|
-
out_masked = jnp.take(brainstate.
|
240
|
-
out_filtered = brainstate.
|
258
|
+
out_masked = jnp.take(brainstate.nn.standardize(x, where=m), jnp.array([0, 2, 3]))
|
259
|
+
out_filtered = brainstate.nn.standardize(x_filtered)
|
241
260
|
|
242
|
-
self.assertAllClose(out_masked, out_filtered)
|
261
|
+
self.assertAllClose(out_masked, out_filtered, rtol=1e-6, atol=1e-6)
|
243
262
|
|
244
263
|
def testOneHot(self):
|
245
|
-
actual = brainstate.
|
264
|
+
actual = brainstate.nn.one_hot(jnp.array([0, 1, 2]), 3)
|
246
265
|
expected = jnp.array([[1., 0., 0.],
|
247
266
|
[0., 1., 0.],
|
248
267
|
[0., 0., 1.]])
|
249
268
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
250
269
|
|
251
|
-
actual = brainstate.
|
270
|
+
actual = brainstate.nn.one_hot(jnp.array([1, 2, 0]), 3)
|
252
271
|
expected = jnp.array([[0., 1., 0.],
|
253
272
|
[0., 0., 1.],
|
254
273
|
[1., 0., 0.]])
|
255
274
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
256
275
|
|
257
276
|
def testOneHotOutOfBound(self):
|
258
|
-
actual = brainstate.
|
277
|
+
actual = brainstate.nn.one_hot(jnp.array([-1, 3]), 3)
|
259
278
|
expected = jnp.array([[0., 0., 0.],
|
260
279
|
[0., 0., 0.]])
|
261
280
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
262
281
|
|
263
282
|
def testOneHotNonArrayInput(self):
|
264
|
-
actual = brainstate.
|
283
|
+
actual = brainstate.nn.one_hot([0, 1, 2], 3)
|
265
284
|
expected = jnp.array([[1., 0., 0.],
|
266
285
|
[0., 1., 0.],
|
267
286
|
[0., 0., 1.]])
|
268
287
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
269
288
|
|
270
289
|
def testOneHotCustomDtype(self):
|
271
|
-
actual = brainstate.
|
290
|
+
actual = brainstate.nn.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
|
272
291
|
expected = jnp.array([[True, False, False],
|
273
292
|
[False, True, False],
|
274
293
|
[False, False, True]])
|
@@ -279,14 +298,14 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
279
298
|
[0., 0., 1.],
|
280
299
|
[1., 0., 0.]]).T
|
281
300
|
|
282
|
-
actual = brainstate.
|
301
|
+
actual = brainstate.nn.one_hot(jnp.array([1, 2, 0]), 3, axis=0)
|
283
302
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
284
303
|
|
285
|
-
actual = brainstate.
|
304
|
+
actual = brainstate.nn.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
|
286
305
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
287
306
|
|
288
307
|
def testTanhExists(self):
|
289
|
-
print(brainstate.
|
308
|
+
print(brainstate.nn.tanh) # doesn't crash
|
290
309
|
|
291
310
|
def testCustomJVPLeak(self):
|
292
311
|
# https://github.com/google/jax/issues/8171
|
@@ -295,7 +314,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
295
314
|
a = jnp.array(1.)
|
296
315
|
|
297
316
|
def f(hx, _):
|
298
|
-
hx = brainstate.
|
317
|
+
hx = brainstate.nn.sigmoid(hx + a)
|
299
318
|
return hx, None
|
300
319
|
|
301
320
|
hx = jnp.array(0.)
|
@@ -306,7 +325,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
306
325
|
|
307
326
|
def testCustomJVPLeak2(self):
|
308
327
|
# https://github.com/google/jax/issues/8171
|
309
|
-
# The above test uses jax.brainstate.
|
328
|
+
# The above test uses jax.brainstate.nn.sigmoid, as in the original #8171, but that
|
310
329
|
# function no longer actually has a custom_jvp! So we inline the old def.
|
311
330
|
|
312
331
|
@jax.custom_jvp
|
@@ -329,3 +348,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
329
348
|
|
330
349
|
with jax.checking_leaks():
|
331
350
|
fwd() # doesn't crash
|
351
|
+
|
352
|
+
|
353
|
+
if __name__ == '__main__':
|
354
|
+
absltest.main()
|