egglog 7.2.0__cp312-none-win_amd64.whl → 8.0.0__cp312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph.py CHANGED
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import contextlib
3
4
  import inspect
4
5
  import pathlib
5
6
  import tempfile
6
7
  from abc import abstractmethod
7
- from collections.abc import Callable, Iterable
8
+ from collections.abc import Callable, Generator, Iterable
8
9
  from contextvars import ContextVar, Token
9
10
  from dataclasses import InitVar, dataclass, field
10
11
  from inspect import Parameter, currentframe, signature
@@ -37,8 +38,6 @@ from .runtime import *
37
38
  from .thunk import *
38
39
 
39
40
  if TYPE_CHECKING:
40
- import ipywidgets
41
-
42
41
  from .builtins import Bool, PyObject, String, f64, i64
43
42
 
44
43
 
@@ -464,10 +463,14 @@ class _ExprMetaclass(type):
464
463
  prev_frame = frame.f_back
465
464
  assert prev_frame
466
465
 
466
+ # Pass in an instance of the class so that when we are generating the decls
467
+ # we can update them eagerly so that we can access the methods in the class body
468
+ runtime_cls = RuntimeClass(None, TypeRefWithVars(name)) # type: ignore[arg-type]
469
+
467
470
  # Store frame so that we can get live access to updated locals/globals
468
471
  # Otherwise, f_locals returns a copy
469
472
  # https://peps.python.org/pep-0667/
