brainstate 0.0.2.post20240913__py2.py3-none-any.whl → 0.0.2.post20241009__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. brainstate/__init__.py +4 -2
  2. brainstate/_module.py +102 -67
  3. brainstate/_state.py +2 -2
  4. brainstate/_visualization.py +47 -0
  5. brainstate/environ.py +116 -9
  6. brainstate/environ_test.py +56 -0
  7. brainstate/functional/_activations.py +134 -56
  8. brainstate/functional/_activations_test.py +331 -0
  9. brainstate/functional/_normalization.py +21 -10
  10. brainstate/init/_generic.py +4 -2
  11. brainstate/mixin.py +1 -1
  12. brainstate/nn/__init__.py +7 -2
  13. brainstate/nn/_base.py +2 -2
  14. brainstate/nn/_connections.py +4 -4
  15. brainstate/nn/_dynamics.py +5 -5
  16. brainstate/nn/_elementwise.py +9 -9
  17. brainstate/nn/_embedding.py +3 -3
  18. brainstate/nn/_normalizations.py +3 -3
  19. brainstate/nn/_others.py +2 -2
  20. brainstate/nn/_poolings.py +6 -6
  21. brainstate/nn/_rate_rnns.py +1 -1
  22. brainstate/nn/_readout.py +1 -1
  23. brainstate/nn/_synouts.py +1 -1
  24. brainstate/nn/event/__init__.py +25 -0
  25. brainstate/nn/event/_misc.py +34 -0
  26. brainstate/nn/event/csr.py +312 -0
  27. brainstate/nn/event/csr_test.py +118 -0
  28. brainstate/nn/event/fixed_probability.py +276 -0
  29. brainstate/nn/event/fixed_probability_test.py +127 -0
  30. brainstate/nn/event/linear.py +220 -0
  31. brainstate/nn/event/linear_test.py +111 -0
  32. brainstate/nn/metrics.py +390 -0
  33. brainstate/optim/__init__.py +5 -1
  34. brainstate/optim/_optax_optimizer.py +208 -0
  35. brainstate/optim/_optax_optimizer_test.py +14 -0
  36. brainstate/random/__init__.py +24 -0
  37. brainstate/{random.py → random/_rand_funs.py} +7 -1596
  38. brainstate/random/_rand_seed.py +169 -0
  39. brainstate/random/_rand_state.py +1491 -0
  40. brainstate/{_random_for_unit.py → random/_random_for_unit.py} +1 -1
  41. brainstate/{random_test.py → random/random_test.py} +208 -191
  42. brainstate/transform/_jit.py +1 -1
  43. brainstate/transform/_jit_test.py +19 -0
  44. brainstate/transform/_make_jaxpr.py +1 -1
  45. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/METADATA +1 -1
  46. brainstate-0.0.2.post20241009.dist-info/RECORD +87 -0
  47. brainstate-0.0.2.post20240913.dist-info/RECORD +0 -70
  48. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/LICENSE +0 -0
  49. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/WHEEL +0 -0
  50. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/top_level.txt +0 -0
@@ -18,121 +18,138 @@ import platform
18
18
  import unittest
19
19
 
20
20
  import jax.numpy as jnp
21
+ import jax.random
21
22
  import jax.random as jr
22
23
  import numpy as np
23
24
  import pytest
24
25
 
25
- import brainstate as bc
26
+ import brainstate as bst
26
27
 
27
28
 
28
29
  class TestRandom(unittest.TestCase):
30
+
31
+ def test_seed2(self):
32
+ test_seed = 299
33
+ key = jax.random.PRNGKey(test_seed)
34
+ bst.random.seed(key)
35
+
36
+ @jax.jit
37
+ def jit_seed(key):
38
+ bst.random.seed(key)
39
+ with bst.random.seed_context(key):
40
+ print(bst.random.DEFAULT.value)
41
+
42
+ jit_seed(key)
43
+ jit_seed(1)
44
+ jit_seed(None)
45
+
29
46
  def test_seed(self):
30
47
  test_seed = 299
31
- bc.random.seed(test_seed)
32
- a = bc.random.rand(3)
33
- bc.random.seed(test_seed)
34
- b = bc.random.rand(3)
48
+ bst.random.seed(test_seed)
49
+ a = bst.random.rand(3)
50
+ bst.random.seed(test_seed)
51
+ b = bst.random.rand(3)
35
52
  self.assertTrue(jnp.array_equal(a, b))
36
53
 
37
54
  def test_rand(self):
38
- bc.random.seed()
39
- a = bc.random.rand(3, 2)
55
+ bst.random.seed()
56
+ a = bst.random.rand(3, 2)
40
57
  self.assertTupleEqual(a.shape, (3, 2))
41
58
  self.assertTrue((a >= 0).all() and (a < 1).all())
42
59
 
43
60
  key = jr.PRNGKey(123)
44
61
  jres = jr.uniform(key, shape=(10, 100))
