brainstate 0.1.1__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +3 -0
  3. brainstate/_state.py +1 -1
  4. brainstate/augment/_autograd_test.py +132 -133
  5. brainstate/augment/_eval_shape_test.py +7 -9
  6. brainstate/augment/_mapping_test.py +75 -76
  7. brainstate/compile/_ad_checkpoint_test.py +6 -8
  8. brainstate/compile/_conditions_test.py +35 -36
  9. brainstate/compile/_error_if_test.py +10 -13
  10. brainstate/compile/_loop_collect_return_test.py +7 -9
  11. brainstate/compile/_loop_no_collection_test.py +7 -8
  12. brainstate/compile/_make_jaxpr.py +29 -14
  13. brainstate/compile/_make_jaxpr_test.py +20 -20
  14. brainstate/functional/_activations_test.py +61 -61
  15. brainstate/graph/_graph_node_test.py +16 -18
  16. brainstate/graph/_graph_operation_test.py +154 -156
  17. brainstate/init/_random_inits_test.py +20 -21
  18. brainstate/init/_regular_inits_test.py +4 -5
  19. brainstate/nn/_collective_ops_test.py +8 -8
  20. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
  21. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
  22. brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
  23. brainstate/nn/_dyn_impl/_readout_test.py +9 -10
  24. brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
  25. brainstate/nn/_dynamics/_synouts_test.py +4 -5
  26. brainstate/nn/_elementwise/_dropout_test.py +9 -9
  27. brainstate/nn/_elementwise/_elementwise_test.py +57 -59
  28. brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
  29. brainstate/nn/_event/_linear_mv_test.py +0 -1
  30. brainstate/nn/_exp_euler_test.py +5 -6
  31. brainstate/nn/_interaction/_conv_test.py +31 -33
  32. brainstate/nn/_interaction/_linear_test.py +15 -17
  33. brainstate/nn/_interaction/_normalizations_test.py +10 -12
  34. brainstate/nn/_interaction/_poolings_test.py +19 -21
  35. brainstate/nn/_module_test.py +34 -37
  36. brainstate/optim/_lr_scheduler_test.py +3 -3
  37. brainstate/optim/_optax_optimizer_test.py +8 -9
  38. brainstate/random/_rand_funs_test.py +183 -184
  39. brainstate/random/_rand_seed_test.py +10 -12
  40. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/METADATA +1 -1
  41. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/RECORD +44 -44
  42. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
  43. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
  44. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,11 @@
1
1
  # -*- coding: utf-8 -*-
2
2
 
3
- from __future__ import annotations
4
-
5
3
  import jax
6
4
  import numpy as np
7
5
  from absl.testing import absltest
8
6
  from absl.testing import parameterized
9
7
 
10
- import brainstate as bst
8
+ import brainstate
11
9
  import brainstate.nn as nn
12
10
 
13
11
 
@@ -18,7 +16,7 @@ class TestFlatten(parameterized.TestCase):
18
16
  (32, 8),
19
17
  (10, 20, 30),
20
18
  ]:
21
- arr = bst.random.rand(*size)
19
+ arr = brainstate.random.rand(*size)
22
20
  f = nn.Flatten(start_axis=0)
23
21
  out = f(arr)
24
22
  self.assertTrue(out.shape == (np.prod(size),))
@@ -29,21 +27,21 @@ class TestFlatten(parameterized.TestCase):
29
27
  (32, 8),
30
28
  (10, 20, 30),
31
29
  ]:
32
- arr = bst.random.rand(*size)
30
+ arr = brainstate.random.rand(*size)
33
31
  f = nn.Flatten(start_axis=1)
34
32
  out = f(arr)
35
33
  self.assertTrue(out.shape == (size[0], np.prod(size[1:])))
36
34
 
37
35
  def test_flatten3(self):
38
36
  size = (16, 32, 32, 8)
39
- arr = bst.random.rand(*size)
37
+ arr = brainstate.random.rand(*size)
40
38
  f = nn.Flatten(start_axis=0, in_size=(32, 8))
41
39
  out = f(arr)
42
40
  self.assertTrue(out.shape == (16, 32, 32 * 8))
43
41
 
44
42
  def test_flatten4(self):