470
- decls_thunk = Thunk.fn(
473
+ runtime_cls.__egg_decls_thunk__ = Thunk.fn(
471
474
  _generate_class_decls,
472
475
  namespace,
473
476
  prev_frame,
@@ -475,21 +478,22 @@ class _ExprMetaclass(type):
475
478
  egg_sort,
476
479
  name,
477
480
  ruleset,
478
- fallback=Declarations,
481
+ runtime_cls,
479
482
  )
480
- return RuntimeClass(decls_thunk, TypeRefWithVars(name))
483
+ return runtime_cls
481
484
 
482
485
  def __instancecheck__(cls, instance: object) -> bool:
483
486
  return isinstance(instance, RuntimeExpr)
484
487
 
485
488
 
486
- def _generate_class_decls(
489
+ def _generate_class_decls( # noqa: C901
487
490
  namespace: dict[str, Any],
488
491
  frame: FrameType,
489
492
  builtin: bool,
490
493
  egg_sort: str | None,
491
494
  cls_name: str,
492
495
  ruleset: Ruleset | None,
496
+ runtime_cls: RuntimeClass,
493
497
  ) -> Declarations:
494
498
  """
495
499
  Lazy constructor for class declerations to support classes with methods whose types are not yet defined.
@@ -502,6 +506,8 @@ def _generate_class_decls(
502
506
  del parameters
503
507
  cls_decl = ClassDecl(egg_sort, type_vars, builtin)
504
508
  decls = Declarations(_classes={cls_name: cls_decl})
509
+ # Update class think eagerly when resolving so that lookups work in methods
510
+ runtime_cls.__egg_decls_thunk__ = Thunk.value(decls)
505
511
 
506
512
  ##
507
513
  # Register class variables
@@ -522,16 +528,13 @@ def _generate_class_decls(
522
528
  # Register methods, classmethods, preserved methods, and properties
523
529
  ##
524
530
 
525
- # The type ref of self is paramterized by the type vars
526
- TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
527
-
528
531
  # Get all the methods from the class
529
532
  filtered_namespace: list[tuple[str, Any]] = [
530
533
  (k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
531
534
  ]
532
535
 
533
536
  # all methods we should try adding default functions for
534
- default_function_refs: dict[ClassMethodRef | MethodRef | PropertyRef, Callable] = {}
537
+ add_default_funcs: list[Callable[[], None]] = []
535
538
  # Then register each of its methods
536
539
  for method_name, method in filtered_namespace:
537
540
  is_init = method_name == "__init__"
@@ -550,7 +553,6 @@ def _generate_class_decls(
550
553
  cls_decl.preserved_methods[method_name] = fn
551
554
  continue
552
555
  locals = frame.f_locals
553
-
554
556
  ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
555
557
  match fn:
556
558
  case classmethod():
@@ -561,16 +563,25 @@ def _generate_class_decls(
561
563
  fn = fn.fget
562
564
  case _:
563
565
  ref = InitRef(cls_name) if is_init else MethodRef(cls_name, method_name)
566
+ special_function_name: SpecialFunctions | None = (
567
+ "fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None
568
+ )
569
+ if special_function_name:
570
+ decl = FunctionDecl(special_function_name, builtin=True, egg_name=egg_fn)
571
+ decls.set_function_decl(ref, decl)
572
+ continue
564
573
 
565
- _fn_decl(decls, egg_fn, ref, fn, locals, default, cost, merge, on_merge, mutates, builtin, unextractable)
574
+ _, add_rewrite = _fn_decl(
575
+ decls, egg_fn, ref, fn, locals, default, cost, merge, on_merge, mutates, builtin, ruleset, unextractable
576
+ )
566
577
 
567
578
  if not builtin and not isinstance(ref, InitRef) and not mutates:
568
- default_function_refs[ref] = fn
579
+ add_default_funcs.append(add_rewrite)
569
580
 
570
581
  # Add all rewrite methods at the end so that all methods are registered first and can be accessed
571
582
  # in the bodies
572
- for ref, fn in default_function_refs.items():
573
- _add_default_rewrite_function(decls, ref, fn, ruleset)
583
+ for add_rewrite in add_default_funcs:
584
+ add_rewrite()
574
585
 
575
586
  return decls
576
587
 
@@ -590,6 +601,7 @@ def function(
590
601
  unextractable: bool = False,
591
602
  builtin: bool = False,
592
603
  ruleset: Ruleset | None = None,
604
+ use_body_as_name: bool = False,
593
605
  ) -> Callable[[CALLABLE], CALLABLE]: ...
594
606
 
595
607
 
@@ -604,6 +616,7 @@ def function(
604
616
  mutates_first_arg: bool = False,
605
617
  unextractable: bool = False,
606
618
  ruleset: Ruleset | None = None,
619
+ use_body_as_name: bool = False,
607
620
  ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...
608
621
 
609
622
 
@@ -635,16 +648,18 @@ class _FunctionConstructor:
635
648
  on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
636
649
  unextractable: bool = False
637
650
  ruleset: Ruleset | None = None
651
+ use_body_as_name: bool = False
638
652
 
639
653
  def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
640
- return RuntimeFunction(Thunk.fn(self.create_decls, fn), FunctionRef(fn.__name__))
654
+ return RuntimeFunction(*split_thunk(Thunk.fn(self.create_decls, fn)))
641
655
 
642
- def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations:
656
+ def create_decls(self, fn: Callable[..., RuntimeExpr]) -> tuple[Declarations, CallableRef]:
643
657
  decls = Declarations()
644
- _fn_decl(
658
+ ref = None if self.use_body_as_name else FunctionRef(fn.__name__)
659
+ ref, add_rewrite = _fn_decl(
645
660
  decls,
646
661
  self.egg_fn,
647
- (ref := FunctionRef(fn.__name__)),
662
+ ref,
648
663
  fn,
649
664
  self.hint_locals,
650
665
  self.default,
@@ -653,16 +668,18 @@ class _FunctionConstructor:
653
668
  self.on_merge,
654
669
  self.mutates_first_arg,
655
670
  self.builtin,
671
+ self.ruleset,
656
672
  unextractable=self.unextractable,
657
673
  )
658
- _add_default_rewrite_function(decls, ref, fn, self.ruleset)
659
- return decls
674
+ add_rewrite()
675
+ return decls, ref
660
676
 
661
677
 
662
678
  def _fn_decl(
663
679
  decls: Declarations,
664
680
  egg_name: str | None,
665
- ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
681
+ # If ref is Callable, then generate the ref from the function name
682
+ ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef | None,
666
683
  fn: object,
667
684
  # Pass in the locals, retrieved from the frame when wrapping,
668
685
  # so that we support classes and function defined inside of other functions (which won't show up in the globals)
@@ -673,23 +690,15 @@ def _fn_decl(
673
690
  on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None,
674
691
  mutates_first_arg: bool,
675
692
  is_builtin: bool,
693
+ ruleset: Ruleset | None = None,
676
694
  unextractable: bool = False,
677
- ) -> None:
695
+ ) -> tuple[CallableRef, Callable[[], None]]:
678
696
  """
679
- Sets the function decl for the function object.
697
+ Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
680
698
  """
681
699
  if not isinstance(fn, FunctionType):
682
700
  raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
683
701
 
684
- # partial function creation and calling are handled with a special case in the type checker, so don't
685
- # use the normal logic
686
- special_function_name: SpecialFunctions | None = (
687
- "fn-partial" if egg_name == "unstable-fn" else "fn-app" if egg_name == "unstable-app" else None
688
- )
689
- if special_function_name:
690
- decls.set_function_decl(ref, FunctionDecl(special_function_name, builtin=True, egg_name=egg_name))
691
- return
692
-
693
702
  hint_globals = fn.__globals__.copy()
694
703
  # Copy Callable into global if not present bc sometimes it gets automatically removed by ruff to type only block
695
704
  # https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/
@@ -719,7 +728,7 @@ def _fn_decl(
719
728
 
720
729
  # Resolve all default values as arg types
721
730
  arg_defaults = [
722
- resolve_literal(t, p.default) if p.default is not Parameter.empty else None
731
+ resolve_literal(t, p.default, Thunk.value(decls)) if p.default is not Parameter.empty else None
723
732
  for (t, p) in zip(arg_types, params, strict=True)
724
733
  ]
725
734
 
@@ -740,8 +749,8 @@ def _fn_decl(
740
749
  None
741
750
  if merge is None
742
751
  else merge(
743
- RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
744
- RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
752
+ RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old", False))),
753
+ RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new", False))),
745
754
  )
746
755
  )
747
756
  decls |= merged
@@ -751,29 +760,47 @@ def _fn_decl(
751
760
  if on_merge is None
752
761
  else _action_likes(
753
762
  on_merge(
754
- RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
755
- RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
763
+ RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old", False))),
764
+ RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new", False))),
756
765
  )
757
766
  )
758
767
  )
759
768
  decls.update(*merge_action)
760
- decl = FunctionDecl(
761
- FunctionSignature(
769
+ # defer this in generator so it doesnt resolve for builtins eagerly
770
+ args = (TypedExprDecl(tp.to_just(), VarDecl(name, False)) for name, tp in zip(arg_names, arg_types, strict=True))
771
+ res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
772
+ res_thunk: Callable[[], object]
773
+ # If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it
774
+ if not ref:
775
+ tuple_args = tuple(args)
776
+ res = _create_default_value(decls, ref, fn, tuple_args, ruleset)
777
+ assert isinstance(res, RuntimeExpr)
778
+ res_ref = UnnamedFunctionRef(tuple_args, res.__egg_typed_expr__)
779
+ decls._unnamed_functions.add(res_ref)
780
+ res_thunk = Thunk.value(res)
781
+
782
+ else:
783
+ signature_ = FunctionSignature(
762
784
  return_type=None if mutates_first_arg else return_type,
763
785
  var_arg_type=var_arg_type,
764
786
  arg_types=arg_types,
765
787
  arg_names=arg_names,
766
788
  arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
767
- ),
768
- cost=cost,
769
- egg_name=egg_name,
770
- merge=merged.__egg_typed_expr__.expr if merged is not None else None,
771
- unextractable=unextractable,
772
- builtin=is_builtin,
773
- default=None if default is None else default.__egg_typed_expr__.expr,
774
- on_merge=tuple(a.action for a in merge_action),
775
- )
776
- decls.set_function_decl(ref, decl)
789
+ )
790
+ decl = FunctionDecl(
791
+ signature=signature_,
792
+ cost=cost,
793
+ egg_name=egg_name,
794
+ merge=merged.__egg_typed_expr__.expr if merged is not None else None,
795
+ unextractable=unextractable,
796
+ builtin=is_builtin,
797
+ default=None if default is None else default.__egg_typed_expr__.expr,
798
+ on_merge=tuple(a.action for a in merge_action),
799
+ )
800
+ res_ref = ref
801
+ decls.set_function_decl(ref, decl)
802
+ res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset)
803
+ return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk)
777
804
 
778
805
 
779
806
  # Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
@@ -804,7 +831,7 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..
804
831
  Creates a function whose return type is `Unit` and has a default value.
805
832
  """
806
833
  decls_thunk = Thunk.fn(_relation_decls, name, tps, egg_fn)
807
- return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, FunctionRef(name)))
834
+ return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, Thunk.value(FunctionRef(name))))
808
835
 