45
- self.assertTrue(jnp.allclose(jres, bc.random.rand(10, 100, key=key)))
46
- self.assertTrue(jnp.allclose(jres, bc.random.rand(10, 100, key=123)))
62
+ self.assertTrue(jnp.allclose(jres, bst.random.rand(10, 100, key=key)))
63
+ self.assertTrue(jnp.allclose(jres, bst.random.rand(10, 100, key=123)))
47
64
 
48
65
  def test_randint1(self):
49
- bc.random.seed()
50
- a = bc.random.randint(5)
66
+ bst.random.seed()
67
+ a = bst.random.randint(5)
51
68
  self.assertTupleEqual(a.shape, ())
52
69
  self.assertTrue(0 <= a < 5)
53
70
 
54
71
  def test_randint2(self):
55
- bc.random.seed()
56
- a = bc.random.randint(2, 6, size=(4, 3))
72
+ bst.random.seed()
73
+ a = bst.random.randint(2, 6, size=(4, 3))
57
74
  self.assertTupleEqual(a.shape, (4, 3))
58
75
  self.assertTrue((a >= 2).all() and (a < 6).all())
59
76
 
60
77
  def test_randint3(self):
61
- bc.random.seed()
62
- a = bc.random.randint([1, 2, 3], [10, 7, 8])
78
+ bst.random.seed()
79
+ a = bst.random.randint([1, 2, 3], [10, 7, 8])
63
80
  self.assertTupleEqual(a.shape, (3,))
64
81
  self.assertTrue((a - jnp.array([1, 2, 3]) >= 0).all()
65
82
  and (-a + jnp.array([10, 7, 8]) > 0).all())
66
83
 
67
84
  def test_randint4(self):
68
- bc.random.seed()
69
- a = bc.random.randint([1, 2, 3], [10, 7, 8], size=(2, 3))
85
+ bst.random.seed()
86
+ a = bst.random.randint([1, 2, 3], [10, 7, 8], size=(2, 3))
70
87
  self.assertTupleEqual(a.shape, (2, 3))
71
88
 
72
89
  def test_randn(self):
73
- bc.random.seed()
74
- a = bc.random.randn(3, 2)
90
+ bst.random.seed()
91
+ a = bst.random.randn(3, 2)
75
92
  self.assertTupleEqual(a.shape, (3, 2))
76
93
 
77
94
  def test_random1(self):
78
- bc.random.seed()
79
- a = bc.random.random()
95
+ bst.random.seed()
96
+ a = bst.random.random()
80
97
  self.assertTrue(0. <= a < 1)
81
98
 
82
99
  def test_random2(self):
83
- bc.random.seed()
84
- a = bc.random.random(size=(3, 2))
100
+ bst.random.seed()
101
+ a = bst.random.random(size=(3, 2))
85
102
  self.assertTupleEqual(a.shape, (3, 2))
86
103
  self.assertTrue((a >= 0).all() and (a < 1).all())
87
104
 
88
105
  def test_random_sample(self):
89
- bc.random.seed()
90
- a = bc.random.random_sample(size=(3, 2))
106
+ bst.random.seed()
107
+ a = bst.random.random_sample(size=(3, 2))
91
108
  self.assertTupleEqual(a.shape, (3, 2))
92
109
  self.assertTrue((a >= 0).all() and (a < 1).all())
93
110
 
94
111
  def test_choice1(self):
95
- bc.random.seed()
96
- a = bc.random.choice(5)
112
+ bst.random.seed()
113
+ a = bst.random.choice(5)
97
114
  self.assertTupleEqual(jnp.shape(a), ())
98
115
  self.assertTrue(0 <= a < 5)
99
116
 
100
117
  def test_choice2(self):
101
- bc.random.seed()
102
- a = bc.random.choice(5, 3, p=[0.1, 0.4, 0.2, 0., 0.3])
118
+ bst.random.seed()
119
+ a = bst.random.choice(5, 3, p=[0.1, 0.4, 0.2, 0., 0.3])
103
120
  self.assertTupleEqual(a.shape, (3,))
104
121
  self.assertTrue((a >= 0).all() and (a < 5).all())
105
122
 
106
123
  def test_choice3(self):
107
- bc.random.seed()
108
- a = bc.random.choice(jnp.arange(2, 20), size=(4, 3), replace=False)
124
+ bst.random.seed()
125
+ a = bst.random.choice(jnp.arange(2, 20), size=(4, 3), replace=False)
109
126
  self.assertTupleEqual(a.shape, (4, 3))
110
127
  self.assertTrue((a >= 2).all() and (a < 20).all())
111
128
  self.assertEqual(len(jnp.unique(a)), 12)
112
129
 
113
130
  def test_permutation1(self):
114
- bc.random.seed()
115
- a = bc.random.permutation(10)
131
+ bst.random.seed()
132
+ a = bst.random.permutation(10)
116
133
  self.assertTupleEqual(a.shape, (10,))
