brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +10 -3
  3. brainstate/_state.py +178 -178
  4. brainstate/_utils.py +0 -1
  5. brainstate/augment/_autograd.py +0 -2
  6. brainstate/augment/_autograd_test.py +132 -133
  7. brainstate/augment/_eval_shape.py +0 -2
  8. brainstate/augment/_eval_shape_test.py +7 -9
  9. brainstate/augment/_mapping.py +2 -3
  10. brainstate/augment/_mapping_test.py +75 -76
  11. brainstate/augment/_random.py +0 -2
  12. brainstate/compile/_ad_checkpoint.py +0 -2
  13. brainstate/compile/_ad_checkpoint_test.py +6 -8
  14. brainstate/compile/_conditions.py +0 -2
  15. brainstate/compile/_conditions_test.py +35 -36
  16. brainstate/compile/_error_if.py +0 -2
  17. brainstate/compile/_error_if_test.py +10 -13
  18. brainstate/compile/_jit.py +9 -8
  19. brainstate/compile/_loop_collect_return.py +0 -2
  20. brainstate/compile/_loop_collect_return_test.py +7 -9
  21. brainstate/compile/_loop_no_collection.py +0 -2
  22. brainstate/compile/_loop_no_collection_test.py +7 -8
  23. brainstate/compile/_make_jaxpr.py +30 -17
  24. brainstate/compile/_make_jaxpr_test.py +20 -20
  25. brainstate/compile/_progress_bar.py +0 -1
  26. brainstate/compile/_unvmap.py +0 -1
  27. brainstate/compile/_util.py +0 -2
  28. brainstate/environ.py +0 -2
  29. brainstate/functional/_activations.py +0 -2
  30. brainstate/functional/_activations_test.py +61 -61
  31. brainstate/functional/_normalization.py +0 -2
  32. brainstate/functional/_others.py +0 -2
  33. brainstate/functional/_spikes.py +0 -1
  34. brainstate/graph/_graph_node.py +1 -3
  35. brainstate/graph/_graph_node_test.py +16 -18
  36. brainstate/graph/_graph_operation.py +4 -2
  37. brainstate/graph/_graph_operation_test.py +154 -156
  38. brainstate/init/_base.py +0 -2
  39. brainstate/init/_generic.py +0 -1
  40. brainstate/init/_random_inits.py +0 -1
  41. brainstate/init/_random_inits_test.py +20 -21
  42. brainstate/init/_regular_inits.py +0 -2
  43. brainstate/init/_regular_inits_test.py +4 -5
  44. brainstate/mixin.py +0 -2
  45. brainstate/nn/_collective_ops.py +0 -3
  46. brainstate/nn/_collective_ops_test.py +8 -8
  47. brainstate/nn/_common.py +0 -2
  48. brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
  49. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
  50. brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
  51. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
  52. brainstate/nn/_dyn_impl/_inputs.py +0 -1
  53. brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
  54. brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
  55. brainstate/nn/_dyn_impl/_readout.py +0 -1
  56. brainstate/nn/_dyn_impl/_readout_test.py +9 -10
  57. brainstate/nn/_dynamics/_dynamics_base.py +0 -1
  58. brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
  59. brainstate/nn/_dynamics/_projection_base.py +0 -1
  60. brainstate/nn/_dynamics/_state_delay.py +0 -2
  61. brainstate/nn/_dynamics/_synouts.py +0 -2
  62. brainstate/nn/_dynamics/_synouts_test.py +4 -5
  63. brainstate/nn/_elementwise/_dropout.py +0 -2
  64. brainstate/nn/_elementwise/_dropout_test.py +9 -9
  65. brainstate/nn/_elementwise/_elementwise.py +0 -2
  66. brainstate/nn/_elementwise/_elementwise_test.py +57 -59
  67. brainstate/nn/_event/_fixedprob_mv.py +0 -1
  68. brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
  69. brainstate/nn/_event/_linear_mv.py +0 -2
  70. brainstate/nn/_event/_linear_mv_test.py +0 -1
  71. brainstate/nn/_exp_euler.py +0 -2
  72. brainstate/nn/_exp_euler_test.py +5 -6
  73. brainstate/nn/_interaction/_conv.py +0 -2
  74. brainstate/nn/_interaction/_conv_test.py +31 -33
  75. brainstate/nn/_interaction/_embedding.py +0 -1
  76. brainstate/nn/_interaction/_linear.py +0 -2
  77. brainstate/nn/_interaction/_linear_test.py +15 -17
  78. brainstate/nn/_interaction/_normalizations.py +0 -2
  79. brainstate/nn/_interaction/_normalizations_test.py +10 -12
  80. brainstate/nn/_interaction/_poolings.py +0 -2
  81. brainstate/nn/_interaction/_poolings_test.py +19 -21
  82. brainstate/nn/_module.py +0 -1
  83. brainstate/nn/_module_test.py +34 -37
  84. brainstate/nn/metrics.py +0 -2
  85. brainstate/optim/_base.py +0 -2
  86. brainstate/optim/_lr_scheduler.py +0 -1
  87. brainstate/optim/_lr_scheduler_test.py +3 -3
  88. brainstate/optim/_optax_optimizer.py +0 -2
  89. brainstate/optim/_optax_optimizer_test.py +8 -9
  90. brainstate/optim/_sgd_optimizer.py +0 -1
  91. brainstate/random/_rand_funs.py +0 -1
  92. brainstate/random/_rand_funs_test.py +183 -184
  93. brainstate/random/_rand_seed.py +0 -1
  94. brainstate/random/_rand_seed_test.py +10 -12
  95. brainstate/random/_rand_state.py +0 -1
  96. brainstate/surrogate.py +0 -1
  97. brainstate/typing.py +0 -2
  98. brainstate/util/_caller.py +4 -6
  99. brainstate/util/_others.py +0 -2
  100. brainstate/util/_pretty_pytree.py +201 -150
  101. brainstate/util/_pretty_repr.py +0 -2
  102. brainstate/util/_pretty_table.py +57 -3
  103. brainstate/util/_scaling.py +0 -2
  104. brainstate/util/_struct.py +0 -2
  105. brainstate/util/filter.py +0 -2
  106. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
  107. brainstate-0.1.2.dist-info/RECORD +133 -0
  108. brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
  109. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
  110. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
  111. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -25,48 +25,48 @@ from absl.testing import parameterized
25
25
  from jax._src import test_util as jtu
26
26
  from jax.test_util import check_grads
27
27
 
28
- import brainstate as bst
28
+ import brainstate
29
29
 
30
30
 
31
31
  class NNFunctionsTest(jtu.JaxTestCase):
32
32
  @jtu.skip_on_flag("jax_skip_slow_tests", True)
33
33
  def testSoftplusGrad(self):
34
- check_grads(bst.functional.softplus, (1e-8,), order=4, )
34
+ check_grads(brainstate.functional.softplus, (1e-8,), order=4, )
35
35
 
36
36
  def testSoftplusGradZero(self):
37
- check_grads(bst.functional.softplus, (0.,), order=1)
37
+ check_grads(brainstate.functional.softplus, (0.,), order=1)
38
38
 
39
39
  def testSoftplusGradInf(self):
40
- self.assertAllClose(1., jax.grad(bst.functional.softplus)(float('inf')))
40
+ self.assertAllClose(1., jax.grad(brainstate.functional.softplus)(float('inf')))
41
41
 
42
42
  def testSoftplusGradNegInf(self):
43
- check_grads(bst.functional.softplus, (-float('inf'),), order=1)
43
+ check_grads(brainstate.functional.softplus, (-float('inf'),), order=1)
44
44
 
45
45
  def testSoftplusGradNan(self):
46
- check_grads(bst.functional.softplus, (float('nan'),), order=1)
46
+ check_grads(brainstate.functional.softplus, (float('nan'),), order=1)
47
47
 
48
48
  @parameterized.parameters([int, float] + jtu.dtypes.floating + jtu.dtypes.integer)
49
49
  def testSoftplusZero(self, dtype):
