effectful 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,35 +1,46 @@
1
+ import functools
1
2
  import typing
2
- import warnings
3
- from typing import Any, Collection, List, Mapping, Optional, Tuple
3
+ from collections.abc import Collection, Mapping
4
+ from typing import Any
5
+
6
+ import pyro.poutine.subsample_messenger
4
7
 
5
8
  try:
6
9
  import pyro
7
10
  except ImportError:
8
11
  raise ImportError("Pyro is required to use effectful.handlers.pyro.")
9
12
 
13
+ import pyro.distributions as dist
14
+ from pyro.distributions.torch_distribution import (
15
+ TorchDistribution,
16
+ TorchDistributionMixin,
17
+ )
18
+
10
19
  try:
11
20
  import torch
12
21
  except ImportError:
13
22
  raise ImportError("PyTorch is required to use effectful.handlers.pyro.")
14
23
 
15
- from typing_extensions import ParamSpec
16
-
17
- from effectful.handlers.torch import Indexable, sizesof, to_tensor
18
- from effectful.ops.syntax import defop
19
- from effectful.ops.types import Operation
20
-
21
- P = ParamSpec("P")
24
+ from effectful.handlers.torch import (
25
+ bind_dims,
26
+ sizesof,
27
+ unbind_dims,
28
+ )
29
+ from effectful.internals.runtime import interpreter
30
+ from effectful.ops.semantics import apply, runner, typeof
31
+ from effectful.ops.syntax import defdata, defop, defterm
32
+ from effectful.ops.types import NotHandled, Operation, Term
22
33
 
23
34
 
24
35
  @defop
25
36
  def pyro_sample(
26
37
  name: str,
27
- fn: pyro.distributions.torch_distribution.TorchDistributionMixin,
38
+ fn: TorchDistributionMixin,
28
39
  *args,
29
- obs: Optional[torch.Tensor] = None,
30
- obs_mask: Optional[torch.BoolTensor] = None,
31
- mask: Optional[torch.BoolTensor] = None,
32
- infer: Optional[pyro.poutine.runtime.InferDict] = None,
40
+ obs: torch.Tensor | None = None,
41
+ obs_mask: torch.BoolTensor | None = None,
42
+ mask: torch.BoolTensor | None = None,
43
+ infer: pyro.poutine.runtime.InferDict | None = None,
33
44
  **kwargs,
34
45
  ) -> torch.Tensor:
35
46
  """
@@ -41,6 +52,39 @@ def pyro_sample(
41
52
  )
42
53
 
43
54
 
55
+ class Naming:
56
+ """
57
+ A mapping from dimensions (indexed from the right) to names.
58
+ """
59
+
60
+ def __init__(self, name_to_dim: Mapping[Operation[[], torch.Tensor], int]):
61
+ assert all(v < 0 for v in name_to_dim.values())
62
+ self.name_to_dim = name_to_dim
63
+
64
+ @staticmethod
65
+ def from_shape(
66
+ names: Collection[Operation[[], torch.Tensor]], event_dims: int
67
+ ) -> "Naming":
68
+ """Create a naming from a set of indices and the number of event dimensions.
69
+
70
+ The resulting naming converts tensors of shape
71
+ ``| batch_shape | named | event_shape |``
72
+ to tensors of shape ``| batch_shape | event_shape |, | named |``.
73
+
74
+ """
75
+ assert event_dims >= 0
76
+ return Naming({n: -event_dims - len(names) + i for i, n in enumerate(names)})
77
+
78
+ def apply(self, value: torch.Tensor) -> torch.Tensor:
79
+ indexes: list[Any] = [slice(None)] * (len(value.shape))
80
+ for n, d in self.name_to_dim.items():
81
+ indexes[len(value.shape) + d] = n()
82
+ return value[tuple(indexes)]
83
+
84
+ def __repr__(self):
85
+ return f"Naming({self.name_to_dim})"
86
+
87
+
44
88
  class PyroShim(pyro.poutine.messenger.Messenger):
45
89
  """Pyro handler that wraps all sample sites in a custom effectful type.
46
90
 
@@ -81,17 +125,24 @@ class PyroShim(pyro.poutine.messenger.Messenger):
81
125
  Sampled y
82
126
  """
83
127
 
84
- _current_site: Optional[str]
128
+ # Tracks the named dimensions on any sample site that we have handled.
129
+ # Ideally, this information would be carried on the sample message itself.
130
+ # However, when using guides, sample sites are completely replaced by fresh
131
+ # guide sample sites that do not carry the same infer dict.
132
+ #
133
+ # We can only restore the named dimensions on samples that we have handled
134
+ # at least once in the shim.
135
+ _index_naming: dict[str, Naming]
85
136
 
86
- def __enter__(self):
87
- if any(isinstance(m, PyroShim) for m in pyro.poutine.runtime._PYRO_STACK):
88
- warnings.warn("PyroShim should be installed at most once.")
89
- return super().__enter__()
137
+ def __init__(self):
138
+ self._index_naming = {}
90
139
 
91
140
  @staticmethod
92
141
  def _broadcast_to_named(
93
- t: torch.Tensor, shape: torch.Size, indices: Mapping[Operation[[], int], int]
94
- ) -> Tuple[torch.Tensor, "Naming"]:
142
+ t: torch.Tensor,
143
+ shape: torch.Size,
144
+ indices: Mapping[Operation[[], torch.Tensor], int],
145
+ ) -> tuple[torch.Tensor, "Naming"]:
95
146
  """Convert a tensor `t` to a fully positional tensor that is
