egglog 10.0.1__cp313-cp313-win_amd64.whl → 11.0.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +3 -1
- egglog/bindings.cp313-win_amd64.pyd +0 -0
- egglog/bindings.pyi +26 -34
- egglog/builtins.py +231 -182
- egglog/conversion.py +61 -43
- egglog/declarations.py +104 -18
- egglog/deconstruct.py +173 -0
- egglog/egraph.py +86 -144
- egglog/egraph_state.py +17 -14
- egglog/examples/bignum.py +1 -1
- egglog/examples/multiset.py +2 -2
- egglog/exp/array_api.py +46 -12
- egglog/exp/array_api_jit.py +11 -5
- egglog/exp/array_api_program_gen.py +1 -1
- egglog/exp/program_gen.py +4 -3
- egglog/pretty.py +11 -25
- egglog/runtime.py +203 -151
- egglog/thunk.py +6 -4
- egglog/type_constraint_solver.py +1 -1
- egglog/version_compat.py +87 -0
- {egglog-10.0.1.dist-info → egglog-11.0.0.dist-info}/METADATA +1 -1
- egglog-11.0.0.dist-info/RECORD +45 -0
- {egglog-10.0.1.dist-info → egglog-11.0.0.dist-info}/WHEEL +1 -1
- egglog/functionalize.py +0 -91
- egglog-10.0.1.dist-info/RECORD +0 -44
- {egglog-10.0.1.dist-info → egglog-11.0.0.dist-info}/licenses/LICENSE +0 -0
egglog/runtime.py
CHANGED
|
@@ -11,30 +11,32 @@ so they are not mangled by Python and can be accessed by the user.
|
|
|
11
11
|
|
|
12
12
|
from __future__ import annotations
|
|
13
13
|
|
|
14
|
+
import itertools
|
|
14
15
|
import operator
|
|
16
|
+
import types
|
|
15
17
|
from collections.abc import Callable
|
|
16
|
-
from dataclasses import dataclass, replace
|
|
18
|
+
from dataclasses import InitVar, dataclass, replace
|
|
17
19
|
from inspect import Parameter, Signature
|
|
18
20
|
from itertools import zip_longest
|
|
19
|
-
from typing import TYPE_CHECKING, TypeVar, Union, cast, get_args, get_origin
|
|
21
|
+
from typing import TYPE_CHECKING, Any, TypeVar, Union, cast, get_args, get_origin
|
|
20
22
|
|
|
21
23
|
from .declarations import *
|
|
22
24
|
from .pretty import *
|
|
23
25
|
from .thunk import Thunk
|
|
24
26
|
from .type_constraint_solver import *
|
|
27
|
+
from .version_compat import *
|
|
25
28
|
|
|
26
29
|
if TYPE_CHECKING:
|
|
27
30
|
from collections.abc import Iterable
|
|
28
31
|
|
|
29
|
-
from .egraph import Fact
|
|
30
|
-
|
|
31
32
|
|
|
32
33
|
__all__ = [
|
|
33
34
|
"LIT_CLASS_NAMES",
|
|
34
|
-
"
|
|
35
|
+
"NUMERIC_BINARY_METHODS",
|
|
35
36
|
"RuntimeClass",
|
|
36
37
|
"RuntimeExpr",
|
|
37
38
|
"RuntimeFunction",
|
|
39
|
+
"define_expr_method",
|
|
38
40
|
"resolve_callable",
|
|
39
41
|
"resolve_type_annotation",
|
|
40
42
|
"resolve_type_annotation_mutate",
|
|
@@ -45,24 +47,34 @@ UNIT_CLASS_NAME = "Unit"
|
|
|
45
47
|
UNARY_LIT_CLASS_NAMES = {"i64", "f64", "Bool", "String"}
|
|
46
48
|
LIT_CLASS_NAMES = UNARY_LIT_CLASS_NAMES | {UNIT_CLASS_NAME, "PyObject"}
|
|
47
49
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
"
|
|
53
|
-
"
|
|
54
|
-
"
|
|
55
|
-
"
|
|
56
|
-
"
|
|
57
|
-
"
|
|
58
|
-
"
|
|
59
|
-
"
|
|
60
|
-
"
|
|
61
|
-
"
|
|
50
|
+
# All methods which should return NotImplemented if they fail to resolve and are reflected as well
|
|
51
|
+
# From https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types
|
|
52
|
+
|
|
53
|
+
NUMERIC_BINARY_METHODS = {
|
|
54
|
+
"__add__",
|
|
55
|
+
"__sub__",
|
|
56
|
+
"__mul__",
|
|
57
|
+
"__matmul__",
|
|
58
|
+
"__truediv__",
|
|
59
|
+
"__floordiv__",
|
|
60
|
+
"__mod__",
|
|
61
|
+
"__divmod__",
|
|
62
|
+
"__pow__",
|
|
63
|
+
"__lshift__",
|
|
64
|
+
"__rshift__",
|
|
65
|
+
"__and__",
|
|
66
|
+
"__xor__",
|
|
67
|
+
"__or__",
|
|
68
|
+
"__lt__",
|
|
69
|
+
"__le__",
|
|
70
|
+
"__gt__",
|
|
71
|
+
"__ge__",
|
|
62
72
|
}
|
|
63
73
|
|
|
64
|
-
|
|
65
|
-
|
|
74
|
+
|
|
75
|
+
# Methods that need to be defined on the runtime type that holds `Expr` objects, so that they can be used as methods.
|
|
76
|
+
|
|
77
|
+
TYPE_DEFINED_METHODS = {
|
|
66
78
|
"__bool__",
|
|
67
79
|
"__len__",
|
|
68
80
|
"__complex__",
|
|
@@ -70,9 +82,15 @@ PRESERVED_METHODS = [
|
|
|
70
82
|
"__float__",
|
|
71
83
|
"__iter__",
|
|
72
84
|
"__index__",
|
|
73
|
-
"
|
|
74
|
-
"
|
|
75
|
-
|
|
85
|
+
"__call__",
|
|
86
|
+
"__getitem__",
|
|
87
|
+
"__setitem__",
|
|
88
|
+
"__delitem__",
|
|
89
|
+
"__pos__",
|
|
90
|
+
"__neg__",
|
|
91
|
+
"__invert__",
|
|
92
|
+
"__round__",
|
|
93
|
+
}
|
|
76
94
|
|
|
77
95
|
# Set this globally so we can get access to PyObject when we have a type annotation of just object.
|
|
78
96
|
# This is the only time a type annotation doesn't need to include the egglog type b/c object is top so that would be redundant statically.
|
|
@@ -134,11 +152,48 @@ def inverse_resolve_type_annotation(decls_thunk: Callable[[], Declarations], tp:
|
|
|
134
152
|
##
|
|
135
153
|
|
|
136
154
|
|
|
137
|
-
|
|
138
|
-
|
|
155
|
+
class BaseClassFactoryMeta(type):
|
|
156
|
+
"""
|
|
157
|
+
Base metaclass for all runtime classes created by ClassFactory
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def __instancecheck__(cls, instance: object) -> bool:
|
|
161
|
+
assert isinstance(cls, RuntimeClass)
|
|
162
|
+
return isinstance(instance, RuntimeExpr) and cls.__egg_tp__.name == instance.__egg_typed_expr__.tp.name
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class ClassFactory(type):
|
|
166
|
+
"""
|
|
167
|
+
A metaclass for types which should create `type` objects when instantiated.
|
|
168
|
+
|
|
169
|
+
That's so that they work with `isinstance` and can be placed in `match ClassName()`.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __call__(cls, *args, **kwargs) -> type:
|
|
173
|
+
# If we have params, don't inherit from `type` because we don't need to match against this and also
|
|
174
|
+
# this won't work with `Union[X]` because it won't look at `__parameters__` for instances of `type`.
|
|
175
|
+
if kwargs.pop("_egg_has_params", False):
|
|
176
|
+
return super().__call__(*args, **kwargs)
|
|
177
|
+
namespace: dict[str, Any] = {}
|
|
178
|
+
for m in reversed(cls.__mro__):
|
|
179
|
+
namespace.update(m.__dict__)
|
|
180
|
+
init = namespace.pop("__init__")
|
|
181
|
+
meta = types.new_class("type(RuntimeClass)", (BaseClassFactoryMeta,), {}, lambda ns: ns.update(**namespace))
|
|
182
|
+
tp = types.new_class("RuntimeClass", (), {"metaclass": meta})
|
|
183
|
+
init(tp, *args, **kwargs)
|
|
184
|
+
return tp
|
|
185
|
+
|
|
186
|
+
def __instancecheck__(cls, instance: object) -> bool:
|
|
187
|
+
return isinstance(instance, BaseClassFactoryMeta)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@dataclass(match_args=False)
|
|
191
|
+
class RuntimeClass(DelayedDeclerations, metaclass=ClassFactory):
|
|
139
192
|
__egg_tp__: TypeRefWithVars
|
|
193
|
+
# True if we want `__parameters__` to be recognized by `Union`, which means we can't inherit from `type` directly.
|
|
194
|
+
_egg_has_params: InitVar[bool] = False
|
|
140
195
|
|
|
141
|
-
def __post_init__(self) -> None:
|
|
196
|
+
def __post_init__(self, _egg_has_params: bool) -> None:
|
|
142
197
|
global _PY_OBJECT_CLASS, _UNSTABLE_FN_CLASS
|
|
143
198
|
if (name := self.__egg_tp__.name) == "PyObject":
|
|
144
199
|
_PY_OBJECT_CLASS = self
|
|
@@ -228,7 +283,7 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
228
283
|
else:
|
|
229
284
|
final_args = new_args
|
|
230
285
|
tp = TypeRefWithVars(self.__egg_tp__.name, final_args)
|
|
231
|
-
return RuntimeClass(Thunk.fn(Declarations.create, self, *decls_like), tp)
|
|
286
|
+
return RuntimeClass(Thunk.fn(Declarations.create, self, *decls_like), tp, _egg_has_params=True)
|
|
232
287
|
|
|
233
288
|
def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable:
|
|
234
289
|
if name == "__origin__" and self.__egg_tp__.args:
|
|
@@ -249,8 +304,7 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
249
304
|
try:
|
|
250
305
|
cls_decl = self.__egg_decls__._classes[self.__egg_tp__.name]
|
|
251
306
|
except Exception as e:
|
|
252
|
-
|
|
253
|
-
raise
|
|
307
|
+
raise add_note(f"Error processing class {self.__egg_tp__.name}", e) from None
|
|
254
308
|
|
|
255
309
|
preserved_methods = cls_decl.preserved_methods
|
|
256
310
|
if name in preserved_methods:
|
|
@@ -281,10 +335,21 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
281
335
|
def __str__(self) -> str:
|
|
282
336
|
return str(self.__egg_tp__)
|
|
283
337
|
|
|
338
|
+
def __repr__(self) -> str:
|
|
339
|
+
return str(self)
|
|
340
|
+
|
|
284
341
|
# Make hashable so can go in Union
|
|
285
342
|
def __hash__(self) -> int:
|
|
286
343
|
return hash((id(self.__egg_decls_thunk__), self.__egg_tp__))
|
|
287
344
|
|
|
345
|
+
def __eq__(self, other: object) -> bool:
|
|
346
|
+
"""
|
|
347
|
+
Support equality for runtime comparison of egglog classes.
|
|
348
|
+
"""
|
|
349
|
+
if not isinstance(other, RuntimeClass):
|
|
350
|
+
return NotImplemented
|
|
351
|
+
return self.__egg_tp__ == other.__egg_tp__
|
|
352
|
+
|
|
288
353
|
# Support unioning like types
|
|
289
354
|
def __or__(self, value: type) -> object:
|
|
290
355
|
return Union[self, value] # noqa: UP007
|
|
@@ -296,6 +361,10 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
296
361
|
"""
|
|
297
362
|
return tuple(inverse_resolve_type_annotation(self.__egg_decls_thunk__, tp) for tp in self.__egg_tp__.args)
|
|
298
363
|
|
|
364
|
+
@property
|
|
365
|
+
def __match_args__(self) -> tuple[str, ...]:
|
|
366
|
+
return self.__egg_decls__._classes[self.__egg_tp__.name].match_args
|
|
367
|
+
|
|
299
368
|
|
|
300
369
|
@dataclass
|
|
301
370
|
class RuntimeFunction(DelayedDeclerations):
|
|
@@ -303,20 +372,30 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
303
372
|
# bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
|
|
304
373
|
__egg_bound__: JustTypeRef | RuntimeExpr | None = None
|
|
305
374
|
|
|
375
|
+
def __eq__(self, other: object) -> bool:
|
|
376
|
+
"""
|
|
377
|
+
Support equality for runtime comparison of egglog functions.
|
|
378
|
+
"""
|
|
379
|
+
if not isinstance(other, RuntimeFunction):
|
|
380
|
+
return NotImplemented
|
|
381
|
+
return self.__egg_ref__ == other.__egg_ref__ and bool(self.__egg_bound__ == other.__egg_bound__)
|
|
382
|
+
|
|
383
|
+
def __hash__(self) -> int:
|
|
384
|
+
return hash((self.__egg_ref__, self.__egg_bound__))
|
|
385
|
+
|
|
306
386
|
@property
|
|
307
387
|
def __egg_ref__(self) -> CallableRef:
|
|
308
388
|
return self.__egg_ref_thunk__()
|
|
309
389
|
|
|
310
390
|
def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None:
|
|
311
|
-
from .conversion import resolve_literal
|
|
391
|
+
from .conversion import resolve_literal # noqa: PLC0415
|
|
312
392
|
|
|
313
393
|
if isinstance(self.__egg_bound__, RuntimeExpr):
|
|
314
394
|
args = (self.__egg_bound__, *args)
|
|
315
395
|
try:
|
|
316
396
|
signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).signature
|
|
317
397
|
except Exception as e:
|
|
318
|
-
|
|
319
|
-
raise
|
|
398
|
+
raise add_note(f"Failed to find callable {self}", e) # noqa: B904
|
|
320
399
|
decls = self.__egg_decls__.copy()
|
|
321
400
|
# Special case function application bc we dont support variadic generics yet generally
|
|
322
401
|
if signature == "fn-app":
|
|
@@ -344,7 +423,7 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
344
423
|
try:
|
|
345
424
|
bound = py_signature.bind(*args, **kwargs)
|
|
346
425
|
except TypeError as err:
|
|
347
|
-
raise TypeError(f"Failed to
|
|
426
|
+
raise TypeError(f"Failed to bind arguments for {self} with args {args} and kwargs {kwargs}: {err}") from err
|
|
348
427
|
del kwargs
|
|
349
428
|
bound.apply_defaults()
|
|
350
429
|
assert not bound.kwargs
|
|
@@ -413,9 +492,7 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args:
|
|
|
413
492
|
Parameter(
|
|
414
493
|
n,
|
|
415
494
|
Parameter.POSITIONAL_OR_KEYWORD,
|
|
416
|
-
default=RuntimeExpr.__from_values__(
|
|
417
|
-
decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n, True))
|
|
418
|
-
)
|
|
495
|
+
default=RuntimeExpr.__from_values__(decls, TypedExprDecl(t.to_just(), d or LetRefDecl(n)))
|
|
419
496
|
if d is not None or optional_args
|
|
420
497
|
else Parameter.empty,
|
|
421
498
|
)
|
|
@@ -426,32 +503,6 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args:
|
|
|
426
503
|
return Signature(parameters)
|
|
427
504
|
|
|
428
505
|
|
|
429
|
-
# All methods which should return NotImplemented if they fail to resolve
|
|
430
|
-
# From https://docs.python.org/3/reference/datamodel.html
|
|
431
|
-
PARTIAL_METHODS = {
|
|
432
|
-
"__lt__",
|
|
433
|
-
"__le__",
|
|
434
|
-
"__eq__",
|
|
435
|
-
"__ne__",
|
|
436
|
-
"__gt__",
|
|
437
|
-
"__ge__",
|
|
438
|
-
"__add__",
|
|
439
|
-
"__sub__",
|
|
440
|
-
"__mul__",
|
|
441
|
-
"__matmul__",
|
|
442
|
-
"__truediv__",
|
|
443
|
-
"__floordiv__",
|
|
444
|
-
"__mod__",
|
|
445
|
-
"__divmod__",
|
|
446
|
-
"__pow__",
|
|
447
|
-
"__lshift__",
|
|
448
|
-
"__rshift__",
|
|
449
|
-
"__and__",
|
|
450
|
-
"__xor__",
|
|
451
|
-
"__or__",
|
|
452
|
-
}
|
|
453
|
-
|
|
454
|
-
|
|
455
506
|
@dataclass
|
|
456
507
|
class RuntimeExpr(DelayedDeclerations):
|
|
457
508
|
__egg_typed_expr_thunk__: Callable[[], TypedExprDecl]
|
|
@@ -468,17 +519,14 @@ class RuntimeExpr(DelayedDeclerations):
|
|
|
468
519
|
return self.__egg_typed_expr_thunk__()
|
|
469
520
|
|
|
470
521
|
def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
if name in class_decl.properties:
|
|
480
|
-
return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(cls_name, name)), self)()
|
|
481
|
-
raise AttributeError(f"{cls_name} has no method {name}") from None
|
|
522
|
+
if (method := _get_expr_method(self, name)) is not None:
|
|
523
|
+
return method
|
|
524
|
+
if name in self.__egg_class_decl__.properties:
|
|
525
|
+
fn = RuntimeFunction(
|
|
526
|
+
self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_class_name__, name)), self
|
|
527
|
+
)
|
|
528
|
+
return fn()
|
|
529
|
+
raise AttributeError(f"{self.__egg_class_name__} has no method {name}") from None
|
|
482
530
|
|
|
483
531
|
def __repr__(self) -> str:
|
|
484
532
|
"""
|
|
@@ -493,7 +541,7 @@ class RuntimeExpr(DelayedDeclerations):
|
|
|
493
541
|
return pretty_decl(self.__egg_decls__, self.__egg_typed_expr__.expr, wrapping_fn=wrapping_fn)
|
|
494
542
|
|
|
495
543
|
def _ipython_display_(self) -> None:
|
|
496
|
-
from IPython.display import Code, display
|
|
544
|
+
from IPython.display import Code, display # noqa: PLC0415
|
|
497
545
|
|
|
498
546
|
display(Code(str(self), language="python"))
|
|
499
547
|
|
|
@@ -509,13 +557,6 @@ class RuntimeExpr(DelayedDeclerations):
|
|
|
509
557
|
def __egg_class_decl__(self) -> ClassDecl:
|
|
510
558
|
return self.__egg_decls__.get_class_decl(self.__egg_class_name__)
|
|
511
559
|
|
|
512
|
-
# These both will be overriden below in the special methods section, but add these here for type hinting purposes
|
|
513
|
-
def __eq__(self, other: object) -> Fact: # type: ignore[override, empty-body]
|
|
514
|
-
...
|
|
515
|
-
|
|
516
|
-
def __ne__(self, other: object) -> RuntimeExpr: # type: ignore[override, empty-body]
|
|
517
|
-
...
|
|
518
|
-
|
|
519
560
|
# Implement these so that copy() works on this object
|
|
520
561
|
# otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion
|
|
521
562
|
|
|
@@ -527,91 +568,102 @@ class RuntimeExpr(DelayedDeclerations):
|
|
|
527
568
|
self.__egg_typed_expr_thunk__ = Thunk.value(d[1])
|
|
528
569
|
|
|
529
570
|
def __hash__(self) -> int:
|
|
571
|
+
if (method := _get_expr_method(self, "__hash__")) is not None:
|
|
572
|
+
return cast("int", cast("Any", method()))
|
|
530
573
|
return hash(self.__egg_typed_expr__)
|
|
531
574
|
|
|
575
|
+
# Implement this directly to special case behavior where it transforms to an egraph equality, if it is not a
|
|
576
|
+
# preserved method or defined on the class
|
|
577
|
+
def __eq__(self, other: object) -> object: # type: ignore[override]
|
|
578
|
+
if (method := _get_expr_method(self, "__eq__")) is not None:
|
|
579
|
+
return method(other)
|
|
532
580
|
|
|
533
|
-
#
|
|
534
|
-
|
|
581
|
+
# TODO: Check if two objects can be upcasted to be the same. If not, then return NotImplemented so other
|
|
582
|
+
# expr gets a chance to resolve __eq__ which could be a preserved method.
|
|
583
|
+
from .egraph import BaseExpr, eq # noqa: PLC0415
|
|
535
584
|
|
|
536
|
-
|
|
537
|
-
self: RuntimeExpr,
|
|
538
|
-
*args: object,
|
|
539
|
-
__name: str = name,
|
|
540
|
-
**kwargs: object,
|
|
541
|
-
) -> RuntimeExpr | Fact | None:
|
|
542
|
-
from .conversion import ConvertError
|
|
585
|
+
return eq(cast("BaseExpr", self)).to(cast("BaseExpr", other))
|
|
543
586
|
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
except KeyError:
|
|
550
|
-
pass
|
|
551
|
-
else:
|
|
552
|
-
return method(self, *args, **kwargs)
|
|
553
|
-
# If this is a "partial" method meaning that it can return NotImplemented,
|
|
554
|
-
# we want to find the "best" superparent (lowest cost) of the arg types to call with it, instead of just
|
|
555
|
-
# using the arg type of the self arg.
|
|
556
|
-
# This is neccesary so if we add like an int to a ndarray, it will upcast the int to an ndarray, instead of vice versa.
|
|
557
|
-
if __name in PARTIAL_METHODS:
|
|
558
|
-
try:
|
|
559
|
-
return call_method_min_conversion(self, args[0], __name)
|
|
560
|
-
except ConvertError:
|
|
561
|
-
# Defer raising not imeplemented in case the dunder method is not symmetrical, then
|
|
562
|
-
# we use the standard process
|
|
563
|
-
pass
|
|
564
|
-
if __name in class_decl.methods:
|
|
565
|
-
fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(class_name, __name)), self)
|
|
566
|
-
return fn(*args, **kwargs) # type: ignore[arg-type]
|
|
567
|
-
# Handle == and != fallbacks to eq and ne helpers if the methods aren't defined on the class explicitly.
|
|
568
|
-
if __name == "__eq__":
|
|
569
|
-
from .egraph import BaseExpr, eq
|
|
570
|
-
|
|
571
|
-
return eq(cast("BaseExpr", self)).to(cast("BaseExpr", args[0]))
|
|
572
|
-
if __name == "__ne__":
|
|
573
|
-
from .egraph import BaseExpr, ne
|
|
574
|
-
|
|
575
|
-
return cast("RuntimeExpr", ne(cast("BaseExpr", self)).to(cast("BaseExpr", args[0])))
|
|
576
|
-
|
|
577
|
-
if __name in PARTIAL_METHODS:
|
|
578
|
-
return NotImplemented
|
|
579
|
-
raise TypeError(f"{class_name!r} object does not support {__name}")
|
|
587
|
+
def __ne__(self, other: object) -> object: # type: ignore[override]
|
|
588
|
+
if (method := _get_expr_method(self, "__ne__")) is not None:
|
|
589
|
+
return method(other)
|
|
590
|
+
|
|
591
|
+
from .egraph import BaseExpr, ne # noqa: PLC0415
|
|
580
592
|
|
|
581
|
-
|
|
593
|
+
return ne(cast("BaseExpr", self)).to(cast("BaseExpr", other))
|
|
582
594
|
|
|
583
|
-
|
|
584
|
-
|
|
595
|
+
def __call__(
|
|
596
|
+
self, *args: object, **kwargs: object
|
|
597
|
+
) -> object: # define it here only for type checking, it will be overriden below
|
|
598
|
+
...
|
|
585
599
|
|
|
586
|
-
def _reflected_method(self: RuntimeExpr, other: object, __non_reflected: str = non_reflected) -> RuntimeExpr | None:
|
|
587
|
-
# All binary methods are also "partial" meaning we should try to upcast first.
|
|
588
|
-
return call_method_min_conversion(other, self, __non_reflected)
|
|
589
600
|
|
|
590
|
-
|
|
601
|
+
def _get_expr_method(expr: RuntimeExpr, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
|
|
602
|
+
if name in (preserved_methods := expr.__egg_class_decl__.preserved_methods):
|
|
603
|
+
return preserved_methods[name].__get__(expr)
|
|
591
604
|
|
|
605
|
+
if name in expr.__egg_class_decl__.methods:
|
|
606
|
+
return RuntimeFunction(expr.__egg_decls_thunk__, Thunk.value(MethodRef(expr.__egg_class_name__, name)), expr)
|
|
607
|
+
return None
|
|
592
608
|
|
|
593
|
-
def call_method_min_conversion(slf: object, other: object, name: str) -> RuntimeExpr | None:
|
|
594
|
-
from .conversion import min_convertable_tp, resolve_literal
|
|
595
609
|
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
slf = resolve_literal(min_tp, slf)
|
|
600
|
-
other = resolve_literal(min_tp, other)
|
|
601
|
-
method = RuntimeFunction(slf.__egg_decls_thunk__, Thunk.value(MethodRef(slf.__egg_class_name__, name)), slf)
|
|
602
|
-
return method(other)
|
|
610
|
+
def define_expr_method(name: str) -> None:
|
|
611
|
+
"""
|
|
612
|
+
Given the name of a method, explicitly defines it on the runtime type that holds `Expr` objects as a method.
|
|
603
613
|
|
|
614
|
+
Call this if you need a method to be defined on the type itself where overrindg with `__getattr__` does not suffice,
|
|
615
|
+
like for NumPy's `__array_ufunc__`.
|
|
616
|
+
"""
|
|
604
617
|
|
|
605
|
-
|
|
618
|
+
def _defined_method(self: RuntimeExpr, *args, __name: str = name, **kwargs):
|
|
619
|
+
fn = _get_expr_method(self, __name)
|
|
620
|
+
if fn is None:
|
|
621
|
+
raise TypeError(f"{self.__egg_class_name__} expression has no method {__name}")
|
|
622
|
+
return fn(*args, **kwargs)
|
|
606
623
|
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
624
|
+
setattr(RuntimeExpr, name, _defined_method)
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
for name in TYPE_DEFINED_METHODS:
|
|
628
|
+
define_expr_method(name)
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
for name, r_method in itertools.product(NUMERIC_BINARY_METHODS, (False, True)):
|
|
632
|
+
|
|
633
|
+
def _numeric_binary_method(self: object, other: object, name: str = name, r_method: bool = r_method) -> object:
|
|
634
|
+
"""
|
|
635
|
+
Implements numeric binary operations.
|
|
636
|
+
|
|
637
|
+
Tries to find the minimum cost conversion of either the LHS or the RHS, by finding all methods with either
|
|
638
|
+
the LHS or the RHS as exactly the right type and then upcasting the other to that type.
|
|
639
|
+
"""
|
|
640
|
+
# 1. switch if reversed method
|
|
641
|
+
if r_method:
|
|
642
|
+
self, other = other, self
|
|
643
|
+
# If the types don't exactly match to start, then we need to try converting one of them, by finding the cheapest conversion
|
|
644
|
+
if not (
|
|
645
|
+
isinstance(self, RuntimeExpr)
|
|
646
|
+
and isinstance(other, RuntimeExpr)
|
|
647
|
+
and (
|
|
648
|
+
self.__egg_decls__.check_binary_method_with_types(
|
|
649
|
+
name, self.__egg_typed_expr__.tp, other.__egg_typed_expr__.tp
|
|
650
|
+
)
|
|
651
|
+
)
|
|
652
|
+
):
|
|
653
|
+
from .conversion import min_binary_conversion, resolve_type # noqa: PLC0415
|
|
654
|
+
|
|
655
|
+
best_method = min_binary_conversion(name, resolve_type(self), resolve_type(other))
|
|
656
|
+
|
|
657
|
+
if not best_method:
|
|
658
|
+
raise RuntimeError(f"Cannot resolve {name} for {self} and {other}, no conversion found")
|
|
659
|
+
self, other = best_method[0](self), best_method[1](other)
|
|
660
|
+
|
|
661
|
+
method_ref = MethodRef(self.__egg_class_name__, name)
|
|
662
|
+
fn = RuntimeFunction(Thunk.value(self.__egg_decls__), Thunk.value(method_ref), self)
|
|
663
|
+
return fn(other)
|
|
613
664
|
|
|
614
|
-
|
|
665
|
+
method_name = f"__r{name[2:]}" if r_method else name
|
|
666
|
+
setattr(RuntimeExpr, method_name, _numeric_binary_method)
|
|
615
667
|
|
|
616
668
|
|
|
617
669
|
def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
|
egglog/thunk.py
CHANGED
|
@@ -41,13 +41,13 @@ class Thunk(Generic[T, Unpack[TS]]):
|
|
|
41
41
|
state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving | Error
|
|
42
42
|
|
|
43
43
|
@classmethod
|
|
44
|
-
def fn(cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS]) -> Thunk[T, Unpack[TS]]:
|
|
44
|
+
def fn(cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS], context: str | None = None) -> Thunk[T, Unpack[TS]]:
|
|
45
45
|
"""
|
|
46
46
|
Create a thunk based on some functions and some partial args.
|
|
47
47
|
|
|
48
48
|
If the function is called while it is being resolved recursively it will raise an exception.
|
|
49
49
|
"""
|
|
50
|
-
return cls(Unresolved(fn, args))
|
|
50
|
+
return cls(Unresolved(fn, args, context))
|
|
51
51
|
|
|
52
52
|
@classmethod
|
|
53
53
|
def value(cls, value: T) -> Thunk[T]:
|
|
@@ -57,12 +57,12 @@ class Thunk(Generic[T, Unpack[TS]]):
|
|
|
57
57
|
match self.state:
|
|
58
58
|
case Resolved(value):
|
|
59
59
|
return value
|
|
60
|
-
case Unresolved(fn, args):
|
|
60
|
+
case Unresolved(fn, args, context):
|
|
61
61
|
self.state = Resolving()
|
|
62
62
|
try:
|
|
63
63
|
res = fn(*args)
|
|
64
64
|
except Exception as e:
|
|
65
|
-
self.state = Error(e)
|
|
65
|
+
self.state = Error(e, context)
|
|
66
66
|
raise e from None
|
|
67
67
|
else:
|
|
68
68
|
self.state = Resolved(res)
|
|
@@ -83,6 +83,7 @@ class Resolved(Generic[T]):
|
|
|
83
83
|
class Unresolved(Generic[T, Unpack[TS]]):
|
|
84
84
|
fn: Callable[[Unpack[TS]], T]
|
|
85
85
|
args: tuple[Unpack[TS]]
|
|
86
|
+
context: str | None
|
|
86
87
|
|
|
87
88
|
|
|
88
89
|
@dataclass
|
|
@@ -93,3 +94,4 @@ class Resolving:
|
|
|
93
94
|
@dataclass
|
|
94
95
|
class Error:
|
|
95
96
|
e: Exception
|
|
97
|
+
context: str | None
|
egglog/type_constraint_solver.py
CHANGED
|
@@ -107,7 +107,7 @@ class TypeConstraintSolver:
|
|
|
107
107
|
try:
|
|
108
108
|
return self._cls_typevar_index_to_type[cls_name][tp]
|
|
109
109
|
except KeyError as e:
|
|
110
|
-
raise TypeConstraintError(f"Not enough bound typevars for {tp} in class {cls_name}") from e
|
|
110
|
+
raise TypeConstraintError(f"Not enough bound typevars for {tp!r} in class {cls_name}") from e
|
|
111
111
|
case TypeRefWithVars(name, args):
|
|
112
112
|
return JustTypeRef(name, tuple(self.substitute_typevars(arg, cls_name) for arg in args))
|
|
113
113
|
assert_never(tp)
|
egglog/version_compat.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import sys
|
|
3
|
+
import types
|
|
4
|
+
import typing
|
|
5
|
+
|
|
6
|
+
BEFORE_3_11 = sys.version_info < (3, 11)
|
|
7
|
+
|
|
8
|
+
__all__ = ["add_note"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def add_note(message: str, exc: BaseException) -> BaseException:
|
|
12
|
+
"""
|
|
13
|
+
Backwards compatible add_note for Python <= 3.10
|
|
14
|
+
"""
|
|
15
|
+
if BEFORE_3_11:
|
|
16
|
+
return exc
|
|
17
|
+
exc.add_note(message)
|
|
18
|
+
return exc
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# For Python version 3.10 need to monkeypatch this function so that RuntimeClass type parameters
|
|
22
|
+
# will be collected as typevars
|
|
23
|
+
if BEFORE_3_11:
|
|
24
|
+
|
|
25
|
+
@typing.no_type_check
|
|
26
|
+
def _collect_type_vars_monkeypatch(types_, typevar_types=None):
|
|
27
|
+
"""
|
|
28
|
+
Collect all type variable contained
|
|
29
|
+
in types in order of first appearance (lexicographic order). For example::
|
|
30
|
+
|
|
31
|
+
_collect_type_vars((T, List[S, T])) == (T, S)
|
|
32
|
+
"""
|
|
33
|
+
from .runtime import RuntimeClass # noqa: PLC0415
|
|
34
|
+
|
|
35
|
+
if typevar_types is None:
|
|
36
|
+
typevar_types = typing.TypeVar
|
|
37
|
+
tvars = []
|
|
38
|
+
for t in types_:
|
|
39
|
+
if isinstance(t, typevar_types) and t not in tvars:
|
|
40
|
+
tvars.append(t)
|
|
41
|
+
# **MONKEYPATCH CHANGE HERE TO ADD RuntimeClass**
|
|
42
|
+
if isinstance(t, (typing._GenericAlias, typing.GenericAlias, types.UnionType, RuntimeClass)): # type: ignore[name-defined]
|
|
43
|
+
tvars.extend([t for t in t.__parameters__ if t not in tvars])
|
|
44
|
+
return tuple(tvars)
|
|
45
|
+
|
|
46
|
+
typing._collect_type_vars = _collect_type_vars_monkeypatch # type: ignore[attr-defined]
|
|
47
|
+
|
|
48
|
+
@typing.no_type_check
|
|
49
|
+
@typing._tp_cache
|
|
50
|
+
def __getitem__monkeypatch(self, params): # noqa: C901, PLR0912
|
|
51
|
+
from .runtime import RuntimeClass # noqa: PLC0415
|
|
52
|
+
|
|
53
|
+
if self.__origin__ in (typing.Generic, typing.Protocol):
|
|
54
|
+
# Can't subscript Generic[...] or Protocol[...].
|
|
55
|
+
raise TypeError(f"Cannot subscript already-subscripted {self}")
|
|
56
|
+
if not isinstance(params, tuple):
|
|
57
|
+
params = (params,)
|
|
58
|
+
params = tuple(typing._type_convert(p) for p in params)
|
|
59
|
+
if self._paramspec_tvars and any(isinstance(t, typing.ParamSpec) for t in self.__parameters__):
|
|
60
|
+
params = typing._prepare_paramspec_params(self, params)
|
|
61
|
+
else:
|
|
62
|
+
typing._check_generic(self, params, len(self.__parameters__))
|
|
63
|
+
|
|
64
|
+
subst = dict(zip(self.__parameters__, params, strict=False))
|
|
65
|
+
new_args = []
|
|
66
|
+
for arg in self.__args__:
|
|
67
|
+
if isinstance(arg, self._typevar_types):
|
|
68
|
+
if isinstance(arg, typing.ParamSpec):
|
|
69
|
+
arg = subst[arg] # noqa: PLW2901
|
|
70
|
+
if not typing._is_param_expr(arg):
|
|
71
|
+
raise TypeError(f"Expected a list of types, an ellipsis, ParamSpec, or Concatenate. Got {arg}")
|
|
72
|
+
else:
|
|
73
|
+
arg = subst[arg] # noqa: PLW2901
|
|
74
|
+
# **MONKEYPATCH CHANGE HERE TO ADD RuntimeClass**
|
|
75
|
+
elif isinstance(arg, (typing._GenericAlias, typing.GenericAlias, types.UnionType, RuntimeClass)):
|
|
76
|
+
subparams = arg.__parameters__
|
|
77
|
+
if subparams:
|
|
78
|
+
subargs = tuple(subst[x] for x in subparams)
|
|
79
|
+
arg = arg[subargs] # noqa: PLW2901
|
|
80
|
+
# Required to flatten out the args for CallableGenericAlias
|
|
81
|
+
if self.__origin__ == collections.abc.Callable and isinstance(arg, tuple):
|
|
82
|
+
new_args.extend(arg)
|
|
83
|
+
else:
|
|
84
|
+
new_args.append(arg)
|
|
85
|
+
return self.copy_with(tuple(new_args))
|
|
86
|
+
|
|
87
|
+
typing._GenericAlias.__getitem__ = __getitem__monkeypatch # type: ignore[attr-defined]
|