egglog 7.2.0__cp311-none-win_amd64.whl → 8.0.1__cp311-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp311-win_amd64.pyd +0 -0
- egglog/bindings.pyi +107 -53
- egglog/builtins.py +49 -6
- egglog/conversion.py +32 -9
- egglog/declarations.py +82 -4
- egglog/egraph.py +260 -179
- egglog/egraph_state.py +149 -66
- egglog/examples/higher_order_functions.py +4 -9
- egglog/exp/array_api.py +278 -93
- egglog/exp/array_api_jit.py +4 -8
- egglog/exp/array_api_loopnest.py +149 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +62 -25
- egglog/exp/program_gen.py +23 -17
- egglog/functionalize.py +91 -0
- egglog/ipython_magic.py +1 -1
- egglog/pretty.py +88 -44
- egglog/runtime.py +53 -40
- egglog/thunk.py +30 -18
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35774 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.2.0.dist-info → egglog-8.0.1.dist-info}/METADATA +33 -32
- egglog-8.0.1.dist-info/RECORD +42 -0
- {egglog-7.2.0.dist-info → egglog-8.0.1.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.2.0.dist-info/RECORD +0 -40
- {egglog-7.2.0.dist-info/license_files → egglog-8.0.1.dist-info/licenses}/LICENSE +0 -0
egglog/egraph_state.py
CHANGED
|
@@ -4,10 +4,10 @@ Implement conversion to/from egglog.
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
+
import re
|
|
7
8
|
from collections import defaultdict
|
|
8
9
|
from dataclasses import dataclass, field
|
|
9
10
|
from typing import TYPE_CHECKING, overload
|
|
10
|
-
from weakref import WeakKeyDictionary
|
|
11
11
|
|
|
12
12
|
from typing_extensions import assert_never
|
|
13
13
|
|
|
@@ -19,7 +19,7 @@ from .type_constraint_solver import TypeConstraintError, TypeConstraintSolver
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
20
|
from collections.abc import Iterable
|
|
21
21
|
|
|
22
|
-
__all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT"
|
|
22
|
+
__all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT"]
|
|
23
23
|
|
|
24
24
|
# Create a global sort for python objects, so we can store them without an e-graph instance
|
|
25
25
|
# Needed when serializing commands to egg commands when creating modules
|
|
@@ -52,7 +52,7 @@ class EGraphState:
|
|
|
52
52
|
type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
|
|
53
53
|
|
|
54
54
|
# Cache of egg expressions for converting to egg
|
|
55
|
-
expr_to_egg_cache:
|
|
55
|
+
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)
|
|
56
56
|
|
|
57
57
|
def copy(self) -> EGraphState:
|
|
58
58
|
"""
|
|
@@ -71,15 +71,15 @@ class EGraphState:
|
|
|
71
71
|
def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
|
|
72
72
|
match schedule:
|
|
73
73
|
case SaturateDecl(schedule):
|
|
74
|
-
return bindings.Saturate(self.schedule_to_egg(schedule))
|
|
74
|
+
return bindings.Saturate(bindings.DUMMY_SPAN, self.schedule_to_egg(schedule))
|
|
75
75
|
case RepeatDecl(schedule, times):
|
|
76
|
-
return bindings.Repeat(times, self.schedule_to_egg(schedule))
|
|
76
|
+
return bindings.Repeat(bindings.DUMMY_SPAN, times, self.schedule_to_egg(schedule))
|
|
77
77
|
case SequenceDecl(schedules):
|
|
78
|
-
return bindings.Sequence([self.schedule_to_egg(s) for s in schedules])
|
|
78
|
+
return bindings.Sequence(bindings.DUMMY_SPAN, [self.schedule_to_egg(s) for s in schedules])
|
|
79
79
|
case RunDecl(ruleset_name, until):
|
|
80
80
|
self.ruleset_to_egg(ruleset_name)
|
|
81
81
|
config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until)))
|
|
82
|
-
return bindings.Run(config)
|
|
82
|
+
return bindings.Run(bindings.DUMMY_SPAN, config)
|
|
83
83
|
case _:
|
|
84
84
|
assert_never(schedule)
|
|
85
85
|
|
|
@@ -116,6 +116,7 @@ class EGraphState:
|
|
|
116
116
|
case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
|
|
117
117
|
self.type_ref_to_egg(tp)
|
|
118
118
|
rewrite = bindings.Rewrite(
|
|
119
|
+
bindings.DUMMY_SPAN,
|
|
119
120
|
self._expr_to_egg(lhs),
|
|
120
121
|
self._expr_to_egg(rhs),
|
|
121
122
|
[self.fact_to_egg(c) for c in conditions],
|
|
@@ -127,19 +128,24 @@ class EGraphState:
|
|
|
127
128
|
)
|
|
128
129
|
case RuleDecl(head, body, name):
|
|
129
130
|
rule = bindings.Rule(
|
|
131
|
+
bindings.DUMMY_SPAN,
|
|
130
132
|
[self.action_to_egg(a) for a in head],
|
|
131
133
|
[self.fact_to_egg(f) for f in body],
|
|
132
134
|
)
|
|
133
135
|
return bindings.RuleCommand(name or "", ruleset, rule)
|
|
136
|
+
# TODO: Replace with just constants value and looking at REF of function
|
|
134
137
|
case DefaultRewriteDecl(ref, expr):
|
|
135
138
|
decl = self.__egg_decls__.get_callable_decl(ref).to_function_decl()
|
|
136
139
|
sig = decl.signature
|
|
137
140
|
assert isinstance(sig, FunctionSignature)
|
|
138
|
-
args
|
|
139
|
-
|
|
141
|
+
# Replace args with rule_var_name mapping
|
|
142
|
+
arg_mapping = tuple(
|
|
143
|
+
TypedExprDecl(tp.to_just(), VarDecl(name, False))
|
|
140
144
|
for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
|
|
141
145
|
)
|
|
142
|
-
rewrite_decl = RewriteDecl(
|
|
146
|
+
rewrite_decl = RewriteDecl(
|
|
147
|
+
sig.semantic_return_type.to_just(), CallDecl(ref, arg_mapping), expr, (), False
|
|
148
|
+
)
|
|
143
149
|
return self.command_to_egg(rewrite_decl, ruleset)
|
|
144
150
|
case _:
|
|
145
151
|
assert_never(cmd)
|
|
@@ -147,13 +153,16 @@ class EGraphState:
|
|
|
147
153
|
def action_to_egg(self, action: ActionDecl) -> bindings._Action:
|
|
148
154
|
match action:
|
|
149
155
|
case LetDecl(name, typed_expr):
|
|
150
|
-
|
|
156
|
+
var_decl = VarDecl(name, True)
|
|
157
|
+
var_egg = self._expr_to_egg(var_decl)
|
|
158
|
+
self.expr_to_egg_cache[var_decl] = var_egg
|
|
159
|
+
return bindings.Let(bindings.DUMMY_SPAN, var_egg.name, self.typed_expr_to_egg(typed_expr))
|
|
151
160
|
case SetDecl(tp, call, rhs):
|
|
152
161
|
self.type_ref_to_egg(tp)
|
|
153
162
|
call_ = self._expr_to_egg(call)
|
|
154
|
-
return bindings.Set(call_.name, call_.args, self._expr_to_egg(rhs))
|
|
163
|
+
return bindings.Set(bindings.DUMMY_SPAN, call_.name, call_.args, self._expr_to_egg(rhs))
|
|
155
164
|
case ExprActionDecl(typed_expr):
|
|
156
|
-
return bindings.Expr_(self.typed_expr_to_egg(typed_expr))
|
|
165
|
+
return bindings.Expr_(bindings.DUMMY_SPAN, self.typed_expr_to_egg(typed_expr))
|
|
157
166
|
case ChangeDecl(tp, call, change):
|
|
158
167
|
self.type_ref_to_egg(tp)
|
|
159
168
|
call_ = self._expr_to_egg(call)
|
|
@@ -165,12 +174,12 @@ class EGraphState:
|
|
|
165
174
|
egg_change = bindings.Subsume()
|
|
166
175
|
case _:
|
|
167
176
|
assert_never(change)
|
|
168
|
-
return bindings.Change(egg_change, call_.name, call_.args)
|
|
177
|
+
return bindings.Change(bindings.DUMMY_SPAN, egg_change, call_.name, call_.args)
|
|
169
178
|
case UnionDecl(tp, lhs, rhs):
|
|
170
179
|
self.type_ref_to_egg(tp)
|
|
171
|
-
return bindings.Union(self._expr_to_egg(lhs), self._expr_to_egg(rhs))
|
|
180
|
+
return bindings.Union(bindings.DUMMY_SPAN, self._expr_to_egg(lhs), self._expr_to_egg(rhs))
|
|
172
181
|
case PanicDecl(name):
|
|
173
|
-
return bindings.Panic(name)
|
|
182
|
+
return bindings.Panic(bindings.DUMMY_SPAN, name)
|
|
174
183
|
case _:
|
|
175
184
|
assert_never(action)
|
|
176
185
|
|
|
@@ -178,9 +187,9 @@ class EGraphState:
|
|
|
178
187
|
match fact:
|
|
179
188
|
case EqDecl(tp, exprs):
|
|
180
189
|
self.type_ref_to_egg(tp)
|
|
181
|
-
return bindings.Eq([self._expr_to_egg(e) for e in exprs])
|
|
190
|
+
return bindings.Eq(bindings.DUMMY_SPAN, [self._expr_to_egg(e) for e in exprs])
|
|
182
191
|
case ExprFactDecl(typed_expr):
|
|
183
|
-
return bindings.Fact(self.typed_expr_to_egg(typed_expr))
|
|
192
|
+
return bindings.Fact(self.typed_expr_to_egg(typed_expr, False))
|
|
184
193
|
case _:
|
|
185
194
|
assert_never(fact)
|
|
186
195
|
|
|
@@ -191,22 +200,31 @@ class EGraphState:
|
|
|
191
200
|
if ref in self.callable_ref_to_egg_fn:
|
|
192
201
|
return self.callable_ref_to_egg_fn[ref]
|
|
193
202
|
decl = self.__egg_decls__.get_callable_decl(ref)
|
|
194
|
-
self.callable_ref_to_egg_fn[ref] = egg_name = decl.egg_name or
|
|
203
|
+
self.callable_ref_to_egg_fn[ref] = egg_name = decl.egg_name or _sanitize_egg_ident(
|
|
204
|
+
self._generate_callable_egg_name(ref)
|
|
205
|
+
)
|
|
195
206
|
self.egg_fn_to_callable_refs[egg_name].add(ref)
|
|
196
207
|
match decl:
|
|
197
208
|
case RelationDecl(arg_types, _, _):
|
|
198
|
-
self.egraph.run_program(
|
|
209
|
+
self.egraph.run_program(
|
|
210
|
+
bindings.Relation(bindings.DUMMY_SPAN, egg_name, [self.type_ref_to_egg(a) for a in arg_types])
|
|
211
|
+
)
|
|
199
212
|
case ConstantDecl(tp, _):
|
|
200
213
|
# Use function decleration instead of constant b/c constants cannot be extracted
|
|
201
214
|
# https://github.com/egraphs-good/egglog/issues/334
|
|
202
215
|
self.egraph.run_program(
|
|
203
|
-
bindings.Function(
|
|
216
|
+
bindings.Function(
|
|
217
|
+
bindings.FunctionDecl(
|
|
218
|
+
bindings.DUMMY_SPAN, egg_name, bindings.Schema([], self.type_ref_to_egg(tp))
|
|
219
|
+
)
|
|
220
|
+
)
|
|
204
221
|
)
|
|
205
222
|
case FunctionDecl():
|
|
206
223
|
if not decl.builtin:
|
|
207
224
|
signature = decl.signature
|
|
208
225
|
assert isinstance(signature, FunctionSignature), "Cannot turn special function to egg"
|
|
209
226
|
egg_fn_decl = bindings.FunctionDecl(
|
|
227
|
+
bindings.DUMMY_SPAN,
|
|
210
228
|
egg_name,
|
|
211
229
|
bindings.Schema(
|
|
212
230
|
[self.type_ref_to_egg(a.to_just()) for a in signature.arg_types],
|
|
@@ -239,17 +257,18 @@ class EGraphState:
|
|
|
239
257
|
# UnstableFn is a special case, where the rest of args are collected into a call
|
|
240
258
|
type_args: list[bindings._Expr] = [
|
|
241
259
|
bindings.Call(
|
|
260
|
+
bindings.DUMMY_SPAN,
|
|
242
261
|
self.type_ref_to_egg(ref.args[1]),
|
|
243
|
-
[bindings.Var(self.type_ref_to_egg(a)) for a in ref.args[2:]],
|
|
262
|
+
[bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(a)) for a in ref.args[2:]],
|
|
244
263
|
),
|
|
245
|
-
bindings.Var(self.type_ref_to_egg(ref.args[0])),
|
|
264
|
+
bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(ref.args[0])),
|
|
246
265
|
]
|
|
247
266
|
else:
|
|
248
|
-
type_args = [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args]
|
|
267
|
+
type_args = [bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(a)) for a in ref.args]
|
|
249
268
|
args = (self.type_ref_to_egg(JustTypeRef(ref.name)), type_args)
|
|
250
269
|
else:
|
|
251
270
|
args = None
|
|
252
|
-
self.egraph.run_program(bindings.Sort(egg_name, args))
|
|
271
|
+
self.egraph.run_program(bindings.Sort(bindings.DUMMY_SPAN, egg_name, args))
|
|
253
272
|
# For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods, because
|
|
254
273
|
# these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted
|
|
255
274
|
# even if you never use that function.
|
|
@@ -272,31 +291,60 @@ class EGraphState:
|
|
|
272
291
|
if len(v) == 1
|
|
273
292
|
}
|
|
274
293
|
|
|
275
|
-
def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl) -> bindings._Expr:
|
|
294
|
+
def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl, transform_let: bool = True) -> bindings._Expr:
|
|
295
|
+
# transform all expressions with multiple parents into a let binding, so that less expressions
|
|
296
|
+
# are sent to egglog. Only for performance reasons.
|
|
297
|
+
if transform_let:
|
|
298
|
+
have_multiple_parents = _exprs_multiple_parents(typed_expr_decl)
|
|
299
|
+
for expr in reversed(have_multiple_parents):
|
|
300
|
+
self._transform_let(expr)
|
|
301
|
+
|
|
276
302
|
self.type_ref_to_egg(typed_expr_decl.tp)
|
|
277
303
|
return self._expr_to_egg(typed_expr_decl.expr)
|
|
278
304
|
|
|
305
|
+
def _transform_let(self, typed_expr: TypedExprDecl) -> None:
|
|
306
|
+
"""
|
|
307
|
+
Rewrites this expression as a let binding if it's not already a let binding.
|
|
308
|
+
"""
|
|
309
|
+
var_decl = VarDecl(f"__expr_{hash(typed_expr)}", True)
|
|
310
|
+
if var_decl in self.expr_to_egg_cache:
|
|
311
|
+
return
|
|
312
|
+
var_egg = self._expr_to_egg(var_decl)
|
|
313
|
+
cmd = bindings.ActionCommand(
|
|
314
|
+
bindings.Let(bindings.DUMMY_SPAN, var_egg.name, self.typed_expr_to_egg(typed_expr))
|
|
315
|
+
)
|
|
316
|
+
try:
|
|
317
|
+
self.egraph.run_program(cmd)
|
|
318
|
+
# errors when creating let bindings for things like `(vec-empty)`
|
|
319
|
+
except bindings.EggSmolError:
|
|
320
|
+
return
|
|
321
|
+
self.expr_to_egg_cache[typed_expr.expr] = var_egg
|
|
322
|
+
self.expr_to_egg_cache[var_decl] = var_egg
|
|
323
|
+
|
|
279
324
|
@overload
|
|
280
325
|
def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
|
|
281
326
|
|
|
327
|
+
@overload
|
|
328
|
+
def _expr_to_egg(self, expr_decl: VarDecl) -> bindings.Var: ...
|
|
329
|
+
|
|
282
330
|
@overload
|
|
283
331
|
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
|
|
284
332
|
|
|
285
|
-
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
|
|
333
|
+
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: # noqa: PLR0912,C901
|
|
286
334
|
"""
|
|
287
335
|
Convert an ExprDecl to an egg expression.
|
|
288
|
-
|
|
289
|
-
Cached using weakrefs to avoid memory leaks.
|
|
290
336
|
"""
|
|
291
337
|
try:
|
|
292
338
|
return self.expr_to_egg_cache[expr_decl]
|
|
293
339
|
except KeyError:
|
|
294
340
|
pass
|
|
295
|
-
|
|
296
341
|
res: bindings._Expr
|
|
297
342
|
match expr_decl:
|
|
298
|
-
case VarDecl(name):
|
|
299
|
-
|
|
343
|
+
case VarDecl(name, is_let):
|
|
344
|
+
# prefix let bindings with % to avoid name conflicts with rewrites
|
|
345
|
+
if is_let:
|
|
346
|
+
name = f"%{name}"
|
|
347
|
+
res = bindings.Var(bindings.DUMMY_SPAN, name)
|
|
300
348
|
case LitDecl(value):
|
|
301
349
|
l: bindings._Literal
|
|
302
350
|
match value:
|
|
@@ -312,19 +360,22 @@ class EGraphState:
|
|
|
312
360
|
l = bindings.String(s)
|
|
313
361
|
case _:
|
|
314
362
|
assert_never(value)
|
|
315
|
-
res = bindings.Lit(l)
|
|
363
|
+
res = bindings.Lit(bindings.DUMMY_SPAN, l)
|
|
316
364
|
case CallDecl(ref, args, _):
|
|
317
365
|
egg_fn = self.callable_ref_to_egg(ref)
|
|
318
|
-
egg_args = [self.typed_expr_to_egg(a) for a in args]
|
|
319
|
-
res = bindings.Call(egg_fn, egg_args)
|
|
366
|
+
egg_args = [self.typed_expr_to_egg(a, False) for a in args]
|
|
367
|
+
res = bindings.Call(bindings.DUMMY_SPAN, egg_fn, egg_args)
|
|
320
368
|
case PyObjectDecl(value):
|
|
321
369
|
res = GLOBAL_PY_OBJECT_SORT.store(value)
|
|
322
370
|
case PartialCallDecl(call_decl):
|
|
323
371
|
egg_fn_call = self._expr_to_egg(call_decl)
|
|
324
|
-
res = bindings.Call(
|
|
372
|
+
res = bindings.Call(
|
|
373
|
+
bindings.DUMMY_SPAN,
|
|
374
|
+
"unstable-fn",
|
|
375
|
+
[bindings.Lit(bindings.DUMMY_SPAN, bindings.String(egg_fn_call.name)), *egg_fn_call.args],
|
|
376
|
+
)
|
|
325
377
|
case _:
|
|
326
378
|
assert_never(expr_decl.expr)
|
|
327
|
-
|
|
328
379
|
self.expr_to_egg_cache[expr_decl] = res
|
|
329
380
|
return res
|
|
330
381
|
|
|
@@ -343,6 +394,65 @@ class EGraphState:
|
|
|
343
394
|
"""
|
|
344
395
|
return frozenset(tp for tp in self.type_ref_to_egg_sort if tp.name == cls_name)
|
|
345
396
|
|
|
397
|
+
def _generate_callable_egg_name(self, ref: CallableRef) -> str:
|
|
398
|
+
"""
|
|
399
|
+
Generates a valid egg function name for a callable reference.
|
|
400
|
+
"""
|
|
401
|
+
match ref:
|
|
402
|
+
case FunctionRef(name):
|
|
403
|
+
return name
|
|
404
|
+
|
|
405
|
+
case ConstantRef(name):
|
|
406
|
+
return name
|
|
407
|
+
case (
|
|
408
|
+
MethodRef(cls_name, name)
|
|
409
|
+
| ClassMethodRef(cls_name, name)
|
|
410
|
+
| ClassVariableRef(cls_name, name)
|
|
411
|
+
| PropertyRef(cls_name, name)
|
|
412
|
+
):
|
|
413
|
+
return f"{cls_name}.{name}"
|
|
414
|
+
case InitRef(cls_name):
|
|
415
|
+
return f"{cls_name}.__init__"
|
|
416
|
+
case UnnamedFunctionRef(args, val):
|
|
417
|
+
parts = [str(self._expr_to_egg(a.expr)) + "-" + str(self.type_ref_to_egg(a.tp)) for a in args] + [
|
|
418
|
+
str(self.typed_expr_to_egg(val, False))
|
|
419
|
+
]
|
|
420
|
+
return "_".join(parts)
|
|
421
|
+
case _:
|
|
422
|
+
assert_never(ref)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
# https://chatgpt.com/share/9ab899b4-4e17-4426-a3f2-79d67a5ec456
|
|
426
|
+
_EGGLOG_INVALID_IDENT = re.compile(r"[^\w\-+*/?!=<>&|^/%]")
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def _sanitize_egg_ident(input_string: str) -> str:
|
|
430
|
+
"""
|
|
431
|
+
Replaces all invalid characters in an egg identifier with an underscore.
|
|
432
|
+
"""
|
|
433
|
+
return _EGGLOG_INVALID_IDENT.sub("_", input_string)
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def _exprs_multiple_parents(typed_expr: TypedExprDecl) -> list[TypedExprDecl]:
|
|
437
|
+
"""
|
|
438
|
+
Returns all expressions that have multiple parents (a list but semantically just an ordered set).
|
|
439
|
+
"""
|
|
440
|
+
to_traverse = {typed_expr}
|
|
441
|
+
traversed = set[TypedExprDecl]()
|
|
442
|
+
traversed_twice = list[TypedExprDecl]()
|
|
443
|
+
while to_traverse:
|
|
444
|
+
typed_expr = to_traverse.pop()
|
|
445
|
+
if typed_expr in traversed:
|
|
446
|
+
traversed_twice.append(typed_expr)
|
|
447
|
+
continue
|
|
448
|
+
traversed.add(typed_expr)
|
|
449
|
+
expr = typed_expr.expr
|
|
450
|
+
if isinstance(expr, CallDecl):
|
|
451
|
+
to_traverse.update(expr.args)
|
|
452
|
+
elif isinstance(expr, PartialCallDecl):
|
|
453
|
+
to_traverse.update(expr.call.args)
|
|
454
|
+
return traversed_twice
|
|
455
|
+
|
|
346
456
|
|
|
347
457
|
def _generate_type_egg_name(ref: JustTypeRef) -> str:
|
|
348
458
|
"""
|
|
@@ -354,26 +464,6 @@ def _generate_type_egg_name(ref: JustTypeRef) -> str:
|
|
|
354
464
|
return f"{name}_{'_'.join(map(_generate_type_egg_name, ref.args))}"
|
|
355
465
|
|
|
356
466
|
|
|
357
|
-
def _generate_callable_egg_name(ref: CallableRef) -> str:
|
|
358
|
-
"""
|
|
359
|
-
Generates a valid egg function name for a callable reference.
|
|
360
|
-
"""
|
|
361
|
-
match ref:
|
|
362
|
-
case FunctionRef(name) | ConstantRef(name):
|
|
363
|
-
return name
|
|
364
|
-
case (
|
|
365
|
-
MethodRef(cls_name, name)
|
|
366
|
-
| ClassMethodRef(cls_name, name)
|
|
367
|
-
| ClassVariableRef(cls_name, name)
|
|
368
|
-
| PropertyRef(cls_name, name)
|
|
369
|
-
):
|
|
370
|
-
return f"{cls_name}_{name}"
|
|
371
|
-
case InitRef(cls_name):
|
|
372
|
-
return f"{cls_name}___init__"
|
|
373
|
-
case _:
|
|
374
|
-
assert_never(ref)
|
|
375
|
-
|
|
376
|
-
|
|
377
467
|
@dataclass
|
|
378
468
|
class FromEggState:
|
|
379
469
|
"""
|
|
@@ -395,7 +485,7 @@ class FromEggState:
|
|
|
395
485
|
"""
|
|
396
486
|
expr_decl: ExprDecl
|
|
397
487
|
if isinstance(term, bindings.TermVar):
|
|
398
|
-
expr_decl = VarDecl(term.name)
|
|
488
|
+
expr_decl = VarDecl(term.name, True)
|
|
399
489
|
elif isinstance(term, bindings.TermLit):
|
|
400
490
|
value = term.value
|
|
401
491
|
expr_decl = LitDecl(None if isinstance(value, bindings.Unit) else value.value)
|
|
@@ -403,7 +493,7 @@ class FromEggState:
|
|
|
403
493
|
if term.name == "py-object":
|
|
404
494
|
call = bindings.termdag_term_to_expr(self.termdag, term)
|
|
405
495
|
expr_decl = PyObjectDecl(self.state.egraph.eval_py_object(call))
|
|
406
|
-
|
|
496
|
+
elif term.name == "unstable-fn":
|
|
407
497
|
# Get function name
|
|
408
498
|
fn_term, *arg_terms = term.args
|
|
409
499
|
fn_value = self.resolve_term(fn_term, JustTypeRef("String"))
|
|
@@ -478,10 +568,3 @@ class FromEggState:
|
|
|
478
568
|
except KeyError:
|
|
479
569
|
res = self.cache[term_id] = self.from_expr(tp, self.termdag.nodes[term_id])
|
|
480
570
|
return res
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
def _rule_var_name(s: str) -> str:
|
|
484
|
-
"""
|
|
485
|
-
Create a hidden variable name, for rewrites, so that let bindings or function won't conflict with it
|
|
486
|
-
"""
|
|
487
|
-
return f"__var__{s}"
|
|
@@ -16,35 +16,30 @@ if TYPE_CHECKING:
|
|
|
16
16
|
|
|
17
17
|
class Math(Expr):
|
|
18
18
|
def __init__(self, i: i64Like) -> None: ...
|
|
19
|
-
|
|
20
19
|
def __add__(self, other: Math) -> Math: ...
|
|
21
20
|
|
|
22
21
|
|
|
23
22
|
class MathList(Expr):
|
|
24
23
|
def __init__(self) -> None: ...
|
|
25
|
-
|
|
26
24
|
def append(self, i: Math) -> MathList: ...
|
|
27
|
-
|
|
28
25
|
def map(self, f: Callable[[Math], Math]) -> MathList: ...
|
|
29
26
|
|
|
30
27
|
|
|
31
28
|
@ruleset
|
|
32
|
-
def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]):
|
|
29
|
+
def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]):
|
|
33
30
|
yield rewrite(Math(i) + Math(j)).to(Math(i + j))
|
|
34
31
|
yield rewrite(xs.append(x).map(f)).to(xs.map(f).append(f(x)))
|
|
35
32
|
yield rewrite(MathList().map(f)).to(MathList())
|
|
36
33
|
|
|
37
34
|
|
|
38
35
|
@function(ruleset=math_ruleset)
|
|
39
|
-
def
|
|
40
|
-
return x + Math(1)
|
|
36
|
+
def incr_list(xs: MathList) -> MathList:
|
|
37
|
+
return xs.map(lambda x: x + Math(1))
|
|
41
38
|
|
|
42
39
|
|
|
43
40
|
egraph = EGraph()
|
|
44
|
-
|
|
45
|
-
y = egraph.let("y", x.map(increment_by_one))
|
|
41
|
+
y = egraph.let("y", incr_list(MathList().append(Math(1)).append(Math(2))))
|
|
46
42
|
egraph.run(math_ruleset.saturate())
|
|
47
|
-
|
|
48
43
|
egraph.check(eq(y).to(MathList().append(Math(2)).append(Math(3))))
|
|
49
44
|
|
|
50
45
|
egraph
|