brainstate 0.1.0.post20241210__py2.py3-none-any.whl → 0.1.0.post20241219__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/compile/_loop_collect_return.py +14 -6
- brainstate/compile/_progress_bar.py +5 -3
- brainstate/event/__init__.py +8 -6
- brainstate/event/_csr.py +906 -0
- brainstate/event/_csr_mv.py +12 -25
- brainstate/event/_csr_mv_test.py +76 -76
- brainstate/event/_csr_test.py +90 -0
- brainstate/event/_fixedprob_mv.py +52 -32
- brainstate/event/_linear_mv.py +2 -2
- brainstate/event/_xla_custom_op.py +8 -11
- brainstate/graph/_graph_node.py +10 -1
- brainstate/graph/_graph_operation.py +8 -6
- brainstate/nn/_dyn_impl/_inputs.py +127 -2
- brainstate/nn/_dynamics/_dynamics_base.py +12 -0
- brainstate/nn/_dynamics/_projection_base.py +25 -7
- brainstate/nn/_elementwise/_dropout_test.py +11 -11
- brainstate/nn/_interaction/_linear.py +21 -248
- brainstate/nn/_interaction/_linear_test.py +73 -6
- brainstate/random/_rand_funs.py +7 -3
- brainstate/typing.py +3 -0
- {brainstate-0.1.0.post20241210.dist-info → brainstate-0.1.0.post20241219.dist-info}/METADATA +3 -2
- {brainstate-0.1.0.post20241210.dist-info → brainstate-0.1.0.post20241219.dist-info}/RECORD +25 -24
- brainstate/event/_csr_benchmark.py +0 -14
- {brainstate-0.1.0.post20241210.dist-info → brainstate-0.1.0.post20241219.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20241210.dist-info → brainstate-0.1.0.post20241219.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20241210.dist-info → brainstate-0.1.0.post20241219.dist-info}/top_level.txt +0 -0
@@ -63,7 +63,7 @@ def scan(
|
|
63
63
|
length: int | None = None,
|
64
64
|
reverse: bool = False,
|
65
65
|
unroll: int | bool = 1,
|
66
|
-
pbar: ProgressBar | None = None,
|
66
|
+
pbar: ProgressBar | int | None = None,
|
67
67
|
) -> Tuple[Carry, Y]:
|
68
68
|
"""
|
69
69
|
Scan a function over leading array axes while carrying along state.
|
@@ -177,7 +177,13 @@ def scan(
|
|
177
177
|
has_pbar = False
|
178
178
|
if pbar is not None:
|
179
179
|
has_pbar = True
|
180
|
-
|
180
|
+
if isinstance(pbar, ProgressBar):
|
181
|
+
pbar_runner = pbar.init(length)
|
182
|
+
elif isinstance(pbar, int):
|
183
|
+
pbar_runner = ProgressBar(freq=pbar).init(length)
|
184
|
+
else:
|
185
|
+
raise TypeError("pbar argument should be a ProgressBar instance or an integer.")
|
186
|
+
f = _wrap_fun_with_pbar(f, pbar_runner)
|
181
187
|
init = (0, init) if pbar else init
|
182
188
|
|
183
189
|
# not jit
|
@@ -230,7 +236,7 @@ def checkpointed_scan(
|
|
230
236
|
xs: X,
|
231
237
|
length: Optional[int] = None,
|
232
238
|
base: int = 16,
|
233
|
-
pbar: Optional[ProgressBar] = None,
|
239
|
+
pbar: Optional[ProgressBar | int] = None,
|
234
240
|
) -> Tuple[Carry, Y]:
|
235
241
|
"""
|
236
242
|
Scan a function over leading array axes while carrying along state.
|
@@ -288,8 +294,10 @@ def checkpointed_scan(
|
|
288
294
|
length, = unique_lengths
|
289
295
|
|
290
296
|
# function with progress bar
|
291
|
-
if pbar
|
297
|
+
if isinstance(pbar, ProgressBar):
|
292
298
|
pbar_runner = pbar.init(length)
|
299
|
+
elif isinstance(pbar, int):
|
300
|
+
pbar_runner = ProgressBar(freq=pbar).init(length)
|
293
301
|
else:
|
294
302
|
pbar_runner = None
|
295
303
|
|
@@ -380,7 +388,7 @@ def for_loop(
|
|
380
388
|
length: Optional[int] = None,
|
381
389
|
reverse: bool = False,
|
382
390
|
unroll: int | bool = 1,
|
383
|
-
pbar: Optional[ProgressBar] = None
|
391
|
+
pbar: Optional[ProgressBar | int] = None
|
384
392
|
) -> Y:
|
385
393
|
"""
|
386
394
|
``for-loop`` control flow with :py:class:`~.State`.
|
@@ -432,7 +440,7 @@ def checkpointed_for_loop(
|
|
432
440
|
*xs: X,
|
433
441
|
length: Optional[int] = None,
|
434
442
|
base: int = 16,
|
435
|
-
pbar: Optional[ProgressBar] = None,
|
443
|
+
pbar: Optional[ProgressBar | int] = None,
|
436
444
|
) -> Y:
|
437
445
|
"""
|
438
446
|
``for-loop`` control flow with :py:class:`~.State` with a checkpointed version, similar to :py:func:`for_loop`.
|
@@ -35,6 +35,8 @@ class ProgressBar(object):
|
|
35
35
|
|
36
36
|
def __init__(self, freq: Optional[int] = None, count: Optional[int] = None, **kwargs):
|
37
37
|
self.print_freq = freq
|
38
|
+
if isinstance(freq, int):
|
39
|
+
assert freq > 0, "Print rate should be > 0."
|
38
40
|
self.print_count = count
|
39
41
|
if self.print_freq is not None and self.print_count is not None:
|
40
42
|
raise ValueError("Cannot specify both count and freq.")
|
@@ -114,17 +116,17 @@ class ProgressBarRunner(object):
|
|
114
116
|
|
115
117
|
_ = jax.lax.cond(
|
116
118
|
iter_num == 0,
|
117
|
-
lambda: jax.debug.callback(self._define_tqdm),
|
119
|
+
lambda: jax.debug.callback(self._define_tqdm, ordered=True),
|
118
120
|
lambda: None,
|
119
121
|
)
|
120
122
|
_ = jax.lax.cond(
|
121
123
|
iter_num % self.print_freq == (self.print_freq - 1),
|
122
|
-
lambda: jax.debug.callback(self._update_tqdm),
|
124
|
+
lambda: jax.debug.callback(self._update_tqdm, ordered=True),
|
123
125
|
lambda: None,
|
124
126
|
)
|
125
127
|
_ = jax.lax.cond(
|
126
128
|
iter_num == self.n - 1,
|
127
|
-
lambda: jax.debug.callback(self._close_tqdm),
|
129
|
+
lambda: jax.debug.callback(self._close_tqdm, ordered=True),
|
128
130
|
lambda: None,
|
129
131
|
)
|
130
132
|
|
brainstate/event/__init__.py
CHANGED
@@ -14,14 +14,16 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
|
17
|
+
from ._csr import *
|
17
18
|
from ._csr_mv import *
|
18
|
-
from ._csr_mv import __all__ as __all_csr
|
19
19
|
from ._fixedprob_mv import *
|
20
|
-
from ._fixedprob_mv import __all__ as __all_fixed_probability
|
21
20
|
from ._linear_mv import *
|
22
21
|
from ._xla_custom_op import *
|
23
|
-
from ._xla_custom_op import __all__ as __all_xla_custom_op
|
24
|
-
from ._linear_mv import __all__ as __all_linear
|
25
22
|
|
26
|
-
__all__ =
|
27
|
-
|
23
|
+
__all__ = [
|
24
|
+
'CSRLinear',
|
25
|
+
'FixedProb',
|
26
|
+
'XLACustomOp',
|
27
|
+
'CSR',
|
28
|
+
'CSC',
|
29
|
+
]
|