egglog 11.2.0__cp310-cp310-win_amd64.whl → 11.4.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph_state.py CHANGED
@@ -8,6 +8,7 @@ import re
8
8
  from collections import defaultdict
9
9
  from dataclasses import dataclass, field, replace
10
10
  from typing import TYPE_CHECKING, Literal, overload
11
+ from uuid import UUID
11
12
 
12
13
  from typing_extensions import assert_never
13
14
 
@@ -67,6 +68,7 @@ class EGraphState:
67
68
 
68
69
  # Bidirectional mapping between egg sort names and python type references.
69
70
  type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
71
+ egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)
70
72
 
71
73
  # Cache of egg expressions for converting to egg
72
74
  expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)
@@ -85,22 +87,145 @@ class EGraphState:
85
87
  egg_fn_to_callable_refs=defaultdict(set, {k: v.copy() for k, v in self.egg_fn_to_callable_refs.items()}),
86
88
  callable_ref_to_egg_fn=self.callable_ref_to_egg_fn.copy(),
87
89
  type_ref_to_egg_sort=self.type_ref_to_egg_sort.copy(),
90
+ egg_sort_to_type_ref=self.egg_sort_to_type_ref.copy(),
88
91
  expr_to_egg_cache=self.expr_to_egg_cache.copy(),
89
92
  cost_callables=self.cost_callables.copy(),
90
93
  )
91
94
 
92
- def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
95
+ def run_schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Command:
96
+ """
97
+ Turn a run schedule into an egg command.
98
+
99
+ If there exists any custom schedulers in the schedule, it will be turned into a custom extract command otherwise
100
+ will be a normal run command.
101
+ """
102
+ processed_schedule = self._process_schedule(schedule)
103
+ if processed_schedule is None:
104
+ return bindings.RunSchedule(self._schedule_to_egg(schedule))
105
+ top_level_schedules = self._schedule_with_scheduler_to_egg(processed_schedule, [])
106
+ if len(top_level_schedules) == 1:
107
+ schedule_expr = top_level_schedules[0]
108
+ else:
109
+ schedule_expr = bindings.Call(span(), "seq", top_level_schedules)
110
+ return bindings.UserDefined(span(), "run-schedule", [schedule_expr])
111
+
112
+ def _process_schedule(self, schedule: ScheduleDecl) -> ScheduleDecl | None:
113
+ """
114
+ Processes a schedule to determine if it contains any custom schedulers.
115
+
116
+ If it does, it returns a new schedule with all the required let bindings added to the other scope.
117
+ If not, returns none.
118
+
119
+ Also processes all rulesets in the schedule to make sure they are registered.
120
+ """
121
+ bound_schedulers: list[UUID] = []
122
+ unbound_schedulers: list[BackOffDecl] = []
123
+
124
+ def helper(s: ScheduleDecl) -> None:
125
+ match s:
126
+ case LetSchedulerDecl(scheduler, inner):
127
+ bound_schedulers.append(scheduler.id)
128
+ return helper(inner)
129
+ case RunDecl(ruleset_name, _, scheduler):
130
+ self.ruleset_to_egg(ruleset_name)
131
+ if scheduler and scheduler.id not in bound_schedulers:
132
+ unbound_schedulers.append(scheduler)
133
+ case SaturateDecl(inner) | RepeatDecl(inner, _):
134
+ return helper(inner)
135
+ case SequenceDecl(schedules):
136
+ for sc in schedules:
137
+ helper(sc)
138
+ case _:
139
+ assert_never(s)
140
+ return None
141
+
142
+ helper(schedule)
143
+ if not bound_schedulers and not unbound_schedulers:
144
+ return None
145
+ for scheduler in unbound_schedulers:
146
+ schedule = LetSchedulerDecl(scheduler, schedule)
147
+ return schedule
148
+
149
+ def _schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
150
+ msg = "Should never reach this, let schedulers should be handled by custom scheduler"
93
151
  match schedule:
