brainstate 0.1.4__py2.py3-none-any.whl → 0.1.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_state.py +6 -5
  3. brainstate/augment/_autograd.py +31 -12
  4. brainstate/augment/_autograd_test.py +46 -46
  5. brainstate/augment/_eval_shape.py +4 -4
  6. brainstate/augment/_mapping.py +22 -17
  7. brainstate/augment/_mapping_test.py +162 -0
  8. brainstate/compile/_conditions.py +2 -2
  9. brainstate/compile/_make_jaxpr.py +59 -6
  10. brainstate/compile/_progress_bar.py +2 -2
  11. brainstate/environ.py +19 -19
  12. brainstate/functional/_activations_test.py +12 -12
  13. brainstate/graph/_graph_operation.py +69 -69
  14. brainstate/graph/_graph_operation_test.py +2 -2
  15. brainstate/mixin.py +0 -17
  16. brainstate/nn/_collective_ops.py +4 -4
  17. brainstate/nn/_common.py +7 -19
  18. brainstate/nn/_dropout_test.py +2 -2
  19. brainstate/nn/_dynamics.py +53 -35
  20. brainstate/nn/_elementwise.py +30 -30
  21. brainstate/nn/_exp_euler.py +13 -16
  22. brainstate/nn/_inputs.py +1 -1
  23. brainstate/nn/_linear.py +4 -4
  24. brainstate/nn/_module.py +6 -6
  25. brainstate/nn/_module_test.py +1 -1
  26. brainstate/nn/_normalizations.py +11 -11
  27. brainstate/nn/_normalizations_test.py +6 -6
  28. brainstate/nn/_poolings.py +24 -24
  29. brainstate/nn/_synapse.py +1 -12
  30. brainstate/nn/_utils.py +1 -1
  31. brainstate/nn/metrics.py +4 -4
  32. brainstate/optim/_optax_optimizer.py +8 -8
  33. brainstate/random/_rand_funs.py +37 -37
  34. brainstate/random/_rand_funs_test.py +3 -3
  35. brainstate/random/_rand_seed.py +7 -7
  36. brainstate/random/_rand_state.py +13 -7
  37. brainstate/surrogate.py +40 -40
  38. brainstate/util/pretty_pytree.py +10 -10
  39. brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
  40. brainstate/util/struct.py +7 -7
  41. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/METADATA +12 -12
  42. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/RECORD +45 -45
  43. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/WHEEL +1 -1
  44. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/LICENSE +0 -0
  45. {brainstate-0.1.4.dist-info → brainstate-0.1.6.dist-info}/top_level.txt +0 -0
brainstate/nn/metrics.py CHANGED
@@ -61,12 +61,12 @@ class Average(Metric):
61
61
  Example usage::
62
62
 
63
63
  >>> import jax.numpy as jnp
64
- >>> import brainstate as bst
64
+ >>> import brainstate as brainstate
65
65
 
66
66
  >>> batch_loss = jnp.array([1, 2, 3, 4])
67
67
  >>> batch_loss2 = jnp.array([3, 2, 1, 0])
68
68
 
69
- >>> metrics = bst.nn.metrics.Average()
69
+ >>> metrics = brainstate.nn.metrics.Average()
70
70
  >>> metrics.compute()
71
71
  Array(nan, dtype=float32)
72
72
  >>> metrics.update(values=batch_loss)
@@ -223,7 +223,7 @@ class Accuracy(Average):
223
223
 
224
224
  Example usage::
225
225
 
226
- >>> import brainstate as bst
226
+ >>> import brainstate as brainstate
227
227
  >>> import jax, jax.numpy as jnp
228
228
 
229
229
  >>> logits = jax.random.normal(jax.random.key(0), (5, 2))
@@ -231,7 +231,7 @@ class Accuracy(Average):
231
231
  >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
232
232
  >>> labels2 = jnp.array([0, 1, 1, 1, 1])
233
233
 
234
- >>> metrics = bst.nn.metrics.Accuracy()
234
+ >>> metrics = brainstate.nn.metrics.Accuracy()
235
235
  >>> metrics.compute()
236
236
  Array(nan, dtype=float32)
237
237
  >>> metrics.update(logits=logits, labels=labels)
@@ -36,29 +36,29 @@ class OptaxOptimizer(Optimizer):
36
36
 
37
37
  >>> import jax
