effectful 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/handlers/indexed.py +27 -46
- effectful/handlers/jax/__init__.py +14 -0
- effectful/handlers/jax/_handlers.py +293 -0
- effectful/handlers/jax/_terms.py +502 -0
- effectful/handlers/jax/numpy/__init__.py +23 -0
- effectful/handlers/jax/numpy/linalg.py +13 -0
- effectful/handlers/jax/scipy/special.py +11 -0
- effectful/handlers/numpyro.py +562 -0
- effectful/handlers/pyro.py +565 -214
- effectful/handlers/torch.py +321 -169
- effectful/internals/runtime.py +6 -13
- effectful/internals/tensor_utils.py +32 -0
- effectful/internals/unification.py +900 -0
- effectful/ops/semantics.py +104 -84
- effectful/ops/syntax.py +1276 -167
- effectful/ops/types.py +141 -35
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/METADATA +65 -57
- effectful-0.2.0.dist-info/RECORD +26 -0
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/WHEEL +1 -1
- effectful/handlers/numbers.py +0 -259
- effectful/internals/base_impl.py +0 -259
- effectful-0.0.1.dist-info/RECORD +0 -19
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info/licenses}/LICENSE.md +0 -0
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/top_level.txt +0 -0
effectful/handlers/pyro.py
CHANGED
@@ -1,35 +1,46 @@
|
|
1
|
+
import functools
|
1
2
|
import typing
|
2
|
-
import
|
3
|
-
from typing import Any
|
3
|
+
from collections.abc import Collection, Mapping
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
import pyro.poutine.subsample_messenger
|
4
7
|
|
5
8
|
try:
|
6
9
|
import pyro
|
7
10
|
except ImportError:
|
8
11
|
raise ImportError("Pyro is required to use effectful.handlers.pyro.")
|
9
12
|
|
13
|
+
import pyro.distributions as dist
|
14
|
+
from pyro.distributions.torch_distribution import (
|
15
|
+
TorchDistribution,
|
16
|
+
TorchDistributionMixin,
|
17
|
+
)
|
18
|
+
|
10
19
|
try:
|
11
20
|
import torch
|
12
21
|
except ImportError:
|
13
22
|
raise ImportError("PyTorch is required to use effectful.handlers.pyro.")
|
14
23
|
|
15
|
-
from
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
24
|
+
from effectful.handlers.torch import (
|
25
|
+
bind_dims,
|
26
|
+
sizesof,
|
27
|
+
unbind_dims,
|
28
|
+
)
|
29
|
+
from effectful.internals.runtime import interpreter
|
30
|
+
from effectful.ops.semantics import apply, runner, typeof
|
31
|
+
from effectful.ops.syntax import defdata, defop, defterm
|
32
|
+
from effectful.ops.types import NotHandled, Operation, Term
|
22
33
|
|
23
34
|
|
24
35
|
@defop
|
25
36
|
def pyro_sample(
|
26
37
|
name: str,
|
27
|
-
fn:
|
38
|
+
fn: TorchDistributionMixin,
|
28
39
|
*args,
|
29
|
-
obs:
|
30
|
-
obs_mask:
|
31
|
-
mask:
|
32
|
-
infer:
|
40
|
+
obs: torch.Tensor | None = None,
|
41
|
+
obs_mask: torch.BoolTensor | None = None,
|
42
|
+
mask: torch.BoolTensor | None = None,
|
43
|
+
infer: pyro.poutine.runtime.InferDict | None = None,
|
33
44
|
**kwargs,
|
34
45
|
) -> torch.Tensor:
|
35
46
|
"""
|
@@ -41,6 +52,39 @@ def pyro_sample(
|
|
41
52
|
)
|
42
53
|
|
43
54
|
|
55
|
+
class Naming:
|
56
|
+
"""
|
57
|
+
A mapping from dimensions (indexed from the right) to names.
|
58
|
+
"""
|
59
|
+
|
60
|
+
def __init__(self, name_to_dim: Mapping[Operation[[], torch.Tensor], int]):
|
61
|
+
assert all(v < 0 for v in name_to_dim.values())
|
62
|
+
self.name_to_dim = name_to_dim
|
63
|
+
|
64
|
+
@staticmethod
|
65
|
+
def from_shape(
|
66
|
+
names: Collection[Operation[[], torch.Tensor]], event_dims: int
|
67
|
+
) -> "Naming":
|
68
|
+
"""Create a naming from a set of indices and the number of event dimensions.
|
69
|
+
|
70
|
+
The resulting naming converts tensors of shape
|
71
|
+
``| batch_shape | named | event_shape |``
|
72
|
+
to tensors of shape ``| batch_shape | event_shape |, | named |``.
|
73
|
+
|
74
|
+
"""
|
75
|
+
assert event_dims >= 0
|
76
|
+
return Naming({n: -event_dims - len(names) + i for i, n in enumerate(names)})
|
77
|
+
|
78
|
+
def apply(self, value: torch.Tensor) -> torch.Tensor:
|
79
|
+
indexes: list[Any] = [slice(None)] * (len(value.shape))
|
80
|
+
for n, d in self.name_to_dim.items():
|
81
|
+
indexes[len(value.shape) + d] = n()
|
82
|
+
return value[tuple(indexes)]
|
83
|
+
|
84
|
+
def __repr__(self):
|
85
|
+
return f"Naming({self.name_to_dim})"
|
86
|
+
|
87
|
+
|
44
88
|
class PyroShim(pyro.poutine.messenger.Messenger):
|
45
89
|
"""Pyro handler that wraps all sample sites in a custom effectful type.
|
46
90
|
|
@@ -81,17 +125,24 @@ class PyroShim(pyro.poutine.messenger.Messenger):
|
|
81
125
|
Sampled y
|
82
126
|
"""
|
83
127
|
|
84
|
-
|
128
|
+
# Tracks the named dimensions on any sample site that we have handled.
|
129
|
+
# Ideally, this information would be carried on the sample message itself.
|
130
|
+
# However, when using guides, sample sites are completely replaced by fresh
|
131
|
+
# guide sample sites that do not carry the same infer dict.
|
132
|
+
#
|
133
|
+
# We can only restore the named dimensions on samples that we have handled
|
134
|
+
# at least once in the shim.
|
135
|
+
_index_naming: dict[str, Naming]
|
85
136
|
|
86
|
-
def
|
87
|
-
|
88
|
-
warnings.warn("PyroShim should be installed at most once.")
|
89
|
-
return super().__enter__()
|
137
|
+
def __init__(self):
|
138
|
+
self._index_naming = {}
|
90
139
|
|
91
140
|
@staticmethod
|
92
141
|
def _broadcast_to_named(
|
93
|
-
t: torch.Tensor,
|
94
|
-
|
142
|
+
t: torch.Tensor,
|
143
|
+
shape: torch.Size,
|
144
|
+
indices: Mapping[Operation[[], torch.Tensor], int],
|
145
|
+
) -> tuple[torch.Tensor, "Naming"]:
|
95
146
|
"""Convert a tensor `t` to a fully positional tensor that is
|
96
147
|
broadcastable with the positional representation of tensors of shape
|
97
148
|
|shape|, |indices|.
|
@@ -99,6 +150,9 @@ class PyroShim(pyro.poutine.messenger.Messenger):
|
|
99
150
|
"""
|
100
151
|
t_indices = sizesof(t)
|
101
152
|
|
153
|
+
if not isinstance(t, torch.Tensor):
|
154
|
+
t = torch.tensor(t)
|
155
|
+
|
102
156
|
if len(t.shape) < len(shape):
|
103
157
|
t = t.expand(shape)
|
104
158
|
|
@@ -106,7 +160,7 @@ class PyroShim(pyro.poutine.messenger.Messenger):
|
|
106
160
|
name_to_dim = {}
|
107
161
|
for i, (k, v) in enumerate(reversed(list(indices.items()))):
|
108
162
|
if k in t_indices:
|
109
|
-
t =
|
163
|
+
t = bind_dims(t, k)
|
110
164
|
else:
|
111
165
|
t = t.expand((v,) + t.shape)
|
112
166
|
name_to_dim[k] = -len(shape) - i - 1
|
@@ -114,7 +168,7 @@ class PyroShim(pyro.poutine.messenger.Messenger):
|
|
114
168
|
# create a positional dimension for every remaining named index in `t`
|
115
169
|
n_batch_and_dist_named = len(t.shape)
|
116
170
|
for i, k in enumerate(reversed(list(sizesof(t).keys()))):
|
117
|
-
t =
|
171
|
+
t = bind_dims(t, k)
|
118
172
|
name_to_dim[k] = -n_batch_and_dist_named - i - 1
|
119
173
|
|
120
174
|
return t, Naming(name_to_dim)
|
@@ -124,26 +178,47 @@ class PyroShim(pyro.poutine.messenger.Messenger):
|
|
124
178
|
assert msg["type"] == "sample"
|
125
179
|
assert msg["name"] is not None
|
126
180
|
assert msg["infer"] is not None
|
127
|
-
assert isinstance(
|
128
|
-
msg["fn"], pyro.distributions.torch_distribution.TorchDistributionMixin
|
129
|
-
)
|
181
|
+
assert isinstance(msg["fn"], TorchDistributionMixin)
|
130
182
|
|
131
183
|
if pyro.poutine.util.site_is_subsample(msg) or pyro.poutine.util.site_is_factor(
|
132
184
|
msg
|
133
185
|
):
|
134
186
|
return
|
135
187
|
|
136
|
-
if
|
137
|
-
|
138
|
-
|
188
|
+
if "pyro_shim_status" in msg["infer"]:
|
189
|
+
handler_id, handler_stage = msg["infer"]["pyro_shim_status"] # type: ignore
|
190
|
+
else:
|
191
|
+
handler_id = id(self)
|
192
|
+
handler_stage = 0
|
193
|
+
msg["infer"]["pyro_shim_status"] = (handler_id, handler_stage) # type: ignore
|
194
|
+
|
195
|
+
if handler_id != id(self): # Never handle a message that is not ours.
|
196
|
+
return
|
197
|
+
|
198
|
+
assert handler_stage in (0, 1)
|
199
|
+
|
200
|
+
# PyroShim turns each call to pyro.sample into two calls. The first
|
201
|
+
# dispatches to pyro_sample and the effectful stack. The effectful stack
|
202
|
+
# eventually calls pyro.sample again. We use state in PyroShim to
|
203
|
+
# recognize that we've been called twice, and we dispatch to the pyro
|
204
|
+
# stack.
|
205
|
+
#
|
206
|
+
# This branch handles the second call, so it massages the message to be
|
207
|
+
# compatible with Pyro. In particular, it removes all named dimensions
|
208
|
+
# and stores naming information in the message. Names are replaced by
|
209
|
+
# _pyro_post_sample.
|
210
|
+
if handler_stage == 1:
|
211
|
+
if "_markov_scope" in msg["infer"]:
|
212
|
+
msg["infer"]["_markov_scope"].pop(msg["name"], None)
|
139
213
|
|
140
214
|
dist = msg["fn"]
|
141
215
|
obs = msg["value"] if msg["is_observed"] else None
|
142
216
|
|
143
217
|
# pdist shape: | named1 | batch_shape | event_shape |
|
144
218
|
# obs shape: | batch_shape | event_shape |, | named2 | where named2 may overlap named1
|
145
|
-
|
146
|
-
naming =
|
219
|
+
indices = sizesof(dist)
|
220
|
+
naming = Naming.from_shape(indices, len(dist.shape()))
|
221
|
+
pdist = bind_dims(dist, *indices.keys())
|
147
222
|
|
148
223
|
if msg["mask"] is None:
|
149
224
|
mask = torch.tensor(True)
|
@@ -152,290 +227,566 @@ class PyroShim(pyro.poutine.messenger.Messenger):
|
|
152
227
|
else:
|
153
228
|
mask = msg["mask"]
|
154
229
|
|
155
|
-
|
156
|
-
|
230
|
+
assert set(sizesof(mask).keys()) <= (
|
231
|
+
set(indices.keys()) | set(sizesof(obs).keys())
|
157
232
|
)
|
233
|
+
pos_mask, _ = PyroShim._broadcast_to_named(mask, dist.batch_shape, indices)
|
158
234
|
|
159
|
-
pos_obs:
|
235
|
+
pos_obs: torch.Tensor | None = None
|
160
236
|
if obs is not None:
|
161
237
|
pos_obs, naming = PyroShim._broadcast_to_named(
|
162
|
-
obs, dist.shape(),
|
238
|
+
obs, dist.shape(), indices
|
163
239
|
)
|
164
240
|
|
241
|
+
# Each of the batch dimensions on the distribution gets a
|
242
|
+
# cond_indep_stack frame.
|
165
243
|
for var, dim in naming.name_to_dim.items():
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
244
|
+
# There can be additional batch dimensions on the observation
|
245
|
+
# that do not get frames, so only consider dimensions on the
|
246
|
+
# distribution.
|
247
|
+
if var in indices:
|
248
|
+
frame = pyro.poutine.indep_messenger.CondIndepStackFrame(
|
249
|
+
name=f"__index_plate_{var}",
|
250
|
+
# dims are indexed from the right of the batch shape
|
251
|
+
dim=dim + len(pdist.event_shape),
|
252
|
+
size=indices[var],
|
253
|
+
counter=0,
|
254
|
+
)
|
255
|
+
msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"]
|
170
256
|
|
171
257
|
msg["fn"] = pdist
|
172
258
|
msg["value"] = pos_obs
|
173
259
|
msg["mask"] = pos_mask
|
174
|
-
|
260
|
+
|
261
|
+
# stash the index naming on the sample message so that future
|
262
|
+
# consumers of the trace can get at it
|
263
|
+
msg["_index_naming"] = naming # type: ignore
|
264
|
+
|
265
|
+
self._index_naming[msg["name"]] = naming
|
175
266
|
|
176
267
|
assert sizesof(msg["value"]) == {}
|
177
268
|
assert sizesof(msg["mask"]) == {}
|
178
269
|
|
179
|
-
|
270
|
+
# This branch handles the first call to pyro.sample by calling pyro_sample.
|
271
|
+
else:
|
272
|
+
infer = msg["infer"].copy()
|
273
|
+
infer["pyro_shim_status"] = (handler_id, 1) # type: ignore
|
180
274
|
|
181
|
-
try:
|
182
|
-
self._current_site = msg["name"]
|
183
275
|
msg["value"] = pyro_sample(
|
184
276
|
msg["name"],
|
185
277
|
msg["fn"],
|
186
278
|
obs=msg["value"] if msg["is_observed"] else None,
|
187
|
-
infer=
|
279
|
+
infer=infer,
|
188
280
|
)
|
189
|
-
finally:
|
190
|
-
self._current_site = None
|
191
281
|
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
282
|
+
# flags to guarantee commutativity of condition, intervene, trace
|
283
|
+
msg["stop"] = True
|
284
|
+
msg["done"] = True
|
285
|
+
msg["mask"] = False
|
286
|
+
msg["is_observed"] = True
|
287
|
+
msg["infer"]["is_auxiliary"] = True
|
288
|
+
msg["infer"]["_do_not_trace"] = True
|
199
289
|
|
200
290
|
def _pyro_post_sample(self, msg: pyro.poutine.runtime.Message) -> None:
|
201
|
-
|
202
|
-
|
203
|
-
|
291
|
+
if typing.TYPE_CHECKING:
|
292
|
+
assert msg["name"] is not None
|
293
|
+
assert msg["value"] is not None
|
294
|
+
assert msg["infer"] is not None
|
204
295
|
|
205
|
-
#
|
206
|
-
|
296
|
+
# If there is no shim status, assume that we are looking at a guide sample.
|
297
|
+
# In this case, we should handle the sample and claim it as ours if we have naming
|
298
|
+
# information for it.
|
299
|
+
if "pyro_shim_status" not in msg["infer"]:
|
300
|
+
# Except, of course, for subsample messages, which we should ignore.
|
301
|
+
if (
|
302
|
+
pyro.poutine.util.site_is_subsample(msg)
|
303
|
+
or msg["name"] not in self._index_naming
|
304
|
+
):
|
305
|
+
return
|
306
|
+
msg["infer"]["pyro_shim_status"] = (id(self), 1) # type: ignore
|
307
|
+
|
308
|
+
# If this message has been handled already by a different pyro shim, ignore.
|
309
|
+
handler_id, handler_stage = msg["infer"]["pyro_shim_status"] # type: ignore
|
310
|
+
if handler_id != id(self) or handler_stage < 1:
|
311
|
+
return
|
207
312
|
|
208
313
|
value = msg["value"]
|
209
314
|
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape
|
216
|
-
if len(value.shape) < len(dist_shape):
|
217
|
-
value = value.broadcast_to(
|
218
|
-
torch.broadcast_shapes(value.shape, dist_shape)
|
219
|
-
)
|
220
|
-
value = naming.apply(value)
|
221
|
-
msg["value"] = value
|
315
|
+
naming = self._index_naming.get(msg["name"], Naming({}))
|
316
|
+
infer = msg["infer"] if msg["infer"] is not None else {}
|
317
|
+
assert "enumerate" not in infer or len(naming.name_to_dim) == 0, (
|
318
|
+
"Enumeration is not currently supported in PyroShim."
|
319
|
+
)
|
222
320
|
|
321
|
+
# note: is it safe to assume that msg['fn'] is a distribution?
|
322
|
+
dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape # type: ignore
|
323
|
+
if len(value.shape) < len(dist_shape):
|
324
|
+
value = value.broadcast_to(torch.broadcast_shapes(value.shape, dist_shape))
|
223
325
|
|
224
|
-
|
225
|
-
|
226
|
-
A mapping from dimensions (indexed from the right) to names.
|
227
|
-
"""
|
326
|
+
value = naming.apply(value)
|
327
|
+
msg["value"] = value
|
228
328
|
|
229
|
-
def __init__(self, name_to_dim: Mapping[Operation[[], int], int]):
|
230
|
-
assert all(v < 0 for v in name_to_dim.values())
|
231
|
-
self.name_to_dim = name_to_dim
|
232
329
|
|
233
|
-
|
234
|
-
|
235
|
-
|
330
|
+
PyroDistribution = (
|
331
|
+
pyro.distributions.torch_distribution.TorchDistribution
|
332
|
+
| pyro.distributions.torch_distribution.TorchDistributionMixin
|
333
|
+
)
|
236
334
|
|
237
|
-
The resulting naming converts tensors of shape
|
238
|
-
``| batch_shape | named | event_shape |``
|
239
|
-
to tensors of shape ``| batch_shape | event_shape |, | named |``.
|
240
335
|
|
241
|
-
|
242
|
-
|
243
|
-
|
336
|
+
@unbind_dims.register(pyro.distributions.torch_distribution.TorchDistribution) # type: ignore
|
337
|
+
@unbind_dims.register(pyro.distributions.torch_distribution.TorchDistributionMixin) # type: ignore
|
338
|
+
def _unbind_dims_distribution(
|
339
|
+
value: pyro.distributions.torch_distribution.TorchDistribution,
|
340
|
+
*names: Operation[[], torch.Tensor],
|
341
|
+
) -> pyro.distributions.torch_distribution.TorchDistribution:
|
342
|
+
d = value
|
343
|
+
batch_shape = None
|
244
344
|
|
245
|
-
def
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
345
|
+
def _validate_batch_shape(t):
|
346
|
+
nonlocal batch_shape
|
347
|
+
if len(t.shape) < len(names):
|
348
|
+
raise ValueError(
|
349
|
+
"All tensors must have at least as many dimensions as names"
|
350
|
+
)
|
250
351
|
|
251
|
-
|
252
|
-
|
352
|
+
if batch_shape is None:
|
353
|
+
batch_shape = t.shape[: len(names)]
|
253
354
|
|
355
|
+
if (
|
356
|
+
len(t.shape) < len(batch_shape)
|
357
|
+
or t.shape[: len(batch_shape)] != batch_shape
|
358
|
+
):
|
359
|
+
raise ValueError("All tensors must have the same batch shape.")
|
360
|
+
|
361
|
+
def _to_named(a):
|
362
|
+
nonlocal batch_shape
|
363
|
+
if isinstance(a, torch.Tensor):
|
364
|
+
_validate_batch_shape(a)
|
365
|
+
return typing.cast(torch.Tensor, a)[tuple(n() for n in names)]
|
366
|
+
elif isinstance(a, TorchDistribution):
|
367
|
+
return unbind_dims(a, *names)
|
368
|
+
else:
|
369
|
+
return a
|
370
|
+
|
371
|
+
# Convert to a term in a context that does not evaluate distribution constructors.
|
372
|
+
def _apply(op, *args, **kwargs):
|
373
|
+
typ = op.__type_rule__(*args, **kwargs)
|
374
|
+
if issubclass(
|
375
|
+
typ, pyro.distributions.torch_distribution.TorchDistribution
|
376
|
+
) or issubclass(
|
377
|
+
typ, pyro.distributions.torch_distribution.TorchDistributionMixin
|
378
|
+
):
|
379
|
+
return defdata(op, *args, **kwargs)
|
380
|
+
return op.__default_rule__(*args, **kwargs)
|
381
|
+
|
382
|
+
with runner({apply: _apply}):
|
383
|
+
d = defterm(d)
|
384
|
+
|
385
|
+
if not (isinstance(d, Term) and typeof(d) is TorchDistribution):
|
386
|
+
raise NotHandled
|
387
|
+
|
388
|
+
new_d = d.op(
|
389
|
+
*[_to_named(a) for a in d.args],
|
390
|
+
**{k: _to_named(v) for (k, v) in d.kwargs.items()},
|
391
|
+
)
|
392
|
+
assert new_d.event_shape == d.event_shape
|
393
|
+
return new_d
|
394
|
+
|
395
|
+
|
396
|
+
@bind_dims.register(pyro.distributions.torch_distribution.TorchDistribution) # type: ignore
|
397
|
+
@bind_dims.register(pyro.distributions.torch_distribution.TorchDistributionMixin) # type: ignore
|
398
|
+
def _bind_dims_distribution(
|
399
|
+
value: pyro.distributions.torch_distribution.TorchDistribution,
|
400
|
+
*names: Operation[[], torch.Tensor],
|
401
|
+
) -> pyro.distributions.torch_distribution.TorchDistribution:
|
402
|
+
d = value
|
403
|
+
|
404
|
+
def _to_positional(a, indices):
|
405
|
+
if isinstance(a, torch.Tensor):
|
406
|
+
# broadcast to full indexed shape
|
407
|
+
existing_dims = set(sizesof(a).keys())
|
408
|
+
missing_dims = set(indices) - existing_dims
|
409
|
+
|
410
|
+
a_indexed = torch.broadcast_to(
|
411
|
+
a, torch.Size([indices[dim] for dim in missing_dims]) + a.shape
|
412
|
+
)[tuple(n() for n in missing_dims)]
|
413
|
+
return bind_dims(a_indexed, *names)
|
414
|
+
elif isinstance(a, TorchDistribution):
|
415
|
+
return bind_dims(a, *names)
|
416
|
+
else:
|
417
|
+
return a
|
418
|
+
|
419
|
+
# Convert to a term in a context that does not evaluate distribution constructors.
|
420
|
+
def _apply(op, *args, **kwargs):
|
421
|
+
typ = op.__type_rule__(*args, **kwargs)
|
422
|
+
if issubclass(
|
423
|
+
typ, pyro.distributions.torch_distribution.TorchDistribution
|
424
|
+
) or issubclass(
|
425
|
+
typ, pyro.distributions.torch_distribution.TorchDistributionMixin
|
426
|
+
):
|
427
|
+
return defdata(op, *args, **kwargs)
|
428
|
+
return op.__default_rule__(*args, **kwargs)
|
254
429
|
|
255
|
-
|
256
|
-
|
257
|
-
positional.
|
430
|
+
with runner({apply: _apply}):
|
431
|
+
d = defterm(d)
|
258
432
|
|
259
|
-
|
433
|
+
if not (isinstance(d, Term) and typeof(d) is TorchDistribution):
|
434
|
+
raise NotHandled
|
260
435
|
|
261
|
-
|
436
|
+
sizes = sizesof(d)
|
437
|
+
indices = {k: sizes[k] for k in names}
|
262
438
|
|
263
|
-
|
264
|
-
|
265
|
-
)
|
266
|
-
self.base_dist = base_dist
|
267
|
-
self.indices = sizesof(base_dist.sample())
|
439
|
+
pos_args = [_to_positional(a, indices) for a in d.args]
|
440
|
+
pos_kwargs = {k: _to_positional(v, indices) for (k, v) in d.kwargs.items()}
|
441
|
+
new_d = d.op(*pos_args, **pos_kwargs)
|
268
442
|
|
269
|
-
|
270
|
-
|
443
|
+
assert new_d.event_shape == d.event_shape
|
444
|
+
return new_d
|
271
445
|
|
272
|
-
super().__init__()
|
273
446
|
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
447
|
+
@functools.cache
|
448
|
+
def _register_distribution_op(
|
449
|
+
dist_constr: type[TorchDistribution],
|
450
|
+
) -> Operation[Any, TorchDistribution]:
|
451
|
+
# introduce a wrapper so that we can control type annotations
|
452
|
+
def wrapper(*args, **kwargs) -> TorchDistribution:
|
453
|
+
return dist_constr(*args, **kwargs)
|
281
454
|
|
282
|
-
|
283
|
-
n_sample = len(value.shape) - n_base
|
455
|
+
return defop(wrapper)
|
284
456
|
|
285
|
-
base_dims = dims[len(dims) - n_base :]
|
286
|
-
named_dims = dims[:n_named]
|
287
|
-
sample_dims = dims[n_named : n_named + n_sample]
|
288
457
|
|
289
|
-
|
290
|
-
|
291
|
-
|
458
|
+
@defdata.register(pyro.distributions.torch_distribution.TorchDistribution)
|
459
|
+
@defdata.register(pyro.distributions.torch_distribution.TorchDistributionMixin)
|
460
|
+
class _DistributionTerm(Term[TorchDistribution], TorchDistribution):
|
461
|
+
"""A distribution wrapper that satisfies the Term interface.
|
292
462
|
|
293
|
-
|
294
|
-
|
463
|
+
Represented as a term of the form call(D, *args, **kwargs) where D is the
|
464
|
+
distribution constructor.
|
295
465
|
|
296
|
-
|
466
|
+
Note: When we construct instances of this class, we put distribution
|
467
|
+
parameters that can be expanded in the args list and those that cannot in
|
468
|
+
the kwargs list.
|
297
469
|
|
298
|
-
|
299
|
-
|
300
|
-
|
470
|
+
"""
|
471
|
+
|
472
|
+
_op: Operation[Any, TorchDistribution]
|
473
|
+
_args: tuple
|
474
|
+
_kwargs: dict
|
475
|
+
|
476
|
+
def __init__(self, op: Operation[Any, TorchDistribution], *args, **kwargs):
|
477
|
+
self._op = op
|
478
|
+
self._args = tuple(defterm(a) for a in args)
|
479
|
+
self._kwargs = {k: defterm(v) for (k, v) in kwargs.items()}
|
480
|
+
|
481
|
+
@property
|
482
|
+
def op(self):
|
483
|
+
return self._op
|
484
|
+
|
485
|
+
@property
|
486
|
+
def args(self):
|
487
|
+
return self._args
|
488
|
+
|
489
|
+
@property
|
490
|
+
def kwargs(self):
|
491
|
+
return self._kwargs
|
492
|
+
|
493
|
+
@property
|
494
|
+
def _base_dist(self):
|
495
|
+
return self._op(*self.args, **self.kwargs)
|
301
496
|
|
302
497
|
@property
|
303
498
|
def has_rsample(self):
|
304
|
-
return self.
|
499
|
+
return self._base_dist.has_rsample
|
305
500
|
|
306
501
|
@property
|
307
502
|
def batch_shape(self):
|
308
|
-
return
|
309
|
-
torch.Size([s for s in self.indices.values()]) + self.base_dist.batch_shape
|
310
|
-
)
|
503
|
+
return self._base_dist.batch_shape
|
311
504
|
|
312
505
|
@property
|
313
506
|
def event_shape(self):
|
314
|
-
return self.
|
507
|
+
return self._base_dist.event_shape
|
315
508
|
|
316
509
|
@property
|
317
510
|
def has_enumerate_support(self):
|
318
|
-
return self.
|
511
|
+
return self._base_dist.has_enumerate_support
|
319
512
|
|
320
513
|
@property
|
321
514
|
def arg_constraints(self):
|
322
|
-
return self.
|
515
|
+
return self._base_dist.arg_constraints
|
323
516
|
|
324
517
|
@property
|
325
518
|
def support(self):
|
326
|
-
return self.
|
327
|
-
|
328
|
-
def __repr__(self):
|
329
|
-
return f"PositionalDistribution({self.base_dist})"
|
519
|
+
return self._base_dist.support
|
330
520
|
|
331
521
|
def sample(self, sample_shape=torch.Size()):
|
332
|
-
return self.
|
522
|
+
return self._base_dist.sample(sample_shape)
|
333
523
|
|
334
524
|
def rsample(self, sample_shape=torch.Size()):
|
335
|
-
return self.
|
525
|
+
return self._base_dist.rsample(sample_shape)
|
336
526
|
|
337
527
|
def log_prob(self, value):
|
338
|
-
return self.
|
339
|
-
self.base_dist.log_prob(self._from_positional(value))
|
340
|
-
)
|
528
|
+
return self._base_dist.log_prob(value)
|
341
529
|
|
342
530
|
def enumerate_support(self, expand=True):
|
343
|
-
return self.
|
531
|
+
return self._base_dist.enumerate_support(expand)
|
344
532
|
|
345
533
|
|
346
|
-
|
347
|
-
|
534
|
+
@defterm.register(TorchDistribution)
|
535
|
+
@defterm.register(TorchDistributionMixin)
|
536
|
+
def _embed_distribution(dist: TorchDistribution) -> Term[TorchDistribution]:
|
537
|
+
raise ValueError(
|
538
|
+
f"No embedding provided for distribution of type {type(dist).__name__}."
|
539
|
+
)
|
348
540
|
|
349
|
-
def __init__(
|
350
|
-
self,
|
351
|
-
base_dist: pyro.distributions.torch_distribution.TorchDistribution,
|
352
|
-
names: Collection[Operation[[], int]],
|
353
|
-
):
|
354
|
-
"""
|
355
|
-
:param base_dist: A distribution with batch dimensions.
|
356
541
|
|
357
|
-
|
542
|
+
@defterm.register
|
543
|
+
def _embed_expanded(d: dist.ExpandedDistribution) -> Term[TorchDistribution]:
|
544
|
+
with interpreter({}):
|
545
|
+
batch_shape = d._batch_shape
|
546
|
+
base_dist = d.base_dist
|
547
|
+
base_batch_shape = base_dist.batch_shape
|
548
|
+
if batch_shape == base_batch_shape:
|
549
|
+
return base_dist
|
358
550
|
|
359
|
-
|
360
|
-
self.base_dist = base_dist
|
361
|
-
self.names = names
|
551
|
+
raise ValueError("Nontrivial ExpandedDistribution not implemented.")
|
362
552
|
|
363
|
-
assert 1 <= len(names) <= len(base_dist.batch_shape)
|
364
|
-
base_indices = sizesof(base_dist.sample())
|
365
|
-
assert not any(n in base_indices for n in names)
|
366
553
|
|
367
|
-
|
368
|
-
|
369
|
-
|
554
|
+
@defterm.register
|
555
|
+
def _embed_independent(d: dist.Independent) -> Term[TorchDistribution]:
|
556
|
+
with interpreter({}):
|
557
|
+
base_dist = d.base_dist
|
558
|
+
reinterpreted_batch_ndims = d.reinterpreted_batch_ndims
|
370
559
|
|
371
|
-
|
372
|
-
return self.naming.apply(value)
|
560
|
+
return _register_distribution_op(type(d))(base_dist, reinterpreted_batch_ndims)
|
373
561
|
|
374
|
-
def _from_named(self, value: torch.Tensor) -> torch.Tensor:
|
375
|
-
pos_value = to_tensor(value, self.names)
|
376
562
|
|
377
|
-
|
563
|
+
@defterm.register
|
564
|
+
def _embed_folded(d: dist.FoldedDistribution) -> Term[TorchDistribution]:
|
565
|
+
with interpreter({}):
|
566
|
+
base_dist = d.base_dist
|
378
567
|
|
379
|
-
|
380
|
-
n_named = len(self.names)
|
381
|
-
n_sample = len(pos_value.shape) - n_base - n_named
|
568
|
+
return _register_distribution_op(type(d))(base_dist) # type: ignore
|
382
569
|
|
383
|
-
base_dims = dims[len(dims) - n_base :]
|
384
|
-
named_dims = dims[:n_named]
|
385
|
-
sample_dims = dims[n_named : n_named + n_sample]
|
386
570
|
|
387
|
-
|
571
|
+
@defterm.register
|
572
|
+
def _embed_masked(d: dist.MaskedDistribution) -> Term[TorchDistribution]:
|
573
|
+
with interpreter({}):
|
574
|
+
base_dist = d.base_dist
|
575
|
+
mask = d._mask
|
388
576
|
|
389
|
-
|
577
|
+
return _register_distribution_op(type(d))(base_dist, mask)
|
390
578
|
|
391
|
-
@property
|
392
|
-
def has_rsample(self):
|
393
|
-
return self.base_dist.has_rsample
|
394
579
|
|
395
|
-
|
396
|
-
|
397
|
-
|
580
|
+
@defterm.register(dist.Cauchy)
|
581
|
+
@defterm.register(dist.Gumbel)
|
582
|
+
@defterm.register(dist.Laplace)
|
583
|
+
@defterm.register(dist.LogNormal)
|
584
|
+
@defterm.register(dist.Logistic)
|
585
|
+
@defterm.register(dist.LogisticNormal)
|
586
|
+
@defterm.register(dist.Normal)
|
587
|
+
@defterm.register(dist.StudentT)
|
588
|
+
def _embed_loc_scale(d: TorchDistribution) -> Term[TorchDistribution]:
|
589
|
+
with interpreter({}):
|
590
|
+
loc = d.loc
|
591
|
+
scale = d.scale
|
398
592
|
|
399
|
-
|
400
|
-
def event_shape(self):
|
401
|
-
return self.base_dist.event_shape
|
593
|
+
return _register_distribution_op(type(d))(loc, scale)
|
402
594
|
|
403
|
-
@property
|
404
|
-
def has_enumerate_support(self):
|
405
|
-
return self.base_dist.has_enumerate_support
|
406
595
|
|
407
|
-
|
408
|
-
|
409
|
-
|
596
|
+
@defterm.register(dist.Bernoulli)
|
597
|
+
@defterm.register(dist.Categorical)
|
598
|
+
@defterm.register(dist.ContinuousBernoulli)
|
599
|
+
@defterm.register(dist.Geometric)
|
600
|
+
@defterm.register(dist.OneHotCategorical)
|
601
|
+
@defterm.register(dist.OneHotCategoricalStraightThrough)
|
602
|
+
def _embed_probs(d: TorchDistribution) -> Term[TorchDistribution]:
|
603
|
+
with interpreter({}):
|
604
|
+
probs = d.probs
|
410
605
|
|
411
|
-
|
412
|
-
def support(self):
|
413
|
-
return self.base_dist.support
|
606
|
+
return _register_distribution_op(type(d))(probs)
|
414
607
|
|
415
|
-
def __repr__(self):
|
416
|
-
return f"NamedDistribution({self.base_dist}, {self.names})"
|
417
608
|
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
return t
|
609
|
+
@defterm.register(dist.Beta)
|
610
|
+
@defterm.register(dist.Kumaraswamy)
|
611
|
+
def _embed_beta(d: TorchDistribution) -> Term[TorchDistribution]:
|
612
|
+
with interpreter({}):
|
613
|
+
concentration1 = d.concentration1
|
614
|
+
concentration0 = d.concentration0
|
425
615
|
|
426
|
-
|
427
|
-
return self._to_named(
|
428
|
-
self.base_dist.rsample(sample_shape), offset=len(sample_shape)
|
429
|
-
)
|
616
|
+
return _register_distribution_op(type(d))(concentration1, concentration0)
|
430
617
|
|
431
|
-
def log_prob(self, value):
|
432
|
-
v1 = self._from_named(value)
|
433
|
-
v2 = self.base_dist.log_prob(v1)
|
434
|
-
v3 = self._to_named(v2)
|
435
|
-
return v3
|
436
618
|
|
437
|
-
|
438
|
-
|
619
|
+
@defterm.register
|
620
|
+
def _embed_binomial(d: dist.Binomial) -> Term[TorchDistribution]:
|
621
|
+
with interpreter({}):
|
622
|
+
total_count = d.total_count
|
623
|
+
probs = d.probs
|
624
|
+
|
625
|
+
return _register_distribution_op(dist.Binomial)(total_count, probs)
|
626
|
+
|
627
|
+
|
628
|
+
@defterm.register
|
629
|
+
def _embed_chi2(d: dist.Chi2) -> Term[TorchDistribution]:
|
630
|
+
with interpreter({}):
|
631
|
+
df = d.df
|
632
|
+
|
633
|
+
return _register_distribution_op(dist.Chi2)(df)
|
634
|
+
|
635
|
+
|
636
|
+
@defterm.register
|
637
|
+
def _embed_dirichlet(d: dist.Dirichlet) -> Term[TorchDistribution]:
|
638
|
+
with interpreter({}):
|
639
|
+
concentration = d.concentration
|
640
|
+
|
641
|
+
return _register_distribution_op(dist.Dirichlet)(concentration)
|
642
|
+
|
643
|
+
|
644
|
+
@defterm.register
|
645
|
+
def _embed_exponential(d: dist.Exponential) -> Term[TorchDistribution]:
|
646
|
+
with interpreter({}):
|
647
|
+
rate = d.rate
|
648
|
+
|
649
|
+
return _register_distribution_op(dist.Exponential)(rate)
|
650
|
+
|
651
|
+
|
652
|
+
@defterm.register
|
653
|
+
def _embed_fisher_snedecor(d: dist.FisherSnedecor) -> Term[TorchDistribution]:
|
654
|
+
with interpreter({}):
|
655
|
+
df1 = d.df1
|
656
|
+
df2 = d.df2
|
657
|
+
|
658
|
+
return _register_distribution_op(dist.FisherSnedecor)(df1, df2)
|
659
|
+
|
660
|
+
|
661
|
+
@defterm.register
|
662
|
+
def _embed_gamma(d: dist.Gamma) -> Term[TorchDistribution]:
|
663
|
+
with interpreter({}):
|
664
|
+
concentration = d.concentration
|
665
|
+
rate = d.rate
|
666
|
+
|
667
|
+
return _register_distribution_op(dist.Gamma)(concentration, rate)
|
668
|
+
|
669
|
+
|
670
|
+
@defterm.register(dist.HalfCauchy)
|
671
|
+
@defterm.register(dist.HalfNormal)
|
672
|
+
def _embed_half_cauchy(d: TorchDistribution) -> Term[TorchDistribution]:
|
673
|
+
with interpreter({}):
|
674
|
+
scale = d.scale
|
675
|
+
|
676
|
+
return _register_distribution_op(type(d))(scale)
|
677
|
+
|
678
|
+
|
679
|
+
@defterm.register
|
680
|
+
def _embed_lkj_cholesky(d: dist.LKJCholesky) -> Term[TorchDistribution]:
|
681
|
+
with interpreter({}):
|
682
|
+
dim = d.dim
|
683
|
+
concentration = d.concentration
|
684
|
+
|
685
|
+
return _register_distribution_op(dist.LKJCholesky)(dim, concentration=concentration)
|
686
|
+
|
687
|
+
|
688
|
+
@defterm.register
|
689
|
+
def _embed_multinomial(d: dist.Multinomial) -> Term[TorchDistribution]:
|
690
|
+
with interpreter({}):
|
691
|
+
total_count = d.total_count
|
692
|
+
probs = d.probs
|
693
|
+
|
694
|
+
return _register_distribution_op(dist.Multinomial)(total_count, probs)
|
695
|
+
|
696
|
+
|
697
|
+
@defterm.register
|
698
|
+
def _embed_multivariate_normal(d: dist.MultivariateNormal) -> Term[TorchDistribution]:
|
699
|
+
with interpreter({}):
|
700
|
+
loc = d.loc
|
701
|
+
scale_tril = d.scale_tril
|
702
|
+
|
703
|
+
return _register_distribution_op(dist.MultivariateNormal)(
|
704
|
+
loc, scale_tril=scale_tril
|
705
|
+
)
|
706
|
+
|
707
|
+
|
708
|
+
@defterm.register
|
709
|
+
def _embed_negative_binomial(d: dist.NegativeBinomial) -> Term[TorchDistribution]:
|
710
|
+
with interpreter({}):
|
711
|
+
total_count = d.total_count
|
712
|
+
probs = d.probs
|
713
|
+
|
714
|
+
return _register_distribution_op(dist.NegativeBinomial)(total_count, probs)
|
715
|
+
|
716
|
+
|
717
|
+
@defterm.register
|
718
|
+
def _embed_pareto(d: dist.Pareto) -> Term[TorchDistribution]:
|
719
|
+
with interpreter({}):
|
720
|
+
scale = d.scale
|
721
|
+
alpha = d.alpha
|
722
|
+
|
723
|
+
return _register_distribution_op(dist.Pareto)(scale, alpha)
|
724
|
+
|
725
|
+
|
726
|
+
@defterm.register
|
727
|
+
def _embed_poisson(d: dist.Poisson) -> Term[TorchDistribution]:
|
728
|
+
with interpreter({}):
|
729
|
+
rate = d.rate
|
730
|
+
|
731
|
+
return _register_distribution_op(dist.Poisson)(rate)
|
732
|
+
|
733
|
+
|
734
|
+
@defterm.register(dist.RelaxedBernoulli)
|
735
|
+
@defterm.register(dist.RelaxedOneHotCategorical)
|
736
|
+
def _embed_relaxed(d: TorchDistribution) -> Term[TorchDistribution]:
|
737
|
+
with interpreter({}):
|
738
|
+
temperature = d.temperature
|
739
|
+
probs = d.probs
|
740
|
+
|
741
|
+
return _register_distribution_op(type(d))(temperature, probs)
|
742
|
+
|
743
|
+
|
744
|
+
@defterm.register
|
745
|
+
def _embed_uniform(d: dist.Uniform) -> Term[TorchDistribution]:
|
746
|
+
with interpreter({}):
|
747
|
+
low = d.low
|
748
|
+
high = d.high
|
749
|
+
|
750
|
+
return _register_distribution_op(dist.Uniform)(low, high)
|
751
|
+
|
752
|
+
|
753
|
+
@defterm.register
|
754
|
+
def _embed_von_mises(d: dist.VonMises) -> Term[TorchDistribution]:
|
755
|
+
with interpreter({}):
|
756
|
+
loc = d.loc
|
757
|
+
concentration = d.concentration
|
758
|
+
|
759
|
+
return _register_distribution_op(dist.VonMises)(loc, concentration)
|
760
|
+
|
761
|
+
|
762
|
+
@defterm.register
|
763
|
+
def _embed_weibull(d: dist.Weibull) -> Term[TorchDistribution]:
|
764
|
+
with interpreter({}):
|
765
|
+
scale = d.scale
|
766
|
+
concentration = d.concentration
|
767
|
+
|
768
|
+
return _register_distribution_op(dist.Weibull)(scale, concentration)
|
769
|
+
|
770
|
+
|
771
|
+
@defterm.register
|
772
|
+
def _embed_wishart(d: dist.Wishart) -> Term[TorchDistribution]:
|
773
|
+
with interpreter({}):
|
774
|
+
df = d.df
|
775
|
+
scale_tril = d.scale_tril
|
776
|
+
|
777
|
+
return _register_distribution_op(dist.Wishart)(df, scale_tril)
|
778
|
+
|
779
|
+
|
780
|
+
@defterm.register
|
781
|
+
def _embed_delta(d: dist.Delta) -> Term[TorchDistribution]:
|
782
|
+
with interpreter({}):
|
783
|
+
v = d.v
|
784
|
+
log_density = d.log_density
|
785
|
+
event_dim = d.event_dim
|
786
|
+
|
787
|
+
return _register_distribution_op(dist.Delta)(
|
788
|
+
v, log_density=log_density, event_dim=event_dim
|
789
|
+
)
|
439
790
|
|
440
791
|
|
441
792
|
def pyro_module_shim(
|