50
- self.assertEqual(jnp.log(dtype(2)), bst.functional.softplus(dtype(0)))
50
+ self.assertEqual(jnp.log(dtype(2)), brainstate.functional.softplus(dtype(0)))
51
51
 
52
52
  def testSparseplusGradZero(self):
53
- check_grads(bst.functional.sparse_plus, (-2.,), order=1)
53
+ check_grads(brainstate.functional.sparse_plus, (-2.,), order=1)
54
54
 
55
55
  def testSparseplusGrad(self):
56
- check_grads(bst.functional.sparse_plus, (0.,), order=1)
56
+ check_grads(brainstate.functional.sparse_plus, (0.,), order=1)
57
57
 
58
58
  def testSparseplusAndSparseSigmoid(self):
59
59
  self.assertAllClose(
60
- jax.grad(bst.functional.sparse_plus)(0.),
61
- bst.functional.sparse_sigmoid(0.),
60
+ jax.grad(brainstate.functional.sparse_plus)(0.),
61
+ brainstate.functional.sparse_sigmoid(0.),
62
62
  check_dtypes=False)
63
63
  self.assertAllClose(
64
- jax.grad(bst.functional.sparse_plus)(2.),
65
- bst.functional.sparse_sigmoid(2.),
64
+ jax.grad(brainstate.functional.sparse_plus)(2.),
65
+ brainstate.functional.sparse_sigmoid(2.),
66
66
  check_dtypes=False)
67
67
  self.assertAllClose(
68
- jax.grad(bst.functional.sparse_plus)(-2.),
69
- bst.functional.sparse_sigmoid(-2.),
68
+ jax.grad(brainstate.functional.sparse_plus)(-2.),
69
+ brainstate.functional.sparse_sigmoid(-2.),
70
70
  check_dtypes=False)
71
71
 
72
72
  # def testSquareplusGrad(self):
@@ -107,55 +107,55 @@ class NNFunctionsTest(jtu.JaxTestCase):
107
107
 
108
108
  @parameterized.parameters([float] + jtu.dtypes.floating)
109
109
  def testMishZero(self, dtype):
110
- self.assertEqual(dtype(0), bst.functional.mish(dtype(0)))
110
+ self.assertEqual(dtype(0), brainstate.functional.mish(dtype(0)))
111
111
 
112
112
  def testReluGrad(self):
113
113
  rtol = None
114
- check_grads(bst.functional.relu, (1.,), order=3, rtol=rtol)
115
- check_grads(bst.functional.relu, (-1.,), order=3, rtol=rtol)
116
- jaxpr = jax.make_jaxpr(jax.grad(bst.functional.relu))(0.)
114
+ check_grads(brainstate.functional.relu, (1.,), order=3, rtol=rtol)
115
+ check_grads(brainstate.functional.relu, (-1.,), order=3, rtol=rtol)
116
+ jaxpr = jax.make_jaxpr(jax.grad(brainstate.functional.relu))(0.)
117
117
  self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
118
118
 
119
119
  def testRelu6Grad(self):
120
120
  rtol = None
121
- check_grads(bst.functional.relu6, (1.,), order=3, rtol=rtol)
122
- check_grads(bst.functional.relu6, (-1.,), order=3, rtol=rtol)
123
- self.assertAllClose(jax.grad(bst.functional.relu6)(0.), 0., check_dtypes=False)
124
- self.assertAllClose(jax.grad(bst.functional.relu6)(6.), 0., check_dtypes=False)
121
+ check_grads(brainstate.functional.relu6, (1.,), order=3, rtol=rtol)
122
+ check_grads(brainstate.functional.relu6, (-1.,), order=3, rtol=rtol)
123
+ self.assertAllClose(jax.grad(brainstate.functional.relu6)(0.), 0., check_dtypes=False)
124
+ self.assertAllClose(jax.grad(brainstate.functional.relu6)(6.), 0., check_dtypes=False)
125
125
 
126
126
  def testSoftplusValue(self):