38
38
  >>> import jax.numpy as jnp
39
- >>> import brainstate as bst
39
+ >>> import brainstate as brainstate
40
40
  >>> import optax
41
41
  ...
42
- >>> class Model(bst.nn.Module):
42
+ >>> class Model(brainstate.nn.Module):
43
43
  ... def __init__(self):
44
44
  ... super().__init__()
45
- ... self.linear1 = bst.nn.Linear(2, 3)
46
- ... self.linear2 = bst.nn.Linear(3, 4)
45
+ ... self.linear1 = brainstate.nn.Linear(2, 3)
46
+ ... self.linear2 = brainstate.nn.Linear(3, 4)
47
47
  ... def __call__(self, x):
48
48
  ... return self.linear2(self.linear1(x))
49
49
  ...
50
- >>> x = bst.random.randn(1, 2)
50
+ >>> x = brainstate.random.randn(1, 2)
51
51
  >>> y = jnp.ones((1, 4))
52
52
  ...
53
53
  >>> model = Model()
54
54
  >>> tx = optax.adam(1e-3)
55
- >>> optimizer = bst.optim.OptaxOptimizer(tx)
56
- >>> optimizer.register_trainable_weights(model.states(bst.ParamState))
55
+ >>> optimizer = brainstate.optim.OptaxOptimizer(tx)
56
+ >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
57
57
  ...
58
58
  >>> loss_fn = lambda: ((model(x) - y) ** 2).mean()
59
59
  >>> loss_fn()
60
60
  Array(1.7055722, dtype=float32)
61
- >>> grads = bst.augment.grad(loss_fn, model.states(bst.ParamState))()
61
+ >>> grads = brainstate.augment.grad(loss_fn, model.states(brainstate.ParamState))()
62
62
  >>> optimizer.update(grads)
63
63
  >>> loss_fn()
64
64
  Array(1.6925814, dtype=float32)
