brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,8 +13,42 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """
17
+ Random seed management utilities for BrainState.
18
+
19
+ This module provides comprehensive random seed management functionality, enabling
20
+ reproducible computations across JAX and NumPy backends. It supports both traditional
21
+ integer seeds and JAX's PRNG key system, providing a unified interface for random
22
+ number generation in scientific computing and machine learning applications.
23
+
24
+ Key Features:
25
+ - Unified seed management for JAX and NumPy
26
+ - Context managers for temporary seed changes
27
+ - Key splitting for parallel computation
28
+ - Automatic seed backup and restoration
29
+ - Thread-safe random state management
30
+
31
+ Example:
32
+ Basic usage for reproducible random number generation:
33
+
34
+ >>> import brainstate
35
+ >>> brainstate.random.seed(42)
36
+ >>> print(brainstate.random.rand(3))
37
+ [0.95598125 0.4032725 0.96086407]
38
+
39
+ Using context managers for temporary seeds:
40
+
41
+ >>> with brainstate.random.seed_context(123):
42
+ ... values = brainstate.random.rand(2)
43
+ >>> print(values) # Reproducible output
44
+
45
+ Key splitting for parallel computation:
46
+
47
+ >>> keys = brainstate.random.split_keys(4) # Generate 4 independent keys
48
+ >>> # Use keys for parallel random number generation
49
+ """
50
+
16
51
  from contextlib import contextmanager
17
- from typing import Optional
18
52
 
19
53
  import jax
20
54
  import numpy as np
@@ -23,74 +57,255 @@ from brainstate.typing import SeedOrKey
23
57
  from ._rand_state import RandomState, DEFAULT, use_prng_key
24
58
 
25
59
  __all__ = [
26
- 'seed', 'set_key', 'get_key', 'default_rng', 'split_key', 'split_keys', 'seed_context', 'restore_key',
60
+ 'seed',
61
+ 'set_key',
62
+ 'get_key',
63
+ 'default_rng',
64
+ 'split_key',
65
+ 'split_keys',
66
+ 'seed_context',
67
+ 'restore_key',
27
68
  'self_assign_multi_keys',
69
+ 'clone_rng',
28
70
  ]
29
71
 
30
72
 
31
- def restore_key():
32
- """Restore the default random key."""
73
+ def restore_key() -> None:
74
+ """
75
+ Restore the default random key to its previous state.
76
+
77
+ This function restores the global random state to a previously backed up state.
78
+ It's useful for undoing changes to the random state or implementing checkpoint
79
+ functionality in computational workflows.
80
+
81
+ Note:
82
+ This operation requires that a backup was previously created. If no backup
83
+ exists, this function may not have any effect or may restore to an initial state.
84
+
85
+ Example:
86
+ >>> import brainstate
87
+ >>> brainstate.random.seed(42)
88
+ >>> original_key = brainstate.random.get_key()
89
+ >>> brainstate.random.seed(123) # Change the seed
90
+ >>> brainstate.random.restore_key() # Restore to previous state
91
+ >>> assert np.array_equal(brainstate.random.get_key(), original_key)
92
+
93
+ See Also:
94
+ - :func:`set_key`: Set a new random key
95
+ - :func:`get_key`: Get the current random key
96
+ - :func:`seed_context`: Temporary seed changes with automatic restoration
97
+ """
33
98
  DEFAULT.restore_key()
34
99
 
35
100
 
36
- def split_key(n: Optional[int] = None, backup: bool = False):
37
- """Create a new seed from the current seed.
101
+ def split_key(n: int = None, backup: bool = False):
102
+ """
103
+ Create new random key(s) from the current seed.
104
+
105
+ This function generates one or more independent random keys by splitting the
106
+ current global random state. It follows JAX's random paradigm, ensuring that
107
+ each split key produces statistically independent random sequences.
108
+
109
+ Args:
110
+ n: The number of keys to generate. If None, returns a single key.
111
+ If an integer, returns an array of n keys.
112
+ backup: Whether to backup the current key before splitting. This allows
113
+ restoration of the original state using :func:`restore_key`.
114
+
115
+ Returns:
116
+ If n is None: A single JAX PRNG key.
117
+ If n is an integer: An array of n independent JAX PRNG keys.
118
+
119
+ Example:
120
+ Generate a single key:
121
+
122
+ >>> import brainstate
123
+ >>> brainstate.random.seed(42)
124
+ >>> key = brainstate.random.split_key()
125
+ >>> print(key.shape)
126
+ (2,)
38
127
 
