brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
brainstate/random/_rand_state.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
|
+
from __future__ import annotations
|
17
18
|
|
18
19
|
from collections import namedtuple
|
19
20
|
from functools import partial
|
@@ -30,7 +31,7 @@ from jax import lax, core, dtypes
|
|
30
31
|
|
31
32
|
from brainstate import environ
|
32
33
|
from brainstate._state import State
|
33
|
-
from brainstate.
|
34
|
+
from brainstate.compile._error_if import jit_error_if
|
34
35
|
from brainstate.typing import DTypeLike, Size, SeedOrKey
|
35
36
|
from ._random_for_unit import uniform_for_unit, permutation_for_unit
|
36
37
|
|
@@ -38,1064 +39,1098 @@ __all__ = ['RandomState', 'DEFAULT', ]
|
|
38
39
|
|
39
40
|
|
40
41
|
class RandomState(State):
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
42
|
+
"""RandomState that track the random generator state. """
|
43
|
+
|
44
|
+
# __slots__ = ('_backup', '_value')
|
45
|
+
|
46
|
+
def __init__(self, seed_or_key: Optional[SeedOrKey] = None):
|
47
|
+
"""RandomState constructor.
|
48
|
+
|
49
|
+
Parameters
|
50
|
+
----------
|
51
|
+
seed_or_key: int, Array, optional
|
52
|
+
It can be an integer for initial seed of the random number generator,
|
53
|
+
or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
|
54
|
+
"""
|
55
|
+
with jax.ensure_compile_time_eval():
|
56
|
+
if seed_or_key is None:
|
57
|
+
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
|
58
|
+
if isinstance(seed_or_key, int):
|
59
|
+
key = jr.PRNGKey(seed_or_key)
|
60
|
+
else:
|
61
|
+
if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
|
62
|
+
raise ValueError('key must be an array with dtype uint32. '
|
63
|
+
f'But we got {seed_or_key}')
|
64
|
+
key = seed_or_key
|
65
|
+
super().__init__(key)
|
66
|
+
|
67
|
+
self._backup = None
|
68
|
+
|
69
|
+
def __repr__(self):
|
70
|
+
return f'{self.__class__.__name__}({self.value})'
|
71
|
+
|
72
|
+
def check_if_deleted(self):
|
73
|
+
if (
|
74
|
+
isinstance(self._value, jax.Array) and
|
75
|
+
not isinstance(self._value, jax.core.Tracer) and
|
76
|
+
self._value.is_deleted()
|
77
|
+
):
|
78
|
+
self.seed()
|
79
|
+
|
80
|
+
# ------------------- #
|
81
|
+
# seed and random key #
|
82
|
+
# ------------------- #
|
83
|
+
|
84
|
+
def backup_key(self):
|
85
|
+
if self._backup is not None:
|
86
|
+
raise ValueError('The random key has been backed up, and has not been restored.')
|
87
|
+
self._backup = self.value
|
88
|
+
|
89
|
+
def restore_key(self):
|
90
|
+
if self._backup is None:
|
91
|
+
raise ValueError('The random key has not been backed up.')
|
92
|
+
self.value = self._backup
|
93
|
+
self._backup = None
|
94
|
+
|
95
|
+
def clone(self):
|
96
|
+
return type(self)(self.split_key())
|
97
|
+
|
98
|
+
def set_key(self, key: SeedOrKey):
|
99
|
+
self.value = key
|
100
|
+
|
101
|
+
def seed(self, seed_or_key: Optional[SeedOrKey] = None):
|
102
|
+
"""Sets a new random seed.
|
103
|
+
|
104
|
+
Parameters
|
105
|
+
----------
|
106
|
+
seed_or_key: int, ArrayLike, optional
|
107
|
+
It can be an integer for initial seed of the random number generator,
|
108
|
+
or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
|
109
|
+
"""
|
110
|
+
with jax.ensure_compile_time_eval():
|
111
|
+
if seed_or_key is None:
|
112
|
+
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
|
113
|
+
if np.size(seed_or_key) == 1:
|
114
|
+
key = jr.PRNGKey(seed_or_key)
|
115
|
+
else:
|
116
|
+
if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
|
117
|
+
raise ValueError('key must be an array with dtype uint32. '
|
118
|
+
f'But we got {seed_or_key}')
|
119
|
+
key = seed_or_key
|
120
|
+
self.value = key
|
121
|
+
|
122
|
+
def split_key(self, n: Optional[int] = None, backup: bool = False) -> SeedOrKey:
|
123
|
+
"""
|
124
|
+
Create a new seed from the current seed.
|
125
|
+
|
126
|
+
Parameters
|
127
|
+
----------
|
128
|
+
n: int, optional
|
129
|
+
The number of seeds to generate.
|
130
|
+
backup : bool, optional
|
131
|
+
Whether to backup the current key.
|
132
|
+
|
133
|
+
Returns
|
134
|
+
-------
|
135
|
+
key : SeedOrKey
|
136
|
+
The new seed or a tuple of JAX random keys.
|
137
|
+
"""
|
138
|
+
if n is not None:
|
139
|
+
assert isinstance(n, int) and n >= 1, f'n should be an integer greater than 1, but we got {n}'
|
140
|
+
|
141
|
+
if not isinstance(self.value, jax.Array):
|
142
|
+
self.value = jnp.asarray(self.value, dtype=jnp.uint32)
|
143
|
+
keys = jr.split(self.value, num=2 if n is None else n + 1)
|
144
|
+
self.value = keys[0]
|
145
|
+
if backup:
|
146
|
+
self.backup_key()
|
147
|
+
if n is None:
|
148
|
+
return keys[1]
|
149
|
+
else:
|
150
|
+
return keys[1:]
|
151
|
+
|
152
|
+
def self_assign_multi_keys(self, n: int, backup: bool = True):
|
153
|
+
"""
|
154
|
+
Self-assign multiple keys to the current random state.
|
155
|
+
"""
|
156
|
+
if backup:
|
157
|
+
keys = jr.split(self.value, n + 1)
|
158
|
+
self.value = keys[0]
|
159
|
+
self.backup_key()
|
160
|
+
self.value = keys[1:]
|
161
|
+
else:
|
162
|
+
self.value = jr.split(self.value, n)
|
163
|
+
|
164
|
+
# ---------------- #
|
165
|
+
# random functions #
|
166
|
+
# ---------------- #
|
167
|
+
|
168
|
+
def rand(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
|
169
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
170
|
+
dtype = dtype or environ.dftype()
|
171
|
+
r = uniform_for_unit(key, shape=dn, minval=0., maxval=1., dtype=dtype)
|
172
|
+
return r
|
173
|
+
|
174
|
+
def randint(
|
175
|
+
self,
|
176
|
+
low,
|
177
|
+
high=None,
|
178
|
+
size: Optional[Size] = None,
|
179
|
+
dtype: DTypeLike = None,
|
180
|
+
key: Optional[SeedOrKey] = None
|
181
|
+
):
|
182
|
+
if high is None:
|
183
|
+
high = low
|
184
|
+
low = 0
|
185
|
+
high = _check_py_seq(high)
|
186
|
+
low = _check_py_seq(low)
|
187
|
+
if size is None:
|
188
|
+
size = lax.broadcast_shapes(jnp.shape(low),
|
189
|
+
jnp.shape(high))
|
190
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
191
|
+
dtype = dtype or environ.ditype()
|
192
|
+
r = jr.randint(key,
|
193
|
+
shape=_size2shape(size),
|
194
|
+
minval=low, maxval=high, dtype=dtype)
|
195
|
+
return r
|
196
|
+
|
197
|
+
def random_integers(
|
198
|
+
self,
|
199
|
+
low,
|
200
|
+
high=None,
|
201
|
+
size: Optional[Size] = None,
|
202
|
+
key: Optional[SeedOrKey] = None,
|
203
|
+
dtype: DTypeLike = None,
|
204
|
+
):
|
205
|
+
low = _check_py_seq(low)
|
206
|
+
high = _check_py_seq(high)
|
207
|
+
if high is None:
|
208
|
+
high = low
|
209
|
+
low = 1
|
210
|
+
high += 1
|
211
|
+
if size is None:
|
212
|
+
size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
|
213
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
214
|
+
dtype = dtype or environ.ditype()
|
215
|
+
r = jr.randint(key,
|
216
|
+
shape=_size2shape(size),
|
217
|
+
minval=low,
|
218
|
+
maxval=high,
|
219
|
+
dtype=dtype)
|
220
|
+
return r
|
221
|
+
|
222
|
+
def randn(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
|
223
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
224
|
+
dtype = dtype or environ.dftype()
|
225
|
+
r = jr.normal(key, shape=dn, dtype=dtype)
|
226
|
+
return r
|
227
|
+
|
228
|
+
def random(self,
|
229
|
+
size: Optional[Size] = None,
|
230
|
+
key: Optional[SeedOrKey] = None,
|
231
|
+
dtype: DTypeLike = None):
|
232
|
+
dtype = dtype or environ.dftype()
|
233
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
234
|
+
r = uniform_for_unit(key, shape=_size2shape(size), minval=0., maxval=1., dtype=dtype)
|
235
|
+
return r
|
132
236
|
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
def randint(
|
140
|
-
self,
|
141
|
-
low,
|
142
|
-
high=None,
|
143
|
-
size: Optional[Size] = None,
|
144
|
-
dtype: DTypeLike = None,
|
145
|
-
key: Optional[SeedOrKey] = None
|
146
|
-
):
|
147
|
-
if high is None:
|
148
|
-
high = low
|
149
|
-
low = 0
|
150
|
-
high = _check_py_seq(high)
|
151
|
-
low = _check_py_seq(low)
|
152
|
-
if size is None:
|
153
|
-
size = lax.broadcast_shapes(jnp.shape(low),
|
154
|
-
jnp.shape(high))
|
155
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
156
|
-
dtype = dtype or environ.ditype()
|
157
|
-
r = jr.randint(key,
|
158
|
-
shape=_size2shape(size),
|
159
|
-
minval=low, maxval=high, dtype=dtype)
|
160
|
-
return r
|
161
|
-
|
162
|
-
def random_integers(
|
163
|
-
self,
|
164
|
-
low,
|
165
|
-
high=None,
|
166
|
-
size: Optional[Size] = None,
|
167
|
-
key: Optional[SeedOrKey] = None,
|
168
|
-
dtype: DTypeLike = None,
|
169
|
-
):
|
170
|
-
low = _check_py_seq(low)
|
171
|
-
high = _check_py_seq(high)
|
172
|
-
if high is None:
|
173
|
-
high = low
|
174
|
-
low = 1
|
175
|
-
high += 1
|
176
|
-
if size is None:
|
177
|
-
size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
|
178
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
179
|
-
dtype = dtype or environ.ditype()
|
180
|
-
r = jr.randint(key,
|
181
|
-
shape=_size2shape(size),
|
182
|
-
minval=low,
|
183
|
-
maxval=high,
|
184
|
-
dtype=dtype)
|
185
|
-
return r
|
186
|
-
|
187
|
-
def randn(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
|
188
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
189
|
-
dtype = dtype or environ.dftype()
|
190
|
-
r = jr.normal(key, shape=dn, dtype=dtype)
|
191
|
-
return r
|
237
|
+
def random_sample(self,
|
238
|
+
size: Optional[Size] = None,
|
239
|
+
key: Optional[SeedOrKey] = None,
|
240
|
+
dtype: DTypeLike = None):
|
241
|
+
r = self.random(size=size, key=key, dtype=dtype)
|
242
|
+
return r
|
192
243
|
|
193
|
-
|
244
|
+
def ranf(self,
|
194
245
|
size: Optional[Size] = None,
|
195
246
|
key: Optional[SeedOrKey] = None,
|
196
247
|
dtype: DTypeLike = None):
|
197
|
-
|
198
|
-
|
199
|
-
r = uniform_for_unit(key, shape=_size2shape(size), minval=0., maxval=1., dtype=dtype)
|
200
|
-
return r
|
201
|
-
|
202
|
-
def random_sample(self,
|
203
|
-
size: Optional[Size] = None,
|
204
|
-
key: Optional[SeedOrKey] = None,
|
205
|
-
dtype: DTypeLike = None):
|
206
|
-
r = self.random(size=size, key=key, dtype=dtype)
|
207
|
-
return r
|
248
|
+
r = self.random(size=size, key=key, dtype=dtype)
|
249
|
+
return r
|
208
250
|
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
251
|
+
def sample(self,
|
252
|
+
size: Optional[Size] = None,
|
253
|
+
key: Optional[SeedOrKey] = None,
|
254
|
+
dtype: DTypeLike = None):
|
255
|
+
r = self.random(size=size, key=key, dtype=dtype)
|
256
|
+
return r
|
215
257
|
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
258
|
+
def choice(self,
|
259
|
+
a,
|
260
|
+
size: Optional[Size] = None,
|
261
|
+
replace=True,
|
262
|
+
p=None,
|
263
|
+
key: Optional[SeedOrKey] = None):
|
264
|
+
a = _check_py_seq(a)
|
265
|
+
p = _check_py_seq(p)
|
266
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
267
|
+
r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
|
268
|
+
return r
|
269
|
+
|
270
|
+
def permutation(self,
|
271
|
+
x,
|
272
|
+
axis: int = 0,
|
273
|
+
independent: bool = False,
|
274
|
+
key: Optional[SeedOrKey] = None):
|
275
|
+
x = _check_py_seq(x)
|
276
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
277
|
+
r = permutation_for_unit(key, x, axis=axis, independent=independent)
|
278
|
+
return r
|
279
|
+
|
280
|
+
def shuffle(self,
|
281
|
+
x,
|
282
|
+
axis=0,
|
283
|
+
key: Optional[SeedOrKey] = None):
|
284
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
285
|
+
x = permutation_for_unit(key, x, axis=axis)
|
286
|
+
return x
|
222
287
|
|
223
|
-
|
288
|
+
def beta(self,
|
224
289
|
a,
|
225
|
-
|
226
|
-
replace=True,
|
227
|
-
p=None,
|
228
|
-
key: Optional[SeedOrKey] = None):
|
229
|
-
a = _check_py_seq(a)
|
230
|
-
p = _check_py_seq(p)
|
231
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
232
|
-
r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
|
233
|
-
return r
|
234
|
-
|
235
|
-
def permutation(self,
|
236
|
-
x,
|
237
|
-
axis: int = 0,
|
238
|
-
independent: bool = False,
|
239
|
-
key: Optional[SeedOrKey] = None):
|
240
|
-
x = _check_py_seq(x)
|
241
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
242
|
-
r = permutation_for_unit(key, x, axis=axis, independent=independent)
|
243
|
-
return r
|
244
|
-
|
245
|
-
def shuffle(self,
|
246
|
-
x,
|
247
|
-
axis=0,
|
248
|
-
key: Optional[SeedOrKey] = None):
|
249
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
250
|
-
x = permutation_for_unit(key, x, axis=axis)
|
251
|
-
return x
|
252
|
-
|
253
|
-
def beta(self,
|
254
|
-
a,
|
255
|
-
b,
|
256
|
-
size: Optional[Size] = None,
|
257
|
-
key: Optional[SeedOrKey] = None,
|
258
|
-
dtype: DTypeLike = None):
|
259
|
-
a = _check_py_seq(a)
|
260
|
-
b = _check_py_seq(b)
|
261
|
-
if size is None:
|
262
|
-
size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b))
|
263
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
264
|
-
dtype = dtype or environ.dftype()
|
265
|
-
r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype)
|
266
|
-
return r
|
267
|
-
|
268
|
-
def exponential(self,
|
269
|
-
scale=None,
|
270
|
-
size: Optional[Size] = None,
|
271
|
-
key: Optional[SeedOrKey] = None,
|
272
|
-
dtype: DTypeLike = None):
|
273
|
-
if size is None:
|
274
|
-
size = jnp.shape(scale)
|
275
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
276
|
-
dtype = dtype or environ.dftype()
|
277
|
-
scale = jnp.asarray(scale, dtype=dtype)
|
278
|
-
r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
|
279
|
-
if scale is not None:
|
280
|
-
r = r / scale
|
281
|
-
return r
|
282
|
-
|
283
|
-
def gamma(self,
|
284
|
-
shape,
|
285
|
-
scale=None,
|
286
|
-
size: Optional[Size] = None,
|
287
|
-
key: Optional[SeedOrKey] = None,
|
288
|
-
dtype: DTypeLike = None):
|
289
|
-
shape = _check_py_seq(shape)
|
290
|
-
scale = _check_py_seq(scale)
|
291
|
-
if size is None:
|
292
|
-
size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale))
|
293
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
294
|
-
dtype = dtype or environ.dftype()
|
295
|
-
r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
|
296
|
-
if scale is not None:
|
297
|
-
r = r * scale
|
298
|
-
return r
|
299
|
-
|
300
|
-
def gumbel(self,
|
301
|
-
loc=None,
|
302
|
-
scale=None,
|
290
|
+
b,
|
303
291
|
size: Optional[Size] = None,
|
304
292
|
key: Optional[SeedOrKey] = None,
|
305
293
|
dtype: DTypeLike = None):
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
294
|
+
a = _check_py_seq(a)
|
295
|
+
b = _check_py_seq(b)
|
296
|
+
if size is None:
|
297
|
+
size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b))
|
298
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
299
|
+
dtype = dtype or environ.dftype()
|
300
|
+
r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype)
|
301
|
+
return r
|
302
|
+
|
303
|
+
def exponential(self,
|
304
|
+
scale=None,
|
305
|
+
size: Optional[Size] = None,
|
306
|
+
key: Optional[SeedOrKey] = None,
|
307
|
+
dtype: DTypeLike = None):
|
308
|
+
if size is None:
|
309
|
+
size = jnp.shape(scale)
|
310
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
311
|
+
dtype = dtype or environ.dftype()
|
312
|
+
scale = jnp.asarray(scale, dtype=dtype)
|
313
|
+
r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
|
314
|
+
if scale is not None:
|
315
|
+
r = r / scale
|
316
|
+
return r
|
317
|
+
|
318
|
+
def gamma(self,
|
319
|
+
shape,
|
317
320
|
scale=None,
|
318
321
|
size: Optional[Size] = None,
|
319
322
|
key: Optional[SeedOrKey] = None,
|
320
323
|
dtype: DTypeLike = None):
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
324
|
+
shape = _check_py_seq(shape)
|
325
|
+
scale = _check_py_seq(scale)
|
326
|
+
if size is None:
|
327
|
+
size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale))
|
328
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
329
|
+
dtype = dtype or environ.dftype()
|
330
|
+
r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
|
331
|
+
if scale is not None:
|
332
|
+
r = r * scale
|
333
|
+
return r
|
334
|
+
|
335
|
+
def gumbel(self,
|
331
336
|
loc=None,
|
332
337
|
scale=None,
|
333
338
|
size: Optional[Size] = None,
|
334
339
|
key: Optional[SeedOrKey] = None,
|
335
340
|
dtype: DTypeLike = None):
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
key: Optional[SeedOrKey] = None,
|
364
|
-
dtype: DTypeLike = None):
|
365
|
-
if size is None:
|
366
|
-
size = jnp.shape(a)
|
367
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
368
|
-
dtype = dtype or environ.dftype()
|
369
|
-
a = jnp.asarray(a, dtype=dtype)
|
370
|
-
r = jr.pareto(key, b=a, shape=_size2shape(size), dtype=dtype)
|
371
|
-
return r
|
372
|
-
|
373
|
-
def poisson(self,
|
374
|
-
lam=1.0,
|
375
|
-
size: Optional[Size] = None,
|
376
|
-
key: Optional[SeedOrKey] = None,
|
377
|
-
dtype: DTypeLike = None):
|
378
|
-
lam = _check_py_seq(lam)
|
379
|
-
if size is None:
|
380
|
-
size = jnp.shape(lam)
|
381
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
382
|
-
dtype = dtype or environ.ditype()
|
383
|
-
r = jr.poisson(key, lam=lam, shape=_size2shape(size), dtype=dtype)
|
384
|
-
return r
|
385
|
-
|
386
|
-
def standard_cauchy(self,
|
387
|
-
size: Optional[Size] = None,
|
388
|
-
key: Optional[SeedOrKey] = None,
|
389
|
-
dtype: DTypeLike = None):
|
390
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
391
|
-
dtype = dtype or environ.dftype()
|
392
|
-
r = jr.cauchy(key, shape=_size2shape(size), dtype=dtype)
|
393
|
-
return r
|
394
|
-
|
395
|
-
def standard_exponential(self,
|
396
|
-
size: Optional[Size] = None,
|
397
|
-
key: Optional[SeedOrKey] = None,
|
398
|
-
dtype: DTypeLike = None):
|
399
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
400
|
-
dtype = dtype or environ.dftype()
|
401
|
-
r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
|
402
|
-
return r
|
403
|
-
|
404
|
-
def standard_gamma(self,
|
405
|
-
shape,
|
406
|
-
size: Optional[Size] = None,
|
407
|
-
key: Optional[SeedOrKey] = None,
|
408
|
-
dtype: DTypeLike = None):
|
409
|
-
shape = _check_py_seq(shape)
|
410
|
-
if size is None:
|
411
|
-
size = jnp.shape(shape)
|
412
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
413
|
-
dtype = dtype or environ.dftype()
|
414
|
-
r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
|
415
|
-
return r
|
416
|
-
|
417
|
-
def standard_normal(self,
|
418
|
-
size: Optional[Size] = None,
|
419
|
-
key: Optional[SeedOrKey] = None,
|
420
|
-
dtype: DTypeLike = None):
|
421
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
422
|
-
dtype = dtype or environ.dftype()
|
423
|
-
r = jr.normal(key, shape=_size2shape(size), dtype=dtype)
|
424
|
-
return r
|
425
|
-
|
426
|
-
def standard_t(self, df,
|
341
|
+
loc = _check_py_seq(loc)
|
342
|
+
scale = _check_py_seq(scale)
|
343
|
+
if size is None:
|
344
|
+
size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
|
345
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
346
|
+
dtype = dtype or environ.dftype()
|
347
|
+
r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype))
|
348
|
+
return r
|
349
|
+
|
350
|
+
def laplace(self,
|
351
|
+
loc=None,
|
352
|
+
scale=None,
|
353
|
+
size: Optional[Size] = None,
|
354
|
+
key: Optional[SeedOrKey] = None,
|
355
|
+
dtype: DTypeLike = None):
|
356
|
+
loc = _check_py_seq(loc)
|
357
|
+
scale = _check_py_seq(scale)
|
358
|
+
if size is None:
|
359
|
+
size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
|
360
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
361
|
+
dtype = dtype or environ.dftype()
|
362
|
+
r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype))
|
363
|
+
return r
|
364
|
+
|
365
|
+
def logistic(self,
|
366
|
+
loc=None,
|
367
|
+
scale=None,
|
427
368
|
size: Optional[Size] = None,
|
428
369
|
key: Optional[SeedOrKey] = None,
|
429
370
|
dtype: DTypeLike = None):
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
key: Optional[SeedOrKey] = None,
|
443
|
-
dtype: DTypeLike = None):
|
444
|
-
low = _check_py_seq(low)
|
445
|
-
high = _check_py_seq(high)
|
446
|
-
if size is None:
|
447
|
-
size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
|
448
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
449
|
-
dtype = dtype or environ.dftype()
|
450
|
-
r = uniform_for_unit(key, shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)
|
451
|
-
return r
|
452
|
-
|
453
|
-
def __norm_cdf(self, x, sqrt2, dtype):
|
454
|
-
# Computes standard normal cumulative distribution function
|
455
|
-
return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
|
456
|
-
|
457
|
-
def truncated_normal(
|
458
|
-
self,
|
459
|
-
lower,
|
460
|
-
upper,
|
461
|
-
size: Optional[Size] = None,
|
462
|
-
loc=0.,
|
463
|
-
scale=1.,
|
464
|
-
key: Optional[SeedOrKey] = None,
|
465
|
-
dtype: DTypeLike = None
|
466
|
-
):
|
467
|
-
lower = _check_py_seq(lower)
|
468
|
-
upper = _check_py_seq(upper)
|
469
|
-
loc = _check_py_seq(loc)
|
470
|
-
scale = _check_py_seq(scale)
|
471
|
-
dtype = dtype or environ.dftype()
|
472
|
-
|
473
|
-
lower = u.math.asarray(lower, dtype=dtype)
|
474
|
-
upper = u.math.asarray(upper, dtype=dtype)
|
475
|
-
loc = u.math.asarray(loc, dtype=dtype)
|
476
|
-
scale = u.math.asarray(scale, dtype=dtype)
|
477
|
-
unit = u.get_unit(lower)
|
478
|
-
lower, upper, loc, scale = (
|
479
|
-
lower.mantissa if isinstance(lower, u.Quantity) else lower,
|
480
|
-
u.Quantity(upper).in_unit(unit).mantissa,
|
481
|
-
u.Quantity(loc).in_unit(unit).mantissa,
|
482
|
-
u.Quantity(scale).in_unit(unit).mantissa
|
483
|
-
)
|
484
|
-
|
485
|
-
jit_error_if(
|
486
|
-
u.math.any(u.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
|
487
|
-
"mean is more than 2 std from [lower, upper] in truncated_normal. "
|
488
|
-
"The distribution of values may be incorrect."
|
489
|
-
)
|
490
|
-
|
491
|
-
if size is None:
|
492
|
-
size = u.math.broadcast_shapes(jnp.shape(lower),
|
493
|
-
jnp.shape(upper),
|
494
|
-
jnp.shape(loc),
|
495
|
-
jnp.shape(scale))
|
496
|
-
|
497
|
-
# Values are generated by using a truncated uniform distribution and
|
498
|
-
# then using the inverse CDF for the normal distribution.
|
499
|
-
# Get upper and lower cdf values
|
500
|
-
sqrt2 = np.array(np.sqrt(2), dtype=dtype)
|
501
|
-
l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
|
502
|
-
u_ = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
|
503
|
-
|
504
|
-
# Uniformly fill tensor with values from [l, u], then translate to
|
505
|
-
# [2l-1, 2u-1].
|
506
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
507
|
-
out = uniform_for_unit(
|
508
|
-
key, size, dtype,
|
509
|
-
minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
|
510
|
-
maxval=lax.nextafter(2 * u_- 1, np.array(-np.inf, dtype=dtype))
|
511
|
-
)
|
512
|
-
|
513
|
-
# Use inverse cdf transform for normal distribution to get truncated
|
514
|
-
# standard normal
|
515
|
-
out = lax.erf_inv(out)
|
516
|
-
|
517
|
-
# Transform to proper mean, std
|
518
|
-
out = out * scale * sqrt2 + loc
|
519
|
-
|
520
|
-
# Clamp to ensure it's in the proper range
|
521
|
-
out = jnp.clip(
|
522
|
-
out,
|
523
|
-
lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
|
524
|
-
lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
|
525
|
-
)
|
526
|
-
return out if unit.is_unitless else u.Quantity(out, unit=unit)
|
527
|
-
|
528
|
-
def _check_p(self, p):
|
529
|
-
raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
|
530
|
-
|
531
|
-
def bernoulli(self,
|
532
|
-
p,
|
533
|
-
size: Optional[Size] = None,
|
534
|
-
key: Optional[SeedOrKey] = None):
|
535
|
-
p = _check_py_seq(p)
|
536
|
-
jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
|
537
|
-
if size is None:
|
538
|
-
size = jnp.shape(p)
|
539
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
540
|
-
r = jr.bernoulli(key, p=p, shape=_size2shape(size))
|
541
|
-
return r
|
542
|
-
|
543
|
-
def lognormal(
|
544
|
-
self,
|
545
|
-
mean=None,
|
546
|
-
sigma=None,
|
547
|
-
size: Optional[Size] = None,
|
548
|
-
key: Optional[SeedOrKey] = None,
|
549
|
-
dtype: DTypeLike = None
|
550
|
-
):
|
551
|
-
mean = _check_py_seq(mean)
|
552
|
-
sigma = _check_py_seq(sigma)
|
553
|
-
mean = u.math.asarray(mean, dtype=dtype)
|
554
|
-
sigma = u.math.asarray(sigma, dtype=dtype)
|
555
|
-
unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
|
556
|
-
mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
|
557
|
-
sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, u.Quantity) else sigma
|
558
|
-
|
559
|
-
if size is None:
|
560
|
-
size = jnp.broadcast_shapes(
|
561
|
-
jnp.shape(mean),
|
562
|
-
jnp.shape(sigma)
|
563
|
-
)
|
564
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
565
|
-
dtype = dtype or environ.dftype()
|
566
|
-
samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
|
567
|
-
samples = _loc_scale(mean, sigma, samples)
|
568
|
-
samples = jnp.exp(samples)
|
569
|
-
return samples if unit.is_unitless else u.Quantity(samples, unit=unit)
|
570
|
-
|
571
|
-
def binomial(self,
|
572
|
-
n,
|
573
|
-
p,
|
371
|
+
loc = _check_py_seq(loc)
|
372
|
+
scale = _check_py_seq(scale)
|
373
|
+
if size is None:
|
374
|
+
size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
|
375
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
376
|
+
dtype = dtype or environ.dftype()
|
377
|
+
r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
|
378
|
+
return r
|
379
|
+
|
380
|
+
def normal(self,
|
381
|
+
loc=None,
|
382
|
+
scale=None,
|
574
383
|
size: Optional[Size] = None,
|
575
384
|
key: Optional[SeedOrKey] = None,
|
576
385
|
dtype: DTypeLike = None):
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
else:
|
602
|
-
dist = jr.normal(key, (df,) + _size2shape(size), dtype=dtype) ** 2
|
603
|
-
dist = dist.sum(axis=0)
|
604
|
-
return dist
|
605
|
-
|
606
|
-
def dirichlet(self,
|
607
|
-
alpha,
|
386
|
+
loc = _check_py_seq(loc)
|
387
|
+
scale = _check_py_seq(scale)
|
388
|
+
if size is None:
|
389
|
+
size = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(loc))
|
390
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
391
|
+
dtype = dtype or environ.dftype()
|
392
|
+
r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
|
393
|
+
return r
|
394
|
+
|
395
|
+
def pareto(self,
|
396
|
+
a,
|
397
|
+
size: Optional[Size] = None,
|
398
|
+
key: Optional[SeedOrKey] = None,
|
399
|
+
dtype: DTypeLike = None):
|
400
|
+
if size is None:
|
401
|
+
size = jnp.shape(a)
|
402
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
403
|
+
dtype = dtype or environ.dftype()
|
404
|
+
a = jnp.asarray(a, dtype=dtype)
|
405
|
+
r = jr.pareto(key, b=a, shape=_size2shape(size), dtype=dtype)
|
406
|
+
return r
|
407
|
+
|
408
|
+
def poisson(self,
|
409
|
+
lam=1.0,
|
608
410
|
size: Optional[Size] = None,
|
609
411
|
key: Optional[SeedOrKey] = None,
|
610
412
|
dtype: DTypeLike = None):
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
413
|
+
lam = _check_py_seq(lam)
|
414
|
+
if size is None:
|
415
|
+
size = jnp.shape(lam)
|
416
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
417
|
+
dtype = dtype or environ.ditype()
|
418
|
+
r = jr.poisson(key, lam=lam, shape=_size2shape(size), dtype=dtype)
|
419
|
+
return r
|
420
|
+
|
421
|
+
def standard_cauchy(self,
|
422
|
+
size: Optional[Size] = None,
|
423
|
+
key: Optional[SeedOrKey] = None,
|
424
|
+
dtype: DTypeLike = None):
|
425
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
426
|
+
dtype = dtype or environ.dftype()
|
427
|
+
r = jr.cauchy(key, shape=_size2shape(size), dtype=dtype)
|
428
|
+
return r
|
429
|
+
|
430
|
+
def standard_exponential(self,
|
431
|
+
size: Optional[Size] = None,
|
432
|
+
key: Optional[SeedOrKey] = None,
|
433
|
+
dtype: DTypeLike = None):
|
434
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
435
|
+
dtype = dtype or environ.dftype()
|
436
|
+
r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
|
437
|
+
return r
|
438
|
+
|
439
|
+
def standard_gamma(self,
|
440
|
+
shape,
|
441
|
+
size: Optional[Size] = None,
|
442
|
+
key: Optional[SeedOrKey] = None,
|
443
|
+
dtype: DTypeLike = None):
|
444
|
+
shape = _check_py_seq(shape)
|
445
|
+
if size is None:
|
446
|
+
size = jnp.shape(shape)
|
447
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
448
|
+
dtype = dtype or environ.dftype()
|
449
|
+
r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
|
450
|
+
return r
|
451
|
+
|
452
|
+
def standard_normal(self,
|
453
|
+
size: Optional[Size] = None,
|
454
|
+
key: Optional[SeedOrKey] = None,
|
455
|
+
dtype: DTypeLike = None):
|
456
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
457
|
+
dtype = dtype or environ.dftype()
|
458
|
+
r = jr.normal(key, shape=_size2shape(size), dtype=dtype)
|
459
|
+
return r
|
616
460
|
|
617
|
-
|
618
|
-
|
461
|
+
def standard_t(self, df,
|
462
|
+
size: Optional[Size] = None,
|
463
|
+
key: Optional[SeedOrKey] = None,
|
464
|
+
dtype: DTypeLike = None):
|
465
|
+
df = _check_py_seq(df)
|
466
|
+
if size is None:
|
467
|
+
size = jnp.shape(size)
|
468
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
469
|
+
dtype = dtype or environ.dftype()
|
470
|
+
r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
|
471
|
+
return r
|
472
|
+
|
473
|
+
def uniform(self,
|
474
|
+
low=0.0,
|
475
|
+
high=1.0,
|
619
476
|
size: Optional[Size] = None,
|
620
477
|
key: Optional[SeedOrKey] = None,
|
621
478
|
dtype: DTypeLike = None):
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
479
|
+
low = _check_py_seq(low)
|
480
|
+
high = _check_py_seq(high)
|
481
|
+
if size is None:
|
482
|
+
size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
|
483
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
484
|
+
dtype = dtype or environ.dftype()
|
485
|
+
r = uniform_for_unit(key, shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)
|
486
|
+
return r
|
487
|
+
|
488
|
+
def __norm_cdf(self, x, sqrt2, dtype):
|
489
|
+
# Computes standard normal cumulative distribution function
|
490
|
+
return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
|
491
|
+
|
492
|
+
def truncated_normal(
|
493
|
+
self,
|
494
|
+
lower,
|
495
|
+
upper,
|
496
|
+
size: Optional[Size] = None,
|
497
|
+
loc=0.,
|
498
|
+
scale=1.,
|
499
|
+
key: Optional[SeedOrKey] = None,
|
500
|
+
dtype: DTypeLike = None
|
501
|
+
):
|
502
|
+
lower = _check_py_seq(lower)
|
503
|
+
upper = _check_py_seq(upper)
|
504
|
+
loc = _check_py_seq(loc)
|
505
|
+
scale = _check_py_seq(scale)
|
506
|
+
dtype = dtype or environ.dftype()
|
507
|
+
|
508
|
+
lower = u.math.asarray(lower, dtype=dtype)
|
509
|
+
upper = u.math.asarray(upper, dtype=dtype)
|
510
|
+
loc = u.math.asarray(loc, dtype=dtype)
|
511
|
+
scale = u.math.asarray(scale, dtype=dtype)
|
512
|
+
unit = u.get_unit(lower)
|
513
|
+
lower, upper, loc, scale = (
|
514
|
+
lower.mantissa if isinstance(lower, u.Quantity) else lower,
|
515
|
+
u.Quantity(upper).in_unit(unit).mantissa,
|
516
|
+
u.Quantity(loc).in_unit(unit).mantissa,
|
517
|
+
u.Quantity(scale).in_unit(unit).mantissa
|
518
|
+
)
|
519
|
+
|
520
|
+
jit_error_if(
|
521
|
+
u.math.any(u.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
|
522
|
+
"mean is more than 2 std from [lower, upper] in truncated_normal. "
|
523
|
+
"The distribution of values may be incorrect."
|
524
|
+
)
|
525
|
+
|
526
|
+
if size is None:
|
527
|
+
size = u.math.broadcast_shapes(jnp.shape(lower),
|
528
|
+
jnp.shape(upper),
|
529
|
+
jnp.shape(loc),
|
530
|
+
jnp.shape(scale))
|
531
|
+
|
532
|
+
# Values are generated by using a truncated uniform distribution and
|
533
|
+
# then using the inverse CDF for the normal distribution.
|
534
|
+
# Get upper and lower cdf values
|
535
|
+
sqrt2 = np.array(np.sqrt(2), dtype=dtype)
|
536
|
+
l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
|
537
|
+
u_ = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
|
538
|
+
|
539
|
+
# Uniformly fill tensor with values from [l, u], then translate to
|
540
|
+
# [2l-1, 2u-1].
|
541
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
542
|
+
out = uniform_for_unit(
|
543
|
+
key, size, dtype,
|
544
|
+
minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
|
545
|
+
maxval=lax.nextafter(2 * u_ - 1, np.array(-np.inf, dtype=dtype))
|
546
|
+
)
|
547
|
+
|
548
|
+
# Use inverse cdf transform for normal distribution to get truncated
|
549
|
+
# standard normal
|
550
|
+
out = lax.erf_inv(out)
|
551
|
+
|
552
|
+
# Transform to proper mean, std
|
553
|
+
out = out * scale * sqrt2 + loc
|
554
|
+
|
555
|
+
# Clamp to ensure it's in the proper range
|
556
|
+
out = jnp.clip(
|
557
|
+
out,
|
558
|
+
lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
|
559
|
+
lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
|
560
|
+
)
|
561
|
+
return out if unit.is_unitless else u.Quantity(out, unit=unit)
|
562
|
+
|
563
|
+
def _check_p(self, p):
|
564
|
+
raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
|
565
|
+
|
566
|
+
def bernoulli(self,
|
567
|
+
p,
|
568
|
+
size: Optional[Size] = None,
|
569
|
+
key: Optional[SeedOrKey] = None):
|
570
|
+
p = _check_py_seq(p)
|
571
|
+
jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
|
572
|
+
if size is None:
|
573
|
+
size = jnp.shape(p)
|
574
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
575
|
+
r = jr.bernoulli(key, p=p, shape=_size2shape(size))
|
576
|
+
return r
|
577
|
+
|
578
|
+
def lognormal(
|
579
|
+
self,
|
580
|
+
mean=None,
|
581
|
+
sigma=None,
|
582
|
+
size: Optional[Size] = None,
|
583
|
+
key: Optional[SeedOrKey] = None,
|
584
|
+
dtype: DTypeLike = None
|
585
|
+
):
|
586
|
+
mean = _check_py_seq(mean)
|
587
|
+
sigma = _check_py_seq(sigma)
|
588
|
+
mean = u.math.asarray(mean, dtype=dtype)
|
589
|
+
sigma = u.math.asarray(sigma, dtype=dtype)
|
590
|
+
unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
|
591
|
+
mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
|
592
|
+
sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, u.Quantity) else sigma
|
593
|
+
|
594
|
+
if size is None:
|
595
|
+
size = jnp.broadcast_shapes(
|
596
|
+
jnp.shape(mean),
|
597
|
+
jnp.shape(sigma)
|
598
|
+
)
|
599
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
600
|
+
dtype = dtype or environ.dftype()
|
601
|
+
samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
|
602
|
+
samples = _loc_scale(mean, sigma, samples)
|
603
|
+
samples = jnp.exp(samples)
|
604
|
+
return samples if unit.is_unitless else u.Quantity(samples, unit=unit)
|
605
|
+
|
606
|
+
def binomial(self,
|
607
|
+
n,
|
608
|
+
p,
|
609
|
+
size: Optional[Size] = None,
|
610
|
+
key: Optional[SeedOrKey] = None,
|
611
|
+
dtype: DTypeLike = None):
|
612
|
+
n = _check_py_seq(n)
|
613
|
+
p = _check_py_seq(p)
|
614
|
+
jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
|
615
|
+
if size is None:
|
616
|
+
size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
|
617
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
618
|
+
r = _binomial(key, p, n, shape=_size2shape(size))
|
619
|
+
dtype = dtype or environ.ditype()
|
620
|
+
return jnp.asarray(r, dtype=dtype)
|
621
|
+
|
622
|
+
def chisquare(self,
|
623
|
+
df,
|
637
624
|
size: Optional[Size] = None,
|
638
625
|
key: Optional[SeedOrKey] = None,
|
639
626
|
dtype: DTypeLike = None):
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
cov,
|
657
|
-
size: Optional[Size] = None,
|
658
|
-
method: str = 'cholesky',
|
659
|
-
key: Optional[SeedOrKey] = None,
|
660
|
-
dtype: DTypeLike = None
|
661
|
-
):
|
662
|
-
if method not in {'svd', 'eigh', 'cholesky'}:
|
663
|
-
raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
|
664
|
-
dtype = dtype or environ.dftype()
|
665
|
-
mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
|
666
|
-
cov = u.math.asarray(_check_py_seq(cov), dtype=dtype)
|
667
|
-
if isinstance(mean, u.Quantity):
|
668
|
-
assert isinstance(cov, u.Quantity)
|
669
|
-
assert mean.unit ** 2 == cov.unit
|
670
|
-
mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
|
671
|
-
cov = cov.mantissa if isinstance(cov, u.Quantity) else cov
|
672
|
-
unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
|
673
|
-
|
674
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
675
|
-
if not jnp.ndim(mean) >= 1:
|
676
|
-
raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
|
677
|
-
if not jnp.ndim(cov) >= 2:
|
678
|
-
raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}")
|
679
|
-
n = mean.shape[-1]
|
680
|
-
if jnp.shape(cov)[-2:] != (n, n):
|
681
|
-
raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
|
682
|
-
f"but got cov.shape == {jnp.shape(cov)}.")
|
683
|
-
if size is None:
|
684
|
-
size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
|
685
|
-
else:
|
686
|
-
size = _size2shape(size)
|
687
|
-
_check_shape("normal", size, mean.shape[:-1], cov.shape[:-2])
|
688
|
-
|
689
|
-
if method == 'svd':
|
690
|
-
(u_, s, _) = jnp.linalg.svd(cov)
|
691
|
-
factor = u_ * jnp.sqrt(s[..., None, :])
|
692
|
-
elif method == 'eigh':
|
693
|
-
(w, v) = jnp.linalg.eigh(cov)
|
694
|
-
factor = v * jnp.sqrt(w[..., None, :])
|
695
|
-
else: # 'cholesky'
|
696
|
-
factor = jnp.linalg.cholesky(cov)
|
697
|
-
normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
|
698
|
-
r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
|
699
|
-
return r if unit.is_unitless else u.Quantity(r, unit=unit)
|
700
|
-
|
701
|
-
def rayleigh(self,
|
702
|
-
scale=1.0,
|
703
|
-
size: Optional[Size] = None,
|
704
|
-
key: Optional[SeedOrKey] = None,
|
705
|
-
dtype: DTypeLike = None):
|
706
|
-
scale = _check_py_seq(scale)
|
707
|
-
if size is None:
|
708
|
-
size = jnp.shape(scale)
|
709
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
710
|
-
dtype = dtype or environ.dftype()
|
711
|
-
x = jnp.sqrt(-2. * jnp.log(uniform_for_unit(key, shape=_size2shape(size), minval=0, maxval=1, dtype=dtype)))
|
712
|
-
r = x * scale
|
713
|
-
return r
|
714
|
-
|
715
|
-
def triangular(self,
|
716
|
-
size: Optional[Size] = None,
|
717
|
-
key: Optional[SeedOrKey] = None):
|
718
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
719
|
-
bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
|
720
|
-
r = 2 * bernoulli_samples - 1
|
721
|
-
return r
|
722
|
-
|
723
|
-
def vonmises(self,
|
724
|
-
mu,
|
725
|
-
kappa,
|
726
|
-
size: Optional[Size] = None,
|
727
|
-
key: Optional[SeedOrKey] = None,
|
728
|
-
dtype: DTypeLike = None):
|
729
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
730
|
-
dtype = dtype or environ.dftype()
|
731
|
-
mu = jnp.asarray(_check_py_seq(mu), dtype=dtype)
|
732
|
-
kappa = jnp.asarray(_check_py_seq(kappa), dtype=dtype)
|
733
|
-
if size is None:
|
734
|
-
size = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(kappa))
|
735
|
-
size = _size2shape(size)
|
736
|
-
samples = _von_mises_centered(key, kappa, size, dtype=dtype)
|
737
|
-
samples = samples + mu
|
738
|
-
samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
|
739
|
-
return samples
|
740
|
-
|
741
|
-
def weibull(self,
|
742
|
-
a,
|
743
|
-
size: Optional[Size] = None,
|
744
|
-
key: Optional[SeedOrKey] = None,
|
745
|
-
dtype: DTypeLike = None):
|
746
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
747
|
-
a = _check_py_seq(a)
|
748
|
-
if size is None:
|
749
|
-
size = jnp.shape(a)
|
750
|
-
else:
|
751
|
-
if jnp.size(a) > 1:
|
752
|
-
raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
|
753
|
-
size = _size2shape(size)
|
754
|
-
dtype = dtype or environ.dftype()
|
755
|
-
random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
|
756
|
-
r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
|
757
|
-
return r
|
758
|
-
|
759
|
-
def weibull_min(self,
|
760
|
-
a,
|
761
|
-
scale=None,
|
627
|
+
df = _check_py_seq(df)
|
628
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
629
|
+
dtype = dtype or environ.dftype()
|
630
|
+
if size is None:
|
631
|
+
if jnp.ndim(df) == 0:
|
632
|
+
dist = jr.normal(key, (df,), dtype=dtype) ** 2
|
633
|
+
dist = dist.sum()
|
634
|
+
else:
|
635
|
+
raise NotImplementedError('Do not support non-scale "df" when "size" is None')
|
636
|
+
else:
|
637
|
+
dist = jr.normal(key, (df,) + _size2shape(size), dtype=dtype) ** 2
|
638
|
+
dist = dist.sum(axis=0)
|
639
|
+
return dist
|
640
|
+
|
641
|
+
def dirichlet(self,
|
642
|
+
alpha,
|
762
643
|
size: Optional[Size] = None,
|
763
644
|
key: Optional[SeedOrKey] = None,
|
764
645
|
dtype: DTypeLike = None):
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
def wald(self,
|
813
|
-
mean,
|
814
|
-
scale,
|
815
|
-
size: Optional[Size] = None,
|
816
|
-
key: Optional[SeedOrKey] = None,
|
817
|
-
dtype: DTypeLike = None):
|
818
|
-
dtype = dtype or environ.dftype()
|
819
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
820
|
-
mean = jnp.asarray(_check_py_seq(mean), dtype=dtype)
|
821
|
-
scale = jnp.asarray(_check_py_seq(scale), dtype=dtype)
|
822
|
-
if size is None:
|
823
|
-
size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale))
|
824
|
-
size = _size2shape(size)
|
825
|
-
sampled_chi2 = jnp.square(self.randn(*size))
|
826
|
-
sampled_uniform = self.uniform(size=size, key=key, dtype=dtype)
|
827
|
-
# Wikipedia defines an intermediate x with the formula
|
828
|
-
# x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2)
|
829
|
-
# where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration.
|
830
|
-
# Let us write
|
831
|
-
# w = loc * y / (2 * conc)
|
832
|
-
# Then we can extract the common factor in the last two terms to obtain
|
833
|
-
# x = loc + loc * w * (1 - sqrt(2 / w + 1))
|
834
|
-
# Now we see that the Wikipedia formula suffers from catastrphic
|
835
|
-
# cancellation for large w (e.g., if conc << loc).
|
836
|
-
#
|
837
|
-
# Fortunately, we can fix this by multiplying both sides
|
838
|
-
# by 1 + sqrt(2 / w + 1). We get
|
839
|
-
# x * (1 + sqrt(2 / w + 1)) =
|
840
|
-
# = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1))
|
841
|
-
# = loc * (sqrt(2 / w + 1) - 1)
|
842
|
-
# The term sqrt(2 / w + 1) + 1 no longer presents numerical
|
843
|
-
# difficulties for large w, and sqrt(2 / w + 1) - 1 is just
|
844
|
-
# sqrt1pm1(2 / w), which we know how to compute accurately.
|
845
|
-
# This just leaves the matter of small w, where 2 / w may
|
846
|
-
# overflow. In the limit a w -> 0, x -> loc, so we just mask
|
847
|
-
# that case.
|
848
|
-
sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above
|
849
|
-
safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0)
|
850
|
-
denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0)
|
851
|
-
ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator
|
852
|
-
sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above
|
853
|
-
res = jnp.where(sampled_uniform <= mean / (mean + sampled),
|
854
|
-
sampled,
|
855
|
-
jnp.square(mean) / sampled)
|
856
|
-
return res
|
857
|
-
|
858
|
-
def t(self,
|
859
|
-
df,
|
646
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
647
|
+
alpha = _check_py_seq(alpha)
|
648
|
+
dtype = dtype or environ.dftype()
|
649
|
+
r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size), dtype=dtype)
|
650
|
+
return r
|
651
|
+
|
652
|
+
def geometric(self,
|
653
|
+
p,
|
654
|
+
size: Optional[Size] = None,
|
655
|
+
key: Optional[SeedOrKey] = None,
|
656
|
+
dtype: DTypeLike = None):
|
657
|
+
p = _check_py_seq(p)
|
658
|
+
if size is None:
|
659
|
+
size = jnp.shape(p)
|
660
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
661
|
+
dtype = dtype or environ.dftype()
|
662
|
+
u_ = uniform_for_unit(key, size, dtype=dtype)
|
663
|
+
r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p))
|
664
|
+
return r
|
665
|
+
|
666
|
+
def _check_p2(self, p):
|
667
|
+
raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
|
668
|
+
|
669
|
+
def multinomial(self,
|
670
|
+
n,
|
671
|
+
pvals,
|
672
|
+
size: Optional[Size] = None,
|
673
|
+
key: Optional[SeedOrKey] = None,
|
674
|
+
dtype: DTypeLike = None):
|
675
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
676
|
+
n = _check_py_seq(n)
|
677
|
+
pvals = _check_py_seq(pvals)
|
678
|
+
jit_error_if(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
|
679
|
+
if isinstance(n, jax.core.Tracer):
|
680
|
+
raise ValueError("The total count parameter `n` should not be a jax abstract array.")
|
681
|
+
size = _size2shape(size)
|
682
|
+
n_max = int(np.max(jax.device_get(n)))
|
683
|
+
batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n))
|
684
|
+
r = _multinomial(key, pvals, n, n_max, batch_shape + size)
|
685
|
+
dtype = dtype or environ.ditype()
|
686
|
+
return jnp.asarray(r, dtype=dtype)
|
687
|
+
|
688
|
+
def multivariate_normal(
|
689
|
+
self,
|
690
|
+
mean,
|
691
|
+
cov,
|
860
692
|
size: Optional[Size] = None,
|
693
|
+
method: str = 'cholesky',
|
861
694
|
key: Optional[SeedOrKey] = None,
|
862
|
-
dtype: DTypeLike = None
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
695
|
+
dtype: DTypeLike = None
|
696
|
+
):
|
697
|
+
if method not in {'svd', 'eigh', 'cholesky'}:
|
698
|
+
raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
|
699
|
+
dtype = dtype or environ.dftype()
|
700
|
+
mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
|
701
|
+
cov = u.math.asarray(_check_py_seq(cov), dtype=dtype)
|
702
|
+
if isinstance(mean, u.Quantity):
|
703
|
+
assert isinstance(cov, u.Quantity)
|
704
|
+
assert mean.unit ** 2 == cov.unit
|
705
|
+
mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
|
706
|
+
cov = cov.mantissa if isinstance(cov, u.Quantity) else cov
|
707
|
+
unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
|
708
|
+
|
709
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
710
|
+
if not jnp.ndim(mean) >= 1:
|
711
|
+
raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
|
712
|
+
if not jnp.ndim(cov) >= 2:
|
713
|
+
raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}")
|
714
|
+
n = mean.shape[-1]
|
715
|
+
if jnp.shape(cov)[-2:] != (n, n):
|
716
|
+
raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
|
717
|
+
f"but got cov.shape == {jnp.shape(cov)}.")
|
718
|
+
if size is None:
|
719
|
+
size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
|
720
|
+
else:
|
721
|
+
size = _size2shape(size)
|
722
|
+
_check_shape("normal", size, mean.shape[:-1], cov.shape[:-2])
|
723
|
+
|
724
|
+
if method == 'svd':
|
725
|
+
(u_, s, _) = jnp.linalg.svd(cov)
|
726
|
+
factor = u_ * jnp.sqrt(s[..., None, :])
|
727
|
+
elif method == 'eigh':
|
728
|
+
(w, v) = jnp.linalg.eigh(cov)
|
729
|
+
factor = v * jnp.sqrt(w[..., None, :])
|
730
|
+
else: # 'cholesky'
|
731
|
+
factor = jnp.linalg.cholesky(cov)
|
732
|
+
normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
|
733
|
+
r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
|
734
|
+
return r if unit.is_unitless else u.Quantity(r, unit=unit)
|
735
|
+
|
736
|
+
def rayleigh(self,
|
737
|
+
scale=1.0,
|
883
738
|
size: Optional[Size] = None,
|
884
739
|
key: Optional[SeedOrKey] = None,
|
885
740
|
dtype: DTypeLike = None):
|
886
|
-
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
size: Optional[Size] = None,
|
924
|
-
key: Optional[SeedOrKey] = None,
|
925
|
-
dtype: DTypeLike = None):
|
926
|
-
dtype = dtype or environ.dftype()
|
927
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
928
|
-
a = _check_py_seq(a)
|
929
|
-
if size is None:
|
930
|
-
size = jnp.shape(a)
|
931
|
-
r = jr.loggamma(key, a, shape=_size2shape(size), dtype=dtype)
|
932
|
-
return r
|
933
|
-
|
934
|
-
def categorical(self,
|
935
|
-
logits,
|
936
|
-
axis: int = -1,
|
937
|
-
size: Optional[Size] = None,
|
938
|
-
key: Optional[SeedOrKey] = None):
|
939
|
-
key = self.split_key() if key is None else _formalize_key(key)
|
940
|
-
logits = _check_py_seq(logits)
|
941
|
-
if size is None:
|
942
|
-
size = list(jnp.shape(logits))
|
943
|
-
size.pop(axis)
|
944
|
-
r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
|
945
|
-
return r
|
946
|
-
|
947
|
-
def zipf(self,
|
948
|
-
a,
|
949
|
-
size: Optional[Size] = None,
|
950
|
-
key: Optional[SeedOrKey] = None,
|
951
|
-
dtype: DTypeLike = None):
|
952
|
-
a = _check_py_seq(a)
|
953
|
-
if size is None:
|
954
|
-
size = jnp.shape(a)
|
955
|
-
dtype = dtype or environ.ditype()
|
956
|
-
r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
|
957
|
-
jax.ShapeDtypeStruct(size, dtype),
|
958
|
-
a)
|
959
|
-
return r
|
960
|
-
|
961
|
-
def power(self,
|
962
|
-
a,
|
963
|
-
size: Optional[Size] = None,
|
964
|
-
key: Optional[SeedOrKey] = None,
|
965
|
-
dtype: DTypeLike = None):
|
966
|
-
a = _check_py_seq(a)
|
967
|
-
if size is None:
|
968
|
-
size = jnp.shape(a)
|
969
|
-
size = _size2shape(size)
|
970
|
-
dtype = dtype or environ.dftype()
|
971
|
-
r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
|
972
|
-
jax.ShapeDtypeStruct(size, dtype),
|
973
|
-
a)
|
974
|
-
return r
|
975
|
-
|
976
|
-
def f(self,
|
977
|
-
dfnum,
|
978
|
-
dfden,
|
979
|
-
size: Optional[Size] = None,
|
980
|
-
key: Optional[SeedOrKey] = None,
|
981
|
-
dtype: DTypeLike = None):
|
982
|
-
dfnum = _check_py_seq(dfnum)
|
983
|
-
dfden = _check_py_seq(dfden)
|
984
|
-
if size is None:
|
985
|
-
size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden))
|
986
|
-
size = _size2shape(size)
|
987
|
-
d = {'dfnum': dfnum, 'dfden': dfden}
|
988
|
-
dtype = dtype or environ.dftype()
|
989
|
-
r = jax.pure_callback(lambda dfnum_, dfden_: np.random.f(dfnum=dfnum_,
|
990
|
-
dfden=dfden_,
|
991
|
-
size=size).astype(dtype),
|
992
|
-
jax.ShapeDtypeStruct(size, dtype),
|
993
|
-
dfnum, dfden)
|
994
|
-
return r
|
995
|
-
|
996
|
-
def hypergeometric(
|
997
|
-
self,
|
998
|
-
ngood,
|
999
|
-
nbad,
|
1000
|
-
nsample,
|
1001
|
-
size: Optional[Size] = None,
|
1002
|
-
key: Optional[SeedOrKey] = None,
|
1003
|
-
dtype: DTypeLike = None
|
1004
|
-
):
|
1005
|
-
ngood = _check_py_seq(ngood)
|
1006
|
-
nbad = _check_py_seq(nbad)
|
1007
|
-
nsample = _check_py_seq(nsample)
|
1008
|
-
|
1009
|
-
if size is None:
|
1010
|
-
size = lax.broadcast_shapes(jnp.shape(ngood),
|
1011
|
-
jnp.shape(nbad),
|
1012
|
-
jnp.shape(nsample))
|
1013
|
-
size = _size2shape(size)
|
1014
|
-
dtype = dtype or environ.ditype()
|
1015
|
-
d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
|
1016
|
-
r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'],
|
1017
|
-
nbad=d['nbad'],
|
1018
|
-
nsample=d['nsample'],
|
1019
|
-
size=size).astype(dtype),
|
1020
|
-
jax.ShapeDtypeStruct(size, dtype),
|
1021
|
-
d)
|
1022
|
-
return r
|
1023
|
-
|
1024
|
-
def logseries(self,
|
1025
|
-
p,
|
741
|
+
scale = _check_py_seq(scale)
|
742
|
+
if size is None:
|
743
|
+
size = jnp.shape(scale)
|
744
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
745
|
+
dtype = dtype or environ.dftype()
|
746
|
+
x = jnp.sqrt(-2. * jnp.log(uniform_for_unit(key, shape=_size2shape(size), minval=0, maxval=1, dtype=dtype)))
|
747
|
+
r = x * scale
|
748
|
+
return r
|
749
|
+
|
750
|
+
def triangular(self,
|
751
|
+
size: Optional[Size] = None,
|
752
|
+
key: Optional[SeedOrKey] = None):
|
753
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
754
|
+
bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
|
755
|
+
r = 2 * bernoulli_samples - 1
|
756
|
+
return r
|
757
|
+
|
758
|
+
def vonmises(self,
|
759
|
+
mu,
|
760
|
+
kappa,
|
761
|
+
size: Optional[Size] = None,
|
762
|
+
key: Optional[SeedOrKey] = None,
|
763
|
+
dtype: DTypeLike = None):
|
764
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
765
|
+
dtype = dtype or environ.dftype()
|
766
|
+
mu = jnp.asarray(_check_py_seq(mu), dtype=dtype)
|
767
|
+
kappa = jnp.asarray(_check_py_seq(kappa), dtype=dtype)
|
768
|
+
if size is None:
|
769
|
+
size = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(kappa))
|
770
|
+
size = _size2shape(size)
|
771
|
+
samples = _von_mises_centered(key, kappa, size, dtype=dtype)
|
772
|
+
samples = samples + mu
|
773
|
+
samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
|
774
|
+
return samples
|
775
|
+
|
776
|
+
def weibull(self,
|
777
|
+
a,
|
1026
778
|
size: Optional[Size] = None,
|
1027
779
|
key: Optional[SeedOrKey] = None,
|
1028
780
|
dtype: DTypeLike = None):
|
1029
|
-
|
1030
|
-
|
1031
|
-
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
781
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
782
|
+
a = _check_py_seq(a)
|
783
|
+
if size is None:
|
784
|
+
size = jnp.shape(a)
|
785
|
+
else:
|
786
|
+
if jnp.size(a) > 1:
|
787
|
+
raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
|
788
|
+
size = _size2shape(size)
|
789
|
+
dtype = dtype or environ.dftype()
|
790
|
+
random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
|
791
|
+
r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
|
792
|
+
return r
|
793
|
+
|
794
|
+
def weibull_min(self,
|
795
|
+
a,
|
796
|
+
scale=None,
|
797
|
+
size: Optional[Size] = None,
|
798
|
+
key: Optional[SeedOrKey] = None,
|
799
|
+
dtype: DTypeLike = None):
|
800
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
801
|
+
a = _check_py_seq(a)
|
802
|
+
scale = _check_py_seq(scale)
|
803
|
+
if size is None:
|
804
|
+
size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale))
|
805
|
+
else:
|
806
|
+
if jnp.size(a) > 1:
|
807
|
+
raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
|
808
|
+
size = _size2shape(size)
|
809
|
+
dtype = dtype or environ.dftype()
|
810
|
+
random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
|
811
|
+
r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
|
812
|
+
if scale is not None:
|
813
|
+
r /= scale
|
814
|
+
return r
|
815
|
+
|
816
|
+
def maxwell(self,
|
817
|
+
size: Optional[Size] = None,
|
818
|
+
key: Optional[SeedOrKey] = None,
|
819
|
+
dtype: DTypeLike = None):
|
820
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
821
|
+
shape = _size2shape(size) + (3,)
|
822
|
+
dtype = dtype or environ.dftype()
|
823
|
+
norm_rvs = jr.normal(key=key, shape=shape, dtype=dtype)
|
824
|
+
r = jnp.linalg.norm(norm_rvs, axis=-1)
|
825
|
+
return r
|
826
|
+
|
827
|
+
def negative_binomial(self,
|
828
|
+
n,
|
829
|
+
p,
|
830
|
+
size: Optional[Size] = None,
|
831
|
+
key: Optional[SeedOrKey] = None,
|
832
|
+
dtype: DTypeLike = None):
|
833
|
+
n = _check_py_seq(n)
|
834
|
+
p = _check_py_seq(p)
|
835
|
+
if size is None:
|
836
|
+
size = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p))
|
837
|
+
size = _size2shape(size)
|
838
|
+
logits = jnp.log(p) - jnp.log1p(-p)
|
839
|
+
if key is None:
|
840
|
+
keys = self.split_key(2)
|
841
|
+
else:
|
842
|
+
keys = jr.split(_formalize_key(key), 2)
|
843
|
+
rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0], dtype=environ.dftype())
|
844
|
+
r = self.poisson(lam=rate, key=keys[1], dtype=dtype or environ.ditype())
|
845
|
+
return r
|
846
|
+
|
847
|
+
def wald(self,
|
848
|
+
mean,
|
849
|
+
scale,
|
850
|
+
size: Optional[Size] = None,
|
851
|
+
key: Optional[SeedOrKey] = None,
|
852
|
+
dtype: DTypeLike = None):
|
853
|
+
dtype = dtype or environ.dftype()
|
854
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
855
|
+
mean = jnp.asarray(_check_py_seq(mean), dtype=dtype)
|
856
|
+
scale = jnp.asarray(_check_py_seq(scale), dtype=dtype)
|
857
|
+
if size is None:
|
858
|
+
size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale))
|
859
|
+
size = _size2shape(size)
|
860
|
+
sampled_chi2 = jnp.square(self.randn(*size))
|
861
|
+
sampled_uniform = self.uniform(size=size, key=key, dtype=dtype)
|
862
|
+
# Wikipedia defines an intermediate x with the formula
|
863
|
+
# x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2)
|
864
|
+
# where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration.
|
865
|
+
# Let us write
|
866
|
+
# w = loc * y / (2 * conc)
|
867
|
+
# Then we can extract the common factor in the last two terms to obtain
|
868
|
+
# x = loc + loc * w * (1 - sqrt(2 / w + 1))
|
869
|
+
# Now we see that the Wikipedia formula suffers from catastrphic
|
870
|
+
# cancellation for large w (e.g., if conc << loc).
|
871
|
+
#
|
872
|
+
# Fortunately, we can fix this by multiplying both sides
|
873
|
+
# by 1 + sqrt(2 / w + 1). We get
|
874
|
+
# x * (1 + sqrt(2 / w + 1)) =
|
875
|
+
# = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1))
|
876
|
+
# = loc * (sqrt(2 / w + 1) - 1)
|
877
|
+
# The term sqrt(2 / w + 1) + 1 no longer presents numerical
|
878
|
+
# difficulties for large w, and sqrt(2 / w + 1) - 1 is just
|
879
|
+
# sqrt1pm1(2 / w), which we know how to compute accurately.
|
880
|
+
# This just leaves the matter of small w, where 2 / w may
|
881
|
+
# overflow. In the limit a w -> 0, x -> loc, so we just mask
|
882
|
+
# that case.
|
883
|
+
sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above
|
884
|
+
safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0)
|
885
|
+
denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0)
|
886
|
+
ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator
|
887
|
+
sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above
|
888
|
+
res = jnp.where(sampled_uniform <= mean / (mean + sampled),
|
889
|
+
sampled,
|
890
|
+
jnp.square(mean) / sampled)
|
891
|
+
return res
|
892
|
+
|
893
|
+
def t(self,
|
894
|
+
df,
|
895
|
+
size: Optional[Size] = None,
|
896
|
+
key: Optional[SeedOrKey] = None,
|
897
|
+
dtype: DTypeLike = None):
|
898
|
+
dtype = dtype or environ.dftype()
|
899
|
+
df = jnp.asarray(_check_py_seq(df), dtype=dtype)
|
900
|
+
if size is None:
|
901
|
+
size = np.shape(df)
|
902
|
+
else:
|
903
|
+
size = _size2shape(size)
|
904
|
+
_check_shape("t", size, np.shape(df))
|
905
|
+
if key is None:
|
906
|
+
keys = self.split_key(2)
|
907
|
+
else:
|
908
|
+
keys = jr.split(_formalize_key(key), 2)
|
909
|
+
n = jr.normal(keys[0], size, dtype=dtype)
|
910
|
+
two = _const(n, 2)
|
911
|
+
half_df = lax.div(df, two)
|
912
|
+
g = jr.gamma(keys[1], half_df, size, dtype=dtype)
|
913
|
+
r = n * jnp.sqrt(half_df / g)
|
914
|
+
return r
|
915
|
+
|
916
|
+
def orthogonal(self,
|
917
|
+
n: int,
|
1043
918
|
size: Optional[Size] = None,
|
1044
919
|
key: Optional[SeedOrKey] = None,
|
1045
920
|
dtype: DTypeLike = None):
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
921
|
+
dtype = dtype or environ.dftype()
|
922
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
923
|
+
size = _size2shape(size)
|
924
|
+
_check_shape("orthogonal", size)
|
925
|
+
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
|
926
|
+
z = jr.normal(key, size + (n, n), dtype=dtype)
|
927
|
+
q, r = jnp.linalg.qr(z)
|
928
|
+
d = jnp.diagonal(r, 0, -2, -1)
|
929
|
+
r = q * jnp.expand_dims(d / abs(d), -2)
|
930
|
+
return r
|
931
|
+
|
932
|
+
def noncentral_chisquare(self,
|
933
|
+
df,
|
934
|
+
nonc,
|
935
|
+
size: Optional[Size] = None,
|
936
|
+
key: Optional[SeedOrKey] = None,
|
937
|
+
dtype: DTypeLike = None):
|
938
|
+
dtype = dtype or environ.dftype()
|
939
|
+
df = jnp.asarray(_check_py_seq(df), dtype=dtype)
|
940
|
+
nonc = jnp.asarray(_check_py_seq(nonc), dtype=dtype)
|
941
|
+
if size is None:
|
942
|
+
size = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nonc))
|
943
|
+
size = _size2shape(size)
|
944
|
+
if key is None:
|
945
|
+
keys = self.split_key(3)
|
946
|
+
else:
|
947
|
+
keys = jr.split(_formalize_key(key), 3)
|
948
|
+
i = jr.poisson(keys[0], 0.5 * nonc, shape=size, dtype=environ.ditype())
|
949
|
+
n = jr.normal(keys[1], shape=size, dtype=dtype) + jnp.sqrt(nonc)
|
950
|
+
cond = jnp.greater(df, 1.0)
|
951
|
+
df2 = jnp.where(cond, df - 1.0, df + 2.0 * i)
|
952
|
+
chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size, dtype=dtype)
|
953
|
+
r = jnp.where(cond, chi2 + n * n, chi2)
|
954
|
+
return r
|
955
|
+
|
956
|
+
def loggamma(self,
|
957
|
+
a,
|
958
|
+
size: Optional[Size] = None,
|
959
|
+
key: Optional[SeedOrKey] = None,
|
960
|
+
dtype: DTypeLike = None):
|
961
|
+
dtype = dtype or environ.dftype()
|
962
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
963
|
+
a = _check_py_seq(a)
|
964
|
+
if size is None:
|
965
|
+
size = jnp.shape(a)
|
966
|
+
r = jr.loggamma(key, a, shape=_size2shape(size), dtype=dtype)
|
967
|
+
return r
|
968
|
+
|
969
|
+
def categorical(self,
|
970
|
+
logits,
|
971
|
+
axis: int = -1,
|
972
|
+
size: Optional[Size] = None,
|
973
|
+
key: Optional[SeedOrKey] = None):
|
974
|
+
key = self.split_key() if key is None else _formalize_key(key)
|
975
|
+
logits = _check_py_seq(logits)
|
976
|
+
if size is None:
|
977
|
+
size = list(jnp.shape(logits))
|
978
|
+
size.pop(axis)
|
979
|
+
r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
|
980
|
+
return r
|
981
|
+
|
982
|
+
def zipf(self,
|
983
|
+
a,
|
984
|
+
size: Optional[Size] = None,
|
985
|
+
key: Optional[SeedOrKey] = None,
|
986
|
+
dtype: DTypeLike = None):
|
987
|
+
a = _check_py_seq(a)
|
988
|
+
if size is None:
|
989
|
+
size = jnp.shape(a)
|
990
|
+
dtype = dtype or environ.ditype()
|
991
|
+
r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
|
992
|
+
jax.ShapeDtypeStruct(size, dtype),
|
993
|
+
a)
|
994
|
+
return r
|
995
|
+
|
996
|
+
def power(self,
|
997
|
+
a,
|
998
|
+
size: Optional[Size] = None,
|
999
|
+
key: Optional[SeedOrKey] = None,
|
1000
|
+
dtype: DTypeLike = None):
|
1001
|
+
a = _check_py_seq(a)
|
1002
|
+
if size is None:
|
1003
|
+
size = jnp.shape(a)
|
1004
|
+
size = _size2shape(size)
|
1005
|
+
dtype = dtype or environ.dftype()
|
1006
|
+
r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
|
1007
|
+
jax.ShapeDtypeStruct(size, dtype),
|
1008
|
+
a)
|
1009
|
+
return r
|
1010
|
+
|
1011
|
+
def f(self,
|
1012
|
+
dfnum,
|
1013
|
+
dfden,
|
1014
|
+
size: Optional[Size] = None,
|
1015
|
+
key: Optional[SeedOrKey] = None,
|
1016
|
+
dtype: DTypeLike = None):
|
1017
|
+
dfnum = _check_py_seq(dfnum)
|
1018
|
+
dfden = _check_py_seq(dfden)
|
1019
|
+
if size is None:
|
1020
|
+
size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden))
|
1021
|
+
size = _size2shape(size)
|
1022
|
+
d = {'dfnum': dfnum, 'dfden': dfden}
|
1023
|
+
dtype = dtype or environ.dftype()
|
1024
|
+
r = jax.pure_callback(lambda dfnum_, dfden_: np.random.f(dfnum=dfnum_,
|
1025
|
+
dfden=dfden_,
|
1026
|
+
size=size).astype(dtype),
|
1027
|
+
jax.ShapeDtypeStruct(size, dtype),
|
1028
|
+
dfnum, dfden)
|
1029
|
+
return r
|
1030
|
+
|
1031
|
+
def hypergeometric(
|
1032
|
+
self,
|
1033
|
+
ngood,
|
1034
|
+
nbad,
|
1035
|
+
nsample,
|
1036
|
+
size: Optional[Size] = None,
|
1037
|
+
key: Optional[SeedOrKey] = None,
|
1038
|
+
dtype: DTypeLike = None
|
1039
|
+
):
|
1040
|
+
ngood = _check_py_seq(ngood)
|
1041
|
+
nbad = _check_py_seq(nbad)
|
1042
|
+
nsample = _check_py_seq(nsample)
|
1043
|
+
|
1044
|
+
if size is None:
|
1045
|
+
size = lax.broadcast_shapes(jnp.shape(ngood),
|
1046
|
+
jnp.shape(nbad),
|
1047
|
+
jnp.shape(nsample))
|
1048
|
+
size = _size2shape(size)
|
1049
|
+
dtype = dtype or environ.ditype()
|
1050
|
+
d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
|
1051
|
+
r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'],
|
1052
|
+
nbad=d['nbad'],
|
1053
|
+
nsample=d['nsample'],
|
1054
|
+
size=size).astype(dtype),
|
1055
|
+
jax.ShapeDtypeStruct(size, dtype),
|
1056
|
+
d)
|
1057
|
+
return r
|
1058
|
+
|
1059
|
+
def logseries(self,
|
1060
|
+
p,
|
1061
|
+
size: Optional[Size] = None,
|
1062
|
+
key: Optional[SeedOrKey] = None,
|
1063
|
+
dtype: DTypeLike = None):
|
1064
|
+
p = _check_py_seq(p)
|
1065
|
+
if size is None:
|
1066
|
+
size = jnp.shape(p)
|
1067
|
+
size = _size2shape(size)
|
1068
|
+
dtype = dtype or environ.ditype()
|
1069
|
+
r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype),
|
1070
|
+
jax.ShapeDtypeStruct(size, dtype),
|
1071
|
+
p)
|
1072
|
+
return r
|
1073
|
+
|
1074
|
+
def noncentral_f(self,
|
1075
|
+
dfnum,
|
1076
|
+
dfden,
|
1077
|
+
nonc,
|
1078
|
+
size: Optional[Size] = None,
|
1079
|
+
key: Optional[SeedOrKey] = None,
|
1080
|
+
dtype: DTypeLike = None):
|
1081
|
+
dfnum = _check_py_seq(dfnum)
|
1082
|
+
dfden = _check_py_seq(dfden)
|
1083
|
+
nonc = _check_py_seq(nonc)
|
1084
|
+
if size is None:
|
1085
|
+
size = lax.broadcast_shapes(jnp.shape(dfnum),
|
1086
|
+
jnp.shape(dfden),
|
1087
|
+
jnp.shape(nonc))
|
1088
|
+
size = _size2shape(size)
|
1089
|
+
d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
|
1090
|
+
dtype = dtype or environ.dftype()
|
1091
|
+
r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
|
1092
|
+
dfden=x['dfden'],
|
1093
|
+
nonc=x['nonc'],
|
1094
|
+
size=size).astype(dtype),
|
1095
|
+
jax.ShapeDtypeStruct(size, dtype),
|
1096
|
+
d)
|
1097
|
+
return r
|
1098
|
+
|
1099
|
+
# PyTorch compatibility #
|
1100
|
+
# --------------------- #
|
1101
|
+
|
1102
|
+
def rand_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
|
1103
|
+
"""Returns a tensor with the same size as input that is filled with random
|
1104
|
+
numbers from a uniform distribution on the interval ``[0, 1)``.
|
1105
|
+
|
1106
|
+
Args:
|
1107
|
+
input: the ``size`` of input will determine size of the output tensor.
|
1108
|
+
dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
|
1109
|
+
key: the seed or key for the random.
|
1110
|
+
|
1111
|
+
Returns:
|
1112
|
+
The random data.
|
1113
|
+
"""
|
1114
|
+
return self.random(jnp.shape(input), key=key).astype(dtype)
|
1115
|
+
|
1116
|
+
def randn_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
|
1117
|
+
"""Returns a tensor with the same size as ``input`` that is filled with
|
1118
|
+
random numbers from a normal distribution with mean 0 and variance 1.
|
1119
|
+
|
1120
|
+
Args:
|
1121
|
+
input: the ``size`` of input will determine size of the output tensor.
|
1122
|
+
dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
|
1123
|
+
key: the seed or key for the random.
|
1124
|
+
|
1125
|
+
Returns:
|
1126
|
+
The random data.
|
1127
|
+
"""
|
1128
|
+
return self.randn(*jnp.shape(input), key=key).astype(dtype)
|
1129
|
+
|
1130
|
+
def randint_like(self, input, low=0, high=None, *, dtype=None, key: Optional[SeedOrKey] = None):
|
1131
|
+
if high is None:
|
1132
|
+
high = max(input)
|
1133
|
+
return self.randint(low, high=high, size=jnp.shape(input), dtype=dtype, key=key)
|
1099
1134
|
|
1100
1135
|
|
1101
1136
|
# default random generator
|
@@ -1106,393 +1141,393 @@ DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))
|
|
1106
1141
|
|
1107
1142
|
|
1108
1143
|
def _formalize_key(key):
|
1109
|
-
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
|
1114
|
-
|
1115
|
-
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1144
|
+
if isinstance(key, int):
|
1145
|
+
return jr.PRNGKey(key)
|
1146
|
+
elif isinstance(key, (jax.Array, np.ndarray)):
|
1147
|
+
if key.dtype != jnp.uint32:
|
1148
|
+
raise TypeError('key must be a int or an array with two uint32.')
|
1149
|
+
if key.size != 2:
|
1150
|
+
raise TypeError('key must be a int or an array with two uint32.')
|
1151
|
+
return jnp.asarray(key, dtype=jnp.uint32)
|
1152
|
+
else:
|
1153
|
+
raise TypeError('key must be a int or an array with two uint32.')
|
1119
1154
|
|
1120
1155
|
|
1121
1156
|
def _size2shape(size):
|
1122
|
-
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1157
|
+
if size is None:
|
1158
|
+
return ()
|
1159
|
+
elif isinstance(size, (tuple, list)):
|
1160
|
+
return tuple(size)
|
1161
|
+
else:
|
1162
|
+
return (size,)
|
1128
1163
|
|
1129
1164
|
|
1130
1165
|
def _check_shape(name, shape, *param_shapes):
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1166
|
+
if param_shapes:
|
1167
|
+
shape_ = lax.broadcast_shapes(shape, *param_shapes)
|
1168
|
+
if shape != shape_:
|
1169
|
+
msg = ("{} parameter shapes must be broadcast-compatible with shape "
|
1170
|
+
"argument, and the result of broadcasting the shapes must equal "
|
1171
|
+
"the shape argument, but got result {} for shape argument {}.")
|
1172
|
+
raise ValueError(msg.format(name, shape_, shape))
|
1138
1173
|
|
1139
1174
|
|
1140
1175
|
def _is_python_scalar(x):
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1176
|
+
if hasattr(x, 'aval'):
|
1177
|
+
return x.aval.weak_type
|
1178
|
+
elif np.ndim(x) == 0:
|
1179
|
+
return True
|
1180
|
+
elif isinstance(x, (bool, int, float, complex)):
|
1181
|
+
return True
|
1182
|
+
else:
|
1183
|
+
return False
|
1149
1184
|
|
1150
1185
|
|
1151
1186
|
python_scalar_dtypes = {
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
|
1187
|
+
bool: np.dtype('bool'),
|
1188
|
+
int: np.dtype('int64'),
|
1189
|
+
float: np.dtype('float64'),
|
1190
|
+
complex: np.dtype('complex128'),
|
1156
1191
|
}
|
1157
1192
|
|
1158
1193
|
|
1159
1194
|
def _dtype(x, *, canonicalize: bool = False):
|
1160
|
-
|
1161
|
-
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1195
|
+
"""Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
|
1196
|
+
if x is None:
|
1197
|
+
raise ValueError(f"Invalid argument to dtype: {x}.")
|
1198
|
+
elif isinstance(x, type) and x in python_scalar_dtypes:
|
1199
|
+
dt = python_scalar_dtypes[x]
|
1200
|
+
elif type(x) in python_scalar_dtypes:
|
1201
|
+
dt = python_scalar_dtypes[type(x)]
|
1202
|
+
elif hasattr(x, 'dtype'):
|
1203
|
+
dt = x.dtype
|
1204
|
+
else:
|
1205
|
+
dt = np.result_type(x)
|
1206
|
+
return dtypes.canonicalize_dtype(dt) if canonicalize else dt
|
1172
1207
|
|
1173
1208
|
|
1174
1209
|
def _const(example, val):
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1210
|
+
if _is_python_scalar(example):
|
1211
|
+
dtype = dtypes.canonicalize_dtype(type(example))
|
1212
|
+
val = dtypes.scalar_type_of(example)(val)
|
1213
|
+
return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype)
|
1214
|
+
else:
|
1215
|
+
dtype = dtypes.canonicalize_dtype(example.dtype)
|
1216
|
+
return np.array(val, dtype)
|
1182
1217
|
|
1183
1218
|
|
1184
1219
|
_tr_params = namedtuple(
|
1185
|
-
|
1220
|
+
"tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"]
|
1186
1221
|
)
|
1187
1222
|
|
1188
1223
|
|
1189
1224
|
def _get_tr_params(n, p):
|
1190
|
-
|
1191
|
-
|
1192
|
-
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1204
|
-
|
1205
|
-
|
1225
|
+
# See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
|
1226
|
+
# constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
|
1227
|
+
mu = n * p
|
1228
|
+
spq = jnp.sqrt(mu * (1 - p))
|
1229
|
+
c = mu + 0.5
|
1230
|
+
b = 1.15 + 2.53 * spq
|
1231
|
+
a = -0.0873 + 0.0248 * b + 0.01 * p
|
1232
|
+
alpha = (2.83 + 5.1 / b) * spq
|
1233
|
+
u_r = 0.43
|
1234
|
+
v_r = 0.92 - 4.2 / b
|
1235
|
+
m = jnp.floor((n + 1) * p).astype(n.dtype)
|
1236
|
+
log_p = jnp.log(p)
|
1237
|
+
log1_p = jnp.log1p(-p)
|
1238
|
+
log_h = ((m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) +
|
1239
|
+
_stirling_approx_tail(m) + _stirling_approx_tail(n - m))
|
1240
|
+
return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)
|
1206
1241
|
|
1207
1242
|
|
1208
1243
|
def _stirling_approx_tail(k):
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1216
|
-
|
1217
|
-
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1244
|
+
precomputed = jnp.array([0.08106146679532726,
|
1245
|
+
0.04134069595540929,
|
1246
|
+
0.02767792568499834,
|
1247
|
+
0.02079067210376509,
|
1248
|
+
0.01664469118982119,
|
1249
|
+
0.01387612882307075,
|
1250
|
+
0.01189670994589177,
|
1251
|
+
0.01041126526197209,
|
1252
|
+
0.009255462182712733,
|
1253
|
+
0.008330563433362871],
|
1254
|
+
dtype=environ.dftype())
|
1255
|
+
kp1 = k + 1
|
1256
|
+
kp1sq = (k + 1) ** 2
|
1257
|
+
return jnp.where(k < 10,
|
1258
|
+
precomputed[k],
|
1259
|
+
(1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1)
|
1225
1260
|
|
1226
1261
|
|
1227
1262
|
def _binomial_btrs(key, p, n):
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1232
|
-
Hormann, "The Generation of Binonmial Random Variates"
|
1233
|
-
(https://core.ac.uk/download/pdf/11007254.pdf)
|
1234
|
-
"""
|
1235
|
-
|
1236
|
-
def _btrs_body_fn(val):
|
1237
|
-
_, key, _, _ = val
|
1238
|
-
key, key_u, key_v = jr.split(key, 3)
|
1239
|
-
u = jr.uniform(key_u)
|
1240
|
-
v = jr.uniform(key_v)
|
1241
|
-
u = u - 0.5
|
1242
|
-
k = jnp.floor(
|
1243
|
-
(2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c
|
1244
|
-
).astype(n.dtype)
|
1245
|
-
return k, key, u, v
|
1246
|
-
|
1247
|
-
def _btrs_cond_fn(val):
|
1248
|
-
def accept_fn(k, u, v):
|
1249
|
-
# See acceptance condition in Step 3. (Page 3) of TRS algorithm
|
1250
|
-
# v <= f(k) * g_grad(u) / alpha
|
1251
|
-
|
1252
|
-
m = tr_params.m
|
1253
|
-
log_p = tr_params.log_p
|
1254
|
-
log1_p = tr_params.log1_p
|
1255
|
-
# See: formula for log(f(k)) at bottom of Page 5.
|
1256
|
-
log_f = (
|
1257
|
-
(n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0))
|
1258
|
-
+ (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p)
|
1259
|
-
+ (_stirling_approx_tail(k) - _stirling_approx_tail(n - k))
|
1260
|
-
+ tr_params.log_h
|
1261
|
-
)
|
1262
|
-
g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
|
1263
|
-
return jnp.log((v * tr_params.alpha) / g) <= log_f
|
1264
|
-
|
1265
|
-
k, key, u, v = val
|
1266
|
-
early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
|
1267
|
-
early_reject = (k < 0) | (k > n)
|
1268
|
-
return lax.cond(
|
1269
|
-
early_accept | early_reject,
|
1270
|
-
(),
|
1271
|
-
lambda _: ~early_accept,
|
1272
|
-
(k, u, v),
|
1273
|
-
lambda x: ~accept_fn(*x),
|
1274
|
-
)
|
1263
|
+
"""
|
1264
|
+
Based on the transformed rejection sampling algorithm (BTRS) from the
|
1265
|
+
following reference:
|
1275
1266
|
|
1276
|
-
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1280
|
-
|
1267
|
+
Hormann, "The Generation of Binonmial Random Variates"
|
1268
|
+
(https://core.ac.uk/download/pdf/11007254.pdf)
|
1269
|
+
"""
|
1270
|
+
|
1271
|
+
def _btrs_body_fn(val):
|
1272
|
+
_, key, _, _ = val
|
1273
|
+
key, key_u, key_v = jr.split(key, 3)
|
1274
|
+
u = jr.uniform(key_u)
|
1275
|
+
v = jr.uniform(key_v)
|
1276
|
+
u = u - 0.5
|
1277
|
+
k = jnp.floor(
|
1278
|
+
(2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c
|
1279
|
+
).astype(n.dtype)
|
1280
|
+
return k, key, u, v
|
1281
|
+
|
1282
|
+
def _btrs_cond_fn(val):
|
1283
|
+
def accept_fn(k, u, v):
|
1284
|
+
# See acceptance condition in Step 3. (Page 3) of TRS algorithm
|
1285
|
+
# v <= f(k) * g_grad(u) / alpha
|
1286
|
+
|
1287
|
+
m = tr_params.m
|
1288
|
+
log_p = tr_params.log_p
|
1289
|
+
log1_p = tr_params.log1_p
|
1290
|
+
# See: formula for log(f(k)) at bottom of Page 5.
|
1291
|
+
log_f = (
|
1292
|
+
(n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0))
|
1293
|
+
+ (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p)
|
1294
|
+
+ (_stirling_approx_tail(k) - _stirling_approx_tail(n - k))
|
1295
|
+
+ tr_params.log_h
|
1296
|
+
)
|
1297
|
+
g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
|
1298
|
+
return jnp.log((v * tr_params.alpha) / g) <= log_f
|
1299
|
+
|
1300
|
+
k, key, u, v = val
|
1301
|
+
early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
|
1302
|
+
early_reject = (k < 0) | (k > n)
|
1303
|
+
return lax.cond(
|
1304
|
+
early_accept | early_reject,
|
1305
|
+
(),
|
1306
|
+
lambda _: ~early_accept,
|
1307
|
+
(k, u, v),
|
1308
|
+
lambda x: ~accept_fn(*x),
|
1309
|
+
)
|
1310
|
+
|
1311
|
+
tr_params = _get_tr_params(n, p)
|
1312
|
+
ret = lax.while_loop(
|
1313
|
+
_btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0)
|
1314
|
+
) # use k=-1 initially so that cond_fn returns True
|
1315
|
+
return ret[0]
|
1281
1316
|
|
1282
1317
|
|
1283
1318
|
def _binomial_inversion(key, p, n):
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1287
|
-
|
1288
|
-
|
1289
|
-
|
1290
|
-
|
1319
|
+
def _binom_inv_body_fn(val):
|
1320
|
+
i, key, geom_acc = val
|
1321
|
+
key, key_u = jr.split(key)
|
1322
|
+
u = jr.uniform(key_u)
|
1323
|
+
geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
|
1324
|
+
geom_acc = geom_acc + geom
|
1325
|
+
return i + 1, key, geom_acc
|
1291
1326
|
|
1292
|
-
|
1293
|
-
|
1294
|
-
|
1327
|
+
def _binom_inv_cond_fn(val):
|
1328
|
+
i, _, geom_acc = val
|
1329
|
+
return geom_acc <= n
|
1295
1330
|
|
1296
|
-
|
1297
|
-
|
1298
|
-
|
1331
|
+
log1_p = jnp.log1p(-p)
|
1332
|
+
ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
|
1333
|
+
return ret[0]
|
1299
1334
|
|
1300
1335
|
|
1301
1336
|
def _binomial_dispatch(key, p, n):
|
1302
|
-
|
1303
|
-
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1337
|
+
def dispatch(key, p, n):
|
1338
|
+
is_le_mid = p <= 0.5
|
1339
|
+
pq = jnp.where(is_le_mid, p, 1 - p)
|
1340
|
+
mu = n * pq
|
1341
|
+
k = lax.cond(
|
1342
|
+
mu < 10,
|
1343
|
+
(key, pq, n),
|
1344
|
+
lambda x: _binomial_inversion(*x),
|
1345
|
+
(key, pq, n),
|
1346
|
+
lambda x: _binomial_btrs(*x),
|
1347
|
+
)
|
1348
|
+
return jnp.where(is_le_mid, k, n - k)
|
1349
|
+
|
1350
|
+
# Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
|
1351
|
+
cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
|
1352
|
+
return lax.cond(
|
1353
|
+
cond0 & (p < 1),
|
1354
|
+
(key, p, n),
|
1355
|
+
lambda x: dispatch(*x),
|
1356
|
+
(),
|
1357
|
+
lambda _: jnp.where(cond0, n, 0),
|
1312
1358
|
)
|
1313
|
-
return jnp.where(is_le_mid, k, n - k)
|
1314
|
-
|
1315
|
-
# Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
|
1316
|
-
cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
|
1317
|
-
return lax.cond(
|
1318
|
-
cond0 & (p < 1),
|
1319
|
-
(key, p, n),
|
1320
|
-
lambda x: dispatch(*x),
|
1321
|
-
(),
|
1322
|
-
lambda _: jnp.where(cond0, n, 0),
|
1323
|
-
)
|
1324
1359
|
|
1325
1360
|
|
1326
1361
|
@partial(jit, static_argnums=(3,))
|
1327
1362
|
def _binomial(key, p, n, shape):
|
1328
|
-
|
1329
|
-
|
1330
|
-
|
1331
|
-
|
1332
|
-
|
1333
|
-
|
1334
|
-
|
1335
|
-
|
1336
|
-
|
1337
|
-
|
1363
|
+
shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
|
1364
|
+
# reshape to map over axis 0
|
1365
|
+
p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
|
1366
|
+
n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
|
1367
|
+
key = jr.split(key, jnp.size(p))
|
1368
|
+
if jax.default_backend() == "cpu":
|
1369
|
+
ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
|
1370
|
+
else:
|
1371
|
+
ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
|
1372
|
+
return jnp.reshape(ret, shape)
|
1338
1373
|
|
1339
1374
|
|
1340
1375
|
@partial(jit, static_argnums=(2,))
|
1341
1376
|
def _categorical(key, p, shape):
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1377
|
+
# this implementation is fast when event shape is small, and slow otherwise
|
1378
|
+
# Ref: https://stackoverflow.com/a/34190035
|
1379
|
+
shape = shape or p.shape[:-1]
|
1380
|
+
s = jnp.cumsum(p, axis=-1)
|
1381
|
+
r = jr.uniform(key, shape=shape + (1,))
|
1382
|
+
return jnp.sum(s < r, axis=-1)
|
1348
1383
|
|
1349
1384
|
|
1350
1385
|
def _scatter_add_one(operand, indices, updates):
|
1351
|
-
|
1352
|
-
|
1353
|
-
|
1354
|
-
|
1355
|
-
|
1356
|
-
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1360
|
-
|
1386
|
+
return lax.scatter_add(
|
1387
|
+
operand,
|
1388
|
+
indices,
|
1389
|
+
updates,
|
1390
|
+
lax.ScatterDimensionNumbers(
|
1391
|
+
update_window_dims=(),
|
1392
|
+
inserted_window_dims=(0,),
|
1393
|
+
scatter_dims_to_operand_dims=(0,),
|
1394
|
+
),
|
1395
|
+
)
|
1361
1396
|
|
1362
1397
|
|
1363
1398
|
def _reshape(x, shape):
|
1364
|
-
|
1365
|
-
|
1366
|
-
|
1367
|
-
|
1399
|
+
if isinstance(x, (int, float, np.ndarray, np.generic)):
|
1400
|
+
return np.reshape(x, shape)
|
1401
|
+
else:
|
1402
|
+
return jnp.reshape(x, shape)
|
1368
1403
|
|
1369
1404
|
|
1370
1405
|
def _promote_shapes(*args, shape=()):
|
1371
|
-
|
1372
|
-
|
1373
|
-
|
1374
|
-
|
1375
|
-
|
1376
|
-
|
1377
|
-
|
1378
|
-
|
1379
|
-
|
1380
|
-
|
1406
|
+
# adapted from lax.lax_numpy
|
1407
|
+
if len(args) < 2 and not shape:
|
1408
|
+
return args
|
1409
|
+
else:
|
1410
|
+
shapes = [jnp.shape(arg) for arg in args]
|
1411
|
+
num_dims = len(lax.broadcast_shapes(shape, *shapes))
|
1412
|
+
return [
|
1413
|
+
_reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
|
1414
|
+
for arg, s in zip(args, shapes)
|
1415
|
+
]
|
1381
1416
|
|
1382
1417
|
|
1383
1418
|
@partial(jit, static_argnums=(3, 4))
|
1384
1419
|
def _multinomial(key, p, n, n_max, shape=()):
|
1385
|
-
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1391
|
-
|
1392
|
-
|
1393
|
-
|
1394
|
-
|
1395
|
-
|
1396
|
-
|
1397
|
-
|
1398
|
-
|
1399
|
-
|
1400
|
-
|
1401
|
-
|
1402
|
-
|
1403
|
-
|
1404
|
-
|
1405
|
-
|
1406
|
-
|
1407
|
-
|
1408
|
-
|
1409
|
-
|
1420
|
+
if jnp.shape(n) != jnp.shape(p)[:-1]:
|
1421
|
+
broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
|
1422
|
+
n = jnp.broadcast_to(n, broadcast_shape)
|
1423
|
+
p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
|
1424
|
+
shape = shape or p.shape[:-1]
|
1425
|
+
if n_max == 0:
|
1426
|
+
return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
|
1427
|
+
# get indices from categorical distribution then gather the result
|
1428
|
+
indices = _categorical(key, p, (n_max,) + shape)
|
1429
|
+
# mask out values when counts is heterogeneous
|
1430
|
+
if jnp.ndim(n) > 0:
|
1431
|
+
mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
|
1432
|
+
mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
|
1433
|
+
excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1),
|
1434
|
+
jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))],
|
1435
|
+
-1)
|
1436
|
+
else:
|
1437
|
+
mask = 1
|
1438
|
+
excess = 0
|
1439
|
+
# NB: we transpose to move batch shape to the front
|
1440
|
+
indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
|
1441
|
+
samples_2D = vmap(_scatter_add_one)(jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
|
1442
|
+
jnp.expand_dims(indices_2D, axis=-1),
|
1443
|
+
jnp.ones(indices_2D.shape, dtype=indices.dtype))
|
1444
|
+
return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
|
1410
1445
|
|
1411
1446
|
|
1412
1447
|
@partial(jit, static_argnums=(2, 3), static_argnames=['shape', 'dtype'])
|
1413
1448
|
def _von_mises_centered(key, concentration, shape, dtype=None):
|
1414
|
-
|
1415
|
-
|
1416
|
-
Returns
|
1417
|
-
-------
|
1418
|
-
out: array_like
|
1419
|
-
centered samples from von Mises
|
1420
|
-
|
1421
|
-
References
|
1422
|
-
----------
|
1423
|
-
.. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
|
1424
|
-
Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
|
1425
|
-
|
1426
|
-
"""
|
1427
|
-
shape = shape or jnp.shape(concentration)
|
1428
|
-
dtype = dtype or environ.dftype()
|
1429
|
-
concentration = lax.convert_element_type(concentration, dtype)
|
1430
|
-
concentration = jnp.broadcast_to(concentration, shape)
|
1431
|
-
|
1432
|
-
if dtype == jnp.float16:
|
1433
|
-
s_cutoff = 1.8e-1
|
1434
|
-
elif dtype == jnp.float32:
|
1435
|
-
s_cutoff = 2e-2
|
1436
|
-
elif dtype == jnp.float64:
|
1437
|
-
s_cutoff = 1.2e-4
|
1438
|
-
else:
|
1439
|
-
raise ValueError(f"Unsupported dtype: {dtype}")
|
1440
|
-
|
1441
|
-
r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
|
1442
|
-
rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
|
1443
|
-
s_exact = (1.0 + rho ** 2) / (2.0 * rho)
|
1444
|
-
|
1445
|
-
s_approximate = 1.0 / concentration
|
1446
|
-
|
1447
|
-
s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
|
1448
|
-
|
1449
|
-
def cond_fn(*args):
|
1450
|
-
"""check if all are done or reached max number of iterations"""
|
1451
|
-
i, _, done, _, _ = args[0]
|
1452
|
-
return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
|
1453
|
-
|
1454
|
-
def body_fn(*args):
|
1455
|
-
i, key, done, _, w = args[0]
|
1456
|
-
uni_ukey, uni_vkey, key = jr.split(key, 3)
|
1457
|
-
u = jr.uniform(
|
1458
|
-
key=uni_ukey,
|
1459
|
-
shape=shape,
|
1460
|
-
dtype=concentration.dtype,
|
1461
|
-
minval=-1.0,
|
1462
|
-
maxval=1.0,
|
1463
|
-
)
|
1464
|
-
z = jnp.cos(jnp.pi * u)
|
1465
|
-
w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
|
1466
|
-
y = concentration * (s - w)
|
1467
|
-
v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
|
1468
|
-
accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
|
1469
|
-
return i + 1, key, accept | done, u, w
|
1449
|
+
"""Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
|
1470
1450
|
|
1471
|
-
|
1472
|
-
|
1473
|
-
|
1451
|
+
Returns
|
1452
|
+
-------
|
1453
|
+
out: array_like
|
1454
|
+
centered samples from von Mises
|
1474
1455
|
|
1475
|
-
|
1476
|
-
|
1477
|
-
|
1478
|
-
|
1479
|
-
)
|
1456
|
+
References
|
1457
|
+
----------
|
1458
|
+
.. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
|
1459
|
+
Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
|
1480
1460
|
|
1481
|
-
|
1461
|
+
"""
|
1462
|
+
shape = shape or jnp.shape(concentration)
|
1463
|
+
dtype = dtype or environ.dftype()
|
1464
|
+
concentration = lax.convert_element_type(concentration, dtype)
|
1465
|
+
concentration = jnp.broadcast_to(concentration, shape)
|
1466
|
+
|
1467
|
+
if dtype == jnp.float16:
|
1468
|
+
s_cutoff = 1.8e-1
|
1469
|
+
elif dtype == jnp.float32:
|
1470
|
+
s_cutoff = 2e-2
|
1471
|
+
elif dtype == jnp.float64:
|
1472
|
+
s_cutoff = 1.2e-4
|
1473
|
+
else:
|
1474
|
+
raise ValueError(f"Unsupported dtype: {dtype}")
|
1475
|
+
|
1476
|
+
r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
|
1477
|
+
rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
|
1478
|
+
s_exact = (1.0 + rho ** 2) / (2.0 * rho)
|
1479
|
+
|
1480
|
+
s_approximate = 1.0 / concentration
|
1481
|
+
|
1482
|
+
s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
|
1483
|
+
|
1484
|
+
def cond_fn(*args):
|
1485
|
+
"""check if all are done or reached max number of iterations"""
|
1486
|
+
i, _, done, _, _ = args[0]
|
1487
|
+
return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
|
1488
|
+
|
1489
|
+
def body_fn(*args):
|
1490
|
+
i, key, done, _, w = args[0]
|
1491
|
+
uni_ukey, uni_vkey, key = jr.split(key, 3)
|
1492
|
+
u = jr.uniform(
|
1493
|
+
key=uni_ukey,
|
1494
|
+
shape=shape,
|
1495
|
+
dtype=concentration.dtype,
|
1496
|
+
minval=-1.0,
|
1497
|
+
maxval=1.0,
|
1498
|
+
)
|
1499
|
+
z = jnp.cos(jnp.pi * u)
|
1500
|
+
w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
|
1501
|
+
y = concentration * (s - w)
|
1502
|
+
v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
|
1503
|
+
accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
|
1504
|
+
return i + 1, key, accept | done, u, w
|
1505
|
+
|
1506
|
+
init_done = jnp.zeros(shape, dtype=bool)
|
1507
|
+
init_u = jnp.zeros(shape)
|
1508
|
+
init_w = jnp.zeros(shape)
|
1509
|
+
|
1510
|
+
_, _, done, u, w = lax.while_loop(
|
1511
|
+
cond_fun=cond_fn,
|
1512
|
+
body_fun=body_fn,
|
1513
|
+
init_val=(jnp.array(0), key, init_done, init_u, init_w),
|
1514
|
+
)
|
1515
|
+
|
1516
|
+
return jnp.sign(u) * jnp.arccos(w)
|
1482
1517
|
|
1483
1518
|
|
1484
1519
|
def _loc_scale(loc, scale, value):
|
1485
|
-
|
1486
|
-
|
1487
|
-
|
1488
|
-
|
1489
|
-
|
1490
|
-
else:
|
1491
|
-
if scale is None:
|
1492
|
-
return value + loc
|
1520
|
+
if loc is None:
|
1521
|
+
if scale is None:
|
1522
|
+
return value
|
1523
|
+
else:
|
1524
|
+
return value * scale
|
1493
1525
|
else:
|
1494
|
-
|
1526
|
+
if scale is None:
|
1527
|
+
return value + loc
|
1528
|
+
else:
|
1529
|
+
return value * scale + loc
|
1495
1530
|
|
1496
1531
|
|
1497
1532
|
def _check_py_seq(seq):
|
1498
|
-
|
1533
|
+
return jnp.asarray(seq) if isinstance(seq, (tuple, list)) else seq
|