brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -23,20 +23,20 @@ import brainstate
|
|
23
23
|
class TestNormalInit(unittest.TestCase):
|
24
24
|
|
25
25
|
def test_normal_init1(self):
|
26
|
-
init = brainstate.init.Normal()
|
26
|
+
init = brainstate.nn.init.Normal()
|
27
27
|
for size in [(100,), (10, 20), (10, 20, 30)]:
|
28
28
|
weights = init(size)
|
29
29
|
assert weights.shape == size
|
30
30
|
|
31
31
|
def test_normal_init2(self):
|
32
|
-
init = brainstate.init.Normal(scale=0.5)
|
32
|
+
init = brainstate.nn.init.Normal(scale=0.5)
|
33
33
|
for size in [(100,), (10, 20)]:
|
34
34
|
weights = init(size)
|
35
35
|
assert weights.shape == size
|
36
36
|
|
37
37
|
def test_normal_init3(self):
|
38
|
-
init1 = brainstate.init.Normal(scale=0.5, seed=10)
|
39
|
-
init2 = brainstate.init.Normal(scale=0.5, seed=10)
|
38
|
+
init1 = brainstate.nn.init.Normal(scale=0.5, seed=10)
|
39
|
+
init2 = brainstate.nn.init.Normal(scale=0.5, seed=10)
|
40
40
|
size = (10,)
|
41
41
|
weights1 = init1(size)
|
42
42
|
weights2 = init2(size)
|
@@ -46,13 +46,13 @@ class TestNormalInit(unittest.TestCase):
|
|
46
46
|
|
47
47
|
class TestUniformInit(unittest.TestCase):
|
48
48
|
def test_uniform_init1(self):
|
49
|
-
init = brainstate.init.Normal()
|
49
|
+
init = brainstate.nn.init.Normal()
|
50
50
|
for size in [(100,), (10, 20), (10, 20, 30)]:
|
51
51
|
weights = init(size)
|
52
52
|
assert weights.shape == size
|
53
53
|
|
54
54
|
def test_uniform_init2(self):
|
55
|
-
init = brainstate.init.Uniform(min_val=10, max_val=20)
|
55
|
+
init = brainstate.nn.init.Uniform(min_val=10, max_val=20)
|
56
56
|
for size in [(100,), (10, 20)]:
|
57
57
|
weights = init(size)
|
58
58
|
assert weights.shape == size
|
@@ -60,20 +60,21 @@ class TestUniformInit(unittest.TestCase):
|
|
60
60
|
|
61
61
|
class TestVarianceScaling(unittest.TestCase):
|
62
62
|
def test_var_scaling1(self):
|
63
|
-
init = brainstate.init.VarianceScaling(scale=1., mode='fan_in', distribution='truncated_normal')
|
63
|
+
init = brainstate.nn.init.VarianceScaling(scale=1., mode='fan_in', distribution='truncated_normal')
|
64
64
|
for size in [(10, 20), (10, 20, 30)]:
|
65
65
|
weights = init(size)
|
66
66
|
assert weights.shape == size
|
67
67
|
|
68
68
|
def test_var_scaling2(self):
|
69
|
-
init = brainstate.init.VarianceScaling(scale=2, mode='fan_out', distribution='normal')
|
69
|
+
init = brainstate.nn.init.VarianceScaling(scale=2, mode='fan_out', distribution='normal')
|
70
70
|
for size in [(10, 20), (10, 20, 30)]:
|
71
71
|
weights = init(size)
|
72
72
|
assert weights.shape == size
|
73
73
|
|
74
74
|
def test_var_scaling3(self):
|
75
|
-
init = brainstate.init.VarianceScaling(
|
76
|
-
|
75
|
+
init = brainstate.nn.init.VarianceScaling(
|
76
|
+
scale=2 / 4, mode='fan_avg', in_axis=0, out_axis=1, distribution='uniform'
|
77
|
+
)
|
77
78
|
for size in [(10, 20), (10, 20, 30)]:
|
78
79
|
weights = init(size)
|
79
80
|
assert weights.shape == size
|
@@ -81,7 +82,7 @@ class TestVarianceScaling(unittest.TestCase):
|
|
81
82
|
|
82
83
|
class TestKaimingUniformUnit(unittest.TestCase):
|
83
84
|
def test_kaiming_uniform_init(self):
|
84
|
-
init = brainstate.init.KaimingUniform()
|
85
|
+
init = brainstate.nn.init.KaimingUniform()
|
85
86
|
for size in [(10, 20), (10, 20, 30)]:
|
86
87
|
weights = init(size)
|
87
88
|
assert weights.shape == size
|
@@ -89,7 +90,7 @@ class TestKaimingUniformUnit(unittest.TestCase):
|
|
89
90
|
|
90
91
|
class TestKaimingNormalUnit(unittest.TestCase):
|
91
92
|
def test_kaiming_normal_init(self):
|
92
|
-
init = brainstate.init.KaimingNormal()
|
93
|
+
init = brainstate.nn.init.KaimingNormal()
|
93
94
|
for size in [(10, 20), (10, 20, 30)]:
|
94
95
|
weights = init(size)
|
95
96
|
assert weights.shape == size
|
@@ -97,7 +98,7 @@ class TestKaimingNormalUnit(unittest.TestCase):
|
|
97
98
|
|
98
99
|
class TestXavierUniformUnit(unittest.TestCase):
|
99
100
|
def test_xavier_uniform_init(self):
|
100
|
-
init = brainstate.init.XavierUniform()
|
101
|
+
init = brainstate.nn.init.XavierUniform()
|
101
102
|
for size in [(10, 20), (10, 20, 30)]:
|
102
103
|
weights = init(size)
|
103
104
|
assert weights.shape == size
|
@@ -105,7 +106,7 @@ class TestXavierUniformUnit(unittest.TestCase):
|
|
105
106
|
|
106
107
|
class TestXavierNormalUnit(unittest.TestCase):
|
107
108
|
def test_xavier_normal_init(self):
|
108
|
-
init = brainstate.init.XavierNormal()
|
109
|
+
init = brainstate.nn.init.XavierNormal()
|
109
110
|
for size in [(10, 20), (10, 20, 30)]:
|
110
111
|
weights = init(size)
|
111
112
|
assert weights.shape == size
|
@@ -113,7 +114,7 @@ class TestXavierNormalUnit(unittest.TestCase):
|
|
113
114
|
|
114
115
|
class TestLecunUniformUnit(unittest.TestCase):
|
115
116
|
def test_lecun_uniform_init(self):
|
116
|
-
init = brainstate.init.LecunUniform()
|
117
|
+
init = brainstate.nn.init.LecunUniform()
|
117
118
|
for size in [(10, 20), (10, 20, 30)]:
|
118
119
|
weights = init(size)
|
119
120
|
assert weights.shape == size
|
@@ -121,7 +122,7 @@ class TestLecunUniformUnit(unittest.TestCase):
|
|
121
122
|
|
122
123
|
class TestLecunNormalUnit(unittest.TestCase):
|
123
124
|
def test_lecun_normal_init(self):
|
124
|
-
init = brainstate.init.LecunNormal()
|
125
|
+
init = brainstate.nn.init.LecunNormal()
|
125
126
|
for size in [(10, 20), (10, 20, 30)]:
|
126
127
|
weights = init(size)
|
127
128
|
assert weights.shape == size
|
@@ -129,13 +130,13 @@ class TestLecunNormalUnit(unittest.TestCase):
|
|
129
130
|
|
130
131
|
class TestOrthogonalUnit(unittest.TestCase):
|
131
132
|
def test_orthogonal_init1(self):
|
132
|
-
init = brainstate.init.Orthogonal()
|
133
|
+
init = brainstate.nn.init.Orthogonal()
|
133
134
|
for size in [(20, 20), (10, 20, 30)]:
|
134
135
|
weights = init(size)
|
135
136
|
assert weights.shape == size
|
136
137
|
|
137
138
|
def test_orthogonal_init2(self):
|
138
|
-
init = brainstate.init.Orthogonal(scale=2., axis=0)
|
139
|
+
init = brainstate.nn.init.Orthogonal(scale=2., axis=0)
|
139
140
|
for size in [(10, 20), (10, 20, 30)]:
|
140
141
|
weights = init(size)
|
141
142
|
assert weights.shape == size
|
@@ -143,7 +144,37 @@ class TestOrthogonalUnit(unittest.TestCase):
|
|
143
144
|
|
144
145
|
class TestDeltaOrthogonalUnit(unittest.TestCase):
|
145
146
|
def test_delta_orthogonal_init1(self):
|
146
|
-
init = brainstate.init.DeltaOrthogonal()
|
147
|
+
init = brainstate.nn.init.DeltaOrthogonal()
|
147
148
|
for size in [(20, 20, 20), (10, 20, 30, 40), (50, 40, 30, 20, 20)]:
|
148
149
|
weights = init(size)
|
149
150
|
assert weights.shape == size
|
151
|
+
|
152
|
+
|
153
|
+
class TestZeroInit(unittest.TestCase):
|
154
|
+
def test_zero_init(self):
|
155
|
+
init = brainstate.nn.init.ZeroInit()
|
156
|
+
for size in [(100,), (10, 20), (10, 20, 30)]:
|
157
|
+
weights = init(size)
|
158
|
+
assert weights.shape == size
|
159
|
+
|
160
|
+
|
161
|
+
class TestOneInit(unittest.TestCase):
|
162
|
+
def test_one_init(self):
|
163
|
+
for size in [(100,), (10, 20), (10, 20, 30)]:
|
164
|
+
for value in [0., 1., -1.]:
|
165
|
+
init = brainstate.nn.init.Constant(value=value)
|
166
|
+
weights = init(size)
|
167
|
+
assert weights.shape == size
|
168
|
+
assert (weights == value).all()
|
169
|
+
|
170
|
+
|
171
|
+
class TestIdentityInit(unittest.TestCase):
|
172
|
+
def test_identity_init(self):
|
173
|
+
for size in [(100,), (10, 20)]:
|
174
|
+
for value in [0., 1., -1.]:
|
175
|
+
init = brainstate.nn.init.Identity(value=value)
|
176
|
+
weights = init(size)
|
177
|
+
if len(size) == 1:
|
178
|
+
assert weights.shape == (size[0], size[0])
|
179
|
+
else:
|
180
|
+
assert weights.shape == size
|
brainstate/random/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -13,6 +13,252 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""
|
17
|
+
Random number generation module for BrainState.
|
18
|
+
|
19
|
+
This module provides a comprehensive set of random number generation functions and utilities
|
20
|
+
for neural network simulations and scientific computing. It wraps JAX's random number
|
21
|
+
generation capabilities with a stateful interface that simplifies usage while maintaining
|
22
|
+
reproducibility and performance.
|
23
|
+
|
24
|
+
The module includes:
|
25
|
+
|
26
|
+
- Standard random distributions (uniform, normal, exponential, etc.)
|
27
|
+
- Random state management with automatic key splitting
|
28
|
+
- Seed management utilities for reproducible simulations
|
29
|
+
- NumPy-compatible API for easy migration
|
30
|
+
|
31
|
+
Key Features
|
32
|
+
------------
|
33
|
+
|
34
|
+
- **Stateful random generation**: Automatic management of JAX's PRNG keys
|
35
|
+
- **NumPy compatibility**: Drop-in replacement for most NumPy random functions
|
36
|
+
- **Reproducibility**: Robust seed management and state tracking
|
37
|
+
- **Performance**: JIT-compiled random functions for efficient generation
|
38
|
+
- **Thread-safe**: Proper handling of random state in parallel computations
|
39
|
+
|
40
|
+
Random State Management
|
41
|
+
-----------------------
|
42
|
+
|
43
|
+
The module uses a global `DEFAULT` RandomState instance that automatically manages
|
44
|
+
JAX's PRNG keys. This eliminates the need to manually track and split keys:
|
45
|
+
|
46
|
+
.. code-block:: python
|
47
|
+
|
48
|
+
>>> import brainstate as bs
|
49
|
+
>>> import brainstate.random as bsr
|
50
|
+
>>>
|
51
|
+
>>> # Set a global seed for reproducibility
|
52
|
+
>>> bsr.seed(42)
|
53
|
+
>>>
|
54
|
+
>>> # Generate random numbers without manual key management
|
55
|
+
>>> x = bsr.normal(0, 1, size=(3, 3))
|
56
|
+
>>> y = bsr.uniform(0, 1, size=(100,))
|
57
|
+
|
58
|
+
Custom Random States
|
59
|
+
--------------------
|
60
|
+
|
61
|
+
For more control, you can create custom RandomState instances:
|
62
|
+
|
63
|
+
.. code-block:: python
|
64
|
+
|
65
|
+
>>> import brainstate.random as bsr
|
66
|
+
>>>
|
67
|
+
>>> # Create a custom random state
|
68
|
+
>>> rng = bsr.RandomState(seed=123)
|
69
|
+
>>>
|
70
|
+
>>> # Use it for generation
|
71
|
+
>>> data = rng.normal(0, 1, size=(10, 10))
|
72
|
+
>>>
|
73
|
+
>>> # Get the current key
|
74
|
+
>>> current_key = rng.value
|
75
|
+
|
76
|
+
Available Distributions
|
77
|
+
-----------------------
|
78
|
+
|
79
|
+
The module provides a wide range of probability distributions:
|
80
|
+
|
81
|
+
**Uniform Distributions:**
|
82
|
+
|
83
|
+
- `rand`, `random`, `random_sample`, `ranf`, `sample` - Uniform [0, 1)
|
84
|
+
- `randint`, `random_integers` - Uniform integers
|
85
|
+
- `choice` - Random selection from array
|
86
|
+
- `permutation`, `shuffle` - Random ordering
|
87
|
+
|
88
|
+
**Normal Distributions:**
|
89
|
+
|
90
|
+
- `randn`, `normal` - Normal (Gaussian) distribution
|
91
|
+
- `standard_normal` - Standard normal distribution
|
92
|
+
- `multivariate_normal` - Multivariate normal distribution
|
93
|
+
- `truncated_normal` - Truncated normal distribution
|
94
|
+
|
95
|
+
**Other Continuous Distributions:**
|
96
|
+
|
97
|
+
- `beta` - Beta distribution
|
98
|
+
- `exponential`, `standard_exponential` - Exponential distribution
|
99
|
+
- `gamma`, `standard_gamma` - Gamma distribution
|
100
|
+
- `gumbel` - Gumbel distribution
|
101
|
+
- `laplace` - Laplace distribution
|
102
|
+
- `logistic` - Logistic distribution
|
103
|
+
- `pareto` - Pareto distribution
|
104
|
+
- `rayleigh` - Rayleigh distribution
|
105
|
+
- `standard_cauchy` - Cauchy distribution
|
106
|
+
- `standard_t` - Student's t-distribution
|
107
|
+
- `uniform` - Uniform distribution over [low, high)
|
108
|
+
- `weibull` - Weibull distribution
|
109
|
+
|
110
|
+
**Discrete Distributions:**
|
111
|
+
|
112
|
+
- `bernoulli` - Bernoulli distribution
|
113
|
+
- `binomial` - Binomial distribution
|
114
|
+
- `poisson` - Poisson distribution
|
115
|
+
|
116
|
+
Seed Management
|
117
|
+
---------------
|
118
|
+
|
119
|
+
The module provides utilities for managing random seeds:
|
120
|
+
|
121
|
+
.. code-block:: python
|
122
|
+
|
123
|
+
>>> import brainstate.random as bsr
|
124
|
+
>>>
|
125
|
+
>>> # Set a global seed
|
126
|
+
>>> bsr.seed(42)
|
127
|
+
>>>
|
128
|
+
>>> # Get current seed/key
|
129
|
+
>>> key = bsr.get_key()
|
130
|
+
>>>
|
131
|
+
>>> # Split the key for parallel operations
|
132
|
+
>>> keys = bsr.split_key(n=4)
|
133
|
+
>>>
|
134
|
+
>>> # Use context manager for temporary seed
|
135
|
+
>>> with bsr.local_seed(123):
|
136
|
+
... x = bsr.normal(0, 1, (5,)) # Uses seed 123
|
137
|
+
>>> y = bsr.normal(0, 1, (5,)) # Uses original seed
|
138
|
+
|
139
|
+
Examples
|
140
|
+
--------
|
141
|
+
|
142
|
+
**Basic random number generation:**
|
143
|
+
|
144
|
+
.. code-block:: python
|
145
|
+
|
146
|
+
>>> import brainstate.random as bsr
|
147
|
+
>>> import jax.numpy as jnp
|
148
|
+
>>>
|
149
|
+
>>> # Set seed for reproducibility
|
150
|
+
>>> bsr.seed(0)
|
151
|
+
>>>
|
152
|
+
>>> # Generate uniform random numbers
|
153
|
+
>>> uniform_data = bsr.random((3, 3))
|
154
|
+
>>> print(uniform_data.shape)
|
155
|
+
(3, 3)
|
156
|
+
>>>
|
157
|
+
>>> # Generate normal random numbers
|
158
|
+
>>> normal_data = bsr.normal(loc=0, scale=1, size=(100,))
|
159
|
+
>>> print(f"Mean: {normal_data.mean():.3f}, Std: {normal_data.std():.3f}")
|
160
|
+
Mean: -0.045, Std: 0.972
|
161
|
+
|
162
|
+
**Sampling and shuffling:**
|
163
|
+
|
164
|
+
.. code-block:: python
|
165
|
+
|
166
|
+
>>> import brainstate.random as bsr
|
167
|
+
>>> import jax.numpy as jnp
|
168
|
+
>>>
|
169
|
+
>>> bsr.seed(42)
|
170
|
+
>>>
|
171
|
+
>>> # Random choice from array
|
172
|
+
>>> arr = jnp.array([1, 2, 3, 4, 5])
|
173
|
+
>>> samples = bsr.choice(arr, size=3, replace=False)
|
174
|
+
>>> print(samples)
|
175
|
+
[4 1 5]
|
176
|
+
>>>
|
177
|
+
>>> # Random permutation
|
178
|
+
>>> perm = bsr.permutation(10)
|
179
|
+
>>> print(perm)
|
180
|
+
[3 5 1 7 9 0 2 8 4 6]
|
181
|
+
>>>
|
182
|
+
>>> # In-place shuffle
|
183
|
+
>>> data = jnp.arange(5)
|
184
|
+
>>> bsr.shuffle(data)
|
185
|
+
>>> print(data)
|
186
|
+
[2 0 4 1 3]
|
187
|
+
|
188
|
+
**Advanced distributions:**
|
189
|
+
|
190
|
+
.. code-block:: python
|
191
|
+
|
192
|
+
>>> import brainstate.random as bsr
|
193
|
+
>>> import matplotlib.pyplot as plt
|
194
|
+
>>>
|
195
|
+
>>> bsr.seed(123)
|
196
|
+
>>>
|
197
|
+
>>> # Generate samples from different distributions
|
198
|
+
>>> normal_samples = bsr.normal(0, 1, 1000)
|
199
|
+
>>> exponential_samples = bsr.exponential(1.0, 1000)
|
200
|
+
>>> beta_samples = bsr.beta(2, 5, 1000)
|
201
|
+
>>>
|
202
|
+
>>> # Plot histograms
|
203
|
+
>>> fig, axes = plt.subplots(1, 3, figsize=(12, 4))
|
204
|
+
>>> axes[0].hist(normal_samples, bins=30, density=True)
|
205
|
+
>>> axes[0].set_title('Normal Distribution')
|
206
|
+
>>> axes[1].hist(exponential_samples, bins=30, density=True)
|
207
|
+
>>> axes[1].set_title('Exponential Distribution')
|
208
|
+
>>> axes[2].hist(beta_samples, bins=30, density=True)
|
209
|
+
>>> axes[2].set_title('Beta Distribution')
|
210
|
+
>>> plt.show()
|
211
|
+
|
212
|
+
**Using with neural network simulations:**
|
213
|
+
|
214
|
+
.. code-block:: python
|
215
|
+
|
216
|
+
>>> import brainstate as bs
|
217
|
+
>>> import brainstate.random as bsr
|
218
|
+
>>> import brainstate.nn as nn
|
219
|
+
>>>
|
220
|
+
>>> class NoisyNeuron(bs.Module):
|
221
|
+
... def __init__(self, n_neurons, noise_scale=0.1):
|
222
|
+
... super().__init__()
|
223
|
+
... self.n_neurons = n_neurons
|
224
|
+
... self.noise_scale = noise_scale
|
225
|
+
... self.membrane = bs.State(jnp.zeros(n_neurons))
|
226
|
+
...
|
227
|
+
... def update(self, input_current):
|
228
|
+
... # Add noise to input current
|
229
|
+
... noise = bsr.normal(0, self.noise_scale, self.n_neurons)
|
230
|
+
... self.membrane.value += input_current + noise
|
231
|
+
... return self.membrane.value
|
232
|
+
>>>
|
233
|
+
>>> # Create and run noisy neuron model
|
234
|
+
>>> bsr.seed(42)
|
235
|
+
>>> neuron = NoisyNeuron(100)
|
236
|
+
>>> output = neuron.update(jnp.ones(100) * 0.5)
|
237
|
+
|
238
|
+
Notes
|
239
|
+
-----
|
240
|
+
|
241
|
+
- This module is designed to work seamlessly with JAX's functional programming model
|
242
|
+
- Random functions are JIT-compilable for optimal performance
|
243
|
+
- The global DEFAULT state is thread-local to avoid race conditions
|
244
|
+
- For deterministic results, always set a seed before random operations
|
245
|
+
|
246
|
+
See Also
|
247
|
+
--------
|
248
|
+
|
249
|
+
jax.random : JAX's random number generation module
|
250
|
+
numpy.random : NumPy's random number generation module
|
251
|
+
RandomState : The stateful random number generator class
|
252
|
+
|
253
|
+
References
|
254
|
+
----------
|
255
|
+
.. [1] JAX Random Number Generation:
|
256
|
+
https://jax.readthedocs.io/en/latest/jax.random.html
|
257
|
+
.. [2] NumPy Random Sampling:
|
258
|
+
https://numpy.org/doc/stable/reference/random/index.html
|
259
|
+
|
260
|
+
"""
|
261
|
+
|
16
262
|
from ._rand_funs import *
|
17
263
|
from ._rand_funs import __all__ as __all_random__
|
18
264
|
from ._rand_seed import *
|