brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. brainstate/__init__.py +2 -4
  2. brainstate/_deprecation_test.py +2 -24
  3. brainstate/_state.py +540 -35
  4. brainstate/_state_test.py +1085 -8
  5. brainstate/graph/_operation.py +1 -5
  6. brainstate/mixin.py +14 -0
  7. brainstate/nn/__init__.py +42 -33
  8. brainstate/nn/_collective_ops.py +2 -0
  9. brainstate/nn/_common_test.py +0 -20
  10. brainstate/nn/_delay.py +1 -1
  11. brainstate/nn/_dropout_test.py +9 -6
  12. brainstate/nn/_dynamics.py +67 -464
  13. brainstate/nn/_dynamics_test.py +0 -14
  14. brainstate/nn/_embedding.py +7 -7
  15. brainstate/nn/_exp_euler.py +9 -9
  16. brainstate/nn/_linear.py +21 -21
  17. brainstate/nn/_module.py +25 -18
  18. brainstate/nn/_normalizations.py +27 -27
  19. brainstate/random/__init__.py +6 -6
  20. brainstate/random/{_rand_funs.py → _fun.py} +1 -1
  21. brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
  22. brainstate/random/_impl.py +672 -0
  23. brainstate/random/{_rand_seed.py → _seed.py} +1 -1
  24. brainstate/random/{_rand_state.py → _state.py} +121 -418
  25. brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
  26. brainstate/transform/__init__.py +6 -9
  27. brainstate/transform/_conditions.py +2 -2
  28. brainstate/transform/_find_state.py +200 -0
  29. brainstate/transform/_find_state_test.py +84 -0
  30. brainstate/transform/_make_jaxpr.py +221 -61
  31. brainstate/transform/_make_jaxpr_test.py +125 -1
  32. brainstate/transform/_mapping.py +287 -209
  33. brainstate/transform/_mapping_test.py +94 -184
  34. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
  35. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
  36. brainstate/transform/_eval_shape.py +0 -145
  37. brainstate/transform/_eval_shape_test.py +0 -38
  38. brainstate/transform/_random.py +0 -171
  39. /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
  40. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  41. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
  42. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,6 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from functools import partial
19
18
  from operator import index
20
19
  from typing import Optional
21
20
 
@@ -24,12 +23,16 @@ import jax
24
23
  import jax.numpy as jnp
25
24
  import jax.random as jr
26
25
  import numpy as np
27
- from jax import jit, vmap
28
- from jax import lax, core, dtypes
26
+ from jax import lax, core
29
27
 
30
28
  from brainstate import environ
31
29
  from brainstate._state import State
32
30
  from brainstate.typing import DTypeLike, Size, SeedOrKey
31
+ from ._impl import (
32
+ multinomial, von_mises_centered, const,
33
+ formalize_key, _loc_scale, _size2shape, _check_py_seq, _check_shape,
34
+ noncentral_f, logseries, hypergeometric, f, power, zipf
35
+ )
33
36
 
34
37
  __all__ = [
35
38
  'RandomState',
@@ -73,14 +76,10 @@ class RandomState(State):
73
76
 
74
77
  self._backup = None
75
78
 
76
- def __repr__(
77
- self
78
- ):
79
+ def __repr__(self):
79
80
  return f'{self.__class__.__name__}({self.value})'
80
81
 
81
- def check_if_deleted(
82
- self
83
- ):
82
+ def check_if_deleted(self):
84
83
  if not use_prng_key and isinstance(self._value, np.ndarray):
85
84
  self._value = jr.key(np.random.randint(0, 10000))
86
85
 
@@ -91,6 +90,11 @@ class RandomState(State):
91
90
  ):
92
91
  self.seed()
93
92
 
93
+ @staticmethod
94
+ def _batch_keys(batch_size: int):
95
+ key = jr.PRNGKey(0) if use_prng_key else jr.key(0)
96
+ return jr.split(key, batch_size)
97
+
94
98
  # ------------------- #
95
99
  # seed and random key #
96
100
  # ------------------- #
