brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +10 -3
  3. brainstate/_state.py +178 -178
  4. brainstate/_utils.py +0 -1
  5. brainstate/augment/_autograd.py +0 -2
  6. brainstate/augment/_autograd_test.py +132 -133
  7. brainstate/augment/_eval_shape.py +0 -2
  8. brainstate/augment/_eval_shape_test.py +7 -9
  9. brainstate/augment/_mapping.py +2 -3
  10. brainstate/augment/_mapping_test.py +75 -76
  11. brainstate/augment/_random.py +0 -2
  12. brainstate/compile/_ad_checkpoint.py +0 -2
  13. brainstate/compile/_ad_checkpoint_test.py +6 -8
  14. brainstate/compile/_conditions.py +0 -2
  15. brainstate/compile/_conditions_test.py +35 -36
  16. brainstate/compile/_error_if.py +0 -2
  17. brainstate/compile/_error_if_test.py +10 -13
  18. brainstate/compile/_jit.py +9 -8
  19. brainstate/compile/_loop_collect_return.py +0 -2
  20. brainstate/compile/_loop_collect_return_test.py +7 -9
  21. brainstate/compile/_loop_no_collection.py +0 -2
  22. brainstate/compile/_loop_no_collection_test.py +7 -8
  23. brainstate/compile/_make_jaxpr.py +30 -17
  24. brainstate/compile/_make_jaxpr_test.py +20 -20
  25. brainstate/compile/_progress_bar.py +0 -1
  26. brainstate/compile/_unvmap.py +0 -1
  27. brainstate/compile/_util.py +0 -2
  28. brainstate/environ.py +0 -2
  29. brainstate/functional/_activations.py +0 -2
  30. brainstate/functional/_activations_test.py +61 -61
  31. brainstate/functional/_normalization.py +0 -2
  32. brainstate/functional/_others.py +0 -2
  33. brainstate/functional/_spikes.py +0 -1
  34. brainstate/graph/_graph_node.py +1 -3
  35. brainstate/graph/_graph_node_test.py +16 -18
  36. brainstate/graph/_graph_operation.py +4 -2
  37. brainstate/graph/_graph_operation_test.py +154 -156
  38. brainstate/init/_base.py +0 -2
  39. brainstate/init/_generic.py +0 -1
  40. brainstate/init/_random_inits.py +0 -1
  41. brainstate/init/_random_inits_test.py +20 -21
  42. brainstate/init/_regular_inits.py +0 -2
  43. brainstate/init/_regular_inits_test.py +4 -5
  44. brainstate/mixin.py +0 -2
  45. brainstate/nn/_collective_ops.py +0 -3
  46. brainstate/nn/_collective_ops_test.py +8 -8
  47. brainstate/nn/_common.py +0 -2
  48. brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
  49. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
  50. brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
  51. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
  52. brainstate/nn/_dyn_impl/_inputs.py +0 -1
  53. brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
  54. brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
  55. brainstate/nn/_dyn_impl/_readout.py +0 -1
  56. brainstate/nn/_dyn_impl/_readout_test.py +9 -10
  57. brainstate/nn/_dynamics/_dynamics_base.py +0 -1
  58. brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
  59. brainstate/nn/_dynamics/_projection_base.py +0 -1
  60. brainstate/nn/_dynamics/_state_delay.py +0 -2
  61. brainstate/nn/_dynamics/_synouts.py +0 -2
  62. brainstate/nn/_dynamics/_synouts_test.py +4 -5
  63. brainstate/nn/_elementwise/_dropout.py +0 -2
  64. brainstate/nn/_elementwise/_dropout_test.py +9 -9
  65. brainstate/nn/_elementwise/_elementwise.py +0 -2
  66. brainstate/nn/_elementwise/_elementwise_test.py +57 -59
  67. brainstate/nn/_event/_fixedprob_mv.py +0 -1
  68. brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
  69. brainstate/nn/_event/_linear_mv.py +0 -2
  70. brainstate/nn/_event/_linear_mv_test.py +0 -1
  71. brainstate/nn/_exp_euler.py +0 -2
  72. brainstate/nn/_exp_euler_test.py +5 -6
  73. brainstate/nn/_interaction/_conv.py +0 -2
  74. brainstate/nn/_interaction/_conv_test.py +31 -33
  75. brainstate/nn/_interaction/_embedding.py +0 -1
  76. brainstate/nn/_interaction/_linear.py +0 -2
  77. brainstate/nn/_interaction/_linear_test.py +15 -17
  78. brainstate/nn/_interaction/_normalizations.py +0 -2
  79. brainstate/nn/_interaction/_normalizations_test.py +10 -12
  80. brainstate/nn/_interaction/_poolings.py +0 -2
  81. brainstate/nn/_interaction/_poolings_test.py +19 -21
  82. brainstate/nn/_module.py +0 -1
  83. brainstate/nn/_module_test.py +34 -37
  84. brainstate/nn/metrics.py +0 -2
  85. brainstate/optim/_base.py +0 -2
  86. brainstate/optim/_lr_scheduler.py +0 -1
  87. brainstate/optim/_lr_scheduler_test.py +3 -3
  88. brainstate/optim/_optax_optimizer.py +0 -2
  89. brainstate/optim/_optax_optimizer_test.py +8 -9
  90. brainstate/optim/_sgd_optimizer.py +0 -1
  91. brainstate/random/_rand_funs.py +0 -1
  92. brainstate/random/_rand_funs_test.py +183 -184
  93. brainstate/random/_rand_seed.py +0 -1
  94. brainstate/random/_rand_seed_test.py +10 -12
  95. brainstate/random/_rand_state.py +0 -1
  96. brainstate/surrogate.py +0 -1
  97. brainstate/typing.py +0 -2
  98. brainstate/util/_caller.py +4 -6
  99. brainstate/util/_others.py +0 -2
  100. brainstate/util/_pretty_pytree.py +201 -150
  101. brainstate/util/_pretty_repr.py +0 -2
  102. brainstate/util/_pretty_table.py +57 -3
  103. brainstate/util/_scaling.py +0 -2
  104. brainstate/util/_struct.py +0 -2
  105. brainstate/util/filter.py +0 -2
  106. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
  107. brainstate-0.1.2.dist-info/RECORD +133 -0
  108. brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
  109. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
  110. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
  111. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -13,43 +13,40 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import unittest
