effectful 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/__init__.py +0 -0
- effectful/handlers/__init__.py +0 -0
- effectful/handlers/indexed.py +320 -0
- effectful/handlers/numbers.py +259 -0
- effectful/handlers/pyro.py +466 -0
- effectful/handlers/torch.py +572 -0
- effectful/internals/__init__.py +0 -0
- effectful/internals/base_impl.py +259 -0
- effectful/internals/runtime.py +78 -0
- effectful/ops/__init__.py +0 -0
- effectful/ops/semantics.py +329 -0
- effectful/ops/syntax.py +523 -0
- effectful/ops/types.py +110 -0
- effectful/py.typed +0 -0
- effectful-0.0.1.dist-info/LICENSE.md +202 -0
- effectful-0.0.1.dist-info/METADATA +170 -0
- effectful-0.0.1.dist-info/RECORD +19 -0
- effectful-0.0.1.dist-info/WHEEL +5 -0
- effectful-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,259 @@
|
|
1
|
+
import collections
|
2
|
+
import functools
|
3
|
+
import inspect
|
4
|
+
import typing
|
5
|
+
from typing import Callable, Generic, Mapping, Sequence, Set, Type, TypeVar
|
6
|
+
|
7
|
+
import tree
|
8
|
+
from typing_extensions import ParamSpec
|
9
|
+
|
10
|
+
from effectful.ops.types import Expr, Operation, Term
|
11
|
+
|
12
|
+
P = ParamSpec("P")
|
13
|
+
Q = ParamSpec("Q")
|
14
|
+
S = TypeVar("S")
|
15
|
+
T = TypeVar("T")
|
16
|
+
V = TypeVar("V")
|
17
|
+
|
18
|
+
|
19
|
+
def rename(
|
20
|
+
subs: Mapping[Operation[..., S], Operation[..., S]],
|
21
|
+
leaf_value: V, # Union[Term[V], Operation[..., V], V],
|
22
|
+
) -> V: # Union[Term[V], Operation[..., V], V]:
|
23
|
+
from effectful.internals.runtime import interpreter
|
24
|
+
from effectful.ops.semantics import apply, evaluate
|
25
|
+
|
26
|
+
if isinstance(leaf_value, Operation):
|
27
|
+
return subs.get(leaf_value, leaf_value) # type: ignore
|
28
|
+
elif isinstance(leaf_value, Term):
|
29
|
+
with interpreter(
|
30
|
+
{apply: lambda _, op, *a, **k: op.__free_rule__(*a, **k), **subs}
|
31
|
+
):
|
32
|
+
return evaluate(leaf_value) # type: ignore
|
33
|
+
else:
|
34
|
+
return leaf_value
|
35
|
+
|
36
|
+
|
37
|
+
class _BaseOperation(Generic[Q, V], Operation[Q, V]):
|
38
|
+
signature: Callable[Q, V]
|
39
|
+
|
40
|
+
def __init__(self, signature: Callable[Q, V]):
|
41
|
+
functools.update_wrapper(self, signature)
|
42
|
+
self.signature = signature
|
43
|
+
|
44
|
+
def __eq__(self, other):
|
45
|
+
if not isinstance(other, Operation):
|
46
|
+
return NotImplemented
|
47
|
+
return self.signature == other.signature
|
48
|
+
|
49
|
+
def __hash__(self):
|
50
|
+
return hash(self.signature)
|
51
|
+
|
52
|
+
def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]":
|
53
|
+
from effectful.ops.syntax import NoDefaultRule
|
54
|
+
|
55
|
+
try:
|
56
|
+
return self.signature(*args, **kwargs)
|
57
|
+
except NoDefaultRule:
|
58
|
+
return self.__free_rule__(*args, **kwargs)
|
59
|
+
|
60
|
+
def __free_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]":
|
61
|
+
from effectful.ops.syntax import Bound, Scoped, defdata, defop
|
62
|
+
|
63
|
+
sig = inspect.signature(self.signature)
|
64
|
+
bound_sig = sig.bind(*args, **kwargs)
|
65
|
+
bound_sig.apply_defaults()
|
66
|
+
|
67
|
+
bound_vars: dict[int, set[Operation]] = collections.defaultdict(set)
|
68
|
+
scoped_args: dict[int, set[str]] = collections.defaultdict(set)
|
69
|
+
unscoped_args: set[str] = set()
|
70
|
+
for param_name, param in bound_sig.signature.parameters.items():
|
71
|
+
if typing.get_origin(param.annotation) is typing.Annotated:
|
72
|
+
for anno in param.annotation.__metadata__:
|
73
|
+
if isinstance(anno, Bound):
|
74
|
+
scoped_args[anno.scope].add(param_name)
|
75
|
+
if param.kind is inspect.Parameter.VAR_POSITIONAL:
|
76
|
+
assert isinstance(bound_sig.arguments[param_name], tuple)
|
77
|
+
for bound_var in bound_sig.arguments[param_name]:
|
78
|
+
bound_vars[anno.scope].add(bound_var)
|
79
|
+
elif param.kind is inspect.Parameter.VAR_KEYWORD:
|
80
|
+
assert isinstance(bound_sig.arguments[param_name], dict)
|
81
|
+
for bound_var in bound_sig.arguments[param_name].values():
|
82
|
+
bound_vars[anno.scope].add(bound_var)
|
83
|
+
else:
|
84
|
+
bound_vars[anno.scope].add(bound_sig.arguments[param_name])
|
85
|
+
elif isinstance(anno, Scoped):
|
86
|
+
scoped_args[anno.scope].add(param_name)
|
87
|
+
else:
|
88
|
+
unscoped_args.add(param_name)
|
89
|
+
|
90
|
+
# TODO replace this temporary check with more general scope level propagation
|
91
|
+
if bound_vars:
|
92
|
+
min_scope = min(bound_vars.keys(), default=0)
|
93
|
+
scoped_args[min_scope] |= unscoped_args
|
94
|
+
max_scope = max(bound_vars.keys(), default=0)
|
95
|
+
assert all(s in bound_vars or s > max_scope for s in scoped_args.keys())
|
96
|
+
|
97
|
+
# recursively rename bound variables from innermost to outermost scope
|
98
|
+
for scope in sorted(bound_vars.keys()):
|
99
|
+
# create fresh variables for each bound variable in the scope
|
100
|
+
renaming_map = {var: defop(var) for var in bound_vars[scope]}
|
101
|
+
# get just the arguments that are in the scope
|
102
|
+
for name in scoped_args[scope]:
|
103
|
+
bound_sig.arguments[name] = tree.map_structure(
|
104
|
+
lambda a: rename(renaming_map, a),
|
105
|
+
bound_sig.arguments[name],
|
106
|
+
)
|
107
|
+
|
108
|
+
return defdata(_BaseTerm(self, bound_sig.args, bound_sig.kwargs))
|
109
|
+
|
110
|
+
def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> Type[V]:
|
111
|
+
sig = inspect.signature(self.signature)
|
112
|
+
bound_sig = sig.bind(*args, **kwargs)
|
113
|
+
bound_sig.apply_defaults()
|
114
|
+
|
115
|
+
anno = sig.return_annotation
|
116
|
+
if anno is inspect.Signature.empty:
|
117
|
+
return typing.cast(Type[V], object)
|
118
|
+
elif isinstance(anno, typing.TypeVar):
|
119
|
+
# rudimentary but sound special-case type inference sufficient for syntax ops:
|
120
|
+
# if the return type annotation is a TypeVar,
|
121
|
+
# look for a parameter with the same annotation and return its type,
|
122
|
+
# otherwise give up and return Any/object
|
123
|
+
for name, param in bound_sig.signature.parameters.items():
|
124
|
+
if param.annotation is anno and param.kind not in (
|
125
|
+
inspect.Parameter.VAR_POSITIONAL,
|
126
|
+
inspect.Parameter.VAR_KEYWORD,
|
127
|
+
):
|
128
|
+
arg = bound_sig.arguments[name]
|
129
|
+
tp: Type[V] = type(arg) if not isinstance(arg, type) else arg
|
130
|
+
return tp
|
131
|
+
return typing.cast(Type[V], object)
|
132
|
+
elif typing.get_origin(anno) is typing.Annotated:
|
133
|
+
tp = typing.get_args(anno)[0]
|
134
|
+
if not typing.TYPE_CHECKING:
|
135
|
+
tp = tp if typing.get_origin(tp) is None else typing.get_origin(tp)
|
136
|
+
return tp
|
137
|
+
elif typing.get_origin(anno) is not None:
|
138
|
+
return typing.get_origin(anno)
|
139
|
+
else:
|
140
|
+
return anno
|
141
|
+
|
142
|
+
def __fvs_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> Set[Operation]:
|
143
|
+
from effectful.ops.syntax import Bound
|
144
|
+
|
145
|
+
sig = inspect.signature(self.signature)
|
146
|
+
bound_sig = sig.bind(*args, **kwargs)
|
147
|
+
bound_sig.apply_defaults()
|
148
|
+
|
149
|
+
bound_vars: Set[Operation] = set()
|
150
|
+
for param_name, param in bound_sig.signature.parameters.items():
|
151
|
+
if typing.get_origin(param.annotation) is typing.Annotated:
|
152
|
+
for anno in param.annotation.__metadata__:
|
153
|
+
if isinstance(anno, Bound):
|
154
|
+
if param.kind is inspect.Parameter.VAR_POSITIONAL:
|
155
|
+
for bound_var in bound_sig.arguments[param_name]:
|
156
|
+
bound_vars.add(bound_var)
|
157
|
+
elif param.kind is inspect.Parameter.VAR_KEYWORD:
|
158
|
+
for bound_var in bound_sig.arguments[param_name].values():
|
159
|
+
bound_vars.add(bound_var)
|
160
|
+
else:
|
161
|
+
bound_var = bound_sig.arguments[param_name]
|
162
|
+
bound_vars.add(bound_var)
|
163
|
+
|
164
|
+
return bound_vars
|
165
|
+
|
166
|
+
def __repr_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> str:
|
167
|
+
args_str = ", ".join(map(str, args)) if args else ""
|
168
|
+
kwargs_str = (
|
169
|
+
", ".join(f"{k}={str(v)}" for k, v in kwargs.items()) if kwargs else ""
|
170
|
+
)
|
171
|
+
|
172
|
+
ret = f"{self.signature.__name__}({args_str}"
|
173
|
+
if kwargs:
|
174
|
+
ret += f"{', ' if args else ''}"
|
175
|
+
ret += f"{kwargs_str})"
|
176
|
+
return ret
|
177
|
+
|
178
|
+
def __repr__(self):
|
179
|
+
return self.signature.__name__
|
180
|
+
|
181
|
+
|
182
|
+
class _BaseTerm(Generic[T], Term[T]):
|
183
|
+
_op: Operation[..., T]
|
184
|
+
_args: Sequence[Expr]
|
185
|
+
_kwargs: Mapping[str, Expr]
|
186
|
+
|
187
|
+
def __init__(
|
188
|
+
self,
|
189
|
+
op: Operation[..., T],
|
190
|
+
args: Sequence[Expr],
|
191
|
+
kwargs: Mapping[str, Expr],
|
192
|
+
):
|
193
|
+
self._op = op
|
194
|
+
self._args = args
|
195
|
+
self._kwargs = kwargs
|
196
|
+
|
197
|
+
def __eq__(self, other) -> bool:
|
198
|
+
from effectful.ops.syntax import syntactic_eq
|
199
|
+
|
200
|
+
return syntactic_eq(self, other)
|
201
|
+
|
202
|
+
@property
|
203
|
+
def op(self):
|
204
|
+
return self._op
|
205
|
+
|
206
|
+
@property
|
207
|
+
def args(self):
|
208
|
+
return self._args
|
209
|
+
|
210
|
+
@property
|
211
|
+
def kwargs(self):
|
212
|
+
return self._kwargs
|
213
|
+
|
214
|
+
|
215
|
+
class _CallableTerm(Generic[P, T], _BaseTerm[collections.abc.Callable[P, T]]):
|
216
|
+
def __call__(self, *args: Expr, **kwargs: Expr) -> Expr[T]:
|
217
|
+
from effectful.ops.semantics import call
|
218
|
+
|
219
|
+
return call(self, *args, **kwargs) # type: ignore
|
220
|
+
|
221
|
+
|
222
|
+
def _unembed_callable(value: Callable[P, T]) -> Expr[Callable[P, T]]:
|
223
|
+
from effectful.internals.runtime import interpreter
|
224
|
+
from effectful.ops.semantics import apply, call
|
225
|
+
from effectful.ops.syntax import deffn, defop
|
226
|
+
|
227
|
+
assert not isinstance(value, Term)
|
228
|
+
|
229
|
+
try:
|
230
|
+
sig = inspect.signature(value)
|
231
|
+
except ValueError:
|
232
|
+
return value
|
233
|
+
|
234
|
+
for name, param in sig.parameters.items():
|
235
|
+
if param.kind in (
|
236
|
+
inspect.Parameter.VAR_POSITIONAL,
|
237
|
+
inspect.Parameter.VAR_KEYWORD,
|
238
|
+
):
|
239
|
+
raise NotImplementedError(
|
240
|
+
f"cannot unembed {value}: parameter {name} is variadic"
|
241
|
+
)
|
242
|
+
|
243
|
+
bound_sig = sig.bind(
|
244
|
+
**{name: defop(param.annotation) for name, param in sig.parameters.items()}
|
245
|
+
)
|
246
|
+
bound_sig.apply_defaults()
|
247
|
+
|
248
|
+
with interpreter(
|
249
|
+
{
|
250
|
+
apply: lambda _, op, *a, **k: op.__free_rule__(*a, **k),
|
251
|
+
call: call.__default_rule__,
|
252
|
+
}
|
253
|
+
):
|
254
|
+
body = value(
|
255
|
+
*[a() for a in bound_sig.args],
|
256
|
+
**{k: v() for k, v in bound_sig.kwargs.items()},
|
257
|
+
)
|
258
|
+
|
259
|
+
return deffn(body, *bound_sig.args, **bound_sig.kwargs)
|
@@ -0,0 +1,78 @@
|
|
1
|
+
import contextlib
|
2
|
+
import dataclasses
|
3
|
+
import functools
|
4
|
+
from typing import Callable, Generic, Mapping, Tuple, TypeVar
|
5
|
+
|
6
|
+
from typing_extensions import ParamSpec
|
7
|
+
|
8
|
+
from effectful.ops.syntax import defop
|
9
|
+
from effectful.ops.types import Interpretation, Operation
|
10
|
+
|
11
|
+
P = ParamSpec("P")
|
12
|
+
S = TypeVar("S")
|
13
|
+
T = TypeVar("T")
|
14
|
+
|
15
|
+
|
16
|
+
@dataclasses.dataclass
|
17
|
+
class Runtime(Generic[S, T]):
|
18
|
+
interpretation: "Interpretation[S, T]"
|
19
|
+
|
20
|
+
|
21
|
+
@functools.lru_cache(maxsize=1)
|
22
|
+
def get_runtime() -> Runtime:
|
23
|
+
return Runtime(interpretation={})
|
24
|
+
|
25
|
+
|
26
|
+
def get_interpretation():
|
27
|
+
return get_runtime().interpretation
|
28
|
+
|
29
|
+
|
30
|
+
@contextlib.contextmanager
|
31
|
+
def interpreter(intp: "Interpretation"):
|
32
|
+
|
33
|
+
r = get_runtime()
|
34
|
+
old_intp = r.interpretation
|
35
|
+
try:
|
36
|
+
old_intp, r.interpretation = r.interpretation, dict(intp)
|
37
|
+
yield intp
|
38
|
+
finally:
|
39
|
+
r.interpretation = old_intp
|
40
|
+
|
41
|
+
|
42
|
+
@defop
|
43
|
+
def _get_args() -> Tuple[Tuple, Mapping]:
|
44
|
+
return ((), {})
|
45
|
+
|
46
|
+
|
47
|
+
def _restore_args(fn: Callable[P, T]) -> Callable[P, T]:
|
48
|
+
@functools.wraps(fn)
|
49
|
+
def _cont_wrapper(*a: P.args, **k: P.kwargs) -> T:
|
50
|
+
a, k = (a, k) if a or k else _get_args() # type: ignore
|
51
|
+
return fn(*a, **k)
|
52
|
+
|
53
|
+
return _cont_wrapper
|
54
|
+
|
55
|
+
|
56
|
+
def _save_args(fn: Callable[P, T]) -> Callable[P, T]:
|
57
|
+
from effectful.ops.semantics import handler
|
58
|
+
|
59
|
+
@functools.wraps(fn)
|
60
|
+
def _cont_wrapper(*a: P.args, **k: P.kwargs) -> T:
|
61
|
+
with handler({_get_args: lambda: (a, k)}):
|
62
|
+
return fn(*a, **k)
|
63
|
+
|
64
|
+
return _cont_wrapper
|
65
|
+
|
66
|
+
|
67
|
+
def _set_prompt(
|
68
|
+
prompt: Operation[P, T], cont: Callable[P, T], body: Callable[P, T]
|
69
|
+
) -> Callable[P, T]:
|
70
|
+
from effectful.ops.semantics import handler
|
71
|
+
|
72
|
+
@functools.wraps(body)
|
73
|
+
def bound_body(*a: P.args, **k: P.kwargs) -> T:
|
74
|
+
next_cont = get_interpretation().get(prompt, prompt.__default_rule__)
|
75
|
+
with handler({prompt: handler({prompt: next_cont})(cont)}):
|
76
|
+
return body(*a, **k)
|
77
|
+
|
78
|
+
return bound_body
|
File without changes
|
@@ -0,0 +1,329 @@
|
|
1
|
+
import contextlib
|
2
|
+
import functools
|
3
|
+
from typing import Any, Callable, Optional, Set, Type, TypeVar
|
4
|
+
|
5
|
+
import tree
|
6
|
+
from typing_extensions import ParamSpec
|
7
|
+
|
8
|
+
from effectful.ops.syntax import NoDefaultRule, deffn, defop, defterm
|
9
|
+
from effectful.ops.types import Expr, Interpretation, Operation, Term
|
10
|
+
|
11
|
+
P = ParamSpec("P")
|
12
|
+
Q = ParamSpec("Q")
|
13
|
+
S = TypeVar("S")
|
14
|
+
T = TypeVar("T")
|
15
|
+
V = TypeVar("V")
|
16
|
+
|
17
|
+
|
18
|
+
@defop # type: ignore
|
19
|
+
def apply(
|
20
|
+
intp: Interpretation[S, T], op: Operation[P, S], *args: P.args, **kwargs: P.kwargs
|
21
|
+
) -> T:
|
22
|
+
"""Apply ``op`` to ``args``, ``kwargs`` in interpretation ``intp``.
|
23
|
+
|
24
|
+
Handling :func:`apply` changes the evaluation strategy of terms.
|
25
|
+
|
26
|
+
**Example usage**:
|
27
|
+
|
28
|
+
>>> @defop
|
29
|
+
... def add(x: int, y: int) -> int:
|
30
|
+
... return x + y
|
31
|
+
>>> @defop
|
32
|
+
... def mul(x: int, y: int) -> int:
|
33
|
+
... return x * y
|
34
|
+
|
35
|
+
``add`` and ``mul`` have default rules, so this term evaluates:
|
36
|
+
|
37
|
+
>>> mul(add(1, 2), 3)
|
38
|
+
9
|
39
|
+
|
40
|
+
By installing an :func:`apply` handler, we capture the term instead:
|
41
|
+
|
42
|
+
>>> with handler({apply: lambda _, op, *args, **kwargs: op.__free_rule__(*args, **kwargs) }):
|
43
|
+
... term = mul(add(1, 2), 3)
|
44
|
+
>>> term
|
45
|
+
mul(add(1, 2), 3)
|
46
|
+
|
47
|
+
"""
|
48
|
+
if op in intp:
|
49
|
+
return intp[op](*args, **kwargs)
|
50
|
+
elif apply in intp:
|
51
|
+
return intp[apply](intp, op, *args, **kwargs)
|
52
|
+
else:
|
53
|
+
return op.__default_rule__(*args, **kwargs) # type: ignore
|
54
|
+
|
55
|
+
|
56
|
+
@defop # type: ignore
|
57
|
+
def call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
58
|
+
"""An operation that eliminates a callable term.
|
59
|
+
|
60
|
+
This operation is invoked by the ``__call__`` method of a callable term.
|
61
|
+
|
62
|
+
"""
|
63
|
+
if not isinstance(fn, Term):
|
64
|
+
fn = defterm(fn)
|
65
|
+
|
66
|
+
if isinstance(fn, Term) and fn.op is deffn:
|
67
|
+
body: Expr[Callable[P, T]] = fn.args[0]
|
68
|
+
argvars: tuple[Operation, ...] = fn.args[1:]
|
69
|
+
kwvars: dict[str, Operation] = fn.kwargs
|
70
|
+
subs = {
|
71
|
+
**{v: functools.partial(lambda x: x, a) for v, a in zip(argvars, args)},
|
72
|
+
**{kwvars[k]: functools.partial(lambda x: x, kwargs[k]) for k in kwargs},
|
73
|
+
}
|
74
|
+
with handler(subs):
|
75
|
+
return evaluate(body)
|
76
|
+
else:
|
77
|
+
raise NoDefaultRule
|
78
|
+
|
79
|
+
|
80
|
+
@defop
|
81
|
+
def fwd(*args, **kwargs) -> Any:
|
82
|
+
"""Forward execution to the next most enclosing handler.
|
83
|
+
|
84
|
+
:func:`fwd` should only be called in the context of a handler.
|
85
|
+
|
86
|
+
:param args: Positional arguments.
|
87
|
+
:param kwargs: Keyword arguments.
|
88
|
+
|
89
|
+
If no positional or keyword arguments are provided, :func:`fwd` will forward
|
90
|
+
the current arguments to the next handler.
|
91
|
+
|
92
|
+
"""
|
93
|
+
raise RuntimeError("fwd should only be called in the context of a handler")
|
94
|
+
|
95
|
+
|
96
|
+
def coproduct(
|
97
|
+
intp: Interpretation[S, T], intp2: Interpretation[S, T]
|
98
|
+
) -> Interpretation[S, T]:
|
99
|
+
"""The coproduct of two interpretations handles any effect that is handled
|
100
|
+
by either. If both interpretations handle an effect, ``intp2`` takes
|
101
|
+
precedence.
|
102
|
+
|
103
|
+
Handlers in ``intp2`` that override a handler in ``intp`` may call the
|
104
|
+
overridden handler using :func:`fwd`. This allows handlers to be written
|
105
|
+
that extend or wrap other handlers.
|
106
|
+
|
107
|
+
**Example usage**:
|
108
|
+
|
109
|
+
The ``message`` effect produces a welcome message using two helper effects:
|
110
|
+
``greeting`` and ``name``. By handling these helper effects, we can customize the
|
111
|
+
message.
|
112
|
+
|
113
|
+
>>> message, greeting, name = defop(str), defop(str), defop(str)
|
114
|
+
>>> i1 = {message: lambda: f"{greeting()} {name()}!", greeting: lambda: "Hi"}
|
115
|
+
>>> i2 = {name: lambda: "Jack"}
|
116
|
+
|
117
|
+
The coproduct of ``i1`` and ``i2`` handles all three effects.
|
118
|
+
|
119
|
+
>>> i3 = coproduct(i1, i2)
|
120
|
+
>>> with handler(i3):
|
121
|
+
... print(f'{message()}')
|
122
|
+
Hi Jack!
|
123
|
+
|
124
|
+
We can delegate to an enclosing handler by calling :func:`fwd`. Here we
|
125
|
+
override the ``name`` handler to format the name differently.
|
126
|
+
|
127
|
+
>>> i4 = coproduct(i3, {name: lambda: f'*{fwd()}*'})
|
128
|
+
>>> with handler(i4):
|
129
|
+
... print(f'{message()}')
|
130
|
+
Hi *Jack*!
|
131
|
+
|
132
|
+
.. note::
|
133
|
+
|
134
|
+
:func:`coproduct` allows effects to be overridden in a pervasive way, but
|
135
|
+
this is not always desirable. In particular, an interpretation with
|
136
|
+
handlers that call "internal" private effects may be broken if coproducted
|
137
|
+
with an interpretation that handles those effects. It is dangerous to take
|
138
|
+
the coproduct of arbitrary interpretations. For an alternate form of
|
139
|
+
interpretation composition, see :func:`product`.
|
140
|
+
|
141
|
+
"""
|
142
|
+
from effectful.internals.runtime import (
|
143
|
+
_get_args,
|
144
|
+
_restore_args,
|
145
|
+
_save_args,
|
146
|
+
_set_prompt,
|
147
|
+
)
|
148
|
+
|
149
|
+
res = dict(intp)
|
150
|
+
for op, i2 in intp2.items():
|
151
|
+
if op is fwd or op is _get_args:
|
152
|
+
res[op] = i2 # fast path for special cases, should be equivalent if removed
|
153
|
+
else:
|
154
|
+
i1 = intp.get(op, op.__default_rule__) # type: ignore
|
155
|
+
|
156
|
+
# calling fwd in the right handler should dispatch to the left handler
|
157
|
+
res[op] = _set_prompt(fwd, _restore_args(_save_args(i1)), _save_args(i2))
|
158
|
+
|
159
|
+
return res
|
160
|
+
|
161
|
+
|
162
|
+
def product(
|
163
|
+
intp: Interpretation[S, T], intp2: Interpretation[S, T]
|
164
|
+
) -> Interpretation[S, T]:
|
165
|
+
"""The product of two interpretations handles any effect that is handled by
|
166
|
+
``intp2``. Handlers in ``intp2`` may override handlers in ``intp``, but
|
167
|
+
those changes are not visible to the handlers in ``intp``. In this way,
|
168
|
+
``intp`` is isolated from ``intp2``.
|
169
|
+
|
170
|
+
**Example usage**:
|
171
|
+
|
172
|
+
In this example, ``i1`` has a ``param`` effect that defines some hyperparameter and
|
173
|
+
an effect ``f1`` that uses it. ``i2`` redefines ``param`` and uses it in a new effect
|
174
|
+
``f2``, which calls ``f1``.
|
175
|
+
|
176
|
+
>>> param, f1, f2 = defop(int), defop(dict), defop(dict)
|
177
|
+
>>> i1 = {param: lambda: 1, f1: lambda: {'inner': param()}}
|
178
|
+
>>> i2 = {param: lambda: 2, f2: lambda: f1() | {'outer': param()}}
|
179
|
+
|
180
|
+
Using :func:`product`, ``i2``'s override of ``param`` is not visible to ``i1``.
|
181
|
+
|
182
|
+
>>> with handler(product(i1, i2)):
|
183
|
+
... print(f2())
|
184
|
+
{'inner': 1, 'outer': 2}
|
185
|
+
|
186
|
+
However, if we use :func:`coproduct`, ``i1`` is not isolated from ``i2``.
|
187
|
+
|
188
|
+
>>> with handler(coproduct(i1, i2)):
|
189
|
+
... print(f2())
|
190
|
+
{'inner': 2, 'outer': 2}
|
191
|
+
|
192
|
+
**References**
|
193
|
+
|
194
|
+
[1] Ahman, D., & Bauer, A. (2020, April). Runners in action. In European
|
195
|
+
Symposium on Programming (pp. 29-55). Cham: Springer International
|
196
|
+
Publishing.
|
197
|
+
|
198
|
+
"""
|
199
|
+
if any(op in intp for op in intp2): # alpha-rename
|
200
|
+
renaming = {op: defop(op) for op in intp2 if op in intp}
|
201
|
+
intp_fresh = {renaming.get(op, op): handler(renaming)(intp[op]) for op in intp}
|
202
|
+
return product(intp_fresh, intp2)
|
203
|
+
else:
|
204
|
+
refls2 = {op: op.__default_rule__ for op in intp2}
|
205
|
+
intp_ = coproduct({}, {op: runner(refls2)(intp[op]) for op in intp})
|
206
|
+
return {op: runner(intp_)(intp2[op]) for op in intp2}
|
207
|
+
|
208
|
+
|
209
|
+
@contextlib.contextmanager
|
210
|
+
def runner(intp: Interpretation[S, T]):
|
211
|
+
"""Install an interpretation by taking a product with the current
|
212
|
+
interpretation.
|
213
|
+
|
214
|
+
"""
|
215
|
+
from effectful.internals.runtime import get_interpretation, interpreter
|
216
|
+
|
217
|
+
@interpreter(get_interpretation())
|
218
|
+
def _reapply(_, op: Operation[P, S], *args: P.args, **kwargs: P.kwargs):
|
219
|
+
return op(*args, **kwargs)
|
220
|
+
|
221
|
+
with interpreter({apply: _reapply, **intp}):
|
222
|
+
yield intp
|
223
|
+
|
224
|
+
|
225
|
+
@contextlib.contextmanager
|
226
|
+
def handler(intp: Interpretation[S, T]):
|
227
|
+
"""Install an interpretation by taking a coproduct with the current
|
228
|
+
interpretation.
|
229
|
+
|
230
|
+
"""
|
231
|
+
from effectful.internals.runtime import get_interpretation, interpreter
|
232
|
+
|
233
|
+
with interpreter(coproduct(get_interpretation(), intp)):
|
234
|
+
yield intp
|
235
|
+
|
236
|
+
|
237
|
+
def evaluate(expr: Expr[T], *, intp: Optional[Interpretation[S, T]] = None) -> Expr[T]:
|
238
|
+
"""Evaluate expression ``expr`` using interpretation ``intp``. If no
|
239
|
+
interpretation is provided, uses the current interpretation.
|
240
|
+
|
241
|
+
:param expr: The expression to evaluate.
|
242
|
+
:param intp: Optional interpretation for evaluating ``expr``.
|
243
|
+
|
244
|
+
**Example usage**:
|
245
|
+
|
246
|
+
>>> @defop
|
247
|
+
... def add(x: int, y: int) -> int:
|
248
|
+
... raise NoDefaultRule
|
249
|
+
>>> expr = add(1, add(2, 3))
|
250
|
+
>>> expr
|
251
|
+
add(1, add(2, 3))
|
252
|
+
>>> evaluate(expr, intp={add: lambda x, y: x + y})
|
253
|
+
6
|
254
|
+
|
255
|
+
"""
|
256
|
+
if intp is None:
|
257
|
+
from effectful.internals.runtime import get_interpretation
|
258
|
+
|
259
|
+
intp = get_interpretation()
|
260
|
+
|
261
|
+
expr = defterm(expr) if not isinstance(expr, Term) else expr
|
262
|
+
|
263
|
+
if isinstance(expr, Term):
|
264
|
+
(args, kwargs) = tree.map_structure(
|
265
|
+
functools.partial(evaluate, intp=intp), (expr.args, expr.kwargs)
|
266
|
+
)
|
267
|
+
return apply.__default_rule__(intp, expr.op, *args, **kwargs) # type: ignore
|
268
|
+
elif tree.is_nested(expr):
|
269
|
+
return tree.map_structure(functools.partial(evaluate, intp=intp), expr)
|
270
|
+
else:
|
271
|
+
return expr
|
272
|
+
|
273
|
+
|
274
|
+
def typeof(term: Expr[T]) -> Type[T]:
|
275
|
+
"""Return the type of an expression.
|
276
|
+
|
277
|
+
**Example usage**:
|
278
|
+
|
279
|
+
Type signatures are used to infer the types of expressions.
|
280
|
+
|
281
|
+
>>> @defop
|
282
|
+
... def cmp(x: int, y: int) -> bool:
|
283
|
+
... raise NoDefaultRule
|
284
|
+
>>> typeof(cmp(1, 2))
|
285
|
+
<class 'bool'>
|
286
|
+
|
287
|
+
Types can be computed in the presence of type variables.
|
288
|
+
|
289
|
+
>>> from typing import TypeVar
|
290
|
+
>>> T = TypeVar('T')
|
291
|
+
>>> @defop
|
292
|
+
... def if_then_else(x: bool, a: T, b: T) -> T:
|
293
|
+
... raise NoDefaultRule
|
294
|
+
>>> typeof(if_then_else(True, 0, 1))
|
295
|
+
<class 'int'>
|
296
|
+
|
297
|
+
"""
|
298
|
+
from effectful.internals.runtime import interpreter
|
299
|
+
|
300
|
+
with interpreter({apply: lambda _, op, *a, **k: op.__type_rule__(*a, **k)}):
|
301
|
+
return evaluate(term) # type: ignore
|
302
|
+
|
303
|
+
|
304
|
+
def fvsof(term: Expr[S]) -> Set[Operation]:
|
305
|
+
"""Return the free variables of an expression.
|
306
|
+
|
307
|
+
**Example usage**:
|
308
|
+
|
309
|
+
>>> @defop
|
310
|
+
... def f(x: int, y: int) -> int:
|
311
|
+
... raise NoDefaultRule
|
312
|
+
>>> fvsof(f(1, 2))
|
313
|
+
{f}
|
314
|
+
|
315
|
+
"""
|
316
|
+
from effectful.internals.runtime import interpreter
|
317
|
+
|
318
|
+
_fvs: Set[Operation] = set()
|
319
|
+
|
320
|
+
def _update_fvs(_, op, *args, **kwargs):
|
321
|
+
_fvs.add(op)
|
322
|
+
for bound_var in op.__fvs_rule__(*args, **kwargs):
|
323
|
+
if bound_var in _fvs:
|
324
|
+
_fvs.remove(bound_var)
|
325
|
+
|
326
|
+
with interpreter({apply: _update_fvs}):
|
327
|
+
evaluate(term)
|
328
|
+
|
329
|
+
return _fvs
|