brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -23,20 +23,20 @@ import brainstate
23
23
  class TestNormalInit(unittest.TestCase):
24
24
 
25
25
  def test_normal_init1(self):
26
- init = brainstate.init.Normal()
26
+ init = brainstate.nn.init.Normal()
27
27
  for size in [(100,), (10, 20), (10, 20, 30)]:
28
28
  weights = init(size)
29
29
  assert weights.shape == size
30
30
 
31
31
  def test_normal_init2(self):
32
- init = brainstate.init.Normal(scale=0.5)
32
+ init = brainstate.nn.init.Normal(scale=0.5)
33
33
  for size in [(100,), (10, 20)]:
34
34
  weights = init(size)
35
35
  assert weights.shape == size
36
36
 
37
37
  def test_normal_init3(self):
38
- init1 = brainstate.init.Normal(scale=0.5, seed=10)
39
- init2 = brainstate.init.Normal(scale=0.5, seed=10)
38
+ init1 = brainstate.nn.init.Normal(scale=0.5, seed=10)
39
+ init2 = brainstate.nn.init.Normal(scale=0.5, seed=10)
40
40
  size = (10,)
41
41
  weights1 = init1(size)
42
42
  weights2 = init2(size)
@@ -46,13 +46,13 @@ class TestNormalInit(unittest.TestCase):
46
46
 
47
47
  class TestUniformInit(unittest.TestCase):
48
48
  def test_uniform_init1(self):
49
- init = brainstate.init.Normal()
49
+ init = brainstate.nn.init.Normal()
50
50
  for size in [(100,), (10, 20), (10, 20, 30)]:
51
51
  weights = init(size)
52
52
  assert weights.shape == size
53
53
 
54
54
  def test_uniform_init2(self):
55
- init = brainstate.init.Uniform(min_val=10, max_val=20)
55
+ init = brainstate.nn.init.Uniform(min_val=10, max_val=20)
56
56
  for size in [(100,), (10, 20)]:
57
57
  weights = init(size)
58
58
  assert weights.shape == size
@@ -60,20 +60,21 @@ class TestUniformInit(unittest.TestCase):
60
60
 
61
61
  class TestVarianceScaling(unittest.TestCase):
62
62
  def test_var_scaling1(self):
63
- init = brainstate.init.VarianceScaling(scale=1., mode='fan_in', distribution='truncated_normal')
63
+ init = brainstate.nn.init.VarianceScaling(scale=1., mode='fan_in', distribution='truncated_normal')
64
64
  for size in [(10, 20), (10, 20, 30)]:
65
65
  weights = init(size)
66
66
  assert weights.shape == size
67
67
 
68
68
  def test_var_scaling2(self):
69
- init = brainstate.init.VarianceScaling(scale=2, mode='fan_out', distribution='normal')
69
+ init = brainstate.nn.init.VarianceScaling(scale=2, mode='fan_out', distribution='normal')
70
70
  for size in [(10, 20), (10, 20, 30)]:
71
71
  weights = init(size)
72
72
  assert weights.shape == size
73
73
 
74
74
  def test_var_scaling3(self):
75
- init = brainstate.init.VarianceScaling(scale=2 / 4, mode='fan_avg', in_axis=0, out_axis=1,
76
- distribution='uniform')
75
+ init = brainstate.nn.init.VarianceScaling(
76
+ scale=2 / 4, mode='fan_avg', in_axis=0, out_axis=1, distribution='uniform'
77
+ )
77
78
  for size in [(10, 20), (10, 20, 30)]:
78
79
  weights = init(size)
79
80
  assert weights.shape == size
@@ -81,7 +82,7 @@ class TestVarianceScaling(unittest.TestCase):
81
82
 
82
83
  class TestKaimingUniformUnit(unittest.TestCase):
83
84
  def test_kaiming_uniform_init(self):
84
- init = brainstate.init.KaimingUniform()
85
+ init = brainstate.nn.init.KaimingUniform()
85
86
  for size in [(10, 20), (10, 20, 30)]:
86
87
  weights = init(size)
87
88
  assert weights.shape == size
@@ -89,7 +90,7 @@ class TestKaimingUniformUnit(unittest.TestCase):
89
90
 
90
91
  class TestKaimingNormalUnit(unittest.TestCase):
91
92
  def test_kaiming_normal_init(self):
92
- init = brainstate.init.KaimingNormal()
93
+ init = brainstate.nn.init.KaimingNormal()
93
94
  for size in [(10, 20), (10, 20, 30)]:
94
95
  weights = init(size)
95
96
  assert weights.shape == size
@@ -97,7 +98,7 @@ class TestKaimingNormalUnit(unittest.TestCase):
97
98
 
98
99
  class TestXavierUniformUnit(unittest.TestCase):
