brainstate 0.1.0.post20241209__py2.py3-none-any.whl → 0.1.0.post20241210__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -94,9 +94,8 @@ def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
94
94
  return false_fun(*operands)
95
95
 
96
96
  # evaluate jaxpr
97
- with jax.ensure_compile_time_eval():
98
- stateful_true = StatefulFunction(true_fun).make_jaxpr(*operands)
99
- stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
97
+ stateful_true = StatefulFunction(true_fun).make_jaxpr(*operands)
98
+ stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
100
99
 
101
100
  # state trace and state values
102
101
  state_trace = stateful_true.get_state_trace() + stateful_false.get_state_trace()
@@ -175,10 +174,9 @@ def switch(index, branches: Sequence[Callable], *operands):
175
174
  return branches[int(index)](*operands)
176
175
 
177
176
  # evaluate jaxpr
178
- with jax.ensure_compile_time_eval():
179
- wrapped_branches = [StatefulFunction(branch) for branch in branches]
180
- for wrapped_branch in wrapped_branches:
181
- wrapped_branch.make_jaxpr(*operands)
177
+ wrapped_branches = [StatefulFunction(branch) for branch in branches]
178
+ for wrapped_branch in wrapped_branches:
179
+ wrapped_branch.make_jaxpr(*operands)
182
180
 
183
181
  # wrap the functions
184
182
  state_trace = wrapped_branches[0].get_state_trace() + wrapped_branches[1].get_state_trace()
