egglog 7.2.0__cp310-none-win_amd64.whl → 8.0.1__cp310-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph.py CHANGED
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import contextlib
3
4
  import inspect
4
5
  import pathlib
5
6
  import tempfile
6
7
  from abc import abstractmethod
7
- from collections.abc import Callable, Iterable
8
+ from collections.abc import Callable, Generator, Iterable
8
9
  from contextvars import ContextVar, Token
9
10
  from dataclasses import InitVar, dataclass, field
10
11
  from inspect import Parameter, currentframe, signature
@@ -37,8 +38,6 @@ from .runtime import *
37
38
  from .thunk import *
38
39
 
39
40
  if TYPE_CHECKING:
40
- import ipywidgets
41
-
42
41
  from .builtins import Bool, PyObject, String, f64, i64
43
42
 
44
43
 
@@ -464,10 +463,14 @@ class _ExprMetaclass(type):
464
463
  prev_frame = frame.f_back
465
464
  assert prev_frame
466
465
 
466
+ # Pass in an instance of the class so that when we are generating the decls
467
+ # we can update them eagerly so that we can access the methods in the class body
468
+ runtime_cls = RuntimeClass(None, TypeRefWithVars(name)) # type: ignore[arg-type]
469
+
467
470
  # Store frame so that we can get live access to updated locals/globals
468
471
  # Otherwise, f_locals returns a copy
469
472
  # https://peps.python.org/pep-0667/