99
100
  def test_xavier_uniform_init(self):
100
- init = brainstate.init.XavierUniform()
101
+ init = brainstate.nn.init.XavierUniform()
101
102
  for size in [(10, 20), (10, 20, 30)]:
102
103
  weights = init(size)
103
104
  assert weights.shape == size
@@ -105,7 +106,7 @@ class TestXavierUniformUnit(unittest.TestCase):
105
106
 
106
107
  class TestXavierNormalUnit(unittest.TestCase):
107
108
  def test_xavier_normal_init(self):
108
- init = brainstate.init.XavierNormal()
109
+ init = brainstate.nn.init.XavierNormal()
109
110
  for size in [(10, 20), (10, 20, 30)]:
110
111
  weights = init(size)
111
112
  assert weights.shape == size
@@ -113,7 +114,7 @@ class TestXavierNormalUnit(unittest.TestCase):
113
114
 
114
115
  class TestLecunUniformUnit(unittest.TestCase):
115
116
  def test_lecun_uniform_init(self):
116
- init = brainstate.init.LecunUniform()
117
+ init = brainstate.nn.init.LecunUniform()
117
118
  for size in [(10, 20), (10, 20, 30)]:
118
119
  weights = init(size)
119
120
  assert weights.shape == size
@@ -121,7 +122,7 @@ class TestLecunUniformUnit(unittest.TestCase):
121
122
 
122
123
  class TestLecunNormalUnit(unittest.TestCase):
123
124
  def test_lecun_normal_init(self):
124
- init = brainstate.init.LecunNormal()
125
+ init = brainstate.nn.init.LecunNormal()
125
126
  for size in [(10, 20), (10, 20, 30)]:
126
127
  weights = init(size)
127
128
  assert weights.shape == size
@@ -129,13 +130,13 @@ class TestLecunNormalUnit(unittest.TestCase):
129
130
 
130
131
  class TestOrthogonalUnit(unittest.TestCase):
131
132
  def test_orthogonal_init1(self):
132
- init = brainstate.init.Orthogonal()
133
+ init = brainstate.nn.init.Orthogonal()
133
134
  for size in [(20, 20), (10, 20, 30)]:
134
135
  weights = init(size)
135
136
  assert weights.shape == size
136
137
 
137
138
  def test_orthogonal_init2(self):
138
- init = brainstate.init.Orthogonal(scale=2., axis=0)
139
+ init = brainstate.nn.init.Orthogonal(scale=2., axis=0)
139
140
  for size in [(10, 20), (10, 20, 30)]:
140
141
  weights = init(size)
141
142
  assert weights.shape == size
@@ -143,7 +144,37 @@ class TestOrthogonalUnit(unittest.TestCase):
143
144
 
144
145
  class TestDeltaOrthogonalUnit(unittest.TestCase):
145
146
  def test_delta_orthogonal_init1(self):
146
- init = brainstate.init.DeltaOrthogonal()
147
+ init = brainstate.nn.init.DeltaOrthogonal()
147
148
  for size in [(20, 20, 20), (10, 20, 30, 40), (50, 40, 30, 20, 20)]:
148
149
  weights = init(size)
149
150
  assert weights.shape == size