39
- This function is useful for the consistency with JAX's random paradigm.
128
+ Generate multiple keys for parallel computation:
40
129
 
41
- Parameters
42
- ----------
43
- n : int, optional
44
- The number of seeds to generate.
45
- backup : bool, optional
46
- Whether to backup the current key.
130
+ >>> keys = brainstate.random.split_key(4)
131
+ >>> print(keys.shape)
132
+ (4, 2)
47
133
 
48
- Returns
49
- -------
50
- key : jax.random.PRNGKey
51
- A new random key.
134
+ Use with backup for state restoration:
52
135
 
136
+ >>> original_key = brainstate.random.get_key()
137
+ >>> keys = brainstate.random.split_key(2, backup=True)
138
+ >>> brainstate.random.restore_key()
139
+ >>> assert np.array_equal(brainstate.random.get_key(), original_key)
140
+
141
+ Note:
142
+ This function advances the global random state. Each call produces
143
+ different keys unless the state is reset.
144
+
145
+ See Also:
146
+ - :func:`split_keys`: Convenience function for multiple keys
147
+ - :func:`seed`: Set the random seed
148
+ - :func:`restore_key`: Restore backed up key
53
149
  """
54
150
  return DEFAULT.split_key(n=n, backup=backup)
55
151
 
56
152
 
57
153
  def split_keys(n: int, backup: bool = False):
58
- """Create multiple seeds from the current seed. This is used
59
- internally by `pmap` and `vmap` to ensure that random numbers
60
- are different in parallel threads.
154
+ """
155
+ Create multiple independent random keys from the current seed.
61
156
 
62
- Parameters
63
- ----------
64
- n : int
65
- The number of seeds to generate.
66
- backup : bool, optional
67
- Whether to backup the current key
157
+ This is a convenience function that generates exactly n independent random keys
158
+ by splitting the current global random state. It's commonly used internally by
159
+ parallel computation functions like `pmap` and `vmap` to ensure that each
160
+ parallel thread gets a unique random key.
68
161
 
69
- Returns
70
- -------
71
- keys : jax.random.PRNGKey
72
- A tuple of JAX random keys.
162
+ Args:
163
+ n: The number of independent keys to generate. Must be a positive integer.
164
+ backup: Whether to backup the current key before splitting. If True,
165
+ the original key can be restored using :func:`restore_key`.
73
166
 
167
+ Returns:
168
+ An array of n independent JAX PRNG keys with shape (n, 2).
169
+
170
+ Raises:
171
+ ValueError: If n is not a positive integer.
172
+
173
+ Example:
174
+ Generate keys for parallel computation:
175
+
176
+ >>> import brainstate
177
+ >>> brainstate.random.seed(42)
178
+ >>> keys = brainstate.random.split_keys(4)
179
+ >>> print(keys.shape)
180
+ (4, 2)
181
+
182
+ Use with vmap for parallel random number generation:
183
+
184
+ >>> import jax
185
+ >>> keys = brainstate.random.split_keys(8)
186
+ >>> @jax.vmap
187
+ ... def generate_random(key):
188
+ ... return jax.random.normal(key, (10,))
189
+ >>> parallel_randoms = generate_random(keys)
190
+ >>> print(parallel_randoms.shape)
191
+ (8, 10)
192
+
193
+ Use with backup for state preservation:
194
+
195
+ >>> original_state = brainstate.random.get_key()
196
+ >>> keys = brainstate.random.split_keys(3, backup=True)
197
+ >>> # ... use keys for computation ...
198
+ >>> brainstate.random.restore_key() # Restore original state
199
+
200
+ Note:
201
+ This function is equivalent to calling :func:`split_key` with n as an argument.
202
+ It's provided as a convenience function with a more explicit name for clarity.
203
+
204
+ See Also:
205
+ - :func:`split_key`: More general key splitting function
206
+ - :func:`self_assign_multi_keys`: Assign multiple keys to global state
207
+ - :func:`seed_context`: Temporary seed changes
74
208
  """
