egglog 7.2.0__cp310-none-win_amd64.whl → 8.0.0__cp310-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp310-win_amd64.pyd +0 -0
- egglog/bindings.pyi +63 -23
- egglog/builtins.py +49 -6
- egglog/conversion.py +31 -8
- egglog/declarations.py +83 -4
- egglog/egraph.py +241 -173
- egglog/egraph_state.py +137 -61
- egglog/examples/higher_order_functions.py +3 -8
- egglog/exp/array_api.py +274 -92
- egglog/exp/array_api_jit.py +1 -4
- egglog/exp/array_api_loopnest.py +145 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +51 -12
- egglog/functionalize.py +91 -0
- egglog/pretty.py +84 -40
- egglog/runtime.py +52 -39
- egglog/thunk.py +30 -18
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35753 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.2.0.dist-info → egglog-8.0.0.dist-info}/METADATA +33 -32
- egglog-8.0.0.dist-info/RECORD +42 -0
- {egglog-7.2.0.dist-info → egglog-8.0.0.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.2.0.dist-info/RECORD +0 -40
- {egglog-7.2.0.dist-info/license_files → egglog-8.0.0.dist-info/licenses}/LICENSE +0 -0
egglog/egraph.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import contextlib
|
|
3
4
|
import inspect
|
|
4
5
|
import pathlib
|
|
5
6
|
import tempfile
|
|
6
7
|
from abc import abstractmethod
|
|
7
|
-
from collections.abc import Callable, Iterable
|
|
8
|
+
from collections.abc import Callable, Generator, Iterable
|
|
8
9
|
from contextvars import ContextVar, Token
|
|
9
10
|
from dataclasses import InitVar, dataclass, field
|
|
10
11
|
from inspect import Parameter, currentframe, signature
|
|
@@ -37,8 +38,6 @@ from .runtime import *
|
|
|
37
38
|
from .thunk import *
|
|
38
39
|
|
|
39
40
|
if TYPE_CHECKING:
|
|
40
|
-
import ipywidgets
|
|
41
|
-
|
|
42
41
|
from .builtins import Bool, PyObject, String, f64, i64
|
|
43
42
|
|
|
44
43
|
|
|
@@ -464,10 +463,14 @@ class _ExprMetaclass(type):
|
|
|
464
463
|
prev_frame = frame.f_back
|
|
465
464
|
assert prev_frame
|
|
466
465
|
|
|
466
|
+
# Pass in an instance of the class so that when we are generating the decls
|
|
467
|
+
# we can update them eagerly so that we can access the methods in the class body
|
|
468
|
+
runtime_cls = RuntimeClass(None, TypeRefWithVars(name)) # type: ignore[arg-type]
|
|
469
|
+
|
|
467
470
|
# Store frame so that we can get live access to updated locals/globals
|
|
468
471
|
# Otherwise, f_locals returns a copy
|
|
469
472
|
# https://peps.python.org/pep-0667/
|
|
470
|
-
|
|
473
|
+
runtime_cls.__egg_decls_thunk__ = Thunk.fn(
|
|
471
474
|
_generate_class_decls,
|
|
472
475
|
namespace,
|
|
473
476
|
prev_frame,
|
|
@@ -475,21 +478,22 @@ class _ExprMetaclass(type):
|
|
|
475
478
|
egg_sort,
|
|
476
479
|
name,
|
|
477
480
|
ruleset,
|
|
478
|
-
|
|
481
|
+
runtime_cls,
|
|
479
482
|
)
|
|
480
|
-
return
|
|
483
|
+
return runtime_cls
|
|
481
484
|
|
|
482
485
|
def __instancecheck__(cls, instance: object) -> bool:
|
|
483
486
|
return isinstance(instance, RuntimeExpr)
|
|
484
487
|
|
|
485
488
|
|
|
486
|
-
def _generate_class_decls(
|
|
489
|
+
def _generate_class_decls( # noqa: C901
|
|
487
490
|
namespace: dict[str, Any],
|
|
488
491
|
frame: FrameType,
|
|
489
492
|
builtin: bool,
|
|
490
493
|
egg_sort: str | None,
|
|
491
494
|
cls_name: str,
|
|
492
495
|
ruleset: Ruleset | None,
|
|
496
|
+
runtime_cls: RuntimeClass,
|
|
493
497
|
) -> Declarations:
|
|
494
498
|
"""
|
|
495
499
|
Lazy constructor for class declerations to support classes with methods whose types are not yet defined.
|
|
@@ -502,6 +506,8 @@ def _generate_class_decls(
|
|
|
502
506
|
del parameters
|
|
503
507
|
cls_decl = ClassDecl(egg_sort, type_vars, builtin)
|
|
504
508
|
decls = Declarations(_classes={cls_name: cls_decl})
|
|
509
|
+
# Update class think eagerly when resolving so that lookups work in methods
|
|
510
|
+
runtime_cls.__egg_decls_thunk__ = Thunk.value(decls)
|
|
505
511
|
|
|
506
512
|
##
|
|
507
513
|
# Register class variables
|
|
@@ -522,16 +528,13 @@ def _generate_class_decls(
|
|
|
522
528
|
# Register methods, classmethods, preserved methods, and properties
|
|
523
529
|
##
|
|
524
530
|
|
|
525
|
-
# The type ref of self is paramterized by the type vars
|
|
526
|
-
TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
|
|
527
|
-
|
|
528
531
|
# Get all the methods from the class
|
|
529
532
|
filtered_namespace: list[tuple[str, Any]] = [
|
|
530
533
|
(k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
|
|
531
534
|
]
|
|
532
535
|
|
|
533
536
|
# all methods we should try adding default functions for
|
|
534
|
-
|
|
537
|
+
add_default_funcs: list[Callable[[], None]] = []
|
|
535
538
|
# Then register each of its methods
|
|
536
539
|
for method_name, method in filtered_namespace:
|
|
537
540
|
is_init = method_name == "__init__"
|
|
@@ -550,7 +553,6 @@ def _generate_class_decls(
|
|
|
550
553
|
cls_decl.preserved_methods[method_name] = fn
|
|
551
554
|
continue
|
|
552
555
|
locals = frame.f_locals
|
|
553
|
-
|
|
554
556
|
ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
|
|
555
557
|
match fn:
|
|
556
558
|
case classmethod():
|
|
@@ -561,16 +563,25 @@ def _generate_class_decls(
|
|
|
561
563
|
fn = fn.fget
|
|
562
564
|
case _:
|
|
563
565
|
ref = InitRef(cls_name) if is_init else MethodRef(cls_name, method_name)
|
|
566
|
+
special_function_name: SpecialFunctions | None = (
|
|
567
|
+
"fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None
|
|
568
|
+
)
|
|
569
|
+
if special_function_name:
|
|
570
|
+
decl = FunctionDecl(special_function_name, builtin=True, egg_name=egg_fn)
|
|
571
|
+
decls.set_function_decl(ref, decl)
|
|
572
|
+
continue
|
|
564
573
|
|
|
565
|
-
|
|
574
|
+
_, add_rewrite = _fn_decl(
|
|
575
|
+
decls, egg_fn, ref, fn, locals, default, cost, merge, on_merge, mutates, builtin, ruleset, unextractable
|
|
576
|
+
)
|
|
566
577
|
|
|
567
578
|
if not builtin and not isinstance(ref, InitRef) and not mutates:
|
|
568
|
-
|
|
579
|
+
add_default_funcs.append(add_rewrite)
|
|
569
580
|
|
|
570
581
|
# Add all rewrite methods at the end so that all methods are registered first and can be accessed
|
|
571
582
|
# in the bodies
|
|
572
|
-
for
|
|
573
|
-
|
|
583
|
+
for add_rewrite in add_default_funcs:
|
|
584
|
+
add_rewrite()
|
|
574
585
|
|
|
575
586
|
return decls
|
|
576
587
|
|
|
@@ -590,6 +601,7 @@ def function(
|
|
|
590
601
|
unextractable: bool = False,
|
|
591
602
|
builtin: bool = False,
|
|
592
603
|
ruleset: Ruleset | None = None,
|
|
604
|
+
use_body_as_name: bool = False,
|
|
593
605
|
) -> Callable[[CALLABLE], CALLABLE]: ...
|
|
594
606
|
|
|
595
607
|
|
|
@@ -604,6 +616,7 @@ def function(
|
|
|
604
616
|
mutates_first_arg: bool = False,
|
|
605
617
|
unextractable: bool = False,
|
|
606
618
|
ruleset: Ruleset | None = None,
|
|
619
|
+
use_body_as_name: bool = False,
|
|
607
620
|
) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...
|
|
608
621
|
|
|
609
622
|
|
|
@@ -635,16 +648,18 @@ class _FunctionConstructor:
|
|
|
635
648
|
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
|
|
636
649
|
unextractable: bool = False
|
|
637
650
|
ruleset: Ruleset | None = None
|
|
651
|
+
use_body_as_name: bool = False
|
|
638
652
|
|
|
639
653
|
def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
|
|
640
|
-
return RuntimeFunction(Thunk.fn(self.create_decls, fn)
|
|
654
|
+
return RuntimeFunction(*split_thunk(Thunk.fn(self.create_decls, fn)))
|
|
641
655
|
|
|
642
|
-
def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations:
|
|
656
|
+
def create_decls(self, fn: Callable[..., RuntimeExpr]) -> tuple[Declarations, CallableRef]:
|
|
643
657
|
decls = Declarations()
|
|
644
|
-
|
|
658
|
+
ref = None if self.use_body_as_name else FunctionRef(fn.__name__)
|
|
659
|
+
ref, add_rewrite = _fn_decl(
|
|
645
660
|
decls,
|
|
646
661
|
self.egg_fn,
|
|
647
|
-
|
|
662
|
+
ref,
|
|
648
663
|
fn,
|
|
649
664
|
self.hint_locals,
|
|
650
665
|
self.default,
|
|
@@ -653,16 +668,18 @@ class _FunctionConstructor:
|
|
|
653
668
|
self.on_merge,
|
|
654
669
|
self.mutates_first_arg,
|
|
655
670
|
self.builtin,
|
|
671
|
+
self.ruleset,
|
|
656
672
|
unextractable=self.unextractable,
|
|
657
673
|
)
|
|
658
|
-
|
|
659
|
-
return decls
|
|
674
|
+
add_rewrite()
|
|
675
|
+
return decls, ref
|
|
660
676
|
|
|
661
677
|
|
|
662
678
|
def _fn_decl(
|
|
663
679
|
decls: Declarations,
|
|
664
680
|
egg_name: str | None,
|
|
665
|
-
ref
|
|
681
|
+
# If ref is Callable, then generate the ref from the function name
|
|
682
|
+
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef | None,
|
|
666
683
|
fn: object,
|
|
667
684
|
# Pass in the locals, retrieved from the frame when wrapping,
|
|
668
685
|
# so that we support classes and function defined inside of other functions (which won't show up in the globals)
|
|
@@ -673,23 +690,15 @@ def _fn_decl(
|
|
|
673
690
|
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None,
|
|
674
691
|
mutates_first_arg: bool,
|
|
675
692
|
is_builtin: bool,
|
|
693
|
+
ruleset: Ruleset | None = None,
|
|
676
694
|
unextractable: bool = False,
|
|
677
|
-
) -> None:
|
|
695
|
+
) -> tuple[CallableRef, Callable[[], None]]:
|
|
678
696
|
"""
|
|
679
|
-
Sets the function decl for the function object.
|
|
697
|
+
Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
|
|
680
698
|
"""
|
|
681
699
|
if not isinstance(fn, FunctionType):
|
|
682
700
|
raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
|
|
683
701
|
|
|
684
|
-
# partial function creation and calling are handled with a special case in the type checker, so don't
|
|
685
|
-
# use the normal logic
|
|
686
|
-
special_function_name: SpecialFunctions | None = (
|
|
687
|
-
"fn-partial" if egg_name == "unstable-fn" else "fn-app" if egg_name == "unstable-app" else None
|
|
688
|
-
)
|
|
689
|
-
if special_function_name:
|
|
690
|
-
decls.set_function_decl(ref, FunctionDecl(special_function_name, builtin=True, egg_name=egg_name))
|
|
691
|
-
return
|
|
692
|
-
|
|
693
702
|
hint_globals = fn.__globals__.copy()
|
|
694
703
|
# Copy Callable into global if not present bc sometimes it gets automatically removed by ruff to type only block
|
|
695
704
|
# https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/
|
|
@@ -719,7 +728,7 @@ def _fn_decl(
|
|
|
719
728
|
|
|
720
729
|
# Resolve all default values as arg types
|
|
721
730
|
arg_defaults = [
|
|
722
|
-
resolve_literal(t, p.default) if p.default is not Parameter.empty else None
|
|
731
|
+
resolve_literal(t, p.default, Thunk.value(decls)) if p.default is not Parameter.empty else None
|
|
723
732
|
for (t, p) in zip(arg_types, params, strict=True)
|
|
724
733
|
]
|
|
725
734
|
|
|
@@ -740,8 +749,8 @@ def _fn_decl(
|
|
|
740
749
|
None
|
|
741
750
|
if merge is None
|
|
742
751
|
else merge(
|
|
743
|
-
RuntimeExpr.
|
|
744
|
-
RuntimeExpr.
|
|
752
|
+
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old", False))),
|
|
753
|
+
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new", False))),
|
|
745
754
|
)
|
|
746
755
|
)
|
|
747
756
|
decls |= merged
|
|
@@ -751,29 +760,47 @@ def _fn_decl(
|
|
|
751
760
|
if on_merge is None
|
|
752
761
|
else _action_likes(
|
|
753
762
|
on_merge(
|
|
754
|
-
RuntimeExpr.
|
|
755
|
-
RuntimeExpr.
|
|
763
|
+
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old", False))),
|
|
764
|
+
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new", False))),
|
|
756
765
|
)
|
|
757
766
|
)
|
|
758
767
|
)
|
|
759
768
|
decls.update(*merge_action)
|
|
760
|
-
|
|
761
|
-
|
|
769
|
+
# defer this in generator so it doesnt resolve for builtins eagerly
|
|
770
|
+
args = (TypedExprDecl(tp.to_just(), VarDecl(name, False)) for name, tp in zip(arg_names, arg_types, strict=True))
|
|
771
|
+
res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
|
|
772
|
+
res_thunk: Callable[[], object]
|
|
773
|
+
# If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it
|
|
774
|
+
if not ref:
|
|
775
|
+
tuple_args = tuple(args)
|
|
776
|
+
res = _create_default_value(decls, ref, fn, tuple_args, ruleset)
|
|
777
|
+
assert isinstance(res, RuntimeExpr)
|
|
778
|
+
res_ref = UnnamedFunctionRef(tuple_args, res.__egg_typed_expr__)
|
|
779
|
+
decls._unnamed_functions.add(res_ref)
|
|
780
|
+
res_thunk = Thunk.value(res)
|
|
781
|
+
|
|
782
|
+
else:
|
|
783
|
+
signature_ = FunctionSignature(
|
|
762
784
|
return_type=None if mutates_first_arg else return_type,
|
|
763
785
|
var_arg_type=var_arg_type,
|
|
764
786
|
arg_types=arg_types,
|
|
765
787
|
arg_names=arg_names,
|
|
766
788
|
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
|
|
767
|
-
)
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
789
|
+
)
|
|
790
|
+
decl = FunctionDecl(
|
|
791
|
+
signature=signature_,
|
|
792
|
+
cost=cost,
|
|
793
|
+
egg_name=egg_name,
|
|
794
|
+
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
|
|
795
|
+
unextractable=unextractable,
|
|
796
|
+
builtin=is_builtin,
|
|
797
|
+
default=None if default is None else default.__egg_typed_expr__.expr,
|
|
798
|
+
on_merge=tuple(a.action for a in merge_action),
|
|
799
|
+
)
|
|
800
|
+
res_ref = ref
|
|
801
|
+
decls.set_function_decl(ref, decl)
|
|
802
|
+
res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset)
|
|
803
|
+
return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk)
|
|
777
804
|
|
|
778
805
|
|
|
779
806
|
# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
|
|
@@ -804,7 +831,7 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..
|
|
|
804
831
|
Creates a function whose return type is `Unit` and has a default value.
|
|
805
832
|
"""
|
|
806
833
|
decls_thunk = Thunk.fn(_relation_decls, name, tps, egg_fn)
|
|
807
|
-
return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, FunctionRef(name)))
|
|
834
|
+
return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, Thunk.value(FunctionRef(name))))
|
|
808
835
|
|
|
809
836
|
|
|
810
837
|
def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations:
|
|
@@ -828,7 +855,9 @@ def constant(
|
|
|
828
855
|
A "constant" is implemented as the instantiation of a value that takes no args.
|
|
829
856
|
This creates a function with `name` and return type `tp` and returns a value of it being called.
|
|
830
857
|
"""
|
|
831
|
-
return cast(
|
|
858
|
+
return cast(
|
|
859
|
+
EXPR, RuntimeExpr(*split_thunk(Thunk.fn(_constant_thunk, name, tp, egg_name, default_replacement, ruleset)))
|
|
860
|
+
)
|
|
832
861
|
|
|
833
862
|
|
|
834
863
|
def _constant_thunk(
|
|
@@ -842,24 +871,34 @@ def _constant_thunk(
|
|
|
842
871
|
return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
|
|
843
872
|
|
|
844
873
|
|
|
845
|
-
def
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
var_args: list[object] = [
|
|
855
|
-
RuntimeExpr.__from_value__(decls, TypedExprDecl(tp.to_just(), VarDecl(_rule_var_name(name))))
|
|
856
|
-
for name, tp in zip(signature.arg_names, signature.arg_types, strict=False)
|
|
857
|
-
]
|
|
874
|
+
def _create_default_value(
|
|
875
|
+
decls: Declarations,
|
|
876
|
+
ref: CallableRef | None,
|
|
877
|
+
fn: Callable,
|
|
878
|
+
args: Iterable[TypedExprDecl],
|
|
879
|
+
ruleset: Ruleset | None,
|
|
880
|
+
) -> object:
|
|
881
|
+
args: list[object] = [RuntimeExpr.__from_values__(decls, a) for a in args]
|
|
882
|
+
|
|
858
883
|
# If this is a classmethod, add the class as the first arg
|
|
859
884
|
if isinstance(ref, ClassMethodRef):
|
|
860
885
|
tp = decls.get_paramaterized_class(ref.class_name)
|
|
861
|
-
|
|
862
|
-
|
|
886
|
+
args.insert(0, RuntimeClass(Thunk.value(decls), tp))
|
|
887
|
+
with set_current_ruleset(ruleset):
|
|
888
|
+
return fn(*args)
|
|
889
|
+
|
|
890
|
+
|
|
891
|
+
def _add_default_rewrite_function(
|
|
892
|
+
decls: Declarations,
|
|
893
|
+
ref: CallableRef,
|
|
894
|
+
res_type: TypeOrVarRef,
|
|
895
|
+
ruleset: Ruleset | None,
|
|
896
|
+
value_thunk: Callable[[], object],
|
|
897
|
+
) -> None:
|
|
898
|
+
"""
|
|
899
|
+
Helper functions that resolves a value thunk to create the default value.
|
|
900
|
+
"""
|
|
901
|
+
_add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset)
|
|
863
902
|
|
|
864
903
|
|
|
865
904
|
def _add_default_rewrite(
|
|
@@ -872,7 +911,7 @@ def _add_default_rewrite(
|
|
|
872
911
|
"""
|
|
873
912
|
if default_rewrite is None:
|
|
874
913
|
return
|
|
875
|
-
resolved_value = resolve_literal(type_ref, default_rewrite)
|
|
914
|
+
resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls))
|
|
876
915
|
rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr)
|
|
877
916
|
if ruleset:
|
|
878
917
|
ruleset_decls = ruleset._current_egg_decls
|
|
@@ -931,6 +970,8 @@ class GraphvizKwargs(TypedDict, total=False):
|
|
|
931
970
|
max_calls_per_function: int | None
|
|
932
971
|
n_inline_leaves: int
|
|
933
972
|
split_primitive_outputs: bool
|
|
973
|
+
split_functions: list[object]
|
|
974
|
+
include_temporary_functions: bool
|
|
934
975
|
|
|
935
976
|
|
|
936
977
|
@dataclass
|
|
@@ -973,81 +1014,19 @@ class EGraph(_BaseModule):
|
|
|
973
1014
|
raise ValueError(msg)
|
|
974
1015
|
return cmds
|
|
975
1016
|
|
|
976
|
-
def
|
|
977
|
-
|
|
978
|
-
Returns the graphviz representation of the e-graph.
|
|
979
|
-
"""
|
|
980
|
-
return {"image/svg+xml": self.graphviz().pipe(format="svg", quiet=True, encoding="utf-8")}
|
|
981
|
-
|
|
982
|
-
def graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
|
|
983
|
-
# By default we want to split primitive outputs
|
|
984
|
-
kwargs.setdefault("split_primitive_outputs", True)
|
|
985
|
-
n_inline = kwargs.pop("n_inline_leaves", 0)
|
|
986
|
-
serialized = self._egraph.serialize([], **kwargs) # type: ignore[misc]
|
|
987
|
-
serialized.map_ops(self._state.op_mapping())
|
|
988
|
-
for _ in range(n_inline):
|
|
989
|
-
serialized.inline_leaves()
|
|
990
|
-
original = serialized.to_dot()
|
|
991
|
-
# Add link to stylesheet to the graph, so that edges light up on hover
|
|
992
|
-
# https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
|
|
993
|
-
styles = """/* the lines within the edges */
|
|
994
|
-
.edge:active path,
|
|
995
|
-
.edge:hover path {
|
|
996
|
-
stroke: fuchsia;
|
|
997
|
-
stroke-width: 3;
|
|
998
|
-
stroke-opacity: 1;
|
|
999
|
-
}
|
|
1000
|
-
/* arrows are typically drawn with a polygon */
|
|
1001
|
-
.edge:active polygon,
|
|
1002
|
-
.edge:hover polygon {
|
|
1003
|
-
stroke: fuchsia;
|
|
1004
|
-
stroke-width: 3;
|
|
1005
|
-
fill: fuchsia;
|
|
1006
|
-
stroke-opacity: 1;
|
|
1007
|
-
fill-opacity: 1;
|
|
1008
|
-
}
|
|
1009
|
-
/* If you happen to have text and want to color that as well... */
|
|
1010
|
-
.edge:active text,
|
|
1011
|
-
.edge:hover text {
|
|
1012
|
-
fill: fuchsia;
|
|
1013
|
-
}"""
|
|
1014
|
-
p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
|
|
1015
|
-
p.write_text(styles)
|
|
1016
|
-
with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
|
|
1017
|
-
return graphviz.Source(with_stylesheet)
|
|
1018
|
-
|
|
1019
|
-
def graphviz_svg(self, **kwargs: Unpack[GraphvizKwargs]) -> str:
|
|
1020
|
-
return self.graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")
|
|
1021
|
-
|
|
1022
|
-
def _repr_html_(self) -> str:
|
|
1023
|
-
"""
|
|
1024
|
-
Add a _repr_html_ to be an SVG to work with sphinx gallery.
|
|
1025
|
-
|
|
1026
|
-
ala https://github.com/xflr6/graphviz/pull/121
|
|
1027
|
-
until this PR is merged and released
|
|
1028
|
-
https://github.com/sphinx-gallery/sphinx-gallery/pull/1138
|
|
1029
|
-
"""
|
|
1030
|
-
return self.graphviz_svg()
|
|
1031
|
-
|
|
1032
|
-
def display(self, **kwargs: Unpack[GraphvizKwargs]) -> None:
|
|
1033
|
-
"""
|
|
1034
|
-
Displays the e-graph in the notebook.
|
|
1035
|
-
"""
|
|
1036
|
-
if IN_IPYTHON:
|
|
1037
|
-
from IPython.display import SVG, display
|
|
1038
|
-
|
|
1039
|
-
display(SVG(self.graphviz_svg(**kwargs)))
|
|
1040
|
-
else:
|
|
1041
|
-
self.graphviz(**kwargs).render(view=True, format="svg", quiet=True)
|
|
1017
|
+
def _ipython_display_(self) -> None:
|
|
1018
|
+
self.display()
|
|
1042
1019
|
|
|
1043
1020
|
def input(self, fn: Callable[..., String], path: str) -> None:
|
|
1044
1021
|
"""
|
|
1045
1022
|
Loads a CSV file and sets it as *input, output of the function.
|
|
1046
1023
|
"""
|
|
1024
|
+
self._egraph.run_program(bindings.Input(self._callable_to_egg(fn), path))
|
|
1025
|
+
|
|
1026
|
+
def _callable_to_egg(self, fn: object) -> str:
|
|
1047
1027
|
ref, decls = resolve_callable(fn)
|
|
1048
1028
|
self._add_decls(decls)
|
|
1049
|
-
|
|
1050
|
-
self._egraph.run_program(bindings.Input(fn_name, path))
|
|
1029
|
+
return self._state.callable_ref_to_egg(ref)
|
|
1051
1030
|
|
|
1052
1031
|
def let(self, name: str, expr: EXPR) -> EXPR:
|
|
1053
1032
|
"""
|
|
@@ -1059,8 +1038,8 @@ class EGraph(_BaseModule):
|
|
|
1059
1038
|
self._add_decls(runtime_expr)
|
|
1060
1039
|
return cast(
|
|
1061
1040
|
EXPR,
|
|
1062
|
-
RuntimeExpr.
|
|
1063
|
-
self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name))
|
|
1041
|
+
RuntimeExpr.__from_values__(
|
|
1042
|
+
self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name, True))
|
|
1064
1043
|
),
|
|
1065
1044
|
)
|
|
1066
1045
|
|
|
@@ -1090,7 +1069,7 @@ class EGraph(_BaseModule):
|
|
|
1090
1069
|
msg = "No extract report saved"
|
|
1091
1070
|
raise ValueError(msg) # noqa: TRY004
|
|
1092
1071
|
(new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
|
|
1093
|
-
return cast(EXPR, RuntimeExpr.
|
|
1072
|
+
return cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
|
|
1094
1073
|
|
|
1095
1074
|
def include(self, path: str) -> None:
|
|
1096
1075
|
"""
|
|
@@ -1145,7 +1124,7 @@ class EGraph(_BaseModule):
|
|
|
1145
1124
|
facts = _fact_likes(fact_likes)
|
|
1146
1125
|
self._add_decls(*facts)
|
|
1147
1126
|
egg_facts = [self._state.fact_to_egg(f.fact) for f in _fact_likes(facts)]
|
|
1148
|
-
return bindings.Check(egg_facts)
|
|
1127
|
+
return bindings.Check(bindings.DUMMY_SPAN, egg_facts)
|
|
1149
1128
|
|
|
1150
1129
|
@overload
|
|
1151
1130
|
def extract(self, expr: EXPR, /, include_cost: Literal[False] = False) -> EXPR: ...
|
|
@@ -1167,7 +1146,7 @@ class EGraph(_BaseModule):
|
|
|
1167
1146
|
raise ValueError(msg) # noqa: TRY004
|
|
1168
1147
|
(new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
|
|
1169
1148
|
|
|
1170
|
-
res = cast(EXPR, RuntimeExpr.
|
|
1149
|
+
res = cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
|
|
1171
1150
|
if include_cost:
|
|
1172
1151
|
return res, extract_report.cost
|
|
1173
1152
|
return res
|
|
@@ -1185,11 +1164,15 @@ class EGraph(_BaseModule):
|
|
|
1185
1164
|
msg = "Wrong extract report type"
|
|
1186
1165
|
raise ValueError(msg) # noqa: TRY004
|
|
1187
1166
|
new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp)
|
|
1188
|
-
return [cast(EXPR, RuntimeExpr.
|
|
1167
|
+
return [cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs]
|
|
1189
1168
|
|
|
1190
1169
|
def _run_extract(self, typed_expr: TypedExprDecl, n: int) -> bindings._ExtractReport:
|
|
1191
1170
|
expr = self._state.typed_expr_to_egg(typed_expr)
|
|
1192
|
-
self._egraph.run_program(
|
|
1171
|
+
self._egraph.run_program(
|
|
1172
|
+
bindings.ActionCommand(
|
|
1173
|
+
bindings.Extract(bindings.DUMMY_SPAN, expr, bindings.Lit(bindings.DUMMY_SPAN, bindings.Int(n)))
|
|
1174
|
+
)
|
|
1175
|
+
)
|
|
1193
1176
|
extract_report = self._egraph.extract_report()
|
|
1194
1177
|
if not extract_report:
|
|
1195
1178
|
msg = "No extract report saved"
|
|
@@ -1261,42 +1244,110 @@ class EGraph(_BaseModule):
|
|
|
1261
1244
|
return self._egraph.eval_py_object(egg_expr)
|
|
1262
1245
|
raise TypeError(f"Eval not implemented for {typed_expr.tp}")
|
|
1263
1246
|
|
|
1264
|
-
def
|
|
1265
|
-
self,
|
|
1266
|
-
|
|
1267
|
-
|
|
1247
|
+
def _serialize(
|
|
1248
|
+
self,
|
|
1249
|
+
**kwargs: Unpack[GraphvizKwargs],
|
|
1250
|
+
) -> bindings.SerializedEGraph:
|
|
1251
|
+
max_functions = kwargs.pop("max_functions", None)
|
|
1252
|
+
max_calls_per_function = kwargs.pop("max_calls_per_function", None)
|
|
1253
|
+
split_primitive_outputs = kwargs.pop("split_primitive_outputs", True)
|
|
1254
|
+
split_functions = kwargs.pop("split_functions", [])
|
|
1255
|
+
include_temporary_functions = kwargs.pop("include_temporary_functions", False)
|
|
1256
|
+
n_inline_leaves = kwargs.pop("n_inline_leaves", 1)
|
|
1257
|
+
serialized = self._egraph.serialize(
|
|
1258
|
+
[],
|
|
1259
|
+
max_functions=max_functions,
|
|
1260
|
+
max_calls_per_function=max_calls_per_function,
|
|
1261
|
+
include_temporary_functions=include_temporary_functions,
|
|
1262
|
+
)
|
|
1263
|
+
if split_primitive_outputs or split_functions:
|
|
1264
|
+
additional_ops = set(map(self._callable_to_egg, split_functions))
|
|
1265
|
+
serialized.split_e_classes(self._egraph, additional_ops)
|
|
1266
|
+
serialized.map_ops(self._state.op_mapping())
|
|
1268
1267
|
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
while self.run(1).updated and i < max:
|
|
1272
|
-
i += 1
|
|
1273
|
-
dots.append(str(self.graphviz(**kwargs)))
|
|
1274
|
-
return graphviz_widget_with_slider(dots, performance=performance)
|
|
1268
|
+
for _ in range(n_inline_leaves):
|
|
1269
|
+
serialized.inline_leaves()
|
|
1275
1270
|
|
|
1276
|
-
|
|
1277
|
-
self, file: str = "tmp.html", performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
|
|
1278
|
-
) -> None:
|
|
1279
|
-
# raise NotImplementedError("Upstream bugs prevent rendering to HTML")
|
|
1271
|
+
return serialized
|
|
1280
1272
|
|
|
1281
|
-
|
|
1273
|
+
def _graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
|
|
1274
|
+
serialized = self._serialize(**kwargs)
|
|
1282
1275
|
|
|
1283
|
-
|
|
1276
|
+
original = serialized.to_dot()
|
|
1277
|
+
# Add link to stylesheet to the graph, so that edges light up on hover
|
|
1278
|
+
# https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
|
|
1279
|
+
styles = """/* the lines within the edges */
|
|
1280
|
+
.edge:active path,
|
|
1281
|
+
.edge:hover path {
|
|
1282
|
+
stroke: fuchsia;
|
|
1283
|
+
stroke-width: 3;
|
|
1284
|
+
stroke-opacity: 1;
|
|
1285
|
+
}
|
|
1286
|
+
/* arrows are typically drawn with a polygon */
|
|
1287
|
+
.edge:active polygon,
|
|
1288
|
+
.edge:hover polygon {
|
|
1289
|
+
stroke: fuchsia;
|
|
1290
|
+
stroke-width: 3;
|
|
1291
|
+
fill: fuchsia;
|
|
1292
|
+
stroke-opacity: 1;
|
|
1293
|
+
fill-opacity: 1;
|
|
1294
|
+
}
|
|
1295
|
+
/* If you happen to have text and want to color that as well... */
|
|
1296
|
+
.edge:active text,
|
|
1297
|
+
.edge:hover text {
|
|
1298
|
+
fill: fuchsia;
|
|
1299
|
+
}"""
|
|
1300
|
+
p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
|
|
1301
|
+
p.write_text(styles)
|
|
1302
|
+
with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
|
|
1303
|
+
return graphviz.Source(with_stylesheet)
|
|
1304
|
+
|
|
1305
|
+
def display(self, graphviz: bool = False, **kwargs: Unpack[GraphvizKwargs]) -> None:
|
|
1306
|
+
"""
|
|
1307
|
+
Displays the e-graph.
|
|
1284
1308
|
|
|
1285
|
-
|
|
1286
|
-
|
|
1309
|
+
If in IPython it will display it inline, otherwise it will write it to a file and open it.
|
|
1310
|
+
"""
|
|
1311
|
+
from IPython.display import SVG, display
|
|
1312
|
+
|
|
1313
|
+
from .visualizer_widget import VisualizerWidget
|
|
1314
|
+
|
|
1315
|
+
if graphviz:
|
|
1316
|
+
if IN_IPYTHON:
|
|
1317
|
+
svg = self._graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")
|
|
1318
|
+
display(SVG(svg))
|
|
1319
|
+
else:
|
|
1320
|
+
self._graphviz(**kwargs).render(view=True, format="svg", quiet=True)
|
|
1321
|
+
else:
|
|
1322
|
+
serialized = self._serialize(**kwargs)
|
|
1323
|
+
VisualizerWidget(egraphs=[serialized.to_json()]).display_or_open()
|
|
1287
1324
|
|
|
1288
|
-
|
|
1325
|
+
def saturate(self, schedule: Schedule | None = None, *, max: int = 1000, **kwargs: Unpack[GraphvizKwargs]) -> None:
|
|
1326
|
+
"""
|
|
1327
|
+
Saturate the egraph, running the given schedule until the egraph is saturated.
|
|
1328
|
+
It serializes the egraph at each step and returns a widget to visualize the egraph.
|
|
1329
|
+
"""
|
|
1330
|
+
from .visualizer_widget import VisualizerWidget
|
|
1289
1331
|
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1332
|
+
def to_json() -> str:
|
|
1333
|
+
return self._serialize(**kwargs).to_json()
|
|
1334
|
+
|
|
1335
|
+
egraphs = [to_json()]
|
|
1336
|
+
i = 0
|
|
1337
|
+
while self.run(schedule or 1).updated and i < max:
|
|
1338
|
+
i += 1
|
|
1339
|
+
egraphs.append(to_json())
|
|
1340
|
+
VisualizerWidget(egraphs=egraphs).display_or_open()
|
|
1293
1341
|
|
|
1294
1342
|
@classmethod
|
|
1295
1343
|
def current(cls) -> EGraph:
|
|
1296
1344
|
"""
|
|
1297
1345
|
Returns the current egraph, which is the one in the context.
|
|
1298
1346
|
"""
|
|
1299
|
-
|
|
1347
|
+
try:
|
|
1348
|
+
return CURRENT_EGRAPH.get()
|
|
1349
|
+
except LookupError:
|
|
1350
|
+
return cls(save_egglog_string=True)
|
|
1300
1351
|
|
|
1301
1352
|
@property
|
|
1302
1353
|
def _egraph(self) -> bindings.EGraph:
|
|
@@ -1448,7 +1499,8 @@ class Ruleset(Schedule):
|
|
|
1448
1499
|
To return the egg decls, we go through our deferred rules and add any we haven't yet
|
|
1449
1500
|
"""
|
|
1450
1501
|
while self.deferred_rule_gens:
|
|
1451
|
-
|
|
1502
|
+
with set_current_ruleset(self):
|
|
1503
|
+
rules = self.deferred_rule_gens.pop()()
|
|
1452
1504
|
self._current_egg_decls.update(*rules)
|
|
1453
1505
|
self.__egg_ruleset__.rules.extend(r.decl for r in rules)
|
|
1454
1506
|
return self._current_egg_decls
|
|
@@ -1688,16 +1740,16 @@ def action_command(action: Action) -> Action:
|
|
|
1688
1740
|
return action
|
|
1689
1741
|
|
|
1690
1742
|
|
|
1691
|
-
def var(name: str, bound: type[
|
|
1743
|
+
def var(name: str, bound: type[T]) -> T:
|
|
1692
1744
|
"""Create a new variable with the given name and type."""
|
|
1693
|
-
return cast(
|
|
1745
|
+
return cast(T, _var(name, bound))
|
|
1694
1746
|
|
|
1695
1747
|
|
|
1696
1748
|
def _var(name: str, bound: object) -> RuntimeExpr:
|
|
1697
1749
|
"""Create a new variable with the given name and type."""
|
|
1698
1750
|
decls = Declarations()
|
|
1699
1751
|
type_ref = resolve_type_annotation(decls, bound)
|
|
1700
|
-
return RuntimeExpr.
|
|
1752
|
+
return RuntimeExpr.__from_values__(decls, TypedExprDecl(type_ref.to_just(), VarDecl(name, False)))
|
|
1701
1753
|
|
|
1702
1754
|
|
|
1703
1755
|
def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
|
|
@@ -1790,7 +1842,7 @@ class _NeBuilder(Generic[EXPR]):
|
|
|
1790
1842
|
lhs = to_runtime_expr(self.lhs)
|
|
1791
1843
|
rhs = convert_to_same_type(rhs, lhs)
|
|
1792
1844
|
assert isinstance(Unit, RuntimeClass)
|
|
1793
|
-
res = RuntimeExpr.
|
|
1845
|
+
res = RuntimeExpr.__from_values__(
|
|
1794
1846
|
Declarations.create(Unit, lhs, rhs),
|
|
1795
1847
|
TypedExprDecl(
|
|
1796
1848
|
JustTypeRef("Unit"), CallDecl(FunctionRef("!="), (lhs.__egg_typed_expr__, rhs.__egg_typed_expr__))
|
|
@@ -1944,7 +1996,7 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
|
|
|
1944
1996
|
if "Callable" not in globals:
|
|
1945
1997
|
globals["Callable"] = Callable
|
|
1946
1998
|
hints = get_type_hints(gen, globals, frame.f_locals)
|
|
1947
|
-
args = [_var(
|
|
1999
|
+
args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
|
|
1948
2000
|
return list(gen(*args)) # type: ignore[misc]
|
|
1949
2001
|
|
|
1950
2002
|
|
|
@@ -1959,3 +2011,19 @@ def _fact_like(fact_like: FactLike) -> Fact:
|
|
|
1959
2011
|
if isinstance(fact_like, Expr):
|
|
1960
2012
|
return expr_fact(fact_like)
|
|
1961
2013
|
return fact_like
|
|
2014
|
+
|
|
2015
|
+
|
|
2016
|
+
_CURRENT_RULESET = ContextVar[Ruleset | None]("CURRENT_RULESET", default=None)
|
|
2017
|
+
|
|
2018
|
+
|
|
2019
|
+
def get_current_ruleset() -> Ruleset | None:
|
|
2020
|
+
return _CURRENT_RULESET.get()
|
|
2021
|
+
|
|
2022
|
+
|
|
2023
|
+
@contextlib.contextmanager
|
|
2024
|
+
def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
|
|
2025
|
+
token = _CURRENT_RULESET.set(r)
|
|
2026
|
+
try:
|
|
2027
|
+
yield
|
|
2028
|
+
finally:
|
|
2029
|
+
_CURRENT_RULESET.reset(token)
|