effectful 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/handlers/indexed.py +23 -24
- effectful/handlers/jax/__init__.py +14 -0
- effectful/handlers/jax/_handlers.py +293 -0
- effectful/handlers/jax/_terms.py +502 -0
- effectful/handlers/jax/numpy/__init__.py +23 -0
- effectful/handlers/jax/numpy/linalg.py +13 -0
- effectful/handlers/jax/scipy/special.py +11 -0
- effectful/handlers/numpyro.py +562 -0
- effectful/handlers/pyro.py +565 -214
- effectful/handlers/torch.py +297 -168
- effectful/internals/runtime.py +6 -13
- effectful/internals/tensor_utils.py +32 -0
- effectful/internals/unification.py +901 -0
- effectful/ops/semantics.py +109 -77
- effectful/ops/syntax.py +821 -250
- effectful/ops/types.py +121 -29
- {effectful-0.1.0.dist-info → effectful-0.2.1.dist-info}/METADATA +59 -56
- effectful-0.2.1.dist-info/RECORD +26 -0
- {effectful-0.1.0.dist-info → effectful-0.2.1.dist-info}/WHEEL +1 -1
- effectful/handlers/numbers.py +0 -263
- effectful-0.1.0.dist-info/RECORD +0 -18
- {effectful-0.1.0.dist-info → effectful-0.2.1.dist-info/licenses}/LICENSE.md +0 -0
- {effectful-0.1.0.dist-info → effectful-0.2.1.dist-info}/top_level.txt +0 -0
effectful/ops/semantics.py
CHANGED
@@ -1,22 +1,22 @@
|
|
1
|
+
import collections.abc
|
1
2
|
import contextlib
|
2
|
-
import
|
3
|
-
|
3
|
+
import dataclasses
|
4
|
+
import types
|
5
|
+
import typing
|
6
|
+
from typing import Any
|
4
7
|
|
5
|
-
import
|
6
|
-
from
|
8
|
+
from effectful.ops.syntax import defop
|
9
|
+
from effectful.ops.types import (
|
10
|
+
Expr,
|
11
|
+
Interpretation,
|
12
|
+
NotHandled, # noqa: F401
|
13
|
+
Operation,
|
14
|
+
Term,
|
15
|
+
)
|
7
16
|
|
8
|
-
from effectful.ops.syntax import deffn, defop
|
9
|
-
from effectful.ops.types import Expr, Interpretation, Operation, Term
|
10
17
|
|
11
|
-
|
12
|
-
|
13
|
-
S = TypeVar("S")
|
14
|
-
T = TypeVar("T")
|
15
|
-
V = TypeVar("V")
|
16
|
-
|
17
|
-
|
18
|
-
@defop
|
19
|
-
def apply(intp: Interpretation, op: Operation, *args, **kwargs) -> Any:
|
18
|
+
@defop # type: ignore
|
19
|
+
def apply[**P, T](op: Operation[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
20
20
|
"""Apply ``op`` to ``args``, ``kwargs`` in interpretation ``intp``.
|
21
21
|
|
22
22
|
Handling :func:`apply` changes the evaluation strategy of terms.
|
@@ -37,41 +37,23 @@ def apply(intp: Interpretation, op: Operation, *args, **kwargs) -> Any:
|
|
37
37
|
|
38
38
|
By installing an :func:`apply` handler, we capture the term instead:
|
39
39
|
|
40
|
-
>>>
|
40
|
+
>>> def default(*args, **kwargs):
|
41
|
+
... raise NotHandled
|
42
|
+
>>> with handler({apply: default }):
|
41
43
|
... term = mul(add(1, 2), 3)
|
42
|
-
>>> term
|
44
|
+
>>> print(str(term))
|
43
45
|
mul(add(1, 2), 3)
|
44
46
|
|
45
47
|
"""
|
48
|
+
from effectful.internals.runtime import get_interpretation
|
49
|
+
|
50
|
+
intp = get_interpretation()
|
46
51
|
if op in intp:
|
47
52
|
return intp[op](*args, **kwargs)
|
48
53
|
elif apply in intp:
|
49
|
-
return intp[apply](
|
54
|
+
return intp[apply](op, *args, **kwargs)
|
50
55
|
else:
|
51
|
-
return op.__default_rule__(*args, **kwargs)
|
52
|
-
|
53
|
-
|
54
|
-
@defop # type: ignore
|
55
|
-
def call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
56
|
-
"""An operation that eliminates a callable term.
|
57
|
-
|
58
|
-
This operation is invoked by the ``__call__`` method of a callable term.
|
59
|
-
|
60
|
-
"""
|
61
|
-
if isinstance(fn, Term) and fn.op is deffn:
|
62
|
-
body: Expr[Callable[P, T]] = fn.args[0]
|
63
|
-
argvars: tuple[Operation, ...] = fn.args[1:]
|
64
|
-
kwvars: dict[str, Operation] = fn.kwargs
|
65
|
-
subs = {
|
66
|
-
**{v: functools.partial(lambda x: x, a) for v, a in zip(argvars, args)},
|
67
|
-
**{kwvars[k]: functools.partial(lambda x: x, kwargs[k]) for k in kwargs},
|
68
|
-
}
|
69
|
-
with handler(subs):
|
70
|
-
return evaluate(body)
|
71
|
-
elif not any(isinstance(a, Term) for a in tree.flatten((fn, args, kwargs))):
|
72
|
-
return fn(*args, **kwargs)
|
73
|
-
else:
|
74
|
-
raise NotImplementedError
|
56
|
+
return op.__default_rule__(*args, **kwargs) # type: ignore
|
75
57
|
|
76
58
|
|
77
59
|
@defop
|
@@ -190,7 +172,7 @@ def product(intp: Interpretation, intp2: Interpretation) -> Interpretation:
|
|
190
172
|
|
191
173
|
"""
|
192
174
|
if any(op in intp for op in intp2): # alpha-rename
|
193
|
-
renaming = {op: defop(op) for op in intp2 if op in intp}
|
175
|
+
renaming: Interpretation = {op: defop(op) for op in intp2 if op in intp}
|
194
176
|
intp_fresh = {renaming.get(op, op): handler(renaming)(intp[op]) for op in intp}
|
195
177
|
return product(intp_fresh, intp2)
|
196
178
|
else:
|
@@ -208,7 +190,7 @@ def runner(intp: Interpretation):
|
|
208
190
|
from effectful.internals.runtime import get_interpretation, interpreter
|
209
191
|
|
210
192
|
@interpreter(get_interpretation())
|
211
|
-
def _reapply
|
193
|
+
def _reapply[**P, S](op: Operation[P, S], *args: P.args, **kwargs: P.kwargs):
|
212
194
|
return op(*args, **kwargs)
|
213
195
|
|
214
196
|
with interpreter({apply: _reapply, **intp}):
|
@@ -227,7 +209,7 @@ def handler(intp: Interpretation):
|
|
227
209
|
yield intp
|
228
210
|
|
229
211
|
|
230
|
-
def evaluate(expr: Expr[T], *, intp:
|
212
|
+
def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]:
|
231
213
|
"""Evaluate expression ``expr`` using interpretation ``intp``. If no
|
232
214
|
interpretation is provided, uses the current interpretation.
|
233
215
|
|
@@ -238,31 +220,69 @@ def evaluate(expr: Expr[T], *, intp: Optional[Interpretation] = None) -> Expr[T]
|
|
238
220
|
|
239
221
|
>>> @defop
|
240
222
|
... def add(x: int, y: int) -> int:
|
241
|
-
... raise
|
223
|
+
... raise NotHandled
|
242
224
|
>>> expr = add(1, add(2, 3))
|
243
|
-
>>> expr
|
225
|
+
>>> print(str(expr))
|
244
226
|
add(1, add(2, 3))
|
245
227
|
>>> evaluate(expr, intp={add: lambda x, y: x + y})
|
246
228
|
6
|
247
229
|
|
248
230
|
"""
|
249
|
-
|
250
|
-
from effectful.internals.runtime import get_interpretation
|
231
|
+
from effectful.internals.runtime import get_interpretation, interpreter
|
251
232
|
|
252
|
-
|
233
|
+
if intp is not None:
|
234
|
+
return interpreter(intp)(evaluate)(expr)
|
253
235
|
|
254
236
|
if isinstance(expr, Term):
|
255
|
-
|
256
|
-
|
237
|
+
args = tuple(evaluate(arg) for arg in expr.args)
|
238
|
+
kwargs = {k: evaluate(v) for k, v in expr.kwargs.items()}
|
239
|
+
return expr.op(*args, **kwargs)
|
240
|
+
elif isinstance(expr, Operation):
|
241
|
+
op_intp = get_interpretation().get(expr, expr)
|
242
|
+
return op_intp if isinstance(op_intp, Operation) else expr # type: ignore
|
243
|
+
elif isinstance(expr, collections.abc.Mapping):
|
244
|
+
if isinstance(expr, collections.defaultdict):
|
245
|
+
return type(expr)(expr.default_factory, evaluate(tuple(expr.items()))) # type: ignore
|
246
|
+
elif isinstance(expr, types.MappingProxyType):
|
247
|
+
return type(expr)(dict(evaluate(tuple(expr.items())))) # type: ignore
|
248
|
+
else:
|
249
|
+
return type(expr)(evaluate(tuple(expr.items()))) # type: ignore
|
250
|
+
elif isinstance(expr, collections.abc.Sequence):
|
251
|
+
if isinstance(expr, str | bytes):
|
252
|
+
return typing.cast(T, expr) # mypy doesnt like ignore here, so we use cast
|
253
|
+
elif (
|
254
|
+
isinstance(expr, tuple)
|
255
|
+
and hasattr(expr, "_fields")
|
256
|
+
and all(hasattr(expr, field) for field in getattr(expr, "_fields"))
|
257
|
+
): # namedtuple
|
258
|
+
return type(expr)(
|
259
|
+
**{field: evaluate(getattr(expr, field)) for field in expr._fields}
|
260
|
+
)
|
261
|
+
else:
|
262
|
+
return type(expr)(evaluate(item) for item in expr) # type: ignore
|
263
|
+
elif isinstance(expr, collections.abc.Set):
|
264
|
+
if isinstance(expr, collections.abc.ItemsView | collections.abc.KeysView):
|
265
|
+
return {evaluate(item) for item in expr} # type: ignore
|
266
|
+
else:
|
267
|
+
return type(expr)(evaluate(item) for item in expr) # type: ignore
|
268
|
+
elif isinstance(expr, collections.abc.ValuesView):
|
269
|
+
return [evaluate(item) for item in expr] # type: ignore
|
270
|
+
elif dataclasses.is_dataclass(expr) and not isinstance(expr, type):
|
271
|
+
return typing.cast(
|
272
|
+
T,
|
273
|
+
dataclasses.replace(
|
274
|
+
expr,
|
275
|
+
**{
|
276
|
+
field.name: evaluate(getattr(expr, field.name))
|
277
|
+
for field in dataclasses.fields(expr)
|
278
|
+
},
|
279
|
+
),
|
257
280
|
)
|
258
|
-
return apply.__default_rule__(intp, expr.op, *args, **kwargs)
|
259
|
-
elif tree.is_nested(expr):
|
260
|
-
return tree.map_structure(functools.partial(evaluate, intp=intp), expr)
|
261
281
|
else:
|
262
|
-
return expr
|
282
|
+
return typing.cast(T, expr)
|
263
283
|
|
264
284
|
|
265
|
-
def typeof(term: Expr[T]) ->
|
285
|
+
def typeof[T](term: Expr[T]) -> type[T]:
|
266
286
|
"""Return the type of an expression.
|
267
287
|
|
268
288
|
**Example usage**:
|
@@ -271,51 +291,63 @@ def typeof(term: Expr[T]) -> Type[T]:
|
|
271
291
|
|
272
292
|
>>> @defop
|
273
293
|
... def cmp(x: int, y: int) -> bool:
|
274
|
-
... raise
|
294
|
+
... raise NotHandled
|
275
295
|
>>> typeof(cmp(1, 2))
|
276
296
|
<class 'bool'>
|
277
297
|
|
278
298
|
Types can be computed in the presence of type variables.
|
279
299
|
|
280
|
-
>>> from typing import TypeVar
|
281
|
-
>>> T = TypeVar('T')
|
282
300
|
>>> @defop
|
283
|
-
... def if_then_else(x: bool, a: T, b: T) -> T:
|
284
|
-
... raise
|
301
|
+
... def if_then_else[T](x: bool, a: T, b: T) -> T:
|
302
|
+
... raise NotHandled
|
285
303
|
>>> typeof(if_then_else(True, 0, 1))
|
286
304
|
<class 'int'>
|
287
305
|
|
288
306
|
"""
|
289
307
|
from effectful.internals.runtime import interpreter
|
290
308
|
|
291
|
-
with interpreter({apply: lambda
|
292
|
-
|
309
|
+
with interpreter({apply: lambda op, *a, **k: op.__type_rule__(*a, **k)}):
|
310
|
+
if isinstance(term, Term):
|
311
|
+
# If term is a Term, we evaluate it to get its type
|
312
|
+
tp = evaluate(term)
|
313
|
+
if isinstance(tp, typing.TypeVar):
|
314
|
+
tp = (
|
315
|
+
tp.__bound__
|
316
|
+
if tp.__bound__
|
317
|
+
else tp.__constraints__[0]
|
318
|
+
if tp.__constraints__
|
319
|
+
else object
|
320
|
+
)
|
321
|
+
if isinstance(tp, types.UnionType):
|
322
|
+
raise TypeError(
|
323
|
+
f"Cannot determine type of {term} because it is a union type: {tp}"
|
324
|
+
)
|
325
|
+
return typing.get_origin(tp) or tp # type: ignore
|
326
|
+
else:
|
327
|
+
return type(term)
|
293
328
|
|
294
329
|
|
295
|
-
def fvsof(term: Expr[S]) -> Set[Operation]:
|
330
|
+
def fvsof[S](term: Expr[S]) -> collections.abc.Set[Operation]:
|
296
331
|
"""Return the free variables of an expression.
|
297
332
|
|
298
333
|
**Example usage**:
|
299
334
|
|
300
335
|
>>> @defop
|
301
336
|
... def f(x: int, y: int) -> int:
|
302
|
-
... raise
|
303
|
-
>>> fvsof(f(1, 2))
|
304
|
-
|
305
|
-
|
337
|
+
... raise NotHandled
|
338
|
+
>>> fvs = fvsof(f(1, 2))
|
339
|
+
>>> assert f in fvs
|
340
|
+
>>> assert len(fvs) == 1
|
306
341
|
"""
|
307
342
|
from effectful.internals.runtime import interpreter
|
308
343
|
|
309
|
-
_fvs:
|
344
|
+
_fvs: set[Operation] = set()
|
310
345
|
|
311
|
-
def _update_fvs(
|
346
|
+
def _update_fvs(op, *args, **kwargs):
|
312
347
|
_fvs.add(op)
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
*(k for k in kwarg_ctxs.values()),
|
317
|
-
)
|
318
|
-
for bound_var in bound_vars:
|
348
|
+
bindings = op.__fvs_rule__(*args, **kwargs)
|
349
|
+
for bound_var in set().union(*(*bindings.args, *bindings.kwargs.values())):
|
350
|
+
assert isinstance(bound_var, Operation)
|
319
351
|
if bound_var in _fvs:
|
320
352
|
_fvs.remove(bound_var)
|
321
353
|
|