brainstate 0.1.0.post20250126__py2.py3-none-any.whl → 0.1.0.post20250129__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/compile/_make_jaxpr.py +1 -1
- brainstate/compile/_progress_bar.py +49 -6
- brainstate/nn/_dyn_impl/_inputs.py +21 -4
- {brainstate-0.1.0.post20250126.dist-info → brainstate-0.1.0.post20250129.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20250126.dist-info → brainstate-0.1.0.post20250129.dist-info}/RECORD +8 -8
- {brainstate-0.1.0.post20250126.dist-info → brainstate-0.1.0.post20250129.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250126.dist-info → brainstate-0.1.0.post20250129.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250126.dist-info → brainstate-0.1.0.post20250129.dist-info}/top_level.txt +0 -0
@@ -730,7 +730,7 @@ def _make_jaxpr(
|
|
730
730
|
def _abstractify(args, kwargs):
|
731
731
|
flat_args, in_tree = jax.tree.flatten((args, kwargs))
|
732
732
|
if abstracted_axes is None:
|
733
|
-
return map(
|
733
|
+
return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
|
734
734
|
else:
|
735
735
|
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
|
736
736
|
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
|
@@ -35,6 +35,47 @@ Output = Any
|
|
35
35
|
class ProgressBar(object):
|
36
36
|
"""
|
37
37
|
A progress bar for tracking the progress of a jitted for-loop computation.
|
38
|
+
|
39
|
+
It can be used in :py:func:`for_loop`, :py:func:`checkpointed_for_loop`, :py:func:`scan`,
|
40
|
+
and :py:func:`checkpointed_scan` functions. Or any other jitted function that uses
|
41
|
+
a for-loop.
|
42
|
+
|
43
|
+
The message displayed in the progress bar can be customized by the following two methods:
|
44
|
+
|
45
|
+
1. By passing a string to the `desc` argument. For example:
|
46
|
+
|
47
|
+
.. code-block:: python
|
48
|
+
|
49
|
+
ProgressBar(desc="Running 1000 iterations")
|
50
|
+
|
51
|
+
2. By passing a tuple with a string and a callable function to the `desc` argument. The callable
|
52
|
+
function should take a dictionary as input and return a dictionary. The returned dictionary
|
53
|
+
will be used to format the string. For example:
|
54
|
+
|
55
|
+
.. code-block:: python
|
56
|
+
|
57
|
+
a = bst.State(1.)
|
58
|
+
def loop_fn(x):
|
59
|
+
a.value = x.value + 1.
|
60
|
+
return jnp.sum(x ** 2)
|
61
|
+
|
62
|
+
pbar = ProgressBar(desc=("Running {i} iterations, loss = {loss}",
|
63
|
+
lambda i_carray_y: {"i": i_carray_y["i"], "loss": i_carray_y["y"]}))
|
64
|
+
|
65
|
+
bst.compile.for_loop(loop_fn, xs, pbar=pbar)
|
66
|
+
|
67
|
+
In this example, ``"i"`` denotes the iteration number and ``"loss"`` is computed from the output,
|
68
|
+
the ``"carry"`` is the dynamic state in the loop, for example ``a.value`` in this case.
|
69
|
+
|
70
|
+
|
71
|
+
Args:
|
72
|
+
freq: The frequency at which to print the progress bar. If not specified, the progress
|
73
|
+
bar will be printed every 5% of the total iterations.
|
74
|
+
count: The number of times to print the progress bar. If not specified, the progress
|
75
|
+
bar will be printed every 5% of the total iterations.
|
76
|
+
desc: A description of the progress bar. If not specified, a default message will be
|
77
|
+
displayed.
|
78
|
+
kwargs: Additional keyword arguments to pass to the progress bar.
|
38
79
|
"""
|
39
80
|
__module__ = "brainstate.compile"
|
40
81
|
|
@@ -42,7 +83,7 @@ class ProgressBar(object):
|
|
42
83
|
self,
|
43
84
|
freq: Optional[int] = None,
|
44
85
|
count: Optional[int] = None,
|
45
|
-
desc: Optional[Tuple[str, Callable[[Dict], Dict]]] = None,
|
86
|
+
desc: Optional[Tuple[str, Callable[[Dict], Dict]] | str] = None,
|
46
87
|
**kwargs
|
47
88
|
):
|
48
89
|
# print rate
|
@@ -62,9 +103,12 @@ class ProgressBar(object):
|
|
62
103
|
|
63
104
|
# description
|
64
105
|
if desc is not None:
|
65
|
-
|
66
|
-
|
67
|
-
|
106
|
+
if isinstance(desc, str):
|
107
|
+
pass
|
108
|
+
else:
|
109
|
+
assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
|
110
|
+
assert isinstance(desc[0], str), 'Description should be a string.'
|
111
|
+
assert callable(desc[1]), 'Description should be a callable.'
|
68
112
|
self.desc = desc
|
69
113
|
|
70
114
|
# check if tqdm is installed
|
@@ -136,8 +180,7 @@ class ProgressBarRunner(object):
|
|
136
180
|
self.tqdm_bars[0].close()
|
137
181
|
|
138
182
|
def __call__(self, iter_num, **kwargs):
|
139
|
-
data = dict(i=iter_num, **kwargs)
|
140
|
-
data = dict() if isinstance(self.message, str) else self.message[1](data)
|
183
|
+
data = dict() if isinstance(self.message, str) else self.message[1](dict(i=iter_num, **kwargs))
|
141
184
|
assert isinstance(data, dict), 'Description function should return a dictionary.'
|
142
185
|
|
143
186
|
_ = jax.lax.cond(
|
@@ -216,15 +216,21 @@ def poisson_input(
|
|
216
216
|
weight: u.Quantity,
|
217
217
|
target: State,
|
218
218
|
indices: Optional[Union[np.ndarray, jax.Array]] = None,
|
219
|
+
refractory: Optional[Union[jax.Array]] = None,
|
219
220
|
):
|
220
221
|
"""
|
221
222
|
Poisson Input to the given :py:class:`brainstate.State`.
|
223
|
+
|
224
|
+
Args:
|
225
|
+
freq:
|
226
|
+
refractory: should be the same length with ``target``.
|
222
227
|
"""
|
223
228
|
freq = maybe_state(freq)
|
224
229
|
weight = maybe_state(weight)
|
225
230
|
|
226
231
|
assert isinstance(target, State), 'The target must be a State.'
|
227
|
-
p =
|
232
|
+
p = freq * environ.get_dt()
|
233
|
+
p = p.to_decimal() if isinstance(p, u.Quantity) else p
|
228
234
|
a = num_input * p
|
229
235
|
b = num_input * (1 - p)
|
230
236
|
tar_val = target.value
|
@@ -290,8 +296,9 @@ def poisson_input(
|
|
290
296
|
# )
|
291
297
|
|
292
298
|
# update target variable
|
293
|
-
|
294
|
-
lambda x: x * weight,
|
299
|
+
data = jax.tree.map(
|
300
|
+
lambda tar, x: tar + x * weight,
|
301
|
+
target.value,
|
295
302
|
inp,
|
296
303
|
is_leaf=u.math.is_quantity
|
297
304
|
)
|
@@ -356,9 +363,19 @@ def poisson_input(
|
|
356
363
|
# )
|
357
364
|
|
358
365
|
# update target variable
|
359
|
-
|
366
|
+
data = jax.tree.map(
|
360
367
|
lambda x, tar: tar.at[indices].add(x * weight),
|
361
368
|
inp,
|
362
369
|
tar_val,
|
363
370
|
is_leaf=u.math.is_quantity
|
364
371
|
)
|
372
|
+
|
373
|
+
if refractory is not None:
|
374
|
+
target.value = jax.tree.map(
|
375
|
+
lambda x, tar: u.math.where(refractory, tar, x),
|
376
|
+
data,
|
377
|
+
tar_val,
|
378
|
+
is_leaf=u.math.is_quantity
|
379
|
+
)
|
380
|
+
else:
|
381
|
+
target.value = data
|
{brainstate-0.1.0.post20250126.dist-info → brainstate-0.1.0.post20250129.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250129
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -30,9 +30,9 @@ brainstate/compile/_loop_collect_return.py,sha256=TrKBZhtQecTtuiVz_HOeyepde-znzj
|
|
30
30
|
brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
|
31
31
|
brainstate/compile/_loop_no_collection.py,sha256=qto2__Zt2PJntkjB9AXEgraGLvNUJS483BhCXjJyqv0,7495
|
32
32
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
33
|
-
brainstate/compile/_make_jaxpr.py,sha256=
|
33
|
+
brainstate/compile/_make_jaxpr.py,sha256=J4oWoPBwG-fdJvNhBEtNgmo3rXrIWCoajELhaIumgPU,33309
|
34
34
|
brainstate/compile/_make_jaxpr_test.py,sha256=3gwdiutn_PJyiweu3oPEXumxEVHKaE2xDGvkwZy2GEo,4367
|
35
|
-
brainstate/compile/_progress_bar.py,sha256=
|
35
|
+
brainstate/compile/_progress_bar.py,sha256=5pCMCEmbTO5XmKtzRUJGA178tuBznWKuh9Kw00wAL1I,7524
|
36
36
|
brainstate/compile/_unvmap.py,sha256=CJA6D9lUcBfvdLrpFVvC2AdTJqe9uY0Ht6PltQJyr4U,4228
|
37
37
|
brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
|
38
38
|
brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
|
@@ -65,7 +65,7 @@ brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2
|
|
65
65
|
brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
|
66
66
|
brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=MsbPyaiDyjungyzuK2b3exRGaMpZgmsmmNHNLjgxQKw,15269
|
67
67
|
brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
|
68
|
-
brainstate/nn/_dyn_impl/_inputs.py,sha256=
|
68
|
+
brainstate/nn/_dyn_impl/_inputs.py,sha256=72-UnT-hpG03EvSYx72ldDhbgZwmaoOYxxkANpX6xpo,11779
|
69
69
|
brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
|
70
70
|
brainstate/nn/_dyn_impl/_rate_rnns.py,sha256=dz_yT_6hJVhKulcjIARbGtmMzZqISws96CtBc6o5GOo,14768
|
71
71
|
brainstate/nn/_dyn_impl/_rate_rnns_test.py,sha256=gNgtr-a4ZiU1XF9wFG1HiJ9fLosfWchVR9Zn1x39xt4,2452
|
@@ -117,8 +117,8 @@ brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14
|
|
117
117
|
brainstate/util/_pretty_repr.py,sha256=bDpU4gbkS4B8cXBkiN8kBQNmruxiJzDRF-eIqzyeYnM,5716
|
118
118
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
119
119
|
brainstate/util/_struct.py,sha256=KMMHcshOM20gYhSahNzWLxsTt-Rt3AeX3Uz26-rP9vI,17619
|
120
|
-
brainstate-0.1.0.
|
121
|
-
brainstate-0.1.0.
|
122
|
-
brainstate-0.1.0.
|
123
|
-
brainstate-0.1.0.
|
124
|
-
brainstate-0.1.0.
|
120
|
+
brainstate-0.1.0.post20250129.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
121
|
+
brainstate-0.1.0.post20250129.dist-info/METADATA,sha256=g93rl2oDT8uoaL4wbHVDFtC7Y5F3kuuebI64WBKUTc0,3585
|
122
|
+
brainstate-0.1.0.post20250129.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
123
|
+
brainstate-0.1.0.post20250129.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
124
|
+
brainstate-0.1.0.post20250129.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20250126.dist-info → brainstate-0.1.0.post20250129.dist-info}/top_level.txt
RENAMED
File without changes
|