brainstate 0.1.0.post20250212__py2.py3-none-any.whl → 0.1.0.post20250216__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. brainstate/_state.py +853 -90
  2. brainstate/_state_test.py +1 -3
  3. brainstate/augment/__init__.py +2 -2
  4. brainstate/augment/_autograd.py +257 -115
  5. brainstate/augment/_autograd_test.py +2 -3
  6. brainstate/augment/_eval_shape.py +3 -4
  7. brainstate/augment/_mapping.py +582 -62
  8. brainstate/augment/_mapping_test.py +114 -30
  9. brainstate/augment/_random.py +61 -7
  10. brainstate/compile/_ad_checkpoint.py +2 -3
  11. brainstate/compile/_conditions.py +4 -5
  12. brainstate/compile/_conditions_test.py +1 -2
  13. brainstate/compile/_error_if.py +1 -2
  14. brainstate/compile/_error_if_test.py +1 -2
  15. brainstate/compile/_jit.py +23 -16
  16. brainstate/compile/_jit_test.py +1 -2
  17. brainstate/compile/_loop_collect_return.py +18 -10
  18. brainstate/compile/_loop_collect_return_test.py +1 -1
  19. brainstate/compile/_loop_no_collection.py +5 -5
  20. brainstate/compile/_make_jaxpr.py +23 -21
  21. brainstate/compile/_make_jaxpr_test.py +1 -2
  22. brainstate/compile/_progress_bar.py +1 -2
  23. brainstate/compile/_unvmap.py +1 -0
  24. brainstate/compile/_util.py +4 -2
  25. brainstate/environ.py +4 -4
  26. brainstate/environ_test.py +1 -2
  27. brainstate/functional/_activations.py +1 -2
  28. brainstate/functional/_activations_test.py +1 -1
  29. brainstate/functional/_normalization.py +1 -2
  30. brainstate/functional/_others.py +1 -2
  31. brainstate/functional/_spikes.py +136 -20
  32. brainstate/graph/_graph_node.py +2 -43
  33. brainstate/graph/_graph_operation.py +4 -20
  34. brainstate/graph/_graph_operation_test.py +3 -4
  35. brainstate/init/_base.py +1 -2
  36. brainstate/init/_generic.py +1 -2
  37. brainstate/nn/__init__.py +4 -0
  38. brainstate/nn/_collective_ops.py +351 -48
  39. brainstate/nn/_collective_ops_test.py +36 -0
  40. brainstate/nn/_common.py +194 -0
  41. brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
  42. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
  43. brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
  44. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
  45. brainstate/nn/_dyn_impl/_inputs.py +1 -2
  46. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
  47. brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
  48. brainstate/nn/_dyn_impl/_readout.py +2 -3
  49. brainstate/nn/_dyn_impl/_readout_test.py +1 -2
  50. brainstate/nn/_dynamics/_dynamics_base.py +2 -3
  51. brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
  52. brainstate/nn/_dynamics/_state_delay.py +3 -3
  53. brainstate/nn/_dynamics/_synouts_test.py +1 -2
  54. brainstate/nn/_elementwise/_dropout.py +6 -7
  55. brainstate/nn/_elementwise/_dropout_test.py +1 -2
  56. brainstate/nn/_elementwise/_elementwise.py +1 -2
  57. brainstate/nn/_exp_euler.py +1 -2
  58. brainstate/nn/_exp_euler_test.py +1 -2
  59. brainstate/nn/_interaction/_conv.py +1 -2
  60. brainstate/nn/_interaction/_conv_test.py +1 -0
  61. brainstate/nn/_interaction/_linear.py +1 -2
  62. brainstate/nn/_interaction/_linear_test.py +1 -2
  63. brainstate/nn/_interaction/_normalizations.py +1 -2
  64. brainstate/nn/_interaction/_poolings.py +3 -4
  65. brainstate/nn/_module.py +63 -19
  66. brainstate/nn/_module_test.py +1 -2
  67. brainstate/nn/metrics.py +3 -4
  68. brainstate/optim/_lr_scheduler.py +1 -2
  69. brainstate/optim/_lr_scheduler_test.py +2 -3
  70. brainstate/optim/_optax_optimizer_test.py +1 -2
  71. brainstate/optim/_sgd_optimizer.py +2 -3
  72. brainstate/random/_rand_funs.py +1 -2
  73. brainstate/random/_rand_funs_test.py +2 -3
  74. brainstate/random/_rand_seed.py +2 -3
  75. brainstate/random/_rand_seed_test.py +1 -2
  76. brainstate/random/_rand_state.py +3 -4
  77. brainstate/surrogate.py +5 -2
  78. brainstate/transform.py +0 -3
  79. brainstate/typing.py +28 -25
  80. brainstate/util/__init__.py +9 -7
  81. brainstate/util/_caller.py +1 -2
  82. brainstate/util/_error.py +27 -0
  83. brainstate/util/_others.py +60 -15
  84. brainstate/util/{_dict.py → _pretty_pytree.py} +2 -2
  85. brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
  86. brainstate/util/_pretty_repr.py +1 -2
  87. brainstate/util/_pretty_table.py +2900 -0
  88. brainstate/util/_struct.py +11 -11
  89. brainstate/util/filter.py +472 -0
  90. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250216.dist-info}/METADATA +2 -2
  91. brainstate-0.1.0.post20250216.dist-info/RECORD +127 -0
  92. brainstate/util/_filter.py +0 -178
  93. brainstate-0.1.0.post20250212.dist-info/RECORD +0 -124
  94. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250216.dist-info}/LICENSE +0 -0
  95. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250216.dist-info}/WHEEL +0 -0
  96. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250216.dist-info}/top_level.txt +0 -0
