egglog 12.0.0__cp313-cp313t-manylinux_2_17_ppc64.manylinux2014_ppc64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- egglog/__init__.py +13 -0
- egglog/bindings.cpython-313t-powerpc64-linux-gnu.so +0 -0
- egglog/bindings.pyi +887 -0
- egglog/builtins.py +1144 -0
- egglog/config.py +8 -0
- egglog/conversion.py +290 -0
- egglog/declarations.py +964 -0
- egglog/deconstruct.py +176 -0
- egglog/egraph.py +2247 -0
- egglog/egraph_state.py +978 -0
- egglog/examples/README.rst +5 -0
- egglog/examples/__init__.py +3 -0
- egglog/examples/bignum.py +32 -0
- egglog/examples/bool.py +38 -0
- egglog/examples/eqsat_basic.py +44 -0
- egglog/examples/fib.py +28 -0
- egglog/examples/higher_order_functions.py +42 -0
- egglog/examples/jointree.py +64 -0
- egglog/examples/lambda_.py +287 -0
- egglog/examples/matrix.py +175 -0
- egglog/examples/multiset.py +60 -0
- egglog/examples/ndarrays.py +144 -0
- egglog/examples/resolution.py +84 -0
- egglog/examples/schedule_demo.py +34 -0
- egglog/exp/MoA.ipynb +617 -0
- egglog/exp/__init__.py +3 -0
- egglog/exp/any_expr.py +947 -0
- egglog/exp/any_expr_example.ipynb +408 -0
- egglog/exp/array_api.py +2019 -0
- egglog/exp/array_api_jit.py +51 -0
- egglog/exp/array_api_loopnest.py +74 -0
- egglog/exp/array_api_numba.py +69 -0
- egglog/exp/array_api_program_gen.py +510 -0
- egglog/exp/program_gen.py +427 -0
- egglog/exp/siu_examples.py +32 -0
- egglog/ipython_magic.py +41 -0
- egglog/pretty.py +566 -0
- egglog/py.typed +0 -0
- egglog/runtime.py +888 -0
- egglog/thunk.py +97 -0
- egglog/type_constraint_solver.py +111 -0
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35798 -0
- egglog/visualizer_widget.py +39 -0
- egglog-12.0.0.dist-info/METADATA +93 -0
- egglog-12.0.0.dist-info/RECORD +48 -0
- egglog-12.0.0.dist-info/WHEEL +5 -0
- egglog-12.0.0.dist-info/licenses/LICENSE +21 -0
egglog/egraph_state.py
ADDED
|
@@ -0,0 +1,978 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implement conversion to/from egglog.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import re
|
|
8
|
+
from base64 import standard_b64decode, standard_b64encode
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
from dataclasses import dataclass, field, replace
|
|
11
|
+
from typing import TYPE_CHECKING, Literal, assert_never, overload
|
|
12
|
+
from uuid import UUID
|
|
13
|
+
|
|
14
|
+
import cloudpickle
|
|
15
|
+
|
|
16
|
+
from . import bindings
|
|
17
|
+
from .declarations import *
|
|
18
|
+
from .declarations import ConstructorDecl
|
|
19
|
+
from .pretty import *
|
|
20
|
+
from .type_constraint_solver import *
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from collections.abc import Iterable
|
|
24
|
+
|
|
25
|
+
__all__ = ["EGraphState", "span"]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def span(frame_index: int = 0) -> bindings.RustSpan:
|
|
29
|
+
"""
|
|
30
|
+
Returns a span for the current file and line.
|
|
31
|
+
|
|
32
|
+
If `frame_index` is passed, it will return the span for that frame in the stack, where 0 is the current frame
|
|
33
|
+
this is called in and 1 is the parent.
|
|
34
|
+
"""
|
|
35
|
+
# Currently disable this because it's too expensive.
|
|
36
|
+
# import inspect
|
|
37
|
+
|
|
38
|
+
# frame = inspect.stack()[frame_index + 1]
|
|
39
|
+
return bindings.RustSpan("", 0, 0)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class EGraphState:
|
|
44
|
+
"""
|
|
45
|
+
State of the EGraph declerations and rulesets, so when we pop/push the stack we know whats defined.
|
|
46
|
+
|
|
47
|
+
Used for converting to/from egg and for pretty printing.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
egraph: bindings.EGraph
|
|
51
|
+
# The decleratons we have added.
|
|
52
|
+
__egg_decls__: Declarations = field(default_factory=Declarations)
|
|
53
|
+
# Mapping of added rulesets to the added rules
|
|
54
|
+
rulesets: dict[Ident, set[RewriteOrRuleDecl]] = field(default_factory=dict)
|
|
55
|
+
|
|
56
|
+
# Bidirectional mapping between egg function names and python callable references.
|
|
57
|
+
# Note that there are possibly mutliple callable references for a single egg function name, like `+`
|
|
58
|
+
# for both int and rational classes.
|
|
59
|
+
egg_fn_to_callable_refs: dict[str, set[CallableRef]] = field(
|
|
60
|
+
default_factory=lambda: defaultdict(set, {"!=": {FunctionRef(Ident.builtin("!="))}})
|
|
61
|
+
)
|
|
62
|
+
callable_ref_to_egg_fn: dict[CallableRef, tuple[str, bool]] = field(
|
|
63
|
+
default_factory=lambda: {FunctionRef(Ident.builtin("!=")): ("!=", False)}
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Bidirectional mapping between egg sort names and python type references.
|
|
67
|
+
type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
|
|
68
|
+
egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)
|
|
69
|
+
|
|
70
|
+
# Cache of egg expressions for converting to egg
|
|
71
|
+
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)
|
|
72
|
+
|
|
73
|
+
# Callables which have cost tables associated with them
|
|
74
|
+
cost_callables: set[CallableRef] = field(default_factory=set)
|
|
75
|
+
|
|
76
|
+
def copy(self) -> EGraphState:
|
|
77
|
+
"""
|
|
78
|
+
Returns a copy of the state. Th egraph reference is kept the same. Used for pushing/popping.
|
|
79
|
+
"""
|
|
80
|
+
return EGraphState(
|
|
81
|
+
egraph=self.egraph,
|
|
82
|
+
__egg_decls__=self.__egg_decls__.copy(),
|
|
83
|
+
rulesets={k: v.copy() for k, v in self.rulesets.items()},
|
|
84
|
+
egg_fn_to_callable_refs=defaultdict(set, {k: v.copy() for k, v in self.egg_fn_to_callable_refs.items()}),
|
|
85
|
+
callable_ref_to_egg_fn=self.callable_ref_to_egg_fn.copy(),
|
|
86
|
+
type_ref_to_egg_sort=self.type_ref_to_egg_sort.copy(),
|
|
87
|
+
egg_sort_to_type_ref=self.egg_sort_to_type_ref.copy(),
|
|
88
|
+
expr_to_egg_cache=self.expr_to_egg_cache.copy(),
|
|
89
|
+
cost_callables=self.cost_callables.copy(),
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def run_schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Command:
|
|
93
|
+
"""
|
|
94
|
+
Turn a run schedule into an egg command.
|
|
95
|
+
|
|
96
|
+
If there exists any custom schedulers in the schedule, it will be turned into a custom extract command otherwise
|
|
97
|
+
will be a normal run command.
|
|
98
|
+
"""
|
|
99
|
+
processed_schedule = self._process_schedule(schedule)
|
|
100
|
+
if processed_schedule is None:
|
|
101
|
+
return bindings.RunSchedule(self._schedule_to_egg(schedule))
|
|
102
|
+
top_level_schedules = self._schedule_with_scheduler_to_egg(processed_schedule, [])
|
|
103
|
+
if len(top_level_schedules) == 1:
|
|
104
|
+
schedule_expr = top_level_schedules[0]
|
|
105
|
+
else:
|
|
106
|
+
schedule_expr = bindings.Call(span(), "seq", top_level_schedules)
|
|
107
|
+
return bindings.UserDefined(span(), "run-schedule", [schedule_expr])
|
|
108
|
+
|
|
109
|
+
def _process_schedule(self, schedule: ScheduleDecl) -> ScheduleDecl | None:
|
|
110
|
+
"""
|
|
111
|
+
Processes a schedule to determine if it contains any custom schedulers.
|
|
112
|
+
|
|
113
|
+
If it does, it returns a new schedule with all the required let bindings added to the other scope.
|
|
114
|
+
If not, returns none.
|
|
115
|
+
|
|
116
|
+
Also processes all rulesets in the schedule to make sure they are registered.
|
|
117
|
+
"""
|
|
118
|
+
bound_schedulers: list[UUID] = []
|
|
119
|
+
unbound_schedulers: list[BackOffDecl] = []
|
|
120
|
+
|
|
121
|
+
def helper(s: ScheduleDecl) -> None:
|
|
122
|
+
match s:
|
|
123
|
+
case LetSchedulerDecl(scheduler, inner):
|
|
124
|
+
bound_schedulers.append(scheduler.id)
|
|
125
|
+
return helper(inner)
|
|
126
|
+
case RunDecl(ruleset_name, _, scheduler):
|
|
127
|
+
self.ruleset_to_egg(ruleset_name)
|
|
128
|
+
if scheduler and scheduler.id not in bound_schedulers:
|
|
129
|
+
unbound_schedulers.append(scheduler)
|
|
130
|
+
case SaturateDecl(inner) | RepeatDecl(inner, _):
|
|
131
|
+
return helper(inner)
|
|
132
|
+
case SequenceDecl(schedules):
|
|
133
|
+
for sc in schedules:
|
|
134
|
+
helper(sc)
|
|
135
|
+
case _:
|
|
136
|
+
assert_never(s)
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
helper(schedule)
|
|
140
|
+
if not bound_schedulers and not unbound_schedulers:
|
|
141
|
+
return None
|
|
142
|
+
for scheduler in unbound_schedulers:
|
|
143
|
+
schedule = LetSchedulerDecl(scheduler, schedule)
|
|
144
|
+
return schedule
|
|
145
|
+
|
|
146
|
+
def _schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
|
|
147
|
+
msg = "Should never reach this, let schedulers should be handled by custom scheduler"
|
|
148
|
+
match schedule:
|
|
149
|
+
case SaturateDecl(schedule):
|
|
150
|
+
return bindings.Saturate(span(), self._schedule_to_egg(schedule))
|
|
151
|
+
case RepeatDecl(schedule, times):
|
|
152
|
+
return bindings.Repeat(span(), times, self._schedule_to_egg(schedule))
|
|
153
|
+
case SequenceDecl(schedules):
|
|
154
|
+
return bindings.Sequence(span(), [self._schedule_to_egg(s) for s in schedules])
|
|
155
|
+
case RunDecl(ruleset_ident, until, scheduler):
|
|
156
|
+
if scheduler is not None:
|
|
157
|
+
raise ValueError(msg)
|
|
158
|
+
config = bindings.RunConfig(
|
|
159
|
+
str(ruleset_ident), None if not until else list(map(self.fact_to_egg, until))
|
|
160
|
+
)
|
|
161
|
+
return bindings.Run(span(), config)
|
|
162
|
+
case LetSchedulerDecl():
|
|
163
|
+
raise ValueError(msg)
|
|
164
|
+
case _:
|
|
165
|
+
assert_never(schedule)
|
|
166
|
+
|
|
167
|
+
def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
|
|
168
|
+
self, schedule: ScheduleDecl, bound_schedulers: list[UUID]
|
|
169
|
+
) -> list[bindings._Expr]:
|
|
170
|
+
"""
|
|
171
|
+
Turns a scheduler into an egg expression, to be used with a custom extract command.
|
|
172
|
+
|
|
173
|
+
The bound_schedulers is a list of all the schedulers that have been bound. We can lookup their name as `_scheduler_{index}`.
|
|
174
|
+
"""
|
|
175
|
+
match schedule:
|
|
176
|
+
case LetSchedulerDecl(BackOffDecl(id, match_limit, ban_length), inner):
|
|
177
|
+
name = f"_scheduler_{len(bound_schedulers)}"
|
|
178
|
+
bound_schedulers.append(id)
|
|
179
|
+
args: list[bindings._Expr] = []
|
|
180
|
+
if match_limit is not None:
|
|
181
|
+
args.append(bindings.Var(span(), ":match-limit"))
|
|
182
|
+
args.append(bindings.Lit(span(), bindings.Int(match_limit)))
|
|
183
|
+
if ban_length is not None:
|
|
184
|
+
args.append(bindings.Var(span(), ":ban-length"))
|
|
185
|
+
args.append(bindings.Lit(span(), bindings.Int(ban_length)))
|
|
186
|
+
back_off_decl = bindings.Call(span(), "back-off", args)
|
|
187
|
+
let_decl = bindings.Call(span(), "let-scheduler", [bindings.Var(span(), name), back_off_decl])
|
|
188
|
+
return [let_decl, *self._schedule_with_scheduler_to_egg(inner, bound_schedulers)]
|
|
189
|
+
case RunDecl(ruleset_ident, until, scheduler):
|
|
190
|
+
args = [bindings.Var(span(), str(ruleset_ident))]
|
|
191
|
+
if scheduler:
|
|
192
|
+
name = "run-with"
|
|
193
|
+
scheduler_name = f"_scheduler_{bound_schedulers.index(scheduler.id)}"
|
|
194
|
+
args.insert(0, bindings.Var(span(), scheduler_name))
|
|
195
|
+
else:
|
|
196
|
+
name = "run"
|
|
197
|
+
if until:
|
|
198
|
+
if len(until) > 1:
|
|
199
|
+
msg = "Can only have one until fact with custom scheduler"
|
|
200
|
+
raise ValueError(msg)
|
|
201
|
+
args.append(bindings.Var(span(), ":until"))
|
|
202
|
+
fact_egg = self.fact_to_egg(until[0])
|
|
203
|
+
if isinstance(fact_egg, bindings.Eq):
|
|
204
|
+
msg = "Cannot use equality fact with custom scheduler"
|
|
205
|
+
raise ValueError(msg)
|
|
206
|
+
args.append(fact_egg.expr)
|
|
207
|
+
return [bindings.Call(span(), name, args)]
|
|
208
|
+
case SaturateDecl(inner):
|
|
209
|
+
return [
|
|
210
|
+
bindings.Call(span(), "saturate", self._schedule_with_scheduler_to_egg(inner, bound_schedulers))
|
|
211
|
+
]
|
|
212
|
+
case RepeatDecl(inner, times):
|
|
213
|
+
return [
|
|
214
|
+
bindings.Call(
|
|
215
|
+
span(),
|
|
216
|
+
"repeat",
|
|
217
|
+
[
|
|
218
|
+
bindings.Lit(span(), bindings.Int(times)),
|
|
219
|
+
*self._schedule_with_scheduler_to_egg(inner, bound_schedulers),
|
|
220
|
+
],
|
|
221
|
+
)
|
|
222
|
+
]
|
|
223
|
+
case SequenceDecl(schedules):
|
|
224
|
+
res = []
|
|
225
|
+
for s in schedules:
|
|
226
|
+
res.extend(self._schedule_with_scheduler_to_egg(s, bound_schedulers))
|
|
227
|
+
return res
|
|
228
|
+
case _:
|
|
229
|
+
assert_never(schedule)
|
|
230
|
+
|
|
231
|
+
def ruleset_to_egg(self, ident: Ident) -> None:
|
|
232
|
+
"""
|
|
233
|
+
Registers a ruleset if it's not already registered.
|
|
234
|
+
"""
|
|
235
|
+
match self.__egg_decls__._rulesets[ident]:
|
|
236
|
+
case RulesetDecl(rules):
|
|
237
|
+
if ident not in self.rulesets:
|
|
238
|
+
if str(ident):
|
|
239
|
+
self.egraph.run_program(bindings.AddRuleset(span(), str(ident)))
|
|
240
|
+
added_rules = self.rulesets[ident] = set()
|
|
241
|
+
else:
|
|
242
|
+
added_rules = self.rulesets[ident]
|
|
243
|
+
for rule in rules:
|
|
244
|
+
if rule in added_rules:
|
|
245
|
+
continue
|
|
246
|
+
cmd = self.command_to_egg(rule, ident)
|
|
247
|
+
if cmd is not None:
|
|
248
|
+
self.egraph.run_program(cmd)
|
|
249
|
+
added_rules.add(rule)
|
|
250
|
+
case CombinedRulesetDecl(rulesets):
|
|
251
|
+
if ident in self.rulesets:
|
|
252
|
+
return
|
|
253
|
+
self.rulesets[ident] = set()
|
|
254
|
+
for ruleset in rulesets:
|
|
255
|
+
self.ruleset_to_egg(ruleset)
|
|
256
|
+
self.egraph.run_program(bindings.UnstableCombinedRuleset(span(), str(ident), list(map(str, rulesets))))
|
|
257
|
+
|
|
258
|
+
def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command | None:
|
|
259
|
+
match cmd:
|
|
260
|
+
case ActionCommandDecl(action):
|
|
261
|
+
action_egg = self.action_to_egg(action, expr_to_let=True)
|
|
262
|
+
if not action_egg:
|
|
263
|
+
return None
|
|
264
|
+
return bindings.ActionCommand(action_egg)
|
|
265
|
+
case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
|
|
266
|
+
self.type_ref_to_egg(tp)
|
|
267
|
+
rewrite = bindings.Rewrite(
|
|
268
|
+
span(),
|
|
269
|
+
self._expr_to_egg(lhs),
|
|
270
|
+
self._expr_to_egg(rhs),
|
|
271
|
+
[self.fact_to_egg(c) for c in conditions],
|
|
272
|
+
)
|
|
273
|
+
return (
|
|
274
|
+
bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume)
|
|
275
|
+
if isinstance(cmd, RewriteDecl)
|
|
276
|
+
else bindings.BiRewriteCommand(str(ruleset), rewrite)
|
|
277
|
+
)
|
|
278
|
+
case RuleDecl(head, body, name):
|
|
279
|
+
return bindings.RuleCommand(
|
|
280
|
+
bindings.Rule(
|
|
281
|
+
span(),
|
|
282
|
+
[self.action_to_egg(a) for a in head],
|
|
283
|
+
[self.fact_to_egg(f) for f in body],
|
|
284
|
+
name or "",
|
|
285
|
+
str(ruleset),
|
|
286
|
+
)
|
|
287
|
+
)
|
|
288
|
+
# TODO: Replace with just constants value and looking at REF of function
|
|
289
|
+
case DefaultRewriteDecl(ref, expr, subsume):
|
|
290
|
+
sig = self.__egg_decls__.get_callable_decl(ref).signature
|
|
291
|
+
assert isinstance(sig, FunctionSignature)
|
|
292
|
+
# Replace args with rule_var_name mapping
|
|
293
|
+
arg_mapping = tuple(
|
|
294
|
+
TypedExprDecl(tp.to_just(), UnboundVarDecl(name))
|
|
295
|
+
for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
|
|
296
|
+
)
|
|
297
|
+
rewrite_decl = RewriteDecl(
|
|
298
|
+
sig.semantic_return_type.to_just(), CallDecl(ref, arg_mapping), expr, (), subsume
|
|
299
|
+
)
|
|
300
|
+
return self.command_to_egg(rewrite_decl, ruleset)
|
|
301
|
+
case _:
|
|
302
|
+
assert_never(cmd)
|
|
303
|
+
|
|
304
|
+
@overload
|
|
305
|
+
def action_to_egg(self, action: ActionDecl) -> bindings._Action: ...
|
|
306
|
+
|
|
307
|
+
@overload
|
|
308
|
+
def action_to_egg(self, action: ActionDecl, expr_to_let: Literal[True] = ...) -> bindings._Action | None: ...
|
|
309
|
+
|
|
310
|
+
def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindings._Action | None: # noqa: C901, PLR0911, PLR0912
|
|
311
|
+
match action:
|
|
312
|
+
case LetDecl(name, typed_expr):
|
|
313
|
+
var_decl = LetRefDecl(name)
|
|
314
|
+
var_egg = self._expr_to_egg(var_decl)
|
|
315
|
+
self.expr_to_egg_cache[var_decl] = var_egg
|
|
316
|
+
return bindings.Let(span(), var_egg.name, self.typed_expr_to_egg(typed_expr))
|
|
317
|
+
case SetDecl(tp, call, rhs):
|
|
318
|
+
self.type_ref_to_egg(tp)
|
|
319
|
+
call_ = self._expr_to_egg(call)
|
|
320
|
+
return bindings.Set(span(), call_.name, call_.args, self._expr_to_egg(rhs))
|
|
321
|
+
case ExprActionDecl(typed_expr):
|
|
322
|
+
if expr_to_let:
|
|
323
|
+
maybe_typed_expr = self._transform_let(typed_expr)
|
|
324
|
+
if maybe_typed_expr:
|
|
325
|
+
typed_expr = maybe_typed_expr
|
|
326
|
+
else:
|
|
327
|
+
return None
|
|
328
|
+
return bindings.Expr_(span(), self.typed_expr_to_egg(typed_expr))
|
|
329
|
+
case ChangeDecl(tp, call, change):
|
|
330
|
+
self.type_ref_to_egg(tp)
|
|
331
|
+
call_ = self._expr_to_egg(call)
|
|
332
|
+
egg_change: bindings._Change
|
|
333
|
+
match change:
|
|
334
|
+
case "delete":
|
|
335
|
+
egg_change = bindings.Delete()
|
|
336
|
+
case "subsume":
|
|
337
|
+
egg_change = bindings.Subsume()
|
|
338
|
+
case _:
|
|
339
|
+
assert_never(change)
|
|
340
|
+
return bindings.Change(span(), egg_change, call_.name, call_.args)
|
|
341
|
+
case UnionDecl(tp, lhs, rhs):
|
|
342
|
+
self.type_ref_to_egg(tp)
|
|
343
|
+
return bindings.Union(span(), self._expr_to_egg(lhs), self._expr_to_egg(rhs))
|
|
344
|
+
case PanicDecl(name):
|
|
345
|
+
return bindings.Panic(span(), name)
|
|
346
|
+
case SetCostDecl(tp, expr, cost):
|
|
347
|
+
self.type_ref_to_egg(tp)
|
|
348
|
+
cost_table = self.create_cost_table(expr.callable)
|
|
349
|
+
args_egg = [self.typed_expr_to_egg(x, False) for x in expr.args]
|
|
350
|
+
return bindings.Set(span(), cost_table, args_egg, self._expr_to_egg(cost))
|
|
351
|
+
case _:
|
|
352
|
+
assert_never(action)
|
|
353
|
+
|
|
354
|
+
def create_cost_table(self, ref: CallableRef) -> str:
|
|
355
|
+
"""
|
|
356
|
+
Creates the egg cost table if needed and gets the name of the table.
|
|
357
|
+
"""
|
|
358
|
+
name = self.cost_table_name(ref)
|
|
359
|
+
if ref not in self.cost_callables:
|
|
360
|
+
self.cost_callables.add(ref)
|
|
361
|
+
signature = self.__egg_decls__.get_callable_decl(ref).signature
|
|
362
|
+
assert isinstance(signature, FunctionSignature), "Can only add cost tables for functions"
|
|
363
|
+
signature = replace(signature, return_type=TypeRefWithVars(Ident.builtin("i64")))
|
|
364
|
+
self.egraph.run_program(
|
|
365
|
+
bindings.FunctionCommand(span(), name, self._signature_to_egg_schema(signature), None)
|
|
366
|
+
)
|
|
367
|
+
return name
|
|
368
|
+
|
|
369
|
+
def cost_table_name(self, ref: CallableRef) -> str:
|
|
370
|
+
return f"cost_table_{self.callable_ref_to_egg(ref)[0]}"
|
|
371
|
+
|
|
372
|
+
def fact_to_egg(self, fact: FactDecl) -> bindings._Fact:
|
|
373
|
+
match fact:
|
|
374
|
+
case EqDecl(tp, left, right):
|
|
375
|
+
self.type_ref_to_egg(tp)
|
|
376
|
+
return bindings.Eq(span(), self._expr_to_egg(left), self._expr_to_egg(right))
|
|
377
|
+
case ExprFactDecl(typed_expr):
|
|
378
|
+
return bindings.Fact(self.typed_expr_to_egg(typed_expr, False))
|
|
379
|
+
case _:
|
|
380
|
+
assert_never(fact)
|
|
381
|
+
|
|
382
|
+
def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]: # noqa: C901, PLR0912
|
|
383
|
+
"""
|
|
384
|
+
Returns the egg function name for a callable reference, registering it if it is not already registered.
|
|
385
|
+
|
|
386
|
+
Also returns whether the args should be reversed
|
|
387
|
+
"""
|
|
388
|
+
if ref in self.callable_ref_to_egg_fn:
|
|
389
|
+
return self.callable_ref_to_egg_fn[ref]
|
|
390
|
+
decl = self.__egg_decls__.get_callable_decl(ref)
|
|
391
|
+
egg_name = decl.egg_name or _sanitize_egg_ident(self._generate_callable_egg_name(ref))
|
|
392
|
+
self.egg_fn_to_callable_refs[egg_name].add(ref)
|
|
393
|
+
reverse_args = False
|
|
394
|
+
match decl:
|
|
395
|
+
case RelationDecl(arg_types, _, _):
|
|
396
|
+
self.egraph.run_program(
|
|
397
|
+
bindings.Relation(span(), egg_name, [self.type_ref_to_egg(a) for a in arg_types])
|
|
398
|
+
)
|
|
399
|
+
case ConstantDecl(tp, _):
|
|
400
|
+
# Use constructor decleration instead of constant b/c constants cannot be extracted
|
|
401
|
+
# https://github.com/egraphs-good/egglog/issues/334
|
|
402
|
+
is_function = self.__egg_decls__._classes[tp.ident].builtin
|
|
403
|
+
schema = bindings.Schema([], self.type_ref_to_egg(tp))
|
|
404
|
+
if is_function:
|
|
405
|
+
self.egraph.run_program(bindings.FunctionCommand(span(), egg_name, schema, None))
|
|
406
|
+
else:
|
|
407
|
+
self.egraph.run_program(bindings.Constructor(span(), egg_name, schema, None, False))
|
|
408
|
+
case FunctionDecl(signature, builtin, _, merge):
|
|
409
|
+
if isinstance(signature, FunctionSignature):
|
|
410
|
+
reverse_args = signature.reverse_args
|
|
411
|
+
if not builtin:
|
|
412
|
+
assert isinstance(signature, FunctionSignature), "Cannot turn special function to egg"
|
|
413
|
+
# Compile functions that return unit to relations, because these show up in methods where you
|
|
414
|
+
# cant use the relation helper
|
|
415
|
+
schema = self._signature_to_egg_schema(signature)
|
|
416
|
+
if signature.return_type == TypeRefWithVars(Ident.builtin("Unit")):
|
|
417
|
+
if merge:
|
|
418
|
+
msg = "Cannot specify a merge function for a function that returns unit"
|
|
419
|
+
raise ValueError(msg)
|
|
420
|
+
self.egraph.run_program(bindings.Relation(span(), egg_name, schema.input))
|
|
421
|
+
else:
|
|
422
|
+
self.egraph.run_program(
|
|
423
|
+
bindings.FunctionCommand(
|
|
424
|
+
span(),
|
|
425
|
+
egg_name,
|
|
426
|
+
self._signature_to_egg_schema(signature),
|
|
427
|
+
self._expr_to_egg(merge) if merge else None,
|
|
428
|
+
)
|
|
429
|
+
)
|
|
430
|
+
case ConstructorDecl(signature, _, cost, unextractable):
|
|
431
|
+
self.egraph.run_program(
|
|
432
|
+
bindings.Constructor(
|
|
433
|
+
span(),
|
|
434
|
+
egg_name,
|
|
435
|
+
self._signature_to_egg_schema(signature),
|
|
436
|
+
cost,
|
|
437
|
+
unextractable,
|
|
438
|
+
)
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
case _:
|
|
442
|
+
assert_never(decl)
|
|
443
|
+
self.callable_ref_to_egg_fn[ref] = egg_name, reverse_args
|
|
444
|
+
return egg_name, reverse_args
|
|
445
|
+
|
|
446
|
+
def _signature_to_egg_schema(self, signature: FunctionSignature) -> bindings.Schema:
|
|
447
|
+
return bindings.Schema(
|
|
448
|
+
[self.type_ref_to_egg(a.to_just()) for a in signature.arg_types],
|
|
449
|
+
self.type_ref_to_egg(signature.semantic_return_type.to_just()),
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
def type_ref_to_egg(self, ref: JustTypeRef) -> str: # noqa: C901, PLR0912
|
|
453
|
+
"""
|
|
454
|
+
Returns the egg sort name for a type reference, registering it if it is not already registered.
|
|
455
|
+
"""
|
|
456
|
+
try:
|
|
457
|
+
return self.type_ref_to_egg_sort[ref]
|
|
458
|
+
except KeyError:
|
|
459
|
+
pass
|
|
460
|
+
decl = self.__egg_decls__._classes[ref.ident]
|
|
461
|
+
self.type_ref_to_egg_sort[ref] = egg_name = (not ref.args and decl.egg_name) or _generate_type_egg_name(ref)
|
|
462
|
+
self.egg_sort_to_type_ref[egg_name] = ref
|
|
463
|
+
if not decl.builtin or ref.args:
|
|
464
|
+
if ref.args:
|
|
465
|
+
if ref.ident == Ident.builtin("UnstableFn"):
|
|
466
|
+
# UnstableFn is a special case, where the rest of args are collected into a call
|
|
467
|
+
if len(ref.args) < 2:
|
|
468
|
+
msg = "Zero argument higher order functions not supported"
|
|
469
|
+
raise NotImplementedError(msg)
|
|
470
|
+
type_args: list[bindings._Expr] = [
|
|
471
|
+
bindings.Call(
|
|
472
|
+
span(),
|
|
473
|
+
self.type_ref_to_egg(ref.args[1]),
|
|
474
|
+
[bindings.Var(span(), self.type_ref_to_egg(a)) for a in ref.args[2:]],
|
|
475
|
+
),
|
|
476
|
+
bindings.Var(span(), self.type_ref_to_egg(ref.args[0])),
|
|
477
|
+
]
|
|
478
|
+
else:
|
|
479
|
+
# If any of methods have another type ref in them process all those first with substituted vars
|
|
480
|
+
# so that things like multiset - mapp will be added. Function type must be added first.
|
|
481
|
+
# Find all args of all methods and find any with type args themselves that are not this type and add them
|
|
482
|
+
tcs = TypeConstraintSolver(self.__egg_decls__)
|
|
483
|
+
tcs.bind_class(ref)
|
|
484
|
+
for method in decl.methods.values():
|
|
485
|
+
if not isinstance((signature := method.signature), FunctionSignature):
|
|
486
|
+
continue
|
|
487
|
+
for arg_tp in signature.arg_types:
|
|
488
|
+
if isinstance(arg_tp, TypeRefWithVars) and arg_tp.args and arg_tp.ident != ref.ident:
|
|
489
|
+
self.type_ref_to_egg(tcs.substitute_typevars(arg_tp, ref.ident))
|
|
490
|
+
|
|
491
|
+
type_args = [bindings.Var(span(), self.type_ref_to_egg(a)) for a in ref.args]
|
|
492
|
+
args = (self.type_ref_to_egg(JustTypeRef(ref.ident)), type_args)
|
|
493
|
+
else:
|
|
494
|
+
args = None
|
|
495
|
+
self.egraph.run_program(bindings.Sort(span(), egg_name, args))
|
|
496
|
+
# For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods, because
|
|
497
|
+
# these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted
|
|
498
|
+
# even if you never use that function.
|
|
499
|
+
if decl.builtin:
|
|
500
|
+
for method_name in decl.class_methods:
|
|
501
|
+
self.callable_ref_to_egg(ClassMethodRef(ref.ident, method_name))
|
|
502
|
+
if decl.init:
|
|
503
|
+
self.callable_ref_to_egg(InitRef(ref.ident))
|
|
504
|
+
|
|
505
|
+
return egg_name
|
|
506
|
+
|
|
507
|
+
def op_mapping(self) -> dict[str, str]:
|
|
508
|
+
"""
|
|
509
|
+
Create a mapping of egglog function name to Python function name, for use in the serialized format
|
|
510
|
+
for better visualization.
|
|
511
|
+
|
|
512
|
+
Includes cost tables
|
|
513
|
+
"""
|
|
514
|
+
return {
|
|
515
|
+
k: pretty_callable_ref(self.__egg_decls__, next(iter(v)))
|
|
516
|
+
for k, v in self.egg_fn_to_callable_refs.items()
|
|
517
|
+
if len(v) == 1
|
|
518
|
+
} | {
|
|
519
|
+
self.cost_table_name(ref): f"cost({pretty_callable_ref(self.__egg_decls__, ref, include_all_args=True)})"
|
|
520
|
+
for ref in self.cost_callables
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
def possible_egglog_functions(self, names: list[str]) -> Iterable[str]:
|
|
524
|
+
"""
|
|
525
|
+
Given a list of egglog functions, returns all the possible Python function strings
|
|
526
|
+
"""
|
|
527
|
+
for name in names:
|
|
528
|
+
for c in self.egg_fn_to_callable_refs[name]:
|
|
529
|
+
yield pretty_callable_ref(self.__egg_decls__, c)
|
|
530
|
+
|
|
531
|
+
def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl, transform_let: bool = True) -> bindings._Expr:
|
|
532
|
+
# transform all expressions with multiple parents into a let binding, so that less expressions
|
|
533
|
+
# are sent to egglog. Only for performance reasons.
|
|
534
|
+
if transform_let:
|
|
535
|
+
have_multiple_parents = _exprs_multiple_parents(typed_expr_decl)
|
|
536
|
+
for expr in reversed(have_multiple_parents):
|
|
537
|
+
self._transform_let(expr)
|
|
538
|
+
|
|
539
|
+
self.type_ref_to_egg(typed_expr_decl.tp)
|
|
540
|
+
return self._expr_to_egg(typed_expr_decl.expr)
|
|
541
|
+
|
|
542
|
+
def _transform_let(self, typed_expr: TypedExprDecl) -> TypedExprDecl | None:
|
|
543
|
+
"""
|
|
544
|
+
Rewrites this expression as a let binding if it's not already a let binding.
|
|
545
|
+
"""
|
|
546
|
+
# TODO: Replace with counter so that it works with hash collisions and is more stable
|
|
547
|
+
var_decl = LetRefDecl(f"__expr_{hash(typed_expr)}")
|
|
548
|
+
if var_decl in self.expr_to_egg_cache:
|
|
549
|
+
return None
|
|
550
|
+
var_egg = self._expr_to_egg(var_decl)
|
|
551
|
+
cmd = bindings.ActionCommand(bindings.Let(span(), var_egg.name, self.typed_expr_to_egg(typed_expr)))
|
|
552
|
+
try:
|
|
553
|
+
self.egraph.run_program(cmd)
|
|
554
|
+
# errors when creating let bindings for things like `(vec-empty)`
|
|
555
|
+
except bindings.EggSmolError:
|
|
556
|
+
return typed_expr
|
|
557
|
+
self.expr_to_egg_cache[typed_expr.expr] = var_egg
|
|
558
|
+
self.expr_to_egg_cache[var_decl] = var_egg
|
|
559
|
+
return None
|
|
560
|
+
|
|
561
|
+
@overload
|
|
562
|
+
def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
|
|
563
|
+
|
|
564
|
+
@overload
|
|
565
|
+
def _expr_to_egg(self, expr_decl: UnboundVarDecl | LetRefDecl) -> bindings.Var: ...
|
|
566
|
+
|
|
567
|
+
@overload
|
|
568
|
+
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
|
|
569
|
+
|
|
570
|
+
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: # noqa: PLR0912,C901
|
|
571
|
+
"""
|
|
572
|
+
Convert an ExprDecl to an egg expression.
|
|
573
|
+
"""
|
|
574
|
+
try:
|
|
575
|
+
return self.expr_to_egg_cache[expr_decl]
|
|
576
|
+
except KeyError:
|
|
577
|
+
pass
|
|
578
|
+
res: bindings._Expr
|
|
579
|
+
match expr_decl:
|
|
580
|
+
case LetRefDecl(name):
|
|
581
|
+
res = bindings.Var(span(), f"{name}")
|
|
582
|
+
case UnboundVarDecl(name, egg_name):
|
|
583
|
+
res = bindings.Var(span(), egg_name or f"_{name}")
|
|
584
|
+
case LitDecl(value):
|
|
585
|
+
l: bindings._Literal
|
|
586
|
+
match value:
|
|
587
|
+
case None:
|
|
588
|
+
l = bindings.Unit()
|
|
589
|
+
case bool(i):
|
|
590
|
+
l = bindings.Bool(i)
|
|
591
|
+
case int(i):
|
|
592
|
+
l = bindings.Int(i)
|
|
593
|
+
case float(f):
|
|
594
|
+
l = bindings.Float(f)
|
|
595
|
+
case str(s):
|
|
596
|
+
l = bindings.String(s)
|
|
597
|
+
case _:
|
|
598
|
+
assert_never(value)
|
|
599
|
+
res = bindings.Lit(span(), l)
|
|
600
|
+
case CallDecl() | GetCostDecl():
|
|
601
|
+
egg_fn, typed_args = self.translate_call(expr_decl)
|
|
602
|
+
egg_args = [self.typed_expr_to_egg(a, False) for a in typed_args]
|
|
603
|
+
res = bindings.Call(span(), egg_fn, egg_args)
|
|
604
|
+
case PyObjectDecl(value):
|
|
605
|
+
res = bindings.Call(
|
|
606
|
+
span(),
|
|
607
|
+
"py-object",
|
|
608
|
+
[bindings.Lit(span(), bindings.String(standard_b64encode(value).decode("utf-8")))],
|
|
609
|
+
)
|
|
610
|
+
case PartialCallDecl(call_decl):
|
|
611
|
+
egg_fn_call = self._expr_to_egg(call_decl)
|
|
612
|
+
res = bindings.Call(
|
|
613
|
+
span(),
|
|
614
|
+
"unstable-fn",
|
|
615
|
+
[bindings.Lit(span(), bindings.String(egg_fn_call.name)), *egg_fn_call.args],
|
|
616
|
+
)
|
|
617
|
+
case ValueDecl():
|
|
618
|
+
msg = "Cannot turn a Value into an expression"
|
|
619
|
+
raise ValueError(msg)
|
|
620
|
+
case _:
|
|
621
|
+
assert_never(expr_decl.expr)
|
|
622
|
+
self.expr_to_egg_cache[expr_decl] = res
|
|
623
|
+
return res
|
|
624
|
+
|
|
625
|
+
def translate_call(self, expr: CallDecl | GetCostDecl) -> tuple[str, list[TypedExprDecl]]:
|
|
626
|
+
"""
|
|
627
|
+
Handle get cost and call decl, turn into egg table name and typed expr decls.
|
|
628
|
+
"""
|
|
629
|
+
match expr:
|
|
630
|
+
case CallDecl(ref, args, _):
|
|
631
|
+
egg_fn, reverse_args = self.callable_ref_to_egg(ref)
|
|
632
|
+
args_list = list(args)
|
|
633
|
+
if reverse_args:
|
|
634
|
+
args_list.reverse()
|
|
635
|
+
return egg_fn, args_list
|
|
636
|
+
case GetCostDecl(ref, args):
|
|
637
|
+
cost_table = self.create_cost_table(ref)
|
|
638
|
+
return cost_table, list(args)
|
|
639
|
+
case _:
|
|
640
|
+
assert_never(expr)
|
|
641
|
+
|
|
642
|
+
def exprs_from_egg(
|
|
643
|
+
self, termdag: bindings.TermDag, terms: list[bindings._Term], tp: JustTypeRef
|
|
644
|
+
) -> Iterable[TypedExprDecl]:
|
|
645
|
+
"""
|
|
646
|
+
Create a function that can convert from an egg term to a typed expr.
|
|
647
|
+
"""
|
|
648
|
+
state = FromEggState(self, termdag)
|
|
649
|
+
return [state.from_expr(tp, term) for term in terms]
|
|
650
|
+
|
|
651
|
+
def _get_possible_types(self, cls_ident: Ident) -> frozenset[JustTypeRef]:
|
|
652
|
+
"""
|
|
653
|
+
Given a class name, returns all possible registered types that it can be.
|
|
654
|
+
"""
|
|
655
|
+
return frozenset(tp for tp in self.type_ref_to_egg_sort if tp.ident == cls_ident)
|
|
656
|
+
|
|
657
|
+
def _generate_callable_egg_name(self, ref: CallableRef) -> str:
|
|
658
|
+
"""
|
|
659
|
+
Generates a valid egg function name for a callable reference.
|
|
660
|
+
"""
|
|
661
|
+
match ref:
|
|
662
|
+
case FunctionRef(ident):
|
|
663
|
+
return str(ident)
|
|
664
|
+
|
|
665
|
+
case ConstantRef(ident):
|
|
666
|
+
# Prefix to avoid name collisions with local vars
|
|
667
|
+
return f"%{ident}"
|
|
668
|
+
case (
|
|
669
|
+
MethodRef(cls_ident, name)
|
|
670
|
+
| ClassMethodRef(cls_ident, name)
|
|
671
|
+
| ClassVariableRef(cls_ident, name)
|
|
672
|
+
| PropertyRef(cls_ident, name)
|
|
673
|
+
):
|
|
674
|
+
return f"{cls_ident}.{name}"
|
|
675
|
+
case InitRef(cls_ident):
|
|
676
|
+
return f"{cls_ident}.__init__"
|
|
677
|
+
case UnnamedFunctionRef(args, val):
|
|
678
|
+
parts = [str(self._expr_to_egg(a.expr)) + "-" + str(self.type_ref_to_egg(a.tp)) for a in args] + [
|
|
679
|
+
str(self.typed_expr_to_egg(val, False))
|
|
680
|
+
]
|
|
681
|
+
return "_".join(parts)
|
|
682
|
+
case _:
|
|
683
|
+
assert_never(ref)
|
|
684
|
+
|
|
685
|
+
def typed_expr_to_value(self, typed_expr: TypedExprDecl) -> bindings.Value:
|
|
686
|
+
egg_expr = self.typed_expr_to_egg(typed_expr, False)
|
|
687
|
+
return self.egraph.eval_expr(egg_expr)[1]
|
|
688
|
+
|
|
689
|
+
def value_to_expr(self, tp: JustTypeRef, value: bindings.Value) -> ExprDecl: # noqa: C901, PLR0911, PLR0912
|
|
690
|
+
if tp.ident.module != Ident.builtin("").module:
|
|
691
|
+
return ValueDecl(value)
|
|
692
|
+
|
|
693
|
+
match tp.ident.name:
|
|
694
|
+
# Should match list in egraph bindings
|
|
695
|
+
case "i64":
|
|
696
|
+
return LitDecl(self.egraph.value_to_i64(value))
|
|
697
|
+
case "f64":
|
|
698
|
+
return LitDecl(self.egraph.value_to_f64(value))
|
|
699
|
+
case "Bool":
|
|
700
|
+
return LitDecl(self.egraph.value_to_bool(value))
|
|
701
|
+
case "String":
|
|
702
|
+
return LitDecl(self.egraph.value_to_string(value))
|
|
703
|
+
case "Unit":
|
|
704
|
+
return LitDecl(None)
|
|
705
|
+
case "PyObject":
|
|
706
|
+
val = self.egraph.value_to_pyobject(value)
|
|
707
|
+
return PyObjectDecl(cloudpickle.dumps(val))
|
|
708
|
+
case "Rational":
|
|
709
|
+
fraction = self.egraph.value_to_rational(value)
|
|
710
|
+
return CallDecl(
|
|
711
|
+
InitRef(Ident.builtin("Rational")),
|
|
712
|
+
(
|
|
713
|
+
TypedExprDecl(JustTypeRef(Ident.builtin("i64")), LitDecl(fraction.numerator)),
|
|
714
|
+
TypedExprDecl(JustTypeRef(Ident.builtin("i64")), LitDecl(fraction.denominator)),
|
|
715
|
+
),
|
|
716
|
+
)
|
|
717
|
+
case "BigInt":
|
|
718
|
+
i = self.egraph.value_to_bigint(value)
|
|
719
|
+
return CallDecl(
|
|
720
|
+
ClassMethodRef(Ident.builtin("BigInt"), "from_string"),
|
|
721
|
+
(TypedExprDecl(JustTypeRef(Ident.builtin("String")), LitDecl(str(i))),),
|
|
722
|
+
)
|
|
723
|
+
case "BigRat":
|
|
724
|
+
fraction = self.egraph.value_to_bigrat(value)
|
|
725
|
+
return CallDecl(
|
|
726
|
+
InitRef(Ident.builtin("BigRat")),
|
|
727
|
+
(
|
|
728
|
+
TypedExprDecl(
|
|
729
|
+
JustTypeRef(Ident.builtin("BigInt")),
|
|
730
|
+
CallDecl(
|
|
731
|
+
ClassMethodRef(Ident.builtin("BigInt"), "from_string"),
|
|
732
|
+
(
|
|
733
|
+
TypedExprDecl(
|
|
734
|
+
JustTypeRef(Ident.builtin("String")), LitDecl(str(fraction.numerator))
|
|
735
|
+
),
|
|
736
|
+
),
|
|
737
|
+
),
|
|
738
|
+
),
|
|
739
|
+
TypedExprDecl(
|
|
740
|
+
JustTypeRef(Ident.builtin("BigInt")),
|
|
741
|
+
CallDecl(
|
|
742
|
+
ClassMethodRef(Ident.builtin("BigInt"), "from_string"),
|
|
743
|
+
(
|
|
744
|
+
TypedExprDecl(
|
|
745
|
+
JustTypeRef(Ident.builtin("String")), LitDecl(str(fraction.denominator))
|
|
746
|
+
),
|
|
747
|
+
),
|
|
748
|
+
),
|
|
749
|
+
),
|
|
750
|
+
),
|
|
751
|
+
)
|
|
752
|
+
case "Map":
|
|
753
|
+
k_tp, v_tp = tp.args
|
|
754
|
+
expr = CallDecl(ClassMethodRef(Ident.builtin("Map"), "empty"), (), (k_tp, v_tp))
|
|
755
|
+
for k, v in self.egraph.value_to_map(value).items():
|
|
756
|
+
expr = CallDecl(
|
|
757
|
+
MethodRef(Ident.builtin("Map"), "insert"),
|
|
758
|
+
(
|
|
759
|
+
TypedExprDecl(tp, expr),
|
|
760
|
+
TypedExprDecl(k_tp, self.value_to_expr(k_tp, k)),
|
|
761
|
+
TypedExprDecl(v_tp, self.value_to_expr(v_tp, v)),
|
|
762
|
+
),
|
|
763
|
+
)
|
|
764
|
+
return expr
|
|
765
|
+
case "Set":
|
|
766
|
+
xs_ = self.egraph.value_to_set(value)
|
|
767
|
+
(v_tp,) = tp.args
|
|
768
|
+
return CallDecl(
|
|
769
|
+
InitRef(Ident.builtin("Set")),
|
|
770
|
+
tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs_),
|
|
771
|
+
(v_tp,),
|
|
772
|
+
)
|
|
773
|
+
case "Vec":
|
|
774
|
+
xs = self.egraph.value_to_vec(value)
|
|
775
|
+
(v_tp,) = tp.args
|
|
776
|
+
return CallDecl(
|
|
777
|
+
InitRef(Ident.builtin("Vec")),
|
|
778
|
+
tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs),
|
|
779
|
+
(v_tp,),
|
|
780
|
+
)
|
|
781
|
+
case "MultiSet":
|
|
782
|
+
xs = self.egraph.value_to_multiset(value)
|
|
783
|
+
(v_tp,) = tp.args
|
|
784
|
+
return CallDecl(
|
|
785
|
+
InitRef(Ident.builtin("MultiSet")),
|
|
786
|
+
tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs),
|
|
787
|
+
(v_tp,),
|
|
788
|
+
)
|
|
789
|
+
case "UnstableFn":
|
|
790
|
+
_names, _args = self.egraph.value_to_function(value)
|
|
791
|
+
return_tp, *arg_types = tp.args
|
|
792
|
+
return self._unstable_fn_value_to_expr(_names, _args, return_tp, arg_types)
|
|
793
|
+
case _:
|
|
794
|
+
raise NotImplementedError(f"Value to expr not implemented for type {tp.ident}")
|
|
795
|
+
|
|
796
|
+
def _unstable_fn_value_to_expr(
|
|
797
|
+
self, name: str, partial_args: list[bindings.Value], return_tp: JustTypeRef, _arg_types: list[JustTypeRef]
|
|
798
|
+
) -> PartialCallDecl:
|
|
799
|
+
# Similar to FromEggState::from_call but accepts partial list of args and returns in values
|
|
800
|
+
# Find first callable ref whose return type matches and fill in arg types.
|
|
801
|
+
for callable_ref in self.egg_fn_to_callable_refs[name]:
|
|
802
|
+
signature = self.__egg_decls__.get_callable_decl(callable_ref).signature
|
|
803
|
+
if not isinstance(signature, FunctionSignature):
|
|
804
|
+
continue
|
|
805
|
+
if signature.semantic_return_type.ident != return_tp.ident:
|
|
806
|
+
continue
|
|
807
|
+
tcs = TypeConstraintSolver(self.__egg_decls__)
|
|
808
|
+
|
|
809
|
+
arg_types, bound_tp_params = tcs.infer_arg_types(
|
|
810
|
+
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, return_tp, None
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
args = tuple(
|
|
814
|
+
TypedExprDecl(tp, self.value_to_expr(tp, v)) for tp, v in zip(arg_types, partial_args, strict=False)
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
call_decl = CallDecl(
|
|
818
|
+
callable_ref,
|
|
819
|
+
args,
|
|
820
|
+
# Don't include bound type params if this is just a method, we only needed them for type resolution
|
|
821
|
+
# but dont need to store them
|
|
822
|
+
bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else (),
|
|
823
|
+
)
|
|
824
|
+
return PartialCallDecl(call_decl)
|
|
825
|
+
raise ValueError(f"Function '{name}' not found")
|
|
826
|
+
|
|
827
|
+
|
|
828
|
+
# https://chatgpt.com/share/9ab899b4-4e17-4426-a3f2-79d67a5ec456
|
|
829
|
+
_EGGLOG_INVALID_IDENT = re.compile(r"[^\w\-+*/?!=<>&|^/%]")
|
|
830
|
+
|
|
831
|
+
|
|
832
|
+
def _sanitize_egg_ident(input_string: str) -> str:
|
|
833
|
+
"""
|
|
834
|
+
Replaces all invalid characters in an egg identifier with an underscore.
|
|
835
|
+
"""
|
|
836
|
+
return _EGGLOG_INVALID_IDENT.sub("_", input_string)
|
|
837
|
+
|
|
838
|
+
|
|
839
|
+
def _exprs_multiple_parents(typed_expr: TypedExprDecl) -> list[TypedExprDecl]:
|
|
840
|
+
"""
|
|
841
|
+
Returns all expressions that have multiple parents (a list but semantically just an ordered set).
|
|
842
|
+
"""
|
|
843
|
+
to_traverse = {typed_expr}
|
|
844
|
+
traversed = set[TypedExprDecl]()
|
|
845
|
+
traversed_twice = list[TypedExprDecl]()
|
|
846
|
+
while to_traverse:
|
|
847
|
+
typed_expr = to_traverse.pop()
|
|
848
|
+
if typed_expr in traversed:
|
|
849
|
+
traversed_twice.append(typed_expr)
|
|
850
|
+
continue
|
|
851
|
+
traversed.add(typed_expr)
|
|
852
|
+
expr = typed_expr.expr
|
|
853
|
+
if isinstance(expr, CallDecl):
|
|
854
|
+
to_traverse.update(expr.args)
|
|
855
|
+
elif isinstance(expr, PartialCallDecl):
|
|
856
|
+
to_traverse.update(expr.call.args)
|
|
857
|
+
return traversed_twice
|
|
858
|
+
|
|
859
|
+
|
|
860
|
+
def _generate_type_egg_name(ref: JustTypeRef) -> str:
|
|
861
|
+
"""
|
|
862
|
+
Generates an egg sort name for this type reference by linearizing the type.
|
|
863
|
+
"""
|
|
864
|
+
name = ref.ident
|
|
865
|
+
if not ref.args:
|
|
866
|
+
return str(name)
|
|
867
|
+
return f"{name}[{','.join(map(_generate_type_egg_name, ref.args))}]"
|
|
868
|
+
|
|
869
|
+
|
|
870
|
+
@dataclass
|
|
871
|
+
class FromEggState:
|
|
872
|
+
"""
|
|
873
|
+
Dataclass containing state used when converting from an egg term to a typed expr.
|
|
874
|
+
"""
|
|
875
|
+
|
|
876
|
+
state: EGraphState
|
|
877
|
+
termdag: bindings.TermDag
|
|
878
|
+
# Cache of termdag ID to TypedExprDecl
|
|
879
|
+
cache: dict[int, TypedExprDecl] = field(default_factory=dict)
|
|
880
|
+
|
|
881
|
+
@property
|
|
882
|
+
def decls(self) -> Declarations:
|
|
883
|
+
return self.state.__egg_decls__
|
|
884
|
+
|
|
885
|
+
def from_expr(self, tp: JustTypeRef, term: bindings._Term) -> TypedExprDecl:
|
|
886
|
+
"""
|
|
887
|
+
Convert an egg term to a typed expr.
|
|
888
|
+
"""
|
|
889
|
+
expr_decl: ExprDecl
|
|
890
|
+
if isinstance(term, bindings.TermVar):
|
|
891
|
+
expr_decl = LetRefDecl(term.name)
|
|
892
|
+
elif isinstance(term, bindings.TermLit):
|
|
893
|
+
value = term.value
|
|
894
|
+
expr_decl = LitDecl(None if isinstance(value, bindings.Unit) else value.value)
|
|
895
|
+
elif isinstance(term, bindings.TermApp):
|
|
896
|
+
if term.name == "py-object":
|
|
897
|
+
(str_term,) = term.args
|
|
898
|
+
call = self.termdag.get(str_term)
|
|
899
|
+
assert isinstance(call, bindings.TermLit)
|
|
900
|
+
assert isinstance(call.value, bindings.String)
|
|
901
|
+
expr_decl = PyObjectDecl(standard_b64decode(call.value.value))
|
|
902
|
+
elif term.name == "unstable-fn":
|
|
903
|
+
# Get function name
|
|
904
|
+
fn_term, *arg_terms = term.args
|
|
905
|
+
fn_value = self.resolve_term(fn_term, JustTypeRef(Ident.builtin("String")))
|
|
906
|
+
assert isinstance(fn_value.expr, LitDecl)
|
|
907
|
+
fn_name = fn_value.expr.value
|
|
908
|
+
assert isinstance(fn_name, str)
|
|
909
|
+
|
|
910
|
+
# Resolve what types the partially applied args are
|
|
911
|
+
assert tp.ident == Ident.builtin("UnstableFn")
|
|
912
|
+
call_decl = self.from_call(tp.args[0], bindings.TermApp(fn_name, arg_terms))
|
|
913
|
+
expr_decl = PartialCallDecl(call_decl)
|
|
914
|
+
else:
|
|
915
|
+
expr_decl = self.from_call(tp, term)
|
|
916
|
+
else:
|
|
917
|
+
assert_never(term)
|
|
918
|
+
return TypedExprDecl(tp, expr_decl)
|
|
919
|
+
|
|
920
|
+
def from_call(
|
|
921
|
+
self,
|
|
922
|
+
tp: JustTypeRef,
|
|
923
|
+
term: bindings.TermApp, # additional_arg_tps: tuple[JustTypeRef, ...]
|
|
924
|
+
) -> CallDecl:
|
|
925
|
+
"""
|
|
926
|
+
Convert a call to a CallDecl.
|
|
927
|
+
|
|
928
|
+
There could be Python call refs which match the call, so we need to find the correct one.
|
|
929
|
+
|
|
930
|
+
The additional_arg_tps are known types for arguments that come after the term args, used to infer types
|
|
931
|
+
for partially applied functions, where we know the types of the later args, but not of the earlier ones where
|
|
932
|
+
we have values for.
|
|
933
|
+
"""
|
|
934
|
+
# Find the first callable ref that matches the call
|
|
935
|
+
for callable_ref in self.state.egg_fn_to_callable_refs[term.name]:
|
|
936
|
+
# If this is a classmethod, we might need the type params that were bound for this type
|
|
937
|
+
# This could be multiple types if the classmethod is ambiguous, like map create.
|
|
938
|
+
possible_types: Iterable[JustTypeRef | None]
|
|
939
|
+
signature = self.decls.get_callable_decl(callable_ref).signature
|
|
940
|
+
assert isinstance(signature, FunctionSignature)
|
|
941
|
+
if isinstance(callable_ref, ClassMethodRef | InitRef | MethodRef):
|
|
942
|
+
# Need OR in case we have class method whose class whas never added as a sort, which would happen
|
|
943
|
+
# if the class method didn't return that type and no other function did. In this case, we don't need
|
|
944
|
+
# to care about the type vars and we we don't need to bind any possible type.
|
|
945
|
+
possible_types = self.state._get_possible_types(callable_ref.ident) or [None]
|
|
946
|
+
cls_name = callable_ref.ident
|
|
947
|
+
else:
|
|
948
|
+
possible_types = [None]
|
|
949
|
+
cls_name = None
|
|
950
|
+
for possible_type in possible_types:
|
|
951
|
+
tcs = TypeConstraintSolver(self.decls)
|
|
952
|
+
if possible_type and possible_type.args:
|
|
953
|
+
tcs.bind_class(possible_type)
|
|
954
|
+
try:
|
|
955
|
+
arg_types, bound_tp_params = tcs.infer_arg_types(
|
|
956
|
+
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, tp, cls_name
|
|
957
|
+
)
|
|
958
|
+
except TypeConstraintError:
|
|
959
|
+
continue
|
|
960
|
+
args = tuple(self.resolve_term(a, tp) for a, tp in zip(term.args, arg_types, strict=False))
|
|
961
|
+
|
|
962
|
+
return CallDecl(
|
|
963
|
+
callable_ref,
|
|
964
|
+
args,
|
|
965
|
+
# Don't include bound type params if this is just a method, we only needed them for type resolution
|
|
966
|
+
# but dont need to store them
|
|
967
|
+
bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else (),
|
|
968
|
+
)
|
|
969
|
+
raise ValueError(
|
|
970
|
+
f"Could not find callable ref for call {term}. None of these refs matched the types: {self.state.egg_fn_to_callable_refs[term.name]}"
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
def resolve_term(self, term_id: int, tp: JustTypeRef) -> TypedExprDecl:
|
|
974
|
+
try:
|
|
975
|
+
return self.cache[term_id]
|
|
976
|
+
except KeyError:
|
|
977
|
+
res = self.cache[term_id] = self.from_expr(tp, self.termdag.get(term_id))
|
|
978
|
+
return res
|