brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -25,32 +25,76 @@ from ._rand_state import RandomState, DEFAULT
25
25
 
26
26
  __all__ = [
27
27
  # numpy compatibility
28
- 'rand', 'randint', 'random_integers', 'randn', 'random',
29
- 'random_sample', 'ranf', 'sample', 'choice', 'permutation', 'shuffle', 'beta',
30
- 'exponential', 'gamma', 'gumbel', 'laplace', 'logistic', 'normal', 'pareto',
31
- 'poisson', 'standard_cauchy', 'standard_exponential', 'standard_gamma',
32
- 'standard_normal', 'standard_t', 'uniform', 'truncated_normal', 'bernoulli',
33
- 'lognormal', 'binomial', 'chisquare', 'dirichlet', 'geometric', 'f',
34
- 'hypergeometric', 'logseries', 'multinomial', 'multivariate_normal',
35
- 'negative_binomial', 'noncentral_chisquare', 'noncentral_f', 'power',
36
- 'rayleigh', 'triangular', 'vonmises', 'wald', 'weibull', 'weibull_min',
37
- 'zipf', 'maxwell', 't', 'orthogonal', 'loggamma', 'categorical',
28
+ 'rand',
29
+ 'randint',
30
+ 'random_integers',
31
+ 'randn',
32
+ 'random',
33
+ 'random_sample',
34
+ 'ranf',
35
+ 'sample',
36
+ 'choice',
37
+ 'permutation',
38
+ 'shuffle',
39
+ 'beta',
40
+ 'exponential',
41
+ 'gamma',
42
+ 'gumbel',
43
+ 'laplace',
44
+ 'logistic',
45
+ 'normal',
46
+ 'pareto',
47
+ 'poisson',
48
+ 'standard_cauchy',
49
+ 'standard_exponential',
50
+ 'standard_gamma',
51
+ 'standard_normal',
52
+ 'standard_t',
53
+ 'uniform',
54
+ 'truncated_normal',
55
+ 'bernoulli',
56
+ 'lognormal',
57
+ 'binomial',
58
+ 'chisquare',
59
+ 'dirichlet',
60
+ 'geometric',
61
+ 'f',
62
+ 'hypergeometric',
63
+ 'logseries',
64
+ 'multinomial',
65
+ 'multivariate_normal',
66
+ 'negative_binomial',
67
+ 'noncentral_chisquare',
68
+ 'noncentral_f',
69
+ 'power',
70
+ 'rayleigh',
71
+ 'triangular',
72
+ 'vonmises',
73
+ 'wald',
74
+ 'weibull',
75
+ 'weibull_min',
76
+ 'zipf',
77
+ 'maxwell',
78
+ 't',
79
+ 'orthogonal',
80
+ 'loggamma',
81
+ 'categorical',
38
82
 
39
83
  # pytorch compatibility
40
- 'rand_like', 'randint_like', 'randn_like',
84
+ 'rand_like',
85
+ 'randint_like',
86
+ 'randn_like',
41
87
  ]
42
88
 
43
89
 