209
+ if not isinstance(n, int) or n <= 0:
210
+ raise ValueError(f"n must be a positive integer, got {n}")
75
211
  return split_key(n, backup=backup)
76
212
 
77
213
 
78
- def self_assign_multi_keys(n: int, backup: bool = True):
214
+ def self_assign_multi_keys(n: int, backup: bool = True) -> None:
79
215
  """
80
- Assign multiple keys to the current key.
216
+ Assign multiple keys to the global random state for parallel access.
217
+
218
+ This function prepares the global random state for parallel computation by
219
+ pre-generating n independent keys. It's particularly useful when you need
220
+ to ensure that parallel computations have access to independent random
221
+ sequences without the overhead of key splitting during computation.
222
+
223
+ Args:
224
+ n: The number of independent keys to pre-generate and assign.
225
+ Must be a positive integer.
226
+ backup: Whether to backup the current random state before assignment.
227
+ If True, the original state can be restored using :func:`restore_key`.
228
+
229
+ Raises:
230
+ ValueError: If n is not a positive integer.
231
+
232
+ Example:
233
+ Prepare for parallel computation:
234
+
235
+ >>> import brainstate
236
+ >>> brainstate.random.seed(42)
237
+ >>> # Prepare 4 independent keys for parallel access
238
+ >>> brainstate.random.self_assign_multi_keys(4)
239
+
240
+ Use in parallel context:
241
+
242
+ >>> # The random state now has 4 independent keys ready for use
243
+ >>> # Each parallel thread can access a different key
244
+
245
+ Note:
246
+ This is an advanced function primarily used internally for optimizing
247
+ parallel random number generation. In most cases, :func:`split_keys`
248
+ provides a more straightforward interface for parallel computation.
249
+
250
+ See Also:
251
+ - :func:`split_keys`: Generate multiple independent keys
252
+ - :func:`restore_key`: Restore backed up state
253
+ - :func:`seed_context`: Temporary state changes
81
254
  """
255
+ if not isinstance(n, int) or n <= 0:
256
+ raise ValueError(f"n must be a positive integer, got {n}")
82
257
  DEFAULT.self_assign_multi_keys(n, backup=backup)
83
258
 
84
259
 
85
- def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState:
86
- """Clone the random state according to the given setting.
260
+ def clone_rng(seed_or_key: SeedOrKey = None, clone: bool = True) -> RandomState:
261
+ """
262
+ Create a clone of the random state or a new random state.
263
+
264
+ This function provides a flexible way to create independent random states,
265
+ either by cloning the current global state or by creating a new state with
266
+ a specific seed or key. Cloned states are independent and don't affect each
267
+ other when used for random number generation.
87
268
 
88
269
  Args:
89
- seed_or_key: The seed (an integer) or the random key.
90
- clone: Bool. Whether clone the default random state.
270
+ seed_or_key: Optional seed (integer) or JAX random key to initialize
271
+ the new random state. If None, uses the current global state.
272
+ clone: Whether to clone the default random state. If False and
273
+ seed_or_key is None, returns the global state directly (not recommended
274
+ for most use cases as it shares state).
91
275
 
92
276
  Returns:
