brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,568 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import platform
|
19
|
+
import unittest
|
20
|
+
|
21
|
+
import jax.numpy as jnp
|
22
|
+
import jax.random as jr
|
23
|
+
import numpy as np
|
24
|
+
import pytest
|
25
|
+
|
26
|
+
import brainstate as bst
|
27
|
+
|
28
|
+
|
29
|
+
class TestRandom(unittest.TestCase):
|
30
|
+
|
31
|
+
def test_rand(self):
|
32
|
+
bst.random.seed()
|
33
|
+
a = bst.random.rand(3, 2)
|
34
|
+
self.assertTupleEqual(a.shape, (3, 2))
|
35
|
+
self.assertTrue((a >= 0).all() and (a < 1).all())
|
36
|
+
|
37
|
+
key = jr.PRNGKey(123)
|
38
|
+
jres = jr.uniform(key, shape=(10, 100))
|
39
|
+
self.assertTrue(jnp.allclose(jres, bst.random.rand(10, 100, key=key)))
|
40
|
+
self.assertTrue(jnp.allclose(jres, bst.random.rand(10, 100, key=123)))
|
41
|
+
|
42
|
+
def test_randint1(self):
|
43
|
+
bst.random.seed()
|
44
|
+
a = bst.random.randint(5)
|
45
|
+
self.assertTupleEqual(a.shape, ())
|
46
|
+
self.assertTrue(0 <= a < 5)
|
47
|
+
|
48
|
+
def test_randint2(self):
|
49
|
+
bst.random.seed()
|
50
|
+
a = bst.random.randint(2, 6, size=(4, 3))
|
51
|
+
self.assertTupleEqual(a.shape, (4, 3))
|
52
|
+
self.assertTrue((a >= 2).all() and (a < 6).all())
|
53
|
+
|
54
|
+
def test_randint3(self):
|
55
|
+
bst.random.seed()
|
56
|
+
a = bst.random.randint([1, 2, 3], [10, 7, 8])
|
57
|
+
self.assertTupleEqual(a.shape, (3,))
|
58
|
+
self.assertTrue((a - jnp.array([1, 2, 3]) >= 0).all()
|
59
|
+
and (-a + jnp.array([10, 7, 8]) > 0).all())
|
60
|
+
|
61
|
+
def test_randint4(self):
|
62
|
+
bst.random.seed()
|
63
|
+
a = bst.random.randint([1, 2, 3], [10, 7, 8], size=(2, 3))
|
64
|
+
self.assertTupleEqual(a.shape, (2, 3))
|
65
|
+
|
66
|
+
def test_randn(self):
|
67
|
+
bst.random.seed()
|
68
|
+
a = bst.random.randn(3, 2)
|
69
|
+
self.assertTupleEqual(a.shape, (3, 2))
|
70
|
+
|
71
|
+
def test_random1(self):
|
72
|
+
bst.random.seed()
|
73
|
+
a = bst.random.random()
|
74
|
+
self.assertTrue(0. <= a < 1)
|
75
|
+
|
76
|
+
def test_random2(self):
|
77
|
+
bst.random.seed()
|
78
|
+
a = bst.random.random(size=(3, 2))
|
79
|
+
self.assertTupleEqual(a.shape, (3, 2))
|
80
|
+
self.assertTrue((a >= 0).all() and (a < 1).all())
|
81
|
+
|
82
|
+
def test_random_sample(self):
|
83
|
+
bst.random.seed()
|
84
|
+
a = bst.random.random_sample(size=(3, 2))
|
85
|
+
self.assertTupleEqual(a.shape, (3, 2))
|
86
|
+
self.assertTrue((a >= 0).all() and (a < 1).all())
|
87
|
+
|
88
|
+
def test_choice1(self):
|
89
|
+
bst.random.seed()
|
90
|
+
a = bst.random.choice(5)
|
91
|
+
self.assertTupleEqual(jnp.shape(a), ())
|
92
|
+
self.assertTrue(0 <= a < 5)
|
93
|
+
|
94
|
+
def test_choice2(self):
|
95
|
+
bst.random.seed()
|
96
|
+
a = bst.random.choice(5, 3, p=[0.1, 0.4, 0.2, 0., 0.3])
|
97
|
+
self.assertTupleEqual(a.shape, (3,))
|
98
|
+
self.assertTrue((a >= 0).all() and (a < 5).all())
|
99
|
+
|
100
|
+
def test_choice3(self):
|
101
|
+
bst.random.seed()
|
102
|
+
a = bst.random.choice(jnp.arange(2, 20), size=(4, 3), replace=False)
|
103
|
+
self.assertTupleEqual(a.shape, (4, 3))
|
104
|
+
self.assertTrue((a >= 2).all() and (a < 20).all())
|
105
|
+
self.assertEqual(len(jnp.unique(a)), 12)
|
106
|
+
|
107
|
+
def test_permutation1(self):
|
108
|
+
bst.random.seed()
|
109
|
+
a = bst.random.permutation(10)
|
110
|
+
self.assertTupleEqual(a.shape, (10,))
|
111
|
+
self.assertEqual(len(jnp.unique(a)), 10)
|
112
|
+
|
113
|
+
def test_permutation2(self):
|
114
|
+
bst.random.seed()
|
115
|
+
a = bst.random.permutation(jnp.arange(10))
|
116
|
+
self.assertTupleEqual(a.shape, (10,))
|
117
|
+
self.assertEqual(len(jnp.unique(a)), 10)
|
118
|
+
|
119
|
+
def test_shuffle1(self):
|
120
|
+
bst.random.seed()
|
121
|
+
a = jnp.arange(10)
|
122
|
+
bst.random.shuffle(a)
|
123
|
+
self.assertTupleEqual(a.shape, (10,))
|
124
|
+
self.assertEqual(len(jnp.unique(a)), 10)
|
125
|
+
|
126
|
+
def test_shuffle2(self):
|
127
|
+
bst.random.seed()
|
128
|
+
a = jnp.arange(12).reshape(4, 3)
|
129
|
+
bst.random.shuffle(a, axis=1)
|
130
|
+
self.assertTupleEqual(a.shape, (4, 3))
|
131
|
+
self.assertEqual(len(jnp.unique(a)), 12)
|
132
|
+
|
133
|
+
# test that a is only shuffled along axis 1
|
134
|
+
uni = jnp.unique(jnp.diff(a, axis=0))
|
135
|
+
self.assertEqual(uni, jnp.asarray([3]))
|
136
|
+
|
137
|
+
def test_beta1(self):
|
138
|
+
bst.random.seed()
|
139
|
+
a = bst.random.beta(2, 2)
|
140
|
+
self.assertTupleEqual(a.shape, ())
|
141
|
+
|
142
|
+
def test_beta2(self):
|
143
|
+
bst.random.seed()
|
144
|
+
a = bst.random.beta([2, 2, 3], 2, size=(3,))
|
145
|
+
self.assertTupleEqual(a.shape, (3,))
|
146
|
+
|
147
|
+
def test_exponential1(self):
|
148
|
+
bst.random.seed()
|
149
|
+
a = bst.random.exponential(10., size=[3, 2])
|
150
|
+
self.assertTupleEqual(a.shape, (3, 2))
|
151
|
+
|
152
|
+
def test_exponential2(self):
|
153
|
+
bst.random.seed()
|
154
|
+
a = bst.random.exponential([1., 2., 5.])
|
155
|
+
self.assertTupleEqual(a.shape, (3,))
|
156
|
+
|
157
|
+
def test_gamma(self):
|
158
|
+
bst.random.seed()
|
159
|
+
a = bst.random.gamma(2, 10., size=[3, 2])
|
160
|
+
self.assertTupleEqual(a.shape, (3, 2))
|
161
|
+
|
162
|
+
def test_gumbel(self):
|
163
|
+
bst.random.seed()
|
164
|
+
a = bst.random.gumbel(0., 2., size=[3, 2])
|
165
|
+
self.assertTupleEqual(a.shape, (3, 2))
|
166
|
+
|
167
|
+
def test_laplace(self):
|
168
|
+
bst.random.seed()
|
169
|
+
a = bst.random.laplace(0., 2., size=[3, 2])
|
170
|
+
self.assertTupleEqual(a.shape, (3, 2))
|
171
|
+
|
172
|
+
def test_logistic(self):
|
173
|
+
bst.random.seed()
|
174
|
+
a = bst.random.logistic(0., 2., size=[3, 2])
|
175
|
+
self.assertTupleEqual(a.shape, (3, 2))
|
176
|
+
|
177
|
+
def test_normal1(self):
|
178
|
+
bst.random.seed()
|
179
|
+
a = bst.random.normal()
|
180
|
+
self.assertTupleEqual(a.shape, ())
|
181
|
+
|
182
|
+
def test_normal2(self):
|
183
|
+
bst.random.seed()
|
184
|
+
a = bst.random.normal(loc=[0., 2., 4.], scale=[1., 2., 3.])
|
185
|
+
self.assertTupleEqual(a.shape, (3,))
|
186
|
+
|
187
|
+
def test_normal3(self):
|
188
|
+
bst.random.seed()
|
189
|
+
a = bst.random.normal(loc=[0., 2., 4.], scale=[[1., 2., 3.], [1., 1., 1.]])
|
190
|
+
print(a)
|
191
|
+
self.assertTupleEqual(a.shape, (2, 3))
|
192
|
+
|
193
|
+
def test_pareto(self):
|
194
|
+
bst.random.seed()
|
195
|
+
a = bst.random.pareto([1, 2, 2])
|
196
|
+
self.assertTupleEqual(a.shape, (3,))
|
197
|
+
|
198
|
+
def test_poisson(self):
|
199
|
+
bst.random.seed()
|
200
|
+
a = bst.random.poisson([1., 2., 2.], size=3)
|
201
|
+
self.assertTupleEqual(a.shape, (3,))
|
202
|
+
|
203
|
+
def test_standard_cauchy(self):
|
204
|
+
bst.random.seed()
|
205
|
+
a = bst.random.standard_cauchy(size=(3, 2))
|
206
|
+
self.assertTupleEqual(a.shape, (3, 2))
|
207
|
+
|
208
|
+
def test_standard_exponential(self):
|
209
|
+
bst.random.seed()
|
210
|
+
a = bst.random.standard_exponential(size=(3, 2))
|
211
|
+
self.assertTupleEqual(a.shape, (3, 2))
|
212
|
+
|
213
|
+
def test_standard_gamma(self):
|
214
|
+
bst.random.seed()
|
215
|
+
a = bst.random.standard_gamma(shape=[1, 2, 4], size=3)
|
216
|
+
self.assertTupleEqual(a.shape, (3,))
|
217
|
+
|
218
|
+
def test_standard_normal(self):
|
219
|
+
bst.random.seed()
|
220
|
+
a = bst.random.standard_normal(size=(3, 2))
|
221
|
+
self.assertTupleEqual(a.shape, (3, 2))
|
222
|
+
|
223
|
+
def test_standard_t(self):
|
224
|
+
bst.random.seed()
|
225
|
+
a = bst.random.standard_t(df=[1, 2, 4], size=3)
|
226
|
+
self.assertTupleEqual(a.shape, (3,))
|
227
|
+
|
228
|
+
def test_standard_uniform1(self):
|
229
|
+
bst.random.seed()
|
230
|
+
a = bst.random.uniform()
|
231
|
+
self.assertTupleEqual(a.shape, ())
|
232
|
+
self.assertTrue(0 <= a < 1)
|
233
|
+
|
234
|
+
def test_uniform2(self):
|
235
|
+
bst.random.seed()
|
236
|
+
a = bst.random.uniform(low=[-1., 5., 2.], high=[2., 6., 10.], size=3)
|
237
|
+
self.assertTupleEqual(a.shape, (3,))
|
238
|
+
self.assertTrue((a - jnp.array([-1., 5., 2.]) >= 0).all()
|
239
|
+
and (-a + jnp.array([2., 6., 10.]) > 0).all())
|
240
|
+
|
241
|
+
def test_uniform3(self):
|
242
|
+
bst.random.seed()
|
243
|
+
a = bst.random.uniform(low=-1., high=[2., 6., 10.], size=(2, 3))
|
244
|
+
self.assertTupleEqual(a.shape, (2, 3))
|
245
|
+
|
246
|
+
def test_uniform4(self):
|
247
|
+
bst.random.seed()
|
248
|
+
a = bst.random.uniform(low=[-1., 5., 2.], high=[[2., 6., 10.], [10., 10., 10.]])
|
249
|
+
self.assertTupleEqual(a.shape, (2, 3))
|
250
|
+
|
251
|
+
def test_truncated_normal1(self):
|
252
|
+
bst.random.seed()
|
253
|
+
a = bst.random.truncated_normal(-1., 1.)
|
254
|
+
self.assertTupleEqual(a.shape, ())
|
255
|
+
self.assertTrue(-1. <= a <= 1.)
|
256
|
+
|
257
|
+
def test_truncated_normal2(self):
|
258
|
+
bst.random.seed()
|
259
|
+
a = bst.random.truncated_normal(-1., [1., 2., 1.], size=(4, 3))
|
260
|
+
self.assertTupleEqual(a.shape, (4, 3))
|
261
|
+
|
262
|
+
def test_truncated_normal3(self):
|
263
|
+
bst.random.seed()
|
264
|
+
a = bst.random.truncated_normal([-1., 0., 1.], [[2., 2., 4.], [2., 2., 4.]])
|
265
|
+
self.assertTupleEqual(a.shape, (2, 3))
|
266
|
+
self.assertTrue((a - jnp.array([-1., 0., 1.]) >= 0.).all()
|
267
|
+
and (- a + jnp.array([2., 2., 4.]) >= 0.).all())
|
268
|
+
|
269
|
+
def test_bernoulli1(self):
|
270
|
+
bst.random.seed()
|
271
|
+
a = bst.random.bernoulli()
|
272
|
+
self.assertTupleEqual(a.shape, ())
|
273
|
+
self.assertTrue(a == 0 or a == 1)
|
274
|
+
|
275
|
+
def test_bernoulli2(self):
|
276
|
+
bst.random.seed()
|
277
|
+
a = bst.random.bernoulli([0.5, 0.6, 0.8])
|
278
|
+
self.assertTupleEqual(a.shape, (3,))
|
279
|
+
self.assertTrue(jnp.logical_xor(a == 1, a == 0).all())
|
280
|
+
|
281
|
+
def test_bernoulli3(self):
|
282
|
+
bst.random.seed()
|
283
|
+
a = bst.random.bernoulli([0.5, 0.6], size=(3, 2))
|
284
|
+
self.assertTupleEqual(a.shape, (3, 2))
|
285
|
+
self.assertTrue(jnp.logical_xor(a == 1, a == 0).all())
|
286
|
+
|
287
|
+
def test_lognormal1(self):
|
288
|
+
bst.random.seed()
|
289
|
+
a = bst.random.lognormal()
|
290
|
+
self.assertTupleEqual(a.shape, ())
|
291
|
+
|
292
|
+
def test_lognormal2(self):
|
293
|
+
bst.random.seed()
|
294
|
+
a = bst.random.lognormal(sigma=[2., 1.], size=[3, 2])
|
295
|
+
self.assertTupleEqual(a.shape, (3, 2))
|
296
|
+
|
297
|
+
def test_lognormal3(self):
|
298
|
+
bst.random.seed()
|
299
|
+
a = bst.random.lognormal([2., 0.], [[2., 1.], [3., 1.2]])
|
300
|
+
self.assertTupleEqual(a.shape, (2, 2))
|
301
|
+
|
302
|
+
def test_binomial1(self):
|
303
|
+
bst.random.seed()
|
304
|
+
a = bst.random.binomial(5, 0.5)
|
305
|
+
b = np.random.binomial(5, 0.5)
|
306
|
+
print(a)
|
307
|
+
print(b)
|
308
|
+
self.assertTupleEqual(a.shape, ())
|
309
|
+
self.assertTrue(a.dtype, int)
|
310
|
+
|
311
|
+
def test_binomial2(self):
|
312
|
+
bst.random.seed()
|
313
|
+
a = bst.random.binomial(5, 0.5, size=(3, 2))
|
314
|
+
self.assertTupleEqual(a.shape, (3, 2))
|
315
|
+
self.assertTrue((a >= 0).all() and (a <= 5).all())
|
316
|
+
|
317
|
+
def test_binomial3(self):
|
318
|
+
bst.random.seed()
|
319
|
+
a = bst.random.binomial(n=jnp.asarray([2, 3, 4]), p=jnp.asarray([[0.5, 0.5, 0.5], [0.6, 0.6, 0.6]]))
|
320
|
+
self.assertTupleEqual(a.shape, (2, 3))
|
321
|
+
|
322
|
+
def test_chisquare1(self):
|
323
|
+
bst.random.seed()
|
324
|
+
a = bst.random.chisquare(3)
|
325
|
+
self.assertTupleEqual(a.shape, ())
|
326
|
+
self.assertTrue(a.dtype, float)
|
327
|
+
|
328
|
+
def test_chisquare2(self):
|
329
|
+
bst.random.seed()
|
330
|
+
with self.assertRaises(NotImplementedError):
|
331
|
+
a = bst.random.chisquare(df=[2, 3, 4])
|
332
|
+
|
333
|
+
def test_chisquare3(self):
|
334
|
+
bst.random.seed()
|
335
|
+
a = bst.random.chisquare(df=2, size=100)
|
336
|
+
self.assertTupleEqual(a.shape, (100,))
|
337
|
+
|
338
|
+
def test_chisquare4(self):
|
339
|
+
bst.random.seed()
|
340
|
+
a = bst.random.chisquare(df=2, size=(100, 10))
|
341
|
+
self.assertTupleEqual(a.shape, (100, 10))
|
342
|
+
|
343
|
+
def test_dirichlet1(self):
|
344
|
+
bst.random.seed()
|
345
|
+
a = bst.random.dirichlet((10, 5, 3))
|
346
|
+
self.assertTupleEqual(a.shape, (3,))
|
347
|
+
|
348
|
+
def test_dirichlet2(self):
|
349
|
+
bst.random.seed()
|
350
|
+
a = bst.random.dirichlet((10, 5, 3), 20)
|
351
|
+
self.assertTupleEqual(a.shape, (20, 3))
|
352
|
+
|
353
|
+
def test_f(self):
|
354
|
+
bst.random.seed()
|
355
|
+
a = bst.random.f(1., 48., 100)
|
356
|
+
self.assertTupleEqual(a.shape, (100,))
|
357
|
+
|
358
|
+
def test_geometric(self):
|
359
|
+
bst.random.seed()
|
360
|
+
a = bst.random.geometric([0.7, 0.5, 0.2])
|
361
|
+
self.assertTupleEqual(a.shape, (3,))
|
362
|
+
|
363
|
+
def test_hypergeometric1(self):
|
364
|
+
bst.random.seed()
|
365
|
+
a = bst.random.hypergeometric(10, 10, 10, 20)
|
366
|
+
self.assertTupleEqual(a.shape, (20,))
|
367
|
+
|
368
|
+
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error')
|
369
|
+
def test_hypergeometric2(self):
|
370
|
+
bst.random.seed()
|
371
|
+
a = bst.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]])
|
372
|
+
self.assertTupleEqual(a.shape, (2, 2))
|
373
|
+
|
374
|
+
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error')
|
375
|
+
def test_hypergeometric3(self):
|
376
|
+
bst.random.seed()
|
377
|
+
a = bst.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]], size=(3, 2, 2))
|
378
|
+
self.assertTupleEqual(a.shape, (3, 2, 2))
|
379
|
+
|
380
|
+
def test_logseries(self):
|
381
|
+
bst.random.seed()
|
382
|
+
a = bst.random.logseries([0.7, 0.5, 0.2], size=[4, 3])
|
383
|
+
self.assertTupleEqual(a.shape, (4, 3))
|
384
|
+
|
385
|
+
def test_multinominal1(self):
|
386
|
+
bst.random.seed()
|
387
|
+
a = np.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2])
|
388
|
+
print(a, a.shape)
|
389
|
+
b = bst.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2])
|
390
|
+
print(b, b.shape)
|
391
|
+
self.assertTupleEqual(a.shape, b.shape)
|
392
|
+
self.assertTupleEqual(b.shape, (4, 2, 3))
|
393
|
+
|
394
|
+
def test_multinominal2(self):
|
395
|
+
bst.random.seed()
|
396
|
+
a = bst.random.multinomial(100, (0.5, 0.2, 0.3))
|
397
|
+
self.assertTupleEqual(a.shape, (3,))
|
398
|
+
self.assertTrue(a.sum() == 100)
|
399
|
+
|
400
|
+
def test_multivariate_normal1(self):
|
401
|
+
bst.random.seed()
|
402
|
+
# self.skipTest('Windows jaxlib error')
|
403
|
+
a = np.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3)
|
404
|
+
b = bst.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3)
|
405
|
+
print('test_multivariate_normal1')
|
406
|
+
print(a)
|
407
|
+
print(b)
|
408
|
+
self.assertTupleEqual(a.shape, b.shape)
|
409
|
+
self.assertTupleEqual(a.shape, (3, 2))
|
410
|
+
|
411
|
+
def test_multivariate_normal2(self):
|
412
|
+
bst.random.seed()
|
413
|
+
a = np.random.multivariate_normal([1, 2], [[1, 3], [3, 1]])
|
414
|
+
b = bst.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], method='svd')
|
415
|
+
print(a)
|
416
|
+
print(b)
|
417
|
+
self.assertTupleEqual(a.shape, b.shape)
|
418
|
+
self.assertTupleEqual(a.shape, (2,))
|
419
|
+
|
420
|
+
def test_negative_binomial(self):
|
421
|
+
bst.random.seed()
|
422
|
+
a = np.random.negative_binomial([3., 10.], 0.5)
|
423
|
+
b = bst.random.negative_binomial([3., 10.], 0.5)
|
424
|
+
print(a)
|
425
|
+
print(b)
|
426
|
+
self.assertTupleEqual(a.shape, b.shape)
|
427
|
+
self.assertTupleEqual(b.shape, (2,))
|
428
|
+
|
429
|
+
def test_negative_binomial2(self):
|
430
|
+
bst.random.seed()
|
431
|
+
a = np.random.negative_binomial(3., 0.5, 10)
|
432
|
+
b = bst.random.negative_binomial(3., 0.5, 10)
|
433
|
+
print(a)
|
434
|
+
print(b)
|
435
|
+
self.assertTupleEqual(a.shape, b.shape)
|
436
|
+
self.assertTupleEqual(b.shape, (10,))
|
437
|
+
|
438
|
+
def test_noncentral_chisquare(self):
|
439
|
+
bst.random.seed()
|
440
|
+
a = np.random.noncentral_chisquare(3, [3., 2.], (4, 2))
|
441
|
+
b = bst.random.noncentral_chisquare(3, [3., 2.], (4, 2))
|
442
|
+
self.assertTupleEqual(a.shape, b.shape)
|
443
|
+
self.assertTupleEqual(b.shape, (4, 2))
|
444
|
+
|
445
|
+
def test_noncentral_chisquare2(self):
|
446
|
+
bst.random.seed()
|
447
|
+
a = bst.random.noncentral_chisquare(3, [3., 2.])
|
448
|
+
self.assertTupleEqual(a.shape, (2,))
|
449
|
+
|
450
|
+
def test_noncentral_f(self):
|
451
|
+
bst.random.seed()
|
452
|
+
a = bst.random.noncentral_f(3, 20, 3., 100)
|
453
|
+
self.assertTupleEqual(a.shape, (100,))
|
454
|
+
|
455
|
+
def test_power(self):
|
456
|
+
bst.random.seed()
|
457
|
+
a = np.random.power(2, (4, 2))
|
458
|
+
b = bst.random.power(2, (4, 2))
|
459
|
+
self.assertTupleEqual(a.shape, b.shape)
|
460
|
+
self.assertTupleEqual(b.shape, (4, 2))
|
461
|
+
|
462
|
+
def test_rayleigh(self):
|
463
|
+
bst.random.seed()
|
464
|
+
a = bst.random.power(2., (4, 2))
|
465
|
+
self.assertTupleEqual(a.shape, (4, 2))
|
466
|
+
|
467
|
+
def test_triangular(self):
|
468
|
+
bst.random.seed()
|
469
|
+
a = bst.random.triangular((2, 2))
|
470
|
+
self.assertTupleEqual(a.shape, (2, 2))
|
471
|
+
|
472
|
+
def test_vonmises(self):
|
473
|
+
bst.random.seed()
|
474
|
+
a = np.random.vonmises(2., 2.)
|
475
|
+
b = bst.random.vonmises(2., 2.)
|
476
|
+
print(a, b)
|
477
|
+
self.assertTupleEqual(np.shape(a), b.shape)
|
478
|
+
self.assertTupleEqual(b.shape, ())
|
479
|
+
|
480
|
+
def test_vonmises2(self):
|
481
|
+
bst.random.seed()
|
482
|
+
a = np.random.vonmises(2., 2., 10)
|
483
|
+
b = bst.random.vonmises(2., 2., 10)
|
484
|
+
print(a, b)
|
485
|
+
self.assertTupleEqual(a.shape, b.shape)
|
486
|
+
self.assertTupleEqual(b.shape, (10,))
|
487
|
+
|
488
|
+
def test_wald(self):
|
489
|
+
bst.random.seed()
|
490
|
+
a = np.random.wald([2., 0.5], 2.)
|
491
|
+
b = bst.random.wald([2., 0.5], 2.)
|
492
|
+
self.assertTupleEqual(a.shape, b.shape)
|
493
|
+
self.assertTupleEqual(b.shape, (2,))
|
494
|
+
|
495
|
+
def test_wald2(self):
|
496
|
+
bst.random.seed()
|
497
|
+
a = np.random.wald(2., 2., 100)
|
498
|
+
b = bst.random.wald(2., 2., 100)
|
499
|
+
self.assertTupleEqual(a.shape, b.shape)
|
500
|
+
self.assertTupleEqual(b.shape, (100,))
|
501
|
+
|
502
|
+
def test_weibull(self):
|
503
|
+
bst.random.seed()
|
504
|
+
a = bst.random.weibull(2., (4, 2))
|
505
|
+
self.assertTupleEqual(a.shape, (4, 2))
|
506
|
+
|
507
|
+
def test_weibull2(self):
|
508
|
+
bst.random.seed()
|
509
|
+
a = bst.random.weibull(2., )
|
510
|
+
self.assertTupleEqual(a.shape, ())
|
511
|
+
|
512
|
+
def test_weibull3(self):
|
513
|
+
bst.random.seed()
|
514
|
+
a = bst.random.weibull([2., 3.], )
|
515
|
+
self.assertTupleEqual(a.shape, (2,))
|
516
|
+
|
517
|
+
def test_weibull_min(self):
|
518
|
+
bst.random.seed()
|
519
|
+
a = bst.random.weibull_min(2., 2., (4, 2))
|
520
|
+
self.assertTupleEqual(a.shape, (4, 2))
|
521
|
+
|
522
|
+
def test_weibull_min2(self):
|
523
|
+
bst.random.seed()
|
524
|
+
a = bst.random.weibull_min(2., 2.)
|
525
|
+
self.assertTupleEqual(a.shape, ())
|
526
|
+
|
527
|
+
def test_weibull_min3(self):
|
528
|
+
bst.random.seed()
|
529
|
+
a = bst.random.weibull_min([2., 3.], 2.)
|
530
|
+
self.assertTupleEqual(a.shape, (2,))
|
531
|
+
|
532
|
+
def test_zipf(self):
|
533
|
+
bst.random.seed()
|
534
|
+
a = bst.random.zipf(2., (4, 2))
|
535
|
+
self.assertTupleEqual(a.shape, (4, 2))
|
536
|
+
|
537
|
+
def test_zipf2(self):
|
538
|
+
bst.random.seed()
|
539
|
+
a = np.random.zipf([1.1, 2.])
|
540
|
+
b = bst.random.zipf([1.1, 2.])
|
541
|
+
self.assertTupleEqual(a.shape, b.shape)
|
542
|
+
self.assertTupleEqual(b.shape, (2,))
|
543
|
+
|
544
|
+
def test_maxwell(self):
|
545
|
+
bst.random.seed()
|
546
|
+
a = bst.random.maxwell(10)
|
547
|
+
self.assertTupleEqual(a.shape, (10,))
|
548
|
+
|
549
|
+
def test_maxwell2(self):
|
550
|
+
bst.random.seed()
|
551
|
+
a = bst.random.maxwell()
|
552
|
+
self.assertTupleEqual(a.shape, ())
|
553
|
+
|
554
|
+
def test_t(self):
|
555
|
+
bst.random.seed()
|
556
|
+
a = bst.random.t(1., size=10)
|
557
|
+
self.assertTupleEqual(a.shape, (10,))
|
558
|
+
|
559
|
+
def test_t2(self):
|
560
|
+
bst.random.seed()
|
561
|
+
a = bst.random.t([1., 2.], size=None)
|
562
|
+
self.assertTupleEqual(a.shape, (2,))
|
563
|
+
|
564
|
+
# class TestRandomKey(unittest.TestCase):
|
565
|
+
# def test_clear_memory(self):
|
566
|
+
# bst.random.split_key()
|
567
|
+
# print(bst.random.DEFAULT.value)
|
568
|
+
# self.assertTrue(isinstance(bst.random.DEFAULT.value, np.ndarray))
|