brainstate 0.1.5__py2.py3-none-any.whl → 0.1.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_state.py +5 -5
- brainstate/augment/_autograd.py +31 -12
- brainstate/augment/_autograd_test.py +46 -46
- brainstate/augment/_eval_shape.py +4 -4
- brainstate/augment/_mapping.py +13 -8
- brainstate/compile/_conditions.py +2 -2
- brainstate/compile/_make_jaxpr.py +59 -6
- brainstate/compile/_progress_bar.py +2 -2
- brainstate/environ.py +19 -19
- brainstate/functional/_activations_test.py +12 -12
- brainstate/graph/_graph_operation.py +69 -69
- brainstate/graph/_graph_operation_test.py +2 -2
- brainstate/mixin.py +0 -17
- brainstate/nn/_collective_ops.py +4 -4
- brainstate/nn/_dropout_test.py +2 -2
- brainstate/nn/_dynamics.py +53 -35
- brainstate/nn/_elementwise.py +30 -30
- brainstate/nn/_linear.py +4 -4
- brainstate/nn/_module.py +6 -6
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +11 -11
- brainstate/nn/_normalizations_test.py +6 -6
- brainstate/nn/_poolings.py +24 -24
- brainstate/nn/_synapse.py +1 -12
- brainstate/nn/_utils.py +1 -1
- brainstate/nn/metrics.py +4 -4
- brainstate/optim/_optax_optimizer.py +8 -8
- brainstate/random/_rand_funs.py +37 -37
- brainstate/random/_rand_funs_test.py +3 -3
- brainstate/random/_rand_seed.py +7 -7
- brainstate/surrogate.py +40 -40
- brainstate/util/pretty_pytree.py +10 -10
- brainstate/util/{_pretty_pytree_test.py → pretty_pytree_test.py} +36 -37
- brainstate/util/struct.py +7 -7
- {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/METADATA +12 -12
- {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/RECORD +40 -40
- {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/WHEEL +1 -1
- {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/LICENSE +0 -0
- {brainstate-0.1.5.dist-info → brainstate-0.1.6.dist-info}/top_level.txt +0 -0
brainstate/nn/_poolings.py
CHANGED
@@ -55,8 +55,8 @@ class Flatten(Module):
|
|
55
55
|
end_axis: last dim to flatten (default = -1).
|
56
56
|
|
57
57
|
Examples::
|
58
|
-
>>> import brainstate as
|
59
|
-
>>> inp =
|
58
|
+
>>> import brainstate as brainstate
|
59
|
+
>>> inp = brainstate.random.randn(32, 1, 5, 5)
|
60
60
|
>>> # With default parameters
|
61
61
|
>>> m = Flatten()
|
62
62
|
>>> output = m(inp)
|
@@ -334,10 +334,10 @@ class MaxPool1d(_MaxPool):
|
|
334
334
|
|
335
335
|
Examples::
|
336
336
|
|
337
|
-
>>> import brainstate as
|
337
|
+
>>> import brainstate as brainstate
|
338
338
|
>>> # pool of size=3, stride=2
|
339
339
|
>>> m = MaxPool1d(3, stride=2, channel_axis=-1)
|
340
|
-
>>> input =
|
340
|
+
>>> input = brainstate.random.randn(20, 50, 16)
|
341
341
|
>>> output = m(input)
|
342
342
|
>>> output.shape
|
343
343
|
(20, 24, 16)
|
@@ -418,12 +418,12 @@ class MaxPool2d(_MaxPool):
|
|
418
418
|
|
419
419
|
Examples::
|
420
420
|
|
421
|
-
>>> import brainstate as
|
421
|
+
>>> import brainstate as brainstate
|
422
422
|
>>> # pool of square window of size=3, stride=2
|
423
423
|
>>> m = MaxPool2d(3, stride=2)
|
424
424
|
>>> # pool of non-square window
|
425
425
|
>>> m = MaxPool2d((3, 2), stride=(2, 1), channel_axis=-1)
|
426
|
-
>>> input =
|
426
|
+
>>> input = brainstate.random.randn(20, 50, 32, 16)
|
427
427
|
>>> output = m(input)
|
428
428
|
>>> output.shape
|
429
429
|
(20, 24, 31, 16)
|
@@ -509,12 +509,12 @@ class MaxPool3d(_MaxPool):
|
|
509
509
|
|
510
510
|
Examples::
|
511
511
|
|
512
|
-
>>> import brainstate as
|
512
|
+
>>> import brainstate as brainstate
|
513
513
|
>>> # pool of square window of size=3, stride=2
|
514
514
|
>>> m = MaxPool3d(3, stride=2)
|
515
515
|
>>> # pool of non-square window
|
516
516
|
>>> m = MaxPool3d((3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
|
517
|
-
>>> input =
|
517
|
+
>>> input = brainstate.random.randn(20, 50, 44, 31, 16)
|
518
518
|
>>> output = m(input)
|
519
519
|
>>> output.shape
|
520
520
|
(20, 24, 43, 15, 16)
|
@@ -588,10 +588,10 @@ class AvgPool1d(_AvgPool):
|
|
588
588
|
|
589
589
|
Examples::
|
590
590
|
|
591
|
-
>>> import brainstate as
|
591
|
+
>>> import brainstate as brainstate
|
592
592
|
>>> # pool with window of size=3, stride=2
|
593
593
|
>>> m = AvgPool1d(3, stride=2)
|
594
|
-
>>> input =
|
594
|
+
>>> input = brainstate.random.randn(20, 50, 16)
|
595
595
|
>>> m(input).shape
|
596
596
|
(20, 24, 16)
|
597
597
|
|
@@ -665,12 +665,12 @@ class AvgPool2d(_AvgPool):
|
|
665
665
|
|
666
666
|
Examples::
|
667
667
|
|
668
|
-
>>> import brainstate as
|
668
|
+
>>> import brainstate as brainstate
|
669
669
|
>>> # pool of square window of size=3, stride=2
|
670
670
|
>>> m = AvgPool2d(3, stride=2)
|
671
671
|
>>> # pool of non-square window
|
672
672
|
>>> m = AvgPool2d((3, 2), stride=(2, 1))
|
673
|
-
>>> input =
|
673
|
+
>>> input = brainstate.random.randn(20, 50, 32, , 16)
|
674
674
|
>>> output = m(input)
|
675
675
|
>>> output.shape
|
676
676
|
(20, 24, 31, 16)
|
@@ -753,12 +753,12 @@ class AvgPool3d(_AvgPool):
|
|
753
753
|
|
754
754
|
Examples::
|
755
755
|
|
756
|
-
>>> import brainstate as
|
756
|
+
>>> import brainstate as brainstate
|
757
757
|
>>> # pool of square window of size=3, stride=2
|
758
758
|
>>> m = AvgPool3d(3, stride=2)
|
759
759
|
>>> # pool of non-square window
|
760
760
|
>>> m = AvgPool3d((3, 2, 2), stride=(2, 1, 2))
|
761
|
-
>>> input =
|
761
|
+
>>> input = brainstate.random.randn(20, 50, 44, 31, 16)
|
762
762
|
>>> output = m(input)
|
763
763
|
>>> output.shape
|
764
764
|
(20, 24, 43, 15, 16)
|
@@ -931,10 +931,10 @@ class AdaptiveAvgPool1d(_AdaptivePool):
|
|
931
931
|
|
932
932
|
Examples:
|
933
933
|
|
934
|
-
>>> import brainstate as
|
934
|
+
>>> import brainstate as brainstate
|
935
935
|
>>> # target output size of 5
|
936
936
|
>>> m = AdaptiveMaxPool1d(5)
|
937
|
-
>>> input =
|
937
|
+
>>> input = brainstate.random.randn(1, 64, 8)
|
938
938
|
>>> output = m(input)
|
939
939
|
>>> output.shape
|
940
940
|
(1, 5, 8)
|
@@ -979,22 +979,22 @@ class AdaptiveAvgPool2d(_AdaptivePool):
|
|
979
979
|
|
980
980
|
Examples:
|
981
981
|
|
982
|
-
>>> import brainstate as
|
982
|
+
>>> import brainstate as brainstate
|
983
983
|
>>> # target output size of 5x7
|
984
984
|
>>> m = AdaptiveMaxPool2d((5, 7))
|
985
|
-
>>> input =
|
985
|
+
>>> input = brainstate.random.randn(1, 8, 9, 64)
|
986
986
|
>>> output = m(input)
|
987
987
|
>>> output.shape
|
988
988
|
(1, 5, 7, 64)
|
989
989
|
>>> # target output size of 7x7 (square)
|
990
990
|
>>> m = AdaptiveMaxPool2d(7)
|
991
|
-
>>> input =
|
991
|
+
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
992
992
|
>>> output = m(input)
|
993
993
|
>>> output.shape
|
994
994
|
(1, 7, 7, 64)
|
995
995
|
>>> # target output size of 10x7
|
996
996
|
>>> m = AdaptiveMaxPool2d((None, 7))
|
997
|
-
>>> input =
|
997
|
+
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
998
998
|
>>> output = m(input)
|
999
999
|
>>> output.shape
|
1000
1000
|
(1, 10, 7, 64)
|
@@ -1040,22 +1040,22 @@ class AdaptiveAvgPool3d(_AdaptivePool):
|
|
1040
1040
|
|
1041
1041
|
Examples:
|
1042
1042
|
|
1043
|
-
>>> import brainstate as
|
1043
|
+
>>> import brainstate as brainstate
|
1044
1044
|
>>> # target output size of 5x7x9
|
1045
1045
|
>>> m = AdaptiveMaxPool3d((5, 7, 9))
|
1046
|
-
>>> input =
|
1046
|
+
>>> input = brainstate.random.randn(1, 8, 9, 10, 64)
|
1047
1047
|
>>> output = m(input)
|
1048
1048
|
>>> output.shape
|
1049
1049
|
(1, 5, 7, 9, 64)
|
1050
1050
|
>>> # target output size of 7x7x7 (cube)
|
1051
1051
|
>>> m = AdaptiveMaxPool3d(7)
|
1052
|
-
>>> input =
|
1052
|
+
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
1053
1053
|
>>> output = m(input)
|
1054
1054
|
>>> output.shape
|
1055
1055
|
(1, 7, 7, 7, 64)
|
1056
1056
|
>>> # target output size of 7x9x8
|
1057
1057
|
>>> m = AdaptiveMaxPool3d((7, None, None))
|
1058
|
-
>>> input =
|
1058
|
+
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
1059
1059
|
>>> output = m(input)
|
1060
1060
|
>>> output.shape
|
1061
1061
|
(1, 7, 9, 8, 64)
|
brainstate/nn/_synapse.py
CHANGED
@@ -123,9 +123,6 @@ class Expon(Synapse, AlignPost):
|
|
123
123
|
g = exp_euler_step(lambda g: self.sum_current_inputs(-g) / self.tau, self.g.value)
|
124
124
|
self.g.value = self.sum_delta_inputs(g)
|
125
125
|
if x is not None: self.g.value += x
|
126
|
-
return self.update_return()
|
127
|
-
|
128
|
-
def update_return(self) -> PyTree:
|
129
126
|
return self.g.value
|
130
127
|
|
131
128
|
|
@@ -232,9 +229,6 @@ class DualExpon(Synapse, AlignPost):
|
|
232
229
|
if x is not None:
|
233
230
|
self.g_rise.value += x
|
234
231
|
self.g_decay.value += x
|
235
|
-
return self.update_return()
|
236
|
-
|
237
|
-
def update_return(self) -> PyTree:
|
238
232
|
return self.a * (self.g_decay.value - self.g_rise.value)
|
239
233
|
|
240
234
|
|
@@ -414,12 +408,8 @@ class AMPA(Synapse):
|
|
414
408
|
t = environ.get('t')
|
415
409
|
self.spike_arrival_time.value = u.math.where(pre_spike, t, self.spike_arrival_time.value)
|
416
410
|
TT = ((t - self.spike_arrival_time.value) < self.T_duration) * self.T
|
417
|
-
dg = lambda g: self.alpha * TT * (1 - g) - self.beta * g
|
411
|
+
dg = lambda g: self.alpha * TT * (1 * u.get_unit(g) - g) - self.beta * g
|
418
412
|
self.g.value = exp_euler_step(dg, self.g.value)
|
419
|
-
return self.update_return()
|
420
|
-
|
421
|
-
def update_return(self) -> PyTree:
|
422
|
-
"""Return the synaptic conductance value."""
|
423
413
|
return self.g.value
|
424
414
|
|
425
415
|
|
@@ -513,4 +503,3 @@ class GABAa(AMPA):
|
|
513
503
|
in_size=in_size,
|
514
504
|
g_initializer=g_initializer
|
515
505
|
)
|
516
|
-
|
brainstate/nn/_utils.py
CHANGED
brainstate/nn/metrics.py
CHANGED
@@ -61,12 +61,12 @@ class Average(Metric):
|
|
61
61
|
Example usage::
|
62
62
|
|
63
63
|
>>> import jax.numpy as jnp
|
64
|
-
>>> import brainstate as
|
64
|
+
>>> import brainstate as brainstate
|
65
65
|
|
66
66
|
>>> batch_loss = jnp.array([1, 2, 3, 4])
|
67
67
|
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
|
68
68
|
|
69
|
-
>>> metrics =
|
69
|
+
>>> metrics = brainstate.nn.metrics.Average()
|
70
70
|
>>> metrics.compute()
|
71
71
|
Array(nan, dtype=float32)
|
72
72
|
>>> metrics.update(values=batch_loss)
|
@@ -223,7 +223,7 @@ class Accuracy(Average):
|
|
223
223
|
|
224
224
|
Example usage::
|
225
225
|
|
226
|
-
>>> import brainstate as
|
226
|
+
>>> import brainstate as brainstate
|
227
227
|
>>> import jax, jax.numpy as jnp
|
228
228
|
|
229
229
|
>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
|
@@ -231,7 +231,7 @@ class Accuracy(Average):
|
|
231
231
|
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
|
232
232
|
>>> labels2 = jnp.array([0, 1, 1, 1, 1])
|
233
233
|
|
234
|
-
>>> metrics =
|
234
|
+
>>> metrics = brainstate.nn.metrics.Accuracy()
|
235
235
|
>>> metrics.compute()
|
236
236
|
Array(nan, dtype=float32)
|
237
237
|
>>> metrics.update(logits=logits, labels=labels)
|
@@ -36,29 +36,29 @@ class OptaxOptimizer(Optimizer):
|
|
36
36
|
|
37
37
|
>>> import jax
|
38
38
|
>>> import jax.numpy as jnp
|
39
|
-
>>> import brainstate as
|
39
|
+
>>> import brainstate as brainstate
|
40
40
|
>>> import optax
|
41
41
|
...
|
42
|
-
>>> class Model(
|
42
|
+
>>> class Model(brainstate.nn.Module):
|
43
43
|
... def __init__(self):
|
44
44
|
... super().__init__()
|
45
|
-
... self.linear1 =
|
46
|
-
... self.linear2 =
|
45
|
+
... self.linear1 = brainstate.nn.Linear(2, 3)
|
46
|
+
... self.linear2 = brainstate.nn.Linear(3, 4)
|
47
47
|
... def __call__(self, x):
|
48
48
|
... return self.linear2(self.linear1(x))
|
49
49
|
...
|
50
|
-
>>> x =
|
50
|
+
>>> x = brainstate.random.randn(1, 2)
|
51
51
|
>>> y = jnp.ones((1, 4))
|
52
52
|
...
|
53
53
|
>>> model = Model()
|
54
54
|
>>> tx = optax.adam(1e-3)
|
55
|
-
>>> optimizer =
|
56
|
-
>>> optimizer.register_trainable_weights(model.states(
|
55
|
+
>>> optimizer = brainstate.optim.OptaxOptimizer(tx)
|
56
|
+
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
|
57
57
|
...
|
58
58
|
>>> loss_fn = lambda: ((model(x) - y) ** 2).mean()
|
59
59
|
>>> loss_fn()
|
60
60
|
Array(1.7055722, dtype=float32)
|
61
|
-
>>> grads =
|
61
|
+
>>> grads = brainstate.augment.grad(loss_fn, model.states(brainstate.ParamState))()
|
62
62
|
>>> optimizer.update(grads)
|
63
63
|
>>> loss_fn()
|
64
64
|
Array(1.6925814, dtype=float32)
|
brainstate/random/_rand_funs.py
CHANGED
@@ -78,8 +78,8 @@ def rand(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
|
|
78
78
|
|
79
79
|
Examples
|
80
80
|
--------
|
81
|
-
>>> import brainstate as
|
82
|
-
>>>
|
81
|
+
>>> import brainstate as brainstate
|
82
|
+
>>> brainstate.random.rand(3,2)
|
83
83
|
array([[ 0.14022471, 0.96360618], #random
|
84
84
|
[ 0.37601032, 0.25528411], #random
|
85
85
|
[ 0.49313049, 0.94909878]]) #random
|
@@ -135,31 +135,31 @@ def randint(low, high=None, size: Optional[Size] = None,
|
|
135
135
|
|
136
136
|
Examples
|
137
137
|
--------
|
138
|
-
>>> import brainstate as
|
139
|
-
>>>
|
138
|
+
>>> import brainstate as brainstate
|
139
|
+
>>> brainstate.random.randint(2, size=10)
|
140
140
|
array([1, 0, 0, 0, 1, 1, 0, 0, 1, 0]) # random
|
141
|
-
>>>
|
141
|
+
>>> brainstate.random.randint(1, size=10)
|
142
142
|
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
143
143
|
|
144
144
|
Generate a 2 x 4 array of ints between 0 and 4, inclusive:
|
145
145
|
|
146
|
-
>>>
|
146
|
+
>>> brainstate.random.randint(5, size=(2, 4))
|
147
147
|
array([[4, 0, 2, 1], # random
|
148
148
|
[3, 2, 2, 0]])
|
149
149
|
|
150
150
|
Generate a 1 x 3 array with 3 different upper bounds
|
151
151
|
|
152
|
-
>>>
|
152
|
+
>>> brainstate.random.randint(1, [3, 5, 10])
|
153
153
|
array([2, 2, 9]) # random
|
154
154
|
|
155
155
|
Generate a 1 by 3 array with 3 different lower bounds
|
156
156
|
|
157
|
-
>>>
|
157
|
+
>>> brainstate.random.randint([1, 5, 7], 10)
|
158
158
|
array([9, 8, 7]) # random
|
159
159
|
|
160
160
|
Generate a 2 by 4 array using broadcasting with dtype of uint8
|
161
161
|
|
162
|
-
>>>
|
162
|
+
>>> brainstate.random.randint([1, 3, 5, 7], [[10], [20]], dtype=np.uint8)
|
163
163
|
array([[ 8, 6, 9, 7], # random
|
164
164
|
[ 1, 16, 9, 12]], dtype=uint8)
|
165
165
|
"""
|
@@ -219,12 +219,12 @@ def random_integers(low,
|
|
219
219
|
|
220
220
|
Examples
|
221
221
|
--------
|
222
|
-
>>> import brainstate as
|
223
|
-
>>>
|
222
|
+
>>> import brainstate as brainstate
|
223
|
+
>>> brainstate.random.random_integers(5)
|
224
224
|
4 # random
|
225
|
-
>>> type(
|
225
|
+
>>> type(brainstate.random.random_integers(5))
|
226
226
|
<class 'numpy.int64'>
|
227
|
-
>>>
|
227
|
+
>>> brainstate.random.random_integers(5, size=(3,2))
|
228
228
|
array([[5, 4], # random
|
229
229
|
[3, 3],
|
230
230
|
[4, 5]])
|
@@ -233,13 +233,13 @@ def random_integers(low,
|
|
233
233
|
numbers between 0 and 2.5, inclusive (*i.e.*, from the set
|
234
234
|
:math:`{0, 5/8, 10/8, 15/8, 20/8}`):
|
235
235
|
|
236
|
-
>>> 2.5 * (
|
236
|
+
>>> 2.5 * (brainstate.random.random_integers(5, size=(5,)) - 1) / 4.
|
237
237
|
array([ 0.625, 1.25 , 0.625, 0.625, 2.5 ]) # random
|
238
238
|
|
239
239
|
Roll two six sided dice 1000 times and sum the results:
|
240
240
|
|
241
|
-
>>> d1 =
|
242
|
-
>>> d2 =
|
241
|
+
>>> d1 = brainstate.random.random_integers(1, 6, 1000)
|
242
|
+
>>> d2 = brainstate.random.random_integers(1, 6, 1000)
|
243
243
|
>>> dsums = d1 + d2
|
244
244
|
|
245
245
|
Display results as a histogram:
|
@@ -301,13 +301,13 @@ def randn(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
|
|
301
301
|
|
302
302
|
Examples
|
303
303
|
--------
|
304
|
-
>>> import brainstate as
|
305
|
-
>>>
|
304
|
+
>>> import brainstate as brainstate
|
305
|
+
>>> brainstate.random.randn()
|
306
306
|
2.1923875335537315 # random
|
307
307
|
|
308
308
|
Two-by-four array of samples from N(3, 6.25):
|
309
309
|
|
310
|
-
>>> 3 + 2.5 *
|
310
|
+
>>> 3 + 2.5 * brainstate.random.randn(2, 4)
|
311
311
|
array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random
|
312
312
|
[ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random
|
313
313
|
"""
|
@@ -359,17 +359,17 @@ def random_sample(size: Optional[Size] = None, key: Optional[SeedOrKey] = None,
|
|
359
359
|
|
360
360
|
Examples
|
361
361
|
--------
|
362
|
-
>>> import brainstate as
|
363
|
-
>>>
|
362
|
+
>>> import brainstate as brainstate
|
363
|
+
>>> brainstate.random.random_sample()
|
364
364
|
0.47108547995356098 # random
|
365
|
-
>>> type(
|
365
|
+
>>> type(brainstate.random.random_sample())
|
366
366
|
<class 'float'>
|
367
|
-
>>>
|
367
|
+
>>> brainstate.random.random_sample((5,))
|
368
368
|
array([ 0.30220482, 0.86820401, 0.1654503 , 0.11659149, 0.54323428]) # random
|
369
369
|
|
370
370
|
Three-by-two array of random numbers from [-5, 0):
|
371
371
|
|
372
|
-
>>> 5 *
|
372
|
+
>>> 5 * brainstate.random.random_sample((3, 2)) - 5
|
373
373
|
array([[-3.99149989, -0.52338984], # random
|
374
374
|
[-2.99091858, -0.79479508],
|
375
375
|
[-1.23204345, -1.75224494]])
|
@@ -450,34 +450,34 @@ def choice(a, size: Optional[Size] = None, replace=True, p=None,
|
|
450
450
|
--------
|
451
451
|
Generate a uniform random sample from np.arange(5) of size 3:
|
452
452
|
|
453
|
-
>>> import brainstate as
|
454
|
-
>>>
|
453
|
+
>>> import brainstate as brainstate
|
454
|
+
>>> brainstate.random.choice(5, 3)
|
455
455
|
array([0, 3, 4]) # random
|
456
456
|
>>> #This is equivalent to brainpy.math.random.randint(0,5,3)
|
457
457
|
|
458
458
|
Generate a non-uniform random sample from np.arange(5) of size 3:
|
459
459
|
|
460
|
-
>>>
|
460
|
+
>>> brainstate.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0])
|
461
461
|
array([3, 3, 0]) # random
|
462
462
|
|
463
463
|
Generate a uniform random sample from np.arange(5) of size 3 without
|
464
464
|
replacement:
|
465
465
|
|
466
|
-
>>>
|
466
|
+
>>> brainstate.random.choice(5, 3, replace=False)
|
467
467
|
array([3,1,0]) # random
|
468
468
|
>>> #This is equivalent to brainpy.math.random.permutation(np.arange(5))[:3]
|
469
469
|
|
470
470
|
Generate a non-uniform random sample from np.arange(5) of size
|
471
471
|
3 without replacement:
|
472
472
|
|
473
|
-
>>>
|
473
|
+
>>> brainstate.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0])
|
474
474
|
array([2, 3, 0]) # random
|
475
475
|
|
476
476
|
Any of the above can be repeated with an arbitrary array-like
|
477
477
|
instead of just integers. For instance:
|
478
478
|
|
479
479
|
>>> aa_milne_arr = ['pooh', 'rabbit', 'piglet', 'Christopher']
|
480
|
-
>>>
|
480
|
+
>>> brainstate.random.choice(aa_milne_arr, 5, p=[0.5, 0.1, 0.1, 0.3])
|
481
481
|
array(['pooh', 'pooh', 'pooh', 'Christopher', 'piglet'], # random
|
482
482
|
dtype='<U11')
|
483
483
|
"""
|
@@ -519,15 +519,15 @@ def permutation(x,
|
|
519
519
|
|
520
520
|
Examples
|
521
521
|
--------
|
522
|
-
>>> import brainstate as
|
523
|
-
>>>
|
522
|
+
>>> import brainstate as brainstate
|
523
|
+
>>> brainstate.random.permutation(10)
|
524
524
|
array([1, 7, 4, 3, 0, 9, 2, 5, 8, 6]) # random
|
525
525
|
|
526
|
-
>>>
|
526
|
+
>>> brainstate.random.permutation([1, 4, 9, 12, 15])
|
527
527
|
array([15, 1, 9, 4, 12]) # random
|
528
528
|
|
529
529
|
>>> arr = np.arange(9).reshape((3, 3))
|
530
|
-
>>>
|
530
|
+
>>> brainstate.random.permutation(arr)
|
531
531
|
array([[6, 7, 8], # random
|
532
532
|
[0, 1, 2],
|
533
533
|
[3, 4, 5]])
|
@@ -557,16 +557,16 @@ def shuffle(x, axis=0, key: Optional[SeedOrKey] = None):
|
|
557
557
|
|
558
558
|
Examples
|
559
559
|
--------
|
560
|
-
>>> import brainstate as
|
560
|
+
>>> import brainstate as brainstate
|
561
561
|
>>> arr = np.arange(10)
|
562
|
-
>>>
|
562
|
+
>>> brainstate.random.shuffle(arr)
|
563
563
|
>>> arr
|
564
564
|
[1 7 5 2 9 4 3 6 0 8] # random
|
565
565
|
|
566
566
|
Multi-dimensional arrays are only shuffled along the first axis:
|
567
567
|
|
568
568
|
>>> arr = np.arange(9).reshape((3, 3))
|
569
|
-
>>>
|
569
|
+
>>> brainstate.random.shuffle(arr)
|
570
570
|
>>> arr
|
571
571
|
array([[3, 4, 5], # random
|
572
572
|
[6, 7, 8],
|
@@ -562,6 +562,6 @@ class TestRandom(unittest.TestCase):
|
|
562
562
|
|
563
563
|
# class TestRandomKey(unittest.TestCase):
|
564
564
|
# def test_clear_memory(self):
|
565
|
-
#
|
566
|
-
# print(
|
567
|
-
# self.assertTrue(isinstance(
|
565
|
+
# brainstate.random.split_key()
|
566
|
+
# print(brainstate.random.DEFAULT.value)
|
567
|
+
# self.assertTrue(isinstance(brainstate.random.DEFAULT.value, np.ndarray))
|
brainstate/random/_rand_seed.py
CHANGED
@@ -183,16 +183,16 @@ def seed_context(seed_or_key: SeedOrKey):
|
|
183
183
|
|
184
184
|
Examples:
|
185
185
|
|
186
|
-
>>> import brainstate as
|
187
|
-
>>> print(
|
186
|
+
>>> import brainstate as brainstate
|
187
|
+
>>> print(brainstate.random.rand(2))
|
188
188
|
[0.57721865 0.9820676 ]
|
189
|
-
>>> print(
|
189
|
+
>>> print(brainstate.random.rand(2))
|
190
190
|
[0.8511752 0.95312667]
|
191
|
-
>>> with
|
192
|
-
... print(
|
191
|
+
>>> with brainstate.random.seed_context(42):
|
192
|
+
... print(brainstate.random.rand(2))
|
193
193
|
[0.95598125 0.4032725 ]
|
194
|
-
>>> with
|
195
|
-
... print(
|
194
|
+
>>> with brainstate.random.seed_context(42):
|
195
|
+
... print(brainstate.random.rand(2))
|
196
196
|
[0.95598125 0.4032725 ]
|
197
197
|
|
198
198
|
Args:
|