egglog 10.0.2__cp310-cp310-win_amd64.whl → 11.0.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph.py CHANGED
@@ -5,7 +5,7 @@ import inspect
5
5
  import pathlib
6
6
  import tempfile
7
7
  from collections.abc import Callable, Generator, Iterable
8
- from contextvars import ContextVar
8
+ from contextvars import ContextVar, Token
9
9
  from dataclasses import InitVar, dataclass, field
10
10
  from functools import partial
11
11
  from inspect import Parameter, currentframe, signature
@@ -29,6 +29,7 @@ from typing_extensions import Never, ParamSpec, Self, Unpack, assert_never
29
29
 
30
30
  from . import bindings
31
31
  from .conversion import *
32
+ from .conversion import convert_to_same_type, resolve_literal
32
33
  from .declarations import *
33
34
  from .egraph_state import *
34
35
  from .ipython_magic import IN_IPYTHON
@@ -82,7 +83,6 @@ __all__ = [
82
83
  "run",
83
84
  "seq",
84
85
  "set_",
85
- "simplify",
86
86
  "subsume",
87
87
  "union",
88
88
  "unstable_combine_rulesets",
@@ -110,12 +110,12 @@ IGNORED_ATTRIBUTES = {
110
110
  "__weakref__",
111
111
  "__orig_bases__",
112
112
  "__annotations__",
113
- "__hash__",
114
113
  "__qualname__",
115
114
  "__firstlineno__",
116
115
  "__static_attributes__",
116
+ "__match_args__",
117
117
  # Ignore all reflected binary method
118
- *REFLECTED_BINARY_METHODS.keys(),
118
+ *(f"__r{m[2:]}" for m in NUMERIC_BINARY_METHODS),
119
119
  }
120
120
 
121
121
 
@@ -139,15 +139,6 @@ ALWAYS_PRESERVED = {
139
139
  }
140
140
 
141
141
 
142
- def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
143
- """
144
- Simplify an expression by running the schedule.
145
- """
146
- if schedule:
147
- return EGraph().simplify(x, schedule)
148
- return EGraph().extract(x)
149
-
150
-
151
142
  def check_eq(x: BASE_EXPR, y: BASE_EXPR, schedule: Schedule | None = None, *, add_second=True, display=False) -> EGraph:
152
143
  """
153
144
  Verifies that two expressions are equal after running the schedule.
@@ -291,7 +282,6 @@ def function(
291
282
  mutates_first_arg: bool = ...,
292
283
  unextractable: bool = ...,
293
284
  ruleset: Ruleset | None = ...,
294
- use_body_as_name: bool = ...,
295
285
  subsume: bool = ...,
296
286
  ) -> Callable[[CONSTRUCTOR_CALLABLE], CONSTRUCTOR_CALLABLE]: ...
297
287
 
@@ -371,6 +361,7 @@ class BaseExpr(metaclass=_ExprMetaclass):
371
361
 
372
362
  def __ne__(self, other: Self) -> Unit: ... # type: ignore[override, empty-body]
373
363
 
364
+ # not currently dissalowing other types of equality https://github.com/python/typeshed/issues/8217#issuecomment-3140873292
374
365
  def __eq__(self, other: Self) -> Fact: ... # type: ignore[override, empty-body]
375
366
 
376
367
 
@@ -404,7 +395,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
404
395
  )
405
396
  type_vars = tuple(ClassTypeVarRef.from_type_var(p) for p in parameters)
406
397
  del parameters
407
- cls_decl = ClassDecl(egg_sort, type_vars, builtin)
398
+ cls_decl = ClassDecl(egg_sort, type_vars, builtin, match_args=namespace.pop("__match_args__", ()))
408
399
  decls = Declarations(_classes={cls_name: cls_decl})
409
400
  # Update class think eagerly when resolving so that lookups work in methods
410
401
  runtime_cls.__egg_decls_thunk__ = Thunk.value(decls)
@@ -456,6 +447,9 @@ def _generate_class_decls( # noqa: C901,PLR0912
456
447
  continue
457
448
  locals = frame.f_locals
458
449
  ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
450
+ # TODO: Store deprecated message so we can print at runtime
451
+ if (getattr(fn, "__deprecated__", None)) is not None:
452
+ fn = fn.__wrapped__ # type: ignore[attr-defined]
459
453
  match fn:
460
454
  case classmethod():
461
455
  ref = ClassMethodRef(cls_name, method_name)
@@ -477,7 +471,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
477
471
  decls.set_function_decl(ref, decl)
478
472
  continue
479
473
  try:
480
- _, add_rewrite = _fn_decl(
474
+ add_rewrite = _fn_decl(
481
475
  decls,
482
476
  egg_fn,
483
477
  ref,
@@ -515,7 +509,6 @@ class _FunctionConstructor:
515
509
  merge: Callable[[object, object], object] | None = None
516
510
  unextractable: bool = False
517
511
  ruleset: Ruleset | None = None
518
- use_body_as_name: bool = False
519
512
  subsume: bool = False
520
513
 
521
514
  def __call__(self, fn: Callable) -> RuntimeFunction:
@@ -523,11 +516,10 @@ class _FunctionConstructor:
523
516
 
524
517
  def create_decls(self, fn: Callable) -> tuple[Declarations, CallableRef]:
525
518
  decls = Declarations()
526
- ref = None if self.use_body_as_name else FunctionRef(fn.__name__)
527
- ref, add_rewrite = _fn_decl(
519
+ add_rewrite = _fn_decl(
528
520
  decls,
529
521
  self.egg_fn,
530
- ref,
522
+ ref := FunctionRef(fn.__name__),
531
523
  fn,
532
524
  self.hint_locals,
533
525
  self.cost,
@@ -545,8 +537,7 @@ class _FunctionConstructor:
545
537
  def _fn_decl(
546
538
  decls: Declarations,
547
539
  egg_name: str | None,
548
- # If ref is Callable, then generate the ref from the function name
549
- ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef | None,
540
+ ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
550
541
  fn: object,
551
542
  # Pass in the locals, retrieved from the frame when wrapping,
552
543
  # so that we support classes and function defined inside of other functions (which won't show up in the globals)
@@ -559,7 +550,7 @@ def _fn_decl(
559
550
  ruleset: Ruleset | None = None,
560
551
  unextractable: bool = False,
561
552
  reverse_args: bool = False,
562
- ) -> tuple[CallableRef, Callable[[], None]]:
553
+ ) -> Callable[[], None]:
563
554
  """
564
555
  Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
565
556
  """
@@ -619,8 +610,8 @@ def _fn_decl(
619
610
  else resolve_literal(
620
611
  return_type,
621
612
  merge(
622
- RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old", False))),
623
- RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new", False))),
613
+ RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), UnboundVarDecl("old", "old"))),
614
+ RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), UnboundVarDecl("new", "new"))),
624
615
  ),
625
616
  lambda: decls,
626
617
  )
@@ -628,51 +619,40 @@ def _fn_decl(
628
619
  decls |= merged
629
620
 
630
621
  # defer this in generator so it doesn't resolve for builtins eagerly
631
- args = (TypedExprDecl(tp.to_just(), VarDecl(name, False)) for name, tp in zip(arg_names, arg_types, strict=True))
632
- res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
633
- res_thunk: Callable[[], object]
634
- # If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it
635
- if not ref:
636
- tuple_args = tuple(args)
637
- res = _create_default_value(decls, ref, fn, tuple_args, ruleset)
638
- assert isinstance(res, RuntimeExpr)
639
- res_ref = UnnamedFunctionRef(tuple_args, res.__egg_typed_expr__)
640
- decls._unnamed_functions.add(res_ref)
641
- res_thunk = Thunk.value(res)
622
+ args = (TypedExprDecl(tp.to_just(), UnboundVarDecl(name)) for name, tp in zip(arg_names, arg_types, strict=True))
642
623
 
624
+ return_type_is_eqsort = (
625
+ not decls._classes[return_type.name].builtin if isinstance(return_type, TypeRefWithVars) else False
626
+ )
627
+ is_constructor = not is_builtin and return_type_is_eqsort and merged is None
628
+ signature_ = FunctionSignature(
629
+ return_type=None if mutates_first_arg else return_type,
630
+ var_arg_type=var_arg_type,
631
+ arg_types=arg_types,
632
+ arg_names=arg_names,
633
+ arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
634
+ reverse_args=reverse_args,
635
+ )
636
+ decl: ConstructorDecl | FunctionDecl
637
+ if is_constructor:
638
+ decl = ConstructorDecl(signature_, egg_name, cost, unextractable)
643
639
  else:
644
- return_type_is_eqsort = (
645
- not decls._classes[return_type.name].builtin if isinstance(return_type, TypeRefWithVars) else False
646
- )
647
- is_constructor = not is_builtin and return_type_is_eqsort and merged is None
648
- signature_ = FunctionSignature(
649
- return_type=None if mutates_first_arg else return_type,
650
- var_arg_type=var_arg_type,
651
- arg_types=arg_types,
652
- arg_names=arg_names,
653
- arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
654
- reverse_args=reverse_args,
640
+ if cost is not None:
641
+ msg = "Cost can only be set for constructors"
642
+ raise ValueError(msg)
643
+ if unextractable:
644
+ msg = "Unextractable can only be set for constructors"
645
+ raise ValueError(msg)
646
+ decl = FunctionDecl(
647
+ signature=signature_,
648
+ egg_name=egg_name,
649
+ merge=merged.__egg_typed_expr__.expr if merged is not None else None,
650
+ builtin=is_builtin,
655
651
  )
656
- decl: ConstructorDecl | FunctionDecl
657
- if is_constructor:
658
- decl = ConstructorDecl(signature_, egg_name, cost, unextractable)
659
- else:
660
- if cost is not None:
661
- msg = "Cost can only be set for constructors"
662
- raise ValueError(msg)
663
- if unextractable:
664
- msg = "Unextractable can only be set for constructors"
665
- raise ValueError(msg)
666
- decl = FunctionDecl(
667
- signature=signature_,
668
- egg_name=egg_name,
669
- merge=merged.__egg_typed_expr__.expr if merged is not None else None,
670
- builtin=is_builtin,
671
- )
672
- res_ref = ref
673
- decls.set_function_decl(ref, decl)
674
- res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset, context=f"creating {ref}")
675
- return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk, subsume)
652
+ decls.set_function_decl(ref, decl)
653
+ return Thunk.fn(
654
+ _add_default_rewrite_function, decls, ref, fn, args, ruleset, subsume, return_type, context=f"creating {ref}"
655
+ )
676
656
 
677
657
 
678
658
  # Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
@@ -707,7 +687,7 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..
707
687
 
708
688
 
709
689
  def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations:
710
- from .builtins import Unit
690
+ from .builtins import Unit # noqa: PLC0415
711
691
 
712
692
  decls = Declarations()
713
693
  decls |= cast("RuntimeClass", Unit)
@@ -746,13 +726,15 @@ def _constant_thunk(
746
726
  return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
747
727
 
748
728
 
749
- def _create_default_value(
729
+ def _add_default_rewrite_function(
750
730
  decls: Declarations,
751
- ref: CallableRef | None,
731
+ ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
752
732
  fn: Callable,
753
733
  args: Iterable[TypedExprDecl],
754
734
  ruleset: Ruleset | None,
755
- ) -> object:
735
+ subsume: bool,
736
+ res_type: TypeOrVarRef,
737
+ ) -> None:
756
738
  args: list[object] = [RuntimeExpr.__from_values__(decls, a) for a in args]
757
739
 
758
740
  # If this is a classmethod, add the class as the first arg
@@ -760,21 +742,8 @@ def _create_default_value(
760
742
  tp = decls.get_paramaterized_class(ref.class_name)
761
743
  args.insert(0, RuntimeClass(Thunk.value(decls), tp))
762
744
  with set_current_ruleset(ruleset):
763
- return fn(*args)
764
-
765
-
766
- def _add_default_rewrite_function(
767
- decls: Declarations,
768
- ref: CallableRef,
769
- res_type: TypeOrVarRef,
770
- ruleset: Ruleset | None,
771
- value_thunk: Callable[[], object],
772
- subsume: bool,
773
- ) -> None:
774
- """
775
- Helper functions that resolves a value thunk to create the default value.
776
- """
777
- _add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset, subsume)
745
+ res = fn(*args)
746
+ _add_default_rewrite(decls, ref, res_type, res, ruleset, subsume)
778
747
 
779
748
 
780
749
  def _add_default_rewrite(
@@ -794,6 +763,13 @@ def _add_default_rewrite(
794
763
  return
795
764
  resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls))
796
765
  rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr, subsume)
766
+ ruleset_decls = _add_default_rewrite_inner(decls, rewrite_decl, ruleset)
767
+ ruleset_decls |= resolved_value
768
+
769
+
770
+ def _add_default_rewrite_inner(
771
+ decls: Declarations, rewrite_decl: DefaultRewriteDecl, ruleset: Ruleset | None
772
+ ) -> Declarations:
797
773
  if ruleset:
798
774
  ruleset_decls = ruleset._current_egg_decls
799
775
  ruleset_decl = ruleset.__egg_ruleset__
@@ -801,7 +777,7 @@ def _add_default_rewrite(
801
777
  ruleset_decls = decls
802
778
  ruleset_decl = decls.default_ruleset
803
779
  ruleset_decl.rules.append(rewrite_decl)
804
- ruleset_decls |= resolved_value
780
+ return ruleset_decls
805
781
 
806
782
 
807
783
  def _last_param_variable(params: list[Parameter]) -> bool:
@@ -882,6 +858,7 @@ class EGraph:
882
858
  self._add_decls(decls)
883
859
  return self._state.callable_ref_to_egg(ref)[0]
884
860
 
861
+ # TODO: Change let to be action...
885
862
  def let(self, name: str, expr: BASE_EXPR) -> BASE_EXPR:
886
863
  """
887
864
  Define a new expression in the egraph and return a reference to it.
@@ -893,38 +870,10 @@ class EGraph:
893
870
  return cast(
894
871
  "BASE_EXPR",
895
872
  RuntimeExpr.__from_values__(
896
- self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name, True))
873
+ self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, LetRefDecl(name))
897
874
  ),