44
- def rand(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
90
+ def rand(
91
+ *dn,
92
+ key: Optional[SeedOrKey] = None,
93
+ dtype: DTypeLike = None
94
+ ):
45
95
  r"""
46
96
  Random values in a given shape.
47
97
 
48
- .. note::
49
- This is a convenience function for users porting code from Matlab,
50
- and wraps `random_sample`. That function takes a
51
- tuple to specify the size of the output, which is consistent with
52
- other NumPy functions like `numpy.zeros` and `numpy.ones`.
53
-
54
98
  Create an array of the given shape and populate it with
55
99
  random samples from a uniform distribution
56
100
  over ``[0, 1)``.
@@ -78,18 +122,25 @@ def rand(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
78
122
 
79
123
  Examples
80
124
  --------
81
- >>> import brainstate as brainstate
82
- >>> brainstate.random.rand(3,2)
83
- array([[ 0.14022471, 0.96360618], #random
84
- [ 0.37601032, 0.25528411], #random
85
- [ 0.49313049, 0.94909878]]) #random
125
+ Generate random values in a 3x2 array:
126
+
127
+ .. code-block:: python
128
+
129
+ >>> import brainstate
130
+ >>> arr = brainstate.random.rand(3, 2)
131
+ >>> print(arr.shape) # (3, 2)
132
+ >>> print((arr >= 0).all() and (arr < 1).all()) # True
86
133
  """
87
134
  return DEFAULT.rand(*dn, key=key, dtype=dtype)
88
135
 
89
136
 
90
- def randint(low, high=None, size: Optional[Size] = None,
91
- key: Optional[SeedOrKey] = None,
92
- dtype: DTypeLike = None):
137
+ def randint(
138
+ low,
139
+ high=None,
140
+ size: Optional[Size] = None,
141
+ key: Optional[SeedOrKey] = None,
142
+ dtype: DTypeLike = None
143
+ ):
93
144
  r"""Return random integers from `low` (inclusive) to `high` (exclusive).
94
145
 
95
146
  Return random integers from the "discrete uniform" distribution of
@@ -110,9 +161,6 @@ def randint(low, high=None, size: Optional[Size] = None,
110
161
  Output shape. If the given shape is, e.g., ``(m, n, k)``, then
111
162
  ``m * n * k`` samples are drawn. Default is None, in which case a
112
163
  single value is returned.
113
- dtype : dtype, optional
114
- Desired dtype of the result. Byteorder must be native.
115
- The default value is int.
116
164
  key : PRNGKey, optional
117
165
  The key for the random number generator. If not given, the
118
166
  default random number generator is used.
@@ -135,43 +183,49 @@ def randint(low, high=None, size: Optional[Size] = None,
135
183
 
136
184
  Examples
137
185
  --------
138
- >>> import brainstate as brainstate
139
- >>> brainstate.random.randint(2, size=10)
140
- array([1, 0, 0, 0, 1, 1, 0, 0, 1, 0]) # random
141
- >>> brainstate.random.randint(1, size=10)
142
- array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
186
+ Generate 10 random integers from 0 to 1 (exclusive):
187
+
188
+ .. code-block:: python
143
189
 
144
- Generate a 2 x 4 array of ints between 0 and 4, inclusive:
190
+ >>> import brainstate
191
+ >>> arr = brainstate.random.randint(2, size=10)
192
+ >>> print(arr.shape) # (10,)
193
+ >>> print((arr >= 0).all() and (arr < 2).all()) # True
145
194
 
146
- >>> brainstate.random.randint(5, size=(2, 4))
147
- array([[4, 0, 2, 1], # random
148
- [3, 2, 2, 0]])
195
+ Generate a 2x4 array of integers from 0 to 4 (exclusive):
149
196
 
150
- Generate a 1 x 3 array with 3 different upper bounds
197
+ .. code-block:: python
151
198
 
152
- >>> brainstate.random.randint(1, [3, 5, 10])
153
- array([2, 2, 9]) # random
199
+ >>> arr = brainstate.random.randint(5, size=(2, 4))
200
+ >>> print(arr.shape) # (2, 4)
201
+ >>> print((arr >= 0).all() and (arr < 5).all()) # True
154
202
 
155
- Generate a 1 by 3 array with 3 different lower bounds
203
+ Generate integers with different upper bounds using broadcasting:
156
204
 
157
- >>> brainstate.random.randint([1, 5, 7], 10)
158
- array([9, 8, 7]) # random
205
+ .. code-block:: python
159
206
 
160
- Generate a 2 by 4 array using broadcasting with dtype of uint8
207
+ >>> arr = brainstate.random.randint(1, [3, 5, 10])
208
+ >>> print(arr.shape) # (3,)
161
209
 
162
- >>> brainstate.random.randint([1, 3, 5, 7], [[10], [20]], dtype=np.uint8)
163
- array([[ 8, 6, 9, 7], # random
164
- [ 1, 16, 9, 12]], dtype=uint8)
210
+ Generate integers with different lower bounds:
211
+
212
+ .. code-block:: python
213
+
214
+ >>> arr = brainstate.random.randint([1, 5, 7], 10)
215
+ >>> print(arr.shape) # (3,)
216
+ >>> print((arr >= [1, 5, 7]).all()) # True
165
217
  """
166
218
 
167
219
  return DEFAULT.randint(low, high=high, size=size, dtype=dtype, key=key)
168
220
 
169
221
 
170
- def random_integers(low,
171
- high=None,
172
- size: Optional[Size] = None,
173
- key: Optional[SeedOrKey] = None,
174
- dtype: DTypeLike = None):
222
+ def random_integers(
223
+ low,
224
+ high=None,
225
+ size: Optional[Size] = None,
226
+ key: Optional[SeedOrKey] = None,
227
+ dtype: DTypeLike = None
228
+ ):
175
229
  r"""
176
230
  Random integers of type `np.int_` between `low` and `high`, inclusive.
177
231
 
@@ -219,53 +273,54 @@ def random_integers(low,
219
273
 
220
274
  Examples
221
275
  --------
222
- >>> import brainstate as brainstate
223
- >>> brainstate.random.random_integers(5)
224
- 4 # random
225
- >>> type(brainstate.random.random_integers(5))
226
- <class 'numpy.int64'>
227
- >>> brainstate.random.random_integers(5, size=(3,2))
228
- array([[5, 4], # random
229
- [3, 3],
230
- [4, 5]])
276
+ Generate a single random integer from 1 to 5 (inclusive):
277
+
278
+ .. code-block:: python
279
+
280
+ >>> import brainstate
281
+ >>> val = brainstate.random.random_integers(5)
282
+ >>> print(type(val)) # <class 'numpy.int64'>
283
+ >>> print(1 <= val <= 5) # True
284
+
285
+ Generate a 3x2 array of random integers from 1 to 5 (inclusive):
286
+
287
+ .. code-block:: python
288
+
289
+ >>> arr = brainstate.random.random_integers(5, size=(3, 2))
290
+ >>> print(arr.shape) # (3, 2)
291
+ >>> print((arr >= 1).all() and (arr <= 5).all()) # True
231
292
 
232
293
  Choose five random numbers from the set of five evenly-spaced
233
294
  numbers between 0 and 2.5, inclusive (*i.e.*, from the set
234
295
  :math:`{0, 5/8, 10/8, 15/8, 20/8}`):
235
296
 
236
- >>> 2.5 * (brainstate.random.random_integers(5, size=(5,)) - 1) / 4.
237
- array([ 0.625, 1.25 , 0.625, 0.625, 2.5 ]) # random
297
+ .. code-block:: python
238
298
 
239
- Roll two six sided dice 1000 times and sum the results:
299
+ >>> vals = 2.5 * (brainstate.random.random_integers(5, size=(5,)) - 1) / 4.
300
+ >>> print(vals.shape) # (5,)
240
301
 
241
- >>> d1 = brainstate.random.random_integers(1, 6, 1000)
242
- >>> d2 = brainstate.random.random_integers(1, 6, 1000)
243
- >>> dsums = d1 + d2
302
+ Roll two six sided dice 1000 times and sum the results:
244
303
 
245
- Display results as a histogram:
304
+ .. code-block:: python
246
305
 
247
- >>> import matplotlib.pyplot as plt # noqa
248
- >>> count, bins, ignored = plt.hist(dsums, 11, density=True)
249
- >>> plt.show()
306
+ >>> d1 = brainstate.random.random_integers(1, 6, 1000)
307
+ >>> d2 = brainstate.random.random_integers(1, 6, 1000)
308
+ >>> dsums = d1 + d2
309
+ >>> print(dsums.shape) # (1000,)
310
+ >>> print((dsums >= 2).all() and (dsums <= 12).all()) # True
250
311
  """
251
312
 
252
313
  return DEFAULT.random_integers(low, high=high, size=size, key=key, dtype=dtype)
253
314
 
254
315
 
255
- def randn(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
316
+ def randn(
317
+ *dn,
318
+ key: Optional[SeedOrKey] = None,
319
+ dtype: DTypeLike = None
320
+ ):
256
321
  r"""
257
322
  Return a sample (or samples) from the "standard normal" distribution.
258
323
 
259
- .. note::
260
- This is a convenience function for users porting code from Matlab,
261
- and wraps `standard_normal`. That function takes a
262
- tuple to specify the size of the output, which is consistent with
263
- other NumPy functions like `numpy.zeros` and `numpy.ones`.
264
-
265
- .. note::
266
- New code should use the ``standard_normal`` method of a ``default_rng()``
267
- instance instead; please see the :ref:`random-quick-start`.
268
-
269
324
  If positive int_like arguments are provided, `randn` generates an array
270
325
  of shape ``(d0, d1, ..., dn)``, filled
271
326
  with random floats sampled from a univariate "normal" (Gaussian)
@@ -301,21 +356,37 @@ def randn(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
301
356
 
302
357
  Examples
303
358
  --------
304
- >>> import brainstate as brainstate
305
- >>> brainstate.random.randn()
306
- 2.1923875335537315 # random
359
+ Generate a single random number from standard normal distribution:
360
+
361
+ .. code-block:: python
362
+
363
+ >>> import brainstate
364
+ >>> val = brainstate.random.randn()
365
+ >>> print(type(val)) # <class 'numpy.float64'>
366
+
367
+ Generate a 2x4 array of standard normal samples:
368
+
369
+ .. code-block:: python
370
+
371
+ >>> arr = brainstate.random.randn(2, 4)
372
+ >>> print(arr.shape) # (2, 4)
307
373
 
308
374
  Two-by-four array of samples from N(3, 6.25):
309
375
 
310
- >>> 3 + 2.5 * brainstate.random.randn(2, 4)
311
- array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random
312
- [ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random
376
+ .. code-block:: python
377
+
378
+ >>> arr = 3 + 2.5 * brainstate.random.randn(2, 4)
379
+ >>> print(arr.shape) # (2, 4)
313
380
  """
314
381
 
315
382
  return DEFAULT.randn(*dn, key=key, dtype=dtype)
316
383
 
317
384
 
318
- def random(size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
385
+ def random(
386
+ size: Optional[Size] = None,
387
+ key: Optional[SeedOrKey] = None,
388
+ dtype: DTypeLike = None
389
+ ):
319
390
  r"""
320
391
  Return random floats in the half-open interval [0.0, 1.0). Alias for
321
392
  `random_sample` to ease forward-porting to the new random API.
@@ -323,7 +394,11 @@ def random(size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype:
323
394
  return DEFAULT.random(size, key=key, dtype=dtype)
324
395
 
325
396
 
326
- def random_sample(size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
397
+ def random_sample(
398
+ size: Optional[Size] = None,
399
+ key: Optional[SeedOrKey] = None,
400
+ dtype: DTypeLike = None
401
+ ):
327
402
  r"""
328
403
  Return random floats in the half-open interval [0.0, 1.0).
329
404
 
@@ -333,10 +408,6 @@ def random_sample(size: Optional[Size] = None, key: Optional[SeedOrKey] = None,
333
408
 
334
409
  (b - a) * random_sample() + a
335
410
 
336
- .. note::
337
- New code should use the ``random`` method of a ``default_rng()``
338
- instance instead; please see the :ref:`random-quick-start`.
339
-
340
411
  Parameters
341
412
  ----------
342
413
  size : int or tuple of ints, optional
@@ -359,25 +430,39 @@ def random_sample(size: Optional[Size] = None, key: Optional[SeedOrKey] = None,
359
430
 
360
431
  Examples
361
432
  --------
362
- >>> import brainstate as brainstate
363
- >>> brainstate.random.random_sample()
364
- 0.47108547995356098 # random
365
- >>> type(brainstate.random.random_sample())
366
- <class 'float'>
367
- >>> brainstate.random.random_sample((5,))
368
- array([ 0.30220482, 0.86820401, 0.1654503 , 0.11659149, 0.54323428]) # random
433
+ Generate a single random float:
434
+
435
+ .. code-block:: python
436
+
437
+ >>> import brainstate
438
+ >>> val = brainstate.random.random_sample()
439
+ >>> print(type(val)) # <class 'float'>
440
+ >>> print(0.0 <= val < 1.0) # True
441
+
442
+ Generate an array of 5 random floats:
443
+
444
+ .. code-block:: python
445
+
446
+ >>> arr = brainstate.random.random_sample((5,))
447
+ >>> print(arr.shape) # (5,)
448
+ >>> print((arr >= 0.0).all() and (arr < 1.0).all()) # True
369
449
 
370
450
  Three-by-two array of random numbers from [-5, 0):
371
451
 
372
- >>> 5 * brainstate.random.random_sample((3, 2)) - 5
373
- array([[-3.99149989, -0.52338984], # random
374
- [-2.99091858, -0.79479508],
375
- [-1.23204345, -1.75224494]])
452
+ .. code-block:: python
453
+
454
+ >>> arr = 5 * brainstate.random.random_sample((3, 2)) - 5
455
+ >>> print(arr.shape) # (3, 2)
456
+ print((arr >= -5.0).all() and (arr < 0.0).all()) # True
376
457
  """
377
458
  return DEFAULT.random_sample(size, key=key, dtype=dtype)
378
459
 
379
460
 
380
- def ranf(size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
461
+ def ranf(
462
+ size: Optional[Size] = None,
463
+ key: Optional[SeedOrKey] = None,
464
+ dtype: DTypeLike = None
465
+ ):
381
466
  r"""
382
467
  This is an alias of `random_sample`. See `random_sample` for the complete
383
468
  documentation.
@@ -385,7 +470,11 @@ def ranf(size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DT
385
470
  return DEFAULT.ranf(size, key=key, dtype=dtype)
386
471
 
387
472
 
388
- def sample(size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
473
+ def sample(
474
+ size: Optional[Size] = None,
475
+ key: Optional[SeedOrKey] = None,
476
+ dtype: DTypeLike = None
477
+ ):
389
478
  """
390
479
  This is an alias of `random_sample`. See `random_sample` for the complete
391
480
  documentation.
@@ -393,8 +482,13 @@ def sample(size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype:
393
482
  return DEFAULT.sample(size, key=key, dtype=dtype)
394
483
 
395
484
 
396
- def choice(a, size: Optional[Size] = None, replace=True, p=None,
397
- key: Optional[SeedOrKey] = None):
485
+ def choice(
486
+ a,
487
+ size: Optional[Size] = None,
488
+ replace=True,
489
+ p=None,
490
+ key: Optional[SeedOrKey] = None
491
+ ):
398
492
  r"""
399
493
  Generates a random sample from a given 1-D array
400
494
 
@@ -450,45 +544,56 @@ def choice(a, size: Optional[Size] = None, replace=True, p=None,
450
544
  --------
451
545
  Generate a uniform random sample from np.arange(5) of size 3:
452
546
 
453
- >>> import brainstate as brainstate
454
- >>> brainstate.random.choice(5, 3)
455
- array([0, 3, 4]) # random
456
- >>> #This is equivalent to brainpy.math.random.randint(0,5,3)
547
+ .. code-block:: python
548
+
549
+ >>> import brainstate
550
+ >>> result = brainstate.random.choice(5, 3)
551
+ >>> print(result.shape) # (3,)
552
+ >>> print((result >= 0).all() and (result < 5).all()) # True
457
553
 
458
554
  Generate a non-uniform random sample from np.arange(5) of size 3:
459
555
 
460
- >>> brainstate.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0])
461
- array([3, 3, 0]) # random
556
+ .. code-block:: python
557
+
558
+ >>> result = brainstate.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0])
559
+ >>> print(result.shape) # (3,)
560
+ >>> print(set(result).issubset({0, 2, 3})) # True (only non-zero prob elements)
561
+
562
+ Generate a uniform random sample from np.arange(5) of size 3 without replacement:
462
563
 
463
- Generate a uniform random sample from np.arange(5) of size 3 without
464
- replacement:
564
+ .. code-block:: python
465
565
 
466
- >>> brainstate.random.choice(5, 3, replace=False)
467
- array([3,1,0]) # random
468
- >>> #This is equivalent to brainpy.math.random.permutation(np.arange(5))[:3]
566
+ >>> result = brainstate.random.choice(5, 3, replace=False)
567
+ >>> print(result.shape) # (3,)
568
+ >>> print(len(set(result)) == 3) # True (all unique)
469
569
 
470
- Generate a non-uniform random sample from np.arange(5) of size
471
- 3 without replacement:
570
+ Generate a non-uniform random sample from np.arange(5) of size 3 without replacement:
472
571
 
473
- >>> brainstate.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0])
474
- array([2, 3, 0]) # random
572
+ .. code-block:: python
475
573
 
476
- Any of the above can be repeated with an arbitrary array-like
477
- instead of just integers. For instance:
574
+ >>> result = brainstate.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0])
575
+ >>> print(result.shape) # (3,)
576
+ >>> print(len(set(result)) == 3) # True (all unique)
478
577
 
