effectful 0.0.1__tar.gz → 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {effectful-0.0.1 → effectful-0.1.0}/PKG-INFO +17 -12
- {effectful-0.0.1 → effectful-0.1.0}/README.rst +3 -6
- {effectful-0.0.1 → effectful-0.1.0}/effectful/handlers/indexed.py +4 -22
- {effectful-0.0.1 → effectful-0.1.0}/effectful/handlers/numbers.py +10 -6
- {effectful-0.0.1 → effectful-0.1.0}/effectful/handlers/pyro.py +2 -2
- {effectful-0.0.1 → effectful-0.1.0}/effectful/handlers/torch.py +33 -10
- {effectful-0.0.1 → effectful-0.1.0}/effectful/ops/semantics.py +25 -29
- effectful-0.1.0/effectful/ops/syntax.py +1070 -0
- {effectful-0.0.1 → effectful-0.1.0}/effectful/ops/types.py +27 -13
- {effectful-0.0.1 → effectful-0.1.0}/effectful.egg-info/PKG-INFO +17 -12
- {effectful-0.0.1 → effectful-0.1.0}/effectful.egg-info/SOURCES.txt +0 -1
- {effectful-0.0.1 → effectful-0.1.0}/effectful.egg-info/requires.txt +11 -2
- {effectful-0.0.1 → effectful-0.1.0}/setup.py +16 -7
- {effectful-0.0.1 → effectful-0.1.0}/tests/test_handlers_numbers.py +7 -4
- {effectful-0.0.1 → effectful-0.1.0}/tests/test_ops_semantics.py +58 -8
- {effectful-0.0.1 → effectful-0.1.0}/tests/test_ops_syntax.py +40 -8
- effectful-0.0.1/effectful/internals/base_impl.py +0 -259
- effectful-0.0.1/effectful/ops/syntax.py +0 -523
- {effectful-0.0.1 → effectful-0.1.0}/LICENSE.md +0 -0
- {effectful-0.0.1 → effectful-0.1.0}/effectful/__init__.py +0 -0
- {effectful-0.0.1 → effectful-0.1.0}/effectful/handlers/__init__.py +0 -0
- {effectful-0.0.1 → effectful-0.1.0}/effectful/internals/__init__.py +0 -0
- {effectful-0.0.1 → effectful-0.1.0}/effectful/internals/runtime.py +0 -0
- {effectful-0.0.1 → effectful-0.1.0}/effectful/ops/__init__.py +0 -0
- {effectful-0.0.1 → effectful-0.1.0}/effectful/py.typed +0 -0
- {effectful-0.0.1 → effectful-0.1.0}/effectful.egg-info/dependency_links.txt +0 -0
- {effectful-0.0.1 → effectful-0.1.0}/effectful.egg-info/top_level.txt +0 -0
- {effectful-0.0.1 → effectful-0.1.0}/setup.cfg +0 -0
- {effectful-0.0.1 → effectful-0.1.0}/tests/test_examples_minipyro.py +0 -0
- {effectful-0.0.1 → effectful-0.1.0}/tests/test_handlers_indexed.py +0 -0
- {effectful-0.0.1 → effectful-0.1.0}/tests/test_handlers_pyro.py +0 -0
- {effectful-0.0.1 → effectful-0.1.0}/tests/test_handlers_pyro_dist.py +0 -0
- {effectful-0.0.1 → effectful-0.1.0}/tests/test_handlers_torch.py +0 -0
- {effectful-0.0.1 → effectful-0.1.0}/tests/test_semi_ring.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: effectful
|
3
|
-
Version: 0.0
|
3
|
+
Version: 0.1.0
|
4
4
|
Summary: Metaprogramming infrastructure
|
5
5
|
Home-page: https://www.basis.ai/
|
6
6
|
Author: Basis
|
@@ -28,6 +28,12 @@ Provides-Extra: dev
|
|
28
28
|
Requires-Dist: torch; extra == "dev"
|
29
29
|
Requires-Dist: pyro-ppl; extra == "dev"
|
30
30
|
Requires-Dist: torch; extra == "dev"
|
31
|
+
Requires-Dist: setuptools; extra == "dev"
|
32
|
+
Requires-Dist: sphinx; extra == "dev"
|
33
|
+
Requires-Dist: sphinxcontrib-bibtex; extra == "dev"
|
34
|
+
Requires-Dist: sphinx_rtd_theme; extra == "dev"
|
35
|
+
Requires-Dist: myst-parser; extra == "dev"
|
36
|
+
Requires-Dist: nbsphinx; extra == "dev"
|
31
37
|
Requires-Dist: pytest; extra == "dev"
|
32
38
|
Requires-Dist: pytest-cov; extra == "dev"
|
33
39
|
Requires-Dist: pytest-xdist; extra == "dev"
|
@@ -36,13 +42,15 @@ Requires-Dist: mypy; extra == "dev"
|
|
36
42
|
Requires-Dist: black; extra == "dev"
|
37
43
|
Requires-Dist: flake8; extra == "dev"
|
38
44
|
Requires-Dist: isort; extra == "dev"
|
39
|
-
Requires-Dist: sphinx; extra == "dev"
|
40
|
-
Requires-Dist: sphinxcontrib-bibtex; extra == "dev"
|
41
|
-
Requires-Dist: sphinx_rtd_theme; extra == "dev"
|
42
|
-
Requires-Dist: myst-parser; extra == "dev"
|
43
|
-
Requires-Dist: nbsphinx; extra == "dev"
|
44
45
|
Requires-Dist: nbval; extra == "dev"
|
45
46
|
Requires-Dist: nbqa; extra == "dev"
|
47
|
+
Provides-Extra: docs
|
48
|
+
Requires-Dist: setuptools; extra == "docs"
|
49
|
+
Requires-Dist: sphinx; extra == "docs"
|
50
|
+
Requires-Dist: sphinxcontrib-bibtex; extra == "docs"
|
51
|
+
Requires-Dist: sphinx_rtd_theme; extra == "docs"
|
52
|
+
Requires-Dist: myst-parser; extra == "docs"
|
53
|
+
Requires-Dist: nbsphinx; extra == "docs"
|
46
54
|
Dynamic: author
|
47
55
|
Dynamic: classifier
|
48
56
|
Dynamic: description
|
@@ -55,6 +63,7 @@ Dynamic: requires-dist
|
|
55
63
|
Dynamic: requires-python
|
56
64
|
Dynamic: summary
|
57
65
|
|
66
|
+
|
58
67
|
.. index-inclusion-marker
|
59
68
|
|
60
69
|
Effectful
|
@@ -75,7 +84,7 @@ Install From Source
|
|
75
84
|
git clone git@github.com:BasisResearch/effectful.git
|
76
85
|
cd effectful
|
77
86
|
git checkout master
|
78
|
-
pip install -e .[
|
87
|
+
pip install -e .[pyro]
|
79
88
|
|
80
89
|
Install With Optional PyTorch/Pyro Support
|
81
90
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
@@ -103,16 +112,12 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
|
|
103
112
|
|
104
113
|
.. code:: python
|
105
114
|
|
106
|
-
import operator
|
107
115
|
import functools
|
108
116
|
|
109
|
-
from effectful.handlers.operator import OPERATORS
|
110
117
|
from effectful.ops.types import Term
|
111
118
|
from effectful.ops.syntax import defop
|
112
119
|
from effectful.ops.semantics import handler, evaluate, coproduct, fwd
|
113
|
-
|
114
|
-
|
115
|
-
add = OPERATORS[operator.add]
|
120
|
+
from effectful.handlers.numbers import add
|
116
121
|
|
117
122
|
def beta_add(x: int, y: int) -> int:
|
118
123
|
match x, y:
|
@@ -1,3 +1,4 @@
|
|
1
|
+
|
1
2
|
.. index-inclusion-marker
|
2
3
|
|
3
4
|
Effectful
|
@@ -18,7 +19,7 @@ Install From Source
|
|
18
19
|
git clone git@github.com:BasisResearch/effectful.git
|
19
20
|
cd effectful
|
20
21
|
git checkout master
|
21
|
-
pip install -e .[
|
22
|
+
pip install -e .[pyro]
|
22
23
|
|
23
24
|
Install With Optional PyTorch/Pyro Support
|
24
25
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
@@ -46,16 +47,12 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
|
|
46
47
|
|
47
48
|
.. code:: python
|
48
49
|
|
49
|
-
import operator
|
50
50
|
import functools
|
51
51
|
|
52
|
-
from effectful.handlers.operator import OPERATORS
|
53
52
|
from effectful.ops.types import Term
|
54
53
|
from effectful.ops.syntax import defop
|
55
54
|
from effectful.ops.semantics import handler, evaluate, coproduct, fwd
|
56
|
-
|
57
|
-
|
58
|
-
add = OPERATORS[operator.add]
|
55
|
+
from effectful.handlers.numbers import add
|
59
56
|
|
60
57
|
def beta_add(x: int, y: int) -> int:
|
61
58
|
match x, y:
|
@@ -6,7 +6,7 @@ import torch
|
|
6
6
|
|
7
7
|
from effectful.handlers.torch import Indexable, sizesof
|
8
8
|
from effectful.ops.syntax import deffn, defop
|
9
|
-
from effectful.ops.types import Operation
|
9
|
+
from effectful.ops.types import Operation
|
10
10
|
|
11
11
|
K = TypeVar("K")
|
12
12
|
T = TypeVar("T")
|
@@ -61,16 +61,6 @@ class IndexSet(Dict[str, Set[int]]):
|
|
61
61
|
def __hash__(self):
|
62
62
|
return hash(frozenset((k, frozenset(vs)) for k, vs in self.items()))
|
63
63
|
|
64
|
-
def _to_handler(self):
|
65
|
-
"""Return an effectful handler that binds each index variable to a
|
66
|
-
tensor of its possible index values.
|
67
|
-
|
68
|
-
"""
|
69
|
-
return {
|
70
|
-
name_to_sym(k): functools.partial(lambda v: v, torch.tensor(list(v)))
|
71
|
-
for k, v in self.items()
|
72
|
-
}
|
73
|
-
|
74
64
|
|
75
65
|
def union(*indexsets: IndexSet) -> IndexSet:
|
76
66
|
"""
|
@@ -166,17 +156,9 @@ def indices_of(value: Any) -> IndexSet:
|
|
166
156
|
:param kwargs: Additional keyword arguments used by specific implementations.
|
167
157
|
:return: A :class:`IndexSet` containing the indices on which the value is supported.
|
168
158
|
"""
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
k.__name__: set(range(v)) # type:ignore
|
173
|
-
for (k, v) in sizesof(value).items()
|
174
|
-
}
|
175
|
-
)
|
176
|
-
elif isinstance(value, torch.distributions.Distribution):
|
177
|
-
return indices_of(value.sample())
|
178
|
-
|
179
|
-
return IndexSet()
|
159
|
+
return IndexSet(
|
160
|
+
**{getattr(k, "__name__"): set(range(v)) for (k, v) in sizesof(value).items()}
|
161
|
+
)
|
180
162
|
|
181
163
|
|
182
164
|
@functools.lru_cache(maxsize=None)
|
@@ -1,11 +1,15 @@
|
|
1
|
+
"""
|
2
|
+
This module provides a term representation for numbers and operations on them.
|
3
|
+
"""
|
4
|
+
|
1
5
|
import numbers
|
2
6
|
import operator
|
3
7
|
from typing import Any, TypeVar
|
4
8
|
|
5
9
|
from typing_extensions import ParamSpec
|
6
10
|
|
7
|
-
from effectful.ops.syntax import
|
8
|
-
from effectful.ops.types import Operation, Term
|
11
|
+
from effectful.ops.syntax import defdata, defop, syntactic_eq
|
12
|
+
from effectful.ops.types import Expr, Operation, Term
|
9
13
|
|
10
14
|
P = ParamSpec("P")
|
11
15
|
Q = ParamSpec("Q")
|
@@ -20,7 +24,7 @@ T_Number = TypeVar("T_Number", bound=numbers.Number)
|
|
20
24
|
@numbers.Number.register
|
21
25
|
class _NumberTerm(Term[numbers.Number]):
|
22
26
|
def __init__(
|
23
|
-
self, op: Operation[..., numbers.Number], args:
|
27
|
+
self, op: Operation[..., numbers.Number], *args: Expr, **kwargs: Expr
|
24
28
|
) -> None:
|
25
29
|
self._op = op
|
26
30
|
self._args = args
|
@@ -56,7 +60,7 @@ def _wrap_cmp(op):
|
|
56
60
|
if not any(isinstance(a, Term) for a in (x, y)):
|
57
61
|
return op(x, y)
|
58
62
|
else:
|
59
|
-
raise
|
63
|
+
raise NotImplementedError
|
60
64
|
|
61
65
|
_wrapped_op.__name__ = op.__name__
|
62
66
|
return _wrapped_op
|
@@ -67,7 +71,7 @@ def _wrap_binop(op):
|
|
67
71
|
if not any(isinstance(a, Term) for a in (x, y)):
|
68
72
|
return op(x, y)
|
69
73
|
else:
|
70
|
-
raise
|
74
|
+
raise NotImplementedError
|
71
75
|
|
72
76
|
_wrapped_op.__name__ = op.__name__
|
73
77
|
return _wrapped_op
|
@@ -78,7 +82,7 @@ def _wrap_unop(op):
|
|
78
82
|
if not isinstance(x, Term):
|
79
83
|
return op(x)
|
80
84
|
else:
|
81
|
-
raise
|
85
|
+
raise NotImplementedError
|
82
86
|
|
83
87
|
_wrapped_op.__name__ = op.__name__
|
84
88
|
return _wrapped_op
|
@@ -264,7 +264,7 @@ class PositionalDistribution(pyro.distributions.torch_distribution.TorchDistribu
|
|
264
264
|
self, base_dist: pyro.distributions.torch_distribution.TorchDistribution
|
265
265
|
):
|
266
266
|
self.base_dist = base_dist
|
267
|
-
self.indices = sizesof(base_dist
|
267
|
+
self.indices = sizesof(base_dist)
|
268
268
|
|
269
269
|
n_base = len(base_dist.batch_shape) + len(base_dist.event_shape)
|
270
270
|
self.naming = Naming.from_shape(self.indices.keys(), n_base)
|
@@ -361,7 +361,7 @@ class NamedDistribution(pyro.distributions.torch_distribution.TorchDistribution)
|
|
361
361
|
self.names = names
|
362
362
|
|
363
363
|
assert 1 <= len(names) <= len(base_dist.batch_shape)
|
364
|
-
base_indices = sizesof(base_dist
|
364
|
+
base_indices = sizesof(base_dist)
|
365
365
|
assert not any(n in base_indices for n in names)
|
366
366
|
|
367
367
|
n_base = len(base_dist.batch_shape) + len(base_dist.event_shape)
|
@@ -12,10 +12,9 @@ import tree
|
|
12
12
|
from typing_extensions import ParamSpec
|
13
13
|
|
14
14
|
import effectful.handlers.numbers # noqa: F401
|
15
|
-
from effectful.internals.base_impl import _BaseTerm
|
16
15
|
from effectful.internals.runtime import interpreter
|
17
16
|
from effectful.ops.semantics import apply, evaluate, fvsof, typeof
|
18
|
-
from effectful.ops.syntax import
|
17
|
+
from effectful.ops.syntax import defdata, defop
|
19
18
|
from effectful.ops.types import Expr, Operation, Term
|
20
19
|
|
21
20
|
P = ParamSpec("P")
|
@@ -90,6 +89,11 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
|
|
90
89
|
>>> sizesof(Indexable(torch.ones(2, 3))[a(), b()])
|
91
90
|
{a: 2, b: 3}
|
92
91
|
"""
|
92
|
+
if isinstance(value, torch.distributions.Distribution) and not isinstance(
|
93
|
+
value, Term
|
94
|
+
):
|
95
|
+
return {v: s for a in value.__dict__.values() for v, s in sizesof(a).items()}
|
96
|
+
|
93
97
|
sizes: dict[Operation[[], int], int] = {}
|
94
98
|
|
95
99
|
def _torch_getitem_sizeof(
|
@@ -111,12 +115,12 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
|
|
111
115
|
)
|
112
116
|
sizes[k.op] = shape[i]
|
113
117
|
|
114
|
-
return torch_getitem
|
118
|
+
return defdata(torch_getitem, x, key)
|
115
119
|
|
116
120
|
with interpreter(
|
117
121
|
{
|
118
122
|
torch_getitem: _torch_getitem_sizeof,
|
119
|
-
apply: lambda _, op, *a, **k: op
|
123
|
+
apply: lambda _, op, *a, **k: defdata(op, *a, **k),
|
120
124
|
}
|
121
125
|
):
|
122
126
|
evaluate(value)
|
@@ -204,7 +208,7 @@ def _register_torch_op(torch_fn: Callable[P, T]):
|
|
204
208
|
@defop
|
205
209
|
def _torch_op(*args, **kwargs) -> torch.Tensor:
|
206
210
|
|
207
|
-
tm = _torch_op
|
211
|
+
tm = defdata(_torch_op, *args, **kwargs)
|
208
212
|
sized_fvs = sizesof(tm)
|
209
213
|
|
210
214
|
if (
|
@@ -214,7 +218,7 @@ def _register_torch_op(torch_fn: Callable[P, T]):
|
|
214
218
|
and args[1]
|
215
219
|
and all(isinstance(k, Term) and k.op in sized_fvs for k in args[1])
|
216
220
|
):
|
217
|
-
raise
|
221
|
+
raise NotImplementedError
|
218
222
|
elif sized_fvs and set(sized_fvs.keys()) == fvsof(tm) - {
|
219
223
|
torch_getitem,
|
220
224
|
_torch_op,
|
@@ -230,7 +234,7 @@ def _register_torch_op(torch_fn: Callable[P, T]):
|
|
230
234
|
):
|
231
235
|
return typing.cast(torch.Tensor, torch_fn(*args, **kwargs))
|
232
236
|
else:
|
233
|
-
raise
|
237
|
+
raise NotImplementedError
|
234
238
|
|
235
239
|
functools.update_wrapper(_torch_op, torch_fn)
|
236
240
|
return _torch_op
|
@@ -315,7 +319,7 @@ class Indexable:
|
|
315
319
|
|
316
320
|
|
317
321
|
@defdata.register(torch.Tensor)
|
318
|
-
def _embed_tensor(op, args, kwargs):
|
322
|
+
def _embed_tensor(op, *args, **kwargs):
|
319
323
|
if (
|
320
324
|
op is torch_getitem
|
321
325
|
and not isinstance(args[0], Term)
|
@@ -328,10 +332,29 @@ def _embed_tensor(op, args, kwargs):
|
|
328
332
|
):
|
329
333
|
return _EagerTensorTerm(args[0], args[1])
|
330
334
|
else:
|
331
|
-
return _TensorTerm(op, args, kwargs)
|
335
|
+
return _TensorTerm(op, *args, **kwargs)
|
336
|
+
|
332
337
|
|
338
|
+
class _TensorTerm(Term[torch.Tensor]):
|
339
|
+
def __init__(
|
340
|
+
self, op: Operation[..., torch.Tensor], *args: Expr, **kwargs: Expr
|
341
|
+
) -> None:
|
342
|
+
self._op = op
|
343
|
+
self._args = args
|
344
|
+
self._kwargs = kwargs
|
345
|
+
|
346
|
+
@property
|
347
|
+
def op(self) -> Operation[..., torch.Tensor]:
|
348
|
+
return self._op
|
349
|
+
|
350
|
+
@property
|
351
|
+
def args(self) -> tuple:
|
352
|
+
return self._args
|
353
|
+
|
354
|
+
@property
|
355
|
+
def kwargs(self) -> dict:
|
356
|
+
return self._kwargs
|
333
357
|
|
334
|
-
class _TensorTerm(_BaseTerm[torch.Tensor]):
|
335
358
|
def __getitem__(
|
336
359
|
self, key: Union[Expr[IndexElement], Tuple[Expr[IndexElement], ...]]
|
337
360
|
) -> Expr[torch.Tensor]:
|
@@ -5,7 +5,7 @@ from typing import Any, Callable, Optional, Set, Type, TypeVar
|
|
5
5
|
import tree
|
6
6
|
from typing_extensions import ParamSpec
|
7
7
|
|
8
|
-
from effectful.ops.syntax import
|
8
|
+
from effectful.ops.syntax import deffn, defop
|
9
9
|
from effectful.ops.types import Expr, Interpretation, Operation, Term
|
10
10
|
|
11
11
|
P = ParamSpec("P")
|
@@ -15,10 +15,8 @@ T = TypeVar("T")
|
|
15
15
|
V = TypeVar("V")
|
16
16
|
|
17
17
|
|
18
|
-
@defop
|
19
|
-
def apply(
|
20
|
-
intp: Interpretation[S, T], op: Operation[P, S], *args: P.args, **kwargs: P.kwargs
|
21
|
-
) -> T:
|
18
|
+
@defop
|
19
|
+
def apply(intp: Interpretation, op: Operation, *args, **kwargs) -> Any:
|
22
20
|
"""Apply ``op`` to ``args``, ``kwargs`` in interpretation ``intp``.
|
23
21
|
|
24
22
|
Handling :func:`apply` changes the evaluation strategy of terms.
|
@@ -50,7 +48,7 @@ def apply(
|
|
50
48
|
elif apply in intp:
|
51
49
|
return intp[apply](intp, op, *args, **kwargs)
|
52
50
|
else:
|
53
|
-
return op.__default_rule__(*args, **kwargs)
|
51
|
+
return op.__default_rule__(*args, **kwargs)
|
54
52
|
|
55
53
|
|
56
54
|
@defop # type: ignore
|
@@ -60,9 +58,6 @@ def call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
|
60
58
|
This operation is invoked by the ``__call__`` method of a callable term.
|
61
59
|
|
62
60
|
"""
|
63
|
-
if not isinstance(fn, Term):
|
64
|
-
fn = defterm(fn)
|
65
|
-
|
66
61
|
if isinstance(fn, Term) and fn.op is deffn:
|
67
62
|
body: Expr[Callable[P, T]] = fn.args[0]
|
68
63
|
argvars: tuple[Operation, ...] = fn.args[1:]
|
@@ -73,8 +68,10 @@ def call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
|
73
68
|
}
|
74
69
|
with handler(subs):
|
75
70
|
return evaluate(body)
|
71
|
+
elif not any(isinstance(a, Term) for a in tree.flatten((fn, args, kwargs))):
|
72
|
+
return fn(*args, **kwargs)
|
76
73
|
else:
|
77
|
-
raise
|
74
|
+
raise NotImplementedError
|
78
75
|
|
79
76
|
|
80
77
|
@defop
|
@@ -93,9 +90,7 @@ def fwd(*args, **kwargs) -> Any:
|
|
93
90
|
raise RuntimeError("fwd should only be called in the context of a handler")
|
94
91
|
|
95
92
|
|
96
|
-
def coproduct(
|
97
|
-
intp: Interpretation[S, T], intp2: Interpretation[S, T]
|
98
|
-
) -> Interpretation[S, T]:
|
93
|
+
def coproduct(intp: Interpretation, intp2: Interpretation) -> Interpretation:
|
99
94
|
"""The coproduct of two interpretations handles any effect that is handled
|
100
95
|
by either. If both interpretations handle an effect, ``intp2`` takes
|
101
96
|
precedence.
|
@@ -151,7 +146,7 @@ def coproduct(
|
|
151
146
|
if op is fwd or op is _get_args:
|
152
147
|
res[op] = i2 # fast path for special cases, should be equivalent if removed
|
153
148
|
else:
|
154
|
-
i1 = intp.get(op, op.__default_rule__)
|
149
|
+
i1 = intp.get(op, op.__default_rule__)
|
155
150
|
|
156
151
|
# calling fwd in the right handler should dispatch to the left handler
|
157
152
|
res[op] = _set_prompt(fwd, _restore_args(_save_args(i1)), _save_args(i2))
|
@@ -159,9 +154,7 @@ def coproduct(
|
|
159
154
|
return res
|
160
155
|
|
161
156
|
|
162
|
-
def product(
|
163
|
-
intp: Interpretation[S, T], intp2: Interpretation[S, T]
|
164
|
-
) -> Interpretation[S, T]:
|
157
|
+
def product(intp: Interpretation, intp2: Interpretation) -> Interpretation:
|
165
158
|
"""The product of two interpretations handles any effect that is handled by
|
166
159
|
``intp2``. Handlers in ``intp2`` may override handlers in ``intp``, but
|
167
160
|
those changes are not visible to the handlers in ``intp``. In this way,
|
@@ -207,7 +200,7 @@ def product(
|
|
207
200
|
|
208
201
|
|
209
202
|
@contextlib.contextmanager
|
210
|
-
def runner(intp: Interpretation
|
203
|
+
def runner(intp: Interpretation):
|
211
204
|
"""Install an interpretation by taking a product with the current
|
212
205
|
interpretation.
|
213
206
|
|
@@ -223,7 +216,7 @@ def runner(intp: Interpretation[S, T]):
|
|
223
216
|
|
224
217
|
|
225
218
|
@contextlib.contextmanager
|
226
|
-
def handler(intp: Interpretation
|
219
|
+
def handler(intp: Interpretation):
|
227
220
|
"""Install an interpretation by taking a coproduct with the current
|
228
221
|
interpretation.
|
229
222
|
|
@@ -234,7 +227,7 @@ def handler(intp: Interpretation[S, T]):
|
|
234
227
|
yield intp
|
235
228
|
|
236
229
|
|
237
|
-
def evaluate(expr: Expr[T], *, intp: Optional[Interpretation
|
230
|
+
def evaluate(expr: Expr[T], *, intp: Optional[Interpretation] = None) -> Expr[T]:
|
238
231
|
"""Evaluate expression ``expr`` using interpretation ``intp``. If no
|
239
232
|
interpretation is provided, uses the current interpretation.
|
240
233
|
|
@@ -245,7 +238,7 @@ def evaluate(expr: Expr[T], *, intp: Optional[Interpretation[S, T]] = None) -> E
|
|
245
238
|
|
246
239
|
>>> @defop
|
247
240
|
... def add(x: int, y: int) -> int:
|
248
|
-
... raise
|
241
|
+
... raise NotImplementedError
|
249
242
|
>>> expr = add(1, add(2, 3))
|
250
243
|
>>> expr
|
251
244
|
add(1, add(2, 3))
|
@@ -258,13 +251,11 @@ def evaluate(expr: Expr[T], *, intp: Optional[Interpretation[S, T]] = None) -> E
|
|
258
251
|
|
259
252
|
intp = get_interpretation()
|
260
253
|
|
261
|
-
expr = defterm(expr) if not isinstance(expr, Term) else expr
|
262
|
-
|
263
254
|
if isinstance(expr, Term):
|
264
255
|
(args, kwargs) = tree.map_structure(
|
265
256
|
functools.partial(evaluate, intp=intp), (expr.args, expr.kwargs)
|
266
257
|
)
|
267
|
-
return apply.__default_rule__(intp, expr.op, *args, **kwargs)
|
258
|
+
return apply.__default_rule__(intp, expr.op, *args, **kwargs)
|
268
259
|
elif tree.is_nested(expr):
|
269
260
|
return tree.map_structure(functools.partial(evaluate, intp=intp), expr)
|
270
261
|
else:
|
@@ -280,7 +271,7 @@ def typeof(term: Expr[T]) -> Type[T]:
|
|
280
271
|
|
281
272
|
>>> @defop
|
282
273
|
... def cmp(x: int, y: int) -> bool:
|
283
|
-
... raise
|
274
|
+
... raise NotImplementedError
|
284
275
|
>>> typeof(cmp(1, 2))
|
285
276
|
<class 'bool'>
|
286
277
|
|
@@ -290,7 +281,7 @@ def typeof(term: Expr[T]) -> Type[T]:
|
|
290
281
|
>>> T = TypeVar('T')
|
291
282
|
>>> @defop
|
292
283
|
... def if_then_else(x: bool, a: T, b: T) -> T:
|
293
|
-
... raise
|
284
|
+
... raise NotImplementedError
|
294
285
|
>>> typeof(if_then_else(True, 0, 1))
|
295
286
|
<class 'int'>
|
296
287
|
|
@@ -298,7 +289,7 @@ def typeof(term: Expr[T]) -> Type[T]:
|
|
298
289
|
from effectful.internals.runtime import interpreter
|
299
290
|
|
300
291
|
with interpreter({apply: lambda _, op, *a, **k: op.__type_rule__(*a, **k)}):
|
301
|
-
return evaluate(term) # type: ignore
|
292
|
+
return evaluate(term) if isinstance(term, Term) else type(term) # type: ignore
|
302
293
|
|
303
294
|
|
304
295
|
def fvsof(term: Expr[S]) -> Set[Operation]:
|
@@ -308,7 +299,7 @@ def fvsof(term: Expr[S]) -> Set[Operation]:
|
|
308
299
|
|
309
300
|
>>> @defop
|
310
301
|
... def f(x: int, y: int) -> int:
|
311
|
-
... raise
|
302
|
+
... raise NotImplementedError
|
312
303
|
>>> fvsof(f(1, 2))
|
313
304
|
{f}
|
314
305
|
|
@@ -319,7 +310,12 @@ def fvsof(term: Expr[S]) -> Set[Operation]:
|
|
319
310
|
|
320
311
|
def _update_fvs(_, op, *args, **kwargs):
|
321
312
|
_fvs.add(op)
|
322
|
-
|
313
|
+
arg_ctxs, kwarg_ctxs = op.__fvs_rule__(*args, **kwargs)
|
314
|
+
bound_vars = set().union(
|
315
|
+
*(a for a in arg_ctxs),
|
316
|
+
*(k for k in kwarg_ctxs.values()),
|
317
|
+
)
|
318
|
+
for bound_var in bound_vars:
|
323
319
|
if bound_var in _fvs:
|
324
320
|
_fvs.remove(bound_var)
|
325
321
|
|