brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250126__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. brainstate/__init__.py +1 -2
  2. brainstate/_state.py +77 -44
  3. brainstate/_state_test.py +0 -17
  4. brainstate/augment/__init__.py +10 -20
  5. brainstate/augment/_eval_shape.py +9 -10
  6. brainstate/augment/_eval_shape_test.py +1 -1
  7. brainstate/augment/_mapping.py +265 -277
  8. brainstate/augment/_mapping_test.py +147 -175
  9. brainstate/compile/__init__.py +18 -37
  10. brainstate/compile/_ad_checkpoint.py +6 -4
  11. brainstate/compile/_jit.py +37 -28
  12. brainstate/compile/_loop_collect_return.py +6 -3
  13. brainstate/compile/_loop_no_collection.py +2 -0
  14. brainstate/compile/_make_jaxpr.py +15 -4
  15. brainstate/compile/_make_jaxpr_test.py +10 -6
  16. brainstate/compile/_progress_bar.py +68 -40
  17. brainstate/compile/_unvmap.py +9 -6
  18. brainstate/graph/__init__.py +12 -16
  19. brainstate/graph/_graph_node.py +1 -23
  20. brainstate/graph/_graph_operation.py +1 -1
  21. brainstate/graph/_graph_operation_test.py +0 -159
  22. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  23. brainstate/nn/_elementwise/_dropout_test.py +1 -1
  24. brainstate/nn/_interaction/_conv.py +4 -2
  25. brainstate/nn/_interaction/_linear.py +84 -10
  26. brainstate/random/_rand_funs.py +9 -2
  27. brainstate/random/_rand_seed.py +12 -2
  28. brainstate/random/_rand_state.py +50 -179
  29. brainstate/surrogate.py +5 -1
  30. brainstate/util/__init__.py +0 -4
  31. brainstate/util/_caller.py +1 -1
  32. brainstate/util/_dict.py +4 -1
  33. brainstate/util/_filter.py +1 -1
  34. brainstate/util/_pretty_repr.py +1 -1
  35. brainstate/util/_struct.py +1 -1
  36. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/METADATA +2 -1
  37. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/RECORD +40 -60
  38. brainstate/event/__init__.py +0 -29
  39. brainstate/event/_csr.py +0 -906
  40. brainstate/event/_csr_mv.py +0 -303
  41. brainstate/event/_csr_mv_benchmark.py +0 -14
  42. brainstate/event/_csr_mv_test.py +0 -118
  43. brainstate/event/_csr_test.py +0 -90
  44. brainstate/event/_fixedprob_mv.py +0 -730
  45. brainstate/event/_fixedprob_mv_benchmark.py +0 -128
  46. brainstate/event/_fixedprob_mv_test.py +0 -132
  47. brainstate/event/_linear_mv.py +0 -359
  48. brainstate/event/_linear_mv_benckmark.py +0 -82
  49. brainstate/event/_linear_mv_test.py +0 -117
  50. brainstate/event/_misc.py +0 -34
  51. brainstate/event/_xla_custom_op.py +0 -313
  52. brainstate/event/_xla_custom_op_test.py +0 -55
  53. brainstate/graph/_graph_context.py +0 -443
  54. brainstate/graph/_graph_context_test.py +0 -65
  55. brainstate/graph/_graph_convert.py +0 -246
  56. brainstate/util/_tracers.py +0 -68
  57. brainstate/util/_visualization.py +0 -47
  58. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/LICENSE +0 -0
  59. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/WHEEL +0 -0
  60. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/top_level.txt +0 -0
@@ -17,57 +17,122 @@ from __future__ import annotations
17
17
 
18
18
  import unittest
19
19
 
20
- import jax.core
21
20
  import jax.numpy as jnp
22
21
 
23
22
  import brainstate as bst
23
+ from brainstate.augment._mapping import BatchAxisError
24
24
 
25
25
 
26
26
  class TestVmap(unittest.TestCase):