19
17
 
20
18
  import jax
21
19
  import jax.numpy as jnp
22
- import jaxlib.xla_extension
23
20
 
24
- import brainstate as bst
21
+ import brainstate
25
22
 
26
23
 
27
24
  class TestJitError(unittest.TestCase):
28
25
  def test1(self):
29
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
30
- bst.compile.jit_error_if(True, 'error')
26
+ with self.assertRaises(Exception):
27
+ brainstate.compile.jit_error_if(True, 'error')
31
28
 
32
29
  def err_f(x):
33
30
  raise ValueError(f'error: {x}')
34
31
 
35
- bst.compile.jit_error_if(False, err_f, 1.)
36
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
37
- bst.compile.jit_error_if(True, err_f, 1.)
32
+ brainstate.compile.jit_error_if(False, err_f, 1.)
33
+ with self.assertRaises(Exception):
34
+ brainstate.compile.jit_error_if(True, err_f, 1.)
38
35
 
39
36
  def test_vmap(self):
40
37
  def f(x):
41
- bst.compile.jit_error_if(x, 'error: {x}', x=x)
38
+ brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
42
39
 
43
40
  jax.vmap(f)(jnp.array([False, False, False]))
44
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
41
+ with self.assertRaises(Exception):
45
42
  jax.vmap(f)(jnp.array([True, False, False]))
46
43
 
47
44
  def test_vmap_vmap(self):
48
45
  def f(x):
49
- bst.compile.jit_error_if(x, 'error: {x}', x=x)
46
+ brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
50
47
 
51
48
  jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
52
49
  [False, False, False]]))
53
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
50
+ with self.assertRaises(Exception):
54
51
  jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
55
52
  [True, False, False]]))
@@ -13,16 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import functools
19
17
  from collections.abc import Iterable, Sequence
20
18
  from typing import (Any, Callable, Union)
21
19
 
22
20
  import jax
23
21
  from jax._src import sharding_impls
24
- from jax.lib import xla_client as xc
25
22
 
