brainstate 0.0.2.post20240913__py2.py3-none-any.whl → 0.0.2.post20241010__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. brainstate/__init__.py +4 -2
  2. brainstate/_module.py +102 -67
  3. brainstate/_state.py +2 -2
  4. brainstate/_visualization.py +47 -0
  5. brainstate/environ.py +116 -9
  6. brainstate/environ_test.py +56 -0
  7. brainstate/functional/_activations.py +134 -56
  8. brainstate/functional/_activations_test.py +331 -0
  9. brainstate/functional/_normalization.py +21 -10
  10. brainstate/init/_generic.py +4 -2
  11. brainstate/mixin.py +1 -1
  12. brainstate/nn/__init__.py +7 -2
  13. brainstate/nn/_base.py +2 -2
  14. brainstate/nn/_connections.py +4 -4
  15. brainstate/nn/_dynamics.py +5 -5
  16. brainstate/nn/_elementwise.py +9 -9
  17. brainstate/nn/_embedding.py +3 -3
  18. brainstate/nn/_normalizations.py +3 -3
  19. brainstate/nn/_others.py +2 -2
  20. brainstate/nn/_poolings.py +6 -6
  21. brainstate/nn/_rate_rnns.py +1 -1
  22. brainstate/nn/_readout.py +1 -1
  23. brainstate/nn/_synouts.py +1 -1
  24. brainstate/nn/event/__init__.py +25 -0
  25. brainstate/nn/event/_misc.py +34 -0
  26. brainstate/nn/event/csr.py +312 -0
  27. brainstate/nn/event/csr_test.py +118 -0
  28. brainstate/nn/event/fixed_probability.py +276 -0
  29. brainstate/nn/event/fixed_probability_test.py +127 -0
  30. brainstate/nn/event/linear.py +220 -0
  31. brainstate/nn/event/linear_test.py +111 -0
  32. brainstate/nn/metrics.py +390 -0
  33. brainstate/optim/__init__.py +5 -1
  34. brainstate/optim/_optax_optimizer.py +208 -0
  35. brainstate/optim/_optax_optimizer_test.py +14 -0
  36. brainstate/random/__init__.py +24 -0
  37. brainstate/{random.py → random/_rand_funs.py} +7 -1596
  38. brainstate/random/_rand_seed.py +169 -0
  39. brainstate/random/_rand_state.py +1498 -0
  40. brainstate/{_random_for_unit.py → random/_random_for_unit.py} +1 -1
  41. brainstate/{random_test.py → random/random_test.py} +208 -191
  42. brainstate/transform/_jit.py +1 -1
  43. brainstate/transform/_jit_test.py +19 -0
  44. brainstate/transform/_make_jaxpr.py +1 -1
  45. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241010.dist-info}/METADATA +1 -1
  46. brainstate-0.0.2.post20241010.dist-info/RECORD +87 -0
  47. brainstate-0.0.2.post20240913.dist-info/RECORD +0 -70
  48. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241010.dist-info}/LICENSE +0 -0
  49. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241010.dist-info}/WHEEL +0 -0
  50. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241010.dist-info}/top_level.txt +0 -0
@@ -16,31 +16,14 @@
16
16
 
17
17
  # -*- coding: utf-8 -*-
18
18
 
19
- from collections import namedtuple
20
- from contextlib import contextmanager
21
- from functools import partial
22
- from operator import index
23
19
  from typing import Optional
24
20
 
25
- import brainunit as bu
26
- import jax
27
- import jax.numpy as jnp
28
- import jax.random as jr
29
21
  import numpy as np
30
- from jax import jit, vmap
31
- from jax import lax, core, dtypes
32
22
 
33
- from brainstate import environ
34
- from ._random_for_unit import uniform_for_unit, permutation_for_unit
35
- from ._state import State
36
- from .transform._error_if import jit_error_if
37
- from .typing import DTypeLike, Size, SeedOrKey
23
+ from brainstate.typing import DTypeLike, Size, SeedOrKey
24
+ from ._rand_state import RandomState, DEFAULT
38
25
 