470
- decls_thunk = Thunk.fn(
473
+ runtime_cls.__egg_decls_thunk__ = Thunk.fn(
471
474
  _generate_class_decls,
472
475
  namespace,
473
476
  prev_frame,
@@ -475,21 +478,22 @@ class _ExprMetaclass(type):
475
478
  egg_sort,
476
479
  name,
477
480
  ruleset,
478
- fallback=Declarations,
481
+ runtime_cls,
479
482
  )
480
- return RuntimeClass(decls_thunk, TypeRefWithVars(name))
483
+ return runtime_cls
481
484
 
482
485
  def __instancecheck__(cls, instance: object) -> bool:
483
486
  return isinstance(instance, RuntimeExpr)
484
487
 
485
488
 
486
- def _generate_class_decls(
489
+ def _generate_class_decls( # noqa: C901,PLR0912
487
490
  namespace: dict[str, Any],
488
491
  frame: FrameType,
489
492
  builtin: bool,
490
493
  egg_sort: str | None,
491
494
  cls_name: str,
492
495
  ruleset: Ruleset | None,
496
+ runtime_cls: RuntimeClass,
493
497
  ) -> Declarations:
494
498
  """
495
499
  Lazy constructor for class declerations to support classes with methods whose types are not yet defined.
@@ -502,6 +506,8 @@ def _generate_class_decls(
502
506
  del parameters
503
507
  cls_decl = ClassDecl(egg_sort, type_vars, builtin)
504
508
  decls = Declarations(_classes={cls_name: cls_decl})
509
+ # Update class think eagerly when resolving so that lookups work in methods
510
+ runtime_cls.__egg_decls_thunk__ = Thunk.value(decls)
505
511
 
506
512
  ##
507
513
  # Register class variables
@@ -522,16 +528,13 @@ def _generate_class_decls(
522
528
  # Register methods, classmethods, preserved methods, and properties
523
529
  ##
524
530
 
525
- # The type ref of self is paramterized by the type vars
526
- TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
527
-
528
531
  # Get all the methods from the class
529
532
  filtered_namespace: list[tuple[str, Any]] = [
530
533
  (k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
531
534
  ]
532
535
 
533
536
  # all methods we should try adding default functions for
534
- default_function_refs: dict[ClassMethodRef | MethodRef | PropertyRef, Callable] = {}
537
+ add_default_funcs: list[Callable[[], None]] = []
535
538
  # Then register each of its methods
536
539
  for method_name, method in filtered_namespace:
537
540
  is_init = method_name == "__init__"
@@ -550,7 +553,6 @@ def _generate_class_decls(
550
553
  cls_decl.preserved_methods[method_name] = fn
551
554
  continue
552
555
  locals = frame.f_locals
553
-
554
556
  ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
555
557
  match fn:
556
558
  case classmethod():
@@ -561,16 +563,25 @@ def _generate_class_decls(
561
563
  fn = fn.fget
562
564
  case _:
563
565
  ref = InitRef(cls_name) if is_init else MethodRef(cls_name, method_name)
566
+ special_function_name: SpecialFunctions | None = (
567
+ "fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None
568
+ )
569
+ if special_function_name:
570
+ decl = FunctionDecl(special_function_name, builtin=True, egg_name=egg_fn)
571
+ decls.set_function_decl(ref, decl)
572
+ continue
564
573
 
565
- _fn_decl(decls, egg_fn, ref, fn, locals, default, cost, merge, on_merge, mutates, builtin, unextractable)
574
+ _, add_rewrite = _fn_decl(
575
+ decls, egg_fn, ref, fn, locals, default, cost, merge, on_merge, mutates, builtin, ruleset, unextractable
576
+ )
566
577
 
567
578
  if not builtin and not isinstance(ref, InitRef) and not mutates:
568
- default_function_refs[ref] = fn
579
+ add_default_funcs.append(add_rewrite)
569
580
 
570
581
  # Add all rewrite methods at the end so that all methods are registered first and can be accessed
571
582
  # in the bodies
572
- for ref, fn in default_function_refs.items():
573
- _add_default_rewrite_function(decls, ref, fn, ruleset)
583
+ for add_rewrite in add_default_funcs:
584
+ add_rewrite()
574
585
 
575
586
  return decls
576
587
 
@@ -590,6 +601,7 @@ def function(
590
601
  unextractable: bool = False,
591
602
  builtin: bool = False,
592
603
  ruleset: Ruleset | None = None,
604
+ use_body_as_name: bool = False,
593
605
  ) -> Callable[[CALLABLE], CALLABLE]: ...
594
606
 
595
607
 
@@ -604,6 +616,7 @@ def function(
604
616
  mutates_first_arg: bool = False,
605
617
  unextractable: bool = False,
606
618
  ruleset: Ruleset | None = None,
619
+ use_body_as_name: bool = False,
607
620
  ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...
608
621
 
609
622
 
@@ -635,16 +648,18 @@ class _FunctionConstructor:
635
648
  on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
636
649
  unextractable: bool = False
637
650
  ruleset: Ruleset | None = None
651
+ use_body_as_name: bool = False
638
652
 
639
653
  def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
640
- return RuntimeFunction(Thunk.fn(self.create_decls, fn), FunctionRef(fn.__name__))
654
+ return RuntimeFunction(*split_thunk(Thunk.fn(self.create_decls, fn)))
641
655
 
642
- def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations:
656
+ def create_decls(self, fn: Callable[..., RuntimeExpr]) -> tuple[Declarations, CallableRef]:
643
657
  decls = Declarations()
644
- _fn_decl(
658
+ ref = None if self.use_body_as_name else FunctionRef(fn.__name__)
659
+ ref, add_rewrite = _fn_decl(
645
660
  decls,
646
661
  self.egg_fn,
647
- (ref := FunctionRef(fn.__name__)),
662
+ ref,
648
663
  fn,
649
664
  self.hint_locals,
650
665
  self.default,
@@ -653,16 +668,18 @@ class _FunctionConstructor:
653
668
  self.on_merge,
654
669
  self.mutates_first_arg,
655
670
  self.builtin,
671
+ self.ruleset,
656
672
  unextractable=self.unextractable,
657
673
  )
658
- _add_default_rewrite_function(decls, ref, fn, self.ruleset)
659
- return decls
674
+ add_rewrite()
675
+ return decls, ref
660
676
 
661
677
 
662
678
  def _fn_decl(
663
679
  decls: Declarations,
664
680
  egg_name: str | None,
665
- ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
681
+ # If ref is Callable, then generate the ref from the function name
682
+ ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef | None,
666
683
  fn: object,
667
684
  # Pass in the locals, retrieved from the frame when wrapping,
668
685
  # so that we support classes and function defined inside of other functions (which won't show up in the globals)
@@ -673,30 +690,26 @@ def _fn_decl(
673
690
  on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None,
674
691
  mutates_first_arg: bool,
675
692
  is_builtin: bool,
693
+ ruleset: Ruleset | None = None,
676
694
  unextractable: bool = False,
677
- ) -> None:
695
+ ) -> tuple[CallableRef, Callable[[], None]]:
678
696
  """
679
- Sets the function decl for the function object.
697
+ Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
680
698
  """
681
699
  if not isinstance(fn, FunctionType):
682
700
  raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
683
701
 
684
- # partial function creation and calling are handled with a special case in the type checker, so don't
685
- # use the normal logic
686
- special_function_name: SpecialFunctions | None = (
687
- "fn-partial" if egg_name == "unstable-fn" else "fn-app" if egg_name == "unstable-app" else None
688
- )
689
- if special_function_name:
690
- decls.set_function_decl(ref, FunctionDecl(special_function_name, builtin=True, egg_name=egg_name))
691
- return
692
-
693
702
  hint_globals = fn.__globals__.copy()
694
703
  # Copy Callable into global if not present bc sometimes it gets automatically removed by ruff to type only block
695
704
  # https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/
696
705
  if "Callable" not in hint_globals:
697
706
  hint_globals["Callable"] = Callable
698
-
699
- hints = get_type_hints(fn, hint_globals, hint_locals)
707
+ # Instead of passing both globals and locals, just pass the globals. Otherwise, for some reason forward references
708
+ # won't be resolved correctly
709
+ # We need this to be false so it returns "__forward_value__" https://github.com/python/cpython/blob/440ed18e08887b958ad50db1b823e692a747b671/Lib/typing.py#L919
710
+ # https://github.com/egraphs-good/egglog-python/issues/210
711
+ hint_globals.update(hint_locals)
712
+ hints = get_type_hints(fn, hint_globals)
700
713
 
701
714
  params = list(signature(fn).parameters.values())
702
715
 
@@ -719,7 +732,7 @@ def _fn_decl(
719
732
 
720
733
  # Resolve all default values as arg types
721
734
  arg_defaults = [
722
- resolve_literal(t, p.default) if p.default is not Parameter.empty else None
735
+ resolve_literal(t, p.default, Thunk.value(decls)) if p.default is not Parameter.empty else None
723
736
  for (t, p) in zip(arg_types, params, strict=True)
724
737
  ]
725
738
 
@@ -740,8 +753,8 @@ def _fn_decl(
740
753
  None
741
754
  if merge is None
742
755
  else merge(
743
- RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
744
- RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
756
+ RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old", False))),
757
+ RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new", False))),
745
758
  )
746
759
  )
747
760
  decls |= merged
@@ -751,29 +764,47 @@ def _fn_decl(
751
764
  if on_merge is None
752
765
  else _action_likes(
753
766
  on_merge(
754
- RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
755
- RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
767
+ RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old", False))),
768
+ RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new", False))),
756
769
  )
757
770
  )
758
771
  )
759
772
  decls.update(*merge_action)
760
- decl = FunctionDecl(
761
- FunctionSignature(
773
+ # defer this in generator so it doesnt resolve for builtins eagerly
774
+ args = (TypedExprDecl(tp.to_just(), VarDecl(name, False)) for name, tp in zip(arg_names, arg_types, strict=True))
775
+ res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
776
+ res_thunk: Callable[[], object]
777
+ # If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it
778
+ if not ref:
779
+ tuple_args = tuple(args)
780
+ res = _create_default_value(decls, ref, fn, tuple_args, ruleset)
781
+ assert isinstance(res, RuntimeExpr)
782
+ res_ref = UnnamedFunctionRef(tuple_args, res.__egg_typed_expr__)
783
+ decls._unnamed_functions.add(res_ref)
784
+ res_thunk = Thunk.value(res)
785
+
786
+ else:
787
+ signature_ = FunctionSignature(
762
788
  return_type=None if mutates_first_arg else return_type,
763
789
  var_arg_type=var_arg_type,
764
790
  arg_types=arg_types,
765
791
  arg_names=arg_names,
766
792
  arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
767
- ),
768
- cost=cost,
769
- egg_name=egg_name,
770
- merge=merged.__egg_typed_expr__.expr if merged is not None else None,
771
- unextractable=unextractable,
772
- builtin=is_builtin,
773
- default=None if default is None else default.__egg_typed_expr__.expr,
774
- on_merge=tuple(a.action for a in merge_action),
775
- )
776
- decls.set_function_decl(ref, decl)
793
+ )
794
+ decl = FunctionDecl(
795
+ signature=signature_,
796
+ cost=cost,
797
+ egg_name=egg_name,
798
+ merge=merged.__egg_typed_expr__.expr if merged is not None else None,
799
+ unextractable=unextractable,
800
+ builtin=is_builtin,
801
+ default=None if default is None else default.__egg_typed_expr__.expr,
802
+ on_merge=tuple(a.action for a in merge_action),
803
+ )
804
+ res_ref = ref
805
+ decls.set_function_decl(ref, decl)
806
+ res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset)
807
+ return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk)
777
808
 
778
809
 
779
810
  # Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
@@ -804,7 +835,7 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..
804
835
  Creates a function whose return type is `Unit` and has a default value.
805
836
  """
806
837
  decls_thunk = Thunk.fn(_relation_decls, name, tps, egg_fn)
807
- return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, FunctionRef(name)))
838
+ return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, Thunk.value(FunctionRef(name))))
808
839
 
