brainstate 0.1.5__py2.py3-none-any.whl → 0.1.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_state.py +5 -5
  3. brainstate/augment/_autograd.py +31 -12
  4. brainstate/augment/_autograd_test.py +46 -46
  5. brainstate/augment/_eval_shape.py +4 -4
  6. brainstate/augment/_mapping.py +13 -8
  7. brainstate/compile/_conditions.py +2 -2
  8. brainstate/compile/_make_jaxpr.py +59 -6
  9. brainstate/compile/_progress_bar.py +2 -2
  10. brainstate/environ.py +19 -19
  11. brainstate/functional/_activations_test.py +12 -12
  12. brainstate/graph/_graph_operation.py +69 -69
  13. brainstate/graph/_graph_operation_test.py +2 -2
  14. brainstate/mixin.py +0 -17
  15. brainstate/nn/_collective_ops.py +4 -4
  16. brainstate/nn/_dropout_test.py +2 -2
  17. brainstate/nn/_dynamics.py +53 -35
  18. brainstate/nn/_elementwise.py +30 -30
  19. brainstate/nn/_linear.py +4 -4
  20. brainstate/nn/_module.py +6 -6
  21. brainstate/nn/_module_test.py +1 -1
  22. brainstate/nn/_normalizations.py +11 -11
  23. brainstate/nn/_normalizations_test.py +6 -6
  24. brainstate/nn/_poolings.py +24 -24
  25. brainstate/nn/_synapse.py +1 -12
  26. brainstate/nn/_utils.py +1 -1
  27. brainstate/nn/metrics.py +4 -4
  28. brainstate/optim/_optax_optimizer.py +8 -8
  29. brainstate/random/_rand_funs.py +37 -37
  30. brainstate/random/_rand_funs_test.py +3 -3
  31. brainstate/random/_rand_seed.py +7 -7
  32. brainstate/surrogate.py +40 -40
  33. brainstate/util/pretty_pytree.py +10 -10
  34. brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
  35. brainstate/util/struct.py +7 -7
  36. {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/METADATA +12 -12
  37. {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/RECORD +40 -40
  38. {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/WHEEL +1 -1
  39. {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/LICENSE +0 -0
  40. {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/top_level.txt +0 -0
@@ -55,8 +55,8 @@ class Flatten(Module):
55
55
  end_axis: last dim to flatten (default = -1).
56
56
 
57
57
  Examples::
58
- >>> import brainstate as bst
59
- >>> inp = bst.random.randn(32, 1, 5, 5)
58
+ >>> import brainstate as brainstate
59
+ >>> inp = brainstate.random.randn(32, 1, 5, 5)
60
60
  >>> # With default parameters
61
61
  >>> m = Flatten()
62
62
  >>> output = m(inp)
@@ -334,10 +334,10 @@ class MaxPool1d(_MaxPool):
334
334
 
335
335
  Examples::
336
336
 
337
- >>> import brainstate as bst
337
+ >>> import brainstate as brainstate
338
338
  >>> # pool of size=3, stride=2
339
339
  >>> m = MaxPool1d(3, stride=2, channel_axis=-1)
340
- >>> input = bst.random.randn(20, 50, 16)
340
+ >>> input = brainstate.random.randn(20, 50, 16)
341
341
  >>> output = m(input)
342
342
  >>> output.shape
343
343
  (20, 24, 16)
@@ -418,12 +418,12 @@ class MaxPool2d(_MaxPool):
418
418
 
419
419
  Examples::
420
420
 
421
- >>> import brainstate as bst
421
+ >>> import brainstate as brainstate
422
422
  >>> # pool of square window of size=3, stride=2
423
423
  >>> m = MaxPool2d(3, stride=2)
424
424
  >>> # pool of non-square window
425
425
  >>> m = MaxPool2d((3, 2), stride=(2, 1), channel_axis=-1)
426
- >>> input = bst.random.randn(20, 50, 32, 16)
426
+ >>> input = brainstate.random.randn(20, 50, 32, 16)
427
427
  >>> output = m(input)
428
428
  >>> output.shape
429
429
  (20, 24, 31, 16)
@@ -509,12 +509,12 @@ class MaxPool3d(_MaxPool):
509
509
 
510
510
  Examples::
511
511
 
512
- >>> import brainstate as bst
512
+ >>> import brainstate as brainstate
513
513
  >>> # pool of square window of size=3, stride=2
514
514
  >>> m = MaxPool3d(3, stride=2)
515
515
  >>> # pool of non-square window
516
516
  >>> m = MaxPool3d((3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
517
- >>> input = bst.random.randn(20, 50, 44, 31, 16)
517
+ >>> input = brainstate.random.randn(20, 50, 44, 31, 16)
518
518
  >>> output = m(input)
519
519
  >>> output.shape
520
520
  (20, 24, 43, 15, 16)
@@ -588,10 +588,10 @@ class AvgPool1d(_AvgPool):
588
588
 
589
589
  Examples::
590
590
 
591
- >>> import brainstate as bst
591
+ >>> import brainstate as brainstate
592
592
  >>> # pool with window of size=3, stride=2
593
593
  >>> m = AvgPool1d(3, stride=2)
594
- >>> input = bst.random.randn(20, 50, 16)
594
+ >>> input = brainstate.random.randn(20, 50, 16)
595
595
  >>> m(input).shape
596
596
  (20, 24, 16)
597
597
 
@@ -665,12 +665,12 @@ class AvgPool2d(_AvgPool):
665
665
 
666
666
  Examples::
667
667
 
668
- >>> import brainstate as bst
668
+ >>> import brainstate as brainstate
669
669
  >>> # pool of square window of size=3, stride=2
670
670
  >>> m = AvgPool2d(3, stride=2)
671
671
  >>> # pool of non-square window
672
672
  >>> m = AvgPool2d((3, 2), stride=(2, 1))
673
- >>> input = bst.random.randn(20, 50, 32, , 16)
673
+ >>> input = brainstate.random.randn(20, 50, 32, , 16)
674
674
  >>> output = m(input)
675
675
  >>> output.shape
676
676
  (20, 24, 31, 16)
@@ -753,12 +753,12 @@ class AvgPool3d(_AvgPool):
753
753
 
754
754
  Examples::
755
755
 
756
- >>> import brainstate as bst
756
+ >>> import brainstate as brainstate
757
757
  >>> # pool of square window of size=3, stride=2
758
758
  >>> m = AvgPool3d(3, stride=2)
759
759
  >>> # pool of non-square window
760
760
  >>> m = AvgPool3d((3, 2, 2), stride=(2, 1, 2))
761
- >>> input = bst.random.randn(20, 50, 44, 31, 16)
761
+ >>> input = brainstate.random.randn(20, 50, 44, 31, 16)
762
762
  >>> output = m(input)
763
763
  >>> output.shape
764
764
  (20, 24, 43, 15, 16)
@@ -931,10 +931,10 @@ class AdaptiveAvgPool1d(_AdaptivePool):
931
931
 
932
932
  Examples:
933
933
 
934
- >>> import brainstate as bst
934
+ >>> import brainstate as brainstate
935
935
  >>> # target output size of 5
936
936
  >>> m = AdaptiveMaxPool1d(5)
937
- >>> input = bst.random.randn(1, 64, 8)
937
+ >>> input = brainstate.random.randn(1, 64, 8)
938
938
  >>> output = m(input)
939
939
  >>> output.shape
940
940
  (1, 5, 8)
@@ -979,22 +979,22 @@ class AdaptiveAvgPool2d(_AdaptivePool):
979
979
 
980
980
  Examples:
981
981
 
982
- >>> import brainstate as bst
982
+ >>> import brainstate as brainstate
983
983
  >>> # target output size of 5x7
984
984
  >>> m = AdaptiveMaxPool2d((5, 7))
985
- >>> input = bst.random.randn(1, 8, 9, 64)
985
+ >>> input = brainstate.random.randn(1, 8, 9, 64)
986
986
  >>> output = m(input)
987
987
  >>> output.shape
988
988
  (1, 5, 7, 64)
989
989
  >>> # target output size of 7x7 (square)
990
990
  >>> m = AdaptiveMaxPool2d(7)
991
- >>> input = bst.random.randn(1, 10, 9, 64)
991
+ >>> input = brainstate.random.randn(1, 10, 9, 64)
992
992
  >>> output = m(input)
993
993
  >>> output.shape
994
994
  (1, 7, 7, 64)
995
995
  >>> # target output size of 10x7
996
996
  >>> m = AdaptiveMaxPool2d((None, 7))
997
- >>> input = bst.random.randn(1, 10, 9, 64)
997
+ >>> input = brainstate.random.randn(1, 10, 9, 64)
998
998
  >>> output = m(input)
999
999
  >>> output.shape
1000
1000
  (1, 10, 7, 64)
@@ -1040,22 +1040,22 @@ class AdaptiveAvgPool3d(_AdaptivePool):
1040
1040
 
1041
1041
  Examples:
1042
1042
 
1043
- >>> import brainstate as bst
1043
+ >>> import brainstate as brainstate
1044
1044
  >>> # target output size of 5x7x9
1045
1045
  >>> m = AdaptiveMaxPool3d((5, 7, 9))
1046
- >>> input = bst.random.randn(1, 8, 9, 10, 64)
1046
+ >>> input = brainstate.random.randn(1, 8, 9, 10, 64)
1047
1047
  >>> output = m(input)
1048
1048
  >>> output.shape
1049
1049
  (1, 5, 7, 9, 64)
1050
1050
  >>> # target output size of 7x7x7 (cube)
1051
1051
  >>> m = AdaptiveMaxPool3d(7)
1052
- >>> input = bst.random.randn(1, 10, 9, 8, 64)
1052
+ >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
1053
1053
  >>> output = m(input)
1054
1054
  >>> output.shape
1055
1055
  (1, 7, 7, 7, 64)
1056
1056
  >>> # target output size of 7x9x8
1057
1057
  >>> m = AdaptiveMaxPool3d((7, None, None))
1058
- >>> input = bst.random.randn(1, 10, 9, 8, 64)
1058
+ >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
1059
1059
  >>> output = m(input)
1060
1060
  >>> output.shape
1061
1061
  (1, 7, 9, 8, 64)
brainstate/nn/_synapse.py CHANGED
@@ -123,9 +123,6 @@ class Expon(Synapse, AlignPost):
123
123
  g = exp_euler_step(lambda g: self.sum_current_inputs(-g) / self.tau, self.g.value)
124
124
  self.g.value = self.sum_delta_inputs(g)
125
125
  if x is not None: self.g.value += x
126
- return self.update_return()
127
-
128
- def update_return(self) -> PyTree:
129
126
  return self.g.value
130
127
 
131
128
 
@@ -232,9 +229,6 @@ class DualExpon(Synapse, AlignPost):
232
229
  if x is not None:
233
230
  self.g_rise.value += x
234
231
  self.g_decay.value += x
235
- return self.update_return()
236
-
237
- def update_return(self) -> PyTree:
238
232
  return self.a * (self.g_decay.value - self.g_rise.value)
239
233
 
240
234
 
@@ -414,12 +408,8 @@ class AMPA(Synapse):
414
408
  t = environ.get('t')
415
409
  self.spike_arrival_time.value = u.math.where(pre_spike, t, self.spike_arrival_time.value)
416
410
  TT = ((t - self.spike_arrival_time.value) < self.T_duration) * self.T
417
- dg = lambda g: self.alpha * TT * (1 - g) - self.beta * g
411
+ dg = lambda g: self.alpha * TT * (1 * u.get_unit(g) - g) - self.beta * g
418
412
  self.g.value = exp_euler_step(dg, self.g.value)
419
- return self.update_return()
420
-
421
- def update_return(self) -> PyTree:
422
- """Return the synaptic conductance value."""
423
413
  return self.g.value
424
414
 
425
415
 
@@ -513,4 +503,3 @@ class GABAa(AMPA):
513
503
  in_size=in_size,
514
504
  g_initializer=g_initializer
515
505
  )
516
-
brainstate/nn/_utils.py CHANGED
@@ -62,7 +62,7 @@ def count_parameters(
62
62
 
63
63
  Parameters:
64
64
  -----------
65
- model : bst.nn.Module
65
+ model : brainstate.nn.Module
66
66
  The neural network model for which to count parameters.
67
67
 
68
68
  Returns:
brainstate/nn/metrics.py CHANGED
@@ -61,12 +61,12 @@ class Average(Metric):
61
61
  Example usage::
62
62
 
63
63
  >>> import jax.numpy as jnp
64
- >>> import brainstate as bst
64
+ >>> import brainstate as brainstate
65
65
 
66
66
  >>> batch_loss = jnp.array([1, 2, 3, 4])
67
67
  >>> batch_loss2 = jnp.array([3, 2, 1, 0])
68
68
 
69
- >>> metrics = bst.nn.metrics.Average()
69
+ >>> metrics = brainstate.nn.metrics.Average()
70
70
  >>> metrics.compute()
71
71
  Array(nan, dtype=float32)
72
72
  >>> metrics.update(values=batch_loss)
@@ -223,7 +223,7 @@ class Accuracy(Average):
223
223
 
224
224
  Example usage::
225
225
 
226
- >>> import brainstate as bst
226
+ >>> import brainstate as brainstate
227
227
  >>> import jax, jax.numpy as jnp
228
228
 
229
229
  >>> logits = jax.random.normal(jax.random.key(0), (5, 2))
@@ -231,7 +231,7 @@ class Accuracy(Average):
231
231
  >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
232
232
  >>> labels2 = jnp.array([0, 1, 1, 1, 1])
233
233
 
234
- >>> metrics = bst.nn.metrics.Accuracy()
234
+ >>> metrics = brainstate.nn.metrics.Accuracy()
235
235
  >>> metrics.compute()
236
236
  Array(nan, dtype=float32)
237
237
  >>> metrics.update(logits=logits, labels=labels)
@@ -36,29 +36,29 @@ class OptaxOptimizer(Optimizer):
36
36
 
37
37
  >>> import jax
38
38
  >>> import jax.numpy as jnp
39
- >>> import brainstate as bst
39
+ >>> import brainstate as brainstate
40
40
  >>> import optax
41
41
  ...
42
- >>> class Model(bst.nn.Module):
42
+ >>> class Model(brainstate.nn.Module):
43
43
  ... def __init__(self):
44
44
  ... super().__init__()
45
- ... self.linear1 = bst.nn.Linear(2, 3)
46
- ... self.linear2 = bst.nn.Linear(3, 4)
45
+ ... self.linear1 = brainstate.nn.Linear(2, 3)
46
+ ... self.linear2 = brainstate.nn.Linear(3, 4)
47
47
  ... def __call__(self, x):
48
48
  ... return self.linear2(self.linear1(x))
49
49
  ...
50
- >>> x = bst.random.randn(1, 2)
50
+ >>> x = brainstate.random.randn(1, 2)
51
51
  >>> y = jnp.ones((1, 4))
52
52
  ...
53
53
  >>> model = Model()
54
54
  >>> tx = optax.adam(1e-3)
55
- >>> optimizer = bst.optim.OptaxOptimizer(tx)
56
- >>> optimizer.register_trainable_weights(model.states(bst.ParamState))
55
+ >>> optimizer = brainstate.optim.OptaxOptimizer(tx)
56
+ >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
57
57
  ...
58
58
  >>> loss_fn = lambda: ((model(x) - y) ** 2).mean()
59
59
  >>> loss_fn()
60
60
  Array(1.7055722, dtype=float32)
61
- >>> grads = bst.augment.grad(loss_fn, model.states(bst.ParamState))()
61
+ >>> grads = brainstate.augment.grad(loss_fn, model.states(brainstate.ParamState))()
62
62
  >>> optimizer.update(grads)
63
63
  >>> loss_fn()
64
64
  Array(1.6925814, dtype=float32)
@@ -78,8 +78,8 @@ def rand(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
78
78
 
79
79
  Examples
80
80
  --------
81
- >>> import brainstate as bst
82
- >>> bst.random.rand(3,2)
81
+ >>> import brainstate as brainstate
82
+ >>> brainstate.random.rand(3,2)
83
83
  array([[ 0.14022471, 0.96360618], #random
84
84
  [ 0.37601032, 0.25528411], #random
85
85
  [ 0.49313049, 0.94909878]]) #random
@@ -135,31 +135,31 @@ def randint(low, high=None, size: Optional[Size] = None,
135
135
 
136
136
  Examples
137
137
  --------
138
- >>> import brainstate as bst
139
- >>> bst.random.randint(2, size=10)
138
+ >>> import brainstate as brainstate
139
+ >>> brainstate.random.randint(2, size=10)
140
140
  array([1, 0, 0, 0, 1, 1, 0, 0, 1, 0]) # random
141
- >>> bst.random.randint(1, size=10)
141
+ >>> brainstate.random.randint(1, size=10)
142
142
  array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
143
143
 
144
144
  Generate a 2 x 4 array of ints between 0 and 4, inclusive:
145
145
 
146
- >>> bst.random.randint(5, size=(2, 4))
146
+ >>> brainstate.random.randint(5, size=(2, 4))
147
147
  array([[4, 0, 2, 1], # random
148
148
  [3, 2, 2, 0]])
149
149
 
150
150
  Generate a 1 x 3 array with 3 different upper bounds
151
151
 
152
- >>> bst.random.randint(1, [3, 5, 10])
152
+ >>> brainstate.random.randint(1, [3, 5, 10])
153
153
  array([2, 2, 9]) # random
154
154
 
155
155
  Generate a 1 by 3 array with 3 different lower bounds
156
156
 
157
- >>> bst.random.randint([1, 5, 7], 10)
157
+ >>> brainstate.random.randint([1, 5, 7], 10)
158
158
  array([9, 8, 7]) # random
159
159
 
160
160
  Generate a 2 by 4 array using broadcasting with dtype of uint8
161
161
 
162
- >>> bst.random.randint([1, 3, 5, 7], [[10], [20]], dtype=np.uint8)
162
+ >>> brainstate.random.randint([1, 3, 5, 7], [[10], [20]], dtype=np.uint8)
163
163
  array([[ 8, 6, 9, 7], # random
164
164
  [ 1, 16, 9, 12]], dtype=uint8)
165
165
  """
@@ -219,12 +219,12 @@ def random_integers(low,
219
219
 
220
220
  Examples
221
221
  --------
222
- >>> import brainstate as bst
223
- >>> bst.random.random_integers(5)
222
+ >>> import brainstate as brainstate
223
+ >>> brainstate.random.random_integers(5)
224
224
  4 # random
225
- >>> type(bst.random.random_integers(5))
225
+ >>> type(brainstate.random.random_integers(5))
226
226
  <class 'numpy.int64'>
227
- >>> bst.random.random_integers(5, size=(3,2))
227
+ >>> brainstate.random.random_integers(5, size=(3,2))
228
228
  array([[5, 4], # random
229
229
  [3, 3],
230
230
  [4, 5]])
@@ -233,13 +233,13 @@ def random_integers(low,
233
233
  numbers between 0 and 2.5, inclusive (*i.e.*, from the set
234
234
  :math:`{0, 5/8, 10/8, 15/8, 20/8}`):
235
235
 
236
- >>> 2.5 * (bst.random.random_integers(5, size=(5,)) - 1) / 4.
236
+ >>> 2.5 * (brainstate.random.random_integers(5, size=(5,)) - 1) / 4.
237
237
  array([ 0.625, 1.25 , 0.625, 0.625, 2.5 ]) # random
238
238
 
239
239
  Roll two six sided dice 1000 times and sum the results:
240
240
 
241
- >>> d1 = bst.random.random_integers(1, 6, 1000)
242
- >>> d2 = bst.random.random_integers(1, 6, 1000)
241
+ >>> d1 = brainstate.random.random_integers(1, 6, 1000)
242
+ >>> d2 = brainstate.random.random_integers(1, 6, 1000)
243
243
  >>> dsums = d1 + d2
244
244
 
245
245
  Display results as a histogram:
@@ -301,13 +301,13 @@ def randn(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
301
301
 
302
302
  Examples
303
303
  --------
304
- >>> import brainstate as bst
305
- >>> bst.random.randn()
304
+ >>> import brainstate as brainstate
305
+ >>> brainstate.random.randn()
306
306
  2.1923875335537315 # random
307
307
 
308
308
  Two-by-four array of samples from N(3, 6.25):
309
309
 
310
- >>> 3 + 2.5 * bst.random.randn(2, 4)
310
+ >>> 3 + 2.5 * brainstate.random.randn(2, 4)
311
311
  array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random
312
312
  [ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random
313
313
  """
@@ -359,17 +359,17 @@ def random_sample(size: Optional[Size] = None, key: Optional[SeedOrKey] = None,
359
359
 
360
360
  Examples
361
361
  --------
362
- >>> import brainstate as bst
363
- >>> bst.random.random_sample()
362
+ >>> import brainstate as brainstate
363
+ >>> brainstate.random.random_sample()
364
364
  0.47108547995356098 # random
365
- >>> type(bst.random.random_sample())
365
+ >>> type(brainstate.random.random_sample())
366
366
  <class 'float'>
367
- >>> bst.random.random_sample((5,))
367
+ >>> brainstate.random.random_sample((5,))
368
368
  array([ 0.30220482, 0.86820401, 0.1654503 , 0.11659149, 0.54323428]) # random
369
369
 
370
370
  Three-by-two array of random numbers from [-5, 0):
371
371
 
372
- >>> 5 * bst.random.random_sample((3, 2)) - 5
372
+ >>> 5 * brainstate.random.random_sample((3, 2)) - 5
373
373
  array([[-3.99149989, -0.52338984], # random
374
374
  [-2.99091858, -0.79479508],
375
375
  [-1.23204345, -1.75224494]])
@@ -450,34 +450,34 @@ def choice(a, size: Optional[Size] = None, replace=True, p=None,
450
450
  --------
451
451
  Generate a uniform random sample from np.arange(5) of size 3:
452
452
 
453
- >>> import brainstate as bst
454
- >>> bst.random.choice(5, 3)
453
+ >>> import brainstate as brainstate
454
+ >>> brainstate.random.choice(5, 3)
455
455
  array([0, 3, 4]) # random
456
456
  >>> #This is equivalent to brainpy.math.random.randint(0,5,3)
457
457
 
458
458
  Generate a non-uniform random sample from np.arange(5) of size 3:
459
459
 
460
- >>> bst.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0])
460
+ >>> brainstate.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0])
461
461
  array([3, 3, 0]) # random
462
462
 
463
463
  Generate a uniform random sample from np.arange(5) of size 3 without
464
464
  replacement:
465
465
 
466
- >>> bst.random.choice(5, 3, replace=False)
466
+ >>> brainstate.random.choice(5, 3, replace=False)
467
467
  array([3,1,0]) # random
468
468
  >>> #This is equivalent to brainpy.math.random.permutation(np.arange(5))[:3]
469
469
 
470
470
  Generate a non-uniform random sample from np.arange(5) of size
471
471
  3 without replacement:
472
472
 
473
- >>> bst.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0])
473
+ >>> brainstate.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0])
474
474
  array([2, 3, 0]) # random
475
475
 
476
476
  Any of the above can be repeated with an arbitrary array-like
477
477
  instead of just integers. For instance:
478
478
 
479
479
  >>> aa_milne_arr = ['pooh', 'rabbit', 'piglet', 'Christopher']
480
- >>> bst.random.choice(aa_milne_arr, 5, p=[0.5, 0.1, 0.1, 0.3])
480
+ >>> brainstate.random.choice(aa_milne_arr, 5, p=[0.5, 0.1, 0.1, 0.3])
481
481
  array(['pooh', 'pooh', 'pooh', 'Christopher', 'piglet'], # random
482
482
  dtype='<U11')
483
483
  """
@@ -519,15 +519,15 @@ def permutation(x,
519
519
 
520
520
  Examples
521
521
  --------
522
- >>> import brainstate as bst
523
- >>> bst.random.permutation(10)
522
+ >>> import brainstate as brainstate
523
+ >>> brainstate.random.permutation(10)
524
524
  array([1, 7, 4, 3, 0, 9, 2, 5, 8, 6]) # random
525
525
 
526
- >>> bst.random.permutation([1, 4, 9, 12, 15])
526
+ >>> brainstate.random.permutation([1, 4, 9, 12, 15])
527
527
  array([15, 1, 9, 4, 12]) # random
528
528
 
529
529
  >>> arr = np.arange(9).reshape((3, 3))
530
- >>> bst.random.permutation(arr)
530
+ >>> brainstate.random.permutation(arr)
531
531
  array([[6, 7, 8], # random
532
532
  [0, 1, 2],
533
533
  [3, 4, 5]])
@@ -557,16 +557,16 @@ def shuffle(x, axis=0, key: Optional[SeedOrKey] = None):
557
557
 
558
558
  Examples
559
559
  --------
560
- >>> import brainstate as bst
560
+ >>> import brainstate as brainstate
561
561
  >>> arr = np.arange(10)
562
- >>> bst.random.shuffle(arr)
562
+ >>> brainstate.random.shuffle(arr)
563
563
  >>> arr
564
564
  [1 7 5 2 9 4 3 6 0 8] # random
565
565
 
566
566
  Multi-dimensional arrays are only shuffled along the first axis:
567
567
 
568
568
  >>> arr = np.arange(9).reshape((3, 3))
569
- >>> bst.random.shuffle(arr)
569
+ >>> brainstate.random.shuffle(arr)
570
570
  >>> arr
571
571
  array([[3, 4, 5], # random
572
572
  [6, 7, 8],
@@ -562,6 +562,6 @@ class TestRandom(unittest.TestCase):
562
562
 
563
563
  # class TestRandomKey(unittest.TestCase):
564
564
  # def test_clear_memory(self):
565
- # bst.random.split_key()
566
- # print(bst.random.DEFAULT.value)
567
- # self.assertTrue(isinstance(bst.random.DEFAULT.value, np.ndarray))
565
+ # brainstate.random.split_key()
566
+ # print(brainstate.random.DEFAULT.value)
567
+ # self.assertTrue(isinstance(brainstate.random.DEFAULT.value, np.ndarray))
@@ -183,16 +183,16 @@ def seed_context(seed_or_key: SeedOrKey):
183
183
 
184
184
  Examples:
185
185
 
186
- >>> import brainstate as bst
187
- >>> print(bst.random.rand(2))
186
+ >>> import brainstate as brainstate
187
+ >>> print(brainstate.random.rand(2))
188
188
  [0.57721865 0.9820676 ]
189
- >>> print(bst.random.rand(2))
189
+ >>> print(brainstate.random.rand(2))
190
190
  [0.8511752 0.95312667]
191
- >>> with bst.random.seed_context(42):
192
- ... print(bst.random.rand(2))
191
+ >>> with brainstate.random.seed_context(42):
192
+ ... print(brainstate.random.rand(2))
193
193
  [0.95598125 0.4032725 ]
194
- >>> with bst.random.seed_context(42):
195
- ... print(bst.random.rand(2))
194
+ >>> with brainstate.random.seed_context(42):
195
+ ... print(brainstate.random.rand(2))
196
196
  [0.95598125 0.4032725 ]
197
197
 
198
198
  Args: