effectful 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/handlers/indexed.py +23 -24
- effectful/handlers/jax/__init__.py +14 -0
- effectful/handlers/jax/_handlers.py +293 -0
- effectful/handlers/jax/_terms.py +502 -0
- effectful/handlers/jax/numpy/__init__.py +23 -0
- effectful/handlers/jax/numpy/linalg.py +13 -0
- effectful/handlers/jax/scipy/special.py +11 -0
- effectful/handlers/numpyro.py +562 -0
- effectful/handlers/pyro.py +565 -214
- effectful/handlers/torch.py +297 -168
- effectful/internals/runtime.py +6 -13
- effectful/internals/tensor_utils.py +32 -0
- effectful/internals/unification.py +900 -0
- effectful/ops/semantics.py +101 -77
- effectful/ops/syntax.py +813 -251
- effectful/ops/types.py +121 -29
- {effectful-0.1.0.dist-info → effectful-0.2.0.dist-info}/METADATA +59 -56
- effectful-0.2.0.dist-info/RECORD +26 -0
- {effectful-0.1.0.dist-info → effectful-0.2.0.dist-info}/WHEEL +1 -1
- effectful/handlers/numbers.py +0 -263
- effectful-0.1.0.dist-info/RECORD +0 -18
- {effectful-0.1.0.dist-info → effectful-0.2.0.dist-info/licenses}/LICENSE.md +0 -0
- {effectful-0.1.0.dist-info → effectful-0.2.0.dist-info}/top_level.txt +0 -0
effectful/handlers/indexed.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1
1
|
import functools
|
2
2
|
import operator
|
3
|
-
from
|
3
|
+
from collections.abc import Iterable
|
4
|
+
from typing import Any
|
4
5
|
|
5
6
|
import torch
|
6
7
|
|
7
|
-
from effectful.handlers.torch import
|
8
|
+
from effectful.handlers.torch import sizesof
|
8
9
|
from effectful.ops.syntax import deffn, defop
|
9
10
|
from effectful.ops.types import Operation
|
10
11
|
|
11
|
-
K = TypeVar("K")
|
12
|
-
T = TypeVar("T")
|
13
12
|
|
14
|
-
|
15
|
-
class IndexSet(Dict[str, Set[int]]):
|
13
|
+
class IndexSet(dict[str, set[int]]):
|
16
14
|
"""
|
17
15
|
:class:`IndexSet` s represent the support of an indexed value,
|
18
16
|
for which free variables correspond to single interventions and indices
|
@@ -32,13 +30,13 @@ class IndexSet(Dict[str, Set[int]]):
|
|
32
30
|
for which a value is defined::
|
33
31
|
|
34
32
|
>>> IndexSet(x={0, 1}, y={2, 3})
|
35
|
-
IndexSet({x: {0, 1}, y: {2, 3}})
|
33
|
+
IndexSet({'x': {0, 1}, 'y': {2, 3}})
|
36
34
|
|
37
35
|
:class:`IndexSet` 's constructor will automatically drop empty entries
|
38
36
|
and attempt to convert input values to :class:`set` s::
|
39
37
|
|
40
38
|
>>> IndexSet(x=[0, 0, 1], y=set(), z=2)
|
41
|
-
IndexSet({x: {0, 1}, z: {2}})
|
39
|
+
IndexSet({'x': {0, 1}, 'z': {2}})
|
42
40
|
|
43
41
|
:class:`IndexSet` s are also hashable and can be used as keys in :class:`dict` s::
|
44
42
|
|
@@ -47,7 +45,7 @@ class IndexSet(Dict[str, Set[int]]):
|
|
47
45
|
True
|
48
46
|
"""
|
49
47
|
|
50
|
-
def __init__(self, **mapping:
|
48
|
+
def __init__(self, **mapping: int | Iterable[int]):
|
51
49
|
index_set = {}
|
52
50
|
for k, vs in mapping.items():
|
53
51
|
indexes = {vs} if isinstance(vs, int) else set(vs)
|
@@ -161,12 +159,12 @@ def indices_of(value: Any) -> IndexSet:
|
|
161
159
|
)
|
162
160
|
|
163
161
|
|
164
|
-
@functools.
|
165
|
-
def name_to_sym(name: str) -> Operation[[],
|
166
|
-
return defop(
|
162
|
+
@functools.cache
|
163
|
+
def name_to_sym(name: str) -> Operation[[], torch.Tensor]:
|
164
|
+
return defop(torch.Tensor, name=name)
|
167
165
|
|
168
166
|
|
169
|
-
def gather(value: torch.Tensor, indexset: IndexSet
|
167
|
+
def gather(value: torch.Tensor, indexset: IndexSet) -> torch.Tensor:
|
170
168
|
"""
|
171
169
|
Selects entries from an indexed value at the indices in a :class:`IndexSet` .
|
172
170
|
:func:`gather` is useful in conjunction with :class:`MultiWorldCounterfactual`
|
@@ -230,9 +228,7 @@ def gather(value: torch.Tensor, indexset: IndexSet, **kwargs) -> torch.Tensor:
|
|
230
228
|
"""
|
231
229
|
indexset_vars = {name_to_sym(name): inds for name, inds in indexset.items()}
|
232
230
|
binding = {
|
233
|
-
k: functools.partial(
|
234
|
-
lambda v: v, Indexable(torch.tensor(list(indexset_vars[k])))[k()]
|
235
|
-
)
|
231
|
+
k: functools.partial(lambda v: v, torch.tensor(list(indexset_vars[k]))[k()])
|
236
232
|
for k in sizesof(value).keys()
|
237
233
|
if k in indexset_vars
|
238
234
|
}
|
@@ -241,14 +237,15 @@ def gather(value: torch.Tensor, indexset: IndexSet, **kwargs) -> torch.Tensor:
|
|
241
237
|
|
242
238
|
|
243
239
|
def stack(
|
244
|
-
values:
|
240
|
+
values: tuple[torch.Tensor, ...] | list[torch.Tensor], name: str
|
245
241
|
) -> torch.Tensor:
|
246
242
|
"""Stack a sequence of indexed values, creating a new dimension. The new
|
247
243
|
dimension is indexed by `dim`. The indexed values in the stack must have
|
248
244
|
identical shapes.
|
249
245
|
|
250
246
|
"""
|
251
|
-
|
247
|
+
values = torch.distributions.utils.broadcast_all(*values)
|
248
|
+
return torch.stack(values)[name_to_sym(name)()]
|
252
249
|
|
253
250
|
|
254
251
|
def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Tensor:
|
@@ -263,12 +260,14 @@ def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Ten
|
|
263
260
|
Unlike a Python conditional expression, however, the case may be a tensor,
|
264
261
|
and both branches are evaluated, as with :func:`torch.where` ::
|
265
262
|
|
266
|
-
>>> from effectful.
|
267
|
-
>>>
|
268
|
-
|
263
|
+
>>> from effectful.ops.syntax import defop
|
264
|
+
>>> from effectful.handlers.torch import bind_dims
|
265
|
+
|
266
|
+
>>> b = defop(torch.Tensor, name="b")
|
267
|
+
>>> fst, snd = torch.randn(2, 3)[b()], torch.randn(2, 3)[b()]
|
269
268
|
>>> case = (fst < snd).all(-1)
|
270
269
|
>>> x = cond(fst, snd, case)
|
271
|
-
>>> assert (
|
270
|
+
>>> assert (bind_dims(x, b) == bind_dims(torch.where(case[..., None], snd, fst), b)).all()
|
272
271
|
|
273
272
|
.. note::
|
274
273
|
|
@@ -286,10 +285,10 @@ def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Ten
|
|
286
285
|
)
|
287
286
|
|
288
287
|
|
289
|
-
def cond_n(values:
|
288
|
+
def cond_n(values: dict[IndexSet, torch.Tensor], case: torch.Tensor) -> torch.Tensor:
|
290
289
|
assert len(values) > 0
|
291
290
|
assert all(isinstance(k, IndexSet) for k in values.keys())
|
292
|
-
result:
|
291
|
+
result: torch.Tensor | None = None
|
293
292
|
for indices, value in values.items():
|
294
293
|
tst = torch.as_tensor(
|
295
294
|
functools.reduce(
|
@@ -0,0 +1,14 @@
|
|
1
|
+
try:
|
2
|
+
# Dummy import to check if jax is installed
|
3
|
+
import jax # noqa: F401
|
4
|
+
except ImportError:
|
5
|
+
raise ImportError("Jax is required to use effectful.handlers.jax")
|
6
|
+
|
7
|
+
# side effect: register defdata for jax.Array
|
8
|
+
import effectful.handlers.jax._terms # noqa: F401
|
9
|
+
|
10
|
+
from ._handlers import bind_dims as bind_dims
|
11
|
+
from ._handlers import jax_getitem as jax_getitem
|
12
|
+
from ._handlers import jit as jit
|
13
|
+
from ._handlers import sizesof as sizesof
|
14
|
+
from ._handlers import unbind_dims as unbind_dims
|
@@ -0,0 +1,293 @@
|
|
1
|
+
import functools
|
2
|
+
import typing
|
3
|
+
from collections.abc import Callable, Mapping, Sequence
|
4
|
+
from types import EllipsisType
|
5
|
+
from typing import Annotated
|
6
|
+
|
7
|
+
try:
|
8
|
+
import jax
|
9
|
+
import jax.numpy as jnp
|
10
|
+
except ImportError:
|
11
|
+
raise ImportError("JAX is required to use effectful.handlers.jax")
|
12
|
+
|
13
|
+
import tree
|
14
|
+
|
15
|
+
from effectful.ops.semantics import fvsof, typeof
|
16
|
+
from effectful.ops.syntax import (
|
17
|
+
Scoped,
|
18
|
+
_CustomSingleDispatchCallable,
|
19
|
+
defdata,
|
20
|
+
deffn,
|
21
|
+
defop,
|
22
|
+
defterm,
|
23
|
+
syntactic_eq,
|
24
|
+
)
|
25
|
+
from effectful.ops.types import Expr, NotHandled, Operation, Term
|
26
|
+
|
27
|
+
# + An element of an array index expression.
|
28
|
+
IndexElement = None | int | slice | Sequence[int] | EllipsisType | jax.Array
|
29
|
+
|
30
|
+
|
31
|
+
def is_eager_array(x):
|
32
|
+
return isinstance(x, jax.Array) or (
|
33
|
+
isinstance(x, Term)
|
34
|
+
and x.op is jax_getitem
|
35
|
+
and isinstance(x.args[0], jax.Array)
|
36
|
+
and all(
|
37
|
+
(not isinstance(k, Term)) or (not k.args and not k.kwargs)
|
38
|
+
for k in x.args[1]
|
39
|
+
)
|
40
|
+
and not x.kwargs
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
def sizesof(value) -> Mapping[Operation[[], jax.Array], int]:
|
45
|
+
"""Return the sizes of named dimensions in an array expression.
|
46
|
+
|
47
|
+
Sizes are inferred from the array shape.
|
48
|
+
|
49
|
+
:param value: An array expression.
|
50
|
+
:return: A mapping from named dimensions to their sizes.
|
51
|
+
|
52
|
+
**Example usage**:
|
53
|
+
|
54
|
+
>>> a, b = defop(jax.Array, name='a'), defop(jax.Array, name='b')
|
55
|
+
>>> sizes = sizesof(jax_getitem(jnp.ones((2, 3)), [a(), b()]))
|
56
|
+
>>> assert sizes[a] == 2 and sizes[b] == 3
|
57
|
+
"""
|
58
|
+
sizes: dict[Operation[[], jax.Array], int] = {}
|
59
|
+
|
60
|
+
def update_sizes(sizes, op, size):
|
61
|
+
old_size = sizes.get(op)
|
62
|
+
if old_size is not None and size != old_size:
|
63
|
+
raise ValueError(
|
64
|
+
f"Named index {op} used in incompatible dimensions of size {old_size} and {size}"
|
65
|
+
)
|
66
|
+
sizes[op] = size
|
67
|
+
|
68
|
+
def _getitem_sizeof(x: jax.Array, key: tuple[Expr[IndexElement], ...]):
|
69
|
+
if is_eager_array(x):
|
70
|
+
for i, k in enumerate(key):
|
71
|
+
if isinstance(k, Term) and len(k.args) == 0 and len(k.kwargs) == 0:
|
72
|
+
update_sizes(sizes, k.op, x.shape[i])
|
73
|
+
|
74
|
+
def _sizesof(expr):
|
75
|
+
expr = defterm(expr)
|
76
|
+
if isinstance(expr, Term):
|
77
|
+
for x in tree.flatten((expr.args, expr.kwargs)):
|
78
|
+
_sizesof(x)
|
79
|
+
if expr.op is jax_getitem:
|
80
|
+
_getitem_sizeof(*expr.args)
|
81
|
+
elif tree.is_nested(expr):
|
82
|
+
for x in tree.flatten(expr):
|
83
|
+
_sizesof(x)
|
84
|
+
|
85
|
+
_sizesof(value)
|
86
|
+
return sizes
|
87
|
+
|
88
|
+
|
89
|
+
def _partial_eval(t: Expr[jax.Array]) -> Expr[jax.Array]:
|
90
|
+
"""Partially evaluate a term with respect to its sized free variables."""
|
91
|
+
|
92
|
+
sized_fvs = sizesof(t)
|
93
|
+
if not sized_fvs:
|
94
|
+
return t
|
95
|
+
|
96
|
+
def _is_eager(t):
|
97
|
+
return not isinstance(t, Term) or t.op in sized_fvs or is_eager_array(t)
|
98
|
+
|
99
|
+
if not (
|
100
|
+
isinstance(t, Term)
|
101
|
+
and all(_is_eager(a) for a in tree.flatten((t.args, t.kwargs)))
|
102
|
+
):
|
103
|
+
return t
|
104
|
+
|
105
|
+
tpe_jax_fn = jax.vmap(deffn(t, *sized_fvs.keys()))
|
106
|
+
|
107
|
+
# Create indices for each dimension
|
108
|
+
indices = jnp.meshgrid(
|
109
|
+
*[jnp.arange(size) for size in sized_fvs.values()], indexing="ij"
|
110
|
+
)
|
111
|
+
|
112
|
+
# Flatten indices for vmap
|
113
|
+
flat_indices = [idx.reshape(-1) for idx in indices]
|
114
|
+
|
115
|
+
# Apply vmap
|
116
|
+
flat_result = tpe_jax_fn(*flat_indices)
|
117
|
+
|
118
|
+
def reindex_flat_array(t):
|
119
|
+
if not isinstance(t, jax.Array):
|
120
|
+
return t
|
121
|
+
|
122
|
+
result_shape = indices[0].shape + t.shape[1:]
|
123
|
+
result = jnp.reshape(t, result_shape)
|
124
|
+
return jax_getitem(result, tuple(k() for k in sized_fvs.keys()))
|
125
|
+
|
126
|
+
result = tree.map_structure(reindex_flat_array, flat_result)
|
127
|
+
return result
|
128
|
+
|
129
|
+
|
130
|
+
@functools.cache
|
131
|
+
def _register_jax_op[**P, T](jax_fn: Callable[P, T]):
|
132
|
+
if getattr(jax_fn, "__name__", None) == "__getitem__":
|
133
|
+
return jax_getitem
|
134
|
+
|
135
|
+
@defop
|
136
|
+
def _jax_op(*args, **kwargs) -> jax.Array:
|
137
|
+
tm = defdata(_jax_op, *args, **kwargs)
|
138
|
+
sized_fvs = sizesof(tm)
|
139
|
+
|
140
|
+
if (
|
141
|
+
_jax_op is jax_getitem
|
142
|
+
and not isinstance(args[0], Term)
|
143
|
+
and sized_fvs
|
144
|
+
and args[1]
|
145
|
+
and all(isinstance(k, Term) and k.op in sized_fvs for k in args[1])
|
146
|
+
):
|
147
|
+
raise NotHandled
|
148
|
+
elif sized_fvs and set(sized_fvs.keys()) == fvsof(tm) - {jax_getitem, _jax_op}:
|
149
|
+
# note: this cast is a lie. partial_eval can return non-arrays, as
|
150
|
+
# can jax_fn. for example, some jax functions return tuples,
|
151
|
+
# which partial_eval handles.
|
152
|
+
return typing.cast(jax.Array, _partial_eval(tm))
|
153
|
+
elif not any(
|
154
|
+
tree.flatten(
|
155
|
+
tree.map_structure(lambda x: isinstance(x, Term), (args, kwargs))
|
156
|
+
)
|
157
|
+
):
|
158
|
+
return typing.cast(jax.Array, jax_fn(*args, **kwargs))
|
159
|
+
else:
|
160
|
+
raise NotHandled
|
161
|
+
|
162
|
+
functools.update_wrapper(_jax_op, jax_fn)
|
163
|
+
return _jax_op
|
164
|
+
|
165
|
+
|
166
|
+
@functools.cache
|
167
|
+
def _register_jax_op_no_partial_eval[**P, T](jax_fn: Callable[P, T]):
|
168
|
+
# FIXME: Presumably not all jax ops return arrays. In other cases, we won't
|
169
|
+
# get the right kind of term.
|
170
|
+
@defop
|
171
|
+
def _jax_op(*args, **kwargs) -> jax.Array:
|
172
|
+
if not any(
|
173
|
+
tree.flatten(
|
174
|
+
tree.map_structure(lambda x: isinstance(x, Term), (args, kwargs))
|
175
|
+
)
|
176
|
+
):
|
177
|
+
return typing.cast(jax.Array, jax_fn(*args, **kwargs))
|
178
|
+
else:
|
179
|
+
raise NotHandled
|
180
|
+
|
181
|
+
functools.update_wrapper(_jax_op, jax_fn)
|
182
|
+
return _jax_op
|
183
|
+
|
184
|
+
|
185
|
+
@_register_jax_op
|
186
|
+
def jax_getitem(x: jax.Array, key: tuple[IndexElement, ...]) -> jax.Array:
|
187
|
+
"""Operation for indexing an array. Unlike the standard __getitem__ method,
|
188
|
+
this operation correctly handles indexing with terms.
|
189
|
+
|
190
|
+
"""
|
191
|
+
return x[tuple(key)]
|
192
|
+
|
193
|
+
|
194
|
+
@defop
|
195
|
+
@_CustomSingleDispatchCallable
|
196
|
+
def bind_dims[T, A, B](
|
197
|
+
__dispatch: Callable[[type], Callable[..., T]],
|
198
|
+
value: Annotated[T, Scoped[A | B]],
|
199
|
+
*names: Annotated[Operation[[], jax.Array], Scoped[B]],
|
200
|
+
) -> Annotated[T, Scoped[A]]:
|
201
|
+
"""Convert named dimensions to positional dimensions.
|
202
|
+
|
203
|
+
:param t: An array.
|
204
|
+
:param args: Named dimensions to convert to positional dimensions.
|
205
|
+
These positional dimensions will appear at the beginning of the
|
206
|
+
shape.
|
207
|
+
:return: An array with the named dimensions in ``args`` converted to positional dimensions.
|
208
|
+
|
209
|
+
**Example usage**:
|
210
|
+
|
211
|
+
>>> import jax.numpy as jnp
|
212
|
+
>>> from effectful.ops.syntax import defop
|
213
|
+
>>> a, b = defop(jax.Array, name='a'), defop(jax.Array, name='b')
|
214
|
+
>>> t = jax_getitem(jnp.ones((2, 3)), [a(), b()])
|
215
|
+
>>> bind_dims(t, b, a).shape
|
216
|
+
(3, 2)
|
217
|
+
"""
|
218
|
+
if tree.is_nested(value):
|
219
|
+
return tree.map_structure(lambda v: bind_dims(v, *names), value)
|
220
|
+
|
221
|
+
semantic_type = typeof(value)
|
222
|
+
return __dispatch(semantic_type)(value, *names)
|
223
|
+
|
224
|
+
|
225
|
+
@defop
|
226
|
+
@_CustomSingleDispatchCallable
|
227
|
+
def unbind_dims[T, A, B](
|
228
|
+
__dispatch: Callable[[type], Callable[..., T]],
|
229
|
+
value: Annotated[T, Scoped[A | B]],
|
230
|
+
*names: Annotated[Operation[[], jax.Array], Scoped[B]],
|
231
|
+
) -> Annotated[T, Scoped[A | B]]:
|
232
|
+
"""Convert positional dimensions to named dimensions."""
|
233
|
+
if tree.is_nested(value):
|
234
|
+
return tree.map_structure(lambda v: unbind_dims(v, *names), value)
|
235
|
+
|
236
|
+
semantic_type = typeof(value)
|
237
|
+
return __dispatch(semantic_type)(value, *names)
|
238
|
+
|
239
|
+
|
240
|
+
def jit(f, *args, **kwargs):
|
241
|
+
f_noindex, f_reindex = _indexed_func_wrapper(f, jax_getitem, sizesof)
|
242
|
+
f_noindex_jitted = jax.jit(f_noindex, *args, **kwargs)
|
243
|
+
return lambda *args, **kwargs: f_reindex(f_noindex_jitted(*args, **kwargs))
|
244
|
+
|
245
|
+
|
246
|
+
def _indexed_func_wrapper[**P, S, T](
|
247
|
+
func: Callable[P, T], getitem, sizesof
|
248
|
+
) -> tuple[Callable[P, S], Callable[[S], T]]:
|
249
|
+
# index expressions for the result of the function
|
250
|
+
indexes = None
|
251
|
+
|
252
|
+
# hide index lists from tree.map_structure
|
253
|
+
class Indexes:
|
254
|
+
def __init__(self, sizes):
|
255
|
+
self.sizes = sizes
|
256
|
+
self.indexes = list(sizes.keys())
|
257
|
+
|
258
|
+
# strip named indexes from the result of the function and store them
|
259
|
+
def deindexed(*args, **kwargs):
|
260
|
+
nonlocal indexes
|
261
|
+
|
262
|
+
def deindex_tensor(t, i):
|
263
|
+
t_ = bind_dims(t, *i.sizes.keys())
|
264
|
+
assert all(t_.shape[j] == i.sizes[v] for j, v in enumerate(i.sizes))
|
265
|
+
return t_
|
266
|
+
|
267
|
+
ret = func(*args, **kwargs)
|
268
|
+
indexes = tree.map_structure(lambda t: Indexes(sizesof(t)), ret)
|
269
|
+
tensors = tree.map_structure(lambda t, i: deindex_tensor(t, i), ret, indexes)
|
270
|
+
return tensors
|
271
|
+
|
272
|
+
# reapply the stored indexes to a result
|
273
|
+
def reindex(ret, starting_dim=0):
|
274
|
+
def index_expr(i):
|
275
|
+
return (slice(None),) * (starting_dim) + tuple(x() for x in i.indexes)
|
276
|
+
|
277
|
+
if tree.is_nested(ret):
|
278
|
+
indexed_ret = tree.map_structure(
|
279
|
+
lambda t, i: getitem(t, index_expr(i)), ret, indexes
|
280
|
+
)
|
281
|
+
else:
|
282
|
+
indexed_ret = getitem(ret, index_expr(indexes))
|
283
|
+
|
284
|
+
return indexed_ret
|
285
|
+
|
286
|
+
return deindexed, reindex
|
287
|
+
|
288
|
+
|
289
|
+
@syntactic_eq.register
|
290
|
+
def _(x: jax.typing.ArrayLike, other) -> bool:
|
291
|
+
return isinstance(other, jax.typing.ArrayLike) and bool( # type: ignore[arg-type]
|
292
|
+
(jnp.asarray(x) == jnp.asarray(other)).all()
|
293
|
+
)
|