brainstate 0.0.2.post20240910__py2.py3-none-any.whl → 0.0.2.post20241009__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +4 -2
- brainstate/_module.py +102 -67
- brainstate/_state.py +2 -2
- brainstate/_visualization.py +47 -0
- brainstate/environ.py +116 -9
- brainstate/environ_test.py +56 -0
- brainstate/functional/_activations.py +134 -56
- brainstate/functional/_activations_test.py +331 -0
- brainstate/functional/_normalization.py +21 -10
- brainstate/init/_generic.py +4 -2
- brainstate/init/_regular_inits.py +7 -7
- brainstate/mixin.py +1 -1
- brainstate/nn/__init__.py +7 -2
- brainstate/nn/_base.py +2 -2
- brainstate/nn/_connections.py +4 -4
- brainstate/nn/_dynamics.py +5 -5
- brainstate/nn/_elementwise.py +9 -9
- brainstate/nn/_embedding.py +3 -3
- brainstate/nn/_normalizations.py +3 -3
- brainstate/nn/_others.py +2 -2
- brainstate/nn/_poolings.py +6 -6
- brainstate/nn/_rate_rnns.py +1 -1
- brainstate/nn/_readout.py +1 -1
- brainstate/nn/_synouts.py +1 -1
- brainstate/nn/event/__init__.py +25 -0
- brainstate/nn/event/_misc.py +34 -0
- brainstate/nn/event/csr.py +312 -0
- brainstate/nn/event/csr_test.py +118 -0
- brainstate/nn/event/fixed_probability.py +276 -0
- brainstate/nn/event/fixed_probability_test.py +127 -0
- brainstate/nn/event/linear.py +220 -0
- brainstate/nn/event/linear_test.py +111 -0
- brainstate/nn/metrics.py +390 -0
- brainstate/optim/__init__.py +5 -1
- brainstate/optim/_optax_optimizer.py +208 -0
- brainstate/optim/_optax_optimizer_test.py +14 -0
- brainstate/random/__init__.py +24 -0
- brainstate/{random.py → random/_rand_funs.py} +7 -1596
- brainstate/random/_rand_seed.py +169 -0
- brainstate/random/_rand_state.py +1491 -0
- brainstate/{_random_for_unit.py → random/_random_for_unit.py} +1 -1
- brainstate/{random_test.py → random/random_test.py} +208 -191
- brainstate/transform/_jit.py +1 -1
- brainstate/transform/_jit_test.py +19 -0
- brainstate/transform/_make_jaxpr.py +1 -1
- {brainstate-0.0.2.post20240910.dist-info → brainstate-0.0.2.post20241009.dist-info}/METADATA +1 -1
- brainstate-0.0.2.post20241009.dist-info/RECORD +87 -0
- brainstate-0.0.2.post20240910.dist-info/RECORD +0 -70
- {brainstate-0.0.2.post20240910.dist-info → brainstate-0.0.2.post20241009.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20240910.dist-info → brainstate-0.0.2.post20241009.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20240910.dist-info → brainstate-0.0.2.post20241009.dist-info}/top_level.txt +0 -0
@@ -18,121 +18,138 @@ import platform
|
|
18
18
|
import unittest
|
19
19
|
|
20
20
|
import jax.numpy as jnp
|
21
|
+
import jax.random
|
21
22
|
import jax.random as jr
|
22
23
|
import numpy as np
|
23
24
|
import pytest
|
24
25
|
|
25
|
-
import brainstate as
|
26
|
+
import brainstate as bst
|
26
27
|
|
27
28
|
|
28
29
|
class TestRandom(unittest.TestCase):
|
30
|
+
|
31
|
+
def test_seed2(self):
|
32
|
+
test_seed = 299
|
33
|
+
key = jax.random.PRNGKey(test_seed)
|
34
|
+
bst.random.seed(key)
|
35
|
+
|
36
|
+
@jax.jit
|
37
|
+
def jit_seed(key):
|
38
|
+
bst.random.seed(key)
|
39
|
+
with bst.random.seed_context(key):
|
40
|
+
print(bst.random.DEFAULT.value)
|
41
|
+
|
42
|
+
jit_seed(key)
|
43
|
+
jit_seed(1)
|
44
|
+
jit_seed(None)
|
45
|
+
|
29
46
|
def test_seed(self):
|
30
47
|
test_seed = 299
|
31
|
-
|
32
|
-
a =
|
33
|
-
|
34
|
-
b =
|
48
|
+
bst.random.seed(test_seed)
|
49
|
+
a = bst.random.rand(3)
|
50
|
+
bst.random.seed(test_seed)
|
51
|
+
b = bst.random.rand(3)
|
35
52
|
self.assertTrue(jnp.array_equal(a, b))
|
36
53
|
|
37
54
|
def test_rand(self):
|
38
|
-
|
39
|
-
a =
|
55
|
+
bst.random.seed()
|
56
|
+
a = bst.random.rand(3, 2)
|
40
57
|
self.assertTupleEqual(a.shape, (3, 2))
|
41
58
|
self.assertTrue((a >= 0).all() and (a < 1).all())
|
42
59
|
|
43
60
|
key = jr.PRNGKey(123)
|
44
61
|
jres = jr.uniform(key, shape=(10, 100))
|
45
|
-
self.assertTrue(jnp.allclose(jres,
|
46
|
-
self.assertTrue(jnp.allclose(jres,
|
62
|
+
self.assertTrue(jnp.allclose(jres, bst.random.rand(10, 100, key=key)))
|
63
|
+
self.assertTrue(jnp.allclose(jres, bst.random.rand(10, 100, key=123)))
|
47
64
|
|
48
65
|
def test_randint1(self):
|
49
|
-
|
50
|
-
a =
|
66
|
+
bst.random.seed()
|
67
|
+
a = bst.random.randint(5)
|
51
68
|
self.assertTupleEqual(a.shape, ())
|
52
69
|
self.assertTrue(0 <= a < 5)
|
53
70
|
|
54
71
|
def test_randint2(self):
|
55
|
-
|
56
|
-
a =
|
72
|
+
bst.random.seed()
|
73
|
+
a = bst.random.randint(2, 6, size=(4, 3))
|
57
74
|
self.assertTupleEqual(a.shape, (4, 3))
|
58
75
|
self.assertTrue((a >= 2).all() and (a < 6).all())
|
59
76
|
|
60
77
|
def test_randint3(self):
|
61
|
-
|
62
|
-
a =
|
78
|
+
bst.random.seed()
|
79
|
+
a = bst.random.randint([1, 2, 3], [10, 7, 8])
|
63
80
|
self.assertTupleEqual(a.shape, (3,))
|
64
81
|
self.assertTrue((a - jnp.array([1, 2, 3]) >= 0).all()
|
65
82
|
and (-a + jnp.array([10, 7, 8]) > 0).all())
|
66
83
|
|
67
84
|
def test_randint4(self):
|
68
|
-
|
69
|
-
a =
|
85
|
+
bst.random.seed()
|
86
|
+
a = bst.random.randint([1, 2, 3], [10, 7, 8], size=(2, 3))
|
70
87
|
self.assertTupleEqual(a.shape, (2, 3))
|
71
88
|
|
72
89
|
def test_randn(self):
|
73
|
-
|
74
|
-
a =
|
90
|
+
bst.random.seed()
|
91
|
+
a = bst.random.randn(3, 2)
|
75
92
|
self.assertTupleEqual(a.shape, (3, 2))
|
76
93
|
|
77
94
|
def test_random1(self):
|
78
|
-
|
79
|
-
a =
|
95
|
+
bst.random.seed()
|
96
|
+
a = bst.random.random()
|
80
97
|
self.assertTrue(0. <= a < 1)
|
81
98
|
|
82
99
|
def test_random2(self):
|
83
|
-
|
84
|
-
a =
|
100
|
+
bst.random.seed()
|
101
|
+
a = bst.random.random(size=(3, 2))
|
85
102
|
self.assertTupleEqual(a.shape, (3, 2))
|
86
103
|
self.assertTrue((a >= 0).all() and (a < 1).all())
|
87
104
|
|
88
105
|
def test_random_sample(self):
|
89
|
-
|
90
|
-
a =
|
106
|
+
bst.random.seed()
|
107
|
+
a = bst.random.random_sample(size=(3, 2))
|
91
108
|
self.assertTupleEqual(a.shape, (3, 2))
|
92
109
|
self.assertTrue((a >= 0).all() and (a < 1).all())
|
93
110
|
|
94
111
|
def test_choice1(self):
|
95
|
-
|
96
|
-
a =
|
112
|
+
bst.random.seed()
|
113
|
+
a = bst.random.choice(5)
|
97
114
|
self.assertTupleEqual(jnp.shape(a), ())
|
98
115
|
self.assertTrue(0 <= a < 5)
|
99
116
|
|
100
117
|
def test_choice2(self):
|
101
|
-
|
102
|
-
a =
|
118
|
+
bst.random.seed()
|
119
|
+
a = bst.random.choice(5, 3, p=[0.1, 0.4, 0.2, 0., 0.3])
|
103
120
|
self.assertTupleEqual(a.shape, (3,))
|
104
121
|
self.assertTrue((a >= 0).all() and (a < 5).all())
|
105
122
|
|
106
123
|
def test_choice3(self):
|
107
|
-
|
108
|
-
a =
|
124
|
+
bst.random.seed()
|
125
|
+
a = bst.random.choice(jnp.arange(2, 20), size=(4, 3), replace=False)
|
109
126
|
self.assertTupleEqual(a.shape, (4, 3))
|
110
127
|
self.assertTrue((a >= 2).all() and (a < 20).all())
|
111
128
|
self.assertEqual(len(jnp.unique(a)), 12)
|
112
129
|
|
113
130
|
def test_permutation1(self):
|
114
|
-
|
115
|
-
a =
|
131
|
+
bst.random.seed()
|
132
|
+
a = bst.random.permutation(10)
|
116
133
|
self.assertTupleEqual(a.shape, (10,))
|
117
134
|
self.assertEqual(len(jnp.unique(a)), 10)
|
118
135
|
|
119
136
|
def test_permutation2(self):
|
120
|
-
|
121
|
-
a =
|
137
|
+
bst.random.seed()
|
138
|
+
a = bst.random.permutation(jnp.arange(10))
|
122
139
|
self.assertTupleEqual(a.shape, (10,))
|
123
140
|
self.assertEqual(len(jnp.unique(a)), 10)
|
124
141
|
|
125
142
|
def test_shuffle1(self):
|
126
|
-
|
143
|
+
bst.random.seed()
|
127
144
|
a = jnp.arange(10)
|
128
|
-
|
145
|
+
bst.random.shuffle(a)
|
129
146
|
self.assertTupleEqual(a.shape, (10,))
|
130
147
|
self.assertEqual(len(jnp.unique(a)), 10)
|
131
148
|
|
132
149
|
def test_shuffle2(self):
|
133
|
-
|
150
|
+
bst.random.seed()
|
134
151
|
a = jnp.arange(12).reshape(4, 3)
|
135
|
-
|
152
|
+
bst.random.shuffle(a, axis=1)
|
136
153
|
self.assertTupleEqual(a.shape, (4, 3))
|
137
154
|
self.assertEqual(len(jnp.unique(a)), 12)
|
138
155
|
|
@@ -141,173 +158,173 @@ class TestRandom(unittest.TestCase):
|
|
141
158
|
self.assertEqual(uni, jnp.asarray([3]))
|
142
159
|
|
143
160
|
def test_beta1(self):
|
144
|
-
|
145
|
-
a =
|
161
|
+
bst.random.seed()
|
162
|
+
a = bst.random.beta(2, 2)
|
146
163
|
self.assertTupleEqual(a.shape, ())
|
147
164
|
|
148
165
|
def test_beta2(self):
|
149
|
-
|
150
|
-
a =
|
166
|
+
bst.random.seed()
|
167
|
+
a = bst.random.beta([2, 2, 3], 2, size=(3,))
|
151
168
|
self.assertTupleEqual(a.shape, (3,))
|
152
169
|
|
153
170
|
def test_exponential1(self):
|
154
|
-
|
155
|
-
a =
|
171
|
+
bst.random.seed()
|
172
|
+
a = bst.random.exponential(10., size=[3, 2])
|
156
173
|
self.assertTupleEqual(a.shape, (3, 2))
|
157
174
|
|
158
175
|
def test_exponential2(self):
|
159
|
-
|
160
|
-
a =
|
176
|
+
bst.random.seed()
|
177
|
+
a = bst.random.exponential([1., 2., 5.])
|
161
178
|
self.assertTupleEqual(a.shape, (3,))
|
162
179
|
|
163
180
|
def test_gamma(self):
|
164
|
-
|
165
|
-
a =
|
181
|
+
bst.random.seed()
|
182
|
+
a = bst.random.gamma(2, 10., size=[3, 2])
|
166
183
|
self.assertTupleEqual(a.shape, (3, 2))
|
167
184
|
|
168
185
|
def test_gumbel(self):
|
169
|
-
|
170
|
-
a =
|
186
|
+
bst.random.seed()
|
187
|
+
a = bst.random.gumbel(0., 2., size=[3, 2])
|
171
188
|
self.assertTupleEqual(a.shape, (3, 2))
|
172
189
|
|
173
190
|
def test_laplace(self):
|
174
|
-
|
175
|
-
a =
|
191
|
+
bst.random.seed()
|
192
|
+
a = bst.random.laplace(0., 2., size=[3, 2])
|
176
193
|
self.assertTupleEqual(a.shape, (3, 2))
|
177
194
|
|
178
195
|
def test_logistic(self):
|
179
|
-
|
180
|
-
a =
|
196
|
+
bst.random.seed()
|
197
|
+
a = bst.random.logistic(0., 2., size=[3, 2])
|
181
198
|
self.assertTupleEqual(a.shape, (3, 2))
|
182
199
|
|
183
200
|
def test_normal1(self):
|
184
|
-
|
185
|
-
a =
|
201
|
+
bst.random.seed()
|
202
|
+
a = bst.random.normal()
|
186
203
|
self.assertTupleEqual(a.shape, ())
|
187
204
|
|
188
205
|
def test_normal2(self):
|
189
|
-
|
190
|
-
a =
|
206
|
+
bst.random.seed()
|
207
|
+
a = bst.random.normal(loc=[0., 2., 4.], scale=[1., 2., 3.])
|
191
208
|
self.assertTupleEqual(a.shape, (3,))
|
192
209
|
|
193
210
|
def test_normal3(self):
|
194
|
-
|
195
|
-
a =
|
211
|
+
bst.random.seed()
|
212
|
+
a = bst.random.normal(loc=[0., 2., 4.], scale=[[1., 2., 3.], [1., 1., 1.]])
|
196
213
|
print(a)
|
197
214
|
self.assertTupleEqual(a.shape, (2, 3))
|
198
215
|
|
199
216
|
def test_pareto(self):
|
200
|
-
|
201
|
-
a =
|
217
|
+
bst.random.seed()
|
218
|
+
a = bst.random.pareto([1, 2, 2])
|
202
219
|
self.assertTupleEqual(a.shape, (3,))
|
203
220
|
|
204
221
|
def test_poisson(self):
|
205
|
-
|
206
|
-
a =
|
222
|
+
bst.random.seed()
|
223
|
+
a = bst.random.poisson([1., 2., 2.], size=3)
|
207
224
|
self.assertTupleEqual(a.shape, (3,))
|
208
225
|
|
209
226
|
def test_standard_cauchy(self):
|
210
|
-
|
211
|
-
a =
|
227
|
+
bst.random.seed()
|
228
|
+
a = bst.random.standard_cauchy(size=(3, 2))
|
212
229
|
self.assertTupleEqual(a.shape, (3, 2))
|
213
230
|
|
214
231
|
def test_standard_exponential(self):
|
215
|
-
|
216
|
-
a =
|
232
|
+
bst.random.seed()
|
233
|
+
a = bst.random.standard_exponential(size=(3, 2))
|
217
234
|
self.assertTupleEqual(a.shape, (3, 2))
|
218
235
|
|
219
236
|
def test_standard_gamma(self):
|
220
|
-
|
221
|
-
a =
|
237
|
+
bst.random.seed()
|
238
|
+
a = bst.random.standard_gamma(shape=[1, 2, 4], size=3)
|
222
239
|
self.assertTupleEqual(a.shape, (3,))
|
223
240
|
|
224
241
|
def test_standard_normal(self):
|
225
|
-
|
226
|
-
a =
|
242
|
+
bst.random.seed()
|
243
|
+
a = bst.random.standard_normal(size=(3, 2))
|
227
244
|
self.assertTupleEqual(a.shape, (3, 2))
|
228
245
|
|
229
246
|
def test_standard_t(self):
|
230
|
-
|
231
|
-
a =
|
247
|
+
bst.random.seed()
|
248
|
+
a = bst.random.standard_t(df=[1, 2, 4], size=3)
|
232
249
|
self.assertTupleEqual(a.shape, (3,))
|
233
250
|
|
234
251
|
def test_standard_uniform1(self):
|
235
|
-
|
236
|
-
a =
|
252
|
+
bst.random.seed()
|
253
|
+
a = bst.random.uniform()
|
237
254
|
self.assertTupleEqual(a.shape, ())
|
238
255
|
self.assertTrue(0 <= a < 1)
|
239
256
|
|
240
257
|
def test_uniform2(self):
|
241
|
-
|
242
|
-
a =
|
258
|
+
bst.random.seed()
|
259
|
+
a = bst.random.uniform(low=[-1., 5., 2.], high=[2., 6., 10.], size=3)
|
243
260
|
self.assertTupleEqual(a.shape, (3,))
|
244
261
|
self.assertTrue((a - jnp.array([-1., 5., 2.]) >= 0).all()
|
245
262
|
and (-a + jnp.array([2., 6., 10.]) > 0).all())
|
246
263
|
|
247
264
|
def test_uniform3(self):
|
248
|
-
|
249
|
-
a =
|
265
|
+
bst.random.seed()
|
266
|
+
a = bst.random.uniform(low=-1., high=[2., 6., 10.], size=(2, 3))
|
250
267
|
self.assertTupleEqual(a.shape, (2, 3))
|
251
268
|
|
252
269
|
def test_uniform4(self):
|
253
|
-
|
254
|
-
a =
|
270
|
+
bst.random.seed()
|
271
|
+
a = bst.random.uniform(low=[-1., 5., 2.], high=[[2., 6., 10.], [10., 10., 10.]])
|
255
272
|
self.assertTupleEqual(a.shape, (2, 3))
|
256
273
|
|
257
274
|
def test_truncated_normal1(self):
|
258
|
-
|
259
|
-
a =
|
275
|
+
bst.random.seed()
|
276
|
+
a = bst.random.truncated_normal(-1., 1.)
|
260
277
|
self.assertTupleEqual(a.shape, ())
|
261
278
|
self.assertTrue(-1. <= a <= 1.)
|
262
279
|
|
263
280
|
def test_truncated_normal2(self):
|
264
|
-
|
265
|
-
a =
|
281
|
+
bst.random.seed()
|
282
|
+
a = bst.random.truncated_normal(-1., [1., 2., 1.], size=(4, 3))
|
266
283
|
self.assertTupleEqual(a.shape, (4, 3))
|
267
284
|
|
268
285
|
def test_truncated_normal3(self):
|
269
|
-
|
270
|
-
a =
|
286
|
+
bst.random.seed()
|
287
|
+
a = bst.random.truncated_normal([-1., 0., 1.], [[2., 2., 4.], [2., 2., 4.]])
|
271
288
|
self.assertTupleEqual(a.shape, (2, 3))
|
272
289
|
self.assertTrue((a - jnp.array([-1., 0., 1.]) >= 0.).all()
|
273
290
|
and (- a + jnp.array([2., 2., 4.]) >= 0.).all())
|
274
291
|
|
275
292
|
def test_bernoulli1(self):
|
276
|
-
|
277
|
-
a =
|
293
|
+
bst.random.seed()
|
294
|
+
a = bst.random.bernoulli()
|
278
295
|
self.assertTupleEqual(a.shape, ())
|
279
296
|
self.assertTrue(a == 0 or a == 1)
|
280
297
|
|
281
298
|
def test_bernoulli2(self):
|
282
|
-
|
283
|
-
a =
|
299
|
+
bst.random.seed()
|
300
|
+
a = bst.random.bernoulli([0.5, 0.6, 0.8])
|
284
301
|
self.assertTupleEqual(a.shape, (3,))
|
285
302
|
self.assertTrue(jnp.logical_xor(a == 1, a == 0).all())
|
286
303
|
|
287
304
|
def test_bernoulli3(self):
|
288
|
-
|
289
|
-
a =
|
305
|
+
bst.random.seed()
|
306
|
+
a = bst.random.bernoulli([0.5, 0.6], size=(3, 2))
|
290
307
|
self.assertTupleEqual(a.shape, (3, 2))
|
291
308
|
self.assertTrue(jnp.logical_xor(a == 1, a == 0).all())
|
292
309
|
|
293
310
|
def test_lognormal1(self):
|
294
|
-
|
295
|
-
a =
|
311
|
+
bst.random.seed()
|
312
|
+
a = bst.random.lognormal()
|
296
313
|
self.assertTupleEqual(a.shape, ())
|
297
314
|
|
298
315
|
def test_lognormal2(self):
|
299
|
-
|
300
|
-
a =
|
316
|
+
bst.random.seed()
|
317
|
+
a = bst.random.lognormal(sigma=[2., 1.], size=[3, 2])
|
301
318
|
self.assertTupleEqual(a.shape, (3, 2))
|
302
319
|
|
303
320
|
def test_lognormal3(self):
|
304
|
-
|
305
|
-
a =
|
321
|
+
bst.random.seed()
|
322
|
+
a = bst.random.lognormal([2., 0.], [[2., 1.], [3., 1.2]])
|
306
323
|
self.assertTupleEqual(a.shape, (2, 2))
|
307
324
|
|
308
325
|
def test_binomial1(self):
|
309
|
-
|
310
|
-
a =
|
326
|
+
bst.random.seed()
|
327
|
+
a = bst.random.binomial(5, 0.5)
|
311
328
|
b = np.random.binomial(5, 0.5)
|
312
329
|
print(a)
|
313
330
|
print(b)
|
@@ -315,99 +332,99 @@ class TestRandom(unittest.TestCase):
|
|
315
332
|
self.assertTrue(a.dtype, int)
|
316
333
|
|
317
334
|
def test_binomial2(self):
|
318
|
-
|
319
|
-
a =
|
335
|
+
bst.random.seed()
|
336
|
+
a = bst.random.binomial(5, 0.5, size=(3, 2))
|
320
337
|
self.assertTupleEqual(a.shape, (3, 2))
|
321
338
|
self.assertTrue((a >= 0).all() and (a <= 5).all())
|
322
339
|
|
323
340
|
def test_binomial3(self):
|
324
|
-
|
325
|
-
a =
|
341
|
+
bst.random.seed()
|
342
|
+
a = bst.random.binomial(n=jnp.asarray([2, 3, 4]), p=jnp.asarray([[0.5, 0.5, 0.5], [0.6, 0.6, 0.6]]))
|
326
343
|
self.assertTupleEqual(a.shape, (2, 3))
|
327
344
|
|
328
345
|
def test_chisquare1(self):
|
329
|
-
|
330
|
-
a =
|
346
|
+
bst.random.seed()
|
347
|
+
a = bst.random.chisquare(3)
|
331
348
|
self.assertTupleEqual(a.shape, ())
|
332
349
|
self.assertTrue(a.dtype, float)
|
333
350
|
|
334
351
|
def test_chisquare2(self):
|
335
|
-
|
352
|
+
bst.random.seed()
|
336
353
|
with self.assertRaises(NotImplementedError):
|
337
|
-
a =
|
354
|
+
a = bst.random.chisquare(df=[2, 3, 4])
|
338
355
|
|
339
356
|
def test_chisquare3(self):
|
340
|
-
|
341
|
-
a =
|
357
|
+
bst.random.seed()
|
358
|
+
a = bst.random.chisquare(df=2, size=100)
|
342
359
|
self.assertTupleEqual(a.shape, (100,))
|
343
360
|
|
344
361
|
def test_chisquare4(self):
|
345
|
-
|
346
|
-
a =
|
362
|
+
bst.random.seed()
|
363
|
+
a = bst.random.chisquare(df=2, size=(100, 10))
|
347
364
|
self.assertTupleEqual(a.shape, (100, 10))
|
348
365
|
|
349
366
|
def test_dirichlet1(self):
|
350
|
-
|
351
|
-
a =
|
367
|
+
bst.random.seed()
|
368
|
+
a = bst.random.dirichlet((10, 5, 3))
|
352
369
|
self.assertTupleEqual(a.shape, (3,))
|
353
370
|
|
354
371
|
def test_dirichlet2(self):
|
355
|
-
|
356
|
-
a =
|
372
|
+
bst.random.seed()
|
373
|
+
a = bst.random.dirichlet((10, 5, 3), 20)
|
357
374
|
self.assertTupleEqual(a.shape, (20, 3))
|
358
375
|
|
359
376
|
def test_f(self):
|
360
|
-
|
361
|
-
a =
|
377
|
+
bst.random.seed()
|
378
|
+
a = bst.random.f(1., 48., 100)
|
362
379
|
self.assertTupleEqual(a.shape, (100,))
|
363
380
|
|
364
381
|
def test_geometric(self):
|
365
|
-
|
366
|
-
a =
|
382
|
+
bst.random.seed()
|
383
|
+
a = bst.random.geometric([0.7, 0.5, 0.2])
|
367
384
|
self.assertTupleEqual(a.shape, (3,))
|
368
385
|
|
369
386
|
def test_hypergeometric1(self):
|
370
|
-
|
371
|
-
a =
|
387
|
+
bst.random.seed()
|
388
|
+
a = bst.random.hypergeometric(10, 10, 10, 20)
|
372
389
|
self.assertTupleEqual(a.shape, (20,))
|
373
390
|
|
374
391
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error')
|
375
392
|
def test_hypergeometric2(self):
|
376
|
-
|
377
|
-
a =
|
393
|
+
bst.random.seed()
|
394
|
+
a = bst.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]])
|
378
395
|
self.assertTupleEqual(a.shape, (2, 2))
|
379
396
|
|
380
397
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error')
|
381
398
|
def test_hypergeometric3(self):
|
382
|
-
|
383
|
-
a =
|
399
|
+
bst.random.seed()
|
400
|
+
a = bst.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]], size=(3, 2, 2))
|
384
401
|
self.assertTupleEqual(a.shape, (3, 2, 2))
|
385
402
|
|
386
403
|
def test_logseries(self):
|
387
|
-
|
388
|
-
a =
|
404
|
+
bst.random.seed()
|
405
|
+
a = bst.random.logseries([0.7, 0.5, 0.2], size=[4, 3])
|
389
406
|
self.assertTupleEqual(a.shape, (4, 3))
|
390
407
|
|
391
408
|
def test_multinominal1(self):
|
392
|
-
|
409
|
+
bst.random.seed()
|
393
410
|
a = np.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2])
|
394
411
|
print(a, a.shape)
|
395
|
-
b =
|
412
|
+
b = bst.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2])
|
396
413
|
print(b, b.shape)
|
397
414
|
self.assertTupleEqual(a.shape, b.shape)
|
398
415
|
self.assertTupleEqual(b.shape, (4, 2, 3))
|
399
416
|
|
400
417
|
def test_multinominal2(self):
|
401
|
-
|
402
|
-
a =
|
418
|
+
bst.random.seed()
|
419
|
+
a = bst.random.multinomial(100, (0.5, 0.2, 0.3))
|
403
420
|
self.assertTupleEqual(a.shape, (3,))
|
404
421
|
self.assertTrue(a.sum() == 100)
|
405
422
|
|
406
423
|
def test_multivariate_normal1(self):
|
407
|
-
|
424
|
+
bst.random.seed()
|
408
425
|
# self.skipTest('Windows jaxlib error')
|
409
426
|
a = np.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3)
|
410
|
-
b =
|
427
|
+
b = bst.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3)
|
411
428
|
print('test_multivariate_normal1')
|
412
429
|
print(a)
|
413
430
|
print(b)
|
@@ -415,162 +432,162 @@ class TestRandom(unittest.TestCase):
|
|
415
432
|
self.assertTupleEqual(a.shape, (3, 2))
|
416
433
|
|
417
434
|
def test_multivariate_normal2(self):
|
418
|
-
|
435
|
+
bst.random.seed()
|
419
436
|
a = np.random.multivariate_normal([1, 2], [[1, 3], [3, 1]])
|
420
|
-
b =
|
437
|
+
b = bst.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], method='svd')
|
421
438
|
print(a)
|
422
439
|
print(b)
|
423
440
|
self.assertTupleEqual(a.shape, b.shape)
|
424
441
|
self.assertTupleEqual(a.shape, (2,))
|
425
442
|
|
426
443
|
def test_negative_binomial(self):
|
427
|
-
|
444
|
+
bst.random.seed()
|
428
445
|
a = np.random.negative_binomial([3., 10.], 0.5)
|
429
|
-
b =
|
446
|
+
b = bst.random.negative_binomial([3., 10.], 0.5)
|
430
447
|
print(a)
|
431
448
|
print(b)
|
432
449
|
self.assertTupleEqual(a.shape, b.shape)
|
433
450
|
self.assertTupleEqual(b.shape, (2,))
|
434
451
|
|
435
452
|
def test_negative_binomial2(self):
|
436
|
-
|
453
|
+
bst.random.seed()
|
437
454
|
a = np.random.negative_binomial(3., 0.5, 10)
|
438
|
-
b =
|
455
|
+
b = bst.random.negative_binomial(3., 0.5, 10)
|
439
456
|
print(a)
|
440
457
|
print(b)
|
441
458
|
self.assertTupleEqual(a.shape, b.shape)
|
442
459
|
self.assertTupleEqual(b.shape, (10,))
|
443
460
|
|
444
461
|
def test_noncentral_chisquare(self):
|
445
|
-
|
462
|
+
bst.random.seed()
|
446
463
|
a = np.random.noncentral_chisquare(3, [3., 2.], (4, 2))
|
447
|
-
b =
|
464
|
+
b = bst.random.noncentral_chisquare(3, [3., 2.], (4, 2))
|
448
465
|
self.assertTupleEqual(a.shape, b.shape)
|
449
466
|
self.assertTupleEqual(b.shape, (4, 2))
|
450
467
|
|
451
468
|
def test_noncentral_chisquare2(self):
|
452
|
-
|
453
|
-
a =
|
469
|
+
bst.random.seed()
|
470
|
+
a = bst.random.noncentral_chisquare(3, [3., 2.])
|
454
471
|
self.assertTupleEqual(a.shape, (2,))
|
455
472
|
|
456
473
|
def test_noncentral_f(self):
|
457
|
-
|
458
|
-
a =
|
474
|
+
bst.random.seed()
|
475
|
+
a = bst.random.noncentral_f(3, 20, 3., 100)
|
459
476
|
self.assertTupleEqual(a.shape, (100,))
|
460
477
|
|
461
478
|
def test_power(self):
|
462
|
-
|
479
|
+
bst.random.seed()
|
463
480
|
a = np.random.power(2, (4, 2))
|
464
|
-
b =
|
481
|
+
b = bst.random.power(2, (4, 2))
|
465
482
|
self.assertTupleEqual(a.shape, b.shape)
|
466
483
|
self.assertTupleEqual(b.shape, (4, 2))
|
467
484
|
|
468
485
|
def test_rayleigh(self):
|
469
|
-
|
470
|
-
a =
|
486
|
+
bst.random.seed()
|
487
|
+
a = bst.random.power(2., (4, 2))
|
471
488
|
self.assertTupleEqual(a.shape, (4, 2))
|
472
489
|
|
473
490
|
def test_triangular(self):
|
474
|
-
|
475
|
-
a =
|
491
|
+
bst.random.seed()
|
492
|
+
a = bst.random.triangular((2, 2))
|
476
493
|
self.assertTupleEqual(a.shape, (2, 2))
|
477
494
|
|
478
495
|
def test_vonmises(self):
|
479
|
-
|
496
|
+
bst.random.seed()
|
480
497
|
a = np.random.vonmises(2., 2.)
|
481
|
-
b =
|
498
|
+
b = bst.random.vonmises(2., 2.)
|
482
499
|
print(a, b)
|
483
500
|
self.assertTupleEqual(np.shape(a), b.shape)
|
484
501
|
self.assertTupleEqual(b.shape, ())
|
485
502
|
|
486
503
|
def test_vonmises2(self):
|
487
|
-
|
504
|
+
bst.random.seed()
|
488
505
|
a = np.random.vonmises(2., 2., 10)
|
489
|
-
b =
|
506
|
+
b = bst.random.vonmises(2., 2., 10)
|
490
507
|
print(a, b)
|
491
508
|
self.assertTupleEqual(a.shape, b.shape)
|
492
509
|
self.assertTupleEqual(b.shape, (10,))
|
493
510
|
|
494
511
|
def test_wald(self):
|
495
|
-
|
512
|
+
bst.random.seed()
|
496
513
|
a = np.random.wald([2., 0.5], 2.)
|
497
|
-
b =
|
514
|
+
b = bst.random.wald([2., 0.5], 2.)
|
498
515
|
self.assertTupleEqual(a.shape, b.shape)
|
499
516
|
self.assertTupleEqual(b.shape, (2,))
|
500
517
|
|
501
518
|
def test_wald2(self):
|
502
|
-
|
519
|
+
bst.random.seed()
|
503
520
|
a = np.random.wald(2., 2., 100)
|
504
|
-
b =
|
521
|
+
b = bst.random.wald(2., 2., 100)
|
505
522
|
self.assertTupleEqual(a.shape, b.shape)
|
506
523
|
self.assertTupleEqual(b.shape, (100,))
|
507
524
|
|
508
525
|
def test_weibull(self):
|
509
|
-
|
510
|
-
a =
|
526
|
+
bst.random.seed()
|
527
|
+
a = bst.random.weibull(2., (4, 2))
|
511
528
|
self.assertTupleEqual(a.shape, (4, 2))
|
512
529
|
|
513
530
|
def test_weibull2(self):
|
514
|
-
|
515
|
-
a =
|
531
|
+
bst.random.seed()
|
532
|
+
a = bst.random.weibull(2., )
|
516
533
|
self.assertTupleEqual(a.shape, ())
|
517
534
|
|
518
535
|
def test_weibull3(self):
|
519
|
-
|
520
|
-
a =
|
536
|
+
bst.random.seed()
|
537
|
+
a = bst.random.weibull([2., 3.], )
|
521
538
|
self.assertTupleEqual(a.shape, (2,))
|
522
539
|
|
523
540
|
def test_weibull_min(self):
|
524
|
-
|
525
|
-
a =
|
541
|
+
bst.random.seed()
|
542
|
+
a = bst.random.weibull_min(2., 2., (4, 2))
|
526
543
|
self.assertTupleEqual(a.shape, (4, 2))
|
527
544
|
|
528
545
|
def test_weibull_min2(self):
|
529
|
-
|
530
|
-
a =
|
546
|
+
bst.random.seed()
|
547
|
+
a = bst.random.weibull_min(2., 2.)
|
531
548
|
self.assertTupleEqual(a.shape, ())
|
532
549
|
|
533
550
|
def test_weibull_min3(self):
|
534
|
-
|
535
|
-
a =
|
551
|
+
bst.random.seed()
|
552
|
+
a = bst.random.weibull_min([2., 3.], 2.)
|
536
553
|
self.assertTupleEqual(a.shape, (2,))
|
537
554
|
|
538
555
|
def test_zipf(self):
|
539
|
-
|
540
|
-
a =
|
556
|
+
bst.random.seed()
|
557
|
+
a = bst.random.zipf(2., (4, 2))
|
541
558
|
self.assertTupleEqual(a.shape, (4, 2))
|
542
559
|
|
543
560
|
def test_zipf2(self):
|
544
|
-
|
561
|
+
bst.random.seed()
|
545
562
|
a = np.random.zipf([1.1, 2.])
|
546
|
-
b =
|
563
|
+
b = bst.random.zipf([1.1, 2.])
|
547
564
|
self.assertTupleEqual(a.shape, b.shape)
|
548
565
|
self.assertTupleEqual(b.shape, (2,))
|
549
566
|
|
550
567
|
def test_maxwell(self):
|
551
|
-
|
552
|
-
a =
|
568
|
+
bst.random.seed()
|
569
|
+
a = bst.random.maxwell(10)
|
553
570
|
self.assertTupleEqual(a.shape, (10,))
|
554
571
|
|
555
572
|
def test_maxwell2(self):
|
556
|
-
|
557
|
-
a =
|
573
|
+
bst.random.seed()
|
574
|
+
a = bst.random.maxwell()
|
558
575
|
self.assertTupleEqual(a.shape, ())
|
559
576
|
|
560
577
|
def test_t(self):
|
561
|
-
|
562
|
-
a =
|
578
|
+
bst.random.seed()
|
579
|
+
a = bst.random.t(1., size=10)
|
563
580
|
self.assertTupleEqual(a.shape, (10,))
|
564
581
|
|
565
582
|
def test_t2(self):
|
566
|
-
|
567
|
-
a =
|
583
|
+
bst.random.seed()
|
584
|
+
a = bst.random.t([1., 2.], size=None)
|
568
585
|
self.assertTupleEqual(a.shape, (2,))
|
569
586
|
|
570
587
|
|
571
588
|
class TestRandomKey(unittest.TestCase):
|
572
589
|
def test_clear_memory(self):
|
573
|
-
|
574
|
-
|
575
|
-
print(
|
576
|
-
self.assertTrue(isinstance(
|
590
|
+
bst.random.split_key()
|
591
|
+
bst.util.clear_buffer_memory()
|
592
|
+
print(bst.random.DEFAULT.value)
|
593
|
+
self.assertTrue(isinstance(bst.random.DEFAULT.value, np.ndarray))
|