brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,675 +1,675 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """
17
- Random seed management utilities for BrainState.
18
-
19
- This module provides comprehensive random seed management functionality, enabling
20
- reproducible computations across JAX and NumPy backends. It supports both traditional
21
- integer seeds and JAX's PRNG key system, providing a unified interface for random
22
- number generation in scientific computing and machine learning applications.
23
-
24
- Key Features:
25
- - Unified seed management for JAX and NumPy
26
- - Context managers for temporary seed changes
27
- - Key splitting for parallel computation
28
- - Automatic seed backup and restoration
29
- - Thread-safe random state management
30
-
31
- Example:
32
- Basic usage for reproducible random number generation:
33
-
34
- >>> import brainstate
35
- >>> brainstate.random.seed(42)
36
- >>> print(brainstate.random.rand(3))
37
- [0.95598125 0.4032725 0.96086407]
38
-
39
- Using context managers for temporary seeds:
40
-
41
- >>> with brainstate.random.seed_context(123):
42
- ... values = brainstate.random.rand(2)
43
- >>> print(values) # Reproducible output
44
-
45
- Key splitting for parallel computation:
46
-
47
- >>> keys = brainstate.random.split_keys(4) # Generate 4 independent keys
48
- >>> # Use keys for parallel random number generation
49
- """
50
-
51
- from contextlib import contextmanager
52
-
53
- import jax
54
- import numpy as np
55
-
56
- from brainstate.typing import SeedOrKey
57
- from ._rand_state import RandomState, DEFAULT, use_prng_key
58
-
59
- __all__ = [
60
- 'seed',
61
- 'set_key',
62
- 'get_key',
63
- 'default_rng',
64
- 'split_key',
65
- 'split_keys',
66
- 'seed_context',
67
- 'restore_key',
68
- 'self_assign_multi_keys',
69
- 'clone_rng',
70
- ]
71
-
72
-
73
- def restore_key() -> None:
74
- """
75
- Restore the default random key to its previous state.
76
-
77
- This function restores the global random state to a previously backed up state.
78
- It's useful for undoing changes to the random state or implementing checkpoint
79
- functionality in computational workflows.
80
-
81
- Note:
82
- This operation requires that a backup was previously created. If no backup
83
- exists, this function may not have any effect or may restore to an initial state.
84
-
85
- Example:
86
- >>> import brainstate
87
- >>> brainstate.random.seed(42)
88
- >>> original_key = brainstate.random.get_key()
89
- >>> brainstate.random.seed(123) # Change the seed
90
- >>> brainstate.random.restore_key() # Restore to previous state
91
- >>> assert np.array_equal(brainstate.random.get_key(), original_key)
92
-
93
- See Also:
94
- - :func:`set_key`: Set a new random key
95
- - :func:`get_key`: Get the current random key
96
- - :func:`seed_context`: Temporary seed changes with automatic restoration
97
- """
98
- DEFAULT.restore_key()
99
-
100
-
101
- def split_key(n: int = None, backup: bool = False):
102
- """
103
- Create new random key(s) from the current seed.
104
-
105
- This function generates one or more independent random keys by splitting the
106
- current global random state. It follows JAX's random paradigm, ensuring that
107
- each split key produces statistically independent random sequences.
108
-
109
- Args:
110
- n: The number of keys to generate. If None, returns a single key.
111
- If an integer, returns an array of n keys.
112
- backup: Whether to backup the current key before splitting. This allows
113
- restoration of the original state using :func:`restore_key`.
114
-
115
- Returns:
116
- If n is None: A single JAX PRNG key.
117
- If n is an integer: An array of n independent JAX PRNG keys.
118
-
119
- Example:
120
- Generate a single key:
121
-
122
- >>> import brainstate
123
- >>> brainstate.random.seed(42)
124
- >>> key = brainstate.random.split_key()
125
- >>> print(key.shape)
126
- (2,)
127
-
128
- Generate multiple keys for parallel computation:
129
-
130
- >>> keys = brainstate.random.split_key(4)
131
- >>> print(keys.shape)
132
- (4, 2)
133
-
134
- Use with backup for state restoration:
135
-
136
- >>> original_key = brainstate.random.get_key()
137
- >>> keys = brainstate.random.split_key(2, backup=True)
138
- >>> brainstate.random.restore_key()
139
- >>> assert np.array_equal(brainstate.random.get_key(), original_key)
140
-
141
- Note:
142
- This function advances the global random state. Each call produces
143
- different keys unless the state is reset.
144
-
145
- See Also:
146
- - :func:`split_keys`: Convenience function for multiple keys
147
- - :func:`seed`: Set the random seed
148
- - :func:`restore_key`: Restore backed up key
149
- """
150
- return DEFAULT.split_key(n=n, backup=backup)
151
-
152
-
153
- def split_keys(n: int, backup: bool = False):
154
- """
155
- Create multiple independent random keys from the current seed.
156
-
157
- This is a convenience function that generates exactly n independent random keys
158
- by splitting the current global random state. It's commonly used internally by
159
- parallel computation functions like `pmap` and `vmap` to ensure that each
160
- parallel thread gets a unique random key.
161
-
162
- Args:
163
- n: The number of independent keys to generate. Must be a positive integer.
164
- backup: Whether to backup the current key before splitting. If True,
165
- the original key can be restored using :func:`restore_key`.
166
-
167
- Returns:
168
- An array of n independent JAX PRNG keys with shape (n, 2).
169
-
170
- Raises:
171
- ValueError: If n is not a positive integer.
172
-
173
- Example:
174
- Generate keys for parallel computation:
175
-
176
- >>> import brainstate
177
- >>> brainstate.random.seed(42)
178
- >>> keys = brainstate.random.split_keys(4)
179
- >>> print(keys.shape)
180
- (4, 2)
181
-
182
- Use with vmap for parallel random number generation:
183
-
184
- >>> import jax
185
- >>> keys = brainstate.random.split_keys(8)
186
- >>> @jax.vmap
187
- ... def generate_random(key):
188
- ... return jax.random.normal(key, (10,))
189
- >>> parallel_randoms = generate_random(keys)
190
- >>> print(parallel_randoms.shape)
191
- (8, 10)
192
-
193
- Use with backup for state preservation:
194
-
195
- >>> original_state = brainstate.random.get_key()
196
- >>> keys = brainstate.random.split_keys(3, backup=True)
197
- >>> # ... use keys for computation ...
198
- >>> brainstate.random.restore_key() # Restore original state
199
-
200
- Note:
201
- This function is equivalent to calling :func:`split_key` with n as an argument.
202
- It's provided as a convenience function with a more explicit name for clarity.
203
-
204
- See Also:
205
- - :func:`split_key`: More general key splitting function
206
- - :func:`self_assign_multi_keys`: Assign multiple keys to global state
207
- - :func:`seed_context`: Temporary seed changes
208
- """
209
- if not isinstance(n, int) or n <= 0:
210
- raise ValueError(f"n must be a positive integer, got {n}")
211
- return split_key(n, backup=backup)
212
-
213
-
214
- def self_assign_multi_keys(n: int, backup: bool = True) -> None:
215
- """
216
- Assign multiple keys to the global random state for parallel access.
217
-
218
- This function prepares the global random state for parallel computation by
219
- pre-generating n independent keys. It's particularly useful when you need
220
- to ensure that parallel computations have access to independent random
221
- sequences without the overhead of key splitting during computation.
222
-
223
- Args:
224
- n: The number of independent keys to pre-generate and assign.
225
- Must be a positive integer.
226
- backup: Whether to backup the current random state before assignment.
227
- If True, the original state can be restored using :func:`restore_key`.
228
-
229
- Raises:
230
- ValueError: If n is not a positive integer.
231
-
232
- Example:
233
- Prepare for parallel computation:
234
-
235
- >>> import brainstate
236
- >>> brainstate.random.seed(42)
237
- >>> # Prepare 4 independent keys for parallel access
238
- >>> brainstate.random.self_assign_multi_keys(4)
239
-
240
- Use in parallel context:
241
-
242
- >>> # The random state now has 4 independent keys ready for use
243
- >>> # Each parallel thread can access a different key
244
-
245
- Note:
246
- This is an advanced function primarily used internally for optimizing
247
- parallel random number generation. In most cases, :func:`split_keys`
248
- provides a more straightforward interface for parallel computation.
249
-
250
- See Also:
251
- - :func:`split_keys`: Generate multiple independent keys
252
- - :func:`restore_key`: Restore backed up state
253
- - :func:`seed_context`: Temporary state changes
254
- """
255
- if not isinstance(n, int) or n <= 0:
256
- raise ValueError(f"n must be a positive integer, got {n}")
257
- DEFAULT.self_assign_multi_keys(n, backup=backup)
258
-
259
-
260
- def clone_rng(seed_or_key: SeedOrKey = None, clone: bool = True) -> RandomState:
261
- """
262
- Create a clone of the random state or a new random state.
263
-
264
- This function provides a flexible way to create independent random states,
265
- either by cloning the current global state or by creating a new state with
266
- a specific seed or key. Cloned states are independent and don't affect each
267
- other when used for random number generation.
268
-
269
- Args:
270
- seed_or_key: Optional seed (integer) or JAX random key to initialize
271
- the new random state. If None, uses the current global state.
272
- clone: Whether to clone the default random state. If False and
273
- seed_or_key is None, returns the global state directly (not recommended
274
- for most use cases as it shares state).
275
-
276
- Returns:
277
- A RandomState instance that can be used independently for random
278
- number generation.
279
-
280
- Example:
281
- Clone the current global state:
282
-
283
- >>> import brainstate
284
- >>> brainstate.random.seed(42)
285
- >>> rng1 = brainstate.random.clone_rng()
286
- >>> rng2 = brainstate.random.clone_rng()
287
- >>> # rng1 and rng2 are independent copies
288
-
289
- Create a new state with specific seed:
290
-
291
- >>> rng_fixed = brainstate.random.clone_rng(123)
292
- >>> # Always produces the same sequences when reset to seed 123
293
-
294
- Use for independent computations:
295
-
296
- >>> rng = brainstate.random.clone_rng(456)
297
- >>> values1 = rng.normal(size=5)
298
- >>> values2 = rng.normal(size=5)
299
- >>> # values1 and values2 are different but reproducible
300
-
301
- Note:
302
- Cloned random states are completely independent. Changes to one state
303
- (like advancing through random number generation) don't affect others.
304
-
305
- See Also:
306
- - :func:`default_rng`: Get or create a random state
307
- - :func:`seed`: Set the global random seed
308
- - :class:`RandomState`: The random state class
309
- """
310
- if seed_or_key is None:
311
- return DEFAULT.clone() if clone else DEFAULT
312
- else:
313
- return RandomState(seed_or_key)
314
-
315
-
316
- def default_rng(seed_or_key: SeedOrKey = None) -> RandomState:
317
- """
318
- Get the default random state or create a new one with specified seed.
319
-
320
- This function provides access to the global random state used throughout
321
- BrainState, or creates a new independent random state if a seed is provided.
322
- It's the primary interface for obtaining random state objects in BrainState.
323
-
324
- Args:
325
- seed_or_key: Optional seed (integer) or JAX random key. If None,
326
- returns the global default random state. If provided, creates
327
- a new independent RandomState with the specified seed.
328
-
329
- Returns:
330
- The default RandomState if seed_or_key is None, otherwise a new
331
- RandomState initialized with the provided seed or key.
332
-
333
- Example:
334
- Get the global random state:
335
-
336
- >>> import brainstate
337
- >>> rng = brainstate.random.default_rng()
338
- >>> # rng is the global random state used by brainstate.random functions
339
-
340
- Create a new independent random state:
341
-
342
- >>> rng_local = brainstate.random.default_rng(42)
343
- >>> values = rng_local.normal(size=10)
344
-
345
- Use for reproducible local computations:
346
-
347
- >>> def reproducible_computation():
348
- ... local_rng = brainstate.random.default_rng(12345)
349
- ... return local_rng.uniform(size=5)
350
- >>> result1 = reproducible_computation()
351
- >>> result2 = reproducible_computation()
352
- >>> assert np.allclose(result1, result2) # Always the same
353
-
354
- Note:
355
- When seed_or_key is None, this returns the actual global state object.
356
- Modifications to this state (through random number generation) will
357
- affect all subsequent calls to global random functions.
358
-
359
- See Also:
360
- - :func:`clone_rng`: Create independent clones of random states
361
- - :func:`seed`: Set the global random seed
362
- - :class:`RandomState`: The underlying random state implementation
363
- """
364
- if seed_or_key is None:
365
- return DEFAULT
366
- else:
367
- return RandomState(seed_or_key)
368
-
369
-
370
- def set_key(seed_or_key: SeedOrKey) -> None:
371
- """
372
- Set a new random key for the global random state.
373
-
374
- This function updates the global random state with a new key, which can be
375
- either an integer seed or a JAX PRNG key. All subsequent calls to global
376
- random functions will use this new key state.
377
-
378
- Args:
379
- seed_or_key: The new random key to set. Can be:
380
- - An integer seed (will be converted to a JAX PRNG key)
381
- - A JAX PRNG key array
382
- - A numpy array representing a PRNG key
383
-
384
- Raises:
385
- ValueError: If the provided key is not in a valid format.
386
-
387
- Example:
388
- Set with integer seed:
389
-
390
- >>> import brainstate
391
- >>> brainstate.random.set_key(42)
392
- >>> values1 = brainstate.random.rand(3)
393
-
394
- Set with JAX key:
395
-
396
- >>> import jax
397
- >>> key = jax.random.key(123)
398
- >>> brainstate.random.set_key(key)
399
- >>> values2 = brainstate.random.rand(3)
400
-
401
- Restore reproducible state:
402
-
403
- >>> brainstate.random.set_key(42)
404
- >>> # Now random functions will produce the same sequences as first example
405
-
406
- Note:
407
- This function immediately changes the global random state. All threads
408
- and computations using the global random functions will be affected.
409
-
410
- See Also:
411
- - :func:`get_key`: Get the current random key
412
- - :func:`seed`: Set seed (also affects NumPy)
413
- - :func:`restore_key`: Restore a backed up key
414
- """
415
- if isinstance(seed_or_key, int):
416
- # Create key using appropriate JAX function based on version
417
- key = jax.random.PRNGKey(seed_or_key) if use_prng_key else jax.random.key(seed_or_key)
418
- elif isinstance(seed_or_key, (jax.numpy.ndarray, np.ndarray)):
419
- if jax.numpy.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
420
- key = seed_or_key
421
- elif seed_or_key.size == 2 and seed_or_key.dtype == jax.numpy.uint32:
422
- key = seed_or_key
423
- else:
424
- raise ValueError(
425
- f"seed_or_key should be an integer, a JAX PRNG key, or a uint32 array of size 2. "
426
- f"Got array with dtype {seed_or_key.dtype} and size {seed_or_key.size}."
427
- )
428
- else:
429
- raise ValueError(
430
- f"seed_or_key must be an integer or a JAX-compatible array. "
431
- f"Got {type(seed_or_key)}."
432
- )
433
- DEFAULT.set_key(key)
434
-
435
-
436
- def get_key():
437
- """
438
- Get the current global random key.
439
-
440
- This function returns the current random key used by the global random state.
441
- The returned key represents the internal state of the JAX PRNG and can be used
442
- to restore the random state later or to create independent random number generators.
443
-
444
- Returns:
445
- The current JAX PRNG key as a numpy array. This is typically a 2-element
446
- uint32 array representing the internal state of the random number generator.
447
-
448
- Example:
449
- Get and store the current random state:
450
-
451
- >>> import brainstate
452
- >>> brainstate.random.seed(42)
453
- >>> current_key = brainstate.random.get_key()
454
- >>> print(current_key.shape)
455
- (2,)
456
-
457
- Use the key to restore state later:
458
-
459
- >>> # Generate some random numbers
460
- >>> values1 = brainstate.random.rand(3)
461
- >>> # Restore the previous state
462
- >>> brainstate.random.set_key(current_key)
463
- >>> values2 = brainstate.random.rand(3)
464
- >>> # values1 and values2 will be identical
465
-
466
- Compare keys for debugging:
467
-
468
- >>> brainstate.random.seed(123)
469
- >>> key1 = brainstate.random.get_key()
470
- >>> brainstate.random.seed(123)
471
- >>> key2 = brainstate.random.get_key()
472
- >>> assert jax.numpy.array_equal(key1, key2) # Same seed gives same key
473
-
474
- Note:
475
- The returned key is a snapshot of the current state. Subsequent calls to
476
- random functions will advance the internal state, so calling get_key()
477
- again will return a different key unless the state is reset.
478
-
479
- See Also:
480
- - :func:`set_key`: Set a new random key
481
- - :func:`seed`: Set the random seed (also affects NumPy)
482
- - :func:`split_key`: Create new keys from current state
483
- - :func:`seed_context`: Temporary seed changes with automatic restoration
484
-
485
- """
486
- return DEFAULT.value
487
-
488
-
489
- def seed(seed_or_key: SeedOrKey = None):
490
- """
491
- Set the global random seed for both JAX and NumPy.
492
-
493
- This function initializes the global random state with a new seed, affecting
494
- both JAX and NumPy random number generators. It ensures reproducible random
495
- number generation across the entire BrainState ecosystem.
496
-
497
- Args:
498
- seed_or_key: The seed or key to set. Can be:
499
- - None: Generates a random seed automatically
500
- - int: An integer seed (0 to 2^32-1)
501
- - JAX PRNG key: A JAX random key array
502
- If None, a random seed is generated using NumPy's random generator.
503
-
504
- Raises:
505
- ValueError: If seed_or_key is not a valid seed format (not an integer,
506
- valid JAX key, or None).
507
-
508
- Example:
509
- Set a specific seed for reproducible results:
510
-
511
- >>> import brainstate
512
- >>> brainstate.random.seed(42)
513
- >>> values1 = brainstate.random.rand(3)
514
- >>> brainstate.random.seed(42) # Reset to same seed
515
- >>> values2 = brainstate.random.rand(3)
516
- >>> assert np.allclose(values1, values2) # Same values
517
-
518
- Use automatic random seeding:
519
-
520
- >>> brainstate.random.seed() # Uses random seed
521
- >>> # Each call will produce different sequences
522
-
523
- Use with JAX keys:
524
-
525
- >>> import jax
526
- >>> key = jax.random.key(123)
527
- >>> brainstate.random.seed(key)
528
- >>> # Now both JAX and NumPy use consistent seeds
529
-
530
- Ensure reproducibility in scientific experiments:
531
-
532
- >>> def experiment():
533
- ... brainstate.random.seed(12345) # Fixed seed for reproducibility
534
- ... data = brainstate.random.normal(size=(100, 10))
535
- ... return data.mean()
536
- >>> result1 = experiment()
537
- >>> result2 = experiment()
538
- >>> assert result1 == result2 # Always same result
539
-
540
- Note:
541
- - This function affects the global random state used by all BrainState
542
- random functions and NumPy's global random state.
543
- - When using automatic seeding (seed_or_key=None), NumPy's seed is not
544
- set to maintain its current state.
545
- - JAX compilation is handled automatically with compile-time evaluation.
546
- - For JAX keys, only the first element is used to seed NumPy to maintain
547
- compatibility between the two random systems.
548
-
549
- See Also:
550
- - :func:`set_key`: Set only the JAX random key
551
- - :func:`get_key`: Get the current random key
552
- - :func:`seed_context`: Temporary seed changes
553
- - :func:`split_key`: Create independent random keys
554
-
555
- """
556
- with jax.ensure_compile_time_eval():
557
- _set_numpy_seed = True
558
- if seed_or_key is None:
559
- seed_or_key = np.random.randint(0, 100000)
560
- _set_numpy_seed = False
561
-
562
- # numpy random seed
563
- if _set_numpy_seed:
564
- try:
565
- if np.size(seed_or_key) == 1: # seed
566
- np.random.seed(seed_or_key)
567
- elif np.size(seed_or_key) == 2: # jax random key
568
- np.random.seed(seed_or_key[0])
569
- else:
570
- raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
571
- except jax.errors.TracerArrayConversionError:
572
- pass
573
-
574
- # jax random seed
575
- DEFAULT.seed(seed_or_key)
576
-
577
-
578
- @contextmanager
579
- def seed_context(seed_or_key: SeedOrKey):
580
- """
581
- Context manager for temporary random seed changes with automatic restoration.
582
-
583
- This context manager temporarily changes the global random seed for the duration
584
- of the block, then automatically restores the previous random state when exiting.
585
- It's ideal for ensuring reproducible computations in specific code sections without
586
- permanently affecting the global random state.
587
-
588
- Args:
589
- seed_or_key: The temporary seed or key to use within the context. Can be:
590
- - int: An integer seed for reproducible sequences
591
- - JAX PRNG key: A JAX random key array
592
- The seed affects both JAX and NumPy random states during the context.
593
-
594
- Yields:
595
- None: The context manager doesn't yield any value, but provides a
596
- controlled random environment for the enclosed code block.
597
-
598
- Example:
599
- Reproducible computations without affecting global state:
600
-
601
- >>> import brainstate
602
- >>> # Global state remains unaffected
603
- >>> global_values1 = brainstate.random.rand(2)
604
- >>>
605
- >>> with brainstate.random.seed_context(42):
606
- ... temp_values1 = brainstate.random.rand(2)
607
- ... print(f"First run: {temp_values1}")
608
- [0.95598125 0.4032725 ]
609
- >>>
610
- >>> with brainstate.random.seed_context(42):
611
- ... temp_values2 = brainstate.random.rand(2)
612
- ... print(f"Second run: {temp_values2}")
613
- [0.95598125 0.4032725 ]
614
- >>>
615
- >>> # Values are identical within context
616
- >>> assert np.allclose(temp_values1, temp_values2)
617
- >>>
618
- >>> # Global state continues from where it left off
619
- >>> global_values2 = brainstate.random.rand(2)
620
-
621
- Nested contexts for complex scenarios:
622
-
623
- >>> with brainstate.random.seed_context(123):
624
- ... outer_values = brainstate.random.rand(2)
625
- ... with brainstate.random.seed_context(456):
626
- ... inner_values = brainstate.random.rand(2)
627
- ... # Outer context is restored here
628
- ... outer_values2 = brainstate.random.rand(2)
629
-
630
- Exception safety - state is restored even on errors:
631
-
632
- >>> try:
633
- ... with brainstate.random.seed_context(789):
634
- ... some_values = brainstate.random.rand(3)
635
- ... raise ValueError("Something went wrong")
636
- ... except ValueError:
637
- ... pass
638
- >>> # Random state is properly restored
639
-
640
- Testing reproducible algorithms:
641
-
642
- >>> def test_algorithm():
643
- ... with brainstate.random.seed_context(42):
644
- ... data = brainstate.random.normal(size=(100,))
645
- ... return data.mean()
646
- >>>
647
- >>> result1 = test_algorithm()
648
- >>> result2 = test_algorithm()
649
- >>> assert result1 == result2 # Always same result
650
-
651
- Note:
652
- - The context manager saves and restores the complete JAX random state
653
- - NumPy's random state is also temporarily modified during the context
654
- - Nested contexts work correctly - each level restores its own state
655
- - Exception safety is guaranteed - random state is restored even if
656
- exceptions occur within the context
657
- - This is more convenient than manually saving/restoring state with
658
- get_key() and set_key()
659
-
660
- See Also:
661
- - :func:`seed`: Permanently set the global random seed
662
- - :func:`get_key`: Get the current random key for manual state management
663
- - :func:`set_key`: Set the random key for manual state management
664
- - :func:`clone_rng`: Create independent random states
665
-
666
- """
667
- # get the old random key
668
- old_jrand_key = DEFAULT.value
669
- try:
670
- # set the seed of jax random state
671
- DEFAULT.seed(seed_or_key)
672
- yield
673
- finally:
674
- # restore the random state
675
- DEFAULT.seed(old_jrand_key)
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Random seed management utilities for BrainState.
18
+
19
+ This module provides comprehensive random seed management functionality, enabling
20
+ reproducible computations across JAX and NumPy backends. It supports both traditional
21
+ integer seeds and JAX's PRNG key system, providing a unified interface for random
22
+ number generation in scientific computing and machine learning applications.
23
+
24
+ Key Features:
25
+ - Unified seed management for JAX and NumPy
26
+ - Context managers for temporary seed changes
27
+ - Key splitting for parallel computation
28
+ - Automatic seed backup and restoration
29
+ - Thread-safe random state management
30
+
31
+ Example:
32
+ Basic usage for reproducible random number generation:
33
+
34
+ >>> import brainstate
35
+ >>> brainstate.random.seed(42)
36
+ >>> print(brainstate.random.rand(3))
37
+ [0.95598125 0.4032725 0.96086407]
38
+
39
+ Using context managers for temporary seeds:
40
+
41
+ >>> with brainstate.random.seed_context(123):
42
+ ... values = brainstate.random.rand(2)
43
+ >>> print(values) # Reproducible output
44
+
45
+ Key splitting for parallel computation:
46
+
47
+ >>> keys = brainstate.random.split_keys(4) # Generate 4 independent keys
48
+ >>> # Use keys for parallel random number generation
49
+ """
50
+
51
+ from contextlib import contextmanager
52
+
53
+ import jax
54
+ import numpy as np
55
+
56
+ from brainstate.typing import SeedOrKey
57
+ from ._state import RandomState, DEFAULT, use_prng_key
58
+
59
+ __all__ = [
60
+ 'seed',
61
+ 'set_key',
62
+ 'get_key',
63
+ 'default_rng',
64
+ 'split_key',
65
+ 'split_keys',
66
+ 'seed_context',
67
+ 'restore_key',
68
+ 'self_assign_multi_keys',
69
+ 'clone_rng',
70
+ ]
71
+
72
+
73
+ def restore_key() -> None:
74
+ """
75
+ Restore the default random key to its previous state.
76
+
77
+ This function restores the global random state to a previously backed up state.
78
+ It's useful for undoing changes to the random state or implementing checkpoint
79
+ functionality in computational workflows.
80
+
81
+ Note:
82
+ This operation requires that a backup was previously created. If no backup
83
+ exists, this function may not have any effect or may restore to an initial state.
84
+
85
+ Example:
86
+ >>> import brainstate
87
+ >>> brainstate.random.seed(42)
88
+ >>> original_key = brainstate.random.get_key()
89
+ >>> brainstate.random.seed(123) # Change the seed
90
+ >>> brainstate.random.restore_key() # Restore to previous state
91
+ >>> assert np.array_equal(brainstate.random.get_key(), original_key)
92
+
93
+ See Also:
94
+ - :func:`set_key`: Set a new random key
95
+ - :func:`get_key`: Get the current random key
96
+ - :func:`seed_context`: Temporary seed changes with automatic restoration
97
+ """
98
+ DEFAULT.restore_key()
99
+
100
+
101
+ def split_key(n: int = None, backup: bool = False):
102
+ """
103
+ Create new random key(s) from the current seed.
104
+
105
+ This function generates one or more independent random keys by splitting the
106
+ current global random state. It follows JAX's random paradigm, ensuring that
107
+ each split key produces statistically independent random sequences.
108
+
109
+ Args:
110
+ n: The number of keys to generate. If None, returns a single key.
111
+ If an integer, returns an array of n keys.
112
+ backup: Whether to backup the current key before splitting. This allows
113
+ restoration of the original state using :func:`restore_key`.
114
+
115
+ Returns:
116
+ If n is None: A single JAX PRNG key.
117
+ If n is an integer: An array of n independent JAX PRNG keys.
118
+
119
+ Example:
120
+ Generate a single key:
121
+
122
+ >>> import brainstate
123
+ >>> brainstate.random.seed(42)
124
+ >>> key = brainstate.random.split_key()
125
+ >>> print(key.shape)
126
+ (2,)
127
+
128
+ Generate multiple keys for parallel computation:
129
+
130
+ >>> keys = brainstate.random.split_key(4)
131
+ >>> print(keys.shape)
132
+ (4, 2)
133
+
134
+ Use with backup for state restoration:
135
+
136
+ >>> original_key = brainstate.random.get_key()
137
+ >>> keys = brainstate.random.split_key(2, backup=True)
138
+ >>> brainstate.random.restore_key()
139
+ >>> assert np.array_equal(brainstate.random.get_key(), original_key)
140
+
141
+ Note:
142
+ This function advances the global random state. Each call produces
143
+ different keys unless the state is reset.
144
+
145
+ See Also:
146
+ - :func:`split_keys`: Convenience function for multiple keys
147
+ - :func:`seed`: Set the random seed
148
+ - :func:`restore_key`: Restore backed up key
149
+ """
150
+ return DEFAULT.split_key(n=n, backup=backup)
151
+
152
+
153
+ def split_keys(n: int, backup: bool = False):
154
+ """
155
+ Create multiple independent random keys from the current seed.
156
+
157
+ This is a convenience function that generates exactly n independent random keys
158
+ by splitting the current global random state. It's commonly used internally by
159
+ parallel computation functions like `pmap` and `vmap` to ensure that each
160
+ parallel thread gets a unique random key.
161
+
162
+ Args:
163
+ n: The number of independent keys to generate. Must be a positive integer.
164
+ backup: Whether to backup the current key before splitting. If True,
165
+ the original key can be restored using :func:`restore_key`.
166
+
167
+ Returns:
168
+ An array of n independent JAX PRNG keys with shape (n, 2).
169
+
170
+ Raises:
171
+ ValueError: If n is not a positive integer.
172
+
173
+ Example:
174
+ Generate keys for parallel computation:
175
+
176
+ >>> import brainstate
177
+ >>> brainstate.random.seed(42)
178
+ >>> keys = brainstate.random.split_keys(4)
179
+ >>> print(keys.shape)
180
+ (4, 2)
181
+
182
+ Use with vmap for parallel random number generation:
183
+
184
+ >>> import jax
185
+ >>> keys = brainstate.random.split_keys(8)
186
+ >>> @jax.vmap
187
+ ... def generate_random(key):
188
+ ... return jax.random.normal(key, (10,))
189
+ >>> parallel_randoms = generate_random(keys)
190
+ >>> print(parallel_randoms.shape)
191
+ (8, 10)
192
+
193
+ Use with backup for state preservation:
194
+
195
+ >>> original_state = brainstate.random.get_key()
196
+ >>> keys = brainstate.random.split_keys(3, backup=True)
197
+ >>> # ... use keys for computation ...
198
+ >>> brainstate.random.restore_key() # Restore original state
199
+
200
+ Note:
201
+ This function is equivalent to calling :func:`split_key` with n as an argument.
202
+ It's provided as a convenience function with a more explicit name for clarity.
203
+
204
+ See Also:
205
+ - :func:`split_key`: More general key splitting function
206
+ - :func:`self_assign_multi_keys`: Assign multiple keys to global state
207
+ - :func:`seed_context`: Temporary seed changes
208
+ """
209
+ if not isinstance(n, int) or n <= 0:
210
+ raise ValueError(f"n must be a positive integer, got {n}")
211
+ return split_key(n, backup=backup)
212
+
213
+
214
+ def self_assign_multi_keys(n: int, backup: bool = True) -> None:
215
+ """
216
+ Assign multiple keys to the global random state for parallel access.
217
+
218
+ This function prepares the global random state for parallel computation by
219
+ pre-generating n independent keys. It's particularly useful when you need
220
+ to ensure that parallel computations have access to independent random
221
+ sequences without the overhead of key splitting during computation.
222
+
223
+ Args:
224
+ n: The number of independent keys to pre-generate and assign.
225
+ Must be a positive integer.
226
+ backup: Whether to backup the current random state before assignment.
227
+ If True, the original state can be restored using :func:`restore_key`.
228
+
229
+ Raises:
230
+ ValueError: If n is not a positive integer.
231
+
232
+ Example:
233
+ Prepare for parallel computation:
234
+
235
+ >>> import brainstate
236
+ >>> brainstate.random.seed(42)
237
+ >>> # Prepare 4 independent keys for parallel access
238
+ >>> brainstate.random.self_assign_multi_keys(4)
239
+
240
+ Use in parallel context:
241
+
242
+ >>> # The random state now has 4 independent keys ready for use
243
+ >>> # Each parallel thread can access a different key
244
+
245
+ Note:
246
+ This is an advanced function primarily used internally for optimizing
247
+ parallel random number generation. In most cases, :func:`split_keys`
248
+ provides a more straightforward interface for parallel computation.
249
+
250
+ See Also:
251
+ - :func:`split_keys`: Generate multiple independent keys
252
+ - :func:`restore_key`: Restore backed up state
253
+ - :func:`seed_context`: Temporary state changes
254
+ """
255
+ if not isinstance(n, int) or n <= 0:
256
+ raise ValueError(f"n must be a positive integer, got {n}")
257
+ DEFAULT.self_assign_multi_keys(n, backup=backup)
258
+
259
+
260
+ def clone_rng(seed_or_key: SeedOrKey = None, clone: bool = True) -> RandomState:
261
+ """
262
+ Create a clone of the random state or a new random state.
263
+
264
+ This function provides a flexible way to create independent random states,
265
+ either by cloning the current global state or by creating a new state with
266
+ a specific seed or key. Cloned states are independent and don't affect each
267
+ other when used for random number generation.
268
+
269
+ Args:
270
+ seed_or_key: Optional seed (integer) or JAX random key to initialize
271
+ the new random state. If None, uses the current global state.
272
+ clone: Whether to clone the default random state. If False and
273
+ seed_or_key is None, returns the global state directly (not recommended
274
+ for most use cases as it shares state).
275
+
276
+ Returns:
277
+ A RandomState instance that can be used independently for random
278
+ number generation.
279
+
280
+ Example:
281
+ Clone the current global state:
282
+
283
+ >>> import brainstate
284
+ >>> brainstate.random.seed(42)
285
+ >>> rng1 = brainstate.random.clone_rng()
286
+ >>> rng2 = brainstate.random.clone_rng()
287
+ >>> # rng1 and rng2 are independent copies
288
+
289
+ Create a new state with specific seed:
290
+
291
+ >>> rng_fixed = brainstate.random.clone_rng(123)
292
+ >>> # Always produces the same sequences when reset to seed 123
293
+
294
+ Use for independent computations:
295
+
296
+ >>> rng = brainstate.random.clone_rng(456)
297
+ >>> values1 = rng.normal(size=5)
298
+ >>> values2 = rng.normal(size=5)
299
+ >>> # values1 and values2 are different but reproducible
300
+
301
+ Note:
302
+ Cloned random states are completely independent. Changes to one state
303
+ (like advancing through random number generation) don't affect others.
304
+
305
+ See Also:
306
+ - :func:`default_rng`: Get or create a random state
307
+ - :func:`seed`: Set the global random seed
308
+ - :class:`RandomState`: The random state class
309
+ """
310
+ if seed_or_key is None:
311
+ return DEFAULT.clone() if clone else DEFAULT
312
+ else:
313
+ return RandomState(seed_or_key)
314
+
315
+
316
+ def default_rng(seed_or_key: SeedOrKey = None) -> RandomState:
317
+ """
318
+ Get the default random state or create a new one with specified seed.
319
+
320
+ This function provides access to the global random state used throughout
321
+ BrainState, or creates a new independent random state if a seed is provided.
322
+ It's the primary interface for obtaining random state objects in BrainState.
323
+
324
+ Args:
325
+ seed_or_key: Optional seed (integer) or JAX random key. If None,
326
+ returns the global default random state. If provided, creates
327
+ a new independent RandomState with the specified seed.
328
+
329
+ Returns:
330
+ The default RandomState if seed_or_key is None, otherwise a new
331
+ RandomState initialized with the provided seed or key.
332
+
333
+ Example:
334
+ Get the global random state:
335
+
336
+ >>> import brainstate
337
+ >>> rng = brainstate.random.default_rng()
338
+ >>> # rng is the global random state used by brainstate.random functions
339
+
340
+ Create a new independent random state:
341
+
342
+ >>> rng_local = brainstate.random.default_rng(42)
343
+ >>> values = rng_local.normal(size=10)
344
+
345
+ Use for reproducible local computations:
346
+
347
+ >>> def reproducible_computation():
348
+ ... local_rng = brainstate.random.default_rng(12345)
349
+ ... return local_rng.uniform(size=5)
350
+ >>> result1 = reproducible_computation()
351
+ >>> result2 = reproducible_computation()
352
+ >>> assert np.allclose(result1, result2) # Always the same
353
+
354
+ Note:
355
+ When seed_or_key is None, this returns the actual global state object.
356
+ Modifications to this state (through random number generation) will
357
+ affect all subsequent calls to global random functions.
358
+
359
+ See Also:
360
+ - :func:`clone_rng`: Create independent clones of random states
361
+ - :func:`seed`: Set the global random seed
362
+ - :class:`RandomState`: The underlying random state implementation
363
+ """
364
+ if seed_or_key is None:
365
+ return DEFAULT
366
+ else:
367
+ return RandomState(seed_or_key)
368
+
369
+
370
+ def set_key(seed_or_key: SeedOrKey) -> None:
371
+ """
372
+ Set a new random key for the global random state.
373
+
374
+ This function updates the global random state with a new key, which can be
375
+ either an integer seed or a JAX PRNG key. All subsequent calls to global
376
+ random functions will use this new key state.
377
+
378
+ Args:
379
+ seed_or_key: The new random key to set. Can be:
380
+ - An integer seed (will be converted to a JAX PRNG key)
381
+ - A JAX PRNG key array
382
+ - A numpy array representing a PRNG key
383
+
384
+ Raises:
385
+ ValueError: If the provided key is not in a valid format.
386
+
387
+ Example:
388
+ Set with integer seed:
389
+
390
+ >>> import brainstate
391
+ >>> brainstate.random.set_key(42)
392
+ >>> values1 = brainstate.random.rand(3)
393
+
394
+ Set with JAX key:
395
+
396
+ >>> import jax
397
+ >>> key = jax.random.key(123)
398
+ >>> brainstate.random.set_key(key)
399
+ >>> values2 = brainstate.random.rand(3)
400
+
401
+ Restore reproducible state:
402
+
403
+ >>> brainstate.random.set_key(42)
404
+ >>> # Now random functions will produce the same sequences as first example
405
+
406
+ Note:
407
+ This function immediately changes the global random state. All threads
408
+ and computations using the global random functions will be affected.
409
+
410
+ See Also:
411
+ - :func:`get_key`: Get the current random key
412
+ - :func:`seed`: Set seed (also affects NumPy)
413
+ - :func:`restore_key`: Restore a backed up key
414
+ """
415
+ if isinstance(seed_or_key, int):
416
+ # Create key using appropriate JAX function based on version
417
+ key = jax.random.PRNGKey(seed_or_key) if use_prng_key else jax.random.key(seed_or_key)
418
+ elif isinstance(seed_or_key, (jax.numpy.ndarray, np.ndarray)):
419
+ if jax.numpy.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
420
+ key = seed_or_key
421
+ elif seed_or_key.size == 2 and seed_or_key.dtype == jax.numpy.uint32:
422
+ key = seed_or_key
423
+ else:
424
+ raise ValueError(
425
+ f"seed_or_key should be an integer, a JAX PRNG key, or a uint32 array of size 2. "
426
+ f"Got array with dtype {seed_or_key.dtype} and size {seed_or_key.size}."
427
+ )
428
+ else:
429
+ raise ValueError(
430
+ f"seed_or_key must be an integer or a JAX-compatible array. "
431
+ f"Got {type(seed_or_key)}."
432
+ )
433
+ DEFAULT.set_key(key)
434
+
435
+
436
+ def get_key():
437
+ """
438
+ Get the current global random key.
439
+
440
+ This function returns the current random key used by the global random state.
441
+ The returned key represents the internal state of the JAX PRNG and can be used
442
+ to restore the random state later or to create independent random number generators.
443
+
444
+ Returns:
445
+ The current JAX PRNG key as a numpy array. This is typically a 2-element
446
+ uint32 array representing the internal state of the random number generator.
447
+
448
+ Example:
449
+ Get and store the current random state:
450
+
451
+ >>> import brainstate
452
+ >>> brainstate.random.seed(42)
453
+ >>> current_key = brainstate.random.get_key()
454
+ >>> print(current_key.shape)
455
+ (2,)
456
+
457
+ Use the key to restore state later:
458
+
459
+ >>> # Generate some random numbers
460
+ >>> values1 = brainstate.random.rand(3)
461
+ >>> # Restore the previous state
462
+ >>> brainstate.random.set_key(current_key)
463
+ >>> values2 = brainstate.random.rand(3)
464
+ >>> # values1 and values2 will be identical
465
+
466
+ Compare keys for debugging:
467
+
468
+ >>> brainstate.random.seed(123)
469
+ >>> key1 = brainstate.random.get_key()
470
+ >>> brainstate.random.seed(123)
471
+ >>> key2 = brainstate.random.get_key()
472
+ >>> assert jax.numpy.array_equal(key1, key2) # Same seed gives same key
473
+
474
+ Note:
475
+ The returned key is a snapshot of the current state. Subsequent calls to
476
+ random functions will advance the internal state, so calling get_key()
477
+ again will return a different key unless the state is reset.
478
+
479
+ See Also:
480
+ - :func:`set_key`: Set a new random key
481
+ - :func:`seed`: Set the random seed (also affects NumPy)
482
+ - :func:`split_key`: Create new keys from current state
483
+ - :func:`seed_context`: Temporary seed changes with automatic restoration
484
+
485
+ """
486
+ return DEFAULT.value
487
+
488
+
489
+ def seed(seed_or_key: SeedOrKey = None):
490
+ """
491
+ Set the global random seed for both JAX and NumPy.
492
+
493
+ This function initializes the global random state with a new seed, affecting
494
+ both JAX and NumPy random number generators. It ensures reproducible random
495
+ number generation across the entire BrainState ecosystem.
496
+
497
+ Args:
498
+ seed_or_key: The seed or key to set. Can be:
499
+ - None: Generates a random seed automatically
500
+ - int: An integer seed (0 to 2^32-1)
501
+ - JAX PRNG key: A JAX random key array
502
+ If None, a random seed is generated using NumPy's random generator.
503
+
504
+ Raises:
505
+ ValueError: If seed_or_key is not a valid seed format (not an integer,
506
+ valid JAX key, or None).
507
+
508
+ Example:
509
+ Set a specific seed for reproducible results:
510
+
511
+ >>> import brainstate
512
+ >>> brainstate.random.seed(42)
513
+ >>> values1 = brainstate.random.rand(3)
514
+ >>> brainstate.random.seed(42) # Reset to same seed
515
+ >>> values2 = brainstate.random.rand(3)
516
+ >>> assert np.allclose(values1, values2) # Same values
517
+
518
+ Use automatic random seeding:
519
+
520
+ >>> brainstate.random.seed() # Uses random seed
521
+ >>> # Each call will produce different sequences
522
+
523
+ Use with JAX keys:
524
+
525
+ >>> import jax
526
+ >>> key = jax.random.key(123)
527
+ >>> brainstate.random.seed(key)
528
+ >>> # Now both JAX and NumPy use consistent seeds
529
+
530
+ Ensure reproducibility in scientific experiments:
531
+
532
+ >>> def experiment():
533
+ ... brainstate.random.seed(12345) # Fixed seed for reproducibility
534
+ ... data = brainstate.random.normal(size=(100, 10))
535
+ ... return data.mean()
536
+ >>> result1 = experiment()
537
+ >>> result2 = experiment()
538
+ >>> assert result1 == result2 # Always same result
539
+
540
+ Note:
541
+ - This function affects the global random state used by all BrainState
542
+ random functions and NumPy's global random state.
543
+ - When using automatic seeding (seed_or_key=None), NumPy's seed is not
544
+ set to maintain its current state.
545
+ - JAX compilation is handled automatically with compile-time evaluation.
546
+ - For JAX keys, only the first element is used to seed NumPy to maintain
547
+ compatibility between the two random systems.
548
+
549
+ See Also:
550
+ - :func:`set_key`: Set only the JAX random key
551
+ - :func:`get_key`: Get the current random key
552
+ - :func:`seed_context`: Temporary seed changes
553
+ - :func:`split_key`: Create independent random keys
554
+
555
+ """
556
+ with jax.ensure_compile_time_eval():
557
+ _set_numpy_seed = True
558
+ if seed_or_key is None:
559
+ seed_or_key = np.random.randint(0, 100000)
560
+ _set_numpy_seed = False
561
+
562
+ # numpy random seed
563
+ if _set_numpy_seed:
564
+ try:
565
+ if np.size(seed_or_key) == 1: # seed
566
+ np.random.seed(seed_or_key)
567
+ elif np.size(seed_or_key) == 2: # jax random key
568
+ np.random.seed(seed_or_key[0])
569
+ else:
570
+ raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
571
+ except jax.errors.TracerArrayConversionError:
572
+ pass
573
+
574
+ # jax random seed
575
+ DEFAULT.seed(seed_or_key)
576
+
577
+
578
+ @contextmanager
579
+ def seed_context(seed_or_key: SeedOrKey):
580
+ """
581
+ Context manager for temporary random seed changes with automatic restoration.
582
+
583
+ This context manager temporarily changes the global random seed for the duration
584
+ of the block, then automatically restores the previous random state when exiting.
585
+ It's ideal for ensuring reproducible computations in specific code sections without
586
+ permanently affecting the global random state.
587
+
588
+ Args:
589
+ seed_or_key: The temporary seed or key to use within the context. Can be:
590
+ - int: An integer seed for reproducible sequences
591
+ - JAX PRNG key: A JAX random key array
592
+ The seed affects both JAX and NumPy random states during the context.
593
+
594
+ Yields:
595
+ None: The context manager doesn't yield any value, but provides a
596
+ controlled random environment for the enclosed code block.
597
+
598
+ Example:
599
+ Reproducible computations without affecting global state:
600
+
601
+ >>> import brainstate
602
+ >>> # Global state remains unaffected
603
+ >>> global_values1 = brainstate.random.rand(2)
604
+ >>>
605
+ >>> with brainstate.random.seed_context(42):
606
+ ... temp_values1 = brainstate.random.rand(2)
607
+ ... print(f"First run: {temp_values1}")
608
+ [0.95598125 0.4032725 ]
609
+ >>>
610
+ >>> with brainstate.random.seed_context(42):
611
+ ... temp_values2 = brainstate.random.rand(2)
612
+ ... print(f"Second run: {temp_values2}")
613
+ [0.95598125 0.4032725 ]
614
+ >>>
615
+ >>> # Values are identical within context
616
+ >>> assert np.allclose(temp_values1, temp_values2)
617
+ >>>
618
+ >>> # Global state continues from where it left off
619
+ >>> global_values2 = brainstate.random.rand(2)
620
+
621
+ Nested contexts for complex scenarios:
622
+
623
+ >>> with brainstate.random.seed_context(123):
624
+ ... outer_values = brainstate.random.rand(2)
625
+ ... with brainstate.random.seed_context(456):
626
+ ... inner_values = brainstate.random.rand(2)
627
+ ... # Outer context is restored here
628
+ ... outer_values2 = brainstate.random.rand(2)
629
+
630
+ Exception safety - state is restored even on errors:
631
+
632
+ >>> try:
633
+ ... with brainstate.random.seed_context(789):
634
+ ... some_values = brainstate.random.rand(3)
635
+ ... raise ValueError("Something went wrong")
636
+ ... except ValueError:
637
+ ... pass
638
+ >>> # Random state is properly restored
639
+
640
+ Testing reproducible algorithms:
641
+
642
+ >>> def test_algorithm():
643
+ ... with brainstate.random.seed_context(42):
644
+ ... data = brainstate.random.normal(size=(100,))
645
+ ... return data.mean()
646
+ >>>
647
+ >>> result1 = test_algorithm()
648
+ >>> result2 = test_algorithm()
649
+ >>> assert result1 == result2 # Always same result
650
+
651
+ Note:
652
+ - The context manager saves and restores the complete JAX random state
653
+ - NumPy's random state is also temporarily modified during the context
654
+ - Nested contexts work correctly - each level restores its own state
655
+ - Exception safety is guaranteed - random state is restored even if
656
+ exceptions occur within the context
657
+ - This is more convenient than manually saving/restoring state with
658
+ get_key() and set_key()
659
+
660
+ See Also:
661
+ - :func:`seed`: Permanently set the global random seed
662
+ - :func:`get_key`: Get the current random key for manual state management
663
+ - :func:`set_key`: Set the random key for manual state management
664
+ - :func:`clone_rng`: Create independent random states
665
+
666
+ """
667
+ # get the old random key
668
+ old_jrand_key = DEFAULT.value
669
+ try:
670
+ # set the seed of jax random state
671
+ DEFAULT.seed(seed_or_key)
672
+ yield
673
+ finally:
674
+ # restore the random state
675
+ DEFAULT.seed(old_jrand_key)