94
152
  case SaturateDecl(schedule):
95
- return bindings.Saturate(span(), self.schedule_to_egg(schedule))
153
+ return bindings.Saturate(span(), self._schedule_to_egg(schedule))
96
154
  case RepeatDecl(schedule, times):
97
- return bindings.Repeat(span(), times, self.schedule_to_egg(schedule))
155
+ return bindings.Repeat(span(), times, self._schedule_to_egg(schedule))
98
156
  case SequenceDecl(schedules):
99
- return bindings.Sequence(span(), [self.schedule_to_egg(s) for s in schedules])
100
- case RunDecl(ruleset_name, until):
101
- self.ruleset_to_egg(ruleset_name)
157
+ return bindings.Sequence(span(), [self._schedule_to_egg(s) for s in schedules])
158
+ case RunDecl(ruleset_name, until, scheduler):
159
+ if scheduler is not None:
160
+ raise ValueError(msg)
102
161
  config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until)))
103
162
  return bindings.Run(span(), config)
163
+ case LetSchedulerDecl():
164
+ raise ValueError(msg)
165
+ case _:
166
+ assert_never(schedule)
167
+
168
+ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
169
+ self, schedule: ScheduleDecl, bound_schedulers: list[UUID]
170
+ ) -> list[bindings._Expr]:
171
+ """
172
+ Turns a scheduler into an egg expression, to be used with a custom extract command.
173
+
174
+ The bound_schedulers is a list of all the schedulers that have been bound. We can lookup their name as `_scheduler_{index}`.
175
+ """
176
+ match schedule:
177
+ case LetSchedulerDecl(BackOffDecl(id, match_limit, ban_length), inner):
178
+ name = f"_scheduler_{len(bound_schedulers)}"
179
+ bound_schedulers.append(id)
180
+ args: list[bindings._Expr] = []
181
+ if match_limit is not None:
182
+ args.append(bindings.Var(span(), ":match-limit"))
183
+ args.append(bindings.Lit(span(), bindings.Int(match_limit)))
184
+ if ban_length is not None:
185
+ args.append(bindings.Var(span(), ":ban-length"))
186
+ args.append(bindings.Lit(span(), bindings.Int(ban_length)))
187
+ back_off_decl = bindings.Call(span(), "back-off", args)
188
+ let_decl = bindings.Call(span(), "let-scheduler", [bindings.Var(span(), name), back_off_decl])
189
+ return [let_decl, *self._schedule_with_scheduler_to_egg(inner, bound_schedulers)]
190
+ case RunDecl(ruleset_name, until, scheduler):
191
+ args = [bindings.Var(span(), ruleset_name)]
192
+ if scheduler:
193
+ name = "run-with"
194
+ scheduler_name = f"_scheduler_{bound_schedulers.index(scheduler.id)}"
195
+ args.insert(0, bindings.Var(span(), scheduler_name))
196
+ else:
197
+ name = "run"
198
+ if until:
199
+ if len(until) > 1:
200
+ msg = "Can only have one until fact with custom scheduler"
201
+ raise ValueError(msg)
202
+ args.append(bindings.Var(span(), ":until"))
203
+ fact_egg = self.fact_to_egg(until[0])
204
+ if isinstance(fact_egg, bindings.Eq):
205
+ msg = "Cannot use equality fact with custom scheduler"
206
+ raise ValueError(msg)
207
+ args.append(fact_egg.expr)
208
+ return [bindings.Call(span(), name, args)]
209
+ case SaturateDecl(inner):
210
+ return [
211
+ bindings.Call(span(), "saturate", self._schedule_with_scheduler_to_egg(inner, bound_schedulers))
212
+ ]
213
+ case RepeatDecl(inner, times):
214
+ return [
215
+ bindings.Call(
216
+ span(),
217
+ "repeat",
218
+ [
219
+ bindings.Lit(span(), bindings.Int(times)),
220
+ *self._schedule_with_scheduler_to_egg(inner, bound_schedulers),
221
+ ],
222
+ )
223
+ ]
224
+ case SequenceDecl(schedules):
225
+ res = []
226
+ for s in schedules:
227
+ res.extend(self._schedule_with_scheduler_to_egg(s, bound_schedulers))
228
+ return res
104
229
  case _:
