egglog 6.1.0__cp312-none-win_amd64.whl → 7.1.0__cp312-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +1 -1
- egglog/bindings.cp312-win_amd64.pyd +0 -0
- egglog/bindings.pyi +9 -0
- egglog/builtins.py +42 -2
- egglog/conversion.py +177 -0
- egglog/declarations.py +354 -734
- egglog/egraph.py +602 -800
- egglog/egraph_state.py +456 -0
- egglog/exp/array_api.py +100 -88
- egglog/exp/array_api_numba.py +6 -1
- egglog/exp/siu_examples.py +35 -0
- egglog/pretty.py +464 -0
- egglog/runtime.py +279 -431
- egglog/thunk.py +71 -0
- egglog/type_constraint_solver.py +5 -2
- {egglog-6.1.0.dist-info → egglog-7.1.0.dist-info}/METADATA +7 -7
- {egglog-6.1.0.dist-info → egglog-7.1.0.dist-info}/RECORD +19 -14
- {egglog-6.1.0.dist-info → egglog-7.1.0.dist-info}/WHEEL +0 -0
- {egglog-6.1.0.dist-info → egglog-7.1.0.dist-info}/license_files/LICENSE +0 -0
egglog/declarations.py
CHANGED
|
@@ -1,18 +1,17 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Data only descriptions of the components of an egraph and the expressions.
|
|
3
|
+
|
|
4
|
+
We seperate it it into two pieces, the references the declerations, so that we can report mutually recursive types.
|
|
3
5
|
"""
|
|
4
6
|
|
|
5
7
|
from __future__ import annotations
|
|
6
8
|
|
|
7
|
-
from collections import defaultdict
|
|
8
9
|
from dataclasses import dataclass, field
|
|
9
|
-
from
|
|
10
|
-
from typing import TYPE_CHECKING, Protocol, TypeAlias, Union, runtime_checkable
|
|
10
|
+
from functools import cached_property
|
|
11
|
+
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, Union, runtime_checkable
|
|
11
12
|
|
|
12
13
|
from typing_extensions import Self, assert_never
|
|
13
14
|
|
|
14
|
-
from . import bindings
|
|
15
|
-
|
|
16
15
|
if TYPE_CHECKING:
|
|
17
16
|
from collections.abc import Callable, Iterable
|
|
18
17
|
|
|
@@ -20,84 +19,68 @@ if TYPE_CHECKING:
|
|
|
20
19
|
__all__ = [
|
|
21
20
|
"Declarations",
|
|
22
21
|
"DeclerationsLike",
|
|
23
|
-
"
|
|
22
|
+
"DelayedDeclerations",
|
|
23
|
+
"upcast_declerations",
|
|
24
|
+
"Declarations",
|
|
24
25
|
"JustTypeRef",
|
|
25
26
|
"ClassTypeVarRef",
|
|
26
27
|
"TypeRefWithVars",
|
|
27
28
|
"TypeOrVarRef",
|
|
28
|
-
"FunctionRef",
|
|
29
29
|
"MethodRef",
|
|
30
30
|
"ClassMethodRef",
|
|
31
|
+
"FunctionRef",
|
|
32
|
+
"ConstantRef",
|
|
31
33
|
"ClassVariableRef",
|
|
32
|
-
"FunctionCallableRef",
|
|
33
34
|
"PropertyRef",
|
|
34
35
|
"CallableRef",
|
|
35
|
-
"ConstantRef",
|
|
36
36
|
"FunctionDecl",
|
|
37
|
+
"RelationDecl",
|
|
38
|
+
"ConstantDecl",
|
|
39
|
+
"CallableDecl",
|
|
37
40
|
"VarDecl",
|
|
38
|
-
"LitType",
|
|
39
41
|
"PyObjectDecl",
|
|
42
|
+
"PartialCallDecl",
|
|
43
|
+
"LitType",
|
|
40
44
|
"LitDecl",
|
|
41
45
|
"CallDecl",
|
|
42
46
|
"ExprDecl",
|
|
43
47
|
"TypedExprDecl",
|
|
44
48
|
"ClassDecl",
|
|
45
|
-
"
|
|
46
|
-
"
|
|
49
|
+
"RulesetDecl",
|
|
50
|
+
"CombinedRulesetDecl",
|
|
51
|
+
"SaturateDecl",
|
|
52
|
+
"RepeatDecl",
|
|
53
|
+
"SequenceDecl",
|
|
54
|
+
"RunDecl",
|
|
55
|
+
"ScheduleDecl",
|
|
56
|
+
"EqDecl",
|
|
57
|
+
"ExprFactDecl",
|
|
58
|
+
"FactDecl",
|
|
59
|
+
"LetDecl",
|
|
60
|
+
"SetDecl",
|
|
61
|
+
"ExprActionDecl",
|
|
62
|
+
"ChangeDecl",
|
|
63
|
+
"UnionDecl",
|
|
64
|
+
"PanicDecl",
|
|
65
|
+
"ActionDecl",
|
|
66
|
+
"RewriteDecl",
|
|
67
|
+
"BiRewriteDecl",
|
|
68
|
+
"RuleDecl",
|
|
69
|
+
"RewriteOrRuleDecl",
|
|
70
|
+
"ActionCommandDecl",
|
|
71
|
+
"CommandDecl",
|
|
72
|
+
"SpecialFunctions",
|
|
73
|
+
"FunctionSignature",
|
|
47
74
|
]
|
|
48
75
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
"__lt__": "<",
|
|
58
|
-
"__le__": "<=",
|
|
59
|
-
"__eq__": "==",
|
|
60
|
-
"__ne__": "!=",
|
|
61
|
-
"__gt__": ">",
|
|
62
|
-
"__ge__": ">=",
|
|
63
|
-
# Numeric
|
|
64
|
-
"__add__": "+",
|
|
65
|
-
"__sub__": "-",
|
|
66
|
-
"__mul__": "*",
|
|
67
|
-
"__matmul__": "@",
|
|
68
|
-
"__truediv__": "/",
|
|
69
|
-
"__floordiv__": "//",
|
|
70
|
-
"__mod__": "%",
|
|
71
|
-
# TODO: Support divmod, with tuple return value
|
|
72
|
-
# "__divmod__": "divmod",
|
|
73
|
-
# TODO: Three arg power
|
|
74
|
-
"__pow__": "**",
|
|
75
|
-
"__lshift__": "<<",
|
|
76
|
-
"__rshift__": ">>",
|
|
77
|
-
"__and__": "&",
|
|
78
|
-
"__xor__": "^",
|
|
79
|
-
"__or__": "|",
|
|
80
|
-
}
|
|
81
|
-
REFLECTED_BINARY_METHODS = {
|
|
82
|
-
"__radd__": "__add__",
|
|
83
|
-
"__rsub__": "__sub__",
|
|
84
|
-
"__rmul__": "__mul__",
|
|
85
|
-
"__rmatmul__": "__matmul__",
|
|
86
|
-
"__rtruediv__": "__truediv__",
|
|
87
|
-
"__rfloordiv__": "__floordiv__",
|
|
88
|
-
"__rmod__": "__mod__",
|
|
89
|
-
"__rpow__": "__pow__",
|
|
90
|
-
"__rlshift__": "__lshift__",
|
|
91
|
-
"__rrshift__": "__rshift__",
|
|
92
|
-
"__rand__": "__and__",
|
|
93
|
-
"__rxor__": "__xor__",
|
|
94
|
-
"__ror__": "__or__",
|
|
95
|
-
}
|
|
96
|
-
UNARY_METHODS = {
|
|
97
|
-
"__pos__": "+",
|
|
98
|
-
"__neg__": "-",
|
|
99
|
-
"__invert__": "~",
|
|
100
|
-
}
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class DelayedDeclerations:
|
|
79
|
+
__egg_decls_thunk__: Callable[[], Declarations]
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def __egg_decls__(self) -> Declarations:
|
|
83
|
+
return self.__egg_decls_thunk__()
|
|
101
84
|
|
|
102
85
|
|
|
103
86
|
@runtime_checkable
|
|
@@ -109,7 +92,7 @@ class HasDeclerations(Protocol):
|
|
|
109
92
|
DeclerationsLike: TypeAlias = Union[HasDeclerations, None, "Declarations"]
|
|
110
93
|
|
|
111
94
|
|
|
112
|
-
def
|
|
95
|
+
def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]:
|
|
113
96
|
d = []
|
|
114
97
|
for l in declerations_like:
|
|
115
98
|
if l is None:
|
|
@@ -125,30 +108,14 @@ def upcast_decleratioons(declerations_like: Iterable[DeclerationsLike]) -> list[
|
|
|
125
108
|
|
|
126
109
|
@dataclass
|
|
127
110
|
class Declarations:
|
|
128
|
-
_functions: dict[str, FunctionDecl] = field(default_factory=dict)
|
|
111
|
+
_functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict)
|
|
112
|
+
_constants: dict[str, ConstantDecl] = field(default_factory=dict)
|
|
129
113
|
_classes: dict[str, ClassDecl] = field(default_factory=dict)
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
# Bidirectional mapping between egg function names and python callable references.
|
|
133
|
-
# Note that there are possibly mutliple callable references for a single egg function name, like `+`
|
|
134
|
-
# for both int and rational classes.
|
|
135
|
-
_egg_fn_to_callable_refs: defaultdict[str, set[CallableRef]] = field(default_factory=lambda: defaultdict(set))
|
|
136
|
-
_callable_ref_to_egg_fn: dict[CallableRef, str] = field(default_factory=dict)
|
|
137
|
-
|
|
138
|
-
# Bidirectional mapping between egg sort names and python type references.
|
|
139
|
-
_egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)
|
|
140
|
-
_type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
|
|
141
|
-
|
|
142
|
-
# Mapping from egg name (of sort or function) to command to create it.
|
|
143
|
-
_cmds: dict[str, bindings._Command] = field(default_factory=dict)
|
|
144
|
-
|
|
145
|
-
def __post_init__(self) -> None:
|
|
146
|
-
if "!=" not in self._egg_fn_to_callable_refs:
|
|
147
|
-
self.register_callable_ref(FunctionRef("!="), "!=")
|
|
114
|
+
_rulesets: dict[str, RulesetDecl | CombinedRulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
|
|
148
115
|
|
|
149
116
|
@classmethod
|
|
150
117
|
def create(cls, *others: DeclerationsLike) -> Declarations:
|
|
151
|
-
others =
|
|
118
|
+
others = upcast_declerations(others)
|
|
152
119
|
if not others:
|
|
153
120
|
return Declarations()
|
|
154
121
|
first, *rest = others
|
|
@@ -159,25 +126,9 @@ class Declarations:
|
|
|
159
126
|
return new
|
|
160
127
|
|
|
161
128
|
def copy(self) -> Declarations:
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
_constants=self._constants.copy(),
|
|
166
|
-
_egg_fn_to_callable_refs=defaultdict(set, {k: v.copy() for k, v in self._egg_fn_to_callable_refs.items()}),
|
|
167
|
-
_callable_ref_to_egg_fn=self._callable_ref_to_egg_fn.copy(),
|
|
168
|
-
_egg_sort_to_type_ref=self._egg_sort_to_type_ref.copy(),
|
|
169
|
-
_type_ref_to_egg_sort=self._type_ref_to_egg_sort.copy(),
|
|
170
|
-
_cmds=self._cmds.copy(),
|
|
171
|
-
)
|
|
172
|
-
|
|
173
|
-
def __deepcopy__(self, memo: dict) -> Declarations:
|
|
174
|
-
return self.copy()
|
|
175
|
-
|
|
176
|
-
def add_cmd(self, name: str, cmd: bindings._Command) -> None:
|
|
177
|
-
self._cmds[name] = cmd
|
|
178
|
-
|
|
179
|
-
def list_cmds(self) -> list[bindings._Command]:
|
|
180
|
-
return list(self._cmds.values())
|
|
129
|
+
new = Declarations()
|
|
130
|
+
new |= self
|
|
131
|
+
return new
|
|
181
132
|
|
|
182
133
|
def update(self, *others: DeclerationsLike) -> None:
|
|
183
134
|
for other in others:
|
|
@@ -200,82 +151,26 @@ class Declarations:
|
|
|
200
151
|
"""
|
|
201
152
|
Updates the other decl with these values in palce.
|
|
202
153
|
"""
|
|
203
|
-
# If cmds are == skip unioning for time savings
|
|
204
|
-
# if set(self._cmds) == set(other._cmds) and self.record_cmds and other.record_cmds:
|
|
205
|
-
# return self
|
|
206
154
|
other._functions |= self._functions
|
|
207
155
|
other._classes |= self._classes
|
|
208
156
|
other._constants |= self._constants
|
|
209
|
-
other.
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
other._callable_ref_to_egg_fn |= self._callable_ref_to_egg_fn
|
|
213
|
-
for egg_fn, callable_refs in self._egg_fn_to_callable_refs.items():
|
|
214
|
-
other._egg_fn_to_callable_refs[egg_fn] |= callable_refs
|
|
215
|
-
|
|
216
|
-
def set_function_decl(self, ref: FunctionCallableRef, decl: FunctionDecl) -> None:
|
|
217
|
-
"""
|
|
218
|
-
Sets a function declaration for the given callable reference.
|
|
219
|
-
"""
|
|
157
|
+
other._rulesets |= self._rulesets
|
|
158
|
+
|
|
159
|
+
def get_callable_decl(self, ref: CallableRef) -> CallableDecl:
|
|
220
160
|
match ref:
|
|
221
161
|
case FunctionRef(name):
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
self.
|
|
162
|
+
return self._functions[name]
|
|
163
|
+
case ConstantRef(name):
|
|
164
|
+
return self._constants[name]
|
|
225
165
|
case MethodRef(class_name, method_name):
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
self._classes[class_name].
|
|
229
|
-
case ClassMethodRef(class_name,
|
|
230
|
-
|
|
231
|
-
raise ValueError(f"Class method {class_name}.{method_name} already registered")
|
|
232
|
-
self._classes[class_name].class_methods[method_name] = decl
|
|
166
|
+
return self._classes[class_name].methods[method_name]
|
|
167
|
+
case ClassVariableRef(class_name, name):
|
|
168
|
+
return self._classes[class_name].class_variables[name]
|
|
169
|
+
case ClassMethodRef(class_name, name):
|
|
170
|
+
return self._classes[class_name].class_methods[name]
|
|
233
171
|
case PropertyRef(class_name, property_name):
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
self._classes[class_name].properties[property_name] = decl
|
|
237
|
-
case _:
|
|
238
|
-
assert_never(ref)
|
|
239
|
-
|
|
240
|
-
def set_constant_type(self, ref: ConstantCallableRef, tp: JustTypeRef) -> None:
|
|
241
|
-
match ref:
|
|
242
|
-
case ConstantRef(name):
|
|
243
|
-
if name in self._constants:
|
|
244
|
-
raise ValueError(f"Constant {name} already registered")
|
|
245
|
-
self._constants[name] = tp
|
|
246
|
-
case ClassVariableRef(class_name, variable_name):
|
|
247
|
-
if variable_name in self._classes[class_name].class_variables:
|
|
248
|
-
raise ValueError(f"Class variable {class_name}.{variable_name} already registered")
|
|
249
|
-
self._classes[class_name].class_variables[variable_name] = tp
|
|
250
|
-
case _:
|
|
251
|
-
assert_never(ref)
|
|
252
|
-
|
|
253
|
-
def register_callable_ref(self, ref: CallableRef, egg_name: str) -> None:
|
|
254
|
-
"""
|
|
255
|
-
Registers a callable reference with the given egg name.
|
|
256
|
-
|
|
257
|
-
The callable's function needs to be registered first.
|
|
258
|
-
"""
|
|
259
|
-
if ref in self._callable_ref_to_egg_fn:
|
|
260
|
-
raise ValueError(f"Callable ref {ref} already registered")
|
|
261
|
-
self._callable_ref_to_egg_fn[ref] = egg_name
|
|
262
|
-
self._egg_fn_to_callable_refs[egg_name].add(ref)
|
|
263
|
-
|
|
264
|
-
def get_callable_refs(self, egg_name: str) -> Iterable[CallableRef]:
|
|
265
|
-
return self._egg_fn_to_callable_refs[egg_name]
|
|
266
|
-
|
|
267
|
-
def get_egg_fn(self, ref: CallableRef) -> str:
|
|
268
|
-
return self._callable_ref_to_egg_fn[ref]
|
|
269
|
-
|
|
270
|
-
def get_egg_sort(self, ref: JustTypeRef) -> str:
|
|
271
|
-
return self._type_ref_to_egg_sort[ref]
|
|
272
|
-
|
|
273
|
-
def op_mapping(self) -> dict[str, str]:
|
|
274
|
-
"""
|
|
275
|
-
Create a mapping of egglog function name to Python function name, for use in the serialized format
|
|
276
|
-
for better visualization.
|
|
277
|
-
"""
|
|
278
|
-
return {k: str(next(iter(v))) for k, v in self._egg_fn_to_callable_refs.items() if len(v) == 1}
|
|
172
|
+
return self._classes[class_name].properties[property_name]
|
|
173
|
+
assert_never(ref)
|
|
279
174
|
|
|
280
175
|
def has_method(self, class_name: str, method_name: str) -> bool | None:
|
|
281
176
|
"""
|
|
@@ -285,138 +180,36 @@ class Declarations:
|
|
|
285
180
|
return method_name in self._classes[class_name].methods
|
|
286
181
|
return None
|
|
287
182
|
|
|
288
|
-
def get_function_decl(self, ref: CallableRef) -> FunctionDecl:
|
|
289
|
-
match ref:
|
|
290
|
-
case ConstantRef(name):
|
|
291
|
-
return self._constants[name].to_constant_function_decl()
|
|
292
|
-
case ClassVariableRef(class_name, variable_name):
|
|
293
|
-
return self._classes[class_name].class_variables[variable_name].to_constant_function_decl()
|
|
294
|
-
case FunctionRef(name):
|
|
295
|
-
return self._functions[name]
|
|
296
|
-
case MethodRef(class_name, method_name):
|
|
297
|
-
return self._classes[class_name].methods[method_name]
|
|
298
|
-
case ClassMethodRef(class_name, method_name):
|
|
299
|
-
return self._classes[class_name].class_methods[method_name]
|
|
300
|
-
case PropertyRef(class_name, property_name):
|
|
301
|
-
return self._classes[class_name].properties[property_name]
|
|
302
|
-
assert_never(ref)
|
|
303
|
-
|
|
304
183
|
def get_class_decl(self, name: str) -> ClassDecl:
|
|
305
184
|
return self._classes[name]
|
|
306
185
|
|
|
307
|
-
def get_possible_types(self, cls_name: str) -> frozenset[JustTypeRef]:
|
|
308
|
-
"""
|
|
309
|
-
Given a class name, returns all possible registered types that it can be.
|
|
310
|
-
"""
|
|
311
|
-
return frozenset(tp for tp in self._type_ref_to_egg_sort if tp.name == cls_name)
|
|
312
186
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
187
|
+
@dataclass
|
|
188
|
+
class ClassDecl:
|
|
189
|
+
egg_name: str | None = None
|
|
190
|
+
type_vars: tuple[str, ...] = ()
|
|
191
|
+
builtin: bool = False
|
|
192
|
+
class_methods: dict[str, FunctionDecl] = field(default_factory=dict)
|
|
193
|
+
# These have to be seperate from class_methods so that printing them can be done easily
|
|
194
|
+
class_variables: dict[str, ConstantDecl] = field(default_factory=dict)
|
|
195
|
+
methods: dict[str, FunctionDecl] = field(default_factory=dict)
|
|
196
|
+
properties: dict[str, FunctionDecl] = field(default_factory=dict)
|
|
197
|
+
preserved_methods: dict[str, Callable] = field(default_factory=dict)
|
|
320
198
|
|
|
321
|
-
def register_sort(self, ref: JustTypeRef, builtin: bool, egg_name: str | None = None) -> str:
|
|
322
|
-
"""
|
|
323
|
-
Register a sort with the given name. If no name is given, one is generated.
|
|
324
199
|
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
try:
|
|
329
|
-
egg_sort = self.get_egg_sort(ref)
|
|
330
|
-
except KeyError:
|
|
331
|
-
pass
|
|
332
|
-
else:
|
|
333
|
-
return egg_sort
|
|
334
|
-
egg_name = egg_name or ref.generate_egg_name()
|
|
335
|
-
if egg_name in self._egg_sort_to_type_ref:
|
|
336
|
-
raise ValueError(f"Sort {egg_name} is already registered.")
|
|
337
|
-
self._egg_sort_to_type_ref[egg_name] = ref
|
|
338
|
-
self._type_ref_to_egg_sort[ref] = egg_name
|
|
339
|
-
if not builtin:
|
|
340
|
-
self.add_cmd(
|
|
341
|
-
egg_name,
|
|
342
|
-
bindings.Sort(
|
|
343
|
-
egg_name,
|
|
344
|
-
(
|
|
345
|
-
self.get_egg_sort(JustTypeRef(ref.name)),
|
|
346
|
-
[bindings.Var(self.register_sort(arg, False)) for arg in ref.args],
|
|
347
|
-
)
|
|
348
|
-
if ref.args
|
|
349
|
-
else None,
|
|
350
|
-
),
|
|
351
|
-
)
|
|
352
|
-
|
|
353
|
-
return egg_name
|
|
354
|
-
|
|
355
|
-
def register_function_callable(
|
|
356
|
-
self,
|
|
357
|
-
ref: FunctionCallableRef,
|
|
358
|
-
fn_decl: FunctionDecl,
|
|
359
|
-
egg_name: str | None,
|
|
360
|
-
cost: int | None,
|
|
361
|
-
default: ExprDecl | None,
|
|
362
|
-
merge: ExprDecl | None,
|
|
363
|
-
merge_action: list[bindings._Action],
|
|
364
|
-
unextractable: bool,
|
|
365
|
-
builtin: bool,
|
|
366
|
-
is_relation: bool = False,
|
|
367
|
-
) -> None:
|
|
368
|
-
"""
|
|
369
|
-
Registers a callable with the given egg name.
|
|
200
|
+
@dataclass(frozen=True)
|
|
201
|
+
class RulesetDecl:
|
|
202
|
+
rules: list[RewriteOrRuleDecl]
|
|
370
203
|
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
self
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
if fn_decl.var_arg_type is not None:
|
|
382
|
-
msg = "egglog does not support variable arguments yet."
|
|
383
|
-
raise NotImplementedError(msg)
|
|
384
|
-
# Remove all vars from the type refs, raising an errory if we find one,
|
|
385
|
-
# since we cannot create egg functions with vars
|
|
386
|
-
arg_sorts = [self.register_sort(a.to_just(), False) for a in fn_decl.arg_types]
|
|
387
|
-
cmd: bindings._Command
|
|
388
|
-
if is_relation:
|
|
389
|
-
assert not default
|
|
390
|
-
assert not merge
|
|
391
|
-
assert not merge_action
|
|
392
|
-
assert not cost
|
|
393
|
-
cmd = bindings.Relation(egg_name, arg_sorts)
|
|
394
|
-
else:
|
|
395
|
-
egg_fn_decl = bindings.FunctionDecl(
|
|
396
|
-
egg_name,
|
|
397
|
-
bindings.Schema(arg_sorts, self.register_sort(fn_decl.return_type.to_just(), False)),
|
|
398
|
-
default.to_egg(self) if default else None,
|
|
399
|
-
merge.to_egg(self) if merge else None,
|
|
400
|
-
merge_action,
|
|
401
|
-
cost,
|
|
402
|
-
unextractable,
|
|
403
|
-
)
|
|
404
|
-
cmd = bindings.Function(egg_fn_decl)
|
|
405
|
-
self.add_cmd(egg_name, cmd)
|
|
406
|
-
|
|
407
|
-
def register_constant_callable(self, ref: ConstantCallableRef, type_ref: JustTypeRef, egg_name: str | None) -> None:
|
|
408
|
-
egg_name = egg_name or ref.generate_egg_name()
|
|
409
|
-
self.register_callable_ref(ref, egg_name)
|
|
410
|
-
self.set_constant_type(ref, type_ref)
|
|
411
|
-
egg_sort = self.register_sort(type_ref, False)
|
|
412
|
-
# self.add_cmd(egg_name, bindings.Declare(egg_name, self.get_egg_sort(type_ref)))
|
|
413
|
-
# Use function decleration instead of constant b/c constants cannot be extracted
|
|
414
|
-
# https://github.com/egraphs-good/egglog/issues/334
|
|
415
|
-
fn_decl = bindings.FunctionDecl(egg_name, bindings.Schema([], egg_sort))
|
|
416
|
-
self.add_cmd(egg_name, bindings.Function(fn_decl))
|
|
417
|
-
|
|
418
|
-
def register_preserved_method(self, class_: str, method: str, fn: Callable) -> None:
|
|
419
|
-
self._classes[class_].preserved_methods[method] = fn
|
|
204
|
+
# Make hashable so when traversing for pretty-fying we can know which rulesets we have already
|
|
205
|
+
# made into strings
|
|
206
|
+
def __hash__(self) -> int:
|
|
207
|
+
return hash((type(self), tuple(self.rules)))
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
@dataclass(frozen=True)
|
|
211
|
+
class CombinedRulesetDecl:
|
|
212
|
+
rulesets: tuple[str, ...]
|
|
420
213
|
|
|
421
214
|
|
|
422
215
|
# Have two different types of type refs, one that can include vars recursively and one that cannot.
|
|
@@ -427,38 +220,18 @@ class JustTypeRef:
|
|
|
427
220
|
name: str
|
|
428
221
|
args: tuple[JustTypeRef, ...] = ()
|
|
429
222
|
|
|
430
|
-
def generate_egg_name(self) -> str:
|
|
431
|
-
"""
|
|
432
|
-
Generates an egg sort name for this type reference by linearizing the type.
|
|
433
|
-
"""
|
|
434
|
-
if not self.args:
|
|
435
|
-
return self.name
|
|
436
|
-
args = "_".join(a.generate_egg_name() for a in self.args)
|
|
437
|
-
return f"{self.name}_{args}"
|
|
438
|
-
|
|
439
223
|
def to_var(self) -> TypeRefWithVars:
|
|
440
224
|
return TypeRefWithVars(self.name, tuple(a.to_var() for a in self.args))
|
|
441
225
|
|
|
442
|
-
def
|
|
443
|
-
if
|
|
444
|
-
return self.name
|
|
445
|
-
|
|
446
|
-
return f"{self.name}[{args}]"
|
|
226
|
+
def __str__(self) -> str:
|
|
227
|
+
if self.args:
|
|
228
|
+
return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
|
|
229
|
+
return self.name
|
|
447
230
|
|
|
448
|
-
def to_constant_function_decl(self) -> FunctionDecl:
|
|
449
|
-
"""
|
|
450
|
-
Create a function declaration for a constant function.
|
|
451
231
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
arg_types=(),
|
|
456
|
-
arg_names=(),
|
|
457
|
-
arg_defaults=(),
|
|
458
|
-
return_type=self.to_var(),
|
|
459
|
-
mutates_first_arg=False,
|
|
460
|
-
var_arg_type=None,
|
|
461
|
-
)
|
|
232
|
+
##
|
|
233
|
+
# Type references with vars
|
|
234
|
+
##
|
|
462
235
|
|
|
463
236
|
|
|
464
237
|
@dataclass(frozen=True)
|
|
@@ -473,7 +246,7 @@ class ClassTypeVarRef:
|
|
|
473
246
|
msg = "egglog does not support generic classes yet."
|
|
474
247
|
raise NotImplementedError(msg)
|
|
475
248
|
|
|
476
|
-
def
|
|
249
|
+
def __str__(self) -> str:
|
|
477
250
|
return self.name
|
|
478
251
|
|
|
479
252
|
|
|
@@ -485,30 +258,27 @@ class TypeRefWithVars:
|
|
|
485
258
|
def to_just(self) -> JustTypeRef:
|
|
486
259
|
return JustTypeRef(self.name, tuple(a.to_just() for a in self.args))
|
|
487
260
|
|
|
488
|
-
def
|
|
489
|
-
if
|
|
490
|
-
return self.name
|
|
491
|
-
|
|
492
|
-
return f"{self.name}[{args}]"
|
|
261
|
+
def __str__(self) -> str:
|
|
262
|
+
if self.args:
|
|
263
|
+
return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
|
|
264
|
+
return self.name
|
|
493
265
|
|
|
494
266
|
|
|
495
267
|
TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars
|
|
496
268
|
|
|
269
|
+
##
|
|
270
|
+
# Callables References
|
|
271
|
+
##
|
|
272
|
+
|
|
497
273
|
|
|
498
274
|
@dataclass(frozen=True)
|
|
499
275
|
class FunctionRef:
|
|
500
276
|
name: str
|
|
501
277
|
|
|
502
|
-
def generate_egg_name(self) -> str:
|
|
503
|
-
return self.name
|
|
504
|
-
|
|
505
|
-
def __str__(self) -> str:
|
|
506
|
-
return self.name
|
|
507
|
-
|
|
508
278
|
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
279
|
+
@dataclass(frozen=True)
|
|
280
|
+
class ConstantRef:
|
|
281
|
+
name: str
|
|
512
282
|
|
|
513
283
|
|
|
514
284
|
@dataclass(frozen=True)
|
|
@@ -516,123 +286,122 @@ class MethodRef:
|
|
|
516
286
|
class_name: str
|
|
517
287
|
method_name: str
|
|
518
288
|
|
|
519
|
-
def generate_egg_name(self) -> str:
|
|
520
|
-
return f"{self.class_name}_{self.method_name}"
|
|
521
|
-
|
|
522
|
-
def __str__(self) -> str: # noqa: PLR0911
|
|
523
|
-
match self.method_name:
|
|
524
|
-
case _ if self.method_name in UNARY_METHODS:
|
|
525
|
-
return f"{UNARY_METHODS[self.method_name]}{ARG}"
|
|
526
|
-
case _ if self.method_name in BINARY_METHODS:
|
|
527
|
-
return f"({ARG} {BINARY_METHODS[self.method_name]} {ARG})"
|
|
528
|
-
case "__getitem__":
|
|
529
|
-
return f"{ARG}[{ARG}]"
|
|
530
|
-
case "__call__":
|
|
531
|
-
return f"{ARG}({ARG})"
|
|
532
|
-
case "__delitem__":
|
|
533
|
-
return f"del {ARG}[{ARG}]"
|
|
534
|
-
case "__setitem__":
|
|
535
|
-
return f"{ARG}[{ARG}] = {ARG}"
|
|
536
|
-
return f"{ARG}.{self.method_name}"
|
|
537
|
-
|
|
538
289
|
|
|
539
290
|
@dataclass(frozen=True)
|
|
540
291
|
class ClassMethodRef:
|
|
541
292
|
class_name: str
|
|
542
293
|
method_name: str
|
|
543
294
|
|
|
544
|
-
def generate_egg_name(self) -> str:
|
|
545
|
-
return f"{self.class_name}_{self.method_name}"
|
|
546
295
|
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
296
|
+
@dataclass(frozen=True)
|
|
297
|
+
class ClassVariableRef:
|
|
298
|
+
class_name: str
|
|
299
|
+
var_name: str
|
|
551
300
|
|
|
552
301
|
|
|
553
302
|
@dataclass(frozen=True)
|
|
554
|
-
class
|
|
555
|
-
|
|
303
|
+
class PropertyRef:
|
|
304
|
+
class_name: str
|
|
305
|
+
property_name: str
|
|
556
306
|
|
|
557
|
-
def generate_egg_name(self) -> str:
|
|
558
|
-
return self.name
|
|
559
307
|
|
|
560
|
-
|
|
561
|
-
return self.name
|
|
308
|
+
CallableRef: TypeAlias = FunctionRef | ConstantRef | MethodRef | ClassMethodRef | ClassVariableRef | PropertyRef
|
|
562
309
|
|
|
563
310
|
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
variable_name: str
|
|
311
|
+
##
|
|
312
|
+
# Callables
|
|
313
|
+
##
|
|
568
314
|
|
|
569
|
-
def generate_egg_name(self) -> str:
|
|
570
|
-
return f"{self.class_name}_{self.variable_name}"
|
|
571
315
|
|
|
572
|
-
|
|
573
|
-
|
|
316
|
+
@dataclass(frozen=True)
|
|
317
|
+
class RelationDecl:
|
|
318
|
+
arg_types: tuple[JustTypeRef, ...]
|
|
319
|
+
# List of defaults. None for any arg which doesn't have one.
|
|
320
|
+
arg_defaults: tuple[ExprDecl | None, ...]
|
|
321
|
+
egg_name: str | None
|
|
322
|
+
|
|
323
|
+
def to_function_decl(self) -> FunctionDecl:
|
|
324
|
+
return FunctionDecl(
|
|
325
|
+
FunctionSignature(
|
|
326
|
+
arg_types=tuple(a.to_var() for a in self.arg_types),
|
|
327
|
+
arg_names=tuple(f"__{i}" for i in range(len(self.arg_types))),
|
|
328
|
+
arg_defaults=self.arg_defaults,
|
|
329
|
+
return_type=TypeRefWithVars("Unit"),
|
|
330
|
+
),
|
|
331
|
+
egg_name=self.egg_name,
|
|
332
|
+
default=LitDecl(None),
|
|
333
|
+
)
|
|
574
334
|
|
|
575
335
|
|
|
576
336
|
@dataclass(frozen=True)
|
|
577
|
-
class
|
|
578
|
-
|
|
579
|
-
|
|
337
|
+
class ConstantDecl:
|
|
338
|
+
"""
|
|
339
|
+
Same as `(declare)` in egglog
|
|
340
|
+
"""
|
|
580
341
|
|
|
581
|
-
|
|
582
|
-
|
|
342
|
+
type_ref: JustTypeRef
|
|
343
|
+
egg_name: str | None = None
|
|
583
344
|
|
|
584
|
-
def
|
|
585
|
-
return
|
|
345
|
+
def to_function_decl(self) -> FunctionDecl:
|
|
346
|
+
return FunctionDecl(
|
|
347
|
+
FunctionSignature(return_type=self.type_ref.to_var()),
|
|
348
|
+
egg_name=self.egg_name,
|
|
349
|
+
)
|
|
586
350
|
|
|
587
351
|
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
CallableRef: TypeAlias = ConstantCallableRef | FunctionCallableRef
|
|
352
|
+
# special cases for partial function creation and application, which cannot use the normal python rules
|
|
353
|
+
SpecialFunctions: TypeAlias = Literal["fn-partial", "fn-app"]
|
|
591
354
|
|
|
592
355
|
|
|
593
356
|
@dataclass(frozen=True)
|
|
594
|
-
class
|
|
595
|
-
arg_types: tuple[TypeOrVarRef, ...]
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
arg_defaults: tuple[ExprDecl | None, ...]
|
|
599
|
-
|
|
600
|
-
|
|
357
|
+
class FunctionSignature:
|
|
358
|
+
arg_types: tuple[TypeOrVarRef, ...] = ()
|
|
359
|
+
arg_names: tuple[str, ...] = ()
|
|
360
|
+
# List of defaults. None for any arg which doesn't have one.
|
|
361
|
+
arg_defaults: tuple[ExprDecl | None, ...] = ()
|
|
362
|
+
# If None, then the first arg is mutated and returned
|
|
363
|
+
return_type: TypeOrVarRef | None = None
|
|
601
364
|
var_arg_type: TypeOrVarRef | None = None
|
|
602
365
|
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
Parameter.POSITIONAL_OR_KEYWORD,
|
|
614
|
-
default=transform_default(TypedExprDecl(t.to_just(), d)) if d else Parameter.empty,
|
|
615
|
-
)
|
|
616
|
-
for n, d, t in zip(arg_names, self.arg_defaults, self.arg_types, strict=True)
|
|
617
|
-
]
|
|
618
|
-
if self.var_arg_type is not None:
|
|
619
|
-
parameters.append(Parameter("__rest", Parameter.VAR_POSITIONAL))
|
|
620
|
-
return Signature(parameters)
|
|
366
|
+
@property
|
|
367
|
+
def semantic_return_type(self) -> TypeOrVarRef:
|
|
368
|
+
"""
|
|
369
|
+
The type that is returned by the function, which wil be in the first arg if it mutates it.
|
|
370
|
+
"""
|
|
371
|
+
return self.return_type or self.arg_types[0]
|
|
372
|
+
|
|
373
|
+
@property
|
|
374
|
+
def mutates(self) -> bool:
|
|
375
|
+
return self.return_type is None
|
|
621
376
|
|
|
622
377
|
|
|
623
378
|
@dataclass(frozen=True)
|
|
624
|
-
class
|
|
625
|
-
|
|
379
|
+
class FunctionDecl:
|
|
380
|
+
signature: FunctionSignature | SpecialFunctions = field(default_factory=FunctionSignature)
|
|
381
|
+
|
|
382
|
+
# Egg params
|
|
383
|
+
builtin: bool = False
|
|
384
|
+
egg_name: str | None = None
|
|
385
|
+
cost: int | None = None
|
|
386
|
+
default: ExprDecl | None = None
|
|
387
|
+
on_merge: tuple[ActionDecl, ...] = ()
|
|
388
|
+
merge: ExprDecl | None = None
|
|
389
|
+
unextractable: bool = False
|
|
390
|
+
|
|
391
|
+
def to_function_decl(self) -> FunctionDecl:
|
|
392
|
+
return self
|
|
626
393
|
|
|
627
|
-
@classmethod
|
|
628
|
-
def from_egg(cls, var: bindings.TermVar) -> ExprDecl:
|
|
629
|
-
return cls(var.name)
|
|
630
394
|
|
|
631
|
-
|
|
632
|
-
return bindings.Var(self.name)
|
|
395
|
+
CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl
|
|
633
396
|
|
|
634
|
-
|
|
635
|
-
|
|
397
|
+
##
|
|
398
|
+
# Expressions
|
|
399
|
+
##
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
@dataclass(frozen=True)
|
|
403
|
+
class VarDecl:
|
|
404
|
+
name: str
|
|
636
405
|
|
|
637
406
|
|
|
638
407
|
@dataclass(frozen=True)
|
|
@@ -646,16 +415,14 @@ class PyObjectDecl:
|
|
|
646
415
|
except TypeError:
|
|
647
416
|
return id(self.value)
|
|
648
417
|
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
return
|
|
653
|
-
|
|
654
|
-
def to_egg(self, _decls: Declarations) -> bindings._Expr:
|
|
655
|
-
return GLOBAL_PY_OBJECT_SORT.store(self.value)
|
|
418
|
+
def __eq__(self, other: object) -> bool:
|
|
419
|
+
if not isinstance(other, PyObjectDecl):
|
|
420
|
+
return False
|
|
421
|
+
return self.parts == other.parts
|
|
656
422
|
|
|
657
|
-
|
|
658
|
-
|
|
423
|
+
@property
|
|
424
|
+
def parts(self) -> tuple[type, object]:
|
|
425
|
+
return (type(self.value), self.value)
|
|
659
426
|
|
|
660
427
|
|
|
661
428
|
LitType: TypeAlias = int | str | float | bool | None
|
|
@@ -665,53 +432,30 @@ LitType: TypeAlias = int | str | float | bool | None
|
|
|
665
432
|
class LitDecl:
|
|
666
433
|
value: LitType
|
|
667
434
|
|
|
668
|
-
|
|
669
|
-
def from_egg(cls, lit: bindings.TermLit) -> ExprDecl:
|
|
670
|
-
value = lit.value
|
|
671
|
-
if isinstance(value, bindings.Unit):
|
|
672
|
-
return cls(None)
|
|
673
|
-
return cls(value.value)
|
|
674
|
-
|
|
675
|
-
def to_egg(self, _decls: Declarations) -> bindings.Lit:
|
|
676
|
-
if self.value is None:
|
|
677
|
-
return bindings.Lit(bindings.Unit())
|
|
678
|
-
if isinstance(self.value, bool):
|
|
679
|
-
return bindings.Lit(bindings.Bool(self.value))
|
|
680
|
-
if isinstance(self.value, int):
|
|
681
|
-
return bindings.Lit(bindings.Int(self.value))
|
|
682
|
-
if isinstance(self.value, float):
|
|
683
|
-
return bindings.Lit(bindings.F64(self.value))
|
|
684
|
-
if isinstance(self.value, str):
|
|
685
|
-
return bindings.Lit(bindings.String(self.value))
|
|
686
|
-
assert_never(self.value)
|
|
687
|
-
|
|
688
|
-
def pretty(self, context: PrettyContext, unwrap_lit: bool = True, **kwargs) -> str:
|
|
435
|
+
def __hash__(self) -> int:
|
|
689
436
|
"""
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
:param wrap_lit: If True, wraps the literal in a call to the literal constructor.
|
|
437
|
+
Include type in has so that 1.0 != 1
|
|
693
438
|
"""
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
assert_never(self.value)
|
|
439
|
+
return hash(self.parts)
|
|
440
|
+
|
|
441
|
+
def __eq__(self, other: object) -> bool:
|
|
442
|
+
if not isinstance(other, LitDecl):
|
|
443
|
+
return False
|
|
444
|
+
return self.parts == other.parts
|
|
445
|
+
|
|
446
|
+
@property
|
|
447
|
+
def parts(self) -> tuple[type, LitType]:
|
|
448
|
+
return (type(self.value), self.value)
|
|
705
449
|
|
|
706
450
|
|
|
707
451
|
@dataclass(frozen=True)
|
|
708
452
|
class CallDecl:
|
|
709
453
|
callable: CallableRef
|
|
454
|
+
# TODO: Can I make these not typed expressions?
|
|
710
455
|
args: tuple[TypedExprDecl, ...] = ()
|
|
711
456
|
# type parameters that were bound to the callable, if it is a classmethod
|
|
712
457
|
# Used for pretty printing classmethod calls with type parameters
|
|
713
458
|
bound_tp_params: tuple[JustTypeRef, ...] | None = None
|
|
714
|
-
_cached_hash: int | None = None
|
|
715
459
|
|
|
716
460
|
def __post_init__(self) -> None:
|
|
717
461
|
if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef):
|
|
@@ -719,252 +463,33 @@ class CallDecl:
|
|
|
719
463
|
raise ValueError(msg)
|
|
720
464
|
|
|
721
465
|
def __hash__(self) -> int:
|
|
722
|
-
# Modified hash which will cache result for performance
|
|
723
|
-
if self._cached_hash is None:
|
|
724
|
-
res = hash((self.callable, self.args, self.bound_tp_params))
|
|
725
|
-
object.__setattr__(self, "_cached_hash", res)
|
|
726
|
-
return res
|
|
727
466
|
return self._cached_hash
|
|
728
467
|
|
|
468
|
+
@cached_property
|
|
469
|
+
def _cached_hash(self) -> int:
|
|
470
|
+
return hash((self.callable, self.args, self.bound_tp_params))
|
|
471
|
+
|
|
729
472
|
def __eq__(self, other: object) -> bool:
|
|
730
473
|
# Override eq to use cached hash for perf
|
|
731
474
|
if not isinstance(other, CallDecl):
|
|
732
475
|
return False
|
|
733
476
|
return hash(self) == hash(other)
|
|
734
477
|
|
|
735
|
-
@classmethod
|
|
736
|
-
def from_egg(
|
|
737
|
-
cls,
|
|
738
|
-
egraph: bindings.EGraph,
|
|
739
|
-
decls: Declarations,
|
|
740
|
-
return_tp: JustTypeRef,
|
|
741
|
-
termdag: bindings.TermDag,
|
|
742
|
-
term: bindings.TermApp,
|
|
743
|
-
cache: dict[int, TypedExprDecl],
|
|
744
|
-
) -> ExprDecl:
|
|
745
|
-
"""
|
|
746
|
-
Convert an egg expression into a typed expression by using the declerations.
|
|
747
478
|
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
"""
|
|
751
|
-
from .type_constraint_solver import TypeConstraintError, TypeConstraintSolver
|
|
752
|
-
|
|
753
|
-
# Find the first callable ref that matches the call
|
|
754
|
-
for callable_ref in decls.get_callable_refs(term.name):
|
|
755
|
-
# If this is a classmethod, we might need the type params that were bound for this type
|
|
756
|
-
# This could be multiple types if the classmethod is ambiguous, like map create.
|
|
757
|
-
possible_types: Iterable[JustTypeRef | None]
|
|
758
|
-
fn_decl = decls.get_function_decl(callable_ref)
|
|
759
|
-
if isinstance(callable_ref, ClassMethodRef):
|
|
760
|
-
possible_types = decls.get_possible_types(callable_ref.class_name)
|
|
761
|
-
cls_name = callable_ref.class_name
|
|
762
|
-
else:
|
|
763
|
-
possible_types = [None]
|
|
764
|
-
cls_name = None
|
|
765
|
-
for possible_type in possible_types:
|
|
766
|
-
tcs = TypeConstraintSolver(decls)
|
|
767
|
-
if possible_type and possible_type.args:
|
|
768
|
-
tcs.bind_class(possible_type)
|
|
769
|
-
|
|
770
|
-
try:
|
|
771
|
-
arg_types, bound_tp_params = tcs.infer_arg_types(
|
|
772
|
-
fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, return_tp, cls_name
|
|
773
|
-
)
|
|
774
|
-
except TypeConstraintError:
|
|
775
|
-
continue
|
|
776
|
-
args: list[TypedExprDecl] = []
|
|
777
|
-
for a, tp in zip(term.args, arg_types, strict=False):
|
|
778
|
-
if a in cache:
|
|
779
|
-
res = cache[a]
|
|
780
|
-
else:
|
|
781
|
-
res = TypedExprDecl.from_egg(egraph, decls, tp, termdag, termdag.nodes[a], cache)
|
|
782
|
-
cache[a] = res
|
|
783
|
-
args.append(res)
|
|
784
|
-
return cls(callable_ref, tuple(args), bound_tp_params)
|
|
785
|
-
raise ValueError(f"Could not find callable ref for call {term}")
|
|
786
|
-
|
|
787
|
-
def to_egg(self, decls: Declarations) -> bindings._Expr:
|
|
788
|
-
"""Convert a Call to an egg Call."""
|
|
789
|
-
# This was removed when we replaced declerations constants with our b/c of unextractable constants
|
|
790
|
-
# # If this is a constant, then emit it just as a var, not as a call
|
|
791
|
-
# if isinstance(self.callable, ConstantRef | ClassVariableRef):
|
|
792
|
-
# decls.get_egg_fn
|
|
793
|
-
# return bindings.Var(egg_fn)
|
|
794
|
-
if hasattr(self, "_cached_egg"):
|
|
795
|
-
return self._cached_egg
|
|
796
|
-
egg_fn = decls.get_egg_fn(self.callable)
|
|
797
|
-
res = bindings.Call(egg_fn, [a.to_egg(decls) for a in self.args])
|
|
798
|
-
object.__setattr__(self, "_cached_egg", res)
|
|
799
|
-
return res
|
|
800
|
-
|
|
801
|
-
def pretty(self, context: PrettyContext, parens: bool = True, **kwargs) -> str: # noqa: C901
|
|
802
|
-
"""
|
|
803
|
-
Pretty print the call.
|
|
804
|
-
|
|
805
|
-
:param parens: If true, wrap the call in parens if it is a binary method call.
|
|
806
|
-
"""
|
|
807
|
-
if self in context.names:
|
|
808
|
-
return context.names[self]
|
|
809
|
-
ref, args = self.callable, [a.expr for a in self.args]
|
|
810
|
-
# Special case !=
|
|
811
|
-
if ref == FunctionRef("!="):
|
|
812
|
-
return f"ne({args[0].pretty(context, parens=False, unwrap_lit=False)}).to({args[1].pretty(context, parens=False, unwrap_lit=False)})"
|
|
813
|
-
function_decl = context.decls.get_function_decl(ref)
|
|
814
|
-
# Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default
|
|
815
|
-
n_defaults = 0
|
|
816
|
-
for arg, default in zip(
|
|
817
|
-
reversed(args), reversed(function_decl.arg_defaults), strict=not function_decl.var_arg_type
|
|
818
|
-
):
|
|
819
|
-
if arg != default:
|
|
820
|
-
break
|
|
821
|
-
n_defaults += 1
|
|
822
|
-
if n_defaults:
|
|
823
|
-
args = args[:-n_defaults]
|
|
824
|
-
if function_decl.mutates_first_arg:
|
|
825
|
-
first_arg = args[0]
|
|
826
|
-
expr_str = first_arg.pretty(context, parens=False)
|
|
827
|
-
# copy an identifer expression iff it has multiple parents (b/c then we can't mutate it directly)
|
|
828
|
-
has_multiple_parents = context.parents[first_arg] > 1
|
|
829
|
-
expr_name = context.name_expr(function_decl.arg_types[0], expr_str, copy_identifier=has_multiple_parents)
|
|
830
|
-
# Set the first arg to be the name of the mutated arg and return the name
|
|
831
|
-
args[0] = VarDecl(expr_name)
|
|
832
|
-
else:
|
|
833
|
-
expr_name = None
|
|
834
|
-
match ref:
|
|
835
|
-
case FunctionRef(name):
|
|
836
|
-
expr = _pretty_call(context, name, args)
|
|
837
|
-
case ClassMethodRef(class_name, method_name):
|
|
838
|
-
tp_ref = JustTypeRef(class_name, self.bound_tp_params or ())
|
|
839
|
-
fn_str = tp_ref.pretty() if method_name == "__init__" else f"{tp_ref.pretty()}.{method_name}"
|
|
840
|
-
expr = _pretty_call(context, fn_str, args)
|
|
841
|
-
case MethodRef(_class_name, method_name):
|
|
842
|
-
slf, *args = args
|
|
843
|
-
slf = slf.pretty(context, unwrap_lit=False)
|
|
844
|
-
match method_name:
|
|
845
|
-
case _ if method_name in UNARY_METHODS:
|
|
846
|
-
expr = f"{UNARY_METHODS[method_name]}{slf}"
|
|
847
|
-
case _ if method_name in BINARY_METHODS:
|
|
848
|
-
assert len(args) == 1
|
|
849
|
-
expr = f"{slf} {BINARY_METHODS[method_name]} {args[0].pretty(context)}"
|
|
850
|
-
if parens:
|
|
851
|
-
expr = f"({expr})"
|
|
852
|
-
case "__getitem__":
|
|
853
|
-
assert len(args) == 1
|
|
854
|
-
expr = f"{slf}[{args[0].pretty(context, parens=False)}]"
|
|
855
|
-
case "__call__":
|
|
856
|
-
expr = _pretty_call(context, slf, args)
|
|
857
|
-
case "__delitem__":
|
|
858
|
-
assert len(args) == 1
|
|
859
|
-
expr = f"del {slf}[{args[0].pretty(context, parens=False)}]"
|
|
860
|
-
case "__setitem__":
|
|
861
|
-
assert len(args) == 2
|
|
862
|
-
expr = (
|
|
863
|
-
f"{slf}[{args[0].pretty(context, parens=False)}] = {args[1].pretty(context, parens=False)}"
|
|
864
|
-
)
|
|
865
|
-
case _:
|
|
866
|
-
expr = _pretty_call(context, f"{slf}.{method_name}", args)
|
|
867
|
-
case ConstantRef(name):
|
|
868
|
-
expr = name
|
|
869
|
-
case ClassVariableRef(class_name, variable_name):
|
|
870
|
-
expr = f"{class_name}.{variable_name}"
|
|
871
|
-
case PropertyRef(_class_name, property_name):
|
|
872
|
-
expr = f"{args[0].pretty(context)}.{property_name}"
|
|
873
|
-
case _:
|
|
874
|
-
assert_never(ref)
|
|
875
|
-
# If we have a name, then we mutated
|
|
876
|
-
if expr_name:
|
|
877
|
-
context.statements.append(expr)
|
|
878
|
-
context.names[self] = expr_name
|
|
879
|
-
return expr_name
|
|
880
|
-
|
|
881
|
-
# We use a heuristic to decide whether to name this sub-expression as a variable
|
|
882
|
-
# The rough goal is to reduce the number of newlines, given our line length of ~180
|
|
883
|
-
# We determine it's worth making a new line for this expression if the total characters
|
|
884
|
-
# it would take up is > than some constant (~ line length).
|
|
885
|
-
n_parents = context.parents[self]
|
|
886
|
-
line_diff: int = len(expr) - LINE_DIFFERENCE
|
|
887
|
-
if n_parents > 1 and n_parents * line_diff > MAX_LINE_LENGTH:
|
|
888
|
-
expr_name = context.name_expr(function_decl.return_type, expr, copy_identifier=False)
|
|
889
|
-
context.names[self] = expr_name
|
|
890
|
-
return expr_name
|
|
891
|
-
return expr
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
MAX_LINE_LENGTH = 110
|
|
895
|
-
LINE_DIFFERENCE = 10
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
def _plot_line_length(expr: object):
|
|
899
|
-
"""
|
|
900
|
-
Plots the number of line lengths based on different max lengths
|
|
479
|
+
@dataclass(frozen=True)
|
|
480
|
+
class PartialCallDecl:
|
|
901
481
|
"""
|
|
902
|
-
|
|
903
|
-
import altair as alt
|
|
904
|
-
import pandas as pd
|
|
482
|
+
A partially applied function aka a function sort.
|
|
905
483
|
|
|
906
|
-
|
|
907
|
-
for line_length in range(40, 180, 10):
|
|
908
|
-
MAX_LINE_LENGTH = line_length
|
|
909
|
-
for diff in range(0, 40, 5):
|
|
910
|
-
LINE_DIFFERENCE = diff
|
|
911
|
-
new_l = len(str(expr).split())
|
|
912
|
-
sizes.append((line_length, diff, new_l))
|
|
913
|
-
|
|
914
|
-
df = pd.DataFrame(sizes, columns=["MAX_LINE_LENGTH", "LENGTH_DIFFERENCE", "n"]) # noqa: PD901
|
|
915
|
-
|
|
916
|
-
return alt.Chart(df).mark_rect().encode(x="MAX_LINE_LENGTH:O", y="LENGTH_DIFFERENCE:O", color="n:Q")
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
def _pretty_call(context: PrettyContext, fn: str, args: Iterable[ExprDecl]) -> str:
|
|
920
|
-
return f"{fn}({', '.join(a.pretty(context, parens=False) for a in args)})"
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
@dataclass
|
|
924
|
-
class PrettyContext:
|
|
925
|
-
decls: Declarations
|
|
926
|
-
# List of statements of "context" setting variable for the expr
|
|
927
|
-
statements: list[str] = field(default_factory=list)
|
|
928
|
-
|
|
929
|
-
names: dict[ExprDecl, str] = field(default_factory=dict)
|
|
930
|
-
parents: dict[ExprDecl, int] = field(default_factory=lambda: defaultdict(lambda: 0))
|
|
931
|
-
_traversed_exprs: set[ExprDecl] = field(default_factory=set)
|
|
932
|
-
|
|
933
|
-
# Mapping of type to the number of times we have generated a name for that type, used to generate unique names
|
|
934
|
-
_gen_name_types: dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0))
|
|
935
|
-
|
|
936
|
-
def generate_name(self, typ: str) -> str:
|
|
937
|
-
self._gen_name_types[typ] += 1
|
|
938
|
-
return f"_{typ}_{self._gen_name_types[typ]}"
|
|
939
|
-
|
|
940
|
-
def name_expr(self, expr_type: TypeOrVarRef, expr_str: str, copy_identifier: bool) -> str:
|
|
941
|
-
tp_name = expr_type.to_just().name
|
|
942
|
-
# If the thing we are naming is already a variable, we don't need to name it
|
|
943
|
-
if expr_str.isidentifier():
|
|
944
|
-
if copy_identifier:
|
|
945
|
-
name = self.generate_name(tp_name)
|
|
946
|
-
self.statements.append(f"{name} = copy({expr_str})")
|
|
947
|
-
else:
|
|
948
|
-
name = expr_str
|
|
949
|
-
else:
|
|
950
|
-
name = self.generate_name(tp_name)
|
|
951
|
-
self.statements.append(f"{name} = {expr_str}")
|
|
952
|
-
return name
|
|
484
|
+
Note it does not need to have any args, in which case it's just a function pointer.
|
|
953
485
|
|
|
954
|
-
|
|
955
|
-
|
|
486
|
+
Seperated from the call decl so it's clear it is translated to a `unstable-fn` call.
|
|
487
|
+
"""
|
|
956
488
|
|
|
957
|
-
|
|
958
|
-
if expr in self._traversed_exprs:
|
|
959
|
-
return
|
|
960
|
-
self._traversed_exprs.add(expr)
|
|
961
|
-
if isinstance(expr, CallDecl):
|
|
962
|
-
for arg in set(expr.args):
|
|
963
|
-
self.parents[arg.expr] += 1
|
|
964
|
-
self.traverse_for_parents(arg.expr)
|
|
489
|
+
call: CallDecl
|
|
965
490
|
|
|
966
491
|
|
|
967
|
-
ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl
|
|
492
|
+
ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl
|
|
968
493
|
|
|
969
494
|
|
|
970
495
|
@dataclass(frozen=True)
|
|
@@ -972,33 +497,6 @@ class TypedExprDecl:
|
|
|
972
497
|
tp: JustTypeRef
|
|
973
498
|
expr: ExprDecl
|
|
974
499
|
|
|
975
|
-
@classmethod
|
|
976
|
-
def from_egg(
|
|
977
|
-
cls,
|
|
978
|
-
egraph: bindings.EGraph,
|
|
979
|
-
decls: Declarations,
|
|
980
|
-
tp: JustTypeRef,
|
|
981
|
-
termdag: bindings.TermDag,
|
|
982
|
-
term: bindings._Term,
|
|
983
|
-
cache: dict[int, TypedExprDecl],
|
|
984
|
-
) -> TypedExprDecl:
|
|
985
|
-
expr_decl: ExprDecl
|
|
986
|
-
if isinstance(term, bindings.TermVar):
|
|
987
|
-
expr_decl = VarDecl.from_egg(term)
|
|
988
|
-
elif isinstance(term, bindings.TermLit):
|
|
989
|
-
expr_decl = LitDecl.from_egg(term)
|
|
990
|
-
elif isinstance(term, bindings.TermApp):
|
|
991
|
-
if term.name == "py-object":
|
|
992
|
-
expr_decl = PyObjectDecl.from_egg(egraph, termdag, term)
|
|
993
|
-
else:
|
|
994
|
-
expr_decl = CallDecl.from_egg(egraph, decls, tp, termdag, term, cache)
|
|
995
|
-
else:
|
|
996
|
-
assert_never(term)
|
|
997
|
-
return cls(tp, expr_decl)
|
|
998
|
-
|
|
999
|
-
def to_egg(self, decls: Declarations) -> bindings._Expr:
|
|
1000
|
-
return self.expr.to_egg(decls)
|
|
1001
|
-
|
|
1002
500
|
def descendants(self) -> list[TypedExprDecl]:
|
|
1003
501
|
"""
|
|
1004
502
|
Returns a list of all the descendants of this expression.
|
|
@@ -1010,11 +508,133 @@ class TypedExprDecl:
|
|
|
1010
508
|
return l
|
|
1011
509
|
|
|
1012
510
|
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
511
|
+
##
|
|
512
|
+
# Schedules
|
|
513
|
+
##
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
@dataclass(frozen=True)
|
|
517
|
+
class SaturateDecl:
|
|
518
|
+
schedule: ScheduleDecl
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
@dataclass(frozen=True)
|
|
522
|
+
class RepeatDecl:
|
|
523
|
+
schedule: ScheduleDecl
|
|
524
|
+
times: int
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
@dataclass(frozen=True)
|
|
528
|
+
class SequenceDecl:
|
|
529
|
+
schedules: tuple[ScheduleDecl, ...]
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
@dataclass(frozen=True)
|
|
533
|
+
class RunDecl:
|
|
534
|
+
ruleset: str
|
|
535
|
+
until: tuple[FactDecl, ...] | None
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl
|
|
539
|
+
|
|
540
|
+
##
|
|
541
|
+
# Facts
|
|
542
|
+
##
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
@dataclass(frozen=True)
|
|
546
|
+
class EqDecl:
|
|
547
|
+
tp: JustTypeRef
|
|
548
|
+
exprs: tuple[ExprDecl, ...]
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
@dataclass(frozen=True)
|
|
552
|
+
class ExprFactDecl:
|
|
553
|
+
typed_expr: TypedExprDecl
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
FactDecl: TypeAlias = EqDecl | ExprFactDecl
|
|
557
|
+
|
|
558
|
+
##
|
|
559
|
+
# Actions
|
|
560
|
+
##
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
@dataclass(frozen=True)
|
|
564
|
+
class LetDecl:
|
|
565
|
+
name: str
|
|
566
|
+
typed_expr: TypedExprDecl
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
@dataclass(frozen=True)
|
|
570
|
+
class SetDecl:
|
|
571
|
+
tp: JustTypeRef
|
|
572
|
+
call: CallDecl
|
|
573
|
+
rhs: ExprDecl
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
@dataclass(frozen=True)
|
|
577
|
+
class ExprActionDecl:
|
|
578
|
+
typed_expr: TypedExprDecl
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
@dataclass(frozen=True)
|
|
582
|
+
class ChangeDecl:
|
|
583
|
+
tp: JustTypeRef
|
|
584
|
+
call: CallDecl
|
|
585
|
+
change: Literal["delete", "subsume"]
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
@dataclass(frozen=True)
|
|
589
|
+
class UnionDecl:
|
|
590
|
+
tp: JustTypeRef
|
|
591
|
+
lhs: ExprDecl
|
|
592
|
+
rhs: ExprDecl
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
@dataclass(frozen=True)
|
|
596
|
+
class PanicDecl:
|
|
597
|
+
msg: str
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
##
|
|
604
|
+
# Commands
|
|
605
|
+
##
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
@dataclass(frozen=True)
|
|
609
|
+
class RewriteDecl:
|
|
610
|
+
tp: JustTypeRef
|
|
611
|
+
lhs: ExprDecl
|
|
612
|
+
rhs: ExprDecl
|
|
613
|
+
conditions: tuple[FactDecl, ...]
|
|
614
|
+
subsume: bool
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
@dataclass(frozen=True)
|
|
618
|
+
class BiRewriteDecl:
|
|
619
|
+
tp: JustTypeRef
|
|
620
|
+
lhs: ExprDecl
|
|
621
|
+
rhs: ExprDecl
|
|
622
|
+
conditions: tuple[FactDecl, ...]
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
@dataclass(frozen=True)
|
|
626
|
+
class RuleDecl:
|
|
627
|
+
head: tuple[ActionDecl, ...]
|
|
628
|
+
body: tuple[FactDecl, ...]
|
|
629
|
+
name: str | None
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
@dataclass(frozen=True)
|
|
636
|
+
class ActionCommandDecl:
|
|
637
|
+
action: ActionDecl
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
CommandDecl: TypeAlias = RewriteOrRuleDecl | ActionCommandDecl
|