brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +2 -4
- brainstate/_deprecation_test.py +2 -24
- brainstate/_state.py +540 -35
- brainstate/_state_test.py +1085 -8
- brainstate/graph/_operation.py +1 -5
- brainstate/mixin.py +14 -0
- brainstate/nn/__init__.py +42 -33
- brainstate/nn/_collective_ops.py +2 -0
- brainstate/nn/_common_test.py +0 -20
- brainstate/nn/_delay.py +1 -1
- brainstate/nn/_dropout_test.py +9 -6
- brainstate/nn/_dynamics.py +67 -464
- brainstate/nn/_dynamics_test.py +0 -14
- brainstate/nn/_embedding.py +7 -7
- brainstate/nn/_exp_euler.py +9 -9
- brainstate/nn/_linear.py +21 -21
- brainstate/nn/_module.py +25 -18
- brainstate/nn/_normalizations.py +27 -27
- brainstate/random/__init__.py +6 -6
- brainstate/random/{_rand_funs.py → _fun.py} +1 -1
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +1 -1
- brainstate/random/{_rand_state.py → _state.py} +121 -418
- brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
- brainstate/transform/__init__.py +6 -9
- brainstate/transform/_conditions.py +2 -2
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_make_jaxpr.py +221 -61
- brainstate/transform/_make_jaxpr_test.py +125 -1
- brainstate/transform/_mapping.py +287 -209
- brainstate/transform/_mapping_test.py +94 -184
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,6 @@
|
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
|
-
from functools import partial
|
19
18
|
from operator import index
|
20
19
|
from typing import Optional
|
21
20
|
|
@@ -24,12 +23,16 @@ import jax
|
|
24
23
|
import jax.numpy as jnp
|
25
24
|
import jax.random as jr
|
26
25
|
import numpy as np
|
27
|
-
from jax import
|
28
|
-
from jax import lax, core, dtypes
|
26
|
+
from jax import lax, core
|
29
27
|
|
30
28
|
from brainstate import environ
|
31
29
|
from brainstate._state import State
|
32
30
|
from brainstate.typing import DTypeLike, Size, SeedOrKey
|
31
|
+
from ._impl import (
|
32
|
+
multinomial, von_mises_centered, const,
|
33
|
+
formalize_key, _loc_scale, _size2shape, _check_py_seq, _check_shape,
|
34
|
+
noncentral_f, logseries, hypergeometric, f, power, zipf
|
35
|
+
)
|
33
36
|
|
34
37
|
__all__ = [
|
35
38
|
'RandomState',
|
@@ -73,14 +76,10 @@ class RandomState(State):
|
|
73
76
|
|
74
77
|
self._backup = None
|
75
78
|
|
76
|
-
def __repr__(
|
77
|
-
self
|
78
|
-
):
|
79
|
+
def __repr__(self):
|
79
80
|
return f'{self.__class__.__name__}({self.value})'
|
80
81
|
|
81
|
-
def check_if_deleted(
|
82
|
-
self
|
83
|
-
):
|
82
|
+
def check_if_deleted(self):
|
84
83
|
if not use_prng_key and isinstance(self._value, np.ndarray):
|
85
84
|
self._value = jr.key(np.random.randint(0, 10000))
|
86
85
|
|
@@ -91,6 +90,11 @@ class RandomState(State):
|
|
91
90
|
):
|
92
91
|
self.seed()
|
93
92
|
|
93
|
+
@staticmethod
|
94
|
+
def _batch_keys(batch_size: int):
|
95
|
+
key = jr.PRNGKey(0) if use_prng_key else jr.key(0)
|
96
|
+
return jr.split(key, batch_size)
|
97
|
+
|
94
98
|
# ------------------- #
|
95
99
|
# seed and random key #
|
96
100
|
# ------------------- #
|
@@ -126,7 +130,7 @@ class RandomState(State):
|
|
126
130
|
"""
|
127
131
|
with jax.ensure_compile_time_eval():
|
128
132
|
if seed_or_key is None:
|
129
|
-
seed_or_key = np.random.randint(0,
|
133
|
+
seed_or_key = np.random.randint(0, 10000000, 2, dtype=np.uint32)
|
130
134
|
if np.size(seed_or_key) == 1:
|
131
135
|
if isinstance(seed_or_key, int):
|
132
136
|
key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
|
@@ -156,7 +160,7 @@ class RandomState(State):
|
|
156
160
|
n: int, optional
|
157
161
|
The number of seeds to generate.
|
158
162
|
backup : bool, optional
|
159
|
-
Whether to
|
163
|
+
Whether to back up the current key.
|
160
164
|
|
161
165
|
Returns
|
162
166
|
-------
|
@@ -193,6 +197,9 @@ class RandomState(State):
|
|
193
197
|
else:
|
194
198
|
self.value = jr.split(self.value, n)
|
195
199
|
|
200
|
+
def __get_key(self, key):
|
201
|
+
return self.split_key() if key is None else formalize_key(key, use_prng_key)
|
202
|
+
|
196
203
|
# ---------------- #
|
197
204
|
# random functions #
|
198
205
|
# ---------------- #
|
@@ -203,7 +210,7 @@ class RandomState(State):
|
|
203
210
|
key: Optional[SeedOrKey] = None,
|
204
211
|
dtype: DTypeLike = None
|
205
212
|
):
|
206
|
-
key = self.
|
213
|
+
key = self.__get_key(key)
|
207
214
|
dtype = dtype or environ.dftype()
|
208
215
|
r = jr.uniform(key, dn, dtype)
|
209
216
|
return r
|
@@ -222,9 +229,8 @@ class RandomState(State):
|
|
222
229
|
high = _check_py_seq(high)
|
223
230
|
low = _check_py_seq(low)
|
224
231
|
if size is None:
|
225
|
-
size = lax.broadcast_shapes(u.math.shape(low),
|
226
|
-
|
227
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
232
|
+
size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
|
233
|
+
key = self.__get_key(key)
|
228
234
|
dtype = dtype or environ.ditype()
|
229
235
|
r = jr.randint(key,
|
230
236
|
shape=_size2shape(size),
|
@@ -247,7 +253,7 @@ class RandomState(State):
|
|
247
253
|
high += 1
|
248
254
|
if size is None:
|
249
255
|
size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
|
250
|
-
key = self.
|
256
|
+
key = self.__get_key(key)
|
251
257
|
dtype = dtype or environ.ditype()
|
252
258
|
r = jr.randint(key,
|
253
259
|
shape=_size2shape(size),
|
@@ -262,9 +268,8 @@ class RandomState(State):
|
|
262
268
|
key: Optional[SeedOrKey] = None,
|
263
269
|
dtype: DTypeLike = None
|
264
270
|
):
|
265
|
-
key = self.
|
266
|
-
|
267
|
-
r = jr.normal(key, shape=dn, dtype=dtype)
|
271
|
+
key = self.__get_key(key)
|
272
|
+
r = jr.normal(key, shape=dn, dtype=dtype or environ.dftype())
|
268
273
|
return r
|
269
274
|
|
270
275
|
def random(
|
@@ -273,9 +278,8 @@ class RandomState(State):
|
|
273
278
|
key: Optional[SeedOrKey] = None,
|
274
279
|
dtype: DTypeLike = None
|
275
280
|
):
|
276
|
-
|
277
|
-
|
278
|
-
r = jr.uniform(key, _size2shape(size), dtype)
|
281
|
+
key = self.__get_key(key)
|
282
|
+
r = jr.uniform(key, _size2shape(size), dtype=dtype or environ.dftype())
|
279
283
|
return r
|
280
284
|
|
281
285
|
def random_sample(
|
@@ -284,7 +288,7 @@ class RandomState(State):
|
|
284
288
|
key: Optional[SeedOrKey] = None,
|
285
289
|
dtype: DTypeLike = None
|
286
290
|
):
|
287
|
-
r = self.random(size=size, key=key, dtype=dtype)
|
291
|
+
r = self.random(size=size, key=key, dtype=dtype or environ.dftype())
|
288
292
|
return r
|
289
293
|
|
290
294
|
def ranf(
|
@@ -293,7 +297,7 @@ class RandomState(State):
|
|
293
297
|
key: Optional[SeedOrKey] = None,
|
294
298
|
dtype: DTypeLike = None
|
295
299
|
):
|
296
|
-
r = self.random(size=size, key=key, dtype=dtype)
|
300
|
+
r = self.random(size=size, key=key, dtype=dtype or environ.dftype())
|
297
301
|
return r
|
298
302
|
|
299
303
|
def sample(
|
@@ -302,7 +306,7 @@ class RandomState(State):
|
|
302
306
|
key: Optional[SeedOrKey] = None,
|
303
307
|
dtype: DTypeLike = None
|
304
308
|
):
|
305
|
-
r = self.random(size=size, key=key, dtype=dtype)
|
309
|
+
r = self.random(size=size, key=key, dtype=dtype or environ.dftype())
|
306
310
|
return r
|
307
311
|
|
308
312
|
def choice(
|
@@ -316,7 +320,7 @@ class RandomState(State):
|
|
316
320
|
a = _check_py_seq(a)
|
317
321
|
a, unit = u.split_mantissa_unit(a)
|
318
322
|
p = _check_py_seq(p)
|
319
|
-
key = self.
|
323
|
+
key = self.__get_key(key)
|
320
324
|
r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
|
321
325
|
return u.maybe_decimal(r * unit)
|
322
326
|
|
@@ -329,7 +333,7 @@ class RandomState(State):
|
|
329
333
|
):
|
330
334
|
x = _check_py_seq(x)
|
331
335
|
x, unit = u.split_mantissa_unit(x)
|
332
|
-
key = self.
|
336
|
+
key = self.__get_key(key)
|
333
337
|
r = jr.permutation(key, x, axis, independent=independent)
|
334
338
|
return u.maybe_decimal(r * unit)
|
335
339
|
|
@@ -353,9 +357,8 @@ class RandomState(State):
|
|
353
357
|
b = _check_py_seq(b)
|
354
358
|
if size is None:
|
355
359
|
size = lax.broadcast_shapes(u.math.shape(a), u.math.shape(b))
|
356
|
-
key = self.
|
357
|
-
|
358
|
-
r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype)
|
360
|
+
key = self.__get_key(key)
|
361
|
+
r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype or environ.dftype())
|
359
362
|
return r
|
360
363
|
|
361
364
|
def exponential(
|
@@ -367,9 +370,8 @@ class RandomState(State):
|
|
367
370
|
):
|
368
371
|
if size is None:
|
369
372
|
size = u.math.shape(scale)
|
370
|
-
key = self.
|
371
|
-
|
372
|
-
r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
|
373
|
+
key = self.__get_key(key)
|
374
|
+
r = jr.exponential(key, shape=_size2shape(size), dtype=dtype or environ.dftype())
|
373
375
|
if scale is not None:
|
374
376
|
scale = u.math.asarray(scale, dtype=dtype)
|
375
377
|
r = r / scale
|
@@ -387,9 +389,8 @@ class RandomState(State):
|
|
387
389
|
scale = _check_py_seq(scale)
|
388
390
|
if size is None:
|
389
391
|
size = lax.broadcast_shapes(u.math.shape(shape), u.math.shape(scale))
|
390
|
-
key = self.
|
391
|
-
|
392
|
-
r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
|
392
|
+
key = self.__get_key(key)
|
393
|
+
r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype or environ.dftype())
|
393
394
|
if scale is not None:
|
394
395
|
r = r * scale
|
395
396
|
return r
|
@@ -406,9 +407,8 @@ class RandomState(State):
|
|
406
407
|
scale = _check_py_seq(scale)
|
407
408
|
if size is None:
|
408
409
|
size = lax.broadcast_shapes(u.math.shape(loc), u.math.shape(scale))
|
409
|
-
key = self.
|
410
|
-
|
411
|
-
r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype))
|
410
|
+
key = self.__get_key(key)
|
411
|
+
r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype or environ.dftype()))
|
412
412
|
return r
|
413
413
|
|
414
414
|
def laplace(
|
@@ -423,9 +423,8 @@ class RandomState(State):
|
|
423
423
|
scale = _check_py_seq(scale)
|
424
424
|
if size is None:
|
425
425
|
size = lax.broadcast_shapes(u.math.shape(loc), u.math.shape(scale))
|
426
|
-
key = self.
|
427
|
-
|
428
|
-
r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype))
|
426
|
+
key = self.__get_key(key)
|
427
|
+
r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype or environ.dftype()))
|
429
428
|
return r
|
430
429
|
|
431
430
|
def logistic(
|
@@ -443,9 +442,8 @@ class RandomState(State):
|
|
443
442
|
u.math.shape(loc) if loc is not None else (),
|
444
443
|
u.math.shape(scale) if scale is not None else ()
|
445
444
|
)
|
446
|
-
key = self.
|
447
|
-
|
448
|
-
r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
|
445
|
+
key = self.__get_key(key)
|
446
|
+
r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype or environ.dftype()))
|
449
447
|
return r
|
450
448
|
|
451
449
|
def normal(
|
@@ -463,7 +461,7 @@ class RandomState(State):
|
|
463
461
|
u.math.shape(scale) if scale is not None else (),
|
464
462
|
u.math.shape(loc) if loc is not None else ()
|
465
463
|
)
|
466
|
-
key = self.
|
464
|
+
key = self.__get_key(key)
|
467
465
|
dtype = dtype or environ.dftype()
|
468
466
|
r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
|
469
467
|
return r
|
@@ -477,7 +475,7 @@ class RandomState(State):
|
|
477
475
|
):
|
478
476
|
if size is None:
|
479
477
|
size = u.math.shape(a)
|
480
|
-
key = self.
|
478
|
+
key = self.__get_key(key)
|
481
479
|
dtype = dtype or environ.dftype()
|
482
480
|
a = u.math.asarray(a, dtype=dtype)
|
483
481
|
r = jr.pareto(key, b=a, shape=_size2shape(size), dtype=dtype)
|
@@ -493,7 +491,7 @@ class RandomState(State):
|
|
493
491
|
lam = _check_py_seq(lam)
|
494
492
|
if size is None:
|
495
493
|
size = u.math.shape(lam)
|
496
|
-
key = self.
|
494
|
+
key = self.__get_key(key)
|
497
495
|
dtype = dtype or environ.ditype()
|
498
496
|
r = jr.poisson(key, lam=lam, shape=_size2shape(size), dtype=dtype)
|
499
497
|
return r
|
@@ -504,7 +502,7 @@ class RandomState(State):
|
|
504
502
|
key: Optional[SeedOrKey] = None,
|
505
503
|
dtype: DTypeLike = None
|
506
504
|
):
|
507
|
-
key = self.
|
505
|
+
key = self.__get_key(key)
|
508
506
|
dtype = dtype or environ.dftype()
|
509
507
|
r = jr.cauchy(key, shape=_size2shape(size), dtype=dtype)
|
510
508
|
return r
|
@@ -515,7 +513,7 @@ class RandomState(State):
|
|
515
513
|
key: Optional[SeedOrKey] = None,
|
516
514
|
dtype: DTypeLike = None
|
517
515
|
):
|
518
|
-
key = self.
|
516
|
+
key = self.__get_key(key)
|
519
517
|
dtype = dtype or environ.dftype()
|
520
518
|
r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
|
521
519
|
return r
|
@@ -530,7 +528,7 @@ class RandomState(State):
|
|
530
528
|
shape = _check_py_seq(shape)
|
531
529
|
if size is None:
|
532
530
|
size = u.math.shape(shape) if shape is not None else ()
|
533
|
-
key = self.
|
531
|
+
key = self.__get_key(key)
|
534
532
|
dtype = dtype or environ.dftype()
|
535
533
|
r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
|
536
534
|
return r
|
@@ -541,7 +539,7 @@ class RandomState(State):
|
|
541
539
|
key: Optional[SeedOrKey] = None,
|
542
540
|
dtype: DTypeLike = None
|
543
541
|
):
|
544
|
-
key = self.
|
542
|
+
key = self.__get_key(key)
|
545
543
|
dtype = dtype or environ.dftype()
|
546
544
|
r = jr.normal(key, shape=_size2shape(size), dtype=dtype)
|
547
545
|
return r
|
@@ -556,7 +554,7 @@ class RandomState(State):
|
|
556
554
|
df = _check_py_seq(df)
|
557
555
|
if size is None:
|
558
556
|
size = u.math.shape(size) if size is not None else ()
|
559
|
-
key = self.
|
557
|
+
key = self.__get_key(key)
|
560
558
|
dtype = dtype or environ.dftype()
|
561
559
|
r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
|
562
560
|
return r
|
@@ -573,17 +571,12 @@ class RandomState(State):
|
|
573
571
|
high = u.Quantity(_check_py_seq(high)).to(unit).mantissa
|
574
572
|
if size is None:
|
575
573
|
size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
|
576
|
-
key = self.
|
574
|
+
key = self.__get_key(key)
|
577
575
|
dtype = dtype or environ.dftype()
|
578
576
|
r = jr.uniform(key, _size2shape(size), dtype=dtype, minval=low, maxval=high)
|
579
577
|
return u.maybe_decimal(r * unit)
|
580
578
|
|
581
|
-
def __norm_cdf(
|
582
|
-
self,
|
583
|
-
x,
|
584
|
-
sqrt2,
|
585
|
-
dtype
|
586
|
-
):
|
579
|
+
def __norm_cdf(self, x, sqrt2, dtype):
|
587
580
|
# Computes standard normal cumulative distribution function
|
588
581
|
return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
|
589
582
|
|
@@ -639,7 +632,7 @@ class RandomState(State):
|
|
639
632
|
|
640
633
|
# Uniformly fill tensor with values from [l, u], then translate to
|
641
634
|
# [2l-1, 2u-1].
|
642
|
-
key = self.
|
635
|
+
key = self.__get_key(key)
|
643
636
|
out = jr.uniform(
|
644
637
|
key, size, dtype,
|
645
638
|
minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
|
@@ -677,7 +670,7 @@ class RandomState(State):
|
|
677
670
|
jit_error_if(jnp.any(jnp.logical_or(p < 0, p > 1)), self._check_p, p=p)
|
678
671
|
if size is None:
|
679
672
|
size = u.math.shape(p)
|
680
|
-
key = self.
|
673
|
+
key = self.__get_key(key)
|
681
674
|
r = jr.bernoulli(key, p=p, shape=_size2shape(size))
|
682
675
|
return r
|
683
676
|
|
@@ -689,6 +682,7 @@ class RandomState(State):
|
|
689
682
|
key: Optional[SeedOrKey] = None,
|
690
683
|
dtype: DTypeLike = None
|
691
684
|
):
|
685
|
+
dtype = dtype or environ.dftype()
|
692
686
|
mean = _check_py_seq(mean)
|
693
687
|
sigma = _check_py_seq(sigma)
|
694
688
|
mean = u.math.asarray(mean, dtype=dtype)
|
@@ -702,7 +696,7 @@ class RandomState(State):
|
|
702
696
|
u.math.shape(mean) if mean is not None else (),
|
703
697
|
u.math.shape(sigma) if sigma is not None else ()
|
704
698
|
)
|
705
|
-
key = self.
|
699
|
+
key = self.__get_key(key)
|
706
700
|
dtype = dtype or environ.dftype()
|
707
701
|
samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
|
708
702
|
samples = _loc_scale(mean, sigma, samples)
|
@@ -729,7 +723,7 @@ class RandomState(State):
|
|
729
723
|
)
|
730
724
|
if size is None:
|
731
725
|
size = jnp.broadcast_shapes(u.math.shape(n), u.math.shape(p))
|
732
|
-
key = self.
|
726
|
+
key = self.__get_key(key)
|
733
727
|
r = jr.binomial(key, n, p, shape=_size2shape(size))
|
734
728
|
dtype = dtype or environ.ditype()
|
735
729
|
return u.math.asarray(r, dtype=dtype)
|
@@ -742,7 +736,7 @@ class RandomState(State):
|
|
742
736
|
dtype: DTypeLike = None
|
743
737
|
):
|
744
738
|
df = _check_py_seq(df)
|
745
|
-
key = self.
|
739
|
+
key = self.__get_key(key)
|
746
740
|
dtype = dtype or environ.dftype()
|
747
741
|
if size is None:
|
748
742
|
if jnp.ndim(df) == 0:
|
@@ -762,7 +756,7 @@ class RandomState(State):
|
|
762
756
|
key: Optional[SeedOrKey] = None,
|
763
757
|
dtype: DTypeLike = None
|
764
758
|
):
|
765
|
-
key = self.
|
759
|
+
key = self.__get_key(key)
|
766
760
|
alpha = _check_py_seq(alpha)
|
767
761
|
dtype = dtype or environ.dftype()
|
768
762
|
r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size), dtype=dtype)
|
@@ -778,7 +772,7 @@ class RandomState(State):
|
|
778
772
|
p = _check_py_seq(p)
|
779
773
|
if size is None:
|
780
774
|
size = u.math.shape(p)
|
781
|
-
key = self.
|
775
|
+
key = self.__get_key(key)
|
782
776
|
dtype = dtype or environ.dftype()
|
783
777
|
u_ = jr.uniform(key, size, dtype)
|
784
778
|
r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p))
|
@@ -796,7 +790,7 @@ class RandomState(State):
|
|
796
790
|
dtype: DTypeLike = None,
|
797
791
|
check_valid: bool = True
|
798
792
|
):
|
799
|
-
key = self.
|
793
|
+
key = self.__get_key(key)
|
800
794
|
n = _check_py_seq(n)
|
801
795
|
pvals = _check_py_seq(pvals)
|
802
796
|
if check_valid:
|
@@ -807,7 +801,7 @@ class RandomState(State):
|
|
807
801
|
size = _size2shape(size)
|
808
802
|
n_max = int(np.max(jax.device_get(n)))
|
809
803
|
batch_shape = lax.broadcast_shapes(u.math.shape(pvals)[:-1], u.math.shape(n))
|
810
|
-
r =
|
804
|
+
r = multinomial(key, pvals, n, n_max=n_max, shape=batch_shape + size)
|
811
805
|
dtype = dtype or environ.ditype()
|
812
806
|
return u.math.asarray(r, dtype=dtype)
|
813
807
|
|
@@ -832,7 +826,7 @@ class RandomState(State):
|
|
832
826
|
cov = cov.mantissa if isinstance(cov, u.Quantity) else cov
|
833
827
|
unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
|
834
828
|
|
835
|
-
key = self.
|
829
|
+
key = self.__get_key(key)
|
836
830
|
if not jnp.ndim(mean) >= 1:
|
837
831
|
raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
|
838
832
|
if not jnp.ndim(cov) >= 2:
|
@@ -869,7 +863,7 @@ class RandomState(State):
|
|
869
863
|
scale = _check_py_seq(scale)
|
870
864
|
if size is None:
|
871
865
|
size = u.math.shape(scale)
|
872
|
-
key = self.
|
866
|
+
key = self.__get_key(key)
|
873
867
|
dtype = dtype or environ.dftype()
|
874
868
|
x = jnp.sqrt(-2. * jnp.log(jr.uniform(key, shape=_size2shape(size), dtype=dtype)))
|
875
869
|
r = x * scale
|
@@ -880,7 +874,7 @@ class RandomState(State):
|
|
880
874
|
size: Optional[Size] = None,
|
881
875
|
key: Optional[SeedOrKey] = None
|
882
876
|
):
|
883
|
-
key = self.
|
877
|
+
key = self.__get_key(key)
|
884
878
|
bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
|
885
879
|
r = 2 * bernoulli_samples - 1
|
886
880
|
return r
|
@@ -893,14 +887,14 @@ class RandomState(State):
|
|
893
887
|
key: Optional[SeedOrKey] = None,
|
894
888
|
dtype: DTypeLike = None
|
895
889
|
):
|
896
|
-
key = self.
|
890
|
+
key = self.__get_key(key)
|
897
891
|
dtype = dtype or environ.dftype()
|
898
892
|
mu = u.math.asarray(_check_py_seq(mu), dtype=dtype)
|
899
893
|
kappa = u.math.asarray(_check_py_seq(kappa), dtype=dtype)
|
900
894
|
if size is None:
|
901
895
|
size = lax.broadcast_shapes(u.math.shape(mu), u.math.shape(kappa))
|
902
896
|
size = _size2shape(size)
|
903
|
-
samples =
|
897
|
+
samples = von_mises_centered(key, kappa, size, dtype=dtype)
|
904
898
|
samples = samples + mu
|
905
899
|
samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
|
906
900
|
return samples
|
@@ -912,7 +906,7 @@ class RandomState(State):
|
|
912
906
|
key: Optional[SeedOrKey] = None,
|
913
907
|
dtype: DTypeLike = None
|
914
908
|
):
|
915
|
-
key = self.
|
909
|
+
key = self.__get_key(key)
|
916
910
|
a = _check_py_seq(a)
|
917
911
|
if size is None:
|
918
912
|
size = u.math.shape(a)
|
@@ -933,7 +927,7 @@ class RandomState(State):
|
|
933
927
|
key: Optional[SeedOrKey] = None,
|
934
928
|
dtype: DTypeLike = None
|
935
929
|
):
|
936
|
-
key = self.
|
930
|
+
key = self.__get_key(key)
|
937
931
|
a = _check_py_seq(a)
|
938
932
|
scale = _check_py_seq(scale)
|
939
933
|
if size is None:
|
@@ -955,7 +949,7 @@ class RandomState(State):
|
|
955
949
|
key: Optional[SeedOrKey] = None,
|
956
950
|
dtype: DTypeLike = None
|
957
951
|
):
|
958
|
-
key = self.
|
952
|
+
key = self.__get_key(key)
|
959
953
|
shape = _size2shape(size) + (3,)
|
960
954
|
dtype = dtype or environ.dftype()
|
961
955
|
norm_rvs = jr.normal(key=key, shape=shape, dtype=dtype)
|
@@ -979,7 +973,7 @@ class RandomState(State):
|
|
979
973
|
if key is None:
|
980
974
|
keys = self.split_key(2)
|
981
975
|
else:
|
982
|
-
keys = jr.split(
|
976
|
+
keys = jr.split(formalize_key(key, use_prng_key), 2)
|
983
977
|
rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0], dtype=environ.dftype())
|
984
978
|
r = self.poisson(lam=rate, key=keys[1], dtype=dtype or environ.ditype())
|
985
979
|
return r
|
@@ -993,7 +987,7 @@ class RandomState(State):
|
|
993
987
|
dtype: DTypeLike = None
|
994
988
|
):
|
995
989
|
dtype = dtype or environ.dftype()
|
996
|
-
key = self.
|
990
|
+
key = self.__get_key(key)
|
997
991
|
mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
|
998
992
|
scale = u.math.asarray(_check_py_seq(scale), dtype=dtype)
|
999
993
|
if size is None:
|
@@ -1049,9 +1043,9 @@ class RandomState(State):
|
|
1049
1043
|
if key is None:
|
1050
1044
|
keys = self.split_key(2)
|
1051
1045
|
else:
|
1052
|
-
keys = jr.split(
|
1046
|
+
keys = jr.split(formalize_key(key, use_prng_key), 2)
|
1053
1047
|
n = jr.normal(keys[0], size, dtype=dtype)
|
1054
|
-
two =
|
1048
|
+
two = const(n, 2)
|
1055
1049
|
half_df = lax.div(df, two)
|
1056
1050
|
g = jr.gamma(keys[1], half_df, size, dtype=dtype)
|
1057
1051
|
r = n * jnp.sqrt(half_df / g)
|
@@ -1065,7 +1059,7 @@ class RandomState(State):
|
|
1065
1059
|
dtype: DTypeLike = None
|
1066
1060
|
):
|
1067
1061
|
dtype = dtype or environ.dftype()
|
1068
|
-
key = self.
|
1062
|
+
key = self.__get_key(key)
|
1069
1063
|
size = _size2shape(size)
|
1070
1064
|
_check_shape("orthogonal", size)
|
1071
1065
|
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
|
@@ -1092,7 +1086,7 @@ class RandomState(State):
|
|
1092
1086
|
if key is None:
|
1093
1087
|
keys = self.split_key(3)
|
1094
1088
|
else:
|
1095
|
-
keys = jr.split(
|
1089
|
+
keys = jr.split(formalize_key(key, use_prng_key), 3)
|
1096
1090
|
i = jr.poisson(keys[0], 0.5 * nonc, shape=size, dtype=environ.ditype())
|
1097
1091
|
n = jr.normal(keys[1], shape=size, dtype=dtype) + jnp.sqrt(nonc)
|
1098
1092
|
cond = jnp.greater(df, 1.0)
|
@@ -1109,7 +1103,7 @@ class RandomState(State):
|
|
1109
1103
|
dtype: DTypeLike = None
|
1110
1104
|
):
|
1111
1105
|
dtype = dtype or environ.dftype()
|
1112
|
-
key = self.
|
1106
|
+
key = self.__get_key(key)
|
1113
1107
|
a = _check_py_seq(a)
|
1114
1108
|
if size is None:
|
1115
1109
|
size = u.math.shape(a)
|
@@ -1123,7 +1117,7 @@ class RandomState(State):
|
|
1123
1117
|
size: Optional[Size] = None,
|
1124
1118
|
key: Optional[SeedOrKey] = None
|
1125
1119
|
):
|
1126
|
-
key = self.
|
1120
|
+
key = self.__get_key(key)
|
1127
1121
|
logits = _check_py_seq(logits)
|
1128
1122
|
if size is None:
|
1129
1123
|
size = list(u.math.shape(logits))
|
@@ -1141,10 +1135,12 @@ class RandomState(State):
|
|
1141
1135
|
a = _check_py_seq(a)
|
1142
1136
|
if size is None:
|
1143
1137
|
size = u.math.shape(a)
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1138
|
+
r = zipf(
|
1139
|
+
self.__get_key(key),
|
1140
|
+
a,
|
1141
|
+
shape=size,
|
1142
|
+
dtype=dtype or environ.ditype()
|
1143
|
+
)
|
1148
1144
|
return r
|
1149
1145
|
|
1150
1146
|
def power(
|
@@ -1158,10 +1154,12 @@ class RandomState(State):
|
|
1158
1154
|
if size is None:
|
1159
1155
|
size = u.math.shape(a)
|
1160
1156
|
size = _size2shape(size)
|
1161
|
-
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1157
|
+
r = power(
|
1158
|
+
self.__get_key(key),
|
1159
|
+
a,
|
1160
|
+
shape=size,
|
1161
|
+
dtype=dtype or environ.dftype(),
|
1162
|
+
)
|
1165
1163
|
return r
|
1166
1164
|
|
1167
1165
|
def f(
|
@@ -1177,14 +1175,12 @@ class RandomState(State):
|
|
1177
1175
|
if size is None:
|
1178
1176
|
size = jnp.broadcast_shapes(u.math.shape(dfnum), u.math.shape(dfden))
|
1179
1177
|
size = _size2shape(size)
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1183
|
-
|
1184
|
-
|
1185
|
-
|
1186
|
-
jax.ShapeDtypeStruct(size, dtype),
|
1187
|
-
dfnum, dfden
|
1178
|
+
r = f(
|
1179
|
+
self.__get_key(key),
|
1180
|
+
dfnum,
|
1181
|
+
dfden,
|
1182
|
+
shape=size,
|
1183
|
+
dtype=dtype or environ.dftype(),
|
1188
1184
|
)
|
1189
1185
|
return r
|
1190
1186
|
|
@@ -1200,23 +1196,20 @@ class RandomState(State):
|
|
1200
1196
|
ngood = _check_py_seq(ngood)
|
1201
1197
|
nbad = _check_py_seq(nbad)
|
1202
1198
|
nsample = _check_py_seq(nsample)
|
1203
|
-
|
1204
1199
|
if size is None:
|
1205
|
-
size = lax.broadcast_shapes(
|
1206
|
-
|
1207
|
-
|
1200
|
+
size = lax.broadcast_shapes(
|
1201
|
+
u.math.shape(ngood),
|
1202
|
+
u.math.shape(nbad),
|
1203
|
+
u.math.shape(nsample)
|
1204
|
+
)
|
1208
1205
|
size = _size2shape(size)
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1216
|
-
size=size
|
1217
|
-
).astype(dtype),
|
1218
|
-
jax.ShapeDtypeStruct(size, dtype),
|
1219
|
-
d
|
1206
|
+
r = hypergeometric(
|
1207
|
+
self.__get_key(key),
|
1208
|
+
ngood,
|
1209
|
+
nbad,
|
1210
|
+
nsample,
|
1211
|
+
shape=size,
|
1212
|
+
dtype=dtype or environ.ditype(),
|
1220
1213
|
)
|
1221
1214
|
return r
|
1222
1215
|
|
@@ -1231,11 +1224,11 @@ class RandomState(State):
|
|
1231
1224
|
if size is None:
|
1232
1225
|
size = u.math.shape(p)
|
1233
1226
|
size = _size2shape(size)
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1227
|
+
r = logseries(
|
1228
|
+
self.__get_key(key),
|
1229
|
+
p,
|
1230
|
+
shape=size,
|
1231
|
+
dtype=dtype or environ.ditype()
|
1239
1232
|
)
|
1240
1233
|
return r
|
1241
1234
|
|
@@ -1256,15 +1249,13 @@ class RandomState(State):
|
|
1256
1249
|
u.math.shape(dfden),
|
1257
1250
|
u.math.shape(nonc))
|
1258
1251
|
size = _size2shape(size)
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1262
|
-
|
1263
|
-
|
1264
|
-
|
1265
|
-
|
1266
|
-
jax.ShapeDtypeStruct(size, dtype),
|
1267
|
-
d
|
1252
|
+
r = noncentral_f(
|
1253
|
+
self.__get_key(key),
|
1254
|
+
dfnum,
|
1255
|
+
dfden,
|
1256
|
+
nonc,
|
1257
|
+
shape=size,
|
1258
|
+
dtype=dtype or environ.dftype(),
|
1268
1259
|
)
|
1269
1260
|
return r
|
1270
1261
|
|
@@ -1327,291 +1318,3 @@ class RandomState(State):
|
|
1327
1318
|
|
1328
1319
|
# default random generator
|
1329
1320
|
DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))
|
1330
|
-
|
1331
|
-
|
1332
|
-
# ---------------------------------------------------------------------------------------------------------------
|
1333
|
-
|
1334
|
-
|
1335
|
-
def _formalize_key(key):
|
1336
|
-
if isinstance(key, int):
|
1337
|
-
return jr.PRNGKey(key) if use_prng_key else jr.key(key)
|
1338
|
-
elif isinstance(key, (jax.Array, np.ndarray)):
|
1339
|
-
if jnp.issubdtype(key.dtype, jax.dtypes.prng_key):
|
1340
|
-
return key
|
1341
|
-
if key.size == 1 and jnp.issubdtype(key.dtype, jnp.integer):
|
1342
|
-
return jr.PRNGKey(key) if use_prng_key else jr.key(key)
|
1343
|
-
|
1344
|
-
if key.dtype != jnp.uint32:
|
1345
|
-
raise TypeError('key must be a int or an array with two uint32.')
|
1346
|
-
if key.size != 2:
|
1347
|
-
raise TypeError('key must be a int or an array with two uint32.')
|
1348
|
-
return u.math.asarray(key, dtype=jnp.uint32)
|
1349
|
-
else:
|
1350
|
-
raise TypeError('key must be a int or an array with two uint32.')
|
1351
|
-
|
1352
|
-
|
1353
|
-
def _size2shape(size):
|
1354
|
-
if size is None:
|
1355
|
-
return ()
|
1356
|
-
elif isinstance(size, (tuple, list)):
|
1357
|
-
return tuple(size)
|
1358
|
-
else:
|
1359
|
-
return (size,)
|
1360
|
-
|
1361
|
-
|
1362
|
-
def _check_shape(
|
1363
|
-
name,
|
1364
|
-
shape,
|
1365
|
-
*param_shapes
|
1366
|
-
):
|
1367
|
-
if param_shapes:
|
1368
|
-
shape_ = lax.broadcast_shapes(shape, *param_shapes)
|
1369
|
-
if shape != shape_:
|
1370
|
-
msg = ("{} parameter shapes must be broadcast-compatible with shape "
|
1371
|
-
"argument, and the result of broadcasting the shapes must equal "
|
1372
|
-
"the shape argument, but got result {} for shape argument {}.")
|
1373
|
-
raise ValueError(msg.format(name, shape_, shape))
|
1374
|
-
|
1375
|
-
|
1376
|
-
def _is_python_scalar(x):
|
1377
|
-
if hasattr(x, 'aval'):
|
1378
|
-
return x.aval.weak_type
|
1379
|
-
elif np.ndim(x) == 0:
|
1380
|
-
return True
|
1381
|
-
elif isinstance(x, (bool, int, float, complex)):
|
1382
|
-
return True
|
1383
|
-
else:
|
1384
|
-
return False
|
1385
|
-
|
1386
|
-
|
1387
|
-
python_scalar_dtypes = {
|
1388
|
-
bool: np.dtype('bool'),
|
1389
|
-
int: np.dtype('int64'),
|
1390
|
-
float: np.dtype('float64'),
|
1391
|
-
complex: np.dtype('complex128'),
|
1392
|
-
}
|
1393
|
-
|
1394
|
-
|
1395
|
-
def _dtype(
|
1396
|
-
x,
|
1397
|
-
*,
|
1398
|
-
canonicalize: bool = False
|
1399
|
-
):
|
1400
|
-
"""Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
|
1401
|
-
if x is None:
|
1402
|
-
raise ValueError(f"Invalid argument to dtype: {x}.")
|
1403
|
-
elif isinstance(x, type) and x in python_scalar_dtypes:
|
1404
|
-
dt = python_scalar_dtypes[x]
|
1405
|
-
elif type(x) in python_scalar_dtypes:
|
1406
|
-
dt = python_scalar_dtypes[type(x)]
|
1407
|
-
elif hasattr(x, 'dtype'):
|
1408
|
-
dt = x.dtype
|
1409
|
-
else:
|
1410
|
-
dt = np.result_type(x)
|
1411
|
-
return dtypes.canonicalize_dtype(dt) if canonicalize else dt
|
1412
|
-
|
1413
|
-
|
1414
|
-
def _const(
|
1415
|
-
example,
|
1416
|
-
val
|
1417
|
-
):
|
1418
|
-
if _is_python_scalar(example):
|
1419
|
-
dtype = dtypes.canonicalize_dtype(type(example))
|
1420
|
-
val = dtypes.scalar_type_of(example)(val)
|
1421
|
-
return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype)
|
1422
|
-
else:
|
1423
|
-
dtype = dtypes.canonicalize_dtype(example.dtype)
|
1424
|
-
return np.array(val, dtype)
|
1425
|
-
|
1426
|
-
|
1427
|
-
@partial(jit, static_argnums=(2,))
|
1428
|
-
def _categorical(
|
1429
|
-
key,
|
1430
|
-
p,
|
1431
|
-
shape
|
1432
|
-
):
|
1433
|
-
# this implementation is fast when event shape is small, and slow otherwise
|
1434
|
-
# Ref: https://stackoverflow.com/a/34190035
|
1435
|
-
shape = shape or p.shape[:-1]
|
1436
|
-
s = jnp.cumsum(p, axis=-1)
|
1437
|
-
r = jr.uniform(key, shape=shape + (1,))
|
1438
|
-
return jnp.sum(s < r, axis=-1)
|
1439
|
-
|
1440
|
-
|
1441
|
-
def _scatter_add_one(
|
1442
|
-
operand,
|
1443
|
-
indices,
|
1444
|
-
updates
|
1445
|
-
):
|
1446
|
-
return lax.scatter_add(
|
1447
|
-
operand,
|
1448
|
-
indices,
|
1449
|
-
updates,
|
1450
|
-
lax.ScatterDimensionNumbers(
|
1451
|
-
update_window_dims=(),
|
1452
|
-
inserted_window_dims=(0,),
|
1453
|
-
scatter_dims_to_operand_dims=(0,),
|
1454
|
-
),
|
1455
|
-
)
|
1456
|
-
|
1457
|
-
|
1458
|
-
def _reshape(x, shape):
|
1459
|
-
if isinstance(x, (int, float, np.ndarray, np.generic)):
|
1460
|
-
return np.reshape(x, shape)
|
1461
|
-
else:
|
1462
|
-
return jnp.reshape(x, shape)
|
1463
|
-
|
1464
|
-
|
1465
|
-
def _promote_shapes(
|
1466
|
-
*args,
|
1467
|
-
shape=()
|
1468
|
-
):
|
1469
|
-
# adapted from lax.lax_numpy
|
1470
|
-
if len(args) < 2 and not shape:
|
1471
|
-
return args
|
1472
|
-
else:
|
1473
|
-
shapes = [u.math.shape(arg) for arg in args]
|
1474
|
-
num_dims = len(lax.broadcast_shapes(shape, *shapes))
|
1475
|
-
return [
|
1476
|
-
_reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
|
1477
|
-
for arg, s in zip(args, shapes)
|
1478
|
-
]
|
1479
|
-
|
1480
|
-
|
1481
|
-
@partial(jit, static_argnums=(3, 4))
|
1482
|
-
def _multinomial(
|
1483
|
-
key,
|
1484
|
-
p,
|
1485
|
-
n,
|
1486
|
-
n_max,
|
1487
|
-
shape=()
|
1488
|
-
):
|
1489
|
-
if u.math.shape(n) != u.math.shape(p)[:-1]:
|
1490
|
-
broadcast_shape = lax.broadcast_shapes(u.math.shape(n), u.math.shape(p)[:-1])
|
1491
|
-
n = jnp.broadcast_to(n, broadcast_shape)
|
1492
|
-
p = jnp.broadcast_to(p, broadcast_shape + u.math.shape(p)[-1:])
|
1493
|
-
shape = shape or p.shape[:-1]
|
1494
|
-
if n_max == 0:
|
1495
|
-
return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
|
1496
|
-
# get indices from categorical distribution then gather the result
|
1497
|
-
indices = _categorical(key, p, (n_max,) + shape)
|
1498
|
-
# mask out values when counts is heterogeneous
|
1499
|
-
if jnp.ndim(n) > 0:
|
1500
|
-
mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
|
1501
|
-
mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
|
1502
|
-
excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1),
|
1503
|
-
jnp.zeros(u.math.shape(n) + (p.shape[-1] - 1,))],
|
1504
|
-
-1)
|
1505
|
-
else:
|
1506
|
-
mask = 1
|
1507
|
-
excess = 0
|
1508
|
-
# NB: we transpose to move batch shape to the front
|
1509
|
-
indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
|
1510
|
-
samples_2D = vmap(_scatter_add_one)(
|
1511
|
-
jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
|
1512
|
-
jnp.expand_dims(indices_2D, axis=-1),
|
1513
|
-
jnp.ones(indices_2D.shape, dtype=indices.dtype)
|
1514
|
-
)
|
1515
|
-
return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
|
1516
|
-
|
1517
|
-
|
1518
|
-
@partial(jit, static_argnums=(2, 3), static_argnames=['shape', 'dtype'])
|
1519
|
-
def _von_mises_centered(
|
1520
|
-
key,
|
1521
|
-
concentration,
|
1522
|
-
shape,
|
1523
|
-
dtype=None
|
1524
|
-
):
|
1525
|
-
"""Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
|
1526
|
-
|
1527
|
-
Returns
|
1528
|
-
-------
|
1529
|
-
out: array_like
|
1530
|
-
centered samples from von Mises
|
1531
|
-
|
1532
|
-
References
|
1533
|
-
----------
|
1534
|
-
.. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
|
1535
|
-
Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
|
1536
|
-
|
1537
|
-
"""
|
1538
|
-
shape = shape or u.math.shape(concentration)
|
1539
|
-
dtype = dtype or environ.dftype()
|
1540
|
-
concentration = lax.convert_element_type(concentration, dtype)
|
1541
|
-
concentration = jnp.broadcast_to(concentration, shape)
|
1542
|
-
|
1543
|
-
if dtype == jnp.float16:
|
1544
|
-
s_cutoff = 1.8e-1
|
1545
|
-
elif dtype == jnp.float32:
|
1546
|
-
s_cutoff = 2e-2
|
1547
|
-
elif dtype == jnp.float64:
|
1548
|
-
s_cutoff = 1.2e-4
|
1549
|
-
else:
|
1550
|
-
raise ValueError(f"Unsupported dtype: {dtype}")
|
1551
|
-
|
1552
|
-
r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
|
1553
|
-
rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
|
1554
|
-
s_exact = (1.0 + rho ** 2) / (2.0 * rho)
|
1555
|
-
|
1556
|
-
s_approximate = 1.0 / concentration
|
1557
|
-
|
1558
|
-
s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
|
1559
|
-
|
1560
|
-
def cond_fn(
|
1561
|
-
*args
|
1562
|
-
):
|
1563
|
-
"""check if all are done or reached max number of iterations"""
|
1564
|
-
i, _, done, _, _ = args[0]
|
1565
|
-
return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
|
1566
|
-
|
1567
|
-
def body_fn(
|
1568
|
-
*args
|
1569
|
-
):
|
1570
|
-
i, key, done, _, w = args[0]
|
1571
|
-
uni_ukey, uni_vkey, key = jr.split(key, 3)
|
1572
|
-
u_ = jr.uniform(
|
1573
|
-
key=uni_ukey,
|
1574
|
-
shape=shape,
|
1575
|
-
dtype=concentration.dtype,
|
1576
|
-
minval=-1.0,
|
1577
|
-
maxval=1.0,
|
1578
|
-
)
|
1579
|
-
z = jnp.cos(jnp.pi * u_)
|
1580
|
-
w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
|
1581
|
-
y = concentration * (s - w)
|
1582
|
-
v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
|
1583
|
-
accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
|
1584
|
-
return i + 1, key, accept | done, u_, w
|
1585
|
-
|
1586
|
-
init_done = jnp.zeros(shape, dtype=bool)
|
1587
|
-
init_u = jnp.zeros(shape)
|
1588
|
-
init_w = jnp.zeros(shape)
|
1589
|
-
|
1590
|
-
_, _, done, uu, w = lax.while_loop(
|
1591
|
-
cond_fun=cond_fn,
|
1592
|
-
body_fun=body_fn,
|
1593
|
-
init_val=(jnp.array(0), key, init_done, init_u, init_w),
|
1594
|
-
)
|
1595
|
-
|
1596
|
-
return jnp.sign(uu) * jnp.arccos(w)
|
1597
|
-
|
1598
|
-
|
1599
|
-
def _loc_scale(
|
1600
|
-
loc,
|
1601
|
-
scale,
|
1602
|
-
value
|
1603
|
-
):
|
1604
|
-
if loc is None:
|
1605
|
-
if scale is None:
|
1606
|
-
return value
|
1607
|
-
else:
|
1608
|
-
return value * scale
|
1609
|
-
else:
|
1610
|
-
if scale is None:
|
1611
|
-
return value + loc
|
1612
|
-
else:
|
1613
|
-
return value * scale + loc
|
1614
|
-
|
1615
|
-
|
1616
|
-
def _check_py_seq(seq):
|
1617
|
-
return u.math.asarray(seq) if isinstance(seq, (tuple, list)) else seq
|