egglog 10.0.1__cp313-cp313-win_amd64.whl → 11.0.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph.py CHANGED
@@ -5,7 +5,7 @@ import inspect
5
5
  import pathlib
6
6
  import tempfile
7
7
  from collections.abc import Callable, Generator, Iterable
8
- from contextvars import ContextVar
8
+ from contextvars import ContextVar, Token
9
9
  from dataclasses import InitVar, dataclass, field
10
10
  from functools import partial
11
11
  from inspect import Parameter, currentframe, signature
@@ -16,7 +16,6 @@ from typing import (
16
16
  ClassVar,
17
17
  Generic,
18
18
  Literal,
19
- Never,
20
19
  TypeAlias,
21
20
  TypedDict,
22
21
  TypeVar,
@@ -26,16 +25,18 @@ from typing import (
26
25
  )
27
26
 
28
27
  import graphviz
29
- from typing_extensions import ParamSpec, Self, Unpack, assert_never
28
+ from typing_extensions import Never, ParamSpec, Self, Unpack, assert_never
30
29
 
31
30
  from . import bindings
32
31
  from .conversion import *
32
+ from .conversion import convert_to_same_type, resolve_literal
33
33
  from .declarations import *
34
34
  from .egraph_state import *
35
35
  from .ipython_magic import IN_IPYTHON
36
36
  from .pretty import pretty_decl
37
37
  from .runtime import *
38
38
  from .thunk import *
39
+ from .version_compat import *
39
40
 
40
41
  if TYPE_CHECKING:
41
42
  from .builtins import String, Unit
@@ -82,7 +83,6 @@ __all__ = [
82
83
  "run",
83
84
  "seq",
84
85
  "set_",
85
- "simplify",
86
86
  "subsume",
87
87
  "union",
88
88
  "unstable_combine_rulesets",
@@ -110,12 +110,12 @@ IGNORED_ATTRIBUTES = {
110
110
  "__weakref__",
111
111
  "__orig_bases__",
112
112
  "__annotations__",
113
- "__hash__",
114
113
  "__qualname__",
115
114
  "__firstlineno__",
116
115
  "__static_attributes__",
116
+ "__match_args__",
117
117
  # Ignore all reflected binary method
118
- *REFLECTED_BINARY_METHODS.keys(),
118
+ *(f"__r{m[2:]}" for m in NUMERIC_BINARY_METHODS),
119
119
  }
120
120
 
121
121
 
@@ -139,15 +139,6 @@ ALWAYS_PRESERVED = {
139
139
  }
140
140
 
141
141
 
142
- def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
143
- """
144
- Simplify an expression by running the schedule.
145
- """
146
- if schedule:
147
- return EGraph().simplify(x, schedule)
148
- return EGraph().extract(x)
149
-
150
-
151
142
  def check_eq(x: BASE_EXPR, y: BASE_EXPR, schedule: Schedule | None = None, *, add_second=True, display=False) -> EGraph:
152
143
  """
153
144
  Verifies that two expressions are equal after running the schedule.
@@ -169,8 +160,9 @@ def check_eq(x: BASE_EXPR, y: BASE_EXPR, schedule: Schedule | None = None, *, ad
169
160
  except bindings.EggSmolError as err:
170
161
  if display:
171
162
  egraph.display()
172
- err.add_note(f"Failed:\n{eq(x).to(y)}\n\nExtracted:\n {eq(egraph.extract(x)).to(egraph.extract(y))})")
173
- raise
163
+ raise add_note(
164
+ f"Failed:\n{eq(x).to(y)}\n\nExtracted:\n {eq(egraph.extract(x)).to(egraph.extract(y))})", err
165
+ ) from None
174
166
  return egraph
175
167
 
176
168
 
@@ -290,7 +282,6 @@ def function(
290
282
  mutates_first_arg: bool = ...,
291
283
  unextractable: bool = ...,
292
284
  ruleset: Ruleset | None = ...,
293
- use_body_as_name: bool = ...,
294
285
  subsume: bool = ...,
295
286
  ) -> Callable[[CONSTRUCTOR_CALLABLE], CONSTRUCTOR_CALLABLE]: ...
296
287
 
@@ -370,6 +361,7 @@ class BaseExpr(metaclass=_ExprMetaclass):
370
361
 
371
362
  def __ne__(self, other: Self) -> Unit: ... # type: ignore[override, empty-body]
372
363
 
364
+ # not currently dissalowing other types of equality https://github.com/python/typeshed/issues/8217#issuecomment-3140873292
373
365
  def __eq__(self, other: Self) -> Fact: ... # type: ignore[override, empty-body]
374
366
 
375
367
 
@@ -403,7 +395,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
403
395
  )
404
396
  type_vars = tuple(ClassTypeVarRef.from_type_var(p) for p in parameters)
405
397
  del parameters
406
- cls_decl = ClassDecl(egg_sort, type_vars, builtin)
398
+ cls_decl = ClassDecl(egg_sort, type_vars, builtin, match_args=namespace.pop("__match_args__", ()))
407
399
  decls = Declarations(_classes={cls_name: cls_decl})
408
400
  # Update class think eagerly when resolving so that lookups work in methods
409
401
  runtime_cls.__egg_decls_thunk__ = Thunk.value(decls)
@@ -455,6 +447,9 @@ def _generate_class_decls( # noqa: C901,PLR0912
455
447
  continue
456
448
  locals = frame.f_locals
457
449
  ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
450
+ # TODO: Store deprecated message so we can print at runtime
451
+ if (getattr(fn, "__deprecated__", None)) is not None:
452
+ fn = fn.__wrapped__ # type: ignore[attr-defined]
458
453
  match fn:
459
454
  case classmethod():
460
455
  ref = ClassMethodRef(cls_name, method_name)
@@ -476,7 +471,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
476
471
  decls.set_function_decl(ref, decl)
477
472
  continue
478
473
  try:
479
- _, add_rewrite = _fn_decl(
474
+ add_rewrite = _fn_decl(
480
475
  decls,
481
476
  egg_fn,
482
477
  ref,
@@ -492,8 +487,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
492
487
  reverse_args=reverse_args,
493
488
  )
494
489
  except Exception as e:
495
- e.add_note(f"Error processing {cls_name}.{method_name}")
496
- raise
490
+ raise add_note(f"Error processing {cls_name}.{method_name}", e) from None
497
491
 
498
492
  if not builtin and not isinstance(ref, InitRef) and not mutates:
499
493
  add_default_funcs.append(add_rewrite)
@@ -515,7 +509,6 @@ class _FunctionConstructor:
515
509
  merge: Callable[[object, object], object] | None = None
516
510
  unextractable: bool = False
517
511
  ruleset: Ruleset | None = None
518
- use_body_as_name: bool = False
519
512
  subsume: bool = False
520
513
 
521
514
  def __call__(self, fn: Callable) -> RuntimeFunction:
@@ -523,11 +516,10 @@ class _FunctionConstructor:
523
516
 
524
517
  def create_decls(self, fn: Callable) -> tuple[Declarations, CallableRef]:
525
518
  decls = Declarations()
526
- ref = None if self.use_body_as_name else FunctionRef(fn.__name__)
527
- ref, add_rewrite = _fn_decl(
519
+ add_rewrite = _fn_decl(
528
520
  decls,
529
521
  self.egg_fn,
530
- ref,
522
+ ref := FunctionRef(fn.__name__),
531
523
  fn,
532
524
  self.hint_locals,
533
525
  self.cost,
@@ -545,8 +537,7 @@ class _FunctionConstructor:
545
537
  def _fn_decl(
546
538
  decls: Declarations,
547
539
  egg_name: str | None,
548
- # If ref is Callable, then generate the ref from the function name
549
- ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef | None,
540
+ ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
550
541
  fn: object,
551
542
  # Pass in the locals, retrieved from the frame when wrapping,
552
543
  # so that we support classes and function defined inside of other functions (which won't show up in the globals)
@@ -559,7 +550,7 @@ def _fn_decl(
559
550
  ruleset: Ruleset | None = None,
560
551
  unextractable: bool = False,
561
552
  reverse_args: bool = False,
562
- ) -> tuple[CallableRef, Callable[[], None]]:
553
+ ) -> Callable[[], None]:
563
554
  """
564
555
  Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
565
556
  """
@@ -569,16 +560,11 @@ def _fn_decl(
569
560
  if not isinstance(fn, FunctionType):
570
561
  raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
571
562
 
572
- hint_globals = fn.__globals__.copy()
573
- # Copy Callable into global if not present bc sometimes it gets automatically removed by ruff to type only block
574
- # https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/
575
- if "Callable" not in hint_globals:
576
- hint_globals["Callable"] = Callable
577
563
  # Instead of passing both globals and locals, just pass the globals. Otherwise, for some reason forward references
578
564
  # won't be resolved correctly
579
565
  # We need this to be false so it returns "__forward_value__" https://github.com/python/cpython/blob/440ed18e08887b958ad50db1b823e692a747b671/Lib/typing.py#L919
580
566
  # https://github.com/egraphs-good/egglog-python/issues/210
581
- hint_globals.update(hint_locals)
567
+ hint_globals = {**fn.__globals__, **hint_locals}
582
568
  hints = get_type_hints(fn, hint_globals)
583
569
 
584
570
  params = list(signature(fn).parameters.values())
@@ -624,60 +610,49 @@ def _fn_decl(
624
610
  else resolve_literal(
625
611
  return_type,
626
612
  merge(
627
- RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old", False))),
628
- RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new", False))),
613
+ RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), UnboundVarDecl("old", "old"))),
614
+ RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), UnboundVarDecl("new", "new"))),
629
615
  ),
630
616
  lambda: decls,
631
617
  )
632
618
  )
633
619
  decls |= merged
634
620
 
635
- # defer this in generator so it doesnt resolve for builtins eagerly
636
- args = (TypedExprDecl(tp.to_just(), VarDecl(name, False)) for name, tp in zip(arg_names, arg_types, strict=True))
637
- res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
638
- res_thunk: Callable[[], object]
639
- # If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it
640
- if not ref:
641
- tuple_args = tuple(args)
642
- res = _create_default_value(decls, ref, fn, tuple_args, ruleset)
643
- assert isinstance(res, RuntimeExpr)
644
- res_ref = UnnamedFunctionRef(tuple_args, res.__egg_typed_expr__)
645
- decls._unnamed_functions.add(res_ref)
646
- res_thunk = Thunk.value(res)
621
+ # defer this in generator so it doesn't resolve for builtins eagerly
622
+ args = (TypedExprDecl(tp.to_just(), UnboundVarDecl(name)) for name, tp in zip(arg_names, arg_types, strict=True))
647
623
 
624
+ return_type_is_eqsort = (
625
+ not decls._classes[return_type.name].builtin if isinstance(return_type, TypeRefWithVars) else False
626
+ )
627
+ is_constructor = not is_builtin and return_type_is_eqsort and merged is None
628
+ signature_ = FunctionSignature(
629
+ return_type=None if mutates_first_arg else return_type,
630
+ var_arg_type=var_arg_type,
631
+ arg_types=arg_types,
632
+ arg_names=arg_names,
633
+ arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
634
+ reverse_args=reverse_args,
635
+ )
636
+ decl: ConstructorDecl | FunctionDecl
637
+ if is_constructor:
638
+ decl = ConstructorDecl(signature_, egg_name, cost, unextractable)
648
639
  else:
649
- return_type_is_eqsort = (
650
- not decls._classes[return_type.name].builtin if isinstance(return_type, TypeRefWithVars) else False
651
- )
652
- is_constructor = not is_builtin and return_type_is_eqsort and merged is None
653
- signature_ = FunctionSignature(
654
- return_type=None if mutates_first_arg else return_type,
655
- var_arg_type=var_arg_type,
656
- arg_types=arg_types,
657
- arg_names=arg_names,
658
- arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
659
- reverse_args=reverse_args,
640
+ if cost is not None:
641
+ msg = "Cost can only be set for constructors"
642
+ raise ValueError(msg)
643
+ if unextractable:
644
+ msg = "Unextractable can only be set for constructors"
645
+ raise ValueError(msg)
646
+ decl = FunctionDecl(
647
+ signature=signature_,
648
+ egg_name=egg_name,
649
+ merge=merged.__egg_typed_expr__.expr if merged is not None else None,
650
+ builtin=is_builtin,
660
651
  )
661
- decl: ConstructorDecl | FunctionDecl
662
- if is_constructor:
663
- decl = ConstructorDecl(signature_, egg_name, cost, unextractable)
664
- else:
665
- if cost is not None:
666
- msg = "Cost can only be set for constructors"
667
- raise ValueError(msg)
668
- if unextractable:
669
- msg = "Unextractable can only be set for constructors"
670
- raise ValueError(msg)
671
- decl = FunctionDecl(
672
- signature=signature_,
673
- egg_name=egg_name,
674
- merge=merged.__egg_typed_expr__.expr if merged is not None else None,
675
- builtin=is_builtin,
676
- )
677
- res_ref = ref
678
- decls.set_function_decl(ref, decl)
679
- res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset)
680
- return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk, subsume)
652
+ decls.set_function_decl(ref, decl)
653
+ return Thunk.fn(
654
+ _add_default_rewrite_function, decls, ref, fn, args, ruleset, subsume, return_type, context=f"creating {ref}"
655
+ )
681
656
 
682
657
 
683
658
  # Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
@@ -712,7 +687,7 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..
712
687
 
713
688
 
714
689
  def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations:
715
- from .builtins import Unit
690
+ from .builtins import Unit # noqa: PLC0415
716
691
 
717
692
  decls = Declarations()
718
693
  decls |= cast("RuntimeClass", Unit)
@@ -751,13 +726,15 @@ def _constant_thunk(
751
726
  return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
752
727
 
753
728
 
754
- def _create_default_value(
729
+ def _add_default_rewrite_function(
755
730
  decls: Declarations,
756
- ref: CallableRef | None,
731
+ ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
757
732
  fn: Callable,
758
733
  args: Iterable[TypedExprDecl],
759
734
  ruleset: Ruleset | None,
760
- ) -> object:
735
+ subsume: bool,
736
+ res_type: TypeOrVarRef,
737
+ ) -> None:
761
738
  args: list[object] = [RuntimeExpr.__from_values__(decls, a) for a in args]
762
739
 
763
740
  # If this is a classmethod, add the class as the first arg
@@ -765,21 +742,8 @@ def _create_default_value(
765
742
  tp = decls.get_paramaterized_class(ref.class_name)
766
743
  args.insert(0, RuntimeClass(Thunk.value(decls), tp))
767
744
  with set_current_ruleset(ruleset):
768
- return fn(*args)
769
-
770
-
771
- def _add_default_rewrite_function(
772
- decls: Declarations,
773
- ref: CallableRef,
774
- res_type: TypeOrVarRef,
775
- ruleset: Ruleset | None,
776
- value_thunk: Callable[[], object],
777
- subsume: bool,
778
- ) -> None:
779
- """
780
- Helper functions that resolves a value thunk to create the default value.
781
- """
782
- _add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset, subsume)
745
+ res = fn(*args)
746
+ _add_default_rewrite(decls, ref, res_type, res, ruleset, subsume)
783
747
 
784
748
 
785
749
  def _add_default_rewrite(
@@ -799,6 +763,13 @@ def _add_default_rewrite(
799
763
  return
800
764
  resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls))
801
765
  rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr, subsume)
766
+ ruleset_decls = _add_default_rewrite_inner(decls, rewrite_decl, ruleset)
767
+ ruleset_decls |= resolved_value
768
+
769
+
770
+ def _add_default_rewrite_inner(
771
+ decls: Declarations, rewrite_decl: DefaultRewriteDecl, ruleset: Ruleset | None
772
+ ) -> Declarations:
802
773
  if ruleset:
803
774
  ruleset_decls = ruleset._current_egg_decls
804
775
  ruleset_decl = ruleset.__egg_ruleset__
@@ -806,7 +777,7 @@ def _add_default_rewrite(
806
777
  ruleset_decls = decls
807
778
  ruleset_decl = decls.default_ruleset
808
779
  ruleset_decl.rules.append(rewrite_decl)
809
- ruleset_decls |= resolved_value
780
+ return ruleset_decls
810
781
 
811
782
 
812
783
  def _last_param_variable(params: list[Parameter]) -> bool:
@@ -887,6 +858,7 @@ class EGraph:
887
858
  self._add_decls(decls)
888
859
  return self._state.callable_ref_to_egg(ref)[0]
889
860
 
861
+ # TODO: Change let to be action...
890
862
  def let(self, name: str, expr: BASE_EXPR) -> BASE_EXPR:
891
863
  """
892
864
  Define a new expression in the egraph and return a reference to it.
@@ -898,38 +870,10 @@ class EGraph:
898
870
  return cast(
899
871
  "BASE_EXPR",
900
872
  RuntimeExpr.__from_values__(
901
- self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name, True))
873
+ self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, LetRefDecl(name))
902
874
  ),
