egglog 10.0.2__cp312-cp312-win_amd64.whl → 11.1.0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +3 -1
- egglog/bindings.cp312-win_amd64.pyd +0 -0
- egglog/bindings.pyi +26 -34
- egglog/builtins.py +231 -183
- egglog/conversion.py +61 -43
- egglog/declarations.py +103 -17
- egglog/deconstruct.py +173 -0
- egglog/egraph.py +78 -130
- egglog/egraph_state.py +17 -14
- egglog/examples/bignum.py +1 -1
- egglog/examples/multiset.py +2 -2
- egglog/exp/array_api.py +37 -3
- egglog/exp/array_api_jit.py +11 -5
- egglog/exp/array_api_program_gen.py +1 -1
- egglog/exp/program_gen.py +2 -2
- egglog/pretty.py +11 -25
- egglog/runtime.py +197 -147
- egglog/version_compat.py +3 -3
- {egglog-10.0.2.dist-info → egglog-11.1.0.dist-info}/METADATA +1 -1
- {egglog-10.0.2.dist-info → egglog-11.1.0.dist-info}/RECORD +22 -22
- {egglog-10.0.2.dist-info → egglog-11.1.0.dist-info}/WHEEL +1 -1
- egglog/functionalize.py +0 -91
- {egglog-10.0.2.dist-info → egglog-11.1.0.dist-info}/licenses/LICENSE +0 -0
egglog/conversion.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from collections import defaultdict
|
|
4
|
+
from collections.abc import Callable
|
|
4
5
|
from contextlib import contextmanager
|
|
5
6
|
from contextvars import ContextVar
|
|
6
7
|
from dataclasses import dataclass
|
|
7
|
-
from typing import TYPE_CHECKING, TypeVar, cast
|
|
8
|
+
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
|
8
9
|
|
|
9
10
|
from .declarations import *
|
|
10
11
|
from .pretty import *
|
|
@@ -13,14 +14,14 @@ from .thunk import *
|
|
|
13
14
|
from .type_constraint_solver import TypeConstraintError
|
|
14
15
|
|
|
15
16
|
if TYPE_CHECKING:
|
|
16
|
-
from collections.abc import
|
|
17
|
+
from collections.abc import Generator
|
|
17
18
|
|
|
18
19
|
from .egraph import BaseExpr
|
|
19
20
|
from .type_constraint_solver import TypeConstraintSolver
|
|
20
21
|
|
|
21
|
-
__all__ = ["ConvertError", "convert", "
|
|
22
|
+
__all__ = ["ConvertError", "convert", "converter", "get_type_args"]
|
|
22
23
|
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
|
|
23
|
-
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
|
|
24
|
+
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable[[Any], RuntimeExpr]]] = {}
|
|
24
25
|
# Global declerations to store all convertable types so we can query if they have certain methods or not
|
|
25
26
|
_CONVERSION_DECLS = Declarations.create()
|
|
26
27
|
# Defer a list of declerations to be added to the global declerations, so that we can not trigger them procesing
|
|
@@ -28,7 +29,7 @@ _CONVERSION_DECLS = Declarations.create()
|
|
|
28
29
|
_TO_PROCESS_DECLS: list[DeclerationsLike] = []
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
def
|
|
32
|
+
def retrieve_conversion_decls() -> Declarations:
|
|
32
33
|
_CONVERSION_DECLS.update(*_TO_PROCESS_DECLS)
|
|
33
34
|
_TO_PROCESS_DECLS.clear()
|
|
34
35
|
return _CONVERSION_DECLS
|
|
@@ -49,10 +50,10 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost:
|
|
|
49
50
|
to_type_name = process_tp(to_type)
|
|
50
51
|
if not isinstance(to_type_name, JustTypeRef):
|
|
51
52
|
raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
|
|
52
|
-
_register_converter(process_tp(from_type), to_type_name, fn, cost)
|
|
53
|
+
_register_converter(process_tp(from_type), to_type_name, cast("Callable[[Any], RuntimeExpr]", fn), cost)
|
|
53
54
|
|
|
54
55
|
|
|
55
|
-
def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None:
|
|
56
|
+
def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable[[Any], RuntimeExpr], cost: int) -> None:
|
|
56
57
|
"""
|
|
57
58
|
Registers a converter from some type to an egglog type, if not already registered.
|
|
58
59
|
|
|
@@ -97,15 +98,15 @@ class _ComposedConverter:
|
|
|
97
98
|
We use the dataclass instead of the lambda to make it easier to debug.
|
|
98
99
|
"""
|
|
99
100
|
|
|
100
|
-
a_b: Callable
|
|
101
|
-
b_c: Callable
|
|
101
|
+
a_b: Callable[[Any], RuntimeExpr]
|
|
102
|
+
b_c: Callable[[Any], RuntimeExpr]
|
|
102
103
|
b_args: tuple[JustTypeRef, ...]
|
|
103
104
|
|
|
104
|
-
def __call__(self, x:
|
|
105
|
+
def __call__(self, x: Any) -> RuntimeExpr:
|
|
105
106
|
# if we have A -> B and B[C] -> D then we should use (C,) as the type args
|
|
106
107
|
# when converting from A -> B
|
|
107
108
|
if self.b_args:
|
|
108
|
-
with with_type_args(self.b_args,
|
|
109
|
+
with with_type_args(self.b_args, retrieve_conversion_decls):
|
|
109
110
|
first_res = self.a_b(x)
|
|
110
111
|
else:
|
|
111
112
|
first_res = self.a_b(x)
|
|
@@ -142,36 +143,53 @@ def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
|
|
|
142
143
|
return tp
|
|
143
144
|
|
|
144
145
|
|
|
145
|
-
def
|
|
146
|
+
def min_binary_conversion(
|
|
147
|
+
method_name: str, lhs: type | JustTypeRef, rhs: type | JustTypeRef
|
|
148
|
+
) -> tuple[Callable[[Any], RuntimeExpr], Callable[[Any], RuntimeExpr]] | None:
|
|
146
149
|
"""
|
|
147
|
-
|
|
150
|
+
Given a binary method and two starting types for the LHS and RHS, return a pair of callable which will convert
|
|
151
|
+
the LHS and RHS to appropriate types which support this method. If no such conversion is possible, return None.
|
|
152
|
+
|
|
153
|
+
It should return the types which minimize the total conversion cost. If one of the types is a Python type, then
|
|
154
|
+
both of them can be converted. However, if both are egglog types, then only one of them can be converted.
|
|
155
|
+
"""
|
|
156
|
+
decls = retrieve_conversion_decls()
|
|
157
|
+
# tuple of (cost, convert lhs, convert rhs)
|
|
158
|
+
best_method: tuple[int, Callable[[Any], RuntimeExpr], Callable[[Any], RuntimeExpr]] | None = None
|
|
159
|
+
|
|
160
|
+
possible_lhs = _all_conversions_from(lhs) if isinstance(lhs, type) else [(0, lhs, identity)]
|
|
161
|
+
possible_rhs = _all_conversions_from(rhs) if isinstance(rhs, type) else [(0, rhs, identity)]
|
|
162
|
+
for lhs_cost, lhs_converted_type, lhs_convert in possible_lhs:
|
|
163
|
+
# Start by checking if we have a LHS that matches exactly and a RHS which can be converted
|
|
164
|
+
if (desired_other_type := decls.check_binary_method_with_self_type(method_name, lhs_converted_type)) and (
|
|
165
|
+
converter := CONVERSIONS.get((rhs, desired_other_type))
|
|
166
|
+
):
|
|
167
|
+
cost = lhs_cost + converter[0]
|
|
168
|
+
if best_method is None or best_method[0] > cost:
|
|
169
|
+
best_method = (cost, lhs_convert, converter[1])
|
|
170
|
+
|
|
171
|
+
for rhs_cost, rhs_converted_type, rhs_convert in possible_rhs:
|
|
172
|
+
# Next see if it's possible to convert the LHS and keep the RHS as is
|
|
173
|
+
for desired_self_type in decls.check_binary_method_with_other_type(method_name, rhs_converted_type):
|
|
174
|
+
if converter := CONVERSIONS.get((lhs, desired_self_type)):
|
|
175
|
+
cost = rhs_cost + converter[0]
|
|
176
|
+
if best_method is None or best_method[0] > cost:
|
|
177
|
+
best_method = (cost, converter[1], rhs_convert)
|
|
178
|
+
if best_method is None:
|
|
179
|
+
return None
|
|
180
|
+
return best_method[1], best_method[2]
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _all_conversions_from(tp: JustTypeRef | type) -> list[tuple[int, JustTypeRef, Callable[[Any], RuntimeExpr]]]:
|
|
184
|
+
"""
|
|
185
|
+
Get all conversions from a type to other types.
|
|
186
|
+
|
|
187
|
+
Returns a list of tuples of (cost, target type, conversion function).
|
|
148
188
|
"""
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
if not (
|
|
154
|
-
(isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
|
|
155
|
-
or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
|
|
156
|
-
):
|
|
157
|
-
raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
|
|
158
|
-
a_converts_to = {
|
|
159
|
-
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
|
|
160
|
-
}
|
|
161
|
-
b_converts_to = {
|
|
162
|
-
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
|
|
163
|
-
}
|
|
164
|
-
if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
|
|
165
|
-
a_converts_to[a_tp] = 0
|
|
166
|
-
if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
|
|
167
|
-
b_converts_to[b_tp] = 0
|
|
168
|
-
common = set(a_converts_to) & set(b_converts_to)
|
|
169
|
-
if not common:
|
|
170
|
-
raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
|
|
171
|
-
return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
def identity(x: object) -> object:
|
|
189
|
+
return [(cost, target, fn) for (source, target), (cost, fn) in CONVERSIONS.items() if source == tp]
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def identity(x: Any) -> Any:
|
|
175
193
|
return x
|
|
176
194
|
|
|
177
195
|
|
|
@@ -197,7 +215,7 @@ def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declaratio
|
|
|
197
215
|
def resolve_literal(
|
|
198
216
|
tp: TypeOrVarRef,
|
|
199
217
|
arg: object,
|
|
200
|
-
decls: Callable[[], Declarations] =
|
|
218
|
+
decls: Callable[[], Declarations] = retrieve_conversion_decls,
|
|
201
219
|
tcs: TypeConstraintSolver | None = None,
|
|
202
220
|
cls_name: str | None = None,
|
|
203
221
|
) -> RuntimeExpr:
|
|
@@ -208,12 +226,12 @@ def resolve_literal(
|
|
|
208
226
|
|
|
209
227
|
If it cannot be resolved, we assume that the value passed in will resolve it.
|
|
210
228
|
"""
|
|
211
|
-
arg_type =
|
|
229
|
+
arg_type = resolve_type(arg)
|
|
212
230
|
|
|
213
231
|
# If we have any type variables, dont bother trying to resolve the literal, just return the arg
|
|
214
232
|
try:
|
|
215
233
|
tp_just = tp.to_just()
|
|
216
|
-
except
|
|
234
|
+
except TypeVarError:
|
|
217
235
|
# If this is a generic arg but passed in a non runtime expression, try to resolve the generic
|
|
218
236
|
# args first based on the existing type constraint solver
|
|
219
237
|
if tcs:
|
|
@@ -258,7 +276,7 @@ def _debug_print_converers():
|
|
|
258
276
|
source_to_targets[source].append(target)
|
|
259
277
|
|
|
260
278
|
|
|
261
|
-
def
|
|
279
|
+
def resolve_type(x: object) -> JustTypeRef | type:
|
|
262
280
|
if isinstance(x, RuntimeExpr):
|
|
263
281
|
return x.__egg_typed_expr__.tp
|
|
264
282
|
tp = type(x)
|
egglog/declarations.py
CHANGED
|
@@ -51,6 +51,7 @@ __all__ = [
|
|
|
51
51
|
"InitRef",
|
|
52
52
|
"JustTypeRef",
|
|
53
53
|
"LetDecl",
|
|
54
|
+
"LetRefDecl",
|
|
54
55
|
"LitDecl",
|
|
55
56
|
"LitType",
|
|
56
57
|
"MethodRef",
|
|
@@ -72,16 +73,18 @@ __all__ = [
|
|
|
72
73
|
"SpecialFunctions",
|
|
73
74
|
"TypeOrVarRef",
|
|
74
75
|
"TypeRefWithVars",
|
|
76
|
+
"TypeVarError",
|
|
75
77
|
"TypedExprDecl",
|
|
78
|
+
"UnboundVarDecl",
|
|
76
79
|
"UnionDecl",
|
|
77
80
|
"UnnamedFunctionRef",
|
|
78
|
-
"
|
|
81
|
+
"collect_unbound_vars",
|
|
79
82
|
"replace_typed_expr",
|
|
80
83
|
"upcast_declerations",
|
|
81
84
|
]
|
|
82
85
|
|
|
83
86
|
|
|
84
|
-
@dataclass
|
|
87
|
+
@dataclass(match_args=False)
|
|
85
88
|
class DelayedDeclerations:
|
|
86
89
|
__egg_decls_thunk__: Callable[[], Declarations] = field(repr=False)
|
|
87
90
|
|
|
@@ -93,7 +96,7 @@ class DelayedDeclerations:
|
|
|
93
96
|
# Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
|
|
94
97
|
# instead raise explicitly
|
|
95
98
|
except AttributeError as err:
|
|
96
|
-
msg = f"Cannot resolve declarations for {self}"
|
|
99
|
+
msg = f"Cannot resolve declarations for {self}: {err}"
|
|
97
100
|
raise RuntimeError(msg) from err
|
|
98
101
|
|
|
99
102
|
|
|
@@ -223,14 +226,46 @@ class Declarations:
|
|
|
223
226
|
case _:
|
|
224
227
|
assert_never(ref)
|
|
225
228
|
|
|
226
|
-
def
|
|
229
|
+
def check_binary_method_with_types(self, method_name: str, self_type: JustTypeRef, other_type: JustTypeRef) -> bool:
|
|
230
|
+
"""
|
|
231
|
+
Checks if the class has a binary method compatible with the given types.
|
|
232
|
+
"""
|
|
233
|
+
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
|
|
234
|
+
if callable_decl := self._classes[self_type.name].methods.get(method_name):
|
|
235
|
+
match callable_decl.signature:
|
|
236
|
+
case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(
|
|
237
|
+
vars, self_type
|
|
238
|
+
) and other_arg_type.matches_just(vars, other_type):
|
|
239
|
+
return True
|
|
240
|
+
return False
|
|
241
|
+
|
|
242
|
+
def check_binary_method_with_self_type(self, method_name: str, self_type: JustTypeRef) -> JustTypeRef | None:
|
|
227
243
|
"""
|
|
228
|
-
|
|
244
|
+
Checks if the class has a binary method with the given name and self type. Returns the other type if it exists.
|
|
229
245
|
"""
|
|
230
|
-
|
|
231
|
-
|
|
246
|
+
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
|
|
247
|
+
class_decl = self._classes.get(self_type.name)
|
|
248
|
+
if class_decl is None:
|
|
249
|
+
return None
|
|
250
|
+
if callable_decl := class_decl.methods.get(method_name):
|
|
251
|
+
match callable_decl.signature:
|
|
252
|
+
case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(vars, self_type):
|
|
253
|
+
return other_arg_type.to_just(vars)
|
|
232
254
|
return None
|
|
233
255
|
|
|
256
|
+
def check_binary_method_with_other_type(self, method_name: str, other_type: JustTypeRef) -> Iterable[JustTypeRef]:
|
|
257
|
+
"""
|
|
258
|
+
Returns the types which are compatible with the given binary method name and other type.
|
|
259
|
+
"""
|
|
260
|
+
for class_decl in self._classes.values():
|
|
261
|
+
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
|
|
262
|
+
if callable_decl := class_decl.methods.get(method_name):
|
|
263
|
+
match callable_decl.signature:
|
|
264
|
+
case FunctionSignature((self_arg_type, other_arg_type)) if other_arg_type.matches_just(
|
|
265
|
+
vars, other_type
|
|
266
|
+
):
|
|
267
|
+
yield self_arg_type.to_just(vars)
|
|
268
|
+
|
|
234
269
|
def get_class_decl(self, name: str) -> ClassDecl:
|
|
235
270
|
return self._classes[name]
|
|
236
271
|
|
|
@@ -254,6 +289,7 @@ class ClassDecl:
|
|
|
254
289
|
methods: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict)
|
|
255
290
|
properties: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict)
|
|
256
291
|
preserved_methods: dict[str, Callable] = field(default_factory=dict)
|
|
292
|
+
match_args: tuple[str, ...] = field(default=())
|
|
257
293
|
|
|
258
294
|
|
|
259
295
|
@dataclass(frozen=True)
|
|
@@ -298,6 +334,10 @@ class JustTypeRef:
|
|
|
298
334
|
_RESOLVED_TYPEVARS: dict[ClassTypeVarRef, TypeVar] = {}
|
|
299
335
|
|
|
300
336
|
|
|
337
|
+
class TypeVarError(RuntimeError):
|
|
338
|
+
"""Error when trying to resolve a type variable that doesn't exist."""
|
|
339
|
+
|
|
340
|
+
|
|
301
341
|
@dataclass(frozen=True)
|
|
302
342
|
class ClassTypeVarRef:
|
|
303
343
|
"""
|
|
@@ -307,9 +347,10 @@ class ClassTypeVarRef:
|
|
|
307
347
|
name: str
|
|
308
348
|
module: str
|
|
309
349
|
|
|
310
|
-
def to_just(self) -> JustTypeRef:
|
|
311
|
-
|
|
312
|
-
|
|
350
|
+
def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef:
|
|
351
|
+
if vars is None or self not in vars:
|
|
352
|
+
raise TypeVarError(f"Cannot convert type variable {self} to concrete type without variable bindings")
|
|
353
|
+
return vars[self]
|
|
313
354
|
|
|
314
355
|
def __str__(self) -> str:
|
|
315
356
|
return str(self.to_type_var())
|
|
@@ -323,20 +364,39 @@ class ClassTypeVarRef:
|
|
|
323
364
|
def to_type_var(self) -> TypeVar:
|
|
324
365
|
return _RESOLVED_TYPEVARS[self]
|
|
325
366
|
|
|
367
|
+
def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool:
|
|
368
|
+
"""
|
|
369
|
+
Checks if this type variable matches the given JustTypeRef, including type variables.
|
|
370
|
+
"""
|
|
371
|
+
if self in vars:
|
|
372
|
+
return vars[self] == other
|
|
373
|
+
vars[self] = other
|
|
374
|
+
return True
|
|
375
|
+
|
|
326
376
|
|
|
327
377
|
@dataclass(frozen=True)
|
|
328
378
|
class TypeRefWithVars:
|
|
329
379
|
name: str
|
|
330
380
|
args: tuple[TypeOrVarRef, ...] = ()
|
|
331
381
|
|
|
332
|
-
def to_just(self) -> JustTypeRef:
|
|
333
|
-
return JustTypeRef(self.name, tuple(a.to_just() for a in self.args))
|
|
382
|
+
def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef:
|
|
383
|
+
return JustTypeRef(self.name, tuple(a.to_just(vars) for a in self.args))
|
|
334
384
|
|
|
335
385
|
def __str__(self) -> str:
|
|
336
386
|
if self.args:
|
|
337
387
|
return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
|
|
338
388
|
return self.name
|
|
339
389
|
|
|
390
|
+
def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool:
|
|
391
|
+
"""
|
|
392
|
+
Checks if this type reference matches the given JustTypeRef, including type variables.
|
|
393
|
+
"""
|
|
394
|
+
return (
|
|
395
|
+
self.name == other.name
|
|
396
|
+
and len(self.args) == len(other.args)
|
|
397
|
+
and all(a.matches_just(vars, b) for a, b in zip(self.args, other.args, strict=True))
|
|
398
|
+
)
|
|
399
|
+
|
|
340
400
|
|
|
341
401
|
TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars
|
|
342
402
|
|
|
@@ -361,7 +421,7 @@ class UnnamedFunctionRef:
|
|
|
361
421
|
arg_names = []
|
|
362
422
|
for a in self.args:
|
|
363
423
|
arg_types.append(a.tp.to_var())
|
|
364
|
-
assert isinstance(a.expr,
|
|
424
|
+
assert isinstance(a.expr, UnboundVarDecl)
|
|
365
425
|
arg_names.append(a.expr.name)
|
|
366
426
|
return FunctionSignature(
|
|
367
427
|
arg_types=tuple(arg_types),
|
|
@@ -514,10 +574,14 @@ CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl | Construct
|
|
|
514
574
|
|
|
515
575
|
|
|
516
576
|
@dataclass(frozen=True)
|
|
517
|
-
class
|
|
577
|
+
class UnboundVarDecl:
|
|
578
|
+
name: str
|
|
579
|
+
egg_name: str | None = None
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
@dataclass(frozen=True)
|
|
583
|
+
class LetRefDecl:
|
|
518
584
|
name: str
|
|
519
|
-
# Differentiate between let bound vars and vars created in rules so that they won't shadow in egglog, by adding a prefix
|
|
520
|
-
is_let: bool
|
|
521
585
|
|
|
522
586
|
|
|
523
587
|
@dataclass(frozen=True)
|
|
@@ -628,7 +692,7 @@ class PartialCallDecl:
|
|
|
628
692
|
call: CallDecl
|
|
629
693
|
|
|
630
694
|
|
|
631
|
-
ExprDecl: TypeAlias =
|
|
695
|
+
ExprDecl: TypeAlias = UnboundVarDecl | LetRefDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl
|
|
632
696
|
|
|
633
697
|
|
|
634
698
|
@dataclass(frozen=True)
|
|
@@ -678,6 +742,28 @@ def replace_typed_expr(typed_expr: TypedExprDecl, replacements: Mapping[TypedExp
|
|
|
678
742
|
return _inner(typed_expr)
|
|
679
743
|
|
|
680
744
|
|
|
745
|
+
def collect_unbound_vars(typed_expr: TypedExprDecl) -> set[TypedExprDecl]:
|
|
746
|
+
"""
|
|
747
|
+
Returns the set of all unbound vars
|
|
748
|
+
"""
|
|
749
|
+
seen = set[TypedExprDecl]()
|
|
750
|
+
unbound_vars = set[TypedExprDecl]()
|
|
751
|
+
|
|
752
|
+
def visit(typed_expr: TypedExprDecl) -> None:
|
|
753
|
+
if typed_expr in seen:
|
|
754
|
+
return
|
|
755
|
+
seen.add(typed_expr)
|
|
756
|
+
match typed_expr.expr:
|
|
757
|
+
case CallDecl(_, args) | PartialCallDecl(CallDecl(_, args)):
|
|
758
|
+
for arg in args:
|
|
759
|
+
visit(arg)
|
|
760
|
+
case UnboundVarDecl(_):
|
|
761
|
+
unbound_vars.add(typed_expr)
|
|
762
|
+
|
|
763
|
+
visit(typed_expr)
|
|
764
|
+
return unbound_vars
|
|
765
|
+
|
|
766
|
+
|
|
681
767
|
##
|
|
682
768
|
# Schedules
|
|
683
769
|
##
|
egglog/deconstruct.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions to deconstruct expressions in Python.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from functools import partial
|
|
9
|
+
from typing import TYPE_CHECKING, TypeVar, overload
|
|
10
|
+
|
|
11
|
+
from typing_extensions import TypeVarTuple, Unpack
|
|
12
|
+
|
|
13
|
+
from .declarations import *
|
|
14
|
+
from .egraph import BaseExpr
|
|
15
|
+
from .runtime import *
|
|
16
|
+
from .thunk import *
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from .builtins import Bool, PyObject, String, UnstableFn, f64, i64
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
T = TypeVar("T", bound=BaseExpr)
|
|
23
|
+
TS = TypeVarTuple("TS", default=Unpack[tuple[BaseExpr, ...]])
|
|
24
|
+
|
|
25
|
+
__all__ = ["get_callable_args", "get_callable_fn", "get_let_name", "get_literal_value", "get_var_name"]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@overload
|
|
29
|
+
def get_literal_value(x: String) -> str | None: ...
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@overload
|
|
33
|
+
def get_literal_value(x: Bool) -> bool | None: ...
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@overload
|
|
37
|
+
def get_literal_value(x: i64) -> int | None: ...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@overload
|
|
41
|
+
def get_literal_value(x: f64) -> float | None: ...
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@overload
|
|
45
|
+
def get_literal_value(x: PyObject) -> object: ...
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@overload
|
|
49
|
+
def get_literal_value(x: UnstableFn[T, Unpack[TS]]) -> Callable[[Unpack[TS]], T] | None: ...
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_literal_value(x: String | Bool | i64 | f64 | PyObject | UnstableFn) -> object:
|
|
53
|
+
"""
|
|
54
|
+
Returns the literal value of an expression if it is a literal.
|
|
55
|
+
If it is not a literal, returns None.
|
|
56
|
+
"""
|
|
57
|
+
if not isinstance(x, RuntimeExpr):
|
|
58
|
+
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
59
|
+
match x.__egg_typed_expr__.expr:
|
|
60
|
+
case LitDecl(v):
|
|
61
|
+
return v
|
|
62
|
+
case PyObjectDecl(obj):
|
|
63
|
+
return obj
|
|
64
|
+
case PartialCallDecl(call):
|
|
65
|
+
fn, args = _deconstruct_call_decl(x.__egg_decls_thunk__, call)
|
|
66
|
+
if not args:
|
|
67
|
+
return fn
|
|
68
|
+
return partial(fn, *args)
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_let_name(x: BaseExpr) -> str | None:
|
|
73
|
+
"""
|
|
74
|
+
Check if the expression is a `let` expression and return the name of the variable.
|
|
75
|
+
If it is not a `let` expression, return None.
|
|
76
|
+
"""
|
|
77
|
+
if not isinstance(x, RuntimeExpr):
|
|
78
|
+
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
79
|
+
match x.__egg_typed_expr__.expr:
|
|
80
|
+
case LetRefDecl(name):
|
|
81
|
+
return name
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def get_var_name(x: BaseExpr) -> str | None:
|
|
86
|
+
"""
|
|
87
|
+
Check if the expression is a variable and return its name.
|
|
88
|
+
If it is not a variable, return None.
|
|
89
|
+
"""
|
|
90
|
+
if not isinstance(x, RuntimeExpr):
|
|
91
|
+
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
92
|
+
match x.__egg_typed_expr__.expr:
|
|
93
|
+
case UnboundVarDecl(name, _egg_name):
|
|
94
|
+
return name
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_callable_fn(x: T) -> Callable[..., T] | None:
|
|
99
|
+
"""
|
|
100
|
+
Gets the function of an expression if it is a call expression.
|
|
101
|
+
If it is not a call expression (a property, a primitive value, constants, classvars, a let value), return None.
|
|
102
|
+
For those values, you can check them by comparing them directly with equality or for primitives calling `.eval()`
|
|
103
|
+
to return the Python value.
|
|
104
|
+
"""
|
|
105
|
+
if not isinstance(x, RuntimeExpr):
|
|
106
|
+
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
107
|
+
match x.__egg_typed_expr__.expr:
|
|
108
|
+
case CallDecl() as call:
|
|
109
|
+
fn, _ = _deconstruct_call_decl(x.__egg_decls_thunk__, call)
|
|
110
|
+
return fn
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@overload
|
|
115
|
+
def get_callable_args(x: T, fn: None = ...) -> tuple[BaseExpr, ...]: ...
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@overload
|
|
119
|
+
def get_callable_args(x: T, fn: Callable[[Unpack[TS]], T]) -> tuple[Unpack[TS]] | None: ...
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def get_callable_args(x: T, fn: Callable[[Unpack[TS]], T] | None = None) -> tuple[Unpack[TS]] | None:
|
|
123
|
+
"""
|
|
124
|
+
Gets all the arguments of an expression.
|
|
125
|
+
If a function is provided, it will only return the arguments if the expression is a call
|
|
126
|
+
to that function.
|
|
127
|
+
|
|
128
|
+
Note that recursively calling the arguments is the safe way to walk the expression tree.
|
|
129
|
+
"""
|
|
130
|
+
if not isinstance(x, RuntimeExpr):
|
|
131
|
+
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
132
|
+
match x.__egg_typed_expr__.expr:
|
|
133
|
+
case CallDecl() as call:
|
|
134
|
+
actual_fn, args = _deconstruct_call_decl(x.__egg_decls_thunk__, call)
|
|
135
|
+
if fn is None:
|
|
136
|
+
return args
|
|
137
|
+
# Compare functions and classes without considering bound type parameters, so that you can pass
|
|
138
|
+
# in a binding like Vec[i64] and match Vec[i64](...) or Vec(...) calls.
|
|
139
|
+
if (
|
|
140
|
+
isinstance(actual_fn, RuntimeFunction)
|
|
141
|
+
and isinstance(fn, RuntimeFunction)
|
|
142
|
+
and actual_fn.__egg_ref__ == fn.__egg_ref__
|
|
143
|
+
):
|
|
144
|
+
return args
|
|
145
|
+
if (
|
|
146
|
+
isinstance(actual_fn, RuntimeClass)
|
|
147
|
+
and isinstance(fn, RuntimeClass)
|
|
148
|
+
and actual_fn.__egg_tp__.name == fn.__egg_tp__.name
|
|
149
|
+
):
|
|
150
|
+
return args
|
|
151
|
+
return None
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _deconstruct_call_decl(
|
|
155
|
+
decls_thunk: Callable[[], Declarations], call: CallDecl
|
|
156
|
+
) -> tuple[Callable, tuple[object, ...]]:
|
|
157
|
+
"""
|
|
158
|
+
Deconstructs a CallDecl into a runtime callable and its arguments.
|
|
159
|
+
"""
|
|
160
|
+
args = call.args
|
|
161
|
+
arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args)
|
|
162
|
+
if isinstance(call.callable, InitRef):
|
|
163
|
+
return RuntimeClass(
|
|
164
|
+
decls_thunk,
|
|
165
|
+
TypeRefWithVars(call.callable.class_name, tuple(tp.to_var() for tp in (call.bound_tp_params or []))),
|
|
166
|
+
), arg_exprs
|
|
167
|
+
egg_bound = (
|
|
168
|
+
JustTypeRef(call.callable.class_name, call.bound_tp_params or ())
|
|
169
|
+
if isinstance(call.callable, ClassMethodRef)
|
|
170
|
+
else None
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return RuntimeFunction(decls_thunk, Thunk.value(call.callable), egg_bound), arg_exprs
|