117
134
  self.assertEqual(len(jnp.unique(a)), 10)
118
135
 
119
136
  def test_permutation2(self):
120
- bc.random.seed()
121
- a = bc.random.permutation(jnp.arange(10))
137
+ bst.random.seed()
138
+ a = bst.random.permutation(jnp.arange(10))
122
139
  self.assertTupleEqual(a.shape, (10,))
123
140
  self.assertEqual(len(jnp.unique(a)), 10)
124
141
 
125
142
  def test_shuffle1(self):
126
- bc.random.seed()
143
+ bst.random.seed()
127
144
  a = jnp.arange(10)
128
- bc.random.shuffle(a)
145
+ bst.random.shuffle(a)
129
146
  self.assertTupleEqual(a.shape, (10,))
130
147
  self.assertEqual(len(jnp.unique(a)), 10)
131
148
 
132
149
  def test_shuffle2(self):
133
- bc.random.seed()
150
+ bst.random.seed()
134
151
  a = jnp.arange(12).reshape(4, 3)
135
- bc.random.shuffle(a, axis=1)
152
+ bst.random.shuffle(a, axis=1)
136
153
  self.assertTupleEqual(a.shape, (4, 3))
137
154
  self.assertEqual(len(jnp.unique(a)), 12)
138
155
 
@@ -141,173 +158,173 @@ class TestRandom(unittest.TestCase):
141
158
  self.assertEqual(uni, jnp.asarray([3]))
142
159
 
143
160
  def test_beta1(self):
144
- bc.random.seed()
145
- a = bc.random.beta(2, 2)
161
+ bst.random.seed()
162
+ a = bst.random.beta(2, 2)
146
163
  self.assertTupleEqual(a.shape, ())
147
164
 
148
165
  def test_beta2(self):
149
- bc.random.seed()
150
- a = bc.random.beta([2, 2, 3], 2, size=(3,))
166
+ bst.random.seed()
167
+ a = bst.random.beta([2, 2, 3], 2, size=(3,))
151
168
  self.assertTupleEqual(a.shape, (3,))
152
169
 
153
170
  def test_exponential1(self):
154
- bc.random.seed()
155
- a = bc.random.exponential(10., size=[3, 2])
171
+ bst.random.seed()
172
+ a = bst.random.exponential(10., size=[3, 2])
156
173
  self.assertTupleEqual(a.shape, (3, 2))
157
174
 
158
175
  def test_exponential2(self):
159
- bc.random.seed()
160
- a = bc.random.exponential([1., 2., 5.])
176
+ bst.random.seed()
177
+ a = bst.random.exponential([1., 2., 5.])
161
178
  self.assertTupleEqual(a.shape, (3,))
162
179
 
163
180
  def test_gamma(self):
164
- bc.random.seed()
165
- a = bc.random.gamma(2, 10., size=[3, 2])
181
+ bst.random.seed()
182
+ a = bst.random.gamma(2, 10., size=[3, 2])
166
183
  self.assertTupleEqual(a.shape, (3, 2))
167
184
 
168
185
  def test_gumbel(self):
169
- bc.random.seed()
170
- a = bc.random.gumbel(0., 2., size=[3, 2])
186
+ bst.random.seed()
187
+ a = bst.random.gumbel(0., 2., size=[3, 2])
171
188
  self.assertTupleEqual(a.shape, (3, 2))
172
189
 
173
190
  def test_laplace(self):
174
- bc.random.seed()
175
- a = bc.random.laplace(0., 2., size=[3, 2])
191
+ bst.random.seed()
192
+ a = bst.random.laplace(0., 2., size=[3, 2])
176
193
  self.assertTupleEqual(a.shape, (3, 2))
177
194
 
178
195
  def test_logistic(self):
179
- bc.random.seed()
180
- a = bc.random.logistic(0., 2., size=[3, 2])
196
+ bst.random.seed()
197
+ a = bst.random.logistic(0., 2., size=[3, 2])
181
198
  self.assertTupleEqual(a.shape, (3, 2))
182
199
 
183
200
  def test_normal1(self):
184
- bc.random.seed()
185
- a = bc.random.normal()
201
+ bst.random.seed()
202
+ a = bst.random.normal()
186
203
  self.assertTupleEqual(a.shape, ())
187
204
 
188
205
  def test_normal2(self):
189
- bc.random.seed()
190
- a = bc.random.normal(loc=[0., 2., 4.], scale=[1., 2., 3.])
206
+ bst.random.seed()
207
+ a = bst.random.normal(loc=[0., 2., 4.], scale=[1., 2., 3.])
191
208
  self.assertTupleEqual(a.shape, (3,))
192
209
 
193
210
  def test_normal3(self):
194
- bc.random.seed()
195
- a = bc.random.normal(loc=[0., 2., 4.], scale=[[1., 2., 3.], [1., 1., 1.]])
211
+ bst.random.seed()
212
+ a = bst.random.normal(loc=[0., 2., 4.], scale=[[1., 2., 3.], [1., 1., 1.]])
196
213
  print(a)