27
- def test_vmap_return_keep_reference_return(self):
28
- @bst.augment.vmap(in_axes=0, out_axes=0)
29
- def create_model(key):
30
- bst.random.set_key(key)
31
- m1 = bst.nn.Linear(2, 3)
32
-
33
- m2 = bst.nn.Linear(3, 4)
34
- m2.a = m1
35
- m3 = bst.nn.Linear(3, 5)
36
- m3.a = m1
37
- self.assertTrue(id(m2.a) == id(m3.a))
38
- return m2, m3
39
-
40
- m2, m3 = create_model(bst.random.split_key(10))
41
- self.assertTrue(id(m2.a) == id(m3.a))
42
- jax.core.concrete_or_error(None, bst.random.DEFAULT.value)
43
-
44
- def test_vmap_return_keep_reference_pass_into_fun(self):
45
- @bst.augment.vmap(in_axes=(None, None, 0), out_axes=0)
46
- def run_model(m2, m3, x):
47
- self.assertTrue(id(m2.a) == id(m3.a))
48
- self.assertTrue(id(m2) != m2_id)
49
- self.assertTrue(id(m3) != m3_id)
50
- return m2(x), m3(x)
51
-
52
- m1 = bst.nn.Linear(2, 3)
53
- m2 = bst.nn.Linear(4, 3)
54
- m2.a = m1
55
- m3 = bst.nn.Linear(4, 5)
56
- m3.a = m1
57
- m3_id = id(m3)
58
- m2_id = id(m2)
59
- r1, r2 = run_model(m2, m3, jnp.ones((4, 3, 4)))
60
-
61
- def test_vmap_set_key(self):
62
- @bst.augment.vmap(in_axes=0, out_axes=0)
63
- def create_model(key):
64
- bst.random.set_key(key)
65
- return bst.nn.Linear(2, 3)
66
-
67
- model = create_model(bst.random.split_keys(10))
68
- print(model.weight.value_call(jnp.shape))
69
- model.weight.value_call(lambda x: jax.core.concrete_or_error(None, x))
70
- bst.random.seed()
27
+ def test_vmap_1(self):
28
+ class Model(bst.nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+
32
+ self.a = bst.State(bst.random.randn(5))
33
+ self.b = bst.State(bst.random.randn(5))
34
+
35
+ def __call__(self, *args, **kwargs):
36
+ return self.a.value * self.b.value
37
+
38
+ model = Model()
39
+ r1 = model.a.value * model.b.value
40
+ r2 = bst.augment.vmap(model, in_states=model.states())()
41
+ self.assertTrue(jnp.allclose(r1, r2))
42
+
43
+ def test_vmap_2(self):
44
+ class Model(bst.nn.Module):
45
+ def __init__(self):
46
+ super().__init__()
47
+
48
+ self.a = bst.ShortTermState(bst.random.randn(5))
49
+ self.b = bst.ShortTermState(bst.random.randn(5))
50
+ self.c = bst.State(bst.random.randn(1))
51
+
52
+ def __call__(self, *args, **kwargs):
53
+ self.c.value = self.a.value * self.b.value
54
+ return self.c.value + 1.
55
+
56
+ model = Model()
57
+ with self.assertRaises(BatchAxisError):
58
+ r2 = bst.augment.vmap(model, in_states=model.states(bst.ShortTermState))()
59
+
60
+ model = Model()
61
+ r2 = bst.augment.vmap(model, in_states=model.states(bst.ShortTermState), out_states=model.c)()
62
+
63
+ def test_vmap_3(self):
64
+ class Model(bst.nn.Module):
65
+ def __init__(self):
66
+ super().__init__()
67
+
68
+ self.a = bst.State(bst.random.randn(5))
69
+ self.b = bst.State(bst.random.randn(5))
70
+
71
+ def __call__(self, *args, **kwargs):
72
+ return self.a.value * self.b.value
73
+
74
+ model = Model()
75
+ with self.assertRaises(BatchAxisError):
76
+ r2 = bst.augment.vmap(model, in_states=model.states(), out_states={1: model.states()})()
77
+
78
+ def test_vmap_with_random(self):
79
+ class Model(bst.nn.Module):
80
+ def __init__(self):
81
+ super().__init__()
82
+
83
+ self.a = bst.ShortTermState(bst.random.randn(5))
84
+ self.b = bst.ShortTermState(bst.random.randn(5))
85
+ self.c = bst.State(bst.random.randn(1))
86
+
87
+ def __call__(self, key):
88
+ bst.random.set_key(key)
89
+ self.c.value = self.a.value * self.b.value
90
+ return self.c.value + bst.random.randn(1)
91
+
92
+ model = Model()
93
+ r2 = bst.augment.vmap(
94
+ model,
95
+ in_states=model.states(bst.ShortTermState),
96
+ out_states=model.c
97
+ )(
98
+ bst.random.split_key(5)
99
+ )
100
+ print(bst.random.DEFAULT)
101
+
102
+ def test_vmap_with_random_2(self):
103
+ class Model(bst.nn.Module):
104
+ def __init__(self):
105
+ super().__init__()
106
+
107
+ self.a = bst.ShortTermState(bst.random.randn(5))
108
+ self.b = bst.ShortTermState(bst.random.randn(5))
109
+ self.c = bst.State(bst.random.randn(1))
110
+ self.rng = bst.random.RandomState(1)
111
+
112
+ def __call__(self, key):
113
+ self.rng.set_key(key)
114
+ self.c.value = self.a.value * self.b.value
115
+ return self.c.value + bst.random.randn(1)
116
+
117
+ model = Model()
118
+ with self.assertRaises(BatchAxisError):
119
+ r2 = bst.augment.vmap(
120
+ model,
121
+ in_states=model.states(bst.ShortTermState),
122
+ out_states=model.c
123
+ )(
124
+ bst.random.split_key(5)
125
+ )
126
+
127
+ model = Model()
128
+ r2 = bst.augment.vmap(
129
+ model,
130
+ in_states=model.states(bst.ShortTermState),
131
+ out_states=model.c,
132
+ rngs=model.rng,
133
+ )(
134
+ bst.random.split_key(5)
135
+ )
71
136
 
72
137
  def test_vmap_input(self):
73
138
  model = bst.nn.Linear(2, 3)
@@ -98,8 +163,8 @@ class TestVmap(unittest.TestCase):
98
163
 
99
164
  @bst.augment.vmap(in_axes=(None, 0), out_axes=0)
100
165
  def forward(model, x):
101
- self.assertTrue(id(model) != model_id)
102
- self.assertTrue(id(model.weight) != weight_id)
166
+ self.assertTrue(id(model) == model_id)
167
+ self.assertTrue(id(model.weight) == weight_id)
103
168
  print(id(model), id(model.weight))
104
169
  return model(x)
105
170
 
@@ -108,51 +173,7 @@ class TestVmap(unittest.TestCase):
108
173
  print(model.weight.value_call(jnp.shape))
109
174
  print(model.weight.value)
110
175
 
111
- def test_vmap1(self):
112
- model = bst.nn.Linear(2, 3)
113
- x = jnp.ones((5, 2))
114
-
115
- @bst.augment.vmap(in_axes=(None, 0), out_axes=0)
116
- def forward(model, x):
117
- return model(x)
118
-
119
- y = forward(model, x)
120
- print(y.shape)
121
-
122
- def test_vmap2(self):
123
- class LinearEnsemble(bst.nn.Module):
124
- def __init__(self, num):
125
- super().__init__()
126
- self.w = bst.ParamState(bst.random.random((num, 2, 3)))
127
-
128
- model = LinearEnsemble(5)
129
- x = jnp.ones((2,))
130
-
131
- @bst.augment.vmap(in_axes=(0, None), out_axes=0)
132
- def forward(model, x):
133
- return jnp.dot(x, model.w.value)
134
-
135
- y = forward(model, x)
136
- print(y.shape)
137
-
138
- def test_vmap3(self):
139
- class Foo(bst.nn.Module):
140
- def __init__(self):
141
- super().__init__()
142
- self.a = bst.ParamState(jnp.arange(4))
143
- self.b = bst.ShortTermState(jnp.arange(4))
144
-
145
- state_axes = bst.augment.StateAxes({bst.ParamState: 0, bst.ShortTermState: None})
146
-
147
- @bst.augment.vmap(in_axes=(state_axes,), out_axes=0)
148
- def mul(foo):
149
- return foo.a.value * foo.b.value
150
-
151
- foo = Foo()
152
- y = mul(foo)
153
- print(y.shape)
154
-
155
- def test_vmap4(self):
176
+ def test_vmap_jit(self):
156
177
  class Foo(bst.nn.Module):
157
178
  def __init__(self):
158
179
  super().__init__()
@@ -162,60 +183,35 @@ class TestVmap(unittest.TestCase):
162
183
  def __call__(self):
163
184
  self.b.value = self.a.value * self.b.value
164
185
 
165
- @bst.augment.vmap
166
- def mul(foo):
167
- foo()
168
- return foo
169
-
170
186
  foo = Foo()
171
- with bst.StateTraceStack() as trace:
172
- m = mul(foo)
173
-
174
- self.assertTrue(m is foo)
175
- print(m.a.value, foo.a.value)
176
- self.assertTrue(jnp.allclose(m.a.value, foo.a.value))
177
- print(m.b.value, foo.b.value)
178
- self.assertTrue(jnp.allclose(m.b.value, foo.b.value))
179
- print(trace.get_write_states())
180
- self.assertTrue(len(trace.get_write_states()) == 1)
181
- print(trace.get_read_states())
182
- self.assertTrue(len(trace.get_read_states()) == 2)
183
-
184
- def test_vmap5(self):
185
- class Foo(bst.nn.Module):
186
- def __init__(self):
187
- super().__init__()
188
- self.a = bst.ParamState(jnp.arange(4))
189
- self.b = bst.ShortTermState(jnp.arange(4))
190
187
 
191
- def __call__(self):
192
- self.b.value = self.a.value * self.b.value
193
-
194
- @bst.augment.vmap
195
- def mul(foo):
188
+ @bst.augment.vmap(in_states=foo.states())
189
+ def mul():
196
190
  foo()
197
191
 
198
- foo = Foo()
192
+ @bst.compile.jit
193
+ def mul_jit(inp):
194
+ mul()
195
+ foo.a.value += inp
196
+
199
197
  with bst.StateTraceStack() as trace:
200
- mul(foo)
198
+ mul_jit(1.)
201
199
 
202
200
  print(foo.a.value)
203
201
  print(foo.b.value)
204
- self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4)))
202
+ self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
205
203
  self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
