brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,194 @@
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import unittest
|
18
|
+
|
19
|
+
import jax
|
20
|
+
import jax.numpy as jnp
|
21
|
+
from jax import vmap
|
22
|
+
from jax.lax import psum, pmean, pmax
|
23
|
+
|
24
|
+
import brainstate
|
25
|
+
import brainstate.transform
|
26
|
+
from brainstate._error import BatchAxisError
|
27
|
+
|
28
|
+
|
29
|
+
|
30
|
+
class TestMap(unittest.TestCase):
|
31
|
+
def test_map(self):
|
32
|
+
for dim in [(10,), (10, 10), (10, 10, 10)]:
|
33
|
+
x = brainstate.random.rand(*dim)
|
34
|
+
r1 = brainstate.transform.map(lambda a: a + 1, x, batch_size=None)
|
35
|
+
r2 = brainstate.transform.map(lambda a: a + 1, x, batch_size=2)
|
36
|
+
r3 = brainstate.transform.map(lambda a: a + 1, x, batch_size=4)
|
37
|
+
r4 = brainstate.transform.map(lambda a: a + 1, x, batch_size=5)
|
38
|
+
true_r = x + 1
|
39
|
+
|
40
|
+
self.assertTrue(jnp.allclose(r1, true_r))
|
41
|
+
self.assertTrue(jnp.allclose(r2, true_r))
|
42
|
+
self.assertTrue(jnp.allclose(r3, true_r))
|
43
|
+
self.assertTrue(jnp.allclose(r4, true_r))
|
44
|
+
|
45
|
+
|
46
|
+
class TestAxisName:
|
47
|
+
def test1(self):
|
48
|
+
def compute_stats_with_axis_name(x):
|
49
|
+
"""Compute statistics using named axis operations"""
|
50
|
+
# Sum across the named axis 'batch'
|
51
|
+
total_sum = psum(x, axis_name='batch')
|
52
|
+
|
53
|
+
# Mean across the named axis 'batch'
|
54
|
+
mean_val = pmean(x, axis_name='batch')
|
55
|
+
|
56
|
+
# Max across the named axis 'batch'
|
57
|
+
max_val = pmax(x, axis_name='batch')
|
58
|
+
|
59
|
+
return {
|
60
|
+
'sum': total_sum,
|
61
|
+
'mean': mean_val,
|
62
|
+
'max': max_val,
|
63
|
+
'original': x
|
64
|
+
}
|
65
|
+
|
66
|
+
batch_data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
67
|
+
print("Input batch data:", batch_data)
|
68
|
+
|
69
|
+
# vmap with axis name 'batch'
|
70
|
+
vectorized_stats_jax = jax.jit(vmap(compute_stats_with_axis_name, axis_name='batch'))
|
71
|
+
result_jax = vectorized_stats_jax(batch_data)
|
72
|
+
|
73
|
+
# vmap with axis name 'batch'
|
74
|
+
vectorized_stats = brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
75
|
+
result = vectorized_stats(batch_data)
|
76
|
+
|
77
|
+
# vmap with axis name 'batch'
|
78
|
+
vectorized_stats_v2 = brainstate.transform.jit(
|
79
|
+
brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
80
|
+
)
|
81
|
+
result_v2 = vectorized_stats_v2(batch_data)
|
82
|
+
|
83
|
+
for key in result_jax.keys():
|
84
|
+
print(f" {key}: {result_jax[key]}")
|
85
|
+
assert jnp.allclose(result_jax[key], result[key]), f"Mismatch in {key}"
|
86
|
+
assert jnp.allclose(result_jax[key], result_v2[key]), f"Mismatch in {key}"
|
87
|
+
|
88
|
+
def test_nested_vmap(self):
|
89
|
+
def nested_computation(x):
|
90
|
+
"""Computation with multiple named axes"""
|
91
|
+
# Sum over 'inner' axis, then mean over 'outer' axis
|
92
|
+
inner_sum = psum(x, axis_name='inner')
|
93
|
+
outer_mean = pmean(inner_sum, axis_name='outer')
|
94
|
+
return outer_mean
|
95
|
+
|
96
|
+
# Create 2D batch data
|
97
|
+
data_2d = jnp.arange(12.0).reshape(3, 4) # Shape: [outer_batch=3, inner_batch=4]
|
98
|
+
print("Input 2D data shape:", data_2d.shape)
|
99
|
+
print("Input 2D data:\n", data_2d)
|
100
|
+
|
101
|
+
# Nested vmap: first over inner dimension, then outer dimension
|
102
|
+
inner_vmap = vmap(nested_computation, axis_name='inner')
|
103
|
+
nested_vmap = vmap(inner_vmap, axis_name='outer')
|
104
|
+
|
105
|
+
result_2d = nested_vmap(data_2d)
|
106
|
+
print("Result after nested vmap:", result_2d)
|
107
|
+
|
108
|
+
inner_vmap_bst = brainstate.transform.vmap(nested_computation, axis_name='inner')
|
109
|
+
nested_vmap_bst = brainstate.transform.vmap(inner_vmap_bst, axis_name='outer')
|
110
|
+
result_2d_bst = nested_vmap_bst(data_2d)
|
111
|
+
print("Result after nested vmap:", result_2d_bst)
|
112
|
+
|
113
|
+
assert jnp.allclose(result_2d, result_2d_bst)
|
114
|
+
|
115
|
+
def _gradient_averaging_simulation_bst(self):
|
116
|
+
def loss_function(params, x, y):
|
117
|
+
"""Simple quadratic loss"""
|
118
|
+
pred = params * x
|
119
|
+
return (pred - y) ** 2
|
120
|
+
|
121
|
+
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
122
|
+
"""Compute gradients and average them across the batch"""
|
123
|
+
# Compute per-sample gradients
|
124
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
125
|
+
per_sample_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
126
|
+
|
127
|
+
# Average gradients across batch using named axis
|
128
|
+
def average_grads(grads):
|
129
|
+
return pmean(grads, axis_name='batch')
|
130
|
+
|
131
|
+
# Apply averaging with named axis
|
132
|
+
averaged_grads = vmap(average_grads, axis_name='batch')(per_sample_grads)
|
133
|
+
return averaged_grads
|
134
|
+
|
135
|
+
# Example data
|
136
|
+
params = 2.0
|
137
|
+
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
138
|
+
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
139
|
+
|
140
|
+
print("Parameters:", params)
|
141
|
+
print("Batch X:", batch_x)
|
142
|
+
print("Batch Y:", batch_y)
|
143
|
+
|
144
|
+
# Compute individual gradients first
|
145
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
146
|
+
individual_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
147
|
+
print("Individual gradients:", individual_grads)
|
148
|
+
|
149
|
+
# Now compute averaged gradients using axis names
|
150
|
+
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
151
|
+
print("Averaged gradients:", averaged_grads)
|
152
|
+
|
153
|
+
return individual_grads, averaged_grads
|
154
|
+
|
155
|
+
def _gradient_averaging_simulation_jax(self):
|
156
|
+
def loss_function(params, x, y):
|
157
|
+
"""Simple quadratic loss"""
|
158
|
+
pred = params * x
|
159
|
+
return (pred - y) ** 2
|
160
|
+
|
161
|
+
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
162
|
+
"""Compute gradients and average them across the batch"""
|
163
|
+
# Compute per-sample gradients
|
164
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
165
|
+
per_sample_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
166
|
+
|
167
|
+
# Average gradients across batch using named axis
|
168
|
+
def average_grads(grads):
|
169
|
+
return pmean(grads, axis_name='batch')
|
170
|
+
|
171
|
+
# Apply averaging with named axis
|
172
|
+
averaged_grads = brainstate.transform.vmap(average_grads, axis_name='batch')(per_sample_grads)
|
173
|
+
return averaged_grads
|
174
|
+
|
175
|
+
# Example data
|
176
|
+
params = 2.0
|
177
|
+
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
178
|
+
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
179
|
+
|
180
|
+
print("Parameters:", params)
|
181
|
+
print("Batch X:", batch_x)
|
182
|
+
print("Batch Y:", batch_y)
|
183
|
+
|
184
|
+
# Compute individual gradients first
|
185
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
186
|
+
individual_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
187
|
+
print("Individual gradients:", individual_grads)
|
188
|
+
|
189
|
+
# Now compute averaged gradients using axis names
|
190
|
+
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
191
|
+
print("Averaged gradients:", averaged_grads)
|
192
|
+
|
193
|
+
return individual_grads, averaged_grads
|
194
|
+
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -41,42 +41,95 @@ class ProgressBar(object):
|
|
41
41
|
|
42
42
|
The message displayed in the progress bar can be customized by the following two methods:
|
43
43
|
|
44
|
-
1. By passing a string to the `desc` argument.
|
44
|
+
1. By passing a string to the `desc` argument.
|
45
|
+
2. By passing a tuple with a string and a callable function to the `desc` argument. The callable
|
46
|
+
function should take a dictionary as input and return a dictionary. The returned dictionary
|
47
|
+
will be used to format the string.
|
48
|
+
|
49
|
+
In the second case, ``"i"`` denotes the iteration number and other keys can be computed from the
|
50
|
+
loop outputs and carry values.
|
51
|
+
|
52
|
+
Parameters
|
53
|
+
----------
|
54
|
+
freq : int, optional
|
55
|
+
The frequency at which to print the progress bar. If not specified, the progress
|
56
|
+
bar will be printed every 5% of the total iterations.
|
57
|
+
count : int, optional
|
58
|
+
The number of times to print the progress bar. If not specified, the progress
|
59
|
+
bar will be printed every 5% of the total iterations. Cannot be used together with `freq`.
|
60
|
+
desc : str or tuple, optional
|
61
|
+
A description of the progress bar. If not specified, a default message will be
|
62
|
+
displayed. Can be either a string or a tuple of (format_string, format_function).
|
63
|
+
**kwargs
|
64
|
+
Additional keyword arguments to pass to the progress bar.
|
65
|
+
|
66
|
+
Examples
|
67
|
+
--------
|
68
|
+
Basic usage with default description:
|
45
69
|
|
46
70
|
.. code-block:: python
|
47
71
|
|
48
|
-
|
72
|
+
>>> import brainstate
|
73
|
+
>>> import jax.numpy as jnp
|
74
|
+
>>>
|
75
|
+
>>> def loop_fn(x):
|
76
|
+
... return x ** 2
|
77
|
+
>>>
|
78
|
+
>>> xs = jnp.arange(100)
|
79
|
+
>>> pbar = brainstate.transform.ProgressBar()
|
80
|
+
>>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)
|
49
81
|
|
50
|
-
|
51
|
-
|
52
|
-
|
82
|
+
With custom description string:
|
83
|
+
|
84
|
+
.. code-block:: python
|
85
|
+
|
86
|
+
>>> pbar = brainstate.transform.ProgressBar(desc="Running 1000 iterations")
|
87
|
+
>>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)
|
88
|
+
|
89
|
+
With frequency control:
|
53
90
|
|
54
91
|
.. code-block:: python
|
55
92
|
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
93
|
+
>>> # Update every 10 iterations
|
94
|
+
>>> pbar = brainstate.transform.ProgressBar(freq=10)
|
95
|
+
>>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)
|
96
|
+
>>>
|
97
|
+
>>> # Update exactly 20 times during execution
|
98
|
+
>>> pbar = brainstate.transform.ProgressBar(count=20)
|
99
|
+
>>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)
|
60
100
|
|
61
|
-
|
62
|
-
lambda i_carray_y: {"i": i_carray_y["i"], "loss": i_carray_y["y"]}))
|
101
|
+
With dynamic description based on loop variables:
|
63
102
|
|
64
|
-
|
103
|
+
.. code-block:: python
|
65
104
|
|
66
|
-
|
67
|
-
|
105
|
+
>>> state = brainstate.State(1.0)
|
106
|
+
>>>
|
107
|
+
>>> def loop_fn(x):
|
108
|
+
... state.value += x
|
109
|
+
... loss = jnp.sum(x ** 2)
|
110
|
+
... return loss
|
111
|
+
>>>
|
112
|
+
>>> def format_desc(data):
|
113
|
+
... return {"i": data["i"], "loss": data["y"], "state": data["carry"]}
|
114
|
+
>>>
|
115
|
+
>>> pbar = brainstate.transform.ProgressBar(
|
116
|
+
... desc=("Iteration {i}, loss = {loss:.4f}, state = {state:.2f}", format_desc)
|
117
|
+
... )
|
118
|
+
>>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)
|
119
|
+
|
120
|
+
With scan function:
|
68
121
|
|
122
|
+
.. code-block:: python
|
69
123
|
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
kwargs: Additional keyword arguments to pass to the progress bar.
|
124
|
+
>>> def scan_fn(carry, x):
|
125
|
+
... new_carry = carry + x
|
126
|
+
... return new_carry, new_carry ** 2
|
127
|
+
>>>
|
128
|
+
>>> init_carry = 0.0
|
129
|
+
>>> pbar = brainstate.transform.ProgressBar(freq=5)
|
130
|
+
>>> final_carry, ys = brainstate.transform.scan(scan_fn, init_carry, xs, pbar=pbar)
|
78
131
|
"""
|
79
|
-
__module__ = "brainstate.
|
132
|
+
__module__ = "brainstate.transform"
|
80
133
|
|
81
134
|
def __init__(
|
82
135
|
self,
|
@@ -141,7 +194,7 @@ class ProgressBar(object):
|
|
141
194
|
|
142
195
|
|
143
196
|
class ProgressBarRunner(object):
|
144
|
-
__module__ = "brainstate.
|
197
|
+
__module__ = "brainstate.transform"
|
145
198
|
|
146
199
|
def __init__(
|
147
200
|
self,
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -16,6 +16,7 @@
|
|
16
16
|
import functools
|
17
17
|
from typing import Callable, Sequence, Union
|
18
18
|
|
19
|
+
from brainstate._utils import set_module_as
|
19
20
|
from brainstate.random import DEFAULT, RandomState
|
20
21
|
from brainstate.typing import Missing
|
21
22
|
from brainstate.util import PrettyObject
|
@@ -27,47 +28,66 @@ __all__ = [
|
|
27
28
|
|
28
29
|
class RngRestore(PrettyObject):
|
29
30
|
"""
|
30
|
-
|
31
|
+
Manage backing up and restoring multiple random states.
|
31
32
|
|
32
|
-
|
33
|
-
|
33
|
+
Parameters
|
34
|
+
----------
|
35
|
+
rngs : Sequence[RandomState]
|
36
|
+
Sequence of :class:`~brainstate.random.RandomState` instances whose
|
37
|
+
states should be captured and restored.
|
38
|
+
|
39
|
+
Attributes
|
40
|
+
----------
|
41
|
+
rngs : Sequence[RandomState]
|
42
|
+
Managed random-state instances.
|
43
|
+
rng_keys : list
|
44
|
+
Cached keys captured by :meth:`backup` until :meth:`restore` runs.
|
34
45
|
|
35
|
-
|
36
|
-
|
37
|
-
|
46
|
+
Examples
|
47
|
+
--------
|
48
|
+
.. code-block:: python
|
49
|
+
|
50
|
+
>>> import brainstate
|
51
|
+
>>>
|
52
|
+
>>> rng = brainstate.random.RandomState(0)
|
53
|
+
>>> restorer = brainstate.transform.RngRestore([rng])
|
54
|
+
>>> restorer.backup()
|
55
|
+
>>> _ = rng.random()
|
56
|
+
>>> restorer.restore()
|
38
57
|
"""
|
58
|
+
__module__ = 'brainstate.transform'
|
39
59
|
|
40
60
|
def __init__(self, rngs: Sequence[RandomState]):
|
41
61
|
"""
|
42
|
-
Initialize the
|
62
|
+
Initialize a restorer for the provided random states.
|
43
63
|
|
44
|
-
|
45
|
-
|
46
|
-
|
64
|
+
Parameters
|
65
|
+
----------
|
66
|
+
rngs : Sequence[RandomState]
|
67
|
+
Random states that will be backed up and restored.
|
47
68
|
"""
|
48
69
|
self.rngs: Sequence[RandomState] = rngs
|
49
70
|
self.rng_keys = []
|
50
71
|
|
51
72
|
def backup(self):
|
52
73
|
"""
|
53
|
-
|
74
|
+
Cache the current key for each managed random state.
|
54
75
|
|
55
|
-
|
56
|
-
|
76
|
+
Notes
|
77
|
+
-----
|
78
|
+
The cached keys persist until :meth:`restore` is called, after which the
|
79
|
+
internal cache is cleared.
|
57
80
|
"""
|
58
81
|
self.rng_keys = [rng.value for rng in self.rngs]
|
59
82
|
|
60
83
|
def restore(self):
|
61
84
|
"""
|
62
|
-
Restore
|
85
|
+
Restore each random state to the cached key.
|
63
86
|
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
Raises:
|
69
|
-
ValueError: If the number of saved random keys does not match
|
70
|
-
the number of RandomState instances.
|
87
|
+
Raises
|
88
|
+
------
|
89
|
+
ValueError
|
90
|
+
Raised when the number of stored keys does not match ``rngs``.
|
71
91
|
"""
|
72
92
|
if len(self.rng_keys) != len(self.rngs):
|
73
93
|
raise ValueError('The number of random keys does not match the number of random states.')
|
@@ -95,50 +115,50 @@ def _rng_backup(
|
|
95
115
|
return wrapper
|
96
116
|
|
97
117
|
|
118
|
+
@set_module_as('brainstate.transform')
|
98
119
|
def restore_rngs(
|
99
120
|
fn: Callable = Missing(),
|
100
121
|
rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
|
101
122
|
) -> Callable:
|
102
123
|
"""
|
103
|
-
|
104
|
-
|
105
|
-
This function can be used as a decorator or called directly. It ensures that the
|
106
|
-
random state of the specified RandomState instances is preserved across function calls,
|
107
|
-
which is useful for maintaining reproducibility in stochastic operations.
|
124
|
+
Decorate a function so specified random states are restored after execution.
|
108
125
|
|
109
126
|
Parameters
|
110
127
|
----------
|
111
128
|
fn : Callable, optional
|
112
|
-
|
113
|
-
with
|
129
|
+
Function to wrap. When omitted, :func:`restore_rngs` returns a decorator
|
130
|
+
preconfigured with ``rngs``.
|
114
131
|
rngs : Union[RandomState, Sequence[RandomState]], optional
|
115
|
-
|
116
|
-
|
117
|
-
the default RandomState instance will be used.
|
132
|
+
Random states whose keys should be backed up before running ``fn`` and
|
133
|
+
restored afterwards. Defaults to :data:`brainstate.random.DEFAULT`.
|
118
134
|
|
119
135
|
Returns
|
120
136
|
-------
|
121
137
|
Callable
|
122
|
-
|
123
|
-
|
124
|
-
If `fn` is not provided, returns a partial function that can be used as
|
125
|
-
a decorator with the specified `rngs`.
|
138
|
+
Wrapped callable that restores the random state or a partially applied
|
139
|
+
decorator depending on how :func:`restore_rngs` is used.
|
126
140
|
|
127
141
|
Raises
|
128
142
|
------
|
129
143
|
AssertionError
|
130
|
-
If
|
144
|
+
If ``rngs`` is neither a :class:`~brainstate.random.RandomState` instance nor
|
145
|
+
a sequence of such instances.
|
131
146
|
|
132
147
|
Examples
|
133
148
|
--------
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
149
|
+
.. code-block:: python
|
150
|
+
|
151
|
+
>>> import brainstate
|
152
|
+
>>>
|
153
|
+
>>> rng = brainstate.random.RandomState(0)
|
154
|
+
>>>
|
155
|
+
>>> @brainstate.transform.restore_rngs(rngs=rng)
|
156
|
+
... def sample_pair():
|
157
|
+
... first = rng.random()
|
158
|
+
... second = rng.random()
|
159
|
+
... return first, second
|
160
|
+
>>>
|
161
|
+
>>> assert sample_pair()[0] == sample_pair()[0]
|
142
162
|
"""
|
143
163
|
if isinstance(fn, Missing):
|
144
164
|
return functools.partial(restore_rngs, rngs=rngs)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -27,8 +27,39 @@ __all__ = [
|
|
27
27
|
]
|
28
28
|
|
29
29
|
|
30
|
-
@set_module_as('brainstate.
|
30
|
+
@set_module_as('brainstate.transform')
|
31
31
|
def unvmap(x, op: str = 'any'):
|
32
|
+
"""
|
33
|
+
Remove a leading vmap dimension by aggregating batched values.
|
34
|
+
|
35
|
+
Parameters
|
36
|
+
----------
|
37
|
+
x : Any
|
38
|
+
Value produced inside a :func:`jax.vmap`-transformed function.
|
39
|
+
op : {'all', 'any', 'none', 'max'}, default='any'
|
40
|
+
Reduction to apply across the vmapped axis. ``'none'`` returns ``x`` without
|
41
|
+
reduction, while ``'max'`` computes the maximum element.
|
42
|
+
|
43
|
+
Returns
|
44
|
+
-------
|
45
|
+
Any
|
46
|
+
Result of applying the requested reduction with vmap metadata removed.
|
47
|
+
|
48
|
+
Raises
|
49
|
+
------
|
50
|
+
ValueError
|
51
|
+
If ``op`` is not one of ``'all'``, ``'any'``, ``'none'``, or ``'max'``.
|
52
|
+
|
53
|
+
Examples
|
54
|
+
--------
|
55
|
+
.. code-block:: python
|
56
|
+
|
57
|
+
>>> import jax.numpy as jnp
|
58
|
+
>>> import brainstate
|
59
|
+
>>>
|
60
|
+
>>> xs = jnp.array([[True, False], [True, True]])
|
61
|
+
>>> brainstate.transform.unvmap(xs, op='all')
|
62
|
+
"""
|
32
63
|
if op == 'all':
|
33
64
|
return unvmap_all(x)
|
34
65
|
elif op == 'any':
|
@@ -47,7 +78,29 @@ unvmap_all_p = Primitive("unvmap_all")
|
|
47
78
|
|
48
79
|
|
49
80
|
def unvmap_all(x):
|
50
|
-
"""
|
81
|
+
"""
|
82
|
+
Evaluate :func:`jax.numpy.all` while ignoring vmapped batch dimensions.
|
83
|
+
|
84
|
+
Parameters
|
85
|
+
----------
|
86
|
+
x : Any
|
87
|
+
Input array or pytree produced under :func:`jax.vmap`.
|
88
|
+
|
89
|
+
Returns
|
90
|
+
-------
|
91
|
+
jax.Array
|
92
|
+
Scalar boolean result of ``jnp.all(x)``.
|
93
|
+
|
94
|
+
Examples
|
95
|
+
--------
|
96
|
+
.. code-block:: python
|
97
|
+
|
98
|
+
>>> import jax.numpy as jnp
|
99
|
+
>>> import brainstate
|
100
|
+
>>>
|
101
|
+
>>> values = jnp.array([[True, False], [True, True]])
|
102
|
+
>>> brainstate.transform.unvmap(values, op='all')
|
103
|
+
"""
|
51
104
|
return unvmap_all_p.bind(x)
|
52
105
|
|
53
106
|
|
@@ -78,7 +131,29 @@ unvmap_any_p = Primitive("unvmap_any")
|
|
78
131
|
|
79
132
|
|
80
133
|
def unvmap_any(x):
|
81
|
-
"""
|
134
|
+
"""
|
135
|
+
Evaluate :func:`jax.numpy.any` while ignoring vmapped batch dimensions.
|
136
|
+
|
137
|
+
Parameters
|
138
|
+
----------
|
139
|
+
x : Any
|
140
|
+
Input array or pytree produced under :func:`jax.vmap`.
|
141
|
+
|
142
|
+
Returns
|
143
|
+
-------
|
144
|
+
jax.Array
|
145
|
+
Scalar boolean result of ``jnp.any(x)``.
|
146
|
+
|
147
|
+
Examples
|
148
|
+
--------
|
149
|
+
.. code-block:: python
|
150
|
+
|
151
|
+
>>> import jax.numpy as jnp
|
152
|
+
>>> import brainstate
|
153
|
+
>>>
|
154
|
+
>>> values = jnp.array([[False, False], [False, True]])
|
155
|
+
>>> brainstate.transform.unvmap(values, op='any')
|
156
|
+
"""
|
82
157
|
return unvmap_any_p.bind(x)
|
83
158
|
|
84
159
|
|
@@ -109,7 +184,29 @@ unvmap_max_p = Primitive("unvmap_max")
|
|
109
184
|
|
110
185
|
|
111
186
|
def unvmap_max(x):
|
112
|
-
"""
|
187
|
+
"""
|
188
|
+
Evaluate :func:`jax.numpy.max` while ignoring vmapped batch dimensions.
|
189
|
+
|
190
|
+
Parameters
|
191
|
+
----------
|
192
|
+
x : Any
|
193
|
+
Input array or pytree produced under :func:`jax.vmap`.
|
194
|
+
|
195
|
+
Returns
|
196
|
+
-------
|
197
|
+
jax.Array
|
198
|
+
Scalar containing the maximum value of ``x`` with the same dtype.
|
199
|
+
|
200
|
+
Examples
|
201
|
+
--------
|
202
|
+
.. code-block:: python
|
203
|
+
|
204
|
+
>>> import jax.numpy as jnp
|
205
|
+
>>> import brainstate
|
206
|
+
>>>
|
207
|
+
>>> values = jnp.array([[1.0, 2.0], [0.5, 3.5]])
|
208
|
+
>>> brainstate.transform.unvmap(values, op='max')
|
209
|
+
"""
|
113
210
|
return unvmap_max_p.bind(x)
|
114
211
|
|
115
212
|
|