brainstate 0.1.1__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +3 -0
  3. brainstate/_state.py +1 -1
  4. brainstate/augment/_autograd_test.py +132 -133
  5. brainstate/augment/_eval_shape_test.py +7 -9
  6. brainstate/augment/_mapping_test.py +75 -76
  7. brainstate/compile/_ad_checkpoint_test.py +6 -8
  8. brainstate/compile/_conditions_test.py +35 -36
  9. brainstate/compile/_error_if_test.py +10 -13
  10. brainstate/compile/_loop_collect_return_test.py +7 -9
  11. brainstate/compile/_loop_no_collection_test.py +7 -8
  12. brainstate/compile/_make_jaxpr.py +29 -14
  13. brainstate/compile/_make_jaxpr_test.py +20 -20
  14. brainstate/functional/_activations_test.py +61 -61
  15. brainstate/graph/_graph_node_test.py +16 -18
  16. brainstate/graph/_graph_operation_test.py +154 -156
  17. brainstate/init/_random_inits_test.py +20 -21
  18. brainstate/init/_regular_inits_test.py +4 -5
  19. brainstate/nn/_collective_ops_test.py +8 -8
  20. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
  21. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
  22. brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
  23. brainstate/nn/_dyn_impl/_readout_test.py +9 -10
  24. brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
  25. brainstate/nn/_dynamics/_synouts_test.py +4 -5
  26. brainstate/nn/_elementwise/_dropout_test.py +9 -9
  27. brainstate/nn/_elementwise/_elementwise_test.py +57 -59
  28. brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
  29. brainstate/nn/_event/_linear_mv_test.py +0 -1
  30. brainstate/nn/_exp_euler_test.py +5 -6
  31. brainstate/nn/_interaction/_conv_test.py +31 -33
  32. brainstate/nn/_interaction/_linear_test.py +15 -17
  33. brainstate/nn/_interaction/_normalizations_test.py +10 -12
  34. brainstate/nn/_interaction/_poolings_test.py +19 -21
  35. brainstate/nn/_module_test.py +34 -37
  36. brainstate/optim/_lr_scheduler_test.py +3 -3
  37. brainstate/optim/_optax_optimizer_test.py +8 -9
  38. brainstate/random/_rand_funs_test.py +183 -184
  39. brainstate/random/_rand_seed_test.py +10 -12
  40. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/METADATA +1 -1
  41. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/RECORD +44 -44
  42. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
  43. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
  44. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -13,20 +13,18 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import unittest
19
17
 
20
18
  import jax.numpy as jnp
21
19
  import numpy as np
22
20
 
23
- import brainstate as bst
21
+ import brainstate
24
22
 
25
23
 
26
24
  class TestForLoop(unittest.TestCase):
27
25
  def test_for_loop(self):
28
- a = bst.ShortTermState(0.)
29
- b = bst.ShortTermState(0.)
26
+ a = brainstate.ShortTermState(0.)
27
+ b = brainstate.ShortTermState(0.)
30
28
 
31
29
  def f(i):
32
30
  a.value += (1 + b.value)
@@ -34,7 +32,7 @@ class TestForLoop(unittest.TestCase):
34
32
 
35
33
  n_iter = 10
36
34
  ops = np.arange(n_iter)
37
- r = bst.compile.for_loop(f, ops)
35
+ r = brainstate.compile.for_loop(f, ops)
38
36
 
39
37
  print(a)
40
38
  print(b)
@@ -42,8 +40,8 @@ class TestForLoop(unittest.TestCase):
42
40
  self.assertTrue(jnp.allclose(r, ops + 1))
43
41
 
44
42
  def test_checkpointed_for_loop(self):
45
- a = bst.ShortTermState(0.)
46
- b = bst.ShortTermState(0.)
43
+ a = brainstate.ShortTermState(0.)
44
+ b = brainstate.ShortTermState(0.)
47
45
 
48
46
  def f(i):
49
47
  a.value += (1 + b.value)
@@ -51,7 +49,7 @@ class TestForLoop(unittest.TestCase):
51
49
 
52
50
  n_iter = 18