@@ -15,12 +15,14 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
18
+ import jax
20
19
  import jax.numpy as jnp
20
+ import numpy as np
21
+ import unittest
21
22
 
22
23
  import brainstate as bst
23
24
  from brainstate.augment._mapping import BatchAxisError
25
+ from brainstate.augment._mapping import _remove_axis
24
26
 
25
27
 
26
28
  class TestVmap(unittest.TestCase):
@@ -99,6 +101,27 @@ class TestVmap(unittest.TestCase):
99
101
  )
100
102
  print(bst.random.DEFAULT)
101
103
 
104
+ def test_vmap_with_random_v3(self):
105
+ class Model(bst.nn.Module):
106
+ def __init__(self):
107
+ super().__init__()
108
+
109
+ self.a = bst.ShortTermState(bst.random.randn(5))
110
+ self.b = bst.ShortTermState(bst.random.randn(5))
111
+ self.c = bst.State(bst.random.randn(1))
112
+
113
+ def __call__(self):
114
+ self.c.value = self.a.value * self.b.value
115
+ return self.c.value + bst.random.randn(1)
116
+
117
+ model = Model()
118
+ r2 = bst.augment.vmap(
119
+ model,
120
+ in_states=model.states(bst.ShortTermState),
121
+ out_states=model.c
122
+ )()
123
+ print(bst.random.DEFAULT)
124
+
102
125
  def test_vmap_with_random_2(self):
103
126
  class Model(bst.nn.Module):
104
127
  def __init__(self):
@@ -114,22 +137,11 @@ class TestVmap(unittest.TestCase):
114
137
  self.c.value = self.a.value * self.b.value
115
138
  return self.c.value + bst.random.randn(1)
116
139
 
117
- model = Model()
118
- with self.assertRaises(BatchAxisError):
119
- r2 = bst.augment.vmap(
120
- model,
121
- in_states=model.states(bst.ShortTermState),
122
- out_states=model.c
123
- )(
124
- bst.random.split_key(5)
125
- )
126
-
127
140
  model = Model()
