brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/random/_rand_seed.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -13,8 +13,42 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""
|
17
|
+
Random seed management utilities for BrainState.
|
18
|
+
|
19
|
+
This module provides comprehensive random seed management functionality, enabling
|
20
|
+
reproducible computations across JAX and NumPy backends. It supports both traditional
|
21
|
+
integer seeds and JAX's PRNG key system, providing a unified interface for random
|
22
|
+
number generation in scientific computing and machine learning applications.
|
23
|
+
|
24
|
+
Key Features:
|
25
|
+
- Unified seed management for JAX and NumPy
|
26
|
+
- Context managers for temporary seed changes
|
27
|
+
- Key splitting for parallel computation
|
28
|
+
- Automatic seed backup and restoration
|
29
|
+
- Thread-safe random state management
|
30
|
+
|
31
|
+
Example:
|
32
|
+
Basic usage for reproducible random number generation:
|
33
|
+
|
34
|
+
>>> import brainstate
|
35
|
+
>>> brainstate.random.seed(42)
|
36
|
+
>>> print(brainstate.random.rand(3))
|
37
|
+
[0.95598125 0.4032725 0.96086407]
|
38
|
+
|
39
|
+
Using context managers for temporary seeds:
|
40
|
+
|
41
|
+
>>> with brainstate.random.seed_context(123):
|
42
|
+
... values = brainstate.random.rand(2)
|
43
|
+
>>> print(values) # Reproducible output
|
44
|
+
|
45
|
+
Key splitting for parallel computation:
|
46
|
+
|
47
|
+
>>> keys = brainstate.random.split_keys(4) # Generate 4 independent keys
|
48
|
+
>>> # Use keys for parallel random number generation
|
49
|
+
"""
|
50
|
+
|
16
51
|
from contextlib import contextmanager
|
17
|
-
from typing import Optional
|
18
52
|
|
19
53
|
import jax
|
20
54
|
import numpy as np
|
@@ -23,74 +57,255 @@ from brainstate.typing import SeedOrKey
|
|
23
57
|
from ._rand_state import RandomState, DEFAULT, use_prng_key
|
24
58
|
|
25
59
|
__all__ = [
|
26
|
-
'seed',
|
60
|
+
'seed',
|
61
|
+
'set_key',
|
62
|
+
'get_key',
|
63
|
+
'default_rng',
|
64
|
+
'split_key',
|
65
|
+
'split_keys',
|
66
|
+
'seed_context',
|
67
|
+
'restore_key',
|
27
68
|
'self_assign_multi_keys',
|
69
|
+
'clone_rng',
|
28
70
|
]
|
29
71
|
|
30
72
|
|
31
|
-
def restore_key():
|
32
|
-
"""
|
73
|
+
def restore_key() -> None:
|
74
|
+
"""
|
75
|
+
Restore the default random key to its previous state.
|
76
|
+
|
77
|
+
This function restores the global random state to a previously backed up state.
|
78
|
+
It's useful for undoing changes to the random state or implementing checkpoint
|
79
|
+
functionality in computational workflows.
|
80
|
+
|
81
|
+
Note:
|
82
|
+
This operation requires that a backup was previously created. If no backup
|
83
|
+
exists, this function may not have any effect or may restore to an initial state.
|
84
|
+
|
85
|
+
Example:
|
86
|
+
>>> import brainstate
|
87
|
+
>>> brainstate.random.seed(42)
|
88
|
+
>>> original_key = brainstate.random.get_key()
|
89
|
+
>>> brainstate.random.seed(123) # Change the seed
|
90
|
+
>>> brainstate.random.restore_key() # Restore to previous state
|
91
|
+
>>> assert np.array_equal(brainstate.random.get_key(), original_key)
|
92
|
+
|
93
|
+
See Also:
|
94
|
+
- :func:`set_key`: Set a new random key
|
95
|
+
- :func:`get_key`: Get the current random key
|
96
|
+
- :func:`seed_context`: Temporary seed changes with automatic restoration
|
97
|
+
"""
|
33
98
|
DEFAULT.restore_key()
|
34
99
|
|
35
100
|
|
36
|
-
def split_key(n:
|
37
|
-
"""
|
101
|
+
def split_key(n: int = None, backup: bool = False):
|
102
|
+
"""
|
103
|
+
Create new random key(s) from the current seed.
|
104
|
+
|
105
|
+
This function generates one or more independent random keys by splitting the
|
106
|
+
current global random state. It follows JAX's random paradigm, ensuring that
|
107
|
+
each split key produces statistically independent random sequences.
|
108
|
+
|
109
|
+
Args:
|
110
|
+
n: The number of keys to generate. If None, returns a single key.
|
111
|
+
If an integer, returns an array of n keys.
|
112
|
+
backup: Whether to backup the current key before splitting. This allows
|
113
|
+
restoration of the original state using :func:`restore_key`.
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
If n is None: A single JAX PRNG key.
|
117
|
+
If n is an integer: An array of n independent JAX PRNG keys.
|
118
|
+
|
119
|
+
Example:
|
120
|
+
Generate a single key:
|
121
|
+
|
122
|
+
>>> import brainstate
|
123
|
+
>>> brainstate.random.seed(42)
|
124
|
+
>>> key = brainstate.random.split_key()
|
125
|
+
>>> print(key.shape)
|
126
|
+
(2,)
|
38
127
|
|
39
|
-
|
128
|
+
Generate multiple keys for parallel computation:
|
40
129
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
The number of seeds to generate.
|
45
|
-
backup : bool, optional
|
46
|
-
Whether to backup the current key.
|
130
|
+
>>> keys = brainstate.random.split_key(4)
|
131
|
+
>>> print(keys.shape)
|
132
|
+
(4, 2)
|
47
133
|
|
48
|
-
|
49
|
-
-------
|
50
|
-
key : jax.random.PRNGKey
|
51
|
-
A new random key.
|
134
|
+
Use with backup for state restoration:
|
52
135
|
|
136
|
+
>>> original_key = brainstate.random.get_key()
|
137
|
+
>>> keys = brainstate.random.split_key(2, backup=True)
|
138
|
+
>>> brainstate.random.restore_key()
|
139
|
+
>>> assert np.array_equal(brainstate.random.get_key(), original_key)
|
140
|
+
|
141
|
+
Note:
|
142
|
+
This function advances the global random state. Each call produces
|
143
|
+
different keys unless the state is reset.
|
144
|
+
|
145
|
+
See Also:
|
146
|
+
- :func:`split_keys`: Convenience function for multiple keys
|
147
|
+
- :func:`seed`: Set the random seed
|
148
|
+
- :func:`restore_key`: Restore backed up key
|
53
149
|
"""
|
54
150
|
return DEFAULT.split_key(n=n, backup=backup)
|
55
151
|
|
56
152
|
|
57
153
|
def split_keys(n: int, backup: bool = False):
|
58
|
-
"""
|
59
|
-
|
60
|
-
are different in parallel threads.
|
154
|
+
"""
|
155
|
+
Create multiple independent random keys from the current seed.
|
61
156
|
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
backup : bool, optional
|
67
|
-
Whether to backup the current key
|
157
|
+
This is a convenience function that generates exactly n independent random keys
|
158
|
+
by splitting the current global random state. It's commonly used internally by
|
159
|
+
parallel computation functions like `pmap` and `vmap` to ensure that each
|
160
|
+
parallel thread gets a unique random key.
|
68
161
|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
162
|
+
Args:
|
163
|
+
n: The number of independent keys to generate. Must be a positive integer.
|
164
|
+
backup: Whether to backup the current key before splitting. If True,
|
165
|
+
the original key can be restored using :func:`restore_key`.
|
73
166
|
|
167
|
+
Returns:
|
168
|
+
An array of n independent JAX PRNG keys with shape (n, 2).
|
169
|
+
|
170
|
+
Raises:
|
171
|
+
ValueError: If n is not a positive integer.
|
172
|
+
|
173
|
+
Example:
|
174
|
+
Generate keys for parallel computation:
|
175
|
+
|
176
|
+
>>> import brainstate
|
177
|
+
>>> brainstate.random.seed(42)
|
178
|
+
>>> keys = brainstate.random.split_keys(4)
|
179
|
+
>>> print(keys.shape)
|
180
|
+
(4, 2)
|
181
|
+
|
182
|
+
Use with vmap for parallel random number generation:
|
183
|
+
|
184
|
+
>>> import jax
|
185
|
+
>>> keys = brainstate.random.split_keys(8)
|
186
|
+
>>> @jax.vmap
|
187
|
+
... def generate_random(key):
|
188
|
+
... return jax.random.normal(key, (10,))
|
189
|
+
>>> parallel_randoms = generate_random(keys)
|
190
|
+
>>> print(parallel_randoms.shape)
|
191
|
+
(8, 10)
|
192
|
+
|
193
|
+
Use with backup for state preservation:
|
194
|
+
|
195
|
+
>>> original_state = brainstate.random.get_key()
|
196
|
+
>>> keys = brainstate.random.split_keys(3, backup=True)
|
197
|
+
>>> # ... use keys for computation ...
|
198
|
+
>>> brainstate.random.restore_key() # Restore original state
|
199
|
+
|
200
|
+
Note:
|
201
|
+
This function is equivalent to calling :func:`split_key` with n as an argument.
|
202
|
+
It's provided as a convenience function with a more explicit name for clarity.
|
203
|
+
|
204
|
+
See Also:
|
205
|
+
- :func:`split_key`: More general key splitting function
|
206
|
+
- :func:`self_assign_multi_keys`: Assign multiple keys to global state
|
207
|
+
- :func:`seed_context`: Temporary seed changes
|
74
208
|
"""
|
209
|
+
if not isinstance(n, int) or n <= 0:
|
210
|
+
raise ValueError(f"n must be a positive integer, got {n}")
|
75
211
|
return split_key(n, backup=backup)
|
76
212
|
|
77
213
|
|
78
|
-
def self_assign_multi_keys(n: int, backup: bool = True):
|
214
|
+
def self_assign_multi_keys(n: int, backup: bool = True) -> None:
|
79
215
|
"""
|
80
|
-
Assign multiple keys to the
|
216
|
+
Assign multiple keys to the global random state for parallel access.
|
217
|
+
|
218
|
+
This function prepares the global random state for parallel computation by
|
219
|
+
pre-generating n independent keys. It's particularly useful when you need
|
220
|
+
to ensure that parallel computations have access to independent random
|
221
|
+
sequences without the overhead of key splitting during computation.
|
222
|
+
|
223
|
+
Args:
|
224
|
+
n: The number of independent keys to pre-generate and assign.
|
225
|
+
Must be a positive integer.
|
226
|
+
backup: Whether to backup the current random state before assignment.
|
227
|
+
If True, the original state can be restored using :func:`restore_key`.
|
228
|
+
|
229
|
+
Raises:
|
230
|
+
ValueError: If n is not a positive integer.
|
231
|
+
|
232
|
+
Example:
|
233
|
+
Prepare for parallel computation:
|
234
|
+
|
235
|
+
>>> import brainstate
|
236
|
+
>>> brainstate.random.seed(42)
|
237
|
+
>>> # Prepare 4 independent keys for parallel access
|
238
|
+
>>> brainstate.random.self_assign_multi_keys(4)
|
239
|
+
|
240
|
+
Use in parallel context:
|
241
|
+
|
242
|
+
>>> # The random state now has 4 independent keys ready for use
|
243
|
+
>>> # Each parallel thread can access a different key
|
244
|
+
|
245
|
+
Note:
|
246
|
+
This is an advanced function primarily used internally for optimizing
|
247
|
+
parallel random number generation. In most cases, :func:`split_keys`
|
248
|
+
provides a more straightforward interface for parallel computation.
|
249
|
+
|
250
|
+
See Also:
|
251
|
+
- :func:`split_keys`: Generate multiple independent keys
|
252
|
+
- :func:`restore_key`: Restore backed up state
|
253
|
+
- :func:`seed_context`: Temporary state changes
|
81
254
|
"""
|
255
|
+
if not isinstance(n, int) or n <= 0:
|
256
|
+
raise ValueError(f"n must be a positive integer, got {n}")
|
82
257
|
DEFAULT.self_assign_multi_keys(n, backup=backup)
|
83
258
|
|
84
259
|
|
85
|
-
def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState:
|
86
|
-
"""
|
260
|
+
def clone_rng(seed_or_key: SeedOrKey = None, clone: bool = True) -> RandomState:
|
261
|
+
"""
|
262
|
+
Create a clone of the random state or a new random state.
|
263
|
+
|
264
|
+
This function provides a flexible way to create independent random states,
|
265
|
+
either by cloning the current global state or by creating a new state with
|
266
|
+
a specific seed or key. Cloned states are independent and don't affect each
|
267
|
+
other when used for random number generation.
|
87
268
|
|
88
269
|
Args:
|
89
|
-
|
90
|
-
|
270
|
+
seed_or_key: Optional seed (integer) or JAX random key to initialize
|
271
|
+
the new random state. If None, uses the current global state.
|
272
|
+
clone: Whether to clone the default random state. If False and
|
273
|
+
seed_or_key is None, returns the global state directly (not recommended
|
274
|
+
for most use cases as it shares state).
|
91
275
|
|
92
276
|
Returns:
|
93
|
-
|
277
|
+
A RandomState instance that can be used independently for random
|
278
|
+
number generation.
|
279
|
+
|
280
|
+
Example:
|
281
|
+
Clone the current global state:
|
282
|
+
|
283
|
+
>>> import brainstate
|
284
|
+
>>> brainstate.random.seed(42)
|
285
|
+
>>> rng1 = brainstate.random.clone_rng()
|
286
|
+
>>> rng2 = brainstate.random.clone_rng()
|
287
|
+
>>> # rng1 and rng2 are independent copies
|
288
|
+
|
289
|
+
Create a new state with specific seed:
|
290
|
+
|
291
|
+
>>> rng_fixed = brainstate.random.clone_rng(123)
|
292
|
+
>>> # Always produces the same sequences when reset to seed 123
|
293
|
+
|
294
|
+
Use for independent computations:
|
295
|
+
|
296
|
+
>>> rng = brainstate.random.clone_rng(456)
|
297
|
+
>>> values1 = rng.normal(size=5)
|
298
|
+
>>> values2 = rng.normal(size=5)
|
299
|
+
>>> # values1 and values2 are different but reproducible
|
300
|
+
|
301
|
+
Note:
|
302
|
+
Cloned random states are completely independent. Changes to one state
|
303
|
+
(like advancing through random number generation) don't affect others.
|
304
|
+
|
305
|
+
See Also:
|
306
|
+
- :func:`default_rng`: Get or create a random state
|
307
|
+
- :func:`seed`: Set the global random seed
|
308
|
+
- :class:`RandomState`: The random state class
|
94
309
|
"""
|
95
310
|
if seed_or_key is None:
|
96
311
|
return DEFAULT.clone() if clone else DEFAULT
|
@@ -98,15 +313,53 @@ def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState:
|
|
98
313
|
return RandomState(seed_or_key)
|
99
314
|
|
100
315
|
|
101
|
-
def default_rng(seed_or_key=None) -> RandomState:
|
316
|
+
def default_rng(seed_or_key: SeedOrKey = None) -> RandomState:
|
102
317
|
"""
|
103
|
-
Get the default random state.
|
318
|
+
Get the default random state or create a new one with specified seed.
|
319
|
+
|
320
|
+
This function provides access to the global random state used throughout
|
321
|
+
BrainState, or creates a new independent random state if a seed is provided.
|
322
|
+
It's the primary interface for obtaining random state objects in BrainState.
|
104
323
|
|
105
324
|
Args:
|
106
|
-
|
325
|
+
seed_or_key: Optional seed (integer) or JAX random key. If None,
|
326
|
+
returns the global default random state. If provided, creates
|
327
|
+
a new independent RandomState with the specified seed.
|
107
328
|
|
108
329
|
Returns:
|
109
|
-
|
330
|
+
The default RandomState if seed_or_key is None, otherwise a new
|
331
|
+
RandomState initialized with the provided seed or key.
|
332
|
+
|
333
|
+
Example:
|
334
|
+
Get the global random state:
|
335
|
+
|
336
|
+
>>> import brainstate
|
337
|
+
>>> rng = brainstate.random.default_rng()
|
338
|
+
>>> # rng is the global random state used by brainstate.random functions
|
339
|
+
|
340
|
+
Create a new independent random state:
|
341
|
+
|
342
|
+
>>> rng_local = brainstate.random.default_rng(42)
|
343
|
+
>>> values = rng_local.normal(size=10)
|
344
|
+
|
345
|
+
Use for reproducible local computations:
|
346
|
+
|
347
|
+
>>> def reproducible_computation():
|
348
|
+
... local_rng = brainstate.random.default_rng(12345)
|
349
|
+
... return local_rng.uniform(size=5)
|
350
|
+
>>> result1 = reproducible_computation()
|
351
|
+
>>> result2 = reproducible_computation()
|
352
|
+
>>> assert np.allclose(result1, result2) # Always the same
|
353
|
+
|
354
|
+
Note:
|
355
|
+
When seed_or_key is None, this returns the actual global state object.
|
356
|
+
Modifications to this state (through random number generation) will
|
357
|
+
affect all subsequent calls to global random functions.
|
358
|
+
|
359
|
+
See Also:
|
360
|
+
- :func:`clone_rng`: Create independent clones of random states
|
361
|
+
- :func:`seed`: Set the global random seed
|
362
|
+
- :class:`RandomState`: The underlying random state implementation
|
110
363
|
"""
|
111
364
|
if seed_or_key is None:
|
112
365
|
return DEFAULT
|
@@ -114,45 +367,191 @@ def default_rng(seed_or_key=None) -> RandomState:
|
|
114
367
|
return RandomState(seed_or_key)
|
115
368
|
|
116
369
|
|
117
|
-
def set_key(seed_or_key: SeedOrKey):
|
118
|
-
"""
|
370
|
+
def set_key(seed_or_key: SeedOrKey) -> None:
|
371
|
+
"""
|
372
|
+
Set a new random key for the global random state.
|
373
|
+
|
374
|
+
This function updates the global random state with a new key, which can be
|
375
|
+
either an integer seed or a JAX PRNG key. All subsequent calls to global
|
376
|
+
random functions will use this new key state.
|
377
|
+
|
378
|
+
Args:
|
379
|
+
seed_or_key: The new random key to set. Can be:
|
380
|
+
- An integer seed (will be converted to a JAX PRNG key)
|
381
|
+
- A JAX PRNG key array
|
382
|
+
- A numpy array representing a PRNG key
|
119
383
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
384
|
+
Raises:
|
385
|
+
ValueError: If the provided key is not in a valid format.
|
386
|
+
|
387
|
+
Example:
|
388
|
+
Set with integer seed:
|
389
|
+
|
390
|
+
>>> import brainstate
|
391
|
+
>>> brainstate.random.set_key(42)
|
392
|
+
>>> values1 = brainstate.random.rand(3)
|
393
|
+
|
394
|
+
Set with JAX key:
|
395
|
+
|
396
|
+
>>> import jax
|
397
|
+
>>> key = jax.random.key(123)
|
398
|
+
>>> brainstate.random.set_key(key)
|
399
|
+
>>> values2 = brainstate.random.rand(3)
|
400
|
+
|
401
|
+
Restore reproducible state:
|
402
|
+
|
403
|
+
>>> brainstate.random.set_key(42)
|
404
|
+
>>> # Now random functions will produce the same sequences as first example
|
405
|
+
|
406
|
+
Note:
|
407
|
+
This function immediately changes the global random state. All threads
|
408
|
+
and computations using the global random functions will be affected.
|
409
|
+
|
410
|
+
See Also:
|
411
|
+
- :func:`get_key`: Get the current random key
|
412
|
+
- :func:`seed`: Set seed (also affects NumPy)
|
413
|
+
- :func:`restore_key`: Restore a backed up key
|
124
414
|
"""
|
125
415
|
if isinstance(seed_or_key, int):
|
126
|
-
# key
|
127
|
-
key = jax.random.PRNGKey(seed_or_key) if use_prng_key else
|
416
|
+
# Create key using appropriate JAX function based on version
|
417
|
+
key = jax.random.PRNGKey(seed_or_key) if use_prng_key else jax.random.key(seed_or_key)
|
128
418
|
elif isinstance(seed_or_key, (jax.numpy.ndarray, np.ndarray)):
|
129
419
|
if jax.numpy.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
|
130
420
|
key = seed_or_key
|
131
421
|
elif seed_or_key.size == 2 and seed_or_key.dtype == jax.numpy.uint32:
|
132
422
|
key = seed_or_key
|
133
423
|
else:
|
134
|
-
raise ValueError(
|
424
|
+
raise ValueError(
|
425
|
+
f"seed_or_key should be an integer, a JAX PRNG key, or a uint32 array of size 2. "
|
426
|
+
f"Got array with dtype {seed_or_key.dtype} and size {seed_or_key.size}."
|
427
|
+
)
|
428
|
+
else:
|
429
|
+
raise ValueError(
|
430
|
+
f"seed_or_key must be an integer or a JAX-compatible array. "
|
431
|
+
f"Got {type(seed_or_key)}."
|
432
|
+
)
|
135
433
|
DEFAULT.set_key(key)
|
136
434
|
|
137
435
|
|
138
436
|
def get_key():
|
139
|
-
"""
|
437
|
+
"""
|
438
|
+
Get the current global random key.
|
439
|
+
|
440
|
+
This function returns the current random key used by the global random state.
|
441
|
+
The returned key represents the internal state of the JAX PRNG and can be used
|
442
|
+
to restore the random state later or to create independent random number generators.
|
443
|
+
|
444
|
+
Returns:
|
445
|
+
The current JAX PRNG key as a numpy array. This is typically a 2-element
|
446
|
+
uint32 array representing the internal state of the random number generator.
|
447
|
+
|
448
|
+
Example:
|
449
|
+
Get and store the current random state:
|
450
|
+
|
451
|
+
>>> import brainstate
|
452
|
+
>>> brainstate.random.seed(42)
|
453
|
+
>>> current_key = brainstate.random.get_key()
|
454
|
+
>>> print(current_key.shape)
|
455
|
+
(2,)
|
456
|
+
|
457
|
+
Use the key to restore state later:
|
458
|
+
|
459
|
+
>>> # Generate some random numbers
|
460
|
+
>>> values1 = brainstate.random.rand(3)
|
461
|
+
>>> # Restore the previous state
|
462
|
+
>>> brainstate.random.set_key(current_key)
|
463
|
+
>>> values2 = brainstate.random.rand(3)
|
464
|
+
>>> # values1 and values2 will be identical
|
465
|
+
|
466
|
+
Compare keys for debugging:
|
467
|
+
|
468
|
+
>>> brainstate.random.seed(123)
|
469
|
+
>>> key1 = brainstate.random.get_key()
|
470
|
+
>>> brainstate.random.seed(123)
|
471
|
+
>>> key2 = brainstate.random.get_key()
|
472
|
+
>>> assert jax.numpy.array_equal(key1, key2) # Same seed gives same key
|
473
|
+
|
474
|
+
Note:
|
475
|
+
The returned key is a snapshot of the current state. Subsequent calls to
|
476
|
+
random functions will advance the internal state, so calling get_key()
|
477
|
+
again will return a different key unless the state is reset.
|
478
|
+
|
479
|
+
See Also:
|
480
|
+
- :func:`set_key`: Set a new random key
|
481
|
+
- :func:`seed`: Set the random seed (also affects NumPy)
|
482
|
+
- :func:`split_key`: Create new keys from current state
|
483
|
+
- :func:`seed_context`: Temporary seed changes with automatic restoration
|
140
484
|
|
141
|
-
Returns
|
142
|
-
-------
|
143
|
-
seed_or_key: int
|
144
|
-
The random key.
|
145
485
|
"""
|
146
486
|
return DEFAULT.value
|
147
487
|
|
148
488
|
|
149
489
|
def seed(seed_or_key: SeedOrKey = None):
|
150
|
-
"""
|
490
|
+
"""
|
491
|
+
Set the global random seed for both JAX and NumPy.
|
492
|
+
|
493
|
+
This function initializes the global random state with a new seed, affecting
|
494
|
+
both JAX and NumPy random number generators. It ensures reproducible random
|
495
|
+
number generation across the entire BrainState ecosystem.
|
496
|
+
|
497
|
+
Args:
|
498
|
+
seed_or_key: The seed or key to set. Can be:
|
499
|
+
- None: Generates a random seed automatically
|
500
|
+
- int: An integer seed (0 to 2^32-1)
|
501
|
+
- JAX PRNG key: A JAX random key array
|
502
|
+
If None, a random seed is generated using NumPy's random generator.
|
503
|
+
|
504
|
+
Raises:
|
505
|
+
ValueError: If seed_or_key is not a valid seed format (not an integer,
|
506
|
+
valid JAX key, or None).
|
507
|
+
|
508
|
+
Example:
|
509
|
+
Set a specific seed for reproducible results:
|
510
|
+
|
511
|
+
>>> import brainstate
|
512
|
+
>>> brainstate.random.seed(42)
|
513
|
+
>>> values1 = brainstate.random.rand(3)
|
514
|
+
>>> brainstate.random.seed(42) # Reset to same seed
|
515
|
+
>>> values2 = brainstate.random.rand(3)
|
516
|
+
>>> assert np.allclose(values1, values2) # Same values
|
517
|
+
|
518
|
+
Use automatic random seeding:
|
519
|
+
|
520
|
+
>>> brainstate.random.seed() # Uses random seed
|
521
|
+
>>> # Each call will produce different sequences
|
522
|
+
|
523
|
+
Use with JAX keys:
|
524
|
+
|
525
|
+
>>> import jax
|
526
|
+
>>> key = jax.random.key(123)
|
527
|
+
>>> brainstate.random.seed(key)
|
528
|
+
>>> # Now both JAX and NumPy use consistent seeds
|
529
|
+
|
530
|
+
Ensure reproducibility in scientific experiments:
|
531
|
+
|
532
|
+
>>> def experiment():
|
533
|
+
... brainstate.random.seed(12345) # Fixed seed for reproducibility
|
534
|
+
... data = brainstate.random.normal(size=(100, 10))
|
535
|
+
... return data.mean()
|
536
|
+
>>> result1 = experiment()
|
537
|
+
>>> result2 = experiment()
|
538
|
+
>>> assert result1 == result2 # Always same result
|
539
|
+
|
540
|
+
Note:
|
541
|
+
- This function affects the global random state used by all BrainState
|
542
|
+
random functions and NumPy's global random state.
|
543
|
+
- When using automatic seeding (seed_or_key=None), NumPy's seed is not
|
544
|
+
set to maintain its current state.
|
545
|
+
- JAX compilation is handled automatically with compile-time evaluation.
|
546
|
+
- For JAX keys, only the first element is used to seed NumPy to maintain
|
547
|
+
compatibility between the two random systems.
|
548
|
+
|
549
|
+
See Also:
|
550
|
+
- :func:`set_key`: Set only the JAX random key
|
551
|
+
- :func:`get_key`: Get the current random key
|
552
|
+
- :func:`seed_context`: Temporary seed changes
|
553
|
+
- :func:`split_key`: Create independent random keys
|
151
554
|
|
152
|
-
Parameters
|
153
|
-
----------
|
154
|
-
seed_or_key: int, optional
|
155
|
-
The random seed (an integer) or jax random key.
|
156
555
|
"""
|
157
556
|
with jax.ensure_compile_time_eval():
|
158
557
|
_set_numpy_seed = True
|
@@ -179,24 +578,90 @@ def seed(seed_or_key: SeedOrKey = None):
|
|
179
578
|
@contextmanager
|
180
579
|
def seed_context(seed_or_key: SeedOrKey):
|
181
580
|
"""
|
182
|
-
|
183
|
-
|
184
|
-
Examples:
|
581
|
+
Context manager for temporary random seed changes with automatic restoration.
|
185
582
|
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
[0.8511752 0.95312667]
|
191
|
-
>>> with brainstate.random.seed_context(42):
|
192
|
-
... print(brainstate.random.rand(2))
|
193
|
-
[0.95598125 0.4032725 ]
|
194
|
-
>>> with brainstate.random.seed_context(42):
|
195
|
-
... print(brainstate.random.rand(2))
|
196
|
-
[0.95598125 0.4032725 ]
|
583
|
+
This context manager temporarily changes the global random seed for the duration
|
584
|
+
of the block, then automatically restores the previous random state when exiting.
|
585
|
+
It's ideal for ensuring reproducible computations in specific code sections without
|
586
|
+
permanently affecting the global random state.
|
197
587
|
|
198
588
|
Args:
|
199
|
-
|
589
|
+
seed_or_key: The temporary seed or key to use within the context. Can be:
|
590
|
+
- int: An integer seed for reproducible sequences
|
591
|
+
- JAX PRNG key: A JAX random key array
|
592
|
+
The seed affects both JAX and NumPy random states during the context.
|
593
|
+
|
594
|
+
Yields:
|
595
|
+
None: The context manager doesn't yield any value, but provides a
|
596
|
+
controlled random environment for the enclosed code block.
|
597
|
+
|
598
|
+
Example:
|
599
|
+
Reproducible computations without affecting global state:
|
600
|
+
|
601
|
+
>>> import brainstate
|
602
|
+
>>> # Global state remains unaffected
|
603
|
+
>>> global_values1 = brainstate.random.rand(2)
|
604
|
+
>>>
|
605
|
+
>>> with brainstate.random.seed_context(42):
|
606
|
+
... temp_values1 = brainstate.random.rand(2)
|
607
|
+
... print(f"First run: {temp_values1}")
|
608
|
+
[0.95598125 0.4032725 ]
|
609
|
+
>>>
|
610
|
+
>>> with brainstate.random.seed_context(42):
|
611
|
+
... temp_values2 = brainstate.random.rand(2)
|
612
|
+
... print(f"Second run: {temp_values2}")
|
613
|
+
[0.95598125 0.4032725 ]
|
614
|
+
>>>
|
615
|
+
>>> # Values are identical within context
|
616
|
+
>>> assert np.allclose(temp_values1, temp_values2)
|
617
|
+
>>>
|
618
|
+
>>> # Global state continues from where it left off
|
619
|
+
>>> global_values2 = brainstate.random.rand(2)
|
620
|
+
|
621
|
+
Nested contexts for complex scenarios:
|
622
|
+
|
623
|
+
>>> with brainstate.random.seed_context(123):
|
624
|
+
... outer_values = brainstate.random.rand(2)
|
625
|
+
... with brainstate.random.seed_context(456):
|
626
|
+
... inner_values = brainstate.random.rand(2)
|
627
|
+
... # Outer context is restored here
|
628
|
+
... outer_values2 = brainstate.random.rand(2)
|
629
|
+
|
630
|
+
Exception safety - state is restored even on errors:
|
631
|
+
|
632
|
+
>>> try:
|
633
|
+
... with brainstate.random.seed_context(789):
|
634
|
+
... some_values = brainstate.random.rand(3)
|
635
|
+
... raise ValueError("Something went wrong")
|
636
|
+
... except ValueError:
|
637
|
+
... pass
|
638
|
+
>>> # Random state is properly restored
|
639
|
+
|
640
|
+
Testing reproducible algorithms:
|
641
|
+
|
642
|
+
>>> def test_algorithm():
|
643
|
+
... with brainstate.random.seed_context(42):
|
644
|
+
... data = brainstate.random.normal(size=(100,))
|
645
|
+
... return data.mean()
|
646
|
+
>>>
|
647
|
+
>>> result1 = test_algorithm()
|
648
|
+
>>> result2 = test_algorithm()
|
649
|
+
>>> assert result1 == result2 # Always same result
|
650
|
+
|
651
|
+
Note:
|
652
|
+
- The context manager saves and restores the complete JAX random state
|
653
|
+
- NumPy's random state is also temporarily modified during the context
|
654
|
+
- Nested contexts work correctly - each level restores its own state
|
655
|
+
- Exception safety is guaranteed - random state is restored even if
|
656
|
+
exceptions occur within the context
|
657
|
+
- This is more convenient than manually saving/restoring state with
|
658
|
+
get_key() and set_key()
|
659
|
+
|
660
|
+
See Also:
|
661
|
+
- :func:`seed`: Permanently set the global random seed
|
662
|
+
- :func:`get_key`: Get the current random key for manual state management
|
663
|
+
- :func:`set_key`: Set the random key for manual state management
|
664
|
+
- :func:`clone_rng`: Create independent random states
|
200
665
|
|
201
666
|
"""
|
202
667
|
# get the old random key
|