effectful 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/handlers/indexed.py +27 -46
- effectful/handlers/jax/__init__.py +14 -0
- effectful/handlers/jax/_handlers.py +293 -0
- effectful/handlers/jax/_terms.py +502 -0
- effectful/handlers/jax/numpy/__init__.py +23 -0
- effectful/handlers/jax/numpy/linalg.py +13 -0
- effectful/handlers/jax/scipy/special.py +11 -0
- effectful/handlers/numpyro.py +562 -0
- effectful/handlers/pyro.py +565 -214
- effectful/handlers/torch.py +321 -169
- effectful/internals/runtime.py +6 -13
- effectful/internals/tensor_utils.py +32 -0
- effectful/internals/unification.py +900 -0
- effectful/ops/semantics.py +104 -84
- effectful/ops/syntax.py +1276 -167
- effectful/ops/types.py +141 -35
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/METADATA +65 -57
- effectful-0.2.0.dist-info/RECORD +26 -0
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/WHEEL +1 -1
- effectful/handlers/numbers.py +0 -259
- effectful/internals/base_impl.py +0 -259
- effectful-0.0.1.dist-info/RECORD +0 -19
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info/licenses}/LICENSE.md +0 -0
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/top_level.txt +0 -0
effectful/ops/semantics.py
CHANGED
@@ -1,24 +1,22 @@
|
|
1
|
+
import collections.abc
|
1
2
|
import contextlib
|
2
|
-
import
|
3
|
-
|
3
|
+
import dataclasses
|
4
|
+
import types
|
5
|
+
import typing
|
6
|
+
from typing import Any
|
4
7
|
|
5
|
-
import
|
6
|
-
from
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
S = TypeVar("S")
|
14
|
-
T = TypeVar("T")
|
15
|
-
V = TypeVar("V")
|
8
|
+
from effectful.ops.syntax import defop
|
9
|
+
from effectful.ops.types import (
|
10
|
+
Expr,
|
11
|
+
Interpretation,
|
12
|
+
NotHandled, # noqa: F401
|
13
|
+
Operation,
|
14
|
+
Term,
|
15
|
+
)
|
16
16
|
|
17
17
|
|
18
18
|
@defop # type: ignore
|
19
|
-
def apply(
|
20
|
-
intp: Interpretation[S, T], op: Operation[P, S], *args: P.args, **kwargs: P.kwargs
|
21
|
-
) -> T:
|
19
|
+
def apply[**P, T](op: Operation[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
22
20
|
"""Apply ``op`` to ``args``, ``kwargs`` in interpretation ``intp``.
|
23
21
|
|
24
22
|
Handling :func:`apply` changes the evaluation strategy of terms.
|
@@ -39,44 +37,25 @@ def apply(
|
|
39
37
|
|
40
38
|
By installing an :func:`apply` handler, we capture the term instead:
|
41
39
|
|
42
|
-
>>>
|
40
|
+
>>> def default(*args, **kwargs):
|
41
|
+
... raise NotHandled
|
42
|
+
>>> with handler({apply: default }):
|
43
43
|
... term = mul(add(1, 2), 3)
|
44
|
-
>>> term
|
44
|
+
>>> print(str(term))
|
45
45
|
mul(add(1, 2), 3)
|
46
46
|
|
47
47
|
"""
|
48
|
+
from effectful.internals.runtime import get_interpretation
|
49
|
+
|
50
|
+
intp = get_interpretation()
|
48
51
|
if op in intp:
|
49
52
|
return intp[op](*args, **kwargs)
|
50
53
|
elif apply in intp:
|
51
|
-
return intp[apply](
|
54
|
+
return intp[apply](op, *args, **kwargs)
|
52
55
|
else:
|
53
56
|
return op.__default_rule__(*args, **kwargs) # type: ignore
|
54
57
|
|
55
58
|
|
56
|
-
@defop # type: ignore
|
57
|
-
def call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
58
|
-
"""An operation that eliminates a callable term.
|
59
|
-
|
60
|
-
This operation is invoked by the ``__call__`` method of a callable term.
|
61
|
-
|
62
|
-
"""
|
63
|
-
if not isinstance(fn, Term):
|
64
|
-
fn = defterm(fn)
|
65
|
-
|
66
|
-
if isinstance(fn, Term) and fn.op is deffn:
|
67
|
-
body: Expr[Callable[P, T]] = fn.args[0]
|
68
|
-
argvars: tuple[Operation, ...] = fn.args[1:]
|
69
|
-
kwvars: dict[str, Operation] = fn.kwargs
|
70
|
-
subs = {
|
71
|
-
**{v: functools.partial(lambda x: x, a) for v, a in zip(argvars, args)},
|
72
|
-
**{kwvars[k]: functools.partial(lambda x: x, kwargs[k]) for k in kwargs},
|
73
|
-
}
|
74
|
-
with handler(subs):
|
75
|
-
return evaluate(body)
|
76
|
-
else:
|
77
|
-
raise NoDefaultRule
|
78
|
-
|
79
|
-
|
80
59
|
@defop
|
81
60
|
def fwd(*args, **kwargs) -> Any:
|
82
61
|
"""Forward execution to the next most enclosing handler.
|
@@ -93,9 +72,7 @@ def fwd(*args, **kwargs) -> Any:
|
|
93
72
|
raise RuntimeError("fwd should only be called in the context of a handler")
|
94
73
|
|
95
74
|
|
96
|
-
def coproduct(
|
97
|
-
intp: Interpretation[S, T], intp2: Interpretation[S, T]
|
98
|
-
) -> Interpretation[S, T]:
|
75
|
+
def coproduct(intp: Interpretation, intp2: Interpretation) -> Interpretation:
|
99
76
|
"""The coproduct of two interpretations handles any effect that is handled
|
100
77
|
by either. If both interpretations handle an effect, ``intp2`` takes
|
101
78
|
precedence.
|
@@ -151,7 +128,7 @@ def coproduct(
|
|
151
128
|
if op is fwd or op is _get_args:
|
152
129
|
res[op] = i2 # fast path for special cases, should be equivalent if removed
|
153
130
|
else:
|
154
|
-
i1 = intp.get(op, op.__default_rule__)
|
131
|
+
i1 = intp.get(op, op.__default_rule__)
|
155
132
|
|
156
133
|
# calling fwd in the right handler should dispatch to the left handler
|
157
134
|
res[op] = _set_prompt(fwd, _restore_args(_save_args(i1)), _save_args(i2))
|
@@ -159,9 +136,7 @@ def coproduct(
|
|
159
136
|
return res
|
160
137
|
|
161
138
|
|
162
|
-
def product(
|
163
|
-
intp: Interpretation[S, T], intp2: Interpretation[S, T]
|
164
|
-
) -> Interpretation[S, T]:
|
139
|
+
def product(intp: Interpretation, intp2: Interpretation) -> Interpretation:
|
165
140
|
"""The product of two interpretations handles any effect that is handled by
|
166
141
|
``intp2``. Handlers in ``intp2`` may override handlers in ``intp``, but
|
167
142
|
those changes are not visible to the handlers in ``intp``. In this way,
|
@@ -197,7 +172,7 @@ def product(
|
|
197
172
|
|
198
173
|
"""
|
199
174
|
if any(op in intp for op in intp2): # alpha-rename
|
200
|
-
renaming = {op: defop(op) for op in intp2 if op in intp}
|
175
|
+
renaming: Interpretation = {op: defop(op) for op in intp2 if op in intp}
|
201
176
|
intp_fresh = {renaming.get(op, op): handler(renaming)(intp[op]) for op in intp}
|
202
177
|
return product(intp_fresh, intp2)
|
203
178
|
else:
|
@@ -207,7 +182,7 @@ def product(
|
|
207
182
|
|
208
183
|
|
209
184
|
@contextlib.contextmanager
|
210
|
-
def runner(intp: Interpretation
|
185
|
+
def runner(intp: Interpretation):
|
211
186
|
"""Install an interpretation by taking a product with the current
|
212
187
|
interpretation.
|
213
188
|
|
@@ -215,7 +190,7 @@ def runner(intp: Interpretation[S, T]):
|
|
215
190
|
from effectful.internals.runtime import get_interpretation, interpreter
|
216
191
|
|
217
192
|
@interpreter(get_interpretation())
|
218
|
-
def _reapply
|
193
|
+
def _reapply[**P, S](op: Operation[P, S], *args: P.args, **kwargs: P.kwargs):
|
219
194
|
return op(*args, **kwargs)
|
220
195
|
|
221
196
|
with interpreter({apply: _reapply, **intp}):
|
@@ -223,7 +198,7 @@ def runner(intp: Interpretation[S, T]):
|
|
223
198
|
|
224
199
|
|
225
200
|
@contextlib.contextmanager
|
226
|
-
def handler(intp: Interpretation
|
201
|
+
def handler(intp: Interpretation):
|
227
202
|
"""Install an interpretation by taking a coproduct with the current
|
228
203
|
interpretation.
|
229
204
|
|
@@ -234,7 +209,7 @@ def handler(intp: Interpretation[S, T]):
|
|
234
209
|
yield intp
|
235
210
|
|
236
211
|
|
237
|
-
def evaluate(expr: Expr[T], *, intp:
|
212
|
+
def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]:
|
238
213
|
"""Evaluate expression ``expr`` using interpretation ``intp``. If no
|
239
214
|
interpretation is provided, uses the current interpretation.
|
240
215
|
|
@@ -245,33 +220,61 @@ def evaluate(expr: Expr[T], *, intp: Optional[Interpretation[S, T]] = None) -> E
|
|
245
220
|
|
246
221
|
>>> @defop
|
247
222
|
... def add(x: int, y: int) -> int:
|
248
|
-
... raise
|
223
|
+
... raise NotHandled
|
249
224
|
>>> expr = add(1, add(2, 3))
|
250
|
-
>>> expr
|
225
|
+
>>> print(str(expr))
|
251
226
|
add(1, add(2, 3))
|
252
227
|
>>> evaluate(expr, intp={add: lambda x, y: x + y})
|
253
228
|
6
|
254
229
|
|
255
230
|
"""
|
256
|
-
|
257
|
-
from effectful.internals.runtime import get_interpretation
|
258
|
-
|
259
|
-
intp = get_interpretation()
|
231
|
+
from effectful.internals.runtime import get_interpretation, interpreter
|
260
232
|
|
261
|
-
|
233
|
+
if intp is not None:
|
234
|
+
return interpreter(intp)(evaluate)(expr)
|
262
235
|
|
263
236
|
if isinstance(expr, Term):
|
264
|
-
|
265
|
-
|
237
|
+
args = tuple(evaluate(arg) for arg in expr.args)
|
238
|
+
kwargs = {k: evaluate(v) for k, v in expr.kwargs.items()}
|
239
|
+
return expr.op(*args, **kwargs)
|
240
|
+
elif isinstance(expr, Operation):
|
241
|
+
op_intp = get_interpretation().get(expr, expr)
|
242
|
+
return op_intp if isinstance(op_intp, Operation) else expr # type: ignore
|
243
|
+
elif isinstance(expr, collections.abc.Mapping):
|
244
|
+
if isinstance(expr, collections.defaultdict):
|
245
|
+
return type(expr)(expr.default_factory, evaluate(tuple(expr.items()))) # type: ignore
|
246
|
+
elif isinstance(expr, types.MappingProxyType):
|
247
|
+
return type(expr)(dict(evaluate(tuple(expr.items())))) # type: ignore
|
248
|
+
else:
|
249
|
+
return type(expr)(evaluate(tuple(expr.items()))) # type: ignore
|
250
|
+
elif isinstance(expr, collections.abc.Sequence):
|
251
|
+
if isinstance(expr, str | bytes):
|
252
|
+
return typing.cast(T, expr) # mypy doesnt like ignore here, so we use cast
|
253
|
+
else:
|
254
|
+
return type(expr)(evaluate(item) for item in expr) # type: ignore
|
255
|
+
elif isinstance(expr, collections.abc.Set):
|
256
|
+
if isinstance(expr, collections.abc.ItemsView | collections.abc.KeysView):
|
257
|
+
return {evaluate(item) for item in expr} # type: ignore
|
258
|
+
else:
|
259
|
+
return type(expr)(evaluate(item) for item in expr) # type: ignore
|
260
|
+
elif isinstance(expr, collections.abc.ValuesView):
|
261
|
+
return [evaluate(item) for item in expr] # type: ignore
|
262
|
+
elif dataclasses.is_dataclass(expr) and not isinstance(expr, type):
|
263
|
+
return typing.cast(
|
264
|
+
T,
|
265
|
+
dataclasses.replace(
|
266
|
+
expr,
|
267
|
+
**{
|
268
|
+
field.name: evaluate(getattr(expr, field.name))
|
269
|
+
for field in dataclasses.fields(expr)
|
270
|
+
},
|
271
|
+
),
|
266
272
|
)
|
267
|
-
return apply.__default_rule__(intp, expr.op, *args, **kwargs) # type: ignore
|
268
|
-
elif tree.is_nested(expr):
|
269
|
-
return tree.map_structure(functools.partial(evaluate, intp=intp), expr)
|
270
273
|
else:
|
271
|
-
return expr
|
274
|
+
return typing.cast(T, expr)
|
272
275
|
|
273
276
|
|
274
|
-
def typeof(term: Expr[T]) ->
|
277
|
+
def typeof[T](term: Expr[T]) -> type[T]:
|
275
278
|
"""Return the type of an expression.
|
276
279
|
|
277
280
|
**Example usage**:
|
@@ -280,46 +283,63 @@ def typeof(term: Expr[T]) -> Type[T]:
|
|
280
283
|
|
281
284
|
>>> @defop
|
282
285
|
... def cmp(x: int, y: int) -> bool:
|
283
|
-
... raise
|
286
|
+
... raise NotHandled
|
284
287
|
>>> typeof(cmp(1, 2))
|
285
288
|
<class 'bool'>
|
286
289
|
|
287
290
|
Types can be computed in the presence of type variables.
|
288
291
|
|
289
|
-
>>> from typing import TypeVar
|
290
|
-
>>> T = TypeVar('T')
|
291
292
|
>>> @defop
|
292
|
-
... def if_then_else(x: bool, a: T, b: T) -> T:
|
293
|
-
... raise
|
293
|
+
... def if_then_else[T](x: bool, a: T, b: T) -> T:
|
294
|
+
... raise NotHandled
|
294
295
|
>>> typeof(if_then_else(True, 0, 1))
|
295
296
|
<class 'int'>
|
296
297
|
|
297
298
|
"""
|
298
299
|
from effectful.internals.runtime import interpreter
|
299
300
|
|
300
|
-
with interpreter({apply: lambda
|
301
|
-
|
301
|
+
with interpreter({apply: lambda op, *a, **k: op.__type_rule__(*a, **k)}):
|
302
|
+
if isinstance(term, Term):
|
303
|
+
# If term is a Term, we evaluate it to get its type
|
304
|
+
tp = evaluate(term)
|
305
|
+
if isinstance(tp, typing.TypeVar):
|
306
|
+
tp = (
|
307
|
+
tp.__bound__
|
308
|
+
if tp.__bound__
|
309
|
+
else tp.__constraints__[0]
|
310
|
+
if tp.__constraints__
|
311
|
+
else object
|
312
|
+
)
|
313
|
+
if isinstance(tp, types.UnionType):
|
314
|
+
raise TypeError(
|
315
|
+
f"Cannot determine type of {term} because it is a union type: {tp}"
|
316
|
+
)
|
317
|
+
return typing.get_origin(tp) or tp # type: ignore
|
318
|
+
else:
|
319
|
+
return type(term)
|
302
320
|
|
303
321
|
|
304
|
-
def fvsof(term: Expr[S]) -> Set[Operation]:
|
322
|
+
def fvsof[S](term: Expr[S]) -> collections.abc.Set[Operation]:
|
305
323
|
"""Return the free variables of an expression.
|
306
324
|
|
307
325
|
**Example usage**:
|
308
326
|
|
309
327
|
>>> @defop
|
310
328
|
... def f(x: int, y: int) -> int:
|
311
|
-
... raise
|
312
|
-
>>> fvsof(f(1, 2))
|
313
|
-
|
314
|
-
|
329
|
+
... raise NotHandled
|
330
|
+
>>> fvs = fvsof(f(1, 2))
|
331
|
+
>>> assert f in fvs
|
332
|
+
>>> assert len(fvs) == 1
|
315
333
|
"""
|
316
334
|
from effectful.internals.runtime import interpreter
|
317
335
|
|
318
|
-
_fvs:
|
336
|
+
_fvs: set[Operation] = set()
|
319
337
|
|
320
|
-
def _update_fvs(
|
338
|
+
def _update_fvs(op, *args, **kwargs):
|
321
339
|
_fvs.add(op)
|
322
|
-
|
340
|
+
bindings = op.__fvs_rule__(*args, **kwargs)
|
341
|
+
for bound_var in set().union(*(*bindings.args, *bindings.kwargs.values())):
|
342
|
+
assert isinstance(bound_var, Operation)
|
323
343
|
if bound_var in _fvs:
|
324
344
|
_fvs.remove(bound_var)
|
325
345
|
|