brainstate 0.1.0.post20250126__py2.py3-none-any.whl → 0.1.0.post20250127__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/compile/_make_jaxpr.py +1 -1
- brainstate/compile/_progress_bar.py +49 -6
- brainstate/nn/_dyn_impl/_inputs.py +4 -2
- {brainstate-0.1.0.post20250126.dist-info → brainstate-0.1.0.post20250127.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20250126.dist-info → brainstate-0.1.0.post20250127.dist-info}/RECORD +8 -8
- {brainstate-0.1.0.post20250126.dist-info → brainstate-0.1.0.post20250127.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250126.dist-info → brainstate-0.1.0.post20250127.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250126.dist-info → brainstate-0.1.0.post20250127.dist-info}/top_level.txt +0 -0
@@ -730,7 +730,7 @@ def _make_jaxpr(
|
|
730
730
|
def _abstractify(args, kwargs):
|
731
731
|
flat_args, in_tree = jax.tree.flatten((args, kwargs))
|
732
732
|
if abstracted_axes is None:
|
733
|
-
return map(
|
733
|
+
return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
|
734
734
|
else:
|
735
735
|
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
|
736
736
|
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
|
@@ -35,6 +35,47 @@ Output = Any
|
|
35
35
|
class ProgressBar(object):
|
36
36
|
"""
|
37
37
|
A progress bar for tracking the progress of a jitted for-loop computation.
|
38
|
+
|
39
|
+
It can be used in :py:func:`for_loop`, :py:func:`checkpointed_for_loop`, :py:func:`scan`,
|
40
|
+
and :py:func:`checkpointed_scan` functions. Or any other jitted function that uses
|
41
|
+
a for-loop.
|
42
|
+
|
43
|
+
The message displayed in the progress bar can be customized by the following two methods:
|
44
|
+
|
45
|
+
1. By passing a string to the `desc` argument. For example:
|
46
|
+
|
47
|
+
.. code-block:: python
|
48
|
+
|
49
|
+
ProgressBar(desc="Running 1000 iterations")
|
50
|
+
|
51
|
+
2. By passing a tuple with a string and a callable function to the `desc` argument. The callable
|
52
|
+
function should take a dictionary as input and return a dictionary. The returned dictionary
|
53
|
+
will be used to format the string. For example:
|
54
|
+
|
55
|
+
.. code-block:: python
|
56
|
+
|
57
|
+
a = bst.State(1.)
|
58
|
+
def loop_fn(x):
|
59
|
+
a.value = x.value + 1.
|
60
|
+
return jnp.sum(x ** 2)
|
61
|
+
|
62
|
+
pbar = ProgressBar(desc=("Running {i} iterations, loss = {loss}",
|
63
|
+
lambda i_carray_y: {"i": i_carray_y["i"], "loss": i_carray_y["y"]}))
|
64
|
+
|
65
|
+
bst.compile.for_loop(loop_fn, xs, pbar=pbar)
|
66
|
+
|
67
|
+
In this example, ``"i"`` denotes the iteration number and ``"loss"`` is computed from the output,
|
68
|
+
the ``"carry"`` is the dynamic state in the loop, for example ``a.value`` in this case.
|
69
|
+
|
70
|
+
|
71
|
+
Args:
|
72
|
+
freq: The frequency at which to print the progress bar. If not specified, the progress
|
73
|
+
bar will be printed every 5% of the total iterations.
|
74
|
+
count: The number of times to print the progress bar. If not specified, the progress
|
75
|
+
bar will be printed every 5% of the total iterations.
|
76
|
+
desc: A description of the progress bar. If not specified, a default message will be
|
77
|
+
displayed.
|
78
|
+
kwargs: Additional keyword arguments to pass to the progress bar.
|
38
79
|
"""
|
39
80
|
__module__ = "brainstate.compile"
|
40
81
|
|
@@ -42,7 +83,7 @@ class ProgressBar(object):
|
|
42
83
|
self,
|
43
84
|
freq: Optional[int] = None,
|
44
85
|
count: Optional[int] = None,
|
45
|
-
desc: Optional[Tuple[str, Callable[[Dict], Dict]]] = None,
|
86
|
+
desc: Optional[Tuple[str, Callable[[Dict], Dict]] | str] = None,
|
46
87
|
**kwargs
|
47
88
|
):
|
48
89
|
# print rate
|
@@ -62,9 +103,12 @@ class ProgressBar(object):
|
|
62
103
|
|
63
104
|
# description
|
64
105
|
if desc is not None:
|
65
|
-
|
66
|
-
|
67
|
-
|
106
|
+
if isinstance(desc, str):
|
107
|
+
pass
|
108
|
+
else:
|
109
|
+
assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
|
110
|
+
assert isinstance(desc[0], str), 'Description should be a string.'
|
111
|
+
assert callable(desc[1]), 'Description should be a callable.'
|
68
112
|
self.desc = desc
|
69
113
|
|
70
114
|
# check if tqdm is installed
|
@@ -136,8 +180,7 @@ class ProgressBarRunner(object):
|
|
136
180
|
self.tqdm_bars[0].close()
|
137
181
|
|
138
182
|
def __call__(self, iter_num, **kwargs):
|
139
|
-
data = dict(i=iter_num, **kwargs)
|
140
|
-
data = dict() if isinstance(self.message, str) else self.message[1](data)
|
183
|
+
data = dict() if isinstance(self.message, str) else self.message[1](dict(i=iter_num, **kwargs))
|
141
184
|
assert isinstance(data, dict), 'Description function should return a dictionary.'
|
142
185
|
|
143
186
|
_ = jax.lax.cond(
|
@@ -224,7 +224,8 @@ def poisson_input(
|
|
224
224
|
weight = maybe_state(weight)
|
225
225
|
|
226
226
|
assert isinstance(target, State), 'The target must be a State.'
|
227
|
-
p =
|
227
|
+
p = freq * environ.get_dt()
|
228
|
+
p = p.to_decimal() if isinstance(p, u.Quantity) else p
|
228
229
|
a = num_input * p
|
229
230
|
b = num_input * (1 - p)
|
230
231
|
tar_val = target.value
|
@@ -291,7 +292,8 @@ def poisson_input(
|
|
291
292
|
|
292
293
|
# update target variable
|
293
294
|
target.value = jax.tree.map(
|
294
|
-
lambda x: x * weight,
|
295
|
+
lambda tar, x: tar + x * weight,
|
296
|
+
target.value,
|
295
297
|
inp,
|
296
298
|
is_leaf=u.math.is_quantity
|
297
299
|
)
|
{brainstate-0.1.0.post20250126.dist-info → brainstate-0.1.0.post20250127.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250127
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -30,9 +30,9 @@ brainstate/compile/_loop_collect_return.py,sha256=TrKBZhtQecTtuiVz_HOeyepde-znzj
|
|
30
30
|
brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
|
31
31
|
brainstate/compile/_loop_no_collection.py,sha256=qto2__Zt2PJntkjB9AXEgraGLvNUJS483BhCXjJyqv0,7495
|
32
32
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
33
|
-
brainstate/compile/_make_jaxpr.py,sha256=
|
33
|
+
brainstate/compile/_make_jaxpr.py,sha256=J4oWoPBwG-fdJvNhBEtNgmo3rXrIWCoajELhaIumgPU,33309
|
34
34
|
brainstate/compile/_make_jaxpr_test.py,sha256=3gwdiutn_PJyiweu3oPEXumxEVHKaE2xDGvkwZy2GEo,4367
|
35
|
-
brainstate/compile/_progress_bar.py,sha256=
|
35
|
+
brainstate/compile/_progress_bar.py,sha256=5pCMCEmbTO5XmKtzRUJGA178tuBznWKuh9Kw00wAL1I,7524
|
36
36
|
brainstate/compile/_unvmap.py,sha256=CJA6D9lUcBfvdLrpFVvC2AdTJqe9uY0Ht6PltQJyr4U,4228
|
37
37
|
brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
|
38
38
|
brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
|
@@ -65,7 +65,7 @@ brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2
|
|
65
65
|
brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
|
66
66
|
brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=MsbPyaiDyjungyzuK2b3exRGaMpZgmsmmNHNLjgxQKw,15269
|
67
67
|
brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
|
68
|
-
brainstate/nn/_dyn_impl/_inputs.py,sha256=
|
68
|
+
brainstate/nn/_dyn_impl/_inputs.py,sha256=QOUpAb2YJOE78uAvIS8Ep6MFcQHV-V6uRwmYvk5p9bk,11385
|
69
69
|
brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
|
70
70
|
brainstate/nn/_dyn_impl/_rate_rnns.py,sha256=dz_yT_6hJVhKulcjIARbGtmMzZqISws96CtBc6o5GOo,14768
|
71
71
|
brainstate/nn/_dyn_impl/_rate_rnns_test.py,sha256=gNgtr-a4ZiU1XF9wFG1HiJ9fLosfWchVR9Zn1x39xt4,2452
|
@@ -117,8 +117,8 @@ brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14
|
|
117
117
|
brainstate/util/_pretty_repr.py,sha256=bDpU4gbkS4B8cXBkiN8kBQNmruxiJzDRF-eIqzyeYnM,5716
|
118
118
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
119
119
|
brainstate/util/_struct.py,sha256=KMMHcshOM20gYhSahNzWLxsTt-Rt3AeX3Uz26-rP9vI,17619
|
120
|
-
brainstate-0.1.0.
|
121
|
-
brainstate-0.1.0.
|
122
|
-
brainstate-0.1.0.
|
123
|
-
brainstate-0.1.0.
|
124
|
-
brainstate-0.1.0.
|
120
|
+
brainstate-0.1.0.post20250127.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
121
|
+
brainstate-0.1.0.post20250127.dist-info/METADATA,sha256=-j8gDuJN37nbSkijFC9TtG5UbflMZZWK-sftY7fkeps,3585
|
122
|
+
brainstate-0.1.0.post20250127.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
123
|
+
brainstate-0.1.0.post20250127.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
124
|
+
brainstate-0.1.0.post20250127.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20250126.dist-info → brainstate-0.1.0.post20250127.dist-info}/top_level.txt
RENAMED
File without changes
|