809
836
 
810
837
  def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations:
@@ -828,7 +855,9 @@ def constant(
828
855
  A "constant" is implemented as the instantiation of a value that takes no args.
829
856
  This creates a function with `name` and return type `tp` and returns a value of it being called.
830
857
  """
831
- return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name, default_replacement, ruleset)))
858
+ return cast(
859
+ EXPR, RuntimeExpr(*split_thunk(Thunk.fn(_constant_thunk, name, tp, egg_name, default_replacement, ruleset)))
860
+ )
832
861
 
833
862
 
834
863
  def _constant_thunk(
@@ -842,24 +871,34 @@ def _constant_thunk(
842
871
  return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
843
872
 
844
873
 
845
- def _add_default_rewrite_function(decls: Declarations, ref: CallableRef, fn: Callable, ruleset: Ruleset | None) -> None:
846
- """
847
- Adds a default rewrite for a function, by calling the functions with vars and adding it if it is not None.
848
- """
849
- callable_decl = decls.get_callable_decl(ref)
850
- assert isinstance(callable_decl, FunctionDecl)
851
- signature = callable_decl.signature
852
- assert isinstance(signature, FunctionSignature)
853
-
854
- var_args: list[object] = [
855
- RuntimeExpr.__from_value__(decls, TypedExprDecl(tp.to_just(), VarDecl(_rule_var_name(name))))
856
- for name, tp in zip(signature.arg_names, signature.arg_types, strict=False)
857
- ]
874
+ def _create_default_value(
875
+ decls: Declarations,
876
+ ref: CallableRef | None,
877
+ fn: Callable,
878
+ args: Iterable[TypedExprDecl],
879
+ ruleset: Ruleset | None,
880
+ ) -> object:
881
+ args: list[object] = [RuntimeExpr.__from_values__(decls, a) for a in args]
882
+
858
883
  # If this is a classmethod, add the class as the first arg
859
884
  if isinstance(ref, ClassMethodRef):
860
885
  tp = decls.get_paramaterized_class(ref.class_name)
861
- var_args.insert(0, RuntimeClass(Thunk.value(decls), tp))
862
- _add_default_rewrite(decls, ref, signature.semantic_return_type, fn(*var_args), ruleset)
886
+ args.insert(0, RuntimeClass(Thunk.value(decls), tp))
887
+ with set_current_ruleset(ruleset):
888
+ return fn(*args)
889
+
890
+
891
+ def _add_default_rewrite_function(
892
+ decls: Declarations,
893
+ ref: CallableRef,
894
+ res_type: TypeOrVarRef,
895
+ ruleset: Ruleset | None,
896
+ value_thunk: Callable[[], object],
897
+ ) -> None:
898
+ """
899
+ Helper functions that resolves a value thunk to create the default value.
900
+ """
901
+ _add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset)
863
902
 
864
903
 
865
904
  def _add_default_rewrite(
@@ -872,7 +911,7 @@ def _add_default_rewrite(
872
911
  """
873
912
  if default_rewrite is None:
874
913
  return
875
- resolved_value = resolve_literal(type_ref, default_rewrite)
914
+ resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls))
876
915
  rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr)
