effectful 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/handlers/indexed.py +4 -22
- effectful/handlers/numbers.py +10 -6
- effectful/handlers/pyro.py +2 -2
- effectful/handlers/torch.py +33 -10
- effectful/ops/semantics.py +25 -29
- effectful/ops/syntax.py +633 -86
- effectful/ops/types.py +27 -13
- {effectful-0.0.1.dist-info → effectful-0.1.0.dist-info}/METADATA +17 -12
- effectful-0.1.0.dist-info/RECORD +18 -0
- effectful/internals/base_impl.py +0 -259
- effectful-0.0.1.dist-info/RECORD +0 -19
- {effectful-0.0.1.dist-info → effectful-0.1.0.dist-info}/LICENSE.md +0 -0
- {effectful-0.0.1.dist-info → effectful-0.1.0.dist-info}/WHEEL +0 -0
- {effectful-0.0.1.dist-info → effectful-0.1.0.dist-info}/top_level.txt +0 -0
effectful/ops/syntax.py
CHANGED
@@ -1,22 +1,16 @@
|
|
1
|
-
import collections
|
1
|
+
import collections.abc
|
2
2
|
import dataclasses
|
3
3
|
import functools
|
4
|
+
import inspect
|
5
|
+
import random
|
6
|
+
import types
|
4
7
|
import typing
|
5
|
-
from typing import
|
6
|
-
Annotated,
|
7
|
-
Callable,
|
8
|
-
Generic,
|
9
|
-
Mapping,
|
10
|
-
Optional,
|
11
|
-
Sequence,
|
12
|
-
Type,
|
13
|
-
TypeVar,
|
14
|
-
)
|
8
|
+
from typing import Annotated, Callable, Generic, Optional, Type, TypeVar
|
15
9
|
|
16
10
|
import tree
|
17
11
|
from typing_extensions import Concatenate, ParamSpec
|
18
12
|
|
19
|
-
from effectful.ops.types import
|
13
|
+
from effectful.ops.types import Annotation, Expr, Interpretation, Operation, Term
|
20
14
|
|
21
15
|
P = ParamSpec("P")
|
22
16
|
Q = ParamSpec("Q")
|
@@ -26,34 +20,357 @@ V = TypeVar("V")
|
|
26
20
|
|
27
21
|
|
28
22
|
@dataclasses.dataclass
|
29
|
-
class
|
30
|
-
|
23
|
+
class Scoped(Annotation):
|
24
|
+
"""
|
25
|
+
A special type annotation that indicates the relative scope of a parameter
|
26
|
+
in the signature of an :class:`Operation` created with :func:`defop` .
|
31
27
|
|
28
|
+
:class:`Scoped` makes it easy to describe higher-order :class:`Operation` s
|
29
|
+
that take other :class:`Term` s and :class:`Operation` s as arguments,
|
30
|
+
inspired by a number of recent proposals to view syntactic variables
|
31
|
+
as algebraic effects and environments as effect handlers.
|
32
32
|
|
33
|
-
|
34
|
-
|
35
|
-
|
33
|
+
As a result, in ``effectful`` many complex higher-order programming constructs,
|
34
|
+
such as lambda-abstraction, let-binding, loops, try-catch exception handling,
|
35
|
+
nondeterminism, capture-avoiding substitution and algebraic effect handling,
|
36
|
+
can be expressed uniformly using :func:`defop` as ordinary :class:`Operation` s
|
37
|
+
and evaluated or transformed using generalized effect handlers that respect
|
38
|
+
the scoping semantics of the operations.
|
39
|
+
|
40
|
+
.. warning::
|
41
|
+
|
42
|
+
:class:`Scoped` instances are typically constructed using indexing
|
43
|
+
syntactic sugar borrowed from generic types like :class:`typing.Generic` .
|
44
|
+
For example, ``Scoped[A]`` desugars to a :class:`Scoped` instances
|
45
|
+
with ``ordinal={A}``, and ``Scoped[A | B]`` desugars to a :class:`Scoped`
|
46
|
+
instance with ``ordinal={A, B}`` .
|
47
|
+
|
48
|
+
However, :class:`Scoped` is not a generic type, and the set of :class:`typing.TypeVar` s
|
49
|
+
used for the :class:`Scoped` annotations in a given operation must be disjoint
|
50
|
+
from the set of :class:`typing.TypeVar` s used for generic types of the parameters.
|
51
|
+
|
52
|
+
**Example usage**:
|
53
|
+
|
54
|
+
We illustrate the use of :class:`Scoped` with a few case studies of classical
|
55
|
+
syntactic variable binding constructs expressed as :class:`Operation` s.
|
56
|
+
|
57
|
+
>>> from typing import Annotated, TypeVar
|
58
|
+
>>> from effectful.ops.syntax import Scoped, defop
|
59
|
+
>>> from effectful.ops.semantics import fvsof
|
60
|
+
>>> from effectful.handlers.numbers import add
|
61
|
+
>>> A, B, S, T = TypeVar('A'), TypeVar('B'), TypeVar('S'), TypeVar('T')
|
62
|
+
>>> x, y = defop(int, name='x'), defop(int, name='y')
|
63
|
+
|
64
|
+
* For example, we can define a higher-order operation :func:`Lambda`
|
65
|
+
that takes an :class:`Operation` representing a bound syntactic variable
|
66
|
+
and a :class:`Term` representing the body of an anonymous function,
|
67
|
+
and returns a :class:`Term` representing a lambda function:
|
36
68
|
|
69
|
+
>>> @defop
|
70
|
+
... def Lambda(
|
71
|
+
... var: Annotated[Operation[[], S], Scoped[A]],
|
72
|
+
... body: Annotated[T, Scoped[A | B]]
|
73
|
+
... ) -> Annotated[Callable[[S], T], Scoped[B]]:
|
74
|
+
... raise NotImplementedError
|
75
|
+
|
76
|
+
* The :class:`Scoped` annotation is used here to indicate that the argument ``var``
|
77
|
+
passed to :func:`Lambda` may appear free in ``body``, but not in the resulting function.
|
78
|
+
In other words, it is bound by :func:`Lambda`:
|
79
|
+
|
80
|
+
>>> assert x not in fvsof(Lambda(x, add(x(), 1)))
|
37
81
|
|
38
|
-
|
39
|
-
"""Raised in an operation's signature to indicate that the operation has no default rule."""
|
82
|
+
However, variables in ``body`` other than ``var`` still appear free in the result:
|
40
83
|
|
41
|
-
|
84
|
+
>>> assert y in fvsof(Lambda(x, add(x(), y())))
|
42
85
|
|
86
|
+
* :class:`Scoped` can also be used with variadic arguments and keyword arguments.
|
87
|
+
For example, we can define a generalized :func:`LambdaN` that takes a variable
|
88
|
+
number of arguments and keyword arguments:
|
43
89
|
|
44
|
-
@
|
45
|
-
def
|
90
|
+
>>> @defop
|
91
|
+
... def LambdaN(
|
92
|
+
... body: Annotated[T, Scoped[A | B]]
|
93
|
+
... *args: Annotated[Operation[[], S], Scoped[A]],
|
94
|
+
... **kwargs: Annotated[Operation[[], S], Scoped[A]]
|
95
|
+
... ) -> Annotated[Callable[..., T], Scoped[B]]:
|
96
|
+
... raise NotImplementedError
|
97
|
+
|
98
|
+
This is equivalent to the built-in :class:`Operation` :func:`deffn`:
|
46
99
|
|
100
|
+
>>> assert not {x, y} & fvsof(LambdaN(add(x(), y()), x, y))
|
47
101
|
|
48
|
-
|
49
|
-
|
102
|
+
* :class:`Scoped` and :func:`defop` can also express more complex scoping semantics.
|
103
|
+
For example, we can define a :func:`Let` operation that binds a variable in
|
104
|
+
a :class:`Term` ``body`` to a ``value`` that may be another possibly open :class:`Term` :
|
50
105
|
|
106
|
+
>>> @defop
|
107
|
+
... def Let(
|
108
|
+
... var: Annotated[Operation[[], S], Scoped[A]],
|
109
|
+
... val: Annotated[S, Scoped[B]],
|
110
|
+
... body: Annotated[T, Scoped[A | B]]
|
111
|
+
... ) -> Annotated[T, Scoped[B]]:
|
112
|
+
... raise NotImplementedError
|
113
|
+
|
114
|
+
Here the variable ``var`` is bound by :func:`Let` in `body` but not in ``val`` :
|
115
|
+
|
116
|
+
>>> assert x not in fvsof(Let(x, add(y(), 1), add(x(), y())))
|
117
|
+
>>> assert {x, y} in fvsof(Let(x, add(y(), x()), add(x(), y())))
|
118
|
+
|
119
|
+
This is reflected in the free variables of subterms of the result:
|
120
|
+
|
121
|
+
>>> assert x in fvsof(Let(x, add(x(), y()), add(x(), y())).args[1])
|
122
|
+
>>> assert x not in fvsof(Let(x, add(y(), 1), add(x(), y())).args[2])
|
123
|
+
"""
|
51
124
|
|
52
|
-
|
53
|
-
|
125
|
+
ordinal: collections.abc.Set
|
126
|
+
|
127
|
+
def __class_getitem__(cls, item: TypeVar | typing._SpecialForm):
|
128
|
+
assert not isinstance(item, tuple), "can only be in one scope"
|
129
|
+
if isinstance(item, typing.TypeVar):
|
130
|
+
return cls(ordinal=frozenset({item}))
|
131
|
+
elif typing.get_origin(item) is typing.Union and typing.get_args(item):
|
132
|
+
return cls(ordinal=frozenset(typing.get_args(item)))
|
133
|
+
else:
|
134
|
+
raise TypeError(
|
135
|
+
f"expected TypeVar or non-empty Union of TypeVars, but got {item}"
|
136
|
+
)
|
137
|
+
|
138
|
+
@staticmethod
|
139
|
+
def _param_is_var(param: type | inspect.Parameter) -> bool:
|
140
|
+
"""
|
141
|
+
Helper function that checks if a parameter is annotated as an :class:`Operation` .
|
142
|
+
|
143
|
+
:param param: The parameter to check.
|
144
|
+
:returns: ``True`` if the parameter is an :class:`Operation` , ``False`` otherwise.
|
145
|
+
"""
|
146
|
+
if isinstance(param, inspect.Parameter):
|
147
|
+
param = param.annotation
|
148
|
+
if typing.get_origin(param) is Annotated:
|
149
|
+
param = typing.get_args(param)[0]
|
150
|
+
if typing.get_origin(param) is not None:
|
151
|
+
param = typing.cast(type, typing.get_origin(param))
|
152
|
+
return isinstance(param, type) and issubclass(param, Operation)
|
153
|
+
|
154
|
+
@classmethod
|
155
|
+
def _get_param_ordinal(cls, param: type | inspect.Parameter) -> collections.abc.Set:
|
156
|
+
"""
|
157
|
+
Given a type or parameter, extracts the ordinal from its :class:`Scoped` annotation.
|
158
|
+
|
159
|
+
:param param: The type or signature parameter to extract the ordinal from.
|
160
|
+
:returns: The ordinal typevars.
|
161
|
+
"""
|
162
|
+
if isinstance(param, inspect.Parameter):
|
163
|
+
return cls._get_param_ordinal(param.annotation)
|
164
|
+
elif typing.get_origin(param) is Annotated:
|
165
|
+
for a in typing.get_args(param)[1:]:
|
166
|
+
if isinstance(a, cls):
|
167
|
+
return a.ordinal
|
168
|
+
return set()
|
169
|
+
else:
|
170
|
+
return set()
|
171
|
+
|
172
|
+
@classmethod
|
173
|
+
def _get_root_ordinal(cls, sig: inspect.Signature) -> collections.abc.Set:
|
174
|
+
"""
|
175
|
+
Given a signature, computes the intersection of all :class:`Scoped` annotations.
|
176
|
+
|
177
|
+
:param sig: The signature to check.
|
178
|
+
:returns: The intersection of the `ordinal`s of all :class:`Scoped` annotations.
|
179
|
+
"""
|
180
|
+
return set(cls._get_param_ordinal(sig.return_annotation)).intersection(
|
181
|
+
*(cls._get_param_ordinal(p) for p in sig.parameters.values())
|
182
|
+
)
|
183
|
+
|
184
|
+
@classmethod
|
185
|
+
def _get_fresh_ordinal(cls, *, name: str = "RootScope") -> collections.abc.Set:
|
186
|
+
return {TypeVar(name)}
|
187
|
+
|
188
|
+
@classmethod
|
189
|
+
def _check_has_single_scope(cls, sig: inspect.Signature) -> bool:
|
190
|
+
"""
|
191
|
+
Checks if each parameter has at most one :class:`Scoped` annotation.
|
192
|
+
|
193
|
+
:param sig: The signature to check.
|
194
|
+
:returns: True if each parameter has at most one :class:`Scoped` annotation, False otherwise.
|
195
|
+
"""
|
196
|
+
# invariant: at most one Scope annotation per parameter
|
197
|
+
return not any(
|
198
|
+
len([a for a in p.annotation.__metadata__ if isinstance(a, cls)]) > 1
|
199
|
+
for p in sig.parameters.values()
|
200
|
+
if typing.get_origin(p.annotation) is Annotated
|
201
|
+
)
|
202
|
+
|
203
|
+
@classmethod
|
204
|
+
def _check_no_typevar_overlap(cls, sig: inspect.Signature) -> bool:
|
205
|
+
"""
|
206
|
+
Checks if there is no overlap between ordinal typevars and generic ones.
|
207
|
+
|
208
|
+
:param sig: The signature to check.
|
209
|
+
:returns: True if there is no overlap between ordinal typevars and generic ones, False otherwise.
|
210
|
+
"""
|
211
|
+
|
212
|
+
def _get_free_type_vars(
|
213
|
+
tp: type | typing._SpecialForm | inspect.Parameter | tuple | list,
|
214
|
+
) -> collections.abc.Set[TypeVar]:
|
215
|
+
if isinstance(tp, TypeVar):
|
216
|
+
return {tp}
|
217
|
+
elif isinstance(tp, (tuple, list)):
|
218
|
+
return set().union(*map(_get_free_type_vars, tp))
|
219
|
+
elif isinstance(tp, inspect.Parameter):
|
220
|
+
return _get_free_type_vars(tp.annotation)
|
221
|
+
elif typing.get_origin(tp) is Annotated:
|
222
|
+
return _get_free_type_vars(typing.get_args(tp)[0])
|
223
|
+
elif typing.get_origin(tp) is not None:
|
224
|
+
return _get_free_type_vars(typing.get_args(tp))
|
225
|
+
else:
|
226
|
+
return set()
|
227
|
+
|
228
|
+
# invariant: no overlap between ordinal typevars and generic ones
|
229
|
+
free_type_vars = _get_free_type_vars(
|
230
|
+
(sig.return_annotation, *sig.parameters.values())
|
231
|
+
)
|
232
|
+
return all(
|
233
|
+
free_type_vars.isdisjoint(cls._get_param_ordinal(p))
|
234
|
+
for p in (
|
235
|
+
sig.return_annotation,
|
236
|
+
*sig.parameters.values(),
|
237
|
+
)
|
238
|
+
)
|
239
|
+
|
240
|
+
@classmethod
|
241
|
+
def _check_no_boundvars_in_result(cls, sig: inspect.Signature) -> bool:
|
242
|
+
"""
|
243
|
+
Checks that no bound variables would appear free in the return value.
|
244
|
+
|
245
|
+
:param sig: The signature to check.
|
246
|
+
:returns: True if no bound variables would appear free in the return value, False otherwise.
|
247
|
+
|
248
|
+
.. note::
|
249
|
+
|
250
|
+
This is used as a post-condition for :func:`infer_annotations`.
|
251
|
+
However, it is not a necessary condition for the correctness of the
|
252
|
+
`Scope` annotations of an operation - our current implementation
|
253
|
+
merely does not extend to cases where this condition is true.
|
254
|
+
"""
|
255
|
+
root_ordinal = cls._get_root_ordinal(sig)
|
256
|
+
return_ordinal = cls._get_param_ordinal(sig.return_annotation)
|
257
|
+
return not any(
|
258
|
+
root_ordinal < cls._get_param_ordinal(p) <= return_ordinal
|
259
|
+
for p in sig.parameters.values()
|
260
|
+
if cls._param_is_var(p)
|
261
|
+
)
|
54
262
|
|
263
|
+
@classmethod
|
264
|
+
def infer_annotations(cls, sig: inspect.Signature) -> inspect.Signature:
|
265
|
+
"""
|
266
|
+
Given a :class:`inspect.Signature` for an :class:`Operation` for which
|
267
|
+
only some :class:`inspect.Parameter` s have manual :class:`Scoped` annotations,
|
268
|
+
computes a new signature with :class:`Scoped` annotations attached to each parameter,
|
269
|
+
including the return type annotation.
|
270
|
+
|
271
|
+
The new annotations are inferred by joining the manual annotations with a
|
272
|
+
fresh root scope. The root scope is the intersection of all :class:`Scoped`
|
273
|
+
annotations in the resulting :class:`inspect.Signature` object.
|
274
|
+
|
275
|
+
:class`Operation` s in this root scope are free in the result and in all arguments.
|
276
|
+
|
277
|
+
:param sig: The signature of the operation.
|
278
|
+
:returns: A new signature with inferred :class:`Scoped` annotations.
|
279
|
+
"""
|
280
|
+
# pre-conditions
|
281
|
+
assert cls._check_has_single_scope(sig)
|
282
|
+
assert cls._check_no_typevar_overlap(sig)
|
283
|
+
assert cls._check_no_boundvars_in_result(sig)
|
284
|
+
|
285
|
+
root_ordinal = cls._get_root_ordinal(sig)
|
286
|
+
if not root_ordinal:
|
287
|
+
root_ordinal = cls._get_fresh_ordinal()
|
288
|
+
|
289
|
+
# add missing Scoped annotations and join everything with the root scope
|
290
|
+
new_annos: list[type | typing._SpecialForm] = []
|
291
|
+
for anno in (
|
292
|
+
sig.return_annotation,
|
293
|
+
*(p.annotation for p in sig.parameters.values()),
|
294
|
+
):
|
295
|
+
new_scope = cls(ordinal=cls._get_param_ordinal(anno) | root_ordinal)
|
296
|
+
if typing.get_origin(anno) is Annotated:
|
297
|
+
new_anno = typing.get_args(anno)[0]
|
298
|
+
new_anno = Annotated[new_anno, new_scope]
|
299
|
+
for other in typing.get_args(anno)[1:]:
|
300
|
+
if not isinstance(other, cls):
|
301
|
+
new_anno = Annotated[new_anno, other]
|
302
|
+
else:
|
303
|
+
new_anno = Annotated[anno, new_scope]
|
304
|
+
|
305
|
+
new_annos.append(new_anno)
|
306
|
+
|
307
|
+
# construct a new Signature structure with the inferred annotations
|
308
|
+
new_return_anno, new_annos = new_annos[0], new_annos[1:]
|
309
|
+
inferred_sig = sig.replace(
|
310
|
+
parameters=[
|
311
|
+
p.replace(annotation=a)
|
312
|
+
for p, a in zip(sig.parameters.values(), new_annos)
|
313
|
+
],
|
314
|
+
return_annotation=new_return_anno,
|
315
|
+
)
|
55
316
|
|
56
|
-
|
317
|
+
# post-conditions
|
318
|
+
assert cls._get_root_ordinal(inferred_sig) == root_ordinal != set()
|
319
|
+
return inferred_sig
|
320
|
+
|
321
|
+
def analyze(self, bound_sig: inspect.BoundArguments) -> frozenset[Operation]:
|
322
|
+
"""
|
323
|
+
Computes a set of bound variables given a signature with bound arguments.
|
324
|
+
|
325
|
+
The :func:`analyze` methods of :class:`Scoped` annotations that appear on
|
326
|
+
the signature of an :class:`Operation` are used by :func:`defop` to generate
|
327
|
+
implementations of :func:`Operation.__fvs_rule__` underlying alpha-renaming
|
328
|
+
in :func:`defterm` and :func:`defdata` and free variable sets in :func:`fvsof` .
|
329
|
+
|
330
|
+
Specifically, the :func:`analyze` method of the :class:`Scoped` annotation
|
331
|
+
of a parameter computes the set of bound variables in that parameter's value.
|
332
|
+
The :func:`Operation.__fvs_rule__` method generated by :func:`defop` simply
|
333
|
+
extracts the annotation of each parameter, calls :func:`analyze` on the value
|
334
|
+
given for the corresponding parameter in ``bound_sig`` , and returns the results.
|
335
|
+
|
336
|
+
:param bound_sig: The :class:`inspect.Signature` of an :class:`Operation`
|
337
|
+
together with values for all of its arguments.
|
338
|
+
:returns: A set of bound variables.
|
339
|
+
"""
|
340
|
+
bound_vars: frozenset[Operation] = frozenset()
|
341
|
+
return_ordinal = self._get_param_ordinal(bound_sig.signature.return_annotation)
|
342
|
+
for name, param in bound_sig.signature.parameters.items():
|
343
|
+
param_ordinal = self._get_param_ordinal(param)
|
344
|
+
if (
|
345
|
+
self._param_is_var(param)
|
346
|
+
and param_ordinal <= self.ordinal
|
347
|
+
and not param_ordinal <= return_ordinal
|
348
|
+
):
|
349
|
+
if param.kind is inspect.Parameter.VAR_POSITIONAL:
|
350
|
+
# pre-condition: all bound variables should be distinct
|
351
|
+
assert len(bound_sig.arguments[name]) == len(
|
352
|
+
set(bound_sig.arguments[name])
|
353
|
+
)
|
354
|
+
param_bound_vars = {*bound_sig.arguments[name]}
|
355
|
+
elif param.kind is inspect.Parameter.VAR_KEYWORD:
|
356
|
+
# pre-condition: all bound variables should be distinct
|
357
|
+
assert len(bound_sig.arguments[name].values()) == len(
|
358
|
+
set(bound_sig.arguments[name].values())
|
359
|
+
)
|
360
|
+
param_bound_vars = {*bound_sig.arguments[name].values()}
|
361
|
+
else:
|
362
|
+
param_bound_vars = {bound_sig.arguments[name]}
|
363
|
+
|
364
|
+
# pre-condition: all bound variables should be distinct
|
365
|
+
assert not bound_vars & param_bound_vars
|
366
|
+
|
367
|
+
bound_vars |= param_bound_vars
|
368
|
+
|
369
|
+
return bound_vars
|
370
|
+
|
371
|
+
|
372
|
+
@functools.singledispatch
|
373
|
+
def defop(t: Callable[P, T], *, name: Optional[str] = None) -> Operation[P, T]:
|
57
374
|
"""Creates a fresh :class:`Operation`.
|
58
375
|
|
59
376
|
:param t: May be a type, callable, or :class:`Operation`. If a type, the
|
@@ -94,12 +411,12 @@ def defop(t, *, name=None):
|
|
94
411
|
* Defining an operation with no default rule:
|
95
412
|
|
96
413
|
We can use :func:`defop` and the
|
97
|
-
:exc:`
|
414
|
+
:exc:`NotImplementedError` exception to define an
|
98
415
|
operation with no default rule:
|
99
416
|
|
100
417
|
>>> @defop
|
101
418
|
... def add(x: int, y: int) -> int:
|
102
|
-
... raise
|
419
|
+
... raise NotImplementedError
|
103
420
|
>>> add(1, 2)
|
104
421
|
add(1, 2)
|
105
422
|
|
@@ -184,45 +501,167 @@ def defop(t, *, name=None):
|
|
184
501
|
1 2
|
185
502
|
|
186
503
|
"""
|
504
|
+
raise NotImplementedError(f"expected type or callable, got {t}")
|
187
505
|
|
188
|
-
if isinstance(t, Operation):
|
189
506
|
|
190
|
-
|
191
|
-
|
507
|
+
@defop.register(typing.cast(Type[collections.abc.Callable], collections.abc.Callable))
|
508
|
+
class _BaseOperation(Generic[Q, V], Operation[Q, V]):
|
509
|
+
__signature__: inspect.Signature
|
510
|
+
__name__: str
|
192
511
|
|
193
|
-
|
194
|
-
return defop(func, name=name)
|
195
|
-
elif isinstance(t, type):
|
512
|
+
_default: Callable[Q, V]
|
196
513
|
|
197
|
-
|
198
|
-
|
514
|
+
def __init__(self, default: Callable[Q, V], *, name: Optional[str] = None):
|
515
|
+
functools.update_wrapper(self, default)
|
516
|
+
self._default = default
|
517
|
+
self.__name__ = name or default.__name__
|
518
|
+
self.__signature__ = inspect.signature(default)
|
199
519
|
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
520
|
+
def __eq__(self, other):
|
521
|
+
if not isinstance(other, Operation):
|
522
|
+
return NotImplemented
|
523
|
+
return self is other
|
204
524
|
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
525
|
+
def __hash__(self):
|
526
|
+
return hash(self._default)
|
527
|
+
|
528
|
+
def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]":
|
529
|
+
try:
|
530
|
+
return self._default(*args, **kwargs)
|
531
|
+
except NotImplementedError:
|
532
|
+
return typing.cast(
|
533
|
+
Callable[Concatenate[Operation[Q, V], Q], Expr[V]], defdata
|
534
|
+
)(self, *args, **kwargs)
|
535
|
+
|
536
|
+
def __fvs_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> tuple[
|
537
|
+
tuple[collections.abc.Set[Operation], ...],
|
538
|
+
dict[str, collections.abc.Set[Operation]],
|
539
|
+
]:
|
540
|
+
sig = Scoped.infer_annotations(self.__signature__)
|
541
|
+
bound_sig = sig.bind(*args, **kwargs)
|
542
|
+
bound_sig.apply_defaults()
|
543
|
+
|
544
|
+
result_sig = sig.bind(
|
545
|
+
*(frozenset() for _ in bound_sig.args),
|
546
|
+
**{k: frozenset() for k in bound_sig.kwargs},
|
547
|
+
)
|
548
|
+
for name, param in sig.parameters.items():
|
549
|
+
if typing.get_origin(param.annotation) is typing.Annotated:
|
550
|
+
for anno in typing.get_args(param.annotation)[1:]:
|
551
|
+
if isinstance(anno, Scoped):
|
552
|
+
param_bound_vars = anno.analyze(bound_sig)
|
553
|
+
if param.kind is inspect.Parameter.VAR_POSITIONAL:
|
554
|
+
result_sig.arguments[name] = tuple(
|
555
|
+
param_bound_vars for _ in bound_sig.arguments[name]
|
556
|
+
)
|
557
|
+
elif param.kind is inspect.Parameter.VAR_KEYWORD:
|
558
|
+
result_sig.kwargs[name] = {
|
559
|
+
k: param_bound_vars for k in bound_sig.arguments[name]
|
560
|
+
}
|
561
|
+
else:
|
562
|
+
result_sig.arguments[name] = param_bound_vars
|
563
|
+
|
564
|
+
return tuple(result_sig.args), dict(result_sig.kwargs)
|
565
|
+
|
566
|
+
def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> Type[V]:
|
567
|
+
sig = inspect.signature(self._default)
|
568
|
+
bound_sig = sig.bind(*args, **kwargs)
|
569
|
+
bound_sig.apply_defaults()
|
570
|
+
|
571
|
+
anno = sig.return_annotation
|
572
|
+
if anno is inspect.Signature.empty:
|
573
|
+
return typing.cast(Type[V], object)
|
574
|
+
elif isinstance(anno, typing.TypeVar):
|
575
|
+
# rudimentary but sound special-case type inference sufficient for syntax ops:
|
576
|
+
# if the return type annotation is a TypeVar,
|
577
|
+
# look for a parameter with the same annotation and return its type,
|
578
|
+
# otherwise give up and return Any/object
|
579
|
+
for name, param in bound_sig.signature.parameters.items():
|
580
|
+
if param.annotation is anno and param.kind not in (
|
581
|
+
inspect.Parameter.VAR_POSITIONAL,
|
582
|
+
inspect.Parameter.VAR_KEYWORD,
|
583
|
+
):
|
584
|
+
arg = bound_sig.arguments[name]
|
585
|
+
tp: Type[V] = type(arg) if not isinstance(arg, type) else arg
|
586
|
+
return tp
|
587
|
+
return typing.cast(Type[V], object)
|
588
|
+
elif typing.get_origin(anno) is typing.Annotated:
|
589
|
+
tp = typing.get_args(anno)[0]
|
590
|
+
if not typing.TYPE_CHECKING:
|
591
|
+
tp = tp if typing.get_origin(tp) is None else typing.get_origin(tp)
|
592
|
+
return tp
|
593
|
+
elif typing.get_origin(anno) is not None:
|
594
|
+
return typing.get_origin(anno)
|
595
|
+
else:
|
596
|
+
return anno
|
597
|
+
|
598
|
+
def __repr_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> str:
|
599
|
+
args_str = ", ".join(map(str, args)) if args else ""
|
600
|
+
kwargs_str = (
|
601
|
+
", ".join(f"{k}={str(v)}" for k, v in kwargs.items()) if kwargs else ""
|
602
|
+
)
|
603
|
+
|
604
|
+
ret = f"{self.__name__}({args_str}"
|
605
|
+
if kwargs:
|
606
|
+
ret += f"{', ' if args else ''}"
|
607
|
+
ret += f"{kwargs_str})"
|
608
|
+
return ret
|
609
|
+
|
610
|
+
def __repr__(self):
|
611
|
+
return self.__name__
|
612
|
+
|
613
|
+
|
614
|
+
@defop.register(Operation)
|
615
|
+
def _(t: Operation[P, T], *, name: Optional[str] = None) -> Operation[P, T]:
|
616
|
+
|
617
|
+
@functools.wraps(t)
|
618
|
+
def func(*args, **kwargs):
|
619
|
+
raise NotImplementedError
|
620
|
+
|
621
|
+
if name is None:
|
622
|
+
name = (
|
623
|
+
getattr(t, "__name__", str(t))[:10000] + f"__{random.randint(0, 1 << 32)}"
|
624
|
+
)
|
625
|
+
return defop(func, name=name)
|
626
|
+
|
627
|
+
|
628
|
+
@defop.register(type)
|
629
|
+
def _(t: Type[T], *, name: Optional[str] = None) -> Operation[[], T]:
|
630
|
+
def func() -> t: # type: ignore
|
631
|
+
raise NotImplementedError
|
632
|
+
|
633
|
+
if name is None:
|
634
|
+
name = t.__name__ + f"__{random.randint(0, 1 << 32)}"
|
635
|
+
return typing.cast(Operation[[], T], defop(func, name=name))
|
636
|
+
|
637
|
+
|
638
|
+
@defop.register(types.BuiltinFunctionType)
|
639
|
+
def _(t: Callable[P, T], *, name: Optional[str] = None) -> Operation[P, T]:
|
640
|
+
|
641
|
+
@functools.wraps(t)
|
642
|
+
def func(*args, **kwargs):
|
643
|
+
if not any(isinstance(a, Term) for a in tree.flatten((args, kwargs))):
|
644
|
+
return t(*args, **kwargs)
|
645
|
+
else:
|
646
|
+
raise NotImplementedError
|
647
|
+
|
648
|
+
return defop(func, name=name)
|
210
649
|
|
211
650
|
|
212
651
|
@defop
|
213
652
|
def deffn(
|
214
|
-
body: T,
|
215
|
-
*args: Annotated[Operation,
|
216
|
-
**kwargs: Annotated[Operation,
|
653
|
+
body: Annotated[T, Scoped[S]],
|
654
|
+
*args: Annotated[Operation, Scoped[S]],
|
655
|
+
**kwargs: Annotated[Operation, Scoped[S]],
|
217
656
|
) -> Callable[..., T]:
|
218
657
|
"""An operation that represents a lambda function.
|
219
658
|
|
220
659
|
:param body: The body of the function.
|
221
660
|
:type body: T
|
222
661
|
:param args: Operations representing the positional arguments of the function.
|
223
|
-
:type args:
|
662
|
+
:type args: Operation
|
224
663
|
:param kwargs: Operations representing the keyword arguments of the function.
|
225
|
-
:type kwargs:
|
664
|
+
:type kwargs: Operation
|
226
665
|
:returns: A callable term.
|
227
666
|
:rtype: Callable[..., T]
|
228
667
|
|
@@ -249,12 +688,12 @@ def deffn(
|
|
249
688
|
automatically create the right free variables.
|
250
689
|
|
251
690
|
"""
|
252
|
-
raise
|
691
|
+
raise NotImplementedError
|
253
692
|
|
254
693
|
|
255
|
-
class _CustomSingleDispatchCallable(Generic[P, T]):
|
694
|
+
class _CustomSingleDispatchCallable(Generic[P, Q, S, T]):
|
256
695
|
def __init__(
|
257
|
-
self, func: Callable[Concatenate[Callable[[type], Callable[
|
696
|
+
self, func: Callable[Concatenate[Callable[[type], Callable[Q, S]], P], T]
|
258
697
|
):
|
259
698
|
self._func = func
|
260
699
|
self._registry = functools.singledispatch(func)
|
@@ -273,7 +712,7 @@ class _CustomSingleDispatchCallable(Generic[P, T]):
|
|
273
712
|
|
274
713
|
|
275
714
|
@_CustomSingleDispatchCallable
|
276
|
-
def defterm(
|
715
|
+
def defterm(__dispatch: Callable[[type], Callable[[T], Expr[T]]], value: T):
|
277
716
|
"""Convert a value to a term, using the type of the value to dispatch.
|
278
717
|
|
279
718
|
:param value: The value to convert.
|
@@ -298,21 +737,22 @@ def defterm(dispatch, value: T) -> Expr[T]:
|
|
298
737
|
if isinstance(value, Term):
|
299
738
|
return value
|
300
739
|
else:
|
301
|
-
return
|
740
|
+
return __dispatch(type(value))(value)
|
302
741
|
|
303
742
|
|
304
743
|
@_CustomSingleDispatchCallable
|
305
|
-
def defdata(
|
306
|
-
|
744
|
+
def defdata(
|
745
|
+
__dispatch: Callable[[type], Callable[..., Expr[T]]],
|
746
|
+
op: Operation[..., T],
|
747
|
+
*args,
|
748
|
+
**kwargs,
|
749
|
+
) -> Expr[T]:
|
750
|
+
"""Constructs a Term that is an instance of its semantic type.
|
307
751
|
|
308
|
-
:param expr: The term to convert.
|
309
|
-
:type expr: Term[T]
|
310
752
|
:returns: An instance of ``T``.
|
311
753
|
:rtype: Expr[T]
|
312
754
|
|
313
|
-
This function is
|
314
|
-
resgistered with :func:`defdata` are automatically applied when terms are
|
315
|
-
constructed.
|
755
|
+
This function is the only way to construct a :class:`Term` from an :class:`Operation`.
|
316
756
|
|
317
757
|
.. note::
|
318
758
|
|
@@ -326,57 +766,164 @@ def defdata(dispatch, expr: Term[T]) -> Expr[T]:
|
|
326
766
|
|
327
767
|
.. code-block:: python
|
328
768
|
|
329
|
-
class _CallableTerm(Generic[P, T],
|
769
|
+
class _CallableTerm(Generic[P, T], Term[collections.abc.Callable[P, T]]):
|
770
|
+
def __init__(
|
771
|
+
self,
|
772
|
+
op: Operation[..., T],
|
773
|
+
*args: Expr,
|
774
|
+
**kwargs: Expr,
|
775
|
+
):
|
776
|
+
self._op = op
|
777
|
+
self._args = args
|
778
|
+
self._kwargs = kwargs
|
779
|
+
|
780
|
+
@property
|
781
|
+
def op(self):
|
782
|
+
return self._op
|
783
|
+
|
784
|
+
@property
|
785
|
+
def args(self):
|
786
|
+
return self._args
|
787
|
+
|
788
|
+
@property
|
789
|
+
def kwargs(self):
|
790
|
+
return self._kwargs
|
791
|
+
|
330
792
|
def __call__(self, *args: Expr, **kwargs: Expr) -> Expr[T]:
|
331
793
|
from effectful.ops.semantics import call
|
332
794
|
|
333
795
|
return call(self, *args, **kwargs)
|
334
796
|
|
335
797
|
@defdata.register(collections.abc.Callable)
|
336
|
-
def _(op, args, kwargs):
|
337
|
-
return _CallableTerm(op, args, kwargs)
|
338
|
-
|
339
|
-
When a :class:`Callable` term is passed to :func:`defdata`, it is
|
340
|
-
reconstructed as a :class:`_CallableTerm`, which implements the
|
341
|
-
:func:`__call__` method.
|
798
|
+
def _(op, *args, **kwargs):
|
799
|
+
return _CallableTerm(op, *args, **kwargs)
|
342
800
|
|
801
|
+
When an Operation whose return type is `Callable` is passed to :func:`defdata`,
|
802
|
+
it is reconstructed as a :class:`_CallableTerm`, which implements the :func:`__call__` method.
|
343
803
|
"""
|
344
|
-
from effectful.ops.semantics import typeof
|
804
|
+
from effectful.ops.semantics import apply, evaluate, typeof
|
805
|
+
|
806
|
+
arg_ctxs, kwarg_ctxs = op.__fvs_rule__(*args, **kwargs)
|
807
|
+
renaming = {
|
808
|
+
var: defop(var)
|
809
|
+
for bound_vars in (*arg_ctxs, *kwarg_ctxs.values())
|
810
|
+
for var in bound_vars
|
811
|
+
}
|
812
|
+
|
813
|
+
args_, kwargs_ = list(args), dict(kwargs)
|
814
|
+
for i, (v, c) in (
|
815
|
+
*enumerate(zip(args, arg_ctxs)),
|
816
|
+
*{k: (v, kwarg_ctxs[k]) for k, v in kwargs.items()}.items(),
|
817
|
+
):
|
818
|
+
if c:
|
819
|
+
v = tree.map_structure(
|
820
|
+
lambda a: renaming.get(a, a) if isinstance(a, Operation) else a, v
|
821
|
+
)
|
822
|
+
res = evaluate(
|
823
|
+
v,
|
824
|
+
intp={
|
825
|
+
apply: lambda _, op, *a, **k: defdata(op, *a, **k),
|
826
|
+
**{op: renaming[op] for op in c},
|
827
|
+
},
|
828
|
+
)
|
829
|
+
if isinstance(i, int):
|
830
|
+
args_[i] = res
|
831
|
+
elif isinstance(i, str):
|
832
|
+
kwargs_[i] = res
|
345
833
|
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
else:
|
351
|
-
return expr
|
834
|
+
tp: Type[T] = typeof(
|
835
|
+
__dispatch(typing.cast(Type[T], object))(op, *args_, **kwargs_)
|
836
|
+
)
|
837
|
+
return __dispatch(tp)(op, *args_, **kwargs_)
|
352
838
|
|
353
839
|
|
354
840
|
@defterm.register(object)
|
355
841
|
@defterm.register(Operation)
|
356
842
|
@defterm.register(Term)
|
843
|
+
@defterm.register(type)
|
844
|
+
@defterm.register(types.BuiltinFunctionType)
|
357
845
|
def _(value: T) -> T:
|
358
846
|
return value
|
359
847
|
|
360
848
|
|
361
849
|
@defdata.register(object)
|
362
|
-
|
363
|
-
|
850
|
+
class _BaseTerm(Generic[T], Term[T]):
|
851
|
+
_op: Operation[..., T]
|
852
|
+
_args: collections.abc.Sequence[Expr]
|
853
|
+
_kwargs: collections.abc.Mapping[str, Expr]
|
364
854
|
|
365
|
-
|
855
|
+
def __init__(
|
856
|
+
self,
|
857
|
+
op: Operation[..., T],
|
858
|
+
*args: Expr,
|
859
|
+
**kwargs: Expr,
|
860
|
+
):
|
861
|
+
self._op = op
|
862
|
+
self._args = args
|
863
|
+
self._kwargs = kwargs
|
864
|
+
|
865
|
+
def __eq__(self, other) -> bool:
|
866
|
+
from effectful.ops.syntax import syntactic_eq
|
867
|
+
|
868
|
+
return syntactic_eq(self, other)
|
869
|
+
|
870
|
+
@property
|
871
|
+
def op(self):
|
872
|
+
return self._op
|
873
|
+
|
874
|
+
@property
|
875
|
+
def args(self):
|
876
|
+
return self._args
|
877
|
+
|
878
|
+
@property
|
879
|
+
def kwargs(self):
|
880
|
+
return self._kwargs
|
366
881
|
|
367
882
|
|
368
883
|
@defdata.register(collections.abc.Callable)
|
369
|
-
|
370
|
-
|
884
|
+
class _CallableTerm(Generic[P, T], _BaseTerm[collections.abc.Callable[P, T]]):
|
885
|
+
def __call__(self, *args: Expr, **kwargs: Expr) -> Expr[T]:
|
886
|
+
from effectful.ops.semantics import call
|
371
887
|
|
372
|
-
|
888
|
+
return call(self, *args, **kwargs) # type: ignore
|
373
889
|
|
374
890
|
|
375
891
|
@defterm.register(collections.abc.Callable)
|
376
|
-
def _(
|
377
|
-
from effectful.internals.
|
892
|
+
def _(value: Callable[P, T]) -> Expr[Callable[P, T]]:
|
893
|
+
from effectful.internals.runtime import interpreter
|
894
|
+
from effectful.ops.semantics import apply, call
|
895
|
+
|
896
|
+
assert not isinstance(value, Term)
|
897
|
+
|
898
|
+
try:
|
899
|
+
sig = inspect.signature(value)
|
900
|
+
except ValueError:
|
901
|
+
return value
|
902
|
+
|
903
|
+
for name, param in sig.parameters.items():
|
904
|
+
if param.kind in (
|
905
|
+
inspect.Parameter.VAR_POSITIONAL,
|
906
|
+
inspect.Parameter.VAR_KEYWORD,
|
907
|
+
):
|
908
|
+
raise ValueError(f"cannot unembed {value}: parameter {name} is variadic")
|
909
|
+
|
910
|
+
bound_sig = sig.bind(
|
911
|
+
**{name: defop(param.annotation) for name, param in sig.parameters.items()}
|
912
|
+
)
|
913
|
+
bound_sig.apply_defaults()
|
914
|
+
|
915
|
+
with interpreter(
|
916
|
+
{
|
917
|
+
apply: lambda _, op, *a, **k: defdata(op, *a, **k),
|
918
|
+
call: call.__default_rule__,
|
919
|
+
}
|
920
|
+
):
|
921
|
+
body = value(
|
922
|
+
*[a() for a in bound_sig.args],
|
923
|
+
**{k: v() for k, v in bound_sig.kwargs.items()},
|
924
|
+
)
|
378
925
|
|
379
|
-
return
|
926
|
+
return deffn(body, *bound_sig.args, **bound_sig.kwargs)
|
380
927
|
|
381
928
|
|
382
929
|
def syntactic_eq(x: Expr[T], other: Expr[T]) -> bool:
|