809
840
 
810
841
  def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations:
@@ -828,7 +859,9 @@ def constant(
828
859
  A "constant" is implemented as the instantiation of a value that takes no args.
829
860
  This creates a function with `name` and return type `tp` and returns a value of it being called.
830
861
  """
831
- return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name, default_replacement, ruleset)))
862
+ return cast(
863
+ EXPR, RuntimeExpr(*split_thunk(Thunk.fn(_constant_thunk, name, tp, egg_name, default_replacement, ruleset)))
864
+ )
832
865
 
833
866
 
834
867
  def _constant_thunk(
@@ -842,24 +875,34 @@ def _constant_thunk(
842
875
  return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
843
876
 
844
877
 
845
- def _add_default_rewrite_function(decls: Declarations, ref: CallableRef, fn: Callable, ruleset: Ruleset | None) -> None:
846
- """
847
- Adds a default rewrite for a function, by calling the functions with vars and adding it if it is not None.
848
- """
849
- callable_decl = decls.get_callable_decl(ref)
850
- assert isinstance(callable_decl, FunctionDecl)
851
- signature = callable_decl.signature
852
- assert isinstance(signature, FunctionSignature)
853
-
854
- var_args: list[object] = [
855
- RuntimeExpr.__from_value__(decls, TypedExprDecl(tp.to_just(), VarDecl(_rule_var_name(name))))
856
- for name, tp in zip(signature.arg_names, signature.arg_types, strict=False)
857
- ]
878
+ def _create_default_value(
879
+ decls: Declarations,
880
+ ref: CallableRef | None,
881
+ fn: Callable,
882
+ args: Iterable[TypedExprDecl],
883
+ ruleset: Ruleset | None,
884
+ ) -> object:
885
+ args: list[object] = [RuntimeExpr.__from_values__(decls, a) for a in args]
886
+
858
887
  # If this is a classmethod, add the class as the first arg
859
888
  if isinstance(ref, ClassMethodRef):
860
889
  tp = decls.get_paramaterized_class(ref.class_name)
861
- var_args.insert(0, RuntimeClass(Thunk.value(decls), tp))
862
- _add_default_rewrite(decls, ref, signature.semantic_return_type, fn(*var_args), ruleset)
890
+ args.insert(0, RuntimeClass(Thunk.value(decls), tp))
891
+ with set_current_ruleset(ruleset):
892
+ return fn(*args)
893
+
894
+
895
+ def _add_default_rewrite_function(
896
+ decls: Declarations,
897
+ ref: CallableRef,
898
+ res_type: TypeOrVarRef,
899
+ ruleset: Ruleset | None,
900
+ value_thunk: Callable[[], object],
901
+ ) -> None:
902
+ """
903
+ Helper functions that resolves a value thunk to create the default value.
904
+ """
905
+ _add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset)
863
906
 
864
907
 
865
908
  def _add_default_rewrite(
@@ -872,7 +915,7 @@ def _add_default_rewrite(
872
915
  """
873
916
  if default_rewrite is None:
874
917
  return
875
- resolved_value = resolve_literal(type_ref, default_rewrite)
918
+ resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls))
876
919
  rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr)