96
147
  broadcastable with the positional representation of tensors of shape
97
148
  |shape|, |indices|.
@@ -99,6 +150,9 @@ class PyroShim(pyro.poutine.messenger.Messenger):
99
150
  """
100
151
  t_indices = sizesof(t)
101
152
 
153
+ if not isinstance(t, torch.Tensor):
154
+ t = torch.tensor(t)
155
+
102
156
  if len(t.shape) < len(shape):
103
157
  t = t.expand(shape)
104
158
 
@@ -106,7 +160,7 @@ class PyroShim(pyro.poutine.messenger.Messenger):
106
160
  name_to_dim = {}
107
161
  for i, (k, v) in enumerate(reversed(list(indices.items()))):
108
162
  if k in t_indices:
109
- t = to_tensor(t, [k])
163
+ t = bind_dims(t, k)
110
164
  else:
111
165
  t = t.expand((v,) + t.shape)
112
166
  name_to_dim[k] = -len(shape) - i - 1
@@ -114,7 +168,7 @@ class PyroShim(pyro.poutine.messenger.Messenger):
114
168
  # create a positional dimension for every remaining named index in `t`
115
169
  n_batch_and_dist_named = len(t.shape)
116
170
  for i, k in enumerate(reversed(list(sizesof(t).keys()))):
117
- t = to_tensor(t, [k])
171
+ t = bind_dims(t, k)
118
172
  name_to_dim[k] = -n_batch_and_dist_named - i - 1
119
173
 
120
174
  return t, Naming(name_to_dim)
@@ -124,26 +178,47 @@ class PyroShim(pyro.poutine.messenger.Messenger):
124
178
  assert msg["type"] == "sample"
125
179
  assert msg["name"] is not None
126
180
  assert msg["infer"] is not None
127
- assert isinstance(
128
- msg["fn"], pyro.distributions.torch_distribution.TorchDistributionMixin
129
- )
181
+ assert isinstance(msg["fn"], TorchDistributionMixin)
130
182
 
131
183
  if pyro.poutine.util.site_is_subsample(msg) or pyro.poutine.util.site_is_factor(
132
184
  msg
133
185
  ):
134
186
  return
135
187
 
136
- if getattr(self, "_current_site", None) == msg["name"]:
137
- if "_markov_scope" in msg["infer"] and self._current_site:
138
- msg["infer"]["_markov_scope"].pop(self._current_site, None)
188
+ if "pyro_shim_status" in msg["infer"]:
189
+ handler_id, handler_stage = msg["infer"]["pyro_shim_status"] # type: ignore
190
+ else:
191
+ handler_id = id(self)
192
+ handler_stage = 0
193
+ msg["infer"]["pyro_shim_status"] = (handler_id, handler_stage) # type: ignore
194
+
195
+ if handler_id != id(self): # Never handle a message that is not ours.
196
+ return
197
+
198
+ assert handler_stage in (0, 1)
199
+
200
+ # PyroShim turns each call to pyro.sample into two calls. The first
201
+ # dispatches to pyro_sample and the effectful stack. The effectful stack
202
+ # eventually calls pyro.sample again. We use state in PyroShim to
203
+ # recognize that we've been called twice, and we dispatch to the pyro
204
+ # stack.
205
+ #
206
+ # This branch handles the second call, so it massages the message to be
207
+ # compatible with Pyro. In particular, it removes all named dimensions
208
+ # and stores naming information in the message. Names are replaced by
209
+ # _pyro_post_sample.
210
+ if handler_stage == 1:
211
+ if "_markov_scope" in msg["infer"]:
212
+ msg["infer"]["_markov_scope"].pop(msg["name"], None)
139
213
 
140
214
  dist = msg["fn"]
141
215
  obs = msg["value"] if msg["is_observed"] else None
142
216
 
143
217
  # pdist shape: | named1 | batch_shape | event_shape |
144
218
  # obs shape: | batch_shape | event_shape |, | named2 | where named2 may overlap named1
145
- pdist = PositionalDistribution(dist)
146
- naming = pdist.naming
219
+ indices = sizesof(dist)
220
+ naming = Naming.from_shape(indices, len(dist.shape()))
221
+ pdist = bind_dims(dist, *indices.keys())
147
222
 
148
223
  if msg["mask"] is None:
149
224
  mask = torch.tensor(True)
@@ -152,290 +227,566 @@ class PyroShim(pyro.poutine.messenger.Messenger):
152
227
  else:
153
228
  mask = msg["mask"]
154
229
 
155
- pos_mask, _ = PyroShim._broadcast_to_named(
156
- mask, dist.batch_shape, pdist.indices
230
+ assert set(sizesof(mask).keys()) <= (
231
+ set(indices.keys()) | set(sizesof(obs).keys())
157
232
  )
233
+ pos_mask, _ = PyroShim._broadcast_to_named(mask, dist.batch_shape, indices)
158
234
 
159
- pos_obs: Optional[torch.Tensor] = None
235
+ pos_obs: torch.Tensor | None = None
160
236
  if obs is not None:
161
237
  pos_obs, naming = PyroShim._broadcast_to_named(
162
- obs, dist.shape(), pdist.indices
238
+ obs, dist.shape(), indices
163
239
  )
164
240
 
241
+ # Each of the batch dimensions on the distribution gets a
242
+ # cond_indep_stack frame.
165
243
  for var, dim in naming.name_to_dim.items():
166
- frame = pyro.poutine.indep_messenger.CondIndepStackFrame(
167
- name=str(var), dim=dim, size=-1, counter=0
168
- )
169
- msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"]
244
+ # There can be additional batch dimensions on the observation
245
+ # that do not get frames, so only consider dimensions on the
246
+ # distribution.
247
+ if var in indices:
248
+ frame = pyro.poutine.indep_messenger.CondIndepStackFrame(
249
+ name=f"__index_plate_{var}",
250
+ # dims are indexed from the right of the batch shape
251
+ dim=dim + len(pdist.event_shape),
252
+ size=indices[var],
253
+ counter=0,
254
+ )
255
+ msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"]
170
256
 
171
257
  msg["fn"] = pdist
172
258
  msg["value"] = pos_obs
173
259
  msg["mask"] = pos_mask
174
- msg["infer"]["_index_naming"] = naming # type: ignore
260
+
261
+ # stash the index naming on the sample message so that future
262
+ # consumers of the trace can get at it
263
+ msg["_index_naming"] = naming # type: ignore
264
+
265
+ self._index_naming[msg["name"]] = naming
175
266
 
176
267
  assert sizesof(msg["value"]) == {}
177
268
  assert sizesof(msg["mask"]) == {}
178
269
 
179
- return
270
+ # This branch handles the first call to pyro.sample by calling pyro_sample.
271
+ else:
272
+ infer = msg["infer"].copy()
273
+ infer["pyro_shim_status"] = (handler_id, 1) # type: ignore
180
274
 
181
- try:
182
- self._current_site = msg["name"]
183
275
  msg["value"] = pyro_sample(
184
276
  msg["name"],
185
277
  msg["fn"],
186
278
  obs=msg["value"] if msg["is_observed"] else None,
187
- infer=msg["infer"].copy(),
279
+ infer=infer,
188
280
  )
189
- finally:
190
- self._current_site = None
191
281
 
192
- # flags to guarantee commutativity of condition, intervene, trace
193
- msg["stop"] = True
194
- msg["done"] = True
195
- msg["mask"] = False
196
- msg["is_observed"] = True
197
- msg["infer"]["is_auxiliary"] = True
198
- msg["infer"]["_do_not_trace"] = True
282
+ # flags to guarantee commutativity of condition, intervene, trace
283
+ msg["stop"] = True
284
+ msg["done"] = True
285
+ msg["mask"] = False
286
+ msg["is_observed"] = True
287
+ msg["infer"]["is_auxiliary"] = True
288
+ msg["infer"]["_do_not_trace"] = True
199
289
 
200
290
  def _pyro_post_sample(self, msg: pyro.poutine.runtime.Message) -> None:
201
- infer = msg.get("infer")
202
- if infer is None or "_index_naming" not in infer:
203
- return
291
+ if typing.TYPE_CHECKING:
292
+ assert msg["name"] is not None
293
+ assert msg["value"] is not None
294
+ assert msg["infer"] is not None
204
295
 
205
- # note: Pyro uses a TypedDict for infer, so it doesn't know we've stored this key
206
- naming = infer["_index_naming"] # type: ignore
296
+ # If there is no shim status, assume that we are looking at a guide sample.
297
+ # In this case, we should handle the sample and claim it as ours if we have naming
298
+ # information for it.
299
+ if "pyro_shim_status" not in msg["infer"]:
300
+ # Except, of course, for subsample messages, which we should ignore.
301
+ if (
302
+ pyro.poutine.util.site_is_subsample(msg)
303
+ or msg["name"] not in self._index_naming
304
+ ):
305
+ return
306
+ msg["infer"]["pyro_shim_status"] = (id(self), 1) # type: ignore
307
+
308
+ # If this message has been handled already by a different pyro shim, ignore.
309
+ handler_id, handler_stage = msg["infer"]["pyro_shim_status"] # type: ignore
310
+ if handler_id != id(self) or handler_stage < 1:
311
+ return
207
312
 
208
313
  value = msg["value"]
209
314
 
210
- if value is not None:
211
- # note: is it safe to assume that msg['fn'] is a distribution?
212
- assert isinstance(
213
- msg["fn"], pyro.distributions.torch_distribution.TorchDistribution
214
- )
215
- dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape
216
- if len(value.shape) < len(dist_shape):
217
- value = value.broadcast_to(
218
- torch.broadcast_shapes(value.shape, dist_shape)
219
- )
220
- value = naming.apply(value)
221
- msg["value"] = value
315
+ naming = self._index_naming.get(msg["name"], Naming({}))
316
+ infer = msg["infer"] if msg["infer"] is not None else {}
317
+ assert "enumerate" not in infer or len(naming.name_to_dim) == 0, (
318
+ "Enumeration is not currently supported in PyroShim."
319
+ )
222
320
 
321
+ # note: is it safe to assume that msg['fn'] is a distribution?
322
+ dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape # type: ignore
323
+ if len(value.shape) < len(dist_shape):
324
+ value = value.broadcast_to(torch.broadcast_shapes(value.shape, dist_shape))
223
325
 
224
- class Naming:
225
- """
226
- A mapping from dimensions (indexed from the right) to names.
227
- """
326
+ value = naming.apply(value)
327
+ msg["value"] = value
228
328
 
229
- def __init__(self, name_to_dim: Mapping[Operation[[], int], int]):
230
- assert all(v < 0 for v in name_to_dim.values())
231
- self.name_to_dim = name_to_dim
232
329
 
233
- @staticmethod
234
- def from_shape(names: Collection[Operation[[], int]], event_dims: int) -> "Naming":
235
- """Create a naming from a set of indices and the number of event dimensions.
330
+ PyroDistribution = (
331
+ pyro.distributions.torch_distribution.TorchDistribution
332
+ | pyro.distributions.torch_distribution.TorchDistributionMixin
333
+ )
236
334
 