197
214
  self.assertTupleEqual(a.shape, (2, 3))
198
215
 
199
216
  def test_pareto(self):
200
- bc.random.seed()
201
- a = bc.random.pareto([1, 2, 2])
217
+ bst.random.seed()
218
+ a = bst.random.pareto([1, 2, 2])
202
219
  self.assertTupleEqual(a.shape, (3,))
203
220
 
204
221
  def test_poisson(self):
205
- bc.random.seed()
206
- a = bc.random.poisson([1., 2., 2.], size=3)
222
+ bst.random.seed()
223
+ a = bst.random.poisson([1., 2., 2.], size=3)
207
224
  self.assertTupleEqual(a.shape, (3,))
208
225
 
209
226
  def test_standard_cauchy(self):
210
- bc.random.seed()
211
- a = bc.random.standard_cauchy(size=(3, 2))
227
+ bst.random.seed()
228
+ a = bst.random.standard_cauchy(size=(3, 2))
212
229
  self.assertTupleEqual(a.shape, (3, 2))
213
230
 
214
231
  def test_standard_exponential(self):
215
- bc.random.seed()
216
- a = bc.random.standard_exponential(size=(3, 2))
232
+ bst.random.seed()
233
+ a = bst.random.standard_exponential(size=(3, 2))
217
234
  self.assertTupleEqual(a.shape, (3, 2))
218
235
 
219
236
  def test_standard_gamma(self):
220
- bc.random.seed()
221
- a = bc.random.standard_gamma(shape=[1, 2, 4], size=3)
237
+ bst.random.seed()
238
+ a = bst.random.standard_gamma(shape=[1, 2, 4], size=3)
222
239
  self.assertTupleEqual(a.shape, (3,))
223
240
 
224
241
  def test_standard_normal(self):
225
- bc.random.seed()
226
- a = bc.random.standard_normal(size=(3, 2))
242
+ bst.random.seed()
243
+ a = bst.random.standard_normal(size=(3, 2))
227
244
  self.assertTupleEqual(a.shape, (3, 2))
228
245
 
229
246
  def test_standard_t(self):
230
- bc.random.seed()
231
- a = bc.random.standard_t(df=[1, 2, 4], size=3)
247
+ bst.random.seed()
248
+ a = bst.random.standard_t(df=[1, 2, 4], size=3)
232
249
  self.assertTupleEqual(a.shape, (3,))
233
250
 
234
251
  def test_standard_uniform1(self):
235
- bc.random.seed()
236
- a = bc.random.uniform()
252
+ bst.random.seed()
253
+ a = bst.random.uniform()
237
254
  self.assertTupleEqual(a.shape, ())
238
255
  self.assertTrue(0 <= a < 1)
239
256
 
240
257
  def test_uniform2(self):
241
- bc.random.seed()
242
- a = bc.random.uniform(low=[-1., 5., 2.], high=[2., 6., 10.], size=3)
258
+ bst.random.seed()
259
+ a = bst.random.uniform(low=[-1., 5., 2.], high=[2., 6., 10.], size=3)
243
260
  self.assertTupleEqual(a.shape, (3,))
244
261
  self.assertTrue((a - jnp.array([-1., 5., 2.]) >= 0).all()
245
262
  and (-a + jnp.array([2., 6., 10.]) > 0).all())
246
263
 
247
264
  def test_uniform3(self):
248
- bc.random.seed()
249
- a = bc.random.uniform(low=-1., high=[2., 6., 10.], size=(2, 3))
265
+ bst.random.seed()
266
+ a = bst.random.uniform(low=-1., high=[2., 6., 10.], size=(2, 3))
250
267
  self.assertTupleEqual(a.shape, (2, 3))
251
268
 
252
269
  def test_uniform4(self):
253
- bc.random.seed()
254
- a = bc.random.uniform(low=[-1., 5., 2.], high=[[2., 6., 10.], [10., 10., 10.]])
270
+ bst.random.seed()
271
+ a = bst.random.uniform(low=[-1., 5., 2.], high=[[2., 6., 10.], [10., 10., 10.]])
255
272
  self.assertTupleEqual(a.shape, (2, 3))
256
273
 
257
274
  def test_truncated_normal1(self):
258
- bc.random.seed()
259
- a = bc.random.truncated_normal(-1., 1.)
275
+ bst.random.seed()
276
+ a = bst.random.truncated_normal(-1., 1.)
260
277
  self.assertTupleEqual(a.shape, ())
261
278
  self.assertTrue(-1. <= a <= 1.)
262
279
 
263
280
  def test_truncated_normal2(self):
264
- bc.random.seed()
265
- a = bc.random.truncated_normal(-1., [1., 2., 1.], size=(4, 3))
281
+ bst.random.seed()
282
+ a = bst.random.truncated_normal(-1., [1., 2., 1.], size=(4, 3))
266
283
  self.assertTupleEqual(a.shape, (4, 3))