877
920
  if ruleset:
878
921
  ruleset_decls = ruleset._current_egg_decls
@@ -931,6 +974,8 @@ class GraphvizKwargs(TypedDict, total=False):
931
974
  max_calls_per_function: int | None
932
975
  n_inline_leaves: int
933
976
  split_primitive_outputs: bool
977
+ split_functions: list[object]
978
+ include_temporary_functions: bool
934
979
 
935
980
 
936
981
  @dataclass
@@ -973,81 +1018,19 @@ class EGraph(_BaseModule):
973
1018
  raise ValueError(msg)
974
1019
  return cmds
975
1020
 
976
- def _repr_mimebundle_(self, *args, **kwargs):
977
- """
978
- Returns the graphviz representation of the e-graph.
979
- """
980
- return {"image/svg+xml": self.graphviz().pipe(format="svg", quiet=True, encoding="utf-8")}
981
-
982
- def graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
983
- # By default we want to split primitive outputs
984
- kwargs.setdefault("split_primitive_outputs", True)
985
- n_inline = kwargs.pop("n_inline_leaves", 0)
986
- serialized = self._egraph.serialize([], **kwargs) # type: ignore[misc]
987
- serialized.map_ops(self._state.op_mapping())
988
- for _ in range(n_inline):
989
- serialized.inline_leaves()
990
- original = serialized.to_dot()
991
- # Add link to stylesheet to the graph, so that edges light up on hover
992
- # https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
993
- styles = """/* the lines within the edges */
994
- .edge:active path,
995
- .edge:hover path {
996
- stroke: fuchsia;
997
- stroke-width: 3;
998
- stroke-opacity: 1;
999
- }
1000
- /* arrows are typically drawn with a polygon */
1001
- .edge:active polygon,
1002
- .edge:hover polygon {
1003
- stroke: fuchsia;
1004
- stroke-width: 3;
1005
- fill: fuchsia;
1006
- stroke-opacity: 1;
1007
- fill-opacity: 1;
1008
- }
1009
- /* If you happen to have text and want to color that as well... */
1010
- .edge:active text,
1011
- .edge:hover text {
1012
- fill: fuchsia;
1013
- }"""
1014
- p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
1015
- p.write_text(styles)
1016
- with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
1017
- return graphviz.Source(with_stylesheet)
1018
-
1019
- def graphviz_svg(self, **kwargs: Unpack[GraphvizKwargs]) -> str:
1020
- return self.graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")
1021
-
1022
- def _repr_html_(self) -> str:
1023
- """
1024
- Add a _repr_html_ to be an SVG to work with sphinx gallery.
1025
-
1026
- ala https://github.com/xflr6/graphviz/pull/121
1027
- until this PR is merged and released
1028
- https://github.com/sphinx-gallery/sphinx-gallery/pull/1138
1029
- """
1030
- return self.graphviz_svg()
1031
-
1032
- def display(self, **kwargs: Unpack[GraphvizKwargs]) -> None:
1033
- """
1034
- Displays the e-graph in the notebook.
1035
- """
1036
- if IN_IPYTHON:
1037
- from IPython.display import SVG, display
1038
-
1039
- display(SVG(self.graphviz_svg(**kwargs)))
1040
- else:
1041
- self.graphviz(**kwargs).render(view=True, format="svg", quiet=True)
1021
+ def _ipython_display_(self) -> None:
1022
+ self.display()
1042
1023
 