237
- The resulting naming converts tensors of shape
238
- ``| batch_shape | named | event_shape |``
239
- to tensors of shape ``| batch_shape | event_shape |, | named |``.
240
335
 
241
- """
242
- assert event_dims >= 0
243
- return Naming({n: -event_dims - len(names) + i for i, n in enumerate(names)})
336
+ @unbind_dims.register(pyro.distributions.torch_distribution.TorchDistribution) # type: ignore
337
+ @unbind_dims.register(pyro.distributions.torch_distribution.TorchDistributionMixin) # type: ignore
338
+ def _unbind_dims_distribution(
339
+ value: pyro.distributions.torch_distribution.TorchDistribution,
340
+ *names: Operation[[], torch.Tensor],
341
+ ) -> pyro.distributions.torch_distribution.TorchDistribution:
342
+ d = value
343
+ batch_shape = None
244
344
 
245
- def apply(self, value: torch.Tensor) -> torch.Tensor:
246
- indexes: List[Any] = [slice(None)] * (len(value.shape))
247
- for n, d in self.name_to_dim.items():
248
- indexes[len(value.shape) + d] = n()
249
- return Indexable(value)[tuple(indexes)]
345
+ def _validate_batch_shape(t):
346
+ nonlocal batch_shape
347
+ if len(t.shape) < len(names):
348
+ raise ValueError(
349
+ "All tensors must have at least as many dimensions as names"
350
+ )
250
351
 
251
- def __repr__(self):
252
- return f"Naming({self.name_to_dim})"
352
+ if batch_shape is None:
353
+ batch_shape = t.shape[: len(names)]
253
354
 
355
+ if (
356
+ len(t.shape) < len(batch_shape)
357
+ or t.shape[: len(batch_shape)] != batch_shape
358
+ ):
359
+ raise ValueError("All tensors must have the same batch shape.")
360
+
361
+ def _to_named(a):
362
+ nonlocal batch_shape
363
+ if isinstance(a, torch.Tensor):
364
+ _validate_batch_shape(a)
365
+ return typing.cast(torch.Tensor, a)[tuple(n() for n in names)]
366
+ elif isinstance(a, TorchDistribution):
367
+ return unbind_dims(a, *names)
368
+ else:
369
+ return a
370
+
371
+ # Convert to a term in a context that does not evaluate distribution constructors.
372
+ def _apply(op, *args, **kwargs):
373
+ typ = op.__type_rule__(*args, **kwargs)
374
+ if issubclass(
375
+ typ, pyro.distributions.torch_distribution.TorchDistribution
376
+ ) or issubclass(
377
+ typ, pyro.distributions.torch_distribution.TorchDistributionMixin
378
+ ):
379
+ return defdata(op, *args, **kwargs)
380
+ return op.__default_rule__(*args, **kwargs)
381
+
382
+ with runner({apply: _apply}):
383
+ d = defterm(d)
384
+
385
+ if not (isinstance(d, Term) and typeof(d) is TorchDistribution):
386
+ raise NotHandled
387
+
388
+ new_d = d.op(
389
+ *[_to_named(a) for a in d.args],
390
+ **{k: _to_named(v) for (k, v) in d.kwargs.items()},
391
+ )
392
+ assert new_d.event_shape == d.event_shape
393
+ return new_d
394
+
395
+
396
+ @bind_dims.register(pyro.distributions.torch_distribution.TorchDistribution) # type: ignore
397
+ @bind_dims.register(pyro.distributions.torch_distribution.TorchDistributionMixin) # type: ignore
398
+ def _bind_dims_distribution(
399
+ value: pyro.distributions.torch_distribution.TorchDistribution,
400
+ *names: Operation[[], torch.Tensor],
401
+ ) -> pyro.distributions.torch_distribution.TorchDistribution:
402
+ d = value
403
+
404
+ def _to_positional(a, indices):
405
+ if isinstance(a, torch.Tensor):
406
+ # broadcast to full indexed shape
407
+ existing_dims = set(sizesof(a).keys())
408
+ missing_dims = set(indices) - existing_dims
409
+
410
+ a_indexed = torch.broadcast_to(
411
+ a, torch.Size([indices[dim] for dim in missing_dims]) + a.shape
412
+ )[tuple(n() for n in missing_dims)]
413
+ return bind_dims(a_indexed, *names)
414
+ elif isinstance(a, TorchDistribution):
415
+ return bind_dims(a, *names)
416
+ else:
417
+ return a
418
+
419
+ # Convert to a term in a context that does not evaluate distribution constructors.
420
+ def _apply(op, *args, **kwargs):
421
+ typ = op.__type_rule__(*args, **kwargs)
422
+ if issubclass(
423
+ typ, pyro.distributions.torch_distribution.TorchDistribution
424
+ ) or issubclass(
425
+ typ, pyro.distributions.torch_distribution.TorchDistributionMixin
426
+ ):
427
+ return defdata(op, *args, **kwargs)
428
+ return op.__default_rule__(*args, **kwargs)
254
429
 
255
- class PositionalDistribution(pyro.distributions.torch_distribution.TorchDistribution):
256
- """A distribution wrapper that lazily converts indexed dimensions to
257
- positional.
430
+ with runner({apply: _apply}):
431
+ d = defterm(d)
258
432
 
