brainstate 0.1.0.post20241209__py2.py3-none-any.whl → 0.1.0.post20241210__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/compile/_conditions.py +5 -7
- brainstate/compile/_jit.py +3 -3
- brainstate/compile/_loop_collect_return.py +5 -6
- brainstate/compile/_loop_no_collection.py +4 -5
- brainstate/compile/_progress_bar.py +20 -19
- {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241210.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241210.dist-info}/RECORD +10 -10
- {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241210.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241210.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241210.dist-info}/top_level.txt +0 -0
@@ -94,9 +94,8 @@ def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
|
|
94
94
|
return false_fun(*operands)
|
95
95
|
|
96
96
|
# evaluate jaxpr
|
97
|
-
|
98
|
-
|
99
|
-
stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
|
97
|
+
stateful_true = StatefulFunction(true_fun).make_jaxpr(*operands)
|
98
|
+
stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
|
100
99
|
|
101
100
|
# state trace and state values
|
102
101
|
state_trace = stateful_true.get_state_trace() + stateful_false.get_state_trace()
|
@@ -175,10 +174,9 @@ def switch(index, branches: Sequence[Callable], *operands):
|
|
175
174
|
return branches[int(index)](*operands)
|
176
175
|
|
177
176
|
# evaluate jaxpr
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
wrapped_branch.make_jaxpr(*operands)
|
177
|
+
wrapped_branches = [StatefulFunction(branch) for branch in branches]
|
178
|
+
for wrapped_branch in wrapped_branches:
|
179
|
+
wrapped_branch.make_jaxpr(*operands)
|
182
180
|
|
183
181
|
# wrap the functions
|
184
182
|
state_trace = wrapped_branches[0].get_state_trace() + wrapped_branches[1].get_state_trace()
|
brainstate/compile/_jit.py
CHANGED
@@ -83,9 +83,9 @@ def _get_jitted_fun(
|
|
83
83
|
return fun.fun(*args, **params)
|
84
84
|
|
85
85
|
# compile the function and get the state trace
|
86
|
-
|
87
|
-
|
88
|
-
|
86
|
+
state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
|
87
|
+
read_state_vals = state_trace.get_read_state_values(True)
|
88
|
+
|
89
89
|
# call the jitted function
|
90
90
|
write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
|
91
91
|
# write the state values back to the states
|
@@ -202,12 +202,11 @@ def scan(
|
|
202
202
|
# ------------------------------ #
|
203
203
|
xs_avals = [jax.core.raise_to_shaped(jax.core.get_aval(x)) for x in xs_flat]
|
204
204
|
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
|
205
|
+
stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
|
206
|
+
state_trace = stateful_fun.get_state_trace()
|
207
|
+
all_writen_state_vals = state_trace.get_write_state_values(True)
|
208
|
+
all_read_state_vals = state_trace.get_read_state_values(True)
|
209
|
+
wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
|
211
210
|
|
212
211
|
# scan
|
213
212
|
init = (all_writen_state_vals, init)
|
@@ -103,11 +103,10 @@ def while_loop(
|
|
103
103
|
pass
|
104
104
|
|
105
105
|
# evaluate jaxpr
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
raise ValueError("while_loop: cond_fun should not have any write states.")
|
106
|
+
stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
|
107
|
+
stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
|
108
|
+
if len(stateful_cond.get_write_states()) != 0:
|
109
|
+
raise ValueError("while_loop: cond_fun should not have any write states.")
|
111
110
|
|
112
111
|
# state trace and state values
|
113
112
|
state_trace = stateful_cond.get_state_trace() + stateful_body.get_state_trace()
|
@@ -105,25 +105,26 @@ class ProgressBarRunner(object):
|
|
105
105
|
self.tqdm_bars[0].close()
|
106
106
|
|
107
107
|
def __call__(self, iter_num, *args, **kwargs):
|
108
|
-
jax.debug.callback(
|
109
|
-
|
110
|
-
iter_num == 0,
|
111
|
-
(iter_num + 1) % self.print_freq == 0,
|
112
|
-
iter_num == self.n - 1
|
113
|
-
)
|
114
|
-
|
115
|
-
# _ = jax.lax.cond(
|
108
|
+
# jax.debug.callback(
|
109
|
+
# self._tqdm,
|
116
110
|
# iter_num == 0,
|
117
|
-
# lambda: jax.debug.callback(self._define_tqdm, ordered=True),
|
118
|
-
# lambda: None,
|
119
|
-
# )
|
120
|
-
# _ = jax.lax.cond(
|
121
111
|
# (iter_num + 1) % self.print_freq == 0,
|
122
|
-
#
|
123
|
-
# lambda: None,
|
124
|
-
# )
|
125
|
-
# _ = jax.lax.cond(
|
126
|
-
# iter_num == self.n - 1,
|
127
|
-
# lambda: jax.debug.callback(self._close_tqdm, ordered=True),
|
128
|
-
# lambda: None,
|
112
|
+
# iter_num == self.n - 1
|
129
113
|
# )
|
114
|
+
|
115
|
+
_ = jax.lax.cond(
|
116
|
+
iter_num == 0,
|
117
|
+
lambda: jax.debug.callback(self._define_tqdm),
|
118
|
+
lambda: None,
|
119
|
+
)
|
120
|
+
_ = jax.lax.cond(
|
121
|
+
iter_num % self.print_freq == (self.print_freq - 1),
|
122
|
+
lambda: jax.debug.callback(self._update_tqdm),
|
123
|
+
lambda: None,
|
124
|
+
)
|
125
|
+
_ = jax.lax.cond(
|
126
|
+
iter_num == self.n - 1,
|
127
|
+
lambda: jax.debug.callback(self._close_tqdm),
|
128
|
+
lambda: None,
|
129
|
+
)
|
130
|
+
|
{brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241210.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20241210
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -20,19 +20,19 @@ brainstate/augment/_random.py,sha256=rkB4w4BkKsz9p8lTk31kVHvlVPJSvtGk8REn936KI_4
|
|
20
20
|
brainstate/compile/__init__.py,sha256=qZZIYoyEl51IFkFu-Hb-bP3PAEHo94HlTDf57P2ze08,1858
|
21
21
|
brainstate/compile/_ad_checkpoint.py,sha256=5zJ1ENeTU4FzRY_uNpr85NhKfuicMMjcIbhu6-bSM4k,9451
|
22
22
|
brainstate/compile/_ad_checkpoint_test.py,sha256=R1I76nG4zIqb6g3M_VxWts7rUC1OHJCjtQhPkcbXodk,1746
|
23
|
-
brainstate/compile/_conditions.py,sha256=
|
23
|
+
brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nbYcF0,10113
|
24
24
|
brainstate/compile/_conditions_test.py,sha256=s9LF6h9LvigvgxUIugTqvgCHBIU8TXS1Ar1OlIxXfrw,8389
|
25
25
|
brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvrWs,2689
|
26
26
|
brainstate/compile/_error_if_test.py,sha256=SJmAfosVoGd4vhfFtb1IvjeFVW914bfTccCg6DoLWYk,1992
|
27
|
-
brainstate/compile/_jit.py,sha256=
|
27
|
+
brainstate/compile/_jit.py,sha256=3mQ-RUFz35wceZyKE_MoR58OBL0RK_i6sHm4rWYzMLs,13698
|
28
28
|
brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
|
29
|
-
brainstate/compile/_loop_collect_return.py,sha256=
|
29
|
+
brainstate/compile/_loop_collect_return.py,sha256=DybSBixeuxleKJV6n9FgVBDsUTmexzS0IdgWYRqp5cU,22940
|
30
30
|
brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
|
31
|
-
brainstate/compile/_loop_no_collection.py,sha256=
|
31
|
+
brainstate/compile/_loop_no_collection.py,sha256=0i31gdQ7sI-d6pvnh08ttUUwdAtpx4uoYhGuf_CyL9s,7343
|
32
32
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
33
33
|
brainstate/compile/_make_jaxpr.py,sha256=S5O9KUB3bsxoKcfptlV0MRfKA__Ija37WxkakIRL3z0,33010
|
34
34
|
brainstate/compile/_make_jaxpr_test.py,sha256=qJUtkyj50JQ6f4UJbOLhvRdkbNn3NSKibFL9jESdQkA,4279
|
35
|
-
brainstate/compile/_progress_bar.py,sha256=
|
35
|
+
brainstate/compile/_progress_bar.py,sha256=H544Oh10SiF5ccrKHM9ay7ZHigYIhNhSQGEKbDxRJgg,4485
|
36
36
|
brainstate/compile/_unvmap.py,sha256=ewbLLNXiI_dBsEBaVzSS0BEXNol22sd9gMzk606lSkM,4139
|
37
37
|
brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
|
38
38
|
brainstate/event/__init__.py,sha256=wOBkq7kDg90M8Y9FuoXRlSEuu1ZzbIhCJ1dHeLqN6_Q,1194
|
@@ -136,8 +136,8 @@ brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7
|
|
136
136
|
brainstate/util/_struct.py,sha256=0exv0oOiSt1hmx20Y4J2-pCGtCTx13WcAlEYSBkyung,17640
|
137
137
|
brainstate/util/_tracers.py,sha256=0r5T4nhxMzI79NtqroqitsdMT4YfpgV5RdYJLS5uJ0w,2285
|
138
138
|
brainstate/util/_visualization.py,sha256=n4ZVz10z7VBqA0cKO6vyHwEMprWJgPeEqtITzDMai2Y,1519
|
139
|
-
brainstate-0.1.0.
|
140
|
-
brainstate-0.1.0.
|
141
|
-
brainstate-0.1.0.
|
142
|
-
brainstate-0.1.0.
|
143
|
-
brainstate-0.1.0.
|
139
|
+
brainstate-0.1.0.post20241210.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
140
|
+
brainstate-0.1.0.post20241210.dist-info/METADATA,sha256=E6AATarjpwssXflLfA-OCkxFxqZxqJNxHZteO6UWMhw,3401
|
141
|
+
brainstate-0.1.0.post20241210.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
142
|
+
brainstate-0.1.0.post20241210.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
143
|
+
brainstate-0.1.0.post20241210.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241210.dist-info}/top_level.txt
RENAMED
File without changes
|