effectful 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
effectful/ops/syntax.py CHANGED
@@ -1,59 +1,387 @@
1
- import collections
1
+ import collections.abc
2
2
  import dataclasses
3
3
  import functools
4
+ import inspect
5
+ import numbers
6
+ import operator
7
+ import random
8
+ import types
4
9
  import typing
5
- from typing import (
6
- Annotated,
7
- Callable,
8
- Generic,
9
- Mapping,
10
- Optional,
11
- Sequence,
12
- Type,
13
- TypeVar,
14
- )
10
+ import warnings
11
+ from collections.abc import Callable, Iterable, Mapping
12
+ from typing import Annotated, Any, Concatenate
15
13
 
16
- import tree
17
- from typing_extensions import Concatenate, ParamSpec
14
+ from effectful.ops.types import Annotation, Expr, NotHandled, Operation, Term
18
15
 
19
- from effectful.ops.types import ArgAnnotation, Expr, Interpretation, Operation, Term
20
16
 
21
- P = ParamSpec("P")
22
- Q = ParamSpec("Q")
23
- S = TypeVar("S")
24
- T = TypeVar("T")
25
- V = TypeVar("V")
17
+ @dataclasses.dataclass
18
+ class Scoped(Annotation):
19
+ """
20
+ A special type annotation that indicates the relative scope of a parameter
21
+ in the signature of an :class:`Operation` created with :func:`defop` .
26
22
 
23
+ :class:`Scoped` makes it easy to describe higher-order :class:`Operation` s
24
+ that take other :class:`Term` s and :class:`Operation` s as arguments,
25
+ inspired by a number of recent proposals to view syntactic variables
26
+ as algebraic effects and environments as effect handlers.
27
27
 
28
- @dataclasses.dataclass
29
- class Bound(ArgAnnotation):
30
- scope: int = 0
28
+ As a result, in ``effectful`` many complex higher-order programming constructs,
29
+ such as lambda-abstraction, let-binding, loops, try-catch exception handling,
30
+ nondeterminism, capture-avoiding substitution and algebraic effect handling,
31
+ can be expressed uniformly using :func:`defop` as ordinary :class:`Operation` s
32
+ and evaluated or transformed using generalized effect handlers that respect
33
+ the scoping semantics of the operations.
31
34
 
35
+ .. warning::
32
36
 
33
- @dataclasses.dataclass
34
- class Scoped(ArgAnnotation):
35
- scope: int = 0
37
+ :class:`Scoped` instances are typically constructed using indexing
38
+ syntactic sugar borrowed from generic types like :class:`typing.Generic` .
39
+ For example, ``Scoped[A]`` desugars to a :class:`Scoped` instances
40
+ with ``ordinal={A}``, and ``Scoped[A | B]`` desugars to a :class:`Scoped`
41
+ instance with ``ordinal={A, B}`` .
36
42
 
43
+ However, :class:`Scoped` is not a generic type, and the set of :class:`typing.TypeVar` s
44
+ used for the :class:`Scoped` annotations in a given operation must be disjoint
45
+ from the set of :class:`typing.TypeVar` s used for generic types of the parameters.
37
46
 
38
- class NoDefaultRule(Exception):
39
- """Raised in an operation's signature to indicate that the operation has no default rule."""
47
+ **Example usage**:
40
48
 
