egglog 6.0.1__cp310-none-win_amd64.whl → 7.0.0__cp310-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +1 -1
- egglog/bindings.cp310-win_amd64.pyd +0 -0
- egglog/bindings.pyi +24 -4
- egglog/builtins.py +1 -1
- egglog/conversion.py +172 -0
- egglog/declarations.py +329 -735
- egglog/egraph.py +539 -803
- egglog/egraph_state.py +417 -0
- egglog/exp/array_api.py +96 -84
- egglog/exp/array_api_numba.py +13 -11
- egglog/exp/siu_examples.py +35 -0
- egglog/pretty.py +418 -0
- egglog/runtime.py +196 -430
- egglog/thunk.py +72 -0
- egglog/type_constraint_solver.py +5 -2
- {egglog-6.0.1.dist-info → egglog-7.0.0.dist-info}/METADATA +14 -14
- {egglog-6.0.1.dist-info → egglog-7.0.0.dist-info}/RECORD +19 -14
- {egglog-6.0.1.dist-info → egglog-7.0.0.dist-info}/WHEEL +0 -0
- {egglog-6.0.1.dist-info → egglog-7.0.0.dist-info}/license_files/LICENSE +0 -0
egglog/egraph.py
CHANGED
|
@@ -3,12 +3,10 @@ from __future__ import annotations
|
|
|
3
3
|
import inspect
|
|
4
4
|
import pathlib
|
|
5
5
|
import tempfile
|
|
6
|
-
from abc import
|
|
6
|
+
from abc import abstractmethod
|
|
7
7
|
from collections.abc import Callable, Iterable
|
|
8
8
|
from contextvars import ContextVar, Token
|
|
9
|
-
from copy import deepcopy
|
|
10
9
|
from dataclasses import InitVar, dataclass, field
|
|
11
|
-
from functools import cached_property
|
|
12
10
|
from inspect import Parameter, currentframe, signature
|
|
13
11
|
from types import FrameType, FunctionType
|
|
14
12
|
from typing import (
|
|
@@ -18,6 +16,7 @@ from typing import (
|
|
|
18
16
|
Generic,
|
|
19
17
|
Literal,
|
|
20
18
|
NoReturn,
|
|
19
|
+
TypeAlias,
|
|
21
20
|
TypedDict,
|
|
22
21
|
TypeVar,
|
|
23
22
|
cast,
|
|
@@ -26,14 +25,16 @@ from typing import (
|
|
|
26
25
|
)
|
|
27
26
|
|
|
28
27
|
import graphviz
|
|
29
|
-
from typing_extensions import ParamSpec, Self, Unpack, deprecated
|
|
30
|
-
|
|
31
|
-
from egglog.declarations import REFLECTED_BINARY_METHODS, Declarations
|
|
28
|
+
from typing_extensions import ParamSpec, Self, Unpack, assert_never, deprecated
|
|
32
29
|
|
|
33
30
|
from . import bindings
|
|
31
|
+
from .conversion import *
|
|
34
32
|
from .declarations import *
|
|
33
|
+
from .egraph_state import *
|
|
35
34
|
from .ipython_magic import IN_IPYTHON
|
|
35
|
+
from .pretty import pretty_decl
|
|
36
36
|
from .runtime import *
|
|
37
|
+
from .thunk import *
|
|
37
38
|
|
|
38
39
|
if TYPE_CHECKING:
|
|
39
40
|
import ipywidgets
|
|
@@ -58,6 +59,7 @@ __all__ = [
|
|
|
58
59
|
"let",
|
|
59
60
|
"constant",
|
|
60
61
|
"delete",
|
|
62
|
+
"subsume",
|
|
61
63
|
"union",
|
|
62
64
|
"set_",
|
|
63
65
|
"rule",
|
|
@@ -65,6 +67,9 @@ __all__ = [
|
|
|
65
67
|
"vars_",
|
|
66
68
|
"Fact",
|
|
67
69
|
"expr_parts",
|
|
70
|
+
"expr_action",
|
|
71
|
+
"expr_fact",
|
|
72
|
+
"action_command",
|
|
68
73
|
"Schedule",
|
|
69
74
|
"run",
|
|
70
75
|
"seq",
|
|
@@ -79,11 +84,10 @@ __all__ = [
|
|
|
79
84
|
"_NeBuilder",
|
|
80
85
|
"_SetBuilder",
|
|
81
86
|
"_UnionBuilder",
|
|
82
|
-
"
|
|
83
|
-
"
|
|
84
|
-
"BiRewrite",
|
|
85
|
-
"Union_",
|
|
87
|
+
"RewriteOrRule",
|
|
88
|
+
"Fact",
|
|
86
89
|
"Action",
|
|
90
|
+
"Command",
|
|
87
91
|
]
|
|
88
92
|
|
|
89
93
|
T = TypeVar("T")
|
|
@@ -112,7 +116,24 @@ IGNORED_ATTRIBUTES = {
|
|
|
112
116
|
}
|
|
113
117
|
|
|
114
118
|
|
|
119
|
+
# special methods that return none and mutate self
|
|
115
120
|
ALWAYS_MUTATES_SELF = {"__setitem__", "__delitem__"}
|
|
121
|
+
# special methods which must return real python values instead of lazy expressions
|
|
122
|
+
ALWAYS_PRESERVED = {
|
|
123
|
+
"__repr__",
|
|
124
|
+
"__str__",
|
|
125
|
+
"__bytes__",
|
|
126
|
+
"__format__",
|
|
127
|
+
"__hash__",
|
|
128
|
+
"__bool__",
|
|
129
|
+
"__len__",
|
|
130
|
+
"__length_hint__",
|
|
131
|
+
"__iter__",
|
|
132
|
+
"__reversed__",
|
|
133
|
+
"__contains__",
|
|
134
|
+
"__index__",
|
|
135
|
+
"__bufer__",
|
|
136
|
+
}
|
|
116
137
|
|
|
117
138
|
|
|
118
139
|
def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
|
|
@@ -124,7 +145,7 @@ def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
|
|
|
124
145
|
return EGraph().extract(x)
|
|
125
146
|
|
|
126
147
|
|
|
127
|
-
def check(x: FactLike, schedule: Schedule | None = None, *given:
|
|
148
|
+
def check(x: FactLike, schedule: Schedule | None = None, *given: ActionLike) -> None:
|
|
128
149
|
"""
|
|
129
150
|
Verifies that the fact is true given some assumptions and after running the schedule.
|
|
130
151
|
"""
|
|
@@ -136,9 +157,6 @@ def check(x: FactLike, schedule: Schedule | None = None, *given: Union_ | Expr |
|
|
|
136
157
|
egraph.check(x)
|
|
137
158
|
|
|
138
159
|
|
|
139
|
-
# def extract(res: )
|
|
140
|
-
|
|
141
|
-
|
|
142
160
|
@dataclass
|
|
143
161
|
class _BaseModule:
|
|
144
162
|
"""
|
|
@@ -181,15 +199,8 @@ class _BaseModule:
|
|
|
181
199
|
Registers a class.
|
|
182
200
|
"""
|
|
183
201
|
if kwargs:
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
def _inner(cls: object, egg_sort: str = kwargs["egg_sort"]):
|
|
187
|
-
assert isinstance(cls, RuntimeClass)
|
|
188
|
-
assert isinstance(cls.lazy_decls, _ClassDeclerationsConstructor)
|
|
189
|
-
cls.lazy_decls.egg_sort = egg_sort
|
|
190
|
-
return cls
|
|
191
|
-
|
|
192
|
-
return _inner
|
|
202
|
+
msg = "Switch to subclassing from Expr and passing egg_sort as a keyword arg to the class constructor"
|
|
203
|
+
raise NotImplementedError(msg)
|
|
193
204
|
|
|
194
205
|
assert len(args) == 1
|
|
195
206
|
return args[0]
|
|
@@ -280,9 +291,9 @@ class _BaseModule:
|
|
|
280
291
|
# If we have any positional args, then we are calling it directly on a function
|
|
281
292
|
if args:
|
|
282
293
|
assert len(args) == 1
|
|
283
|
-
return
|
|
294
|
+
return _FunctionConstructor(fn_locals)(args[0])
|
|
284
295
|
# otherwise, we are passing some keyword args, so save those, and then return a partial
|
|
285
|
-
return
|
|
296
|
+
return _FunctionConstructor(fn_locals, **kwargs)
|
|
286
297
|
|
|
287
298
|
@deprecated("Use top level `ruleset` function instead")
|
|
288
299
|
def ruleset(self, name: str) -> Ruleset:
|
|
@@ -324,17 +335,26 @@ class _BaseModule:
|
|
|
324
335
|
"""
|
|
325
336
|
return constant(name, tp, egg_name)
|
|
326
337
|
|
|
327
|
-
def register(
|
|
338
|
+
def register(
|
|
339
|
+
self,
|
|
340
|
+
/,
|
|
341
|
+
command_or_generator: ActionLike | RewriteOrRule | RewriteOrRuleGenerator,
|
|
342
|
+
*command_likes: ActionLike | RewriteOrRule,
|
|
343
|
+
) -> None:
|
|
328
344
|
"""
|
|
329
345
|
Registers any number of rewrites or rules.
|
|
330
346
|
"""
|
|
331
347
|
if isinstance(command_or_generator, FunctionType):
|
|
332
348
|
assert not command_likes
|
|
333
|
-
|
|
349
|
+
current_frame = inspect.currentframe()
|
|
350
|
+
assert current_frame
|
|
351
|
+
original_frame = current_frame.f_back
|
|
352
|
+
assert original_frame
|
|
353
|
+
command_likes = tuple(_rewrite_or_rule_generator(command_or_generator, original_frame))
|
|
334
354
|
else:
|
|
335
355
|
command_likes = (cast(CommandLike, command_or_generator), *command_likes)
|
|
336
|
-
|
|
337
|
-
self._register_commands(
|
|
356
|
+
commands = [_command_like(c) for c in command_likes]
|
|
357
|
+
self._register_commands(commands)
|
|
338
358
|
|
|
339
359
|
@abstractmethod
|
|
340
360
|
def _register_commands(self, cmds: list[Command]) -> None:
|
|
@@ -417,136 +437,116 @@ class _ExprMetaclass(type):
|
|
|
417
437
|
# If this is the Expr subclass, just return the class
|
|
418
438
|
if not bases:
|
|
419
439
|
return super().__new__(cls, name, bases, namespace)
|
|
440
|
+
# TODO: Raise error on subclassing or multiple inheritence
|
|
420
441
|
|
|
421
442
|
frame = currentframe()
|
|
422
443
|
assert frame
|
|
423
444
|
prev_frame = frame.f_back
|
|
424
445
|
assert prev_frame
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
cls_name=name,
|
|
434
|
-
).current_cls
|
|
446
|
+
|
|
447
|
+
# Store frame so that we can get live access to updated locals/globals
|
|
448
|
+
# Otherwise, f_locals returns a copy
|
|
449
|
+
# https://peps.python.org/pep-0667/
|
|
450
|
+
decls_thunk = Thunk.fn(
|
|
451
|
+
_generate_class_decls, namespace, prev_frame, builtin, egg_sort, name, fallback=Declarations
|
|
452
|
+
)
|
|
453
|
+
return RuntimeClass(decls_thunk, TypeRefWithVars(name))
|
|
435
454
|
|
|
436
455
|
def __instancecheck__(cls, instance: object) -> bool:
|
|
437
456
|
return isinstance(instance, RuntimeExpr)
|
|
438
457
|
|
|
439
458
|
|
|
440
|
-
|
|
441
|
-
|
|
459
|
+
def _generate_class_decls(
|
|
460
|
+
namespace: dict[str, Any], frame: FrameType, builtin: bool, egg_sort: str | None, cls_name: str
|
|
461
|
+
) -> Declarations:
|
|
442
462
|
"""
|
|
443
463
|
Lazy constructor for class declerations to support classes with methods whose types are not yet defined.
|
|
444
464
|
"""
|
|
465
|
+
parameters: list[TypeVar] = (
|
|
466
|
+
# Get the generic params from the orig bases generic class
|
|
467
|
+
namespace["__orig_bases__"][1].__parameters__ if "__orig_bases__" in namespace else []
|
|
468
|
+
)
|
|
469
|
+
type_vars = tuple(p.__name__ for p in parameters)
|
|
470
|
+
del parameters
|
|
471
|
+
cls_decl = ClassDecl(egg_sort, type_vars, builtin)
|
|
472
|
+
decls = Declarations(_classes={cls_name: cls_decl})
|
|
473
|
+
|
|
474
|
+
##
|
|
475
|
+
# Register class variables
|
|
476
|
+
##
|
|
477
|
+
# Create a dummy type to pass to get_type_hints to resolve the annotations we have
|
|
478
|
+
_Dummytype = type("_DummyType", (), {"__annotations__": namespace.get("__annotations__", {})})
|
|
479
|
+
for k, v in get_type_hints(_Dummytype, globalns=frame.f_globals, localns=frame.f_locals).items():
|
|
480
|
+
if getattr(v, "__origin__", None) == ClassVar:
|
|
481
|
+
(inner_tp,) = v.__args__
|
|
482
|
+
type_ref = resolve_type_annotation(decls, inner_tp).to_just()
|
|
483
|
+
cls_decl.class_variables[k] = ConstantDecl(type_ref)
|
|
445
484
|
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
egg_sort: str | None
|
|
450
|
-
cls_name: str
|
|
451
|
-
current_cls: RuntimeClass = field(init=False)
|
|
485
|
+
else:
|
|
486
|
+
msg = f"On class {cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
|
|
487
|
+
raise NotImplementedError(msg)
|
|
452
488
|
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
_Dummytype.__annotations__ = self.namespace.get("__annotations__", {})
|
|
477
|
-
# Make lazy update to locals, so we keep a live handle on them after class creation
|
|
478
|
-
locals = self.frame.f_locals.copy()
|
|
479
|
-
locals[self.cls_name] = self.current_cls
|
|
480
|
-
for k, v in get_type_hints(_Dummytype, globalns=self.frame.f_globals, localns=locals).items():
|
|
481
|
-
if getattr(v, "__origin__", None) == ClassVar:
|
|
482
|
-
(inner_tp,) = v.__args__
|
|
483
|
-
_register_constant(decls, ClassVariableRef(self.cls_name, k), inner_tp, None)
|
|
484
|
-
else:
|
|
485
|
-
msg = f"On class {self.cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
|
|
486
|
-
raise NotImplementedError(msg)
|
|
487
|
-
|
|
488
|
-
# Then register each of its methods
|
|
489
|
-
for method_name, method in cls_dict.items():
|
|
490
|
-
is_init = method_name == "__init__"
|
|
491
|
-
# Don't register the init methods for literals, since those don't use the type checking mechanisms
|
|
492
|
-
if is_init and self.cls_name in LIT_CLASS_NAMES:
|
|
493
|
-
continue
|
|
494
|
-
if isinstance(method, _WrappedMethod):
|
|
495
|
-
fn = method.fn
|
|
496
|
-
egg_fn = method.egg_fn
|
|
497
|
-
cost = method.cost
|
|
498
|
-
default = method.default
|
|
499
|
-
merge = method.merge
|
|
500
|
-
on_merge = method.on_merge
|
|
501
|
-
mutates_first_arg = method.mutates_self
|
|
502
|
-
unextractable = method.unextractable
|
|
503
|
-
if method.preserve:
|
|
504
|
-
decls.register_preserved_method(self.cls_name, method_name, fn)
|
|
505
|
-
continue
|
|
506
|
-
else:
|
|
507
|
-
fn = method
|
|
489
|
+
##
|
|
490
|
+
# Register methods, classmethods, preserved methods, and properties
|
|
491
|
+
##
|
|
492
|
+
|
|
493
|
+
# The type ref of self is paramterized by the type vars
|
|
494
|
+
slf_type_ref = TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
|
|
495
|
+
|
|
496
|
+
# Get all the methods from the class
|
|
497
|
+
filtered_namespace: list[tuple[str, Any]] = [
|
|
498
|
+
(k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
|
|
499
|
+
]
|
|
500
|
+
|
|
501
|
+
# Then register each of its methods
|
|
502
|
+
for method_name, method in filtered_namespace:
|
|
503
|
+
is_init = method_name == "__init__"
|
|
504
|
+
# Don't register the init methods for literals, since those don't use the type checking mechanisms
|
|
505
|
+
if is_init and cls_name in LIT_CLASS_NAMES:
|
|
506
|
+
continue
|
|
507
|
+
match method:
|
|
508
|
+
case _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn, preserve, mutates, unextractable):
|
|
509
|
+
pass
|
|
510
|
+
case _:
|
|
508
511
|
egg_fn, cost, default, merge, on_merge = None, None, None, None, None
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
fn = fn.fget
|
|
520
|
-
is_property = True
|
|
521
|
-
if is_classmethod:
|
|
522
|
-
msg = "Can't have a classmethod property"
|
|
523
|
-
raise NotImplementedError(msg)
|
|
524
|
-
else:
|
|
525
|
-
is_property = False
|
|
526
|
-
ref: FunctionCallableRef = (
|
|
527
|
-
ClassMethodRef(self.cls_name, method_name)
|
|
528
|
-
if is_classmethod
|
|
529
|
-
else PropertyRef(self.cls_name, method_name)
|
|
530
|
-
if is_property
|
|
531
|
-
else MethodRef(self.cls_name, method_name)
|
|
532
|
-
)
|
|
533
|
-
_register_function(
|
|
512
|
+
fn = method
|
|
513
|
+
unextractable, preserve = False, False
|
|
514
|
+
mutates = method_name in ALWAYS_MUTATES_SELF
|
|
515
|
+
if preserve:
|
|
516
|
+
cls_decl.preserved_methods[method_name] = fn
|
|
517
|
+
continue
|
|
518
|
+
locals = frame.f_locals
|
|
519
|
+
|
|
520
|
+
def create_decl(fn: object, first: Literal["cls"] | TypeRefWithVars) -> FunctionDecl:
|
|
521
|
+
return _fn_decl(
|
|
534
522
|
decls,
|
|
535
|
-
|
|
536
|
-
egg_fn,
|
|
523
|
+
egg_fn, # noqa: B023
|
|
537
524
|
fn,
|
|
538
|
-
locals,
|
|
539
|
-
default,
|
|
540
|
-
cost,
|
|
541
|
-
merge,
|
|
542
|
-
on_merge,
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
is_init,
|
|
547
|
-
unextractable
|
|
525
|
+
locals, # noqa: B023
|
|
526
|
+
default, # noqa: B023
|
|
527
|
+
cost, # noqa: B023
|
|
528
|
+
merge, # noqa: B023
|
|
529
|
+
on_merge, # noqa: B023
|
|
530
|
+
mutates, # noqa: B023
|
|
531
|
+
builtin,
|
|
532
|
+
first,
|
|
533
|
+
is_init, # noqa: B023
|
|
534
|
+
unextractable, # noqa: B023
|
|
548
535
|
)
|
|
549
536
|
|
|
537
|
+
match fn:
|
|
538
|
+
case classmethod():
|
|
539
|
+
cls_decl.class_methods[method_name] = create_decl(fn.__func__, "cls")
|
|
540
|
+
case property():
|
|
541
|
+
cls_decl.properties[method_name] = create_decl(fn.fget, slf_type_ref)
|
|
542
|
+
case _:
|
|
543
|
+
if is_init:
|
|
544
|
+
cls_decl.class_methods[method_name] = create_decl(fn, slf_type_ref)
|
|
545
|
+
else:
|
|
546
|
+
cls_decl.methods[method_name] = create_decl(fn, slf_type_ref)
|
|
547
|
+
|
|
548
|
+
return decls
|
|
549
|
+
|
|
550
550
|
|
|
551
551
|
@overload
|
|
552
552
|
def function(fn: CALLABLE, /) -> CALLABLE: ...
|
|
@@ -589,48 +589,46 @@ def function(*args, **kwargs) -> Any:
|
|
|
589
589
|
# If we have any positional args, then we are calling it directly on a function
|
|
590
590
|
if args:
|
|
591
591
|
assert len(args) == 1
|
|
592
|
-
return
|
|
592
|
+
return _FunctionConstructor(fn_locals)(args[0])
|
|
593
593
|
# otherwise, we are passing some keyword args, so save those, and then return a partial
|
|
594
|
-
return
|
|
594
|
+
return _FunctionConstructor(fn_locals, **kwargs)
|
|
595
595
|
|
|
596
596
|
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
hint_locals: dict[str, Any]
|
|
600
|
-
builtin: bool = False
|
|
601
|
-
mutates_first_arg: bool = False
|
|
602
|
-
egg_fn: str | None = None
|
|
603
|
-
cost: int | None = None
|
|
604
|
-
default: RuntimeExpr | None = None
|
|
605
|
-
merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None
|
|
606
|
-
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
|
|
607
|
-
unextractable: bool = False
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
return RuntimeFunction(decls, name)
|
|
597
|
+
@dataclass
|
|
598
|
+
class _FunctionConstructor:
|
|
599
|
+
hint_locals: dict[str, Any]
|
|
600
|
+
builtin: bool = False
|
|
601
|
+
mutates_first_arg: bool = False
|
|
602
|
+
egg_fn: str | None = None
|
|
603
|
+
cost: int | None = None
|
|
604
|
+
default: RuntimeExpr | None = None
|
|
605
|
+
merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None
|
|
606
|
+
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
|
|
607
|
+
unextractable: bool = False
|
|
608
|
+
|
|
609
|
+
def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
|
|
610
|
+
return RuntimeFunction(Thunk.fn(self.create_decls, fn), FunctionRef(fn.__name__))
|
|
611
|
+
|
|
612
|
+
def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations:
|
|
613
|
+
decls = Declarations()
|
|
614
|
+
decls._functions[fn.__name__] = _fn_decl(
|
|
615
|
+
decls,
|
|
616
|
+
self.egg_fn,
|
|
617
|
+
fn,
|
|
618
|
+
self.hint_locals,
|
|
619
|
+
self.default,
|
|
620
|
+
self.cost,
|
|
621
|
+
self.merge,
|
|
622
|
+
self.on_merge,
|
|
623
|
+
self.mutates_first_arg,
|
|
624
|
+
self.builtin,
|
|
625
|
+
unextractable=self.unextractable,
|
|
626
|
+
)
|
|
627
|
+
return decls
|
|
629
628
|
|
|
630
629
|
|
|
631
|
-
def
|
|
630
|
+
def _fn_decl(
|
|
632
631
|
decls: Declarations,
|
|
633
|
-
ref: FunctionCallableRef,
|
|
634
632
|
egg_name: str | None,
|
|
635
633
|
fn: object,
|
|
636
634
|
# Pass in the locals, retrieved from the frame when wrapping,
|
|
@@ -646,7 +644,7 @@ def _register_function(
|
|
|
646
644
|
first_arg: Literal["cls"] | TypeOrVarRef | None = None,
|
|
647
645
|
is_init: bool = False,
|
|
648
646
|
unextractable: bool = False,
|
|
649
|
-
) ->
|
|
647
|
+
) -> FunctionDecl:
|
|
650
648
|
if not isinstance(fn, FunctionType):
|
|
651
649
|
raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
|
|
652
650
|
|
|
@@ -699,8 +697,8 @@ def _register_function(
|
|
|
699
697
|
None
|
|
700
698
|
if merge is None
|
|
701
699
|
else merge(
|
|
702
|
-
RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
|
|
703
|
-
RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
|
|
700
|
+
RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
|
|
701
|
+
RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
|
|
704
702
|
)
|
|
705
703
|
)
|
|
706
704
|
decls |= merged
|
|
@@ -710,30 +708,25 @@ def _register_function(
|
|
|
710
708
|
if on_merge is None
|
|
711
709
|
else _action_likes(
|
|
712
710
|
on_merge(
|
|
713
|
-
RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
|
|
714
|
-
RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
|
|
711
|
+
RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
|
|
712
|
+
RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
|
|
715
713
|
)
|
|
716
714
|
)
|
|
717
715
|
)
|
|
718
716
|
decls.update(*merge_action)
|
|
719
|
-
|
|
720
|
-
return_type=return_type,
|
|
717
|
+
return FunctionDecl(
|
|
718
|
+
return_type=None if mutates_first_arg else return_type,
|
|
721
719
|
var_arg_type=var_arg_type,
|
|
722
720
|
arg_types=arg_types,
|
|
723
721
|
arg_names=tuple(t.name for t in params),
|
|
724
722
|
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
None if default is None else default.__egg_typed_expr__.expr,
|
|
733
|
-
merged.__egg_typed_expr__.expr if merged is not None else None,
|
|
734
|
-
[a._to_egg_action() for a in merge_action],
|
|
735
|
-
unextractable,
|
|
736
|
-
is_builtin,
|
|
723
|
+
cost=cost,
|
|
724
|
+
egg_name=egg_name,
|
|
725
|
+
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
|
|
726
|
+
unextractable=unextractable,
|
|
727
|
+
builtin=is_builtin,
|
|
728
|
+
default=None if default is None else default.__egg_typed_expr__.expr,
|
|
729
|
+
on_merge=tuple(a.action for a in merge_action),
|
|
737
730
|
)
|
|
738
731
|
|
|
739
732
|
|
|
@@ -764,49 +757,31 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..
|
|
|
764
757
|
"""
|
|
765
758
|
Creates a function whose return type is `Unit` and has a default value.
|
|
766
759
|
"""
|
|
760
|
+
decls_thunk = Thunk.fn(_relation_decls, name, tps, egg_fn)
|
|
761
|
+
return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, FunctionRef(name)))
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations:
|
|
767
765
|
decls = Declarations()
|
|
768
766
|
decls |= cast(RuntimeClass, Unit)
|
|
769
|
-
arg_types = tuple(resolve_type_annotation(decls, tp) for tp in tps)
|
|
770
|
-
|
|
771
|
-
decls
|
|
772
|
-
FunctionRef(name),
|
|
773
|
-
fn_decl,
|
|
774
|
-
egg_fn,
|
|
775
|
-
cost=None,
|
|
776
|
-
default=None,
|
|
777
|
-
merge=None,
|
|
778
|
-
merge_action=[],
|
|
779
|
-
unextractable=False,
|
|
780
|
-
builtin=False,
|
|
781
|
-
is_relation=True,
|
|
782
|
-
)
|
|
783
|
-
return cast(Callable[..., Unit], RuntimeFunction(decls, name))
|
|
767
|
+
arg_types = tuple(resolve_type_annotation(decls, tp).to_just() for tp in tps)
|
|
768
|
+
decls._functions[name] = RelationDecl(arg_types, tuple(None for _ in tps), egg_fn)
|
|
769
|
+
return decls
|
|
784
770
|
|
|
785
771
|
|
|
786
772
|
def constant(name: str, tp: type[EXPR], egg_name: str | None = None) -> EXPR:
|
|
787
773
|
"""
|
|
788
|
-
|
|
789
774
|
A "constant" is implemented as the instantiation of a value that takes no args.
|
|
790
775
|
This creates a function with `name` and return type `tp` and returns a value of it being called.
|
|
791
776
|
"""
|
|
792
|
-
|
|
793
|
-
decls = Declarations()
|
|
794
|
-
type_ref = _register_constant(decls, ref, tp, egg_name)
|
|
795
|
-
return cast(EXPR, RuntimeExpr(decls, TypedExprDecl(type_ref, CallDecl(ref))))
|
|
777
|
+
return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name)))
|
|
796
778
|
|
|
797
779
|
|
|
798
|
-
def
|
|
799
|
-
decls
|
|
800
|
-
ref: ConstantRef | ClassVariableRef,
|
|
801
|
-
tp: object,
|
|
802
|
-
egg_name: str | None,
|
|
803
|
-
) -> JustTypeRef:
|
|
804
|
-
"""
|
|
805
|
-
Register a constant, returning its typeref().
|
|
806
|
-
"""
|
|
780
|
+
def _constant_thunk(name: str, tp: type, egg_name: str | None) -> tuple[Declarations, TypedExprDecl]:
|
|
781
|
+
decls = Declarations()
|
|
807
782
|
type_ref = resolve_type_annotation(decls, tp).to_just()
|
|
808
|
-
decls.
|
|
809
|
-
return type_ref
|
|
783
|
+
decls._constants[name] = ConstantDecl(type_ref, egg_name)
|
|
784
|
+
return decls, TypedExprDecl(type_ref, CallDecl(ConstantRef(name)))
|
|
810
785
|
|
|
811
786
|
|
|
812
787
|
def _last_param_variable(params: list[Parameter]) -> bool:
|
|
@@ -858,29 +833,6 @@ class GraphvizKwargs(TypedDict, total=False):
|
|
|
858
833
|
split_primitive_outputs: bool
|
|
859
834
|
|
|
860
835
|
|
|
861
|
-
@dataclass
|
|
862
|
-
class _EGraphState:
|
|
863
|
-
"""
|
|
864
|
-
State of the EGraph declerations and rulesets, so when we pop/push the stack we know whats defined.
|
|
865
|
-
"""
|
|
866
|
-
|
|
867
|
-
# The decleratons we have added. The _cmds represent all the symbols we have added
|
|
868
|
-
decls: Declarations = field(default_factory=Declarations)
|
|
869
|
-
# List of rulesets already added, so we don't re-add them if they are passed again
|
|
870
|
-
added_rulesets: set[str] = field(default_factory=set)
|
|
871
|
-
|
|
872
|
-
def add_decls(self, new_decls: Declarations) -> Iterable[bindings._Command]:
|
|
873
|
-
new_cmds = [v for k, v in new_decls._cmds.items() if k not in self.decls._cmds]
|
|
874
|
-
self.decls |= new_decls
|
|
875
|
-
return new_cmds
|
|
876
|
-
|
|
877
|
-
def add_rulesets(self, rulesets: Iterable[Ruleset]) -> Iterable[bindings._Command]:
|
|
878
|
-
for ruleset in rulesets:
|
|
879
|
-
if ruleset.egg_name not in self.added_rulesets:
|
|
880
|
-
self.added_rulesets.add(ruleset.egg_name)
|
|
881
|
-
yield from ruleset._cmds
|
|
882
|
-
|
|
883
|
-
|
|
884
836
|
@dataclass
|
|
885
837
|
class EGraph(_BaseModule):
|
|
886
838
|
"""
|
|
@@ -892,56 +844,34 @@ class EGraph(_BaseModule):
|
|
|
892
844
|
seminaive: InitVar[bool] = True
|
|
893
845
|
save_egglog_string: InitVar[bool] = False
|
|
894
846
|
|
|
895
|
-
|
|
896
|
-
_egraph: bindings.EGraph = field(repr=False, init=False)
|
|
897
|
-
_egglog_string: str | None = field(default=None, repr=False, init=False)
|
|
898
|
-
_state: _EGraphState = field(default_factory=_EGraphState, repr=False)
|
|
847
|
+
_state: EGraphState = field(init=False)
|
|
899
848
|
# For pushing/popping with egglog
|
|
900
|
-
_state_stack: list[
|
|
849
|
+
_state_stack: list[EGraphState] = field(default_factory=list, repr=False)
|
|
901
850
|
# For storing the global "current" egraph
|
|
902
851
|
_token_stack: list[Token[EGraph]] = field(default_factory=list, repr=False)
|
|
903
852
|
|
|
904
853
|
def __post_init__(self, modules: list[Module], seminaive: bool, save_egglog_string: bool) -> None:
|
|
905
|
-
|
|
854
|
+
egraph = bindings.EGraph(GLOBAL_PY_OBJECT_SORT, seminaive=seminaive, record=save_egglog_string)
|
|
855
|
+
self._state = EGraphState(egraph)
|
|
906
856
|
super().__post_init__(modules)
|
|
907
857
|
|
|
908
858
|
for m in self._flatted_deps:
|
|
909
|
-
self._add_decls(*m.cmds)
|
|
910
859
|
self._register_commands(m.cmds)
|
|
911
|
-
if save_egglog_string:
|
|
912
|
-
self._egglog_string = ""
|
|
913
|
-
|
|
914
|
-
def _register_commands(self, commands: list[Command]) -> None:
|
|
915
|
-
for c in commands:
|
|
916
|
-
if c.ruleset:
|
|
917
|
-
self._add_schedule(c.ruleset)
|
|
918
|
-
|
|
919
|
-
self._add_decls(*commands)
|
|
920
|
-
self._process_commands(command._to_egg_command(self._default_ruleset_name) for command in commands)
|
|
921
|
-
|
|
922
|
-
def _process_commands(self, commands: Iterable[bindings._Command]) -> None:
|
|
923
|
-
commands = list(commands)
|
|
924
|
-
self._egraph.run_program(*commands)
|
|
925
|
-
if isinstance(self._egglog_string, str):
|
|
926
|
-
self._egglog_string += "\n".join(str(c) for c in commands) + "\n"
|
|
927
860
|
|
|
928
861
|
def _add_decls(self, *decls: DeclerationsLike) -> None:
|
|
929
|
-
for d in
|
|
930
|
-
self.
|
|
931
|
-
|
|
932
|
-
def _add_schedule(self, schedule: Schedule) -> None:
|
|
933
|
-
self._add_decls(schedule)
|
|
934
|
-
self._process_commands(self._state.add_rulesets(schedule._rulesets()))
|
|
862
|
+
for d in decls:
|
|
863
|
+
self._state.__egg_decls__ |= d
|
|
935
864
|
|
|
936
865
|
@property
|
|
937
866
|
def as_egglog_string(self) -> str:
|
|
938
867
|
"""
|
|
939
868
|
Returns the egglog string for this module.
|
|
940
869
|
"""
|
|
941
|
-
|
|
870
|
+
cmds = self._egraph.commands()
|
|
871
|
+
if cmds is None:
|
|
942
872
|
msg = "Can't get egglog string unless EGraph created with save_egglog_string=True"
|
|
943
873
|
raise ValueError(msg)
|
|
944
|
-
return
|
|
874
|
+
return cmds
|
|
945
875
|
|
|
946
876
|
def _repr_mimebundle_(self, *args, **kwargs):
|
|
947
877
|
"""
|
|
@@ -954,7 +884,7 @@ class EGraph(_BaseModule):
|
|
|
954
884
|
kwargs.setdefault("split_primitive_outputs", True)
|
|
955
885
|
n_inline = kwargs.pop("n_inline_leaves", 0)
|
|
956
886
|
serialized = self._egraph.serialize([], **kwargs) # type: ignore[misc]
|
|
957
|
-
serialized.map_ops(self._state.
|
|
887
|
+
serialized.map_ops(self._state.op_mapping())
|
|
958
888
|
for _ in range(n_inline):
|
|
959
889
|
serialized.inline_leaves()
|
|
960
890
|
original = serialized.to_dot()
|
|
@@ -1016,17 +946,23 @@ class EGraph(_BaseModule):
|
|
|
1016
946
|
Loads a CSV file and sets it as *input, output of the function.
|
|
1017
947
|
"""
|
|
1018
948
|
ref, decls = resolve_callable(fn)
|
|
1019
|
-
|
|
1020
|
-
self.
|
|
1021
|
-
self.
|
|
949
|
+
self._add_decls(decls)
|
|
950
|
+
fn_name = self._state.callable_ref_to_egg(ref)
|
|
951
|
+
self._egraph.run_program(bindings.Input(fn_name, path))
|
|
1022
952
|
|
|
1023
953
|
def let(self, name: str, expr: EXPR) -> EXPR:
|
|
1024
954
|
"""
|
|
1025
955
|
Define a new expression in the egraph and return a reference to it.
|
|
1026
956
|
"""
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
957
|
+
action = let(name, expr)
|
|
958
|
+
self.register(action)
|
|
959
|
+
runtime_expr = to_runtime_expr(expr)
|
|
960
|
+
return cast(
|
|
961
|
+
EXPR,
|
|
962
|
+
RuntimeExpr.__from_value__(
|
|
963
|
+
self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name))
|
|
964
|
+
),
|
|
965
|
+
)
|
|
1030
966
|
|
|
1031
967
|
@overload
|
|
1032
968
|
def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> EXPR: ...
|
|
@@ -1041,28 +977,19 @@ class EGraph(_BaseModule):
|
|
|
1041
977
|
Simplifies the given expression.
|
|
1042
978
|
"""
|
|
1043
979
|
schedule = run(ruleset, *until) * limit_or_schedule if isinstance(limit_or_schedule, int) else limit_or_schedule
|
|
1044
|
-
del limit_or_schedule
|
|
1045
|
-
|
|
1046
|
-
self._add_decls(
|
|
1047
|
-
self.
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
self.
|
|
980
|
+
del limit_or_schedule, until, ruleset
|
|
981
|
+
runtime_expr = to_runtime_expr(expr)
|
|
982
|
+
self._add_decls(runtime_expr, schedule)
|
|
983
|
+
egg_schedule = self._state.schedule_to_egg(schedule.schedule)
|
|
984
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
985
|
+
egg_expr = self._state.expr_to_egg(typed_expr.expr)
|
|
986
|
+
self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule))
|
|
1051
987
|
extract_report = self._egraph.extract_report()
|
|
1052
988
|
if not isinstance(extract_report, bindings.Best):
|
|
1053
989
|
msg = "No extract report saved"
|
|
1054
990
|
raise ValueError(msg) # noqa: TRY004
|
|
1055
|
-
new_typed_expr =
|
|
1056
|
-
|
|
1057
|
-
)
|
|
1058
|
-
return cast(EXPR, RuntimeExpr(self._state.decls, new_typed_expr))
|
|
1059
|
-
|
|
1060
|
-
@property
|
|
1061
|
-
def _default_ruleset_name(self) -> str:
|
|
1062
|
-
if self.default_ruleset:
|
|
1063
|
-
self._add_schedule(self.default_ruleset)
|
|
1064
|
-
return self.default_ruleset.egg_name
|
|
1065
|
-
return ""
|
|
991
|
+
(new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
|
|
992
|
+
return cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
|
|
1066
993
|
|
|
1067
994
|
def include(self, path: str) -> None:
|
|
1068
995
|
"""
|
|
@@ -1092,8 +1019,9 @@ class EGraph(_BaseModule):
|
|
|
1092
1019
|
return self._run_schedule(limit_or_schedule)
|
|
1093
1020
|
|
|
1094
1021
|
def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
|
|
1095
|
-
self.
|
|
1096
|
-
self.
|
|
1022
|
+
self._add_decls(schedule)
|
|
1023
|
+
egg_schedule = self._state.schedule_to_egg(schedule.schedule)
|
|
1024
|
+
self._egraph.run_program(bindings.RunSchedule(egg_schedule))
|
|
1097
1025
|
run_report = self._egraph.run_report()
|
|
1098
1026
|
if not run_report:
|
|
1099
1027
|
msg = "No run report saved"
|
|
@@ -1104,18 +1032,18 @@ class EGraph(_BaseModule):
|
|
|
1104
1032
|
"""
|
|
1105
1033
|
Check if a fact is true in the egraph.
|
|
1106
1034
|
"""
|
|
1107
|
-
self.
|
|
1035
|
+
self._egraph.run_program(self._facts_to_check(facts))
|
|
1108
1036
|
|
|
1109
1037
|
def check_fail(self, *facts: FactLike) -> None:
|
|
1110
1038
|
"""
|
|
1111
1039
|
Checks that one of the facts is not true
|
|
1112
1040
|
"""
|
|
1113
|
-
self.
|
|
1041
|
+
self._egraph.run_program(bindings.Fail(self._facts_to_check(facts)))
|
|
1114
1042
|
|
|
1115
|
-
def _facts_to_check(self,
|
|
1116
|
-
facts = _fact_likes(
|
|
1043
|
+
def _facts_to_check(self, fact_likes: Iterable[FactLike]) -> bindings.Check:
|
|
1044
|
+
facts = _fact_likes(fact_likes)
|
|
1117
1045
|
self._add_decls(*facts)
|
|
1118
|
-
egg_facts = [f.
|
|
1046
|
+
egg_facts = [self._state.fact_to_egg(f.fact) for f in _fact_likes(facts)]
|
|
1119
1047
|
return bindings.Check(egg_facts)
|
|
1120
1048
|
|
|
1121
1049
|
@overload
|
|
@@ -1128,16 +1056,17 @@ class EGraph(_BaseModule):
|
|
|
1128
1056
|
"""
|
|
1129
1057
|
Extract the lowest cost expression from the egraph.
|
|
1130
1058
|
"""
|
|
1131
|
-
|
|
1132
|
-
self._add_decls(
|
|
1133
|
-
|
|
1059
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1060
|
+
self._add_decls(runtime_expr)
|
|
1061
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1062
|
+
extract_report = self._run_extract(typed_expr.expr, 0)
|
|
1063
|
+
|
|
1134
1064
|
if not isinstance(extract_report, bindings.Best):
|
|
1135
1065
|
msg = "No extract report saved"
|
|
1136
1066
|
raise ValueError(msg) # noqa: TRY004
|
|
1137
|
-
new_typed_expr =
|
|
1138
|
-
|
|
1139
|
-
)
|
|
1140
|
-
res = cast(EXPR, RuntimeExpr(self._state.decls, new_typed_expr))
|
|
1067
|
+
(new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
|
|
1068
|
+
|
|
1069
|
+
res = cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
|
|
1141
1070
|
if include_cost:
|
|
1142
1071
|
return res, extract_report.cost
|
|
1143
1072
|
return res
|
|
@@ -1146,23 +1075,20 @@ class EGraph(_BaseModule):
|
|
|
1146
1075
|
"""
|
|
1147
1076
|
Extract multiple expressions from the egraph.
|
|
1148
1077
|
"""
|
|
1149
|
-
|
|
1150
|
-
self._add_decls(
|
|
1078
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1079
|
+
self._add_decls(runtime_expr)
|
|
1080
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1151
1081
|
|
|
1152
|
-
extract_report = self._run_extract(expr
|
|
1082
|
+
extract_report = self._run_extract(typed_expr.expr, n)
|
|
1153
1083
|
if not isinstance(extract_report, bindings.Variants):
|
|
1154
1084
|
msg = "Wrong extract report type"
|
|
1155
1085
|
raise ValueError(msg) # noqa: TRY004
|
|
1156
|
-
new_exprs =
|
|
1157
|
-
|
|
1158
|
-
self._egraph, self._state.decls, expr.__egg_typed_expr__.tp, extract_report.termdag, term, {}
|
|
1159
|
-
)
|
|
1160
|
-
for term in extract_report.terms
|
|
1161
|
-
]
|
|
1162
|
-
return [cast(EXPR, RuntimeExpr(self._state.decls, expr)) for expr in new_exprs]
|
|
1086
|
+
new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp)
|
|
1087
|
+
return [cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, expr)) for expr in new_exprs]
|
|
1163
1088
|
|
|
1164
|
-
def _run_extract(self, expr:
|
|
1165
|
-
self.
|
|
1089
|
+
def _run_extract(self, expr: ExprDecl, n: int) -> bindings._ExtractReport:
|
|
1090
|
+
expr = self._state.expr_to_egg(expr)
|
|
1091
|
+
self._egraph.run_program(bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n)))))
|
|
1166
1092
|
extract_report = self._egraph.extract_report()
|
|
1167
1093
|
if not extract_report:
|
|
1168
1094
|
msg = "No extract report saved"
|
|
@@ -1173,15 +1099,15 @@ class EGraph(_BaseModule):
|
|
|
1173
1099
|
"""
|
|
1174
1100
|
Push the current state of the egraph, so that it can be popped later and reverted back.
|
|
1175
1101
|
"""
|
|
1176
|
-
self.
|
|
1102
|
+
self._egraph.run_program(bindings.Push(1))
|
|
1177
1103
|
self._state_stack.append(self._state)
|
|
1178
|
-
self._state =
|
|
1104
|
+
self._state = self._state.copy()
|
|
1179
1105
|
|
|
1180
1106
|
def pop(self) -> None:
|
|
1181
1107
|
"""
|
|
1182
1108
|
Pop the current state of the egraph, reverting back to the previous state.
|
|
1183
1109
|
"""
|
|
1184
|
-
self.
|
|
1110
|
+
self._egraph.run_program(bindings.Pop(1))
|
|
1185
1111
|
self._state = self._state_stack.pop()
|
|
1186
1112
|
|
|
1187
1113
|
def __enter__(self) -> Self:
|
|
@@ -1217,9 +1143,10 @@ class EGraph(_BaseModule):
|
|
|
1217
1143
|
"""
|
|
1218
1144
|
Evaluates the given expression (which must be a primitive type), returning the result.
|
|
1219
1145
|
"""
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1146
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1147
|
+
self._add_decls(runtime_expr)
|
|
1148
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1149
|
+
egg_expr = self._state.expr_to_egg(typed_expr.expr)
|
|
1223
1150
|
match typed_expr.tp:
|
|
1224
1151
|
case JustTypeRef("i64"):
|
|
1225
1152
|
return self._egraph.eval_i64(egg_expr)
|
|
@@ -1231,7 +1158,7 @@ class EGraph(_BaseModule):
|
|
|
1231
1158
|
return self._egraph.eval_string(egg_expr)
|
|
1232
1159
|
case JustTypeRef("PyObject"):
|
|
1233
1160
|
return self._egraph.eval_py_object(egg_expr)
|
|
1234
|
-
raise
|
|
1161
|
+
raise TypeError(f"Eval not implemented for {typed_expr.tp}")
|
|
1235
1162
|
|
|
1236
1163
|
def saturate(
|
|
1237
1164
|
self, *, max: int = 1000, performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
|
|
@@ -1270,6 +1197,32 @@ class EGraph(_BaseModule):
|
|
|
1270
1197
|
"""
|
|
1271
1198
|
return CURRENT_EGRAPH.get()
|
|
1272
1199
|
|
|
1200
|
+
@property
|
|
1201
|
+
def _egraph(self) -> bindings.EGraph:
|
|
1202
|
+
return self._state.egraph
|
|
1203
|
+
|
|
1204
|
+
@property
|
|
1205
|
+
def __egg_decls__(self) -> Declarations:
|
|
1206
|
+
return self._state.__egg_decls__
|
|
1207
|
+
|
|
1208
|
+
def _register_commands(self, cmds: list[Command]) -> None:
|
|
1209
|
+
self._add_decls(*cmds)
|
|
1210
|
+
egg_cmds = list(map(self._command_to_egg, cmds))
|
|
1211
|
+
self._egraph.run_program(*egg_cmds)
|
|
1212
|
+
|
|
1213
|
+
def _command_to_egg(self, cmd: Command) -> bindings._Command:
|
|
1214
|
+
ruleset_name = ""
|
|
1215
|
+
cmd_decl: CommandDecl
|
|
1216
|
+
match cmd:
|
|
1217
|
+
case RewriteOrRule(_, cmd_decl, ruleset):
|
|
1218
|
+
if ruleset:
|
|
1219
|
+
ruleset_name = ruleset.__egg_name__
|
|
1220
|
+
case Action(_, action):
|
|
1221
|
+
cmd_decl = ActionCommandDecl(action)
|
|
1222
|
+
case _:
|
|
1223
|
+
assert_never(cmd)
|
|
1224
|
+
return self._state.command_to_egg(cmd_decl, ruleset_name)
|
|
1225
|
+
|
|
1273
1226
|
|
|
1274
1227
|
CURRENT_EGRAPH = ContextVar[EGraph]("CURRENT_EGRAPH")
|
|
1275
1228
|
|
|
@@ -1316,61 +1269,53 @@ class Unit(Expr, egg_sort="Unit", builtin=True):
|
|
|
1316
1269
|
|
|
1317
1270
|
|
|
1318
1271
|
def ruleset(
|
|
1319
|
-
rule_or_generator:
|
|
1272
|
+
rule_or_generator: RewriteOrRule | RewriteOrRuleGenerator | None = None,
|
|
1273
|
+
*rules: RewriteOrRule,
|
|
1274
|
+
name: None | str = None,
|
|
1320
1275
|
) -> Ruleset:
|
|
1321
1276
|
"""
|
|
1322
1277
|
Creates a ruleset with the following rules.
|
|
1323
1278
|
|
|
1324
1279
|
If no name is provided, one is generated based on the current module
|
|
1325
1280
|
"""
|
|
1326
|
-
r = Ruleset(name
|
|
1281
|
+
r = Ruleset(name)
|
|
1327
1282
|
if rule_or_generator is not None:
|
|
1328
|
-
r.register(rule_or_generator, *rules)
|
|
1283
|
+
r.register(rule_or_generator, *rules, _increase_frame=True)
|
|
1329
1284
|
return r
|
|
1330
1285
|
|
|
1331
1286
|
|
|
1332
|
-
|
|
1287
|
+
@dataclass
|
|
1288
|
+
class Schedule(DelayedDeclerations):
|
|
1333
1289
|
"""
|
|
1334
1290
|
A composition of some rulesets, either composing them sequentially, running them repeatedly, running them till saturation, or running until some facts are met
|
|
1335
1291
|
"""
|
|
1336
1292
|
|
|
1293
|
+
# Defer declerations so that we can have rule generators that used not yet defined yet
|
|
1294
|
+
schedule: ScheduleDecl
|
|
1295
|
+
|
|
1296
|
+
def __str__(self) -> str:
|
|
1297
|
+
return pretty_decl(self.__egg_decls__, self.schedule)
|
|
1298
|
+
|
|
1299
|
+
def __repr__(self) -> str:
|
|
1300
|
+
return str(self)
|
|
1301
|
+
|
|
1337
1302
|
def __mul__(self, length: int) -> Schedule:
|
|
1338
1303
|
"""
|
|
1339
1304
|
Repeat the schedule a number of times.
|
|
1340
1305
|
"""
|
|
1341
|
-
return
|
|
1306
|
+
return Schedule(self.__egg_decls_thunk__, RepeatDecl(self.schedule, length))
|
|
1342
1307
|
|
|
1343
1308
|
def saturate(self) -> Schedule:
|
|
1344
1309
|
"""
|
|
1345
1310
|
Run the schedule until the e-graph is saturated.
|
|
1346
1311
|
"""
|
|
1347
|
-
return
|
|
1312
|
+
return Schedule(self.__egg_decls_thunk__, SaturateDecl(self.schedule))
|
|
1348
1313
|
|
|
1349
1314
|
def __add__(self, other: Schedule) -> Schedule:
|
|
1350
1315
|
"""
|
|
1351
1316
|
Run two schedules in sequence.
|
|
1352
1317
|
"""
|
|
1353
|
-
return
|
|
1354
|
-
|
|
1355
|
-
@abstractmethod
|
|
1356
|
-
def __str__(self) -> str:
|
|
1357
|
-
raise NotImplementedError
|
|
1358
|
-
|
|
1359
|
-
@abstractmethod
|
|
1360
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1361
|
-
raise NotImplementedError
|
|
1362
|
-
|
|
1363
|
-
@abstractmethod
|
|
1364
|
-
def _rulesets(self) -> Iterable[Ruleset]:
|
|
1365
|
-
"""
|
|
1366
|
-
Mapping of all the rulesets used to commands.
|
|
1367
|
-
"""
|
|
1368
|
-
raise NotImplementedError
|
|
1369
|
-
|
|
1370
|
-
@property
|
|
1371
|
-
@abstractmethod
|
|
1372
|
-
def __egg_decls__(self) -> Declarations:
|
|
1373
|
-
raise NotImplementedError
|
|
1318
|
+
return Schedule(Thunk.fn(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule)))
|
|
1374
1319
|
|
|
1375
1320
|
|
|
1376
1321
|
@dataclass
|
|
@@ -1379,419 +1324,119 @@ class Ruleset(Schedule):
|
|
|
1379
1324
|
A collection of rules, which can be run as a schedule.
|
|
1380
1325
|
"""
|
|
1381
1326
|
|
|
1327
|
+
__egg_decls_thunk__: Callable[[], Declarations] = field(init=False)
|
|
1328
|
+
schedule: RunDecl = field(init=False)
|
|
1382
1329
|
name: str | None
|
|
1383
|
-
rules: list[Rule | Rewrite] = field(default_factory=list)
|
|
1384
1330
|
|
|
1385
|
-
|
|
1331
|
+
# Current declerations we have accumulated
|
|
1332
|
+
_current_egg_decls: Declarations = field(default_factory=Declarations)
|
|
1333
|
+
# Current rulesets we have accumulated
|
|
1334
|
+
__egg_ruleset__: RulesetDecl = field(init=False)
|
|
1335
|
+
# Rule generator functions that have been deferred, to allow for late type binding
|
|
1336
|
+
deferred_rule_gens: list[Callable[[], Iterable[RewriteOrRule]]] = field(default_factory=list)
|
|
1337
|
+
|
|
1338
|
+
def __post_init__(self) -> None:
|
|
1339
|
+
self.schedule = RunDecl(self.__egg_name__, ())
|
|
1340
|
+
self.__egg_ruleset__ = self._current_egg_decls._rulesets[self.__egg_name__] = RulesetDecl([])
|
|
1341
|
+
self.__egg_decls_thunk__ = self._update_egg_decls
|
|
1342
|
+
|
|
1343
|
+
def _update_egg_decls(self) -> Declarations:
|
|
1344
|
+
"""
|
|
1345
|
+
To return the egg decls, we go through our deferred rules and add any we haven't yet
|
|
1346
|
+
"""
|
|
1347
|
+
while self.deferred_rule_gens:
|
|
1348
|
+
rules = self.deferred_rule_gens.pop()()
|
|
1349
|
+
self._current_egg_decls.update(*rules)
|
|
1350
|
+
self.__egg_ruleset__.rules.extend(r.decl for r in rules)
|
|
1351
|
+
return self._current_egg_decls
|
|
1352
|
+
|
|
1353
|
+
def append(self, rule: RewriteOrRule) -> None:
|
|
1386
1354
|
"""
|
|
1387
1355
|
Register a rule with the ruleset.
|
|
1388
1356
|
"""
|
|
1389
|
-
self.
|
|
1357
|
+
self._current_egg_decls |= rule
|
|
1358
|
+
self.__egg_ruleset__.rules.append(rule.decl)
|
|
1390
1359
|
|
|
1391
|
-
def register(
|
|
1360
|
+
def register(
|
|
1361
|
+
self,
|
|
1362
|
+
/,
|
|
1363
|
+
rule_or_generator: RewriteOrRule | RewriteOrRuleGenerator,
|
|
1364
|
+
*rules: RewriteOrRule,
|
|
1365
|
+
_increase_frame: bool = False,
|
|
1366
|
+
) -> None:
|
|
1392
1367
|
"""
|
|
1393
1368
|
Register rewrites or rules, either as a function or as values.
|
|
1394
1369
|
"""
|
|
1395
|
-
if isinstance(rule_or_generator,
|
|
1396
|
-
|
|
1397
|
-
|
|
1370
|
+
if isinstance(rule_or_generator, RewriteOrRule):
|
|
1371
|
+
self.append(rule_or_generator)
|
|
1372
|
+
for r in rules:
|
|
1373
|
+
self.append(r)
|
|
1398
1374
|
else:
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
def _cmds(self) -> list[bindings._Command]:
|
|
1409
|
-
cmds = [r._to_egg_command(self.egg_name) for r in self.rules]
|
|
1410
|
-
if self.egg_name:
|
|
1411
|
-
cmds.insert(0, bindings.AddRuleset(self.egg_name))
|
|
1412
|
-
return cmds
|
|
1375
|
+
assert not rules
|
|
1376
|
+
current_frame = inspect.currentframe()
|
|
1377
|
+
assert current_frame
|
|
1378
|
+
original_frame = current_frame.f_back
|
|
1379
|
+
assert original_frame
|
|
1380
|
+
if _increase_frame:
|
|
1381
|
+
original_frame = original_frame.f_back
|
|
1382
|
+
assert original_frame
|
|
1383
|
+
self.deferred_rule_gens.append(Thunk.fn(_rewrite_or_rule_generator, rule_or_generator, original_frame))
|
|
1413
1384
|
|
|
1414
1385
|
def __str__(self) -> str:
|
|
1415
|
-
return
|
|
1386
|
+
return pretty_decl(self._current_egg_decls, self.__egg_ruleset__, ruleset_name=self.name)
|
|
1416
1387
|
|
|
1417
1388
|
def __repr__(self) -> str:
|
|
1418
|
-
|
|
1419
|
-
return str(self)
|
|
1420
|
-
rules = ", ".join(map(repr, self.rules))
|
|
1421
|
-
return f"ruleset({rules}, name={self.egg_name!r})"
|
|
1422
|
-
|
|
1423
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1424
|
-
return bindings.Run(self._to_egg_config())
|
|
1425
|
-
|
|
1426
|
-
def _to_egg_config(self) -> bindings.RunConfig:
|
|
1427
|
-
return bindings.RunConfig(self.egg_name, None)
|
|
1428
|
-
|
|
1429
|
-
def _rulesets(self) -> Iterable[Ruleset]:
|
|
1430
|
-
yield self
|
|
1389
|
+
return str(self)
|
|
1431
1390
|
|
|
1391
|
+
# Create a unique name if we didn't pass one from the user
|
|
1432
1392
|
@property
|
|
1433
|
-
def
|
|
1434
|
-
return self.name or f"
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
class Command(ABC):
|
|
1438
|
-
"""
|
|
1439
|
-
A command that can be executed in the egg interpreter.
|
|
1440
|
-
|
|
1441
|
-
We only use this for commands which return no result and don't create new Python objects.
|
|
1442
|
-
|
|
1443
|
-
Anything that can be passed to the `register` function in a Module is a Command.
|
|
1444
|
-
"""
|
|
1445
|
-
|
|
1446
|
-
ruleset: Ruleset | None
|
|
1447
|
-
|
|
1448
|
-
@property
|
|
1449
|
-
@abstractmethod
|
|
1450
|
-
def __egg_decls__(self) -> Declarations:
|
|
1451
|
-
raise NotImplementedError
|
|
1452
|
-
|
|
1453
|
-
@abstractmethod
|
|
1454
|
-
def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
|
|
1455
|
-
raise NotImplementedError
|
|
1456
|
-
|
|
1457
|
-
@abstractmethod
|
|
1458
|
-
def __str__(self) -> str:
|
|
1459
|
-
raise NotImplementedError
|
|
1393
|
+
def __egg_name__(self) -> str:
|
|
1394
|
+
return self.name or f"ruleset_{id(self)}"
|
|
1460
1395
|
|
|
1461
1396
|
|
|
1462
1397
|
@dataclass
|
|
1463
|
-
class
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
_conditions: tuple[Fact, ...]
|
|
1468
|
-
_fn_name: ClassVar[str] = "rewrite"
|
|
1398
|
+
class RewriteOrRule:
|
|
1399
|
+
__egg_decls__: Declarations
|
|
1400
|
+
decl: RewriteOrRuleDecl
|
|
1401
|
+
ruleset: Ruleset | None = None
|
|
1469
1402
|
|
|
1470
1403
|
def __str__(self) -> str:
|
|
1471
|
-
|
|
1472
|
-
return f"{self._fn_name}({self._lhs}).to({args_str})"
|
|
1473
|
-
|
|
1474
|
-
def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
|
|
1475
|
-
return bindings.RewriteCommand(
|
|
1476
|
-
self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite()
|
|
1477
|
-
)
|
|
1478
|
-
|
|
1479
|
-
def _to_egg_rewrite(self) -> bindings.Rewrite:
|
|
1480
|
-
return bindings.Rewrite(
|
|
1481
|
-
self._lhs.__egg_typed_expr__.expr.to_egg(self._lhs.__egg_decls__),
|
|
1482
|
-
self._rhs.__egg_typed_expr__.expr.to_egg(self._rhs.__egg_decls__),
|
|
1483
|
-
[c._to_egg_fact() for c in self._conditions],
|
|
1484
|
-
)
|
|
1485
|
-
|
|
1486
|
-
@cached_property
|
|
1487
|
-
def __egg_decls__(self) -> Declarations:
|
|
1488
|
-
return Declarations.create(self._lhs, self._rhs, *self._conditions)
|
|
1489
|
-
|
|
1490
|
-
def with_ruleset(self, ruleset: Ruleset) -> Rewrite:
|
|
1491
|
-
return Rewrite(ruleset, self._lhs, self._rhs, self._conditions)
|
|
1492
|
-
|
|
1404
|
+
return pretty_decl(self.__egg_decls__, self.decl)
|
|
1493
1405
|
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
_fn_name: ClassVar[str] = "birewrite"
|
|
1497
|
-
|
|
1498
|
-
def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
|
|
1499
|
-
return bindings.BiRewriteCommand(
|
|
1500
|
-
self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite()
|
|
1501
|
-
)
|
|
1406
|
+
def __repr__(self) -> str:
|
|
1407
|
+
return str(self)
|
|
1502
1408
|
|
|
1503
1409
|
|
|
1504
1410
|
@dataclass
|
|
1505
|
-
class Fact
|
|
1411
|
+
class Fact:
|
|
1506
1412
|
"""
|
|
1507
1413
|
A query on an EGraph, either by an expression or an equivalence between multiple expressions.
|
|
1508
1414
|
"""
|
|
1509
1415
|
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
raise NotImplementedError
|
|
1513
|
-
|
|
1514
|
-
@property
|
|
1515
|
-
@abstractmethod
|
|
1516
|
-
def __egg_decls__(self) -> Declarations:
|
|
1517
|
-
raise NotImplementedError
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
@dataclass
|
|
1521
|
-
class Eq(Fact):
|
|
1522
|
-
_exprs: list[RuntimeExpr]
|
|
1523
|
-
|
|
1524
|
-
def __str__(self) -> str:
|
|
1525
|
-
first, *rest = self._exprs
|
|
1526
|
-
args_str = ", ".join(map(str, rest))
|
|
1527
|
-
return f"eq({first}).to({args_str})"
|
|
1528
|
-
|
|
1529
|
-
def _to_egg_fact(self) -> bindings.Eq:
|
|
1530
|
-
return bindings.Eq([e.__egg__ for e in self._exprs])
|
|
1531
|
-
|
|
1532
|
-
@cached_property
|
|
1533
|
-
def __egg_decls__(self) -> Declarations:
|
|
1534
|
-
return Declarations.create(*self._exprs)
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
@dataclass
|
|
1538
|
-
class ExprFact(Fact):
|
|
1539
|
-
_expr: RuntimeExpr
|
|
1416
|
+
__egg_decls__: Declarations
|
|
1417
|
+
fact: FactDecl
|
|
1540
1418
|
|
|
1541
1419
|
def __str__(self) -> str:
|
|
1542
|
-
return
|
|
1420
|
+
return pretty_decl(self.__egg_decls__, self.fact)
|
|
1543
1421
|
|
|
1544
|
-
def
|
|
1545
|
-
return
|
|
1546
|
-
|
|
1547
|
-
@cached_property
|
|
1548
|
-
def __egg_decls__(self) -> Declarations:
|
|
1549
|
-
return self._expr.__egg_decls__
|
|
1422
|
+
def __repr__(self) -> str:
|
|
1423
|
+
return str(self)
|
|
1550
1424
|
|
|
1551
1425
|
|
|
1552
1426
|
@dataclass
|
|
1553
|
-
class
|
|
1554
|
-
head: tuple[Action, ...]
|
|
1555
|
-
body: tuple[Fact, ...]
|
|
1556
|
-
name: str
|
|
1557
|
-
ruleset: Ruleset | None
|
|
1558
|
-
|
|
1559
|
-
def __str__(self) -> str:
|
|
1560
|
-
head_str = ", ".join(map(str, self.head))
|
|
1561
|
-
body_str = ", ".join(map(str, self.body))
|
|
1562
|
-
return f"rule({body_str}).then({head_str})"
|
|
1563
|
-
|
|
1564
|
-
def _to_egg_command(self, default_ruleset_name: str) -> bindings.RuleCommand:
|
|
1565
|
-
return bindings.RuleCommand(
|
|
1566
|
-
self.name,
|
|
1567
|
-
self.ruleset.egg_name if self.ruleset else default_ruleset_name,
|
|
1568
|
-
bindings.Rule(
|
|
1569
|
-
[a._to_egg_action() for a in self.head],
|
|
1570
|
-
[f._to_egg_fact() for f in self.body],
|
|
1571
|
-
),
|
|
1572
|
-
)
|
|
1573
|
-
|
|
1574
|
-
@cached_property
|
|
1575
|
-
def __egg_decls__(self) -> Declarations:
|
|
1576
|
-
return Declarations.create(*self.head, *self.body)
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
class Action(Command, ABC):
|
|
1427
|
+
class Action:
|
|
1580
1428
|
"""
|
|
1581
1429
|
A change to an EGraph, either unioning multiple expressing, setting the value of a function call, deleting an expression, or panicking.
|
|
1582
1430
|
"""
|
|
1583
1431
|
|
|
1584
|
-
|
|
1585
|
-
|
|
1586
|
-
raise NotImplementedError
|
|
1587
|
-
|
|
1588
|
-
def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
|
|
1589
|
-
return bindings.ActionCommand(self._to_egg_action())
|
|
1590
|
-
|
|
1591
|
-
@property
|
|
1592
|
-
def ruleset(self) -> None | Ruleset: # type: ignore[override]
|
|
1593
|
-
return None
|
|
1594
|
-
|
|
1595
|
-
|
|
1596
|
-
@dataclass
|
|
1597
|
-
class Let(Action):
|
|
1598
|
-
_name: str
|
|
1599
|
-
_value: RuntimeExpr
|
|
1432
|
+
__egg_decls__: Declarations
|
|
1433
|
+
action: ActionDecl
|
|
1600
1434
|
|
|
1601
1435
|
def __str__(self) -> str:
|
|
1602
|
-
return
|
|
1603
|
-
|
|
1604
|
-
def _to_egg_action(self) -> bindings.Let:
|
|
1605
|
-
return bindings.Let(self._name, self._value.__egg__)
|
|
1606
|
-
|
|
1607
|
-
@property
|
|
1608
|
-
def __egg_decls__(self) -> Declarations:
|
|
1609
|
-
return self._value.__egg_decls__
|
|
1610
|
-
|
|
1611
|
-
|
|
1612
|
-
@dataclass
|
|
1613
|
-
class Set(Action):
|
|
1614
|
-
"""
|
|
1615
|
-
Similar to union, except can be used on primitive expressions, whereas union can only be used on user defined expressions.
|
|
1616
|
-
"""
|
|
1617
|
-
|
|
1618
|
-
_call: RuntimeExpr
|
|
1619
|
-
_rhs: RuntimeExpr
|
|
1620
|
-
|
|
1621
|
-
def __str__(self) -> str:
|
|
1622
|
-
return f"set({self._call}).to({self._rhs})"
|
|
1623
|
-
|
|
1624
|
-
def _to_egg_action(self) -> bindings.Set:
|
|
1625
|
-
egg_call = self._call.__egg__
|
|
1626
|
-
if not isinstance(egg_call, bindings.Call):
|
|
1627
|
-
raise ValueError(f"Can only create a set with a call for the lhs, got {self._call}") # noqa: TRY004
|
|
1628
|
-
return bindings.Set(
|
|
1629
|
-
egg_call.name,
|
|
1630
|
-
egg_call.args,
|
|
1631
|
-
self._rhs.__egg__,
|
|
1632
|
-
)
|
|
1633
|
-
|
|
1634
|
-
@cached_property
|
|
1635
|
-
def __egg_decls__(self) -> Declarations:
|
|
1636
|
-
return Declarations.create(self._call, self._rhs)
|
|
1637
|
-
|
|
1638
|
-
|
|
1639
|
-
@dataclass
|
|
1640
|
-
class ExprAction(Action):
|
|
1641
|
-
_expr: RuntimeExpr
|
|
1642
|
-
|
|
1643
|
-
def __str__(self) -> str:
|
|
1644
|
-
return str(self._expr)
|
|
1645
|
-
|
|
1646
|
-
def _to_egg_action(self) -> bindings.Expr_:
|
|
1647
|
-
return bindings.Expr_(self._expr.__egg__)
|
|
1648
|
-
|
|
1649
|
-
@property
|
|
1650
|
-
def __egg_decls__(self) -> Declarations:
|
|
1651
|
-
return self._expr.__egg_decls__
|
|
1652
|
-
|
|
1653
|
-
|
|
1654
|
-
@dataclass
|
|
1655
|
-
class Delete(Action):
|
|
1656
|
-
"""
|
|
1657
|
-
Remove a function call from an EGraph.
|
|
1658
|
-
"""
|
|
1659
|
-
|
|
1660
|
-
_call: RuntimeExpr
|
|
1661
|
-
|
|
1662
|
-
def __str__(self) -> str:
|
|
1663
|
-
return f"delete({self._call})"
|
|
1664
|
-
|
|
1665
|
-
def _to_egg_action(self) -> bindings.Delete:
|
|
1666
|
-
egg_call = self._call.__egg__
|
|
1667
|
-
if not isinstance(egg_call, bindings.Call):
|
|
1668
|
-
raise ValueError(f"Can only create a call with a call for the lhs, got {self._call}") # noqa: TRY004
|
|
1669
|
-
return bindings.Delete(egg_call.name, egg_call.args)
|
|
1670
|
-
|
|
1671
|
-
@property
|
|
1672
|
-
def __egg_decls__(self) -> Declarations:
|
|
1673
|
-
return self._call.__egg_decls__
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
@dataclass
|
|
1677
|
-
class Union_(Action): # noqa: N801
|
|
1678
|
-
"""
|
|
1679
|
-
Merges two equivalence classes of two expressions.
|
|
1680
|
-
"""
|
|
1681
|
-
|
|
1682
|
-
_lhs: RuntimeExpr
|
|
1683
|
-
_rhs: RuntimeExpr
|
|
1684
|
-
|
|
1685
|
-
def __str__(self) -> str:
|
|
1686
|
-
return f"union({self._lhs}).with_({self._rhs})"
|
|
1687
|
-
|
|
1688
|
-
def _to_egg_action(self) -> bindings.Union:
|
|
1689
|
-
return bindings.Union(self._lhs.__egg__, self._rhs.__egg__)
|
|
1690
|
-
|
|
1691
|
-
@cached_property
|
|
1692
|
-
def __egg_decls__(self) -> Declarations:
|
|
1693
|
-
return Declarations.create(self._lhs, self._rhs)
|
|
1694
|
-
|
|
1695
|
-
|
|
1696
|
-
@dataclass
|
|
1697
|
-
class Panic(Action):
|
|
1698
|
-
message: str
|
|
1699
|
-
|
|
1700
|
-
def __str__(self) -> str:
|
|
1701
|
-
return f"panic({self.message})"
|
|
1702
|
-
|
|
1703
|
-
def _to_egg_action(self) -> bindings.Panic:
|
|
1704
|
-
return bindings.Panic(self.message)
|
|
1705
|
-
|
|
1706
|
-
@cached_property
|
|
1707
|
-
def __egg_decls__(self) -> Declarations:
|
|
1708
|
-
return Declarations()
|
|
1709
|
-
|
|
1710
|
-
|
|
1711
|
-
@dataclass
|
|
1712
|
-
class Run(Schedule):
|
|
1713
|
-
"""Configuration of a run"""
|
|
1714
|
-
|
|
1715
|
-
# None if using default ruleset
|
|
1716
|
-
ruleset: Ruleset | None
|
|
1717
|
-
until: tuple[Fact, ...]
|
|
1718
|
-
|
|
1719
|
-
def __str__(self) -> str:
|
|
1720
|
-
args_str = ", ".join(map(str, [self.ruleset, *self.until]))
|
|
1721
|
-
return f"run({args_str})"
|
|
1722
|
-
|
|
1723
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1724
|
-
return bindings.Run(self._to_egg_config(default_ruleset_name))
|
|
1725
|
-
|
|
1726
|
-
def _to_egg_config(self, default_ruleset_name: str) -> bindings.RunConfig:
|
|
1727
|
-
return bindings.RunConfig(
|
|
1728
|
-
self.ruleset.egg_name if self.ruleset else default_ruleset_name,
|
|
1729
|
-
[fact._to_egg_fact() for fact in self.until] if self.until else None,
|
|
1730
|
-
)
|
|
1731
|
-
|
|
1732
|
-
def _rulesets(self) -> Iterable[Ruleset]:
|
|
1733
|
-
if self.ruleset:
|
|
1734
|
-
yield self.ruleset
|
|
1735
|
-
|
|
1736
|
-
@cached_property
|
|
1737
|
-
def __egg_decls__(self) -> Declarations:
|
|
1738
|
-
return Declarations.create(self.ruleset, *self.until)
|
|
1436
|
+
return pretty_decl(self.__egg_decls__, self.action)
|
|
1739
1437
|
|
|
1740
|
-
|
|
1741
|
-
|
|
1742
|
-
class Saturate(Schedule):
|
|
1743
|
-
schedule: Schedule
|
|
1744
|
-
|
|
1745
|
-
def __str__(self) -> str:
|
|
1746
|
-
return f"{self.schedule}.saturate()"
|
|
1747
|
-
|
|
1748
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1749
|
-
return bindings.Saturate(self.schedule._to_egg_schedule(default_ruleset_name))
|
|
1750
|
-
|
|
1751
|
-
def _rulesets(self) -> Iterable[Ruleset]:
|
|
1752
|
-
return self.schedule._rulesets()
|
|
1753
|
-
|
|
1754
|
-
@property
|
|
1755
|
-
def __egg_decls__(self) -> Declarations:
|
|
1756
|
-
return self.schedule.__egg_decls__
|
|
1757
|
-
|
|
1758
|
-
|
|
1759
|
-
@dataclass
|
|
1760
|
-
class Repeat(Schedule):
|
|
1761
|
-
length: int
|
|
1762
|
-
schedule: Schedule
|
|
1763
|
-
|
|
1764
|
-
def __str__(self) -> str:
|
|
1765
|
-
return f"{self.schedule} * {self.length}"
|
|
1766
|
-
|
|
1767
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1768
|
-
return bindings.Repeat(self.length, self.schedule._to_egg_schedule(default_ruleset_name))
|
|
1769
|
-
|
|
1770
|
-
def _rulesets(self) -> Iterable[Ruleset]:
|
|
1771
|
-
return self.schedule._rulesets()
|
|
1772
|
-
|
|
1773
|
-
@property
|
|
1774
|
-
def __egg_decls__(self) -> Declarations:
|
|
1775
|
-
return self.schedule.__egg_decls__
|
|
1776
|
-
|
|
1777
|
-
|
|
1778
|
-
@dataclass
|
|
1779
|
-
class Sequence(Schedule):
|
|
1780
|
-
schedules: tuple[Schedule, ...]
|
|
1781
|
-
|
|
1782
|
-
def __str__(self) -> str:
|
|
1783
|
-
return f"sequence({', '.join(map(str, self.schedules))})"
|
|
1784
|
-
|
|
1785
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1786
|
-
return bindings.Sequence([schedule._to_egg_schedule(default_ruleset_name) for schedule in self.schedules])
|
|
1787
|
-
|
|
1788
|
-
def _rulesets(self) -> Iterable[Ruleset]:
|
|
1789
|
-
for s in self.schedules:
|
|
1790
|
-
yield from s._rulesets()
|
|
1791
|
-
|
|
1792
|
-
@cached_property
|
|
1793
|
-
def __egg_decls__(self) -> Declarations:
|
|
1794
|
-
return Declarations.create(*self.schedules)
|
|
1438
|
+
def __repr__(self) -> str:
|
|
1439
|
+
return str(self)
|
|
1795
1440
|
|
|
1796
1441
|
|
|
1797
1442
|
# We use these builders so that when creating these structures we can type check
|
|
@@ -1800,16 +1445,16 @@ class Sequence(Schedule):
|
|
|
1800
1445
|
|
|
1801
1446
|
@deprecated("Use <ruleset>.register(<rewrite>) instead of passing rulesets as arguments to rewrites.")
|
|
1802
1447
|
@overload
|
|
1803
|
-
def rewrite(lhs: EXPR, ruleset: Ruleset) -> _RewriteBuilder[EXPR]: ...
|
|
1448
|
+
def rewrite(lhs: EXPR, ruleset: Ruleset, *, subsume: bool = False) -> _RewriteBuilder[EXPR]: ...
|
|
1804
1449
|
|
|
1805
1450
|
|
|
1806
1451
|
@overload
|
|
1807
|
-
def rewrite(lhs: EXPR, ruleset: None = None) -> _RewriteBuilder[EXPR]: ...
|
|
1452
|
+
def rewrite(lhs: EXPR, ruleset: None = None, *, subsume: bool = False) -> _RewriteBuilder[EXPR]: ...
|
|
1808
1453
|
|
|
1809
1454
|
|
|
1810
|
-
def rewrite(lhs: EXPR, ruleset: Ruleset | None = None) -> _RewriteBuilder[EXPR]:
|
|
1455
|
+
def rewrite(lhs: EXPR, ruleset: Ruleset | None = None, *, subsume: bool = False) -> _RewriteBuilder[EXPR]:
|
|
1811
1456
|
"""Rewrite the given expression to a new expression."""
|
|
1812
|
-
return _RewriteBuilder(lhs, ruleset)
|
|
1457
|
+
return _RewriteBuilder(lhs, ruleset, subsume)
|
|
1813
1458
|
|
|
1814
1459
|
|
|
1815
1460
|
@deprecated("Use <ruleset>.register(<birewrite>) instead of passing rulesets as arguments to birewrites.")
|
|
@@ -1838,25 +1483,41 @@ def ne(expr: EXPR) -> _NeBuilder[EXPR]:
|
|
|
1838
1483
|
|
|
1839
1484
|
def panic(message: str) -> Action:
|
|
1840
1485
|
"""Raise an error with the given message."""
|
|
1841
|
-
return
|
|
1486
|
+
return Action(Declarations(), PanicDecl(message))
|
|
1842
1487
|
|
|
1843
1488
|
|
|
1844
1489
|
def let(name: str, expr: Expr) -> Action:
|
|
1845
1490
|
"""Create a let binding."""
|
|
1846
|
-
|
|
1491
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1492
|
+
return Action(runtime_expr.__egg_decls__, LetDecl(name, runtime_expr.__egg_typed_expr__))
|
|
1847
1493
|
|
|
1848
1494
|
|
|
1849
1495
|
def expr_action(expr: Expr) -> Action:
|
|
1850
|
-
|
|
1496
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1497
|
+
return Action(runtime_expr.__egg_decls__, ExprActionDecl(runtime_expr.__egg_typed_expr__))
|
|
1851
1498
|
|
|
1852
1499
|
|
|
1853
1500
|
def delete(expr: Expr) -> Action:
|
|
1854
1501
|
"""Create a delete expression."""
|
|
1855
|
-
|
|
1502
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1503
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1504
|
+
call_decl = typed_expr.expr
|
|
1505
|
+
assert isinstance(call_decl, CallDecl), "Can only delete calls, not literals or vars"
|
|
1506
|
+
return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "delete"))
|
|
1507
|
+
|
|
1508
|
+
|
|
1509
|
+
def subsume(expr: Expr) -> Action:
|
|
1510
|
+
"""Subsume an expression so it cannot be matched against or extracted"""
|
|
1511
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1512
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1513
|
+
call_decl = typed_expr.expr
|
|
1514
|
+
assert isinstance(call_decl, CallDecl), "Can only subsume calls, not literals or vars"
|
|
1515
|
+
return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "subsume"))
|
|
1856
1516
|
|
|
1857
1517
|
|
|
1858
1518
|
def expr_fact(expr: Expr) -> Fact:
|
|
1859
|
-
|
|
1519
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1520
|
+
return Fact(runtime_expr.__egg_decls__, ExprFactDecl(runtime_expr.__egg_typed_expr__))
|
|
1860
1521
|
|
|
1861
1522
|
|
|
1862
1523
|
def union(lhs: EXPR) -> _UnionBuilder[EXPR]:
|
|
@@ -1883,6 +1544,11 @@ def rule(*facts: FactLike, ruleset: Ruleset | None = None, name: str | None = No
|
|
|
1883
1544
|
return _RuleBuilder(facts=_fact_likes(facts), name=name, ruleset=ruleset)
|
|
1884
1545
|
|
|
1885
1546
|
|
|
1547
|
+
@deprecated("This function is now a no-op, you can remove it and use actions as commands")
|
|
1548
|
+
def action_command(action: Action) -> Action:
|
|
1549
|
+
return action
|
|
1550
|
+
|
|
1551
|
+
|
|
1886
1552
|
def var(name: str, bound: type[EXPR]) -> EXPR:
|
|
1887
1553
|
"""Create a new variable with the given name and type."""
|
|
1888
1554
|
return cast(EXPR, _var(name, bound))
|
|
@@ -1890,9 +1556,9 @@ def var(name: str, bound: type[EXPR]) -> EXPR:
|
|
|
1890
1556
|
|
|
1891
1557
|
def _var(name: str, bound: object) -> RuntimeExpr:
|
|
1892
1558
|
"""Create a new variable with the given name and type."""
|
|
1893
|
-
if not isinstance(bound, RuntimeClass
|
|
1559
|
+
if not isinstance(bound, RuntimeClass):
|
|
1894
1560
|
raise TypeError(f"Unexpected type {type(bound)}")
|
|
1895
|
-
return RuntimeExpr(bound.__egg_decls__, TypedExprDecl(
|
|
1561
|
+
return RuntimeExpr.__from_value__(bound.__egg_decls__, TypedExprDecl(bound.__egg_tp__.to_just(), VarDecl(name)))
|
|
1896
1562
|
|
|
1897
1563
|
|
|
1898
1564
|
def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
|
|
@@ -1905,16 +1571,29 @@ def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
|
|
|
1905
1571
|
class _RewriteBuilder(Generic[EXPR]):
|
|
1906
1572
|
lhs: EXPR
|
|
1907
1573
|
ruleset: Ruleset | None
|
|
1574
|
+
subsume: bool
|
|
1908
1575
|
|
|
1909
|
-
def to(self, rhs: EXPR, *conditions: FactLike) ->
|
|
1576
|
+
def to(self, rhs: EXPR, *conditions: FactLike) -> RewriteOrRule:
|
|
1910
1577
|
lhs = to_runtime_expr(self.lhs)
|
|
1911
|
-
|
|
1578
|
+
facts = _fact_likes(conditions)
|
|
1579
|
+
rhs = convert_to_same_type(rhs, lhs)
|
|
1580
|
+
rule = RewriteOrRule(
|
|
1581
|
+
Declarations.create(lhs, rhs, *facts, self.ruleset),
|
|
1582
|
+
RewriteDecl(
|
|
1583
|
+
lhs.__egg_typed_expr__.tp,
|
|
1584
|
+
lhs.__egg_typed_expr__.expr,
|
|
1585
|
+
rhs.__egg_typed_expr__.expr,
|
|
1586
|
+
tuple(f.fact for f in facts),
|
|
1587
|
+
self.subsume,
|
|
1588
|
+
),
|
|
1589
|
+
)
|
|
1912
1590
|
if self.ruleset:
|
|
1913
1591
|
self.ruleset.append(rule)
|
|
1914
1592
|
return rule
|
|
1915
1593
|
|
|
1916
1594
|
def __str__(self) -> str:
|
|
1917
|
-
|
|
1595
|
+
lhs = to_runtime_expr(self.lhs)
|
|
1596
|
+
return lhs.__egg_pretty__("rewrite")
|
|
1918
1597
|
|
|
1919
1598
|
|
|
1920
1599
|
@dataclass
|
|
@@ -1922,15 +1601,26 @@ class _BirewriteBuilder(Generic[EXPR]):
|
|
|
1922
1601
|
lhs: EXPR
|
|
1923
1602
|
ruleset: Ruleset | None
|
|
1924
1603
|
|
|
1925
|
-
def to(self, rhs: EXPR, *conditions: FactLike) ->
|
|
1604
|
+
def to(self, rhs: EXPR, *conditions: FactLike) -> RewriteOrRule:
|
|
1926
1605
|
lhs = to_runtime_expr(self.lhs)
|
|
1927
|
-
|
|
1606
|
+
facts = _fact_likes(conditions)
|
|
1607
|
+
rhs = convert_to_same_type(rhs, lhs)
|
|
1608
|
+
rule = RewriteOrRule(
|
|
1609
|
+
Declarations.create(lhs, rhs, *facts, self.ruleset),
|
|
1610
|
+
BiRewriteDecl(
|
|
1611
|
+
lhs.__egg_typed_expr__.tp,
|
|
1612
|
+
lhs.__egg_typed_expr__.expr,
|
|
1613
|
+
rhs.__egg_typed_expr__.expr,
|
|
1614
|
+
tuple(f.fact for f in facts),
|
|
1615
|
+
),
|
|
1616
|
+
)
|
|
1928
1617
|
if self.ruleset:
|
|
1929
1618
|
self.ruleset.append(rule)
|
|
1930
1619
|
return rule
|
|
1931
1620
|
|
|
1932
1621
|
def __str__(self) -> str:
|
|
1933
|
-
|
|
1622
|
+
lhs = to_runtime_expr(self.lhs)
|
|
1623
|
+
return lhs.__egg_pretty__("birewrite")
|
|
1934
1624
|
|
|
1935
1625
|
|
|
1936
1626
|
@dataclass
|
|
@@ -1939,52 +1629,84 @@ class _EqBuilder(Generic[EXPR]):
|
|
|
1939
1629
|
|
|
1940
1630
|
def to(self, *exprs: EXPR) -> Fact:
|
|
1941
1631
|
expr = to_runtime_expr(self.expr)
|
|
1942
|
-
|
|
1632
|
+
args = [expr, *(convert_to_same_type(e, expr) for e in exprs)]
|
|
1633
|
+
return Fact(
|
|
1634
|
+
Declarations.create(*args),
|
|
1635
|
+
EqDecl(expr.__egg_typed_expr__.tp, tuple(a.__egg_typed_expr__.expr for a in args)),
|
|
1636
|
+
)
|
|
1637
|
+
|
|
1638
|
+
def __repr__(self) -> str:
|
|
1639
|
+
return str(self)
|
|
1943
1640
|
|
|
1944
1641
|
def __str__(self) -> str:
|
|
1945
|
-
|
|
1642
|
+
expr = to_runtime_expr(self.expr)
|
|
1643
|
+
return expr.__egg_pretty__("eq")
|
|
1946
1644
|
|
|
1947
1645
|
|
|
1948
1646
|
@dataclass
|
|
1949
1647
|
class _NeBuilder(Generic[EXPR]):
|
|
1950
|
-
|
|
1648
|
+
lhs: EXPR
|
|
1951
1649
|
|
|
1952
|
-
def to(self,
|
|
1953
|
-
|
|
1954
|
-
|
|
1955
|
-
|
|
1956
|
-
res = RuntimeExpr(
|
|
1957
|
-
|
|
1958
|
-
TypedExprDecl(
|
|
1650
|
+
def to(self, rhs: EXPR) -> Unit:
|
|
1651
|
+
lhs = to_runtime_expr(self.lhs)
|
|
1652
|
+
rhs = convert_to_same_type(rhs, lhs)
|
|
1653
|
+
assert isinstance(Unit, RuntimeClass)
|
|
1654
|
+
res = RuntimeExpr.__from_value__(
|
|
1655
|
+
Declarations.create(Unit, lhs, rhs),
|
|
1656
|
+
TypedExprDecl(
|
|
1657
|
+
JustTypeRef("Unit"), CallDecl(FunctionRef("!="), (lhs.__egg_typed_expr__, rhs.__egg_typed_expr__))
|
|
1658
|
+
),
|
|
1959
1659
|
)
|
|
1960
1660
|
return cast(Unit, res)
|
|
1961
1661
|
|
|
1662
|
+
def __repr__(self) -> str:
|
|
1663
|
+
return str(self)
|
|
1664
|
+
|
|
1962
1665
|
def __str__(self) -> str:
|
|
1963
|
-
|
|
1666
|
+
expr = to_runtime_expr(self.lhs)
|
|
1667
|
+
return expr.__egg_pretty__("ne")
|
|
1964
1668
|
|
|
1965
1669
|
|
|
1966
1670
|
@dataclass
|
|
1967
1671
|
class _SetBuilder(Generic[EXPR]):
|
|
1968
|
-
lhs:
|
|
1672
|
+
lhs: EXPR
|
|
1969
1673
|
|
|
1970
|
-
def to(self, rhs: EXPR) ->
|
|
1674
|
+
def to(self, rhs: EXPR) -> Action:
|
|
1971
1675
|
lhs = to_runtime_expr(self.lhs)
|
|
1972
|
-
|
|
1676
|
+
rhs = convert_to_same_type(rhs, lhs)
|
|
1677
|
+
lhs_expr = lhs.__egg_typed_expr__.expr
|
|
1678
|
+
assert isinstance(lhs_expr, CallDecl), "Can only set function calls"
|
|
1679
|
+
return Action(
|
|
1680
|
+
Declarations.create(lhs, rhs),
|
|
1681
|
+
SetDecl(lhs.__egg_typed_expr__.tp, lhs_expr, rhs.__egg_typed_expr__.expr),
|
|
1682
|
+
)
|
|
1683
|
+
|
|
1684
|
+
def __repr__(self) -> str:
|
|
1685
|
+
return str(self)
|
|
1973
1686
|
|
|
1974
1687
|
def __str__(self) -> str:
|
|
1975
|
-
|
|
1688
|
+
lhs = to_runtime_expr(self.lhs)
|
|
1689
|
+
return lhs.__egg_pretty__("set_")
|
|
1976
1690
|
|
|
1977
1691
|
|
|
1978
1692
|
@dataclass
|
|
1979
1693
|
class _UnionBuilder(Generic[EXPR]):
|
|
1980
|
-
lhs:
|
|
1694
|
+
lhs: EXPR
|
|
1981
1695
|
|
|
1982
1696
|
def with_(self, rhs: EXPR) -> Action:
|
|
1983
1697
|
lhs = to_runtime_expr(self.lhs)
|
|
1984
|
-
|
|
1698
|
+
rhs = convert_to_same_type(rhs, lhs)
|
|
1699
|
+
return Action(
|
|
1700
|
+
Declarations.create(lhs, rhs),
|
|
1701
|
+
UnionDecl(lhs.__egg_typed_expr__.tp, lhs.__egg_typed_expr__.expr, rhs.__egg_typed_expr__.expr),
|
|
1702
|
+
)
|
|
1703
|
+
|
|
1704
|
+
def __repr__(self) -> str:
|
|
1705
|
+
return str(self)
|
|
1985
1706
|
|
|
1986
1707
|
def __str__(self) -> str:
|
|
1987
|
-
|
|
1708
|
+
lhs = to_runtime_expr(self.lhs)
|
|
1709
|
+
return lhs.__egg_pretty__("union")
|
|
1988
1710
|
|
|
1989
1711
|
|
|
1990
1712
|
@dataclass
|
|
@@ -1993,12 +1715,25 @@ class _RuleBuilder:
|
|
|
1993
1715
|
name: str | None
|
|
1994
1716
|
ruleset: Ruleset | None
|
|
1995
1717
|
|
|
1996
|
-
def then(self, *actions: ActionLike) ->
|
|
1997
|
-
|
|
1718
|
+
def then(self, *actions: ActionLike) -> RewriteOrRule:
|
|
1719
|
+
actions = _action_likes(actions)
|
|
1720
|
+
rule = RewriteOrRule(
|
|
1721
|
+
Declarations.create(self.ruleset, *actions, *self.facts),
|
|
1722
|
+
RuleDecl(tuple(a.action for a in actions), tuple(f.fact for f in self.facts), self.name),
|
|
1723
|
+
)
|
|
1998
1724
|
if self.ruleset:
|
|
1999
1725
|
self.ruleset.append(rule)
|
|
2000
1726
|
return rule
|
|
2001
1727
|
|
|
1728
|
+
def __str__(self) -> str:
|
|
1729
|
+
# TODO: Figure out how to stringify rulebuilder that preserves statements
|
|
1730
|
+
args = list(map(str, self.facts))
|
|
1731
|
+
if self.name is not None:
|
|
1732
|
+
args.append(f"name={self.name}")
|
|
1733
|
+
if ruleset is not None:
|
|
1734
|
+
args.append(f"ruleset={self.ruleset}")
|
|
1735
|
+
return f"rule({', '.join(args)})"
|
|
1736
|
+
|
|
2002
1737
|
|
|
2003
1738
|
def expr_parts(expr: Expr) -> TypedExprDecl:
|
|
2004
1739
|
"""
|
|
@@ -2015,60 +1750,61 @@ def to_runtime_expr(expr: Expr) -> RuntimeExpr:
|
|
|
2015
1750
|
return expr
|
|
2016
1751
|
|
|
2017
1752
|
|
|
2018
|
-
def run(ruleset: Ruleset | None = None, *until:
|
|
1753
|
+
def run(ruleset: Ruleset | None = None, *until: FactLike) -> Schedule:
|
|
2019
1754
|
"""
|
|
2020
1755
|
Create a run configuration.
|
|
2021
1756
|
"""
|
|
2022
|
-
|
|
1757
|
+
facts = _fact_likes(until)
|
|
1758
|
+
return Schedule(
|
|
1759
|
+
Thunk.fn(Declarations.create, ruleset, *facts),
|
|
1760
|
+
RunDecl(ruleset.__egg_name__ if ruleset else "", tuple(f.fact for f in facts) or None),
|
|
1761
|
+
)
|
|
2023
1762
|
|
|
2024
1763
|
|
|
2025
1764
|
def seq(*schedules: Schedule) -> Schedule:
|
|
2026
1765
|
"""
|
|
2027
1766
|
Run a sequence of schedules.
|
|
2028
1767
|
"""
|
|
2029
|
-
return
|
|
1768
|
+
return Schedule(Thunk.fn(Declarations.create, *schedules), SequenceDecl(tuple(s.schedule for s in schedules)))
|
|
2030
1769
|
|
|
2031
1770
|
|
|
2032
|
-
|
|
1771
|
+
ActionLike: TypeAlias = Action | Expr
|
|
2033
1772
|
|
|
2034
1773
|
|
|
2035
|
-
def
|
|
2036
|
-
|
|
2037
|
-
return expr_action(command_like)
|
|
2038
|
-
return command_like
|
|
1774
|
+
def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:
|
|
1775
|
+
return tuple(map(_action_like, action_likes))
|
|
2039
1776
|
|
|
2040
1777
|
|
|
2041
|
-
|
|
1778
|
+
def _action_like(action_like: ActionLike) -> Action:
|
|
1779
|
+
if isinstance(action_like, Expr):
|
|
1780
|
+
return expr_action(action_like)
|
|
1781
|
+
return action_like
|
|
2042
1782
|
|
|
2043
1783
|
|
|
2044
|
-
|
|
2045
|
-
"""
|
|
2046
|
-
Calls the function with variables of the type and name of the arguments.
|
|
2047
|
-
"""
|
|
2048
|
-
# Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
|
|
2049
|
-
# but not in the globals
|
|
2050
|
-
current_frame = inspect.currentframe()
|
|
2051
|
-
assert current_frame
|
|
2052
|
-
register_frame = current_frame.f_back
|
|
2053
|
-
assert register_frame
|
|
2054
|
-
original_frame = register_frame.f_back
|
|
2055
|
-
assert original_frame
|
|
2056
|
-
hints = get_type_hints(gen, gen.__globals__, original_frame.f_locals)
|
|
2057
|
-
args = (_var(p.name, hints[p.name]) for p in signature(gen).parameters.values())
|
|
2058
|
-
return gen(*args)
|
|
1784
|
+
Command: TypeAlias = Action | RewriteOrRule
|
|
2059
1785
|
|
|
1786
|
+
CommandLike: TypeAlias = ActionLike | RewriteOrRule
|
|
2060
1787
|
|
|
2061
|
-
ActionLike = Action | Expr
|
|
2062
1788
|
|
|
1789
|
+
def _command_like(command_like: CommandLike) -> Command:
|
|
1790
|
+
if isinstance(command_like, RewriteOrRule):
|
|
1791
|
+
return command_like
|
|
1792
|
+
return _action_like(command_like)
|
|
2063
1793
|
|
|
2064
|
-
def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:
|
|
2065
|
-
return tuple(map(_action_like, action_likes))
|
|
2066
1794
|
|
|
1795
|
+
RewriteOrRuleGenerator = Callable[..., Iterable[RewriteOrRule]]
|
|
2067
1796
|
|
|
2068
|
-
|
|
2069
|
-
|
|
2070
|
-
|
|
2071
|
-
|
|
1797
|
+
|
|
1798
|
+
def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) -> Iterable[RewriteOrRule]:
|
|
1799
|
+
"""
|
|
1800
|
+
Returns a thunk which will call the function with variables of the type and name of the arguments.
|
|
1801
|
+
"""
|
|
1802
|
+
# Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
|
|
1803
|
+
# but not in the globals
|
|
1804
|
+
|
|
1805
|
+
hints = get_type_hints(gen, gen.__globals__, frame.f_locals)
|
|
1806
|
+
args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
|
|
1807
|
+
return list(gen(*args)) # type: ignore[misc]
|
|
2072
1808
|
|
|
2073
1809
|
|
|
2074
1810
|
FactLike = Fact | Expr
|