267
284
 
268
285
  def test_truncated_normal3(self):
269
- bc.random.seed()
270
- a = bc.random.truncated_normal([-1., 0., 1.], [[2., 2., 4.], [2., 2., 4.]])
286
+ bst.random.seed()
287
+ a = bst.random.truncated_normal([-1., 0., 1.], [[2., 2., 4.], [2., 2., 4.]])
271
288
  self.assertTupleEqual(a.shape, (2, 3))
272
289
  self.assertTrue((a - jnp.array([-1., 0., 1.]) >= 0.).all()
273
290
  and (- a + jnp.array([2., 2., 4.]) >= 0.).all())
274
291
 
275
292
  def test_bernoulli1(self):
276
- bc.random.seed()
277
- a = bc.random.bernoulli()
293
+ bst.random.seed()
294
+ a = bst.random.bernoulli()
278
295
  self.assertTupleEqual(a.shape, ())
279
296
  self.assertTrue(a == 0 or a == 1)
280
297
 
281
298
  def test_bernoulli2(self):
282
- bc.random.seed()
283
- a = bc.random.bernoulli([0.5, 0.6, 0.8])
299
+ bst.random.seed()
300
+ a = bst.random.bernoulli([0.5, 0.6, 0.8])
284
301
  self.assertTupleEqual(a.shape, (3,))
285
302
  self.assertTrue(jnp.logical_xor(a == 1, a == 0).all())
286
303
 
287
304
  def test_bernoulli3(self):
288
- bc.random.seed()
289
- a = bc.random.bernoulli([0.5, 0.6], size=(3, 2))
305
+ bst.random.seed()
306
+ a = bst.random.bernoulli([0.5, 0.6], size=(3, 2))
290
307
  self.assertTupleEqual(a.shape, (3, 2))
291
308
  self.assertTrue(jnp.logical_xor(a == 1, a == 0).all())
292
309
 
293
310
  def test_lognormal1(self):
294
- bc.random.seed()
295
- a = bc.random.lognormal()
311
+ bst.random.seed()
312
+ a = bst.random.lognormal()
296
313
  self.assertTupleEqual(a.shape, ())
297
314
 
298
315
  def test_lognormal2(self):
299
- bc.random.seed()
300
- a = bc.random.lognormal(sigma=[2., 1.], size=[3, 2])
316
+ bst.random.seed()
317
+ a = bst.random.lognormal(sigma=[2., 1.], size=[3, 2])
301
318
  self.assertTupleEqual(a.shape, (3, 2))
302
319
 
303
320
  def test_lognormal3(self):
304
- bc.random.seed()
305
- a = bc.random.lognormal([2., 0.], [[2., 1.], [3., 1.2]])
321
+ bst.random.seed()
322
+ a = bst.random.lognormal([2., 0.], [[2., 1.], [3., 1.2]])
306
323
  self.assertTupleEqual(a.shape, (2, 2))
307
324
 
308
325
  def test_binomial1(self):
309
- bc.random.seed()
310
- a = bc.random.binomial(5, 0.5)
326
+ bst.random.seed()
327
+ a = bst.random.binomial(5, 0.5)
311
328
  b = np.random.binomial(5, 0.5)
312
329
  print(a)
313
330
  print(b)
@@ -315,99 +332,99 @@ class TestRandom(unittest.TestCase):
315
332
  self.assertTrue(a.dtype, int)
316
333
 
317
334
  def test_binomial2(self):
318
- bc.random.seed()
319
- a = bc.random.binomial(5, 0.5, size=(3, 2))
335
+ bst.random.seed()
336
+ a = bst.random.binomial(5, 0.5, size=(3, 2))
320
337
  self.assertTupleEqual(a.shape, (3, 2))
321
338
  self.assertTrue((a >= 0).all() and (a <= 5).all())
322
339
 
323
340
  def test_binomial3(self):
324
- bc.random.seed()
325
- a = bc.random.binomial(n=jnp.asarray([2, 3, 4]), p=jnp.asarray([[0.5, 0.5, 0.5], [0.6, 0.6, 0.6]]))
341
+ bst.random.seed()
342
+ a = bst.random.binomial(n=jnp.asarray([2, 3, 4]), p=jnp.asarray([[0.5, 0.5, 0.5], [0.6, 0.6, 0.6]]))
326
343
  self.assertTupleEqual(a.shape, (2, 3))
327
344
 
328
345
  def test_chisquare1(self):
329
- bc.random.seed()
330
- a = bc.random.chisquare(3)
346
+ bst.random.seed()
347
+ a = bst.random.chisquare(3)
331
348
  self.assertTupleEqual(a.shape, ())
332
349
  self.assertTrue(a.dtype, float)
333
350
 
334
351
  def test_chisquare2(self):
