brainstate 0.1.1__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +3 -0
- brainstate/_state.py +1 -1
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +29 -14
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/functional/_activations_test.py +61 -61
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
- brainstate/nn/_dyn_impl/_readout_test.py +9 -10
- brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
- brainstate/nn/_dynamics/_synouts_test.py +4 -5
- brainstate/nn/_elementwise/_dropout_test.py +9 -9
- brainstate/nn/_elementwise/_elementwise_test.py +57 -59
- brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
- brainstate/nn/_event/_linear_mv_test.py +0 -1
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/_interaction/_conv_test.py +31 -33
- brainstate/nn/_interaction/_linear_test.py +15 -17
- brainstate/nn/_interaction/_normalizations_test.py +10 -12
- brainstate/nn/_interaction/_poolings_test.py +19 -21
- brainstate/nn/_module_test.py +34 -37
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed_test.py +10 -12
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/METADATA +1 -1
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/RECORD +44 -44
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -13,20 +13,18 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
import unittest
|
19
17
|
|
20
18
|
import jax.numpy as jnp
|
21
19
|
import numpy as np
|
22
20
|
|
23
|
-
import brainstate
|
21
|
+
import brainstate
|
24
22
|
|
25
23
|
|
26
24
|
class TestForLoop(unittest.TestCase):
|
27
25
|
def test_for_loop(self):
|
28
|
-
a =
|
29
|
-
b =
|
26
|
+
a = brainstate.ShortTermState(0.)
|
27
|
+
b = brainstate.ShortTermState(0.)
|
30
28
|
|
31
29
|
def f(i):
|
32
30
|
a.value += (1 + b.value)
|
@@ -34,7 +32,7 @@ class TestForLoop(unittest.TestCase):
|
|
34
32
|
|
35
33
|
n_iter = 10
|
36
34
|
ops = np.arange(n_iter)
|
37
|
-
r =
|
35
|
+
r = brainstate.compile.for_loop(f, ops)
|
38
36
|
|
39
37
|
print(a)
|
40
38
|
print(b)
|
@@ -42,8 +40,8 @@ class TestForLoop(unittest.TestCase):
|
|
42
40
|
self.assertTrue(jnp.allclose(r, ops + 1))
|
43
41
|
|
44
42
|
def test_checkpointed_for_loop(self):
|
45
|
-
a =
|
46
|
-
b =
|
43
|
+
a = brainstate.ShortTermState(0.)
|
44
|
+
b = brainstate.ShortTermState(0.)
|
47
45
|
|
48
46
|
def f(i):
|
49
47
|
a.value += (1 + b.value)
|
@@ -51,7 +49,7 @@ class TestForLoop(unittest.TestCase):
|
|
51
49
|
|
52
50
|
n_iter = 18
|
53
51
|
ops = jnp.arange(n_iter)
|
54
|
-
r =
|
52
|
+
r = brainstate.compile.checkpointed_for_loop(f, ops, base=2, pbar=brainstate.compile.ProgressBar())
|
55
53
|
|
56
54
|
print(a)
|
57
55
|
print(b)
|
@@ -13,17 +13,16 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
16
|
|
18
17
|
from unittest import TestCase
|
19
18
|
|
20
|
-
import brainstate
|
19
|
+
import brainstate
|
21
20
|
|
22
21
|
|
23
22
|
class TestWhileLoop(TestCase):
|
24
23
|
def test1(self):
|
25
|
-
a =
|
26
|
-
b =
|
24
|
+
a = brainstate.State(1.)
|
25
|
+
b = brainstate.State(20.)
|
27
26
|
|
28
27
|
def cond(_):
|
29
28
|
return a.value < b.value
|
@@ -31,13 +30,13 @@ class TestWhileLoop(TestCase):
|
|
31
30
|
def body(_):
|
32
31
|
a.value += 1.
|
33
32
|
|
34
|
-
|
33
|
+
brainstate.compile.while_loop(cond, body, None)
|
35
34
|
|
36
35
|
print(a.value, b.value)
|
37
36
|
|
38
37
|
def test2(self):
|
39
|
-
a =
|
40
|
-
b =
|
38
|
+
a = brainstate.State(1.)
|
39
|
+
b = brainstate.State(20.)
|
41
40
|
|
42
41
|
def cond(x):
|
43
42
|
return a.value < b.value
|
@@ -46,6 +45,6 @@ class TestWhileLoop(TestCase):
|
|
46
45
|
a.value += x
|
47
46
|
return x
|
48
47
|
|
49
|
-
r =
|
48
|
+
r = brainstate.compile.while_loop(cond, body, 1.)
|
50
49
|
|
51
50
|
print(a.value, b.value, r)
|
@@ -62,8 +62,8 @@ import jax
|
|
62
62
|
from jax._src import source_info_util
|
63
63
|
from jax._src.linear_util import annotate
|
64
64
|
from jax._src.traceback_util import api_boundary
|
65
|
-
from jax.api_util import shaped_abstractify
|
66
|
-
from jax.extend.linear_util import transformation_with_aux
|
65
|
+
from jax.api_util import shaped_abstractify
|
66
|
+
from jax.extend.linear_util import transformation_with_aux
|
67
67
|
from jax.interpreters import partial_eval as pe
|
68
68
|
|
69
69
|
from brainstate._compatible_import import (
|
@@ -73,6 +73,7 @@ from brainstate._compatible_import import (
|
|
73
73
|
safe_zip,
|
74
74
|
unzip2,
|
75
75
|
wraps,
|
76
|
+
wrap_init,
|
76
77
|
)
|
77
78
|
from brainstate._state import State, StateTraceStack
|
78
79
|
from brainstate._utils import set_module_as
|
@@ -96,7 +97,7 @@ def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
|
|
96
97
|
return tuple(safe_map(operator.index, x))
|
97
98
|
|
98
99
|
|
99
|
-
def
|
100
|
+
def _jax_v04_new_arg_fn(frame, trace, aval):
|
100
101
|
"""
|
101
102
|
Transform a new argument to a tracer.
|
102
103
|
|
@@ -117,27 +118,41 @@ def _new_arg_fn(frame, trace, aval):
|
|
117
118
|
return tracer
|
118
119
|
|
119
120
|
|
120
|
-
def
|
121
|
+
def _jax_v04_new_jax_trace():
|
121
122
|
main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
|
122
123
|
frame = main.jaxpr_stack[-1]
|
123
124
|
trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
|
124
125
|
return frame, trace
|
125
126
|
|
126
127
|
|
128
|
+
def _jax_v04_new_arg():
|
129
|
+
# Should be within the calling of ``jax.make_jaxpr()``
|
130
|
+
frame, trace = _jax_v04_new_jax_trace()
|
131
|
+
# Set the function to transform the new argument to a tracer
|
132
|
+
fn = functools.partial(_jax_v04_new_arg_fn, frame, trace)
|
133
|
+
return fn
|
134
|
+
|
135
|
+
|
136
|
+
def _jax_new_version_new_arg():
|
137
|
+
trace = jax.core.trace_ctx.trace
|
138
|
+
|
139
|
+
def wrapper(x):
|
140
|
+
if jax.__version_info__ < (0, 6, 1):
|
141
|
+
return trace.new_arg(shaped_abstractify(x))
|
142
|
+
else:
|
143
|
+
return trace.new_arg(shaped_abstractify(x), source_info=source_info_util.current())
|
144
|
+
|
145
|
+
return wrapper
|
146
|
+
|
147
|
+
|
127
148
|
def _init_state_trace_stack(name) -> StateTraceStack:
|
128
149
|
state_trace: StateTraceStack = StateTraceStack(name=name)
|
129
150
|
|
130
151
|
if jax.__version_info__ < (0, 4, 36):
|
131
|
-
|
132
|
-
frame, trace = _new_jax_trace()
|
133
|
-
# Set the function to transform the new argument to a tracer
|
134
|
-
state_trace.set_new_arg(functools.partial(_new_arg_fn, frame, trace))
|
135
|
-
return state_trace
|
136
|
-
|
152
|
+
state_trace.set_new_arg(_jax_v04_new_arg())
|
137
153
|
else:
|
138
|
-
|
139
|
-
|
140
|
-
return state_trace
|
154
|
+
state_trace.set_new_arg(_jax_new_version_new_arg())
|
155
|
+
return state_trace
|
141
156
|
|
142
157
|
|
143
158
|
class StatefulFunction(PrettyObject):
|
@@ -743,7 +758,7 @@ def _make_jaxpr(
|
|
743
758
|
@wraps(fun)
|
744
759
|
@api_boundary
|
745
760
|
def make_jaxpr_f(*args, **kwargs):
|
746
|
-
f = wrap_init(fun,
|
761
|
+
f = wrap_init(fun, (), {}, 'brainstate.compile.make_jaxpr')
|
747
762
|
if static_argnums:
|
748
763
|
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
749
764
|
f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
|
@@ -21,7 +21,7 @@ import jax
|
|
21
21
|
import jax.numpy as jnp
|
22
22
|
import pytest
|
23
23
|
|
24
|
-
import brainstate
|
24
|
+
import brainstate
|
25
25
|
from brainstate._compatible_import import jaxpr_as_fun
|
26
26
|
|
27
27
|
|
@@ -29,10 +29,10 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
29
29
|
def test_compar_jax_make_jaxpr(self):
|
30
30
|
def func4(arg): # Arg is a pair
|
31
31
|
temp = arg[0] + jnp.sin(arg[1]) * 3.
|
32
|
-
c =
|
32
|
+
c = brainstate.random.rand_like(arg[0])
|
33
33
|
return jnp.sum(temp + c)
|
34
34
|
|
35
|
-
key =
|
35
|
+
key = brainstate.random.DEFAULT.value
|
36
36
|
jaxpr = jax.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
|
37
37
|
print(jaxpr)
|
38
38
|
self.assertTrue(len(jaxpr.in_avals) == 2)
|
@@ -40,66 +40,66 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
40
40
|
self.assertTrue(len(jaxpr.out_avals) == 1)
|
41
41
|
self.assertTrue(jnp.allclose(jaxpr.consts[0], key))
|
42
42
|
|
43
|
-
|
44
|
-
print(
|
43
|
+
brainstate.random.seed(1)
|
44
|
+
print(brainstate.random.DEFAULT.value)
|
45
45
|
|
46
|
-
jaxpr2, states =
|
46
|
+
jaxpr2, states = brainstate.compile.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
|
47
47
|
print(jaxpr2)
|
48
48
|
self.assertTrue(len(jaxpr2.in_avals) == 3)
|
49
49
|
self.assertTrue(len(jaxpr2.out_avals) == 2)
|
50
50
|
self.assertTrue(len(jaxpr2.consts) == 0)
|
51
|
-
print(
|
51
|
+
print(brainstate.random.DEFAULT.value)
|
52
52
|
|
53
53
|
def test_StatefulFunction_1(self):
|
54
54
|
def func4(arg): # Arg is a pair
|
55
55
|
temp = arg[0] + jnp.sin(arg[1]) * 3.
|
56
|
-
c =
|
56
|
+
c = brainstate.random.rand_like(arg[0])
|
57
57
|
return jnp.sum(temp + c)
|
58
58
|
|
59
|
-
fun =
|
59
|
+
fun = brainstate.compile.StatefulFunction(func4).make_jaxpr((jnp.zeros(8), jnp.ones(8)))
|
60
60
|
print(fun.get_states())
|
61
61
|
print(fun.get_jaxpr())
|
62
62
|
|
63
63
|
def test_StatefulFunction_2(self):
|
64
|
-
st1 =
|
64
|
+
st1 = brainstate.State(jnp.ones(10))
|
65
65
|
|
66
66
|
def f1(x):
|
67
67
|
st1.value = x + st1.value
|
68
68
|
|
69
69
|
def f2(x):
|
70
|
-
jaxpr =
|
70
|
+
jaxpr = brainstate.compile.make_jaxpr(f1)(x)
|
71
71
|
c = 1. + x
|
72
72
|
return c
|
73
73
|
|
74
74
|
def f3(x):
|
75
|
-
jaxpr =
|
75
|
+
jaxpr = brainstate.compile.make_jaxpr(f1)(x)
|
76
76
|
c = 1.
|
77
77
|
return c
|
78
78
|
|
79
79
|
print()
|
80
|
-
jaxpr =
|
80
|
+
jaxpr = brainstate.compile.make_jaxpr(f1)(jnp.zeros(1))
|
81
81
|
print(jaxpr)
|
82
82
|
jaxpr = jax.make_jaxpr(f2)(jnp.zeros(1))
|
83
83
|
print(jaxpr)
|
84
84
|
jaxpr = jax.make_jaxpr(f3)(jnp.zeros(1))
|
85
85
|
print(jaxpr)
|
86
|
-
jaxpr, _ =
|
86
|
+
jaxpr, _ = brainstate.compile.make_jaxpr(f3)(jnp.zeros(1))
|
87
87
|
print(jaxpr)
|
88
88
|
self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
|
89
89
|
f3(jnp.zeros(1))))
|
90
90
|
|
91
91
|
def test_compar_jax_make_jaxpr2(self):
|
92
|
-
st1 =
|
92
|
+
st1 = brainstate.State(jnp.ones(10))
|
93
93
|
|
94
94
|
def fa(x):
|
95
95
|
st1.value = x + st1.value
|
96
96
|
|
97
97
|
def ffa(x):
|
98
|
-
jaxpr, states =
|
98
|
+
jaxpr, states = brainstate.compile.make_jaxpr(fa)(x)
|
99
99
|
c = 1. + x
|
100
100
|
return c
|
101
101
|
|
102
|
-
jaxpr, states =
|
102
|
+
jaxpr, states = brainstate.compile.make_jaxpr(ffa)(jnp.zeros(1))
|
103
103
|
print()
|
104
104
|
print(jaxpr)
|
105
105
|
print(states)
|
@@ -112,7 +112,7 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
112
112
|
def fa(x):
|
113
113
|
return 1.
|
114
114
|
|
115
|
-
jaxpr, states =
|
115
|
+
jaxpr, states = brainstate.compile.make_jaxpr(fa)(jnp.zeros(1))
|
116
116
|
print()
|
117
117
|
print(jaxpr)
|
118
118
|
print(states)
|
@@ -125,9 +125,9 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
125
125
|
def test_return_states():
|
126
126
|
import jax.numpy
|
127
127
|
|
128
|
-
a =
|
128
|
+
a = brainstate.State(jax.numpy.ones(3))
|
129
129
|
|
130
|
-
@
|
130
|
+
@brainstate.compile.jit
|
131
131
|
def f():
|
132
132
|
return a
|
133
133
|
|
@@ -25,48 +25,48 @@ from absl.testing import parameterized
|
|
25
25
|
from jax._src import test_util as jtu
|
26
26
|
from jax.test_util import check_grads
|
27
27
|
|
28
|
-
import brainstate
|
28
|
+
import brainstate
|
29
29
|
|
30
30
|
|
31
31
|
class NNFunctionsTest(jtu.JaxTestCase):
|
32
32
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
33
33
|
def testSoftplusGrad(self):
|
34
|
-
check_grads(
|
34
|
+
check_grads(brainstate.functional.softplus, (1e-8,), order=4, )
|
35
35
|
|
36
36
|
def testSoftplusGradZero(self):
|
37
|
-
check_grads(
|
37
|
+
check_grads(brainstate.functional.softplus, (0.,), order=1)
|
38
38
|
|
39
39
|
def testSoftplusGradInf(self):
|
40
|
-
self.assertAllClose(1., jax.grad(
|
40
|
+
self.assertAllClose(1., jax.grad(brainstate.functional.softplus)(float('inf')))
|
41
41
|
|
42
42
|
def testSoftplusGradNegInf(self):
|
43
|
-
check_grads(
|
43
|
+
check_grads(brainstate.functional.softplus, (-float('inf'),), order=1)
|
44
44
|
|
45
45
|
def testSoftplusGradNan(self):
|
46
|
-
check_grads(
|
46
|
+
check_grads(brainstate.functional.softplus, (float('nan'),), order=1)
|
47
47
|
|
48
48
|
@parameterized.parameters([int, float] + jtu.dtypes.floating + jtu.dtypes.integer)
|
49
49
|
def testSoftplusZero(self, dtype):
|
50
|
-
self.assertEqual(jnp.log(dtype(2)),
|
50
|
+
self.assertEqual(jnp.log(dtype(2)), brainstate.functional.softplus(dtype(0)))
|
51
51
|
|
52
52
|
def testSparseplusGradZero(self):
|
53
|
-
check_grads(
|
53
|
+
check_grads(brainstate.functional.sparse_plus, (-2.,), order=1)
|
54
54
|
|
55
55
|
def testSparseplusGrad(self):
|
56
|
-
check_grads(
|
56
|
+
check_grads(brainstate.functional.sparse_plus, (0.,), order=1)
|
57
57
|
|
58
58
|
def testSparseplusAndSparseSigmoid(self):
|
59
59
|
self.assertAllClose(
|
60
|
-
jax.grad(
|
61
|
-
|
60
|
+
jax.grad(brainstate.functional.sparse_plus)(0.),
|
61
|
+
brainstate.functional.sparse_sigmoid(0.),
|
62
62
|
check_dtypes=False)
|
63
63
|
self.assertAllClose(
|
64
|
-
jax.grad(
|
65
|
-
|
64
|
+
jax.grad(brainstate.functional.sparse_plus)(2.),
|
65
|
+
brainstate.functional.sparse_sigmoid(2.),
|
66
66
|
check_dtypes=False)
|
67
67
|
self.assertAllClose(
|
68
|
-
jax.grad(
|
69
|
-
|
68
|
+
jax.grad(brainstate.functional.sparse_plus)(-2.),
|
69
|
+
brainstate.functional.sparse_sigmoid(-2.),
|
70
70
|
check_dtypes=False)
|
71
71
|
|
72
72
|
# def testSquareplusGrad(self):
|
@@ -107,55 +107,55 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
107
107
|
|
108
108
|
@parameterized.parameters([float] + jtu.dtypes.floating)
|
109
109
|
def testMishZero(self, dtype):
|
110
|
-
self.assertEqual(dtype(0),
|
110
|
+
self.assertEqual(dtype(0), brainstate.functional.mish(dtype(0)))
|
111
111
|
|
112
112
|
def testReluGrad(self):
|
113
113
|
rtol = None
|
114
|
-
check_grads(
|
115
|
-
check_grads(
|
116
|
-
jaxpr = jax.make_jaxpr(jax.grad(
|
114
|
+
check_grads(brainstate.functional.relu, (1.,), order=3, rtol=rtol)
|
115
|
+
check_grads(brainstate.functional.relu, (-1.,), order=3, rtol=rtol)
|
116
|
+
jaxpr = jax.make_jaxpr(jax.grad(brainstate.functional.relu))(0.)
|
117
117
|
self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
|
118
118
|
|
119
119
|
def testRelu6Grad(self):
|
120
120
|
rtol = None
|
121
|
-
check_grads(
|
122
|
-
check_grads(
|
123
|
-
self.assertAllClose(jax.grad(
|
124
|
-
self.assertAllClose(jax.grad(
|
121
|
+
check_grads(brainstate.functional.relu6, (1.,), order=3, rtol=rtol)
|
122
|
+
check_grads(brainstate.functional.relu6, (-1.,), order=3, rtol=rtol)
|
123
|
+
self.assertAllClose(jax.grad(brainstate.functional.relu6)(0.), 0., check_dtypes=False)
|
124
|
+
self.assertAllClose(jax.grad(brainstate.functional.relu6)(6.), 0., check_dtypes=False)
|
125
125
|
|
126
126
|
def testSoftplusValue(self):
|
127
|
-
val =
|
127
|
+
val = brainstate.functional.softplus(89.)
|
128
128
|
self.assertAllClose(val, 89., check_dtypes=False)
|
129
129
|
|
130
130
|
def testSparseplusValue(self):
|
131
|
-
val =
|
131
|
+
val = brainstate.functional.sparse_plus(89.)
|
132
132
|
self.assertAllClose(val, 89., check_dtypes=False)
|
133
133
|
|
134
134
|
def testSparsesigmoidValue(self):
|
135
|
-
self.assertAllClose(
|
136
|
-
self.assertAllClose(
|
137
|
-
self.assertAllClose(
|
135
|
+
self.assertAllClose(brainstate.functional.sparse_sigmoid(-2.), 0., check_dtypes=False)
|
136
|
+
self.assertAllClose(brainstate.functional.sparse_sigmoid(2.), 1., check_dtypes=False)
|
137
|
+
self.assertAllClose(brainstate.functional.sparse_sigmoid(0.), .5, check_dtypes=False)
|
138
138
|
|
139
139
|
# def testSquareplusValue(self):
|
140
140
|
# val = bst.functional.squareplus(1e3)
|
141
141
|
# self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
142
142
|
|
143
143
|
def testMishValue(self):
|
144
|
-
val =
|
144
|
+
val = brainstate.functional.mish(1e3)
|
145
145
|
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
146
146
|
|
147
147
|
def testEluValue(self):
|
148
|
-
val =
|
148
|
+
val = brainstate.functional.elu(1e4)
|
149
149
|
self.assertAllClose(val, 1e4, check_dtypes=False)
|
150
150
|
|
151
151
|
def testGluValue(self):
|
152
|
-
val =
|
152
|
+
val = brainstate.functional.glu(jnp.array([1.0, 0.0]), axis=0)
|
153
153
|
self.assertAllClose(val, jnp.array([0.5]))
|
154
154
|
|
155
155
|
@parameterized.parameters(False, True)
|
156
156
|
def testGeluIntType(self, approximate):
|
157
|
-
val_float =
|
158
|
-
val_int =
|
157
|
+
val_float = brainstate.functional.gelu(jnp.array(-1.0), approximate=approximate)
|
158
|
+
val_int = brainstate.functional.gelu(jnp.array(-1), approximate=approximate)
|
159
159
|
self.assertAllClose(val_float, val_int)
|
160
160
|
|
161
161
|
@parameterized.parameters(False, True)
|
@@ -166,19 +166,19 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
166
166
|
rng = jtu.rand_default(self.rng())
|
167
167
|
args_maker = lambda: [rng((4, 5, 6), jnp.float32)]
|
168
168
|
self._CheckAgainstNumpy(
|
169
|
-
gelu_reference, partial(
|
169
|
+
gelu_reference, partial(brainstate.functional.gelu, approximate=approximate), args_maker,
|
170
170
|
check_dtypes=False, tol=1e-3 if approximate else None)
|
171
171
|
|
172
172
|
@parameterized.parameters(*itertools.product(
|
173
173
|
(jnp.float32, jnp.bfloat16, jnp.float16),
|
174
|
-
(partial(
|
175
|
-
partial(
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
174
|
+
(partial(brainstate.functional.gelu, approximate=False),
|
175
|
+
partial(brainstate.functional.gelu, approximate=True),
|
176
|
+
brainstate.functional.relu,
|
177
|
+
brainstate.functional.softplus,
|
178
|
+
brainstate.functional.sparse_plus,
|
179
|
+
brainstate.functional.sigmoid,
|
180
180
|
# bst.functional.squareplus,
|
181
|
-
|
181
|
+
brainstate.functional.mish)))
|
182
182
|
def testDtypeMatchesInput(self, dtype, fn):
|
183
183
|
x = jnp.zeros((), dtype=dtype)
|
184
184
|
out = fn(x)
|
@@ -187,26 +187,26 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
187
187
|
def testEluMemory(self):
|
188
188
|
# see https://github.com/google/jax/pull/1640
|
189
189
|
with jax.enable_checks(False): # With checks we materialize the array
|
190
|
-
jax.make_jaxpr(lambda:
|
190
|
+
jax.make_jaxpr(lambda: brainstate.functional.elu(jnp.ones((10 ** 12,)))) # don't oom
|
191
191
|
|
192
192
|
def testHardTanhMemory(self):
|
193
193
|
# see https://github.com/google/jax/pull/1640
|
194
194
|
with jax.enable_checks(False): # With checks we materialize the array
|
195
|
-
jax.make_jaxpr(lambda:
|
195
|
+
jax.make_jaxpr(lambda: brainstate.functional.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
|
196
196
|
|
197
|
-
@parameterized.parameters([
|
197
|
+
@parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
|
198
198
|
def testSoftmaxEmptyArray(self, fn):
|
199
199
|
x = jnp.array([], dtype=float)
|
200
200
|
self.assertArraysEqual(fn(x), x)
|
201
201
|
|
202
|
-
@parameterized.parameters([
|
202
|
+
@parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
|
203
203
|
def testSoftmaxEmptyMask(self, fn):
|
204
204
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
205
205
|
m = jnp.zeros_like(x, dtype=bool)
|
206
|
-
expected = jnp.full_like(x, 0.0 if fn is
|
206
|
+
expected = jnp.full_like(x, 0.0 if fn is brainstate.functional.softmax else -jnp.inf)
|
207
207
|
self.assertArraysEqual(fn(x, where=m), expected)
|
208
208
|
|
209
|
-
@parameterized.parameters([
|
209
|
+
@parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
|
210
210
|
def testSoftmaxWhereMask(self, fn):
|
211
211
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
212
212
|
m = jnp.array([True, False, True, True])
|
@@ -214,10 +214,10 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
214
214
|
out = fn(x, where=m)
|
215
215
|
self.assertAllClose(out[m], fn(x[m]))
|
216
216
|
|
217
|
-
probs = out if fn is
|
217
|
+
probs = out if fn is brainstate.functional.softmax else jnp.exp(out)
|
218
218
|
self.assertAllClose(probs.sum(), 1.0)
|
219
219
|
|
220
|
-
@parameterized.parameters([
|
220
|
+
@parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
|
221
221
|
def testSoftmaxWhereGrad(self, fn):
|
222
222
|
# regression test for https://github.com/google/jax/issues/19490
|
223
223
|
x = jnp.array([36., 10000.])
|
@@ -229,46 +229,46 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
229
229
|
|
230
230
|
def testSoftmaxGrad(self):
|
231
231
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
232
|
-
jtu.check_grads(
|
232
|
+
jtu.check_grads(brainstate.functional.softmax, (x,), order=2, atol=5e-3)
|
233
233
|
|
234
234
|
def testStandardizeWhereMask(self):
|
235
235
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
236
236
|
m = jnp.array([True, False, True, True])
|
237
237
|
x_filtered = jnp.take(x, jnp.array([0, 2, 3]))
|
238
238
|
|
239
|
-
out_masked = jnp.take(
|
240
|
-
out_filtered =
|
239
|
+
out_masked = jnp.take(brainstate.functional.standardize(x, where=m), jnp.array([0, 2, 3]))
|
240
|
+
out_filtered = brainstate.functional.standardize(x_filtered)
|
241
241
|
|
242
242
|
self.assertAllClose(out_masked, out_filtered)
|
243
243
|
|
244
244
|
def testOneHot(self):
|
245
|
-
actual =
|
245
|
+
actual = brainstate.functional.one_hot(jnp.array([0, 1, 2]), 3)
|
246
246
|
expected = jnp.array([[1., 0., 0.],
|
247
247
|
[0., 1., 0.],
|
248
248
|
[0., 0., 1.]])
|
249
249
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
250
250
|
|
251
|
-
actual =
|
251
|
+
actual = brainstate.functional.one_hot(jnp.array([1, 2, 0]), 3)
|
252
252
|
expected = jnp.array([[0., 1., 0.],
|
253
253
|
[0., 0., 1.],
|
254
254
|
[1., 0., 0.]])
|
255
255
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
256
256
|
|
257
257
|
def testOneHotOutOfBound(self):
|
258
|
-
actual =
|
258
|
+
actual = brainstate.functional.one_hot(jnp.array([-1, 3]), 3)
|
259
259
|
expected = jnp.array([[0., 0., 0.],
|
260
260
|
[0., 0., 0.]])
|
261
261
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
262
262
|
|
263
263
|
def testOneHotNonArrayInput(self):
|
264
|
-
actual =
|
264
|
+
actual = brainstate.functional.one_hot([0, 1, 2], 3)
|
265
265
|
expected = jnp.array([[1., 0., 0.],
|
266
266
|
[0., 1., 0.],
|
267
267
|
[0., 0., 1.]])
|
268
268
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
269
269
|
|
270
270
|
def testOneHotCustomDtype(self):
|
271
|
-
actual =
|
271
|
+
actual = brainstate.functional.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
|
272
272
|
expected = jnp.array([[True, False, False],
|
273
273
|
[False, True, False],
|
274
274
|
[False, False, True]])
|
@@ -279,14 +279,14 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
279
279
|
[0., 0., 1.],
|
280
280
|
[1., 0., 0.]]).T
|
281
281
|
|
282
|
-
actual =
|
282
|
+
actual = brainstate.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=0)
|
283
283
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
284
284
|
|
285
|
-
actual =
|
285
|
+
actual = brainstate.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
|
286
286
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
287
287
|
|
288
288
|
def testTanhExists(self):
|
289
|
-
print(
|
289
|
+
print(brainstate.functional.tanh) # doesn't crash
|
290
290
|
|
291
291
|
def testCustomJVPLeak(self):
|
292
292
|
# https://github.com/google/jax/issues/8171
|
@@ -295,7 +295,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
295
295
|
a = jnp.array(1.)
|
296
296
|
|
297
297
|
def f(hx, _):
|
298
|
-
hx =
|
298
|
+
hx = brainstate.functional.sigmoid(hx + a)
|
299
299
|
return hx, None
|
300
300
|
|
301
301
|
hx = jnp.array(0.)
|