206
204
 
207
205
  write_state_ids = [id(st) for st in trace.get_write_states()]
208
206
  read_state_ids = [id(st) for st in trace.get_read_states()]
209
207
 
210
- assert id(foo.a) in read_state_ids
208
+ assert id(foo.a) in write_state_ids
211
209
  assert id(foo.b) in write_state_ids
212
210
 
213
211
  print(trace.get_write_states())
214
212
  print(trace.get_read_states())
215
213
 
216
-
217
-
218
- def test_vmap_jit(self):
214
+ def test_vmap_jit_2(self):
219
215
  class Foo(bst.nn.Module):
220
216
  def __init__(self):
221
217
  super().__init__()
@@ -225,70 +221,46 @@ class TestVmap(unittest.TestCase):
225
221
  def __call__(self):
226
222
  self.b.value = self.a.value * self.b.value
227
223
 
228
- @bst.augment.vmap
229
- def mul(foo):
224
+ foo = Foo()
225
+
226
+ @bst.augment.vmap(in_states=foo.states())
227
+ def mul():
230
228
  foo()
231
229
 
232
230
  @bst.compile.jit
233
231
  def mul_jit(inp):
234
- mul(foo)
235
- foo.a.value += inp
232
+ mul()
233
+ foo.b.value += inp
236
234
 
