effectful 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,466 @@
1
+ import typing
2
+ import warnings
3
+ from typing import Any, Collection, List, Mapping, Optional, Tuple
4
+
5
+ try:
6
+ import pyro
7
+ except ImportError:
8
+ raise ImportError("Pyro is required to use effectful.handlers.pyro.")
9
+
10
+ try:
11
+ import torch
12
+ except ImportError:
13
+ raise ImportError("PyTorch is required to use effectful.handlers.pyro.")
14
+
15
+ from typing_extensions import ParamSpec
16
+
17
+ from effectful.handlers.torch import Indexable, sizesof, to_tensor
18
+ from effectful.ops.syntax import defop
19
+ from effectful.ops.types import Operation
20
+
21
+ P = ParamSpec("P")
22
+
23
+
24
+ @defop
25
+ def pyro_sample(
26
+ name: str,
27
+ fn: pyro.distributions.torch_distribution.TorchDistributionMixin,
28
+ *args,
29
+ obs: Optional[torch.Tensor] = None,
30
+ obs_mask: Optional[torch.BoolTensor] = None,
31
+ mask: Optional[torch.BoolTensor] = None,
32
+ infer: Optional[pyro.poutine.runtime.InferDict] = None,
33
+ **kwargs,
34
+ ) -> torch.Tensor:
35
+ """
36
+ Operation to sample from a Pyro distribution. See :func:`pyro.sample`.
37
+ """
38
+ with pyro.poutine.mask(mask=mask if mask is not None else True):
39
+ return pyro.sample(
40
+ name, fn, *args, obs=obs, obs_mask=obs_mask, infer=infer, **kwargs
41
+ )
42
+
43
+
44
+ class PyroShim(pyro.poutine.messenger.Messenger):
45
+ """Pyro handler that wraps all sample sites in a custom effectful type.
46
+
47
+ .. note::
48
+
49
+ This handler should be installed around any Pyro model that you want to
50
+ use effectful handlers with.
51
+
52
+ **Example usage**:
53
+
54
+ >>> import pyro.distributions as dist
55
+ >>> from effectful.ops.semantics import fwd, handler
56
+ >>> torch.distributions.Distribution.set_default_validate_args(False)
57
+
58
+ It can be used as a decorator:
59
+
60
+ >>> @PyroShim()
61
+ ... def model():
62
+ ... return pyro.sample("x", dist.Normal(0, 1))
63
+
64
+ It can also be used as a context manager:
65
+
66
+ >>> with PyroShim():
67
+ ... x = pyro.sample("x", dist.Normal(0, 1))
68
+
69
+ When :class:`PyroShim` is installed, all sample sites perform the
70
+ :func:`pyro_sample` effect, which can be handled by an effectful
71
+ interpretation.
72
+
73
+ >>> def log_sample(name, *args, **kwargs):
74
+ ... print(f"Sampled {name}")
75
+ ... return fwd()
76
+
77
+ >>> with PyroShim(), handler({pyro_sample: log_sample}):
78
+ ... x = pyro.sample("x", dist.Normal(0, 1))
79
+ ... y = pyro.sample("y", dist.Normal(0, 1))
80
+ Sampled x
81
+ Sampled y
82
+ """
83
+
84
+ _current_site: Optional[str]
85
+
86
+ def __enter__(self):
87
+ if any(isinstance(m, PyroShim) for m in pyro.poutine.runtime._PYRO_STACK):
88
+ warnings.warn("PyroShim should be installed at most once.")
89
+ return super().__enter__()
90
+
91
+ @staticmethod
92
+ def _broadcast_to_named(
93
+ t: torch.Tensor, shape: torch.Size, indices: Mapping[Operation[[], int], int]
94
+ ) -> Tuple[torch.Tensor, "Naming"]:
95
+ """Convert a tensor `t` to a fully positional tensor that is
96
+ broadcastable with the positional representation of tensors of shape
97
+ |shape|, |indices|.
98
+
99
+ """
100
+ t_indices = sizesof(t)
101
+
102
+ if len(t.shape) < len(shape):
103
+ t = t.expand(shape)
104
+
105
+ # create a positional dimension for every named index in the target shape
106
+ name_to_dim = {}
107
+ for i, (k, v) in enumerate(reversed(list(indices.items()))):
108
+ if k in t_indices:
109
+ t = to_tensor(t, [k])
110
+ else:
111
+ t = t.expand((v,) + t.shape)
112
+ name_to_dim[k] = -len(shape) - i - 1
113
+
114
+ # create a positional dimension for every remaining named index in `t`
115
+ n_batch_and_dist_named = len(t.shape)
116
+ for i, k in enumerate(reversed(list(sizesof(t).keys()))):
117
+ t = to_tensor(t, [k])
118
+ name_to_dim[k] = -n_batch_and_dist_named - i - 1
119
+
120
+ return t, Naming(name_to_dim)
121
+
122
+ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None:
123
+ if typing.TYPE_CHECKING:
124
+ assert msg["type"] == "sample"
125
+ assert msg["name"] is not None
126
+ assert msg["infer"] is not None
127
+ assert isinstance(
128
+ msg["fn"], pyro.distributions.torch_distribution.TorchDistributionMixin
129
+ )
130
+
131
+ if pyro.poutine.util.site_is_subsample(msg) or pyro.poutine.util.site_is_factor(
132
+ msg
133
+ ):
134
+ return
135
+
136
+ if getattr(self, "_current_site", None) == msg["name"]:
137
+ if "_markov_scope" in msg["infer"] and self._current_site:
138
+ msg["infer"]["_markov_scope"].pop(self._current_site, None)
139
+
140
+ dist = msg["fn"]
141
+ obs = msg["value"] if msg["is_observed"] else None
142
+
143
+ # pdist shape: | named1 | batch_shape | event_shape |
144
+ # obs shape: | batch_shape | event_shape |, | named2 | where named2 may overlap named1
145
+ pdist = PositionalDistribution(dist)
146
+ naming = pdist.naming
147
+
148
+ if msg["mask"] is None:
149
+ mask = torch.tensor(True)
150
+ elif isinstance(msg["mask"], bool):
151
+ mask = torch.tensor(msg["mask"])
152
+ else:
153
+ mask = msg["mask"]
154
+
155
+ pos_mask, _ = PyroShim._broadcast_to_named(
156
+ mask, dist.batch_shape, pdist.indices
157
+ )
158
+
159
+ pos_obs: Optional[torch.Tensor] = None
160
+ if obs is not None:
161
+ pos_obs, naming = PyroShim._broadcast_to_named(
162
+ obs, dist.shape(), pdist.indices
163
+ )
164
+
165
+ for var, dim in naming.name_to_dim.items():
166
+ frame = pyro.poutine.indep_messenger.CondIndepStackFrame(
167
+ name=str(var), dim=dim, size=-1, counter=0
168
+ )
169
+ msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"]
170
+
171
+ msg["fn"] = pdist
172
+ msg["value"] = pos_obs
173
+ msg["mask"] = pos_mask
174
+ msg["infer"]["_index_naming"] = naming # type: ignore
175
+
176
+ assert sizesof(msg["value"]) == {}
177
+ assert sizesof(msg["mask"]) == {}
178
+
179
+ return
180
+
181
+ try:
182
+ self._current_site = msg["name"]
183
+ msg["value"] = pyro_sample(
184
+ msg["name"],
185
+ msg["fn"],
186
+ obs=msg["value"] if msg["is_observed"] else None,
187
+ infer=msg["infer"].copy(),
188
+ )
189
+ finally:
190
+ self._current_site = None
191
+
192
+ # flags to guarantee commutativity of condition, intervene, trace
193
+ msg["stop"] = True
194
+ msg["done"] = True
195
+ msg["mask"] = False
196
+ msg["is_observed"] = True
197
+ msg["infer"]["is_auxiliary"] = True
198
+ msg["infer"]["_do_not_trace"] = True
199
+
200
+ def _pyro_post_sample(self, msg: pyro.poutine.runtime.Message) -> None:
201
+ infer = msg.get("infer")
202
+ if infer is None or "_index_naming" not in infer:
203
+ return
204
+
205
+ # note: Pyro uses a TypedDict for infer, so it doesn't know we've stored this key
206
+ naming = infer["_index_naming"] # type: ignore
207
+
208
+ value = msg["value"]
209
+
210
+ if value is not None:
211
+ # note: is it safe to assume that msg['fn'] is a distribution?
212
+ assert isinstance(
213
+ msg["fn"], pyro.distributions.torch_distribution.TorchDistribution
214
+ )
215
+ dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape
216
+ if len(value.shape) < len(dist_shape):
217
+ value = value.broadcast_to(
218
+ torch.broadcast_shapes(value.shape, dist_shape)
219
+ )
220
+ value = naming.apply(value)
221
+ msg["value"] = value
222
+
223
+
224
+ class Naming:
225
+ """
226
+ A mapping from dimensions (indexed from the right) to names.
227
+ """
228
+
229
+ def __init__(self, name_to_dim: Mapping[Operation[[], int], int]):
230
+ assert all(v < 0 for v in name_to_dim.values())
231
+ self.name_to_dim = name_to_dim
232
+
233
+ @staticmethod
234
+ def from_shape(names: Collection[Operation[[], int]], event_dims: int) -> "Naming":
235
+ """Create a naming from a set of indices and the number of event dimensions.
236
+
237
+ The resulting naming converts tensors of shape
238
+ ``| batch_shape | named | event_shape |``
239
+ to tensors of shape ``| batch_shape | event_shape |, | named |``.
240
+
241
+ """
242
+ assert event_dims >= 0
243
+ return Naming({n: -event_dims - len(names) + i for i, n in enumerate(names)})
244
+
245
+ def apply(self, value: torch.Tensor) -> torch.Tensor:
246
+ indexes: List[Any] = [slice(None)] * (len(value.shape))
247
+ for n, d in self.name_to_dim.items():
248
+ indexes[len(value.shape) + d] = n()
249
+ return Indexable(value)[tuple(indexes)]
250
+
251
+ def __repr__(self):
252
+ return f"Naming({self.name_to_dim})"
253
+
254
+
255
+ class PositionalDistribution(pyro.distributions.torch_distribution.TorchDistribution):
256
+ """A distribution wrapper that lazily converts indexed dimensions to
257
+ positional.
258
+
259
+ """
260
+
261
+ indices: Mapping[Operation[[], int], int]
262
+
263
+ def __init__(
264
+ self, base_dist: pyro.distributions.torch_distribution.TorchDistribution
265
+ ):
266
+ self.base_dist = base_dist
267
+ self.indices = sizesof(base_dist.sample())
268
+
269
+ n_base = len(base_dist.batch_shape) + len(base_dist.event_shape)
270
+ self.naming = Naming.from_shape(self.indices.keys(), n_base)
271
+
272
+ super().__init__()
273
+
274
+ def _to_positional(self, value: torch.Tensor) -> torch.Tensor:
275
+ # self.base_dist has shape: | batch_shape | event_shape | & named
276
+ # assume value comes from base_dist with shape:
277
+ # | sample_shape | batch_shape | event_shape | & named
278
+ # return a tensor of shape | sample_shape | named | batch_shape | event_shape |
279
+ n_named = len(self.indices)
280
+ dims = list(range(n_named + len(value.shape)))
281
+
282
+ n_base = len(self.event_shape) + len(self.base_dist.batch_shape)
283
+ n_sample = len(value.shape) - n_base
284
+
285
+ base_dims = dims[len(dims) - n_base :]
286
+ named_dims = dims[:n_named]
287
+ sample_dims = dims[n_named : n_named + n_sample]
288
+
289
+ # shape: | named | sample_shape | batch_shape | event_shape |
290
+ # TODO: replace with something more efficient
291
+ pos_tensor = to_tensor(value, self.indices.keys())
292
+
293
+ # shape: | sample_shape | named | batch_shape | event_shape |
294
+ pos_tensor_r = torch.permute(pos_tensor, sample_dims + named_dims + base_dims)
295
+
296
+ return pos_tensor_r
297
+
298
+ def _from_positional(self, value: torch.Tensor) -> torch.Tensor:
299
+ # maximal value shape: | sample_shape | named | batch_shape | event_shape |
300
+ return self.naming.apply(value)
301
+
302
+ @property
303
+ def has_rsample(self):
304
+ return self.base_dist.has_rsample
305
+
306
+ @property
307
+ def batch_shape(self):
308
+ return (
309
+ torch.Size([s for s in self.indices.values()]) + self.base_dist.batch_shape
310
+ )
311
+
312
+ @property
313
+ def event_shape(self):
314
+ return self.base_dist.event_shape
315
+
316
+ @property
317
+ def has_enumerate_support(self):
318
+ return self.base_dist.has_enumerate_support
319
+
320
+ @property
321
+ def arg_constraints(self):
322
+ return self.base_dist.arg_constraints
323
+
324
+ @property
325
+ def support(self):
326
+ return self.base_dist.support
327
+
328
+ def __repr__(self):
329
+ return f"PositionalDistribution({self.base_dist})"
330
+
331
+ def sample(self, sample_shape=torch.Size()):
332
+ return self._to_positional(self.base_dist.sample(sample_shape))
333
+
334
+ def rsample(self, sample_shape=torch.Size()):
335
+ return self._to_positional(self.base_dist.rsample(sample_shape))
336
+
337
+ def log_prob(self, value):
338
+ return self._to_positional(
339
+ self.base_dist.log_prob(self._from_positional(value))
340
+ )
341
+
342
+ def enumerate_support(self, expand=True):
343
+ return self._to_positional(self.base_dist.enumerate_support(expand))
344
+
345
+
346
+ class NamedDistribution(pyro.distributions.torch_distribution.TorchDistribution):
347
+ """A distribution wrapper that lazily names leftmost dimensions."""
348
+
349
+ def __init__(
350
+ self,
351
+ base_dist: pyro.distributions.torch_distribution.TorchDistribution,
352
+ names: Collection[Operation[[], int]],
353
+ ):
354
+ """
355
+ :param base_dist: A distribution with batch dimensions.
356
+
357
+ :param names: A list of names.
358
+
359
+ """
360
+ self.base_dist = base_dist
361
+ self.names = names
362
+
363
+ assert 1 <= len(names) <= len(base_dist.batch_shape)
364
+ base_indices = sizesof(base_dist.sample())
365
+ assert not any(n in base_indices for n in names)
366
+
367
+ n_base = len(base_dist.batch_shape) + len(base_dist.event_shape)
368
+ self.naming = Naming.from_shape(names, n_base - len(names))
369
+ super().__init__()
370
+
371
+ def _to_named(self, value: torch.Tensor, offset=0) -> torch.Tensor:
372
+ return self.naming.apply(value)
373
+
374
+ def _from_named(self, value: torch.Tensor) -> torch.Tensor:
375
+ pos_value = to_tensor(value, self.names)
376
+
377
+ dims = list(range(len(pos_value.shape)))
378
+
379
+ n_base = len(self.event_shape) + len(self.batch_shape)
380
+ n_named = len(self.names)
381
+ n_sample = len(pos_value.shape) - n_base - n_named
382
+
383
+ base_dims = dims[len(dims) - n_base :]
384
+ named_dims = dims[:n_named]
385
+ sample_dims = dims[n_named : n_named + n_sample]
386
+
387
+ pos_tensor_r = torch.permute(pos_value, sample_dims + named_dims + base_dims)
388
+
389
+ return pos_tensor_r
390
+
391
+ @property
392
+ def has_rsample(self):
393
+ return self.base_dist.has_rsample
394
+
395
+ @property
396
+ def batch_shape(self):
397
+ return self.base_dist.batch_shape[len(self.names) :]
398
+
399
+ @property
400
+ def event_shape(self):
401
+ return self.base_dist.event_shape
402
+
403
+ @property
404
+ def has_enumerate_support(self):
405
+ return self.base_dist.has_enumerate_support
406
+
407
+ @property
408
+ def arg_constraints(self):
409
+ return self.base_dist.arg_constraints
410
+
411
+ @property
412
+ def support(self):
413
+ return self.base_dist.support
414
+
415
+ def __repr__(self):
416
+ return f"NamedDistribution({self.base_dist}, {self.names})"
417
+
418
+ def sample(self, sample_shape=torch.Size()):
419
+ t = self._to_named(
420
+ self.base_dist.sample(sample_shape), offset=len(sample_shape)
421
+ )
422
+ assert set(sizesof(t).keys()) == set(self.names)
423
+ assert t.shape == self.shape() + sample_shape
424
+ return t
425
+
426
+ def rsample(self, sample_shape=torch.Size()):
427
+ return self._to_named(
428
+ self.base_dist.rsample(sample_shape), offset=len(sample_shape)
429
+ )
430
+
431
+ def log_prob(self, value):
432
+ v1 = self._from_named(value)
433
+ v2 = self.base_dist.log_prob(v1)
434
+ v3 = self._to_named(v2)
435
+ return v3
436
+
437
+ def enumerate_support(self, expand=True):
438
+ return self._to_named(self.base_dist.enumerate_support(expand))
439
+
440
+
441
+ def pyro_module_shim(
442
+ module: type[pyro.nn.module.PyroModule],
443
+ ) -> type[pyro.nn.module.PyroModule]:
444
+ """Wrap a :class:`PyroModule` in a :class:`PyroShim`.
445
+
446
+ Returns a new subclass of :class:`PyroModule` that wraps calls to
447
+ :func:`forward` in a :class:`PyroShim`.
448
+
449
+ **Example usage**:
450
+
451
+ .. code-block:: python
452
+
453
+ class SimpleModel(PyroModule):
454
+ def forward(self):
455
+ return pyro.sample("y", dist.Normal(0, 1))
456
+
457
+ SimpleModelShim = pyro_module_shim(SimpleModel)
458
+
459
+ """
460
+
461
+ class PyroModuleShim(module): # type: ignore
462
+ def forward(self, *args, **kwargs):
463
+ with PyroShim():
464
+ return super().forward(*args, **kwargs)
465
+
466
+ return PyroModuleShim