@@ -83,9 +83,9 @@ def _get_jitted_fun(
83
83
  return fun.fun(*args, **params)
84
84
 
85
85
  # compile the function and get the state trace
86
- with jax.ensure_compile_time_eval():
87
- state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
88
- read_state_vals = state_trace.get_read_state_values(True)
86
+ state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
87
+ read_state_vals = state_trace.get_read_state_values(True)
88
+
89
89
  # call the jitted function
90
90
  write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
91
91
  # write the state values back to the states
@@ -202,12 +202,11 @@ def scan(
202
202
  # ------------------------------ #
203
203
  xs_avals = [jax.core.raise_to_shaped(jax.core.get_aval(x)) for x in xs_flat]
204
204
  x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
205
- with jax.ensure_compile_time_eval():
206
- stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
207
- state_trace = stateful_fun.get_state_trace()
208
- all_writen_state_vals = state_trace.get_write_state_values(True)
209
- all_read_state_vals = state_trace.get_read_state_values(True)
210
- wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
205
+ stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
206
+ state_trace = stateful_fun.get_state_trace()
207
+ all_writen_state_vals = state_trace.get_write_state_values(True)
208
+ all_read_state_vals = state_trace.get_read_state_values(True)
209
+ wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
211
210
 
212
211
  # scan
213
212
  init = (all_writen_state_vals, init)
@@ -103,11 +103,10 @@ def while_loop(
103
103
  pass
104
104
 
105
105
  # evaluate jaxpr
106
- with jax.ensure_compile_time_eval():
107
- stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
108
- stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
109
- if len(stateful_cond.get_write_states()) != 0:
110
- raise ValueError("while_loop: cond_fun should not have any write states.")
106
+ stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
107
+ stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
108
+ if len(stateful_cond.get_write_states()) != 0:
109
+ raise ValueError("while_loop: cond_fun should not have any write states.")
111
110
 
112
111
  # state trace and state values
113
112
  state_trace = stateful_cond.get_state_trace() + stateful_body.get_state_trace()
@@ -105,25 +105,26 @@ class ProgressBarRunner(object):
105
105
  self.tqdm_bars[0].close()
106
106
 
107
107
  def __call__(self, iter_num, *args, **kwargs):
108
- jax.debug.callback(
109
- self._tqdm,
110
- iter_num == 0,
111
- (iter_num + 1) % self.print_freq == 0,
112
- iter_num == self.n - 1
113
- )
114
-
115
- # _ = jax.lax.cond(
108
+ # jax.debug.callback(
109
+ # self._tqdm,
116
110
  # iter_num == 0,
117
- # lambda: jax.debug.callback(self._define_tqdm, ordered=True),
118
- # lambda: None,
119
- # )
120
- # _ = jax.lax.cond(
121
111
  # (iter_num + 1) % self.print_freq == 0,
122
- # lambda: jax.debug.callback(self._update_tqdm, ordered=True),
123
- # lambda: None,
124
- # )
125
- # _ = jax.lax.cond(
126
- # iter_num == self.n - 1,
127
- # lambda: jax.debug.callback(self._close_tqdm, ordered=True),
128
- # lambda: None,
112
+ # iter_num == self.n - 1
129
113
  # )
114
+
115
+ _ = jax.lax.cond(
116
+ iter_num == 0,
117
+ lambda: jax.debug.callback(self._define_tqdm),
118
+ lambda: None,
119
+ )
120
+ _ = jax.lax.cond(
121
+ iter_num % self.print_freq == (self.print_freq - 1),
122
+ lambda: jax.debug.callback(self._update_tqdm),
123
+ lambda: None,
124
+ )
125
+ _ = jax.lax.cond(
126
+ iter_num == self.n - 1,
127
+ lambda: jax.debug.callback(self._close_tqdm),
128
+ lambda: None,
129
+ )
130
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20241209
3
+ Version: 0.1.0.post20241210
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -20,19 +20,19 @@ brainstate/augment/_random.py,sha256=rkB4w4BkKsz9p8lTk31kVHvlVPJSvtGk8REn936KI_4
20
20
  brainstate/compile/__init__.py,sha256=qZZIYoyEl51IFkFu-Hb-bP3PAEHo94HlTDf57P2ze08,1858
21
21
  brainstate/compile/_ad_checkpoint.py,sha256=5zJ1ENeTU4FzRY_uNpr85NhKfuicMMjcIbhu6-bSM4k,9451
22
22
  brainstate/compile/_ad_checkpoint_test.py,sha256=R1I76nG4zIqb6g3M_VxWts7rUC1OHJCjtQhPkcbXodk,1746
23
- brainstate/compile/_conditions.py,sha256=ocz6sDc7Xzabz2GnRsQmS6GDps-WP-OXUd0EZTTlG0k,10217
23
+ brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nbYcF0,10113
24
24
  brainstate/compile/_conditions_test.py,sha256=s9LF6h9LvigvgxUIugTqvgCHBIU8TXS1Ar1OlIxXfrw,8389
25
25
  brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvrWs,2689
26
26
  brainstate/compile/_error_if_test.py,sha256=SJmAfosVoGd4vhfFtb1IvjeFVW914bfTccCg6DoLWYk,1992
27
- brainstate/compile/_jit.py,sha256=bfEszNttEtE6npqHBam1_DBlRa39fE6qP6lGaWw2amA,13750
27
+ brainstate/compile/_jit.py,sha256=3mQ-RUFz35wceZyKE_MoR58OBL0RK_i6sHm4rWYzMLs,13698
28
28
  brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
29
- brainstate/compile/_loop_collect_return.py,sha256=8vDB2l0d4sIn0apspJzkhFhxjsL7reIptDeFRI9b1tc,23002
29
+ brainstate/compile/_loop_collect_return.py,sha256=DybSBixeuxleKJV6n9FgVBDsUTmexzS0IdgWYRqp5cU,22940
30
30
  brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
31
- brainstate/compile/_loop_no_collection.py,sha256=2rSK20enkBMXPAbsCyb7PCICPNrgaSpl5jfumgWpxA0,7401
31
+ brainstate/compile/_loop_no_collection.py,sha256=0i31gdQ7sI-d6pvnh08ttUUwdAtpx4uoYhGuf_CyL9s,7343
32
32
  brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
33
33
  brainstate/compile/_make_jaxpr.py,sha256=S5O9KUB3bsxoKcfptlV0MRfKA__Ija37WxkakIRL3z0,33010
34
34
  brainstate/compile/_make_jaxpr_test.py,sha256=qJUtkyj50JQ6f4UJbOLhvRdkbNn3NSKibFL9jESdQkA,4279
35
- brainstate/compile/_progress_bar.py,sha256=FafEbD9KzmhCCizfQoXXLw46asn9_uiuH1U5_DMtSXg,4529
35
+ brainstate/compile/_progress_bar.py,sha256=H544Oh10SiF5ccrKHM9ay7ZHigYIhNhSQGEKbDxRJgg,4485
36
36
  brainstate/compile/_unvmap.py,sha256=ewbLLNXiI_dBsEBaVzSS0BEXNol22sd9gMzk606lSkM,4139
37
37
  brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
38
38
  brainstate/event/__init__.py,sha256=wOBkq7kDg90M8Y9FuoXRlSEuu1ZzbIhCJ1dHeLqN6_Q,1194
@@ -136,8 +136,8 @@ brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7
136
136
  brainstate/util/_struct.py,sha256=0exv0oOiSt1hmx20Y4J2-pCGtCTx13WcAlEYSBkyung,17640
137
137
  brainstate/util/_tracers.py,sha256=0r5T4nhxMzI79NtqroqitsdMT4YfpgV5RdYJLS5uJ0w,2285
138
138
  brainstate/util/_visualization.py,sha256=n4ZVz10z7VBqA0cKO6vyHwEMprWJgPeEqtITzDMai2Y,1519
139
- brainstate-0.1.0.post20241209.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
140
- brainstate-0.1.0.post20241209.dist-info/METADATA,sha256=gXsiYWSQqOJ0CWKINESG4sSpnDkcmVYgWJWeEFLTHoA,3401
141
- brainstate-0.1.0.post20241209.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
142
- brainstate-0.1.0.post20241209.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
143
- brainstate-0.1.0.post20241209.dist-info/RECORD,,
139
+ brainstate-0.1.0.post20241210.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
140
+ brainstate-0.1.0.post20241210.dist-info/METADATA,sha256=E6AATarjpwssXflLfA-OCkxFxqZxqJNxHZteO6UWMhw,3401
141
+ brainstate-0.1.0.post20241210.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
142
+ brainstate-0.1.0.post20241210.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
143
+ brainstate-0.1.0.post20241210.dist-info/RECORD,,