237
- foo = Foo()
238
235
  with bst.StateTraceStack() as trace:
239
236
  mul_jit(1.)
240
237
 
241
238
  print(foo.a.value)
242
239
  print(foo.b.value)
243
- self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
244
- self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
240
+ self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4)))
241
+ self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4) + 1.))
245
242
 
246
243
  write_state_ids = [id(st) for st in trace.get_write_states()]
247
244
  read_state_ids = [id(st) for st in trace.get_read_states()]
248
245
 
249
- assert id(foo.a) in write_state_ids
246
+ assert id(foo.a) in read_state_ids
250
247
  assert id(foo.b) in write_state_ids
251
248
 
252
249
  print(trace.get_write_states())
253
250
  print(trace.get_read_states())
254
251
 
255
252
 
256
- def test_vmap_grad(self):
257
- class Foo(bst.nn.Module):
258
- def __init__(self):
259
- super().__init__()
260
- self.a = bst.ParamState(jnp.arange(4.))
261
- self.b = bst.ShortTermState(jnp.arange(4.))
262
-
263
- def __call__(self):
264
- self.b.value = self.a.value * self.b.value
265
-
266
- @bst.augment.vmap
267
- def mul(foo):
268
- foo()
269
-
270
- def loss():
271
- mul(foo)
272
- return jnp.sum(foo.b.value)
273
-
274
- foo = Foo()
275
- with bst.StateTraceStack() as trace:
276
- grads, loss = bst.augment.grad(loss, foo.states(bst.ParamState), return_value=True)()
277
- print(grads)
278
- print(loss)
279
-
280
- # print(foo.a.value)
281
- # print(foo.b.value)
282
- # self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
283
- # self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
284
- #
285
- # write_state_ids = [id(st) for st in trace.get_write_states()]
286
- # read_state_ids = [id(st) for st in trace.get_read_states()]
287
- #
288
- # assert id(foo.a) in write_state_ids
289
- # assert id(foo.b) in write_state_ids
290
- #
291
- # print(trace.get_write_states())
292
- # print(trace.get_read_states())
293
-
294
-
253
+ class TestMap(unittest.TestCase):
254
+ def test_map(self):
255
+ for dim in [(10,), (10, 10), (10, 10, 10)]:
256
+ x = bst.random.rand(*dim)
257
+ r1 = bst.augment.map(lambda a: a + 1, x, batch_size=None)
258
+ r2 = bst.augment.map(lambda a: a + 1, x, batch_size=2)
259
+ r3 = bst.augment.map(lambda a: a + 1, x, batch_size=4)
260
+ r4 = bst.augment.map(lambda a: a + 1, x, batch_size=5)
261
+ true_r = x + 1
262
+
263
+ self.assertTrue(jnp.allclose(r1, true_r))
264
+ self.assertTrue(jnp.allclose(r2, true_r))
265
+ self.assertTrue(jnp.allclose(r3, true_r))
266
+ self.assertTrue(jnp.allclose(r4, true_r))
@@ -17,41 +17,22 @@
17
17
  This module contains the functions for the compilation of JAX code.
