brainstate 0.1.0.post20241209__py2.py3-none-any.whl → 0.1.0.post20241219__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. brainstate/compile/_conditions.py +5 -7
  2. brainstate/compile/_jit.py +3 -3
  3. brainstate/compile/_loop_collect_return.py +19 -12
  4. brainstate/compile/_loop_no_collection.py +4 -5
  5. brainstate/compile/_progress_bar.py +22 -19
  6. brainstate/event/__init__.py +8 -6
  7. brainstate/event/_csr.py +906 -0
  8. brainstate/event/_csr_mv.py +12 -25
  9. brainstate/event/_csr_mv_test.py +76 -76
  10. brainstate/event/_csr_test.py +90 -0
  11. brainstate/event/_fixedprob_mv.py +52 -32
  12. brainstate/event/_linear_mv.py +2 -2
  13. brainstate/event/_xla_custom_op.py +8 -11
  14. brainstate/graph/_graph_node.py +10 -1
  15. brainstate/graph/_graph_operation.py +8 -6
  16. brainstate/nn/_dyn_impl/_inputs.py +127 -2
  17. brainstate/nn/_dynamics/_dynamics_base.py +12 -0
  18. brainstate/nn/_dynamics/_projection_base.py +25 -7
  19. brainstate/nn/_elementwise/_dropout_test.py +11 -11
  20. brainstate/nn/_interaction/_linear.py +21 -248
  21. brainstate/nn/_interaction/_linear_test.py +73 -6
  22. brainstate/random/_rand_funs.py +7 -3
  23. brainstate/typing.py +3 -0
  24. {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241219.dist-info}/METADATA +3 -2
  25. {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241219.dist-info}/RECORD +28 -27
  26. brainstate/event/_csr_benchmark.py +0 -14
  27. {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241219.dist-info}/LICENSE +0 -0
  28. {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241219.dist-info}/WHEEL +0 -0
  29. {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241219.dist-info}/top_level.txt +0 -0
@@ -94,9 +94,8 @@ def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
94
94
  return false_fun(*operands)
95
95
 
96
96
  # evaluate jaxpr
97
- with jax.ensure_compile_time_eval():
98
- stateful_true = StatefulFunction(true_fun).make_jaxpr(*operands)
99
- stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
97
+ stateful_true = StatefulFunction(true_fun).make_jaxpr(*operands)
98
+ stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
100
99
 
101
100
  # state trace and state values
102
101
  state_trace = stateful_true.get_state_trace() + stateful_false.get_state_trace()
@@ -175,10 +174,9 @@ def switch(index, branches: Sequence[Callable], *operands):
175
174
  return branches[int(index)](*operands)
176
175
 
177
176
  # evaluate jaxpr
178
- with jax.ensure_compile_time_eval():
179
- wrapped_branches = [StatefulFunction(branch) for branch in branches]
180
- for wrapped_branch in wrapped_branches:
181
- wrapped_branch.make_jaxpr(*operands)
177
+ wrapped_branches = [StatefulFunction(branch) for branch in branches]
178
+ for wrapped_branch in wrapped_branches:
179
+ wrapped_branch.make_jaxpr(*operands)
182
180
 
183
181
  # wrap the functions
184
182
  state_trace = wrapped_branches[0].get_state_trace() + wrapped_branches[1].get_state_trace()