@@ -78,8 +78,8 @@ def rand(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
78
78
 
79
79
  Examples
80
80
  --------
81
- >>> import brainstate as bst
82
- >>> bst.random.rand(3,2)
81
+ >>> import brainstate as brainstate
82
+ >>> brainstate.random.rand(3,2)
83
83
  array([[ 0.14022471, 0.96360618], #random
84
84
  [ 0.37601032, 0.25528411], #random
85
85
  [ 0.49313049, 0.94909878]]) #random
@@ -135,31 +135,31 @@ def randint(low, high=None, size: Optional[Size] = None,
135
135
 
136
136
  Examples
137
137
  --------
138
- >>> import brainstate as bst
139
- >>> bst.random.randint(2, size=10)
138
+ >>> import brainstate as brainstate
139
+ >>> brainstate.random.randint(2, size=10)
140
140
  array([1, 0, 0, 0, 1, 1, 0, 0, 1, 0]) # random
141
- >>> bst.random.randint(1, size=10)
141
+ >>> brainstate.random.randint(1, size=10)
142
142
  array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
143
143
 
144
144
  Generate a 2 x 4 array of ints between 0 and 4, inclusive:
145
145
 
146
- >>> bst.random.randint(5, size=(2, 4))
146
+ >>> brainstate.random.randint(5, size=(2, 4))
147
147
  array([[4, 0, 2, 1], # random
148
148
  [3, 2, 2, 0]])
149
149
 
150
150
  Generate a 1 x 3 array with 3 different upper bounds
151
151
 
152
- >>> bst.random.randint(1, [3, 5, 10])
152
+ >>> brainstate.random.randint(1, [3, 5, 10])
153
153
  array([2, 2, 9]) # random
154
154
 
155
155
  Generate a 1 by 3 array with 3 different lower bounds
156
156
 
157
- >>> bst.random.randint([1, 5, 7], 10)
157
+ >>> brainstate.random.randint([1, 5, 7], 10)
158
158
  array([9, 8, 7]) # random
159
159
 
160
160
  Generate a 2 by 4 array using broadcasting with dtype of uint8
161
161
 
162
- >>> bst.random.randint([1, 3, 5, 7], [[10], [20]], dtype=np.uint8)
162
+ >>> brainstate.random.randint([1, 3, 5, 7], [[10], [20]], dtype=np.uint8)
163
163
  array([[ 8, 6, 9, 7], # random
164
164
  [ 1, 16, 9, 12]], dtype=uint8)
165
165
  """
@@ -219,12 +219,12 @@ def random_integers(low,
219
219
 
220
220
  Examples
221
221
  --------
222
- >>> import brainstate as bst
223
- >>> bst.random.random_integers(5)
222
+ >>> import brainstate as brainstate
223
+ >>> brainstate.random.random_integers(5)
224
224
  4 # random
225
- >>> type(bst.random.random_integers(5))
225
+ >>> type(brainstate.random.random_integers(5))
226
226
  <class 'numpy.int64'>
227
- >>> bst.random.random_integers(5, size=(3,2))
227
+ >>> brainstate.random.random_integers(5, size=(3,2))
228
228
  array([[5, 4], # random
229
229
  [3, 3],
230
230
  [4, 5]])
@@ -233,13 +233,13 @@ def random_integers(low,
233
233
  numbers between 0 and 2.5, inclusive (*i.e.*, from the set
234
234
  :math:`{0, 5/8, 10/8, 15/8, 20/8}`):
235
235
 
236
- >>> 2.5 * (bst.random.random_integers(5, size=(5,)) - 1) / 4.
236
+ >>> 2.5 * (brainstate.random.random_integers(5, size=(5,)) - 1) / 4.
237
237
  array([ 0.625, 1.25 , 0.625, 0.625, 2.5 ]) # random
238
238
 
239
239
  Roll two six sided dice 1000 times and sum the results:
240
240
 
241
- >>> d1 = bst.random.random_integers(1, 6, 1000)
242
- >>> d2 = bst.random.random_integers(1, 6, 1000)
241
+ >>> d1 = brainstate.random.random_integers(1, 6, 1000)
242
+ >>> d2 = brainstate.random.random_integers(1, 6, 1000)
243
243
  >>> dsums = d1 + d2
244
244
 
245
245
  Display results as a histogram:
@@ -301,13 +301,13 @@ def randn(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
301
301
 
302
302
  Examples
303
303
  --------
304
- >>> import brainstate as bst
305
- >>> bst.random.randn()
304
+ >>> import brainstate as brainstate
305
+ >>> brainstate.random.randn()
306
306
  2.1923875335537315 # random
307
307
 
308
308
  Two-by-four array of samples from N(3, 6.25):
309
309
 
310
- >>> 3 + 2.5 * bst.random.randn(2, 4)
310
+ >>> 3 + 2.5 * brainstate.random.randn(2, 4)
311
311
  array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random
312
312
  [ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random
313
313
  """
@@ -359,17 +359,17 @@ def random_sample(size: Optional[Size] = None, key: Optional[SeedOrKey] = None,
359
359
 
360
360
  Examples
361
361
  --------
362
- >>> import brainstate as bst
363
- >>> bst.random.random_sample()
362
+ >>> import brainstate as brainstate
363
+ >>> brainstate.random.random_sample()
364
364
  0.47108547995356098 # random
365
- >>> type(bst.random.random_sample())
365
+ >>> type(brainstate.random.random_sample())
366
366
  <class 'float'>
367
- >>> bst.random.random_sample((5,))
367
+ >>> brainstate.random.random_sample((5,))
368
368
  array([ 0.30220482, 0.86820401, 0.1654503 , 0.11659149, 0.54323428]) # random
369
369
 
370
370
  Three-by-two array of random numbers from [-5, 0):
371
371
 
372
- >>> 5 * bst.random.random_sample((3, 2)) - 5
372
+ >>> 5 * brainstate.random.random_sample((3, 2)) - 5
373
373
  array([[-3.99149989, -0.52338984], # random
374
374
  [-2.99091858, -0.79479508],
375
375
  [-1.23204345, -1.75224494]])
@@ -450,34 +450,34 @@ def choice(a, size: Optional[Size] = None, replace=True, p=None,
450
450
  --------
451
451
  Generate a uniform random sample from np.arange(5) of size 3:
452
452
 
453
- >>> import brainstate as bst
454
- >>> bst.random.choice(5, 3)
453
+ >>> import brainstate as brainstate
454
+ >>> brainstate.random.choice(5, 3)
455
455
  array([0, 3, 4]) # random
456
456
  >>> #This is equivalent to brainpy.math.random.randint(0,5,3)
457
457
 
458
458
  Generate a non-uniform random sample from np.arange(5) of size 3:
459
459
 
460
- >>> bst.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0])
460
+ >>> brainstate.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0])
461
461
  array([3, 3, 0]) # random