877
916
  if ruleset:
878
917
  ruleset_decls = ruleset._current_egg_decls
@@ -931,6 +970,8 @@ class GraphvizKwargs(TypedDict, total=False):
931
970
  max_calls_per_function: int | None
932
971
  n_inline_leaves: int
933
972
  split_primitive_outputs: bool
973
+ split_functions: list[object]
974
+ include_temporary_functions: bool
934
975
 
935
976
 
936
977
  @dataclass
@@ -973,81 +1014,19 @@ class EGraph(_BaseModule):
973
1014
  raise ValueError(msg)
974
1015
  return cmds
975
1016
 
976
- def _repr_mimebundle_(self, *args, **kwargs):
977
- """
978
- Returns the graphviz representation of the e-graph.
979
- """
980
- return {"image/svg+xml": self.graphviz().pipe(format="svg", quiet=True, encoding="utf-8")}
981
-
982
- def graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
983
- # By default we want to split primitive outputs
984
- kwargs.setdefault("split_primitive_outputs", True)
985
- n_inline = kwargs.pop("n_inline_leaves", 0)
986
- serialized = self._egraph.serialize([], **kwargs) # type: ignore[misc]
987
- serialized.map_ops(self._state.op_mapping())
988
- for _ in range(n_inline):
989
- serialized.inline_leaves()
990
- original = serialized.to_dot()
991
- # Add link to stylesheet to the graph, so that edges light up on hover
992
- # https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
993
- styles = """/* the lines within the edges */
994
- .edge:active path,
995
- .edge:hover path {
996
- stroke: fuchsia;
997
- stroke-width: 3;
998
- stroke-opacity: 1;
999
- }
1000
- /* arrows are typically drawn with a polygon */
1001
- .edge:active polygon,
1002
- .edge:hover polygon {
1003
- stroke: fuchsia;
1004
- stroke-width: 3;
1005
- fill: fuchsia;
1006
- stroke-opacity: 1;
1007
- fill-opacity: 1;
1008
- }
1009
- /* If you happen to have text and want to color that as well... */
1010
- .edge:active text,
1011
- .edge:hover text {
1012
- fill: fuchsia;
1013
- }"""
1014
- p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
1015
- p.write_text(styles)
1016
- with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
1017
- return graphviz.Source(with_stylesheet)
1018
-
1019
- def graphviz_svg(self, **kwargs: Unpack[GraphvizKwargs]) -> str:
1020
- return self.graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")
1021
-
1022
- def _repr_html_(self) -> str:
1023
- """
1024
- Add a _repr_html_ to be an SVG to work with sphinx gallery.
1025
-
1026
- ala https://github.com/xflr6/graphviz/pull/121
1027
- until this PR is merged and released
1028
- https://github.com/sphinx-gallery/sphinx-gallery/pull/1138
1029
- """
1030
- return self.graphviz_svg()
1031
-
1032
- def display(self, **kwargs: Unpack[GraphvizKwargs]) -> None:
1033
- """
1034
- Displays the e-graph in the notebook.
1035
- """
1036
- if IN_IPYTHON:
1037
- from IPython.display import SVG, display
1038
-
1039
- display(SVG(self.graphviz_svg(**kwargs)))
1040
- else:
1041
- self.graphviz(**kwargs).render(view=True, format="svg", quiet=True)
1017
+ def _ipython_display_(self) -> None:
1018
+ self.display()
1042
1019
 
