brainstate 0.1.1__py2.py3-none-any.whl → 0.1.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +12 -9
- brainstate/_state.py +1 -1
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +29 -14
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/functional/_activations_test.py +61 -61
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/mixin.py +1 -14
- brainstate/nn/__init__.py +81 -17
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
- brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +31 -33
- brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +6 -2
- brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +1 -1
- brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +9 -9
- brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +139 -18
- brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +14 -15
- brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
- brainstate/nn/_elementwise_test.py +169 -0
- brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob_mv.py} +1 -1
- brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_mv_test.py} +0 -1
- brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
- brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
- brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +1 -1
- brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -1
- brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +15 -17
- brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
- brainstate/nn/_module_test.py +34 -37
- brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
- brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +18 -19
- brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
- brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +10 -12
- brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
- brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +20 -22
- brainstate/nn/{_dynamics/_projection_base.py → _projection.py} +35 -3
- brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
- brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +6 -7
- brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
- brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +9 -10
- brainstate/nn/_stp.py +236 -0
- brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +17 -206
- brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +9 -10
- brainstate/nn/_synaptic_projection.py +133 -0
- brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
- brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +4 -5
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed_test.py +10 -12
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/METADATA +1 -1
- brainstate-0.1.3.dist-info/RECORD +131 -0
- brainstate/nn/_dyn_impl/__init__.py +0 -42
- brainstate/nn/_dynamics/__init__.py +0 -37
- brainstate/nn/_elementwise/__init__.py +0 -22
- brainstate/nn/_elementwise/_elementwise_test.py +0 -171
- brainstate/nn/_interaction/__init__.py +0 -41
- brainstate-0.1.1.dist-info/RECORD +0 -133
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/LICENSE +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/WHEEL +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
16
|
|
18
17
|
import platform
|
19
18
|
import unittest
|
@@ -23,110 +22,110 @@ import jax.random as jr
|
|
23
22
|
import numpy as np
|
24
23
|
import pytest
|
25
24
|
|
26
|
-
import brainstate
|
25
|
+
import brainstate
|
27
26
|
|
28
27
|
|
29
28
|
class TestRandom(unittest.TestCase):
|
30
29
|
|
31
30
|
def test_rand(self):
|
32
|
-
|
33
|
-
a =
|
31
|
+
brainstate.random.seed()
|
32
|
+
a = brainstate.random.rand(3, 2)
|
34
33
|
self.assertTupleEqual(a.shape, (3, 2))
|
35
34
|
self.assertTrue((a >= 0).all() and (a < 1).all())
|
36
35
|
|
37
36
|
key = jr.PRNGKey(123)
|
38
37
|
jres = jr.uniform(key, shape=(10, 100))
|
39
|
-
self.assertTrue(jnp.allclose(jres,
|
40
|
-
self.assertTrue(jnp.allclose(jres,
|
38
|
+
self.assertTrue(jnp.allclose(jres, brainstate.random.rand(10, 100, key=key)))
|
39
|
+
self.assertTrue(jnp.allclose(jres, brainstate.random.rand(10, 100, key=123)))
|
41
40
|
|
42
41
|
def test_randint1(self):
|
43
|
-
|
44
|
-
a =
|
42
|
+
brainstate.random.seed()
|
43
|
+
a = brainstate.random.randint(5)
|
45
44
|
self.assertTupleEqual(a.shape, ())
|
46
45
|
self.assertTrue(0 <= a < 5)
|
47
46
|
|
48
47
|
def test_randint2(self):
|
49
|
-
|
50
|
-
a =
|
48
|
+
brainstate.random.seed()
|
49
|
+
a = brainstate.random.randint(2, 6, size=(4, 3))
|
51
50
|
self.assertTupleEqual(a.shape, (4, 3))
|
52
51
|
self.assertTrue((a >= 2).all() and (a < 6).all())
|
53
52
|
|
54
53
|
def test_randint3(self):
|
55
|
-
|
56
|
-
a =
|
54
|
+
brainstate.random.seed()
|
55
|
+
a = brainstate.random.randint([1, 2, 3], [10, 7, 8])
|
57
56
|
self.assertTupleEqual(a.shape, (3,))
|
58
57
|
self.assertTrue((a - jnp.array([1, 2, 3]) >= 0).all()
|
59
58
|
and (-a + jnp.array([10, 7, 8]) > 0).all())
|
60
59
|
|
61
60
|
def test_randint4(self):
|
62
|
-
|
63
|
-
a =
|
61
|
+
brainstate.random.seed()
|
62
|
+
a = brainstate.random.randint([1, 2, 3], [10, 7, 8], size=(2, 3))
|
64
63
|
self.assertTupleEqual(a.shape, (2, 3))
|
65
64
|
|
66
65
|
def test_randn(self):
|
67
|
-
|
68
|
-
a =
|
66
|
+
brainstate.random.seed()
|
67
|
+
a = brainstate.random.randn(3, 2)
|
69
68
|
self.assertTupleEqual(a.shape, (3, 2))
|
70
69
|
|
71
70
|
def test_random1(self):
|
72
|
-
|
73
|
-
a =
|
71
|
+
brainstate.random.seed()
|
72
|
+
a = brainstate.random.random()
|
74
73
|
self.assertTrue(0. <= a < 1)
|
75
74
|
|
76
75
|
def test_random2(self):
|
77
|
-
|
78
|
-
a =
|
76
|
+
brainstate.random.seed()
|
77
|
+
a = brainstate.random.random(size=(3, 2))
|
79
78
|
self.assertTupleEqual(a.shape, (3, 2))
|
80
79
|
self.assertTrue((a >= 0).all() and (a < 1).all())
|
81
80
|
|
82
81
|
def test_random_sample(self):
|
83
|
-
|
84
|
-
a =
|
82
|
+
brainstate.random.seed()
|
83
|
+
a = brainstate.random.random_sample(size=(3, 2))
|
85
84
|
self.assertTupleEqual(a.shape, (3, 2))
|
86
85
|
self.assertTrue((a >= 0).all() and (a < 1).all())
|
87
86
|
|
88
87
|
def test_choice1(self):
|
89
|
-
|
90
|
-
a =
|
88
|
+
brainstate.random.seed()
|
89
|
+
a = brainstate.random.choice(5)
|
91
90
|
self.assertTupleEqual(jnp.shape(a), ())
|
92
91
|
self.assertTrue(0 <= a < 5)
|
93
92
|
|
94
93
|
def test_choice2(self):
|
95
|
-
|
96
|
-
a =
|
94
|
+
brainstate.random.seed()
|
95
|
+
a = brainstate.random.choice(5, 3, p=[0.1, 0.4, 0.2, 0., 0.3])
|
97
96
|
self.assertTupleEqual(a.shape, (3,))
|
98
97
|
self.assertTrue((a >= 0).all() and (a < 5).all())
|
99
98
|
|
100
99
|
def test_choice3(self):
|
101
|
-
|
102
|
-
a =
|
100
|
+
brainstate.random.seed()
|
101
|
+
a = brainstate.random.choice(jnp.arange(2, 20), size=(4, 3), replace=False)
|
103
102
|
self.assertTupleEqual(a.shape, (4, 3))
|
104
103
|
self.assertTrue((a >= 2).all() and (a < 20).all())
|
105
104
|
self.assertEqual(len(jnp.unique(a)), 12)
|
106
105
|
|
107
106
|
def test_permutation1(self):
|
108
|
-
|
109
|
-
a =
|
107
|
+
brainstate.random.seed()
|
108
|
+
a = brainstate.random.permutation(10)
|
110
109
|
self.assertTupleEqual(a.shape, (10,))
|
111
110
|
self.assertEqual(len(jnp.unique(a)), 10)
|
112
111
|
|
113
112
|
def test_permutation2(self):
|
114
|
-
|
115
|
-
a =
|
113
|
+
brainstate.random.seed()
|
114
|
+
a = brainstate.random.permutation(jnp.arange(10))
|
116
115
|
self.assertTupleEqual(a.shape, (10,))
|
117
116
|
self.assertEqual(len(jnp.unique(a)), 10)
|
118
117
|
|
119
118
|
def test_shuffle1(self):
|
120
|
-
|
119
|
+
brainstate.random.seed()
|
121
120
|
a = jnp.arange(10)
|
122
|
-
|
121
|
+
brainstate.random.shuffle(a)
|
123
122
|
self.assertTupleEqual(a.shape, (10,))
|
124
123
|
self.assertEqual(len(jnp.unique(a)), 10)
|
125
124
|
|
126
125
|
def test_shuffle2(self):
|
127
|
-
|
126
|
+
brainstate.random.seed()
|
128
127
|
a = jnp.arange(12).reshape(4, 3)
|
129
|
-
|
128
|
+
brainstate.random.shuffle(a, axis=1)
|
130
129
|
self.assertTupleEqual(a.shape, (4, 3))
|
131
130
|
self.assertEqual(len(jnp.unique(a)), 12)
|
132
131
|
|
@@ -135,173 +134,173 @@ class TestRandom(unittest.TestCase):
|
|
135
134
|
self.assertEqual(uni, jnp.asarray([3]))
|
136
135
|
|
137
136
|
def test_beta1(self):
|
138
|
-
|
139
|
-
a =
|
137
|
+
brainstate.random.seed()
|
138
|
+
a = brainstate.random.beta(2, 2)
|
140
139
|
self.assertTupleEqual(a.shape, ())
|
141
140
|
|
142
141
|
def test_beta2(self):
|
143
|
-
|
144
|
-
a =
|
142
|
+
brainstate.random.seed()
|
143
|
+
a = brainstate.random.beta([2, 2, 3], 2, size=(3,))
|
145
144
|
self.assertTupleEqual(a.shape, (3,))
|
146
145
|
|
147
146
|
def test_exponential1(self):
|
148
|
-
|
149
|
-
a =
|
147
|
+
brainstate.random.seed()
|
148
|
+
a = brainstate.random.exponential(10., size=[3, 2])
|
150
149
|
self.assertTupleEqual(a.shape, (3, 2))
|
151
150
|
|
152
151
|
def test_exponential2(self):
|
153
|
-
|
154
|
-
a =
|
152
|
+
brainstate.random.seed()
|
153
|
+
a = brainstate.random.exponential([1., 2., 5.])
|
155
154
|
self.assertTupleEqual(a.shape, (3,))
|
156
155
|
|
157
156
|
def test_gamma(self):
|
158
|
-
|
159
|
-
a =
|
157
|
+
brainstate.random.seed()
|
158
|
+
a = brainstate.random.gamma(2, 10., size=[3, 2])
|
160
159
|
self.assertTupleEqual(a.shape, (3, 2))
|
161
160
|
|
162
161
|
def test_gumbel(self):
|
163
|
-
|
164
|
-
a =
|
162
|
+
brainstate.random.seed()
|
163
|
+
a = brainstate.random.gumbel(0., 2., size=[3, 2])
|
165
164
|
self.assertTupleEqual(a.shape, (3, 2))
|
166
165
|
|
167
166
|
def test_laplace(self):
|
168
|
-
|
169
|
-
a =
|
167
|
+
brainstate.random.seed()
|
168
|
+
a = brainstate.random.laplace(0., 2., size=[3, 2])
|
170
169
|
self.assertTupleEqual(a.shape, (3, 2))
|
171
170
|
|
172
171
|
def test_logistic(self):
|
173
|
-
|
174
|
-
a =
|
172
|
+
brainstate.random.seed()
|
173
|
+
a = brainstate.random.logistic(0., 2., size=[3, 2])
|
175
174
|
self.assertTupleEqual(a.shape, (3, 2))
|
176
175
|
|
177
176
|
def test_normal1(self):
|
178
|
-
|
179
|
-
a =
|
177
|
+
brainstate.random.seed()
|
178
|
+
a = brainstate.random.normal()
|
180
179
|
self.assertTupleEqual(a.shape, ())
|
181
180
|
|
182
181
|
def test_normal2(self):
|
183
|
-
|
184
|
-
a =
|
182
|
+
brainstate.random.seed()
|
183
|
+
a = brainstate.random.normal(loc=[0., 2., 4.], scale=[1., 2., 3.])
|
185
184
|
self.assertTupleEqual(a.shape, (3,))
|
186
185
|
|
187
186
|
def test_normal3(self):
|
188
|
-
|
189
|
-
a =
|
187
|
+
brainstate.random.seed()
|
188
|
+
a = brainstate.random.normal(loc=[0., 2., 4.], scale=[[1., 2., 3.], [1., 1., 1.]])
|
190
189
|
print(a)
|
191
190
|
self.assertTupleEqual(a.shape, (2, 3))
|
192
191
|
|
193
192
|
def test_pareto(self):
|
194
|
-
|
195
|
-
a =
|
193
|
+
brainstate.random.seed()
|
194
|
+
a = brainstate.random.pareto([1, 2, 2])
|
196
195
|
self.assertTupleEqual(a.shape, (3,))
|
197
196
|
|
198
197
|
def test_poisson(self):
|
199
|
-
|
200
|
-
a =
|
198
|
+
brainstate.random.seed()
|
199
|
+
a = brainstate.random.poisson([1., 2., 2.], size=3)
|
201
200
|
self.assertTupleEqual(a.shape, (3,))
|
202
201
|
|
203
202
|
def test_standard_cauchy(self):
|
204
|
-
|
205
|
-
a =
|
203
|
+
brainstate.random.seed()
|
204
|
+
a = brainstate.random.standard_cauchy(size=(3, 2))
|
206
205
|
self.assertTupleEqual(a.shape, (3, 2))
|
207
206
|
|
208
207
|
def test_standard_exponential(self):
|
209
|
-
|
210
|
-
a =
|
208
|
+
brainstate.random.seed()
|
209
|
+
a = brainstate.random.standard_exponential(size=(3, 2))
|
211
210
|
self.assertTupleEqual(a.shape, (3, 2))
|
212
211
|
|
213
212
|
def test_standard_gamma(self):
|
214
|
-
|
215
|
-
a =
|
213
|
+
brainstate.random.seed()
|
214
|
+
a = brainstate.random.standard_gamma(shape=[1, 2, 4], size=3)
|
216
215
|
self.assertTupleEqual(a.shape, (3,))
|
217
216
|
|
218
217
|
def test_standard_normal(self):
|
219
|
-
|
220
|
-
a =
|
218
|
+
brainstate.random.seed()
|
219
|
+
a = brainstate.random.standard_normal(size=(3, 2))
|
221
220
|
self.assertTupleEqual(a.shape, (3, 2))
|
222
221
|
|
223
222
|
def test_standard_t(self):
|
224
|
-
|
225
|
-
a =
|
223
|
+
brainstate.random.seed()
|
224
|
+
a = brainstate.random.standard_t(df=[1, 2, 4], size=3)
|
226
225
|
self.assertTupleEqual(a.shape, (3,))
|
227
226
|
|
228
227
|
def test_standard_uniform1(self):
|
229
|
-
|
230
|
-
a =
|
228
|
+
brainstate.random.seed()
|
229
|
+
a = brainstate.random.uniform()
|
231
230
|
self.assertTupleEqual(a.shape, ())
|
232
231
|
self.assertTrue(0 <= a < 1)
|
233
232
|
|
234
233
|
def test_uniform2(self):
|
235
|
-
|
236
|
-
a =
|
234
|
+
brainstate.random.seed()
|
235
|
+
a = brainstate.random.uniform(low=[-1., 5., 2.], high=[2., 6., 10.], size=3)
|
237
236
|
self.assertTupleEqual(a.shape, (3,))
|
238
237
|
self.assertTrue((a - jnp.array([-1., 5., 2.]) >= 0).all()
|
239
238
|
and (-a + jnp.array([2., 6., 10.]) > 0).all())
|
240
239
|
|
241
240
|
def test_uniform3(self):
|
242
|
-
|
243
|
-
a =
|
241
|
+
brainstate.random.seed()
|
242
|
+
a = brainstate.random.uniform(low=-1., high=[2., 6., 10.], size=(2, 3))
|
244
243
|
self.assertTupleEqual(a.shape, (2, 3))
|
245
244
|
|
246
245
|
def test_uniform4(self):
|
247
|
-
|
248
|
-
a =
|
246
|
+
brainstate.random.seed()
|
247
|
+
a = brainstate.random.uniform(low=[-1., 5., 2.], high=[[2., 6., 10.], [10., 10., 10.]])
|
249
248
|
self.assertTupleEqual(a.shape, (2, 3))
|
250
249
|
|
251
250
|
def test_truncated_normal1(self):
|
252
|
-
|
253
|
-
a =
|
251
|
+
brainstate.random.seed()
|
252
|
+
a = brainstate.random.truncated_normal(-1., 1.)
|
254
253
|
self.assertTupleEqual(a.shape, ())
|
255
254
|
self.assertTrue(-1. <= a <= 1.)
|
256
255
|
|
257
256
|
def test_truncated_normal2(self):
|
258
|
-
|
259
|
-
a =
|
257
|
+
brainstate.random.seed()
|
258
|
+
a = brainstate.random.truncated_normal(-1., [1., 2., 1.], size=(4, 3))
|
260
259
|
self.assertTupleEqual(a.shape, (4, 3))
|
261
260
|
|
262
261
|
def test_truncated_normal3(self):
|
263
|
-
|
264
|
-
a =
|
262
|
+
brainstate.random.seed()
|
263
|
+
a = brainstate.random.truncated_normal([-1., 0., 1.], [[2., 2., 4.], [2., 2., 4.]])
|
265
264
|
self.assertTupleEqual(a.shape, (2, 3))
|
266
265
|
self.assertTrue((a - jnp.array([-1., 0., 1.]) >= 0.).all()
|
267
266
|
and (- a + jnp.array([2., 2., 4.]) >= 0.).all())
|
268
267
|
|
269
268
|
def test_bernoulli1(self):
|
270
|
-
|
271
|
-
a =
|
269
|
+
brainstate.random.seed()
|
270
|
+
a = brainstate.random.bernoulli()
|
272
271
|
self.assertTupleEqual(a.shape, ())
|
273
272
|
self.assertTrue(a == 0 or a == 1)
|
274
273
|
|
275
274
|
def test_bernoulli2(self):
|
276
|
-
|
277
|
-
a =
|
275
|
+
brainstate.random.seed()
|
276
|
+
a = brainstate.random.bernoulli([0.5, 0.6, 0.8])
|
278
277
|
self.assertTupleEqual(a.shape, (3,))
|
279
278
|
self.assertTrue(jnp.logical_xor(a == 1, a == 0).all())
|
280
279
|
|
281
280
|
def test_bernoulli3(self):
|
282
|
-
|
283
|
-
a =
|
281
|
+
brainstate.random.seed()
|
282
|
+
a = brainstate.random.bernoulli([0.5, 0.6], size=(3, 2))
|
284
283
|
self.assertTupleEqual(a.shape, (3, 2))
|
285
284
|
self.assertTrue(jnp.logical_xor(a == 1, a == 0).all())
|
286
285
|
|
287
286
|
def test_lognormal1(self):
|
288
|
-
|
289
|
-
a =
|
287
|
+
brainstate.random.seed()
|
288
|
+
a = brainstate.random.lognormal()
|
290
289
|
self.assertTupleEqual(a.shape, ())
|
291
290
|
|
292
291
|
def test_lognormal2(self):
|
293
|
-
|
294
|
-
a =
|
292
|
+
brainstate.random.seed()
|
293
|
+
a = brainstate.random.lognormal(sigma=[2., 1.], size=[3, 2])
|
295
294
|
self.assertTupleEqual(a.shape, (3, 2))
|
296
295
|
|
297
296
|
def test_lognormal3(self):
|
298
|
-
|
299
|
-
a =
|
297
|
+
brainstate.random.seed()
|
298
|
+
a = brainstate.random.lognormal([2., 0.], [[2., 1.], [3., 1.2]])
|
300
299
|
self.assertTupleEqual(a.shape, (2, 2))
|
301
300
|
|
302
301
|
def test_binomial1(self):
|
303
|
-
|
304
|
-
a =
|
302
|
+
brainstate.random.seed()
|
303
|
+
a = brainstate.random.binomial(5, 0.5)
|
305
304
|
b = np.random.binomial(5, 0.5)
|
306
305
|
print(a)
|
307
306
|
print(b)
|
@@ -309,99 +308,99 @@ class TestRandom(unittest.TestCase):
|
|
309
308
|
self.assertTrue(a.dtype, int)
|
310
309
|
|
311
310
|
def test_binomial2(self):
|
312
|
-
|
313
|
-
a =
|
311
|
+
brainstate.random.seed()
|
312
|
+
a = brainstate.random.binomial(5, 0.5, size=(3, 2))
|
314
313
|
self.assertTupleEqual(a.shape, (3, 2))
|
315
314
|
self.assertTrue((a >= 0).all() and (a <= 5).all())
|
316
315
|
|
317
316
|
def test_binomial3(self):
|
318
|
-
|
319
|
-
a =
|
317
|
+
brainstate.random.seed()
|
318
|
+
a = brainstate.random.binomial(n=jnp.asarray([2, 3, 4]), p=jnp.asarray([[0.5, 0.5, 0.5], [0.6, 0.6, 0.6]]))
|
320
319
|
self.assertTupleEqual(a.shape, (2, 3))
|
321
320
|
|
322
321
|
def test_chisquare1(self):
|
323
|
-
|
324
|
-
a =
|
322
|
+
brainstate.random.seed()
|
323
|
+
a = brainstate.random.chisquare(3)
|
325
324
|
self.assertTupleEqual(a.shape, ())
|
326
325
|
self.assertTrue(a.dtype, float)
|
327
326
|
|
328
327
|
def test_chisquare2(self):
|
329
|
-
|
328
|
+
brainstate.random.seed()
|
330
329
|
with self.assertRaises(NotImplementedError):
|
331
|
-
a =
|
330
|
+
a = brainstate.random.chisquare(df=[2, 3, 4])
|
332
331
|
|
333
332
|
def test_chisquare3(self):
|
334
|
-
|
335
|
-
a =
|
333
|
+
brainstate.random.seed()
|
334
|
+
a = brainstate.random.chisquare(df=2, size=100)
|
336
335
|
self.assertTupleEqual(a.shape, (100,))
|
337
336
|
|
338
337
|
def test_chisquare4(self):
|
339
|
-
|
340
|
-
a =
|
338
|
+
brainstate.random.seed()
|
339
|
+
a = brainstate.random.chisquare(df=2, size=(100, 10))
|
341
340
|
self.assertTupleEqual(a.shape, (100, 10))
|
342
341
|
|
343
342
|
def test_dirichlet1(self):
|
344
|
-
|
345
|
-
a =
|
343
|
+
brainstate.random.seed()
|
344
|
+
a = brainstate.random.dirichlet((10, 5, 3))
|
346
345
|
self.assertTupleEqual(a.shape, (3,))
|
347
346
|
|
348
347
|
def test_dirichlet2(self):
|
349
|
-
|
350
|
-
a =
|
348
|
+
brainstate.random.seed()
|
349
|
+
a = brainstate.random.dirichlet((10, 5, 3), 20)
|
351
350
|
self.assertTupleEqual(a.shape, (20, 3))
|
352
351
|
|
353
352
|
def test_f(self):
|
354
|
-
|
355
|
-
a =
|
353
|
+
brainstate.random.seed()
|
354
|
+
a = brainstate.random.f(1., 48., 100)
|
356
355
|
self.assertTupleEqual(a.shape, (100,))
|
357
356
|
|
358
357
|
def test_geometric(self):
|
359
|
-
|
360
|
-
a =
|
358
|
+
brainstate.random.seed()
|
359
|
+
a = brainstate.random.geometric([0.7, 0.5, 0.2])
|
361
360
|
self.assertTupleEqual(a.shape, (3,))
|
362
361
|
|
363
362
|
def test_hypergeometric1(self):
|
364
|
-
|
365
|
-
a =
|
363
|
+
brainstate.random.seed()
|
364
|
+
a = brainstate.random.hypergeometric(10, 10, 10, 20)
|
366
365
|
self.assertTupleEqual(a.shape, (20,))
|
367
366
|
|
368
367
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error')
|
369
368
|
def test_hypergeometric2(self):
|
370
|
-
|
371
|
-
a =
|
369
|
+
brainstate.random.seed()
|
370
|
+
a = brainstate.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]])
|
372
371
|
self.assertTupleEqual(a.shape, (2, 2))
|
373
372
|
|
374
373
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error')
|
375
374
|
def test_hypergeometric3(self):
|
376
|
-
|
377
|
-
a =
|
375
|
+
brainstate.random.seed()
|
376
|
+
a = brainstate.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]], size=(3, 2, 2))
|
378
377
|
self.assertTupleEqual(a.shape, (3, 2, 2))
|
379
378
|
|
380
379
|
def test_logseries(self):
|
381
|
-
|
382
|
-
a =
|
380
|
+
brainstate.random.seed()
|
381
|
+
a = brainstate.random.logseries([0.7, 0.5, 0.2], size=[4, 3])
|
383
382
|
self.assertTupleEqual(a.shape, (4, 3))
|
384
383
|
|
385
384
|
def test_multinominal1(self):
|
386
|
-
|
385
|
+
brainstate.random.seed()
|
387
386
|
a = np.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2])
|
388
387
|
print(a, a.shape)
|
389
|
-
b =
|
388
|
+
b = brainstate.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2])
|
390
389
|
print(b, b.shape)
|
391
390
|
self.assertTupleEqual(a.shape, b.shape)
|
392
391
|
self.assertTupleEqual(b.shape, (4, 2, 3))
|
393
392
|
|
394
393
|
def test_multinominal2(self):
|
395
|
-
|
396
|
-
a =
|
394
|
+
brainstate.random.seed()
|
395
|
+
a = brainstate.random.multinomial(100, (0.5, 0.2, 0.3))
|
397
396
|
self.assertTupleEqual(a.shape, (3,))
|
398
397
|
self.assertTrue(a.sum() == 100)
|
399
398
|
|
400
399
|
def test_multivariate_normal1(self):
|
401
|
-
|
400
|
+
brainstate.random.seed()
|
402
401
|
# self.skipTest('Windows jaxlib error')
|
403
402
|
a = np.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3)
|
404
|
-
b =
|
403
|
+
b = brainstate.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3)
|
405
404
|
print('test_multivariate_normal1')
|
406
405
|
print(a)
|
407
406
|
print(b)
|
@@ -409,156 +408,156 @@ class TestRandom(unittest.TestCase):
|
|
409
408
|
self.assertTupleEqual(a.shape, (3, 2))
|
410
409
|
|
411
410
|
def test_multivariate_normal2(self):
|
412
|
-
|
411
|
+
brainstate.random.seed()
|
413
412
|
a = np.random.multivariate_normal([1, 2], [[1, 3], [3, 1]])
|
414
|
-
b =
|
413
|
+
b = brainstate.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], method='svd')
|
415
414
|
print(a)
|
416
415
|
print(b)
|
417
416
|
self.assertTupleEqual(a.shape, b.shape)
|
418
417
|
self.assertTupleEqual(a.shape, (2,))
|
419
418
|
|
420
419
|
def test_negative_binomial(self):
|
421
|
-
|
420
|
+
brainstate.random.seed()
|
422
421
|
a = np.random.negative_binomial([3., 10.], 0.5)
|
423
|
-
b =
|
422
|
+
b = brainstate.random.negative_binomial([3., 10.], 0.5)
|
424
423
|
print(a)
|
425
424
|
print(b)
|
426
425
|
self.assertTupleEqual(a.shape, b.shape)
|
427
426
|
self.assertTupleEqual(b.shape, (2,))
|
428
427
|
|
429
428
|
def test_negative_binomial2(self):
|
430
|
-
|
429
|
+
brainstate.random.seed()
|
431
430
|
a = np.random.negative_binomial(3., 0.5, 10)
|
432
|
-
b =
|
431
|
+
b = brainstate.random.negative_binomial(3., 0.5, 10)
|
433
432
|
print(a)
|
434
433
|
print(b)
|
435
434
|
self.assertTupleEqual(a.shape, b.shape)
|
436
435
|
self.assertTupleEqual(b.shape, (10,))
|
437
436
|
|
438
437
|
def test_noncentral_chisquare(self):
|
439
|
-
|
438
|
+
brainstate.random.seed()
|
440
439
|
a = np.random.noncentral_chisquare(3, [3., 2.], (4, 2))
|
441
|
-
b =
|
440
|
+
b = brainstate.random.noncentral_chisquare(3, [3., 2.], (4, 2))
|
442
441
|
self.assertTupleEqual(a.shape, b.shape)
|
443
442
|
self.assertTupleEqual(b.shape, (4, 2))
|
444
443
|
|
445
444
|
def test_noncentral_chisquare2(self):
|
446
|
-
|
447
|
-
a =
|
445
|
+
brainstate.random.seed()
|
446
|
+
a = brainstate.random.noncentral_chisquare(3, [3., 2.])
|
448
447
|
self.assertTupleEqual(a.shape, (2,))
|
449
448
|
|
450
449
|
def test_noncentral_f(self):
|
451
|
-
|
452
|
-
a =
|
450
|
+
brainstate.random.seed()
|
451
|
+
a = brainstate.random.noncentral_f(3, 20, 3., 100)
|
453
452
|
self.assertTupleEqual(a.shape, (100,))
|
454
453
|
|
455
454
|
def test_power(self):
|
456
|
-
|
455
|
+
brainstate.random.seed()
|
457
456
|
a = np.random.power(2, (4, 2))
|
458
|
-
b =
|
457
|
+
b = brainstate.random.power(2, (4, 2))
|
459
458
|
self.assertTupleEqual(a.shape, b.shape)
|
460
459
|
self.assertTupleEqual(b.shape, (4, 2))
|
461
460
|
|
462
461
|
def test_rayleigh(self):
|
463
|
-
|
464
|
-
a =
|
462
|
+
brainstate.random.seed()
|
463
|
+
a = brainstate.random.power(2., (4, 2))
|
465
464
|
self.assertTupleEqual(a.shape, (4, 2))
|
466
465
|
|
467
466
|
def test_triangular(self):
|
468
|
-
|
469
|
-
a =
|
467
|
+
brainstate.random.seed()
|
468
|
+
a = brainstate.random.triangular((2, 2))
|
470
469
|
self.assertTupleEqual(a.shape, (2, 2))
|
471
470
|
|
472
471
|
def test_vonmises(self):
|
473
|
-
|
472
|
+
brainstate.random.seed()
|
474
473
|
a = np.random.vonmises(2., 2.)
|
475
|
-
b =
|
474
|
+
b = brainstate.random.vonmises(2., 2.)
|
476
475
|
print(a, b)
|
477
476
|
self.assertTupleEqual(np.shape(a), b.shape)
|
478
477
|
self.assertTupleEqual(b.shape, ())
|
479
478
|
|
480
479
|
def test_vonmises2(self):
|
481
|
-
|
480
|
+
brainstate.random.seed()
|
482
481
|
a = np.random.vonmises(2., 2., 10)
|
483
|
-
b =
|
482
|
+
b = brainstate.random.vonmises(2., 2., 10)
|
484
483
|
print(a, b)
|
485
484
|
self.assertTupleEqual(a.shape, b.shape)
|
486
485
|
self.assertTupleEqual(b.shape, (10,))
|
487
486
|
|
488
487
|
def test_wald(self):
|
489
|
-
|
488
|
+
brainstate.random.seed()
|
490
489
|
a = np.random.wald([2., 0.5], 2.)
|
491
|
-
b =
|
490
|
+
b = brainstate.random.wald([2., 0.5], 2.)
|
492
491
|
self.assertTupleEqual(a.shape, b.shape)
|
493
492
|
self.assertTupleEqual(b.shape, (2,))
|
494
493
|
|
495
494
|
def test_wald2(self):
|
496
|
-
|
495
|
+
brainstate.random.seed()
|
497
496
|
a = np.random.wald(2., 2., 100)
|
498
|
-
b =
|
497
|
+
b = brainstate.random.wald(2., 2., 100)
|
499
498
|
self.assertTupleEqual(a.shape, b.shape)
|
500
499
|
self.assertTupleEqual(b.shape, (100,))
|
501
500
|
|
502
501
|
def test_weibull(self):
|
503
|
-
|
504
|
-
a =
|
502
|
+
brainstate.random.seed()
|
503
|
+
a = brainstate.random.weibull(2., (4, 2))
|
505
504
|
self.assertTupleEqual(a.shape, (4, 2))
|
506
505
|
|
507
506
|
def test_weibull2(self):
|
508
|
-
|
509
|
-
a =
|
507
|
+
brainstate.random.seed()
|
508
|
+
a = brainstate.random.weibull(2., )
|
510
509
|
self.assertTupleEqual(a.shape, ())
|
511
510
|
|
512
511
|
def test_weibull3(self):
|
513
|
-
|
514
|
-
a =
|
512
|
+
brainstate.random.seed()
|
513
|
+
a = brainstate.random.weibull([2., 3.], )
|
515
514
|
self.assertTupleEqual(a.shape, (2,))
|
516
515
|
|
517
516
|
def test_weibull_min(self):
|
518
|
-
|
519
|
-
a =
|
517
|
+
brainstate.random.seed()
|
518
|
+
a = brainstate.random.weibull_min(2., 2., (4, 2))
|
520
519
|
self.assertTupleEqual(a.shape, (4, 2))
|
521
520
|
|
522
521
|
def test_weibull_min2(self):
|
523
|
-
|
524
|
-
a =
|
522
|
+
brainstate.random.seed()
|
523
|
+
a = brainstate.random.weibull_min(2., 2.)
|
525
524
|
self.assertTupleEqual(a.shape, ())
|
526
525
|
|
527
526
|
def test_weibull_min3(self):
|
528
|
-
|
529
|
-
a =
|
527
|
+
brainstate.random.seed()
|
528
|
+
a = brainstate.random.weibull_min([2., 3.], 2.)
|
530
529
|
self.assertTupleEqual(a.shape, (2,))
|
531
530
|
|
532
531
|
def test_zipf(self):
|
533
|
-
|
534
|
-
a =
|
532
|
+
brainstate.random.seed()
|
533
|
+
a = brainstate.random.zipf(2., (4, 2))
|
535
534
|
self.assertTupleEqual(a.shape, (4, 2))
|
536
535
|
|
537
536
|
def test_zipf2(self):
|
538
|
-
|
537
|
+
brainstate.random.seed()
|
539
538
|
a = np.random.zipf([1.1, 2.])
|
540
|
-
b =
|
539
|
+
b = brainstate.random.zipf([1.1, 2.])
|
541
540
|
self.assertTupleEqual(a.shape, b.shape)
|
542
541
|
self.assertTupleEqual(b.shape, (2,))
|
543
542
|
|
544
543
|
def test_maxwell(self):
|
545
|
-
|
546
|
-
a =
|
544
|
+
brainstate.random.seed()
|
545
|
+
a = brainstate.random.maxwell(10)
|
547
546
|
self.assertTupleEqual(a.shape, (10,))
|
548
547
|
|
549
548
|
def test_maxwell2(self):
|
550
|
-
|
551
|
-
a =
|
549
|
+
brainstate.random.seed()
|
550
|
+
a = brainstate.random.maxwell()
|
552
551
|
self.assertTupleEqual(a.shape, ())
|
553
552
|
|
554
553
|
def test_t(self):
|
555
|
-
|
556
|
-
a =
|
554
|
+
brainstate.random.seed()
|
555
|
+
a = brainstate.random.t(1., size=10)
|
557
556
|
self.assertTupleEqual(a.shape, (10,))
|
558
557
|
|
559
558
|
def test_t2(self):
|
560
|
-
|
561
|
-
a =
|
559
|
+
brainstate.random.seed()
|
560
|
+
a = brainstate.random.t([1., 2.], size=None)
|
562
561
|
self.assertTupleEqual(a.shape, (2,))
|
563
562
|
|
564
563
|
# class TestRandomKey(unittest.TestCase):
|