egglog 7.1.0__cp312-none-win_amd64.whl → 8.0.0__cp312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph.py CHANGED
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import contextlib
3
4
  import inspect
4
5
  import pathlib
5
6
  import tempfile
6
7
  from abc import abstractmethod
7
- from collections.abc import Callable, Iterable
8
+ from collections.abc import Callable, Generator, Iterable
8
9
  from contextvars import ContextVar, Token
9
10
  from dataclasses import InitVar, dataclass, field
10
11
  from inspect import Parameter, currentframe, signature
@@ -37,8 +38,6 @@ from .runtime import *
37
38
  from .thunk import *
38
39
 
39
40
  if TYPE_CHECKING:
40
- import ipywidgets
41
-
42
41
  from .builtins import Bool, PyObject, String, f64, i64
43
42
 
44
43
 
@@ -352,7 +351,7 @@ class _BaseModule:
352
351
  This is the same as defining a nullary function with a high cost.
353
352
  # TODO: Rename as declare to match eggglog?
354
353
  """
355
- return constant(name, tp, egg_name)
354
+ return constant(name, tp, egg_name=egg_name)
356
355
 
357
356
  def register(
358
357
  self,
@@ -452,6 +451,7 @@ class _ExprMetaclass(type):
452
451
  namespace: dict[str, Any],
453
452
  egg_sort: str | None = None,
454
453
  builtin: bool = False,
454
+ ruleset: Ruleset | None = None,
455
455
  ) -> RuntimeClass | type:
456
456
  # If this is the Expr subclass, just return the class
457
457
  if not bases:
@@ -463,20 +463,37 @@ class _ExprMetaclass(type):
463
463
  prev_frame = frame.f_back
464
464
  assert prev_frame
465
465
 
466
+ # Pass in an instance of the class so that when we are generating the decls
467
+ # we can update them eagerly so that we can access the methods in the class body
468
+ runtime_cls = RuntimeClass(None, TypeRefWithVars(name)) # type: ignore[arg-type]
469
+
466
470
  # Store frame so that we can get live access to updated locals/globals
467
471
  # Otherwise, f_locals returns a copy
468
472
  # https://peps.python.org/pep-0667/
469
- decls_thunk = Thunk.fn(
470
- _generate_class_decls, namespace, prev_frame, builtin, egg_sort, name, fallback=Declarations
473
+ runtime_cls.__egg_decls_thunk__ = Thunk.fn(
474
+ _generate_class_decls,
475
+ namespace,
476
+ prev_frame,
477
+ builtin,
478
+ egg_sort,
479
+ name,
480
+ ruleset,
481
+ runtime_cls,
471
482
  )
472
- return RuntimeClass(decls_thunk, TypeRefWithVars(name))
483
+ return runtime_cls
473
484
 
474
485
  def __instancecheck__(cls, instance: object) -> bool:
475
486
  return isinstance(instance, RuntimeExpr)
476
487
 
477
488
 
478
489
  def _generate_class_decls( # noqa: C901
479
- namespace: dict[str, Any], frame: FrameType, builtin: bool, egg_sort: str | None, cls_name: str
490
+ namespace: dict[str, Any],
491
+ frame: FrameType,
492
+ builtin: bool,
493
+ egg_sort: str | None,
494
+ cls_name: str,
495
+ ruleset: Ruleset | None,
496
+ runtime_cls: RuntimeClass,
480
497
  ) -> Declarations:
481
498
  """
482
499
  Lazy constructor for class declerations to support classes with methods whose types are not yet defined.