41
- pass
49
+ We illustrate the use of :class:`Scoped` with a few case studies of classical
50
+ syntactic variable binding constructs expressed as :class:`Operation` s.
51
+
52
+ >>> from typing import Annotated
53
+ >>> from effectful.ops.syntax import Scoped, defop
54
+ >>> from effectful.ops.semantics import fvsof
55
+ >>> x, y = defop(int, name='x'), defop(int, name='y')
56
+
57
+ * For example, we can define a higher-order operation :func:`Lambda`
58
+ that takes an :class:`Operation` representing a bound syntactic variable
59
+ and a :class:`Term` representing the body of an anonymous function,
60
+ and returns a :class:`Term` representing a lambda function:
61
+
62
+ >>> @defop
63
+ ... def Lambda[S, T, A, B](
64
+ ... var: Annotated[Operation[[], S], Scoped[A]],
65
+ ... body: Annotated[T, Scoped[A | B]]
66
+ ... ) -> Annotated[Callable[[S], T], Scoped[B]]:
67
+ ... raise NotHandled
68
+
69
+ * The :class:`Scoped` annotation is used here to indicate that the argument ``var``
70
+ passed to :func:`Lambda` may appear free in ``body``, but not in the resulting function.
71
+ In other words, it is bound by :func:`Lambda`:
72
+
73
+ >>> assert x not in fvsof(Lambda(x, x() + 1))
74
+
75
+ However, variables in ``body`` other than ``var`` still appear free in the result:
76
+
77
+ >>> assert y in fvsof(Lambda(x, x() + y()))
78
+
79
+ * :class:`Scoped` can also be used with variadic arguments and keyword arguments.
80
+ For example, we can define a generalized :func:`LambdaN` that takes a variable
81
+ number of arguments and keyword arguments:
82
+
83
+ >>> @defop
84
+ ... def LambdaN[S, T, A, B](
85
+ ... body: Annotated[T, Scoped[A | B]],
86
+ ... *args: Annotated[Operation[[], S], Scoped[A]],
87
+ ... **kwargs: Annotated[Operation[[], S], Scoped[A]]
88
+ ... ) -> Annotated[Callable[..., T], Scoped[B]]:
89
+ ... raise NotHandled
90
+
91
+ This is equivalent to the built-in :class:`Operation` :func:`deffn`:
92
+
93
+ >>> assert not {x, y} & fvsof(LambdaN(x() + y(), x, y))
94
+
95
+ * :class:`Scoped` and :func:`defop` can also express more complex scoping semantics.
96
+ For example, we can define a :func:`Let` operation that binds a variable in
97
+ a :class:`Term` ``body`` to a ``value`` that may be another possibly open :class:`Term` :
98
+
99
+ >>> @defop
100
+ ... def Let[S, T, A, B](
101
+ ... var: Annotated[Operation[[], S], Scoped[A]],
102
+ ... val: Annotated[S, Scoped[B]],
103
+ ... body: Annotated[T, Scoped[A | B]]
104
+ ... ) -> Annotated[T, Scoped[B]]:
105
+ ... raise NotHandled
106
+
107
+ Here the variable ``var`` is bound by :func:`Let` in `body` but not in ``val`` :
108
+
109
+ >>> assert x not in fvsof(Let(x, y() + 1, x() + y()))
110
+
111
+ >>> fvs = fvsof(Let(x, y() + x(), x() + y()))
112
+ >>> assert x in fvs and y in fvs
113
+
114
+ This is reflected in the free variables of subterms of the result:
115
+
116
+ >>> assert x in fvsof(Let(x, x() + y(), x() + y()).args[1])
117
+ >>> assert x not in fvsof(Let(x, y() + 1, x() + y()).args[2])
118
+ """
119
+
120
+ ordinal: collections.abc.Set
121
+
122
+ def __class_getitem__(cls, item: typing.TypeVar | typing._SpecialForm):
123
+ assert not isinstance(item, tuple), "can only be in one scope"
124
+ if isinstance(item, typing.TypeVar):
125
+ return cls(ordinal=frozenset({item}))
126
+ elif typing.get_origin(item) is typing.Union and typing.get_args(item):
127
+ return cls(ordinal=frozenset(typing.get_args(item)))
128
+ else:
129
+ raise TypeError(
130
+ f"expected TypeVar or non-empty Union of TypeVars, but got {item}"
131
+ )
42
132
 
133
+ @staticmethod
134
+ def _param_is_var(param: type | inspect.Parameter) -> bool:
135
+ """
136
+ Helper function that checks if a parameter is annotated as an :class:`Operation` .
137
+
138
+ :param param: The parameter to check.
139
+ :returns: ``True`` if the parameter is an :class:`Operation` , ``False`` otherwise.
140
+ """
141
+ if isinstance(param, inspect.Parameter):
142
+ param = param.annotation
143
+ if typing.get_origin(param) is Annotated:
144
+ param = typing.get_args(param)[0]
145
+ if typing.get_origin(param) is not None:
146
+ param = typing.cast(type, typing.get_origin(param))
147
+ return isinstance(param, type) and issubclass(param, Operation)
43
148
 
44
- @typing.overload
45
- def defop(t: Type[T], *, name: Optional[str] = None) -> Operation[[], T]: ...
149
+ @classmethod
150
+ def _get_param_ordinal(cls, param: type | inspect.Parameter) -> collections.abc.Set:
151
+ """
152
+ Given a type or parameter, extracts the ordinal from its :class:`Scoped` annotation.
153
+
154
+ :param param: The type or signature parameter to extract the ordinal from.
155
+ :returns: The ordinal typevars.
156
+ """
157
+ if isinstance(param, inspect.Parameter):
158
+ return cls._get_param_ordinal(param.annotation)
159
+ elif typing.get_origin(param) is Annotated:
160
+ for a in typing.get_args(param)[1:]:
161
+ if isinstance(a, cls):
162
+ return a.ordinal
163
+ return set()
164
+ else:
165
+ return set()
166
+
167
+ @classmethod
168
+ def _get_root_ordinal(cls, sig: inspect.Signature) -> collections.abc.Set:
169
+ """
170
+ Given a signature, computes the intersection of all :class:`Scoped` annotations.
171
+
172
+ :param sig: The signature to check.
173
+ :returns: The intersection of the `ordinal`s of all :class:`Scoped` annotations.
174
+ """
175
+ return set(cls._get_param_ordinal(sig.return_annotation)).intersection(
176
+ *(cls._get_param_ordinal(p) for p in sig.parameters.values())
177
+ )
46
178
 
179
+ @classmethod
180
+ def _get_fresh_ordinal(cls, *, name: str = "RootScope") -> collections.abc.Set:
181
+ return {typing.TypeVar(name)}
47
182
 
48
- @typing.overload
49
- def defop(t: Callable[P, T], *, name: Optional[str] = None) -> Operation[P, T]: ...
183
+ @classmethod
184
+ def _check_has_single_scope(cls, sig: inspect.Signature) -> bool:
185
+ """
186
+ Checks if each parameter has at most one :class:`Scoped` annotation.
187
+
188
+ :param sig: The signature to check.
189
+ :returns: True if each parameter has at most one :class:`Scoped` annotation, False otherwise.
190
+ """
191
+ # invariant: at most one Scope annotation per parameter
192
+ return not any(
193
+ len([a for a in p.annotation.__metadata__ if isinstance(a, cls)]) > 1
194
+ for p in sig.parameters.values()
195
+ if typing.get_origin(p.annotation) is Annotated
196
+ )
50
197
 
198
+ @classmethod
199
+ def _check_no_typevar_overlap(cls, sig: inspect.Signature) -> bool:
200
+ """
201
+ Checks if there is no overlap between ordinal typevars and generic ones.
202
+
203
+ :param sig: The signature to check.
204
+ :returns: True if there is no overlap between ordinal typevars and generic ones, False otherwise.
205
+ """
206
+
207
+ def _get_free_type_vars(
208
+ tp: type | typing._SpecialForm | inspect.Parameter | tuple | list,
209
+ ) -> collections.abc.Set[typing.TypeVar]:
210
+ if isinstance(tp, typing.TypeVar):
211
+ return {tp}
212
+ elif isinstance(tp, tuple | list):
213
+ return set().union(*map(_get_free_type_vars, tp))
214
+ elif isinstance(tp, inspect.Parameter):
215
+ return _get_free_type_vars(tp.annotation)
216
+ elif typing.get_origin(tp) is Annotated:
217
+ return _get_free_type_vars(typing.get_args(tp)[0])
218
+ elif typing.get_origin(tp) is not None:
219
+ return _get_free_type_vars(typing.get_args(tp))
220
+ else:
221
+ return set()
222
+
223
+ # invariant: no overlap between ordinal typevars and generic ones
224
+ free_type_vars = _get_free_type_vars(
225
+ (sig.return_annotation, *sig.parameters.values())
226
+ )
227
+ return all(
228
+ free_type_vars.isdisjoint(cls._get_param_ordinal(p))
229
+ for p in (
230
+ sig.return_annotation,
231
+ *sig.parameters.values(),
232
+ )
233
+ )
51
234
 
52
- @typing.overload
53
- def defop(t: Operation[P, T], *, name: Optional[str] = None) -> Operation[P, T]: ...
235
+ @classmethod
236
+ def _check_no_boundvars_in_result(cls, sig: inspect.Signature) -> bool:
237
+ """
238
+ Checks that no bound variables would appear free in the return value.
239
+
240
+ :param sig: The signature to check.
241
+ :returns: True if no bound variables would appear free in the return value, False otherwise.
242
+
243
+ .. note::
244
+
245
+ This is used as a post-condition for :func:`infer_annotations`.
246
+ However, it is not a necessary condition for the correctness of the
247
+ `Scope` annotations of an operation - our current implementation
248
+ merely does not extend to cases where this condition is true.
249
+ """
250
+ root_ordinal = cls._get_root_ordinal(sig)
251
+ return_ordinal = cls._get_param_ordinal(sig.return_annotation)
252
+ return not any(
253
+ root_ordinal < cls._get_param_ordinal(p) <= return_ordinal
254
+ for p in sig.parameters.values()
255
+ if cls._param_is_var(p)
256
+ )
54
257
 
258
+ @classmethod
259
+ def infer_annotations(cls, sig: inspect.Signature) -> inspect.Signature:
260
+ """
261
+ Given a :class:`inspect.Signature` for an :class:`Operation` for which
262
+ only some :class:`inspect.Parameter` s have manual :class:`Scoped` annotations,
263
+ computes a new signature with :class:`Scoped` annotations attached to each parameter,
264
+ including the return type annotation.
265
+
266
+ The new annotations are inferred by joining the manual annotations with a
267
+ fresh root scope. The root scope is the intersection of all :class:`Scoped`
268
+ annotations in the resulting :class:`inspect.Signature` object.
269
+
270
+ :class`Operation` s in this root scope are free in the result and in all arguments.
271
+
272
+ :param sig: The signature of the operation.
273
+ :returns: A new signature with inferred :class:`Scoped` annotations.
274
+ """
275
+ # pre-conditions
276
+ assert cls._check_has_single_scope(sig)
277
+ assert cls._check_no_typevar_overlap(sig)
278
+ assert cls._check_no_boundvars_in_result(sig)
279
+
280
+ root_ordinal = cls._get_root_ordinal(sig)
281
+ if not root_ordinal:
282
+ root_ordinal = cls._get_fresh_ordinal()
283
+
284
+ # add missing Scoped annotations and join everything with the root scope
285
+ new_annos: list[type | typing._SpecialForm] = []
286
+ for anno in (
287
+ sig.return_annotation,
288
+ *(p.annotation for p in sig.parameters.values()),
289
+ ):
290
+ new_scope = cls(ordinal=cls._get_param_ordinal(anno) | root_ordinal)
291
+ if typing.get_origin(anno) is Annotated:
292
+ new_anno = typing.get_args(anno)[0]
293
+ new_anno = Annotated[new_anno, new_scope]
294
+ for other in typing.get_args(anno)[1:]:
295
+ if not isinstance(other, cls):
296
+ new_anno = Annotated[new_anno, other]
297
+ else:
298
+ new_anno = Annotated[anno, new_scope]
299
+
300
+ new_annos.append(new_anno)
301
+
302
+ # construct a new Signature structure with the inferred annotations
303
+ new_return_anno, new_annos = new_annos[0], new_annos[1:]
304
+ inferred_sig = sig.replace(
305
+ parameters=[
306
+ p.replace(annotation=a)
307
+ for p, a in zip(sig.parameters.values(), new_annos)
308
+ ],
309
+ return_annotation=new_return_anno,
310
+ )
55
311
 
56
- def defop(t, *, name=None):
312
+ # post-conditions
313
+ assert cls._get_root_ordinal(inferred_sig) == root_ordinal != set()
314
+ return inferred_sig
315
+
316
+ def analyze(self, bound_sig: inspect.BoundArguments) -> frozenset[Operation]:
317
+ """
318
+ Computes a set of bound variables given a signature with bound arguments.
319
+
320
+ The :func:`analyze` methods of :class:`Scoped` annotations that appear on
321
+ the signature of an :class:`Operation` are used by :func:`defop` to generate
322
+ implementations of :func:`Operation.__fvs_rule__` underlying alpha-renaming
323
+ in :func:`defterm` and :func:`defdata` and free variable sets in :func:`fvsof` .
324
+
325
+ Specifically, the :func:`analyze` method of the :class:`Scoped` annotation
326
+ of a parameter computes the set of bound variables in that parameter's value.
327
+ The :func:`Operation.__fvs_rule__` method generated by :func:`defop` simply
328
+ extracts the annotation of each parameter, calls :func:`analyze` on the value
329
+ given for the corresponding parameter in ``bound_sig`` , and returns the results.
330
+
331
+ :param bound_sig: The :class:`inspect.Signature` of an :class:`Operation`
332
+ together with values for all of its arguments.
333
+ :returns: A set of bound variables.
334
+ """
335
+ bound_vars: frozenset[Operation] = frozenset()
336
+ return_ordinal = self._get_param_ordinal(bound_sig.signature.return_annotation)
337
+ for name, param in bound_sig.signature.parameters.items():
338
+ param_ordinal = self._get_param_ordinal(param)
339
+ if param_ordinal <= self.ordinal and not param_ordinal <= return_ordinal:
340
+ param_value = bound_sig.arguments[name]
341
+ param_bound_vars = set()
342
+
343
+ if self._param_is_var(param):
344
+ # Handle individual Operation parameters (existing behavior)
345
+ if param.kind is inspect.Parameter.VAR_POSITIONAL:
346
+ # pre-condition: all bound variables should be distinct
347
+ assert len(param_value) == len(set(param_value))
348
+ param_bound_vars = set(param_value)
349
+ elif param.kind is inspect.Parameter.VAR_KEYWORD:
350
+ # pre-condition: all bound variables should be distinct
351
+ assert len(param_value.values()) == len(
352
+ set(param_value.values())
353
+ )
354
+ param_bound_vars = set(param_value.values())
355
+ else:
356
+ param_bound_vars = {param_value}
357
+ elif param_ordinal: # Only process if there's a Scoped annotation
358
+ # We can't use flatten here because we want to be able
359
+ # to see dict keys
360
+ def extract_operations(obj):
361
+ if isinstance(obj, Operation):
362
+ param_bound_vars.add(obj)
363
+ elif isinstance(obj, dict):
364
+ for k, v in obj.items():
365
+ extract_operations(k)
366
+ extract_operations(v)
367
+ elif isinstance(obj, list | set | tuple):
368
+ for v in obj:
369
+ extract_operations(v)
370
+
371
+ extract_operations(param_value)
372
+
373
+ # pre-condition: all bound variables should be distinct
374
+ if param_bound_vars:
375
+ assert not bound_vars & param_bound_vars
376
+ bound_vars |= param_bound_vars
377
+
378
+ return bound_vars
379
+
380
+
381
+ @functools.singledispatch
382
+ def defop[**P, T](
383
+ t: Callable[P, T], *, name: str | None = None, freshening=list[int] | None
384
+ ) -> Operation[P, T]:
57
385
  """Creates a fresh :class:`Operation`.
