brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2319 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +1652 -1652
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1624 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1433 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +137 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +633 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +154 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +477 -477
- brainstate/nn/_dynamics.py +1267 -1267
- brainstate/nn/_dynamics_test.py +67 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +384 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/_rand_funs.py +3938 -3938
- brainstate/random/_rand_funs_test.py +640 -640
- brainstate/random/_rand_seed.py +675 -675
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1617
- brainstate/random/_rand_state_test.py +551 -551
- brainstate/transform/__init__.py +59 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -145
- brainstate/transform/_eval_shape_test.py +38 -38
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -2016
- brainstate/transform/_make_jaxpr_test.py +1510 -1510
- brainstate/transform/_mapping.py +529 -529
- brainstate/transform/_mapping_test.py +194 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_random.py +171 -171
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate-0.2.0.dist-info/RECORD +0 -111
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,194 +1,194 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
import unittest
|
18
|
-
|
19
|
-
import jax
|
20
|
-
import jax.numpy as jnp
|
21
|
-
from jax import vmap
|
22
|
-
from jax.lax import psum, pmean, pmax
|
23
|
-
|
24
|
-
import brainstate
|
25
|
-
import brainstate.transform
|
26
|
-
from brainstate._error import BatchAxisError
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
class TestMap(unittest.TestCase):
|
31
|
-
def test_map(self):
|
32
|
-
for dim in [(10,), (10, 10), (10, 10, 10)]:
|
33
|
-
x = brainstate.random.rand(*dim)
|
34
|
-
r1 = brainstate.transform.map(lambda a: a + 1, x, batch_size=None)
|
35
|
-
r2 = brainstate.transform.map(lambda a: a + 1, x, batch_size=2)
|
36
|
-
r3 = brainstate.transform.map(lambda a: a + 1, x, batch_size=4)
|
37
|
-
r4 = brainstate.transform.map(lambda a: a + 1, x, batch_size=5)
|
38
|
-
true_r = x + 1
|
39
|
-
|
40
|
-
self.assertTrue(jnp.allclose(r1, true_r))
|
41
|
-
self.assertTrue(jnp.allclose(r2, true_r))
|
42
|
-
self.assertTrue(jnp.allclose(r3, true_r))
|
43
|
-
self.assertTrue(jnp.allclose(r4, true_r))
|
44
|
-
|
45
|
-
|
46
|
-
class TestAxisName:
|
47
|
-
def test1(self):
|
48
|
-
def compute_stats_with_axis_name(x):
|
49
|
-
"""Compute statistics using named axis operations"""
|
50
|
-
# Sum across the named axis 'batch'
|
51
|
-
total_sum = psum(x, axis_name='batch')
|
52
|
-
|
53
|
-
# Mean across the named axis 'batch'
|
54
|
-
mean_val = pmean(x, axis_name='batch')
|
55
|
-
|
56
|
-
# Max across the named axis 'batch'
|
57
|
-
max_val = pmax(x, axis_name='batch')
|
58
|
-
|
59
|
-
return {
|
60
|
-
'sum': total_sum,
|
61
|
-
'mean': mean_val,
|
62
|
-
'max': max_val,
|
63
|
-
'original': x
|
64
|
-
}
|
65
|
-
|
66
|
-
batch_data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
67
|
-
print("Input batch data:", batch_data)
|
68
|
-
|
69
|
-
# vmap with axis name 'batch'
|
70
|
-
vectorized_stats_jax = jax.jit(vmap(compute_stats_with_axis_name, axis_name='batch'))
|
71
|
-
result_jax = vectorized_stats_jax(batch_data)
|
72
|
-
|
73
|
-
# vmap with axis name 'batch'
|
74
|
-
vectorized_stats = brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
75
|
-
result = vectorized_stats(batch_data)
|
76
|
-
|
77
|
-
# vmap with axis name 'batch'
|
78
|
-
vectorized_stats_v2 = brainstate.transform.jit(
|
79
|
-
brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
80
|
-
)
|
81
|
-
result_v2 = vectorized_stats_v2(batch_data)
|
82
|
-
|
83
|
-
for key in result_jax.keys():
|
84
|
-
print(f" {key}: {result_jax[key]}")
|
85
|
-
assert jnp.allclose(result_jax[key], result[key]), f"Mismatch in {key}"
|
86
|
-
assert jnp.allclose(result_jax[key], result_v2[key]), f"Mismatch in {key}"
|
87
|
-
|
88
|
-
def test_nested_vmap(self):
|
89
|
-
def nested_computation(x):
|
90
|
-
"""Computation with multiple named axes"""
|
91
|
-
# Sum over 'inner' axis, then mean over 'outer' axis
|
92
|
-
inner_sum = psum(x, axis_name='inner')
|
93
|
-
outer_mean = pmean(inner_sum, axis_name='outer')
|
94
|
-
return outer_mean
|
95
|
-
|
96
|
-
# Create 2D batch data
|
97
|
-
data_2d = jnp.arange(12.0).reshape(3, 4) # Shape: [outer_batch=3, inner_batch=4]
|
98
|
-
print("Input 2D data shape:", data_2d.shape)
|
99
|
-
print("Input 2D data:\n", data_2d)
|
100
|
-
|
101
|
-
# Nested vmap: first over inner dimension, then outer dimension
|
102
|
-
inner_vmap = vmap(nested_computation, axis_name='inner')
|
103
|
-
nested_vmap = vmap(inner_vmap, axis_name='outer')
|
104
|
-
|
105
|
-
result_2d = nested_vmap(data_2d)
|
106
|
-
print("Result after nested vmap:", result_2d)
|
107
|
-
|
108
|
-
inner_vmap_bst = brainstate.transform.vmap(nested_computation, axis_name='inner')
|
109
|
-
nested_vmap_bst = brainstate.transform.vmap(inner_vmap_bst, axis_name='outer')
|
110
|
-
result_2d_bst = nested_vmap_bst(data_2d)
|
111
|
-
print("Result after nested vmap:", result_2d_bst)
|
112
|
-
|
113
|
-
assert jnp.allclose(result_2d, result_2d_bst)
|
114
|
-
|
115
|
-
def _gradient_averaging_simulation_bst(self):
|
116
|
-
def loss_function(params, x, y):
|
117
|
-
"""Simple quadratic loss"""
|
118
|
-
pred = params * x
|
119
|
-
return (pred - y) ** 2
|
120
|
-
|
121
|
-
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
122
|
-
"""Compute gradients and average them across the batch"""
|
123
|
-
# Compute per-sample gradients
|
124
|
-
grad_fn = jax.grad(loss_function, argnums=0)
|
125
|
-
per_sample_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
126
|
-
|
127
|
-
# Average gradients across batch using named axis
|
128
|
-
def average_grads(grads):
|
129
|
-
return pmean(grads, axis_name='batch')
|
130
|
-
|
131
|
-
# Apply averaging with named axis
|
132
|
-
averaged_grads = vmap(average_grads, axis_name='batch')(per_sample_grads)
|
133
|
-
return averaged_grads
|
134
|
-
|
135
|
-
# Example data
|
136
|
-
params = 2.0
|
137
|
-
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
138
|
-
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
139
|
-
|
140
|
-
print("Parameters:", params)
|
141
|
-
print("Batch X:", batch_x)
|
142
|
-
print("Batch Y:", batch_y)
|
143
|
-
|
144
|
-
# Compute individual gradients first
|
145
|
-
grad_fn = jax.grad(loss_function, argnums=0)
|
146
|
-
individual_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
147
|
-
print("Individual gradients:", individual_grads)
|
148
|
-
|
149
|
-
# Now compute averaged gradients using axis names
|
150
|
-
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
151
|
-
print("Averaged gradients:", averaged_grads)
|
152
|
-
|
153
|
-
return individual_grads, averaged_grads
|
154
|
-
|
155
|
-
def _gradient_averaging_simulation_jax(self):
|
156
|
-
def loss_function(params, x, y):
|
157
|
-
"""Simple quadratic loss"""
|
158
|
-
pred = params * x
|
159
|
-
return (pred - y) ** 2
|
160
|
-
|
161
|
-
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
162
|
-
"""Compute gradients and average them across the batch"""
|
163
|
-
# Compute per-sample gradients
|
164
|
-
grad_fn = jax.grad(loss_function, argnums=0)
|
165
|
-
per_sample_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
166
|
-
|
167
|
-
# Average gradients across batch using named axis
|
168
|
-
def average_grads(grads):
|
169
|
-
return pmean(grads, axis_name='batch')
|
170
|
-
|
171
|
-
# Apply averaging with named axis
|
172
|
-
averaged_grads = brainstate.transform.vmap(average_grads, axis_name='batch')(per_sample_grads)
|
173
|
-
return averaged_grads
|
174
|
-
|
175
|
-
# Example data
|
176
|
-
params = 2.0
|
177
|
-
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
178
|
-
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
179
|
-
|
180
|
-
print("Parameters:", params)
|
181
|
-
print("Batch X:", batch_x)
|
182
|
-
print("Batch Y:", batch_y)
|
183
|
-
|
184
|
-
# Compute individual gradients first
|
185
|
-
grad_fn = jax.grad(loss_function, argnums=0)
|
186
|
-
individual_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
187
|
-
print("Individual gradients:", individual_grads)
|
188
|
-
|
189
|
-
# Now compute averaged gradients using axis names
|
190
|
-
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
191
|
-
print("Averaged gradients:", averaged_grads)
|
192
|
-
|
193
|
-
return individual_grads, averaged_grads
|
194
|
-
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import unittest
|
18
|
+
|
19
|
+
import jax
|
20
|
+
import jax.numpy as jnp
|
21
|
+
from jax import vmap
|
22
|
+
from jax.lax import psum, pmean, pmax
|
23
|
+
|
24
|
+
import brainstate
|
25
|
+
import brainstate.transform
|
26
|
+
from brainstate._error import BatchAxisError
|
27
|
+
|
28
|
+
|
29
|
+
|
30
|
+
class TestMap(unittest.TestCase):
|
31
|
+
def test_map(self):
|
32
|
+
for dim in [(10,), (10, 10), (10, 10, 10)]:
|
33
|
+
x = brainstate.random.rand(*dim)
|
34
|
+
r1 = brainstate.transform.map(lambda a: a + 1, x, batch_size=None)
|
35
|
+
r2 = brainstate.transform.map(lambda a: a + 1, x, batch_size=2)
|
36
|
+
r3 = brainstate.transform.map(lambda a: a + 1, x, batch_size=4)
|
37
|
+
r4 = brainstate.transform.map(lambda a: a + 1, x, batch_size=5)
|
38
|
+
true_r = x + 1
|
39
|
+
|
40
|
+
self.assertTrue(jnp.allclose(r1, true_r))
|
41
|
+
self.assertTrue(jnp.allclose(r2, true_r))
|
42
|
+
self.assertTrue(jnp.allclose(r3, true_r))
|
43
|
+
self.assertTrue(jnp.allclose(r4, true_r))
|
44
|
+
|
45
|
+
|
46
|
+
class TestAxisName:
|
47
|
+
def test1(self):
|
48
|
+
def compute_stats_with_axis_name(x):
|
49
|
+
"""Compute statistics using named axis operations"""
|
50
|
+
# Sum across the named axis 'batch'
|
51
|
+
total_sum = psum(x, axis_name='batch')
|
52
|
+
|
53
|
+
# Mean across the named axis 'batch'
|
54
|
+
mean_val = pmean(x, axis_name='batch')
|
55
|
+
|
56
|
+
# Max across the named axis 'batch'
|
57
|
+
max_val = pmax(x, axis_name='batch')
|
58
|
+
|
59
|
+
return {
|
60
|
+
'sum': total_sum,
|
61
|
+
'mean': mean_val,
|
62
|
+
'max': max_val,
|
63
|
+
'original': x
|
64
|
+
}
|
65
|
+
|
66
|
+
batch_data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
67
|
+
print("Input batch data:", batch_data)
|
68
|
+
|
69
|
+
# vmap with axis name 'batch'
|
70
|
+
vectorized_stats_jax = jax.jit(vmap(compute_stats_with_axis_name, axis_name='batch'))
|
71
|
+
result_jax = vectorized_stats_jax(batch_data)
|
72
|
+
|
73
|
+
# vmap with axis name 'batch'
|
74
|
+
vectorized_stats = brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
75
|
+
result = vectorized_stats(batch_data)
|
76
|
+
|
77
|
+
# vmap with axis name 'batch'
|
78
|
+
vectorized_stats_v2 = brainstate.transform.jit(
|
79
|
+
brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
80
|
+
)
|
81
|
+
result_v2 = vectorized_stats_v2(batch_data)
|
82
|
+
|
83
|
+
for key in result_jax.keys():
|
84
|
+
print(f" {key}: {result_jax[key]}")
|
85
|
+
assert jnp.allclose(result_jax[key], result[key]), f"Mismatch in {key}"
|
86
|
+
assert jnp.allclose(result_jax[key], result_v2[key]), f"Mismatch in {key}"
|
87
|
+
|
88
|
+
def test_nested_vmap(self):
|
89
|
+
def nested_computation(x):
|
90
|
+
"""Computation with multiple named axes"""
|
91
|
+
# Sum over 'inner' axis, then mean over 'outer' axis
|
92
|
+
inner_sum = psum(x, axis_name='inner')
|
93
|
+
outer_mean = pmean(inner_sum, axis_name='outer')
|
94
|
+
return outer_mean
|
95
|
+
|
96
|
+
# Create 2D batch data
|
97
|
+
data_2d = jnp.arange(12.0).reshape(3, 4) # Shape: [outer_batch=3, inner_batch=4]
|
98
|
+
print("Input 2D data shape:", data_2d.shape)
|
99
|
+
print("Input 2D data:\n", data_2d)
|
100
|
+
|
101
|
+
# Nested vmap: first over inner dimension, then outer dimension
|
102
|
+
inner_vmap = vmap(nested_computation, axis_name='inner')
|
103
|
+
nested_vmap = vmap(inner_vmap, axis_name='outer')
|
104
|
+
|
105
|
+
result_2d = nested_vmap(data_2d)
|
106
|
+
print("Result after nested vmap:", result_2d)
|
107
|
+
|
108
|
+
inner_vmap_bst = brainstate.transform.vmap(nested_computation, axis_name='inner')
|
109
|
+
nested_vmap_bst = brainstate.transform.vmap(inner_vmap_bst, axis_name='outer')
|
110
|
+
result_2d_bst = nested_vmap_bst(data_2d)
|
111
|
+
print("Result after nested vmap:", result_2d_bst)
|
112
|
+
|
113
|
+
assert jnp.allclose(result_2d, result_2d_bst)
|
114
|
+
|
115
|
+
def _gradient_averaging_simulation_bst(self):
|
116
|
+
def loss_function(params, x, y):
|
117
|
+
"""Simple quadratic loss"""
|
118
|
+
pred = params * x
|
119
|
+
return (pred - y) ** 2
|
120
|
+
|
121
|
+
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
122
|
+
"""Compute gradients and average them across the batch"""
|
123
|
+
# Compute per-sample gradients
|
124
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
125
|
+
per_sample_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
126
|
+
|
127
|
+
# Average gradients across batch using named axis
|
128
|
+
def average_grads(grads):
|
129
|
+
return pmean(grads, axis_name='batch')
|
130
|
+
|
131
|
+
# Apply averaging with named axis
|
132
|
+
averaged_grads = vmap(average_grads, axis_name='batch')(per_sample_grads)
|
133
|
+
return averaged_grads
|
134
|
+
|
135
|
+
# Example data
|
136
|
+
params = 2.0
|
137
|
+
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
138
|
+
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
139
|
+
|
140
|
+
print("Parameters:", params)
|
141
|
+
print("Batch X:", batch_x)
|
142
|
+
print("Batch Y:", batch_y)
|
143
|
+
|
144
|
+
# Compute individual gradients first
|
145
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
146
|
+
individual_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
147
|
+
print("Individual gradients:", individual_grads)
|
148
|
+
|
149
|
+
# Now compute averaged gradients using axis names
|
150
|
+
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
151
|
+
print("Averaged gradients:", averaged_grads)
|
152
|
+
|
153
|
+
return individual_grads, averaged_grads
|
154
|
+
|
155
|
+
def _gradient_averaging_simulation_jax(self):
|
156
|
+
def loss_function(params, x, y):
|
157
|
+
"""Simple quadratic loss"""
|
158
|
+
pred = params * x
|
159
|
+
return (pred - y) ** 2
|
160
|
+
|
161
|
+
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
162
|
+
"""Compute gradients and average them across the batch"""
|
163
|
+
# Compute per-sample gradients
|
164
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
165
|
+
per_sample_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
166
|
+
|
167
|
+
# Average gradients across batch using named axis
|
168
|
+
def average_grads(grads):
|
169
|
+
return pmean(grads, axis_name='batch')
|
170
|
+
|
171
|
+
# Apply averaging with named axis
|
172
|
+
averaged_grads = brainstate.transform.vmap(average_grads, axis_name='batch')(per_sample_grads)
|
173
|
+
return averaged_grads
|
174
|
+
|
175
|
+
# Example data
|
176
|
+
params = 2.0
|
177
|
+
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
178
|
+
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
179
|
+
|
180
|
+
print("Parameters:", params)
|
181
|
+
print("Batch X:", batch_x)
|
182
|
+
print("Batch Y:", batch_y)
|
183
|
+
|
184
|
+
# Compute individual gradients first
|
185
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
186
|
+
individual_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
187
|
+
print("Individual gradients:", individual_grads)
|
188
|
+
|
189
|
+
# Now compute averaged gradients using axis names
|
190
|
+
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
191
|
+
print("Averaged gradients:", averaged_grads)
|
192
|
+
|
193
|
+
return individual_grads, averaged_grads
|
194
|
+
|