brainstate 0.1.0.post20250104__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +77 -44
- brainstate/_state_test.py +0 -17
- brainstate/augment/_eval_shape.py +9 -10
- brainstate/augment/_eval_shape_test.py +1 -1
- brainstate/augment/_mapping.py +265 -277
- brainstate/augment/_mapping_test.py +147 -175
- brainstate/compile/_ad_checkpoint.py +6 -4
- brainstate/compile/_error_if_test.py +1 -0
- brainstate/compile/_jit.py +37 -28
- brainstate/compile/_loop_collect_return.py +8 -5
- brainstate/compile/_loop_no_collection.py +2 -0
- brainstate/compile/_make_jaxpr.py +7 -3
- brainstate/compile/_make_jaxpr_test.py +2 -1
- brainstate/compile/_progress_bar.py +68 -40
- brainstate/compile/_unvmap.py +6 -2
- brainstate/environ.py +28 -18
- brainstate/environ_test.py +4 -0
- brainstate/event/__init__.py +0 -2
- brainstate/event/_csr.py +266 -23
- brainstate/event/_csr_test.py +187 -0
- brainstate/event/_fixedprob_mv.py +4 -2
- brainstate/event/_fixedprob_mv_test.py +2 -1
- brainstate/event/_xla_custom_op.py +16 -5
- brainstate/graph/__init__.py +8 -12
- brainstate/graph/_graph_node.py +1 -23
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/graph/_graph_operation_test.py +0 -159
- brainstate/nn/_dyn_impl/_inputs.py +124 -39
- brainstate/nn/_interaction/_conv.py +4 -2
- brainstate/nn/_interaction/_linear.py +84 -10
- brainstate/random/_rand_funs.py +9 -2
- brainstate/random/_rand_seed.py +12 -2
- brainstate/random/_rand_state.py +50 -179
- brainstate/surrogate.py +5 -1
- brainstate/util/__init__.py +0 -4
- brainstate/util/_caller.py +1 -1
- brainstate/util/_dict.py +4 -1
- brainstate/util/_filter.py +1 -1
- brainstate/util/_pretty_repr.py +1 -1
- brainstate/util/_struct.py +1 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +46 -52
- brainstate/event/_csr_mv_test.py +0 -118
- brainstate/graph/_graph_context.py +0 -443
- brainstate/graph/_graph_context_test.py +0 -65
- brainstate/graph/_graph_convert.py +0 -246
- brainstate/util/_tracers.py +0 -68
- brainstate/util/_visualization.py +0 -47
- /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
@@ -17,57 +17,122 @@ from __future__ import annotations
|
|
17
17
|
|
18
18
|
import unittest
|
19
19
|
|
20
|
-
import jax.core
|
21
20
|
import jax.numpy as jnp
|
22
21
|
|
23
22
|
import brainstate as bst
|
23
|
+
from brainstate.augment._mapping import BatchAxisError
|
24
24
|
|
25
25
|
|
26
26
|
class TestVmap(unittest.TestCase):
|
27
|
-
def
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
self.assertTrue(
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
27
|
+
def test_vmap_1(self):
|
28
|
+
class Model(bst.nn.Module):
|
29
|
+
def __init__(self):
|
30
|
+
super().__init__()
|
31
|
+
|
32
|
+
self.a = bst.State(bst.random.randn(5))
|
33
|
+
self.b = bst.State(bst.random.randn(5))
|
34
|
+
|
35
|
+
def __call__(self, *args, **kwargs):
|
36
|
+
return self.a.value * self.b.value
|
37
|
+
|
38
|
+
model = Model()
|
39
|
+
r1 = model.a.value * model.b.value
|
40
|
+
r2 = bst.augment.vmap(model, in_states=model.states())()
|
41
|
+
self.assertTrue(jnp.allclose(r1, r2))
|
42
|
+
|
43
|
+
def test_vmap_2(self):
|
44
|
+
class Model(bst.nn.Module):
|
45
|
+
def __init__(self):
|
46
|
+
super().__init__()
|
47
|
+
|
48
|
+
self.a = bst.ShortTermState(bst.random.randn(5))
|
49
|
+
self.b = bst.ShortTermState(bst.random.randn(5))
|
50
|
+
self.c = bst.State(bst.random.randn(1))
|
51
|
+
|
52
|
+
def __call__(self, *args, **kwargs):
|
53
|
+
self.c.value = self.a.value * self.b.value
|
54
|
+
return self.c.value + 1.
|
55
|
+
|
56
|
+
model = Model()
|
57
|
+
with self.assertRaises(BatchAxisError):
|
58
|
+
r2 = bst.augment.vmap(model, in_states=model.states(bst.ShortTermState))()
|
59
|
+
|
60
|
+
model = Model()
|
61
|
+
r2 = bst.augment.vmap(model, in_states=model.states(bst.ShortTermState), out_states=model.c)()
|
62
|
+
|
63
|
+
def test_vmap_3(self):
|
64
|
+
class Model(bst.nn.Module):
|
65
|
+
def __init__(self):
|
66
|
+
super().__init__()
|
67
|
+
|
68
|
+
self.a = bst.State(bst.random.randn(5))
|
69
|
+
self.b = bst.State(bst.random.randn(5))
|
70
|
+
|
71
|
+
def __call__(self, *args, **kwargs):
|
72
|
+
return self.a.value * self.b.value
|
73
|
+
|
74
|
+
model = Model()
|
75
|
+
with self.assertRaises(BatchAxisError):
|
76
|
+
r2 = bst.augment.vmap(model, in_states=model.states(), out_states={1: model.states()})()
|
77
|
+
|
78
|
+
def test_vmap_with_random(self):
|
79
|
+
class Model(bst.nn.Module):
|
80
|
+
def __init__(self):
|
81
|
+
super().__init__()
|
82
|
+
|
83
|
+
self.a = bst.ShortTermState(bst.random.randn(5))
|
84
|
+
self.b = bst.ShortTermState(bst.random.randn(5))
|
85
|
+
self.c = bst.State(bst.random.randn(1))
|
86
|
+
|
87
|
+
def __call__(self, key):
|
88
|
+
bst.random.set_key(key)
|
89
|
+
self.c.value = self.a.value * self.b.value
|
90
|
+
return self.c.value + bst.random.randn(1)
|
91
|
+
|
92
|
+
model = Model()
|
93
|
+
r2 = bst.augment.vmap(
|
94
|
+
model,
|
95
|
+
in_states=model.states(bst.ShortTermState),
|
96
|
+
out_states=model.c
|
97
|
+
)(
|
98
|
+
bst.random.split_key(5)
|
99
|
+
)
|
100
|
+
print(bst.random.DEFAULT)
|
101
|
+
|
102
|
+
def test_vmap_with_random_2(self):
|
103
|
+
class Model(bst.nn.Module):
|
104
|
+
def __init__(self):
|
105
|
+
super().__init__()
|
106
|
+
|
107
|
+
self.a = bst.ShortTermState(bst.random.randn(5))
|
108
|
+
self.b = bst.ShortTermState(bst.random.randn(5))
|
109
|
+
self.c = bst.State(bst.random.randn(1))
|
110
|
+
self.rng = bst.random.RandomState(1)
|
111
|
+
|
112
|
+
def __call__(self, key):
|
113
|
+
self.rng.set_key(key)
|
114
|
+
self.c.value = self.a.value * self.b.value
|
115
|
+
return self.c.value + bst.random.randn(1)
|
116
|
+
|
117
|
+
model = Model()
|
118
|
+
with self.assertRaises(BatchAxisError):
|
119
|
+
r2 = bst.augment.vmap(
|
120
|
+
model,
|
121
|
+
in_states=model.states(bst.ShortTermState),
|
122
|
+
out_states=model.c
|
123
|
+
)(
|
124
|
+
bst.random.split_key(5)
|
125
|
+
)
|
126
|
+
|
127
|
+
model = Model()
|
128
|
+
r2 = bst.augment.vmap(
|
129
|
+
model,
|
130
|
+
in_states=model.states(bst.ShortTermState),
|
131
|
+
out_states=model.c,
|
132
|
+
rngs=model.rng,
|
133
|
+
)(
|
134
|
+
bst.random.split_key(5)
|
135
|
+
)
|
71
136
|
|
72
137
|
def test_vmap_input(self):
|
73
138
|
model = bst.nn.Linear(2, 3)
|
@@ -98,8 +163,8 @@ class TestVmap(unittest.TestCase):
|
|
98
163
|
|
99
164
|
@bst.augment.vmap(in_axes=(None, 0), out_axes=0)
|
100
165
|
def forward(model, x):
|
101
|
-
self.assertTrue(id(model)
|
102
|
-
self.assertTrue(id(model.weight)
|
166
|
+
self.assertTrue(id(model) == model_id)
|
167
|
+
self.assertTrue(id(model.weight) == weight_id)
|
103
168
|
print(id(model), id(model.weight))
|
104
169
|
return model(x)
|
105
170
|
|
@@ -108,51 +173,7 @@ class TestVmap(unittest.TestCase):
|
|
108
173
|
print(model.weight.value_call(jnp.shape))
|
109
174
|
print(model.weight.value)
|
110
175
|
|
111
|
-
def
|
112
|
-
model = bst.nn.Linear(2, 3)
|
113
|
-
x = jnp.ones((5, 2))
|
114
|
-
|
115
|
-
@bst.augment.vmap(in_axes=(None, 0), out_axes=0)
|
116
|
-
def forward(model, x):
|
117
|
-
return model(x)
|
118
|
-
|
119
|
-
y = forward(model, x)
|
120
|
-
print(y.shape)
|
121
|
-
|
122
|
-
def test_vmap2(self):
|
123
|
-
class LinearEnsemble(bst.nn.Module):
|
124
|
-
def __init__(self, num):
|
125
|
-
super().__init__()
|
126
|
-
self.w = bst.ParamState(bst.random.random((num, 2, 3)))
|
127
|
-
|
128
|
-
model = LinearEnsemble(5)
|
129
|
-
x = jnp.ones((2,))
|
130
|
-
|
131
|
-
@bst.augment.vmap(in_axes=(0, None), out_axes=0)
|
132
|
-
def forward(model, x):
|
133
|
-
return jnp.dot(x, model.w.value)
|
134
|
-
|
135
|
-
y = forward(model, x)
|
136
|
-
print(y.shape)
|
137
|
-
|
138
|
-
def test_vmap3(self):
|
139
|
-
class Foo(bst.nn.Module):
|
140
|
-
def __init__(self):
|
141
|
-
super().__init__()
|
142
|
-
self.a = bst.ParamState(jnp.arange(4))
|
143
|
-
self.b = bst.ShortTermState(jnp.arange(4))
|
144
|
-
|
145
|
-
state_axes = bst.augment.StateAxes({bst.ParamState: 0, bst.ShortTermState: None})
|
146
|
-
|
147
|
-
@bst.augment.vmap(in_axes=(state_axes,), out_axes=0)
|
148
|
-
def mul(foo):
|
149
|
-
return foo.a.value * foo.b.value
|
150
|
-
|
151
|
-
foo = Foo()
|
152
|
-
y = mul(foo)
|
153
|
-
print(y.shape)
|
154
|
-
|
155
|
-
def test_vmap4(self):
|
176
|
+
def test_vmap_jit(self):
|
156
177
|
class Foo(bst.nn.Module):
|
157
178
|
def __init__(self):
|
158
179
|
super().__init__()
|
@@ -162,60 +183,35 @@ class TestVmap(unittest.TestCase):
|
|
162
183
|
def __call__(self):
|
163
184
|
self.b.value = self.a.value * self.b.value
|
164
185
|
|
165
|
-
@bst.augment.vmap
|
166
|
-
def mul(foo):
|
167
|
-
foo()
|
168
|
-
return foo
|
169
|
-
|
170
186
|
foo = Foo()
|
171
|
-
with bst.StateTraceStack() as trace:
|
172
|
-
m = mul(foo)
|
173
|
-
|
174
|
-
self.assertTrue(m is foo)
|
175
|
-
print(m.a.value, foo.a.value)
|
176
|
-
self.assertTrue(jnp.allclose(m.a.value, foo.a.value))
|
177
|
-
print(m.b.value, foo.b.value)
|
178
|
-
self.assertTrue(jnp.allclose(m.b.value, foo.b.value))
|
179
|
-
print(trace.get_write_states())
|
180
|
-
self.assertTrue(len(trace.get_write_states()) == 1)
|
181
|
-
print(trace.get_read_states())
|
182
|
-
self.assertTrue(len(trace.get_read_states()) == 2)
|
183
|
-
|
184
|
-
def test_vmap5(self):
|
185
|
-
class Foo(bst.nn.Module):
|
186
|
-
def __init__(self):
|
187
|
-
super().__init__()
|
188
|
-
self.a = bst.ParamState(jnp.arange(4))
|
189
|
-
self.b = bst.ShortTermState(jnp.arange(4))
|
190
187
|
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
@bst.augment.vmap
|
195
|
-
def mul(foo):
|
188
|
+
@bst.augment.vmap(in_states=foo.states())
|
189
|
+
def mul():
|
196
190
|
foo()
|
197
191
|
|
198
|
-
|
192
|
+
@bst.compile.jit
|
193
|
+
def mul_jit(inp):
|
194
|
+
mul()
|
195
|
+
foo.a.value += inp
|
196
|
+
|
199
197
|
with bst.StateTraceStack() as trace:
|
200
|
-
|
198
|
+
mul_jit(1.)
|
201
199
|
|
202
200
|
print(foo.a.value)
|
203
201
|
print(foo.b.value)
|
204
|
-
self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4)))
|
202
|
+
self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
|
205
203
|
self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
|
206
204
|
|
207
205
|
write_state_ids = [id(st) for st in trace.get_write_states()]
|
208
206
|
read_state_ids = [id(st) for st in trace.get_read_states()]
|
209
207
|
|
210
|
-
assert id(foo.a) in
|
208
|
+
assert id(foo.a) in write_state_ids
|
211
209
|
assert id(foo.b) in write_state_ids
|
212
210
|
|
213
211
|
print(trace.get_write_states())
|
214
212
|
print(trace.get_read_states())
|
215
213
|
|
216
|
-
|
217
|
-
|
218
|
-
def test_vmap_jit(self):
|
214
|
+
def test_vmap_jit_2(self):
|
219
215
|
class Foo(bst.nn.Module):
|
220
216
|
def __init__(self):
|
221
217
|
super().__init__()
|
@@ -225,70 +221,46 @@ class TestVmap(unittest.TestCase):
|
|
225
221
|
def __call__(self):
|
226
222
|
self.b.value = self.a.value * self.b.value
|
227
223
|
|
228
|
-
|
229
|
-
|
224
|
+
foo = Foo()
|
225
|
+
|
226
|
+
@bst.augment.vmap(in_states=foo.states())
|
227
|
+
def mul():
|
230
228
|
foo()
|
231
229
|
|
232
230
|
@bst.compile.jit
|
233
231
|
def mul_jit(inp):
|
234
|
-
mul(
|
235
|
-
foo.
|
232
|
+
mul()
|
233
|
+
foo.b.value += inp
|
236
234
|
|
237
|
-
foo = Foo()
|
238
235
|
with bst.StateTraceStack() as trace:
|
239
236
|
mul_jit(1.)
|
240
237
|
|
241
238
|
print(foo.a.value)
|
242
239
|
print(foo.b.value)
|
243
|
-
self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4)
|
244
|
-
self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
|
240
|
+
self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4)))
|
241
|
+
self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4) + 1.))
|
245
242
|
|
246
243
|
write_state_ids = [id(st) for st in trace.get_write_states()]
|
247
244
|
read_state_ids = [id(st) for st in trace.get_read_states()]
|
248
245
|
|
249
|
-
assert id(foo.a) in
|
246
|
+
assert id(foo.a) in read_state_ids
|
250
247
|
assert id(foo.b) in write_state_ids
|
251
248
|
|
252
249
|
print(trace.get_write_states())
|
253
250
|
print(trace.get_read_states())
|
254
251
|
|
255
252
|
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
def loss():
|
271
|
-
mul(foo)
|
272
|
-
return jnp.sum(foo.b.value)
|
273
|
-
|
274
|
-
foo = Foo()
|
275
|
-
with bst.StateTraceStack() as trace:
|
276
|
-
grads, loss = bst.augment.grad(loss, foo.states(bst.ParamState), return_value=True)()
|
277
|
-
print(grads)
|
278
|
-
print(loss)
|
279
|
-
|
280
|
-
# print(foo.a.value)
|
281
|
-
# print(foo.b.value)
|
282
|
-
# self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
|
283
|
-
# self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
|
284
|
-
#
|
285
|
-
# write_state_ids = [id(st) for st in trace.get_write_states()]
|
286
|
-
# read_state_ids = [id(st) for st in trace.get_read_states()]
|
287
|
-
#
|
288
|
-
# assert id(foo.a) in write_state_ids
|
289
|
-
# assert id(foo.b) in write_state_ids
|
290
|
-
#
|
291
|
-
# print(trace.get_write_states())
|
292
|
-
# print(trace.get_read_states())
|
293
|
-
|
294
|
-
|
253
|
+
class TestMap(unittest.TestCase):
|
254
|
+
def test_map(self):
|
255
|
+
for dim in [(10,), (10, 10), (10, 10, 10)]:
|
256
|
+
x = bst.random.rand(*dim)
|
257
|
+
r1 = bst.augment.map(lambda a: a + 1, x, batch_size=None)
|
258
|
+
r2 = bst.augment.map(lambda a: a + 1, x, batch_size=2)
|
259
|
+
r3 = bst.augment.map(lambda a: a + 1, x, batch_size=4)
|
260
|
+
r4 = bst.augment.map(lambda a: a + 1, x, batch_size=5)
|
261
|
+
true_r = x + 1
|
262
|
+
|
263
|
+
self.assertTrue(jnp.allclose(r1, true_r))
|
264
|
+
self.assertTrue(jnp.allclose(r2, true_r))
|
265
|
+
self.assertTrue(jnp.allclose(r3, true_r))
|
266
|
+
self.assertTrue(jnp.allclose(r4, true_r))
|
@@ -182,10 +182,12 @@ def checkpoint(
|
|
182
182
|
|
183
183
|
static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
|
184
184
|
fun = StatefulFunction(fun, static_argnums=static_argnums)
|
185
|
-
checkpointed_fun = jax.checkpoint(
|
186
|
-
|
187
|
-
|
188
|
-
|
185
|
+
checkpointed_fun = jax.checkpoint(
|
186
|
+
fun.jaxpr_call,
|
187
|
+
prevent_cse=prevent_cse,
|
188
|
+
policy=policy,
|
189
|
+
static_argnums=tuple(i + 1 for i in static_argnums)
|
190
|
+
)
|
189
191
|
|
190
192
|
@functools.wraps(fun.fun)
|
191
193
|
def remat_fun(*args, **params):
|
brainstate/compile/_jit.py
CHANGED
@@ -62,7 +62,6 @@ def _get_jitted_fun(
|
|
62
62
|
**kwargs
|
63
63
|
) -> JittedFunction:
|
64
64
|
static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
|
65
|
-
# TODO: add to cache stack for clear_cache
|
66
65
|
fun = StatefulFunction(fun, static_argnums=static_argnums, abstracted_axes=abstracted_axes, cache_type='jit')
|
67
66
|
jit_fun = jax.jit(fun.jaxpr_call,
|
68
67
|
static_argnums=tuple(i + 1 for i in static_argnums),
|
@@ -83,10 +82,12 @@ def _get_jitted_fun(
|
|
83
82
|
return fun.fun(*args, **params)
|
84
83
|
|
85
84
|
# compile the function and get the state trace
|
85
|
+
# print('Compiling ...')
|
86
86
|
state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
|
87
87
|
read_state_vals = state_trace.get_read_state_values(True)
|
88
88
|
|
89
89
|
# call the jitted function
|
90
|
+
# print('Running ...')
|
90
91
|
write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
|
91
92
|
# write the state values back to the states
|
92
93
|
write_back_state_values(state_trace, read_state_vals, write_state_vals)
|
@@ -127,11 +128,15 @@ def _get_jitted_fun(
|
|
127
128
|
"""
|
128
129
|
# compile the function and get the state trace
|
129
130
|
state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
|
130
|
-
read_state_vals = state_trace.get_read_state_values(True)
|
131
|
+
read_state_vals = state_trace.get_read_state_values(replace_writen=True)
|
132
|
+
write_state_vals = state_trace.get_write_state_values(replace_read=True)
|
131
133
|
|
132
|
-
#
|
133
|
-
|
134
|
+
# compile the model
|
135
|
+
ret = jit_fun.lower(state_trace.get_state_values(), *args, **params).compile()
|
134
136
|
|
137
|
+
# write the state values back to the states
|
138
|
+
write_back_state_values(state_trace, read_state_vals, write_state_vals)
|
139
|
+
return ret
|
135
140
|
|
136
141
|
jitted_fun: JittedFunction
|
137
142
|
|
@@ -290,31 +295,35 @@ def jit(
|
|
290
295
|
|
291
296
|
if isinstance(fun, Missing):
|
292
297
|
def wrapper(fun_again: Callable) -> JittedFunction:
|
293
|
-
return _get_jitted_fun(
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
298
|
+
return _get_jitted_fun(
|
299
|
+
fun_again,
|
300
|
+
in_shardings=in_shardings,
|
301
|
+
out_shardings=out_shardings,
|
302
|
+
static_argnums=static_argnums,
|
303
|
+
donate_argnums=donate_argnums,
|
304
|
+
donate_argnames=donate_argnames,
|
305
|
+
keep_unused=keep_unused,
|
306
|
+
device=device,
|
307
|
+
backend=backend,
|
308
|
+
inline=inline,
|
309
|
+
abstracted_axes=abstracted_axes,
|
310
|
+
**kwargs
|
311
|
+
)
|
305
312
|
|
306
313
|
return wrapper
|
307
314
|
|
308
315
|
else:
|
309
|
-
return _get_jitted_fun(
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
316
|
+
return _get_jitted_fun(
|
317
|
+
fun,
|
318
|
+
in_shardings,
|
319
|
+
out_shardings,
|
320
|
+
static_argnums,
|
321
|
+
donate_argnums,
|
322
|
+
donate_argnames,
|
323
|
+
keep_unused,
|
324
|
+
device,
|
325
|
+
backend,
|
326
|
+
inline,
|
327
|
+
abstracted_axes,
|
328
|
+
**kwargs
|
329
|
+
)
|
@@ -31,6 +31,7 @@ from ._util import write_back_state_values, wrap_single_fun
|
|
31
31
|
__all__ = [
|
32
32
|
# "scan" syntax, which is similar to jax.lax.scan
|
33
33
|
'scan', 'checkpointed_scan',
|
34
|
+
|
34
35
|
# "for_loop" syntax
|
35
36
|
'for_loop', 'checkpointed_for_loop',
|
36
37
|
]
|
@@ -48,9 +49,9 @@ def _wrap_fun_with_pbar(
|
|
48
49
|
@wraps(fun)
|
49
50
|
def new_fun(new_carry, inputs):
|
50
51
|
i, old_carry = new_carry
|
51
|
-
|
52
|
-
pbar_runner(unvmap(i, op='none'))
|
53
|
-
return (i + 1,
|
52
|
+
new_carry, new_outputs = fun(old_carry, inputs)
|
53
|
+
pbar_runner(unvmap(i, op='none'), carry=new_carry, y=new_outputs)
|
54
|
+
return (i + 1, new_carry), new_outputs
|
54
55
|
|
55
56
|
return new_fun
|
56
57
|
|
@@ -206,7 +207,7 @@ def scan(
|
|
206
207
|
|
207
208
|
# evaluate jaxpr, get all states #
|
208
209
|
# ------------------------------ #
|
209
|
-
xs_avals = [jax.core.
|
210
|
+
xs_avals = [jax.core.get_aval(x) for x in xs_flat]
|
210
211
|
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
211
212
|
stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
|
212
213
|
state_trace = stateful_fun.get_state_trace()
|
@@ -302,7 +303,7 @@ def checkpointed_scan(
|
|
302
303
|
pbar_runner = None
|
303
304
|
|
304
305
|
# evaluate jaxpr
|
305
|
-
xs_avals = [jax.core.
|
306
|
+
xs_avals = [jax.core.get_aval(x) for x in xs_flat]
|
306
307
|
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
307
308
|
stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
|
308
309
|
state_trace = stateful_fun.get_state_trace()
|
@@ -476,6 +477,8 @@ def checkpointed_for_loop(
|
|
476
477
|
return ys
|
477
478
|
|
478
479
|
|
480
|
+
# This function is adapted from ``while_loop`` in `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
|
481
|
+
|
479
482
|
# There's several tricks happening here to work around various limitations of JAX.
|
480
483
|
# (Also see https://github.com/google/jax/issues/2139#issuecomment-1039293633)
|
481
484
|
# 1. `unvmap_any` prior to using `lax.cond`. JAX has a problem in that vmap-of-cond
|
@@ -134,6 +134,8 @@ def bounded_while_loop(
|
|
134
134
|
"""
|
135
135
|
While loop with a bound on the maximum number of steps.
|
136
136
|
|
137
|
+
This function is adapted from ``while_loop`` in `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
|
138
|
+
|
137
139
|
This function is useful when you want to ensure that a while loop terminates
|
138
140
|
even if the condition function is never false. The function is implemented
|
139
141
|
using a scan operation, so it is reverse-mode differentiable.
|
@@ -396,7 +396,7 @@ class StatefulFunction(object):
|
|
396
396
|
# Checking whether the states are returned.
|
397
397
|
for leaf in jax.tree.leaves(out):
|
398
398
|
if isinstance(leaf, State):
|
399
|
-
leaf.
|
399
|
+
leaf.raise_error_with_source_info(ValueError(f"State object is not allowed to be returned: {leaf}"))
|
400
400
|
return out, state_values
|
401
401
|
|
402
402
|
def make_jaxpr(self, *args, return_only_write: bool = False, **kwargs):
|
@@ -741,11 +741,15 @@ def _make_jaxpr(
|
|
741
741
|
in_type = tuple(jax.util.safe_zip(in_avals, keep_inputs))
|
742
742
|
f, out_tree = _flatten_fun(f, in_tree)
|
743
743
|
f = annotate(f, in_type)
|
744
|
-
|
744
|
+
if jax.__version_info__ < (0, 5, 0):
|
745
|
+
debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
|
745
746
|
with ExitStack() as stack:
|
746
747
|
for axis_name, size in axis_env or []:
|
747
748
|
stack.enter_context(jax.core.extend_axis_env(axis_name, size, None))
|
748
|
-
|
749
|
+
if jax.__version_info__ < (0, 5, 0):
|
750
|
+
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
|
751
|
+
else:
|
752
|
+
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
|
749
753
|
closed_jaxpr = jax.core.ClosedJaxpr(jaxpr, consts)
|
750
754
|
if return_shape:
|
751
755
|
out_avals, _ = jax.util.unzip2(out_type)
|
@@ -18,6 +18,7 @@ from __future__ import annotations
|
|
18
18
|
import unittest
|
19
19
|
|
20
20
|
import jax
|
21
|
+
import jax.extend as je
|
21
22
|
import jax.numpy as jnp
|
22
23
|
import pytest
|
23
24
|
|
@@ -84,7 +85,7 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
84
85
|
print(jaxpr)
|
85
86
|
jaxpr, _ = bst.compile.make_jaxpr(f3)(jnp.zeros(1))
|
86
87
|
print(jaxpr)
|
87
|
-
self.assertTrue(jnp.allclose(
|
88
|
+
self.assertTrue(jnp.allclose(je.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
|
88
89
|
f3(jnp.zeros(1))))
|
89
90
|
|
90
91
|
def test_compar_jax_make_jaxpr2(self):
|