93
- The random state.
277
+ A RandomState instance that can be used independently for random
278
+ number generation.
279
+
280
+ Example:
281
+ Clone the current global state:
282
+
283
+ >>> import brainstate
284
+ >>> brainstate.random.seed(42)
285
+ >>> rng1 = brainstate.random.clone_rng()
286
+ >>> rng2 = brainstate.random.clone_rng()
287
+ >>> # rng1 and rng2 are independent copies
288
+
289
+ Create a new state with specific seed:
290
+
291
+ >>> rng_fixed = brainstate.random.clone_rng(123)
292
+ >>> # Always produces the same sequences when reset to seed 123
293
+
294
+ Use for independent computations:
295
+
296
+ >>> rng = brainstate.random.clone_rng(456)
297
+ >>> values1 = rng.normal(size=5)
298
+ >>> values2 = rng.normal(size=5)
299
+ >>> # values1 and values2 are different but reproducible
300
+
301
+ Note:
302
+ Cloned random states are completely independent. Changes to one state
303
+ (like advancing through random number generation) don't affect others.
304
+
305
+ See Also:
306
+ - :func:`default_rng`: Get or create a random state
307
+ - :func:`seed`: Set the global random seed
308
+ - :class:`RandomState`: The random state class
94
309
  """
95
310
  if seed_or_key is None:
96
311
  return DEFAULT.clone() if clone else DEFAULT
@@ -98,15 +313,53 @@ def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState:
98
313
  return RandomState(seed_or_key)
99
314
 
100
315
 
101
- def default_rng(seed_or_key=None) -> RandomState:
316
+ def default_rng(seed_or_key: SeedOrKey = None) -> RandomState:
102
317
  """
103
- Get the default random state.
318
+ Get the default random state or create a new one with specified seed.
319
+
320
+ This function provides access to the global random state used throughout
321
+ BrainState, or creates a new independent random state if a seed is provided.
322
+ It's the primary interface for obtaining random state objects in BrainState.
104
323
 
105
324
  Args:
106
- seed_or_key: The seed (an integer) or the jax random key.
325
+ seed_or_key: Optional seed (integer) or JAX random key. If None,
326
+ returns the global default random state. If provided, creates
327
+ a new independent RandomState with the specified seed.
107
328
 
108
329
  Returns:
109
- The random state.
330
+ The default RandomState if seed_or_key is None, otherwise a new
331
+ RandomState initialized with the provided seed or key.
332
+
333
+ Example:
334
+ Get the global random state:
335
+
336
+ >>> import brainstate
337
+ >>> rng = brainstate.random.default_rng()
338
+ >>> # rng is the global random state used by brainstate.random functions
339
+
340
+ Create a new independent random state:
341
+
342
+ >>> rng_local = brainstate.random.default_rng(42)
343
+ >>> values = rng_local.normal(size=10)
344
+
345
+ Use for reproducible local computations:
346
+
347
+ >>> def reproducible_computation():
348
+ ... local_rng = brainstate.random.default_rng(12345)
349
+ ... return local_rng.uniform(size=5)
350
+ >>> result1 = reproducible_computation()
351
+ >>> result2 = reproducible_computation()
352
+ >>> assert np.allclose(result1, result2) # Always the same
353
+
354
+ Note:
355
+ When seed_or_key is None, this returns the actual global state object.
356
+ Modifications to this state (through random number generation) will
357
+ affect all subsequent calls to global random functions.
358
+
359
+ See Also:
360
+ - :func:`clone_rng`: Create independent clones of random states
361
+ - :func:`seed`: Set the global random seed
362
+ - :class:`RandomState`: The underlying random state implementation
110
363
  """
111
364
  if seed_or_key is None:
112
365
  return DEFAULT
@@ -114,45 +367,191 @@ def default_rng(seed_or_key=None) -> RandomState:
114
367
  return RandomState(seed_or_key)
115
368
 
116
369
 
117
- def set_key(seed_or_key: SeedOrKey):
118
- """Sets a new random key.
370
+ def set_key(seed_or_key: SeedOrKey) -> None:
371
+ """
372
+ Set a new random key for the global random state.
373
+
374
+ This function updates the global random state with a new key, which can be
375
+ either an integer seed or a JAX PRNG key. All subsequent calls to global
376
+ random functions will use this new key state.
377
+
378
+ Args:
379
+ seed_or_key: The new random key to set. Can be:
380
+ - An integer seed (will be converted to a JAX PRNG key)
381
+ - A JAX PRNG key array
382
+ - A numpy array representing a PRNG key
119
383
 
120
- Parameters
121
- ----------
122
- seed_or_key: int
123
- The random key.
384
+ Raises:
385
+ ValueError: If the provided key is not in a valid format.
386
+
387
+ Example:
388
+ Set with integer seed:
389
+
390
+ >>> import brainstate
391
+ >>> brainstate.random.set_key(42)
392
+ >>> values1 = brainstate.random.rand(3)
393
+
394
+ Set with JAX key:
395
+
396
+ >>> import jax
397
+ >>> key = jax.random.key(123)
398
+ >>> brainstate.random.set_key(key)
399
+ >>> values2 = brainstate.random.rand(3)
400
+
401
+ Restore reproducible state:
402
+
403
+ >>> brainstate.random.set_key(42)
404
+ >>> # Now random functions will produce the same sequences as first example
405
+
406
+ Note:
407
+ This function immediately changes the global random state. All threads
408
+ and computations using the global random functions will be affected.
409
+
410
+ See Also:
411
+ - :func:`get_key`: Get the current random key
412
+ - :func:`seed`: Set seed (also affects NumPy)
413
+ - :func:`restore_key`: Restore a backed up key
124
414
  """
