brainstate 0.1.0.post20250126__py2.py3-none-any.whl → 0.1.0.post20250129__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -730,7 +730,7 @@ def _make_jaxpr(
730
730
  def _abstractify(args, kwargs):
731
731
  flat_args, in_tree = jax.tree.flatten((args, kwargs))
732
732
  if abstracted_axes is None:
733
- return map(jax.api_util.shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
733
+ return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
734
734
  else:
735
735
  axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
736
736
  in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
@@ -35,6 +35,47 @@ Output = Any
35
35
  class ProgressBar(object):
36
36
  """
37
37
  A progress bar for tracking the progress of a jitted for-loop computation.
38
+
39
+ It can be used in :py:func:`for_loop`, :py:func:`checkpointed_for_loop`, :py:func:`scan`,
40
+ and :py:func:`checkpointed_scan` functions. Or any other jitted function that uses
41
+ a for-loop.
42
+
43
+ The message displayed in the progress bar can be customized by the following two methods:
44
+
45
+ 1. By passing a string to the `desc` argument. For example:
46
+
47
+ .. code-block:: python
48
+
49
+ ProgressBar(desc="Running 1000 iterations")
50
+
51
+ 2. By passing a tuple with a string and a callable function to the `desc` argument. The callable
52
+ function should take a dictionary as input and return a dictionary. The returned dictionary
53
+ will be used to format the string. For example:
54
+
55
+ .. code-block:: python
56
+
57
+ a = bst.State(1.)
58
+ def loop_fn(x):
59
+ a.value = x.value + 1.
60
+ return jnp.sum(x ** 2)
61
+
62
+ pbar = ProgressBar(desc=("Running {i} iterations, loss = {loss}",
63
+ lambda i_carray_y: {"i": i_carray_y["i"], "loss": i_carray_y["y"]}))
64
+
65
+ bst.compile.for_loop(loop_fn, xs, pbar=pbar)
66
+
67
+ In this example, ``"i"`` denotes the iteration number and ``"loss"`` is computed from the output,
68
+ the ``"carry"`` is the dynamic state in the loop, for example ``a.value`` in this case.
69
+
70
+
71
+ Args:
72
+ freq: The frequency at which to print the progress bar. If not specified, the progress
73
+ bar will be printed every 5% of the total iterations.
74
+ count: The number of times to print the progress bar. If not specified, the progress
75
+ bar will be printed every 5% of the total iterations.
76
+ desc: A description of the progress bar. If not specified, a default message will be
77
+ displayed.
78
+ kwargs: Additional keyword arguments to pass to the progress bar.
38
79
  """
39
80
  __module__ = "brainstate.compile"
40
81
 
@@ -42,7 +83,7 @@ class ProgressBar(object):
42
83
  self,
43
84
  freq: Optional[int] = None,
44
85
  count: Optional[int] = None,
45
- desc: Optional[Tuple[str, Callable[[Dict], Dict]]] = None,
86
+ desc: Optional[Tuple[str, Callable[[Dict], Dict]] | str] = None,
46
87
  **kwargs
47
88
  ):
48
89
  # print rate
@@ -62,9 +103,12 @@ class ProgressBar(object):
62
103
 
63
104
  # description
64
105
  if desc is not None:
65
- assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
66
- assert isinstance(desc[0], str), 'Description should be a string.'
67
- assert callable(desc[1]), 'Description should be a callable.'
106
+ if isinstance(desc, str):
107
+ pass
108
+ else:
109
+ assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
110
+ assert isinstance(desc[0], str), 'Description should be a string.'
111
+ assert callable(desc[1]), 'Description should be a callable.'
68
112
  self.desc = desc
69
113
 
70
114
  # check if tqdm is installed
@@ -136,8 +180,7 @@ class ProgressBarRunner(object):
136
180
  self.tqdm_bars[0].close()
137
181
 
138
182
  def __call__(self, iter_num, **kwargs):
139
- data = dict(i=iter_num, **kwargs)
140
- data = dict() if isinstance(self.message, str) else self.message[1](data)
183
+ data = dict() if isinstance(self.message, str) else self.message[1](dict(i=iter_num, **kwargs))
141
184
  assert isinstance(data, dict), 'Description function should return a dictionary.'
142
185
 
143
186
  _ = jax.lax.cond(
@@ -216,15 +216,21 @@ def poisson_input(
216
216
  weight: u.Quantity,
217
217
  target: State,
218
218
  indices: Optional[Union[np.ndarray, jax.Array]] = None,
219
+ refractory: Optional[Union[jax.Array]] = None,
219
220
  ):
220
221
  """
221
222
  Poisson Input to the given :py:class:`brainstate.State`.
223
+
224
+ Args:
225
+ freq:
226
+ refractory: should be the same length with ``target``.
222
227
  """
223
228
  freq = maybe_state(freq)
224
229
  weight = maybe_state(weight)
225
230
 
226
231
  assert isinstance(target, State), 'The target must be a State.'
227
- p = (freq * environ.get_dt()).to_decimal()
232
+ p = freq * environ.get_dt()
233
+ p = p.to_decimal() if isinstance(p, u.Quantity) else p
228
234
  a = num_input * p
229
235
  b = num_input * (1 - p)
230
236
  tar_val = target.value
@@ -290,8 +296,9 @@ def poisson_input(
290
296
  # )
291
297
 
292
298
  # update target variable
293
- target.value = jax.tree.map(
294
- lambda x: x * weight,
299
+ data = jax.tree.map(
300
+ lambda tar, x: tar + x * weight,
301
+ target.value,
295
302
  inp,
296
303
  is_leaf=u.math.is_quantity
297
304
  )
@@ -356,9 +363,19 @@ def poisson_input(
356
363
  # )
357
364
 
358
365
  # update target variable
359
- target.value = jax.tree.map(
366
+ data = jax.tree.map(
360
367
  lambda x, tar: tar.at[indices].add(x * weight),
361
368
  inp,
362
369
  tar_val,
363
370
  is_leaf=u.math.is_quantity
364
371
  )
372
+
373
+ if refractory is not None:
374
+ target.value = jax.tree.map(
375
+ lambda x, tar: u.math.where(refractory, tar, x),
376
+ data,
377
+ tar_val,
378
+ is_leaf=u.math.is_quantity
379
+ )
380
+ else:
381
+ target.value = data
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20250126
3
+ Version: 0.1.0.post20250129
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -30,9 +30,9 @@ brainstate/compile/_loop_collect_return.py,sha256=TrKBZhtQecTtuiVz_HOeyepde-znzj
30
30
  brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
31
31
  brainstate/compile/_loop_no_collection.py,sha256=qto2__Zt2PJntkjB9AXEgraGLvNUJS483BhCXjJyqv0,7495
32
32
  brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
33
- brainstate/compile/_make_jaxpr.py,sha256=BwQZXfpWeXim2Pq6I7NgIvZ8p0hSH_QFyVlR4CHswmw,33322
33
+ brainstate/compile/_make_jaxpr.py,sha256=J4oWoPBwG-fdJvNhBEtNgmo3rXrIWCoajELhaIumgPU,33309
34
34
  brainstate/compile/_make_jaxpr_test.py,sha256=3gwdiutn_PJyiweu3oPEXumxEVHKaE2xDGvkwZy2GEo,4367
35
- brainstate/compile/_progress_bar.py,sha256=0oVlZ4kW_ZMciJjOR_ebj3PNe_XkCMkoQpv-HUUdoF0,5554
35
+ brainstate/compile/_progress_bar.py,sha256=5pCMCEmbTO5XmKtzRUJGA178tuBznWKuh9Kw00wAL1I,7524
36
36
  brainstate/compile/_unvmap.py,sha256=CJA6D9lUcBfvdLrpFVvC2AdTJqe9uY0Ht6PltQJyr4U,4228
37
37
  brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
38
38
  brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
@@ -65,7 +65,7 @@ brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2
65
65
  brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
66
66
  brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=MsbPyaiDyjungyzuK2b3exRGaMpZgmsmmNHNLjgxQKw,15269
67
67
  brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
68
- brainstate/nn/_dyn_impl/_inputs.py,sha256=UNoGxKIKXwPnhelljDowqAWlV6ds7aBBkEbvdy2oDI4,11302
68
+ brainstate/nn/_dyn_impl/_inputs.py,sha256=72-UnT-hpG03EvSYx72ldDhbgZwmaoOYxxkANpX6xpo,11779
69
69
  brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
70
70
  brainstate/nn/_dyn_impl/_rate_rnns.py,sha256=dz_yT_6hJVhKulcjIARbGtmMzZqISws96CtBc6o5GOo,14768
71
71
  brainstate/nn/_dyn_impl/_rate_rnns_test.py,sha256=gNgtr-a4ZiU1XF9wFG1HiJ9fLosfWchVR9Zn1x39xt4,2452
@@ -117,8 +117,8 @@ brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14
117
117
  brainstate/util/_pretty_repr.py,sha256=bDpU4gbkS4B8cXBkiN8kBQNmruxiJzDRF-eIqzyeYnM,5716
118
118
  brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
119
119
  brainstate/util/_struct.py,sha256=KMMHcshOM20gYhSahNzWLxsTt-Rt3AeX3Uz26-rP9vI,17619
120
- brainstate-0.1.0.post20250126.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
121
- brainstate-0.1.0.post20250126.dist-info/METADATA,sha256=G9QbLQwKk1SwCQ4P2MhFRufKVefEUHsl_DNDut8GRdQ,3585
122
- brainstate-0.1.0.post20250126.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
123
- brainstate-0.1.0.post20250126.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
124
- brainstate-0.1.0.post20250126.dist-info/RECORD,,
120
+ brainstate-0.1.0.post20250129.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
121
+ brainstate-0.1.0.post20250129.dist-info/METADATA,sha256=g93rl2oDT8uoaL4wbHVDFtC7Y5F3kuuebI64WBKUTc0,3585
122
+ brainstate-0.1.0.post20250129.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
123
+ brainstate-0.1.0.post20250129.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
124
+ brainstate-0.1.0.post20250129.dist-info/RECORD,,