brainstate 0.1.0.post20250104__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +77 -44
- brainstate/_state_test.py +0 -17
- brainstate/augment/_eval_shape.py +9 -10
- brainstate/augment/_eval_shape_test.py +1 -1
- brainstate/augment/_mapping.py +265 -277
- brainstate/augment/_mapping_test.py +147 -175
- brainstate/compile/_ad_checkpoint.py +6 -4
- brainstate/compile/_error_if_test.py +1 -0
- brainstate/compile/_jit.py +37 -28
- brainstate/compile/_loop_collect_return.py +8 -5
- brainstate/compile/_loop_no_collection.py +2 -0
- brainstate/compile/_make_jaxpr.py +7 -3
- brainstate/compile/_make_jaxpr_test.py +2 -1
- brainstate/compile/_progress_bar.py +68 -40
- brainstate/compile/_unvmap.py +6 -2
- brainstate/environ.py +28 -18
- brainstate/environ_test.py +4 -0
- brainstate/event/__init__.py +0 -2
- brainstate/event/_csr.py +266 -23
- brainstate/event/_csr_test.py +187 -0
- brainstate/event/_fixedprob_mv.py +4 -2
- brainstate/event/_fixedprob_mv_test.py +2 -1
- brainstate/event/_xla_custom_op.py +16 -5
- brainstate/graph/__init__.py +8 -12
- brainstate/graph/_graph_node.py +1 -23
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/graph/_graph_operation_test.py +0 -159
- brainstate/nn/_dyn_impl/_inputs.py +124 -39
- brainstate/nn/_interaction/_conv.py +4 -2
- brainstate/nn/_interaction/_linear.py +84 -10
- brainstate/random/_rand_funs.py +9 -2
- brainstate/random/_rand_seed.py +12 -2
- brainstate/random/_rand_state.py +50 -179
- brainstate/surrogate.py +5 -1
- brainstate/util/__init__.py +0 -4
- brainstate/util/_caller.py +1 -1
- brainstate/util/_dict.py +4 -1
- brainstate/util/_filter.py +1 -1
- brainstate/util/_pretty_repr.py +1 -1
- brainstate/util/_struct.py +1 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +46 -52
- brainstate/event/_csr_mv_test.py +0 -118
- brainstate/graph/_graph_context.py +0 -443
- brainstate/graph/_graph_context_test.py +0 -65
- brainstate/graph/_graph_convert.py +0 -246
- brainstate/util/_tracers.py +0 -68
- brainstate/util/_visualization.py +0 -47
- /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
@@ -16,34 +16,59 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import copy
|
19
|
-
|
19
|
+
import importlib.util
|
20
|
+
from typing import Optional, Callable, Any, Tuple, Dict
|
20
21
|
|
21
22
|
import jax
|
22
23
|
|
23
|
-
|
24
|
-
from tqdm.auto import tqdm
|
25
|
-
except (ImportError, ModuleNotFoundError):
|
26
|
-
tqdm = None
|
24
|
+
tqdm_installed = importlib.util.find_spec('tqdm') is not None
|
27
25
|
|
28
26
|
__all__ = [
|
29
27
|
'ProgressBar',
|
30
28
|
]
|
31
29
|
|
30
|
+
Index = int
|
31
|
+
Carray = Any
|
32
|
+
Output = Any
|
33
|
+
|
32
34
|
|
33
35
|
class ProgressBar(object):
|
36
|
+
"""
|
37
|
+
A progress bar for tracking the progress of a jitted for-loop computation.
|
38
|
+
"""
|
34
39
|
__module__ = "brainstate.compile"
|
35
40
|
|
36
|
-
def __init__(
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
freq: Optional[int] = None,
|
44
|
+
count: Optional[int] = None,
|
45
|
+
desc: Optional[Tuple[str, Callable[[Dict], Dict]]] = None,
|
46
|
+
**kwargs
|
47
|
+
):
|
48
|
+
# print rate
|
37
49
|
self.print_freq = freq
|
38
50
|
if isinstance(freq, int):
|
39
51
|
assert freq > 0, "Print rate should be > 0."
|
52
|
+
|
53
|
+
# print count
|
40
54
|
self.print_count = count
|
41
55
|
if self.print_freq is not None and self.print_count is not None:
|
42
56
|
raise ValueError("Cannot specify both count and freq.")
|
57
|
+
|
58
|
+
# other parameters
|
43
59
|
for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
|
44
60
|
kwargs.pop(kwarg, None)
|
45
61
|
self.kwargs = kwargs
|
46
|
-
|
62
|
+
|
63
|
+
# description
|
64
|
+
if desc is not None:
|
65
|
+
assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
|
66
|
+
assert isinstance(desc[0], str), 'Description should be a string.'
|
67
|
+
assert callable(desc[1]), 'Description should be a callable.'
|
68
|
+
self.desc = desc
|
69
|
+
|
70
|
+
# check if tqdm is installed
|
71
|
+
if not tqdm_installed:
|
47
72
|
raise ImportError("tqdm is not installed.")
|
48
73
|
|
49
74
|
def init(self, n: int):
|
@@ -67,15 +92,22 @@ class ProgressBar(object):
|
|
67
92
|
raise ValueError("Print rate should be less than the "
|
68
93
|
f"number of steps {n}, got {freq}")
|
69
94
|
remainder = n % freq
|
70
|
-
|
71
|
-
message =
|
72
|
-
return ProgressBarRunner(n,
|
95
|
+
|
96
|
+
message = f"Running for {n:,} iterations" if self.desc is None else self.desc
|
97
|
+
return ProgressBarRunner(n, freq, remainder, message, **kwargs)
|
73
98
|
|
74
99
|
|
75
100
|
class ProgressBarRunner(object):
|
76
101
|
__module__ = "brainstate.compile"
|
77
102
|
|
78
|
-
def __init__(
|
103
|
+
def __init__(
|
104
|
+
self,
|
105
|
+
n: int,
|
106
|
+
print_freq: int,
|
107
|
+
remainder: int,
|
108
|
+
message: str | Tuple[str, Callable[[Dict], Dict]],
|
109
|
+
**kwargs
|
110
|
+
):
|
79
111
|
self.tqdm_bars = {}
|
80
112
|
self.kwargs = kwargs
|
81
113
|
self.n = n
|
@@ -83,50 +115,46 @@ class ProgressBarRunner(object):
|
|
83
115
|
self.remainder = remainder
|
84
116
|
self.message = message
|
85
117
|
|
86
|
-
def _define_tqdm(self):
|
118
|
+
def _define_tqdm(self, x: dict):
|
119
|
+
from tqdm.auto import tqdm
|
87
120
|
self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
|
88
|
-
|
121
|
+
if isinstance(self.message, str):
|
122
|
+
self.tqdm_bars[0].set_description(self.message, refresh=False)
|
123
|
+
else:
|
124
|
+
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
89
125
|
|
90
|
-
def _update_tqdm(self):
|
126
|
+
def _update_tqdm(self, x: dict):
|
91
127
|
self.tqdm_bars[0].update(self.print_freq)
|
128
|
+
if not isinstance(self.message, str):
|
129
|
+
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
92
130
|
|
93
|
-
def _close_tqdm(self):
|
131
|
+
def _close_tqdm(self, x: dict):
|
94
132
|
if self.remainder > 0:
|
95
133
|
self.tqdm_bars[0].update(self.remainder)
|
134
|
+
if not isinstance(self.message, str):
|
135
|
+
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
96
136
|
self.tqdm_bars[0].close()
|
97
137
|
|
98
|
-
def
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
if is_print:
|
103
|
-
self.tqdm_bars[0].update(self.print_freq)
|
104
|
-
if is_final:
|
105
|
-
if self.remainder > 0:
|
106
|
-
self.tqdm_bars[0].update(self.remainder)
|
107
|
-
self.tqdm_bars[0].close()
|
108
|
-
|
109
|
-
def __call__(self, iter_num, *args, **kwargs):
|
110
|
-
# jax.debug.callback(
|
111
|
-
# self._tqdm,
|
112
|
-
# iter_num == 0,
|
113
|
-
# (iter_num + 1) % self.print_freq == 0,
|
114
|
-
# iter_num == self.n - 1
|
115
|
-
# )
|
138
|
+
def __call__(self, iter_num, **kwargs):
|
139
|
+
data = dict(i=iter_num, **kwargs)
|
140
|
+
data = dict() if isinstance(self.message, str) else self.message[1](data)
|
141
|
+
assert isinstance(data, dict), 'Description function should return a dictionary.'
|
116
142
|
|
117
143
|
_ = jax.lax.cond(
|
118
144
|
iter_num == 0,
|
119
|
-
lambda: jax.debug.callback(self._define_tqdm, ordered=True),
|
120
|
-
lambda: None,
|
145
|
+
lambda x: jax.debug.callback(self._define_tqdm, x, ordered=True),
|
146
|
+
lambda x: None,
|
147
|
+
data
|
121
148
|
)
|
122
149
|
_ = jax.lax.cond(
|
123
150
|
iter_num % self.print_freq == (self.print_freq - 1),
|
124
|
-
lambda: jax.debug.callback(self._update_tqdm, ordered=True),
|
125
|
-
lambda: None,
|
151
|
+
lambda x: jax.debug.callback(self._update_tqdm, x, ordered=True),
|
152
|
+
lambda x: None,
|
153
|
+
data
|
126
154
|
)
|
127
155
|
_ = jax.lax.cond(
|
128
156
|
iter_num == self.n - 1,
|
129
|
-
lambda: jax.debug.callback(self._close_tqdm, ordered=True),
|
130
|
-
lambda: None,
|
157
|
+
lambda x: jax.debug.callback(self._close_tqdm, x, ordered=True),
|
158
|
+
lambda x: None,
|
159
|
+
data
|
131
160
|
)
|
132
|
-
|
brainstate/compile/_unvmap.py
CHANGED
@@ -19,9 +19,13 @@ import jax.core
|
|
19
19
|
import jax.interpreters.batching as batching
|
20
20
|
import jax.interpreters.mlir as mlir
|
21
21
|
import jax.numpy as jnp
|
22
|
-
|
23
22
|
from brainstate._utils import set_module_as
|
24
23
|
|
24
|
+
if jax.__version_info__ < (0, 4, 38):
|
25
|
+
from jax.core import Primitive
|
26
|
+
else:
|
27
|
+
from jax.extend.core import Primitive
|
28
|
+
|
25
29
|
__all__ = [
|
26
30
|
"unvmap",
|
27
31
|
]
|
@@ -43,7 +47,7 @@ def unvmap(x, op: str = 'any'):
|
|
43
47
|
|
44
48
|
# unvmap_all
|
45
49
|
|
46
|
-
unvmap_all_p =
|
50
|
+
unvmap_all_p = Primitive("unvmap_all")
|
47
51
|
|
48
52
|
|
49
53
|
def unvmap_all(x):
|
brainstate/environ.py
CHANGED
@@ -249,7 +249,13 @@ def get_precision() -> int:
|
|
249
249
|
The default precision.
|
250
250
|
"""
|
251
251
|
precision = get('precision')
|
252
|
-
|
252
|
+
if precision == 'bf16':
|
253
|
+
return 16
|
254
|
+
if isinstance(precision, int):
|
255
|
+
return precision
|
256
|
+
if isinstance(precision, str):
|
257
|
+
return int(precision)
|
258
|
+
raise ValueError(f'Unsupported precision: {precision}')
|
253
259
|
|
254
260
|
|
255
261
|
def set(
|
@@ -345,8 +351,8 @@ def _set_jax_precision(precision: int | str):
|
|
345
351
|
Args:
|
346
352
|
precision: int. The default precision.
|
347
353
|
"""
|
348
|
-
assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
|
349
|
-
if precision
|
354
|
+
# assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
|
355
|
+
if precision in [64, '64']:
|
350
356
|
config.update("jax_enable_x64", True)
|
351
357
|
else:
|
352
358
|
config.update("jax_enable_x64", False)
|
@@ -354,13 +360,13 @@ def _set_jax_precision(precision: int | str):
|
|
354
360
|
|
355
361
|
@functools.lru_cache()
|
356
362
|
def _get_uint(precision: int):
|
357
|
-
if precision
|
363
|
+
if precision in [64, '64']:
|
358
364
|
return np.uint64
|
359
|
-
elif precision
|
365
|
+
elif precision in [32, '32']:
|
360
366
|
return np.uint32
|
361
|
-
elif precision in [16, 'bf16']:
|
367
|
+
elif precision in [16, '16', 'bf16']:
|
362
368
|
return np.uint16
|
363
|
-
elif precision
|
369
|
+
elif precision in [8, '8']:
|
364
370
|
return np.uint8
|
365
371
|
else:
|
366
372
|
raise ValueError(f'Unsupported precision: {precision}')
|
@@ -368,13 +374,13 @@ def _get_uint(precision: int):
|
|
368
374
|
|
369
375
|
@functools.lru_cache()
|
370
376
|
def _get_int(precision: int):
|
371
|
-
if precision
|
377
|
+
if precision in [64, '64']:
|
372
378
|
return np.int64
|
373
|
-
elif precision
|
379
|
+
elif precision in [32, '32']:
|
374
380
|
return np.int32
|
375
|
-
elif precision in [16, 'bf16']:
|
381
|
+
elif precision in [16, '16', 'bf16']:
|
376
382
|
return np.int16
|
377
|
-
elif precision
|
383
|
+
elif precision in [8, '8']:
|
378
384
|
return np.int8
|
379
385
|
else:
|
380
386
|
raise ValueError(f'Unsupported precision: {precision}')
|
@@ -382,25 +388,29 @@ def _get_int(precision: int):
|
|
382
388
|
|
383
389
|
@functools.lru_cache()
|
384
390
|
def _get_float(precision: int):
|
385
|
-
if precision
|
391
|
+
if precision in [64, '64']:
|
386
392
|
return np.float64
|
387
|
-
elif precision
|
393
|
+
elif precision in [32, '32']:
|
388
394
|
return np.float32
|
389
|
-
elif precision
|
395
|
+
elif precision in [16, '16']:
|
390
396
|
return np.float16
|
391
|
-
elif precision
|
397
|
+
elif precision in ['bf16']:
|
392
398
|
return jnp.bfloat16
|
399
|
+
elif precision in [8, '8']:
|
400
|
+
return jnp.float8_e5m2
|
393
401
|
else:
|
394
402
|
raise ValueError(f'Unsupported precision: {precision}')
|
395
403
|
|
396
404
|
|
397
405
|
@functools.lru_cache()
|
398
406
|
def _get_complex(precision: int):
|
399
|
-
if precision == 64:
|
407
|
+
if precision == [64, '64']:
|
400
408
|
return np.complex128
|
401
|
-
elif precision == 32:
|
409
|
+
elif precision == [32, '32']:
|
402
410
|
return np.complex64
|
403
|
-
elif precision in [16, 'bf16']:
|
411
|
+
elif precision in [16, '16', 'bf16']:
|
412
|
+
return np.complex64
|
413
|
+
elif precision == [8, '8']:
|
404
414
|
return np.complex64
|
405
415
|
else:
|
406
416
|
raise ValueError(f'Unsupported precision: {precision}')
|
brainstate/environ_test.py
CHANGED
@@ -32,6 +32,10 @@ class TestEnviron(unittest.TestCase):
|
|
32
32
|
self.assertEqual(a.dtype, jnp.float32)
|
33
33
|
|
34
34
|
with bst.environ.context(precision=16):
|
35
|
+
a = bst.random.randn(1)
|
36
|
+
self.assertEqual(a.dtype, jnp.float16)
|
37
|
+
|
38
|
+
with bst.environ.context(precision='bf16'):
|
35
39
|
a = bst.random.randn(1)
|
36
40
|
self.assertEqual(a.dtype, jnp.bfloat16)
|
37
41
|
|
brainstate/event/__init__.py
CHANGED