brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -58
- brainstate/_compatible_import.py +340 -148
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +45 -55
- brainstate/_state.py +1652 -1605
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -563
- brainstate/environ_test.py +1223 -62
- brainstate/graph/__init__.py +22 -29
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1433 -365
- brainstate/mixin_test.py +1017 -77
- brainstate/nn/__init__.py +137 -135
- brainstate/nn/_activations.py +1100 -808
- brainstate/nn/_activations_test.py +354 -331
- brainstate/nn/_collective_ops.py +633 -514
- brainstate/nn/_collective_ops_test.py +774 -43
- brainstate/nn/_common.py +226 -178
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +2010 -501
- brainstate/nn/_conv_test.py +849 -238
- brainstate/nn/_delay.py +575 -588
- brainstate/nn/_delay_test.py +243 -238
- brainstate/nn/_dropout.py +618 -426
- brainstate/nn/_dropout_test.py +477 -100
- brainstate/nn/_dynamics.py +1267 -1343
- brainstate/nn/_dynamics_test.py +67 -78
- brainstate/nn/_elementwise.py +1298 -1119
- brainstate/nn/_elementwise_test.py +830 -169
- brainstate/nn/_embedding.py +408 -58
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
- brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
- brainstate/nn/_exp_euler.py +254 -92
- brainstate/nn/_exp_euler_test.py +377 -35
- brainstate/nn/_linear.py +744 -424
- brainstate/nn/_linear_test.py +475 -107
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +384 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -975
- brainstate/nn/_normalizations_test.py +699 -73
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +2239 -1177
- brainstate/nn/_poolings_test.py +953 -217
- brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +216 -89
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +809 -553
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
- brainstate/random/__init__.py +270 -24
- brainstate/random/_rand_funs.py +3938 -3616
- brainstate/random/_rand_funs_test.py +640 -567
- brainstate/random/_rand_seed.py +675 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1409
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
- brainstate/{augment → transform}/_autograd.py +1025 -778
- brainstate/{augment → transform}/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +220 -220
- brainstate/{compile → transform}/_error_if.py +94 -92
- brainstate/{compile → transform}/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +38 -38
- brainstate/{compile → transform}/_jit.py +399 -346
- brainstate/{compile → transform}/_jit_test.py +143 -143
- brainstate/{compile → transform}/_loop_collect_return.py +675 -536
- brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
- brainstate/{compile → transform}/_loop_no_collection.py +283 -184
- brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +255 -202
- brainstate/{augment → transform}/_random.py +171 -151
- brainstate/{compile → transform}/_unvmap.py +256 -159
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +837 -304
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +27 -50
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +945 -469
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +910 -523
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,194 @@
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import unittest
|
18
|
+
|
19
|
+
import jax
|
20
|
+
import jax.numpy as jnp
|
21
|
+
from jax import vmap
|
22
|
+
from jax.lax import psum, pmean, pmax
|
23
|
+
|
24
|
+
import brainstate
|
25
|
+
import brainstate.transform
|
26
|
+
from brainstate._error import BatchAxisError
|
27
|
+
|
28
|
+
|
29
|
+
|
30
|
+
class TestMap(unittest.TestCase):
|
31
|
+
def test_map(self):
|
32
|
+
for dim in [(10,), (10, 10), (10, 10, 10)]:
|
33
|
+
x = brainstate.random.rand(*dim)
|
34
|
+
r1 = brainstate.transform.map(lambda a: a + 1, x, batch_size=None)
|
35
|
+
r2 = brainstate.transform.map(lambda a: a + 1, x, batch_size=2)
|
36
|
+
r3 = brainstate.transform.map(lambda a: a + 1, x, batch_size=4)
|
37
|
+
r4 = brainstate.transform.map(lambda a: a + 1, x, batch_size=5)
|
38
|
+
true_r = x + 1
|
39
|
+
|
40
|
+
self.assertTrue(jnp.allclose(r1, true_r))
|
41
|
+
self.assertTrue(jnp.allclose(r2, true_r))
|
42
|
+
self.assertTrue(jnp.allclose(r3, true_r))
|
43
|
+
self.assertTrue(jnp.allclose(r4, true_r))
|
44
|
+
|
45
|
+
|
46
|
+
class TestAxisName:
|
47
|
+
def test1(self):
|
48
|
+
def compute_stats_with_axis_name(x):
|
49
|
+
"""Compute statistics using named axis operations"""
|
50
|
+
# Sum across the named axis 'batch'
|
51
|
+
total_sum = psum(x, axis_name='batch')
|
52
|
+
|
53
|
+
# Mean across the named axis 'batch'
|
54
|
+
mean_val = pmean(x, axis_name='batch')
|
55
|
+
|
56
|
+
# Max across the named axis 'batch'
|
57
|
+
max_val = pmax(x, axis_name='batch')
|
58
|
+
|
59
|
+
return {
|
60
|
+
'sum': total_sum,
|
61
|
+
'mean': mean_val,
|
62
|
+
'max': max_val,
|
63
|
+
'original': x
|
64
|
+
}
|
65
|
+
|
66
|
+
batch_data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
67
|
+
print("Input batch data:", batch_data)
|
68
|
+
|
69
|
+
# vmap with axis name 'batch'
|
70
|
+
vectorized_stats_jax = jax.jit(vmap(compute_stats_with_axis_name, axis_name='batch'))
|
71
|
+
result_jax = vectorized_stats_jax(batch_data)
|
72
|
+
|
73
|
+
# vmap with axis name 'batch'
|
74
|
+
vectorized_stats = brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
75
|
+
result = vectorized_stats(batch_data)
|
76
|
+
|
77
|
+
# vmap with axis name 'batch'
|
78
|
+
vectorized_stats_v2 = brainstate.transform.jit(
|
79
|
+
brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
80
|
+
)
|
81
|
+
result_v2 = vectorized_stats_v2(batch_data)
|
82
|
+
|
83
|
+
for key in result_jax.keys():
|
84
|
+
print(f" {key}: {result_jax[key]}")
|
85
|
+
assert jnp.allclose(result_jax[key], result[key]), f"Mismatch in {key}"
|
86
|
+
assert jnp.allclose(result_jax[key], result_v2[key]), f"Mismatch in {key}"
|
87
|
+
|
88
|
+
def test_nested_vmap(self):
|
89
|
+
def nested_computation(x):
|
90
|
+
"""Computation with multiple named axes"""
|
91
|
+
# Sum over 'inner' axis, then mean over 'outer' axis
|
92
|
+
inner_sum = psum(x, axis_name='inner')
|
93
|
+
outer_mean = pmean(inner_sum, axis_name='outer')
|
94
|
+
return outer_mean
|
95
|
+
|
96
|
+
# Create 2D batch data
|
97
|
+
data_2d = jnp.arange(12.0).reshape(3, 4) # Shape: [outer_batch=3, inner_batch=4]
|
98
|
+
print("Input 2D data shape:", data_2d.shape)
|
99
|
+
print("Input 2D data:\n", data_2d)
|
100
|
+
|
101
|
+
# Nested vmap: first over inner dimension, then outer dimension
|
102
|
+
inner_vmap = vmap(nested_computation, axis_name='inner')
|
103
|
+
nested_vmap = vmap(inner_vmap, axis_name='outer')
|
104
|
+
|
105
|
+
result_2d = nested_vmap(data_2d)
|
106
|
+
print("Result after nested vmap:", result_2d)
|
107
|
+
|
108
|
+
inner_vmap_bst = brainstate.transform.vmap(nested_computation, axis_name='inner')
|
109
|
+
nested_vmap_bst = brainstate.transform.vmap(inner_vmap_bst, axis_name='outer')
|
110
|
+
result_2d_bst = nested_vmap_bst(data_2d)
|
111
|
+
print("Result after nested vmap:", result_2d_bst)
|
112
|
+
|
113
|
+
assert jnp.allclose(result_2d, result_2d_bst)
|
114
|
+
|
115
|
+
def _gradient_averaging_simulation_bst(self):
|
116
|
+
def loss_function(params, x, y):
|
117
|
+
"""Simple quadratic loss"""
|
118
|
+
pred = params * x
|
119
|
+
return (pred - y) ** 2
|
120
|
+
|
121
|
+
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
122
|
+
"""Compute gradients and average them across the batch"""
|
123
|
+
# Compute per-sample gradients
|
124
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
125
|
+
per_sample_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
126
|
+
|
127
|
+
# Average gradients across batch using named axis
|
128
|
+
def average_grads(grads):
|
129
|
+
return pmean(grads, axis_name='batch')
|
130
|
+
|
131
|
+
# Apply averaging with named axis
|
132
|
+
averaged_grads = vmap(average_grads, axis_name='batch')(per_sample_grads)
|
133
|
+
return averaged_grads
|
134
|
+
|
135
|
+
# Example data
|
136
|
+
params = 2.0
|
137
|
+
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
138
|
+
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
139
|
+
|
140
|
+
print("Parameters:", params)
|
141
|
+
print("Batch X:", batch_x)
|
142
|
+
print("Batch Y:", batch_y)
|
143
|
+
|
144
|
+
# Compute individual gradients first
|
145
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
146
|
+
individual_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
147
|
+
print("Individual gradients:", individual_grads)
|
148
|
+
|
149
|
+
# Now compute averaged gradients using axis names
|
150
|
+
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
151
|
+
print("Averaged gradients:", averaged_grads)
|
152
|
+
|
153
|
+
return individual_grads, averaged_grads
|
154
|
+
|
155
|
+
def _gradient_averaging_simulation_jax(self):
|
156
|
+
def loss_function(params, x, y):
|
157
|
+
"""Simple quadratic loss"""
|
158
|
+
pred = params * x
|
159
|
+
return (pred - y) ** 2
|
160
|
+
|
161
|
+
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
162
|
+
"""Compute gradients and average them across the batch"""
|
163
|
+
# Compute per-sample gradients
|
164
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
165
|
+
per_sample_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
166
|
+
|
167
|
+
# Average gradients across batch using named axis
|
168
|
+
def average_grads(grads):
|
169
|
+
return pmean(grads, axis_name='batch')
|
170
|
+
|
171
|
+
# Apply averaging with named axis
|
172
|
+
averaged_grads = brainstate.transform.vmap(average_grads, axis_name='batch')(per_sample_grads)
|
173
|
+
return averaged_grads
|
174
|
+
|
175
|
+
# Example data
|
176
|
+
params = 2.0
|
177
|
+
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
178
|
+
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
179
|
+
|
180
|
+
print("Parameters:", params)
|
181
|
+
print("Batch X:", batch_x)
|
182
|
+
print("Batch Y:", batch_y)
|
183
|
+
|
184
|
+
# Compute individual gradients first
|
185
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
186
|
+
individual_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
187
|
+
print("Individual gradients:", individual_grads)
|
188
|
+
|
189
|
+
# Now compute averaged gradients using axis names
|
190
|
+
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
191
|
+
print("Averaged gradients:", averaged_grads)
|
192
|
+
|
193
|
+
return individual_grads, averaged_grads
|
194
|
+
|
@@ -1,202 +1,255 @@
|
|
1
|
-
# Copyright 2024
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
import copy
|
18
|
-
import importlib.util
|
19
|
-
from typing import Optional, Callable, Any, Tuple, Dict
|
20
|
-
|
21
|
-
import jax
|
22
|
-
|
23
|
-
tqdm_installed = importlib.util.find_spec('tqdm') is not None
|
24
|
-
|
25
|
-
__all__ = [
|
26
|
-
'ProgressBar',
|
27
|
-
]
|
28
|
-
|
29
|
-
Index = int
|
30
|
-
Carray = Any
|
31
|
-
Output = Any
|
32
|
-
|
33
|
-
|
34
|
-
class ProgressBar(object):
|
35
|
-
"""
|
36
|
-
A progress bar for tracking the progress of a jitted for-loop computation.
|
37
|
-
|
38
|
-
It can be used in :py:func:`for_loop`, :py:func:`checkpointed_for_loop`, :py:func:`scan`,
|
39
|
-
and :py:func:`checkpointed_scan` functions. Or any other jitted function that uses
|
40
|
-
a for-loop.
|
41
|
-
|
42
|
-
The message displayed in the progress bar can be customized by the following two methods:
|
43
|
-
|
44
|
-
1. By passing a string to the `desc` argument.
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
#
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
self
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
self.
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import copy
|
18
|
+
import importlib.util
|
19
|
+
from typing import Optional, Callable, Any, Tuple, Dict
|
20
|
+
|
21
|
+
import jax
|
22
|
+
|
23
|
+
tqdm_installed = importlib.util.find_spec('tqdm') is not None
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
'ProgressBar',
|
27
|
+
]
|
28
|
+
|
29
|
+
Index = int
|
30
|
+
Carray = Any
|
31
|
+
Output = Any
|
32
|
+
|
33
|
+
|
34
|
+
class ProgressBar(object):
|
35
|
+
"""
|
36
|
+
A progress bar for tracking the progress of a jitted for-loop computation.
|
37
|
+
|
38
|
+
It can be used in :py:func:`for_loop`, :py:func:`checkpointed_for_loop`, :py:func:`scan`,
|
39
|
+
and :py:func:`checkpointed_scan` functions. Or any other jitted function that uses
|
40
|
+
a for-loop.
|
41
|
+
|
42
|
+
The message displayed in the progress bar can be customized by the following two methods:
|
43
|
+
|
44
|
+
1. By passing a string to the `desc` argument.
|
45
|
+
2. By passing a tuple with a string and a callable function to the `desc` argument. The callable
|
46
|
+
function should take a dictionary as input and return a dictionary. The returned dictionary
|
47
|
+
will be used to format the string.
|
48
|
+
|
49
|
+
In the second case, ``"i"`` denotes the iteration number and other keys can be computed from the
|
50
|
+
loop outputs and carry values.
|
51
|
+
|
52
|
+
Parameters
|
53
|
+
----------
|
54
|
+
freq : int, optional
|
55
|
+
The frequency at which to print the progress bar. If not specified, the progress
|
56
|
+
bar will be printed every 5% of the total iterations.
|
57
|
+
count : int, optional
|
58
|
+
The number of times to print the progress bar. If not specified, the progress
|
59
|
+
bar will be printed every 5% of the total iterations. Cannot be used together with `freq`.
|
60
|
+
desc : str or tuple, optional
|
61
|
+
A description of the progress bar. If not specified, a default message will be
|
62
|
+
displayed. Can be either a string or a tuple of (format_string, format_function).
|
63
|
+
**kwargs
|
64
|
+
Additional keyword arguments to pass to the progress bar.
|
65
|
+
|
66
|
+
Examples
|
67
|
+
--------
|
68
|
+
Basic usage with default description:
|
69
|
+
|
70
|
+
.. code-block:: python
|
71
|
+
|
72
|
+
>>> import brainstate
|
73
|
+
>>> import jax.numpy as jnp
|
74
|
+
>>>
|
75
|
+
>>> def loop_fn(x):
|
76
|
+
... return x ** 2
|
77
|
+
>>>
|
78
|
+
>>> xs = jnp.arange(100)
|
79
|
+
>>> pbar = brainstate.transform.ProgressBar()
|
80
|
+
>>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)
|
81
|
+
|
82
|
+
With custom description string:
|
83
|
+
|
84
|
+
.. code-block:: python
|
85
|
+
|
86
|
+
>>> pbar = brainstate.transform.ProgressBar(desc="Running 1000 iterations")
|
87
|
+
>>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)
|
88
|
+
|
89
|
+
With frequency control:
|
90
|
+
|
91
|
+
.. code-block:: python
|
92
|
+
|
93
|
+
>>> # Update every 10 iterations
|
94
|
+
>>> pbar = brainstate.transform.ProgressBar(freq=10)
|
95
|
+
>>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)
|
96
|
+
>>>
|
97
|
+
>>> # Update exactly 20 times during execution
|
98
|
+
>>> pbar = brainstate.transform.ProgressBar(count=20)
|
99
|
+
>>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)
|
100
|
+
|
101
|
+
With dynamic description based on loop variables:
|
102
|
+
|
103
|
+
.. code-block:: python
|
104
|
+
|
105
|
+
>>> state = brainstate.State(1.0)
|
106
|
+
>>>
|
107
|
+
>>> def loop_fn(x):
|
108
|
+
... state.value += x
|
109
|
+
... loss = jnp.sum(x ** 2)
|
110
|
+
... return loss
|
111
|
+
>>>
|
112
|
+
>>> def format_desc(data):
|
113
|
+
... return {"i": data["i"], "loss": data["y"], "state": data["carry"]}
|
114
|
+
>>>
|
115
|
+
>>> pbar = brainstate.transform.ProgressBar(
|
116
|
+
... desc=("Iteration {i}, loss = {loss:.4f}, state = {state:.2f}", format_desc)
|
117
|
+
... )
|
118
|
+
>>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)
|
119
|
+
|
120
|
+
With scan function:
|
121
|
+
|
122
|
+
.. code-block:: python
|
123
|
+
|
124
|
+
>>> def scan_fn(carry, x):
|
125
|
+
... new_carry = carry + x
|
126
|
+
... return new_carry, new_carry ** 2
|
127
|
+
>>>
|
128
|
+
>>> init_carry = 0.0
|
129
|
+
>>> pbar = brainstate.transform.ProgressBar(freq=5)
|
130
|
+
>>> final_carry, ys = brainstate.transform.scan(scan_fn, init_carry, xs, pbar=pbar)
|
131
|
+
"""
|
132
|
+
__module__ = "brainstate.transform"
|
133
|
+
|
134
|
+
def __init__(
|
135
|
+
self,
|
136
|
+
freq: Optional[int] = None,
|
137
|
+
count: Optional[int] = None,
|
138
|
+
desc: Optional[Tuple[str, Callable[[Dict], Dict]] | str] = None,
|
139
|
+
**kwargs
|
140
|
+
):
|
141
|
+
# print rate
|
142
|
+
self.print_freq = freq
|
143
|
+
if isinstance(freq, int):
|
144
|
+
assert freq > 0, "Print rate should be > 0."
|
145
|
+
|
146
|
+
# print count
|
147
|
+
self.print_count = count
|
148
|
+
if self.print_freq is not None and self.print_count is not None:
|
149
|
+
raise ValueError("Cannot specify both count and freq.")
|
150
|
+
|
151
|
+
# other parameters
|
152
|
+
for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
|
153
|
+
kwargs.pop(kwarg, None)
|
154
|
+
self.kwargs = kwargs
|
155
|
+
|
156
|
+
# description
|
157
|
+
if desc is not None:
|
158
|
+
if isinstance(desc, str):
|
159
|
+
pass
|
160
|
+
else:
|
161
|
+
assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
|
162
|
+
assert isinstance(desc[0], str), 'Description should be a string.'
|
163
|
+
assert callable(desc[1]), 'Description should be a callable.'
|
164
|
+
self.desc = desc
|
165
|
+
|
166
|
+
# check if tqdm is installed
|
167
|
+
if not tqdm_installed:
|
168
|
+
raise ImportError("tqdm is not installed.")
|
169
|
+
|
170
|
+
def init(self, n: int):
|
171
|
+
kwargs = copy.copy(self.kwargs)
|
172
|
+
freq = self.print_freq
|
173
|
+
count = self.print_count
|
174
|
+
if count is not None:
|
175
|
+
freq, remainder = divmod(n, count)
|
176
|
+
if freq == 0:
|
177
|
+
raise ValueError(f"Count {count} is too large for n {n}.")
|
178
|
+
elif freq is None:
|
179
|
+
if n > 20:
|
180
|
+
freq = int(n / 20)
|
181
|
+
else:
|
182
|
+
freq = 1
|
183
|
+
remainder = n % freq
|
184
|
+
else:
|
185
|
+
if freq < 1:
|
186
|
+
raise ValueError(f"Print rate should be > 0 got {freq}")
|
187
|
+
elif freq > n:
|
188
|
+
raise ValueError("Print rate should be less than the "
|
189
|
+
f"number of steps {n}, got {freq}")
|
190
|
+
remainder = n % freq
|
191
|
+
|
192
|
+
message = f"Running for {n:,} iterations" if self.desc is None else self.desc
|
193
|
+
return ProgressBarRunner(n, freq, remainder, message, **kwargs)
|
194
|
+
|
195
|
+
|
196
|
+
class ProgressBarRunner(object):
|
197
|
+
__module__ = "brainstate.transform"
|
198
|
+
|
199
|
+
def __init__(
|
200
|
+
self,
|
201
|
+
n: int,
|
202
|
+
print_freq: int,
|
203
|
+
remainder: int,
|
204
|
+
message: str | Tuple[str, Callable[[Dict], Dict]],
|
205
|
+
**kwargs
|
206
|
+
):
|
207
|
+
self.tqdm_bars = {}
|
208
|
+
self.kwargs = kwargs
|
209
|
+
self.n = n
|
210
|
+
self.print_freq = print_freq
|
211
|
+
self.remainder = remainder
|
212
|
+
self.message = message
|
213
|
+
|
214
|
+
def _define_tqdm(self, x: dict):
|
215
|
+
from tqdm.auto import tqdm
|
216
|
+
self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
|
217
|
+
if isinstance(self.message, str):
|
218
|
+
self.tqdm_bars[0].set_description(self.message, refresh=False)
|
219
|
+
else:
|
220
|
+
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
221
|
+
|
222
|
+
def _update_tqdm(self, x: dict):
|
223
|
+
self.tqdm_bars[0].update(self.print_freq)
|
224
|
+
if not isinstance(self.message, str):
|
225
|
+
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
226
|
+
|
227
|
+
def _close_tqdm(self, x: dict):
|
228
|
+
if self.remainder > 0:
|
229
|
+
self.tqdm_bars[0].update(self.remainder)
|
230
|
+
if not isinstance(self.message, str):
|
231
|
+
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
232
|
+
self.tqdm_bars[0].close()
|
233
|
+
|
234
|
+
def __call__(self, iter_num, **kwargs):
|
235
|
+
data = dict() if isinstance(self.message, str) else self.message[1](dict(i=iter_num, **kwargs))
|
236
|
+
assert isinstance(data, dict), 'Description function should return a dictionary.'
|
237
|
+
|
238
|
+
_ = jax.lax.cond(
|
239
|
+
iter_num == 0,
|
240
|
+
lambda x: jax.debug.callback(self._define_tqdm, x, ordered=True),
|
241
|
+
lambda x: None,
|
242
|
+
data
|
243
|
+
)
|
244
|
+
_ = jax.lax.cond(
|
245
|
+
iter_num % self.print_freq == (self.print_freq - 1),
|
246
|
+
lambda x: jax.debug.callback(self._update_tqdm, x, ordered=True),
|
247
|
+
lambda x: None,
|
248
|
+
data
|
249
|
+
)
|
250
|
+
_ = jax.lax.cond(
|
251
|
+
iter_num == self.n - 1,
|
252
|
+
lambda x: jax.debug.callback(self._close_tqdm, x, ordered=True),
|
253
|
+
lambda x: None,
|
254
|
+
data
|
255
|
+
)
|