127
- val = bst.functional.softplus(89.)
127
+ val = brainstate.functional.softplus(89.)
128
128
  self.assertAllClose(val, 89., check_dtypes=False)
129
129
 
130
130
  def testSparseplusValue(self):
131
- val = bst.functional.sparse_plus(89.)
131
+ val = brainstate.functional.sparse_plus(89.)
132
132
  self.assertAllClose(val, 89., check_dtypes=False)
133
133
 
134
134
  def testSparsesigmoidValue(self):
135
- self.assertAllClose(bst.functional.sparse_sigmoid(-2.), 0., check_dtypes=False)
136
- self.assertAllClose(bst.functional.sparse_sigmoid(2.), 1., check_dtypes=False)
137
- self.assertAllClose(bst.functional.sparse_sigmoid(0.), .5, check_dtypes=False)
135
+ self.assertAllClose(brainstate.functional.sparse_sigmoid(-2.), 0., check_dtypes=False)
136
+ self.assertAllClose(brainstate.functional.sparse_sigmoid(2.), 1., check_dtypes=False)
137
+ self.assertAllClose(brainstate.functional.sparse_sigmoid(0.), .5, check_dtypes=False)
138
138
 
139
139
  # def testSquareplusValue(self):
140
140
  # val = bst.functional.squareplus(1e3)
141
141
  # self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
142
142
 
143
143
  def testMishValue(self):
144
- val = bst.functional.mish(1e3)
144
+ val = brainstate.functional.mish(1e3)
145
145
  self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
146
146
 
147
147
  def testEluValue(self):
148
- val = bst.functional.elu(1e4)
148
+ val = brainstate.functional.elu(1e4)
149
149
  self.assertAllClose(val, 1e4, check_dtypes=False)
150
150
 
151
151
  def testGluValue(self):
152
- val = bst.functional.glu(jnp.array([1.0, 0.0]), axis=0)
152
+ val = brainstate.functional.glu(jnp.array([1.0, 0.0]), axis=0)
153
153
  self.assertAllClose(val, jnp.array([0.5]))
154
154
 
155
155
  @parameterized.parameters(False, True)
156
156
  def testGeluIntType(self, approximate):
157
- val_float = bst.functional.gelu(jnp.array(-1.0), approximate=approximate)
158
- val_int = bst.functional.gelu(jnp.array(-1), approximate=approximate)
157
+ val_float = brainstate.functional.gelu(jnp.array(-1.0), approximate=approximate)
158
+ val_int = brainstate.functional.gelu(jnp.array(-1), approximate=approximate)
159
159
  self.assertAllClose(val_float, val_int)
160
160
 
161
161
  @parameterized.parameters(False, True)
@@ -166,19 +166,19 @@ class NNFunctionsTest(jtu.JaxTestCase):
166
166
  rng = jtu.rand_default(self.rng())
167
167
  args_maker = lambda: [rng((4, 5, 6), jnp.float32)]
168
168
  self._CheckAgainstNumpy(
169
- gelu_reference, partial(bst.functional.gelu, approximate=approximate), args_maker,
169
+ gelu_reference, partial(brainstate.functional.gelu, approximate=approximate), args_maker,
170
170
  check_dtypes=False, tol=1e-3 if approximate else None)
171
171
 
172
172
  @parameterized.parameters(*itertools.product(
173
173
  (jnp.float32, jnp.bfloat16, jnp.float16),
174
- (partial(bst.functional.gelu, approximate=False),
175
- partial(bst.functional.gelu, approximate=True),
176
- bst.functional.relu,
177
- bst.functional.softplus,
178
- bst.functional.sparse_plus,
179
- bst.functional.sigmoid,
174
+ (partial(brainstate.functional.gelu, approximate=False),
175
+ partial(brainstate.functional.gelu, approximate=True),
176
+ brainstate.functional.relu,
177
+ brainstate.functional.softplus,
178
+ brainstate.functional.sparse_plus,
179
+ brainstate.functional.sigmoid,
180
180
  # bst.functional.squareplus,
181
- bst.functional.mish)))
181
+ brainstate.functional.mish)))
182
182
  def testDtypeMatchesInput(self, dtype, fn):
