brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/random/__init__.py
CHANGED
@@ -1,270 +1,270 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""
|
17
|
-
Random number generation module for BrainState.
|
18
|
-
|
19
|
-
This module provides a comprehensive set of random number generation functions and utilities
|
20
|
-
for neural network simulations and scientific computing. It wraps JAX's random number
|
21
|
-
generation capabilities with a stateful interface that simplifies usage while maintaining
|
22
|
-
reproducibility and performance.
|
23
|
-
|
24
|
-
The module includes:
|
25
|
-
|
26
|
-
- Standard random distributions (uniform, normal, exponential, etc.)
|
27
|
-
- Random state management with automatic key splitting
|
28
|
-
- Seed management utilities for reproducible simulations
|
29
|
-
- NumPy-compatible API for easy migration
|
30
|
-
|
31
|
-
Key Features
|
32
|
-
------------
|
33
|
-
|
34
|
-
- **Stateful random generation**: Automatic management of JAX's PRNG keys
|
35
|
-
- **NumPy compatibility**: Drop-in replacement for most NumPy random functions
|
36
|
-
- **Reproducibility**: Robust seed management and state tracking
|
37
|
-
- **Performance**: JIT-compiled random functions for efficient generation
|
38
|
-
- **Thread-safe**: Proper handling of random state in parallel computations
|
39
|
-
|
40
|
-
Random State Management
|
41
|
-
-----------------------
|
42
|
-
|
43
|
-
The module uses a global `DEFAULT` RandomState instance that automatically manages
|
44
|
-
JAX's PRNG keys. This eliminates the need to manually track and split keys:
|
45
|
-
|
46
|
-
.. code-block:: python
|
47
|
-
|
48
|
-
>>> import brainstate as bs
|
49
|
-
>>> import brainstate.random as bsr
|
50
|
-
>>>
|
51
|
-
>>> # Set a global seed for reproducibility
|
52
|
-
>>> bsr.seed(42)
|
53
|
-
>>>
|
54
|
-
>>> # Generate random numbers without manual key management
|
55
|
-
>>> x = bsr.normal(0, 1, size=(3, 3))
|
56
|
-
>>> y = bsr.uniform(0, 1, size=(100,))
|
57
|
-
|
58
|
-
Custom Random States
|
59
|
-
--------------------
|
60
|
-
|
61
|
-
For more control, you can create custom RandomState instances:
|
62
|
-
|
63
|
-
.. code-block:: python
|
64
|
-
|
65
|
-
>>> import brainstate.random as bsr
|
66
|
-
>>>
|
67
|
-
>>> # Create a custom random state
|
68
|
-
>>> rng = bsr.RandomState(seed=123)
|
69
|
-
>>>
|
70
|
-
>>> # Use it for generation
|
71
|
-
>>> data = rng.normal(0, 1, size=(10, 10))
|
72
|
-
>>>
|
73
|
-
>>> # Get the current key
|
74
|
-
>>> current_key = rng.value
|
75
|
-
|
76
|
-
Available Distributions
|
77
|
-
-----------------------
|
78
|
-
|
79
|
-
The module provides a wide range of probability distributions:
|
80
|
-
|
81
|
-
**Uniform Distributions:**
|
82
|
-
|
83
|
-
- `rand`, `random`, `random_sample`, `ranf`, `sample` - Uniform [0, 1)
|
84
|
-
- `randint`, `random_integers` - Uniform integers
|
85
|
-
- `choice` - Random selection from array
|
86
|
-
- `permutation`, `shuffle` - Random ordering
|
87
|
-
|
88
|
-
**Normal Distributions:**
|
89
|
-
|
90
|
-
- `randn`, `normal` - Normal (Gaussian) distribution
|
91
|
-
- `standard_normal` - Standard normal distribution
|
92
|
-
- `multivariate_normal` - Multivariate normal distribution
|
93
|
-
- `truncated_normal` - Truncated normal distribution
|
94
|
-
|
95
|
-
**Other Continuous Distributions:**
|
96
|
-
|
97
|
-
- `beta` - Beta distribution
|
98
|
-
- `exponential`, `standard_exponential` - Exponential distribution
|
99
|
-
- `gamma`, `standard_gamma` - Gamma distribution
|
100
|
-
- `gumbel` - Gumbel distribution
|
101
|
-
- `laplace` - Laplace distribution
|
102
|
-
- `logistic` - Logistic distribution
|
103
|
-
- `pareto` - Pareto distribution
|
104
|
-
- `rayleigh` - Rayleigh distribution
|
105
|
-
- `standard_cauchy` - Cauchy distribution
|
106
|
-
- `standard_t` - Student's t-distribution
|
107
|
-
- `uniform` - Uniform distribution over [low, high)
|
108
|
-
- `weibull` - Weibull distribution
|
109
|
-
|
110
|
-
**Discrete Distributions:**
|
111
|
-
|
112
|
-
- `bernoulli` - Bernoulli distribution
|
113
|
-
- `binomial` - Binomial distribution
|
114
|
-
- `poisson` - Poisson distribution
|
115
|
-
|
116
|
-
Seed Management
|
117
|
-
---------------
|
118
|
-
|
119
|
-
The module provides utilities for managing random seeds:
|
120
|
-
|
121
|
-
.. code-block:: python
|
122
|
-
|
123
|
-
>>> import brainstate.random as bsr
|
124
|
-
>>>
|
125
|
-
>>> # Set a global seed
|
126
|
-
>>> bsr.seed(42)
|
127
|
-
>>>
|
128
|
-
>>> # Get current seed/key
|
129
|
-
>>> key = bsr.get_key()
|
130
|
-
>>>
|
131
|
-
>>> # Split the key for parallel operations
|
132
|
-
>>> keys = bsr.split_key(n=4)
|
133
|
-
>>>
|
134
|
-
>>> # Use context manager for temporary seed
|
135
|
-
>>> with bsr.local_seed(123):
|
136
|
-
... x = bsr.normal(0, 1, (5,)) # Uses seed 123
|
137
|
-
>>> y = bsr.normal(0, 1, (5,)) # Uses original seed
|
138
|
-
|
139
|
-
Examples
|
140
|
-
--------
|
141
|
-
|
142
|
-
**Basic random number generation:**
|
143
|
-
|
144
|
-
.. code-block:: python
|
145
|
-
|
146
|
-
>>> import brainstate.random as bsr
|
147
|
-
>>> import jax.numpy as jnp
|
148
|
-
>>>
|
149
|
-
>>> # Set seed for reproducibility
|
150
|
-
>>> bsr.seed(0)
|
151
|
-
>>>
|
152
|
-
>>> # Generate uniform random numbers
|
153
|
-
>>> uniform_data = bsr.random((3, 3))
|
154
|
-
>>> print(uniform_data.shape)
|
155
|
-
(3, 3)
|
156
|
-
>>>
|
157
|
-
>>> # Generate normal random numbers
|
158
|
-
>>> normal_data = bsr.normal(loc=0, scale=1, size=(100,))
|
159
|
-
>>> print(f"Mean: {normal_data.mean():.3f}, Std: {normal_data.std():.3f}")
|
160
|
-
Mean: -0.045, Std: 0.972
|
161
|
-
|
162
|
-
**Sampling and shuffling:**
|
163
|
-
|
164
|
-
.. code-block:: python
|
165
|
-
|
166
|
-
>>> import brainstate.random as bsr
|
167
|
-
>>> import jax.numpy as jnp
|
168
|
-
>>>
|
169
|
-
>>> bsr.seed(42)
|
170
|
-
>>>
|
171
|
-
>>> # Random choice from array
|
172
|
-
>>> arr = jnp.array([1, 2, 3, 4, 5])
|
173
|
-
>>> samples = bsr.choice(arr, size=3, replace=False)
|
174
|
-
>>> print(samples)
|
175
|
-
[4 1 5]
|
176
|
-
>>>
|
177
|
-
>>> # Random permutation
|
178
|
-
>>> perm = bsr.permutation(10)
|
179
|
-
>>> print(perm)
|
180
|
-
[3 5 1 7 9 0 2 8 4 6]
|
181
|
-
>>>
|
182
|
-
>>> # In-place shuffle
|
183
|
-
>>> data = jnp.arange(5)
|
184
|
-
>>> bsr.shuffle(data)
|
185
|
-
>>> print(data)
|
186
|
-
[2 0 4 1 3]
|
187
|
-
|
188
|
-
**Advanced distributions:**
|
189
|
-
|
190
|
-
.. code-block:: python
|
191
|
-
|
192
|
-
>>> import brainstate.random as bsr
|
193
|
-
>>> import matplotlib.pyplot as plt
|
194
|
-
>>>
|
195
|
-
>>> bsr.seed(123)
|
196
|
-
>>>
|
197
|
-
>>> # Generate samples from different distributions
|
198
|
-
>>> normal_samples = bsr.normal(0, 1, 1000)
|
199
|
-
>>> exponential_samples = bsr.exponential(1.0, 1000)
|
200
|
-
>>> beta_samples = bsr.beta(2, 5, 1000)
|
201
|
-
>>>
|
202
|
-
>>> # Plot histograms
|
203
|
-
>>> fig, axes = plt.subplots(1, 3, figsize=(12, 4))
|
204
|
-
>>> axes[0].hist(normal_samples, bins=30, density=True)
|
205
|
-
>>> axes[0].set_title('Normal Distribution')
|
206
|
-
>>> axes[1].hist(exponential_samples, bins=30, density=True)
|
207
|
-
>>> axes[1].set_title('Exponential Distribution')
|
208
|
-
>>> axes[2].hist(beta_samples, bins=30, density=True)
|
209
|
-
>>> axes[2].set_title('Beta Distribution')
|
210
|
-
>>> plt.show()
|
211
|
-
|
212
|
-
**Using with neural network simulations:**
|
213
|
-
|
214
|
-
.. code-block:: python
|
215
|
-
|
216
|
-
>>> import brainstate as bs
|
217
|
-
>>> import brainstate.random as bsr
|
218
|
-
>>> import brainstate.nn as nn
|
219
|
-
>>>
|
220
|
-
>>> class NoisyNeuron(bs.Module):
|
221
|
-
... def __init__(self, n_neurons, noise_scale=0.1):
|
222
|
-
... super().__init__()
|
223
|
-
... self.n_neurons = n_neurons
|
224
|
-
... self.noise_scale = noise_scale
|
225
|
-
... self.membrane = bs.State(jnp.zeros(n_neurons))
|
226
|
-
...
|
227
|
-
... def update(self, input_current):
|
228
|
-
... # Add noise to input current
|
229
|
-
... noise = bsr.normal(0, self.noise_scale, self.n_neurons)
|
230
|
-
... self.membrane.value += input_current + noise
|
231
|
-
... return self.membrane.value
|
232
|
-
>>>
|
233
|
-
>>> # Create and run noisy neuron model
|
234
|
-
>>> bsr.seed(42)
|
235
|
-
>>> neuron = NoisyNeuron(100)
|
236
|
-
>>> output = neuron.update(jnp.ones(100) * 0.5)
|
237
|
-
|
238
|
-
Notes
|
239
|
-
-----
|
240
|
-
|
241
|
-
- This module is designed to work seamlessly with JAX's functional programming model
|
242
|
-
- Random functions are JIT-compilable for optimal performance
|
243
|
-
- The global DEFAULT state is thread-local to avoid race conditions
|
244
|
-
- For deterministic results, always set a seed before random operations
|
245
|
-
|
246
|
-
See Also
|
247
|
-
--------
|
248
|
-
|
249
|
-
jax.random : JAX's random number generation module
|
250
|
-
numpy.random : NumPy's random number generation module
|
251
|
-
RandomState : The stateful random number generator class
|
252
|
-
|
253
|
-
References
|
254
|
-
----------
|
255
|
-
.. [1] JAX Random Number Generation:
|
256
|
-
https://jax.readthedocs.io/en/latest/jax.random.html
|
257
|
-
.. [2] NumPy Random Sampling:
|
258
|
-
https://numpy.org/doc/stable/reference/random/index.html
|
259
|
-
|
260
|
-
"""
|
261
|
-
|
262
|
-
from .
|
263
|
-
from .
|
264
|
-
from .
|
265
|
-
from .
|
266
|
-
from .
|
267
|
-
from .
|
268
|
-
|
269
|
-
__all__ = __all_random__ + __all_state__ + __all_seed__
|
270
|
-
del __all_random__, __all_state__, __all_seed__
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""
|
17
|
+
Random number generation module for BrainState.
|
18
|
+
|
19
|
+
This module provides a comprehensive set of random number generation functions and utilities
|
20
|
+
for neural network simulations and scientific computing. It wraps JAX's random number
|
21
|
+
generation capabilities with a stateful interface that simplifies usage while maintaining
|
22
|
+
reproducibility and performance.
|
23
|
+
|
24
|
+
The module includes:
|
25
|
+
|
26
|
+
- Standard random distributions (uniform, normal, exponential, etc.)
|
27
|
+
- Random state management with automatic key splitting
|
28
|
+
- Seed management utilities for reproducible simulations
|
29
|
+
- NumPy-compatible API for easy migration
|
30
|
+
|
31
|
+
Key Features
|
32
|
+
------------
|
33
|
+
|
34
|
+
- **Stateful random generation**: Automatic management of JAX's PRNG keys
|
35
|
+
- **NumPy compatibility**: Drop-in replacement for most NumPy random functions
|
36
|
+
- **Reproducibility**: Robust seed management and state tracking
|
37
|
+
- **Performance**: JIT-compiled random functions for efficient generation
|
38
|
+
- **Thread-safe**: Proper handling of random state in parallel computations
|
39
|
+
|
40
|
+
Random State Management
|
41
|
+
-----------------------
|
42
|
+
|
43
|
+
The module uses a global `DEFAULT` RandomState instance that automatically manages
|
44
|
+
JAX's PRNG keys. This eliminates the need to manually track and split keys:
|
45
|
+
|
46
|
+
.. code-block:: python
|
47
|
+
|
48
|
+
>>> import brainstate as bs
|
49
|
+
>>> import brainstate.random as bsr
|
50
|
+
>>>
|
51
|
+
>>> # Set a global seed for reproducibility
|
52
|
+
>>> bsr.seed(42)
|
53
|
+
>>>
|
54
|
+
>>> # Generate random numbers without manual key management
|
55
|
+
>>> x = bsr.normal(0, 1, size=(3, 3))
|
56
|
+
>>> y = bsr.uniform(0, 1, size=(100,))
|
57
|
+
|
58
|
+
Custom Random States
|
59
|
+
--------------------
|
60
|
+
|
61
|
+
For more control, you can create custom RandomState instances:
|
62
|
+
|
63
|
+
.. code-block:: python
|
64
|
+
|
65
|
+
>>> import brainstate.random as bsr
|
66
|
+
>>>
|
67
|
+
>>> # Create a custom random state
|
68
|
+
>>> rng = bsr.RandomState(seed=123)
|
69
|
+
>>>
|
70
|
+
>>> # Use it for generation
|
71
|
+
>>> data = rng.normal(0, 1, size=(10, 10))
|
72
|
+
>>>
|
73
|
+
>>> # Get the current key
|
74
|
+
>>> current_key = rng.value
|
75
|
+
|
76
|
+
Available Distributions
|
77
|
+
-----------------------
|
78
|
+
|
79
|
+
The module provides a wide range of probability distributions:
|
80
|
+
|
81
|
+
**Uniform Distributions:**
|
82
|
+
|
83
|
+
- `rand`, `random`, `random_sample`, `ranf`, `sample` - Uniform [0, 1)
|
84
|
+
- `randint`, `random_integers` - Uniform integers
|
85
|
+
- `choice` - Random selection from array
|
86
|
+
- `permutation`, `shuffle` - Random ordering
|
87
|
+
|
88
|
+
**Normal Distributions:**
|
89
|
+
|
90
|
+
- `randn`, `normal` - Normal (Gaussian) distribution
|
91
|
+
- `standard_normal` - Standard normal distribution
|
92
|
+
- `multivariate_normal` - Multivariate normal distribution
|
93
|
+
- `truncated_normal` - Truncated normal distribution
|
94
|
+
|
95
|
+
**Other Continuous Distributions:**
|
96
|
+
|
97
|
+
- `beta` - Beta distribution
|
98
|
+
- `exponential`, `standard_exponential` - Exponential distribution
|
99
|
+
- `gamma`, `standard_gamma` - Gamma distribution
|
100
|
+
- `gumbel` - Gumbel distribution
|
101
|
+
- `laplace` - Laplace distribution
|
102
|
+
- `logistic` - Logistic distribution
|
103
|
+
- `pareto` - Pareto distribution
|
104
|
+
- `rayleigh` - Rayleigh distribution
|
105
|
+
- `standard_cauchy` - Cauchy distribution
|
106
|
+
- `standard_t` - Student's t-distribution
|
107
|
+
- `uniform` - Uniform distribution over [low, high)
|
108
|
+
- `weibull` - Weibull distribution
|
109
|
+
|
110
|
+
**Discrete Distributions:**
|
111
|
+
|
112
|
+
- `bernoulli` - Bernoulli distribution
|
113
|
+
- `binomial` - Binomial distribution
|
114
|
+
- `poisson` - Poisson distribution
|
115
|
+
|
116
|
+
Seed Management
|
117
|
+
---------------
|
118
|
+
|
119
|
+
The module provides utilities for managing random seeds:
|
120
|
+
|
121
|
+
.. code-block:: python
|
122
|
+
|
123
|
+
>>> import brainstate.random as bsr
|
124
|
+
>>>
|
125
|
+
>>> # Set a global seed
|
126
|
+
>>> bsr.seed(42)
|
127
|
+
>>>
|
128
|
+
>>> # Get current seed/key
|
129
|
+
>>> key = bsr.get_key()
|
130
|
+
>>>
|
131
|
+
>>> # Split the key for parallel operations
|
132
|
+
>>> keys = bsr.split_key(n=4)
|
133
|
+
>>>
|
134
|
+
>>> # Use context manager for temporary seed
|
135
|
+
>>> with bsr.local_seed(123):
|
136
|
+
... x = bsr.normal(0, 1, (5,)) # Uses seed 123
|
137
|
+
>>> y = bsr.normal(0, 1, (5,)) # Uses original seed
|
138
|
+
|
139
|
+
Examples
|
140
|
+
--------
|
141
|
+
|
142
|
+
**Basic random number generation:**
|
143
|
+
|
144
|
+
.. code-block:: python
|
145
|
+
|
146
|
+
>>> import brainstate.random as bsr
|
147
|
+
>>> import jax.numpy as jnp
|
148
|
+
>>>
|
149
|
+
>>> # Set seed for reproducibility
|
150
|
+
>>> bsr.seed(0)
|
151
|
+
>>>
|
152
|
+
>>> # Generate uniform random numbers
|
153
|
+
>>> uniform_data = bsr.random((3, 3))
|
154
|
+
>>> print(uniform_data.shape)
|
155
|
+
(3, 3)
|
156
|
+
>>>
|
157
|
+
>>> # Generate normal random numbers
|
158
|
+
>>> normal_data = bsr.normal(loc=0, scale=1, size=(100,))
|
159
|
+
>>> print(f"Mean: {normal_data.mean():.3f}, Std: {normal_data.std():.3f}")
|
160
|
+
Mean: -0.045, Std: 0.972
|
161
|
+
|
162
|
+
**Sampling and shuffling:**
|
163
|
+
|
164
|
+
.. code-block:: python
|
165
|
+
|
166
|
+
>>> import brainstate.random as bsr
|
167
|
+
>>> import jax.numpy as jnp
|
168
|
+
>>>
|
169
|
+
>>> bsr.seed(42)
|
170
|
+
>>>
|
171
|
+
>>> # Random choice from array
|
172
|
+
>>> arr = jnp.array([1, 2, 3, 4, 5])
|
173
|
+
>>> samples = bsr.choice(arr, size=3, replace=False)
|
174
|
+
>>> print(samples)
|
175
|
+
[4 1 5]
|
176
|
+
>>>
|
177
|
+
>>> # Random permutation
|
178
|
+
>>> perm = bsr.permutation(10)
|
179
|
+
>>> print(perm)
|
180
|
+
[3 5 1 7 9 0 2 8 4 6]
|
181
|
+
>>>
|
182
|
+
>>> # In-place shuffle
|
183
|
+
>>> data = jnp.arange(5)
|
184
|
+
>>> bsr.shuffle(data)
|
185
|
+
>>> print(data)
|
186
|
+
[2 0 4 1 3]
|
187
|
+
|
188
|
+
**Advanced distributions:**
|
189
|
+
|
190
|
+
.. code-block:: python
|
191
|
+
|
192
|
+
>>> import brainstate.random as bsr
|
193
|
+
>>> import matplotlib.pyplot as plt
|
194
|
+
>>>
|
195
|
+
>>> bsr.seed(123)
|
196
|
+
>>>
|
197
|
+
>>> # Generate samples from different distributions
|
198
|
+
>>> normal_samples = bsr.normal(0, 1, 1000)
|
199
|
+
>>> exponential_samples = bsr.exponential(1.0, 1000)
|
200
|
+
>>> beta_samples = bsr.beta(2, 5, 1000)
|
201
|
+
>>>
|
202
|
+
>>> # Plot histograms
|
203
|
+
>>> fig, axes = plt.subplots(1, 3, figsize=(12, 4))
|
204
|
+
>>> axes[0].hist(normal_samples, bins=30, density=True)
|
205
|
+
>>> axes[0].set_title('Normal Distribution')
|
206
|
+
>>> axes[1].hist(exponential_samples, bins=30, density=True)
|
207
|
+
>>> axes[1].set_title('Exponential Distribution')
|
208
|
+
>>> axes[2].hist(beta_samples, bins=30, density=True)
|
209
|
+
>>> axes[2].set_title('Beta Distribution')
|
210
|
+
>>> plt.show()
|
211
|
+
|
212
|
+
**Using with neural network simulations:**
|
213
|
+
|
214
|
+
.. code-block:: python
|
215
|
+
|
216
|
+
>>> import brainstate as bs
|
217
|
+
>>> import brainstate.random as bsr
|
218
|
+
>>> import brainstate.nn as nn
|
219
|
+
>>>
|
220
|
+
>>> class NoisyNeuron(bs.Module):
|
221
|
+
... def __init__(self, n_neurons, noise_scale=0.1):
|
222
|
+
... super().__init__()
|
223
|
+
... self.n_neurons = n_neurons
|
224
|
+
... self.noise_scale = noise_scale
|
225
|
+
... self.membrane = bs.State(jnp.zeros(n_neurons))
|
226
|
+
...
|
227
|
+
... def update(self, input_current):
|
228
|
+
... # Add noise to input current
|
229
|
+
... noise = bsr.normal(0, self.noise_scale, self.n_neurons)
|
230
|
+
... self.membrane.value += input_current + noise
|
231
|
+
... return self.membrane.value
|
232
|
+
>>>
|
233
|
+
>>> # Create and run noisy neuron model
|
234
|
+
>>> bsr.seed(42)
|
235
|
+
>>> neuron = NoisyNeuron(100)
|
236
|
+
>>> output = neuron.update(jnp.ones(100) * 0.5)
|
237
|
+
|
238
|
+
Notes
|
239
|
+
-----
|
240
|
+
|
241
|
+
- This module is designed to work seamlessly with JAX's functional programming model
|
242
|
+
- Random functions are JIT-compilable for optimal performance
|
243
|
+
- The global DEFAULT state is thread-local to avoid race conditions
|
244
|
+
- For deterministic results, always set a seed before random operations
|
245
|
+
|
246
|
+
See Also
|
247
|
+
--------
|
248
|
+
|
249
|
+
jax.random : JAX's random number generation module
|
250
|
+
numpy.random : NumPy's random number generation module
|
251
|
+
RandomState : The stateful random number generator class
|
252
|
+
|
253
|
+
References
|
254
|
+
----------
|
255
|
+
.. [1] JAX Random Number Generation:
|
256
|
+
https://jax.readthedocs.io/en/latest/jax.random.html
|
257
|
+
.. [2] NumPy Random Sampling:
|
258
|
+
https://numpy.org/doc/stable/reference/random/index.html
|
259
|
+
|
260
|
+
"""
|
261
|
+
|
262
|
+
from ._fun import *
|
263
|
+
from ._fun import __all__ as __all_random__
|
264
|
+
from ._seed import *
|
265
|
+
from ._seed import __all__ as __all_seed__
|
266
|
+
from ._state import *
|
267
|
+
from ._state import __all__ as __all_state__
|
268
|
+
|
269
|
+
__all__ = __all_random__ + __all_state__ + __all_seed__
|
270
|
+
del __all_random__, __all_state__, __all_seed__
|