479
- >>> aa_milne_arr = ['pooh', 'rabbit', 'piglet', 'Christopher']
480
- >>> brainstate.random.choice(aa_milne_arr, 5, p=[0.5, 0.1, 0.1, 0.3])
481
- array(['pooh', 'pooh', 'pooh', 'Christopher', 'piglet'], # random
482
- dtype='<U11')
578
+ Any of the above can be repeated with an arbitrary array-like instead of just integers:
579
+
580
+ .. code-block:: python
581
+
582
+ >>> aa_milne_arr = ['pooh', 'rabbit', 'piglet', 'Christopher']
583
+ >>> result = brainstate.random.choice(aa_milne_arr, 5, p=[0.5, 0.1, 0.1, 0.3])
584
+ >>> print(result.shape) # (5,)
585
+ >>> print(result.dtype.kind) # 'U' (unicode string)
483
586
  """
484
587
  a = a
485
588
  return DEFAULT.choice(a=a, size=size, replace=replace, p=p, key=key)
486
589
 
487
590
 
488
- def permutation(x,
489
- axis: int = 0,
490
- independent: bool = False,
491
- key: Optional[SeedOrKey] = None):
591
+ def permutation(
592
+ x,
593
+ axis: int = 0,
594
+ independent: bool = False,
595
+ key: Optional[SeedOrKey] = None
596
+ ):
492
597
  r"""
493
598
  Randomly permute a sequence, or return a permuted range.
494
599
 