@@ -126,7 +130,7 @@ class RandomState(State):
126
130
  """
127
131
  with jax.ensure_compile_time_eval():
128
132
  if seed_or_key is None:
129
- seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
133
+ seed_or_key = np.random.randint(0, 10000000, 2, dtype=np.uint32)
130
134
  if np.size(seed_or_key) == 1:
131
135
  if isinstance(seed_or_key, int):
132
136
  key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
@@ -156,7 +160,7 @@ class RandomState(State):
156
160
  n: int, optional
157
161
  The number of seeds to generate.
158
162
  backup : bool, optional
159
- Whether to backup the current key.
163
+ Whether to back up the current key.
160
164
 
161
165
  Returns
162
166
  -------
@@ -193,6 +197,9 @@ class RandomState(State):
193
197
  else:
194
198
  self.value = jr.split(self.value, n)
195
199
 
200
+ def __get_key(self, key):
201
+ return self.split_key() if key is None else formalize_key(key, use_prng_key)
202
+
196
203
  # ---------------- #
197
204
  # random functions #
198
205
  # ---------------- #
@@ -203,7 +210,7 @@ class RandomState(State):
203
210
  key: Optional[SeedOrKey] = None,
204
211
  dtype: DTypeLike = None
205
212
  ):
206
- key = self.split_key() if key is None else _formalize_key(key)
213
+ key = self.__get_key(key)
207
214
  dtype = dtype or environ.dftype()
208
215
  r = jr.uniform(key, dn, dtype)
209
216
  return r
@@ -222,9 +229,8 @@ class RandomState(State):
222
229
  high = _check_py_seq(high)
223
230
  low = _check_py_seq(low)
224
231
  if size is None:
225
- size = lax.broadcast_shapes(u.math.shape(low),
226
- u.math.shape(high))
227
- key = self.split_key() if key is None else _formalize_key(key)
232
+ size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
233
+ key = self.__get_key(key)
228
234
  dtype = dtype or environ.ditype()
229
235
  r = jr.randint(key,
230
236
  shape=_size2shape(size),
@@ -247,7 +253,7 @@ class RandomState(State):
247
253
  high += 1
248
254
  if size is None:
249
255
  size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
250
- key = self.split_key() if key is None else _formalize_key(key)
256
+ key = self.__get_key(key)
251
257
  dtype = dtype or environ.ditype()
252
258
  r = jr.randint(key,
253
259
  shape=_size2shape(size),
@@ -262,9 +268,8 @@ class RandomState(State):
262
268
  key: Optional[SeedOrKey] = None,
263
269
  dtype: DTypeLike = None
264
270
  ):
265
- key = self.split_key() if key is None else _formalize_key(key)
266
- dtype = dtype or environ.dftype()
267
- r = jr.normal(key, shape=dn, dtype=dtype)
271
+ key = self.__get_key(key)
272
+ r = jr.normal(key, shape=dn, dtype=dtype or environ.dftype())
268
273
  return r
269
274
 
270
275
  def random(
@@ -273,9 +278,8 @@ class RandomState(State):
273
278
  key: Optional[SeedOrKey] = None,
274
279
  dtype: DTypeLike = None
275
280
  ):
276
- dtype = dtype or environ.dftype()
277
- key = self.split_key() if key is None else _formalize_key(key)
278
- r = jr.uniform(key, _size2shape(size), dtype)
281
+ key = self.__get_key(key)
282
+ r = jr.uniform(key, _size2shape(size), dtype=dtype or environ.dftype())
279
283
  return r
280
284
 
281
285
  def random_sample(
@@ -284,7 +288,7 @@ class RandomState(State):
284
288
  key: Optional[SeedOrKey] = None,
285
289
  dtype: DTypeLike = None
286
290
  ):
287
- r = self.random(size=size, key=key, dtype=dtype)
291
+ r = self.random(size=size, key=key, dtype=dtype or environ.dftype())
288
292
  return r
289
293
 
290
294
  def ranf(
@@ -293,7 +297,7 @@ class RandomState(State):
293
297
  key: Optional[SeedOrKey] = None,
294
298
  dtype: DTypeLike = None
295
299
  ):
296
- r = self.random(size=size, key=key, dtype=dtype)
300
+ r = self.random(size=size, key=key, dtype=dtype or environ.dftype())
297
301
  return r
298
302
 
299
303
  def sample(
@@ -302,7 +306,7 @@ class RandomState(State):
302
306
  key: Optional[SeedOrKey] = None,
303
307
  dtype: DTypeLike = None
304
308
  ):
305
- r = self.random(size=size, key=key, dtype=dtype)
309
+ r = self.random(size=size, key=key, dtype=dtype or environ.dftype())
306
310
  return r
307
311
 
308
312
  def choice(
@@ -316,7 +320,7 @@ class RandomState(State):
316
320
  a = _check_py_seq(a)
317
321
  a, unit = u.split_mantissa_unit(a)
318
322
  p = _check_py_seq(p)
319
- key = self.split_key() if key is None else _formalize_key(key)
323
+ key = self.__get_key(key)
320
324
  r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
321
325
  return u.maybe_decimal(r * unit)
322
326
 
@@ -329,7 +333,7 @@ class RandomState(State):
329
333
  ):
330
334
  x = _check_py_seq(x)
331
335
  x, unit = u.split_mantissa_unit(x)
332
- key = self.split_key() if key is None else _formalize_key(key)
336
+ key = self.__get_key(key)
333
337
  r = jr.permutation(key, x, axis, independent=independent)
334
338
  return u.maybe_decimal(r * unit)
335
339
 
@@ -353,9 +357,8 @@ class RandomState(State):
353
357
  b = _check_py_seq(b)
354
358
  if size is None:
355
359
  size = lax.broadcast_shapes(u.math.shape(a), u.math.shape(b))
356
- key = self.split_key() if key is None else _formalize_key(key)
357
- dtype = dtype or environ.dftype()
358
- r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype)
360
+ key = self.__get_key(key)
361
+ r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype or environ.dftype())
359
362
  return r
360
363
 
361
364
  def exponential(
@@ -367,9 +370,8 @@ class RandomState(State):
367
370
  ):
368
371
  if size is None:
369
372
  size = u.math.shape(scale)
370
- key = self.split_key() if key is None else _formalize_key(key)
371
- dtype = dtype or environ.dftype()
372
- r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
373
+ key = self.__get_key(key)
374
+ r = jr.exponential(key, shape=_size2shape(size), dtype=dtype or environ.dftype())
373
375
  if scale is not None:
374
376
  scale = u.math.asarray(scale, dtype=dtype)
375
377
  r = r / scale
@@ -387,9 +389,8 @@ class RandomState(State):
387
389
  scale = _check_py_seq(scale)
388
390
  if size is None:
389
391
  size = lax.broadcast_shapes(u.math.shape(shape), u.math.shape(scale))
390
- key = self.split_key() if key is None else _formalize_key(key)
391
- dtype = dtype or environ.dftype()
392
- r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
392
+ key = self.__get_key(key)
393
+ r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype or environ.dftype())
393
394
  if scale is not None:
394
395
  r = r * scale
395
396
  return r
@@ -406,9 +407,8 @@ class RandomState(State):
406
407
  scale = _check_py_seq(scale)
407
408
  if size is None:
408
409
  size = lax.broadcast_shapes(u.math.shape(loc), u.math.shape(scale))
409
- key = self.split_key() if key is None else _formalize_key(key)
410
- dtype = dtype or environ.dftype()
411
- r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype))
410
+ key = self.__get_key(key)
411
+ r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype or environ.dftype()))
412
412
  return r
413
413
 
414
414
  def laplace(
@@ -423,9 +423,8 @@ class RandomState(State):
423
423
  scale = _check_py_seq(scale)
424
424
  if size is None:
425
425
  size = lax.broadcast_shapes(u.math.shape(loc), u.math.shape(scale))
426
- key = self.split_key() if key is None else _formalize_key(key)
427
- dtype = dtype or environ.dftype()
428
- r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype))
426
+ key = self.__get_key(key)
427
+ r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype or environ.dftype()))
429
428
  return r
430
429
 
431
430
  def logistic(
@@ -443,9 +442,8 @@ class RandomState(State):
443
442
  u.math.shape(loc) if loc is not None else (),
444
443
  u.math.shape(scale) if scale is not None else ()
445
444
  )
446
- key = self.split_key() if key is None else _formalize_key(key)
447
- dtype = dtype or environ.dftype()
448
- r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
445
+ key = self.__get_key(key)
446
+ r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype or environ.dftype()))
449
447
  return r
450
448
 
451
449
  def normal(
@@ -463,7 +461,7 @@ class RandomState(State):
463
461
  u.math.shape(scale) if scale is not None else (),
464
462
  u.math.shape(loc) if loc is not None else ()
465
463
  )
466
- key = self.split_key() if key is None else _formalize_key(key)
464
+ key = self.__get_key(key)
467
465
  dtype = dtype or environ.dftype()
468
466
  r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
469
467
  return r
@@ -477,7 +475,7 @@ class RandomState(State):
477
475
  ):
478
476
  if size is None:
479
477
  size = u.math.shape(a)
480
- key = self.split_key() if key is None else _formalize_key(key)
478
+ key = self.__get_key(key)
481
479
  dtype = dtype or environ.dftype()
482
480
  a = u.math.asarray(a, dtype=dtype)
483
481
  r = jr.pareto(key, b=a, shape=_size2shape(size), dtype=dtype)
@@ -493,7 +491,7 @@ class RandomState(State):
493
491
  lam = _check_py_seq(lam)
494
492
  if size is None:
495
493
  size = u.math.shape(lam)
496
- key = self.split_key() if key is None else _formalize_key(key)
494
+ key = self.__get_key(key)
497
495
  dtype = dtype or environ.ditype()
498
496
  r = jr.poisson(key, lam=lam, shape=_size2shape(size), dtype=dtype)
499
497
  return r
@@ -504,7 +502,7 @@ class RandomState(State):
504
502
  key: Optional[SeedOrKey] = None,
505
503
  dtype: DTypeLike = None
506
504
  ):
507
- key = self.split_key() if key is None else _formalize_key(key)
505
+ key = self.__get_key(key)
508
506
  dtype = dtype or environ.dftype()
509
507
  r = jr.cauchy(key, shape=_size2shape(size), dtype=dtype)
510
508
  return r
@@ -515,7 +513,7 @@ class RandomState(State):
515
513
  key: Optional[SeedOrKey] = None,
516
514
  dtype: DTypeLike = None
517
515
  ):
518
- key = self.split_key() if key is None else _formalize_key(key)
516
+ key = self.__get_key(key)
519
517
  dtype = dtype or environ.dftype()
520
518
  r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
521
519
  return r
@@ -530,7 +528,7 @@ class RandomState(State):
530
528
  shape = _check_py_seq(shape)
531
529
  if size is None:
532
530
  size = u.math.shape(shape) if shape is not None else ()
533
- key = self.split_key() if key is None else _formalize_key(key)
531
+ key = self.__get_key(key)
534
532
  dtype = dtype or environ.dftype()
535
533
  r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
536
534
  return r
@@ -541,7 +539,7 @@ class RandomState(State):
541
539
  key: Optional[SeedOrKey] = None,
542
540
  dtype: DTypeLike = None
543
541
  ):
544
- key = self.split_key() if key is None else _formalize_key(key)
542
+ key = self.__get_key(key)
545
543
  dtype = dtype or environ.dftype()
546
544
  r = jr.normal(key, shape=_size2shape(size), dtype=dtype)
547
545
  return r
@@ -556,7 +554,7 @@ class RandomState(State):
556
554
  df = _check_py_seq(df)
557
555
  if size is None:
558
556
  size = u.math.shape(size) if size is not None else ()
559
- key = self.split_key() if key is None else _formalize_key(key)
557
+ key = self.__get_key(key)
560
558
  dtype = dtype or environ.dftype()
561
559
  r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
562
560
  return r
@@ -573,17 +571,12 @@ class RandomState(State):
573
571
  high = u.Quantity(_check_py_seq(high)).to(unit).mantissa
574
572
  if size is None:
575
573
  size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
576
- key = self.split_key() if key is None else _formalize_key(key)
574
+ key = self.__get_key(key)
577
575
  dtype = dtype or environ.dftype()
578
576
  r = jr.uniform(key, _size2shape(size), dtype=dtype, minval=low, maxval=high)
579
577
  return u.maybe_decimal(r * unit)
580
578
 
581
- def __norm_cdf(
582
- self,
583
- x,
584
- sqrt2,
585
- dtype
586
- ):
579
+ def __norm_cdf(self, x, sqrt2, dtype):
587
580
  # Computes standard normal cumulative distribution function
588
581
  return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
589
582
 
@@ -639,7 +632,7 @@ class RandomState(State):
639
632
 
640
633
  # Uniformly fill tensor with values from [l, u], then translate to
641
634
  # [2l-1, 2u-1].
642
- key = self.split_key() if key is None else _formalize_key(key)
635
+ key = self.__get_key(key)
643
636
  out = jr.uniform(
644
637
  key, size, dtype,
645
638
  minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
@@ -677,7 +670,7 @@ class RandomState(State):
677
670
  jit_error_if(jnp.any(jnp.logical_or(p < 0, p > 1)), self._check_p, p=p)
678
671
  if size is None:
679
672
  size = u.math.shape(p)
680
- key = self.split_key() if key is None else _formalize_key(key)
673
+ key = self.__get_key(key)
681
674
  r = jr.bernoulli(key, p=p, shape=_size2shape(size))
682
675
  return r
683
676
 
@@ -689,6 +682,7 @@ class RandomState(State):
689
682
  key: Optional[SeedOrKey] = None,
690
683
  dtype: DTypeLike = None
691
684
  ):
685
+ dtype = dtype or environ.dftype()
692
686
  mean = _check_py_seq(mean)
693
687
  sigma = _check_py_seq(sigma)
694
688
  mean = u.math.asarray(mean, dtype=dtype)
@@ -702,7 +696,7 @@ class RandomState(State):
702
696
  u.math.shape(mean) if mean is not None else (),
703
697
  u.math.shape(sigma) if sigma is not None else ()
704
698
  )
705
- key = self.split_key() if key is None else _formalize_key(key)
699
+ key = self.__get_key(key)
706
700
  dtype = dtype or environ.dftype()
707
701
  samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
708
702
  samples = _loc_scale(mean, sigma, samples)
@@ -729,7 +723,7 @@ class RandomState(State):
729
723
  )
730
724
  if size is None:
731
725
  size = jnp.broadcast_shapes(u.math.shape(n), u.math.shape(p))
732
- key = self.split_key() if key is None else _formalize_key(key)
726
+ key = self.__get_key(key)
733
727
  r = jr.binomial(key, n, p, shape=_size2shape(size))
734
728
  dtype = dtype or environ.ditype()
735
729
  return u.math.asarray(r, dtype=dtype)
@@ -742,7 +736,7 @@ class RandomState(State):
742
736
  dtype: DTypeLike = None
743
737
  ):
744
738
  df = _check_py_seq(df)
745
- key = self.split_key() if key is None else _formalize_key(key)
739
+ key = self.__get_key(key)
746
740
  dtype = dtype or environ.dftype()
747
741
  if size is None:
748
742
  if jnp.ndim(df) == 0:
@@ -762,7 +756,7 @@ class RandomState(State):
762
756
  key: Optional[SeedOrKey] = None,
763
757
  dtype: DTypeLike = None
764
758
  ):
765
- key = self.split_key() if key is None else _formalize_key(key)
759
+ key = self.__get_key(key)
766
760
  alpha = _check_py_seq(alpha)
767
761
  dtype = dtype or environ.dftype()
768
762
  r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size), dtype=dtype)
@@ -778,7 +772,7 @@ class RandomState(State):
778
772
  p = _check_py_seq(p)
779
773
  if size is None:
780
774
  size = u.math.shape(p)
781
- key = self.split_key() if key is None else _formalize_key(key)
775
+ key = self.__get_key(key)
782
776
  dtype = dtype or environ.dftype()
783
777
  u_ = jr.uniform(key, size, dtype)
784
778
  r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p))
@@ -796,7 +790,7 @@ class RandomState(State):
796
790
  dtype: DTypeLike = None,
797
791
  check_valid: bool = True
798
792
  ):
799
- key = self.split_key() if key is None else _formalize_key(key)
793
+ key = self.__get_key(key)
800
794
  n = _check_py_seq(n)
801
795
  pvals = _check_py_seq(pvals)
802
796
  if check_valid:
@@ -807,7 +801,7 @@ class RandomState(State):
807
801
  size = _size2shape(size)
808
802
  n_max = int(np.max(jax.device_get(n)))
809
803
  batch_shape = lax.broadcast_shapes(u.math.shape(pvals)[:-1], u.math.shape(n))
810
- r = _multinomial(key, pvals, n, n_max, batch_shape + size)
804
+ r = multinomial(key, pvals, n, n_max=n_max, shape=batch_shape + size)
811
805
  dtype = dtype or environ.ditype()
812
806
  return u.math.asarray(r, dtype=dtype)
813
807
 
@@ -832,7 +826,7 @@ class RandomState(State):
832
826
  cov = cov.mantissa if isinstance(cov, u.Quantity) else cov
833
827
  unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
834
828
 
835
- key = self.split_key() if key is None else _formalize_key(key)
829
+ key = self.__get_key(key)
836
830
  if not jnp.ndim(mean) >= 1:
837
831
  raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
838
832
  if not jnp.ndim(cov) >= 2:
@@ -869,7 +863,7 @@ class RandomState(State):
869
863
  scale = _check_py_seq(scale)
870
864
  if size is None:
871
865
  size = u.math.shape(scale)
872
- key = self.split_key() if key is None else _formalize_key(key)
866
+ key = self.__get_key(key)
873
867
  dtype = dtype or environ.dftype()
874
868
  x = jnp.sqrt(-2. * jnp.log(jr.uniform(key, shape=_size2shape(size), dtype=dtype)))
875
869
  r = x * scale
@@ -880,7 +874,7 @@ class RandomState(State):
880
874
  size: Optional[Size] = None,
881
875
  key: Optional[SeedOrKey] = None
882
876
  ):
883
- key = self.split_key() if key is None else _formalize_key(key)
877
+ key = self.__get_key(key)
884
878
  bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
885
879
  r = 2 * bernoulli_samples - 1
886
880
  return r
@@ -893,14 +887,14 @@ class RandomState(State):
893
887
  key: Optional[SeedOrKey] = None,
894
888
  dtype: DTypeLike = None
895
889
  ):
896
- key = self.split_key() if key is None else _formalize_key(key)
890
+ key = self.__get_key(key)
897
891
  dtype = dtype or environ.dftype()
898
892
  mu = u.math.asarray(_check_py_seq(mu), dtype=dtype)
899
893
  kappa = u.math.asarray(_check_py_seq(kappa), dtype=dtype)
900
894
  if size is None:
901
895
  size = lax.broadcast_shapes(u.math.shape(mu), u.math.shape(kappa))
902
896
  size = _size2shape(size)
903
- samples = _von_mises_centered(key, kappa, size, dtype=dtype)
897
+ samples = von_mises_centered(key, kappa, size, dtype=dtype)
904
898
  samples = samples + mu
905
899
  samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
906
900
  return samples
@@ -912,7 +906,7 @@ class RandomState(State):
912
906
  key: Optional[SeedOrKey] = None,
913
907
  dtype: DTypeLike = None
914
908
  ):
915
- key = self.split_key() if key is None else _formalize_key(key)
909
+ key = self.__get_key(key)
916
910
  a = _check_py_seq(a)
917
911
  if size is None:
918
912
  size = u.math.shape(a)
@@ -933,7 +927,7 @@ class RandomState(State):
933
927
  key: Optional[SeedOrKey] = None,
934
928
  dtype: DTypeLike = None
935
929
  ):
936
- key = self.split_key() if key is None else _formalize_key(key)
930
+ key = self.__get_key(key)
937
931
  a = _check_py_seq(a)
938
932
  scale = _check_py_seq(scale)
939
933
  if size is None:
@@ -955,7 +949,7 @@ class RandomState(State):
955
949
  key: Optional[SeedOrKey] = None,
956
950
  dtype: DTypeLike = None
957
951
  ):
958
- key = self.split_key() if key is None else _formalize_key(key)
952
+ key = self.__get_key(key)
959
953
  shape = _size2shape(size) + (3,)
960
954
  dtype = dtype or environ.dftype()
961
955
  norm_rvs = jr.normal(key=key, shape=shape, dtype=dtype)
@@ -979,7 +973,7 @@ class RandomState(State):
979
973
  if key is None:
980
974
  keys = self.split_key(2)
981
975
  else:
982
- keys = jr.split(_formalize_key(key), 2)
976
+ keys = jr.split(formalize_key(key, use_prng_key), 2)
983
977
  rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0], dtype=environ.dftype())
984
978
  r = self.poisson(lam=rate, key=keys[1], dtype=dtype or environ.ditype())
985
979
  return r
@@ -993,7 +987,7 @@ class RandomState(State):
993
987
  dtype: DTypeLike = None
994
988
  ):
995
989
  dtype = dtype or environ.dftype()
996
- key = self.split_key() if key is None else _formalize_key(key)
990
+ key = self.__get_key(key)
997
991
  mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
998
992
  scale = u.math.asarray(_check_py_seq(scale), dtype=dtype)
999
993
  if size is None:
@@ -1049,9 +1043,9 @@ class RandomState(State):
1049
1043
  if key is None:
1050
1044
  keys = self.split_key(2)
1051
1045
  else:
1052
- keys = jr.split(_formalize_key(key), 2)
1046
+ keys = jr.split(formalize_key(key, use_prng_key), 2)
1053
1047
  n = jr.normal(keys[0], size, dtype=dtype)
1054
- two = _const(n, 2)
1048
+ two = const(n, 2)
1055
1049
  half_df = lax.div(df, two)
1056
1050
  g = jr.gamma(keys[1], half_df, size, dtype=dtype)
1057
1051
  r = n * jnp.sqrt(half_df / g)
@@ -1065,7 +1059,7 @@ class RandomState(State):
1065
1059
  dtype: DTypeLike = None
1066
1060
  ):
1067
1061
  dtype = dtype or environ.dftype()
1068
- key = self.split_key() if key is None else _formalize_key(key)
1062
+ key = self.__get_key(key)
1069
1063
  size = _size2shape(size)
1070
1064
  _check_shape("orthogonal", size)
1071
1065
  n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
@@ -1092,7 +1086,7 @@ class RandomState(State):
1092
1086
  if key is None:
1093
1087
  keys = self.split_key(3)
1094
1088
  else:
1095
- keys = jr.split(_formalize_key(key), 3)
1089
+ keys = jr.split(formalize_key(key, use_prng_key), 3)
1096
1090
  i = jr.poisson(keys[0], 0.5 * nonc, shape=size, dtype=environ.ditype())
1097
1091
  n = jr.normal(keys[1], shape=size, dtype=dtype) + jnp.sqrt(nonc)
1098
1092
  cond = jnp.greater(df, 1.0)
@@ -1109,7 +1103,7 @@ class RandomState(State):
1109
1103
  dtype: DTypeLike = None
1110
1104
  ):
1111
1105
  dtype = dtype or environ.dftype()
1112
- key = self.split_key() if key is None else _formalize_key(key)
1106
+ key = self.__get_key(key)
1113
1107
  a = _check_py_seq(a)
1114
1108
  if size is None:
1115
1109
  size = u.math.shape(a)
@@ -1123,7 +1117,7 @@ class RandomState(State):
1123
1117
  size: Optional[Size] = None,
1124
1118
  key: Optional[SeedOrKey] = None
1125
1119
  ):
1126
- key = self.split_key() if key is None else _formalize_key(key)
1120
+ key = self.__get_key(key)
1127
1121
  logits = _check_py_seq(logits)
1128
1122
  if size is None:
1129
1123
  size = list(u.math.shape(logits))
@@ -1141,10 +1135,12 @@ class RandomState(State):
1141
1135
  a = _check_py_seq(a)
1142
1136
  if size is None:
1143
1137
  size = u.math.shape(a)
1144
- dtype = dtype or environ.ditype()
1145
- r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
1146
- jax.ShapeDtypeStruct(size, dtype),
1147
- a)
1138
+ r = zipf(
1139
+ self.__get_key(key),
1140
+ a,
1141
+ shape=size,
1142
+ dtype=dtype or environ.ditype()
1143
+ )
1148
1144
  return r
1149
1145
 
1150
1146
  def power(
@@ -1158,10 +1154,12 @@ class RandomState(State):
1158
1154
  if size is None:
1159
1155
  size = u.math.shape(a)
1160
1156
  size = _size2shape(size)
1161
- dtype = dtype or environ.dftype()
1162
- r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
1163
- jax.ShapeDtypeStruct(size, dtype),
1164
- a)
1157
+ r = power(
1158
+ self.__get_key(key),
1159
+ a,
1160
+ shape=size,
1161
+ dtype=dtype or environ.dftype(),
1162
+ )
1165
1163
  return r
1166
1164
 
1167
1165
  def f(
@@ -1177,14 +1175,12 @@ class RandomState(State):
1177
1175
  if size is None:
1178
1176
  size = jnp.broadcast_shapes(u.math.shape(dfnum), u.math.shape(dfden))
1179
1177
  size = _size2shape(size)
1180
- d = {'dfnum': dfnum, 'dfden': dfden}
1181
- dtype = dtype or environ.dftype()
1182
- r = jax.pure_callback(
1183
- lambda dfnum_, dfden_: np.random.f(dfnum=dfnum_,
1184
- dfden=dfden_,
1185
- size=size).astype(dtype),
1186
- jax.ShapeDtypeStruct(size, dtype),
1187
- dfnum, dfden
1178
+ r = f(
1179
+ self.__get_key(key),
1180
+ dfnum,
1181
+ dfden,
1182
+ shape=size,
1183
+ dtype=dtype or environ.dftype(),
1188
1184
  )
1189
1185
  return r
1190
1186
 
@@ -1200,23 +1196,20 @@ class RandomState(State):
1200
1196
  ngood = _check_py_seq(ngood)
1201
1197
  nbad = _check_py_seq(nbad)
1202
1198
  nsample = _check_py_seq(nsample)
1203
-
1204
1199
  if size is None:
1205
- size = lax.broadcast_shapes(u.math.shape(ngood),
1206
- u.math.shape(nbad),
1207
- u.math.shape(nsample))
1200
+ size = lax.broadcast_shapes(
1201
+ u.math.shape(ngood),
1202
+ u.math.shape(nbad),
1203
+ u.math.shape(nsample)
1204
+ )
1208
1205
  size = _size2shape(size)
1209
- dtype = dtype or environ.ditype()
1210
- d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
1211
- r = jax.pure_callback(
1212
- lambda d: np.random.hypergeometric(
1213
- ngood=d['ngood'],
1214
- nbad=d['nbad'],
1215
- nsample=d['nsample'],
1216
- size=size
1217
- ).astype(dtype),
1218
- jax.ShapeDtypeStruct(size, dtype),
1219
- d
1206
+ r = hypergeometric(
1207
+ self.__get_key(key),
1208
+ ngood,
1209
+ nbad,
1210
+ nsample,
1211
+ shape=size,
1212
+ dtype=dtype or environ.ditype(),
1220
1213
  )
1221
1214
  return r
1222
1215
 
@@ -1231,11 +1224,11 @@ class RandomState(State):
1231
1224
  if size is None:
1232
1225
  size = u.math.shape(p)
1233
1226
  size = _size2shape(size)
1234
- dtype = dtype or environ.ditype()
1235
- r = jax.pure_callback(
1236
- lambda p: np.random.logseries(p=p, size=size).astype(dtype),
1237
- jax.ShapeDtypeStruct(size, dtype),
1238
- p
1227
+ r = logseries(
1228
+ self.__get_key(key),
1229
+ p,
1230
+ shape=size,
1231
+ dtype=dtype or environ.ditype()
1239
1232
  )
1240
1233
  return r
1241
1234
 
@@ -1256,15 +1249,13 @@ class RandomState(State):
1256
1249
  u.math.shape(dfden),
1257
1250
  u.math.shape(nonc))
1258
1251
  size = _size2shape(size)
1259
- d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
1260
- dtype = dtype or environ.dftype()
1261
- r = jax.pure_callback(
1262
- lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
1263
- dfden=x['dfden'],
1264
- nonc=x['nonc'],
1265
- size=size).astype(dtype),
1266
- jax.ShapeDtypeStruct(size, dtype),
1267
- d
1252
+ r = noncentral_f(
1253
+ self.__get_key(key),
1254
+ dfnum,
1255
+ dfden,
1256
+ nonc,
1257
+ shape=size,
1258
+ dtype=dtype or environ.dftype(),
1268
1259
  )
1269
1260
  return r
1270
1261
 
@@ -1327,291 +1318,3 @@ class RandomState(State):
1327
1318
 
1328
1319
  # default random generator
1329
1320
  DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))
1330
-
1331
-
1332
- # ---------------------------------------------------------------------------------------------------------------
1333
-
1334
-
1335
- def _formalize_key(key):
1336
- if isinstance(key, int):
1337
- return jr.PRNGKey(key) if use_prng_key else jr.key(key)
1338
- elif isinstance(key, (jax.Array, np.ndarray)):
1339
- if jnp.issubdtype(key.dtype, jax.dtypes.prng_key):
1340
- return key
1341
- if key.size == 1 and jnp.issubdtype(key.dtype, jnp.integer):
1342
- return jr.PRNGKey(key) if use_prng_key else jr.key(key)
1343
-
1344
- if key.dtype != jnp.uint32:
1345
- raise TypeError('key must be a int or an array with two uint32.')
1346
- if key.size != 2:
1347
- raise TypeError('key must be a int or an array with two uint32.')
1348
- return u.math.asarray(key, dtype=jnp.uint32)
1349
- else:
1350
- raise TypeError('key must be a int or an array with two uint32.')
1351
-
1352
-
1353
- def _size2shape(size):
1354
- if size is None:
1355
- return ()
1356
- elif isinstance(size, (tuple, list)):
1357
- return tuple(size)
1358
- else:
1359
- return (size,)
1360
-
1361
-
1362
- def _check_shape(
1363
- name,
1364
- shape,
1365
- *param_shapes
1366
- ):
1367
- if param_shapes:
1368
- shape_ = lax.broadcast_shapes(shape, *param_shapes)
1369
- if shape != shape_:
1370
- msg = ("{} parameter shapes must be broadcast-compatible with shape "
1371
- "argument, and the result of broadcasting the shapes must equal "
1372
- "the shape argument, but got result {} for shape argument {}.")
1373
- raise ValueError(msg.format(name, shape_, shape))
1374
-
1375
-
1376
- def _is_python_scalar(x):
1377
- if hasattr(x, 'aval'):
1378
- return x.aval.weak_type
1379
- elif np.ndim(x) == 0:
1380
- return True
1381
- elif isinstance(x, (bool, int, float, complex)):
1382
- return True
1383
- else:
1384
- return False
1385
-
1386
-
1387
- python_scalar_dtypes = {
1388
- bool: np.dtype('bool'),
1389
- int: np.dtype('int64'),
1390
- float: np.dtype('float64'),
1391
- complex: np.dtype('complex128'),
1392
- }
1393
-
1394
-
1395
- def _dtype(
1396
- x,
1397
- *,
1398
- canonicalize: bool = False
1399
- ):
1400
- """Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
1401
- if x is None:
1402
- raise ValueError(f"Invalid argument to dtype: {x}.")
1403
- elif isinstance(x, type) and x in python_scalar_dtypes:
1404
- dt = python_scalar_dtypes[x]
1405
- elif type(x) in python_scalar_dtypes:
1406
- dt = python_scalar_dtypes[type(x)]
1407
- elif hasattr(x, 'dtype'):
1408
- dt = x.dtype
1409
- else:
1410
- dt = np.result_type(x)
1411
- return dtypes.canonicalize_dtype(dt) if canonicalize else dt
1412
-
1413
-
1414
- def _const(
1415
- example,
1416
- val
1417
- ):
1418
- if _is_python_scalar(example):
1419
- dtype = dtypes.canonicalize_dtype(type(example))
1420
- val = dtypes.scalar_type_of(example)(val)
1421
- return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype)
1422
- else:
1423
- dtype = dtypes.canonicalize_dtype(example.dtype)
1424
- return np.array(val, dtype)
1425
-
1426
-
1427
- @partial(jit, static_argnums=(2,))
1428
- def _categorical(
1429
- key,
1430
- p,
1431
- shape
1432
- ):
1433
- # this implementation is fast when event shape is small, and slow otherwise
1434
- # Ref: https://stackoverflow.com/a/34190035
1435
- shape = shape or p.shape[:-1]
1436
- s = jnp.cumsum(p, axis=-1)
1437
- r = jr.uniform(key, shape=shape + (1,))
1438
- return jnp.sum(s < r, axis=-1)
1439
-
1440
-
1441
- def _scatter_add_one(
1442
- operand,
1443
- indices,
1444
- updates
1445
- ):
1446
- return lax.scatter_add(
1447
- operand,
1448
- indices,
1449
- updates,
1450
- lax.ScatterDimensionNumbers(
1451
- update_window_dims=(),
1452
- inserted_window_dims=(0,),
1453
- scatter_dims_to_operand_dims=(0,),
1454
- ),
1455
- )
1456
-
1457
-
1458
- def _reshape(x, shape):
1459
- if isinstance(x, (int, float, np.ndarray, np.generic)):
1460
- return np.reshape(x, shape)
1461
- else:
1462
- return jnp.reshape(x, shape)
1463
-
1464
-
1465
- def _promote_shapes(
1466
- *args,
1467
- shape=()
1468
- ):
1469
- # adapted from lax.lax_numpy
1470
- if len(args) < 2 and not shape:
1471
- return args
1472
- else:
1473
- shapes = [u.math.shape(arg) for arg in args]
1474
- num_dims = len(lax.broadcast_shapes(shape, *shapes))
1475
- return [
1476
- _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
1477
- for arg, s in zip(args, shapes)
1478
- ]
1479
-
1480
-
1481
- @partial(jit, static_argnums=(3, 4))
1482
- def _multinomial(
1483
- key,
1484
- p,
1485
- n,
1486
- n_max,
1487
- shape=()
1488
- ):
1489
- if u.math.shape(n) != u.math.shape(p)[:-1]:
1490
- broadcast_shape = lax.broadcast_shapes(u.math.shape(n), u.math.shape(p)[:-1])
1491
- n = jnp.broadcast_to(n, broadcast_shape)
1492
- p = jnp.broadcast_to(p, broadcast_shape + u.math.shape(p)[-1:])
1493
- shape = shape or p.shape[:-1]
1494
- if n_max == 0:
1495
- return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
1496
- # get indices from categorical distribution then gather the result
1497
- indices = _categorical(key, p, (n_max,) + shape)
1498
- # mask out values when counts is heterogeneous
1499
- if jnp.ndim(n) > 0:
1500
- mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
1501
- mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
1502
- excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1),
1503
- jnp.zeros(u.math.shape(n) + (p.shape[-1] - 1,))],
1504
- -1)
1505
- else:
1506
- mask = 1
1507
- excess = 0
1508
- # NB: we transpose to move batch shape to the front
1509
- indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
1510
- samples_2D = vmap(_scatter_add_one)(
1511
- jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
1512
- jnp.expand_dims(indices_2D, axis=-1),
1513
- jnp.ones(indices_2D.shape, dtype=indices.dtype)
1514
- )
1515
- return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
1516
-
1517
-
1518
- @partial(jit, static_argnums=(2, 3), static_argnames=['shape', 'dtype'])
1519
- def _von_mises_centered(
1520
- key,
1521
- concentration,
1522
- shape,
1523
- dtype=None
1524
- ):
1525
- """Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
1526
-
1527
- Returns
1528
- -------
1529
- out: array_like
1530
- centered samples from von Mises
1531
-
1532
- References
1533
- ----------
1534
- .. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
1535
- Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
1536
-
1537
- """
1538
- shape = shape or u.math.shape(concentration)
1539
- dtype = dtype or environ.dftype()
1540
- concentration = lax.convert_element_type(concentration, dtype)
1541
- concentration = jnp.broadcast_to(concentration, shape)
1542
-
1543
- if dtype == jnp.float16:
1544
- s_cutoff = 1.8e-1
1545
- elif dtype == jnp.float32:
1546
- s_cutoff = 2e-2
1547
- elif dtype == jnp.float64:
1548
- s_cutoff = 1.2e-4
1549
- else:
1550
- raise ValueError(f"Unsupported dtype: {dtype}")
1551
-
1552
- r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
1553
- rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
1554
- s_exact = (1.0 + rho ** 2) / (2.0 * rho)
1555
-
1556
- s_approximate = 1.0 / concentration
1557
-
1558
- s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
1559
-
1560
- def cond_fn(
1561
- *args
1562
- ):
1563
- """check if all are done or reached max number of iterations"""
1564
- i, _, done, _, _ = args[0]
1565
- return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
1566
-
1567
- def body_fn(
1568
- *args
1569
- ):
1570
- i, key, done, _, w = args[0]
1571
- uni_ukey, uni_vkey, key = jr.split(key, 3)
1572
- u_ = jr.uniform(
1573
- key=uni_ukey,
1574
- shape=shape,
1575
- dtype=concentration.dtype,
1576
- minval=-1.0,
1577
- maxval=1.0,
1578
- )
1579
- z = jnp.cos(jnp.pi * u_)
1580
- w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
1581
- y = concentration * (s - w)
1582
- v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
1583
- accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
1584
- return i + 1, key, accept | done, u_, w
1585
-
1586
- init_done = jnp.zeros(shape, dtype=bool)
1587
- init_u = jnp.zeros(shape)
1588
- init_w = jnp.zeros(shape)
1589
-
1590
- _, _, done, uu, w = lax.while_loop(
1591
- cond_fun=cond_fn,
1592
- body_fun=body_fn,
1593
- init_val=(jnp.array(0), key, init_done, init_u, init_w),
1594
- )
1595
-
1596
- return jnp.sign(uu) * jnp.arccos(w)
1597
-
1598
-
1599
- def _loc_scale(
1600
- loc,
1601
- scale,
1602
- value
1603
- ):
1604
- if loc is None:
1605
- if scale is None:
1606
- return value
1607
- else:
1608
- return value * scale
1609
- else:
1610
- if scale is None:
1611
- return value + loc
1612
- else:
1613
- return value * scale + loc
1614
-
1615
-
1616
- def _check_py_seq(seq):
1617
- return u.math.asarray(seq) if isinstance(seq, (tuple, list)) else seq