effectful 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/handlers/indexed.py +23 -24
- effectful/handlers/jax/__init__.py +14 -0
- effectful/handlers/jax/_handlers.py +293 -0
- effectful/handlers/jax/_terms.py +502 -0
- effectful/handlers/jax/numpy/__init__.py +23 -0
- effectful/handlers/jax/numpy/linalg.py +13 -0
- effectful/handlers/jax/scipy/special.py +11 -0
- effectful/handlers/numpyro.py +562 -0
- effectful/handlers/pyro.py +565 -214
- effectful/handlers/torch.py +297 -168
- effectful/internals/runtime.py +6 -13
- effectful/internals/tensor_utils.py +32 -0
- effectful/internals/unification.py +900 -0
- effectful/ops/semantics.py +101 -77
- effectful/ops/syntax.py +813 -251
- effectful/ops/types.py +121 -29
- {effectful-0.1.0.dist-info → effectful-0.2.0.dist-info}/METADATA +59 -56
- effectful-0.2.0.dist-info/RECORD +26 -0
- {effectful-0.1.0.dist-info → effectful-0.2.0.dist-info}/WHEEL +1 -1
- effectful/handlers/numbers.py +0 -263
- effectful-0.1.0.dist-info/RECORD +0 -18
- {effectful-0.1.0.dist-info → effectful-0.2.0.dist-info/licenses}/LICENSE.md +0 -0
- {effectful-0.1.0.dist-info → effectful-0.2.0.dist-info}/top_level.txt +0 -0
effectful/ops/semantics.py
CHANGED
@@ -1,22 +1,22 @@
|
|
1
|
+
import collections.abc
|
1
2
|
import contextlib
|
2
|
-
import
|
3
|
-
|
3
|
+
import dataclasses
|
4
|
+
import types
|
5
|
+
import typing
|
6
|
+
from typing import Any
|
4
7
|
|
5
|
-
import
|
6
|
-
from
|
8
|
+
from effectful.ops.syntax import defop
|
9
|
+
from effectful.ops.types import (
|
10
|
+
Expr,
|
11
|
+
Interpretation,
|
12
|
+
NotHandled, # noqa: F401
|
13
|
+
Operation,
|
14
|
+
Term,
|
15
|
+
)
|
7
16
|
|
8
|
-
from effectful.ops.syntax import deffn, defop
|
9
|
-
from effectful.ops.types import Expr, Interpretation, Operation, Term
|
10
17
|
|
11
|
-
|
12
|
-
|
13
|
-
S = TypeVar("S")
|
14
|
-
T = TypeVar("T")
|
15
|
-
V = TypeVar("V")
|
16
|
-
|
17
|
-
|
18
|
-
@defop
|
19
|
-
def apply(intp: Interpretation, op: Operation, *args, **kwargs) -> Any:
|
18
|
+
@defop # type: ignore
|
19
|
+
def apply[**P, T](op: Operation[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
20
20
|
"""Apply ``op`` to ``args``, ``kwargs`` in interpretation ``intp``.
|
21
21
|
|
22
22
|
Handling :func:`apply` changes the evaluation strategy of terms.
|
@@ -37,41 +37,23 @@ def apply(intp: Interpretation, op: Operation, *args, **kwargs) -> Any:
|
|
37
37
|
|
38
38
|
By installing an :func:`apply` handler, we capture the term instead:
|
39
39
|
|
40
|
-
>>>
|
40
|
+
>>> def default(*args, **kwargs):
|
41
|
+
... raise NotHandled
|
42
|
+
>>> with handler({apply: default }):
|
41
43
|
... term = mul(add(1, 2), 3)
|
42
|
-
>>> term
|
44
|
+
>>> print(str(term))
|
43
45
|
mul(add(1, 2), 3)
|
44
46
|
|
45
47
|
"""
|
48
|
+
from effectful.internals.runtime import get_interpretation
|
49
|
+
|
50
|
+
intp = get_interpretation()
|
46
51
|
if op in intp:
|
47
52
|
return intp[op](*args, **kwargs)
|
48
53
|
elif apply in intp:
|
49
|
-
return intp[apply](
|
54
|
+
return intp[apply](op, *args, **kwargs)
|
50
55
|
else:
|
51
|
-
return op.__default_rule__(*args, **kwargs)
|
52
|
-
|
53
|
-
|
54
|
-
@defop # type: ignore
|
55
|
-
def call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
56
|
-
"""An operation that eliminates a callable term.
|
57
|
-
|
58
|
-
This operation is invoked by the ``__call__`` method of a callable term.
|
59
|
-
|
60
|
-
"""
|
61
|
-
if isinstance(fn, Term) and fn.op is deffn:
|
62
|
-
body: Expr[Callable[P, T]] = fn.args[0]
|
63
|
-
argvars: tuple[Operation, ...] = fn.args[1:]
|
64
|
-
kwvars: dict[str, Operation] = fn.kwargs
|
65
|
-
subs = {
|
66
|
-
**{v: functools.partial(lambda x: x, a) for v, a in zip(argvars, args)},
|
67
|
-
**{kwvars[k]: functools.partial(lambda x: x, kwargs[k]) for k in kwargs},
|
68
|
-
}
|
69
|
-
with handler(subs):
|
70
|
-
return evaluate(body)
|
71
|
-
elif not any(isinstance(a, Term) for a in tree.flatten((fn, args, kwargs))):
|
72
|
-
return fn(*args, **kwargs)
|
73
|
-
else:
|
74
|
-
raise NotImplementedError
|
56
|
+
return op.__default_rule__(*args, **kwargs) # type: ignore
|
75
57
|
|
76
58
|
|
77
59
|
@defop
|
@@ -190,7 +172,7 @@ def product(intp: Interpretation, intp2: Interpretation) -> Interpretation:
|
|
190
172
|
|
191
173
|
"""
|
192
174
|
if any(op in intp for op in intp2): # alpha-rename
|
193
|
-
renaming = {op: defop(op) for op in intp2 if op in intp}
|
175
|
+
renaming: Interpretation = {op: defop(op) for op in intp2 if op in intp}
|
194
176
|
intp_fresh = {renaming.get(op, op): handler(renaming)(intp[op]) for op in intp}
|
195
177
|
return product(intp_fresh, intp2)
|
196
178
|
else:
|
@@ -208,7 +190,7 @@ def runner(intp: Interpretation):
|
|
208
190
|
from effectful.internals.runtime import get_interpretation, interpreter
|
209
191
|
|
210
192
|
@interpreter(get_interpretation())
|
211
|
-
def _reapply
|
193
|
+
def _reapply[**P, S](op: Operation[P, S], *args: P.args, **kwargs: P.kwargs):
|
212
194
|
return op(*args, **kwargs)
|
213
195
|
|
214
196
|
with interpreter({apply: _reapply, **intp}):
|
@@ -227,7 +209,7 @@ def handler(intp: Interpretation):
|
|
227
209
|
yield intp
|
228
210
|
|
229
211
|
|
230
|
-
def evaluate(expr: Expr[T], *, intp:
|
212
|
+
def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]:
|
231
213
|
"""Evaluate expression ``expr`` using interpretation ``intp``. If no
|
232
214
|
interpretation is provided, uses the current interpretation.
|
233
215
|
|
@@ -238,31 +220,61 @@ def evaluate(expr: Expr[T], *, intp: Optional[Interpretation] = None) -> Expr[T]
|
|
238
220
|
|
239
221
|
>>> @defop
|
240
222
|
... def add(x: int, y: int) -> int:
|
241
|
-
... raise
|
223
|
+
... raise NotHandled
|
242
224
|
>>> expr = add(1, add(2, 3))
|
243
|
-
>>> expr
|
225
|
+
>>> print(str(expr))
|
244
226
|
add(1, add(2, 3))
|
245
227
|
>>> evaluate(expr, intp={add: lambda x, y: x + y})
|
246
228
|
6
|
247
229
|
|
248
230
|
"""
|
249
|
-
|
250
|
-
from effectful.internals.runtime import get_interpretation
|
231
|
+
from effectful.internals.runtime import get_interpretation, interpreter
|
251
232
|
|
252
|
-
|
233
|
+
if intp is not None:
|
234
|
+
return interpreter(intp)(evaluate)(expr)
|
253
235
|
|
254
236
|
if isinstance(expr, Term):
|
255
|
-
|
256
|
-
|
237
|
+
args = tuple(evaluate(arg) for arg in expr.args)
|
238
|
+
kwargs = {k: evaluate(v) for k, v in expr.kwargs.items()}
|
239
|
+
return expr.op(*args, **kwargs)
|
240
|
+
elif isinstance(expr, Operation):
|
241
|
+
op_intp = get_interpretation().get(expr, expr)
|
242
|
+
return op_intp if isinstance(op_intp, Operation) else expr # type: ignore
|
243
|
+
elif isinstance(expr, collections.abc.Mapping):
|
244
|
+
if isinstance(expr, collections.defaultdict):
|
245
|
+
return type(expr)(expr.default_factory, evaluate(tuple(expr.items()))) # type: ignore
|
246
|
+
elif isinstance(expr, types.MappingProxyType):
|
247
|
+
return type(expr)(dict(evaluate(tuple(expr.items())))) # type: ignore
|
248
|
+
else:
|
249
|
+
return type(expr)(evaluate(tuple(expr.items()))) # type: ignore
|
250
|
+
elif isinstance(expr, collections.abc.Sequence):
|
251
|
+
if isinstance(expr, str | bytes):
|
252
|
+
return typing.cast(T, expr) # mypy doesnt like ignore here, so we use cast
|
253
|
+
else:
|
254
|
+
return type(expr)(evaluate(item) for item in expr) # type: ignore
|
255
|
+
elif isinstance(expr, collections.abc.Set):
|
256
|
+
if isinstance(expr, collections.abc.ItemsView | collections.abc.KeysView):
|
257
|
+
return {evaluate(item) for item in expr} # type: ignore
|
258
|
+
else:
|
259
|
+
return type(expr)(evaluate(item) for item in expr) # type: ignore
|
260
|
+
elif isinstance(expr, collections.abc.ValuesView):
|
261
|
+
return [evaluate(item) for item in expr] # type: ignore
|
262
|
+
elif dataclasses.is_dataclass(expr) and not isinstance(expr, type):
|
263
|
+
return typing.cast(
|
264
|
+
T,
|
265
|
+
dataclasses.replace(
|
266
|
+
expr,
|
267
|
+
**{
|
268
|
+
field.name: evaluate(getattr(expr, field.name))
|
269
|
+
for field in dataclasses.fields(expr)
|
270
|
+
},
|
271
|
+
),
|
257
272
|
)
|
258
|
-
return apply.__default_rule__(intp, expr.op, *args, **kwargs)
|
259
|
-
elif tree.is_nested(expr):
|
260
|
-
return tree.map_structure(functools.partial(evaluate, intp=intp), expr)
|
261
273
|
else:
|
262
|
-
return expr
|
274
|
+
return typing.cast(T, expr)
|
263
275
|
|
264
276
|
|
265
|
-
def typeof(term: Expr[T]) ->
|
277
|
+
def typeof[T](term: Expr[T]) -> type[T]:
|
266
278
|
"""Return the type of an expression.
|
267
279
|
|
268
280
|
**Example usage**:
|
@@ -271,51 +283,63 @@ def typeof(term: Expr[T]) -> Type[T]:
|
|
271
283
|
|
272
284
|
>>> @defop
|
273
285
|
... def cmp(x: int, y: int) -> bool:
|
274
|
-
... raise
|
286
|
+
... raise NotHandled
|
275
287
|
>>> typeof(cmp(1, 2))
|
276
288
|
<class 'bool'>
|
277
289
|
|
278
290
|
Types can be computed in the presence of type variables.
|
279
291
|
|
280
|
-
>>> from typing import TypeVar
|
281
|
-
>>> T = TypeVar('T')
|
282
292
|
>>> @defop
|
283
|
-
... def if_then_else(x: bool, a: T, b: T) -> T:
|
284
|
-
... raise
|
293
|
+
... def if_then_else[T](x: bool, a: T, b: T) -> T:
|
294
|
+
... raise NotHandled
|
285
295
|
>>> typeof(if_then_else(True, 0, 1))
|
286
296
|
<class 'int'>
|
287
297
|
|
288
298
|
"""
|
289
299
|
from effectful.internals.runtime import interpreter
|
290
300
|
|
291
|
-
with interpreter({apply: lambda
|
292
|
-
|
301
|
+
with interpreter({apply: lambda op, *a, **k: op.__type_rule__(*a, **k)}):
|
302
|
+
if isinstance(term, Term):
|
303
|
+
# If term is a Term, we evaluate it to get its type
|
304
|
+
tp = evaluate(term)
|
305
|
+
if isinstance(tp, typing.TypeVar):
|
306
|
+
tp = (
|
307
|
+
tp.__bound__
|
308
|
+
if tp.__bound__
|
309
|
+
else tp.__constraints__[0]
|
310
|
+
if tp.__constraints__
|
311
|
+
else object
|
312
|
+
)
|
313
|
+
if isinstance(tp, types.UnionType):
|
314
|
+
raise TypeError(
|
315
|
+
f"Cannot determine type of {term} because it is a union type: {tp}"
|
316
|
+
)
|
317
|
+
return typing.get_origin(tp) or tp # type: ignore
|
318
|
+
else:
|
319
|
+
return type(term)
|
293
320
|
|
294
321
|
|
295
|
-
def fvsof(term: Expr[S]) -> Set[Operation]:
|
322
|
+
def fvsof[S](term: Expr[S]) -> collections.abc.Set[Operation]:
|
296
323
|
"""Return the free variables of an expression.
|
297
324
|
|
298
325
|
**Example usage**:
|
299
326
|
|
300
327
|
>>> @defop
|
301
328
|
... def f(x: int, y: int) -> int:
|
302
|
-
... raise
|
303
|
-
>>> fvsof(f(1, 2))
|
304
|
-
|
305
|
-
|
329
|
+
... raise NotHandled
|
330
|
+
>>> fvs = fvsof(f(1, 2))
|
331
|
+
>>> assert f in fvs
|
332
|
+
>>> assert len(fvs) == 1
|
306
333
|
"""
|
307
334
|
from effectful.internals.runtime import interpreter
|
308
335
|
|
309
|
-
_fvs:
|
336
|
+
_fvs: set[Operation] = set()
|
310
337
|
|
311
|
-
def _update_fvs(
|
338
|
+
def _update_fvs(op, *args, **kwargs):
|
312
339
|
_fvs.add(op)
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
*(k for k in kwarg_ctxs.values()),
|
317
|
-
)
|
318
|
-
for bound_var in bound_vars:
|
340
|
+
bindings = op.__fvs_rule__(*args, **kwargs)
|
341
|
+
for bound_var in set().union(*(*bindings.args, *bindings.kwargs.values())):
|
342
|
+
assert isinstance(bound_var, Operation)
|
319
343
|
if bound_var in _fvs:
|
320
344
|
_fvs.remove(bound_var)
|
321
345
|
|