brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,551 +1,551 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import unittest
17
-
18
- import jax.numpy as jnp
19
- import jax.random as jr
20
- import numpy as np
21
-
22
- from brainstate.random._rand_state import RandomState, DEFAULT, _formalize_key, _size2shape, _check_py_seq
23
-
24
-
25
- class TestRandomStateInitialization(unittest.TestCase):
26
- """Test RandomState initialization and setup."""
27
-
28
- def test_init_with_none(self):
29
- """Test initialization with None seed."""
30
- rs = RandomState(None)
31
- self.assertIsNotNone(rs.value)
32
- self.assertEqual(rs.value.shape, (2,))
33
- self.assertEqual(rs.value.dtype, jnp.uint32)
34
-
35
- def test_init_with_int_seed(self):
36
- """Test initialization with integer seed."""
37
- seed = 42
38
- rs = RandomState(seed)
39
- expected_key = jr.PRNGKey(seed)
40
- np.testing.assert_array_equal(rs.value, expected_key)
41
-
42
- def test_init_with_prng_key(self):
43
- """Test initialization with JAX PRNGKey."""
44
- key = jr.PRNGKey(123)
45
- rs = RandomState(key)
46
- np.testing.assert_array_equal(rs.value, key)
47
-
48
- def test_init_with_uint32_array(self):
49
- """Test initialization with uint32 array."""
50
- key_array = np.array([123, 456], dtype=np.uint32)
51
- rs = RandomState(key_array)
52
- np.testing.assert_array_equal(rs.value, key_array)
53
-
54
- def test_init_with_invalid_key(self):
55
- """Test initialization with invalid key raises error."""
56
- # Test case that should raise error: wrong length AND wrong dtype
57
- with self.assertRaises(ValueError):
58
- RandomState(np.array([1, 2, 3], dtype=np.int32)) # len != 2 AND dtype != uint32
59
-
60
- # Test valid cases that should NOT raise errors
61
- # Wrong length but correct dtype is OK
62
- rs1 = RandomState(np.array([1, 2, 3], dtype=np.uint32))
63
- self.assertIsNotNone(rs1.value)
64
-
65
- # Correct length but wrong dtype is OK
66
- rs2 = RandomState(np.array([1, 2], dtype=np.int32))
67
- self.assertIsNotNone(rs2.value)
68
-
69
- def test_repr(self):
70
- """Test string representation."""
71
- rs = RandomState(42)
72
- repr_str = repr(rs)
73
- self.assertIn("RandomState", repr_str)
74
- self.assertIn("42", repr_str)
75
-
76
-
77
- class TestRandomStateKeyManagement(unittest.TestCase):
78
- """Test key management functionality."""
79
-
80
- def setUp(self):
81
- self.rs = RandomState(42)
82
-
83
- def test_seed_with_int(self):
84
- """Test seeding with integer."""
85
- self.rs.seed(123)
86
- expected_key = jr.PRNGKey(123)
87
- np.testing.assert_array_equal(self.rs.value, expected_key)
88
-
89
- def test_seed_with_none(self):
90
- """Test seeding with None generates new random seed."""
91
- original_key = self.rs.value.copy()
92
- self.rs.seed(None)
93
- # Should be different (with very high probability)
94
- self.assertFalse(np.array_equal(self.rs.value, original_key))
95
-
96
- def test_seed_with_prng_key(self):
97
- """Test seeding with PRNGKey."""
98
- key = jr.PRNGKey(999)
99
- self.rs.seed(key)
100
- np.testing.assert_array_equal(self.rs.value, key)
101
-
102
- def test_seed_with_invalid_input(self):
103
- """Test seeding with invalid input raises error."""
104
- with self.assertRaises(ValueError):
105
- self.rs.seed([1, 2, 3]) # Wrong length list
106
-
107
- def test_split_key_single(self):
108
- """Test splitting key to get single new key."""
109
- original_key = self.rs.value.copy()
110
- new_key = self.rs.split_key()
111
-
112
- # Original key should have changed
113
- self.assertFalse(np.array_equal(self.rs.value, original_key))
114
- # New key should be different from both
115
- self.assertFalse(np.array_equal(new_key, original_key))
116
- self.assertFalse(np.array_equal(new_key, self.rs.value))
117
-
118
- def test_split_key_multiple(self):
119
- """Test splitting key to get multiple new keys."""
120
- n = 3
121
- original_key = self.rs.value.copy()
122
- new_keys = self.rs.split_key(n)
123
-
124
- self.assertEqual(len(new_keys), n)
125
- # All keys should be different
126
- for i, key in enumerate(new_keys):
127
- self.assertFalse(np.array_equal(key, original_key))
128
- for j, other_key in enumerate(new_keys):
129
- if i != j:
130
- self.assertFalse(np.array_equal(key, other_key))
131
-
132
- def test_split_key_invalid_n(self):
133
- """Test split_key with invalid n raises error."""
134
- with self.assertRaises(AssertionError):
135
- self.rs.split_key(0)
136
-
137
- with self.assertRaises(AssertionError):
138
- self.rs.split_key(-1)
139
-
140
- def test_backup_restore_key(self):
141
- """Test backup and restore functionality."""
142
- original_key = self.rs.value.copy()
143
-
144
- # Backup the key
145
- self.rs.backup_key()
146
-
147
- # Change the key
148
- self.rs.split_key()
149
- changed_key = self.rs.value.copy()
150
- self.assertFalse(np.array_equal(changed_key, original_key))
151
-
152
- # Restore the key
153
- self.rs.restore_key()
154
- np.testing.assert_array_equal(self.rs.value, original_key)
155
-
156
- def test_backup_already_backed_up(self):
157
- """Test backup when already backed up raises error."""
158
- self.rs.backup_key()
159
- with self.assertRaises(ValueError):
160
- self.rs.backup_key()
161
-
162
- def test_restore_without_backup(self):
163
- """Test restore without backup raises error."""
164
- with self.assertRaises(ValueError):
165
- self.rs.restore_key()
166
-
167
- def test_clone(self):
168
- """Test cloning creates independent copy."""
169
- clone = self.rs.clone()
170
-
171
- # Should be different instances
172
- self.assertIsNot(clone, self.rs)
173
-
174
- # Should have different keys after split
175
- original_key = self.rs.value.copy()
176
- clone_key = clone.value.copy()
177
-
178
- self.rs.split_key()
179
- clone.split_key()
180
-
181
- self.assertFalse(np.array_equal(self.rs.value, clone.value))
182
-
183
- def test_set_key(self):
184
- """Test setting key directly."""
185
- new_key = jr.PRNGKey(999)
186
- self.rs.set_key(new_key)
187
- np.testing.assert_array_equal(self.rs.value, new_key)
188
-
189
-
190
- class TestRandomStateDistributions(unittest.TestCase):
191
- """Test random distribution methods."""
192
-
193
- def setUp(self):
194
- self.rs = RandomState(42)
195
-
196
- def test_rand(self):
197
- """Test rand method."""
198
- # Single value
199
- val = self.rs.rand()
200
- self.assertEqual(val.shape, ())
201
- self.assertTrue(0 <= val < 1)
202
-
203
- # Multiple dimensions
204
- arr = self.rs.rand(3, 2)
205
- self.assertEqual(arr.shape, (3, 2))
206
- self.assertTrue((arr >= 0).all() and (arr < 1).all())
207
-
208
- def test_randint(self):
209
- """Test randint method."""
210
- # Single bound
211
- val = self.rs.randint(10)
212
- self.assertTrue(0 <= val < 10)
213
-
214
- # Both bounds
215
- val = self.rs.randint(5, 15)
216
- self.assertTrue(5 <= val < 15)
217
-
218
- # With size
219
- arr = self.rs.randint(0, 5, size=(2, 3))
220
- self.assertEqual(arr.shape, (2, 3))
221
- self.assertTrue((arr >= 0).all() and (arr < 5).all())
222
-
223
- def test_randn(self):
224
- """Test randn method."""
225
- # Single value
226
- val = self.rs.randn()
227
- self.assertEqual(val.shape, ())
228
-
229
- # Multiple dimensions
230
- arr = self.rs.randn(3, 2)
231
- self.assertEqual(arr.shape, (3, 2))
232
-
233
- def test_normal(self):
234
- """Test normal distribution."""
235
- # Standard normal
236
- val = self.rs.normal()
237
- self.assertEqual(val.shape, ())
238
-
239
- # With parameters
240
- arr = self.rs.normal(loc=5.0, scale=2.0, size=(3, 2))
241
- self.assertEqual(arr.shape, (3, 2))
242
-
243
- def test_uniform(self):
244
- """Test uniform distribution."""
245
- # Standard uniform
246
- val = self.rs.uniform()
247
- self.assertTrue(0.0 <= val < 1.0)
248
-
249
- # With bounds
250
- arr = self.rs.uniform(low=2.0, high=8.0, size=(2, 3))
251
- self.assertEqual(arr.shape, (2, 3))
252
- self.assertTrue((arr >= 2.0).all() and (arr < 8.0).all())
253
-
254
- def test_choice(self):
255
- """Test choice method."""
256
- # Choose from range
257
- val = self.rs.choice(5)
258
- self.assertTrue(0 <= val < 5)
259
-
260
- # Choose from array
261
- options = jnp.array([10, 20, 30, 40])
262
- val = self.rs.choice(options)
263
- self.assertIn(val, options)
264
-
265
- # Multiple choices
266
- arr = self.rs.choice(5, size=10)
267
- self.assertEqual(arr.shape, (10,))
268
- self.assertTrue((arr >= 0).all() and (arr < 5).all())
269
-
270
- def test_beta(self):
271
- """Test beta distribution."""
272
- arr = self.rs.beta(2.0, 3.0, size=(2, 3))
273
- self.assertEqual(arr.shape, (2, 3))
274
- self.assertTrue((arr >= 0).all() and (arr <= 1).all())
275
-
276
- def test_exponential(self):
277
- """Test exponential distribution."""
278
- arr = self.rs.exponential(scale=2.0, size=(2, 3))
279
- self.assertEqual(arr.shape, (2, 3))
280
- self.assertTrue((arr >= 0).all())
281
-
282
- def test_gamma(self):
283
- """Test gamma distribution."""
284
- arr = self.rs.gamma(shape=2.0, scale=1.0, size=(2, 3))
285
- self.assertEqual(arr.shape, (2, 3))
286
- self.assertTrue((arr >= 0).all())
287
-
288
- def test_poisson(self):
289
- """Test Poisson distribution."""
290
- arr = self.rs.poisson(lam=3.0, size=(2, 3))
291
- self.assertEqual(arr.shape, (2, 3))
292
- self.assertTrue((arr >= 0).all())
293
-
294
- def test_binomial(self):
295
- """Test binomial distribution."""
296
- arr = self.rs.binomial(n=10, p=0.3, size=(2, 3))
297
- self.assertEqual(arr.shape, (2, 3))
298
- self.assertTrue((arr >= 0).all() and (arr <= 10).all())
299
-
300
- def test_bernoulli(self):
301
- """Test Bernoulli distribution."""
302
- arr = self.rs.bernoulli(p=0.7, size=(100,))
303
- self.assertEqual(arr.shape, (100,))
304
- self.assertTrue(jnp.all((arr == 0) | (arr == 1)))
305
-
306
- def test_bernoulli_invalid_p(self):
307
- """Test Bernoulli with invalid probability."""
308
- # Note: This should trigger jit_error_if, but in test we check the validation exists
309
- with self.assertRaises((ValueError, Exception)):
310
- self.rs.bernoulli(p=1.5) # p > 1
311
-
312
- def test_truncated_normal(self):
313
- """Test truncated normal distribution."""
314
- arr = self.rs.truncated_normal(lower=-1.0, upper=1.0, size=(2, 3))
315
- self.assertEqual(arr.shape, (2, 3))
316
- self.assertTrue((arr >= -1.0).all() and (arr <= 1.0).all())
317
-
318
- def test_multivariate_normal(self):
319
- """Test multivariate normal distribution."""
320
- mean = jnp.array([0.0, 1.0])
321
- cov = jnp.array([[1.0, 0.5], [0.5, 2.0]])
322
-
323
- arr = self.rs.multivariate_normal(mean, cov, size=(3,))
324
- self.assertEqual(arr.shape, (3, 2))
325
-
326
- def test_categorical(self):
327
- """Test categorical distribution."""
328
- logits = jnp.array([0.1, 0.2, 0.3, 0.4])
329
- arr = self.rs.categorical(logits, size=(10,))
330
- self.assertEqual(arr.shape, (10,))
331
- self.assertTrue((arr >= 0).all() and (arr < len(logits)).all())
332
-
333
-
334
- class TestRandomStatePyTorchCompatibility(unittest.TestCase):
335
- """Test PyTorch-like methods."""
336
-
337
- def setUp(self):
338
- self.rs = RandomState(42)
339
-
340
- def test_rand_like(self):
341
- """Test rand_like method."""
342
- input_tensor = jnp.zeros((3, 4))
343
- result = self.rs.rand_like(input_tensor)
344
- self.assertEqual(result.shape, input_tensor.shape)
345
- self.assertTrue((result >= 0).all() and (result < 1).all())
346
-
347
- def test_randn_like(self):
348
- """Test randn_like method."""
349
- input_tensor = jnp.zeros((2, 3))
350
- result = self.rs.randn_like(input_tensor)
351
- self.assertEqual(result.shape, input_tensor.shape)
352
-
353
- def test_randint_like(self):
354
- """Test randint_like method."""
355
- input_tensor = jnp.zeros((2, 3), dtype=jnp.int32)
356
- result = self.rs.randint_like(input_tensor, low=0, high=10)
357
- self.assertEqual(result.shape, input_tensor.shape)
358
- self.assertTrue((result >= 0).all() and (result < 10).all())
359
-
360
-
361
- class TestRandomStateKeyBehavior(unittest.TestCase):
362
- """Test key parameter behavior across methods."""
363
-
364
- def setUp(self):
365
- self.rs = RandomState(42)
366
-
367
- def test_external_key_does_not_change_state(self):
368
- """Test that using external key doesn't change internal state."""
369
- original_key = self.rs.value.copy()
370
- external_key = jr.PRNGKey(999)
371
-
372
- # Use external key
373
- self.rs.rand(5, key=external_key)
374
-
375
- # Internal state should be unchanged
376
- np.testing.assert_array_equal(self.rs.value, original_key)
377
-
378
- def test_no_key_changes_state(self):
379
- """Test that not providing key changes internal state."""
380
- original_key = self.rs.value.copy()
381
-
382
- # Use internal key
383
- self.rs.rand(5)
384
-
385
- # Internal state should have changed
386
- self.assertFalse(np.array_equal(self.rs.value, original_key))
387
-
388
- def test_reproducibility_with_same_key(self):
389
- """Test reproducibility when using same external key."""
390
- key = jr.PRNGKey(123)
391
-
392
- result1 = self.rs.rand(5, key=key)
393
- result2 = self.rs.rand(5, key=key)
394
-
395
- np.testing.assert_array_equal(result1, result2)
396
-
397
- def test_reproducibility_with_seed(self):
398
- """Test reproducibility with seeding."""
399
- self.rs.seed(42)
400
- result1 = self.rs.rand(5)
401
-
402
- self.rs.seed(42)
403
- result2 = self.rs.rand(5)
404
-
405
- np.testing.assert_array_equal(result1, result2)
406
-
407
-
408
- class TestGlobalDefaultInstance(unittest.TestCase):
409
- """Test the global DEFAULT RandomState instance."""
410
-
411
- def test_default_exists(self):
412
- """Test that DEFAULT instance exists and is RandomState."""
413
- self.assertIsInstance(DEFAULT, RandomState)
414
-
415
- def test_default_has_valid_key(self):
416
- """Test that DEFAULT has valid key."""
417
- self.assertIsNotNone(DEFAULT.value)
418
- self.assertEqual(DEFAULT.value.shape, (2,))
419
- self.assertEqual(DEFAULT.value.dtype, jnp.uint32)
420
-
421
- def test_default_seeding(self):
422
- """Test seeding DEFAULT instance."""
423
- original_key = DEFAULT.value.copy()
424
- DEFAULT.seed(12345)
425
- self.assertFalse(np.array_equal(DEFAULT.value, original_key))
426
-
427
- def test_default_split_key(self):
428
- """Test splitting DEFAULT key."""
429
- original_key = DEFAULT.value.copy()
430
- new_key = DEFAULT.split_key()
431
- self.assertFalse(np.array_equal(DEFAULT.value, original_key))
432
- self.assertIsNotNone(new_key)
433
-
434
-
435
- class TestUtilityFunctions(unittest.TestCase):
436
- """Test utility functions in _rand_state module."""
437
-
438
- def test_formalize_key_with_int(self):
439
- """Test _formalize_key with integer."""
440
- key = _formalize_key(42)
441
- expected = jr.PRNGKey(42)
442
- np.testing.assert_array_equal(key, expected)
443
-
444
- def test_formalize_key_with_array(self):
445
- """Test _formalize_key with array."""
446
- input_key = jr.PRNGKey(123)
447
- key = _formalize_key(input_key)
448
- np.testing.assert_array_equal(key, input_key)
449
-
450
- def test_formalize_key_with_uint32_array(self):
451
- """Test _formalize_key with uint32 array."""
452
- input_array = np.array([123, 456], dtype=np.uint32)
453
- key = _formalize_key(input_array)
454
- np.testing.assert_array_equal(key, input_array)
455
-
456
- def test_formalize_key_invalid_input(self):
457
- """Test _formalize_key with invalid input."""
458
- with self.assertRaises(TypeError):
459
- _formalize_key("invalid")
460
-
461
- with self.assertRaises(TypeError):
462
- _formalize_key(np.array([1, 2, 3], dtype=np.uint32)) # Wrong size
463
-
464
- with self.assertRaises(TypeError):
465
- _formalize_key(np.array([1, 2], dtype=np.int32)) # Wrong dtype
466
-
467
- def test_size2shape(self):
468
- """Test _size2shape function."""
469
- self.assertEqual(_size2shape(None), ())
470
- self.assertEqual(_size2shape(5), (5,))
471
- self.assertEqual(_size2shape((3, 4)), (3, 4))
472
- self.assertEqual(_size2shape([2, 3, 4]), (2, 3, 4))
473
-
474
- def test_check_py_seq(self):
475
- """Test _check_py_seq function."""
476
- # Should convert lists/tuples to arrays
477
- result = _check_py_seq([1, 2, 3])
478
- self.assertIsInstance(result, jnp.ndarray)
479
- np.testing.assert_array_equal(result, jnp.array([1, 2, 3]))
480
-
481
- # Should leave other types unchanged
482
- arr = jnp.array([1, 2, 3])
483
- result = _check_py_seq(arr)
484
- self.assertIs(result, arr)
485
-
486
- scalar = 5
487
- result = _check_py_seq(scalar)
488
- self.assertEqual(result, scalar)
489
-
490
-
491
- class TestErrorHandling(unittest.TestCase):
492
- """Test error handling and edge cases."""
493
-
494
- def setUp(self):
495
- self.rs = RandomState(42)
496
-
497
- def test_invalid_distribution_parameters(self):
498
- """Test invalid parameters for distributions."""
499
- # Note: Some distributions may not validate parameters immediately
500
- # so we test what we can verify
501
-
502
- # Test invalid probability for binomial should work with check_valid=True
503
- try:
504
- # This may or may not raise immediately depending on JAX compilation
505
- self.rs.binomial(n=10, p=1.5, check_valid=True)
506
- except:
507
- pass # Expected to fail
508
-
509
- # Test normal distribution works with negative scale (JAX allows this)
510
- result = self.rs.normal(scale=-1.0, size=(2,))
511
- self.assertEqual(result.shape, (2,))
512
-
513
- def test_invalid_size_parameters(self):
514
- """Test invalid size parameters."""
515
- # Test empty shape works for distributions that accept size parameter
516
- result = self.rs.random(size=())
517
- self.assertEqual(result.shape, ())
518
-
519
- # Test with None size
520
- result = self.rs.random(size=None)
521
- self.assertEqual(result.shape, ())
522
-
523
- def test_dtype_consistency(self):
524
- """Test dtype consistency across methods."""
525
- # Integer methods should return integers
526
- result = self.rs.randint(10, size=(3,))
527
- self.assertTrue(jnp.issubdtype(result.dtype, jnp.integer))
528
-
529
- # Float methods should return floats
530
- result = self.rs.rand(3)
531
- self.assertTrue(jnp.issubdtype(result.dtype, jnp.floating))
532
-
533
- def test_self_assign_multi_keys(self):
534
- """Test self_assign_multi_keys method."""
535
- original_shape = self.rs.value.shape
536
-
537
- # Test with backup
538
- self.rs.self_assign_multi_keys(3, backup=True)
539
- self.assertEqual(self.rs.value.shape, (3, 2))
540
-
541
- # Restore should work
542
- self.rs.restore_key()
543
- self.assertEqual(self.rs.value.shape, original_shape)
544
-
545
- # Test without backup
546
- self.rs.self_assign_multi_keys(2, backup=False)
547
- self.assertEqual(self.rs.value.shape, (2, 2))
548
-
549
-
550
- if __name__ == '__main__':
551
- unittest.main()
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import jax.numpy as jnp
19
+ import jax.random as jr
20
+ import numpy as np
21
+
22
+ from brainstate.random._state import RandomState, DEFAULT, formalize_key, _size2shape, _check_py_seq
23
+
24
+
25
+ class TestRandomStateInitialization(unittest.TestCase):
26
+ """Test RandomState initialization and setup."""
27
+
28
+ def test_init_with_none(self):
29
+ """Test initialization with None seed."""
30
+ rs = RandomState(None)
31
+ self.assertIsNotNone(rs.value)
32
+ self.assertEqual(rs.value.shape, (2,))
33
+ self.assertEqual(rs.value.dtype, jnp.uint32)
34
+
35
+ def test_init_with_int_seed(self):
36
+ """Test initialization with integer seed."""
37
+ seed = 42
38
+ rs = RandomState(seed)
39
+ expected_key = jr.PRNGKey(seed)
40
+ np.testing.assert_array_equal(rs.value, expected_key)
41
+
42
+ def test_init_with_prng_key(self):
43
+ """Test initialization with JAX PRNGKey."""
44
+ key = jr.PRNGKey(123)
45
+ rs = RandomState(key)
46
+ np.testing.assert_array_equal(rs.value, key)
47
+
48
+ def test_init_with_uint32_array(self):
49
+ """Test initialization with uint32 array."""
50
+ key_array = np.array([123, 456], dtype=np.uint32)
51
+ rs = RandomState(key_array)
52
+ np.testing.assert_array_equal(rs.value, key_array)
53
+
54
+ def test_init_with_invalid_key(self):
55
+ """Test initialization with invalid key raises error."""
56
+ # Test case that should raise error: wrong length AND wrong dtype
57
+ with self.assertRaises(ValueError):
58
+ RandomState(np.array([1, 2, 3], dtype=np.int32)) # len != 2 AND dtype != uint32
59
+
60
+ # Test valid cases that should NOT raise errors
61
+ # Wrong length but correct dtype is OK
62
+ rs1 = RandomState(np.array([1, 2, 3], dtype=np.uint32))
63
+ self.assertIsNotNone(rs1.value)
64
+
65
+ # Correct length but wrong dtype is OK
66
+ rs2 = RandomState(np.array([1, 2], dtype=np.int32))
67
+ self.assertIsNotNone(rs2.value)
68
+
69
+ def test_repr(self):
70
+ """Test string representation."""
71
+ rs = RandomState(42)
72
+ repr_str = repr(rs)
73
+ self.assertIn("RandomState", repr_str)
74
+ self.assertIn("42", repr_str)
75
+
76
+
77
+ class TestRandomStateKeyManagement(unittest.TestCase):
78
+ """Test key management functionality."""
79
+
80
+ def setUp(self):
81
+ self.rs = RandomState(42)
82
+
83
+ def test_seed_with_int(self):
84
+ """Test seeding with integer."""
85
+ self.rs.seed(123)
86
+ expected_key = jr.PRNGKey(123)
87
+ np.testing.assert_array_equal(self.rs.value, expected_key)
88
+
89
+ def test_seed_with_none(self):
90
+ """Test seeding with None generates new random seed."""
91
+ original_key = self.rs.value.copy()
92
+ self.rs.seed(None)
93
+ # Should be different (with very high probability)
94
+ self.assertFalse(np.array_equal(self.rs.value, original_key))
95
+
96
+ def test_seed_with_prng_key(self):
97
+ """Test seeding with PRNGKey."""
98
+ key = jr.PRNGKey(999)
99
+ self.rs.seed(key)
100
+ np.testing.assert_array_equal(self.rs.value, key)
101
+
102
+ def test_seed_with_invalid_input(self):
103
+ """Test seeding with invalid input raises error."""
104
+ with self.assertRaises(ValueError):
105
+ self.rs.seed([1, 2, 3]) # Wrong length list
106
+
107
+ def test_split_key_single(self):
108
+ """Test splitting key to get single new key."""
109
+ original_key = self.rs.value.copy()
110
+ new_key = self.rs.split_key()
111
+
112
+ # Original key should have changed
113
+ self.assertFalse(np.array_equal(self.rs.value, original_key))
114
+ # New key should be different from both
115
+ self.assertFalse(np.array_equal(new_key, original_key))
116
+ self.assertFalse(np.array_equal(new_key, self.rs.value))
117
+
118
+ def test_split_key_multiple(self):
119
+ """Test splitting key to get multiple new keys."""
120
+ n = 3
121
+ original_key = self.rs.value.copy()
122
+ new_keys = self.rs.split_key(n)
123
+
124
+ self.assertEqual(len(new_keys), n)
125
+ # All keys should be different
126
+ for i, key in enumerate(new_keys):
127
+ self.assertFalse(np.array_equal(key, original_key))
128
+ for j, other_key in enumerate(new_keys):
129
+ if i != j:
130
+ self.assertFalse(np.array_equal(key, other_key))
131
+
132
+ def test_split_key_invalid_n(self):
133
+ """Test split_key with invalid n raises error."""
134
+ with self.assertRaises(AssertionError):
135
+ self.rs.split_key(0)
136
+
137
+ with self.assertRaises(AssertionError):
138
+ self.rs.split_key(-1)
139
+
140
+ def test_backup_restore_key(self):
141
+ """Test backup and restore functionality."""
142
+ original_key = self.rs.value.copy()
143
+
144
+ # Backup the key
145
+ self.rs.backup_key()
146
+
147
+ # Change the key
148
+ self.rs.split_key()
149
+ changed_key = self.rs.value.copy()
150
+ self.assertFalse(np.array_equal(changed_key, original_key))
151
+
152
+ # Restore the key
153
+ self.rs.restore_key()
154
+ np.testing.assert_array_equal(self.rs.value, original_key)
155
+
156
+ def test_backup_already_backed_up(self):
157
+ """Test backup when already backed up raises error."""
158
+ self.rs.backup_key()
159
+ with self.assertRaises(ValueError):
160
+ self.rs.backup_key()
161
+
162
+ def test_restore_without_backup(self):
163
+ """Test restore without backup raises error."""
164
+ with self.assertRaises(ValueError):
165
+ self.rs.restore_key()
166
+
167
+ def test_clone(self):
168
+ """Test cloning creates independent copy."""
169
+ clone = self.rs.clone()
170
+
171
+ # Should be different instances
172
+ self.assertIsNot(clone, self.rs)
173
+
174
+ # Should have different keys after split
175
+ original_key = self.rs.value.copy()
176
+ clone_key = clone.value.copy()
177
+
178
+ self.rs.split_key()
179
+ clone.split_key()
180
+
181
+ self.assertFalse(np.array_equal(self.rs.value, clone.value))
182
+
183
+ def test_set_key(self):
184
+ """Test setting key directly."""
185
+ new_key = jr.PRNGKey(999)
186
+ self.rs.set_key(new_key)
187
+ np.testing.assert_array_equal(self.rs.value, new_key)
188
+
189
+
190
+ class TestRandomStateDistributions(unittest.TestCase):
191
+ """Test random distribution methods."""
192
+
193
+ def setUp(self):
194
+ self.rs = RandomState(42)
195
+
196
+ def test_rand(self):
197
+ """Test rand method."""
198
+ # Single value
199
+ val = self.rs.rand()
200
+ self.assertEqual(val.shape, ())
201
+ self.assertTrue(0 <= val < 1)
202
+
203
+ # Multiple dimensions
204
+ arr = self.rs.rand(3, 2)
205
+ self.assertEqual(arr.shape, (3, 2))
206
+ self.assertTrue((arr >= 0).all() and (arr < 1).all())
207
+
208
+ def test_randint(self):
209
+ """Test randint method."""
210
+ # Single bound
211
+ val = self.rs.randint(10)
212
+ self.assertTrue(0 <= val < 10)
213
+
214
+ # Both bounds
215
+ val = self.rs.randint(5, 15)
216
+ self.assertTrue(5 <= val < 15)
217
+
218
+ # With size
219
+ arr = self.rs.randint(0, 5, size=(2, 3))
220
+ self.assertEqual(arr.shape, (2, 3))
221
+ self.assertTrue((arr >= 0).all() and (arr < 5).all())
222
+
223
+ def test_randn(self):
224
+ """Test randn method."""
225
+ # Single value
226
+ val = self.rs.randn()
227
+ self.assertEqual(val.shape, ())
228
+
229
+ # Multiple dimensions
230
+ arr = self.rs.randn(3, 2)
231
+ self.assertEqual(arr.shape, (3, 2))
232
+
233
+ def test_normal(self):
234
+ """Test normal distribution."""
235
+ # Standard normal
236
+ val = self.rs.normal()
237
+ self.assertEqual(val.shape, ())
238
+
239
+ # With parameters
240
+ arr = self.rs.normal(loc=5.0, scale=2.0, size=(3, 2))
241
+ self.assertEqual(arr.shape, (3, 2))
242
+
243
+ def test_uniform(self):
244
+ """Test uniform distribution."""
245
+ # Standard uniform
246
+ val = self.rs.uniform()
247
+ self.assertTrue(0.0 <= val < 1.0)
248
+
249
+ # With bounds
250
+ arr = self.rs.uniform(low=2.0, high=8.0, size=(2, 3))
251
+ self.assertEqual(arr.shape, (2, 3))
252
+ self.assertTrue((arr >= 2.0).all() and (arr < 8.0).all())
253
+
254
+ def test_choice(self):
255
+ """Test choice method."""
256
+ # Choose from range
257
+ val = self.rs.choice(5)
258
+ self.assertTrue(0 <= val < 5)
259
+
260
+ # Choose from array
261
+ options = jnp.array([10, 20, 30, 40])
262
+ val = self.rs.choice(options)
263
+ self.assertIn(val, options)
264
+
265
+ # Multiple choices
266
+ arr = self.rs.choice(5, size=10)
267
+ self.assertEqual(arr.shape, (10,))
268
+ self.assertTrue((arr >= 0).all() and (arr < 5).all())
269
+
270
+ def test_beta(self):
271
+ """Test beta distribution."""
272
+ arr = self.rs.beta(2.0, 3.0, size=(2, 3))
273
+ self.assertEqual(arr.shape, (2, 3))
274
+ self.assertTrue((arr >= 0).all() and (arr <= 1).all())
275
+
276
+ def test_exponential(self):
277
+ """Test exponential distribution."""
278
+ arr = self.rs.exponential(scale=2.0, size=(2, 3))
279
+ self.assertEqual(arr.shape, (2, 3))
280
+ self.assertTrue((arr >= 0).all())
281
+
282
+ def test_gamma(self):
283
+ """Test gamma distribution."""
284
+ arr = self.rs.gamma(shape=2.0, scale=1.0, size=(2, 3))
285
+ self.assertEqual(arr.shape, (2, 3))
286
+ self.assertTrue((arr >= 0).all())
287
+
288
+ def test_poisson(self):
289
+ """Test Poisson distribution."""
290
+ arr = self.rs.poisson(lam=3.0, size=(2, 3))
291
+ self.assertEqual(arr.shape, (2, 3))
292
+ self.assertTrue((arr >= 0).all())
293
+
294
+ def test_binomial(self):
295
+ """Test binomial distribution."""
296
+ arr = self.rs.binomial(n=10, p=0.3, size=(2, 3))
297
+ self.assertEqual(arr.shape, (2, 3))
298
+ self.assertTrue((arr >= 0).all() and (arr <= 10).all())
299
+
300
+ def test_bernoulli(self):
301
+ """Test Bernoulli distribution."""
302
+ arr = self.rs.bernoulli(p=0.7, size=(100,))
303
+ self.assertEqual(arr.shape, (100,))
304
+ self.assertTrue(jnp.all((arr == 0) | (arr == 1)))
305
+
306
+ def test_bernoulli_invalid_p(self):
307
+ """Test Bernoulli with invalid probability."""
308
+ # Note: This should trigger jit_error_if, but in test we check the validation exists
309
+ with self.assertRaises((ValueError, Exception)):
310
+ self.rs.bernoulli(p=1.5) # p > 1
311
+
312
+ def test_truncated_normal(self):
313
+ """Test truncated normal distribution."""
314
+ arr = self.rs.truncated_normal(lower=-1.0, upper=1.0, size=(2, 3))
315
+ self.assertEqual(arr.shape, (2, 3))
316
+ self.assertTrue((arr >= -1.0).all() and (arr <= 1.0).all())
317
+
318
+ def test_multivariate_normal(self):
319
+ """Test multivariate normal distribution."""
320
+ mean = jnp.array([0.0, 1.0])
321
+ cov = jnp.array([[1.0, 0.5], [0.5, 2.0]])
322
+
323
+ arr = self.rs.multivariate_normal(mean, cov, size=(3,))
324
+ self.assertEqual(arr.shape, (3, 2))
325
+
326
+ def test_categorical(self):
327
+ """Test categorical distribution."""
328
+ logits = jnp.array([0.1, 0.2, 0.3, 0.4])
329
+ arr = self.rs.categorical(logits, size=(10,))
330
+ self.assertEqual(arr.shape, (10,))
331
+ self.assertTrue((arr >= 0).all() and (arr < len(logits)).all())
332
+
333
+
334
+ class TestRandomStatePyTorchCompatibility(unittest.TestCase):
335
+ """Test PyTorch-like methods."""
336
+
337
+ def setUp(self):
338
+ self.rs = RandomState(42)
339
+
340
+ def test_rand_like(self):
341
+ """Test rand_like method."""
342
+ input_tensor = jnp.zeros((3, 4))
343
+ result = self.rs.rand_like(input_tensor)
344
+ self.assertEqual(result.shape, input_tensor.shape)
345
+ self.assertTrue((result >= 0).all() and (result < 1).all())
346
+
347
+ def test_randn_like(self):
348
+ """Test randn_like method."""
349
+ input_tensor = jnp.zeros((2, 3))
350
+ result = self.rs.randn_like(input_tensor)
351
+ self.assertEqual(result.shape, input_tensor.shape)
352
+
353
+ def test_randint_like(self):
354
+ """Test randint_like method."""
355
+ input_tensor = jnp.zeros((2, 3), dtype=jnp.int32)
356
+ result = self.rs.randint_like(input_tensor, low=0, high=10)
357
+ self.assertEqual(result.shape, input_tensor.shape)
358
+ self.assertTrue((result >= 0).all() and (result < 10).all())
359
+
360
+
361
+ class TestRandomStateKeyBehavior(unittest.TestCase):
362
+ """Test key parameter behavior across methods."""
363
+
364
+ def setUp(self):
365
+ self.rs = RandomState(42)
366
+
367
+ def test_external_key_does_not_change_state(self):
368
+ """Test that using external key doesn't change internal state."""
369
+ original_key = self.rs.value.copy()
370
+ external_key = jr.PRNGKey(999)
371
+
372
+ # Use external key
373
+ self.rs.rand(5, key=external_key)
374
+
375
+ # Internal state should be unchanged
376
+ np.testing.assert_array_equal(self.rs.value, original_key)
377
+
378
+ def test_no_key_changes_state(self):
379
+ """Test that not providing key changes internal state."""
380
+ original_key = self.rs.value.copy()
381
+
382
+ # Use internal key
383
+ self.rs.rand(5)
384
+
385
+ # Internal state should have changed
386
+ self.assertFalse(np.array_equal(self.rs.value, original_key))
387
+
388
+ def test_reproducibility_with_same_key(self):
389
+ """Test reproducibility when using same external key."""
390
+ key = jr.PRNGKey(123)
391
+
392
+ result1 = self.rs.rand(5, key=key)
393
+ result2 = self.rs.rand(5, key=key)
394
+
395
+ np.testing.assert_array_equal(result1, result2)
396
+
397
+ def test_reproducibility_with_seed(self):
398
+ """Test reproducibility with seeding."""
399
+ self.rs.seed(42)
400
+ result1 = self.rs.rand(5)
401
+
402
+ self.rs.seed(42)
403
+ result2 = self.rs.rand(5)
404
+
405
+ np.testing.assert_array_equal(result1, result2)
406
+
407
+
408
+ class TestGlobalDefaultInstance(unittest.TestCase):
409
+ """Test the global DEFAULT RandomState instance."""
410
+
411
+ def test_default_exists(self):
412
+ """Test that DEFAULT instance exists and is RandomState."""
413
+ self.assertIsInstance(DEFAULT, RandomState)
414
+
415
+ def test_default_has_valid_key(self):
416
+ """Test that DEFAULT has valid key."""
417
+ self.assertIsNotNone(DEFAULT.value)
418
+ self.assertEqual(DEFAULT.value.shape, (2,))
419
+ self.assertEqual(DEFAULT.value.dtype, jnp.uint32)
420
+
421
+ def test_default_seeding(self):
422
+ """Test seeding DEFAULT instance."""
423
+ original_key = DEFAULT.value.copy()
424
+ DEFAULT.seed(12345)
425
+ self.assertFalse(np.array_equal(DEFAULT.value, original_key))
426
+
427
+ def test_default_split_key(self):
428
+ """Test splitting DEFAULT key."""
429
+ original_key = DEFAULT.value.copy()
430
+ new_key = DEFAULT.split_key()
431
+ self.assertFalse(np.array_equal(DEFAULT.value, original_key))
432
+ self.assertIsNotNone(new_key)
433
+
434
+
435
+ class TestUtilityFunctions(unittest.TestCase):
436
+ """Test utility functions in _rand_state module."""
437
+
438
+ def test_formalize_key_with_int(self):
439
+ """Test _formalize_key with integer."""
440
+ key = formalize_key(42)
441
+ expected = jr.PRNGKey(42)
442
+ np.testing.assert_array_equal(key, expected)
443
+
444
+ def test_formalize_key_with_array(self):
445
+ """Test _formalize_key with array."""
446
+ input_key = jr.PRNGKey(123)
447
+ key = formalize_key(input_key, True)
448
+ np.testing.assert_array_equal(key, input_key)
449
+
450
+ def test_formalize_key_with_uint32_array(self):
451
+ """Test _formalize_key with uint32 array."""
452
+ input_array = np.array([123, 456], dtype=np.uint32)
453
+ key = formalize_key(input_array)
454
+ np.testing.assert_array_equal(key, input_array)
455
+
456
+ def test_formalize_key_invalid_input(self):
457
+ """Test _formalize_key with invalid input."""
458
+ with self.assertRaises(TypeError):
459
+ formalize_key("invalid")
460
+
461
+ with self.assertRaises(TypeError):
462
+ formalize_key(np.array([1, 2, 3], dtype=np.uint32)) # Wrong size
463
+
464
+ with self.assertRaises(TypeError):
465
+ formalize_key(np.array([1, 2], dtype=np.int32)) # Wrong dtype
466
+
467
+ def test_size2shape(self):
468
+ """Test _size2shape function."""
469
+ self.assertEqual(_size2shape(None), ())
470
+ self.assertEqual(_size2shape(5), (5,))
471
+ self.assertEqual(_size2shape((3, 4)), (3, 4))
472
+ self.assertEqual(_size2shape([2, 3, 4]), (2, 3, 4))
473
+
474
+ def test_check_py_seq(self):
475
+ """Test _check_py_seq function."""
476
+ # Should convert lists/tuples to arrays
477
+ result = _check_py_seq([1, 2, 3])
478
+ self.assertIsInstance(result, jnp.ndarray)
479
+ np.testing.assert_array_equal(result, jnp.array([1, 2, 3]))
480
+
481
+ # Should leave other types unchanged
482
+ arr = jnp.array([1, 2, 3])
483
+ result = _check_py_seq(arr)
484
+ self.assertIs(result, arr)
485
+
486
+ scalar = 5
487
+ result = _check_py_seq(scalar)
488
+ self.assertEqual(result, scalar)
489
+
490
+
491
+ class TestErrorHandling(unittest.TestCase):
492
+ """Test error handling and edge cases."""
493
+
494
+ def setUp(self):
495
+ self.rs = RandomState(42)
496
+
497
+ def test_invalid_distribution_parameters(self):
498
+ """Test invalid parameters for distributions."""
499
+ # Note: Some distributions may not validate parameters immediately
500
+ # so we test what we can verify
501
+
502
+ # Test invalid probability for binomial should work with check_valid=True
503
+ try:
504
+ # This may or may not raise immediately depending on JAX compilation
505
+ self.rs.binomial(n=10, p=1.5, check_valid=True)
506
+ except:
507
+ pass # Expected to fail
508
+
509
+ # Test normal distribution works with negative scale (JAX allows this)
510
+ result = self.rs.normal(scale=-1.0, size=(2,))
511
+ self.assertEqual(result.shape, (2,))
512
+
513
+ def test_invalid_size_parameters(self):
514
+ """Test invalid size parameters."""
515
+ # Test empty shape works for distributions that accept size parameter
516
+ result = self.rs.random(size=())
517
+ self.assertEqual(result.shape, ())
518
+
519
+ # Test with None size
520
+ result = self.rs.random(size=None)
521
+ self.assertEqual(result.shape, ())
522
+
523
+ def test_dtype_consistency(self):
524
+ """Test dtype consistency across methods."""
525
+ # Integer methods should return integers
526
+ result = self.rs.randint(10, size=(3,))
527
+ self.assertTrue(jnp.issubdtype(result.dtype, jnp.integer))
528
+
529
+ # Float methods should return floats
530
+ result = self.rs.rand(3)
531
+ self.assertTrue(jnp.issubdtype(result.dtype, jnp.floating))
532
+
533
+ def test_self_assign_multi_keys(self):
534
+ """Test self_assign_multi_keys method."""
535
+ original_shape = self.rs.value.shape
536
+
537
+ # Test with backup
538
+ self.rs.self_assign_multi_keys(3, backup=True)
539
+ self.assertEqual(self.rs.value.shape, (3, 2))
540
+
541
+ # Restore should work
542
+ self.rs.restore_key()
543
+ self.assertEqual(self.rs.value.shape, original_shape)
544
+
545
+ # Test without backup
546
+ self.rs.self_assign_multi_keys(2, backup=False)
547
+ self.assertEqual(self.rs.value.shape, (2, 2))
548
+
549
+
550
+ if __name__ == '__main__':
551
+ unittest.main()