45
43
  size = (16, 32, 32, 8)
46
- arr = bst.random.rand(*size)
44
+ arr = brainstate.random.rand(*size)
47
45
  f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
48
46
  out = f(arr)
49
47
  self.assertTrue(out.shape == (16, 32, 32 * 8))
@@ -58,7 +56,7 @@ class TestPool(parameterized.TestCase):
58
56
  super().__init__(*args, **kwargs)
59
57
 
60
58
  def test_MaxPool2d_v1(self):
61
- arr = bst.random.rand(16, 32, 32, 8)
59
+ arr = brainstate.random.rand(16, 32, 32, 8)
62
60
 
63
61
  out = nn.MaxPool2d(2, 2, channel_axis=-1)(arr)
64
62
  self.assertTrue(out.shape == (16, 16, 16, 8))
@@ -79,7 +77,7 @@ class TestPool(parameterized.TestCase):
79
77
  self.assertTrue(out.shape == (16, 17, 32, 5))
80
78
 
81
79
  def test_AvgPool2d_v1(self):
82
- arr = bst.random.rand(16, 32, 32, 8)
80
+ arr = brainstate.random.rand(16, 32, 32, 8)
83
81
 
84
82
  out = nn.AvgPool2d(2, 2, channel_axis=-1)(arr)
85
83
  self.assertTrue(out.shape == (16, 16, 16, 8))
@@ -107,7 +105,7 @@ class TestPool(parameterized.TestCase):
107
105
  def test_adaptive_pool1d(self, target_size):
108
106
  from brainstate.nn._interaction._poolings import _adaptive_pool1d
109
107
 
110
- arr = bst.random.rand(100)
108
+ arr = brainstate.random.rand(100)
111
109
  op = jax.numpy.mean
112
110
 
113
111
  out = _adaptive_pool1d(arr, target_size, op)
@@ -119,7 +117,7 @@ class TestPool(parameterized.TestCase):
119
117
  self.assertTrue(out.shape == (target_size,))
120
118
 
121
119
  def test_AdaptiveAvgPool2d_v1(self):
122
- input = bst.random.randn(64, 8, 9)
120
+ input = brainstate.random.randn(64, 8, 9)
123
121
 
124
122
  output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
125
123
  self.assertTrue(output.shape == (64, 5, 7))
@@ -137,8 +135,8 @@ class TestPool(parameterized.TestCase):
137
135
  self.assertTrue(output.shape == (64, 2, 3))
138
136
 
139
137
  def test_AdaptiveAvgPool2d_v2(self):
140
- bst.random.seed()
141
- input = bst.random.randn(128, 64, 32, 16)
138
+ brainstate.random.seed()
139
+ input = brainstate.random.randn(128, 64, 32, 16)
142
140
 
143
141
  output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
144
142
  self.assertTrue(output.shape == (128, 64, 5, 7))
@@ -154,13 +152,13 @@ class TestPool(parameterized.TestCase):
154
152
  print()
155
153
 
156
154
  def test_AdaptiveAvgPool3d_v1(self):
157
- input = bst.random.randn(10, 128, 64, 32)
155
+ input = brainstate.random.randn(10, 128, 64, 32)
158
156
  net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3], channel_axis=0)
159
157
  output = net(input)
160
158
  self.assertTrue(output.shape == (10, 6, 5, 3))
161
159
 
162
160
  def test_AdaptiveAvgPool3d_v2(self):
163
- input = bst.random.randn(10, 20, 128, 64, 32)
161
+ input = brainstate.random.randn(10, 20, 128, 64, 32)
164
162
  net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3])
165
163
  output = net(input)
166
164
  self.assertTrue(output.shape == (10, 6, 5, 3, 32))
@@ -169,7 +167,7 @@ class TestPool(parameterized.TestCase):
169
167
  axis=(-1, 0, 1)
170
168
  )
171
169
  def test_AdaptiveMaxPool1d_v1(self, axis):
172
- input = bst.random.randn(32, 16)
170
+ input = brainstate.random.randn(32, 16)
173
171
  net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
174
172
  output = net(input)
175
173
 
@@ -177,7 +175,7 @@ class TestPool(parameterized.TestCase):
177
175
  axis=(-1, 0, 1, 2)
178
176
  )
