egglog 6.1.0__cp310-none-win_amd64.whl → 7.0.0__cp310-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +1 -1
- egglog/bindings.cp310-win_amd64.pyd +0 -0
- egglog/bindings.pyi +2 -0
- egglog/builtins.py +1 -1
- egglog/conversion.py +172 -0
- egglog/declarations.py +329 -735
- egglog/egraph.py +531 -804
- egglog/egraph_state.py +417 -0
- egglog/exp/array_api.py +92 -80
- egglog/exp/array_api_numba.py +6 -1
- egglog/exp/siu_examples.py +35 -0
- egglog/pretty.py +418 -0
- egglog/runtime.py +196 -430
- egglog/thunk.py +72 -0
- egglog/type_constraint_solver.py +5 -2
- {egglog-6.1.0.dist-info → egglog-7.0.0.dist-info}/METADATA +19 -19
- {egglog-6.1.0.dist-info → egglog-7.0.0.dist-info}/RECORD +19 -14
- {egglog-6.1.0.dist-info → egglog-7.0.0.dist-info}/WHEEL +0 -0
- {egglog-6.1.0.dist-info → egglog-7.0.0.dist-info}/license_files/LICENSE +0 -0
egglog/egraph.py
CHANGED
|
@@ -3,12 +3,10 @@ from __future__ import annotations
|
|
|
3
3
|
import inspect
|
|
4
4
|
import pathlib
|
|
5
5
|
import tempfile
|
|
6
|
-
from abc import
|
|
6
|
+
from abc import abstractmethod
|
|
7
7
|
from collections.abc import Callable, Iterable
|
|
8
8
|
from contextvars import ContextVar, Token
|
|
9
|
-
from copy import deepcopy
|
|
10
9
|
from dataclasses import InitVar, dataclass, field
|
|
11
|
-
from functools import cached_property
|
|
12
10
|
from inspect import Parameter, currentframe, signature
|
|
13
11
|
from types import FrameType, FunctionType
|
|
14
12
|
from typing import (
|
|
@@ -18,6 +16,7 @@ from typing import (
|
|
|
18
16
|
Generic,
|
|
19
17
|
Literal,
|
|
20
18
|
NoReturn,
|
|
19
|
+
TypeAlias,
|
|
21
20
|
TypedDict,
|
|
22
21
|
TypeVar,
|
|
23
22
|
cast,
|
|
@@ -26,14 +25,16 @@ from typing import (
|
|
|
26
25
|
)
|
|
27
26
|
|
|
28
27
|
import graphviz
|
|
29
|
-
from typing_extensions import ParamSpec, Self, Unpack, deprecated
|
|
30
|
-
|
|
31
|
-
from egglog.declarations import REFLECTED_BINARY_METHODS, Declarations
|
|
28
|
+
from typing_extensions import ParamSpec, Self, Unpack, assert_never, deprecated
|
|
32
29
|
|
|
33
30
|
from . import bindings
|
|
31
|
+
from .conversion import *
|
|
34
32
|
from .declarations import *
|
|
33
|
+
from .egraph_state import *
|
|
35
34
|
from .ipython_magic import IN_IPYTHON
|
|
35
|
+
from .pretty import pretty_decl
|
|
36
36
|
from .runtime import *
|
|
37
|
+
from .thunk import *
|
|
37
38
|
|
|
38
39
|
if TYPE_CHECKING:
|
|
39
40
|
import ipywidgets
|
|
@@ -58,6 +59,7 @@ __all__ = [
|
|
|
58
59
|
"let",
|
|
59
60
|
"constant",
|
|
60
61
|
"delete",
|
|
62
|
+
"subsume",
|
|
61
63
|
"union",
|
|
62
64
|
"set_",
|
|
63
65
|
"rule",
|
|
@@ -65,6 +67,9 @@ __all__ = [
|
|
|
65
67
|
"vars_",
|
|
66
68
|
"Fact",
|
|
67
69
|
"expr_parts",
|
|
70
|
+
"expr_action",
|
|
71
|
+
"expr_fact",
|
|
72
|
+
"action_command",
|
|
68
73
|
"Schedule",
|
|
69
74
|
"run",
|
|
70
75
|
"seq",
|
|
@@ -79,11 +84,10 @@ __all__ = [
|
|
|
79
84
|
"_NeBuilder",
|
|
80
85
|
"_SetBuilder",
|
|
81
86
|
"_UnionBuilder",
|
|
82
|
-
"
|
|
83
|
-
"
|
|
84
|
-
"BiRewrite",
|
|
85
|
-
"Union_",
|
|
87
|
+
"RewriteOrRule",
|
|
88
|
+
"Fact",
|
|
86
89
|
"Action",
|
|
90
|
+
"Command",
|
|
87
91
|
]
|
|
88
92
|
|
|
89
93
|
T = TypeVar("T")
|
|
@@ -112,7 +116,24 @@ IGNORED_ATTRIBUTES = {
|
|
|
112
116
|
}
|
|
113
117
|
|
|
114
118
|
|
|
119
|
+
# special methods that return none and mutate self
|
|
115
120
|
ALWAYS_MUTATES_SELF = {"__setitem__", "__delitem__"}
|
|
121
|
+
# special methods which must return real python values instead of lazy expressions
|
|
122
|
+
ALWAYS_PRESERVED = {
|
|
123
|
+
"__repr__",
|
|
124
|
+
"__str__",
|
|
125
|
+
"__bytes__",
|
|
126
|
+
"__format__",
|
|
127
|
+
"__hash__",
|
|
128
|
+
"__bool__",
|
|
129
|
+
"__len__",
|
|
130
|
+
"__length_hint__",
|
|
131
|
+
"__iter__",
|
|
132
|
+
"__reversed__",
|
|
133
|
+
"__contains__",
|
|
134
|
+
"__index__",
|
|
135
|
+
"__bufer__",
|
|
136
|
+
}
|
|
116
137
|
|
|
117
138
|
|
|
118
139
|
def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
|
|
@@ -124,7 +145,7 @@ def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
|
|
|
124
145
|
return EGraph().extract(x)
|
|
125
146
|
|
|
126
147
|
|
|
127
|
-
def check(x: FactLike, schedule: Schedule | None = None, *given:
|
|
148
|
+
def check(x: FactLike, schedule: Schedule | None = None, *given: ActionLike) -> None:
|
|
128
149
|
"""
|
|
129
150
|
Verifies that the fact is true given some assumptions and after running the schedule.
|
|
130
151
|
"""
|
|
@@ -136,9 +157,6 @@ def check(x: FactLike, schedule: Schedule | None = None, *given: Union_ | Expr |
|
|
|
136
157
|
egraph.check(x)
|
|
137
158
|
|
|
138
159
|
|
|
139
|
-
# def extract(res: )
|
|
140
|
-
|
|
141
|
-
|
|
142
160
|
@dataclass
|
|
143
161
|
class _BaseModule:
|
|
144
162
|
"""
|
|
@@ -181,15 +199,8 @@ class _BaseModule:
|
|
|
181
199
|
Registers a class.
|
|
182
200
|
"""
|
|
183
201
|
if kwargs:
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
def _inner(cls: object, egg_sort: str = kwargs["egg_sort"]):
|
|
187
|
-
assert isinstance(cls, RuntimeClass)
|
|
188
|
-
assert isinstance(cls.lazy_decls, _ClassDeclerationsConstructor)
|
|
189
|
-
cls.lazy_decls.egg_sort = egg_sort
|
|
190
|
-
return cls
|
|
191
|
-
|
|
192
|
-
return _inner
|
|
202
|
+
msg = "Switch to subclassing from Expr and passing egg_sort as a keyword arg to the class constructor"
|
|
203
|
+
raise NotImplementedError(msg)
|
|
193
204
|
|
|
194
205
|
assert len(args) == 1
|
|
195
206
|
return args[0]
|
|
@@ -280,9 +291,9 @@ class _BaseModule:
|
|
|
280
291
|
# If we have any positional args, then we are calling it directly on a function
|
|
281
292
|
if args:
|
|
282
293
|
assert len(args) == 1
|
|
283
|
-
return
|
|
294
|
+
return _FunctionConstructor(fn_locals)(args[0])
|
|
284
295
|
# otherwise, we are passing some keyword args, so save those, and then return a partial
|
|
285
|
-
return
|
|
296
|
+
return _FunctionConstructor(fn_locals, **kwargs)
|
|
286
297
|
|
|
287
298
|
@deprecated("Use top level `ruleset` function instead")
|
|
288
299
|
def ruleset(self, name: str) -> Ruleset:
|
|
@@ -324,17 +335,26 @@ class _BaseModule:
|
|
|
324
335
|
"""
|
|
325
336
|
return constant(name, tp, egg_name)
|
|
326
337
|
|
|
327
|
-
def register(
|
|
338
|
+
def register(
|
|
339
|
+
self,
|
|
340
|
+
/,
|
|
341
|
+
command_or_generator: ActionLike | RewriteOrRule | RewriteOrRuleGenerator,
|
|
342
|
+
*command_likes: ActionLike | RewriteOrRule,
|
|
343
|
+
) -> None:
|
|
328
344
|
"""
|
|
329
345
|
Registers any number of rewrites or rules.
|
|
330
346
|
"""
|
|
331
347
|
if isinstance(command_or_generator, FunctionType):
|
|
332
348
|
assert not command_likes
|
|
333
|
-
|
|
349
|
+
current_frame = inspect.currentframe()
|
|
350
|
+
assert current_frame
|
|
351
|
+
original_frame = current_frame.f_back
|
|
352
|
+
assert original_frame
|
|
353
|
+
command_likes = tuple(_rewrite_or_rule_generator(command_or_generator, original_frame))
|
|
334
354
|
else:
|
|
335
355
|
command_likes = (cast(CommandLike, command_or_generator), *command_likes)
|
|
336
|
-
|
|
337
|
-
self._register_commands(
|
|
356
|
+
commands = [_command_like(c) for c in command_likes]
|
|
357
|
+
self._register_commands(commands)
|
|
338
358
|
|
|
339
359
|
@abstractmethod
|
|
340
360
|
def _register_commands(self, cmds: list[Command]) -> None:
|
|
@@ -417,136 +437,116 @@ class _ExprMetaclass(type):
|
|
|
417
437
|
# If this is the Expr subclass, just return the class
|
|
418
438
|
if not bases:
|
|
419
439
|
return super().__new__(cls, name, bases, namespace)
|
|
440
|
+
# TODO: Raise error on subclassing or multiple inheritence
|
|
420
441
|
|
|
421
442
|
frame = currentframe()
|
|
422
443
|
assert frame
|
|
423
444
|
prev_frame = frame.f_back
|
|
424
445
|
assert prev_frame
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
cls_name=name,
|
|
434
|
-
).current_cls
|
|
446
|
+
|
|
447
|
+
# Store frame so that we can get live access to updated locals/globals
|
|
448
|
+
# Otherwise, f_locals returns a copy
|
|
449
|
+
# https://peps.python.org/pep-0667/
|
|
450
|
+
decls_thunk = Thunk.fn(
|
|
451
|
+
_generate_class_decls, namespace, prev_frame, builtin, egg_sort, name, fallback=Declarations
|
|
452
|
+
)
|
|
453
|
+
return RuntimeClass(decls_thunk, TypeRefWithVars(name))
|
|
435
454
|
|
|
436
455
|
def __instancecheck__(cls, instance: object) -> bool:
|
|
437
456
|
return isinstance(instance, RuntimeExpr)
|
|
438
457
|
|
|
439
458
|
|
|
440
|
-
|
|
441
|
-
|
|
459
|
+
def _generate_class_decls(
|
|
460
|
+
namespace: dict[str, Any], frame: FrameType, builtin: bool, egg_sort: str | None, cls_name: str
|
|
461
|
+
) -> Declarations:
|
|
442
462
|
"""
|
|
443
463
|
Lazy constructor for class declerations to support classes with methods whose types are not yet defined.
|
|
444
464
|
"""
|
|
465
|
+
parameters: list[TypeVar] = (
|
|
466
|
+
# Get the generic params from the orig bases generic class
|
|
467
|
+
namespace["__orig_bases__"][1].__parameters__ if "__orig_bases__" in namespace else []
|
|
468
|
+
)
|
|
469
|
+
type_vars = tuple(p.__name__ for p in parameters)
|
|
470
|
+
del parameters
|
|
471
|
+
cls_decl = ClassDecl(egg_sort, type_vars, builtin)
|
|
472
|
+
decls = Declarations(_classes={cls_name: cls_decl})
|
|
473
|
+
|
|
474
|
+
##
|
|
475
|
+
# Register class variables
|
|
476
|
+
##
|
|
477
|
+
# Create a dummy type to pass to get_type_hints to resolve the annotations we have
|
|
478
|
+
_Dummytype = type("_DummyType", (), {"__annotations__": namespace.get("__annotations__", {})})
|
|
479
|
+
for k, v in get_type_hints(_Dummytype, globalns=frame.f_globals, localns=frame.f_locals).items():
|
|
480
|
+
if getattr(v, "__origin__", None) == ClassVar:
|
|
481
|
+
(inner_tp,) = v.__args__
|
|
482
|
+
type_ref = resolve_type_annotation(decls, inner_tp).to_just()
|
|
483
|
+
cls_decl.class_variables[k] = ConstantDecl(type_ref)
|
|
445
484
|
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
egg_sort: str | None
|
|
450
|
-
cls_name: str
|
|
451
|
-
current_cls: RuntimeClass = field(init=False)
|
|
485
|
+
else:
|
|
486
|
+
msg = f"On class {cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
|
|
487
|
+
raise NotImplementedError(msg)
|
|
452
488
|
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
_Dummytype.__annotations__ = self.namespace.get("__annotations__", {})
|
|
477
|
-
# Make lazy update to locals, so we keep a live handle on them after class creation
|
|
478
|
-
locals = self.frame.f_locals.copy()
|
|
479
|
-
locals[self.cls_name] = self.current_cls
|
|
480
|
-
for k, v in get_type_hints(_Dummytype, globalns=self.frame.f_globals, localns=locals).items():
|
|
481
|
-
if getattr(v, "__origin__", None) == ClassVar:
|
|
482
|
-
(inner_tp,) = v.__args__
|
|
483
|
-
_register_constant(decls, ClassVariableRef(self.cls_name, k), inner_tp, None)
|
|
484
|
-
else:
|
|
485
|
-
msg = f"On class {self.cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
|
|
486
|
-
raise NotImplementedError(msg)
|
|
487
|
-
|
|
488
|
-
# Then register each of its methods
|
|
489
|
-
for method_name, method in cls_dict.items():
|
|
490
|
-
is_init = method_name == "__init__"
|
|
491
|
-
# Don't register the init methods for literals, since those don't use the type checking mechanisms
|
|
492
|
-
if is_init and self.cls_name in LIT_CLASS_NAMES:
|
|
493
|
-
continue
|
|
494
|
-
if isinstance(method, _WrappedMethod):
|
|
495
|
-
fn = method.fn
|
|
496
|
-
egg_fn = method.egg_fn
|
|
497
|
-
cost = method.cost
|
|
498
|
-
default = method.default
|
|
499
|
-
merge = method.merge
|
|
500
|
-
on_merge = method.on_merge
|
|
501
|
-
mutates_first_arg = method.mutates_self
|
|
502
|
-
unextractable = method.unextractable
|
|
503
|
-
if method.preserve:
|
|
504
|
-
decls.register_preserved_method(self.cls_name, method_name, fn)
|
|
505
|
-
continue
|
|
506
|
-
else:
|
|
507
|
-
fn = method
|
|
489
|
+
##
|
|
490
|
+
# Register methods, classmethods, preserved methods, and properties
|
|
491
|
+
##
|
|
492
|
+
|
|
493
|
+
# The type ref of self is paramterized by the type vars
|
|
494
|
+
slf_type_ref = TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
|
|
495
|
+
|
|
496
|
+
# Get all the methods from the class
|
|
497
|
+
filtered_namespace: list[tuple[str, Any]] = [
|
|
498
|
+
(k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
|
|
499
|
+
]
|
|
500
|
+
|
|
501
|
+
# Then register each of its methods
|
|
502
|
+
for method_name, method in filtered_namespace:
|
|
503
|
+
is_init = method_name == "__init__"
|
|
504
|
+
# Don't register the init methods for literals, since those don't use the type checking mechanisms
|
|
505
|
+
if is_init and cls_name in LIT_CLASS_NAMES:
|
|
506
|
+
continue
|
|
507
|
+
match method:
|
|
508
|
+
case _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn, preserve, mutates, unextractable):
|
|
509
|
+
pass
|
|
510
|
+
case _:
|
|
508
511
|
egg_fn, cost, default, merge, on_merge = None, None, None, None, None
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
fn = fn.fget
|
|
520
|
-
is_property = True
|
|
521
|
-
if is_classmethod:
|
|
522
|
-
msg = "Can't have a classmethod property"
|
|
523
|
-
raise NotImplementedError(msg)
|
|
524
|
-
else:
|
|
525
|
-
is_property = False
|
|
526
|
-
ref: FunctionCallableRef = (
|
|
527
|
-
ClassMethodRef(self.cls_name, method_name)
|
|
528
|
-
if is_classmethod
|
|
529
|
-
else PropertyRef(self.cls_name, method_name)
|
|
530
|
-
if is_property
|
|
531
|
-
else MethodRef(self.cls_name, method_name)
|
|
532
|
-
)
|
|
533
|
-
_register_function(
|
|
512
|
+
fn = method
|
|
513
|
+
unextractable, preserve = False, False
|
|
514
|
+
mutates = method_name in ALWAYS_MUTATES_SELF
|
|
515
|
+
if preserve:
|
|
516
|
+
cls_decl.preserved_methods[method_name] = fn
|
|
517
|
+
continue
|
|
518
|
+
locals = frame.f_locals
|
|
519
|
+
|
|
520
|
+
def create_decl(fn: object, first: Literal["cls"] | TypeRefWithVars) -> FunctionDecl:
|
|
521
|
+
return _fn_decl(
|
|
534
522
|
decls,
|
|
535
|
-
|
|
536
|
-
egg_fn,
|
|
523
|
+
egg_fn, # noqa: B023
|
|
537
524
|
fn,
|
|
538
|
-
locals,
|
|
539
|
-
default,
|
|
540
|
-
cost,
|
|
541
|
-
merge,
|
|
542
|
-
on_merge,
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
is_init,
|
|
547
|
-
unextractable
|
|
525
|
+
locals, # noqa: B023
|
|
526
|
+
default, # noqa: B023
|
|
527
|
+
cost, # noqa: B023
|
|
528
|
+
merge, # noqa: B023
|
|
529
|
+
on_merge, # noqa: B023
|
|
530
|
+
mutates, # noqa: B023
|
|
531
|
+
builtin,
|
|
532
|
+
first,
|
|
533
|
+
is_init, # noqa: B023
|
|
534
|
+
unextractable, # noqa: B023
|
|
548
535
|
)
|
|
549
536
|
|
|
537
|
+
match fn:
|
|
538
|
+
case classmethod():
|
|
539
|
+
cls_decl.class_methods[method_name] = create_decl(fn.__func__, "cls")
|
|
540
|
+
case property():
|
|
541
|
+
cls_decl.properties[method_name] = create_decl(fn.fget, slf_type_ref)
|
|
542
|
+
case _:
|
|
543
|
+
if is_init:
|
|
544
|
+
cls_decl.class_methods[method_name] = create_decl(fn, slf_type_ref)
|
|
545
|
+
else:
|
|
546
|
+
cls_decl.methods[method_name] = create_decl(fn, slf_type_ref)
|
|
547
|
+
|
|
548
|
+
return decls
|
|
549
|
+
|
|
550
550
|
|
|
551
551
|
@overload
|
|
552
552
|
def function(fn: CALLABLE, /) -> CALLABLE: ...
|
|
@@ -589,48 +589,46 @@ def function(*args, **kwargs) -> Any:
|
|
|
589
589
|
# If we have any positional args, then we are calling it directly on a function
|
|
590
590
|
if args:
|
|
591
591
|
assert len(args) == 1
|
|
592
|
-
return
|
|
592
|
+
return _FunctionConstructor(fn_locals)(args[0])
|
|
593
593
|
# otherwise, we are passing some keyword args, so save those, and then return a partial
|
|
594
|
-
return
|
|
594
|
+
return _FunctionConstructor(fn_locals, **kwargs)
|
|
595
595
|
|
|
596
596
|
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
hint_locals: dict[str, Any]
|
|
600
|
-
builtin: bool = False
|
|
601
|
-
mutates_first_arg: bool = False
|
|
602
|
-
egg_fn: str | None = None
|
|
603
|
-
cost: int | None = None
|
|
604
|
-
default: RuntimeExpr | None = None
|
|
605
|
-
merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None
|
|
606
|
-
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
|
|
607
|
-
unextractable: bool = False
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
return RuntimeFunction(decls, name)
|
|
597
|
+
@dataclass
|
|
598
|
+
class _FunctionConstructor:
|
|
599
|
+
hint_locals: dict[str, Any]
|
|
600
|
+
builtin: bool = False
|
|
601
|
+
mutates_first_arg: bool = False
|
|
602
|
+
egg_fn: str | None = None
|
|
603
|
+
cost: int | None = None
|
|
604
|
+
default: RuntimeExpr | None = None
|
|
605
|
+
merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None
|
|
606
|
+
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
|
|
607
|
+
unextractable: bool = False
|
|
608
|
+
|
|
609
|
+
def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
|
|
610
|
+
return RuntimeFunction(Thunk.fn(self.create_decls, fn), FunctionRef(fn.__name__))
|
|
611
|
+
|
|
612
|
+
def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations:
|
|
613
|
+
decls = Declarations()
|
|
614
|
+
decls._functions[fn.__name__] = _fn_decl(
|
|
615
|
+
decls,
|
|
616
|
+
self.egg_fn,
|
|
617
|
+
fn,
|
|
618
|
+
self.hint_locals,
|
|
619
|
+
self.default,
|
|
620
|
+
self.cost,
|
|
621
|
+
self.merge,
|
|
622
|
+
self.on_merge,
|
|
623
|
+
self.mutates_first_arg,
|
|
624
|
+
self.builtin,
|
|
625
|
+
unextractable=self.unextractable,
|
|
626
|
+
)
|
|
627
|
+
return decls
|
|
629
628
|
|
|
630
629
|
|
|
631
|
-
def
|
|
630
|
+
def _fn_decl(
|
|
632
631
|
decls: Declarations,
|
|
633
|
-
ref: FunctionCallableRef,
|
|
634
632
|
egg_name: str | None,
|
|
635
633
|
fn: object,
|
|
636
634
|
# Pass in the locals, retrieved from the frame when wrapping,
|
|
@@ -646,7 +644,7 @@ def _register_function(
|
|
|
646
644
|
first_arg: Literal["cls"] | TypeOrVarRef | None = None,
|
|
647
645
|
is_init: bool = False,
|
|
648
646
|
unextractable: bool = False,
|
|
649
|
-
) ->
|
|
647
|
+
) -> FunctionDecl:
|
|
650
648
|
if not isinstance(fn, FunctionType):
|
|
651
649
|
raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
|
|
652
650
|
|
|
@@ -699,8 +697,8 @@ def _register_function(
|
|
|
699
697
|
None
|
|
700
698
|
if merge is None
|
|
701
699
|
else merge(
|
|
702
|
-
RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
|
|
703
|
-
RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
|
|
700
|
+
RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
|
|
701
|
+
RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
|
|
704
702
|
)
|
|
705
703
|
)
|
|
706
704
|
decls |= merged
|
|
@@ -710,30 +708,25 @@ def _register_function(
|
|
|
710
708
|
if on_merge is None
|
|
711
709
|
else _action_likes(
|
|
712
710
|
on_merge(
|
|
713
|
-
RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
|
|
714
|
-
RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
|
|
711
|
+
RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
|
|
712
|
+
RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
|
|
715
713
|
)
|
|
716
714
|
)
|
|
717
715
|
)
|
|
718
716
|
decls.update(*merge_action)
|
|
719
|
-
|
|
720
|
-
return_type=return_type,
|
|
717
|
+
return FunctionDecl(
|
|
718
|
+
return_type=None if mutates_first_arg else return_type,
|
|
721
719
|
var_arg_type=var_arg_type,
|
|
722
720
|
arg_types=arg_types,
|
|
723
721
|
arg_names=tuple(t.name for t in params),
|
|
724
722
|
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
None if default is None else default.__egg_typed_expr__.expr,
|
|
733
|
-
merged.__egg_typed_expr__.expr if merged is not None else None,
|
|
734
|
-
[a._to_egg_action() for a in merge_action],
|
|
735
|
-
unextractable,
|
|
736
|
-
is_builtin,
|
|
723
|
+
cost=cost,
|
|
724
|
+
egg_name=egg_name,
|
|
725
|
+
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
|
|
726
|
+
unextractable=unextractable,
|
|
727
|
+
builtin=is_builtin,
|
|
728
|
+
default=None if default is None else default.__egg_typed_expr__.expr,
|
|
729
|
+
on_merge=tuple(a.action for a in merge_action),
|
|
737
730
|
)
|
|
738
731
|
|
|
739
732
|
|
|
@@ -764,49 +757,31 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..
|
|
|
764
757
|
"""
|
|
765
758
|
Creates a function whose return type is `Unit` and has a default value.
|
|
766
759
|
"""
|
|
760
|
+
decls_thunk = Thunk.fn(_relation_decls, name, tps, egg_fn)
|
|
761
|
+
return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, FunctionRef(name)))
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations:
|
|
767
765
|
decls = Declarations()
|
|
768
766
|
decls |= cast(RuntimeClass, Unit)
|
|
769
|
-
arg_types = tuple(resolve_type_annotation(decls, tp) for tp in tps)
|
|
770
|
-
|
|
771
|
-
decls
|
|
772
|
-
FunctionRef(name),
|
|
773
|
-
fn_decl,
|
|
774
|
-
egg_fn,
|
|
775
|
-
cost=None,
|
|
776
|
-
default=None,
|
|
777
|
-
merge=None,
|
|
778
|
-
merge_action=[],
|
|
779
|
-
unextractable=False,
|
|
780
|
-
builtin=False,
|
|
781
|
-
is_relation=True,
|
|
782
|
-
)
|
|
783
|
-
return cast(Callable[..., Unit], RuntimeFunction(decls, name))
|
|
767
|
+
arg_types = tuple(resolve_type_annotation(decls, tp).to_just() for tp in tps)
|
|
768
|
+
decls._functions[name] = RelationDecl(arg_types, tuple(None for _ in tps), egg_fn)
|
|
769
|
+
return decls
|
|
784
770
|
|
|
785
771
|
|
|
786
772
|
def constant(name: str, tp: type[EXPR], egg_name: str | None = None) -> EXPR:
|
|
787
773
|
"""
|
|
788
|
-
|
|
789
774
|
A "constant" is implemented as the instantiation of a value that takes no args.
|
|
790
775
|
This creates a function with `name` and return type `tp` and returns a value of it being called.
|
|
791
776
|
"""
|
|
792
|
-
|
|
793
|
-
decls = Declarations()
|
|
794
|
-
type_ref = _register_constant(decls, ref, tp, egg_name)
|
|
795
|
-
return cast(EXPR, RuntimeExpr(decls, TypedExprDecl(type_ref, CallDecl(ref))))
|
|
777
|
+
return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name)))
|
|
796
778
|
|
|
797
779
|
|
|
798
|
-
def
|
|
799
|
-
decls
|
|
800
|
-
ref: ConstantRef | ClassVariableRef,
|
|
801
|
-
tp: object,
|
|
802
|
-
egg_name: str | None,
|
|
803
|
-
) -> JustTypeRef:
|
|
804
|
-
"""
|
|
805
|
-
Register a constant, returning its typeref().
|
|
806
|
-
"""
|
|
780
|
+
def _constant_thunk(name: str, tp: type, egg_name: str | None) -> tuple[Declarations, TypedExprDecl]:
|
|
781
|
+
decls = Declarations()
|
|
807
782
|
type_ref = resolve_type_annotation(decls, tp).to_just()
|
|
808
|
-
decls.
|
|
809
|
-
return type_ref
|
|
783
|
+
decls._constants[name] = ConstantDecl(type_ref, egg_name)
|
|
784
|
+
return decls, TypedExprDecl(type_ref, CallDecl(ConstantRef(name)))
|
|
810
785
|
|
|
811
786
|
|
|
812
787
|
def _last_param_variable(params: list[Parameter]) -> bool:
|
|
@@ -858,29 +833,6 @@ class GraphvizKwargs(TypedDict, total=False):
|
|
|
858
833
|
split_primitive_outputs: bool
|
|
859
834
|
|
|
860
835
|
|
|
861
|
-
@dataclass
|
|
862
|
-
class _EGraphState:
|
|
863
|
-
"""
|
|
864
|
-
State of the EGraph declerations and rulesets, so when we pop/push the stack we know whats defined.
|
|
865
|
-
"""
|
|
866
|
-
|
|
867
|
-
# The decleratons we have added. The _cmds represent all the symbols we have added
|
|
868
|
-
decls: Declarations = field(default_factory=Declarations)
|
|
869
|
-
# List of rulesets already added, so we don't re-add them if they are passed again
|
|
870
|
-
added_rulesets: set[str] = field(default_factory=set)
|
|
871
|
-
|
|
872
|
-
def add_decls(self, new_decls: Declarations) -> Iterable[bindings._Command]:
|
|
873
|
-
new_cmds = [v for k, v in new_decls._cmds.items() if k not in self.decls._cmds]
|
|
874
|
-
self.decls |= new_decls
|
|
875
|
-
return new_cmds
|
|
876
|
-
|
|
877
|
-
def add_rulesets(self, rulesets: Iterable[Ruleset]) -> Iterable[bindings._Command]:
|
|
878
|
-
for ruleset in rulesets:
|
|
879
|
-
if ruleset.egg_name not in self.added_rulesets:
|
|
880
|
-
self.added_rulesets.add(ruleset.egg_name)
|
|
881
|
-
yield from ruleset._cmds
|
|
882
|
-
|
|
883
|
-
|
|
884
836
|
@dataclass
|
|
885
837
|
class EGraph(_BaseModule):
|
|
886
838
|
"""
|
|
@@ -892,56 +844,34 @@ class EGraph(_BaseModule):
|
|
|
892
844
|
seminaive: InitVar[bool] = True
|
|
893
845
|
save_egglog_string: InitVar[bool] = False
|
|
894
846
|
|
|
895
|
-
|
|
896
|
-
_egraph: bindings.EGraph = field(repr=False, init=False)
|
|
897
|
-
_egglog_string: str | None = field(default=None, repr=False, init=False)
|
|
898
|
-
_state: _EGraphState = field(default_factory=_EGraphState, repr=False)
|
|
847
|
+
_state: EGraphState = field(init=False)
|
|
899
848
|
# For pushing/popping with egglog
|
|
900
|
-
_state_stack: list[
|
|
849
|
+
_state_stack: list[EGraphState] = field(default_factory=list, repr=False)
|
|
901
850
|
# For storing the global "current" egraph
|
|
902
851
|
_token_stack: list[Token[EGraph]] = field(default_factory=list, repr=False)
|
|
903
852
|
|
|
904
853
|
def __post_init__(self, modules: list[Module], seminaive: bool, save_egglog_string: bool) -> None:
|
|
905
|
-
|
|
854
|
+
egraph = bindings.EGraph(GLOBAL_PY_OBJECT_SORT, seminaive=seminaive, record=save_egglog_string)
|
|
855
|
+
self._state = EGraphState(egraph)
|
|
906
856
|
super().__post_init__(modules)
|
|
907
857
|
|
|
908
858
|
for m in self._flatted_deps:
|
|
909
|
-
self._add_decls(*m.cmds)
|
|
910
859
|
self._register_commands(m.cmds)
|
|
911
|
-
if save_egglog_string:
|
|
912
|
-
self._egglog_string = ""
|
|
913
|
-
|
|
914
|
-
def _register_commands(self, commands: list[Command]) -> None:
|
|
915
|
-
for c in commands:
|
|
916
|
-
if c.ruleset:
|
|
917
|
-
self._add_schedule(c.ruleset)
|
|
918
|
-
|
|
919
|
-
self._add_decls(*commands)
|
|
920
|
-
self._process_commands(command._to_egg_command(self._default_ruleset_name) for command in commands)
|
|
921
|
-
|
|
922
|
-
def _process_commands(self, commands: Iterable[bindings._Command]) -> None:
|
|
923
|
-
commands = list(commands)
|
|
924
|
-
self._egraph.run_program(*commands)
|
|
925
|
-
if isinstance(self._egglog_string, str):
|
|
926
|
-
self._egglog_string += "\n".join(str(c) for c in commands) + "\n"
|
|
927
860
|
|
|
928
861
|
def _add_decls(self, *decls: DeclerationsLike) -> None:
|
|
929
|
-
for d in
|
|
930
|
-
self.
|
|
931
|
-
|
|
932
|
-
def _add_schedule(self, schedule: Schedule) -> None:
|
|
933
|
-
self._add_decls(schedule)
|
|
934
|
-
self._process_commands(self._state.add_rulesets(schedule._rulesets()))
|
|
862
|
+
for d in decls:
|
|
863
|
+
self._state.__egg_decls__ |= d
|
|
935
864
|
|
|
936
865
|
@property
|
|
937
866
|
def as_egglog_string(self) -> str:
|
|
938
867
|
"""
|
|
939
868
|
Returns the egglog string for this module.
|
|
940
869
|
"""
|
|
941
|
-
|
|
870
|
+
cmds = self._egraph.commands()
|
|
871
|
+
if cmds is None:
|
|
942
872
|
msg = "Can't get egglog string unless EGraph created with save_egglog_string=True"
|
|
943
873
|
raise ValueError(msg)
|
|
944
|
-
return
|
|
874
|
+
return cmds
|
|
945
875
|
|
|
946
876
|
def _repr_mimebundle_(self, *args, **kwargs):
|
|
947
877
|
"""
|
|
@@ -954,7 +884,7 @@ class EGraph(_BaseModule):
|
|
|
954
884
|
kwargs.setdefault("split_primitive_outputs", True)
|
|
955
885
|
n_inline = kwargs.pop("n_inline_leaves", 0)
|
|
956
886
|
serialized = self._egraph.serialize([], **kwargs) # type: ignore[misc]
|
|
957
|
-
serialized.map_ops(self._state.
|
|
887
|
+
serialized.map_ops(self._state.op_mapping())
|
|
958
888
|
for _ in range(n_inline):
|
|
959
889
|
serialized.inline_leaves()
|
|
960
890
|
original = serialized.to_dot()
|
|
@@ -1016,17 +946,23 @@ class EGraph(_BaseModule):
|
|
|
1016
946
|
Loads a CSV file and sets it as *input, output of the function.
|
|
1017
947
|
"""
|
|
1018
948
|
ref, decls = resolve_callable(fn)
|
|
1019
|
-
|
|
1020
|
-
self.
|
|
1021
|
-
self.
|
|
949
|
+
self._add_decls(decls)
|
|
950
|
+
fn_name = self._state.callable_ref_to_egg(ref)
|
|
951
|
+
self._egraph.run_program(bindings.Input(fn_name, path))
|
|
1022
952
|
|
|
1023
953
|
def let(self, name: str, expr: EXPR) -> EXPR:
|
|
1024
954
|
"""
|
|
1025
955
|
Define a new expression in the egraph and return a reference to it.
|
|
1026
956
|
"""
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
957
|
+
action = let(name, expr)
|
|
958
|
+
self.register(action)
|
|
959
|
+
runtime_expr = to_runtime_expr(expr)
|
|
960
|
+
return cast(
|
|
961
|
+
EXPR,
|
|
962
|
+
RuntimeExpr.__from_value__(
|
|
963
|
+
self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name))
|
|
964
|
+
),
|
|
965
|
+
)
|
|
1030
966
|
|
|
1031
967
|
@overload
|
|
1032
968
|
def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> EXPR: ...
|
|
@@ -1041,28 +977,19 @@ class EGraph(_BaseModule):
|
|
|
1041
977
|
Simplifies the given expression.
|
|
1042
978
|
"""
|
|
1043
979
|
schedule = run(ruleset, *until) * limit_or_schedule if isinstance(limit_or_schedule, int) else limit_or_schedule
|
|
1044
|
-
del limit_or_schedule
|
|
1045
|
-
|
|
1046
|
-
self._add_decls(
|
|
1047
|
-
self.
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
self.
|
|
980
|
+
del limit_or_schedule, until, ruleset
|
|
981
|
+
runtime_expr = to_runtime_expr(expr)
|
|
982
|
+
self._add_decls(runtime_expr, schedule)
|
|
983
|
+
egg_schedule = self._state.schedule_to_egg(schedule.schedule)
|
|
984
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
985
|
+
egg_expr = self._state.expr_to_egg(typed_expr.expr)
|
|
986
|
+
self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule))
|
|
1051
987
|
extract_report = self._egraph.extract_report()
|
|
1052
988
|
if not isinstance(extract_report, bindings.Best):
|
|
1053
989
|
msg = "No extract report saved"
|
|
1054
990
|
raise ValueError(msg) # noqa: TRY004
|
|
1055
|
-
new_typed_expr =
|
|
1056
|
-
|
|
1057
|
-
)
|
|
1058
|
-
return cast(EXPR, RuntimeExpr(self._state.decls, new_typed_expr))
|
|
1059
|
-
|
|
1060
|
-
@property
|
|
1061
|
-
def _default_ruleset_name(self) -> str:
|
|
1062
|
-
if self.default_ruleset:
|
|
1063
|
-
self._add_schedule(self.default_ruleset)
|
|
1064
|
-
return self.default_ruleset.egg_name
|
|
1065
|
-
return ""
|
|
991
|
+
(new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
|
|
992
|
+
return cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
|
|
1066
993
|
|
|
1067
994
|
def include(self, path: str) -> None:
|
|
1068
995
|
"""
|
|
@@ -1092,8 +1019,9 @@ class EGraph(_BaseModule):
|
|
|
1092
1019
|
return self._run_schedule(limit_or_schedule)
|
|
1093
1020
|
|
|
1094
1021
|
def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
|
|
1095
|
-
self.
|
|
1096
|
-
self.
|
|
1022
|
+
self._add_decls(schedule)
|
|
1023
|
+
egg_schedule = self._state.schedule_to_egg(schedule.schedule)
|
|
1024
|
+
self._egraph.run_program(bindings.RunSchedule(egg_schedule))
|
|
1097
1025
|
run_report = self._egraph.run_report()
|
|
1098
1026
|
if not run_report:
|
|
1099
1027
|
msg = "No run report saved"
|
|
@@ -1104,18 +1032,18 @@ class EGraph(_BaseModule):
|
|
|
1104
1032
|
"""
|
|
1105
1033
|
Check if a fact is true in the egraph.
|
|
1106
1034
|
"""
|
|
1107
|
-
self.
|
|
1035
|
+
self._egraph.run_program(self._facts_to_check(facts))
|
|
1108
1036
|
|
|
1109
1037
|
def check_fail(self, *facts: FactLike) -> None:
|
|
1110
1038
|
"""
|
|
1111
1039
|
Checks that one of the facts is not true
|
|
1112
1040
|
"""
|
|
1113
|
-
self.
|
|
1041
|
+
self._egraph.run_program(bindings.Fail(self._facts_to_check(facts)))
|
|
1114
1042
|
|
|
1115
|
-
def _facts_to_check(self,
|
|
1116
|
-
facts = _fact_likes(
|
|
1043
|
+
def _facts_to_check(self, fact_likes: Iterable[FactLike]) -> bindings.Check:
|
|
1044
|
+
facts = _fact_likes(fact_likes)
|
|
1117
1045
|
self._add_decls(*facts)
|
|
1118
|
-
egg_facts = [f.
|
|
1046
|
+
egg_facts = [self._state.fact_to_egg(f.fact) for f in _fact_likes(facts)]
|
|
1119
1047
|
return bindings.Check(egg_facts)
|
|
1120
1048
|
|
|
1121
1049
|
@overload
|
|
@@ -1128,16 +1056,17 @@ class EGraph(_BaseModule):
|
|
|
1128
1056
|
"""
|
|
1129
1057
|
Extract the lowest cost expression from the egraph.
|
|
1130
1058
|
"""
|
|
1131
|
-
|
|
1132
|
-
self._add_decls(
|
|
1133
|
-
|
|
1059
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1060
|
+
self._add_decls(runtime_expr)
|
|
1061
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1062
|
+
extract_report = self._run_extract(typed_expr.expr, 0)
|
|
1063
|
+
|
|
1134
1064
|
if not isinstance(extract_report, bindings.Best):
|
|
1135
1065
|
msg = "No extract report saved"
|
|
1136
1066
|
raise ValueError(msg) # noqa: TRY004
|
|
1137
|
-
new_typed_expr =
|
|
1138
|
-
|
|
1139
|
-
)
|
|
1140
|
-
res = cast(EXPR, RuntimeExpr(self._state.decls, new_typed_expr))
|
|
1067
|
+
(new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
|
|
1068
|
+
|
|
1069
|
+
res = cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
|
|
1141
1070
|
if include_cost:
|
|
1142
1071
|
return res, extract_report.cost
|
|
1143
1072
|
return res
|
|
@@ -1146,23 +1075,20 @@ class EGraph(_BaseModule):
|
|
|
1146
1075
|
"""
|
|
1147
1076
|
Extract multiple expressions from the egraph.
|
|
1148
1077
|
"""
|
|
1149
|
-
|
|
1150
|
-
self._add_decls(
|
|
1078
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1079
|
+
self._add_decls(runtime_expr)
|
|
1080
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1151
1081
|
|
|
1152
|
-
extract_report = self._run_extract(expr
|
|
1082
|
+
extract_report = self._run_extract(typed_expr.expr, n)
|
|
1153
1083
|
if not isinstance(extract_report, bindings.Variants):
|
|
1154
1084
|
msg = "Wrong extract report type"
|
|
1155
1085
|
raise ValueError(msg) # noqa: TRY004
|
|
1156
|
-
new_exprs =
|
|
1157
|
-
|
|
1158
|
-
self._egraph, self._state.decls, expr.__egg_typed_expr__.tp, extract_report.termdag, term, {}
|
|
1159
|
-
)
|
|
1160
|
-
for term in extract_report.terms
|
|
1161
|
-
]
|
|
1162
|
-
return [cast(EXPR, RuntimeExpr(self._state.decls, expr)) for expr in new_exprs]
|
|
1086
|
+
new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp)
|
|
1087
|
+
return [cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, expr)) for expr in new_exprs]
|
|
1163
1088
|
|
|
1164
|
-
def _run_extract(self, expr:
|
|
1165
|
-
self.
|
|
1089
|
+
def _run_extract(self, expr: ExprDecl, n: int) -> bindings._ExtractReport:
|
|
1090
|
+
expr = self._state.expr_to_egg(expr)
|
|
1091
|
+
self._egraph.run_program(bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n)))))
|
|
1166
1092
|
extract_report = self._egraph.extract_report()
|
|
1167
1093
|
if not extract_report:
|
|
1168
1094
|
msg = "No extract report saved"
|
|
@@ -1173,15 +1099,15 @@ class EGraph(_BaseModule):
|
|
|
1173
1099
|
"""
|
|
1174
1100
|
Push the current state of the egraph, so that it can be popped later and reverted back.
|
|
1175
1101
|
"""
|
|
1176
|
-
self.
|
|
1102
|
+
self._egraph.run_program(bindings.Push(1))
|
|
1177
1103
|
self._state_stack.append(self._state)
|
|
1178
|
-
self._state =
|
|
1104
|
+
self._state = self._state.copy()
|
|
1179
1105
|
|
|
1180
1106
|
def pop(self) -> None:
|
|
1181
1107
|
"""
|
|
1182
1108
|
Pop the current state of the egraph, reverting back to the previous state.
|
|
1183
1109
|
"""
|
|
1184
|
-
self.
|
|
1110
|
+
self._egraph.run_program(bindings.Pop(1))
|
|
1185
1111
|
self._state = self._state_stack.pop()
|
|
1186
1112
|
|
|
1187
1113
|
def __enter__(self) -> Self:
|
|
@@ -1217,9 +1143,10 @@ class EGraph(_BaseModule):
|
|
|
1217
1143
|
"""
|
|
1218
1144
|
Evaluates the given expression (which must be a primitive type), returning the result.
|
|
1219
1145
|
"""
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1146
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1147
|
+
self._add_decls(runtime_expr)
|
|
1148
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1149
|
+
egg_expr = self._state.expr_to_egg(typed_expr.expr)
|
|
1223
1150
|
match typed_expr.tp:
|
|
1224
1151
|
case JustTypeRef("i64"):
|
|
1225
1152
|
return self._egraph.eval_i64(egg_expr)
|
|
@@ -1231,7 +1158,7 @@ class EGraph(_BaseModule):
|
|
|
1231
1158
|
return self._egraph.eval_string(egg_expr)
|
|
1232
1159
|
case JustTypeRef("PyObject"):
|
|
1233
1160
|
return self._egraph.eval_py_object(egg_expr)
|
|
1234
|
-
raise
|
|
1161
|
+
raise TypeError(f"Eval not implemented for {typed_expr.tp}")
|
|
1235
1162
|
|
|
1236
1163
|
def saturate(
|
|
1237
1164
|
self, *, max: int = 1000, performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
|
|
@@ -1270,6 +1197,32 @@ class EGraph(_BaseModule):
|
|
|
1270
1197
|
"""
|
|
1271
1198
|
return CURRENT_EGRAPH.get()
|
|
1272
1199
|
|
|
1200
|
+
@property
|
|
1201
|
+
def _egraph(self) -> bindings.EGraph:
|
|
1202
|
+
return self._state.egraph
|
|
1203
|
+
|
|
1204
|
+
@property
|
|
1205
|
+
def __egg_decls__(self) -> Declarations:
|
|
1206
|
+
return self._state.__egg_decls__
|
|
1207
|
+
|
|
1208
|
+
def _register_commands(self, cmds: list[Command]) -> None:
|
|
1209
|
+
self._add_decls(*cmds)
|
|
1210
|
+
egg_cmds = list(map(self._command_to_egg, cmds))
|
|
1211
|
+
self._egraph.run_program(*egg_cmds)
|
|
1212
|
+
|
|
1213
|
+
def _command_to_egg(self, cmd: Command) -> bindings._Command:
|
|
1214
|
+
ruleset_name = ""
|
|
1215
|
+
cmd_decl: CommandDecl
|
|
1216
|
+
match cmd:
|
|
1217
|
+
case RewriteOrRule(_, cmd_decl, ruleset):
|
|
1218
|
+
if ruleset:
|
|
1219
|
+
ruleset_name = ruleset.__egg_name__
|
|
1220
|
+
case Action(_, action):
|
|
1221
|
+
cmd_decl = ActionCommandDecl(action)
|
|
1222
|
+
case _:
|
|
1223
|
+
assert_never(cmd)
|
|
1224
|
+
return self._state.command_to_egg(cmd_decl, ruleset_name)
|
|
1225
|
+
|
|
1273
1226
|
|
|
1274
1227
|
CURRENT_EGRAPH = ContextVar[EGraph]("CURRENT_EGRAPH")
|
|
1275
1228
|
|
|
@@ -1316,61 +1269,53 @@ class Unit(Expr, egg_sort="Unit", builtin=True):
|
|
|
1316
1269
|
|
|
1317
1270
|
|
|
1318
1271
|
def ruleset(
|
|
1319
|
-
rule_or_generator:
|
|
1272
|
+
rule_or_generator: RewriteOrRule | RewriteOrRuleGenerator | None = None,
|
|
1273
|
+
*rules: RewriteOrRule,
|
|
1274
|
+
name: None | str = None,
|
|
1320
1275
|
) -> Ruleset:
|
|
1321
1276
|
"""
|
|
1322
1277
|
Creates a ruleset with the following rules.
|
|
1323
1278
|
|
|
1324
1279
|
If no name is provided, one is generated based on the current module
|
|
1325
1280
|
"""
|
|
1326
|
-
r = Ruleset(name
|
|
1281
|
+
r = Ruleset(name)
|
|
1327
1282
|
if rule_or_generator is not None:
|
|
1328
|
-
r.register(rule_or_generator, *rules)
|
|
1283
|
+
r.register(rule_or_generator, *rules, _increase_frame=True)
|
|
1329
1284
|
return r
|
|
1330
1285
|
|
|
1331
1286
|
|
|
1332
|
-
|
|
1287
|
+
@dataclass
|
|
1288
|
+
class Schedule(DelayedDeclerations):
|
|
1333
1289
|
"""
|
|
1334
1290
|
A composition of some rulesets, either composing them sequentially, running them repeatedly, running them till saturation, or running until some facts are met
|
|
1335
1291
|
"""
|
|
1336
1292
|
|
|
1293
|
+
# Defer declerations so that we can have rule generators that used not yet defined yet
|
|
1294
|
+
schedule: ScheduleDecl
|
|
1295
|
+
|
|
1296
|
+
def __str__(self) -> str:
|
|
1297
|
+
return pretty_decl(self.__egg_decls__, self.schedule)
|
|
1298
|
+
|
|
1299
|
+
def __repr__(self) -> str:
|
|
1300
|
+
return str(self)
|
|
1301
|
+
|
|
1337
1302
|
def __mul__(self, length: int) -> Schedule:
|
|
1338
1303
|
"""
|
|
1339
1304
|
Repeat the schedule a number of times.
|
|
1340
1305
|
"""
|
|
1341
|
-
return
|
|
1306
|
+
return Schedule(self.__egg_decls_thunk__, RepeatDecl(self.schedule, length))
|
|
1342
1307
|
|
|
1343
1308
|
def saturate(self) -> Schedule:
|
|
1344
1309
|
"""
|
|
1345
1310
|
Run the schedule until the e-graph is saturated.
|
|
1346
1311
|
"""
|
|
1347
|
-
return
|
|
1312
|
+
return Schedule(self.__egg_decls_thunk__, SaturateDecl(self.schedule))
|
|
1348
1313
|
|
|
1349
1314
|
def __add__(self, other: Schedule) -> Schedule:
|
|
1350
1315
|
"""
|
|
1351
1316
|
Run two schedules in sequence.
|
|
1352
1317
|
"""
|
|
1353
|
-
return
|
|
1354
|
-
|
|
1355
|
-
@abstractmethod
|
|
1356
|
-
def __str__(self) -> str:
|
|
1357
|
-
raise NotImplementedError
|
|
1358
|
-
|
|
1359
|
-
@abstractmethod
|
|
1360
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1361
|
-
raise NotImplementedError
|
|
1362
|
-
|
|
1363
|
-
@abstractmethod
|
|
1364
|
-
def _rulesets(self) -> Iterable[Ruleset]:
|
|
1365
|
-
"""
|
|
1366
|
-
Mapping of all the rulesets used to commands.
|
|
1367
|
-
"""
|
|
1368
|
-
raise NotImplementedError
|
|
1369
|
-
|
|
1370
|
-
@property
|
|
1371
|
-
@abstractmethod
|
|
1372
|
-
def __egg_decls__(self) -> Declarations:
|
|
1373
|
-
raise NotImplementedError
|
|
1318
|
+
return Schedule(Thunk.fn(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule)))
|
|
1374
1319
|
|
|
1375
1320
|
|
|
1376
1321
|
@dataclass
|
|
@@ -1379,422 +1324,119 @@ class Ruleset(Schedule):
|
|
|
1379
1324
|
A collection of rules, which can be run as a schedule.
|
|
1380
1325
|
"""
|
|
1381
1326
|
|
|
1327
|
+
__egg_decls_thunk__: Callable[[], Declarations] = field(init=False)
|
|
1328
|
+
schedule: RunDecl = field(init=False)
|
|
1382
1329
|
name: str | None
|
|
1383
|
-
rules: list[Rule | Rewrite] = field(default_factory=list)
|
|
1384
1330
|
|
|
1385
|
-
|
|
1331
|
+
# Current declerations we have accumulated
|
|
1332
|
+
_current_egg_decls: Declarations = field(default_factory=Declarations)
|
|
1333
|
+
# Current rulesets we have accumulated
|
|
1334
|
+
__egg_ruleset__: RulesetDecl = field(init=False)
|
|
1335
|
+
# Rule generator functions that have been deferred, to allow for late type binding
|
|
1336
|
+
deferred_rule_gens: list[Callable[[], Iterable[RewriteOrRule]]] = field(default_factory=list)
|
|
1337
|
+
|
|
1338
|
+
def __post_init__(self) -> None:
|
|
1339
|
+
self.schedule = RunDecl(self.__egg_name__, ())
|
|
1340
|
+
self.__egg_ruleset__ = self._current_egg_decls._rulesets[self.__egg_name__] = RulesetDecl([])
|
|
1341
|
+
self.__egg_decls_thunk__ = self._update_egg_decls
|
|
1342
|
+
|
|
1343
|
+
def _update_egg_decls(self) -> Declarations:
|
|
1344
|
+
"""
|
|
1345
|
+
To return the egg decls, we go through our deferred rules and add any we haven't yet
|
|
1346
|
+
"""
|
|
1347
|
+
while self.deferred_rule_gens:
|
|
1348
|
+
rules = self.deferred_rule_gens.pop()()
|
|
1349
|
+
self._current_egg_decls.update(*rules)
|
|
1350
|
+
self.__egg_ruleset__.rules.extend(r.decl for r in rules)
|
|
1351
|
+
return self._current_egg_decls
|
|
1352
|
+
|
|
1353
|
+
def append(self, rule: RewriteOrRule) -> None:
|
|
1386
1354
|
"""
|
|
1387
1355
|
Register a rule with the ruleset.
|
|
1388
1356
|
"""
|
|
1389
|
-
self.
|
|
1357
|
+
self._current_egg_decls |= rule
|
|
1358
|
+
self.__egg_ruleset__.rules.append(rule.decl)
|
|
1390
1359
|
|
|
1391
|
-
def register(
|
|
1360
|
+
def register(
|
|
1361
|
+
self,
|
|
1362
|
+
/,
|
|
1363
|
+
rule_or_generator: RewriteOrRule | RewriteOrRuleGenerator,
|
|
1364
|
+
*rules: RewriteOrRule,
|
|
1365
|
+
_increase_frame: bool = False,
|
|
1366
|
+
) -> None:
|
|
1392
1367
|
"""
|
|
1393
1368
|
Register rewrites or rules, either as a function or as values.
|
|
1394
1369
|
"""
|
|
1395
|
-
if isinstance(rule_or_generator,
|
|
1396
|
-
|
|
1397
|
-
|
|
1370
|
+
if isinstance(rule_or_generator, RewriteOrRule):
|
|
1371
|
+
self.append(rule_or_generator)
|
|
1372
|
+
for r in rules:
|
|
1373
|
+
self.append(r)
|
|
1398
1374
|
else:
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
def _cmds(self) -> list[bindings._Command]:
|
|
1409
|
-
cmds = [r._to_egg_command(self.egg_name) for r in self.rules]
|
|
1410
|
-
if self.egg_name:
|
|
1411
|
-
cmds.insert(0, bindings.AddRuleset(self.egg_name))
|
|
1412
|
-
return cmds
|
|
1375
|
+
assert not rules
|
|
1376
|
+
current_frame = inspect.currentframe()
|
|
1377
|
+
assert current_frame
|
|
1378
|
+
original_frame = current_frame.f_back
|
|
1379
|
+
assert original_frame
|
|
1380
|
+
if _increase_frame:
|
|
1381
|
+
original_frame = original_frame.f_back
|
|
1382
|
+
assert original_frame
|
|
1383
|
+
self.deferred_rule_gens.append(Thunk.fn(_rewrite_or_rule_generator, rule_or_generator, original_frame))
|
|
1413
1384
|
|
|
1414
1385
|
def __str__(self) -> str:
|
|
1415
|
-
return
|
|
1386
|
+
return pretty_decl(self._current_egg_decls, self.__egg_ruleset__, ruleset_name=self.name)
|
|
1416
1387
|
|
|
1417
1388
|
def __repr__(self) -> str:
|
|
1418
|
-
|
|
1419
|
-
return str(self)
|
|
1420
|
-
rules = ", ".join(map(repr, self.rules))
|
|
1421
|
-
return f"ruleset({rules}, name={self.egg_name!r})"
|
|
1422
|
-
|
|
1423
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1424
|
-
return bindings.Run(self._to_egg_config())
|
|
1425
|
-
|
|
1426
|
-
def _to_egg_config(self) -> bindings.RunConfig:
|
|
1427
|
-
return bindings.RunConfig(self.egg_name, None)
|
|
1428
|
-
|
|
1429
|
-
def _rulesets(self) -> Iterable[Ruleset]:
|
|
1430
|
-
yield self
|
|
1431
|
-
|
|
1432
|
-
@property
|
|
1433
|
-
def egg_name(self) -> str:
|
|
1434
|
-
return self.name or f"_ruleset_{id(self)}"
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
class Command(ABC):
|
|
1438
|
-
"""
|
|
1439
|
-
A command that can be executed in the egg interpreter.
|
|
1440
|
-
|
|
1441
|
-
We only use this for commands which return no result and don't create new Python objects.
|
|
1442
|
-
|
|
1443
|
-
Anything that can be passed to the `register` function in a Module is a Command.
|
|
1444
|
-
"""
|
|
1445
|
-
|
|
1446
|
-
ruleset: Ruleset | None
|
|
1389
|
+
return str(self)
|
|
1447
1390
|
|
|
1391
|
+
# Create a unique name if we didn't pass one from the user
|
|
1448
1392
|
@property
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
raise NotImplementedError
|
|
1452
|
-
|
|
1453
|
-
@abstractmethod
|
|
1454
|
-
def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
|
|
1455
|
-
raise NotImplementedError
|
|
1456
|
-
|
|
1457
|
-
@abstractmethod
|
|
1458
|
-
def __str__(self) -> str:
|
|
1459
|
-
raise NotImplementedError
|
|
1393
|
+
def __egg_name__(self) -> str:
|
|
1394
|
+
return self.name or f"ruleset_{id(self)}"
|
|
1460
1395
|
|
|
1461
1396
|
|
|
1462
1397
|
@dataclass
|
|
1463
|
-
class
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
_conditions: tuple[Fact, ...]
|
|
1468
|
-
_subsume: bool
|
|
1469
|
-
_fn_name: ClassVar[str] = "rewrite"
|
|
1398
|
+
class RewriteOrRule:
|
|
1399
|
+
__egg_decls__: Declarations
|
|
1400
|
+
decl: RewriteOrRuleDecl
|
|
1401
|
+
ruleset: Ruleset | None = None
|
|
1470
1402
|
|
|
1471
1403
|
def __str__(self) -> str:
|
|
1472
|
-
|
|
1473
|
-
return f"{self._fn_name}({self._lhs}).to({args_str})"
|
|
1474
|
-
|
|
1475
|
-
def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
|
|
1476
|
-
return bindings.RewriteCommand(
|
|
1477
|
-
self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite(), self._subsume
|
|
1478
|
-
)
|
|
1479
|
-
|
|
1480
|
-
def _to_egg_rewrite(self) -> bindings.Rewrite:
|
|
1481
|
-
return bindings.Rewrite(
|
|
1482
|
-
self._lhs.__egg_typed_expr__.expr.to_egg(self._lhs.__egg_decls__),
|
|
1483
|
-
self._rhs.__egg_typed_expr__.expr.to_egg(self._rhs.__egg_decls__),
|
|
1484
|
-
[c._to_egg_fact() for c in self._conditions],
|
|
1485
|
-
)
|
|
1486
|
-
|
|
1487
|
-
@cached_property
|
|
1488
|
-
def __egg_decls__(self) -> Declarations:
|
|
1489
|
-
return Declarations.create(self._lhs, self._rhs, *self._conditions)
|
|
1490
|
-
|
|
1491
|
-
def with_ruleset(self, ruleset: Ruleset) -> Rewrite:
|
|
1492
|
-
return Rewrite(ruleset, self._lhs, self._rhs, self._conditions, self._subsume)
|
|
1404
|
+
return pretty_decl(self.__egg_decls__, self.decl)
|
|
1493
1405
|
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
class BiRewrite(Rewrite):
|
|
1497
|
-
_fn_name: ClassVar[str] = "birewrite"
|
|
1498
|
-
|
|
1499
|
-
def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
|
|
1500
|
-
return bindings.BiRewriteCommand(
|
|
1501
|
-
self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite()
|
|
1502
|
-
)
|
|
1406
|
+
def __repr__(self) -> str:
|
|
1407
|
+
return str(self)
|
|
1503
1408
|
|
|
1504
1409
|
|
|
1505
1410
|
@dataclass
|
|
1506
|
-
class Fact
|
|
1411
|
+
class Fact:
|
|
1507
1412
|
"""
|
|
1508
1413
|
A query on an EGraph, either by an expression or an equivalence between multiple expressions.
|
|
1509
1414
|
"""
|
|
1510
1415
|
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
raise NotImplementedError
|
|
1514
|
-
|
|
1515
|
-
@property
|
|
1516
|
-
@abstractmethod
|
|
1517
|
-
def __egg_decls__(self) -> Declarations:
|
|
1518
|
-
raise NotImplementedError
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
@dataclass
|
|
1522
|
-
class Eq(Fact):
|
|
1523
|
-
_exprs: list[RuntimeExpr]
|
|
1416
|
+
__egg_decls__: Declarations
|
|
1417
|
+
fact: FactDecl
|
|
1524
1418
|
|
|
1525
1419
|
def __str__(self) -> str:
|
|
1526
|
-
|
|
1527
|
-
args_str = ", ".join(map(str, rest))
|
|
1528
|
-
return f"eq({first}).to({args_str})"
|
|
1529
|
-
|
|
1530
|
-
def _to_egg_fact(self) -> bindings.Eq:
|
|
1531
|
-
return bindings.Eq([e.__egg__ for e in self._exprs])
|
|
1420
|
+
return pretty_decl(self.__egg_decls__, self.fact)
|
|
1532
1421
|
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
return Declarations.create(*self._exprs)
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
@dataclass
|
|
1539
|
-
class ExprFact(Fact):
|
|
1540
|
-
_expr: RuntimeExpr
|
|
1541
|
-
|
|
1542
|
-
def __str__(self) -> str:
|
|
1543
|
-
return str(self._expr)
|
|
1544
|
-
|
|
1545
|
-
def _to_egg_fact(self) -> bindings.Fact:
|
|
1546
|
-
return bindings.Fact(self._expr.__egg__)
|
|
1547
|
-
|
|
1548
|
-
@cached_property
|
|
1549
|
-
def __egg_decls__(self) -> Declarations:
|
|
1550
|
-
return self._expr.__egg_decls__
|
|
1422
|
+
def __repr__(self) -> str:
|
|
1423
|
+
return str(self)
|
|
1551
1424
|
|
|
1552
1425
|
|
|
1553
1426
|
@dataclass
|
|
1554
|
-
class
|
|
1555
|
-
head: tuple[Action, ...]
|
|
1556
|
-
body: tuple[Fact, ...]
|
|
1557
|
-
name: str
|
|
1558
|
-
ruleset: Ruleset | None
|
|
1559
|
-
|
|
1560
|
-
def __str__(self) -> str:
|
|
1561
|
-
head_str = ", ".join(map(str, self.head))
|
|
1562
|
-
body_str = ", ".join(map(str, self.body))
|
|
1563
|
-
return f"rule({body_str}).then({head_str})"
|
|
1564
|
-
|
|
1565
|
-
def _to_egg_command(self, default_ruleset_name: str) -> bindings.RuleCommand:
|
|
1566
|
-
return bindings.RuleCommand(
|
|
1567
|
-
self.name,
|
|
1568
|
-
self.ruleset.egg_name if self.ruleset else default_ruleset_name,
|
|
1569
|
-
bindings.Rule(
|
|
1570
|
-
[a._to_egg_action() for a in self.head],
|
|
1571
|
-
[f._to_egg_fact() for f in self.body],
|
|
1572
|
-
),
|
|
1573
|
-
)
|
|
1574
|
-
|
|
1575
|
-
@cached_property
|
|
1576
|
-
def __egg_decls__(self) -> Declarations:
|
|
1577
|
-
return Declarations.create(*self.head, *self.body)
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
class Action(Command, ABC):
|
|
1427
|
+
class Action:
|
|
1581
1428
|
"""
|
|
1582
1429
|
A change to an EGraph, either unioning multiple expressing, setting the value of a function call, deleting an expression, or panicking.
|
|
1583
1430
|
"""
|
|
1584
1431
|
|
|
1585
|
-
|
|
1586
|
-
|
|
1587
|
-
raise NotImplementedError
|
|
1588
|
-
|
|
1589
|
-
def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
|
|
1590
|
-
return bindings.ActionCommand(self._to_egg_action())
|
|
1591
|
-
|
|
1592
|
-
@property
|
|
1593
|
-
def ruleset(self) -> None | Ruleset: # type: ignore[override]
|
|
1594
|
-
return None
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
@dataclass
|
|
1598
|
-
class Let(Action):
|
|
1599
|
-
_name: str
|
|
1600
|
-
_value: RuntimeExpr
|
|
1601
|
-
|
|
1602
|
-
def __str__(self) -> str:
|
|
1603
|
-
return f"let({self._name}, {self._value})"
|
|
1604
|
-
|
|
1605
|
-
def _to_egg_action(self) -> bindings.Let:
|
|
1606
|
-
return bindings.Let(self._name, self._value.__egg__)
|
|
1607
|
-
|
|
1608
|
-
@property
|
|
1609
|
-
def __egg_decls__(self) -> Declarations:
|
|
1610
|
-
return self._value.__egg_decls__
|
|
1611
|
-
|
|
1612
|
-
|
|
1613
|
-
@dataclass
|
|
1614
|
-
class Set(Action):
|
|
1615
|
-
"""
|
|
1616
|
-
Similar to union, except can be used on primitive expressions, whereas union can only be used on user defined expressions.
|
|
1617
|
-
"""
|
|
1618
|
-
|
|
1619
|
-
_call: RuntimeExpr
|
|
1620
|
-
_rhs: RuntimeExpr
|
|
1621
|
-
|
|
1622
|
-
def __str__(self) -> str:
|
|
1623
|
-
return f"set({self._call}).to({self._rhs})"
|
|
1624
|
-
|
|
1625
|
-
def _to_egg_action(self) -> bindings.Set:
|
|
1626
|
-
egg_call = self._call.__egg__
|
|
1627
|
-
if not isinstance(egg_call, bindings.Call):
|
|
1628
|
-
raise ValueError(f"Can only create a set with a call for the lhs, got {self._call}") # noqa: TRY004
|
|
1629
|
-
return bindings.Set(
|
|
1630
|
-
egg_call.name,
|
|
1631
|
-
egg_call.args,
|
|
1632
|
-
self._rhs.__egg__,
|
|
1633
|
-
)
|
|
1634
|
-
|
|
1635
|
-
@cached_property
|
|
1636
|
-
def __egg_decls__(self) -> Declarations:
|
|
1637
|
-
return Declarations.create(self._call, self._rhs)
|
|
1638
|
-
|
|
1639
|
-
|
|
1640
|
-
@dataclass
|
|
1641
|
-
class ExprAction(Action):
|
|
1642
|
-
_expr: RuntimeExpr
|
|
1643
|
-
|
|
1644
|
-
def __str__(self) -> str:
|
|
1645
|
-
return str(self._expr)
|
|
1646
|
-
|
|
1647
|
-
def _to_egg_action(self) -> bindings.Expr_:
|
|
1648
|
-
return bindings.Expr_(self._expr.__egg__)
|
|
1649
|
-
|
|
1650
|
-
@property
|
|
1651
|
-
def __egg_decls__(self) -> Declarations:
|
|
1652
|
-
return self._expr.__egg_decls__
|
|
1653
|
-
|
|
1654
|
-
|
|
1655
|
-
@dataclass
|
|
1656
|
-
class Change(Action):
|
|
1657
|
-
"""
|
|
1658
|
-
Change a function call in an EGraph.
|
|
1659
|
-
"""
|
|
1660
|
-
|
|
1661
|
-
change: Literal["delete", "subsume"]
|
|
1662
|
-
_call: RuntimeExpr
|
|
1663
|
-
|
|
1664
|
-
def __str__(self) -> str:
|
|
1665
|
-
return f"{self.change}({self._call})"
|
|
1666
|
-
|
|
1667
|
-
def _to_egg_action(self) -> bindings.Change:
|
|
1668
|
-
egg_call = self._call.__egg__
|
|
1669
|
-
if not isinstance(egg_call, bindings.Call):
|
|
1670
|
-
raise ValueError(f"Can only create a call with a call for the lhs, got {self._call}") # noqa: TRY004
|
|
1671
|
-
change: bindings._Change = bindings.Delete() if self.change == "delete" else bindings.Subsume()
|
|
1672
|
-
return bindings.Change(change, egg_call.name, egg_call.args)
|
|
1673
|
-
|
|
1674
|
-
@property
|
|
1675
|
-
def __egg_decls__(self) -> Declarations:
|
|
1676
|
-
return self._call.__egg_decls__
|
|
1677
|
-
|
|
1678
|
-
|
|
1679
|
-
@dataclass
|
|
1680
|
-
class Union_(Action): # noqa: N801
|
|
1681
|
-
"""
|
|
1682
|
-
Merges two equivalence classes of two expressions.
|
|
1683
|
-
"""
|
|
1684
|
-
|
|
1685
|
-
_lhs: RuntimeExpr
|
|
1686
|
-
_rhs: RuntimeExpr
|
|
1687
|
-
|
|
1688
|
-
def __str__(self) -> str:
|
|
1689
|
-
return f"union({self._lhs}).with_({self._rhs})"
|
|
1690
|
-
|
|
1691
|
-
def _to_egg_action(self) -> bindings.Union:
|
|
1692
|
-
return bindings.Union(self._lhs.__egg__, self._rhs.__egg__)
|
|
1693
|
-
|
|
1694
|
-
@cached_property
|
|
1695
|
-
def __egg_decls__(self) -> Declarations:
|
|
1696
|
-
return Declarations.create(self._lhs, self._rhs)
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
@dataclass
|
|
1700
|
-
class Panic(Action):
|
|
1701
|
-
message: str
|
|
1702
|
-
|
|
1703
|
-
def __str__(self) -> str:
|
|
1704
|
-
return f"panic({self.message})"
|
|
1705
|
-
|
|
1706
|
-
def _to_egg_action(self) -> bindings.Panic:
|
|
1707
|
-
return bindings.Panic(self.message)
|
|
1708
|
-
|
|
1709
|
-
@cached_property
|
|
1710
|
-
def __egg_decls__(self) -> Declarations:
|
|
1711
|
-
return Declarations()
|
|
1712
|
-
|
|
1713
|
-
|
|
1714
|
-
@dataclass
|
|
1715
|
-
class Run(Schedule):
|
|
1716
|
-
"""Configuration of a run"""
|
|
1717
|
-
|
|
1718
|
-
# None if using default ruleset
|
|
1719
|
-
ruleset: Ruleset | None
|
|
1720
|
-
until: tuple[Fact, ...]
|
|
1721
|
-
|
|
1722
|
-
def __str__(self) -> str:
|
|
1723
|
-
args_str = ", ".join(map(str, [self.ruleset, *self.until]))
|
|
1724
|
-
return f"run({args_str})"
|
|
1725
|
-
|
|
1726
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1727
|
-
return bindings.Run(self._to_egg_config(default_ruleset_name))
|
|
1728
|
-
|
|
1729
|
-
def _to_egg_config(self, default_ruleset_name: str) -> bindings.RunConfig:
|
|
1730
|
-
return bindings.RunConfig(
|
|
1731
|
-
self.ruleset.egg_name if self.ruleset else default_ruleset_name,
|
|
1732
|
-
[fact._to_egg_fact() for fact in self.until] if self.until else None,
|
|
1733
|
-
)
|
|
1734
|
-
|
|
1735
|
-
def _rulesets(self) -> Iterable[Ruleset]:
|
|
1736
|
-
if self.ruleset:
|
|
1737
|
-
yield self.ruleset
|
|
1738
|
-
|
|
1739
|
-
@cached_property
|
|
1740
|
-
def __egg_decls__(self) -> Declarations:
|
|
1741
|
-
return Declarations.create(self.ruleset, *self.until)
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
@dataclass
|
|
1745
|
-
class Saturate(Schedule):
|
|
1746
|
-
schedule: Schedule
|
|
1747
|
-
|
|
1748
|
-
def __str__(self) -> str:
|
|
1749
|
-
return f"{self.schedule}.saturate()"
|
|
1750
|
-
|
|
1751
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1752
|
-
return bindings.Saturate(self.schedule._to_egg_schedule(default_ruleset_name))
|
|
1753
|
-
|
|
1754
|
-
def _rulesets(self) -> Iterable[Ruleset]:
|
|
1755
|
-
return self.schedule._rulesets()
|
|
1756
|
-
|
|
1757
|
-
@property
|
|
1758
|
-
def __egg_decls__(self) -> Declarations:
|
|
1759
|
-
return self.schedule.__egg_decls__
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
@dataclass
|
|
1763
|
-
class Repeat(Schedule):
|
|
1764
|
-
length: int
|
|
1765
|
-
schedule: Schedule
|
|
1766
|
-
|
|
1767
|
-
def __str__(self) -> str:
|
|
1768
|
-
return f"{self.schedule} * {self.length}"
|
|
1769
|
-
|
|
1770
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1771
|
-
return bindings.Repeat(self.length, self.schedule._to_egg_schedule(default_ruleset_name))
|
|
1772
|
-
|
|
1773
|
-
def _rulesets(self) -> Iterable[Ruleset]:
|
|
1774
|
-
return self.schedule._rulesets()
|
|
1775
|
-
|
|
1776
|
-
@property
|
|
1777
|
-
def __egg_decls__(self) -> Declarations:
|
|
1778
|
-
return self.schedule.__egg_decls__
|
|
1779
|
-
|
|
1780
|
-
|
|
1781
|
-
@dataclass
|
|
1782
|
-
class Sequence(Schedule):
|
|
1783
|
-
schedules: tuple[Schedule, ...]
|
|
1432
|
+
__egg_decls__: Declarations
|
|
1433
|
+
action: ActionDecl
|
|
1784
1434
|
|
|
1785
1435
|
def __str__(self) -> str:
|
|
1786
|
-
return
|
|
1787
|
-
|
|
1788
|
-
def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
|
|
1789
|
-
return bindings.Sequence([schedule._to_egg_schedule(default_ruleset_name) for schedule in self.schedules])
|
|
1436
|
+
return pretty_decl(self.__egg_decls__, self.action)
|
|
1790
1437
|
|
|
1791
|
-
def
|
|
1792
|
-
|
|
1793
|
-
yield from s._rulesets()
|
|
1794
|
-
|
|
1795
|
-
@cached_property
|
|
1796
|
-
def __egg_decls__(self) -> Declarations:
|
|
1797
|
-
return Declarations.create(*self.schedules)
|
|
1438
|
+
def __repr__(self) -> str:
|
|
1439
|
+
return str(self)
|
|
1798
1440
|
|
|
1799
1441
|
|
|
1800
1442
|
# We use these builders so that when creating these structures we can type check
|
|
@@ -1841,30 +1483,41 @@ def ne(expr: EXPR) -> _NeBuilder[EXPR]:
|
|
|
1841
1483
|
|
|
1842
1484
|
def panic(message: str) -> Action:
|
|
1843
1485
|
"""Raise an error with the given message."""
|
|
1844
|
-
return
|
|
1486
|
+
return Action(Declarations(), PanicDecl(message))
|
|
1845
1487
|
|
|
1846
1488
|
|
|
1847
1489
|
def let(name: str, expr: Expr) -> Action:
|
|
1848
1490
|
"""Create a let binding."""
|
|
1849
|
-
|
|
1491
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1492
|
+
return Action(runtime_expr.__egg_decls__, LetDecl(name, runtime_expr.__egg_typed_expr__))
|
|
1850
1493
|
|
|
1851
1494
|
|
|
1852
1495
|
def expr_action(expr: Expr) -> Action:
|
|
1853
|
-
|
|
1496
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1497
|
+
return Action(runtime_expr.__egg_decls__, ExprActionDecl(runtime_expr.__egg_typed_expr__))
|
|
1854
1498
|
|
|
1855
1499
|
|
|
1856
1500
|
def delete(expr: Expr) -> Action:
|
|
1857
1501
|
"""Create a delete expression."""
|
|
1858
|
-
|
|
1502
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1503
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1504
|
+
call_decl = typed_expr.expr
|
|
1505
|
+
assert isinstance(call_decl, CallDecl), "Can only delete calls, not literals or vars"
|
|
1506
|
+
return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "delete"))
|
|
1859
1507
|
|
|
1860
1508
|
|
|
1861
1509
|
def subsume(expr: Expr) -> Action:
|
|
1862
|
-
"""Subsume an expression
|
|
1863
|
-
|
|
1510
|
+
"""Subsume an expression so it cannot be matched against or extracted"""
|
|
1511
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1512
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1513
|
+
call_decl = typed_expr.expr
|
|
1514
|
+
assert isinstance(call_decl, CallDecl), "Can only subsume calls, not literals or vars"
|
|
1515
|
+
return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "subsume"))
|
|
1864
1516
|
|
|
1865
1517
|
|
|
1866
1518
|
def expr_fact(expr: Expr) -> Fact:
|
|
1867
|
-
|
|
1519
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1520
|
+
return Fact(runtime_expr.__egg_decls__, ExprFactDecl(runtime_expr.__egg_typed_expr__))
|
|
1868
1521
|
|
|
1869
1522
|
|
|
1870
1523
|
def union(lhs: EXPR) -> _UnionBuilder[EXPR]:
|
|
@@ -1891,6 +1544,11 @@ def rule(*facts: FactLike, ruleset: Ruleset | None = None, name: str | None = No
|
|
|
1891
1544
|
return _RuleBuilder(facts=_fact_likes(facts), name=name, ruleset=ruleset)
|
|
1892
1545
|
|
|
1893
1546
|
|
|
1547
|
+
@deprecated("This function is now a no-op, you can remove it and use actions as commands")
|
|
1548
|
+
def action_command(action: Action) -> Action:
|
|
1549
|
+
return action
|
|
1550
|
+
|
|
1551
|
+
|
|
1894
1552
|
def var(name: str, bound: type[EXPR]) -> EXPR:
|
|
1895
1553
|
"""Create a new variable with the given name and type."""
|
|
1896
1554
|
return cast(EXPR, _var(name, bound))
|
|
@@ -1898,9 +1556,9 @@ def var(name: str, bound: type[EXPR]) -> EXPR:
|
|
|
1898
1556
|
|
|
1899
1557
|
def _var(name: str, bound: object) -> RuntimeExpr:
|
|
1900
1558
|
"""Create a new variable with the given name and type."""
|
|
1901
|
-
if not isinstance(bound, RuntimeClass
|
|
1559
|
+
if not isinstance(bound, RuntimeClass):
|
|
1902
1560
|
raise TypeError(f"Unexpected type {type(bound)}")
|
|
1903
|
-
return RuntimeExpr(bound.__egg_decls__, TypedExprDecl(
|
|
1561
|
+
return RuntimeExpr.__from_value__(bound.__egg_decls__, TypedExprDecl(bound.__egg_tp__.to_just(), VarDecl(name)))
|
|
1904
1562
|
|
|
1905
1563
|
|
|
1906
1564
|
def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
|
|
@@ -1915,15 +1573,27 @@ class _RewriteBuilder(Generic[EXPR]):
|
|
|
1915
1573
|
ruleset: Ruleset | None
|
|
1916
1574
|
subsume: bool
|
|
1917
1575
|
|
|
1918
|
-
def to(self, rhs: EXPR, *conditions: FactLike) ->
|
|
1576
|
+
def to(self, rhs: EXPR, *conditions: FactLike) -> RewriteOrRule:
|
|
1919
1577
|
lhs = to_runtime_expr(self.lhs)
|
|
1920
|
-
|
|
1578
|
+
facts = _fact_likes(conditions)
|
|
1579
|
+
rhs = convert_to_same_type(rhs, lhs)
|
|
1580
|
+
rule = RewriteOrRule(
|
|
1581
|
+
Declarations.create(lhs, rhs, *facts, self.ruleset),
|
|
1582
|
+
RewriteDecl(
|
|
1583
|
+
lhs.__egg_typed_expr__.tp,
|
|
1584
|
+
lhs.__egg_typed_expr__.expr,
|
|
1585
|
+
rhs.__egg_typed_expr__.expr,
|
|
1586
|
+
tuple(f.fact for f in facts),
|
|
1587
|
+
self.subsume,
|
|
1588
|
+
),
|
|
1589
|
+
)
|
|
1921
1590
|
if self.ruleset:
|
|
1922
1591
|
self.ruleset.append(rule)
|
|
1923
1592
|
return rule
|
|
1924
1593
|
|
|
1925
1594
|
def __str__(self) -> str:
|
|
1926
|
-
|
|
1595
|
+
lhs = to_runtime_expr(self.lhs)
|
|
1596
|
+
return lhs.__egg_pretty__("rewrite")
|
|
1927
1597
|
|
|
1928
1598
|
|
|
1929
1599
|
@dataclass
|
|
@@ -1931,15 +1601,26 @@ class _BirewriteBuilder(Generic[EXPR]):
|
|
|
1931
1601
|
lhs: EXPR
|
|
1932
1602
|
ruleset: Ruleset | None
|
|
1933
1603
|
|
|
1934
|
-
def to(self, rhs: EXPR, *conditions: FactLike) ->
|
|
1604
|
+
def to(self, rhs: EXPR, *conditions: FactLike) -> RewriteOrRule:
|
|
1935
1605
|
lhs = to_runtime_expr(self.lhs)
|
|
1936
|
-
|
|
1606
|
+
facts = _fact_likes(conditions)
|
|
1607
|
+
rhs = convert_to_same_type(rhs, lhs)
|
|
1608
|
+
rule = RewriteOrRule(
|
|
1609
|
+
Declarations.create(lhs, rhs, *facts, self.ruleset),
|
|
1610
|
+
BiRewriteDecl(
|
|
1611
|
+
lhs.__egg_typed_expr__.tp,
|
|
1612
|
+
lhs.__egg_typed_expr__.expr,
|
|
1613
|
+
rhs.__egg_typed_expr__.expr,
|
|
1614
|
+
tuple(f.fact for f in facts),
|
|
1615
|
+
),
|
|
1616
|
+
)
|
|
1937
1617
|
if self.ruleset:
|
|
1938
1618
|
self.ruleset.append(rule)
|
|
1939
1619
|
return rule
|
|
1940
1620
|
|
|
1941
1621
|
def __str__(self) -> str:
|
|
1942
|
-
|
|
1622
|
+
lhs = to_runtime_expr(self.lhs)
|
|
1623
|
+
return lhs.__egg_pretty__("birewrite")
|
|
1943
1624
|
|
|
1944
1625
|
|
|
1945
1626
|
@dataclass
|
|
@@ -1948,52 +1629,84 @@ class _EqBuilder(Generic[EXPR]):
|
|
|
1948
1629
|
|
|
1949
1630
|
def to(self, *exprs: EXPR) -> Fact:
|
|
1950
1631
|
expr = to_runtime_expr(self.expr)
|
|
1951
|
-
|
|
1632
|
+
args = [expr, *(convert_to_same_type(e, expr) for e in exprs)]
|
|
1633
|
+
return Fact(
|
|
1634
|
+
Declarations.create(*args),
|
|
1635
|
+
EqDecl(expr.__egg_typed_expr__.tp, tuple(a.__egg_typed_expr__.expr for a in args)),
|
|
1636
|
+
)
|
|
1637
|
+
|
|
1638
|
+
def __repr__(self) -> str:
|
|
1639
|
+
return str(self)
|
|
1952
1640
|
|
|
1953
1641
|
def __str__(self) -> str:
|
|
1954
|
-
|
|
1642
|
+
expr = to_runtime_expr(self.expr)
|
|
1643
|
+
return expr.__egg_pretty__("eq")
|
|
1955
1644
|
|
|
1956
1645
|
|
|
1957
1646
|
@dataclass
|
|
1958
1647
|
class _NeBuilder(Generic[EXPR]):
|
|
1959
|
-
|
|
1648
|
+
lhs: EXPR
|
|
1960
1649
|
|
|
1961
|
-
def to(self,
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
res = RuntimeExpr(
|
|
1966
|
-
|
|
1967
|
-
TypedExprDecl(
|
|
1650
|
+
def to(self, rhs: EXPR) -> Unit:
|
|
1651
|
+
lhs = to_runtime_expr(self.lhs)
|
|
1652
|
+
rhs = convert_to_same_type(rhs, lhs)
|
|
1653
|
+
assert isinstance(Unit, RuntimeClass)
|
|
1654
|
+
res = RuntimeExpr.__from_value__(
|
|
1655
|
+
Declarations.create(Unit, lhs, rhs),
|
|
1656
|
+
TypedExprDecl(
|
|
1657
|
+
JustTypeRef("Unit"), CallDecl(FunctionRef("!="), (lhs.__egg_typed_expr__, rhs.__egg_typed_expr__))
|
|
1658
|
+
),
|
|
1968
1659
|
)
|
|
1969
1660
|
return cast(Unit, res)
|
|
1970
1661
|
|
|
1662
|
+
def __repr__(self) -> str:
|
|
1663
|
+
return str(self)
|
|
1664
|
+
|
|
1971
1665
|
def __str__(self) -> str:
|
|
1972
|
-
|
|
1666
|
+
expr = to_runtime_expr(self.lhs)
|
|
1667
|
+
return expr.__egg_pretty__("ne")
|
|
1973
1668
|
|
|
1974
1669
|
|
|
1975
1670
|
@dataclass
|
|
1976
1671
|
class _SetBuilder(Generic[EXPR]):
|
|
1977
|
-
lhs:
|
|
1672
|
+
lhs: EXPR
|
|
1978
1673
|
|
|
1979
|
-
def to(self, rhs: EXPR) ->
|
|
1674
|
+
def to(self, rhs: EXPR) -> Action:
|
|
1980
1675
|
lhs = to_runtime_expr(self.lhs)
|
|
1981
|
-
|
|
1676
|
+
rhs = convert_to_same_type(rhs, lhs)
|
|
1677
|
+
lhs_expr = lhs.__egg_typed_expr__.expr
|
|
1678
|
+
assert isinstance(lhs_expr, CallDecl), "Can only set function calls"
|
|
1679
|
+
return Action(
|
|
1680
|
+
Declarations.create(lhs, rhs),
|
|
1681
|
+
SetDecl(lhs.__egg_typed_expr__.tp, lhs_expr, rhs.__egg_typed_expr__.expr),
|
|
1682
|
+
)
|
|
1683
|
+
|
|
1684
|
+
def __repr__(self) -> str:
|
|
1685
|
+
return str(self)
|
|
1982
1686
|
|
|
1983
1687
|
def __str__(self) -> str:
|
|
1984
|
-
|
|
1688
|
+
lhs = to_runtime_expr(self.lhs)
|
|
1689
|
+
return lhs.__egg_pretty__("set_")
|
|
1985
1690
|
|
|
1986
1691
|
|
|
1987
1692
|
@dataclass
|
|
1988
1693
|
class _UnionBuilder(Generic[EXPR]):
|
|
1989
|
-
lhs:
|
|
1694
|
+
lhs: EXPR
|
|
1990
1695
|
|
|
1991
1696
|
def with_(self, rhs: EXPR) -> Action:
|
|
1992
1697
|
lhs = to_runtime_expr(self.lhs)
|
|
1993
|
-
|
|
1698
|
+
rhs = convert_to_same_type(rhs, lhs)
|
|
1699
|
+
return Action(
|
|
1700
|
+
Declarations.create(lhs, rhs),
|
|
1701
|
+
UnionDecl(lhs.__egg_typed_expr__.tp, lhs.__egg_typed_expr__.expr, rhs.__egg_typed_expr__.expr),
|
|
1702
|
+
)
|
|
1703
|
+
|
|
1704
|
+
def __repr__(self) -> str:
|
|
1705
|
+
return str(self)
|
|
1994
1706
|
|
|
1995
1707
|
def __str__(self) -> str:
|
|
1996
|
-
|
|
1708
|
+
lhs = to_runtime_expr(self.lhs)
|
|
1709
|
+
return lhs.__egg_pretty__("union")
|
|
1997
1710
|
|
|
1998
1711
|
|
|
1999
1712
|
@dataclass
|
|
@@ -2002,12 +1715,25 @@ class _RuleBuilder:
|
|
|
2002
1715
|
name: str | None
|
|
2003
1716
|
ruleset: Ruleset | None
|
|
2004
1717
|
|
|
2005
|
-
def then(self, *actions: ActionLike) ->
|
|
2006
|
-
|
|
1718
|
+
def then(self, *actions: ActionLike) -> RewriteOrRule:
|
|
1719
|
+
actions = _action_likes(actions)
|
|
1720
|
+
rule = RewriteOrRule(
|
|
1721
|
+
Declarations.create(self.ruleset, *actions, *self.facts),
|
|
1722
|
+
RuleDecl(tuple(a.action for a in actions), tuple(f.fact for f in self.facts), self.name),
|
|
1723
|
+
)
|
|
2007
1724
|
if self.ruleset:
|
|
2008
1725
|
self.ruleset.append(rule)
|
|
2009
1726
|
return rule
|
|
2010
1727
|
|
|
1728
|
+
def __str__(self) -> str:
|
|
1729
|
+
# TODO: Figure out how to stringify rulebuilder that preserves statements
|
|
1730
|
+
args = list(map(str, self.facts))
|
|
1731
|
+
if self.name is not None:
|
|
1732
|
+
args.append(f"name={self.name}")
|
|
1733
|
+
if ruleset is not None:
|
|
1734
|
+
args.append(f"ruleset={self.ruleset}")
|
|
1735
|
+
return f"rule({', '.join(args)})"
|
|
1736
|
+
|
|
2011
1737
|
|
|
2012
1738
|
def expr_parts(expr: Expr) -> TypedExprDecl:
|
|
2013
1739
|
"""
|
|
@@ -2024,60 +1750,61 @@ def to_runtime_expr(expr: Expr) -> RuntimeExpr:
|
|
|
2024
1750
|
return expr
|
|
2025
1751
|
|
|
2026
1752
|
|
|
2027
|
-
def run(ruleset: Ruleset | None = None, *until:
|
|
1753
|
+
def run(ruleset: Ruleset | None = None, *until: FactLike) -> Schedule:
|
|
2028
1754
|
"""
|
|
2029
1755
|
Create a run configuration.
|
|
2030
1756
|
"""
|
|
2031
|
-
|
|
1757
|
+
facts = _fact_likes(until)
|
|
1758
|
+
return Schedule(
|
|
1759
|
+
Thunk.fn(Declarations.create, ruleset, *facts),
|
|
1760
|
+
RunDecl(ruleset.__egg_name__ if ruleset else "", tuple(f.fact for f in facts) or None),
|
|
1761
|
+
)
|
|
2032
1762
|
|
|
2033
1763
|
|
|
2034
1764
|
def seq(*schedules: Schedule) -> Schedule:
|
|
2035
1765
|
"""
|
|
2036
1766
|
Run a sequence of schedules.
|
|
2037
1767
|
"""
|
|
2038
|
-
return
|
|
1768
|
+
return Schedule(Thunk.fn(Declarations.create, *schedules), SequenceDecl(tuple(s.schedule for s in schedules)))
|
|
2039
1769
|
|
|
2040
1770
|
|
|
2041
|
-
|
|
1771
|
+
ActionLike: TypeAlias = Action | Expr
|
|
2042
1772
|
|
|
2043
1773
|
|
|
2044
|
-
def
|
|
2045
|
-
|
|
2046
|
-
return expr_action(command_like)
|
|
2047
|
-
return command_like
|
|
1774
|
+
def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:
|
|
1775
|
+
return tuple(map(_action_like, action_likes))
|
|
2048
1776
|
|
|
2049
1777
|
|
|
2050
|
-
|
|
1778
|
+
def _action_like(action_like: ActionLike) -> Action:
|
|
1779
|
+
if isinstance(action_like, Expr):
|
|
1780
|
+
return expr_action(action_like)
|
|
1781
|
+
return action_like
|
|
2051
1782
|
|
|
2052
1783
|
|
|
2053
|
-
|
|
2054
|
-
"""
|
|
2055
|
-
Calls the function with variables of the type and name of the arguments.
|
|
2056
|
-
"""
|
|
2057
|
-
# Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
|
|
2058
|
-
# but not in the globals
|
|
2059
|
-
current_frame = inspect.currentframe()
|
|
2060
|
-
assert current_frame
|
|
2061
|
-
register_frame = current_frame.f_back
|
|
2062
|
-
assert register_frame
|
|
2063
|
-
original_frame = register_frame.f_back
|
|
2064
|
-
assert original_frame
|
|
2065
|
-
hints = get_type_hints(gen, gen.__globals__, original_frame.f_locals)
|
|
2066
|
-
args = (_var(p.name, hints[p.name]) for p in signature(gen).parameters.values())
|
|
2067
|
-
return gen(*args)
|
|
1784
|
+
Command: TypeAlias = Action | RewriteOrRule
|
|
2068
1785
|
|
|
1786
|
+
CommandLike: TypeAlias = ActionLike | RewriteOrRule
|
|
2069
1787
|
|
|
2070
|
-
ActionLike = Action | Expr
|
|
2071
1788
|
|
|
1789
|
+
def _command_like(command_like: CommandLike) -> Command:
|
|
1790
|
+
if isinstance(command_like, RewriteOrRule):
|
|
1791
|
+
return command_like
|
|
1792
|
+
return _action_like(command_like)
|
|
2072
1793
|
|
|
2073
|
-
def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:
|
|
2074
|
-
return tuple(map(_action_like, action_likes))
|
|
2075
1794
|
|
|
1795
|
+
RewriteOrRuleGenerator = Callable[..., Iterable[RewriteOrRule]]
|
|
2076
1796
|
|
|
2077
|
-
|
|
2078
|
-
|
|
2079
|
-
|
|
2080
|
-
|
|
1797
|
+
|
|
1798
|
+
def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) -> Iterable[RewriteOrRule]:
|
|
1799
|
+
"""
|
|
1800
|
+
Returns a thunk which will call the function with variables of the type and name of the arguments.
|
|
1801
|
+
"""
|
|
1802
|
+
# Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
|
|
1803
|
+
# but not in the globals
|
|
1804
|
+
|
|
1805
|
+
hints = get_type_hints(gen, gen.__globals__, frame.f_locals)
|
|
1806
|
+
args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
|
|
1807
|
+
return list(gen(*args)) # type: ignore[misc]
|
|
2081
1808
|
|
|
2082
1809
|
|
|
2083
1810
|
FactLike = Fact | Expr
|