58
386
 
59
387
  :param t: May be a type, callable, or :class:`Operation`. If a type, the
@@ -94,13 +422,13 @@ def defop(t, *, name=None):
94
422
  * Defining an operation with no default rule:
95
423
 
96
424
  We can use :func:`defop` and the
97
- :exc:`effectful.internals.sugar.NoDefaultRule` exception to define an
425
+ :exc:`NotHandled` exception to define an
98
426
  operation with no default rule:
99
427
 
100
428
  >>> @defop
101
429
  ... def add(x: int, y: int) -> int:
102
- ... raise NoDefaultRule
103
- >>> add(1, 2)
430
+ ... raise NotHandled
431
+ >>> print(str(add(1, 2)))
104
432
  add(1, 2)
105
433
 
106
434
  When an operation has no default rule, the free rule is used instead, which
@@ -111,15 +439,14 @@ def defop(t, *, name=None):
111
439
 
112
440
  Passing :func:`defop` a type is a handy way to create a free variable.
113
441
 
114
- >>> import effectful.handlers.operator
115
442
  >>> from effectful.ops.semantics import evaluate
116
443
  >>> x = defop(int, name='x')
117
444
  >>> y = x() + 1
118
445
 
119
446
  ``y`` is free in ``x``, so it is not fully evaluated:
120
447
 
121
- >>> y
122
- add(x(), 1)
448
+ >>> print(str(y))
449
+ __add__(x(), 1)
123
450
 
124
451
  We bind ``x`` by installing a handler for it:
125
452
 
@@ -132,7 +459,8 @@ def defop(t, *, name=None):
132
459
  Because the result of :func:`defop` is always fresh, it's important to
133
460
  be careful with variable identity.
134
461
 
135
- Two variables with the same name are not equal:
462
+ Two operations with the same name that come from different calls to
463
+ ``defop`` are not equal:
136
464
 
137
465
  >>> x1 = defop(int, name='x')
138
466
  >>> x2 = defop(int, name='x')
@@ -143,24 +471,22 @@ def defop(t, *, name=None):
143
471
  operation object. In this example, ``scale`` returns a term with a free
144
472
  variable ``x``:
145
473
 
146
- >>> import effectful.handlers.operator
474
+ >>> x = defop(float, name='x')
147
475
  >>> def scale(a: float) -> float:
148
- ... x = defop(float, name='x')
149
476
  ... return x() * a
150
477
 
151
- Binding the variable ``x`` by creating a fresh operation object does not
478
+ Binding the variable ``x`` as follows does not work:
152
479
 
153
480
  >>> term = scale(3.0)
154
- >>> x = defop(float, name='x')
155
- >>> with handler({x: lambda: 2.0}):
156
- ... print(evaluate(term))
157
- mul(x(), 3.0)
481
+ >>> fresh_x = defop(float, name='x')
482
+ >>> with handler({fresh_x: lambda: 2.0}):
483
+ ... print(str(evaluate(term)))
484
+ __mul__(x(), 3.0)
158
485
 
159
- This does:
486
+ Only the original operation object will work:
160
487
 
161
488
  >>> from effectful.ops.semantics import fvsof
162
- >>> correct_x = [v for v in fvsof(term) if str(x) == 'x'][0]
163
- >>> with handler({correct_x: lambda: 2.0}):
489
+ >>> with handler({x: lambda: 2.0}):
164
490
  ... print(evaluate(term))
165
491
  6.0
166
492
 
@@ -170,7 +496,7 @@ def defop(t, *, name=None):
170
496
  the same name and signature, but no default rule.
171
497
 
172
498
  >>> fresh_select = defop(select)
173
- >>> fresh_select(1, 2)
499
+ >>> print(str(fresh_select(1, 2)))
174
500
  select(1, 2)
175
501
 
176
502
  The new operation is distinct from the original:
@@ -184,47 +510,284 @@ def defop(t, *, name=None):
184
510
  1 2
185
511
 
