egglog 7.1.0__cp311-none-win_amd64.whl → 8.0.0__cp311-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp311-win_amd64.pyd +0 -0
- egglog/bindings.pyi +63 -23
- egglog/builtins.py +49 -6
- egglog/conversion.py +31 -8
- egglog/declarations.py +146 -8
- egglog/egraph.py +337 -203
- egglog/egraph_state.py +171 -64
- egglog/examples/higher_order_functions.py +45 -0
- egglog/exp/array_api.py +278 -93
- egglog/exp/array_api_jit.py +1 -4
- egglog/exp/array_api_loopnest.py +145 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +51 -12
- egglog/functionalize.py +91 -0
- egglog/pretty.py +97 -43
- egglog/runtime.py +60 -44
- egglog/thunk.py +44 -20
- egglog/type_constraint_solver.py +5 -4
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35753 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.1.0.dist-info → egglog-8.0.0.dist-info}/METADATA +31 -30
- egglog-8.0.0.dist-info/RECORD +42 -0
- {egglog-7.1.0.dist-info → egglog-8.0.0.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.1.0.dist-info/RECORD +0 -39
- {egglog-7.1.0.dist-info/license_files → egglog-8.0.0.dist-info/licenses}/LICENSE +0 -0
egglog/egraph_state.py
CHANGED
|
@@ -4,10 +4,10 @@ Implement conversion to/from egglog.
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
+
import re
|
|
7
8
|
from collections import defaultdict
|
|
8
9
|
from dataclasses import dataclass, field
|
|
9
10
|
from typing import TYPE_CHECKING, overload
|
|
10
|
-
from weakref import WeakKeyDictionary
|
|
11
11
|
|
|
12
12
|
from typing_extensions import assert_never
|
|
13
13
|
|
|
@@ -52,7 +52,7 @@ class EGraphState:
|
|
|
52
52
|
type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
|
|
53
53
|
|
|
54
54
|
# Cache of egg expressions for converting to egg
|
|
55
|
-
expr_to_egg_cache:
|
|
55
|
+
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)
|
|
56
56
|
|
|
57
57
|
def copy(self) -> EGraphState:
|
|
58
58
|
"""
|
|
@@ -71,15 +71,15 @@ class EGraphState:
|
|
|
71
71
|
def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
|
|
72
72
|
match schedule:
|
|
73
73
|
case SaturateDecl(schedule):
|
|
74
|
-
return bindings.Saturate(self.schedule_to_egg(schedule))
|
|
74
|
+
return bindings.Saturate(bindings.DUMMY_SPAN, self.schedule_to_egg(schedule))
|
|
75
75
|
case RepeatDecl(schedule, times):
|
|
76
|
-
return bindings.Repeat(times, self.schedule_to_egg(schedule))
|
|
76
|
+
return bindings.Repeat(bindings.DUMMY_SPAN, times, self.schedule_to_egg(schedule))
|
|
77
77
|
case SequenceDecl(schedules):
|
|
78
|
-
return bindings.Sequence([self.schedule_to_egg(s) for s in schedules])
|
|
78
|
+
return bindings.Sequence(bindings.DUMMY_SPAN, [self.schedule_to_egg(s) for s in schedules])
|
|
79
79
|
case RunDecl(ruleset_name, until):
|
|
80
80
|
self.ruleset_to_egg(ruleset_name)
|
|
81
81
|
config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until)))
|
|
82
|
-
return bindings.Run(config)
|
|
82
|
+
return bindings.Run(bindings.DUMMY_SPAN, config)
|
|
83
83
|
case _:
|
|
84
84
|
assert_never(schedule)
|
|
85
85
|
|
|
@@ -98,7 +98,8 @@ class EGraphState:
|
|
|
98
98
|
for rule in rules:
|
|
99
99
|
if rule in added_rules:
|
|
100
100
|
continue
|
|
101
|
-
self.
|
|
101
|
+
cmd = self.command_to_egg(rule, name)
|
|
102
|
+
self.egraph.run_program(cmd)
|
|
102
103
|
added_rules.add(rule)
|
|
103
104
|
case CombinedRulesetDecl(rulesets):
|
|
104
105
|
if name in self.rulesets:
|
|
@@ -115,8 +116,9 @@ class EGraphState:
|
|
|
115
116
|
case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
|
|
116
117
|
self.type_ref_to_egg(tp)
|
|
117
118
|
rewrite = bindings.Rewrite(
|
|
118
|
-
|
|
119
|
-
self.
|
|
119
|
+
bindings.DUMMY_SPAN,
|
|
120
|
+
self._expr_to_egg(lhs),
|
|
121
|
+
self._expr_to_egg(rhs),
|
|
120
122
|
[self.fact_to_egg(c) for c in conditions],
|
|
121
123
|
)
|
|
122
124
|
return (
|
|
@@ -126,26 +128,44 @@ class EGraphState:
|
|
|
126
128
|
)
|
|
127
129
|
case RuleDecl(head, body, name):
|
|
128
130
|
rule = bindings.Rule(
|
|
131
|
+
bindings.DUMMY_SPAN,
|
|
129
132
|
[self.action_to_egg(a) for a in head],
|
|
130
133
|
[self.fact_to_egg(f) for f in body],
|
|
131
134
|
)
|
|
132
135
|
return bindings.RuleCommand(name or "", ruleset, rule)
|
|
136
|
+
# TODO: Replace with just constants value and looking at REF of function
|
|
137
|
+
case DefaultRewriteDecl(ref, expr):
|
|
138
|
+
decl = self.__egg_decls__.get_callable_decl(ref).to_function_decl()
|
|
139
|
+
sig = decl.signature
|
|
140
|
+
assert isinstance(sig, FunctionSignature)
|
|
141
|
+
# Replace args with rule_var_name mapping
|
|
142
|
+
arg_mapping = tuple(
|
|
143
|
+
TypedExprDecl(tp.to_just(), VarDecl(name, False))
|
|
144
|
+
for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
|
|
145
|
+
)
|
|
146
|
+
rewrite_decl = RewriteDecl(
|
|
147
|
+
sig.semantic_return_type.to_just(), CallDecl(ref, arg_mapping), expr, (), False
|
|
148
|
+
)
|
|
149
|
+
return self.command_to_egg(rewrite_decl, ruleset)
|
|
133
150
|
case _:
|
|
134
151
|
assert_never(cmd)
|
|
135
152
|
|
|
136
153
|
def action_to_egg(self, action: ActionDecl) -> bindings._Action:
|
|
137
154
|
match action:
|
|
138
155
|
case LetDecl(name, typed_expr):
|
|
139
|
-
|
|
156
|
+
var_decl = VarDecl(name, True)
|
|
157
|
+
var_egg = self._expr_to_egg(var_decl)
|
|
158
|
+
self.expr_to_egg_cache[var_decl] = var_egg
|
|
159
|
+
return bindings.Let(bindings.DUMMY_SPAN, var_egg.name, self.typed_expr_to_egg(typed_expr))
|
|
140
160
|
case SetDecl(tp, call, rhs):
|
|
141
161
|
self.type_ref_to_egg(tp)
|
|
142
|
-
call_ = self.
|
|
143
|
-
return bindings.Set(call_.name, call_.args, self.
|
|
162
|
+
call_ = self._expr_to_egg(call)
|
|
163
|
+
return bindings.Set(bindings.DUMMY_SPAN, call_.name, call_.args, self._expr_to_egg(rhs))
|
|
144
164
|
case ExprActionDecl(typed_expr):
|
|
145
|
-
return bindings.Expr_(self.typed_expr_to_egg(typed_expr))
|
|
165
|
+
return bindings.Expr_(bindings.DUMMY_SPAN, self.typed_expr_to_egg(typed_expr))
|
|
146
166
|
case ChangeDecl(tp, call, change):
|
|
147
167
|
self.type_ref_to_egg(tp)
|
|
148
|
-
call_ = self.
|
|
168
|
+
call_ = self._expr_to_egg(call)
|
|
149
169
|
egg_change: bindings._Change
|
|
150
170
|
match change:
|
|
151
171
|
case "delete":
|
|
@@ -154,12 +174,12 @@ class EGraphState:
|
|
|
154
174
|
egg_change = bindings.Subsume()
|
|
155
175
|
case _:
|
|
156
176
|
assert_never(change)
|
|
157
|
-
return bindings.Change(egg_change, call_.name, call_.args)
|
|
177
|
+
return bindings.Change(bindings.DUMMY_SPAN, egg_change, call_.name, call_.args)
|
|
158
178
|
case UnionDecl(tp, lhs, rhs):
|
|
159
179
|
self.type_ref_to_egg(tp)
|
|
160
|
-
return bindings.Union(self.
|
|
180
|
+
return bindings.Union(bindings.DUMMY_SPAN, self._expr_to_egg(lhs), self._expr_to_egg(rhs))
|
|
161
181
|
case PanicDecl(name):
|
|
162
|
-
return bindings.Panic(name)
|
|
182
|
+
return bindings.Panic(bindings.DUMMY_SPAN, name)
|
|
163
183
|
case _:
|
|
164
184
|
assert_never(action)
|
|
165
185
|
|
|
@@ -167,9 +187,9 @@ class EGraphState:
|
|
|
167
187
|
match fact:
|
|
168
188
|
case EqDecl(tp, exprs):
|
|
169
189
|
self.type_ref_to_egg(tp)
|
|
170
|
-
return bindings.Eq([self.
|
|
190
|
+
return bindings.Eq(bindings.DUMMY_SPAN, [self._expr_to_egg(e) for e in exprs])
|
|
171
191
|
case ExprFactDecl(typed_expr):
|
|
172
|
-
return bindings.Fact(self.typed_expr_to_egg(typed_expr))
|
|
192
|
+
return bindings.Fact(self.typed_expr_to_egg(typed_expr, False))
|
|
173
193
|
case _:
|
|
174
194
|
assert_never(fact)
|
|
175
195
|
|
|
@@ -180,7 +200,9 @@ class EGraphState:
|
|
|
180
200
|
if ref in self.callable_ref_to_egg_fn:
|
|
181
201
|
return self.callable_ref_to_egg_fn[ref]
|
|
182
202
|
decl = self.__egg_decls__.get_callable_decl(ref)
|
|
183
|
-
self.callable_ref_to_egg_fn[ref] = egg_name = decl.egg_name or
|
|
203
|
+
self.callable_ref_to_egg_fn[ref] = egg_name = decl.egg_name or _sanitize_egg_ident(
|
|
204
|
+
self._generate_callable_egg_name(ref)
|
|
205
|
+
)
|
|
184
206
|
self.egg_fn_to_callable_refs[egg_name].add(ref)
|
|
185
207
|
match decl:
|
|
186
208
|
case RelationDecl(arg_types, _, _):
|
|
@@ -201,8 +223,8 @@ class EGraphState:
|
|
|
201
223
|
[self.type_ref_to_egg(a.to_just()) for a in signature.arg_types],
|
|
202
224
|
self.type_ref_to_egg(signature.semantic_return_type.to_just()),
|
|
203
225
|
),
|
|
204
|
-
self.
|
|
205
|
-
self.
|
|
226
|
+
self._expr_to_egg(decl.default) if decl.default else None,
|
|
227
|
+
self._expr_to_egg(decl.merge) if decl.merge else None,
|
|
206
228
|
[self.action_to_egg(a) for a in decl.on_merge],
|
|
207
229
|
decl.cost,
|
|
208
230
|
decl.unextractable,
|
|
@@ -228,13 +250,14 @@ class EGraphState:
|
|
|
228
250
|
# UnstableFn is a special case, where the rest of args are collected into a call
|
|
229
251
|
type_args: list[bindings._Expr] = [
|
|
230
252
|
bindings.Call(
|
|
253
|
+
bindings.DUMMY_SPAN,
|
|
231
254
|
self.type_ref_to_egg(ref.args[1]),
|
|
232
|
-
[bindings.Var(self.type_ref_to_egg(a)) for a in ref.args[2:]],
|
|
255
|
+
[bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(a)) for a in ref.args[2:]],
|
|
233
256
|
),
|
|
234
|
-
bindings.Var(self.type_ref_to_egg(ref.args[0])),
|
|
257
|
+
bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(ref.args[0])),
|
|
235
258
|
]
|
|
236
259
|
else:
|
|
237
|
-
type_args = [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args]
|
|
260
|
+
type_args = [bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(a)) for a in ref.args]
|
|
238
261
|
args = (self.type_ref_to_egg(JustTypeRef(ref.name)), type_args)
|
|
239
262
|
else:
|
|
240
263
|
args = None
|
|
@@ -245,6 +268,8 @@ class EGraphState:
|
|
|
245
268
|
if decl.builtin:
|
|
246
269
|
for method in decl.class_methods:
|
|
247
270
|
self.callable_ref_to_egg(ClassMethodRef(ref.name, method))
|
|
271
|
+
if decl.init:
|
|
272
|
+
self.callable_ref_to_egg(InitRef(ref.name))
|
|
248
273
|
|
|
249
274
|
return egg_name
|
|
250
275
|
|
|
@@ -259,31 +284,60 @@ class EGraphState:
|
|
|
259
284
|
if len(v) == 1
|
|
260
285
|
}
|
|
261
286
|
|
|
262
|
-
def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl) -> bindings._Expr:
|
|
287
|
+
def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl, transform_let: bool = True) -> bindings._Expr:
|
|
288
|
+
# transform all expressions with multiple parents into a let binding, so that less expressions
|
|
289
|
+
# are sent to egglog. Only for performance reasons.
|
|
290
|
+
if transform_let:
|
|
291
|
+
have_multiple_parents = _exprs_multiple_parents(typed_expr_decl)
|
|
292
|
+
for expr in reversed(have_multiple_parents):
|
|
293
|
+
self._transform_let(expr)
|
|
294
|
+
|
|
263
295
|
self.type_ref_to_egg(typed_expr_decl.tp)
|
|
264
|
-
return self.
|
|
296
|
+
return self._expr_to_egg(typed_expr_decl.expr)
|
|
297
|
+
|
|
298
|
+
def _transform_let(self, typed_expr: TypedExprDecl) -> None:
|
|
299
|
+
"""
|
|
300
|
+
Rewrites this expression as a let binding if it's not already a let binding.
|
|
301
|
+
"""
|
|
302
|
+
var_decl = VarDecl(f"__expr_{hash(typed_expr)}", True)
|
|
303
|
+
if var_decl in self.expr_to_egg_cache:
|
|
304
|
+
return
|
|
305
|
+
var_egg = self._expr_to_egg(var_decl)
|
|
306
|
+
cmd = bindings.ActionCommand(
|
|
307
|
+
bindings.Let(bindings.DUMMY_SPAN, var_egg.name, self.typed_expr_to_egg(typed_expr))
|
|
308
|
+
)
|
|
309
|
+
try:
|
|
310
|
+
self.egraph.run_program(cmd)
|
|
311
|
+
# errors when creating let bindings for things like `(vec-empty)`
|
|
312
|
+
except bindings.EggSmolError:
|
|
313
|
+
return
|
|
314
|
+
self.expr_to_egg_cache[typed_expr.expr] = var_egg
|
|
315
|
+
self.expr_to_egg_cache[var_decl] = var_egg
|
|
265
316
|
|
|
266
317
|
@overload
|
|
267
|
-
def
|
|
318
|
+
def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
|
|
268
319
|
|
|
269
320
|
@overload
|
|
270
|
-
def
|
|
321
|
+
def _expr_to_egg(self, expr_decl: VarDecl) -> bindings.Var: ...
|
|
271
322
|
|
|
272
|
-
|
|
323
|
+
@overload
|
|
324
|
+
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
|
|
325
|
+
|
|
326
|
+
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
|
|
273
327
|
"""
|
|
274
328
|
Convert an ExprDecl to an egg expression.
|
|
275
|
-
|
|
276
|
-
Cached using weakrefs to avoid memory leaks.
|
|
277
329
|
"""
|
|
278
330
|
try:
|
|
279
331
|
return self.expr_to_egg_cache[expr_decl]
|
|
280
332
|
except KeyError:
|
|
281
333
|
pass
|
|
282
|
-
|
|
283
334
|
res: bindings._Expr
|
|
284
335
|
match expr_decl:
|
|
285
|
-
case VarDecl(name):
|
|
286
|
-
|
|
336
|
+
case VarDecl(name, is_let):
|
|
337
|
+
# prefix let bindings with % to avoid name conflicts with rewrites
|
|
338
|
+
if is_let:
|
|
339
|
+
name = f"%{name}"
|
|
340
|
+
res = bindings.Var(bindings.DUMMY_SPAN, name)
|
|
287
341
|
case LitDecl(value):
|
|
288
342
|
l: bindings._Literal
|
|
289
343
|
match value:
|
|
@@ -299,19 +353,22 @@ class EGraphState:
|
|
|
299
353
|
l = bindings.String(s)
|
|
300
354
|
case _:
|
|
301
355
|
assert_never(value)
|
|
302
|
-
res = bindings.Lit(l)
|
|
356
|
+
res = bindings.Lit(bindings.DUMMY_SPAN, l)
|
|
303
357
|
case CallDecl(ref, args, _):
|
|
304
358
|
egg_fn = self.callable_ref_to_egg(ref)
|
|
305
|
-
egg_args = [self.typed_expr_to_egg(a) for a in args]
|
|
306
|
-
res = bindings.Call(egg_fn, egg_args)
|
|
359
|
+
egg_args = [self.typed_expr_to_egg(a, False) for a in args]
|
|
360
|
+
res = bindings.Call(bindings.DUMMY_SPAN, egg_fn, egg_args)
|
|
307
361
|
case PyObjectDecl(value):
|
|
308
362
|
res = GLOBAL_PY_OBJECT_SORT.store(value)
|
|
309
363
|
case PartialCallDecl(call_decl):
|
|
310
|
-
egg_fn_call = self.
|
|
311
|
-
res = bindings.Call(
|
|
364
|
+
egg_fn_call = self._expr_to_egg(call_decl)
|
|
365
|
+
res = bindings.Call(
|
|
366
|
+
bindings.DUMMY_SPAN,
|
|
367
|
+
"unstable-fn",
|
|
368
|
+
[bindings.Lit(bindings.DUMMY_SPAN, bindings.String(egg_fn_call.name)), *egg_fn_call.args],
|
|
369
|
+
)
|
|
312
370
|
case _:
|
|
313
371
|
assert_never(expr_decl.expr)
|
|
314
|
-
|
|
315
372
|
self.expr_to_egg_cache[expr_decl] = res
|
|
316
373
|
return res
|
|
317
374
|
|
|
@@ -330,6 +387,65 @@ class EGraphState:
|
|
|
330
387
|
"""
|
|
331
388
|
return frozenset(tp for tp in self.type_ref_to_egg_sort if tp.name == cls_name)
|
|
332
389
|
|
|
390
|
+
def _generate_callable_egg_name(self, ref: CallableRef) -> str:
|
|
391
|
+
"""
|
|
392
|
+
Generates a valid egg function name for a callable reference.
|
|
393
|
+
"""
|
|
394
|
+
match ref:
|
|
395
|
+
case FunctionRef(name):
|
|
396
|
+
return name
|
|
397
|
+
|
|
398
|
+
case ConstantRef(name):
|
|
399
|
+
return name
|
|
400
|
+
case (
|
|
401
|
+
MethodRef(cls_name, name)
|
|
402
|
+
| ClassMethodRef(cls_name, name)
|
|
403
|
+
| ClassVariableRef(cls_name, name)
|
|
404
|
+
| PropertyRef(cls_name, name)
|
|
405
|
+
):
|
|
406
|
+
return f"{cls_name}.{name}"
|
|
407
|
+
case InitRef(cls_name):
|
|
408
|
+
return f"{cls_name}.__init__"
|
|
409
|
+
case UnnamedFunctionRef(args, val):
|
|
410
|
+
parts = [str(self._expr_to_egg(a.expr)) + "-" + str(self.type_ref_to_egg(a.tp)) for a in args] + [
|
|
411
|
+
str(self.typed_expr_to_egg(val, False))
|
|
412
|
+
]
|
|
413
|
+
return "_".join(parts)
|
|
414
|
+
case _:
|
|
415
|
+
assert_never(ref)
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
# https://chatgpt.com/share/9ab899b4-4e17-4426-a3f2-79d67a5ec456
|
|
419
|
+
_EGGLOG_INVALID_IDENT = re.compile(r"[^\w\-+*/?!=<>&|^/%]")
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def _sanitize_egg_ident(input_string: str) -> str:
|
|
423
|
+
"""
|
|
424
|
+
Replaces all invalid characters in an egg identifier with an underscore.
|
|
425
|
+
"""
|
|
426
|
+
return _EGGLOG_INVALID_IDENT.sub("_", input_string)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def _exprs_multiple_parents(typed_expr: TypedExprDecl) -> list[TypedExprDecl]:
|
|
430
|
+
"""
|
|
431
|
+
Returns all expressions that have multiple parents (a list but semantically just an ordered set).
|
|
432
|
+
"""
|
|
433
|
+
to_traverse = {typed_expr}
|
|
434
|
+
traversed = set[TypedExprDecl]()
|
|
435
|
+
traversed_twice = list[TypedExprDecl]()
|
|
436
|
+
while to_traverse:
|
|
437
|
+
typed_expr = to_traverse.pop()
|
|
438
|
+
if typed_expr in traversed:
|
|
439
|
+
traversed_twice.append(typed_expr)
|
|
440
|
+
continue
|
|
441
|
+
traversed.add(typed_expr)
|
|
442
|
+
expr = typed_expr.expr
|
|
443
|
+
if isinstance(expr, CallDecl):
|
|
444
|
+
to_traverse.update(expr.args)
|
|
445
|
+
elif isinstance(expr, PartialCallDecl):
|
|
446
|
+
to_traverse.update(expr.call.args)
|
|
447
|
+
return traversed_twice
|
|
448
|
+
|
|
333
449
|
|
|
334
450
|
def _generate_type_egg_name(ref: JustTypeRef) -> str:
|
|
335
451
|
"""
|
|
@@ -341,24 +457,6 @@ def _generate_type_egg_name(ref: JustTypeRef) -> str:
|
|
|
341
457
|
return f"{name}_{'_'.join(map(_generate_type_egg_name, ref.args))}"
|
|
342
458
|
|
|
343
459
|
|
|
344
|
-
def _generate_callable_egg_name(ref: CallableRef) -> str:
|
|
345
|
-
"""
|
|
346
|
-
Generates a valid egg function name for a callable reference.
|
|
347
|
-
"""
|
|
348
|
-
match ref:
|
|
349
|
-
case FunctionRef(name) | ConstantRef(name):
|
|
350
|
-
return name
|
|
351
|
-
case (
|
|
352
|
-
MethodRef(cls_name, name)
|
|
353
|
-
| ClassMethodRef(cls_name, name)
|
|
354
|
-
| ClassVariableRef(cls_name, name)
|
|
355
|
-
| PropertyRef(cls_name, name)
|
|
356
|
-
):
|
|
357
|
-
return f"{cls_name}_{name}"
|
|
358
|
-
case _:
|
|
359
|
-
assert_never(ref)
|
|
360
|
-
|
|
361
|
-
|
|
362
460
|
@dataclass
|
|
363
461
|
class FromEggState:
|
|
364
462
|
"""
|
|
@@ -380,7 +478,7 @@ class FromEggState:
|
|
|
380
478
|
"""
|
|
381
479
|
expr_decl: ExprDecl
|
|
382
480
|
if isinstance(term, bindings.TermVar):
|
|
383
|
-
expr_decl = VarDecl(term.name)
|
|
481
|
+
expr_decl = VarDecl(term.name, True)
|
|
384
482
|
elif isinstance(term, bindings.TermLit):
|
|
385
483
|
value = term.value
|
|
386
484
|
expr_decl = LitDecl(None if isinstance(value, bindings.Unit) else value.value)
|
|
@@ -427,8 +525,11 @@ class FromEggState:
|
|
|
427
525
|
possible_types: Iterable[JustTypeRef | None]
|
|
428
526
|
signature = self.decls.get_callable_decl(callable_ref).to_function_decl().signature
|
|
429
527
|
assert isinstance(signature, FunctionSignature)
|
|
430
|
-
if isinstance(callable_ref, ClassMethodRef):
|
|
431
|
-
|
|
528
|
+
if isinstance(callable_ref, ClassMethodRef | InitRef | MethodRef):
|
|
529
|
+
# Need OR in case we have class method whose class whas never added as a sort, which would happen
|
|
530
|
+
# if the class method didn't return that type and no other function did. In this case, we don't need
|
|
531
|
+
# to care about the type vars and we we don't need to bind any possible type.
|
|
532
|
+
possible_types = self.state._get_possible_types(callable_ref.class_name) or [None]
|
|
432
533
|
cls_name = callable_ref.class_name
|
|
433
534
|
else:
|
|
434
535
|
possible_types = [None]
|
|
@@ -437,7 +538,6 @@ class FromEggState:
|
|
|
437
538
|
tcs = TypeConstraintSolver(self.decls)
|
|
438
539
|
if possible_type and possible_type.args:
|
|
439
540
|
tcs.bind_class(possible_type)
|
|
440
|
-
|
|
441
541
|
try:
|
|
442
542
|
arg_types, bound_tp_params = tcs.infer_arg_types(
|
|
443
543
|
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, tp, cls_name
|
|
@@ -445,7 +545,14 @@ class FromEggState:
|
|
|
445
545
|
except TypeConstraintError:
|
|
446
546
|
continue
|
|
447
547
|
args = tuple(self.resolve_term(a, tp) for a, tp in zip(term.args, arg_types, strict=False))
|
|
448
|
-
|
|
548
|
+
|
|
549
|
+
return CallDecl(
|
|
550
|
+
callable_ref,
|
|
551
|
+
args,
|
|
552
|
+
# Don't include bound type params if this is just a method, we only needed them for type resolution
|
|
553
|
+
# but dont need to store them
|
|
554
|
+
bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else None,
|
|
555
|
+
)
|
|
449
556
|
raise ValueError(f"Could not find callable ref for call {term}")
|
|
450
557
|
|
|
451
558
|
def resolve_term(self, term_id: int, tp: JustTypeRef) -> TypedExprDecl:
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
"""
|
|
3
|
+
Higher Order Functions
|
|
4
|
+
======================
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
from egglog import *
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from collections.abc import Callable
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Math(Expr):
|
|
18
|
+
def __init__(self, i: i64Like) -> None: ...
|
|
19
|
+
def __add__(self, other: Math) -> Math: ...
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MathList(Expr):
|
|
23
|
+
def __init__(self) -> None: ...
|
|
24
|
+
def append(self, i: Math) -> MathList: ...
|
|
25
|
+
def map(self, f: Callable[[Math], Math]) -> MathList: ...
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@ruleset
|
|
29
|
+
def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]): # noqa: ANN201
|
|
30
|
+
yield rewrite(Math(i) + Math(j)).to(Math(i + j))
|
|
31
|
+
yield rewrite(xs.append(x).map(f)).to(xs.map(f).append(f(x)))
|
|
32
|
+
yield rewrite(MathList().map(f)).to(MathList())
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@function(ruleset=math_ruleset)
|
|
36
|
+
def incr_list(xs: MathList) -> MathList:
|
|
37
|
+
return xs.map(lambda x: x + Math(1))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
egraph = EGraph()
|
|
41
|
+
y = egraph.let("y", incr_list(MathList().append(Math(1)).append(Math(2))))
|
|
42
|
+
egraph.run(math_ruleset.saturate())
|
|
43
|
+
egraph.check(eq(y).to(MathList().append(Math(2)).append(Math(3))))
|
|
44
|
+
|
|
45
|
+
egraph
|