18
18
  """
19
19
 
20
- from ._ad_checkpoint import *
21
- from ._ad_checkpoint import __all__ as _ad_checkpoint_all
22
- from ._conditions import *
23
- from ._conditions import __all__ as _conditions_all
24
- from ._error_if import *
25
- from ._error_if import __all__ as _jit_error_all
26
- from ._jit import *
27
- from ._jit import __all__ as _jit_all
28
- from ._loop_collect_return import *
29
- from ._loop_collect_return import __all__ as _loops_collection
30
- from ._loop_no_collection import *
31
- from ._loop_no_collection import __all__ as _loops_no_collection
32
- from ._make_jaxpr import *
33
- from ._make_jaxpr import __all__ as _make_jaxpr_all
34
- from ._progress_bar import *
35
- from ._progress_bar import __all__ as _progress_bar_all
20
+ from ._ad_checkpoint import checkpoint, remat
21
+ from ._conditions import cond, switch, ifelse
22
+ from ._error_if import jit_error_if
23
+ from ._jit import jit
24
+ from ._loop_collect_return import scan, checkpointed_scan, for_loop, checkpointed_for_loop
25
+ from ._loop_no_collection import while_loop, bounded_while_loop
26
+ from ._make_jaxpr import StatefulFunction, make_jaxpr
27
+ from ._progress_bar import ProgressBar
36
28
 
37
- __all__ = (
38
- _jit_error_all
39
- + _conditions_all
40
- + _make_jaxpr_all
41
- + _jit_all
42
- + _progress_bar_all
43
- + _loops_collection
44
- + _loops_no_collection
45
- + _ad_checkpoint_all
46
- )
47
-
48
- del (
49
- _jit_error_all,
50
- _conditions_all,
51
- _loops_collection,
52
- _make_jaxpr_all,
53
- _jit_all,
54
- _progress_bar_all,
55
- _loops_no_collection,
56
- _ad_checkpoint_all
57
- )
29
+ __all__ = [
30
+ 'checkpoint', 'remat',
31
+ 'cond', 'switch', 'ifelse',
32
+ 'jit_error_if',
33
+ 'jit',
34
+ 'scan', 'checkpointed_scan', 'for_loop', 'checkpointed_for_loop',
35
+ 'while_loop', 'bounded_while_loop',
36
+ 'StatefulFunction', 'make_jaxpr',
37
+ 'ProgressBar',
38
+ ]
@@ -182,10 +182,12 @@ def checkpoint(
182
182
 
183
183
  static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
184
184
  fun = StatefulFunction(fun, static_argnums=static_argnums)
185
- checkpointed_fun = jax.checkpoint(fun.jaxpr_call,
186
- prevent_cse=prevent_cse,
187
- policy=policy,
188
- static_argnums=tuple(i + 1 for i in static_argnums))
185
+ checkpointed_fun = jax.checkpoint(
186
+ fun.jaxpr_call,
187
+ prevent_cse=prevent_cse,
188
+ policy=policy,
189
+ static_argnums=tuple(i + 1 for i in static_argnums)
190
+ )
189
191
 
190
192
  @functools.wraps(fun.fun)
191
193
  def remat_fun(*args, **params):
@@ -62,7 +62,6 @@ def _get_jitted_fun(
62
62
  **kwargs
63
63
  ) -> JittedFunction:
64
64
  static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
65
- # TODO: add to cache stack for clear_cache
66
65
  fun = StatefulFunction(fun, static_argnums=static_argnums, abstracted_axes=abstracted_axes, cache_type='jit')
67
66
  jit_fun = jax.jit(fun.jaxpr_call,
68
67
  static_argnums=tuple(i + 1 for i in static_argnums),
@@ -83,10 +82,12 @@ def _get_jitted_fun(
83
82
  return fun.fun(*args, **params)
84
83
 
85
84
  # compile the function and get the state trace
85
+ # print('Compiling ...')
86
86
  state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
87
87
  read_state_vals = state_trace.get_read_state_values(True)
88
88
 
89
89
  # call the jitted function
90
+ # print('Running ...')
90
91
  write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
91
92
  # write the state values back to the states
92
93
  write_back_state_values(state_trace, read_state_vals, write_state_vals)
@@ -127,11 +128,15 @@ def _get_jitted_fun(
127
128
  """