105
230
  assert_never(schedule)
106
231
 
@@ -229,6 +354,7 @@ class EGraphState:
229
354
  Creates the egg cost table if needed and gets the name of the table.
230
355
  """
231
356
  name = self.cost_table_name(ref)
357
+ print(name, self.cost_callables)
232
358
  if ref not in self.cost_callables:
233
359
  self.cost_callables.add(ref)
234
360
  signature = self.__egg_decls__.get_callable_decl(ref).signature
@@ -332,10 +458,14 @@ class EGraphState:
332
458
  pass
333
459
  decl = self.__egg_decls__._classes[ref.name]
334
460
  self.type_ref_to_egg_sort[ref] = egg_name = decl.egg_name or _generate_type_egg_name(ref)
461
+ self.egg_sort_to_type_ref[egg_name] = ref
335
462
  if not decl.builtin or ref.args:
336
463
  if ref.args:
337
464
  if ref.name == "UnstableFn":
338
465
  # UnstableFn is a special case, where the rest of args are collected into a call
466
+ if len(ref.args) < 2:
467
+ msg = "Zero argument higher order functions not supported"
468
+ raise NotImplementedError(msg)
339
469
  type_args: list[bindings._Expr] = [
340
470
  bindings.Call(
341
471
  span(),
@@ -466,11 +596,9 @@ class EGraphState:
466
596
  case _:
467
597
  assert_never(value)
468
598
  res = bindings.Lit(span(), l)
469
- case CallDecl(ref, args, _):
470
- egg_fn, reverse_args = self.callable_ref_to_egg(ref)
471
- egg_args = [self.typed_expr_to_egg(a, False) for a in args]
472
- if reverse_args:
473
- egg_args.reverse()
599
+ case CallDecl() | GetCostDecl():
600
+ egg_fn, typed_args = self.translate_call(expr_decl)
601
+ egg_args = [self.typed_expr_to_egg(a, False) for a in typed_args]
474
602
  res = bindings.Call(span(), egg_fn, egg_args)
475
603
  case PyObjectDecl(value):
476
604
  res = GLOBAL_PY_OBJECT_SORT.store(value)
@@ -481,11 +609,31 @@ class EGraphState:
481
609
  "unstable-fn",
482
610
  [bindings.Lit(span(), bindings.String(egg_fn_call.name)), *egg_fn_call.args],
483
611
  )
612
+ case ValueDecl():
613
+ msg = "Cannot turn a Value into an expression"
614
+ raise ValueError(msg)
484
615
  case _:
485
616
  assert_never(expr_decl.expr)
486
617
  self.expr_to_egg_cache[expr_decl] = res
487
618
  return res
488
619
 
620
+ def translate_call(self, expr: CallDecl | GetCostDecl) -> tuple[str, list[TypedExprDecl]]:
621
+ """
622
+ Handle get cost and call decl, turn into egg table name and typed expr decls.
623
+ """
624
+ match expr:
625
+ case CallDecl(ref, args, _):
626
+ egg_fn, reverse_args = self.callable_ref_to_egg(ref)
627
+ args_list = list(args)
628
+ if reverse_args:
629
+ args_list.reverse()
630
+ return egg_fn, args_list
631
+ case GetCostDecl(ref, args):
632
+ cost_table = self.create_cost_table(ref)
633
+ return cost_table, list(args)
634
+ case _:
635
+ assert_never(expr)
636
+
489
637
  def exprs_from_egg(
490
638
  self, termdag: bindings.TermDag, terms: list[bindings._Term], tp: JustTypeRef
491
639
  ) -> Iterable[TypedExprDecl]:
@@ -529,6 +677,129 @@ class EGraphState:
529
677
  case _:
530
678
  assert_never(ref)
531
679
 
680
+ def typed_expr_to_value(self, typed_expr: TypedExprDecl) -> bindings.Value:
681
+ egg_expr = self.typed_expr_to_egg(typed_expr, False)
682
+ return self.egraph.eval_expr(egg_expr)[1]
683
+
684
+ def value_to_expr(self, tp: JustTypeRef, value: bindings.Value) -> ExprDecl: # noqa: C901, PLR0911, PLR0912
685
+ match tp.name:
686
+ # Should match list in egraph bindings
687
+ case "i64":
688
+ return LitDecl(self.egraph.value_to_i64(value))
689
+ case "f64":
690
+ return LitDecl(self.egraph.value_to_f64(value))
691
+ case "Bool":
692
+ return LitDecl(self.egraph.value_to_bool(value))
693
+ case "String":
694
+ return LitDecl(self.egraph.value_to_string(value))
695
+ case "Unit":
696
+ return LitDecl(None)
697
+ case "PyObject":
698
+ return PyObjectDecl(self.egraph.value_to_pyobject(GLOBAL_PY_OBJECT_SORT, value))
699
+ case "Rational":
700
+ fraction = self.egraph.value_to_rational(value)
701
+ return CallDecl(
702
+ InitRef("Rational"),
703
+ (
704
+ TypedExprDecl(JustTypeRef("i64"), LitDecl(fraction.numerator)),
705
+ TypedExprDecl(JustTypeRef("i64"), LitDecl(fraction.denominator)),
706
+ ),
707
+ )
708
+ case "BigInt":
709
+ i = self.egraph.value_to_bigint(value)
710
+ return CallDecl(
711
+ ClassMethodRef("BigInt", "from_string"),
712
+ (TypedExprDecl(JustTypeRef("String"), LitDecl(str(i))),),
713
+ )
714
+ case "BigRat":
715
+ fraction = self.egraph.value_to_bigrat(value)
716
+ return CallDecl(
717
+ InitRef("BigRat"),
718
+ (
719
+ TypedExprDecl(
720
+ JustTypeRef("BigInt"),
721
+ CallDecl(
722
+ ClassMethodRef("BigInt", "from_string"),
723
+ (TypedExprDecl(JustTypeRef("String"), LitDecl(str(fraction.numerator))),),
724
+ ),
725
+ ),
726
+ TypedExprDecl(
727
+ JustTypeRef("BigInt"),
728
+ CallDecl(
729
+ ClassMethodRef("BigInt", "from_string"),
730
+ (TypedExprDecl(JustTypeRef("String"), LitDecl(str(fraction.denominator))),),
731
+ ),
732
+ ),
733
+ ),
734
+ )
735
+ case "Map":
736
+ k_tp, v_tp = tp.args
737
+ expr = CallDecl(ClassMethodRef("Map", "empty"), (), (k_tp, v_tp))
738
+ for k, v in self.egraph.value_to_map(value).items():
739
+ expr = CallDecl(
740
+ MethodRef("Map", "insert"),
741
+ (
742
+ TypedExprDecl(tp, expr),
743
+ TypedExprDecl(k_tp, self.value_to_expr(k_tp, k)),
744
+ TypedExprDecl(v_tp, self.value_to_expr(v_tp, v)),
745
+ ),
746
+ )
747
+ return expr
748
+ case "Set":
749
+ xs_ = self.egraph.value_to_set(value)
750
+ (v_tp,) = tp.args
751
+ return CallDecl(
752
+ InitRef("Set"), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs_), (v_tp,)
753
+ )
754
+ case "Vec":
755
+ xs = self.egraph.value_to_vec(value)
756
+ (v_tp,) = tp.args
757
+ return CallDecl(
758
+ InitRef("Vec"), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs), (v_tp,)
759
+ )
760
+ case "MultiSet":
761
+ xs = self.egraph.value_to_multiset(value)
762
+ (v_tp,) = tp.args
763
+ return CallDecl(
764
+ InitRef("MultiSet"), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs), (v_tp,)
765
+ )
766
+ case "UnstableFn":
767
+ _names, _args = self.egraph.value_to_function(value)
768
+ return_tp, *arg_types = tp.args
769
+ return self._unstable_fn_value_to_expr(_names, _args, return_tp, arg_types)
770
+ return ValueDecl(value)
771
+
772
+ def _unstable_fn_value_to_expr(
773
+ self, name: str, partial_args: list[bindings.Value], return_tp: JustTypeRef, _arg_types: list[JustTypeRef]
774
+ ) -> PartialCallDecl:
775
+ # Similar to FromEggState::from_call but accepts partial list of args and returns in values
776
+ # Find first callable ref whose return type matches and fill in arg types.
777
+ for callable_ref in self.egg_fn_to_callable_refs[name]:
778
+ signature = self.__egg_decls__.get_callable_decl(callable_ref).signature
779
+ if not isinstance(signature, FunctionSignature):
780
+ continue
781
+ if signature.semantic_return_type.name != return_tp.name:
782
+ continue
783
+ tcs = TypeConstraintSolver(self.__egg_decls__)
784
+
785
+ arg_types, bound_tp_params = tcs.infer_arg_types(
786
+ signature.arg_types, signature.semantic_return_type, signature.var_arg_type, return_tp, None
787
+ )
788
+
789
+ args = tuple(
790
+ TypedExprDecl(tp, self.value_to_expr(tp, v)) for tp, v in zip(arg_types, partial_args, strict=False)
791
+ )
792
+
793
+ call_decl = CallDecl(
794
+ callable_ref,
795
+ args,
796
+ # Don't include bound type params if this is just a method, we only needed them for type resolution
797
+ # but dont need to store them
798
+ bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else (),
799
+ )
800
+ return PartialCallDecl(call_decl)
801
+ raise ValueError(f"Function '{name}' not found")
802
+
532
803
 
533
804
  # https://chatgpt.com/share/9ab899b4-4e17-4426-a3f2-79d67a5ec456
534
805
  _EGGLOG_INVALID_IDENT = re.compile(r"[^\w\-+*/?!=<>&|^/%]")
@@ -666,7 +937,7 @@ class FromEggState:
666
937
  args,
667
938
  # Don't include bound type params if this is just a method, we only needed them for type resolution
668
939
  # but dont need to store them
669
- bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else None,
940
+ bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else (),
670
941
  )
671
942
  raise ValueError(
672
943
  f"Could not find callable ref for call {term}. None of these refs matched the types: {self.state.egg_fn_to_callable_refs[term.name]}"
@@ -62,6 +62,3 @@ egraph.register(query)
62
62
  egraph.run(1000)
63
63
  print(egraph.extract(query))
64
64
  print(egraph.extract(query.size))
65
-
66
-
67
- egraph
@@ -4,7 +4,7 @@ from typing import TypeVar, cast
4
4
 
5
5
  import numpy as np
6
6
 
7
- from egglog import EGraph
7
+ from egglog import EGraph, greedy_dag_cost_model
8
8
  from egglog.exp.array_api import NDArray, set_array_api_egraph, try_evaling
9
9
  from egglog.exp.array_api_numba import array_api_numba_schedule
10
10
  from egglog.exp.array_api_program_gen import EvalProgram, array_api_program_gen_schedule, ndarray_function_two_program
@@ -41,7 +41,7 @@ def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph,
41
41
  res = fn(NDArray.var(arg1), NDArray.var(arg2))
42
42
  egraph.register(res)
43
43
  egraph.run(array_api_numba_schedule)
44
- res_optimized = egraph.extract(res)
44
+ res_optimized = egraph.extract(res, cost_model=greedy_dag_cost_model())
45
45
 
46
46
  return (
47
47
  egraph,
egglog/pretty.py CHANGED
@@ -67,7 +67,9 @@ UNARY_METHODS = {
67
67
  "__invert__": "~",
68
68
  }
69
69
 
70
- AllDecls: TypeAlias = RulesetDecl | CombinedRulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl
70
+ AllDecls: TypeAlias = (
71
+ RulesetDecl | CombinedRulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl | BackOffDecl
72
+ )
71
73
 
72
74
 
73
75
  def pretty_decl(
@@ -181,17 +183,19 @@ class TraverseContext:
181
183
  if isinstance(de, DefaultRewriteDecl):
182
184
  continue
183
185
  self(de)
184
- case CallDecl(ref, exprs, _):
186
+ case CallDecl(ref, exprs, _) | GetCostDecl(ref, exprs):
185
187
  match ref:
186
188
  case FunctionRef(UnnamedFunctionRef(_, res)):
187
189
  self(res.expr)
188
190
  case _:
189
191
  for e in exprs:
190
192
  self(e.expr)
191
- case RunDecl(_, until):
193
+ case RunDecl(_, until, scheduler):
192
194
  if until:
193
195
  for f in until:
194
196
  self(f)
197
+ if scheduler:
198
+ self(scheduler)
195
199
  case PartialCallDecl(c):
196
200
  self(c)
197
201
  case CombinedRulesetDecl(_):
@@ -201,6 +205,13 @@ class TraverseContext:
201
205
  case SetCostDecl(_, e, c):
202
206
  self(e)
203
207
  self(c)
208
+ case BackOffDecl() | ValueDecl():
209
+ pass
210
+ case LetSchedulerDecl(scheduler, schedule):
211
+ self(scheduler)
212
+ self(schedule)
213
+ case GetCostDecl(ref, args):
214
+ self(CallDecl(ref, args))
204
215
  case _:
205
216
  assert_never(decl)
206
217
 
@@ -238,7 +249,11 @@ class PrettyContext:
238
249
  # it would take up is > than some constant (~ line length).
239
250
  line_diff: int = len(expr) - LINE_DIFFERENCE
240
251
  n_parents = self.parents[decl]
241
- if n_parents > 1 and n_parents * line_diff > MAX_LINE_LENGTH:
252
+ if n_parents > 1 and (
253
+ n_parents * line_diff > MAX_LINE_LENGTH
254
+ # Schedulers with multiple parents need to be the same object, b/c are created with hidden UUIDs
255
+ or tp_name == "scheduler"
256
+ ):
242
257
  self.names[decl] = expr_name = self._name_expr(tp_name, expr, copy_identifier=False)
243
258
  return expr_name
244
259
  return expr
@@ -318,16 +333,31 @@ class PrettyContext:
318
333
  return f"{self(schedules[0], parens=True)} + {self(schedules[1], parens=True)}", "schedule"
319
334
  args = ", ".join(map(self, schedules))
320
335
  return f"seq({args})", "schedule"
321
- case RunDecl(ruleset_name, until):
336
+ case LetSchedulerDecl(scheduler, schedule):
337
+ return f"{self(scheduler, parens=True)}.scope({self(schedule, parens=True)})", "schedule"
338
+ case RunDecl(ruleset_name, until, scheduler):
322
339
  ruleset = self.decls._rulesets[ruleset_name]
323
340
  ruleset_str = self(ruleset, ruleset_name=ruleset_name)
324
- if not until:
341
+ if not until and not scheduler:
325
342
  return ruleset_str, "schedule"
326
- args = ", ".join(map(self, until))
327
- return f"run({ruleset_str}, {args})", "schedule"
343
+ arg_lst = list(map(self, until or []))
344
+ if scheduler:
345
+ arg_lst.append(f"scheduler={self(scheduler)}")
346
+ return f"run({ruleset_str}, {', '.join(arg_lst)})", "schedule"
328
347
  case DefaultRewriteDecl():
329
348
  msg = "default rewrites should not be pretty printed"
330
349
  raise TypeError(msg)
350
+ case BackOffDecl(_, match_limit, ban_length):
351
+ list_args: list[str] = []
352
+ if match_limit is not None:
353
+ list_args.append(f"match_limit={match_limit}")
354
+ if ban_length is not None:
355
+ list_args.append(f"ban_length={ban_length}")
356
+ return f"back_off({', '.join(list_args)})", "scheduler"
357
+ case ValueDecl(value):
358
+ return str(value), "value"
359
+ case GetCostDecl(ref, args):
360
+ return f"get_cost({self(CallDecl(ref, args))})", "get_cost"
331
361
  assert_never(decl)
332
362
 
333
363
  def _call(
egglog/runtime.py CHANGED
@@ -457,7 +457,7 @@ class RuntimeFunction(DelayedDeclerations):
457
457
  arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
458
458
  return_tp = tcs.substitute_typevars(signature.semantic_return_type, cls_name)
459
459
  bound_params = (
460
- cast("JustTypeRef", bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else None
460
+ cast("JustTypeRef", bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else ()
461
461
  )
462
462
  # If we were using unstable-app to call a funciton, add that function back as the first arg.
463
463
  if function_value:
@@ -584,11 +584,17 @@ class RuntimeExpr(DelayedDeclerations):
584
584
  if (method := _get_expr_method(self, "__eq__")) is not None:
585
585
  return method(other)
586
586
 
587
- # TODO: Check if two objects can be upcasted to be the same. If not, then return NotImplemented so other
588
- # expr gets a chance to resolve __eq__ which could be a preserved method.
589
- from .egraph import BaseExpr, eq # noqa: PLC0415
587
+ if not (isinstance(self, RuntimeExpr) and isinstance(other, RuntimeExpr)):
588
+ return NotImplemented
589
+ if self.__egg_typed_expr__.tp != other.__egg_typed_expr__.tp:
590
+ return NotImplemented
590
591
 
591
- return eq(cast("BaseExpr", self)).to(cast("BaseExpr", other))
592
+ from .egraph import Fact # noqa: PLC0415
593
+
594
+ return Fact(
595
+ Declarations.create(self, other),
596
+ EqDecl(self.__egg_typed_expr__.tp, self.__egg_typed_expr__.expr, other.__egg_typed_expr__.expr),
597
+ )
592
598
 
593
599
  def __ne__(self, other: object) -> object: # type: ignore[override]
594
600
  if (method := _get_expr_method(self, "__ne__")) is not None:
@@ -635,14 +641,22 @@ for name in TYPE_DEFINED_METHODS:
635
641
 
636
642
 
637
643
  for name, r_method in itertools.product(NUMERIC_BINARY_METHODS, (False, True)):
644
+ method_name = f"__r{name[2:]}" if r_method else name
638
645
 
639
- def _numeric_binary_method(self: object, other: object, name: str = name, r_method: bool = r_method) -> object:
646
+ def _numeric_binary_method(
647
+ self: object, other: object, name: str = name, r_method: bool = r_method, method_name: str = method_name
648
+ ) -> object:
640
649
  """