186
512
  """
513
+ raise NotImplementedError(f"expected type or callable, got {t}")
514
+
515
+
516
+ @defop.register(typing.cast(type[collections.abc.Callable], collections.abc.Callable))
517
+ class _BaseOperation[**Q, V](Operation[Q, V]):
518
+ __signature__: inspect.Signature
519
+ __name__: str
520
+
521
+ _default: Callable[Q, V]
522
+
523
+ def __init__(
524
+ self,
525
+ default: Callable[Q, V],
526
+ *,
527
+ name: str | None = None,
528
+ freshening: list[int] | None = None,
529
+ ):
530
+ functools.update_wrapper(self, default)
531
+ self._default = default
532
+ self.__name__ = name or default.__name__
533
+ self._freshening = freshening or []
534
+ self.__signature__ = inspect.signature(default)
535
+
536
+ def __eq__(self, other):
537
+ if not isinstance(other, Operation):
538
+ return NotImplemented
539
+ return self is other
540
+
541
+ def __lt__(self, other):
542
+ if not isinstance(other, Operation):
543
+ return NotImplemented
544
+ return id(self) < id(other)
545
+
546
+ def __hash__(self):
547
+ return hash(self._default)
548
+
549
+ def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]":
550
+ try:
551
+ try:
552
+ return self._default(*args, **kwargs)
553
+ except NotImplementedError:
554
+ warnings.warn(
555
+ "Operations should raise effectful.ops.types.NotHandled instead of NotImplementedError.",
556
+ DeprecationWarning,
557
+ )
558
+ raise NotHandled
559
+ except NotHandled:
560
+ return typing.cast(
561
+ Callable[Concatenate[Operation[Q, V], Q], Expr[V]], defdata
562
+ )(self, *args, **kwargs)
563
+
564
+ def __fvs_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> inspect.BoundArguments:
565
+ sig = Scoped.infer_annotations(self.__signature__)
566
+ bound_sig = sig.bind(*args, **kwargs)
567
+ bound_sig.apply_defaults()
568
+
569
+ result_sig = sig.bind(
570
+ *(frozenset() for _ in bound_sig.args),
571
+ **{k: frozenset() for k in bound_sig.kwargs},
572
+ )
573
+ for name, param in sig.parameters.items():
574
+ if typing.get_origin(param.annotation) is typing.Annotated:
575
+ for anno in typing.get_args(param.annotation)[1:]:
576
+ if isinstance(anno, Scoped):
577
+ param_bound_vars = anno.analyze(bound_sig)
578
+ if param.kind is inspect.Parameter.VAR_POSITIONAL:
579
+ result_sig.arguments[name] = tuple(
580
+ param_bound_vars for _ in bound_sig.arguments[name]
581
+ )
582
+ elif param.kind is inspect.Parameter.VAR_KEYWORD:
583
+ for k in bound_sig.arguments[name]:
584
+ result_sig.arguments[name][k] = param_bound_vars
585
+ else:
586
+ result_sig.arguments[name] = param_bound_vars
587
+
588
+ return result_sig
589
+
590
+ def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> type[V]:
591
+ from effectful.internals.unification import (
592
+ freetypevars,
593
+ nested_type,
594
+ substitute,
595
+ unify,
596
+ )
187
597
 
188
- if isinstance(t, Operation):
598
+ return_anno = self.__signature__.return_annotation
599
+ if typing.get_origin(return_anno) is typing.Annotated:
600
+ return_anno = typing.get_args(return_anno)[0]
189
601
 
190
- def func(*args, **kwargs):
191
- raise NoDefaultRule
602
+ if return_anno is inspect.Parameter.empty:
603
+ return typing.cast(type[V], object)
604
+ elif return_anno is None:
605
+ return type(None) # type: ignore
606
+ elif not freetypevars(return_anno):
607
+ return return_anno
192
608
 
193
- functools.update_wrapper(func, t)
194
- return defop(func, name=name)
195
- elif isinstance(t, type):
609
+ type_args = tuple(nested_type(a) for a in args)
610
+ type_kwargs = {k: nested_type(v) for k, v in kwargs.items()}
611
+ bound_sig = self.__signature__.bind(*type_args, **type_kwargs)
612
+ return substitute(return_anno, unify(self.__signature__, bound_sig)) # type: ignore
196
613
 
197
- def func() -> t: # type: ignore
198
- raise NoDefaultRule
614
+ def __repr__(self):
615
+ return f"_BaseOperation({self._default}, name={self.__name__}, freshening={self._freshening})"
199
616
 
200
- func.__name__ = name or t.__name__
201
- return typing.cast(Operation[[], T], defop(func, name=name))
202
- elif isinstance(t, collections.abc.Callable):
203
- from effectful.internals.base_impl import _BaseOperation
617
+ def __str__(self):
618
+ return self.__name__
204
619
 
205
- op = _BaseOperation(t)
206
- op.__name__ = name or t.__name__
207
- return op
208
- else:
209
- raise ValueError(f"expected type or callable, got {t}")
620
+ def __get__(self, instance, owner):
621
+ if instance is not None:
622
+ # This is an instance-level operation, so we need to bind the instance
623
+ return functools.partial(self, instance)
624
+ else:
625
+ # This is a static operation, so we return the operation itself
626
+ return self
627
+
628
+
629
+ @defop.register(Operation)
630
+ def _[**P, T](t: Operation[P, T], *, name: str | None = None) -> Operation[P, T]:
631
+ @functools.wraps(t)
632
+ def func(*args, **kwargs):
633
+ raise NotHandled
634
+
635
+ if name is None:
636
+ name = getattr(t, "__name__", str(t))
637
+ freshening = getattr(t, "_freshening", []) + [random.randint(0, 1 << 32)]
638
+
639
+ return defop(func, name=name, freshening=freshening)
640
+
641
+
642
+ @defop.register(type)
643
+ @defop.register(typing.cast(type, types.GenericAlias))
644
+ @defop.register(typing.cast(type, typing._GenericAlias)) # type: ignore
645
+ @defop.register(typing.cast(type, types.UnionType))
646
+ def _[T](t: type[T], *, name: str | None = None) -> Operation[[], T]:
647
+ def func() -> t: # type: ignore
648
+ raise NotHandled
649
+
650
+ freshening = []
651
+ if name is None:
652
+ name = t.__name__
653
+ freshening = [random.randint(0, 1 << 32)]
654
+
655
+ return typing.cast(
656
+ Operation[[], T],
657
+ defop(func, name=name, freshening=freshening),
658
+ )
659
+
660
+
661
+ @defop.register(types.BuiltinFunctionType)
662
+ def _[**P, T](t: Callable[P, T], *, name: str | None = None) -> Operation[P, T]:
663
+ @functools.wraps(t)
664
+ def func(*args, **kwargs):
665
+ from effectful.ops.semantics import fvsof
666
+
667
+ if not fvsof((args, kwargs)):
668
+ return t(*args, **kwargs)
669
+ else:
670
+ raise NotHandled
671
+
672
+ return defop(func, name=name)
673
+
674
+
675
+ @defop.register(classmethod)
676
+ def _[**P, S, T]( # type: ignore
677
+ t: classmethod, *, name: str | None = None
678
+ ) -> Operation[Concatenate[type[S], P], T]:
679
+ raise NotImplementedError("classmethod operations are not yet supported")
680
+
681
+
682
+ @defop.register(staticmethod)
683
+ class _StaticMethodOperation[**P, S, T](_BaseOperation[P, T]):
684
+ def __init__(self, default: staticmethod, **kwargs):
685
+ super().__init__(default=default.__func__, **kwargs)
686
+
687
+ def __get__(self, instance: S, owner: type[S] | None = None) -> Callable[P, T]:
688
+ return self
689
+
690
+
691
+ @defop.register(property)
692
+ class _PropertyOperation[S, T](_BaseOperation[[S], T]):
693
+ def __init__(self, default: property, **kwargs): # type: ignore
694
+ assert not default.fset, "property with setter is not supported"
695
+ assert not default.fdel, "property with deleter is not supported"
696
+ super().__init__(default=typing.cast(Callable[[S], T], default.fget), **kwargs)
697
+
698
+ @typing.overload
699
+ def __get__(
700
+ self, instance: None, owner: type[S] | None = None
701
+ ) -> "_PropertyOperation[S, T]": ...
702
+
703
+ @typing.overload
704
+ def __get__(self, instance: S, owner: type[S] | None = None) -> T: ...
705
+
706
+ def __get__(self, instance, owner: type[S] | None = None):
707
+ if instance is not None:
708
+ return self(instance)
709
+ else:
710
+ return self
711
+
712
+
713
+ @defop.register(functools.singledispatchmethod)
714
+ class _SingleDispatchMethodOperation[**P, S, T](_BaseOperation[Concatenate[S, P], T]):
715
+ _default: Callable[Concatenate[S, P], T]
716
+
717
+ def __init__(self, default: functools.singledispatchmethod, **kwargs): # type: ignore
718
+ if isinstance(default.func, classmethod):
719
+ raise NotImplementedError("Operations as classmethod are not yet supported")
720
+
721
+ @functools.wraps(default.func)
722
+ def _wrapper(obj: S, *args: P.args, **kwargs: P.kwargs) -> T:
723
+ return default.__get__(obj)(*args, **kwargs)
724
+
725
+ self._registry: functools.singledispatchmethod = default
726
+ super().__init__(_wrapper, **kwargs)
727
+
728
+ @typing.overload
729
+ def __get__(
730
+ self, instance: None, owner: type[S] | None = None
731
+ ) -> "_SingleDispatchMethodOperation[P, S, T]": ...
732
+
733
+ @typing.overload
734
+ def __get__(self, instance: S, owner: type[S] | None = None) -> Callable[P, T]: ...
735
+
736
+ def __get__(self, instance, owner: type[S] | None = None):
737
+ if instance is not None:
738
+ return functools.partial(self, instance)
739
+ else:
740
+ return self
741
+
742
+ @property
743
+ def register(self):
744
+ return self._registry.register
745
+
746
+ @property
747
+ def __isabstractmethod__(self):
748
+ return self._registry.__isabstractmethod__
749
+
750
+
751
+ class _SingleDispatchOperation[**P, S, T](_BaseOperation[Concatenate[S, P], T]):
752
+ _default: "functools._SingleDispatchCallable[T]"
753
+
754
+ @property
755
+ def register(self):
756
+ return self._default.register
757
+
758
+ @property
759
+ def dispatch(self):
760
+ return self._default.dispatch
761
+
762
+
763
+ if typing.TYPE_CHECKING:
764
+ defop.register(functools._SingleDispatchCallable)(_SingleDispatchOperation)
765
+ else:
766
+
767
+ @typing.runtime_checkable
768
+ class _SingleDispatchCallable(typing.Protocol):
769
+ registry: types.MappingProxyType[object, Callable]
770
+
771
+ def dispatch(self, cls: type) -> Callable: ...
772
+ def register(self, cls: type, func: Callable | None = None) -> Callable: ...
773
+ def _clear_cache(self) -> None: ...
774
+ def __call__(self, /, *args, **kwargs): ...
775
+
776
+ defop.register(_SingleDispatchCallable)(_SingleDispatchOperation)
210
777
 
211
778
 
212
779
  @defop
213
- def deffn(
214
- body: T,
215
- *args: Annotated[Operation, Bound()],
216
- **kwargs: Annotated[Operation, Bound()],
217
- ) -> Callable[..., T]:
780
+ def deffn[T, A, B](
781
+ body: Annotated[T, Scoped[A | B]],
782
+ *args: Annotated[Operation, Scoped[A]],
783
+ **kwargs: Annotated[Operation, Scoped[A]],
784
+ ) -> Annotated[Callable[..., T], Scoped[B]]:
218
785
  """An operation that represents a lambda function.
