egglog 10.0.2__cp312-cp312-win_amd64.whl → 11.1.0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +3 -1
- egglog/bindings.cp312-win_amd64.pyd +0 -0
- egglog/bindings.pyi +26 -34
- egglog/builtins.py +231 -183
- egglog/conversion.py +61 -43
- egglog/declarations.py +103 -17
- egglog/deconstruct.py +173 -0
- egglog/egraph.py +78 -130
- egglog/egraph_state.py +17 -14
- egglog/examples/bignum.py +1 -1
- egglog/examples/multiset.py +2 -2
- egglog/exp/array_api.py +37 -3
- egglog/exp/array_api_jit.py +11 -5
- egglog/exp/array_api_program_gen.py +1 -1
- egglog/exp/program_gen.py +2 -2
- egglog/pretty.py +11 -25
- egglog/runtime.py +197 -147
- egglog/version_compat.py +3 -3
- {egglog-10.0.2.dist-info → egglog-11.1.0.dist-info}/METADATA +1 -1
- {egglog-10.0.2.dist-info → egglog-11.1.0.dist-info}/RECORD +22 -22
- {egglog-10.0.2.dist-info → egglog-11.1.0.dist-info}/WHEEL +1 -1
- egglog/functionalize.py +0 -91
- {egglog-10.0.2.dist-info → egglog-11.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -505,6 +505,6 @@ def _ndarray_program(
|
|
|
505
505
|
yield rewrite(ndarray_program(abs(x))).to((Program("np.abs(") + ndarray_program(x) + ")").assign())
|
|
506
506
|
|
|
507
507
|
# asarray
|
|
508
|
-
yield rewrite(ndarray_program(asarray(x, odtype))).to(
|
|
508
|
+
yield rewrite(ndarray_program(asarray(x, odtype, OptionalBool.none, optional_device_))).to(
|
|
509
509
|
Program("np.asarray(") + ndarray_program(x) + ", " + optional_dtype_program(odtype) + ")"
|
|
510
510
|
)
|
egglog/exp/program_gen.py
CHANGED
|
@@ -79,7 +79,7 @@ class Program(Expr):
|
|
|
79
79
|
Triggers compilation of the program.
|
|
80
80
|
"""
|
|
81
81
|
|
|
82
|
-
@method(merge=lambda old, _new: old) # type: ignore[
|
|
82
|
+
@method(merge=lambda old, _new: old) # type: ignore[prop-decorator]
|
|
83
83
|
@property
|
|
84
84
|
def parent(self) -> Program:
|
|
85
85
|
"""
|
|
@@ -108,7 +108,7 @@ class EvalProgram(Expr):
|
|
|
108
108
|
"""
|
|
109
109
|
|
|
110
110
|
# Only allow it to be set once, b/c hash of functions not stable
|
|
111
|
-
@method(merge=lambda old, _new: old) # type: ignore[
|
|
111
|
+
@method(merge=lambda old, _new: old) # type: ignore[prop-decorator]
|
|
112
112
|
@property
|
|
113
113
|
def as_py_object(self) -> PyObject:
|
|
114
114
|
"""
|
egglog/pretty.py
CHANGED
|
@@ -107,7 +107,7 @@ def pretty_callable_ref(
|
|
|
107
107
|
"""
|
|
108
108
|
# Pass in three dummy args, which are the max used for any operation that
|
|
109
109
|
# is not a generic function call
|
|
110
|
-
args: list[ExprDecl] = [
|
|
110
|
+
args: list[ExprDecl] = [UnboundVarDecl(ARG_STR)] * 3
|
|
111
111
|
if first_arg:
|
|
112
112
|
args.insert(0, first_arg)
|
|
113
113
|
context = PrettyContext(decls, defaultdict(lambda: 0))
|
|
@@ -166,7 +166,7 @@ class TraverseContext:
|
|
|
166
166
|
self(d.expr)
|
|
167
167
|
case ChangeDecl(_, d, _) | SaturateDecl(d) | RepeatDecl(d, _) | ActionCommandDecl(d):
|
|
168
168
|
self(d)
|
|
169
|
-
case PanicDecl(_) |
|
|
169
|
+
case PanicDecl(_) | UnboundVarDecl(_) | LetRefDecl(_) | LitDecl(_) | PyObjectDecl(_):
|
|
170
170
|
pass
|
|
171
171
|
case SequenceDecl(decls) | RulesetDecl(decls):
|
|
172
172
|
for de in decls:
|
|
@@ -233,6 +233,10 @@ class PrettyContext:
|
|
|
233
233
|
return expr
|
|
234
234
|
|
|
235
235
|
def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_name: str | None) -> tuple[str, str]: # noqa: C901, PLR0911, PLR0912
|
|
236
|
+
"""
|
|
237
|
+
Returns a tuple of a string value of the decleration and the "type" to use when create a memoized cached version
|
|
238
|
+
for de-duplication.
|
|
239
|
+
"""
|
|
236
240
|
match decl:
|
|
237
241
|
case LitDecl(value):
|
|
238
242
|
match value:
|
|
@@ -247,7 +251,7 @@ class PrettyContext:
|
|
|
247
251
|
case str(s):
|
|
248
252
|
return repr(s) if unwrap_lit else f"String({s!r})", "String"
|
|
249
253
|
assert_never(value)
|
|
250
|
-
case
|
|
254
|
+
case UnboundVarDecl(name) | LetRefDecl(name):
|
|
251
255
|
return name, name
|
|
252
256
|
case CallDecl(_, _, _):
|
|
253
257
|
return self._call(decl, parens)
|
|
@@ -357,7 +361,7 @@ class PrettyContext:
|
|
|
357
361
|
has_multiple_parents = self.parents[first_arg] > 1
|
|
358
362
|
self.names[decl] = expr_name = self._name_expr(tp_name, expr_str, copy_identifier=has_multiple_parents)
|
|
359
363
|
# Set the first arg to be the name of the mutated arg and return the name
|
|
360
|
-
args[0] =
|
|
364
|
+
args[0] = LetRefDecl(expr_name)
|
|
361
365
|
else:
|
|
362
366
|
expr_name = None
|
|
363
367
|
res = self._call_inner(ref, args, decl.bound_tp_params, parens)
|
|
@@ -390,6 +394,7 @@ class PrettyContext:
|
|
|
390
394
|
return f"{tp_ref}.{method_name}", args
|
|
391
395
|
case MethodRef(_class_name, method_name):
|
|
392
396
|
slf, *args = args
|
|
397
|
+
non_str_slf = slf
|
|
393
398
|
slf = self(slf, parens=True)
|
|
394
399
|
match method_name:
|
|
395
400
|
case _ if method_name in UNARY_METHODS:
|
|
@@ -406,6 +411,8 @@ class PrettyContext:
|
|
|
406
411
|
return f"del {slf}[{self(args[0], unwrap_lit=True)}]"
|
|
407
412
|
case "__setitem__":
|
|
408
413
|
return f"{slf}[{self(args[0], unwrap_lit=True)}] = {self(args[1], unwrap_lit=True)}"
|
|
414
|
+
case "__round__":
|
|
415
|
+
return "round", [non_str_slf, *args]
|
|
409
416
|
case _:
|
|
410
417
|
return f"{slf}.{method_name}", args
|
|
411
418
|
case ConstantRef(name):
|
|
@@ -487,24 +494,3 @@ class PrettyContext:
|
|
|
487
494
|
if arg_names:
|
|
488
495
|
prefix += f" {', '.join(self(a.expr) for a in arg_names)}"
|
|
489
496
|
return f"{prefix}: {self(res.expr)}"
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
def _plot_line_length(expr: object): # pragma: no cover
|
|
493
|
-
"""
|
|
494
|
-
Plots the number of line lengths based on different max lengths
|
|
495
|
-
"""
|
|
496
|
-
global MAX_LINE_LENGTH, LINE_DIFFERENCE
|
|
497
|
-
import altair as alt
|
|
498
|
-
import pandas as pd
|
|
499
|
-
|
|
500
|
-
sizes = []
|
|
501
|
-
for line_length in range(40, 180, 10):
|
|
502
|
-
MAX_LINE_LENGTH = line_length
|
|
503
|
-
for diff in range(0, 40, 5):
|
|
504
|
-
LINE_DIFFERENCE = diff
|
|
505
|
-
new_l = len(str(expr).split())
|
|
506
|
-
sizes.append((line_length, diff, new_l))
|
|
507
|
-
|
|
508
|
-
df = pd.DataFrame(sizes, columns=["MAX_LINE_LENGTH", "LENGTH_DIFFERENCE", "n"])
|
|
509
|
-
|
|
510
|
-
return alt.Chart(df).mark_rect().encode(x="MAX_LINE_LENGTH:O", y="LENGTH_DIFFERENCE:O", color="n:Q")
|
egglog/runtime.py
CHANGED
|
@@ -11,12 +11,14 @@ so they are not mangled by Python and can be accessed by the user.
|
|
|
11
11
|
|
|
12
12
|
from __future__ import annotations
|
|
13
13
|
|
|
14
|
+
import itertools
|
|
14
15
|
import operator
|
|
16
|
+
import types
|
|
15
17
|
from collections.abc import Callable
|
|
16
|
-
from dataclasses import dataclass, replace
|
|
18
|
+
from dataclasses import InitVar, dataclass, replace
|
|
17
19
|
from inspect import Parameter, Signature
|
|
18
20
|
from itertools import zip_longest
|
|
19
|
-
from typing import TYPE_CHECKING, TypeVar, Union, cast, get_args, get_origin
|
|
21
|
+
from typing import TYPE_CHECKING, Any, TypeVar, Union, cast, get_args, get_origin
|
|
20
22
|
|
|
21
23
|
from .declarations import *
|
|
22
24
|
from .pretty import *
|
|
@@ -27,15 +29,14 @@ from .version_compat import *
|
|
|
27
29
|
if TYPE_CHECKING:
|
|
28
30
|
from collections.abc import Iterable
|
|
29
31
|
|
|
30
|
-
from .egraph import Fact
|
|
31
|
-
|
|
32
32
|
|
|
33
33
|
__all__ = [
|
|
34
34
|
"LIT_CLASS_NAMES",
|
|
35
|
-
"
|
|
35
|
+
"NUMERIC_BINARY_METHODS",
|
|
36
36
|
"RuntimeClass",
|
|
37
37
|
"RuntimeExpr",
|
|
38
38
|
"RuntimeFunction",
|
|
39
|
+
"define_expr_method",
|
|
39
40
|
"resolve_callable",
|
|
40
41
|
"resolve_type_annotation",
|
|
41
42
|
"resolve_type_annotation_mutate",
|
|
@@ -46,24 +47,34 @@ UNIT_CLASS_NAME = "Unit"
|
|
|
46
47
|
UNARY_LIT_CLASS_NAMES = {"i64", "f64", "Bool", "String"}
|
|
47
48
|
LIT_CLASS_NAMES = UNARY_LIT_CLASS_NAMES | {UNIT_CLASS_NAME, "PyObject"}
|
|
48
49
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
"
|
|
54
|
-
"
|
|
55
|
-
"
|
|
56
|
-
"
|
|
57
|
-
"
|
|
58
|
-
"
|
|
59
|
-
"
|
|
60
|
-
"
|
|
61
|
-
"
|
|
62
|
-
"
|
|
50
|
+
# All methods which should return NotImplemented if they fail to resolve and are reflected as well
|
|
51
|
+
# From https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types
|
|
52
|
+
|
|
53
|
+
NUMERIC_BINARY_METHODS = {
|
|
54
|
+
"__add__",
|
|
55
|
+
"__sub__",
|
|
56
|
+
"__mul__",
|
|
57
|
+
"__matmul__",
|
|
58
|
+
"__truediv__",
|
|
59
|
+
"__floordiv__",
|
|
60
|
+
"__mod__",
|
|
61
|
+
"__divmod__",
|
|
62
|
+
"__pow__",
|
|
63
|
+
"__lshift__",
|
|
64
|
+
"__rshift__",
|
|
65
|
+
"__and__",
|
|
66
|
+
"__xor__",
|
|
67
|
+
"__or__",
|
|
68
|
+
"__lt__",
|
|
69
|
+
"__le__",
|
|
70
|
+
"__gt__",
|
|
71
|
+
"__ge__",
|
|
63
72
|
}
|
|
64
73
|
|
|
65
|
-
|
|
66
|
-
|
|
74
|
+
|
|
75
|
+
# Methods that need to be defined on the runtime type that holds `Expr` objects, so that they can be used as methods.
|
|
76
|
+
|
|
77
|
+
TYPE_DEFINED_METHODS = {
|
|
67
78
|
"__bool__",
|
|
68
79
|
"__len__",
|
|
69
80
|
"__complex__",
|
|
@@ -71,9 +82,15 @@ PRESERVED_METHODS = [
|
|
|
71
82
|
"__float__",
|
|
72
83
|
"__iter__",
|
|
73
84
|
"__index__",
|
|
74
|
-
"
|
|
75
|
-
"
|
|
76
|
-
|
|
85
|
+
"__call__",
|
|
86
|
+
"__getitem__",
|
|
87
|
+
"__setitem__",
|
|
88
|
+
"__delitem__",
|
|
89
|
+
"__pos__",
|
|
90
|
+
"__neg__",
|
|
91
|
+
"__invert__",
|
|
92
|
+
"__round__",
|
|
93
|
+
}
|
|
77
94
|
|
|
78
95
|
# Set this globally so we can get access to PyObject when we have a type annotation of just object.
|
|
79
96
|
# This is the only time a type annotation doesn't need to include the egglog type b/c object is top so that would be redundant statically.
|
|
@@ -135,11 +152,48 @@ def inverse_resolve_type_annotation(decls_thunk: Callable[[], Declarations], tp:
|
|
|
135
152
|
##
|
|
136
153
|
|
|
137
154
|
|
|
138
|
-
|
|
139
|
-
|
|
155
|
+
class BaseClassFactoryMeta(type):
|
|
156
|
+
"""
|
|
157
|
+
Base metaclass for all runtime classes created by ClassFactory
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def __instancecheck__(cls, instance: object) -> bool:
|
|
161
|
+
assert isinstance(cls, RuntimeClass)
|
|
162
|
+
return isinstance(instance, RuntimeExpr) and cls.__egg_tp__.name == instance.__egg_typed_expr__.tp.name
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class ClassFactory(type):
|
|
166
|
+
"""
|
|
167
|
+
A metaclass for types which should create `type` objects when instantiated.
|
|
168
|
+
|
|
169
|
+
That's so that they work with `isinstance` and can be placed in `match ClassName()`.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __call__(cls, *args, **kwargs) -> type:
|
|
173
|
+
# If we have params, don't inherit from `type` because we don't need to match against this and also
|
|
174
|
+
# this won't work with `Union[X]` because it won't look at `__parameters__` for instances of `type`.
|
|
175
|
+
if kwargs.pop("_egg_has_params", False):
|
|
176
|
+
return super().__call__(*args, **kwargs)
|
|
177
|
+
namespace: dict[str, Any] = {}
|
|
178
|
+
for m in reversed(cls.__mro__):
|
|
179
|
+
namespace.update(m.__dict__)
|
|
180
|
+
init = namespace.pop("__init__")
|
|
181
|
+
meta = types.new_class("type(RuntimeClass)", (BaseClassFactoryMeta,), {}, lambda ns: ns.update(**namespace))
|
|
182
|
+
tp = types.new_class("RuntimeClass", (), {"metaclass": meta})
|
|
183
|
+
init(tp, *args, **kwargs)
|
|
184
|
+
return tp
|
|
185
|
+
|
|
186
|
+
def __instancecheck__(cls, instance: object) -> bool:
|
|
187
|
+
return isinstance(instance, BaseClassFactoryMeta)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@dataclass(match_args=False)
|
|
191
|
+
class RuntimeClass(DelayedDeclerations, metaclass=ClassFactory):
|
|
140
192
|
__egg_tp__: TypeRefWithVars
|
|
193
|
+
# True if we want `__parameters__` to be recognized by `Union`, which means we can't inherit from `type` directly.
|
|
194
|
+
_egg_has_params: InitVar[bool] = False
|
|
141
195
|
|
|
142
|
-
def __post_init__(self) -> None:
|
|
196
|
+
def __post_init__(self, _egg_has_params: bool) -> None:
|
|
143
197
|
global _PY_OBJECT_CLASS, _UNSTABLE_FN_CLASS
|
|
144
198
|
if (name := self.__egg_tp__.name) == "PyObject":
|
|
145
199
|
_PY_OBJECT_CLASS = self
|
|
@@ -229,7 +283,7 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
229
283
|
else:
|
|
230
284
|
final_args = new_args
|
|
231
285
|
tp = TypeRefWithVars(self.__egg_tp__.name, final_args)
|
|
232
|
-
return RuntimeClass(Thunk.fn(Declarations.create, self, *decls_like), tp)
|
|
286
|
+
return RuntimeClass(Thunk.fn(Declarations.create, self, *decls_like), tp, _egg_has_params=True)
|
|
233
287
|
|
|
234
288
|
def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable:
|
|
235
289
|
if name == "__origin__" and self.__egg_tp__.args:
|
|
@@ -288,6 +342,14 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
288
342
|
def __hash__(self) -> int:
|
|
289
343
|
return hash((id(self.__egg_decls_thunk__), self.__egg_tp__))
|
|
290
344
|
|
|
345
|
+
def __eq__(self, other: object) -> bool:
|
|
346
|
+
"""
|
|
347
|
+
Support equality for runtime comparison of egglog classes.
|
|
348
|
+
"""
|
|
349
|
+
if not isinstance(other, RuntimeClass):
|
|
350
|
+
return NotImplemented
|
|
351
|
+
return self.__egg_tp__ == other.__egg_tp__
|
|
352
|
+
|
|
291
353
|
# Support unioning like types
|
|
292
354
|
def __or__(self, value: type) -> object:
|
|
293
355
|
return Union[self, value] # noqa: UP007
|
|
@@ -299,6 +361,10 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
299
361
|
"""
|
|
300
362
|
return tuple(inverse_resolve_type_annotation(self.__egg_decls_thunk__, tp) for tp in self.__egg_tp__.args)
|
|
301
363
|
|
|
364
|
+
@property
|
|
365
|
+
def __match_args__(self) -> tuple[str, ...]:
|
|
366
|
+
return self.__egg_decls__._classes[self.__egg_tp__.name].match_args
|
|
367
|
+
|
|
302
368
|
|
|
303
369
|
@dataclass
|
|
304
370
|
class RuntimeFunction(DelayedDeclerations):
|
|
@@ -306,12 +372,23 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
306
372
|
# bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
|
|
307
373
|
__egg_bound__: JustTypeRef | RuntimeExpr | None = None
|
|
308
374
|
|
|
375
|
+
def __eq__(self, other: object) -> bool:
|
|
376
|
+
"""
|
|
377
|
+
Support equality for runtime comparison of egglog functions.
|
|
378
|
+
"""
|
|
379
|
+
if not isinstance(other, RuntimeFunction):
|
|
380
|
+
return NotImplemented
|
|
381
|
+
return self.__egg_ref__ == other.__egg_ref__ and bool(self.__egg_bound__ == other.__egg_bound__)
|
|
382
|
+
|
|
383
|
+
def __hash__(self) -> int:
|
|
384
|
+
return hash((self.__egg_ref__, self.__egg_bound__))
|
|
385
|
+
|
|
309
386
|
@property
|
|
310
387
|
def __egg_ref__(self) -> CallableRef:
|
|
311
388
|
return self.__egg_ref_thunk__()
|
|
312
389
|
|
|
313
390
|
def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None:
|
|
314
|
-
from .conversion import resolve_literal
|
|
391
|
+
from .conversion import resolve_literal # noqa: PLC0415
|
|
315
392
|
|
|
316
393
|
if isinstance(self.__egg_bound__, RuntimeExpr):
|
|
317
394
|
args = (self.__egg_bound__, *args)
|
|
@@ -346,7 +423,7 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
346
423
|
try:
|
|
347
424
|
bound = py_signature.bind(*args, **kwargs)
|
|
348
425
|
except TypeError as err:
|
|
349
|
-
raise TypeError(f"Failed to
|
|
426
|
+
raise TypeError(f"Failed to bind arguments for {self} with args {args} and kwargs {kwargs}: {err}") from err
|
|
350
427
|
del kwargs
|
|
351
428
|
bound.apply_defaults()
|
|
352
429
|
assert not bound.kwargs
|
|
@@ -415,9 +492,7 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args:
|
|
|
415
492
|
Parameter(
|
|
416
493
|
n,
|
|
417
494
|
Parameter.POSITIONAL_OR_KEYWORD,
|
|
418
|
-
default=RuntimeExpr.__from_values__(
|
|
419
|
-
decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n, True))
|
|
420
|
-
)
|
|
495
|
+
default=RuntimeExpr.__from_values__(decls, TypedExprDecl(t.to_just(), d or LetRefDecl(n)))
|
|
421
496
|
if d is not None or optional_args
|
|
422
497
|
else Parameter.empty,
|
|
423
498
|
)
|
|
@@ -428,32 +503,6 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args:
|
|
|
428
503
|
return Signature(parameters)
|
|
429
504
|
|
|
430
505
|
|
|
431
|
-
# All methods which should return NotImplemented if they fail to resolve
|
|
432
|
-
# From https://docs.python.org/3/reference/datamodel.html
|
|
433
|
-
PARTIAL_METHODS = {
|
|
434
|
-
"__lt__",
|
|
435
|
-
"__le__",
|
|
436
|
-
"__eq__",
|
|
437
|
-
"__ne__",
|
|
438
|
-
"__gt__",
|
|
439
|
-
"__ge__",
|
|
440
|
-
"__add__",
|
|
441
|
-
"__sub__",
|
|
442
|
-
"__mul__",
|
|
443
|
-
"__matmul__",
|
|
444
|
-
"__truediv__",
|
|
445
|
-
"__floordiv__",
|
|
446
|
-
"__mod__",
|
|
447
|
-
"__divmod__",
|
|
448
|
-
"__pow__",
|
|
449
|
-
"__lshift__",
|
|
450
|
-
"__rshift__",
|
|
451
|
-
"__and__",
|
|
452
|
-
"__xor__",
|
|
453
|
-
"__or__",
|
|
454
|
-
}
|
|
455
|
-
|
|
456
|
-
|
|
457
506
|
@dataclass
|
|
458
507
|
class RuntimeExpr(DelayedDeclerations):
|
|
459
508
|
__egg_typed_expr_thunk__: Callable[[], TypedExprDecl]
|
|
@@ -470,17 +519,14 @@ class RuntimeExpr(DelayedDeclerations):
|
|
|
470
519
|
return self.__egg_typed_expr_thunk__()
|
|
471
520
|
|
|
472
521
|
def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
if name in class_decl.properties:
|
|
482
|
-
return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(cls_name, name)), self)()
|
|
483
|
-
raise AttributeError(f"{cls_name} has no method {name}") from None
|
|
522
|
+
if (method := _get_expr_method(self, name)) is not None:
|
|
523
|
+
return method
|
|
524
|
+
if name in self.__egg_class_decl__.properties:
|
|
525
|
+
fn = RuntimeFunction(
|
|
526
|
+
self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_class_name__, name)), self
|
|
527
|
+
)
|
|
528
|
+
return fn()
|
|
529
|
+
raise AttributeError(f"{self.__egg_class_name__} has no method {name}") from None
|
|
484
530
|
|
|
485
531
|
def __repr__(self) -> str:
|
|
486
532
|
"""
|
|
@@ -495,7 +541,7 @@ class RuntimeExpr(DelayedDeclerations):
|
|
|
495
541
|
return pretty_decl(self.__egg_decls__, self.__egg_typed_expr__.expr, wrapping_fn=wrapping_fn)
|
|
496
542
|
|
|
497
543
|
def _ipython_display_(self) -> None:
|
|
498
|
-
from IPython.display import Code, display
|
|
544
|
+
from IPython.display import Code, display # noqa: PLC0415
|
|
499
545
|
|
|
500
546
|
display(Code(str(self), language="python"))
|
|
501
547
|
|
|
@@ -511,13 +557,6 @@ class RuntimeExpr(DelayedDeclerations):
|
|
|
511
557
|
def __egg_class_decl__(self) -> ClassDecl:
|
|
512
558
|
return self.__egg_decls__.get_class_decl(self.__egg_class_name__)
|
|
513
559
|
|
|
514
|
-
# These both will be overriden below in the special methods section, but add these here for type hinting purposes
|
|
515
|
-
def __eq__(self, other: object) -> Fact: # type: ignore[override, empty-body]
|
|
516
|
-
...
|
|
517
|
-
|
|
518
|
-
def __ne__(self, other: object) -> RuntimeExpr: # type: ignore[override, empty-body]
|
|
519
|
-
...
|
|
520
|
-
|
|
521
560
|
# Implement these so that copy() works on this object
|
|
522
561
|
# otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion
|
|
523
562
|
|
|
@@ -529,91 +568,102 @@ class RuntimeExpr(DelayedDeclerations):
|
|
|
529
568
|
self.__egg_typed_expr_thunk__ = Thunk.value(d[1])
|
|
530
569
|
|
|
531
570
|
def __hash__(self) -> int:
|
|
571
|
+
if (method := _get_expr_method(self, "__hash__")) is not None:
|
|
572
|
+
return cast("int", cast("Any", method()))
|
|
532
573
|
return hash(self.__egg_typed_expr__)
|
|
533
574
|
|
|
575
|
+
# Implement this directly to special case behavior where it transforms to an egraph equality, if it is not a
|
|
576
|
+
# preserved method or defined on the class
|
|
577
|
+
def __eq__(self, other: object) -> object: # type: ignore[override]
|
|
578
|
+
if (method := _get_expr_method(self, "__eq__")) is not None:
|
|
579
|
+
return method(other)
|
|
534
580
|
|
|
535
|
-
#
|
|
536
|
-
|
|
581
|
+
# TODO: Check if two objects can be upcasted to be the same. If not, then return NotImplemented so other
|
|
582
|
+
# expr gets a chance to resolve __eq__ which could be a preserved method.
|
|
583
|
+
from .egraph import BaseExpr, eq # noqa: PLC0415
|
|
537
584
|
|
|
538
|
-
|
|
539
|
-
self: RuntimeExpr,
|
|
540
|
-
*args: object,
|
|
541
|
-
__name: str = name,
|
|
542
|
-
**kwargs: object,
|
|
543
|
-
) -> RuntimeExpr | Fact | None:
|
|
544
|
-
from .conversion import ConvertError
|
|
585
|
+
return eq(cast("BaseExpr", self)).to(cast("BaseExpr", other))
|
|
545
586
|
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
try:
|
|
550
|
-
method = class_decl.preserved_methods[__name]
|
|
551
|
-
except KeyError:
|
|
552
|
-
pass
|
|
553
|
-
else:
|
|
554
|
-
return method(self, *args, **kwargs)
|
|
555
|
-
# If this is a "partial" method meaning that it can return NotImplemented,
|
|
556
|
-
# we want to find the "best" superparent (lowest cost) of the arg types to call with it, instead of just
|
|
557
|
-
# using the arg type of the self arg.
|
|
558
|
-
# This is neccesary so if we add like an int to a ndarray, it will upcast the int to an ndarray, instead of vice versa.
|
|
559
|
-
if __name in PARTIAL_METHODS:
|
|
560
|
-
try:
|
|
561
|
-
return call_method_min_conversion(self, args[0], __name)
|
|
562
|
-
except ConvertError:
|
|
563
|
-
# Defer raising not imeplemented in case the dunder method is not symmetrical, then
|
|
564
|
-
# we use the standard process
|
|
565
|
-
pass
|
|
566
|
-
if __name in class_decl.methods:
|
|
567
|
-
fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(class_name, __name)), self)
|
|
568
|
-
return fn(*args, **kwargs) # type: ignore[arg-type]
|
|
569
|
-
# Handle == and != fallbacks to eq and ne helpers if the methods aren't defined on the class explicitly.
|
|
570
|
-
if __name == "__eq__":
|
|
571
|
-
from .egraph import BaseExpr, eq
|
|
572
|
-
|
|
573
|
-
return eq(cast("BaseExpr", self)).to(cast("BaseExpr", args[0]))
|
|
574
|
-
if __name == "__ne__":
|
|
575
|
-
from .egraph import BaseExpr, ne
|
|
576
|
-
|
|
577
|
-
return cast("RuntimeExpr", ne(cast("BaseExpr", self)).to(cast("BaseExpr", args[0])))
|
|
578
|
-
|
|
579
|
-
if __name in PARTIAL_METHODS:
|
|
580
|
-
return NotImplemented
|
|
581
|
-
raise TypeError(f"{class_name!r} object does not support {__name}")
|
|
587
|
+
def __ne__(self, other: object) -> object: # type: ignore[override]
|
|
588
|
+
if (method := _get_expr_method(self, "__ne__")) is not None:
|
|
589
|
+
return method(other)
|
|
582
590
|
|
|
583
|
-
|
|
591
|
+
from .egraph import BaseExpr, ne # noqa: PLC0415
|
|
584
592
|
|
|
585
|
-
|
|
586
|
-
for reflected, non_reflected in REFLECTED_BINARY_METHODS.items():
|
|
593
|
+
return ne(cast("BaseExpr", self)).to(cast("BaseExpr", other))
|
|
587
594
|
|
|
588
|
-
def
|
|
589
|
-
|
|
590
|
-
|
|
595
|
+
def __call__(
|
|
596
|
+
self, *args: object, **kwargs: object
|
|
597
|
+
) -> object: # define it here only for type checking, it will be overriden below
|
|
598
|
+
...
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
def _get_expr_method(expr: RuntimeExpr, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
|
|
602
|
+
if name in (preserved_methods := expr.__egg_class_decl__.preserved_methods):
|
|
603
|
+
return preserved_methods[name].__get__(expr)
|
|
591
604
|
|
|
592
|
-
|
|
605
|
+
if name in expr.__egg_class_decl__.methods:
|
|
606
|
+
return RuntimeFunction(expr.__egg_decls_thunk__, Thunk.value(MethodRef(expr.__egg_class_name__, name)), expr)
|
|
607
|
+
return None
|
|
593
608
|
|
|
594
609
|
|
|
595
|
-
def
|
|
596
|
-
|
|
610
|
+
def define_expr_method(name: str) -> None:
|
|
611
|
+
"""
|
|
612
|
+
Given the name of a method, explicitly defines it on the runtime type that holds `Expr` objects as a method.
|
|
597
613
|
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
slf = resolve_literal(min_tp, slf)
|
|
602
|
-
other = resolve_literal(min_tp, other)
|
|
603
|
-
method = RuntimeFunction(slf.__egg_decls_thunk__, Thunk.value(MethodRef(slf.__egg_class_name__, name)), slf)
|
|
604
|
-
return method(other)
|
|
614
|
+
Call this if you need a method to be defined on the type itself where overrindg with `__getattr__` does not suffice,
|
|
615
|
+
like for NumPy's `__array_ufunc__`.
|
|
616
|
+
"""
|
|
605
617
|
|
|
618
|
+
def _defined_method(self: RuntimeExpr, *args, __name: str = name, **kwargs):
|
|
619
|
+
fn = _get_expr_method(self, __name)
|
|
620
|
+
if fn is None:
|
|
621
|
+
raise TypeError(f"{self.__egg_class_name__} expression has no method {__name}")
|
|
622
|
+
return fn(*args, **kwargs)
|
|
606
623
|
|
|
607
|
-
|
|
624
|
+
setattr(RuntimeExpr, name, _defined_method)
|
|
608
625
|
|
|
609
|
-
def _preserved_method(self: RuntimeExpr, __name: str = name):
|
|
610
|
-
try:
|
|
611
|
-
method = self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).preserved_methods[__name]
|
|
612
|
-
except KeyError as e:
|
|
613
|
-
raise TypeError(f"{self.__egg_typed_expr__.tp.name} has no method {__name}") from e
|
|
614
|
-
return method(self)
|
|
615
626
|
|
|
616
|
-
|
|
627
|
+
for name in TYPE_DEFINED_METHODS:
|
|
628
|
+
define_expr_method(name)
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
for name, r_method in itertools.product(NUMERIC_BINARY_METHODS, (False, True)):
|
|
632
|
+
|
|
633
|
+
def _numeric_binary_method(self: object, other: object, name: str = name, r_method: bool = r_method) -> object:
|
|
634
|
+
"""
|
|
635
|
+
Implements numeric binary operations.
|
|
636
|
+
|
|
637
|
+
Tries to find the minimum cost conversion of either the LHS or the RHS, by finding all methods with either
|
|
638
|
+
the LHS or the RHS as exactly the right type and then upcasting the other to that type.
|
|
639
|
+
"""
|
|
640
|
+
# 1. switch if reversed method
|
|
641
|
+
if r_method:
|
|
642
|
+
self, other = other, self
|
|
643
|
+
# If the types don't exactly match to start, then we need to try converting one of them, by finding the cheapest conversion
|
|
644
|
+
if not (
|
|
645
|
+
isinstance(self, RuntimeExpr)
|
|
646
|
+
and isinstance(other, RuntimeExpr)
|
|
647
|
+
and (
|
|
648
|
+
self.__egg_decls__.check_binary_method_with_types(
|
|
649
|
+
name, self.__egg_typed_expr__.tp, other.__egg_typed_expr__.tp
|
|
650
|
+
)
|
|
651
|
+
)
|
|
652
|
+
):
|
|
653
|
+
from .conversion import min_binary_conversion, resolve_type # noqa: PLC0415
|
|
654
|
+
|
|
655
|
+
best_method = min_binary_conversion(name, resolve_type(self), resolve_type(other))
|
|
656
|
+
|
|
657
|
+
if not best_method:
|
|
658
|
+
raise RuntimeError(f"Cannot resolve {name} for {self} and {other}, no conversion found")
|
|
659
|
+
self, other = best_method[0](self), best_method[1](other)
|
|
660
|
+
|
|
661
|
+
method_ref = MethodRef(self.__egg_class_name__, name)
|
|
662
|
+
fn = RuntimeFunction(Thunk.value(self.__egg_decls__), Thunk.value(method_ref), self)
|
|
663
|
+
return fn(other)
|
|
664
|
+
|
|
665
|
+
method_name = f"__r{name[2:]}" if r_method else name
|
|
666
|
+
setattr(RuntimeExpr, method_name, _numeric_binary_method)
|
|
617
667
|
|
|
618
668
|
|
|
619
669
|
def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
|
egglog/version_compat.py
CHANGED
|
@@ -30,7 +30,7 @@ if BEFORE_3_11:
|
|
|
30
30
|
|
|
31
31
|
_collect_type_vars((T, List[S, T])) == (T, S)
|
|
32
32
|
"""
|
|
33
|
-
from .runtime import RuntimeClass
|
|
33
|
+
from .runtime import RuntimeClass # noqa: PLC0415
|
|
34
34
|
|
|
35
35
|
if typevar_types is None:
|
|
36
36
|
typevar_types = typing.TypeVar
|
|
@@ -39,7 +39,7 @@ if BEFORE_3_11:
|
|
|
39
39
|
if isinstance(t, typevar_types) and t not in tvars:
|
|
40
40
|
tvars.append(t)
|
|
41
41
|
# **MONKEYPATCH CHANGE HERE TO ADD RuntimeClass**
|
|
42
|
-
if isinstance(t, (typing._GenericAlias, typing.GenericAlias, types.UnionType, RuntimeClass)):
|
|
42
|
+
if isinstance(t, (typing._GenericAlias, typing.GenericAlias, types.UnionType, RuntimeClass)):
|
|
43
43
|
tvars.extend([t for t in t.__parameters__ if t not in tvars])
|
|
44
44
|
return tuple(tvars)
|
|
45
45
|
|
|
@@ -48,7 +48,7 @@ if BEFORE_3_11:
|
|
|
48
48
|
@typing.no_type_check
|
|
49
49
|
@typing._tp_cache
|
|
50
50
|
def __getitem__monkeypatch(self, params): # noqa: C901, PLR0912
|
|
51
|
-
from .runtime import RuntimeClass
|
|
51
|
+
from .runtime import RuntimeClass # noqa: PLC0415
|
|
52
52
|
|
|
53
53
|
if self.__origin__ in (typing.Generic, typing.Protocol):
|
|
54
54
|
# Can't subscript Generic[...] or Protocol[...].
|