641
650
  Implements numeric binary operations.
642
651
 
643
652
  Tries to find the minimum cost conversion of either the LHS or the RHS, by finding all methods with either
644
653
  the LHS or the RHS as exactly the right type and then upcasting the other to that type.
645
654
  """
655
+ # First check if we have a preserved method for this:
656
+ if isinstance(self, RuntimeExpr) and (
657
+ (preserved_method := self.__egg_class_decl__.preserved_methods.get(method_name)) is not None
658
+ ):
659
+ return preserved_method.__get__(self)(other)
646
660
  # 1. switch if reversed method
647
661
  if r_method:
648
662
  self, other = other, self
@@ -668,7 +682,6 @@ for name, r_method in itertools.product(NUMERIC_BINARY_METHODS, (False, True)):
668
682
  fn = RuntimeFunction(Thunk.value(self.__egg_decls__), Thunk.value(method_ref), self)
669
683
  return fn(other)
670
684
 
671
- method_name = f"__r{name[2:]}" if r_method else name
672
685
  setattr(RuntimeExpr, method_name, _numeric_binary_method)
673
686
 
674
687
 
@@ -688,6 +701,8 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
688
701
  ):
689
702
  raise NotImplementedError(f"Can only turn constants or classvars into callable refs, not {expr}")
690
703
  return expr.callable, decl_thunk()
704
+ case types.MethodWrapperType() if isinstance((slf := callable.__self__), RuntimeClass):
705
+ return MethodRef(slf.__egg_tp__.name, callable.__name__), slf.__egg_decls__
691
706
  case _:
692
707
  raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref")
693
708
 
@@ -54,7 +54,7 @@ class TypeConstraintSolver:
54
54
  fn_var_args: TypeOrVarRef | None,
55
55
  return_: JustTypeRef,
56
56
  cls_name: str | None,
57
- ) -> tuple[Iterable[JustTypeRef], tuple[JustTypeRef, ...] | None]:
57
+ ) -> tuple[Iterable[JustTypeRef], tuple[JustTypeRef, ...]]:
58
58
  """