335
- bc.random.seed()
352
+ bst.random.seed()
336
353
  with self.assertRaises(NotImplementedError):
337
- a = bc.random.chisquare(df=[2, 3, 4])
354
+ a = bst.random.chisquare(df=[2, 3, 4])
338
355
 
339
356
  def test_chisquare3(self):
340
- bc.random.seed()
341
- a = bc.random.chisquare(df=2, size=100)
357
+ bst.random.seed()
358
+ a = bst.random.chisquare(df=2, size=100)
342
359
  self.assertTupleEqual(a.shape, (100,))
343
360
 
344
361
  def test_chisquare4(self):
345
- bc.random.seed()
346
- a = bc.random.chisquare(df=2, size=(100, 10))
362
+ bst.random.seed()
363
+ a = bst.random.chisquare(df=2, size=(100, 10))
347
364
  self.assertTupleEqual(a.shape, (100, 10))
348
365
 
349
366
  def test_dirichlet1(self):
350
- bc.random.seed()
351
- a = bc.random.dirichlet((10, 5, 3))
367
+ bst.random.seed()
368
+ a = bst.random.dirichlet((10, 5, 3))
352
369
  self.assertTupleEqual(a.shape, (3,))
353
370
 
354
371
  def test_dirichlet2(self):
355
- bc.random.seed()
356
- a = bc.random.dirichlet((10, 5, 3), 20)
372
+ bst.random.seed()
373
+ a = bst.random.dirichlet((10, 5, 3), 20)
357
374
  self.assertTupleEqual(a.shape, (20, 3))
358
375
 
359
376
  def test_f(self):
360
- bc.random.seed()
361
- a = bc.random.f(1., 48., 100)
377
+ bst.random.seed()
378
+ a = bst.random.f(1., 48., 100)
362
379
  self.assertTupleEqual(a.shape, (100,))
363
380
 
364
381
  def test_geometric(self):
365
- bc.random.seed()
366
- a = bc.random.geometric([0.7, 0.5, 0.2])
382
+ bst.random.seed()
383
+ a = bst.random.geometric([0.7, 0.5, 0.2])
367
384
  self.assertTupleEqual(a.shape, (3,))
368
385
 
369
386
  def test_hypergeometric1(self):
370
- bc.random.seed()
371
- a = bc.random.hypergeometric(10, 10, 10, 20)
387
+ bst.random.seed()
388
+ a = bst.random.hypergeometric(10, 10, 10, 20)
372
389
  self.assertTupleEqual(a.shape, (20,))
373
390
 
374
391
  @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error')
375
392
  def test_hypergeometric2(self):
376
- bc.random.seed()
377
- a = bc.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]])
393
+ bst.random.seed()
394
+ a = bst.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]])
378
395
  self.assertTupleEqual(a.shape, (2, 2))
379
396
 
380
397
  @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error')
381
398
  def test_hypergeometric3(self):
382
- bc.random.seed()
383
- a = bc.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]], size=(3, 2, 2))
399
+ bst.random.seed()
400
+ a = bst.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]], size=(3, 2, 2))
384
401
  self.assertTupleEqual(a.shape, (3, 2, 2))
385
402
 
386
403
  def test_logseries(self):
387
- bc.random.seed()
388
- a = bc.random.logseries([0.7, 0.5, 0.2], size=[4, 3])
404
+ bst.random.seed()
405
+ a = bst.random.logseries([0.7, 0.5, 0.2], size=[4, 3])
389
406
  self.assertTupleEqual(a.shape, (4, 3))
390
407
 
391
408
  def test_multinominal1(self):
392
- bc.random.seed()
409
+ bst.random.seed()
393
410
  a = np.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2])
394
411
  print(a, a.shape)
395
- b = bc.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2])
412
+ b = bst.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2])
396
413
  print(b, b.shape)
397
414
  self.assertTupleEqual(a.shape, b.shape)
398
415
  self.assertTupleEqual(b.shape, (4, 2, 3))
399
416
 
400
417
  def test_multinominal2(self):
401
- bc.random.seed()
402
- a = bc.random.multinomial(100, (0.5, 0.2, 0.3))
418
+ bst.random.seed()
419
+ a = bst.random.multinomial(100, (0.5, 0.2, 0.3))
403
420
  self.assertTupleEqual(a.shape, (3,))
404
421
  self.assertTrue(a.sum() == 100)
405
422
 
406
423
  def test_multivariate_normal1(self):
407
- bc.random.seed()
424
+ bst.random.seed()
408
425
  # self.skipTest('Windows jaxlib error')
409
426
  a = np.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3)
410
- b = bc.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3)
427
+ b = bst.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3)
411
428
  print('test_multivariate_normal1')
412
429
  print(a)
413
430
  print(b)
@@ -415,162 +432,162 @@ class TestRandom(unittest.TestCase):
415
432
  self.assertTupleEqual(a.shape, (3, 2))
416
433
 
