brainstate 0.1.1__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +3 -0
- brainstate/_state.py +1 -1
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +29 -14
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/functional/_activations_test.py +61 -61
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
- brainstate/nn/_dyn_impl/_readout_test.py +9 -10
- brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
- brainstate/nn/_dynamics/_synouts_test.py +4 -5
- brainstate/nn/_elementwise/_dropout_test.py +9 -9
- brainstate/nn/_elementwise/_elementwise_test.py +57 -59
- brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
- brainstate/nn/_event/_linear_mv_test.py +0 -1
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/_interaction/_conv_test.py +31 -33
- brainstate/nn/_interaction/_linear_test.py +15 -17
- brainstate/nn/_interaction/_normalizations_test.py +10 -12
- brainstate/nn/_interaction/_poolings_test.py +19 -21
- brainstate/nn/_module_test.py +34 -37
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed_test.py +10 -12
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/METADATA +1 -1
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/RECORD +44 -44
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,11 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
|
3
|
-
from __future__ import annotations
|
4
|
-
|
5
3
|
import jax
|
6
4
|
import numpy as np
|
7
5
|
from absl.testing import absltest
|
8
6
|
from absl.testing import parameterized
|
9
7
|
|
10
|
-
import brainstate
|
8
|
+
import brainstate
|
11
9
|
import brainstate.nn as nn
|
12
10
|
|
13
11
|
|
@@ -18,7 +16,7 @@ class TestFlatten(parameterized.TestCase):
|
|
18
16
|
(32, 8),
|
19
17
|
(10, 20, 30),
|
20
18
|
]:
|
21
|
-
arr =
|
19
|
+
arr = brainstate.random.rand(*size)
|
22
20
|
f = nn.Flatten(start_axis=0)
|
23
21
|
out = f(arr)
|
24
22
|
self.assertTrue(out.shape == (np.prod(size),))
|
@@ -29,21 +27,21 @@ class TestFlatten(parameterized.TestCase):
|
|
29
27
|
(32, 8),
|
30
28
|
(10, 20, 30),
|
31
29
|
]:
|
32
|
-
arr =
|
30
|
+
arr = brainstate.random.rand(*size)
|
33
31
|
f = nn.Flatten(start_axis=1)
|
34
32
|
out = f(arr)
|
35
33
|
self.assertTrue(out.shape == (size[0], np.prod(size[1:])))
|
36
34
|
|
37
35
|
def test_flatten3(self):
|
38
36
|
size = (16, 32, 32, 8)
|
39
|
-
arr =
|
37
|
+
arr = brainstate.random.rand(*size)
|
40
38
|
f = nn.Flatten(start_axis=0, in_size=(32, 8))
|
41
39
|
out = f(arr)
|
42
40
|
self.assertTrue(out.shape == (16, 32, 32 * 8))
|
43
41
|
|
44
42
|
def test_flatten4(self):
|
45
43
|
size = (16, 32, 32, 8)
|
46
|
-
arr =
|
44
|
+
arr = brainstate.random.rand(*size)
|
47
45
|
f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
|
48
46
|
out = f(arr)
|
49
47
|
self.assertTrue(out.shape == (16, 32, 32 * 8))
|
@@ -58,7 +56,7 @@ class TestPool(parameterized.TestCase):
|
|
58
56
|
super().__init__(*args, **kwargs)
|
59
57
|
|
60
58
|
def test_MaxPool2d_v1(self):
|
61
|
-
arr =
|
59
|
+
arr = brainstate.random.rand(16, 32, 32, 8)
|
62
60
|
|
63
61
|
out = nn.MaxPool2d(2, 2, channel_axis=-1)(arr)
|
64
62
|
self.assertTrue(out.shape == (16, 16, 16, 8))
|
@@ -79,7 +77,7 @@ class TestPool(parameterized.TestCase):
|
|
79
77
|
self.assertTrue(out.shape == (16, 17, 32, 5))
|
80
78
|
|
81
79
|
def test_AvgPool2d_v1(self):
|
82
|
-
arr =
|
80
|
+
arr = brainstate.random.rand(16, 32, 32, 8)
|
83
81
|
|
84
82
|
out = nn.AvgPool2d(2, 2, channel_axis=-1)(arr)
|
85
83
|
self.assertTrue(out.shape == (16, 16, 16, 8))
|
@@ -107,7 +105,7 @@ class TestPool(parameterized.TestCase):
|
|
107
105
|
def test_adaptive_pool1d(self, target_size):
|
108
106
|
from brainstate.nn._interaction._poolings import _adaptive_pool1d
|
109
107
|
|
110
|
-
arr =
|
108
|
+
arr = brainstate.random.rand(100)
|
111
109
|
op = jax.numpy.mean
|
112
110
|
|
113
111
|
out = _adaptive_pool1d(arr, target_size, op)
|
@@ -119,7 +117,7 @@ class TestPool(parameterized.TestCase):
|
|
119
117
|
self.assertTrue(out.shape == (target_size,))
|
120
118
|
|
121
119
|
def test_AdaptiveAvgPool2d_v1(self):
|
122
|
-
input =
|
120
|
+
input = brainstate.random.randn(64, 8, 9)
|
123
121
|
|
124
122
|
output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
|
125
123
|
self.assertTrue(output.shape == (64, 5, 7))
|
@@ -137,8 +135,8 @@ class TestPool(parameterized.TestCase):
|
|
137
135
|
self.assertTrue(output.shape == (64, 2, 3))
|
138
136
|
|
139
137
|
def test_AdaptiveAvgPool2d_v2(self):
|
140
|
-
|
141
|
-
input =
|
138
|
+
brainstate.random.seed()
|
139
|
+
input = brainstate.random.randn(128, 64, 32, 16)
|
142
140
|
|
143
141
|
output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
|
144
142
|
self.assertTrue(output.shape == (128, 64, 5, 7))
|
@@ -154,13 +152,13 @@ class TestPool(parameterized.TestCase):
|
|
154
152
|
print()
|
155
153
|
|
156
154
|
def test_AdaptiveAvgPool3d_v1(self):
|
157
|
-
input =
|
155
|
+
input = brainstate.random.randn(10, 128, 64, 32)
|
158
156
|
net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3], channel_axis=0)
|
159
157
|
output = net(input)
|
160
158
|
self.assertTrue(output.shape == (10, 6, 5, 3))
|
161
159
|
|
162
160
|
def test_AdaptiveAvgPool3d_v2(self):
|
163
|
-
input =
|
161
|
+
input = brainstate.random.randn(10, 20, 128, 64, 32)
|
164
162
|
net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3])
|
165
163
|
output = net(input)
|
166
164
|
self.assertTrue(output.shape == (10, 6, 5, 3, 32))
|
@@ -169,7 +167,7 @@ class TestPool(parameterized.TestCase):
|
|
169
167
|
axis=(-1, 0, 1)
|
170
168
|
)
|
171
169
|
def test_AdaptiveMaxPool1d_v1(self, axis):
|
172
|
-
input =
|
170
|
+
input = brainstate.random.randn(32, 16)
|
173
171
|
net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
|
174
172
|
output = net(input)
|
175
173
|
|
@@ -177,7 +175,7 @@ class TestPool(parameterized.TestCase):
|
|
177
175
|
axis=(-1, 0, 1, 2)
|
178
176
|
)
|
179
177
|
def test_AdaptiveMaxPool1d_v2(self, axis):
|
180
|
-
input =
|
178
|
+
input = brainstate.random.randn(2, 32, 16)
|
181
179
|
net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
|
182
180
|
output = net(input)
|
183
181
|
|
@@ -185,7 +183,7 @@ class TestPool(parameterized.TestCase):
|
|
185
183
|
axis=(-1, 0, 1, 2)
|
186
184
|
)
|
187
185
|
def test_AdaptiveMaxPool2d_v1(self, axis):
|
188
|
-
input =
|
186
|
+
input = brainstate.random.randn(32, 16, 12)
|
189
187
|
net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
|
190
188
|
output = net(input)
|
191
189
|
|
@@ -193,7 +191,7 @@ class TestPool(parameterized.TestCase):
|
|
193
191
|
axis=(-1, 0, 1, 2, 3)
|
194
192
|
)
|
195
193
|
def test_AdaptiveMaxPool2d_v2(self, axis):
|
196
|
-
input =
|
194
|
+
input = brainstate.random.randn(2, 32, 16, 12)
|
197
195
|
net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
|
198
196
|
output = net(input)
|
199
197
|
|
@@ -201,7 +199,7 @@ class TestPool(parameterized.TestCase):
|
|
201
199
|
axis=(-1, 0, 1, 2, 3)
|
202
200
|
)
|
203
201
|
def test_AdaptiveMaxPool3d_v1(self, axis):
|
204
|
-
input =
|
202
|
+
input = brainstate.random.randn(2, 128, 64, 32)
|
205
203
|
net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
|
206
204
|
output = net(input)
|
207
205
|
print()
|
@@ -210,7 +208,7 @@ class TestPool(parameterized.TestCase):
|
|
210
208
|
axis=(-1, 0, 1, 2, 3, 4)
|
211
209
|
)
|
212
210
|
def test_AdaptiveMaxPool3d_v1(self, axis):
|
213
|
-
input =
|
211
|
+
input = brainstate.random.randn(2, 128, 64, 32, 16)
|
214
212
|
net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
|
215
213
|
output = net(input)
|
216
214
|
|
brainstate/nn/_module_test.py
CHANGED
@@ -13,20 +13,17 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
import unittest
|
19
17
|
|
20
18
|
import jax.numpy as jnp
|
21
|
-
import jaxlib.xla_extension
|
22
19
|
|
23
|
-
import brainstate
|
20
|
+
import brainstate
|
24
21
|
|
25
22
|
|
26
23
|
class TestDelay(unittest.TestCase):
|
27
24
|
def test_delay1(self):
|
28
|
-
a =
|
29
|
-
delay =
|
25
|
+
a = brainstate.State(brainstate.random.random(10, 20))
|
26
|
+
delay = brainstate.nn.Delay(a.value)
|
30
27
|
delay.register_entry('a', 1.)
|
31
28
|
delay.register_entry('b', 2.)
|
32
29
|
delay.register_entry('c', None)
|
@@ -36,7 +33,7 @@ class TestDelay(unittest.TestCase):
|
|
36
33
|
delay.register_entry('c', 10.)
|
37
34
|
|
38
35
|
def test_rotation_delay(self):
|
39
|
-
rotation_delay =
|
36
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
|
40
37
|
t0 = 0.
|
41
38
|
t1, n1 = 1., 10
|
42
39
|
t2, n2 = 2., 20
|
@@ -53,7 +50,7 @@ class TestDelay(unittest.TestCase):
|
|
53
50
|
# print(rotation_delay.max_length)
|
54
51
|
|
55
52
|
for i in range(100):
|
56
|
-
|
53
|
+
brainstate.environ.set(i=i)
|
57
54
|
rotation_delay.update(jnp.ones((1,)) * i)
|
58
55
|
# print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
|
59
56
|
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
@@ -61,7 +58,7 @@ class TestDelay(unittest.TestCase):
|
|
61
58
|
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
62
59
|
|
63
60
|
def test_concat_delay(self):
|
64
|
-
rotation_delay =
|
61
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
|
65
62
|
t0 = 0.
|
66
63
|
t1, n1 = 1., 10
|
67
64
|
t2, n2 = 2., 20
|
@@ -74,7 +71,7 @@ class TestDelay(unittest.TestCase):
|
|
74
71
|
|
75
72
|
print()
|
76
73
|
for i in range(100):
|
77
|
-
|
74
|
+
brainstate.environ.set(i=i)
|
78
75
|
rotation_delay.update(jnp.ones((1,)) * i)
|
79
76
|
print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
|
80
77
|
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
@@ -83,40 +80,40 @@ class TestDelay(unittest.TestCase):
|
|
83
80
|
# bst.util.clear_buffer_memory()
|
84
81
|
|
85
82
|
def test_jit_erro(self):
|
86
|
-
rotation_delay =
|
83
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
|
87
84
|
rotation_delay.init_state()
|
88
85
|
|
89
|
-
with
|
86
|
+
with brainstate.environ.context(i=0, t=0, jit_error_check=True):
|
90
87
|
rotation_delay.retrieve_at_time(-2.0)
|
91
|
-
with self.assertRaises(
|
88
|
+
with self.assertRaises(Exception):
|
92
89
|
rotation_delay.retrieve_at_time(-2.1)
|
93
90
|
rotation_delay.retrieve_at_time(-2.01)
|
94
|
-
with self.assertRaises(
|
91
|
+
with self.assertRaises(Exception):
|
95
92
|
rotation_delay.retrieve_at_time(-2.09)
|
96
|
-
with self.assertRaises(
|
93
|
+
with self.assertRaises(Exception):
|
97
94
|
rotation_delay.retrieve_at_time(0.1)
|
98
|
-
with self.assertRaises(
|
95
|
+
with self.assertRaises(Exception):
|
99
96
|
rotation_delay.retrieve_at_time(0.01)
|
100
97
|
|
101
98
|
def test_round_interp(self):
|
102
99
|
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
103
100
|
for delay_method in ['rotation', 'concat']:
|
104
|
-
rotation_delay =
|
105
|
-
|
101
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
|
102
|
+
interp_method='round')
|
106
103
|
t0, n1 = 0.01, 0
|
107
104
|
t1, n1 = 1.04, 10
|
108
105
|
t2, n2 = 1.06, 11
|
109
106
|
rotation_delay.init_state()
|
110
107
|
|
111
|
-
@
|
108
|
+
@brainstate.compile.jit
|
112
109
|
def retrieve(td, i):
|
113
|
-
with
|
110
|
+
with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
|
114
111
|
return rotation_delay.retrieve_at_time(td)
|
115
112
|
|
116
113
|
print()
|
117
114
|
for i in range(100):
|
118
|
-
t = i *
|
119
|
-
with
|
115
|
+
t = i * brainstate.environ.get_dt()
|
116
|
+
with brainstate.environ.context(i=i, t=t):
|
120
117
|
rotation_delay.update(jnp.ones(shape) * i)
|
121
118
|
print(i,
|
122
119
|
retrieve(t - t0, i),
|
@@ -131,22 +128,22 @@ class TestDelay(unittest.TestCase):
|
|
131
128
|
for delay_method in ['rotation', 'concat']:
|
132
129
|
print(shape, delay_method)
|
133
130
|
|
134
|
-
rotation_delay =
|
135
|
-
|
131
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
|
132
|
+
interp_method='linear_interp')
|
136
133
|
t0, n0 = 0.01, 0.1
|
137
134
|
t1, n1 = 1.04, 10.4
|
138
135
|
t2, n2 = 1.06, 10.6
|
139
136
|
rotation_delay.init_state()
|
140
137
|
|
141
|
-
@
|
138
|
+
@brainstate.compile.jit
|
142
139
|
def retrieve(td, i):
|
143
|
-
with
|
140
|
+
with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
|
144
141
|
return rotation_delay.retrieve_at_time(td)
|
145
142
|
|
146
143
|
print()
|
147
144
|
for i in range(100):
|
148
|
-
t = i *
|
149
|
-
with
|
145
|
+
t = i * brainstate.environ.get_dt()
|
146
|
+
with brainstate.environ.context(i=i, t=t):
|
150
147
|
rotation_delay.update(jnp.ones(shape) * i)
|
151
148
|
print(i,
|
152
149
|
retrieve(t - t0, i),
|
@@ -157,8 +154,8 @@ class TestDelay(unittest.TestCase):
|
|
157
154
|
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
158
155
|
|
159
156
|
def test_rotation_and_concat_delay(self):
|
160
|
-
rotation_delay =
|
161
|
-
concat_delay =
|
157
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
|
158
|
+
concat_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
|
162
159
|
t0 = 0.
|
163
160
|
t1, n1 = 1., 10
|
164
161
|
t2, n2 = 2., 20
|
@@ -175,7 +172,7 @@ class TestDelay(unittest.TestCase):
|
|
175
172
|
|
176
173
|
print()
|
177
174
|
for i in range(100):
|
178
|
-
|
175
|
+
brainstate.environ.set(i=i)
|
179
176
|
new = jnp.ones((1,)) * i
|
180
177
|
rotation_delay.update(new)
|
181
178
|
concat_delay.update(new)
|
@@ -186,17 +183,17 @@ class TestDelay(unittest.TestCase):
|
|
186
183
|
|
187
184
|
class TestModule(unittest.TestCase):
|
188
185
|
def test_states(self):
|
189
|
-
class A(
|
186
|
+
class A(brainstate.nn.Module):
|
190
187
|
def __init__(self):
|
191
188
|
super().__init__()
|
192
|
-
self.a =
|
193
|
-
self.b =
|
189
|
+
self.a = brainstate.State(brainstate.random.random(10, 20))
|
190
|
+
self.b = brainstate.State(brainstate.random.random(10, 20))
|
194
191
|
|
195
|
-
class B(
|
192
|
+
class B(brainstate.nn.Module):
|
196
193
|
def __init__(self):
|
197
194
|
super().__init__()
|
198
195
|
self.a = A()
|
199
|
-
self.b =
|
196
|
+
self.b = brainstate.State(brainstate.random.random(10, 20))
|
200
197
|
|
201
198
|
b = B()
|
202
199
|
print()
|
@@ -207,5 +204,5 @@ class TestModule(unittest.TestCase):
|
|
207
204
|
|
208
205
|
|
209
206
|
if __name__ == '__main__':
|
210
|
-
with
|
207
|
+
with brainstate.environ.context(dt=0.1):
|
211
208
|
unittest.main()
|
@@ -19,12 +19,12 @@ import unittest
|
|
19
19
|
|
20
20
|
import jax.numpy as jnp
|
21
21
|
|
22
|
-
import brainstate
|
22
|
+
import brainstate
|
23
23
|
|
24
24
|
|
25
25
|
class TestMultiStepLR(unittest.TestCase):
|
26
26
|
def test1(self):
|
27
|
-
lr =
|
27
|
+
lr = brainstate.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)
|
28
28
|
for i in range(40):
|
29
29
|
r = lr(i)
|
30
30
|
if i < 10:
|
@@ -37,7 +37,7 @@ class TestMultiStepLR(unittest.TestCase):
|
|
37
37
|
self.assertTrue(jnp.allclose(r, 0.0001))
|
38
38
|
|
39
39
|
def test2(self):
|
40
|
-
lr =
|
40
|
+
lr = brainstate.compile.jit(brainstate.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
|
41
41
|
for i in range(40):
|
42
42
|
r = lr(i)
|
43
43
|
if i < 10:
|
@@ -13,39 +13,38 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
16
|
|
18
17
|
import unittest
|
19
18
|
|
20
19
|
import jax
|
21
20
|
import optax
|
22
21
|
|
23
|
-
import brainstate
|
22
|
+
import brainstate
|
24
23
|
|
25
24
|
|
26
25
|
class TestOptaxOptimizer(unittest.TestCase):
|
27
26
|
def test1(self):
|
28
|
-
class Model(
|
27
|
+
class Model(brainstate.nn.Module):
|
29
28
|
def __init__(self):
|
30
29
|
super().__init__()
|
31
|
-
self.linear1 =
|
32
|
-
self.linear2 =
|
30
|
+
self.linear1 = brainstate.nn.Linear(2, 3)
|
31
|
+
self.linear2 = brainstate.nn.Linear(3, 4)
|
33
32
|
|
34
33
|
def __call__(self, x):
|
35
34
|
return self.linear2(self.linear1(x))
|
36
35
|
|
37
|
-
x =
|
36
|
+
x = brainstate.random.randn(1, 2)
|
38
37
|
y = jax.numpy.ones((1, 4))
|
39
38
|
|
40
39
|
model = Model()
|
41
40
|
tx = optax.adam(1e-3)
|
42
|
-
optimizer =
|
43
|
-
optimizer.register_trainable_weights(model.states(
|
41
|
+
optimizer = brainstate.optim.OptaxOptimizer(tx)
|
42
|
+
optimizer.register_trainable_weights(model.states(brainstate.ParamState))
|
44
43
|
|
45
44
|
loss_fn = lambda: ((model(x) - y) ** 2).mean()
|
46
45
|
prev_loss = loss_fn()
|
47
46
|
|
48
|
-
grads =
|
47
|
+
grads = brainstate.augment.grad(loss_fn, model.states(brainstate.ParamState))()
|
49
48
|
optimizer.update(grads)
|
50
49
|
|
51
50
|
new_loss = loss_fn()
|