183
183
  x = jnp.zeros((), dtype=dtype)
184
184
  out = fn(x)
@@ -187,26 +187,26 @@ class NNFunctionsTest(jtu.JaxTestCase):
187
187
  def testEluMemory(self):
188
188
  # see https://github.com/google/jax/pull/1640
189
189
  with jax.enable_checks(False): # With checks we materialize the array
190
- jax.make_jaxpr(lambda: bst.functional.elu(jnp.ones((10 ** 12,)))) # don't oom
190
+ jax.make_jaxpr(lambda: brainstate.functional.elu(jnp.ones((10 ** 12,)))) # don't oom
191
191
 
192
192
  def testHardTanhMemory(self):
193
193
  # see https://github.com/google/jax/pull/1640
194
194
  with jax.enable_checks(False): # With checks we materialize the array
195
- jax.make_jaxpr(lambda: bst.functional.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
195
+ jax.make_jaxpr(lambda: brainstate.functional.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
196
196
 
197
- @parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
197
+ @parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
198
198
  def testSoftmaxEmptyArray(self, fn):
199
199
  x = jnp.array([], dtype=float)
200
200
  self.assertArraysEqual(fn(x), x)
201
201
 
202
- @parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
202
+ @parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
203
203
  def testSoftmaxEmptyMask(self, fn):
204
204
  x = jnp.array([5.5, 1.3, -4.2, 0.9])
205
205
  m = jnp.zeros_like(x, dtype=bool)
206
- expected = jnp.full_like(x, 0.0 if fn is bst.functional.softmax else -jnp.inf)
206
+ expected = jnp.full_like(x, 0.0 if fn is brainstate.functional.softmax else -jnp.inf)
207
207
  self.assertArraysEqual(fn(x, where=m), expected)
208
208
 
209
- @parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
209
+ @parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
210
210
  def testSoftmaxWhereMask(self, fn):
211
211
  x = jnp.array([5.5, 1.3, -4.2, 0.9])
212
212
  m = jnp.array([True, False, True, True])
@@ -214,10 +214,10 @@ class NNFunctionsTest(jtu.JaxTestCase):
214
214
  out = fn(x, where=m)
215
215
  self.assertAllClose(out[m], fn(x[m]))
216
216
 
217
- probs = out if fn is bst.functional.softmax else jnp.exp(out)
217
+ probs = out if fn is brainstate.functional.softmax else jnp.exp(out)
218
218
  self.assertAllClose(probs.sum(), 1.0)
219
219
 
220
- @parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
220
+ @parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
221
221
  def testSoftmaxWhereGrad(self, fn):
222
222
  # regression test for https://github.com/google/jax/issues/19490
223
223
  x = jnp.array([36., 10000.])
@@ -229,46 +229,46 @@ class NNFunctionsTest(jtu.JaxTestCase):
229
229
 
230
230
  def testSoftmaxGrad(self):
231
231
  x = jnp.array([5.5, 1.3, -4.2, 0.9])
232
- jtu.check_grads(bst.functional.softmax, (x,), order=2, atol=5e-3)
232
+ jtu.check_grads(brainstate.functional.softmax, (x,), order=2, atol=5e-3)
233
233
 
234
234
  def testStandardizeWhereMask(self):
235
235
  x = jnp.array([5.5, 1.3, -4.2, 0.9])
236
236
  m = jnp.array([True, False, True, True])
237
237
  x_filtered = jnp.take(x, jnp.array([0, 2, 3]))
238
238
 
239
- out_masked = jnp.take(bst.functional.standardize(x, where=m), jnp.array([0, 2, 3]))
240
- out_filtered = bst.functional.standardize(x_filtered)
239
+ out_masked = jnp.take(brainstate.functional.standardize(x, where=m), jnp.array([0, 2, 3]))
240
+ out_filtered = brainstate.functional.standardize(x_filtered)
241
241
 
242
242
  self.assertAllClose(out_masked, out_filtered)
243
243
 
244
244
  def testOneHot(self):
245
- actual = bst.functional.one_hot(jnp.array([0, 1, 2]), 3)
245
+ actual = brainstate.functional.one_hot(jnp.array([0, 1, 2]), 3)
246
246
  expected = jnp.array([[1., 0., 0.],
247
247
  [0., 1., 0.],
248
248
  [0., 0., 1.]])
249
249
  self.assertAllClose(actual, expected, check_dtypes=False)
250
250
 
251
- actual = bst.functional.one_hot(jnp.array([1, 2, 0]), 3)
251
+ actual = brainstate.functional.one_hot(jnp.array([1, 2, 0]), 3)
252
252
  expected = jnp.array([[0., 1., 0.],
253
253
  [0., 0., 1.],
254
254
  [1., 0., 0.]])
255
255
  self.assertAllClose(actual, expected, check_dtypes=False)
256
256
 
257
257
  def testOneHotOutOfBound(self):
258
- actual = bst.functional.one_hot(jnp.array([-1, 3]), 3)
258
+ actual = brainstate.functional.one_hot(jnp.array([-1, 3]), 3)
259
259
  expected = jnp.array([[0., 0., 0.],
260
260
  [0., 0., 0.]])
261
261
  self.assertAllClose(actual, expected, check_dtypes=False)
262
262
 
263
263
  def testOneHotNonArrayInput(self):
264
- actual = bst.functional.one_hot([0, 1, 2], 3)
264
+ actual = brainstate.functional.one_hot([0, 1, 2], 3)
265
265
  expected = jnp.array([[1., 0., 0.],
266
266
  [0., 1., 0.],
267
267
  [0., 0., 1.]])
268
268
  self.assertAllClose(actual, expected, check_dtypes=False)
269
269
 
270
270
  def testOneHotCustomDtype(self):
271
- actual = bst.functional.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
271
+ actual = brainstate.functional.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
272
272
  expected = jnp.array([[True, False, False],
273
273
  [False, True, False],
274
274
  [False, False, True]])
@@ -279,14 +279,14 @@ class NNFunctionsTest(jtu.JaxTestCase):
279
279
  [0., 0., 1.],
280
280
  [1., 0., 0.]]).T