1043
1020
  def input(self, fn: Callable[..., String], path: str) -> None:
1044
1021
  """
1045
1022
  Loads a CSV file and sets it as *input, output of the function.
1046
1023
  """
1024
+ self._egraph.run_program(bindings.Input(self._callable_to_egg(fn), path))
1025
+
1026
+ def _callable_to_egg(self, fn: object) -> str:
1047
1027
  ref, decls = resolve_callable(fn)
1048
1028
  self._add_decls(decls)
1049
- fn_name = self._state.callable_ref_to_egg(ref)
1050
- self._egraph.run_program(bindings.Input(fn_name, path))
1029
+ return self._state.callable_ref_to_egg(ref)
1051
1030
 
1052
1031
  def let(self, name: str, expr: EXPR) -> EXPR:
1053
1032
  """
@@ -1059,8 +1038,8 @@ class EGraph(_BaseModule):
1059
1038
  self._add_decls(runtime_expr)
1060
1039
  return cast(
1061
1040
  EXPR,
1062
- RuntimeExpr.__from_value__(
1063
- self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name))
1041
+ RuntimeExpr.__from_values__(
1042
+ self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name, True))
1064
1043
  ),
1065
1044
  )
1066
1045
 
@@ -1090,7 +1069,7 @@ class EGraph(_BaseModule):
1090
1069
  msg = "No extract report saved"
1091
1070
  raise ValueError(msg) # noqa: TRY004
1092
1071
  (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
1093
- return cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
1072
+ return cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
1094
1073
 
1095
1074
  def include(self, path: str) -> None:
1096
1075
  """
