brainstate 0.1.0.post20250104__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. brainstate/_state.py +77 -44
  2. brainstate/_state_test.py +0 -17
  3. brainstate/augment/_eval_shape.py +9 -10
  4. brainstate/augment/_eval_shape_test.py +1 -1
  5. brainstate/augment/_mapping.py +265 -277
  6. brainstate/augment/_mapping_test.py +147 -175
  7. brainstate/compile/_ad_checkpoint.py +6 -4
  8. brainstate/compile/_error_if_test.py +1 -0
  9. brainstate/compile/_jit.py +37 -28
  10. brainstate/compile/_loop_collect_return.py +8 -5
  11. brainstate/compile/_loop_no_collection.py +2 -0
  12. brainstate/compile/_make_jaxpr.py +7 -3
  13. brainstate/compile/_make_jaxpr_test.py +2 -1
  14. brainstate/compile/_progress_bar.py +68 -40
  15. brainstate/compile/_unvmap.py +6 -2
  16. brainstate/environ.py +28 -18
  17. brainstate/environ_test.py +4 -0
  18. brainstate/event/__init__.py +0 -2
  19. brainstate/event/_csr.py +266 -23
  20. brainstate/event/_csr_test.py +187 -0
  21. brainstate/event/_fixedprob_mv.py +4 -2
  22. brainstate/event/_fixedprob_mv_test.py +2 -1
  23. brainstate/event/_xla_custom_op.py +16 -5
  24. brainstate/graph/__init__.py +8 -12
  25. brainstate/graph/_graph_node.py +1 -23
  26. brainstate/graph/_graph_operation.py +1 -1
  27. brainstate/graph/_graph_operation_test.py +0 -159
  28. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  29. brainstate/nn/_interaction/_conv.py +4 -2
  30. brainstate/nn/_interaction/_linear.py +84 -10
  31. brainstate/random/_rand_funs.py +9 -2
  32. brainstate/random/_rand_seed.py +12 -2
  33. brainstate/random/_rand_state.py +50 -179
  34. brainstate/surrogate.py +5 -1
  35. brainstate/util/__init__.py +0 -4
  36. brainstate/util/_caller.py +1 -1
  37. brainstate/util/_dict.py +4 -1
  38. brainstate/util/_filter.py +1 -1
  39. brainstate/util/_pretty_repr.py +1 -1
  40. brainstate/util/_struct.py +1 -1
  41. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
  42. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +46 -52
  43. brainstate/event/_csr_mv_test.py +0 -118
  44. brainstate/graph/_graph_context.py +0 -443
  45. brainstate/graph/_graph_context_test.py +0 -65
  46. brainstate/graph/_graph_convert.py +0 -246
  47. brainstate/util/_tracers.py +0 -68
  48. brainstate/util/_visualization.py +0 -47
  49. /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
  50. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
  51. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
  52. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
@@ -16,34 +16,59 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import copy
19
- from typing import Optional
19
+ import importlib.util
20
+ from typing import Optional, Callable, Any, Tuple, Dict
20
21
 
21
22
  import jax
22
23
 
23
- try:
24
- from tqdm.auto import tqdm
25
- except (ImportError, ModuleNotFoundError):
26
- tqdm = None
24
+ tqdm_installed = importlib.util.find_spec('tqdm') is not None
27
25
 
28
26
  __all__ = [
29
27
  'ProgressBar',
30
28
  ]
31
29
 
30
+ Index = int
31
+ Carray = Any
32
+ Output = Any
33
+
32
34
 
33
35
  class ProgressBar(object):
36
+ """
37
+ A progress bar for tracking the progress of a jitted for-loop computation.
38
+ """
34
39
  __module__ = "brainstate.compile"
35
40
 
36
- def __init__(self, freq: Optional[int] = None, count: Optional[int] = None, **kwargs):
41
+ def __init__(
42
+ self,
43
+ freq: Optional[int] = None,
44
+ count: Optional[int] = None,
45
+ desc: Optional[Tuple[str, Callable[[Dict], Dict]]] = None,
46
+ **kwargs
47
+ ):
48
+ # print rate
37
49
  self.print_freq = freq
38
50
  if isinstance(freq, int):
39
51
  assert freq > 0, "Print rate should be > 0."
52
+
53
+ # print count
40
54
  self.print_count = count
41
55
  if self.print_freq is not None and self.print_count is not None:
42
56
  raise ValueError("Cannot specify both count and freq.")
57
+
58
+ # other parameters
43
59
  for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
44
60
  kwargs.pop(kwarg, None)
45
61
  self.kwargs = kwargs
46
- if tqdm is None:
62
+
63
+ # description
64
+ if desc is not None:
65
+ assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
66
+ assert isinstance(desc[0], str), 'Description should be a string.'
67
+ assert callable(desc[1]), 'Description should be a callable.'
68
+ self.desc = desc
69
+
70
+ # check if tqdm is installed
71
+ if not tqdm_installed:
47
72
  raise ImportError("tqdm is not installed.")
48
73
 
49
74
  def init(self, n: int):
