brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,210 +1,675 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from contextlib import contextmanager
17
- from typing import Optional
18
-
19
- import jax
20
- import numpy as np
21
-
22
- from brainstate.typing import SeedOrKey
23
- from ._rand_state import RandomState, DEFAULT, use_prng_key
24
-
25
- __all__ = [
26
- 'seed', 'set_key', 'get_key', 'default_rng', 'split_key', 'split_keys', 'seed_context', 'restore_key',
27
- 'self_assign_multi_keys',
28
- ]
29
-
30
-
31
- def restore_key():
32
- """Restore the default random key."""
33
- DEFAULT.restore_key()
34
-
35
-
36
- def split_key(n: Optional[int] = None, backup: bool = False):
37
- """Create a new seed from the current seed.
38
-
39
- This function is useful for the consistency with JAX's random paradigm.
40
-
41
- Parameters
42
- ----------
43
- n : int, optional
44
- The number of seeds to generate.
45
- backup : bool, optional
46
- Whether to backup the current key.
47
-
48
- Returns
49
- -------
50
- key : jax.random.PRNGKey
51
- A new random key.
52
-
53
- """
54
- return DEFAULT.split_key(n=n, backup=backup)
55
-
56
-
57
- def split_keys(n: int, backup: bool = False):
58
- """Create multiple seeds from the current seed. This is used
59
- internally by `pmap` and `vmap` to ensure that random numbers
60
- are different in parallel threads.
61
-
62
- Parameters
63
- ----------
64
- n : int
65
- The number of seeds to generate.
66
- backup : bool, optional
67
- Whether to backup the current key
68
-
69
- Returns
70
- -------
71
- keys : jax.random.PRNGKey
72
- A tuple of JAX random keys.
73
-
74
- """
75
- return split_key(n, backup=backup)
76
-
77
-
78
- def self_assign_multi_keys(n: int, backup: bool = True):
79
- """
80
- Assign multiple keys to the current key.
81
- """
82
- DEFAULT.self_assign_multi_keys(n, backup=backup)
83
-
84
-
85
- def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState:
86
- """Clone the random state according to the given setting.
87
-
88
- Args:
89
- seed_or_key: The seed (an integer) or the random key.
90
- clone: Bool. Whether clone the default random state.
91
-
92
- Returns:
93
- The random state.
94
- """
95
- if seed_or_key is None:
96
- return DEFAULT.clone() if clone else DEFAULT
97
- else:
98
- return RandomState(seed_or_key)
99
-
100
-
101
- def default_rng(seed_or_key=None) -> RandomState:
102
- """
103
- Get the default random state.
104
-
105
- Args:
106
- seed_or_key: The seed (an integer) or the jax random key.
107
-
108
- Returns:
109
- The random state.
110
- """
111
- if seed_or_key is None:
112
- return DEFAULT
113
- else:
114
- return RandomState(seed_or_key)
115
-
116
-
117
- def set_key(seed_or_key: SeedOrKey):
118
- """Sets a new random key.
119
-
120
- Parameters
121
- ----------
122
- seed_or_key: int
123
- The random key.
124
- """
125
- if isinstance(seed_or_key, int):
126
- # key = jax.random.key(seed_or_key)
127
- key = jax.random.PRNGKey(seed_or_key) if use_prng_key else jrjax.random.key(seed_or_key)
128
- elif isinstance(seed_or_key, (jax.numpy.ndarray, np.ndarray)):
129
- if jax.numpy.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
130
- key = seed_or_key
131
- elif seed_or_key.size == 2 and seed_or_key.dtype == jax.numpy.uint32:
132
- key = seed_or_key
133
- else:
134
- raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
135
- DEFAULT.set_key(key)
136
-
137
-
138
- def get_key():
139
- """Get a new random key.
140
-
141
- Returns
142
- -------
143
- seed_or_key: int
144
- The random key.
145
- """
146
- return DEFAULT.value
147
-
148
-
149
- def seed(seed_or_key: SeedOrKey = None):
150
- """Sets a new random seed.
151
-
152
- Parameters
153
- ----------
154
- seed_or_key: int, optional
155
- The random seed (an integer) or jax random key.
156
- """
157
- with jax.ensure_compile_time_eval():
158
- _set_numpy_seed = True
159
- if seed_or_key is None:
160
- seed_or_key = np.random.randint(0, 100000)
161
- _set_numpy_seed = False
162
-
163
- # numpy random seed
164
- if _set_numpy_seed:
165
- try:
166
- if np.size(seed_or_key) == 1: # seed
167
- np.random.seed(seed_or_key)
168
- elif np.size(seed_or_key) == 2: # jax random key
169
- np.random.seed(seed_or_key[0])
170
- else:
171
- raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
172
- except jax.errors.TracerArrayConversionError:
173
- pass
174
-
175
- # jax random seed
176
- DEFAULT.seed(seed_or_key)
177
-
178
-
179
- @contextmanager
180
- def seed_context(seed_or_key: SeedOrKey):
181
- """
182
- A context manager that sets the random seed for the duration of the block.
183
-
184
- Examples:
185
-
186
- >>> import brainstate as brainstate
187
- >>> print(brainstate.random.rand(2))
188
- [0.57721865 0.9820676 ]
189
- >>> print(brainstate.random.rand(2))
190
- [0.8511752 0.95312667]
191
- >>> with brainstate.random.seed_context(42):
192
- ... print(brainstate.random.rand(2))
193
- [0.95598125 0.4032725 ]
194
- >>> with brainstate.random.seed_context(42):
195
- ... print(brainstate.random.rand(2))
196
- [0.95598125 0.4032725 ]
197
-
198
- Args:
199
- seed_or_key: The seed (an integer) or jax random key.
200
-
201
- """
202
- # get the old random key
203
- old_jrand_key = DEFAULT.value
204
- try:
205
- # set the seed of jax random state
206
- DEFAULT.seed(seed_or_key)
207
- yield
208
- finally:
209
- # restore the random state
210
- DEFAULT.seed(old_jrand_key)
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Random seed management utilities for BrainState.
18
+
19
+ This module provides comprehensive random seed management functionality, enabling
20
+ reproducible computations across JAX and NumPy backends. It supports both traditional
21
+ integer seeds and JAX's PRNG key system, providing a unified interface for random
22
+ number generation in scientific computing and machine learning applications.
23
+
24
+ Key Features:
25
+ - Unified seed management for JAX and NumPy
26
+ - Context managers for temporary seed changes
27
+ - Key splitting for parallel computation
28
+ - Automatic seed backup and restoration
29
+ - Thread-safe random state management
30
+
31
+ Example:
32
+ Basic usage for reproducible random number generation:
33
+
34
+ >>> import brainstate
35
+ >>> brainstate.random.seed(42)
36
+ >>> print(brainstate.random.rand(3))
37
+ [0.95598125 0.4032725 0.96086407]
38
+
39
+ Using context managers for temporary seeds:
40
+
41
+ >>> with brainstate.random.seed_context(123):
42
+ ... values = brainstate.random.rand(2)
43
+ >>> print(values) # Reproducible output
44
+
45
+ Key splitting for parallel computation:
46
+
47
+ >>> keys = brainstate.random.split_keys(4) # Generate 4 independent keys
48
+ >>> # Use keys for parallel random number generation
49
+ """
50
+
51
+ from contextlib import contextmanager
52
+
53
+ import jax
54
+ import numpy as np
55
+
56
+ from brainstate.typing import SeedOrKey
57
+ from ._rand_state import RandomState, DEFAULT, use_prng_key
58
+
59
+ __all__ = [
60
+ 'seed',
61
+ 'set_key',
62
+ 'get_key',
63
+ 'default_rng',
64
+ 'split_key',
65
+ 'split_keys',
66
+ 'seed_context',
67
+ 'restore_key',
68
+ 'self_assign_multi_keys',
69
+ 'clone_rng',
70
+ ]
71
+
72
+
73
+ def restore_key() -> None:
74
+ """
75
+ Restore the default random key to its previous state.
76
+
77
+ This function restores the global random state to a previously backed up state.
78
+ It's useful for undoing changes to the random state or implementing checkpoint
79
+ functionality in computational workflows.
80
+
81
+ Note:
82
+ This operation requires that a backup was previously created. If no backup
83
+ exists, this function may not have any effect or may restore to an initial state.
84
+
85
+ Example:
86
+ >>> import brainstate
87
+ >>> brainstate.random.seed(42)
88
+ >>> original_key = brainstate.random.get_key()
89
+ >>> brainstate.random.seed(123) # Change the seed
90
+ >>> brainstate.random.restore_key() # Restore to previous state
91
+ >>> assert np.array_equal(brainstate.random.get_key(), original_key)
92
+
93
+ See Also:
94
+ - :func:`set_key`: Set a new random key
95
+ - :func:`get_key`: Get the current random key
96
+ - :func:`seed_context`: Temporary seed changes with automatic restoration
97
+ """
98
+ DEFAULT.restore_key()
99
+
100
+
101
+ def split_key(n: int = None, backup: bool = False):
102
+ """
103
+ Create new random key(s) from the current seed.
104
+
105
+ This function generates one or more independent random keys by splitting the
106
+ current global random state. It follows JAX's random paradigm, ensuring that
107
+ each split key produces statistically independent random sequences.
108
+
109
+ Args:
110
+ n: The number of keys to generate. If None, returns a single key.
111
+ If an integer, returns an array of n keys.
112
+ backup: Whether to backup the current key before splitting. This allows
113
+ restoration of the original state using :func:`restore_key`.
114
+
115
+ Returns:
116
+ If n is None: A single JAX PRNG key.
117
+ If n is an integer: An array of n independent JAX PRNG keys.
118
+
119
+ Example:
120
+ Generate a single key:
121
+
122
+ >>> import brainstate
123
+ >>> brainstate.random.seed(42)
124
+ >>> key = brainstate.random.split_key()
125
+ >>> print(key.shape)
126
+ (2,)
127
+
128
+ Generate multiple keys for parallel computation:
129
+
130
+ >>> keys = brainstate.random.split_key(4)
131
+ >>> print(keys.shape)
132
+ (4, 2)
133
+
134
+ Use with backup for state restoration:
135
+
136
+ >>> original_key = brainstate.random.get_key()
137
+ >>> keys = brainstate.random.split_key(2, backup=True)
138
+ >>> brainstate.random.restore_key()
139
+ >>> assert np.array_equal(brainstate.random.get_key(), original_key)
140
+
141
+ Note:
142
+ This function advances the global random state. Each call produces
143
+ different keys unless the state is reset.
144
+
145
+ See Also:
146
+ - :func:`split_keys`: Convenience function for multiple keys
147
+ - :func:`seed`: Set the random seed
148
+ - :func:`restore_key`: Restore backed up key
149
+ """
150
+ return DEFAULT.split_key(n=n, backup=backup)
151
+
152
+
153
+ def split_keys(n: int, backup: bool = False):
154
+ """
155
+ Create multiple independent random keys from the current seed.
156
+
157
+ This is a convenience function that generates exactly n independent random keys
158
+ by splitting the current global random state. It's commonly used internally by
159
+ parallel computation functions like `pmap` and `vmap` to ensure that each
160
+ parallel thread gets a unique random key.
161
+
162
+ Args:
163
+ n: The number of independent keys to generate. Must be a positive integer.
164
+ backup: Whether to backup the current key before splitting. If True,
165
+ the original key can be restored using :func:`restore_key`.
166
+
167
+ Returns:
168
+ An array of n independent JAX PRNG keys with shape (n, 2).
169
+
170
+ Raises:
171
+ ValueError: If n is not a positive integer.
172
+
173
+ Example:
174
+ Generate keys for parallel computation:
175
+
176
+ >>> import brainstate
177
+ >>> brainstate.random.seed(42)
178
+ >>> keys = brainstate.random.split_keys(4)
179
+ >>> print(keys.shape)
180
+ (4, 2)
181
+
182
+ Use with vmap for parallel random number generation:
183
+
184
+ >>> import jax
185
+ >>> keys = brainstate.random.split_keys(8)
186
+ >>> @jax.vmap
187
+ ... def generate_random(key):
188
+ ... return jax.random.normal(key, (10,))
189
+ >>> parallel_randoms = generate_random(keys)
190
+ >>> print(parallel_randoms.shape)
191
+ (8, 10)
192
+
193
+ Use with backup for state preservation:
194
+
195
+ >>> original_state = brainstate.random.get_key()
196
+ >>> keys = brainstate.random.split_keys(3, backup=True)
197
+ >>> # ... use keys for computation ...
198
+ >>> brainstate.random.restore_key() # Restore original state
199
+
200
+ Note:
201
+ This function is equivalent to calling :func:`split_key` with n as an argument.
202
+ It's provided as a convenience function with a more explicit name for clarity.
203
+
204
+ See Also:
205
+ - :func:`split_key`: More general key splitting function
206
+ - :func:`self_assign_multi_keys`: Assign multiple keys to global state
207
+ - :func:`seed_context`: Temporary seed changes
208
+ """
209
+ if not isinstance(n, int) or n <= 0:
210
+ raise ValueError(f"n must be a positive integer, got {n}")
211
+ return split_key(n, backup=backup)
212
+
213
+
214
+ def self_assign_multi_keys(n: int, backup: bool = True) -> None:
215
+ """
216
+ Assign multiple keys to the global random state for parallel access.
217
+
218
+ This function prepares the global random state for parallel computation by
219
+ pre-generating n independent keys. It's particularly useful when you need
220
+ to ensure that parallel computations have access to independent random
221
+ sequences without the overhead of key splitting during computation.
222
+
223
+ Args:
224
+ n: The number of independent keys to pre-generate and assign.
225
+ Must be a positive integer.
226
+ backup: Whether to backup the current random state before assignment.
227
+ If True, the original state can be restored using :func:`restore_key`.
228
+
229
+ Raises:
230
+ ValueError: If n is not a positive integer.
231
+
232
+ Example:
233
+ Prepare for parallel computation:
234
+
235
+ >>> import brainstate
236
+ >>> brainstate.random.seed(42)
237
+ >>> # Prepare 4 independent keys for parallel access
238
+ >>> brainstate.random.self_assign_multi_keys(4)
239
+
240
+ Use in parallel context:
241
+
242
+ >>> # The random state now has 4 independent keys ready for use
243
+ >>> # Each parallel thread can access a different key
244
+
245
+ Note:
246
+ This is an advanced function primarily used internally for optimizing
247
+ parallel random number generation. In most cases, :func:`split_keys`
248
+ provides a more straightforward interface for parallel computation.
249
+
250
+ See Also:
251
+ - :func:`split_keys`: Generate multiple independent keys
252
+ - :func:`restore_key`: Restore backed up state
253
+ - :func:`seed_context`: Temporary state changes
254
+ """
255
+ if not isinstance(n, int) or n <= 0:
256
+ raise ValueError(f"n must be a positive integer, got {n}")
257
+ DEFAULT.self_assign_multi_keys(n, backup=backup)
258
+
259
+
260
+ def clone_rng(seed_or_key: SeedOrKey = None, clone: bool = True) -> RandomState:
261
+ """
262
+ Create a clone of the random state or a new random state.
263
+
264
+ This function provides a flexible way to create independent random states,
265
+ either by cloning the current global state or by creating a new state with
266
+ a specific seed or key. Cloned states are independent and don't affect each
267
+ other when used for random number generation.
268
+
269
+ Args:
270
+ seed_or_key: Optional seed (integer) or JAX random key to initialize
271
+ the new random state. If None, uses the current global state.
272
+ clone: Whether to clone the default random state. If False and
273
+ seed_or_key is None, returns the global state directly (not recommended
274
+ for most use cases as it shares state).
275
+
276
+ Returns:
277
+ A RandomState instance that can be used independently for random
278
+ number generation.
279
+
280
+ Example:
281
+ Clone the current global state:
282
+
283
+ >>> import brainstate
284
+ >>> brainstate.random.seed(42)
285
+ >>> rng1 = brainstate.random.clone_rng()
286
+ >>> rng2 = brainstate.random.clone_rng()
287
+ >>> # rng1 and rng2 are independent copies
288
+
289
+ Create a new state with specific seed:
290
+
291
+ >>> rng_fixed = brainstate.random.clone_rng(123)
292
+ >>> # Always produces the same sequences when reset to seed 123
293
+
294
+ Use for independent computations:
295
+
296
+ >>> rng = brainstate.random.clone_rng(456)
297
+ >>> values1 = rng.normal(size=5)
298
+ >>> values2 = rng.normal(size=5)
299
+ >>> # values1 and values2 are different but reproducible
300
+
301
+ Note:
302
+ Cloned random states are completely independent. Changes to one state
303
+ (like advancing through random number generation) don't affect others.
304
+
305
+ See Also:
306
+ - :func:`default_rng`: Get or create a random state
307
+ - :func:`seed`: Set the global random seed
308
+ - :class:`RandomState`: The random state class
309
+ """
310
+ if seed_or_key is None:
311
+ return DEFAULT.clone() if clone else DEFAULT
312
+ else:
313
+ return RandomState(seed_or_key)
314
+
315
+
316
+ def default_rng(seed_or_key: SeedOrKey = None) -> RandomState:
317
+ """
318
+ Get the default random state or create a new one with specified seed.
319
+
320
+ This function provides access to the global random state used throughout
321
+ BrainState, or creates a new independent random state if a seed is provided.
322
+ It's the primary interface for obtaining random state objects in BrainState.
323
+
324
+ Args:
325
+ seed_or_key: Optional seed (integer) or JAX random key. If None,
326
+ returns the global default random state. If provided, creates
327
+ a new independent RandomState with the specified seed.
328
+
329
+ Returns:
330
+ The default RandomState if seed_or_key is None, otherwise a new
331
+ RandomState initialized with the provided seed or key.
332
+
333
+ Example:
334
+ Get the global random state:
335
+
336
+ >>> import brainstate
337
+ >>> rng = brainstate.random.default_rng()
338
+ >>> # rng is the global random state used by brainstate.random functions
339
+
340
+ Create a new independent random state:
341
+
342
+ >>> rng_local = brainstate.random.default_rng(42)
343
+ >>> values = rng_local.normal(size=10)
344
+
345
+ Use for reproducible local computations:
346
+
347
+ >>> def reproducible_computation():
348
+ ... local_rng = brainstate.random.default_rng(12345)
349
+ ... return local_rng.uniform(size=5)
350
+ >>> result1 = reproducible_computation()
351
+ >>> result2 = reproducible_computation()
352
+ >>> assert np.allclose(result1, result2) # Always the same
353
+
354
+ Note:
355
+ When seed_or_key is None, this returns the actual global state object.
356
+ Modifications to this state (through random number generation) will
357
+ affect all subsequent calls to global random functions.
358
+
359
+ See Also:
360
+ - :func:`clone_rng`: Create independent clones of random states
361
+ - :func:`seed`: Set the global random seed
362
+ - :class:`RandomState`: The underlying random state implementation
363
+ """
364
+ if seed_or_key is None:
365
+ return DEFAULT
366
+ else:
367
+ return RandomState(seed_or_key)
368
+
369
+
370
+ def set_key(seed_or_key: SeedOrKey) -> None:
371
+ """
372
+ Set a new random key for the global random state.
373
+
374
+ This function updates the global random state with a new key, which can be
375
+ either an integer seed or a JAX PRNG key. All subsequent calls to global
376
+ random functions will use this new key state.
377
+
378
+ Args:
379
+ seed_or_key: The new random key to set. Can be:
380
+ - An integer seed (will be converted to a JAX PRNG key)
381
+ - A JAX PRNG key array
382
+ - A numpy array representing a PRNG key
383
+
384
+ Raises:
385
+ ValueError: If the provided key is not in a valid format.
386
+
387
+ Example:
388
+ Set with integer seed:
389
+
390
+ >>> import brainstate
391
+ >>> brainstate.random.set_key(42)
392
+ >>> values1 = brainstate.random.rand(3)
393
+
394
+ Set with JAX key:
395
+
396
+ >>> import jax
397
+ >>> key = jax.random.key(123)
398
+ >>> brainstate.random.set_key(key)
399
+ >>> values2 = brainstate.random.rand(3)
400
+
401
+ Restore reproducible state:
402
+
403
+ >>> brainstate.random.set_key(42)
404
+ >>> # Now random functions will produce the same sequences as first example
405
+
406
+ Note:
407
+ This function immediately changes the global random state. All threads
408
+ and computations using the global random functions will be affected.
409
+
410
+ See Also:
411
+ - :func:`get_key`: Get the current random key
412
+ - :func:`seed`: Set seed (also affects NumPy)
413
+ - :func:`restore_key`: Restore a backed up key
414
+ """
415
+ if isinstance(seed_or_key, int):
416
+ # Create key using appropriate JAX function based on version
417
+ key = jax.random.PRNGKey(seed_or_key) if use_prng_key else jax.random.key(seed_or_key)
418
+ elif isinstance(seed_or_key, (jax.numpy.ndarray, np.ndarray)):
419
+ if jax.numpy.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
420
+ key = seed_or_key
421
+ elif seed_or_key.size == 2 and seed_or_key.dtype == jax.numpy.uint32:
422
+ key = seed_or_key
423
+ else:
424
+ raise ValueError(
425
+ f"seed_or_key should be an integer, a JAX PRNG key, or a uint32 array of size 2. "
426
+ f"Got array with dtype {seed_or_key.dtype} and size {seed_or_key.size}."
427
+ )
428
+ else:
429
+ raise ValueError(
430
+ f"seed_or_key must be an integer or a JAX-compatible array. "
431
+ f"Got {type(seed_or_key)}."
432
+ )
433
+ DEFAULT.set_key(key)
434
+
435
+
436
+ def get_key():
437
+ """
438
+ Get the current global random key.
439
+
440
+ This function returns the current random key used by the global random state.
441
+ The returned key represents the internal state of the JAX PRNG and can be used
442
+ to restore the random state later or to create independent random number generators.
443
+
444
+ Returns:
445
+ The current JAX PRNG key as a numpy array. This is typically a 2-element
446
+ uint32 array representing the internal state of the random number generator.
447
+
448
+ Example:
449
+ Get and store the current random state:
450
+
451
+ >>> import brainstate
452
+ >>> brainstate.random.seed(42)
453
+ >>> current_key = brainstate.random.get_key()
454
+ >>> print(current_key.shape)
455
+ (2,)
456
+
457
+ Use the key to restore state later:
458
+
459
+ >>> # Generate some random numbers
460
+ >>> values1 = brainstate.random.rand(3)
461
+ >>> # Restore the previous state
462
+ >>> brainstate.random.set_key(current_key)
463
+ >>> values2 = brainstate.random.rand(3)
464
+ >>> # values1 and values2 will be identical
465
+
466
+ Compare keys for debugging:
467
+
468
+ >>> brainstate.random.seed(123)
469
+ >>> key1 = brainstate.random.get_key()
470
+ >>> brainstate.random.seed(123)
471
+ >>> key2 = brainstate.random.get_key()
472
+ >>> assert jax.numpy.array_equal(key1, key2) # Same seed gives same key
473
+
474
+ Note:
475
+ The returned key is a snapshot of the current state. Subsequent calls to
476
+ random functions will advance the internal state, so calling get_key()
477
+ again will return a different key unless the state is reset.
478
+
479
+ See Also:
480
+ - :func:`set_key`: Set a new random key
481
+ - :func:`seed`: Set the random seed (also affects NumPy)
482
+ - :func:`split_key`: Create new keys from current state
483
+ - :func:`seed_context`: Temporary seed changes with automatic restoration
484
+
485
+ """
486
+ return DEFAULT.value
487
+
488
+
489
+ def seed(seed_or_key: SeedOrKey = None):
490
+ """
491
+ Set the global random seed for both JAX and NumPy.
492
+
493
+ This function initializes the global random state with a new seed, affecting
494
+ both JAX and NumPy random number generators. It ensures reproducible random
495
+ number generation across the entire BrainState ecosystem.
496
+
497
+ Args:
498
+ seed_or_key: The seed or key to set. Can be:
499
+ - None: Generates a random seed automatically
500
+ - int: An integer seed (0 to 2^32-1)
501
+ - JAX PRNG key: A JAX random key array
502
+ If None, a random seed is generated using NumPy's random generator.
503
+
504
+ Raises:
505
+ ValueError: If seed_or_key is not a valid seed format (not an integer,
506
+ valid JAX key, or None).
507
+
508
+ Example:
509
+ Set a specific seed for reproducible results:
510
+
511
+ >>> import brainstate
512
+ >>> brainstate.random.seed(42)
513
+ >>> values1 = brainstate.random.rand(3)
514
+ >>> brainstate.random.seed(42) # Reset to same seed
515
+ >>> values2 = brainstate.random.rand(3)
516
+ >>> assert np.allclose(values1, values2) # Same values
517
+
518
+ Use automatic random seeding:
519
+
520
+ >>> brainstate.random.seed() # Uses random seed
521
+ >>> # Each call will produce different sequences
522
+
523
+ Use with JAX keys:
524
+
525
+ >>> import jax
526
+ >>> key = jax.random.key(123)
527
+ >>> brainstate.random.seed(key)
528
+ >>> # Now both JAX and NumPy use consistent seeds
529
+
530
+ Ensure reproducibility in scientific experiments:
531
+
532
+ >>> def experiment():
533
+ ... brainstate.random.seed(12345) # Fixed seed for reproducibility
534
+ ... data = brainstate.random.normal(size=(100, 10))
535
+ ... return data.mean()
536
+ >>> result1 = experiment()
537
+ >>> result2 = experiment()
538
+ >>> assert result1 == result2 # Always same result
539
+
540
+ Note:
541
+ - This function affects the global random state used by all BrainState
542
+ random functions and NumPy's global random state.
543
+ - When using automatic seeding (seed_or_key=None), NumPy's seed is not
544
+ set to maintain its current state.
545
+ - JAX compilation is handled automatically with compile-time evaluation.
546
+ - For JAX keys, only the first element is used to seed NumPy to maintain
547
+ compatibility between the two random systems.
548
+
549
+ See Also:
550
+ - :func:`set_key`: Set only the JAX random key
551
+ - :func:`get_key`: Get the current random key
552
+ - :func:`seed_context`: Temporary seed changes
553
+ - :func:`split_key`: Create independent random keys
554
+
555
+ """
556
+ with jax.ensure_compile_time_eval():
557
+ _set_numpy_seed = True
558
+ if seed_or_key is None:
559
+ seed_or_key = np.random.randint(0, 100000)
560
+ _set_numpy_seed = False
561
+
562
+ # numpy random seed
563
+ if _set_numpy_seed:
564
+ try:
565
+ if np.size(seed_or_key) == 1: # seed
566
+ np.random.seed(seed_or_key)
567
+ elif np.size(seed_or_key) == 2: # jax random key
568
+ np.random.seed(seed_or_key[0])
569
+ else:
570
+ raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
571
+ except jax.errors.TracerArrayConversionError:
572
+ pass
573
+
574
+ # jax random seed
575
+ DEFAULT.seed(seed_or_key)
576
+
577
+
578
+ @contextmanager
579
+ def seed_context(seed_or_key: SeedOrKey):
580
+ """
581
+ Context manager for temporary random seed changes with automatic restoration.
582
+
583
+ This context manager temporarily changes the global random seed for the duration
584
+ of the block, then automatically restores the previous random state when exiting.
585
+ It's ideal for ensuring reproducible computations in specific code sections without
586
+ permanently affecting the global random state.
587
+
588
+ Args:
589
+ seed_or_key: The temporary seed or key to use within the context. Can be:
590
+ - int: An integer seed for reproducible sequences
591
+ - JAX PRNG key: A JAX random key array
592
+ The seed affects both JAX and NumPy random states during the context.
593
+
594
+ Yields:
595
+ None: The context manager doesn't yield any value, but provides a
596
+ controlled random environment for the enclosed code block.
597
+
598
+ Example:
599
+ Reproducible computations without affecting global state:
600
+
601
+ >>> import brainstate
602
+ >>> # Global state remains unaffected
603
+ >>> global_values1 = brainstate.random.rand(2)
604
+ >>>
605
+ >>> with brainstate.random.seed_context(42):
606
+ ... temp_values1 = brainstate.random.rand(2)
607
+ ... print(f"First run: {temp_values1}")
608
+ [0.95598125 0.4032725 ]
609
+ >>>
610
+ >>> with brainstate.random.seed_context(42):
611
+ ... temp_values2 = brainstate.random.rand(2)
612
+ ... print(f"Second run: {temp_values2}")
613
+ [0.95598125 0.4032725 ]
614
+ >>>
615
+ >>> # Values are identical within context
616
+ >>> assert np.allclose(temp_values1, temp_values2)
617
+ >>>
618
+ >>> # Global state continues from where it left off
619
+ >>> global_values2 = brainstate.random.rand(2)
620
+
621
+ Nested contexts for complex scenarios:
622
+
623
+ >>> with brainstate.random.seed_context(123):
624
+ ... outer_values = brainstate.random.rand(2)
625
+ ... with brainstate.random.seed_context(456):
626
+ ... inner_values = brainstate.random.rand(2)
627
+ ... # Outer context is restored here
628
+ ... outer_values2 = brainstate.random.rand(2)
629
+
630
+ Exception safety - state is restored even on errors:
631
+
632
+ >>> try:
633
+ ... with brainstate.random.seed_context(789):
634
+ ... some_values = brainstate.random.rand(3)
635
+ ... raise ValueError("Something went wrong")
636
+ ... except ValueError:
637
+ ... pass
638
+ >>> # Random state is properly restored
639
+
640
+ Testing reproducible algorithms:
641
+
642
+ >>> def test_algorithm():
643
+ ... with brainstate.random.seed_context(42):
644
+ ... data = brainstate.random.normal(size=(100,))
645
+ ... return data.mean()
646
+ >>>
647
+ >>> result1 = test_algorithm()
648
+ >>> result2 = test_algorithm()
649
+ >>> assert result1 == result2 # Always same result
650
+
651
+ Note:
652
+ - The context manager saves and restores the complete JAX random state
653
+ - NumPy's random state is also temporarily modified during the context
654
+ - Nested contexts work correctly - each level restores its own state
655
+ - Exception safety is guaranteed - random state is restored even if
656
+ exceptions occur within the context
657
+ - This is more convenient than manually saving/restoring state with
658
+ get_key() and set_key()
659
+
660
+ See Also:
661
+ - :func:`seed`: Permanently set the global random seed
662
+ - :func:`get_key`: Get the current random key for manual state management
663
+ - :func:`set_key`: Set the random key for manual state management
664
+ - :func:`clone_rng`: Create independent random states
665
+
666
+ """
667
+ # get the old random key
668
+ old_jrand_key = DEFAULT.value
669
+ try:
670
+ # set the seed of jax random state
671
+ DEFAULT.seed(seed_or_key)
672
+ yield
673
+ finally:
674
+ # restore the random state
675
+ DEFAULT.seed(old_jrand_key)