128
141
  r2 = bst.augment.vmap(
129
142
  model,
130
143
  in_states=model.states(bst.ShortTermState),
131
- out_states=model.c,
132
- rngs=model.rng,
144
+ out_states=model.c
133
145
  )(
134
146
  bst.random.split_key(5)
135
147
  )
@@ -154,24 +166,17 @@ class TestVmap(unittest.TestCase):
154
166
  print(model.weight.value_call(jnp.shape))
155
167
  print(model.weight.value)
156
168
 
157
- def test_vmap_model(self):
158
- model = bst.nn.Linear(2, 3)
159
- model_id = id(model)
160
- weight_id = id(model.weight)
161
- print(id(model), id(model.weight))
162
- x = jnp.ones((5, 2))
169
+ def test_vmap_states_and_input_1(self):
170
+ gru = bst.nn.GRUCell(2, 3)
171
+ gru.init_state(5)
163
172
 
164
- @bst.augment.vmap(in_axes=(None, 0), out_axes=0)
165
- def forward(model, x):
166
- self.assertTrue(id(model) == model_id)
167
- self.assertTrue(id(model.weight) == weight_id)
168
- print(id(model), id(model.weight))
169
- return model(x)
173
+ @bst.augment.vmap(in_states=gru.states(bst.HiddenState))
174
+ def forward(x):
175
+ return gru(x)
170
176
 
171
- y = forward(model, x)
172
- print(y.shape)
173
- print(model.weight.value_call(jnp.shape))
174
- print(model.weight.value)
177
+ xs = bst.random.randn(5, 2)
178
+ y = forward(xs)
179
+ self.assertTrue(y.shape == (5, 3))
175
180
 
176
181
  def test_vmap_jit(self):
177
182
  class Foo(bst.nn.Module):
@@ -249,6 +254,16 @@ class TestVmap(unittest.TestCase):
249
254
  print(trace.get_write_states())
250
255
  print(trace.get_read_states())
251
256
 
257
+ def test_auto_rand_key_split(self):
258
+ def f():
259
+ return bst.random.rand(1)
260
+
261
+ res = bst.augment.vmap(f, axis_size=10)()
262
+ self.assertTrue(jnp.all(~(res[0] == res[1:])))
263
+
264
+ res2 = jax.vmap(f, axis_size=10)()
265
+ self.assertTrue(jnp.all((res2[0] == res2[1:])))
266
+
252
267
 
253
268
  class TestMap(unittest.TestCase):
254
269
  def test_map(self):
@@ -264,3 +279,72 @@ class TestMap(unittest.TestCase):
264
279
  self.assertTrue(jnp.allclose(r2, true_r))
265
280
  self.assertTrue(jnp.allclose(r3, true_r))
266
281
  self.assertTrue(jnp.allclose(r4, true_r))
