brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +10 -3
- brainstate/_state.py +178 -178
- brainstate/_utils.py +0 -1
- brainstate/augment/_autograd.py +0 -2
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape.py +0 -2
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping.py +2 -3
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/augment/_random.py +0 -2
- brainstate/compile/_ad_checkpoint.py +0 -2
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions.py +0 -2
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if.py +0 -2
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_jit.py +9 -8
- brainstate/compile/_loop_collect_return.py +0 -2
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection.py +0 -2
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +30 -17
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/compile/_progress_bar.py +0 -1
- brainstate/compile/_unvmap.py +0 -1
- brainstate/compile/_util.py +0 -2
- brainstate/environ.py +0 -2
- brainstate/functional/_activations.py +0 -2
- brainstate/functional/_activations_test.py +61 -61
- brainstate/functional/_normalization.py +0 -2
- brainstate/functional/_others.py +0 -2
- brainstate/functional/_spikes.py +0 -1
- brainstate/graph/_graph_node.py +1 -3
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation.py +4 -2
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_base.py +0 -2
- brainstate/init/_generic.py +0 -1
- brainstate/init/_random_inits.py +0 -1
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits.py +0 -2
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/mixin.py +0 -2
- brainstate/nn/_collective_ops.py +0 -3
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/_common.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
- brainstate/nn/_dyn_impl/_inputs.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
- brainstate/nn/_dyn_impl/_readout.py +0 -1
- brainstate/nn/_dyn_impl/_readout_test.py +9 -10
- brainstate/nn/_dynamics/_dynamics_base.py +0 -1
- brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
- brainstate/nn/_dynamics/_projection_base.py +0 -1
- brainstate/nn/_dynamics/_state_delay.py +0 -2
- brainstate/nn/_dynamics/_synouts.py +0 -2
- brainstate/nn/_dynamics/_synouts_test.py +4 -5
- brainstate/nn/_elementwise/_dropout.py +0 -2
- brainstate/nn/_elementwise/_dropout_test.py +9 -9
- brainstate/nn/_elementwise/_elementwise.py +0 -2
- brainstate/nn/_elementwise/_elementwise_test.py +57 -59
- brainstate/nn/_event/_fixedprob_mv.py +0 -1
- brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
- brainstate/nn/_event/_linear_mv.py +0 -2
- brainstate/nn/_event/_linear_mv_test.py +0 -1
- brainstate/nn/_exp_euler.py +0 -2
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/_interaction/_conv.py +0 -2
- brainstate/nn/_interaction/_conv_test.py +31 -33
- brainstate/nn/_interaction/_embedding.py +0 -1
- brainstate/nn/_interaction/_linear.py +0 -2
- brainstate/nn/_interaction/_linear_test.py +15 -17
- brainstate/nn/_interaction/_normalizations.py +0 -2
- brainstate/nn/_interaction/_normalizations_test.py +10 -12
- brainstate/nn/_interaction/_poolings.py +0 -2
- brainstate/nn/_interaction/_poolings_test.py +19 -21
- brainstate/nn/_module.py +0 -1
- brainstate/nn/_module_test.py +34 -37
- brainstate/nn/metrics.py +0 -2
- brainstate/optim/_base.py +0 -2
- brainstate/optim/_lr_scheduler.py +0 -1
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer.py +0 -2
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/optim/_sgd_optimizer.py +0 -1
- brainstate/random/_rand_funs.py +0 -1
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed.py +0 -1
- brainstate/random/_rand_seed_test.py +10 -12
- brainstate/random/_rand_state.py +0 -1
- brainstate/surrogate.py +0 -1
- brainstate/typing.py +0 -2
- brainstate/util/_caller.py +4 -6
- brainstate/util/_others.py +0 -2
- brainstate/util/_pretty_pytree.py +201 -150
- brainstate/util/_pretty_repr.py +0 -2
- brainstate/util/_pretty_table.py +57 -3
- brainstate/util/_scaling.py +0 -2
- brainstate/util/_struct.py +0 -2
- brainstate/util/filter.py +0 -2
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
- brainstate-0.1.2.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -13,43 +13,40 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
import unittest
|
19
17
|
|
20
18
|
import jax
|
21
19
|
import jax.numpy as jnp
|
22
|
-
import jaxlib.xla_extension
|
23
20
|
|
24
|
-
import brainstate
|
21
|
+
import brainstate
|
25
22
|
|
26
23
|
|
27
24
|
class TestJitError(unittest.TestCase):
|
28
25
|
def test1(self):
|
29
|
-
with self.assertRaises(
|
30
|
-
|
26
|
+
with self.assertRaises(Exception):
|
27
|
+
brainstate.compile.jit_error_if(True, 'error')
|
31
28
|
|
32
29
|
def err_f(x):
|
33
30
|
raise ValueError(f'error: {x}')
|
34
31
|
|
35
|
-
|
36
|
-
with self.assertRaises(
|
37
|
-
|
32
|
+
brainstate.compile.jit_error_if(False, err_f, 1.)
|
33
|
+
with self.assertRaises(Exception):
|
34
|
+
brainstate.compile.jit_error_if(True, err_f, 1.)
|
38
35
|
|
39
36
|
def test_vmap(self):
|
40
37
|
def f(x):
|
41
|
-
|
38
|
+
brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
|
42
39
|
|
43
40
|
jax.vmap(f)(jnp.array([False, False, False]))
|
44
|
-
with self.assertRaises(
|
41
|
+
with self.assertRaises(Exception):
|
45
42
|
jax.vmap(f)(jnp.array([True, False, False]))
|
46
43
|
|
47
44
|
def test_vmap_vmap(self):
|
48
45
|
def f(x):
|
49
|
-
|
46
|
+
brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
|
50
47
|
|
51
48
|
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
52
49
|
[False, False, False]]))
|
53
|
-
with self.assertRaises(
|
50
|
+
with self.assertRaises(Exception):
|
54
51
|
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
55
52
|
[True, False, False]]))
|
brainstate/compile/_jit.py
CHANGED
@@ -13,16 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
import functools
|
19
17
|
from collections.abc import Iterable, Sequence
|
20
18
|
from typing import (Any, Callable, Union)
|
21
19
|
|
22
20
|
import jax
|
23
21
|
from jax._src import sharding_impls
|
24
|
-
from jax.lib import xla_client as xc
|
25
22
|
|
23
|
+
from brainstate._compatible_import import Device
|
26
24
|
from brainstate._utils import set_module_as
|
27
25
|
from brainstate.typing import Missing
|
28
26
|
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
|
@@ -94,8 +92,8 @@ def _get_jitted_fun(
|
|
94
92
|
read_state_vals = state_trace.get_read_state_values(True)
|
95
93
|
|
96
94
|
# call the jitted function
|
97
|
-
# print('Running ...')
|
98
95
|
write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
|
96
|
+
|
99
97
|
# write the state values back to the states
|
100
98
|
write_back_state_values(state_trace, read_state_vals, write_state_vals)
|
101
99
|
return outs
|
@@ -106,8 +104,11 @@ def _get_jitted_fun(
|
|
106
104
|
"""
|
107
105
|
# clear the cache of the stateful function
|
108
106
|
fun.clear_cache()
|
109
|
-
|
110
|
-
|
107
|
+
try:
|
108
|
+
# clear the cache of the jitted function
|
109
|
+
jit_fun.clear_cache()
|
110
|
+
except AttributeError:
|
111
|
+
pass
|
111
112
|
|
112
113
|
def eval_shape():
|
113
114
|
raise NotImplementedError
|
@@ -165,7 +166,7 @@ def _get_jitted_fun(
|
|
165
166
|
# compile the jitted function
|
166
167
|
jitted_fun.compile = compile
|
167
168
|
|
168
|
-
# trace the jitted
|
169
|
+
# trace the jitted function
|
169
170
|
jitted_fun.trace = trace
|
170
171
|
|
171
172
|
return jitted_fun
|
@@ -180,7 +181,7 @@ def jit(
|
|
180
181
|
donate_argnums: int | Sequence[int] | None = None,
|
181
182
|
donate_argnames: str | Iterable[str] | None = None,
|
182
183
|
keep_unused: bool = False,
|
183
|
-
device:
|
184
|
+
device: Device | None = None,
|
184
185
|
backend: str | None = None,
|
185
186
|
inline: bool = False,
|
186
187
|
abstracted_axes: Any | None = None,
|
@@ -13,8 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
import math
|
19
17
|
from functools import wraps
|
20
18
|
from typing import Callable, Optional, TypeVar, Tuple, Any
|
@@ -13,20 +13,18 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
import unittest
|
19
17
|
|
20
18
|
import jax.numpy as jnp
|
21
19
|
import numpy as np
|
22
20
|
|
23
|
-
import brainstate
|
21
|
+
import brainstate
|
24
22
|
|
25
23
|
|
26
24
|
class TestForLoop(unittest.TestCase):
|
27
25
|
def test_for_loop(self):
|
28
|
-
a =
|
29
|
-
b =
|
26
|
+
a = brainstate.ShortTermState(0.)
|
27
|
+
b = brainstate.ShortTermState(0.)
|
30
28
|
|
31
29
|
def f(i):
|
32
30
|
a.value += (1 + b.value)
|
@@ -34,7 +32,7 @@ class TestForLoop(unittest.TestCase):
|
|
34
32
|
|
35
33
|
n_iter = 10
|
36
34
|
ops = np.arange(n_iter)
|
37
|
-
r =
|
35
|
+
r = brainstate.compile.for_loop(f, ops)
|
38
36
|
|
39
37
|
print(a)
|
40
38
|
print(b)
|
@@ -42,8 +40,8 @@ class TestForLoop(unittest.TestCase):
|
|
42
40
|
self.assertTrue(jnp.allclose(r, ops + 1))
|
43
41
|
|
44
42
|
def test_checkpointed_for_loop(self):
|
45
|
-
a =
|
46
|
-
b =
|
43
|
+
a = brainstate.ShortTermState(0.)
|
44
|
+
b = brainstate.ShortTermState(0.)
|
47
45
|
|
48
46
|
def f(i):
|
49
47
|
a.value += (1 + b.value)
|
@@ -51,7 +49,7 @@ class TestForLoop(unittest.TestCase):
|
|
51
49
|
|
52
50
|
n_iter = 18
|
53
51
|
ops = jnp.arange(n_iter)
|
54
|
-
r =
|
52
|
+
r = brainstate.compile.checkpointed_for_loop(f, ops, base=2, pbar=brainstate.compile.ProgressBar())
|
55
53
|
|
56
54
|
print(a)
|
57
55
|
print(b)
|
@@ -13,17 +13,16 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
16
|
|
18
17
|
from unittest import TestCase
|
19
18
|
|
20
|
-
import brainstate
|
19
|
+
import brainstate
|
21
20
|
|
22
21
|
|
23
22
|
class TestWhileLoop(TestCase):
|
24
23
|
def test1(self):
|
25
|
-
a =
|
26
|
-
b =
|
24
|
+
a = brainstate.State(1.)
|
25
|
+
b = brainstate.State(20.)
|
27
26
|
|
28
27
|
def cond(_):
|
29
28
|
return a.value < b.value
|
@@ -31,13 +30,13 @@ class TestWhileLoop(TestCase):
|
|
31
30
|
def body(_):
|
32
31
|
a.value += 1.
|
33
32
|
|
34
|
-
|
33
|
+
brainstate.compile.while_loop(cond, body, None)
|
35
34
|
|
36
35
|
print(a.value, b.value)
|
37
36
|
|
38
37
|
def test2(self):
|
39
|
-
a =
|
40
|
-
b =
|
38
|
+
a = brainstate.State(1.)
|
39
|
+
b = brainstate.State(20.)
|
41
40
|
|
42
41
|
def cond(x):
|
43
42
|
return a.value < b.value
|
@@ -46,6 +45,6 @@ class TestWhileLoop(TestCase):
|
|
46
45
|
a.value += x
|
47
46
|
return x
|
48
47
|
|
49
|
-
r =
|
48
|
+
r = brainstate.compile.while_loop(cond, body, 1.)
|
50
49
|
|
51
50
|
print(a.value, b.value, r)
|
@@ -51,8 +51,6 @@ function.
|
|
51
51
|
|
52
52
|
"""
|
53
53
|
|
54
|
-
from __future__ import annotations
|
55
|
-
|
56
54
|
import functools
|
57
55
|
import inspect
|
58
56
|
import operator
|
@@ -65,7 +63,7 @@ from jax._src import source_info_util
|
|
65
63
|
from jax._src.linear_util import annotate
|
66
64
|
from jax._src.traceback_util import api_boundary
|
67
65
|
from jax.api_util import shaped_abstractify
|
68
|
-
from jax.extend.linear_util import transformation_with_aux
|
66
|
+
from jax.extend.linear_util import transformation_with_aux
|
69
67
|
from jax.interpreters import partial_eval as pe
|
70
68
|
|
71
69
|
from brainstate._compatible_import import (
|
@@ -75,6 +73,7 @@ from brainstate._compatible_import import (
|
|
75
73
|
safe_zip,
|
76
74
|
unzip2,
|
77
75
|
wraps,
|
76
|
+
wrap_init,
|
78
77
|
)
|
79
78
|
from brainstate._state import State, StateTraceStack
|
80
79
|
from brainstate._utils import set_module_as
|
@@ -98,7 +97,7 @@ def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
|
|
98
97
|
return tuple(safe_map(operator.index, x))
|
99
98
|
|
100
99
|
|
101
|
-
def
|
100
|
+
def _jax_v04_new_arg_fn(frame, trace, aval):
|
102
101
|
"""
|
103
102
|
Transform a new argument to a tracer.
|
104
103
|
|
@@ -119,27 +118,41 @@ def _new_arg_fn(frame, trace, aval):
|
|
119
118
|
return tracer
|
120
119
|
|
121
120
|
|
122
|
-
def
|
121
|
+
def _jax_v04_new_jax_trace():
|
123
122
|
main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
|
124
123
|
frame = main.jaxpr_stack[-1]
|
125
124
|
trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
|
126
125
|
return frame, trace
|
127
126
|
|
128
127
|
|
128
|
+
def _jax_v04_new_arg():
|
129
|
+
# Should be within the calling of ``jax.make_jaxpr()``
|
130
|
+
frame, trace = _jax_v04_new_jax_trace()
|
131
|
+
# Set the function to transform the new argument to a tracer
|
132
|
+
fn = functools.partial(_jax_v04_new_arg_fn, frame, trace)
|
133
|
+
return fn
|
134
|
+
|
135
|
+
|
136
|
+
def _jax_new_version_new_arg():
|
137
|
+
trace = jax.core.trace_ctx.trace
|
138
|
+
|
139
|
+
def wrapper(x):
|
140
|
+
if jax.__version_info__ < (0, 6, 1):
|
141
|
+
return trace.new_arg(shaped_abstractify(x))
|
142
|
+
else:
|
143
|
+
return trace.new_arg(shaped_abstractify(x), source_info=source_info_util.current())
|
144
|
+
|
145
|
+
return wrapper
|
146
|
+
|
147
|
+
|
129
148
|
def _init_state_trace_stack(name) -> StateTraceStack:
|
130
149
|
state_trace: StateTraceStack = StateTraceStack(name=name)
|
131
150
|
|
132
151
|
if jax.__version_info__ < (0, 4, 36):
|
133
|
-
|
134
|
-
frame, trace = _new_jax_trace()
|
135
|
-
# Set the function to transform the new argument to a tracer
|
136
|
-
state_trace.set_new_arg(functools.partial(_new_arg_fn, frame, trace))
|
137
|
-
return state_trace
|
138
|
-
|
152
|
+
state_trace.set_new_arg(_jax_v04_new_arg())
|
139
153
|
else:
|
140
|
-
|
141
|
-
|
142
|
-
return state_trace
|
154
|
+
state_trace.set_new_arg(_jax_new_version_new_arg())
|
155
|
+
return state_trace
|
143
156
|
|
144
157
|
|
145
158
|
class StatefulFunction(PrettyObject):
|
@@ -745,7 +758,7 @@ def _make_jaxpr(
|
|
745
758
|
@wraps(fun)
|
746
759
|
@api_boundary
|
747
760
|
def make_jaxpr_f(*args, **kwargs):
|
748
|
-
f = wrap_init(fun)
|
761
|
+
f = wrap_init(fun, (), {}, 'brainstate.compile.make_jaxpr')
|
749
762
|
if static_argnums:
|
750
763
|
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
751
764
|
f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
|
@@ -754,12 +767,12 @@ def _make_jaxpr(
|
|
754
767
|
f, out_tree = _flatten_fun(f, in_tree)
|
755
768
|
f = annotate(f, in_type)
|
756
769
|
if jax.__version_info__ < (0, 5, 0):
|
757
|
-
|
770
|
+
debug_info_ = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
|
758
771
|
with ExitStack() as stack:
|
759
772
|
if axis_env is not None:
|
760
773
|
stack.enter_context(extend_axis_env_nd(axis_env))
|
761
774
|
if jax.__version_info__ < (0, 5, 0):
|
762
|
-
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=
|
775
|
+
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info_)
|
763
776
|
else:
|
764
777
|
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
|
765
778
|
closed_jaxpr = ClosedJaxpr(jaxpr, consts)
|
@@ -21,7 +21,7 @@ import jax
|
|
21
21
|
import jax.numpy as jnp
|
22
22
|
import pytest
|
23
23
|
|
24
|
-
import brainstate
|
24
|
+
import brainstate
|
25
25
|
from brainstate._compatible_import import jaxpr_as_fun
|
26
26
|
|
27
27
|
|
@@ -29,10 +29,10 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
29
29
|
def test_compar_jax_make_jaxpr(self):
|
30
30
|
def func4(arg): # Arg is a pair
|
31
31
|
temp = arg[0] + jnp.sin(arg[1]) * 3.
|
32
|
-
c =
|
32
|
+
c = brainstate.random.rand_like(arg[0])
|
33
33
|
return jnp.sum(temp + c)
|
34
34
|
|
35
|
-
key =
|
35
|
+
key = brainstate.random.DEFAULT.value
|
36
36
|
jaxpr = jax.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
|
37
37
|
print(jaxpr)
|
38
38
|
self.assertTrue(len(jaxpr.in_avals) == 2)
|
@@ -40,66 +40,66 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
40
40
|
self.assertTrue(len(jaxpr.out_avals) == 1)
|
41
41
|
self.assertTrue(jnp.allclose(jaxpr.consts[0], key))
|
42
42
|
|
43
|
-
|
44
|
-
print(
|
43
|
+
brainstate.random.seed(1)
|
44
|
+
print(brainstate.random.DEFAULT.value)
|
45
45
|
|
46
|
-
jaxpr2, states =
|
46
|
+
jaxpr2, states = brainstate.compile.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
|
47
47
|
print(jaxpr2)
|
48
48
|
self.assertTrue(len(jaxpr2.in_avals) == 3)
|
49
49
|
self.assertTrue(len(jaxpr2.out_avals) == 2)
|
50
50
|
self.assertTrue(len(jaxpr2.consts) == 0)
|
51
|
-
print(
|
51
|
+
print(brainstate.random.DEFAULT.value)
|
52
52
|
|
53
53
|
def test_StatefulFunction_1(self):
|
54
54
|
def func4(arg): # Arg is a pair
|
55
55
|
temp = arg[0] + jnp.sin(arg[1]) * 3.
|
56
|
-
c =
|
56
|
+
c = brainstate.random.rand_like(arg[0])
|
57
57
|
return jnp.sum(temp + c)
|
58
58
|
|
59
|
-
fun =
|
59
|
+
fun = brainstate.compile.StatefulFunction(func4).make_jaxpr((jnp.zeros(8), jnp.ones(8)))
|
60
60
|
print(fun.get_states())
|
61
61
|
print(fun.get_jaxpr())
|
62
62
|
|
63
63
|
def test_StatefulFunction_2(self):
|
64
|
-
st1 =
|
64
|
+
st1 = brainstate.State(jnp.ones(10))
|
65
65
|
|
66
66
|
def f1(x):
|
67
67
|
st1.value = x + st1.value
|
68
68
|
|
69
69
|
def f2(x):
|
70
|
-
jaxpr =
|
70
|
+
jaxpr = brainstate.compile.make_jaxpr(f1)(x)
|
71
71
|
c = 1. + x
|
72
72
|
return c
|
73
73
|
|
74
74
|
def f3(x):
|
75
|
-
jaxpr =
|
75
|
+
jaxpr = brainstate.compile.make_jaxpr(f1)(x)
|
76
76
|
c = 1.
|
77
77
|
return c
|
78
78
|
|
79
79
|
print()
|
80
|
-
jaxpr =
|
80
|
+
jaxpr = brainstate.compile.make_jaxpr(f1)(jnp.zeros(1))
|
81
81
|
print(jaxpr)
|
82
82
|
jaxpr = jax.make_jaxpr(f2)(jnp.zeros(1))
|
83
83
|
print(jaxpr)
|
84
84
|
jaxpr = jax.make_jaxpr(f3)(jnp.zeros(1))
|
85
85
|
print(jaxpr)
|
86
|
-
jaxpr, _ =
|
86
|
+
jaxpr, _ = brainstate.compile.make_jaxpr(f3)(jnp.zeros(1))
|
87
87
|
print(jaxpr)
|
88
88
|
self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
|
89
89
|
f3(jnp.zeros(1))))
|
90
90
|
|
91
91
|
def test_compar_jax_make_jaxpr2(self):
|
92
|
-
st1 =
|
92
|
+
st1 = brainstate.State(jnp.ones(10))
|
93
93
|
|
94
94
|
def fa(x):
|
95
95
|
st1.value = x + st1.value
|
96
96
|
|
97
97
|
def ffa(x):
|
98
|
-
jaxpr, states =
|
98
|
+
jaxpr, states = brainstate.compile.make_jaxpr(fa)(x)
|
99
99
|
c = 1. + x
|
100
100
|
return c
|
101
101
|
|
102
|
-
jaxpr, states =
|
102
|
+
jaxpr, states = brainstate.compile.make_jaxpr(ffa)(jnp.zeros(1))
|
103
103
|
print()
|
104
104
|
print(jaxpr)
|
105
105
|
print(states)
|
@@ -112,7 +112,7 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
112
112
|
def fa(x):
|
113
113
|
return 1.
|
114
114
|
|
115
|
-
jaxpr, states =
|
115
|
+
jaxpr, states = brainstate.compile.make_jaxpr(fa)(jnp.zeros(1))
|
116
116
|
print()
|
117
117
|
print(jaxpr)
|
118
118
|
print(states)
|
@@ -125,9 +125,9 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
125
125
|
def test_return_states():
|
126
126
|
import jax.numpy
|
127
127
|
|
128
|
-
a =
|
128
|
+
a = brainstate.State(jax.numpy.ones(3))
|
129
129
|
|
130
|
-
@
|
130
|
+
@brainstate.compile.jit
|
131
131
|
def f():
|
132
132
|
return a
|
133
133
|
|
brainstate/compile/_unvmap.py
CHANGED
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from __future__ import annotations
|
16
15
|
|
17
16
|
import jax
|
18
17
|
import jax.core
|
brainstate/compile/_util.py
CHANGED
brainstate/environ.py
CHANGED