259
- """
433
+ if not (isinstance(d, Term) and typeof(d) is TorchDistribution):
434
+ raise NotHandled
260
435
 
261
- indices: Mapping[Operation[[], int], int]
436
+ sizes = sizesof(d)
437
+ indices = {k: sizes[k] for k in names}
262
438
 
263
- def __init__(
264
- self, base_dist: pyro.distributions.torch_distribution.TorchDistribution
265
- ):
266
- self.base_dist = base_dist
267
- self.indices = sizesof(base_dist.sample())
439
+ pos_args = [_to_positional(a, indices) for a in d.args]
440
+ pos_kwargs = {k: _to_positional(v, indices) for (k, v) in d.kwargs.items()}
441
+ new_d = d.op(*pos_args, **pos_kwargs)
268
442
 
269
- n_base = len(base_dist.batch_shape) + len(base_dist.event_shape)
270
- self.naming = Naming.from_shape(self.indices.keys(), n_base)
443
+ assert new_d.event_shape == d.event_shape
444
+ return new_d
271
445
 
272
- super().__init__()
273
446
 
274
- def _to_positional(self, value: torch.Tensor) -> torch.Tensor:
275
- # self.base_dist has shape: | batch_shape | event_shape | & named
276
- # assume value comes from base_dist with shape:
277
- # | sample_shape | batch_shape | event_shape | & named
278
- # return a tensor of shape | sample_shape | named | batch_shape | event_shape |
279
- n_named = len(self.indices)
280
- dims = list(range(n_named + len(value.shape)))
447
+ @functools.cache
448
+ def _register_distribution_op(
449
+ dist_constr: type[TorchDistribution],
450
+ ) -> Operation[Any, TorchDistribution]:
451
+ # introduce a wrapper so that we can control type annotations
452
+ def wrapper(*args, **kwargs) -> TorchDistribution:
453
+ return dist_constr(*args, **kwargs)
281
454
 
282
- n_base = len(self.event_shape) + len(self.base_dist.batch_shape)
283
- n_sample = len(value.shape) - n_base
455
+ return defop(wrapper)
284
456
 
285
- base_dims = dims[len(dims) - n_base :]
286
- named_dims = dims[:n_named]
287
- sample_dims = dims[n_named : n_named + n_sample]
288
457
 
289
- # shape: | named | sample_shape | batch_shape | event_shape |
290
- # TODO: replace with something more efficient
291
- pos_tensor = to_tensor(value, self.indices.keys())
458
+ @defdata.register(pyro.distributions.torch_distribution.TorchDistribution)
459
+ @defdata.register(pyro.distributions.torch_distribution.TorchDistributionMixin)
460
+ class _DistributionTerm(Term[TorchDistribution], TorchDistribution):
461
+ """A distribution wrapper that satisfies the Term interface.
292
462
 
