egglog 11.2.0__cp310-cp310-win_amd64.whl → 11.4.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp310-win_amd64.pyd +0 -0
- egglog/bindings.pyi +62 -1
- egglog/builtins.py +10 -0
- egglog/declarations.py +39 -4
- egglog/deconstruct.py +9 -7
- egglog/egraph.py +360 -26
- egglog/egraph_state.py +283 -12
- egglog/examples/jointree.py +0 -3
- egglog/exp/array_api_jit.py +2 -2
- egglog/pretty.py +38 -8
- egglog/runtime.py +22 -7
- egglog/type_constraint_solver.py +2 -2
- {egglog-11.2.0.dist-info → egglog-11.4.0.dist-info}/METADATA +19 -1
- {egglog-11.2.0.dist-info → egglog-11.4.0.dist-info}/RECORD +16 -16
- {egglog-11.2.0.dist-info → egglog-11.4.0.dist-info}/WHEEL +0 -0
- {egglog-11.2.0.dist-info → egglog-11.4.0.dist-info}/licenses/LICENSE +0 -0
egglog/egraph_state.py
CHANGED
|
@@ -8,6 +8,7 @@ import re
|
|
|
8
8
|
from collections import defaultdict
|
|
9
9
|
from dataclasses import dataclass, field, replace
|
|
10
10
|
from typing import TYPE_CHECKING, Literal, overload
|
|
11
|
+
from uuid import UUID
|
|
11
12
|
|
|
12
13
|
from typing_extensions import assert_never
|
|
13
14
|
|
|
@@ -67,6 +68,7 @@ class EGraphState:
|
|
|
67
68
|
|
|
68
69
|
# Bidirectional mapping between egg sort names and python type references.
|
|
69
70
|
type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
|
|
71
|
+
egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)
|
|
70
72
|
|
|
71
73
|
# Cache of egg expressions for converting to egg
|
|
72
74
|
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)
|
|
@@ -85,22 +87,145 @@ class EGraphState:
|
|
|
85
87
|
egg_fn_to_callable_refs=defaultdict(set, {k: v.copy() for k, v in self.egg_fn_to_callable_refs.items()}),
|
|
86
88
|
callable_ref_to_egg_fn=self.callable_ref_to_egg_fn.copy(),
|
|
87
89
|
type_ref_to_egg_sort=self.type_ref_to_egg_sort.copy(),
|
|
90
|
+
egg_sort_to_type_ref=self.egg_sort_to_type_ref.copy(),
|
|
88
91
|
expr_to_egg_cache=self.expr_to_egg_cache.copy(),
|
|
89
92
|
cost_callables=self.cost_callables.copy(),
|
|
90
93
|
)
|
|
91
94
|
|
|
92
|
-
def
|
|
95
|
+
def run_schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Command:
|
|
96
|
+
"""
|
|
97
|
+
Turn a run schedule into an egg command.
|
|
98
|
+
|
|
99
|
+
If there exists any custom schedulers in the schedule, it will be turned into a custom extract command otherwise
|
|
100
|
+
will be a normal run command.
|
|
101
|
+
"""
|
|
102
|
+
processed_schedule = self._process_schedule(schedule)
|
|
103
|
+
if processed_schedule is None:
|
|
104
|
+
return bindings.RunSchedule(self._schedule_to_egg(schedule))
|
|
105
|
+
top_level_schedules = self._schedule_with_scheduler_to_egg(processed_schedule, [])
|
|
106
|
+
if len(top_level_schedules) == 1:
|
|
107
|
+
schedule_expr = top_level_schedules[0]
|
|
108
|
+
else:
|
|
109
|
+
schedule_expr = bindings.Call(span(), "seq", top_level_schedules)
|
|
110
|
+
return bindings.UserDefined(span(), "run-schedule", [schedule_expr])
|
|
111
|
+
|
|
112
|
+
def _process_schedule(self, schedule: ScheduleDecl) -> ScheduleDecl | None:
|
|
113
|
+
"""
|
|
114
|
+
Processes a schedule to determine if it contains any custom schedulers.
|
|
115
|
+
|
|
116
|
+
If it does, it returns a new schedule with all the required let bindings added to the other scope.
|
|
117
|
+
If not, returns none.
|
|
118
|
+
|
|
119
|
+
Also processes all rulesets in the schedule to make sure they are registered.
|
|
120
|
+
"""
|
|
121
|
+
bound_schedulers: list[UUID] = []
|
|
122
|
+
unbound_schedulers: list[BackOffDecl] = []
|
|
123
|
+
|
|
124
|
+
def helper(s: ScheduleDecl) -> None:
|
|
125
|
+
match s:
|
|
126
|
+
case LetSchedulerDecl(scheduler, inner):
|
|
127
|
+
bound_schedulers.append(scheduler.id)
|
|
128
|
+
return helper(inner)
|
|
129
|
+
case RunDecl(ruleset_name, _, scheduler):
|
|
130
|
+
self.ruleset_to_egg(ruleset_name)
|
|
131
|
+
if scheduler and scheduler.id not in bound_schedulers:
|
|
132
|
+
unbound_schedulers.append(scheduler)
|
|
133
|
+
case SaturateDecl(inner) | RepeatDecl(inner, _):
|
|
134
|
+
return helper(inner)
|
|
135
|
+
case SequenceDecl(schedules):
|
|
136
|
+
for sc in schedules:
|
|
137
|
+
helper(sc)
|
|
138
|
+
case _:
|
|
139
|
+
assert_never(s)
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
helper(schedule)
|
|
143
|
+
if not bound_schedulers and not unbound_schedulers:
|
|
144
|
+
return None
|
|
145
|
+
for scheduler in unbound_schedulers:
|
|
146
|
+
schedule = LetSchedulerDecl(scheduler, schedule)
|
|
147
|
+
return schedule
|
|
148
|
+
|
|
149
|
+
def _schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
|
|
150
|
+
msg = "Should never reach this, let schedulers should be handled by custom scheduler"
|
|
93
151
|
match schedule:
|
|
94
152
|
case SaturateDecl(schedule):
|
|
95
|
-
return bindings.Saturate(span(), self.
|
|
153
|
+
return bindings.Saturate(span(), self._schedule_to_egg(schedule))
|
|
96
154
|
case RepeatDecl(schedule, times):
|
|
97
|
-
return bindings.Repeat(span(), times, self.
|
|
155
|
+
return bindings.Repeat(span(), times, self._schedule_to_egg(schedule))
|
|
98
156
|
case SequenceDecl(schedules):
|
|
99
|
-
return bindings.Sequence(span(), [self.
|
|
100
|
-
case RunDecl(ruleset_name, until):
|
|
101
|
-
|
|
157
|
+
return bindings.Sequence(span(), [self._schedule_to_egg(s) for s in schedules])
|
|
158
|
+
case RunDecl(ruleset_name, until, scheduler):
|
|
159
|
+
if scheduler is not None:
|
|
160
|
+
raise ValueError(msg)
|
|
102
161
|
config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until)))
|
|
103
162
|
return bindings.Run(span(), config)
|
|
163
|
+
case LetSchedulerDecl():
|
|
164
|
+
raise ValueError(msg)
|
|
165
|
+
case _:
|
|
166
|
+
assert_never(schedule)
|
|
167
|
+
|
|
168
|
+
def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
|
|
169
|
+
self, schedule: ScheduleDecl, bound_schedulers: list[UUID]
|
|
170
|
+
) -> list[bindings._Expr]:
|
|
171
|
+
"""
|
|
172
|
+
Turns a scheduler into an egg expression, to be used with a custom extract command.
|
|
173
|
+
|
|
174
|
+
The bound_schedulers is a list of all the schedulers that have been bound. We can lookup their name as `_scheduler_{index}`.
|
|
175
|
+
"""
|
|
176
|
+
match schedule:
|
|
177
|
+
case LetSchedulerDecl(BackOffDecl(id, match_limit, ban_length), inner):
|
|
178
|
+
name = f"_scheduler_{len(bound_schedulers)}"
|
|
179
|
+
bound_schedulers.append(id)
|
|
180
|
+
args: list[bindings._Expr] = []
|
|
181
|
+
if match_limit is not None:
|
|
182
|
+
args.append(bindings.Var(span(), ":match-limit"))
|
|
183
|
+
args.append(bindings.Lit(span(), bindings.Int(match_limit)))
|
|
184
|
+
if ban_length is not None:
|
|
185
|
+
args.append(bindings.Var(span(), ":ban-length"))
|
|
186
|
+
args.append(bindings.Lit(span(), bindings.Int(ban_length)))
|
|
187
|
+
back_off_decl = bindings.Call(span(), "back-off", args)
|
|
188
|
+
let_decl = bindings.Call(span(), "let-scheduler", [bindings.Var(span(), name), back_off_decl])
|
|
189
|
+
return [let_decl, *self._schedule_with_scheduler_to_egg(inner, bound_schedulers)]
|
|
190
|
+
case RunDecl(ruleset_name, until, scheduler):
|
|
191
|
+
args = [bindings.Var(span(), ruleset_name)]
|
|
192
|
+
if scheduler:
|
|
193
|
+
name = "run-with"
|
|
194
|
+
scheduler_name = f"_scheduler_{bound_schedulers.index(scheduler.id)}"
|
|
195
|
+
args.insert(0, bindings.Var(span(), scheduler_name))
|
|
196
|
+
else:
|
|
197
|
+
name = "run"
|
|
198
|
+
if until:
|
|
199
|
+
if len(until) > 1:
|
|
200
|
+
msg = "Can only have one until fact with custom scheduler"
|
|
201
|
+
raise ValueError(msg)
|
|
202
|
+
args.append(bindings.Var(span(), ":until"))
|
|
203
|
+
fact_egg = self.fact_to_egg(until[0])
|
|
204
|
+
if isinstance(fact_egg, bindings.Eq):
|
|
205
|
+
msg = "Cannot use equality fact with custom scheduler"
|
|
206
|
+
raise ValueError(msg)
|
|
207
|
+
args.append(fact_egg.expr)
|
|
208
|
+
return [bindings.Call(span(), name, args)]
|
|
209
|
+
case SaturateDecl(inner):
|
|
210
|
+
return [
|
|
211
|
+
bindings.Call(span(), "saturate", self._schedule_with_scheduler_to_egg(inner, bound_schedulers))
|
|
212
|
+
]
|
|
213
|
+
case RepeatDecl(inner, times):
|
|
214
|
+
return [
|
|
215
|
+
bindings.Call(
|
|
216
|
+
span(),
|
|
217
|
+
"repeat",
|
|
218
|
+
[
|
|
219
|
+
bindings.Lit(span(), bindings.Int(times)),
|
|
220
|
+
*self._schedule_with_scheduler_to_egg(inner, bound_schedulers),
|
|
221
|
+
],
|
|
222
|
+
)
|
|
223
|
+
]
|
|
224
|
+
case SequenceDecl(schedules):
|
|
225
|
+
res = []
|
|
226
|
+
for s in schedules:
|
|
227
|
+
res.extend(self._schedule_with_scheduler_to_egg(s, bound_schedulers))
|
|
228
|
+
return res
|
|
104
229
|
case _:
|
|
105
230
|
assert_never(schedule)
|
|
106
231
|
|
|
@@ -229,6 +354,7 @@ class EGraphState:
|
|
|
229
354
|
Creates the egg cost table if needed and gets the name of the table.
|
|
230
355
|
"""
|
|
231
356
|
name = self.cost_table_name(ref)
|
|
357
|
+
print(name, self.cost_callables)
|
|
232
358
|
if ref not in self.cost_callables:
|
|
233
359
|
self.cost_callables.add(ref)
|
|
234
360
|
signature = self.__egg_decls__.get_callable_decl(ref).signature
|
|
@@ -332,10 +458,14 @@ class EGraphState:
|
|
|
332
458
|
pass
|
|
333
459
|
decl = self.__egg_decls__._classes[ref.name]
|
|
334
460
|
self.type_ref_to_egg_sort[ref] = egg_name = decl.egg_name or _generate_type_egg_name(ref)
|
|
461
|
+
self.egg_sort_to_type_ref[egg_name] = ref
|
|
335
462
|
if not decl.builtin or ref.args:
|
|
336
463
|
if ref.args:
|
|
337
464
|
if ref.name == "UnstableFn":
|
|
338
465
|
# UnstableFn is a special case, where the rest of args are collected into a call
|
|
466
|
+
if len(ref.args) < 2:
|
|
467
|
+
msg = "Zero argument higher order functions not supported"
|
|
468
|
+
raise NotImplementedError(msg)
|
|
339
469
|
type_args: list[bindings._Expr] = [
|
|
340
470
|
bindings.Call(
|
|
341
471
|
span(),
|
|
@@ -466,11 +596,9 @@ class EGraphState:
|
|
|
466
596
|
case _:
|
|
467
597
|
assert_never(value)
|
|
468
598
|
res = bindings.Lit(span(), l)
|
|
469
|
-
case CallDecl(
|
|
470
|
-
egg_fn,
|
|
471
|
-
egg_args = [self.typed_expr_to_egg(a, False) for a in
|
|
472
|
-
if reverse_args:
|
|
473
|
-
egg_args.reverse()
|
|
599
|
+
case CallDecl() | GetCostDecl():
|
|
600
|
+
egg_fn, typed_args = self.translate_call(expr_decl)
|
|
601
|
+
egg_args = [self.typed_expr_to_egg(a, False) for a in typed_args]
|
|
474
602
|
res = bindings.Call(span(), egg_fn, egg_args)
|
|
475
603
|
case PyObjectDecl(value):
|
|
476
604
|
res = GLOBAL_PY_OBJECT_SORT.store(value)
|
|
@@ -481,11 +609,31 @@ class EGraphState:
|
|
|
481
609
|
"unstable-fn",
|
|
482
610
|
[bindings.Lit(span(), bindings.String(egg_fn_call.name)), *egg_fn_call.args],
|
|
483
611
|
)
|
|
612
|
+
case ValueDecl():
|
|
613
|
+
msg = "Cannot turn a Value into an expression"
|
|
614
|
+
raise ValueError(msg)
|
|
484
615
|
case _:
|
|
485
616
|
assert_never(expr_decl.expr)
|
|
486
617
|
self.expr_to_egg_cache[expr_decl] = res
|
|
487
618
|
return res
|
|
488
619
|
|
|
620
|
+
def translate_call(self, expr: CallDecl | GetCostDecl) -> tuple[str, list[TypedExprDecl]]:
|
|
621
|
+
"""
|
|
622
|
+
Handle get cost and call decl, turn into egg table name and typed expr decls.
|
|
623
|
+
"""
|
|
624
|
+
match expr:
|
|
625
|
+
case CallDecl(ref, args, _):
|
|
626
|
+
egg_fn, reverse_args = self.callable_ref_to_egg(ref)
|
|
627
|
+
args_list = list(args)
|
|
628
|
+
if reverse_args:
|
|
629
|
+
args_list.reverse()
|
|
630
|
+
return egg_fn, args_list
|
|
631
|
+
case GetCostDecl(ref, args):
|
|
632
|
+
cost_table = self.create_cost_table(ref)
|
|
633
|
+
return cost_table, list(args)
|
|
634
|
+
case _:
|
|
635
|
+
assert_never(expr)
|
|
636
|
+
|
|
489
637
|
def exprs_from_egg(
|
|
490
638
|
self, termdag: bindings.TermDag, terms: list[bindings._Term], tp: JustTypeRef
|
|
491
639
|
) -> Iterable[TypedExprDecl]:
|
|
@@ -529,6 +677,129 @@ class EGraphState:
|
|
|
529
677
|
case _:
|
|
530
678
|
assert_never(ref)
|
|
531
679
|
|
|
680
|
+
def typed_expr_to_value(self, typed_expr: TypedExprDecl) -> bindings.Value:
|
|
681
|
+
egg_expr = self.typed_expr_to_egg(typed_expr, False)
|
|
682
|
+
return self.egraph.eval_expr(egg_expr)[1]
|
|
683
|
+
|
|
684
|
+
def value_to_expr(self, tp: JustTypeRef, value: bindings.Value) -> ExprDecl: # noqa: C901, PLR0911, PLR0912
|
|
685
|
+
match tp.name:
|
|
686
|
+
# Should match list in egraph bindings
|
|
687
|
+
case "i64":
|
|
688
|
+
return LitDecl(self.egraph.value_to_i64(value))
|
|
689
|
+
case "f64":
|
|
690
|
+
return LitDecl(self.egraph.value_to_f64(value))
|
|
691
|
+
case "Bool":
|
|
692
|
+
return LitDecl(self.egraph.value_to_bool(value))
|
|
693
|
+
case "String":
|
|
694
|
+
return LitDecl(self.egraph.value_to_string(value))
|
|
695
|
+
case "Unit":
|
|
696
|
+
return LitDecl(None)
|
|
697
|
+
case "PyObject":
|
|
698
|
+
return PyObjectDecl(self.egraph.value_to_pyobject(GLOBAL_PY_OBJECT_SORT, value))
|
|
699
|
+
case "Rational":
|
|
700
|
+
fraction = self.egraph.value_to_rational(value)
|
|
701
|
+
return CallDecl(
|
|
702
|
+
InitRef("Rational"),
|
|
703
|
+
(
|
|
704
|
+
TypedExprDecl(JustTypeRef("i64"), LitDecl(fraction.numerator)),
|
|
705
|
+
TypedExprDecl(JustTypeRef("i64"), LitDecl(fraction.denominator)),
|
|
706
|
+
),
|
|
707
|
+
)
|
|
708
|
+
case "BigInt":
|
|
709
|
+
i = self.egraph.value_to_bigint(value)
|
|
710
|
+
return CallDecl(
|
|
711
|
+
ClassMethodRef("BigInt", "from_string"),
|
|
712
|
+
(TypedExprDecl(JustTypeRef("String"), LitDecl(str(i))),),
|
|
713
|
+
)
|
|
714
|
+
case "BigRat":
|
|
715
|
+
fraction = self.egraph.value_to_bigrat(value)
|
|
716
|
+
return CallDecl(
|
|
717
|
+
InitRef("BigRat"),
|
|
718
|
+
(
|
|
719
|
+
TypedExprDecl(
|
|
720
|
+
JustTypeRef("BigInt"),
|
|
721
|
+
CallDecl(
|
|
722
|
+
ClassMethodRef("BigInt", "from_string"),
|
|
723
|
+
(TypedExprDecl(JustTypeRef("String"), LitDecl(str(fraction.numerator))),),
|
|
724
|
+
),
|
|
725
|
+
),
|
|
726
|
+
TypedExprDecl(
|
|
727
|
+
JustTypeRef("BigInt"),
|
|
728
|
+
CallDecl(
|
|
729
|
+
ClassMethodRef("BigInt", "from_string"),
|
|
730
|
+
(TypedExprDecl(JustTypeRef("String"), LitDecl(str(fraction.denominator))),),
|
|
731
|
+
),
|
|
732
|
+
),
|
|
733
|
+
),
|
|
734
|
+
)
|
|
735
|
+
case "Map":
|
|
736
|
+
k_tp, v_tp = tp.args
|
|
737
|
+
expr = CallDecl(ClassMethodRef("Map", "empty"), (), (k_tp, v_tp))
|
|
738
|
+
for k, v in self.egraph.value_to_map(value).items():
|
|
739
|
+
expr = CallDecl(
|
|
740
|
+
MethodRef("Map", "insert"),
|
|
741
|
+
(
|
|
742
|
+
TypedExprDecl(tp, expr),
|
|
743
|
+
TypedExprDecl(k_tp, self.value_to_expr(k_tp, k)),
|
|
744
|
+
TypedExprDecl(v_tp, self.value_to_expr(v_tp, v)),
|
|
745
|
+
),
|
|
746
|
+
)
|
|
747
|
+
return expr
|
|
748
|
+
case "Set":
|
|
749
|
+
xs_ = self.egraph.value_to_set(value)
|
|
750
|
+
(v_tp,) = tp.args
|
|
751
|
+
return CallDecl(
|
|
752
|
+
InitRef("Set"), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs_), (v_tp,)
|
|
753
|
+
)
|
|
754
|
+
case "Vec":
|
|
755
|
+
xs = self.egraph.value_to_vec(value)
|
|
756
|
+
(v_tp,) = tp.args
|
|
757
|
+
return CallDecl(
|
|
758
|
+
InitRef("Vec"), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs), (v_tp,)
|
|
759
|
+
)
|
|
760
|
+
case "MultiSet":
|
|
761
|
+
xs = self.egraph.value_to_multiset(value)
|
|
762
|
+
(v_tp,) = tp.args
|
|
763
|
+
return CallDecl(
|
|
764
|
+
InitRef("MultiSet"), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs), (v_tp,)
|
|
765
|
+
)
|
|
766
|
+
case "UnstableFn":
|
|
767
|
+
_names, _args = self.egraph.value_to_function(value)
|
|
768
|
+
return_tp, *arg_types = tp.args
|
|
769
|
+
return self._unstable_fn_value_to_expr(_names, _args, return_tp, arg_types)
|
|
770
|
+
return ValueDecl(value)
|
|
771
|
+
|
|
772
|
+
def _unstable_fn_value_to_expr(
|
|
773
|
+
self, name: str, partial_args: list[bindings.Value], return_tp: JustTypeRef, _arg_types: list[JustTypeRef]
|
|
774
|
+
) -> PartialCallDecl:
|
|
775
|
+
# Similar to FromEggState::from_call but accepts partial list of args and returns in values
|
|
776
|
+
# Find first callable ref whose return type matches and fill in arg types.
|
|
777
|
+
for callable_ref in self.egg_fn_to_callable_refs[name]:
|
|
778
|
+
signature = self.__egg_decls__.get_callable_decl(callable_ref).signature
|
|
779
|
+
if not isinstance(signature, FunctionSignature):
|
|
780
|
+
continue
|
|
781
|
+
if signature.semantic_return_type.name != return_tp.name:
|
|
782
|
+
continue
|
|
783
|
+
tcs = TypeConstraintSolver(self.__egg_decls__)
|
|
784
|
+
|
|
785
|
+
arg_types, bound_tp_params = tcs.infer_arg_types(
|
|
786
|
+
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, return_tp, None
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
args = tuple(
|
|
790
|
+
TypedExprDecl(tp, self.value_to_expr(tp, v)) for tp, v in zip(arg_types, partial_args, strict=False)
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
call_decl = CallDecl(
|
|
794
|
+
callable_ref,
|
|
795
|
+
args,
|
|
796
|
+
# Don't include bound type params if this is just a method, we only needed them for type resolution
|
|
797
|
+
# but dont need to store them
|
|
798
|
+
bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else (),
|
|
799
|
+
)
|
|
800
|
+
return PartialCallDecl(call_decl)
|
|
801
|
+
raise ValueError(f"Function '{name}' not found")
|
|
802
|
+
|
|
532
803
|
|
|
533
804
|
# https://chatgpt.com/share/9ab899b4-4e17-4426-a3f2-79d67a5ec456
|
|
534
805
|
_EGGLOG_INVALID_IDENT = re.compile(r"[^\w\-+*/?!=<>&|^/%]")
|
|
@@ -666,7 +937,7 @@ class FromEggState:
|
|
|
666
937
|
args,
|
|
667
938
|
# Don't include bound type params if this is just a method, we only needed them for type resolution
|
|
668
939
|
# but dont need to store them
|
|
669
|
-
bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else
|
|
940
|
+
bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else (),
|
|
670
941
|
)
|
|
671
942
|
raise ValueError(
|
|
672
943
|
f"Could not find callable ref for call {term}. None of these refs matched the types: {self.state.egg_fn_to_callable_refs[term.name]}"
|
egglog/examples/jointree.py
CHANGED
egglog/exp/array_api_jit.py
CHANGED
|
@@ -4,7 +4,7 @@ from typing import TypeVar, cast
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
|
|
7
|
-
from egglog import EGraph
|
|
7
|
+
from egglog import EGraph, greedy_dag_cost_model
|
|
8
8
|
from egglog.exp.array_api import NDArray, set_array_api_egraph, try_evaling
|
|
9
9
|
from egglog.exp.array_api_numba import array_api_numba_schedule
|
|
10
10
|
from egglog.exp.array_api_program_gen import EvalProgram, array_api_program_gen_schedule, ndarray_function_two_program
|
|
@@ -41,7 +41,7 @@ def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph,
|
|
|
41
41
|
res = fn(NDArray.var(arg1), NDArray.var(arg2))
|
|
42
42
|
egraph.register(res)
|
|
43
43
|
egraph.run(array_api_numba_schedule)
|
|
44
|
-
res_optimized = egraph.extract(res)
|
|
44
|
+
res_optimized = egraph.extract(res, cost_model=greedy_dag_cost_model())
|
|
45
45
|
|
|
46
46
|
return (
|
|
47
47
|
egraph,
|
egglog/pretty.py
CHANGED
|
@@ -67,7 +67,9 @@ UNARY_METHODS = {
|
|
|
67
67
|
"__invert__": "~",
|
|
68
68
|
}
|
|
69
69
|
|
|
70
|
-
AllDecls: TypeAlias =
|
|
70
|
+
AllDecls: TypeAlias = (
|
|
71
|
+
RulesetDecl | CombinedRulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl | BackOffDecl
|
|
72
|
+
)
|
|
71
73
|
|
|
72
74
|
|
|
73
75
|
def pretty_decl(
|
|
@@ -181,17 +183,19 @@ class TraverseContext:
|
|
|
181
183
|
if isinstance(de, DefaultRewriteDecl):
|
|
182
184
|
continue
|
|
183
185
|
self(de)
|
|
184
|
-
case CallDecl(ref, exprs, _):
|
|
186
|
+
case CallDecl(ref, exprs, _) | GetCostDecl(ref, exprs):
|
|
185
187
|
match ref:
|
|
186
188
|
case FunctionRef(UnnamedFunctionRef(_, res)):
|
|
187
189
|
self(res.expr)
|
|
188
190
|
case _:
|
|
189
191
|
for e in exprs:
|
|
190
192
|
self(e.expr)
|
|
191
|
-
case RunDecl(_, until):
|
|
193
|
+
case RunDecl(_, until, scheduler):
|
|
192
194
|
if until:
|
|
193
195
|
for f in until:
|
|
194
196
|
self(f)
|
|
197
|
+
if scheduler:
|
|
198
|
+
self(scheduler)
|
|
195
199
|
case PartialCallDecl(c):
|
|
196
200
|
self(c)
|
|
197
201
|
case CombinedRulesetDecl(_):
|
|
@@ -201,6 +205,13 @@ class TraverseContext:
|
|
|
201
205
|
case SetCostDecl(_, e, c):
|
|
202
206
|
self(e)
|
|
203
207
|
self(c)
|
|
208
|
+
case BackOffDecl() | ValueDecl():
|
|
209
|
+
pass
|
|
210
|
+
case LetSchedulerDecl(scheduler, schedule):
|
|
211
|
+
self(scheduler)
|
|
212
|
+
self(schedule)
|
|
213
|
+
case GetCostDecl(ref, args):
|
|
214
|
+
self(CallDecl(ref, args))
|
|
204
215
|
case _:
|
|
205
216
|
assert_never(decl)
|
|
206
217
|
|
|
@@ -238,7 +249,11 @@ class PrettyContext:
|
|
|
238
249
|
# it would take up is > than some constant (~ line length).
|
|
239
250
|
line_diff: int = len(expr) - LINE_DIFFERENCE
|
|
240
251
|
n_parents = self.parents[decl]
|
|
241
|
-
if n_parents > 1 and
|
|
252
|
+
if n_parents > 1 and (
|
|
253
|
+
n_parents * line_diff > MAX_LINE_LENGTH
|
|
254
|
+
# Schedulers with multiple parents need to be the same object, b/c are created with hidden UUIDs
|
|
255
|
+
or tp_name == "scheduler"
|
|
256
|
+
):
|
|
242
257
|
self.names[decl] = expr_name = self._name_expr(tp_name, expr, copy_identifier=False)
|
|
243
258
|
return expr_name
|
|
244
259
|
return expr
|
|
@@ -318,16 +333,31 @@ class PrettyContext:
|
|
|
318
333
|
return f"{self(schedules[0], parens=True)} + {self(schedules[1], parens=True)}", "schedule"
|
|
319
334
|
args = ", ".join(map(self, schedules))
|
|
320
335
|
return f"seq({args})", "schedule"
|
|
321
|
-
case
|
|
336
|
+
case LetSchedulerDecl(scheduler, schedule):
|
|
337
|
+
return f"{self(scheduler, parens=True)}.scope({self(schedule, parens=True)})", "schedule"
|
|
338
|
+
case RunDecl(ruleset_name, until, scheduler):
|
|
322
339
|
ruleset = self.decls._rulesets[ruleset_name]
|
|
323
340
|
ruleset_str = self(ruleset, ruleset_name=ruleset_name)
|
|
324
|
-
if not until:
|
|
341
|
+
if not until and not scheduler:
|
|
325
342
|
return ruleset_str, "schedule"
|
|
326
|
-
|
|
327
|
-
|
|
343
|
+
arg_lst = list(map(self, until or []))
|
|
344
|
+
if scheduler:
|
|
345
|
+
arg_lst.append(f"scheduler={self(scheduler)}")
|
|
346
|
+
return f"run({ruleset_str}, {', '.join(arg_lst)})", "schedule"
|
|
328
347
|
case DefaultRewriteDecl():
|
|
329
348
|
msg = "default rewrites should not be pretty printed"
|
|
330
349
|
raise TypeError(msg)
|
|
350
|
+
case BackOffDecl(_, match_limit, ban_length):
|
|
351
|
+
list_args: list[str] = []
|
|
352
|
+
if match_limit is not None:
|
|
353
|
+
list_args.append(f"match_limit={match_limit}")
|
|
354
|
+
if ban_length is not None:
|
|
355
|
+
list_args.append(f"ban_length={ban_length}")
|
|
356
|
+
return f"back_off({', '.join(list_args)})", "scheduler"
|
|
357
|
+
case ValueDecl(value):
|
|
358
|
+
return str(value), "value"
|
|
359
|
+
case GetCostDecl(ref, args):
|
|
360
|
+
return f"get_cost({self(CallDecl(ref, args))})", "get_cost"
|
|
331
361
|
assert_never(decl)
|
|
332
362
|
|
|
333
363
|
def _call(
|
egglog/runtime.py
CHANGED
|
@@ -457,7 +457,7 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
457
457
|
arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
|
|
458
458
|
return_tp = tcs.substitute_typevars(signature.semantic_return_type, cls_name)
|
|
459
459
|
bound_params = (
|
|
460
|
-
cast("JustTypeRef", bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else
|
|
460
|
+
cast("JustTypeRef", bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else ()
|
|
461
461
|
)
|
|
462
462
|
# If we were using unstable-app to call a funciton, add that function back as the first arg.
|
|
463
463
|
if function_value:
|
|
@@ -584,11 +584,17 @@ class RuntimeExpr(DelayedDeclerations):
|
|
|
584
584
|
if (method := _get_expr_method(self, "__eq__")) is not None:
|
|
585
585
|
return method(other)
|
|
586
586
|
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
587
|
+
if not (isinstance(self, RuntimeExpr) and isinstance(other, RuntimeExpr)):
|
|
588
|
+
return NotImplemented
|
|
589
|
+
if self.__egg_typed_expr__.tp != other.__egg_typed_expr__.tp:
|
|
590
|
+
return NotImplemented
|
|
590
591
|
|
|
591
|
-
|
|
592
|
+
from .egraph import Fact # noqa: PLC0415
|
|
593
|
+
|
|
594
|
+
return Fact(
|
|
595
|
+
Declarations.create(self, other),
|
|
596
|
+
EqDecl(self.__egg_typed_expr__.tp, self.__egg_typed_expr__.expr, other.__egg_typed_expr__.expr),
|
|
597
|
+
)
|
|
592
598
|
|
|
593
599
|
def __ne__(self, other: object) -> object: # type: ignore[override]
|
|
594
600
|
if (method := _get_expr_method(self, "__ne__")) is not None:
|
|
@@ -635,14 +641,22 @@ for name in TYPE_DEFINED_METHODS:
|
|
|
635
641
|
|
|
636
642
|
|
|
637
643
|
for name, r_method in itertools.product(NUMERIC_BINARY_METHODS, (False, True)):
|
|
644
|
+
method_name = f"__r{name[2:]}" if r_method else name
|
|
638
645
|
|
|
639
|
-
def _numeric_binary_method(
|
|
646
|
+
def _numeric_binary_method(
|
|
647
|
+
self: object, other: object, name: str = name, r_method: bool = r_method, method_name: str = method_name
|
|
648
|
+
) -> object:
|
|
640
649
|
"""
|
|
641
650
|
Implements numeric binary operations.
|
|
642
651
|
|
|
643
652
|
Tries to find the minimum cost conversion of either the LHS or the RHS, by finding all methods with either
|
|
644
653
|
the LHS or the RHS as exactly the right type and then upcasting the other to that type.
|
|
645
654
|
"""
|
|
655
|
+
# First check if we have a preserved method for this:
|
|
656
|
+
if isinstance(self, RuntimeExpr) and (
|
|
657
|
+
(preserved_method := self.__egg_class_decl__.preserved_methods.get(method_name)) is not None
|
|
658
|
+
):
|
|
659
|
+
return preserved_method.__get__(self)(other)
|
|
646
660
|
# 1. switch if reversed method
|
|
647
661
|
if r_method:
|
|
648
662
|
self, other = other, self
|
|
@@ -668,7 +682,6 @@ for name, r_method in itertools.product(NUMERIC_BINARY_METHODS, (False, True)):
|
|
|
668
682
|
fn = RuntimeFunction(Thunk.value(self.__egg_decls__), Thunk.value(method_ref), self)
|
|
669
683
|
return fn(other)
|
|
670
684
|
|
|
671
|
-
method_name = f"__r{name[2:]}" if r_method else name
|
|
672
685
|
setattr(RuntimeExpr, method_name, _numeric_binary_method)
|
|
673
686
|
|
|
674
687
|
|
|
@@ -688,6 +701,8 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
|
|
|
688
701
|
):
|
|
689
702
|
raise NotImplementedError(f"Can only turn constants or classvars into callable refs, not {expr}")
|
|
690
703
|
return expr.callable, decl_thunk()
|
|
704
|
+
case types.MethodWrapperType() if isinstance((slf := callable.__self__), RuntimeClass):
|
|
705
|
+
return MethodRef(slf.__egg_tp__.name, callable.__name__), slf.__egg_decls__
|
|
691
706
|
case _:
|
|
692
707
|
raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref")
|
|
693
708
|
|
egglog/type_constraint_solver.py
CHANGED
|
@@ -54,7 +54,7 @@ class TypeConstraintSolver:
|
|
|
54
54
|
fn_var_args: TypeOrVarRef | None,
|
|
55
55
|
return_: JustTypeRef,
|
|
56
56
|
cls_name: str | None,
|
|
57
|
-
) -> tuple[Iterable[JustTypeRef], tuple[JustTypeRef, ...]
|
|
57
|
+
) -> tuple[Iterable[JustTypeRef], tuple[JustTypeRef, ...]]:
|
|
58
58
|
"""
|
|
59
59
|
Given a return type, infer the argument types. If there is a variable arg, it returns an infinite iterable.
|
|
60
60
|
|
|
@@ -75,7 +75,7 @@ class TypeConstraintSolver:
|
|
|
75
75
|
)
|
|
76
76
|
)
|
|
77
77
|
if cls_name
|
|
78
|
-
else
|
|
78
|
+
else ()
|
|
79
79
|
)
|
|
80
80
|
return arg_types, bound_typevars
|
|
81
81
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: egglog
|
|
3
|
-
Version: 11.
|
|
3
|
+
Version: 11.4.0
|
|
4
4
|
Classifier: Environment :: MacOS X
|
|
5
5
|
Classifier: Environment :: Win32 (MS Windows)
|
|
6
6
|
Classifier: Intended Audience :: Developers
|
|
@@ -51,6 +51,7 @@ Requires-Dist: egglog[array] ; extra == 'docs'
|
|
|
51
51
|
Requires-Dist: line-profiler ; extra == 'docs'
|
|
52
52
|
Requires-Dist: sphinxcontrib-mermaid ; extra == 'docs'
|
|
53
53
|
Requires-Dist: ablog ; extra == 'docs'
|
|
54
|
+
Requires-Dist: jupytext ; extra == 'docs'
|
|
54
55
|
Provides-Extra: array
|
|
55
56
|
Provides-Extra: dev
|
|
56
57
|
Provides-Extra: test
|
|
@@ -71,4 +72,21 @@ allowing you to use e-graphs in Python for optimization, symbolic computation, a
|
|
|
71
72
|
Please see the [documentation](https://egglog-python.readthedocs.io/) for more information.
|
|
72
73
|
|
|
73
74
|
Come say hello [on the e-graphs Zulip](https://egraphs.zulipchat.com/#narrow/stream/375765-egglog/) or [open an issue](https://github.com/egraphs-good/egglog-python/issues/new/choose)!
|
|
75
|
+
|
|
76
|
+
## How to cite
|
|
77
|
+
|
|
78
|
+
If you use **egglog-python** in academic work, please cite the paper:
|
|
79
|
+
|
|
80
|
+
```bibtex
|
|
81
|
+
@misc{Shanabrook2023EgglogPython,
|
|
82
|
+
title = {Egglog Python: A Pythonic Library for E-graphs},
|
|
83
|
+
author = {Saul Shanabrook},
|
|
84
|
+
year = {2023},
|
|
85
|
+
eprint = {2305.04311},
|
|
86
|
+
archivePrefix = {arXiv},
|
|
87
|
+
primaryClass = {cs.PL},
|
|
88
|
+
doi = {10.48550/arXiv.2305.04311},
|
|
89
|
+
url = {https://arxiv.org/abs/2305.04311},
|
|
90
|
+
note = {Presented at EGRAPHS@PLDI 2023}
|
|
91
|
+
}
|
|
74
92
|
|