219
786
 
220
787
  :param body: The body of the function.
221
- :type body: T
222
788
  :param args: Operations representing the positional arguments of the function.
223
- :type args: Annotated[Operation, Bound()]
224
789
  :param kwargs: Operations representing the keyword arguments of the function.
225
- :type kwargs: Annotated[Operation, Bound()]
226
790
  :returns: A callable term.
227
- :rtype: Callable[..., T]
228
791
 
229
792
  :func:`deffn` terms are eliminated by the :func:`call` operation, which
230
793
  performs beta-reduction.
@@ -234,11 +797,13 @@ def deffn(
234
797
  Here :func:`deffn` is used to define a term that represents the function
235
798
  ``lambda x, y=1: 2 * x + y``:
236
799
 
237
- >>> import effectful.handlers.operator
800
+ >>> import random
801
+ >>> random.seed(0)
802
+
238
803
  >>> x, y = defop(int, name='x'), defop(int, name='y')
239
804
  >>> term = deffn(2 * x() + y(), x, y=y)
240
- >>> term
241
- deffn(add(mul(2, x()), y()), x, y=y)
805
+ >>> print(str(term)) # doctest: +ELLIPSIS
806
+ deffn(...)
242
807
  >>> term(3, y=4)
243
808
  10
244
809
 
@@ -249,14 +814,14 @@ def deffn(
249
814
  automatically create the right free variables.
250
815
 
251
816
  """
252
- raise NoDefaultRule
817
+ raise NotHandled
253
818
 
254
819
 
255
- class _CustomSingleDispatchCallable(Generic[P, T]):
820
+ class _CustomSingleDispatchCallable[**P, **Q, S, T]:
256
821
  def __init__(
257
- self, func: Callable[Concatenate[Callable[[type], Callable[P, T]], P], T]
822
+ self, func: Callable[Concatenate[Callable[[type], Callable[Q, S]], P], T]
258
823
  ):
259
- self._func = func
824
+ self.func = func
260
825
  self._registry = functools.singledispatch(func)
261
826
  functools.update_wrapper(self, func)
262
827
 
@@ -269,50 +834,52 @@ class _CustomSingleDispatchCallable(Generic[P, T]):
269
834
  return self._registry.register
270
835
 
271
836
  def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
272
- return self._func(self.dispatch, *args, **kwargs)
837
+ return self.func(self.dispatch, *args, **kwargs)
273
838
 
274
839
 
275
- @_CustomSingleDispatchCallable
276
- def defterm(dispatch, value: T) -> Expr[T]:
277
- """Convert a value to a term, using the type of the value to dispatch.
840
+ @defop.register(_CustomSingleDispatchCallable)
841
+ class _CustomSingleDispatchOperation[**P, **Q, S, T](_BaseOperation[P, T]):
842
+ _default: _CustomSingleDispatchCallable[P, Q, S, T]
278
843
 
279
- :param value: The value to convert.
280
- :type value: T
281
- :returns: A term.
282
- :rtype: Expr[T]
844
+ def __init__(self, default: _CustomSingleDispatchCallable[P, Q, S, T], **kwargs):
845
+ super().__init__(default, **kwargs)
846
+ self.__signature__ = inspect.signature(functools.partial(default.func, None)) # type: ignore
283
847
 
284
- **Example usage**:
848
+ @property
849
+ def dispatch(self):
850
+ return self._registry.dispatch
851
+
852
+ @property
853
+ def register(self):
854
+ return self._registry.register
285
855
 
286
- :func:`defterm` can be passed a function, and it will convert that function
287
- to a term by calling it with appropriately typed free variables:
288
856
 
289
- >>> def incr(x: int) -> int:
290
- ... return x + 1
291
- >>> term = defterm(incr)
292
- >>> term
293
- deffn(add(int(), 1), int)
294
- >>> term(2)
295
- 3
857
+ @_CustomSingleDispatchCallable
858
+ def defterm[T](__dispatch: Callable[[type], Callable[[T], Expr[T]]], value: T):
859
+ """Convert a value to a term, using the type of the value to dispatch.
296
860
 
861
+ :param value: The value to convert.
862
+ :returns: A term.
297
863
  """
298
864
  if isinstance(value, Term):
299
865
  return value
300
866
  else:
301
- return dispatch(type(value))(value)
867
+ return __dispatch(type(value))(value)
302
868
 
303
869
 
304
870
  @_CustomSingleDispatchCallable
305
- def defdata(dispatch, expr: Term[T]) -> Expr[T]:
306
- """Converts a term so that it is an instance of its inferred type.
871
+ def defdata[T](
872
+ __dispatch: Callable[[type], Callable[..., Expr[T]]],
873
+ op: Operation[..., T],
874
+ *args,
875
+ **kwargs,
876
+ ) -> Expr[T]:
877
+ """Constructs a Term that is an instance of its semantic type.
307
878
 
308
- :param expr: The term to convert.
309
- :type expr: Term[T]
310
879
  :returns: An instance of ``T``.
311
880
  :rtype: Expr[T]
312
881
 
313
- This function is called by :func:`__free_rule__`, so conversions
314
- resgistered with :func:`defdata` are automatically applied when terms are
315
- constructed.
882
+ This function is the only way to construct a :class:`Term` from an :class:`Operation`.
316
883
 
317
884
  .. note::
318
885
 
@@ -326,91 +893,288 @@ def defdata(dispatch, expr: Term[T]) -> Expr[T]:
326
893
 
327
894
  .. code-block:: python
328
895
 
329
- class _CallableTerm(Generic[P, T], _BaseTerm[collections.abc.Callable[P, T]]):
330
- def __call__(self, *args: Expr, **kwargs: Expr) -> Expr[T]:
331
- from effectful.ops.semantics import call
332
-
333
- return call(self, *args, **kwargs)
334
-
335
896
  @defdata.register(collections.abc.Callable)
336
- def _(op, args, kwargs):
337
- return _CallableTerm(op, args, kwargs)
897
+ class _CallableTerm[**P, T](Term[collections.abc.Callable[P, T]]):
898
+ def __init__(
899
+ self,
900
+ op: Operation[..., T],
901
+ *args: Expr,
902
+ **kwargs: Expr,
903
+ ):
904
+ self._op = op
905
+ self._args = args
906
+ self._kwargs = kwargs
907
+
908
+ @property
909
+ def op(self):
910
+ return self._op
911
+
912
+ @property
913
+ def args(self):
914
+ return self._args
915
+
916
+ @property
917
+ def kwargs(self):
918
+ return self._kwargs
919
+
920
+ @defop
921
+ def __call__(self: collections.abc.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
922
+ ...
923
+
924
+ When an Operation whose return type is `Callable` is passed to :func:`defdata`,
925
+ it is reconstructed as a :class:`_CallableTerm`, which implements the :func:`__call__` method.
926
+ """
927
+ from effectful.ops.semantics import apply, evaluate, typeof
338
928
 
339
- When a :class:`Callable` term is passed to :func:`defdata`, it is
340
- reconstructed as a :class:`_CallableTerm`, which implements the
341
- :func:`__call__` method.
929
+ bindings: inspect.BoundArguments = op.__fvs_rule__(*args, **kwargs)
930
+ renaming = {
931
+ var: defop(var)
932
+ for bound_vars in (*bindings.args, *bindings.kwargs.values())
933
+ for var in bound_vars
934
+ }
342
935
 
343
- """
344
- from effectful.ops.semantics import typeof
936
+ renamed_args: inspect.BoundArguments = op.__signature__.bind(*args, **kwargs)
937
+ renamed_args.apply_defaults()
345
938
 
346
- if isinstance(expr, Term):
347
- impl: Callable[[Operation[..., T], Sequence, Mapping[str, object]], Expr[T]]
348
- impl = dispatch(typeof(expr)) # type: ignore
349
- return impl(expr.op, expr.args, expr.kwargs)
350
- else:
351
- return expr
939
+ args_ = [
940
+ evaluate(
941
+ arg, intp={apply: defdata, **{v: renaming[v] for v in bindings.args[i]}}
942
+ )
943
+ for i, arg in enumerate(renamed_args.args)
944
+ ]
945
+ kwargs_ = {
946
+ k: evaluate(
947
+ arg, intp={apply: defdata, **{v: renaming[v] for v in bindings.kwargs[k]}}
948
+ )
949
+ for k, arg in renamed_args.kwargs.items()
950
+ }
951
+
952
+ base_term = __dispatch(typing.cast(type[T], object))(op, *args_, **kwargs_)
953
+ return __dispatch(typeof(base_term))(op, *args_, **kwargs_)
352
954
 
353
955
 
354
956
  @defterm.register(object)
355
957
  @defterm.register(Operation)
356
958
  @defterm.register(Term)
357
- def _(value: T) -> T:
959
+ @defterm.register(type)
960
+ @defterm.register(types.BuiltinFunctionType)
961
+ def _[T](value: T) -> T:
358
962
  return value
359
963
 
360
964
 
361
965
  @defdata.register(object)
362
- def _(op, args, kwargs):
363
- from effectful.internals.base_impl import _BaseTerm
966
+ class _BaseTerm[T](Term[T]):
967
+ _op: Operation[..., T]
968
+ _args: collections.abc.Sequence[Expr]
969
+ _kwargs: collections.abc.Mapping[str, Expr]
970
+
971
+ def __init__(
972
+ self,
973
+ op: Operation[..., T],
974
+ *args: Expr,
975
+ **kwargs: Expr,
976
+ ):
977
+ self._op = op
978
+ self._args = args
979
+ self._kwargs = kwargs
980
+
981
+ def __eq__(self, other) -> bool:
982
+ from effectful.ops.syntax import syntactic_eq
983
+
984
+ return syntactic_eq(self, other)
985
+
986
+ @property
987
+ def op(self):
988
+ return self._op
989
+
990
+ @property
991
+ def args(self):
992
+ return self._args
364
993
 
365
- return _BaseTerm(op, args, kwargs)
994
+ @property
995
+ def kwargs(self):
996
+ return self._kwargs
366
997
 
367
998
 
368
999
  @defdata.register(collections.abc.Callable)
369
- def _(op, args, kwargs):
370
- from effectful.internals.base_impl import _CallableTerm
1000
+ class _CallableTerm[**P, T](_BaseTerm[collections.abc.Callable[P, T]]):
1001
+ @defop
1002
+ def __call__(
1003
+ self: collections.abc.Callable[P, T], *args: P.args, **kwargs: P.kwargs
1004
+ ) -> T:
1005
+ from effectful.ops.semantics import evaluate, fvsof, handler
1006
+
1007
+ if isinstance(self, Term) and self.op is deffn:
1008
+ body: Expr[Callable[P, T]] = self.args[0]
1009
+ argvars: tuple[Operation, ...] = self.args[1:]
1010
+ kwvars: dict[str, Operation] = self.kwargs
1011
+ subs = {
1012
+ **{v: functools.partial(lambda x: x, a) for v, a in zip(argvars, args)},
1013
+ **{
1014
+ kwvars[k]: functools.partial(lambda x: x, kwargs[k]) for k in kwargs
1015
+ },
1016
+ }
1017
+ with handler(subs):
1018
+ return evaluate(body)
1019
+ elif not fvsof((self, args, kwargs)):
1020
+ return self(*args, **kwargs)
1021
+ else:
1022
+ raise NotHandled
1023
+
1024
+
1025
+ def trace[**P, T](value: Callable[P, T]) -> Callable[P, T]:
1026
+ """Convert a callable to a term by calling it with appropriately typed free variables.
1027
+
1028
+ **Example usage**:
1029
+
1030
+ :func:`trace` can be passed a function, and it will convert that function to
1031
+ a term by calling it with appropriately typed free variables:
1032
+
1033
+ >>> def incr(x: int) -> int:
1034
+ ... return x + 1
1035
+ >>> term = trace(incr)
371
1036
 
372
- return _CallableTerm(op, args, kwargs)
1037
+ >>> print(str(term))
1038
+ deffn(__add__(int(), 1), int)
373
1039
 
1040
+ >>> term(2)
1041
+ 3
1042
+
1043
+ """
1044
+ from effectful.internals.runtime import interpreter
1045
+ from effectful.ops.semantics import apply
374
1046
 
375
- @defterm.register(collections.abc.Callable)
376
- def _(fn: Callable[P, T]):
377
- from effectful.internals.base_impl import _unembed_callable
1047
+ call = defdata.dispatch(collections.abc.Callable).__call__
1048
+ assert isinstance(call, Operation)
378
1049
 
379
- return _unembed_callable(fn)
1050
+ assert not isinstance(value, Term)
380
1051
 
1052
+ try:
1053
+ sig = inspect.signature(value)
1054
+ except ValueError:
1055
+ return value
1056
+
1057
+ for name, param in sig.parameters.items():
1058
+ if param.kind in (
1059
+ inspect.Parameter.VAR_POSITIONAL,
1060
+ inspect.Parameter.VAR_KEYWORD,
1061
+ ):
1062
+ raise ValueError(f"cannot unembed {value}: parameter {name} is variadic")
1063
+
1064
+ bound_sig = sig.bind(
1065
+ **{name: defop(param.annotation) for name, param in sig.parameters.items()}
1066
+ )
1067
+ bound_sig.apply_defaults()
1068
+
1069
+ with interpreter({apply: defdata, call: call.__default_rule__}):
1070
+ body = value(
1071
+ *[a() for a in bound_sig.args],
1072
+ **{k: v() for k, v in bound_sig.kwargs.items()},
1073
+ )
381
1074
 
382
- def syntactic_eq(x: Expr[T], other: Expr[T]) -> bool:
1075
+ return deffn(body, *bound_sig.args, **bound_sig.kwargs)
1076
+
1077
+
1078
+ @defop
1079
+ def defstream[S, T, A, B](
1080
+ body: Annotated[T, Scoped[A | B]],
1081
+ streams: Annotated[Mapping[Operation[[], S], Iterable[S]], Scoped[B]],
1082
+ ) -> Annotated[Iterable[T], Scoped[A]]:
1083
+ """A higher-order operation that represents a for-expression."""
1084
+ raise NotHandled
1085
+
1086
+
1087
+ @defdata.register(collections.abc.Iterable)
1088
+ class _IterableTerm[T](_BaseTerm[collections.abc.Iterable[T]]):
1089
+ @defop
1090
+ def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]:
1091
+ if not isinstance(self, Term):
1092
+ return iter(self)
1093
+ else:
1094
+ raise NotHandled
1095
+
1096
+
1097
+ @defdata.register(collections.abc.Iterator)
1098
+ class _IteratorTerm[T](_IterableTerm[T]):
1099
+ @defop
1100
+ def __next__(self: collections.abc.Iterator[T]) -> T:
1101
+ if not isinstance(self, Term):
1102
+ return next(self)
1103
+ else:
1104
+ raise NotHandled
1105
+
1106
+
1107
+ iter_ = _IterableTerm.__iter__
1108
+ next_ = _IteratorTerm.__next__
1109
+
1110
+
1111
+ @_CustomSingleDispatchCallable
1112
+ def syntactic_eq(
1113
+ __dispatch: Callable[[type], Callable[[Any, Any], bool]], x, other
1114
+ ) -> bool:
383
1115
  """Syntactic equality, ignoring the interpretation of the terms.
384
1116
 
385
1117
  :param x: A term.
386
- :type x: Expr[T]
387
1118
  :param other: Another term.
388
- :type other: Expr[T]
389
1119
  :returns: ``True`` if the terms are syntactically equal and ``False`` otherwise.
390
1120
  """
391
- if isinstance(x, Term) and isinstance(other, Term):
392
- op, args, kwargs = x.op, x.args, x.kwargs
393
- op2, args2, kwargs2 = other.op, other.args, other.kwargs
394
- try:
395
- tree.assert_same_structure(
396
- (op, args, kwargs), (op2, args2, kwargs2), check_types=True
397
- )
398
- except (TypeError, ValueError):
399
- return False
400
- return all(
401
- tree.flatten(
402
- tree.map_structure(
403
- syntactic_eq, (op, args, kwargs), (op2, args2, kwargs2)
404
- )
405
- )
1121
+ if (
1122
+ dataclasses.is_dataclass(x)
1123
+ and not isinstance(x, type)
1124
+ and dataclasses.is_dataclass(other)
1125
+ and not isinstance(other, type)
1126
+ ):
1127
+ return type(x) == type(other) and syntactic_eq(
1128
+ {field.name: getattr(x, field.name) for field in dataclasses.fields(x)},
1129
+ {
1130
+ field.name: getattr(other, field.name)
1131
+ for field in dataclasses.fields(other)
1132
+ },
406
1133
  )
407
- elif isinstance(x, Term) or isinstance(other, Term):
408
- return False
409
1134
  else:
410
- return x == other
1135
+ return __dispatch(type(x))(x, other)
1136
+
1137
+
1138
+ @syntactic_eq.register
1139
+ def _(x: Term, other) -> bool:
1140
+ if not isinstance(other, Term):
1141
+ return False
1142
+
1143
+ op, args, kwargs = x.op, x.args, x.kwargs
1144
+ op2, args2, kwargs2 = other.op, other.args, other.kwargs
1145
+ return (
1146
+ op == op2
1147
+ and len(args) == len(args2)
1148
+ and set(kwargs) == set(kwargs2)
1149
+ and all(syntactic_eq(a, b) for a, b in zip(args, args2))
1150
+ and all(syntactic_eq(kwargs[k], kwargs2[k]) for k in kwargs)
1151
+ )
1152
+
1153
+
1154
+ @syntactic_eq.register
1155
+ def _(x: collections.abc.Mapping, other) -> bool:
1156
+ return isinstance(other, collections.abc.Mapping) and all(
1157
+ k in x and k in other and syntactic_eq(x[k], other[k])
1158
+ for k in set(x) | set(other)
1159
+ )
1160
+
1161
+
1162
+ @syntactic_eq.register
1163
+ def _(x: collections.abc.Sequence, other) -> bool:
1164
+ return (
1165
+ isinstance(other, collections.abc.Sequence)
1166
+ and len(x) == len(other)
1167
+ and all(syntactic_eq(a, b) for a, b in zip(x, other))
1168
+ )
1169
+
411
1170
 
1171
+ @syntactic_eq.register(object)
1172
+ @syntactic_eq.register(str | bytes)
1173
+ def _(x: object, other) -> bool:
1174
+ return x == other
412
1175
 
413
- class ObjectInterpretation(Generic[T, V], Interpretation[T, V]):
1176
+
1177
+ class ObjectInterpretation[T, V](collections.abc.Mapping):
414
1178
  """A helper superclass for defining an ``Interpretation`` of many
415
1179
  :class:`~effectful.ops.types.Operation` instances with shared state or behavior.
416
1180
 
@@ -487,8 +1251,8 @@ class ObjectInterpretation(Generic[T, V], Interpretation[T, V]):
487
1251
  return self.implementations[item].__get__(self, type(self))
488
1252
 
489
1253
 
490
- class _ImplementedOperation(Generic[P, Q, T, V]):
491
- impl: Optional[Callable[Q, V]]
1254
+ class _ImplementedOperation[**P, **Q, T, V]:
1255
+ impl: Callable[Q, V] | None
492
1256
  op: Operation[P, T]
493
1257
 
494
1258
  def __init__(self, op: Operation[P, T]):
@@ -512,7 +1276,7 @@ class _ImplementedOperation(Generic[P, Q, T, V]):
512
1276
  owner._temporary_implementations[self.op] = self.impl
513
1277
 
514
1278
 
515
- def implements(op: Operation[P, V]):
1279
+ def implements[**P, V](op: Operation[P, V]):
516
1280
  """Marks a method in an :class:`ObjectInterpretation` as the implementation of a
517
1281
  particular abstract :class:`Operation`.
518
1282
 
@@ -521,3 +1285,348 @@ def implements(op: Operation[P, V]):
521
1285
 
522
1286
  """
523
1287
  return _ImplementedOperation(op)
1288
+
1289
+
1290
+ @defdata.register(numbers.Number)
1291
+ @functools.total_ordering
1292
+ class _NumberTerm[T: numbers.Number](_BaseTerm[T], numbers.Number):
1293
+ def __hash__(self):
1294
+ return id(self)
1295
+
1296
+ def __complex__(self) -> complex:
1297
+ raise ValueError("Cannot convert term to complex number")
1298
+
1299
+ def __float__(self) -> float:
1300
+ raise ValueError("Cannot convert term to float")
1301
+
1302
+ def __int__(self) -> int:
1303
+ raise ValueError("Cannot convert term to int")
1304
+
1305
+ def __bool__(self) -> bool:
1306
+ raise ValueError("Cannot convert term to bool")
1307
+
1308
+ @defop # type: ignore[prop-decorator]
1309
+ @property
1310
+ def real(self) -> float:
1311
+ if not isinstance(self, Term):
1312
+ return self.real
1313
+ else:
1314
+ raise NotHandled
1315
+
1316
+ @defop # type: ignore[prop-decorator]
1317
+ @property
1318
+ def imag(self) -> float:
1319
+ if not isinstance(self, Term):
1320
+ return self.imag
1321
+ else:
1322
+ raise NotHandled
1323
+
1324
+ @defop
1325
+ def conjugate(self) -> complex:
1326
+ if not isinstance(self, Term):
1327
+ return self.conjugate()
1328
+ else:
1329
+ raise NotHandled
1330
+
1331
+ @defop # type: ignore[prop-decorator]
1332
+ @property
1333
+ def numerator(self) -> int:
1334
+ if not isinstance(self, Term):
1335
+ return self.numerator
1336
+ else:
1337
+ raise NotHandled
1338
+
1339
+ @defop # type: ignore[prop-decorator]
1340
+ @property
1341
+ def denominator(self) -> int:
1342
+ if not isinstance(self, Term):
1343
+ return self.denominator
1344
+ else:
1345
+ raise NotHandled
1346
+
1347
+ @defop
1348
+ def __abs__(self) -> float:
1349
+ """Return the absolute value of the term."""
1350
+ if not isinstance(self, Term):
1351
+ return self.__abs__()
1352
+ else:
1353
+ raise NotHandled
1354
+
1355
+ @defop
1356
+ def __neg__(self: T) -> T:
1357
+ if not isinstance(self, Term):
1358
+ return self.__neg__() # type: ignore
1359
+ else:
1360
+ raise NotHandled
1361
+
1362
+ @defop
1363
+ def __pos__(self: T) -> T:
1364
+ if not isinstance(self, Term):
1365
+ return self.__pos__() # type: ignore
1366
+ else:
1367
+ raise NotHandled
1368
+
1369
+ @defop
1370
+ def __trunc__(self) -> int:
1371
+ if not isinstance(self, Term):
1372
+ return self.__trunc__()
1373
+ else:
1374
+ raise NotHandled
1375
+
1376
+ @defop
1377
+ def __floor__(self) -> int:
1378
+ if not isinstance(self, Term):
1379
+ return self.__floor__()
1380
+ else:
1381
+ raise NotHandled
1382
+
1383
+ @defop
1384
+ def __ceil__(self) -> int:
1385
+ if not isinstance(self, Term):
1386
+ return self.__ceil__()
1387
+ else:
1388
+ raise NotHandled
1389
+
1390
+ @defop
1391
+ def __round__(self, ndigits: int | None = None) -> numbers.Real:
1392
+ if not isinstance(self, Term) and not isinstance(ndigits, Term):
1393
+ return self.__round__(ndigits)
1394
+ else:
1395
+ raise NotHandled
1396
+
1397
+ @defop
1398
+ def __invert__(self) -> int:
1399
+ if not isinstance(self, Term):
1400
+ return self.__invert__()
1401
+ else:
1402
+ raise NotHandled
1403
+
1404
+ @defop
1405
+ def __index__(self) -> int:
1406
+ if not isinstance(self, Term):
1407
+ return self.__index__()
1408
+ else:
1409
+ raise NotHandled
1410
+
1411
+ @defop
1412
+ def __eq__(self, other) -> bool: # type: ignore[override]
1413
+ if not isinstance(self, Term) and not isinstance(other, Term):
1414
+ return self.__eq__(other)
1415
+ else:
1416
+ return syntactic_eq(self, other)
1417
+
1418
+ @defop
1419
+ def __lt__(self, other) -> bool:
1420
+ if not isinstance(self, Term) and not isinstance(other, Term):
1421
+ return self.__lt__(other)
1422
+ else:
1423
+ raise NotHandled
1424
+
1425
+ @defop
1426
+ def __add__(self, other: T) -> T:
1427
+ if not isinstance(self, Term) and not isinstance(other, Term):
1428
+ return operator.__add__(self, other)
1429
+ else:
1430
+ raise NotHandled
1431
+
1432
+ def __radd__(self, other):
1433
+ if isinstance(other, Term) and isinstance(other, type(self)):
1434
+ return other.__add__(self)
1435
+ elif not isinstance(other, Term):
1436
+ return type(self).__add__(other, self)
1437
+ else:
1438
+ return NotImplemented
1439
+
1440
+ @defop
1441
+ def __sub__(self, other: T) -> T:
1442
+ if not isinstance(self, Term) and not isinstance(other, Term):
1443
+ return operator.__sub__(self, other)
1444
+ else:
1445
+ raise NotHandled
1446
+
1447
+ def __rsub__(self, other):
1448
+ if isinstance(other, Term) and isinstance(other, type(self)):
1449
+ return other.__sub__(self)
1450
+ elif not isinstance(other, Term):
1451
+ return type(self).__sub__(other, self)
1452
+ else:
1453
+ return NotImplemented
1454
+
1455
+ @defop
1456
+ def __mul__(self, other: T) -> T:
1457
+ if not isinstance(self, Term) and not isinstance(other, Term):
1458
+ return operator.__mul__(self, other)
1459
+ else:
1460
+ raise NotHandled
1461
+
1462
+ def __rmul__(self, other):
1463
+ if isinstance(other, Term) and isinstance(other, type(self)):
1464
+ return other.__mul__(self)
1465
+ elif not isinstance(other, Term):
1466
+ return type(self).__mul__(other, self)
1467
+ else:
1468
+ return NotImplemented
1469
+
1470
+ @defop
1471
+ def __truediv__(self, other: T) -> T:
1472
+ if not isinstance(self, Term) and not isinstance(other, Term):
1473
+ return operator.__truediv__(self, other)
1474
+ else:
1475
+ raise NotHandled
1476
+
1477
+ def __rtruediv__(self, other):
1478
+ if isinstance(other, Term) and isinstance(other, type(self)):
1479
+ return other.__truediv__(self)
1480
+ elif not isinstance(other, Term):
1481
+ return type(self).__truediv__(other, self)
1482
+ else:
1483
+ return NotImplemented
1484
+
1485
+ @defop
1486
+ def __floordiv__(self, other: T) -> T:
1487
+ if not isinstance(self, Term) and not isinstance(other, Term):
1488
+ return operator.__floordiv__(self, other)
1489
+ else:
1490
+ raise NotHandled
1491
+
1492
+ def __rfloordiv__(self, other):
1493
+ if isinstance(other, Term) and isinstance(other, type(self)):
1494
+ return other.__floordiv__(self)
1495
+ elif not isinstance(other, Term):
1496
+ return type(self).__floordiv__(other, self)
1497
+ else:
1498
+ return NotImplemented
1499
+
1500
+ @defop
1501
+ def __mod__(self, other: T) -> T:
1502
+ if not isinstance(self, Term) and not isinstance(other, Term):
1503
+ return operator.__mod__(self, other)
1504
+ else:
1505
+ raise NotHandled
1506
+
1507
+ def __rmod__(self, other):
1508
+ if isinstance(other, Term) and isinstance(other, type(self)):
1509
+ return other.__mod__(self)
1510
+ elif not isinstance(other, Term):
1511
+ return type(self).__mod__(other, self)
1512
+ else:
1513
+ return NotImplemented
1514
+
1515
+ @defop
1516
+ def __pow__(self, other: T) -> T:
1517
+ if not isinstance(self, Term) and not isinstance(other, Term):
1518
+ return operator.__pow__(self, other)
1519
+ else:
1520
+ raise NotHandled
1521
+
1522
+ def __rpow__(self, other):
1523
+ if isinstance(other, Term) and isinstance(other, type(self)):
1524
+ return other.__pow__(self)
1525
+ elif not isinstance(other, Term):
1526
+ return type(self).__pow__(other, self)
1527
+ else:
1528
+ return NotImplemented
1529
+
1530
+ @defop
1531
+ def __lshift__(self, other: T) -> T:
1532
+ if not isinstance(self, Term) and not isinstance(other, Term):
1533
+ return operator.__lshift__(self, other)
1534
+ else:
1535
+ raise NotHandled
1536
+
1537
+ def __rlshift__(self, other):
1538
+ if isinstance(other, Term) and isinstance(other, type(self)):
1539
+ return other.__lshift__(self)
1540
+ elif not isinstance(other, Term):
1541
+ return type(self).__lshift__(other, self)
1542
+ else:
1543
+ return NotImplemented
1544
+
1545
+ @defop
1546
+ def __rshift__(self, other: T) -> T:
1547
+ if not isinstance(self, Term) and not isinstance(other, Term):
1548
+ return operator.__rshift__(self, other)
1549
+ else:
1550
+ raise NotHandled
1551
+
1552
+ def __rrshift__(self, other):
1553
+ if isinstance(other, Term) and isinstance(other, type(self)):
1554
+ return other.__rshift__(self)
1555
+ elif not isinstance(other, Term):
1556
+ return type(self).__rshift__(other, self)
1557
+ else:
1558
+ return NotImplemented
1559
+
1560
+ @defop
1561
+ def __and__(self, other: T) -> T:
1562
+ if not isinstance(self, Term) and not isinstance(other, Term):
1563
+ return operator.__and__(self, other)
1564
+ else:
1565
+ raise NotHandled
1566
+
1567
+ def __rand__(self, other):
1568
+ if isinstance(other, Term) and isinstance(other, type(self)):
1569
+ return other.__and__(self)
1570
+ elif not isinstance(other, Term):
1571
+ return type(self).__and__(other, self)
1572
+ else:
1573
+ return NotImplemented
1574
+
1575
+ @defop
1576
+ def __xor__(self, other: T) -> T:
1577
+ if not isinstance(self, Term) and not isinstance(other, Term):
1578
+ return operator.__xor__(self, other)
1579
+ else:
1580
+ raise NotHandled
1581
+
1582
+ def __rxor__(self, other):
1583
+ if isinstance(other, Term) and isinstance(other, type(self)):
1584
+ return other.__xor__(self)
1585
+ elif not isinstance(other, Term):
1586
+ return type(self).__xor__(other, self)
1587
+ else:
1588
+ return NotImplemented
1589
+
1590
+ @defop
1591
+ def __or__(self, other: T) -> T:
1592
+ if not isinstance(self, Term) and not isinstance(other, Term):
1593
+ return operator.__or__(self, other)
1594
+ else:
1595
+ raise NotHandled
1596
+
1597
+ def __ror__(self, other):
1598
+ if isinstance(other, Term) and isinstance(other, type(self)):
1599
+ return other.__or__(self)
1600
+ elif not isinstance(other, Term):
1601
+ return type(self).__or__(other, self)
1602
+ else:
1603
+ return NotImplemented
1604
+
1605
+
1606
+ @defdata.register(numbers.Complex)
1607
+ @numbers.Complex.register
1608
+ class _ComplexTerm[T: numbers.Complex](_NumberTerm[T]):
1609
+ pass
1610
+
1611
+
1612
+ @defdata.register(numbers.Real)
1613
+ @numbers.Real.register
1614
+ class _RealTerm[T: numbers.Real](_ComplexTerm[T]):
1615
+ pass
1616
+
1617
+
1618
+ @defdata.register(numbers.Rational)
1619
+ @numbers.Rational.register
1620
+ class _RationalTerm[T: numbers.Rational](_RealTerm[T]):
1621
+ pass
1622
+
1623
+
1624
+ @defdata.register(numbers.Integral)
1625
+ @numbers.Integral.register
1626
+ class _IntegralTerm[T: numbers.Integral](_RationalTerm[T]):
1627
+ pass
1628
+
1629
+
1630
+ @defdata.register(bool)
1631
+ class _BoolTerm[T: bool](_IntegralTerm[T]): # type: ignore
1632
+ pass