128
129
  # compile the function and get the state trace
129
130
  state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
130
- read_state_vals = state_trace.get_read_state_values(True)
131
+ read_state_vals = state_trace.get_read_state_values(replace_writen=True)
132
+ write_state_vals = state_trace.get_write_state_values(replace_read=True)
131
133
 
132
- # call the jitted function
133
- return jit_fun.lower(state_trace.get_state_values(), *args, **params).compile()
134
+ # compile the model
135
+ ret = jit_fun.lower(state_trace.get_state_values(), *args, **params).compile()
134
136
 
137
+ # write the state values back to the states
138
+ write_back_state_values(state_trace, read_state_vals, write_state_vals)
139
+ return ret
135
140
 
136
141
  jitted_fun: JittedFunction
137
142
 
@@ -290,31 +295,35 @@ def jit(
290
295
 
291
296
  if isinstance(fun, Missing):
292
297
  def wrapper(fun_again: Callable) -> JittedFunction:
293
- return _get_jitted_fun(fun_again,
294
- in_shardings=in_shardings,
295
- out_shardings=out_shardings,
296
- static_argnums=static_argnums,
297
- donate_argnums=donate_argnums,
298
- donate_argnames=donate_argnames,
299
- keep_unused=keep_unused,
300
- device=device,
301
- backend=backend,
302
- inline=inline,
303
- abstracted_axes=abstracted_axes,
304
- **kwargs)
298
+ return _get_jitted_fun(
299
+ fun_again,
300
+ in_shardings=in_shardings,
301
+ out_shardings=out_shardings,
302
+ static_argnums=static_argnums,
303
+ donate_argnums=donate_argnums,
304
+ donate_argnames=donate_argnames,
305
+ keep_unused=keep_unused,
306
+ device=device,
307
+ backend=backend,
308
+ inline=inline,
309
+ abstracted_axes=abstracted_axes,
310
+ **kwargs
311
+ )
305
312
 
306
313
  return wrapper
307
314
 
308
315
  else:
309
- return _get_jitted_fun(fun,
310
- in_shardings,
311
- out_shardings,
312
- static_argnums,
313
- donate_argnums,
314
- donate_argnames,
315
- keep_unused,
316
- device,
317
- backend,
318
- inline,
319
- abstracted_axes,
320
- **kwargs)
316
+ return _get_jitted_fun(
317
+ fun,
318
+ in_shardings,
319
+ out_shardings,
320
+ static_argnums,
321
+ donate_argnums,
322
+ donate_argnames,
323
+ keep_unused,
324
+ device,
325
+ backend,
326
+ inline,
327
+ abstracted_axes,
328
+ **kwargs
329
+ )
@@ -31,6 +31,7 @@ from ._util import write_back_state_values, wrap_single_fun
31
31
  __all__ = [
32
32
  # "scan" syntax, which is similar to jax.lax.scan
33
33
  'scan', 'checkpointed_scan',
34
+
34
35
  # "for_loop" syntax
35
36
  'for_loop', 'checkpointed_for_loop',
36
37
  ]
@@ -48,9 +49,9 @@ def _wrap_fun_with_pbar(
48
49
  @wraps(fun)
49
50
  def new_fun(new_carry, inputs):
50
51
  i, old_carry = new_carry
51
- old_carry, old_outputs = fun(old_carry, inputs)
52
- pbar_runner(unvmap(i, op='none'))
53
- return (i + 1, old_carry), old_outputs
52
+ new_carry, new_outputs = fun(old_carry, inputs)
53
+ pbar_runner(unvmap(i, op='none'), carry=new_carry, y=new_outputs)
54
+ return (i + 1, new_carry), new_outputs
54
55
 
55
56
  return new_fun
56
57
 
@@ -476,6 +477,8 @@ def checkpointed_for_loop(
476
477
  return ys
477
478
 
478
479
 
480
+ # This function is adapted from ``while_loop`` in `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
481
+
479
482
  # There's several tricks happening here to work around various limitations of JAX.
480
483
  # (Also see https://github.com/google/jax/issues/2139#issuecomment-1039293633)
481
484
  # 1. `unvmap_any` prior to using `lax.cond`. JAX has a problem in that vmap-of-cond
@@ -134,6 +134,8 @@ def bounded_while_loop(
134
134
  """
135
135
  While loop with a bound on the maximum number of steps.
136
136
 
137
+ This function is adapted from ``while_loop`` in `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
138
+
137
139
  This function is useful when you want to ensure that a while loop terminates
138
140
  even if the condition function is never false. The function is implemented
139
141
  using a scan operation, so it is reverse-mode differentiable.
@@ -73,6 +73,13 @@ from brainstate._state import State, StateTraceStack
73
73
  from brainstate._utils import set_module_as
74
74
  from brainstate.typing import PyTree
75
75
 
76
+
77
+ if jax.__version_info__ < (0, 4, 38):
78
+ from jax.core import ClosedJaxpr
79
+ else:
80
+ from jax.extend.core import ClosedJaxpr
81
+
82
+
76
83
  AxisName = Hashable
77
84
 
78
85
  __all__ = [
@@ -396,7 +403,7 @@ class StatefulFunction(object):
396
403
  # Checking whether the states are returned.
397
404
  for leaf in jax.tree.leaves(out):
398
405
  if isinstance(leaf, State):
399
- leaf._raise_error_with_source_info(ValueError(f"State object is not allowed to be returned: {leaf}"))
406
+ leaf.raise_error_with_source_info(ValueError(f"State object is not allowed to be returned: {leaf}"))
400
407
  return out, state_values
401
408
 
402
409
  def make_jaxpr(self, *args, return_only_write: bool = False, **kwargs):
@@ -660,7 +667,7 @@ def _make_jaxpr(
660
667
  axis_env: Sequence[tuple[AxisName, int]] | None = None,
661
668
  return_shape: bool = False,
662
669
  abstracted_axes: Any | None = None,
663
- ) -> Callable[..., (jax.core.ClosedJaxpr | tuple[jax.core.ClosedJaxpr, Any])]:
670
+ ) -> Callable[..., (ClosedJaxpr | tuple[ClosedJaxpr, Any])]:
664
671
  """Creates a function that produces its jaxpr given example args.
665
672
 
666
673
  Args:
@@ -741,11 +748,15 @@ def _make_jaxpr(
741
748
  in_type = tuple(jax.util.safe_zip(in_avals, keep_inputs))
742
749
  f, out_tree = _flatten_fun(f, in_tree)
743
750
  f = annotate(f, in_type)
744
- debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
751
+ if jax.__version_info__ < (0, 5, 0):
752
+ debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
745
753
  with ExitStack() as stack:
746
754
  for axis_name, size in axis_env or []:
747
755
  stack.enter_context(jax.core.extend_axis_env(axis_name, size, None))
748
- jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
756
+ if jax.__version_info__ < (0, 5, 0):
757
+ jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
758
+ else:
759
+ jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
749
760
  closed_jaxpr = jax.core.ClosedJaxpr(jaxpr, consts)
750
761
  if return_shape:
751
762
  out_avals, _ = jax.util.unzip2(out_type)