898
875
  )
899
876
 
900
- @overload
901
- def simplify(self, expr: BASE_EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> BASE_EXPR: ...
902
-
903
- @overload
904
- def simplify(self, expr: BASE_EXPR, schedule: Schedule, /) -> BASE_EXPR: ...
905
-
906
- def simplify(
907
- self, expr: BASE_EXPR, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None
908
- ) -> BASE_EXPR:
909
- """
910
- Simplifies the given expression.
911
- """
912
- schedule = run(ruleset, *until) * limit_or_schedule if isinstance(limit_or_schedule, int) else limit_or_schedule
913
- del limit_or_schedule, until, ruleset
914
- runtime_expr = to_runtime_expr(expr)
915
- self._add_decls(runtime_expr, schedule)
916
- egg_schedule = self._state.schedule_to_egg(schedule.schedule)
917
- typed_expr = runtime_expr.__egg_typed_expr__
918
- # Must also register type
919
- egg_expr = self._state.typed_expr_to_egg(typed_expr)
920
- self._egraph.run_program(bindings.Simplify(span(1), egg_expr, egg_schedule))
921
- extract_report = self._egraph.extract_report()
922
- if not isinstance(extract_report, bindings.Best):
923
- msg = "No extract report saved"
924
- raise ValueError(msg) # noqa: TRY004
925
- (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
926
- return cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
927
-
928
877
  def include(self, path: str) -> None:
929
878
  """
930
879
  Include a file of rules.
@@ -1036,9 +985,7 @@ class EGraph:
1036
985
  self._add_decls(expr)
1037
986
  expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
1038
987
  try:
1039
- self._egraph.run_program(
1040
- bindings.ActionCommand(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))
1041
- )
988
+ self._egraph.run_program(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))
1042
989
  except BaseException as e:
1043
990
  raise add_note("Extracting: " + str(expr), e) # noqa: B904
1044
991
  extract_report = self._egraph.extract_report()
@@ -1138,9 +1085,9 @@ class EGraph:
1138
1085
 
1139
1086
  If in IPython it will display it inline, otherwise it will write it to a file and open it.
1140
1087
  """
1141
- from IPython.display import SVG, display
1088
+ from IPython.display import SVG, display # noqa: PLC0415
1142
1089
 
1143
- from .visualizer_widget import VisualizerWidget
1090
+ from .visualizer_widget import VisualizerWidget # noqa: PLC0415
1144
1091
 
1145
1092
  if graphviz:
1146
1093
  if IN_IPYTHON:
@@ -1167,7 +1114,7 @@ class EGraph:
1167
1114
 
1168
1115
  If an `expr` is passed, it's also extracted after each run and printed
1169
1116
  """
1170
- from .visualizer_widget import VisualizerWidget
1117
+ from .visualizer_widget import VisualizerWidget # noqa: PLC0415
1171
1118
 
1172
1119
  def to_json() -> str:
1173
1120
  if expr is not None:
@@ -1559,16 +1506,17 @@ def rule(*facts: FactLike, ruleset: None = None, name: str | None = None) -> _Ru
1559
1506
  return _RuleBuilder(facts=_fact_likes(facts), name=name, ruleset=ruleset)
1560
1507
 
1561
1508
 
1562
- def var(name: str, bound: type[T]) -> T:
1509
+ def var(name: str, bound: type[T], egg_name: str | None = None) -> T:
1563
1510
  """Create a new variable with the given name and type."""
1564
- return cast("T", _var(name, bound))
1511
+ return cast("T", _var(name, bound, egg_name=egg_name))
1565
1512
 
1566
1513
 
1567
- def _var(name: str, bound: object) -> RuntimeExpr:
1514
+ def _var(name: str, bound: object, egg_name: str | None) -> RuntimeExpr:
1568
1515
  """Create a new variable with the given name and type."""
1569
1516
  decls_like, type_ref = resolve_type_annotation(bound)
1570
1517
  return RuntimeExpr(
1571
- Thunk.fn(Declarations.create, decls_like), Thunk.value(TypedExprDecl(type_ref.to_just(), VarDecl(name, False)))
1518
+ Thunk.fn(Declarations.create, decls_like),
1519
+ Thunk.value(TypedExprDecl(type_ref.to_just(), UnboundVarDecl(name, egg_name))),
1572
1520
  )
1573
1521
 
1574
1522
 
@@ -1659,7 +1607,7 @@ class _NeBuilder(Generic[BASE_EXPR]):
1659
1607
  lhs: BASE_EXPR
1660
1608
 
1661
1609
  def to(self, rhs: BASE_EXPR) -> Unit:
1662
- from .builtins import Unit
1610
+ from .builtins import Unit # noqa: PLC0415
1663
1611
 
1664
1612
  lhs = to_runtime_expr(self.lhs)
1665
1613
  rhs = convert_to_same_type(rhs, lhs)
@@ -1818,7 +1766,7 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
1818
1766
  # python/tests/test_no_import_star.py::test_no_import_star_rulesset
1819
1767
  combined = {**gen.__globals__, **frame.f_locals}
1820
1768
  hints = get_type_hints(gen, combined, combined)
1821
- args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
1769
+ args = [_var(p.name, hints[p.name], egg_name=None) for p in signature(gen).parameters.values()]
1822
1770
  return list(gen(*args)) # type: ignore[misc]
1823
1771
 
1824
1772
 
@@ -1844,7 +1792,7 @@ def get_current_ruleset() -> Ruleset | None:
1844
1792
 
1845
1793
  @contextlib.contextmanager
1846
1794
  def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
1847
- token = _CURRENT_RULESET.set(r)
1795
+ token: Token[Ruleset | None] = _CURRENT_RULESET.set(r)
1848
1796
  try:
1849
1797
  yield
1850
1798
  finally:
egglog/egraph_state.py CHANGED
@@ -108,7 +108,7 @@ class EGraphState:
108
108
  case RulesetDecl(rules):
109
109
  if name not in self.rulesets:
110
110
  if name:
111
- self.egraph.run_program(bindings.AddRuleset(name))
111
+ self.egraph.run_program(bindings.AddRuleset(span(), name))
112
112
  added_rules = self.rulesets[name] = set()
113
113
  else:
114
114
  added_rules = self.rulesets[name]
@@ -125,7 +125,7 @@ class EGraphState:
125
125
  self.rulesets[name] = set()
126
126
  for ruleset in rulesets:
127
127
  self.ruleset_to_egg(ruleset)
128
- self.egraph.run_program(bindings.UnstableCombinedRuleset(name, list(rulesets)))
128
+ self.egraph.run_program(bindings.UnstableCombinedRuleset(span(), name, list(rulesets)))
129
129
 
130
130
  def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command | None:
131
131
  match cmd:
@@ -160,7 +160,7 @@ class EGraphState:
160
160
  assert isinstance(sig, FunctionSignature)
161
161
  # Replace args with rule_var_name mapping
162
162
  arg_mapping = tuple(
163
- TypedExprDecl(tp.to_just(), VarDecl(name, False))
163
+ TypedExprDecl(tp.to_just(), UnboundVarDecl(name))
164
164
  for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
165
165
  )
166
166
  rewrite_decl = RewriteDecl(
@@ -179,7 +179,7 @@ class EGraphState:
179
179
  def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindings._Action | None: # noqa: C901, PLR0911, PLR0912
180
180
  match action:
181
181
  case LetDecl(name, typed_expr):
182
- var_decl = VarDecl(name, True)
182
+ var_decl = LetRefDecl(name)
183
183
  var_egg = self._expr_to_egg(var_decl)
184
184
  self.expr_to_egg_cache[var_decl] = var_egg
185
185
  return bindings.Let(span(), var_egg.name, self.typed_expr_to_egg(typed_expr))
@@ -369,7 +369,8 @@ class EGraphState:
369
369
  """
370
370
  Rewrites this expression as a let binding if it's not already a let binding.
371
371
  """
372
- var_decl = VarDecl(f"__expr_{hash(typed_expr)}", True)
372
+ # TODO: Replace with counter so that it works with hash collisions and is more stable
373
+ var_decl = LetRefDecl(f"__expr_{hash(typed_expr)}")
373
374
  if var_decl in self.expr_to_egg_cache:
374
375
  return None
375
376
  var_egg = self._expr_to_egg(var_decl)
@@ -387,7 +388,7 @@ class EGraphState:
387
388
  def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
388
389
 
389
390
  @overload
390
- def _expr_to_egg(self, expr_decl: VarDecl) -> bindings.Var: ...
391
+ def _expr_to_egg(self, expr_decl: UnboundVarDecl | LetRefDecl) -> bindings.Var: ...
391
392
 
392
393
  @overload
393
394
  def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
@@ -402,11 +403,10 @@ class EGraphState:
402
403
  pass
403
404
  res: bindings._Expr
404
405
  match expr_decl:
405
- case VarDecl(name, is_let):
406
- # prefix let bindings with % to avoid name conflicts with rewrites
407
- if is_let:
408
- name = f"%{name}"
409
- res = bindings.Var(span(), name)
406
+ case LetRefDecl(name):
407
+ res = bindings.Var(span(), f"{name}")
408
+ case UnboundVarDecl(name, egg_name):
409
+ res = bindings.Var(span(), egg_name or f"_{name}")
410
410
  case LitDecl(value):
411
411
  l: bindings._Literal
412
412
  match value:
@@ -467,7 +467,8 @@ class EGraphState:
467
467
  return name
468
468
 
469
469
  case ConstantRef(name):
470
- return name
470
+ # Prefix to avoid name collisions with local vars
471
+ return f"%{name}"
471
472
  case (
472
473
  MethodRef(cls_name, name)
473
474
  | ClassMethodRef(cls_name, name)
@@ -549,7 +550,7 @@ class FromEggState:
549
550
  """
550
551
  expr_decl: ExprDecl
551
552
  if isinstance(term, bindings.TermVar):
552
- expr_decl = VarDecl(term.name, True)
553
+ expr_decl = LetRefDecl(term.name)
553
554
  elif isinstance(term, bindings.TermLit):
554
555
  value = term.value
555
556
  expr_decl = LitDecl(None if isinstance(value, bindings.Unit) else value.value)
@@ -624,7 +625,9 @@ class FromEggState:
624
625
  # but dont need to store them
625
626
  bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else None,
626
627
  )
627
- raise ValueError(f"Could not find callable ref for call {term}")
628
+ raise ValueError(
629
+ f"Could not find callable ref for call {term}. None of these refs matched the types: {self.state.egg_fn_to_callable_refs[term.name]}"
630
+ )
628
631
 
629
632
  def resolve_term(self, term_id: int, tp: JustTypeRef) -> TypedExprDecl:
630
633
  try:
egglog/examples/bignum.py CHANGED
@@ -14,7 +14,7 @@ z = BigRat(x, y)
14
14
 
15
15
  egraph = EGraph()
16
16
 
17
- assert egraph.extract(z.numer.to_string()).eval() == "-617"
17
+ assert egraph.extract(z.numer.to_string()).value == "-617"
18
18
 
19
19
 
20
20
  @function
@@ -32,7 +32,7 @@ egraph.register(xs)
32
32
  egraph.check(xs == MultiSet(Math(1), Math(3), Math(2)))
33
33
  egraph.check_fail(xs == MultiSet(Math(1), Math(1), Math(2), Math(3)))
34
34
 
35
- assert Counter(egraph.extract(xs).eval()) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
35
+ assert Counter(egraph.extract(xs).value) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
36
36
 
37
37
 
38
38
  inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
@@ -45,7 +45,7 @@ assert Math(4) not in xs
45
45
 
46
46
  egraph.check(xs.remove(Math(1)) == MultiSet(Math(2), Math(3)))
47
47
 
48
- assert egraph.extract(xs.length()).eval() == 3
48
+ assert egraph.extract(xs.length()).value == 3
49
49
  assert len(xs) == 3
50
50
 
51
51
  egraph.check(MultiSet(Math(1), Math(1)).length() == i64(2))
egglog/exp/array_api.py CHANGED
@@ -154,6 +154,18 @@ class Int(Expr, ruleset=array_api_ruleset):
154
154
  def __eq__(self, other: IntLike) -> Boolean: # type: ignore[override]
155
155
  ...
156
156
 
157
+ # add a hash so that this test can pass
158
+ # https://github.com/scikit-learn/scikit-learn/blob/6fd23fca53845b32b249f2b36051c081b65e2fab/sklearn/utils/validation.py#L486-L487
159
+ @method(preserve=True)
160
+ def __hash__(self) -> int:
161
+ egraph = _get_current_egraph()
162
+ egraph.register(self)
163
+ egraph.run(array_api_schedule)
164
+ simplified = egraph.extract(self)
165
+ return hash(cast("RuntimeExpr", simplified).__egg_typed_expr__)
166
+
167
+ def __round__(self, ndigits: OptionalIntLike = None) -> Int: ...
168
+
157
169
  # TODO: Fix this?
158
170
  # Make != always return a Bool, so that numpy.unique works on a tuple of ints
159
171
  # In _unique1d
@@ -280,6 +292,8 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
280
292
  yield rewrite(Int.if_(TRUE, o, b), subsume=True).to(o)
281
293
  yield rewrite(Int.if_(FALSE, o, b), subsume=True).to(b)
282
294
 
295
+ yield rewrite(o.__round__(OptionalInt.none)).to(o)
296
+
283
297
  # Never cannot be equal to anything real
284
298
  yield rule(eq(Int.NEVER).to(Int(i))).then(panic("Int.NEVER cannot be equal to any real int"))
285
299
 
@@ -354,8 +368,14 @@ class Float(Expr, ruleset=array_api_ruleset):
354
368
  def __sub__(self, other: FloatLike) -> Float: ...
355
369
 
356
370
  def __pow__(self, other: FloatLike) -> Float: ...
371
+ def __round__(self, ndigits: OptionalIntLike = None) -> Float: ...
357
372
 
358
373
  def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
374
+ def __ne__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
375
+ def __lt__(self, other: FloatLike) -> Boolean: ...
376
+ def __le__(self, other: FloatLike) -> Boolean: ...
377
+ def __gt__(self, other: FloatLike) -> Boolean: ...
378
+ def __ge__(self, other: FloatLike) -> Boolean: ...
359
379
 
360
380
 
361
381
  converter(float, Float, lambda x: Float(x))
@@ -366,9 +386,10 @@ FloatLike: TypeAlias = Float | float | IntLike
366
386
 
367
387
 
368
388
  @array_api_ruleset.register
369
- def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
389
+ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat, i_: Int):
370
390
  return [
371
391
  rule(eq(fl).to(Float(f))).then(set_(fl.to_f64).to(f)),
392
+ rewrite(Float.from_int(Int(i))).to(Float(f64.from_i64(i))),
372
393
  rewrite(Float(f).abs()).to(Float(f), f >= 0.0),
373
394
  rewrite(Float(f).abs()).to(Float(-f), f < 0.0),
374
395
  # Convert from float to rationl, if its a whole number i.e. can be converted to int
@@ -383,11 +404,22 @@ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
383
404
  rewrite(Float.rational(r) - Float.rational(r1)).to(Float.rational(r - r1)),
384
405
  rewrite(Float.rational(r) * Float.rational(r1)).to(Float.rational(r * r1)),
385
406
  rewrite(Float(f) ** Float(f2)).to(Float(f**f2)),
386
- # ==
407
+ # comparisons
387
408
  rewrite(Float(f) == Float(f)).to(TRUE),
388
409
  rewrite(Float(f) == Float(f2)).to(FALSE, ne(f).to(f2)),
410
+ rewrite(Float(f) != Float(f2)).to(TRUE, f != f2),
411
+ rewrite(Float(f) != Float(f)).to(FALSE),
412
+ rewrite(Float(f) >= Float(f2)).to(TRUE, f >= f2),
413
+ rewrite(Float(f) >= Float(f2)).to(FALSE, f < f2),
414
+ rewrite(Float(f) <= Float(f2)).to(TRUE, f <= f2),
415
+ rewrite(Float(f) <= Float(f2)).to(FALSE, f > f2),
416
+ rewrite(Float(f) > Float(f2)).to(TRUE, f > f2),
417
+ rewrite(Float(f) > Float(f2)).to(FALSE, f <= f2),
418
+ rewrite(Float(f) < Float(f2)).to(TRUE, f < f2),
389
419
  rewrite(Float.rational(r) == Float.rational(r)).to(TRUE),
390
420
  rewrite(Float.rational(r) == Float.rational(r1)).to(FALSE, ne(r).to(r1)),
421
+ # round
422
+ rewrite(Float.rational(r).__round__()).to(Float.rational(r.round())),
391
423
  ]