@@ -1145,7 +1124,7 @@ class EGraph(_BaseModule):
1145
1124
  facts = _fact_likes(fact_likes)
1146
1125
  self._add_decls(*facts)
1147
1126
  egg_facts = [self._state.fact_to_egg(f.fact) for f in _fact_likes(facts)]
1148
- return bindings.Check(egg_facts)
1127
+ return bindings.Check(bindings.DUMMY_SPAN, egg_facts)
1149
1128
 
1150
1129
  @overload
1151
1130
  def extract(self, expr: EXPR, /, include_cost: Literal[False] = False) -> EXPR: ...
@@ -1167,7 +1146,7 @@ class EGraph(_BaseModule):
1167
1146
  raise ValueError(msg) # noqa: TRY004
1168
1147
  (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
1169
1148
 
1170
- res = cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
1149
+ res = cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
1171
1150
  if include_cost:
1172
1151
  return res, extract_report.cost
1173
1152
  return res
@@ -1185,11 +1164,15 @@ class EGraph(_BaseModule):
1185
1164
  msg = "Wrong extract report type"
1186
1165
  raise ValueError(msg) # noqa: TRY004
1187
1166
  new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp)
1188
- return [cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, expr)) for expr in new_exprs]
1167
+ return [cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs]
1189
1168
 
1190
1169
  def _run_extract(self, typed_expr: TypedExprDecl, n: int) -> bindings._ExtractReport:
1191
1170
  expr = self._state.typed_expr_to_egg(typed_expr)
1192
- self._egraph.run_program(bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n)))))
1171
+ self._egraph.run_program(
1172
+ bindings.ActionCommand(
1173
+ bindings.Extract(bindings.DUMMY_SPAN, expr, bindings.Lit(bindings.DUMMY_SPAN, bindings.Int(n)))
1174
+ )
1175
+ )
1193
1176
  extract_report = self._egraph.extract_report()
1194
1177
  if not extract_report:
1195
1178
  msg = "No extract report saved"
@@ -1261,42 +1244,110 @@ class EGraph(_BaseModule):
1261
1244
  return self._egraph.eval_py_object(egg_expr)
1262
1245
  raise TypeError(f"Eval not implemented for {typed_expr.tp}")
1263
1246
 
1264
- def saturate(
1265
- self, *, max: int = 1000, performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
1266
- ) -> ipywidgets.Widget:
1267
- from .graphviz_widget import graphviz_widget_with_slider
1247
+ def _serialize(
1248
+ self,
1249
+ **kwargs: Unpack[GraphvizKwargs],
1250
+ ) -> bindings.SerializedEGraph:
1251
+ max_functions = kwargs.pop("max_functions", None)
1252
+ max_calls_per_function = kwargs.pop("max_calls_per_function", None)
1253
+ split_primitive_outputs = kwargs.pop("split_primitive_outputs", True)
1254
+ split_functions = kwargs.pop("split_functions", [])
1255
+ include_temporary_functions = kwargs.pop("include_temporary_functions", False)
1256
+ n_inline_leaves = kwargs.pop("n_inline_leaves", 1)
1257
+ serialized = self._egraph.serialize(
1258
+ [],
1259
+ max_functions=max_functions,
1260
+ max_calls_per_function=max_calls_per_function,
1261
+ include_temporary_functions=include_temporary_functions,
1262
+ )
1263
+ if split_primitive_outputs or split_functions:
1264
+ additional_ops = set(map(self._callable_to_egg, split_functions))
1265
+ serialized.split_e_classes(self._egraph, additional_ops)
1266
+ serialized.map_ops(self._state.op_mapping())
1268
1267
 