282
+
283
+
284
+ class TestRemoveAxis:
285
+
286
+ def test_remove_axis_2d_array_axis_0(self):
287
+ input_array = np.array([[1, 2, 3], [4, 5, 6]])
288
+ expected_output = np.array([1, 2, 3])
289
+
290
+ result = _remove_axis(input_array, axis=0)
291
+
292
+ np.testing.assert_array_equal(result, expected_output)
293
+
294
+ def test_remove_axis_3d_array(self):
295
+ # Create a 3D array
296
+ x = np.arange(24).reshape((2, 3, 4))
297
+
298
+ # Remove axis 1
299
+ result = _remove_axis(x, axis=1)
300
+
301
+ # Expected result: a 2D array with shape (2, 4)
302
+ expected = x[:, 0, :]
303
+
304
+ np.testing.assert_array_equal(result, expected)
305
+ assert result.shape == (2, 4)
306
+
307
+ def test_remove_axis_1d_array(self):
308
+ # Create a 1D array
309
+ x = np.array([1, 2, 3, 4, 5])
310
+
311
+ # Remove axis 0 (the only axis in a 1D array)
312
+ result = _remove_axis(x, axis=0)
313
+
314
+ # Check that the result is a scalar (0D array) and equal to the first element
315
+ assert np.isscalar(result), "Result should be a scalar"
316
+ assert result == 1, "Result should be equal to the first element of the input array"
317
+
318
+ def test_remove_axis_out_of_bounds(self):
319
+ x = jnp.array([[1, 2], [3, 4]])
320
+ with unittest.TestCase().assertRaises(IndexError):
321
+ _remove_axis(x, axis=2)
322
+
323
+ def test_remove_axis_negative(self):
324
+ x = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
325
+ result = _remove_axis(x, -1)
326
+ expected = jnp.array([[1, 3], [5, 7]])
327
+ np.testing.assert_array_equal(result, expected)
328
+
329
+ def test_remove_axis_with_nan_and_inf(self):
330
+ x = jnp.array([[1.0, jnp.nan, 3.0], [4.0, 5.0, jnp.inf]])
331
+ result = _remove_axis(x, axis=0)
332
+ expected = jnp.array([1.0, jnp.nan, 3.0])
333
+ np.testing.assert_array_equal(result, expected)
334
+ assert jnp.isnan(result[1])
335
+
336
+ def test_remove_axis_different_dtypes(self):
337
+ # Test with integer array
338
+ int_array = jnp.array([[1, 2, 3], [4, 5, 6]])
339
+ int_result = _remove_axis(int_array, 0)
340
+ assert jnp.array_equal(int_result, jnp.array([1, 2, 3]))
341
+
342
+ # Test with float array
343
+ float_array = jnp.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]])
344
+ float_result = _remove_axis(float_array, 1)
345
+ assert jnp.allclose(float_result, jnp.array([1.1, 4.4]))
346
+
347
+ # Test with complex array
348
+ complex_array = jnp.array([[1 + 1j, 2 + 2j], [3 + 3j, 4 + 4j]])
349
+ complex_result = _remove_axis(complex_array, 0)
350
+ assert jnp.allclose(complex_result, jnp.array([1 + 1j, 2 + 2j]))
@@ -20,34 +20,62 @@ from typing import Callable, Sequence, Union
20
20
 
21
21
  from brainstate.random import DEFAULT, RandomState
22
22
  from brainstate.typing import Missing
23
+ from brainstate.util import PrettyObject
23
24
 
24
25
  __all__ = [
25
26
  'restore_rngs'
26
27
  ]
27
28
 
28
29
 
29
- class RngRestore:
30
+ class RngRestore(PrettyObject):
30
31
  """
31
32
  Backup and restore the random state of a sequence of RandomState instances.
33
+
34
+ This class provides functionality to save the current state of multiple
35
+ RandomState instances and later restore them to their saved states.
36
+
37
+ Attributes:
38
+ rngs (Sequence[RandomState]): A sequence of RandomState instances to manage.
39
+ rng_keys (list): A list to store the backed up random keys.
32
40
  """
33
41
 
34
42
  def __init__(self, rngs: Sequence[RandomState]):
43
+ """
44
+ Initialize the RngRestore instance.
45
+
46
+ Args:
47
+ rngs (Sequence[RandomState]): A sequence of RandomState instances
48
+ whose states will be managed.
49
+ """
35
50
  self.rngs: Sequence[RandomState] = rngs
36
51
  self.rng_keys = []
37
52
 
38
53
  def backup(self):
39
54
  """
40
55
  Backup the current random key of the RandomState instances.
56
+
57
+ This method saves the current value (state) of each RandomState
58
+ instance in the rngs sequence.
41
59
  """
42
60
  self.rng_keys = [rng.value for rng in self.rngs]
43
61
 
44
62
  def restore(self):
45
63
  """
46
64
  Restore the random key of the RandomState instances.
65
+
66
+ This method restores each RandomState instance to its previously
67
+ saved state. It raises an error if the number of saved keys doesn't
68
+ match the number of RandomState instances.
69
+
70
+ Raises:
71
+ ValueError: If the number of saved random keys does not match
72
+ the number of RandomState instances.
47
73
  """