@@ -83,9 +83,9 @@ def _get_jitted_fun(
83
83
  return fun.fun(*args, **params)
84
84
 
85
85
  # compile the function and get the state trace
86
- with jax.ensure_compile_time_eval():
87
- state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
88
- read_state_vals = state_trace.get_read_state_values(True)
86
+ state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
87
+ read_state_vals = state_trace.get_read_state_values(True)
88
+
89
89
  # call the jitted function
90
90
  write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
91
91
  # write the state values back to the states
@@ -63,7 +63,7 @@ def scan(
63
63
  length: int | None = None,
64
64
  reverse: bool = False,
65
65
  unroll: int | bool = 1,
66
- pbar: ProgressBar | None = None,
66
+ pbar: ProgressBar | int | None = None,
67
67
  ) -> Tuple[Carry, Y]:
68
68
  """
69
69
  Scan a function over leading array axes while carrying along state.
@@ -177,7 +177,13 @@ def scan(
177
177
  has_pbar = False
178
178
  if pbar is not None:
179
179
  has_pbar = True
180
- f = _wrap_fun_with_pbar(f, pbar.init(length))
180
+ if isinstance(pbar, ProgressBar):
181
+ pbar_runner = pbar.init(length)
182
+ elif isinstance(pbar, int):
183
+ pbar_runner = ProgressBar(freq=pbar).init(length)
184
+ else:
185
+ raise TypeError("pbar argument should be a ProgressBar instance or an integer.")
186
+ f = _wrap_fun_with_pbar(f, pbar_runner)
181
187
  init = (0, init) if pbar else init
182
188
 
183
189
  # not jit
@@ -202,12 +208,11 @@ def scan(
202
208
  # ------------------------------ #
203
209
  xs_avals = [jax.core.raise_to_shaped(jax.core.get_aval(x)) for x in xs_flat]
204
210
  x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
205
- with jax.ensure_compile_time_eval():
206
- stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
207
- state_trace = stateful_fun.get_state_trace()
208
- all_writen_state_vals = state_trace.get_write_state_values(True)
209
- all_read_state_vals = state_trace.get_read_state_values(True)
210
- wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
211
+ stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
212
+ state_trace = stateful_fun.get_state_trace()
213
+ all_writen_state_vals = state_trace.get_write_state_values(True)
214
+ all_read_state_vals = state_trace.get_read_state_values(True)
215
+ wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
211
216
 
212
217
  # scan
213
218
  init = (all_writen_state_vals, init)
@@ -231,7 +236,7 @@ def checkpointed_scan(
231
236
  xs: X,
232
237
  length: Optional[int] = None,
233
238
  base: int = 16,
234
- pbar: Optional[ProgressBar] = None,
239
+ pbar: Optional[ProgressBar | int] = None,
235
240
  ) -> Tuple[Carry, Y]:
236
241
  """
237
242
  Scan a function over leading array axes while carrying along state.
@@ -289,8 +294,10 @@ def checkpointed_scan(
289
294
  length, = unique_lengths
290
295
 
291
296
  # function with progress bar
292
- if pbar is not None:
297
+ if isinstance(pbar, ProgressBar):
293
298
  pbar_runner = pbar.init(length)
299
+ elif isinstance(pbar, int):
300
+ pbar_runner = ProgressBar(freq=pbar).init(length)
294
301
  else:
295
302
  pbar_runner = None
296
303
 
@@ -381,7 +388,7 @@ def for_loop(
381
388
  length: Optional[int] = None,
382
389
  reverse: bool = False,
383
390
  unroll: int | bool = 1,
384
- pbar: Optional[ProgressBar] = None
391
+ pbar: Optional[ProgressBar | int] = None
385
392
  ) -> Y:
386
393
  """
387
394
  ``for-loop`` control flow with :py:class:`~.State`.
@@ -433,7 +440,7 @@ def checkpointed_for_loop(
433
440
  *xs: X,
434
441
  length: Optional[int] = None,
435
442
  base: int = 16,
436
- pbar: Optional[ProgressBar] = None,
443
+ pbar: Optional[ProgressBar | int] = None,
437
444
  ) -> Y:
438
445
  """
439
446
  ``for-loop`` control flow with :py:class:`~.State` with a checkpointed version, similar to :py:func:`for_loop`.
@@ -103,11 +103,10 @@ def while_loop(
103
103
  pass
104
104
 
105
105
  # evaluate jaxpr
106
- with jax.ensure_compile_time_eval():
107
- stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
108
- stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
109
- if len(stateful_cond.get_write_states()) != 0:
110
- raise ValueError("while_loop: cond_fun should not have any write states.")
106
+ stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
107
+ stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
108
+ if len(stateful_cond.get_write_states()) != 0:
109
+ raise ValueError("while_loop: cond_fun should not have any write states.")
111
110
 
112
111
  # state trace and state values
113
112
  state_trace = stateful_cond.get_state_trace() + stateful_body.get_state_trace()
@@ -35,6 +35,8 @@ class ProgressBar(object):
35
35
 
36
36
  def __init__(self, freq: Optional[int] = None, count: Optional[int] = None, **kwargs):
37
37
  self.print_freq = freq
38
+ if isinstance(freq, int):
39
+ assert freq > 0, "Print rate should be > 0."
38
40
  self.print_count = count
39
41
  if self.print_freq is not None and self.print_count is not None:
40
42
  raise ValueError("Cannot specify both count and freq.")
@@ -105,25 +107,26 @@ class ProgressBarRunner(object):
105
107
  self.tqdm_bars[0].close()
106
108
 
107
109
  def __call__(self, iter_num, *args, **kwargs):
108
- jax.debug.callback(
109
- self._tqdm,
110
- iter_num == 0,
111
- (iter_num + 1) % self.print_freq == 0,
112
- iter_num == self.n - 1
113
- )
114
-
115
- # _ = jax.lax.cond(
110
+ # jax.debug.callback(
111
+ # self._tqdm,
116
112
  # iter_num == 0,
117
- # lambda: jax.debug.callback(self._define_tqdm, ordered=True),
118
- # lambda: None,
119
- # )
120
- # _ = jax.lax.cond(
121
113
  # (iter_num + 1) % self.print_freq == 0,
122
- # lambda: jax.debug.callback(self._update_tqdm, ordered=True),
123
- # lambda: None,
124
- # )
125
- # _ = jax.lax.cond(
126
- # iter_num == self.n - 1,
127
- # lambda: jax.debug.callback(self._close_tqdm, ordered=True),
128
- # lambda: None,
114
+ # iter_num == self.n - 1
129
115
  # )
116
+
117
+ _ = jax.lax.cond(
118
+ iter_num == 0,
119
+ lambda: jax.debug.callback(self._define_tqdm, ordered=True),
120
+ lambda: None,
121
+ )
122
+ _ = jax.lax.cond(
123
+ iter_num % self.print_freq == (self.print_freq - 1),
124
+ lambda: jax.debug.callback(self._update_tqdm, ordered=True),
125
+ lambda: None,
126
+ )
127
+ _ = jax.lax.cond(
128
+ iter_num == self.n - 1,
129
+ lambda: jax.debug.callback(self._close_tqdm, ordered=True),
130
+ lambda: None,
131
+ )
132
+
@@ -14,14 +14,16 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
+ from ._csr import *
17
18
  from ._csr_mv import *
18
- from ._csr_mv import __all__ as __all_csr
19
19
  from ._fixedprob_mv import *
20
- from ._fixedprob_mv import __all__ as __all_fixed_probability
21
20
  from ._linear_mv import *
22
21
  from ._xla_custom_op import *
23
- from ._xla_custom_op import __all__ as __all_xla_custom_op
24
- from ._linear_mv import __all__ as __all_linear
25
22
 
26
- __all__ = __all_fixed_probability + __all_linear + __all_csr + __all_xla_custom_op
27
- del __all_fixed_probability, __all_linear, __all_csr, __all_xla_custom_op
23
+ __all__ = [
24
+ 'CSRLinear',
25
+ 'FixedProb',
26
+ 'XLACustomOp',
27
+ 'CSR',
28
+ 'CSC',
29
+ ]