125
415
  if isinstance(seed_or_key, int):
126
- # key = jax.random.key(seed_or_key)
127
- key = jax.random.PRNGKey(seed_or_key) if use_prng_key else jrjax.random.key(seed_or_key)
416
+ # Create key using appropriate JAX function based on version
417
+ key = jax.random.PRNGKey(seed_or_key) if use_prng_key else jax.random.key(seed_or_key)
128
418
  elif isinstance(seed_or_key, (jax.numpy.ndarray, np.ndarray)):
129
419
  if jax.numpy.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
130
420
  key = seed_or_key
131
421
  elif seed_or_key.size == 2 and seed_or_key.dtype == jax.numpy.uint32:
132
422
  key = seed_or_key
133
423
  else:
134
- raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
424
+ raise ValueError(
425
+ f"seed_or_key should be an integer, a JAX PRNG key, or a uint32 array of size 2. "
426
+ f"Got array with dtype {seed_or_key.dtype} and size {seed_or_key.size}."
427
+ )
428
+ else:
429
+ raise ValueError(
430
+ f"seed_or_key must be an integer or a JAX-compatible array. "
431
+ f"Got {type(seed_or_key)}."
432
+ )
135
433
  DEFAULT.set_key(key)
136
434
 
137
435
 
138
436
  def get_key():
139
- """Get a new random key.
437
+ """
438
+ Get the current global random key.
439
+
440
+ This function returns the current random key used by the global random state.
441
+ The returned key represents the internal state of the JAX PRNG and can be used
442
+ to restore the random state later or to create independent random number generators.
443
+
444
+ Returns:
445
+ The current JAX PRNG key as a numpy array. This is typically a 2-element
446
+ uint32 array representing the internal state of the random number generator.
447
+
448
+ Example:
449
+ Get and store the current random state:
450
+
451
+ >>> import brainstate
452
+ >>> brainstate.random.seed(42)
453
+ >>> current_key = brainstate.random.get_key()
454
+ >>> print(current_key.shape)
455
+ (2,)
456
+
457
+ Use the key to restore state later:
458
+
459
+ >>> # Generate some random numbers
460
+ >>> values1 = brainstate.random.rand(3)
461
+ >>> # Restore the previous state
462
+ >>> brainstate.random.set_key(current_key)
463
+ >>> values2 = brainstate.random.rand(3)
464
+ >>> # values1 and values2 will be identical
465
+
466
+ Compare keys for debugging:
467
+
468
+ >>> brainstate.random.seed(123)
469
+ >>> key1 = brainstate.random.get_key()
470
+ >>> brainstate.random.seed(123)
471
+ >>> key2 = brainstate.random.get_key()
472
+ >>> assert jax.numpy.array_equal(key1, key2) # Same seed gives same key
473
+
474
+ Note:
475
+ The returned key is a snapshot of the current state. Subsequent calls to
476
+ random functions will advance the internal state, so calling get_key()
477
+ again will return a different key unless the state is reset.
478
+
479
+ See Also:
480
+ - :func:`set_key`: Set a new random key
481
+ - :func:`seed`: Set the random seed (also affects NumPy)
482
+ - :func:`split_key`: Create new keys from current state
483
+ - :func:`seed_context`: Temporary seed changes with automatic restoration
140
484
 
