egglog 7.2.0__cp311-none-win_amd64.whl → 8.0.0__cp311-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp311-win_amd64.pyd +0 -0
- egglog/bindings.pyi +63 -23
- egglog/builtins.py +49 -6
- egglog/conversion.py +31 -8
- egglog/declarations.py +83 -4
- egglog/egraph.py +241 -173
- egglog/egraph_state.py +137 -61
- egglog/examples/higher_order_functions.py +3 -8
- egglog/exp/array_api.py +274 -92
- egglog/exp/array_api_jit.py +1 -4
- egglog/exp/array_api_loopnest.py +145 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +51 -12
- egglog/functionalize.py +91 -0
- egglog/pretty.py +84 -40
- egglog/runtime.py +52 -39
- egglog/thunk.py +30 -18
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35753 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.2.0.dist-info → egglog-8.0.0.dist-info}/METADATA +33 -32
- egglog-8.0.0.dist-info/RECORD +42 -0
- {egglog-7.2.0.dist-info → egglog-8.0.0.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.2.0.dist-info/RECORD +0 -40
- {egglog-7.2.0.dist-info/license_files → egglog-8.0.0.dist-info/licenses}/LICENSE +0 -0
egglog/egraph_state.py
CHANGED
|
@@ -4,10 +4,10 @@ Implement conversion to/from egglog.
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
+
import re
|
|
7
8
|
from collections import defaultdict
|
|
8
9
|
from dataclasses import dataclass, field
|
|
9
10
|
from typing import TYPE_CHECKING, overload
|
|
10
|
-
from weakref import WeakKeyDictionary
|
|
11
11
|
|
|
12
12
|
from typing_extensions import assert_never
|
|
13
13
|
|
|
@@ -19,7 +19,7 @@ from .type_constraint_solver import TypeConstraintError, TypeConstraintSolver
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
20
|
from collections.abc import Iterable
|
|
21
21
|
|
|
22
|
-
__all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT"
|
|
22
|
+
__all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT"]
|
|
23
23
|
|
|
24
24
|
# Create a global sort for python objects, so we can store them without an e-graph instance
|
|
25
25
|
# Needed when serializing commands to egg commands when creating modules
|
|
@@ -52,7 +52,7 @@ class EGraphState:
|
|
|
52
52
|
type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
|
|
53
53
|
|
|
54
54
|
# Cache of egg expressions for converting to egg
|
|
55
|
-
expr_to_egg_cache:
|
|
55
|
+
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)
|
|
56
56
|
|
|
57
57
|
def copy(self) -> EGraphState:
|
|
58
58
|
"""
|
|
@@ -71,15 +71,15 @@ class EGraphState:
|
|
|
71
71
|
def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
|
|
72
72
|
match schedule:
|
|
73
73
|
case SaturateDecl(schedule):
|
|
74
|
-
return bindings.Saturate(self.schedule_to_egg(schedule))
|
|
74
|
+
return bindings.Saturate(bindings.DUMMY_SPAN, self.schedule_to_egg(schedule))
|
|
75
75
|
case RepeatDecl(schedule, times):
|
|
76
|
-
return bindings.Repeat(times, self.schedule_to_egg(schedule))
|
|
76
|
+
return bindings.Repeat(bindings.DUMMY_SPAN, times, self.schedule_to_egg(schedule))
|
|
77
77
|
case SequenceDecl(schedules):
|
|
78
|
-
return bindings.Sequence([self.schedule_to_egg(s) for s in schedules])
|
|
78
|
+
return bindings.Sequence(bindings.DUMMY_SPAN, [self.schedule_to_egg(s) for s in schedules])
|
|
79
79
|
case RunDecl(ruleset_name, until):
|
|
80
80
|
self.ruleset_to_egg(ruleset_name)
|
|
81
81
|
config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until)))
|
|
82
|
-
return bindings.Run(config)
|
|
82
|
+
return bindings.Run(bindings.DUMMY_SPAN, config)
|
|
83
83
|
case _:
|
|
84
84
|
assert_never(schedule)
|
|
85
85
|
|
|
@@ -116,6 +116,7 @@ class EGraphState:
|
|
|
116
116
|
case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
|
|
117
117
|
self.type_ref_to_egg(tp)
|
|
118
118
|
rewrite = bindings.Rewrite(
|
|
119
|
+
bindings.DUMMY_SPAN,
|
|
119
120
|
self._expr_to_egg(lhs),
|
|
120
121
|
self._expr_to_egg(rhs),
|
|
121
122
|
[self.fact_to_egg(c) for c in conditions],
|
|
@@ -127,19 +128,24 @@ class EGraphState:
|
|
|
127
128
|
)
|
|
128
129
|
case RuleDecl(head, body, name):
|
|
129
130
|
rule = bindings.Rule(
|
|
131
|
+
bindings.DUMMY_SPAN,
|
|
130
132
|
[self.action_to_egg(a) for a in head],
|
|
131
133
|
[self.fact_to_egg(f) for f in body],
|
|
132
134
|
)
|
|
133
135
|
return bindings.RuleCommand(name or "", ruleset, rule)
|
|
136
|
+
# TODO: Replace with just constants value and looking at REF of function
|
|
134
137
|
case DefaultRewriteDecl(ref, expr):
|
|
135
138
|
decl = self.__egg_decls__.get_callable_decl(ref).to_function_decl()
|
|
136
139
|
sig = decl.signature
|
|
137
140
|
assert isinstance(sig, FunctionSignature)
|
|
138
|
-
args
|
|
139
|
-
|
|
141
|
+
# Replace args with rule_var_name mapping
|
|
142
|
+
arg_mapping = tuple(
|
|
143
|
+
TypedExprDecl(tp.to_just(), VarDecl(name, False))
|
|
140
144
|
for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
|
|
141
145
|
)
|
|
142
|
-
rewrite_decl = RewriteDecl(
|
|
146
|
+
rewrite_decl = RewriteDecl(
|
|
147
|
+
sig.semantic_return_type.to_just(), CallDecl(ref, arg_mapping), expr, (), False
|
|
148
|
+
)
|
|
143
149
|
return self.command_to_egg(rewrite_decl, ruleset)
|
|
144
150
|
case _:
|
|
145
151
|
assert_never(cmd)
|
|
@@ -147,13 +153,16 @@ class EGraphState:
|
|
|
147
153
|
def action_to_egg(self, action: ActionDecl) -> bindings._Action:
|
|
148
154
|
match action:
|
|
149
155
|
case LetDecl(name, typed_expr):
|
|
150
|
-
|
|
156
|
+
var_decl = VarDecl(name, True)
|
|
157
|
+
var_egg = self._expr_to_egg(var_decl)
|
|
158
|
+
self.expr_to_egg_cache[var_decl] = var_egg
|
|
159
|
+
return bindings.Let(bindings.DUMMY_SPAN, var_egg.name, self.typed_expr_to_egg(typed_expr))
|
|
151
160
|
case SetDecl(tp, call, rhs):
|
|
152
161
|
self.type_ref_to_egg(tp)
|
|
153
162
|
call_ = self._expr_to_egg(call)
|
|
154
|
-
return bindings.Set(call_.name, call_.args, self._expr_to_egg(rhs))
|
|
163
|
+
return bindings.Set(bindings.DUMMY_SPAN, call_.name, call_.args, self._expr_to_egg(rhs))
|
|
155
164
|
case ExprActionDecl(typed_expr):
|
|
156
|
-
return bindings.Expr_(self.typed_expr_to_egg(typed_expr))
|
|
165
|
+
return bindings.Expr_(bindings.DUMMY_SPAN, self.typed_expr_to_egg(typed_expr))
|
|
157
166
|
case ChangeDecl(tp, call, change):
|
|
158
167
|
self.type_ref_to_egg(tp)
|
|
159
168
|
call_ = self._expr_to_egg(call)
|
|
@@ -165,12 +174,12 @@ class EGraphState:
|
|
|
165
174
|
egg_change = bindings.Subsume()
|
|
166
175
|
case _:
|
|
167
176
|
assert_never(change)
|
|
168
|
-
return bindings.Change(egg_change, call_.name, call_.args)
|
|
177
|
+
return bindings.Change(bindings.DUMMY_SPAN, egg_change, call_.name, call_.args)
|
|
169
178
|
case UnionDecl(tp, lhs, rhs):
|
|
170
179
|
self.type_ref_to_egg(tp)
|
|
171
|
-
return bindings.Union(self._expr_to_egg(lhs), self._expr_to_egg(rhs))
|
|
180
|
+
return bindings.Union(bindings.DUMMY_SPAN, self._expr_to_egg(lhs), self._expr_to_egg(rhs))
|
|
172
181
|
case PanicDecl(name):
|
|
173
|
-
return bindings.Panic(name)
|
|
182
|
+
return bindings.Panic(bindings.DUMMY_SPAN, name)
|
|
174
183
|
case _:
|
|
175
184
|
assert_never(action)
|
|
176
185
|
|
|
@@ -178,9 +187,9 @@ class EGraphState:
|
|
|
178
187
|
match fact:
|
|
179
188
|
case EqDecl(tp, exprs):
|
|
180
189
|
self.type_ref_to_egg(tp)
|
|
181
|
-
return bindings.Eq([self._expr_to_egg(e) for e in exprs])
|
|
190
|
+
return bindings.Eq(bindings.DUMMY_SPAN, [self._expr_to_egg(e) for e in exprs])
|
|
182
191
|
case ExprFactDecl(typed_expr):
|
|
183
|
-
return bindings.Fact(self.typed_expr_to_egg(typed_expr))
|
|
192
|
+
return bindings.Fact(self.typed_expr_to_egg(typed_expr, False))
|
|
184
193
|
case _:
|
|
185
194
|
assert_never(fact)
|
|
186
195
|
|
|
@@ -191,7 +200,9 @@ class EGraphState:
|
|
|
191
200
|
if ref in self.callable_ref_to_egg_fn:
|
|
192
201
|
return self.callable_ref_to_egg_fn[ref]
|
|
193
202
|
decl = self.__egg_decls__.get_callable_decl(ref)
|
|
194
|
-
self.callable_ref_to_egg_fn[ref] = egg_name = decl.egg_name or
|
|
203
|
+
self.callable_ref_to_egg_fn[ref] = egg_name = decl.egg_name or _sanitize_egg_ident(
|
|
204
|
+
self._generate_callable_egg_name(ref)
|
|
205
|
+
)
|
|
195
206
|
self.egg_fn_to_callable_refs[egg_name].add(ref)
|
|
196
207
|
match decl:
|
|
197
208
|
case RelationDecl(arg_types, _, _):
|
|
@@ -239,13 +250,14 @@ class EGraphState:
|
|
|
239
250
|
# UnstableFn is a special case, where the rest of args are collected into a call
|
|
240
251
|
type_args: list[bindings._Expr] = [
|
|
241
252
|
bindings.Call(
|
|
253
|
+
bindings.DUMMY_SPAN,
|
|
242
254
|
self.type_ref_to_egg(ref.args[1]),
|
|
243
|
-
[bindings.Var(self.type_ref_to_egg(a)) for a in ref.args[2:]],
|
|
255
|
+
[bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(a)) for a in ref.args[2:]],
|
|
244
256
|
),
|
|
245
|
-
bindings.Var(self.type_ref_to_egg(ref.args[0])),
|
|
257
|
+
bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(ref.args[0])),
|
|
246
258
|
]
|
|
247
259
|
else:
|
|
248
|
-
type_args = [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args]
|
|
260
|
+
type_args = [bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(a)) for a in ref.args]
|
|
249
261
|
args = (self.type_ref_to_egg(JustTypeRef(ref.name)), type_args)
|
|
250
262
|
else:
|
|
251
263
|
args = None
|
|
@@ -272,31 +284,60 @@ class EGraphState:
|
|
|
272
284
|
if len(v) == 1
|
|
273
285
|
}
|
|
274
286
|
|
|
275
|
-
def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl) -> bindings._Expr:
|
|
287
|
+
def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl, transform_let: bool = True) -> bindings._Expr:
|
|
288
|
+
# transform all expressions with multiple parents into a let binding, so that less expressions
|
|
289
|
+
# are sent to egglog. Only for performance reasons.
|
|
290
|
+
if transform_let:
|
|
291
|
+
have_multiple_parents = _exprs_multiple_parents(typed_expr_decl)
|
|
292
|
+
for expr in reversed(have_multiple_parents):
|
|
293
|
+
self._transform_let(expr)
|
|
294
|
+
|
|
276
295
|
self.type_ref_to_egg(typed_expr_decl.tp)
|
|
277
296
|
return self._expr_to_egg(typed_expr_decl.expr)
|
|
278
297
|
|
|
298
|
+
def _transform_let(self, typed_expr: TypedExprDecl) -> None:
|
|
299
|
+
"""
|
|
300
|
+
Rewrites this expression as a let binding if it's not already a let binding.
|
|
301
|
+
"""
|
|
302
|
+
var_decl = VarDecl(f"__expr_{hash(typed_expr)}", True)
|
|
303
|
+
if var_decl in self.expr_to_egg_cache:
|
|
304
|
+
return
|
|
305
|
+
var_egg = self._expr_to_egg(var_decl)
|
|
306
|
+
cmd = bindings.ActionCommand(
|
|
307
|
+
bindings.Let(bindings.DUMMY_SPAN, var_egg.name, self.typed_expr_to_egg(typed_expr))
|
|
308
|
+
)
|
|
309
|
+
try:
|
|
310
|
+
self.egraph.run_program(cmd)
|
|
311
|
+
# errors when creating let bindings for things like `(vec-empty)`
|
|
312
|
+
except bindings.EggSmolError:
|
|
313
|
+
return
|
|
314
|
+
self.expr_to_egg_cache[typed_expr.expr] = var_egg
|
|
315
|
+
self.expr_to_egg_cache[var_decl] = var_egg
|
|
316
|
+
|
|
279
317
|
@overload
|
|
280
318
|
def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
|
|
281
319
|
|
|
320
|
+
@overload
|
|
321
|
+
def _expr_to_egg(self, expr_decl: VarDecl) -> bindings.Var: ...
|
|
322
|
+
|
|
282
323
|
@overload
|
|
283
324
|
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
|
|
284
325
|
|
|
285
326
|
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
|
|
286
327
|
"""
|
|
287
328
|
Convert an ExprDecl to an egg expression.
|
|
288
|
-
|
|
289
|
-
Cached using weakrefs to avoid memory leaks.
|
|
290
329
|
"""
|
|
291
330
|
try:
|
|
292
331
|
return self.expr_to_egg_cache[expr_decl]
|
|
293
332
|
except KeyError:
|
|
294
333
|
pass
|
|
295
|
-
|
|
296
334
|
res: bindings._Expr
|
|
297
335
|
match expr_decl:
|
|
298
|
-
case VarDecl(name):
|
|
299
|
-
|
|
336
|
+
case VarDecl(name, is_let):
|
|
337
|
+
# prefix let bindings with % to avoid name conflicts with rewrites
|
|
338
|
+
if is_let:
|
|
339
|
+
name = f"%{name}"
|
|
340
|
+
res = bindings.Var(bindings.DUMMY_SPAN, name)
|
|
300
341
|
case LitDecl(value):
|
|
301
342
|
l: bindings._Literal
|
|
302
343
|
match value:
|
|
@@ -312,19 +353,22 @@ class EGraphState:
|
|
|
312
353
|
l = bindings.String(s)
|
|
313
354
|
case _:
|
|
314
355
|
assert_never(value)
|
|
315
|
-
res = bindings.Lit(l)
|
|
356
|
+
res = bindings.Lit(bindings.DUMMY_SPAN, l)
|
|
316
357
|
case CallDecl(ref, args, _):
|
|
317
358
|
egg_fn = self.callable_ref_to_egg(ref)
|
|
318
|
-
egg_args = [self.typed_expr_to_egg(a) for a in args]
|
|
319
|
-
res = bindings.Call(egg_fn, egg_args)
|
|
359
|
+
egg_args = [self.typed_expr_to_egg(a, False) for a in args]
|
|
360
|
+
res = bindings.Call(bindings.DUMMY_SPAN, egg_fn, egg_args)
|
|
320
361
|
case PyObjectDecl(value):
|
|
321
362
|
res = GLOBAL_PY_OBJECT_SORT.store(value)
|
|
322
363
|
case PartialCallDecl(call_decl):
|
|
323
364
|
egg_fn_call = self._expr_to_egg(call_decl)
|
|
324
|
-
res = bindings.Call(
|
|
365
|
+
res = bindings.Call(
|
|
366
|
+
bindings.DUMMY_SPAN,
|
|
367
|
+
"unstable-fn",
|
|
368
|
+
[bindings.Lit(bindings.DUMMY_SPAN, bindings.String(egg_fn_call.name)), *egg_fn_call.args],
|
|
369
|
+
)
|
|
325
370
|
case _:
|
|
326
371
|
assert_never(expr_decl.expr)
|
|
327
|
-
|
|
328
372
|
self.expr_to_egg_cache[expr_decl] = res
|
|
329
373
|
return res
|
|
330
374
|
|
|
@@ -343,6 +387,65 @@ class EGraphState:
|
|
|
343
387
|
"""
|
|
344
388
|
return frozenset(tp for tp in self.type_ref_to_egg_sort if tp.name == cls_name)
|
|
345
389
|
|
|
390
|
+
def _generate_callable_egg_name(self, ref: CallableRef) -> str:
|
|
391
|
+
"""
|
|
392
|
+
Generates a valid egg function name for a callable reference.
|
|
393
|
+
"""
|
|
394
|
+
match ref:
|
|
395
|
+
case FunctionRef(name):
|
|
396
|
+
return name
|
|
397
|
+
|
|
398
|
+
case ConstantRef(name):
|
|
399
|
+
return name
|
|
400
|
+
case (
|
|
401
|
+
MethodRef(cls_name, name)
|
|
402
|
+
| ClassMethodRef(cls_name, name)
|
|
403
|
+
| ClassVariableRef(cls_name, name)
|
|
404
|
+
| PropertyRef(cls_name, name)
|
|
405
|
+
):
|
|
406
|
+
return f"{cls_name}.{name}"
|
|
407
|
+
case InitRef(cls_name):
|
|
408
|
+
return f"{cls_name}.__init__"
|
|
409
|
+
case UnnamedFunctionRef(args, val):
|
|
410
|
+
parts = [str(self._expr_to_egg(a.expr)) + "-" + str(self.type_ref_to_egg(a.tp)) for a in args] + [
|
|
411
|
+
str(self.typed_expr_to_egg(val, False))
|
|
412
|
+
]
|
|
413
|
+
return "_".join(parts)
|
|
414
|
+
case _:
|
|
415
|
+
assert_never(ref)
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
# https://chatgpt.com/share/9ab899b4-4e17-4426-a3f2-79d67a5ec456
|
|
419
|
+
_EGGLOG_INVALID_IDENT = re.compile(r"[^\w\-+*/?!=<>&|^/%]")
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def _sanitize_egg_ident(input_string: str) -> str:
|
|
423
|
+
"""
|
|
424
|
+
Replaces all invalid characters in an egg identifier with an underscore.
|
|
425
|
+
"""
|
|
426
|
+
return _EGGLOG_INVALID_IDENT.sub("_", input_string)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def _exprs_multiple_parents(typed_expr: TypedExprDecl) -> list[TypedExprDecl]:
|
|
430
|
+
"""
|
|
431
|
+
Returns all expressions that have multiple parents (a list but semantically just an ordered set).
|
|
432
|
+
"""
|
|
433
|
+
to_traverse = {typed_expr}
|
|
434
|
+
traversed = set[TypedExprDecl]()
|
|
435
|
+
traversed_twice = list[TypedExprDecl]()
|
|
436
|
+
while to_traverse:
|
|
437
|
+
typed_expr = to_traverse.pop()
|
|
438
|
+
if typed_expr in traversed:
|
|
439
|
+
traversed_twice.append(typed_expr)
|
|
440
|
+
continue
|
|
441
|
+
traversed.add(typed_expr)
|
|
442
|
+
expr = typed_expr.expr
|
|
443
|
+
if isinstance(expr, CallDecl):
|
|
444
|
+
to_traverse.update(expr.args)
|
|
445
|
+
elif isinstance(expr, PartialCallDecl):
|
|
446
|
+
to_traverse.update(expr.call.args)
|
|
447
|
+
return traversed_twice
|
|
448
|
+
|
|
346
449
|
|
|
347
450
|
def _generate_type_egg_name(ref: JustTypeRef) -> str:
|
|
348
451
|
"""
|
|
@@ -354,26 +457,6 @@ def _generate_type_egg_name(ref: JustTypeRef) -> str:
|
|
|
354
457
|
return f"{name}_{'_'.join(map(_generate_type_egg_name, ref.args))}"
|
|
355
458
|
|
|
356
459
|
|
|
357
|
-
def _generate_callable_egg_name(ref: CallableRef) -> str:
|
|
358
|
-
"""
|
|
359
|
-
Generates a valid egg function name for a callable reference.
|
|
360
|
-
"""
|
|
361
|
-
match ref:
|
|
362
|
-
case FunctionRef(name) | ConstantRef(name):
|
|
363
|
-
return name
|
|
364
|
-
case (
|
|
365
|
-
MethodRef(cls_name, name)
|
|
366
|
-
| ClassMethodRef(cls_name, name)
|
|
367
|
-
| ClassVariableRef(cls_name, name)
|
|
368
|
-
| PropertyRef(cls_name, name)
|
|
369
|
-
):
|
|
370
|
-
return f"{cls_name}_{name}"
|
|
371
|
-
case InitRef(cls_name):
|
|
372
|
-
return f"{cls_name}___init__"
|
|
373
|
-
case _:
|
|
374
|
-
assert_never(ref)
|
|
375
|
-
|
|
376
|
-
|
|
377
460
|
@dataclass
|
|
378
461
|
class FromEggState:
|
|
379
462
|
"""
|
|
@@ -395,7 +478,7 @@ class FromEggState:
|
|
|
395
478
|
"""
|
|
396
479
|
expr_decl: ExprDecl
|
|
397
480
|
if isinstance(term, bindings.TermVar):
|
|
398
|
-
expr_decl = VarDecl(term.name)
|
|
481
|
+
expr_decl = VarDecl(term.name, True)
|
|
399
482
|
elif isinstance(term, bindings.TermLit):
|
|
400
483
|
value = term.value
|
|
401
484
|
expr_decl = LitDecl(None if isinstance(value, bindings.Unit) else value.value)
|
|
@@ -478,10 +561,3 @@ class FromEggState:
|
|
|
478
561
|
except KeyError:
|
|
479
562
|
res = self.cache[term_id] = self.from_expr(tp, self.termdag.nodes[term_id])
|
|
480
563
|
return res
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
def _rule_var_name(s: str) -> str:
|
|
484
|
-
"""
|
|
485
|
-
Create a hidden variable name, for rewrites, so that let bindings or function won't conflict with it
|
|
486
|
-
"""
|
|
487
|
-
return f"__var__{s}"
|
|
@@ -16,15 +16,12 @@ if TYPE_CHECKING:
|
|
|
16
16
|
|
|
17
17
|
class Math(Expr):
|
|
18
18
|
def __init__(self, i: i64Like) -> None: ...
|
|
19
|
-
|
|
20
19
|
def __add__(self, other: Math) -> Math: ...
|
|
21
20
|
|
|
22
21
|
|
|
23
22
|
class MathList(Expr):
|
|
24
23
|
def __init__(self) -> None: ...
|
|
25
|
-
|
|
26
24
|
def append(self, i: Math) -> MathList: ...
|
|
27
|
-
|
|
28
25
|
def map(self, f: Callable[[Math], Math]) -> MathList: ...
|
|
29
26
|
|
|
30
27
|
|
|
@@ -36,15 +33,13 @@ def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math
|
|
|
36
33
|
|
|
37
34
|
|
|
38
35
|
@function(ruleset=math_ruleset)
|
|
39
|
-
def
|
|
40
|
-
return x + Math(1)
|
|
36
|
+
def incr_list(xs: MathList) -> MathList:
|
|
37
|
+
return xs.map(lambda x: x + Math(1))
|
|
41
38
|
|
|
42
39
|
|
|
43
40
|
egraph = EGraph()
|
|
44
|
-
|
|
45
|
-
y = egraph.let("y", x.map(increment_by_one))
|
|
41
|
+
y = egraph.let("y", incr_list(MathList().append(Math(1)).append(Math(2))))
|
|
46
42
|
egraph.run(math_ruleset.saturate())
|
|
47
|
-
|
|
48
43
|
egraph.check(eq(y).to(MathList().append(Math(2)).append(Math(3))))
|
|
49
44
|
|
|
50
45
|
egraph
|