@@ -489,6 +506,8 @@ def _generate_class_decls( # noqa: C901
489
506
  del parameters
490
507
  cls_decl = ClassDecl(egg_sort, type_vars, builtin)
491
508
  decls = Declarations(_classes={cls_name: cls_decl})
509
+ # Update class think eagerly when resolving so that lookups work in methods
510
+ runtime_cls.__egg_decls_thunk__ = Thunk.value(decls)
492
511
 
493
512
  ##
494
513
  # Register class variables
@@ -498,9 +517,9 @@ def _generate_class_decls( # noqa: C901
498
517
  for k, v in get_type_hints(_Dummytype, globalns=frame.f_globals, localns=frame.f_locals).items():
499
518
  if getattr(v, "__origin__", None) == ClassVar:
500
519
  (inner_tp,) = v.__args__
501
- type_ref = resolve_type_annotation(decls, inner_tp).to_just()
502
- cls_decl.class_variables[k] = ConstantDecl(type_ref)
503
-
520
+ type_ref = resolve_type_annotation(decls, inner_tp)
521
+ cls_decl.class_variables[k] = ConstantDecl(type_ref.to_just())
522
+ _add_default_rewrite(decls, ClassVariableRef(cls_name, k), type_ref, namespace.pop(k, None), ruleset)
504
523
  else:
505
524
  msg = f"On class {cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
506
525
  raise NotImplementedError(msg)
@@ -509,14 +528,13 @@ def _generate_class_decls( # noqa: C901
509
528
  # Register methods, classmethods, preserved methods, and properties
510
529
  ##
511
530
 
512
- # The type ref of self is paramterized by the type vars
513
- slf_type_ref = TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
514
-
515
531
  # Get all the methods from the class
516
532
  filtered_namespace: list[tuple[str, Any]] = [
517
533
  (k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
518
534
  ]
519
535
 
536
+ # all methods we should try adding default functions for
537
+ add_default_funcs: list[Callable[[], None]] = []
520
538
  # Then register each of its methods
521
539
  for method_name, method in filtered_namespace:
522
540
  is_init = method_name == "__init__"
@@ -535,44 +553,35 @@ def _generate_class_decls( # noqa: C901
535
553
  cls_decl.preserved_methods[method_name] = fn
536
554
  continue
537
555
  locals = frame.f_locals
538
-
539
- def create_decl(fn: object, first: Literal["cls"] | TypeRefWithVars) -> FunctionDecl:
540
- special_function_name: SpecialFunctions | None = (
541
- "fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None # noqa: B023
542
- )
543
- if special_function_name:
544
- return FunctionDecl(
545
- special_function_name,
546
- builtin=True,
547
- egg_name=egg_fn, # noqa: B023
548
- )
549
-
550
- return _fn_decl(
551
- decls,
552
- egg_fn, # noqa: B023
553
- fn,
554
- locals, # noqa: B023
555
- default, # noqa: B023
556
- cost, # noqa: B023
557
- merge, # noqa: B023
558
- on_merge, # noqa: B023
559
- mutates, # noqa: B023
560
- builtin,
561
- first,
562
- is_init, # noqa: B023
563
- unextractable, # noqa: B023
564
- )
565
-
556
+ ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
566
557
  match fn:
567
558
  case classmethod():
568
- cls_decl.class_methods[method_name] = create_decl(fn.__func__, "cls")
559
+ ref = ClassMethodRef(cls_name, method_name)
560
+ fn = fn.__func__
569
561
  case property():
570
- cls_decl.properties[method_name] = create_decl(fn.fget, slf_type_ref)
562
+ ref = PropertyRef(cls_name, method_name)
563
+ fn = fn.fget
571
564
  case _:
572
- if is_init:
573
- cls_decl.class_methods[method_name] = create_decl(fn, slf_type_ref)
574
- else:
575
- cls_decl.methods[method_name] = create_decl(fn, slf_type_ref)
565
+ ref = InitRef(cls_name) if is_init else MethodRef(cls_name, method_name)
566
+ special_function_name: SpecialFunctions | None = (
567
+ "fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None
568
+ )
569
+ if special_function_name:
570
+ decl = FunctionDecl(special_function_name, builtin=True, egg_name=egg_fn)
571
+ decls.set_function_decl(ref, decl)
572
+ continue
573
+
574
+ _, add_rewrite = _fn_decl(
575
+ decls, egg_fn, ref, fn, locals, default, cost, merge, on_merge, mutates, builtin, ruleset, unextractable
576
+ )
577
+
578
+ if not builtin and not isinstance(ref, InitRef) and not mutates:
579
+ add_default_funcs.append(add_rewrite)
580
+
581
+ # Add all rewrite methods at the end so that all methods are registered first and can be accessed
582
+ # in the bodies
583
+ for add_rewrite in add_default_funcs:
584
+ add_rewrite()
576
585
 
577
586
  return decls
578
587
 
@@ -591,6 +600,8 @@ def function(
591
600
  mutates_first_arg: bool = False,
592
601
  unextractable: bool = False,
593
602
  builtin: bool = False,
603
+ ruleset: Ruleset | None = None,
604
+ use_body_as_name: bool = False,
594
605
  ) -> Callable[[CALLABLE], CALLABLE]: ...
595
606
 
596
607
 
@@ -604,6 +615,8 @@ def function(
604
615
  on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None,
605
616
  mutates_first_arg: bool = False,
606
617
  unextractable: bool = False,
618
+ ruleset: Ruleset | None = None,
619
+ use_body_as_name: bool = False,
607
620
  ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...
608
621
 
609
622
 
@@ -634,15 +647,19 @@ class _FunctionConstructor:
634
647
  merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None
635
648
  on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
636
649
  unextractable: bool = False
650
+ ruleset: Ruleset | None = None
651
+ use_body_as_name: bool = False
637
652
 
638
653
  def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
639
- return RuntimeFunction(Thunk.fn(self.create_decls, fn), FunctionRef(fn.__name__))
654
+ return RuntimeFunction(*split_thunk(Thunk.fn(self.create_decls, fn)))
640
655
 
641
- def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations:
656
+ def create_decls(self, fn: Callable[..., RuntimeExpr]) -> tuple[Declarations, CallableRef]:
642
657
  decls = Declarations()
643
- decls._functions[fn.__name__] = _fn_decl(
658
+ ref = None if self.use_body_as_name else FunctionRef(fn.__name__)
659
+ ref, add_rewrite = _fn_decl(
644
660
  decls,
645
661
  self.egg_fn,
662
+ ref,
646
663
  fn,
647
664
  self.hint_locals,
648
665
  self.default,
@@ -651,14 +668,18 @@ class _FunctionConstructor:
651
668
  self.on_merge,
652
669
  self.mutates_first_arg,
653
670
  self.builtin,
671
+ self.ruleset,
654
672
  unextractable=self.unextractable,
655
673
  )
656
- return decls
674
+ add_rewrite()
675
+ return decls, ref
657
676
 
658
677
 
659
678
  def _fn_decl(
660
679
  decls: Declarations,
661
680
  egg_name: str | None,
681
+ # If ref is Callable, then generate the ref from the function name
682
+ ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef | None,
662
683
  fn: object,
663
684
  # Pass in the locals, retrieved from the frame when wrapping,
664
685
  # so that we support classes and function defined inside of other functions (which won't show up in the globals)
@@ -669,11 +690,12 @@ def _fn_decl(
669
690
  on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None,
670
691
  mutates_first_arg: bool,
671
692
  is_builtin: bool,
672
- # The first arg is either cls, for a classmethod, a self type, or none for a function
673
- first_arg: Literal["cls"] | TypeOrVarRef | None = None,
674
- is_init: bool = False,
693
+ ruleset: Ruleset | None = None,
675
694
  unextractable: bool = False,
676
- ) -> FunctionDecl:
695
+ ) -> tuple[CallableRef, Callable[[], None]]:
696
+ """
697
+ Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
698
+ """
677
699
  if not isinstance(fn, FunctionType):
678
700
  raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
679
701
 
@@ -687,8 +709,8 @@ def _fn_decl(
687
709
 
688
710
  params = list(signature(fn).parameters.values())
689
711
 
690
- # If this is an init function, or a classmethod, remove the first arg name
691
- if is_init or first_arg == "cls":
712
+ # If this is an init function, or a classmethod, the first arg is not used
713
+ if isinstance(ref, ClassMethodRef | InitRef):
692
714
  params = params[1:]
693
715
 
694
716
  if _last_param_variable(params):
@@ -698,40 +720,37 @@ def _fn_decl(
698
720
  else:
699
721
  var_arg_type = None
700
722
  arg_types = tuple(
701
- first_arg
702
- # If the first arg is a self, and this not an __init__ fn, add this as a typeref
703
- if i == 0 and isinstance(first_arg, ClassTypeVarRef | TypeRefWithVars) and not is_init
723
+ decls.get_paramaterized_class(ref.class_name)
724
+ if i == 0 and isinstance(ref, MethodRef | PropertyRef)
704
725
  else resolve_type_annotation(decls, hints[t.name])
705
726
  for i, t in enumerate(params)
706
727
  )
707
728
 
708
729
  # Resolve all default values as arg types
709
730
  arg_defaults = [
710
- resolve_literal(t, p.default) if p.default is not Parameter.empty else None
731
+ resolve_literal(t, p.default, Thunk.value(decls)) if p.default is not Parameter.empty else None
711
732
  for (t, p) in zip(arg_types, params, strict=True)
712
733
  ]
713
734
 
714
735
  decls.update(*arg_defaults)
715
736
 
716
- # If this is an init fn use the first arg as the return type
717
- if is_init:
718
- assert not mutates_first_arg
719
- if not isinstance(first_arg, ClassTypeVarRef | TypeRefWithVars):
720
- msg = "Init function must have a self type"
721
- raise ValueError(msg)
722
- return_type = first_arg
723
- elif mutates_first_arg:
724
- return_type = arg_types[0]
725
- else:
726
- return_type = resolve_type_annotation(decls, hints["return"])
737
+ return_type = (
738
+ decls.get_paramaterized_class(ref.class_name)
739
+ if isinstance(ref, InitRef)
740
+ else arg_types[0]
741
+ if mutates_first_arg
742
+ else resolve_type_annotation(decls, hints["return"])
743
+ )
744
+
745
+ arg_names = tuple(t.name for t in params)
727
746
 
728
747
  decls |= default
729
748
  merged = (
730
749
  None
731
750
  if merge is None
732
751
  else merge(
733
- RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
734
- RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
752
+ RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old", False))),
753
+ RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new", False))),
735
754
  )
736
755
  )
737
756
  decls |= merged
@@ -741,28 +760,47 @@ def _fn_decl(
741
760
  if on_merge is None
742
761
  else _action_likes(
743
762
  on_merge(
744
- RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
745
- RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
763
+ RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old", False))),
764
+ RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new", False))),
746
765
  )
747
766
  )
748
767
  )
749
768
  decls.update(*merge_action)
750
- return FunctionDecl(
751
- FunctionSignature(
769
+ # defer this in generator so it doesnt resolve for builtins eagerly
770
+ args = (TypedExprDecl(tp.to_just(), VarDecl(name, False)) for name, tp in zip(arg_names, arg_types, strict=True))
771
+ res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
772
+ res_thunk: Callable[[], object]
773
+ # If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it
774
+ if not ref:
775
+ tuple_args = tuple(args)
776
+ res = _create_default_value(decls, ref, fn, tuple_args, ruleset)
777
+ assert isinstance(res, RuntimeExpr)
778
+ res_ref = UnnamedFunctionRef(tuple_args, res.__egg_typed_expr__)
779
+ decls._unnamed_functions.add(res_ref)
780
+ res_thunk = Thunk.value(res)
781
+
782
+ else:
783
+ signature_ = FunctionSignature(
752
784
  return_type=None if mutates_first_arg else return_type,
753
785
  var_arg_type=var_arg_type,
754
786
  arg_types=arg_types,
755
- arg_names=tuple(t.name for t in params),
787
+ arg_names=arg_names,
756
788
  arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
757
- ),
758
- cost=cost,
759
- egg_name=egg_name,
760
- merge=merged.__egg_typed_expr__.expr if merged is not None else None,
761
- unextractable=unextractable,
762
- builtin=is_builtin,
763
- default=None if default is None else default.__egg_typed_expr__.expr,
764
- on_merge=tuple(a.action for a in merge_action),
765
- )
789
+ )
790
+ decl = FunctionDecl(
791
+ signature=signature_,
792
+ cost=cost,
793
+ egg_name=egg_name,
794
+ merge=merged.__egg_typed_expr__.expr if merged is not None else None,
795
+ unextractable=unextractable,
796
+ builtin=is_builtin,
797
+ default=None if default is None else default.__egg_typed_expr__.expr,
798
+ on_merge=tuple(a.action for a in merge_action),
799
+ )
800
+ res_ref = ref
801
+ decls.set_function_decl(ref, decl)
802
+ res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset)
803
+ return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk)
766
804
 
767
805
 
768
806
  # Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
@@ -793,7 +831,7 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..
793
831
  Creates a function whose return type is `Unit` and has a default value.
794
832
  """
795
833
  decls_thunk = Thunk.fn(_relation_decls, name, tps, egg_fn)
796
- return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, FunctionRef(name)))
834
+ return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, Thunk.value(FunctionRef(name))))
797
835
 
798
836
 
799
837
  def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations:
@@ -804,19 +842,85 @@ def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Dec
804
842
  return decls
805
843
 
806
844
 
807
- def constant(name: str, tp: type[EXPR], egg_name: str | None = None) -> EXPR:
845
+ def constant(
846
+ name: str,
847
+ tp: type[EXPR],
848
+ default_replacement: EXPR | None = None,
849
+ /,
850
+ *,
851
+ egg_name: str | None = None,
852
+ ruleset: Ruleset | None = None,
853
+ ) -> EXPR:
808
854
  """
809
855
  A "constant" is implemented as the instantiation of a value that takes no args.
810
856
  This creates a function with `name` and return type `tp` and returns a value of it being called.
811
857
  """
812
- return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name)))
858
+ return cast(
859
+ EXPR, RuntimeExpr(*split_thunk(Thunk.fn(_constant_thunk, name, tp, egg_name, default_replacement, ruleset)))
860
+ )
813
861
 
814
862
 
815
- def _constant_thunk(name: str, tp: type, egg_name: str | None) -> tuple[Declarations, TypedExprDecl]:
863
+ def _constant_thunk(
864
+ name: str, tp: type, egg_name: str | None, default_replacement: object, ruleset: Ruleset | None
865
+ ) -> tuple[Declarations, TypedExprDecl]:
816
866
  decls = Declarations()
817
- type_ref = resolve_type_annotation(decls, tp).to_just()
818
- decls._constants[name] = ConstantDecl(type_ref, egg_name)
819
- return decls, TypedExprDecl(type_ref, CallDecl(ConstantRef(name)))
867
+ type_ref = resolve_type_annotation(decls, tp)
868
+ callable_ref = ConstantRef(name)
869
+ decls._constants[name] = ConstantDecl(type_ref.to_just(), egg_name)
870
+ _add_default_rewrite(decls, callable_ref, type_ref, default_replacement, ruleset)
871
+ return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
872
+
873
+
874
+ def _create_default_value(
875
+ decls: Declarations,
876
+ ref: CallableRef | None,
877
+ fn: Callable,
878
+ args: Iterable[TypedExprDecl],
879
+ ruleset: Ruleset | None,
880
+ ) -> object:
881
+ args: list[object] = [RuntimeExpr.__from_values__(decls, a) for a in args]
882
+
883
+ # If this is a classmethod, add the class as the first arg
884
+ if isinstance(ref, ClassMethodRef):
885
+ tp = decls.get_paramaterized_class(ref.class_name)
886
+ args.insert(0, RuntimeClass(Thunk.value(decls), tp))
887
+ with set_current_ruleset(ruleset):
888
+ return fn(*args)
889
+
890
+
891
+ def _add_default_rewrite_function(
892
+ decls: Declarations,
893
+ ref: CallableRef,
894
+ res_type: TypeOrVarRef,
895
+ ruleset: Ruleset | None,
896
+ value_thunk: Callable[[], object],
897
+ ) -> None:
898
+ """
899
+ Helper functions that resolves a value thunk to create the default value.
900
+ """
901
+ _add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset)
902
+
903
+
904
+ def _add_default_rewrite(
905
+ decls: Declarations, ref: CallableRef, type_ref: TypeOrVarRef, default_rewrite: object, ruleset: Ruleset | None
906
+ ) -> None:
907
+ """
908
+ Adds a default rewrite for the callable, if the default rewrite is not None
909
+
910
+ Will add it to the ruleset if it is passed in, or add it to the default ruleset on the passed in decls if not.
911
+ """
912
+ if default_rewrite is None:
913
+ return
914
+ resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls))
915
+ rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr)
916
+ if ruleset:
917
+ ruleset_decls = ruleset._current_egg_decls
918
+ ruleset_decl = ruleset.__egg_ruleset__
919
+ else:
920
+ ruleset_decls = decls
921
+ ruleset_decl = decls.default_ruleset
922
+ ruleset_decl.rules.append(rewrite_decl)
923
+ ruleset_decls |= resolved_value
820
924
 
821
925
 
822
926
  def _last_param_variable(params: list[Parameter]) -> bool:
@@ -866,6 +970,8 @@ class GraphvizKwargs(TypedDict, total=False):
866
970
  max_calls_per_function: int | None
867
971
  n_inline_leaves: int
868
972
  split_primitive_outputs: bool
973
+ split_functions: list[object]
974
+ include_temporary_functions: bool
869
975
 
870
976
 
871
977
  @dataclass
@@ -908,81 +1014,19 @@ class EGraph(_BaseModule):
908
1014
  raise ValueError(msg)
909
1015
  return cmds
910
1016
 
911
- def _repr_mimebundle_(self, *args, **kwargs):
912
- """
913
- Returns the graphviz representation of the e-graph.
914
- """
915
- return {"image/svg+xml": self.graphviz().pipe(format="svg", quiet=True, encoding="utf-8")}
916
-
917
- def graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
918
- # By default we want to split primitive outputs
919
- kwargs.setdefault("split_primitive_outputs", True)
920
- n_inline = kwargs.pop("n_inline_leaves", 0)
921
- serialized = self._egraph.serialize([], **kwargs) # type: ignore[misc]
922
- serialized.map_ops(self._state.op_mapping())
923
- for _ in range(n_inline):
924
- serialized.inline_leaves()
925
- original = serialized.to_dot()
926
- # Add link to stylesheet to the graph, so that edges light up on hover
927
- # https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
928
- styles = """/* the lines within the edges */
929
- .edge:active path,
930
- .edge:hover path {
931
- stroke: fuchsia;
932
- stroke-width: 3;
933
- stroke-opacity: 1;
934
- }
935
- /* arrows are typically drawn with a polygon */
936
- .edge:active polygon,
937
- .edge:hover polygon {
938
- stroke: fuchsia;
939
- stroke-width: 3;
940
- fill: fuchsia;
941
- stroke-opacity: 1;
942
- fill-opacity: 1;
943
- }
944
- /* If you happen to have text and want to color that as well... */
945
- .edge:active text,
946
- .edge:hover text {
947
- fill: fuchsia;
948
- }"""
949
- p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
950
- p.write_text(styles)
951
- with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
952
- return graphviz.Source(with_stylesheet)
953
-
954
- def graphviz_svg(self, **kwargs: Unpack[GraphvizKwargs]) -> str:
955
- return self.graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")
956
-
957
- def _repr_html_(self) -> str:
958
- """
959
- Add a _repr_html_ to be an SVG to work with sphinx gallery.
960
-
961
- ala https://github.com/xflr6/graphviz/pull/121
962
- until this PR is merged and released
963
- https://github.com/sphinx-gallery/sphinx-gallery/pull/1138
964
- """
965
- return self.graphviz_svg()
966
-
967
- def display(self, **kwargs: Unpack[GraphvizKwargs]) -> None:
968
- """
969
- Displays the e-graph in the notebook.
970
- """
971
- if IN_IPYTHON:
972
- from IPython.display import SVG, display
973
-
974
- display(SVG(self.graphviz_svg(**kwargs)))
975
- else:
976
- self.graphviz(**kwargs).render(view=True, format="svg", quiet=True)
1017
+ def _ipython_display_(self) -> None:
1018
+ self.display()
977
1019
 
978
1020
  def input(self, fn: Callable[..., String], path: str) -> None:
979
1021
  """
980
1022
  Loads a CSV file and sets it as *input, output of the function.
981
1023
  """
1024
+ self._egraph.run_program(bindings.Input(self._callable_to_egg(fn), path))
1025
+
1026
+ def _callable_to_egg(self, fn: object) -> str:
982
1027
  ref, decls = resolve_callable(fn)
983
1028
  self._add_decls(decls)
984
- fn_name = self._state.callable_ref_to_egg(ref)
985
- self._egraph.run_program(bindings.Input(fn_name, path))
1029
+ return self._state.callable_ref_to_egg(ref)
986
1030
 
987
1031
  def let(self, name: str, expr: EXPR) -> EXPR:
988
1032
  """
@@ -991,10 +1035,11 @@ class EGraph(_BaseModule):
991
1035
  action = let(name, expr)
992
1036
  self.register(action)
993
1037
  runtime_expr = to_runtime_expr(expr)
1038
+ self._add_decls(runtime_expr)
994
1039
  return cast(
995
1040
  EXPR,
996
- RuntimeExpr.__from_value__(
997
- self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name))
1041
+ RuntimeExpr.__from_values__(
1042
+ self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name, True))
998
1043
  ),
999
1044
  )
1000
1045
 
@@ -1016,14 +1061,15 @@ class EGraph(_BaseModule):
1016
1061
  self._add_decls(runtime_expr, schedule)
1017
1062
  egg_schedule = self._state.schedule_to_egg(schedule.schedule)
1018
1063
  typed_expr = runtime_expr.__egg_typed_expr__
1019
- egg_expr = self._state.expr_to_egg(typed_expr.expr)
1064
+ # Must also register type
1065
+ egg_expr = self._state.typed_expr_to_egg(typed_expr)
1020
1066
  self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule))
1021
1067
  extract_report = self._egraph.extract_report()
1022
1068
  if not isinstance(extract_report, bindings.Best):
1023
1069
  msg = "No extract report saved"
1024
1070
  raise ValueError(msg) # noqa: TRY004
1025
1071
  (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
1026
- return cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
1072
+ return cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
1027
1073
 
1028
1074
  def include(self, path: str) -> None:
1029
1075
  """
@@ -1078,7 +1124,7 @@ class EGraph(_BaseModule):
1078
1124
  facts = _fact_likes(fact_likes)
1079
1125
  self._add_decls(*facts)
1080
1126
  egg_facts = [self._state.fact_to_egg(f.fact) for f in _fact_likes(facts)]
1081
- return bindings.Check(egg_facts)
1127
+ return bindings.Check(bindings.DUMMY_SPAN, egg_facts)
1082
1128
 
1083
1129
  @overload
1084
1130
  def extract(self, expr: EXPR, /, include_cost: Literal[False] = False) -> EXPR: ...
@@ -1100,7 +1146,7 @@ class EGraph(_BaseModule):
1100
1146
  raise ValueError(msg) # noqa: TRY004
1101
1147
  (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
1102
1148
 
1103
- res = cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
1149
+ res = cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
1104
1150
  if include_cost:
1105
1151
  return res, extract_report.cost
1106
1152
  return res
@@ -1118,12 +1164,15 @@ class EGraph(_BaseModule):
1118
1164
  msg = "Wrong extract report type"
1119
1165
  raise ValueError(msg) # noqa: TRY004
1120
1166
  new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp)
1121
- return [cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, expr)) for expr in new_exprs]
1167
+ return [cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs]
1122
1168
 
1123
1169
  def _run_extract(self, typed_expr: TypedExprDecl, n: int) -> bindings._ExtractReport:
1124
- self._state.type_ref_to_egg(typed_expr.tp)
1125
- expr = self._state.expr_to_egg(typed_expr.expr)
1126
- self._egraph.run_program(bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n)))))
1170
+ expr = self._state.typed_expr_to_egg(typed_expr)
1171
+ self._egraph.run_program(
1172
+ bindings.ActionCommand(
1173
+ bindings.Extract(bindings.DUMMY_SPAN, expr, bindings.Lit(bindings.DUMMY_SPAN, bindings.Int(n)))
1174
+ )
1175
+ )
1127
1176
  extract_report = self._egraph.extract_report()
1128
1177
  if not extract_report:
1129
1178
  msg = "No extract report saved"
@@ -1181,7 +1230,7 @@ class EGraph(_BaseModule):
1181
1230
  runtime_expr = to_runtime_expr(expr)
1182
1231
  self._add_decls(runtime_expr)
1183
1232
  typed_expr = runtime_expr.__egg_typed_expr__
1184
- egg_expr = self._state.expr_to_egg(typed_expr.expr)
1233
+ egg_expr = self._state.typed_expr_to_egg(typed_expr)
1185
1234
  match typed_expr.tp:
1186
1235
  case JustTypeRef("i64"):
1187
1236
  return self._egraph.eval_i64(egg_expr)
@@ -1195,42 +1244,110 @@ class EGraph(_BaseModule):
1195
1244
  return self._egraph.eval_py_object(egg_expr)
1196
1245
  raise TypeError(f"Eval not implemented for {typed_expr.tp}")
1197
1246
 
1198
- def saturate(
1199
- self, *, max: int = 1000, performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
1200
- ) -> ipywidgets.Widget:
1201
- from .graphviz_widget import graphviz_widget_with_slider
1247
+ def _serialize(
1248
+ self,
1249
+ **kwargs: Unpack[GraphvizKwargs],
1250
+ ) -> bindings.SerializedEGraph:
1251
+ max_functions = kwargs.pop("max_functions", None)
1252
+ max_calls_per_function = kwargs.pop("max_calls_per_function", None)
1253
+ split_primitive_outputs = kwargs.pop("split_primitive_outputs", True)
1254
+ split_functions = kwargs.pop("split_functions", [])
1255
+ include_temporary_functions = kwargs.pop("include_temporary_functions", False)
1256
+ n_inline_leaves = kwargs.pop("n_inline_leaves", 1)
1257
+ serialized = self._egraph.serialize(
1258
+ [],
1259
+ max_functions=max_functions,
1260
+ max_calls_per_function=max_calls_per_function,
1261
+ include_temporary_functions=include_temporary_functions,
1262
+ )
1263
+ if split_primitive_outputs or split_functions:
1264
+ additional_ops = set(map(self._callable_to_egg, split_functions))
1265
+ serialized.split_e_classes(self._egraph, additional_ops)
1266
+ serialized.map_ops(self._state.op_mapping())
1202
1267
 
1203
- dots = [str(self.graphviz(**kwargs))]
1204
- i = 0
1205
- while self.run(1).updated and i < max:
1206
- i += 1
1207
- dots.append(str(self.graphviz(**kwargs)))
1208
- return graphviz_widget_with_slider(dots, performance=performance)
1268
+ for _ in range(n_inline_leaves):
1269
+ serialized.inline_leaves()
1209
1270
 
1210
- def saturate_to_html(
1211
- self, file: str = "tmp.html", performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
1212
- ) -> None:
1213
- # raise NotImplementedError("Upstream bugs prevent rendering to HTML")
1271
+ return serialized
1272
+
1273
+ def _graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
1274
+ serialized = self._serialize(**kwargs)
1275
+
1276
+ original = serialized.to_dot()
1277
+ # Add link to stylesheet to the graph, so that edges light up on hover
1278
+ # https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
1279
+ styles = """/* the lines within the edges */
1280
+ .edge:active path,
1281
+ .edge:hover path {
1282
+ stroke: fuchsia;
1283
+ stroke-width: 3;
1284
+ stroke-opacity: 1;
1285
+ }
1286
+ /* arrows are typically drawn with a polygon */
1287
+ .edge:active polygon,
1288
+ .edge:hover polygon {
1289
+ stroke: fuchsia;
1290
+ stroke-width: 3;
1291
+ fill: fuchsia;
1292
+ stroke-opacity: 1;
1293
+ fill-opacity: 1;
1294
+ }
1295
+ /* If you happen to have text and want to color that as well... */
1296
+ .edge:active text,
1297
+ .edge:hover text {
1298
+ fill: fuchsia;
1299
+ }"""
1300
+ p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
1301
+ p.write_text(styles)
1302
+ with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
1303
+ return graphviz.Source(with_stylesheet)
1214
1304
 
1215
- # import panel
1305
+ def display(self, graphviz: bool = False, **kwargs: Unpack[GraphvizKwargs]) -> None:
1306
+ """
1307
+ Displays the e-graph.
1216
1308
 
1217
- # panel.extension("ipywidgets")
1309
+ If in IPython it will display it inline, otherwise it will write it to a file and open it.
1310
+ """
1311
+ from IPython.display import SVG, display
1218
1312
 
1219
- widget = self.saturate(performance=performance, **kwargs)
1220
- # panel.panel(widget).save(file)
1313
+ from .visualizer_widget import VisualizerWidget
1221
1314
 
1222
- from ipywidgets.embed import embed_minimal_html
1315
+ if graphviz:
1316
+ if IN_IPYTHON:
1317
+ svg = self._graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")
1318
+ display(SVG(svg))
1319
+ else:
1320
+ self._graphviz(**kwargs).render(view=True, format="svg", quiet=True)
1321
+ else:
1322
+ serialized = self._serialize(**kwargs)
1323
+ VisualizerWidget(egraphs=[serialized.to_json()]).display_or_open()
1223
1324
 
1224
- embed_minimal_html("tmp.html", views=[widget], drop_defaults=False)
1225
- # Use panel while this issue persists
1226
- # https://github.com/jupyter-widgets/ipywidgets/issues/3761#issuecomment-1755563436
1325
+ def saturate(self, schedule: Schedule | None = None, *, max: int = 1000, **kwargs: Unpack[GraphvizKwargs]) -> None:
1326
+ """
1327
+ Saturate the egraph, running the given schedule until the egraph is saturated.
1328
+ It serializes the egraph at each step and returns a widget to visualize the egraph.
1329
+ """
1330
+ from .visualizer_widget import VisualizerWidget
1331
+
1332
+ def to_json() -> str:
1333
+ return self._serialize(**kwargs).to_json()
1334
+
1335
+ egraphs = [to_json()]
1336
+ i = 0
1337
+ while self.run(schedule or 1).updated and i < max:
1338
+ i += 1
1339
+ egraphs.append(to_json())
1340
+ VisualizerWidget(egraphs=egraphs).display_or_open()
1227
1341
 
1228
1342
  @classmethod
1229
1343
  def current(cls) -> EGraph:
1230
1344
  """
1231
1345
  Returns the current egraph, which is the one in the context.
1232
1346
  """
1233
- return CURRENT_EGRAPH.get()
1347
+ try:
1348
+ return CURRENT_EGRAPH.get()
1349
+ except LookupError:
1350
+ return cls(save_egglog_string=True)
1234
1351
 
1235
1352
  @property
1236
1353
  def _egraph(self) -> bindings.EGraph:
@@ -1382,7 +1499,8 @@ class Ruleset(Schedule):
1382
1499
  To return the egg decls, we go through our deferred rules and add any we haven't yet
1383
1500
  """
1384
1501
  while self.deferred_rule_gens:
1385
- rules = self.deferred_rule_gens.pop()()
1502
+ with set_current_ruleset(self):
1503
+ rules = self.deferred_rule_gens.pop()()
1386
1504
  self._current_egg_decls.update(*rules)
1387
1505
  self.__egg_ruleset__.rules.extend(r.decl for r in rules)
1388
1506
  return self._current_egg_decls
@@ -1622,16 +1740,16 @@ def action_command(action: Action) -> Action:
1622
1740
  return action
1623
1741
 
1624
1742
 
1625
- def var(name: str, bound: type[EXPR]) -> EXPR:
1743
+ def var(name: str, bound: type[T]) -> T:
1626
1744
  """Create a new variable with the given name and type."""
1627
- return cast(EXPR, _var(name, bound))
1745
+ return cast(T, _var(name, bound))
1628
1746
 
1629
1747
 
1630
1748
  def _var(name: str, bound: object) -> RuntimeExpr:
1631
1749
  """Create a new variable with the given name and type."""
1632
1750
  decls = Declarations()
1633
1751
  type_ref = resolve_type_annotation(decls, bound)
1634
- return RuntimeExpr.__from_value__(decls, TypedExprDecl(type_ref.to_just(), VarDecl(name)))
1752
+ return RuntimeExpr.__from_values__(decls, TypedExprDecl(type_ref.to_just(), VarDecl(name, False)))
1635
1753
 
1636
1754
 
1637
1755
  def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
@@ -1724,7 +1842,7 @@ class _NeBuilder(Generic[EXPR]):
1724
1842
  lhs = to_runtime_expr(self.lhs)
1725
1843
  rhs = convert_to_same_type(rhs, lhs)
1726
1844
  assert isinstance(Unit, RuntimeClass)
1727
- res = RuntimeExpr.__from_value__(
1845
+ res = RuntimeExpr.__from_values__(
1728
1846
  Declarations.create(Unit, lhs, rhs),
1729
1847
  TypedExprDecl(
1730
1848
  JustTypeRef("Unit"), CallDecl(FunctionRef("!="), (lhs.__egg_typed_expr__, rhs.__egg_typed_expr__))
@@ -1893,3 +2011,19 @@ def _fact_like(fact_like: FactLike) -> Fact:
1893
2011
  if isinstance(fact_like, Expr):
1894
2012
  return expr_fact(fact_like)
1895
2013
  return fact_like
2014
+
2015
+
2016
+ _CURRENT_RULESET = ContextVar[Ruleset | None]("CURRENT_RULESET", default=None)
2017
+
2018
+
2019
+ def get_current_ruleset() -> Ruleset | None:
2020
+ return _CURRENT_RULESET.get()
2021
+
2022
+
2023
+ @contextlib.contextmanager
2024
+ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
2025
+ token = _CURRENT_RULESET.set(r)
2026
+ try:
2027
+ yield
2028
+ finally:
2029
+ _CURRENT_RULESET.reset(token)