brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,1617 +1,1320 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from functools import partial
19
- from operator import index
20
- from typing import Optional
21
-
22
- import brainunit as u
23
- import jax
24
- import jax.numpy as jnp
25
- import jax.random as jr
26
- import numpy as np
27
- from jax import jit, vmap
28
- from jax import lax, core, dtypes
29
-
30
- from brainstate import environ
31
- from brainstate._state import State
32
- from brainstate.typing import DTypeLike, Size, SeedOrKey
33
-
34
- __all__ = [
35
- 'RandomState',
36
- 'DEFAULT',
37
- ]
38
-
39
- use_prng_key = True
40
-
41
-
42
- class RandomState(State):
43
- """RandomState that track the random generator state. """
44
-
45
- # __slots__ = ('_backup', '_value')
46
-
47
- def __init__(
48
- self,
49
- seed_or_key: Optional[SeedOrKey] = None
50
- ):
51
- """RandomState constructor.
52
-
53
- Parameters
54
- ----------
55
- seed_or_key: int, Array, optional
56
- It can be an integer for initial seed of the random number generator,
57
- or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
58
- """
59
- with jax.ensure_compile_time_eval():
60
- if seed_or_key is None:
61
- seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
62
- if isinstance(seed_or_key, int):
63
- key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
64
- else:
65
- if jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
66
- key = seed_or_key
67
- else:
68
- if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
69
- raise ValueError('key must be an array with dtype uint32. '
70
- f'But we got {seed_or_key}')
71
- key = seed_or_key
72
- super().__init__(key)
73
-
74
- self._backup = None
75
-
76
- def __repr__(
77
- self
78
- ):
79
- return f'{self.__class__.__name__}({self.value})'
80
-
81
- def check_if_deleted(
82
- self
83
- ):
84
- if not use_prng_key and isinstance(self._value, np.ndarray):
85
- self._value = jr.key(np.random.randint(0, 10000))
86
-
87
- if (
88
- isinstance(self._value, jax.Array) and
89
- not isinstance(self._value, jax.core.Tracer) and
90
- self._value.is_deleted()
91
- ):
92
- self.seed()
93
-
94
- # ------------------- #
95
- # seed and random key #
96
- # ------------------- #
97
-
98
- def backup_key(self):
99
- if self._backup is not None:
100
- raise ValueError('The random key has been backed up, and has not been restored.')
101
- self._backup = self.value
102
-
103
- def restore_key(self):
104
- if self._backup is None:
105
- raise ValueError('The random key has not been backed up.')
106
- self.value = self._backup
107
- self._backup = None
108
-
109
- def clone(self):
110
- return type(self)(self.split_key())
111
-
112
- def set_key(self, key: SeedOrKey):
113
- self.value = key
114
-
115
- def seed(
116
- self,
117
- seed_or_key: Optional[SeedOrKey] = None
118
- ):
119
- """Sets a new random seed.
120
-
121
- Parameters
122
- ----------
123
- seed_or_key: int, ArrayLike, optional
124
- It can be an integer for initial seed of the random number generator,
125
- or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
126
- """
127
- with jax.ensure_compile_time_eval():
128
- if seed_or_key is None:
129
- seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
130
- if np.size(seed_or_key) == 1:
131
- if isinstance(seed_or_key, int):
132
- key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
133
- elif jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
134
- key = seed_or_key
135
- elif isinstance(seed_or_key, (jnp.ndarray, np.ndarray)) and jnp.issubdtype(seed_or_key.dtype, jnp.integer):
136
- key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
137
- else:
138
- raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
139
- else:
140
- if len(seed_or_key) == 2 and seed_or_key.dtype == np.uint32:
141
- key = seed_or_key
142
- else:
143
- raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
144
- self.value = key
145
-
146
- def split_key(
147
- self,
148
- n: Optional[int] = None,
149
- backup: bool = False
150
- ) -> SeedOrKey:
151
- """
152
- Create a new seed from the current seed.
153
-
154
- Parameters
155
- ----------
156
- n: int, optional
157
- The number of seeds to generate.
158
- backup : bool, optional
159
- Whether to backup the current key.
160
-
161
- Returns
162
- -------
163
- key : SeedOrKey
164
- The new seed or a tuple of JAX random keys.
165
- """
166
- if n is not None:
167
- assert isinstance(n, int) and n >= 1, f'n should be an integer greater than 1, but we got {n}'
168
-
169
- if not isinstance(self.value, jax.Array):
170
- self.value = u.math.asarray(self.value, dtype=jnp.uint32)
171
- keys = jr.split(self.value, num=2 if n is None else n + 1)
172
- self.value = keys[0]
173
- if backup:
174
- self.backup_key()
175
- if n is None:
176
- return keys[1]
177
- else:
178
- return keys[1:]
179
-
180
- def self_assign_multi_keys(
181
- self,
182
- n: int,
183
- backup: bool = True
184
- ):
185
- """
186
- Self-assign multiple keys to the current random state.
187
- """
188
- if backup:
189
- keys = jr.split(self.value, n + 1)
190
- self.value = keys[0]
191
- self.backup_key()
192
- self.value = keys[1:]
193
- else:
194
- self.value = jr.split(self.value, n)
195
-
196
- # ---------------- #
197
- # random functions #
198
- # ---------------- #
199
-
200
- def rand(
201
- self,
202
- *dn,
203
- key: Optional[SeedOrKey] = None,
204
- dtype: DTypeLike = None
205
- ):
206
- key = self.split_key() if key is None else _formalize_key(key)
207
- dtype = dtype or environ.dftype()
208
- r = jr.uniform(key, dn, dtype)
209
- return r
210
-
211
- def randint(
212
- self,
213
- low,
214
- high=None,
215
- size: Optional[Size] = None,
216
- dtype: DTypeLike = None,
217
- key: Optional[SeedOrKey] = None
218
- ):
219
- if high is None:
220
- high = low
221
- low = 0
222
- high = _check_py_seq(high)
223
- low = _check_py_seq(low)
224
- if size is None:
225
- size = lax.broadcast_shapes(u.math.shape(low),
226
- u.math.shape(high))
227
- key = self.split_key() if key is None else _formalize_key(key)
228
- dtype = dtype or environ.ditype()
229
- r = jr.randint(key,
230
- shape=_size2shape(size),
231
- minval=low, maxval=high, dtype=dtype)
232
- return r
233
-
234
- def random_integers(
235
- self,
236
- low,
237
- high=None,
238
- size: Optional[Size] = None,
239
- key: Optional[SeedOrKey] = None,
240
- dtype: DTypeLike = None
241
- ):
242
- low = _check_py_seq(low)
243
- high = _check_py_seq(high)
244
- if high is None:
245
- high = low
246
- low = 1
247
- high += 1
248
- if size is None:
249
- size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
250
- key = self.split_key() if key is None else _formalize_key(key)
251
- dtype = dtype or environ.ditype()
252
- r = jr.randint(key,
253
- shape=_size2shape(size),
254
- minval=low,
255
- maxval=high,
256
- dtype=dtype)
257
- return r
258
-
259
- def randn(
260
- self,
261
- *dn,
262
- key: Optional[SeedOrKey] = None,
263
- dtype: DTypeLike = None
264
- ):
265
- key = self.split_key() if key is None else _formalize_key(key)
266
- dtype = dtype or environ.dftype()
267
- r = jr.normal(key, shape=dn, dtype=dtype)
268
- return r
269
-
270
- def random(
271
- self,
272
- size: Optional[Size] = None,
273
- key: Optional[SeedOrKey] = None,
274
- dtype: DTypeLike = None
275
- ):
276
- dtype = dtype or environ.dftype()
277
- key = self.split_key() if key is None else _formalize_key(key)
278
- r = jr.uniform(key, _size2shape(size), dtype)
279
- return r
280
-
281
- def random_sample(
282
- self,
283
- size: Optional[Size] = None,
284
- key: Optional[SeedOrKey] = None,
285
- dtype: DTypeLike = None
286
- ):
287
- r = self.random(size=size, key=key, dtype=dtype)
288
- return r
289
-
290
- def ranf(
291
- self,
292
- size: Optional[Size] = None,
293
- key: Optional[SeedOrKey] = None,
294
- dtype: DTypeLike = None
295
- ):
296
- r = self.random(size=size, key=key, dtype=dtype)
297
- return r
298
-
299
- def sample(
300
- self,
301
- size: Optional[Size] = None,
302
- key: Optional[SeedOrKey] = None,
303
- dtype: DTypeLike = None
304
- ):
305
- r = self.random(size=size, key=key, dtype=dtype)
306
- return r
307
-
308
- def choice(
309
- self,
310
- a,
311
- size: Optional[Size] = None,
312
- replace=True,
313
- p=None,
314
- key: Optional[SeedOrKey] = None
315
- ):
316
- a = _check_py_seq(a)
317
- a, unit = u.split_mantissa_unit(a)
318
- p = _check_py_seq(p)
319
- key = self.split_key() if key is None else _formalize_key(key)
320
- r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
321
- return u.maybe_decimal(r * unit)
322
-
323
- def permutation(
324
- self,
325
- x,
326
- axis: int = 0,
327
- independent: bool = False,
328
- key: Optional[SeedOrKey] = None
329
- ):
330
- x = _check_py_seq(x)
331
- x, unit = u.split_mantissa_unit(x)
332
- key = self.split_key() if key is None else _formalize_key(key)
333
- r = jr.permutation(key, x, axis, independent=independent)
334
- return u.maybe_decimal(r * unit)
335
-
336
- def shuffle(
337
- self,
338
- x,
339
- axis=0,
340
- key: Optional[SeedOrKey] = None
341
- ):
342
- return self.permutation(x, axis=axis, key=key, independent=False)
343
-
344
- def beta(
345
- self,
346
- a,
347
- b,
348
- size: Optional[Size] = None,
349
- key: Optional[SeedOrKey] = None,
350
- dtype: DTypeLike = None
351
- ):
352
- a = _check_py_seq(a)
353
- b = _check_py_seq(b)
354
- if size is None:
355
- size = lax.broadcast_shapes(u.math.shape(a), u.math.shape(b))
356
- key = self.split_key() if key is None else _formalize_key(key)
357
- dtype = dtype or environ.dftype()
358
- r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype)
359
- return r
360
-
361
- def exponential(
362
- self,
363
- scale=None,
364
- size: Optional[Size] = None,
365
- key: Optional[SeedOrKey] = None,
366
- dtype: DTypeLike = None
367
- ):
368
- if size is None:
369
- size = u.math.shape(scale)
370
- key = self.split_key() if key is None else _formalize_key(key)
371
- dtype = dtype or environ.dftype()
372
- r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
373
- if scale is not None:
374
- scale = u.math.asarray(scale, dtype=dtype)
375
- r = r / scale
376
- return r
377
-
378
- def gamma(
379
- self,
380
- shape,
381
- scale=None,
382
- size: Optional[Size] = None,
383
- key: Optional[SeedOrKey] = None,
384
- dtype: DTypeLike = None
385
- ):
386
- shape = _check_py_seq(shape)
387
- scale = _check_py_seq(scale)
388
- if size is None:
389
- size = lax.broadcast_shapes(u.math.shape(shape), u.math.shape(scale))
390
- key = self.split_key() if key is None else _formalize_key(key)
391
- dtype = dtype or environ.dftype()
392
- r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
393
- if scale is not None:
394
- r = r * scale
395
- return r
396
-
397
- def gumbel(
398
- self,
399
- loc=None,
400
- scale=None,
401
- size: Optional[Size] = None,
402
- key: Optional[SeedOrKey] = None,
403
- dtype: DTypeLike = None
404
- ):
405
- loc = _check_py_seq(loc)
406
- scale = _check_py_seq(scale)
407
- if size is None:
408
- size = lax.broadcast_shapes(u.math.shape(loc), u.math.shape(scale))
409
- key = self.split_key() if key is None else _formalize_key(key)
410
- dtype = dtype or environ.dftype()
411
- r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype))
412
- return r
413
-
414
- def laplace(
415
- self,
416
- loc=None,
417
- scale=None,
418
- size: Optional[Size] = None,
419
- key: Optional[SeedOrKey] = None,
420
- dtype: DTypeLike = None
421
- ):
422
- loc = _check_py_seq(loc)
423
- scale = _check_py_seq(scale)
424
- if size is None:
425
- size = lax.broadcast_shapes(u.math.shape(loc), u.math.shape(scale))
426
- key = self.split_key() if key is None else _formalize_key(key)
427
- dtype = dtype or environ.dftype()
428
- r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype))
429
- return r
430
-
431
- def logistic(
432
- self,
433
- loc=None,
434
- scale=None,
435
- size: Optional[Size] = None,
436
- key: Optional[SeedOrKey] = None,
437
- dtype: DTypeLike = None
438
- ):
439
- loc = _check_py_seq(loc)
440
- scale = _check_py_seq(scale)
441
- if size is None:
442
- size = lax.broadcast_shapes(
443
- u.math.shape(loc) if loc is not None else (),
444
- u.math.shape(scale) if scale is not None else ()
445
- )
446
- key = self.split_key() if key is None else _formalize_key(key)
447
- dtype = dtype or environ.dftype()
448
- r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
449
- return r
450
-
451
- def normal(
452
- self,
453
- loc=None,
454
- scale=None,
455
- size: Optional[Size] = None,
456
- key: Optional[SeedOrKey] = None,
457
- dtype: DTypeLike = None
458
- ):
459
- loc = _check_py_seq(loc)
460
- scale = _check_py_seq(scale)
461
- if size is None:
462
- size = lax.broadcast_shapes(
463
- u.math.shape(scale) if scale is not None else (),
464
- u.math.shape(loc) if loc is not None else ()
465
- )
466
- key = self.split_key() if key is None else _formalize_key(key)
467
- dtype = dtype or environ.dftype()
468
- r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
469
- return r
470
-
471
- def pareto(
472
- self,
473
- a,
474
- size: Optional[Size] = None,
475
- key: Optional[SeedOrKey] = None,
476
- dtype: DTypeLike = None
477
- ):
478
- if size is None:
479
- size = u.math.shape(a)
480
- key = self.split_key() if key is None else _formalize_key(key)
481
- dtype = dtype or environ.dftype()
482
- a = u.math.asarray(a, dtype=dtype)
483
- r = jr.pareto(key, b=a, shape=_size2shape(size), dtype=dtype)
484
- return r
485
-
486
- def poisson(
487
- self,
488
- lam=1.0,
489
- size: Optional[Size] = None,
490
- key: Optional[SeedOrKey] = None,
491
- dtype: DTypeLike = None
492
- ):
493
- lam = _check_py_seq(lam)
494
- if size is None:
495
- size = u.math.shape(lam)
496
- key = self.split_key() if key is None else _formalize_key(key)
497
- dtype = dtype or environ.ditype()
498
- r = jr.poisson(key, lam=lam, shape=_size2shape(size), dtype=dtype)
499
- return r
500
-
501
- def standard_cauchy(
502
- self,
503
- size: Optional[Size] = None,
504
- key: Optional[SeedOrKey] = None,
505
- dtype: DTypeLike = None
506
- ):
507
- key = self.split_key() if key is None else _formalize_key(key)
508
- dtype = dtype or environ.dftype()
509
- r = jr.cauchy(key, shape=_size2shape(size), dtype=dtype)
510
- return r
511
-
512
- def standard_exponential(
513
- self,
514
- size: Optional[Size] = None,
515
- key: Optional[SeedOrKey] = None,
516
- dtype: DTypeLike = None
517
- ):
518
- key = self.split_key() if key is None else _formalize_key(key)
519
- dtype = dtype or environ.dftype()
520
- r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
521
- return r
522
-
523
- def standard_gamma(
524
- self,
525
- shape,
526
- size: Optional[Size] = None,
527
- key: Optional[SeedOrKey] = None,
528
- dtype: DTypeLike = None
529
- ):
530
- shape = _check_py_seq(shape)
531
- if size is None:
532
- size = u.math.shape(shape) if shape is not None else ()
533
- key = self.split_key() if key is None else _formalize_key(key)
534
- dtype = dtype or environ.dftype()
535
- r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
536
- return r
537
-
538
- def standard_normal(
539
- self,
540
- size: Optional[Size] = None,
541
- key: Optional[SeedOrKey] = None,
542
- dtype: DTypeLike = None
543
- ):
544
- key = self.split_key() if key is None else _formalize_key(key)
545
- dtype = dtype or environ.dftype()
546
- r = jr.normal(key, shape=_size2shape(size), dtype=dtype)
547
- return r
548
-
549
- def standard_t(
550
- self,
551
- df,
552
- size: Optional[Size] = None,
553
- key: Optional[SeedOrKey] = None,
554
- dtype: DTypeLike = None
555
- ):
556
- df = _check_py_seq(df)
557
- if size is None:
558
- size = u.math.shape(size) if size is not None else ()
559
- key = self.split_key() if key is None else _formalize_key(key)
560
- dtype = dtype or environ.dftype()
561
- r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
562
- return r
563
-
564
- def uniform(
565
- self,
566
- low=0.0,
567
- high=1.0,
568
- size: Optional[Size] = None,
569
- key: Optional[SeedOrKey] = None,
570
- dtype: DTypeLike = None
571
- ):
572
- low, unit = u.split_mantissa_unit(_check_py_seq(low))
573
- high = u.Quantity(_check_py_seq(high)).to(unit).mantissa
574
- if size is None:
575
- size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
576
- key = self.split_key() if key is None else _formalize_key(key)
577
- dtype = dtype or environ.dftype()
578
- r = jr.uniform(key, _size2shape(size), dtype=dtype, minval=low, maxval=high)
579
- return u.maybe_decimal(r * unit)
580
-
581
- def __norm_cdf(
582
- self,
583
- x,
584
- sqrt2,
585
- dtype
586
- ):
587
- # Computes standard normal cumulative distribution function
588
- return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
589
-
590
- def truncated_normal(
591
- self,
592
- lower,
593
- upper,
594
- size: Optional[Size] = None,
595
- loc=0.0,
596
- scale=1.0,
597
- key: Optional[SeedOrKey] = None,
598
- dtype: DTypeLike = None,
599
- check_valid: bool = True
600
- ):
601
- lower = _check_py_seq(lower)
602
- upper = _check_py_seq(upper)
603
- loc = _check_py_seq(loc)
604
- scale = _check_py_seq(scale)
605
- dtype = dtype or environ.dftype()
606
-
607
- lower, unit = u.split_mantissa_unit(u.math.asarray(lower, dtype=dtype))
608
- upper = u.math.asarray(upper, dtype=dtype)
609
- loc = u.math.asarray(loc, dtype=dtype)
610
- scale = u.math.asarray(scale, dtype=dtype)
611
- upper, loc, scale = (
612
- u.Quantity(upper).in_unit(unit).mantissa,
613
- u.Quantity(loc).in_unit(unit).mantissa,
614
- u.Quantity(scale).in_unit(unit).mantissa
615
- )
616
-
617
- if check_valid:
618
- from brainstate.transform._error_if import jit_error_if
619
- jit_error_if(
620
- u.math.any(u.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
621
- "mean is more than 2 std from [lower, upper] in truncated_normal. "
622
- "The distribution of values may be incorrect."
623
- )
624
-
625
- if size is None:
626
- size = u.math.broadcast_shapes(
627
- u.math.shape(lower),
628
- u.math.shape(upper),
629
- u.math.shape(loc),
630
- u.math.shape(scale)
631
- )
632
-
633
- # Values are generated by using a truncated uniform distribution and
634
- # then using the inverse CDF for the normal distribution.
635
- # Get upper and lower cdf values
636
- sqrt2 = np.array(np.sqrt(2), dtype=dtype)
637
- l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
638
- u_ = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
639
-
640
- # Uniformly fill tensor with values from [l, u], then translate to
641
- # [2l-1, 2u-1].
642
- key = self.split_key() if key is None else _formalize_key(key)
643
- out = jr.uniform(
644
- key, size, dtype,
645
- minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
646
- maxval=lax.nextafter(2 * u_ - 1, np.array(-np.inf, dtype=dtype))
647
- )
648
-
649
- # Use inverse cdf transform for normal distribution to get truncated
650
- # standard normal
651
- out = lax.erf_inv(out)
652
-
653
- # Transform to proper mean, std
654
- out = out * scale * sqrt2 + loc
655
-
656
- # Clamp to ensure it's in the proper range
657
- out = jnp.clip(
658
- out,
659
- lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
660
- lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
661
- )
662
- return u.maybe_decimal(out * unit)
663
-
664
- def _check_p(self, *args, **kwargs):
665
- raise ValueError('Parameter p should be within [0, 1], but we got {p}')
666
-
667
- def bernoulli(
668
- self,
669
- p,
670
- size: Optional[Size] = None,
671
- key: Optional[SeedOrKey] = None,
672
- check_valid: bool = True
673
- ):
674
- p = _check_py_seq(p)
675
- if check_valid:
676
- from brainstate.transform._error_if import jit_error_if
677
- jit_error_if(jnp.any(jnp.logical_or(p < 0, p > 1)), self._check_p, p=p)
678
- if size is None:
679
- size = u.math.shape(p)
680
- key = self.split_key() if key is None else _formalize_key(key)
681
- r = jr.bernoulli(key, p=p, shape=_size2shape(size))
682
- return r
683
-
684
- def lognormal(
685
- self,
686
- mean=None,
687
- sigma=None,
688
- size: Optional[Size] = None,
689
- key: Optional[SeedOrKey] = None,
690
- dtype: DTypeLike = None
691
- ):
692
- mean = _check_py_seq(mean)
693
- sigma = _check_py_seq(sigma)
694
- mean = u.math.asarray(mean, dtype=dtype)
695
- sigma = u.math.asarray(sigma, dtype=dtype)
696
- unit = mean.unit if isinstance(mean, u.Quantity) else u.UNITLESS
697
- mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
698
- sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, u.Quantity) else sigma
699
-
700
- if size is None:
701
- size = jnp.broadcast_shapes(
702
- u.math.shape(mean) if mean is not None else (),
703
- u.math.shape(sigma) if sigma is not None else ()
704
- )
705
- key = self.split_key() if key is None else _formalize_key(key)
706
- dtype = dtype or environ.dftype()
707
- samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
708
- samples = _loc_scale(mean, sigma, samples)
709
- samples = jnp.exp(samples)
710
- return u.maybe_decimal(samples * unit)
711
-
712
- def binomial(
713
- self,
714
- n,
715
- p,
716
- size: Optional[Size] = None,
717
- key: Optional[SeedOrKey] = None,
718
- dtype: DTypeLike = None,
719
- check_valid: bool = True
720
- ):
721
- n = _check_py_seq(n)
722
- p = _check_py_seq(p)
723
- if check_valid:
724
- from brainstate.transform._error_if import jit_error_if
725
- jit_error_if(
726
- jnp.any(jnp.logical_or(p < 0, p > 1)),
727
- 'Parameter p should be within [0, 1], but we got {p}',
728
- p=p
729
- )
730
- if size is None:
731
- size = jnp.broadcast_shapes(u.math.shape(n), u.math.shape(p))
732
- key = self.split_key() if key is None else _formalize_key(key)
733
- r = jr.binomial(key, n, p, shape=_size2shape(size))
734
- dtype = dtype or environ.ditype()
735
- return u.math.asarray(r, dtype=dtype)
736
-
737
- def chisquare(
738
- self,
739
- df,
740
- size: Optional[Size] = None,
741
- key: Optional[SeedOrKey] = None,
742
- dtype: DTypeLike = None
743
- ):
744
- df = _check_py_seq(df)
745
- key = self.split_key() if key is None else _formalize_key(key)
746
- dtype = dtype or environ.dftype()
747
- if size is None:
748
- if jnp.ndim(df) == 0:
749
- dist = jr.normal(key, (df,), dtype=dtype) ** 2
750
- dist = dist.sum()
751
- else:
752
- raise NotImplementedError('Do not support non-scale "df" when "size" is None')
753
- else:
754
- dist = jr.normal(key, (df,) + _size2shape(size), dtype=dtype) ** 2
755
- dist = dist.sum(axis=0)
756
- return dist
757
-
758
- def dirichlet(
759
- self,
760
- alpha,
761
- size: Optional[Size] = None,
762
- key: Optional[SeedOrKey] = None,
763
- dtype: DTypeLike = None
764
- ):
765
- key = self.split_key() if key is None else _formalize_key(key)
766
- alpha = _check_py_seq(alpha)
767
- dtype = dtype or environ.dftype()
768
- r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size), dtype=dtype)
769
- return r
770
-
771
- def geometric(
772
- self,
773
- p,
774
- size: Optional[Size] = None,
775
- key: Optional[SeedOrKey] = None,
776
- dtype: DTypeLike = None
777
- ):
778
- p = _check_py_seq(p)
779
- if size is None:
780
- size = u.math.shape(p)
781
- key = self.split_key() if key is None else _formalize_key(key)
782
- dtype = dtype or environ.dftype()
783
- u_ = jr.uniform(key, size, dtype)
784
- r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p))
785
- return r
786
-
787
- def _check_p2(self, p):
788
- raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
789
-
790
- def multinomial(
791
- self,
792
- n,
793
- pvals,
794
- size: Optional[Size] = None,
795
- key: Optional[SeedOrKey] = None,
796
- dtype: DTypeLike = None,
797
- check_valid: bool = True
798
- ):
799
- key = self.split_key() if key is None else _formalize_key(key)
800
- n = _check_py_seq(n)
801
- pvals = _check_py_seq(pvals)
802
- if check_valid:
803
- from brainstate.transform._error_if import jit_error_if
804
- jit_error_if(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
805
- if isinstance(n, jax.core.Tracer):
806
- raise ValueError("The total count parameter `n` should not be a jax abstract array.")
807
- size = _size2shape(size)
808
- n_max = int(np.max(jax.device_get(n)))
809
- batch_shape = lax.broadcast_shapes(u.math.shape(pvals)[:-1], u.math.shape(n))
810
- r = _multinomial(key, pvals, n, n_max, batch_shape + size)
811
- dtype = dtype or environ.ditype()
812
- return u.math.asarray(r, dtype=dtype)
813
-
814
- def multivariate_normal(
815
- self,
816
- mean,
817
- cov,
818
- size: Optional[Size] = None,
819
- method: str = 'cholesky',
820
- key: Optional[SeedOrKey] = None,
821
- dtype: DTypeLike = None
822
- ):
823
- if method not in {'svd', 'eigh', 'cholesky'}:
824
- raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
825
- dtype = dtype or environ.dftype()
826
- mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
827
- cov = u.math.asarray(_check_py_seq(cov), dtype=dtype)
828
- if isinstance(mean, u.Quantity):
829
- assert isinstance(cov, u.Quantity)
830
- assert mean.unit ** 2 == cov.unit
831
- mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
832
- cov = cov.mantissa if isinstance(cov, u.Quantity) else cov
833
- unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
834
-
835
- key = self.split_key() if key is None else _formalize_key(key)
836
- if not jnp.ndim(mean) >= 1:
837
- raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
838
- if not jnp.ndim(cov) >= 2:
839
- raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}")
840
- n = mean.shape[-1]
841
- if u.math.shape(cov)[-2:] != (n, n):
842
- raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
843
- f"but got cov.shape == {u.math.shape(cov)}.")
844
- if size is None:
845
- size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
846
- else:
847
- size = _size2shape(size)
848
- _check_shape("normal", size, mean.shape[:-1], cov.shape[:-2])
849
-
850
- if method == 'svd':
851
- (u_, s, _) = jnp.linalg.svd(cov)
852
- factor = u_ * jnp.sqrt(s[..., None, :])
853
- elif method == 'eigh':
854
- (w, v) = jnp.linalg.eigh(cov)
855
- factor = v * jnp.sqrt(w[..., None, :])
856
- else: # 'cholesky'
857
- factor = jnp.linalg.cholesky(cov)
858
- normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
859
- r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
860
- return u.maybe_decimal(r * unit)
861
-
862
- def rayleigh(
863
- self,
864
- scale=1.0,
865
- size: Optional[Size] = None,
866
- key: Optional[SeedOrKey] = None,
867
- dtype: DTypeLike = None
868
- ):
869
- scale = _check_py_seq(scale)
870
- if size is None:
871
- size = u.math.shape(scale)
872
- key = self.split_key() if key is None else _formalize_key(key)
873
- dtype = dtype or environ.dftype()
874
- x = jnp.sqrt(-2. * jnp.log(jr.uniform(key, shape=_size2shape(size), dtype=dtype)))
875
- r = x * scale
876
- return r
877
-
878
- def triangular(
879
- self,
880
- size: Optional[Size] = None,
881
- key: Optional[SeedOrKey] = None
882
- ):
883
- key = self.split_key() if key is None else _formalize_key(key)
884
- bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
885
- r = 2 * bernoulli_samples - 1
886
- return r
887
-
888
- def vonmises(
889
- self,
890
- mu,
891
- kappa,
892
- size: Optional[Size] = None,
893
- key: Optional[SeedOrKey] = None,
894
- dtype: DTypeLike = None
895
- ):
896
- key = self.split_key() if key is None else _formalize_key(key)
897
- dtype = dtype or environ.dftype()
898
- mu = u.math.asarray(_check_py_seq(mu), dtype=dtype)
899
- kappa = u.math.asarray(_check_py_seq(kappa), dtype=dtype)
900
- if size is None:
901
- size = lax.broadcast_shapes(u.math.shape(mu), u.math.shape(kappa))
902
- size = _size2shape(size)
903
- samples = _von_mises_centered(key, kappa, size, dtype=dtype)
904
- samples = samples + mu
905
- samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
906
- return samples
907
-
908
- def weibull(
909
- self,
910
- a,
911
- size: Optional[Size] = None,
912
- key: Optional[SeedOrKey] = None,
913
- dtype: DTypeLike = None
914
- ):
915
- key = self.split_key() if key is None else _formalize_key(key)
916
- a = _check_py_seq(a)
917
- if size is None:
918
- size = u.math.shape(a)
919
- else:
920
- if jnp.size(a) > 1:
921
- raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
922
- size = _size2shape(size)
923
- dtype = dtype or environ.dftype()
924
- random_uniform = jr.uniform(key=key, shape=size, dtype=dtype)
925
- r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
926
- return r
927
-
928
- def weibull_min(
929
- self,
930
- a,
931
- scale=None,
932
- size: Optional[Size] = None,
933
- key: Optional[SeedOrKey] = None,
934
- dtype: DTypeLike = None
935
- ):
936
- key = self.split_key() if key is None else _formalize_key(key)
937
- a = _check_py_seq(a)
938
- scale = _check_py_seq(scale)
939
- if size is None:
940
- size = jnp.broadcast_shapes(u.math.shape(a), u.math.shape(scale) if scale is not None else ())
941
- else:
942
- if jnp.size(a) > 1:
943
- raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
944
- size = _size2shape(size)
945
- dtype = dtype or environ.dftype()
946
- random_uniform = jr.uniform(key=key, shape=size, dtype=dtype)
947
- r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
948
- if scale is not None:
949
- r /= scale
950
- return r
951
-
952
- def maxwell(
953
- self,
954
- size: Optional[Size] = None,
955
- key: Optional[SeedOrKey] = None,
956
- dtype: DTypeLike = None
957
- ):
958
- key = self.split_key() if key is None else _formalize_key(key)
959
- shape = _size2shape(size) + (3,)
960
- dtype = dtype or environ.dftype()
961
- norm_rvs = jr.normal(key=key, shape=shape, dtype=dtype)
962
- r = jnp.linalg.norm(norm_rvs, axis=-1)
963
- return r
964
-
965
- def negative_binomial(
966
- self,
967
- n,
968
- p,
969
- size: Optional[Size] = None,
970
- key: Optional[SeedOrKey] = None,
971
- dtype: DTypeLike = None
972
- ):
973
- n = _check_py_seq(n)
974
- p = _check_py_seq(p)
975
- if size is None:
976
- size = lax.broadcast_shapes(u.math.shape(n), u.math.shape(p))
977
- size = _size2shape(size)
978
- logits = jnp.log(p) - jnp.log1p(-p)
979
- if key is None:
980
- keys = self.split_key(2)
981
- else:
982
- keys = jr.split(_formalize_key(key), 2)
983
- rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0], dtype=environ.dftype())
984
- r = self.poisson(lam=rate, key=keys[1], dtype=dtype or environ.ditype())
985
- return r
986
-
987
- def wald(
988
- self,
989
- mean,
990
- scale,
991
- size: Optional[Size] = None,
992
- key: Optional[SeedOrKey] = None,
993
- dtype: DTypeLike = None
994
- ):
995
- dtype = dtype or environ.dftype()
996
- key = self.split_key() if key is None else _formalize_key(key)
997
- mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
998
- scale = u.math.asarray(_check_py_seq(scale), dtype=dtype)
999
- if size is None:
1000
- size = lax.broadcast_shapes(u.math.shape(mean), u.math.shape(scale))
1001
- size = _size2shape(size)
1002
- sampled_chi2 = jnp.square(self.randn(*size))
1003
- sampled_uniform = self.uniform(size=size, key=key, dtype=dtype)
1004
- # Wikipedia defines an intermediate x with the formula
1005
- # x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2)
1006
- # where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration.
1007
- # Let us write
1008
- # w = loc * y / (2 * conc)
1009
- # Then we can extract the common factor in the last two terms to obtain
1010
- # x = loc + loc * w * (1 - sqrt(2 / w + 1))
1011
- # Now we see that the Wikipedia formula suffers from catastrphic
1012
- # cancellation for large w (e.g., if conc << loc).
1013
- #
1014
- # Fortunately, we can fix this by multiplying both sides
1015
- # by 1 + sqrt(2 / w + 1). We get
1016
- # x * (1 + sqrt(2 / w + 1)) =
1017
- # = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1))
1018
- # = loc * (sqrt(2 / w + 1) - 1)
1019
- # The term sqrt(2 / w + 1) + 1 no longer presents numerical
1020
- # difficulties for large w, and sqrt(2 / w + 1) - 1 is just
1021
- # sqrt1pm1(2 / w), which we know how to compute accurately.
1022
- # This just leaves the matter of small w, where 2 / w may
1023
- # overflow. In the limit a w -> 0, x -> loc, so we just mask
1024
- # that case.
1025
- sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above
1026
- safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0)
1027
- denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0)
1028
- ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator
1029
- sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above
1030
- res = jnp.where(sampled_uniform <= mean / (mean + sampled),
1031
- sampled,
1032
- jnp.square(mean) / sampled)
1033
- return res
1034
-
1035
- def t(
1036
- self,
1037
- df,
1038
- size: Optional[Size] = None,
1039
- key: Optional[SeedOrKey] = None,
1040
- dtype: DTypeLike = None
1041
- ):
1042
- dtype = dtype or environ.dftype()
1043
- df = u.math.asarray(_check_py_seq(df), dtype=dtype)
1044
- if size is None:
1045
- size = np.shape(df)
1046
- else:
1047
- size = _size2shape(size)
1048
- _check_shape("t", size, np.shape(df))
1049
- if key is None:
1050
- keys = self.split_key(2)
1051
- else:
1052
- keys = jr.split(_formalize_key(key), 2)
1053
- n = jr.normal(keys[0], size, dtype=dtype)
1054
- two = _const(n, 2)
1055
- half_df = lax.div(df, two)
1056
- g = jr.gamma(keys[1], half_df, size, dtype=dtype)
1057
- r = n * jnp.sqrt(half_df / g)
1058
- return r
1059
-
1060
- def orthogonal(
1061
- self,
1062
- n: int,
1063
- size: Optional[Size] = None,
1064
- key: Optional[SeedOrKey] = None,
1065
- dtype: DTypeLike = None
1066
- ):
1067
- dtype = dtype or environ.dftype()
1068
- key = self.split_key() if key is None else _formalize_key(key)
1069
- size = _size2shape(size)
1070
- _check_shape("orthogonal", size)
1071
- n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
1072
- z = jr.normal(key, size + (n, n), dtype=dtype)
1073
- q, r = jnp.linalg.qr(z)
1074
- d = jnp.diagonal(r, 0, -2, -1)
1075
- r = q * jnp.expand_dims(d / abs(d), -2)
1076
- return r
1077
-
1078
- def noncentral_chisquare(
1079
- self,
1080
- df,
1081
- nonc,
1082
- size: Optional[Size] = None,
1083
- key: Optional[SeedOrKey] = None,
1084
- dtype: DTypeLike = None
1085
- ):
1086
- dtype = dtype or environ.dftype()
1087
- df = u.math.asarray(_check_py_seq(df), dtype=dtype)
1088
- nonc = u.math.asarray(_check_py_seq(nonc), dtype=dtype)
1089
- if size is None:
1090
- size = lax.broadcast_shapes(u.math.shape(df), u.math.shape(nonc))
1091
- size = _size2shape(size)
1092
- if key is None:
1093
- keys = self.split_key(3)
1094
- else:
1095
- keys = jr.split(_formalize_key(key), 3)
1096
- i = jr.poisson(keys[0], 0.5 * nonc, shape=size, dtype=environ.ditype())
1097
- n = jr.normal(keys[1], shape=size, dtype=dtype) + jnp.sqrt(nonc)
1098
- cond = jnp.greater(df, 1.0)
1099
- df2 = jnp.where(cond, df - 1.0, df + 2.0 * i)
1100
- chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size, dtype=dtype)
1101
- r = jnp.where(cond, chi2 + n * n, chi2)
1102
- return r
1103
-
1104
- def loggamma(
1105
- self,
1106
- a,
1107
- size: Optional[Size] = None,
1108
- key: Optional[SeedOrKey] = None,
1109
- dtype: DTypeLike = None
1110
- ):
1111
- dtype = dtype or environ.dftype()
1112
- key = self.split_key() if key is None else _formalize_key(key)
1113
- a = _check_py_seq(a)
1114
- if size is None:
1115
- size = u.math.shape(a)
1116
- r = jr.loggamma(key, a, shape=_size2shape(size), dtype=dtype)
1117
- return r
1118
-
1119
- def categorical(
1120
- self,
1121
- logits,
1122
- axis: int = -1,
1123
- size: Optional[Size] = None,
1124
- key: Optional[SeedOrKey] = None
1125
- ):
1126
- key = self.split_key() if key is None else _formalize_key(key)
1127
- logits = _check_py_seq(logits)
1128
- if size is None:
1129
- size = list(u.math.shape(logits))
1130
- size.pop(axis)
1131
- r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
1132
- return r
1133
-
1134
- def zipf(
1135
- self,
1136
- a,
1137
- size: Optional[Size] = None,
1138
- key: Optional[SeedOrKey] = None,
1139
- dtype: DTypeLike = None
1140
- ):
1141
- a = _check_py_seq(a)
1142
- if size is None:
1143
- size = u.math.shape(a)
1144
- dtype = dtype or environ.ditype()
1145
- r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
1146
- jax.ShapeDtypeStruct(size, dtype),
1147
- a)
1148
- return r
1149
-
1150
- def power(
1151
- self,
1152
- a,
1153
- size: Optional[Size] = None,
1154
- key: Optional[SeedOrKey] = None,
1155
- dtype: DTypeLike = None
1156
- ):
1157
- a = _check_py_seq(a)
1158
- if size is None:
1159
- size = u.math.shape(a)
1160
- size = _size2shape(size)
1161
- dtype = dtype or environ.dftype()
1162
- r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
1163
- jax.ShapeDtypeStruct(size, dtype),
1164
- a)
1165
- return r
1166
-
1167
- def f(
1168
- self,
1169
- dfnum,
1170
- dfden,
1171
- size: Optional[Size] = None,
1172
- key: Optional[SeedOrKey] = None,
1173
- dtype: DTypeLike = None
1174
- ):
1175
- dfnum = _check_py_seq(dfnum)
1176
- dfden = _check_py_seq(dfden)
1177
- if size is None:
1178
- size = jnp.broadcast_shapes(u.math.shape(dfnum), u.math.shape(dfden))
1179
- size = _size2shape(size)
1180
- d = {'dfnum': dfnum, 'dfden': dfden}
1181
- dtype = dtype or environ.dftype()
1182
- r = jax.pure_callback(
1183
- lambda dfnum_, dfden_: np.random.f(dfnum=dfnum_,
1184
- dfden=dfden_,
1185
- size=size).astype(dtype),
1186
- jax.ShapeDtypeStruct(size, dtype),
1187
- dfnum, dfden
1188
- )
1189
- return r
1190
-
1191
- def hypergeometric(
1192
- self,
1193
- ngood,
1194
- nbad,
1195
- nsample,
1196
- size: Optional[Size] = None,
1197
- key: Optional[SeedOrKey] = None,
1198
- dtype: DTypeLike = None
1199
- ):
1200
- ngood = _check_py_seq(ngood)
1201
- nbad = _check_py_seq(nbad)
1202
- nsample = _check_py_seq(nsample)
1203
-
1204
- if size is None:
1205
- size = lax.broadcast_shapes(u.math.shape(ngood),
1206
- u.math.shape(nbad),
1207
- u.math.shape(nsample))
1208
- size = _size2shape(size)
1209
- dtype = dtype or environ.ditype()
1210
- d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
1211
- r = jax.pure_callback(
1212
- lambda d: np.random.hypergeometric(
1213
- ngood=d['ngood'],
1214
- nbad=d['nbad'],
1215
- nsample=d['nsample'],
1216
- size=size
1217
- ).astype(dtype),
1218
- jax.ShapeDtypeStruct(size, dtype),
1219
- d
1220
- )
1221
- return r
1222
-
1223
- def logseries(
1224
- self,
1225
- p,
1226
- size: Optional[Size] = None,
1227
- key: Optional[SeedOrKey] = None,
1228
- dtype: DTypeLike = None
1229
- ):
1230
- p = _check_py_seq(p)
1231
- if size is None:
1232
- size = u.math.shape(p)
1233
- size = _size2shape(size)
1234
- dtype = dtype or environ.ditype()
1235
- r = jax.pure_callback(
1236
- lambda p: np.random.logseries(p=p, size=size).astype(dtype),
1237
- jax.ShapeDtypeStruct(size, dtype),
1238
- p
1239
- )
1240
- return r
1241
-
1242
- def noncentral_f(
1243
- self,
1244
- dfnum,
1245
- dfden,
1246
- nonc,
1247
- size: Optional[Size] = None,
1248
- key: Optional[SeedOrKey] = None,
1249
- dtype: DTypeLike = None
1250
- ):
1251
- dfnum = _check_py_seq(dfnum)
1252
- dfden = _check_py_seq(dfden)
1253
- nonc = _check_py_seq(nonc)
1254
- if size is None:
1255
- size = lax.broadcast_shapes(u.math.shape(dfnum),
1256
- u.math.shape(dfden),
1257
- u.math.shape(nonc))
1258
- size = _size2shape(size)
1259
- d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
1260
- dtype = dtype or environ.dftype()
1261
- r = jax.pure_callback(
1262
- lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
1263
- dfden=x['dfden'],
1264
- nonc=x['nonc'],
1265
- size=size).astype(dtype),
1266
- jax.ShapeDtypeStruct(size, dtype),
1267
- d
1268
- )
1269
- return r
1270
-
1271
- # PyTorch compatibility #
1272
- # --------------------- #
1273
-
1274
- def rand_like(
1275
- self,
1276
- input,
1277
- *,
1278
- dtype=None,
1279
- key: Optional[SeedOrKey] = None
1280
- ):
1281
- """Returns a tensor with the same size as input that is filled with random
1282
- numbers from a uniform distribution on the interval ``[0, 1)``.
1283
-
1284
- Args:
1285
- input: the ``size`` of input will determine size of the output tensor.
1286
- dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1287
- key: the seed or key for the random.
1288
-
1289
- Returns:
1290
- The random data.
1291
- """
1292
- return self.random(u.math.shape(input), key=key).astype(dtype)
1293
-
1294
- def randn_like(
1295
- self,
1296
- input,
1297
- *,
1298
- dtype=None,
1299
- key: Optional[SeedOrKey] = None
1300
- ):
1301
- """Returns a tensor with the same size as ``input`` that is filled with
1302
- random numbers from a normal distribution with mean 0 and variance 1.
1303
-
1304
- Args:
1305
- input: the ``size`` of input will determine size of the output tensor.
1306
- dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1307
- key: the seed or key for the random.
1308
-
1309
- Returns:
1310
- The random data.
1311
- """
1312
- return self.randn(*u.math.shape(input), key=key).astype(dtype)
1313
-
1314
- def randint_like(
1315
- self,
1316
- input,
1317
- low=0,
1318
- high=None,
1319
- *,
1320
- dtype=None,
1321
- key: Optional[SeedOrKey] = None
1322
- ):
1323
- if high is None:
1324
- high = max(input)
1325
- return self.randint(low, high=high, size=u.math.shape(input), dtype=dtype, key=key)
1326
-
1327
-
1328
- # default random generator
1329
- DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))
1330
-
1331
-
1332
- # ---------------------------------------------------------------------------------------------------------------
1333
-
1334
-
1335
- def _formalize_key(key):
1336
- if isinstance(key, int):
1337
- return jr.PRNGKey(key) if use_prng_key else jr.key(key)
1338
- elif isinstance(key, (jax.Array, np.ndarray)):
1339
- if jnp.issubdtype(key.dtype, jax.dtypes.prng_key):
1340
- return key
1341
- if key.size == 1 and jnp.issubdtype(key.dtype, jnp.integer):
1342
- return jr.PRNGKey(key) if use_prng_key else jr.key(key)
1343
-
1344
- if key.dtype != jnp.uint32:
1345
- raise TypeError('key must be a int or an array with two uint32.')
1346
- if key.size != 2:
1347
- raise TypeError('key must be a int or an array with two uint32.')
1348
- return u.math.asarray(key, dtype=jnp.uint32)
1349
- else:
1350
- raise TypeError('key must be a int or an array with two uint32.')
1351
-
1352
-
1353
- def _size2shape(size):
1354
- if size is None:
1355
- return ()
1356
- elif isinstance(size, (tuple, list)):
1357
- return tuple(size)
1358
- else:
1359
- return (size,)
1360
-
1361
-
1362
- def _check_shape(
1363
- name,
1364
- shape,
1365
- *param_shapes
1366
- ):
1367
- if param_shapes:
1368
- shape_ = lax.broadcast_shapes(shape, *param_shapes)
1369
- if shape != shape_:
1370
- msg = ("{} parameter shapes must be broadcast-compatible with shape "
1371
- "argument, and the result of broadcasting the shapes must equal "
1372
- "the shape argument, but got result {} for shape argument {}.")
1373
- raise ValueError(msg.format(name, shape_, shape))
1374
-
1375
-
1376
- def _is_python_scalar(x):
1377
- if hasattr(x, 'aval'):
1378
- return x.aval.weak_type
1379
- elif np.ndim(x) == 0:
1380
- return True
1381
- elif isinstance(x, (bool, int, float, complex)):
1382
- return True
1383
- else:
1384
- return False
1385
-
1386
-
1387
- python_scalar_dtypes = {
1388
- bool: np.dtype('bool'),
1389
- int: np.dtype('int64'),
1390
- float: np.dtype('float64'),
1391
- complex: np.dtype('complex128'),
1392
- }
1393
-
1394
-
1395
- def _dtype(
1396
- x,
1397
- *,
1398
- canonicalize: bool = False
1399
- ):
1400
- """Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
1401
- if x is None:
1402
- raise ValueError(f"Invalid argument to dtype: {x}.")
1403
- elif isinstance(x, type) and x in python_scalar_dtypes:
1404
- dt = python_scalar_dtypes[x]
1405
- elif type(x) in python_scalar_dtypes:
1406
- dt = python_scalar_dtypes[type(x)]
1407
- elif hasattr(x, 'dtype'):
1408
- dt = x.dtype
1409
- else:
1410
- dt = np.result_type(x)
1411
- return dtypes.canonicalize_dtype(dt) if canonicalize else dt
1412
-
1413
-
1414
- def _const(
1415
- example,
1416
- val
1417
- ):
1418
- if _is_python_scalar(example):
1419
- dtype = dtypes.canonicalize_dtype(type(example))
1420
- val = dtypes.scalar_type_of(example)(val)
1421
- return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype)
1422
- else:
1423
- dtype = dtypes.canonicalize_dtype(example.dtype)
1424
- return np.array(val, dtype)
1425
-
1426
-
1427
- @partial(jit, static_argnums=(2,))
1428
- def _categorical(
1429
- key,
1430
- p,
1431
- shape
1432
- ):
1433
- # this implementation is fast when event shape is small, and slow otherwise
1434
- # Ref: https://stackoverflow.com/a/34190035
1435
- shape = shape or p.shape[:-1]
1436
- s = jnp.cumsum(p, axis=-1)
1437
- r = jr.uniform(key, shape=shape + (1,))
1438
- return jnp.sum(s < r, axis=-1)
1439
-
1440
-
1441
- def _scatter_add_one(
1442
- operand,
1443
- indices,
1444
- updates
1445
- ):
1446
- return lax.scatter_add(
1447
- operand,
1448
- indices,
1449
- updates,
1450
- lax.ScatterDimensionNumbers(
1451
- update_window_dims=(),
1452
- inserted_window_dims=(0,),
1453
- scatter_dims_to_operand_dims=(0,),
1454
- ),
1455
- )
1456
-
1457
-
1458
- def _reshape(x, shape):
1459
- if isinstance(x, (int, float, np.ndarray, np.generic)):
1460
- return np.reshape(x, shape)
1461
- else:
1462
- return jnp.reshape(x, shape)
1463
-
1464
-
1465
- def _promote_shapes(
1466
- *args,
1467
- shape=()
1468
- ):
1469
- # adapted from lax.lax_numpy
1470
- if len(args) < 2 and not shape:
1471
- return args
1472
- else:
1473
- shapes = [u.math.shape(arg) for arg in args]
1474
- num_dims = len(lax.broadcast_shapes(shape, *shapes))
1475
- return [
1476
- _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
1477
- for arg, s in zip(args, shapes)
1478
- ]
1479
-
1480
-
1481
- @partial(jit, static_argnums=(3, 4))
1482
- def _multinomial(
1483
- key,
1484
- p,
1485
- n,
1486
- n_max,
1487
- shape=()
1488
- ):
1489
- if u.math.shape(n) != u.math.shape(p)[:-1]:
1490
- broadcast_shape = lax.broadcast_shapes(u.math.shape(n), u.math.shape(p)[:-1])
1491
- n = jnp.broadcast_to(n, broadcast_shape)
1492
- p = jnp.broadcast_to(p, broadcast_shape + u.math.shape(p)[-1:])
1493
- shape = shape or p.shape[:-1]
1494
- if n_max == 0:
1495
- return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
1496
- # get indices from categorical distribution then gather the result
1497
- indices = _categorical(key, p, (n_max,) + shape)
1498
- # mask out values when counts is heterogeneous
1499
- if jnp.ndim(n) > 0:
1500
- mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
1501
- mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
1502
- excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1),
1503
- jnp.zeros(u.math.shape(n) + (p.shape[-1] - 1,))],
1504
- -1)
1505
- else:
1506
- mask = 1
1507
- excess = 0
1508
- # NB: we transpose to move batch shape to the front
1509
- indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
1510
- samples_2D = vmap(_scatter_add_one)(
1511
- jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
1512
- jnp.expand_dims(indices_2D, axis=-1),
1513
- jnp.ones(indices_2D.shape, dtype=indices.dtype)
1514
- )
1515
- return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
1516
-
1517
-
1518
- @partial(jit, static_argnums=(2, 3), static_argnames=['shape', 'dtype'])
1519
- def _von_mises_centered(
1520
- key,
1521
- concentration,
1522
- shape,
1523
- dtype=None
1524
- ):
1525
- """Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
1526
-
1527
- Returns
1528
- -------
1529
- out: array_like
1530
- centered samples from von Mises
1531
-
1532
- References
1533
- ----------
1534
- .. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
1535
- Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
1536
-
1537
- """
1538
- shape = shape or u.math.shape(concentration)
1539
- dtype = dtype or environ.dftype()
1540
- concentration = lax.convert_element_type(concentration, dtype)
1541
- concentration = jnp.broadcast_to(concentration, shape)
1542
-
1543
- if dtype == jnp.float16:
1544
- s_cutoff = 1.8e-1
1545
- elif dtype == jnp.float32:
1546
- s_cutoff = 2e-2
1547
- elif dtype == jnp.float64:
1548
- s_cutoff = 1.2e-4
1549
- else:
1550
- raise ValueError(f"Unsupported dtype: {dtype}")
1551
-
1552
- r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
1553
- rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
1554
- s_exact = (1.0 + rho ** 2) / (2.0 * rho)
1555
-
1556
- s_approximate = 1.0 / concentration
1557
-
1558
- s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
1559
-
1560
- def cond_fn(
1561
- *args
1562
- ):
1563
- """check if all are done or reached max number of iterations"""
1564
- i, _, done, _, _ = args[0]
1565
- return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
1566
-
1567
- def body_fn(
1568
- *args
1569
- ):
1570
- i, key, done, _, w = args[0]
1571
- uni_ukey, uni_vkey, key = jr.split(key, 3)
1572
- u_ = jr.uniform(
1573
- key=uni_ukey,
1574
- shape=shape,
1575
- dtype=concentration.dtype,
1576
- minval=-1.0,
1577
- maxval=1.0,
1578
- )
1579
- z = jnp.cos(jnp.pi * u_)
1580
- w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
1581
- y = concentration * (s - w)
1582
- v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
1583
- accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
1584
- return i + 1, key, accept | done, u_, w
1585
-
1586
- init_done = jnp.zeros(shape, dtype=bool)
1587
- init_u = jnp.zeros(shape)
1588
- init_w = jnp.zeros(shape)
1589
-
1590
- _, _, done, uu, w = lax.while_loop(
1591
- cond_fun=cond_fn,
1592
- body_fun=body_fn,
1593
- init_val=(jnp.array(0), key, init_done, init_u, init_w),
1594
- )
1595
-
1596
- return jnp.sign(uu) * jnp.arccos(w)
1597
-
1598
-
1599
- def _loc_scale(
1600
- loc,
1601
- scale,
1602
- value
1603
- ):
1604
- if loc is None:
1605
- if scale is None:
1606
- return value
1607
- else:
1608
- return value * scale
1609
- else:
1610
- if scale is None:
1611
- return value + loc
1612
- else:
1613
- return value * scale + loc
1614
-
1615
-
1616
- def _check_py_seq(seq):
1617
- return u.math.asarray(seq) if isinstance(seq, (tuple, list)) else seq
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from operator import index
19
+ from typing import Optional
20
+
21
+ import brainunit as u
22
+ import jax
23
+ import jax.numpy as jnp
24
+ import jax.random as jr
25
+ import numpy as np
26
+ from jax import lax, core
27
+
28
+ from brainstate import environ
29
+ from brainstate._state import State
30
+ from brainstate.typing import DTypeLike, Size, SeedOrKey
31
+ from ._impl import (
32
+ multinomial, von_mises_centered, const,
33
+ formalize_key, _loc_scale, _size2shape, _check_py_seq, _check_shape,
34
+ noncentral_f, logseries, hypergeometric, f, power, zipf
35
+ )
36
+
37
+ __all__ = [
38
+ 'RandomState',
39
+ 'DEFAULT',
40
+ ]
41
+
42
+ use_prng_key = True
43
+
44
+
45
+ class RandomState(State):
46
+ """RandomState that track the random generator state. """
47
+
48
+ # __slots__ = ('_backup', '_value')
49
+
50
+ def __init__(
51
+ self,
52
+ seed_or_key: Optional[SeedOrKey] = None
53
+ ):
54
+ """RandomState constructor.
55
+
56
+ Parameters
57
+ ----------
58
+ seed_or_key: int, Array, optional
59
+ It can be an integer for initial seed of the random number generator,
60
+ or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
61
+ """
62
+ with jax.ensure_compile_time_eval():
63
+ if seed_or_key is None:
64
+ seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
65
+ if isinstance(seed_or_key, int):
66
+ key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
67
+ else:
68
+ if jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
69
+ key = seed_or_key
70
+ else:
71
+ if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
72
+ raise ValueError('key must be an array with dtype uint32. '
73
+ f'But we got {seed_or_key}')
74
+ key = seed_or_key
75
+ super().__init__(key)
76
+
77
+ self._backup = None
78
+
79
+ def __repr__(self):
80
+ return f'{self.__class__.__name__}({self.value})'
81
+
82
+ def check_if_deleted(self):
83
+ if not use_prng_key and isinstance(self._value, np.ndarray):
84
+ self._value = jr.key(np.random.randint(0, 10000))
85
+
86
+ if (
87
+ isinstance(self._value, jax.Array) and
88
+ not isinstance(self._value, jax.core.Tracer) and
89
+ self._value.is_deleted()
90
+ ):
91
+ self.seed()
92
+
93
+ @staticmethod
94
+ def _batch_keys(batch_size: int):
95
+ key = jr.PRNGKey(0) if use_prng_key else jr.key(0)
96
+ return jr.split(key, batch_size)
97
+
98
+ # ------------------- #
99
+ # seed and random key #
100
+ # ------------------- #
101
+
102
+ def backup_key(self):
103
+ if self._backup is not None:
104
+ raise ValueError('The random key has been backed up, and has not been restored.')
105
+ self._backup = self.value
106
+
107
+ def restore_key(self):
108
+ if self._backup is None:
109
+ raise ValueError('The random key has not been backed up.')
110
+ self.value = self._backup
111
+ self._backup = None
112
+
113
+ def clone(self):
114
+ return type(self)(self.split_key())
115
+
116
+ def set_key(self, key: SeedOrKey):
117
+ self.value = key
118
+
119
+ def seed(
120
+ self,
121
+ seed_or_key: Optional[SeedOrKey] = None
122
+ ):
123
+ """Sets a new random seed.
124
+
125
+ Parameters
126
+ ----------
127
+ seed_or_key: int, ArrayLike, optional
128
+ It can be an integer for initial seed of the random number generator,
129
+ or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
130
+ """
131
+ with jax.ensure_compile_time_eval():
132
+ if seed_or_key is None:
133
+ seed_or_key = np.random.randint(0, 10000000, 2, dtype=np.uint32)
134
+ if np.size(seed_or_key) == 1:
135
+ if isinstance(seed_or_key, int):
136
+ key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
137
+ elif jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
138
+ key = seed_or_key
139
+ elif isinstance(seed_or_key, (jnp.ndarray, np.ndarray)) and jnp.issubdtype(seed_or_key.dtype, jnp.integer):
140
+ key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
141
+ else:
142
+ raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
143
+ else:
144
+ if len(seed_or_key) == 2 and seed_or_key.dtype == np.uint32:
145
+ key = seed_or_key
146
+ else:
147
+ raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
148
+ self.value = key
149
+
150
+ def split_key(
151
+ self,
152
+ n: Optional[int] = None,
153
+ backup: bool = False
154
+ ) -> SeedOrKey:
155
+ """
156
+ Create a new seed from the current seed.
157
+
158
+ Parameters
159
+ ----------
160
+ n: int, optional
161
+ The number of seeds to generate.
162
+ backup : bool, optional
163
+ Whether to back up the current key.
164
+
165
+ Returns
166
+ -------
167
+ key : SeedOrKey
168
+ The new seed or a tuple of JAX random keys.
169
+ """
170
+ if n is not None:
171
+ assert isinstance(n, int) and n >= 1, f'n should be an integer greater than 1, but we got {n}'
172
+
173
+ if not isinstance(self.value, jax.Array):
174
+ self.value = u.math.asarray(self.value, dtype=jnp.uint32)
175
+ keys = jr.split(self.value, num=2 if n is None else n + 1)
176
+ self.value = keys[0]
177
+ if backup:
178
+ self.backup_key()
179
+ if n is None:
180
+ return keys[1]
181
+ else:
182
+ return keys[1:]
183
+
184
+ def self_assign_multi_keys(
185
+ self,
186
+ n: int,
187
+ backup: bool = True
188
+ ):
189
+ """
190
+ Self-assign multiple keys to the current random state.
191
+ """
192
+ if backup:
193
+ keys = jr.split(self.value, n + 1)
194
+ self.value = keys[0]
195
+ self.backup_key()
196
+ self.value = keys[1:]
197
+ else:
198
+ self.value = jr.split(self.value, n)
199
+
200
+ def __get_key(self, key):
201
+ return self.split_key() if key is None else formalize_key(key, use_prng_key)
202
+
203
+ # ---------------- #
204
+ # random functions #
205
+ # ---------------- #
206
+
207
+ def rand(
208
+ self,
209
+ *dn,
210
+ key: Optional[SeedOrKey] = None,
211
+ dtype: DTypeLike = None
212
+ ):
213
+ key = self.__get_key(key)
214
+ dtype = dtype or environ.dftype()
215
+ r = jr.uniform(key, dn, dtype)
216
+ return r
217
+
218
+ def randint(
219
+ self,
220
+ low,
221
+ high=None,
222
+ size: Optional[Size] = None,
223
+ dtype: DTypeLike = None,
224
+ key: Optional[SeedOrKey] = None
225
+ ):
226
+ if high is None:
227
+ high = low
228
+ low = 0
229
+ high = _check_py_seq(high)
230
+ low = _check_py_seq(low)
231
+ if size is None:
232
+ size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
233
+ key = self.__get_key(key)
234
+ dtype = dtype or environ.ditype()
235
+ r = jr.randint(key,
236
+ shape=_size2shape(size),
237
+ minval=low, maxval=high, dtype=dtype)
238
+ return r
239
+
240
+ def random_integers(
241
+ self,
242
+ low,
243
+ high=None,
244
+ size: Optional[Size] = None,
245
+ key: Optional[SeedOrKey] = None,
246
+ dtype: DTypeLike = None
247
+ ):
248
+ low = _check_py_seq(low)
249
+ high = _check_py_seq(high)
250
+ if high is None:
251
+ high = low
252
+ low = 1
253
+ high += 1
254
+ if size is None:
255
+ size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
256
+ key = self.__get_key(key)
257
+ dtype = dtype or environ.ditype()
258
+ r = jr.randint(key,
259
+ shape=_size2shape(size),
260
+ minval=low,
261
+ maxval=high,
262
+ dtype=dtype)
263
+ return r
264
+
265
+ def randn(
266
+ self,
267
+ *dn,
268
+ key: Optional[SeedOrKey] = None,
269
+ dtype: DTypeLike = None
270
+ ):
271
+ key = self.__get_key(key)
272
+ r = jr.normal(key, shape=dn, dtype=dtype or environ.dftype())
273
+ return r
274
+
275
+ def random(
276
+ self,
277
+ size: Optional[Size] = None,
278
+ key: Optional[SeedOrKey] = None,
279
+ dtype: DTypeLike = None
280
+ ):
281
+ key = self.__get_key(key)
282
+ r = jr.uniform(key, _size2shape(size), dtype=dtype or environ.dftype())
283
+ return r
284
+
285
+ def random_sample(
286
+ self,
287
+ size: Optional[Size] = None,
288
+ key: Optional[SeedOrKey] = None,
289
+ dtype: DTypeLike = None
290
+ ):
291
+ r = self.random(size=size, key=key, dtype=dtype or environ.dftype())
292
+ return r
293
+
294
+ def ranf(
295
+ self,
296
+ size: Optional[Size] = None,
297
+ key: Optional[SeedOrKey] = None,
298
+ dtype: DTypeLike = None
299
+ ):
300
+ r = self.random(size=size, key=key, dtype=dtype or environ.dftype())
301
+ return r
302
+
303
+ def sample(
304
+ self,
305
+ size: Optional[Size] = None,
306
+ key: Optional[SeedOrKey] = None,
307
+ dtype: DTypeLike = None
308
+ ):
309
+ r = self.random(size=size, key=key, dtype=dtype or environ.dftype())
310
+ return r
311
+
312
+ def choice(
313
+ self,
314
+ a,
315
+ size: Optional[Size] = None,
316
+ replace=True,
317
+ p=None,
318
+ key: Optional[SeedOrKey] = None
319
+ ):
320
+ a = _check_py_seq(a)
321
+ a, unit = u.split_mantissa_unit(a)
322
+ p = _check_py_seq(p)
323
+ key = self.__get_key(key)
324
+ r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
325
+ return u.maybe_decimal(r * unit)
326
+
327
+ def permutation(
328
+ self,
329
+ x,
330
+ axis: int = 0,
331
+ independent: bool = False,
332
+ key: Optional[SeedOrKey] = None
333
+ ):
334
+ x = _check_py_seq(x)
335
+ x, unit = u.split_mantissa_unit(x)
336
+ key = self.__get_key(key)
337
+ r = jr.permutation(key, x, axis, independent=independent)
338
+ return u.maybe_decimal(r * unit)
339
+
340
+ def shuffle(
341
+ self,
342
+ x,
343
+ axis=0,
344
+ key: Optional[SeedOrKey] = None
345
+ ):
346
+ return self.permutation(x, axis=axis, key=key, independent=False)
347
+
348
+ def beta(
349
+ self,
350
+ a,
351
+ b,
352
+ size: Optional[Size] = None,
353
+ key: Optional[SeedOrKey] = None,
354
+ dtype: DTypeLike = None
355
+ ):
356
+ a = _check_py_seq(a)
357
+ b = _check_py_seq(b)
358
+ if size is None:
359
+ size = lax.broadcast_shapes(u.math.shape(a), u.math.shape(b))
360
+ key = self.__get_key(key)
361
+ r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype or environ.dftype())
362
+ return r
363
+
364
+ def exponential(
365
+ self,
366
+ scale=None,
367
+ size: Optional[Size] = None,
368
+ key: Optional[SeedOrKey] = None,
369
+ dtype: DTypeLike = None
370
+ ):
371
+ if size is None:
372
+ size = u.math.shape(scale)
373
+ key = self.__get_key(key)
374
+ r = jr.exponential(key, shape=_size2shape(size), dtype=dtype or environ.dftype())
375
+ if scale is not None:
376
+ scale = u.math.asarray(scale, dtype=dtype)
377
+ r = r / scale
378
+ return r
379
+
380
+ def gamma(
381
+ self,
382
+ shape,
383
+ scale=None,
384
+ size: Optional[Size] = None,
385
+ key: Optional[SeedOrKey] = None,
386
+ dtype: DTypeLike = None
387
+ ):
388
+ shape = _check_py_seq(shape)
389
+ scale = _check_py_seq(scale)
390
+ if size is None:
391
+ size = lax.broadcast_shapes(u.math.shape(shape), u.math.shape(scale))
392
+ key = self.__get_key(key)
393
+ r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype or environ.dftype())
394
+ if scale is not None:
395
+ r = r * scale
396
+ return r
397
+
398
+ def gumbel(
399
+ self,
400
+ loc=None,
401
+ scale=None,
402
+ size: Optional[Size] = None,
403
+ key: Optional[SeedOrKey] = None,
404
+ dtype: DTypeLike = None
405
+ ):
406
+ loc = _check_py_seq(loc)
407
+ scale = _check_py_seq(scale)
408
+ if size is None:
409
+ size = lax.broadcast_shapes(u.math.shape(loc), u.math.shape(scale))
410
+ key = self.__get_key(key)
411
+ r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype or environ.dftype()))
412
+ return r
413
+
414
+ def laplace(
415
+ self,
416
+ loc=None,
417
+ scale=None,
418
+ size: Optional[Size] = None,
419
+ key: Optional[SeedOrKey] = None,
420
+ dtype: DTypeLike = None
421
+ ):
422
+ loc = _check_py_seq(loc)
423
+ scale = _check_py_seq(scale)
424
+ if size is None:
425
+ size = lax.broadcast_shapes(u.math.shape(loc), u.math.shape(scale))
426
+ key = self.__get_key(key)
427
+ r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype or environ.dftype()))
428
+ return r
429
+
430
+ def logistic(
431
+ self,
432
+ loc=None,
433
+ scale=None,
434
+ size: Optional[Size] = None,
435
+ key: Optional[SeedOrKey] = None,
436
+ dtype: DTypeLike = None
437
+ ):
438
+ loc = _check_py_seq(loc)
439
+ scale = _check_py_seq(scale)
440
+ if size is None:
441
+ size = lax.broadcast_shapes(
442
+ u.math.shape(loc) if loc is not None else (),
443
+ u.math.shape(scale) if scale is not None else ()
444
+ )
445
+ key = self.__get_key(key)
446
+ r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype or environ.dftype()))
447
+ return r
448
+
449
+ def normal(
450
+ self,
451
+ loc=None,
452
+ scale=None,
453
+ size: Optional[Size] = None,
454
+ key: Optional[SeedOrKey] = None,
455
+ dtype: DTypeLike = None
456
+ ):
457
+ loc = _check_py_seq(loc)
458
+ scale = _check_py_seq(scale)
459
+ if size is None:
460
+ size = lax.broadcast_shapes(
461
+ u.math.shape(scale) if scale is not None else (),
462
+ u.math.shape(loc) if loc is not None else ()
463
+ )
464
+ key = self.__get_key(key)
465
+ dtype = dtype or environ.dftype()
466
+ r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
467
+ return r
468
+
469
+ def pareto(
470
+ self,
471
+ a,
472
+ size: Optional[Size] = None,
473
+ key: Optional[SeedOrKey] = None,
474
+ dtype: DTypeLike = None
475
+ ):
476
+ if size is None:
477
+ size = u.math.shape(a)
478
+ key = self.__get_key(key)
479
+ dtype = dtype or environ.dftype()
480
+ a = u.math.asarray(a, dtype=dtype)
481
+ r = jr.pareto(key, b=a, shape=_size2shape(size), dtype=dtype)
482
+ return r
483
+
484
+ def poisson(
485
+ self,
486
+ lam=1.0,
487
+ size: Optional[Size] = None,
488
+ key: Optional[SeedOrKey] = None,
489
+ dtype: DTypeLike = None
490
+ ):
491
+ lam = _check_py_seq(lam)
492
+ if size is None:
493
+ size = u.math.shape(lam)
494
+ key = self.__get_key(key)
495
+ dtype = dtype or environ.ditype()
496
+ r = jr.poisson(key, lam=lam, shape=_size2shape(size), dtype=dtype)
497
+ return r
498
+
499
+ def standard_cauchy(
500
+ self,
501
+ size: Optional[Size] = None,
502
+ key: Optional[SeedOrKey] = None,
503
+ dtype: DTypeLike = None
504
+ ):
505
+ key = self.__get_key(key)
506
+ dtype = dtype or environ.dftype()
507
+ r = jr.cauchy(key, shape=_size2shape(size), dtype=dtype)
508
+ return r
509
+
510
+ def standard_exponential(
511
+ self,
512
+ size: Optional[Size] = None,
513
+ key: Optional[SeedOrKey] = None,
514
+ dtype: DTypeLike = None
515
+ ):
516
+ key = self.__get_key(key)
517
+ dtype = dtype or environ.dftype()
518
+ r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
519
+ return r
520
+
521
+ def standard_gamma(
522
+ self,
523
+ shape,
524
+ size: Optional[Size] = None,
525
+ key: Optional[SeedOrKey] = None,
526
+ dtype: DTypeLike = None
527
+ ):
528
+ shape = _check_py_seq(shape)
529
+ if size is None:
530
+ size = u.math.shape(shape) if shape is not None else ()
531
+ key = self.__get_key(key)
532
+ dtype = dtype or environ.dftype()
533
+ r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
534
+ return r
535
+
536
+ def standard_normal(
537
+ self,
538
+ size: Optional[Size] = None,
539
+ key: Optional[SeedOrKey] = None,
540
+ dtype: DTypeLike = None
541
+ ):
542
+ key = self.__get_key(key)
543
+ dtype = dtype or environ.dftype()
544
+ r = jr.normal(key, shape=_size2shape(size), dtype=dtype)
545
+ return r
546
+
547
+ def standard_t(
548
+ self,
549
+ df,
550
+ size: Optional[Size] = None,
551
+ key: Optional[SeedOrKey] = None,
552
+ dtype: DTypeLike = None
553
+ ):
554
+ df = _check_py_seq(df)
555
+ if size is None:
556
+ size = u.math.shape(size) if size is not None else ()
557
+ key = self.__get_key(key)
558
+ dtype = dtype or environ.dftype()
559
+ r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
560
+ return r
561
+
562
+ def uniform(
563
+ self,
564
+ low=0.0,
565
+ high=1.0,
566
+ size: Optional[Size] = None,
567
+ key: Optional[SeedOrKey] = None,
568
+ dtype: DTypeLike = None
569
+ ):
570
+ low, unit = u.split_mantissa_unit(_check_py_seq(low))
571
+ high = u.Quantity(_check_py_seq(high)).to(unit).mantissa
572
+ if size is None:
573
+ size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
574
+ key = self.__get_key(key)
575
+ dtype = dtype or environ.dftype()
576
+ r = jr.uniform(key, _size2shape(size), dtype=dtype, minval=low, maxval=high)
577
+ return u.maybe_decimal(r * unit)
578
+
579
+ def __norm_cdf(self, x, sqrt2, dtype):
580
+ # Computes standard normal cumulative distribution function
581
+ return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
582
+
583
+ def truncated_normal(
584
+ self,
585
+ lower,
586
+ upper,
587
+ size: Optional[Size] = None,
588
+ loc=0.0,
589
+ scale=1.0,
590
+ key: Optional[SeedOrKey] = None,
591
+ dtype: DTypeLike = None,
592
+ check_valid: bool = True
593
+ ):
594
+ lower = _check_py_seq(lower)
595
+ upper = _check_py_seq(upper)
596
+ loc = _check_py_seq(loc)
597
+ scale = _check_py_seq(scale)
598
+ dtype = dtype or environ.dftype()
599
+
600
+ lower, unit = u.split_mantissa_unit(u.math.asarray(lower, dtype=dtype))
601
+ upper = u.math.asarray(upper, dtype=dtype)
602
+ loc = u.math.asarray(loc, dtype=dtype)
603
+ scale = u.math.asarray(scale, dtype=dtype)
604
+ upper, loc, scale = (
605
+ u.Quantity(upper).in_unit(unit).mantissa,
606
+ u.Quantity(loc).in_unit(unit).mantissa,
607
+ u.Quantity(scale).in_unit(unit).mantissa
608
+ )
609
+
610
+ if check_valid:
611
+ from brainstate.transform._error_if import jit_error_if
612
+ jit_error_if(
613
+ u.math.any(u.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
614
+ "mean is more than 2 std from [lower, upper] in truncated_normal. "
615
+ "The distribution of values may be incorrect."
616
+ )
617
+
618
+ if size is None:
619
+ size = u.math.broadcast_shapes(
620
+ u.math.shape(lower),
621
+ u.math.shape(upper),
622
+ u.math.shape(loc),
623
+ u.math.shape(scale)
624
+ )
625
+
626
+ # Values are generated by using a truncated uniform distribution and
627
+ # then using the inverse CDF for the normal distribution.
628
+ # Get upper and lower cdf values
629
+ sqrt2 = np.array(np.sqrt(2), dtype=dtype)
630
+ l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
631
+ u_ = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
632
+
633
+ # Uniformly fill tensor with values from [l, u], then translate to
634
+ # [2l-1, 2u-1].
635
+ key = self.__get_key(key)
636
+ out = jr.uniform(
637
+ key, size, dtype,
638
+ minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
639
+ maxval=lax.nextafter(2 * u_ - 1, np.array(-np.inf, dtype=dtype))
640
+ )
641
+
642
+ # Use inverse cdf transform for normal distribution to get truncated
643
+ # standard normal
644
+ out = lax.erf_inv(out)
645
+
646
+ # Transform to proper mean, std
647
+ out = out * scale * sqrt2 + loc
648
+
649
+ # Clamp to ensure it's in the proper range
650
+ out = jnp.clip(
651
+ out,
652
+ lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
653
+ lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
654
+ )
655
+ return u.maybe_decimal(out * unit)
656
+
657
+ def _check_p(self, *args, **kwargs):
658
+ raise ValueError('Parameter p should be within [0, 1], but we got {p}')
659
+
660
+ def bernoulli(
661
+ self,
662
+ p,
663
+ size: Optional[Size] = None,
664
+ key: Optional[SeedOrKey] = None,
665
+ check_valid: bool = True
666
+ ):
667
+ p = _check_py_seq(p)
668
+ if check_valid:
669
+ from brainstate.transform._error_if import jit_error_if
670
+ jit_error_if(jnp.any(jnp.logical_or(p < 0, p > 1)), self._check_p, p=p)
671
+ if size is None:
672
+ size = u.math.shape(p)
673
+ key = self.__get_key(key)
674
+ r = jr.bernoulli(key, p=p, shape=_size2shape(size))
675
+ return r
676
+
677
+ def lognormal(
678
+ self,
679
+ mean=None,
680
+ sigma=None,
681
+ size: Optional[Size] = None,
682
+ key: Optional[SeedOrKey] = None,
683
+ dtype: DTypeLike = None
684
+ ):
685
+ dtype = dtype or environ.dftype()
686
+ mean = _check_py_seq(mean)
687
+ sigma = _check_py_seq(sigma)
688
+ mean = u.math.asarray(mean, dtype=dtype)
689
+ sigma = u.math.asarray(sigma, dtype=dtype)
690
+ unit = mean.unit if isinstance(mean, u.Quantity) else u.UNITLESS
691
+ mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
692
+ sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, u.Quantity) else sigma
693
+
694
+ if size is None:
695
+ size = jnp.broadcast_shapes(
696
+ u.math.shape(mean) if mean is not None else (),
697
+ u.math.shape(sigma) if sigma is not None else ()
698
+ )
699
+ key = self.__get_key(key)
700
+ dtype = dtype or environ.dftype()
701
+ samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
702
+ samples = _loc_scale(mean, sigma, samples)
703
+ samples = jnp.exp(samples)
704
+ return u.maybe_decimal(samples * unit)
705
+
706
+ def binomial(
707
+ self,
708
+ n,
709
+ p,
710
+ size: Optional[Size] = None,
711
+ key: Optional[SeedOrKey] = None,
712
+ dtype: DTypeLike = None,
713
+ check_valid: bool = True
714
+ ):
715
+ n = _check_py_seq(n)
716
+ p = _check_py_seq(p)
717
+ if check_valid:
718
+ from brainstate.transform._error_if import jit_error_if
719
+ jit_error_if(
720
+ jnp.any(jnp.logical_or(p < 0, p > 1)),
721
+ 'Parameter p should be within [0, 1], but we got {p}',
722
+ p=p
723
+ )
724
+ if size is None:
725
+ size = jnp.broadcast_shapes(u.math.shape(n), u.math.shape(p))
726
+ key = self.__get_key(key)
727
+ r = jr.binomial(key, n, p, shape=_size2shape(size))
728
+ dtype = dtype or environ.ditype()
729
+ return u.math.asarray(r, dtype=dtype)
730
+
731
+ def chisquare(
732
+ self,
733
+ df,
734
+ size: Optional[Size] = None,
735
+ key: Optional[SeedOrKey] = None,
736
+ dtype: DTypeLike = None
737
+ ):
738
+ df = _check_py_seq(df)
739
+ key = self.__get_key(key)
740
+ dtype = dtype or environ.dftype()
741
+ if size is None:
742
+ if jnp.ndim(df) == 0:
743
+ dist = jr.normal(key, (df,), dtype=dtype) ** 2
744
+ dist = dist.sum()
745
+ else:
746
+ raise NotImplementedError('Do not support non-scale "df" when "size" is None')
747
+ else:
748
+ dist = jr.normal(key, (df,) + _size2shape(size), dtype=dtype) ** 2
749
+ dist = dist.sum(axis=0)
750
+ return dist
751
+
752
+ def dirichlet(
753
+ self,
754
+ alpha,
755
+ size: Optional[Size] = None,
756
+ key: Optional[SeedOrKey] = None,
757
+ dtype: DTypeLike = None
758
+ ):
759
+ key = self.__get_key(key)
760
+ alpha = _check_py_seq(alpha)
761
+ dtype = dtype or environ.dftype()
762
+ r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size), dtype=dtype)
763
+ return r
764
+
765
+ def geometric(
766
+ self,
767
+ p,
768
+ size: Optional[Size] = None,
769
+ key: Optional[SeedOrKey] = None,
770
+ dtype: DTypeLike = None
771
+ ):
772
+ p = _check_py_seq(p)
773
+ if size is None:
774
+ size = u.math.shape(p)
775
+ key = self.__get_key(key)
776
+ dtype = dtype or environ.dftype()
777
+ u_ = jr.uniform(key, size, dtype)
778
+ r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p))
779
+ return r
780
+
781
+ def _check_p2(self, p):
782
+ raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
783
+
784
+ def multinomial(
785
+ self,
786
+ n,
787
+ pvals,
788
+ size: Optional[Size] = None,
789
+ key: Optional[SeedOrKey] = None,
790
+ dtype: DTypeLike = None,
791
+ check_valid: bool = True
792
+ ):
793
+ key = self.__get_key(key)
794
+ n = _check_py_seq(n)
795
+ pvals = _check_py_seq(pvals)
796
+ if check_valid:
797
+ from brainstate.transform._error_if import jit_error_if
798
+ jit_error_if(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
799
+ if isinstance(n, jax.core.Tracer):
800
+ raise ValueError("The total count parameter `n` should not be a jax abstract array.")
801
+ size = _size2shape(size)
802
+ n_max = int(np.max(jax.device_get(n)))
803
+ batch_shape = lax.broadcast_shapes(u.math.shape(pvals)[:-1], u.math.shape(n))
804
+ r = multinomial(key, pvals, n, n_max=n_max, shape=batch_shape + size)
805
+ dtype = dtype or environ.ditype()
806
+ return u.math.asarray(r, dtype=dtype)
807
+
808
+ def multivariate_normal(
809
+ self,
810
+ mean,
811
+ cov,
812
+ size: Optional[Size] = None,
813
+ method: str = 'cholesky',
814
+ key: Optional[SeedOrKey] = None,
815
+ dtype: DTypeLike = None
816
+ ):
817
+ if method not in {'svd', 'eigh', 'cholesky'}:
818
+ raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
819
+ dtype = dtype or environ.dftype()
820
+ mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
821
+ cov = u.math.asarray(_check_py_seq(cov), dtype=dtype)
822
+ if isinstance(mean, u.Quantity):
823
+ assert isinstance(cov, u.Quantity)
824
+ assert mean.unit ** 2 == cov.unit
825
+ mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
826
+ cov = cov.mantissa if isinstance(cov, u.Quantity) else cov
827
+ unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
828
+
829
+ key = self.__get_key(key)
830
+ if not jnp.ndim(mean) >= 1:
831
+ raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
832
+ if not jnp.ndim(cov) >= 2:
833
+ raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}")
834
+ n = mean.shape[-1]
835
+ if u.math.shape(cov)[-2:] != (n, n):
836
+ raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
837
+ f"but got cov.shape == {u.math.shape(cov)}.")
838
+ if size is None:
839
+ size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
840
+ else:
841
+ size = _size2shape(size)
842
+ _check_shape("normal", size, mean.shape[:-1], cov.shape[:-2])
843
+
844
+ if method == 'svd':
845
+ (u_, s, _) = jnp.linalg.svd(cov)
846
+ factor = u_ * jnp.sqrt(s[..., None, :])
847
+ elif method == 'eigh':
848
+ (w, v) = jnp.linalg.eigh(cov)
849
+ factor = v * jnp.sqrt(w[..., None, :])
850
+ else: # 'cholesky'
851
+ factor = jnp.linalg.cholesky(cov)
852
+ normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
853
+ r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
854
+ return u.maybe_decimal(r * unit)
855
+
856
+ def rayleigh(
857
+ self,
858
+ scale=1.0,
859
+ size: Optional[Size] = None,
860
+ key: Optional[SeedOrKey] = None,
861
+ dtype: DTypeLike = None
862
+ ):
863
+ scale = _check_py_seq(scale)
864
+ if size is None:
865
+ size = u.math.shape(scale)
866
+ key = self.__get_key(key)
867
+ dtype = dtype or environ.dftype()
868
+ x = jnp.sqrt(-2. * jnp.log(jr.uniform(key, shape=_size2shape(size), dtype=dtype)))
869
+ r = x * scale
870
+ return r
871
+
872
+ def triangular(
873
+ self,
874
+ size: Optional[Size] = None,
875
+ key: Optional[SeedOrKey] = None
876
+ ):
877
+ key = self.__get_key(key)
878
+ bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
879
+ r = 2 * bernoulli_samples - 1
880
+ return r
881
+
882
+ def vonmises(
883
+ self,
884
+ mu,
885
+ kappa,
886
+ size: Optional[Size] = None,
887
+ key: Optional[SeedOrKey] = None,
888
+ dtype: DTypeLike = None
889
+ ):
890
+ key = self.__get_key(key)
891
+ dtype = dtype or environ.dftype()
892
+ mu = u.math.asarray(_check_py_seq(mu), dtype=dtype)
893
+ kappa = u.math.asarray(_check_py_seq(kappa), dtype=dtype)
894
+ if size is None:
895
+ size = lax.broadcast_shapes(u.math.shape(mu), u.math.shape(kappa))
896
+ size = _size2shape(size)
897
+ samples = von_mises_centered(key, kappa, size, dtype=dtype)
898
+ samples = samples + mu
899
+ samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
900
+ return samples
901
+
902
+ def weibull(
903
+ self,
904
+ a,
905
+ size: Optional[Size] = None,
906
+ key: Optional[SeedOrKey] = None,
907
+ dtype: DTypeLike = None
908
+ ):
909
+ key = self.__get_key(key)
910
+ a = _check_py_seq(a)
911
+ if size is None:
912
+ size = u.math.shape(a)
913
+ else:
914
+ if jnp.size(a) > 1:
915
+ raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
916
+ size = _size2shape(size)
917
+ dtype = dtype or environ.dftype()
918
+ random_uniform = jr.uniform(key=key, shape=size, dtype=dtype)
919
+ r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
920
+ return r
921
+
922
+ def weibull_min(
923
+ self,
924
+ a,
925
+ scale=None,
926
+ size: Optional[Size] = None,
927
+ key: Optional[SeedOrKey] = None,
928
+ dtype: DTypeLike = None
929
+ ):
930
+ key = self.__get_key(key)
931
+ a = _check_py_seq(a)
932
+ scale = _check_py_seq(scale)
933
+ if size is None:
934
+ size = jnp.broadcast_shapes(u.math.shape(a), u.math.shape(scale) if scale is not None else ())
935
+ else:
936
+ if jnp.size(a) > 1:
937
+ raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
938
+ size = _size2shape(size)
939
+ dtype = dtype or environ.dftype()
940
+ random_uniform = jr.uniform(key=key, shape=size, dtype=dtype)
941
+ r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
942
+ if scale is not None:
943
+ r /= scale
944
+ return r
945
+
946
+ def maxwell(
947
+ self,
948
+ size: Optional[Size] = None,
949
+ key: Optional[SeedOrKey] = None,
950
+ dtype: DTypeLike = None
951
+ ):
952
+ key = self.__get_key(key)
953
+ shape = _size2shape(size) + (3,)
954
+ dtype = dtype or environ.dftype()
955
+ norm_rvs = jr.normal(key=key, shape=shape, dtype=dtype)
956
+ r = jnp.linalg.norm(norm_rvs, axis=-1)
957
+ return r
958
+
959
+ def negative_binomial(
960
+ self,
961
+ n,
962
+ p,
963
+ size: Optional[Size] = None,
964
+ key: Optional[SeedOrKey] = None,
965
+ dtype: DTypeLike = None
966
+ ):
967
+ n = _check_py_seq(n)
968
+ p = _check_py_seq(p)
969
+ if size is None:
970
+ size = lax.broadcast_shapes(u.math.shape(n), u.math.shape(p))
971
+ size = _size2shape(size)
972
+ logits = jnp.log(p) - jnp.log1p(-p)
973
+ if key is None:
974
+ keys = self.split_key(2)
975
+ else:
976
+ keys = jr.split(formalize_key(key, use_prng_key), 2)
977
+ rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0], dtype=environ.dftype())
978
+ r = self.poisson(lam=rate, key=keys[1], dtype=dtype or environ.ditype())
979
+ return r
980
+
981
+ def wald(
982
+ self,
983
+ mean,
984
+ scale,
985
+ size: Optional[Size] = None,
986
+ key: Optional[SeedOrKey] = None,
987
+ dtype: DTypeLike = None
988
+ ):
989
+ dtype = dtype or environ.dftype()
990
+ key = self.__get_key(key)
991
+ mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
992
+ scale = u.math.asarray(_check_py_seq(scale), dtype=dtype)
993
+ if size is None:
994
+ size = lax.broadcast_shapes(u.math.shape(mean), u.math.shape(scale))
995
+ size = _size2shape(size)
996
+ sampled_chi2 = jnp.square(self.randn(*size))
997
+ sampled_uniform = self.uniform(size=size, key=key, dtype=dtype)
998
+ # Wikipedia defines an intermediate x with the formula
999
+ # x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2)
1000
+ # where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration.
1001
+ # Let us write
1002
+ # w = loc * y / (2 * conc)
1003
+ # Then we can extract the common factor in the last two terms to obtain
1004
+ # x = loc + loc * w * (1 - sqrt(2 / w + 1))
1005
+ # Now we see that the Wikipedia formula suffers from catastrphic
1006
+ # cancellation for large w (e.g., if conc << loc).
1007
+ #
1008
+ # Fortunately, we can fix this by multiplying both sides
1009
+ # by 1 + sqrt(2 / w + 1). We get
1010
+ # x * (1 + sqrt(2 / w + 1)) =
1011
+ # = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1))
1012
+ # = loc * (sqrt(2 / w + 1) - 1)
1013
+ # The term sqrt(2 / w + 1) + 1 no longer presents numerical
1014
+ # difficulties for large w, and sqrt(2 / w + 1) - 1 is just
1015
+ # sqrt1pm1(2 / w), which we know how to compute accurately.
1016
+ # This just leaves the matter of small w, where 2 / w may
1017
+ # overflow. In the limit a w -> 0, x -> loc, so we just mask
1018
+ # that case.
1019
+ sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above
1020
+ safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0)
1021
+ denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0)
1022
+ ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator
1023
+ sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above
1024
+ res = jnp.where(sampled_uniform <= mean / (mean + sampled),
1025
+ sampled,
1026
+ jnp.square(mean) / sampled)
1027
+ return res
1028
+
1029
+ def t(
1030
+ self,
1031
+ df,
1032
+ size: Optional[Size] = None,
1033
+ key: Optional[SeedOrKey] = None,
1034
+ dtype: DTypeLike = None
1035
+ ):
1036
+ dtype = dtype or environ.dftype()
1037
+ df = u.math.asarray(_check_py_seq(df), dtype=dtype)
1038
+ if size is None:
1039
+ size = np.shape(df)
1040
+ else:
1041
+ size = _size2shape(size)
1042
+ _check_shape("t", size, np.shape(df))
1043
+ if key is None:
1044
+ keys = self.split_key(2)
1045
+ else:
1046
+ keys = jr.split(formalize_key(key, use_prng_key), 2)
1047
+ n = jr.normal(keys[0], size, dtype=dtype)
1048
+ two = const(n, 2)
1049
+ half_df = lax.div(df, two)
1050
+ g = jr.gamma(keys[1], half_df, size, dtype=dtype)
1051
+ r = n * jnp.sqrt(half_df / g)
1052
+ return r
1053
+
1054
+ def orthogonal(
1055
+ self,
1056
+ n: int,
1057
+ size: Optional[Size] = None,
1058
+ key: Optional[SeedOrKey] = None,
1059
+ dtype: DTypeLike = None
1060
+ ):
1061
+ dtype = dtype or environ.dftype()
1062
+ key = self.__get_key(key)
1063
+ size = _size2shape(size)
1064
+ _check_shape("orthogonal", size)
1065
+ n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
1066
+ z = jr.normal(key, size + (n, n), dtype=dtype)
1067
+ q, r = jnp.linalg.qr(z)
1068
+ d = jnp.diagonal(r, 0, -2, -1)
1069
+ r = q * jnp.expand_dims(d / abs(d), -2)
1070
+ return r
1071
+
1072
+ def noncentral_chisquare(
1073
+ self,
1074
+ df,
1075
+ nonc,
1076
+ size: Optional[Size] = None,
1077
+ key: Optional[SeedOrKey] = None,
1078
+ dtype: DTypeLike = None
1079
+ ):
1080
+ dtype = dtype or environ.dftype()
1081
+ df = u.math.asarray(_check_py_seq(df), dtype=dtype)
1082
+ nonc = u.math.asarray(_check_py_seq(nonc), dtype=dtype)
1083
+ if size is None:
1084
+ size = lax.broadcast_shapes(u.math.shape(df), u.math.shape(nonc))
1085
+ size = _size2shape(size)
1086
+ if key is None:
1087
+ keys = self.split_key(3)
1088
+ else:
1089
+ keys = jr.split(formalize_key(key, use_prng_key), 3)
1090
+ i = jr.poisson(keys[0], 0.5 * nonc, shape=size, dtype=environ.ditype())
1091
+ n = jr.normal(keys[1], shape=size, dtype=dtype) + jnp.sqrt(nonc)
1092
+ cond = jnp.greater(df, 1.0)
1093
+ df2 = jnp.where(cond, df - 1.0, df + 2.0 * i)
1094
+ chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size, dtype=dtype)
1095
+ r = jnp.where(cond, chi2 + n * n, chi2)
1096
+ return r
1097
+
1098
+ def loggamma(
1099
+ self,
1100
+ a,
1101
+ size: Optional[Size] = None,
1102
+ key: Optional[SeedOrKey] = None,
1103
+ dtype: DTypeLike = None
1104
+ ):
1105
+ dtype = dtype or environ.dftype()
1106
+ key = self.__get_key(key)
1107
+ a = _check_py_seq(a)
1108
+ if size is None:
1109
+ size = u.math.shape(a)
1110
+ r = jr.loggamma(key, a, shape=_size2shape(size), dtype=dtype)
1111
+ return r
1112
+
1113
+ def categorical(
1114
+ self,
1115
+ logits,
1116
+ axis: int = -1,
1117
+ size: Optional[Size] = None,
1118
+ key: Optional[SeedOrKey] = None
1119
+ ):
1120
+ key = self.__get_key(key)
1121
+ logits = _check_py_seq(logits)
1122
+ if size is None:
1123
+ size = list(u.math.shape(logits))
1124
+ size.pop(axis)
1125
+ r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
1126
+ return r
1127
+
1128
+ def zipf(
1129
+ self,
1130
+ a,
1131
+ size: Optional[Size] = None,
1132
+ key: Optional[SeedOrKey] = None,
1133
+ dtype: DTypeLike = None
1134
+ ):
1135
+ a = _check_py_seq(a)
1136
+ if size is None:
1137
+ size = u.math.shape(a)
1138
+ r = zipf(
1139
+ self.__get_key(key),
1140
+ a,
1141
+ shape=size,
1142
+ dtype=dtype or environ.ditype()
1143
+ )
1144
+ return r
1145
+
1146
+ def power(
1147
+ self,
1148
+ a,
1149
+ size: Optional[Size] = None,
1150
+ key: Optional[SeedOrKey] = None,
1151
+ dtype: DTypeLike = None
1152
+ ):
1153
+ a = _check_py_seq(a)
1154
+ if size is None:
1155
+ size = u.math.shape(a)
1156
+ size = _size2shape(size)
1157
+ r = power(
1158
+ self.__get_key(key),
1159
+ a,
1160
+ shape=size,
1161
+ dtype=dtype or environ.dftype(),
1162
+ )
1163
+ return r
1164
+
1165
+ def f(
1166
+ self,
1167
+ dfnum,
1168
+ dfden,
1169
+ size: Optional[Size] = None,
1170
+ key: Optional[SeedOrKey] = None,
1171
+ dtype: DTypeLike = None
1172
+ ):
1173
+ dfnum = _check_py_seq(dfnum)
1174
+ dfden = _check_py_seq(dfden)
1175
+ if size is None:
1176
+ size = jnp.broadcast_shapes(u.math.shape(dfnum), u.math.shape(dfden))
1177
+ size = _size2shape(size)
1178
+ r = f(
1179
+ self.__get_key(key),
1180
+ dfnum,
1181
+ dfden,
1182
+ shape=size,
1183
+ dtype=dtype or environ.dftype(),
1184
+ )
1185
+ return r
1186
+
1187
+ def hypergeometric(
1188
+ self,
1189
+ ngood,
1190
+ nbad,
1191
+ nsample,
1192
+ size: Optional[Size] = None,
1193
+ key: Optional[SeedOrKey] = None,
1194
+ dtype: DTypeLike = None
1195
+ ):
1196
+ ngood = _check_py_seq(ngood)
1197
+ nbad = _check_py_seq(nbad)
1198
+ nsample = _check_py_seq(nsample)
1199
+ if size is None:
1200
+ size = lax.broadcast_shapes(
1201
+ u.math.shape(ngood),
1202
+ u.math.shape(nbad),
1203
+ u.math.shape(nsample)
1204
+ )
1205
+ size = _size2shape(size)
1206
+ r = hypergeometric(
1207
+ self.__get_key(key),
1208
+ ngood,
1209
+ nbad,
1210
+ nsample,
1211
+ shape=size,
1212
+ dtype=dtype or environ.ditype(),
1213
+ )
1214
+ return r
1215
+
1216
+ def logseries(
1217
+ self,
1218
+ p,
1219
+ size: Optional[Size] = None,
1220
+ key: Optional[SeedOrKey] = None,
1221
+ dtype: DTypeLike = None
1222
+ ):
1223
+ p = _check_py_seq(p)
1224
+ if size is None:
1225
+ size = u.math.shape(p)
1226
+ size = _size2shape(size)
1227
+ r = logseries(
1228
+ self.__get_key(key),
1229
+ p,
1230
+ shape=size,
1231
+ dtype=dtype or environ.ditype()
1232
+ )
1233
+ return r
1234
+
1235
+ def noncentral_f(
1236
+ self,
1237
+ dfnum,
1238
+ dfden,
1239
+ nonc,
1240
+ size: Optional[Size] = None,
1241
+ key: Optional[SeedOrKey] = None,
1242
+ dtype: DTypeLike = None
1243
+ ):
1244
+ dfnum = _check_py_seq(dfnum)
1245
+ dfden = _check_py_seq(dfden)
1246
+ nonc = _check_py_seq(nonc)
1247
+ if size is None:
1248
+ size = lax.broadcast_shapes(u.math.shape(dfnum),
1249
+ u.math.shape(dfden),
1250
+ u.math.shape(nonc))
1251
+ size = _size2shape(size)
1252
+ r = noncentral_f(
1253
+ self.__get_key(key),
1254
+ dfnum,
1255
+ dfden,
1256
+ nonc,
1257
+ shape=size,
1258
+ dtype=dtype or environ.dftype(),
1259
+ )
1260
+ return r
1261
+
1262
+ # PyTorch compatibility #
1263
+ # --------------------- #
1264
+
1265
+ def rand_like(
1266
+ self,
1267
+ input,
1268
+ *,
1269
+ dtype=None,
1270
+ key: Optional[SeedOrKey] = None
1271
+ ):
1272
+ """Returns a tensor with the same size as input that is filled with random
1273
+ numbers from a uniform distribution on the interval ``[0, 1)``.
1274
+
1275
+ Args:
1276
+ input: the ``size`` of input will determine size of the output tensor.
1277
+ dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1278
+ key: the seed or key for the random.
1279
+
1280
+ Returns:
1281
+ The random data.
1282
+ """
1283
+ return self.random(u.math.shape(input), key=key).astype(dtype)
1284
+
1285
+ def randn_like(
1286
+ self,
1287
+ input,
1288
+ *,
1289
+ dtype=None,
1290
+ key: Optional[SeedOrKey] = None
1291
+ ):
1292
+ """Returns a tensor with the same size as ``input`` that is filled with
1293
+ random numbers from a normal distribution with mean 0 and variance 1.
1294
+
1295
+ Args:
1296
+ input: the ``size`` of input will determine size of the output tensor.
1297
+ dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1298
+ key: the seed or key for the random.
1299
+
1300
+ Returns:
1301
+ The random data.
1302
+ """
1303
+ return self.randn(*u.math.shape(input), key=key).astype(dtype)
1304
+
1305
+ def randint_like(
1306
+ self,
1307
+ input,
1308
+ low=0,
1309
+ high=None,
1310
+ *,
1311
+ dtype=None,
1312
+ key: Optional[SeedOrKey] = None
1313
+ ):
1314
+ if high is None:
1315
+ high = max(input)
1316
+ return self.randint(low, high=high, size=u.math.shape(input), dtype=dtype, key=key)
1317
+
1318
+
1319
+ # default random generator
1320
+ DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))