@@ -67,15 +92,22 @@ class ProgressBar(object):
67
92
  raise ValueError("Print rate should be less than the "
68
93
  f"number of steps {n}, got {freq}")
69
94
  remainder = n % freq
70
- desc = kwargs.pop("desc", f"Running for {n:,} iterations")
71
- message = kwargs.pop("message", desc)
72
- return ProgressBarRunner(n, message, freq, remainder, **kwargs)
95
+
96
+ message = f"Running for {n:,} iterations" if self.desc is None else self.desc
97
+ return ProgressBarRunner(n, freq, remainder, message, **kwargs)
73
98
 
74
99
 
75
100
  class ProgressBarRunner(object):
76
101
  __module__ = "brainstate.compile"
77
102
 
78
- def __init__(self, n: int, message, print_freq: int, remainder: int, **kwargs):
103
+ def __init__(
104
+ self,
105
+ n: int,
106
+ print_freq: int,
107
+ remainder: int,
108
+ message: str | Tuple[str, Callable[[Dict], Dict]],
109
+ **kwargs
110
+ ):
79
111
  self.tqdm_bars = {}
80
112
  self.kwargs = kwargs
81
113
  self.n = n
@@ -83,50 +115,46 @@ class ProgressBarRunner(object):
83
115
  self.remainder = remainder
84
116
  self.message = message
85
117
 
86
- def _define_tqdm(self):
118
+ def _define_tqdm(self, x: dict):
119
+ from tqdm.auto import tqdm
87
120
  self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
88
- self.tqdm_bars[0].set_description(self.message, refresh=False)
121
+ if isinstance(self.message, str):
122
+ self.tqdm_bars[0].set_description(self.message, refresh=False)
123
+ else:
124
+ self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
89
125
 
90
- def _update_tqdm(self):
126
+ def _update_tqdm(self, x: dict):
91
127
  self.tqdm_bars[0].update(self.print_freq)
128
+ if not isinstance(self.message, str):
129
+ self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
92
130
 
93
- def _close_tqdm(self):
131
+ def _close_tqdm(self, x: dict):
94
132
  if self.remainder > 0:
95
133
  self.tqdm_bars[0].update(self.remainder)
134
+ if not isinstance(self.message, str):
135
+ self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
96
136
  self.tqdm_bars[0].close()
97
137
 
98
- def _tqdm(self, is_init, is_print, is_final):
99
- if is_init:
100
- self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
101
- self.tqdm_bars[0].set_description(self.message, refresh=False)
102
- if is_print:
103
- self.tqdm_bars[0].update(self.print_freq)
104
- if is_final:
105
- if self.remainder > 0:
106
- self.tqdm_bars[0].update(self.remainder)
107
- self.tqdm_bars[0].close()
108
-
109
- def __call__(self, iter_num, *args, **kwargs):
110
- # jax.debug.callback(
111
- # self._tqdm,
112
- # iter_num == 0,
113
- # (iter_num + 1) % self.print_freq == 0,
114
- # iter_num == self.n - 1
115
- # )
138
+ def __call__(self, iter_num, **kwargs):
139
+ data = dict(i=iter_num, **kwargs)
140
+ data = dict() if isinstance(self.message, str) else self.message[1](data)
141
+ assert isinstance(data, dict), 'Description function should return a dictionary.'
116
142
 
117
143
  _ = jax.lax.cond(
118
144
  iter_num == 0,
119
- lambda: jax.debug.callback(self._define_tqdm, ordered=True),
120
- lambda: None,
145
+ lambda x: jax.debug.callback(self._define_tqdm, x, ordered=True),
146
+ lambda x: None,
147
+ data
121
148
  )
122
149
  _ = jax.lax.cond(
123
150
  iter_num % self.print_freq == (self.print_freq - 1),
124
- lambda: jax.debug.callback(self._update_tqdm, ordered=True),
125
- lambda: None,
151
+ lambda x: jax.debug.callback(self._update_tqdm, x, ordered=True),
152
+ lambda x: None,
153
+ data
126
154
  )
127
155
  _ = jax.lax.cond(
128
156
  iter_num == self.n - 1,
129
- lambda: jax.debug.callback(self._close_tqdm, ordered=True),
130
- lambda: None,
157
+ lambda x: jax.debug.callback(self._close_tqdm, x, ordered=True),
158
+ lambda x: None,
159
+ data
131
160
  )
132
-
@@ -19,9 +19,13 @@ import jax.core
19
19
  import jax.interpreters.batching as batching
20
20
  import jax.interpreters.mlir as mlir
21
21
  import jax.numpy as jnp
22
-
23
22
  from brainstate._utils import set_module_as
24
23
 
24
+ if jax.__version_info__ < (0, 4, 38):
25
+ from jax.core import Primitive
26
+ else:
27
+ from jax.extend.core import Primitive
28
+
25
29
  __all__ = [
26
30
  "unvmap",
27
31
  ]
@@ -43,7 +47,7 @@ def unvmap(x, op: str = 'any'):
43
47
 
44
48
  # unvmap_all