53
51
  ops = jnp.arange(n_iter)
54
- r = bst.compile.checkpointed_for_loop(f, ops, base=2, pbar=bst.compile.ProgressBar())
52
+ r = brainstate.compile.checkpointed_for_loop(f, ops, base=2, pbar=brainstate.compile.ProgressBar())
55
53
 
56
54
  print(a)
57
55
  print(b)
@@ -13,17 +13,16 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  from unittest import TestCase
19
18
 
20
- import brainstate as bst
19
+ import brainstate
21
20
 
22
21
 
23
22
  class TestWhileLoop(TestCase):
24
23
  def test1(self):
25
- a = bst.State(1.)
26
- b = bst.State(20.)
24
+ a = brainstate.State(1.)
25
+ b = brainstate.State(20.)
27
26
 
28
27
  def cond(_):
29
28
  return a.value < b.value
@@ -31,13 +30,13 @@ class TestWhileLoop(TestCase):
31
30
  def body(_):
32
31
  a.value += 1.
33
32
 
34
- bst.compile.while_loop(cond, body, None)
33
+ brainstate.compile.while_loop(cond, body, None)
35
34
 
36
35
  print(a.value, b.value)
37
36
 
38
37
  def test2(self):
39
- a = bst.State(1.)
40
- b = bst.State(20.)
38
+ a = brainstate.State(1.)
39
+ b = brainstate.State(20.)
41
40
 
42
41
  def cond(x):
43
42
  return a.value < b.value
@@ -46,6 +45,6 @@ class TestWhileLoop(TestCase):
46
45
  a.value += x
47
46
  return x
48
47
 
49
- r = bst.compile.while_loop(cond, body, 1.)
48
+ r = brainstate.compile.while_loop(cond, body, 1.)
50
49
 
51
50
  print(a.value, b.value, r)
@@ -62,8 +62,8 @@ import jax
62
62
  from jax._src import source_info_util
63
63
  from jax._src.linear_util import annotate
64
64
  from jax._src.traceback_util import api_boundary
65
- from jax.api_util import shaped_abstractify, debug_info
66
- from jax.extend.linear_util import transformation_with_aux, wrap_init
65
+ from jax.api_util import shaped_abstractify
66
+ from jax.extend.linear_util import transformation_with_aux
67
67
  from jax.interpreters import partial_eval as pe
68
68
 