@@ -519,23 +624,42 @@ def permutation(x,
519
624
 
520
625
  Examples
521
626
  --------
522
- >>> import brainstate as brainstate
523
- >>> brainstate.random.permutation(10)
524
- array([1, 7, 4, 3, 0, 9, 2, 5, 8, 6]) # random
525
-
526
- >>> brainstate.random.permutation([1, 4, 9, 12, 15])
527
- array([15, 1, 9, 4, 12]) # random
528
-
529
- >>> arr = np.arange(9).reshape((3, 3))
530
- >>> brainstate.random.permutation(arr)
531
- array([[6, 7, 8], # random
532
- [0, 1, 2],
533
- [3, 4, 5]])
627
+ Permute integers from 0 to 9:
628
+
629
+ .. code-block:: python
630
+
631
+ >>> import brainstate
632
+ >>> result = brainstate.random.permutation(10)
633
+ >>> print(result.shape) # (10,)
634
+ >>> print(sorted(result)) # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
635
+
636
+ Permute a given array:
637
+
638
+ .. code-block:: python
639
+
640
+ >>> arr = [1, 4, 9, 12, 15]
641
+ >>> result = brainstate.random.permutation(arr)
642
+ >>> print(result.shape) # (5,)
643
+ >>> print(sorted(result)) # [1, 4, 9, 12, 15]
644
+
645
+ Permute rows of a 2D array:
646
+
647
+ .. code-block:: python
648
+
649
+ >>> import numpy as np
650
+ >>> arr = np.arange(9).reshape((3, 3))
651
+ >>> result = brainstate.random.permutation(arr)
652
+ >>> print(result.shape) # (3, 3)
653
+ >>> print(result.flatten().sort() == np.arange(9).sort()) # True
534
654
  """
535
655
  return DEFAULT.permutation(x, axis=axis, independent=independent, key=key)
536
656
 
537
657
 
538
- def shuffle(x, axis=0, key: Optional[SeedOrKey] = None):
658
+ def shuffle(
659
+ x,
660
+ axis=0,
661
+ key: Optional[SeedOrKey] = None
662
+ ):
539
663
  r"""
540
664
  Modify a sequence in-place by shuffling its contents.
541
665
 
@@ -557,25 +681,37 @@ def shuffle(x, axis=0, key: Optional[SeedOrKey] = None):
557
681
 
558
682
  Examples
559
683
  --------
560
- >>> import brainstate as brainstate
561
- >>> arr = np.arange(10)
562
- >>> brainstate.random.shuffle(arr)
563
- >>> arr
564
- [1 7 5 2 9 4 3 6 0 8] # random
684
+ Shuffle a 1D array in-place:
685
+
686
+ .. code-block:: python
687
+
688
+ >>> import brainstate
689
+ >>> import numpy as np
690
+ >>> arr = np.arange(10)
691
+ >>> original_elements = set(arr)
692
+ >>> brainstate.random.shuffle(arr)
693
+ >>> print(set(arr) == original_elements) # True (same elements)
565
694
 
566
695
  Multi-dimensional arrays are only shuffled along the first axis:
567
696
 
568
- >>> arr = np.arange(9).reshape((3, 3))
569
- >>> brainstate.random.shuffle(arr)
570
- >>> arr
571
- array([[3, 4, 5], # random
572
- [6, 7, 8],
573
- [0, 1, 2]])
697
+ .. code-block:: python
698
+
699
+ >>> arr = np.arange(9).reshape((3, 3))
700
+ >>> original_shape = arr.shape
701
+ >>> brainstate.random.shuffle(arr)
702
+ >>> print(arr.shape == original_shape) # True (shape preserved)
703
+ >>> print(sorted(arr.flatten()) == list(range(9))) # True (same elements)
574
704
  """
575
705
  return DEFAULT.shuffle(x, axis, key=key)
576
706
 
577
707
 
578
- def beta(a, b, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
708
+ def beta(
709
+ a,
710
+ b,
711
+ size: Optional[Size] = None,
712
+ key: Optional[SeedOrKey] = None,
713
+ dtype: DTypeLike = None
714
+ ):
579
715
  r"""
580
716
  Draw samples from a Beta distribution.
581
717
 
@@ -616,9 +752,12 @@ def beta(a, b, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dty
616
752
  return DEFAULT.beta(a, b, size=size, key=key, dtype=dtype)
617
753
 
618
754
 
619
- def exponential(scale=None, size: Optional[Size] = None,
620
- key: Optional[SeedOrKey] = None,
621
- dtype: DTypeLike = None):
755
+ def exponential(
756
+ scale=None,
757
+ size: Optional[Size] = None,
758
+ key: Optional[SeedOrKey] = None,
759
+ dtype: DTypeLike = None
760
+ ):
622
761
  r"""
623
762
  Draw samples from an exponential distribution.
624
763
 
@@ -667,9 +806,13 @@ def exponential(scale=None, size: Optional[Size] = None,
667
806
  return DEFAULT.exponential(scale, size, key=key, dtype=dtype)
668
807
 
669
808
 
670
- def gamma(shape, scale=None, size: Optional[Size] = None,
671
- key: Optional[SeedOrKey] = None,
672
- dtype: DTypeLike = None):
809
+ def gamma(
810
+ shape,
811
+ scale=None,
812
+ size: Optional[Size] = None,
813
+ key: Optional[SeedOrKey] = None,
814
+ dtype: DTypeLike = None
815
+ ):
673
816
  r"""
674
817
  Draw samples from a Gamma distribution.
675
818
 
@@ -724,9 +867,13 @@ def gamma(shape, scale=None, size: Optional[Size] = None,
724
867
  return DEFAULT.gamma(shape, scale, size=size, key=key, dtype=dtype)
725
868
 
726
869
 
727
- def gumbel(loc=None, scale=None, size: Optional[Size] = None,
728
- key: Optional[SeedOrKey] = None,
729
- dtype: DTypeLike = None):
870
+ def gumbel(
871
+ loc=None,
872
+ scale=None,
873
+ size: Optional[Size] = None,
874
+ key: Optional[SeedOrKey] = None,
875
+ dtype: DTypeLike = None
876
+ ):
730
877
  r"""
731
878
  Draw samples from a Gumbel distribution.
732
879
 
@@ -798,9 +945,13 @@ def gumbel(loc=None, scale=None, size: Optional[Size] = None,
798
945
  return DEFAULT.gumbel(loc, scale, size=size, key=key, dtype=dtype)
799
946
 
800
947
 
801
- def laplace(loc=None, scale=None, size: Optional[Size] = None,
802
- key: Optional[SeedOrKey] = None,
803
- dtype: DTypeLike = None):
948
+ def laplace(
949
+ loc=None,
950
+ scale=None,
951
+ size: Optional[Size] = None,
952
+ key: Optional[SeedOrKey] = None,
953
+ dtype: DTypeLike = None
954
+ ):
804
955
  r"""
805
956
  Draw samples from the Laplace or double exponential distribution with
806
957
  specified location (or mean) and scale (decay).
@@ -883,9 +1034,13 @@ def laplace(loc=None, scale=None, size: Optional[Size] = None,
883
1034
  return DEFAULT.laplace(loc, scale, size, key=key, dtype=dtype)
884
1035
 
885
1036
 
886
- def logistic(loc=None, scale=None, size: Optional[Size] = None,
887
- key: Optional[SeedOrKey] = None,
888
- dtype: DTypeLike = None):
1037
+ def logistic(
1038
+ loc=None,
1039
+ scale=None,
1040
+ size: Optional[Size] = None,
1041
+ key: Optional[SeedOrKey] = None,
1042
+ dtype: DTypeLike = None
1043
+ ):
889
1044
  r"""
890
1045
  Draw samples from a logistic distribution.
891
1046
 
@@ -1028,38 +1183,33 @@ def normal(
1028
1183
  --------
1029
1184
  Draw samples from the distribution:
1030
1185
 
1031
- >>> mu, sigma = 0, 0.1 # mean and standard deviation
1032
- >>> s = brainstate.random.normal(mu, sigma, 1000)
1033
-
1034
- Verify the mean and the variance:
1035
-
1036
- >>> abs(mu - np.mean(s))
1037
- 0.0 # may vary
1186
+ .. code-block:: python
1038
1187
 
1039
- >>> abs(sigma - np.std(s, ddof=1))
1040
- 0.1 # may vary
1041
-
1042
- Display the histogram of the samples, along with
1043
- the probability density function:
1044
-
1045
- >>> import matplotlib.pyplot as plt # noqa
1046
- >>> count, bins, ignored = plt.hist(s, 30, density=True)
1047
- >>> plt.plot(bins, 1/(sigma * np.sqrt(2 * np.pi)) *
1048
- ... np.exp( - (bins - mu)**2 / (2 * sigma**2) ),
1049
- ... linewidth=2, color='r')
1050
- >>> plt.show()
1188
+ >>> import brainstate
1189
+ >>> import numpy as np
1190
+ >>> mu, sigma = 0, 0.1 # mean and standard deviation
1191
+ >>> s = brainstate.random.normal(mu, sigma, 1000)
1192
+ >>> print(s.shape) # (1000,)
1193
+ >>> print(abs(mu - np.mean(s)) < 0.1) # True (approximately correct mean)
1194
+ >>> print(abs(sigma - np.std(s, ddof=1)) < 0.1) # True (approximately correct std)
1051
1195
 
1052
1196
  Two-by-four array of samples from the normal distribution with
1053
1197
  mean 3 and standard deviation 2.5:
1054
1198
 
1055
- >>> brainstate.random.normal(3, 2.5, size=(2, 4))
1056
- array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random
1057
- [ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random
1199
+ .. code-block:: python
1200
+
1201
+ >>> samples = brainstate.random.normal(3, 2.5, size=(2, 4))
1202
+ >>> print(samples.shape) # (2, 4)
1058
1203
  """
1059
1204
  return DEFAULT.normal(loc, scale, size, key=key, dtype=dtype)
1060
1205
 
1061
1206
 
1062
- def pareto(a, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
1207
+ def pareto(
1208
+ a,
1209
+ size: Optional[Size] = None,
1210
+ key: Optional[SeedOrKey] = None,
1211
+ dtype: DTypeLike = None
1212
+ ):
1063
1213
  r"""
1064
1214
  Draw samples from a Pareto II or Lomax distribution with
1065
1215
  specified shape.
@@ -1154,7 +1304,12 @@ def pareto(a, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtyp
1154
1304
  return DEFAULT.pareto(a, size, key=key, dtype=dtype)
1155
1305
 
1156
1306
 
1157
- def poisson(lam=1.0, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
1307
+ def poisson(
1308
+ lam=1.0,
1309
+ size: Optional[Size] = None,
1310
+ key: Optional[SeedOrKey] = None,
1311
+ dtype: DTypeLike = None
1312
+ ):
1158
1313
  r"""
1159
1314
  Draw samples from a Poisson distribution.
1160
1315
 
@@ -1224,7 +1379,11 @@ def poisson(lam=1.0, size: Optional[Size] = None, key: Optional[SeedOrKey] = Non
1224
1379
  return DEFAULT.poisson(lam, size, key=key, dtype=dtype)
1225
1380
 
1226
1381
 
1227
- def standard_cauchy(size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
1382
+ def standard_cauchy(
1383
+ size: Optional[Size] = None,
1384
+ key: Optional[SeedOrKey] = None,
1385
+ dtype: DTypeLike = None
1386
+ ):
1228
1387
  r"""
1229
1388
  Draw samples from a standard Cauchy distribution with mode = 0.
1230
1389
 
@@ -1289,9 +1448,11 @@ def standard_cauchy(size: Optional[Size] = None, key: Optional[SeedOrKey] = None
1289
1448
  return DEFAULT.standard_cauchy(size, key=key, dtype=dtype)
1290
1449
 
1291
1450
 
1292
- def standard_exponential(size: Optional[Size] = None,
1293
- key: Optional[SeedOrKey] = None,
1294
- dtype: DTypeLike = None):
1451
+ def standard_exponential(
1452
+ size: Optional[Size] = None,
1453
+ key: Optional[SeedOrKey] = None,
1454
+ dtype: DTypeLike = None
1455
+ ):
1295
1456
  r"""
1296
1457
  Draw samples from the standard exponential distribution.
1297
1458
 
@@ -1322,9 +1483,12 @@ def standard_exponential(size: Optional[Size] = None,
1322
1483
  return DEFAULT.standard_exponential(size, key=key, dtype=dtype)
1323
1484
 
1324
1485
 
1325
- def standard_gamma(shape, size: Optional[Size] = None,
1326
- key: Optional[SeedOrKey] = None,
1327
- dtype: DTypeLike = None):
1486
+ def standard_gamma(
1487
+ shape,
1488
+ size: Optional[Size] = None,
1489
+ key: Optional[SeedOrKey] = None,
1490
+ dtype: DTypeLike = None
1491
+ ):
1328
1492
  r"""
1329
1493
  Draw samples from a standard Gamma distribution.
1330
1494
 
@@ -1396,7 +1560,11 @@ def standard_gamma(shape, size: Optional[Size] = None,
1396
1560
  return DEFAULT.standard_gamma(shape, size, key=key, dtype=dtype)
1397
1561
 
1398
1562
 
1399
- def standard_normal(size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
1563
+ def standard_normal(
1564
+ size: Optional[Size] = None,
1565
+ key: Optional[SeedOrKey] = None,
1566
+ dtype: DTypeLike = None
1567
+ ):
1400
1568
  r"""
1401
1569
  Draw samples from a standard Normal distribution (mean=0, stdev=1).
1402
1570
 
@@ -1432,30 +1600,45 @@ def standard_normal(size: Optional[Size] = None, key: Optional[SeedOrKey] = None
1432
1600
 
1433
1601
  Examples
1434
1602
  --------
1435
- >>> brainstate.random.standard_normal()
1436
- 2.1923875335537315 #random
1437
-
1438
- >>> s = brainstate.random.standard_normal(8000)
1439
- >>> s
1440
- array([ 0.6888893 , 0.78096262, -0.89086505, ..., 0.49876311, # random
1441
- -0.38672696, -0.4685006 ]) # random
1442
- >>> s.shape
1443
- (8000,)
1444
- >>> s = brainstate.random.standard_normal(size=(3, 4, 2))
1445
- >>> s.shape
1446
- (3, 4, 2)
1603
+ Generate a single standard normal sample:
1604
+
1605
+ .. code-block:: python
1606
+
1607
+ >>> import brainstate
1608
+ >>> val = brainstate.random.standard_normal()
1609
+ >>> print(type(val)) # <class 'numpy.float64'>
1610
+
1611
+ Generate an array of 8000 standard normal samples:
1612
+
1613
+ .. code-block:: python
1614
+
1615
+ >>> s = brainstate.random.standard_normal(8000)
1616
+ >>> print(s.shape) # (8000,)
1617
+
1618
+ Generate a 3x4x2 array of standard normal samples:
1619
+
1620
+ .. code-block:: python
1621
+
1622
+ >>> s = brainstate.random.standard_normal(size=(3, 4, 2))
1623
+ >>> print(s.shape) # (3, 4, 2)
1447
1624
 
1448
1625
  Two-by-four array of samples from the normal distribution with
1449
1626
  mean 3 and standard deviation 2.5:
1450
1627
 
1451
- >>> 3 + 2.5 * brainstate.random.standard_normal(size=(2, 4))
1452
- array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random
1453
- [ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random
1628
+ .. code-block:: python
1629
+
1630
+ >>> samples = 3 + 2.5 * brainstate.random.standard_normal(size=(2, 4))
1631
+ print(samples.shape) # (2, 4)
1454
1632
  """
1455
1633
  return DEFAULT.standard_normal(size, key=key, dtype=dtype)
1456
1634
 
1457
1635
 
1458
- def standard_t(df, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
1636
+ def standard_t(
1637
+ df,
1638
+ size: Optional[Size] = None,
1639
+ key: Optional[SeedOrKey] = None,
1640
+ dtype: DTypeLike = None
1641
+ ):
1459
1642
  r"""
1460
1643
  Draw samples from a standard Student's t distribution with `df` degrees
1461
1644
  of freedom.
@@ -1558,8 +1741,13 @@ def standard_t(df, size: Optional[Size] = None, key: Optional[SeedOrKey] = None,
1558
1741
  return DEFAULT.standard_t(df, size, key=key, dtype=dtype)
1559
1742
 
1560
1743
 
1561
- def uniform(low=0.0, high=1.0, size: Optional[Size] = None,
1562
- key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
1744
+ def uniform(
1745
+ low=0.0,
1746
+ high=1.0,
1747
+ size: Optional[Size] = None,
1748
+ key: Optional[SeedOrKey] = None,
1749
+ dtype: DTypeLike = None
1750
+ ):
1563
1751
  r"""
1564
1752
  Draw samples from a uniform distribution.
1565
1753
 
@@ -1649,8 +1837,16 @@ def uniform(low=0.0, high=1.0, size: Optional[Size] = None,
1649
1837
  return DEFAULT.uniform(low, high, size, key=key, dtype=dtype)
1650
1838
 
1651
1839
 
1652
- def truncated_normal(lower, upper, size: Optional[Size] = None, loc=0., scale=1.,
1653
- key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
1840
+ def truncated_normal(
1841
+ lower,
1842
+ upper,
1843
+ size: Optional[Size] = None,
1844
+ loc=0.0,
1845
+ scale=1.0,
1846
+ key: Optional[SeedOrKey] = None,
1847
+ dtype: DTypeLike = None,
1848
+ check_valid: bool = True
1849
+ ):
1654
1850
  r"""Sample truncated standard normal random values with given shape and dtype.
1655
1851
 
1656
1852
  Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
@@ -1706,13 +1902,27 @@ def truncated_normal(lower, upper, size: Optional[Size] = None, loc=0., scale=1.
1706
1902
  ``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
1707
1903
  Returns values in the open interval ``(lower, upper)``.
1708
1904
  """
1709
- return DEFAULT.truncated_normal(lower, upper, size, loc, scale, key=key, dtype=dtype)
1905
+ return DEFAULT.truncated_normal(
1906
+ lower,
1907
+ upper,
1908
+ size,
1909
+ loc,
1910
+ scale,
1911
+ key=key,
1912
+ dtype=dtype,
1913
+ check_valid=check_valid,
1914
+ )
1710
1915
 
1711
1916
 
1712
1917
  RandomState.truncated_normal.__doc__ = truncated_normal.__doc__
1713
1918
 
1714
1919
 
1715
- def bernoulli(p=0.5, size: Optional[Size] = None, key: Optional[SeedOrKey] = None):
1920
+ def bernoulli(
1921
+ p=0.5,
1922
+ size: Optional[Size] = None,
1923
+ key: Optional[SeedOrKey] = None,
1924
+ check_valid: bool = True,
1925
+ ):
1716
1926
  r"""Sample Bernoulli random values with given shape and mean.
1717
1927
 
1718
1928
  Parameters
@@ -1735,11 +1945,16 @@ def bernoulli(p=0.5, size: Optional[Size] = None, key: Optional[SeedOrKey] = Non
1735
1945
  A random array with boolean dtype and shape given by ``shape`` if ``shape``
1736
1946
  is not None, or else ``p.shape``.
1737
1947
  """
1738
- return DEFAULT.bernoulli(p, size, key=key)
1948
+ return DEFAULT.bernoulli(p, size, key=key, check_valid=check_valid)
1739
1949
 
1740
1950
 
1741
- def lognormal(mean=None, sigma=None, size: Optional[Size] = None,
1742
- key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
1951
+ def lognormal(
1952
+ mean=None,
1953
+ sigma=None,
1954
+ size: Optional[Size] = None,
1955
+ key: Optional[SeedOrKey] = None,
1956
+ dtype: DTypeLike = None
1957
+ ):
1743
1958
  r"""
1744
1959
  Draw samples from a log-normal distribution.
1745
1960
 
@@ -1853,7 +2068,7 @@ def binomial(
1853
2068
  size: Optional[Size] = None,
1854
2069
  key: Optional[SeedOrKey] = None,
1855
2070
  dtype: DTypeLike = None,
1856
- check_valid: bool = True,
2071
+ check_valid: bool = True
1857
2072
  ):
1858
2073
  r"""
1859
2074
  Draw samples from a binomial distribution.
@@ -1942,7 +2157,12 @@ def binomial(
1942
2157
  return DEFAULT.binomial(n, p, size, key=key, dtype=dtype, check_valid=check_valid)
1943
2158
 
1944
2159
 
1945
- def chisquare(df, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
2160
+ def chisquare(
2161
+ df,
2162
+ size: Optional[Size] = None,
2163
+ key: Optional[SeedOrKey] = None,
2164
+ dtype: DTypeLike = None
2165
+ ):
1946
2166
  r"""
1947
2167
  Draw samples from a chi-square distribution.
1948
2168
 
@@ -2002,13 +2222,24 @@ def chisquare(df, size: Optional[Size] = None, key: Optional[SeedOrKey] = None,
2002
2222
 
2003
2223
  Examples
2004
2224
  --------
2005
- >>> brainstate.random.chisquare(2,4)
2006
- array([ 1.89920014, 9.00867716, 3.13710533, 5.62318272]) # random
2225
+ Generate chi-square samples with 2 degrees of freedom:
2226
+
2227
+ .. code-block:: python
2228
+
2229
+ >>> import brainstate
2230
+ >>> samples = brainstate.random.chisquare(2, 4)
2231
+ >>> print(samples.shape) # (4,)
2232
+ >>> print((samples >= 0).all()) # True (chi-square is always non-negative)
2007
2233
  """
2008
2234
  return DEFAULT.chisquare(df, size, key=key, dtype=dtype)
2009
2235
 
2010
2236
 
2011
- def dirichlet(alpha, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
2237
+ def dirichlet(
2238
+ alpha,
2239
+ size: Optional[Size] = None,
2240
+ key: Optional[SeedOrKey] = None,
2241
+ dtype: DTypeLike = None
2242
+ ):
2012
2243
  r"""
2013
2244
  Draw samples from the Dirichlet distribution.
2014
2245
 
@@ -2077,6 +2308,7 @@ def dirichlet(alpha, size: Optional[Size] = None, key: Optional[SeedOrKey] = Non
2077
2308
  average length, but allowing some variation in the relative sizes of
2078
2309
  the pieces.
2079
2310
 
2311
+ >>> import brainstate
2080
2312
  >>> s = brainstate.random.dirichlet((10, 5, 3), 20).transpose()
2081
2313
 
2082
2314
  >>> import matplotlib.pyplot as plt # noqa
@@ -2088,7 +2320,12 @@ def dirichlet(alpha, size: Optional[Size] = None, key: Optional[SeedOrKey] = Non
2088
2320
  return DEFAULT.dirichlet(alpha, size, key=key, dtype=dtype)
2089
2321
 
2090
2322
 
2091
- def geometric(p, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
2323
+ def geometric(
2324
+ p,
2325
+ size: Optional[Size] = None,
2326
+ key: Optional[SeedOrKey] = None,
2327
+ dtype: DTypeLike = None
2328
+ ):
2092
2329
  r"""
2093
2330
  Draw samples from the geometric distribution.
2094
2331
 
@@ -2127,6 +2364,7 @@ def geometric(p, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, d
2127
2364
  Draw ten thousand values from the geometric distribution,
2128
2365
  with the probability of an individual success equal to 0.35:
2129
2366
 
2367
+ >>> import brainstate
2130
2368
  >>> z = brainstate.random.geometric(p=0.35, size=10000)
2131
2369
 
2132
2370
  How many trials succeeded after a single run?
@@ -2137,7 +2375,13 @@ def geometric(p, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, d
2137
2375
  return DEFAULT.geometric(p, size, key=key, dtype=dtype)
2138
2376
 
2139
2377
 
2140
- def f(dfnum, dfden, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
2378
+ def f(
2379
+ dfnum,
2380
+ dfden,
2381
+ size: Optional[Size] = None,
2382
+ key: Optional[SeedOrKey] = None,
2383
+ dtype: DTypeLike = None
2384
+ ):
2141
2385
  r"""
2142
2386
  Draw samples from an F distribution.
2143
2387
 
@@ -2207,6 +2451,7 @@ def f(dfnum, dfden, size: Optional[Size] = None, key: Optional[SeedOrKey] = None
2207
2451
 
2208
2452
  Draw samples from the distribution:
2209
2453
 
2454
+ >>> import brainstate
2210
2455
  >>> dfnum = 1. # between group degrees of freedom
2211
2456
  >>> dfden = 48. # within groups degrees of freedom
2212
2457
  >>> s = brainstate.random.f(dfnum, dfden, 1000)
@@ -2223,8 +2468,14 @@ def f(dfnum, dfden, size: Optional[Size] = None, key: Optional[SeedOrKey] = None
2223
2468
  return DEFAULT.f(dfnum, dfden, size, key=key, dtype=dtype)
2224
2469
 
2225
2470
 
2226
- def hypergeometric(ngood, nbad, nsample, size: Optional[Size] = None,
2227
- key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
2471
+ def hypergeometric(
2472
+ ngood,
2473
+ nbad,
2474
+ nsample,
2475
+ size: Optional[Size] = None,
2476
+ key: Optional[SeedOrKey] = None,
2477
+ dtype: DTypeLike = None
2478
+ ):
2228
2479
  r"""
2229
2480
  Draw samples from a Hypergeometric distribution.
2230
2481
 
@@ -2300,6 +2551,7 @@ def hypergeometric(ngood, nbad, nsample, size: Optional[Size] = None,
2300
2551
  --------
2301
2552
  Draw samples from the distribution:
2302
2553
 
2554
+ >>> import brainstate
2303
2555
  >>> ngood, nbad, nsamp = 100, 2, 10
2304
2556
  # number of good, number of bad, and number of samples
2305
2557
  >>> s = brainstate.random.hypergeometric(ngood, nbad, nsamp, 1000)
@@ -2318,7 +2570,12 @@ def hypergeometric(ngood, nbad, nsample, size: Optional[Size] = None,
2318
2570
  return DEFAULT.hypergeometric(ngood, nbad, nsample, size, key=key, dtype=dtype)
2319
2571
 
2320
2572
 
2321
- def logseries(p, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
2573
+ def logseries(
2574
+ p,
2575
+ size: Optional[Size] = None,
2576
+ key: Optional[SeedOrKey] = None,
2577
+ dtype: DTypeLike = None
2578
+ ):
2322
2579
  r"""
2323
2580
  Draw samples from a logarithmic series distribution.
2324
2581
 
@@ -2380,6 +2637,7 @@ def logseries(p, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, d
2380
2637
  --------
2381
2638
  Draw samples from the distribution:
2382
2639
 
2640
+ >>> import brainstate
2383
2641
  >>> a = .6
2384
2642
  >>> s = brainstate.random.logseries(a, 10000)
2385
2643
  >>> import matplotlib.pyplot as plt # noqa
@@ -2396,11 +2654,14 @@ def logseries(p, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, d
2396
2654
  return DEFAULT.logseries(p, size, key=key, dtype=dtype)
2397
2655
 
2398
2656
 
2399
- def multinomial(n,
2400
- pvals,
2401
- size: Optional[Size] = None,
2402
- key: Optional[SeedOrKey] = None,
2403
- dtype: DTypeLike = None):
2657
+ def multinomial(
2658
+ n,
2659
+ pvals,
2660
+ size: Optional[Size] = None,
2661
+ key: Optional[SeedOrKey] = None,
2662
+ dtype: DTypeLike = None,
2663
+ check_valid: bool = True,
2664
+ ):
2404
2665
  r"""
2405
2666
  Draw samples from a multinomial distribution.
2406
2667
 
@@ -2442,45 +2703,49 @@ def multinomial(n,
2442
2703
  --------
2443
2704
  Throw a dice 20 times:
2444
2705
 
2445
- >>> brainstate.random.multinomial(20, [1/6.]*6, size=1)
2446
- array([[4, 1, 7, 5, 2, 1]]) # random
2706
+ .. code-block:: python
2447
2707
 
2448
- It landed 4 times on 1, once on 2, etc.
2708
+ >>> import brainstate
2709
+ >>> result = brainstate.random.multinomial(20, [1/6.]*6, size=1)
2710
+ >>> print(result.shape) # (1, 6)
2711
+ >>> print(result.sum()) # 20 (total throws)
2449
2712
 
2450
2713
  Now, throw the dice 20 times, and 20 times again:
2451
2714
 
2452
- >>> brainstate.random.multinomial(20, [1/6.]*6, size=2)
2453
- array([[3, 4, 3, 3, 4, 3], # random
2454
- [2, 4, 3, 4, 0, 7]])
2715
+ .. code-block:: python
2455
2716
 
2456
- For the first run, we threw 3 times 1, 4 times 2, etc. For the second,
2457
- we threw 2 times 1, 4 times 2, etc.
2717
+ >>> result = brainstate.random.multinomial(20, [1/6.]*6, size=2)
2718
+ >>> print(result.shape) # (2, 6)
2719
+ >>> print(result.sum(axis=1)) # [20, 20] (total throws per experiment)
2458
2720
 
2459
2721
  A loaded die is more likely to land on number 6:
2460
2722
 
2461
- >>> brainstate.random.multinomial(100, [1/7.]*5 + [2/7.])
2462
- array([11, 16, 14, 17, 16, 26]) # random
2723
+ .. code-block:: python
2463
2724
 
2464
- The probability inputs should be normalized. As an implementation
2465
- detail, the value of the last entry is ignored and assumed to take
2466
- up any leftover probability mass, but this should not be relied on.
2467
- A biased coin which has twice as much weight on one side as on the
2468
- other should be sampled like so:
2725
+ >>> result = brainstate.random.multinomial(100, [1/7.]*5 + [2/7.])
2726
+ >>> print(result.shape) # (6,)
2727
+ >>> print(result.sum()) # 100 (total throws)
2469
2728
 
2470
- >>> brainstate.random.multinomial(100, [1.0 / 3, 2.0 / 3]) # RIGHT
2471
- array([38, 62]) # random
2729
+ The probability inputs should be normalized. A biased coin which has
2730
+ twice as much weight on one side as on the other should be sampled like so:
2472
2731
 
2473
- not like:
2732
+ .. code-block:: python
2474
2733
 
2475
- >>> brainstate.random.multinomial(100, [1.0, 2.0]) # WRONG
2476
- Traceback (most recent call last):
2477
- ValueError: pvals < 0, pvals > 1 or pvals contains NaNs
2734
+ >>> result = brainstate.random.multinomial(100, [1.0 / 3, 2.0 / 3])
2735
+ >>> print(result.shape) # (2,)
2736
+ print(result.sum()) # 100 (total throws)
2478
2737
  """
2479
- return DEFAULT.multinomial(n, pvals, size, key=key, dtype=dtype)
2738
+ return DEFAULT.multinomial(n, pvals, size, key=key, dtype=dtype, check_valid=check_valid)
2480
2739
 
2481
2740
 
2482
- def multivariate_normal(mean, cov, size: Optional[Size] = None, method: str = 'cholesky',
2483
- key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
2741
+ def multivariate_normal(
2742
+ mean,
2743
+ cov,
2744
+ size: Optional[Size] = None,
2745
+ method: str = 'cholesky',
2746
+ key: Optional[SeedOrKey] = None,
2747
+ dtype: DTypeLike = None
2748
+ ):
2484
2749
  r"""
2485
2750
  Draw random samples from a multivariate normal distribution.
2486
2751
 
@@ -2549,6 +2814,7 @@ def multivariate_normal(mean, cov, size: Optional[Size] = None, method: str = 'c
2549
2814
 
2550
2815
  Diagonal covariance means that points are oriented along x or y-axis:
2551
2816
 
2817
+ >>> import brainstate
2552
2818
  >>> import matplotlib.pyplot as plt # noqa
2553
2819
  >>> x, y = brainstate.random.multivariate_normal(mean, cov, 5000).T
2554
2820
  >>> plt.plot(x, y, 'x')
@@ -2607,11 +2873,13 @@ def multivariate_normal(mean, cov, size: Optional[Size] = None, method: str = 'c
2607
2873
  return DEFAULT.multivariate_normal(mean, cov, size, method, key=key, dtype=dtype)
2608
2874
 
2609
2875
 
2610
- def negative_binomial(n,
2611
- p,
2612
- size: Optional[Size] = None,
2613
- key: Optional[SeedOrKey] = None,
2614
- dtype: DTypeLike = None):
2876
+ def negative_binomial(
2877
+ n,
2878
+ p,
2879
+ size: Optional[Size] = None,
2880
+ key: Optional[SeedOrKey] = None,
2881
+ dtype: DTypeLike = None
2882
+ ):
2615
2883
  r"""
2616
2884
  Draw samples from a negative binomial distribution.
2617
2885
 
@@ -2677,6 +2945,7 @@ def negative_binomial(n,
2677
2945
  for each successive well, that is what is the probability of a
2678
2946
  single success after drilling 5 wells, after 6 wells, etc.?
2679
2947
 
2948
+ >>> import brainstate
2680
2949
  >>> s = brainstate.random.negative_binomial(1, 0.1, 100000)
2681
2950
  >>> for i in range(1, 11): # doctest: +SKIP
2682
2951
  ... probability = sum(s<i) / 100000.
@@ -2685,9 +2954,13 @@ def negative_binomial(n,
2685
2954
  return DEFAULT.negative_binomial(n, p, size, key=key, dtype=dtype)
2686
2955
 
2687
2956
 
2688
- def noncentral_chisquare(df, nonc, size: Optional[Size] = None,
2689
- key: Optional[SeedOrKey] = None,
2690
- dtype: DTypeLike = None):
2957
+ def noncentral_chisquare(
2958
+ df,
2959
+ nonc,
2960
+ size: Optional[Size] = None,
2961
+ key: Optional[SeedOrKey] = None,
2962
+ dtype: DTypeLike = None
2963
+ ):
2691
2964
  r"""
2692
2965
  Draw samples from a noncentral chi-square distribution.
2693
2966
 
@@ -2734,6 +3007,7 @@ def noncentral_chisquare(df, nonc, size: Optional[Size] = None,
2734
3007
  --------
2735
3008
  Draw values from the distribution and plot the histogram
2736
3009
 
3010
+ >>> import brainstate
2737
3011
  >>> import matplotlib.pyplot as plt # noqa
2738
3012
  >>> values = plt.hist(brainstate.random.noncentral_chisquare(3, 20, 100000),
2739
3013
  ... bins=200, density=True)
@@ -2761,9 +3035,14 @@ def noncentral_chisquare(df, nonc, size: Optional[Size] = None,
2761
3035
  return DEFAULT.noncentral_chisquare(df, nonc, size, key=key, dtype=dtype)
2762
3036
 
2763
3037
 
2764
- def noncentral_f(dfnum, dfden, nonc, size: Optional[Size] = None,
2765
- key: Optional[SeedOrKey] = None,
2766
- dtype: DTypeLike = None):
3038
+ def noncentral_f(
3039
+ dfnum,
3040
+ dfden,
3041
+ nonc,
3042
+ size: Optional[Size] = None,
3043
+ key: Optional[SeedOrKey] = None,
3044
+ dtype: DTypeLike = None
3045
+ ):
2767
3046
  r"""
2768
3047
  Draw samples from the noncentral F distribution.
2769
3048
 
@@ -2820,6 +3099,7 @@ def noncentral_f(dfnum, dfden, nonc, size: Optional[Size] = None,
2820
3099
  distribution for the null hypothesis. We'll plot the two probability
2821
3100
  distributions for comparison.
2822
3101
 
3102
+ >>> import brainstate
2823
3103
  >>> dfnum = 3 # between group deg of freedom
2824
3104
  >>> dfden = 20 # within groups degrees of freedom
2825
3105
  >>> nonc = 3.0
@@ -2835,10 +3115,12 @@ def noncentral_f(dfnum, dfden, nonc, size: Optional[Size] = None,
2835
3115
  return DEFAULT.noncentral_f(dfnum, dfden, nonc, size, key=key, dtype=dtype)
2836
3116
 
2837
3117
 
2838
- def power(a,
2839
- size: Optional[Size] = None,
2840
- key: Optional[SeedOrKey] = None,
2841
- dtype: DTypeLike = None):
3118
+ def power(
3119
+ a,
3120
+ size: Optional[Size] = None,
3121
+ key: Optional[SeedOrKey] = None,
3122
+ dtype: DTypeLike = None
3123
+ ):
2842
3124
  r"""
2843
3125
  Draws samples in [0, 1] from a power distribution with positive
2844
3126
  exponent a - 1.
@@ -2895,6 +3177,7 @@ def power(a,
2895
3177
  --------
2896
3178
  Draw samples from the distribution:
2897
3179
 
3180
+ >>> import brainstate
2898
3181
  >>> a = 5. # shape
2899
3182
  >>> samples = 1000
2900
3183
  >>> s = brainstate.random.power(a, samples)
@@ -2936,10 +3219,12 @@ def power(a,
2936
3219
  return DEFAULT.power(a, size, key=key, dtype=dtype)
2937
3220
 
2938
3221
 
2939
- def rayleigh(scale=1.0,
2940
- size: Optional[Size] = None,
2941
- key: Optional[SeedOrKey] = None,
2942
- dtype: DTypeLike = None):
3222
+ def rayleigh(
3223
+ scale=1.0,
3224
+ size: Optional[Size] = None,
3225
+ key: Optional[SeedOrKey] = None,
3226
+ dtype: DTypeLike = None
3227
+ ):
2943
3228
  r"""
2944
3229
  Draw samples from a Rayleigh distribution.
2945
3230
 
@@ -2986,6 +3271,7 @@ def rayleigh(scale=1.0,
2986
3271
  --------
2987
3272
  Draw values from the distribution and plot the histogram
2988
3273
 
3274
+ >>> import brainstate
2989
3275
  >>> from matplotlib.pyplot import hist # noqa
2990
3276
  >>> values = hist(brainstate.random.rayleigh(3, 100000), bins=200, density=True)
2991
3277
 
@@ -3005,8 +3291,10 @@ def rayleigh(scale=1.0,
3005
3291
  return DEFAULT.rayleigh(scale, size, key=key, dtype=dtype)
3006
3292
 
3007
3293
 
3008
- def triangular(size: Optional[Size] = None,
3009
- key: Optional[SeedOrKey] = None):
3294
+ def triangular(
3295
+ size: Optional[Size] = None,
3296
+ key: Optional[SeedOrKey] = None
3297
+ ):
3010
3298
  r"""
3011
3299
  Draw samples from the triangular distribution over the
3012
3300
  interval ``[left, right]``.
@@ -3057,6 +3345,7 @@ def triangular(size: Optional[Size] = None,
3057
3345
  --------
3058
3346
  Draw values from the distribution and plot the histogram:
3059
3347
 
3348
+ >>> import brainstate
3060
3349
  >>> import matplotlib.pyplot as plt # noqa
3061
3350
  >>> h = plt.hist(brainstate.random.triangular(-3, 0, 8, 100000), bins=200,
3062
3351
  ... density=True)
@@ -3065,11 +3354,13 @@ def triangular(size: Optional[Size] = None,
3065
3354
  return DEFAULT.triangular(size, key=key)
3066
3355
 
3067
3356
 
3068
- def vonmises(mu,
3069
- kappa,
3070
- size: Optional[Size] = None,
3071
- key: Optional[SeedOrKey] = None,
3072
- dtype: DTypeLike = None):
3357
+ def vonmises(
3358
+ mu,
3359
+ kappa,
3360
+ size: Optional[Size] = None,
3361
+ key: Optional[SeedOrKey] = None,
3362
+ dtype: DTypeLike = None
3363
+ ):
3073
3364
  r"""
3074
3365
  Draw samples from a von Mises distribution.
3075
3366
 
@@ -3133,6 +3424,7 @@ def vonmises(mu,
3133
3424
  --------
3134
3425
  Draw samples from the distribution:
3135
3426
 
3427
+ >>> import brainstate
3136
3428
  >>> mu, kappa = 0.0, 4.0 # mean and dispersion
3137
3429
  >>> s = brainstate.random.vonmises(mu, kappa, 1000)
3138
3430
 
@@ -3150,11 +3442,13 @@ def vonmises(mu,
3150
3442
  return DEFAULT.vonmises(mu, kappa, size, key=key, dtype=dtype)
3151
3443
 
3152
3444
 
3153
- def wald(mean,
3154
- scale,
3155
- size: Optional[Size] = None,
3156
- key: Optional[SeedOrKey] = None,
3157
- dtype: DTypeLike = None):
3445
+ def wald(
3446
+ mean,
3447
+ scale,
3448
+ size: Optional[Size] = None,
3449
+ key: Optional[SeedOrKey] = None,
3450
+ dtype: DTypeLike = None
3451
+ ):
3158
3452
  r"""
3159
3453
  Draw samples from a Wald, or inverse Gaussian, distribution.
3160
3454
 
@@ -3213,6 +3507,7 @@ def wald(mean,
3213
3507
  --------
3214
3508
  Draw values from the distribution and plot the histogram:
3215
3509
 
3510
+ >>> import brainstate
3216
3511
  >>> import matplotlib.pyplot as plt # noqa
3217
3512
  >>> h = plt.hist(brainstate.random.wald(3, 2, 100000), bins=200, density=True)
3218
3513
  >>> plt.show()
@@ -3220,10 +3515,12 @@ def wald(mean,
3220
3515
  return DEFAULT.wald(mean, scale, size, key=key, dtype=dtype)
3221
3516
 
3222
3517
 
3223
- def weibull(a,
3224
- size: Optional[Size] = None,
3225
- key: Optional[SeedOrKey] = None,
3226
- dtype: DTypeLike = None):
3518
+ def weibull(
3519
+ a,
3520
+ size: Optional[Size] = None,
3521
+ key: Optional[SeedOrKey] = None,
3522
+ dtype: DTypeLike = None
3523
+ ):
3227
3524
  r"""
3228
3525
  Draw samples from a Weibull distribution.
3229
3526
 
@@ -3237,10 +3534,6 @@ def weibull(a,
3237
3534
  The more common 2-parameter Weibull, including a scale parameter
3238
3535
  :math:`\lambda` is just :math:`X = \lambda(-ln(U))^{1/a}`.
3239
3536
 
3240
- .. note::
3241
- New code should use the ``weibull`` method of a ``default_rng()``
3242
- instance instead; please see the :ref:`random-quick-start`.
3243
-
3244
3537
  Parameters
3245
3538
  ----------
3246
3539
  a : float or array_like of floats
@@ -3296,6 +3589,7 @@ def weibull(a,
3296
3589
  --------
3297
3590
  Draw samples from the distribution:
3298
3591
 
3592
+ >>> import brainstate
3299
3593
  >>> a = 5. # shape
3300
3594
  >>> s = brainstate.random.weibull(a, 1000)
3301
3595
 
@@ -3317,11 +3611,13 @@ def weibull(a,
3317
3611
  return DEFAULT.weibull(a, size, key=key, dtype=dtype)
3318
3612
 
3319
3613
 
3320
- def weibull_min(a,
3321
- scale=None,
3322
- size: Optional[Size] = None,
3323
- key: Optional[SeedOrKey] = None,
3324
- dtype: DTypeLike = None):
3614
+ def weibull_min(
3615
+ a,
3616
+ scale=None,
3617
+ size: Optional[Size] = None,
3618
+ key: Optional[SeedOrKey] = None,
3619
+ dtype: DTypeLike = None
3620
+ ):
3325
3621
  """Sample from a Weibull distribution.
3326
3622
 
3327
3623
  The scipy counterpart is `scipy.stats.weibull_min`.
@@ -3340,10 +3636,12 @@ def weibull_min(a,
3340
3636
  return DEFAULT.weibull_min(a, scale, size, key=key, dtype=dtype)
3341
3637
 
3342
3638
 
3343
- def zipf(a,
3344
- size: Optional[Size] = None,
3345
- key: Optional[SeedOrKey] = None,
3346
- dtype: DTypeLike = None):
3639
+ def zipf(
3640
+ a,
3641
+ size: Optional[Size] = None,
3642
+ key: Optional[SeedOrKey] = None,
3643
+ dtype: DTypeLike = None
3644
+ ):
3347
3645
  r"""
3348
3646
  Draw samples from a Zipf distribution.
3349
3647
 
@@ -3355,10 +3653,6 @@ def zipf(a,
3355
3653
  frequency of an item is inversely proportional to its rank in a
3356
3654
  frequency table.
3357
3655
 
3358
- .. note::
3359
- New code should use the ``zipf`` method of a ``default_rng()``
3360
- instance instead; please see the :ref:`random-quick-start`.
3361
-
3362
3656
  Parameters
3363
3657
  ----------
3364
3658
  a : float or array_like of floats
@@ -3405,6 +3699,7 @@ def zipf(a,
3405
3699
  --------
3406
3700
  Draw samples from the distribution:
3407
3701
 
3702
+ >>> import brainstate
3408
3703
  >>> a = 4.0
3409
3704
  >>> n = 20000
3410
3705
  >>> s = brainstate.random.zipf(a, n)
@@ -3433,9 +3728,11 @@ def zipf(a,
3433
3728
  return DEFAULT.zipf(a, size, key=key, dtype=dtype)
3434
3729
 
3435
3730
 
3436
- def maxwell(size: Optional[Size] = None,
3437
- key: Optional[SeedOrKey] = None,
3438
- dtype: DTypeLike = None):
3731
+ def maxwell(
3732
+ size: Optional[Size] = None,
3733
+ key: Optional[SeedOrKey] = None,
3734
+ dtype: DTypeLike = None
3735
+ ):
3439
3736
  """Sample from a one sided Maxwell distribution.
3440
3737
 
3441
3738
  The scipy counterpart is `scipy.stats.maxwell`.
@@ -3453,10 +3750,12 @@ def maxwell(size: Optional[Size] = None,
3453
3750
  return DEFAULT.maxwell(size, key=key, dtype=dtype)
3454
3751
 
3455
3752
 
3456
- def t(df,
3457
- size: Optional[Size] = None,
3458
- key: Optional[SeedOrKey] = None,
3459
- dtype: DTypeLike = None):
3753
+ def t(
3754
+ df,
3755
+ size: Optional[Size] = None,
3756
+ key: Optional[SeedOrKey] = None,
3757
+ dtype: DTypeLike = None
3758
+ ):
3460
3759
  """Sample Student’s t random values.
3461
3760
 
3462
3761
  Parameters
@@ -3478,10 +3777,12 @@ def t(df,
3478
3777
  return DEFAULT.t(df, size, key=key, dtype=dtype)
3479
3778
 
3480
3779
 
3481
- def orthogonal(n: int,
3482
- size: Optional[Size] = None,
3483
- key: Optional[SeedOrKey] = None,
3484
- dtype: DTypeLike = None):
3780
+ def orthogonal(
3781
+ n: int,
3782
+ size: Optional[Size] = None,
3783
+ key: Optional[SeedOrKey] = None,
3784
+ dtype: DTypeLike = None
3785
+ ):
3485
3786
  """Sample uniformly from the orthogonal group `O(n)`.
3486
3787
 
3487
3788
  Parameters
@@ -3502,10 +3803,12 @@ def orthogonal(n: int,
3502
3803
  return DEFAULT.orthogonal(n, size, key=key, dtype=dtype)
3503
3804
 
3504
3805
 
3505
- def loggamma(a,
3506
- size: Optional[Size] = None,
3507
- key: Optional[SeedOrKey] = None,
3508
- dtype: DTypeLike = None):
3806
+ def loggamma(
3807
+ a,
3808
+ size: Optional[Size] = None,
3809
+ key: Optional[SeedOrKey] = None,
3810
+ dtype: DTypeLike = None
3811
+ ):
3509
3812
  """Sample log-gamma random values.
3510
3813
 
3511
3814
  Parameters
@@ -3530,10 +3833,12 @@ def loggamma(a,
3530
3833
  return DEFAULT.loggamma(a, size, key=key, dtype=dtype)
3531
3834
 
3532
3835
 
3533
- def categorical(logits,
3534
- axis: int = -1,
3535
- size: Optional[Size] = None,
3536
- key: Optional[SeedOrKey] = None):
3836
+ def categorical(
3837
+ logits,
3838
+ axis: int = -1,
3839
+ size: Optional[Size] = None,
3840
+ key: Optional[SeedOrKey] = None
3841
+ ):
3537
3842
  """Sample random values from categorical distributions.
3538
3843
 
3539
3844
  Args:
@@ -3552,7 +3857,12 @@ def categorical(logits,
3552
3857
  return DEFAULT.categorical(logits, axis, size, key=key)
3553
3858
 
3554
3859
 
3555
- def rand_like(input, *, dtype=None, key: Optional[SeedOrKey] = None):
3860
+ def rand_like(
3861
+ input,
3862
+ *,
3863
+ dtype=None,
3864
+ key: Optional[SeedOrKey] = None
3865
+ ):
3556
3866
  """Similar to ``rand_like`` in torch.
3557
3867
 
3558
3868
  Returns a tensor with the same size as input that is filled with random
@@ -3569,7 +3879,12 @@ def rand_like(input, *, dtype=None, key: Optional[SeedOrKey] = None):
3569
3879
  return DEFAULT.rand_like(input, dtype=dtype, key=key)
3570
3880
 
3571
3881
 
3572
- def randn_like(input, *, dtype=None, key: Optional[SeedOrKey] = None):
3882
+ def randn_like(
3883
+ input,
3884
+ *,
3885
+ dtype=None,
3886
+ key: Optional[SeedOrKey] = None
3887
+ ):
3573
3888
  """Similar to ``randn_like`` in torch.
3574
3889
 
3575
3890
  Returns a tensor with the same size as ``input`` that is filled with
@@ -3586,7 +3901,14 @@ def randn_like(input, *, dtype=None, key: Optional[SeedOrKey] = None):
3586
3901
  return DEFAULT.randn_like(input, dtype=dtype, key=key)
3587
3902
 
3588
3903
 
3589
- def randint_like(input, low=0, high=None, *, dtype=None, key: Optional[SeedOrKey] = None):
3904
+ def randint_like(
3905
+ input,
3906
+ low=0,
3907
+ high=None,
3908
+ *,
3909
+ dtype=None,
3910
+ key: Optional[SeedOrKey] = None
3911
+ ):
3590
3912
  """Similar to ``randint_like`` in torch.
3591
3913
 
3592
3914
  Returns a tensor with the same shape as Tensor ``input`` filled with