brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +10 -3
- brainstate/_state.py +178 -178
- brainstate/_utils.py +0 -1
- brainstate/augment/_autograd.py +0 -2
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape.py +0 -2
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping.py +2 -3
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/augment/_random.py +0 -2
- brainstate/compile/_ad_checkpoint.py +0 -2
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions.py +0 -2
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if.py +0 -2
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_jit.py +9 -8
- brainstate/compile/_loop_collect_return.py +0 -2
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection.py +0 -2
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +30 -17
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/compile/_progress_bar.py +0 -1
- brainstate/compile/_unvmap.py +0 -1
- brainstate/compile/_util.py +0 -2
- brainstate/environ.py +0 -2
- brainstate/functional/_activations.py +0 -2
- brainstate/functional/_activations_test.py +61 -61
- brainstate/functional/_normalization.py +0 -2
- brainstate/functional/_others.py +0 -2
- brainstate/functional/_spikes.py +0 -1
- brainstate/graph/_graph_node.py +1 -3
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation.py +4 -2
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_base.py +0 -2
- brainstate/init/_generic.py +0 -1
- brainstate/init/_random_inits.py +0 -1
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits.py +0 -2
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/mixin.py +0 -2
- brainstate/nn/_collective_ops.py +0 -3
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/_common.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
- brainstate/nn/_dyn_impl/_inputs.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
- brainstate/nn/_dyn_impl/_readout.py +0 -1
- brainstate/nn/_dyn_impl/_readout_test.py +9 -10
- brainstate/nn/_dynamics/_dynamics_base.py +0 -1
- brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
- brainstate/nn/_dynamics/_projection_base.py +0 -1
- brainstate/nn/_dynamics/_state_delay.py +0 -2
- brainstate/nn/_dynamics/_synouts.py +0 -2
- brainstate/nn/_dynamics/_synouts_test.py +4 -5
- brainstate/nn/_elementwise/_dropout.py +0 -2
- brainstate/nn/_elementwise/_dropout_test.py +9 -9
- brainstate/nn/_elementwise/_elementwise.py +0 -2
- brainstate/nn/_elementwise/_elementwise_test.py +57 -59
- brainstate/nn/_event/_fixedprob_mv.py +0 -1
- brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
- brainstate/nn/_event/_linear_mv.py +0 -2
- brainstate/nn/_event/_linear_mv_test.py +0 -1
- brainstate/nn/_exp_euler.py +0 -2
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/_interaction/_conv.py +0 -2
- brainstate/nn/_interaction/_conv_test.py +31 -33
- brainstate/nn/_interaction/_embedding.py +0 -1
- brainstate/nn/_interaction/_linear.py +0 -2
- brainstate/nn/_interaction/_linear_test.py +15 -17
- brainstate/nn/_interaction/_normalizations.py +0 -2
- brainstate/nn/_interaction/_normalizations_test.py +10 -12
- brainstate/nn/_interaction/_poolings.py +0 -2
- brainstate/nn/_interaction/_poolings_test.py +19 -21
- brainstate/nn/_module.py +0 -1
- brainstate/nn/_module_test.py +34 -37
- brainstate/nn/metrics.py +0 -2
- brainstate/optim/_base.py +0 -2
- brainstate/optim/_lr_scheduler.py +0 -1
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer.py +0 -2
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/optim/_sgd_optimizer.py +0 -1
- brainstate/random/_rand_funs.py +0 -1
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed.py +0 -1
- brainstate/random/_rand_seed_test.py +10 -12
- brainstate/random/_rand_state.py +0 -1
- brainstate/surrogate.py +0 -1
- brainstate/typing.py +0 -2
- brainstate/util/_caller.py +4 -6
- brainstate/util/_others.py +0 -2
- brainstate/util/_pretty_pytree.py +201 -150
- brainstate/util/_pretty_repr.py +0 -2
- brainstate/util/_pretty_table.py +57 -3
- brainstate/util/_scaling.py +0 -2
- brainstate/util/_struct.py +0 -2
- brainstate/util/filter.py +0 -2
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
- brainstate-0.1.2.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -25,48 +25,48 @@ from absl.testing import parameterized
|
|
25
25
|
from jax._src import test_util as jtu
|
26
26
|
from jax.test_util import check_grads
|
27
27
|
|
28
|
-
import brainstate
|
28
|
+
import brainstate
|
29
29
|
|
30
30
|
|
31
31
|
class NNFunctionsTest(jtu.JaxTestCase):
|
32
32
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
33
33
|
def testSoftplusGrad(self):
|
34
|
-
check_grads(
|
34
|
+
check_grads(brainstate.functional.softplus, (1e-8,), order=4, )
|
35
35
|
|
36
36
|
def testSoftplusGradZero(self):
|
37
|
-
check_grads(
|
37
|
+
check_grads(brainstate.functional.softplus, (0.,), order=1)
|
38
38
|
|
39
39
|
def testSoftplusGradInf(self):
|
40
|
-
self.assertAllClose(1., jax.grad(
|
40
|
+
self.assertAllClose(1., jax.grad(brainstate.functional.softplus)(float('inf')))
|
41
41
|
|
42
42
|
def testSoftplusGradNegInf(self):
|
43
|
-
check_grads(
|
43
|
+
check_grads(brainstate.functional.softplus, (-float('inf'),), order=1)
|
44
44
|
|
45
45
|
def testSoftplusGradNan(self):
|
46
|
-
check_grads(
|
46
|
+
check_grads(brainstate.functional.softplus, (float('nan'),), order=1)
|
47
47
|
|
48
48
|
@parameterized.parameters([int, float] + jtu.dtypes.floating + jtu.dtypes.integer)
|
49
49
|
def testSoftplusZero(self, dtype):
|
50
|
-
self.assertEqual(jnp.log(dtype(2)),
|
50
|
+
self.assertEqual(jnp.log(dtype(2)), brainstate.functional.softplus(dtype(0)))
|
51
51
|
|
52
52
|
def testSparseplusGradZero(self):
|
53
|
-
check_grads(
|
53
|
+
check_grads(brainstate.functional.sparse_plus, (-2.,), order=1)
|
54
54
|
|
55
55
|
def testSparseplusGrad(self):
|
56
|
-
check_grads(
|
56
|
+
check_grads(brainstate.functional.sparse_plus, (0.,), order=1)
|
57
57
|
|
58
58
|
def testSparseplusAndSparseSigmoid(self):
|
59
59
|
self.assertAllClose(
|
60
|
-
jax.grad(
|
61
|
-
|
60
|
+
jax.grad(brainstate.functional.sparse_plus)(0.),
|
61
|
+
brainstate.functional.sparse_sigmoid(0.),
|
62
62
|
check_dtypes=False)
|
63
63
|
self.assertAllClose(
|
64
|
-
jax.grad(
|
65
|
-
|
64
|
+
jax.grad(brainstate.functional.sparse_plus)(2.),
|
65
|
+
brainstate.functional.sparse_sigmoid(2.),
|
66
66
|
check_dtypes=False)
|
67
67
|
self.assertAllClose(
|
68
|
-
jax.grad(
|
69
|
-
|
68
|
+
jax.grad(brainstate.functional.sparse_plus)(-2.),
|
69
|
+
brainstate.functional.sparse_sigmoid(-2.),
|
70
70
|
check_dtypes=False)
|
71
71
|
|
72
72
|
# def testSquareplusGrad(self):
|
@@ -107,55 +107,55 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
107
107
|
|
108
108
|
@parameterized.parameters([float] + jtu.dtypes.floating)
|
109
109
|
def testMishZero(self, dtype):
|
110
|
-
self.assertEqual(dtype(0),
|
110
|
+
self.assertEqual(dtype(0), brainstate.functional.mish(dtype(0)))
|
111
111
|
|
112
112
|
def testReluGrad(self):
|
113
113
|
rtol = None
|
114
|
-
check_grads(
|
115
|
-
check_grads(
|
116
|
-
jaxpr = jax.make_jaxpr(jax.grad(
|
114
|
+
check_grads(brainstate.functional.relu, (1.,), order=3, rtol=rtol)
|
115
|
+
check_grads(brainstate.functional.relu, (-1.,), order=3, rtol=rtol)
|
116
|
+
jaxpr = jax.make_jaxpr(jax.grad(brainstate.functional.relu))(0.)
|
117
117
|
self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
|
118
118
|
|
119
119
|
def testRelu6Grad(self):
|
120
120
|
rtol = None
|
121
|
-
check_grads(
|
122
|
-
check_grads(
|
123
|
-
self.assertAllClose(jax.grad(
|
124
|
-
self.assertAllClose(jax.grad(
|
121
|
+
check_grads(brainstate.functional.relu6, (1.,), order=3, rtol=rtol)
|
122
|
+
check_grads(brainstate.functional.relu6, (-1.,), order=3, rtol=rtol)
|
123
|
+
self.assertAllClose(jax.grad(brainstate.functional.relu6)(0.), 0., check_dtypes=False)
|
124
|
+
self.assertAllClose(jax.grad(brainstate.functional.relu6)(6.), 0., check_dtypes=False)
|
125
125
|
|
126
126
|
def testSoftplusValue(self):
|
127
|
-
val =
|
127
|
+
val = brainstate.functional.softplus(89.)
|
128
128
|
self.assertAllClose(val, 89., check_dtypes=False)
|
129
129
|
|
130
130
|
def testSparseplusValue(self):
|
131
|
-
val =
|
131
|
+
val = brainstate.functional.sparse_plus(89.)
|
132
132
|
self.assertAllClose(val, 89., check_dtypes=False)
|
133
133
|
|
134
134
|
def testSparsesigmoidValue(self):
|
135
|
-
self.assertAllClose(
|
136
|
-
self.assertAllClose(
|
137
|
-
self.assertAllClose(
|
135
|
+
self.assertAllClose(brainstate.functional.sparse_sigmoid(-2.), 0., check_dtypes=False)
|
136
|
+
self.assertAllClose(brainstate.functional.sparse_sigmoid(2.), 1., check_dtypes=False)
|
137
|
+
self.assertAllClose(brainstate.functional.sparse_sigmoid(0.), .5, check_dtypes=False)
|
138
138
|
|
139
139
|
# def testSquareplusValue(self):
|
140
140
|
# val = bst.functional.squareplus(1e3)
|
141
141
|
# self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
142
142
|
|
143
143
|
def testMishValue(self):
|
144
|
-
val =
|
144
|
+
val = brainstate.functional.mish(1e3)
|
145
145
|
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
146
146
|
|
147
147
|
def testEluValue(self):
|
148
|
-
val =
|
148
|
+
val = brainstate.functional.elu(1e4)
|
149
149
|
self.assertAllClose(val, 1e4, check_dtypes=False)
|
150
150
|
|
151
151
|
def testGluValue(self):
|
152
|
-
val =
|
152
|
+
val = brainstate.functional.glu(jnp.array([1.0, 0.0]), axis=0)
|
153
153
|
self.assertAllClose(val, jnp.array([0.5]))
|
154
154
|
|
155
155
|
@parameterized.parameters(False, True)
|
156
156
|
def testGeluIntType(self, approximate):
|
157
|
-
val_float =
|
158
|
-
val_int =
|
157
|
+
val_float = brainstate.functional.gelu(jnp.array(-1.0), approximate=approximate)
|
158
|
+
val_int = brainstate.functional.gelu(jnp.array(-1), approximate=approximate)
|
159
159
|
self.assertAllClose(val_float, val_int)
|
160
160
|
|
161
161
|
@parameterized.parameters(False, True)
|
@@ -166,19 +166,19 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
166
166
|
rng = jtu.rand_default(self.rng())
|
167
167
|
args_maker = lambda: [rng((4, 5, 6), jnp.float32)]
|
168
168
|
self._CheckAgainstNumpy(
|
169
|
-
gelu_reference, partial(
|
169
|
+
gelu_reference, partial(brainstate.functional.gelu, approximate=approximate), args_maker,
|
170
170
|
check_dtypes=False, tol=1e-3 if approximate else None)
|
171
171
|
|
172
172
|
@parameterized.parameters(*itertools.product(
|
173
173
|
(jnp.float32, jnp.bfloat16, jnp.float16),
|
174
|
-
(partial(
|
175
|
-
partial(
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
174
|
+
(partial(brainstate.functional.gelu, approximate=False),
|
175
|
+
partial(brainstate.functional.gelu, approximate=True),
|
176
|
+
brainstate.functional.relu,
|
177
|
+
brainstate.functional.softplus,
|
178
|
+
brainstate.functional.sparse_plus,
|
179
|
+
brainstate.functional.sigmoid,
|
180
180
|
# bst.functional.squareplus,
|
181
|
-
|
181
|
+
brainstate.functional.mish)))
|
182
182
|
def testDtypeMatchesInput(self, dtype, fn):
|
183
183
|
x = jnp.zeros((), dtype=dtype)
|
184
184
|
out = fn(x)
|
@@ -187,26 +187,26 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
187
187
|
def testEluMemory(self):
|
188
188
|
# see https://github.com/google/jax/pull/1640
|
189
189
|
with jax.enable_checks(False): # With checks we materialize the array
|
190
|
-
jax.make_jaxpr(lambda:
|
190
|
+
jax.make_jaxpr(lambda: brainstate.functional.elu(jnp.ones((10 ** 12,)))) # don't oom
|
191
191
|
|
192
192
|
def testHardTanhMemory(self):
|
193
193
|
# see https://github.com/google/jax/pull/1640
|
194
194
|
with jax.enable_checks(False): # With checks we materialize the array
|
195
|
-
jax.make_jaxpr(lambda:
|
195
|
+
jax.make_jaxpr(lambda: brainstate.functional.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
|
196
196
|
|
197
|
-
@parameterized.parameters([
|
197
|
+
@parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
|
198
198
|
def testSoftmaxEmptyArray(self, fn):
|
199
199
|
x = jnp.array([], dtype=float)
|
200
200
|
self.assertArraysEqual(fn(x), x)
|
201
201
|
|
202
|
-
@parameterized.parameters([
|
202
|
+
@parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
|
203
203
|
def testSoftmaxEmptyMask(self, fn):
|
204
204
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
205
205
|
m = jnp.zeros_like(x, dtype=bool)
|
206
|
-
expected = jnp.full_like(x, 0.0 if fn is
|
206
|
+
expected = jnp.full_like(x, 0.0 if fn is brainstate.functional.softmax else -jnp.inf)
|
207
207
|
self.assertArraysEqual(fn(x, where=m), expected)
|
208
208
|
|
209
|
-
@parameterized.parameters([
|
209
|
+
@parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
|
210
210
|
def testSoftmaxWhereMask(self, fn):
|
211
211
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
212
212
|
m = jnp.array([True, False, True, True])
|
@@ -214,10 +214,10 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
214
214
|
out = fn(x, where=m)
|
215
215
|
self.assertAllClose(out[m], fn(x[m]))
|
216
216
|
|
217
|
-
probs = out if fn is
|
217
|
+
probs = out if fn is brainstate.functional.softmax else jnp.exp(out)
|
218
218
|
self.assertAllClose(probs.sum(), 1.0)
|
219
219
|
|
220
|
-
@parameterized.parameters([
|
220
|
+
@parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
|
221
221
|
def testSoftmaxWhereGrad(self, fn):
|
222
222
|
# regression test for https://github.com/google/jax/issues/19490
|
223
223
|
x = jnp.array([36., 10000.])
|
@@ -229,46 +229,46 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
229
229
|
|
230
230
|
def testSoftmaxGrad(self):
|
231
231
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
232
|
-
jtu.check_grads(
|
232
|
+
jtu.check_grads(brainstate.functional.softmax, (x,), order=2, atol=5e-3)
|
233
233
|
|
234
234
|
def testStandardizeWhereMask(self):
|
235
235
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
236
236
|
m = jnp.array([True, False, True, True])
|
237
237
|
x_filtered = jnp.take(x, jnp.array([0, 2, 3]))
|
238
238
|
|
239
|
-
out_masked = jnp.take(
|
240
|
-
out_filtered =
|
239
|
+
out_masked = jnp.take(brainstate.functional.standardize(x, where=m), jnp.array([0, 2, 3]))
|
240
|
+
out_filtered = brainstate.functional.standardize(x_filtered)
|
241
241
|
|
242
242
|
self.assertAllClose(out_masked, out_filtered)
|
243
243
|
|
244
244
|
def testOneHot(self):
|
245
|
-
actual =
|
245
|
+
actual = brainstate.functional.one_hot(jnp.array([0, 1, 2]), 3)
|
246
246
|
expected = jnp.array([[1., 0., 0.],
|
247
247
|
[0., 1., 0.],
|
248
248
|
[0., 0., 1.]])
|
249
249
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
250
250
|
|
251
|
-
actual =
|
251
|
+
actual = brainstate.functional.one_hot(jnp.array([1, 2, 0]), 3)
|
252
252
|
expected = jnp.array([[0., 1., 0.],
|
253
253
|
[0., 0., 1.],
|
254
254
|
[1., 0., 0.]])
|
255
255
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
256
256
|
|
257
257
|
def testOneHotOutOfBound(self):
|
258
|
-
actual =
|
258
|
+
actual = brainstate.functional.one_hot(jnp.array([-1, 3]), 3)
|
259
259
|
expected = jnp.array([[0., 0., 0.],
|
260
260
|
[0., 0., 0.]])
|
261
261
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
262
262
|
|
263
263
|
def testOneHotNonArrayInput(self):
|
264
|
-
actual =
|
264
|
+
actual = brainstate.functional.one_hot([0, 1, 2], 3)
|
265
265
|
expected = jnp.array([[1., 0., 0.],
|
266
266
|
[0., 1., 0.],
|
267
267
|
[0., 0., 1.]])
|
268
268
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
269
269
|
|
270
270
|
def testOneHotCustomDtype(self):
|
271
|
-
actual =
|
271
|
+
actual = brainstate.functional.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
|
272
272
|
expected = jnp.array([[True, False, False],
|
273
273
|
[False, True, False],
|
274
274
|
[False, False, True]])
|
@@ -279,14 +279,14 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
279
279
|
[0., 0., 1.],
|
280
280
|
[1., 0., 0.]]).T
|
281
281
|
|
282
|
-
actual =
|
282
|
+
actual = brainstate.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=0)
|
283
283
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
284
284
|
|
285
|
-
actual =
|
285
|
+
actual = brainstate.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
|
286
286
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
287
287
|
|
288
288
|
def testTanhExists(self):
|
289
|
-
print(
|
289
|
+
print(brainstate.functional.tanh) # doesn't crash
|
290
290
|
|
291
291
|
def testCustomJVPLeak(self):
|
292
292
|
# https://github.com/google/jax/issues/8171
|
@@ -295,7 +295,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|
295
295
|
a = jnp.array(1.)
|
296
296
|
|
297
297
|
def f(hx, _):
|
298
|
-
hx =
|
298
|
+
hx = brainstate.functional.sigmoid(hx + a)
|
299
299
|
return hx, None
|
300
300
|
|
301
301
|
hx = jnp.array(0.)
|
brainstate/functional/_others.py
CHANGED
brainstate/functional/_spikes.py
CHANGED
brainstate/graph/_graph_node.py
CHANGED
@@ -15,8 +15,6 @@
|
|
15
15
|
# See the License for the specific language governing permissions and
|
16
16
|
# limitations under the License.
|
17
17
|
|
18
|
-
from __future__ import annotations
|
19
|
-
|
20
18
|
from abc import ABCMeta
|
21
19
|
from copy import deepcopy
|
22
20
|
from typing import Any, Callable, Type, TypeVar, Tuple, TYPE_CHECKING, Mapping, Iterator, Sequence
|
@@ -210,7 +208,7 @@ class List(Node):
|
|
210
208
|
def __len__(self):
|
211
209
|
return len(vars(self))
|
212
210
|
|
213
|
-
def __add__(self, other: Sequence[A]) -> List[A]:
|
211
|
+
def __add__(self, other: Sequence[A]) -> 'List[A]':
|
214
212
|
return List(list(self) + list(other))
|
215
213
|
|
216
214
|
def append(self, value):
|
@@ -13,63 +13,61 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
import unittest
|
19
17
|
|
20
|
-
import brainstate
|
18
|
+
import brainstate
|
21
19
|
|
22
20
|
|
23
21
|
class TestSequential(unittest.TestCase):
|
24
22
|
def test1(self):
|
25
|
-
s =
|
26
|
-
|
27
|
-
graphdef, states =
|
23
|
+
s = brainstate.graph.Sequential(brainstate.nn.Linear(1, 2),
|
24
|
+
brainstate.nn.Linear(2, 3))
|
25
|
+
graphdef, states = brainstate.graph.treefy_split(s)
|
28
26
|
print(states)
|
29
27
|
self.assertTrue(len(states.to_flat()) == 2)
|
30
28
|
|
31
29
|
|
32
30
|
class TestStateRetrieve(unittest.TestCase):
|
33
31
|
def test_list_of_states_1(self):
|
34
|
-
class Model(
|
32
|
+
class Model(brainstate.graph.Node):
|
35
33
|
def __init__(self):
|
36
34
|
self.a = [1, 2, 3]
|
37
|
-
self.b = [
|
35
|
+
self.b = [brainstate.State(1), brainstate.State(2), brainstate.State(3)]
|
38
36
|
|
39
37
|
m = Model()
|
40
|
-
graphdef, states =
|
38
|
+
graphdef, states = brainstate.graph.treefy_split(m)
|
41
39
|
print(states.to_flat())
|
42
40
|
self.assertTrue(len(states.to_flat()) == 3)
|
43
41
|
|
44
42
|
def test_list_of_states_2(self):
|
45
|
-
class Model(
|
43
|
+
class Model(brainstate.graph.Node):
|
46
44
|
def __init__(self):
|
47
45
|
self.a = [1, 2, 3]
|
48
|
-
self.b = [
|
46
|
+
self.b = [brainstate.State(1), [brainstate.State(2), brainstate.State(3)]]
|
49
47
|
|
50
48
|
m = Model()
|
51
|
-
graphdef, states =
|
49
|
+
graphdef, states = brainstate.graph.treefy_split(m)
|
52
50
|
print(states.to_flat())
|
53
51
|
self.assertTrue(len(states.to_flat()) == 3)
|
54
52
|
|
55
53
|
def test_list_of_node_1(self):
|
56
|
-
class Model(
|
54
|
+
class Model(brainstate.graph.Node):
|
57
55
|
def __init__(self):
|
58
56
|
self.a = [1, 2, 3]
|
59
|
-
self.b = [
|
57
|
+
self.b = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(2, 3)]
|
60
58
|
|
61
59
|
m = Model()
|
62
|
-
graphdef, states =
|
60
|
+
graphdef, states = brainstate.graph.treefy_split(m)
|
63
61
|
print(states.to_flat())
|
64
62
|
self.assertTrue(len(states.to_flat()) == 2)
|
65
63
|
|
66
64
|
def test_list_of_node_2(self):
|
67
|
-
class Model(
|
65
|
+
class Model(brainstate.graph.Node):
|
68
66
|
def __init__(self):
|
69
67
|
self.a = [1, 2, 3]
|
70
|
-
self.b = [
|
68
|
+
self.b = [brainstate.nn.Linear(1, 2), [brainstate.nn.Linear(2, 3)], (brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5))]
|
71
69
|
|
72
70
|
m = Model()
|
73
|
-
graphdef, states =
|
71
|
+
graphdef, states = brainstate.graph.treefy_split(m)
|
74
72
|
print(states.to_flat())
|
75
73
|
self.assertTrue(len(states.to_flat()) == 4)
|
@@ -18,8 +18,10 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import dataclasses
|
21
|
-
from typing import (
|
22
|
-
|
21
|
+
from typing import (
|
22
|
+
Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
|
23
|
+
Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional, overload
|
24
|
+
)
|
23
25
|
|
24
26
|
import jax
|
25
27
|
import numpy as np
|