462
462
 
463
463
  Generate a uniform random sample from np.arange(5) of size 3 without
464
464
  replacement:
465
465
 
466
- >>> bst.random.choice(5, 3, replace=False)
466
+ >>> brainstate.random.choice(5, 3, replace=False)
467
467
  array([3,1,0]) # random
468
468
  >>> #This is equivalent to brainpy.math.random.permutation(np.arange(5))[:3]
469
469
 
470
470
  Generate a non-uniform random sample from np.arange(5) of size
471
471
  3 without replacement:
472
472
 
473
- >>> bst.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0])
473
+ >>> brainstate.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0])
474
474
  array([2, 3, 0]) # random
475
475
 
476
476
  Any of the above can be repeated with an arbitrary array-like
477
477
  instead of just integers. For instance:
478
478
 
479
479
  >>> aa_milne_arr = ['pooh', 'rabbit', 'piglet', 'Christopher']
480
- >>> bst.random.choice(aa_milne_arr, 5, p=[0.5, 0.1, 0.1, 0.3])
480
+ >>> brainstate.random.choice(aa_milne_arr, 5, p=[0.5, 0.1, 0.1, 0.3])
481
481
  array(['pooh', 'pooh', 'pooh', 'Christopher', 'piglet'], # random
482
482
  dtype='<U11')