23
+ from brainstate._compatible_import import Device
26
24
  from brainstate._utils import set_module_as
27
25
  from brainstate.typing import Missing
28
26
  from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
@@ -94,8 +92,8 @@ def _get_jitted_fun(
94
92
  read_state_vals = state_trace.get_read_state_values(True)
95
93
 
96
94
  # call the jitted function
97
- # print('Running ...')
98
95
  write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
96
+
99
97
  # write the state values back to the states
100
98
  write_back_state_values(state_trace, read_state_vals, write_state_vals)
101
99
  return outs
@@ -106,8 +104,11 @@ def _get_jitted_fun(
106
104
  """
107
105
  # clear the cache of the stateful function
108
106
  fun.clear_cache()
109
- # clear the cache of the jitted function
110
- jit_fun.clear_cache()
107
+ try:
108
+ # clear the cache of the jitted function
109
+ jit_fun.clear_cache()
110
+ except AttributeError:
111
+ pass
111
112
 
112
113
  def eval_shape():
113
114
  raise NotImplementedError
@@ -165,7 +166,7 @@ def _get_jitted_fun(
165
166
  # compile the jitted function
166
167
  jitted_fun.compile = compile
167
168
 
168
- # trace the jitted
169
+ # trace the jitted function
169
170
  jitted_fun.trace = trace
170
171
 
171
172
  return jitted_fun
@@ -180,7 +181,7 @@ def jit(
180
181
  donate_argnums: int | Sequence[int] | None = None,
181
182
  donate_argnames: str | Iterable[str] | None = None,
182
183
  keep_unused: bool = False,
183
- device: xc.Device | None = None,
184
+ device: Device | None = None,
184
185
  backend: str | None = None,
185
186
  inline: bool = False,
186
187
  abstracted_axes: Any | None = None,
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import math
19
17
  from functools import wraps
20
18
  from typing import Callable, Optional, TypeVar, Tuple, Any
@@ -13,20 +13,18 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import unittest
19
17
 
20
18
  import jax.numpy as jnp
21
19
  import numpy as np
22
20
 
23
- import brainstate as bst
21
+ import brainstate
24
22
 
25
23
 
26
24
  class TestForLoop(unittest.TestCase):
27
25
  def test_for_loop(self):
28
- a = bst.ShortTermState(0.)
29
- b = bst.ShortTermState(0.)
26
+ a = brainstate.ShortTermState(0.)
27
+ b = brainstate.ShortTermState(0.)
30
28
 
31
29
  def f(i):
32
30
  a.value += (1 + b.value)
@@ -34,7 +32,7 @@ class TestForLoop(unittest.TestCase):
34
32
 
35
33
  n_iter = 10
36
34
  ops = np.arange(n_iter)
37
- r = bst.compile.for_loop(f, ops)
35
+ r = brainstate.compile.for_loop(f, ops)
38
36
 
39
37
  print(a)
40
38
  print(b)
@@ -42,8 +40,8 @@ class TestForLoop(unittest.TestCase):
42
40
  self.assertTrue(jnp.allclose(r, ops + 1))
43
41
 
44
42
  def test_checkpointed_for_loop(self):
45
- a = bst.ShortTermState(0.)
46
- b = bst.ShortTermState(0.)
43
+ a = brainstate.ShortTermState(0.)
44
+ b = brainstate.ShortTermState(0.)
47
45
 
48
46
  def f(i):
49
47
  a.value += (1 + b.value)
@@ -51,7 +49,7 @@ class TestForLoop(unittest.TestCase):
51
49
 
52
50
  n_iter = 18
53
51
  ops = jnp.arange(n_iter)
54
- r = bst.compile.checkpointed_for_loop(f, ops, base=2, pbar=bst.compile.ProgressBar())
52
+ r = brainstate.compile.checkpointed_for_loop(f, ops, base=2, pbar=brainstate.compile.ProgressBar())
55
53
 
56
54
  print(a)
57
55
  print(b)
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import math
19
17
  from typing import Any, Callable, TypeVar
20
18
 
@@ -13,17 +13,16 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  from unittest import TestCase
19
18
 
20
- import brainstate as bst
19
+ import brainstate
21
20
 
22
21
 
23
22
  class TestWhileLoop(TestCase):
24
23
  def test1(self):
25
- a = bst.State(1.)
26
- b = bst.State(20.)
24
+ a = brainstate.State(1.)
25
+ b = brainstate.State(20.)
27
26
 
28
27
  def cond(_):
29
28
  return a.value < b.value
@@ -31,13 +30,13 @@ class TestWhileLoop(TestCase):
31
30
  def body(_):
32
31
  a.value += 1.
33
32
 
34
- bst.compile.while_loop(cond, body, None)
33
+ brainstate.compile.while_loop(cond, body, None)
35
34
 
36
35
  print(a.value, b.value)
37
36
 
38
37
  def test2(self):
39
- a = bst.State(1.)
40
- b = bst.State(20.)
38
+ a = brainstate.State(1.)
39
+ b = brainstate.State(20.)
41
40
 
42
41
  def cond(x):
43
42
  return a.value < b.value
@@ -46,6 +45,6 @@ class TestWhileLoop(TestCase):
46
45
  a.value += x
47
46
  return x
48
47
 
49
- r = bst.compile.while_loop(cond, body, 1.)
48
+ r = brainstate.compile.while_loop(cond, body, 1.)
50
49
 
51
50
  print(a.value, b.value, r)
@@ -51,8 +51,6 @@ function.
51
51
 
52
52
  """
53
53
 
54
- from __future__ import annotations
55
-
56
54
  import functools
57
55
  import inspect
58
56
  import operator
@@ -65,7 +63,7 @@ from jax._src import source_info_util
65
63
  from jax._src.linear_util import annotate
66
64
  from jax._src.traceback_util import api_boundary
67
65
  from jax.api_util import shaped_abstractify
68
- from jax.extend.linear_util import transformation_with_aux, wrap_init
66
+ from jax.extend.linear_util import transformation_with_aux
69
67
  from jax.interpreters import partial_eval as pe
70
68
 
71
69
  from brainstate._compatible_import import (
@@ -75,6 +73,7 @@ from brainstate._compatible_import import (
75
73
  safe_zip,
76
74
  unzip2,
77
75
  wraps,
76
+ wrap_init,
78
77
  )
79
78
  from brainstate._state import State, StateTraceStack
80
79
  from brainstate._utils import set_module_as
@@ -98,7 +97,7 @@ def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
98
97
  return tuple(safe_map(operator.index, x))
99
98
 
100
99
 
101
- def _new_arg_fn(frame, trace, aval):
100
+ def _jax_v04_new_arg_fn(frame, trace, aval):
102
101
  """
103
102
  Transform a new argument to a tracer.
104
103
 
@@ -119,27 +118,41 @@ def _new_arg_fn(frame, trace, aval):
119
118
  return tracer
120
119
 
121
120
 
122
- def _new_jax_trace():
121
+ def _jax_v04_new_jax_trace():
123
122
  main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
124
123
  frame = main.jaxpr_stack[-1]
125
124
  trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
126
125
  return frame, trace
127
126
 
128
127
 
128
+ def _jax_v04_new_arg():
129
+ # Should be within the calling of ``jax.make_jaxpr()``
130
+ frame, trace = _jax_v04_new_jax_trace()
131
+ # Set the function to transform the new argument to a tracer
132
+ fn = functools.partial(_jax_v04_new_arg_fn, frame, trace)
133
+ return fn
134
+
135
+
136
+ def _jax_new_version_new_arg():
137
+ trace = jax.core.trace_ctx.trace
138
+
139
+ def wrapper(x):
140
+ if jax.__version_info__ < (0, 6, 1):
141
+ return trace.new_arg(shaped_abstractify(x))
142
+ else:
143
+ return trace.new_arg(shaped_abstractify(x), source_info=source_info_util.current())
144
+
145
+ return wrapper
146
+
147
+
129
148
  def _init_state_trace_stack(name) -> StateTraceStack:
130
149
  state_trace: StateTraceStack = StateTraceStack(name=name)
131
150
 
132
151
  if jax.__version_info__ < (0, 4, 36):
133
- # Should be within the calling of ``jax.make_jaxpr()``
134
- frame, trace = _new_jax_trace()
135
- # Set the function to transform the new argument to a tracer
136
- state_trace.set_new_arg(functools.partial(_new_arg_fn, frame, trace))
137
- return state_trace
138
-
152
+ state_trace.set_new_arg(_jax_v04_new_arg())
139
153
  else:
140
- trace = jax.core.trace_ctx.trace
141
- state_trace.set_new_arg(trace.new_arg)
142
- return state_trace
154
+ state_trace.set_new_arg(_jax_new_version_new_arg())
155
+ return state_trace
143
156
 
144
157
 
145
158
  class StatefulFunction(PrettyObject):
@@ -745,7 +758,7 @@ def _make_jaxpr(
745
758
  @wraps(fun)
746
759
  @api_boundary
747
760
  def make_jaxpr_f(*args, **kwargs):
748
- f = wrap_init(fun)
761
+ f = wrap_init(fun, (), {}, 'brainstate.compile.make_jaxpr')
749
762
  if static_argnums:
750
763
  dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
751
764
  f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
@@ -754,12 +767,12 @@ def _make_jaxpr(
754
767
  f, out_tree = _flatten_fun(f, in_tree)
755
768
  f = annotate(f, in_type)
756
769
  if jax.__version_info__ < (0, 5, 0):
757
- debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
770
+ debug_info_ = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
758
771
  with ExitStack() as stack:
759
772
  if axis_env is not None:
760
773
  stack.enter_context(extend_axis_env_nd(axis_env))
761
774
  if jax.__version_info__ < (0, 5, 0):
762
- jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
775
+ jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info_)
763
776
  else:
764
777
  jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
765
778
  closed_jaxpr = ClosedJaxpr(jaxpr, consts)
@@ -21,7 +21,7 @@ import jax
21
21
  import jax.numpy as jnp
22
22
  import pytest
23
23
 
24
- import brainstate as bst
24
+ import brainstate
25
25
  from brainstate._compatible_import import jaxpr_as_fun
26
26
 
27
27
 
@@ -29,10 +29,10 @@ class TestMakeJaxpr(unittest.TestCase):
29
29
  def test_compar_jax_make_jaxpr(self):
30
30
  def func4(arg): # Arg is a pair
31
31
  temp = arg[0] + jnp.sin(arg[1]) * 3.
32
- c = bst.random.rand_like(arg[0])
32
+ c = brainstate.random.rand_like(arg[0])
33
33
  return jnp.sum(temp + c)
34
34
 
35
- key = bst.random.DEFAULT.value
35
+ key = brainstate.random.DEFAULT.value
36
36
  jaxpr = jax.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
37
37
  print(jaxpr)
38
38
  self.assertTrue(len(jaxpr.in_avals) == 2)
@@ -40,66 +40,66 @@ class TestMakeJaxpr(unittest.TestCase):
40
40
  self.assertTrue(len(jaxpr.out_avals) == 1)
41
41
  self.assertTrue(jnp.allclose(jaxpr.consts[0], key))
42
42
 
43
- bst.random.seed(1)
44
- print(bst.random.DEFAULT.value)
43
+ brainstate.random.seed(1)
44
+ print(brainstate.random.DEFAULT.value)
45
45
 
46
- jaxpr2, states = bst.compile.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
46
+ jaxpr2, states = brainstate.compile.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
47
47
  print(jaxpr2)
48
48
  self.assertTrue(len(jaxpr2.in_avals) == 3)
49
49
  self.assertTrue(len(jaxpr2.out_avals) == 2)
50
50
  self.assertTrue(len(jaxpr2.consts) == 0)
51
- print(bst.random.DEFAULT.value)
51
+ print(brainstate.random.DEFAULT.value)
52
52
 
53
53
  def test_StatefulFunction_1(self):
54
54
  def func4(arg): # Arg is a pair
55
55
  temp = arg[0] + jnp.sin(arg[1]) * 3.
56
- c = bst.random.rand_like(arg[0])
56
+ c = brainstate.random.rand_like(arg[0])
57
57
  return jnp.sum(temp + c)
58
58
 
59
- fun = bst.compile.StatefulFunction(func4).make_jaxpr((jnp.zeros(8), jnp.ones(8)))
59
+ fun = brainstate.compile.StatefulFunction(func4).make_jaxpr((jnp.zeros(8), jnp.ones(8)))
60
60
  print(fun.get_states())
61
61
  print(fun.get_jaxpr())
62
62
 
63
63
  def test_StatefulFunction_2(self):
64
- st1 = bst.State(jnp.ones(10))
64
+ st1 = brainstate.State(jnp.ones(10))
65
65
 
66
66
  def f1(x):
67
67
  st1.value = x + st1.value
68
68
 
69
69
  def f2(x):
70
- jaxpr = bst.compile.make_jaxpr(f1)(x)
70
+ jaxpr = brainstate.compile.make_jaxpr(f1)(x)
71
71
  c = 1. + x
72
72
  return c
73
73
 
74
74
  def f3(x):
75
- jaxpr = bst.compile.make_jaxpr(f1)(x)
75
+ jaxpr = brainstate.compile.make_jaxpr(f1)(x)
76
76
  c = 1.
77
77
  return c
78
78
 
79
79
  print()
80
- jaxpr = bst.compile.make_jaxpr(f1)(jnp.zeros(1))
80
+ jaxpr = brainstate.compile.make_jaxpr(f1)(jnp.zeros(1))
81
81
  print(jaxpr)
82
82
  jaxpr = jax.make_jaxpr(f2)(jnp.zeros(1))
83
83
  print(jaxpr)
84
84
  jaxpr = jax.make_jaxpr(f3)(jnp.zeros(1))
85
85
  print(jaxpr)
86
- jaxpr, _ = bst.compile.make_jaxpr(f3)(jnp.zeros(1))
86
+ jaxpr, _ = brainstate.compile.make_jaxpr(f3)(jnp.zeros(1))
87
87
  print(jaxpr)
88
88
  self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
89
89
  f3(jnp.zeros(1))))
90
90
 
91
91
  def test_compar_jax_make_jaxpr2(self):
92
- st1 = bst.State(jnp.ones(10))
92
+ st1 = brainstate.State(jnp.ones(10))
93
93
 
94
94
  def fa(x):
95
95
  st1.value = x + st1.value
96
96
 
97
97
  def ffa(x):
98
- jaxpr, states = bst.compile.make_jaxpr(fa)(x)
98
+ jaxpr, states = brainstate.compile.make_jaxpr(fa)(x)
99
99
  c = 1. + x
100
100
  return c
101
101
 
102
- jaxpr, states = bst.compile.make_jaxpr(ffa)(jnp.zeros(1))
102
+ jaxpr, states = brainstate.compile.make_jaxpr(ffa)(jnp.zeros(1))
103
103
  print()
104
104
  print(jaxpr)
105
105
  print(states)
@@ -112,7 +112,7 @@ class TestMakeJaxpr(unittest.TestCase):
112
112
  def fa(x):
113
113
  return 1.
114
114
 
115
- jaxpr, states = bst.compile.make_jaxpr(fa)(jnp.zeros(1))
115
+ jaxpr, states = brainstate.compile.make_jaxpr(fa)(jnp.zeros(1))
116
116
  print()
117
117
  print(jaxpr)
118
118
  print(states)
@@ -125,9 +125,9 @@ class TestMakeJaxpr(unittest.TestCase):
125
125
  def test_return_states():
126
126
  import jax.numpy
127
127
 
128
- a = bst.State(jax.numpy.ones(3))
128
+ a = brainstate.State(jax.numpy.ones(3))
129
129
 
130
- @bst.compile.jit
130
+ @brainstate.compile.jit
131
131
  def f():
132
132
  return a
133
133
 
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import copy
19
18
  import importlib.util
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from __future__ import annotations
16
15
 
17
16
  import jax
18
17
  import jax.core
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  from functools import wraps
19
17
  from typing import Sequence, Tuple
20
18
 
brainstate/environ.py CHANGED
@@ -15,8 +15,6 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from __future__ import annotations
19
-
20
18
  import contextlib
21
19
  import dataclasses
22
20
  import functools
@@ -18,8 +18,6 @@
18
18
  Shared neural network activations and other functions.
19
19
  """
20
20
 
21
- from __future__ import annotations
22
-
23
21
  from typing import Any, Union, Sequence
24
22
 
25
23
  import brainunit as u