effectful 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
effectful/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,320 @@
1
+ import functools
2
+ import operator
3
+ from typing import Any, Dict, Iterable, Optional, Set, TypeVar, Union
4
+
5
+ import torch
6
+
7
+ from effectful.handlers.torch import Indexable, sizesof
8
+ from effectful.ops.syntax import deffn, defop
9
+ from effectful.ops.types import Operation, Term
10
+
11
+ K = TypeVar("K")
12
+ T = TypeVar("T")
13
+
14
+
15
+ class IndexSet(Dict[str, Set[int]]):
16
+ """
17
+ :class:`IndexSet` s represent the support of an indexed value,
18
+ for which free variables correspond to single interventions and indices
19
+ to worlds where that intervention either did or did not happen.
20
+
21
+ :class:`IndexSet` can be understood conceptually as generalizing :class:`torch.Size`
22
+ from multidimensional arrays to arbitrary values, from positional to named dimensions,
23
+ and from bounded integer interval supports to finite sets of positive integers.
24
+
25
+ :class:`IndexSet`s are implemented as :class:`dict`s with
26
+ :class:`str`s as keys corresponding to names of free index variables
27
+ and :class:`set` s of positive :class:`int` s as values corresponding
28
+ to the values of the index variables where the indexed value is defined.
29
+
30
+ For example, the following :class:`IndexSet` represents
31
+ the sets of indices of the free variables ``x`` and ``y``
32
+ for which a value is defined::
33
+
34
+ >>> IndexSet(x={0, 1}, y={2, 3})
35
+ IndexSet({x: {0, 1}, y: {2, 3}})
36
+
37
+ :class:`IndexSet` 's constructor will automatically drop empty entries
38
+ and attempt to convert input values to :class:`set` s::
39
+
40
+ >>> IndexSet(x=[0, 0, 1], y=set(), z=2)
41
+ IndexSet({x: {0, 1}, z: {2}})
42
+
43
+ :class:`IndexSet` s are also hashable and can be used as keys in :class:`dict` s::
44
+
45
+ >>> indexset = IndexSet(x={0, 1}, y={2, 3})
46
+ >>> indexset in {indexset: 1}
47
+ True
48
+ """
49
+
50
+ def __init__(self, **mapping: Union[int, Iterable[int]]):
51
+ index_set = {}
52
+ for k, vs in mapping.items():
53
+ indexes = {vs} if isinstance(vs, int) else set(vs)
54
+ if len(indexes) > 0:
55
+ index_set[k] = indexes
56
+ super().__init__(**index_set)
57
+
58
+ def __repr__(self):
59
+ return f"{type(self).__name__}({super().__repr__()})"
60
+
61
+ def __hash__(self):
62
+ return hash(frozenset((k, frozenset(vs)) for k, vs in self.items()))
63
+
64
+ def _to_handler(self):
65
+ """Return an effectful handler that binds each index variable to a
66
+ tensor of its possible index values.
67
+
68
+ """
69
+ return {
70
+ name_to_sym(k): functools.partial(lambda v: v, torch.tensor(list(v)))
71
+ for k, v in self.items()
72
+ }
73
+
74
+
75
+ def union(*indexsets: IndexSet) -> IndexSet:
76
+ """
77
+ Compute the union of multiple :class:`IndexSet` s
78
+ as the union of their keys and of value sets at shared keys.
79
+
80
+ If :class:`IndexSet` may be viewed as a generalization of :class:`torch.Size`,
81
+ then :func:`union` is a generalization of :func:`torch.broadcast_shapes`
82
+ for the more abstract :class:`IndexSet` data structure.
83
+
84
+ Example::
85
+
86
+ >>> s = union(IndexSet(a={0, 1}, b={1}), IndexSet(a={1, 2}))
87
+ >>> s["a"]
88
+ {0, 1, 2}
89
+ >>> s["b"]
90
+ {1}
91
+
92
+ .. note::
93
+
94
+ :func:`union` satisfies several algebraic equations for arbitrary inputs.
95
+ In particular, it is associative, commutative, idempotent and absorbing::
96
+
97
+ union(a, union(b, c)) == union(union(a, b), c)
98
+ union(a, b) == union(b, a)
99
+ union(a, a) == a
100
+ union(a, union(a, b)) == union(a, b)
101
+ """
102
+ return IndexSet(
103
+ **{
104
+ k: set.union(*[vs[k] for vs in indexsets if k in vs])
105
+ for k in set.union(*(set(vs) for vs in indexsets))
106
+ }
107
+ )
108
+
109
+
110
+ def indices_of(value: Any) -> IndexSet:
111
+ """
112
+ Get a :class:`IndexSet` of indices on which an indexed value is supported.
113
+ :func:`indices_of` is useful in conjunction with :class:`MultiWorldCounterfactual`
114
+ for identifying the worlds where an intervention happened upstream of a value.
115
+
116
+ For example, in a model with an outcome variable ``Y`` and a treatment variable
117
+ ``T`` that has been intervened on, ``T`` and ``Y`` are both indexed by ``"T"``::
118
+
119
+ >>> def example():
120
+ ... with MultiWorldCounterfactual():
121
+ ... X = pyro.sample("X", get_X_dist())
122
+ ... T = pyro.sample("T", get_T_dist(X))
123
+ ... T = intervene(T, t, name="T_ax") # adds an index variable "T_ax"
124
+ ... Y = pyro.sample("Y", get_Y_dist(X, T))
125
+ ... assert indices_of(X) == IndexSet({})
126
+ ... assert indices_of(T) == IndexSet({T_ax: {0, 1}})
127
+ ... assert indices_of(Y) == IndexSet({T_ax: {0, 1}})
128
+ >>> example() # doctest: +SKIP
129
+
130
+ Just as multidimensional arrays can be expanded to shapes with new dimensions
131
+ over which they are constant, :func:`indices_of` is defined extensionally,
132
+ meaning that values are treated as constant functions of free variables
133
+ not in their support.
134
+
135
+ .. note::
136
+
137
+ :func:`indices_of` can be extended to new value types by registering
138
+ an implementation for the type using :func:`functools.singledispatch` .
139
+
140
+ .. note::
141
+
142
+ Fully general versions of :func:`indices_of` , :func:`gather`
143
+ and :func:`scatter` would require a dependent broadcasting semantics
144
+ for indexed values, as is the case in sparse or masked array libraries
145
+ like ``torch.sparse`` or relational databases.
146
+
147
+ However, this is beyond the scope of this library as it currently exists.
148
+ Instead, :func:`gather` currently binds free variables in its input indices
149
+ when their indices there are a strict subset of the corresponding indices
150
+ in ``value`` , so that they no longer appear as free in the result.
151
+
152
+ For example, in the above snippet, applying :func:`gather` to to select only
153
+ the values of ``Y`` from worlds where no intervention on ``T`` happened
154
+ would result in a value that no longer contains free variable ``"T"``::
155
+
156
+ >>> indices_of(Y) == IndexSet(T_ax={0, 1}) # doctest: +SKIP
157
+ True
158
+ >>> Y0 = gather(Y, IndexSet(T_ax={0})) # doctest: +SKIP
159
+ >>> indices_of(Y0) == IndexSet() != IndexSet(T_ax={0}) # doctest: +SKIP
160
+ True
161
+
162
+ The practical implications of this imprecision are limited
163
+ since we rarely need to :func:`gather` along a variable twice.
164
+
165
+ :param value: A value.
166
+ :param kwargs: Additional keyword arguments used by specific implementations.
167
+ :return: A :class:`IndexSet` containing the indices on which the value is supported.
168
+ """
169
+ if isinstance(value, Term):
170
+ return IndexSet(
171
+ **{
172
+ k.__name__: set(range(v)) # type:ignore
173
+ for (k, v) in sizesof(value).items()
174
+ }
175
+ )
176
+ elif isinstance(value, torch.distributions.Distribution):
177
+ return indices_of(value.sample())
178
+
179
+ return IndexSet()
180
+
181
+
182
+ @functools.lru_cache(maxsize=None)
183
+ def name_to_sym(name: str) -> Operation[[], int]:
184
+ return defop(int, name=name)
185
+
186
+
187
+ def gather(value: torch.Tensor, indexset: IndexSet, **kwargs) -> torch.Tensor:
188
+ """
189
+ Selects entries from an indexed value at the indices in a :class:`IndexSet` .
190
+ :func:`gather` is useful in conjunction with :class:`MultiWorldCounterfactual`
191
+ for selecting components of a value corresponding to specific counterfactual worlds.
192
+
193
+ For example, in a model with an outcome variable ``Y`` and a treatment variable
194
+ ``T`` that has been intervened on, we can use :func:`gather` to define quantities
195
+ like treatment effects that require comparison of different potential outcomes::
196
+
197
+ >>> def example():
198
+ ... with MultiWorldCounterfactual():
199
+ ... X = pyro.sample("X", get_X_dist())
200
+ ... T = pyro.sample("T", get_T_dist(X))
201
+ ... T = intervene(T, t, name="T_ax") # adds an index variable "T_ax"
202
+ ... Y = pyro.sample("Y", get_Y_dist(X, T))
203
+ ... Y_factual = gather(Y, IndexSet(T_ax=0)) # no intervention
204
+ ... Y_counterfactual = gather(Y, IndexSet(T_ax=1)) # intervention
205
+ ... treatment_effect = Y_counterfactual - Y_factual
206
+ >>> example() # doctest: +SKIP
207
+
208
+ Like :func:`torch.gather` and substitution in term rewriting,
209
+ :func:`gather` is defined extensionally, meaning that values
210
+ are treated as constant functions of variables not in their support.
211
+
212
+ :func:`gather` will accordingly ignore variables in ``indexset``
213
+ that are not in the support of ``value`` computed by :func:`indices_of` .
214
+
215
+ .. note::
216
+
217
+ :func:`gather` can be extended to new value types by registering
218
+ an implementation for the type using :func:`functools.singledispatch` .
219
+
220
+ .. note::
221
+
222
+ Fully general versions of :func:`indices_of` , :func:`gather`
223
+ and :func:`scatter` would require a dependent broadcasting semantics
224
+ for indexed values, as is the case in sparse or masked array libraries
225
+ like ``scipy.sparse`` or ``xarray`` or in relational databases.
226
+
227
+ However, this is beyond the scope of this library as it currently exists.
228
+ Instead, :func:`gather` currently binds free variables in ``indexset``
229
+ when their indices there are a strict subset of the corresponding indices
230
+ in ``value`` , so that they no longer appear as free in the result.
231
+
232
+ For example, in the above snippet, applying :func:`gather` to to select only
233
+ the values of ``Y`` from worlds where no intervention on ``T`` happened
234
+ would result in a value that no longer contains free variable ``"T"``::
235
+
236
+ >>> indices_of(Y) == IndexSet(T_ax={0, 1}) # doctest: +SKIP
237
+ True
238
+ >>> Y0 = gather(Y, IndexSet(T_ax={0})) # doctest: +SKIP
239
+ >>> indices_of(Y0) == IndexSet() != IndexSet(T_ax={0}) # doctest: +SKIP
240
+ True
241
+
242
+ The practical implications of this imprecision are limited
243
+ since we rarely need to :func:`gather` along a variable twice.
244
+
245
+ :param value: The value to gather.
246
+ :param IndexSet indexset: The :class:`IndexSet` of entries to select from ``value``.
247
+ :return: A new value containing entries of ``value`` from ``indexset``.
248
+ """
249
+ indexset_vars = {name_to_sym(name): inds for name, inds in indexset.items()}
250
+ binding = {
251
+ k: functools.partial(
252
+ lambda v: v, Indexable(torch.tensor(list(indexset_vars[k])))[k()]
253
+ )
254
+ for k in sizesof(value).keys()
255
+ if k in indexset_vars
256
+ }
257
+
258
+ return deffn(value, *binding.keys())(*[v() for v in binding.values()])
259
+
260
+
261
+ def stack(
262
+ values: Union[tuple[torch.Tensor, ...], list[torch.Tensor]], name: str, **kwargs
263
+ ) -> torch.Tensor:
264
+ """Stack a sequence of indexed values, creating a new dimension. The new
265
+ dimension is indexed by `dim`. The indexed values in the stack must have
266
+ identical shapes.
267
+
268
+ """
269
+ return Indexable(torch.stack(values))[name_to_sym(name)()]
270
+
271
+
272
+ def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Tensor:
273
+ """
274
+ Selection operation that is the sum-type analogue of :func:`scatter`
275
+ in the sense that where :func:`scatter` propagates both of its arguments,
276
+ :func:`cond` propagates only one, depending on the value of a boolean ``case`` .
277
+
278
+ For a given ``fst`` , ``snd`` , and ``case`` , :func:`cond` returns
279
+ ``snd`` if the ``case`` is true, and ``fst`` otherwise,
280
+ analogous to a Python conditional expression ``snd if case else fst`` .
281
+ Unlike a Python conditional expression, however, the case may be a tensor,
282
+ and both branches are evaluated, as with :func:`torch.where` ::
283
+
284
+ >>> from effectful.internals.sugar import gensym
285
+ >>> b = gensym(int, name="b")
286
+ >>> fst, snd = Indexable(torch.randn(2, 3))[b()], Indexable(torch.randn(2, 3))[b()]
287
+ >>> case = (fst < snd).all(-1)
288
+ >>> x = cond(fst, snd, case)
289
+ >>> assert (to_tensor(x, [b]) == to_tensor(torch.where(case[..., None], snd, fst), [b])).all()
290
+
291
+ .. note::
292
+
293
+ :func:`cond` can be extended to new value types by registering
294
+ an implementation for the type using :func:`functools.singledispatch` .
295
+
296
+ :param fst: The value to return if ``case`` is ``False`` .
297
+ :param snd: The value to return if ``case`` is ``True`` .
298
+ :param case: A boolean value or tensor. If a tensor, should have event shape ``()`` .
299
+ """
300
+ return torch.where(
301
+ case_.reshape(case_.shape + (1,) * min(len(snd.shape), len(fst.shape))),
302
+ snd,
303
+ fst,
304
+ )
305
+
306
+
307
+ def cond_n(values: Dict[IndexSet, torch.Tensor], case: torch.Tensor) -> torch.Tensor:
308
+ assert len(values) > 0
309
+ assert all(isinstance(k, IndexSet) for k in values.keys())
310
+ result: Optional[torch.Tensor] = None
311
+ for indices, value in values.items():
312
+ tst = torch.as_tensor(
313
+ functools.reduce(
314
+ operator.or_, [case == index for index in next(iter(indices.values()))]
315
+ ),
316
+ dtype=torch.bool,
317
+ )
318
+ result = cond(result if result is not None else value, value, tst)
319
+ assert result is not None
320
+ return result
@@ -0,0 +1,259 @@
1
+ import numbers
2
+ import operator
3
+ from typing import Any, TypeVar
4
+
5
+ from typing_extensions import ParamSpec
6
+
7
+ from effectful.ops.syntax import NoDefaultRule, defdata, defop, syntactic_eq
8
+ from effectful.ops.types import Operation, Term
9
+
10
+ P = ParamSpec("P")
11
+ Q = ParamSpec("Q")
12
+ S = TypeVar("S")
13
+ T = TypeVar("T")
14
+ V = TypeVar("V")
15
+
16
+ T_Number = TypeVar("T_Number", bound=numbers.Number)
17
+
18
+
19
+ @defdata.register(numbers.Number)
20
+ @numbers.Number.register
21
+ class _NumberTerm(Term[numbers.Number]):
22
+ def __init__(
23
+ self, op: Operation[..., numbers.Number], args: tuple, kwargs: dict
24
+ ) -> None:
25
+ self._op = op
26
+ self._args = args
27
+ self._kwargs = kwargs
28
+
29
+ @property
30
+ def op(self) -> Operation[..., numbers.Number]:
31
+ return self._op
32
+
33
+ @property
34
+ def args(self) -> tuple:
35
+ return self._args
36
+
37
+ @property
38
+ def kwargs(self) -> dict:
39
+ return self._kwargs
40
+
41
+ def __hash__(self):
42
+ return hash((self.op, tuple(self.args), tuple(self.kwargs.items())))
43
+
44
+
45
+ # Complex specific methods
46
+ @defop
47
+ def eq(x: T_Number, y: T_Number) -> bool:
48
+ if not any(isinstance(a, Term) for a in (x, y)):
49
+ return operator.eq(x, y)
50
+ else:
51
+ return syntactic_eq(x, y)
52
+
53
+
54
+ def _wrap_cmp(op):
55
+ def _wrapped_op(x: T_Number, y: T_Number) -> bool:
56
+ if not any(isinstance(a, Term) for a in (x, y)):
57
+ return op(x, y)
58
+ else:
59
+ raise NoDefaultRule
60
+
61
+ _wrapped_op.__name__ = op.__name__
62
+ return _wrapped_op
63
+
64
+
65
+ def _wrap_binop(op):
66
+ def _wrapped_op(x: T_Number, y: T_Number) -> T_Number:
67
+ if not any(isinstance(a, Term) for a in (x, y)):
68
+ return op(x, y)
69
+ else:
70
+ raise NoDefaultRule
71
+
72
+ _wrapped_op.__name__ = op.__name__
73
+ return _wrapped_op
74
+
75
+
76
+ def _wrap_unop(op):
77
+ def _wrapped_op(x: T_Number) -> T_Number:
78
+ if not isinstance(x, Term):
79
+ return op(x)
80
+ else:
81
+ raise NoDefaultRule
82
+
83
+ _wrapped_op.__name__ = op.__name__
84
+ return _wrapped_op
85
+
86
+
87
+ add = defop(_wrap_binop(operator.add))
88
+ neg = defop(_wrap_unop(operator.neg))
89
+ pos = defop(_wrap_unop(operator.pos))
90
+ sub = defop(_wrap_binop(operator.sub))
91
+ mul = defop(_wrap_binop(operator.mul))
92
+ truediv = defop(_wrap_binop(operator.truediv))
93
+ pow = defop(_wrap_binop(operator.pow))
94
+ abs = defop(_wrap_unop(operator.abs))
95
+
96
+
97
+ @defdata.register(numbers.Complex)
98
+ @numbers.Complex.register
99
+ class _ComplexTerm(_NumberTerm, Term[numbers.Complex]):
100
+ def __bool__(self) -> bool:
101
+ raise ValueError("Cannot convert term to bool")
102
+
103
+ def __add__(self, other: Any) -> numbers.Real:
104
+ return add(self, other)
105
+
106
+ def __radd__(self, other: Any) -> numbers.Real:
107
+ return add(other, self)
108
+
109
+ def __neg__(self):
110
+ return neg(self)
111
+
112
+ def __pos__(self):
113
+ return pos(self)
114
+
115
+ def __sub__(self, other: Any) -> numbers.Real:
116
+ return sub(self, other)
117
+
118
+ def __rsub__(self, other: Any) -> numbers.Real:
119
+ return sub(other, self)
120
+
121
+ def __mul__(self, other: Any) -> numbers.Real:
122
+ return mul(self, other)
123
+
124
+ def __rmul__(self, other: Any) -> numbers.Real:
125
+ return mul(other, self)
126
+
127
+ def __truediv__(self, other: Any) -> numbers.Real:
128
+ return truediv(self, other)
129
+
130
+ def __rtruediv__(self, other: Any) -> numbers.Real:
131
+ return truediv(other, self)
132
+
133
+ def __pow__(self, other: Any) -> numbers.Real:
134
+ return pow(self, other)
135
+
136
+ def __rpow__(self, other: Any) -> numbers.Real:
137
+ return pow(other, self)
138
+
139
+ def __abs__(self) -> numbers.Real:
140
+ return abs(self)
141
+
142
+ def __eq__(self, other: Any) -> bool:
143
+ return eq(self, other)
144
+
145
+
146
+ # Real specific methods
147
+ floordiv = defop(_wrap_binop(operator.floordiv))
148
+ mod = defop(_wrap_binop(operator.mod))
149
+ lt = defop(_wrap_cmp(operator.lt))
150
+ le = defop(_wrap_cmp(operator.le))
151
+ gt = defop(_wrap_cmp(operator.gt))
152
+ ge = defop(_wrap_cmp(operator.ge))
153
+
154
+
155
+ @defdata.register(numbers.Real)
156
+ @numbers.Real.register
157
+ class _RealTerm(_ComplexTerm, Term[numbers.Real]):
158
+ # Real specific methods
159
+ def __float__(self) -> float:
160
+ raise ValueError("Cannot convert term to float")
161
+
162
+ def __trunc__(self) -> numbers.Integral:
163
+ raise NotImplementedError
164
+
165
+ def __floor__(self) -> numbers.Integral:
166
+ raise NotImplementedError
167
+
168
+ def __ceil__(self) -> numbers.Integral:
169
+ raise NotImplementedError
170
+
171
+ def __round__(self, ndigits=None) -> numbers.Integral:
172
+ raise NotImplementedError
173
+
174
+ def __floordiv__(self, other):
175
+ return floordiv(self, other)
176
+
177
+ def __rfloordiv__(self, other):
178
+ return floordiv(other, self)
179
+
180
+ def __mod__(self, other):
181
+ return mod(self, other)
182
+
183
+ def __rmod__(self, other):
184
+ return mod(other, self)
185
+
186
+ def __lt__(self, other):
187
+ return lt(self, other)
188
+
189
+ def __le__(self, other):
190
+ return le(self, other)
191
+
192
+
193
+ @defdata.register(numbers.Rational)
194
+ @numbers.Rational.register
195
+ class _RationalTerm(_RealTerm, Term[numbers.Rational]):
196
+ @property
197
+ def numerator(self):
198
+ raise NotImplementedError
199
+
200
+ @property
201
+ def denominator(self):
202
+ raise NotImplementedError
203
+
204
+
205
+ # Integral specific methods
206
+ index = defop(_wrap_unop(operator.index))
207
+ lshift = defop(_wrap_binop(operator.lshift))
208
+ rshift = defop(_wrap_binop(operator.rshift))
209
+ and_ = defop(_wrap_binop(operator.and_))
210
+ xor = defop(_wrap_binop(operator.xor))
211
+ or_ = defop(_wrap_binop(operator.or_))
212
+ invert = defop(_wrap_unop(operator.invert))
213
+
214
+
215
+ @defdata.register(numbers.Integral)
216
+ @numbers.Integral.register
217
+ class _IntegralTerm(_RationalTerm, Term[numbers.Integral]):
218
+ # Integral specific methods
219
+ def __int__(self) -> int:
220
+ raise ValueError("Cannot convert term to int")
221
+
222
+ def __index__(self) -> numbers.Integral:
223
+ return index(self)
224
+
225
+ def __pow__(self, exponent: Any, modulus=None) -> numbers.Integral:
226
+ return pow(self, exponent)
227
+
228
+ def __lshift__(self, other):
229
+ return lshift(self, other)
230
+
231
+ def __rlshift__(self, other):
232
+ return lshift(other, self)
233
+
234
+ def __rshift__(self, other):
235
+ return rshift(self, other)
236
+
237
+ def __rrshift__(self, other):
238
+ return rshift(other, self)
239
+
240
+ def __and__(self, other):
241
+ return and_(self, other)
242
+
243
+ def __rand__(self, other):
244
+ return and_(other, self)
245
+
246
+ def __xor__(self, other):
247
+ return xor(self, other)
248
+
249
+ def __rxor__(self, other):
250
+ return xor(other, self)
251
+
252
+ def __or__(self, other):
253
+ return or_(self, other)
254
+
255
+ def __ror__(self, other):
256
+ return or_(other, self)
257
+
258
+ def __invert__(self):
259
+ return invert(self)