brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +10 -3
- brainstate/_state.py +178 -178
- brainstate/_utils.py +0 -1
- brainstate/augment/_autograd.py +0 -2
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape.py +0 -2
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping.py +2 -3
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/augment/_random.py +0 -2
- brainstate/compile/_ad_checkpoint.py +0 -2
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions.py +0 -2
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if.py +0 -2
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_jit.py +9 -8
- brainstate/compile/_loop_collect_return.py +0 -2
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection.py +0 -2
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +30 -17
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/compile/_progress_bar.py +0 -1
- brainstate/compile/_unvmap.py +0 -1
- brainstate/compile/_util.py +0 -2
- brainstate/environ.py +0 -2
- brainstate/functional/_activations.py +0 -2
- brainstate/functional/_activations_test.py +61 -61
- brainstate/functional/_normalization.py +0 -2
- brainstate/functional/_others.py +0 -2
- brainstate/functional/_spikes.py +0 -1
- brainstate/graph/_graph_node.py +1 -3
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation.py +4 -2
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_base.py +0 -2
- brainstate/init/_generic.py +0 -1
- brainstate/init/_random_inits.py +0 -1
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits.py +0 -2
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/mixin.py +0 -2
- brainstate/nn/_collective_ops.py +0 -3
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/_common.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
- brainstate/nn/_dyn_impl/_inputs.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
- brainstate/nn/_dyn_impl/_readout.py +0 -1
- brainstate/nn/_dyn_impl/_readout_test.py +9 -10
- brainstate/nn/_dynamics/_dynamics_base.py +0 -1
- brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
- brainstate/nn/_dynamics/_projection_base.py +0 -1
- brainstate/nn/_dynamics/_state_delay.py +0 -2
- brainstate/nn/_dynamics/_synouts.py +0 -2
- brainstate/nn/_dynamics/_synouts_test.py +4 -5
- brainstate/nn/_elementwise/_dropout.py +0 -2
- brainstate/nn/_elementwise/_dropout_test.py +9 -9
- brainstate/nn/_elementwise/_elementwise.py +0 -2
- brainstate/nn/_elementwise/_elementwise_test.py +57 -59
- brainstate/nn/_event/_fixedprob_mv.py +0 -1
- brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
- brainstate/nn/_event/_linear_mv.py +0 -2
- brainstate/nn/_event/_linear_mv_test.py +0 -1
- brainstate/nn/_exp_euler.py +0 -2
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/_interaction/_conv.py +0 -2
- brainstate/nn/_interaction/_conv_test.py +31 -33
- brainstate/nn/_interaction/_embedding.py +0 -1
- brainstate/nn/_interaction/_linear.py +0 -2
- brainstate/nn/_interaction/_linear_test.py +15 -17
- brainstate/nn/_interaction/_normalizations.py +0 -2
- brainstate/nn/_interaction/_normalizations_test.py +10 -12
- brainstate/nn/_interaction/_poolings.py +0 -2
- brainstate/nn/_interaction/_poolings_test.py +19 -21
- brainstate/nn/_module.py +0 -1
- brainstate/nn/_module_test.py +34 -37
- brainstate/nn/metrics.py +0 -2
- brainstate/optim/_base.py +0 -2
- brainstate/optim/_lr_scheduler.py +0 -1
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer.py +0 -2
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/optim/_sgd_optimizer.py +0 -1
- brainstate/random/_rand_funs.py +0 -1
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed.py +0 -1
- brainstate/random/_rand_seed_test.py +10 -12
- brainstate/random/_rand_state.py +0 -1
- brainstate/surrogate.py +0 -1
- brainstate/typing.py +0 -2
- brainstate/util/_caller.py +4 -6
- brainstate/util/_others.py +0 -2
- brainstate/util/_pretty_pytree.py +201 -150
- brainstate/util/_pretty_repr.py +0 -2
- brainstate/util/_pretty_table.py +57 -3
- brainstate/util/_scaling.py +0 -2
- brainstate/util/_struct.py +0 -2
- brainstate/util/filter.py +0 -2
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
- brainstate-0.1.2.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,6 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
|
-
from __future__ import annotations
|
18
17
|
|
19
18
|
import unittest
|
20
19
|
from pprint import pprint
|
@@ -24,7 +23,7 @@ import jax
|
|
24
23
|
import jax.numpy as jnp
|
25
24
|
import pytest
|
26
25
|
|
27
|
-
import brainstate
|
26
|
+
import brainstate
|
28
27
|
from brainstate.augment._autograd import _jacfwd
|
29
28
|
|
30
29
|
|
@@ -32,11 +31,11 @@ class TestPureFuncGrad(unittest.TestCase):
|
|
32
31
|
def test_grad_pure_func_1(self):
|
33
32
|
def call(a, b, c): return jnp.sum(a + b + c)
|
34
33
|
|
35
|
-
|
34
|
+
brainstate.random.seed(1)
|
36
35
|
a = jnp.ones(10)
|
37
|
-
b =
|
38
|
-
c =
|
39
|
-
f_grad =
|
36
|
+
b = brainstate.random.randn(10)
|
37
|
+
c = brainstate.random.uniform(size=10)
|
38
|
+
f_grad = brainstate.augment.grad(call, argnums=[0, 1, 2])
|
40
39
|
grads = f_grad(a, b, c)
|
41
40
|
|
42
41
|
for g in grads: assert (g == 1.).all()
|
@@ -44,29 +43,29 @@ class TestPureFuncGrad(unittest.TestCase):
|
|
44
43
|
def test_grad_pure_func_2(self):
|
45
44
|
def call(a, b, c): return jnp.sum(a + b + c)
|
46
45
|
|
47
|
-
|
46
|
+
brainstate.random.seed(1)
|
48
47
|
a = jnp.ones(10)
|
49
|
-
b =
|
50
|
-
c =
|
51
|
-
f_grad =
|
48
|
+
b = brainstate.random.randn(10)
|
49
|
+
c = brainstate.random.uniform(size=10)
|
50
|
+
f_grad = brainstate.augment.grad(call)
|
52
51
|
assert (f_grad(a, b, c) == 1.).all()
|
53
52
|
|
54
53
|
def test_grad_pure_func_aux1(self):
|
55
54
|
def call(a, b, c):
|
56
55
|
return jnp.sum(a + b + c), (jnp.sin(100), jnp.exp(0.1))
|
57
56
|
|
58
|
-
|
59
|
-
f_grad =
|
57
|
+
brainstate.random.seed(1)
|
58
|
+
f_grad = brainstate.augment.grad(call, argnums=[0, 1, 2])
|
60
59
|
with pytest.raises(TypeError):
|
61
|
-
f_grad(jnp.ones(10),
|
60
|
+
f_grad(jnp.ones(10), brainstate.random.randn(10), brainstate.random.uniform(size=10))
|
62
61
|
|
63
62
|
def test_grad_pure_func_aux2(self):
|
64
63
|
def call(a, b, c):
|
65
64
|
return jnp.sum(a + b + c), (jnp.sin(100), jnp.exp(0.1))
|
66
65
|
|
67
|
-
|
68
|
-
f_grad =
|
69
|
-
grads, aux = f_grad(jnp.ones(10),
|
66
|
+
brainstate.random.seed(1)
|
67
|
+
f_grad = brainstate.augment.grad(call, argnums=[0, 1, 2], has_aux=True)
|
68
|
+
grads, aux = f_grad(jnp.ones(10), brainstate.random.randn(10), brainstate.random.uniform(size=10))
|
70
69
|
for g in grads: assert (g == 1.).all()
|
71
70
|
assert aux[0] == jnp.sin(100)
|
72
71
|
assert aux[1] == jnp.exp(0.1)
|
@@ -74,11 +73,11 @@ class TestPureFuncGrad(unittest.TestCase):
|
|
74
73
|
def test_grad_pure_func_return1(self):
|
75
74
|
def call(a, b, c): return jnp.sum(a + b + c)
|
76
75
|
|
77
|
-
|
76
|
+
brainstate.random.seed(1)
|
78
77
|
a = jnp.ones(10)
|
79
|
-
b =
|
80
|
-
c =
|
81
|
-
f_grad =
|
78
|
+
b = brainstate.random.randn(10)
|
79
|
+
c = brainstate.random.uniform(size=10)
|
80
|
+
f_grad = brainstate.augment.grad(call, return_value=True)
|
82
81
|
grads, returns = f_grad(a, b, c)
|
83
82
|
assert (grads == 1.).all()
|
84
83
|
assert returns == jnp.sum(a + b + c)
|
@@ -87,11 +86,11 @@ class TestPureFuncGrad(unittest.TestCase):
|
|
87
86
|
def call(a, b, c):
|
88
87
|
return jnp.sum(a + b + c), (jnp.sin(100), jnp.exp(0.1))
|
89
88
|
|
90
|
-
|
89
|
+
brainstate.random.seed(1)
|
91
90
|
a = jnp.ones(10)
|
92
|
-
b =
|
93
|
-
c =
|
94
|
-
f_grad =
|
91
|
+
b = brainstate.random.randn(10)
|
92
|
+
c = brainstate.random.uniform(size=10)
|
93
|
+
f_grad = brainstate.augment.grad(call, return_value=True, has_aux=True)
|
95
94
|
grads, returns, aux = f_grad(a, b, c)
|
96
95
|
assert (grads == 1.).all()
|
97
96
|
assert returns == jnp.sum(a + b + c)
|
@@ -101,110 +100,110 @@ class TestPureFuncGrad(unittest.TestCase):
|
|
101
100
|
|
102
101
|
class TestObjectFuncGrad(unittest.TestCase):
|
103
102
|
def test_grad_ob1(self):
|
104
|
-
class Test(
|
103
|
+
class Test(brainstate.nn.Module):
|
105
104
|
def __init__(self):
|
106
105
|
super(Test, self).__init__()
|
107
106
|
|
108
|
-
self.a =
|
109
|
-
self.b =
|
110
|
-
self.c =
|
107
|
+
self.a = brainstate.ParamState(jnp.ones(10))
|
108
|
+
self.b = brainstate.ParamState(brainstate.random.randn(10))
|
109
|
+
self.c = brainstate.ParamState(brainstate.random.uniform(size=10))
|
111
110
|
|
112
111
|
def __call__(self):
|
113
112
|
return jnp.sum(self.a.value + self.b.value + self.c.value)
|
114
113
|
|
115
|
-
|
114
|
+
brainstate.random.seed(0)
|
116
115
|
|
117
116
|
t = Test()
|
118
|
-
f_grad =
|
117
|
+
f_grad = brainstate.augment.grad(t, grad_states={'a': t.a, 'b': t.b, 'c': t.c})
|
119
118
|
grads = f_grad()
|
120
119
|
for g in grads.values():
|
121
120
|
assert (g == 1.).all()
|
122
121
|
|
123
122
|
t = Test()
|
124
|
-
f_grad =
|
123
|
+
f_grad = brainstate.augment.grad(t, grad_states=[t.a, t.b])
|
125
124
|
grads = f_grad()
|
126
125
|
for g in grads: assert (g == 1.).all()
|
127
126
|
|
128
127
|
t = Test()
|
129
|
-
f_grad =
|
128
|
+
f_grad = brainstate.augment.grad(t, grad_states=t.a)
|
130
129
|
grads = f_grad()
|
131
130
|
assert (grads == 1.).all()
|
132
131
|
|
133
132
|
t = Test()
|
134
|
-
f_grad =
|
133
|
+
f_grad = brainstate.augment.grad(t, grad_states=t.states())
|
135
134
|
grads = f_grad()
|
136
135
|
for g in grads.values():
|
137
136
|
assert (g == 1.).all()
|
138
137
|
|
139
138
|
def test_grad_ob_aux(self):
|
140
|
-
class Test(
|
139
|
+
class Test(brainstate.nn.Module):
|
141
140
|
def __init__(self):
|
142
141
|
super(Test, self).__init__()
|
143
|
-
self.a =
|
144
|
-
self.b =
|
145
|
-
self.c =
|
142
|
+
self.a = brainstate.ParamState(jnp.ones(10))
|
143
|
+
self.b = brainstate.ParamState(brainstate.random.randn(10))
|
144
|
+
self.c = brainstate.ParamState(brainstate.random.uniform(size=10))
|
146
145
|
|
147
146
|
def __call__(self):
|
148
147
|
return jnp.sum(self.a.value + self.b.value + self.c.value), (jnp.sin(100), jnp.exp(0.1))
|
149
148
|
|
150
|
-
|
149
|
+
brainstate.random.seed(0)
|
151
150
|
t = Test()
|
152
|
-
f_grad =
|
151
|
+
f_grad = brainstate.augment.grad(t, grad_states=[t.a, t.b], has_aux=True)
|
153
152
|
grads, aux = f_grad()
|
154
153
|
for g in grads: assert (g == 1.).all()
|
155
154
|
assert aux[0] == jnp.sin(100)
|
156
155
|
assert aux[1] == jnp.exp(0.1)
|
157
156
|
|
158
157
|
t = Test()
|
159
|
-
f_grad =
|
158
|
+
f_grad = brainstate.augment.grad(t, grad_states=t.a, has_aux=True)
|
160
159
|
grads, aux = f_grad()
|
161
160
|
assert (grads == 1.).all()
|
162
161
|
assert aux[0] == jnp.sin(100)
|
163
162
|
assert aux[1] == jnp.exp(0.1)
|
164
163
|
|
165
164
|
t = Test()
|
166
|
-
f_grad =
|
165
|
+
f_grad = brainstate.augment.grad(t, grad_states=t.states(), has_aux=True)
|
167
166
|
grads, aux = f_grad()
|
168
167
|
self.assertTrue(len(grads) == len(t.states()))
|
169
168
|
|
170
169
|
def test_grad_ob_return(self):
|
171
|
-
class Test(
|
170
|
+
class Test(brainstate.nn.Module):
|
172
171
|
def __init__(self):
|
173
172
|
super(Test, self).__init__()
|
174
|
-
self.a =
|
175
|
-
self.b =
|
176
|
-
self.c =
|
173
|
+
self.a = brainstate.ParamState(jnp.ones(10))
|
174
|
+
self.b = brainstate.ParamState(brainstate.random.randn(10))
|
175
|
+
self.c = brainstate.ParamState(brainstate.random.uniform(size=10))
|
177
176
|
|
178
177
|
def __call__(self):
|
179
178
|
return jnp.sum(self.a.value + self.b.value + self.c.value)
|
180
179
|
|
181
|
-
|
180
|
+
brainstate.random.seed(0)
|
182
181
|
t = Test()
|
183
|
-
f_grad =
|
182
|
+
f_grad = brainstate.augment.grad(t, grad_states=[t.a, t.b], return_value=True)
|
184
183
|
grads, returns = f_grad()
|
185
184
|
for g in grads: assert (g == 1.).all()
|
186
185
|
assert returns == t()
|
187
186
|
|
188
187
|
t = Test()
|
189
|
-
f_grad =
|
188
|
+
f_grad = brainstate.augment.grad(t, grad_states=t.a, return_value=True)
|
190
189
|
grads, returns = f_grad()
|
191
190
|
assert (grads == 1.).all()
|
192
191
|
assert returns == t()
|
193
192
|
|
194
193
|
def test_grad_ob_aux_return(self):
|
195
|
-
class Test(
|
194
|
+
class Test(brainstate.nn.Module):
|
196
195
|
def __init__(self):
|
197
196
|
super(Test, self).__init__()
|
198
|
-
self.a =
|
199
|
-
self.b =
|
200
|
-
self.c =
|
197
|
+
self.a = brainstate.ParamState(jnp.ones(10))
|
198
|
+
self.b = brainstate.ParamState(brainstate.random.randn(10))
|
199
|
+
self.c = brainstate.ParamState(brainstate.random.uniform(size=10))
|
201
200
|
|
202
201
|
def __call__(self):
|
203
202
|
return jnp.sum(self.a.value + self.b.value + self.c.value), (jnp.sin(100), jnp.exp(0.1))
|
204
203
|
|
205
|
-
|
204
|
+
brainstate.random.seed(0)
|
206
205
|
t = Test()
|
207
|
-
f_grad =
|
206
|
+
f_grad = brainstate.augment.grad(t, grad_states=[t.a, t.b], has_aux=True, return_value=True)
|
208
207
|
grads, returns, aux = f_grad()
|
209
208
|
for g in grads: assert (g == 1.).all()
|
210
209
|
assert returns == jnp.sum(t.a.value + t.b.value + t.c.value)
|
@@ -212,7 +211,7 @@ class TestObjectFuncGrad(unittest.TestCase):
|
|
212
211
|
assert aux[1] == jnp.exp(0.1)
|
213
212
|
|
214
213
|
t = Test()
|
215
|
-
f_grad =
|
214
|
+
f_grad = brainstate.augment.grad(t, grad_states=t.a, has_aux=True, return_value=True)
|
216
215
|
grads, returns, aux = f_grad()
|
217
216
|
assert (grads == 1.).all()
|
218
217
|
assert returns == jnp.sum(t.a.value + t.b.value + t.c.value)
|
@@ -220,101 +219,101 @@ class TestObjectFuncGrad(unittest.TestCase):
|
|
220
219
|
assert aux[1] == jnp.exp(0.1)
|
221
220
|
|
222
221
|
def test_grad_ob_argnums(self):
|
223
|
-
class Test(
|
222
|
+
class Test(brainstate.nn.Module):
|
224
223
|
def __init__(self):
|
225
224
|
super(Test, self).__init__()
|
226
|
-
|
227
|
-
self.a =
|
228
|
-
self.b =
|
229
|
-
self.c =
|
225
|
+
brainstate.random.seed()
|
226
|
+
self.a = brainstate.ParamState(jnp.ones(10))
|
227
|
+
self.b = brainstate.ParamState(brainstate.random.randn(10))
|
228
|
+
self.c = brainstate.ParamState(brainstate.random.uniform(size=10))
|
230
229
|
|
231
230
|
def __call__(self, d):
|
232
231
|
return jnp.sum(self.a.value + self.b.value + self.c.value + 2 * d)
|
233
232
|
|
234
|
-
|
233
|
+
brainstate.random.seed(0)
|
235
234
|
|
236
235
|
t = Test()
|
237
|
-
f_grad =
|
238
|
-
var_grads, arg_grads = f_grad(
|
236
|
+
f_grad = brainstate.augment.grad(t, t.states(), argnums=0)
|
237
|
+
var_grads, arg_grads = f_grad(brainstate.random.random(10))
|
239
238
|
for g in var_grads.values(): assert (g == 1.).all()
|
240
239
|
assert (arg_grads == 2.).all()
|
241
240
|
|
242
241
|
t = Test()
|
243
|
-
f_grad =
|
244
|
-
var_grads, arg_grads = f_grad(
|
242
|
+
f_grad = brainstate.augment.grad(t, t.states(), argnums=[0])
|
243
|
+
var_grads, arg_grads = f_grad(brainstate.random.random(10))
|
245
244
|
for g in var_grads.values(): assert (g == 1.).all()
|
246
245
|
assert (arg_grads[0] == 2.).all()
|
247
246
|
|
248
247
|
t = Test()
|
249
|
-
f_grad =
|
250
|
-
arg_grads = f_grad(
|
248
|
+
f_grad = brainstate.augment.grad(t, argnums=0)
|
249
|
+
arg_grads = f_grad(brainstate.random.random(10))
|
251
250
|
assert (arg_grads == 2.).all()
|
252
251
|
|
253
252
|
t = Test()
|
254
|
-
f_grad =
|
255
|
-
arg_grads = f_grad(
|
253
|
+
f_grad = brainstate.augment.grad(t, argnums=[0])
|
254
|
+
arg_grads = f_grad(brainstate.random.random(10))
|
256
255
|
assert (arg_grads[0] == 2.).all()
|
257
256
|
|
258
257
|
def test_grad_ob_argnums_aux(self):
|
259
|
-
class Test(
|
258
|
+
class Test(brainstate.nn.Module):
|
260
259
|
def __init__(self):
|
261
260
|
super(Test, self).__init__()
|
262
|
-
self.a =
|
263
|
-
self.b =
|
264
|
-
self.c =
|
261
|
+
self.a = brainstate.ParamState(jnp.ones(10))
|
262
|
+
self.b = brainstate.ParamState(brainstate.random.randn(10))
|
263
|
+
self.c = brainstate.ParamState(brainstate.random.uniform(size=10))
|
265
264
|
|
266
265
|
def __call__(self, d):
|
267
266
|
return jnp.sum(self.a.value + self.b.value + self.c.value + 2 * d), (jnp.sin(100), jnp.exp(0.1))
|
268
267
|
|
269
|
-
|
268
|
+
brainstate.random.seed(0)
|
270
269
|
|
271
270
|
t = Test()
|
272
|
-
f_grad =
|
273
|
-
(var_grads, arg_grads), aux = f_grad(
|
271
|
+
f_grad = brainstate.augment.grad(t, grad_states=t.states(), argnums=0, has_aux=True)
|
272
|
+
(var_grads, arg_grads), aux = f_grad(brainstate.random.random(10))
|
274
273
|
for g in var_grads.values(): assert (g == 1.).all()
|
275
274
|
assert (arg_grads == 2.).all()
|
276
275
|
assert aux[0] == jnp.sin(100)
|
277
276
|
assert aux[1] == jnp.exp(0.1)
|
278
277
|
|
279
278
|
t = Test()
|
280
|
-
f_grad =
|
281
|
-
(var_grads, arg_grads), aux = f_grad(
|
279
|
+
f_grad = brainstate.augment.grad(t, grad_states=t.states(), argnums=[0], has_aux=True)
|
280
|
+
(var_grads, arg_grads), aux = f_grad(brainstate.random.random(10))
|
282
281
|
for g in var_grads.values(): assert (g == 1.).all()
|
283
282
|
assert (arg_grads[0] == 2.).all()
|
284
283
|
assert aux[0] == jnp.sin(100)
|
285
284
|
assert aux[1] == jnp.exp(0.1)
|
286
285
|
|
287
286
|
t = Test()
|
288
|
-
f_grad =
|
289
|
-
arg_grads, aux = f_grad(
|
287
|
+
f_grad = brainstate.augment.grad(t, argnums=0, has_aux=True)
|
288
|
+
arg_grads, aux = f_grad(brainstate.random.random(10))
|
290
289
|
assert (arg_grads == 2.).all()
|
291
290
|
assert aux[0] == jnp.sin(100)
|
292
291
|
assert aux[1] == jnp.exp(0.1)
|
293
292
|
|
294
293
|
t = Test()
|
295
|
-
f_grad =
|
296
|
-
arg_grads, aux = f_grad(
|
294
|
+
f_grad = brainstate.augment.grad(t, argnums=[0], has_aux=True)
|
295
|
+
arg_grads, aux = f_grad(brainstate.random.random(10))
|
297
296
|
assert (arg_grads[0] == 2.).all()
|
298
297
|
assert aux[0] == jnp.sin(100)
|
299
298
|
assert aux[1] == jnp.exp(0.1)
|
300
299
|
|
301
300
|
def test_grad_ob_argnums_return(self):
|
302
|
-
class Test(
|
301
|
+
class Test(brainstate.nn.Module):
|
303
302
|
def __init__(self):
|
304
303
|
super(Test, self).__init__()
|
305
304
|
|
306
|
-
self.a =
|
307
|
-
self.b =
|
308
|
-
self.c =
|
305
|
+
self.a = brainstate.ParamState(jnp.ones(10))
|
306
|
+
self.b = brainstate.ParamState(brainstate.random.randn(10))
|
307
|
+
self.c = brainstate.ParamState(brainstate.random.uniform(size=10))
|
309
308
|
|
310
309
|
def __call__(self, d):
|
311
310
|
return jnp.sum(self.a.value + self.b.value + self.c.value + 2 * d)
|
312
311
|
|
313
|
-
|
312
|
+
brainstate.random.seed(0)
|
314
313
|
|
315
314
|
t = Test()
|
316
|
-
f_grad =
|
317
|
-
d =
|
315
|
+
f_grad = brainstate.augment.grad(t, t.states(), argnums=0, return_value=True)
|
316
|
+
d = brainstate.random.random(10)
|
318
317
|
(var_grads, arg_grads), loss = f_grad(d)
|
319
318
|
for g in var_grads.values():
|
320
319
|
assert (g == 1.).all()
|
@@ -322,8 +321,8 @@ class TestObjectFuncGrad(unittest.TestCase):
|
|
322
321
|
assert loss == t(d)
|
323
322
|
|
324
323
|
t = Test()
|
325
|
-
f_grad =
|
326
|
-
d =
|
324
|
+
f_grad = brainstate.augment.grad(t, t.states(), argnums=[0], return_value=True)
|
325
|
+
d = brainstate.random.random(10)
|
327
326
|
(var_grads, arg_grads), loss = f_grad(d)
|
328
327
|
for g in var_grads.values():
|
329
328
|
assert (g == 1.).all()
|
@@ -331,35 +330,35 @@ class TestObjectFuncGrad(unittest.TestCase):
|
|
331
330
|
assert loss == t(d)
|
332
331
|
|
333
332
|
t = Test()
|
334
|
-
f_grad =
|
335
|
-
d =
|
333
|
+
f_grad = brainstate.augment.grad(t, argnums=0, return_value=True)
|
334
|
+
d = brainstate.random.random(10)
|
336
335
|
arg_grads, loss = f_grad(d)
|
337
336
|
assert (arg_grads == 2.).all()
|
338
337
|
assert loss == t(d)
|
339
338
|
|
340
339
|
t = Test()
|
341
|
-
f_grad =
|
342
|
-
d =
|
340
|
+
f_grad = brainstate.augment.grad(t, argnums=[0], return_value=True)
|
341
|
+
d = brainstate.random.random(10)
|
343
342
|
arg_grads, loss = f_grad(d)
|
344
343
|
assert (arg_grads[0] == 2.).all()
|
345
344
|
assert loss == t(d)
|
346
345
|
|
347
346
|
def test_grad_ob_argnums_aux_return(self):
|
348
|
-
class Test(
|
347
|
+
class Test(brainstate.nn.Module):
|
349
348
|
def __init__(self):
|
350
349
|
super(Test, self).__init__()
|
351
|
-
self.a =
|
352
|
-
self.b =
|
353
|
-
self.c =
|
350
|
+
self.a = brainstate.ParamState(jnp.ones(10))
|
351
|
+
self.b = brainstate.ParamState(brainstate.random.randn(10))
|
352
|
+
self.c = brainstate.ParamState(brainstate.random.uniform(size=10))
|
354
353
|
|
355
354
|
def __call__(self, d):
|
356
355
|
return jnp.sum(self.a.value + self.b.value + self.c.value + 2 * d), (jnp.sin(100), jnp.exp(0.1))
|
357
356
|
|
358
|
-
|
357
|
+
brainstate.random.seed(0)
|
359
358
|
|
360
359
|
t = Test()
|
361
|
-
f_grad =
|
362
|
-
d =
|
360
|
+
f_grad = brainstate.augment.grad(t, grad_states=t.states(), argnums=0, has_aux=True, return_value=True)
|
361
|
+
d = brainstate.random.random(10)
|
363
362
|
(var_grads, arg_grads), loss, aux = f_grad(d)
|
364
363
|
for g in var_grads.values(): assert (g == 1.).all()
|
365
364
|
assert (arg_grads == 2.).all()
|
@@ -368,8 +367,8 @@ class TestObjectFuncGrad(unittest.TestCase):
|
|
368
367
|
assert loss == t(d)[0]
|
369
368
|
|
370
369
|
t = Test()
|
371
|
-
f_grad =
|
372
|
-
d =
|
370
|
+
f_grad = brainstate.augment.grad(t, grad_states=t.states(), argnums=[0], has_aux=True, return_value=True)
|
371
|
+
d = brainstate.random.random(10)
|
373
372
|
(var_grads, arg_grads), loss, aux = f_grad(d)
|
374
373
|
for g in var_grads.values(): assert (g == 1.).all()
|
375
374
|
assert (arg_grads[0] == 2.).all()
|
@@ -378,8 +377,8 @@ class TestObjectFuncGrad(unittest.TestCase):
|
|
378
377
|
assert loss == t(d)[0]
|
379
378
|
|
380
379
|
t = Test()
|
381
|
-
f_grad =
|
382
|
-
d =
|
380
|
+
f_grad = brainstate.augment.grad(t, argnums=0, has_aux=True, return_value=True)
|
381
|
+
d = brainstate.random.random(10)
|
383
382
|
arg_grads, loss, aux = f_grad(d)
|
384
383
|
assert (arg_grads == 2.).all()
|
385
384
|
assert aux[0] == jnp.sin(100)
|
@@ -387,8 +386,8 @@ class TestObjectFuncGrad(unittest.TestCase):
|
|
387
386
|
assert loss == t(d)[0]
|
388
387
|
|
389
388
|
t = Test()
|
390
|
-
f_grad =
|
391
|
-
d =
|
389
|
+
f_grad = brainstate.augment.grad(t, argnums=[0], has_aux=True, return_value=True)
|
390
|
+
d = brainstate.random.random(10)
|
392
391
|
arg_grads, loss, aux = f_grad(d)
|
393
392
|
assert (arg_grads[0] == 2.).all()
|
394
393
|
assert aux[0] == jnp.sin(100)
|
@@ -436,11 +435,11 @@ class TestPureFuncJacobian(unittest.TestCase):
|
|
436
435
|
r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
|
437
436
|
return r
|
438
437
|
|
439
|
-
br =
|
438
|
+
br = brainstate.augment.jacrev(f1)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
|
440
439
|
jr = jax.jacrev(f1)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
|
441
440
|
assert (br == jr).all()
|
442
441
|
|
443
|
-
br =
|
442
|
+
br = brainstate.augment.jacrev(f1, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
|
444
443
|
jr = jax.jacrev(f1, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
|
445
444
|
assert (br[0] == jr[0]).all()
|
446
445
|
assert (br[1] == jr[1]).all()
|
@@ -456,12 +455,12 @@ class TestPureFuncJacobian(unittest.TestCase):
|
|
456
455
|
jr = jax.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
|
457
456
|
pprint(jr)
|
458
457
|
|
459
|
-
br =
|
458
|
+
br = brainstate.augment.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
|
460
459
|
pprint(br)
|
461
460
|
assert jnp.array_equal(br[0], jr[0])
|
462
461
|
assert jnp.array_equal(br[1], jr[1])
|
463
462
|
|
464
|
-
br =
|
463
|
+
br = brainstate.augment.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
|
465
464
|
pprint(br)
|
466
465
|
assert jnp.array_equal(br[0], jr[0])
|
467
466
|
assert jnp.array_equal(br[1], jr[1])
|
@@ -471,12 +470,12 @@ class TestPureFuncJacobian(unittest.TestCase):
|
|
471
470
|
r2 = jnp.asarray([4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
|
472
471
|
return r1, r2
|
473
472
|
|
474
|
-
br =
|
473
|
+
br = brainstate.augment.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
|
475
474
|
pprint(br)
|
476
475
|
assert jnp.array_equal(br[0], jr[0])
|
477
476
|
assert jnp.array_equal(br[1], jr[1])
|
478
477
|
|
479
|
-
br =
|
478
|
+
br = brainstate.augment.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
|
480
479
|
pprint(br)
|
481
480
|
assert jnp.array_equal(br[0], jr[0])
|
482
481
|
assert jnp.array_equal(br[1], jr[1])
|
@@ -492,14 +491,14 @@ class TestPureFuncJacobian(unittest.TestCase):
|
|
492
491
|
jr = jax.jacrev(f3, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
|
493
492
|
pprint(jr)
|
494
493
|
|
495
|
-
br =
|
494
|
+
br = brainstate.augment.jacrev(f3, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
|
496
495
|
pprint(br)
|
497
496
|
assert jnp.array_equal(br[0][0], jr[0][0])
|
498
497
|
assert jnp.array_equal(br[0][1], jr[0][1])
|
499
498
|
assert jnp.array_equal(br[1][0], jr[1][0])
|
500
499
|
assert jnp.array_equal(br[1][1], jr[1][1])
|
501
500
|
|
502
|
-
br =
|
501
|
+
br = brainstate.augment.jacrev(f3, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
|
503
502
|
pprint(br)
|
504
503
|
assert jnp.array_equal(br[0][0], jr[0][0])
|
505
504
|
assert jnp.array_equal(br[0][1], jr[0][1])
|
@@ -511,14 +510,14 @@ class TestPureFuncJacobian(unittest.TestCase):
|
|
511
510
|
r2 = jnp.asarray([4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
|
512
511
|
return r1, r2
|
513
512
|
|
514
|
-
br =
|
513
|
+
br = brainstate.augment.jacrev(f3, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
|
515
514
|
pprint(br)
|
516
515
|
assert jnp.array_equal(br[0][0], jr[0][0])
|
517
516
|
assert jnp.array_equal(br[0][1], jr[0][1])
|
518
517
|
assert jnp.array_equal(br[1][0], jr[1][0])
|
519
518
|
assert jnp.array_equal(br[1][1], jr[1][1])
|
520
519
|
|
521
|
-
br =
|
520
|
+
br = brainstate.augment.jacrev(f3, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
|
522
521
|
pprint(br)
|
523
522
|
assert jnp.array_equal(br[0][0], jr[0][0])
|
524
523
|
assert jnp.array_equal(br[0][1], jr[0][1])
|
@@ -537,19 +536,19 @@ class TestPureFuncJacobian(unittest.TestCase):
|
|
537
536
|
f2 = lambda *args: f1(*args)[0]
|
538
537
|
jr = jax.jacrev(f2)(x, y) # jax jacobian
|
539
538
|
pprint(jr)
|
540
|
-
grads, aux =
|
539
|
+
grads, aux = brainstate.augment.jacrev(f1, has_aux=True)(x, y)
|
541
540
|
assert (grads == jr).all()
|
542
541
|
assert aux == (4 * x[1] ** 2 - 2 * x[2])
|
543
542
|
|
544
543
|
jr = jax.jacrev(f2, argnums=(0, 1))(x, y) # jax jacobian
|
545
544
|
pprint(jr)
|
546
|
-
grads, aux =
|
545
|
+
grads, aux = brainstate.augment.jacrev(f1, argnums=(0, 1), has_aux=True)(x, y)
|
547
546
|
assert (grads[0] == jr[0]).all()
|
548
547
|
assert (grads[1] == jr[1]).all()
|
549
548
|
assert aux == (4 * x[1] ** 2 - 2 * x[2])
|
550
549
|
|
551
550
|
def test_jacrev_return_aux1(self):
|
552
|
-
with
|
551
|
+
with brainstate.environ.context(precision=64):
|
553
552
|
def f1(x, y):
|
554
553
|
a = 4 * x[1] ** 2 - 2 * x[2]
|
555
554
|
r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], a, x[2] * jnp.sin(x[0])])
|
@@ -564,12 +563,12 @@ class TestPureFuncJacobian(unittest.TestCase):
|
|
564
563
|
_g2 = jax.jacrev(f2, argnums=(0, 1))(_x, _y) # jax jacobian
|
565
564
|
pprint(_g2)
|
566
565
|
|
567
|
-
grads, vec, aux =
|
566
|
+
grads, vec, aux = brainstate.augment.jacrev(f1, return_value=True, has_aux=True)(_x, _y)
|
568
567
|
assert (grads == _g1).all()
|
569
568
|
assert aux == _a
|
570
569
|
assert (vec == _r).all()
|
571
570
|
|
572
|
-
grads, vec, aux =
|
571
|
+
grads, vec, aux = brainstate.augment.jacrev(f1, return_value=True, argnums=(0, 1), has_aux=True)(_x, _y)
|
573
572
|
assert (grads[0] == _g2[0]).all()
|
574
573
|
assert (grads[1] == _g2[1]).all()
|
575
574
|
assert aux == _a
|
@@ -585,11 +584,11 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
585
584
|
_x = jnp.array([1., 2., 3.])
|
586
585
|
_y = jnp.array([10., 5.])
|
587
586
|
|
588
|
-
class Test(
|
587
|
+
class Test(brainstate.nn.Module):
|
589
588
|
def __init__(self):
|
590
589
|
super(Test, self).__init__()
|
591
|
-
self.x =
|
592
|
-
self.y =
|
590
|
+
self.x = brainstate.State(jnp.array([1., 2., 3.]))
|
591
|
+
self.y = brainstate.State(jnp.array([10., 5.]))
|
593
592
|
|
594
593
|
def __call__(self, ):
|
595
594
|
a = self.x.value[0] * self.y.value[0]
|
@@ -601,12 +600,12 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
601
600
|
|
602
601
|
_jr = jax.jacrev(f1)(_x, _y)
|
603
602
|
t = Test()
|
604
|
-
br =
|
603
|
+
br = brainstate.augment.jacrev(t, grad_states=t.x)()
|
605
604
|
self.assertTrue((br == _jr).all())
|
606
605
|
|
607
606
|
_jr = jax.jacrev(f1, argnums=(0, 1))(_x, _y)
|
608
607
|
t = Test()
|
609
|
-
br =
|
608
|
+
br = brainstate.augment.jacrev(t, grad_states=[t.x, t.y])()
|
610
609
|
self.assertTrue((br[0] == _jr[0]).all())
|
611
610
|
self.assertTrue((br[1] == _jr[1]).all())
|
612
611
|
|
@@ -1202,7 +1201,7 @@ class TestUnitAwareGrad(unittest.TestCase):
|
|
1202
1201
|
return u.math.sum(x ** 2)
|
1203
1202
|
|
1204
1203
|
x = jnp.array([1., 2., 3.]) * u.ms
|
1205
|
-
g =
|
1204
|
+
g = brainstate.augment.grad(f, unit_aware=True)(x)
|
1206
1205
|
self.assertTrue(u.math.allclose(g, 2 * x))
|
1207
1206
|
|
1208
1207
|
def test_vector_grad1(self):
|
@@ -1210,7 +1209,7 @@ class TestUnitAwareGrad(unittest.TestCase):
|
|
1210
1209
|
return x ** 3
|
1211
1210
|
|
1212
1211
|
x = jnp.array([1., 2., 3.]) * u.ms
|
1213
|
-
g =
|
1212
|
+
g = brainstate.augment.vector_grad(f, unit_aware=True)(x)
|
1214
1213
|
self.assertTrue(u.math.allclose(g, 3 * x ** 2))
|
1215
1214
|
|
1216
1215
|
def test_jacrev1(self):
|
@@ -1222,7 +1221,7 @@ class TestUnitAwareGrad(unittest.TestCase):
|
|
1222
1221
|
_x = jnp.array([1., 2., 3.]) * u.ms
|
1223
1222
|
_y = jnp.array([10., 5.]) * u.ms
|
1224
1223
|
|
1225
|
-
g =
|
1224
|
+
g = brainstate.augment.jacrev(f, unit_aware=True, argnums=(0, 1))(_x, _y)
|
1226
1225
|
self.assertTrue(
|
1227
1226
|
u.math.allclose(
|
1228
1227
|
g[0],
|
@@ -1254,7 +1253,7 @@ class TestUnitAwareGrad(unittest.TestCase):
|
|
1254
1253
|
_x = jnp.array([1., 2., 3.]) * u.ms
|
1255
1254
|
_y = jnp.array([10., 5.]) * u.ms
|
1256
1255
|
|
1257
|
-
g =
|
1256
|
+
g = brainstate.augment.jacfwd(f, unit_aware=True, argnums=(0, 1))(_x, _y)
|
1258
1257
|
self.assertTrue(
|
1259
1258
|
u.math.allclose(
|
1260
1259
|
g[0],
|
@@ -1283,7 +1282,7 @@ class TestUnitAwareGrad(unittest.TestCase):
|
|
1283
1282
|
def scalar_function(x):
|
1284
1283
|
return x ** 3 + 3 * x * unit * unit + 2 * unit * unit * unit
|
1285
1284
|
|
1286
|
-
hess =
|
1285
|
+
hess = brainstate.augment.hessian(scalar_function, unit_aware=True)
|
1287
1286
|
x = jnp.array(1.0) * unit
|
1288
1287
|
res = hess(x)
|
1289
1288
|
expected_hessian = jnp.array([[6.0]]) * unit
|