brainstate 0.1.7__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -146
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -470
- brainstate/nn/_delay_test.py +238 -0
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1361
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1120
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -208
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.7.dist-info/RECORD +0 -131
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,202 +1,202 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
import copy
|
18
|
-
import importlib.util
|
19
|
-
from typing import Optional, Callable, Any, Tuple, Dict
|
20
|
-
|
21
|
-
import jax
|
22
|
-
|
23
|
-
tqdm_installed = importlib.util.find_spec('tqdm') is not None
|
24
|
-
|
25
|
-
__all__ = [
|
26
|
-
'ProgressBar',
|
27
|
-
]
|
28
|
-
|
29
|
-
Index = int
|
30
|
-
Carray = Any
|
31
|
-
Output = Any
|
32
|
-
|
33
|
-
|
34
|
-
class ProgressBar(object):
|
35
|
-
"""
|
36
|
-
A progress bar for tracking the progress of a jitted for-loop computation.
|
37
|
-
|
38
|
-
It can be used in :py:func:`for_loop`, :py:func:`checkpointed_for_loop`, :py:func:`scan`,
|
39
|
-
and :py:func:`checkpointed_scan` functions. Or any other jitted function that uses
|
40
|
-
a for-loop.
|
41
|
-
|
42
|
-
The message displayed in the progress bar can be customized by the following two methods:
|
43
|
-
|
44
|
-
1. By passing a string to the `desc` argument. For example:
|
45
|
-
|
46
|
-
.. code-block:: python
|
47
|
-
|
48
|
-
ProgressBar(desc="Running 1000 iterations")
|
49
|
-
|
50
|
-
2. By passing a tuple with a string and a callable function to the `desc` argument. The callable
|
51
|
-
function should take a dictionary as input and return a dictionary. The returned dictionary
|
52
|
-
will be used to format the string. For example:
|
53
|
-
|
54
|
-
.. code-block:: python
|
55
|
-
|
56
|
-
a = brainstate.State(1.)
|
57
|
-
def loop_fn(x):
|
58
|
-
a.value = x.value + 1.
|
59
|
-
return jnp.sum(x ** 2)
|
60
|
-
|
61
|
-
pbar = ProgressBar(desc=("Running {i} iterations, loss = {loss}",
|
62
|
-
lambda i_carray_y: {"i": i_carray_y["i"], "loss": i_carray_y["y"]}))
|
63
|
-
|
64
|
-
brainstate.compile.for_loop(loop_fn, xs, pbar=pbar)
|
65
|
-
|
66
|
-
In this example, ``"i"`` denotes the iteration number and ``"loss"`` is computed from the output,
|
67
|
-
the ``"carry"`` is the dynamic state in the loop, for example ``a.value`` in this case.
|
68
|
-
|
69
|
-
|
70
|
-
Args:
|
71
|
-
freq: The frequency at which to print the progress bar. If not specified, the progress
|
72
|
-
bar will be printed every 5% of the total iterations.
|
73
|
-
count: The number of times to print the progress bar. If not specified, the progress
|
74
|
-
bar will be printed every 5% of the total iterations.
|
75
|
-
desc: A description of the progress bar. If not specified, a default message will be
|
76
|
-
displayed.
|
77
|
-
kwargs: Additional keyword arguments to pass to the progress bar.
|
78
|
-
"""
|
79
|
-
__module__ = "brainstate.compile"
|
80
|
-
|
81
|
-
def __init__(
|
82
|
-
self,
|
83
|
-
freq: Optional[int] = None,
|
84
|
-
count: Optional[int] = None,
|
85
|
-
desc: Optional[Tuple[str, Callable[[Dict], Dict]] | str] = None,
|
86
|
-
**kwargs
|
87
|
-
):
|
88
|
-
# print rate
|
89
|
-
self.print_freq = freq
|
90
|
-
if isinstance(freq, int):
|
91
|
-
assert freq > 0, "Print rate should be > 0."
|
92
|
-
|
93
|
-
# print count
|
94
|
-
self.print_count = count
|
95
|
-
if self.print_freq is not None and self.print_count is not None:
|
96
|
-
raise ValueError("Cannot specify both count and freq.")
|
97
|
-
|
98
|
-
# other parameters
|
99
|
-
for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
|
100
|
-
kwargs.pop(kwarg, None)
|
101
|
-
self.kwargs = kwargs
|
102
|
-
|
103
|
-
# description
|
104
|
-
if desc is not None:
|
105
|
-
if isinstance(desc, str):
|
106
|
-
pass
|
107
|
-
else:
|
108
|
-
assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
|
109
|
-
assert isinstance(desc[0], str), 'Description should be a string.'
|
110
|
-
assert callable(desc[1]), 'Description should be a callable.'
|
111
|
-
self.desc = desc
|
112
|
-
|
113
|
-
# check if tqdm is installed
|
114
|
-
if not tqdm_installed:
|
115
|
-
raise ImportError("tqdm is not installed.")
|
116
|
-
|
117
|
-
def init(self, n: int):
|
118
|
-
kwargs = copy.copy(self.kwargs)
|
119
|
-
freq = self.print_freq
|
120
|
-
count = self.print_count
|
121
|
-
if count is not None:
|
122
|
-
freq, remainder = divmod(n, count)
|
123
|
-
if freq == 0:
|
124
|
-
raise ValueError(f"Count {count} is too large for n {n}.")
|
125
|
-
elif freq is None:
|
126
|
-
if n > 20:
|
127
|
-
freq = int(n / 20)
|
128
|
-
else:
|
129
|
-
freq = 1
|
130
|
-
remainder = n % freq
|
131
|
-
else:
|
132
|
-
if freq < 1:
|
133
|
-
raise ValueError(f"Print rate should be > 0 got {freq}")
|
134
|
-
elif freq > n:
|
135
|
-
raise ValueError("Print rate should be less than the "
|
136
|
-
f"number of steps {n}, got {freq}")
|
137
|
-
remainder = n % freq
|
138
|
-
|
139
|
-
message = f"Running for {n:,} iterations" if self.desc is None else self.desc
|
140
|
-
return ProgressBarRunner(n, freq, remainder, message, **kwargs)
|
141
|
-
|
142
|
-
|
143
|
-
class ProgressBarRunner(object):
|
144
|
-
__module__ = "brainstate.compile"
|
145
|
-
|
146
|
-
def __init__(
|
147
|
-
self,
|
148
|
-
n: int,
|
149
|
-
print_freq: int,
|
150
|
-
remainder: int,
|
151
|
-
message: str | Tuple[str, Callable[[Dict], Dict]],
|
152
|
-
**kwargs
|
153
|
-
):
|
154
|
-
self.tqdm_bars = {}
|
155
|
-
self.kwargs = kwargs
|
156
|
-
self.n = n
|
157
|
-
self.print_freq = print_freq
|
158
|
-
self.remainder = remainder
|
159
|
-
self.message = message
|
160
|
-
|
161
|
-
def _define_tqdm(self, x: dict):
|
162
|
-
from tqdm.auto import tqdm
|
163
|
-
self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
|
164
|
-
if isinstance(self.message, str):
|
165
|
-
self.tqdm_bars[0].set_description(self.message, refresh=False)
|
166
|
-
else:
|
167
|
-
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
168
|
-
|
169
|
-
def _update_tqdm(self, x: dict):
|
170
|
-
self.tqdm_bars[0].update(self.print_freq)
|
171
|
-
if not isinstance(self.message, str):
|
172
|
-
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
173
|
-
|
174
|
-
def _close_tqdm(self, x: dict):
|
175
|
-
if self.remainder > 0:
|
176
|
-
self.tqdm_bars[0].update(self.remainder)
|
177
|
-
if not isinstance(self.message, str):
|
178
|
-
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
179
|
-
self.tqdm_bars[0].close()
|
180
|
-
|
181
|
-
def __call__(self, iter_num, **kwargs):
|
182
|
-
data = dict() if isinstance(self.message, str) else self.message[1](dict(i=iter_num, **kwargs))
|
183
|
-
assert isinstance(data, dict), 'Description function should return a dictionary.'
|
184
|
-
|
185
|
-
_ = jax.lax.cond(
|
186
|
-
iter_num == 0,
|
187
|
-
lambda x: jax.debug.callback(self._define_tqdm, x, ordered=True),
|
188
|
-
lambda x: None,
|
189
|
-
data
|
190
|
-
)
|
191
|
-
_ = jax.lax.cond(
|
192
|
-
iter_num % self.print_freq == (self.print_freq - 1),
|
193
|
-
lambda x: jax.debug.callback(self._update_tqdm, x, ordered=True),
|
194
|
-
lambda x: None,
|
195
|
-
data
|
196
|
-
)
|
197
|
-
_ = jax.lax.cond(
|
198
|
-
iter_num == self.n - 1,
|
199
|
-
lambda x: jax.debug.callback(self._close_tqdm, x, ordered=True),
|
200
|
-
lambda x: None,
|
201
|
-
data
|
202
|
-
)
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import copy
|
18
|
+
import importlib.util
|
19
|
+
from typing import Optional, Callable, Any, Tuple, Dict
|
20
|
+
|
21
|
+
import jax
|
22
|
+
|
23
|
+
tqdm_installed = importlib.util.find_spec('tqdm') is not None
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
'ProgressBar',
|
27
|
+
]
|
28
|
+
|
29
|
+
Index = int
|
30
|
+
Carray = Any
|
31
|
+
Output = Any
|
32
|
+
|
33
|
+
|
34
|
+
class ProgressBar(object):
|
35
|
+
"""
|
36
|
+
A progress bar for tracking the progress of a jitted for-loop computation.
|
37
|
+
|
38
|
+
It can be used in :py:func:`for_loop`, :py:func:`checkpointed_for_loop`, :py:func:`scan`,
|
39
|
+
and :py:func:`checkpointed_scan` functions. Or any other jitted function that uses
|
40
|
+
a for-loop.
|
41
|
+
|
42
|
+
The message displayed in the progress bar can be customized by the following two methods:
|
43
|
+
|
44
|
+
1. By passing a string to the `desc` argument. For example:
|
45
|
+
|
46
|
+
.. code-block:: python
|
47
|
+
|
48
|
+
ProgressBar(desc="Running 1000 iterations")
|
49
|
+
|
50
|
+
2. By passing a tuple with a string and a callable function to the `desc` argument. The callable
|
51
|
+
function should take a dictionary as input and return a dictionary. The returned dictionary
|
52
|
+
will be used to format the string. For example:
|
53
|
+
|
54
|
+
.. code-block:: python
|
55
|
+
|
56
|
+
a = brainstate.State(1.)
|
57
|
+
def loop_fn(x):
|
58
|
+
a.value = x.value + 1.
|
59
|
+
return jnp.sum(x ** 2)
|
60
|
+
|
61
|
+
pbar = ProgressBar(desc=("Running {i} iterations, loss = {loss}",
|
62
|
+
lambda i_carray_y: {"i": i_carray_y["i"], "loss": i_carray_y["y"]}))
|
63
|
+
|
64
|
+
brainstate.compile.for_loop(loop_fn, xs, pbar=pbar)
|
65
|
+
|
66
|
+
In this example, ``"i"`` denotes the iteration number and ``"loss"`` is computed from the output,
|
67
|
+
the ``"carry"`` is the dynamic state in the loop, for example ``a.value`` in this case.
|
68
|
+
|
69
|
+
|
70
|
+
Args:
|
71
|
+
freq: The frequency at which to print the progress bar. If not specified, the progress
|
72
|
+
bar will be printed every 5% of the total iterations.
|
73
|
+
count: The number of times to print the progress bar. If not specified, the progress
|
74
|
+
bar will be printed every 5% of the total iterations.
|
75
|
+
desc: A description of the progress bar. If not specified, a default message will be
|
76
|
+
displayed.
|
77
|
+
kwargs: Additional keyword arguments to pass to the progress bar.
|
78
|
+
"""
|
79
|
+
__module__ = "brainstate.compile"
|
80
|
+
|
81
|
+
def __init__(
|
82
|
+
self,
|
83
|
+
freq: Optional[int] = None,
|
84
|
+
count: Optional[int] = None,
|
85
|
+
desc: Optional[Tuple[str, Callable[[Dict], Dict]] | str] = None,
|
86
|
+
**kwargs
|
87
|
+
):
|
88
|
+
# print rate
|
89
|
+
self.print_freq = freq
|
90
|
+
if isinstance(freq, int):
|
91
|
+
assert freq > 0, "Print rate should be > 0."
|
92
|
+
|
93
|
+
# print count
|
94
|
+
self.print_count = count
|
95
|
+
if self.print_freq is not None and self.print_count is not None:
|
96
|
+
raise ValueError("Cannot specify both count and freq.")
|
97
|
+
|
98
|
+
# other parameters
|
99
|
+
for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
|
100
|
+
kwargs.pop(kwarg, None)
|
101
|
+
self.kwargs = kwargs
|
102
|
+
|
103
|
+
# description
|
104
|
+
if desc is not None:
|
105
|
+
if isinstance(desc, str):
|
106
|
+
pass
|
107
|
+
else:
|
108
|
+
assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
|
109
|
+
assert isinstance(desc[0], str), 'Description should be a string.'
|
110
|
+
assert callable(desc[1]), 'Description should be a callable.'
|
111
|
+
self.desc = desc
|
112
|
+
|
113
|
+
# check if tqdm is installed
|
114
|
+
if not tqdm_installed:
|
115
|
+
raise ImportError("tqdm is not installed.")
|
116
|
+
|
117
|
+
def init(self, n: int):
|
118
|
+
kwargs = copy.copy(self.kwargs)
|
119
|
+
freq = self.print_freq
|
120
|
+
count = self.print_count
|
121
|
+
if count is not None:
|
122
|
+
freq, remainder = divmod(n, count)
|
123
|
+
if freq == 0:
|
124
|
+
raise ValueError(f"Count {count} is too large for n {n}.")
|
125
|
+
elif freq is None:
|
126
|
+
if n > 20:
|
127
|
+
freq = int(n / 20)
|
128
|
+
else:
|
129
|
+
freq = 1
|
130
|
+
remainder = n % freq
|
131
|
+
else:
|
132
|
+
if freq < 1:
|
133
|
+
raise ValueError(f"Print rate should be > 0 got {freq}")
|
134
|
+
elif freq > n:
|
135
|
+
raise ValueError("Print rate should be less than the "
|
136
|
+
f"number of steps {n}, got {freq}")
|
137
|
+
remainder = n % freq
|
138
|
+
|
139
|
+
message = f"Running for {n:,} iterations" if self.desc is None else self.desc
|
140
|
+
return ProgressBarRunner(n, freq, remainder, message, **kwargs)
|
141
|
+
|
142
|
+
|
143
|
+
class ProgressBarRunner(object):
|
144
|
+
__module__ = "brainstate.compile"
|
145
|
+
|
146
|
+
def __init__(
|
147
|
+
self,
|
148
|
+
n: int,
|
149
|
+
print_freq: int,
|
150
|
+
remainder: int,
|
151
|
+
message: str | Tuple[str, Callable[[Dict], Dict]],
|
152
|
+
**kwargs
|
153
|
+
):
|
154
|
+
self.tqdm_bars = {}
|
155
|
+
self.kwargs = kwargs
|
156
|
+
self.n = n
|
157
|
+
self.print_freq = print_freq
|
158
|
+
self.remainder = remainder
|
159
|
+
self.message = message
|
160
|
+
|
161
|
+
def _define_tqdm(self, x: dict):
|
162
|
+
from tqdm.auto import tqdm
|
163
|
+
self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
|
164
|
+
if isinstance(self.message, str):
|
165
|
+
self.tqdm_bars[0].set_description(self.message, refresh=False)
|
166
|
+
else:
|
167
|
+
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
168
|
+
|
169
|
+
def _update_tqdm(self, x: dict):
|
170
|
+
self.tqdm_bars[0].update(self.print_freq)
|
171
|
+
if not isinstance(self.message, str):
|
172
|
+
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
173
|
+
|
174
|
+
def _close_tqdm(self, x: dict):
|
175
|
+
if self.remainder > 0:
|
176
|
+
self.tqdm_bars[0].update(self.remainder)
|
177
|
+
if not isinstance(self.message, str):
|
178
|
+
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
179
|
+
self.tqdm_bars[0].close()
|
180
|
+
|
181
|
+
def __call__(self, iter_num, **kwargs):
|
182
|
+
data = dict() if isinstance(self.message, str) else self.message[1](dict(i=iter_num, **kwargs))
|
183
|
+
assert isinstance(data, dict), 'Description function should return a dictionary.'
|
184
|
+
|
185
|
+
_ = jax.lax.cond(
|
186
|
+
iter_num == 0,
|
187
|
+
lambda x: jax.debug.callback(self._define_tqdm, x, ordered=True),
|
188
|
+
lambda x: None,
|
189
|
+
data
|
190
|
+
)
|
191
|
+
_ = jax.lax.cond(
|
192
|
+
iter_num % self.print_freq == (self.print_freq - 1),
|
193
|
+
lambda x: jax.debug.callback(self._update_tqdm, x, ordered=True),
|
194
|
+
lambda x: None,
|
195
|
+
data
|
196
|
+
)
|
197
|
+
_ = jax.lax.cond(
|
198
|
+
iter_num == self.n - 1,
|
199
|
+
lambda x: jax.debug.callback(self._close_tqdm, x, ordered=True),
|
200
|
+
lambda x: None,
|
201
|
+
data
|
202
|
+
)
|