293
- # shape: | sample_shape | named | batch_shape | event_shape |
294
- pos_tensor_r = torch.permute(pos_tensor, sample_dims + named_dims + base_dims)
463
+ Represented as a term of the form call(D, *args, **kwargs) where D is the
464
+ distribution constructor.
295
465
 
296
- return pos_tensor_r
466
+ Note: When we construct instances of this class, we put distribution
467
+ parameters that can be expanded in the args list and those that cannot in
468
+ the kwargs list.
297
469
 
298
- def _from_positional(self, value: torch.Tensor) -> torch.Tensor:
299
- # maximal value shape: | sample_shape | named | batch_shape | event_shape |
300
- return self.naming.apply(value)
470
+ """
471
+
472
+ _op: Operation[Any, TorchDistribution]
473
+ _args: tuple
474
+ _kwargs: dict
475
+
476
+ def __init__(self, op: Operation[Any, TorchDistribution], *args, **kwargs):
477
+ self._op = op
478
+ self._args = tuple(defterm(a) for a in args)
479
+ self._kwargs = {k: defterm(v) for (k, v) in kwargs.items()}
480
+
481
+ @property
482
+ def op(self):
483
+ return self._op
484
+
485
+ @property
486
+ def args(self):
487
+ return self._args
488
+
489
+ @property
490
+ def kwargs(self):
491
+ return self._kwargs
492
+
493
+ @property
494
+ def _base_dist(self):
495
+ return self._op(*self.args, **self.kwargs)
301
496
 
302
497
  @property
303
498
  def has_rsample(self):
304
- return self.base_dist.has_rsample
499
+ return self._base_dist.has_rsample
305
500
 
306
501
  @property
307
502
  def batch_shape(self):
308
- return (
309
- torch.Size([s for s in self.indices.values()]) + self.base_dist.batch_shape
310
- )
503
+ return self._base_dist.batch_shape
311
504
 
312
505
  @property
313
506
  def event_shape(self):
314
- return self.base_dist.event_shape
507
+ return self._base_dist.event_shape
315
508
 
316
509
  @property
317
510
  def has_enumerate_support(self):
318
- return self.base_dist.has_enumerate_support
511
+ return self._base_dist.has_enumerate_support
319
512
 
320
513
  @property
321
514
  def arg_constraints(self):
322
- return self.base_dist.arg_constraints
515
+ return self._base_dist.arg_constraints
323
516
 
324
517
  @property
325
518
  def support(self):
326
- return self.base_dist.support
327
-
328
- def __repr__(self):
329
- return f"PositionalDistribution({self.base_dist})"
519
+ return self._base_dist.support
330
520
 
331
521
  def sample(self, sample_shape=torch.Size()):
332
- return self._to_positional(self.base_dist.sample(sample_shape))
522
+ return self._base_dist.sample(sample_shape)
333
523
 
334
524
  def rsample(self, sample_shape=torch.Size()):
335
- return self._to_positional(self.base_dist.rsample(sample_shape))
525
+ return self._base_dist.rsample(sample_shape)
336
526
 
337
527
  def log_prob(self, value):
338
- return self._to_positional(
339
- self.base_dist.log_prob(self._from_positional(value))
340
- )
528
+ return self._base_dist.log_prob(value)
341
529
 
342
530
  def enumerate_support(self, expand=True):
343
- return self._to_positional(self.base_dist.enumerate_support(expand))
531
+ return self._base_dist.enumerate_support(expand)
344
532
 
345
533
 
346
- class NamedDistribution(pyro.distributions.torch_distribution.TorchDistribution):
347
- """A distribution wrapper that lazily names leftmost dimensions."""
534
+ @defterm.register(TorchDistribution)
535
+ @defterm.register(TorchDistributionMixin)
536
+ def _embed_distribution(dist: TorchDistribution) -> Term[TorchDistribution]:
537
+ raise ValueError(
538
+ f"No embedding provided for distribution of type {type(dist).__name__}."
539
+ )
348
540
 
349
- def __init__(
350
- self,
351
- base_dist: pyro.distributions.torch_distribution.TorchDistribution,
352
- names: Collection[Operation[[], int]],
353
- ):
354
- """
355
- :param base_dist: A distribution with batch dimensions.
356
541
 
