egglog 0.4.0__cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +5 -0
- egglog/bindings.cpython-312-x86_64-linux-gnu.so +0 -0
- egglog/bindings.pyi +415 -0
- egglog/builtins.py +345 -0
- egglog/config.py +8 -0
- egglog/declarations.py +934 -0
- egglog/egraph.py +1041 -0
- egglog/examples/README.rst +5 -0
- egglog/examples/__init__.py +0 -0
- egglog/examples/eqsat_basic.py +43 -0
- egglog/examples/fib.py +28 -0
- egglog/examples/lambda.py +310 -0
- egglog/examples/matrix.py +184 -0
- egglog/examples/ndarrays.py +159 -0
- egglog/examples/resolution.py +84 -0
- egglog/examples/schedule_demo.py +33 -0
- egglog/ipython_magic.py +40 -0
- egglog/monkeypatch.py +33 -0
- egglog/py.typed +0 -0
- egglog/runtime.py +304 -0
- egglog/type_constraint_solver.py +79 -0
- egglog-0.4.0.dist-info/METADATA +53 -0
- egglog-0.4.0.dist-info/RECORD +25 -0
- egglog-0.4.0.dist-info/WHEEL +4 -0
- egglog-0.4.0.dist-info/license_files/LICENSE +21 -0
egglog/declarations.py
ADDED
|
@@ -0,0 +1,934 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data only descriptions of the components of an egraph and the expressions.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import itertools
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import ClassVar, Iterable, Optional, Union
|
|
12
|
+
|
|
13
|
+
from typing_extensions import assert_never
|
|
14
|
+
|
|
15
|
+
from . import bindings
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"Declarations",
|
|
19
|
+
"ModuleDeclarations",
|
|
20
|
+
"JustTypeRef",
|
|
21
|
+
"ClassTypeVarRef",
|
|
22
|
+
"TypeRefWithVars",
|
|
23
|
+
"TypeOrVarRef",
|
|
24
|
+
"FunctionRef",
|
|
25
|
+
"MethodRef",
|
|
26
|
+
"ClassMethodRef",
|
|
27
|
+
"ClassVariableRef",
|
|
28
|
+
"FunctionCallableRef",
|
|
29
|
+
"CallableRef",
|
|
30
|
+
"ConstantRef",
|
|
31
|
+
"FunctionDecl",
|
|
32
|
+
"VarDecl",
|
|
33
|
+
"LitType",
|
|
34
|
+
"LitDecl",
|
|
35
|
+
"CallDecl",
|
|
36
|
+
"ExprDecl",
|
|
37
|
+
"TypedExprDecl",
|
|
38
|
+
"ClassDecl",
|
|
39
|
+
"Command",
|
|
40
|
+
"Action",
|
|
41
|
+
"ExprAction",
|
|
42
|
+
"Fact",
|
|
43
|
+
"Rewrite",
|
|
44
|
+
"BiRewrite",
|
|
45
|
+
"Eq",
|
|
46
|
+
"ExprFact",
|
|
47
|
+
"Rule",
|
|
48
|
+
"Let",
|
|
49
|
+
"Set",
|
|
50
|
+
"Delete",
|
|
51
|
+
"Union_",
|
|
52
|
+
"Panic",
|
|
53
|
+
"Action",
|
|
54
|
+
"Schedule",
|
|
55
|
+
"Sequence",
|
|
56
|
+
"Run",
|
|
57
|
+
]
|
|
58
|
+
# Special methods which we might want to use as functions
|
|
59
|
+
# Mapping to the operator they represent for pretty printing them
|
|
60
|
+
# https://docs.python.org/3/reference/datamodel.html
|
|
61
|
+
BINARY_METHODS = {
|
|
62
|
+
"__lt__": "<",
|
|
63
|
+
"__le__": "<=",
|
|
64
|
+
"__eq__": "==",
|
|
65
|
+
"__ne__": "!=",
|
|
66
|
+
"__gt__": ">",
|
|
67
|
+
"__ge__": ">=",
|
|
68
|
+
# Numeric
|
|
69
|
+
"__add__": "+",
|
|
70
|
+
"__sub__": "-",
|
|
71
|
+
"__mul__": "*",
|
|
72
|
+
"__matmul__": "@",
|
|
73
|
+
"__truediv__": "/",
|
|
74
|
+
"__floordiv__": "//",
|
|
75
|
+
"__mod__": "%",
|
|
76
|
+
"__divmod__": "divmod",
|
|
77
|
+
"__pow__": "**",
|
|
78
|
+
"__lshift__": "<<",
|
|
79
|
+
"__rshift__": ">>",
|
|
80
|
+
"__and__": "&",
|
|
81
|
+
"__xor__": "^",
|
|
82
|
+
"__or__": "|",
|
|
83
|
+
}
|
|
84
|
+
UNARY_METHODS = {
|
|
85
|
+
"__pos__": "+",
|
|
86
|
+
"__neg__": "-",
|
|
87
|
+
"__invert__": "~",
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass
|
|
92
|
+
class Declarations:
|
|
93
|
+
_functions: dict[str, FunctionDecl] = field(default_factory=dict)
|
|
94
|
+
_classes: dict[str, ClassDecl] = field(default_factory=dict)
|
|
95
|
+
_constants: dict[str, JustTypeRef] = field(default_factory=dict)
|
|
96
|
+
|
|
97
|
+
# Bidirectional mapping between egg function names and python callable references.
|
|
98
|
+
# Note that there are possibly mutliple callable references for a single egg function name, like `+`
|
|
99
|
+
# for both int and rational classes.
|
|
100
|
+
_egg_fn_to_callable_refs: defaultdict[str, set[CallableRef]] = field(default_factory=lambda: defaultdict(set))
|
|
101
|
+
_callable_ref_to_egg_fn: dict[CallableRef, str] = field(default_factory=dict)
|
|
102
|
+
|
|
103
|
+
# Bidirectional mapping between egg sort names and python type references.
|
|
104
|
+
_egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)
|
|
105
|
+
_type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
|
|
106
|
+
|
|
107
|
+
def set_function_decl(self, ref: FunctionCallableRef, decl: FunctionDecl) -> None:
|
|
108
|
+
"""
|
|
109
|
+
Sets a function declaration for the given callable reference.
|
|
110
|
+
"""
|
|
111
|
+
if isinstance(ref, FunctionRef):
|
|
112
|
+
if ref.name in self._functions:
|
|
113
|
+
raise ValueError(f"Function {ref.name} already registered")
|
|
114
|
+
self._functions[ref.name] = decl
|
|
115
|
+
elif isinstance(ref, MethodRef):
|
|
116
|
+
if ref.method_name in self._classes[ref.class_name].methods:
|
|
117
|
+
raise ValueError(f"Method {ref.class_name}.{ref.method_name} already registered")
|
|
118
|
+
self._classes[ref.class_name].methods[ref.method_name] = decl
|
|
119
|
+
elif isinstance(ref, ClassMethodRef):
|
|
120
|
+
if ref.method_name in self._classes[ref.class_name].class_methods:
|
|
121
|
+
raise ValueError(f"Class method {ref.class_name}.{ref.method_name} already registered")
|
|
122
|
+
self._classes[ref.class_name].class_methods[ref.method_name] = decl
|
|
123
|
+
else:
|
|
124
|
+
assert_never(ref)
|
|
125
|
+
|
|
126
|
+
def set_constant_type(self, ref: ConstantCallableRef, tp: JustTypeRef) -> None:
|
|
127
|
+
if isinstance(ref, ConstantRef):
|
|
128
|
+
if ref.name in self._constants:
|
|
129
|
+
raise ValueError(f"Constant {ref.name} already registered")
|
|
130
|
+
self._constants[ref.name] = tp
|
|
131
|
+
elif isinstance(ref, ClassVariableRef):
|
|
132
|
+
if ref.variable_name in self._classes[ref.class_name].class_variables:
|
|
133
|
+
raise ValueError(f"Class variable {ref.class_name}.{ref.variable_name} already registered")
|
|
134
|
+
self._classes[ref.class_name].class_variables[ref.variable_name] = tp
|
|
135
|
+
else:
|
|
136
|
+
assert_never(ref)
|
|
137
|
+
|
|
138
|
+
def register_callable_ref(self, ref: CallableRef, egg_name: str) -> None:
|
|
139
|
+
"""
|
|
140
|
+
Registers a callable reference with the given egg name. The callable's function needs to be registered
|
|
141
|
+
first.
|
|
142
|
+
"""
|
|
143
|
+
if ref in self._callable_ref_to_egg_fn:
|
|
144
|
+
raise ValueError(f"Callable ref {ref} already registered")
|
|
145
|
+
self._callable_ref_to_egg_fn[ref] = egg_name
|
|
146
|
+
self._egg_fn_to_callable_refs[egg_name].add(ref)
|
|
147
|
+
|
|
148
|
+
def get_function_decl(self, ref: FunctionCallableRef) -> FunctionDecl:
|
|
149
|
+
if isinstance(ref, FunctionRef):
|
|
150
|
+
return self._functions[ref.name]
|
|
151
|
+
elif isinstance(ref, MethodRef):
|
|
152
|
+
return self._classes[ref.class_name].methods[ref.method_name]
|
|
153
|
+
elif isinstance(ref, ClassMethodRef):
|
|
154
|
+
return self._classes[ref.class_name].class_methods[ref.method_name]
|
|
155
|
+
assert_never(ref)
|
|
156
|
+
|
|
157
|
+
def get_constant_type(self, ref: ConstantCallableRef) -> JustTypeRef:
|
|
158
|
+
if isinstance(ref, ConstantRef):
|
|
159
|
+
return self._constants[ref.name]
|
|
160
|
+
elif isinstance(ref, ClassVariableRef):
|
|
161
|
+
return self._classes[ref.class_name].class_variables[ref.variable_name]
|
|
162
|
+
assert_never(ref)
|
|
163
|
+
|
|
164
|
+
def get_callable_refs(self, egg_name: str) -> Iterable[CallableRef]:
|
|
165
|
+
return self._egg_fn_to_callable_refs[egg_name]
|
|
166
|
+
|
|
167
|
+
def get_egg_fn(self, ref: CallableRef) -> str:
|
|
168
|
+
return self._callable_ref_to_egg_fn[ref]
|
|
169
|
+
|
|
170
|
+
def get_egg_sort(self, ref: JustTypeRef) -> str:
|
|
171
|
+
return self._type_ref_to_egg_sort[ref]
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
@dataclass
|
|
175
|
+
class ModuleDeclarations:
|
|
176
|
+
"""
|
|
177
|
+
A set of working declerations for a module.
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
# The modules declarations we have, which we can edit
|
|
181
|
+
_decl: Declarations
|
|
182
|
+
# A list of other declarations we can use, but not edit
|
|
183
|
+
_included_decls: list[Declarations] = field(default_factory=list, repr=False)
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def all_decls(self) -> Iterable[Declarations]:
|
|
187
|
+
return itertools.chain([self._decl], self._included_decls)
|
|
188
|
+
|
|
189
|
+
def get_function_decl(self, ref: CallableRef) -> FunctionDecl:
|
|
190
|
+
if isinstance(ref, (ClassVariableRef, ConstantRef)):
|
|
191
|
+
for decls in self.all_decls:
|
|
192
|
+
try:
|
|
193
|
+
return decls.get_constant_type(ref).to_constant_function_decl()
|
|
194
|
+
except KeyError:
|
|
195
|
+
pass
|
|
196
|
+
raise KeyError(f"Constant {ref} not found")
|
|
197
|
+
elif isinstance(ref, (FunctionRef, MethodRef, ClassMethodRef)):
|
|
198
|
+
for decls in self.all_decls:
|
|
199
|
+
try:
|
|
200
|
+
return decls.get_function_decl(ref)
|
|
201
|
+
except KeyError:
|
|
202
|
+
pass
|
|
203
|
+
raise KeyError(f"Function {ref} not found")
|
|
204
|
+
else:
|
|
205
|
+
assert_never(ref)
|
|
206
|
+
|
|
207
|
+
def get_callable_refs(self, egg_name: str) -> Iterable[CallableRef]:
|
|
208
|
+
return itertools.chain.from_iterable(decls.get_callable_refs(egg_name) for decls in self.all_decls)
|
|
209
|
+
|
|
210
|
+
def get_egg_fn(self, ref: CallableRef) -> str:
|
|
211
|
+
for decls in self.all_decls:
|
|
212
|
+
try:
|
|
213
|
+
return decls.get_egg_fn(ref)
|
|
214
|
+
except KeyError:
|
|
215
|
+
pass
|
|
216
|
+
raise KeyError(f"Callable ref {ref} not found")
|
|
217
|
+
|
|
218
|
+
def get_egg_sort(self, ref: JustTypeRef) -> str:
|
|
219
|
+
for decls in self.all_decls:
|
|
220
|
+
try:
|
|
221
|
+
return decls.get_egg_sort(ref)
|
|
222
|
+
except KeyError:
|
|
223
|
+
pass
|
|
224
|
+
raise KeyError(f"Type {ref} not found")
|
|
225
|
+
|
|
226
|
+
def get_class_decl(self, name: str) -> ClassDecl:
|
|
227
|
+
for decls in self.all_decls:
|
|
228
|
+
try:
|
|
229
|
+
return decls._classes[name]
|
|
230
|
+
except KeyError:
|
|
231
|
+
pass
|
|
232
|
+
raise KeyError(f"Class {name} not found")
|
|
233
|
+
|
|
234
|
+
def get_registered_class_args(self, cls_name: str) -> tuple[JustTypeRef, ...]:
|
|
235
|
+
"""
|
|
236
|
+
Given a class name, returns the first typevar regsisted with args of that class.
|
|
237
|
+
"""
|
|
238
|
+
for decl in self.all_decls:
|
|
239
|
+
for tp in decl._type_ref_to_egg_sort.keys():
|
|
240
|
+
if tp.name == cls_name and tp.args:
|
|
241
|
+
return tp.args
|
|
242
|
+
return ()
|
|
243
|
+
|
|
244
|
+
def register_class(self, name: str, n_type_vars: int, egg_sort: Optional[str]) -> Iterable[bindings._Command]:
|
|
245
|
+
# Register class first
|
|
246
|
+
if name in self._decl._classes:
|
|
247
|
+
raise ValueError(f"Class {name} already registered")
|
|
248
|
+
decl = ClassDecl(n_type_vars=n_type_vars)
|
|
249
|
+
self._decl._classes[name] = decl
|
|
250
|
+
_egg_sort, cmds = self.register_sort(JustTypeRef(name), egg_sort)
|
|
251
|
+
return cmds
|
|
252
|
+
|
|
253
|
+
def register_sort(
|
|
254
|
+
self, ref: JustTypeRef, egg_name: Optional[str] = None
|
|
255
|
+
) -> tuple[str, Iterable[bindings._Command]]:
|
|
256
|
+
"""
|
|
257
|
+
Register a sort with the given name. If no name is given, one is generated.
|
|
258
|
+
|
|
259
|
+
If this is a type called with generic args, register the generic args as well.
|
|
260
|
+
"""
|
|
261
|
+
# If the sort is already registered, do nothing
|
|
262
|
+
try:
|
|
263
|
+
egg_sort = self.get_egg_sort(ref)
|
|
264
|
+
except KeyError:
|
|
265
|
+
pass
|
|
266
|
+
else:
|
|
267
|
+
return (egg_sort, [])
|
|
268
|
+
egg_name = egg_name or ref.generate_egg_name()
|
|
269
|
+
if egg_name in self._decl._egg_sort_to_type_ref:
|
|
270
|
+
raise ValueError(f"Sort {egg_name} is already registered.")
|
|
271
|
+
self._decl._egg_sort_to_type_ref[egg_name] = ref
|
|
272
|
+
self._decl._type_ref_to_egg_sort[ref] = egg_name
|
|
273
|
+
return egg_name, ref.to_commands(self)
|
|
274
|
+
|
|
275
|
+
def register_function_callable(
|
|
276
|
+
self,
|
|
277
|
+
ref: FunctionCallableRef,
|
|
278
|
+
fn_decl: FunctionDecl,
|
|
279
|
+
egg_name: Optional[str],
|
|
280
|
+
cost: Optional[int],
|
|
281
|
+
default: Optional[ExprDecl],
|
|
282
|
+
merge: Optional[ExprDecl],
|
|
283
|
+
merge_action: Iterable[Action],
|
|
284
|
+
) -> Iterable[bindings._Command]:
|
|
285
|
+
"""
|
|
286
|
+
Registers a callable with the given egg name. The callable's function needs to be registered
|
|
287
|
+
first.
|
|
288
|
+
"""
|
|
289
|
+
egg_name = egg_name or ref.generate_egg_name()
|
|
290
|
+
self._decl.register_callable_ref(ref, egg_name)
|
|
291
|
+
self._decl.set_function_decl(ref, fn_decl)
|
|
292
|
+
return fn_decl.to_commands(self, egg_name, cost, default, merge, merge_action)
|
|
293
|
+
|
|
294
|
+
def register_constant_callable(
|
|
295
|
+
self, ref: ConstantCallableRef, type_ref: JustTypeRef, egg_name: Optional[str]
|
|
296
|
+
) -> Iterable[bindings._Command]:
|
|
297
|
+
egg_function = ref.generate_egg_name()
|
|
298
|
+
self._decl.register_callable_ref(ref, egg_function)
|
|
299
|
+
self._decl.set_constant_type(ref, type_ref)
|
|
300
|
+
# Create a function decleartion for a constant function. This is similar to how egglog compiles
|
|
301
|
+
# the `declare` command.
|
|
302
|
+
return FunctionDecl((), type_ref.to_var()).to_commands(self, egg_name or ref.generate_egg_name())
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
# Have two different types of type refs, one that can include vars recursively and one that cannot.
|
|
306
|
+
# We only use the one with vars for classmethods and methods, and the other one for egg references as
|
|
307
|
+
# well as runtime values.
|
|
308
|
+
@dataclass(frozen=True)
|
|
309
|
+
class JustTypeRef:
|
|
310
|
+
name: str
|
|
311
|
+
args: tuple[JustTypeRef, ...] = ()
|
|
312
|
+
|
|
313
|
+
def generate_egg_name(self) -> str:
|
|
314
|
+
"""
|
|
315
|
+
Generates an egg sort name for this type reference by linearizing the type.
|
|
316
|
+
"""
|
|
317
|
+
if not self.args:
|
|
318
|
+
return self.name
|
|
319
|
+
args = ", ".join(a.generate_egg_name() for a in self.args)
|
|
320
|
+
return f"{self.name}[{args}]"
|
|
321
|
+
|
|
322
|
+
def to_commands(self, mod_decls: ModuleDeclarations) -> Iterable[bindings._Command]:
|
|
323
|
+
"""
|
|
324
|
+
Returns commands to register this as a sort, as well as for any of its arguments.
|
|
325
|
+
"""
|
|
326
|
+
egg_name = mod_decls.get_egg_sort(self)
|
|
327
|
+
arg_sorts: list[bindings._Expr] = []
|
|
328
|
+
for arg in self.args:
|
|
329
|
+
egg_sort, cmds = mod_decls.register_sort(arg)
|
|
330
|
+
arg_sorts.append(bindings.Var(egg_sort))
|
|
331
|
+
yield from cmds
|
|
332
|
+
yield bindings.Sort(egg_name, (self.name, arg_sorts) if arg_sorts else None)
|
|
333
|
+
|
|
334
|
+
def to_var(self) -> TypeRefWithVars:
|
|
335
|
+
return TypeRefWithVars(self.name, tuple(a.to_var() for a in self.args))
|
|
336
|
+
|
|
337
|
+
def pretty(self) -> str:
|
|
338
|
+
if not self.args:
|
|
339
|
+
return self.name
|
|
340
|
+
args = ", ".join(a.pretty() for a in self.args)
|
|
341
|
+
return f"{self.name}[{args}]"
|
|
342
|
+
|
|
343
|
+
def to_constant_function_decl(self) -> FunctionDecl:
|
|
344
|
+
"""
|
|
345
|
+
Create a function declaration for a constant function. This is similar to how egglog compiles
|
|
346
|
+
the `constant` command.
|
|
347
|
+
"""
|
|
348
|
+
return FunctionDecl(arg_types=(), return_type=self.to_var(), var_arg_type=None)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
@dataclass(frozen=True)
|
|
352
|
+
class ClassTypeVarRef:
|
|
353
|
+
"""
|
|
354
|
+
A class type variable represents one of the types of the class, if it is a generic
|
|
355
|
+
class.
|
|
356
|
+
"""
|
|
357
|
+
|
|
358
|
+
index: int
|
|
359
|
+
|
|
360
|
+
def to_just(self) -> JustTypeRef:
|
|
361
|
+
raise NotImplementedError("egglog does not support generic classes yet.")
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
@dataclass(frozen=True)
|
|
365
|
+
class TypeRefWithVars:
|
|
366
|
+
name: str
|
|
367
|
+
args: tuple[TypeOrVarRef, ...] = ()
|
|
368
|
+
|
|
369
|
+
def to_just(self) -> JustTypeRef:
|
|
370
|
+
return JustTypeRef(self.name, tuple(a.to_just() for a in self.args))
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
TypeOrVarRef = Union[ClassTypeVarRef, TypeRefWithVars]
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
@dataclass(frozen=True)
|
|
377
|
+
class FunctionRef:
|
|
378
|
+
name: str
|
|
379
|
+
|
|
380
|
+
def generate_egg_name(self) -> str:
|
|
381
|
+
return self.name
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
@dataclass(frozen=True)
|
|
385
|
+
class MethodRef:
|
|
386
|
+
class_name: str
|
|
387
|
+
method_name: str
|
|
388
|
+
|
|
389
|
+
def generate_egg_name(self) -> str:
|
|
390
|
+
return f"{self.class_name}.{self.method_name}"
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
@dataclass(frozen=True)
|
|
394
|
+
class ClassMethodRef:
|
|
395
|
+
class_name: str
|
|
396
|
+
method_name: str
|
|
397
|
+
|
|
398
|
+
def to_egg(self, decls: Declarations) -> str:
|
|
399
|
+
return decls.get_egg_fn(self)
|
|
400
|
+
|
|
401
|
+
def generate_egg_name(self) -> str:
|
|
402
|
+
return f"{self.class_name}.{self.method_name}"
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
@dataclass(frozen=True)
|
|
406
|
+
class ConstantRef:
|
|
407
|
+
name: str
|
|
408
|
+
|
|
409
|
+
def generate_egg_name(self) -> str:
|
|
410
|
+
return self.name
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
@dataclass(frozen=True)
|
|
414
|
+
class ClassVariableRef:
|
|
415
|
+
class_name: str
|
|
416
|
+
variable_name: str
|
|
417
|
+
|
|
418
|
+
def generate_egg_name(self) -> str:
|
|
419
|
+
return f"{self.class_name}.{self.variable_name}"
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
ConstantCallableRef = Union[ConstantRef, ClassVariableRef]
|
|
423
|
+
FunctionCallableRef = Union[FunctionRef, MethodRef, ClassMethodRef]
|
|
424
|
+
CallableRef = Union[ConstantCallableRef, FunctionCallableRef]
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
@dataclass(frozen=True)
|
|
428
|
+
class FunctionDecl:
|
|
429
|
+
# TODO: Add arg name to arg so can call with keyword arg
|
|
430
|
+
arg_types: tuple[TypeOrVarRef, ...]
|
|
431
|
+
return_type: TypeOrVarRef
|
|
432
|
+
var_arg_type: Optional[TypeOrVarRef] = None
|
|
433
|
+
|
|
434
|
+
def to_commands(
|
|
435
|
+
self,
|
|
436
|
+
mod_decls: ModuleDeclarations,
|
|
437
|
+
egg_name: str,
|
|
438
|
+
cost: Optional[int] = None,
|
|
439
|
+
default: Optional[ExprDecl] = None,
|
|
440
|
+
merge: Optional[ExprDecl] = None,
|
|
441
|
+
merge_action: Iterable[Action] = (),
|
|
442
|
+
) -> Iterable[bindings._Command]:
|
|
443
|
+
if self.var_arg_type is not None:
|
|
444
|
+
raise NotImplementedError("egglog does not support variable arguments yet.")
|
|
445
|
+
arg_sorts: list[str] = []
|
|
446
|
+
for a in self.arg_types:
|
|
447
|
+
# Remove all vars from the type refs, raising an errory if we find one,
|
|
448
|
+
# since we cannot create egg functions with vars
|
|
449
|
+
arg_sort, cmds = mod_decls.register_sort(a.to_just())
|
|
450
|
+
yield from cmds
|
|
451
|
+
arg_sorts.append(arg_sort)
|
|
452
|
+
return_sort, cmds = mod_decls.register_sort(self.return_type.to_just())
|
|
453
|
+
yield from cmds
|
|
454
|
+
|
|
455
|
+
egg_fn_decl = bindings.FunctionDecl(
|
|
456
|
+
egg_name,
|
|
457
|
+
bindings.Schema(arg_sorts, return_sort),
|
|
458
|
+
default.to_egg(mod_decls) if default else None,
|
|
459
|
+
merge.to_egg(mod_decls) if merge else None,
|
|
460
|
+
[a._to_egg_action(mod_decls) for a in merge_action],
|
|
461
|
+
cost,
|
|
462
|
+
)
|
|
463
|
+
yield bindings.Function(egg_fn_decl)
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
@dataclass(frozen=True)
|
|
467
|
+
class VarDecl:
|
|
468
|
+
name: str
|
|
469
|
+
|
|
470
|
+
@classmethod
|
|
471
|
+
def from_egg(cls, var: bindings.Var) -> TypedExprDecl:
|
|
472
|
+
raise NotImplementedError("Cannot turn var into egg type because typing unknown.")
|
|
473
|
+
|
|
474
|
+
def to_egg(self, _decls: ModuleDeclarations) -> bindings.Var:
|
|
475
|
+
return bindings.Var(self.name)
|
|
476
|
+
|
|
477
|
+
def pretty(self, **kwargs) -> str:
|
|
478
|
+
return self.name
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
LitType = Union[int, str, float, None]
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
@dataclass(frozen=True)
|
|
485
|
+
class LitDecl:
|
|
486
|
+
value: LitType
|
|
487
|
+
|
|
488
|
+
@classmethod
|
|
489
|
+
def from_egg(cls, lit: bindings.Lit) -> TypedExprDecl:
|
|
490
|
+
if isinstance(lit.value, bindings.Int):
|
|
491
|
+
return TypedExprDecl(JustTypeRef("i64"), cls(lit.value.value))
|
|
492
|
+
if isinstance(lit.value, bindings.String):
|
|
493
|
+
return TypedExprDecl(JustTypeRef("String"), cls(lit.value.value))
|
|
494
|
+
if isinstance(lit.value, bindings.F64):
|
|
495
|
+
return TypedExprDecl(JustTypeRef("f64"), cls(lit.value.value))
|
|
496
|
+
elif isinstance(lit.value, bindings.Unit):
|
|
497
|
+
return TypedExprDecl(JustTypeRef("Unit"), cls(None))
|
|
498
|
+
assert_never(lit.value)
|
|
499
|
+
|
|
500
|
+
def to_egg(self, _decls: ModuleDeclarations) -> bindings.Lit:
|
|
501
|
+
if self.value is None:
|
|
502
|
+
return bindings.Lit(bindings.Unit())
|
|
503
|
+
if isinstance(self.value, int):
|
|
504
|
+
return bindings.Lit(bindings.Int(self.value))
|
|
505
|
+
if isinstance(self.value, float):
|
|
506
|
+
return bindings.Lit(bindings.F64(self.value))
|
|
507
|
+
if isinstance(self.value, str):
|
|
508
|
+
return bindings.Lit(bindings.String(self.value))
|
|
509
|
+
assert_never(self.value)
|
|
510
|
+
|
|
511
|
+
def pretty(self, wrap_lit=True, **kwargs) -> str:
|
|
512
|
+
"""
|
|
513
|
+
Returns a string representation of the literal.
|
|
514
|
+
|
|
515
|
+
:param wrap_lit: If True, wraps the literal in a call to the literal constructor.
|
|
516
|
+
"""
|
|
517
|
+
if self.value is None:
|
|
518
|
+
return "Unit()"
|
|
519
|
+
if isinstance(self.value, int):
|
|
520
|
+
return f"i64({self.value})" if wrap_lit else str(self.value)
|
|
521
|
+
if isinstance(self.value, float):
|
|
522
|
+
return f"f64({self.value})" if wrap_lit else str(self.value)
|
|
523
|
+
if isinstance(self.value, str):
|
|
524
|
+
return f"String({repr(self.value)})" if wrap_lit else repr(self.value)
|
|
525
|
+
assert_never(self.value)
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
@dataclass(frozen=True)
|
|
529
|
+
class CallDecl:
|
|
530
|
+
callable: CallableRef
|
|
531
|
+
args: tuple[TypedExprDecl, ...] = ()
|
|
532
|
+
# type parameters that were bound to the callable, if it is a classmethod
|
|
533
|
+
# Used for pretty printing classmethod calls with type parameters
|
|
534
|
+
bound_tp_params: Optional[tuple[JustTypeRef, ...]] = None
|
|
535
|
+
|
|
536
|
+
def __post_init__(self):
|
|
537
|
+
if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef):
|
|
538
|
+
raise ValueError("Cannot bind type parameters to a non-class method callable.")
|
|
539
|
+
|
|
540
|
+
@classmethod
|
|
541
|
+
def from_egg(cls, mod_decls: ModuleDeclarations, call: bindings.Call) -> TypedExprDecl:
|
|
542
|
+
from .type_constraint_solver import TypeConstraintSolver
|
|
543
|
+
|
|
544
|
+
results = tuple(TypedExprDecl.from_egg(mod_decls, a) for a in call.args)
|
|
545
|
+
arg_types = tuple(r.tp for r in results)
|
|
546
|
+
|
|
547
|
+
# Find the first callable ref that matches the call
|
|
548
|
+
for callable_ref in mod_decls.get_callable_refs(call.name):
|
|
549
|
+
# If this is a classmethod, we might need the type params that were bound for this type
|
|
550
|
+
# egglog currently only allows one instantiated type of any generic sort to be used in any program
|
|
551
|
+
# So we just lookup what args were registered for this sort
|
|
552
|
+
if isinstance(callable_ref, ClassMethodRef):
|
|
553
|
+
cls_args = mod_decls.get_registered_class_args(callable_ref.class_name)
|
|
554
|
+
tcs = TypeConstraintSolver.from_type_parameters(cls_args)
|
|
555
|
+
else:
|
|
556
|
+
tcs = TypeConstraintSolver()
|
|
557
|
+
fn_decl = mod_decls.get_function_decl(callable_ref)
|
|
558
|
+
return_tp = tcs.infer_return_type(fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, arg_types)
|
|
559
|
+
return TypedExprDecl(return_tp, cls(callable_ref, tuple(results)))
|
|
560
|
+
raise ValueError(f"Could not find callable ref for call {call}")
|
|
561
|
+
|
|
562
|
+
def to_egg(self, mod_decls: ModuleDeclarations) -> bindings.Call:
|
|
563
|
+
"""Convert a Call to an egg Call."""
|
|
564
|
+
egg_fn = mod_decls.get_egg_fn(self.callable)
|
|
565
|
+
return bindings.Call(egg_fn, [a.to_egg(mod_decls) for a in self.args])
|
|
566
|
+
|
|
567
|
+
def pretty(self, parens=True, **kwargs) -> str:
|
|
568
|
+
"""
|
|
569
|
+
Pretty print the call.
|
|
570
|
+
|
|
571
|
+
:param parens: If true, wrap the call in parens if it is a binary or unary method call.
|
|
572
|
+
"""
|
|
573
|
+
ref, args = self.callable, [a.expr for a in self.args]
|
|
574
|
+
if isinstance(ref, FunctionRef):
|
|
575
|
+
fn_str = ref.name
|
|
576
|
+
elif isinstance(ref, ClassMethodRef):
|
|
577
|
+
tp_ref = JustTypeRef(ref.class_name, self.bound_tp_params or ())
|
|
578
|
+
if ref.method_name == "__init__":
|
|
579
|
+
fn_str = tp_ref.pretty()
|
|
580
|
+
else:
|
|
581
|
+
fn_str = f"{tp_ref.pretty()}.{ref.method_name}"
|
|
582
|
+
elif isinstance(ref, MethodRef):
|
|
583
|
+
name = ref.method_name
|
|
584
|
+
slf, *args = args
|
|
585
|
+
if name in UNARY_METHODS:
|
|
586
|
+
return f"{UNARY_METHODS[name]}{slf.pretty()}"
|
|
587
|
+
elif name in BINARY_METHODS:
|
|
588
|
+
assert len(args) == 1
|
|
589
|
+
expr = f"{slf.pretty()} {BINARY_METHODS[name]} {args[0].pretty(wrap_lit=False)}"
|
|
590
|
+
return expr if not parens else f"({expr})"
|
|
591
|
+
elif name == "__getitem__":
|
|
592
|
+
assert len(args) == 1
|
|
593
|
+
return f"{slf.pretty()}[{args[0].pretty(wrap_lit=False)}]"
|
|
594
|
+
elif name == "__call__":
|
|
595
|
+
return f"{slf.pretty()}({', '.join(a.pretty(wrap_lit=False) for a in args)})"
|
|
596
|
+
fn_str = f"{slf.pretty()}.{name}"
|
|
597
|
+
elif isinstance(ref, ConstantRef):
|
|
598
|
+
return ref.name
|
|
599
|
+
elif isinstance(ref, ClassVariableRef):
|
|
600
|
+
return f"{ref.class_name}.{ref.variable_name}"
|
|
601
|
+
else:
|
|
602
|
+
assert_never(ref)
|
|
603
|
+
return f"{fn_str}({', '.join(a.pretty(wrap_lit=False) for a in args)})"
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
def test_expr_pretty():
|
|
607
|
+
assert VarDecl("x").pretty() == "x"
|
|
608
|
+
assert LitDecl(42).pretty() == "i64(42)"
|
|
609
|
+
assert LitDecl("foo").pretty() == 'String("foo")'
|
|
610
|
+
assert LitDecl(None).pretty() == "unit()"
|
|
611
|
+
|
|
612
|
+
def v(x: str) -> TypedExprDecl:
|
|
613
|
+
return TypedExprDecl(JustTypeRef(""), VarDecl(x))
|
|
614
|
+
|
|
615
|
+
assert CallDecl(FunctionRef("foo"), (v("x"),)).pretty() == "foo(x)"
|
|
616
|
+
assert CallDecl(FunctionRef("foo"), (v("x"), v("y"), v("z"))).pretty() == "foo(x, y, z)"
|
|
617
|
+
assert CallDecl(MethodRef("foo", "__add__"), (v("x"), v("y"))).pretty() == "x + y"
|
|
618
|
+
assert CallDecl(MethodRef("foo", "__getitem__"), (v("x"), v("y"))).pretty() == "x[y]"
|
|
619
|
+
assert CallDecl(ClassMethodRef("foo", "__init__"), (v("x"), v("y"))).pretty() == "foo(x, y)"
|
|
620
|
+
assert CallDecl(ClassMethodRef("foo", "bar"), (v("x"), v("y"))).pretty() == "foo.bar(x, y)"
|
|
621
|
+
assert CallDecl(MethodRef("foo", "__call__"), (v("x"), v("y"))).pretty() == "x(y)"
|
|
622
|
+
assert (
|
|
623
|
+
CallDecl(
|
|
624
|
+
ClassMethodRef("Map", "__init__"),
|
|
625
|
+
(),
|
|
626
|
+
(JustTypeRef("i64"), JustTypeRef("Unit")),
|
|
627
|
+
).pretty()
|
|
628
|
+
== "Map[i64, Unit]()"
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
ExprDecl = Union[VarDecl, LitDecl, CallDecl]
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
@dataclass(frozen=True)
|
|
636
|
+
class TypedExprDecl:
|
|
637
|
+
tp: JustTypeRef
|
|
638
|
+
expr: ExprDecl
|
|
639
|
+
|
|
640
|
+
@classmethod
|
|
641
|
+
def from_egg(cls, mod_decls: ModuleDeclarations, expr: bindings._Expr) -> TypedExprDecl:
|
|
642
|
+
if isinstance(expr, bindings.Var):
|
|
643
|
+
return VarDecl.from_egg(expr)
|
|
644
|
+
if isinstance(expr, bindings.Lit):
|
|
645
|
+
return LitDecl.from_egg(expr)
|
|
646
|
+
if isinstance(expr, bindings.Call):
|
|
647
|
+
return CallDecl.from_egg(mod_decls, expr)
|
|
648
|
+
assert_never(expr)
|
|
649
|
+
|
|
650
|
+
def to_egg(self, decls: ModuleDeclarations) -> bindings._Expr:
|
|
651
|
+
return self.expr.to_egg(decls)
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
@dataclass
|
|
655
|
+
class ClassDecl:
|
|
656
|
+
methods: dict[str, FunctionDecl] = field(default_factory=dict)
|
|
657
|
+
class_methods: dict[str, FunctionDecl] = field(default_factory=dict)
|
|
658
|
+
class_variables: dict[str, JustTypeRef] = field(default_factory=dict)
|
|
659
|
+
n_type_vars: int = 0
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
class Command(ABC):
|
|
663
|
+
"""
|
|
664
|
+
A command that can be executed in the egg interpreter.
|
|
665
|
+
|
|
666
|
+
We only use this for commands which return no result and don't create new Python objects.
|
|
667
|
+
|
|
668
|
+
Anything that can be passed to the `register` function in a Module is a Command.
|
|
669
|
+
"""
|
|
670
|
+
|
|
671
|
+
@abstractmethod
|
|
672
|
+
def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings._Command:
|
|
673
|
+
raise NotImplementedError
|
|
674
|
+
|
|
675
|
+
@abstractmethod
|
|
676
|
+
def __str__(self) -> str:
|
|
677
|
+
raise NotImplementedError
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
@dataclass(frozen=True)
|
|
681
|
+
class Rewrite(Command):
|
|
682
|
+
_ruleset: str
|
|
683
|
+
_lhs: ExprDecl
|
|
684
|
+
_rhs: ExprDecl
|
|
685
|
+
_conditions: tuple[Fact, ...]
|
|
686
|
+
_fn_name: ClassVar[str] = "rewrite"
|
|
687
|
+
|
|
688
|
+
def __str__(self) -> str:
|
|
689
|
+
args_str = ", ".join(map(str, [self._rhs.pretty(), *self._conditions]))
|
|
690
|
+
return f"{self._fn_name}({self._lhs.pretty()}).to({args_str})"
|
|
691
|
+
|
|
692
|
+
def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings._Command:
|
|
693
|
+
return bindings.RewriteCommand(self._ruleset, self._to_egg_rewrite(mod_decls))
|
|
694
|
+
|
|
695
|
+
def _to_egg_rewrite(self, mod_decls: ModuleDeclarations) -> bindings.Rewrite:
|
|
696
|
+
return bindings.Rewrite(
|
|
697
|
+
self._lhs.to_egg(mod_decls),
|
|
698
|
+
self._rhs.to_egg(mod_decls),
|
|
699
|
+
[c._to_egg_fact(mod_decls) for c in self._conditions],
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
@dataclass(frozen=True)
|
|
704
|
+
class BiRewrite(Rewrite):
|
|
705
|
+
_fn_name: ClassVar[str] = "birewrite"
|
|
706
|
+
|
|
707
|
+
def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings._Command:
|
|
708
|
+
return bindings.BiRewriteCommand(self._ruleset, self._to_egg_rewrite(mod_decls))
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
class Fact(ABC):
|
|
712
|
+
"""
|
|
713
|
+
An e-graph fact, either an equality or a unit expression.
|
|
714
|
+
"""
|
|
715
|
+
|
|
716
|
+
@abstractmethod
|
|
717
|
+
def _to_egg_fact(self, mod_decls: ModuleDeclarations) -> bindings._Fact:
|
|
718
|
+
raise NotImplementedError
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
@dataclass(frozen=True)
|
|
722
|
+
class Eq(Fact):
|
|
723
|
+
_exprs: tuple[ExprDecl, ...]
|
|
724
|
+
|
|
725
|
+
def __str__(self) -> str:
|
|
726
|
+
first, *rest = (e.pretty() for e in self._exprs)
|
|
727
|
+
args_str = ", ".join(rest)
|
|
728
|
+
return f"eq({first}).to({args_str})"
|
|
729
|
+
|
|
730
|
+
def _to_egg_fact(self, mod_decls: ModuleDeclarations) -> bindings.Eq:
|
|
731
|
+
return bindings.Eq([e.to_egg(mod_decls) for e in self._exprs])
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
@dataclass(frozen=True)
|
|
735
|
+
class ExprFact(Fact):
|
|
736
|
+
_expr: ExprDecl
|
|
737
|
+
|
|
738
|
+
def __str__(self) -> str:
|
|
739
|
+
return self._expr.pretty()
|
|
740
|
+
|
|
741
|
+
def _to_egg_fact(self, mod_decls: ModuleDeclarations) -> bindings.Fact:
|
|
742
|
+
return bindings.Fact(self._expr.to_egg(mod_decls))
|
|
743
|
+
|
|
744
|
+
|
|
745
|
+
@dataclass(frozen=True)
|
|
746
|
+
class Rule(Command):
|
|
747
|
+
head: tuple[Action, ...]
|
|
748
|
+
body: tuple[Fact, ...]
|
|
749
|
+
name: str
|
|
750
|
+
ruleset: str
|
|
751
|
+
|
|
752
|
+
def __str__(self) -> str:
|
|
753
|
+
head_str = ", ".join(map(str, self.head))
|
|
754
|
+
body_str = ", ".join(map(str, self.body))
|
|
755
|
+
return f"rule({head_str}).then({body_str})"
|
|
756
|
+
|
|
757
|
+
def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings.RuleCommand:
|
|
758
|
+
return bindings.RuleCommand(
|
|
759
|
+
self.name,
|
|
760
|
+
self.ruleset,
|
|
761
|
+
bindings.Rule(
|
|
762
|
+
[a._to_egg_action(mod_decls) for a in self.head],
|
|
763
|
+
[f._to_egg_fact(mod_decls) for f in self.body],
|
|
764
|
+
),
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
class Action(Command, ABC):
|
|
769
|
+
@abstractmethod
|
|
770
|
+
def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings._Action:
|
|
771
|
+
raise NotImplementedError
|
|
772
|
+
|
|
773
|
+
def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings._Command:
|
|
774
|
+
return bindings.ActionCommand(self._to_egg_action(mod_decls))
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
@dataclass(frozen=True)
|
|
778
|
+
class Let(Action):
|
|
779
|
+
_name: str
|
|
780
|
+
_value: ExprDecl
|
|
781
|
+
|
|
782
|
+
def __str__(self) -> str:
|
|
783
|
+
return f"let({self._name}, {self._value.pretty()})"
|
|
784
|
+
|
|
785
|
+
def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Let:
|
|
786
|
+
return bindings.Let(self._name, self._value.to_egg(mod_decls))
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
@dataclass(frozen=True)
|
|
790
|
+
class Set(Action):
|
|
791
|
+
_call: CallDecl
|
|
792
|
+
_rhs: ExprDecl
|
|
793
|
+
|
|
794
|
+
def __str__(self) -> str:
|
|
795
|
+
return f"set({self._call.pretty()}).to({self._rhs.pretty()})"
|
|
796
|
+
|
|
797
|
+
def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Set:
|
|
798
|
+
return bindings.Set(
|
|
799
|
+
mod_decls.get_egg_fn(self._call.callable),
|
|
800
|
+
[a.to_egg(mod_decls) for a in self._call.args],
|
|
801
|
+
self._rhs.to_egg(mod_decls),
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
@dataclass(frozen=True)
|
|
806
|
+
class ExprAction(Action):
|
|
807
|
+
_expr: ExprDecl
|
|
808
|
+
|
|
809
|
+
def __str__(self) -> str:
|
|
810
|
+
return self._expr.pretty()
|
|
811
|
+
|
|
812
|
+
def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Expr_:
|
|
813
|
+
return bindings.Expr_(self._expr.to_egg(mod_decls))
|
|
814
|
+
|
|
815
|
+
|
|
816
|
+
@dataclass(frozen=True)
|
|
817
|
+
class Delete(Action):
|
|
818
|
+
_call: CallDecl
|
|
819
|
+
|
|
820
|
+
def __str__(self) -> str:
|
|
821
|
+
return f"delete({self._call.pretty()})"
|
|
822
|
+
|
|
823
|
+
def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Delete:
|
|
824
|
+
return bindings.Delete(
|
|
825
|
+
mod_decls.get_egg_fn(self._call.callable), [a.to_egg(mod_decls) for a in self._call.args]
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
@dataclass(frozen=True)
|
|
830
|
+
class Union_(Action):
|
|
831
|
+
_lhs: ExprDecl
|
|
832
|
+
_rhs: ExprDecl
|
|
833
|
+
|
|
834
|
+
def __str__(self) -> str:
|
|
835
|
+
return f"union({self._lhs.pretty()}).with_({self._rhs.pretty()})"
|
|
836
|
+
|
|
837
|
+
def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Union:
|
|
838
|
+
return bindings.Union(self._lhs.to_egg(mod_decls), self._rhs.to_egg(mod_decls))
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
@dataclass(frozen=True)
|
|
842
|
+
class Panic(Action):
|
|
843
|
+
message: str
|
|
844
|
+
|
|
845
|
+
def __str__(self) -> str:
|
|
846
|
+
return f"panic({self.message})"
|
|
847
|
+
|
|
848
|
+
def _to_egg_action(self, _decls: ModuleDeclarations) -> bindings.Panic:
|
|
849
|
+
return bindings.Panic(self.message)
|
|
850
|
+
|
|
851
|
+
|
|
852
|
+
# def action_decl_to_egg(decls: Declarations, action: ActionDecl) -> bindings._Action:
|
|
853
|
+
# if isinstance(action, (CallDecl, LitDecl, VarDecl)):
|
|
854
|
+
# return bindings.Expr_(action.to_egg(decls))
|
|
855
|
+
# return action.to_egg(decls)
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
class Schedule(ABC):
|
|
859
|
+
def __mul__(self, length: int) -> Schedule:
|
|
860
|
+
"""
|
|
861
|
+
Repeat the schedule a number of times.
|
|
862
|
+
"""
|
|
863
|
+
return Repeat(length, self)
|
|
864
|
+
|
|
865
|
+
def saturate(self) -> Schedule:
|
|
866
|
+
"""
|
|
867
|
+
Run the schedule until the e-graph is saturated.
|
|
868
|
+
"""
|
|
869
|
+
return Saturate(self)
|
|
870
|
+
|
|
871
|
+
@abstractmethod
|
|
872
|
+
def __str__(self) -> str:
|
|
873
|
+
raise NotImplementedError
|
|
874
|
+
|
|
875
|
+
@abstractmethod
|
|
876
|
+
def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule:
|
|
877
|
+
raise NotImplementedError
|
|
878
|
+
|
|
879
|
+
|
|
880
|
+
@dataclass
|
|
881
|
+
class Run(Schedule):
|
|
882
|
+
"""Configuration of a run"""
|
|
883
|
+
|
|
884
|
+
limit: int
|
|
885
|
+
ruleset: str
|
|
886
|
+
until: tuple[Fact, ...]
|
|
887
|
+
|
|
888
|
+
def __str__(self) -> str:
|
|
889
|
+
args_str = ", ".join(map(str, [self.ruleset, self.limit, *self.until]))
|
|
890
|
+
return f"run({args_str})"
|
|
891
|
+
|
|
892
|
+
def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule:
|
|
893
|
+
return bindings.Run(self._to_egg_config(mod_decls))
|
|
894
|
+
|
|
895
|
+
def _to_egg_config(self, mod_decls: ModuleDeclarations) -> bindings.RunConfig:
|
|
896
|
+
return bindings.RunConfig(
|
|
897
|
+
self.ruleset,
|
|
898
|
+
self.limit,
|
|
899
|
+
[fact._to_egg_fact(mod_decls) for fact in self.until] if self.until else None,
|
|
900
|
+
)
|
|
901
|
+
|
|
902
|
+
|
|
903
|
+
@dataclass
|
|
904
|
+
class Saturate(Schedule):
|
|
905
|
+
schedule: Schedule
|
|
906
|
+
|
|
907
|
+
def __str__(self) -> str:
|
|
908
|
+
return f"{self.schedule}.saturate()"
|
|
909
|
+
|
|
910
|
+
def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule:
|
|
911
|
+
return bindings.Saturate(self.schedule._to_egg_schedule(mod_decls))
|
|
912
|
+
|
|
913
|
+
|
|
914
|
+
@dataclass
|
|
915
|
+
class Repeat(Schedule):
|
|
916
|
+
length: int
|
|
917
|
+
schedule: Schedule
|
|
918
|
+
|
|
919
|
+
def __str__(self) -> str:
|
|
920
|
+
return f"{self.schedule} * {self.length}"
|
|
921
|
+
|
|
922
|
+
def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule:
|
|
923
|
+
return bindings.Repeat(self.length, self.schedule._to_egg_schedule(mod_decls))
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
@dataclass
|
|
927
|
+
class Sequence(Schedule):
|
|
928
|
+
schedules: tuple[Schedule, ...]
|
|
929
|
+
|
|
930
|
+
def __str__(self) -> str:
|
|
931
|
+
return f"sequence({', '.join(map(str, self.schedules))})"
|
|
932
|
+
|
|
933
|
+
def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule:
|
|
934
|
+
return bindings.Sequence([schedule._to_egg_schedule(mod_decls) for schedule in self.schedules])
|