179
177
  def test_AdaptiveMaxPool1d_v2(self, axis):
180
- input = bst.random.randn(2, 32, 16)
178
+ input = brainstate.random.randn(2, 32, 16)
181
179
  net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
182
180
  output = net(input)
183
181
 
@@ -185,7 +183,7 @@ class TestPool(parameterized.TestCase):
185
183
  axis=(-1, 0, 1, 2)
186
184
  )
187
185
  def test_AdaptiveMaxPool2d_v1(self, axis):
188
- input = bst.random.randn(32, 16, 12)
186
+ input = brainstate.random.randn(32, 16, 12)
189
187
  net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
190
188
  output = net(input)
191
189
 
@@ -193,7 +191,7 @@ class TestPool(parameterized.TestCase):
193
191
  axis=(-1, 0, 1, 2, 3)
194
192
  )
195
193
  def test_AdaptiveMaxPool2d_v2(self, axis):
196
- input = bst.random.randn(2, 32, 16, 12)
194
+ input = brainstate.random.randn(2, 32, 16, 12)
197
195
  net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
198
196
  output = net(input)
199
197
 
@@ -201,7 +199,7 @@ class TestPool(parameterized.TestCase):
201
199
  axis=(-1, 0, 1, 2, 3)
202
200
  )
203
201
  def test_AdaptiveMaxPool3d_v1(self, axis):
204
- input = bst.random.randn(2, 128, 64, 32)
202
+ input = brainstate.random.randn(2, 128, 64, 32)
205
203
  net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
206
204
  output = net(input)
207
205
  print()
@@ -210,7 +208,7 @@ class TestPool(parameterized.TestCase):
210
208
  axis=(-1, 0, 1, 2, 3, 4)
211
209
  )
212
210
  def test_AdaptiveMaxPool3d_v1(self, axis):
213
- input = bst.random.randn(2, 128, 64, 32, 16)
211
+ input = brainstate.random.randn(2, 128, 64, 32, 16)
214
212
  net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
215
213
  output = net(input)
216
214
 
@@ -13,20 +13,17 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import unittest
19
17
 
20
18
  import jax.numpy as jnp
21
- import jaxlib.xla_extension
22
19
 
23
- import brainstate as bst
20
+ import brainstate
24
21
 
25
22
 
26
23
  class TestDelay(unittest.TestCase):
27
24
  def test_delay1(self):
28
- a = bst.State(bst.random.random(10, 20))
29
- delay = bst.nn.Delay(a.value)
25
+ a = brainstate.State(brainstate.random.random(10, 20))
26
+ delay = brainstate.nn.Delay(a.value)
30
27
  delay.register_entry('a', 1.)
31
28
  delay.register_entry('b', 2.)
32
29
  delay.register_entry('c', None)
@@ -36,7 +33,7 @@ class TestDelay(unittest.TestCase):
36
33
  delay.register_entry('c', 10.)
37
34
 
38
35
  def test_rotation_delay(self):
39
- rotation_delay = bst.nn.Delay(jnp.ones((1,)))
36
+ rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
40
37
  t0 = 0.
41
38
  t1, n1 = 1., 10
42
39
  t2, n2 = 2., 20
@@ -53,7 +50,7 @@ class TestDelay(unittest.TestCase):
53
50
  # print(rotation_delay.max_length)
54
51
 
55
52
  for i in range(100):
56
- bst.environ.set(i=i)
53
+ brainstate.environ.set(i=i)
57
54
  rotation_delay.update(jnp.ones((1,)) * i)
58
55
  # print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
59
56
  self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
@@ -61,7 +58,7 @@ class TestDelay(unittest.TestCase):
61
58
  self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
62
59
 
63
60
  def test_concat_delay(self):
64
- rotation_delay = bst.nn.Delay(jnp.ones([1]), delay_method='concat')
61
+ rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
65
62
  t0 = 0.
66
63
  t1, n1 = 1., 10
67
64
  t2, n2 = 2., 20
@@ -74,7 +71,7 @@ class TestDelay(unittest.TestCase):
74
71
 
75
72
  print()
76
73
  for i in range(100):
77
- bst.environ.set(i=i)
74
+ brainstate.environ.set(i=i)
78
75
  rotation_delay.update(jnp.ones((1,)) * i)