483
483
  """
@@ -519,15 +519,15 @@ def permutation(x,
519
519
 
520
520
  Examples
521
521
  --------
522
- >>> import brainstate as bst
523
- >>> bst.random.permutation(10)
522
+ >>> import brainstate as brainstate
523
+ >>> brainstate.random.permutation(10)
524
524
  array([1, 7, 4, 3, 0, 9, 2, 5, 8, 6]) # random
525
525
 
526
- >>> bst.random.permutation([1, 4, 9, 12, 15])
526
+ >>> brainstate.random.permutation([1, 4, 9, 12, 15])
527
527
  array([15, 1, 9, 4, 12]) # random
528
528
 
529
529
  >>> arr = np.arange(9).reshape((3, 3))
530
- >>> bst.random.permutation(arr)
530
+ >>> brainstate.random.permutation(arr)
531
531
  array([[6, 7, 8], # random
532
532
  [0, 1, 2],
533
533
  [3, 4, 5]])
@@ -557,16 +557,16 @@ def shuffle(x, axis=0, key: Optional[SeedOrKey] = None):
557
557
 
558
558
  Examples
559
559
  --------
560
- >>> import brainstate as bst
560
+ >>> import brainstate as brainstate
561
561
  >>> arr = np.arange(10)
562
- >>> bst.random.shuffle(arr)
562
+ >>> brainstate.random.shuffle(arr)
563
563
  >>> arr
564
564
  [1 7 5 2 9 4 3 6 0 8] # random
565
565
 
566
566
  Multi-dimensional arrays are only shuffled along the first axis:
567
567
 
568
568
  >>> arr = np.arange(9).reshape((3, 3))
569
- >>> bst.random.shuffle(arr)
569
+ >>> brainstate.random.shuffle(arr)
570
570
  >>> arr
571
571
  array([[3, 4, 5], # random
572
572
  [6, 7, 8],
@@ -562,6 +562,6 @@ class TestRandom(unittest.TestCase):
562
562
 
563
563
  # class TestRandomKey(unittest.TestCase):
564
564
  # def test_clear_memory(self):
565
- # bst.random.split_key()
566
- # print(bst.random.DEFAULT.value)
567
- # self.assertTrue(isinstance(bst.random.DEFAULT.value, np.ndarray))
565
+ # brainstate.random.split_key()
566
+ # print(brainstate.random.DEFAULT.value)
567
+ # self.assertTrue(isinstance(brainstate.random.DEFAULT.value, np.ndarray))
@@ -183,16 +183,16 @@ def seed_context(seed_or_key: SeedOrKey):
183
183
 
184
184
  Examples:
185
185
 
186
- >>> import brainstate as bst
187
- >>> print(bst.random.rand(2))
186
+ >>> import brainstate as brainstate
187
+ >>> print(brainstate.random.rand(2))
188
188
  [0.57721865 0.9820676 ]
189
- >>> print(bst.random.rand(2))
189
+ >>> print(brainstate.random.rand(2))
190
190
  [0.8511752 0.95312667]
191
- >>> with bst.random.seed_context(42):
192
- ... print(bst.random.rand(2))
191
+ >>> with brainstate.random.seed_context(42):
192
+ ... print(brainstate.random.rand(2))
193
193
  [0.95598125 0.4032725 ]
194
- >>> with bst.random.seed_context(42):
195
- ... print(bst.random.rand(2))
194
+ >>> with brainstate.random.seed_context(42):
195
+ ... print(brainstate.random.rand(2))
196
196
  [0.95598125 0.4032725 ]
197
197
 
198
198
  Args:
@@ -384,7 +384,10 @@ class RandomState(State):
384
384
  loc = _check_py_seq(loc)
385
385
  scale = _check_py_seq(scale)
386
386
  if size is None:
387
- size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
387
+ size = lax.broadcast_shapes(
388
+ jnp.shape(loc) if loc is not None else (),
389
+ jnp.shape(scale) if scale is not None else ()
390
+ )
388
391
  key = self.split_key() if key is None else _formalize_key(key)
389
392
  dtype = dtype or environ.dftype()
390
393
  r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
@@ -399,7 +402,10 @@ class RandomState(State):
399
402
  loc = _check_py_seq(loc)
400
403
  scale = _check_py_seq(scale)
401
404
  if size is None:
402
- size = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(loc))
405
+ size = lax.broadcast_shapes(
406
+ jnp.shape(scale) if scale is not None else (),
407
+ jnp.shape(loc) if loc is not None else ()
408
+ )
403
409
  key = self.split_key() if key is None else _formalize_key(key)
404
410
  dtype = dtype or environ.dftype()
405
411
  r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
@@ -456,7 +462,7 @@ class RandomState(State):
456
462
  dtype: DTypeLike = None):
457
463
  shape = _check_py_seq(shape)
458
464
  if size is None:
459
- size = jnp.shape(shape)
465
+ size = jnp.shape(shape) if shape is not None else ()
460
466
  key = self.split_key() if key is None else _formalize_key(key)
461
467
  dtype = dtype or environ.dftype()
462
468
  r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
@@ -477,7 +483,7 @@ class RandomState(State):
477
483
  dtype: DTypeLike = None):
478
484
  df = _check_py_seq(df)
479
485
  if size is None:
480
- size = jnp.shape(size)
486
+ size = jnp.shape(size) if size is not None else ()
481
487
  key = self.split_key() if key is None else _formalize_key(key)
482
488
  dtype = dtype or environ.dftype()
483
489
  r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
@@ -606,8 +612,8 @@ class RandomState(State):
606
612
 
607
613
  if size is None:
608
614
  size = jnp.broadcast_shapes(
609
- jnp.shape(mean),
610
- jnp.shape(sigma)
615
+ jnp.shape(mean) if mean is not None else (),
616
+ jnp.shape(sigma) if sigma is not None else ()
611
617
  )
612
618
  key = self.split_key() if key is None else _formalize_key(key)
613
619
  dtype = dtype or environ.dftype()
@@ -822,7 +828,7 @@ class RandomState(State):
822
828
  a = _check_py_seq(a)
823
829
  scale = _check_py_seq(scale)
824
830
  if size is None:
825
- size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale))
831
+ size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale) if scale is not None else ())
826
832
  else:
827
833
  if jnp.size(a) > 1:
828
834
  raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
brainstate/surrogate.py CHANGED
@@ -106,11 +106,11 @@ class Surrogate(PrettyObject):
106
106
  Examples
107
107
  --------
108
108
 
109
- >>> import brainstate as bst
109
+ >>> import brainstate as brainstate
110
110
  >>> import brainstate.nn as nn
111
111
  >>> import jax.numpy as jnp
112
112
 
113
- >>> class MySurrogate(bst.surrogate.Surrogate):
113
+ >>> class MySurrogate(brainstate.surrogate.Surrogate):
114
114
  ... def __init__(self, alpha=1.):
115
115
  ... super().__init__()
116
116
  ... self.alpha = alpha
@@ -236,11 +236,11 @@ def sigmoid(
236
236
 
237
237
  >>> import jax
238
238
  >>> import brainstate.nn as nn
239
- >>> import brainstate as bst
239
+ >>> import brainstate as brainstate
240
240
  >>> import matplotlib.pyplot as plt
241
241
  >>> xs = jax.numpy.linspace(-2, 2, 1000)
242
242
  >>> for alpha in [1., 2., 4.]:
243
- >>> grads = bst.augment.vector_grad(bst.surrogate.sigmoid)(xs, alpha)
243
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.sigmoid)(xs, alpha)
244
244
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
245
245
  >>> plt.legend()
246
246
  >>> plt.show()
@@ -355,11 +355,11 @@ def piecewise_quadratic(
355
355
 
356
356
  >>> import jax
357
357
  >>> import brainstate.nn as nn
358
- >>> import brainstate as bst
358
+ >>> import brainstate as brainstate
359
359
  >>> import matplotlib.pyplot as plt
360
360
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
361
361
  >>> for alpha in [0.5, 1., 2., 4.]:
362
- >>> grads = bst.augment.vector_grad(bst.surrogate.piecewise_quadratic)(xs, alpha)
362
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_quadratic)(xs, alpha)
363
363
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
364
364
  >>> plt.legend()
365
365
  >>> plt.show()
@@ -522,11 +522,11 @@ def piecewise_exp(
522
522
 
523
523
  >>> import jax
524
524
  >>> import brainstate.nn as nn
525
- >>> import brainstate as bst
525
+ >>> import brainstate as brainstate
526
526
  >>> import matplotlib.pyplot as plt
527
527
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
528
528
  >>> for alpha in [0.5, 1., 2., 4.]:
529
- >>> grads = bst.augment.vector_grad(bst.surrogate.piecewise_exp)(xs, alpha)
529
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_exp)(xs, alpha)
530
530
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
531
531
  >>> plt.legend()
532
532
  >>> plt.show()
@@ -621,11 +621,11 @@ def soft_sign(
621
621
 
622
622
  >>> import jax
623
623
  >>> import brainstate.nn as nn
624
- >>> import brainstate as bst
624
+ >>> import brainstate as brainstate
625
625
  >>> import matplotlib.pyplot as plt
626
626
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
627
627
  >>> for alpha in [0.5, 1., 2., 4.]:
628
- >>> grads = bst.augment.vector_grad(bst.surrogate.soft_sign)(xs, alpha)
628
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.soft_sign)(xs, alpha)
629
629
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
630
630
  >>> plt.legend()
631
631
  >>> plt.show()
@@ -706,11 +706,11 @@ def arctan(
706
706
 
707
707
  >>> import jax
708
708
  >>> import brainstate.nn as nn
709
- >>> import brainstate as bst
709
+ >>> import brainstate as brainstate
710
710
  >>> import matplotlib.pyplot as plt
711
711
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
712
712
  >>> for alpha in [0.5, 1., 2., 4.]:
713
- >>> grads = bst.augment.vector_grad(bst.surrogate.arctan)(xs, alpha)
713
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.arctan)(xs, alpha)
714
714
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
715
715
  >>> plt.legend()
716
716
  >>> plt.show()
@@ -804,11 +804,11 @@ def nonzero_sign_log(
804
804
 
805
805
  >>> import jax
806
806
  >>> import brainstate.nn as nn
807
- >>> import brainstate as bst
807
+ >>> import brainstate as brainstate
808
808
  >>> import matplotlib.pyplot as plt
809
809
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
810
810
  >>> for alpha in [0.5, 1., 2., 4.]:
811
- >>> grads = bst.augment.vector_grad(bst.surrogate.nonzero_sign_log)(xs, alpha)
811
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.nonzero_sign_log)(xs, alpha)
812
812
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
813
813
  >>> plt.legend()
814
814
  >>> plt.show()
@@ -893,11 +893,11 @@ def erf(
893
893
 
894
894
  >>> import jax
895
895
  >>> import brainstate.nn as nn
896
- >>> import brainstate as bst
896
+ >>> import brainstate as brainstate
897
897
  >>> import matplotlib.pyplot as plt
898
898
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
899
899
  >>> for alpha in [0.5, 1., 2., 4.]:
900
- >>> grads = bst.augment.vector_grad(bst.surrogate.nonzero_sign_log)(xs, alpha)
900
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.nonzero_sign_log)(xs, alpha)
901
901
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
902
902
  >>> plt.legend()
903
903
  >>> plt.show()
@@ -1000,12 +1000,12 @@ def piecewise_leaky_relu(
1000
1000
 
1001
1001
  >>> import jax
1002
1002
  >>> import brainstate.nn as nn
1003
- >>> import brainstate as bst
1003
+ >>> import brainstate as brainstate
1004
1004
  >>> import matplotlib.pyplot as plt
1005
1005
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1006
1006
  >>> for c in [0.01, 0.05, 0.1]:
1007
1007
  >>> for w in [1., 2.]:
1008
- >>> grads1 = bst.augment.vector_grad(bst.surrogate.piecewise_leaky_relu)(xs, c=c, w=w)
1008
+ >>> grads1 = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_leaky_relu)(xs, c=c, w=w)
1009
1009
  >>> plt.plot(xs, grads1, label=f'x={c}, w={w}')
1010
1010
  >>> plt.legend()
1011
1011
  >>> plt.show()
@@ -1113,12 +1113,12 @@ def squarewave_fourier_series(
1113
1113
 
1114
1114
  >>> import jax
1115
1115
  >>> import brainstate.nn as nn
1116
- >>> import brainstate as bst
1116
+ >>> import brainstate as brainstate
1117
1117
  >>> import matplotlib.pyplot as plt
1118
1118
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1119
1119
  >>> for n in [2, 4, 8]:
1120
- >>> f = bst.surrogate.SquarewaveFourierSeries(n=n)
1121
- >>> grads1 = bst.augment.vector_grad(f)(xs)
1120
+ >>> f = brainstate.surrogate.SquarewaveFourierSeries(n=n)
1121
+ >>> grads1 = brainstate.augment.vector_grad(f)(xs)
1122
1122
  >>> plt.plot(xs, grads1, label=f'n={n}')
1123
1123
  >>> plt.legend()
1124
1124
  >>> plt.show()
@@ -1214,12 +1214,12 @@ def s2nn(
1214
1214
 
1215
1215
  >>> import jax
1216
1216
  >>> import brainstate.nn as nn
1217
- >>> import brainstate as bst
1217
+ >>> import brainstate as brainstate
1218
1218
  >>> import matplotlib.pyplot as plt
1219
1219
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1220
- >>> grads = bst.augment.vector_grad(bst.surrogate.s2nn)(xs, 4., 1.)
1220
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.s2nn)(xs, 4., 1.)
1221
1221
  >>> plt.plot(xs, grads, label=r'$\alpha=4, \beta=1$')
1222
- >>> grads = bst.augment.vector_grad(bst.surrogate.s2nn)(xs, 8., 2.)
1222
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.s2nn)(xs, 8., 2.)
1223
1223
  >>> plt.plot(xs, grads, label=r'$\alpha=8, \beta=2$')
1224
1224
  >>> plt.legend()
1225
1225
  >>> plt.show()
@@ -1315,11 +1315,11 @@ def q_pseudo_spike(
1315
1315
 
1316
1316
  >>> import jax
1317
1317
  >>> import brainstate.nn as nn
1318
- >>> import brainstate as bst
1318
+ >>> import brainstate as brainstate
1319
1319
  >>> import matplotlib.pyplot as plt
1320
1320
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1321
1321
  >>> for alpha in [0.5, 1., 2., 4.]:
1322
- >>> grads = bst.augment.vector_grad(bst.surrogate.q_pseudo_spike)(xs, alpha)
1322
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.q_pseudo_spike)(xs, alpha)
1323
1323
  >>> plt.plot(xs, grads, label=r'$\alpha=$' + str(alpha))
1324
1324
  >>> plt.legend()
1325
1325
  >>> plt.show()
@@ -1413,10 +1413,10 @@ def leaky_relu(
1413
1413
 
1414
1414
  >>> import jax
1415
1415
  >>> import brainstate.nn as nn
1416
- >>> import brainstate as bst
1416
+ >>> import brainstate as brainstate
1417
1417
  >>> import matplotlib.pyplot as plt
1418
1418
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1419
- >>> grads = bst.augment.vector_grad(bst.surrogate.leaky_relu)(xs, 0., 1.)
1419
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.leaky_relu)(xs, 0., 1.)
1420
1420
  >>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
1421
1421
  >>> plt.legend()
1422
1422
  >>> plt.show()
@@ -1517,10 +1517,10 @@ def log_tailed_relu(
1517
1517
 
1518
1518
  >>> import jax
1519
1519
  >>> import brainstate.nn as nn
1520
- >>> import brainstate as bst
1520
+ >>> import brainstate as brainstate
1521
1521
  >>> import matplotlib.pyplot as plt
1522
1522
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1523
- >>> grads = bst.augment.vector_grad(bst.surrogate.leaky_relu)(xs, 0., 1.)
1523
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.leaky_relu)(xs, 0., 1.)
1524
1524
  >>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
1525
1525
  >>> plt.legend()
1526
1526
  >>> plt.show()
@@ -1596,12 +1596,12 @@ def relu_grad(
1596
1596
 
1597
1597
  >>> import jax
1598
1598
  >>> import brainstate.nn as nn
1599
- >>> import brainstate as bst
1599
+ >>> import brainstate as brainstate
1600
1600
  >>> import matplotlib.pyplot as plt
1601
1601
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1602
1602
  >>> for s in [0.5, 1.]:
1603
1603
  >>> for w in [1, 2.]:
1604
- >>> grads = bst.augment.vector_grad(bst.surrogate.relu_grad)(xs, s, w)
1604
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.relu_grad)(xs, s, w)
1605
1605
  >>> plt.plot(xs, grads, label=r'$\alpha=$' + f'{s}, width={w}')
1606
1606
  >>> plt.legend()
1607
1607
  >>> plt.show()
@@ -1678,11 +1678,11 @@ def gaussian_grad(
1678
1678
 
1679
1679
  >>> import jax
1680
1680
  >>> import brainstate.nn as nn
1681
- >>> import brainstate as bst
1681
+ >>> import brainstate as brainstate
1682
1682
  >>> import matplotlib.pyplot as plt
1683
1683
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1684
1684
  >>> for s in [0.5, 1., 2.]:
1685
- >>> grads = bst.augment.vector_grad(bst.surrogate.gaussian_grad)(xs, s, 0.5)
1685
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.gaussian_grad)(xs, s, 0.5)
1686
1686
  >>> plt.plot(xs, grads, label=r'$\alpha=0.5, \sigma=$' + str(s))
1687
1687
  >>> plt.legend()
1688
1688
  >>> plt.show()
@@ -1773,10 +1773,10 @@ def multi_gaussian_grad(
1773
1773
 
1774
1774
  >>> import jax
1775
1775
  >>> import brainstate.nn as nn
1776
- >>> import brainstate as bst
1776
+ >>> import brainstate as brainstate
1777
1777
  >>> import matplotlib.pyplot as plt
1778
1778
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1779
- >>> grads = bst.augment.vector_grad(bst.surrogate.multi_gaussian_grad)(xs)
1779
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.multi_gaussian_grad)(xs)
1780
1780
  >>> plt.plot(xs, grads)
1781
1781
  >>> plt.show()
1782
1782
 
@@ -1855,11 +1855,11 @@ def inv_square_grad(
1855
1855
 
1856
1856
  >>> import jax
1857
1857
  >>> import brainstate.nn as nn
1858
- >>> import brainstate as bst
1858
+ >>> import brainstate as brainstate
1859
1859
  >>> import matplotlib.pyplot as plt
1860
1860
  >>> xs = jax.numpy.linspace(-1, 1, 1000)
1861
1861
  >>> for alpha in [1., 10., 100.]:
1862
- >>> grads = bst.augment.vector_grad(bst.surrogate.inv_square_grad)(xs, alpha)
1862
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.inv_square_grad)(xs, alpha)
1863
1863
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
1864
1864
  >>> plt.legend()
1865
1865
  >>> plt.show()
@@ -1929,11 +1929,11 @@ def slayer_grad(
1929
1929
 
1930
1930
  >>> import jax
1931
1931
  >>> import brainstate.nn as nn
1932
- >>> import brainstate as bst
1932
+ >>> import brainstate as brainstate
1933
1933
  >>> import matplotlib.pyplot as plt
1934
1934
  >>> xs = jax.numpy.linspace(-3, 3, 1000)
1935
1935
  >>> for alpha in [0.5, 1., 2., 4.]:
1936
- >>> grads = bst.augment.vector_grad(bst.surrogate.slayer_grad)(xs, alpha)
1936
+ >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.slayer_grad)(xs, alpha)
1937
1937
  >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
1938
1938
  >>> plt.legend()
1939
1939
  >>> plt.show()