39
26
  __all__ = [
40
- 'RandomState', 'DEFAULT',
41
-
42
- 'seed', 'default_rng', 'split_key', 'split_keys', 'seed_context',
43
-
44
27
  # numpy compatibility
45
28
  'rand', 'randint', 'random_integers', 'randn', 'random',
46
29
  'random_sample', 'ranf', 'sample', 'choice', 'permutation', 'shuffle', 'beta',
@@ -58,1186 +41,6 @@ __all__ = [
58
41
  ]
59
42
 
60
43
 
61
- class RandomState(State):
62
- """RandomState that track the random generator state. """
63
- __slots__ = ()
64
-
65
- def __init__(self, seed_or_key: Optional[SeedOrKey] = None):
66
- """RandomState constructor.
67
-
68
- Parameters
69
- ----------
70
- seed_or_key: int, Array, optional
71
- It can be an integer for initial seed of the random number generator,
72
- or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
73
- """
74
- with jax.ensure_compile_time_eval():
75
- if seed_or_key is None:
76
- seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
77
- if isinstance(seed_or_key, int):
78
- key = jr.PRNGKey(seed_or_key)
79
- else:
80
- if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
81
- raise ValueError('key must be an array with dtype uint32. '
82
- f'But we got {seed_or_key}')
83
- key = seed_or_key
84
- super().__init__(key)
85
-
86
- def __repr__(self) -> str:
87
- print_code = repr(self.value)
88
- i = print_code.index('(')
89
- return f'{self.__class__.__name__}(key={print_code[i:]})'
90
-
91
- def _check_if_deleted(self):
92
- if isinstance(self._value, jax.Array) and not isinstance(self._value, jax.core.Tracer) and self._value.is_deleted():
93
- self.seed()
94
-
95
- # ------------------- #
96
- # seed and random key #
97
- # ------------------- #
98
-
99
- def clone(self):
100
- return type(self)(self.split_key())
101
-
102
- def seed(self, seed_or_key: Optional[SeedOrKey] = None):
103
- """Sets a new random seed.
104
-
105
- Parameters
106
- ----------
107
- seed_or_key: int, ArrayLike, optional
108
- It can be an integer for initial seed of the random number generator,
109
- or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
110
- """
111
- if seed_or_key is None:
112
- seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
113
- if isinstance(seed_or_key, int):
114
- key = jr.PRNGKey(seed_or_key)
115
- else:
116
- if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
117
- raise ValueError('key must be an array with dtype uint32. '
118
- f'But we got {seed_or_key}')
119
- key = seed_or_key
120
- self.value = key
121
-
122
- def split_key(self):
123
- """Create a new seed from the current seed.
124
- """
125
- if not isinstance(self.value, jax.Array):
126
- self.value = jnp.asarray(self.value, dtype=jnp.uint32)
127
- keys = jr.split(self.value, num=2)
128
- self.value = keys[0]
129
- return keys[1]
130
-
131
- def split_keys(self, n: int):
132
- """Create multiple seeds from the current seed. This is used
133
- internally by `pmap` and `vmap` to ensure that random numbers
134
- are different in parallel threads.
135
-
136
- Parameters
137
- ----------
138
- n : int
139
- The number of seeds to generate.
140
- """
141
- keys = jr.split(self.value, n + 1)
142
- self.value = keys[0]
143
- return keys[1:]
144
-
145
- # ---------------- #
146
- # random functions #
147
- # ---------------- #
148
-
149
- def rand(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
150
- key = self.split_key() if key is None else _formalize_key(key)
151
- dtype = dtype or environ.dftype()
152
- r = uniform_for_unit(key, shape=dn, minval=0., maxval=1., dtype=dtype)
153
- return r
154
-
155
- def randint(
156
- self,
157
- low,
158
- high=None,
159
- size: Optional[Size] = None,
160
- dtype: DTypeLike = None,
161
- key: Optional[SeedOrKey] = None
162
- ):
163
- if high is None:
164
- high = low
165
- low = 0
166
- high = _check_py_seq(high)
167
- low = _check_py_seq(low)
168
- if size is None:
169
- size = lax.broadcast_shapes(jnp.shape(low),
170
- jnp.shape(high))
171
- key = self.split_key() if key is None else _formalize_key(key)
172
- dtype = dtype or environ.ditype()
173
- r = jr.randint(key,
174
- shape=_size2shape(size),
175
- minval=low, maxval=high, dtype=dtype)
176
- return r
177
-
178
- def random_integers(
179
- self,
180
- low,
181
- high=None,
182
- size: Optional[Size] = None,
183
- key: Optional[SeedOrKey] = None,
184
- dtype: DTypeLike = None,
185
- ):
186
- low = _check_py_seq(low)
187
- high = _check_py_seq(high)
188
- if high is None:
189
- high = low
190
- low = 1
191
- high += 1
192
- if size is None:
193
- size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
194
- key = self.split_key() if key is None else _formalize_key(key)
195
- dtype = dtype or environ.ditype()
196
- r = jr.randint(key,
197
- shape=_size2shape(size),
198
- minval=low,
199
- maxval=high,
200
- dtype=dtype)
201
- return r
202
-
203
- def randn(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
204
- key = self.split_key() if key is None else _formalize_key(key)
205
- dtype = dtype or environ.dftype()
206
- r = jr.normal(key, shape=dn, dtype=dtype)
207
- return r
208
-
209
- def random(self,
210
- size: Optional[Size] = None,
211
- key: Optional[SeedOrKey] = None,
212
- dtype: DTypeLike = None):
213
- dtype = dtype or environ.dftype()
214
- key = self.split_key() if key is None else _formalize_key(key)
215
- r = uniform_for_unit(key, shape=_size2shape(size), minval=0., maxval=1., dtype=dtype)
216
- return r
217
-
218
- def random_sample(self,
219
- size: Optional[Size] = None,
220
- key: Optional[SeedOrKey] = None,
221
- dtype: DTypeLike = None):
222
- r = self.random(size=size, key=key, dtype=dtype)
223
- return r
224
-
225
- def ranf(self,
226
- size: Optional[Size] = None,
227
- key: Optional[SeedOrKey] = None,
228
- dtype: DTypeLike = None):
229
- r = self.random(size=size, key=key, dtype=dtype)
230
- return r
231
-
232
- def sample(self,
233
- size: Optional[Size] = None,
234
- key: Optional[SeedOrKey] = None,
235
- dtype: DTypeLike = None):
236
- r = self.random(size=size, key=key, dtype=dtype)
237
- return r
238
-
239
- def choice(self,
240
- a,
241
- size: Optional[Size] = None,
242
- replace=True,
243
- p=None,
244
- key: Optional[SeedOrKey] = None):
245
- a = _check_py_seq(a)
246
- p = _check_py_seq(p)
247
- key = self.split_key() if key is None else _formalize_key(key)
248
- r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
249
- return r
250
-
251
- def permutation(self,
252
- x,
253
- axis: int = 0,
254
- independent: bool = False,
255
- key: Optional[SeedOrKey] = None):
256
- x = _check_py_seq(x)
257
- key = self.split_key() if key is None else _formalize_key(key)
258
- r = permutation_for_unit(key, x, axis=axis, independent=independent)
259
- return r
260
-
261
- def shuffle(self,
262
- x,
263
- axis=0,
264
- key: Optional[SeedOrKey] = None):
265
- key = self.split_key() if key is None else _formalize_key(key)
266
- x = permutation_for_unit(key, x, axis=axis)
267
- return x
268
-
269
- def beta(self,
270
- a,
271
- b,
272
- size: Optional[Size] = None,
273
- key: Optional[SeedOrKey] = None,
274
- dtype: DTypeLike = None):
275
- a = _check_py_seq(a)
276
- b = _check_py_seq(b)
277
- if size is None:
278
- size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b))
279
- key = self.split_key() if key is None else _formalize_key(key)
280
- dtype = dtype or environ.dftype()
281
- r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype)
282
- return r
283
-
284
- def exponential(self,
285
- scale=None,
286
- size: Optional[Size] = None,
287
- key: Optional[SeedOrKey] = None,
288
- dtype: DTypeLike = None):
289
- if size is None:
290
- size = jnp.shape(scale)
291
- key = self.split_key() if key is None else _formalize_key(key)
292
- dtype = dtype or environ.dftype()
293
- scale = jnp.asarray(scale, dtype=dtype)
294
- r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
295
- if scale is not None:
296
- r = r / scale
297
- return r
298
-
299
- def gamma(self,
300
- shape,
301
- scale=None,
302
- size: Optional[Size] = None,
303
- key: Optional[SeedOrKey] = None,
304
- dtype: DTypeLike = None):
305
- shape = _check_py_seq(shape)
306
- scale = _check_py_seq(scale)
307
- if size is None:
308
- size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale))
309
- key = self.split_key() if key is None else _formalize_key(key)
310
- dtype = dtype or environ.dftype()
311
- r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
312
- if scale is not None:
313
- r = r * scale
314
- return r
315
-
316
- def gumbel(self,
317
- loc=None,
318
- scale=None,
319
- size: Optional[Size] = None,
320
- key: Optional[SeedOrKey] = None,
321
- dtype: DTypeLike = None):
322
- loc = _check_py_seq(loc)
323
- scale = _check_py_seq(scale)
324
- if size is None:
325
- size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
326
- key = self.split_key() if key is None else _formalize_key(key)
327
- dtype = dtype or environ.dftype()
328
- r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype))
329
- return r
330
-
331
- def laplace(self,
332
- loc=None,
333
- scale=None,
334
- size: Optional[Size] = None,
335
- key: Optional[SeedOrKey] = None,
336
- dtype: DTypeLike = None):
337
- loc = _check_py_seq(loc)
338
- scale = _check_py_seq(scale)
339
- if size is None:
340
- size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
341
- key = self.split_key() if key is None else _formalize_key(key)
342
- dtype = dtype or environ.dftype()
343
- r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype))
344
- return r
345
-
346
- def logistic(self,
347
- loc=None,
348
- scale=None,
349
- size: Optional[Size] = None,
350
- key: Optional[SeedOrKey] = None,
351
- dtype: DTypeLike = None):
352
- loc = _check_py_seq(loc)
353
- scale = _check_py_seq(scale)
354
- if size is None:
355
- size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
356
- key = self.split_key() if key is None else _formalize_key(key)
357
- dtype = dtype or environ.dftype()
358
- r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
359
- return r
360
-
361
- def normal(self,
362
- loc=None,
363
- scale=None,
364
- size: Optional[Size] = None,
365
- key: Optional[SeedOrKey] = None,
366
- dtype: DTypeLike = None):
367
- loc = _check_py_seq(loc)
368
- scale = _check_py_seq(scale)
369
- if size is None:
370
- size = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(loc))
371
- key = self.split_key() if key is None else _formalize_key(key)
372
- dtype = dtype or environ.dftype()
373
- r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
374
- return r
375
-
376
- def pareto(self,
377
- a,
378
- size: Optional[Size] = None,
379
- key: Optional[SeedOrKey] = None,
380
- dtype: DTypeLike = None):
381
- if size is None:
382
- size = jnp.shape(a)
383
- key = self.split_key() if key is None else _formalize_key(key)
384
- dtype = dtype or environ.dftype()
385
- a = jnp.asarray(a, dtype=dtype)
386
- r = jr.pareto(key, b=a, shape=_size2shape(size), dtype=dtype)
387
- return r
388
-
389
- def poisson(self,
390
- lam=1.0,
391
- size: Optional[Size] = None,
392
- key: Optional[SeedOrKey] = None,
393
- dtype: DTypeLike = None):
394
- lam = _check_py_seq(lam)
395
- if size is None:
396
- size = jnp.shape(lam)
397
- key = self.split_key() if key is None else _formalize_key(key)
398
- dtype = dtype or environ.ditype()
399
- r = jr.poisson(key, lam=lam, shape=_size2shape(size), dtype=dtype)
400
- return r
401
-
402
- def standard_cauchy(self,
403
- size: Optional[Size] = None,
404
- key: Optional[SeedOrKey] = None,
405
- dtype: DTypeLike = None):
406
- key = self.split_key() if key is None else _formalize_key(key)
407
- dtype = dtype or environ.dftype()
408
- r = jr.cauchy(key, shape=_size2shape(size), dtype=dtype)
409
- return r
410
-
411
- def standard_exponential(self,
412
- size: Optional[Size] = None,
413
- key: Optional[SeedOrKey] = None,
414
- dtype: DTypeLike = None):
415
- key = self.split_key() if key is None else _formalize_key(key)
416
- dtype = dtype or environ.dftype()
417
- r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
418
- return r
419
-
420
- def standard_gamma(self,
421
- shape,
422
- size: Optional[Size] = None,
423
- key: Optional[SeedOrKey] = None,
424
- dtype: DTypeLike = None):
425
- shape = _check_py_seq(shape)
426
- if size is None:
427
- size = jnp.shape(shape)
428
- key = self.split_key() if key is None else _formalize_key(key)
429
- dtype = dtype or environ.dftype()
430
- r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
431
- return r
432
-
433
- def standard_normal(self,
434
- size: Optional[Size] = None,
435
- key: Optional[SeedOrKey] = None,
436
- dtype: DTypeLike = None):
437
- key = self.split_key() if key is None else _formalize_key(key)
438
- dtype = dtype or environ.dftype()
439
- r = jr.normal(key, shape=_size2shape(size), dtype=dtype)
440
- return r
441
-
442
- def standard_t(self, df,
443
- size: Optional[Size] = None,
444
- key: Optional[SeedOrKey] = None,
445
- dtype: DTypeLike = None):
446
- df = _check_py_seq(df)
447
- if size is None:
448
- size = jnp.shape(size)
449
- key = self.split_key() if key is None else _formalize_key(key)
450
- dtype = dtype or environ.dftype()
451
- r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
452
- return r
453
-
454
- def uniform(self,
455
- low=0.0,
456
- high=1.0,
457
- size: Optional[Size] = None,
458
- key: Optional[SeedOrKey] = None,
459
- dtype: DTypeLike = None):
460
- low = _check_py_seq(low)
461
- high = _check_py_seq(high)
462
- if size is None:
463
- size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
464
- key = self.split_key() if key is None else _formalize_key(key)
465
- dtype = dtype or environ.dftype()
466
- r = uniform_for_unit(key, shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)
467
- return r
468
-
469
- def __norm_cdf(self, x, sqrt2, dtype):
470
- # Computes standard normal cumulative distribution function
471
- return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
472
-
473
- def truncated_normal(
474
- self,
475
- lower,
476
- upper,
477
- size: Optional[Size] = None,
478
- loc=0.,
479
- scale=1.,
480
- key: Optional[SeedOrKey] = None,
481
- dtype: DTypeLike = None
482
- ):
483
- lower = _check_py_seq(lower)
484
- upper = _check_py_seq(upper)
485
- loc = _check_py_seq(loc)
486
- scale = _check_py_seq(scale)
487
- dtype = dtype or environ.dftype()
488
-
489
- lower = bu.math.asarray(lower, dtype=dtype)
490
- upper = bu.math.asarray(upper, dtype=dtype)
491
- loc = bu.math.asarray(loc, dtype=dtype)
492
- scale = bu.math.asarray(scale, dtype=dtype)
493
- unit = bu.get_unit(lower)
494
- lower, upper, loc, scale = (
495
- lower.mantissa if isinstance(lower, bu.Quantity) else lower,
496
- bu.Quantity(upper).in_unit(unit).mantissa,
497
- bu.Quantity(loc).in_unit(unit).mantissa,
498
- bu.Quantity(scale).in_unit(unit).mantissa
499
- )
500
-
501
- jit_error_if(
502
- bu.math.any(bu.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
503
- "mean is more than 2 std from [lower, upper] in truncated_normal. "
504
- "The distribution of values may be incorrect."
505
- )
506
-
507
- if size is None:
508
- size = bu.math.broadcast_shapes(jnp.shape(lower),
509
- jnp.shape(upper),
510
- jnp.shape(loc),
511
- jnp.shape(scale))
512
-
513
- # Values are generated by using a truncated uniform distribution and
514
- # then using the inverse CDF for the normal distribution.
515
- # Get upper and lower cdf values
516
- sqrt2 = np.array(np.sqrt(2), dtype=dtype)
517
- l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
518
- u = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
519
-
520
- # Uniformly fill tensor with values from [l, u], then translate to
521
- # [2l-1, 2u-1].
522
- key = self.split_key() if key is None else _formalize_key(key)
523
- out = uniform_for_unit(
524
- key, size, dtype,
525
- minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
526
- maxval=lax.nextafter(2 * u - 1, np.array(-np.inf, dtype=dtype))
527
- )
528
-
529
- # Use inverse cdf transform for normal distribution to get truncated
530
- # standard normal
531
- out = lax.erf_inv(out)
532
-
533
- # Transform to proper mean, std
534
- out = out * scale * sqrt2 + loc
535
-
536
- # Clamp to ensure it's in the proper range
537
- out = jnp.clip(
538
- out,
539
- lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
540
- lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
541
- )
542
- return out if unit.is_unitless else bu.Quantity(out, unit=unit)
543
-
544
- def _check_p(self, p):
545
- raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
546
-
547
- def bernoulli(self,
548
- p,
549
- size: Optional[Size] = None,
550
- key: Optional[SeedOrKey] = None):
551
- p = _check_py_seq(p)
552
- jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
553
- if size is None:
554
- size = jnp.shape(p)
555
- key = self.split_key() if key is None else _formalize_key(key)
556
- r = jr.bernoulli(key, p=p, shape=_size2shape(size))
557
- return r
558
-
559
- def lognormal(
560
- self,
561
- mean=None,
562
- sigma=None,
563
- size: Optional[Size] = None,
564
- key: Optional[SeedOrKey] = None,
565
- dtype: DTypeLike = None
566
- ):
567
- mean = _check_py_seq(mean)
568
- sigma = _check_py_seq(sigma)
569
- mean = bu.math.asarray(mean, dtype=dtype)
570
- sigma = bu.math.asarray(sigma, dtype=dtype)
571
- unit = mean.unit if isinstance(mean, bu.Quantity) else bu.Unit()
572
- mean = mean.mantissa if isinstance(mean, bu.Quantity) else mean
573
- sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, bu.Quantity) else sigma
574
-
575
- if size is None:
576
- size = jnp.broadcast_shapes(
577
- jnp.shape(mean),
578
- jnp.shape(sigma)
579
- )
580
- key = self.split_key() if key is None else _formalize_key(key)
581
- dtype = dtype or environ.dftype()
582
- samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
583
- samples = _loc_scale(mean, sigma, samples)
584
- samples = jnp.exp(samples)
585
- return samples if unit.is_unitless else bu.Quantity(samples, unit=unit)
586
-
587
- def binomial(self,
588
- n,
589
- p,
590
- size: Optional[Size] = None,
591
- key: Optional[SeedOrKey] = None,
592
- dtype: DTypeLike = None):
593
- n = _check_py_seq(n)
594
- p = _check_py_seq(p)
595
- jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
596
- if size is None:
597
- size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
598
- key = self.split_key() if key is None else _formalize_key(key)
599
- r = _binomial(key, p, n, shape=_size2shape(size))
600
- dtype = dtype or environ.ditype()
601
- return jnp.asarray(r, dtype=dtype)
602
-
603
- def chisquare(self,
604
- df,
605
- size: Optional[Size] = None,
606
- key: Optional[SeedOrKey] = None,
607
- dtype: DTypeLike = None):
608
- df = _check_py_seq(df)
609
- key = self.split_key() if key is None else _formalize_key(key)
610
- dtype = dtype or environ.dftype()
611
- if size is None:
612
- if jnp.ndim(df) == 0:
613
- dist = jr.normal(key, (df,), dtype=dtype) ** 2
614
- dist = dist.sum()
615
- else:
616
- raise NotImplementedError('Do not support non-scale "df" when "size" is None')
617
- else:
618
- dist = jr.normal(key, (df,) + _size2shape(size), dtype=dtype) ** 2
619
- dist = dist.sum(axis=0)
620
- return dist
621
-
622
- def dirichlet(self,
623
- alpha,
624
- size: Optional[Size] = None,
625
- key: Optional[SeedOrKey] = None,
626
- dtype: DTypeLike = None):
627
- key = self.split_key() if key is None else _formalize_key(key)
628
- alpha = _check_py_seq(alpha)
629
- dtype = dtype or environ.dftype()
630
- r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size), dtype=dtype)
631
- return r
632
-
633
- def geometric(self,
634
- p,
635
- size: Optional[Size] = None,
636
- key: Optional[SeedOrKey] = None,
637
- dtype: DTypeLike = None):
638
- p = _check_py_seq(p)
639
- if size is None:
640
- size = jnp.shape(p)
641
- key = self.split_key() if key is None else _formalize_key(key)
642
- dtype = dtype or environ.dftype()
643
- u = uniform_for_unit(key, size, dtype=dtype)
644
- r = jnp.floor(jnp.log1p(-u) / jnp.log1p(-p))
645
- return r
646
-
647
- def _check_p2(self, p):
648
- raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
649
-
650
- def multinomial(self,
651
- n,
652
- pvals,
653
- size: Optional[Size] = None,
654
- key: Optional[SeedOrKey] = None,
655
- dtype: DTypeLike = None):
656
- key = self.split_key() if key is None else _formalize_key(key)
657
- n = _check_py_seq(n)
658
- pvals = _check_py_seq(pvals)
659
- jit_error_if(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
660
- if isinstance(n, jax.core.Tracer):
661
- raise ValueError("The total count parameter `n` should not be a jax abstract array.")
662
- size = _size2shape(size)
663
- n_max = int(np.max(jax.device_get(n)))
664
- batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n))
665
- r = _multinomial(key, pvals, n, n_max, batch_shape + size)
666
- dtype = dtype or environ.ditype()
667
- return jnp.asarray(r, dtype=dtype)
668
-
669
- def multivariate_normal(
670
- self,
671
- mean,
672
- cov,
673
- size: Optional[Size] = None,
674
- method: str = 'cholesky',
675
- key: Optional[SeedOrKey] = None,
676
- dtype: DTypeLike = None
677
- ):
678
- if method not in {'svd', 'eigh', 'cholesky'}:
679
- raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
680
- dtype = dtype or environ.dftype()
681
- mean = bu.math.asarray(_check_py_seq(mean), dtype=dtype)
682
- cov = bu.math.asarray(_check_py_seq(cov), dtype=dtype)
683
- if isinstance(mean, bu.Quantity):
684
- assert isinstance(cov, bu.Quantity)
685
- assert mean.unit ** 2 == cov.unit
686
- mean = mean.mantissa if isinstance(mean, bu.Quantity) else mean
687
- cov = cov.mantissa if isinstance(cov, bu.Quantity) else cov
688
- unit = mean.unit if isinstance(mean, bu.Quantity) else bu.Unit()
689
-
690
- key = self.split_key() if key is None else _formalize_key(key)
691
- if not jnp.ndim(mean) >= 1:
692
- raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
693
- if not jnp.ndim(cov) >= 2:
694
- raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}")
695
- n = mean.shape[-1]
696
- if jnp.shape(cov)[-2:] != (n, n):
697
- raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
698
- f"but got cov.shape == {jnp.shape(cov)}.")
699
- if size is None:
700
- size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
701
- else:
702
- size = _size2shape(size)
703
- _check_shape("normal", size, mean.shape[:-1], cov.shape[:-2])
704
-
705
- if method == 'svd':
706
- (u, s, _) = jnp.linalg.svd(cov)
707
- factor = u * jnp.sqrt(s[..., None, :])
708
- elif method == 'eigh':
709
- (w, v) = jnp.linalg.eigh(cov)
710
- factor = v * jnp.sqrt(w[..., None, :])
711
- else: # 'cholesky'
712
- factor = jnp.linalg.cholesky(cov)
713
- normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
714
- r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
715
- return r if unit.is_unitless else bu.Quantity(r, unit=unit)
716
-
717
- def rayleigh(self,
718
- scale=1.0,
719
- size: Optional[Size] = None,
720
- key: Optional[SeedOrKey] = None,
721
- dtype: DTypeLike = None):
722
- scale = _check_py_seq(scale)
723
- if size is None:
724
- size = jnp.shape(scale)
725
- key = self.split_key() if key is None else _formalize_key(key)
726
- dtype = dtype or environ.dftype()
727
- x = jnp.sqrt(-2. * jnp.log(uniform_for_unit(key, shape=_size2shape(size), minval=0, maxval=1, dtype=dtype)))
728
- r = x * scale
729
- return r
730
-
731
- def triangular(self,
732
- size: Optional[Size] = None,
733
- key: Optional[SeedOrKey] = None):
734
- key = self.split_key() if key is None else _formalize_key(key)
735
- bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
736
- r = 2 * bernoulli_samples - 1
737
- return r
738
-
739
- def vonmises(self,
740
- mu,
741
- kappa,
742
- size: Optional[Size] = None,
743
- key: Optional[SeedOrKey] = None,
744
- dtype: DTypeLike = None):
745
- key = self.split_key() if key is None else _formalize_key(key)
746
- dtype = dtype or environ.dftype()
747
- mu = jnp.asarray(_check_py_seq(mu), dtype=dtype)
748
- kappa = jnp.asarray(_check_py_seq(kappa), dtype=dtype)
749
- if size is None:
750
- size = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(kappa))
751
- size = _size2shape(size)
752
- samples = _von_mises_centered(key, kappa, size, dtype=dtype)
753
- samples = samples + mu
754
- samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
755
- return samples
756
-
757
- def weibull(self,
758
- a,
759
- size: Optional[Size] = None,
760
- key: Optional[SeedOrKey] = None,
761
- dtype: DTypeLike = None):
762
- key = self.split_key() if key is None else _formalize_key(key)
763
- a = _check_py_seq(a)
764
- if size is None:
765
- size = jnp.shape(a)
766
- else:
767
- if jnp.size(a) > 1:
768
- raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
769
- size = _size2shape(size)
770
- dtype = dtype or environ.dftype()
771
- random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
772
- r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
773
- return r
774
-
775
- def weibull_min(self,
776
- a,
777
- scale=None,
778
- size: Optional[Size] = None,
779
- key: Optional[SeedOrKey] = None,
780
- dtype: DTypeLike = None):
781
- key = self.split_key() if key is None else _formalize_key(key)
782
- a = _check_py_seq(a)
783
- scale = _check_py_seq(scale)
784
- if size is None:
785
- size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale))
786
- else:
787
- if jnp.size(a) > 1:
788
- raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
789
- size = _size2shape(size)
790
- dtype = dtype or environ.dftype()
791
- random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
792
- r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
793
- if scale is not None:
794
- r /= scale
795
- return r
796
-
797
- def maxwell(self,
798
- size: Optional[Size] = None,
799
- key: Optional[SeedOrKey] = None,
800
- dtype: DTypeLike = None):
801
- key = self.split_key() if key is None else _formalize_key(key)
802
- shape = _size2shape(size) + (3,)
803
- dtype = dtype or environ.dftype()
804
- norm_rvs = jr.normal(key=key, shape=shape, dtype=dtype)
805
- r = jnp.linalg.norm(norm_rvs, axis=-1)
806
- return r
807
-
808
- def negative_binomial(self,
809
- n,
810
- p,
811
- size: Optional[Size] = None,
812
- key: Optional[SeedOrKey] = None,
813
- dtype: DTypeLike = None):
814
- n = _check_py_seq(n)
815
- p = _check_py_seq(p)
816
- if size is None:
817
- size = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p))
818
- size = _size2shape(size)
819
- logits = jnp.log(p) - jnp.log1p(-p)
820
- if key is None:
821
- keys = self.split_keys(2)
822
- else:
823
- keys = jr.split(_formalize_key(key), 2)
824
- rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0], dtype=environ.dftype())
825
- r = self.poisson(lam=rate, key=keys[1], dtype=dtype or environ.ditype())
826
- return r
827
-
828
- def wald(self,
829
- mean,
830
- scale,
831
- size: Optional[Size] = None,
832
- key: Optional[SeedOrKey] = None,
833
- dtype: DTypeLike = None):
834
- dtype = dtype or environ.dftype()
835
- key = self.split_key() if key is None else _formalize_key(key)
836
- mean = jnp.asarray(_check_py_seq(mean), dtype=dtype)
837
- scale = jnp.asarray(_check_py_seq(scale), dtype=dtype)
838
- if size is None:
839
- size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale))
840
- size = _size2shape(size)
841
- sampled_chi2 = jnp.square(self.randn(*size))
842
- sampled_uniform = self.uniform(size=size, key=key, dtype=dtype)
843
- # Wikipedia defines an intermediate x with the formula
844
- # x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2)
845
- # where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration.
846
- # Let us write
847
- # w = loc * y / (2 * conc)
848
- # Then we can extract the common factor in the last two terms to obtain
849
- # x = loc + loc * w * (1 - sqrt(2 / w + 1))
850
- # Now we see that the Wikipedia formula suffers from catastrphic
851
- # cancellation for large w (e.g., if conc << loc).
852
- #
853
- # Fortunately, we can fix this by multiplying both sides
854
- # by 1 + sqrt(2 / w + 1). We get
855
- # x * (1 + sqrt(2 / w + 1)) =
856
- # = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1))
857
- # = loc * (sqrt(2 / w + 1) - 1)
858
- # The term sqrt(2 / w + 1) + 1 no longer presents numerical
859
- # difficulties for large w, and sqrt(2 / w + 1) - 1 is just
860
- # sqrt1pm1(2 / w), which we know how to compute accurately.
861
- # This just leaves the matter of small w, where 2 / w may
862
- # overflow. In the limit a w -> 0, x -> loc, so we just mask
863
- # that case.
864
- sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above
865
- safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0)
866
- denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0)
867
- ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator
868
- sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above
869
- res = jnp.where(sampled_uniform <= mean / (mean + sampled),
870
- sampled,
871
- jnp.square(mean) / sampled)
872
- return res
873
-
874
- def t(self,
875
- df,
876
- size: Optional[Size] = None,
877
- key: Optional[SeedOrKey] = None,
878
- dtype: DTypeLike = None):
879
- dtype = dtype or environ.dftype()
880
- df = jnp.asarray(_check_py_seq(df), dtype=dtype)
881
- if size is None:
882
- size = np.shape(df)
883
- else:
884
- size = _size2shape(size)
885
- _check_shape("t", size, np.shape(df))
886
- if key is None:
887
- keys = self.split_keys(2)
888
- else:
889
- keys = jr.split(_formalize_key(key), 2)
890
- n = jr.normal(keys[0], size, dtype=dtype)
891
- two = _const(n, 2)
892
- half_df = lax.div(df, two)
893
- g = jr.gamma(keys[1], half_df, size, dtype=dtype)
894
- r = n * jnp.sqrt(half_df / g)
895
- return r
896
-
897
- def orthogonal(self,
898
- n: int,
899
- size: Optional[Size] = None,
900
- key: Optional[SeedOrKey] = None,
901
- dtype: DTypeLike = None):
902
- dtype = dtype or environ.dftype()
903
- key = self.split_key() if key is None else _formalize_key(key)
904
- size = _size2shape(size)
905
- _check_shape("orthogonal", size)
906
- n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
907
- z = jr.normal(key, size + (n, n), dtype=dtype)
908
- q, r = jnp.linalg.qr(z)
909
- d = jnp.diagonal(r, 0, -2, -1)
910
- r = q * jnp.expand_dims(d / abs(d), -2)
911
- return r
912
-
913
- def noncentral_chisquare(self,
914
- df,
915
- nonc,
916
- size: Optional[Size] = None,
917
- key: Optional[SeedOrKey] = None,
918
- dtype: DTypeLike = None):
919
- dtype = dtype or environ.dftype()
920
- df = jnp.asarray(_check_py_seq(df), dtype=dtype)
921
- nonc = jnp.asarray(_check_py_seq(nonc), dtype=dtype)
922
- if size is None:
923
- size = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nonc))
924
- size = _size2shape(size)
925
- if key is None:
926
- keys = self.split_keys(3)
927
- else:
928
- keys = jr.split(_formalize_key(key), 3)
929
- i = jr.poisson(keys[0], 0.5 * nonc, shape=size, dtype=environ.ditype())
930
- n = jr.normal(keys[1], shape=size, dtype=dtype) + jnp.sqrt(nonc)
931
- cond = jnp.greater(df, 1.0)
932
- df2 = jnp.where(cond, df - 1.0, df + 2.0 * i)
933
- chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size, dtype=dtype)
934
- r = jnp.where(cond, chi2 + n * n, chi2)
935
- return r
936
-
937
- def loggamma(self,
938
- a,
939
- size: Optional[Size] = None,
940
- key: Optional[SeedOrKey] = None,
941
- dtype: DTypeLike = None):
942
- dtype = dtype or environ.dftype()
943
- key = self.split_key() if key is None else _formalize_key(key)
944
- a = _check_py_seq(a)
945
- if size is None:
946
- size = jnp.shape(a)
947
- r = jr.loggamma(key, a, shape=_size2shape(size), dtype=dtype)
948
- return r
949
-
950
- def categorical(self,
951
- logits,
952
- axis: int = -1,
953
- size: Optional[Size] = None,
954
- key: Optional[SeedOrKey] = None):
955
- key = self.split_key() if key is None else _formalize_key(key)
956
- logits = _check_py_seq(logits)
957
- if size is None:
958
- size = list(jnp.shape(logits))
959
- size.pop(axis)
960
- r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
961
- return r
962
-
963
- def zipf(self,
964
- a,
965
- size: Optional[Size] = None,
966
- key: Optional[SeedOrKey] = None,
967
- dtype: DTypeLike = None):
968
- a = _check_py_seq(a)
969
- if size is None:
970
- size = jnp.shape(a)
971
- dtype = dtype or environ.ditype()
972
- r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
973
- jax.ShapeDtypeStruct(size, dtype),
974
- a)
975
- return r
976
-
977
- def power(self,
978
- a,
979
- size: Optional[Size] = None,
980
- key: Optional[SeedOrKey] = None,
981
- dtype: DTypeLike = None):
982
- a = _check_py_seq(a)
983
- if size is None:
984
- size = jnp.shape(a)
985
- size = _size2shape(size)
986
- dtype = dtype or environ.dftype()
987
- r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
988
- jax.ShapeDtypeStruct(size, dtype),
989
- a)
990
- return r
991
-
992
- def f(self,
993
- dfnum,
994
- dfden,
995
- size: Optional[Size] = None,
996
- key: Optional[SeedOrKey] = None,
997
- dtype: DTypeLike = None):
998
- dfnum = _check_py_seq(dfnum)
999
- dfden = _check_py_seq(dfden)
1000
- if size is None:
1001
- size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden))
1002
- size = _size2shape(size)
1003
- d = {'dfnum': dfnum, 'dfden': dfden}
1004
- dtype = dtype or environ.dftype()
1005
- r = jax.pure_callback(lambda dfnum_, dfden_: np.random.f(dfnum=dfnum_,
1006
- dfden=dfden_,
1007
- size=size).astype(dtype),
1008
- jax.ShapeDtypeStruct(size, dtype),
1009
- dfnum, dfden)
1010
- return r
1011
-
1012
- def hypergeometric(
1013
- self,
1014
- ngood,
1015
- nbad,
1016
- nsample,
1017
- size: Optional[Size] = None,
1018
- key: Optional[SeedOrKey] = None,
1019
- dtype: DTypeLike = None
1020
- ):
1021
- ngood = _check_py_seq(ngood)
1022
- nbad = _check_py_seq(nbad)
1023
- nsample = _check_py_seq(nsample)
1024
-
1025
- if size is None:
1026
- size = lax.broadcast_shapes(jnp.shape(ngood),
1027
- jnp.shape(nbad),
1028
- jnp.shape(nsample))
1029
- size = _size2shape(size)
1030
- dtype = dtype or environ.ditype()
1031
- d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
1032
- r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'],
1033
- nbad=d['nbad'],
1034
- nsample=d['nsample'],
1035
- size=size).astype(dtype),
1036
- jax.ShapeDtypeStruct(size, dtype),
1037
- d)
1038
- return r
1039
-
1040
- def logseries(self,
1041
- p,
1042
- size: Optional[Size] = None,
1043
- key: Optional[SeedOrKey] = None,
1044
- dtype: DTypeLike = None):
1045
- p = _check_py_seq(p)
1046
- if size is None:
1047
- size = jnp.shape(p)
1048
- size = _size2shape(size)
1049
- dtype = dtype or environ.ditype()
1050
- r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype),
1051
- jax.ShapeDtypeStruct(size, dtype),
1052
- p)
1053
- return r
1054
-
1055
- def noncentral_f(self,
1056
- dfnum,
1057
- dfden,
1058
- nonc,
1059
- size: Optional[Size] = None,
1060
- key: Optional[SeedOrKey] = None,
1061
- dtype: DTypeLike = None):
1062
- dfnum = _check_py_seq(dfnum)
1063
- dfden = _check_py_seq(dfden)
1064
- nonc = _check_py_seq(nonc)
1065
- if size is None:
1066
- size = lax.broadcast_shapes(jnp.shape(dfnum),
1067
- jnp.shape(dfden),
1068
- jnp.shape(nonc))
1069
- size = _size2shape(size)
1070
- d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
1071
- dtype = dtype or environ.dftype()
1072
- r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
1073
- dfden=x['dfden'],
1074
- nonc=x['nonc'],
1075
- size=size).astype(dtype),
1076
- jax.ShapeDtypeStruct(size, dtype),
1077
- d)
1078
- return r
1079
-
1080
- # PyTorch compatibility #
1081
- # --------------------- #
1082
-
1083
- def rand_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
1084
- """Returns a tensor with the same size as input that is filled with random
1085
- numbers from a uniform distribution on the interval ``[0, 1)``.
1086
-
1087
- Args:
1088
- input: the ``size`` of input will determine size of the output tensor.
1089
- dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1090
- key: the seed or key for the random.
1091
-
1092
- Returns:
1093
- The random data.
1094
- """
1095
- return self.random(jnp.shape(input), key=key).astype(dtype)
1096
-
1097
- def randn_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
1098
- """Returns a tensor with the same size as ``input`` that is filled with
1099
- random numbers from a normal distribution with mean 0 and variance 1.
1100
-
1101
- Args:
1102
- input: the ``size`` of input will determine size of the output tensor.
1103
- dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1104
- key: the seed or key for the random.
1105
-
1106
- Returns:
1107
- The random data.
1108
- """
1109
- return self.randn(*jnp.shape(input), key=key).astype(dtype)
1110
-
1111
- def randint_like(self, input, low=0, high=None, *, dtype=None, key: Optional[SeedOrKey] = None):
1112
- if high is None:
1113
- high = max(input)
1114
- return self.randint(low, high=high, size=jnp.shape(input), dtype=dtype, key=key)
1115
-
1116
-
1117
- # default random generator
1118
- DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))
1119
-
1120
-
1121
- def split_key():
1122
- """Create a new seed from the current seed.
1123
-
1124
- This function is useful for the consistency with JAX's random paradigm."""
1125
- return DEFAULT.split_key()
1126
-
1127
-
1128
- def split_keys(n):
1129
- """Create multiple seeds from the current seed. This is used
1130
- internally by `pmap` and `vmap` to ensure that random numbers
1131
- are different in parallel threads.
1132
-
1133
- Parameters
1134
- ----------
1135
- n : int
1136
- The number of seeds to generate.
1137
- """
1138
- return DEFAULT.split_keys(n)
1139
-
1140
-
1141
- def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState:
1142
- """Clone the random state according to the given setting.
1143
-
1144
- Args:
1145
- seed_or_key: The seed (an integer) or the random key.
1146
- clone: Bool. Whether clone the default random state.
1147
-
1148
- Returns:
1149
- The random state.
1150
- """
1151
- if seed_or_key is None:
1152
- return DEFAULT.clone() if clone else DEFAULT
1153
- else:
1154
- return RandomState(seed_or_key)
1155
-
1156
-
1157
- def default_rng(seed_or_key=None, clone: bool = True) -> RandomState:
1158
- """
1159
- Get the default random state.
1160
-
1161
- Args:
1162
- seed_or_key: The seed (an integer) or the jax random key.
1163
- clone: Bool. Whether clone the default random state.
1164
-
1165
- Returns:
1166
- The random state.
1167
- """
1168
- if seed_or_key is None:
1169
- return DEFAULT.clone() if clone else DEFAULT
1170
- else:
1171
- return RandomState(seed_or_key)
1172
-
1173
-
1174
- def seed(seed_or_key: int = None):
1175
- """Sets a new random seed.
1176
-
1177
- Parameters
1178
- ----------
1179
- seed_or_key: int, optional
1180
- The random seed (an integer) or jax random key.
1181
- """
1182
- with jax.ensure_compile_time_eval():
1183
- if seed_or_key is None:
1184
- seed_or_key = np.random.randint(0, 100000)
1185
-
1186
- # numpy random seed
1187
- if np.size(seed_or_key) == 1: # seed
1188
- np.random.seed(seed_or_key)
1189
- elif np.size(seed_or_key) == 2: # jax random key
1190
- np.random.seed(seed_or_key[0])
1191
- else:
1192
- raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
1193
-
1194
- # jax random seed
1195
- DEFAULT.seed(seed_or_key)
1196
-
1197
-
1198
- @contextmanager
1199
- def seed_context(seed_or_key: SeedOrKey):
1200
- """
1201
- A context manager that sets the random seed for the duration of the block.
1202
-
1203
- Examples:
1204
-
1205
- >>> import brainstate as bst
1206
- >>> print(bst.random.rand(2))
1207
- [0.57721865 0.9820676 ]
1208
- >>> print(bst.random.rand(2))
1209
- [0.8511752 0.95312667]
1210
- >>> with bst.random.seed_context(42):
1211
- ... print(bst.random.rand(2))
1212
- [0.95598125 0.4032725 ]
1213
- >>> with bst.random.seed_context(42):
1214
- ... print(bst.random.rand(2))
1215
- [0.95598125 0.4032725 ]
1216
-
1217
- .. note::
1218
-
1219
- The context manager does not only set the seed for the AX random state, but also for the numpy random state.
1220
-
1221
- Args:
1222
- seed_or_key: The seed (an integer) or jax random key.
1223
-
1224
- """
1225
- old_jrand_key = DEFAULT.value
1226
- old_np_state = np.random.get_state()
1227
- try:
1228
- if np.size(seed_or_key) == 1: # seed
1229
- np.random.seed(seed_or_key)
1230
- elif np.size(seed_or_key) == 2: # jax random key
1231
- np.random.seed(seed_or_key[0])
1232
- else:
1233
- raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
1234
- DEFAULT.seed(seed_or_key)
1235
- yield
1236
- finally:
1237
- np.random.set_state(old_np_state)
1238
- DEFAULT.seed(old_jrand_key)
1239
-
1240
-
1241
44
  def rand(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
1242
45
  r"""
1243
46
  Random values in a given shape.
@@ -1260,6 +63,9 @@ def rand(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
1260
63
  dtype : dtype, optional
1261
64
  Desired dtype of the result. Byteorder must be native.
1262
65
  The default value is float.
66
+ key : PRNGKey, optional
67
+ The key for the random number generator. If not given, the
68
+ default random number generator is used.
1263
69
 
1264
70
  Returns
1265
71
  -------
@@ -1272,7 +78,8 @@ def rand(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
1272
78
 
1273
79
  Examples
1274
80
  --------
1275
- >>> brainstate.random.rand(3,2)
81
+ >>> import brainstate as bst
82
+ >>> bst.random.rand(3,2)
1276
83
  array([[ 0.14022471, 0.96360618], #random
1277
84
  [ 0.37601032, 0.25528411], #random
1278
85
  [ 0.49313049, 0.94909878]]) #random
@@ -4790,402 +3597,6 @@ def randint_like(input, low=0, high=None, *, dtype=None, key: Optional[SeedOrKey
4790
3597
  # ---------------------------------------------------------------------------------------------------------------
4791
3598
 
4792
3599
 
4793
- def _formalize_key(key):
4794
- if isinstance(key, int):
4795
- return jr.PRNGKey(key)
4796
- elif isinstance(key, (jax.Array, np.ndarray)):
4797
- if key.dtype != jnp.uint32:
4798
- raise TypeError('key must be a int or an array with two uint32.')
4799
- if key.size != 2:
4800
- raise TypeError('key must be a int or an array with two uint32.')
4801
- return jnp.asarray(key, dtype=jnp.uint32)
4802
- else:
4803
- raise TypeError('key must be a int or an array with two uint32.')
4804
-
4805
-
4806
- def _size2shape(size):
4807
- if size is None:
4808
- return ()
4809
- elif isinstance(size, (tuple, list)):
4810
- return tuple(size)
4811
- else:
4812
- return (size,)
4813
-
4814
-
4815
- def _check_shape(name, shape, *param_shapes):
4816
- if param_shapes:
4817
- shape_ = lax.broadcast_shapes(shape, *param_shapes)
4818
- if shape != shape_:
4819
- msg = ("{} parameter shapes must be broadcast-compatible with shape "
4820
- "argument, and the result of broadcasting the shapes must equal "
4821
- "the shape argument, but got result {} for shape argument {}.")
4822
- raise ValueError(msg.format(name, shape_, shape))
4823
-
4824
-
4825
- def _is_python_scalar(x):
4826
- if hasattr(x, 'aval'):
4827
- return x.aval.weak_type
4828
- elif np.ndim(x) == 0:
4829
- return True
4830
- elif isinstance(x, (bool, int, float, complex)):
4831
- return True
4832
- else:
4833
- return False
4834
-
4835
-
4836
- python_scalar_dtypes = {
4837
- bool: np.dtype('bool'),
4838
- int: np.dtype('int64'),
4839
- float: np.dtype('float64'),
4840
- complex: np.dtype('complex128'),
4841
- }
4842
-
4843
-
4844
- def _dtype(x, *, canonicalize: bool = False):
4845
- """Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
4846
- if x is None:
4847
- raise ValueError(f"Invalid argument to dtype: {x}.")
4848
- elif isinstance(x, type) and x in python_scalar_dtypes:
4849
- dt = python_scalar_dtypes[x]
4850
- elif type(x) in python_scalar_dtypes:
4851
- dt = python_scalar_dtypes[type(x)]
4852
- elif hasattr(x, 'dtype'):
4853
- dt = x.dtype
4854
- else:
4855
- dt = np.result_type(x)
4856
- return dtypes.canonicalize_dtype(dt) if canonicalize else dt
4857
-
4858
-
4859
- def _const(example, val):
4860
- if _is_python_scalar(example):
4861
- dtype = dtypes.canonicalize_dtype(type(example))
4862
- val = dtypes.scalar_type_of(example)(val)
4863
- return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype)
4864
- else:
4865
- dtype = dtypes.canonicalize_dtype(example.dtype)
4866
- return np.array(val, dtype)
4867
-
4868
-
4869
- _tr_params = namedtuple(
4870
- "tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"]
4871
- )
4872
-
4873
-
4874
- def _get_tr_params(n, p):
4875
- # See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
4876
- # constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
4877
- mu = n * p
4878
- spq = jnp.sqrt(mu * (1 - p))
4879
- c = mu + 0.5
4880
- b = 1.15 + 2.53 * spq
4881
- a = -0.0873 + 0.0248 * b + 0.01 * p
4882
- alpha = (2.83 + 5.1 / b) * spq
4883
- u_r = 0.43
4884
- v_r = 0.92 - 4.2 / b
4885
- m = jnp.floor((n + 1) * p).astype(n.dtype)
4886
- log_p = jnp.log(p)
4887
- log1_p = jnp.log1p(-p)
4888
- log_h = ((m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) +
4889
- _stirling_approx_tail(m) + _stirling_approx_tail(n - m))
4890
- return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)
4891
-
4892
-
4893
- def _stirling_approx_tail(k):
4894
- precomputed = jnp.array([0.08106146679532726,
4895
- 0.04134069595540929,
4896
- 0.02767792568499834,
4897
- 0.02079067210376509,
4898
- 0.01664469118982119,
4899
- 0.01387612882307075,
4900
- 0.01189670994589177,
4901
- 0.01041126526197209,
4902
- 0.009255462182712733,
4903
- 0.008330563433362871],
4904
- dtype=environ.dftype())
4905
- kp1 = k + 1
4906
- kp1sq = (k + 1) ** 2
4907
- return jnp.where(k < 10,
4908
- precomputed[k],
4909
- (1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1)
4910
-
4911
-
4912
- def _binomial_btrs(key, p, n):
4913
- """
4914
- Based on the transformed rejection sampling algorithm (BTRS) from the
4915
- following reference:
4916
-
4917
- Hormann, "The Generation of Binonmial Random Variates"
4918
- (https://core.ac.uk/download/pdf/11007254.pdf)
4919
- """
4920
-
4921
- def _btrs_body_fn(val):
4922
- _, key, _, _ = val
4923
- key, key_u, key_v = jr.split(key, 3)
4924
- u = jr.uniform(key_u)
4925
- v = jr.uniform(key_v)
4926
- u = u - 0.5
4927
- k = jnp.floor(
4928
- (2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c
4929
- ).astype(n.dtype)
4930
- return k, key, u, v
4931
-
4932
- def _btrs_cond_fn(val):
4933
- def accept_fn(k, u, v):
4934
- # See acceptance condition in Step 3. (Page 3) of TRS algorithm
4935
- # v <= f(k) * g_grad(u) / alpha
4936
-
4937
- m = tr_params.m
4938
- log_p = tr_params.log_p
4939
- log1_p = tr_params.log1_p
4940
- # See: formula for log(f(k)) at bottom of Page 5.
4941
- log_f = (
4942
- (n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0))
4943
- + (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p)
4944
- + (_stirling_approx_tail(k) - _stirling_approx_tail(n - k))
4945
- + tr_params.log_h
4946
- )
4947
- g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
4948
- return jnp.log((v * tr_params.alpha) / g) <= log_f
4949
-
4950
- k, key, u, v = val
4951
- early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
4952
- early_reject = (k < 0) | (k > n)
4953
- return lax.cond(
4954
- early_accept | early_reject,
4955
- (),
4956
- lambda _: ~early_accept,
4957
- (k, u, v),
4958
- lambda x: ~accept_fn(*x),
4959
- )
4960
-
4961
- tr_params = _get_tr_params(n, p)
4962
- ret = lax.while_loop(
4963
- _btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0)
4964
- ) # use k=-1 initially so that cond_fn returns True
4965
- return ret[0]
4966
-
4967
-
4968
- def _binomial_inversion(key, p, n):
4969
- def _binom_inv_body_fn(val):
4970
- i, key, geom_acc = val
4971
- key, key_u = jr.split(key)
4972
- u = jr.uniform(key_u)
4973
- geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
4974
- geom_acc = geom_acc + geom
4975
- return i + 1, key, geom_acc
4976
-
4977
- def _binom_inv_cond_fn(val):
4978
- i, _, geom_acc = val
4979
- return geom_acc <= n
4980
-
4981
- log1_p = jnp.log1p(-p)
4982
- ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
4983
- return ret[0]
4984
-
4985
-
4986
- def _binomial_dispatch(key, p, n):
4987
- def dispatch(key, p, n):
4988
- is_le_mid = p <= 0.5
4989
- pq = jnp.where(is_le_mid, p, 1 - p)
4990
- mu = n * pq
4991
- k = lax.cond(
4992
- mu < 10,
4993
- (key, pq, n),
4994
- lambda x: _binomial_inversion(*x),
4995
- (key, pq, n),
4996
- lambda x: _binomial_btrs(*x),
4997
- )
4998
- return jnp.where(is_le_mid, k, n - k)
4999
-
5000
- # Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
5001
- cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
5002
- return lax.cond(
5003
- cond0 & (p < 1),
5004
- (key, p, n),
5005
- lambda x: dispatch(*x),
5006
- (),
5007
- lambda _: jnp.where(cond0, n, 0),
5008
- )
5009
-
5010
-
5011
- @partial(jit, static_argnums=(3,))
5012
- def _binomial(key, p, n, shape):
5013
- shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
5014
- # reshape to map over axis 0
5015
- p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
5016
- n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
5017
- key = jr.split(key, jnp.size(p))
5018
- if jax.default_backend() == "cpu":
5019
- ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
5020
- else:
5021
- ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
5022
- return jnp.reshape(ret, shape)
5023
-
5024
-
5025
- @partial(jit, static_argnums=(2,))
5026
- def _categorical(key, p, shape):
5027
- # this implementation is fast when event shape is small, and slow otherwise
5028
- # Ref: https://stackoverflow.com/a/34190035
5029
- shape = shape or p.shape[:-1]
5030
- s = jnp.cumsum(p, axis=-1)
5031
- r = jr.uniform(key, shape=shape + (1,))
5032
- return jnp.sum(s < r, axis=-1)
5033
-
5034
-
5035
- def _scatter_add_one(operand, indices, updates):
5036
- return lax.scatter_add(
5037
- operand,
5038
- indices,
5039
- updates,
5040
- lax.ScatterDimensionNumbers(
5041
- update_window_dims=(),
5042
- inserted_window_dims=(0,),
5043
- scatter_dims_to_operand_dims=(0,),
5044
- ),
5045
- )
5046
-
5047
-
5048
- def _reshape(x, shape):
5049
- if isinstance(x, (int, float, np.ndarray, np.generic)):
5050
- return np.reshape(x, shape)
5051
- else:
5052
- return jnp.reshape(x, shape)
5053
-
5054
-
5055
- def _promote_shapes(*args, shape=()):
5056
- # adapted from lax.lax_numpy
5057
- if len(args) < 2 and not shape:
5058
- return args
5059
- else:
5060
- shapes = [jnp.shape(arg) for arg in args]
5061
- num_dims = len(lax.broadcast_shapes(shape, *shapes))
5062
- return [
5063
- _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
5064
- for arg, s in zip(args, shapes)
5065
- ]
5066
-
5067
-
5068
- @partial(jit, static_argnums=(3, 4))
5069
- def _multinomial(key, p, n, n_max, shape=()):
5070
- if jnp.shape(n) != jnp.shape(p)[:-1]:
5071
- broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
5072
- n = jnp.broadcast_to(n, broadcast_shape)
5073
- p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
5074
- shape = shape or p.shape[:-1]
5075
- if n_max == 0:
5076
- return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
5077
- # get indices from categorical distribution then gather the result
5078
- indices = _categorical(key, p, (n_max,) + shape)
5079
- # mask out values when counts is heterogeneous
5080
- if jnp.ndim(n) > 0:
5081
- mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
5082
- mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
5083
- excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1),
5084
- jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))],
5085
- -1)
5086
- else:
5087
- mask = 1
5088
- excess = 0
5089
- # NB: we transpose to move batch shape to the front
5090
- indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
5091
- samples_2D = vmap(_scatter_add_one)(jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
5092
- jnp.expand_dims(indices_2D, axis=-1),
5093
- jnp.ones(indices_2D.shape, dtype=indices.dtype))
5094
- return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
5095
-
5096
-
5097
- @partial(jit, static_argnums=(2, 3), static_argnames=['shape', 'dtype'])
5098
- def _von_mises_centered(key, concentration, shape, dtype=None):
5099
- """Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
5100
-
5101
- Returns
5102
- -------
5103
- out: array_like
5104
- centered samples from von Mises
5105
-
5106
- References
5107
- ----------
5108
- .. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
5109
- Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
5110
-
5111
- """
5112
- shape = shape or jnp.shape(concentration)
5113
- dtype = dtype or environ.dftype()
5114
- concentration = lax.convert_element_type(concentration, dtype)
5115
- concentration = jnp.broadcast_to(concentration, shape)
5116
-
5117
- if dtype == jnp.float16:
5118
- s_cutoff = 1.8e-1
5119
- elif dtype == jnp.float32:
5120
- s_cutoff = 2e-2
5121
- elif dtype == jnp.float64:
5122
- s_cutoff = 1.2e-4
5123
- else:
5124
- raise ValueError(f"Unsupported dtype: {dtype}")
5125
-
5126
- r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
5127
- rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
5128
- s_exact = (1.0 + rho ** 2) / (2.0 * rho)
5129
-
5130
- s_approximate = 1.0 / concentration
5131
-
5132
- s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
5133
-
5134
- def cond_fn(*args):
5135
- """check if all are done or reached max number of iterations"""
5136
- i, _, done, _, _ = args[0]
5137
- return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
5138
-
5139
- def body_fn(*args):
5140
- i, key, done, _, w = args[0]
5141
- uni_ukey, uni_vkey, key = jr.split(key, 3)
5142
- u = jr.uniform(
5143
- key=uni_ukey,
5144
- shape=shape,
5145
- dtype=concentration.dtype,
5146
- minval=-1.0,
5147
- maxval=1.0,
5148
- )
5149
- z = jnp.cos(jnp.pi * u)
5150
- w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
5151
- y = concentration * (s - w)
5152
- v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
5153
- accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
5154
- return i + 1, key, accept | done, u, w
5155
-
5156
- init_done = jnp.zeros(shape, dtype=bool)
5157
- init_u = jnp.zeros(shape)
5158
- init_w = jnp.zeros(shape)
5159
-
5160
- _, _, done, u, w = lax.while_loop(
5161
- cond_fun=cond_fn,
5162
- body_fun=body_fn,
5163
- init_val=(jnp.array(0), key, init_done, init_u, init_w),
5164
- )
5165
-
5166
- return jnp.sign(u) * jnp.arccos(w)
5167
-
5168
-
5169
- def _loc_scale(loc, scale, value):
5170
- if loc is None:
5171
- if scale is None:
5172
- return value
5173
- else:
5174
- return value * scale
5175
- else:
5176
- if scale is None:
5177
- return value + loc
5178
- else:
5179
- return value * scale + loc
5180
-
5181
-
5182
- def _check_py_seq(seq):
5183
- return jnp.asarray(seq) if isinstance(seq, (tuple, list)) else seq
5184
-
5185
-
5186
- # ---------------------------------------------------------------------------------------------------------------
5187
-
5188
-
5189
3600
  for __k in dir(RandomState):
5190
3601
  __t = getattr(RandomState, __k)
5191
3602
  if not __k.startswith('__') and callable(__t) and (not __t.__doc__):