79
76
  print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
80
77
  self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
@@ -83,40 +80,40 @@ class TestDelay(unittest.TestCase):
83
80
  # bst.util.clear_buffer_memory()
84
81
 
85
82
  def test_jit_erro(self):
86
- rotation_delay = bst.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
83
+ rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
87
84
  rotation_delay.init_state()
88
85
 
89
- with bst.environ.context(i=0, t=0, jit_error_check=True):
86
+ with brainstate.environ.context(i=0, t=0, jit_error_check=True):
90
87
  rotation_delay.retrieve_at_time(-2.0)
91
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
88
+ with self.assertRaises(Exception):
92
89
  rotation_delay.retrieve_at_time(-2.1)
93
90
  rotation_delay.retrieve_at_time(-2.01)
94
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
91
+ with self.assertRaises(Exception):
95
92
  rotation_delay.retrieve_at_time(-2.09)
96
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
93
+ with self.assertRaises(Exception):
97
94
  rotation_delay.retrieve_at_time(0.1)
98
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
95
+ with self.assertRaises(Exception):
99
96
  rotation_delay.retrieve_at_time(0.01)
100
97
 
101
98
  def test_round_interp(self):
102
99
  for shape in [(1,), (1, 1), (1, 1, 1)]:
103
100
  for delay_method in ['rotation', 'concat']:
104
- rotation_delay = bst.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
105
- interp_method='round')
101
+ rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
102
+ interp_method='round')
106
103
  t0, n1 = 0.01, 0
107
104
  t1, n1 = 1.04, 10
108
105
  t2, n2 = 1.06, 11
109
106
  rotation_delay.init_state()
110
107
 
111
- @bst.compile.jit
108
+ @brainstate.compile.jit
112
109
  def retrieve(td, i):
113
- with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
110
+ with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
114
111
  return rotation_delay.retrieve_at_time(td)
115
112
 
116
113
  print()
117
114
  for i in range(100):
118
- t = i * bst.environ.get_dt()
119
- with bst.environ.context(i=i, t=t):
115
+ t = i * brainstate.environ.get_dt()
116
+ with brainstate.environ.context(i=i, t=t):
120
117
  rotation_delay.update(jnp.ones(shape) * i)
121
118
  print(i,
122
119
  retrieve(t - t0, i),
@@ -131,22 +128,22 @@ class TestDelay(unittest.TestCase):
131
128
  for delay_method in ['rotation', 'concat']:
132
129
  print(shape, delay_method)
133
130
 
134
- rotation_delay = bst.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
135
- interp_method='linear_interp')
131
+ rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
132
+ interp_method='linear_interp')
136
133
  t0, n0 = 0.01, 0.1
137
134
  t1, n1 = 1.04, 10.4
138
135
  t2, n2 = 1.06, 10.6
139
136
  rotation_delay.init_state()
140
137
 
141
- @bst.compile.jit
138
+ @brainstate.compile.jit
142
139
  def retrieve(td, i):
143
- with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
140
+ with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
144
141
  return rotation_delay.retrieve_at_time(td)
145
142
 
146
143
  print()
147
144
  for i in range(100):
148
- t = i * bst.environ.get_dt()
149
- with bst.environ.context(i=i, t=t):
145
+ t = i * brainstate.environ.get_dt()
146
+ with brainstate.environ.context(i=i, t=t):
150
147
  rotation_delay.update(jnp.ones(shape) * i)