392
424
 
393
425
 
@@ -671,6 +703,8 @@ class OptionalInt(Expr, ruleset=array_api_ruleset):
671
703
  def some(cls, value: Int) -> OptionalInt: ...
672
704
 
673
705
 
706
+ OptionalIntLike: TypeAlias = OptionalInt | IntLike | None
707
+
674
708
  converter(type(None), OptionalInt, lambda _: OptionalInt.none)
675
709
  converter(Int, OptionalInt, OptionalInt.some)
676
710
 
@@ -1982,4 +2016,4 @@ def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: Built
1982
2016
  except BaseException as e:
1983
2017
  # egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1984
2018
  raise add_note(f"Cannot evaluate {egraph.extract(expr)}", e) # noqa: B904
1985
- return extracted.eval() # type: ignore[attr-defined]
2019
+ return extracted.value # type: ignore[attr-defined]
@@ -14,16 +14,22 @@ from .program_gen import Program
14
14
  X = TypeVar("X", bound=Callable)
15
15
 
16
16
 
17
- def jit(fn: X) -> X:
17
+ def jit(
18
+ fn: X,
19
+ *,
20
+ handle_expr: Callable[[NDArray], None] | None = None,
21
+ handle_optimized_expr: Callable[[NDArray], None] | None = None,
22
+ ) -> X:
18
23
  """
19
24
  Jit compiles a function
20
25
  """
21
26
  egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
27
+ if handle_expr:
28
+ handle_expr(res)
29
+ if handle_optimized_expr:
30
+ handle_optimized_expr(res_optimized)
22
31
  fn_program = EvalProgram(program, {"np": np})
23
- fn = cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
24
- fn.initial_expr = res # type: ignore[attr-defined]
25
- fn.expr = res_optimized # type: ignore[attr-defined]
26
- return fn
32
+ return cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
27
33
 
28
34
 
29
35
  def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph, NDArray, NDArray, Program]: