brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. brainstate/_state.py +77 -44
  2. brainstate/_state_test.py +0 -17
  3. brainstate/augment/_eval_shape.py +9 -10
  4. brainstate/augment/_eval_shape_test.py +1 -1
  5. brainstate/augment/_mapping.py +265 -277
  6. brainstate/augment/_mapping_test.py +147 -175
  7. brainstate/compile/_ad_checkpoint.py +6 -4
  8. brainstate/compile/_jit.py +37 -28
  9. brainstate/compile/_loop_collect_return.py +6 -3
  10. brainstate/compile/_loop_no_collection.py +2 -0
  11. brainstate/compile/_make_jaxpr.py +7 -3
  12. brainstate/compile/_progress_bar.py +68 -40
  13. brainstate/compile/_unvmap.py +6 -3
  14. brainstate/event/__init__.py +0 -2
  15. brainstate/event/_csr.py +266 -23
  16. brainstate/event/_csr_test.py +187 -0
  17. brainstate/event/_xla_custom_op.py +7 -3
  18. brainstate/graph/__init__.py +8 -12
  19. brainstate/graph/_graph_node.py +1 -23
  20. brainstate/graph/_graph_operation.py +1 -1
  21. brainstate/graph/_graph_operation_test.py +0 -159
  22. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  23. brainstate/nn/_interaction/_conv.py +4 -2
  24. brainstate/nn/_interaction/_linear.py +84 -10
  25. brainstate/random/_rand_funs.py +9 -2
  26. brainstate/random/_rand_seed.py +12 -2
  27. brainstate/random/_rand_state.py +50 -179
  28. brainstate/surrogate.py +5 -1
  29. brainstate/util/__init__.py +0 -4
  30. brainstate/util/_caller.py +1 -1
  31. brainstate/util/_dict.py +4 -1
  32. brainstate/util/_filter.py +1 -1
  33. brainstate/util/_pretty_repr.py +1 -1
  34. brainstate/util/_struct.py +1 -1
  35. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
  36. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +40 -46
  37. brainstate/event/_csr_mv_test.py +0 -118
  38. brainstate/graph/_graph_context.py +0 -443
  39. brainstate/graph/_graph_context_test.py +0 -65
  40. brainstate/graph/_graph_convert.py +0 -246
  41. brainstate/util/_tracers.py +0 -68
  42. brainstate/util/_visualization.py +0 -47
  43. /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
  44. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
  45. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
  46. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
@@ -17,57 +17,122 @@ from __future__ import annotations
17
17
 
18
18
  import unittest
19
19
 
20
- import jax.core
21
20
  import jax.numpy as jnp
22
21
 
23
22
  import brainstate as bst
23
+ from brainstate.augment._mapping import BatchAxisError
24
24
 
25
25
 
26
26
  class TestVmap(unittest.TestCase):
27
- def test_vmap_return_keep_reference_return(self):
28
- @bst.augment.vmap(in_axes=0, out_axes=0)
29
- def create_model(key):
30
- bst.random.set_key(key)
31
- m1 = bst.nn.Linear(2, 3)
32
-
33
- m2 = bst.nn.Linear(3, 4)
34
- m2.a = m1
35
- m3 = bst.nn.Linear(3, 5)
36
- m3.a = m1
37
- self.assertTrue(id(m2.a) == id(m3.a))
38
- return m2, m3
39
-
40
- m2, m3 = create_model(bst.random.split_key(10))
41
- self.assertTrue(id(m2.a) == id(m3.a))
42
- jax.core.concrete_or_error(None, bst.random.DEFAULT.value)
43
-
44
- def test_vmap_return_keep_reference_pass_into_fun(self):
45
- @bst.augment.vmap(in_axes=(None, None, 0), out_axes=0)
46
- def run_model(m2, m3, x):
47
- self.assertTrue(id(m2.a) == id(m3.a))
48
- self.assertTrue(id(m2) != m2_id)
49
- self.assertTrue(id(m3) != m3_id)
50
- return m2(x), m3(x)
51
-
52
- m1 = bst.nn.Linear(2, 3)
53
- m2 = bst.nn.Linear(4, 3)
54
- m2.a = m1
55
- m3 = bst.nn.Linear(4, 5)
56
- m3.a = m1
57
- m3_id = id(m3)
58
- m2_id = id(m2)
59
- r1, r2 = run_model(m2, m3, jnp.ones((4, 3, 4)))
60
-
61
- def test_vmap_set_key(self):
62
- @bst.augment.vmap(in_axes=0, out_axes=0)
63
- def create_model(key):
64
- bst.random.set_key(key)
65
- return bst.nn.Linear(2, 3)
66
-
67
- model = create_model(bst.random.split_keys(10))
68
- print(model.weight.value_call(jnp.shape))
69
- model.weight.value_call(lambda x: jax.core.concrete_or_error(None, x))
70
- bst.random.seed()
27
+ def test_vmap_1(self):
28
+ class Model(bst.nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+
32
+ self.a = bst.State(bst.random.randn(5))
33
+ self.b = bst.State(bst.random.randn(5))
34
+
35
+ def __call__(self, *args, **kwargs):
36
+ return self.a.value * self.b.value
37
+
38
+ model = Model()
39
+ r1 = model.a.value * model.b.value
40
+ r2 = bst.augment.vmap(model, in_states=model.states())()
41
+ self.assertTrue(jnp.allclose(r1, r2))
42
+
43
+ def test_vmap_2(self):
44
+ class Model(bst.nn.Module):
45
+ def __init__(self):
46
+ super().__init__()
47
+
48
+ self.a = bst.ShortTermState(bst.random.randn(5))
49
+ self.b = bst.ShortTermState(bst.random.randn(5))
50
+ self.c = bst.State(bst.random.randn(1))
51
+
52
+ def __call__(self, *args, **kwargs):
53
+ self.c.value = self.a.value * self.b.value
54
+ return self.c.value + 1.
55
+
56
+ model = Model()
57
+ with self.assertRaises(BatchAxisError):
58
+ r2 = bst.augment.vmap(model, in_states=model.states(bst.ShortTermState))()
59
+
60
+ model = Model()
61
+ r2 = bst.augment.vmap(model, in_states=model.states(bst.ShortTermState), out_states=model.c)()
62
+
63
+ def test_vmap_3(self):
64
+ class Model(bst.nn.Module):
65
+ def __init__(self):
66
+ super().__init__()
67
+
68
+ self.a = bst.State(bst.random.randn(5))
69
+ self.b = bst.State(bst.random.randn(5))
70
+
71
+ def __call__(self, *args, **kwargs):
72
+ return self.a.value * self.b.value
73
+
74
+ model = Model()
75
+ with self.assertRaises(BatchAxisError):
76
+ r2 = bst.augment.vmap(model, in_states=model.states(), out_states={1: model.states()})()
77
+
78
+ def test_vmap_with_random(self):
79
+ class Model(bst.nn.Module):
80
+ def __init__(self):
81
+ super().__init__()
82
+
83
+ self.a = bst.ShortTermState(bst.random.randn(5))
84
+ self.b = bst.ShortTermState(bst.random.randn(5))
85
+ self.c = bst.State(bst.random.randn(1))
86
+
87
+ def __call__(self, key):
88
+ bst.random.set_key(key)
89
+ self.c.value = self.a.value * self.b.value
90
+ return self.c.value + bst.random.randn(1)
91
+
92
+ model = Model()
93
+ r2 = bst.augment.vmap(
94
+ model,
95
+ in_states=model.states(bst.ShortTermState),
96
+ out_states=model.c
97
+ )(
98
+ bst.random.split_key(5)
99
+ )
100
+ print(bst.random.DEFAULT)
101
+
102
+ def test_vmap_with_random_2(self):
103
+ class Model(bst.nn.Module):
104
+ def __init__(self):
105
+ super().__init__()
106
+
107
+ self.a = bst.ShortTermState(bst.random.randn(5))
108
+ self.b = bst.ShortTermState(bst.random.randn(5))
109
+ self.c = bst.State(bst.random.randn(1))
110
+ self.rng = bst.random.RandomState(1)
111
+
112
+ def __call__(self, key):
113
+ self.rng.set_key(key)
114
+ self.c.value = self.a.value * self.b.value
115
+ return self.c.value + bst.random.randn(1)
116
+
117
+ model = Model()
118
+ with self.assertRaises(BatchAxisError):
119
+ r2 = bst.augment.vmap(
120
+ model,
121
+ in_states=model.states(bst.ShortTermState),
122
+ out_states=model.c
123
+ )(
124
+ bst.random.split_key(5)
125
+ )
126
+
127
+ model = Model()
128
+ r2 = bst.augment.vmap(
129
+ model,
130
+ in_states=model.states(bst.ShortTermState),
131
+ out_states=model.c,
132
+ rngs=model.rng,
133
+ )(
134
+ bst.random.split_key(5)
135
+ )
71
136
 
72
137
  def test_vmap_input(self):
73
138
  model = bst.nn.Linear(2, 3)
@@ -98,8 +163,8 @@ class TestVmap(unittest.TestCase):
98
163
 
99
164
  @bst.augment.vmap(in_axes=(None, 0), out_axes=0)
100
165
  def forward(model, x):
101
- self.assertTrue(id(model) != model_id)
102
- self.assertTrue(id(model.weight) != weight_id)
166
+ self.assertTrue(id(model) == model_id)
167
+ self.assertTrue(id(model.weight) == weight_id)
103
168
  print(id(model), id(model.weight))
104
169
  return model(x)
105
170
 
@@ -108,51 +173,7 @@ class TestVmap(unittest.TestCase):
108
173
  print(model.weight.value_call(jnp.shape))
109
174
  print(model.weight.value)
110
175
 
111
- def test_vmap1(self):
112
- model = bst.nn.Linear(2, 3)
113
- x = jnp.ones((5, 2))
114
-
115
- @bst.augment.vmap(in_axes=(None, 0), out_axes=0)
116
- def forward(model, x):
117
- return model(x)
118
-
119
- y = forward(model, x)
120
- print(y.shape)
121
-
122
- def test_vmap2(self):
123
- class LinearEnsemble(bst.nn.Module):
124
- def __init__(self, num):
125
- super().__init__()
126
- self.w = bst.ParamState(bst.random.random((num, 2, 3)))
127
-
128
- model = LinearEnsemble(5)
129
- x = jnp.ones((2,))
130
-
131
- @bst.augment.vmap(in_axes=(0, None), out_axes=0)
132
- def forward(model, x):
133
- return jnp.dot(x, model.w.value)
134
-
135
- y = forward(model, x)
136
- print(y.shape)
137
-
138
- def test_vmap3(self):
139
- class Foo(bst.nn.Module):
140
- def __init__(self):
141
- super().__init__()
142
- self.a = bst.ParamState(jnp.arange(4))
143
- self.b = bst.ShortTermState(jnp.arange(4))
144
-
145
- state_axes = bst.augment.StateAxes({bst.ParamState: 0, bst.ShortTermState: None})
146
-
147
- @bst.augment.vmap(in_axes=(state_axes,), out_axes=0)
148
- def mul(foo):
149
- return foo.a.value * foo.b.value
150
-
151
- foo = Foo()
152
- y = mul(foo)
153
- print(y.shape)
154
-
155
- def test_vmap4(self):
176
+ def test_vmap_jit(self):
156
177
  class Foo(bst.nn.Module):
157
178
  def __init__(self):
158
179
  super().__init__()
@@ -162,60 +183,35 @@ class TestVmap(unittest.TestCase):
162
183
  def __call__(self):
163
184
  self.b.value = self.a.value * self.b.value
164
185
 
165
- @bst.augment.vmap
166
- def mul(foo):
167
- foo()
168
- return foo
169
-
170
186
  foo = Foo()
171
- with bst.StateTraceStack() as trace:
172
- m = mul(foo)
173
-
174
- self.assertTrue(m is foo)
175
- print(m.a.value, foo.a.value)
176
- self.assertTrue(jnp.allclose(m.a.value, foo.a.value))
177
- print(m.b.value, foo.b.value)
178
- self.assertTrue(jnp.allclose(m.b.value, foo.b.value))
179
- print(trace.get_write_states())
180
- self.assertTrue(len(trace.get_write_states()) == 1)
181
- print(trace.get_read_states())
182
- self.assertTrue(len(trace.get_read_states()) == 2)
183
-
184
- def test_vmap5(self):
185
- class Foo(bst.nn.Module):
186
- def __init__(self):
187
- super().__init__()
188
- self.a = bst.ParamState(jnp.arange(4))
189
- self.b = bst.ShortTermState(jnp.arange(4))
190
187
 
191
- def __call__(self):
192
- self.b.value = self.a.value * self.b.value
193
-
194
- @bst.augment.vmap
195
- def mul(foo):
188
+ @bst.augment.vmap(in_states=foo.states())
189
+ def mul():
196
190
  foo()
197
191
 
198
- foo = Foo()
192
+ @bst.compile.jit
193
+ def mul_jit(inp):
194
+ mul()
195
+ foo.a.value += inp
196
+
199
197
  with bst.StateTraceStack() as trace:
200
- mul(foo)
198
+ mul_jit(1.)
201
199
 
202
200
  print(foo.a.value)
203
201
  print(foo.b.value)
204
- self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4)))
202
+ self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
205
203
  self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
206
204
 
207
205
  write_state_ids = [id(st) for st in trace.get_write_states()]
208
206
  read_state_ids = [id(st) for st in trace.get_read_states()]
209
207
 
210
- assert id(foo.a) in read_state_ids
208
+ assert id(foo.a) in write_state_ids
211
209
  assert id(foo.b) in write_state_ids
212
210
 
213
211
  print(trace.get_write_states())
214
212
  print(trace.get_read_states())
215
213
 
216
-
217
-
218
- def test_vmap_jit(self):
214
+ def test_vmap_jit_2(self):
219
215
  class Foo(bst.nn.Module):
220
216
  def __init__(self):
221
217
  super().__init__()
@@ -225,70 +221,46 @@ class TestVmap(unittest.TestCase):
225
221
  def __call__(self):
226
222
  self.b.value = self.a.value * self.b.value
227
223
 
228
- @bst.augment.vmap
229
- def mul(foo):
224
+ foo = Foo()
225
+
226
+ @bst.augment.vmap(in_states=foo.states())
227
+ def mul():
230
228
  foo()
231
229
 
232
230
  @bst.compile.jit
233
231
  def mul_jit(inp):
234
- mul(foo)
235
- foo.a.value += inp
232
+ mul()
233
+ foo.b.value += inp
236
234
 
237
- foo = Foo()
238
235
  with bst.StateTraceStack() as trace:
239
236
  mul_jit(1.)
240
237
 
241
238
  print(foo.a.value)
242
239
  print(foo.b.value)
243
- self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
244
- self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
240
+ self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4)))
241
+ self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4) + 1.))
245
242
 
246
243
  write_state_ids = [id(st) for st in trace.get_write_states()]
247
244
  read_state_ids = [id(st) for st in trace.get_read_states()]
248
245
 
249
- assert id(foo.a) in write_state_ids
246
+ assert id(foo.a) in read_state_ids
250
247
  assert id(foo.b) in write_state_ids
251
248
 
252
249
  print(trace.get_write_states())
253
250
  print(trace.get_read_states())
254
251
 
255
252
 
256
- def test_vmap_grad(self):
257
- class Foo(bst.nn.Module):
258
- def __init__(self):
259
- super().__init__()
260
- self.a = bst.ParamState(jnp.arange(4.))
261
- self.b = bst.ShortTermState(jnp.arange(4.))
262
-
263
- def __call__(self):
264
- self.b.value = self.a.value * self.b.value
265
-
266
- @bst.augment.vmap
267
- def mul(foo):
268
- foo()
269
-
270
- def loss():
271
- mul(foo)
272
- return jnp.sum(foo.b.value)
273
-
274
- foo = Foo()
275
- with bst.StateTraceStack() as trace:
276
- grads, loss = bst.augment.grad(loss, foo.states(bst.ParamState), return_value=True)()
277
- print(grads)
278
- print(loss)
279
-
280
- # print(foo.a.value)
281
- # print(foo.b.value)
282
- # self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
283
- # self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
284
- #
285
- # write_state_ids = [id(st) for st in trace.get_write_states()]
286
- # read_state_ids = [id(st) for st in trace.get_read_states()]
287
- #
288
- # assert id(foo.a) in write_state_ids
289
- # assert id(foo.b) in write_state_ids
290
- #
291
- # print(trace.get_write_states())
292
- # print(trace.get_read_states())
293
-
294
-
253
+ class TestMap(unittest.TestCase):
254
+ def test_map(self):
255
+ for dim in [(10,), (10, 10), (10, 10, 10)]:
256
+ x = bst.random.rand(*dim)
257
+ r1 = bst.augment.map(lambda a: a + 1, x, batch_size=None)
258
+ r2 = bst.augment.map(lambda a: a + 1, x, batch_size=2)
259
+ r3 = bst.augment.map(lambda a: a + 1, x, batch_size=4)
260
+ r4 = bst.augment.map(lambda a: a + 1, x, batch_size=5)
261
+ true_r = x + 1
262
+
263
+ self.assertTrue(jnp.allclose(r1, true_r))
264
+ self.assertTrue(jnp.allclose(r2, true_r))
265
+ self.assertTrue(jnp.allclose(r3, true_r))
266
+ self.assertTrue(jnp.allclose(r4, true_r))
@@ -182,10 +182,12 @@ def checkpoint(
182
182
 
183
183
  static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
184
184
  fun = StatefulFunction(fun, static_argnums=static_argnums)
185
- checkpointed_fun = jax.checkpoint(fun.jaxpr_call,
186
- prevent_cse=prevent_cse,
187
- policy=policy,
188
- static_argnums=tuple(i + 1 for i in static_argnums))
185
+ checkpointed_fun = jax.checkpoint(
186
+ fun.jaxpr_call,
187
+ prevent_cse=prevent_cse,
188
+ policy=policy,
189
+ static_argnums=tuple(i + 1 for i in static_argnums)
190
+ )
189
191
 
190
192
  @functools.wraps(fun.fun)
191
193
  def remat_fun(*args, **params):
@@ -62,7 +62,6 @@ def _get_jitted_fun(
62
62
  **kwargs
63
63
  ) -> JittedFunction:
64
64
  static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
65
- # TODO: add to cache stack for clear_cache
66
65
  fun = StatefulFunction(fun, static_argnums=static_argnums, abstracted_axes=abstracted_axes, cache_type='jit')
67
66
  jit_fun = jax.jit(fun.jaxpr_call,
68
67
  static_argnums=tuple(i + 1 for i in static_argnums),
@@ -83,10 +82,12 @@ def _get_jitted_fun(
83
82
  return fun.fun(*args, **params)
84
83
 
85
84
  # compile the function and get the state trace
85
+ # print('Compiling ...')
86
86
  state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
87
87
  read_state_vals = state_trace.get_read_state_values(True)
88
88
 
89
89
  # call the jitted function
90
+ # print('Running ...')
90
91
  write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
91
92
  # write the state values back to the states
92
93
  write_back_state_values(state_trace, read_state_vals, write_state_vals)
@@ -127,11 +128,15 @@ def _get_jitted_fun(
127
128
  """
128
129
  # compile the function and get the state trace
129
130
  state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
130
- read_state_vals = state_trace.get_read_state_values(True)
131
+ read_state_vals = state_trace.get_read_state_values(replace_writen=True)
132
+ write_state_vals = state_trace.get_write_state_values(replace_read=True)
131
133
 
132
- # call the jitted function
133
- return jit_fun.lower(state_trace.get_state_values(), *args, **params).compile()
134
+ # compile the model
135
+ ret = jit_fun.lower(state_trace.get_state_values(), *args, **params).compile()
134
136
 
137
+ # write the state values back to the states
138
+ write_back_state_values(state_trace, read_state_vals, write_state_vals)
139
+ return ret
135
140
 
136
141
  jitted_fun: JittedFunction
137
142
 
@@ -290,31 +295,35 @@ def jit(
290
295
 
291
296
  if isinstance(fun, Missing):
292
297
  def wrapper(fun_again: Callable) -> JittedFunction:
293
- return _get_jitted_fun(fun_again,
294
- in_shardings=in_shardings,
295
- out_shardings=out_shardings,
296
- static_argnums=static_argnums,
297
- donate_argnums=donate_argnums,
298
- donate_argnames=donate_argnames,
299
- keep_unused=keep_unused,
300
- device=device,
301
- backend=backend,
302
- inline=inline,
303
- abstracted_axes=abstracted_axes,
304
- **kwargs)
298
+ return _get_jitted_fun(
299
+ fun_again,
300
+ in_shardings=in_shardings,
301
+ out_shardings=out_shardings,
302
+ static_argnums=static_argnums,
303
+ donate_argnums=donate_argnums,
304
+ donate_argnames=donate_argnames,
305
+ keep_unused=keep_unused,
306
+ device=device,
307
+ backend=backend,
308
+ inline=inline,
309
+ abstracted_axes=abstracted_axes,
310
+ **kwargs
311
+ )
305
312
 
306
313
  return wrapper
307
314
 
308
315
  else:
309
- return _get_jitted_fun(fun,
310
- in_shardings,
311
- out_shardings,
312
- static_argnums,
313
- donate_argnums,
314
- donate_argnames,
315
- keep_unused,
316
- device,
317
- backend,
318
- inline,
319
- abstracted_axes,
320
- **kwargs)
316
+ return _get_jitted_fun(
317
+ fun,
318
+ in_shardings,
319
+ out_shardings,
320
+ static_argnums,
321
+ donate_argnums,
322
+ donate_argnames,
323
+ keep_unused,
324
+ device,
325
+ backend,
326
+ inline,
327
+ abstracted_axes,
328
+ **kwargs
329
+ )
@@ -31,6 +31,7 @@ from ._util import write_back_state_values, wrap_single_fun
31
31
  __all__ = [
32
32
  # "scan" syntax, which is similar to jax.lax.scan
33
33
  'scan', 'checkpointed_scan',
34
+
34
35
  # "for_loop" syntax
35
36
  'for_loop', 'checkpointed_for_loop',
36
37
  ]
@@ -48,9 +49,9 @@ def _wrap_fun_with_pbar(
48
49
  @wraps(fun)
49
50
  def new_fun(new_carry, inputs):
50
51
  i, old_carry = new_carry
51
- old_carry, old_outputs = fun(old_carry, inputs)
52
- pbar_runner(unvmap(i, op='none'))
53
- return (i + 1, old_carry), old_outputs
52
+ new_carry, new_outputs = fun(old_carry, inputs)
53
+ pbar_runner(unvmap(i, op='none'), carry=new_carry, y=new_outputs)
54
+ return (i + 1, new_carry), new_outputs
54
55
 
55
56
  return new_fun
56
57
 
@@ -476,6 +477,8 @@ def checkpointed_for_loop(
476
477
  return ys
477
478
 
478
479
 
480
+ # This function is adapted from ``while_loop`` in `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
481
+
479
482
  # There's several tricks happening here to work around various limitations of JAX.
480
483
  # (Also see https://github.com/google/jax/issues/2139#issuecomment-1039293633)
481
484
  # 1. `unvmap_any` prior to using `lax.cond`. JAX has a problem in that vmap-of-cond
@@ -134,6 +134,8 @@ def bounded_while_loop(
134
134
  """
135
135
  While loop with a bound on the maximum number of steps.
136
136
 
137
+ This function is adapted from ``while_loop`` in `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
138
+
137
139
  This function is useful when you want to ensure that a while loop terminates
138
140
  even if the condition function is never false. The function is implemented
139
141
  using a scan operation, so it is reverse-mode differentiable.
@@ -396,7 +396,7 @@ class StatefulFunction(object):
396
396
  # Checking whether the states are returned.
397
397
  for leaf in jax.tree.leaves(out):
398
398
  if isinstance(leaf, State):
399
- leaf._raise_error_with_source_info(ValueError(f"State object is not allowed to be returned: {leaf}"))
399
+ leaf.raise_error_with_source_info(ValueError(f"State object is not allowed to be returned: {leaf}"))
400
400
  return out, state_values
401
401
 
402
402
  def make_jaxpr(self, *args, return_only_write: bool = False, **kwargs):
@@ -741,11 +741,15 @@ def _make_jaxpr(
741
741
  in_type = tuple(jax.util.safe_zip(in_avals, keep_inputs))
742
742
  f, out_tree = _flatten_fun(f, in_tree)
743
743
  f = annotate(f, in_type)
744
- debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
744
+ if jax.__version_info__ < (0, 5, 0):
745
+ debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
745
746
  with ExitStack() as stack:
746
747
  for axis_name, size in axis_env or []:
747
748
  stack.enter_context(jax.core.extend_axis_env(axis_name, size, None))
748
- jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
749
+ if jax.__version_info__ < (0, 5, 0):
750
+ jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
751
+ else:
752
+ jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
749
753
  closed_jaxpr = jax.core.ClosedJaxpr(jaxpr, consts)
750
754
  if return_shape:
751
755
  out_avals, _ = jax.util.unzip2(out_type)