45
49
 
46
- unvmap_all_p = jax.core.Primitive("unvmap_all")
50
+ unvmap_all_p = Primitive("unvmap_all")
47
51
 
48
52
 
49
53
  def unvmap_all(x):
brainstate/environ.py CHANGED
@@ -249,7 +249,13 @@ def get_precision() -> int:
249
249
  The default precision.
250
250
  """
251
251
  precision = get('precision')
252
- return 16 if precision == 'bf16' else precision
252
+ if precision == 'bf16':
253
+ return 16
254
+ if isinstance(precision, int):
255
+ return precision
256
+ if isinstance(precision, str):
257
+ return int(precision)
258
+ raise ValueError(f'Unsupported precision: {precision}')
253
259
 
254
260
 
255
261
  def set(
@@ -345,8 +351,8 @@ def _set_jax_precision(precision: int | str):
345
351
  Args:
346
352
  precision: int. The default precision.
347
353
  """
348
- assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
349
- if precision == 64:
354
+ # assert precision in [64, 32, 16, 'bf16', 8], f'Precision must be in [64, 32, 16, "bf16", 8]. But got {precision}.'
355
+ if precision in [64, '64']:
350
356
  config.update("jax_enable_x64", True)
351
357
  else:
352
358
  config.update("jax_enable_x64", False)
@@ -354,13 +360,13 @@ def _set_jax_precision(precision: int | str):
354
360
 
355
361
  @functools.lru_cache()
356
362
  def _get_uint(precision: int):
357
- if precision == 64:
363
+ if precision in [64, '64']:
358
364
  return np.uint64
359
- elif precision == 32:
365
+ elif precision in [32, '32']:
360
366
  return np.uint32
361
- elif precision in [16, 'bf16']:
367
+ elif precision in [16, '16', 'bf16']:
362
368
  return np.uint16
363
- elif precision == 8:
369
+ elif precision in [8, '8']:
364
370
  return np.uint8
365
371
  else:
366
372
  raise ValueError(f'Unsupported precision: {precision}')
@@ -368,13 +374,13 @@ def _get_uint(precision: int):
368
374
 
369
375
  @functools.lru_cache()
370
376
  def _get_int(precision: int):
371
- if precision == 64:
377
+ if precision in [64, '64']:
372
378
  return np.int64
373
- elif precision == 32:
379
+ elif precision in [32, '32']:
374
380
  return np.int32
375
- elif precision in [16, 'bf16']:
381
+ elif precision in [16, '16', 'bf16']:
376
382
  return np.int16
377
- elif precision == 8:
383
+ elif precision in [8, '8']:
378
384
  return np.int8
379
385
  else:
380
386
  raise ValueError(f'Unsupported precision: {precision}')
@@ -382,25 +388,29 @@ def _get_int(precision: int):
382
388
 
383
389
  @functools.lru_cache()
384
390
  def _get_float(precision: int):
385
- if precision == 64:
391
+ if precision in [64, '64']:
386
392
  return np.float64
387
- elif precision == 32:
393
+ elif precision in [32, '32']:
388
394
  return np.float32
389
- elif precision == 16:
395
+ elif precision in [16, '16']:
390
396
  return np.float16
391
- elif precision == 'bf16':
397
+ elif precision in ['bf16']:
392
398
  return jnp.bfloat16
399
+ elif precision in [8, '8']:
400
+ return jnp.float8_e5m2
393
401
  else:
394
402
  raise ValueError(f'Unsupported precision: {precision}')
395
403
 
396
404
 
397
405
  @functools.lru_cache()
398
406
  def _get_complex(precision: int):
399
- if precision == 64:
407
+ if precision == [64, '64']:
400
408
  return np.complex128
401
- elif precision == 32:
409
+ elif precision == [32, '32']:
402
410
  return np.complex64
403
- elif precision in [16, 'bf16']:
411
+ elif precision in [16, '16', 'bf16']:
412
+ return np.complex64
413
+ elif precision == [8, '8']:
404
414
  return np.complex64
405
415
  else:
406
416
  raise ValueError(f'Unsupported precision: {precision}')
@@ -32,6 +32,10 @@ class TestEnviron(unittest.TestCase):
32
32
  self.assertEqual(a.dtype, jnp.float32)
33
33
 
34
34
  with bst.environ.context(precision=16):
35
+ a = bst.random.randn(1)
36
+ self.assertEqual(a.dtype, jnp.float16)
37
+
38
+ with bst.environ.context(precision='bf16'):
35
39
  a = bst.random.randn(1)
36
40
  self.assertEqual(a.dtype, jnp.bfloat16)
37
41
 
@@ -15,13 +15,11 @@
15
15
 
16
16
 
17
17
  from ._csr import *
18
- from ._csr_mv import *
19
18
  from ._fixedprob_mv import *
20
19
  from ._linear_mv import *
21
20
  from ._xla_custom_op import *
22
21
 
23
22
  __all__ = [
24
- 'CSRLinear',
25
23
  'FixedProb',
26
24
  'XLACustomOp',
27
25
  'CSR',