1043
1024
  def input(self, fn: Callable[..., String], path: str) -> None:
1044
1025
  """
1045
1026
  Loads a CSV file and sets it as *input, output of the function.
1046
1027
  """
1028
+ self._egraph.run_program(bindings.Input(bindings.DUMMY_SPAN, self._callable_to_egg(fn), path))
1029
+
1030
+ def _callable_to_egg(self, fn: object) -> str:
1047
1031
  ref, decls = resolve_callable(fn)
1048
1032
  self._add_decls(decls)
1049
- fn_name = self._state.callable_ref_to_egg(ref)
1050
- self._egraph.run_program(bindings.Input(fn_name, path))
1033
+ return self._state.callable_ref_to_egg(ref)
1051
1034
 
1052
1035
  def let(self, name: str, expr: EXPR) -> EXPR:
1053
1036
  """
@@ -1059,8 +1042,8 @@ class EGraph(_BaseModule):
1059
1042
  self._add_decls(runtime_expr)
1060
1043
  return cast(
1061
1044
  EXPR,
1062
- RuntimeExpr.__from_value__(
1063
- self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name))
1045
+ RuntimeExpr.__from_values__(
1046
+ self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name, True))
1064
1047
  ),
1065
1048
  )
1066
1049
 
@@ -1084,13 +1067,13 @@ class EGraph(_BaseModule):
1084
1067
  typed_expr = runtime_expr.__egg_typed_expr__
1085
1068
  # Must also register type
1086
1069
  egg_expr = self._state.typed_expr_to_egg(typed_expr)
1087
- self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule))
1070
+ self._egraph.run_program(bindings.Simplify(bindings.DUMMY_SPAN, egg_expr, egg_schedule))
1088
1071
  extract_report = self._egraph.extract_report()
1089
1072
  if not isinstance(extract_report, bindings.Best):
1090
1073
  msg = "No extract report saved"
1091
1074
  raise ValueError(msg) # noqa: TRY004
1092
1075
  (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
1093
- return cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
1076
+ return cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
1094
1077
 
1095
1078
  def include(self, path: str) -> None:
1096
1079
  """