1269
- dots = [str(self.graphviz(**kwargs))]
1270
- i = 0
1271
- while self.run(1).updated and i < max:
1272
- i += 1
1273
- dots.append(str(self.graphviz(**kwargs)))
1274
- return graphviz_widget_with_slider(dots, performance=performance)
1268
+ for _ in range(n_inline_leaves):
1269
+ serialized.inline_leaves()
1275
1270
 
1276
- def saturate_to_html(
1277
- self, file: str = "tmp.html", performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
1278
- ) -> None:
1279
- # raise NotImplementedError("Upstream bugs prevent rendering to HTML")
1271
+ return serialized
1280
1272
 
1281
- # import panel
1273
+ def _graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
1274
+ serialized = self._serialize(**kwargs)
1282
1275
 
1283
- # panel.extension("ipywidgets")
1276
+ original = serialized.to_dot()
1277
+ # Add link to stylesheet to the graph, so that edges light up on hover
1278
+ # https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
1279
+ styles = """/* the lines within the edges */
1280
+ .edge:active path,
1281
+ .edge:hover path {
1282
+ stroke: fuchsia;
1283
+ stroke-width: 3;
1284
+ stroke-opacity: 1;
1285
+ }
1286
+ /* arrows are typically drawn with a polygon */
1287
+ .edge:active polygon,
1288
+ .edge:hover polygon {
1289
+ stroke: fuchsia;
1290
+ stroke-width: 3;
1291
+ fill: fuchsia;
1292
+ stroke-opacity: 1;
1293
+ fill-opacity: 1;
1294
+ }
1295
+ /* If you happen to have text and want to color that as well... */
1296
+ .edge:active text,
1297
+ .edge:hover text {
1298
+ fill: fuchsia;
1299
+ }"""
1300
+ p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
1301
+ p.write_text(styles)
1302
+ with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
1303
+ return graphviz.Source(with_stylesheet)
1304
+
1305
+ def display(self, graphviz: bool = False, **kwargs: Unpack[GraphvizKwargs]) -> None:
1306
+ """
1307
+ Displays the e-graph.
1284
1308
 
1285
- widget = self.saturate(performance=performance, **kwargs)
1286
- # panel.panel(widget).save(file)
1309
+ If in IPython it will display it inline, otherwise it will write it to a file and open it.
1310
+ """
1311
+ from IPython.display import SVG, display
1312
+
1313
+ from .visualizer_widget import VisualizerWidget
1314
+
1315
+ if graphviz:
1316
+ if IN_IPYTHON:
1317
+ svg = self._graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")
1318
+ display(SVG(svg))
1319
+ else:
1320
+ self._graphviz(**kwargs).render(view=True, format="svg", quiet=True)
1321
+ else:
1322
+ serialized = self._serialize(**kwargs)
1323
+ VisualizerWidget(egraphs=[serialized.to_json()]).display_or_open()
1287
1324
 
1288
- from ipywidgets.embed import embed_minimal_html
1325
+ def saturate(self, schedule: Schedule | None = None, *, max: int = 1000, **kwargs: Unpack[GraphvizKwargs]) -> None:
1326
+ """
1327
+ Saturate the egraph, running the given schedule until the egraph is saturated.
1328
+ It serializes the egraph at each step and returns a widget to visualize the egraph.
1329
+ """
1330
+ from .visualizer_widget import VisualizerWidget
1289
1331
 
1290
- embed_minimal_html("tmp.html", views=[widget], drop_defaults=False)
1291
- # Use panel while this issue persists
1292
- # https://github.com/jupyter-widgets/ipywidgets/issues/3761#issuecomment-1755563436
1332
+ def to_json() -> str:
1333
+ return self._serialize(**kwargs).to_json()
1334
+
1335
+ egraphs = [to_json()]
1336
+ i = 0
1337
+ while self.run(schedule or 1).updated and i < max:
1338
+ i += 1
1339
+ egraphs.append(to_json())
1340
+ VisualizerWidget(egraphs=egraphs).display_or_open()
1293
1341
 
1294
1342
  @classmethod