141
- Returns
142
- -------
143
- seed_or_key: int
144
- The random key.
145
485
  """
146
486
  return DEFAULT.value
147
487
 
148
488
 
149
489
  def seed(seed_or_key: SeedOrKey = None):
150
- """Sets a new random seed.
490
+ """
491
+ Set the global random seed for both JAX and NumPy.
492
+
493
+ This function initializes the global random state with a new seed, affecting
494
+ both JAX and NumPy random number generators. It ensures reproducible random
495
+ number generation across the entire BrainState ecosystem.
496
+
497
+ Args:
498
+ seed_or_key: The seed or key to set. Can be:
499
+ - None: Generates a random seed automatically
500
+ - int: An integer seed (0 to 2^32-1)
501
+ - JAX PRNG key: A JAX random key array
502
+ If None, a random seed is generated using NumPy's random generator.
503
+
504
+ Raises:
505
+ ValueError: If seed_or_key is not a valid seed format (not an integer,
506
+ valid JAX key, or None).
507
+
508
+ Example:
509
+ Set a specific seed for reproducible results:
510
+
511
+ >>> import brainstate
512
+ >>> brainstate.random.seed(42)
513
+ >>> values1 = brainstate.random.rand(3)
514
+ >>> brainstate.random.seed(42) # Reset to same seed
515
+ >>> values2 = brainstate.random.rand(3)
516
+ >>> assert np.allclose(values1, values2) # Same values
517
+
518
+ Use automatic random seeding:
519
+
520
+ >>> brainstate.random.seed() # Uses random seed
521
+ >>> # Each call will produce different sequences
522
+
523
+ Use with JAX keys:
524
+
525
+ >>> import jax
526
+ >>> key = jax.random.key(123)
527
+ >>> brainstate.random.seed(key)
528
+ >>> # Now both JAX and NumPy use consistent seeds
529
+
530
+ Ensure reproducibility in scientific experiments:
531
+
532
+ >>> def experiment():
533
+ ... brainstate.random.seed(12345) # Fixed seed for reproducibility
534
+ ... data = brainstate.random.normal(size=(100, 10))
535
+ ... return data.mean()
536
+ >>> result1 = experiment()
537
+ >>> result2 = experiment()
538
+ >>> assert result1 == result2 # Always same result
539
+
540
+ Note:
541
+ - This function affects the global random state used by all BrainState
542
+ random functions and NumPy's global random state.
543
+ - When using automatic seeding (seed_or_key=None), NumPy's seed is not
544
+ set to maintain its current state.
545
+ - JAX compilation is handled automatically with compile-time evaluation.
546
+ - For JAX keys, only the first element is used to seed NumPy to maintain
547
+ compatibility between the two random systems.
548
+
549
+ See Also:
550
+ - :func:`set_key`: Set only the JAX random key
551
+ - :func:`get_key`: Get the current random key
552
+ - :func:`seed_context`: Temporary seed changes
553
+ - :func:`split_key`: Create independent random keys
151
554
 
152
- Parameters
153
- ----------
154
- seed_or_key: int, optional
155
- The random seed (an integer) or jax random key.
156
555
  """
157
556
  with jax.ensure_compile_time_eval():
158
557
  _set_numpy_seed = True
@@ -179,24 +578,90 @@ def seed(seed_or_key: SeedOrKey = None):
179
578
  @contextmanager
180
579
  def seed_context(seed_or_key: SeedOrKey):
