egglog 6.1.0__cp311-none-win_amd64.whl → 7.1.0__cp311-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +1 -1
- egglog/bindings.cp311-win_amd64.pyd +0 -0
- egglog/bindings.pyi +9 -0
- egglog/builtins.py +42 -2
- egglog/conversion.py +177 -0
- egglog/declarations.py +354 -734
- egglog/egraph.py +602 -800
- egglog/egraph_state.py +456 -0
- egglog/exp/array_api.py +100 -88
- egglog/exp/array_api_numba.py +6 -1
- egglog/exp/siu_examples.py +35 -0
- egglog/pretty.py +464 -0
- egglog/runtime.py +279 -431
- egglog/thunk.py +71 -0
- egglog/type_constraint_solver.py +5 -2
- {egglog-6.1.0.dist-info → egglog-7.1.0.dist-info}/METADATA +7 -7
- {egglog-6.1.0.dist-info → egglog-7.1.0.dist-info}/RECORD +19 -14
- {egglog-6.1.0.dist-info → egglog-7.1.0.dist-info}/WHEEL +0 -0
- {egglog-6.1.0.dist-info → egglog-7.1.0.dist-info}/license_files/LICENSE +0 -0
egglog/egraph.py
CHANGED
|
@@ -3,12 +3,10 @@ from __future__ import annotations
|
|
|
3
3
|
import inspect
|
|
4
4
|
import pathlib
|
|
5
5
|
import tempfile
|
|
6
|
-
from abc import
|
|
6
|
+
from abc import abstractmethod
|
|
7
7
|
from collections.abc import Callable, Iterable
|
|
8
8
|
from contextvars import ContextVar, Token
|
|
9
|
-
from copy import deepcopy
|
|
10
9
|
from dataclasses import InitVar, dataclass, field
|
|
11
|
-
from functools import cached_property
|
|
12
10
|
from inspect import Parameter, currentframe, signature
|
|
13
11
|
from types import FrameType, FunctionType
|
|
14
12
|
from typing import (
|
|
@@ -18,6 +16,7 @@ from typing import (
|
|
|
18
16
|
Generic,
|
|
19
17
|
Literal,
|
|
20
18
|
NoReturn,
|
|
19
|
+
TypeAlias,
|
|
21
20
|
TypedDict,
|
|
22
21
|
TypeVar,
|
|
23
22
|
cast,
|
|
@@ -26,14 +25,16 @@ from typing import (
|
|
|
26
25
|
)
|
|
27
26
|
|
|
28
27
|
import graphviz
|
|
29
|
-
from typing_extensions import ParamSpec, Self, Unpack, deprecated
|
|
30
|
-
|
|
31
|
-
from egglog.declarations import REFLECTED_BINARY_METHODS, Declarations
|
|
28
|
+
from typing_extensions import ParamSpec, Self, Unpack, assert_never, deprecated
|
|
32
29
|
|
|
33
30
|
from . import bindings
|
|
31
|
+
from .conversion import *
|
|
34
32
|
from .declarations import *
|
|
33
|
+
from .egraph_state import *
|
|
35
34
|
from .ipython_magic import IN_IPYTHON
|
|
35
|
+
from .pretty import pretty_decl
|
|
36
36
|
from .runtime import *
|
|
37
|
+
from .thunk import *
|
|
37
38
|
|
|
38
39
|
if TYPE_CHECKING:
|
|
39
40
|
import ipywidgets
|
|
@@ -58,6 +59,7 @@ __all__ = [
|
|
|
58
59
|
"let",
|
|
59
60
|
"constant",
|
|
60
61
|
"delete",
|
|
62
|
+
"subsume",
|
|
61
63
|
"union",
|
|
62
64
|
"set_",
|
|
63
65
|
"rule",
|
|
@@ -65,11 +67,15 @@ __all__ = [
|
|
|
65
67
|
"vars_",
|
|
66
68
|
"Fact",
|
|
67
69
|
"expr_parts",
|
|
70
|
+
"expr_action",
|
|
71
|
+
"expr_fact",
|
|
72
|
+
"action_command",
|
|
68
73
|
"Schedule",
|
|
69
74
|
"run",
|
|
70
75
|
"seq",
|
|
71
76
|
"Command",
|
|
72
77
|
"simplify",
|
|
78
|
+
"unstable_combine_rulesets",
|
|
73
79
|
"check",
|
|
74
80
|
"GraphvizKwargs",
|
|
75
81
|
"Ruleset",
|
|
@@ -79,11 +85,11 @@ __all__ = [
|
|
|
79
85
|
"_NeBuilder",
|
|
80
86
|
"_SetBuilder",
|
|
81
87
|
"_UnionBuilder",
|
|
82
|
-
"
|
|
83
|
-
"
|
|
84
|
-
"BiRewrite",
|
|
85
|
-
"Union_",
|
|
88
|
+
"RewriteOrRule",
|
|
89
|
+
"Fact",
|
|
86
90
|
"Action",
|
|
91
|
+
"Command",
|
|
92
|
+
"check_eq",
|
|
87
93
|
]
|
|
88
94
|
|
|
89
95
|
T = TypeVar("T")
|
|
@@ -112,7 +118,24 @@ IGNORED_ATTRIBUTES = {
|
|
|
112
118
|
}
|
|
113
119
|
|
|
114
120
|
|
|
121
|
+
# special methods that return none and mutate self
|
|
115
122
|
ALWAYS_MUTATES_SELF = {"__setitem__", "__delitem__"}
|
|
123
|
+
# special methods which must return real python values instead of lazy expressions
|
|
124
|
+
ALWAYS_PRESERVED = {
|
|
125
|
+
"__repr__",
|
|
126
|
+
"__str__",
|
|
127
|
+
"__bytes__",
|
|
128
|
+
"__format__",
|
|
129
|
+
"__hash__",
|
|
130
|
+
"__bool__",
|
|
131
|
+
"__len__",
|
|
132
|
+
"__length_hint__",
|
|
133
|
+
"__iter__",
|
|
134
|
+
"__reversed__",
|
|
135
|
+
"__contains__",
|
|
136
|
+
"__index__",
|
|
137
|
+
"__bufer__",
|
|
138
|
+
}
|
|
116
139
|
|
|
117
140
|
|
|
118
141
|
def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
|
|
@@ -124,7 +147,24 @@ def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
|
|
|
124
147
|
return EGraph().extract(x)
|
|
125
148
|
|
|
126
149
|
|
|
127
|
-
def
|
|
150
|
+
def check_eq(x: EXPR, y: EXPR, schedule: Schedule | None = None) -> EGraph:
|
|
151
|
+
"""
|
|
152
|
+
Verifies that two expressions are equal after running the schedule.
|
|
153
|
+
"""
|
|
154
|
+
egraph = EGraph()
|
|
155
|
+
x_var = egraph.let("__check_eq_x", x)
|
|
156
|
+
y_var = egraph.let("__check_eq_y", y)
|
|
157
|
+
if schedule:
|
|
158
|
+
egraph.run(schedule)
|
|
159
|
+
fact = eq(x_var).to(y_var)
|
|
160
|
+
try:
|
|
161
|
+
egraph.check(fact)
|
|
162
|
+
except bindings.EggSmolError as err:
|
|
163
|
+
raise AssertionError(f"Failed {eq(x).to(y)}\n -> {ne(egraph.extract(x)).to(egraph.extract(y))})") from err
|
|
164
|
+
return egraph
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def check(x: FactLike, schedule: Schedule | None = None, *given: ActionLike) -> None:
|
|
128
168
|
"""
|
|
129
169
|
Verifies that the fact is true given some assumptions and after running the schedule.
|
|
130
170
|
"""
|
|
@@ -136,9 +176,6 @@ def check(x: FactLike, schedule: Schedule | None = None, *given: Union_ | Expr |
|
|
|
136
176
|
egraph.check(x)
|
|
137
177
|
|
|
138
178
|
|
|
139
|
-
# def extract(res: )
|
|
140
|
-
|
|
141
|
-
|
|
142
179
|
@dataclass
|
|
143
180
|
class _BaseModule:
|
|
144
181
|
"""
|
|
@@ -181,15 +218,8 @@ class _BaseModule:
|
|
|
181
218
|
Registers a class.
|
|
182
219
|
"""
|
|
183
220
|
if kwargs:
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
def _inner(cls: object, egg_sort: str = kwargs["egg_sort"]):
|
|
187
|
-
assert isinstance(cls, RuntimeClass)
|
|
188
|
-
assert isinstance(cls.lazy_decls, _ClassDeclerationsConstructor)
|
|
189
|
-
cls.lazy_decls.egg_sort = egg_sort
|
|
190
|
-
return cls
|
|
191
|
-
|
|
192
|
-
return _inner
|
|
221
|
+
msg = "Switch to subclassing from Expr and passing egg_sort as a keyword arg to the class constructor"
|
|
222
|
+
raise NotImplementedError(msg)
|
|
193
223
|
|
|
194
224
|
assert len(args) == 1
|
|
195
225
|
return args[0]
|
|
@@ -280,9 +310,9 @@ class _BaseModule:
|
|
|
280
310
|
# If we have any positional args, then we are calling it directly on a function
|
|
281
311
|
if args:
|
|
282
312
|
assert len(args) == 1
|
|
283
|
-
return
|
|
313
|
+
return _FunctionConstructor(fn_locals)(args[0])
|
|
284
314
|
# otherwise, we are passing some keyword args, so save those, and then return a partial
|
|
285
|
-
return
|
|
315
|
+
return _FunctionConstructor(fn_locals, **kwargs)
|
|
286
316
|
|
|
287
317
|
@deprecated("Use top level `ruleset` function instead")
|
|
288
318
|
def ruleset(self, name: str) -> Ruleset:
|
|
@@ -324,17 +354,26 @@ class _BaseModule:
|
|
|
324
354
|
"""
|
|
325
355
|
return constant(name, tp, egg_name)
|
|
326
356
|
|
|
327
|
-
def register(
|
|
357
|
+
def register(
|
|
358
|
+
self,
|
|
359
|
+
/,
|
|
360
|
+
command_or_generator: ActionLike | RewriteOrRule | RewriteOrRuleGenerator,
|
|
361
|
+
*command_likes: ActionLike | RewriteOrRule,
|
|
362
|
+
) -> None:
|
|
328
363
|
"""
|
|
329
364
|
Registers any number of rewrites or rules.
|
|
330
365
|
"""
|
|
331
366
|
if isinstance(command_or_generator, FunctionType):
|
|
332
367
|
assert not command_likes
|
|
333
|
-
|
|
368
|
+
current_frame = inspect.currentframe()
|
|
369
|
+
assert current_frame
|
|
370
|
+
original_frame = current_frame.f_back
|
|
371
|
+
assert original_frame
|
|
372
|
+
command_likes = tuple(_rewrite_or_rule_generator(command_or_generator, original_frame))
|
|
334
373
|
else:
|
|
335
374
|
command_likes = (cast(CommandLike, command_or_generator), *command_likes)
|
|
336
|
-
|
|
337
|
-
self._register_commands(
|
|
375
|
+
commands = [_command_like(c) for c in command_likes]
|
|
376
|
+
self._register_commands(commands)
|
|
338
377
|
|
|
339
378
|
@abstractmethod
|
|
340
379
|
def _register_commands(self, cmds: list[Command]) -> None:
|
|
@@ -417,136 +456,126 @@ class _ExprMetaclass(type):
|
|
|
417
456
|
# If this is the Expr subclass, just return the class
|
|
418
457
|
if not bases:
|
|
419
458
|
return super().__new__(cls, name, bases, namespace)
|
|
459
|
+
# TODO: Raise error on subclassing or multiple inheritence
|
|
420
460
|
|
|
421
461
|
frame = currentframe()
|
|
422
462
|
assert frame
|
|
423
463
|
prev_frame = frame.f_back
|
|
424
464
|
assert prev_frame
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
cls_name=name,
|
|
434
|
-
).current_cls
|
|
465
|
+
|
|
466
|
+
# Store frame so that we can get live access to updated locals/globals
|
|
467
|
+
# Otherwise, f_locals returns a copy
|
|
468
|
+
# https://peps.python.org/pep-0667/
|
|
469
|
+
decls_thunk = Thunk.fn(
|
|
470
|
+
_generate_class_decls, namespace, prev_frame, builtin, egg_sort, name, fallback=Declarations
|
|
471
|
+
)
|
|
472
|
+
return RuntimeClass(decls_thunk, TypeRefWithVars(name))
|
|
435
473
|
|
|
436
474
|
def __instancecheck__(cls, instance: object) -> bool:
|
|
437
475
|
return isinstance(instance, RuntimeExpr)
|
|
438
476
|
|
|
439
477
|
|
|
440
|
-
|
|
441
|
-
|
|
478
|
+
def _generate_class_decls( # noqa: C901
|
|
479
|
+
namespace: dict[str, Any], frame: FrameType, builtin: bool, egg_sort: str | None, cls_name: str
|
|
480
|
+
) -> Declarations:
|
|
442
481
|
"""
|
|
443
482
|
Lazy constructor for class declerations to support classes with methods whose types are not yet defined.
|
|
444
483
|
"""
|
|
484
|
+
parameters: list[TypeVar] = (
|
|
485
|
+
# Get the generic params from the orig bases generic class
|
|
486
|
+
namespace["__orig_bases__"][1].__parameters__ if "__orig_bases__" in namespace else []
|
|
487
|
+
)
|
|
488
|
+
type_vars = tuple(p.__name__ for p in parameters)
|
|
489
|
+
del parameters
|
|
490
|
+
cls_decl = ClassDecl(egg_sort, type_vars, builtin)
|
|
491
|
+
decls = Declarations(_classes={cls_name: cls_decl})
|
|
492
|
+
|
|
493
|
+
##
|
|
494
|
+
# Register class variables
|
|
495
|
+
##
|
|
496
|
+
# Create a dummy type to pass to get_type_hints to resolve the annotations we have
|
|
497
|
+
_Dummytype = type("_DummyType", (), {"__annotations__": namespace.get("__annotations__", {})})
|
|
498
|
+
for k, v in get_type_hints(_Dummytype, globalns=frame.f_globals, localns=frame.f_locals).items():
|
|
499
|
+
if getattr(v, "__origin__", None) == ClassVar:
|
|
500
|
+
(inner_tp,) = v.__args__
|
|
501
|
+
type_ref = resolve_type_annotation(decls, inner_tp).to_just()
|
|
502
|
+
cls_decl.class_variables[k] = ConstantDecl(type_ref)
|
|
445
503
|
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
egg_sort: str | None
|
|
450
|
-
cls_name: str
|
|
451
|
-
current_cls: RuntimeClass = field(init=False)
|
|
504
|
+
else:
|
|
505
|
+
msg = f"On class {cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
|
|
506
|
+
raise NotImplementedError(msg)
|
|
452
507
|
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
_Dummytype.__annotations__ = self.namespace.get("__annotations__", {})
|
|
477
|
-
# Make lazy update to locals, so we keep a live handle on them after class creation
|
|
478
|
-
locals = self.frame.f_locals.copy()
|
|
479
|
-
locals[self.cls_name] = self.current_cls
|
|
480
|
-
for k, v in get_type_hints(_Dummytype, globalns=self.frame.f_globals, localns=locals).items():
|
|
481
|
-
if getattr(v, "__origin__", None) == ClassVar:
|
|
482
|
-
(inner_tp,) = v.__args__
|
|
483
|
-
_register_constant(decls, ClassVariableRef(self.cls_name, k), inner_tp, None)
|
|
484
|
-
else:
|
|
485
|
-
msg = f"On class {self.cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
|
|
486
|
-
raise NotImplementedError(msg)
|
|
487
|
-
|
|
488
|
-
# Then register each of its methods
|
|
489
|
-
for method_name, method in cls_dict.items():
|
|
490
|
-
is_init = method_name == "__init__"
|
|
491
|
-
# Don't register the init methods for literals, since those don't use the type checking mechanisms
|
|
492
|
-
if is_init and self.cls_name in LIT_CLASS_NAMES:
|
|
493
|
-
continue
|
|
494
|
-
if isinstance(method, _WrappedMethod):
|
|
495
|
-
fn = method.fn
|
|
496
|
-
egg_fn = method.egg_fn
|
|
497
|
-
cost = method.cost
|
|
498
|
-
default = method.default
|
|
499
|
-
merge = method.merge
|
|
500
|
-
on_merge = method.on_merge
|
|
501
|
-
mutates_first_arg = method.mutates_self
|
|
502
|
-
unextractable = method.unextractable
|
|
503
|
-
if method.preserve:
|
|
504
|
-
decls.register_preserved_method(self.cls_name, method_name, fn)
|
|
505
|
-
continue
|
|
506
|
-
else:
|
|
507
|
-
fn = method
|
|
508
|
+
##
|
|
509
|
+
# Register methods, classmethods, preserved methods, and properties
|
|
510
|
+
##
|
|
511
|
+
|
|
512
|
+
# The type ref of self is paramterized by the type vars
|
|
513
|
+
slf_type_ref = TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
|
|
514
|
+
|
|
515
|
+
# Get all the methods from the class
|
|
516
|
+
filtered_namespace: list[tuple[str, Any]] = [
|
|
517
|
+
(k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
|
|
518
|
+
]
|
|
519
|
+
|
|
520
|
+
# Then register each of its methods
|
|
521
|
+
for method_name, method in filtered_namespace:
|
|
522
|
+
is_init = method_name == "__init__"
|
|
523
|
+
# Don't register the init methods for literals, since those don't use the type checking mechanisms
|
|
524
|
+
if is_init and cls_name in LIT_CLASS_NAMES:
|
|
525
|
+
continue
|
|
526
|
+
match method:
|
|
527
|
+
case _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn, preserve, mutates, unextractable):
|
|
528
|
+
pass
|
|
529
|
+
case _:
|
|
508
530
|
egg_fn, cost, default, merge, on_merge = None, None, None, None, None
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
fn
|
|
520
|
-
is_property = True
|
|
521
|
-
if is_classmethod:
|
|
522
|
-
msg = "Can't have a classmethod property"
|
|
523
|
-
raise NotImplementedError(msg)
|
|
524
|
-
else:
|
|
525
|
-
is_property = False
|
|
526
|
-
ref: FunctionCallableRef = (
|
|
527
|
-
ClassMethodRef(self.cls_name, method_name)
|
|
528
|
-
if is_classmethod
|
|
529
|
-
else PropertyRef(self.cls_name, method_name)
|
|
530
|
-
if is_property
|
|
531
|
-
else MethodRef(self.cls_name, method_name)
|
|
531
|
+
fn = method
|
|
532
|
+
unextractable, preserve = False, False
|
|
533
|
+
mutates = method_name in ALWAYS_MUTATES_SELF
|
|
534
|
+
if preserve:
|
|
535
|
+
cls_decl.preserved_methods[method_name] = fn
|
|
536
|
+
continue
|
|
537
|
+
locals = frame.f_locals
|
|
538
|
+
|
|
539
|
+
def create_decl(fn: object, first: Literal["cls"] | TypeRefWithVars) -> FunctionDecl:
|
|
540
|
+
special_function_name: SpecialFunctions | None = (
|
|
541
|
+
"fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None # noqa: B023
|
|
532
542
|
)
|
|
533
|
-
|
|
543
|
+
if special_function_name:
|
|
544
|
+
return FunctionDecl(
|
|
545
|
+
special_function_name,
|
|
546
|
+
builtin=True,
|
|
547
|
+
egg_name=egg_fn, # noqa: B023
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
return _fn_decl(
|
|
534
551
|
decls,
|
|
535
|
-
|
|
536
|
-
egg_fn,
|
|
552
|
+
egg_fn, # noqa: B023
|
|
537
553
|
fn,
|
|
538
|
-
locals,
|
|
539
|
-
default,
|
|
540
|
-
cost,
|
|
541
|
-
merge,
|
|
542
|
-
on_merge,
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
is_init,
|
|
547
|
-
unextractable
|
|
554
|
+
locals, # noqa: B023
|
|
555
|
+
default, # noqa: B023
|
|
556
|
+
cost, # noqa: B023
|
|
557
|
+
merge, # noqa: B023
|
|
558
|
+
on_merge, # noqa: B023
|
|
559
|
+
mutates, # noqa: B023
|
|
560
|
+
builtin,
|
|
561
|
+
first,
|
|
562
|
+
is_init, # noqa: B023
|
|
563
|
+
unextractable, # noqa: B023
|
|
548
564
|
)
|
|
549
565
|
|
|
566
|
+
match fn:
|
|
567
|
+
case classmethod():
|
|
568
|
+
cls_decl.class_methods[method_name] = create_decl(fn.__func__, "cls")
|
|
569
|
+
case property():
|
|
570
|
+
cls_decl.properties[method_name] = create_decl(fn.fget, slf_type_ref)
|
|
571
|
+
case _:
|
|
572
|
+
if is_init:
|
|
573
|
+
cls_decl.class_methods[method_name] = create_decl(fn, slf_type_ref)
|
|
574
|
+
else:
|
|
575
|
+
cls_decl.methods[method_name] = create_decl(fn, slf_type_ref)
|
|
576
|
+
|
|
577
|
+
return decls
|
|
578
|
+
|
|
550
579
|
|
|
551
580
|
@overload
|
|
552
581
|
def function(fn: CALLABLE, /) -> CALLABLE: ...
|
|
@@ -589,48 +618,46 @@ def function(*args, **kwargs) -> Any:
|
|
|
589
618
|
# If we have any positional args, then we are calling it directly on a function
|
|
590
619
|
if args:
|
|
591
620
|
assert len(args) == 1
|
|
592
|
-
return
|
|
621
|
+
return _FunctionConstructor(fn_locals)(args[0])
|
|
593
622
|
# otherwise, we are passing some keyword args, so save those, and then return a partial
|
|
594
|
-
return
|
|
623
|
+
return _FunctionConstructor(fn_locals, **kwargs)
|
|
595
624
|
|
|
596
625
|
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
hint_locals: dict[str, Any]
|
|
600
|
-
builtin: bool = False
|
|
601
|
-
mutates_first_arg: bool = False
|
|
602
|
-
egg_fn: str | None = None
|
|
603
|
-
cost: int | None = None
|
|
604
|
-
default: RuntimeExpr | None = None
|
|
605
|
-
merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None
|
|
606
|
-
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
|
|
607
|
-
unextractable: bool = False
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
return RuntimeFunction(decls, name)
|
|
626
|
+
@dataclass
|
|
627
|
+
class _FunctionConstructor:
|
|
628
|
+
hint_locals: dict[str, Any]
|
|
629
|
+
builtin: bool = False
|
|
630
|
+
mutates_first_arg: bool = False
|
|
631
|
+
egg_fn: str | None = None
|
|
632
|
+
cost: int | None = None
|
|
633
|
+
default: RuntimeExpr | None = None
|
|
634
|
+
merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None
|
|
635
|
+
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
|
|
636
|
+
unextractable: bool = False
|
|
637
|
+
|
|
638
|
+
def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
|
|
639
|
+
return RuntimeFunction(Thunk.fn(self.create_decls, fn), FunctionRef(fn.__name__))
|
|
640
|
+
|
|
641
|
+
def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations:
|
|
642
|
+
decls = Declarations()
|
|
643
|
+
decls._functions[fn.__name__] = _fn_decl(
|
|
644
|
+
decls,
|
|
645
|
+
self.egg_fn,
|
|
646
|
+
fn,
|
|
647
|
+
self.hint_locals,
|
|
648
|
+
self.default,
|
|
649
|
+
self.cost,
|
|
650
|
+
self.merge,
|
|
651
|
+
self.on_merge,
|
|
652
|
+
self.mutates_first_arg,
|
|
653
|
+
self.builtin,
|
|
654
|
+
unextractable=self.unextractable,
|
|
655
|
+
)
|
|
656
|
+
return decls
|
|
629
657
|
|
|
630
658
|
|
|
631
|
-
def
|
|
659
|
+
def _fn_decl(
|
|
632
660
|
decls: Declarations,
|
|
633
|
-
ref: FunctionCallableRef,
|
|
634
661
|
egg_name: str | None,
|
|
635
662
|
fn: object,
|
|
636
663
|
# Pass in the locals, retrieved from the frame when wrapping,
|
|
@@ -646,11 +673,15 @@ def _register_function(
|
|
|
646
673
|
first_arg: Literal["cls"] | TypeOrVarRef | None = None,
|
|
647
674
|
is_init: bool = False,
|
|
648
675
|
unextractable: bool = False,
|
|
649
|
-
) ->
|
|
676
|
+
) -> FunctionDecl:
|
|
650
677
|
if not isinstance(fn, FunctionType):
|
|
651
678
|
raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
|
|
652
679
|
|
|
653
680
|
hint_globals = fn.__globals__.copy()
|
|
681
|
+
# Copy Callable into global if not present bc sometimes it gets automatically removed by ruff to type only block
|
|
682
|
+
# https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/
|
|
683
|
+
if "Callable" not in hint_globals:
|
|
684
|
+
hint_globals["Callable"] = Callable
|
|
654
685
|
|
|
655
686
|
hints = get_type_hints(fn, hint_globals, hint_locals)
|
|
656
687
|
|
|
@@ -699,8 +730,8 @@ def _register_function(
|
|
|
699
730
|
None
|
|
700
731
|
if merge is None
|
|
701
732
|
else merge(
|
|
702
|
-
RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
|
|
703
|
-
RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
|
|
733
|
+
RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
|
|
734
|
+
RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
|
|
704
735
|
)
|
|
705
736
|
)
|
|
706
737
|
decls |= merged
|
|
@@ -710,30 +741,27 @@ def _register_function(
|
|
|
710
741
|
if on_merge is None
|
|
711
742
|
else _action_likes(
|
|
712
743
|
on_merge(
|
|
713
|
-
RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
|
|
714
|
-
RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
|
|
744
|
+
RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
|
|
745
|
+
RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
|
|
715
746
|
)
|
|
716
747
|
)
|
|
717
748
|
)
|
|
718
749
|
decls.update(*merge_action)
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
None if default is None else default.__egg_typed_expr__.expr,
|
|
733
|
-
|
|
734
|
-
[a._to_egg_action() for a in merge_action],
|
|
735
|
-
unextractable,
|
|
736
|
-
is_builtin,
|
|
750
|
+
return FunctionDecl(
|
|
751
|
+
FunctionSignature(
|
|
752
|
+
return_type=None if mutates_first_arg else return_type,
|
|
753
|
+
var_arg_type=var_arg_type,
|
|
754
|
+
arg_types=arg_types,
|
|
755
|
+
arg_names=tuple(t.name for t in params),
|
|
756
|
+
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
|
|
757
|
+
),
|
|
758
|
+
cost=cost,
|
|
759
|
+
egg_name=egg_name,
|
|
760
|
+
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
|
|
761
|
+
unextractable=unextractable,
|
|
762
|
+
builtin=is_builtin,
|
|
763
|
+
default=None if default is None else default.__egg_typed_expr__.expr,
|
|
764
|
+
on_merge=tuple(a.action for a in merge_action),
|
|
737
765
|
)
|
|
738
766
|
|
|
739
767
|
|
|
@@ -764,49 +792,31 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..
|
|
|
764
792
|
"""
|
|
765
793
|
Creates a function whose return type is `Unit` and has a default value.
|
|
766
794
|
"""
|
|
795
|
+
decls_thunk = Thunk.fn(_relation_decls, name, tps, egg_fn)
|
|
796
|
+
return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, FunctionRef(name)))
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations:
|
|
767
800
|
decls = Declarations()
|
|
768
801
|
decls |= cast(RuntimeClass, Unit)
|
|
769
|
-
arg_types = tuple(resolve_type_annotation(decls, tp) for tp in tps)
|
|
770
|
-
|
|
771
|
-
decls
|
|
772
|
-
FunctionRef(name),
|
|
773
|
-
fn_decl,
|
|
774
|
-
egg_fn,
|
|
775
|
-
cost=None,
|
|
776
|
-
default=None,
|
|
777
|
-
merge=None,
|
|
778
|
-
merge_action=[],
|
|
779
|
-
unextractable=False,
|
|
780
|
-
builtin=False,
|
|
781
|
-
is_relation=True,
|
|
782
|
-
)
|
|
783
|
-
return cast(Callable[..., Unit], RuntimeFunction(decls, name))
|
|
802
|
+
arg_types = tuple(resolve_type_annotation(decls, tp).to_just() for tp in tps)
|
|
803
|
+
decls._functions[name] = RelationDecl(arg_types, tuple(None for _ in tps), egg_fn)
|
|
804
|
+
return decls
|
|
784
805
|
|
|
785
806
|
|
|
786
807
|
def constant(name: str, tp: type[EXPR], egg_name: str | None = None) -> EXPR:
|
|
787
808
|
"""
|
|
788
|
-
|
|
789
809
|
A "constant" is implemented as the instantiation of a value that takes no args.
|
|
790
810
|
This creates a function with `name` and return type `tp` and returns a value of it being called.
|
|
791
811
|
"""
|
|
792
|
-
|
|
793
|
-
decls = Declarations()
|
|
794
|
-
type_ref = _register_constant(decls, ref, tp, egg_name)
|
|
795
|
-
return cast(EXPR, RuntimeExpr(decls, TypedExprDecl(type_ref, CallDecl(ref))))
|
|
812
|
+
return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name)))
|
|
796
813
|
|
|
797
814
|
|
|
798
|
-
def
|
|
799
|
-
decls
|
|
800
|
-
ref: ConstantRef | ClassVariableRef,
|
|
801
|
-
tp: object,
|
|
802
|
-
egg_name: str | None,
|
|
803
|
-
) -> JustTypeRef:
|
|
804
|
-
"""
|
|
805
|
-
Register a constant, returning its typeref().
|
|
806
|
-
"""
|
|
815
|
+
def _constant_thunk(name: str, tp: type, egg_name: str | None) -> tuple[Declarations, TypedExprDecl]:
|
|
816
|
+
decls = Declarations()
|
|
807
817
|
type_ref = resolve_type_annotation(decls, tp).to_just()
|
|
808
|
-
decls.
|
|
809
|
-
return type_ref
|
|
818
|
+
decls._constants[name] = ConstantDecl(type_ref, egg_name)
|
|
819
|
+
return decls, TypedExprDecl(type_ref, CallDecl(ConstantRef(name)))
|
|
810
820
|
|
|
811
821
|
|
|
812
822
|
def _last_param_variable(params: list[Parameter]) -> bool:
|
|
@@ -858,29 +868,6 @@ class GraphvizKwargs(TypedDict, total=False):
|
|
|
858
868
|
split_primitive_outputs: bool
|
|
859
869
|
|
|
860
870
|
|
|
861
|
-
@dataclass
|
|
862
|
-
class _EGraphState:
|
|
863
|
-
"""
|
|
864
|
-
State of the EGraph declerations and rulesets, so when we pop/push the stack we know whats defined.
|
|
865
|
-
"""
|
|
866
|
-
|
|
867
|
-
# The decleratons we have added. The _cmds represent all the symbols we have added
|
|
868
|
-
decls: Declarations = field(default_factory=Declarations)
|
|
869
|
-
# List of rulesets already added, so we don't re-add them if they are passed again
|
|
870
|
-
added_rulesets: set[str] = field(default_factory=set)
|
|
871
|
-
|
|
872
|
-
def add_decls(self, new_decls: Declarations) -> Iterable[bindings._Command]:
|
|
873
|
-
new_cmds = [v for k, v in new_decls._cmds.items() if k not in self.decls._cmds]
|
|
874
|
-
self.decls |= new_decls
|
|
875
|
-
return new_cmds
|
|
876
|
-
|
|
877
|
-
def add_rulesets(self, rulesets: Iterable[Ruleset]) -> Iterable[bindings._Command]:
|
|
878
|
-
for ruleset in rulesets:
|
|
879
|
-
if ruleset.egg_name not in self.added_rulesets:
|
|
880
|
-
self.added_rulesets.add(ruleset.egg_name)
|
|
881
|
-
yield from ruleset._cmds
|
|
882
|
-
|
|
883
|
-
|
|
884
871
|
@dataclass
|
|
885
872
|
class EGraph(_BaseModule):
|
|
886
873
|
"""
|
|
@@ -892,56 +879,34 @@ class EGraph(_BaseModule):
|
|
|
892
879
|
seminaive: InitVar[bool] = True
|
|
893
880
|
save_egglog_string: InitVar[bool] = False
|
|
894
881
|
|
|
895
|
-
|
|
896
|
-
_egraph: bindings.EGraph = field(repr=False, init=False)
|
|
897
|
-
_egglog_string: str | None = field(default=None, repr=False, init=False)
|
|
898
|
-
_state: _EGraphState = field(default_factory=_EGraphState, repr=False)
|
|
882
|
+
_state: EGraphState = field(init=False)
|
|
899
883
|
# For pushing/popping with egglog
|
|
900
|
-
_state_stack: list[
|
|
884
|
+
_state_stack: list[EGraphState] = field(default_factory=list, repr=False)
|
|
901
885
|
# For storing the global "current" egraph
|
|
902
886
|
_token_stack: list[Token[EGraph]] = field(default_factory=list, repr=False)
|
|
903
887
|
|
|
904
888
|
def __post_init__(self, modules: list[Module], seminaive: bool, save_egglog_string: bool) -> None:
|
|
905
|
-
|
|
889
|
+
egraph = bindings.EGraph(GLOBAL_PY_OBJECT_SORT, seminaive=seminaive, record=save_egglog_string)
|
|
890
|
+
self._state = EGraphState(egraph)
|
|
906
891
|
super().__post_init__(modules)
|
|
907
892
|
|
|
908
893
|
for m in self._flatted_deps:
|
|
909
|
-
self._add_decls(*m.cmds)
|
|
910
894
|
self._register_commands(m.cmds)
|
|
911
|
-
if save_egglog_string:
|
|
912
|
-
self._egglog_string = ""
|
|
913
|
-
|
|
914
|
-
def _register_commands(self, commands: list[Command]) -> None:
|
|
915
|
-
for c in commands:
|
|
916
|
-
if c.ruleset:
|
|
917
|
-
self._add_schedule(c.ruleset)
|
|
918
|
-
|
|
919
|
-
self._add_decls(*commands)
|
|
920
|
-
self._process_commands(command._to_egg_command(self._default_ruleset_name) for command in commands)
|
|
921
|
-
|
|
922
|
-
def _process_commands(self, commands: Iterable[bindings._Command]) -> None:
|
|
923
|
-
commands = list(commands)
|
|
924
|
-
self._egraph.run_program(*commands)
|
|
925
|
-
if isinstance(self._egglog_string, str):
|
|
926
|
-
self._egglog_string += "\n".join(str(c) for c in commands) + "\n"
|
|
927
895
|
|
|
928
896
|
def _add_decls(self, *decls: DeclerationsLike) -> None:
|
|
929
|
-
for d in
|
|
930
|
-
self.
|
|
931
|
-
|
|
932
|
-
def _add_schedule(self, schedule: Schedule) -> None:
|
|
933
|
-
self._add_decls(schedule)
|
|
934
|
-
self._process_commands(self._state.add_rulesets(schedule._rulesets()))
|
|
897
|
+
for d in decls:
|
|
898
|
+
self._state.__egg_decls__ |= d
|
|
935
899
|
|
|
936
900
|
@property
|
|
937
901
|
def as_egglog_string(self) -> str:
|
|
938
902
|
"""
|
|
939
903
|
Returns the egglog string for this module.
|
|
940
904
|
"""
|
|
941
|
-
|
|
905
|
+
cmds = self._egraph.commands()
|
|
906
|
+
if cmds is None:
|
|
942
907
|
msg = "Can't get egglog string unless EGraph created with save_egglog_string=True"
|
|
943
908
|
raise ValueError(msg)
|
|
944
|
-
return
|
|
909
|
+
return cmds
|
|
945
910
|
|
|
946
911
|
def _repr_mimebundle_(self, *args, **kwargs):
|
|
947
912
|
"""
|
|
@@ -954,7 +919,7 @@ class EGraph(_BaseModule):
|
|
|
954
919
|
kwargs.setdefault("split_primitive_outputs", True)
|
|
955
920
|
n_inline = kwargs.pop("n_inline_leaves", 0)
|
|
956
921
|
serialized = self._egraph.serialize([], **kwargs) # type: ignore[misc]
|
|
957
|
-
serialized.map_ops(self._state.
|
|
922
|
+
serialized.map_ops(self._state.op_mapping())
|
|
958
923
|
for _ in range(n_inline):
|
|
959
924
|
serialized.inline_leaves()
|
|
960
925
|
original = serialized.to_dot()
|
|
@@ -1003,30 +968,35 @@ class EGraph(_BaseModule):
|
|
|
1003
968
|
"""
|
|
1004
969
|
Displays the e-graph in the notebook.
|
|
1005
970
|
"""
|
|
1006
|
-
graphviz = self.graphviz(**kwargs)
|
|
1007
971
|
if IN_IPYTHON:
|
|
1008
972
|
from IPython.display import SVG, display
|
|
1009
973
|
|
|
1010
974
|
display(SVG(self.graphviz_svg(**kwargs)))
|
|
1011
975
|
else:
|
|
1012
|
-
graphviz.render(view=True, format="svg", quiet=True)
|
|
976
|
+
self.graphviz(**kwargs).render(view=True, format="svg", quiet=True)
|
|
1013
977
|
|
|
1014
978
|
def input(self, fn: Callable[..., String], path: str) -> None:
|
|
1015
979
|
"""
|
|
1016
980
|
Loads a CSV file and sets it as *input, output of the function.
|
|
1017
981
|
"""
|
|
1018
982
|
ref, decls = resolve_callable(fn)
|
|
1019
|
-
|
|
1020
|
-
self.
|
|
1021
|
-
self.
|
|
983
|
+
self._add_decls(decls)
|
|
984
|
+
fn_name = self._state.callable_ref_to_egg(ref)
|
|
985
|
+
self._egraph.run_program(bindings.Input(fn_name, path))
|
|
1022
986
|
|
|
1023
987
|
def let(self, name: str, expr: EXPR) -> EXPR:
|
|
1024
988
|
"""
|
|
1025
989
|
Define a new expression in the egraph and return a reference to it.
|
|
1026
990
|
"""
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
991
|
+
action = let(name, expr)
|
|
992
|
+
self.register(action)
|
|
993
|
+
runtime_expr = to_runtime_expr(expr)
|
|
994
|
+
return cast(
|
|
995
|
+
EXPR,
|
|
996
|
+
RuntimeExpr.__from_value__(
|
|
997
|
+
self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name))
|
|
998
|
+
),
|
|
999
|
+
)
|
|
1030
1000
|
|
|
1031
1001
|
@overload
|
|
1032
1002
|
def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> EXPR: ...
|
|
@@ -1041,28 +1011,19 @@ class EGraph(_BaseModule):
|
|
|
1041
1011
|
Simplifies the given expression.
|
|
1042
1012
|
"""
|
|
1043
1013
|
schedule = run(ruleset, *until) * limit_or_schedule if isinstance(limit_or_schedule, int) else limit_or_schedule
|
|
1044
|
-
del limit_or_schedule
|
|
1045
|
-
|
|
1046
|
-
self._add_decls(
|
|
1047
|
-
self.
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
self.
|
|
1014
|
+
del limit_or_schedule, until, ruleset
|
|
1015
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1016
|
+
self._add_decls(runtime_expr, schedule)
|
|
1017
|
+
egg_schedule = self._state.schedule_to_egg(schedule.schedule)
|
|
1018
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1019
|
+
egg_expr = self._state.expr_to_egg(typed_expr.expr)
|
|
1020
|
+
self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule))
|
|
1051
1021
|
extract_report = self._egraph.extract_report()
|
|
1052
1022
|
if not isinstance(extract_report, bindings.Best):
|
|
1053
1023
|
msg = "No extract report saved"
|
|
1054
1024
|
raise ValueError(msg) # noqa: TRY004
|
|
1055
|
-
new_typed_expr =
|
|
1056
|
-
|
|
1057
|
-
)
|
|
1058
|
-
return cast(EXPR, RuntimeExpr(self._state.decls, new_typed_expr))
|
|
1059
|
-
|
|
1060
|
-
@property
|
|
1061
|
-
def _default_ruleset_name(self) -> str:
|
|
1062
|
-
if self.default_ruleset:
|
|
1063
|
-
self._add_schedule(self.default_ruleset)
|
|
1064
|
-
return self.default_ruleset.egg_name
|
|
1065
|
-
return ""
|
|
1025
|
+
(new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
|
|
1026
|
+
return cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
|
|
1066
1027
|
|
|
1067
1028
|
def include(self, path: str) -> None:
|
|
1068
1029
|
"""
|
|
@@ -1092,8 +1053,9 @@ class EGraph(_BaseModule):
|
|
|
1092
1053
|
return self._run_schedule(limit_or_schedule)
|
|
1093
1054
|
|
|
1094
1055
|
def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
|
|
1095
|
-
self.
|
|
1096
|
-
self.
|
|
1056
|
+
self._add_decls(schedule)
|
|
1057
|
+
egg_schedule = self._state.schedule_to_egg(schedule.schedule)
|
|
1058
|
+
self._egraph.run_program(bindings.RunSchedule(egg_schedule))
|
|
1097
1059
|
run_report = self._egraph.run_report()
|
|
1098
1060
|
if not run_report:
|
|
1099
1061
|
msg = "No run report saved"
|
|
@@ -1104,18 +1066,18 @@ class EGraph(_BaseModule):
|
|
|
1104
1066
|
"""
|
|
1105
1067
|
Check if a fact is true in the egraph.
|
|
1106
1068
|
"""
|
|
1107
|
-
self.
|
|
1069
|
+
self._egraph.run_program(self._facts_to_check(facts))
|
|
1108
1070
|
|
|
1109
1071
|
def check_fail(self, *facts: FactLike) -> None:
|
|
1110
1072
|
"""
|
|
1111
1073
|
Checks that one of the facts is not true
|
|
1112
1074
|
"""
|
|
1113
|
-
self.
|
|
1075
|
+
self._egraph.run_program(bindings.Fail(self._facts_to_check(facts)))
|
|
1114
1076
|
|
|
1115
|
-
def _facts_to_check(self,
|
|
1116
|
-
facts = _fact_likes(
|
|
1077
|
+
def _facts_to_check(self, fact_likes: Iterable[FactLike]) -> bindings.Check:
|
|
1078
|
+
facts = _fact_likes(fact_likes)
|
|
1117
1079
|
self._add_decls(*facts)
|
|
1118
|
-
egg_facts = [f.
|
|
1080
|
+
egg_facts = [self._state.fact_to_egg(f.fact) for f in _fact_likes(facts)]
|
|
1119
1081
|
return bindings.Check(egg_facts)
|
|
1120
1082
|
|
|
1121
1083
|
@overload
|
|
@@ -1128,16 +1090,17 @@ class EGraph(_BaseModule):
|
|
|
1128
1090
|
"""
|
|
1129
1091
|
Extract the lowest cost expression from the egraph.
|
|
1130
1092
|
"""
|
|
1131
|
-
|
|
1132
|
-
self._add_decls(
|
|
1133
|
-
|
|
1093
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1094
|
+
self._add_decls(runtime_expr)
|
|
1095
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1096
|
+
extract_report = self._run_extract(typed_expr, 0)
|
|
1097
|
+
|
|
1134
1098
|
if not isinstance(extract_report, bindings.Best):
|
|
1135
1099
|
msg = "No extract report saved"
|
|
1136
1100
|
raise ValueError(msg) # noqa: TRY004
|
|
1137
|
-
new_typed_expr =
|
|
1138
|
-
|
|
1139
|
-
)
|
|
1140
|
-
res = cast(EXPR, RuntimeExpr(self._state.decls, new_typed_expr))
|
|
1101
|
+
(new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
|
|
1102
|
+
|
|
1103
|
+
res = cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
|
|
1141
1104
|
if include_cost:
|
|
1142
1105
|
return res, extract_report.cost
|
|
1143
1106
|
return res
|
|
@@ -1146,23 +1109,21 @@ class EGraph(_BaseModule):
|
|
|
1146
1109
|
"""
|
|
1147
1110
|
Extract multiple expressions from the egraph.
|
|
1148
1111
|
"""
|
|
1149
|
-
|
|
1150
|
-
self._add_decls(
|
|
1112
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1113
|
+
self._add_decls(runtime_expr)
|
|
1114
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1151
1115
|
|
|
1152
|
-
extract_report = self._run_extract(
|
|
1116
|
+
extract_report = self._run_extract(typed_expr, n)
|
|
1153
1117
|
if not isinstance(extract_report, bindings.Variants):
|
|
1154
1118
|
msg = "Wrong extract report type"
|
|
1155
1119
|
raise ValueError(msg) # noqa: TRY004
|
|
1156
|
-
new_exprs =
|
|
1157
|
-
|
|
1158
|
-
self._egraph, self._state.decls, expr.__egg_typed_expr__.tp, extract_report.termdag, term, {}
|
|
1159
|
-
)
|
|
1160
|
-
for term in extract_report.terms
|
|
1161
|
-
]
|
|
1162
|
-
return [cast(EXPR, RuntimeExpr(self._state.decls, expr)) for expr in new_exprs]
|
|
1120
|
+
new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp)
|
|
1121
|
+
return [cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, expr)) for expr in new_exprs]
|
|
1163
1122
|
|
|
1164
|
-
def _run_extract(self,
|
|
1165
|
-
self.
|
|
1123
|
+
def _run_extract(self, typed_expr: TypedExprDecl, n: int) -> bindings._ExtractReport:
|
|
1124
|
+
self._state.type_ref_to_egg(typed_expr.tp)
|
|
1125
|
+
expr = self._state.expr_to_egg(typed_expr.expr)
|
|
1126
|
+
self._egraph.run_program(bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n)))))
|
|
1166
1127
|
extract_report = self._egraph.extract_report()
|
|
1167
1128
|
if not extract_report:
|
|
1168
1129
|
msg = "No extract report saved"
|
|
@@ -1173,15 +1134,15 @@ class EGraph(_BaseModule):
|
|
|
1173
1134
|
"""
|
|
1174
1135
|
Push the current state of the egraph, so that it can be popped later and reverted back.
|
|
1175
1136
|
"""
|
|
1176
|
-
self.
|
|
1137
|
+
self._egraph.run_program(bindings.Push(1))
|
|
1177
1138
|
self._state_stack.append(self._state)
|
|
1178
|
-
self._state =
|
|
1139
|
+
self._state = self._state.copy()
|
|
1179
1140
|
|
|
1180
1141
|
def pop(self) -> None:
|
|
1181
1142
|
"""
|
|
1182
1143
|
Pop the current state of the egraph, reverting back to the previous state.
|
|
1183
1144
|
"""
|
|
1184
|
-
self.
|
|
1145
|
+
self._egraph.run_program(bindings.Pop(1))
|
|
1185
1146
|
self._state = self._state_stack.pop()
|
|
1186
1147
|
|
|
1187
1148
|
def __enter__(self) -> Self:
|
|
@@ -1217,9 +1178,10 @@ class EGraph(_BaseModule):
|
|
|
1217
1178
|
"""
|
|
1218
1179
|
Evaluates the given expression (which must be a primitive type), returning the result.
|
|
1219
1180
|
"""
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1181
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1182
|
+
self._add_decls(runtime_expr)
|
|
1183
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1184
|
+
egg_expr = self._state.expr_to_egg(typed_expr.expr)
|
|
1223
1185
|
match typed_expr.tp:
|
|
1224
1186
|
case JustTypeRef("i64"):
|
|
1225
1187
|
return self._egraph.eval_i64(egg_expr)
|
|
@@ -1231,7 +1193,7 @@ class EGraph(_BaseModule):
|
|
|
1231
1193
|
return self._egraph.eval_string(egg_expr)
|
|
1232
1194
|
case JustTypeRef("PyObject"):
|
|
1233
1195
|
return self._egraph.eval_py_object(egg_expr)
|
|
1234
|
-
raise
|
|
1196
|
+
raise TypeError(f"Eval not implemented for {typed_expr.tp}")
|
|
1235
1197
|
|
|
1236
1198
|
def saturate(
|
|
1237
1199
|
self, *, max: int = 1000, performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
|
|
@@ -1270,6 +1232,32 @@ class EGraph(_BaseModule):
|
|
|
1270
1232
|
"""
|
|
1271
1233
|
return CURRENT_EGRAPH.get()
|
|
1272
1234
|
|
|
1235
|
+
@property
|
|
1236
|
+
def _egraph(self) -> bindings.EGraph:
|
|
1237
|
+
return self._state.egraph
|
|
1238
|
+
|
|
1239
|
+
@property
|
|
1240
|
+
def __egg_decls__(self) -> Declarations:
|
|
1241
|
+
return self._state.__egg_decls__
|
|
1242
|
+
|
|
1243
|
+
def _register_commands(self, cmds: list[Command]) -> None:
|
|
1244
|
+
self._add_decls(*cmds)
|
|
1245
|
+
egg_cmds = list(map(self._command_to_egg, cmds))
|
|
1246
|
+
self._egraph.run_program(*egg_cmds)
|
|
1247
|
+
|
|
1248
|
+
def _command_to_egg(self, cmd: Command) -> bindings._Command:
|
|
1249
|
+
ruleset_name = ""
|
|
1250
|
+
cmd_decl: CommandDecl
|
|
1251
|
+
match cmd:
|
|
1252
|
+
case RewriteOrRule(_, cmd_decl, ruleset):
|
|
1253
|
+
if ruleset:
|
|
1254
|
+
ruleset_name = ruleset.__egg_name__
|
|
1255
|
+
case Action(_, action):
|
|
1256
|
+
cmd_decl = ActionCommandDecl(action)
|
|
1257
|
+
case _:
|
|
1258
|
+
assert_never(cmd)
|
|
1259
|
+
return self._state.command_to_egg(cmd_decl, ruleset_name)
|
|
1260
|
+
|
|
1273
1261
|
|
|
1274
1262
|
CURRENT_EGRAPH = ContextVar[EGraph]("CURRENT_EGRAPH")
|
|
1275
1263
|
|
|
@@ -1316,61 +1304,55 @@ class Unit(Expr, egg_sort="Unit", builtin=True):
|
|
|
1316
1304
|
|
|
1317
1305
|
|
|
1318
1306
|
def ruleset(
|
|
1319
|
-
rule_or_generator:
|
|
1307
|
+
rule_or_generator: RewriteOrRule | RewriteOrRuleGenerator | None = None,
|
|
1308
|
+
*rules: RewriteOrRule,
|
|
1309
|
+
name: None | str = None,
|
|
1320
1310
|
) -> Ruleset:
|
|
1321
1311
|
"""
|
|
1322
1312
|
Creates a ruleset with the following rules.
|
|
1323
1313
|
|
|
1324
|
-
If no name is provided,
|
|
1314
|
+
If no name is provided, try using the name of the funciton.
|
|
1325
1315
|
"""
|
|
1326
|
-
|
|
1316
|
+
if isinstance(rule_or_generator, FunctionType):
|
|
1317
|
+
name = name or rule_or_generator.__name__
|
|
1318
|
+
r = Ruleset(name)
|
|
1327
1319
|
if rule_or_generator is not None:
|
|
1328
|
-
r.register(rule_or_generator, *rules)
|
|
1320
|
+
r.register(rule_or_generator, *rules, _increase_frame=True)
|
|
1329
1321
|
return r
|
|
1330
1322
|
|
|
1331
1323
|
|
|
1332
|
-
|
|
1324
|
+
@dataclass
|
|
1325
|
+
class Schedule(DelayedDeclerations):
|
|
1333
1326
|
"""
|
|
1334
1327
|
A composition of some rulesets, either composing them sequentially, running them repeatedly, running them till saturation, or running until some facts are met
|
|
1335
1328
|
"""
|
|
1336
1329
|
|
|
1330
|
+
# Defer declerations so that we can have rule generators that used not yet defined yet
|
|
1331
|
+
schedule: ScheduleDecl
|
|
1332
|
+
|
|
1333
|
+
def __str__(self) -> str:
|
|
1334
|
+
return pretty_decl(self.__egg_decls__, self.schedule)
|
|
1335
|
+
|
|
1336
|
+
def __repr__(self) -> str:
|
|
1337
|
+
return str(self)
|
|
1338
|
+
|
|
1337
1339
|
def __mul__(self, length: int) -> Schedule:
|
|
1338
1340
|
"""
|
|
1339
1341
|
Repeat the schedule a number of times.
|
|
1340
1342
|
"""
|
|
1341
|
-
return
|
|
1343
|
+
return Schedule(self.__egg_decls_thunk__, RepeatDecl(self.schedule, length))
|
|
1342
1344
|
|
|
1343
1345
|
def saturate(self) -> Schedule:
|
|
1344
1346
|
"""
|
|
1345
1347
|
Run the schedule until the e-graph is saturated.
|
|
1346
1348
|
"""
|
|
1347
|
-
return
|
|
1349
|
+
return Schedule(self.__egg_decls_thunk__, SaturateDecl(self.schedule))
|
|
1348
1350
|
|
|
1349
1351
|
def __add__(self, other: Schedule) -> Schedule:
|
|
1350
1352
|
"""
|
|
1351
1353
|
Run two schedules in sequence.
|
|
1352
1354
|
"""
|
|
1353
|
-
return
|
|
1354
|
-
|
|
1355
|
-
@abstractmethod
|
|
1356
|
-
def __str__(self) -> str:
|
|
1357
|
-
raise NotImplementedError
|
|
1358
|
-
|
|
1359
|
-
@abstractmethod
|
|
1360
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1361
|
-
raise NotImplementedError
|
|
1362
|
-
|
|
1363
|
-
@abstractmethod
|
|
1364
|
-
def _rulesets(self) -> Iterable[Ruleset]:
|
|
1365
|
-
"""
|
|
1366
|
-
Mapping of all the rulesets used to commands.
|
|
1367
|
-
"""
|
|
1368
|
-
raise NotImplementedError
|
|
1369
|
-
|
|
1370
|
-
@property
|
|
1371
|
-
@abstractmethod
|
|
1372
|
-
def __egg_decls__(self) -> Declarations:
|
|
1373
|
-
raise NotImplementedError
|
|
1355
|
+
return Schedule(Thunk.fn(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule)))
|
|
1374
1356
|
|
|
1375
1357
|
|
|
1376
1358
|
@dataclass
|
|
@@ -1379,422 +1361,155 @@ class Ruleset(Schedule):
|
|
|
1379
1361
|
A collection of rules, which can be run as a schedule.
|
|
1380
1362
|
"""
|
|
1381
1363
|
|
|
1364
|
+
__egg_decls_thunk__: Callable[[], Declarations] = field(init=False)
|
|
1365
|
+
schedule: RunDecl = field(init=False)
|
|
1382
1366
|
name: str | None
|
|
1383
|
-
rules: list[Rule | Rewrite] = field(default_factory=list)
|
|
1384
1367
|
|
|
1385
|
-
|
|
1368
|
+
# Current declerations we have accumulated
|
|
1369
|
+
_current_egg_decls: Declarations = field(default_factory=Declarations)
|
|
1370
|
+
# Current rulesets we have accumulated
|
|
1371
|
+
__egg_ruleset__: RulesetDecl = field(init=False)
|
|
1372
|
+
# Rule generator functions that have been deferred, to allow for late type binding
|
|
1373
|
+
deferred_rule_gens: list[Callable[[], Iterable[RewriteOrRule]]] = field(default_factory=list)
|
|
1374
|
+
|
|
1375
|
+
def __post_init__(self) -> None:
|
|
1376
|
+
self.schedule = RunDecl(self.__egg_name__, ())
|
|
1377
|
+
self.__egg_ruleset__ = self._current_egg_decls._rulesets[self.__egg_name__] = RulesetDecl([])
|
|
1378
|
+
self.__egg_decls_thunk__ = self._update_egg_decls
|
|
1379
|
+
|
|
1380
|
+
def _update_egg_decls(self) -> Declarations:
|
|
1381
|
+
"""
|
|
1382
|
+
To return the egg decls, we go through our deferred rules and add any we haven't yet
|
|
1383
|
+
"""
|
|
1384
|
+
while self.deferred_rule_gens:
|
|
1385
|
+
rules = self.deferred_rule_gens.pop()()
|
|
1386
|
+
self._current_egg_decls.update(*rules)
|
|
1387
|
+
self.__egg_ruleset__.rules.extend(r.decl for r in rules)
|
|
1388
|
+
return self._current_egg_decls
|
|
1389
|
+
|
|
1390
|
+
def append(self, rule: RewriteOrRule) -> None:
|
|
1386
1391
|
"""
|
|
1387
1392
|
Register a rule with the ruleset.
|
|
1388
1393
|
"""
|
|
1389
|
-
self.
|
|
1394
|
+
self._current_egg_decls |= rule
|
|
1395
|
+
self.__egg_ruleset__.rules.append(rule.decl)
|
|
1390
1396
|
|
|
1391
|
-
def register(
|
|
1397
|
+
def register(
|
|
1398
|
+
self,
|
|
1399
|
+
/,
|
|
1400
|
+
rule_or_generator: RewriteOrRule | RewriteOrRuleGenerator,
|
|
1401
|
+
*rules: RewriteOrRule,
|
|
1402
|
+
_increase_frame: bool = False,
|
|
1403
|
+
) -> None:
|
|
1392
1404
|
"""
|
|
1393
1405
|
Register rewrites or rules, either as a function or as values.
|
|
1394
1406
|
"""
|
|
1395
|
-
if isinstance(rule_or_generator,
|
|
1396
|
-
|
|
1397
|
-
|
|
1407
|
+
if isinstance(rule_or_generator, RewriteOrRule):
|
|
1408
|
+
self.append(rule_or_generator)
|
|
1409
|
+
for r in rules:
|
|
1410
|
+
self.append(r)
|
|
1398
1411
|
else:
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
def _cmds(self) -> list[bindings._Command]:
|
|
1409
|
-
cmds = [r._to_egg_command(self.egg_name) for r in self.rules]
|
|
1410
|
-
if self.egg_name:
|
|
1411
|
-
cmds.insert(0, bindings.AddRuleset(self.egg_name))
|
|
1412
|
-
return cmds
|
|
1412
|
+
assert not rules
|
|
1413
|
+
current_frame = inspect.currentframe()
|
|
1414
|
+
assert current_frame
|
|
1415
|
+
original_frame = current_frame.f_back
|
|
1416
|
+
assert original_frame
|
|
1417
|
+
if _increase_frame:
|
|
1418
|
+
original_frame = original_frame.f_back
|
|
1419
|
+
assert original_frame
|
|
1420
|
+
self.deferred_rule_gens.append(Thunk.fn(_rewrite_or_rule_generator, rule_or_generator, original_frame))
|
|
1413
1421
|
|
|
1414
1422
|
def __str__(self) -> str:
|
|
1415
|
-
return
|
|
1423
|
+
return pretty_decl(self._current_egg_decls, self.__egg_ruleset__, ruleset_name=self.name)
|
|
1416
1424
|
|
|
1417
1425
|
def __repr__(self) -> str:
|
|
1418
|
-
|
|
1419
|
-
return str(self)
|
|
1420
|
-
rules = ", ".join(map(repr, self.rules))
|
|
1421
|
-
return f"ruleset({rules}, name={self.egg_name!r})"
|
|
1422
|
-
|
|
1423
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1424
|
-
return bindings.Run(self._to_egg_config())
|
|
1425
|
-
|
|
1426
|
-
def _to_egg_config(self) -> bindings.RunConfig:
|
|
1427
|
-
return bindings.RunConfig(self.egg_name, None)
|
|
1426
|
+
return str(self)
|
|
1428
1427
|
|
|
1429
|
-
def
|
|
1430
|
-
|
|
1428
|
+
def __or__(self, other: Ruleset | UnstableCombinedRuleset) -> UnstableCombinedRuleset:
|
|
1429
|
+
return unstable_combine_rulesets(self, other)
|
|
1431
1430
|
|
|
1431
|
+
# Create a unique name if we didn't pass one from the user
|
|
1432
1432
|
@property
|
|
1433
|
-
def
|
|
1434
|
-
return self.name or f"
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
class Command(ABC):
|
|
1438
|
-
"""
|
|
1439
|
-
A command that can be executed in the egg interpreter.
|
|
1440
|
-
|
|
1441
|
-
We only use this for commands which return no result and don't create new Python objects.
|
|
1442
|
-
|
|
1443
|
-
Anything that can be passed to the `register` function in a Module is a Command.
|
|
1444
|
-
"""
|
|
1445
|
-
|
|
1446
|
-
ruleset: Ruleset | None
|
|
1447
|
-
|
|
1448
|
-
@property
|
|
1449
|
-
@abstractmethod
|
|
1450
|
-
def __egg_decls__(self) -> Declarations:
|
|
1451
|
-
raise NotImplementedError
|
|
1452
|
-
|
|
1453
|
-
@abstractmethod
|
|
1454
|
-
def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
|
|
1455
|
-
raise NotImplementedError
|
|
1456
|
-
|
|
1457
|
-
@abstractmethod
|
|
1458
|
-
def __str__(self) -> str:
|
|
1459
|
-
raise NotImplementedError
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
@dataclass
|
|
1463
|
-
class Rewrite(Command):
|
|
1464
|
-
ruleset: Ruleset | None
|
|
1465
|
-
_lhs: RuntimeExpr
|
|
1466
|
-
_rhs: RuntimeExpr
|
|
1467
|
-
_conditions: tuple[Fact, ...]
|
|
1468
|
-
_subsume: bool
|
|
1469
|
-
_fn_name: ClassVar[str] = "rewrite"
|
|
1470
|
-
|
|
1471
|
-
def __str__(self) -> str:
|
|
1472
|
-
args_str = ", ".join(map(str, [self._rhs, *self._conditions]))
|
|
1473
|
-
return f"{self._fn_name}({self._lhs}).to({args_str})"
|
|
1474
|
-
|
|
1475
|
-
def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
|
|
1476
|
-
return bindings.RewriteCommand(
|
|
1477
|
-
self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite(), self._subsume
|
|
1478
|
-
)
|
|
1479
|
-
|
|
1480
|
-
def _to_egg_rewrite(self) -> bindings.Rewrite:
|
|
1481
|
-
return bindings.Rewrite(
|
|
1482
|
-
self._lhs.__egg_typed_expr__.expr.to_egg(self._lhs.__egg_decls__),
|
|
1483
|
-
self._rhs.__egg_typed_expr__.expr.to_egg(self._rhs.__egg_decls__),
|
|
1484
|
-
[c._to_egg_fact() for c in self._conditions],
|
|
1485
|
-
)
|
|
1486
|
-
|
|
1487
|
-
@cached_property
|
|
1488
|
-
def __egg_decls__(self) -> Declarations:
|
|
1489
|
-
return Declarations.create(self._lhs, self._rhs, *self._conditions)
|
|
1490
|
-
|
|
1491
|
-
def with_ruleset(self, ruleset: Ruleset) -> Rewrite:
|
|
1492
|
-
return Rewrite(ruleset, self._lhs, self._rhs, self._conditions, self._subsume)
|
|
1433
|
+
def __egg_name__(self) -> str:
|
|
1434
|
+
return self.name or f"ruleset_{id(self)}"
|
|
1493
1435
|
|
|
1494
1436
|
|
|
1495
1437
|
@dataclass
|
|
1496
|
-
class
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite()
|
|
1502
|
-
)
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
@dataclass
|
|
1506
|
-
class Fact(ABC):
|
|
1507
|
-
"""
|
|
1508
|
-
A query on an EGraph, either by an expression or an equivalence between multiple expressions.
|
|
1509
|
-
"""
|
|
1510
|
-
|
|
1511
|
-
@abstractmethod
|
|
1512
|
-
def _to_egg_fact(self) -> bindings._Fact:
|
|
1513
|
-
raise NotImplementedError
|
|
1514
|
-
|
|
1515
|
-
@property
|
|
1516
|
-
@abstractmethod
|
|
1517
|
-
def __egg_decls__(self) -> Declarations:
|
|
1518
|
-
raise NotImplementedError
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
@dataclass
|
|
1522
|
-
class Eq(Fact):
|
|
1523
|
-
_exprs: list[RuntimeExpr]
|
|
1524
|
-
|
|
1525
|
-
def __str__(self) -> str:
|
|
1526
|
-
first, *rest = self._exprs
|
|
1527
|
-
args_str = ", ".join(map(str, rest))
|
|
1528
|
-
return f"eq({first}).to({args_str})"
|
|
1529
|
-
|
|
1530
|
-
def _to_egg_fact(self) -> bindings.Eq:
|
|
1531
|
-
return bindings.Eq([e.__egg__ for e in self._exprs])
|
|
1532
|
-
|
|
1533
|
-
@cached_property
|
|
1534
|
-
def __egg_decls__(self) -> Declarations:
|
|
1535
|
-
return Declarations.create(*self._exprs)
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
@dataclass
|
|
1539
|
-
class ExprFact(Fact):
|
|
1540
|
-
_expr: RuntimeExpr
|
|
1541
|
-
|
|
1542
|
-
def __str__(self) -> str:
|
|
1543
|
-
return str(self._expr)
|
|
1544
|
-
|
|
1545
|
-
def _to_egg_fact(self) -> bindings.Fact:
|
|
1546
|
-
return bindings.Fact(self._expr.__egg__)
|
|
1547
|
-
|
|
1548
|
-
@cached_property
|
|
1549
|
-
def __egg_decls__(self) -> Declarations:
|
|
1550
|
-
return self._expr.__egg_decls__
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
@dataclass
|
|
1554
|
-
class Rule(Command):
|
|
1555
|
-
head: tuple[Action, ...]
|
|
1556
|
-
body: tuple[Fact, ...]
|
|
1557
|
-
name: str
|
|
1558
|
-
ruleset: Ruleset | None
|
|
1559
|
-
|
|
1560
|
-
def __str__(self) -> str:
|
|
1561
|
-
head_str = ", ".join(map(str, self.head))
|
|
1562
|
-
body_str = ", ".join(map(str, self.body))
|
|
1563
|
-
return f"rule({body_str}).then({head_str})"
|
|
1564
|
-
|
|
1565
|
-
def _to_egg_command(self, default_ruleset_name: str) -> bindings.RuleCommand:
|
|
1566
|
-
return bindings.RuleCommand(
|
|
1567
|
-
self.name,
|
|
1568
|
-
self.ruleset.egg_name if self.ruleset else default_ruleset_name,
|
|
1569
|
-
bindings.Rule(
|
|
1570
|
-
[a._to_egg_action() for a in self.head],
|
|
1571
|
-
[f._to_egg_fact() for f in self.body],
|
|
1572
|
-
),
|
|
1573
|
-
)
|
|
1574
|
-
|
|
1575
|
-
@cached_property
|
|
1576
|
-
def __egg_decls__(self) -> Declarations:
|
|
1577
|
-
return Declarations.create(*self.head, *self.body)
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
class Action(Command, ABC):
|
|
1581
|
-
"""
|
|
1582
|
-
A change to an EGraph, either unioning multiple expressing, setting the value of a function call, deleting an expression, or panicking.
|
|
1583
|
-
"""
|
|
1584
|
-
|
|
1585
|
-
@abstractmethod
|
|
1586
|
-
def _to_egg_action(self) -> bindings._Action:
|
|
1587
|
-
raise NotImplementedError
|
|
1438
|
+
class UnstableCombinedRuleset(Schedule):
|
|
1439
|
+
__egg_decls_thunk__: Callable[[], Declarations] = field(init=False)
|
|
1440
|
+
schedule: RunDecl = field(init=False)
|
|
1441
|
+
name: str | None
|
|
1442
|
+
rulesets: InitVar[list[Ruleset | UnstableCombinedRuleset]]
|
|
1588
1443
|
|
|
1589
|
-
def
|
|
1590
|
-
|
|
1444
|
+
def __post_init__(self, rulesets: list[Ruleset | UnstableCombinedRuleset]) -> None:
|
|
1445
|
+
self.schedule = RunDecl(self.__egg_name__, ())
|
|
1446
|
+
self.__egg_decls_thunk__ = Thunk.fn(self._create_egg_decls, *rulesets)
|
|
1591
1447
|
|
|
1592
1448
|
@property
|
|
1593
|
-
def
|
|
1594
|
-
return
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
@dataclass
|
|
1598
|
-
class Let(Action):
|
|
1599
|
-
_name: str
|
|
1600
|
-
_value: RuntimeExpr
|
|
1601
|
-
|
|
1602
|
-
def __str__(self) -> str:
|
|
1603
|
-
return f"let({self._name}, {self._value})"
|
|
1449
|
+
def __egg_name__(self) -> str:
|
|
1450
|
+
return self.name or f"combined_ruleset_{id(self)}"
|
|
1604
1451
|
|
|
1605
|
-
def
|
|
1606
|
-
|
|
1452
|
+
def _create_egg_decls(self, *rulesets: Ruleset | UnstableCombinedRuleset) -> Declarations:
|
|
1453
|
+
decls = Declarations.create(*rulesets)
|
|
1454
|
+
decls._rulesets[self.__egg_name__] = CombinedRulesetDecl(tuple(r.__egg_name__ for r in rulesets))
|
|
1455
|
+
return decls
|
|
1607
1456
|
|
|
1608
|
-
|
|
1609
|
-
|
|
1610
|
-
return self._value.__egg_decls__
|
|
1457
|
+
def __or__(self, other: Ruleset | UnstableCombinedRuleset) -> UnstableCombinedRuleset:
|
|
1458
|
+
return unstable_combine_rulesets(self, other)
|
|
1611
1459
|
|
|
1612
1460
|
|
|
1613
|
-
|
|
1614
|
-
|
|
1461
|
+
def unstable_combine_rulesets(
|
|
1462
|
+
*rulesets: Ruleset | UnstableCombinedRuleset, name: str | None = None
|
|
1463
|
+
) -> UnstableCombinedRuleset:
|
|
1615
1464
|
"""
|
|
1616
|
-
|
|
1465
|
+
Combine multiple rulesets into a single ruleset.
|
|
1617
1466
|
"""
|
|
1618
|
-
|
|
1619
|
-
_call: RuntimeExpr
|
|
1620
|
-
_rhs: RuntimeExpr
|
|
1621
|
-
|
|
1622
|
-
def __str__(self) -> str:
|
|
1623
|
-
return f"set({self._call}).to({self._rhs})"
|
|
1624
|
-
|
|
1625
|
-
def _to_egg_action(self) -> bindings.Set:
|
|
1626
|
-
egg_call = self._call.__egg__
|
|
1627
|
-
if not isinstance(egg_call, bindings.Call):
|
|
1628
|
-
raise ValueError(f"Can only create a set with a call for the lhs, got {self._call}") # noqa: TRY004
|
|
1629
|
-
return bindings.Set(
|
|
1630
|
-
egg_call.name,
|
|
1631
|
-
egg_call.args,
|
|
1632
|
-
self._rhs.__egg__,
|
|
1633
|
-
)
|
|
1634
|
-
|
|
1635
|
-
@cached_property
|
|
1636
|
-
def __egg_decls__(self) -> Declarations:
|
|
1637
|
-
return Declarations.create(self._call, self._rhs)
|
|
1467
|
+
return UnstableCombinedRuleset(name, list(rulesets))
|
|
1638
1468
|
|
|
1639
1469
|
|
|
1640
1470
|
@dataclass
|
|
1641
|
-
class
|
|
1642
|
-
|
|
1471
|
+
class RewriteOrRule:
|
|
1472
|
+
__egg_decls__: Declarations
|
|
1473
|
+
decl: RewriteOrRuleDecl
|
|
1474
|
+
ruleset: Ruleset | None = None
|
|
1643
1475
|
|
|
1644
1476
|
def __str__(self) -> str:
|
|
1645
|
-
return
|
|
1646
|
-
|
|
1647
|
-
def _to_egg_action(self) -> bindings.Expr_:
|
|
1648
|
-
return bindings.Expr_(self._expr.__egg__)
|
|
1477
|
+
return pretty_decl(self.__egg_decls__, self.decl)
|
|
1649
1478
|
|
|
1650
|
-
|
|
1651
|
-
|
|
1652
|
-
return self._expr.__egg_decls__
|
|
1479
|
+
def __repr__(self) -> str:
|
|
1480
|
+
return str(self)
|
|
1653
1481
|
|
|
1654
1482
|
|
|
1655
1483
|
@dataclass
|
|
1656
|
-
class
|
|
1484
|
+
class Fact:
|
|
1657
1485
|
"""
|
|
1658
|
-
|
|
1486
|
+
A query on an EGraph, either by an expression or an equivalence between multiple expressions.
|
|
1659
1487
|
"""
|
|
1660
1488
|
|
|
1661
|
-
|
|
1662
|
-
|
|
1489
|
+
__egg_decls__: Declarations
|
|
1490
|
+
fact: FactDecl
|
|
1663
1491
|
|
|
1664
1492
|
def __str__(self) -> str:
|
|
1665
|
-
return
|
|
1666
|
-
|
|
1667
|
-
def _to_egg_action(self) -> bindings.Change:
|
|
1668
|
-
egg_call = self._call.__egg__
|
|
1669
|
-
if not isinstance(egg_call, bindings.Call):
|
|
1670
|
-
raise ValueError(f"Can only create a call with a call for the lhs, got {self._call}") # noqa: TRY004
|
|
1671
|
-
change: bindings._Change = bindings.Delete() if self.change == "delete" else bindings.Subsume()
|
|
1672
|
-
return bindings.Change(change, egg_call.name, egg_call.args)
|
|
1493
|
+
return pretty_decl(self.__egg_decls__, self.fact)
|
|
1673
1494
|
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
return self._call.__egg_decls__
|
|
1495
|
+
def __repr__(self) -> str:
|
|
1496
|
+
return str(self)
|
|
1677
1497
|
|
|
1678
1498
|
|
|
1679
1499
|
@dataclass
|
|
1680
|
-
class
|
|
1500
|
+
class Action:
|
|
1681
1501
|
"""
|
|
1682
|
-
|
|
1502
|
+
A change to an EGraph, either unioning multiple expressing, setting the value of a function call, deleting an expression, or panicking.
|
|
1683
1503
|
"""
|
|
1684
1504
|
|
|
1685
|
-
|
|
1686
|
-
|
|
1687
|
-
|
|
1688
|
-
def __str__(self) -> str:
|
|
1689
|
-
return f"union({self._lhs}).with_({self._rhs})"
|
|
1690
|
-
|
|
1691
|
-
def _to_egg_action(self) -> bindings.Union:
|
|
1692
|
-
return bindings.Union(self._lhs.__egg__, self._rhs.__egg__)
|
|
1693
|
-
|
|
1694
|
-
@cached_property
|
|
1695
|
-
def __egg_decls__(self) -> Declarations:
|
|
1696
|
-
return Declarations.create(self._lhs, self._rhs)
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
@dataclass
|
|
1700
|
-
class Panic(Action):
|
|
1701
|
-
message: str
|
|
1702
|
-
|
|
1703
|
-
def __str__(self) -> str:
|
|
1704
|
-
return f"panic({self.message})"
|
|
1705
|
-
|
|
1706
|
-
def _to_egg_action(self) -> bindings.Panic:
|
|
1707
|
-
return bindings.Panic(self.message)
|
|
1708
|
-
|
|
1709
|
-
@cached_property
|
|
1710
|
-
def __egg_decls__(self) -> Declarations:
|
|
1711
|
-
return Declarations()
|
|
1712
|
-
|
|
1713
|
-
|
|
1714
|
-
@dataclass
|
|
1715
|
-
class Run(Schedule):
|
|
1716
|
-
"""Configuration of a run"""
|
|
1717
|
-
|
|
1718
|
-
# None if using default ruleset
|
|
1719
|
-
ruleset: Ruleset | None
|
|
1720
|
-
until: tuple[Fact, ...]
|
|
1505
|
+
__egg_decls__: Declarations
|
|
1506
|
+
action: ActionDecl
|
|
1721
1507
|
|
|
1722
1508
|
def __str__(self) -> str:
|
|
1723
|
-
|
|
1724
|
-
return f"run({args_str})"
|
|
1725
|
-
|
|
1726
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1727
|
-
return bindings.Run(self._to_egg_config(default_ruleset_name))
|
|
1509
|
+
return pretty_decl(self.__egg_decls__, self.action)
|
|
1728
1510
|
|
|
1729
|
-
def
|
|
1730
|
-
return
|
|
1731
|
-
self.ruleset.egg_name if self.ruleset else default_ruleset_name,
|
|
1732
|
-
[fact._to_egg_fact() for fact in self.until] if self.until else None,
|
|
1733
|
-
)
|
|
1734
|
-
|
|
1735
|
-
def _rulesets(self) -> Iterable[Ruleset]:
|
|
1736
|
-
if self.ruleset:
|
|
1737
|
-
yield self.ruleset
|
|
1738
|
-
|
|
1739
|
-
@cached_property
|
|
1740
|
-
def __egg_decls__(self) -> Declarations:
|
|
1741
|
-
return Declarations.create(self.ruleset, *self.until)
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
@dataclass
|
|
1745
|
-
class Saturate(Schedule):
|
|
1746
|
-
schedule: Schedule
|
|
1747
|
-
|
|
1748
|
-
def __str__(self) -> str:
|
|
1749
|
-
return f"{self.schedule}.saturate()"
|
|
1750
|
-
|
|
1751
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1752
|
-
return bindings.Saturate(self.schedule._to_egg_schedule(default_ruleset_name))
|
|
1753
|
-
|
|
1754
|
-
def _rulesets(self) -> Iterable[Ruleset]:
|
|
1755
|
-
return self.schedule._rulesets()
|
|
1756
|
-
|
|
1757
|
-
@property
|
|
1758
|
-
def __egg_decls__(self) -> Declarations:
|
|
1759
|
-
return self.schedule.__egg_decls__
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
@dataclass
|
|
1763
|
-
class Repeat(Schedule):
|
|
1764
|
-
length: int
|
|
1765
|
-
schedule: Schedule
|
|
1766
|
-
|
|
1767
|
-
def __str__(self) -> str:
|
|
1768
|
-
return f"{self.schedule} * {self.length}"
|
|
1769
|
-
|
|
1770
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1771
|
-
return bindings.Repeat(self.length, self.schedule._to_egg_schedule(default_ruleset_name))
|
|
1772
|
-
|
|
1773
|
-
def _rulesets(self) -> Iterable[Ruleset]:
|
|
1774
|
-
return self.schedule._rulesets()
|
|
1775
|
-
|
|
1776
|
-
@property
|
|
1777
|
-
def __egg_decls__(self) -> Declarations:
|
|
1778
|
-
return self.schedule.__egg_decls__
|
|
1779
|
-
|
|
1780
|
-
|
|
1781
|
-
@dataclass
|
|
1782
|
-
class Sequence(Schedule):
|
|
1783
|
-
schedules: tuple[Schedule, ...]
|
|
1784
|
-
|
|
1785
|
-
def __str__(self) -> str:
|
|
1786
|
-
return f"sequence({', '.join(map(str, self.schedules))})"
|
|
1787
|
-
|
|
1788
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1789
|
-
return bindings.Sequence([schedule._to_egg_schedule(default_ruleset_name) for schedule in self.schedules])
|
|
1790
|
-
|
|
1791
|
-
def _rulesets(self) -> Iterable[Ruleset]:
|
|
1792
|
-
for s in self.schedules:
|
|
1793
|
-
yield from s._rulesets()
|
|
1794
|
-
|
|
1795
|
-
@cached_property
|
|
1796
|
-
def __egg_decls__(self) -> Declarations:
|
|
1797
|
-
return Declarations.create(*self.schedules)
|
|
1511
|
+
def __repr__(self) -> str:
|
|
1512
|
+
return str(self)
|
|
1798
1513
|
|
|
1799
1514
|
|
|
1800
1515
|
# We use these builders so that when creating these structures we can type check
|
|
@@ -1841,30 +1556,41 @@ def ne(expr: EXPR) -> _NeBuilder[EXPR]:
|
|
|
1841
1556
|
|
|
1842
1557
|
def panic(message: str) -> Action:
|
|
1843
1558
|
"""Raise an error with the given message."""
|
|
1844
|
-
return
|
|
1559
|
+
return Action(Declarations(), PanicDecl(message))
|
|
1845
1560
|
|
|
1846
1561
|
|
|
1847
1562
|
def let(name: str, expr: Expr) -> Action:
|
|
1848
1563
|
"""Create a let binding."""
|
|
1849
|
-
|
|
1564
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1565
|
+
return Action(runtime_expr.__egg_decls__, LetDecl(name, runtime_expr.__egg_typed_expr__))
|
|
1850
1566
|
|
|
1851
1567
|
|
|
1852
1568
|
def expr_action(expr: Expr) -> Action:
|
|
1853
|
-
|
|
1569
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1570
|
+
return Action(runtime_expr.__egg_decls__, ExprActionDecl(runtime_expr.__egg_typed_expr__))
|
|
1854
1571
|
|
|
1855
1572
|
|
|
1856
1573
|
def delete(expr: Expr) -> Action:
|
|
1857
1574
|
"""Create a delete expression."""
|
|
1858
|
-
|
|
1575
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1576
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1577
|
+
call_decl = typed_expr.expr
|
|
1578
|
+
assert isinstance(call_decl, CallDecl), "Can only delete calls, not literals or vars"
|
|
1579
|
+
return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "delete"))
|
|
1859
1580
|
|
|
1860
1581
|
|
|
1861
1582
|
def subsume(expr: Expr) -> Action:
|
|
1862
|
-
"""Subsume an expression
|
|
1863
|
-
|
|
1583
|
+
"""Subsume an expression so it cannot be matched against or extracted"""
|
|
1584
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1585
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1586
|
+
call_decl = typed_expr.expr
|
|
1587
|
+
assert isinstance(call_decl, CallDecl), "Can only subsume calls, not literals or vars"
|
|
1588
|
+
return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "subsume"))
|
|
1864
1589
|
|
|
1865
1590
|
|
|
1866
1591
|
def expr_fact(expr: Expr) -> Fact:
|
|
1867
|
-
|
|
1592
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1593
|
+
return Fact(runtime_expr.__egg_decls__, ExprFactDecl(runtime_expr.__egg_typed_expr__))
|
|
1868
1594
|
|
|
1869
1595
|
|
|
1870
1596
|
def union(lhs: EXPR) -> _UnionBuilder[EXPR]:
|
|
@@ -1891,6 +1617,11 @@ def rule(*facts: FactLike, ruleset: Ruleset | None = None, name: str | None = No
|
|
|
1891
1617
|
return _RuleBuilder(facts=_fact_likes(facts), name=name, ruleset=ruleset)
|
|
1892
1618
|
|
|
1893
1619
|
|
|
1620
|
+
@deprecated("This function is now a no-op, you can remove it and use actions as commands")
|
|
1621
|
+
def action_command(action: Action) -> Action:
|
|
1622
|
+
return action
|
|
1623
|
+
|
|
1624
|
+
|
|
1894
1625
|
def var(name: str, bound: type[EXPR]) -> EXPR:
|
|
1895
1626
|
"""Create a new variable with the given name and type."""
|
|
1896
1627
|
return cast(EXPR, _var(name, bound))
|
|
@@ -1898,9 +1629,9 @@ def var(name: str, bound: type[EXPR]) -> EXPR:
|
|
|
1898
1629
|
|
|
1899
1630
|
def _var(name: str, bound: object) -> RuntimeExpr:
|
|
1900
1631
|
"""Create a new variable with the given name and type."""
|
|
1901
|
-
|
|
1902
|
-
|
|
1903
|
-
return RuntimeExpr(
|
|
1632
|
+
decls = Declarations()
|
|
1633
|
+
type_ref = resolve_type_annotation(decls, bound)
|
|
1634
|
+
return RuntimeExpr.__from_value__(decls, TypedExprDecl(type_ref.to_just(), VarDecl(name)))
|
|
1904
1635
|
|
|
1905
1636
|
|
|
1906
1637
|
def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
|
|
@@ -1915,15 +1646,27 @@ class _RewriteBuilder(Generic[EXPR]):
|
|
|
1915
1646
|
ruleset: Ruleset | None
|
|
1916
1647
|
subsume: bool
|
|
1917
1648
|
|
|
1918
|
-
def to(self, rhs: EXPR, *conditions: FactLike) ->
|
|
1649
|
+
def to(self, rhs: EXPR, *conditions: FactLike) -> RewriteOrRule:
|
|
1919
1650
|
lhs = to_runtime_expr(self.lhs)
|
|
1920
|
-
|
|
1651
|
+
facts = _fact_likes(conditions)
|
|
1652
|
+
rhs = convert_to_same_type(rhs, lhs)
|
|
1653
|
+
rule = RewriteOrRule(
|
|
1654
|
+
Declarations.create(lhs, rhs, *facts, self.ruleset),
|
|
1655
|
+
RewriteDecl(
|
|
1656
|
+
lhs.__egg_typed_expr__.tp,
|
|
1657
|
+
lhs.__egg_typed_expr__.expr,
|
|
1658
|
+
rhs.__egg_typed_expr__.expr,
|
|
1659
|
+
tuple(f.fact for f in facts),
|
|
1660
|
+
self.subsume,
|
|
1661
|
+
),
|
|
1662
|
+
)
|
|
1921
1663
|
if self.ruleset:
|
|
1922
1664
|
self.ruleset.append(rule)
|
|
1923
1665
|
return rule
|
|
1924
1666
|
|
|
1925
1667
|
def __str__(self) -> str:
|
|
1926
|
-
|
|
1668
|
+
lhs = to_runtime_expr(self.lhs)
|
|
1669
|
+
return lhs.__egg_pretty__("rewrite")
|
|
1927
1670
|
|
|
1928
1671
|
|
|
1929
1672
|
@dataclass
|
|
@@ -1931,15 +1674,26 @@ class _BirewriteBuilder(Generic[EXPR]):
|
|
|
1931
1674
|
lhs: EXPR
|
|
1932
1675
|
ruleset: Ruleset | None
|
|
1933
1676
|
|
|
1934
|
-
def to(self, rhs: EXPR, *conditions: FactLike) ->
|
|
1677
|
+
def to(self, rhs: EXPR, *conditions: FactLike) -> RewriteOrRule:
|
|
1935
1678
|
lhs = to_runtime_expr(self.lhs)
|
|
1936
|
-
|
|
1679
|
+
facts = _fact_likes(conditions)
|
|
1680
|
+
rhs = convert_to_same_type(rhs, lhs)
|
|
1681
|
+
rule = RewriteOrRule(
|
|
1682
|
+
Declarations.create(lhs, rhs, *facts, self.ruleset),
|
|
1683
|
+
BiRewriteDecl(
|
|
1684
|
+
lhs.__egg_typed_expr__.tp,
|
|
1685
|
+
lhs.__egg_typed_expr__.expr,
|
|
1686
|
+
rhs.__egg_typed_expr__.expr,
|
|
1687
|
+
tuple(f.fact for f in facts),
|
|
1688
|
+
),
|
|
1689
|
+
)
|
|
1937
1690
|
if self.ruleset:
|
|
1938
1691
|
self.ruleset.append(rule)
|
|
1939
1692
|
return rule
|
|
1940
1693
|
|
|
1941
1694
|
def __str__(self) -> str:
|
|
1942
|
-
|
|
1695
|
+
lhs = to_runtime_expr(self.lhs)
|
|
1696
|
+
return lhs.__egg_pretty__("birewrite")
|
|
1943
1697
|
|
|
1944
1698
|
|
|
1945
1699
|
@dataclass
|
|
@@ -1948,52 +1702,84 @@ class _EqBuilder(Generic[EXPR]):
|
|
|
1948
1702
|
|
|
1949
1703
|
def to(self, *exprs: EXPR) -> Fact:
|
|
1950
1704
|
expr = to_runtime_expr(self.expr)
|
|
1951
|
-
|
|
1705
|
+
args = [expr, *(convert_to_same_type(e, expr) for e in exprs)]
|
|
1706
|
+
return Fact(
|
|
1707
|
+
Declarations.create(*args),
|
|
1708
|
+
EqDecl(expr.__egg_typed_expr__.tp, tuple(a.__egg_typed_expr__.expr for a in args)),
|
|
1709
|
+
)
|
|
1710
|
+
|
|
1711
|
+
def __repr__(self) -> str:
|
|
1712
|
+
return str(self)
|
|
1952
1713
|
|
|
1953
1714
|
def __str__(self) -> str:
|
|
1954
|
-
|
|
1715
|
+
expr = to_runtime_expr(self.expr)
|
|
1716
|
+
return expr.__egg_pretty__("eq")
|
|
1955
1717
|
|
|
1956
1718
|
|
|
1957
1719
|
@dataclass
|
|
1958
1720
|
class _NeBuilder(Generic[EXPR]):
|
|
1959
|
-
|
|
1721
|
+
lhs: EXPR
|
|
1960
1722
|
|
|
1961
|
-
def to(self,
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
res = RuntimeExpr(
|
|
1966
|
-
|
|
1967
|
-
TypedExprDecl(
|
|
1723
|
+
def to(self, rhs: EXPR) -> Unit:
|
|
1724
|
+
lhs = to_runtime_expr(self.lhs)
|
|
1725
|
+
rhs = convert_to_same_type(rhs, lhs)
|
|
1726
|
+
assert isinstance(Unit, RuntimeClass)
|
|
1727
|
+
res = RuntimeExpr.__from_value__(
|
|
1728
|
+
Declarations.create(Unit, lhs, rhs),
|
|
1729
|
+
TypedExprDecl(
|
|
1730
|
+
JustTypeRef("Unit"), CallDecl(FunctionRef("!="), (lhs.__egg_typed_expr__, rhs.__egg_typed_expr__))
|
|
1731
|
+
),
|
|
1968
1732
|
)
|
|
1969
1733
|
return cast(Unit, res)
|
|
1970
1734
|
|
|
1735
|
+
def __repr__(self) -> str:
|
|
1736
|
+
return str(self)
|
|
1737
|
+
|
|
1971
1738
|
def __str__(self) -> str:
|
|
1972
|
-
|
|
1739
|
+
expr = to_runtime_expr(self.lhs)
|
|
1740
|
+
return expr.__egg_pretty__("ne")
|
|
1973
1741
|
|
|
1974
1742
|
|
|
1975
1743
|
@dataclass
|
|
1976
1744
|
class _SetBuilder(Generic[EXPR]):
|
|
1977
|
-
lhs:
|
|
1745
|
+
lhs: EXPR
|
|
1978
1746
|
|
|
1979
|
-
def to(self, rhs: EXPR) ->
|
|
1747
|
+
def to(self, rhs: EXPR) -> Action:
|
|
1980
1748
|
lhs = to_runtime_expr(self.lhs)
|
|
1981
|
-
|
|
1749
|
+
rhs = convert_to_same_type(rhs, lhs)
|
|
1750
|
+
lhs_expr = lhs.__egg_typed_expr__.expr
|
|
1751
|
+
assert isinstance(lhs_expr, CallDecl), "Can only set function calls"
|
|
1752
|
+
return Action(
|
|
1753
|
+
Declarations.create(lhs, rhs),
|
|
1754
|
+
SetDecl(lhs.__egg_typed_expr__.tp, lhs_expr, rhs.__egg_typed_expr__.expr),
|
|
1755
|
+
)
|
|
1756
|
+
|
|
1757
|
+
def __repr__(self) -> str:
|
|
1758
|
+
return str(self)
|
|
1982
1759
|
|
|
1983
1760
|
def __str__(self) -> str:
|
|
1984
|
-
|
|
1761
|
+
lhs = to_runtime_expr(self.lhs)
|
|
1762
|
+
return lhs.__egg_pretty__("set_")
|
|
1985
1763
|
|
|
1986
1764
|
|
|
1987
1765
|
@dataclass
|
|
1988
1766
|
class _UnionBuilder(Generic[EXPR]):
|
|
1989
|
-
lhs:
|
|
1767
|
+
lhs: EXPR
|
|
1990
1768
|
|
|
1991
1769
|
def with_(self, rhs: EXPR) -> Action:
|
|
1992
1770
|
lhs = to_runtime_expr(self.lhs)
|
|
1993
|
-
|
|
1771
|
+
rhs = convert_to_same_type(rhs, lhs)
|
|
1772
|
+
return Action(
|
|
1773
|
+
Declarations.create(lhs, rhs),
|
|
1774
|
+
UnionDecl(lhs.__egg_typed_expr__.tp, lhs.__egg_typed_expr__.expr, rhs.__egg_typed_expr__.expr),
|
|
1775
|
+
)
|
|
1776
|
+
|
|
1777
|
+
def __repr__(self) -> str:
|
|
1778
|
+
return str(self)
|
|
1994
1779
|
|
|
1995
1780
|
def __str__(self) -> str:
|
|
1996
|
-
|
|
1781
|
+
lhs = to_runtime_expr(self.lhs)
|
|
1782
|
+
return lhs.__egg_pretty__("union")
|
|
1997
1783
|
|
|
1998
1784
|
|
|
1999
1785
|
@dataclass
|
|
@@ -2002,12 +1788,25 @@ class _RuleBuilder:
|
|
|
2002
1788
|
name: str | None
|
|
2003
1789
|
ruleset: Ruleset | None
|
|
2004
1790
|
|
|
2005
|
-
def then(self, *actions: ActionLike) ->
|
|
2006
|
-
|
|
1791
|
+
def then(self, *actions: ActionLike) -> RewriteOrRule:
|
|
1792
|
+
actions = _action_likes(actions)
|
|
1793
|
+
rule = RewriteOrRule(
|
|
1794
|
+
Declarations.create(self.ruleset, *actions, *self.facts),
|
|
1795
|
+
RuleDecl(tuple(a.action for a in actions), tuple(f.fact for f in self.facts), self.name),
|
|
1796
|
+
)
|
|
2007
1797
|
if self.ruleset:
|
|
2008
1798
|
self.ruleset.append(rule)
|
|
2009
1799
|
return rule
|
|
2010
1800
|
|
|
1801
|
+
def __str__(self) -> str:
|
|
1802
|
+
# TODO: Figure out how to stringify rulebuilder that preserves statements
|
|
1803
|
+
args = list(map(str, self.facts))
|
|
1804
|
+
if self.name is not None:
|
|
1805
|
+
args.append(f"name={self.name}")
|
|
1806
|
+
if ruleset is not None:
|
|
1807
|
+
args.append(f"ruleset={self.ruleset}")
|
|
1808
|
+
return f"rule({', '.join(args)})"
|
|
1809
|
+
|
|
2011
1810
|
|
|
2012
1811
|
def expr_parts(expr: Expr) -> TypedExprDecl:
|
|
2013
1812
|
"""
|
|
@@ -2024,60 +1823,63 @@ def to_runtime_expr(expr: Expr) -> RuntimeExpr:
|
|
|
2024
1823
|
return expr
|
|
2025
1824
|
|
|
2026
1825
|
|
|
2027
|
-
def run(ruleset: Ruleset | None = None, *until:
|
|
1826
|
+
def run(ruleset: Ruleset | None = None, *until: FactLike) -> Schedule:
|
|
2028
1827
|
"""
|
|
2029
1828
|
Create a run configuration.
|
|
2030
1829
|
"""
|
|
2031
|
-
|
|
1830
|
+
facts = _fact_likes(until)
|
|
1831
|
+
return Schedule(
|
|
1832
|
+
Thunk.fn(Declarations.create, ruleset, *facts),
|
|
1833
|
+
RunDecl(ruleset.__egg_name__ if ruleset else "", tuple(f.fact for f in facts) or None),
|
|
1834
|
+
)
|
|
2032
1835
|
|
|
2033
1836
|
|
|
2034
1837
|
def seq(*schedules: Schedule) -> Schedule:
|
|
2035
1838
|
"""
|
|
2036
1839
|
Run a sequence of schedules.
|
|
2037
1840
|
"""
|
|
2038
|
-
return
|
|
1841
|
+
return Schedule(Thunk.fn(Declarations.create, *schedules), SequenceDecl(tuple(s.schedule for s in schedules)))
|
|
2039
1842
|
|
|
2040
1843
|
|
|
2041
|
-
|
|
1844
|
+
ActionLike: TypeAlias = Action | Expr
|
|
2042
1845
|
|
|
2043
1846
|
|
|
2044
|
-
def
|
|
2045
|
-
|
|
2046
|
-
return expr_action(command_like)
|
|
2047
|
-
return command_like
|
|
1847
|
+
def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:
|
|
1848
|
+
return tuple(map(_action_like, action_likes))
|
|
2048
1849
|
|
|
2049
1850
|
|
|
2050
|
-
|
|
1851
|
+
def _action_like(action_like: ActionLike) -> Action:
|
|
1852
|
+
if isinstance(action_like, Expr):
|
|
1853
|
+
return expr_action(action_like)
|
|
1854
|
+
return action_like
|
|
2051
1855
|
|
|
2052
1856
|
|
|
2053
|
-
|
|
2054
|
-
"""
|
|
2055
|
-
Calls the function with variables of the type and name of the arguments.
|
|
2056
|
-
"""
|
|
2057
|
-
# Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
|
|
2058
|
-
# but not in the globals
|
|
2059
|
-
current_frame = inspect.currentframe()
|
|
2060
|
-
assert current_frame
|
|
2061
|
-
register_frame = current_frame.f_back
|
|
2062
|
-
assert register_frame
|
|
2063
|
-
original_frame = register_frame.f_back
|
|
2064
|
-
assert original_frame
|
|
2065
|
-
hints = get_type_hints(gen, gen.__globals__, original_frame.f_locals)
|
|
2066
|
-
args = (_var(p.name, hints[p.name]) for p in signature(gen).parameters.values())
|
|
2067
|
-
return gen(*args)
|
|
1857
|
+
Command: TypeAlias = Action | RewriteOrRule
|
|
2068
1858
|
|
|
1859
|
+
CommandLike: TypeAlias = ActionLike | RewriteOrRule
|
|
2069
1860
|
|
|
2070
|
-
ActionLike = Action | Expr
|
|
2071
1861
|
|
|
1862
|
+
def _command_like(command_like: CommandLike) -> Command:
|
|
1863
|
+
if isinstance(command_like, RewriteOrRule):
|
|
1864
|
+
return command_like
|
|
1865
|
+
return _action_like(command_like)
|
|
2072
1866
|
|
|
2073
|
-
def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:
|
|
2074
|
-
return tuple(map(_action_like, action_likes))
|
|
2075
1867
|
|
|
1868
|
+
RewriteOrRuleGenerator = Callable[..., Iterable[RewriteOrRule]]
|
|
2076
1869
|
|
|
2077
|
-
|
|
2078
|
-
|
|
2079
|
-
|
|
2080
|
-
|
|
1870
|
+
|
|
1871
|
+
def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) -> Iterable[RewriteOrRule]:
|
|
1872
|
+
"""
|
|
1873
|
+
Returns a thunk which will call the function with variables of the type and name of the arguments.
|
|
1874
|
+
"""
|
|
1875
|
+
# Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
|
|
1876
|
+
# but not in the globals
|
|
1877
|
+
globals = gen.__globals__.copy()
|
|
1878
|
+
if "Callable" not in globals:
|
|
1879
|
+
globals["Callable"] = Callable
|
|
1880
|
+
hints = get_type_hints(gen, globals, frame.f_locals)
|
|
1881
|
+
args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
|
|
1882
|
+
return list(gen(*args)) # type: ignore[misc]
|
|
2081
1883
|
|
|
2082
1884
|
|
|
2083
1885
|
FactLike = Fact | Expr
|