effectful 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/handlers/indexed.py +23 -24
- effectful/handlers/jax/__init__.py +14 -0
- effectful/handlers/jax/_handlers.py +293 -0
- effectful/handlers/jax/_terms.py +502 -0
- effectful/handlers/jax/numpy/__init__.py +23 -0
- effectful/handlers/jax/numpy/linalg.py +13 -0
- effectful/handlers/jax/scipy/special.py +11 -0
- effectful/handlers/numpyro.py +562 -0
- effectful/handlers/pyro.py +565 -214
- effectful/handlers/torch.py +297 -168
- effectful/internals/runtime.py +6 -13
- effectful/internals/tensor_utils.py +32 -0
- effectful/internals/unification.py +901 -0
- effectful/ops/semantics.py +109 -77
- effectful/ops/syntax.py +821 -250
- effectful/ops/types.py +121 -29
- {effectful-0.1.0.dist-info → effectful-0.2.1.dist-info}/METADATA +59 -56
- effectful-0.2.1.dist-info/RECORD +26 -0
- {effectful-0.1.0.dist-info → effectful-0.2.1.dist-info}/WHEEL +1 -1
- effectful/handlers/numbers.py +0 -263
- effectful-0.1.0.dist-info/RECORD +0 -18
- {effectful-0.1.0.dist-info → effectful-0.2.1.dist-info/licenses}/LICENSE.md +0 -0
- {effectful-0.1.0.dist-info → effectful-0.2.1.dist-info}/top_level.txt +0 -0
effectful/ops/syntax.py
CHANGED
@@ -2,21 +2,16 @@ import collections.abc
|
|
2
2
|
import dataclasses
|
3
3
|
import functools
|
4
4
|
import inspect
|
5
|
+
import numbers
|
6
|
+
import operator
|
5
7
|
import random
|
6
8
|
import types
|
7
9
|
import typing
|
8
|
-
|
10
|
+
import warnings
|
11
|
+
from collections.abc import Callable, Iterable, Mapping
|
12
|
+
from typing import Annotated, Any, Concatenate
|
9
13
|
|
10
|
-
import
|
11
|
-
from typing_extensions import Concatenate, ParamSpec
|
12
|
-
|
13
|
-
from effectful.ops.types import Annotation, Expr, Interpretation, Operation, Term
|
14
|
-
|
15
|
-
P = ParamSpec("P")
|
16
|
-
Q = ParamSpec("Q")
|
17
|
-
S = TypeVar("S")
|
18
|
-
T = TypeVar("T")
|
19
|
-
V = TypeVar("V")
|
14
|
+
from effectful.ops.types import Annotation, Expr, NotHandled, Operation, Term
|
20
15
|
|
21
16
|
|
22
17
|
@dataclasses.dataclass
|
@@ -54,11 +49,9 @@ class Scoped(Annotation):
|
|
54
49
|
We illustrate the use of :class:`Scoped` with a few case studies of classical
|
55
50
|
syntactic variable binding constructs expressed as :class:`Operation` s.
|
56
51
|
|
57
|
-
>>> from typing import Annotated
|
52
|
+
>>> from typing import Annotated
|
58
53
|
>>> from effectful.ops.syntax import Scoped, defop
|
59
54
|
>>> from effectful.ops.semantics import fvsof
|
60
|
-
>>> from effectful.handlers.numbers import add
|
61
|
-
>>> A, B, S, T = TypeVar('A'), TypeVar('B'), TypeVar('S'), TypeVar('T')
|
62
55
|
>>> x, y = defop(int, name='x'), defop(int, name='y')
|
63
56
|
|
64
57
|
* For example, we can define a higher-order operation :func:`Lambda`
|
@@ -67,64 +60,66 @@ class Scoped(Annotation):
|
|
67
60
|
and returns a :class:`Term` representing a lambda function:
|
68
61
|
|
69
62
|
>>> @defop
|
70
|
-
... def Lambda(
|
63
|
+
... def Lambda[S, T, A, B](
|
71
64
|
... var: Annotated[Operation[[], S], Scoped[A]],
|
72
65
|
... body: Annotated[T, Scoped[A | B]]
|
73
66
|
... ) -> Annotated[Callable[[S], T], Scoped[B]]:
|
74
|
-
... raise
|
67
|
+
... raise NotHandled
|
75
68
|
|
76
69
|
* The :class:`Scoped` annotation is used here to indicate that the argument ``var``
|
77
70
|
passed to :func:`Lambda` may appear free in ``body``, but not in the resulting function.
|
78
71
|
In other words, it is bound by :func:`Lambda`:
|
79
72
|
|
80
|
-
>>> assert x not in fvsof(Lambda(x,
|
73
|
+
>>> assert x not in fvsof(Lambda(x, x() + 1))
|
81
74
|
|
82
75
|
However, variables in ``body`` other than ``var`` still appear free in the result:
|
83
76
|
|
84
|
-
>>> assert y in fvsof(Lambda(x,
|
77
|
+
>>> assert y in fvsof(Lambda(x, x() + y()))
|
85
78
|
|
86
79
|
* :class:`Scoped` can also be used with variadic arguments and keyword arguments.
|
87
80
|
For example, we can define a generalized :func:`LambdaN` that takes a variable
|
88
81
|
number of arguments and keyword arguments:
|
89
82
|
|
90
83
|
>>> @defop
|
91
|
-
... def LambdaN(
|
92
|
-
... body: Annotated[T, Scoped[A | B]]
|
84
|
+
... def LambdaN[S, T, A, B](
|
85
|
+
... body: Annotated[T, Scoped[A | B]],
|
93
86
|
... *args: Annotated[Operation[[], S], Scoped[A]],
|
94
87
|
... **kwargs: Annotated[Operation[[], S], Scoped[A]]
|
95
88
|
... ) -> Annotated[Callable[..., T], Scoped[B]]:
|
96
|
-
... raise
|
89
|
+
... raise NotHandled
|
97
90
|
|
98
91
|
This is equivalent to the built-in :class:`Operation` :func:`deffn`:
|
99
92
|
|
100
|
-
>>> assert not {x, y} & fvsof(LambdaN(
|
93
|
+
>>> assert not {x, y} & fvsof(LambdaN(x() + y(), x, y))
|
101
94
|
|
102
95
|
* :class:`Scoped` and :func:`defop` can also express more complex scoping semantics.
|
103
96
|
For example, we can define a :func:`Let` operation that binds a variable in
|
104
97
|
a :class:`Term` ``body`` to a ``value`` that may be another possibly open :class:`Term` :
|
105
98
|
|
106
99
|
>>> @defop
|
107
|
-
... def Let(
|
100
|
+
... def Let[S, T, A, B](
|
108
101
|
... var: Annotated[Operation[[], S], Scoped[A]],
|
109
102
|
... val: Annotated[S, Scoped[B]],
|
110
103
|
... body: Annotated[T, Scoped[A | B]]
|
111
104
|
... ) -> Annotated[T, Scoped[B]]:
|
112
|
-
... raise
|
105
|
+
... raise NotHandled
|
113
106
|
|
114
107
|
Here the variable ``var`` is bound by :func:`Let` in `body` but not in ``val`` :
|
115
108
|
|
116
|
-
>>> assert x not in fvsof(Let(x,
|
117
|
-
|
109
|
+
>>> assert x not in fvsof(Let(x, y() + 1, x() + y()))
|
110
|
+
|
111
|
+
>>> fvs = fvsof(Let(x, y() + x(), x() + y()))
|
112
|
+
>>> assert x in fvs and y in fvs
|
118
113
|
|
119
114
|
This is reflected in the free variables of subterms of the result:
|
120
115
|
|
121
|
-
>>> assert x in fvsof(Let(x,
|
122
|
-
>>> assert x not in fvsof(Let(x,
|
116
|
+
>>> assert x in fvsof(Let(x, x() + y(), x() + y()).args[1])
|
117
|
+
>>> assert x not in fvsof(Let(x, y() + 1, x() + y()).args[2])
|
123
118
|
"""
|
124
119
|
|
125
120
|
ordinal: collections.abc.Set
|
126
121
|
|
127
|
-
def __class_getitem__(cls, item: TypeVar | typing._SpecialForm):
|
122
|
+
def __class_getitem__(cls, item: typing.TypeVar | typing._SpecialForm):
|
128
123
|
assert not isinstance(item, tuple), "can only be in one scope"
|
129
124
|
if isinstance(item, typing.TypeVar):
|
130
125
|
return cls(ordinal=frozenset({item}))
|
@@ -183,7 +178,7 @@ class Scoped(Annotation):
|
|
183
178
|
|
184
179
|
@classmethod
|
185
180
|
def _get_fresh_ordinal(cls, *, name: str = "RootScope") -> collections.abc.Set:
|
186
|
-
return {TypeVar(name)}
|
181
|
+
return {typing.TypeVar(name)}
|
187
182
|
|
188
183
|
@classmethod
|
189
184
|
def _check_has_single_scope(cls, sig: inspect.Signature) -> bool:
|
@@ -211,10 +206,10 @@ class Scoped(Annotation):
|
|
211
206
|
|
212
207
|
def _get_free_type_vars(
|
213
208
|
tp: type | typing._SpecialForm | inspect.Parameter | tuple | list,
|
214
|
-
) -> collections.abc.Set[TypeVar]:
|
215
|
-
if isinstance(tp, TypeVar):
|
209
|
+
) -> collections.abc.Set[typing.TypeVar]:
|
210
|
+
if isinstance(tp, typing.TypeVar):
|
216
211
|
return {tp}
|
217
|
-
elif isinstance(tp,
|
212
|
+
elif isinstance(tp, tuple | list):
|
218
213
|
return set().union(*map(_get_free_type_vars, tp))
|
219
214
|
elif isinstance(tp, inspect.Parameter):
|
220
215
|
return _get_free_type_vars(tp.annotation)
|
@@ -341,36 +336,52 @@ class Scoped(Annotation):
|
|
341
336
|
return_ordinal = self._get_param_ordinal(bound_sig.signature.return_annotation)
|
342
337
|
for name, param in bound_sig.signature.parameters.items():
|
343
338
|
param_ordinal = self._get_param_ordinal(param)
|
344
|
-
if
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
set(
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
339
|
+
if param_ordinal <= self.ordinal and not param_ordinal <= return_ordinal:
|
340
|
+
param_value = bound_sig.arguments[name]
|
341
|
+
param_bound_vars = set()
|
342
|
+
|
343
|
+
if self._param_is_var(param):
|
344
|
+
# Handle individual Operation parameters (existing behavior)
|
345
|
+
if param.kind is inspect.Parameter.VAR_POSITIONAL:
|
346
|
+
# pre-condition: all bound variables should be distinct
|
347
|
+
assert len(param_value) == len(set(param_value))
|
348
|
+
param_bound_vars = set(param_value)
|
349
|
+
elif param.kind is inspect.Parameter.VAR_KEYWORD:
|
350
|
+
# pre-condition: all bound variables should be distinct
|
351
|
+
assert len(param_value.values()) == len(
|
352
|
+
set(param_value.values())
|
353
|
+
)
|
354
|
+
param_bound_vars = set(param_value.values())
|
355
|
+
else:
|
356
|
+
param_bound_vars = {param_value}
|
357
|
+
elif param_ordinal: # Only process if there's a Scoped annotation
|
358
|
+
# We can't use flatten here because we want to be able
|
359
|
+
# to see dict keys
|
360
|
+
def extract_operations(obj):
|
361
|
+
if isinstance(obj, Operation):
|
362
|
+
param_bound_vars.add(obj)
|
363
|
+
elif isinstance(obj, dict):
|
364
|
+
for k, v in obj.items():
|
365
|
+
extract_operations(k)
|
366
|
+
extract_operations(v)
|
367
|
+
elif isinstance(obj, list | set | tuple):
|
368
|
+
for v in obj:
|
369
|
+
extract_operations(v)
|
370
|
+
|
371
|
+
extract_operations(param_value)
|
363
372
|
|
364
373
|
# pre-condition: all bound variables should be distinct
|
365
|
-
|
366
|
-
|
367
|
-
|
374
|
+
if param_bound_vars:
|
375
|
+
assert not bound_vars & param_bound_vars
|
376
|
+
bound_vars |= param_bound_vars
|
368
377
|
|
369
378
|
return bound_vars
|
370
379
|
|
371
380
|
|
372
381
|
@functools.singledispatch
|
373
|
-
def defop
|
382
|
+
def defop[**P, T](
|
383
|
+
t: Callable[P, T], *, name: str | None = None, freshening=list[int] | None
|
384
|
+
) -> Operation[P, T]:
|
374
385
|
"""Creates a fresh :class:`Operation`.
|
375
386
|
|
376
387
|
:param t: May be a type, callable, or :class:`Operation`. If a type, the
|
@@ -411,13 +422,13 @@ def defop(t: Callable[P, T], *, name: Optional[str] = None) -> Operation[P, T]:
|
|
411
422
|
* Defining an operation with no default rule:
|
412
423
|
|
413
424
|
We can use :func:`defop` and the
|
414
|
-
:exc:`
|
425
|
+
:exc:`NotHandled` exception to define an
|
415
426
|
operation with no default rule:
|
416
427
|
|
417
428
|
>>> @defop
|
418
429
|
... def add(x: int, y: int) -> int:
|
419
|
-
... raise
|
420
|
-
>>> add(1, 2)
|
430
|
+
... raise NotHandled
|
431
|
+
>>> print(str(add(1, 2)))
|
421
432
|
add(1, 2)
|
422
433
|
|
423
434
|
When an operation has no default rule, the free rule is used instead, which
|
@@ -428,15 +439,14 @@ def defop(t: Callable[P, T], *, name: Optional[str] = None) -> Operation[P, T]:
|
|
428
439
|
|
429
440
|
Passing :func:`defop` a type is a handy way to create a free variable.
|
430
441
|
|
431
|
-
>>> import effectful.handlers.operator
|
432
442
|
>>> from effectful.ops.semantics import evaluate
|
433
443
|
>>> x = defop(int, name='x')
|
434
444
|
>>> y = x() + 1
|
435
445
|
|
436
446
|
``y`` is free in ``x``, so it is not fully evaluated:
|
437
447
|
|
438
|
-
>>> y
|
439
|
-
|
448
|
+
>>> print(str(y))
|
449
|
+
__add__(x(), 1)
|
440
450
|
|
441
451
|
We bind ``x`` by installing a handler for it:
|
442
452
|
|
@@ -449,7 +459,8 @@ def defop(t: Callable[P, T], *, name: Optional[str] = None) -> Operation[P, T]:
|
|
449
459
|
Because the result of :func:`defop` is always fresh, it's important to
|
450
460
|
be careful with variable identity.
|
451
461
|
|
452
|
-
Two
|
462
|
+
Two operations with the same name that come from different calls to
|
463
|
+
``defop`` are not equal:
|
453
464
|
|
454
465
|
>>> x1 = defop(int, name='x')
|
455
466
|
>>> x2 = defop(int, name='x')
|
@@ -460,24 +471,22 @@ def defop(t: Callable[P, T], *, name: Optional[str] = None) -> Operation[P, T]:
|
|
460
471
|
operation object. In this example, ``scale`` returns a term with a free
|
461
472
|
variable ``x``:
|
462
473
|
|
463
|
-
>>>
|
474
|
+
>>> x = defop(float, name='x')
|
464
475
|
>>> def scale(a: float) -> float:
|
465
|
-
... x = defop(float, name='x')
|
466
476
|
... return x() * a
|
467
477
|
|
468
|
-
Binding the variable ``x``
|
478
|
+
Binding the variable ``x`` as follows does not work:
|
469
479
|
|
470
480
|
>>> term = scale(3.0)
|
471
|
-
>>>
|
472
|
-
>>> with handler({
|
473
|
-
... print(evaluate(term))
|
474
|
-
|
481
|
+
>>> fresh_x = defop(float, name='x')
|
482
|
+
>>> with handler({fresh_x: lambda: 2.0}):
|
483
|
+
... print(str(evaluate(term)))
|
484
|
+
__mul__(x(), 3.0)
|
475
485
|
|
476
|
-
|
486
|
+
Only the original operation object will work:
|
477
487
|
|
478
488
|
>>> from effectful.ops.semantics import fvsof
|
479
|
-
>>>
|
480
|
-
>>> with handler({correct_x: lambda: 2.0}):
|
489
|
+
>>> with handler({x: lambda: 2.0}):
|
481
490
|
... print(evaluate(term))
|
482
491
|
6.0
|
483
492
|
|
@@ -487,7 +496,7 @@ def defop(t: Callable[P, T], *, name: Optional[str] = None) -> Operation[P, T]:
|
|
487
496
|
the same name and signature, but no default rule.
|
488
497
|
|
489
498
|
>>> fresh_select = defop(select)
|
490
|
-
>>> fresh_select(1, 2)
|
499
|
+
>>> print(str(fresh_select(1, 2)))
|
491
500
|
select(1, 2)
|
492
501
|
|
493
502
|
The new operation is distinct from the original:
|
@@ -504,17 +513,24 @@ def defop(t: Callable[P, T], *, name: Optional[str] = None) -> Operation[P, T]:
|
|
504
513
|
raise NotImplementedError(f"expected type or callable, got {t}")
|
505
514
|
|
506
515
|
|
507
|
-
@defop.register(typing.cast(
|
508
|
-
class _BaseOperation
|
516
|
+
@defop.register(typing.cast(type[collections.abc.Callable], collections.abc.Callable))
|
517
|
+
class _BaseOperation[**Q, V](Operation[Q, V]):
|
509
518
|
__signature__: inspect.Signature
|
510
519
|
__name__: str
|
511
520
|
|
512
521
|
_default: Callable[Q, V]
|
513
522
|
|
514
|
-
def __init__(
|
523
|
+
def __init__(
|
524
|
+
self,
|
525
|
+
default: Callable[Q, V],
|
526
|
+
*,
|
527
|
+
name: str | None = None,
|
528
|
+
freshening: list[int] | None = None,
|
529
|
+
):
|
515
530
|
functools.update_wrapper(self, default)
|
516
531
|
self._default = default
|
517
532
|
self.__name__ = name or default.__name__
|
533
|
+
self._freshening = freshening or []
|
518
534
|
self.__signature__ = inspect.signature(default)
|
519
535
|
|
520
536
|
def __eq__(self, other):
|
@@ -522,21 +538,30 @@ class _BaseOperation(Generic[Q, V], Operation[Q, V]):
|
|
522
538
|
return NotImplemented
|
523
539
|
return self is other
|
524
540
|
|
541
|
+
def __lt__(self, other):
|
542
|
+
if not isinstance(other, Operation):
|
543
|
+
return NotImplemented
|
544
|
+
return id(self) < id(other)
|
545
|
+
|
525
546
|
def __hash__(self):
|
526
547
|
return hash(self._default)
|
527
548
|
|
528
549
|
def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]":
|
529
550
|
try:
|
530
|
-
|
531
|
-
|
551
|
+
try:
|
552
|
+
return self._default(*args, **kwargs)
|
553
|
+
except NotImplementedError:
|
554
|
+
warnings.warn(
|
555
|
+
"Operations should raise effectful.ops.types.NotHandled instead of NotImplementedError.",
|
556
|
+
DeprecationWarning,
|
557
|
+
)
|
558
|
+
raise NotHandled
|
559
|
+
except NotHandled:
|
532
560
|
return typing.cast(
|
533
561
|
Callable[Concatenate[Operation[Q, V], Q], Expr[V]], defdata
|
534
562
|
)(self, *args, **kwargs)
|
535
563
|
|
536
|
-
def __fvs_rule__(self, *args: Q.args, **kwargs: Q.kwargs) ->
|
537
|
-
tuple[collections.abc.Set[Operation], ...],
|
538
|
-
dict[str, collections.abc.Set[Operation]],
|
539
|
-
]:
|
564
|
+
def __fvs_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> inspect.BoundArguments:
|
540
565
|
sig = Scoped.infer_annotations(self.__signature__)
|
541
566
|
bound_sig = sig.bind(*args, **kwargs)
|
542
567
|
bound_sig.apply_defaults()
|
@@ -555,115 +580,214 @@ class _BaseOperation(Generic[Q, V], Operation[Q, V]):
|
|
555
580
|
param_bound_vars for _ in bound_sig.arguments[name]
|
556
581
|
)
|
557
582
|
elif param.kind is inspect.Parameter.VAR_KEYWORD:
|
558
|
-
|
559
|
-
|
560
|
-
}
|
583
|
+
for k in bound_sig.arguments[name]:
|
584
|
+
result_sig.arguments[name][k] = param_bound_vars
|
561
585
|
else:
|
562
586
|
result_sig.arguments[name] = param_bound_vars
|
563
587
|
|
564
|
-
return
|
588
|
+
return result_sig
|
565
589
|
|
566
|
-
def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) ->
|
567
|
-
|
568
|
-
|
569
|
-
|
590
|
+
def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> type[V]:
|
591
|
+
from effectful.internals.unification import (
|
592
|
+
freetypevars,
|
593
|
+
nested_type,
|
594
|
+
substitute,
|
595
|
+
unify,
|
596
|
+
)
|
570
597
|
|
571
|
-
|
572
|
-
if
|
573
|
-
|
574
|
-
elif isinstance(anno, typing.TypeVar):
|
575
|
-
# rudimentary but sound special-case type inference sufficient for syntax ops:
|
576
|
-
# if the return type annotation is a TypeVar,
|
577
|
-
# look for a parameter with the same annotation and return its type,
|
578
|
-
# otherwise give up and return Any/object
|
579
|
-
for name, param in bound_sig.signature.parameters.items():
|
580
|
-
if param.annotation is anno and param.kind not in (
|
581
|
-
inspect.Parameter.VAR_POSITIONAL,
|
582
|
-
inspect.Parameter.VAR_KEYWORD,
|
583
|
-
):
|
584
|
-
arg = bound_sig.arguments[name]
|
585
|
-
tp: Type[V] = type(arg) if not isinstance(arg, type) else arg
|
586
|
-
return tp
|
587
|
-
return typing.cast(Type[V], object)
|
588
|
-
elif typing.get_origin(anno) is typing.Annotated:
|
589
|
-
tp = typing.get_args(anno)[0]
|
590
|
-
if not typing.TYPE_CHECKING:
|
591
|
-
tp = tp if typing.get_origin(tp) is None else typing.get_origin(tp)
|
592
|
-
return tp
|
593
|
-
elif typing.get_origin(anno) is not None:
|
594
|
-
return typing.get_origin(anno)
|
595
|
-
else:
|
596
|
-
return anno
|
598
|
+
return_anno = self.__signature__.return_annotation
|
599
|
+
if typing.get_origin(return_anno) is typing.Annotated:
|
600
|
+
return_anno = typing.get_args(return_anno)[0]
|
597
601
|
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
)
|
602
|
+
if return_anno is inspect.Parameter.empty:
|
603
|
+
return typing.cast(type[V], object)
|
604
|
+
elif return_anno is None:
|
605
|
+
return type(None) # type: ignore
|
606
|
+
elif not freetypevars(return_anno):
|
607
|
+
return return_anno
|
603
608
|
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
return ret
|
609
|
+
type_args = tuple(nested_type(a) for a in args)
|
610
|
+
type_kwargs = {k: nested_type(v) for k, v in kwargs.items()}
|
611
|
+
bound_sig = self.__signature__.bind(*type_args, **type_kwargs)
|
612
|
+
return substitute(return_anno, unify(self.__signature__, bound_sig)) # type: ignore
|
609
613
|
|
610
614
|
def __repr__(self):
|
615
|
+
return f"_BaseOperation({self._default}, name={self.__name__}, freshening={self._freshening})"
|
616
|
+
|
617
|
+
def __str__(self):
|
611
618
|
return self.__name__
|
612
619
|
|
620
|
+
def __get__(self, instance, owner):
|
621
|
+
if instance is not None:
|
622
|
+
# This is an instance-level operation, so we need to bind the instance
|
623
|
+
return functools.partial(self, instance)
|
624
|
+
else:
|
625
|
+
# This is a static operation, so we return the operation itself
|
626
|
+
return self
|
627
|
+
|
613
628
|
|
614
629
|
@defop.register(Operation)
|
615
|
-
def _(t: Operation[P, T], *, name:
|
616
|
-
|
630
|
+
def _[**P, T](t: Operation[P, T], *, name: str | None = None) -> Operation[P, T]:
|
617
631
|
@functools.wraps(t)
|
618
632
|
def func(*args, **kwargs):
|
619
|
-
raise
|
633
|
+
raise NotHandled
|
620
634
|
|
621
635
|
if name is None:
|
622
|
-
name = (
|
623
|
-
|
624
|
-
|
625
|
-
return defop(func, name=name)
|
636
|
+
name = getattr(t, "__name__", str(t))
|
637
|
+
freshening = getattr(t, "_freshening", []) + [random.randint(0, 1 << 32)]
|
638
|
+
|
639
|
+
return defop(func, name=name, freshening=freshening)
|
626
640
|
|
627
641
|
|
628
642
|
@defop.register(type)
|
629
|
-
|
643
|
+
@defop.register(typing.cast(type, types.GenericAlias))
|
644
|
+
@defop.register(typing.cast(type, typing._GenericAlias)) # type: ignore
|
645
|
+
@defop.register(typing.cast(type, types.UnionType))
|
646
|
+
def _[T](t: type[T], *, name: str | None = None) -> Operation[[], T]:
|
630
647
|
def func() -> t: # type: ignore
|
631
|
-
raise
|
648
|
+
raise NotHandled
|
632
649
|
|
650
|
+
freshening = []
|
633
651
|
if name is None:
|
634
|
-
name = t.__name__
|
635
|
-
|
652
|
+
name = t.__name__
|
653
|
+
freshening = [random.randint(0, 1 << 32)]
|
654
|
+
|
655
|
+
return typing.cast(
|
656
|
+
Operation[[], T],
|
657
|
+
defop(func, name=name, freshening=freshening),
|
658
|
+
)
|
636
659
|
|
637
660
|
|
638
661
|
@defop.register(types.BuiltinFunctionType)
|
639
|
-
def _(t: Callable[P, T], *, name:
|
640
|
-
|
662
|
+
def _[**P, T](t: Callable[P, T], *, name: str | None = None) -> Operation[P, T]:
|
641
663
|
@functools.wraps(t)
|
642
664
|
def func(*args, **kwargs):
|
643
|
-
|
665
|
+
from effectful.ops.semantics import fvsof
|
666
|
+
|
667
|
+
if not fvsof((args, kwargs)):
|
644
668
|
return t(*args, **kwargs)
|
645
669
|
else:
|
646
|
-
raise
|
670
|
+
raise NotHandled
|
647
671
|
|
648
672
|
return defop(func, name=name)
|
649
673
|
|
650
674
|
|
675
|
+
@defop.register(classmethod)
|
676
|
+
def _[**P, S, T]( # type: ignore
|
677
|
+
t: classmethod, *, name: str | None = None
|
678
|
+
) -> Operation[Concatenate[type[S], P], T]:
|
679
|
+
raise NotImplementedError("classmethod operations are not yet supported")
|
680
|
+
|
681
|
+
|
682
|
+
@defop.register(staticmethod)
|
683
|
+
class _StaticMethodOperation[**P, S, T](_BaseOperation[P, T]):
|
684
|
+
def __init__(self, default: staticmethod, **kwargs):
|
685
|
+
super().__init__(default=default.__func__, **kwargs)
|
686
|
+
|
687
|
+
def __get__(self, instance: S, owner: type[S] | None = None) -> Callable[P, T]:
|
688
|
+
return self
|
689
|
+
|
690
|
+
|
691
|
+
@defop.register(property)
|
692
|
+
class _PropertyOperation[S, T](_BaseOperation[[S], T]):
|
693
|
+
def __init__(self, default: property, **kwargs): # type: ignore
|
694
|
+
assert not default.fset, "property with setter is not supported"
|
695
|
+
assert not default.fdel, "property with deleter is not supported"
|
696
|
+
super().__init__(default=typing.cast(Callable[[S], T], default.fget), **kwargs)
|
697
|
+
|
698
|
+
@typing.overload
|
699
|
+
def __get__(
|
700
|
+
self, instance: None, owner: type[S] | None = None
|
701
|
+
) -> "_PropertyOperation[S, T]": ...
|
702
|
+
|
703
|
+
@typing.overload
|
704
|
+
def __get__(self, instance: S, owner: type[S] | None = None) -> T: ...
|
705
|
+
|
706
|
+
def __get__(self, instance, owner: type[S] | None = None):
|
707
|
+
if instance is not None:
|
708
|
+
return self(instance)
|
709
|
+
else:
|
710
|
+
return self
|
711
|
+
|
712
|
+
|
713
|
+
@defop.register(functools.singledispatchmethod)
|
714
|
+
class _SingleDispatchMethodOperation[**P, S, T](_BaseOperation[Concatenate[S, P], T]):
|
715
|
+
_default: Callable[Concatenate[S, P], T]
|
716
|
+
|
717
|
+
def __init__(self, default: functools.singledispatchmethod, **kwargs): # type: ignore
|
718
|
+
if isinstance(default.func, classmethod):
|
719
|
+
raise NotImplementedError("Operations as classmethod are not yet supported")
|
720
|
+
|
721
|
+
@functools.wraps(default.func)
|
722
|
+
def _wrapper(obj: S, *args: P.args, **kwargs: P.kwargs) -> T:
|
723
|
+
return default.__get__(obj)(*args, **kwargs)
|
724
|
+
|
725
|
+
self._registry: functools.singledispatchmethod = default
|
726
|
+
super().__init__(_wrapper, **kwargs)
|
727
|
+
|
728
|
+
@typing.overload
|
729
|
+
def __get__(
|
730
|
+
self, instance: None, owner: type[S] | None = None
|
731
|
+
) -> "_SingleDispatchMethodOperation[P, S, T]": ...
|
732
|
+
|
733
|
+
@typing.overload
|
734
|
+
def __get__(self, instance: S, owner: type[S] | None = None) -> Callable[P, T]: ...
|
735
|
+
|
736
|
+
def __get__(self, instance, owner: type[S] | None = None):
|
737
|
+
if instance is not None:
|
738
|
+
return functools.partial(self, instance)
|
739
|
+
else:
|
740
|
+
return self
|
741
|
+
|
742
|
+
@property
|
743
|
+
def register(self):
|
744
|
+
return self._registry.register
|
745
|
+
|
746
|
+
@property
|
747
|
+
def __isabstractmethod__(self):
|
748
|
+
return self._registry.__isabstractmethod__
|
749
|
+
|
750
|
+
|
751
|
+
class _SingleDispatchOperation[**P, S, T](_BaseOperation[Concatenate[S, P], T]):
|
752
|
+
_default: "functools._SingleDispatchCallable[T]"
|
753
|
+
|
754
|
+
@property
|
755
|
+
def register(self):
|
756
|
+
return self._default.register
|
757
|
+
|
758
|
+
@property
|
759
|
+
def dispatch(self):
|
760
|
+
return self._default.dispatch
|
761
|
+
|
762
|
+
|
763
|
+
if typing.TYPE_CHECKING:
|
764
|
+
defop.register(functools._SingleDispatchCallable)(_SingleDispatchOperation)
|
765
|
+
else:
|
766
|
+
|
767
|
+
@typing.runtime_checkable
|
768
|
+
class _SingleDispatchCallable(typing.Protocol):
|
769
|
+
registry: types.MappingProxyType[object, Callable]
|
770
|
+
|
771
|
+
def dispatch(self, cls: type) -> Callable: ...
|
772
|
+
def register(self, cls: type, func: Callable | None = None) -> Callable: ...
|
773
|
+
def _clear_cache(self) -> None: ...
|
774
|
+
def __call__(self, /, *args, **kwargs): ...
|
775
|
+
|
776
|
+
defop.register(_SingleDispatchCallable)(_SingleDispatchOperation)
|
777
|
+
|
778
|
+
|
651
779
|
@defop
|
652
|
-
def deffn(
|
653
|
-
body: Annotated[T, Scoped[
|
654
|
-
*args: Annotated[Operation, Scoped[
|
655
|
-
**kwargs: Annotated[Operation, Scoped[
|
656
|
-
) -> Callable[..., T]:
|
780
|
+
def deffn[T, A, B](
|
781
|
+
body: Annotated[T, Scoped[A | B]],
|
782
|
+
*args: Annotated[Operation, Scoped[A]],
|
783
|
+
**kwargs: Annotated[Operation, Scoped[A]],
|
784
|
+
) -> Annotated[Callable[..., T], Scoped[B]]:
|
657
785
|
"""An operation that represents a lambda function.
|
658
786
|
|
659
787
|
:param body: The body of the function.
|
660
|
-
:type body: T
|
661
788
|
:param args: Operations representing the positional arguments of the function.
|
662
|
-
:type args: Operation
|
663
789
|
:param kwargs: Operations representing the keyword arguments of the function.
|
664
|
-
:type kwargs: Operation
|
665
790
|
:returns: A callable term.
|
666
|
-
:rtype: Callable[..., T]
|
667
791
|
|
668
792
|
:func:`deffn` terms are eliminated by the :func:`call` operation, which
|
669
793
|
performs beta-reduction.
|
@@ -673,11 +797,13 @@ def deffn(
|
|
673
797
|
Here :func:`deffn` is used to define a term that represents the function
|
674
798
|
``lambda x, y=1: 2 * x + y``:
|
675
799
|
|
676
|
-
>>> import
|
800
|
+
>>> import random
|
801
|
+
>>> random.seed(0)
|
802
|
+
|
677
803
|
>>> x, y = defop(int, name='x'), defop(int, name='y')
|
678
804
|
>>> term = deffn(2 * x() + y(), x, y=y)
|
679
|
-
>>> term
|
680
|
-
deffn(
|
805
|
+
>>> print(str(term)) # doctest: +ELLIPSIS
|
806
|
+
deffn(...)
|
681
807
|
>>> term(3, y=4)
|
682
808
|
10
|
683
809
|
|
@@ -688,14 +814,14 @@ def deffn(
|
|
688
814
|
automatically create the right free variables.
|
689
815
|
|
690
816
|
"""
|
691
|
-
raise
|
817
|
+
raise NotHandled
|
692
818
|
|
693
819
|
|
694
|
-
class _CustomSingleDispatchCallable
|
820
|
+
class _CustomSingleDispatchCallable[**P, **Q, S, T]:
|
695
821
|
def __init__(
|
696
822
|
self, func: Callable[Concatenate[Callable[[type], Callable[Q, S]], P], T]
|
697
823
|
):
|
698
|
-
self.
|
824
|
+
self.func = func
|
699
825
|
self._registry = functools.singledispatch(func)
|
700
826
|
functools.update_wrapper(self, func)
|
701
827
|
|
@@ -708,31 +834,32 @@ class _CustomSingleDispatchCallable(Generic[P, Q, S, T]):
|
|
708
834
|
return self._registry.register
|
709
835
|
|
710
836
|
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
|
711
|
-
return self.
|
837
|
+
return self.func(self.dispatch, *args, **kwargs)
|
712
838
|
|
713
839
|
|
714
|
-
@_CustomSingleDispatchCallable
|
715
|
-
|
716
|
-
|
840
|
+
@defop.register(_CustomSingleDispatchCallable)
|
841
|
+
class _CustomSingleDispatchOperation[**P, **Q, S, T](_BaseOperation[P, T]):
|
842
|
+
_default: _CustomSingleDispatchCallable[P, Q, S, T]
|
717
843
|
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
:rtype: Expr[T]
|
844
|
+
def __init__(self, default: _CustomSingleDispatchCallable[P, Q, S, T], **kwargs):
|
845
|
+
super().__init__(default, **kwargs)
|
846
|
+
self.__signature__ = inspect.signature(functools.partial(default.func, None)) # type: ignore
|
722
847
|
|
723
|
-
|
848
|
+
@property
|
849
|
+
def dispatch(self):
|
850
|
+
return self._registry.dispatch
|
724
851
|
|
725
|
-
|
726
|
-
|
852
|
+
@property
|
853
|
+
def register(self):
|
854
|
+
return self._registry.register
|
727
855
|
|
728
|
-
>>> def incr(x: int) -> int:
|
729
|
-
... return x + 1
|
730
|
-
>>> term = defterm(incr)
|
731
|
-
>>> term
|
732
|
-
deffn(add(int(), 1), int)
|
733
|
-
>>> term(2)
|
734
|
-
3
|
735
856
|
|
857
|
+
@_CustomSingleDispatchCallable
|
858
|
+
def defterm[T](__dispatch: Callable[[type], Callable[[T], Expr[T]]], value: T):
|
859
|
+
"""Convert a value to a term, using the type of the value to dispatch.
|
860
|
+
|
861
|
+
:param value: The value to convert.
|
862
|
+
:returns: A term.
|
736
863
|
"""
|
737
864
|
if isinstance(value, Term):
|
738
865
|
return value
|
@@ -741,7 +868,7 @@ def defterm(__dispatch: Callable[[type], Callable[[T], Expr[T]]], value: T):
|
|
741
868
|
|
742
869
|
|
743
870
|
@_CustomSingleDispatchCallable
|
744
|
-
def defdata(
|
871
|
+
def defdata[T](
|
745
872
|
__dispatch: Callable[[type], Callable[..., Expr[T]]],
|
746
873
|
op: Operation[..., T],
|
747
874
|
*args,
|
@@ -766,7 +893,8 @@ def defdata(
|
|
766
893
|
|
767
894
|
.. code-block:: python
|
768
895
|
|
769
|
-
|
896
|
+
@defdata.register(collections.abc.Callable)
|
897
|
+
class _CallableTerm[**P, T](Term[collections.abc.Callable[P, T]]):
|
770
898
|
def __init__(
|
771
899
|
self,
|
772
900
|
op: Operation[..., T],
|
@@ -789,52 +917,40 @@ def defdata(
|
|
789
917
|
def kwargs(self):
|
790
918
|
return self._kwargs
|
791
919
|
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
return call(self, *args, **kwargs)
|
796
|
-
|
797
|
-
@defdata.register(collections.abc.Callable)
|
798
|
-
def _(op, *args, **kwargs):
|
799
|
-
return _CallableTerm(op, *args, **kwargs)
|
920
|
+
@defop
|
921
|
+
def __call__(self: collections.abc.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
922
|
+
...
|
800
923
|
|
801
924
|
When an Operation whose return type is `Callable` is passed to :func:`defdata`,
|
802
925
|
it is reconstructed as a :class:`_CallableTerm`, which implements the :func:`__call__` method.
|
803
926
|
"""
|
804
927
|
from effectful.ops.semantics import apply, evaluate, typeof
|
805
928
|
|
806
|
-
|
929
|
+
bindings: inspect.BoundArguments = op.__fvs_rule__(*args, **kwargs)
|
807
930
|
renaming = {
|
808
931
|
var: defop(var)
|
809
|
-
for bound_vars in (*
|
932
|
+
for bound_vars in (*bindings.args, *bindings.kwargs.values())
|
810
933
|
for var in bound_vars
|
811
934
|
}
|
812
935
|
|
813
|
-
|
814
|
-
|
815
|
-
*enumerate(zip(args, arg_ctxs)),
|
816
|
-
*{k: (v, kwarg_ctxs[k]) for k, v in kwargs.items()}.items(),
|
817
|
-
):
|
818
|
-
if c:
|
819
|
-
v = tree.map_structure(
|
820
|
-
lambda a: renaming.get(a, a) if isinstance(a, Operation) else a, v
|
821
|
-
)
|
822
|
-
res = evaluate(
|
823
|
-
v,
|
824
|
-
intp={
|
825
|
-
apply: lambda _, op, *a, **k: defdata(op, *a, **k),
|
826
|
-
**{op: renaming[op] for op in c},
|
827
|
-
},
|
828
|
-
)
|
829
|
-
if isinstance(i, int):
|
830
|
-
args_[i] = res
|
831
|
-
elif isinstance(i, str):
|
832
|
-
kwargs_[i] = res
|
936
|
+
renamed_args: inspect.BoundArguments = op.__signature__.bind(*args, **kwargs)
|
937
|
+
renamed_args.apply_defaults()
|
833
938
|
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
939
|
+
args_ = [
|
940
|
+
evaluate(
|
941
|
+
arg, intp={apply: defdata, **{v: renaming[v] for v in bindings.args[i]}}
|
942
|
+
)
|
943
|
+
for i, arg in enumerate(renamed_args.args)
|
944
|
+
]
|
945
|
+
kwargs_ = {
|
946
|
+
k: evaluate(
|
947
|
+
arg, intp={apply: defdata, **{v: renaming[v] for v in bindings.kwargs[k]}}
|
948
|
+
)
|
949
|
+
for k, arg in renamed_args.kwargs.items()
|
950
|
+
}
|
951
|
+
|
952
|
+
base_term = __dispatch(typing.cast(type[T], object))(op, *args_, **kwargs_)
|
953
|
+
return __dispatch(typeof(base_term))(op, *args_, **kwargs_)
|
838
954
|
|
839
955
|
|
840
956
|
@defterm.register(object)
|
@@ -842,12 +958,12 @@ def defdata(
|
|
842
958
|
@defterm.register(Term)
|
843
959
|
@defterm.register(type)
|
844
960
|
@defterm.register(types.BuiltinFunctionType)
|
845
|
-
def _(value: T) -> T:
|
961
|
+
def _[T](value: T) -> T:
|
846
962
|
return value
|
847
963
|
|
848
964
|
|
849
965
|
@defdata.register(object)
|
850
|
-
class _BaseTerm
|
966
|
+
class _BaseTerm[T](Term[T]):
|
851
967
|
_op: Operation[..., T]
|
852
968
|
_args: collections.abc.Sequence[Expr]
|
853
969
|
_kwargs: collections.abc.Mapping[str, Expr]
|
@@ -881,17 +997,55 @@ class _BaseTerm(Generic[T], Term[T]):
|
|
881
997
|
|
882
998
|
|
883
999
|
@defdata.register(collections.abc.Callable)
|
884
|
-
class _CallableTerm
|
885
|
-
|
886
|
-
|
1000
|
+
class _CallableTerm[**P, T](_BaseTerm[collections.abc.Callable[P, T]]):
|
1001
|
+
@defop
|
1002
|
+
def __call__(
|
1003
|
+
self: collections.abc.Callable[P, T], *args: P.args, **kwargs: P.kwargs
|
1004
|
+
) -> T:
|
1005
|
+
from effectful.ops.semantics import evaluate, fvsof, handler
|
1006
|
+
|
1007
|
+
if isinstance(self, Term) and self.op is deffn:
|
1008
|
+
body: Expr[Callable[P, T]] = self.args[0]
|
1009
|
+
argvars: tuple[Operation, ...] = self.args[1:]
|
1010
|
+
kwvars: dict[str, Operation] = self.kwargs
|
1011
|
+
subs = {
|
1012
|
+
**{v: functools.partial(lambda x: x, a) for v, a in zip(argvars, args)},
|
1013
|
+
**{
|
1014
|
+
kwvars[k]: functools.partial(lambda x: x, kwargs[k]) for k in kwargs
|
1015
|
+
},
|
1016
|
+
}
|
1017
|
+
with handler(subs):
|
1018
|
+
return evaluate(body)
|
1019
|
+
elif not fvsof((self, args, kwargs)):
|
1020
|
+
return self(*args, **kwargs)
|
1021
|
+
else:
|
1022
|
+
raise NotHandled
|
1023
|
+
|
1024
|
+
|
1025
|
+
def trace[**P, T](value: Callable[P, T]) -> Callable[P, T]:
|
1026
|
+
"""Convert a callable to a term by calling it with appropriately typed free variables.
|
1027
|
+
|
1028
|
+
**Example usage**:
|
1029
|
+
|
1030
|
+
:func:`trace` can be passed a function, and it will convert that function to
|
1031
|
+
a term by calling it with appropriately typed free variables:
|
887
1032
|
|
888
|
-
|
1033
|
+
>>> def incr(x: int) -> int:
|
1034
|
+
... return x + 1
|
1035
|
+
>>> term = trace(incr)
|
889
1036
|
|
1037
|
+
>>> print(str(term))
|
1038
|
+
deffn(__add__(int(), 1), int)
|
890
1039
|
|
891
|
-
|
892
|
-
|
1040
|
+
>>> term(2)
|
1041
|
+
3
|
1042
|
+
|
1043
|
+
"""
|
893
1044
|
from effectful.internals.runtime import interpreter
|
894
|
-
from effectful.ops.semantics import apply
|
1045
|
+
from effectful.ops.semantics import apply
|
1046
|
+
|
1047
|
+
call = defdata.dispatch(collections.abc.Callable).__call__
|
1048
|
+
assert isinstance(call, Operation)
|
895
1049
|
|
896
1050
|
assert not isinstance(value, Term)
|
897
1051
|
|
@@ -912,12 +1066,7 @@ def _(value: Callable[P, T]) -> Expr[Callable[P, T]]:
|
|
912
1066
|
)
|
913
1067
|
bound_sig.apply_defaults()
|
914
1068
|
|
915
|
-
with interpreter(
|
916
|
-
{
|
917
|
-
apply: lambda _, op, *a, **k: defdata(op, *a, **k),
|
918
|
-
call: call.__default_rule__,
|
919
|
-
}
|
920
|
-
):
|
1069
|
+
with interpreter({apply: defdata, call: call.__default_rule__}):
|
921
1070
|
body = value(
|
922
1071
|
*[a() for a in bound_sig.args],
|
923
1072
|
**{k: v() for k, v in bound_sig.kwargs.items()},
|
@@ -926,38 +1075,115 @@ def _(value: Callable[P, T]) -> Expr[Callable[P, T]]:
|
|
926
1075
|
return deffn(body, *bound_sig.args, **bound_sig.kwargs)
|
927
1076
|
|
928
1077
|
|
929
|
-
|
1078
|
+
@defop
|
1079
|
+
def defstream[S, T, A, B](
|
1080
|
+
body: Annotated[T, Scoped[A | B]],
|
1081
|
+
streams: Annotated[Mapping[Operation[[], S], Iterable[S]], Scoped[B]],
|
1082
|
+
) -> Annotated[Iterable[T], Scoped[A]]:
|
1083
|
+
"""A higher-order operation that represents a for-expression."""
|
1084
|
+
raise NotHandled
|
1085
|
+
|
1086
|
+
|
1087
|
+
@defdata.register(collections.abc.Iterable)
|
1088
|
+
class _IterableTerm[T](_BaseTerm[collections.abc.Iterable[T]]):
|
1089
|
+
@defop
|
1090
|
+
def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]:
|
1091
|
+
if not isinstance(self, Term):
|
1092
|
+
return iter(self)
|
1093
|
+
else:
|
1094
|
+
raise NotHandled
|
1095
|
+
|
1096
|
+
|
1097
|
+
@defdata.register(collections.abc.Iterator)
|
1098
|
+
class _IteratorTerm[T](_IterableTerm[T]):
|
1099
|
+
@defop
|
1100
|
+
def __next__(self: collections.abc.Iterator[T]) -> T:
|
1101
|
+
if not isinstance(self, Term):
|
1102
|
+
return next(self)
|
1103
|
+
else:
|
1104
|
+
raise NotHandled
|
1105
|
+
|
1106
|
+
|
1107
|
+
iter_ = _IterableTerm.__iter__
|
1108
|
+
next_ = _IteratorTerm.__next__
|
1109
|
+
|
1110
|
+
|
1111
|
+
@_CustomSingleDispatchCallable
|
1112
|
+
def syntactic_eq(
|
1113
|
+
__dispatch: Callable[[type], Callable[[Any, Any], bool]], x, other
|
1114
|
+
) -> bool:
|
930
1115
|
"""Syntactic equality, ignoring the interpretation of the terms.
|
931
1116
|
|
932
1117
|
:param x: A term.
|
933
|
-
:type x: Expr[T]
|
934
1118
|
:param other: Another term.
|
935
|
-
:type other: Expr[T]
|
936
1119
|
:returns: ``True`` if the terms are syntactically equal and ``False`` otherwise.
|
937
1120
|
"""
|
938
|
-
if
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
syntactic_eq, (op, args, kwargs), (op2, args2, kwargs2)
|
951
|
-
)
|
952
|
-
)
|
1121
|
+
if (
|
1122
|
+
dataclasses.is_dataclass(x)
|
1123
|
+
and not isinstance(x, type)
|
1124
|
+
and dataclasses.is_dataclass(other)
|
1125
|
+
and not isinstance(other, type)
|
1126
|
+
):
|
1127
|
+
return type(x) == type(other) and syntactic_eq(
|
1128
|
+
{field.name: getattr(x, field.name) for field in dataclasses.fields(x)},
|
1129
|
+
{
|
1130
|
+
field.name: getattr(other, field.name)
|
1131
|
+
for field in dataclasses.fields(other)
|
1132
|
+
},
|
953
1133
|
)
|
954
|
-
|
1134
|
+
else:
|
1135
|
+
return __dispatch(type(x))(x, other)
|
1136
|
+
|
1137
|
+
|
1138
|
+
@syntactic_eq.register
|
1139
|
+
def _(x: Term, other) -> bool:
|
1140
|
+
if not isinstance(other, Term):
|
955
1141
|
return False
|
1142
|
+
|
1143
|
+
op, args, kwargs = x.op, x.args, x.kwargs
|
1144
|
+
op2, args2, kwargs2 = other.op, other.args, other.kwargs
|
1145
|
+
return (
|
1146
|
+
op == op2
|
1147
|
+
and len(args) == len(args2)
|
1148
|
+
and set(kwargs) == set(kwargs2)
|
1149
|
+
and all(syntactic_eq(a, b) for a, b in zip(args, args2))
|
1150
|
+
and all(syntactic_eq(kwargs[k], kwargs2[k]) for k in kwargs)
|
1151
|
+
)
|
1152
|
+
|
1153
|
+
|
1154
|
+
@syntactic_eq.register
|
1155
|
+
def _(x: collections.abc.Mapping, other) -> bool:
|
1156
|
+
return isinstance(other, collections.abc.Mapping) and all(
|
1157
|
+
k in x and k in other and syntactic_eq(x[k], other[k])
|
1158
|
+
for k in set(x) | set(other)
|
1159
|
+
)
|
1160
|
+
|
1161
|
+
|
1162
|
+
@syntactic_eq.register
|
1163
|
+
def _(x: collections.abc.Sequence, other) -> bool:
|
1164
|
+
if (
|
1165
|
+
isinstance(x, tuple)
|
1166
|
+
and hasattr(x, "_fields")
|
1167
|
+
and all(hasattr(x, f) for f in x._fields)
|
1168
|
+
):
|
1169
|
+
return type(other) == type(x) and all(
|
1170
|
+
syntactic_eq(getattr(x, f), getattr(other, f)) for f in x._fields
|
1171
|
+
)
|
956
1172
|
else:
|
957
|
-
return
|
1173
|
+
return (
|
1174
|
+
isinstance(other, collections.abc.Sequence)
|
1175
|
+
and len(x) == len(other)
|
1176
|
+
and all(syntactic_eq(a, b) for a, b in zip(x, other))
|
1177
|
+
)
|
1178
|
+
|
958
1179
|
|
1180
|
+
@syntactic_eq.register(object)
|
1181
|
+
@syntactic_eq.register(str | bytes)
|
1182
|
+
def _(x: object, other) -> bool:
|
1183
|
+
return x == other
|
959
1184
|
|
960
|
-
|
1185
|
+
|
1186
|
+
class ObjectInterpretation[T, V](collections.abc.Mapping):
|
961
1187
|
"""A helper superclass for defining an ``Interpretation`` of many
|
962
1188
|
:class:`~effectful.ops.types.Operation` instances with shared state or behavior.
|
963
1189
|
|
@@ -1034,8 +1260,8 @@ class ObjectInterpretation(Generic[T, V], Interpretation[T, V]):
|
|
1034
1260
|
return self.implementations[item].__get__(self, type(self))
|
1035
1261
|
|
1036
1262
|
|
1037
|
-
class _ImplementedOperation
|
1038
|
-
impl:
|
1263
|
+
class _ImplementedOperation[**P, **Q, T, V]:
|
1264
|
+
impl: Callable[Q, V] | None
|
1039
1265
|
op: Operation[P, T]
|
1040
1266
|
|
1041
1267
|
def __init__(self, op: Operation[P, T]):
|
@@ -1059,7 +1285,7 @@ class _ImplementedOperation(Generic[P, Q, T, V]):
|
|
1059
1285
|
owner._temporary_implementations[self.op] = self.impl
|
1060
1286
|
|
1061
1287
|
|
1062
|
-
def implements(op: Operation[P, V]):
|
1288
|
+
def implements[**P, V](op: Operation[P, V]):
|
1063
1289
|
"""Marks a method in an :class:`ObjectInterpretation` as the implementation of a
|
1064
1290
|
particular abstract :class:`Operation`.
|
1065
1291
|
|
@@ -1068,3 +1294,348 @@ def implements(op: Operation[P, V]):
|
|
1068
1294
|
|
1069
1295
|
"""
|
1070
1296
|
return _ImplementedOperation(op)
|
1297
|
+
|
1298
|
+
|
1299
|
+
@defdata.register(numbers.Number)
|
1300
|
+
@functools.total_ordering
|
1301
|
+
class _NumberTerm[T: numbers.Number](_BaseTerm[T], numbers.Number):
|
1302
|
+
def __hash__(self):
|
1303
|
+
return id(self)
|
1304
|
+
|
1305
|
+
def __complex__(self) -> complex:
|
1306
|
+
raise ValueError("Cannot convert term to complex number")
|
1307
|
+
|
1308
|
+
def __float__(self) -> float:
|
1309
|
+
raise ValueError("Cannot convert term to float")
|
1310
|
+
|
1311
|
+
def __int__(self) -> int:
|
1312
|
+
raise ValueError("Cannot convert term to int")
|
1313
|
+
|
1314
|
+
def __bool__(self) -> bool:
|
1315
|
+
raise ValueError("Cannot convert term to bool")
|
1316
|
+
|
1317
|
+
@defop # type: ignore[prop-decorator]
|
1318
|
+
@property
|
1319
|
+
def real(self) -> float:
|
1320
|
+
if not isinstance(self, Term):
|
1321
|
+
return self.real
|
1322
|
+
else:
|
1323
|
+
raise NotHandled
|
1324
|
+
|
1325
|
+
@defop # type: ignore[prop-decorator]
|
1326
|
+
@property
|
1327
|
+
def imag(self) -> float:
|
1328
|
+
if not isinstance(self, Term):
|
1329
|
+
return self.imag
|
1330
|
+
else:
|
1331
|
+
raise NotHandled
|
1332
|
+
|
1333
|
+
@defop
|
1334
|
+
def conjugate(self) -> complex:
|
1335
|
+
if not isinstance(self, Term):
|
1336
|
+
return self.conjugate()
|
1337
|
+
else:
|
1338
|
+
raise NotHandled
|
1339
|
+
|
1340
|
+
@defop # type: ignore[prop-decorator]
|
1341
|
+
@property
|
1342
|
+
def numerator(self) -> int:
|
1343
|
+
if not isinstance(self, Term):
|
1344
|
+
return self.numerator
|
1345
|
+
else:
|
1346
|
+
raise NotHandled
|
1347
|
+
|
1348
|
+
@defop # type: ignore[prop-decorator]
|
1349
|
+
@property
|
1350
|
+
def denominator(self) -> int:
|
1351
|
+
if not isinstance(self, Term):
|
1352
|
+
return self.denominator
|
1353
|
+
else:
|
1354
|
+
raise NotHandled
|
1355
|
+
|
1356
|
+
@defop
|
1357
|
+
def __abs__(self) -> float:
|
1358
|
+
"""Return the absolute value of the term."""
|
1359
|
+
if not isinstance(self, Term):
|
1360
|
+
return self.__abs__()
|
1361
|
+
else:
|
1362
|
+
raise NotHandled
|
1363
|
+
|
1364
|
+
@defop
|
1365
|
+
def __neg__(self: T) -> T:
|
1366
|
+
if not isinstance(self, Term):
|
1367
|
+
return self.__neg__() # type: ignore
|
1368
|
+
else:
|
1369
|
+
raise NotHandled
|
1370
|
+
|
1371
|
+
@defop
|
1372
|
+
def __pos__(self: T) -> T:
|
1373
|
+
if not isinstance(self, Term):
|
1374
|
+
return self.__pos__() # type: ignore
|
1375
|
+
else:
|
1376
|
+
raise NotHandled
|
1377
|
+
|
1378
|
+
@defop
|
1379
|
+
def __trunc__(self) -> int:
|
1380
|
+
if not isinstance(self, Term):
|
1381
|
+
return self.__trunc__()
|
1382
|
+
else:
|
1383
|
+
raise NotHandled
|
1384
|
+
|
1385
|
+
@defop
|
1386
|
+
def __floor__(self) -> int:
|
1387
|
+
if not isinstance(self, Term):
|
1388
|
+
return self.__floor__()
|
1389
|
+
else:
|
1390
|
+
raise NotHandled
|
1391
|
+
|
1392
|
+
@defop
|
1393
|
+
def __ceil__(self) -> int:
|
1394
|
+
if not isinstance(self, Term):
|
1395
|
+
return self.__ceil__()
|
1396
|
+
else:
|
1397
|
+
raise NotHandled
|
1398
|
+
|
1399
|
+
@defop
|
1400
|
+
def __round__(self, ndigits: int | None = None) -> numbers.Real:
|
1401
|
+
if not isinstance(self, Term) and not isinstance(ndigits, Term):
|
1402
|
+
return self.__round__(ndigits)
|
1403
|
+
else:
|
1404
|
+
raise NotHandled
|
1405
|
+
|
1406
|
+
@defop
|
1407
|
+
def __invert__(self) -> int:
|
1408
|
+
if not isinstance(self, Term):
|
1409
|
+
return self.__invert__()
|
1410
|
+
else:
|
1411
|
+
raise NotHandled
|
1412
|
+
|
1413
|
+
@defop
|
1414
|
+
def __index__(self) -> int:
|
1415
|
+
if not isinstance(self, Term):
|
1416
|
+
return self.__index__()
|
1417
|
+
else:
|
1418
|
+
raise NotHandled
|
1419
|
+
|
1420
|
+
@defop
|
1421
|
+
def __eq__(self, other) -> bool: # type: ignore[override]
|
1422
|
+
if not isinstance(self, Term) and not isinstance(other, Term):
|
1423
|
+
return self.__eq__(other)
|
1424
|
+
else:
|
1425
|
+
return syntactic_eq(self, other)
|
1426
|
+
|
1427
|
+
@defop
|
1428
|
+
def __lt__(self, other) -> bool:
|
1429
|
+
if not isinstance(self, Term) and not isinstance(other, Term):
|
1430
|
+
return self.__lt__(other)
|
1431
|
+
else:
|
1432
|
+
raise NotHandled
|
1433
|
+
|
1434
|
+
@defop
|
1435
|
+
def __add__(self, other: T) -> T:
|
1436
|
+
if not isinstance(self, Term) and not isinstance(other, Term):
|
1437
|
+
return operator.__add__(self, other)
|
1438
|
+
else:
|
1439
|
+
raise NotHandled
|
1440
|
+
|
1441
|
+
def __radd__(self, other):
|
1442
|
+
if isinstance(other, Term) and isinstance(other, type(self)):
|
1443
|
+
return other.__add__(self)
|
1444
|
+
elif not isinstance(other, Term):
|
1445
|
+
return type(self).__add__(other, self)
|
1446
|
+
else:
|
1447
|
+
return NotImplemented
|
1448
|
+
|
1449
|
+
@defop
|
1450
|
+
def __sub__(self, other: T) -> T:
|
1451
|
+
if not isinstance(self, Term) and not isinstance(other, Term):
|
1452
|
+
return operator.__sub__(self, other)
|
1453
|
+
else:
|
1454
|
+
raise NotHandled
|
1455
|
+
|
1456
|
+
def __rsub__(self, other):
|
1457
|
+
if isinstance(other, Term) and isinstance(other, type(self)):
|
1458
|
+
return other.__sub__(self)
|
1459
|
+
elif not isinstance(other, Term):
|
1460
|
+
return type(self).__sub__(other, self)
|
1461
|
+
else:
|
1462
|
+
return NotImplemented
|
1463
|
+
|
1464
|
+
@defop
|
1465
|
+
def __mul__(self, other: T) -> T:
|
1466
|
+
if not isinstance(self, Term) and not isinstance(other, Term):
|
1467
|
+
return operator.__mul__(self, other)
|
1468
|
+
else:
|
1469
|
+
raise NotHandled
|
1470
|
+
|
1471
|
+
def __rmul__(self, other):
|
1472
|
+
if isinstance(other, Term) and isinstance(other, type(self)):
|
1473
|
+
return other.__mul__(self)
|
1474
|
+
elif not isinstance(other, Term):
|
1475
|
+
return type(self).__mul__(other, self)
|
1476
|
+
else:
|
1477
|
+
return NotImplemented
|
1478
|
+
|
1479
|
+
@defop
|
1480
|
+
def __truediv__(self, other: T) -> T:
|
1481
|
+
if not isinstance(self, Term) and not isinstance(other, Term):
|
1482
|
+
return operator.__truediv__(self, other)
|
1483
|
+
else:
|
1484
|
+
raise NotHandled
|
1485
|
+
|
1486
|
+
def __rtruediv__(self, other):
|
1487
|
+
if isinstance(other, Term) and isinstance(other, type(self)):
|
1488
|
+
return other.__truediv__(self)
|
1489
|
+
elif not isinstance(other, Term):
|
1490
|
+
return type(self).__truediv__(other, self)
|
1491
|
+
else:
|
1492
|
+
return NotImplemented
|
1493
|
+
|
1494
|
+
@defop
|
1495
|
+
def __floordiv__(self, other: T) -> T:
|
1496
|
+
if not isinstance(self, Term) and not isinstance(other, Term):
|
1497
|
+
return operator.__floordiv__(self, other)
|
1498
|
+
else:
|
1499
|
+
raise NotHandled
|
1500
|
+
|
1501
|
+
def __rfloordiv__(self, other):
|
1502
|
+
if isinstance(other, Term) and isinstance(other, type(self)):
|
1503
|
+
return other.__floordiv__(self)
|
1504
|
+
elif not isinstance(other, Term):
|
1505
|
+
return type(self).__floordiv__(other, self)
|
1506
|
+
else:
|
1507
|
+
return NotImplemented
|
1508
|
+
|
1509
|
+
@defop
|
1510
|
+
def __mod__(self, other: T) -> T:
|
1511
|
+
if not isinstance(self, Term) and not isinstance(other, Term):
|
1512
|
+
return operator.__mod__(self, other)
|
1513
|
+
else:
|
1514
|
+
raise NotHandled
|
1515
|
+
|
1516
|
+
def __rmod__(self, other):
|
1517
|
+
if isinstance(other, Term) and isinstance(other, type(self)):
|
1518
|
+
return other.__mod__(self)
|
1519
|
+
elif not isinstance(other, Term):
|
1520
|
+
return type(self).__mod__(other, self)
|
1521
|
+
else:
|
1522
|
+
return NotImplemented
|
1523
|
+
|
1524
|
+
@defop
|
1525
|
+
def __pow__(self, other: T) -> T:
|
1526
|
+
if not isinstance(self, Term) and not isinstance(other, Term):
|
1527
|
+
return operator.__pow__(self, other)
|
1528
|
+
else:
|
1529
|
+
raise NotHandled
|
1530
|
+
|
1531
|
+
def __rpow__(self, other):
|
1532
|
+
if isinstance(other, Term) and isinstance(other, type(self)):
|
1533
|
+
return other.__pow__(self)
|
1534
|
+
elif not isinstance(other, Term):
|
1535
|
+
return type(self).__pow__(other, self)
|
1536
|
+
else:
|
1537
|
+
return NotImplemented
|
1538
|
+
|
1539
|
+
@defop
|
1540
|
+
def __lshift__(self, other: T) -> T:
|
1541
|
+
if not isinstance(self, Term) and not isinstance(other, Term):
|
1542
|
+
return operator.__lshift__(self, other)
|
1543
|
+
else:
|
1544
|
+
raise NotHandled
|
1545
|
+
|
1546
|
+
def __rlshift__(self, other):
|
1547
|
+
if isinstance(other, Term) and isinstance(other, type(self)):
|
1548
|
+
return other.__lshift__(self)
|
1549
|
+
elif not isinstance(other, Term):
|
1550
|
+
return type(self).__lshift__(other, self)
|
1551
|
+
else:
|
1552
|
+
return NotImplemented
|
1553
|
+
|
1554
|
+
@defop
|
1555
|
+
def __rshift__(self, other: T) -> T:
|
1556
|
+
if not isinstance(self, Term) and not isinstance(other, Term):
|
1557
|
+
return operator.__rshift__(self, other)
|
1558
|
+
else:
|
1559
|
+
raise NotHandled
|
1560
|
+
|
1561
|
+
def __rrshift__(self, other):
|
1562
|
+
if isinstance(other, Term) and isinstance(other, type(self)):
|
1563
|
+
return other.__rshift__(self)
|
1564
|
+
elif not isinstance(other, Term):
|
1565
|
+
return type(self).__rshift__(other, self)
|
1566
|
+
else:
|
1567
|
+
return NotImplemented
|
1568
|
+
|
1569
|
+
@defop
|
1570
|
+
def __and__(self, other: T) -> T:
|
1571
|
+
if not isinstance(self, Term) and not isinstance(other, Term):
|
1572
|
+
return operator.__and__(self, other)
|
1573
|
+
else:
|
1574
|
+
raise NotHandled
|
1575
|
+
|
1576
|
+
def __rand__(self, other):
|
1577
|
+
if isinstance(other, Term) and isinstance(other, type(self)):
|
1578
|
+
return other.__and__(self)
|
1579
|
+
elif not isinstance(other, Term):
|
1580
|
+
return type(self).__and__(other, self)
|
1581
|
+
else:
|
1582
|
+
return NotImplemented
|
1583
|
+
|
1584
|
+
@defop
|
1585
|
+
def __xor__(self, other: T) -> T:
|
1586
|
+
if not isinstance(self, Term) and not isinstance(other, Term):
|
1587
|
+
return operator.__xor__(self, other)
|
1588
|
+
else:
|
1589
|
+
raise NotHandled
|
1590
|
+
|
1591
|
+
def __rxor__(self, other):
|
1592
|
+
if isinstance(other, Term) and isinstance(other, type(self)):
|
1593
|
+
return other.__xor__(self)
|
1594
|
+
elif not isinstance(other, Term):
|
1595
|
+
return type(self).__xor__(other, self)
|
1596
|
+
else:
|
1597
|
+
return NotImplemented
|
1598
|
+
|
1599
|
+
@defop
|
1600
|
+
def __or__(self, other: T) -> T:
|
1601
|
+
if not isinstance(self, Term) and not isinstance(other, Term):
|
1602
|
+
return operator.__or__(self, other)
|
1603
|
+
else:
|
1604
|
+
raise NotHandled
|
1605
|
+
|
1606
|
+
def __ror__(self, other):
|
1607
|
+
if isinstance(other, Term) and isinstance(other, type(self)):
|
1608
|
+
return other.__or__(self)
|
1609
|
+
elif not isinstance(other, Term):
|
1610
|
+
return type(self).__or__(other, self)
|
1611
|
+
else:
|
1612
|
+
return NotImplemented
|
1613
|
+
|
1614
|
+
|
1615
|
+
@defdata.register(numbers.Complex)
|
1616
|
+
@numbers.Complex.register
|
1617
|
+
class _ComplexTerm[T: numbers.Complex](_NumberTerm[T]):
|
1618
|
+
pass
|
1619
|
+
|
1620
|
+
|
1621
|
+
@defdata.register(numbers.Real)
|
1622
|
+
@numbers.Real.register
|
1623
|
+
class _RealTerm[T: numbers.Real](_ComplexTerm[T]):
|
1624
|
+
pass
|
1625
|
+
|
1626
|
+
|
1627
|
+
@defdata.register(numbers.Rational)
|
1628
|
+
@numbers.Rational.register
|
1629
|
+
class _RationalTerm[T: numbers.Rational](_RealTerm[T]):
|
1630
|
+
pass
|
1631
|
+
|
1632
|
+
|
1633
|
+
@defdata.register(numbers.Integral)
|
1634
|
+
@numbers.Integral.register
|
1635
|
+
class _IntegralTerm[T: numbers.Integral](_RationalTerm[T]):
|
1636
|
+
pass
|
1637
|
+
|
1638
|
+
|
1639
|
+
@defdata.register(bool)
|
1640
|
+
class _BoolTerm[T: bool](_IntegralTerm[T]): # type: ignore
|
1641
|
+
pass
|