brainstate 0.1.0.post20250212__py2.py3-none-any.whl → 0.1.0.post20250217__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +853 -90
- brainstate/_state_test.py +1 -3
- brainstate/augment/__init__.py +2 -2
- brainstate/augment/_autograd.py +257 -115
- brainstate/augment/_autograd_test.py +2 -3
- brainstate/augment/_eval_shape.py +3 -4
- brainstate/augment/_mapping.py +582 -62
- brainstate/augment/_mapping_test.py +114 -30
- brainstate/augment/_random.py +61 -7
- brainstate/compile/_ad_checkpoint.py +2 -3
- brainstate/compile/_conditions.py +4 -5
- brainstate/compile/_conditions_test.py +1 -2
- brainstate/compile/_error_if.py +1 -2
- brainstate/compile/_error_if_test.py +1 -2
- brainstate/compile/_jit.py +23 -16
- brainstate/compile/_jit_test.py +1 -2
- brainstate/compile/_loop_collect_return.py +18 -10
- brainstate/compile/_loop_collect_return_test.py +1 -1
- brainstate/compile/_loop_no_collection.py +5 -5
- brainstate/compile/_make_jaxpr.py +23 -21
- brainstate/compile/_make_jaxpr_test.py +1 -2
- brainstate/compile/_progress_bar.py +1 -2
- brainstate/compile/_unvmap.py +1 -0
- brainstate/compile/_util.py +4 -2
- brainstate/environ.py +4 -4
- brainstate/environ_test.py +1 -2
- brainstate/functional/_activations.py +1 -2
- brainstate/functional/_activations_test.py +1 -1
- brainstate/functional/_normalization.py +1 -2
- brainstate/functional/_others.py +1 -2
- brainstate/functional/_spikes.py +136 -20
- brainstate/graph/_graph_node.py +2 -43
- brainstate/graph/_graph_operation.py +4 -20
- brainstate/graph/_graph_operation_test.py +3 -4
- brainstate/init/_base.py +1 -2
- brainstate/init/_generic.py +1 -2
- brainstate/nn/__init__.py +8 -0
- brainstate/nn/_collective_ops.py +351 -48
- brainstate/nn/_collective_ops_test.py +36 -0
- brainstate/nn/_common.py +193 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
- brainstate/nn/_dyn_impl/_inputs.py +1 -2
- brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
- brainstate/nn/_dyn_impl/_readout.py +2 -3
- brainstate/nn/_dyn_impl/_readout_test.py +1 -2
- brainstate/nn/_dynamics/_dynamics_base.py +6 -1
- brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
- brainstate/nn/_dynamics/_state_delay.py +3 -3
- brainstate/nn/_dynamics/_synouts_test.py +1 -2
- brainstate/nn/_elementwise/_dropout.py +6 -7
- brainstate/nn/_elementwise/_dropout_test.py +1 -2
- brainstate/nn/_elementwise/_elementwise.py +1 -2
- brainstate/nn/_exp_euler.py +1 -2
- brainstate/nn/_exp_euler_test.py +1 -2
- brainstate/nn/_interaction/_conv.py +1 -2
- brainstate/nn/_interaction/_conv_test.py +1 -0
- brainstate/nn/_interaction/_linear.py +1 -2
- brainstate/nn/_interaction/_linear_test.py +1 -2
- brainstate/nn/_interaction/_normalizations.py +1 -2
- brainstate/nn/_interaction/_poolings.py +3 -4
- brainstate/nn/_module.py +68 -19
- brainstate/nn/_module_test.py +1 -2
- brainstate/nn/_utils.py +89 -0
- brainstate/nn/metrics.py +3 -4
- brainstate/optim/_lr_scheduler.py +1 -2
- brainstate/optim/_lr_scheduler_test.py +2 -3
- brainstate/optim/_optax_optimizer_test.py +1 -2
- brainstate/optim/_sgd_optimizer.py +2 -3
- brainstate/random/_rand_funs.py +1 -2
- brainstate/random/_rand_funs_test.py +2 -3
- brainstate/random/_rand_seed.py +2 -3
- brainstate/random/_rand_seed_test.py +1 -2
- brainstate/random/_rand_state.py +3 -4
- brainstate/surrogate.py +5 -5
- brainstate/transform.py +0 -3
- brainstate/typing.py +28 -25
- brainstate/util/__init__.py +9 -7
- brainstate/util/_caller.py +1 -2
- brainstate/util/_error.py +27 -0
- brainstate/util/_others.py +60 -15
- brainstate/util/{_dict.py → _pretty_pytree.py} +2 -2
- brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
- brainstate/util/_pretty_repr.py +1 -2
- brainstate/util/_pretty_table.py +2900 -0
- brainstate/util/_struct.py +11 -11
- brainstate/util/filter.py +472 -0
- {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/METADATA +2 -2
- brainstate-0.1.0.post20250217.dist-info/RECORD +128 -0
- brainstate/util/_filter.py +0 -178
- brainstate-0.1.0.post20250212.dist-info/RECORD +0 -124
- {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/top_level.txt +0 -0
@@ -15,12 +15,14 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
import
|
19
|
-
|
18
|
+
import jax
|
20
19
|
import jax.numpy as jnp
|
20
|
+
import numpy as np
|
21
|
+
import unittest
|
21
22
|
|
22
23
|
import brainstate as bst
|
23
24
|
from brainstate.augment._mapping import BatchAxisError
|
25
|
+
from brainstate.augment._mapping import _remove_axis
|
24
26
|
|
25
27
|
|
26
28
|
class TestVmap(unittest.TestCase):
|
@@ -99,6 +101,27 @@ class TestVmap(unittest.TestCase):
|
|
99
101
|
)
|
100
102
|
print(bst.random.DEFAULT)
|
101
103
|
|
104
|
+
def test_vmap_with_random_v3(self):
|
105
|
+
class Model(bst.nn.Module):
|
106
|
+
def __init__(self):
|
107
|
+
super().__init__()
|
108
|
+
|
109
|
+
self.a = bst.ShortTermState(bst.random.randn(5))
|
110
|
+
self.b = bst.ShortTermState(bst.random.randn(5))
|
111
|
+
self.c = bst.State(bst.random.randn(1))
|
112
|
+
|
113
|
+
def __call__(self):
|
114
|
+
self.c.value = self.a.value * self.b.value
|
115
|
+
return self.c.value + bst.random.randn(1)
|
116
|
+
|
117
|
+
model = Model()
|
118
|
+
r2 = bst.augment.vmap(
|
119
|
+
model,
|
120
|
+
in_states=model.states(bst.ShortTermState),
|
121
|
+
out_states=model.c
|
122
|
+
)()
|
123
|
+
print(bst.random.DEFAULT)
|
124
|
+
|
102
125
|
def test_vmap_with_random_2(self):
|
103
126
|
class Model(bst.nn.Module):
|
104
127
|
def __init__(self):
|
@@ -114,22 +137,11 @@ class TestVmap(unittest.TestCase):
|
|
114
137
|
self.c.value = self.a.value * self.b.value
|
115
138
|
return self.c.value + bst.random.randn(1)
|
116
139
|
|
117
|
-
model = Model()
|
118
|
-
with self.assertRaises(BatchAxisError):
|
119
|
-
r2 = bst.augment.vmap(
|
120
|
-
model,
|
121
|
-
in_states=model.states(bst.ShortTermState),
|
122
|
-
out_states=model.c
|
123
|
-
)(
|
124
|
-
bst.random.split_key(5)
|
125
|
-
)
|
126
|
-
|
127
140
|
model = Model()
|
128
141
|
r2 = bst.augment.vmap(
|
129
142
|
model,
|
130
143
|
in_states=model.states(bst.ShortTermState),
|
131
|
-
out_states=model.c
|
132
|
-
rngs=model.rng,
|
144
|
+
out_states=model.c
|
133
145
|
)(
|
134
146
|
bst.random.split_key(5)
|
135
147
|
)
|
@@ -154,24 +166,17 @@ class TestVmap(unittest.TestCase):
|
|
154
166
|
print(model.weight.value_call(jnp.shape))
|
155
167
|
print(model.weight.value)
|
156
168
|
|
157
|
-
def
|
158
|
-
|
159
|
-
|
160
|
-
weight_id = id(model.weight)
|
161
|
-
print(id(model), id(model.weight))
|
162
|
-
x = jnp.ones((5, 2))
|
169
|
+
def test_vmap_states_and_input_1(self):
|
170
|
+
gru = bst.nn.GRUCell(2, 3)
|
171
|
+
gru.init_state(5)
|
163
172
|
|
164
|
-
@bst.augment.vmap(
|
165
|
-
def forward(
|
166
|
-
|
167
|
-
self.assertTrue(id(model.weight) == weight_id)
|
168
|
-
print(id(model), id(model.weight))
|
169
|
-
return model(x)
|
173
|
+
@bst.augment.vmap(in_states=gru.states(bst.HiddenState))
|
174
|
+
def forward(x):
|
175
|
+
return gru(x)
|
170
176
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
print(model.weight.value)
|
177
|
+
xs = bst.random.randn(5, 2)
|
178
|
+
y = forward(xs)
|
179
|
+
self.assertTrue(y.shape == (5, 3))
|
175
180
|
|
176
181
|
def test_vmap_jit(self):
|
177
182
|
class Foo(bst.nn.Module):
|
@@ -249,6 +254,16 @@ class TestVmap(unittest.TestCase):
|
|
249
254
|
print(trace.get_write_states())
|
250
255
|
print(trace.get_read_states())
|
251
256
|
|
257
|
+
def test_auto_rand_key_split(self):
|
258
|
+
def f():
|
259
|
+
return bst.random.rand(1)
|
260
|
+
|
261
|
+
res = bst.augment.vmap(f, axis_size=10)()
|
262
|
+
self.assertTrue(jnp.all(~(res[0] == res[1:])))
|
263
|
+
|
264
|
+
res2 = jax.vmap(f, axis_size=10)()
|
265
|
+
self.assertTrue(jnp.all((res2[0] == res2[1:])))
|
266
|
+
|
252
267
|
|
253
268
|
class TestMap(unittest.TestCase):
|
254
269
|
def test_map(self):
|
@@ -264,3 +279,72 @@ class TestMap(unittest.TestCase):
|
|
264
279
|
self.assertTrue(jnp.allclose(r2, true_r))
|
265
280
|
self.assertTrue(jnp.allclose(r3, true_r))
|
266
281
|
self.assertTrue(jnp.allclose(r4, true_r))
|
282
|
+
|
283
|
+
|
284
|
+
class TestRemoveAxis:
|
285
|
+
|
286
|
+
def test_remove_axis_2d_array_axis_0(self):
|
287
|
+
input_array = np.array([[1, 2, 3], [4, 5, 6]])
|
288
|
+
expected_output = np.array([1, 2, 3])
|
289
|
+
|
290
|
+
result = _remove_axis(input_array, axis=0)
|
291
|
+
|
292
|
+
np.testing.assert_array_equal(result, expected_output)
|
293
|
+
|
294
|
+
def test_remove_axis_3d_array(self):
|
295
|
+
# Create a 3D array
|
296
|
+
x = np.arange(24).reshape((2, 3, 4))
|
297
|
+
|
298
|
+
# Remove axis 1
|
299
|
+
result = _remove_axis(x, axis=1)
|
300
|
+
|
301
|
+
# Expected result: a 2D array with shape (2, 4)
|
302
|
+
expected = x[:, 0, :]
|
303
|
+
|
304
|
+
np.testing.assert_array_equal(result, expected)
|
305
|
+
assert result.shape == (2, 4)
|
306
|
+
|
307
|
+
def test_remove_axis_1d_array(self):
|
308
|
+
# Create a 1D array
|
309
|
+
x = np.array([1, 2, 3, 4, 5])
|
310
|
+
|
311
|
+
# Remove axis 0 (the only axis in a 1D array)
|
312
|
+
result = _remove_axis(x, axis=0)
|
313
|
+
|
314
|
+
# Check that the result is a scalar (0D array) and equal to the first element
|
315
|
+
assert np.isscalar(result), "Result should be a scalar"
|
316
|
+
assert result == 1, "Result should be equal to the first element of the input array"
|
317
|
+
|
318
|
+
def test_remove_axis_out_of_bounds(self):
|
319
|
+
x = jnp.array([[1, 2], [3, 4]])
|
320
|
+
with unittest.TestCase().assertRaises(IndexError):
|
321
|
+
_remove_axis(x, axis=2)
|
322
|
+
|
323
|
+
def test_remove_axis_negative(self):
|
324
|
+
x = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
325
|
+
result = _remove_axis(x, -1)
|
326
|
+
expected = jnp.array([[1, 3], [5, 7]])
|
327
|
+
np.testing.assert_array_equal(result, expected)
|
328
|
+
|
329
|
+
def test_remove_axis_with_nan_and_inf(self):
|
330
|
+
x = jnp.array([[1.0, jnp.nan, 3.0], [4.0, 5.0, jnp.inf]])
|
331
|
+
result = _remove_axis(x, axis=0)
|
332
|
+
expected = jnp.array([1.0, jnp.nan, 3.0])
|
333
|
+
np.testing.assert_array_equal(result, expected)
|
334
|
+
assert jnp.isnan(result[1])
|
335
|
+
|
336
|
+
def test_remove_axis_different_dtypes(self):
|
337
|
+
# Test with integer array
|
338
|
+
int_array = jnp.array([[1, 2, 3], [4, 5, 6]])
|
339
|
+
int_result = _remove_axis(int_array, 0)
|
340
|
+
assert jnp.array_equal(int_result, jnp.array([1, 2, 3]))
|
341
|
+
|
342
|
+
# Test with float array
|
343
|
+
float_array = jnp.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]])
|
344
|
+
float_result = _remove_axis(float_array, 1)
|
345
|
+
assert jnp.allclose(float_result, jnp.array([1.1, 4.4]))
|
346
|
+
|
347
|
+
# Test with complex array
|
348
|
+
complex_array = jnp.array([[1 + 1j, 2 + 2j], [3 + 3j, 4 + 4j]])
|
349
|
+
complex_result = _remove_axis(complex_array, 0)
|
350
|
+
assert jnp.allclose(complex_result, jnp.array([1 + 1j, 2 + 2j]))
|
brainstate/augment/_random.py
CHANGED
@@ -20,34 +20,62 @@ from typing import Callable, Sequence, Union
|
|
20
20
|
|
21
21
|
from brainstate.random import DEFAULT, RandomState
|
22
22
|
from brainstate.typing import Missing
|
23
|
+
from brainstate.util import PrettyObject
|
23
24
|
|
24
25
|
__all__ = [
|
25
26
|
'restore_rngs'
|
26
27
|
]
|
27
28
|
|
28
29
|
|
29
|
-
class RngRestore:
|
30
|
+
class RngRestore(PrettyObject):
|
30
31
|
"""
|
31
32
|
Backup and restore the random state of a sequence of RandomState instances.
|
33
|
+
|
34
|
+
This class provides functionality to save the current state of multiple
|
35
|
+
RandomState instances and later restore them to their saved states.
|
36
|
+
|
37
|
+
Attributes:
|
38
|
+
rngs (Sequence[RandomState]): A sequence of RandomState instances to manage.
|
39
|
+
rng_keys (list): A list to store the backed up random keys.
|
32
40
|
"""
|
33
41
|
|
34
42
|
def __init__(self, rngs: Sequence[RandomState]):
|
43
|
+
"""
|
44
|
+
Initialize the RngRestore instance.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
rngs (Sequence[RandomState]): A sequence of RandomState instances
|
48
|
+
whose states will be managed.
|
49
|
+
"""
|
35
50
|
self.rngs: Sequence[RandomState] = rngs
|
36
51
|
self.rng_keys = []
|
37
52
|
|
38
53
|
def backup(self):
|
39
54
|
"""
|
40
55
|
Backup the current random key of the RandomState instances.
|
56
|
+
|
57
|
+
This method saves the current value (state) of each RandomState
|
58
|
+
instance in the rngs sequence.
|
41
59
|
"""
|
42
60
|
self.rng_keys = [rng.value for rng in self.rngs]
|
43
61
|
|
44
62
|
def restore(self):
|
45
63
|
"""
|
46
64
|
Restore the random key of the RandomState instances.
|
65
|
+
|
66
|
+
This method restores each RandomState instance to its previously
|
67
|
+
saved state. It raises an error if the number of saved keys doesn't
|
68
|
+
match the number of RandomState instances.
|
69
|
+
|
70
|
+
Raises:
|
71
|
+
ValueError: If the number of saved random keys does not match
|
72
|
+
the number of RandomState instances.
|
47
73
|
"""
|
74
|
+
if len(self.rng_keys) != len(self.rngs):
|
75
|
+
raise ValueError('The number of random keys does not match the number of random states.')
|
48
76
|
for rng, key in zip(self.rngs, self.rng_keys):
|
49
77
|
rng.restore_value(key)
|
50
|
-
self.rng_keys
|
78
|
+
self.rng_keys.clear()
|
51
79
|
|
52
80
|
|
53
81
|
def _rng_backup(
|
@@ -74,19 +102,45 @@ def restore_rngs(
|
|
74
102
|
rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
|
75
103
|
) -> Callable:
|
76
104
|
"""
|
77
|
-
|
105
|
+
Decorator to backup and restore the random state before and after a function call.
|
106
|
+
|
107
|
+
This function can be used as a decorator or called directly. It ensures that the
|
108
|
+
random state of the specified RandomState instances is preserved across function calls,
|
109
|
+
which is useful for maintaining reproducibility in stochastic operations.
|
78
110
|
|
79
111
|
Parameters
|
80
112
|
----------
|
81
113
|
fn : Callable, optional
|
82
|
-
The function to be wrapped.
|
83
|
-
|
84
|
-
|
114
|
+
The function to be wrapped. If not provided, the decorator can be used
|
115
|
+
with parameters.
|
116
|
+
rngs : Union[RandomState, Sequence[RandomState]], optional
|
117
|
+
The random state(s) to be backed up and restored. This can be a single
|
118
|
+
RandomState instance or a sequence of RandomState instances. If not provided,
|
119
|
+
the default RandomState instance will be used.
|
85
120
|
|
86
121
|
Returns
|
87
122
|
-------
|
88
123
|
Callable
|
89
|
-
|
124
|
+
If `fn` is provided, returns the wrapped function that will backup the
|
125
|
+
random state before execution and restore it afterwards.
|
126
|
+
If `fn` is not provided, returns a partial function that can be used as
|
127
|
+
a decorator with the specified `rngs`.
|
128
|
+
|
129
|
+
Raises
|
130
|
+
------
|
131
|
+
AssertionError
|
132
|
+
If `rngs` is not a RandomState instance or a sequence of RandomState instances.
|
133
|
+
|
134
|
+
Examples
|
135
|
+
--------
|
136
|
+
>>> @restore_rngs
|
137
|
+
... def my_random_function():
|
138
|
+
... return random.random()
|
139
|
+
|
140
|
+
>>> rng = RandomState(42)
|
141
|
+
>>> @restore_rngs(rngs=rng)
|
142
|
+
... def another_random_function():
|
143
|
+
... return rng.random()
|
90
144
|
"""
|
91
145
|
if isinstance(fn, Missing):
|
92
146
|
return functools.partial(restore_rngs, rngs=rngs)
|
@@ -16,9 +16,8 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import functools
|
19
|
-
from typing import Callable, Tuple, Union
|
20
|
-
|
21
19
|
import jax
|
20
|
+
from typing import Callable, Tuple, Union
|
22
21
|
|
23
22
|
from brainstate.typing import Missing
|
24
23
|
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
|
@@ -181,7 +180,7 @@ def checkpoint(
|
|
181
180
|
return lambda f: checkpoint(f, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums)
|
182
181
|
|
183
182
|
static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
|
184
|
-
fun = StatefulFunction(fun, static_argnums=static_argnums)
|
183
|
+
fun = StatefulFunction(fun, static_argnums=static_argnums, name='checkpoint')
|
185
184
|
checkpointed_fun = jax.checkpoint(
|
186
185
|
fun.jaxpr_call,
|
187
186
|
prevent_cse=prevent_cse,
|
@@ -15,11 +15,10 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
from collections.abc import Callable, Sequence
|
19
|
-
|
20
18
|
import jax
|
21
19
|
import jax.numpy as jnp
|
22
20
|
import numpy as np
|
21
|
+
from collections.abc import Callable, Sequence
|
23
22
|
|
24
23
|
from brainstate._utils import set_module_as
|
25
24
|
from ._error_if import jit_error_if
|
@@ -94,8 +93,8 @@ def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
|
|
94
93
|
return false_fun(*operands)
|
95
94
|
|
96
95
|
# evaluate jaxpr
|
97
|
-
stateful_true = StatefulFunction(true_fun).make_jaxpr(*operands)
|
98
|
-
stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
|
96
|
+
stateful_true = StatefulFunction(true_fun, name='cond:true').make_jaxpr(*operands)
|
97
|
+
stateful_false = StatefulFunction(false_fun, name='conda:false').make_jaxpr(*operands)
|
99
98
|
|
100
99
|
# state trace and state values
|
101
100
|
state_trace = stateful_true.get_state_trace() + stateful_false.get_state_trace()
|
@@ -174,7 +173,7 @@ def switch(index, branches: Sequence[Callable], *operands):
|
|
174
173
|
return branches[int(index)](*operands)
|
175
174
|
|
176
175
|
# evaluate jaxpr
|
177
|
-
wrapped_branches = [StatefulFunction(branch) for branch in branches]
|
176
|
+
wrapped_branches = [StatefulFunction(branch, name='switch') for branch in branches]
|
178
177
|
for wrapped_branch in wrapped_branches:
|
179
178
|
wrapped_branch.make_jaxpr(*operands)
|
180
179
|
|
brainstate/compile/_error_if.py
CHANGED
brainstate/compile/_jit.py
CHANGED
@@ -16,12 +16,11 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import functools
|
19
|
-
from collections.abc import Iterable, Sequence
|
20
|
-
from typing import (Any, Callable, Union)
|
21
|
-
|
22
19
|
import jax
|
20
|
+
from collections.abc import Iterable, Sequence
|
23
21
|
from jax._src import sharding_impls
|
24
22
|
from jax.lib import xla_client as xc
|
23
|
+
from typing import (Any, Callable, Union)
|
25
24
|
|
26
25
|
from brainstate._utils import set_module_as
|
27
26
|
from brainstate.typing import Missing
|
@@ -62,19 +61,27 @@ def _get_jitted_fun(
|
|
62
61
|
**kwargs
|
63
62
|
) -> JittedFunction:
|
64
63
|
static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
|
65
|
-
fun = StatefulFunction(
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
64
|
+
fun = StatefulFunction(
|
65
|
+
fun,
|
66
|
+
static_argnums=static_argnums,
|
67
|
+
abstracted_axes=abstracted_axes,
|
68
|
+
cache_type='jit',
|
69
|
+
name='jit'
|
70
|
+
)
|
71
|
+
jit_fun = jax.jit(
|
72
|
+
fun.jaxpr_call,
|
73
|
+
static_argnums=tuple(i + 1 for i in static_argnums),
|
74
|
+
donate_argnums=donate_argnums,
|
75
|
+
donate_argnames=donate_argnames,
|
76
|
+
keep_unused=keep_unused,
|
77
|
+
device=device,
|
78
|
+
backend=backend,
|
79
|
+
inline=inline,
|
80
|
+
in_shardings=in_shardings,
|
81
|
+
out_shardings=out_shardings,
|
82
|
+
abstracted_axes=abstracted_axes,
|
83
|
+
**kwargs
|
84
|
+
)
|
78
85
|
|
79
86
|
@functools.wraps(fun.fun)
|
80
87
|
def jitted_fun(*args, **params):
|
brainstate/compile/_jit_test.py
CHANGED
@@ -16,11 +16,11 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import math
|
19
|
-
from functools import wraps
|
20
|
-
from typing import Callable, Optional, TypeVar, Tuple, Any
|
21
19
|
|
22
20
|
import jax
|
23
21
|
import jax.numpy as jnp
|
22
|
+
from functools import wraps
|
23
|
+
from typing import Callable, Optional, TypeVar, Tuple, Any
|
24
24
|
|
25
25
|
from brainstate._utils import set_module_as
|
26
26
|
from ._make_jaxpr import StatefulFunction
|
@@ -209,7 +209,7 @@ def scan(
|
|
209
209
|
# ------------------------------ #
|
210
210
|
xs_avals = [jax.core.get_aval(x) for x in xs_flat]
|
211
211
|
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
212
|
-
stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
|
212
|
+
stateful_fun = StatefulFunction(f, name='scan').make_jaxpr(init, xs_tree.unflatten(x_avals))
|
213
213
|
state_trace = stateful_fun.get_state_trace()
|
214
214
|
all_writen_state_vals = state_trace.get_write_state_values(True)
|
215
215
|
all_read_state_vals = state_trace.get_read_state_values(True)
|
@@ -217,12 +217,20 @@ def scan(
|
|
217
217
|
|
218
218
|
# scan
|
219
219
|
init = (all_writen_state_vals, init)
|
220
|
-
(
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
220
|
+
(
|
221
|
+
(
|
222
|
+
all_writen_state_vals,
|
223
|
+
carry
|
224
|
+
),
|
225
|
+
ys
|
226
|
+
) = jax.lax.scan(
|
227
|
+
wrapped_f,
|
228
|
+
init,
|
229
|
+
xs,
|
230
|
+
length=length,
|
231
|
+
reverse=reverse,
|
232
|
+
unroll=unroll
|
233
|
+
)
|
226
234
|
# assign the written state values and restore the read state values
|
227
235
|
write_back_state_values(state_trace, all_read_state_vals, all_writen_state_vals)
|
228
236
|
# carry
|
@@ -305,7 +313,7 @@ def checkpointed_scan(
|
|
305
313
|
# evaluate jaxpr
|
306
314
|
xs_avals = [jax.core.get_aval(x) for x in xs_flat]
|
307
315
|
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
308
|
-
stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
|
316
|
+
stateful_fun = StatefulFunction(f, name='checkpoint_scan').make_jaxpr(init, xs_tree.unflatten(x_avals))
|
309
317
|
state_trace = stateful_fun.get_state_trace()
|
310
318
|
# get all states
|
311
319
|
been_written = state_trace.been_writen
|
@@ -16,9 +16,9 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import math
|
19
|
-
from typing import Any, Callable, TypeVar
|
20
19
|
|
21
20
|
import jax
|
21
|
+
from typing import Any, Callable, TypeVar
|
22
22
|
|
23
23
|
from brainstate._utils import set_module_as
|
24
24
|
from ._loop_collect_return import _bounded_while_loop
|
@@ -103,8 +103,8 @@ def while_loop(
|
|
103
103
|
pass
|
104
104
|
|
105
105
|
# evaluate jaxpr
|
106
|
-
stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
|
107
|
-
stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
|
106
|
+
stateful_cond = StatefulFunction(cond_fun, name='while:cond').make_jaxpr(init_val)
|
107
|
+
stateful_body = StatefulFunction(body_fun, name='while:body').make_jaxpr(init_val)
|
108
108
|
if len(stateful_cond.get_write_states()) != 0:
|
109
109
|
raise ValueError("while_loop: cond_fun should not have any write states.")
|
110
110
|
|
@@ -162,8 +162,8 @@ def bounded_while_loop(
|
|
162
162
|
return init_val
|
163
163
|
|
164
164
|
# evaluate jaxpr
|
165
|
-
stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
|
166
|
-
stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
|
165
|
+
stateful_cond = StatefulFunction(cond_fun, name='bounded_while:cond').make_jaxpr(init_val)
|
166
|
+
stateful_body = StatefulFunction(body_fun, name='bounded_while:body').make_jaxpr(init_val)
|
167
167
|
if len(stateful_cond.get_write_states()) != 0:
|
168
168
|
raise ValueError("while_loop: cond_fun should not have any write states.")
|
169
169
|
|