281
281
 
282
- actual = bst.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=0)
282
+ actual = brainstate.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=0)
283
283
  self.assertAllClose(actual, expected, check_dtypes=False)
284
284
 
285
- actual = bst.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
285
+ actual = brainstate.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
286
286
  self.assertAllClose(actual, expected, check_dtypes=False)
287
287
 
288
288
  def testTanhExists(self):
289
- print(bst.functional.tanh) # doesn't crash
289
+ print(brainstate.functional.tanh) # doesn't crash
290
290
 
291
291
  def testCustomJVPLeak(self):
292
292
  # https://github.com/google/jax/issues/8171
@@ -295,7 +295,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
295
295
  a = jnp.array(1.)
296
296
 
297
297
  def f(hx, _):
298
- hx = bst.functional.sigmoid(hx + a)
298
+ hx = brainstate.functional.sigmoid(hx + a)
299
299
  return hx, None
300
300
 
301
301
  hx = jnp.array(0.)
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  from typing import Optional, Union
19
17
 
20
18
  import brainunit as u
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  from functools import partial
19
17
 
20
18
  import jax
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  __all__ = [
19
18
  'spike_bitwise_or',
@@ -15,8 +15,6 @@
15
15
  # See the License for the specific language governing permissions and
16
16
  # limitations under the License.
17
17
 
18
- from __future__ import annotations
19
-
20
18
  from abc import ABCMeta
21
19
  from copy import deepcopy
22
20
  from typing import Any, Callable, Type, TypeVar, Tuple, TYPE_CHECKING, Mapping, Iterator, Sequence
@@ -210,7 +208,7 @@ class List(Node):
210
208
  def __len__(self):
211
209
  return len(vars(self))
212
210
 
213
- def __add__(self, other: Sequence[A]) -> List[A]:
211
+ def __add__(self, other: Sequence[A]) -> 'List[A]':
214
212
  return List(list(self) + list(other))
215
213
 
216
214
  def append(self, value):
@@ -13,63 +13,61 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import unittest
19
17
 
20
- import brainstate as bst
18
+ import brainstate
21
19
 
22
20
 
23
21
  class TestSequential(unittest.TestCase):
24
22
  def test1(self):
25
- s = bst.graph.Sequential(bst.nn.Linear(1, 2),
26
- bst.nn.Linear(2, 3))
27
- graphdef, states = bst.graph.treefy_split(s)
23
+ s = brainstate.graph.Sequential(brainstate.nn.Linear(1, 2),
24
+ brainstate.nn.Linear(2, 3))
25
+ graphdef, states = brainstate.graph.treefy_split(s)
28
26
  print(states)
29
27
  self.assertTrue(len(states.to_flat()) == 2)
30
28
 
31
29
 
32
30
  class TestStateRetrieve(unittest.TestCase):
33
31
  def test_list_of_states_1(self):
34
- class Model(bst.graph.Node):
32
+ class Model(brainstate.graph.Node):
35
33
  def __init__(self):
36
34
  self.a = [1, 2, 3]
37
- self.b = [bst.State(1), bst.State(2), bst.State(3)]
35
+ self.b = [brainstate.State(1), brainstate.State(2), brainstate.State(3)]
38
36
 
39
37
  m = Model()
40
- graphdef, states = bst.graph.treefy_split(m)
38
+ graphdef, states = brainstate.graph.treefy_split(m)
41
39
  print(states.to_flat())
42
40
  self.assertTrue(len(states.to_flat()) == 3)
43
41
 
44
42
  def test_list_of_states_2(self):
45
- class Model(bst.graph.Node):
43
+ class Model(brainstate.graph.Node):
46
44
  def __init__(self):
47
45
  self.a = [1, 2, 3]
48
- self.b = [bst.State(1), [bst.State(2), bst.State(3)]]
46
+ self.b = [brainstate.State(1), [brainstate.State(2), brainstate.State(3)]]
49
47
 
50
48
  m = Model()
51
- graphdef, states = bst.graph.treefy_split(m)
49
+ graphdef, states = brainstate.graph.treefy_split(m)
52
50
  print(states.to_flat())
53
51
  self.assertTrue(len(states.to_flat()) == 3)
54
52
 
55
53
  def test_list_of_node_1(self):
56
- class Model(bst.graph.Node):
54
+ class Model(brainstate.graph.Node):
57
55
  def __init__(self):
58
56
  self.a = [1, 2, 3]
59
- self.b = [bst.nn.Linear(1, 2), bst.nn.Linear(2, 3)]
57
+ self.b = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(2, 3)]
60
58
 