903
875
  )
904
876
 
905
- @overload
906
- def simplify(self, expr: BASE_EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> BASE_EXPR: ...
907
-
908
- @overload
909
- def simplify(self, expr: BASE_EXPR, schedule: Schedule, /) -> BASE_EXPR: ...
910
-
911
- def simplify(
912
- self, expr: BASE_EXPR, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None
913
- ) -> BASE_EXPR:
914
- """
915
- Simplifies the given expression.
916
- """
917
- schedule = run(ruleset, *until) * limit_or_schedule if isinstance(limit_or_schedule, int) else limit_or_schedule
918
- del limit_or_schedule, until, ruleset
919
- runtime_expr = to_runtime_expr(expr)
920
- self._add_decls(runtime_expr, schedule)
921
- egg_schedule = self._state.schedule_to_egg(schedule.schedule)
922
- typed_expr = runtime_expr.__egg_typed_expr__
923
- # Must also register type
924
- egg_expr = self._state.typed_expr_to_egg(typed_expr)
925
- self._egraph.run_program(bindings.Simplify(span(1), egg_expr, egg_schedule))
926
- extract_report = self._egraph.extract_report()
927
- if not isinstance(extract_report, bindings.Best):
928
- msg = "No extract report saved"
929
- raise ValueError(msg) # noqa: TRY004
930
- (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
931
- return cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
932
-
933
877
  def include(self, path: str) -> None:
934
878
  """
935
879
  Include a file of rules.
@@ -1041,12 +985,9 @@ class EGraph:
1041
985
  self._add_decls(expr)
1042
986
  expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
1043
987
  try:
1044
- self._egraph.run_program(
1045
- bindings.ActionCommand(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))
1046
- )
988
+ self._egraph.run_program(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))
1047
989
  except BaseException as e:
1048
- e.add_note("Extracting: " + str(expr))
1049
- raise
990
+ raise add_note("Extracting: " + str(expr), e) # noqa: B904
1050
991
  extract_report = self._egraph.extract_report()
1051
992
  if not extract_report:
1052
993
  msg = "No extract report saved"
@@ -1144,9 +1085,9 @@ class EGraph:
1144
1085
 
1145
1086
  If in IPython it will display it inline, otherwise it will write it to a file and open it.
1146
1087
  """
1147
- from IPython.display import SVG, display
1088
+ from IPython.display import SVG, display # noqa: PLC0415
1148
1089
 
1149
- from .visualizer_widget import VisualizerWidget
1090
+ from .visualizer_widget import VisualizerWidget # noqa: PLC0415
1150
1091
 
1151
1092
  if graphviz:
1152
1093
  if IN_IPYTHON:
@@ -1173,7 +1114,7 @@ class EGraph:
1173
1114
 
1174
1115
  If an `expr` is passed, it's also extracted after each run and printed
1175
1116
  """
1176
- from .visualizer_widget import VisualizerWidget
1117
+ from .visualizer_widget import VisualizerWidget # noqa: PLC0415
1177
1118
 
1178
1119
  def to_json() -> str:
1179
1120
  if expr is not None:
@@ -1565,16 +1506,17 @@ def rule(*facts: FactLike, ruleset: None = None, name: str | None = None) -> _Ru
1565
1506
  return _RuleBuilder(facts=_fact_likes(facts), name=name, ruleset=ruleset)
1566
1507
 
1567
1508
 
1568
- def var(name: str, bound: type[T]) -> T:
1509
+ def var(name: str, bound: type[T], egg_name: str | None = None) -> T:
1569
1510
  """Create a new variable with the given name and type."""
1570
- return cast("T", _var(name, bound))
1511
+ return cast("T", _var(name, bound, egg_name=egg_name))
1571
1512
 
1572
1513
 
1573
- def _var(name: str, bound: object) -> RuntimeExpr:
1514
+ def _var(name: str, bound: object, egg_name: str | None) -> RuntimeExpr:
1574
1515
  """Create a new variable with the given name and type."""
1575
1516
  decls_like, type_ref = resolve_type_annotation(bound)
1576
1517
  return RuntimeExpr(
1577
- Thunk.fn(Declarations.create, decls_like), Thunk.value(TypedExprDecl(type_ref.to_just(), VarDecl(name, False)))
1518
+ Thunk.fn(Declarations.create, decls_like),
1519
+ Thunk.value(TypedExprDecl(type_ref.to_just(), UnboundVarDecl(name, egg_name))),
1578
1520
  )
1579
1521
 
1580
1522
 
@@ -1665,7 +1607,7 @@ class _NeBuilder(Generic[BASE_EXPR]):
1665
1607
  lhs: BASE_EXPR
1666
1608
 
1667
1609
  def to(self, rhs: BASE_EXPR) -> Unit:
1668
- from .builtins import Unit
1610
+ from .builtins import Unit # noqa: PLC0415
1669
1611
 
1670
1612
  lhs = to_runtime_expr(self.lhs)
1671
1613
  rhs = convert_to_same_type(rhs, lhs)
@@ -1824,7 +1766,7 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
1824
1766
  # python/tests/test_no_import_star.py::test_no_import_star_rulesset
1825
1767
  combined = {**gen.__globals__, **frame.f_locals}
1826
1768
  hints = get_type_hints(gen, combined, combined)
1827
- args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
1769
+ args = [_var(p.name, hints[p.name], egg_name=None) for p in signature(gen).parameters.values()]
1828
1770
  return list(gen(*args)) # type: ignore[misc]
1829
1771
 
1830
1772
 
@@ -1850,7 +1792,7 @@ def get_current_ruleset() -> Ruleset | None:
1850
1792
 
1851
1793
  @contextlib.contextmanager
1852
1794
  def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
1853
- token = _CURRENT_RULESET.set(r)
1795
+ token: Token[Ruleset | None] = _CURRENT_RULESET.set(r)
1854
1796
  try:
1855
1797
  yield
1856
1798
  finally:
egglog/egraph_state.py CHANGED
@@ -108,7 +108,7 @@ class EGraphState:
108
108
  case RulesetDecl(rules):
109
109
  if name not in self.rulesets:
110
110
  if name:
111
- self.egraph.run_program(bindings.AddRuleset(name))
111
+ self.egraph.run_program(bindings.AddRuleset(span(), name))
112
112
  added_rules = self.rulesets[name] = set()
113
113
  else:
114
114
  added_rules = self.rulesets[name]
@@ -125,7 +125,7 @@ class EGraphState:
125
125
  self.rulesets[name] = set()
126
126
  for ruleset in rulesets:
127
127
  self.ruleset_to_egg(ruleset)
128
- self.egraph.run_program(bindings.UnstableCombinedRuleset(name, list(rulesets)))
128
+ self.egraph.run_program(bindings.UnstableCombinedRuleset(span(), name, list(rulesets)))
129
129
 
130
130
  def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command | None:
131
131
  match cmd:
@@ -160,7 +160,7 @@ class EGraphState:
160
160
  assert isinstance(sig, FunctionSignature)
161
161
  # Replace args with rule_var_name mapping
162
162
  arg_mapping = tuple(
163
- TypedExprDecl(tp.to_just(), VarDecl(name, False))
163
+ TypedExprDecl(tp.to_just(), UnboundVarDecl(name))
164
164
  for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
165
165
  )
166
166
  rewrite_decl = RewriteDecl(
@@ -179,7 +179,7 @@ class EGraphState:
179
179
  def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindings._Action | None: # noqa: C901, PLR0911, PLR0912
180
180
  match action:
181
181
  case LetDecl(name, typed_expr):
182
- var_decl = VarDecl(name, True)
182
+ var_decl = LetRefDecl(name)
183
183
  var_egg = self._expr_to_egg(var_decl)
184
184
  self.expr_to_egg_cache[var_decl] = var_egg
185
185
  return bindings.Let(span(), var_egg.name, self.typed_expr_to_egg(typed_expr))
@@ -369,7 +369,8 @@ class EGraphState:
369
369
  """
370
370
  Rewrites this expression as a let binding if it's not already a let binding.
371
371
  """
372
- var_decl = VarDecl(f"__expr_{hash(typed_expr)}", True)
372
+ # TODO: Replace with counter so that it works with hash collisions and is more stable
373
+ var_decl = LetRefDecl(f"__expr_{hash(typed_expr)}")
373
374
  if var_decl in self.expr_to_egg_cache:
374
375
  return None
375
376
  var_egg = self._expr_to_egg(var_decl)
@@ -387,7 +388,7 @@ class EGraphState:
387
388
  def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
388
389
 
389
390
  @overload
390
- def _expr_to_egg(self, expr_decl: VarDecl) -> bindings.Var: ...
391
+ def _expr_to_egg(self, expr_decl: UnboundVarDecl | LetRefDecl) -> bindings.Var: ...
391
392
 
392
393
  @overload
393
394
  def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
@@ -402,11 +403,10 @@ class EGraphState:
402
403
  pass
403
404
  res: bindings._Expr
404
405
  match expr_decl:
405
- case VarDecl(name, is_let):
406
- # prefix let bindings with % to avoid name conflicts with rewrites
407
- if is_let:
408
- name = f"%{name}"
409
- res = bindings.Var(span(), name)
406
+ case LetRefDecl(name):
407
+ res = bindings.Var(span(), f"{name}")
408
+ case UnboundVarDecl(name, egg_name):
409
+ res = bindings.Var(span(), egg_name or f"_{name}")
410
410
  case LitDecl(value):
411
411
  l: bindings._Literal
412
412
  match value:
@@ -467,7 +467,8 @@ class EGraphState:
467
467
  return name
468
468
 
469
469
  case ConstantRef(name):
470
- return name
470
+ # Prefix to avoid name collisions with local vars
471
+ return f"%{name}"
471
472
  case (
472
473
  MethodRef(cls_name, name)
473
474
  | ClassMethodRef(cls_name, name)
@@ -549,7 +550,7 @@ class FromEggState:
549
550
  """
550
551
  expr_decl: ExprDecl
551
552
  if isinstance(term, bindings.TermVar):
552
- expr_decl = VarDecl(term.name, True)
553
+ expr_decl = LetRefDecl(term.name)
553
554
  elif isinstance(term, bindings.TermLit):
554
555
  value = term.value
555
556
  expr_decl = LitDecl(None if isinstance(value, bindings.Unit) else value.value)
@@ -624,7 +625,9 @@ class FromEggState:
624
625
  # but dont need to store them
625
626
  bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else None,
626
627
  )
627
- raise ValueError(f"Could not find callable ref for call {term}")
628
+ raise ValueError(
629
+ f"Could not find callable ref for call {term}. None of these refs matched the types: {self.state.egg_fn_to_callable_refs[term.name]}"
630
+ )
628
631
 
629
632
  def resolve_term(self, term_id: int, tp: JustTypeRef) -> TypedExprDecl:
630
633
  try:
egglog/examples/bignum.py CHANGED
@@ -14,7 +14,7 @@ z = BigRat(x, y)
14
14
 
15
15
  egraph = EGraph()
16
16
 
17
- assert egraph.extract(z.numer.to_string()).eval() == "-617"
17
+ assert egraph.extract(z.numer.to_string()).value == "-617"
18
18
 
19
19
 
20
20
  @function
@@ -32,7 +32,7 @@ egraph.register(xs)
32
32
  egraph.check(xs == MultiSet(Math(1), Math(3), Math(2)))
33
33
  egraph.check_fail(xs == MultiSet(Math(1), Math(1), Math(2), Math(3)))
34
34
 
35
- assert Counter(egraph.extract(xs).eval()) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
35
+ assert Counter(egraph.extract(xs).value) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
36
36
 
37
37
 
38
38
  inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
@@ -45,7 +45,7 @@ assert Math(4) not in xs
45
45
 
46
46
  egraph.check(xs.remove(Math(1)) == MultiSet(Math(2), Math(3)))
47
47
 
48
- assert egraph.extract(xs.length()).eval() == 3
48
+ assert egraph.extract(xs.length()).value == 3
49
49
  assert len(xs) == 3
50
50
 
51
51
  egraph.check(MultiSet(Math(1), Math(1)).length() == i64(2))