@@ -1139,13 +1122,13 @@ class EGraph(_BaseModule):
1139
1122
  """
1140
1123
  Checks that one of the facts is not true
1141
1124
  """
1142
- self._egraph.run_program(bindings.Fail(self._facts_to_check(facts)))
1125
+ self._egraph.run_program(bindings.Fail(bindings.DUMMY_SPAN, self._facts_to_check(facts)))
1143
1126
 
1144
1127
  def _facts_to_check(self, fact_likes: Iterable[FactLike]) -> bindings.Check:
1145
1128
  facts = _fact_likes(fact_likes)
1146
1129
  self._add_decls(*facts)
1147
1130
  egg_facts = [self._state.fact_to_egg(f.fact) for f in _fact_likes(facts)]
1148
- return bindings.Check(egg_facts)
1131
+ return bindings.Check(bindings.DUMMY_SPAN, egg_facts)
1149
1132
 
1150
1133
  @overload
1151
1134
  def extract(self, expr: EXPR, /, include_cost: Literal[False] = False) -> EXPR: ...
@@ -1167,7 +1150,7 @@ class EGraph(_BaseModule):
1167
1150
  raise ValueError(msg) # noqa: TRY004
1168
1151
  (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
1169
1152
 
1170
- res = cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
1153
+ res = cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
1171
1154
  if include_cost:
1172
1155
  return res, extract_report.cost
1173
1156
  return res
@@ -1185,11 +1168,15 @@ class EGraph(_BaseModule):
1185
1168
  msg = "Wrong extract report type"
1186
1169
  raise ValueError(msg) # noqa: TRY004
1187
1170
  new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp)
1188
- return [cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, expr)) for expr in new_exprs]
1171
+ return [cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs]
1189
1172
 
1190
1173
  def _run_extract(self, typed_expr: TypedExprDecl, n: int) -> bindings._ExtractReport:
1191
1174
  expr = self._state.typed_expr_to_egg(typed_expr)
1192
- self._egraph.run_program(bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n)))))
1175
+ self._egraph.run_program(
1176
+ bindings.ActionCommand(
1177
+ bindings.Extract(bindings.DUMMY_SPAN, expr, bindings.Lit(bindings.DUMMY_SPAN, bindings.Int(n)))
1178
+ )
1179
+ )
1193
1180
  extract_report = self._egraph.extract_report()
1194
1181
  if not extract_report:
1195
1182
  msg = "No extract report saved"
@@ -1208,7 +1195,7 @@ class EGraph(_BaseModule):
1208
1195
  """
1209
1196
  Pop the current state of the egraph, reverting back to the previous state.
1210
1197
  """
1211
- self._egraph.run_program(bindings.Pop(1))
1198
+ self._egraph.run_program(bindings.Pop(bindings.DUMMY_SPAN, 1))
1212
1199
  self._state = self._state_stack.pop()
1213
1200
 
1214
1201
  def __enter__(self) -> Self:
@@ -1221,7 +1208,7 @@ class EGraph(_BaseModule):
1221
1208
  self.push()
1222
1209
  return self
1223
1210
 
1224
- def __exit__(self, exc_type, exc, exc_tb) -> None: # noqa: ANN001
1211
+ def __exit__(self, exc_type, exc, exc_tb) -> None:
1225
1212
  CURRENT_EGRAPH.reset(self._token_stack.pop())
1226
1213
  self.pop()
1227
1214
 
@@ -1261,42 +1248,119 @@ class EGraph(_BaseModule):
1261
1248
  return self._egraph.eval_py_object(egg_expr)
1262
1249
  raise TypeError(f"Eval not implemented for {typed_expr.tp}")
1263
1250
 
1264
- def saturate(
1265
- self, *, max: int = 1000, performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
1266
- ) -> ipywidgets.Widget:
1267
- from .graphviz_widget import graphviz_widget_with_slider
1251
+ def _serialize(
1252
+ self,
1253
+ **kwargs: Unpack[GraphvizKwargs],
1254
+ ) -> bindings.SerializedEGraph:
1255
+ max_functions = kwargs.pop("max_functions", None)
1256
+ max_calls_per_function = kwargs.pop("max_calls_per_function", None)
1257
+ split_primitive_outputs = kwargs.pop("split_primitive_outputs", True)
1258
+ split_functions = kwargs.pop("split_functions", [])
1259
+ include_temporary_functions = kwargs.pop("include_temporary_functions", False)
1260
+ n_inline_leaves = kwargs.pop("n_inline_leaves", 1)
1261
+ serialized = self._egraph.serialize(
1262
+ [],
1263
+ max_functions=max_functions,
1264
+ max_calls_per_function=max_calls_per_function,
1265
+ include_temporary_functions=include_temporary_functions,
1266
+ )
1267
+ if split_primitive_outputs or split_functions:
1268
+ additional_ops = set(map(self._callable_to_egg, split_functions))
1269
+ serialized.split_classes(self._egraph, additional_ops)
1270
+ serialized.map_ops(self._state.op_mapping())
1268
1271
 
