egglog 6.1.0__cp310-none-win_amd64.whl → 7.0.0__cp310-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +1 -1
- egglog/bindings.cp310-win_amd64.pyd +0 -0
- egglog/bindings.pyi +2 -0
- egglog/builtins.py +1 -1
- egglog/conversion.py +172 -0
- egglog/declarations.py +329 -735
- egglog/egraph.py +531 -804
- egglog/egraph_state.py +417 -0
- egglog/exp/array_api.py +92 -80
- egglog/exp/array_api_numba.py +6 -1
- egglog/exp/siu_examples.py +35 -0
- egglog/pretty.py +418 -0
- egglog/runtime.py +196 -430
- egglog/thunk.py +72 -0
- egglog/type_constraint_solver.py +5 -2
- {egglog-6.1.0.dist-info → egglog-7.0.0.dist-info}/METADATA +19 -19
- {egglog-6.1.0.dist-info → egglog-7.0.0.dist-info}/RECORD +19 -14
- {egglog-6.1.0.dist-info → egglog-7.0.0.dist-info}/WHEEL +0 -0
- {egglog-6.1.0.dist-info → egglog-7.0.0.dist-info}/license_files/LICENSE +0 -0
egglog/declarations.py
CHANGED
|
@@ -1,18 +1,17 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Data only descriptions of the components of an egraph and the expressions.
|
|
3
|
+
|
|
4
|
+
We seperate it it into two pieces, the references the declerations, so that we can report mutually recursive types.
|
|
3
5
|
"""
|
|
4
6
|
|
|
5
7
|
from __future__ import annotations
|
|
6
8
|
|
|
7
|
-
from collections import defaultdict
|
|
8
9
|
from dataclasses import dataclass, field
|
|
9
|
-
from
|
|
10
|
-
from typing import TYPE_CHECKING, Protocol, TypeAlias, Union, runtime_checkable
|
|
10
|
+
from functools import cached_property
|
|
11
|
+
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, Union, runtime_checkable
|
|
11
12
|
|
|
12
13
|
from typing_extensions import Self, assert_never
|
|
13
14
|
|
|
14
|
-
from . import bindings
|
|
15
|
-
|
|
16
15
|
if TYPE_CHECKING:
|
|
17
16
|
from collections.abc import Callable, Iterable
|
|
18
17
|
|
|
@@ -20,84 +19,64 @@ if TYPE_CHECKING:
|
|
|
20
19
|
__all__ = [
|
|
21
20
|
"Declarations",
|
|
22
21
|
"DeclerationsLike",
|
|
23
|
-
"
|
|
22
|
+
"DelayedDeclerations",
|
|
23
|
+
"upcast_declerations",
|
|
24
|
+
"Declarations",
|
|
24
25
|
"JustTypeRef",
|
|
25
26
|
"ClassTypeVarRef",
|
|
26
27
|
"TypeRefWithVars",
|
|
27
28
|
"TypeOrVarRef",
|
|
28
|
-
"FunctionRef",
|
|
29
29
|
"MethodRef",
|
|
30
30
|
"ClassMethodRef",
|
|
31
|
+
"FunctionRef",
|
|
32
|
+
"ConstantRef",
|
|
31
33
|
"ClassVariableRef",
|
|
32
|
-
"FunctionCallableRef",
|
|
33
34
|
"PropertyRef",
|
|
34
35
|
"CallableRef",
|
|
35
|
-
"ConstantRef",
|
|
36
36
|
"FunctionDecl",
|
|
37
|
+
"RelationDecl",
|
|
38
|
+
"ConstantDecl",
|
|
39
|
+
"CallableDecl",
|
|
37
40
|
"VarDecl",
|
|
38
|
-
"LitType",
|
|
39
41
|
"PyObjectDecl",
|
|
42
|
+
"LitType",
|
|
40
43
|
"LitDecl",
|
|
41
44
|
"CallDecl",
|
|
42
45
|
"ExprDecl",
|
|
43
46
|
"TypedExprDecl",
|
|
44
47
|
"ClassDecl",
|
|
45
|
-
"
|
|
46
|
-
"
|
|
48
|
+
"RulesetDecl",
|
|
49
|
+
"SaturateDecl",
|
|
50
|
+
"RepeatDecl",
|
|
51
|
+
"SequenceDecl",
|
|
52
|
+
"RunDecl",
|
|
53
|
+
"ScheduleDecl",
|
|
54
|
+
"EqDecl",
|
|
55
|
+
"ExprFactDecl",
|
|
56
|
+
"FactDecl",
|
|
57
|
+
"LetDecl",
|
|
58
|
+
"SetDecl",
|
|
59
|
+
"ExprActionDecl",
|
|
60
|
+
"ChangeDecl",
|
|
61
|
+
"UnionDecl",
|
|
62
|
+
"PanicDecl",
|
|
63
|
+
"ActionDecl",
|
|
64
|
+
"RewriteDecl",
|
|
65
|
+
"BiRewriteDecl",
|
|
66
|
+
"RuleDecl",
|
|
67
|
+
"RewriteOrRuleDecl",
|
|
68
|
+
"ActionCommandDecl",
|
|
69
|
+
"CommandDecl",
|
|
47
70
|
]
|
|
48
71
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
"__lt__": "<",
|
|
58
|
-
"__le__": "<=",
|
|
59
|
-
"__eq__": "==",
|
|
60
|
-
"__ne__": "!=",
|
|
61
|
-
"__gt__": ">",
|
|
62
|
-
"__ge__": ">=",
|
|
63
|
-
# Numeric
|
|
64
|
-
"__add__": "+",
|
|
65
|
-
"__sub__": "-",
|
|
66
|
-
"__mul__": "*",
|
|
67
|
-
"__matmul__": "@",
|
|
68
|
-
"__truediv__": "/",
|
|
69
|
-
"__floordiv__": "//",
|
|
70
|
-
"__mod__": "%",
|
|
71
|
-
# TODO: Support divmod, with tuple return value
|
|
72
|
-
# "__divmod__": "divmod",
|
|
73
|
-
# TODO: Three arg power
|
|
74
|
-
"__pow__": "**",
|
|
75
|
-
"__lshift__": "<<",
|
|
76
|
-
"__rshift__": ">>",
|
|
77
|
-
"__and__": "&",
|
|
78
|
-
"__xor__": "^",
|
|
79
|
-
"__or__": "|",
|
|
80
|
-
}
|
|
81
|
-
REFLECTED_BINARY_METHODS = {
|
|
82
|
-
"__radd__": "__add__",
|
|
83
|
-
"__rsub__": "__sub__",
|
|
84
|
-
"__rmul__": "__mul__",
|
|
85
|
-
"__rmatmul__": "__matmul__",
|
|
86
|
-
"__rtruediv__": "__truediv__",
|
|
87
|
-
"__rfloordiv__": "__floordiv__",
|
|
88
|
-
"__rmod__": "__mod__",
|
|
89
|
-
"__rpow__": "__pow__",
|
|
90
|
-
"__rlshift__": "__lshift__",
|
|
91
|
-
"__rrshift__": "__rshift__",
|
|
92
|
-
"__rand__": "__and__",
|
|
93
|
-
"__rxor__": "__xor__",
|
|
94
|
-
"__ror__": "__or__",
|
|
95
|
-
}
|
|
96
|
-
UNARY_METHODS = {
|
|
97
|
-
"__pos__": "+",
|
|
98
|
-
"__neg__": "-",
|
|
99
|
-
"__invert__": "~",
|
|
100
|
-
}
|
|
72
|
+
|
|
73
|
+
@dataclass
|
|
74
|
+
class DelayedDeclerations:
|
|
75
|
+
__egg_decls_thunk__: Callable[[], Declarations]
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def __egg_decls__(self) -> Declarations:
|
|
79
|
+
return self.__egg_decls_thunk__()
|
|
101
80
|
|
|
102
81
|
|
|
103
82
|
@runtime_checkable
|
|
@@ -109,7 +88,10 @@ class HasDeclerations(Protocol):
|
|
|
109
88
|
DeclerationsLike: TypeAlias = Union[HasDeclerations, None, "Declarations"]
|
|
110
89
|
|
|
111
90
|
|
|
112
|
-
|
|
91
|
+
# TODO: Make all ClassDecls take deferred type refs, which return new decls when resolving.
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]:
|
|
113
95
|
d = []
|
|
114
96
|
for l in declerations_like:
|
|
115
97
|
if l is None:
|
|
@@ -125,30 +107,14 @@ def upcast_decleratioons(declerations_like: Iterable[DeclerationsLike]) -> list[
|
|
|
125
107
|
|
|
126
108
|
@dataclass
|
|
127
109
|
class Declarations:
|
|
128
|
-
_functions: dict[str, FunctionDecl] = field(default_factory=dict)
|
|
110
|
+
_functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict)
|
|
111
|
+
_constants: dict[str, ConstantDecl] = field(default_factory=dict)
|
|
129
112
|
_classes: dict[str, ClassDecl] = field(default_factory=dict)
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
# Bidirectional mapping between egg function names and python callable references.
|
|
133
|
-
# Note that there are possibly mutliple callable references for a single egg function name, like `+`
|
|
134
|
-
# for both int and rational classes.
|
|
135
|
-
_egg_fn_to_callable_refs: defaultdict[str, set[CallableRef]] = field(default_factory=lambda: defaultdict(set))
|
|
136
|
-
_callable_ref_to_egg_fn: dict[CallableRef, str] = field(default_factory=dict)
|
|
137
|
-
|
|
138
|
-
# Bidirectional mapping between egg sort names and python type references.
|
|
139
|
-
_egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)
|
|
140
|
-
_type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
|
|
141
|
-
|
|
142
|
-
# Mapping from egg name (of sort or function) to command to create it.
|
|
143
|
-
_cmds: dict[str, bindings._Command] = field(default_factory=dict)
|
|
144
|
-
|
|
145
|
-
def __post_init__(self) -> None:
|
|
146
|
-
if "!=" not in self._egg_fn_to_callable_refs:
|
|
147
|
-
self.register_callable_ref(FunctionRef("!="), "!=")
|
|
113
|
+
_rulesets: dict[str, RulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
|
|
148
114
|
|
|
149
115
|
@classmethod
|
|
150
116
|
def create(cls, *others: DeclerationsLike) -> Declarations:
|
|
151
|
-
others =
|
|
117
|
+
others = upcast_declerations(others)
|
|
152
118
|
if not others:
|
|
153
119
|
return Declarations()
|
|
154
120
|
first, *rest = others
|
|
@@ -159,25 +125,9 @@ class Declarations:
|
|
|
159
125
|
return new
|
|
160
126
|
|
|
161
127
|
def copy(self) -> Declarations:
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
_constants=self._constants.copy(),
|
|
166
|
-
_egg_fn_to_callable_refs=defaultdict(set, {k: v.copy() for k, v in self._egg_fn_to_callable_refs.items()}),
|
|
167
|
-
_callable_ref_to_egg_fn=self._callable_ref_to_egg_fn.copy(),
|
|
168
|
-
_egg_sort_to_type_ref=self._egg_sort_to_type_ref.copy(),
|
|
169
|
-
_type_ref_to_egg_sort=self._type_ref_to_egg_sort.copy(),
|
|
170
|
-
_cmds=self._cmds.copy(),
|
|
171
|
-
)
|
|
172
|
-
|
|
173
|
-
def __deepcopy__(self, memo: dict) -> Declarations:
|
|
174
|
-
return self.copy()
|
|
175
|
-
|
|
176
|
-
def add_cmd(self, name: str, cmd: bindings._Command) -> None:
|
|
177
|
-
self._cmds[name] = cmd
|
|
178
|
-
|
|
179
|
-
def list_cmds(self) -> list[bindings._Command]:
|
|
180
|
-
return list(self._cmds.values())
|
|
128
|
+
new = Declarations()
|
|
129
|
+
new |= self
|
|
130
|
+
return new
|
|
181
131
|
|
|
182
132
|
def update(self, *others: DeclerationsLike) -> None:
|
|
183
133
|
for other in others:
|
|
@@ -200,82 +150,26 @@ class Declarations:
|
|
|
200
150
|
"""
|
|
201
151
|
Updates the other decl with these values in palce.
|
|
202
152
|
"""
|
|
203
|
-
# If cmds are == skip unioning for time savings
|
|
204
|
-
# if set(self._cmds) == set(other._cmds) and self.record_cmds and other.record_cmds:
|
|
205
|
-
# return self
|
|
206
153
|
other._functions |= self._functions
|
|
207
154
|
other._classes |= self._classes
|
|
208
155
|
other._constants |= self._constants
|
|
209
|
-
other.
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
other._callable_ref_to_egg_fn |= self._callable_ref_to_egg_fn
|
|
213
|
-
for egg_fn, callable_refs in self._egg_fn_to_callable_refs.items():
|
|
214
|
-
other._egg_fn_to_callable_refs[egg_fn] |= callable_refs
|
|
215
|
-
|
|
216
|
-
def set_function_decl(self, ref: FunctionCallableRef, decl: FunctionDecl) -> None:
|
|
217
|
-
"""
|
|
218
|
-
Sets a function declaration for the given callable reference.
|
|
219
|
-
"""
|
|
156
|
+
other._rulesets |= self._rulesets
|
|
157
|
+
|
|
158
|
+
def get_callable_decl(self, ref: CallableRef) -> CallableDecl:
|
|
220
159
|
match ref:
|
|
221
160
|
case FunctionRef(name):
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
self.
|
|
161
|
+
return self._functions[name]
|
|
162
|
+
case ConstantRef(name):
|
|
163
|
+
return self._constants[name]
|
|
225
164
|
case MethodRef(class_name, method_name):
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
self._classes[class_name].
|
|
229
|
-
case ClassMethodRef(class_name,
|
|
230
|
-
|
|
231
|
-
raise ValueError(f"Class method {class_name}.{method_name} already registered")
|
|
232
|
-
self._classes[class_name].class_methods[method_name] = decl
|
|
165
|
+
return self._classes[class_name].methods[method_name]
|
|
166
|
+
case ClassVariableRef(class_name, name):
|
|
167
|
+
return self._classes[class_name].class_variables[name]
|
|
168
|
+
case ClassMethodRef(class_name, name):
|
|
169
|
+
return self._classes[class_name].class_methods[name]
|
|
233
170
|
case PropertyRef(class_name, property_name):
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
self._classes[class_name].properties[property_name] = decl
|
|
237
|
-
case _:
|
|
238
|
-
assert_never(ref)
|
|
239
|
-
|
|
240
|
-
def set_constant_type(self, ref: ConstantCallableRef, tp: JustTypeRef) -> None:
|
|
241
|
-
match ref:
|
|
242
|
-
case ConstantRef(name):
|
|
243
|
-
if name in self._constants:
|
|
244
|
-
raise ValueError(f"Constant {name} already registered")
|
|
245
|
-
self._constants[name] = tp
|
|
246
|
-
case ClassVariableRef(class_name, variable_name):
|
|
247
|
-
if variable_name in self._classes[class_name].class_variables:
|
|
248
|
-
raise ValueError(f"Class variable {class_name}.{variable_name} already registered")
|
|
249
|
-
self._classes[class_name].class_variables[variable_name] = tp
|
|
250
|
-
case _:
|
|
251
|
-
assert_never(ref)
|
|
252
|
-
|
|
253
|
-
def register_callable_ref(self, ref: CallableRef, egg_name: str) -> None:
|
|
254
|
-
"""
|
|
255
|
-
Registers a callable reference with the given egg name.
|
|
256
|
-
|
|
257
|
-
The callable's function needs to be registered first.
|
|
258
|
-
"""
|
|
259
|
-
if ref in self._callable_ref_to_egg_fn:
|
|
260
|
-
raise ValueError(f"Callable ref {ref} already registered")
|
|
261
|
-
self._callable_ref_to_egg_fn[ref] = egg_name
|
|
262
|
-
self._egg_fn_to_callable_refs[egg_name].add(ref)
|
|
263
|
-
|
|
264
|
-
def get_callable_refs(self, egg_name: str) -> Iterable[CallableRef]:
|
|
265
|
-
return self._egg_fn_to_callable_refs[egg_name]
|
|
266
|
-
|
|
267
|
-
def get_egg_fn(self, ref: CallableRef) -> str:
|
|
268
|
-
return self._callable_ref_to_egg_fn[ref]
|
|
269
|
-
|
|
270
|
-
def get_egg_sort(self, ref: JustTypeRef) -> str:
|
|
271
|
-
return self._type_ref_to_egg_sort[ref]
|
|
272
|
-
|
|
273
|
-
def op_mapping(self) -> dict[str, str]:
|
|
274
|
-
"""
|
|
275
|
-
Create a mapping of egglog function name to Python function name, for use in the serialized format
|
|
276
|
-
for better visualization.
|
|
277
|
-
"""
|
|
278
|
-
return {k: str(next(iter(v))) for k, v in self._egg_fn_to_callable_refs.items() if len(v) == 1}
|
|
171
|
+
return self._classes[class_name].properties[property_name]
|
|
172
|
+
assert_never(ref)
|
|
279
173
|
|
|
280
174
|
def has_method(self, class_name: str, method_name: str) -> bool | None:
|
|
281
175
|
"""
|
|
@@ -285,138 +179,31 @@ class Declarations:
|
|
|
285
179
|
return method_name in self._classes[class_name].methods
|
|
286
180
|
return None
|
|
287
181
|
|
|
288
|
-
def get_function_decl(self, ref: CallableRef) -> FunctionDecl:
|
|
289
|
-
match ref:
|
|
290
|
-
case ConstantRef(name):
|
|
291
|
-
return self._constants[name].to_constant_function_decl()
|
|
292
|
-
case ClassVariableRef(class_name, variable_name):
|
|
293
|
-
return self._classes[class_name].class_variables[variable_name].to_constant_function_decl()
|
|
294
|
-
case FunctionRef(name):
|
|
295
|
-
return self._functions[name]
|
|
296
|
-
case MethodRef(class_name, method_name):
|
|
297
|
-
return self._classes[class_name].methods[method_name]
|
|
298
|
-
case ClassMethodRef(class_name, method_name):
|
|
299
|
-
return self._classes[class_name].class_methods[method_name]
|
|
300
|
-
case PropertyRef(class_name, property_name):
|
|
301
|
-
return self._classes[class_name].properties[property_name]
|
|
302
|
-
assert_never(ref)
|
|
303
|
-
|
|
304
182
|
def get_class_decl(self, name: str) -> ClassDecl:
|
|
305
183
|
return self._classes[name]
|
|
306
184
|
|
|
307
|
-
def get_possible_types(self, cls_name: str) -> frozenset[JustTypeRef]:
|
|
308
|
-
"""
|
|
309
|
-
Given a class name, returns all possible registered types that it can be.
|
|
310
|
-
"""
|
|
311
|
-
return frozenset(tp for tp in self._type_ref_to_egg_sort if tp.name == cls_name)
|
|
312
185
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
186
|
+
@dataclass
|
|
187
|
+
class ClassDecl:
|
|
188
|
+
egg_name: str | None = None
|
|
189
|
+
type_vars: tuple[str, ...] = ()
|
|
190
|
+
builtin: bool = False
|
|
191
|
+
class_methods: dict[str, FunctionDecl] = field(default_factory=dict)
|
|
192
|
+
# These have to be seperate from class_methods so that printing them can be done easily
|
|
193
|
+
class_variables: dict[str, ConstantDecl] = field(default_factory=dict)
|
|
194
|
+
methods: dict[str, FunctionDecl] = field(default_factory=dict)
|
|
195
|
+
properties: dict[str, FunctionDecl] = field(default_factory=dict)
|
|
196
|
+
preserved_methods: dict[str, Callable] = field(default_factory=dict)
|
|
320
197
|
|
|
321
|
-
def register_sort(self, ref: JustTypeRef, builtin: bool, egg_name: str | None = None) -> str:
|
|
322
|
-
"""
|
|
323
|
-
Register a sort with the given name. If no name is given, one is generated.
|
|
324
198
|
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
try:
|
|
329
|
-
egg_sort = self.get_egg_sort(ref)
|
|
330
|
-
except KeyError:
|
|
331
|
-
pass
|
|
332
|
-
else:
|
|
333
|
-
return egg_sort
|
|
334
|
-
egg_name = egg_name or ref.generate_egg_name()
|
|
335
|
-
if egg_name in self._egg_sort_to_type_ref:
|
|
336
|
-
raise ValueError(f"Sort {egg_name} is already registered.")
|
|
337
|
-
self._egg_sort_to_type_ref[egg_name] = ref
|
|
338
|
-
self._type_ref_to_egg_sort[ref] = egg_name
|
|
339
|
-
if not builtin:
|
|
340
|
-
self.add_cmd(
|
|
341
|
-
egg_name,
|
|
342
|
-
bindings.Sort(
|
|
343
|
-
egg_name,
|
|
344
|
-
(
|
|
345
|
-
self.get_egg_sort(JustTypeRef(ref.name)),
|
|
346
|
-
[bindings.Var(self.register_sort(arg, False)) for arg in ref.args],
|
|
347
|
-
)
|
|
348
|
-
if ref.args
|
|
349
|
-
else None,
|
|
350
|
-
),
|
|
351
|
-
)
|
|
352
|
-
|
|
353
|
-
return egg_name
|
|
354
|
-
|
|
355
|
-
def register_function_callable(
|
|
356
|
-
self,
|
|
357
|
-
ref: FunctionCallableRef,
|
|
358
|
-
fn_decl: FunctionDecl,
|
|
359
|
-
egg_name: str | None,
|
|
360
|
-
cost: int | None,
|
|
361
|
-
default: ExprDecl | None,
|
|
362
|
-
merge: ExprDecl | None,
|
|
363
|
-
merge_action: list[bindings._Action],
|
|
364
|
-
unextractable: bool,
|
|
365
|
-
builtin: bool,
|
|
366
|
-
is_relation: bool = False,
|
|
367
|
-
) -> None:
|
|
368
|
-
"""
|
|
369
|
-
Registers a callable with the given egg name.
|
|
199
|
+
@dataclass
|
|
200
|
+
class RulesetDecl:
|
|
201
|
+
rules: list[RewriteOrRuleDecl]
|
|
370
202
|
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
self
|
|
375
|
-
self.set_function_decl(ref, fn_decl)
|
|
376
|
-
|
|
377
|
-
# Skip generating the cmds if we don't want to record them, like for the builtins
|
|
378
|
-
if builtin:
|
|
379
|
-
return
|
|
380
|
-
|
|
381
|
-
if fn_decl.var_arg_type is not None:
|
|
382
|
-
msg = "egglog does not support variable arguments yet."
|
|
383
|
-
raise NotImplementedError(msg)
|
|
384
|
-
# Remove all vars from the type refs, raising an errory if we find one,
|
|
385
|
-
# since we cannot create egg functions with vars
|
|
386
|
-
arg_sorts = [self.register_sort(a.to_just(), False) for a in fn_decl.arg_types]
|
|
387
|
-
cmd: bindings._Command
|
|
388
|
-
if is_relation:
|
|
389
|
-
assert not default
|
|
390
|
-
assert not merge
|
|
391
|
-
assert not merge_action
|
|
392
|
-
assert not cost
|
|
393
|
-
cmd = bindings.Relation(egg_name, arg_sorts)
|
|
394
|
-
else:
|
|
395
|
-
egg_fn_decl = bindings.FunctionDecl(
|
|
396
|
-
egg_name,
|
|
397
|
-
bindings.Schema(arg_sorts, self.register_sort(fn_decl.return_type.to_just(), False)),
|
|
398
|
-
default.to_egg(self) if default else None,
|
|
399
|
-
merge.to_egg(self) if merge else None,
|
|
400
|
-
merge_action,
|
|
401
|
-
cost,
|
|
402
|
-
unextractable,
|
|
403
|
-
)
|
|
404
|
-
cmd = bindings.Function(egg_fn_decl)
|
|
405
|
-
self.add_cmd(egg_name, cmd)
|
|
406
|
-
|
|
407
|
-
def register_constant_callable(self, ref: ConstantCallableRef, type_ref: JustTypeRef, egg_name: str | None) -> None:
|
|
408
|
-
egg_name = egg_name or ref.generate_egg_name()
|
|
409
|
-
self.register_callable_ref(ref, egg_name)
|
|
410
|
-
self.set_constant_type(ref, type_ref)
|
|
411
|
-
egg_sort = self.register_sort(type_ref, False)
|
|
412
|
-
# self.add_cmd(egg_name, bindings.Declare(egg_name, self.get_egg_sort(type_ref)))
|
|
413
|
-
# Use function decleration instead of constant b/c constants cannot be extracted
|
|
414
|
-
# https://github.com/egraphs-good/egglog/issues/334
|
|
415
|
-
fn_decl = bindings.FunctionDecl(egg_name, bindings.Schema([], egg_sort))
|
|
416
|
-
self.add_cmd(egg_name, bindings.Function(fn_decl))
|
|
417
|
-
|
|
418
|
-
def register_preserved_method(self, class_: str, method: str, fn: Callable) -> None:
|
|
419
|
-
self._classes[class_].preserved_methods[method] = fn
|
|
203
|
+
# Make hashable so when traversing for pretty-fying we can know which rulesets we have already
|
|
204
|
+
# made into strings
|
|
205
|
+
def __hash__(self) -> int:
|
|
206
|
+
return hash((type(self), tuple(self.rules)))
|
|
420
207
|
|
|
421
208
|
|
|
422
209
|
# Have two different types of type refs, one that can include vars recursively and one that cannot.
|
|
@@ -427,38 +214,18 @@ class JustTypeRef:
|
|
|
427
214
|
name: str
|
|
428
215
|
args: tuple[JustTypeRef, ...] = ()
|
|
429
216
|
|
|
430
|
-
def generate_egg_name(self) -> str:
|
|
431
|
-
"""
|
|
432
|
-
Generates an egg sort name for this type reference by linearizing the type.
|
|
433
|
-
"""
|
|
434
|
-
if not self.args:
|
|
435
|
-
return self.name
|
|
436
|
-
args = "_".join(a.generate_egg_name() for a in self.args)
|
|
437
|
-
return f"{self.name}_{args}"
|
|
438
|
-
|
|
439
217
|
def to_var(self) -> TypeRefWithVars:
|
|
440
218
|
return TypeRefWithVars(self.name, tuple(a.to_var() for a in self.args))
|
|
441
219
|
|
|
442
|
-
def
|
|
443
|
-
if
|
|
444
|
-
return self.name
|
|
445
|
-
|
|
446
|
-
return f"{self.name}[{args}]"
|
|
220
|
+
def __str__(self) -> str:
|
|
221
|
+
if self.args:
|
|
222
|
+
return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
|
|
223
|
+
return self.name
|
|
447
224
|
|
|
448
|
-
def to_constant_function_decl(self) -> FunctionDecl:
|
|
449
|
-
"""
|
|
450
|
-
Create a function declaration for a constant function.
|
|
451
225
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
arg_types=(),
|
|
456
|
-
arg_names=(),
|
|
457
|
-
arg_defaults=(),
|
|
458
|
-
return_type=self.to_var(),
|
|
459
|
-
mutates_first_arg=False,
|
|
460
|
-
var_arg_type=None,
|
|
461
|
-
)
|
|
226
|
+
##
|
|
227
|
+
# Type references with vars
|
|
228
|
+
##
|
|
462
229
|
|
|
463
230
|
|
|
464
231
|
@dataclass(frozen=True)
|
|
@@ -473,7 +240,7 @@ class ClassTypeVarRef:
|
|
|
473
240
|
msg = "egglog does not support generic classes yet."
|
|
474
241
|
raise NotImplementedError(msg)
|
|
475
242
|
|
|
476
|
-
def
|
|
243
|
+
def __str__(self) -> str:
|
|
477
244
|
return self.name
|
|
478
245
|
|
|
479
246
|
|
|
@@ -485,30 +252,27 @@ class TypeRefWithVars:
|
|
|
485
252
|
def to_just(self) -> JustTypeRef:
|
|
486
253
|
return JustTypeRef(self.name, tuple(a.to_just() for a in self.args))
|
|
487
254
|
|
|
488
|
-
def
|
|
489
|
-
if
|
|
490
|
-
return self.name
|
|
491
|
-
|
|
492
|
-
return f"{self.name}[{args}]"
|
|
255
|
+
def __str__(self) -> str:
|
|
256
|
+
if self.args:
|
|
257
|
+
return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
|
|
258
|
+
return self.name
|
|
493
259
|
|
|
494
260
|
|
|
495
261
|
TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars
|
|
496
262
|
|
|
263
|
+
##
|
|
264
|
+
# Callables References
|
|
265
|
+
##
|
|
266
|
+
|
|
497
267
|
|
|
498
268
|
@dataclass(frozen=True)
|
|
499
269
|
class FunctionRef:
|
|
500
270
|
name: str
|
|
501
271
|
|
|
502
|
-
def generate_egg_name(self) -> str:
|
|
503
|
-
return self.name
|
|
504
|
-
|
|
505
|
-
def __str__(self) -> str:
|
|
506
|
-
return self.name
|
|
507
272
|
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
ARG = "·"
|
|
273
|
+
@dataclass(frozen=True)
|
|
274
|
+
class ConstantRef:
|
|
275
|
+
name: str
|
|
512
276
|
|
|
513
277
|
|
|
514
278
|
@dataclass(frozen=True)
|
|
@@ -516,123 +280,115 @@ class MethodRef:
|
|
|
516
280
|
class_name: str
|
|
517
281
|
method_name: str
|
|
518
282
|
|
|
519
|
-
def generate_egg_name(self) -> str:
|
|
520
|
-
return f"{self.class_name}_{self.method_name}"
|
|
521
|
-
|
|
522
|
-
def __str__(self) -> str: # noqa: PLR0911
|
|
523
|
-
match self.method_name:
|
|
524
|
-
case _ if self.method_name in UNARY_METHODS:
|
|
525
|
-
return f"{UNARY_METHODS[self.method_name]}{ARG}"
|
|
526
|
-
case _ if self.method_name in BINARY_METHODS:
|
|
527
|
-
return f"({ARG} {BINARY_METHODS[self.method_name]} {ARG})"
|
|
528
|
-
case "__getitem__":
|
|
529
|
-
return f"{ARG}[{ARG}]"
|
|
530
|
-
case "__call__":
|
|
531
|
-
return f"{ARG}({ARG})"
|
|
532
|
-
case "__delitem__":
|
|
533
|
-
return f"del {ARG}[{ARG}]"
|
|
534
|
-
case "__setitem__":
|
|
535
|
-
return f"{ARG}[{ARG}] = {ARG}"
|
|
536
|
-
return f"{ARG}.{self.method_name}"
|
|
537
|
-
|
|
538
283
|
|
|
539
284
|
@dataclass(frozen=True)
|
|
540
285
|
class ClassMethodRef:
|
|
541
286
|
class_name: str
|
|
542
287
|
method_name: str
|
|
543
288
|
|
|
544
|
-
def generate_egg_name(self) -> str:
|
|
545
|
-
return f"{self.class_name}_{self.method_name}"
|
|
546
|
-
|
|
547
|
-
def __str__(self) -> str:
|
|
548
|
-
if self.method_name == "__init__":
|
|
549
|
-
return self.class_name
|
|
550
|
-
return f"{self.class_name}.{self.method_name}"
|
|
551
|
-
|
|
552
289
|
|
|
553
290
|
@dataclass(frozen=True)
|
|
554
|
-
class
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
def generate_egg_name(self) -> str:
|
|
558
|
-
return self.name
|
|
559
|
-
|
|
560
|
-
def __str__(self) -> str:
|
|
561
|
-
return self.name
|
|
291
|
+
class ClassVariableRef:
|
|
292
|
+
class_name: str
|
|
293
|
+
var_name: str
|
|
562
294
|
|
|
563
295
|
|
|
564
296
|
@dataclass(frozen=True)
|
|
565
|
-
class
|
|
297
|
+
class PropertyRef:
|
|
566
298
|
class_name: str
|
|
567
|
-
|
|
299
|
+
property_name: str
|
|
568
300
|
|
|
569
|
-
def generate_egg_name(self) -> str:
|
|
570
|
-
return f"{self.class_name}_{self.variable_name}"
|
|
571
301
|
|
|
572
|
-
|
|
573
|
-
|
|
302
|
+
CallableRef: TypeAlias = FunctionRef | ConstantRef | MethodRef | ClassMethodRef | ClassVariableRef | PropertyRef
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
##
|
|
306
|
+
# Callables
|
|
307
|
+
##
|
|
574
308
|
|
|
575
309
|
|
|
576
310
|
@dataclass(frozen=True)
|
|
577
|
-
class
|
|
578
|
-
|
|
579
|
-
|
|
311
|
+
class RelationDecl:
|
|
312
|
+
arg_types: tuple[JustTypeRef, ...]
|
|
313
|
+
# List of defaults. None for any arg which doesn't have one.
|
|
314
|
+
arg_defaults: tuple[ExprDecl | None, ...]
|
|
315
|
+
egg_name: str | None
|
|
580
316
|
|
|
581
|
-
def
|
|
582
|
-
return
|
|
317
|
+
def to_function_decl(self) -> FunctionDecl:
|
|
318
|
+
return FunctionDecl(
|
|
319
|
+
arg_types=tuple(a.to_var() for a in self.arg_types),
|
|
320
|
+
arg_names=tuple(f"__{i}" for i in range(len(self.arg_types))),
|
|
321
|
+
arg_defaults=self.arg_defaults,
|
|
322
|
+
return_type=TypeRefWithVars("Unit"),
|
|
323
|
+
egg_name=self.egg_name,
|
|
324
|
+
default=LitDecl(None),
|
|
325
|
+
)
|
|
583
326
|
|
|
584
|
-
def __str__(self) -> str:
|
|
585
|
-
return f"{ARG}.{self.property_name}"
|
|
586
327
|
|
|
328
|
+
@dataclass(frozen=True)
|
|
329
|
+
class ConstantDecl:
|
|
330
|
+
"""
|
|
331
|
+
Same as `(declare)` in egglog
|
|
332
|
+
"""
|
|
587
333
|
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
334
|
+
type_ref: JustTypeRef
|
|
335
|
+
egg_name: str | None = None
|
|
336
|
+
|
|
337
|
+
def to_function_decl(self) -> FunctionDecl:
|
|
338
|
+
return FunctionDecl(
|
|
339
|
+
arg_types=(),
|
|
340
|
+
arg_names=(),
|
|
341
|
+
arg_defaults=(),
|
|
342
|
+
return_type=self.type_ref.to_var(),
|
|
343
|
+
egg_name=self.egg_name,
|
|
344
|
+
)
|
|
591
345
|
|
|
592
346
|
|
|
593
347
|
@dataclass(frozen=True)
|
|
594
348
|
class FunctionDecl:
|
|
349
|
+
# All args are delayed except for relations converted to function decls
|
|
595
350
|
arg_types: tuple[TypeOrVarRef, ...]
|
|
596
|
-
|
|
597
|
-
|
|
351
|
+
arg_names: tuple[str, ...]
|
|
352
|
+
# List of defaults. None for any arg which doesn't have one.
|
|
598
353
|
arg_defaults: tuple[ExprDecl | None, ...]
|
|
599
|
-
|
|
600
|
-
|
|
354
|
+
# If None, then the first arg is mutated and returned
|
|
355
|
+
return_type: TypeOrVarRef | None
|
|
601
356
|
var_arg_type: TypeOrVarRef | None = None
|
|
602
357
|
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
Parameter(
|
|
612
|
-
n,
|
|
613
|
-
Parameter.POSITIONAL_OR_KEYWORD,
|
|
614
|
-
default=transform_default(TypedExprDecl(t.to_just(), d)) if d else Parameter.empty,
|
|
615
|
-
)
|
|
616
|
-
for n, d, t in zip(arg_names, self.arg_defaults, self.arg_types, strict=True)
|
|
617
|
-
]
|
|
618
|
-
if self.var_arg_type is not None:
|
|
619
|
-
parameters.append(Parameter("__rest", Parameter.VAR_POSITIONAL))
|
|
620
|
-
return Signature(parameters)
|
|
358
|
+
# Egg params
|
|
359
|
+
builtin: bool = False
|
|
360
|
+
egg_name: str | None = None
|
|
361
|
+
cost: int | None = None
|
|
362
|
+
default: ExprDecl | None = None
|
|
363
|
+
on_merge: tuple[ActionDecl, ...] = ()
|
|
364
|
+
merge: ExprDecl | None = None
|
|
365
|
+
unextractable: bool = False
|
|
621
366
|
|
|
367
|
+
def to_function_decl(self) -> FunctionDecl:
|
|
368
|
+
return self
|
|
622
369
|
|
|
623
|
-
@
|
|
624
|
-
|
|
625
|
-
|
|
370
|
+
@property
|
|
371
|
+
def semantic_return_type(self) -> TypeOrVarRef:
|
|
372
|
+
"""
|
|
373
|
+
The type that is returned by the function, which wil be in the first arg if it mutates it.
|
|
374
|
+
"""
|
|
375
|
+
return self.return_type or self.arg_types[0]
|
|
626
376
|
|
|
627
|
-
@
|
|
628
|
-
def
|
|
629
|
-
return
|
|
377
|
+
@property
|
|
378
|
+
def mutates(self) -> bool:
|
|
379
|
+
return self.return_type is None
|
|
630
380
|
|
|
631
|
-
def to_egg(self, _decls: Declarations) -> bindings.Var:
|
|
632
|
-
return bindings.Var(self.name)
|
|
633
381
|
|
|
634
|
-
|
|
635
|
-
|
|
382
|
+
CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl
|
|
383
|
+
|
|
384
|
+
##
|
|
385
|
+
# Expressions
|
|
386
|
+
##
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
@dataclass(frozen=True)
|
|
390
|
+
class VarDecl:
|
|
391
|
+
name: str
|
|
636
392
|
|
|
637
393
|
|
|
638
394
|
@dataclass(frozen=True)
|
|
@@ -646,16 +402,14 @@ class PyObjectDecl:
|
|
|
646
402
|
except TypeError:
|
|
647
403
|
return id(self.value)
|
|
648
404
|
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
return
|
|
653
|
-
|
|
654
|
-
def to_egg(self, _decls: Declarations) -> bindings._Expr:
|
|
655
|
-
return GLOBAL_PY_OBJECT_SORT.store(self.value)
|
|
405
|
+
def __eq__(self, other: object) -> bool:
|
|
406
|
+
if not isinstance(other, PyObjectDecl):
|
|
407
|
+
return False
|
|
408
|
+
return self.parts == other.parts
|
|
656
409
|
|
|
657
|
-
|
|
658
|
-
|
|
410
|
+
@property
|
|
411
|
+
def parts(self) -> tuple[type, object]:
|
|
412
|
+
return (type(self.value), self.value)
|
|
659
413
|
|
|
660
414
|
|
|
661
415
|
LitType: TypeAlias = int | str | float | bool | None
|
|
@@ -665,53 +419,30 @@ LitType: TypeAlias = int | str | float | bool | None
|
|
|
665
419
|
class LitDecl:
|
|
666
420
|
value: LitType
|
|
667
421
|
|
|
668
|
-
|
|
669
|
-
def from_egg(cls, lit: bindings.TermLit) -> ExprDecl:
|
|
670
|
-
value = lit.value
|
|
671
|
-
if isinstance(value, bindings.Unit):
|
|
672
|
-
return cls(None)
|
|
673
|
-
return cls(value.value)
|
|
674
|
-
|
|
675
|
-
def to_egg(self, _decls: Declarations) -> bindings.Lit:
|
|
676
|
-
if self.value is None:
|
|
677
|
-
return bindings.Lit(bindings.Unit())
|
|
678
|
-
if isinstance(self.value, bool):
|
|
679
|
-
return bindings.Lit(bindings.Bool(self.value))
|
|
680
|
-
if isinstance(self.value, int):
|
|
681
|
-
return bindings.Lit(bindings.Int(self.value))
|
|
682
|
-
if isinstance(self.value, float):
|
|
683
|
-
return bindings.Lit(bindings.F64(self.value))
|
|
684
|
-
if isinstance(self.value, str):
|
|
685
|
-
return bindings.Lit(bindings.String(self.value))
|
|
686
|
-
assert_never(self.value)
|
|
687
|
-
|
|
688
|
-
def pretty(self, context: PrettyContext, unwrap_lit: bool = True, **kwargs) -> str:
|
|
422
|
+
def __hash__(self) -> int:
|
|
689
423
|
"""
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
:param wrap_lit: If True, wraps the literal in a call to the literal constructor.
|
|
424
|
+
Include type in has so that 1.0 != 1
|
|
693
425
|
"""
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
assert_never(self.value)
|
|
426
|
+
return hash(self.parts)
|
|
427
|
+
|
|
428
|
+
def __eq__(self, other: object) -> bool:
|
|
429
|
+
if not isinstance(other, LitDecl):
|
|
430
|
+
return False
|
|
431
|
+
return self.parts == other.parts
|
|
432
|
+
|
|
433
|
+
@property
|
|
434
|
+
def parts(self) -> tuple[type, LitType]:
|
|
435
|
+
return (type(self.value), self.value)
|
|
705
436
|
|
|
706
437
|
|
|
707
438
|
@dataclass(frozen=True)
|
|
708
439
|
class CallDecl:
|
|
709
440
|
callable: CallableRef
|
|
441
|
+
# TODO: Can I make these not typed expressions?
|
|
710
442
|
args: tuple[TypedExprDecl, ...] = ()
|
|
711
443
|
# type parameters that were bound to the callable, if it is a classmethod
|
|
712
444
|
# Used for pretty printing classmethod calls with type parameters
|
|
713
445
|
bound_tp_params: tuple[JustTypeRef, ...] | None = None
|
|
714
|
-
_cached_hash: int | None = None
|
|
715
446
|
|
|
716
447
|
def __post_init__(self) -> None:
|
|
717
448
|
if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef):
|
|
@@ -719,302 +450,165 @@ class CallDecl:
|
|
|
719
450
|
raise ValueError(msg)
|
|
720
451
|
|
|
721
452
|
def __hash__(self) -> int:
|
|
722
|
-
# Modified hash which will cache result for performance
|
|
723
|
-
if self._cached_hash is None:
|
|
724
|
-
res = hash((self.callable, self.args, self.bound_tp_params))
|
|
725
|
-
object.__setattr__(self, "_cached_hash", res)
|
|
726
|
-
return res
|
|
727
453
|
return self._cached_hash
|
|
728
454
|
|
|
455
|
+
@cached_property
|
|
456
|
+
def _cached_hash(self) -> int:
|
|
457
|
+
return hash((self.callable, self.args, self.bound_tp_params))
|
|
458
|
+
|
|
729
459
|
def __eq__(self, other: object) -> bool:
|
|
730
460
|
# Override eq to use cached hash for perf
|
|
731
461
|
if not isinstance(other, CallDecl):
|
|
732
462
|
return False
|
|
733
463
|
return hash(self) == hash(other)
|
|
734
464
|
|
|
735
|
-
@classmethod
|
|
736
|
-
def from_egg(
|
|
737
|
-
cls,
|
|
738
|
-
egraph: bindings.EGraph,
|
|
739
|
-
decls: Declarations,
|
|
740
|
-
return_tp: JustTypeRef,
|
|
741
|
-
termdag: bindings.TermDag,
|
|
742
|
-
term: bindings.TermApp,
|
|
743
|
-
cache: dict[int, TypedExprDecl],
|
|
744
|
-
) -> ExprDecl:
|
|
745
|
-
"""
|
|
746
|
-
Convert an egg expression into a typed expression by using the declerations.
|
|
747
465
|
|
|
748
|
-
|
|
749
|
-
|
|
466
|
+
ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
@dataclass(frozen=True)
|
|
470
|
+
class TypedExprDecl:
|
|
471
|
+
tp: JustTypeRef
|
|
472
|
+
expr: ExprDecl
|
|
473
|
+
|
|
474
|
+
def descendants(self) -> list[TypedExprDecl]:
|
|
750
475
|
"""
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
# Find the first callable ref that matches the call
|
|
754
|
-
for callable_ref in decls.get_callable_refs(term.name):
|
|
755
|
-
# If this is a classmethod, we might need the type params that were bound for this type
|
|
756
|
-
# This could be multiple types if the classmethod is ambiguous, like map create.
|
|
757
|
-
possible_types: Iterable[JustTypeRef | None]
|
|
758
|
-
fn_decl = decls.get_function_decl(callable_ref)
|
|
759
|
-
if isinstance(callable_ref, ClassMethodRef):
|
|
760
|
-
possible_types = decls.get_possible_types(callable_ref.class_name)
|
|
761
|
-
cls_name = callable_ref.class_name
|
|
762
|
-
else:
|
|
763
|
-
possible_types = [None]
|
|
764
|
-
cls_name = None
|
|
765
|
-
for possible_type in possible_types:
|
|
766
|
-
tcs = TypeConstraintSolver(decls)
|
|
767
|
-
if possible_type and possible_type.args:
|
|
768
|
-
tcs.bind_class(possible_type)
|
|
769
|
-
|
|
770
|
-
try:
|
|
771
|
-
arg_types, bound_tp_params = tcs.infer_arg_types(
|
|
772
|
-
fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, return_tp, cls_name
|
|
773
|
-
)
|
|
774
|
-
except TypeConstraintError:
|
|
775
|
-
continue
|
|
776
|
-
args: list[TypedExprDecl] = []
|
|
777
|
-
for a, tp in zip(term.args, arg_types, strict=False):
|
|
778
|
-
if a in cache:
|
|
779
|
-
res = cache[a]
|
|
780
|
-
else:
|
|
781
|
-
res = TypedExprDecl.from_egg(egraph, decls, tp, termdag, termdag.nodes[a], cache)
|
|
782
|
-
cache[a] = res
|
|
783
|
-
args.append(res)
|
|
784
|
-
return cls(callable_ref, tuple(args), bound_tp_params)
|
|
785
|
-
raise ValueError(f"Could not find callable ref for call {term}")
|
|
786
|
-
|
|
787
|
-
def to_egg(self, decls: Declarations) -> bindings._Expr:
|
|
788
|
-
"""Convert a Call to an egg Call."""
|
|
789
|
-
# This was removed when we replaced declerations constants with our b/c of unextractable constants
|
|
790
|
-
# # If this is a constant, then emit it just as a var, not as a call
|
|
791
|
-
# if isinstance(self.callable, ConstantRef | ClassVariableRef):
|
|
792
|
-
# decls.get_egg_fn
|
|
793
|
-
# return bindings.Var(egg_fn)
|
|
794
|
-
if hasattr(self, "_cached_egg"):
|
|
795
|
-
return self._cached_egg
|
|
796
|
-
egg_fn = decls.get_egg_fn(self.callable)
|
|
797
|
-
res = bindings.Call(egg_fn, [a.to_egg(decls) for a in self.args])
|
|
798
|
-
object.__setattr__(self, "_cached_egg", res)
|
|
799
|
-
return res
|
|
800
|
-
|
|
801
|
-
def pretty(self, context: PrettyContext, parens: bool = True, **kwargs) -> str: # noqa: C901
|
|
476
|
+
Returns a list of all the descendants of this expression.
|
|
802
477
|
"""
|
|
803
|
-
|
|
478
|
+
l = [self]
|
|
479
|
+
if isinstance(self.expr, CallDecl):
|
|
480
|
+
for a in self.expr.args:
|
|
481
|
+
l.extend(a.descendants())
|
|
482
|
+
return l
|
|
804
483
|
|
|
805
|
-
:param parens: If true, wrap the call in parens if it is a binary method call.
|
|
806
|
-
"""
|
|
807
|
-
if self in context.names:
|
|
808
|
-
return context.names[self]
|
|
809
|
-
ref, args = self.callable, [a.expr for a in self.args]
|
|
810
|
-
# Special case !=
|
|
811
|
-
if ref == FunctionRef("!="):
|
|
812
|
-
return f"ne({args[0].pretty(context, parens=False, unwrap_lit=False)}).to({args[1].pretty(context, parens=False, unwrap_lit=False)})"
|
|
813
|
-
function_decl = context.decls.get_function_decl(ref)
|
|
814
|
-
# Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default
|
|
815
|
-
n_defaults = 0
|
|
816
|
-
for arg, default in zip(
|
|
817
|
-
reversed(args), reversed(function_decl.arg_defaults), strict=not function_decl.var_arg_type
|
|
818
|
-
):
|
|
819
|
-
if arg != default:
|
|
820
|
-
break
|
|
821
|
-
n_defaults += 1
|
|
822
|
-
if n_defaults:
|
|
823
|
-
args = args[:-n_defaults]
|
|
824
|
-
if function_decl.mutates_first_arg:
|
|
825
|
-
first_arg = args[0]
|
|
826
|
-
expr_str = first_arg.pretty(context, parens=False)
|
|
827
|
-
# copy an identifer expression iff it has multiple parents (b/c then we can't mutate it directly)
|
|
828
|
-
has_multiple_parents = context.parents[first_arg] > 1
|
|
829
|
-
expr_name = context.name_expr(function_decl.arg_types[0], expr_str, copy_identifier=has_multiple_parents)
|
|
830
|
-
# Set the first arg to be the name of the mutated arg and return the name
|
|
831
|
-
args[0] = VarDecl(expr_name)
|
|
832
|
-
else:
|
|
833
|
-
expr_name = None
|
|
834
|
-
match ref:
|
|
835
|
-
case FunctionRef(name):
|
|
836
|
-
expr = _pretty_call(context, name, args)
|
|
837
|
-
case ClassMethodRef(class_name, method_name):
|
|
838
|
-
tp_ref = JustTypeRef(class_name, self.bound_tp_params or ())
|
|
839
|
-
fn_str = tp_ref.pretty() if method_name == "__init__" else f"{tp_ref.pretty()}.{method_name}"
|
|
840
|
-
expr = _pretty_call(context, fn_str, args)
|
|
841
|
-
case MethodRef(_class_name, method_name):
|
|
842
|
-
slf, *args = args
|
|
843
|
-
slf = slf.pretty(context, unwrap_lit=False)
|
|
844
|
-
match method_name:
|
|
845
|
-
case _ if method_name in UNARY_METHODS:
|
|
846
|
-
expr = f"{UNARY_METHODS[method_name]}{slf}"
|
|
847
|
-
case _ if method_name in BINARY_METHODS:
|
|
848
|
-
assert len(args) == 1
|
|
849
|
-
expr = f"{slf} {BINARY_METHODS[method_name]} {args[0].pretty(context)}"
|
|
850
|
-
if parens:
|
|
851
|
-
expr = f"({expr})"
|
|
852
|
-
case "__getitem__":
|
|
853
|
-
assert len(args) == 1
|
|
854
|
-
expr = f"{slf}[{args[0].pretty(context, parens=False)}]"
|
|
855
|
-
case "__call__":
|
|
856
|
-
expr = _pretty_call(context, slf, args)
|
|
857
|
-
case "__delitem__":
|
|
858
|
-
assert len(args) == 1
|
|
859
|
-
expr = f"del {slf}[{args[0].pretty(context, parens=False)}]"
|
|
860
|
-
case "__setitem__":
|
|
861
|
-
assert len(args) == 2
|
|
862
|
-
expr = (
|
|
863
|
-
f"{slf}[{args[0].pretty(context, parens=False)}] = {args[1].pretty(context, parens=False)}"
|
|
864
|
-
)
|
|
865
|
-
case _:
|
|
866
|
-
expr = _pretty_call(context, f"{slf}.{method_name}", args)
|
|
867
|
-
case ConstantRef(name):
|
|
868
|
-
expr = name
|
|
869
|
-
case ClassVariableRef(class_name, variable_name):
|
|
870
|
-
expr = f"{class_name}.{variable_name}"
|
|
871
|
-
case PropertyRef(_class_name, property_name):
|
|
872
|
-
expr = f"{args[0].pretty(context)}.{property_name}"
|
|
873
|
-
case _:
|
|
874
|
-
assert_never(ref)
|
|
875
|
-
# If we have a name, then we mutated
|
|
876
|
-
if expr_name:
|
|
877
|
-
context.statements.append(expr)
|
|
878
|
-
context.names[self] = expr_name
|
|
879
|
-
return expr_name
|
|
880
|
-
|
|
881
|
-
# We use a heuristic to decide whether to name this sub-expression as a variable
|
|
882
|
-
# The rough goal is to reduce the number of newlines, given our line length of ~180
|
|
883
|
-
# We determine it's worth making a new line for this expression if the total characters
|
|
884
|
-
# it would take up is > than some constant (~ line length).
|
|
885
|
-
n_parents = context.parents[self]
|
|
886
|
-
line_diff: int = len(expr) - LINE_DIFFERENCE
|
|
887
|
-
if n_parents > 1 and n_parents * line_diff > MAX_LINE_LENGTH:
|
|
888
|
-
expr_name = context.name_expr(function_decl.return_type, expr, copy_identifier=False)
|
|
889
|
-
context.names[self] = expr_name
|
|
890
|
-
return expr_name
|
|
891
|
-
return expr
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
MAX_LINE_LENGTH = 110
|
|
895
|
-
LINE_DIFFERENCE = 10
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
def _plot_line_length(expr: object):
|
|
899
|
-
"""
|
|
900
|
-
Plots the number of line lengths based on different max lengths
|
|
901
|
-
"""
|
|
902
|
-
global MAX_LINE_LENGTH, LINE_DIFFERENCE
|
|
903
|
-
import altair as alt
|
|
904
|
-
import pandas as pd
|
|
905
484
|
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
for diff in range(0, 40, 5):
|
|
910
|
-
LINE_DIFFERENCE = diff
|
|
911
|
-
new_l = len(str(expr).split())
|
|
912
|
-
sizes.append((line_length, diff, new_l))
|
|
485
|
+
##
|
|
486
|
+
# Schedules
|
|
487
|
+
##
|
|
913
488
|
|
|
914
|
-
df = pd.DataFrame(sizes, columns=["MAX_LINE_LENGTH", "LENGTH_DIFFERENCE", "n"]) # noqa: PD901
|
|
915
489
|
|
|
916
|
-
|
|
490
|
+
@dataclass(frozen=True)
|
|
491
|
+
class SaturateDecl:
|
|
492
|
+
schedule: ScheduleDecl
|
|
917
493
|
|
|
918
494
|
|
|
919
|
-
|
|
920
|
-
|
|
495
|
+
@dataclass(frozen=True)
|
|
496
|
+
class RepeatDecl:
|
|
497
|
+
schedule: ScheduleDecl
|
|
498
|
+
times: int
|
|
921
499
|
|
|
922
500
|
|
|
923
|
-
@dataclass
|
|
924
|
-
class
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
statements: list[str] = field(default_factory=list)
|
|
928
|
-
|
|
929
|
-
names: dict[ExprDecl, str] = field(default_factory=dict)
|
|
930
|
-
parents: dict[ExprDecl, int] = field(default_factory=lambda: defaultdict(lambda: 0))
|
|
931
|
-
_traversed_exprs: set[ExprDecl] = field(default_factory=set)
|
|
932
|
-
|
|
933
|
-
# Mapping of type to the number of times we have generated a name for that type, used to generate unique names
|
|
934
|
-
_gen_name_types: dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0))
|
|
935
|
-
|
|
936
|
-
def generate_name(self, typ: str) -> str:
|
|
937
|
-
self._gen_name_types[typ] += 1
|
|
938
|
-
return f"_{typ}_{self._gen_name_types[typ]}"
|
|
939
|
-
|
|
940
|
-
def name_expr(self, expr_type: TypeOrVarRef, expr_str: str, copy_identifier: bool) -> str:
|
|
941
|
-
tp_name = expr_type.to_just().name
|
|
942
|
-
# If the thing we are naming is already a variable, we don't need to name it
|
|
943
|
-
if expr_str.isidentifier():
|
|
944
|
-
if copy_identifier:
|
|
945
|
-
name = self.generate_name(tp_name)
|
|
946
|
-
self.statements.append(f"{name} = copy({expr_str})")
|
|
947
|
-
else:
|
|
948
|
-
name = expr_str
|
|
949
|
-
else:
|
|
950
|
-
name = self.generate_name(tp_name)
|
|
951
|
-
self.statements.append(f"{name} = {expr_str}")
|
|
952
|
-
return name
|
|
501
|
+
@dataclass(frozen=True)
|
|
502
|
+
class SequenceDecl:
|
|
503
|
+
schedules: tuple[ScheduleDecl, ...]
|
|
504
|
+
|
|
953
505
|
|
|
954
|
-
|
|
955
|
-
|
|
506
|
+
@dataclass(frozen=True)
|
|
507
|
+
class RunDecl:
|
|
508
|
+
ruleset: str
|
|
509
|
+
until: tuple[FactDecl, ...] | None
|
|
956
510
|
|
|
957
|
-
def traverse_for_parents(self, expr: ExprDecl) -> None:
|
|
958
|
-
if expr in self._traversed_exprs:
|
|
959
|
-
return
|
|
960
|
-
self._traversed_exprs.add(expr)
|
|
961
|
-
if isinstance(expr, CallDecl):
|
|
962
|
-
for arg in set(expr.args):
|
|
963
|
-
self.parents[arg.expr] += 1
|
|
964
|
-
self.traverse_for_parents(arg.expr)
|
|
965
511
|
|
|
512
|
+
ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl
|
|
966
513
|
|
|
967
|
-
|
|
514
|
+
##
|
|
515
|
+
# Facts
|
|
516
|
+
##
|
|
968
517
|
|
|
969
518
|
|
|
970
519
|
@dataclass(frozen=True)
|
|
971
|
-
class
|
|
520
|
+
class EqDecl:
|
|
972
521
|
tp: JustTypeRef
|
|
973
|
-
|
|
522
|
+
exprs: tuple[ExprDecl, ...]
|
|
974
523
|
|
|
975
|
-
@classmethod
|
|
976
|
-
def from_egg(
|
|
977
|
-
cls,
|
|
978
|
-
egraph: bindings.EGraph,
|
|
979
|
-
decls: Declarations,
|
|
980
|
-
tp: JustTypeRef,
|
|
981
|
-
termdag: bindings.TermDag,
|
|
982
|
-
term: bindings._Term,
|
|
983
|
-
cache: dict[int, TypedExprDecl],
|
|
984
|
-
) -> TypedExprDecl:
|
|
985
|
-
expr_decl: ExprDecl
|
|
986
|
-
if isinstance(term, bindings.TermVar):
|
|
987
|
-
expr_decl = VarDecl.from_egg(term)
|
|
988
|
-
elif isinstance(term, bindings.TermLit):
|
|
989
|
-
expr_decl = LitDecl.from_egg(term)
|
|
990
|
-
elif isinstance(term, bindings.TermApp):
|
|
991
|
-
if term.name == "py-object":
|
|
992
|
-
expr_decl = PyObjectDecl.from_egg(egraph, termdag, term)
|
|
993
|
-
else:
|
|
994
|
-
expr_decl = CallDecl.from_egg(egraph, decls, tp, termdag, term, cache)
|
|
995
|
-
else:
|
|
996
|
-
assert_never(term)
|
|
997
|
-
return cls(tp, expr_decl)
|
|
998
524
|
|
|
999
|
-
|
|
1000
|
-
|
|
525
|
+
@dataclass(frozen=True)
|
|
526
|
+
class ExprFactDecl:
|
|
527
|
+
typed_expr: TypedExprDecl
|
|
1001
528
|
|
|
1002
|
-
def descendants(self) -> list[TypedExprDecl]:
|
|
1003
|
-
"""
|
|
1004
|
-
Returns a list of all the descendants of this expression.
|
|
1005
|
-
"""
|
|
1006
|
-
l = [self]
|
|
1007
|
-
if isinstance(self.expr, CallDecl):
|
|
1008
|
-
for a in self.expr.args:
|
|
1009
|
-
l.extend(a.descendants())
|
|
1010
|
-
return l
|
|
1011
529
|
|
|
530
|
+
FactDecl: TypeAlias = EqDecl | ExprFactDecl
|
|
1012
531
|
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
532
|
+
##
|
|
533
|
+
# Actions
|
|
534
|
+
##
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
@dataclass(frozen=True)
|
|
538
|
+
class LetDecl:
|
|
539
|
+
name: str
|
|
540
|
+
typed_expr: TypedExprDecl
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
@dataclass(frozen=True)
|
|
544
|
+
class SetDecl:
|
|
545
|
+
tp: JustTypeRef
|
|
546
|
+
call: CallDecl
|
|
547
|
+
rhs: ExprDecl
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
@dataclass(frozen=True)
|
|
551
|
+
class ExprActionDecl:
|
|
552
|
+
typed_expr: TypedExprDecl
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
@dataclass(frozen=True)
|
|
556
|
+
class ChangeDecl:
|
|
557
|
+
tp: JustTypeRef
|
|
558
|
+
call: CallDecl
|
|
559
|
+
change: Literal["delete", "subsume"]
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
@dataclass(frozen=True)
|
|
563
|
+
class UnionDecl:
|
|
564
|
+
tp: JustTypeRef
|
|
565
|
+
lhs: ExprDecl
|
|
566
|
+
rhs: ExprDecl
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
@dataclass(frozen=True)
|
|
570
|
+
class PanicDecl:
|
|
571
|
+
msg: str
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
##
|
|
578
|
+
# Commands
|
|
579
|
+
##
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
@dataclass(frozen=True)
|
|
583
|
+
class RewriteDecl:
|
|
584
|
+
tp: JustTypeRef
|
|
585
|
+
lhs: ExprDecl
|
|
586
|
+
rhs: ExprDecl
|
|
587
|
+
conditions: tuple[FactDecl, ...]
|
|
588
|
+
subsume: bool
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
@dataclass(frozen=True)
|
|
592
|
+
class BiRewriteDecl:
|
|
593
|
+
tp: JustTypeRef
|
|
594
|
+
lhs: ExprDecl
|
|
595
|
+
rhs: ExprDecl
|
|
596
|
+
conditions: tuple[FactDecl, ...]
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
@dataclass(frozen=True)
|
|
600
|
+
class RuleDecl:
|
|
601
|
+
head: tuple[ActionDecl, ...]
|
|
602
|
+
body: tuple[FactDecl, ...]
|
|
603
|
+
name: str | None
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
@dataclass(frozen=True)
|
|
610
|
+
class ActionCommandDecl:
|
|
611
|
+
action: ActionDecl
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
CommandDecl: TypeAlias = RewriteOrRuleDecl | ActionCommandDecl
|