brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,551 +1,551 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import unittest
|
17
|
-
|
18
|
-
import jax.numpy as jnp
|
19
|
-
import jax.random as jr
|
20
|
-
import numpy as np
|
21
|
-
|
22
|
-
from brainstate.random.
|
23
|
-
|
24
|
-
|
25
|
-
class TestRandomStateInitialization(unittest.TestCase):
|
26
|
-
"""Test RandomState initialization and setup."""
|
27
|
-
|
28
|
-
def test_init_with_none(self):
|
29
|
-
"""Test initialization with None seed."""
|
30
|
-
rs = RandomState(None)
|
31
|
-
self.assertIsNotNone(rs.value)
|
32
|
-
self.assertEqual(rs.value.shape, (2,))
|
33
|
-
self.assertEqual(rs.value.dtype, jnp.uint32)
|
34
|
-
|
35
|
-
def test_init_with_int_seed(self):
|
36
|
-
"""Test initialization with integer seed."""
|
37
|
-
seed = 42
|
38
|
-
rs = RandomState(seed)
|
39
|
-
expected_key = jr.PRNGKey(seed)
|
40
|
-
np.testing.assert_array_equal(rs.value, expected_key)
|
41
|
-
|
42
|
-
def test_init_with_prng_key(self):
|
43
|
-
"""Test initialization with JAX PRNGKey."""
|
44
|
-
key = jr.PRNGKey(123)
|
45
|
-
rs = RandomState(key)
|
46
|
-
np.testing.assert_array_equal(rs.value, key)
|
47
|
-
|
48
|
-
def test_init_with_uint32_array(self):
|
49
|
-
"""Test initialization with uint32 array."""
|
50
|
-
key_array = np.array([123, 456], dtype=np.uint32)
|
51
|
-
rs = RandomState(key_array)
|
52
|
-
np.testing.assert_array_equal(rs.value, key_array)
|
53
|
-
|
54
|
-
def test_init_with_invalid_key(self):
|
55
|
-
"""Test initialization with invalid key raises error."""
|
56
|
-
# Test case that should raise error: wrong length AND wrong dtype
|
57
|
-
with self.assertRaises(ValueError):
|
58
|
-
RandomState(np.array([1, 2, 3], dtype=np.int32)) # len != 2 AND dtype != uint32
|
59
|
-
|
60
|
-
# Test valid cases that should NOT raise errors
|
61
|
-
# Wrong length but correct dtype is OK
|
62
|
-
rs1 = RandomState(np.array([1, 2, 3], dtype=np.uint32))
|
63
|
-
self.assertIsNotNone(rs1.value)
|
64
|
-
|
65
|
-
# Correct length but wrong dtype is OK
|
66
|
-
rs2 = RandomState(np.array([1, 2], dtype=np.int32))
|
67
|
-
self.assertIsNotNone(rs2.value)
|
68
|
-
|
69
|
-
def test_repr(self):
|
70
|
-
"""Test string representation."""
|
71
|
-
rs = RandomState(42)
|
72
|
-
repr_str = repr(rs)
|
73
|
-
self.assertIn("RandomState", repr_str)
|
74
|
-
self.assertIn("42", repr_str)
|
75
|
-
|
76
|
-
|
77
|
-
class TestRandomStateKeyManagement(unittest.TestCase):
|
78
|
-
"""Test key management functionality."""
|
79
|
-
|
80
|
-
def setUp(self):
|
81
|
-
self.rs = RandomState(42)
|
82
|
-
|
83
|
-
def test_seed_with_int(self):
|
84
|
-
"""Test seeding with integer."""
|
85
|
-
self.rs.seed(123)
|
86
|
-
expected_key = jr.PRNGKey(123)
|
87
|
-
np.testing.assert_array_equal(self.rs.value, expected_key)
|
88
|
-
|
89
|
-
def test_seed_with_none(self):
|
90
|
-
"""Test seeding with None generates new random seed."""
|
91
|
-
original_key = self.rs.value.copy()
|
92
|
-
self.rs.seed(None)
|
93
|
-
# Should be different (with very high probability)
|
94
|
-
self.assertFalse(np.array_equal(self.rs.value, original_key))
|
95
|
-
|
96
|
-
def test_seed_with_prng_key(self):
|
97
|
-
"""Test seeding with PRNGKey."""
|
98
|
-
key = jr.PRNGKey(999)
|
99
|
-
self.rs.seed(key)
|
100
|
-
np.testing.assert_array_equal(self.rs.value, key)
|
101
|
-
|
102
|
-
def test_seed_with_invalid_input(self):
|
103
|
-
"""Test seeding with invalid input raises error."""
|
104
|
-
with self.assertRaises(ValueError):
|
105
|
-
self.rs.seed([1, 2, 3]) # Wrong length list
|
106
|
-
|
107
|
-
def test_split_key_single(self):
|
108
|
-
"""Test splitting key to get single new key."""
|
109
|
-
original_key = self.rs.value.copy()
|
110
|
-
new_key = self.rs.split_key()
|
111
|
-
|
112
|
-
# Original key should have changed
|
113
|
-
self.assertFalse(np.array_equal(self.rs.value, original_key))
|
114
|
-
# New key should be different from both
|
115
|
-
self.assertFalse(np.array_equal(new_key, original_key))
|
116
|
-
self.assertFalse(np.array_equal(new_key, self.rs.value))
|
117
|
-
|
118
|
-
def test_split_key_multiple(self):
|
119
|
-
"""Test splitting key to get multiple new keys."""
|
120
|
-
n = 3
|
121
|
-
original_key = self.rs.value.copy()
|
122
|
-
new_keys = self.rs.split_key(n)
|
123
|
-
|
124
|
-
self.assertEqual(len(new_keys), n)
|
125
|
-
# All keys should be different
|
126
|
-
for i, key in enumerate(new_keys):
|
127
|
-
self.assertFalse(np.array_equal(key, original_key))
|
128
|
-
for j, other_key in enumerate(new_keys):
|
129
|
-
if i != j:
|
130
|
-
self.assertFalse(np.array_equal(key, other_key))
|
131
|
-
|
132
|
-
def test_split_key_invalid_n(self):
|
133
|
-
"""Test split_key with invalid n raises error."""
|
134
|
-
with self.assertRaises(AssertionError):
|
135
|
-
self.rs.split_key(0)
|
136
|
-
|
137
|
-
with self.assertRaises(AssertionError):
|
138
|
-
self.rs.split_key(-1)
|
139
|
-
|
140
|
-
def test_backup_restore_key(self):
|
141
|
-
"""Test backup and restore functionality."""
|
142
|
-
original_key = self.rs.value.copy()
|
143
|
-
|
144
|
-
# Backup the key
|
145
|
-
self.rs.backup_key()
|
146
|
-
|
147
|
-
# Change the key
|
148
|
-
self.rs.split_key()
|
149
|
-
changed_key = self.rs.value.copy()
|
150
|
-
self.assertFalse(np.array_equal(changed_key, original_key))
|
151
|
-
|
152
|
-
# Restore the key
|
153
|
-
self.rs.restore_key()
|
154
|
-
np.testing.assert_array_equal(self.rs.value, original_key)
|
155
|
-
|
156
|
-
def test_backup_already_backed_up(self):
|
157
|
-
"""Test backup when already backed up raises error."""
|
158
|
-
self.rs.backup_key()
|
159
|
-
with self.assertRaises(ValueError):
|
160
|
-
self.rs.backup_key()
|
161
|
-
|
162
|
-
def test_restore_without_backup(self):
|
163
|
-
"""Test restore without backup raises error."""
|
164
|
-
with self.assertRaises(ValueError):
|
165
|
-
self.rs.restore_key()
|
166
|
-
|
167
|
-
def test_clone(self):
|
168
|
-
"""Test cloning creates independent copy."""
|
169
|
-
clone = self.rs.clone()
|
170
|
-
|
171
|
-
# Should be different instances
|
172
|
-
self.assertIsNot(clone, self.rs)
|
173
|
-
|
174
|
-
# Should have different keys after split
|
175
|
-
original_key = self.rs.value.copy()
|
176
|
-
clone_key = clone.value.copy()
|
177
|
-
|
178
|
-
self.rs.split_key()
|
179
|
-
clone.split_key()
|
180
|
-
|
181
|
-
self.assertFalse(np.array_equal(self.rs.value, clone.value))
|
182
|
-
|
183
|
-
def test_set_key(self):
|
184
|
-
"""Test setting key directly."""
|
185
|
-
new_key = jr.PRNGKey(999)
|
186
|
-
self.rs.set_key(new_key)
|
187
|
-
np.testing.assert_array_equal(self.rs.value, new_key)
|
188
|
-
|
189
|
-
|
190
|
-
class TestRandomStateDistributions(unittest.TestCase):
|
191
|
-
"""Test random distribution methods."""
|
192
|
-
|
193
|
-
def setUp(self):
|
194
|
-
self.rs = RandomState(42)
|
195
|
-
|
196
|
-
def test_rand(self):
|
197
|
-
"""Test rand method."""
|
198
|
-
# Single value
|
199
|
-
val = self.rs.rand()
|
200
|
-
self.assertEqual(val.shape, ())
|
201
|
-
self.assertTrue(0 <= val < 1)
|
202
|
-
|
203
|
-
# Multiple dimensions
|
204
|
-
arr = self.rs.rand(3, 2)
|
205
|
-
self.assertEqual(arr.shape, (3, 2))
|
206
|
-
self.assertTrue((arr >= 0).all() and (arr < 1).all())
|
207
|
-
|
208
|
-
def test_randint(self):
|
209
|
-
"""Test randint method."""
|
210
|
-
# Single bound
|
211
|
-
val = self.rs.randint(10)
|
212
|
-
self.assertTrue(0 <= val < 10)
|
213
|
-
|
214
|
-
# Both bounds
|
215
|
-
val = self.rs.randint(5, 15)
|
216
|
-
self.assertTrue(5 <= val < 15)
|
217
|
-
|
218
|
-
# With size
|
219
|
-
arr = self.rs.randint(0, 5, size=(2, 3))
|
220
|
-
self.assertEqual(arr.shape, (2, 3))
|
221
|
-
self.assertTrue((arr >= 0).all() and (arr < 5).all())
|
222
|
-
|
223
|
-
def test_randn(self):
|
224
|
-
"""Test randn method."""
|
225
|
-
# Single value
|
226
|
-
val = self.rs.randn()
|
227
|
-
self.assertEqual(val.shape, ())
|
228
|
-
|
229
|
-
# Multiple dimensions
|
230
|
-
arr = self.rs.randn(3, 2)
|
231
|
-
self.assertEqual(arr.shape, (3, 2))
|
232
|
-
|
233
|
-
def test_normal(self):
|
234
|
-
"""Test normal distribution."""
|
235
|
-
# Standard normal
|
236
|
-
val = self.rs.normal()
|
237
|
-
self.assertEqual(val.shape, ())
|
238
|
-
|
239
|
-
# With parameters
|
240
|
-
arr = self.rs.normal(loc=5.0, scale=2.0, size=(3, 2))
|
241
|
-
self.assertEqual(arr.shape, (3, 2))
|
242
|
-
|
243
|
-
def test_uniform(self):
|
244
|
-
"""Test uniform distribution."""
|
245
|
-
# Standard uniform
|
246
|
-
val = self.rs.uniform()
|
247
|
-
self.assertTrue(0.0 <= val < 1.0)
|
248
|
-
|
249
|
-
# With bounds
|
250
|
-
arr = self.rs.uniform(low=2.0, high=8.0, size=(2, 3))
|
251
|
-
self.assertEqual(arr.shape, (2, 3))
|
252
|
-
self.assertTrue((arr >= 2.0).all() and (arr < 8.0).all())
|
253
|
-
|
254
|
-
def test_choice(self):
|
255
|
-
"""Test choice method."""
|
256
|
-
# Choose from range
|
257
|
-
val = self.rs.choice(5)
|
258
|
-
self.assertTrue(0 <= val < 5)
|
259
|
-
|
260
|
-
# Choose from array
|
261
|
-
options = jnp.array([10, 20, 30, 40])
|
262
|
-
val = self.rs.choice(options)
|
263
|
-
self.assertIn(val, options)
|
264
|
-
|
265
|
-
# Multiple choices
|
266
|
-
arr = self.rs.choice(5, size=10)
|
267
|
-
self.assertEqual(arr.shape, (10,))
|
268
|
-
self.assertTrue((arr >= 0).all() and (arr < 5).all())
|
269
|
-
|
270
|
-
def test_beta(self):
|
271
|
-
"""Test beta distribution."""
|
272
|
-
arr = self.rs.beta(2.0, 3.0, size=(2, 3))
|
273
|
-
self.assertEqual(arr.shape, (2, 3))
|
274
|
-
self.assertTrue((arr >= 0).all() and (arr <= 1).all())
|
275
|
-
|
276
|
-
def test_exponential(self):
|
277
|
-
"""Test exponential distribution."""
|
278
|
-
arr = self.rs.exponential(scale=2.0, size=(2, 3))
|
279
|
-
self.assertEqual(arr.shape, (2, 3))
|
280
|
-
self.assertTrue((arr >= 0).all())
|
281
|
-
|
282
|
-
def test_gamma(self):
|
283
|
-
"""Test gamma distribution."""
|
284
|
-
arr = self.rs.gamma(shape=2.0, scale=1.0, size=(2, 3))
|
285
|
-
self.assertEqual(arr.shape, (2, 3))
|
286
|
-
self.assertTrue((arr >= 0).all())
|
287
|
-
|
288
|
-
def test_poisson(self):
|
289
|
-
"""Test Poisson distribution."""
|
290
|
-
arr = self.rs.poisson(lam=3.0, size=(2, 3))
|
291
|
-
self.assertEqual(arr.shape, (2, 3))
|
292
|
-
self.assertTrue((arr >= 0).all())
|
293
|
-
|
294
|
-
def test_binomial(self):
|
295
|
-
"""Test binomial distribution."""
|
296
|
-
arr = self.rs.binomial(n=10, p=0.3, size=(2, 3))
|
297
|
-
self.assertEqual(arr.shape, (2, 3))
|
298
|
-
self.assertTrue((arr >= 0).all() and (arr <= 10).all())
|
299
|
-
|
300
|
-
def test_bernoulli(self):
|
301
|
-
"""Test Bernoulli distribution."""
|
302
|
-
arr = self.rs.bernoulli(p=0.7, size=(100,))
|
303
|
-
self.assertEqual(arr.shape, (100,))
|
304
|
-
self.assertTrue(jnp.all((arr == 0) | (arr == 1)))
|
305
|
-
|
306
|
-
def test_bernoulli_invalid_p(self):
|
307
|
-
"""Test Bernoulli with invalid probability."""
|
308
|
-
# Note: This should trigger jit_error_if, but in test we check the validation exists
|
309
|
-
with self.assertRaises((ValueError, Exception)):
|
310
|
-
self.rs.bernoulli(p=1.5) # p > 1
|
311
|
-
|
312
|
-
def test_truncated_normal(self):
|
313
|
-
"""Test truncated normal distribution."""
|
314
|
-
arr = self.rs.truncated_normal(lower=-1.0, upper=1.0, size=(2, 3))
|
315
|
-
self.assertEqual(arr.shape, (2, 3))
|
316
|
-
self.assertTrue((arr >= -1.0).all() and (arr <= 1.0).all())
|
317
|
-
|
318
|
-
def test_multivariate_normal(self):
|
319
|
-
"""Test multivariate normal distribution."""
|
320
|
-
mean = jnp.array([0.0, 1.0])
|
321
|
-
cov = jnp.array([[1.0, 0.5], [0.5, 2.0]])
|
322
|
-
|
323
|
-
arr = self.rs.multivariate_normal(mean, cov, size=(3,))
|
324
|
-
self.assertEqual(arr.shape, (3, 2))
|
325
|
-
|
326
|
-
def test_categorical(self):
|
327
|
-
"""Test categorical distribution."""
|
328
|
-
logits = jnp.array([0.1, 0.2, 0.3, 0.4])
|
329
|
-
arr = self.rs.categorical(logits, size=(10,))
|
330
|
-
self.assertEqual(arr.shape, (10,))
|
331
|
-
self.assertTrue((arr >= 0).all() and (arr < len(logits)).all())
|
332
|
-
|
333
|
-
|
334
|
-
class TestRandomStatePyTorchCompatibility(unittest.TestCase):
|
335
|
-
"""Test PyTorch-like methods."""
|
336
|
-
|
337
|
-
def setUp(self):
|
338
|
-
self.rs = RandomState(42)
|
339
|
-
|
340
|
-
def test_rand_like(self):
|
341
|
-
"""Test rand_like method."""
|
342
|
-
input_tensor = jnp.zeros((3, 4))
|
343
|
-
result = self.rs.rand_like(input_tensor)
|
344
|
-
self.assertEqual(result.shape, input_tensor.shape)
|
345
|
-
self.assertTrue((result >= 0).all() and (result < 1).all())
|
346
|
-
|
347
|
-
def test_randn_like(self):
|
348
|
-
"""Test randn_like method."""
|
349
|
-
input_tensor = jnp.zeros((2, 3))
|
350
|
-
result = self.rs.randn_like(input_tensor)
|
351
|
-
self.assertEqual(result.shape, input_tensor.shape)
|
352
|
-
|
353
|
-
def test_randint_like(self):
|
354
|
-
"""Test randint_like method."""
|
355
|
-
input_tensor = jnp.zeros((2, 3), dtype=jnp.int32)
|
356
|
-
result = self.rs.randint_like(input_tensor, low=0, high=10)
|
357
|
-
self.assertEqual(result.shape, input_tensor.shape)
|
358
|
-
self.assertTrue((result >= 0).all() and (result < 10).all())
|
359
|
-
|
360
|
-
|
361
|
-
class TestRandomStateKeyBehavior(unittest.TestCase):
|
362
|
-
"""Test key parameter behavior across methods."""
|
363
|
-
|
364
|
-
def setUp(self):
|
365
|
-
self.rs = RandomState(42)
|
366
|
-
|
367
|
-
def test_external_key_does_not_change_state(self):
|
368
|
-
"""Test that using external key doesn't change internal state."""
|
369
|
-
original_key = self.rs.value.copy()
|
370
|
-
external_key = jr.PRNGKey(999)
|
371
|
-
|
372
|
-
# Use external key
|
373
|
-
self.rs.rand(5, key=external_key)
|
374
|
-
|
375
|
-
# Internal state should be unchanged
|
376
|
-
np.testing.assert_array_equal(self.rs.value, original_key)
|
377
|
-
|
378
|
-
def test_no_key_changes_state(self):
|
379
|
-
"""Test that not providing key changes internal state."""
|
380
|
-
original_key = self.rs.value.copy()
|
381
|
-
|
382
|
-
# Use internal key
|
383
|
-
self.rs.rand(5)
|
384
|
-
|
385
|
-
# Internal state should have changed
|
386
|
-
self.assertFalse(np.array_equal(self.rs.value, original_key))
|
387
|
-
|
388
|
-
def test_reproducibility_with_same_key(self):
|
389
|
-
"""Test reproducibility when using same external key."""
|
390
|
-
key = jr.PRNGKey(123)
|
391
|
-
|
392
|
-
result1 = self.rs.rand(5, key=key)
|
393
|
-
result2 = self.rs.rand(5, key=key)
|
394
|
-
|
395
|
-
np.testing.assert_array_equal(result1, result2)
|
396
|
-
|
397
|
-
def test_reproducibility_with_seed(self):
|
398
|
-
"""Test reproducibility with seeding."""
|
399
|
-
self.rs.seed(42)
|
400
|
-
result1 = self.rs.rand(5)
|
401
|
-
|
402
|
-
self.rs.seed(42)
|
403
|
-
result2 = self.rs.rand(5)
|
404
|
-
|
405
|
-
np.testing.assert_array_equal(result1, result2)
|
406
|
-
|
407
|
-
|
408
|
-
class TestGlobalDefaultInstance(unittest.TestCase):
|
409
|
-
"""Test the global DEFAULT RandomState instance."""
|
410
|
-
|
411
|
-
def test_default_exists(self):
|
412
|
-
"""Test that DEFAULT instance exists and is RandomState."""
|
413
|
-
self.assertIsInstance(DEFAULT, RandomState)
|
414
|
-
|
415
|
-
def test_default_has_valid_key(self):
|
416
|
-
"""Test that DEFAULT has valid key."""
|
417
|
-
self.assertIsNotNone(DEFAULT.value)
|
418
|
-
self.assertEqual(DEFAULT.value.shape, (2,))
|
419
|
-
self.assertEqual(DEFAULT.value.dtype, jnp.uint32)
|
420
|
-
|
421
|
-
def test_default_seeding(self):
|
422
|
-
"""Test seeding DEFAULT instance."""
|
423
|
-
original_key = DEFAULT.value.copy()
|
424
|
-
DEFAULT.seed(12345)
|
425
|
-
self.assertFalse(np.array_equal(DEFAULT.value, original_key))
|
426
|
-
|
427
|
-
def test_default_split_key(self):
|
428
|
-
"""Test splitting DEFAULT key."""
|
429
|
-
original_key = DEFAULT.value.copy()
|
430
|
-
new_key = DEFAULT.split_key()
|
431
|
-
self.assertFalse(np.array_equal(DEFAULT.value, original_key))
|
432
|
-
self.assertIsNotNone(new_key)
|
433
|
-
|
434
|
-
|
435
|
-
class TestUtilityFunctions(unittest.TestCase):
|
436
|
-
"""Test utility functions in _rand_state module."""
|
437
|
-
|
438
|
-
def test_formalize_key_with_int(self):
|
439
|
-
"""Test _formalize_key with integer."""
|
440
|
-
key =
|
441
|
-
expected = jr.PRNGKey(42)
|
442
|
-
np.testing.assert_array_equal(key, expected)
|
443
|
-
|
444
|
-
def test_formalize_key_with_array(self):
|
445
|
-
"""Test _formalize_key with array."""
|
446
|
-
input_key = jr.PRNGKey(123)
|
447
|
-
key =
|
448
|
-
np.testing.assert_array_equal(key, input_key)
|
449
|
-
|
450
|
-
def test_formalize_key_with_uint32_array(self):
|
451
|
-
"""Test _formalize_key with uint32 array."""
|
452
|
-
input_array = np.array([123, 456], dtype=np.uint32)
|
453
|
-
key =
|
454
|
-
np.testing.assert_array_equal(key, input_array)
|
455
|
-
|
456
|
-
def test_formalize_key_invalid_input(self):
|
457
|
-
"""Test _formalize_key with invalid input."""
|
458
|
-
with self.assertRaises(TypeError):
|
459
|
-
|
460
|
-
|
461
|
-
with self.assertRaises(TypeError):
|
462
|
-
|
463
|
-
|
464
|
-
with self.assertRaises(TypeError):
|
465
|
-
|
466
|
-
|
467
|
-
def test_size2shape(self):
|
468
|
-
"""Test _size2shape function."""
|
469
|
-
self.assertEqual(_size2shape(None), ())
|
470
|
-
self.assertEqual(_size2shape(5), (5,))
|
471
|
-
self.assertEqual(_size2shape((3, 4)), (3, 4))
|
472
|
-
self.assertEqual(_size2shape([2, 3, 4]), (2, 3, 4))
|
473
|
-
|
474
|
-
def test_check_py_seq(self):
|
475
|
-
"""Test _check_py_seq function."""
|
476
|
-
# Should convert lists/tuples to arrays
|
477
|
-
result = _check_py_seq([1, 2, 3])
|
478
|
-
self.assertIsInstance(result, jnp.ndarray)
|
479
|
-
np.testing.assert_array_equal(result, jnp.array([1, 2, 3]))
|
480
|
-
|
481
|
-
# Should leave other types unchanged
|
482
|
-
arr = jnp.array([1, 2, 3])
|
483
|
-
result = _check_py_seq(arr)
|
484
|
-
self.assertIs(result, arr)
|
485
|
-
|
486
|
-
scalar = 5
|
487
|
-
result = _check_py_seq(scalar)
|
488
|
-
self.assertEqual(result, scalar)
|
489
|
-
|
490
|
-
|
491
|
-
class TestErrorHandling(unittest.TestCase):
|
492
|
-
"""Test error handling and edge cases."""
|
493
|
-
|
494
|
-
def setUp(self):
|
495
|
-
self.rs = RandomState(42)
|
496
|
-
|
497
|
-
def test_invalid_distribution_parameters(self):
|
498
|
-
"""Test invalid parameters for distributions."""
|
499
|
-
# Note: Some distributions may not validate parameters immediately
|
500
|
-
# so we test what we can verify
|
501
|
-
|
502
|
-
# Test invalid probability for binomial should work with check_valid=True
|
503
|
-
try:
|
504
|
-
# This may or may not raise immediately depending on JAX compilation
|
505
|
-
self.rs.binomial(n=10, p=1.5, check_valid=True)
|
506
|
-
except:
|
507
|
-
pass # Expected to fail
|
508
|
-
|
509
|
-
# Test normal distribution works with negative scale (JAX allows this)
|
510
|
-
result = self.rs.normal(scale=-1.0, size=(2,))
|
511
|
-
self.assertEqual(result.shape, (2,))
|
512
|
-
|
513
|
-
def test_invalid_size_parameters(self):
|
514
|
-
"""Test invalid size parameters."""
|
515
|
-
# Test empty shape works for distributions that accept size parameter
|
516
|
-
result = self.rs.random(size=())
|
517
|
-
self.assertEqual(result.shape, ())
|
518
|
-
|
519
|
-
# Test with None size
|
520
|
-
result = self.rs.random(size=None)
|
521
|
-
self.assertEqual(result.shape, ())
|
522
|
-
|
523
|
-
def test_dtype_consistency(self):
|
524
|
-
"""Test dtype consistency across methods."""
|
525
|
-
# Integer methods should return integers
|
526
|
-
result = self.rs.randint(10, size=(3,))
|
527
|
-
self.assertTrue(jnp.issubdtype(result.dtype, jnp.integer))
|
528
|
-
|
529
|
-
# Float methods should return floats
|
530
|
-
result = self.rs.rand(3)
|
531
|
-
self.assertTrue(jnp.issubdtype(result.dtype, jnp.floating))
|
532
|
-
|
533
|
-
def test_self_assign_multi_keys(self):
|
534
|
-
"""Test self_assign_multi_keys method."""
|
535
|
-
original_shape = self.rs.value.shape
|
536
|
-
|
537
|
-
# Test with backup
|
538
|
-
self.rs.self_assign_multi_keys(3, backup=True)
|
539
|
-
self.assertEqual(self.rs.value.shape, (3, 2))
|
540
|
-
|
541
|
-
# Restore should work
|
542
|
-
self.rs.restore_key()
|
543
|
-
self.assertEqual(self.rs.value.shape, original_shape)
|
544
|
-
|
545
|
-
# Test without backup
|
546
|
-
self.rs.self_assign_multi_keys(2, backup=False)
|
547
|
-
self.assertEqual(self.rs.value.shape, (2, 2))
|
548
|
-
|
549
|
-
|
550
|
-
if __name__ == '__main__':
|
551
|
-
unittest.main()
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import jax.numpy as jnp
|
19
|
+
import jax.random as jr
|
20
|
+
import numpy as np
|
21
|
+
|
22
|
+
from brainstate.random._state import RandomState, DEFAULT, formalize_key, _size2shape, _check_py_seq
|
23
|
+
|
24
|
+
|
25
|
+
class TestRandomStateInitialization(unittest.TestCase):
|
26
|
+
"""Test RandomState initialization and setup."""
|
27
|
+
|
28
|
+
def test_init_with_none(self):
|
29
|
+
"""Test initialization with None seed."""
|
30
|
+
rs = RandomState(None)
|
31
|
+
self.assertIsNotNone(rs.value)
|
32
|
+
self.assertEqual(rs.value.shape, (2,))
|
33
|
+
self.assertEqual(rs.value.dtype, jnp.uint32)
|
34
|
+
|
35
|
+
def test_init_with_int_seed(self):
|
36
|
+
"""Test initialization with integer seed."""
|
37
|
+
seed = 42
|
38
|
+
rs = RandomState(seed)
|
39
|
+
expected_key = jr.PRNGKey(seed)
|
40
|
+
np.testing.assert_array_equal(rs.value, expected_key)
|
41
|
+
|
42
|
+
def test_init_with_prng_key(self):
|
43
|
+
"""Test initialization with JAX PRNGKey."""
|
44
|
+
key = jr.PRNGKey(123)
|
45
|
+
rs = RandomState(key)
|
46
|
+
np.testing.assert_array_equal(rs.value, key)
|
47
|
+
|
48
|
+
def test_init_with_uint32_array(self):
|
49
|
+
"""Test initialization with uint32 array."""
|
50
|
+
key_array = np.array([123, 456], dtype=np.uint32)
|
51
|
+
rs = RandomState(key_array)
|
52
|
+
np.testing.assert_array_equal(rs.value, key_array)
|
53
|
+
|
54
|
+
def test_init_with_invalid_key(self):
|
55
|
+
"""Test initialization with invalid key raises error."""
|
56
|
+
# Test case that should raise error: wrong length AND wrong dtype
|
57
|
+
with self.assertRaises(ValueError):
|
58
|
+
RandomState(np.array([1, 2, 3], dtype=np.int32)) # len != 2 AND dtype != uint32
|
59
|
+
|
60
|
+
# Test valid cases that should NOT raise errors
|
61
|
+
# Wrong length but correct dtype is OK
|
62
|
+
rs1 = RandomState(np.array([1, 2, 3], dtype=np.uint32))
|
63
|
+
self.assertIsNotNone(rs1.value)
|
64
|
+
|
65
|
+
# Correct length but wrong dtype is OK
|
66
|
+
rs2 = RandomState(np.array([1, 2], dtype=np.int32))
|
67
|
+
self.assertIsNotNone(rs2.value)
|
68
|
+
|
69
|
+
def test_repr(self):
|
70
|
+
"""Test string representation."""
|
71
|
+
rs = RandomState(42)
|
72
|
+
repr_str = repr(rs)
|
73
|
+
self.assertIn("RandomState", repr_str)
|
74
|
+
self.assertIn("42", repr_str)
|
75
|
+
|
76
|
+
|
77
|
+
class TestRandomStateKeyManagement(unittest.TestCase):
|
78
|
+
"""Test key management functionality."""
|
79
|
+
|
80
|
+
def setUp(self):
|
81
|
+
self.rs = RandomState(42)
|
82
|
+
|
83
|
+
def test_seed_with_int(self):
|
84
|
+
"""Test seeding with integer."""
|
85
|
+
self.rs.seed(123)
|
86
|
+
expected_key = jr.PRNGKey(123)
|
87
|
+
np.testing.assert_array_equal(self.rs.value, expected_key)
|
88
|
+
|
89
|
+
def test_seed_with_none(self):
|
90
|
+
"""Test seeding with None generates new random seed."""
|
91
|
+
original_key = self.rs.value.copy()
|
92
|
+
self.rs.seed(None)
|
93
|
+
# Should be different (with very high probability)
|
94
|
+
self.assertFalse(np.array_equal(self.rs.value, original_key))
|
95
|
+
|
96
|
+
def test_seed_with_prng_key(self):
|
97
|
+
"""Test seeding with PRNGKey."""
|
98
|
+
key = jr.PRNGKey(999)
|
99
|
+
self.rs.seed(key)
|
100
|
+
np.testing.assert_array_equal(self.rs.value, key)
|
101
|
+
|
102
|
+
def test_seed_with_invalid_input(self):
|
103
|
+
"""Test seeding with invalid input raises error."""
|
104
|
+
with self.assertRaises(ValueError):
|
105
|
+
self.rs.seed([1, 2, 3]) # Wrong length list
|
106
|
+
|
107
|
+
def test_split_key_single(self):
|
108
|
+
"""Test splitting key to get single new key."""
|
109
|
+
original_key = self.rs.value.copy()
|
110
|
+
new_key = self.rs.split_key()
|
111
|
+
|
112
|
+
# Original key should have changed
|
113
|
+
self.assertFalse(np.array_equal(self.rs.value, original_key))
|
114
|
+
# New key should be different from both
|
115
|
+
self.assertFalse(np.array_equal(new_key, original_key))
|
116
|
+
self.assertFalse(np.array_equal(new_key, self.rs.value))
|
117
|
+
|
118
|
+
def test_split_key_multiple(self):
|
119
|
+
"""Test splitting key to get multiple new keys."""
|
120
|
+
n = 3
|
121
|
+
original_key = self.rs.value.copy()
|
122
|
+
new_keys = self.rs.split_key(n)
|
123
|
+
|
124
|
+
self.assertEqual(len(new_keys), n)
|
125
|
+
# All keys should be different
|
126
|
+
for i, key in enumerate(new_keys):
|
127
|
+
self.assertFalse(np.array_equal(key, original_key))
|
128
|
+
for j, other_key in enumerate(new_keys):
|
129
|
+
if i != j:
|
130
|
+
self.assertFalse(np.array_equal(key, other_key))
|
131
|
+
|
132
|
+
def test_split_key_invalid_n(self):
|
133
|
+
"""Test split_key with invalid n raises error."""
|
134
|
+
with self.assertRaises(AssertionError):
|
135
|
+
self.rs.split_key(0)
|
136
|
+
|
137
|
+
with self.assertRaises(AssertionError):
|
138
|
+
self.rs.split_key(-1)
|
139
|
+
|
140
|
+
def test_backup_restore_key(self):
|
141
|
+
"""Test backup and restore functionality."""
|
142
|
+
original_key = self.rs.value.copy()
|
143
|
+
|
144
|
+
# Backup the key
|
145
|
+
self.rs.backup_key()
|
146
|
+
|
147
|
+
# Change the key
|
148
|
+
self.rs.split_key()
|
149
|
+
changed_key = self.rs.value.copy()
|
150
|
+
self.assertFalse(np.array_equal(changed_key, original_key))
|
151
|
+
|
152
|
+
# Restore the key
|
153
|
+
self.rs.restore_key()
|
154
|
+
np.testing.assert_array_equal(self.rs.value, original_key)
|
155
|
+
|
156
|
+
def test_backup_already_backed_up(self):
|
157
|
+
"""Test backup when already backed up raises error."""
|
158
|
+
self.rs.backup_key()
|
159
|
+
with self.assertRaises(ValueError):
|
160
|
+
self.rs.backup_key()
|
161
|
+
|
162
|
+
def test_restore_without_backup(self):
|
163
|
+
"""Test restore without backup raises error."""
|
164
|
+
with self.assertRaises(ValueError):
|
165
|
+
self.rs.restore_key()
|
166
|
+
|
167
|
+
def test_clone(self):
|
168
|
+
"""Test cloning creates independent copy."""
|
169
|
+
clone = self.rs.clone()
|
170
|
+
|
171
|
+
# Should be different instances
|
172
|
+
self.assertIsNot(clone, self.rs)
|
173
|
+
|
174
|
+
# Should have different keys after split
|
175
|
+
original_key = self.rs.value.copy()
|
176
|
+
clone_key = clone.value.copy()
|
177
|
+
|
178
|
+
self.rs.split_key()
|
179
|
+
clone.split_key()
|
180
|
+
|
181
|
+
self.assertFalse(np.array_equal(self.rs.value, clone.value))
|
182
|
+
|
183
|
+
def test_set_key(self):
|
184
|
+
"""Test setting key directly."""
|
185
|
+
new_key = jr.PRNGKey(999)
|
186
|
+
self.rs.set_key(new_key)
|
187
|
+
np.testing.assert_array_equal(self.rs.value, new_key)
|
188
|
+
|
189
|
+
|
190
|
+
class TestRandomStateDistributions(unittest.TestCase):
|
191
|
+
"""Test random distribution methods."""
|
192
|
+
|
193
|
+
def setUp(self):
|
194
|
+
self.rs = RandomState(42)
|
195
|
+
|
196
|
+
def test_rand(self):
|
197
|
+
"""Test rand method."""
|
198
|
+
# Single value
|
199
|
+
val = self.rs.rand()
|
200
|
+
self.assertEqual(val.shape, ())
|
201
|
+
self.assertTrue(0 <= val < 1)
|
202
|
+
|
203
|
+
# Multiple dimensions
|
204
|
+
arr = self.rs.rand(3, 2)
|
205
|
+
self.assertEqual(arr.shape, (3, 2))
|
206
|
+
self.assertTrue((arr >= 0).all() and (arr < 1).all())
|
207
|
+
|
208
|
+
def test_randint(self):
|
209
|
+
"""Test randint method."""
|
210
|
+
# Single bound
|
211
|
+
val = self.rs.randint(10)
|
212
|
+
self.assertTrue(0 <= val < 10)
|
213
|
+
|
214
|
+
# Both bounds
|
215
|
+
val = self.rs.randint(5, 15)
|
216
|
+
self.assertTrue(5 <= val < 15)
|
217
|
+
|
218
|
+
# With size
|
219
|
+
arr = self.rs.randint(0, 5, size=(2, 3))
|
220
|
+
self.assertEqual(arr.shape, (2, 3))
|
221
|
+
self.assertTrue((arr >= 0).all() and (arr < 5).all())
|
222
|
+
|
223
|
+
def test_randn(self):
|
224
|
+
"""Test randn method."""
|
225
|
+
# Single value
|
226
|
+
val = self.rs.randn()
|
227
|
+
self.assertEqual(val.shape, ())
|
228
|
+
|
229
|
+
# Multiple dimensions
|
230
|
+
arr = self.rs.randn(3, 2)
|
231
|
+
self.assertEqual(arr.shape, (3, 2))
|
232
|
+
|
233
|
+
def test_normal(self):
|
234
|
+
"""Test normal distribution."""
|
235
|
+
# Standard normal
|
236
|
+
val = self.rs.normal()
|
237
|
+
self.assertEqual(val.shape, ())
|
238
|
+
|
239
|
+
# With parameters
|
240
|
+
arr = self.rs.normal(loc=5.0, scale=2.0, size=(3, 2))
|
241
|
+
self.assertEqual(arr.shape, (3, 2))
|
242
|
+
|
243
|
+
def test_uniform(self):
|
244
|
+
"""Test uniform distribution."""
|
245
|
+
# Standard uniform
|
246
|
+
val = self.rs.uniform()
|
247
|
+
self.assertTrue(0.0 <= val < 1.0)
|
248
|
+
|
249
|
+
# With bounds
|
250
|
+
arr = self.rs.uniform(low=2.0, high=8.0, size=(2, 3))
|
251
|
+
self.assertEqual(arr.shape, (2, 3))
|
252
|
+
self.assertTrue((arr >= 2.0).all() and (arr < 8.0).all())
|
253
|
+
|
254
|
+
def test_choice(self):
|
255
|
+
"""Test choice method."""
|
256
|
+
# Choose from range
|
257
|
+
val = self.rs.choice(5)
|
258
|
+
self.assertTrue(0 <= val < 5)
|
259
|
+
|
260
|
+
# Choose from array
|
261
|
+
options = jnp.array([10, 20, 30, 40])
|
262
|
+
val = self.rs.choice(options)
|
263
|
+
self.assertIn(val, options)
|
264
|
+
|
265
|
+
# Multiple choices
|
266
|
+
arr = self.rs.choice(5, size=10)
|
267
|
+
self.assertEqual(arr.shape, (10,))
|
268
|
+
self.assertTrue((arr >= 0).all() and (arr < 5).all())
|
269
|
+
|
270
|
+
def test_beta(self):
|
271
|
+
"""Test beta distribution."""
|
272
|
+
arr = self.rs.beta(2.0, 3.0, size=(2, 3))
|
273
|
+
self.assertEqual(arr.shape, (2, 3))
|
274
|
+
self.assertTrue((arr >= 0).all() and (arr <= 1).all())
|
275
|
+
|
276
|
+
def test_exponential(self):
|
277
|
+
"""Test exponential distribution."""
|
278
|
+
arr = self.rs.exponential(scale=2.0, size=(2, 3))
|
279
|
+
self.assertEqual(arr.shape, (2, 3))
|
280
|
+
self.assertTrue((arr >= 0).all())
|
281
|
+
|
282
|
+
def test_gamma(self):
|
283
|
+
"""Test gamma distribution."""
|
284
|
+
arr = self.rs.gamma(shape=2.0, scale=1.0, size=(2, 3))
|
285
|
+
self.assertEqual(arr.shape, (2, 3))
|
286
|
+
self.assertTrue((arr >= 0).all())
|
287
|
+
|
288
|
+
def test_poisson(self):
|
289
|
+
"""Test Poisson distribution."""
|
290
|
+
arr = self.rs.poisson(lam=3.0, size=(2, 3))
|
291
|
+
self.assertEqual(arr.shape, (2, 3))
|
292
|
+
self.assertTrue((arr >= 0).all())
|
293
|
+
|
294
|
+
def test_binomial(self):
|
295
|
+
"""Test binomial distribution."""
|
296
|
+
arr = self.rs.binomial(n=10, p=0.3, size=(2, 3))
|
297
|
+
self.assertEqual(arr.shape, (2, 3))
|
298
|
+
self.assertTrue((arr >= 0).all() and (arr <= 10).all())
|
299
|
+
|
300
|
+
def test_bernoulli(self):
|
301
|
+
"""Test Bernoulli distribution."""
|
302
|
+
arr = self.rs.bernoulli(p=0.7, size=(100,))
|
303
|
+
self.assertEqual(arr.shape, (100,))
|
304
|
+
self.assertTrue(jnp.all((arr == 0) | (arr == 1)))
|
305
|
+
|
306
|
+
def test_bernoulli_invalid_p(self):
|
307
|
+
"""Test Bernoulli with invalid probability."""
|
308
|
+
# Note: This should trigger jit_error_if, but in test we check the validation exists
|
309
|
+
with self.assertRaises((ValueError, Exception)):
|
310
|
+
self.rs.bernoulli(p=1.5) # p > 1
|
311
|
+
|
312
|
+
def test_truncated_normal(self):
|
313
|
+
"""Test truncated normal distribution."""
|
314
|
+
arr = self.rs.truncated_normal(lower=-1.0, upper=1.0, size=(2, 3))
|
315
|
+
self.assertEqual(arr.shape, (2, 3))
|
316
|
+
self.assertTrue((arr >= -1.0).all() and (arr <= 1.0).all())
|
317
|
+
|
318
|
+
def test_multivariate_normal(self):
|
319
|
+
"""Test multivariate normal distribution."""
|
320
|
+
mean = jnp.array([0.0, 1.0])
|
321
|
+
cov = jnp.array([[1.0, 0.5], [0.5, 2.0]])
|
322
|
+
|
323
|
+
arr = self.rs.multivariate_normal(mean, cov, size=(3,))
|
324
|
+
self.assertEqual(arr.shape, (3, 2))
|
325
|
+
|
326
|
+
def test_categorical(self):
|
327
|
+
"""Test categorical distribution."""
|
328
|
+
logits = jnp.array([0.1, 0.2, 0.3, 0.4])
|
329
|
+
arr = self.rs.categorical(logits, size=(10,))
|
330
|
+
self.assertEqual(arr.shape, (10,))
|
331
|
+
self.assertTrue((arr >= 0).all() and (arr < len(logits)).all())
|
332
|
+
|
333
|
+
|
334
|
+
class TestRandomStatePyTorchCompatibility(unittest.TestCase):
|
335
|
+
"""Test PyTorch-like methods."""
|
336
|
+
|
337
|
+
def setUp(self):
|
338
|
+
self.rs = RandomState(42)
|
339
|
+
|
340
|
+
def test_rand_like(self):
|
341
|
+
"""Test rand_like method."""
|
342
|
+
input_tensor = jnp.zeros((3, 4))
|
343
|
+
result = self.rs.rand_like(input_tensor)
|
344
|
+
self.assertEqual(result.shape, input_tensor.shape)
|
345
|
+
self.assertTrue((result >= 0).all() and (result < 1).all())
|
346
|
+
|
347
|
+
def test_randn_like(self):
|
348
|
+
"""Test randn_like method."""
|
349
|
+
input_tensor = jnp.zeros((2, 3))
|
350
|
+
result = self.rs.randn_like(input_tensor)
|
351
|
+
self.assertEqual(result.shape, input_tensor.shape)
|
352
|
+
|
353
|
+
def test_randint_like(self):
|
354
|
+
"""Test randint_like method."""
|
355
|
+
input_tensor = jnp.zeros((2, 3), dtype=jnp.int32)
|
356
|
+
result = self.rs.randint_like(input_tensor, low=0, high=10)
|
357
|
+
self.assertEqual(result.shape, input_tensor.shape)
|
358
|
+
self.assertTrue((result >= 0).all() and (result < 10).all())
|
359
|
+
|
360
|
+
|
361
|
+
class TestRandomStateKeyBehavior(unittest.TestCase):
|
362
|
+
"""Test key parameter behavior across methods."""
|
363
|
+
|
364
|
+
def setUp(self):
|
365
|
+
self.rs = RandomState(42)
|
366
|
+
|
367
|
+
def test_external_key_does_not_change_state(self):
|
368
|
+
"""Test that using external key doesn't change internal state."""
|
369
|
+
original_key = self.rs.value.copy()
|
370
|
+
external_key = jr.PRNGKey(999)
|
371
|
+
|
372
|
+
# Use external key
|
373
|
+
self.rs.rand(5, key=external_key)
|
374
|
+
|
375
|
+
# Internal state should be unchanged
|
376
|
+
np.testing.assert_array_equal(self.rs.value, original_key)
|
377
|
+
|
378
|
+
def test_no_key_changes_state(self):
|
379
|
+
"""Test that not providing key changes internal state."""
|
380
|
+
original_key = self.rs.value.copy()
|
381
|
+
|
382
|
+
# Use internal key
|
383
|
+
self.rs.rand(5)
|
384
|
+
|
385
|
+
# Internal state should have changed
|
386
|
+
self.assertFalse(np.array_equal(self.rs.value, original_key))
|
387
|
+
|
388
|
+
def test_reproducibility_with_same_key(self):
|
389
|
+
"""Test reproducibility when using same external key."""
|
390
|
+
key = jr.PRNGKey(123)
|
391
|
+
|
392
|
+
result1 = self.rs.rand(5, key=key)
|
393
|
+
result2 = self.rs.rand(5, key=key)
|
394
|
+
|
395
|
+
np.testing.assert_array_equal(result1, result2)
|
396
|
+
|
397
|
+
def test_reproducibility_with_seed(self):
|
398
|
+
"""Test reproducibility with seeding."""
|
399
|
+
self.rs.seed(42)
|
400
|
+
result1 = self.rs.rand(5)
|
401
|
+
|
402
|
+
self.rs.seed(42)
|
403
|
+
result2 = self.rs.rand(5)
|
404
|
+
|
405
|
+
np.testing.assert_array_equal(result1, result2)
|
406
|
+
|
407
|
+
|
408
|
+
class TestGlobalDefaultInstance(unittest.TestCase):
|
409
|
+
"""Test the global DEFAULT RandomState instance."""
|
410
|
+
|
411
|
+
def test_default_exists(self):
|
412
|
+
"""Test that DEFAULT instance exists and is RandomState."""
|
413
|
+
self.assertIsInstance(DEFAULT, RandomState)
|
414
|
+
|
415
|
+
def test_default_has_valid_key(self):
|
416
|
+
"""Test that DEFAULT has valid key."""
|
417
|
+
self.assertIsNotNone(DEFAULT.value)
|
418
|
+
self.assertEqual(DEFAULT.value.shape, (2,))
|
419
|
+
self.assertEqual(DEFAULT.value.dtype, jnp.uint32)
|
420
|
+
|
421
|
+
def test_default_seeding(self):
|
422
|
+
"""Test seeding DEFAULT instance."""
|
423
|
+
original_key = DEFAULT.value.copy()
|
424
|
+
DEFAULT.seed(12345)
|
425
|
+
self.assertFalse(np.array_equal(DEFAULT.value, original_key))
|
426
|
+
|
427
|
+
def test_default_split_key(self):
|
428
|
+
"""Test splitting DEFAULT key."""
|
429
|
+
original_key = DEFAULT.value.copy()
|
430
|
+
new_key = DEFAULT.split_key()
|
431
|
+
self.assertFalse(np.array_equal(DEFAULT.value, original_key))
|
432
|
+
self.assertIsNotNone(new_key)
|
433
|
+
|
434
|
+
|
435
|
+
class TestUtilityFunctions(unittest.TestCase):
|
436
|
+
"""Test utility functions in _rand_state module."""
|
437
|
+
|
438
|
+
def test_formalize_key_with_int(self):
|
439
|
+
"""Test _formalize_key with integer."""
|
440
|
+
key = formalize_key(42)
|
441
|
+
expected = jr.PRNGKey(42)
|
442
|
+
np.testing.assert_array_equal(key, expected)
|
443
|
+
|
444
|
+
def test_formalize_key_with_array(self):
|
445
|
+
"""Test _formalize_key with array."""
|
446
|
+
input_key = jr.PRNGKey(123)
|
447
|
+
key = formalize_key(input_key, True)
|
448
|
+
np.testing.assert_array_equal(key, input_key)
|
449
|
+
|
450
|
+
def test_formalize_key_with_uint32_array(self):
|
451
|
+
"""Test _formalize_key with uint32 array."""
|
452
|
+
input_array = np.array([123, 456], dtype=np.uint32)
|
453
|
+
key = formalize_key(input_array)
|
454
|
+
np.testing.assert_array_equal(key, input_array)
|
455
|
+
|
456
|
+
def test_formalize_key_invalid_input(self):
|
457
|
+
"""Test _formalize_key with invalid input."""
|
458
|
+
with self.assertRaises(TypeError):
|
459
|
+
formalize_key("invalid")
|
460
|
+
|
461
|
+
with self.assertRaises(TypeError):
|
462
|
+
formalize_key(np.array([1, 2, 3], dtype=np.uint32)) # Wrong size
|
463
|
+
|
464
|
+
with self.assertRaises(TypeError):
|
465
|
+
formalize_key(np.array([1, 2], dtype=np.int32)) # Wrong dtype
|
466
|
+
|
467
|
+
def test_size2shape(self):
|
468
|
+
"""Test _size2shape function."""
|
469
|
+
self.assertEqual(_size2shape(None), ())
|
470
|
+
self.assertEqual(_size2shape(5), (5,))
|
471
|
+
self.assertEqual(_size2shape((3, 4)), (3, 4))
|
472
|
+
self.assertEqual(_size2shape([2, 3, 4]), (2, 3, 4))
|
473
|
+
|
474
|
+
def test_check_py_seq(self):
|
475
|
+
"""Test _check_py_seq function."""
|
476
|
+
# Should convert lists/tuples to arrays
|
477
|
+
result = _check_py_seq([1, 2, 3])
|
478
|
+
self.assertIsInstance(result, jnp.ndarray)
|
479
|
+
np.testing.assert_array_equal(result, jnp.array([1, 2, 3]))
|
480
|
+
|
481
|
+
# Should leave other types unchanged
|
482
|
+
arr = jnp.array([1, 2, 3])
|
483
|
+
result = _check_py_seq(arr)
|
484
|
+
self.assertIs(result, arr)
|
485
|
+
|
486
|
+
scalar = 5
|
487
|
+
result = _check_py_seq(scalar)
|
488
|
+
self.assertEqual(result, scalar)
|
489
|
+
|
490
|
+
|
491
|
+
class TestErrorHandling(unittest.TestCase):
|
492
|
+
"""Test error handling and edge cases."""
|
493
|
+
|
494
|
+
def setUp(self):
|
495
|
+
self.rs = RandomState(42)
|
496
|
+
|
497
|
+
def test_invalid_distribution_parameters(self):
|
498
|
+
"""Test invalid parameters for distributions."""
|
499
|
+
# Note: Some distributions may not validate parameters immediately
|
500
|
+
# so we test what we can verify
|
501
|
+
|
502
|
+
# Test invalid probability for binomial should work with check_valid=True
|
503
|
+
try:
|
504
|
+
# This may or may not raise immediately depending on JAX compilation
|
505
|
+
self.rs.binomial(n=10, p=1.5, check_valid=True)
|
506
|
+
except:
|
507
|
+
pass # Expected to fail
|
508
|
+
|
509
|
+
# Test normal distribution works with negative scale (JAX allows this)
|
510
|
+
result = self.rs.normal(scale=-1.0, size=(2,))
|
511
|
+
self.assertEqual(result.shape, (2,))
|
512
|
+
|
513
|
+
def test_invalid_size_parameters(self):
|
514
|
+
"""Test invalid size parameters."""
|
515
|
+
# Test empty shape works for distributions that accept size parameter
|
516
|
+
result = self.rs.random(size=())
|
517
|
+
self.assertEqual(result.shape, ())
|
518
|
+
|
519
|
+
# Test with None size
|
520
|
+
result = self.rs.random(size=None)
|
521
|
+
self.assertEqual(result.shape, ())
|
522
|
+
|
523
|
+
def test_dtype_consistency(self):
|
524
|
+
"""Test dtype consistency across methods."""
|
525
|
+
# Integer methods should return integers
|
526
|
+
result = self.rs.randint(10, size=(3,))
|
527
|
+
self.assertTrue(jnp.issubdtype(result.dtype, jnp.integer))
|
528
|
+
|
529
|
+
# Float methods should return floats
|
530
|
+
result = self.rs.rand(3)
|
531
|
+
self.assertTrue(jnp.issubdtype(result.dtype, jnp.floating))
|
532
|
+
|
533
|
+
def test_self_assign_multi_keys(self):
|
534
|
+
"""Test self_assign_multi_keys method."""
|
535
|
+
original_shape = self.rs.value.shape
|
536
|
+
|
537
|
+
# Test with backup
|
538
|
+
self.rs.self_assign_multi_keys(3, backup=True)
|
539
|
+
self.assertEqual(self.rs.value.shape, (3, 2))
|
540
|
+
|
541
|
+
# Restore should work
|
542
|
+
self.rs.restore_key()
|
543
|
+
self.assertEqual(self.rs.value.shape, original_shape)
|
544
|
+
|
545
|
+
# Test without backup
|
546
|
+
self.rs.self_assign_multi_keys(2, backup=False)
|
547
|
+
self.assertEqual(self.rs.value.shape, (2, 2))
|
548
|
+
|
549
|
+
|
550
|
+
if __name__ == '__main__':
|
551
|
+
unittest.main()
|