1295
1343
  def current(cls) -> EGraph:
1296
1344
  """
1297
1345
  Returns the current egraph, which is the one in the context.
1298
1346
  """
1299
- return CURRENT_EGRAPH.get()
1347
+ try:
1348
+ return CURRENT_EGRAPH.get()
1349
+ except LookupError:
1350
+ return cls(save_egglog_string=True)
1300
1351
 
1301
1352
  @property
1302
1353
  def _egraph(self) -> bindings.EGraph:
@@ -1448,7 +1499,8 @@ class Ruleset(Schedule):
1448
1499
  To return the egg decls, we go through our deferred rules and add any we haven't yet
1449
1500
  """
1450
1501
  while self.deferred_rule_gens:
1451
- rules = self.deferred_rule_gens.pop()()
1502
+ with set_current_ruleset(self):
1503
+ rules = self.deferred_rule_gens.pop()()
1452
1504
  self._current_egg_decls.update(*rules)
1453
1505
  self.__egg_ruleset__.rules.extend(r.decl for r in rules)
1454
1506
  return self._current_egg_decls
@@ -1688,16 +1740,16 @@ def action_command(action: Action) -> Action:
1688
1740
  return action
1689
1741
 
1690
1742
 
1691
- def var(name: str, bound: type[EXPR]) -> EXPR:
1743
+ def var(name: str, bound: type[T]) -> T:
1692
1744
  """Create a new variable with the given name and type."""
1693
- return cast(EXPR, _var(name, bound))
1745
+ return cast(T, _var(name, bound))
1694
1746
 
1695
1747
 
1696
1748
  def _var(name: str, bound: object) -> RuntimeExpr:
1697
1749
  """Create a new variable with the given name and type."""
1698
1750
  decls = Declarations()
1699
1751
  type_ref = resolve_type_annotation(decls, bound)
1700
- return RuntimeExpr.__from_value__(decls, TypedExprDecl(type_ref.to_just(), VarDecl(name)))
1752
+ return RuntimeExpr.__from_values__(decls, TypedExprDecl(type_ref.to_just(), VarDecl(name, False)))
1701
1753
 
1702
1754
 
1703
1755
  def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
@@ -1790,7 +1842,7 @@ class _NeBuilder(Generic[EXPR]):
1790
1842
  lhs = to_runtime_expr(self.lhs)
1791
1843
  rhs = convert_to_same_type(rhs, lhs)
1792
1844
  assert isinstance(Unit, RuntimeClass)
1793
- res = RuntimeExpr.__from_value__(
1845
+ res = RuntimeExpr.__from_values__(
1794
1846
  Declarations.create(Unit, lhs, rhs),
1795
1847
  TypedExprDecl(
1796
1848
  JustTypeRef("Unit"), CallDecl(FunctionRef("!="), (lhs.__egg_typed_expr__, rhs.__egg_typed_expr__))
@@ -1944,7 +1996,7 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
1944
1996
  if "Callable" not in globals:
1945
1997
  globals["Callable"] = Callable
1946
1998
  hints = get_type_hints(gen, globals, frame.f_locals)
1947
- args = [_var(_rule_var_name(p.name), hints[p.name]) for p in signature(gen).parameters.values()]
1999
+ args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
1948
2000
  return list(gen(*args)) # type: ignore[misc]
1949
2001
 
1950
2002
 
@@ -1959,3 +2011,19 @@ def _fact_like(fact_like: FactLike) -> Fact:
1959
2011
  if isinstance(fact_like, Expr):
1960
2012
  return expr_fact(fact_like)
1961
2013
  return fact_like
2014
+
2015
+
2016
+ _CURRENT_RULESET = ContextVar[Ruleset | None]("CURRENT_RULESET", default=None)
2017
+
2018
+
2019
+ def get_current_ruleset() -> Ruleset | None:
2020
+ return _CURRENT_RULESET.get()
2021
+
2022
+
2023
+ @contextlib.contextmanager
2024
+ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
2025
+ token = _CURRENT_RULESET.set(r)
2026
+ try:
2027
+ yield
2028
+ finally:
2029
+ _CURRENT_RULESET.reset(token)