357
- :param names: A list of names.
542
+ @defterm.register
543
+ def _embed_expanded(d: dist.ExpandedDistribution) -> Term[TorchDistribution]:
544
+ with interpreter({}):
545
+ batch_shape = d._batch_shape
546
+ base_dist = d.base_dist
547
+ base_batch_shape = base_dist.batch_shape
548
+ if batch_shape == base_batch_shape:
549
+ return base_dist
358
550
 
359
- """
360
- self.base_dist = base_dist
361
- self.names = names
551
+ raise ValueError("Nontrivial ExpandedDistribution not implemented.")
362
552
 
363
- assert 1 <= len(names) <= len(base_dist.batch_shape)
364
- base_indices = sizesof(base_dist.sample())
365
- assert not any(n in base_indices for n in names)
366
553
 
367
- n_base = len(base_dist.batch_shape) + len(base_dist.event_shape)
368
- self.naming = Naming.from_shape(names, n_base - len(names))
369
- super().__init__()
554
+ @defterm.register
555
+ def _embed_independent(d: dist.Independent) -> Term[TorchDistribution]:
556
+ with interpreter({}):
557
+ base_dist = d.base_dist
558
+ reinterpreted_batch_ndims = d.reinterpreted_batch_ndims
370
559
 
371
- def _to_named(self, value: torch.Tensor, offset=0) -> torch.Tensor:
372
- return self.naming.apply(value)
560
+ return _register_distribution_op(type(d))(base_dist, reinterpreted_batch_ndims)
373
561
 
374
- def _from_named(self, value: torch.Tensor) -> torch.Tensor:
375
- pos_value = to_tensor(value, self.names)
376
562
 
377
- dims = list(range(len(pos_value.shape)))
563
+ @defterm.register
564
+ def _embed_folded(d: dist.FoldedDistribution) -> Term[TorchDistribution]:
565
+ with interpreter({}):
566
+ base_dist = d.base_dist
378
567
 
379
- n_base = len(self.event_shape) + len(self.batch_shape)
380
- n_named = len(self.names)
381
- n_sample = len(pos_value.shape) - n_base - n_named
568
+ return _register_distribution_op(type(d))(base_dist) # type: ignore
382
569
 
383
- base_dims = dims[len(dims) - n_base :]
384
- named_dims = dims[:n_named]
385
- sample_dims = dims[n_named : n_named + n_sample]
386
570
 
387
- pos_tensor_r = torch.permute(pos_value, sample_dims + named_dims + base_dims)
571
+ @defterm.register
572
+ def _embed_masked(d: dist.MaskedDistribution) -> Term[TorchDistribution]:
573
+ with interpreter({}):
574
+ base_dist = d.base_dist
575
+ mask = d._mask
388
576
 
389
- return pos_tensor_r
577
+ return _register_distribution_op(type(d))(base_dist, mask)
390
578
 
391
- @property
392
- def has_rsample(self):
393
- return self.base_dist.has_rsample
394
579
 
395
- @property
396
- def batch_shape(self):
397
- return self.base_dist.batch_shape[len(self.names) :]
580
+ @defterm.register(dist.Cauchy)
581
+ @defterm.register(dist.Gumbel)
582
+ @defterm.register(dist.Laplace)
583
+ @defterm.register(dist.LogNormal)
584
+ @defterm.register(dist.Logistic)
585
+ @defterm.register(dist.LogisticNormal)
586
+ @defterm.register(dist.Normal)
587
+ @defterm.register(dist.StudentT)
588
+ def _embed_loc_scale(d: TorchDistribution) -> Term[TorchDistribution]:
589
+ with interpreter({}):
590
+ loc = d.loc
591
+ scale = d.scale
398
592
 
399
- @property
400
- def event_shape(self):
401
- return self.base_dist.event_shape
593
+ return _register_distribution_op(type(d))(loc, scale)
402
594
 
403
- @property
404
- def has_enumerate_support(self):
405
- return self.base_dist.has_enumerate_support
406
595
 
407
- @property
408
- def arg_constraints(self):
409
- return self.base_dist.arg_constraints
596
+ @defterm.register(dist.Bernoulli)
597
+ @defterm.register(dist.Categorical)
598
+ @defterm.register(dist.ContinuousBernoulli)
599
+ @defterm.register(dist.Geometric)
600
+ @defterm.register(dist.OneHotCategorical)
601
+ @defterm.register(dist.OneHotCategoricalStraightThrough)
602
+ def _embed_probs(d: TorchDistribution) -> Term[TorchDistribution]:
603
+ with interpreter({}):
604
+ probs = d.probs
410
605
 
411
- @property
412
- def support(self):
413
- return self.base_dist.support
606
+ return _register_distribution_op(type(d))(probs)
414
607
 
415
- def __repr__(self):
416
- return f"NamedDistribution({self.base_dist}, {self.names})"
417
608
 
418
- def sample(self, sample_shape=torch.Size()):
419
- t = self._to_named(
420
- self.base_dist.sample(sample_shape), offset=len(sample_shape)
421
- )
422
- assert set(sizesof(t).keys()) == set(self.names)
423
- assert t.shape == self.shape() + sample_shape
424
- return t
609
+ @defterm.register(dist.Beta)
610
+ @defterm.register(dist.Kumaraswamy)
611
+ def _embed_beta(d: TorchDistribution) -> Term[TorchDistribution]:
612
+ with interpreter({}):
613
+ concentration1 = d.concentration1
614
+ concentration0 = d.concentration0
425
615
 
426
- def rsample(self, sample_shape=torch.Size()):
427
- return self._to_named(
428
- self.base_dist.rsample(sample_shape), offset=len(sample_shape)
429
- )
616
+ return _register_distribution_op(type(d))(concentration1, concentration0)
430
617
 
431
- def log_prob(self, value):
432
- v1 = self._from_named(value)
433
- v2 = self.base_dist.log_prob(v1)
434
- v3 = self._to_named(v2)
435
- return v3
436
618
 
437
- def enumerate_support(self, expand=True):
438
- return self._to_named(self.base_dist.enumerate_support(expand))
619
+ @defterm.register
620
+ def _embed_binomial(d: dist.Binomial) -> Term[TorchDistribution]:
621
+ with interpreter({}):
622
+ total_count = d.total_count
623
+ probs = d.probs
624
+
625
+ return _register_distribution_op(dist.Binomial)(total_count, probs)
626
+
627
+
628
+ @defterm.register
629
+ def _embed_chi2(d: dist.Chi2) -> Term[TorchDistribution]:
630
+ with interpreter({}):
631
+ df = d.df
632
+
633
+ return _register_distribution_op(dist.Chi2)(df)
634
+
635
+
636
+ @defterm.register
637
+ def _embed_dirichlet(d: dist.Dirichlet) -> Term[TorchDistribution]:
638
+ with interpreter({}):
639
+ concentration = d.concentration
640
+
641
+ return _register_distribution_op(dist.Dirichlet)(concentration)
642
+
643
+
644
+ @defterm.register
645
+ def _embed_exponential(d: dist.Exponential) -> Term[TorchDistribution]:
646
+ with interpreter({}):
647
+ rate = d.rate
648
+
649
+ return _register_distribution_op(dist.Exponential)(rate)
650
+
651
+
652
+ @defterm.register
653
+ def _embed_fisher_snedecor(d: dist.FisherSnedecor) -> Term[TorchDistribution]:
654
+ with interpreter({}):
655
+ df1 = d.df1
656
+ df2 = d.df2
657
+
658
+ return _register_distribution_op(dist.FisherSnedecor)(df1, df2)
659
+
660
+
661
+ @defterm.register
662
+ def _embed_gamma(d: dist.Gamma) -> Term[TorchDistribution]:
663
+ with interpreter({}):
664
+ concentration = d.concentration
665
+ rate = d.rate
666
+
667
+ return _register_distribution_op(dist.Gamma)(concentration, rate)
668
+
669
+
670
+ @defterm.register(dist.HalfCauchy)
671
+ @defterm.register(dist.HalfNormal)
672
+ def _embed_half_cauchy(d: TorchDistribution) -> Term[TorchDistribution]:
673
+ with interpreter({}):
674
+ scale = d.scale
675
+
676
+ return _register_distribution_op(type(d))(scale)
677
+
678
+
679
+ @defterm.register
680
+ def _embed_lkj_cholesky(d: dist.LKJCholesky) -> Term[TorchDistribution]:
681
+ with interpreter({}):
682
+ dim = d.dim
683
+ concentration = d.concentration
684
+
685
+ return _register_distribution_op(dist.LKJCholesky)(dim, concentration=concentration)
686
+
687
+
688
+ @defterm.register
689
+ def _embed_multinomial(d: dist.Multinomial) -> Term[TorchDistribution]:
690
+ with interpreter({}):
691
+ total_count = d.total_count
692
+ probs = d.probs
693
+
694
+ return _register_distribution_op(dist.Multinomial)(total_count, probs)
695
+
696
+
697
+ @defterm.register
698
+ def _embed_multivariate_normal(d: dist.MultivariateNormal) -> Term[TorchDistribution]:
699
+ with interpreter({}):
700
+ loc = d.loc
701
+ scale_tril = d.scale_tril
702
+
703
+ return _register_distribution_op(dist.MultivariateNormal)(
704
+ loc, scale_tril=scale_tril
705
+ )
706
+
707
+
708
+ @defterm.register
709
+ def _embed_negative_binomial(d: dist.NegativeBinomial) -> Term[TorchDistribution]:
710
+ with interpreter({}):
711
+ total_count = d.total_count
712
+ probs = d.probs
713
+
714
+ return _register_distribution_op(dist.NegativeBinomial)(total_count, probs)
715
+
716
+
717
+ @defterm.register
718
+ def _embed_pareto(d: dist.Pareto) -> Term[TorchDistribution]:
719
+ with interpreter({}):
720
+ scale = d.scale
721
+ alpha = d.alpha
722
+
723
+ return _register_distribution_op(dist.Pareto)(scale, alpha)
724
+
725
+
726
+ @defterm.register
727
+ def _embed_poisson(d: dist.Poisson) -> Term[TorchDistribution]:
728
+ with interpreter({}):
729
+ rate = d.rate
730
+
731
+ return _register_distribution_op(dist.Poisson)(rate)
732
+
733
+
734
+ @defterm.register(dist.RelaxedBernoulli)
735
+ @defterm.register(dist.RelaxedOneHotCategorical)
736
+ def _embed_relaxed(d: TorchDistribution) -> Term[TorchDistribution]:
737
+ with interpreter({}):
738
+ temperature = d.temperature
739
+ probs = d.probs
740
+
741
+ return _register_distribution_op(type(d))(temperature, probs)
742
+
743
+
744
+ @defterm.register
745
+ def _embed_uniform(d: dist.Uniform) -> Term[TorchDistribution]:
746
+ with interpreter({}):
747
+ low = d.low
748
+ high = d.high
749
+
750
+ return _register_distribution_op(dist.Uniform)(low, high)
751
+
752
+
753
+ @defterm.register
754
+ def _embed_von_mises(d: dist.VonMises) -> Term[TorchDistribution]:
755
+ with interpreter({}):
756
+ loc = d.loc
757
+ concentration = d.concentration
758
+
759
+ return _register_distribution_op(dist.VonMises)(loc, concentration)
760
+
761
+
762
+ @defterm.register
763
+ def _embed_weibull(d: dist.Weibull) -> Term[TorchDistribution]:
764
+ with interpreter({}):
765
+ scale = d.scale
766
+ concentration = d.concentration
767
+
768
+ return _register_distribution_op(dist.Weibull)(scale, concentration)
769
+
770
+
771
+ @defterm.register
772
+ def _embed_wishart(d: dist.Wishart) -> Term[TorchDistribution]:
773
+ with interpreter({}):
774
+ df = d.df
775
+ scale_tril = d.scale_tril
776
+
777
+ return _register_distribution_op(dist.Wishart)(df, scale_tril)
778
+
779
+
780
+ @defterm.register
781
+ def _embed_delta(d: dist.Delta) -> Term[TorchDistribution]:
782
+ with interpreter({}):
783
+ v = d.v
784
+ log_density = d.log_density
785
+ event_dim = d.event_dim
786
+
787
+ return _register_distribution_op(dist.Delta)(
788
+ v, log_density=log_density, event_dim=event_dim
789
+ )
439
790
 
440
791
 
441
792
  def pyro_module_shim(