417
434
  def test_multivariate_normal2(self):
418
- bc.random.seed()
435
+ bst.random.seed()
419
436
  a = np.random.multivariate_normal([1, 2], [[1, 3], [3, 1]])
420
- b = bc.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], method='svd')
437
+ b = bst.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], method='svd')
421
438
  print(a)
422
439
  print(b)
423
440
  self.assertTupleEqual(a.shape, b.shape)
424
441
  self.assertTupleEqual(a.shape, (2,))
425
442
 
426
443
  def test_negative_binomial(self):
427
- bc.random.seed()
444
+ bst.random.seed()
428
445
  a = np.random.negative_binomial([3., 10.], 0.5)
429
- b = bc.random.negative_binomial([3., 10.], 0.5)
446
+ b = bst.random.negative_binomial([3., 10.], 0.5)
430
447
  print(a)
431
448
  print(b)
432
449
  self.assertTupleEqual(a.shape, b.shape)
433
450
  self.assertTupleEqual(b.shape, (2,))
434
451
 
435
452
  def test_negative_binomial2(self):
436
- bc.random.seed()
453
+ bst.random.seed()
437
454
  a = np.random.negative_binomial(3., 0.5, 10)
438
- b = bc.random.negative_binomial(3., 0.5, 10)
455
+ b = bst.random.negative_binomial(3., 0.5, 10)
439
456
  print(a)
440
457
  print(b)
441
458
  self.assertTupleEqual(a.shape, b.shape)
442
459
  self.assertTupleEqual(b.shape, (10,))
443
460
 
444
461
  def test_noncentral_chisquare(self):
445
- bc.random.seed()
462
+ bst.random.seed()
446
463
  a = np.random.noncentral_chisquare(3, [3., 2.], (4, 2))
447
- b = bc.random.noncentral_chisquare(3, [3., 2.], (4, 2))
464
+ b = bst.random.noncentral_chisquare(3, [3., 2.], (4, 2))
448
465
  self.assertTupleEqual(a.shape, b.shape)
449
466
  self.assertTupleEqual(b.shape, (4, 2))
450
467
 
451
468
  def test_noncentral_chisquare2(self):
452
- bc.random.seed()
453
- a = bc.random.noncentral_chisquare(3, [3., 2.])
469
+ bst.random.seed()
470
+ a = bst.random.noncentral_chisquare(3, [3., 2.])
454
471
  self.assertTupleEqual(a.shape, (2,))
455
472
 
456
473
  def test_noncentral_f(self):
457
- bc.random.seed()
458
- a = bc.random.noncentral_f(3, 20, 3., 100)
474
+ bst.random.seed()
475
+ a = bst.random.noncentral_f(3, 20, 3., 100)
459
476
  self.assertTupleEqual(a.shape, (100,))
460
477
 
461
478
  def test_power(self):
462
- bc.random.seed()
479
+ bst.random.seed()
463
480
  a = np.random.power(2, (4, 2))
464
- b = bc.random.power(2, (4, 2))
481
+ b = bst.random.power(2, (4, 2))
465
482
  self.assertTupleEqual(a.shape, b.shape)
466
483
  self.assertTupleEqual(b.shape, (4, 2))
467
484
 
468
485
  def test_rayleigh(self):
469
- bc.random.seed()
470
- a = bc.random.power(2., (4, 2))
486
+ bst.random.seed()
487
+ a = bst.random.power(2., (4, 2))
471
488
  self.assertTupleEqual(a.shape, (4, 2))
472
489
 
473
490
  def test_triangular(self):
474
- bc.random.seed()
475
- a = bc.random.triangular((2, 2))
491
+ bst.random.seed()
492
+ a = bst.random.triangular((2, 2))
476
493
  self.assertTupleEqual(a.shape, (2, 2))
477
494
 
478
495
  def test_vonmises(self):
479
- bc.random.seed()
496
+ bst.random.seed()
480
497
  a = np.random.vonmises(2., 2.)
481
- b = bc.random.vonmises(2., 2.)
498
+ b = bst.random.vonmises(2., 2.)
482
499
  print(a, b)
483
500
  self.assertTupleEqual(np.shape(a), b.shape)
484
501
  self.assertTupleEqual(b.shape, ())
485
502
 
486
503
  def test_vonmises2(self):
487
- bc.random.seed()
504
+ bst.random.seed()
488
505
  a = np.random.vonmises(2., 2., 10)
489
- b = bc.random.vonmises(2., 2., 10)
506
+ b = bst.random.vonmises(2., 2., 10)
490
507
  print(a, b)
491
508
  self.assertTupleEqual(a.shape, b.shape)
492
509
  self.assertTupleEqual(b.shape, (10,))
493
510
 
494
511
  def test_wald(self):
495
- bc.random.seed()
512
+ bst.random.seed()
496
513
  a = np.random.wald([2., 0.5], 2.)
