brainstate 0.1.0.post20250126__py2.py3-none-any.whl → 0.1.0.post20250127__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -730,7 +730,7 @@ def _make_jaxpr(
730
730
  def _abstractify(args, kwargs):
731
731
  flat_args, in_tree = jax.tree.flatten((args, kwargs))
732
732
  if abstracted_axes is None:
733
- return map(jax.api_util.shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
733
+ return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
734
734
  else:
735
735
  axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
736
736
  in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
@@ -35,6 +35,47 @@ Output = Any
35
35
  class ProgressBar(object):
36
36
  """
37
37
  A progress bar for tracking the progress of a jitted for-loop computation.
38
+
39
+ It can be used in :py:func:`for_loop`, :py:func:`checkpointed_for_loop`, :py:func:`scan`,
40
+ and :py:func:`checkpointed_scan` functions. Or any other jitted function that uses
41
+ a for-loop.
42
+
43
+ The message displayed in the progress bar can be customized by the following two methods:
44
+
45
+ 1. By passing a string to the `desc` argument. For example:
46
+
47
+ .. code-block:: python
48
+
49
+ ProgressBar(desc="Running 1000 iterations")
50
+
51
+ 2. By passing a tuple with a string and a callable function to the `desc` argument. The callable
52
+ function should take a dictionary as input and return a dictionary. The returned dictionary
53
+ will be used to format the string. For example:
54
+
55
+ .. code-block:: python
56
+
57
+ a = bst.State(1.)
58
+ def loop_fn(x):
59
+ a.value = x.value + 1.
60
+ return jnp.sum(x ** 2)
61
+
62
+ pbar = ProgressBar(desc=("Running {i} iterations, loss = {loss}",
63
+ lambda i_carray_y: {"i": i_carray_y["i"], "loss": i_carray_y["y"]}))
64
+
65
+ bst.compile.for_loop(loop_fn, xs, pbar=pbar)
66
+
67
+ In this example, ``"i"`` denotes the iteration number and ``"loss"`` is computed from the output,
68
+ the ``"carry"`` is the dynamic state in the loop, for example ``a.value`` in this case.
69
+
70
+
71
+ Args:
72
+ freq: The frequency at which to print the progress bar. If not specified, the progress
73
+ bar will be printed every 5% of the total iterations.
74
+ count: The number of times to print the progress bar. If not specified, the progress
75
+ bar will be printed every 5% of the total iterations.
76
+ desc: A description of the progress bar. If not specified, a default message will be
77
+ displayed.
78
+ kwargs: Additional keyword arguments to pass to the progress bar.
38
79
  """
39
80
  __module__ = "brainstate.compile"
40
81
 
@@ -42,7 +83,7 @@ class ProgressBar(object):
42
83
  self,
43
84
  freq: Optional[int] = None,
44
85
  count: Optional[int] = None,
45
- desc: Optional[Tuple[str, Callable[[Dict], Dict]]] = None,
86
+ desc: Optional[Tuple[str, Callable[[Dict], Dict]] | str] = None,
46
87
  **kwargs
47
88
  ):
48
89
  # print rate
@@ -62,9 +103,12 @@ class ProgressBar(object):
62
103
 
63
104
  # description
64
105
  if desc is not None:
65
- assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
66
- assert isinstance(desc[0], str), 'Description should be a string.'
67
- assert callable(desc[1]), 'Description should be a callable.'
106
+ if isinstance(desc, str):
107
+ pass
108
+ else:
109
+ assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
110
+ assert isinstance(desc[0], str), 'Description should be a string.'
111
+ assert callable(desc[1]), 'Description should be a callable.'
68
112
  self.desc = desc
69
113
 
70
114
  # check if tqdm is installed
@@ -136,8 +180,7 @@ class ProgressBarRunner(object):
136
180
  self.tqdm_bars[0].close()
137
181
 
138
182
  def __call__(self, iter_num, **kwargs):
139
- data = dict(i=iter_num, **kwargs)
140
- data = dict() if isinstance(self.message, str) else self.message[1](data)
183
+ data = dict() if isinstance(self.message, str) else self.message[1](dict(i=iter_num, **kwargs))
141
184
  assert isinstance(data, dict), 'Description function should return a dictionary.'
142
185
 
143
186
  _ = jax.lax.cond(
@@ -224,7 +224,8 @@ def poisson_input(
224
224
  weight = maybe_state(weight)
225
225
 
226
226
  assert isinstance(target, State), 'The target must be a State.'
227
- p = (freq * environ.get_dt()).to_decimal()
227
+ p = freq * environ.get_dt()
228
+ p = p.to_decimal() if isinstance(p, u.Quantity) else p
228
229
  a = num_input * p
229
230
  b = num_input * (1 - p)
230
231
  tar_val = target.value
@@ -291,7 +292,8 @@ def poisson_input(
291
292
 
292
293
  # update target variable
293
294
  target.value = jax.tree.map(
294
- lambda x: x * weight,
295
+ lambda tar, x: tar + x * weight,
296
+ target.value,
295
297
  inp,
296
298
  is_leaf=u.math.is_quantity
297
299
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20250126
3
+ Version: 0.1.0.post20250127
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -30,9 +30,9 @@ brainstate/compile/_loop_collect_return.py,sha256=TrKBZhtQecTtuiVz_HOeyepde-znzj
30
30
  brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
31
31
  brainstate/compile/_loop_no_collection.py,sha256=qto2__Zt2PJntkjB9AXEgraGLvNUJS483BhCXjJyqv0,7495
32
32
  brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
33
- brainstate/compile/_make_jaxpr.py,sha256=BwQZXfpWeXim2Pq6I7NgIvZ8p0hSH_QFyVlR4CHswmw,33322
33
+ brainstate/compile/_make_jaxpr.py,sha256=J4oWoPBwG-fdJvNhBEtNgmo3rXrIWCoajELhaIumgPU,33309
34
34
  brainstate/compile/_make_jaxpr_test.py,sha256=3gwdiutn_PJyiweu3oPEXumxEVHKaE2xDGvkwZy2GEo,4367
35
- brainstate/compile/_progress_bar.py,sha256=0oVlZ4kW_ZMciJjOR_ebj3PNe_XkCMkoQpv-HUUdoF0,5554
35
+ brainstate/compile/_progress_bar.py,sha256=5pCMCEmbTO5XmKtzRUJGA178tuBznWKuh9Kw00wAL1I,7524
36
36
  brainstate/compile/_unvmap.py,sha256=CJA6D9lUcBfvdLrpFVvC2AdTJqe9uY0Ht6PltQJyr4U,4228
37
37
  brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
38
38
  brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
@@ -65,7 +65,7 @@ brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2
65
65
  brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
66
66
  brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=MsbPyaiDyjungyzuK2b3exRGaMpZgmsmmNHNLjgxQKw,15269
67
67
  brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
68
- brainstate/nn/_dyn_impl/_inputs.py,sha256=UNoGxKIKXwPnhelljDowqAWlV6ds7aBBkEbvdy2oDI4,11302
68
+ brainstate/nn/_dyn_impl/_inputs.py,sha256=QOUpAb2YJOE78uAvIS8Ep6MFcQHV-V6uRwmYvk5p9bk,11385
69
69
  brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
70
70
  brainstate/nn/_dyn_impl/_rate_rnns.py,sha256=dz_yT_6hJVhKulcjIARbGtmMzZqISws96CtBc6o5GOo,14768
71
71
  brainstate/nn/_dyn_impl/_rate_rnns_test.py,sha256=gNgtr-a4ZiU1XF9wFG1HiJ9fLosfWchVR9Zn1x39xt4,2452
@@ -117,8 +117,8 @@ brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14
117
117
  brainstate/util/_pretty_repr.py,sha256=bDpU4gbkS4B8cXBkiN8kBQNmruxiJzDRF-eIqzyeYnM,5716
118
118
  brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
119
119
  brainstate/util/_struct.py,sha256=KMMHcshOM20gYhSahNzWLxsTt-Rt3AeX3Uz26-rP9vI,17619
120
- brainstate-0.1.0.post20250126.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
121
- brainstate-0.1.0.post20250126.dist-info/METADATA,sha256=G9QbLQwKk1SwCQ4P2MhFRufKVefEUHsl_DNDut8GRdQ,3585
122
- brainstate-0.1.0.post20250126.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
123
- brainstate-0.1.0.post20250126.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
124
- brainstate-0.1.0.post20250126.dist-info/RECORD,,
120
+ brainstate-0.1.0.post20250127.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
121
+ brainstate-0.1.0.post20250127.dist-info/METADATA,sha256=-j8gDuJN37nbSkijFC9TtG5UbflMZZWK-sftY7fkeps,3585
122
+ brainstate-0.1.0.post20250127.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
123
+ brainstate-0.1.0.post20250127.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
124
+ brainstate-0.1.0.post20250127.dist-info/RECORD,,