1269
- dots = [str(self.graphviz(**kwargs))]
1270
- i = 0
1271
- while self.run(1).updated and i < max:
1272
- i += 1
1273
- dots.append(str(self.graphviz(**kwargs)))
1274
- return graphviz_widget_with_slider(dots, performance=performance)
1272
+ for _ in range(n_inline_leaves):
1273
+ serialized.inline_leaves()
1275
1274
 
1276
- def saturate_to_html(
1277
- self, file: str = "tmp.html", performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
1278
- ) -> None:
1279
- # raise NotImplementedError("Upstream bugs prevent rendering to HTML")
1275
+ return serialized
1276
+
1277
+ def _graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
1278
+ serialized = self._serialize(**kwargs)
1279
+
1280
+ original = serialized.to_dot()
1281
+ # Add link to stylesheet to the graph, so that edges light up on hover
1282
+ # https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
1283
+ styles = """/* the lines within the edges */
1284
+ .edge:active path,
1285
+ .edge:hover path {
1286
+ stroke: fuchsia;
1287
+ stroke-width: 3;
1288
+ stroke-opacity: 1;
1289
+ }
1290
+ /* arrows are typically drawn with a polygon */
1291
+ .edge:active polygon,
1292
+ .edge:hover polygon {
1293
+ stroke: fuchsia;
1294
+ stroke-width: 3;
1295
+ fill: fuchsia;
1296
+ stroke-opacity: 1;
1297
+ fill-opacity: 1;
1298
+ }
1299
+ /* If you happen to have text and want to color that as well... */
1300
+ .edge:active text,
1301
+ .edge:hover text {
1302
+ fill: fuchsia;
1303
+ }"""
1304
+ p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
1305
+ p.write_text(styles)
1306
+ with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
1307
+ return graphviz.Source(with_stylesheet)
1308
+
1309
+ def display(self, graphviz: bool = False, **kwargs: Unpack[GraphvizKwargs]) -> None:
1310
+ """
1311
+ Displays the e-graph.
1312
+
1313
+ If in IPython it will display it inline, otherwise it will write it to a file and open it.
1314
+ """
1315
+ from IPython.display import SVG, display
1280
1316
 
1281
- # import panel
1317
+ from .visualizer_widget import VisualizerWidget
1282
1318
 
1283
- # panel.extension("ipywidgets")
1319
+ if graphviz:
1320
+ if IN_IPYTHON:
1321
+ svg = self._graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")
1322
+ display(SVG(svg))
1323
+ else:
1324
+ self._graphviz(**kwargs).render(view=True, format="svg", quiet=True)
1325
+ else:
1326
+ serialized = self._serialize(**kwargs)
1327
+ VisualizerWidget(egraphs=[serialized.to_json()]).display_or_open()
1284
1328
 
1285
- widget = self.saturate(performance=performance, **kwargs)
1286
- # panel.panel(widget).save(file)
1329
+ def saturate(
1330
+ self,
1331
+ schedule: Schedule | None = None,
1332
+ *,
1333
+ expr: Expr | None = None,
1334
+ max: int = 1000,
1335
+ **kwargs: Unpack[GraphvizKwargs],
1336
+ ) -> None:
1337
+ """
1338
+ Saturate the egraph, running the given schedule until the egraph is saturated.
1339
+ It serializes the egraph at each step and returns a widget to visualize the egraph.
1340
+ """
1341
+ from .visualizer_widget import VisualizerWidget
1287
1342
 
1288
- from ipywidgets.embed import embed_minimal_html
1343
+ def to_json() -> str:
1344
+ if expr:
1345
+ print(self.extract(expr))
1346
+ return self._serialize(**kwargs).to_json()
1289
1347
 
