effectful 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/handlers/indexed.py +27 -46
- effectful/handlers/jax/__init__.py +14 -0
- effectful/handlers/jax/_handlers.py +293 -0
- effectful/handlers/jax/_terms.py +502 -0
- effectful/handlers/jax/numpy/__init__.py +23 -0
- effectful/handlers/jax/numpy/linalg.py +13 -0
- effectful/handlers/jax/scipy/special.py +11 -0
- effectful/handlers/numpyro.py +562 -0
- effectful/handlers/pyro.py +565 -214
- effectful/handlers/torch.py +321 -169
- effectful/internals/runtime.py +6 -13
- effectful/internals/tensor_utils.py +32 -0
- effectful/internals/unification.py +900 -0
- effectful/ops/semantics.py +104 -84
- effectful/ops/syntax.py +1276 -167
- effectful/ops/types.py +141 -35
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/METADATA +65 -57
- effectful-0.2.0.dist-info/RECORD +26 -0
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/WHEEL +1 -1
- effectful/handlers/numbers.py +0 -259
- effectful/internals/base_impl.py +0 -259
- effectful-0.0.1.dist-info/RECORD +0 -19
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info/licenses}/LICENSE.md +0 -0
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,562 @@
|
|
1
|
+
try:
|
2
|
+
import numpyro.distributions as dist
|
3
|
+
except ImportError:
|
4
|
+
raise ImportError("Numpyro is required to use effectful.handlers.numpyro")
|
5
|
+
|
6
|
+
|
7
|
+
import functools
|
8
|
+
from collections.abc import Collection, Hashable, Mapping
|
9
|
+
from typing import Any, cast
|
10
|
+
|
11
|
+
import jax
|
12
|
+
import tree
|
13
|
+
|
14
|
+
import effectful.handlers.jax.numpy as jnp
|
15
|
+
from effectful.handlers.jax import bind_dims, jax_getitem, sizesof, unbind_dims
|
16
|
+
from effectful.handlers.jax._handlers import _register_jax_op, is_eager_array
|
17
|
+
from effectful.ops.semantics import apply, runner, typeof
|
18
|
+
from effectful.ops.syntax import defdata, defop, defterm
|
19
|
+
from effectful.ops.types import NotHandled, Operation, Term
|
20
|
+
|
21
|
+
|
22
|
+
class Naming(dict[Operation[[], jax.Array], int]):
|
23
|
+
"""
|
24
|
+
A mapping from dimensions (indexed from the right) to names.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(self, name_to_dim: Mapping[Operation[[], jax.Array], int]):
|
28
|
+
assert all(v < 0 for v in name_to_dim.values())
|
29
|
+
super().__init__(name_to_dim)
|
30
|
+
|
31
|
+
@staticmethod
|
32
|
+
def from_shape(
|
33
|
+
names: Collection[Operation[[], jax.Array]], event_dims: int
|
34
|
+
) -> "Naming":
|
35
|
+
"""Create a naming from a set of indices and the number of event dimensions.
|
36
|
+
|
37
|
+
The resulting naming converts tensors of shape
|
38
|
+
``| batch_shape | named | event_shape |``
|
39
|
+
to tensors of shape ``| batch_shape | event_shape |, | named |``.
|
40
|
+
|
41
|
+
"""
|
42
|
+
assert event_dims >= 0
|
43
|
+
return Naming({n: -event_dims - len(names) + i for i, n in enumerate(names)})
|
44
|
+
|
45
|
+
def apply(self, value: jax.Array) -> jax.Array:
|
46
|
+
indexes: list[Any] = [slice(None)] * (len(value.shape))
|
47
|
+
for n, d in self.items():
|
48
|
+
indexes[len(value.shape) + d] = n()
|
49
|
+
return jax_getitem(value, tuple(indexes))
|
50
|
+
|
51
|
+
def __repr__(self):
|
52
|
+
return f"Naming({super().__repr__()})"
|
53
|
+
|
54
|
+
|
55
|
+
@unbind_dims.register # type: ignore
|
56
|
+
def _unbind_distribution(
|
57
|
+
d: dist.Distribution, *names: Operation[[], jax.Array]
|
58
|
+
) -> dist.Distribution:
|
59
|
+
batch_shape = None
|
60
|
+
|
61
|
+
def _validate_batch_shape(t):
|
62
|
+
nonlocal batch_shape
|
63
|
+
if len(t.shape) < len(names):
|
64
|
+
raise ValueError(
|
65
|
+
"All tensors must have at least as many dimensions as names"
|
66
|
+
)
|
67
|
+
|
68
|
+
if batch_shape is None:
|
69
|
+
batch_shape = t.shape[: len(names)]
|
70
|
+
|
71
|
+
if (
|
72
|
+
len(t.shape) < len(batch_shape)
|
73
|
+
or t.shape[: len(batch_shape)] != batch_shape
|
74
|
+
):
|
75
|
+
raise ValueError("All tensors must have the same batch shape.")
|
76
|
+
|
77
|
+
def _to_named(a):
|
78
|
+
nonlocal batch_shape
|
79
|
+
if isinstance(a, jax.Array):
|
80
|
+
_validate_batch_shape(a)
|
81
|
+
return unbind_dims(a, *names)
|
82
|
+
elif isinstance(a, dist.Distribution):
|
83
|
+
return unbind_dims(a, *names)
|
84
|
+
else:
|
85
|
+
return a
|
86
|
+
|
87
|
+
# Convert to a term in a context that does not evaluate distribution constructors.
|
88
|
+
def _apply(op, *args, **kwargs):
|
89
|
+
typ = op.__type_rule__(*args, **kwargs)
|
90
|
+
if issubclass(typ, dist.Distribution):
|
91
|
+
return defdata(op, *args, **kwargs)
|
92
|
+
return op.__default_rule__(*args, **kwargs)
|
93
|
+
|
94
|
+
with runner({apply: _apply}):
|
95
|
+
d = defterm(d)
|
96
|
+
|
97
|
+
if not (isinstance(d, Term) and typeof(d) is dist.Distribution):
|
98
|
+
raise NotHandled
|
99
|
+
|
100
|
+
# TODO: this is a hack to avoid mangling arguments that are array-valued, but not batched
|
101
|
+
aux_kwargs = set(["total_count"])
|
102
|
+
|
103
|
+
new_d = d.op(
|
104
|
+
*[_to_named(a) for a in d.args],
|
105
|
+
**{k: v if k in aux_kwargs else _to_named(v) for (k, v) in d.kwargs.items()},
|
106
|
+
)
|
107
|
+
return new_d
|
108
|
+
|
109
|
+
|
110
|
+
@bind_dims.register # type: ignore
|
111
|
+
def _bind_dims_distribution(
|
112
|
+
d: dist.Distribution, *names: Operation[[], jax.Array]
|
113
|
+
) -> dist.Distribution:
|
114
|
+
def _to_positional(a, indices):
|
115
|
+
typ = typeof(a)
|
116
|
+
if issubclass(typ, jax.Array):
|
117
|
+
# broadcast to full indexed shape
|
118
|
+
existing_dims = set(sizesof(a).keys())
|
119
|
+
missing_dims = set(indices) - existing_dims
|
120
|
+
|
121
|
+
a_indexed = unbind_dims(
|
122
|
+
jnp.broadcast_to(
|
123
|
+
a, tuple(indices[dim] for dim in missing_dims) + a.shape
|
124
|
+
),
|
125
|
+
*missing_dims,
|
126
|
+
)
|
127
|
+
return bind_dims(a_indexed, *indices)
|
128
|
+
elif issubclass(typ, dist.Distribution):
|
129
|
+
# We are really assuming that only one distriution appears in our arguments. This is sufficient for cases
|
130
|
+
# like Independent and TransformedDistribution
|
131
|
+
return bind_dims(a, *indices)
|
132
|
+
else:
|
133
|
+
return a
|
134
|
+
|
135
|
+
# Convert to a term in a context that does not evaluate distribution constructors.
|
136
|
+
def _apply(op, *args, **kwargs):
|
137
|
+
typ = op.__type_rule__(*args, **kwargs)
|
138
|
+
if issubclass(typ, dist.Distribution):
|
139
|
+
return defdata(op, *args, **kwargs)
|
140
|
+
return op.__default_rule__(*args, **kwargs)
|
141
|
+
|
142
|
+
with runner({apply: _apply}):
|
143
|
+
d = defterm(d)
|
144
|
+
|
145
|
+
if not (isinstance(d, Term) and typeof(d) is dist.Distribution):
|
146
|
+
raise NotHandled
|
147
|
+
|
148
|
+
sizes = sizesof(d)
|
149
|
+
indices = {k: sizes[k] for k in names}
|
150
|
+
|
151
|
+
pos_args = [_to_positional(a, indices) for a in d.args]
|
152
|
+
pos_kwargs = {k: _to_positional(v, indices) for (k, v) in d.kwargs.items()}
|
153
|
+
new_d = d.op(*pos_args, **pos_kwargs)
|
154
|
+
|
155
|
+
return new_d
|
156
|
+
|
157
|
+
|
158
|
+
@functools.cache
|
159
|
+
def _register_distribution_op(
|
160
|
+
dist_constr: type[dist.Distribution],
|
161
|
+
) -> Operation[Any, dist.Distribution]:
|
162
|
+
# introduce a wrapper so that we can control type annotations
|
163
|
+
def wrapper(*args, **kwargs) -> dist.Distribution:
|
164
|
+
if any(isinstance(a, Term) for a in tree.flatten((args, kwargs))):
|
165
|
+
raise NotHandled
|
166
|
+
return dist_constr(*args, **kwargs)
|
167
|
+
|
168
|
+
return defop(wrapper, name=dist_constr.__name__)
|
169
|
+
|
170
|
+
|
171
|
+
@defdata.register(dist.Distribution)
|
172
|
+
def _(op, *args, **kwargs):
|
173
|
+
if all(
|
174
|
+
not isinstance(a, Term) or is_eager_array(a) or isinstance(a, dist.Distribution)
|
175
|
+
for a in tree.flatten((args, kwargs))
|
176
|
+
):
|
177
|
+
return _DistributionTerm(op, *args, **kwargs)
|
178
|
+
else:
|
179
|
+
return defdata.dispatch(object)(op, *args, **kwargs)
|
180
|
+
|
181
|
+
|
182
|
+
def _broadcast_to_named(t, sizes):
|
183
|
+
missing_dims = set(sizes) - set(sizesof(t))
|
184
|
+
t_broadcast = jnp.broadcast_to(
|
185
|
+
t, tuple(sizes[dim] for dim in missing_dims) + t.shape
|
186
|
+
)
|
187
|
+
return jax_getitem(t_broadcast, tuple(dim() for dim in missing_dims))
|
188
|
+
|
189
|
+
|
190
|
+
def expand_to_batch_shape(tensor, batch_ndims, expanded_batch_shape):
|
191
|
+
"""
|
192
|
+
Expands a tensor of shape batch_shape + remaining_shape to
|
193
|
+
expanded_batch_shape + remaining_shape.
|
194
|
+
|
195
|
+
Args:
|
196
|
+
tensor: JAX array with shape batch_shape + event_shape
|
197
|
+
expanded_batch_shape: tuple of the desired expanded batch dimensions
|
198
|
+
event_ndims: number of dimensions in the event_shape
|
199
|
+
|
200
|
+
Returns:
|
201
|
+
A JAX array with shape expanded_batch_shape + event_shape
|
202
|
+
"""
|
203
|
+
# Split the shape into batch and event parts
|
204
|
+
assert len(tensor.shape) >= batch_ndims
|
205
|
+
|
206
|
+
batch_shape = tensor.shape[:batch_ndims] if batch_ndims > 0 else ()
|
207
|
+
remaining_shape = tensor.shape[batch_ndims:]
|
208
|
+
|
209
|
+
# Ensure the expanded batch shape is compatible with the current batch shape
|
210
|
+
if len(expanded_batch_shape) < batch_ndims:
|
211
|
+
raise ValueError(
|
212
|
+
"Expanded batch shape must have at least as many dimensions as current batch shape"
|
213
|
+
)
|
214
|
+
new_batch_shape = jnp.broadcast_shapes(batch_shape, expanded_batch_shape)
|
215
|
+
|
216
|
+
# Create the new shape
|
217
|
+
new_shape = new_batch_shape + remaining_shape
|
218
|
+
|
219
|
+
# Broadcast the tensor to the new shape
|
220
|
+
expanded_tensor = jnp.broadcast_to(tensor, new_shape)
|
221
|
+
|
222
|
+
return expanded_tensor
|
223
|
+
|
224
|
+
|
225
|
+
class _DistributionTerm(dist.Distribution):
|
226
|
+
"""A distribution wrapper that satisfies the Term interface.
|
227
|
+
|
228
|
+
Represented as a term of the form D(*args, **kwargs) where D is the
|
229
|
+
distribution constructor.
|
230
|
+
|
231
|
+
Note: When we construct instances of this class, we put distribution
|
232
|
+
parameters that can be expanded in the args list and those that cannot in
|
233
|
+
the kwargs list.
|
234
|
+
|
235
|
+
"""
|
236
|
+
|
237
|
+
_op: Operation[Any, dist.Distribution]
|
238
|
+
_args: tuple
|
239
|
+
_kwargs: dict
|
240
|
+
|
241
|
+
def __init__(self, op: Operation[Any, dist.Distribution], *args, **kwargs):
|
242
|
+
self._op = op
|
243
|
+
self._args = args
|
244
|
+
self._kwargs = kwargs
|
245
|
+
self.__indices = None
|
246
|
+
self.__pos_base_dist = None
|
247
|
+
|
248
|
+
@property
|
249
|
+
def _indices(self):
|
250
|
+
if self.__indices is None:
|
251
|
+
self.__indices = sizesof(self)
|
252
|
+
return self.__indices
|
253
|
+
|
254
|
+
@property
|
255
|
+
def _pos_base_dist(self):
|
256
|
+
if self.__pos_base_dist is None:
|
257
|
+
self.__pos_base_dist = bind_dims(self, *self._indices)
|
258
|
+
return self.__pos_base_dist
|
259
|
+
|
260
|
+
@property
|
261
|
+
def op(self):
|
262
|
+
return self._op
|
263
|
+
|
264
|
+
@property
|
265
|
+
def args(self):
|
266
|
+
return self._args
|
267
|
+
|
268
|
+
@property
|
269
|
+
def kwargs(self):
|
270
|
+
return self._kwargs
|
271
|
+
|
272
|
+
@property
|
273
|
+
def batch_shape(self):
|
274
|
+
return self._pos_base_dist.batch_shape[len(self._indices) :]
|
275
|
+
|
276
|
+
@property
|
277
|
+
def has_rsample(self) -> bool:
|
278
|
+
return self._pos_base_dist.has_rsample
|
279
|
+
|
280
|
+
@property
|
281
|
+
def event_shape(self):
|
282
|
+
return self._pos_base_dist.event_shape
|
283
|
+
|
284
|
+
def rsample(self, key, sample_shape=()):
|
285
|
+
return self._reindex_sample(
|
286
|
+
self._pos_base_dist.rsample(key, sample_shape), sample_shape
|
287
|
+
)
|
288
|
+
|
289
|
+
def sample(self, key, sample_shape=()):
|
290
|
+
return self._reindex_sample(
|
291
|
+
self._pos_base_dist.sample(key, sample_shape), sample_shape
|
292
|
+
)
|
293
|
+
|
294
|
+
def _reindex_sample(self, value, sample_shape):
|
295
|
+
index = (slice(None),) * len(sample_shape) + tuple(i() for i in self._indices)
|
296
|
+
ret = jax_getitem(value, index)
|
297
|
+
return ret
|
298
|
+
|
299
|
+
def log_prob(self, value):
|
300
|
+
# value has shape named_batch_shape + sample_shape + batch_shape + event_shape
|
301
|
+
n_batch_event = len(self.batch_shape) + len(self.event_shape)
|
302
|
+
sample_shape = (
|
303
|
+
value.shape if n_batch_event == 0 else value.shape[:-n_batch_event]
|
304
|
+
)
|
305
|
+
value = bind_dims(_broadcast_to_named(value, self._indices), *self._indices)
|
306
|
+
dims = list(range(len(value.shape)))
|
307
|
+
n_named_batch = len(self._indices)
|
308
|
+
perm = (
|
309
|
+
dims[n_named_batch : n_named_batch + len(sample_shape)]
|
310
|
+
+ dims[:n_named_batch]
|
311
|
+
+ dims[n_named_batch + len(sample_shape) :]
|
312
|
+
)
|
313
|
+
assert len(perm) == len(value.shape)
|
314
|
+
|
315
|
+
# perm_value has shape sample_shape + named_batch_shape + batch_shape + event_shape
|
316
|
+
perm_value = jnp.permute_dims(value, perm)
|
317
|
+
pos_log_prob = _register_jax_op(self._pos_base_dist.log_prob)(perm_value)
|
318
|
+
ind_log_prob = self._reindex_sample(pos_log_prob, sample_shape)
|
319
|
+
return ind_log_prob
|
320
|
+
|
321
|
+
@property
|
322
|
+
def mean(self):
|
323
|
+
return self._reindex_sample(self._pos_base_dist.mean, ())
|
324
|
+
|
325
|
+
@property
|
326
|
+
def variance(self):
|
327
|
+
return self._reindex_sample(self._pos_base_dist.variance, ())
|
328
|
+
|
329
|
+
def enumerate_support(self, expand=True):
|
330
|
+
return self._reindex_sample(self._pos_base_dist.enumerate_support(expand), ())
|
331
|
+
|
332
|
+
def entropy(self):
|
333
|
+
return self._pos_base_dist.entropy()
|
334
|
+
|
335
|
+
def to_event(self, reinterpreted_batch_ndims=None):
|
336
|
+
raise NotHandled
|
337
|
+
|
338
|
+
def expand(self, batch_shape):
|
339
|
+
def expand_arg(a, batch_shape):
|
340
|
+
if is_eager_array(a):
|
341
|
+
return expand_to_batch_shape(a, len(self.batch_shape), batch_shape)
|
342
|
+
return a
|
343
|
+
|
344
|
+
if self.batch_shape == batch_shape:
|
345
|
+
return self
|
346
|
+
|
347
|
+
expanded_args = [expand_arg(a, batch_shape) for a in self.args]
|
348
|
+
expanded_kwargs = {
|
349
|
+
k: expand_arg(v, batch_shape) for (k, v) in self.kwargs.items()
|
350
|
+
}
|
351
|
+
return self.op(*expanded_args, **expanded_kwargs)
|
352
|
+
|
353
|
+
def __repr__(self):
|
354
|
+
return Term.__repr__(self)
|
355
|
+
|
356
|
+
def __str__(self):
|
357
|
+
return Term.__str__(self)
|
358
|
+
|
359
|
+
|
360
|
+
Term.register(_DistributionTerm)
|
361
|
+
|
362
|
+
|
363
|
+
@defterm.register(dist.Distribution)
|
364
|
+
def _embed_distribution(dist: dist.Distribution) -> Term[dist.Distribution]:
|
365
|
+
raise ValueError(
|
366
|
+
f"No embedding provided for distribution of type {type(dist).__name__}."
|
367
|
+
)
|
368
|
+
|
369
|
+
|
370
|
+
@defterm.register(dist.Cauchy)
|
371
|
+
@defterm.register(dist.Gumbel)
|
372
|
+
@defterm.register(dist.Laplace)
|
373
|
+
@defterm.register(dist.LogNormal)
|
374
|
+
@defterm.register(dist.Logistic)
|
375
|
+
@defterm.register(dist.Normal)
|
376
|
+
@defterm.register(dist.StudentT)
|
377
|
+
def _embed_loc_scale(d: dist.Distribution) -> Term[dist.Distribution]:
|
378
|
+
return _register_distribution_op(cast(Hashable, type(d)))(d.loc, d.scale)
|
379
|
+
|
380
|
+
|
381
|
+
@defterm.register(dist.BernoulliProbs)
|
382
|
+
@defterm.register(dist.CategoricalProbs)
|
383
|
+
@defterm.register(dist.GeometricProbs)
|
384
|
+
def _embed_probs(d: dist.Distribution) -> Term[dist.Distribution]:
|
385
|
+
return _register_distribution_op(cast(Hashable, type(d)))(d.probs)
|
386
|
+
|
387
|
+
|
388
|
+
@defterm.register(dist.BernoulliLogits)
|
389
|
+
@defterm.register(dist.CategoricalLogits)
|
390
|
+
@defterm.register(dist.GeometricLogits)
|
391
|
+
def _embed_logits(d: dist.Distribution) -> Term[dist.Distribution]:
|
392
|
+
return _register_distribution_op(cast(Hashable, type(d)))(d.logits)
|
393
|
+
|
394
|
+
|
395
|
+
@defterm.register(dist.Beta)
|
396
|
+
@defterm.register(dist.Kumaraswamy)
|
397
|
+
def _embed_beta(d: dist.Distribution) -> Term[dist.Distribution]:
|
398
|
+
return _register_distribution_op(cast(Hashable, type(d)))(
|
399
|
+
d.concentration1, d.concentration0
|
400
|
+
)
|
401
|
+
|
402
|
+
|
403
|
+
@defterm.register(dist.BinomialProbs)
|
404
|
+
@defterm.register(dist.NegativeBinomialProbs)
|
405
|
+
@defterm.register(dist.MultinomialProbs)
|
406
|
+
def _embed_binomial_probs(d: dist.Distribution) -> Term[dist.Distribution]:
|
407
|
+
return _register_distribution_op(cast(Hashable, type(d)))(d.probs, d.total_count)
|
408
|
+
|
409
|
+
|
410
|
+
@defterm.register(dist.BinomialLogits)
|
411
|
+
@defterm.register(dist.NegativeBinomialLogits)
|
412
|
+
@defterm.register(dist.MultinomialLogits)
|
413
|
+
def _embed_binomial_logits(d: dist.Distribution) -> Term[dist.Distribution]:
|
414
|
+
return _register_distribution_op(cast(Hashable, type(d)))(d.logits, d.total_count)
|
415
|
+
|
416
|
+
|
417
|
+
@defterm.register
|
418
|
+
def _embed_chi2(d: dist.Chi2) -> Term[dist.Distribution]:
|
419
|
+
return _register_distribution_op(cast(Hashable, type(d)))(d.df)
|
420
|
+
|
421
|
+
|
422
|
+
@defterm.register
|
423
|
+
def _embed_dirichlet(d: dist.Dirichlet) -> Term[dist.Distribution]:
|
424
|
+
return _register_distribution_op(cast(Hashable, type(d)))(d.concentration)
|
425
|
+
|
426
|
+
|
427
|
+
@defterm.register
|
428
|
+
def _embed_dirichlet_multinomial(
|
429
|
+
d: dist.DirichletMultinomial,
|
430
|
+
) -> Term[dist.Distribution]:
|
431
|
+
return _register_distribution_op(cast(Hashable, type(d)))(
|
432
|
+
d.concentration, total_count=d.total_count
|
433
|
+
)
|
434
|
+
|
435
|
+
|
436
|
+
@defterm.register(dist.Exponential)
|
437
|
+
@defterm.register(dist.Poisson)
|
438
|
+
def _embed_exponential(d: dist.Distribution) -> Term[dist.Distribution]:
|
439
|
+
return _register_distribution_op(cast(Hashable, type(d)))(d.rate)
|
440
|
+
|
441
|
+
|
442
|
+
@defterm.register
|
443
|
+
def _embed_gamma(d: dist.Gamma) -> Term[dist.Distribution]:
|
444
|
+
return _register_distribution_op(cast(Hashable, type(d)))(d.concentration, d.rate)
|
445
|
+
|
446
|
+
|
447
|
+
@defterm.register(dist.HalfCauchy)
|
448
|
+
@defterm.register(dist.HalfNormal)
|
449
|
+
def _embed_half_cauchy(d: dist.Distribution) -> Term[dist.Distribution]:
|
450
|
+
return _register_distribution_op(cast(Hashable, type(d)))(d.scale)
|
451
|
+
|
452
|
+
|
453
|
+
@defterm.register
|
454
|
+
def _embed_lkj_cholesky(d: dist.LKJCholesky) -> Term[dist.Distribution]:
|
455
|
+
return _register_distribution_op(cast(Hashable, type(d)))(
|
456
|
+
d.dim, concentration=d.concentration
|
457
|
+
)
|
458
|
+
|
459
|
+
|
460
|
+
@defterm.register
|
461
|
+
def _embed_multivariate_normal(d: dist.MultivariateNormal) -> Term[dist.Distribution]:
|
462
|
+
return _register_distribution_op(cast(Hashable, type(d)))(
|
463
|
+
d.loc, scale_tril=d.scale_tril
|
464
|
+
)
|
465
|
+
|
466
|
+
|
467
|
+
@defterm.register
|
468
|
+
def _embed_pareto(d: dist.Pareto) -> Term[dist.Distribution]:
|
469
|
+
return _register_distribution_op(cast(Hashable, type(d)))(d.scale, d.alpha)
|
470
|
+
|
471
|
+
|
472
|
+
@defterm.register
|
473
|
+
def _embed_uniform(d: dist.Uniform) -> Term[dist.Distribution]:
|
474
|
+
return _register_distribution_op(cast(Hashable, type(d)))(d.low, d.high)
|
475
|
+
|
476
|
+
|
477
|
+
@defterm.register
|
478
|
+
def _embed_von_mises(d: dist.VonMises) -> Term[dist.Distribution]:
|
479
|
+
return _register_distribution_op(cast(Hashable, type(d)))(d.loc, d.concentration)
|
480
|
+
|
481
|
+
|
482
|
+
@defterm.register
|
483
|
+
def _embed_weibull(d: dist.Weibull) -> Term[dist.Distribution]:
|
484
|
+
return _register_distribution_op(dist.Weibull)(d.scale, d.concentration)
|
485
|
+
|
486
|
+
|
487
|
+
@defterm.register
|
488
|
+
def _embed_wishart(d: dist.Wishart) -> Term[dist.Distribution]:
|
489
|
+
return _register_distribution_op(dist.Wishart)(d.df, d.scale_tril)
|
490
|
+
|
491
|
+
|
492
|
+
@defterm.register
|
493
|
+
def _embed_delta(d: dist.Delta) -> Term[dist.Distribution]:
|
494
|
+
return _register_distribution_op(cast(Hashable, type(d)))(
|
495
|
+
d.v, log_density=d.log_density, event_dim=d.event_dim
|
496
|
+
)
|
497
|
+
|
498
|
+
|
499
|
+
@defterm.register
|
500
|
+
def _embed_low_rank_multivariate_normal(
|
501
|
+
d: dist.LowRankMultivariateNormal,
|
502
|
+
) -> Term[dist.Distribution]:
|
503
|
+
return _register_distribution_op(cast(Hashable, type(d)))(
|
504
|
+
d.loc, d.cov_factor, d.cov_diag
|
505
|
+
)
|
506
|
+
|
507
|
+
|
508
|
+
@defterm.register
|
509
|
+
def _embed_relaxed_bernoulli_logits(
|
510
|
+
d: dist.RelaxedBernoulliLogits,
|
511
|
+
) -> Term[dist.Distribution]:
|
512
|
+
return _register_distribution_op(cast(Hashable, type(d)))(d.temperature, d.logits)
|
513
|
+
|
514
|
+
|
515
|
+
@defterm.register
|
516
|
+
def _embed_independent(d: dist.Independent) -> Term[dist.Distribution]:
|
517
|
+
return _register_distribution_op(cast(Hashable, type(d)))(
|
518
|
+
d.base_dist, d.reinterpreted_batch_ndims
|
519
|
+
)
|
520
|
+
|
521
|
+
|
522
|
+
BernoulliLogits = _register_distribution_op(dist.BernoulliLogits)
|
523
|
+
BernoulliProbs = _register_distribution_op(dist.BernoulliProbs)
|
524
|
+
Beta = _register_distribution_op(dist.Beta)
|
525
|
+
BinomialProbs = _register_distribution_op(dist.BinomialProbs)
|
526
|
+
BinomialLogits = _register_distribution_op(dist.BinomialLogits)
|
527
|
+
CategoricalLogits = _register_distribution_op(dist.CategoricalLogits)
|
528
|
+
CategoricalProbs = _register_distribution_op(dist.CategoricalProbs)
|
529
|
+
Cauchy = _register_distribution_op(dist.Cauchy)
|
530
|
+
Chi2 = _register_distribution_op(dist.Chi2)
|
531
|
+
Delta = _register_distribution_op(dist.Delta)
|
532
|
+
Dirichlet = _register_distribution_op(dist.Dirichlet)
|
533
|
+
DirichletMultinomial = _register_distribution_op(dist.DirichletMultinomial)
|
534
|
+
Distribution = _register_distribution_op(dist.Distribution)
|
535
|
+
Exponential = _register_distribution_op(dist.Exponential)
|
536
|
+
Gamma = _register_distribution_op(dist.Gamma)
|
537
|
+
GeometricLogits = _register_distribution_op(dist.GeometricLogits)
|
538
|
+
GeometricProbs = _register_distribution_op(dist.GeometricProbs)
|
539
|
+
Gumbel = _register_distribution_op(dist.Gumbel)
|
540
|
+
HalfCauchy = _register_distribution_op(dist.HalfCauchy)
|
541
|
+
HalfNormal = _register_distribution_op(dist.HalfNormal)
|
542
|
+
Independent = _register_distribution_op(dist.Independent)
|
543
|
+
Kumaraswamy = _register_distribution_op(dist.Kumaraswamy)
|
544
|
+
LKJCholesky = _register_distribution_op(dist.LKJCholesky)
|
545
|
+
Laplace = _register_distribution_op(dist.Laplace)
|
546
|
+
LogNormal = _register_distribution_op(dist.LogNormal)
|
547
|
+
Logistic = _register_distribution_op(dist.Logistic)
|
548
|
+
LowRankMultivariateNormal = _register_distribution_op(dist.LowRankMultivariateNormal)
|
549
|
+
MultinomialProbs = _register_distribution_op(dist.MultinomialProbs)
|
550
|
+
MultinomialLogits = _register_distribution_op(dist.MultinomialLogits)
|
551
|
+
MultivariateNormal = _register_distribution_op(dist.MultivariateNormal)
|
552
|
+
NegativeBinomialProbs = _register_distribution_op(dist.NegativeBinomialProbs)
|
553
|
+
NegativeBinomialLogits = _register_distribution_op(dist.NegativeBinomialLogits)
|
554
|
+
Normal = _register_distribution_op(dist.Normal)
|
555
|
+
Pareto = _register_distribution_op(dist.Pareto)
|
556
|
+
Poisson = _register_distribution_op(dist.Poisson)
|
557
|
+
RelaxedBernoulliLogits = _register_distribution_op(dist.RelaxedBernoulliLogits)
|
558
|
+
StudentT = _register_distribution_op(dist.StudentT)
|
559
|
+
Uniform = _register_distribution_op(dist.Uniform)
|
560
|
+
VonMises = _register_distribution_op(dist.VonMises)
|
561
|
+
Weibull = _register_distribution_op(dist.Weibull)
|
562
|
+
Wishart = _register_distribution_op(dist.Wishart)
|