59
59
  Given a return type, infer the argument types. If there is a variable arg, it returns an infinite iterable.
60
60
 
@@ -75,7 +75,7 @@ class TypeConstraintSolver:
75
75
  )
76
76
  )
77
77
  if cls_name
78
- else None
78
+ else ()
79
79
  )
80
80
  return arg_types, bound_typevars
81
81
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: egglog
3
- Version: 11.2.0
3
+ Version: 11.4.0
4
4
  Classifier: Environment :: MacOS X
5
5
  Classifier: Environment :: Win32 (MS Windows)
6
6
  Classifier: Intended Audience :: Developers
@@ -51,6 +51,7 @@ Requires-Dist: egglog[array] ; extra == 'docs'
51
51
  Requires-Dist: line-profiler ; extra == 'docs'
52
52
  Requires-Dist: sphinxcontrib-mermaid ; extra == 'docs'
53
53
  Requires-Dist: ablog ; extra == 'docs'
54
+ Requires-Dist: jupytext ; extra == 'docs'
54
55
  Provides-Extra: array
55
56
  Provides-Extra: dev
56
57
  Provides-Extra: test
@@ -71,4 +72,21 @@ allowing you to use e-graphs in Python for optimization, symbolic computation, a
71
72
  Please see the [documentation](https://egglog-python.readthedocs.io/) for more information.
72
73
 
73
74
  Come say hello [on the e-graphs Zulip](https://egraphs.zulipchat.com/#narrow/stream/375765-egglog/) or [open an issue](https://github.com/egraphs-good/egglog-python/issues/new/choose)!
75
+
76
+ ## How to cite
77
+
78
+ If you use **egglog-python** in academic work, please cite the paper:
79
+
80
+ ```bibtex
81
+ @misc{Shanabrook2023EgglogPython,
82
+ title = {Egglog Python: A Pythonic Library for E-graphs},
83
+ author = {Saul Shanabrook},
84
+ year = {2023},
85
+ eprint = {2305.04311},
86
+ archivePrefix = {arXiv},
87
+ primaryClass = {cs.PL},
88
+ doi = {10.48550/arXiv.2305.04311},
89
+ url = {https://arxiv.org/abs/2305.04311},
90
+ note = {Presented at EGRAPHS@PLDI 2023}
91
+ }
74
92