69
69
  from brainstate._compatible_import import (
@@ -73,6 +73,7 @@ from brainstate._compatible_import import (
73
73
  safe_zip,
74
74
  unzip2,
75
75
  wraps,
76
+ wrap_init,
76
77
  )
77
78
  from brainstate._state import State, StateTraceStack
78
79
  from brainstate._utils import set_module_as
@@ -96,7 +97,7 @@ def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
96
97
  return tuple(safe_map(operator.index, x))
97
98
 
98
99
 
99
- def _new_arg_fn(frame, trace, aval):
100
+ def _jax_v04_new_arg_fn(frame, trace, aval):
100
101
  """
101
102
  Transform a new argument to a tracer.
102
103
 
@@ -117,27 +118,41 @@ def _new_arg_fn(frame, trace, aval):
117
118
  return tracer
118
119
 
119
120
 
120
- def _new_jax_trace():
121
+ def _jax_v04_new_jax_trace():
121
122
  main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
122
123
  frame = main.jaxpr_stack[-1]
123
124
  trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
124
125
  return frame, trace
125
126
 
126
127
 
128
+ def _jax_v04_new_arg():
129
+ # Should be within the calling of ``jax.make_jaxpr()``
130
+ frame, trace = _jax_v04_new_jax_trace()
131
+ # Set the function to transform the new argument to a tracer
132
+ fn = functools.partial(_jax_v04_new_arg_fn, frame, trace)
133
+ return fn
134
+
135
+
136
+ def _jax_new_version_new_arg():
137
+ trace = jax.core.trace_ctx.trace
138
+
139
+ def wrapper(x):
140
+ if jax.__version_info__ < (0, 6, 1):
141
+ return trace.new_arg(shaped_abstractify(x))
142
+ else:
143
+ return trace.new_arg(shaped_abstractify(x), source_info=source_info_util.current())
144
+
145
+ return wrapper
146
+
147
+
127
148
  def _init_state_trace_stack(name) -> StateTraceStack:
128
149
  state_trace: StateTraceStack = StateTraceStack(name=name)
129
150
 
130
151
  if jax.__version_info__ < (0, 4, 36):
131
- # Should be within the calling of ``jax.make_jaxpr()``
132
- frame, trace = _new_jax_trace()
133
- # Set the function to transform the new argument to a tracer
134
- state_trace.set_new_arg(functools.partial(_new_arg_fn, frame, trace))
135
- return state_trace
136
-
152
+ state_trace.set_new_arg(_jax_v04_new_arg())
137
153
  else:
138
- trace = jax.core.trace_ctx.trace
139
- state_trace.set_new_arg(trace.new_arg)
140
- return state_trace
154
+ state_trace.set_new_arg(_jax_new_version_new_arg())
155
+ return state_trace
141
156
 
142
157
 
143
158
  class StatefulFunction(PrettyObject):
@@ -743,7 +758,7 @@ def _make_jaxpr(
743
758
  @wraps(fun)
744
759
  @api_boundary
745
760
  def make_jaxpr_f(*args, **kwargs):
746
- f = wrap_init(fun, debug_info=debug_info('make_jaxpr', fun, args, kwargs))
761
+ f = wrap_init(fun, (), {}, 'brainstate.compile.make_jaxpr')
747
762
  if static_argnums:
748
763
  dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
749
764
  f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
@@ -21,7 +21,7 @@ import jax
21
21
  import jax.numpy as jnp
22
22
  import pytest
23
23
 
24
- import brainstate as bst
24
+ import brainstate
25
25
  from brainstate._compatible_import import jaxpr_as_fun
26
26
 
27
27
 
@@ -29,10 +29,10 @@ class TestMakeJaxpr(unittest.TestCase):
29
29
  def test_compar_jax_make_jaxpr(self):
30
30
  def func4(arg): # Arg is a pair
31
31
  temp = arg[0] + jnp.sin(arg[1]) * 3.
32
- c = bst.random.rand_like(arg[0])
32
+ c = brainstate.random.rand_like(arg[0])
33
33
  return jnp.sum(temp + c)
34
34
 
35
- key = bst.random.DEFAULT.value
35
+ key = brainstate.random.DEFAULT.value
36
36
  jaxpr = jax.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
37
37
  print(jaxpr)
38
38
  self.assertTrue(len(jaxpr.in_avals) == 2)
@@ -40,66 +40,66 @@ class TestMakeJaxpr(unittest.TestCase):
40
40
  self.assertTrue(len(jaxpr.out_avals) == 1)
41
41
  self.assertTrue(jnp.allclose(jaxpr.consts[0], key))
42
42
 
43
- bst.random.seed(1)
44
- print(bst.random.DEFAULT.value)
43
+ brainstate.random.seed(1)
44
+ print(brainstate.random.DEFAULT.value)
45
45
 
46
- jaxpr2, states = bst.compile.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
46
+ jaxpr2, states = brainstate.compile.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
47
47
  print(jaxpr2)
48
48
  self.assertTrue(len(jaxpr2.in_avals) == 3)
49
49
  self.assertTrue(len(jaxpr2.out_avals) == 2)
50
50
  self.assertTrue(len(jaxpr2.consts) == 0)
51
- print(bst.random.DEFAULT.value)
51
+ print(brainstate.random.DEFAULT.value)
52
52
 
53
53
  def test_StatefulFunction_1(self):
54
54
  def func4(arg): # Arg is a pair
55
55
  temp = arg[0] + jnp.sin(arg[1]) * 3.
56
- c = bst.random.rand_like(arg[0])
56
+ c = brainstate.random.rand_like(arg[0])
57
57
  return jnp.sum(temp + c)
58
58
 
59
- fun = bst.compile.StatefulFunction(func4).make_jaxpr((jnp.zeros(8), jnp.ones(8)))
59
+ fun = brainstate.compile.StatefulFunction(func4).make_jaxpr((jnp.zeros(8), jnp.ones(8)))
60
60
  print(fun.get_states())
61
61
  print(fun.get_jaxpr())
62
62
 
63
63
  def test_StatefulFunction_2(self):
64
- st1 = bst.State(jnp.ones(10))
64
+ st1 = brainstate.State(jnp.ones(10))
65
65
 
66
66
  def f1(x):
67
67
  st1.value = x + st1.value
68
68
 
69
69
  def f2(x):
70
- jaxpr = bst.compile.make_jaxpr(f1)(x)
70
+ jaxpr = brainstate.compile.make_jaxpr(f1)(x)
71
71
  c = 1. + x
72
72
  return c
73
73
 
74
74
  def f3(x):
75
- jaxpr = bst.compile.make_jaxpr(f1)(x)
75
+ jaxpr = brainstate.compile.make_jaxpr(f1)(x)
76
76
  c = 1.
77
77
  return c
78
78
 
79
79
  print()
80
- jaxpr = bst.compile.make_jaxpr(f1)(jnp.zeros(1))
80
+ jaxpr = brainstate.compile.make_jaxpr(f1)(jnp.zeros(1))
81
81
  print(jaxpr)
82
82
  jaxpr = jax.make_jaxpr(f2)(jnp.zeros(1))
83
83
  print(jaxpr)
84
84
  jaxpr = jax.make_jaxpr(f3)(jnp.zeros(1))
85
85
  print(jaxpr)
86
- jaxpr, _ = bst.compile.make_jaxpr(f3)(jnp.zeros(1))
86
+ jaxpr, _ = brainstate.compile.make_jaxpr(f3)(jnp.zeros(1))
87
87
  print(jaxpr)
88
88
  self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
89
89
  f3(jnp.zeros(1))))
90
90
 
91
91
  def test_compar_jax_make_jaxpr2(self):
92
- st1 = bst.State(jnp.ones(10))
92
+ st1 = brainstate.State(jnp.ones(10))
93
93
 
94
94
  def fa(x):
95
95
  st1.value = x + st1.value
96
96
 
97
97
  def ffa(x):
98
- jaxpr, states = bst.compile.make_jaxpr(fa)(x)
98
+ jaxpr, states = brainstate.compile.make_jaxpr(fa)(x)
99
99
  c = 1. + x
100
100
  return c
101
101
 
102
- jaxpr, states = bst.compile.make_jaxpr(ffa)(jnp.zeros(1))
102
+ jaxpr, states = brainstate.compile.make_jaxpr(ffa)(jnp.zeros(1))
103
103
  print()
104
104
  print(jaxpr)
105
105
  print(states)
@@ -112,7 +112,7 @@ class TestMakeJaxpr(unittest.TestCase):
112
112
  def fa(x):
113
113
  return 1.
114
114
 
115
- jaxpr, states = bst.compile.make_jaxpr(fa)(jnp.zeros(1))
115
+ jaxpr, states = brainstate.compile.make_jaxpr(fa)(jnp.zeros(1))
116
116
  print()
117
117
  print(jaxpr)
118
118
  print(states)
@@ -125,9 +125,9 @@ class TestMakeJaxpr(unittest.TestCase):
125
125
  def test_return_states():
126
126
  import jax.numpy
127
127
 
128
- a = bst.State(jax.numpy.ones(3))
128
+ a = brainstate.State(jax.numpy.ones(3))
129
129
 
130
- @bst.compile.jit
130
+ @brainstate.compile.jit
131
131
  def f():
132
132
  return a
133
133
 
@@ -25,48 +25,48 @@ from absl.testing import parameterized
25
25
  from jax._src import test_util as jtu
26
26
  from jax.test_util import check_grads
27
27
 
28
- import brainstate as bst
28
+ import brainstate
29
29
 
30
30
 
31
31
  class NNFunctionsTest(jtu.JaxTestCase):
32
32
  @jtu.skip_on_flag("jax_skip_slow_tests", True)
33
33
  def testSoftplusGrad(self):
34
- check_grads(bst.functional.softplus, (1e-8,), order=4, )
34
+ check_grads(brainstate.functional.softplus, (1e-8,), order=4, )
35
35
 
36
36
  def testSoftplusGradZero(self):
37
- check_grads(bst.functional.softplus, (0.,), order=1)
37
+ check_grads(brainstate.functional.softplus, (0.,), order=1)
38
38
 
39
39
  def testSoftplusGradInf(self):
40
- self.assertAllClose(1., jax.grad(bst.functional.softplus)(float('inf')))
40
+ self.assertAllClose(1., jax.grad(brainstate.functional.softplus)(float('inf')))
41
41
 
42
42
  def testSoftplusGradNegInf(self):
43
- check_grads(bst.functional.softplus, (-float('inf'),), order=1)
43
+ check_grads(brainstate.functional.softplus, (-float('inf'),), order=1)
44
44
 
45
45
  def testSoftplusGradNan(self):
46
- check_grads(bst.functional.softplus, (float('nan'),), order=1)
46
+ check_grads(brainstate.functional.softplus, (float('nan'),), order=1)
47
47
 
48
48
  @parameterized.parameters([int, float] + jtu.dtypes.floating + jtu.dtypes.integer)
49
49
  def testSoftplusZero(self, dtype):
50
- self.assertEqual(jnp.log(dtype(2)), bst.functional.softplus(dtype(0)))
50
+ self.assertEqual(jnp.log(dtype(2)), brainstate.functional.softplus(dtype(0)))
51
51
 
52
52
  def testSparseplusGradZero(self):
53
- check_grads(bst.functional.sparse_plus, (-2.,), order=1)
53
+ check_grads(brainstate.functional.sparse_plus, (-2.,), order=1)
54
54
 
55
55
  def testSparseplusGrad(self):
56
- check_grads(bst.functional.sparse_plus, (0.,), order=1)
56
+ check_grads(brainstate.functional.sparse_plus, (0.,), order=1)
57
57
 
58
58
  def testSparseplusAndSparseSigmoid(self):
59
59
  self.assertAllClose(
60
- jax.grad(bst.functional.sparse_plus)(0.),
61
- bst.functional.sparse_sigmoid(0.),
60
+ jax.grad(brainstate.functional.sparse_plus)(0.),
61
+ brainstate.functional.sparse_sigmoid(0.),
62
62
  check_dtypes=False)
63
63
  self.assertAllClose(
64
- jax.grad(bst.functional.sparse_plus)(2.),
65
- bst.functional.sparse_sigmoid(2.),
64
+ jax.grad(brainstate.functional.sparse_plus)(2.),
65
+ brainstate.functional.sparse_sigmoid(2.),
66
66
  check_dtypes=False)
67
67
  self.assertAllClose(
68
- jax.grad(bst.functional.sparse_plus)(-2.),
69
- bst.functional.sparse_sigmoid(-2.),
68
+ jax.grad(brainstate.functional.sparse_plus)(-2.),
69
+ brainstate.functional.sparse_sigmoid(-2.),
70
70
  check_dtypes=False)
71
71
 
72
72
  # def testSquareplusGrad(self):
@@ -107,55 +107,55 @@ class NNFunctionsTest(jtu.JaxTestCase):
107
107
 
108
108
  @parameterized.parameters([float] + jtu.dtypes.floating)
109
109
  def testMishZero(self, dtype):
110
- self.assertEqual(dtype(0), bst.functional.mish(dtype(0)))
110
+ self.assertEqual(dtype(0), brainstate.functional.mish(dtype(0)))
111
111
 
112
112
  def testReluGrad(self):
113
113
  rtol = None
114
- check_grads(bst.functional.relu, (1.,), order=3, rtol=rtol)
115
- check_grads(bst.functional.relu, (-1.,), order=3, rtol=rtol)
116
- jaxpr = jax.make_jaxpr(jax.grad(bst.functional.relu))(0.)
114
+ check_grads(brainstate.functional.relu, (1.,), order=3, rtol=rtol)
115
+ check_grads(brainstate.functional.relu, (-1.,), order=3, rtol=rtol)
116
+ jaxpr = jax.make_jaxpr(jax.grad(brainstate.functional.relu))(0.)
117
117
  self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
118
118
 
119
119
  def testRelu6Grad(self):
120
120
  rtol = None
121
- check_grads(bst.functional.relu6, (1.,), order=3, rtol=rtol)
122
- check_grads(bst.functional.relu6, (-1.,), order=3, rtol=rtol)
123
- self.assertAllClose(jax.grad(bst.functional.relu6)(0.), 0., check_dtypes=False)
124
- self.assertAllClose(jax.grad(bst.functional.relu6)(6.), 0., check_dtypes=False)
121
+ check_grads(brainstate.functional.relu6, (1.,), order=3, rtol=rtol)
122
+ check_grads(brainstate.functional.relu6, (-1.,), order=3, rtol=rtol)
123
+ self.assertAllClose(jax.grad(brainstate.functional.relu6)(0.), 0., check_dtypes=False)
124
+ self.assertAllClose(jax.grad(brainstate.functional.relu6)(6.), 0., check_dtypes=False)
125
125
 
126
126
  def testSoftplusValue(self):
127
- val = bst.functional.softplus(89.)
127
+ val = brainstate.functional.softplus(89.)
128
128
  self.assertAllClose(val, 89., check_dtypes=False)
129
129
 
130
130
  def testSparseplusValue(self):
131
- val = bst.functional.sparse_plus(89.)
131
+ val = brainstate.functional.sparse_plus(89.)
132
132
  self.assertAllClose(val, 89., check_dtypes=False)
133
133
 
134
134
  def testSparsesigmoidValue(self):
135
- self.assertAllClose(bst.functional.sparse_sigmoid(-2.), 0., check_dtypes=False)
136
- self.assertAllClose(bst.functional.sparse_sigmoid(2.), 1., check_dtypes=False)
137
- self.assertAllClose(bst.functional.sparse_sigmoid(0.), .5, check_dtypes=False)
135
+ self.assertAllClose(brainstate.functional.sparse_sigmoid(-2.), 0., check_dtypes=False)
136
+ self.assertAllClose(brainstate.functional.sparse_sigmoid(2.), 1., check_dtypes=False)
137
+ self.assertAllClose(brainstate.functional.sparse_sigmoid(0.), .5, check_dtypes=False)
138
138
 
139
139
  # def testSquareplusValue(self):
140
140
  # val = bst.functional.squareplus(1e3)
141
141
  # self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
142
142
 
143
143
  def testMishValue(self):
144
- val = bst.functional.mish(1e3)
144
+ val = brainstate.functional.mish(1e3)
145
145
  self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
146
146
 
147
147
  def testEluValue(self):
148
- val = bst.functional.elu(1e4)
148
+ val = brainstate.functional.elu(1e4)
149
149
  self.assertAllClose(val, 1e4, check_dtypes=False)
150
150
 
151
151
  def testGluValue(self):
152
- val = bst.functional.glu(jnp.array([1.0, 0.0]), axis=0)
152
+ val = brainstate.functional.glu(jnp.array([1.0, 0.0]), axis=0)
153
153
  self.assertAllClose(val, jnp.array([0.5]))
154
154
 
155
155
  @parameterized.parameters(False, True)
156
156
  def testGeluIntType(self, approximate):
157
- val_float = bst.functional.gelu(jnp.array(-1.0), approximate=approximate)
158
- val_int = bst.functional.gelu(jnp.array(-1), approximate=approximate)
157
+ val_float = brainstate.functional.gelu(jnp.array(-1.0), approximate=approximate)
158
+ val_int = brainstate.functional.gelu(jnp.array(-1), approximate=approximate)
159
159
  self.assertAllClose(val_float, val_int)
160
160
 
161
161
  @parameterized.parameters(False, True)
@@ -166,19 +166,19 @@ class NNFunctionsTest(jtu.JaxTestCase):
166
166
  rng = jtu.rand_default(self.rng())
167
167
  args_maker = lambda: [rng((4, 5, 6), jnp.float32)]
168
168
  self._CheckAgainstNumpy(
169
- gelu_reference, partial(bst.functional.gelu, approximate=approximate), args_maker,
169
+ gelu_reference, partial(brainstate.functional.gelu, approximate=approximate), args_maker,
170
170
  check_dtypes=False, tol=1e-3 if approximate else None)
171
171
 
172
172
  @parameterized.parameters(*itertools.product(
173
173
  (jnp.float32, jnp.bfloat16, jnp.float16),
174
- (partial(bst.functional.gelu, approximate=False),
175
- partial(bst.functional.gelu, approximate=True),
176
- bst.functional.relu,
177
- bst.functional.softplus,
178
- bst.functional.sparse_plus,
179
- bst.functional.sigmoid,
174
+ (partial(brainstate.functional.gelu, approximate=False),
175
+ partial(brainstate.functional.gelu, approximate=True),
176
+ brainstate.functional.relu,
177
+ brainstate.functional.softplus,
178
+ brainstate.functional.sparse_plus,
179
+ brainstate.functional.sigmoid,
180
180
  # bst.functional.squareplus,
181
- bst.functional.mish)))
181
+ brainstate.functional.mish)))
182
182
  def testDtypeMatchesInput(self, dtype, fn):
183
183
  x = jnp.zeros((), dtype=dtype)
184
184
  out = fn(x)
@@ -187,26 +187,26 @@ class NNFunctionsTest(jtu.JaxTestCase):
187
187
  def testEluMemory(self):
188
188
  # see https://github.com/google/jax/pull/1640
189
189
  with jax.enable_checks(False): # With checks we materialize the array
190
- jax.make_jaxpr(lambda: bst.functional.elu(jnp.ones((10 ** 12,)))) # don't oom
190
+ jax.make_jaxpr(lambda: brainstate.functional.elu(jnp.ones((10 ** 12,)))) # don't oom
191
191
 
192
192
  def testHardTanhMemory(self):
193
193
  # see https://github.com/google/jax/pull/1640
194
194
  with jax.enable_checks(False): # With checks we materialize the array
195
- jax.make_jaxpr(lambda: bst.functional.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
195
+ jax.make_jaxpr(lambda: brainstate.functional.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
196
196
 
197
- @parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
197
+ @parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
198
198
  def testSoftmaxEmptyArray(self, fn):
199
199
  x = jnp.array([], dtype=float)
200
200
  self.assertArraysEqual(fn(x), x)
201
201
 
202
- @parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
202
+ @parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
203
203
  def testSoftmaxEmptyMask(self, fn):
204
204
  x = jnp.array([5.5, 1.3, -4.2, 0.9])
205
205
  m = jnp.zeros_like(x, dtype=bool)
206
- expected = jnp.full_like(x, 0.0 if fn is bst.functional.softmax else -jnp.inf)
206
+ expected = jnp.full_like(x, 0.0 if fn is brainstate.functional.softmax else -jnp.inf)
207
207
  self.assertArraysEqual(fn(x, where=m), expected)
208
208
 
209
- @parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
209
+ @parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
210
210
  def testSoftmaxWhereMask(self, fn):
211
211
  x = jnp.array([5.5, 1.3, -4.2, 0.9])
212
212
  m = jnp.array([True, False, True, True])
@@ -214,10 +214,10 @@ class NNFunctionsTest(jtu.JaxTestCase):
214
214
  out = fn(x, where=m)
215
215
  self.assertAllClose(out[m], fn(x[m]))
216
216
 
217
- probs = out if fn is bst.functional.softmax else jnp.exp(out)
217
+ probs = out if fn is brainstate.functional.softmax else jnp.exp(out)
218
218
  self.assertAllClose(probs.sum(), 1.0)
219
219
 
220
- @parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
220
+ @parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
221
221
  def testSoftmaxWhereGrad(self, fn):
222
222
  # regression test for https://github.com/google/jax/issues/19490
223
223
  x = jnp.array([36., 10000.])
@@ -229,46 +229,46 @@ class NNFunctionsTest(jtu.JaxTestCase):
229
229
 
230
230
  def testSoftmaxGrad(self):
231
231
  x = jnp.array([5.5, 1.3, -4.2, 0.9])
232
- jtu.check_grads(bst.functional.softmax, (x,), order=2, atol=5e-3)
232
+ jtu.check_grads(brainstate.functional.softmax, (x,), order=2, atol=5e-3)
233
233
 
234
234
  def testStandardizeWhereMask(self):
235
235
  x = jnp.array([5.5, 1.3, -4.2, 0.9])
236
236
  m = jnp.array([True, False, True, True])
237
237
  x_filtered = jnp.take(x, jnp.array([0, 2, 3]))
238
238
 
239
- out_masked = jnp.take(bst.functional.standardize(x, where=m), jnp.array([0, 2, 3]))
240
- out_filtered = bst.functional.standardize(x_filtered)
239
+ out_masked = jnp.take(brainstate.functional.standardize(x, where=m), jnp.array([0, 2, 3]))
240
+ out_filtered = brainstate.functional.standardize(x_filtered)
241
241
 
242
242
  self.assertAllClose(out_masked, out_filtered)
243
243
 
244
244
  def testOneHot(self):
245
- actual = bst.functional.one_hot(jnp.array([0, 1, 2]), 3)
245
+ actual = brainstate.functional.one_hot(jnp.array([0, 1, 2]), 3)
246
246
  expected = jnp.array([[1., 0., 0.],
247
247
  [0., 1., 0.],
248
248
  [0., 0., 1.]])
249
249
  self.assertAllClose(actual, expected, check_dtypes=False)
250
250
 
251
- actual = bst.functional.one_hot(jnp.array([1, 2, 0]), 3)
251
+ actual = brainstate.functional.one_hot(jnp.array([1, 2, 0]), 3)
252
252
  expected = jnp.array([[0., 1., 0.],
253
253
  [0., 0., 1.],
254
254
  [1., 0., 0.]])
255
255
  self.assertAllClose(actual, expected, check_dtypes=False)
256
256
 
257
257
  def testOneHotOutOfBound(self):
258
- actual = bst.functional.one_hot(jnp.array([-1, 3]), 3)
258
+ actual = brainstate.functional.one_hot(jnp.array([-1, 3]), 3)
259
259
  expected = jnp.array([[0., 0., 0.],
260
260
  [0., 0., 0.]])
261
261
  self.assertAllClose(actual, expected, check_dtypes=False)
262
262
 
263
263
  def testOneHotNonArrayInput(self):
264
- actual = bst.functional.one_hot([0, 1, 2], 3)
264
+ actual = brainstate.functional.one_hot([0, 1, 2], 3)
265
265
  expected = jnp.array([[1., 0., 0.],
266
266
  [0., 1., 0.],
267
267
  [0., 0., 1.]])
268
268
  self.assertAllClose(actual, expected, check_dtypes=False)
269
269
 
270
270
  def testOneHotCustomDtype(self):
271
- actual = bst.functional.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
271
+ actual = brainstate.functional.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
272
272
  expected = jnp.array([[True, False, False],
273
273
  [False, True, False],
274
274
  [False, False, True]])
@@ -279,14 +279,14 @@ class NNFunctionsTest(jtu.JaxTestCase):
279
279
  [0., 0., 1.],
280
280
  [1., 0., 0.]]).T
281
281
 
282
- actual = bst.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=0)
282
+ actual = brainstate.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=0)
283
283
  self.assertAllClose(actual, expected, check_dtypes=False)
284
284
 
285
- actual = bst.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
285
+ actual = brainstate.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
286
286
  self.assertAllClose(actual, expected, check_dtypes=False)
287
287
 
288
288
  def testTanhExists(self):
289
- print(bst.functional.tanh) # doesn't crash
289
+ print(brainstate.functional.tanh) # doesn't crash
290
290
 
291
291
  def testCustomJVPLeak(self):
292
292
  # https://github.com/google/jax/issues/8171
@@ -295,7 +295,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
295
295
  a = jnp.array(1.)
296
296
 
297
297
  def f(hx, _):
298
- hx = bst.functional.sigmoid(hx + a)
298
+ hx = brainstate.functional.sigmoid(hx + a)
299
299
  return hx, None
300
300
 
301
301
  hx = jnp.array(0.)