brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,551 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import jax.numpy as jnp
19
+ import jax.random as jr
20
+ import numpy as np
21
+
22
+ from brainstate.random._rand_state import RandomState, DEFAULT, _formalize_key, _size2shape, _check_py_seq
23
+
24
+
25
+ class TestRandomStateInitialization(unittest.TestCase):
26
+ """Test RandomState initialization and setup."""
27
+
28
+ def test_init_with_none(self):
29
+ """Test initialization with None seed."""
30
+ rs = RandomState(None)
31
+ self.assertIsNotNone(rs.value)
32
+ self.assertEqual(rs.value.shape, (2,))
33
+ self.assertEqual(rs.value.dtype, jnp.uint32)
34
+
35
+ def test_init_with_int_seed(self):
36
+ """Test initialization with integer seed."""
37
+ seed = 42
38
+ rs = RandomState(seed)
39
+ expected_key = jr.PRNGKey(seed)
40
+ np.testing.assert_array_equal(rs.value, expected_key)
41
+
42
+ def test_init_with_prng_key(self):
43
+ """Test initialization with JAX PRNGKey."""
44
+ key = jr.PRNGKey(123)
45
+ rs = RandomState(key)
46
+ np.testing.assert_array_equal(rs.value, key)
47
+
48
+ def test_init_with_uint32_array(self):
49
+ """Test initialization with uint32 array."""
50
+ key_array = np.array([123, 456], dtype=np.uint32)
51
+ rs = RandomState(key_array)
52
+ np.testing.assert_array_equal(rs.value, key_array)
53
+
54
+ def test_init_with_invalid_key(self):
55
+ """Test initialization with invalid key raises error."""
56
+ # Test case that should raise error: wrong length AND wrong dtype
57
+ with self.assertRaises(ValueError):
58
+ RandomState(np.array([1, 2, 3], dtype=np.int32)) # len != 2 AND dtype != uint32
59
+
60
+ # Test valid cases that should NOT raise errors
61
+ # Wrong length but correct dtype is OK
62
+ rs1 = RandomState(np.array([1, 2, 3], dtype=np.uint32))
63
+ self.assertIsNotNone(rs1.value)
64
+
65
+ # Correct length but wrong dtype is OK
66
+ rs2 = RandomState(np.array([1, 2], dtype=np.int32))
67
+ self.assertIsNotNone(rs2.value)
68
+
69
+ def test_repr(self):
70
+ """Test string representation."""
71
+ rs = RandomState(42)
72
+ repr_str = repr(rs)
73
+ self.assertIn("RandomState", repr_str)
74
+ self.assertIn("42", repr_str)
75
+
76
+
77
+ class TestRandomStateKeyManagement(unittest.TestCase):
78
+ """Test key management functionality."""
79
+
80
+ def setUp(self):
81
+ self.rs = RandomState(42)
82
+
83
+ def test_seed_with_int(self):
84
+ """Test seeding with integer."""
85
+ self.rs.seed(123)
86
+ expected_key = jr.PRNGKey(123)
87
+ np.testing.assert_array_equal(self.rs.value, expected_key)
88
+
89
+ def test_seed_with_none(self):
90
+ """Test seeding with None generates new random seed."""
91
+ original_key = self.rs.value.copy()
92
+ self.rs.seed(None)
93
+ # Should be different (with very high probability)
94
+ self.assertFalse(np.array_equal(self.rs.value, original_key))
95
+
96
+ def test_seed_with_prng_key(self):
97
+ """Test seeding with PRNGKey."""
98
+ key = jr.PRNGKey(999)
99
+ self.rs.seed(key)
100
+ np.testing.assert_array_equal(self.rs.value, key)
101
+
102
+ def test_seed_with_invalid_input(self):
103
+ """Test seeding with invalid input raises error."""
104
+ with self.assertRaises(ValueError):
105
+ self.rs.seed([1, 2, 3]) # Wrong length list
106
+
107
+ def test_split_key_single(self):
108
+ """Test splitting key to get single new key."""
109
+ original_key = self.rs.value.copy()
110
+ new_key = self.rs.split_key()
111
+
112
+ # Original key should have changed
113
+ self.assertFalse(np.array_equal(self.rs.value, original_key))
114
+ # New key should be different from both
115
+ self.assertFalse(np.array_equal(new_key, original_key))
116
+ self.assertFalse(np.array_equal(new_key, self.rs.value))
117
+
118
+ def test_split_key_multiple(self):
119
+ """Test splitting key to get multiple new keys."""
120
+ n = 3
121
+ original_key = self.rs.value.copy()
122
+ new_keys = self.rs.split_key(n)
123
+
124
+ self.assertEqual(len(new_keys), n)
125
+ # All keys should be different
126
+ for i, key in enumerate(new_keys):
127
+ self.assertFalse(np.array_equal(key, original_key))
128
+ for j, other_key in enumerate(new_keys):
129
+ if i != j:
130
+ self.assertFalse(np.array_equal(key, other_key))
131
+
132
+ def test_split_key_invalid_n(self):
133
+ """Test split_key with invalid n raises error."""
134
+ with self.assertRaises(AssertionError):
135
+ self.rs.split_key(0)
136
+
137
+ with self.assertRaises(AssertionError):
138
+ self.rs.split_key(-1)
139
+
140
+ def test_backup_restore_key(self):
141
+ """Test backup and restore functionality."""
142
+ original_key = self.rs.value.copy()
143
+
144
+ # Backup the key
145
+ self.rs.backup_key()
146
+
147
+ # Change the key
148
+ self.rs.split_key()
149
+ changed_key = self.rs.value.copy()
150
+ self.assertFalse(np.array_equal(changed_key, original_key))
151
+
152
+ # Restore the key
153
+ self.rs.restore_key()
154
+ np.testing.assert_array_equal(self.rs.value, original_key)
155
+
156
+ def test_backup_already_backed_up(self):
157
+ """Test backup when already backed up raises error."""
158
+ self.rs.backup_key()
159
+ with self.assertRaises(ValueError):
160
+ self.rs.backup_key()
161
+
162
+ def test_restore_without_backup(self):
163
+ """Test restore without backup raises error."""
164
+ with self.assertRaises(ValueError):
165
+ self.rs.restore_key()
166
+
167
+ def test_clone(self):
168
+ """Test cloning creates independent copy."""
169
+ clone = self.rs.clone()
170
+
171
+ # Should be different instances
172
+ self.assertIsNot(clone, self.rs)
173
+
174
+ # Should have different keys after split
175
+ original_key = self.rs.value.copy()
176
+ clone_key = clone.value.copy()
177
+
178
+ self.rs.split_key()
179
+ clone.split_key()
180
+
181
+ self.assertFalse(np.array_equal(self.rs.value, clone.value))
182
+
183
+ def test_set_key(self):
184
+ """Test setting key directly."""
185
+ new_key = jr.PRNGKey(999)
186
+ self.rs.set_key(new_key)
187
+ np.testing.assert_array_equal(self.rs.value, new_key)
188
+
189
+
190
+ class TestRandomStateDistributions(unittest.TestCase):
191
+ """Test random distribution methods."""
192
+
193
+ def setUp(self):
194
+ self.rs = RandomState(42)
195
+
196
+ def test_rand(self):
197
+ """Test rand method."""
198
+ # Single value
199
+ val = self.rs.rand()
200
+ self.assertEqual(val.shape, ())
201
+ self.assertTrue(0 <= val < 1)
202
+
203
+ # Multiple dimensions
204
+ arr = self.rs.rand(3, 2)
205
+ self.assertEqual(arr.shape, (3, 2))
206
+ self.assertTrue((arr >= 0).all() and (arr < 1).all())
207
+
208
+ def test_randint(self):
209
+ """Test randint method."""
210
+ # Single bound
211
+ val = self.rs.randint(10)
212
+ self.assertTrue(0 <= val < 10)
213
+
214
+ # Both bounds
215
+ val = self.rs.randint(5, 15)
216
+ self.assertTrue(5 <= val < 15)
217
+
218
+ # With size
219
+ arr = self.rs.randint(0, 5, size=(2, 3))
220
+ self.assertEqual(arr.shape, (2, 3))
221
+ self.assertTrue((arr >= 0).all() and (arr < 5).all())
222
+
223
+ def test_randn(self):
224
+ """Test randn method."""
225
+ # Single value
226
+ val = self.rs.randn()
227
+ self.assertEqual(val.shape, ())
228
+
229
+ # Multiple dimensions
230
+ arr = self.rs.randn(3, 2)
231
+ self.assertEqual(arr.shape, (3, 2))
232
+
233
+ def test_normal(self):
234
+ """Test normal distribution."""
235
+ # Standard normal
236
+ val = self.rs.normal()
237
+ self.assertEqual(val.shape, ())
238
+
239
+ # With parameters
240
+ arr = self.rs.normal(loc=5.0, scale=2.0, size=(3, 2))
241
+ self.assertEqual(arr.shape, (3, 2))
242
+
243
+ def test_uniform(self):
244
+ """Test uniform distribution."""
245
+ # Standard uniform
246
+ val = self.rs.uniform()
247
+ self.assertTrue(0.0 <= val < 1.0)
248
+
249
+ # With bounds
250
+ arr = self.rs.uniform(low=2.0, high=8.0, size=(2, 3))
251
+ self.assertEqual(arr.shape, (2, 3))
252
+ self.assertTrue((arr >= 2.0).all() and (arr < 8.0).all())
253
+
254
+ def test_choice(self):
255
+ """Test choice method."""
256
+ # Choose from range
257
+ val = self.rs.choice(5)
258
+ self.assertTrue(0 <= val < 5)
259
+
260
+ # Choose from array
261
+ options = jnp.array([10, 20, 30, 40])
262
+ val = self.rs.choice(options)
263
+ self.assertIn(val, options)
264
+
265
+ # Multiple choices
266
+ arr = self.rs.choice(5, size=10)
267
+ self.assertEqual(arr.shape, (10,))
268
+ self.assertTrue((arr >= 0).all() and (arr < 5).all())
269
+
270
+ def test_beta(self):
271
+ """Test beta distribution."""
272
+ arr = self.rs.beta(2.0, 3.0, size=(2, 3))
273
+ self.assertEqual(arr.shape, (2, 3))
274
+ self.assertTrue((arr >= 0).all() and (arr <= 1).all())
275
+
276
+ def test_exponential(self):
277
+ """Test exponential distribution."""
278
+ arr = self.rs.exponential(scale=2.0, size=(2, 3))
279
+ self.assertEqual(arr.shape, (2, 3))
280
+ self.assertTrue((arr >= 0).all())
281
+
282
+ def test_gamma(self):
283
+ """Test gamma distribution."""
284
+ arr = self.rs.gamma(shape=2.0, scale=1.0, size=(2, 3))
285
+ self.assertEqual(arr.shape, (2, 3))
286
+ self.assertTrue((arr >= 0).all())
287
+
288
+ def test_poisson(self):
289
+ """Test Poisson distribution."""
290
+ arr = self.rs.poisson(lam=3.0, size=(2, 3))
291
+ self.assertEqual(arr.shape, (2, 3))
292
+ self.assertTrue((arr >= 0).all())
293
+
294
+ def test_binomial(self):
295
+ """Test binomial distribution."""
296
+ arr = self.rs.binomial(n=10, p=0.3, size=(2, 3))
297
+ self.assertEqual(arr.shape, (2, 3))
298
+ self.assertTrue((arr >= 0).all() and (arr <= 10).all())
299
+
300
+ def test_bernoulli(self):
301
+ """Test Bernoulli distribution."""
302
+ arr = self.rs.bernoulli(p=0.7, size=(100,))
303
+ self.assertEqual(arr.shape, (100,))
304
+ self.assertTrue(jnp.all((arr == 0) | (arr == 1)))
305
+
306
+ def test_bernoulli_invalid_p(self):
307
+ """Test Bernoulli with invalid probability."""
308
+ # Note: This should trigger jit_error_if, but in test we check the validation exists
309
+ with self.assertRaises((ValueError, Exception)):
310
+ self.rs.bernoulli(p=1.5) # p > 1
311
+
312
+ def test_truncated_normal(self):
313
+ """Test truncated normal distribution."""
314
+ arr = self.rs.truncated_normal(lower=-1.0, upper=1.0, size=(2, 3))
315
+ self.assertEqual(arr.shape, (2, 3))
316
+ self.assertTrue((arr >= -1.0).all() and (arr <= 1.0).all())
317
+
318
+ def test_multivariate_normal(self):
319
+ """Test multivariate normal distribution."""
320
+ mean = jnp.array([0.0, 1.0])
321
+ cov = jnp.array([[1.0, 0.5], [0.5, 2.0]])
322
+
323
+ arr = self.rs.multivariate_normal(mean, cov, size=(3,))
324
+ self.assertEqual(arr.shape, (3, 2))
325
+
326
+ def test_categorical(self):
327
+ """Test categorical distribution."""
328
+ logits = jnp.array([0.1, 0.2, 0.3, 0.4])
329
+ arr = self.rs.categorical(logits, size=(10,))
330
+ self.assertEqual(arr.shape, (10,))
331
+ self.assertTrue((arr >= 0).all() and (arr < len(logits)).all())
332
+
333
+
334
+ class TestRandomStatePyTorchCompatibility(unittest.TestCase):
335
+ """Test PyTorch-like methods."""
336
+
337
+ def setUp(self):
338
+ self.rs = RandomState(42)
339
+
340
+ def test_rand_like(self):
341
+ """Test rand_like method."""
342
+ input_tensor = jnp.zeros((3, 4))
343
+ result = self.rs.rand_like(input_tensor)
344
+ self.assertEqual(result.shape, input_tensor.shape)
345
+ self.assertTrue((result >= 0).all() and (result < 1).all())
346
+
347
+ def test_randn_like(self):
348
+ """Test randn_like method."""
349
+ input_tensor = jnp.zeros((2, 3))
350
+ result = self.rs.randn_like(input_tensor)
351
+ self.assertEqual(result.shape, input_tensor.shape)
352
+
353
+ def test_randint_like(self):
354
+ """Test randint_like method."""
355
+ input_tensor = jnp.zeros((2, 3), dtype=jnp.int32)
356
+ result = self.rs.randint_like(input_tensor, low=0, high=10)
357
+ self.assertEqual(result.shape, input_tensor.shape)
358
+ self.assertTrue((result >= 0).all() and (result < 10).all())
359
+
360
+
361
+ class TestRandomStateKeyBehavior(unittest.TestCase):
362
+ """Test key parameter behavior across methods."""
363
+
364
+ def setUp(self):
365
+ self.rs = RandomState(42)
366
+
367
+ def test_external_key_does_not_change_state(self):
368
+ """Test that using external key doesn't change internal state."""
369
+ original_key = self.rs.value.copy()
370
+ external_key = jr.PRNGKey(999)
371
+
372
+ # Use external key
373
+ self.rs.rand(5, key=external_key)
374
+
375
+ # Internal state should be unchanged
376
+ np.testing.assert_array_equal(self.rs.value, original_key)
377
+
378
+ def test_no_key_changes_state(self):
379
+ """Test that not providing key changes internal state."""
380
+ original_key = self.rs.value.copy()
381
+
382
+ # Use internal key
383
+ self.rs.rand(5)
384
+
385
+ # Internal state should have changed
386
+ self.assertFalse(np.array_equal(self.rs.value, original_key))
387
+
388
+ def test_reproducibility_with_same_key(self):
389
+ """Test reproducibility when using same external key."""
390
+ key = jr.PRNGKey(123)
391
+
392
+ result1 = self.rs.rand(5, key=key)
393
+ result2 = self.rs.rand(5, key=key)
394
+
395
+ np.testing.assert_array_equal(result1, result2)
396
+
397
+ def test_reproducibility_with_seed(self):
398
+ """Test reproducibility with seeding."""
399
+ self.rs.seed(42)
400
+ result1 = self.rs.rand(5)
401
+
402
+ self.rs.seed(42)
403
+ result2 = self.rs.rand(5)
404
+
405
+ np.testing.assert_array_equal(result1, result2)
406
+
407
+
408
+ class TestGlobalDefaultInstance(unittest.TestCase):
409
+ """Test the global DEFAULT RandomState instance."""
410
+
411
+ def test_default_exists(self):
412
+ """Test that DEFAULT instance exists and is RandomState."""
413
+ self.assertIsInstance(DEFAULT, RandomState)
414
+
415
+ def test_default_has_valid_key(self):
416
+ """Test that DEFAULT has valid key."""
417
+ self.assertIsNotNone(DEFAULT.value)
418
+ self.assertEqual(DEFAULT.value.shape, (2,))
419
+ self.assertEqual(DEFAULT.value.dtype, jnp.uint32)
420
+
421
+ def test_default_seeding(self):
422
+ """Test seeding DEFAULT instance."""
423
+ original_key = DEFAULT.value.copy()
424
+ DEFAULT.seed(12345)
425
+ self.assertFalse(np.array_equal(DEFAULT.value, original_key))
426
+
427
+ def test_default_split_key(self):
428
+ """Test splitting DEFAULT key."""
429
+ original_key = DEFAULT.value.copy()
430
+ new_key = DEFAULT.split_key()
431
+ self.assertFalse(np.array_equal(DEFAULT.value, original_key))
432
+ self.assertIsNotNone(new_key)
433
+
434
+
435
+ class TestUtilityFunctions(unittest.TestCase):
436
+ """Test utility functions in _rand_state module."""
437
+
438
+ def test_formalize_key_with_int(self):
439
+ """Test _formalize_key with integer."""
440
+ key = _formalize_key(42)
441
+ expected = jr.PRNGKey(42)
442
+ np.testing.assert_array_equal(key, expected)
443
+
444
+ def test_formalize_key_with_array(self):
445
+ """Test _formalize_key with array."""
446
+ input_key = jr.PRNGKey(123)
447
+ key = _formalize_key(input_key)
448
+ np.testing.assert_array_equal(key, input_key)
449
+
450
+ def test_formalize_key_with_uint32_array(self):
451
+ """Test _formalize_key with uint32 array."""
452
+ input_array = np.array([123, 456], dtype=np.uint32)
453
+ key = _formalize_key(input_array)
454
+ np.testing.assert_array_equal(key, input_array)
455
+
456
+ def test_formalize_key_invalid_input(self):
457
+ """Test _formalize_key with invalid input."""
458
+ with self.assertRaises(TypeError):
459
+ _formalize_key("invalid")
460
+
461
+ with self.assertRaises(TypeError):
462
+ _formalize_key(np.array([1, 2, 3], dtype=np.uint32)) # Wrong size
463
+
464
+ with self.assertRaises(TypeError):
465
+ _formalize_key(np.array([1, 2], dtype=np.int32)) # Wrong dtype
466
+
467
+ def test_size2shape(self):
468
+ """Test _size2shape function."""
469
+ self.assertEqual(_size2shape(None), ())
470
+ self.assertEqual(_size2shape(5), (5,))
471
+ self.assertEqual(_size2shape((3, 4)), (3, 4))
472
+ self.assertEqual(_size2shape([2, 3, 4]), (2, 3, 4))
473
+
474
+ def test_check_py_seq(self):
475
+ """Test _check_py_seq function."""
476
+ # Should convert lists/tuples to arrays
477
+ result = _check_py_seq([1, 2, 3])
478
+ self.assertIsInstance(result, jnp.ndarray)
479
+ np.testing.assert_array_equal(result, jnp.array([1, 2, 3]))
480
+
481
+ # Should leave other types unchanged
482
+ arr = jnp.array([1, 2, 3])
483
+ result = _check_py_seq(arr)
484
+ self.assertIs(result, arr)
485
+
486
+ scalar = 5
487
+ result = _check_py_seq(scalar)
488
+ self.assertEqual(result, scalar)
489
+
490
+
491
+ class TestErrorHandling(unittest.TestCase):
492
+ """Test error handling and edge cases."""
493
+
494
+ def setUp(self):
495
+ self.rs = RandomState(42)
496
+
497
+ def test_invalid_distribution_parameters(self):
498
+ """Test invalid parameters for distributions."""
499
+ # Note: Some distributions may not validate parameters immediately
500
+ # so we test what we can verify
501
+
502
+ # Test invalid probability for binomial should work with check_valid=True
503
+ try:
504
+ # This may or may not raise immediately depending on JAX compilation
505
+ self.rs.binomial(n=10, p=1.5, check_valid=True)
506
+ except:
507
+ pass # Expected to fail
508
+
509
+ # Test normal distribution works with negative scale (JAX allows this)
510
+ result = self.rs.normal(scale=-1.0, size=(2,))
511
+ self.assertEqual(result.shape, (2,))
512
+
513
+ def test_invalid_size_parameters(self):
514
+ """Test invalid size parameters."""
515
+ # Test empty shape works for distributions that accept size parameter
516
+ result = self.rs.random(size=())
517
+ self.assertEqual(result.shape, ())
518
+
519
+ # Test with None size
520
+ result = self.rs.random(size=None)
521
+ self.assertEqual(result.shape, ())
522
+
523
+ def test_dtype_consistency(self):
524
+ """Test dtype consistency across methods."""
525
+ # Integer methods should return integers
526
+ result = self.rs.randint(10, size=(3,))
527
+ self.assertTrue(jnp.issubdtype(result.dtype, jnp.integer))
528
+
529
+ # Float methods should return floats
530
+ result = self.rs.rand(3)
531
+ self.assertTrue(jnp.issubdtype(result.dtype, jnp.floating))
532
+
533
+ def test_self_assign_multi_keys(self):
534
+ """Test self_assign_multi_keys method."""
535
+ original_shape = self.rs.value.shape
536
+
537
+ # Test with backup
538
+ self.rs.self_assign_multi_keys(3, backup=True)
539
+ self.assertEqual(self.rs.value.shape, (3, 2))
540
+
541
+ # Restore should work
542
+ self.rs.restore_key()
543
+ self.assertEqual(self.rs.value.shape, original_shape)
544
+
545
+ # Test without backup
546
+ self.rs.self_assign_multi_keys(2, backup=False)
547
+ self.assertEqual(self.rs.value.shape, (2, 2))
548
+
549
+
550
+ if __name__ == '__main__':
551
+ unittest.main()
@@ -0,0 +1,59 @@
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from ._ad_checkpoint import *
18
+ from ._ad_checkpoint import __all__ as _ad_checkpoint_all
19
+ from ._autograd import *
20
+ from ._autograd import __all__ as _autograd_all
21
+ from ._conditions import *
22
+ from ._conditions import __all__ as _conditions_all
23
+ from ._error_if import *
24
+ from ._error_if import __all__ as _error_if_all
25
+ from ._eval_shape import *
26
+ from ._eval_shape import __all__ as _eval_shape_all
27
+ from ._jit import *
28
+ from ._jit import __all__ as _jit_all
29
+ from ._loop_collect_return import *
30
+ from ._loop_collect_return import __all__ as _loop_collect_return_all
31
+ from ._loop_no_collection import *
32
+ from ._loop_no_collection import __all__ as _loop_no_collection_all
33
+ from ._make_jaxpr import *
34
+ from ._make_jaxpr import __all__ as _make_jaxpr_all
35
+ from ._mapping import *
36
+ from ._mapping import __all__ as _mapping_all
37
+ from ._progress_bar import *
38
+ from ._progress_bar import __all__ as _progress_bar_all
39
+ from ._random import *
40
+ from ._random import __all__ as _random_all
41
+ from ._unvmap import *
42
+ from ._unvmap import __all__ as _unvmap_all
43
+
44
+ __all__ = _ad_checkpoint_all + _autograd_all + _conditions_all + _error_if_all
45
+ __all__ += _eval_shape_all + _jit_all + _loop_collect_return_all + _loop_no_collection_all
46
+ __all__ += _make_jaxpr_all + _mapping_all + _progress_bar_all + _random_all + _unvmap_all
47
+ del _ad_checkpoint_all
48
+ del _autograd_all
49
+ del _conditions_all
50
+ del _error_if_all
51
+ del _eval_shape_all
52
+ del _jit_all
53
+ del _loop_collect_return_all
54
+ del _loop_no_collection_all
55
+ del _make_jaxpr_all
56
+ del _mapping_all
57
+ del _progress_bar_all
58
+ del _random_all
59
+ del _unvmap_all