61
59
  m = Model()
62
- graphdef, states = bst.graph.treefy_split(m)
60
+ graphdef, states = brainstate.graph.treefy_split(m)
63
61
  print(states.to_flat())
64
62
  self.assertTrue(len(states.to_flat()) == 2)
65
63
 
66
64
  def test_list_of_node_2(self):
67
- class Model(bst.graph.Node):
65
+ class Model(brainstate.graph.Node):
68
66
  def __init__(self):
69
67
  self.a = [1, 2, 3]
70
- self.b = [bst.nn.Linear(1, 2), [bst.nn.Linear(2, 3)], (bst.nn.Linear(3, 4), bst.nn.Linear(4, 5))]
68
+ self.b = [brainstate.nn.Linear(1, 2), [brainstate.nn.Linear(2, 3)], (brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5))]
71
69
 
72
70
  m = Model()
73
- graphdef, states = bst.graph.treefy_split(m)
71
+ graphdef, states = brainstate.graph.treefy_split(m)
74
72
  print(states.to_flat())
75
73
  self.assertTrue(len(states.to_flat()) == 4)
@@ -18,8 +18,10 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import dataclasses
21
- from typing import (Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
22
- Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional, overload)
21
+ from typing import (
22
+ Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
23
+ Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional, overload
24
+ )
23
25
 
24
26
  import jax
25
27
  import numpy as np