181
580
  """
182
- A context manager that sets the random seed for the duration of the block.
183
-
184
- Examples:
581
+ Context manager for temporary random seed changes with automatic restoration.
185
582
 
186
- >>> import brainstate as brainstate
187
- >>> print(brainstate.random.rand(2))
188
- [0.57721865 0.9820676 ]
189
- >>> print(brainstate.random.rand(2))
190
- [0.8511752 0.95312667]
191
- >>> with brainstate.random.seed_context(42):
192
- ... print(brainstate.random.rand(2))
193
- [0.95598125 0.4032725 ]
194
- >>> with brainstate.random.seed_context(42):
195
- ... print(brainstate.random.rand(2))
196
- [0.95598125 0.4032725 ]
583
+ This context manager temporarily changes the global random seed for the duration
584
+ of the block, then automatically restores the previous random state when exiting.
585
+ It's ideal for ensuring reproducible computations in specific code sections without
586
+ permanently affecting the global random state.
197
587
 
198
588
  Args:
199
- seed_or_key: The seed (an integer) or jax random key.
589
+ seed_or_key: The temporary seed or key to use within the context. Can be:
590
+ - int: An integer seed for reproducible sequences
591
+ - JAX PRNG key: A JAX random key array
592
+ The seed affects both JAX and NumPy random states during the context.
593
+
594
+ Yields:
595
+ None: The context manager doesn't yield any value, but provides a
596
+ controlled random environment for the enclosed code block.
597
+
598
+ Example:
599
+ Reproducible computations without affecting global state:
600
+
601
+ >>> import brainstate
602
+ >>> # Global state remains unaffected
603
+ >>> global_values1 = brainstate.random.rand(2)
604
+ >>>
605
+ >>> with brainstate.random.seed_context(42):
606
+ ... temp_values1 = brainstate.random.rand(2)
607
+ ... print(f"First run: {temp_values1}")
608
+ [0.95598125 0.4032725 ]
609
+ >>>
610
+ >>> with brainstate.random.seed_context(42):
611
+ ... temp_values2 = brainstate.random.rand(2)
612
+ ... print(f"Second run: {temp_values2}")
613
+ [0.95598125 0.4032725 ]
614
+ >>>
615
+ >>> # Values are identical within context
616
+ >>> assert np.allclose(temp_values1, temp_values2)
617
+ >>>
618
+ >>> # Global state continues from where it left off
619
+ >>> global_values2 = brainstate.random.rand(2)
620
+
621
+ Nested contexts for complex scenarios:
622
+
623
+ >>> with brainstate.random.seed_context(123):
624
+ ... outer_values = brainstate.random.rand(2)
625
+ ... with brainstate.random.seed_context(456):
626
+ ... inner_values = brainstate.random.rand(2)
627
+ ... # Outer context is restored here
628
+ ... outer_values2 = brainstate.random.rand(2)
629
+
630
+ Exception safety - state is restored even on errors:
631
+
632
+ >>> try:
633
+ ... with brainstate.random.seed_context(789):
634
+ ... some_values = brainstate.random.rand(3)
635
+ ... raise ValueError("Something went wrong")
636
+ ... except ValueError:
637
+ ... pass
638
+ >>> # Random state is properly restored
639
+
640
+ Testing reproducible algorithms:
641
+
642
+ >>> def test_algorithm():
643
+ ... with brainstate.random.seed_context(42):
644
+ ... data = brainstate.random.normal(size=(100,))
645
+ ... return data.mean()
646
+ >>>
647
+ >>> result1 = test_algorithm()
648
+ >>> result2 = test_algorithm()
649
+ >>> assert result1 == result2 # Always same result
650
+
651
+ Note:
652
+ - The context manager saves and restores the complete JAX random state
653
+ - NumPy's random state is also temporarily modified during the context
654
+ - Nested contexts work correctly - each level restores its own state
655
+ - Exception safety is guaranteed - random state is restored even if
656
+ exceptions occur within the context
657
+ - This is more convenient than manually saving/restoring state with
658
+ get_key() and set_key()
659
+
660
+ See Also:
661
+ - :func:`seed`: Permanently set the global random seed
662
+ - :func:`get_key`: Get the current random key for manual state management
663
+ - :func:`set_key`: Set the random key for manual state management
664
+ - :func:`clone_rng`: Create independent random states
200
665
 
201
666
  """
202
667
  # get the old random key