1290
- embed_minimal_html("tmp.html", views=[widget], drop_defaults=False)
1291
- # Use panel while this issue persists
1292
- # https://github.com/jupyter-widgets/ipywidgets/issues/3761#issuecomment-1755563436
1348
+ egraphs = [to_json()]
1349
+ i = 0
1350
+ while self.run(schedule or 1).updated and i < max:
1351
+ i += 1
1352
+ egraphs.append(to_json())
1353
+ VisualizerWidget(egraphs=egraphs).display_or_open()
1293
1354
 
1294
1355
  @classmethod
1295
1356
  def current(cls) -> EGraph:
1296
1357
  """
1297
1358
  Returns the current egraph, which is the one in the context.
1298
1359
  """
1299
- return CURRENT_EGRAPH.get()
1360
+ try:
1361
+ return CURRENT_EGRAPH.get()
1362
+ except LookupError:
1363
+ return cls(save_egglog_string=True)
1300
1364
 
1301
1365
  @property
1302
1366
  def _egraph(self) -> bindings.EGraph:
@@ -1448,7 +1512,8 @@ class Ruleset(Schedule):
1448
1512
  To return the egg decls, we go through our deferred rules and add any we haven't yet
1449
1513
  """
1450
1514
  while self.deferred_rule_gens:
1451
- rules = self.deferred_rule_gens.pop()()
1515
+ with set_current_ruleset(self):
1516
+ rules = self.deferred_rule_gens.pop()()
1452
1517
  self._current_egg_decls.update(*rules)
1453
1518
  self.__egg_ruleset__.rules.extend(r.decl for r in rules)
1454
1519
  return self._current_egg_decls
@@ -1688,16 +1753,16 @@ def action_command(action: Action) -> Action:
1688
1753
  return action
1689
1754
 
1690
1755
 
1691
- def var(name: str, bound: type[EXPR]) -> EXPR:
1756
+ def var(name: str, bound: type[T]) -> T:
1692
1757
  """Create a new variable with the given name and type."""
1693
- return cast(EXPR, _var(name, bound))
1758
+ return cast(T, _var(name, bound))
1694
1759
 
1695
1760
 
1696
1761
  def _var(name: str, bound: object) -> RuntimeExpr:
1697
1762
  """Create a new variable with the given name and type."""
1698
1763
  decls = Declarations()
1699
1764
  type_ref = resolve_type_annotation(decls, bound)
1700
- return RuntimeExpr.__from_value__(decls, TypedExprDecl(type_ref.to_just(), VarDecl(name)))
1765
+ return RuntimeExpr.__from_values__(decls, TypedExprDecl(type_ref.to_just(), VarDecl(name, False)))
1701
1766
 
1702
1767
 
1703
1768
  def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
@@ -1790,7 +1855,7 @@ class _NeBuilder(Generic[EXPR]):
1790
1855
  lhs = to_runtime_expr(self.lhs)
1791
1856
  rhs = convert_to_same_type(rhs, lhs)
1792
1857
  assert isinstance(Unit, RuntimeClass)
1793
- res = RuntimeExpr.__from_value__(
1858
+ res = RuntimeExpr.__from_values__(
1794
1859
  Declarations.create(Unit, lhs, rhs),
1795
1860
  TypedExprDecl(
1796
1861
  JustTypeRef("Unit"), CallDecl(FunctionRef("!="), (lhs.__egg_typed_expr__, rhs.__egg_typed_expr__))
@@ -1944,7 +2009,7 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
1944
2009
  if "Callable" not in globals:
1945
2010
  globals["Callable"] = Callable
1946
2011
  hints = get_type_hints(gen, globals, frame.f_locals)
1947
- args = [_var(_rule_var_name(p.name), hints[p.name]) for p in signature(gen).parameters.values()]
2012
+ args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
1948
2013
  return list(gen(*args)) # type: ignore[misc]
1949
2014
 
1950
2015
 
@@ -1959,3 +2024,19 @@ def _fact_like(fact_like: FactLike) -> Fact:
1959
2024
  if isinstance(fact_like, Expr):
1960
2025
  return expr_fact(fact_like)
1961
2026
  return fact_like
2027
+
2028
+
2029
+ _CURRENT_RULESET = ContextVar[Ruleset | None]("CURRENT_RULESET", default=None)
2030
+
2031
+
2032
+ def get_current_ruleset() -> Ruleset | None:
2033
+ return _CURRENT_RULESET.get()
2034
+
2035
+
2036
+ @contextlib.contextmanager
2037
+ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
2038
+ token = _CURRENT_RULESET.set(r)
2039
+ try:
2040
+ yield
2041
+ finally:
2042
+ _CURRENT_RULESET.reset(token)