151
148
  print(i,
152
149
  retrieve(t - t0, i),
@@ -157,8 +154,8 @@ class TestDelay(unittest.TestCase):
157
154
  self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
158
155
 
159
156
  def test_rotation_and_concat_delay(self):
160
- rotation_delay = bst.nn.Delay(jnp.ones((1,)))
161
- concat_delay = bst.nn.Delay(jnp.ones([1]), delay_method='concat')
157
+ rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
158
+ concat_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
162
159
  t0 = 0.
163
160
  t1, n1 = 1., 10
164
161
  t2, n2 = 2., 20
@@ -175,7 +172,7 @@ class TestDelay(unittest.TestCase):
175
172
 
176
173
  print()
177
174
  for i in range(100):
178
- bst.environ.set(i=i)
175
+ brainstate.environ.set(i=i)
179
176
  new = jnp.ones((1,)) * i
180
177
  rotation_delay.update(new)
181
178
  concat_delay.update(new)
@@ -186,17 +183,17 @@ class TestDelay(unittest.TestCase):
186
183
 
187
184
  class TestModule(unittest.TestCase):
188
185
  def test_states(self):
189
- class A(bst.nn.Module):
186
+ class A(brainstate.nn.Module):
190
187
  def __init__(self):
191
188
  super().__init__()
192
- self.a = bst.State(bst.random.random(10, 20))
193
- self.b = bst.State(bst.random.random(10, 20))
189
+ self.a = brainstate.State(brainstate.random.random(10, 20))
190
+ self.b = brainstate.State(brainstate.random.random(10, 20))
194
191
 
195
- class B(bst.nn.Module):
192
+ class B(brainstate.nn.Module):
196
193
  def __init__(self):
197
194
  super().__init__()
198
195
  self.a = A()
199
- self.b = bst.State(bst.random.random(10, 20))
196
+ self.b = brainstate.State(brainstate.random.random(10, 20))
200
197
 
201
198
  b = B()
202
199
  print()
@@ -207,5 +204,5 @@ class TestModule(unittest.TestCase):
207
204
 
208
205
 
209
206
  if __name__ == '__main__':
210
- with bst.environ.context(dt=0.1):
207
+ with brainstate.environ.context(dt=0.1):
211
208
  unittest.main()
@@ -19,12 +19,12 @@ import unittest
19
19
 
20
20
  import jax.numpy as jnp
21
21
 
22
- import brainstate as bst
22
+ import brainstate
23
23
 
24
24
 
25
25
  class TestMultiStepLR(unittest.TestCase):
26
26
  def test1(self):
27
- lr = bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)
27
+ lr = brainstate.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)
28
28
  for i in range(40):
29
29
  r = lr(i)
30
30
  if i < 10:
@@ -37,7 +37,7 @@ class TestMultiStepLR(unittest.TestCase):
37
37
  self.assertTrue(jnp.allclose(r, 0.0001))
38
38
 
39
39
  def test2(self):
40
- lr = bst.compile.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
40
+ lr = brainstate.compile.jit(brainstate.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
41
41
  for i in range(40):
42
42
  r = lr(i)
43
43
  if i < 10:
@@ -13,39 +13,38 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
20
19
  import jax
21
20
  import optax
22
21
 
23
- import brainstate as bst
22
+ import brainstate
24
23
 
25
24
 
26
25
  class TestOptaxOptimizer(unittest.TestCase):
27
26
  def test1(self):
28
- class Model(bst.nn.Module):
27
+ class Model(brainstate.nn.Module):
29
28
  def __init__(self):
30
29
  super().__init__()
31
- self.linear1 = bst.nn.Linear(2, 3)
32
- self.linear2 = bst.nn.Linear(3, 4)
30
+ self.linear1 = brainstate.nn.Linear(2, 3)
31
+ self.linear2 = brainstate.nn.Linear(3, 4)
33
32
 
34
33
  def __call__(self, x):
35
34
  return self.linear2(self.linear1(x))
36
35
 
37
- x = bst.random.randn(1, 2)
36
+ x = brainstate.random.randn(1, 2)
38
37
  y = jax.numpy.ones((1, 4))
39
38
 
40
39
  model = Model()
41
40
  tx = optax.adam(1e-3)
42
- optimizer = bst.optim.OptaxOptimizer(tx)
43
- optimizer.register_trainable_weights(model.states(bst.ParamState))
41
+ optimizer = brainstate.optim.OptaxOptimizer(tx)
42
+ optimizer.register_trainable_weights(model.states(brainstate.ParamState))
44
43
 
45
44
  loss_fn = lambda: ((model(x) - y) ** 2).mean()
46
45
  prev_loss = loss_fn()
47
46
 
48
- grads = bst.augment.grad(loss_fn, model.states(bst.ParamState))()
47
+ grads = brainstate.augment.grad(loss_fn, model.states(brainstate.ParamState))()
49
48
  optimizer.update(grads)
50
49
 
51
50
  new_loss = loss_fn()