brainstate 0.1.7__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -146
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +509 -470
  58. brainstate/nn/_delay_test.py +238 -0
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1361
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1120
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -208
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
  127. brainstate-0.1.9.dist-info/RECORD +130 -0
  128. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.7.dist-info/RECORD +0 -131
  133. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,1409 +1,1409 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from functools import partial
19
- from operator import index
20
- from typing import Optional
21
-
22
- import brainunit as u
23
- import jax
24
- import jax.numpy as jnp
25
- import jax.random as jr
26
- import numpy as np
27
- from jax import jit, vmap
28
- from jax import lax, core, dtypes
29
-
30
- from brainstate import environ
31
- from brainstate._state import State
32
- from brainstate.compile._error_if import jit_error_if
33
- from brainstate.typing import DTypeLike, Size, SeedOrKey
34
- from ._random_for_unit import uniform_for_unit, permutation_for_unit
35
-
36
- __all__ = ['RandomState', 'DEFAULT', ]
37
-
38
- use_prng_key = True
39
-
40
-
41
- class RandomState(State):
42
- """RandomState that track the random generator state. """
43
-
44
- # __slots__ = ('_backup', '_value')
45
-
46
- def __init__(self, seed_or_key: Optional[SeedOrKey] = None):
47
- """RandomState constructor.
48
-
49
- Parameters
50
- ----------
51
- seed_or_key: int, Array, optional
52
- It can be an integer for initial seed of the random number generator,
53
- or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
54
- """
55
- with jax.ensure_compile_time_eval():
56
- if seed_or_key is None:
57
- seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
58
- if isinstance(seed_or_key, int):
59
- key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
60
- else:
61
- if jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
62
- key = seed_or_key
63
- else:
64
- if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
65
- raise ValueError('key must be an array with dtype uint32. '
66
- f'But we got {seed_or_key}')
67
- key = seed_or_key
68
- super().__init__(key)
69
-
70
- self._backup = None
71
-
72
- def __repr__(self):
73
- return f'{self.__class__.__name__}({self.value})'
74
-
75
- def check_if_deleted(self):
76
- if not use_prng_key and isinstance(self._value, np.ndarray):
77
- self._value = jr.key(np.random.randint(0, 10000))
78
-
79
- if (
80
- isinstance(self._value, jax.Array) and
81
- not isinstance(self._value, jax.core.Tracer) and
82
- self._value.is_deleted()
83
- ):
84
- self.seed()
85
-
86
- # ------------------- #
87
- # seed and random key #
88
- # ------------------- #
89
-
90
- def backup_key(self):
91
- if self._backup is not None:
92
- raise ValueError('The random key has been backed up, and has not been restored.')
93
- self._backup = self.value
94
-
95
- def restore_key(self):
96
- if self._backup is None:
97
- raise ValueError('The random key has not been backed up.')
98
- self.value = self._backup
99
- self._backup = None
100
-
101
- def clone(self):
102
- return type(self)(self.split_key())
103
-
104
- def set_key(self, key: SeedOrKey):
105
- self.value = key
106
-
107
- def seed(self, seed_or_key: Optional[SeedOrKey] = None):
108
- """Sets a new random seed.
109
-
110
- Parameters
111
- ----------
112
- seed_or_key: int, ArrayLike, optional
113
- It can be an integer for initial seed of the random number generator,
114
- or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
115
- """
116
- with jax.ensure_compile_time_eval():
117
- if seed_or_key is None:
118
- seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
119
- if np.size(seed_or_key) == 1:
120
- if isinstance(seed_or_key, int):
121
- key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
122
- elif jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
123
- key = seed_or_key
124
- elif isinstance(seed_or_key, (jnp.ndarray, np.ndarray)) and jnp.issubdtype(seed_or_key.dtype, jnp.integer):
125
- key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
126
- else:
127
- raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
128
- else:
129
- if len(seed_or_key) == 2 and seed_or_key.dtype == np.uint32:
130
- key = seed_or_key
131
- else:
132
- raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
133
- self.value = key
134
-
135
- def split_key(self, n: Optional[int] = None, backup: bool = False) -> SeedOrKey:
136
- """
137
- Create a new seed from the current seed.
138
-
139
- Parameters
140
- ----------
141
- n: int, optional
142
- The number of seeds to generate.
143
- backup : bool, optional
144
- Whether to backup the current key.
145
-
146
- Returns
147
- -------
148
- key : SeedOrKey
149
- The new seed or a tuple of JAX random keys.
150
- """
151
- if n is not None:
152
- assert isinstance(n, int) and n >= 1, f'n should be an integer greater than 1, but we got {n}'
153
-
154
- if not isinstance(self.value, jax.Array):
155
- self.value = jnp.asarray(self.value, dtype=jnp.uint32)
156
- keys = jr.split(self.value, num=2 if n is None else n + 1)
157
- self.value = keys[0]
158
- if backup:
159
- self.backup_key()
160
- if n is None:
161
- return keys[1]
162
- else:
163
- return keys[1:]
164
-
165
- def self_assign_multi_keys(self, n: int, backup: bool = True):
166
- """
167
- Self-assign multiple keys to the current random state.
168
- """
169
- if backup:
170
- keys = jr.split(self.value, n + 1)
171
- self.value = keys[0]
172
- self.backup_key()
173
- self.value = keys[1:]
174
- else:
175
- self.value = jr.split(self.value, n)
176
-
177
- # ---------------- #
178
- # random functions #
179
- # ---------------- #
180
-
181
- def rand(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
182
- key = self.split_key() if key is None else _formalize_key(key)
183
- dtype = dtype or environ.dftype()
184
- r = uniform_for_unit(key, shape=dn, minval=0., maxval=1., dtype=dtype)
185
- return r
186
-
187
- def randint(
188
- self,
189
- low,
190
- high=None,
191
- size: Optional[Size] = None,
192
- dtype: DTypeLike = None,
193
- key: Optional[SeedOrKey] = None
194
- ):
195
- if high is None:
196
- high = low
197
- low = 0
198
- high = _check_py_seq(high)
199
- low = _check_py_seq(low)
200
- if size is None:
201
- size = lax.broadcast_shapes(jnp.shape(low),
202
- jnp.shape(high))
203
- key = self.split_key() if key is None else _formalize_key(key)
204
- dtype = dtype or environ.ditype()
205
- r = jr.randint(key,
206
- shape=_size2shape(size),
207
- minval=low, maxval=high, dtype=dtype)
208
- return r
209
-
210
- def random_integers(
211
- self,
212
- low,
213
- high=None,
214
- size: Optional[Size] = None,
215
- key: Optional[SeedOrKey] = None,
216
- dtype: DTypeLike = None,
217
- ):
218
- low = _check_py_seq(low)
219
- high = _check_py_seq(high)
220
- if high is None:
221
- high = low
222
- low = 1
223
- high += 1
224
- if size is None:
225
- size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
226
- key = self.split_key() if key is None else _formalize_key(key)
227
- dtype = dtype or environ.ditype()
228
- r = jr.randint(key,
229
- shape=_size2shape(size),
230
- minval=low,
231
- maxval=high,
232
- dtype=dtype)
233
- return r
234
-
235
- def randn(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
236
- key = self.split_key() if key is None else _formalize_key(key)
237
- dtype = dtype or environ.dftype()
238
- r = jr.normal(key, shape=dn, dtype=dtype)
239
- return r
240
-
241
- def random(self,
242
- size: Optional[Size] = None,
243
- key: Optional[SeedOrKey] = None,
244
- dtype: DTypeLike = None):
245
- dtype = dtype or environ.dftype()
246
- key = self.split_key() if key is None else _formalize_key(key)
247
- r = uniform_for_unit(key, shape=_size2shape(size), minval=0., maxval=1., dtype=dtype)
248
- return r
249
-
250
- def random_sample(self,
251
- size: Optional[Size] = None,
252
- key: Optional[SeedOrKey] = None,
253
- dtype: DTypeLike = None):
254
- r = self.random(size=size, key=key, dtype=dtype)
255
- return r
256
-
257
- def ranf(self,
258
- size: Optional[Size] = None,
259
- key: Optional[SeedOrKey] = None,
260
- dtype: DTypeLike = None):
261
- r = self.random(size=size, key=key, dtype=dtype)
262
- return r
263
-
264
- def sample(self,
265
- size: Optional[Size] = None,
266
- key: Optional[SeedOrKey] = None,
267
- dtype: DTypeLike = None):
268
- r = self.random(size=size, key=key, dtype=dtype)
269
- return r
270
-
271
- def choice(self,
272
- a,
273
- size: Optional[Size] = None,
274
- replace=True,
275
- p=None,
276
- key: Optional[SeedOrKey] = None):
277
- a = _check_py_seq(a)
278
- p = _check_py_seq(p)
279
- key = self.split_key() if key is None else _formalize_key(key)
280
- r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
281
- return r
282
-
283
- def permutation(self,
284
- x,
285
- axis: int = 0,
286
- independent: bool = False,
287
- key: Optional[SeedOrKey] = None):
288
- x = _check_py_seq(x)
289
- key = self.split_key() if key is None else _formalize_key(key)
290
- r = permutation_for_unit(key, x, axis=axis, independent=independent)
291
- return r
292
-
293
- def shuffle(self,
294
- x,
295
- axis=0,
296
- key: Optional[SeedOrKey] = None):
297
- key = self.split_key() if key is None else _formalize_key(key)
298
- x = permutation_for_unit(key, x, axis=axis)
299
- return x
300
-
301
- def beta(self,
302
- a,
303
- b,
304
- size: Optional[Size] = None,
305
- key: Optional[SeedOrKey] = None,
306
- dtype: DTypeLike = None):
307
- a = _check_py_seq(a)
308
- b = _check_py_seq(b)
309
- if size is None:
310
- size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b))
311
- key = self.split_key() if key is None else _formalize_key(key)
312
- dtype = dtype or environ.dftype()
313
- r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype)
314
- return r
315
-
316
- def exponential(self,
317
- scale=None,
318
- size: Optional[Size] = None,
319
- key: Optional[SeedOrKey] = None,
320
- dtype: DTypeLike = None):
321
- if size is None:
322
- size = jnp.shape(scale)
323
- key = self.split_key() if key is None else _formalize_key(key)
324
- dtype = dtype or environ.dftype()
325
- scale = jnp.asarray(scale, dtype=dtype)
326
- r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
327
- if scale is not None:
328
- r = r / scale
329
- return r
330
-
331
- def gamma(self,
332
- shape,
333
- scale=None,
334
- size: Optional[Size] = None,
335
- key: Optional[SeedOrKey] = None,
336
- dtype: DTypeLike = None):
337
- shape = _check_py_seq(shape)
338
- scale = _check_py_seq(scale)
339
- if size is None:
340
- size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale))
341
- key = self.split_key() if key is None else _formalize_key(key)
342
- dtype = dtype or environ.dftype()
343
- r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
344
- if scale is not None:
345
- r = r * scale
346
- return r
347
-
348
- def gumbel(self,
349
- loc=None,
350
- scale=None,
351
- size: Optional[Size] = None,
352
- key: Optional[SeedOrKey] = None,
353
- dtype: DTypeLike = None):
354
- loc = _check_py_seq(loc)
355
- scale = _check_py_seq(scale)
356
- if size is None:
357
- size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
358
- key = self.split_key() if key is None else _formalize_key(key)
359
- dtype = dtype or environ.dftype()
360
- r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype))
361
- return r
362
-
363
- def laplace(self,
364
- loc=None,
365
- scale=None,
366
- size: Optional[Size] = None,
367
- key: Optional[SeedOrKey] = None,
368
- dtype: DTypeLike = None):
369
- loc = _check_py_seq(loc)
370
- scale = _check_py_seq(scale)
371
- if size is None:
372
- size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
373
- key = self.split_key() if key is None else _formalize_key(key)
374
- dtype = dtype or environ.dftype()
375
- r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype))
376
- return r
377
-
378
- def logistic(self,
379
- loc=None,
380
- scale=None,
381
- size: Optional[Size] = None,
382
- key: Optional[SeedOrKey] = None,
383
- dtype: DTypeLike = None):
384
- loc = _check_py_seq(loc)
385
- scale = _check_py_seq(scale)
386
- if size is None:
387
- size = lax.broadcast_shapes(
388
- jnp.shape(loc) if loc is not None else (),
389
- jnp.shape(scale) if scale is not None else ()
390
- )
391
- key = self.split_key() if key is None else _formalize_key(key)
392
- dtype = dtype or environ.dftype()
393
- r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
394
- return r
395
-
396
- def normal(self,
397
- loc=None,
398
- scale=None,
399
- size: Optional[Size] = None,
400
- key: Optional[SeedOrKey] = None,
401
- dtype: DTypeLike = None):
402
- loc = _check_py_seq(loc)
403
- scale = _check_py_seq(scale)
404
- if size is None:
405
- size = lax.broadcast_shapes(
406
- jnp.shape(scale) if scale is not None else (),
407
- jnp.shape(loc) if loc is not None else ()
408
- )
409
- key = self.split_key() if key is None else _formalize_key(key)
410
- dtype = dtype or environ.dftype()
411
- r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
412
- return r
413
-
414
- def pareto(self,
415
- a,
416
- size: Optional[Size] = None,
417
- key: Optional[SeedOrKey] = None,
418
- dtype: DTypeLike = None):
419
- if size is None:
420
- size = jnp.shape(a)
421
- key = self.split_key() if key is None else _formalize_key(key)
422
- dtype = dtype or environ.dftype()
423
- a = jnp.asarray(a, dtype=dtype)
424
- r = jr.pareto(key, b=a, shape=_size2shape(size), dtype=dtype)
425
- return r
426
-
427
- def poisson(self,
428
- lam=1.0,
429
- size: Optional[Size] = None,
430
- key: Optional[SeedOrKey] = None,
431
- dtype: DTypeLike = None):
432
- lam = _check_py_seq(lam)
433
- if size is None:
434
- size = jnp.shape(lam)
435
- key = self.split_key() if key is None else _formalize_key(key)
436
- dtype = dtype or environ.ditype()
437
- r = jr.poisson(key, lam=lam, shape=_size2shape(size), dtype=dtype)
438
- return r
439
-
440
- def standard_cauchy(self,
441
- size: Optional[Size] = None,
442
- key: Optional[SeedOrKey] = None,
443
- dtype: DTypeLike = None):
444
- key = self.split_key() if key is None else _formalize_key(key)
445
- dtype = dtype or environ.dftype()
446
- r = jr.cauchy(key, shape=_size2shape(size), dtype=dtype)
447
- return r
448
-
449
- def standard_exponential(self,
450
- size: Optional[Size] = None,
451
- key: Optional[SeedOrKey] = None,
452
- dtype: DTypeLike = None):
453
- key = self.split_key() if key is None else _formalize_key(key)
454
- dtype = dtype or environ.dftype()
455
- r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
456
- return r
457
-
458
- def standard_gamma(self,
459
- shape,
460
- size: Optional[Size] = None,
461
- key: Optional[SeedOrKey] = None,
462
- dtype: DTypeLike = None):
463
- shape = _check_py_seq(shape)
464
- if size is None:
465
- size = jnp.shape(shape) if shape is not None else ()
466
- key = self.split_key() if key is None else _formalize_key(key)
467
- dtype = dtype or environ.dftype()
468
- r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
469
- return r
470
-
471
- def standard_normal(self,
472
- size: Optional[Size] = None,
473
- key: Optional[SeedOrKey] = None,
474
- dtype: DTypeLike = None):
475
- key = self.split_key() if key is None else _formalize_key(key)
476
- dtype = dtype or environ.dftype()
477
- r = jr.normal(key, shape=_size2shape(size), dtype=dtype)
478
- return r
479
-
480
- def standard_t(self, df,
481
- size: Optional[Size] = None,
482
- key: Optional[SeedOrKey] = None,
483
- dtype: DTypeLike = None):
484
- df = _check_py_seq(df)
485
- if size is None:
486
- size = jnp.shape(size) if size is not None else ()
487
- key = self.split_key() if key is None else _formalize_key(key)
488
- dtype = dtype or environ.dftype()
489
- r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
490
- return r
491
-
492
- def uniform(self,
493
- low=0.0,
494
- high=1.0,
495
- size: Optional[Size] = None,
496
- key: Optional[SeedOrKey] = None,
497
- dtype: DTypeLike = None):
498
- low = _check_py_seq(low)
499
- high = _check_py_seq(high)
500
- if size is None:
501
- size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
502
- key = self.split_key() if key is None else _formalize_key(key)
503
- dtype = dtype or environ.dftype()
504
- r = uniform_for_unit(key, shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)
505
- return r
506
-
507
- def __norm_cdf(self, x, sqrt2, dtype):
508
- # Computes standard normal cumulative distribution function
509
- return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
510
-
511
- def truncated_normal(
512
- self,
513
- lower,
514
- upper,
515
- size: Optional[Size] = None,
516
- loc=0.,
517
- scale=1.,
518
- key: Optional[SeedOrKey] = None,
519
- dtype: DTypeLike = None
520
- ):
521
- lower = _check_py_seq(lower)
522
- upper = _check_py_seq(upper)
523
- loc = _check_py_seq(loc)
524
- scale = _check_py_seq(scale)
525
- dtype = dtype or environ.dftype()
526
-
527
- lower = u.math.asarray(lower, dtype=dtype)
528
- upper = u.math.asarray(upper, dtype=dtype)
529
- loc = u.math.asarray(loc, dtype=dtype)
530
- scale = u.math.asarray(scale, dtype=dtype)
531
- unit = u.get_unit(lower)
532
- lower, upper, loc, scale = (
533
- lower.mantissa if isinstance(lower, u.Quantity) else lower,
534
- u.Quantity(upper).in_unit(unit).mantissa,
535
- u.Quantity(loc).in_unit(unit).mantissa,
536
- u.Quantity(scale).in_unit(unit).mantissa
537
- )
538
-
539
- jit_error_if(
540
- u.math.any(u.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
541
- "mean is more than 2 std from [lower, upper] in truncated_normal. "
542
- "The distribution of values may be incorrect."
543
- )
544
-
545
- if size is None:
546
- size = u.math.broadcast_shapes(jnp.shape(lower),
547
- jnp.shape(upper),
548
- jnp.shape(loc),
549
- jnp.shape(scale))
550
-
551
- # Values are generated by using a truncated uniform distribution and
552
- # then using the inverse CDF for the normal distribution.
553
- # Get upper and lower cdf values
554
- sqrt2 = np.array(np.sqrt(2), dtype=dtype)
555
- l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
556
- u_ = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
557
-
558
- # Uniformly fill tensor with values from [l, u], then translate to
559
- # [2l-1, 2u-1].
560
- key = self.split_key() if key is None else _formalize_key(key)
561
- out = uniform_for_unit(
562
- key, size, dtype,
563
- minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
564
- maxval=lax.nextafter(2 * u_ - 1, np.array(-np.inf, dtype=dtype))
565
- )
566
-
567
- # Use inverse cdf transform for normal distribution to get truncated
568
- # standard normal
569
- out = lax.erf_inv(out)
570
-
571
- # Transform to proper mean, std
572
- out = out * scale * sqrt2 + loc
573
-
574
- # Clamp to ensure it's in the proper range
575
- out = jnp.clip(
576
- out,
577
- lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
578
- lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
579
- )
580
- return out if unit.is_unitless else u.Quantity(out, unit=unit)
581
-
582
- def _check_p(self, *args, **kwargs):
583
- raise ValueError('Parameter p should be within [0, 1], but we got {p}')
584
-
585
- def bernoulli(self,
586
- p,
587
- size: Optional[Size] = None,
588
- key: Optional[SeedOrKey] = None):
589
- p = _check_py_seq(p)
590
- jit_error_if(jnp.any(jnp.logical_or(p < 0, p > 1)), self._check_p, p=p)
591
- if size is None:
592
- size = jnp.shape(p)
593
- key = self.split_key() if key is None else _formalize_key(key)
594
- r = jr.bernoulli(key, p=p, shape=_size2shape(size))
595
- return r
596
-
597
- def lognormal(
598
- self,
599
- mean=None,
600
- sigma=None,
601
- size: Optional[Size] = None,
602
- key: Optional[SeedOrKey] = None,
603
- dtype: DTypeLike = None
604
- ):
605
- mean = _check_py_seq(mean)
606
- sigma = _check_py_seq(sigma)
607
- mean = u.math.asarray(mean, dtype=dtype)
608
- sigma = u.math.asarray(sigma, dtype=dtype)
609
- unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
610
- mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
611
- sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, u.Quantity) else sigma
612
-
613
- if size is None:
614
- size = jnp.broadcast_shapes(
615
- jnp.shape(mean) if mean is not None else (),
616
- jnp.shape(sigma) if sigma is not None else ()
617
- )
618
- key = self.split_key() if key is None else _formalize_key(key)
619
- dtype = dtype or environ.dftype()
620
- samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
621
- samples = _loc_scale(mean, sigma, samples)
622
- samples = jnp.exp(samples)
623
- return samples if unit.is_unitless else u.Quantity(samples, unit=unit)
624
-
625
- def binomial(
626
- self,
627
- n,
628
- p,
629
- size: Optional[Size] = None,
630
- key: Optional[SeedOrKey] = None,
631
- dtype: DTypeLike = None,
632
- check_valid: bool = True,
633
- ):
634
- n = _check_py_seq(n)
635
- p = _check_py_seq(p)
636
- if check_valid:
637
- jit_error_if(
638
- jnp.any(jnp.logical_or(p < 0, p > 1)),
639
- 'Parameter p should be within [0, 1], but we got {p}',
640
- p=p
641
- )
642
- if size is None:
643
- size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
644
- key = self.split_key() if key is None else _formalize_key(key)
645
- r = jr.binomial(key, n, p, shape=_size2shape(size))
646
- dtype = dtype or environ.ditype()
647
- return jnp.asarray(r, dtype=dtype)
648
-
649
- def chisquare(self,
650
- df,
651
- size: Optional[Size] = None,
652
- key: Optional[SeedOrKey] = None,
653
- dtype: DTypeLike = None):
654
- df = _check_py_seq(df)
655
- key = self.split_key() if key is None else _formalize_key(key)
656
- dtype = dtype or environ.dftype()
657
- if size is None:
658
- if jnp.ndim(df) == 0:
659
- dist = jr.normal(key, (df,), dtype=dtype) ** 2
660
- dist = dist.sum()
661
- else:
662
- raise NotImplementedError('Do not support non-scale "df" when "size" is None')
663
- else:
664
- dist = jr.normal(key, (df,) + _size2shape(size), dtype=dtype) ** 2
665
- dist = dist.sum(axis=0)
666
- return dist
667
-
668
- def dirichlet(self,
669
- alpha,
670
- size: Optional[Size] = None,
671
- key: Optional[SeedOrKey] = None,
672
- dtype: DTypeLike = None):
673
- key = self.split_key() if key is None else _formalize_key(key)
674
- alpha = _check_py_seq(alpha)
675
- dtype = dtype or environ.dftype()
676
- r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size), dtype=dtype)
677
- return r
678
-
679
- def geometric(self,
680
- p,
681
- size: Optional[Size] = None,
682
- key: Optional[SeedOrKey] = None,
683
- dtype: DTypeLike = None):
684
- p = _check_py_seq(p)
685
- if size is None:
686
- size = jnp.shape(p)
687
- key = self.split_key() if key is None else _formalize_key(key)
688
- dtype = dtype or environ.dftype()
689
- u_ = uniform_for_unit(key, size, dtype=dtype)
690
- r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p))
691
- return r
692
-
693
- def _check_p2(self, p):
694
- raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
695
-
696
- def multinomial(self,
697
- n,
698
- pvals,
699
- size: Optional[Size] = None,
700
- key: Optional[SeedOrKey] = None,
701
- dtype: DTypeLike = None):
702
- key = self.split_key() if key is None else _formalize_key(key)
703
- n = _check_py_seq(n)
704
- pvals = _check_py_seq(pvals)
705
- jit_error_if(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
706
- if isinstance(n, jax.core.Tracer):
707
- raise ValueError("The total count parameter `n` should not be a jax abstract array.")
708
- size = _size2shape(size)
709
- n_max = int(np.max(jax.device_get(n)))
710
- batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n))
711
- r = _multinomial(key, pvals, n, n_max, batch_shape + size)
712
- dtype = dtype or environ.ditype()
713
- return jnp.asarray(r, dtype=dtype)
714
-
715
- def multivariate_normal(
716
- self,
717
- mean,
718
- cov,
719
- size: Optional[Size] = None,
720
- method: str = 'cholesky',
721
- key: Optional[SeedOrKey] = None,
722
- dtype: DTypeLike = None
723
- ):
724
- if method not in {'svd', 'eigh', 'cholesky'}:
725
- raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
726
- dtype = dtype or environ.dftype()
727
- mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
728
- cov = u.math.asarray(_check_py_seq(cov), dtype=dtype)
729
- if isinstance(mean, u.Quantity):
730
- assert isinstance(cov, u.Quantity)
731
- assert mean.unit ** 2 == cov.unit
732
- mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
733
- cov = cov.mantissa if isinstance(cov, u.Quantity) else cov
734
- unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
735
-
736
- key = self.split_key() if key is None else _formalize_key(key)
737
- if not jnp.ndim(mean) >= 1:
738
- raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
739
- if not jnp.ndim(cov) >= 2:
740
- raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}")
741
- n = mean.shape[-1]
742
- if jnp.shape(cov)[-2:] != (n, n):
743
- raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
744
- f"but got cov.shape == {jnp.shape(cov)}.")
745
- if size is None:
746
- size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
747
- else:
748
- size = _size2shape(size)
749
- _check_shape("normal", size, mean.shape[:-1], cov.shape[:-2])
750
-
751
- if method == 'svd':
752
- (u_, s, _) = jnp.linalg.svd(cov)
753
- factor = u_ * jnp.sqrt(s[..., None, :])
754
- elif method == 'eigh':
755
- (w, v) = jnp.linalg.eigh(cov)
756
- factor = v * jnp.sqrt(w[..., None, :])
757
- else: # 'cholesky'
758
- factor = jnp.linalg.cholesky(cov)
759
- normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
760
- r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
761
- return r if unit.is_unitless else u.Quantity(r, unit=unit)
762
-
763
- def rayleigh(self,
764
- scale=1.0,
765
- size: Optional[Size] = None,
766
- key: Optional[SeedOrKey] = None,
767
- dtype: DTypeLike = None):
768
- scale = _check_py_seq(scale)
769
- if size is None:
770
- size = jnp.shape(scale)
771
- key = self.split_key() if key is None else _formalize_key(key)
772
- dtype = dtype or environ.dftype()
773
- x = jnp.sqrt(-2. * jnp.log(uniform_for_unit(key, shape=_size2shape(size), minval=0, maxval=1, dtype=dtype)))
774
- r = x * scale
775
- return r
776
-
777
- def triangular(self,
778
- size: Optional[Size] = None,
779
- key: Optional[SeedOrKey] = None):
780
- key = self.split_key() if key is None else _formalize_key(key)
781
- bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
782
- r = 2 * bernoulli_samples - 1
783
- return r
784
-
785
- def vonmises(self,
786
- mu,
787
- kappa,
788
- size: Optional[Size] = None,
789
- key: Optional[SeedOrKey] = None,
790
- dtype: DTypeLike = None):
791
- key = self.split_key() if key is None else _formalize_key(key)
792
- dtype = dtype or environ.dftype()
793
- mu = jnp.asarray(_check_py_seq(mu), dtype=dtype)
794
- kappa = jnp.asarray(_check_py_seq(kappa), dtype=dtype)
795
- if size is None:
796
- size = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(kappa))
797
- size = _size2shape(size)
798
- samples = _von_mises_centered(key, kappa, size, dtype=dtype)
799
- samples = samples + mu
800
- samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
801
- return samples
802
-
803
- def weibull(self,
804
- a,
805
- size: Optional[Size] = None,
806
- key: Optional[SeedOrKey] = None,
807
- dtype: DTypeLike = None):
808
- key = self.split_key() if key is None else _formalize_key(key)
809
- a = _check_py_seq(a)
810
- if size is None:
811
- size = jnp.shape(a)
812
- else:
813
- if jnp.size(a) > 1:
814
- raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
815
- size = _size2shape(size)
816
- dtype = dtype or environ.dftype()
817
- random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
818
- r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
819
- return r
820
-
821
- def weibull_min(self,
822
- a,
823
- scale=None,
824
- size: Optional[Size] = None,
825
- key: Optional[SeedOrKey] = None,
826
- dtype: DTypeLike = None):
827
- key = self.split_key() if key is None else _formalize_key(key)
828
- a = _check_py_seq(a)
829
- scale = _check_py_seq(scale)
830
- if size is None:
831
- size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale) if scale is not None else ())
832
- else:
833
- if jnp.size(a) > 1:
834
- raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
835
- size = _size2shape(size)
836
- dtype = dtype or environ.dftype()
837
- random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
838
- r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
839
- if scale is not None:
840
- r /= scale
841
- return r
842
-
843
- def maxwell(self,
844
- size: Optional[Size] = None,
845
- key: Optional[SeedOrKey] = None,
846
- dtype: DTypeLike = None):
847
- key = self.split_key() if key is None else _formalize_key(key)
848
- shape = _size2shape(size) + (3,)
849
- dtype = dtype or environ.dftype()
850
- norm_rvs = jr.normal(key=key, shape=shape, dtype=dtype)
851
- r = jnp.linalg.norm(norm_rvs, axis=-1)
852
- return r
853
-
854
- def negative_binomial(self,
855
- n,
856
- p,
857
- size: Optional[Size] = None,
858
- key: Optional[SeedOrKey] = None,
859
- dtype: DTypeLike = None):
860
- n = _check_py_seq(n)
861
- p = _check_py_seq(p)
862
- if size is None:
863
- size = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p))
864
- size = _size2shape(size)
865
- logits = jnp.log(p) - jnp.log1p(-p)
866
- if key is None:
867
- keys = self.split_key(2)
868
- else:
869
- keys = jr.split(_formalize_key(key), 2)
870
- rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0], dtype=environ.dftype())
871
- r = self.poisson(lam=rate, key=keys[1], dtype=dtype or environ.ditype())
872
- return r
873
-
874
- def wald(self,
875
- mean,
876
- scale,
877
- size: Optional[Size] = None,
878
- key: Optional[SeedOrKey] = None,
879
- dtype: DTypeLike = None):
880
- dtype = dtype or environ.dftype()
881
- key = self.split_key() if key is None else _formalize_key(key)
882
- mean = jnp.asarray(_check_py_seq(mean), dtype=dtype)
883
- scale = jnp.asarray(_check_py_seq(scale), dtype=dtype)
884
- if size is None:
885
- size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale))
886
- size = _size2shape(size)
887
- sampled_chi2 = jnp.square(self.randn(*size))
888
- sampled_uniform = self.uniform(size=size, key=key, dtype=dtype)
889
- # Wikipedia defines an intermediate x with the formula
890
- # x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2)
891
- # where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration.
892
- # Let us write
893
- # w = loc * y / (2 * conc)
894
- # Then we can extract the common factor in the last two terms to obtain
895
- # x = loc + loc * w * (1 - sqrt(2 / w + 1))
896
- # Now we see that the Wikipedia formula suffers from catastrphic
897
- # cancellation for large w (e.g., if conc << loc).
898
- #
899
- # Fortunately, we can fix this by multiplying both sides
900
- # by 1 + sqrt(2 / w + 1). We get
901
- # x * (1 + sqrt(2 / w + 1)) =
902
- # = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1))
903
- # = loc * (sqrt(2 / w + 1) - 1)
904
- # The term sqrt(2 / w + 1) + 1 no longer presents numerical
905
- # difficulties for large w, and sqrt(2 / w + 1) - 1 is just
906
- # sqrt1pm1(2 / w), which we know how to compute accurately.
907
- # This just leaves the matter of small w, where 2 / w may
908
- # overflow. In the limit a w -> 0, x -> loc, so we just mask
909
- # that case.
910
- sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above
911
- safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0)
912
- denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0)
913
- ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator
914
- sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above
915
- res = jnp.where(sampled_uniform <= mean / (mean + sampled),
916
- sampled,
917
- jnp.square(mean) / sampled)
918
- return res
919
-
920
- def t(self,
921
- df,
922
- size: Optional[Size] = None,
923
- key: Optional[SeedOrKey] = None,
924
- dtype: DTypeLike = None):
925
- dtype = dtype or environ.dftype()
926
- df = jnp.asarray(_check_py_seq(df), dtype=dtype)
927
- if size is None:
928
- size = np.shape(df)
929
- else:
930
- size = _size2shape(size)
931
- _check_shape("t", size, np.shape(df))
932
- if key is None:
933
- keys = self.split_key(2)
934
- else:
935
- keys = jr.split(_formalize_key(key), 2)
936
- n = jr.normal(keys[0], size, dtype=dtype)
937
- two = _const(n, 2)
938
- half_df = lax.div(df, two)
939
- g = jr.gamma(keys[1], half_df, size, dtype=dtype)
940
- r = n * jnp.sqrt(half_df / g)
941
- return r
942
-
943
- def orthogonal(self,
944
- n: int,
945
- size: Optional[Size] = None,
946
- key: Optional[SeedOrKey] = None,
947
- dtype: DTypeLike = None):
948
- dtype = dtype or environ.dftype()
949
- key = self.split_key() if key is None else _formalize_key(key)
950
- size = _size2shape(size)
951
- _check_shape("orthogonal", size)
952
- n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
953
- z = jr.normal(key, size + (n, n), dtype=dtype)
954
- q, r = jnp.linalg.qr(z)
955
- d = jnp.diagonal(r, 0, -2, -1)
956
- r = q * jnp.expand_dims(d / abs(d), -2)
957
- return r
958
-
959
- def noncentral_chisquare(self,
960
- df,
961
- nonc,
962
- size: Optional[Size] = None,
963
- key: Optional[SeedOrKey] = None,
964
- dtype: DTypeLike = None):
965
- dtype = dtype or environ.dftype()
966
- df = jnp.asarray(_check_py_seq(df), dtype=dtype)
967
- nonc = jnp.asarray(_check_py_seq(nonc), dtype=dtype)
968
- if size is None:
969
- size = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nonc))
970
- size = _size2shape(size)
971
- if key is None:
972
- keys = self.split_key(3)
973
- else:
974
- keys = jr.split(_formalize_key(key), 3)
975
- i = jr.poisson(keys[0], 0.5 * nonc, shape=size, dtype=environ.ditype())
976
- n = jr.normal(keys[1], shape=size, dtype=dtype) + jnp.sqrt(nonc)
977
- cond = jnp.greater(df, 1.0)
978
- df2 = jnp.where(cond, df - 1.0, df + 2.0 * i)
979
- chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size, dtype=dtype)
980
- r = jnp.where(cond, chi2 + n * n, chi2)
981
- return r
982
-
983
- def loggamma(self,
984
- a,
985
- size: Optional[Size] = None,
986
- key: Optional[SeedOrKey] = None,
987
- dtype: DTypeLike = None):
988
- dtype = dtype or environ.dftype()
989
- key = self.split_key() if key is None else _formalize_key(key)
990
- a = _check_py_seq(a)
991
- if size is None:
992
- size = jnp.shape(a)
993
- r = jr.loggamma(key, a, shape=_size2shape(size), dtype=dtype)
994
- return r
995
-
996
- def categorical(self,
997
- logits,
998
- axis: int = -1,
999
- size: Optional[Size] = None,
1000
- key: Optional[SeedOrKey] = None):
1001
- key = self.split_key() if key is None else _formalize_key(key)
1002
- logits = _check_py_seq(logits)
1003
- if size is None:
1004
- size = list(jnp.shape(logits))
1005
- size.pop(axis)
1006
- r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
1007
- return r
1008
-
1009
- def zipf(self,
1010
- a,
1011
- size: Optional[Size] = None,
1012
- key: Optional[SeedOrKey] = None,
1013
- dtype: DTypeLike = None):
1014
- a = _check_py_seq(a)
1015
- if size is None:
1016
- size = jnp.shape(a)
1017
- dtype = dtype or environ.ditype()
1018
- r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
1019
- jax.ShapeDtypeStruct(size, dtype),
1020
- a)
1021
- return r
1022
-
1023
- def power(self,
1024
- a,
1025
- size: Optional[Size] = None,
1026
- key: Optional[SeedOrKey] = None,
1027
- dtype: DTypeLike = None):
1028
- a = _check_py_seq(a)
1029
- if size is None:
1030
- size = jnp.shape(a)
1031
- size = _size2shape(size)
1032
- dtype = dtype or environ.dftype()
1033
- r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
1034
- jax.ShapeDtypeStruct(size, dtype),
1035
- a)
1036
- return r
1037
-
1038
- def f(self,
1039
- dfnum,
1040
- dfden,
1041
- size: Optional[Size] = None,
1042
- key: Optional[SeedOrKey] = None,
1043
- dtype: DTypeLike = None):
1044
- dfnum = _check_py_seq(dfnum)
1045
- dfden = _check_py_seq(dfden)
1046
- if size is None:
1047
- size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden))
1048
- size = _size2shape(size)
1049
- d = {'dfnum': dfnum, 'dfden': dfden}
1050
- dtype = dtype or environ.dftype()
1051
- r = jax.pure_callback(lambda dfnum_, dfden_: np.random.f(dfnum=dfnum_,
1052
- dfden=dfden_,
1053
- size=size).astype(dtype),
1054
- jax.ShapeDtypeStruct(size, dtype),
1055
- dfnum, dfden)
1056
- return r
1057
-
1058
- def hypergeometric(
1059
- self,
1060
- ngood,
1061
- nbad,
1062
- nsample,
1063
- size: Optional[Size] = None,
1064
- key: Optional[SeedOrKey] = None,
1065
- dtype: DTypeLike = None
1066
- ):
1067
- ngood = _check_py_seq(ngood)
1068
- nbad = _check_py_seq(nbad)
1069
- nsample = _check_py_seq(nsample)
1070
-
1071
- if size is None:
1072
- size = lax.broadcast_shapes(jnp.shape(ngood),
1073
- jnp.shape(nbad),
1074
- jnp.shape(nsample))
1075
- size = _size2shape(size)
1076
- dtype = dtype or environ.ditype()
1077
- d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
1078
- r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'],
1079
- nbad=d['nbad'],
1080
- nsample=d['nsample'],
1081
- size=size).astype(dtype),
1082
- jax.ShapeDtypeStruct(size, dtype),
1083
- d)
1084
- return r
1085
-
1086
- def logseries(self,
1087
- p,
1088
- size: Optional[Size] = None,
1089
- key: Optional[SeedOrKey] = None,
1090
- dtype: DTypeLike = None):
1091
- p = _check_py_seq(p)
1092
- if size is None:
1093
- size = jnp.shape(p)
1094
- size = _size2shape(size)
1095
- dtype = dtype or environ.ditype()
1096
- r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype),
1097
- jax.ShapeDtypeStruct(size, dtype),
1098
- p)
1099
- return r
1100
-
1101
- def noncentral_f(self,
1102
- dfnum,
1103
- dfden,
1104
- nonc,
1105
- size: Optional[Size] = None,
1106
- key: Optional[SeedOrKey] = None,
1107
- dtype: DTypeLike = None):
1108
- dfnum = _check_py_seq(dfnum)
1109
- dfden = _check_py_seq(dfden)
1110
- nonc = _check_py_seq(nonc)
1111
- if size is None:
1112
- size = lax.broadcast_shapes(jnp.shape(dfnum),
1113
- jnp.shape(dfden),
1114
- jnp.shape(nonc))
1115
- size = _size2shape(size)
1116
- d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
1117
- dtype = dtype or environ.dftype()
1118
- r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
1119
- dfden=x['dfden'],
1120
- nonc=x['nonc'],
1121
- size=size).astype(dtype),
1122
- jax.ShapeDtypeStruct(size, dtype),
1123
- d)
1124
- return r
1125
-
1126
- # PyTorch compatibility #
1127
- # --------------------- #
1128
-
1129
- def rand_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
1130
- """Returns a tensor with the same size as input that is filled with random
1131
- numbers from a uniform distribution on the interval ``[0, 1)``.
1132
-
1133
- Args:
1134
- input: the ``size`` of input will determine size of the output tensor.
1135
- dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1136
- key: the seed or key for the random.
1137
-
1138
- Returns:
1139
- The random data.
1140
- """
1141
- return self.random(jnp.shape(input), key=key).astype(dtype)
1142
-
1143
- def randn_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
1144
- """Returns a tensor with the same size as ``input`` that is filled with
1145
- random numbers from a normal distribution with mean 0 and variance 1.
1146
-
1147
- Args:
1148
- input: the ``size`` of input will determine size of the output tensor.
1149
- dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1150
- key: the seed or key for the random.
1151
-
1152
- Returns:
1153
- The random data.
1154
- """
1155
- return self.randn(*jnp.shape(input), key=key).astype(dtype)
1156
-
1157
- def randint_like(self, input, low=0, high=None, *, dtype=None, key: Optional[SeedOrKey] = None):
1158
- if high is None:
1159
- high = max(input)
1160
- return self.randint(low, high=high, size=jnp.shape(input), dtype=dtype, key=key)
1161
-
1162
-
1163
- # default random generator
1164
- DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))
1165
-
1166
-
1167
- # ---------------------------------------------------------------------------------------------------------------
1168
-
1169
-
1170
- def _formalize_key(key):
1171
- if isinstance(key, int):
1172
- return jr.PRNGKey(key) if use_prng_key else jr.key(key)
1173
- elif isinstance(key, (jax.Array, np.ndarray)):
1174
- if jnp.issubdtype(key.dtype, jax.dtypes.prng_key):
1175
- return key
1176
- if key.size == 1 and jnp.issubdtype(key.dtype, jnp.integer):
1177
- return jr.PRNGKey(key) if use_prng_key else jr.key(key)
1178
-
1179
- if key.dtype != jnp.uint32:
1180
- raise TypeError('key must be a int or an array with two uint32.')
1181
- if key.size != 2:
1182
- raise TypeError('key must be a int or an array with two uint32.')
1183
- return jnp.asarray(key, dtype=jnp.uint32)
1184
- else:
1185
- raise TypeError('key must be a int or an array with two uint32.')
1186
-
1187
-
1188
- def _size2shape(size):
1189
- if size is None:
1190
- return ()
1191
- elif isinstance(size, (tuple, list)):
1192
- return tuple(size)
1193
- else:
1194
- return (size,)
1195
-
1196
-
1197
- def _check_shape(name, shape, *param_shapes):
1198
- if param_shapes:
1199
- shape_ = lax.broadcast_shapes(shape, *param_shapes)
1200
- if shape != shape_:
1201
- msg = ("{} parameter shapes must be broadcast-compatible with shape "
1202
- "argument, and the result of broadcasting the shapes must equal "
1203
- "the shape argument, but got result {} for shape argument {}.")
1204
- raise ValueError(msg.format(name, shape_, shape))
1205
-
1206
-
1207
- def _is_python_scalar(x):
1208
- if hasattr(x, 'aval'):
1209
- return x.aval.weak_type
1210
- elif np.ndim(x) == 0:
1211
- return True
1212
- elif isinstance(x, (bool, int, float, complex)):
1213
- return True
1214
- else:
1215
- return False
1216
-
1217
-
1218
- python_scalar_dtypes = {
1219
- bool: np.dtype('bool'),
1220
- int: np.dtype('int64'),
1221
- float: np.dtype('float64'),
1222
- complex: np.dtype('complex128'),
1223
- }
1224
-
1225
-
1226
- def _dtype(x, *, canonicalize: bool = False):
1227
- """Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
1228
- if x is None:
1229
- raise ValueError(f"Invalid argument to dtype: {x}.")
1230
- elif isinstance(x, type) and x in python_scalar_dtypes:
1231
- dt = python_scalar_dtypes[x]
1232
- elif type(x) in python_scalar_dtypes:
1233
- dt = python_scalar_dtypes[type(x)]
1234
- elif hasattr(x, 'dtype'):
1235
- dt = x.dtype
1236
- else:
1237
- dt = np.result_type(x)
1238
- return dtypes.canonicalize_dtype(dt) if canonicalize else dt
1239
-
1240
-
1241
- def _const(example, val):
1242
- if _is_python_scalar(example):
1243
- dtype = dtypes.canonicalize_dtype(type(example))
1244
- val = dtypes.scalar_type_of(example)(val)
1245
- return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype)
1246
- else:
1247
- dtype = dtypes.canonicalize_dtype(example.dtype)
1248
- return np.array(val, dtype)
1249
-
1250
-
1251
- @partial(jit, static_argnums=(2,))
1252
- def _categorical(key, p, shape):
1253
- # this implementation is fast when event shape is small, and slow otherwise
1254
- # Ref: https://stackoverflow.com/a/34190035
1255
- shape = shape or p.shape[:-1]
1256
- s = jnp.cumsum(p, axis=-1)
1257
- r = jr.uniform(key, shape=shape + (1,))
1258
- return jnp.sum(s < r, axis=-1)
1259
-
1260
-
1261
- def _scatter_add_one(operand, indices, updates):
1262
- return lax.scatter_add(
1263
- operand,
1264
- indices,
1265
- updates,
1266
- lax.ScatterDimensionNumbers(
1267
- update_window_dims=(),
1268
- inserted_window_dims=(0,),
1269
- scatter_dims_to_operand_dims=(0,),
1270
- ),
1271
- )
1272
-
1273
-
1274
- def _reshape(x, shape):
1275
- if isinstance(x, (int, float, np.ndarray, np.generic)):
1276
- return np.reshape(x, shape)
1277
- else:
1278
- return jnp.reshape(x, shape)
1279
-
1280
-
1281
- def _promote_shapes(*args, shape=()):
1282
- # adapted from lax.lax_numpy
1283
- if len(args) < 2 and not shape:
1284
- return args
1285
- else:
1286
- shapes = [jnp.shape(arg) for arg in args]
1287
- num_dims = len(lax.broadcast_shapes(shape, *shapes))
1288
- return [
1289
- _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
1290
- for arg, s in zip(args, shapes)
1291
- ]
1292
-
1293
-
1294
- @partial(jit, static_argnums=(3, 4))
1295
- def _multinomial(key, p, n, n_max, shape=()):
1296
- if jnp.shape(n) != jnp.shape(p)[:-1]:
1297
- broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
1298
- n = jnp.broadcast_to(n, broadcast_shape)
1299
- p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
1300
- shape = shape or p.shape[:-1]
1301
- if n_max == 0:
1302
- return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
1303
- # get indices from categorical distribution then gather the result
1304
- indices = _categorical(key, p, (n_max,) + shape)
1305
- # mask out values when counts is heterogeneous
1306
- if jnp.ndim(n) > 0:
1307
- mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
1308
- mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
1309
- excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1),
1310
- jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))],
1311
- -1)
1312
- else:
1313
- mask = 1
1314
- excess = 0
1315
- # NB: we transpose to move batch shape to the front
1316
- indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
1317
- samples_2D = vmap(_scatter_add_one)(jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
1318
- jnp.expand_dims(indices_2D, axis=-1),
1319
- jnp.ones(indices_2D.shape, dtype=indices.dtype))
1320
- return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
1321
-
1322
-
1323
- @partial(jit, static_argnums=(2, 3), static_argnames=['shape', 'dtype'])
1324
- def _von_mises_centered(key, concentration, shape, dtype=None):
1325
- """Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
1326
-
1327
- Returns
1328
- -------
1329
- out: array_like
1330
- centered samples from von Mises
1331
-
1332
- References
1333
- ----------
1334
- .. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
1335
- Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
1336
-
1337
- """
1338
- shape = shape or jnp.shape(concentration)
1339
- dtype = dtype or environ.dftype()
1340
- concentration = lax.convert_element_type(concentration, dtype)
1341
- concentration = jnp.broadcast_to(concentration, shape)
1342
-
1343
- if dtype == jnp.float16:
1344
- s_cutoff = 1.8e-1
1345
- elif dtype == jnp.float32:
1346
- s_cutoff = 2e-2
1347
- elif dtype == jnp.float64:
1348
- s_cutoff = 1.2e-4
1349
- else:
1350
- raise ValueError(f"Unsupported dtype: {dtype}")
1351
-
1352
- r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
1353
- rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
1354
- s_exact = (1.0 + rho ** 2) / (2.0 * rho)
1355
-
1356
- s_approximate = 1.0 / concentration
1357
-
1358
- s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
1359
-
1360
- def cond_fn(*args):
1361
- """check if all are done or reached max number of iterations"""
1362
- i, _, done, _, _ = args[0]
1363
- return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
1364
-
1365
- def body_fn(*args):
1366
- i, key, done, _, w = args[0]
1367
- uni_ukey, uni_vkey, key = jr.split(key, 3)
1368
- u = jr.uniform(
1369
- key=uni_ukey,
1370
- shape=shape,
1371
- dtype=concentration.dtype,
1372
- minval=-1.0,
1373
- maxval=1.0,
1374
- )
1375
- z = jnp.cos(jnp.pi * u)
1376
- w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
1377
- y = concentration * (s - w)
1378
- v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
1379
- accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
1380
- return i + 1, key, accept | done, u, w
1381
-
1382
- init_done = jnp.zeros(shape, dtype=bool)
1383
- init_u = jnp.zeros(shape)
1384
- init_w = jnp.zeros(shape)
1385
-
1386
- _, _, done, u, w = lax.while_loop(
1387
- cond_fun=cond_fn,
1388
- body_fun=body_fn,
1389
- init_val=(jnp.array(0), key, init_done, init_u, init_w),
1390
- )
1391
-
1392
- return jnp.sign(u) * jnp.arccos(w)
1393
-
1394
-
1395
- def _loc_scale(loc, scale, value):
1396
- if loc is None:
1397
- if scale is None:
1398
- return value
1399
- else:
1400
- return value * scale
1401
- else:
1402
- if scale is None:
1403
- return value + loc
1404
- else:
1405
- return value * scale + loc
1406
-
1407
-
1408
- def _check_py_seq(seq):
1409
- return jnp.asarray(seq) if isinstance(seq, (tuple, list)) else seq
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from functools import partial
19
+ from operator import index
20
+ from typing import Optional
21
+
22
+ import brainunit as u
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import jax.random as jr
26
+ import numpy as np
27
+ from jax import jit, vmap
28
+ from jax import lax, core, dtypes
29
+
30
+ from brainstate import environ
31
+ from brainstate._state import State
32
+ from brainstate.compile._error_if import jit_error_if
33
+ from brainstate.typing import DTypeLike, Size, SeedOrKey
34
+ from ._random_for_unit import uniform_for_unit, permutation_for_unit
35
+
36
+ __all__ = ['RandomState', 'DEFAULT', ]
37
+
38
+ use_prng_key = True
39
+
40
+
41
+ class RandomState(State):
42
+ """RandomState that track the random generator state. """
43
+
44
+ # __slots__ = ('_backup', '_value')
45
+
46
+ def __init__(self, seed_or_key: Optional[SeedOrKey] = None):
47
+ """RandomState constructor.
48
+
49
+ Parameters
50
+ ----------
51
+ seed_or_key: int, Array, optional
52
+ It can be an integer for initial seed of the random number generator,
53
+ or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
54
+ """
55
+ with jax.ensure_compile_time_eval():
56
+ if seed_or_key is None:
57
+ seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
58
+ if isinstance(seed_or_key, int):
59
+ key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
60
+ else:
61
+ if jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
62
+ key = seed_or_key
63
+ else:
64
+ if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
65
+ raise ValueError('key must be an array with dtype uint32. '
66
+ f'But we got {seed_or_key}')
67
+ key = seed_or_key
68
+ super().__init__(key)
69
+
70
+ self._backup = None
71
+
72
+ def __repr__(self):
73
+ return f'{self.__class__.__name__}({self.value})'
74
+
75
+ def check_if_deleted(self):
76
+ if not use_prng_key and isinstance(self._value, np.ndarray):
77
+ self._value = jr.key(np.random.randint(0, 10000))
78
+
79
+ if (
80
+ isinstance(self._value, jax.Array) and
81
+ not isinstance(self._value, jax.core.Tracer) and
82
+ self._value.is_deleted()
83
+ ):
84
+ self.seed()
85
+
86
+ # ------------------- #
87
+ # seed and random key #
88
+ # ------------------- #
89
+
90
+ def backup_key(self):
91
+ if self._backup is not None:
92
+ raise ValueError('The random key has been backed up, and has not been restored.')
93
+ self._backup = self.value
94
+
95
+ def restore_key(self):
96
+ if self._backup is None:
97
+ raise ValueError('The random key has not been backed up.')
98
+ self.value = self._backup
99
+ self._backup = None
100
+
101
+ def clone(self):
102
+ return type(self)(self.split_key())
103
+
104
+ def set_key(self, key: SeedOrKey):
105
+ self.value = key
106
+
107
+ def seed(self, seed_or_key: Optional[SeedOrKey] = None):
108
+ """Sets a new random seed.
109
+
110
+ Parameters
111
+ ----------
112
+ seed_or_key: int, ArrayLike, optional
113
+ It can be an integer for initial seed of the random number generator,
114
+ or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
115
+ """
116
+ with jax.ensure_compile_time_eval():
117
+ if seed_or_key is None:
118
+ seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
119
+ if np.size(seed_or_key) == 1:
120
+ if isinstance(seed_or_key, int):
121
+ key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
122
+ elif jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
123
+ key = seed_or_key
124
+ elif isinstance(seed_or_key, (jnp.ndarray, np.ndarray)) and jnp.issubdtype(seed_or_key.dtype, jnp.integer):
125
+ key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
126
+ else:
127
+ raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
128
+ else:
129
+ if len(seed_or_key) == 2 and seed_or_key.dtype == np.uint32:
130
+ key = seed_or_key
131
+ else:
132
+ raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
133
+ self.value = key
134
+
135
+ def split_key(self, n: Optional[int] = None, backup: bool = False) -> SeedOrKey:
136
+ """
137
+ Create a new seed from the current seed.
138
+
139
+ Parameters
140
+ ----------
141
+ n: int, optional
142
+ The number of seeds to generate.
143
+ backup : bool, optional
144
+ Whether to backup the current key.
145
+
146
+ Returns
147
+ -------
148
+ key : SeedOrKey
149
+ The new seed or a tuple of JAX random keys.
150
+ """
151
+ if n is not None:
152
+ assert isinstance(n, int) and n >= 1, f'n should be an integer greater than 1, but we got {n}'
153
+
154
+ if not isinstance(self.value, jax.Array):
155
+ self.value = jnp.asarray(self.value, dtype=jnp.uint32)
156
+ keys = jr.split(self.value, num=2 if n is None else n + 1)
157
+ self.value = keys[0]
158
+ if backup:
159
+ self.backup_key()
160
+ if n is None:
161
+ return keys[1]
162
+ else:
163
+ return keys[1:]
164
+
165
+ def self_assign_multi_keys(self, n: int, backup: bool = True):
166
+ """
167
+ Self-assign multiple keys to the current random state.
168
+ """
169
+ if backup:
170
+ keys = jr.split(self.value, n + 1)
171
+ self.value = keys[0]
172
+ self.backup_key()
173
+ self.value = keys[1:]
174
+ else:
175
+ self.value = jr.split(self.value, n)
176
+
177
+ # ---------------- #
178
+ # random functions #
179
+ # ---------------- #
180
+
181
+ def rand(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
182
+ key = self.split_key() if key is None else _formalize_key(key)
183
+ dtype = dtype or environ.dftype()
184
+ r = uniform_for_unit(key, shape=dn, minval=0., maxval=1., dtype=dtype)
185
+ return r
186
+
187
+ def randint(
188
+ self,
189
+ low,
190
+ high=None,
191
+ size: Optional[Size] = None,
192
+ dtype: DTypeLike = None,
193
+ key: Optional[SeedOrKey] = None
194
+ ):
195
+ if high is None:
196
+ high = low
197
+ low = 0
198
+ high = _check_py_seq(high)
199
+ low = _check_py_seq(low)
200
+ if size is None:
201
+ size = lax.broadcast_shapes(jnp.shape(low),
202
+ jnp.shape(high))
203
+ key = self.split_key() if key is None else _formalize_key(key)
204
+ dtype = dtype or environ.ditype()
205
+ r = jr.randint(key,
206
+ shape=_size2shape(size),
207
+ minval=low, maxval=high, dtype=dtype)
208
+ return r
209
+
210
+ def random_integers(
211
+ self,
212
+ low,
213
+ high=None,
214
+ size: Optional[Size] = None,
215
+ key: Optional[SeedOrKey] = None,
216
+ dtype: DTypeLike = None,
217
+ ):
218
+ low = _check_py_seq(low)
219
+ high = _check_py_seq(high)
220
+ if high is None:
221
+ high = low
222
+ low = 1
223
+ high += 1
224
+ if size is None:
225
+ size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
226
+ key = self.split_key() if key is None else _formalize_key(key)
227
+ dtype = dtype or environ.ditype()
228
+ r = jr.randint(key,
229
+ shape=_size2shape(size),
230
+ minval=low,
231
+ maxval=high,
232
+ dtype=dtype)
233
+ return r
234
+
235
+ def randn(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
236
+ key = self.split_key() if key is None else _formalize_key(key)
237
+ dtype = dtype or environ.dftype()
238
+ r = jr.normal(key, shape=dn, dtype=dtype)
239
+ return r
240
+
241
+ def random(self,
242
+ size: Optional[Size] = None,
243
+ key: Optional[SeedOrKey] = None,
244
+ dtype: DTypeLike = None):
245
+ dtype = dtype or environ.dftype()
246
+ key = self.split_key() if key is None else _formalize_key(key)
247
+ r = uniform_for_unit(key, shape=_size2shape(size), minval=0., maxval=1., dtype=dtype)
248
+ return r
249
+
250
+ def random_sample(self,
251
+ size: Optional[Size] = None,
252
+ key: Optional[SeedOrKey] = None,
253
+ dtype: DTypeLike = None):
254
+ r = self.random(size=size, key=key, dtype=dtype)
255
+ return r
256
+
257
+ def ranf(self,
258
+ size: Optional[Size] = None,
259
+ key: Optional[SeedOrKey] = None,
260
+ dtype: DTypeLike = None):
261
+ r = self.random(size=size, key=key, dtype=dtype)
262
+ return r
263
+
264
+ def sample(self,
265
+ size: Optional[Size] = None,
266
+ key: Optional[SeedOrKey] = None,
267
+ dtype: DTypeLike = None):
268
+ r = self.random(size=size, key=key, dtype=dtype)
269
+ return r
270
+
271
+ def choice(self,
272
+ a,
273
+ size: Optional[Size] = None,
274
+ replace=True,
275
+ p=None,
276
+ key: Optional[SeedOrKey] = None):
277
+ a = _check_py_seq(a)
278
+ p = _check_py_seq(p)
279
+ key = self.split_key() if key is None else _formalize_key(key)
280
+ r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
281
+ return r
282
+
283
+ def permutation(self,
284
+ x,
285
+ axis: int = 0,
286
+ independent: bool = False,
287
+ key: Optional[SeedOrKey] = None):
288
+ x = _check_py_seq(x)
289
+ key = self.split_key() if key is None else _formalize_key(key)
290
+ r = permutation_for_unit(key, x, axis=axis, independent=independent)
291
+ return r
292
+
293
+ def shuffle(self,
294
+ x,
295
+ axis=0,
296
+ key: Optional[SeedOrKey] = None):
297
+ key = self.split_key() if key is None else _formalize_key(key)
298
+ x = permutation_for_unit(key, x, axis=axis)
299
+ return x
300
+
301
+ def beta(self,
302
+ a,
303
+ b,
304
+ size: Optional[Size] = None,
305
+ key: Optional[SeedOrKey] = None,
306
+ dtype: DTypeLike = None):
307
+ a = _check_py_seq(a)
308
+ b = _check_py_seq(b)
309
+ if size is None:
310
+ size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b))
311
+ key = self.split_key() if key is None else _formalize_key(key)
312
+ dtype = dtype or environ.dftype()
313
+ r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype)
314
+ return r
315
+
316
+ def exponential(self,
317
+ scale=None,
318
+ size: Optional[Size] = None,
319
+ key: Optional[SeedOrKey] = None,
320
+ dtype: DTypeLike = None):
321
+ if size is None:
322
+ size = jnp.shape(scale)
323
+ key = self.split_key() if key is None else _formalize_key(key)
324
+ dtype = dtype or environ.dftype()
325
+ scale = jnp.asarray(scale, dtype=dtype)
326
+ r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
327
+ if scale is not None:
328
+ r = r / scale
329
+ return r
330
+
331
+ def gamma(self,
332
+ shape,
333
+ scale=None,
334
+ size: Optional[Size] = None,
335
+ key: Optional[SeedOrKey] = None,
336
+ dtype: DTypeLike = None):
337
+ shape = _check_py_seq(shape)
338
+ scale = _check_py_seq(scale)
339
+ if size is None:
340
+ size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale))
341
+ key = self.split_key() if key is None else _formalize_key(key)
342
+ dtype = dtype or environ.dftype()
343
+ r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
344
+ if scale is not None:
345
+ r = r * scale
346
+ return r
347
+
348
+ def gumbel(self,
349
+ loc=None,
350
+ scale=None,
351
+ size: Optional[Size] = None,
352
+ key: Optional[SeedOrKey] = None,
353
+ dtype: DTypeLike = None):
354
+ loc = _check_py_seq(loc)
355
+ scale = _check_py_seq(scale)
356
+ if size is None:
357
+ size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
358
+ key = self.split_key() if key is None else _formalize_key(key)
359
+ dtype = dtype or environ.dftype()
360
+ r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype))
361
+ return r
362
+
363
+ def laplace(self,
364
+ loc=None,
365
+ scale=None,
366
+ size: Optional[Size] = None,
367
+ key: Optional[SeedOrKey] = None,
368
+ dtype: DTypeLike = None):
369
+ loc = _check_py_seq(loc)
370
+ scale = _check_py_seq(scale)
371
+ if size is None:
372
+ size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
373
+ key = self.split_key() if key is None else _formalize_key(key)
374
+ dtype = dtype or environ.dftype()
375
+ r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype))
376
+ return r
377
+
378
+ def logistic(self,
379
+ loc=None,
380
+ scale=None,
381
+ size: Optional[Size] = None,
382
+ key: Optional[SeedOrKey] = None,
383
+ dtype: DTypeLike = None):
384
+ loc = _check_py_seq(loc)
385
+ scale = _check_py_seq(scale)
386
+ if size is None:
387
+ size = lax.broadcast_shapes(
388
+ jnp.shape(loc) if loc is not None else (),
389
+ jnp.shape(scale) if scale is not None else ()
390
+ )
391
+ key = self.split_key() if key is None else _formalize_key(key)
392
+ dtype = dtype or environ.dftype()
393
+ r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
394
+ return r
395
+
396
+ def normal(self,
397
+ loc=None,
398
+ scale=None,
399
+ size: Optional[Size] = None,
400
+ key: Optional[SeedOrKey] = None,
401
+ dtype: DTypeLike = None):
402
+ loc = _check_py_seq(loc)
403
+ scale = _check_py_seq(scale)
404
+ if size is None:
405
+ size = lax.broadcast_shapes(
406
+ jnp.shape(scale) if scale is not None else (),
407
+ jnp.shape(loc) if loc is not None else ()
408
+ )
409
+ key = self.split_key() if key is None else _formalize_key(key)
410
+ dtype = dtype or environ.dftype()
411
+ r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
412
+ return r
413
+
414
+ def pareto(self,
415
+ a,
416
+ size: Optional[Size] = None,
417
+ key: Optional[SeedOrKey] = None,
418
+ dtype: DTypeLike = None):
419
+ if size is None:
420
+ size = jnp.shape(a)
421
+ key = self.split_key() if key is None else _formalize_key(key)
422
+ dtype = dtype or environ.dftype()
423
+ a = jnp.asarray(a, dtype=dtype)
424
+ r = jr.pareto(key, b=a, shape=_size2shape(size), dtype=dtype)
425
+ return r
426
+
427
+ def poisson(self,
428
+ lam=1.0,
429
+ size: Optional[Size] = None,
430
+ key: Optional[SeedOrKey] = None,
431
+ dtype: DTypeLike = None):
432
+ lam = _check_py_seq(lam)
433
+ if size is None:
434
+ size = jnp.shape(lam)
435
+ key = self.split_key() if key is None else _formalize_key(key)
436
+ dtype = dtype or environ.ditype()
437
+ r = jr.poisson(key, lam=lam, shape=_size2shape(size), dtype=dtype)
438
+ return r
439
+
440
+ def standard_cauchy(self,
441
+ size: Optional[Size] = None,
442
+ key: Optional[SeedOrKey] = None,
443
+ dtype: DTypeLike = None):
444
+ key = self.split_key() if key is None else _formalize_key(key)
445
+ dtype = dtype or environ.dftype()
446
+ r = jr.cauchy(key, shape=_size2shape(size), dtype=dtype)
447
+ return r
448
+
449
+ def standard_exponential(self,
450
+ size: Optional[Size] = None,
451
+ key: Optional[SeedOrKey] = None,
452
+ dtype: DTypeLike = None):
453
+ key = self.split_key() if key is None else _formalize_key(key)
454
+ dtype = dtype or environ.dftype()
455
+ r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
456
+ return r
457
+
458
+ def standard_gamma(self,
459
+ shape,
460
+ size: Optional[Size] = None,
461
+ key: Optional[SeedOrKey] = None,
462
+ dtype: DTypeLike = None):
463
+ shape = _check_py_seq(shape)
464
+ if size is None:
465
+ size = jnp.shape(shape) if shape is not None else ()
466
+ key = self.split_key() if key is None else _formalize_key(key)
467
+ dtype = dtype or environ.dftype()
468
+ r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
469
+ return r
470
+
471
+ def standard_normal(self,
472
+ size: Optional[Size] = None,
473
+ key: Optional[SeedOrKey] = None,
474
+ dtype: DTypeLike = None):
475
+ key = self.split_key() if key is None else _formalize_key(key)
476
+ dtype = dtype or environ.dftype()
477
+ r = jr.normal(key, shape=_size2shape(size), dtype=dtype)
478
+ return r
479
+
480
+ def standard_t(self, df,
481
+ size: Optional[Size] = None,
482
+ key: Optional[SeedOrKey] = None,
483
+ dtype: DTypeLike = None):
484
+ df = _check_py_seq(df)
485
+ if size is None:
486
+ size = jnp.shape(size) if size is not None else ()
487
+ key = self.split_key() if key is None else _formalize_key(key)
488
+ dtype = dtype or environ.dftype()
489
+ r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
490
+ return r
491
+
492
+ def uniform(self,
493
+ low=0.0,
494
+ high=1.0,
495
+ size: Optional[Size] = None,
496
+ key: Optional[SeedOrKey] = None,
497
+ dtype: DTypeLike = None):
498
+ low = _check_py_seq(low)
499
+ high = _check_py_seq(high)
500
+ if size is None:
501
+ size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
502
+ key = self.split_key() if key is None else _formalize_key(key)
503
+ dtype = dtype or environ.dftype()
504
+ r = uniform_for_unit(key, shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)
505
+ return r
506
+
507
+ def __norm_cdf(self, x, sqrt2, dtype):
508
+ # Computes standard normal cumulative distribution function
509
+ return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
510
+
511
+ def truncated_normal(
512
+ self,
513
+ lower,
514
+ upper,
515
+ size: Optional[Size] = None,
516
+ loc=0.,
517
+ scale=1.,
518
+ key: Optional[SeedOrKey] = None,
519
+ dtype: DTypeLike = None
520
+ ):
521
+ lower = _check_py_seq(lower)
522
+ upper = _check_py_seq(upper)
523
+ loc = _check_py_seq(loc)
524
+ scale = _check_py_seq(scale)
525
+ dtype = dtype or environ.dftype()
526
+
527
+ lower = u.math.asarray(lower, dtype=dtype)
528
+ upper = u.math.asarray(upper, dtype=dtype)
529
+ loc = u.math.asarray(loc, dtype=dtype)
530
+ scale = u.math.asarray(scale, dtype=dtype)
531
+ unit = u.get_unit(lower)
532
+ lower, upper, loc, scale = (
533
+ lower.mantissa if isinstance(lower, u.Quantity) else lower,
534
+ u.Quantity(upper).in_unit(unit).mantissa,
535
+ u.Quantity(loc).in_unit(unit).mantissa,
536
+ u.Quantity(scale).in_unit(unit).mantissa
537
+ )
538
+
539
+ jit_error_if(
540
+ u.math.any(u.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
541
+ "mean is more than 2 std from [lower, upper] in truncated_normal. "
542
+ "The distribution of values may be incorrect."
543
+ )
544
+
545
+ if size is None:
546
+ size = u.math.broadcast_shapes(jnp.shape(lower),
547
+ jnp.shape(upper),
548
+ jnp.shape(loc),
549
+ jnp.shape(scale))
550
+
551
+ # Values are generated by using a truncated uniform distribution and
552
+ # then using the inverse CDF for the normal distribution.
553
+ # Get upper and lower cdf values
554
+ sqrt2 = np.array(np.sqrt(2), dtype=dtype)
555
+ l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
556
+ u_ = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
557
+
558
+ # Uniformly fill tensor with values from [l, u], then translate to
559
+ # [2l-1, 2u-1].
560
+ key = self.split_key() if key is None else _formalize_key(key)
561
+ out = uniform_for_unit(
562
+ key, size, dtype,
563
+ minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
564
+ maxval=lax.nextafter(2 * u_ - 1, np.array(-np.inf, dtype=dtype))
565
+ )
566
+
567
+ # Use inverse cdf transform for normal distribution to get truncated
568
+ # standard normal
569
+ out = lax.erf_inv(out)
570
+
571
+ # Transform to proper mean, std
572
+ out = out * scale * sqrt2 + loc
573
+
574
+ # Clamp to ensure it's in the proper range
575
+ out = jnp.clip(
576
+ out,
577
+ lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
578
+ lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
579
+ )
580
+ return out if unit.is_unitless else u.Quantity(out, unit=unit)
581
+
582
+ def _check_p(self, *args, **kwargs):
583
+ raise ValueError('Parameter p should be within [0, 1], but we got {p}')
584
+
585
+ def bernoulli(self,
586
+ p,
587
+ size: Optional[Size] = None,
588
+ key: Optional[SeedOrKey] = None):
589
+ p = _check_py_seq(p)
590
+ jit_error_if(jnp.any(jnp.logical_or(p < 0, p > 1)), self._check_p, p=p)
591
+ if size is None:
592
+ size = jnp.shape(p)
593
+ key = self.split_key() if key is None else _formalize_key(key)
594
+ r = jr.bernoulli(key, p=p, shape=_size2shape(size))
595
+ return r
596
+
597
+ def lognormal(
598
+ self,
599
+ mean=None,
600
+ sigma=None,
601
+ size: Optional[Size] = None,
602
+ key: Optional[SeedOrKey] = None,
603
+ dtype: DTypeLike = None
604
+ ):
605
+ mean = _check_py_seq(mean)
606
+ sigma = _check_py_seq(sigma)
607
+ mean = u.math.asarray(mean, dtype=dtype)
608
+ sigma = u.math.asarray(sigma, dtype=dtype)
609
+ unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
610
+ mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
611
+ sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, u.Quantity) else sigma
612
+
613
+ if size is None:
614
+ size = jnp.broadcast_shapes(
615
+ jnp.shape(mean) if mean is not None else (),
616
+ jnp.shape(sigma) if sigma is not None else ()
617
+ )
618
+ key = self.split_key() if key is None else _formalize_key(key)
619
+ dtype = dtype or environ.dftype()
620
+ samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
621
+ samples = _loc_scale(mean, sigma, samples)
622
+ samples = jnp.exp(samples)
623
+ return samples if unit.is_unitless else u.Quantity(samples, unit=unit)
624
+
625
+ def binomial(
626
+ self,
627
+ n,
628
+ p,
629
+ size: Optional[Size] = None,
630
+ key: Optional[SeedOrKey] = None,
631
+ dtype: DTypeLike = None,
632
+ check_valid: bool = True,
633
+ ):
634
+ n = _check_py_seq(n)
635
+ p = _check_py_seq(p)
636
+ if check_valid:
637
+ jit_error_if(
638
+ jnp.any(jnp.logical_or(p < 0, p > 1)),
639
+ 'Parameter p should be within [0, 1], but we got {p}',
640
+ p=p
641
+ )
642
+ if size is None:
643
+ size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
644
+ key = self.split_key() if key is None else _formalize_key(key)
645
+ r = jr.binomial(key, n, p, shape=_size2shape(size))
646
+ dtype = dtype or environ.ditype()
647
+ return jnp.asarray(r, dtype=dtype)
648
+
649
+ def chisquare(self,
650
+ df,
651
+ size: Optional[Size] = None,
652
+ key: Optional[SeedOrKey] = None,
653
+ dtype: DTypeLike = None):
654
+ df = _check_py_seq(df)
655
+ key = self.split_key() if key is None else _formalize_key(key)
656
+ dtype = dtype or environ.dftype()
657
+ if size is None:
658
+ if jnp.ndim(df) == 0:
659
+ dist = jr.normal(key, (df,), dtype=dtype) ** 2
660
+ dist = dist.sum()
661
+ else:
662
+ raise NotImplementedError('Do not support non-scale "df" when "size" is None')
663
+ else:
664
+ dist = jr.normal(key, (df,) + _size2shape(size), dtype=dtype) ** 2
665
+ dist = dist.sum(axis=0)
666
+ return dist
667
+
668
+ def dirichlet(self,
669
+ alpha,
670
+ size: Optional[Size] = None,
671
+ key: Optional[SeedOrKey] = None,
672
+ dtype: DTypeLike = None):
673
+ key = self.split_key() if key is None else _formalize_key(key)
674
+ alpha = _check_py_seq(alpha)
675
+ dtype = dtype or environ.dftype()
676
+ r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size), dtype=dtype)
677
+ return r
678
+
679
+ def geometric(self,
680
+ p,
681
+ size: Optional[Size] = None,
682
+ key: Optional[SeedOrKey] = None,
683
+ dtype: DTypeLike = None):
684
+ p = _check_py_seq(p)
685
+ if size is None:
686
+ size = jnp.shape(p)
687
+ key = self.split_key() if key is None else _formalize_key(key)
688
+ dtype = dtype or environ.dftype()
689
+ u_ = uniform_for_unit(key, size, dtype=dtype)
690
+ r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p))
691
+ return r
692
+
693
+ def _check_p2(self, p):
694
+ raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
695
+
696
+ def multinomial(self,
697
+ n,
698
+ pvals,
699
+ size: Optional[Size] = None,
700
+ key: Optional[SeedOrKey] = None,
701
+ dtype: DTypeLike = None):
702
+ key = self.split_key() if key is None else _formalize_key(key)
703
+ n = _check_py_seq(n)
704
+ pvals = _check_py_seq(pvals)
705
+ jit_error_if(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
706
+ if isinstance(n, jax.core.Tracer):
707
+ raise ValueError("The total count parameter `n` should not be a jax abstract array.")
708
+ size = _size2shape(size)
709
+ n_max = int(np.max(jax.device_get(n)))
710
+ batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n))
711
+ r = _multinomial(key, pvals, n, n_max, batch_shape + size)
712
+ dtype = dtype or environ.ditype()
713
+ return jnp.asarray(r, dtype=dtype)
714
+
715
+ def multivariate_normal(
716
+ self,
717
+ mean,
718
+ cov,
719
+ size: Optional[Size] = None,
720
+ method: str = 'cholesky',
721
+ key: Optional[SeedOrKey] = None,
722
+ dtype: DTypeLike = None
723
+ ):
724
+ if method not in {'svd', 'eigh', 'cholesky'}:
725
+ raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
726
+ dtype = dtype or environ.dftype()
727
+ mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
728
+ cov = u.math.asarray(_check_py_seq(cov), dtype=dtype)
729
+ if isinstance(mean, u.Quantity):
730
+ assert isinstance(cov, u.Quantity)
731
+ assert mean.unit ** 2 == cov.unit
732
+ mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
733
+ cov = cov.mantissa if isinstance(cov, u.Quantity) else cov
734
+ unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
735
+
736
+ key = self.split_key() if key is None else _formalize_key(key)
737
+ if not jnp.ndim(mean) >= 1:
738
+ raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
739
+ if not jnp.ndim(cov) >= 2:
740
+ raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}")
741
+ n = mean.shape[-1]
742
+ if jnp.shape(cov)[-2:] != (n, n):
743
+ raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
744
+ f"but got cov.shape == {jnp.shape(cov)}.")
745
+ if size is None:
746
+ size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
747
+ else:
748
+ size = _size2shape(size)
749
+ _check_shape("normal", size, mean.shape[:-1], cov.shape[:-2])
750
+
751
+ if method == 'svd':
752
+ (u_, s, _) = jnp.linalg.svd(cov)
753
+ factor = u_ * jnp.sqrt(s[..., None, :])
754
+ elif method == 'eigh':
755
+ (w, v) = jnp.linalg.eigh(cov)
756
+ factor = v * jnp.sqrt(w[..., None, :])
757
+ else: # 'cholesky'
758
+ factor = jnp.linalg.cholesky(cov)
759
+ normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
760
+ r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
761
+ return r if unit.is_unitless else u.Quantity(r, unit=unit)
762
+
763
+ def rayleigh(self,
764
+ scale=1.0,
765
+ size: Optional[Size] = None,
766
+ key: Optional[SeedOrKey] = None,
767
+ dtype: DTypeLike = None):
768
+ scale = _check_py_seq(scale)
769
+ if size is None:
770
+ size = jnp.shape(scale)
771
+ key = self.split_key() if key is None else _formalize_key(key)
772
+ dtype = dtype or environ.dftype()
773
+ x = jnp.sqrt(-2. * jnp.log(uniform_for_unit(key, shape=_size2shape(size), minval=0, maxval=1, dtype=dtype)))
774
+ r = x * scale
775
+ return r
776
+
777
+ def triangular(self,
778
+ size: Optional[Size] = None,
779
+ key: Optional[SeedOrKey] = None):
780
+ key = self.split_key() if key is None else _formalize_key(key)
781
+ bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
782
+ r = 2 * bernoulli_samples - 1
783
+ return r
784
+
785
+ def vonmises(self,
786
+ mu,
787
+ kappa,
788
+ size: Optional[Size] = None,
789
+ key: Optional[SeedOrKey] = None,
790
+ dtype: DTypeLike = None):
791
+ key = self.split_key() if key is None else _formalize_key(key)
792
+ dtype = dtype or environ.dftype()
793
+ mu = jnp.asarray(_check_py_seq(mu), dtype=dtype)
794
+ kappa = jnp.asarray(_check_py_seq(kappa), dtype=dtype)
795
+ if size is None:
796
+ size = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(kappa))
797
+ size = _size2shape(size)
798
+ samples = _von_mises_centered(key, kappa, size, dtype=dtype)
799
+ samples = samples + mu
800
+ samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
801
+ return samples
802
+
803
+ def weibull(self,
804
+ a,
805
+ size: Optional[Size] = None,
806
+ key: Optional[SeedOrKey] = None,
807
+ dtype: DTypeLike = None):
808
+ key = self.split_key() if key is None else _formalize_key(key)
809
+ a = _check_py_seq(a)
810
+ if size is None:
811
+ size = jnp.shape(a)
812
+ else:
813
+ if jnp.size(a) > 1:
814
+ raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
815
+ size = _size2shape(size)
816
+ dtype = dtype or environ.dftype()
817
+ random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
818
+ r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
819
+ return r
820
+
821
+ def weibull_min(self,
822
+ a,
823
+ scale=None,
824
+ size: Optional[Size] = None,
825
+ key: Optional[SeedOrKey] = None,
826
+ dtype: DTypeLike = None):
827
+ key = self.split_key() if key is None else _formalize_key(key)
828
+ a = _check_py_seq(a)
829
+ scale = _check_py_seq(scale)
830
+ if size is None:
831
+ size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale) if scale is not None else ())
832
+ else:
833
+ if jnp.size(a) > 1:
834
+ raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
835
+ size = _size2shape(size)
836
+ dtype = dtype or environ.dftype()
837
+ random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
838
+ r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
839
+ if scale is not None:
840
+ r /= scale
841
+ return r
842
+
843
+ def maxwell(self,
844
+ size: Optional[Size] = None,
845
+ key: Optional[SeedOrKey] = None,
846
+ dtype: DTypeLike = None):
847
+ key = self.split_key() if key is None else _formalize_key(key)
848
+ shape = _size2shape(size) + (3,)
849
+ dtype = dtype or environ.dftype()
850
+ norm_rvs = jr.normal(key=key, shape=shape, dtype=dtype)
851
+ r = jnp.linalg.norm(norm_rvs, axis=-1)
852
+ return r
853
+
854
+ def negative_binomial(self,
855
+ n,
856
+ p,
857
+ size: Optional[Size] = None,
858
+ key: Optional[SeedOrKey] = None,
859
+ dtype: DTypeLike = None):
860
+ n = _check_py_seq(n)
861
+ p = _check_py_seq(p)
862
+ if size is None:
863
+ size = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p))
864
+ size = _size2shape(size)
865
+ logits = jnp.log(p) - jnp.log1p(-p)
866
+ if key is None:
867
+ keys = self.split_key(2)
868
+ else:
869
+ keys = jr.split(_formalize_key(key), 2)
870
+ rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0], dtype=environ.dftype())
871
+ r = self.poisson(lam=rate, key=keys[1], dtype=dtype or environ.ditype())
872
+ return r
873
+
874
+ def wald(self,
875
+ mean,
876
+ scale,
877
+ size: Optional[Size] = None,
878
+ key: Optional[SeedOrKey] = None,
879
+ dtype: DTypeLike = None):
880
+ dtype = dtype or environ.dftype()
881
+ key = self.split_key() if key is None else _formalize_key(key)
882
+ mean = jnp.asarray(_check_py_seq(mean), dtype=dtype)
883
+ scale = jnp.asarray(_check_py_seq(scale), dtype=dtype)
884
+ if size is None:
885
+ size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale))
886
+ size = _size2shape(size)
887
+ sampled_chi2 = jnp.square(self.randn(*size))
888
+ sampled_uniform = self.uniform(size=size, key=key, dtype=dtype)
889
+ # Wikipedia defines an intermediate x with the formula
890
+ # x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2)
891
+ # where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration.
892
+ # Let us write
893
+ # w = loc * y / (2 * conc)
894
+ # Then we can extract the common factor in the last two terms to obtain
895
+ # x = loc + loc * w * (1 - sqrt(2 / w + 1))
896
+ # Now we see that the Wikipedia formula suffers from catastrphic
897
+ # cancellation for large w (e.g., if conc << loc).
898
+ #
899
+ # Fortunately, we can fix this by multiplying both sides
900
+ # by 1 + sqrt(2 / w + 1). We get
901
+ # x * (1 + sqrt(2 / w + 1)) =
902
+ # = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1))
903
+ # = loc * (sqrt(2 / w + 1) - 1)
904
+ # The term sqrt(2 / w + 1) + 1 no longer presents numerical
905
+ # difficulties for large w, and sqrt(2 / w + 1) - 1 is just
906
+ # sqrt1pm1(2 / w), which we know how to compute accurately.
907
+ # This just leaves the matter of small w, where 2 / w may
908
+ # overflow. In the limit a w -> 0, x -> loc, so we just mask
909
+ # that case.
910
+ sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above
911
+ safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0)
912
+ denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0)
913
+ ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator
914
+ sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above
915
+ res = jnp.where(sampled_uniform <= mean / (mean + sampled),
916
+ sampled,
917
+ jnp.square(mean) / sampled)
918
+ return res
919
+
920
+ def t(self,
921
+ df,
922
+ size: Optional[Size] = None,
923
+ key: Optional[SeedOrKey] = None,
924
+ dtype: DTypeLike = None):
925
+ dtype = dtype or environ.dftype()
926
+ df = jnp.asarray(_check_py_seq(df), dtype=dtype)
927
+ if size is None:
928
+ size = np.shape(df)
929
+ else:
930
+ size = _size2shape(size)
931
+ _check_shape("t", size, np.shape(df))
932
+ if key is None:
933
+ keys = self.split_key(2)
934
+ else:
935
+ keys = jr.split(_formalize_key(key), 2)
936
+ n = jr.normal(keys[0], size, dtype=dtype)
937
+ two = _const(n, 2)
938
+ half_df = lax.div(df, two)
939
+ g = jr.gamma(keys[1], half_df, size, dtype=dtype)
940
+ r = n * jnp.sqrt(half_df / g)
941
+ return r
942
+
943
+ def orthogonal(self,
944
+ n: int,
945
+ size: Optional[Size] = None,
946
+ key: Optional[SeedOrKey] = None,
947
+ dtype: DTypeLike = None):
948
+ dtype = dtype or environ.dftype()
949
+ key = self.split_key() if key is None else _formalize_key(key)
950
+ size = _size2shape(size)
951
+ _check_shape("orthogonal", size)
952
+ n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
953
+ z = jr.normal(key, size + (n, n), dtype=dtype)
954
+ q, r = jnp.linalg.qr(z)
955
+ d = jnp.diagonal(r, 0, -2, -1)
956
+ r = q * jnp.expand_dims(d / abs(d), -2)
957
+ return r
958
+
959
+ def noncentral_chisquare(self,
960
+ df,
961
+ nonc,
962
+ size: Optional[Size] = None,
963
+ key: Optional[SeedOrKey] = None,
964
+ dtype: DTypeLike = None):
965
+ dtype = dtype or environ.dftype()
966
+ df = jnp.asarray(_check_py_seq(df), dtype=dtype)
967
+ nonc = jnp.asarray(_check_py_seq(nonc), dtype=dtype)
968
+ if size is None:
969
+ size = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nonc))
970
+ size = _size2shape(size)
971
+ if key is None:
972
+ keys = self.split_key(3)
973
+ else:
974
+ keys = jr.split(_formalize_key(key), 3)
975
+ i = jr.poisson(keys[0], 0.5 * nonc, shape=size, dtype=environ.ditype())
976
+ n = jr.normal(keys[1], shape=size, dtype=dtype) + jnp.sqrt(nonc)
977
+ cond = jnp.greater(df, 1.0)
978
+ df2 = jnp.where(cond, df - 1.0, df + 2.0 * i)
979
+ chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size, dtype=dtype)
980
+ r = jnp.where(cond, chi2 + n * n, chi2)
981
+ return r
982
+
983
+ def loggamma(self,
984
+ a,
985
+ size: Optional[Size] = None,
986
+ key: Optional[SeedOrKey] = None,
987
+ dtype: DTypeLike = None):
988
+ dtype = dtype or environ.dftype()
989
+ key = self.split_key() if key is None else _formalize_key(key)
990
+ a = _check_py_seq(a)
991
+ if size is None:
992
+ size = jnp.shape(a)
993
+ r = jr.loggamma(key, a, shape=_size2shape(size), dtype=dtype)
994
+ return r
995
+
996
+ def categorical(self,
997
+ logits,
998
+ axis: int = -1,
999
+ size: Optional[Size] = None,
1000
+ key: Optional[SeedOrKey] = None):
1001
+ key = self.split_key() if key is None else _formalize_key(key)
1002
+ logits = _check_py_seq(logits)
1003
+ if size is None:
1004
+ size = list(jnp.shape(logits))
1005
+ size.pop(axis)
1006
+ r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
1007
+ return r
1008
+
1009
+ def zipf(self,
1010
+ a,
1011
+ size: Optional[Size] = None,
1012
+ key: Optional[SeedOrKey] = None,
1013
+ dtype: DTypeLike = None):
1014
+ a = _check_py_seq(a)
1015
+ if size is None:
1016
+ size = jnp.shape(a)
1017
+ dtype = dtype or environ.ditype()
1018
+ r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
1019
+ jax.ShapeDtypeStruct(size, dtype),
1020
+ a)
1021
+ return r
1022
+
1023
+ def power(self,
1024
+ a,
1025
+ size: Optional[Size] = None,
1026
+ key: Optional[SeedOrKey] = None,
1027
+ dtype: DTypeLike = None):
1028
+ a = _check_py_seq(a)
1029
+ if size is None:
1030
+ size = jnp.shape(a)
1031
+ size = _size2shape(size)
1032
+ dtype = dtype or environ.dftype()
1033
+ r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
1034
+ jax.ShapeDtypeStruct(size, dtype),
1035
+ a)
1036
+ return r
1037
+
1038
+ def f(self,
1039
+ dfnum,
1040
+ dfden,
1041
+ size: Optional[Size] = None,
1042
+ key: Optional[SeedOrKey] = None,
1043
+ dtype: DTypeLike = None):
1044
+ dfnum = _check_py_seq(dfnum)
1045
+ dfden = _check_py_seq(dfden)
1046
+ if size is None:
1047
+ size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden))
1048
+ size = _size2shape(size)
1049
+ d = {'dfnum': dfnum, 'dfden': dfden}
1050
+ dtype = dtype or environ.dftype()
1051
+ r = jax.pure_callback(lambda dfnum_, dfden_: np.random.f(dfnum=dfnum_,
1052
+ dfden=dfden_,
1053
+ size=size).astype(dtype),
1054
+ jax.ShapeDtypeStruct(size, dtype),
1055
+ dfnum, dfden)
1056
+ return r
1057
+
1058
+ def hypergeometric(
1059
+ self,
1060
+ ngood,
1061
+ nbad,
1062
+ nsample,
1063
+ size: Optional[Size] = None,
1064
+ key: Optional[SeedOrKey] = None,
1065
+ dtype: DTypeLike = None
1066
+ ):
1067
+ ngood = _check_py_seq(ngood)
1068
+ nbad = _check_py_seq(nbad)
1069
+ nsample = _check_py_seq(nsample)
1070
+
1071
+ if size is None:
1072
+ size = lax.broadcast_shapes(jnp.shape(ngood),
1073
+ jnp.shape(nbad),
1074
+ jnp.shape(nsample))
1075
+ size = _size2shape(size)
1076
+ dtype = dtype or environ.ditype()
1077
+ d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
1078
+ r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'],
1079
+ nbad=d['nbad'],
1080
+ nsample=d['nsample'],
1081
+ size=size).astype(dtype),
1082
+ jax.ShapeDtypeStruct(size, dtype),
1083
+ d)
1084
+ return r
1085
+
1086
+ def logseries(self,
1087
+ p,
1088
+ size: Optional[Size] = None,
1089
+ key: Optional[SeedOrKey] = None,
1090
+ dtype: DTypeLike = None):
1091
+ p = _check_py_seq(p)
1092
+ if size is None:
1093
+ size = jnp.shape(p)
1094
+ size = _size2shape(size)
1095
+ dtype = dtype or environ.ditype()
1096
+ r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype),
1097
+ jax.ShapeDtypeStruct(size, dtype),
1098
+ p)
1099
+ return r
1100
+
1101
+ def noncentral_f(self,
1102
+ dfnum,
1103
+ dfden,
1104
+ nonc,
1105
+ size: Optional[Size] = None,
1106
+ key: Optional[SeedOrKey] = None,
1107
+ dtype: DTypeLike = None):
1108
+ dfnum = _check_py_seq(dfnum)
1109
+ dfden = _check_py_seq(dfden)
1110
+ nonc = _check_py_seq(nonc)
1111
+ if size is None:
1112
+ size = lax.broadcast_shapes(jnp.shape(dfnum),
1113
+ jnp.shape(dfden),
1114
+ jnp.shape(nonc))
1115
+ size = _size2shape(size)
1116
+ d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
1117
+ dtype = dtype or environ.dftype()
1118
+ r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
1119
+ dfden=x['dfden'],
1120
+ nonc=x['nonc'],
1121
+ size=size).astype(dtype),
1122
+ jax.ShapeDtypeStruct(size, dtype),
1123
+ d)
1124
+ return r
1125
+
1126
+ # PyTorch compatibility #
1127
+ # --------------------- #
1128
+
1129
+ def rand_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
1130
+ """Returns a tensor with the same size as input that is filled with random
1131
+ numbers from a uniform distribution on the interval ``[0, 1)``.
1132
+
1133
+ Args:
1134
+ input: the ``size`` of input will determine size of the output tensor.
1135
+ dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1136
+ key: the seed or key for the random.
1137
+
1138
+ Returns:
1139
+ The random data.
1140
+ """
1141
+ return self.random(jnp.shape(input), key=key).astype(dtype)
1142
+
1143
+ def randn_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
1144
+ """Returns a tensor with the same size as ``input`` that is filled with
1145
+ random numbers from a normal distribution with mean 0 and variance 1.
1146
+
1147
+ Args:
1148
+ input: the ``size`` of input will determine size of the output tensor.
1149
+ dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1150
+ key: the seed or key for the random.
1151
+
1152
+ Returns:
1153
+ The random data.
1154
+ """
1155
+ return self.randn(*jnp.shape(input), key=key).astype(dtype)
1156
+
1157
+ def randint_like(self, input, low=0, high=None, *, dtype=None, key: Optional[SeedOrKey] = None):
1158
+ if high is None:
1159
+ high = max(input)
1160
+ return self.randint(low, high=high, size=jnp.shape(input), dtype=dtype, key=key)
1161
+
1162
+
1163
+ # default random generator
1164
+ DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))
1165
+
1166
+
1167
+ # ---------------------------------------------------------------------------------------------------------------
1168
+
1169
+
1170
+ def _formalize_key(key):
1171
+ if isinstance(key, int):
1172
+ return jr.PRNGKey(key) if use_prng_key else jr.key(key)
1173
+ elif isinstance(key, (jax.Array, np.ndarray)):
1174
+ if jnp.issubdtype(key.dtype, jax.dtypes.prng_key):
1175
+ return key
1176
+ if key.size == 1 and jnp.issubdtype(key.dtype, jnp.integer):
1177
+ return jr.PRNGKey(key) if use_prng_key else jr.key(key)
1178
+
1179
+ if key.dtype != jnp.uint32:
1180
+ raise TypeError('key must be a int or an array with two uint32.')
1181
+ if key.size != 2:
1182
+ raise TypeError('key must be a int or an array with two uint32.')
1183
+ return jnp.asarray(key, dtype=jnp.uint32)
1184
+ else:
1185
+ raise TypeError('key must be a int or an array with two uint32.')
1186
+
1187
+
1188
+ def _size2shape(size):
1189
+ if size is None:
1190
+ return ()
1191
+ elif isinstance(size, (tuple, list)):
1192
+ return tuple(size)
1193
+ else:
1194
+ return (size,)
1195
+
1196
+
1197
+ def _check_shape(name, shape, *param_shapes):
1198
+ if param_shapes:
1199
+ shape_ = lax.broadcast_shapes(shape, *param_shapes)
1200
+ if shape != shape_:
1201
+ msg = ("{} parameter shapes must be broadcast-compatible with shape "
1202
+ "argument, and the result of broadcasting the shapes must equal "
1203
+ "the shape argument, but got result {} for shape argument {}.")
1204
+ raise ValueError(msg.format(name, shape_, shape))
1205
+
1206
+
1207
+ def _is_python_scalar(x):
1208
+ if hasattr(x, 'aval'):
1209
+ return x.aval.weak_type
1210
+ elif np.ndim(x) == 0:
1211
+ return True
1212
+ elif isinstance(x, (bool, int, float, complex)):
1213
+ return True
1214
+ else:
1215
+ return False
1216
+
1217
+
1218
+ python_scalar_dtypes = {
1219
+ bool: np.dtype('bool'),
1220
+ int: np.dtype('int64'),
1221
+ float: np.dtype('float64'),
1222
+ complex: np.dtype('complex128'),
1223
+ }
1224
+
1225
+
1226
+ def _dtype(x, *, canonicalize: bool = False):
1227
+ """Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
1228
+ if x is None:
1229
+ raise ValueError(f"Invalid argument to dtype: {x}.")
1230
+ elif isinstance(x, type) and x in python_scalar_dtypes:
1231
+ dt = python_scalar_dtypes[x]
1232
+ elif type(x) in python_scalar_dtypes:
1233
+ dt = python_scalar_dtypes[type(x)]
1234
+ elif hasattr(x, 'dtype'):
1235
+ dt = x.dtype
1236
+ else:
1237
+ dt = np.result_type(x)
1238
+ return dtypes.canonicalize_dtype(dt) if canonicalize else dt
1239
+
1240
+
1241
+ def _const(example, val):
1242
+ if _is_python_scalar(example):
1243
+ dtype = dtypes.canonicalize_dtype(type(example))
1244
+ val = dtypes.scalar_type_of(example)(val)
1245
+ return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype)
1246
+ else:
1247
+ dtype = dtypes.canonicalize_dtype(example.dtype)
1248
+ return np.array(val, dtype)
1249
+
1250
+
1251
+ @partial(jit, static_argnums=(2,))
1252
+ def _categorical(key, p, shape):
1253
+ # this implementation is fast when event shape is small, and slow otherwise
1254
+ # Ref: https://stackoverflow.com/a/34190035
1255
+ shape = shape or p.shape[:-1]
1256
+ s = jnp.cumsum(p, axis=-1)
1257
+ r = jr.uniform(key, shape=shape + (1,))
1258
+ return jnp.sum(s < r, axis=-1)
1259
+
1260
+
1261
+ def _scatter_add_one(operand, indices, updates):
1262
+ return lax.scatter_add(
1263
+ operand,
1264
+ indices,
1265
+ updates,
1266
+ lax.ScatterDimensionNumbers(
1267
+ update_window_dims=(),
1268
+ inserted_window_dims=(0,),
1269
+ scatter_dims_to_operand_dims=(0,),
1270
+ ),
1271
+ )
1272
+
1273
+
1274
+ def _reshape(x, shape):
1275
+ if isinstance(x, (int, float, np.ndarray, np.generic)):
1276
+ return np.reshape(x, shape)
1277
+ else:
1278
+ return jnp.reshape(x, shape)
1279
+
1280
+
1281
+ def _promote_shapes(*args, shape=()):
1282
+ # adapted from lax.lax_numpy
1283
+ if len(args) < 2 and not shape:
1284
+ return args
1285
+ else:
1286
+ shapes = [jnp.shape(arg) for arg in args]
1287
+ num_dims = len(lax.broadcast_shapes(shape, *shapes))
1288
+ return [
1289
+ _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
1290
+ for arg, s in zip(args, shapes)
1291
+ ]
1292
+
1293
+
1294
+ @partial(jit, static_argnums=(3, 4))
1295
+ def _multinomial(key, p, n, n_max, shape=()):
1296
+ if jnp.shape(n) != jnp.shape(p)[:-1]:
1297
+ broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
1298
+ n = jnp.broadcast_to(n, broadcast_shape)
1299
+ p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
1300
+ shape = shape or p.shape[:-1]
1301
+ if n_max == 0:
1302
+ return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
1303
+ # get indices from categorical distribution then gather the result
1304
+ indices = _categorical(key, p, (n_max,) + shape)
1305
+ # mask out values when counts is heterogeneous
1306
+ if jnp.ndim(n) > 0:
1307
+ mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
1308
+ mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
1309
+ excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1),
1310
+ jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))],
1311
+ -1)
1312
+ else:
1313
+ mask = 1
1314
+ excess = 0
1315
+ # NB: we transpose to move batch shape to the front
1316
+ indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
1317
+ samples_2D = vmap(_scatter_add_one)(jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
1318
+ jnp.expand_dims(indices_2D, axis=-1),
1319
+ jnp.ones(indices_2D.shape, dtype=indices.dtype))
1320
+ return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
1321
+
1322
+
1323
+ @partial(jit, static_argnums=(2, 3), static_argnames=['shape', 'dtype'])
1324
+ def _von_mises_centered(key, concentration, shape, dtype=None):
1325
+ """Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
1326
+
1327
+ Returns
1328
+ -------
1329
+ out: array_like
1330
+ centered samples from von Mises
1331
+
1332
+ References
1333
+ ----------
1334
+ .. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
1335
+ Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
1336
+
1337
+ """
1338
+ shape = shape or jnp.shape(concentration)
1339
+ dtype = dtype or environ.dftype()
1340
+ concentration = lax.convert_element_type(concentration, dtype)
1341
+ concentration = jnp.broadcast_to(concentration, shape)
1342
+
1343
+ if dtype == jnp.float16:
1344
+ s_cutoff = 1.8e-1
1345
+ elif dtype == jnp.float32:
1346
+ s_cutoff = 2e-2
1347
+ elif dtype == jnp.float64:
1348
+ s_cutoff = 1.2e-4
1349
+ else:
1350
+ raise ValueError(f"Unsupported dtype: {dtype}")
1351
+
1352
+ r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
1353
+ rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
1354
+ s_exact = (1.0 + rho ** 2) / (2.0 * rho)
1355
+
1356
+ s_approximate = 1.0 / concentration
1357
+
1358
+ s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
1359
+
1360
+ def cond_fn(*args):
1361
+ """check if all are done or reached max number of iterations"""
1362
+ i, _, done, _, _ = args[0]
1363
+ return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
1364
+
1365
+ def body_fn(*args):
1366
+ i, key, done, _, w = args[0]
1367
+ uni_ukey, uni_vkey, key = jr.split(key, 3)
1368
+ u = jr.uniform(
1369
+ key=uni_ukey,
1370
+ shape=shape,
1371
+ dtype=concentration.dtype,
1372
+ minval=-1.0,
1373
+ maxval=1.0,
1374
+ )
1375
+ z = jnp.cos(jnp.pi * u)
1376
+ w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
1377
+ y = concentration * (s - w)
1378
+ v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
1379
+ accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
1380
+ return i + 1, key, accept | done, u, w
1381
+
1382
+ init_done = jnp.zeros(shape, dtype=bool)
1383
+ init_u = jnp.zeros(shape)
1384
+ init_w = jnp.zeros(shape)
1385
+
1386
+ _, _, done, u, w = lax.while_loop(
1387
+ cond_fun=cond_fn,
1388
+ body_fun=body_fn,
1389
+ init_val=(jnp.array(0), key, init_done, init_u, init_w),
1390
+ )
1391
+
1392
+ return jnp.sign(u) * jnp.arccos(w)
1393
+
1394
+
1395
+ def _loc_scale(loc, scale, value):
1396
+ if loc is None:
1397
+ if scale is None:
1398
+ return value
1399
+ else:
1400
+ return value * scale
1401
+ else:
1402
+ if scale is None:
1403
+ return value + loc
1404
+ else:
1405
+ return value * scale + loc
1406
+
1407
+
1408
+ def _check_py_seq(seq):
1409
+ return jnp.asarray(seq) if isinstance(seq, (tuple, list)) else seq