egglog 7.0.0__cp312-none-win_amd64.whl → 7.1.0__cp312-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp312-win_amd64.pyd +0 -0
- egglog/bindings.pyi +7 -0
- egglog/builtins.py +41 -1
- egglog/conversion.py +22 -17
- egglog/declarations.py +57 -31
- egglog/egraph.py +93 -18
- egglog/egraph_state.py +76 -37
- egglog/exp/array_api.py +8 -8
- egglog/pretty.py +56 -10
- egglog/runtime.py +112 -30
- egglog/thunk.py +1 -2
- {egglog-7.0.0.dist-info → egglog-7.1.0.dist-info}/METADATA +20 -20
- {egglog-7.0.0.dist-info → egglog-7.1.0.dist-info}/RECORD +15 -15
- {egglog-7.0.0.dist-info → egglog-7.1.0.dist-info}/WHEEL +0 -0
- {egglog-7.0.0.dist-info → egglog-7.1.0.dist-info}/license_files/LICENSE +0 -0
|
Binary file
|
egglog/bindings.pyi
CHANGED
|
@@ -494,6 +494,12 @@ class Relation:
|
|
|
494
494
|
class PrintOverallStatistics:
|
|
495
495
|
def __init__(self) -> None: ...
|
|
496
496
|
|
|
497
|
+
@final
|
|
498
|
+
class UnstableCombinedRuleset:
|
|
499
|
+
name: str
|
|
500
|
+
rulesets: list[str]
|
|
501
|
+
def __init__(self, name: str, rulesets: list[str]) -> None: ...
|
|
502
|
+
|
|
497
503
|
_Command: TypeAlias = (
|
|
498
504
|
SetOption
|
|
499
505
|
| Datatype
|
|
@@ -521,6 +527,7 @@ _Command: TypeAlias = (
|
|
|
521
527
|
| CheckProof
|
|
522
528
|
| Relation
|
|
523
529
|
| PrintOverallStatistics
|
|
530
|
+
| UnstableCombinedRuleset
|
|
524
531
|
)
|
|
525
532
|
|
|
526
533
|
def termdag_term_to_expr(termdag: TermDag, term: _Term) -> _Expr: ...
|
egglog/builtins.py
CHANGED
|
@@ -5,10 +5,14 @@ Builtin sorts and function to egg.
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
|
-
from
|
|
8
|
+
from functools import partial
|
|
9
|
+
from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union, overload
|
|
10
|
+
|
|
11
|
+
from typing_extensions import TypeVarTuple, Unpack
|
|
9
12
|
|
|
10
13
|
from .conversion import converter
|
|
11
14
|
from .egraph import Expr, Unit, function, method
|
|
15
|
+
from .runtime import RuntimeFunction
|
|
12
16
|
|
|
13
17
|
if TYPE_CHECKING:
|
|
14
18
|
from collections.abc import Callable
|
|
@@ -31,6 +35,7 @@ __all__ = [
|
|
|
31
35
|
"py_eval",
|
|
32
36
|
"py_exec",
|
|
33
37
|
"py_eval_fn",
|
|
38
|
+
"UnstableFn",
|
|
34
39
|
]
|
|
35
40
|
|
|
36
41
|
|
|
@@ -461,3 +466,38 @@ def py_exec(code: StringLike, globals: object = PyObject.dict(), locals: object
|
|
|
461
466
|
"""
|
|
462
467
|
Copies the locals, execs the Python code, and returns the locals with any updates.
|
|
463
468
|
"""
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
TS = TypeVarTuple("TS")
|
|
472
|
+
|
|
473
|
+
T1 = TypeVar("T1")
|
|
474
|
+
T2 = TypeVar("T2")
|
|
475
|
+
T3 = TypeVar("T3")
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
class UnstableFn(Expr, Generic[T, Unpack[TS]], builtin=True):
|
|
479
|
+
@overload
|
|
480
|
+
def __init__(self, f: Callable[[Unpack[TS]], T]) -> None: ...
|
|
481
|
+
|
|
482
|
+
@overload
|
|
483
|
+
def __init__(self, f: Callable[[T1, Unpack[TS]], T], _a: T1, /) -> None: ...
|
|
484
|
+
|
|
485
|
+
@overload
|
|
486
|
+
def __init__(self, f: Callable[[T1, T2, Unpack[TS]], T], _a: T1, _b: T2, /) -> None: ...
|
|
487
|
+
|
|
488
|
+
# Removing due to bug in MyPy
|
|
489
|
+
# https://github.com/python/mypy/issues/17212
|
|
490
|
+
# @overload
|
|
491
|
+
# def __init__(self, f: Callable[[T1, T2, T3, Unpack[TS]], T], _a: T1, _b: T2, _c: T3, /) -> None: ...
|
|
492
|
+
|
|
493
|
+
# etc, for partial application
|
|
494
|
+
|
|
495
|
+
@method(egg_fn="unstable-fn")
|
|
496
|
+
def __init__(self, f, *partial) -> None: ...
|
|
497
|
+
|
|
498
|
+
@method(egg_fn="unstable-app")
|
|
499
|
+
def __call__(self, *args: Unpack[TS]) -> T: ...
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
converter(RuntimeFunction, UnstableFn, UnstableFn)
|
|
503
|
+
converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args))
|
egglog/conversion.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import TYPE_CHECKING, TypeVar, cast
|
|
4
|
+
from typing import TYPE_CHECKING, NewType, TypeVar, cast
|
|
5
5
|
|
|
6
6
|
from .declarations import *
|
|
7
7
|
from .pretty import *
|
|
@@ -16,7 +16,8 @@ if TYPE_CHECKING:
|
|
|
16
16
|
|
|
17
17
|
__all__ = ["convert", "converter", "resolve_literal", "convert_to_same_type"]
|
|
18
18
|
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
|
|
19
|
-
|
|
19
|
+
TypeName = NewType("TypeName", str)
|
|
20
|
+
CONVERSIONS: dict[tuple[type | TypeName, TypeName], tuple[int, Callable]] = {}
|
|
20
21
|
# Global declerations to store all convertable types so we can query if they have certain methods or not
|
|
21
22
|
# Defer it as a thunk so we can register conversions without triggering type signature loading
|
|
22
23
|
CONVERSIONS_DECLS: Callable[[], Declarations] = Thunk.value(Declarations())
|
|
@@ -34,12 +35,12 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost:
|
|
|
34
35
|
Register a converter from some type to an egglog type.
|
|
35
36
|
"""
|
|
36
37
|
to_type_name = process_tp(to_type)
|
|
37
|
-
if not isinstance(to_type_name,
|
|
38
|
+
if not isinstance(to_type_name, str):
|
|
38
39
|
raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
|
|
39
40
|
_register_converter(process_tp(from_type), to_type_name, fn, cost)
|
|
40
41
|
|
|
41
42
|
|
|
42
|
-
def _register_converter(a: type |
|
|
43
|
+
def _register_converter(a: type | TypeName, b: TypeName, a_b: Callable, cost: int) -> None:
|
|
43
44
|
"""
|
|
44
45
|
Registers a converter from some type to an egglog type, if not already registered.
|
|
45
46
|
|
|
@@ -94,14 +95,17 @@ def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
|
|
|
94
95
|
return resolve_literal(tp.to_var(), source)
|
|
95
96
|
|
|
96
97
|
|
|
97
|
-
def process_tp(tp: type | RuntimeClass) ->
|
|
98
|
+
def process_tp(tp: type | RuntimeClass) -> TypeName | type:
|
|
98
99
|
"""
|
|
99
100
|
Process a type before converting it, to add it to the global declerations and resolve to a ref.
|
|
100
101
|
"""
|
|
101
102
|
global CONVERSIONS_DECLS
|
|
102
103
|
if isinstance(tp, RuntimeClass):
|
|
103
104
|
CONVERSIONS_DECLS = Thunk.fn(_combine_decls, CONVERSIONS_DECLS, tp)
|
|
104
|
-
|
|
105
|
+
egg_tp = tp.__egg_tp__
|
|
106
|
+
if egg_tp.args:
|
|
107
|
+
raise TypeError(f"Cannot register a converter for a generic type, got {tp}")
|
|
108
|
+
return TypeName(egg_tp.name)
|
|
105
109
|
return tp
|
|
106
110
|
|
|
107
111
|
|
|
@@ -109,7 +113,7 @@ def _combine_decls(d: Callable[[], Declarations], x: HasDeclerations) -> Declara
|
|
|
109
113
|
return Declarations.create(d(), x)
|
|
110
114
|
|
|
111
115
|
|
|
112
|
-
def min_convertable_tp(a: object, b: object, name: str) ->
|
|
116
|
+
def min_convertable_tp(a: object, b: object, name: str) -> TypeName:
|
|
113
117
|
"""
|
|
114
118
|
Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
|
|
115
119
|
"""
|
|
@@ -117,14 +121,14 @@ def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
|
|
|
117
121
|
a_tp = _get_tp(a)
|
|
118
122
|
b_tp = _get_tp(b)
|
|
119
123
|
a_converts_to = {
|
|
120
|
-
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to
|
|
124
|
+
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to, name)
|
|
121
125
|
}
|
|
122
126
|
b_converts_to = {
|
|
123
|
-
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to
|
|
127
|
+
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to, name)
|
|
124
128
|
}
|
|
125
|
-
if isinstance(a_tp,
|
|
129
|
+
if isinstance(a_tp, str):
|
|
126
130
|
a_converts_to[a_tp] = 0
|
|
127
|
-
if isinstance(b_tp,
|
|
131
|
+
if isinstance(b_tp, str):
|
|
128
132
|
b_converts_to[b_tp] = 0
|
|
129
133
|
common = set(a_converts_to) & set(b_converts_to)
|
|
130
134
|
if not common:
|
|
@@ -143,28 +147,29 @@ def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
|
|
|
143
147
|
try:
|
|
144
148
|
tp_just = tp.to_just()
|
|
145
149
|
except NotImplementedError:
|
|
146
|
-
# If this is a var, it has to be a runtime
|
|
150
|
+
# If this is a var, it has to be a runtime expession
|
|
147
151
|
assert isinstance(arg, RuntimeExpr)
|
|
148
152
|
return arg
|
|
149
|
-
|
|
153
|
+
tp_name = TypeName(tp_just.name)
|
|
154
|
+
if arg_type == tp_name:
|
|
150
155
|
# If the type is an egg type, it has to be a runtime expr
|
|
151
156
|
assert isinstance(arg, RuntimeExpr)
|
|
152
157
|
return arg
|
|
153
158
|
# Try all parent types as well, if we are converting from a Python type
|
|
154
159
|
for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
|
|
155
160
|
try:
|
|
156
|
-
fn = CONVERSIONS[(cast(
|
|
161
|
+
fn = CONVERSIONS[(cast(TypeName | type, arg_type_instance), tp_name)][1]
|
|
157
162
|
except KeyError:
|
|
158
163
|
continue
|
|
159
164
|
break
|
|
160
165
|
else:
|
|
161
|
-
raise ConvertError(f"Cannot convert {arg_type} to {
|
|
166
|
+
raise ConvertError(f"Cannot convert {arg_type} to {tp_name}")
|
|
162
167
|
return fn(arg)
|
|
163
168
|
|
|
164
169
|
|
|
165
|
-
def _get_tp(x: object) ->
|
|
170
|
+
def _get_tp(x: object) -> TypeName | type:
|
|
166
171
|
if isinstance(x, RuntimeExpr):
|
|
167
|
-
return x.__egg_typed_expr__.tp
|
|
172
|
+
return TypeName(x.__egg_typed_expr__.tp.name)
|
|
168
173
|
tp = type(x)
|
|
169
174
|
# If this value has a custom metaclass, let's use that as our index instead of the type
|
|
170
175
|
if type(tp) != type:
|
egglog/declarations.py
CHANGED
|
@@ -39,6 +39,7 @@ __all__ = [
|
|
|
39
39
|
"CallableDecl",
|
|
40
40
|
"VarDecl",
|
|
41
41
|
"PyObjectDecl",
|
|
42
|
+
"PartialCallDecl",
|
|
42
43
|
"LitType",
|
|
43
44
|
"LitDecl",
|
|
44
45
|
"CallDecl",
|
|
@@ -46,6 +47,7 @@ __all__ = [
|
|
|
46
47
|
"TypedExprDecl",
|
|
47
48
|
"ClassDecl",
|
|
48
49
|
"RulesetDecl",
|
|
50
|
+
"CombinedRulesetDecl",
|
|
49
51
|
"SaturateDecl",
|
|
50
52
|
"RepeatDecl",
|
|
51
53
|
"SequenceDecl",
|
|
@@ -67,6 +69,8 @@ __all__ = [
|
|
|
67
69
|
"RewriteOrRuleDecl",
|
|
68
70
|
"ActionCommandDecl",
|
|
69
71
|
"CommandDecl",
|
|
72
|
+
"SpecialFunctions",
|
|
73
|
+
"FunctionSignature",
|
|
70
74
|
]
|
|
71
75
|
|
|
72
76
|
|
|
@@ -88,9 +92,6 @@ class HasDeclerations(Protocol):
|
|
|
88
92
|
DeclerationsLike: TypeAlias = Union[HasDeclerations, None, "Declarations"]
|
|
89
93
|
|
|
90
94
|
|
|
91
|
-
# TODO: Make all ClassDecls take deferred type refs, which return new decls when resolving.
|
|
92
|
-
|
|
93
|
-
|
|
94
95
|
def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]:
|
|
95
96
|
d = []
|
|
96
97
|
for l in declerations_like:
|
|
@@ -110,7 +111,7 @@ class Declarations:
|
|
|
110
111
|
_functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict)
|
|
111
112
|
_constants: dict[str, ConstantDecl] = field(default_factory=dict)
|
|
112
113
|
_classes: dict[str, ClassDecl] = field(default_factory=dict)
|
|
113
|
-
_rulesets: dict[str, RulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
|
|
114
|
+
_rulesets: dict[str, RulesetDecl | CombinedRulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
|
|
114
115
|
|
|
115
116
|
@classmethod
|
|
116
117
|
def create(cls, *others: DeclerationsLike) -> Declarations:
|
|
@@ -196,7 +197,7 @@ class ClassDecl:
|
|
|
196
197
|
preserved_methods: dict[str, Callable] = field(default_factory=dict)
|
|
197
198
|
|
|
198
199
|
|
|
199
|
-
@dataclass
|
|
200
|
+
@dataclass(frozen=True)
|
|
200
201
|
class RulesetDecl:
|
|
201
202
|
rules: list[RewriteOrRuleDecl]
|
|
202
203
|
|
|
@@ -206,6 +207,11 @@ class RulesetDecl:
|
|
|
206
207
|
return hash((type(self), tuple(self.rules)))
|
|
207
208
|
|
|
208
209
|
|
|
210
|
+
@dataclass(frozen=True)
|
|
211
|
+
class CombinedRulesetDecl:
|
|
212
|
+
rulesets: tuple[str, ...]
|
|
213
|
+
|
|
214
|
+
|
|
209
215
|
# Have two different types of type refs, one that can include vars recursively and one that cannot.
|
|
210
216
|
# We only use the one with vars for classmethods and methods, and the other one for egg references as
|
|
211
217
|
# well as runtime values.
|
|
@@ -316,10 +322,12 @@ class RelationDecl:
|
|
|
316
322
|
|
|
317
323
|
def to_function_decl(self) -> FunctionDecl:
|
|
318
324
|
return FunctionDecl(
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
325
|
+
FunctionSignature(
|
|
326
|
+
arg_types=tuple(a.to_var() for a in self.arg_types),
|
|
327
|
+
arg_names=tuple(f"__{i}" for i in range(len(self.arg_types))),
|
|
328
|
+
arg_defaults=self.arg_defaults,
|
|
329
|
+
return_type=TypeRefWithVars("Unit"),
|
|
330
|
+
),
|
|
323
331
|
egg_name=self.egg_name,
|
|
324
332
|
default=LitDecl(None),
|
|
325
333
|
)
|
|
@@ -336,25 +344,41 @@ class ConstantDecl:
|
|
|
336
344
|
|
|
337
345
|
def to_function_decl(self) -> FunctionDecl:
|
|
338
346
|
return FunctionDecl(
|
|
339
|
-
|
|
340
|
-
arg_names=(),
|
|
341
|
-
arg_defaults=(),
|
|
342
|
-
return_type=self.type_ref.to_var(),
|
|
347
|
+
FunctionSignature(return_type=self.type_ref.to_var()),
|
|
343
348
|
egg_name=self.egg_name,
|
|
344
349
|
)
|
|
345
350
|
|
|
346
351
|
|
|
352
|
+
# special cases for partial function creation and application, which cannot use the normal python rules
|
|
353
|
+
SpecialFunctions: TypeAlias = Literal["fn-partial", "fn-app"]
|
|
354
|
+
|
|
355
|
+
|
|
347
356
|
@dataclass(frozen=True)
|
|
348
|
-
class
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
arg_names: tuple[str, ...]
|
|
357
|
+
class FunctionSignature:
|
|
358
|
+
arg_types: tuple[TypeOrVarRef, ...] = ()
|
|
359
|
+
arg_names: tuple[str, ...] = ()
|
|
352
360
|
# List of defaults. None for any arg which doesn't have one.
|
|
353
|
-
arg_defaults: tuple[ExprDecl | None, ...]
|
|
361
|
+
arg_defaults: tuple[ExprDecl | None, ...] = ()
|
|
354
362
|
# If None, then the first arg is mutated and returned
|
|
355
|
-
return_type: TypeOrVarRef | None
|
|
363
|
+
return_type: TypeOrVarRef | None = None
|
|
356
364
|
var_arg_type: TypeOrVarRef | None = None
|
|
357
365
|
|
|
366
|
+
@property
|
|
367
|
+
def semantic_return_type(self) -> TypeOrVarRef:
|
|
368
|
+
"""
|
|
369
|
+
The type that is returned by the function, which wil be in the first arg if it mutates it.
|
|
370
|
+
"""
|
|
371
|
+
return self.return_type or self.arg_types[0]
|
|
372
|
+
|
|
373
|
+
@property
|
|
374
|
+
def mutates(self) -> bool:
|
|
375
|
+
return self.return_type is None
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
@dataclass(frozen=True)
|
|
379
|
+
class FunctionDecl:
|
|
380
|
+
signature: FunctionSignature | SpecialFunctions = field(default_factory=FunctionSignature)
|
|
381
|
+
|
|
358
382
|
# Egg params
|
|
359
383
|
builtin: bool = False
|
|
360
384
|
egg_name: str | None = None
|
|
@@ -367,17 +391,6 @@ class FunctionDecl:
|
|
|
367
391
|
def to_function_decl(self) -> FunctionDecl:
|
|
368
392
|
return self
|
|
369
393
|
|
|
370
|
-
@property
|
|
371
|
-
def semantic_return_type(self) -> TypeOrVarRef:
|
|
372
|
-
"""
|
|
373
|
-
The type that is returned by the function, which wil be in the first arg if it mutates it.
|
|
374
|
-
"""
|
|
375
|
-
return self.return_type or self.arg_types[0]
|
|
376
|
-
|
|
377
|
-
@property
|
|
378
|
-
def mutates(self) -> bool:
|
|
379
|
-
return self.return_type is None
|
|
380
|
-
|
|
381
394
|
|
|
382
395
|
CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl
|
|
383
396
|
|
|
@@ -463,7 +476,20 @@ class CallDecl:
|
|
|
463
476
|
return hash(self) == hash(other)
|
|
464
477
|
|
|
465
478
|
|
|
466
|
-
|
|
479
|
+
@dataclass(frozen=True)
|
|
480
|
+
class PartialCallDecl:
|
|
481
|
+
"""
|
|
482
|
+
A partially applied function aka a function sort.
|
|
483
|
+
|
|
484
|
+
Note it does not need to have any args, in which case it's just a function pointer.
|
|
485
|
+
|
|
486
|
+
Seperated from the call decl so it's clear it is translated to a `unstable-fn` call.
|
|
487
|
+
"""
|
|
488
|
+
|
|
489
|
+
call: CallDecl
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl
|
|
467
493
|
|
|
468
494
|
|
|
469
495
|
@dataclass(frozen=True)
|
egglog/egraph.py
CHANGED
|
@@ -75,6 +75,7 @@ __all__ = [
|
|
|
75
75
|
"seq",
|
|
76
76
|
"Command",
|
|
77
77
|
"simplify",
|
|
78
|
+
"unstable_combine_rulesets",
|
|
78
79
|
"check",
|
|
79
80
|
"GraphvizKwargs",
|
|
80
81
|
"Ruleset",
|
|
@@ -88,6 +89,7 @@ __all__ = [
|
|
|
88
89
|
"Fact",
|
|
89
90
|
"Action",
|
|
90
91
|
"Command",
|
|
92
|
+
"check_eq",
|
|
91
93
|
]
|
|
92
94
|
|
|
93
95
|
T = TypeVar("T")
|
|
@@ -145,6 +147,23 @@ def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
|
|
|
145
147
|
return EGraph().extract(x)
|
|
146
148
|
|
|
147
149
|
|
|
150
|
+
def check_eq(x: EXPR, y: EXPR, schedule: Schedule | None = None) -> EGraph:
|
|
151
|
+
"""
|
|
152
|
+
Verifies that two expressions are equal after running the schedule.
|
|
153
|
+
"""
|
|
154
|
+
egraph = EGraph()
|
|
155
|
+
x_var = egraph.let("__check_eq_x", x)
|
|
156
|
+
y_var = egraph.let("__check_eq_y", y)
|
|
157
|
+
if schedule:
|
|
158
|
+
egraph.run(schedule)
|
|
159
|
+
fact = eq(x_var).to(y_var)
|
|
160
|
+
try:
|
|
161
|
+
egraph.check(fact)
|
|
162
|
+
except bindings.EggSmolError as err:
|
|
163
|
+
raise AssertionError(f"Failed {eq(x).to(y)}\n -> {ne(egraph.extract(x)).to(egraph.extract(y))})") from err
|
|
164
|
+
return egraph
|
|
165
|
+
|
|
166
|
+
|
|
148
167
|
def check(x: FactLike, schedule: Schedule | None = None, *given: ActionLike) -> None:
|
|
149
168
|
"""
|
|
150
169
|
Verifies that the fact is true given some assumptions and after running the schedule.
|
|
@@ -456,7 +475,7 @@ class _ExprMetaclass(type):
|
|
|
456
475
|
return isinstance(instance, RuntimeExpr)
|
|
457
476
|
|
|
458
477
|
|
|
459
|
-
def _generate_class_decls(
|
|
478
|
+
def _generate_class_decls( # noqa: C901
|
|
460
479
|
namespace: dict[str, Any], frame: FrameType, builtin: bool, egg_sort: str | None, cls_name: str
|
|
461
480
|
) -> Declarations:
|
|
462
481
|
"""
|
|
@@ -518,6 +537,16 @@ def _generate_class_decls(
|
|
|
518
537
|
locals = frame.f_locals
|
|
519
538
|
|
|
520
539
|
def create_decl(fn: object, first: Literal["cls"] | TypeRefWithVars) -> FunctionDecl:
|
|
540
|
+
special_function_name: SpecialFunctions | None = (
|
|
541
|
+
"fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None # noqa: B023
|
|
542
|
+
)
|
|
543
|
+
if special_function_name:
|
|
544
|
+
return FunctionDecl(
|
|
545
|
+
special_function_name,
|
|
546
|
+
builtin=True,
|
|
547
|
+
egg_name=egg_fn, # noqa: B023
|
|
548
|
+
)
|
|
549
|
+
|
|
521
550
|
return _fn_decl(
|
|
522
551
|
decls,
|
|
523
552
|
egg_fn, # noqa: B023
|
|
@@ -649,6 +678,10 @@ def _fn_decl(
|
|
|
649
678
|
raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
|
|
650
679
|
|
|
651
680
|
hint_globals = fn.__globals__.copy()
|
|
681
|
+
# Copy Callable into global if not present bc sometimes it gets automatically removed by ruff to type only block
|
|
682
|
+
# https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/
|
|
683
|
+
if "Callable" not in hint_globals:
|
|
684
|
+
hint_globals["Callable"] = Callable
|
|
652
685
|
|
|
653
686
|
hints = get_type_hints(fn, hint_globals, hint_locals)
|
|
654
687
|
|
|
@@ -715,11 +748,13 @@ def _fn_decl(
|
|
|
715
748
|
)
|
|
716
749
|
decls.update(*merge_action)
|
|
717
750
|
return FunctionDecl(
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
751
|
+
FunctionSignature(
|
|
752
|
+
return_type=None if mutates_first_arg else return_type,
|
|
753
|
+
var_arg_type=var_arg_type,
|
|
754
|
+
arg_types=arg_types,
|
|
755
|
+
arg_names=tuple(t.name for t in params),
|
|
756
|
+
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
|
|
757
|
+
),
|
|
723
758
|
cost=cost,
|
|
724
759
|
egg_name=egg_name,
|
|
725
760
|
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
|
|
@@ -933,13 +968,12 @@ class EGraph(_BaseModule):
|
|
|
933
968
|
"""
|
|
934
969
|
Displays the e-graph in the notebook.
|
|
935
970
|
"""
|
|
936
|
-
graphviz = self.graphviz(**kwargs)
|
|
937
971
|
if IN_IPYTHON:
|
|
938
972
|
from IPython.display import SVG, display
|
|
939
973
|
|
|
940
974
|
display(SVG(self.graphviz_svg(**kwargs)))
|
|
941
975
|
else:
|
|
942
|
-
graphviz.render(view=True, format="svg", quiet=True)
|
|
976
|
+
self.graphviz(**kwargs).render(view=True, format="svg", quiet=True)
|
|
943
977
|
|
|
944
978
|
def input(self, fn: Callable[..., String], path: str) -> None:
|
|
945
979
|
"""
|
|
@@ -1059,7 +1093,7 @@ class EGraph(_BaseModule):
|
|
|
1059
1093
|
runtime_expr = to_runtime_expr(expr)
|
|
1060
1094
|
self._add_decls(runtime_expr)
|
|
1061
1095
|
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1062
|
-
extract_report = self._run_extract(typed_expr
|
|
1096
|
+
extract_report = self._run_extract(typed_expr, 0)
|
|
1063
1097
|
|
|
1064
1098
|
if not isinstance(extract_report, bindings.Best):
|
|
1065
1099
|
msg = "No extract report saved"
|
|
@@ -1079,15 +1113,16 @@ class EGraph(_BaseModule):
|
|
|
1079
1113
|
self._add_decls(runtime_expr)
|
|
1080
1114
|
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1081
1115
|
|
|
1082
|
-
extract_report = self._run_extract(typed_expr
|
|
1116
|
+
extract_report = self._run_extract(typed_expr, n)
|
|
1083
1117
|
if not isinstance(extract_report, bindings.Variants):
|
|
1084
1118
|
msg = "Wrong extract report type"
|
|
1085
1119
|
raise ValueError(msg) # noqa: TRY004
|
|
1086
1120
|
new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp)
|
|
1087
1121
|
return [cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, expr)) for expr in new_exprs]
|
|
1088
1122
|
|
|
1089
|
-
def _run_extract(self,
|
|
1090
|
-
|
|
1123
|
+
def _run_extract(self, typed_expr: TypedExprDecl, n: int) -> bindings._ExtractReport:
|
|
1124
|
+
self._state.type_ref_to_egg(typed_expr.tp)
|
|
1125
|
+
expr = self._state.expr_to_egg(typed_expr.expr)
|
|
1091
1126
|
self._egraph.run_program(bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n)))))
|
|
1092
1127
|
extract_report = self._egraph.extract_report()
|
|
1093
1128
|
if not extract_report:
|
|
@@ -1276,8 +1311,10 @@ def ruleset(
|
|
|
1276
1311
|
"""
|
|
1277
1312
|
Creates a ruleset with the following rules.
|
|
1278
1313
|
|
|
1279
|
-
If no name is provided,
|
|
1314
|
+
If no name is provided, try using the name of the funciton.
|
|
1280
1315
|
"""
|
|
1316
|
+
if isinstance(rule_or_generator, FunctionType):
|
|
1317
|
+
name = name or rule_or_generator.__name__
|
|
1281
1318
|
r = Ruleset(name)
|
|
1282
1319
|
if rule_or_generator is not None:
|
|
1283
1320
|
r.register(rule_or_generator, *rules, _increase_frame=True)
|
|
@@ -1388,12 +1425,48 @@ class Ruleset(Schedule):
|
|
|
1388
1425
|
def __repr__(self) -> str:
|
|
1389
1426
|
return str(self)
|
|
1390
1427
|
|
|
1428
|
+
def __or__(self, other: Ruleset | UnstableCombinedRuleset) -> UnstableCombinedRuleset:
|
|
1429
|
+
return unstable_combine_rulesets(self, other)
|
|
1430
|
+
|
|
1391
1431
|
# Create a unique name if we didn't pass one from the user
|
|
1392
1432
|
@property
|
|
1393
1433
|
def __egg_name__(self) -> str:
|
|
1394
1434
|
return self.name or f"ruleset_{id(self)}"
|
|
1395
1435
|
|
|
1396
1436
|
|
|
1437
|
+
@dataclass
|
|
1438
|
+
class UnstableCombinedRuleset(Schedule):
|
|
1439
|
+
__egg_decls_thunk__: Callable[[], Declarations] = field(init=False)
|
|
1440
|
+
schedule: RunDecl = field(init=False)
|
|
1441
|
+
name: str | None
|
|
1442
|
+
rulesets: InitVar[list[Ruleset | UnstableCombinedRuleset]]
|
|
1443
|
+
|
|
1444
|
+
def __post_init__(self, rulesets: list[Ruleset | UnstableCombinedRuleset]) -> None:
|
|
1445
|
+
self.schedule = RunDecl(self.__egg_name__, ())
|
|
1446
|
+
self.__egg_decls_thunk__ = Thunk.fn(self._create_egg_decls, *rulesets)
|
|
1447
|
+
|
|
1448
|
+
@property
|
|
1449
|
+
def __egg_name__(self) -> str:
|
|
1450
|
+
return self.name or f"combined_ruleset_{id(self)}"
|
|
1451
|
+
|
|
1452
|
+
def _create_egg_decls(self, *rulesets: Ruleset | UnstableCombinedRuleset) -> Declarations:
|
|
1453
|
+
decls = Declarations.create(*rulesets)
|
|
1454
|
+
decls._rulesets[self.__egg_name__] = CombinedRulesetDecl(tuple(r.__egg_name__ for r in rulesets))
|
|
1455
|
+
return decls
|
|
1456
|
+
|
|
1457
|
+
def __or__(self, other: Ruleset | UnstableCombinedRuleset) -> UnstableCombinedRuleset:
|
|
1458
|
+
return unstable_combine_rulesets(self, other)
|
|
1459
|
+
|
|
1460
|
+
|
|
1461
|
+
def unstable_combine_rulesets(
|
|
1462
|
+
*rulesets: Ruleset | UnstableCombinedRuleset, name: str | None = None
|
|
1463
|
+
) -> UnstableCombinedRuleset:
|
|
1464
|
+
"""
|
|
1465
|
+
Combine multiple rulesets into a single ruleset.
|
|
1466
|
+
"""
|
|
1467
|
+
return UnstableCombinedRuleset(name, list(rulesets))
|
|
1468
|
+
|
|
1469
|
+
|
|
1397
1470
|
@dataclass
|
|
1398
1471
|
class RewriteOrRule:
|
|
1399
1472
|
__egg_decls__: Declarations
|
|
@@ -1556,9 +1629,9 @@ def var(name: str, bound: type[EXPR]) -> EXPR:
|
|
|
1556
1629
|
|
|
1557
1630
|
def _var(name: str, bound: object) -> RuntimeExpr:
|
|
1558
1631
|
"""Create a new variable with the given name and type."""
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
return RuntimeExpr.__from_value__(
|
|
1632
|
+
decls = Declarations()
|
|
1633
|
+
type_ref = resolve_type_annotation(decls, bound)
|
|
1634
|
+
return RuntimeExpr.__from_value__(decls, TypedExprDecl(type_ref.to_just(), VarDecl(name)))
|
|
1562
1635
|
|
|
1563
1636
|
|
|
1564
1637
|
def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
|
|
@@ -1801,8 +1874,10 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
|
|
|
1801
1874
|
"""
|
|
1802
1875
|
# Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
|
|
1803
1876
|
# but not in the globals
|
|
1804
|
-
|
|
1805
|
-
|
|
1877
|
+
globals = gen.__globals__.copy()
|
|
1878
|
+
if "Callable" not in globals:
|
|
1879
|
+
globals["Callable"] = Callable
|
|
1880
|
+
hints = get_type_hints(gen, globals, frame.f_locals)
|
|
1806
1881
|
args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
|
|
1807
1882
|
return list(gen(*args)) # type: ignore[misc]
|
|
1808
1883
|
|