151
+
152
+
153
+ class TestZeroInit(unittest.TestCase):
154
+ def test_zero_init(self):
155
+ init = brainstate.nn.init.ZeroInit()
156
+ for size in [(100,), (10, 20), (10, 20, 30)]:
157
+ weights = init(size)
158
+ assert weights.shape == size
159
+
160
+
161
+ class TestOneInit(unittest.TestCase):
162
+ def test_one_init(self):
163
+ for size in [(100,), (10, 20), (10, 20, 30)]:
164
+ for value in [0., 1., -1.]:
165
+ init = brainstate.nn.init.Constant(value=value)
166
+ weights = init(size)
167
+ assert weights.shape == size
168
+ assert (weights == value).all()
169
+
170
+
171
+ class TestIdentityInit(unittest.TestCase):
172
+ def test_identity_init(self):
173
+ for size in [(100,), (10, 20)]:
174
+ for value in [0., 1., -1.]:
175
+ init = brainstate.nn.init.Identity(value=value)
176
+ weights = init(size)
177
+ if len(size) == 1:
178
+ assert weights.shape == (size[0], size[0])
179
+ else:
180
+ assert weights.shape == size
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,6 +13,252 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """
17
+ Random number generation module for BrainState.
18
+
19
+ This module provides a comprehensive set of random number generation functions and utilities
20
+ for neural network simulations and scientific computing. It wraps JAX's random number
21
+ generation capabilities with a stateful interface that simplifies usage while maintaining
22
+ reproducibility and performance.
23
+
24
+ The module includes:
25
+
26
+ - Standard random distributions (uniform, normal, exponential, etc.)
27
+ - Random state management with automatic key splitting
28
+ - Seed management utilities for reproducible simulations
29
+ - NumPy-compatible API for easy migration
30
+
31
+ Key Features
32
+ ------------
33
+
34
+ - **Stateful random generation**: Automatic management of JAX's PRNG keys
35
+ - **NumPy compatibility**: Drop-in replacement for most NumPy random functions
36
+ - **Reproducibility**: Robust seed management and state tracking
37
+ - **Performance**: JIT-compiled random functions for efficient generation
38
+ - **Thread-safe**: Proper handling of random state in parallel computations
39
+
40
+ Random State Management
41
+ -----------------------
42
+
43
+ The module uses a global `DEFAULT` RandomState instance that automatically manages
44
+ JAX's PRNG keys. This eliminates the need to manually track and split keys:
45
+
46
+ .. code-block:: python
47
+
48
+ >>> import brainstate as bs
49
+ >>> import brainstate.random as bsr
50
+ >>>
51
+ >>> # Set a global seed for reproducibility
52
+ >>> bsr.seed(42)
53
+ >>>
54
+ >>> # Generate random numbers without manual key management
55
+ >>> x = bsr.normal(0, 1, size=(3, 3))
56
+ >>> y = bsr.uniform(0, 1, size=(100,))
57
+
58
+ Custom Random States
59
+ --------------------
60
+
61
+ For more control, you can create custom RandomState instances:
62
+
63
+ .. code-block:: python
64
+
65
+ >>> import brainstate.random as bsr
66
+ >>>
67
+ >>> # Create a custom random state
68
+ >>> rng = bsr.RandomState(seed=123)
69
+ >>>
70
+ >>> # Use it for generation
71
+ >>> data = rng.normal(0, 1, size=(10, 10))
72
+ >>>
73
+ >>> # Get the current key
74
+ >>> current_key = rng.value
75
+
76
+ Available Distributions
77
+ -----------------------
78
+
79
+ The module provides a wide range of probability distributions:
80
+
81
+ **Uniform Distributions:**
82
+
83
+ - `rand`, `random`, `random_sample`, `ranf`, `sample` - Uniform [0, 1)
84
+ - `randint`, `random_integers` - Uniform integers
85
+ - `choice` - Random selection from array
86
+ - `permutation`, `shuffle` - Random ordering
87
+
88
+ **Normal Distributions:**
89
+
90
+ - `randn`, `normal` - Normal (Gaussian) distribution
91
+ - `standard_normal` - Standard normal distribution
92
+ - `multivariate_normal` - Multivariate normal distribution
93
+ - `truncated_normal` - Truncated normal distribution
94
+
95
+ **Other Continuous Distributions:**
96
+
97
+ - `beta` - Beta distribution
98
+ - `exponential`, `standard_exponential` - Exponential distribution
99
+ - `gamma`, `standard_gamma` - Gamma distribution
100
+ - `gumbel` - Gumbel distribution
101
+ - `laplace` - Laplace distribution
102
+ - `logistic` - Logistic distribution
103
+ - `pareto` - Pareto distribution
104
+ - `rayleigh` - Rayleigh distribution
105
+ - `standard_cauchy` - Cauchy distribution
106
+ - `standard_t` - Student's t-distribution
107
+ - `uniform` - Uniform distribution over [low, high)
108
+ - `weibull` - Weibull distribution
109
+
110
+ **Discrete Distributions:**
111
+
112
+ - `bernoulli` - Bernoulli distribution
113
+ - `binomial` - Binomial distribution
114
+ - `poisson` - Poisson distribution
115
+
116
+ Seed Management
117
+ ---------------
118
+
119
+ The module provides utilities for managing random seeds:
120
+
121
+ .. code-block:: python
122
+
123
+ >>> import brainstate.random as bsr
124
+ >>>
125
+ >>> # Set a global seed
126
+ >>> bsr.seed(42)
127
+ >>>
128
+ >>> # Get current seed/key
129
+ >>> key = bsr.get_key()
130
+ >>>
131
+ >>> # Split the key for parallel operations
132
+ >>> keys = bsr.split_key(n=4)
133
+ >>>
134
+ >>> # Use context manager for temporary seed
135
+ >>> with bsr.local_seed(123):
136
+ ... x = bsr.normal(0, 1, (5,)) # Uses seed 123
137
+ >>> y = bsr.normal(0, 1, (5,)) # Uses original seed
138
+
139
+ Examples
140
+ --------
141
+
142
+ **Basic random number generation:**
143
+
144
+ .. code-block:: python
145
+
146
+ >>> import brainstate.random as bsr
147
+ >>> import jax.numpy as jnp
148
+ >>>
149
+ >>> # Set seed for reproducibility
150
+ >>> bsr.seed(0)
151
+ >>>
152
+ >>> # Generate uniform random numbers
153
+ >>> uniform_data = bsr.random((3, 3))
154
+ >>> print(uniform_data.shape)
155
+ (3, 3)
156
+ >>>
157
+ >>> # Generate normal random numbers
158
+ >>> normal_data = bsr.normal(loc=0, scale=1, size=(100,))
159
+ >>> print(f"Mean: {normal_data.mean():.3f}, Std: {normal_data.std():.3f}")
160
+ Mean: -0.045, Std: 0.972
161
+
162
+ **Sampling and shuffling:**
163
+
164
+ .. code-block:: python
165
+
166
+ >>> import brainstate.random as bsr
167
+ >>> import jax.numpy as jnp
168
+ >>>
169
+ >>> bsr.seed(42)
170
+ >>>
171
+ >>> # Random choice from array
172
+ >>> arr = jnp.array([1, 2, 3, 4, 5])
173
+ >>> samples = bsr.choice(arr, size=3, replace=False)
174
+ >>> print(samples)
175
+ [4 1 5]
176
+ >>>
177
+ >>> # Random permutation
178
+ >>> perm = bsr.permutation(10)
179
+ >>> print(perm)
180
+ [3 5 1 7 9 0 2 8 4 6]
181
+ >>>
182
+ >>> # In-place shuffle
183
+ >>> data = jnp.arange(5)
184
+ >>> bsr.shuffle(data)
185
+ >>> print(data)
186
+ [2 0 4 1 3]
187
+
188
+ **Advanced distributions:**
189
+
190
+ .. code-block:: python
191
+
192
+ >>> import brainstate.random as bsr
193
+ >>> import matplotlib.pyplot as plt
194
+ >>>
195
+ >>> bsr.seed(123)
196
+ >>>
197
+ >>> # Generate samples from different distributions
198
+ >>> normal_samples = bsr.normal(0, 1, 1000)
199
+ >>> exponential_samples = bsr.exponential(1.0, 1000)
200
+ >>> beta_samples = bsr.beta(2, 5, 1000)
201
+ >>>
202
+ >>> # Plot histograms
203
+ >>> fig, axes = plt.subplots(1, 3, figsize=(12, 4))
204
+ >>> axes[0].hist(normal_samples, bins=30, density=True)
205
+ >>> axes[0].set_title('Normal Distribution')
206
+ >>> axes[1].hist(exponential_samples, bins=30, density=True)
207
+ >>> axes[1].set_title('Exponential Distribution')
208
+ >>> axes[2].hist(beta_samples, bins=30, density=True)
209
+ >>> axes[2].set_title('Beta Distribution')
210
+ >>> plt.show()
211
+
212
+ **Using with neural network simulations:**
213
+
214
+ .. code-block:: python
215
+
216
+ >>> import brainstate as bs
217
+ >>> import brainstate.random as bsr
218
+ >>> import brainstate.nn as nn
219
+ >>>
220
+ >>> class NoisyNeuron(bs.Module):
221
+ ... def __init__(self, n_neurons, noise_scale=0.1):
222
+ ... super().__init__()
223
+ ... self.n_neurons = n_neurons
224
+ ... self.noise_scale = noise_scale
225
+ ... self.membrane = bs.State(jnp.zeros(n_neurons))
226
+ ...
227
+ ... def update(self, input_current):
228
+ ... # Add noise to input current
229
+ ... noise = bsr.normal(0, self.noise_scale, self.n_neurons)
230
+ ... self.membrane.value += input_current + noise
231
+ ... return self.membrane.value
232
+ >>>
233
+ >>> # Create and run noisy neuron model
234
+ >>> bsr.seed(42)
235
+ >>> neuron = NoisyNeuron(100)
236
+ >>> output = neuron.update(jnp.ones(100) * 0.5)
237
+
238
+ Notes
239
+ -----
240
+
241
+ - This module is designed to work seamlessly with JAX's functional programming model
242
+ - Random functions are JIT-compilable for optimal performance
243
+ - The global DEFAULT state is thread-local to avoid race conditions
244
+ - For deterministic results, always set a seed before random operations
245
+
246
+ See Also
247
+ --------
248
+
249
+ jax.random : JAX's random number generation module
250
+ numpy.random : NumPy's random number generation module
251
+ RandomState : The stateful random number generator class
252
+
253
+ References
254
+ ----------
255
+ .. [1] JAX Random Number Generation:
256
+ https://jax.readthedocs.io/en/latest/jax.random.html
257
+ .. [2] NumPy Random Sampling:
258
+ https://numpy.org/doc/stable/reference/random/index.html
259
+
260
+ """
261
+
16
262
  from ._rand_funs import *
17
263
  from ._rand_funs import __all__ as __all_random__
18
264
  from ._rand_seed import *