egglog 0.4.0__cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +5 -0
- egglog/bindings.cpython-312-x86_64-linux-gnu.so +0 -0
- egglog/bindings.pyi +415 -0
- egglog/builtins.py +345 -0
- egglog/config.py +8 -0
- egglog/declarations.py +934 -0
- egglog/egraph.py +1041 -0
- egglog/examples/README.rst +5 -0
- egglog/examples/__init__.py +0 -0
- egglog/examples/eqsat_basic.py +43 -0
- egglog/examples/fib.py +28 -0
- egglog/examples/lambda.py +310 -0
- egglog/examples/matrix.py +184 -0
- egglog/examples/ndarrays.py +159 -0
- egglog/examples/resolution.py +84 -0
- egglog/examples/schedule_demo.py +33 -0
- egglog/ipython_magic.py +40 -0
- egglog/monkeypatch.py +33 -0
- egglog/py.typed +0 -0
- egglog/runtime.py +304 -0
- egglog/type_constraint_solver.py +79 -0
- egglog-0.4.0.dist-info/METADATA +53 -0
- egglog-0.4.0.dist-info/RECORD +25 -0
- egglog-0.4.0.dist-info/WHEEL +4 -0
- egglog-0.4.0.dist-info/license_files/LICENSE +21 -0
egglog/egraph.py
ADDED
|
@@ -0,0 +1,1041 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from dataclasses import InitVar, dataclass, field
|
|
6
|
+
from inspect import Parameter, currentframe, signature
|
|
7
|
+
from types import FunctionType
|
|
8
|
+
from typing import _GenericAlias # type: ignore[attr-defined]
|
|
9
|
+
from typing import (
|
|
10
|
+
TYPE_CHECKING,
|
|
11
|
+
Any,
|
|
12
|
+
Callable,
|
|
13
|
+
ClassVar,
|
|
14
|
+
Generic,
|
|
15
|
+
Iterable,
|
|
16
|
+
Literal,
|
|
17
|
+
NoReturn,
|
|
18
|
+
Optional,
|
|
19
|
+
TypeVar,
|
|
20
|
+
Union,
|
|
21
|
+
cast,
|
|
22
|
+
get_type_hints,
|
|
23
|
+
overload,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
import graphviz
|
|
27
|
+
from egglog.declarations import Declarations
|
|
28
|
+
from typing_extensions import ParamSpec, get_args, get_origin
|
|
29
|
+
|
|
30
|
+
from . import bindings
|
|
31
|
+
from .declarations import *
|
|
32
|
+
from .monkeypatch import monkeypatch_forward_ref
|
|
33
|
+
from .runtime import *
|
|
34
|
+
from .runtime import _resolve_callable, class_to_ref
|
|
35
|
+
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from .builtins import String
|
|
38
|
+
|
|
39
|
+
monkeypatch_forward_ref()
|
|
40
|
+
|
|
41
|
+
__all__ = [
|
|
42
|
+
"EGraph",
|
|
43
|
+
"Module",
|
|
44
|
+
"BUILTINS",
|
|
45
|
+
"BaseExpr",
|
|
46
|
+
"Unit",
|
|
47
|
+
"rewrite",
|
|
48
|
+
"eq",
|
|
49
|
+
"panic",
|
|
50
|
+
"let",
|
|
51
|
+
"delete",
|
|
52
|
+
"union",
|
|
53
|
+
"set_",
|
|
54
|
+
"rule",
|
|
55
|
+
"var",
|
|
56
|
+
"vars_",
|
|
57
|
+
"Fact",
|
|
58
|
+
"expr_parts",
|
|
59
|
+
"Schedule",
|
|
60
|
+
"run",
|
|
61
|
+
"seq",
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
T = TypeVar("T")
|
|
65
|
+
P = ParamSpec("P")
|
|
66
|
+
TYPE = TypeVar("TYPE", bound="type[BaseExpr]")
|
|
67
|
+
CALLABLE = TypeVar("CALLABLE", bound=Callable)
|
|
68
|
+
EXPR = TypeVar("EXPR", bound="BaseExpr")
|
|
69
|
+
E1 = TypeVar("E1", bound="BaseExpr")
|
|
70
|
+
E2 = TypeVar("E2", bound="BaseExpr")
|
|
71
|
+
E3 = TypeVar("E3", bound="BaseExpr")
|
|
72
|
+
E4 = TypeVar("E4", bound="BaseExpr")
|
|
73
|
+
# Attributes which are sometimes added to classes by the interpreter or the dataclass decorator, or by ipython.
|
|
74
|
+
# We ignore these when inspecting the class.
|
|
75
|
+
|
|
76
|
+
IGNORED_ATTRIBUTES = {
|
|
77
|
+
"__module__",
|
|
78
|
+
"__doc__",
|
|
79
|
+
"__dict__",
|
|
80
|
+
"__weakref__",
|
|
81
|
+
"__orig_bases__",
|
|
82
|
+
"__annotations__",
|
|
83
|
+
"__hash__",
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
_BUILTIN_DECLS: Declarations | None = None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class _BaseModule(ABC):
|
|
92
|
+
"""
|
|
93
|
+
Base Module which provides methods to register sorts, expressions, actions etc.
|
|
94
|
+
|
|
95
|
+
Inherited by:
|
|
96
|
+
- EGraph: Holds a live EGraph instance
|
|
97
|
+
- Builtins: Stores a list of the builtins which have already been pre-regsietered
|
|
98
|
+
- Module: Stores a list of commands and additional declerations
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
# Any modules you want to depend on
|
|
102
|
+
deps: InitVar[list[Module]] = []
|
|
103
|
+
# All dependencies flattened
|
|
104
|
+
_flatted_deps: list[Module] = field(init=False, default_factory=list)
|
|
105
|
+
_mod_decls: ModuleDeclarations = field(init=False)
|
|
106
|
+
|
|
107
|
+
def __post_init__(self, modules: list[Module] = []) -> None:
|
|
108
|
+
included_decls = [_BUILTIN_DECLS] if _BUILTIN_DECLS else []
|
|
109
|
+
# Traverse all the included modules to flatten all their dependencies and add to the included declerations
|
|
110
|
+
for mod in modules:
|
|
111
|
+
for child_mod in [*mod._flatted_deps, mod]:
|
|
112
|
+
if child_mod not in self._flatted_deps:
|
|
113
|
+
self._flatted_deps.append(child_mod)
|
|
114
|
+
included_decls.append(child_mod._mod_decls._decl)
|
|
115
|
+
self._mod_decls = ModuleDeclarations(Declarations(), included_decls)
|
|
116
|
+
|
|
117
|
+
@abstractmethod
|
|
118
|
+
def _process_commands(self, cmds: Iterable[bindings._Command]) -> None:
|
|
119
|
+
"""
|
|
120
|
+
Process the commands generated by this module.
|
|
121
|
+
"""
|
|
122
|
+
raise NotImplementedError
|
|
123
|
+
|
|
124
|
+
@overload
|
|
125
|
+
def class_(self, *, egg_sort: str) -> Callable[[TYPE], TYPE]:
|
|
126
|
+
...
|
|
127
|
+
|
|
128
|
+
@overload
|
|
129
|
+
def class_(self, cls: TYPE, /) -> TYPE:
|
|
130
|
+
...
|
|
131
|
+
|
|
132
|
+
def class_(self, *args, **kwargs) -> Any:
|
|
133
|
+
"""
|
|
134
|
+
Registers a class.
|
|
135
|
+
"""
|
|
136
|
+
frame = currentframe()
|
|
137
|
+
assert frame
|
|
138
|
+
prev_frame = frame.f_back
|
|
139
|
+
assert prev_frame
|
|
140
|
+
|
|
141
|
+
if kwargs:
|
|
142
|
+
assert set(kwargs.keys()) == {"egg_sort"}
|
|
143
|
+
return lambda cls: self._class(cls, prev_frame.f_locals, prev_frame.f_globals, kwargs["egg_sort"])
|
|
144
|
+
assert len(args) == 1
|
|
145
|
+
return self._class(args[0], prev_frame.f_locals, prev_frame.f_globals)
|
|
146
|
+
|
|
147
|
+
def _class(
|
|
148
|
+
self,
|
|
149
|
+
cls: type[BaseExpr],
|
|
150
|
+
hint_locals: dict[str, Any],
|
|
151
|
+
hint_globals: dict[str, Any],
|
|
152
|
+
egg_sort: Optional[str] = None,
|
|
153
|
+
) -> RuntimeClass:
|
|
154
|
+
"""
|
|
155
|
+
Registers a class.
|
|
156
|
+
"""
|
|
157
|
+
cls_name = cls.__name__
|
|
158
|
+
# Get all the methods from the class
|
|
159
|
+
cls_dict: dict[str, Any] = {k: v for k, v in cls.__dict__.items() if k not in IGNORED_ATTRIBUTES}
|
|
160
|
+
parameters: list[TypeVar] = cls_dict.pop("__parameters__", [])
|
|
161
|
+
|
|
162
|
+
n_type_vars = len(parameters)
|
|
163
|
+
self._process_commands(self._mod_decls.register_class(cls_name, n_type_vars, egg_sort))
|
|
164
|
+
# The type ref of self is paramterized by the type vars
|
|
165
|
+
slf_type_ref = TypeRefWithVars(cls_name, tuple(ClassTypeVarRef(i) for i in range(n_type_vars)))
|
|
166
|
+
|
|
167
|
+
# First register any class vars as constants
|
|
168
|
+
hint_globals = hint_globals.copy()
|
|
169
|
+
hint_globals[cls_name] = cls
|
|
170
|
+
for k, v in get_type_hints(cls, globalns=hint_globals, localns=hint_locals).items():
|
|
171
|
+
if v.__origin__ == ClassVar:
|
|
172
|
+
(inner_tp,) = v.__args__
|
|
173
|
+
self._register_constant(ClassVariableRef(cls_name, k), inner_tp, None, (cls, cls_name))
|
|
174
|
+
else:
|
|
175
|
+
raise NotImplementedError("The only supported annotations on class attributes are class vars")
|
|
176
|
+
|
|
177
|
+
# Then register each of its methods
|
|
178
|
+
for method_name, method in cls_dict.items():
|
|
179
|
+
is_init = method_name == "__init__"
|
|
180
|
+
# Don't register the init methods for literals, since those don't use the type checking mechanisms
|
|
181
|
+
if is_init and cls_name in LIT_CLASS_NAMES:
|
|
182
|
+
continue
|
|
183
|
+
if isinstance(method, _WrappedMethod):
|
|
184
|
+
fn = method.fn
|
|
185
|
+
egg_fn = method.egg_fn
|
|
186
|
+
cost = method.cost
|
|
187
|
+
default = method.default
|
|
188
|
+
merge = method.merge
|
|
189
|
+
on_merge = method.on_merge
|
|
190
|
+
else:
|
|
191
|
+
fn = method
|
|
192
|
+
egg_fn, cost, default, merge, on_merge = None, None, None, None, None
|
|
193
|
+
if isinstance(fn, classmethod):
|
|
194
|
+
fn = fn.__func__
|
|
195
|
+
is_classmethod = True
|
|
196
|
+
else:
|
|
197
|
+
# We count __init__ as a classmethod since it is called on the class
|
|
198
|
+
is_classmethod = is_init
|
|
199
|
+
|
|
200
|
+
ref: ClassMethodRef | MethodRef = (
|
|
201
|
+
ClassMethodRef(cls_name, method_name) if is_classmethod else MethodRef(cls_name, method_name)
|
|
202
|
+
)
|
|
203
|
+
self._register_function(
|
|
204
|
+
ref,
|
|
205
|
+
egg_fn,
|
|
206
|
+
fn,
|
|
207
|
+
hint_locals,
|
|
208
|
+
default,
|
|
209
|
+
cost,
|
|
210
|
+
merge,
|
|
211
|
+
on_merge,
|
|
212
|
+
"cls" if is_classmethod and not is_init else slf_type_ref,
|
|
213
|
+
parameters,
|
|
214
|
+
is_init,
|
|
215
|
+
# If this is an i64, use the runtime class for the alias so that i64Like is resolved properly
|
|
216
|
+
# Otherwise, this might be a Map in which case pass in the original cls so that we
|
|
217
|
+
# can do Map[T, V] on it, which is not allowed on the runtime class
|
|
218
|
+
cls_type_and_name=(
|
|
219
|
+
RuntimeClass(self._mod_decls, cls_name) if cls_name in {"i64", "String"} else cls,
|
|
220
|
+
cls_name,
|
|
221
|
+
),
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Register != as a method so we can print it as a string
|
|
225
|
+
self._mod_decls._decl.register_callable_ref(MethodRef(cls_name, "__ne__"), "!=")
|
|
226
|
+
return RuntimeClass(self._mod_decls, cls_name)
|
|
227
|
+
|
|
228
|
+
# We seperate the function and method overloads to make it simpler to know if we are modifying a function or method,
|
|
229
|
+
# So that we can add the functions eagerly to the registry and wait on the methods till we process the class.
|
|
230
|
+
|
|
231
|
+
# We have to seperate method/function overloads for those that use the T params and those that don't
|
|
232
|
+
# Otherwise, if you say just pass in `cost` then the T param is inferred as `Nothing` and
|
|
233
|
+
# It will break the typing.
|
|
234
|
+
@overload
|
|
235
|
+
def method( # type: ignore
|
|
236
|
+
self,
|
|
237
|
+
*,
|
|
238
|
+
egg_fn: Optional[str] = None,
|
|
239
|
+
cost: Optional[int] = None,
|
|
240
|
+
merge: Optional[Callable[[Any, Any], Any]] = None,
|
|
241
|
+
on_merge: Optional[Callable[[Any, Any], Iterable[ActionLike]]] = None,
|
|
242
|
+
) -> Callable[[CALLABLE], CALLABLE]:
|
|
243
|
+
...
|
|
244
|
+
|
|
245
|
+
@overload
|
|
246
|
+
def method(
|
|
247
|
+
self,
|
|
248
|
+
*,
|
|
249
|
+
egg_fn: Optional[str] = None,
|
|
250
|
+
cost: Optional[int] = None,
|
|
251
|
+
default: Optional[EXPR] = None,
|
|
252
|
+
merge: Optional[Callable[[EXPR, EXPR], EXPR]] = None,
|
|
253
|
+
on_merge: Optional[Callable[[EXPR, EXPR], Iterable[ActionLike]]] = None,
|
|
254
|
+
) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
|
|
255
|
+
...
|
|
256
|
+
|
|
257
|
+
def method(
|
|
258
|
+
self,
|
|
259
|
+
*,
|
|
260
|
+
egg_fn: Optional[str] = None,
|
|
261
|
+
cost: Optional[int] = None,
|
|
262
|
+
default: Optional[EXPR] = None,
|
|
263
|
+
merge: Optional[Callable[[EXPR, EXPR], EXPR]] = None,
|
|
264
|
+
on_merge: Optional[Callable[[EXPR, EXPR], Iterable[ActionLike]]] = None,
|
|
265
|
+
) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
|
|
266
|
+
return lambda fn: _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn)
|
|
267
|
+
|
|
268
|
+
@overload
|
|
269
|
+
def function(self, fn: CALLABLE, /) -> CALLABLE:
|
|
270
|
+
...
|
|
271
|
+
|
|
272
|
+
@overload
|
|
273
|
+
def function( # type: ignore
|
|
274
|
+
self,
|
|
275
|
+
*,
|
|
276
|
+
egg_fn: Optional[str] = None,
|
|
277
|
+
cost: Optional[int] = None,
|
|
278
|
+
merge: Optional[Callable[[Any, Any], Any]] = None,
|
|
279
|
+
on_merge: Optional[Callable[[Any, Any], Iterable[ActionLike]]] = None,
|
|
280
|
+
) -> Callable[[CALLABLE], CALLABLE]:
|
|
281
|
+
...
|
|
282
|
+
|
|
283
|
+
@overload
|
|
284
|
+
def function(
|
|
285
|
+
self,
|
|
286
|
+
*,
|
|
287
|
+
egg_fn: Optional[str] = None,
|
|
288
|
+
cost: Optional[int] = None,
|
|
289
|
+
default: Optional[EXPR] = None,
|
|
290
|
+
merge: Optional[Callable[[EXPR, EXPR], EXPR]] = None,
|
|
291
|
+
on_merge: Optional[Callable[[EXPR, EXPR], Iterable[ActionLike]]] = None,
|
|
292
|
+
) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
|
|
293
|
+
...
|
|
294
|
+
|
|
295
|
+
def function(self, *args, **kwargs) -> Any:
|
|
296
|
+
"""
|
|
297
|
+
Registers a function.
|
|
298
|
+
"""
|
|
299
|
+
fn_locals = currentframe().f_back.f_locals # type: ignore
|
|
300
|
+
|
|
301
|
+
# If we have any positional args, then we are calling it directly on a function
|
|
302
|
+
if args:
|
|
303
|
+
assert len(args) == 1
|
|
304
|
+
return self._function(args[0], fn_locals)
|
|
305
|
+
# otherwise, we are passing some keyword args, so save those, and then return a partial
|
|
306
|
+
return lambda fn: self._function(fn, fn_locals, **kwargs)
|
|
307
|
+
|
|
308
|
+
def _function(
|
|
309
|
+
self,
|
|
310
|
+
fn: Callable[..., RuntimeExpr],
|
|
311
|
+
hint_locals: dict[str, Any],
|
|
312
|
+
egg_fn: Optional[str] = None,
|
|
313
|
+
cost: Optional[int] = None,
|
|
314
|
+
default: Optional[RuntimeExpr] = None,
|
|
315
|
+
merge: Optional[Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr]] = None,
|
|
316
|
+
on_merge: Optional[Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]]] = None,
|
|
317
|
+
) -> RuntimeFunction:
|
|
318
|
+
"""
|
|
319
|
+
Uncurried version of function decorator
|
|
320
|
+
"""
|
|
321
|
+
name = fn.__name__
|
|
322
|
+
# Save function decleartion
|
|
323
|
+
self._register_function(FunctionRef(name), egg_fn, fn, hint_locals, default, cost, merge, on_merge)
|
|
324
|
+
# Return a runtime function which will act like the decleration
|
|
325
|
+
return RuntimeFunction(self._mod_decls, name)
|
|
326
|
+
|
|
327
|
+
def _register_function(
|
|
328
|
+
self,
|
|
329
|
+
ref: FunctionCallableRef,
|
|
330
|
+
egg_name: Optional[str],
|
|
331
|
+
fn: Any,
|
|
332
|
+
# Pass in the locals, retrieved from the frame when wrapping,
|
|
333
|
+
# so that we support classes and function defined inside of other functions (which won't show up in the globals)
|
|
334
|
+
hint_locals: dict[str, Any],
|
|
335
|
+
default: Optional[RuntimeExpr],
|
|
336
|
+
cost: Optional[int],
|
|
337
|
+
merge: Optional[Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr]],
|
|
338
|
+
on_merge: Optional[Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]]],
|
|
339
|
+
# The first arg is either cls, for a classmethod, a self type, or none for a function
|
|
340
|
+
first_arg: Literal["cls"] | TypeOrVarRef | None = None,
|
|
341
|
+
cls_typevars: list[TypeVar] = [],
|
|
342
|
+
is_init: bool = False,
|
|
343
|
+
cls_type_and_name: Optional[tuple[type | RuntimeClass, str]] = None,
|
|
344
|
+
) -> None:
|
|
345
|
+
if not isinstance(fn, FunctionType):
|
|
346
|
+
raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
|
|
347
|
+
|
|
348
|
+
hint_globals = fn.__globals__.copy()
|
|
349
|
+
|
|
350
|
+
if cls_type_and_name:
|
|
351
|
+
hint_globals[cls_type_and_name[1]] = cls_type_and_name[0]
|
|
352
|
+
hints = get_type_hints(fn, hint_globals, hint_locals)
|
|
353
|
+
# If this is an init fn use the first arg as the return type
|
|
354
|
+
if is_init:
|
|
355
|
+
if not isinstance(first_arg, (ClassTypeVarRef, TypeRefWithVars)):
|
|
356
|
+
raise ValueError("Init function must have a self type")
|
|
357
|
+
return_type = first_arg
|
|
358
|
+
else:
|
|
359
|
+
return_type = self._resolve_type_annotation(hints["return"], cls_typevars, cls_type_and_name)
|
|
360
|
+
|
|
361
|
+
params = list(signature(fn).parameters.values())
|
|
362
|
+
# Remove first arg if this is a classmethod or a method, since it won't have an annotation
|
|
363
|
+
if first_arg is not None:
|
|
364
|
+
first, *params = params
|
|
365
|
+
if first.annotation != Parameter.empty:
|
|
366
|
+
raise ValueError(f"First arg of a method must not have an annotation, not {first.annotation}")
|
|
367
|
+
|
|
368
|
+
# Check that all the params are positional or keyword, and that there is only one var arg at the end
|
|
369
|
+
found_var_arg = False
|
|
370
|
+
for param in params:
|
|
371
|
+
if found_var_arg:
|
|
372
|
+
raise ValueError("Can only have a single var arg at the end")
|
|
373
|
+
kind = param.kind
|
|
374
|
+
if kind == Parameter.VAR_POSITIONAL:
|
|
375
|
+
found_var_arg = True
|
|
376
|
+
elif kind != Parameter.POSITIONAL_OR_KEYWORD:
|
|
377
|
+
raise ValueError(f"Can only register functions with positional or keyword args, not {param.kind}")
|
|
378
|
+
|
|
379
|
+
if found_var_arg:
|
|
380
|
+
var_arg_param, *params = params
|
|
381
|
+
var_arg_type = self._resolve_type_annotation(hints[var_arg_param.name], cls_typevars, cls_type_and_name)
|
|
382
|
+
else:
|
|
383
|
+
var_arg_type = None
|
|
384
|
+
arg_types = tuple(self._resolve_type_annotation(hints[t.name], cls_typevars, cls_type_and_name) for t in params)
|
|
385
|
+
# If the first arg is a self, and this not an __init__ fn, add this as a typeref
|
|
386
|
+
if isinstance(first_arg, (ClassTypeVarRef, TypeRefWithVars)) and not is_init:
|
|
387
|
+
arg_types = (first_arg,) + arg_types
|
|
388
|
+
|
|
389
|
+
default_decl = None if default is None else default.__egg_typed_expr__.expr
|
|
390
|
+
merge_decl = (
|
|
391
|
+
None
|
|
392
|
+
if merge is None
|
|
393
|
+
else merge(
|
|
394
|
+
RuntimeExpr(self._mod_decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
|
|
395
|
+
RuntimeExpr(self._mod_decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
|
|
396
|
+
).__egg_typed_expr__.expr
|
|
397
|
+
)
|
|
398
|
+
merge_action = (
|
|
399
|
+
[]
|
|
400
|
+
if on_merge is None
|
|
401
|
+
else _action_likes(
|
|
402
|
+
on_merge(
|
|
403
|
+
RuntimeExpr(self._mod_decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
|
|
404
|
+
RuntimeExpr(self._mod_decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
|
|
405
|
+
)
|
|
406
|
+
)
|
|
407
|
+
)
|
|
408
|
+
fn_decl = FunctionDecl(return_type=return_type, var_arg_type=var_arg_type, arg_types=arg_types)
|
|
409
|
+
self._process_commands(
|
|
410
|
+
self._mod_decls.register_function_callable(
|
|
411
|
+
ref, fn_decl, egg_name, cost, default_decl, merge_decl, merge_action
|
|
412
|
+
)
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
def _resolve_type_annotation(
|
|
416
|
+
self,
|
|
417
|
+
tp: object,
|
|
418
|
+
cls_typevars: list[TypeVar],
|
|
419
|
+
cls_type_and_name: Optional[tuple[type | RuntimeClass, str]],
|
|
420
|
+
) -> TypeOrVarRef:
|
|
421
|
+
if isinstance(tp, TypeVar):
|
|
422
|
+
return ClassTypeVarRef(cls_typevars.index(tp))
|
|
423
|
+
# If there is a union, it should be of a literal and another type to allow type promotion
|
|
424
|
+
if get_origin(tp) == Union:
|
|
425
|
+
args = get_args(tp)
|
|
426
|
+
if len(args) != 2:
|
|
427
|
+
raise TypeError("Union types are only supported for type promotion")
|
|
428
|
+
fst, snd = args
|
|
429
|
+
if fst in {int, str, float}:
|
|
430
|
+
return self._resolve_type_annotation(snd, cls_typevars, cls_type_and_name)
|
|
431
|
+
if snd in {int, str, float}:
|
|
432
|
+
return self._resolve_type_annotation(fst, cls_typevars, cls_type_and_name)
|
|
433
|
+
raise TypeError("Union types are only supported for type promotion")
|
|
434
|
+
|
|
435
|
+
# If this is the type for the class, use the class name
|
|
436
|
+
if cls_type_and_name and tp == cls_type_and_name[0]:
|
|
437
|
+
return TypeRefWithVars(cls_type_and_name[1])
|
|
438
|
+
|
|
439
|
+
# If this is the class for this method and we have a paramaterized class, recurse
|
|
440
|
+
if (
|
|
441
|
+
cls_type_and_name
|
|
442
|
+
and isinstance(tp, _GenericAlias)
|
|
443
|
+
and tp.__origin__ == cls_type_and_name[0] # type: ignore
|
|
444
|
+
):
|
|
445
|
+
return TypeRefWithVars(
|
|
446
|
+
cls_type_and_name[1],
|
|
447
|
+
tuple(
|
|
448
|
+
self._resolve_type_annotation(a, cls_typevars, cls_type_and_name)
|
|
449
|
+
for a in tp.__args__ # type: ignore
|
|
450
|
+
),
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
if isinstance(tp, (RuntimeClass, RuntimeParamaterizedClass)):
|
|
454
|
+
return class_to_ref(tp).to_var()
|
|
455
|
+
raise TypeError(f"Unexpected type annotation {tp}")
|
|
456
|
+
|
|
457
|
+
def register(self, command_or_generator: CommandLike | CommandGenerator, *commands: CommandLike) -> None:
|
|
458
|
+
"""
|
|
459
|
+
Registers any number of rewrites or rules.
|
|
460
|
+
"""
|
|
461
|
+
if isinstance(command_or_generator, FunctionType):
|
|
462
|
+
assert not commands
|
|
463
|
+
commands = tuple(_command_generator(command_or_generator))
|
|
464
|
+
else:
|
|
465
|
+
commands = (cast(CommandLike, command_or_generator), *commands)
|
|
466
|
+
self._process_commands(_command_like(command)._to_egg_command(self._mod_decls) for command in commands)
|
|
467
|
+
|
|
468
|
+
def ruleset(self, name: str) -> Ruleset:
|
|
469
|
+
self._process_commands([bindings.AddRuleset(name)])
|
|
470
|
+
return Ruleset(name)
|
|
471
|
+
|
|
472
|
+
# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
|
|
473
|
+
@overload
|
|
474
|
+
def relation(
|
|
475
|
+
self, name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], tp4: type[E4], /
|
|
476
|
+
) -> Callable[[E1, E2, E3, E4], Unit]:
|
|
477
|
+
...
|
|
478
|
+
|
|
479
|
+
@overload
|
|
480
|
+
def relation(self, name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], /) -> Callable[[E1, E2, E3], Unit]:
|
|
481
|
+
...
|
|
482
|
+
|
|
483
|
+
@overload
|
|
484
|
+
def relation(self, name: str, tp1: type[E1], tp2: type[E2], /) -> Callable[[E1, E2], Unit]:
|
|
485
|
+
...
|
|
486
|
+
|
|
487
|
+
@overload
|
|
488
|
+
def relation(self, name: str, tp1: type[T], /, *, egg_fn: Optional[str] = None) -> Callable[[T], Unit]:
|
|
489
|
+
...
|
|
490
|
+
|
|
491
|
+
@overload
|
|
492
|
+
def relation(self, name: str, /, *, egg_fn: Optional[str] = None) -> Callable[[], Unit]:
|
|
493
|
+
...
|
|
494
|
+
|
|
495
|
+
def relation(self, name: str, /, *tps: type, egg_fn: Optional[str] = None) -> Callable[..., Unit]:
|
|
496
|
+
"""
|
|
497
|
+
Defines a relation, which is the same as a function which returns unit.
|
|
498
|
+
"""
|
|
499
|
+
arg_types = tuple(self._resolve_type_annotation(cast(object, tp), [], None) for tp in tps)
|
|
500
|
+
fn_decl = FunctionDecl(arg_types, TypeRefWithVars("Unit"))
|
|
501
|
+
commands = self._mod_decls.register_function_callable(
|
|
502
|
+
FunctionRef(name), fn_decl, egg_fn, cost=None, default=None, merge=None, merge_action=[]
|
|
503
|
+
)
|
|
504
|
+
self._process_commands(commands)
|
|
505
|
+
return cast(Callable[..., Unit], RuntimeFunction(self._mod_decls, name))
|
|
506
|
+
|
|
507
|
+
def input(self, fn: Callable[..., String], path: str) -> None:
|
|
508
|
+
"""
|
|
509
|
+
Loads a CSV file and sets it as *input, output of the function.
|
|
510
|
+
"""
|
|
511
|
+
fn_name = self._mod_decls.get_egg_fn(_resolve_callable(fn))
|
|
512
|
+
self._process_commands([bindings.Input(fn_name, path)])
|
|
513
|
+
|
|
514
|
+
def constant(self, name: str, tp: type[EXPR], egg_name: Optional[str] = None) -> EXPR:
|
|
515
|
+
"""
|
|
516
|
+
Defines a named constant of a certain type.
|
|
517
|
+
|
|
518
|
+
This is the same as defining a nullary function with a high cost.
|
|
519
|
+
"""
|
|
520
|
+
ref = ConstantRef(name)
|
|
521
|
+
type_ref = self._register_constant(ref, tp, egg_name, None)
|
|
522
|
+
return cast(EXPR, RuntimeExpr(self._mod_decls, TypedExprDecl(type_ref, CallDecl(ref))))
|
|
523
|
+
|
|
524
|
+
def _register_constant(
|
|
525
|
+
self,
|
|
526
|
+
ref: ConstantRef | ClassVariableRef,
|
|
527
|
+
tp: object,
|
|
528
|
+
egg_name: Optional[str],
|
|
529
|
+
cls_type_and_name: Optional[tuple[type | RuntimeClass, str]],
|
|
530
|
+
) -> JustTypeRef:
|
|
531
|
+
"""
|
|
532
|
+
Register a constant, returning its typeref().
|
|
533
|
+
"""
|
|
534
|
+
type_ref = self._resolve_type_annotation(tp, [], cls_type_and_name).to_just()
|
|
535
|
+
self._process_commands(self._mod_decls.register_constant_callable(ref, type_ref, egg_name))
|
|
536
|
+
return type_ref
|
|
537
|
+
|
|
538
|
+
def define(self, name: str, expr: EXPR) -> EXPR:
|
|
539
|
+
"""
|
|
540
|
+
Define a new expression in the egraph and return a reference to it.
|
|
541
|
+
"""
|
|
542
|
+
# Don't support cost and maybe will be removed in favor of let
|
|
543
|
+
# https://github.com/egraphs-good/egglog/issues/128#issuecomment-1523760578
|
|
544
|
+
typed_expr = expr_parts(expr)
|
|
545
|
+
self._process_commands([bindings.Define(name, typed_expr.to_egg(self._mod_decls), None)])
|
|
546
|
+
return cast(EXPR, RuntimeExpr(self._mod_decls, TypedExprDecl(typed_expr.tp, VarDecl(name))))
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
@dataclass
|
|
550
|
+
class _Builtins(_BaseModule):
|
|
551
|
+
def __post_init__(self, modules: list[Module] = []) -> None:
|
|
552
|
+
"""
|
|
553
|
+
Register these declarations as builtins, so others can use them.
|
|
554
|
+
"""
|
|
555
|
+
assert not modules
|
|
556
|
+
super().__post_init__(modules)
|
|
557
|
+
global _BUILTIN_DECLS
|
|
558
|
+
if _BUILTIN_DECLS is not None:
|
|
559
|
+
raise RuntimeError("Builtins already initialized")
|
|
560
|
+
_BUILTIN_DECLS = self._mod_decls._decl
|
|
561
|
+
|
|
562
|
+
def _process_commands(self, cmds: Iterable[bindings._Command]) -> None:
|
|
563
|
+
"""
|
|
564
|
+
Commands which would have been used to create the builtins are discarded, since they are already registered.
|
|
565
|
+
"""
|
|
566
|
+
pass
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
@dataclass
|
|
570
|
+
class Module(_BaseModule):
|
|
571
|
+
_cmds: list[bindings._Command] = field(default_factory=list, repr=False)
|
|
572
|
+
|
|
573
|
+
def _process_commands(self, cmds: Iterable[bindings._Command]) -> None:
|
|
574
|
+
self._cmds.extend(cmds)
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
@dataclass
|
|
578
|
+
class EGraph(_BaseModule):
|
|
579
|
+
"""
|
|
580
|
+
Represents an EGraph instance at runtime
|
|
581
|
+
"""
|
|
582
|
+
|
|
583
|
+
_egraph: bindings.EGraph = field(repr=False, default_factory=bindings.EGraph)
|
|
584
|
+
# The current declarations which have been pushed to the stack
|
|
585
|
+
_decl_stack: list[Declarations] = field(default_factory=list, repr=False)
|
|
586
|
+
|
|
587
|
+
def __post_init__(self, modules: list[Module] = []) -> None:
|
|
588
|
+
super().__post_init__(modules)
|
|
589
|
+
for m in self._flatted_deps:
|
|
590
|
+
self._process_commands(m._cmds)
|
|
591
|
+
|
|
592
|
+
def _process_commands(self, commands: Iterable[bindings._Command]) -> None:
|
|
593
|
+
self._egraph.run_program(*commands)
|
|
594
|
+
|
|
595
|
+
def _repr_mimebundle_(self, *args, **kwargs):
|
|
596
|
+
"""
|
|
597
|
+
Returns the graphviz representation of the e-graph.
|
|
598
|
+
"""
|
|
599
|
+
|
|
600
|
+
return self.graphviz._repr_mimebundle_(*args, **kwargs)
|
|
601
|
+
|
|
602
|
+
@property
|
|
603
|
+
def graphviz(self) -> graphviz.Source:
|
|
604
|
+
return graphviz.Source(self._egraph.to_graphviz_string())
|
|
605
|
+
|
|
606
|
+
def _repr_html_(self) -> str:
|
|
607
|
+
"""
|
|
608
|
+
Add a _repr_html_ to be an SVG to work with sphinx gallery
|
|
609
|
+
ala https://github.com/xflr6/graphviz/pull/121
|
|
610
|
+
until this PR is merged and released
|
|
611
|
+
https://github.com/sphinx-gallery/sphinx-gallery/pull/1138
|
|
612
|
+
"""
|
|
613
|
+
return self.graphviz.pipe(format="svg").decode()
|
|
614
|
+
|
|
615
|
+
def display(self):
|
|
616
|
+
"""
|
|
617
|
+
Displays the e-graph in the notebook.
|
|
618
|
+
"""
|
|
619
|
+
from IPython.display import display
|
|
620
|
+
|
|
621
|
+
display(self)
|
|
622
|
+
|
|
623
|
+
def simplify(self, expr: EXPR, limit: int, *until: Fact, ruleset: Optional[Ruleset] = None) -> EXPR:
|
|
624
|
+
"""
|
|
625
|
+
Simplifies the given expression.
|
|
626
|
+
"""
|
|
627
|
+
typed_expr = expr_parts(expr)
|
|
628
|
+
egg_expr = typed_expr.to_egg(self._mod_decls)
|
|
629
|
+
self._process_commands(
|
|
630
|
+
[bindings.Simplify(egg_expr, Run(limit, _ruleset_name(ruleset), until)._to_egg_config(self._mod_decls))]
|
|
631
|
+
)
|
|
632
|
+
extract_report = self._egraph.extract_report()
|
|
633
|
+
if not extract_report:
|
|
634
|
+
raise ValueError("No extract report saved")
|
|
635
|
+
new_typed_expr = TypedExprDecl.from_egg(self._mod_decls, extract_report.expr)
|
|
636
|
+
return cast(EXPR, RuntimeExpr(self._mod_decls, new_typed_expr))
|
|
637
|
+
|
|
638
|
+
def include(self, path: str) -> None:
|
|
639
|
+
"""
|
|
640
|
+
Include a file of rules.
|
|
641
|
+
"""
|
|
642
|
+
raise NotImplementedError(
|
|
643
|
+
"Not implemented yet, because we don't have a way of registering the types with Python"
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
def output(self) -> None:
|
|
647
|
+
raise NotImplementedError("Not imeplemented yet, because there are no examples in the egglog repo")
|
|
648
|
+
|
|
649
|
+
@overload
|
|
650
|
+
def run(self, limit: int, /, *until: Fact, ruleset: Optional[Ruleset] = None) -> bindings.RunReport:
|
|
651
|
+
...
|
|
652
|
+
|
|
653
|
+
@overload
|
|
654
|
+
def run(self, schedule: Schedule, /) -> bindings.RunReport:
|
|
655
|
+
...
|
|
656
|
+
|
|
657
|
+
def run(
|
|
658
|
+
self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Optional[Ruleset] = None
|
|
659
|
+
) -> bindings.RunReport:
|
|
660
|
+
"""
|
|
661
|
+
Run the egraph until the given limit or until the given facts are true.
|
|
662
|
+
"""
|
|
663
|
+
if isinstance(limit_or_schedule, int):
|
|
664
|
+
limit_or_schedule = run(ruleset, limit_or_schedule, *until)
|
|
665
|
+
return self._run_schedule(limit_or_schedule)
|
|
666
|
+
|
|
667
|
+
def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
|
|
668
|
+
self._process_commands([bindings.RunScheduleCommand(schedule._to_egg_schedule(self._mod_decls))])
|
|
669
|
+
run_report = self._egraph.run_report()
|
|
670
|
+
if not run_report:
|
|
671
|
+
raise ValueError("No run report saved")
|
|
672
|
+
return run_report
|
|
673
|
+
|
|
674
|
+
def check(self, *facts: FactLike) -> None:
|
|
675
|
+
"""
|
|
676
|
+
Check if a fact is true in the egraph.
|
|
677
|
+
"""
|
|
678
|
+
self._process_commands([self._facts_to_check(facts)])
|
|
679
|
+
|
|
680
|
+
def check_fail(self, *facts: FactLike) -> None:
|
|
681
|
+
"""
|
|
682
|
+
Checks that one of the facts is not true
|
|
683
|
+
"""
|
|
684
|
+
self._process_commands([bindings.Fail(self._facts_to_check(facts))])
|
|
685
|
+
|
|
686
|
+
def _facts_to_check(self, facts: Iterable[FactLike]) -> bindings.Check:
|
|
687
|
+
egg_facts = [f._to_egg_fact(self._mod_decls) for f in _fact_likes(facts)]
|
|
688
|
+
return bindings.Check(egg_facts)
|
|
689
|
+
|
|
690
|
+
def extract(self, expr: EXPR) -> EXPR:
|
|
691
|
+
"""
|
|
692
|
+
Extract the lowest cost expression from the egraph.
|
|
693
|
+
"""
|
|
694
|
+
typed_expr = expr_parts(expr)
|
|
695
|
+
egg_expr = typed_expr.to_egg(self._mod_decls)
|
|
696
|
+
extract_report = self._run_extract(egg_expr, 0)
|
|
697
|
+
new_typed_expr = TypedExprDecl.from_egg(self._mod_decls, extract_report.expr)
|
|
698
|
+
if new_typed_expr.tp != typed_expr.tp:
|
|
699
|
+
raise RuntimeError(f"Type mismatch: {new_typed_expr.tp} != {typed_expr.tp}")
|
|
700
|
+
return cast(EXPR, RuntimeExpr(self._mod_decls, new_typed_expr))
|
|
701
|
+
|
|
702
|
+
def extract_multiple(self, expr: EXPR, n: int) -> list[EXPR]:
|
|
703
|
+
"""
|
|
704
|
+
Extract multiple expressions from the egraph.
|
|
705
|
+
"""
|
|
706
|
+
typed_expr = expr_parts(expr)
|
|
707
|
+
egg_expr = typed_expr.to_egg(self._mod_decls)
|
|
708
|
+
extract_report = self._run_extract(egg_expr, n)
|
|
709
|
+
new_exprs = [TypedExprDecl.from_egg(self._mod_decls, egg_expr) for egg_expr in extract_report.variants]
|
|
710
|
+
return [cast(EXPR, RuntimeExpr(self._mod_decls, expr)) for expr in new_exprs]
|
|
711
|
+
|
|
712
|
+
def _run_extract(self, expr: bindings._Expr, n: int) -> bindings.ExtractReport:
|
|
713
|
+
self._process_commands([bindings.Extract(n, expr)])
|
|
714
|
+
extract_report = self._egraph.extract_report()
|
|
715
|
+
if not extract_report:
|
|
716
|
+
raise ValueError("No extract report saved")
|
|
717
|
+
return extract_report
|
|
718
|
+
|
|
719
|
+
def push(self) -> None:
|
|
720
|
+
"""
|
|
721
|
+
Push the current state of the egraph, so that it can be popped later and reverted back.
|
|
722
|
+
"""
|
|
723
|
+
self._process_commands([bindings.Push(1)])
|
|
724
|
+
self._decl_stack.append(self._mod_decls._decl)
|
|
725
|
+
self._decls = deepcopy(self._mod_decls._decl)
|
|
726
|
+
|
|
727
|
+
def pop(self) -> None:
|
|
728
|
+
"""
|
|
729
|
+
Pop the current state of the egraph, reverting back to the previous state.
|
|
730
|
+
"""
|
|
731
|
+
self._process_commands([bindings.Pop(1)])
|
|
732
|
+
self._mod_decls._decl = self._decl_stack.pop()
|
|
733
|
+
|
|
734
|
+
def __enter__(self):
|
|
735
|
+
"""
|
|
736
|
+
Copy the egraph state, so that it can be reverted back to the original state at the end.
|
|
737
|
+
"""
|
|
738
|
+
self.push()
|
|
739
|
+
|
|
740
|
+
def __exit__(self, exc_type, exc, exc_tb):
|
|
741
|
+
self.pop()
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
@dataclass(frozen=True)
|
|
745
|
+
class _WrappedMethod(Generic[P, EXPR]):
|
|
746
|
+
"""
|
|
747
|
+
Used to wrap a method and store some extra options on it before processing it when processing the class.
|
|
748
|
+
"""
|
|
749
|
+
|
|
750
|
+
egg_fn: Optional[str]
|
|
751
|
+
cost: Optional[int]
|
|
752
|
+
default: Optional[EXPR]
|
|
753
|
+
merge: Optional[Callable[[EXPR, EXPR], EXPR]]
|
|
754
|
+
on_merge: Optional[Callable[[EXPR, EXPR], Iterable[ActionLike]]]
|
|
755
|
+
fn: Callable[P, EXPR]
|
|
756
|
+
|
|
757
|
+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> EXPR:
|
|
758
|
+
raise NotImplementedError("We should never call a wrapped method. Did you forget to wrap the class?")
|
|
759
|
+
|
|
760
|
+
|
|
761
|
+
class _BaseExprMetaclass(type):
|
|
762
|
+
"""
|
|
763
|
+
Metaclass of BaseExpr, used to override isistance checks, so that runtime expressions are instances
|
|
764
|
+
of BaseExpr at runtime.
|
|
765
|
+
"""
|
|
766
|
+
|
|
767
|
+
def __instancecheck__(self, instance: object) -> bool:
|
|
768
|
+
return isinstance(instance, RuntimeExpr)
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
class BaseExpr(metaclass=_BaseExprMetaclass):
|
|
772
|
+
"""
|
|
773
|
+
Expression base class, which adds suport for != to all expression types.
|
|
774
|
+
"""
|
|
775
|
+
|
|
776
|
+
def __ne__(self: EXPR, other_expr: EXPR) -> Unit: # type: ignore[override, empty-body]
|
|
777
|
+
"""
|
|
778
|
+
Compare whether to expressions are not equal.
|
|
779
|
+
|
|
780
|
+
:param self: The expression to compare.
|
|
781
|
+
:param other_expr: The other expression to compare to, which must be of the same type.
|
|
782
|
+
:meta public:
|
|
783
|
+
"""
|
|
784
|
+
...
|
|
785
|
+
|
|
786
|
+
def __eq__(self, other: NoReturn) -> NoReturn: # type: ignore[override, empty-body]
|
|
787
|
+
"""
|
|
788
|
+
Equality is currently not supported. We only add this method so that
|
|
789
|
+
if you try to use it MyPy will warn you.
|
|
790
|
+
"""
|
|
791
|
+
...
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
BUILTINS = _Builtins()
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
@BUILTINS.class_(egg_sort="Unit")
|
|
798
|
+
class Unit(BaseExpr):
|
|
799
|
+
"""
|
|
800
|
+
The unit type. This is also used to reprsent if a value exists, if it is resolved or not.
|
|
801
|
+
"""
|
|
802
|
+
|
|
803
|
+
def __init__(self) -> None:
|
|
804
|
+
...
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
@dataclass(frozen=True)
|
|
808
|
+
class Ruleset:
|
|
809
|
+
name: str
|
|
810
|
+
|
|
811
|
+
|
|
812
|
+
def _ruleset_name(ruleset: Optional[Ruleset]) -> str:
|
|
813
|
+
return ruleset.name if ruleset else ""
|
|
814
|
+
|
|
815
|
+
|
|
816
|
+
# We use these builders so that when creating these structures we can type check
|
|
817
|
+
# if the arguments are the same type of expression
|
|
818
|
+
|
|
819
|
+
|
|
820
|
+
def rewrite(lhs: EXPR, ruleset: Optional[Ruleset] = None) -> _RewriteBuilder[EXPR]:
|
|
821
|
+
"""Rewrite the given expression to a new expression."""
|
|
822
|
+
return _RewriteBuilder(lhs, ruleset)
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
def birewrite(lhs: EXPR, ruleset: Optional[Ruleset] = None) -> _BirewriteBuilder[EXPR]:
|
|
826
|
+
"""Rewrite the given expression to a new expression and vice versa."""
|
|
827
|
+
return _BirewriteBuilder(lhs, ruleset)
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
def eq(expr: EXPR) -> _EqBuilder[EXPR]:
|
|
831
|
+
"""Check if the given expression is equal to the given value."""
|
|
832
|
+
return _EqBuilder(expr)
|
|
833
|
+
|
|
834
|
+
|
|
835
|
+
def panic(message: str) -> Action:
|
|
836
|
+
"""Raise an error with the given message."""
|
|
837
|
+
return Panic(message)
|
|
838
|
+
|
|
839
|
+
|
|
840
|
+
def let(name: str, expr: BaseExpr) -> Action:
|
|
841
|
+
"""Create a let binding."""
|
|
842
|
+
return Let(name, expr_parts(expr).expr)
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
def expr_action(expr: BaseExpr) -> Action:
|
|
846
|
+
typed_expr = expr_parts(expr)
|
|
847
|
+
return ExprAction(typed_expr.expr)
|
|
848
|
+
|
|
849
|
+
|
|
850
|
+
def delete(expr: BaseExpr) -> Action:
|
|
851
|
+
"""Create a delete expression."""
|
|
852
|
+
decl = expr_parts(expr).expr
|
|
853
|
+
if not isinstance(decl, CallDecl):
|
|
854
|
+
raise ValueError(f"Can only delete calls not {decl}")
|
|
855
|
+
return Delete(decl)
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
def expr_fact(expr: BaseExpr) -> Fact:
|
|
859
|
+
return ExprFact(expr_parts(expr).expr)
|
|
860
|
+
|
|
861
|
+
|
|
862
|
+
def union(lhs: EXPR) -> _UnionBuilder[EXPR]:
|
|
863
|
+
"""Create a union of the given expression."""
|
|
864
|
+
return _UnionBuilder(lhs=lhs)
|
|
865
|
+
|
|
866
|
+
|
|
867
|
+
def set_(lhs: EXPR) -> _SetBuilder[EXPR]:
|
|
868
|
+
"""Create a set of the given expression."""
|
|
869
|
+
return _SetBuilder(lhs=lhs)
|
|
870
|
+
|
|
871
|
+
|
|
872
|
+
def rule(*facts: FactLike, ruleset: Optional[Ruleset] = None, name: Optional[str] = None) -> _RuleBuilder:
|
|
873
|
+
"""Create a rule with the given facts."""
|
|
874
|
+
return _RuleBuilder(facts=_fact_likes(facts), name=name, ruleset=ruleset)
|
|
875
|
+
|
|
876
|
+
|
|
877
|
+
def var(name: str, bound: type[EXPR]) -> EXPR:
|
|
878
|
+
"""Create a new variable with the given name and type."""
|
|
879
|
+
return cast(EXPR, _var(name, bound))
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
def _var(name: str, bound: Any) -> RuntimeExpr:
|
|
883
|
+
"""Create a new variable with the given name and type."""
|
|
884
|
+
if not isinstance(bound, (RuntimeClass, RuntimeParamaterizedClass)):
|
|
885
|
+
raise TypeError(f"Unexpected type {type(bound)}")
|
|
886
|
+
return RuntimeExpr(bound.__egg_decls__, TypedExprDecl(class_to_ref(bound), VarDecl(name)))
|
|
887
|
+
|
|
888
|
+
|
|
889
|
+
def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
|
|
890
|
+
"""Create variables with the given names and type."""
|
|
891
|
+
for name in names.split(" "):
|
|
892
|
+
yield var(name, bound)
|
|
893
|
+
|
|
894
|
+
|
|
895
|
+
@dataclass
|
|
896
|
+
class _RewriteBuilder(Generic[EXPR]):
|
|
897
|
+
lhs: EXPR
|
|
898
|
+
ruleset: Optional[Ruleset]
|
|
899
|
+
|
|
900
|
+
def to(self, rhs: EXPR, *conditions: FactLike) -> Command:
|
|
901
|
+
return Rewrite(
|
|
902
|
+
_ruleset_name(self.ruleset),
|
|
903
|
+
expr_parts(self.lhs).expr,
|
|
904
|
+
expr_parts(rhs).expr,
|
|
905
|
+
_fact_likes(conditions),
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
def __str__(self) -> str:
|
|
909
|
+
return f"rewrite({self.lhs})"
|
|
910
|
+
|
|
911
|
+
|
|
912
|
+
@dataclass
|
|
913
|
+
class _BirewriteBuilder(Generic[EXPR]):
|
|
914
|
+
lhs: EXPR
|
|
915
|
+
ruleset: Optional[Ruleset]
|
|
916
|
+
|
|
917
|
+
def to(self, rhs: EXPR, *conditions: FactLike) -> Command:
|
|
918
|
+
return BiRewrite(
|
|
919
|
+
_ruleset_name(self.ruleset),
|
|
920
|
+
expr_parts(self.lhs).expr,
|
|
921
|
+
expr_parts(rhs).expr,
|
|
922
|
+
_fact_likes(conditions),
|
|
923
|
+
)
|
|
924
|
+
|
|
925
|
+
def __str__(self) -> str:
|
|
926
|
+
return f"birewrite({self.lhs})"
|
|
927
|
+
|
|
928
|
+
|
|
929
|
+
@dataclass
|
|
930
|
+
class _EqBuilder(Generic[EXPR]):
|
|
931
|
+
expr: EXPR
|
|
932
|
+
|
|
933
|
+
def to(self, *exprs: EXPR) -> Fact:
|
|
934
|
+
return Eq(tuple(expr_parts(e).expr for e in (self.expr, *exprs)))
|
|
935
|
+
|
|
936
|
+
def __str__(self) -> str:
|
|
937
|
+
return f"eq({self.expr})"
|
|
938
|
+
|
|
939
|
+
|
|
940
|
+
@dataclass
|
|
941
|
+
class _SetBuilder(Generic[EXPR]):
|
|
942
|
+
lhs: BaseExpr
|
|
943
|
+
|
|
944
|
+
def to(self, rhs: EXPR) -> Action:
|
|
945
|
+
lhs = expr_parts(self.lhs).expr
|
|
946
|
+
if not isinstance(lhs, CallDecl):
|
|
947
|
+
raise ValueError(f"Can only create a call with a call for the lhs, got {lhs}")
|
|
948
|
+
return Set(lhs, expr_parts(rhs).expr)
|
|
949
|
+
|
|
950
|
+
def __str__(self) -> str:
|
|
951
|
+
return f"set_({self.lhs})"
|
|
952
|
+
|
|
953
|
+
|
|
954
|
+
@dataclass
|
|
955
|
+
class _UnionBuilder(Generic[EXPR]):
|
|
956
|
+
lhs: BaseExpr
|
|
957
|
+
|
|
958
|
+
def with_(self, rhs: EXPR) -> Action:
|
|
959
|
+
return Union_(expr_parts(self.lhs).expr, expr_parts(rhs).expr)
|
|
960
|
+
|
|
961
|
+
def __str__(self) -> str:
|
|
962
|
+
return f"union({self.lhs})"
|
|
963
|
+
|
|
964
|
+
|
|
965
|
+
@dataclass
|
|
966
|
+
class _RuleBuilder:
|
|
967
|
+
facts: tuple[Fact, ...]
|
|
968
|
+
name: Optional[str]
|
|
969
|
+
ruleset: Optional[Ruleset]
|
|
970
|
+
|
|
971
|
+
def then(self, *actions: ActionLike) -> Command:
|
|
972
|
+
return Rule(_action_likes(actions), self.facts, self.name or "", _ruleset_name(self.ruleset))
|
|
973
|
+
|
|
974
|
+
|
|
975
|
+
def expr_parts(expr: BaseExpr) -> TypedExprDecl:
|
|
976
|
+
"""
|
|
977
|
+
Returns the underlying type and decleration of the expression. Useful for testing structural equality or debugging.
|
|
978
|
+
"""
|
|
979
|
+
assert isinstance(expr, RuntimeExpr)
|
|
980
|
+
return expr.__egg_typed_expr__
|
|
981
|
+
|
|
982
|
+
|
|
983
|
+
def run(ruleset: Optional[Ruleset] = None, limit: int = 1, *until: Fact) -> Run:
|
|
984
|
+
"""
|
|
985
|
+
Create a run configuration.
|
|
986
|
+
"""
|
|
987
|
+
return Run(limit, _ruleset_name(ruleset), tuple(until))
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
def seq(*schedules: Schedule) -> Schedule:
|
|
991
|
+
"""
|
|
992
|
+
Run a sequence of schedules.
|
|
993
|
+
"""
|
|
994
|
+
return Sequence(tuple(schedules))
|
|
995
|
+
|
|
996
|
+
|
|
997
|
+
CommandLike = Union[Command, BaseExpr]
|
|
998
|
+
|
|
999
|
+
|
|
1000
|
+
def _command_like(command_like: CommandLike) -> Command:
|
|
1001
|
+
if isinstance(command_like, BaseExpr):
|
|
1002
|
+
return expr_action(command_like)
|
|
1003
|
+
return command_like
|
|
1004
|
+
|
|
1005
|
+
|
|
1006
|
+
CommandGenerator = Callable[..., Iterable[Command]]
|
|
1007
|
+
|
|
1008
|
+
|
|
1009
|
+
def _command_generator(gen: CommandGenerator) -> Iterable[Command]:
|
|
1010
|
+
"""
|
|
1011
|
+
Calls the function with variables of the type and name of the arguments.
|
|
1012
|
+
"""
|
|
1013
|
+
hints = get_type_hints(gen)
|
|
1014
|
+
args = (_var(p.name, hints[p.name]) for p in signature(gen).parameters.values())
|
|
1015
|
+
return gen(*args)
|
|
1016
|
+
|
|
1017
|
+
|
|
1018
|
+
ActionLike = Union[Action, BaseExpr]
|
|
1019
|
+
|
|
1020
|
+
|
|
1021
|
+
def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:
|
|
1022
|
+
return tuple(map(_action_like, action_likes))
|
|
1023
|
+
|
|
1024
|
+
|
|
1025
|
+
def _action_like(action_like: ActionLike) -> Action:
|
|
1026
|
+
if isinstance(action_like, BaseExpr):
|
|
1027
|
+
return expr_action(action_like)
|
|
1028
|
+
return action_like
|
|
1029
|
+
|
|
1030
|
+
|
|
1031
|
+
FactLike = Union[Fact, Unit]
|
|
1032
|
+
|
|
1033
|
+
|
|
1034
|
+
def _fact_likes(fact_likes: Iterable[FactLike]) -> tuple[Fact, ...]:
|
|
1035
|
+
return tuple(map(_fact_like, fact_likes))
|
|
1036
|
+
|
|
1037
|
+
|
|
1038
|
+
def _fact_like(fact_like: FactLike) -> Fact:
|
|
1039
|
+
if isinstance(fact_like, BaseExpr):
|
|
1040
|
+
return expr_fact(fact_like)
|
|
1041
|
+
return fact_like
|