brainstate 0.1.0.post20241209__py2.py3-none-any.whl → 0.1.0.post20241219__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/compile/_conditions.py +5 -7
- brainstate/compile/_jit.py +3 -3
- brainstate/compile/_loop_collect_return.py +19 -12
- brainstate/compile/_loop_no_collection.py +4 -5
- brainstate/compile/_progress_bar.py +22 -19
- brainstate/event/__init__.py +8 -6
- brainstate/event/_csr.py +906 -0
- brainstate/event/_csr_mv.py +12 -25
- brainstate/event/_csr_mv_test.py +76 -76
- brainstate/event/_csr_test.py +90 -0
- brainstate/event/_fixedprob_mv.py +52 -32
- brainstate/event/_linear_mv.py +2 -2
- brainstate/event/_xla_custom_op.py +8 -11
- brainstate/graph/_graph_node.py +10 -1
- brainstate/graph/_graph_operation.py +8 -6
- brainstate/nn/_dyn_impl/_inputs.py +127 -2
- brainstate/nn/_dynamics/_dynamics_base.py +12 -0
- brainstate/nn/_dynamics/_projection_base.py +25 -7
- brainstate/nn/_elementwise/_dropout_test.py +11 -11
- brainstate/nn/_interaction/_linear.py +21 -248
- brainstate/nn/_interaction/_linear_test.py +73 -6
- brainstate/random/_rand_funs.py +7 -3
- brainstate/typing.py +3 -0
- {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241219.dist-info}/METADATA +3 -2
- {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241219.dist-info}/RECORD +28 -27
- brainstate/event/_csr_benchmark.py +0 -14
- {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241219.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241219.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241219.dist-info}/top_level.txt +0 -0
@@ -94,9 +94,8 @@ def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
|
|
94
94
|
return false_fun(*operands)
|
95
95
|
|
96
96
|
# evaluate jaxpr
|
97
|
-
|
98
|
-
|
99
|
-
stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
|
97
|
+
stateful_true = StatefulFunction(true_fun).make_jaxpr(*operands)
|
98
|
+
stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
|
100
99
|
|
101
100
|
# state trace and state values
|
102
101
|
state_trace = stateful_true.get_state_trace() + stateful_false.get_state_trace()
|
@@ -175,10 +174,9 @@ def switch(index, branches: Sequence[Callable], *operands):
|
|
175
174
|
return branches[int(index)](*operands)
|
176
175
|
|
177
176
|
# evaluate jaxpr
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
wrapped_branch.make_jaxpr(*operands)
|
177
|
+
wrapped_branches = [StatefulFunction(branch) for branch in branches]
|
178
|
+
for wrapped_branch in wrapped_branches:
|
179
|
+
wrapped_branch.make_jaxpr(*operands)
|
182
180
|
|
183
181
|
# wrap the functions
|
184
182
|
state_trace = wrapped_branches[0].get_state_trace() + wrapped_branches[1].get_state_trace()
|
brainstate/compile/_jit.py
CHANGED
@@ -83,9 +83,9 @@ def _get_jitted_fun(
|
|
83
83
|
return fun.fun(*args, **params)
|
84
84
|
|
85
85
|
# compile the function and get the state trace
|
86
|
-
|
87
|
-
|
88
|
-
|
86
|
+
state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
|
87
|
+
read_state_vals = state_trace.get_read_state_values(True)
|
88
|
+
|
89
89
|
# call the jitted function
|
90
90
|
write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
|
91
91
|
# write the state values back to the states
|
@@ -63,7 +63,7 @@ def scan(
|
|
63
63
|
length: int | None = None,
|
64
64
|
reverse: bool = False,
|
65
65
|
unroll: int | bool = 1,
|
66
|
-
pbar: ProgressBar | None = None,
|
66
|
+
pbar: ProgressBar | int | None = None,
|
67
67
|
) -> Tuple[Carry, Y]:
|
68
68
|
"""
|
69
69
|
Scan a function over leading array axes while carrying along state.
|
@@ -177,7 +177,13 @@ def scan(
|
|
177
177
|
has_pbar = False
|
178
178
|
if pbar is not None:
|
179
179
|
has_pbar = True
|
180
|
-
|
180
|
+
if isinstance(pbar, ProgressBar):
|
181
|
+
pbar_runner = pbar.init(length)
|
182
|
+
elif isinstance(pbar, int):
|
183
|
+
pbar_runner = ProgressBar(freq=pbar).init(length)
|
184
|
+
else:
|
185
|
+
raise TypeError("pbar argument should be a ProgressBar instance or an integer.")
|
186
|
+
f = _wrap_fun_with_pbar(f, pbar_runner)
|
181
187
|
init = (0, init) if pbar else init
|
182
188
|
|
183
189
|
# not jit
|
@@ -202,12 +208,11 @@ def scan(
|
|
202
208
|
# ------------------------------ #
|
203
209
|
xs_avals = [jax.core.raise_to_shaped(jax.core.get_aval(x)) for x in xs_flat]
|
204
210
|
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
|
211
|
+
stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
|
212
|
+
state_trace = stateful_fun.get_state_trace()
|
213
|
+
all_writen_state_vals = state_trace.get_write_state_values(True)
|
214
|
+
all_read_state_vals = state_trace.get_read_state_values(True)
|
215
|
+
wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
|
211
216
|
|
212
217
|
# scan
|
213
218
|
init = (all_writen_state_vals, init)
|
@@ -231,7 +236,7 @@ def checkpointed_scan(
|
|
231
236
|
xs: X,
|
232
237
|
length: Optional[int] = None,
|
233
238
|
base: int = 16,
|
234
|
-
pbar: Optional[ProgressBar] = None,
|
239
|
+
pbar: Optional[ProgressBar | int] = None,
|
235
240
|
) -> Tuple[Carry, Y]:
|
236
241
|
"""
|
237
242
|
Scan a function over leading array axes while carrying along state.
|
@@ -289,8 +294,10 @@ def checkpointed_scan(
|
|
289
294
|
length, = unique_lengths
|
290
295
|
|
291
296
|
# function with progress bar
|
292
|
-
if pbar
|
297
|
+
if isinstance(pbar, ProgressBar):
|
293
298
|
pbar_runner = pbar.init(length)
|
299
|
+
elif isinstance(pbar, int):
|
300
|
+
pbar_runner = ProgressBar(freq=pbar).init(length)
|
294
301
|
else:
|
295
302
|
pbar_runner = None
|
296
303
|
|
@@ -381,7 +388,7 @@ def for_loop(
|
|
381
388
|
length: Optional[int] = None,
|
382
389
|
reverse: bool = False,
|
383
390
|
unroll: int | bool = 1,
|
384
|
-
pbar: Optional[ProgressBar] = None
|
391
|
+
pbar: Optional[ProgressBar | int] = None
|
385
392
|
) -> Y:
|
386
393
|
"""
|
387
394
|
``for-loop`` control flow with :py:class:`~.State`.
|
@@ -433,7 +440,7 @@ def checkpointed_for_loop(
|
|
433
440
|
*xs: X,
|
434
441
|
length: Optional[int] = None,
|
435
442
|
base: int = 16,
|
436
|
-
pbar: Optional[ProgressBar] = None,
|
443
|
+
pbar: Optional[ProgressBar | int] = None,
|
437
444
|
) -> Y:
|
438
445
|
"""
|
439
446
|
``for-loop`` control flow with :py:class:`~.State` with a checkpointed version, similar to :py:func:`for_loop`.
|
@@ -103,11 +103,10 @@ def while_loop(
|
|
103
103
|
pass
|
104
104
|
|
105
105
|
# evaluate jaxpr
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
raise ValueError("while_loop: cond_fun should not have any write states.")
|
106
|
+
stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
|
107
|
+
stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
|
108
|
+
if len(stateful_cond.get_write_states()) != 0:
|
109
|
+
raise ValueError("while_loop: cond_fun should not have any write states.")
|
111
110
|
|
112
111
|
# state trace and state values
|
113
112
|
state_trace = stateful_cond.get_state_trace() + stateful_body.get_state_trace()
|
@@ -35,6 +35,8 @@ class ProgressBar(object):
|
|
35
35
|
|
36
36
|
def __init__(self, freq: Optional[int] = None, count: Optional[int] = None, **kwargs):
|
37
37
|
self.print_freq = freq
|
38
|
+
if isinstance(freq, int):
|
39
|
+
assert freq > 0, "Print rate should be > 0."
|
38
40
|
self.print_count = count
|
39
41
|
if self.print_freq is not None and self.print_count is not None:
|
40
42
|
raise ValueError("Cannot specify both count and freq.")
|
@@ -105,25 +107,26 @@ class ProgressBarRunner(object):
|
|
105
107
|
self.tqdm_bars[0].close()
|
106
108
|
|
107
109
|
def __call__(self, iter_num, *args, **kwargs):
|
108
|
-
jax.debug.callback(
|
109
|
-
|
110
|
-
iter_num == 0,
|
111
|
-
(iter_num + 1) % self.print_freq == 0,
|
112
|
-
iter_num == self.n - 1
|
113
|
-
)
|
114
|
-
|
115
|
-
# _ = jax.lax.cond(
|
110
|
+
# jax.debug.callback(
|
111
|
+
# self._tqdm,
|
116
112
|
# iter_num == 0,
|
117
|
-
# lambda: jax.debug.callback(self._define_tqdm, ordered=True),
|
118
|
-
# lambda: None,
|
119
|
-
# )
|
120
|
-
# _ = jax.lax.cond(
|
121
113
|
# (iter_num + 1) % self.print_freq == 0,
|
122
|
-
#
|
123
|
-
# lambda: None,
|
124
|
-
# )
|
125
|
-
# _ = jax.lax.cond(
|
126
|
-
# iter_num == self.n - 1,
|
127
|
-
# lambda: jax.debug.callback(self._close_tqdm, ordered=True),
|
128
|
-
# lambda: None,
|
114
|
+
# iter_num == self.n - 1
|
129
115
|
# )
|
116
|
+
|
117
|
+
_ = jax.lax.cond(
|
118
|
+
iter_num == 0,
|
119
|
+
lambda: jax.debug.callback(self._define_tqdm, ordered=True),
|
120
|
+
lambda: None,
|
121
|
+
)
|
122
|
+
_ = jax.lax.cond(
|
123
|
+
iter_num % self.print_freq == (self.print_freq - 1),
|
124
|
+
lambda: jax.debug.callback(self._update_tqdm, ordered=True),
|
125
|
+
lambda: None,
|
126
|
+
)
|
127
|
+
_ = jax.lax.cond(
|
128
|
+
iter_num == self.n - 1,
|
129
|
+
lambda: jax.debug.callback(self._close_tqdm, ordered=True),
|
130
|
+
lambda: None,
|
131
|
+
)
|
132
|
+
|
brainstate/event/__init__.py
CHANGED
@@ -14,14 +14,16 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
|
17
|
+
from ._csr import *
|
17
18
|
from ._csr_mv import *
|
18
|
-
from ._csr_mv import __all__ as __all_csr
|
19
19
|
from ._fixedprob_mv import *
|
20
|
-
from ._fixedprob_mv import __all__ as __all_fixed_probability
|
21
20
|
from ._linear_mv import *
|
22
21
|
from ._xla_custom_op import *
|
23
|
-
from ._xla_custom_op import __all__ as __all_xla_custom_op
|
24
|
-
from ._linear_mv import __all__ as __all_linear
|
25
22
|
|
26
|
-
__all__ =
|
27
|
-
|
23
|
+
__all__ = [
|
24
|
+
'CSRLinear',
|
25
|
+
'FixedProb',
|
26
|
+
'XLACustomOp',
|
27
|
+
'CSR',
|
28
|
+
'CSC',
|
29
|
+
]
|