brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
brainstate/random/random_test.py
DELETED
@@ -1,593 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
import platform
|
18
|
-
import unittest
|
19
|
-
|
20
|
-
import jax.numpy as jnp
|
21
|
-
import jax.random
|
22
|
-
import jax.random as jr
|
23
|
-
import numpy as np
|
24
|
-
import pytest
|
25
|
-
|
26
|
-
import brainstate as bst
|
27
|
-
|
28
|
-
|
29
|
-
class TestRandom(unittest.TestCase):
|
30
|
-
|
31
|
-
def test_seed2(self):
|
32
|
-
test_seed = 299
|
33
|
-
key = jax.random.PRNGKey(test_seed)
|
34
|
-
bst.random.seed(key)
|
35
|
-
|
36
|
-
@jax.jit
|
37
|
-
def jit_seed(key):
|
38
|
-
bst.random.seed(key)
|
39
|
-
with bst.random.seed_context(key):
|
40
|
-
print(bst.random.DEFAULT.value)
|
41
|
-
|
42
|
-
jit_seed(key)
|
43
|
-
jit_seed(1)
|
44
|
-
jit_seed(None)
|
45
|
-
|
46
|
-
def test_seed(self):
|
47
|
-
test_seed = 299
|
48
|
-
bst.random.seed(test_seed)
|
49
|
-
a = bst.random.rand(3)
|
50
|
-
bst.random.seed(test_seed)
|
51
|
-
b = bst.random.rand(3)
|
52
|
-
self.assertTrue(jnp.array_equal(a, b))
|
53
|
-
|
54
|
-
def test_rand(self):
|
55
|
-
bst.random.seed()
|
56
|
-
a = bst.random.rand(3, 2)
|
57
|
-
self.assertTupleEqual(a.shape, (3, 2))
|
58
|
-
self.assertTrue((a >= 0).all() and (a < 1).all())
|
59
|
-
|
60
|
-
key = jr.PRNGKey(123)
|
61
|
-
jres = jr.uniform(key, shape=(10, 100))
|
62
|
-
self.assertTrue(jnp.allclose(jres, bst.random.rand(10, 100, key=key)))
|
63
|
-
self.assertTrue(jnp.allclose(jres, bst.random.rand(10, 100, key=123)))
|
64
|
-
|
65
|
-
def test_randint1(self):
|
66
|
-
bst.random.seed()
|
67
|
-
a = bst.random.randint(5)
|
68
|
-
self.assertTupleEqual(a.shape, ())
|
69
|
-
self.assertTrue(0 <= a < 5)
|
70
|
-
|
71
|
-
def test_randint2(self):
|
72
|
-
bst.random.seed()
|
73
|
-
a = bst.random.randint(2, 6, size=(4, 3))
|
74
|
-
self.assertTupleEqual(a.shape, (4, 3))
|
75
|
-
self.assertTrue((a >= 2).all() and (a < 6).all())
|
76
|
-
|
77
|
-
def test_randint3(self):
|
78
|
-
bst.random.seed()
|
79
|
-
a = bst.random.randint([1, 2, 3], [10, 7, 8])
|
80
|
-
self.assertTupleEqual(a.shape, (3,))
|
81
|
-
self.assertTrue((a - jnp.array([1, 2, 3]) >= 0).all()
|
82
|
-
and (-a + jnp.array([10, 7, 8]) > 0).all())
|
83
|
-
|
84
|
-
def test_randint4(self):
|
85
|
-
bst.random.seed()
|
86
|
-
a = bst.random.randint([1, 2, 3], [10, 7, 8], size=(2, 3))
|
87
|
-
self.assertTupleEqual(a.shape, (2, 3))
|
88
|
-
|
89
|
-
def test_randn(self):
|
90
|
-
bst.random.seed()
|
91
|
-
a = bst.random.randn(3, 2)
|
92
|
-
self.assertTupleEqual(a.shape, (3, 2))
|
93
|
-
|
94
|
-
def test_random1(self):
|
95
|
-
bst.random.seed()
|
96
|
-
a = bst.random.random()
|
97
|
-
self.assertTrue(0. <= a < 1)
|
98
|
-
|
99
|
-
def test_random2(self):
|
100
|
-
bst.random.seed()
|
101
|
-
a = bst.random.random(size=(3, 2))
|
102
|
-
self.assertTupleEqual(a.shape, (3, 2))
|
103
|
-
self.assertTrue((a >= 0).all() and (a < 1).all())
|
104
|
-
|
105
|
-
def test_random_sample(self):
|
106
|
-
bst.random.seed()
|
107
|
-
a = bst.random.random_sample(size=(3, 2))
|
108
|
-
self.assertTupleEqual(a.shape, (3, 2))
|
109
|
-
self.assertTrue((a >= 0).all() and (a < 1).all())
|
110
|
-
|
111
|
-
def test_choice1(self):
|
112
|
-
bst.random.seed()
|
113
|
-
a = bst.random.choice(5)
|
114
|
-
self.assertTupleEqual(jnp.shape(a), ())
|
115
|
-
self.assertTrue(0 <= a < 5)
|
116
|
-
|
117
|
-
def test_choice2(self):
|
118
|
-
bst.random.seed()
|
119
|
-
a = bst.random.choice(5, 3, p=[0.1, 0.4, 0.2, 0., 0.3])
|
120
|
-
self.assertTupleEqual(a.shape, (3,))
|
121
|
-
self.assertTrue((a >= 0).all() and (a < 5).all())
|
122
|
-
|
123
|
-
def test_choice3(self):
|
124
|
-
bst.random.seed()
|
125
|
-
a = bst.random.choice(jnp.arange(2, 20), size=(4, 3), replace=False)
|
126
|
-
self.assertTupleEqual(a.shape, (4, 3))
|
127
|
-
self.assertTrue((a >= 2).all() and (a < 20).all())
|
128
|
-
self.assertEqual(len(jnp.unique(a)), 12)
|
129
|
-
|
130
|
-
def test_permutation1(self):
|
131
|
-
bst.random.seed()
|
132
|
-
a = bst.random.permutation(10)
|
133
|
-
self.assertTupleEqual(a.shape, (10,))
|
134
|
-
self.assertEqual(len(jnp.unique(a)), 10)
|
135
|
-
|
136
|
-
def test_permutation2(self):
|
137
|
-
bst.random.seed()
|
138
|
-
a = bst.random.permutation(jnp.arange(10))
|
139
|
-
self.assertTupleEqual(a.shape, (10,))
|
140
|
-
self.assertEqual(len(jnp.unique(a)), 10)
|
141
|
-
|
142
|
-
def test_shuffle1(self):
|
143
|
-
bst.random.seed()
|
144
|
-
a = jnp.arange(10)
|
145
|
-
bst.random.shuffle(a)
|
146
|
-
self.assertTupleEqual(a.shape, (10,))
|
147
|
-
self.assertEqual(len(jnp.unique(a)), 10)
|
148
|
-
|
149
|
-
def test_shuffle2(self):
|
150
|
-
bst.random.seed()
|
151
|
-
a = jnp.arange(12).reshape(4, 3)
|
152
|
-
bst.random.shuffle(a, axis=1)
|
153
|
-
self.assertTupleEqual(a.shape, (4, 3))
|
154
|
-
self.assertEqual(len(jnp.unique(a)), 12)
|
155
|
-
|
156
|
-
# test that a is only shuffled along axis 1
|
157
|
-
uni = jnp.unique(jnp.diff(a, axis=0))
|
158
|
-
self.assertEqual(uni, jnp.asarray([3]))
|
159
|
-
|
160
|
-
def test_beta1(self):
|
161
|
-
bst.random.seed()
|
162
|
-
a = bst.random.beta(2, 2)
|
163
|
-
self.assertTupleEqual(a.shape, ())
|
164
|
-
|
165
|
-
def test_beta2(self):
|
166
|
-
bst.random.seed()
|
167
|
-
a = bst.random.beta([2, 2, 3], 2, size=(3,))
|
168
|
-
self.assertTupleEqual(a.shape, (3,))
|
169
|
-
|
170
|
-
def test_exponential1(self):
|
171
|
-
bst.random.seed()
|
172
|
-
a = bst.random.exponential(10., size=[3, 2])
|
173
|
-
self.assertTupleEqual(a.shape, (3, 2))
|
174
|
-
|
175
|
-
def test_exponential2(self):
|
176
|
-
bst.random.seed()
|
177
|
-
a = bst.random.exponential([1., 2., 5.])
|
178
|
-
self.assertTupleEqual(a.shape, (3,))
|
179
|
-
|
180
|
-
def test_gamma(self):
|
181
|
-
bst.random.seed()
|
182
|
-
a = bst.random.gamma(2, 10., size=[3, 2])
|
183
|
-
self.assertTupleEqual(a.shape, (3, 2))
|
184
|
-
|
185
|
-
def test_gumbel(self):
|
186
|
-
bst.random.seed()
|
187
|
-
a = bst.random.gumbel(0., 2., size=[3, 2])
|
188
|
-
self.assertTupleEqual(a.shape, (3, 2))
|
189
|
-
|
190
|
-
def test_laplace(self):
|
191
|
-
bst.random.seed()
|
192
|
-
a = bst.random.laplace(0., 2., size=[3, 2])
|
193
|
-
self.assertTupleEqual(a.shape, (3, 2))
|
194
|
-
|
195
|
-
def test_logistic(self):
|
196
|
-
bst.random.seed()
|
197
|
-
a = bst.random.logistic(0., 2., size=[3, 2])
|
198
|
-
self.assertTupleEqual(a.shape, (3, 2))
|
199
|
-
|
200
|
-
def test_normal1(self):
|
201
|
-
bst.random.seed()
|
202
|
-
a = bst.random.normal()
|
203
|
-
self.assertTupleEqual(a.shape, ())
|
204
|
-
|
205
|
-
def test_normal2(self):
|
206
|
-
bst.random.seed()
|
207
|
-
a = bst.random.normal(loc=[0., 2., 4.], scale=[1., 2., 3.])
|
208
|
-
self.assertTupleEqual(a.shape, (3,))
|
209
|
-
|
210
|
-
def test_normal3(self):
|
211
|
-
bst.random.seed()
|
212
|
-
a = bst.random.normal(loc=[0., 2., 4.], scale=[[1., 2., 3.], [1., 1., 1.]])
|
213
|
-
print(a)
|
214
|
-
self.assertTupleEqual(a.shape, (2, 3))
|
215
|
-
|
216
|
-
def test_pareto(self):
|
217
|
-
bst.random.seed()
|
218
|
-
a = bst.random.pareto([1, 2, 2])
|
219
|
-
self.assertTupleEqual(a.shape, (3,))
|
220
|
-
|
221
|
-
def test_poisson(self):
|
222
|
-
bst.random.seed()
|
223
|
-
a = bst.random.poisson([1., 2., 2.], size=3)
|
224
|
-
self.assertTupleEqual(a.shape, (3,))
|
225
|
-
|
226
|
-
def test_standard_cauchy(self):
|
227
|
-
bst.random.seed()
|
228
|
-
a = bst.random.standard_cauchy(size=(3, 2))
|
229
|
-
self.assertTupleEqual(a.shape, (3, 2))
|
230
|
-
|
231
|
-
def test_standard_exponential(self):
|
232
|
-
bst.random.seed()
|
233
|
-
a = bst.random.standard_exponential(size=(3, 2))
|
234
|
-
self.assertTupleEqual(a.shape, (3, 2))
|
235
|
-
|
236
|
-
def test_standard_gamma(self):
|
237
|
-
bst.random.seed()
|
238
|
-
a = bst.random.standard_gamma(shape=[1, 2, 4], size=3)
|
239
|
-
self.assertTupleEqual(a.shape, (3,))
|
240
|
-
|
241
|
-
def test_standard_normal(self):
|
242
|
-
bst.random.seed()
|
243
|
-
a = bst.random.standard_normal(size=(3, 2))
|
244
|
-
self.assertTupleEqual(a.shape, (3, 2))
|
245
|
-
|
246
|
-
def test_standard_t(self):
|
247
|
-
bst.random.seed()
|
248
|
-
a = bst.random.standard_t(df=[1, 2, 4], size=3)
|
249
|
-
self.assertTupleEqual(a.shape, (3,))
|
250
|
-
|
251
|
-
def test_standard_uniform1(self):
|
252
|
-
bst.random.seed()
|
253
|
-
a = bst.random.uniform()
|
254
|
-
self.assertTupleEqual(a.shape, ())
|
255
|
-
self.assertTrue(0 <= a < 1)
|
256
|
-
|
257
|
-
def test_uniform2(self):
|
258
|
-
bst.random.seed()
|
259
|
-
a = bst.random.uniform(low=[-1., 5., 2.], high=[2., 6., 10.], size=3)
|
260
|
-
self.assertTupleEqual(a.shape, (3,))
|
261
|
-
self.assertTrue((a - jnp.array([-1., 5., 2.]) >= 0).all()
|
262
|
-
and (-a + jnp.array([2., 6., 10.]) > 0).all())
|
263
|
-
|
264
|
-
def test_uniform3(self):
|
265
|
-
bst.random.seed()
|
266
|
-
a = bst.random.uniform(low=-1., high=[2., 6., 10.], size=(2, 3))
|
267
|
-
self.assertTupleEqual(a.shape, (2, 3))
|
268
|
-
|
269
|
-
def test_uniform4(self):
|
270
|
-
bst.random.seed()
|
271
|
-
a = bst.random.uniform(low=[-1., 5., 2.], high=[[2., 6., 10.], [10., 10., 10.]])
|
272
|
-
self.assertTupleEqual(a.shape, (2, 3))
|
273
|
-
|
274
|
-
def test_truncated_normal1(self):
|
275
|
-
bst.random.seed()
|
276
|
-
a = bst.random.truncated_normal(-1., 1.)
|
277
|
-
self.assertTupleEqual(a.shape, ())
|
278
|
-
self.assertTrue(-1. <= a <= 1.)
|
279
|
-
|
280
|
-
def test_truncated_normal2(self):
|
281
|
-
bst.random.seed()
|
282
|
-
a = bst.random.truncated_normal(-1., [1., 2., 1.], size=(4, 3))
|
283
|
-
self.assertTupleEqual(a.shape, (4, 3))
|
284
|
-
|
285
|
-
def test_truncated_normal3(self):
|
286
|
-
bst.random.seed()
|
287
|
-
a = bst.random.truncated_normal([-1., 0., 1.], [[2., 2., 4.], [2., 2., 4.]])
|
288
|
-
self.assertTupleEqual(a.shape, (2, 3))
|
289
|
-
self.assertTrue((a - jnp.array([-1., 0., 1.]) >= 0.).all()
|
290
|
-
and (- a + jnp.array([2., 2., 4.]) >= 0.).all())
|
291
|
-
|
292
|
-
def test_bernoulli1(self):
|
293
|
-
bst.random.seed()
|
294
|
-
a = bst.random.bernoulli()
|
295
|
-
self.assertTupleEqual(a.shape, ())
|
296
|
-
self.assertTrue(a == 0 or a == 1)
|
297
|
-
|
298
|
-
def test_bernoulli2(self):
|
299
|
-
bst.random.seed()
|
300
|
-
a = bst.random.bernoulli([0.5, 0.6, 0.8])
|
301
|
-
self.assertTupleEqual(a.shape, (3,))
|
302
|
-
self.assertTrue(jnp.logical_xor(a == 1, a == 0).all())
|
303
|
-
|
304
|
-
def test_bernoulli3(self):
|
305
|
-
bst.random.seed()
|
306
|
-
a = bst.random.bernoulli([0.5, 0.6], size=(3, 2))
|
307
|
-
self.assertTupleEqual(a.shape, (3, 2))
|
308
|
-
self.assertTrue(jnp.logical_xor(a == 1, a == 0).all())
|
309
|
-
|
310
|
-
def test_lognormal1(self):
|
311
|
-
bst.random.seed()
|
312
|
-
a = bst.random.lognormal()
|
313
|
-
self.assertTupleEqual(a.shape, ())
|
314
|
-
|
315
|
-
def test_lognormal2(self):
|
316
|
-
bst.random.seed()
|
317
|
-
a = bst.random.lognormal(sigma=[2., 1.], size=[3, 2])
|
318
|
-
self.assertTupleEqual(a.shape, (3, 2))
|
319
|
-
|
320
|
-
def test_lognormal3(self):
|
321
|
-
bst.random.seed()
|
322
|
-
a = bst.random.lognormal([2., 0.], [[2., 1.], [3., 1.2]])
|
323
|
-
self.assertTupleEqual(a.shape, (2, 2))
|
324
|
-
|
325
|
-
def test_binomial1(self):
|
326
|
-
bst.random.seed()
|
327
|
-
a = bst.random.binomial(5, 0.5)
|
328
|
-
b = np.random.binomial(5, 0.5)
|
329
|
-
print(a)
|
330
|
-
print(b)
|
331
|
-
self.assertTupleEqual(a.shape, ())
|
332
|
-
self.assertTrue(a.dtype, int)
|
333
|
-
|
334
|
-
def test_binomial2(self):
|
335
|
-
bst.random.seed()
|
336
|
-
a = bst.random.binomial(5, 0.5, size=(3, 2))
|
337
|
-
self.assertTupleEqual(a.shape, (3, 2))
|
338
|
-
self.assertTrue((a >= 0).all() and (a <= 5).all())
|
339
|
-
|
340
|
-
def test_binomial3(self):
|
341
|
-
bst.random.seed()
|
342
|
-
a = bst.random.binomial(n=jnp.asarray([2, 3, 4]), p=jnp.asarray([[0.5, 0.5, 0.5], [0.6, 0.6, 0.6]]))
|
343
|
-
self.assertTupleEqual(a.shape, (2, 3))
|
344
|
-
|
345
|
-
def test_chisquare1(self):
|
346
|
-
bst.random.seed()
|
347
|
-
a = bst.random.chisquare(3)
|
348
|
-
self.assertTupleEqual(a.shape, ())
|
349
|
-
self.assertTrue(a.dtype, float)
|
350
|
-
|
351
|
-
def test_chisquare2(self):
|
352
|
-
bst.random.seed()
|
353
|
-
with self.assertRaises(NotImplementedError):
|
354
|
-
a = bst.random.chisquare(df=[2, 3, 4])
|
355
|
-
|
356
|
-
def test_chisquare3(self):
|
357
|
-
bst.random.seed()
|
358
|
-
a = bst.random.chisquare(df=2, size=100)
|
359
|
-
self.assertTupleEqual(a.shape, (100,))
|
360
|
-
|
361
|
-
def test_chisquare4(self):
|
362
|
-
bst.random.seed()
|
363
|
-
a = bst.random.chisquare(df=2, size=(100, 10))
|
364
|
-
self.assertTupleEqual(a.shape, (100, 10))
|
365
|
-
|
366
|
-
def test_dirichlet1(self):
|
367
|
-
bst.random.seed()
|
368
|
-
a = bst.random.dirichlet((10, 5, 3))
|
369
|
-
self.assertTupleEqual(a.shape, (3,))
|
370
|
-
|
371
|
-
def test_dirichlet2(self):
|
372
|
-
bst.random.seed()
|
373
|
-
a = bst.random.dirichlet((10, 5, 3), 20)
|
374
|
-
self.assertTupleEqual(a.shape, (20, 3))
|
375
|
-
|
376
|
-
def test_f(self):
|
377
|
-
bst.random.seed()
|
378
|
-
a = bst.random.f(1., 48., 100)
|
379
|
-
self.assertTupleEqual(a.shape, (100,))
|
380
|
-
|
381
|
-
def test_geometric(self):
|
382
|
-
bst.random.seed()
|
383
|
-
a = bst.random.geometric([0.7, 0.5, 0.2])
|
384
|
-
self.assertTupleEqual(a.shape, (3,))
|
385
|
-
|
386
|
-
def test_hypergeometric1(self):
|
387
|
-
bst.random.seed()
|
388
|
-
a = bst.random.hypergeometric(10, 10, 10, 20)
|
389
|
-
self.assertTupleEqual(a.shape, (20,))
|
390
|
-
|
391
|
-
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error')
|
392
|
-
def test_hypergeometric2(self):
|
393
|
-
bst.random.seed()
|
394
|
-
a = bst.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]])
|
395
|
-
self.assertTupleEqual(a.shape, (2, 2))
|
396
|
-
|
397
|
-
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error')
|
398
|
-
def test_hypergeometric3(self):
|
399
|
-
bst.random.seed()
|
400
|
-
a = bst.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]], size=(3, 2, 2))
|
401
|
-
self.assertTupleEqual(a.shape, (3, 2, 2))
|
402
|
-
|
403
|
-
def test_logseries(self):
|
404
|
-
bst.random.seed()
|
405
|
-
a = bst.random.logseries([0.7, 0.5, 0.2], size=[4, 3])
|
406
|
-
self.assertTupleEqual(a.shape, (4, 3))
|
407
|
-
|
408
|
-
def test_multinominal1(self):
|
409
|
-
bst.random.seed()
|
410
|
-
a = np.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2])
|
411
|
-
print(a, a.shape)
|
412
|
-
b = bst.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2])
|
413
|
-
print(b, b.shape)
|
414
|
-
self.assertTupleEqual(a.shape, b.shape)
|
415
|
-
self.assertTupleEqual(b.shape, (4, 2, 3))
|
416
|
-
|
417
|
-
def test_multinominal2(self):
|
418
|
-
bst.random.seed()
|
419
|
-
a = bst.random.multinomial(100, (0.5, 0.2, 0.3))
|
420
|
-
self.assertTupleEqual(a.shape, (3,))
|
421
|
-
self.assertTrue(a.sum() == 100)
|
422
|
-
|
423
|
-
def test_multivariate_normal1(self):
|
424
|
-
bst.random.seed()
|
425
|
-
# self.skipTest('Windows jaxlib error')
|
426
|
-
a = np.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3)
|
427
|
-
b = bst.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3)
|
428
|
-
print('test_multivariate_normal1')
|
429
|
-
print(a)
|
430
|
-
print(b)
|
431
|
-
self.assertTupleEqual(a.shape, b.shape)
|
432
|
-
self.assertTupleEqual(a.shape, (3, 2))
|
433
|
-
|
434
|
-
def test_multivariate_normal2(self):
|
435
|
-
bst.random.seed()
|
436
|
-
a = np.random.multivariate_normal([1, 2], [[1, 3], [3, 1]])
|
437
|
-
b = bst.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], method='svd')
|
438
|
-
print(a)
|
439
|
-
print(b)
|
440
|
-
self.assertTupleEqual(a.shape, b.shape)
|
441
|
-
self.assertTupleEqual(a.shape, (2,))
|
442
|
-
|
443
|
-
def test_negative_binomial(self):
|
444
|
-
bst.random.seed()
|
445
|
-
a = np.random.negative_binomial([3., 10.], 0.5)
|
446
|
-
b = bst.random.negative_binomial([3., 10.], 0.5)
|
447
|
-
print(a)
|
448
|
-
print(b)
|
449
|
-
self.assertTupleEqual(a.shape, b.shape)
|
450
|
-
self.assertTupleEqual(b.shape, (2,))
|
451
|
-
|
452
|
-
def test_negative_binomial2(self):
|
453
|
-
bst.random.seed()
|
454
|
-
a = np.random.negative_binomial(3., 0.5, 10)
|
455
|
-
b = bst.random.negative_binomial(3., 0.5, 10)
|
456
|
-
print(a)
|
457
|
-
print(b)
|
458
|
-
self.assertTupleEqual(a.shape, b.shape)
|
459
|
-
self.assertTupleEqual(b.shape, (10,))
|
460
|
-
|
461
|
-
def test_noncentral_chisquare(self):
|
462
|
-
bst.random.seed()
|
463
|
-
a = np.random.noncentral_chisquare(3, [3., 2.], (4, 2))
|
464
|
-
b = bst.random.noncentral_chisquare(3, [3., 2.], (4, 2))
|
465
|
-
self.assertTupleEqual(a.shape, b.shape)
|
466
|
-
self.assertTupleEqual(b.shape, (4, 2))
|
467
|
-
|
468
|
-
def test_noncentral_chisquare2(self):
|
469
|
-
bst.random.seed()
|
470
|
-
a = bst.random.noncentral_chisquare(3, [3., 2.])
|
471
|
-
self.assertTupleEqual(a.shape, (2,))
|
472
|
-
|
473
|
-
def test_noncentral_f(self):
|
474
|
-
bst.random.seed()
|
475
|
-
a = bst.random.noncentral_f(3, 20, 3., 100)
|
476
|
-
self.assertTupleEqual(a.shape, (100,))
|
477
|
-
|
478
|
-
def test_power(self):
|
479
|
-
bst.random.seed()
|
480
|
-
a = np.random.power(2, (4, 2))
|
481
|
-
b = bst.random.power(2, (4, 2))
|
482
|
-
self.assertTupleEqual(a.shape, b.shape)
|
483
|
-
self.assertTupleEqual(b.shape, (4, 2))
|
484
|
-
|
485
|
-
def test_rayleigh(self):
|
486
|
-
bst.random.seed()
|
487
|
-
a = bst.random.power(2., (4, 2))
|
488
|
-
self.assertTupleEqual(a.shape, (4, 2))
|
489
|
-
|
490
|
-
def test_triangular(self):
|
491
|
-
bst.random.seed()
|
492
|
-
a = bst.random.triangular((2, 2))
|
493
|
-
self.assertTupleEqual(a.shape, (2, 2))
|
494
|
-
|
495
|
-
def test_vonmises(self):
|
496
|
-
bst.random.seed()
|
497
|
-
a = np.random.vonmises(2., 2.)
|
498
|
-
b = bst.random.vonmises(2., 2.)
|
499
|
-
print(a, b)
|
500
|
-
self.assertTupleEqual(np.shape(a), b.shape)
|
501
|
-
self.assertTupleEqual(b.shape, ())
|
502
|
-
|
503
|
-
def test_vonmises2(self):
|
504
|
-
bst.random.seed()
|
505
|
-
a = np.random.vonmises(2., 2., 10)
|
506
|
-
b = bst.random.vonmises(2., 2., 10)
|
507
|
-
print(a, b)
|
508
|
-
self.assertTupleEqual(a.shape, b.shape)
|
509
|
-
self.assertTupleEqual(b.shape, (10,))
|
510
|
-
|
511
|
-
def test_wald(self):
|
512
|
-
bst.random.seed()
|
513
|
-
a = np.random.wald([2., 0.5], 2.)
|
514
|
-
b = bst.random.wald([2., 0.5], 2.)
|
515
|
-
self.assertTupleEqual(a.shape, b.shape)
|
516
|
-
self.assertTupleEqual(b.shape, (2,))
|
517
|
-
|
518
|
-
def test_wald2(self):
|
519
|
-
bst.random.seed()
|
520
|
-
a = np.random.wald(2., 2., 100)
|
521
|
-
b = bst.random.wald(2., 2., 100)
|
522
|
-
self.assertTupleEqual(a.shape, b.shape)
|
523
|
-
self.assertTupleEqual(b.shape, (100,))
|
524
|
-
|
525
|
-
def test_weibull(self):
|
526
|
-
bst.random.seed()
|
527
|
-
a = bst.random.weibull(2., (4, 2))
|
528
|
-
self.assertTupleEqual(a.shape, (4, 2))
|
529
|
-
|
530
|
-
def test_weibull2(self):
|
531
|
-
bst.random.seed()
|
532
|
-
a = bst.random.weibull(2., )
|
533
|
-
self.assertTupleEqual(a.shape, ())
|
534
|
-
|
535
|
-
def test_weibull3(self):
|
536
|
-
bst.random.seed()
|
537
|
-
a = bst.random.weibull([2., 3.], )
|
538
|
-
self.assertTupleEqual(a.shape, (2,))
|
539
|
-
|
540
|
-
def test_weibull_min(self):
|
541
|
-
bst.random.seed()
|
542
|
-
a = bst.random.weibull_min(2., 2., (4, 2))
|
543
|
-
self.assertTupleEqual(a.shape, (4, 2))
|
544
|
-
|
545
|
-
def test_weibull_min2(self):
|
546
|
-
bst.random.seed()
|
547
|
-
a = bst.random.weibull_min(2., 2.)
|
548
|
-
self.assertTupleEqual(a.shape, ())
|
549
|
-
|
550
|
-
def test_weibull_min3(self):
|
551
|
-
bst.random.seed()
|
552
|
-
a = bst.random.weibull_min([2., 3.], 2.)
|
553
|
-
self.assertTupleEqual(a.shape, (2,))
|
554
|
-
|
555
|
-
def test_zipf(self):
|
556
|
-
bst.random.seed()
|
557
|
-
a = bst.random.zipf(2., (4, 2))
|
558
|
-
self.assertTupleEqual(a.shape, (4, 2))
|
559
|
-
|
560
|
-
def test_zipf2(self):
|
561
|
-
bst.random.seed()
|
562
|
-
a = np.random.zipf([1.1, 2.])
|
563
|
-
b = bst.random.zipf([1.1, 2.])
|
564
|
-
self.assertTupleEqual(a.shape, b.shape)
|
565
|
-
self.assertTupleEqual(b.shape, (2,))
|
566
|
-
|
567
|
-
def test_maxwell(self):
|
568
|
-
bst.random.seed()
|
569
|
-
a = bst.random.maxwell(10)
|
570
|
-
self.assertTupleEqual(a.shape, (10,))
|
571
|
-
|
572
|
-
def test_maxwell2(self):
|
573
|
-
bst.random.seed()
|
574
|
-
a = bst.random.maxwell()
|
575
|
-
self.assertTupleEqual(a.shape, ())
|
576
|
-
|
577
|
-
def test_t(self):
|
578
|
-
bst.random.seed()
|
579
|
-
a = bst.random.t(1., size=10)
|
580
|
-
self.assertTupleEqual(a.shape, (10,))
|
581
|
-
|
582
|
-
def test_t2(self):
|
583
|
-
bst.random.seed()
|
584
|
-
a = bst.random.t([1., 2.], size=None)
|
585
|
-
self.assertTupleEqual(a.shape, (2,))
|
586
|
-
|
587
|
-
|
588
|
-
class TestRandomKey(unittest.TestCase):
|
589
|
-
def test_clear_memory(self):
|
590
|
-
bst.random.split_key()
|
591
|
-
bst.util.clear_buffer_memory()
|
592
|
-
print(bst.random.DEFAULT.value)
|
593
|
-
self.assertTrue(isinstance(bst.random.DEFAULT.value, np.ndarray))
|