effectful 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/__init__.py +0 -0
- effectful/handlers/__init__.py +0 -0
- effectful/handlers/indexed.py +320 -0
- effectful/handlers/numbers.py +259 -0
- effectful/handlers/pyro.py +466 -0
- effectful/handlers/torch.py +572 -0
- effectful/internals/__init__.py +0 -0
- effectful/internals/base_impl.py +259 -0
- effectful/internals/runtime.py +78 -0
- effectful/ops/__init__.py +0 -0
- effectful/ops/semantics.py +329 -0
- effectful/ops/syntax.py +523 -0
- effectful/ops/types.py +110 -0
- effectful/py.typed +0 -0
- effectful-0.0.1.dist-info/LICENSE.md +202 -0
- effectful-0.0.1.dist-info/METADATA +170 -0
- effectful-0.0.1.dist-info/RECORD +19 -0
- effectful-0.0.1.dist-info/WHEEL +5 -0
- effectful-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,466 @@
|
|
1
|
+
import typing
|
2
|
+
import warnings
|
3
|
+
from typing import Any, Collection, List, Mapping, Optional, Tuple
|
4
|
+
|
5
|
+
try:
|
6
|
+
import pyro
|
7
|
+
except ImportError:
|
8
|
+
raise ImportError("Pyro is required to use effectful.handlers.pyro.")
|
9
|
+
|
10
|
+
try:
|
11
|
+
import torch
|
12
|
+
except ImportError:
|
13
|
+
raise ImportError("PyTorch is required to use effectful.handlers.pyro.")
|
14
|
+
|
15
|
+
from typing_extensions import ParamSpec
|
16
|
+
|
17
|
+
from effectful.handlers.torch import Indexable, sizesof, to_tensor
|
18
|
+
from effectful.ops.syntax import defop
|
19
|
+
from effectful.ops.types import Operation
|
20
|
+
|
21
|
+
P = ParamSpec("P")
|
22
|
+
|
23
|
+
|
24
|
+
@defop
|
25
|
+
def pyro_sample(
|
26
|
+
name: str,
|
27
|
+
fn: pyro.distributions.torch_distribution.TorchDistributionMixin,
|
28
|
+
*args,
|
29
|
+
obs: Optional[torch.Tensor] = None,
|
30
|
+
obs_mask: Optional[torch.BoolTensor] = None,
|
31
|
+
mask: Optional[torch.BoolTensor] = None,
|
32
|
+
infer: Optional[pyro.poutine.runtime.InferDict] = None,
|
33
|
+
**kwargs,
|
34
|
+
) -> torch.Tensor:
|
35
|
+
"""
|
36
|
+
Operation to sample from a Pyro distribution. See :func:`pyro.sample`.
|
37
|
+
"""
|
38
|
+
with pyro.poutine.mask(mask=mask if mask is not None else True):
|
39
|
+
return pyro.sample(
|
40
|
+
name, fn, *args, obs=obs, obs_mask=obs_mask, infer=infer, **kwargs
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
class PyroShim(pyro.poutine.messenger.Messenger):
|
45
|
+
"""Pyro handler that wraps all sample sites in a custom effectful type.
|
46
|
+
|
47
|
+
.. note::
|
48
|
+
|
49
|
+
This handler should be installed around any Pyro model that you want to
|
50
|
+
use effectful handlers with.
|
51
|
+
|
52
|
+
**Example usage**:
|
53
|
+
|
54
|
+
>>> import pyro.distributions as dist
|
55
|
+
>>> from effectful.ops.semantics import fwd, handler
|
56
|
+
>>> torch.distributions.Distribution.set_default_validate_args(False)
|
57
|
+
|
58
|
+
It can be used as a decorator:
|
59
|
+
|
60
|
+
>>> @PyroShim()
|
61
|
+
... def model():
|
62
|
+
... return pyro.sample("x", dist.Normal(0, 1))
|
63
|
+
|
64
|
+
It can also be used as a context manager:
|
65
|
+
|
66
|
+
>>> with PyroShim():
|
67
|
+
... x = pyro.sample("x", dist.Normal(0, 1))
|
68
|
+
|
69
|
+
When :class:`PyroShim` is installed, all sample sites perform the
|
70
|
+
:func:`pyro_sample` effect, which can be handled by an effectful
|
71
|
+
interpretation.
|
72
|
+
|
73
|
+
>>> def log_sample(name, *args, **kwargs):
|
74
|
+
... print(f"Sampled {name}")
|
75
|
+
... return fwd()
|
76
|
+
|
77
|
+
>>> with PyroShim(), handler({pyro_sample: log_sample}):
|
78
|
+
... x = pyro.sample("x", dist.Normal(0, 1))
|
79
|
+
... y = pyro.sample("y", dist.Normal(0, 1))
|
80
|
+
Sampled x
|
81
|
+
Sampled y
|
82
|
+
"""
|
83
|
+
|
84
|
+
_current_site: Optional[str]
|
85
|
+
|
86
|
+
def __enter__(self):
|
87
|
+
if any(isinstance(m, PyroShim) for m in pyro.poutine.runtime._PYRO_STACK):
|
88
|
+
warnings.warn("PyroShim should be installed at most once.")
|
89
|
+
return super().__enter__()
|
90
|
+
|
91
|
+
@staticmethod
|
92
|
+
def _broadcast_to_named(
|
93
|
+
t: torch.Tensor, shape: torch.Size, indices: Mapping[Operation[[], int], int]
|
94
|
+
) -> Tuple[torch.Tensor, "Naming"]:
|
95
|
+
"""Convert a tensor `t` to a fully positional tensor that is
|
96
|
+
broadcastable with the positional representation of tensors of shape
|
97
|
+
|shape|, |indices|.
|
98
|
+
|
99
|
+
"""
|
100
|
+
t_indices = sizesof(t)
|
101
|
+
|
102
|
+
if len(t.shape) < len(shape):
|
103
|
+
t = t.expand(shape)
|
104
|
+
|
105
|
+
# create a positional dimension for every named index in the target shape
|
106
|
+
name_to_dim = {}
|
107
|
+
for i, (k, v) in enumerate(reversed(list(indices.items()))):
|
108
|
+
if k in t_indices:
|
109
|
+
t = to_tensor(t, [k])
|
110
|
+
else:
|
111
|
+
t = t.expand((v,) + t.shape)
|
112
|
+
name_to_dim[k] = -len(shape) - i - 1
|
113
|
+
|
114
|
+
# create a positional dimension for every remaining named index in `t`
|
115
|
+
n_batch_and_dist_named = len(t.shape)
|
116
|
+
for i, k in enumerate(reversed(list(sizesof(t).keys()))):
|
117
|
+
t = to_tensor(t, [k])
|
118
|
+
name_to_dim[k] = -n_batch_and_dist_named - i - 1
|
119
|
+
|
120
|
+
return t, Naming(name_to_dim)
|
121
|
+
|
122
|
+
def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None:
|
123
|
+
if typing.TYPE_CHECKING:
|
124
|
+
assert msg["type"] == "sample"
|
125
|
+
assert msg["name"] is not None
|
126
|
+
assert msg["infer"] is not None
|
127
|
+
assert isinstance(
|
128
|
+
msg["fn"], pyro.distributions.torch_distribution.TorchDistributionMixin
|
129
|
+
)
|
130
|
+
|
131
|
+
if pyro.poutine.util.site_is_subsample(msg) or pyro.poutine.util.site_is_factor(
|
132
|
+
msg
|
133
|
+
):
|
134
|
+
return
|
135
|
+
|
136
|
+
if getattr(self, "_current_site", None) == msg["name"]:
|
137
|
+
if "_markov_scope" in msg["infer"] and self._current_site:
|
138
|
+
msg["infer"]["_markov_scope"].pop(self._current_site, None)
|
139
|
+
|
140
|
+
dist = msg["fn"]
|
141
|
+
obs = msg["value"] if msg["is_observed"] else None
|
142
|
+
|
143
|
+
# pdist shape: | named1 | batch_shape | event_shape |
|
144
|
+
# obs shape: | batch_shape | event_shape |, | named2 | where named2 may overlap named1
|
145
|
+
pdist = PositionalDistribution(dist)
|
146
|
+
naming = pdist.naming
|
147
|
+
|
148
|
+
if msg["mask"] is None:
|
149
|
+
mask = torch.tensor(True)
|
150
|
+
elif isinstance(msg["mask"], bool):
|
151
|
+
mask = torch.tensor(msg["mask"])
|
152
|
+
else:
|
153
|
+
mask = msg["mask"]
|
154
|
+
|
155
|
+
pos_mask, _ = PyroShim._broadcast_to_named(
|
156
|
+
mask, dist.batch_shape, pdist.indices
|
157
|
+
)
|
158
|
+
|
159
|
+
pos_obs: Optional[torch.Tensor] = None
|
160
|
+
if obs is not None:
|
161
|
+
pos_obs, naming = PyroShim._broadcast_to_named(
|
162
|
+
obs, dist.shape(), pdist.indices
|
163
|
+
)
|
164
|
+
|
165
|
+
for var, dim in naming.name_to_dim.items():
|
166
|
+
frame = pyro.poutine.indep_messenger.CondIndepStackFrame(
|
167
|
+
name=str(var), dim=dim, size=-1, counter=0
|
168
|
+
)
|
169
|
+
msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"]
|
170
|
+
|
171
|
+
msg["fn"] = pdist
|
172
|
+
msg["value"] = pos_obs
|
173
|
+
msg["mask"] = pos_mask
|
174
|
+
msg["infer"]["_index_naming"] = naming # type: ignore
|
175
|
+
|
176
|
+
assert sizesof(msg["value"]) == {}
|
177
|
+
assert sizesof(msg["mask"]) == {}
|
178
|
+
|
179
|
+
return
|
180
|
+
|
181
|
+
try:
|
182
|
+
self._current_site = msg["name"]
|
183
|
+
msg["value"] = pyro_sample(
|
184
|
+
msg["name"],
|
185
|
+
msg["fn"],
|
186
|
+
obs=msg["value"] if msg["is_observed"] else None,
|
187
|
+
infer=msg["infer"].copy(),
|
188
|
+
)
|
189
|
+
finally:
|
190
|
+
self._current_site = None
|
191
|
+
|
192
|
+
# flags to guarantee commutativity of condition, intervene, trace
|
193
|
+
msg["stop"] = True
|
194
|
+
msg["done"] = True
|
195
|
+
msg["mask"] = False
|
196
|
+
msg["is_observed"] = True
|
197
|
+
msg["infer"]["is_auxiliary"] = True
|
198
|
+
msg["infer"]["_do_not_trace"] = True
|
199
|
+
|
200
|
+
def _pyro_post_sample(self, msg: pyro.poutine.runtime.Message) -> None:
|
201
|
+
infer = msg.get("infer")
|
202
|
+
if infer is None or "_index_naming" not in infer:
|
203
|
+
return
|
204
|
+
|
205
|
+
# note: Pyro uses a TypedDict for infer, so it doesn't know we've stored this key
|
206
|
+
naming = infer["_index_naming"] # type: ignore
|
207
|
+
|
208
|
+
value = msg["value"]
|
209
|
+
|
210
|
+
if value is not None:
|
211
|
+
# note: is it safe to assume that msg['fn'] is a distribution?
|
212
|
+
assert isinstance(
|
213
|
+
msg["fn"], pyro.distributions.torch_distribution.TorchDistribution
|
214
|
+
)
|
215
|
+
dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape
|
216
|
+
if len(value.shape) < len(dist_shape):
|
217
|
+
value = value.broadcast_to(
|
218
|
+
torch.broadcast_shapes(value.shape, dist_shape)
|
219
|
+
)
|
220
|
+
value = naming.apply(value)
|
221
|
+
msg["value"] = value
|
222
|
+
|
223
|
+
|
224
|
+
class Naming:
|
225
|
+
"""
|
226
|
+
A mapping from dimensions (indexed from the right) to names.
|
227
|
+
"""
|
228
|
+
|
229
|
+
def __init__(self, name_to_dim: Mapping[Operation[[], int], int]):
|
230
|
+
assert all(v < 0 for v in name_to_dim.values())
|
231
|
+
self.name_to_dim = name_to_dim
|
232
|
+
|
233
|
+
@staticmethod
|
234
|
+
def from_shape(names: Collection[Operation[[], int]], event_dims: int) -> "Naming":
|
235
|
+
"""Create a naming from a set of indices and the number of event dimensions.
|
236
|
+
|
237
|
+
The resulting naming converts tensors of shape
|
238
|
+
``| batch_shape | named | event_shape |``
|
239
|
+
to tensors of shape ``| batch_shape | event_shape |, | named |``.
|
240
|
+
|
241
|
+
"""
|
242
|
+
assert event_dims >= 0
|
243
|
+
return Naming({n: -event_dims - len(names) + i for i, n in enumerate(names)})
|
244
|
+
|
245
|
+
def apply(self, value: torch.Tensor) -> torch.Tensor:
|
246
|
+
indexes: List[Any] = [slice(None)] * (len(value.shape))
|
247
|
+
for n, d in self.name_to_dim.items():
|
248
|
+
indexes[len(value.shape) + d] = n()
|
249
|
+
return Indexable(value)[tuple(indexes)]
|
250
|
+
|
251
|
+
def __repr__(self):
|
252
|
+
return f"Naming({self.name_to_dim})"
|
253
|
+
|
254
|
+
|
255
|
+
class PositionalDistribution(pyro.distributions.torch_distribution.TorchDistribution):
|
256
|
+
"""A distribution wrapper that lazily converts indexed dimensions to
|
257
|
+
positional.
|
258
|
+
|
259
|
+
"""
|
260
|
+
|
261
|
+
indices: Mapping[Operation[[], int], int]
|
262
|
+
|
263
|
+
def __init__(
|
264
|
+
self, base_dist: pyro.distributions.torch_distribution.TorchDistribution
|
265
|
+
):
|
266
|
+
self.base_dist = base_dist
|
267
|
+
self.indices = sizesof(base_dist.sample())
|
268
|
+
|
269
|
+
n_base = len(base_dist.batch_shape) + len(base_dist.event_shape)
|
270
|
+
self.naming = Naming.from_shape(self.indices.keys(), n_base)
|
271
|
+
|
272
|
+
super().__init__()
|
273
|
+
|
274
|
+
def _to_positional(self, value: torch.Tensor) -> torch.Tensor:
|
275
|
+
# self.base_dist has shape: | batch_shape | event_shape | & named
|
276
|
+
# assume value comes from base_dist with shape:
|
277
|
+
# | sample_shape | batch_shape | event_shape | & named
|
278
|
+
# return a tensor of shape | sample_shape | named | batch_shape | event_shape |
|
279
|
+
n_named = len(self.indices)
|
280
|
+
dims = list(range(n_named + len(value.shape)))
|
281
|
+
|
282
|
+
n_base = len(self.event_shape) + len(self.base_dist.batch_shape)
|
283
|
+
n_sample = len(value.shape) - n_base
|
284
|
+
|
285
|
+
base_dims = dims[len(dims) - n_base :]
|
286
|
+
named_dims = dims[:n_named]
|
287
|
+
sample_dims = dims[n_named : n_named + n_sample]
|
288
|
+
|
289
|
+
# shape: | named | sample_shape | batch_shape | event_shape |
|
290
|
+
# TODO: replace with something more efficient
|
291
|
+
pos_tensor = to_tensor(value, self.indices.keys())
|
292
|
+
|
293
|
+
# shape: | sample_shape | named | batch_shape | event_shape |
|
294
|
+
pos_tensor_r = torch.permute(pos_tensor, sample_dims + named_dims + base_dims)
|
295
|
+
|
296
|
+
return pos_tensor_r
|
297
|
+
|
298
|
+
def _from_positional(self, value: torch.Tensor) -> torch.Tensor:
|
299
|
+
# maximal value shape: | sample_shape | named | batch_shape | event_shape |
|
300
|
+
return self.naming.apply(value)
|
301
|
+
|
302
|
+
@property
|
303
|
+
def has_rsample(self):
|
304
|
+
return self.base_dist.has_rsample
|
305
|
+
|
306
|
+
@property
|
307
|
+
def batch_shape(self):
|
308
|
+
return (
|
309
|
+
torch.Size([s for s in self.indices.values()]) + self.base_dist.batch_shape
|
310
|
+
)
|
311
|
+
|
312
|
+
@property
|
313
|
+
def event_shape(self):
|
314
|
+
return self.base_dist.event_shape
|
315
|
+
|
316
|
+
@property
|
317
|
+
def has_enumerate_support(self):
|
318
|
+
return self.base_dist.has_enumerate_support
|
319
|
+
|
320
|
+
@property
|
321
|
+
def arg_constraints(self):
|
322
|
+
return self.base_dist.arg_constraints
|
323
|
+
|
324
|
+
@property
|
325
|
+
def support(self):
|
326
|
+
return self.base_dist.support
|
327
|
+
|
328
|
+
def __repr__(self):
|
329
|
+
return f"PositionalDistribution({self.base_dist})"
|
330
|
+
|
331
|
+
def sample(self, sample_shape=torch.Size()):
|
332
|
+
return self._to_positional(self.base_dist.sample(sample_shape))
|
333
|
+
|
334
|
+
def rsample(self, sample_shape=torch.Size()):
|
335
|
+
return self._to_positional(self.base_dist.rsample(sample_shape))
|
336
|
+
|
337
|
+
def log_prob(self, value):
|
338
|
+
return self._to_positional(
|
339
|
+
self.base_dist.log_prob(self._from_positional(value))
|
340
|
+
)
|
341
|
+
|
342
|
+
def enumerate_support(self, expand=True):
|
343
|
+
return self._to_positional(self.base_dist.enumerate_support(expand))
|
344
|
+
|
345
|
+
|
346
|
+
class NamedDistribution(pyro.distributions.torch_distribution.TorchDistribution):
|
347
|
+
"""A distribution wrapper that lazily names leftmost dimensions."""
|
348
|
+
|
349
|
+
def __init__(
|
350
|
+
self,
|
351
|
+
base_dist: pyro.distributions.torch_distribution.TorchDistribution,
|
352
|
+
names: Collection[Operation[[], int]],
|
353
|
+
):
|
354
|
+
"""
|
355
|
+
:param base_dist: A distribution with batch dimensions.
|
356
|
+
|
357
|
+
:param names: A list of names.
|
358
|
+
|
359
|
+
"""
|
360
|
+
self.base_dist = base_dist
|
361
|
+
self.names = names
|
362
|
+
|
363
|
+
assert 1 <= len(names) <= len(base_dist.batch_shape)
|
364
|
+
base_indices = sizesof(base_dist.sample())
|
365
|
+
assert not any(n in base_indices for n in names)
|
366
|
+
|
367
|
+
n_base = len(base_dist.batch_shape) + len(base_dist.event_shape)
|
368
|
+
self.naming = Naming.from_shape(names, n_base - len(names))
|
369
|
+
super().__init__()
|
370
|
+
|
371
|
+
def _to_named(self, value: torch.Tensor, offset=0) -> torch.Tensor:
|
372
|
+
return self.naming.apply(value)
|
373
|
+
|
374
|
+
def _from_named(self, value: torch.Tensor) -> torch.Tensor:
|
375
|
+
pos_value = to_tensor(value, self.names)
|
376
|
+
|
377
|
+
dims = list(range(len(pos_value.shape)))
|
378
|
+
|
379
|
+
n_base = len(self.event_shape) + len(self.batch_shape)
|
380
|
+
n_named = len(self.names)
|
381
|
+
n_sample = len(pos_value.shape) - n_base - n_named
|
382
|
+
|
383
|
+
base_dims = dims[len(dims) - n_base :]
|
384
|
+
named_dims = dims[:n_named]
|
385
|
+
sample_dims = dims[n_named : n_named + n_sample]
|
386
|
+
|
387
|
+
pos_tensor_r = torch.permute(pos_value, sample_dims + named_dims + base_dims)
|
388
|
+
|
389
|
+
return pos_tensor_r
|
390
|
+
|
391
|
+
@property
|
392
|
+
def has_rsample(self):
|
393
|
+
return self.base_dist.has_rsample
|
394
|
+
|
395
|
+
@property
|
396
|
+
def batch_shape(self):
|
397
|
+
return self.base_dist.batch_shape[len(self.names) :]
|
398
|
+
|
399
|
+
@property
|
400
|
+
def event_shape(self):
|
401
|
+
return self.base_dist.event_shape
|
402
|
+
|
403
|
+
@property
|
404
|
+
def has_enumerate_support(self):
|
405
|
+
return self.base_dist.has_enumerate_support
|
406
|
+
|
407
|
+
@property
|
408
|
+
def arg_constraints(self):
|
409
|
+
return self.base_dist.arg_constraints
|
410
|
+
|
411
|
+
@property
|
412
|
+
def support(self):
|
413
|
+
return self.base_dist.support
|
414
|
+
|
415
|
+
def __repr__(self):
|
416
|
+
return f"NamedDistribution({self.base_dist}, {self.names})"
|
417
|
+
|
418
|
+
def sample(self, sample_shape=torch.Size()):
|
419
|
+
t = self._to_named(
|
420
|
+
self.base_dist.sample(sample_shape), offset=len(sample_shape)
|
421
|
+
)
|
422
|
+
assert set(sizesof(t).keys()) == set(self.names)
|
423
|
+
assert t.shape == self.shape() + sample_shape
|
424
|
+
return t
|
425
|
+
|
426
|
+
def rsample(self, sample_shape=torch.Size()):
|
427
|
+
return self._to_named(
|
428
|
+
self.base_dist.rsample(sample_shape), offset=len(sample_shape)
|
429
|
+
)
|
430
|
+
|
431
|
+
def log_prob(self, value):
|
432
|
+
v1 = self._from_named(value)
|
433
|
+
v2 = self.base_dist.log_prob(v1)
|
434
|
+
v3 = self._to_named(v2)
|
435
|
+
return v3
|
436
|
+
|
437
|
+
def enumerate_support(self, expand=True):
|
438
|
+
return self._to_named(self.base_dist.enumerate_support(expand))
|
439
|
+
|
440
|
+
|
441
|
+
def pyro_module_shim(
|
442
|
+
module: type[pyro.nn.module.PyroModule],
|
443
|
+
) -> type[pyro.nn.module.PyroModule]:
|
444
|
+
"""Wrap a :class:`PyroModule` in a :class:`PyroShim`.
|
445
|
+
|
446
|
+
Returns a new subclass of :class:`PyroModule` that wraps calls to
|
447
|
+
:func:`forward` in a :class:`PyroShim`.
|
448
|
+
|
449
|
+
**Example usage**:
|
450
|
+
|
451
|
+
.. code-block:: python
|
452
|
+
|
453
|
+
class SimpleModel(PyroModule):
|
454
|
+
def forward(self):
|
455
|
+
return pyro.sample("y", dist.Normal(0, 1))
|
456
|
+
|
457
|
+
SimpleModelShim = pyro_module_shim(SimpleModel)
|
458
|
+
|
459
|
+
"""
|
460
|
+
|
461
|
+
class PyroModuleShim(module): # type: ignore
|
462
|
+
def forward(self, *args, **kwargs):
|
463
|
+
with PyroShim():
|
464
|
+
return super().forward(*args, **kwargs)
|
465
|
+
|
466
|
+
return PyroModuleShim
|