74
+ if len(self.rng_keys) != len(self.rngs):
75
+ raise ValueError('The number of random keys does not match the number of random states.')
48
76
  for rng, key in zip(self.rngs, self.rng_keys):
49
77
  rng.restore_value(key)
50
- self.rng_keys = []
78
+ self.rng_keys.clear()
51
79
 
52
80
 
53
81
  def _rng_backup(
@@ -74,19 +102,45 @@ def restore_rngs(
74
102
  rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
75
103
  ) -> Callable:
76
104
  """
77
- Backup the current random state and restore it after the function call.
105
+ Decorator to backup and restore the random state before and after a function call.
106
+
107
+ This function can be used as a decorator or called directly. It ensures that the
108
+ random state of the specified RandomState instances is preserved across function calls,
109
+ which is useful for maintaining reproducibility in stochastic operations.
78
110
 
79
111
  Parameters
80
112
  ----------
81
113
  fn : Callable, optional
82
- The function to be wrapped.
83
- rngs : Union[RandomState, Sequence[RandomState]]
84
- The random state to be backed up and restored. If not provided, the default RandomState instance will be used.
114
+ The function to be wrapped. If not provided, the decorator can be used
115
+ with parameters.
116
+ rngs : Union[RandomState, Sequence[RandomState]], optional
117
+ The random state(s) to be backed up and restored. This can be a single
118
+ RandomState instance or a sequence of RandomState instances. If not provided,
119
+ the default RandomState instance will be used.
85
120
 
86
121
  Returns
87
122
  -------
88
123
  Callable
89
- The wrapped function.
124
+ If `fn` is provided, returns the wrapped function that will backup the
125
+ random state before execution and restore it afterwards.
126
+ If `fn` is not provided, returns a partial function that can be used as
127
+ a decorator with the specified `rngs`.
128
+
129
+ Raises
130
+ ------
131
+ AssertionError
132
+ If `rngs` is not a RandomState instance or a sequence of RandomState instances.
133
+
134
+ Examples
135
+ --------
136
+ >>> @restore_rngs
137
+ ... def my_random_function():
138
+ ... return random.random()
139
+
140
+ >>> rng = RandomState(42)
141
+ >>> @restore_rngs(rngs=rng)
142
+ ... def another_random_function():
143
+ ... return rng.random()
90
144
  """
91
145
  if isinstance(fn, Missing):
92
146
  return functools.partial(restore_rngs, rngs=rngs)
@@ -16,9 +16,8 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import functools
19
- from typing import Callable, Tuple, Union
20
-
21
19
  import jax
20
+ from typing import Callable, Tuple, Union
22
21
 
23
22
  from brainstate.typing import Missing
24
23
  from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
@@ -181,7 +180,7 @@ def checkpoint(
181
180
  return lambda f: checkpoint(f, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums)
182
181
 
183
182
  static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
184
- fun = StatefulFunction(fun, static_argnums=static_argnums)
183
+ fun = StatefulFunction(fun, static_argnums=static_argnums, name='checkpoint')
185
184
  checkpointed_fun = jax.checkpoint(
186
185
  fun.jaxpr_call,
187
186
  prevent_cse=prevent_cse,
@@ -15,11 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- from collections.abc import Callable, Sequence
19
-
20
18
  import jax
21
19
  import jax.numpy as jnp
22
20
  import numpy as np
21
+ from collections.abc import Callable, Sequence
23
22
 
24
23
  from brainstate._utils import set_module_as
25
24
  from ._error_if import jit_error_if
@@ -94,8 +93,8 @@ def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
94
93
  return false_fun(*operands)
95
94
 
96
95
  # evaluate jaxpr
97
- stateful_true = StatefulFunction(true_fun).make_jaxpr(*operands)
98
- stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
96
+ stateful_true = StatefulFunction(true_fun, name='cond:true').make_jaxpr(*operands)
97
+ stateful_false = StatefulFunction(false_fun, name='conda:false').make_jaxpr(*operands)
99
98
 
100
99
  # state trace and state values
101
100
  state_trace = stateful_true.get_state_trace() + stateful_false.get_state_trace()
@@ -174,7 +173,7 @@ def switch(index, branches: Sequence[Callable], *operands):
174
173
  return branches[int(index)](*operands)
175
174
 
176
175
  # evaluate jaxpr
177
- wrapped_branches = [StatefulFunction(branch) for branch in branches]
176
+ wrapped_branches = [StatefulFunction(branch, name='switch') for branch in branches]
178
177
  for wrapped_branch in wrapped_branches:
179
178
  wrapped_branch.make_jaxpr(*operands)
180
179
 
@@ -14,10 +14,9 @@
14
14
  # ==============================================================================
15
15
  from __future__ import annotations
16
16
 
17
- import unittest
18
-
19
17
  import jax
20
18
  import jax.numpy as jnp
19
+ import unittest
21
20
 
22
21
  import brainstate as bst
23
22
 
@@ -16,11 +16,10 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import functools
19
+ import jax
19
20
  from functools import partial
20
21
  from typing import Callable, Union
21
22
 
22
- import jax
23
-
24
23
  from brainstate._utils import set_module_as
25
24
  from ._unvmap import unvmap
26
25
 
@@ -15,11 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import jax
21
19
  import jax.numpy as jnp
22
20
  import jaxlib.xla_extension
21
+ import unittest
23
22
 
24
23
  import brainstate as bst
25
24
 
@@ -16,12 +16,11 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import functools
19
- from collections.abc import Iterable, Sequence
20
- from typing import (Any, Callable, Union)
21
-
22
19
  import jax
20
+ from collections.abc import Iterable, Sequence
23
21
  from jax._src import sharding_impls
24
22
  from jax.lib import xla_client as xc
23
+ from typing import (Any, Callable, Union)
25
24
 
26
25
  from brainstate._utils import set_module_as
27
26
  from brainstate.typing import Missing
@@ -62,19 +61,27 @@ def _get_jitted_fun(
62
61
  **kwargs
63
62
  ) -> JittedFunction:
64
63
  static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
65
- fun = StatefulFunction(fun, static_argnums=static_argnums, abstracted_axes=abstracted_axes, cache_type='jit')
66
- jit_fun = jax.jit(fun.jaxpr_call,
67
- static_argnums=tuple(i + 1 for i in static_argnums),
68
- donate_argnums=donate_argnums,
69
- donate_argnames=donate_argnames,
70
- keep_unused=keep_unused,
71
- device=device,
72
- backend=backend,
73
- inline=inline,
74
- in_shardings=in_shardings,
75
- out_shardings=out_shardings,
76
- abstracted_axes=abstracted_axes,
77
- **kwargs)
64
+ fun = StatefulFunction(
65
+ fun,
66
+ static_argnums=static_argnums,
67
+ abstracted_axes=abstracted_axes,
68
+ cache_type='jit',
69
+ name='jit'
70
+ )
71
+ jit_fun = jax.jit(
72
+ fun.jaxpr_call,
73
+ static_argnums=tuple(i + 1 for i in static_argnums),
74
+ donate_argnums=donate_argnums,
75
+ donate_argnames=donate_argnames,
76
+ keep_unused=keep_unused,
77
+ device=device,
78
+ backend=backend,
79
+ inline=inline,
80
+ in_shardings=in_shardings,
81
+ out_shardings=out_shardings,
82
+ abstracted_axes=abstracted_axes,
83
+ **kwargs
84
+ )
78
85
 
79
86
  @functools.wraps(fun.fun)
80
87
  def jitted_fun(*args, **params):
@@ -15,9 +15,8 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import jax.numpy as jnp
19
+ import unittest
21
20
 
22
21
  import brainstate as bst
23
22
 
@@ -16,11 +16,11 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import math
19
- from functools import wraps
20
- from typing import Callable, Optional, TypeVar, Tuple, Any
21
19
 
22
20
  import jax
23
21
  import jax.numpy as jnp
22
+ from functools import wraps
23
+ from typing import Callable, Optional, TypeVar, Tuple, Any
24
24
 
25
25
  from brainstate._utils import set_module_as
26
26
  from ._make_jaxpr import StatefulFunction
@@ -209,7 +209,7 @@ def scan(
209
209
  # ------------------------------ #
210
210
  xs_avals = [jax.core.get_aval(x) for x in xs_flat]
211
211
  x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
212
- stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
212
+ stateful_fun = StatefulFunction(f, name='scan').make_jaxpr(init, xs_tree.unflatten(x_avals))
213
213
  state_trace = stateful_fun.get_state_trace()
214
214
  all_writen_state_vals = state_trace.get_write_state_values(True)
215
215
  all_read_state_vals = state_trace.get_read_state_values(True)
@@ -217,12 +217,20 @@ def scan(
217
217
 
218
218
  # scan
219
219
  init = (all_writen_state_vals, init)
220
- (all_writen_state_vals, carry), ys = jax.lax.scan(wrapped_f,
221
- init,
222
- xs,
223
- length=length,
224
- reverse=reverse,
225
- unroll=unroll)
220
+ (
221
+ (
222
+ all_writen_state_vals,
223
+ carry
224
+ ),
225
+ ys
226
+ ) = jax.lax.scan(
227
+ wrapped_f,
228
+ init,
229
+ xs,
230
+ length=length,
231
+ reverse=reverse,
232
+ unroll=unroll
233
+ )
226
234
  # assign the written state values and restore the read state values
227
235
  write_back_state_values(state_trace, all_read_state_vals, all_writen_state_vals)
228
236
  # carry
@@ -305,7 +313,7 @@ def checkpointed_scan(
305
313
  # evaluate jaxpr
306
314
  xs_avals = [jax.core.get_aval(x) for x in xs_flat]
307
315
  x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
308
- stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
316
+ stateful_fun = StatefulFunction(f, name='checkpoint_scan').make_jaxpr(init, xs_tree.unflatten(x_avals))
309
317
  state_trace = stateful_fun.get_state_trace()
310
318
  # get all states
311
319
  been_written = state_trace.been_writen
@@ -14,10 +14,10 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from __future__ import annotations
17
- import unittest
18
17
 
19
18
  import jax.numpy as jnp
20
19
  import numpy as np
20
+ import unittest
21
21
 
22
22
  import brainstate as bst
23
23
 
@@ -16,9 +16,9 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import math
19
- from typing import Any, Callable, TypeVar
20
19
 
21
20
  import jax
21
+ from typing import Any, Callable, TypeVar
22
22
 
23
23
  from brainstate._utils import set_module_as
24
24
  from ._loop_collect_return import _bounded_while_loop
@@ -103,8 +103,8 @@ def while_loop(
103
103
  pass
104
104
 
105
105
  # evaluate jaxpr
106
- stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
107
- stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
106
+ stateful_cond = StatefulFunction(cond_fun, name='while:cond').make_jaxpr(init_val)
107
+ stateful_body = StatefulFunction(body_fun, name='while:body').make_jaxpr(init_val)
108
108
  if len(stateful_cond.get_write_states()) != 0:
109
109
  raise ValueError("while_loop: cond_fun should not have any write states.")
110
110
 
@@ -162,8 +162,8 @@ def bounded_while_loop(
162
162
  return init_val
163
163
 
164
164
  # evaluate jaxpr
165
- stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
166
- stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
165
+ stateful_cond = StatefulFunction(cond_fun, name='bounded_while:cond').make_jaxpr(init_val)
166
+ stateful_body = StatefulFunction(body_fun, name='bounded_while:body').make_jaxpr(init_val)
167
167
  if len(stateful_cond.get_write_states()) != 0:
168
168
  raise ValueError("while_loop: cond_fun should not have any write states.")
169
169