497
- b = bc.random.wald([2., 0.5], 2.)
514
+ b = bst.random.wald([2., 0.5], 2.)
498
515
  self.assertTupleEqual(a.shape, b.shape)
499
516
  self.assertTupleEqual(b.shape, (2,))
500
517
 
501
518
  def test_wald2(self):
502
- bc.random.seed()
519
+ bst.random.seed()
503
520
  a = np.random.wald(2., 2., 100)
504
- b = bc.random.wald(2., 2., 100)
521
+ b = bst.random.wald(2., 2., 100)
505
522
  self.assertTupleEqual(a.shape, b.shape)
506
523
  self.assertTupleEqual(b.shape, (100,))
507
524
 
508
525
  def test_weibull(self):
509
- bc.random.seed()
510
- a = bc.random.weibull(2., (4, 2))
526
+ bst.random.seed()
527
+ a = bst.random.weibull(2., (4, 2))
511
528
  self.assertTupleEqual(a.shape, (4, 2))
512
529
 
513
530
  def test_weibull2(self):
514
- bc.random.seed()
515
- a = bc.random.weibull(2., )
531
+ bst.random.seed()
532
+ a = bst.random.weibull(2., )
516
533
  self.assertTupleEqual(a.shape, ())
517
534
 
518
535
  def test_weibull3(self):
519
- bc.random.seed()
520
- a = bc.random.weibull([2., 3.], )
536
+ bst.random.seed()
537
+ a = bst.random.weibull([2., 3.], )
521
538
  self.assertTupleEqual(a.shape, (2,))
522
539
 
523
540
  def test_weibull_min(self):
524
- bc.random.seed()
525
- a = bc.random.weibull_min(2., 2., (4, 2))
541
+ bst.random.seed()
542
+ a = bst.random.weibull_min(2., 2., (4, 2))
526
543
  self.assertTupleEqual(a.shape, (4, 2))
527
544
 
528
545
  def test_weibull_min2(self):
529
- bc.random.seed()
530
- a = bc.random.weibull_min(2., 2.)
546
+ bst.random.seed()
547
+ a = bst.random.weibull_min(2., 2.)
531
548
  self.assertTupleEqual(a.shape, ())
532
549
 
533
550
  def test_weibull_min3(self):
534
- bc.random.seed()
535
- a = bc.random.weibull_min([2., 3.], 2.)
551
+ bst.random.seed()
552
+ a = bst.random.weibull_min([2., 3.], 2.)
536
553
  self.assertTupleEqual(a.shape, (2,))
537
554
 
538
555
  def test_zipf(self):
539
- bc.random.seed()
540
- a = bc.random.zipf(2., (4, 2))
556
+ bst.random.seed()
557
+ a = bst.random.zipf(2., (4, 2))
541
558
  self.assertTupleEqual(a.shape, (4, 2))
542
559
 
543
560
  def test_zipf2(self):
544
- bc.random.seed()
561
+ bst.random.seed()
545
562
  a = np.random.zipf([1.1, 2.])
546
- b = bc.random.zipf([1.1, 2.])
563
+ b = bst.random.zipf([1.1, 2.])
547
564
  self.assertTupleEqual(a.shape, b.shape)
548
565
  self.assertTupleEqual(b.shape, (2,))
549
566
 
550
567
  def test_maxwell(self):
551
- bc.random.seed()
552
- a = bc.random.maxwell(10)
568
+ bst.random.seed()
569
+ a = bst.random.maxwell(10)
553
570
  self.assertTupleEqual(a.shape, (10,))
554
571
 
555
572
  def test_maxwell2(self):
556
- bc.random.seed()
557
- a = bc.random.maxwell()
573
+ bst.random.seed()
574
+ a = bst.random.maxwell()
558
575
  self.assertTupleEqual(a.shape, ())
559
576
 
560
577
  def test_t(self):
561
- bc.random.seed()
562
- a = bc.random.t(1., size=10)
578
+ bst.random.seed()
579
+ a = bst.random.t(1., size=10)
563
580
  self.assertTupleEqual(a.shape, (10,))
564
581
 
565
582
  def test_t2(self):
566
- bc.random.seed()
567
- a = bc.random.t([1., 2.], size=None)
583
+ bst.random.seed()
584
+ a = bst.random.t([1., 2.], size=None)
568
585
  self.assertTupleEqual(a.shape, (2,))
569
586
 
570
587
 
571
588
  class TestRandomKey(unittest.TestCase):
572
589
  def test_clear_memory(self):
573
- bc.random.split_key()
574
- bc.util.clear_buffer_memory()
575
- print(bc.random.DEFAULT.value)
576
- self.assertTrue(isinstance(bc.random.DEFAULT.value, np.ndarray))
590
+ bst.random.split_key()
591
+ bst.util.clear_buffer_memory()
592
+ print(bst.random.DEFAULT.value)
593
+ self.assertTrue(isinstance(bst.random.DEFAULT.value, np.ndarray))