egglog 7.2.0__cp311-none-win_amd64.whl → 8.0.1__cp311-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp311-win_amd64.pyd +0 -0
- egglog/bindings.pyi +107 -53
- egglog/builtins.py +49 -6
- egglog/conversion.py +32 -9
- egglog/declarations.py +82 -4
- egglog/egraph.py +260 -179
- egglog/egraph_state.py +149 -66
- egglog/examples/higher_order_functions.py +4 -9
- egglog/exp/array_api.py +278 -93
- egglog/exp/array_api_jit.py +4 -8
- egglog/exp/array_api_loopnest.py +149 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +62 -25
- egglog/exp/program_gen.py +23 -17
- egglog/functionalize.py +91 -0
- egglog/ipython_magic.py +1 -1
- egglog/pretty.py +88 -44
- egglog/runtime.py +53 -40
- egglog/thunk.py +30 -18
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35774 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.2.0.dist-info → egglog-8.0.1.dist-info}/METADATA +33 -32
- egglog-8.0.1.dist-info/RECORD +42 -0
- {egglog-7.2.0.dist-info → egglog-8.0.1.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.2.0.dist-info/RECORD +0 -40
- {egglog-7.2.0.dist-info/license_files → egglog-8.0.1.dist-info/licenses}/LICENSE +0 -0
egglog/egraph.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import contextlib
|
|
3
4
|
import inspect
|
|
4
5
|
import pathlib
|
|
5
6
|
import tempfile
|
|
6
7
|
from abc import abstractmethod
|
|
7
|
-
from collections.abc import Callable, Iterable
|
|
8
|
+
from collections.abc import Callable, Generator, Iterable
|
|
8
9
|
from contextvars import ContextVar, Token
|
|
9
10
|
from dataclasses import InitVar, dataclass, field
|
|
10
11
|
from inspect import Parameter, currentframe, signature
|
|
@@ -37,8 +38,6 @@ from .runtime import *
|
|
|
37
38
|
from .thunk import *
|
|
38
39
|
|
|
39
40
|
if TYPE_CHECKING:
|
|
40
|
-
import ipywidgets
|
|
41
|
-
|
|
42
41
|
from .builtins import Bool, PyObject, String, f64, i64
|
|
43
42
|
|
|
44
43
|
|
|
@@ -464,10 +463,14 @@ class _ExprMetaclass(type):
|
|
|
464
463
|
prev_frame = frame.f_back
|
|
465
464
|
assert prev_frame
|
|
466
465
|
|
|
466
|
+
# Pass in an instance of the class so that when we are generating the decls
|
|
467
|
+
# we can update them eagerly so that we can access the methods in the class body
|
|
468
|
+
runtime_cls = RuntimeClass(None, TypeRefWithVars(name)) # type: ignore[arg-type]
|
|
469
|
+
|
|
467
470
|
# Store frame so that we can get live access to updated locals/globals
|
|
468
471
|
# Otherwise, f_locals returns a copy
|
|
469
472
|
# https://peps.python.org/pep-0667/
|
|
470
|
-
|
|
473
|
+
runtime_cls.__egg_decls_thunk__ = Thunk.fn(
|
|
471
474
|
_generate_class_decls,
|
|
472
475
|
namespace,
|
|
473
476
|
prev_frame,
|
|
@@ -475,21 +478,22 @@ class _ExprMetaclass(type):
|
|
|
475
478
|
egg_sort,
|
|
476
479
|
name,
|
|
477
480
|
ruleset,
|
|
478
|
-
|
|
481
|
+
runtime_cls,
|
|
479
482
|
)
|
|
480
|
-
return
|
|
483
|
+
return runtime_cls
|
|
481
484
|
|
|
482
485
|
def __instancecheck__(cls, instance: object) -> bool:
|
|
483
486
|
return isinstance(instance, RuntimeExpr)
|
|
484
487
|
|
|
485
488
|
|
|
486
|
-
def _generate_class_decls(
|
|
489
|
+
def _generate_class_decls( # noqa: C901,PLR0912
|
|
487
490
|
namespace: dict[str, Any],
|
|
488
491
|
frame: FrameType,
|
|
489
492
|
builtin: bool,
|
|
490
493
|
egg_sort: str | None,
|
|
491
494
|
cls_name: str,
|
|
492
495
|
ruleset: Ruleset | None,
|
|
496
|
+
runtime_cls: RuntimeClass,
|
|
493
497
|
) -> Declarations:
|
|
494
498
|
"""
|
|
495
499
|
Lazy constructor for class declerations to support classes with methods whose types are not yet defined.
|
|
@@ -502,6 +506,8 @@ def _generate_class_decls(
|
|
|
502
506
|
del parameters
|
|
503
507
|
cls_decl = ClassDecl(egg_sort, type_vars, builtin)
|
|
504
508
|
decls = Declarations(_classes={cls_name: cls_decl})
|
|
509
|
+
# Update class think eagerly when resolving so that lookups work in methods
|
|
510
|
+
runtime_cls.__egg_decls_thunk__ = Thunk.value(decls)
|
|
505
511
|
|
|
506
512
|
##
|
|
507
513
|
# Register class variables
|
|
@@ -522,16 +528,13 @@ def _generate_class_decls(
|
|
|
522
528
|
# Register methods, classmethods, preserved methods, and properties
|
|
523
529
|
##
|
|
524
530
|
|
|
525
|
-
# The type ref of self is paramterized by the type vars
|
|
526
|
-
TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
|
|
527
|
-
|
|
528
531
|
# Get all the methods from the class
|
|
529
532
|
filtered_namespace: list[tuple[str, Any]] = [
|
|
530
533
|
(k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
|
|
531
534
|
]
|
|
532
535
|
|
|
533
536
|
# all methods we should try adding default functions for
|
|
534
|
-
|
|
537
|
+
add_default_funcs: list[Callable[[], None]] = []
|
|
535
538
|
# Then register each of its methods
|
|
536
539
|
for method_name, method in filtered_namespace:
|
|
537
540
|
is_init = method_name == "__init__"
|
|
@@ -550,7 +553,6 @@ def _generate_class_decls(
|
|
|
550
553
|
cls_decl.preserved_methods[method_name] = fn
|
|
551
554
|
continue
|
|
552
555
|
locals = frame.f_locals
|
|
553
|
-
|
|
554
556
|
ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
|
|
555
557
|
match fn:
|
|
556
558
|
case classmethod():
|
|
@@ -561,16 +563,25 @@ def _generate_class_decls(
|
|
|
561
563
|
fn = fn.fget
|
|
562
564
|
case _:
|
|
563
565
|
ref = InitRef(cls_name) if is_init else MethodRef(cls_name, method_name)
|
|
566
|
+
special_function_name: SpecialFunctions | None = (
|
|
567
|
+
"fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None
|
|
568
|
+
)
|
|
569
|
+
if special_function_name:
|
|
570
|
+
decl = FunctionDecl(special_function_name, builtin=True, egg_name=egg_fn)
|
|
571
|
+
decls.set_function_decl(ref, decl)
|
|
572
|
+
continue
|
|
564
573
|
|
|
565
|
-
|
|
574
|
+
_, add_rewrite = _fn_decl(
|
|
575
|
+
decls, egg_fn, ref, fn, locals, default, cost, merge, on_merge, mutates, builtin, ruleset, unextractable
|
|
576
|
+
)
|
|
566
577
|
|
|
567
578
|
if not builtin and not isinstance(ref, InitRef) and not mutates:
|
|
568
|
-
|
|
579
|
+
add_default_funcs.append(add_rewrite)
|
|
569
580
|
|
|
570
581
|
# Add all rewrite methods at the end so that all methods are registered first and can be accessed
|
|
571
582
|
# in the bodies
|
|
572
|
-
for
|
|
573
|
-
|
|
583
|
+
for add_rewrite in add_default_funcs:
|
|
584
|
+
add_rewrite()
|
|
574
585
|
|
|
575
586
|
return decls
|
|
576
587
|
|
|
@@ -590,6 +601,7 @@ def function(
|
|
|
590
601
|
unextractable: bool = False,
|
|
591
602
|
builtin: bool = False,
|
|
592
603
|
ruleset: Ruleset | None = None,
|
|
604
|
+
use_body_as_name: bool = False,
|
|
593
605
|
) -> Callable[[CALLABLE], CALLABLE]: ...
|
|
594
606
|
|
|
595
607
|
|
|
@@ -604,6 +616,7 @@ def function(
|
|
|
604
616
|
mutates_first_arg: bool = False,
|
|
605
617
|
unextractable: bool = False,
|
|
606
618
|
ruleset: Ruleset | None = None,
|
|
619
|
+
use_body_as_name: bool = False,
|
|
607
620
|
) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...
|
|
608
621
|
|
|
609
622
|
|
|
@@ -635,16 +648,18 @@ class _FunctionConstructor:
|
|
|
635
648
|
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
|
|
636
649
|
unextractable: bool = False
|
|
637
650
|
ruleset: Ruleset | None = None
|
|
651
|
+
use_body_as_name: bool = False
|
|
638
652
|
|
|
639
653
|
def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
|
|
640
|
-
return RuntimeFunction(Thunk.fn(self.create_decls, fn)
|
|
654
|
+
return RuntimeFunction(*split_thunk(Thunk.fn(self.create_decls, fn)))
|
|
641
655
|
|
|
642
|
-
def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations:
|
|
656
|
+
def create_decls(self, fn: Callable[..., RuntimeExpr]) -> tuple[Declarations, CallableRef]:
|
|
643
657
|
decls = Declarations()
|
|
644
|
-
|
|
658
|
+
ref = None if self.use_body_as_name else FunctionRef(fn.__name__)
|
|
659
|
+
ref, add_rewrite = _fn_decl(
|
|
645
660
|
decls,
|
|
646
661
|
self.egg_fn,
|
|
647
|
-
|
|
662
|
+
ref,
|
|
648
663
|
fn,
|
|
649
664
|
self.hint_locals,
|
|
650
665
|
self.default,
|
|
@@ -653,16 +668,18 @@ class _FunctionConstructor:
|
|
|
653
668
|
self.on_merge,
|
|
654
669
|
self.mutates_first_arg,
|
|
655
670
|
self.builtin,
|
|
671
|
+
self.ruleset,
|
|
656
672
|
unextractable=self.unextractable,
|
|
657
673
|
)
|
|
658
|
-
|
|
659
|
-
return decls
|
|
674
|
+
add_rewrite()
|
|
675
|
+
return decls, ref
|
|
660
676
|
|
|
661
677
|
|
|
662
678
|
def _fn_decl(
|
|
663
679
|
decls: Declarations,
|
|
664
680
|
egg_name: str | None,
|
|
665
|
-
ref
|
|
681
|
+
# If ref is Callable, then generate the ref from the function name
|
|
682
|
+
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef | None,
|
|
666
683
|
fn: object,
|
|
667
684
|
# Pass in the locals, retrieved from the frame when wrapping,
|
|
668
685
|
# so that we support classes and function defined inside of other functions (which won't show up in the globals)
|
|
@@ -673,30 +690,26 @@ def _fn_decl(
|
|
|
673
690
|
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None,
|
|
674
691
|
mutates_first_arg: bool,
|
|
675
692
|
is_builtin: bool,
|
|
693
|
+
ruleset: Ruleset | None = None,
|
|
676
694
|
unextractable: bool = False,
|
|
677
|
-
) -> None:
|
|
695
|
+
) -> tuple[CallableRef, Callable[[], None]]:
|
|
678
696
|
"""
|
|
679
|
-
Sets the function decl for the function object.
|
|
697
|
+
Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
|
|
680
698
|
"""
|
|
681
699
|
if not isinstance(fn, FunctionType):
|
|
682
700
|
raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
|
|
683
701
|
|
|
684
|
-
# partial function creation and calling are handled with a special case in the type checker, so don't
|
|
685
|
-
# use the normal logic
|
|
686
|
-
special_function_name: SpecialFunctions | None = (
|
|
687
|
-
"fn-partial" if egg_name == "unstable-fn" else "fn-app" if egg_name == "unstable-app" else None
|
|
688
|
-
)
|
|
689
|
-
if special_function_name:
|
|
690
|
-
decls.set_function_decl(ref, FunctionDecl(special_function_name, builtin=True, egg_name=egg_name))
|
|
691
|
-
return
|
|
692
|
-
|
|
693
702
|
hint_globals = fn.__globals__.copy()
|
|
694
703
|
# Copy Callable into global if not present bc sometimes it gets automatically removed by ruff to type only block
|
|
695
704
|
# https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/
|
|
696
705
|
if "Callable" not in hint_globals:
|
|
697
706
|
hint_globals["Callable"] = Callable
|
|
698
|
-
|
|
699
|
-
|
|
707
|
+
# Instead of passing both globals and locals, just pass the globals. Otherwise, for some reason forward references
|
|
708
|
+
# won't be resolved correctly
|
|
709
|
+
# We need this to be false so it returns "__forward_value__" https://github.com/python/cpython/blob/440ed18e08887b958ad50db1b823e692a747b671/Lib/typing.py#L919
|
|
710
|
+
# https://github.com/egraphs-good/egglog-python/issues/210
|
|
711
|
+
hint_globals.update(hint_locals)
|
|
712
|
+
hints = get_type_hints(fn, hint_globals)
|
|
700
713
|
|
|
701
714
|
params = list(signature(fn).parameters.values())
|
|
702
715
|
|
|
@@ -719,7 +732,7 @@ def _fn_decl(
|
|
|
719
732
|
|
|
720
733
|
# Resolve all default values as arg types
|
|
721
734
|
arg_defaults = [
|
|
722
|
-
resolve_literal(t, p.default) if p.default is not Parameter.empty else None
|
|
735
|
+
resolve_literal(t, p.default, Thunk.value(decls)) if p.default is not Parameter.empty else None
|
|
723
736
|
for (t, p) in zip(arg_types, params, strict=True)
|
|
724
737
|
]
|
|
725
738
|
|
|
@@ -740,8 +753,8 @@ def _fn_decl(
|
|
|
740
753
|
None
|
|
741
754
|
if merge is None
|
|
742
755
|
else merge(
|
|
743
|
-
RuntimeExpr.
|
|
744
|
-
RuntimeExpr.
|
|
756
|
+
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old", False))),
|
|
757
|
+
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new", False))),
|
|
745
758
|
)
|
|
746
759
|
)
|
|
747
760
|
decls |= merged
|
|
@@ -751,29 +764,47 @@ def _fn_decl(
|
|
|
751
764
|
if on_merge is None
|
|
752
765
|
else _action_likes(
|
|
753
766
|
on_merge(
|
|
754
|
-
RuntimeExpr.
|
|
755
|
-
RuntimeExpr.
|
|
767
|
+
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old", False))),
|
|
768
|
+
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new", False))),
|
|
756
769
|
)
|
|
757
770
|
)
|
|
758
771
|
)
|
|
759
772
|
decls.update(*merge_action)
|
|
760
|
-
|
|
761
|
-
|
|
773
|
+
# defer this in generator so it doesnt resolve for builtins eagerly
|
|
774
|
+
args = (TypedExprDecl(tp.to_just(), VarDecl(name, False)) for name, tp in zip(arg_names, arg_types, strict=True))
|
|
775
|
+
res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
|
|
776
|
+
res_thunk: Callable[[], object]
|
|
777
|
+
# If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it
|
|
778
|
+
if not ref:
|
|
779
|
+
tuple_args = tuple(args)
|
|
780
|
+
res = _create_default_value(decls, ref, fn, tuple_args, ruleset)
|
|
781
|
+
assert isinstance(res, RuntimeExpr)
|
|
782
|
+
res_ref = UnnamedFunctionRef(tuple_args, res.__egg_typed_expr__)
|
|
783
|
+
decls._unnamed_functions.add(res_ref)
|
|
784
|
+
res_thunk = Thunk.value(res)
|
|
785
|
+
|
|
786
|
+
else:
|
|
787
|
+
signature_ = FunctionSignature(
|
|
762
788
|
return_type=None if mutates_first_arg else return_type,
|
|
763
789
|
var_arg_type=var_arg_type,
|
|
764
790
|
arg_types=arg_types,
|
|
765
791
|
arg_names=arg_names,
|
|
766
792
|
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
|
|
767
|
-
)
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
793
|
+
)
|
|
794
|
+
decl = FunctionDecl(
|
|
795
|
+
signature=signature_,
|
|
796
|
+
cost=cost,
|
|
797
|
+
egg_name=egg_name,
|
|
798
|
+
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
|
|
799
|
+
unextractable=unextractable,
|
|
800
|
+
builtin=is_builtin,
|
|
801
|
+
default=None if default is None else default.__egg_typed_expr__.expr,
|
|
802
|
+
on_merge=tuple(a.action for a in merge_action),
|
|
803
|
+
)
|
|
804
|
+
res_ref = ref
|
|
805
|
+
decls.set_function_decl(ref, decl)
|
|
806
|
+
res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset)
|
|
807
|
+
return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk)
|
|
777
808
|
|
|
778
809
|
|
|
779
810
|
# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
|
|
@@ -804,7 +835,7 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..
|
|
|
804
835
|
Creates a function whose return type is `Unit` and has a default value.
|
|
805
836
|
"""
|
|
806
837
|
decls_thunk = Thunk.fn(_relation_decls, name, tps, egg_fn)
|
|
807
|
-
return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, FunctionRef(name)))
|
|
838
|
+
return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, Thunk.value(FunctionRef(name))))
|
|
808
839
|
|
|
809
840
|
|
|
810
841
|
def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations:
|
|
@@ -828,7 +859,9 @@ def constant(
|
|
|
828
859
|
A "constant" is implemented as the instantiation of a value that takes no args.
|
|
829
860
|
This creates a function with `name` and return type `tp` and returns a value of it being called.
|
|
830
861
|
"""
|
|
831
|
-
return cast(
|
|
862
|
+
return cast(
|
|
863
|
+
EXPR, RuntimeExpr(*split_thunk(Thunk.fn(_constant_thunk, name, tp, egg_name, default_replacement, ruleset)))
|
|
864
|
+
)
|
|
832
865
|
|
|
833
866
|
|
|
834
867
|
def _constant_thunk(
|
|
@@ -842,24 +875,34 @@ def _constant_thunk(
|
|
|
842
875
|
return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
|
|
843
876
|
|
|
844
877
|
|
|
845
|
-
def
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
var_args: list[object] = [
|
|
855
|
-
RuntimeExpr.__from_value__(decls, TypedExprDecl(tp.to_just(), VarDecl(_rule_var_name(name))))
|
|
856
|
-
for name, tp in zip(signature.arg_names, signature.arg_types, strict=False)
|
|
857
|
-
]
|
|
878
|
+
def _create_default_value(
|
|
879
|
+
decls: Declarations,
|
|
880
|
+
ref: CallableRef | None,
|
|
881
|
+
fn: Callable,
|
|
882
|
+
args: Iterable[TypedExprDecl],
|
|
883
|
+
ruleset: Ruleset | None,
|
|
884
|
+
) -> object:
|
|
885
|
+
args: list[object] = [RuntimeExpr.__from_values__(decls, a) for a in args]
|
|
886
|
+
|
|
858
887
|
# If this is a classmethod, add the class as the first arg
|
|
859
888
|
if isinstance(ref, ClassMethodRef):
|
|
860
889
|
tp = decls.get_paramaterized_class(ref.class_name)
|
|
861
|
-
|
|
862
|
-
|
|
890
|
+
args.insert(0, RuntimeClass(Thunk.value(decls), tp))
|
|
891
|
+
with set_current_ruleset(ruleset):
|
|
892
|
+
return fn(*args)
|
|
893
|
+
|
|
894
|
+
|
|
895
|
+
def _add_default_rewrite_function(
|
|
896
|
+
decls: Declarations,
|
|
897
|
+
ref: CallableRef,
|
|
898
|
+
res_type: TypeOrVarRef,
|
|
899
|
+
ruleset: Ruleset | None,
|
|
900
|
+
value_thunk: Callable[[], object],
|
|
901
|
+
) -> None:
|
|
902
|
+
"""
|
|
903
|
+
Helper functions that resolves a value thunk to create the default value.
|
|
904
|
+
"""
|
|
905
|
+
_add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset)
|
|
863
906
|
|
|
864
907
|
|
|
865
908
|
def _add_default_rewrite(
|
|
@@ -872,7 +915,7 @@ def _add_default_rewrite(
|
|
|
872
915
|
"""
|
|
873
916
|
if default_rewrite is None:
|
|
874
917
|
return
|
|
875
|
-
resolved_value = resolve_literal(type_ref, default_rewrite)
|
|
918
|
+
resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls))
|
|
876
919
|
rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr)
|
|
877
920
|
if ruleset:
|
|
878
921
|
ruleset_decls = ruleset._current_egg_decls
|
|
@@ -931,6 +974,8 @@ class GraphvizKwargs(TypedDict, total=False):
|
|
|
931
974
|
max_calls_per_function: int | None
|
|
932
975
|
n_inline_leaves: int
|
|
933
976
|
split_primitive_outputs: bool
|
|
977
|
+
split_functions: list[object]
|
|
978
|
+
include_temporary_functions: bool
|
|
934
979
|
|
|
935
980
|
|
|
936
981
|
@dataclass
|
|
@@ -973,81 +1018,19 @@ class EGraph(_BaseModule):
|
|
|
973
1018
|
raise ValueError(msg)
|
|
974
1019
|
return cmds
|
|
975
1020
|
|
|
976
|
-
def
|
|
977
|
-
|
|
978
|
-
Returns the graphviz representation of the e-graph.
|
|
979
|
-
"""
|
|
980
|
-
return {"image/svg+xml": self.graphviz().pipe(format="svg", quiet=True, encoding="utf-8")}
|
|
981
|
-
|
|
982
|
-
def graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
|
|
983
|
-
# By default we want to split primitive outputs
|
|
984
|
-
kwargs.setdefault("split_primitive_outputs", True)
|
|
985
|
-
n_inline = kwargs.pop("n_inline_leaves", 0)
|
|
986
|
-
serialized = self._egraph.serialize([], **kwargs) # type: ignore[misc]
|
|
987
|
-
serialized.map_ops(self._state.op_mapping())
|
|
988
|
-
for _ in range(n_inline):
|
|
989
|
-
serialized.inline_leaves()
|
|
990
|
-
original = serialized.to_dot()
|
|
991
|
-
# Add link to stylesheet to the graph, so that edges light up on hover
|
|
992
|
-
# https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
|
|
993
|
-
styles = """/* the lines within the edges */
|
|
994
|
-
.edge:active path,
|
|
995
|
-
.edge:hover path {
|
|
996
|
-
stroke: fuchsia;
|
|
997
|
-
stroke-width: 3;
|
|
998
|
-
stroke-opacity: 1;
|
|
999
|
-
}
|
|
1000
|
-
/* arrows are typically drawn with a polygon */
|
|
1001
|
-
.edge:active polygon,
|
|
1002
|
-
.edge:hover polygon {
|
|
1003
|
-
stroke: fuchsia;
|
|
1004
|
-
stroke-width: 3;
|
|
1005
|
-
fill: fuchsia;
|
|
1006
|
-
stroke-opacity: 1;
|
|
1007
|
-
fill-opacity: 1;
|
|
1008
|
-
}
|
|
1009
|
-
/* If you happen to have text and want to color that as well... */
|
|
1010
|
-
.edge:active text,
|
|
1011
|
-
.edge:hover text {
|
|
1012
|
-
fill: fuchsia;
|
|
1013
|
-
}"""
|
|
1014
|
-
p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
|
|
1015
|
-
p.write_text(styles)
|
|
1016
|
-
with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
|
|
1017
|
-
return graphviz.Source(with_stylesheet)
|
|
1018
|
-
|
|
1019
|
-
def graphviz_svg(self, **kwargs: Unpack[GraphvizKwargs]) -> str:
|
|
1020
|
-
return self.graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")
|
|
1021
|
-
|
|
1022
|
-
def _repr_html_(self) -> str:
|
|
1023
|
-
"""
|
|
1024
|
-
Add a _repr_html_ to be an SVG to work with sphinx gallery.
|
|
1025
|
-
|
|
1026
|
-
ala https://github.com/xflr6/graphviz/pull/121
|
|
1027
|
-
until this PR is merged and released
|
|
1028
|
-
https://github.com/sphinx-gallery/sphinx-gallery/pull/1138
|
|
1029
|
-
"""
|
|
1030
|
-
return self.graphviz_svg()
|
|
1031
|
-
|
|
1032
|
-
def display(self, **kwargs: Unpack[GraphvizKwargs]) -> None:
|
|
1033
|
-
"""
|
|
1034
|
-
Displays the e-graph in the notebook.
|
|
1035
|
-
"""
|
|
1036
|
-
if IN_IPYTHON:
|
|
1037
|
-
from IPython.display import SVG, display
|
|
1038
|
-
|
|
1039
|
-
display(SVG(self.graphviz_svg(**kwargs)))
|
|
1040
|
-
else:
|
|
1041
|
-
self.graphviz(**kwargs).render(view=True, format="svg", quiet=True)
|
|
1021
|
+
def _ipython_display_(self) -> None:
|
|
1022
|
+
self.display()
|
|
1042
1023
|
|
|
1043
1024
|
def input(self, fn: Callable[..., String], path: str) -> None:
|
|
1044
1025
|
"""
|
|
1045
1026
|
Loads a CSV file and sets it as *input, output of the function.
|
|
1046
1027
|
"""
|
|
1028
|
+
self._egraph.run_program(bindings.Input(bindings.DUMMY_SPAN, self._callable_to_egg(fn), path))
|
|
1029
|
+
|
|
1030
|
+
def _callable_to_egg(self, fn: object) -> str:
|
|
1047
1031
|
ref, decls = resolve_callable(fn)
|
|
1048
1032
|
self._add_decls(decls)
|
|
1049
|
-
|
|
1050
|
-
self._egraph.run_program(bindings.Input(fn_name, path))
|
|
1033
|
+
return self._state.callable_ref_to_egg(ref)
|
|
1051
1034
|
|
|
1052
1035
|
def let(self, name: str, expr: EXPR) -> EXPR:
|
|
1053
1036
|
"""
|
|
@@ -1059,8 +1042,8 @@ class EGraph(_BaseModule):
|
|
|
1059
1042
|
self._add_decls(runtime_expr)
|
|
1060
1043
|
return cast(
|
|
1061
1044
|
EXPR,
|
|
1062
|
-
RuntimeExpr.
|
|
1063
|
-
self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name))
|
|
1045
|
+
RuntimeExpr.__from_values__(
|
|
1046
|
+
self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name, True))
|
|
1064
1047
|
),
|
|
1065
1048
|
)
|
|
1066
1049
|
|
|
@@ -1084,13 +1067,13 @@ class EGraph(_BaseModule):
|
|
|
1084
1067
|
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1085
1068
|
# Must also register type
|
|
1086
1069
|
egg_expr = self._state.typed_expr_to_egg(typed_expr)
|
|
1087
|
-
self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule))
|
|
1070
|
+
self._egraph.run_program(bindings.Simplify(bindings.DUMMY_SPAN, egg_expr, egg_schedule))
|
|
1088
1071
|
extract_report = self._egraph.extract_report()
|
|
1089
1072
|
if not isinstance(extract_report, bindings.Best):
|
|
1090
1073
|
msg = "No extract report saved"
|
|
1091
1074
|
raise ValueError(msg) # noqa: TRY004
|
|
1092
1075
|
(new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
|
|
1093
|
-
return cast(EXPR, RuntimeExpr.
|
|
1076
|
+
return cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
|
|
1094
1077
|
|
|
1095
1078
|
def include(self, path: str) -> None:
|
|
1096
1079
|
"""
|
|
@@ -1139,13 +1122,13 @@ class EGraph(_BaseModule):
|
|
|
1139
1122
|
"""
|
|
1140
1123
|
Checks that one of the facts is not true
|
|
1141
1124
|
"""
|
|
1142
|
-
self._egraph.run_program(bindings.Fail(self._facts_to_check(facts)))
|
|
1125
|
+
self._egraph.run_program(bindings.Fail(bindings.DUMMY_SPAN, self._facts_to_check(facts)))
|
|
1143
1126
|
|
|
1144
1127
|
def _facts_to_check(self, fact_likes: Iterable[FactLike]) -> bindings.Check:
|
|
1145
1128
|
facts = _fact_likes(fact_likes)
|
|
1146
1129
|
self._add_decls(*facts)
|
|
1147
1130
|
egg_facts = [self._state.fact_to_egg(f.fact) for f in _fact_likes(facts)]
|
|
1148
|
-
return bindings.Check(egg_facts)
|
|
1131
|
+
return bindings.Check(bindings.DUMMY_SPAN, egg_facts)
|
|
1149
1132
|
|
|
1150
1133
|
@overload
|
|
1151
1134
|
def extract(self, expr: EXPR, /, include_cost: Literal[False] = False) -> EXPR: ...
|
|
@@ -1167,7 +1150,7 @@ class EGraph(_BaseModule):
|
|
|
1167
1150
|
raise ValueError(msg) # noqa: TRY004
|
|
1168
1151
|
(new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
|
|
1169
1152
|
|
|
1170
|
-
res = cast(EXPR, RuntimeExpr.
|
|
1153
|
+
res = cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
|
|
1171
1154
|
if include_cost:
|
|
1172
1155
|
return res, extract_report.cost
|
|
1173
1156
|
return res
|
|
@@ -1185,11 +1168,15 @@ class EGraph(_BaseModule):
|
|
|
1185
1168
|
msg = "Wrong extract report type"
|
|
1186
1169
|
raise ValueError(msg) # noqa: TRY004
|
|
1187
1170
|
new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp)
|
|
1188
|
-
return [cast(EXPR, RuntimeExpr.
|
|
1171
|
+
return [cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs]
|
|
1189
1172
|
|
|
1190
1173
|
def _run_extract(self, typed_expr: TypedExprDecl, n: int) -> bindings._ExtractReport:
|
|
1191
1174
|
expr = self._state.typed_expr_to_egg(typed_expr)
|
|
1192
|
-
self._egraph.run_program(
|
|
1175
|
+
self._egraph.run_program(
|
|
1176
|
+
bindings.ActionCommand(
|
|
1177
|
+
bindings.Extract(bindings.DUMMY_SPAN, expr, bindings.Lit(bindings.DUMMY_SPAN, bindings.Int(n)))
|
|
1178
|
+
)
|
|
1179
|
+
)
|
|
1193
1180
|
extract_report = self._egraph.extract_report()
|
|
1194
1181
|
if not extract_report:
|
|
1195
1182
|
msg = "No extract report saved"
|
|
@@ -1208,7 +1195,7 @@ class EGraph(_BaseModule):
|
|
|
1208
1195
|
"""
|
|
1209
1196
|
Pop the current state of the egraph, reverting back to the previous state.
|
|
1210
1197
|
"""
|
|
1211
|
-
self._egraph.run_program(bindings.Pop(1))
|
|
1198
|
+
self._egraph.run_program(bindings.Pop(bindings.DUMMY_SPAN, 1))
|
|
1212
1199
|
self._state = self._state_stack.pop()
|
|
1213
1200
|
|
|
1214
1201
|
def __enter__(self) -> Self:
|
|
@@ -1221,7 +1208,7 @@ class EGraph(_BaseModule):
|
|
|
1221
1208
|
self.push()
|
|
1222
1209
|
return self
|
|
1223
1210
|
|
|
1224
|
-
def __exit__(self, exc_type, exc, exc_tb) -> None:
|
|
1211
|
+
def __exit__(self, exc_type, exc, exc_tb) -> None:
|
|
1225
1212
|
CURRENT_EGRAPH.reset(self._token_stack.pop())
|
|
1226
1213
|
self.pop()
|
|
1227
1214
|
|
|
@@ -1261,42 +1248,119 @@ class EGraph(_BaseModule):
|
|
|
1261
1248
|
return self._egraph.eval_py_object(egg_expr)
|
|
1262
1249
|
raise TypeError(f"Eval not implemented for {typed_expr.tp}")
|
|
1263
1250
|
|
|
1264
|
-
def
|
|
1265
|
-
self,
|
|
1266
|
-
|
|
1267
|
-
|
|
1251
|
+
def _serialize(
|
|
1252
|
+
self,
|
|
1253
|
+
**kwargs: Unpack[GraphvizKwargs],
|
|
1254
|
+
) -> bindings.SerializedEGraph:
|
|
1255
|
+
max_functions = kwargs.pop("max_functions", None)
|
|
1256
|
+
max_calls_per_function = kwargs.pop("max_calls_per_function", None)
|
|
1257
|
+
split_primitive_outputs = kwargs.pop("split_primitive_outputs", True)
|
|
1258
|
+
split_functions = kwargs.pop("split_functions", [])
|
|
1259
|
+
include_temporary_functions = kwargs.pop("include_temporary_functions", False)
|
|
1260
|
+
n_inline_leaves = kwargs.pop("n_inline_leaves", 1)
|
|
1261
|
+
serialized = self._egraph.serialize(
|
|
1262
|
+
[],
|
|
1263
|
+
max_functions=max_functions,
|
|
1264
|
+
max_calls_per_function=max_calls_per_function,
|
|
1265
|
+
include_temporary_functions=include_temporary_functions,
|
|
1266
|
+
)
|
|
1267
|
+
if split_primitive_outputs or split_functions:
|
|
1268
|
+
additional_ops = set(map(self._callable_to_egg, split_functions))
|
|
1269
|
+
serialized.split_classes(self._egraph, additional_ops)
|
|
1270
|
+
serialized.map_ops(self._state.op_mapping())
|
|
1268
1271
|
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
while self.run(1).updated and i < max:
|
|
1272
|
-
i += 1
|
|
1273
|
-
dots.append(str(self.graphviz(**kwargs)))
|
|
1274
|
-
return graphviz_widget_with_slider(dots, performance=performance)
|
|
1272
|
+
for _ in range(n_inline_leaves):
|
|
1273
|
+
serialized.inline_leaves()
|
|
1275
1274
|
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
) ->
|
|
1279
|
-
|
|
1275
|
+
return serialized
|
|
1276
|
+
|
|
1277
|
+
def _graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
|
|
1278
|
+
serialized = self._serialize(**kwargs)
|
|
1279
|
+
|
|
1280
|
+
original = serialized.to_dot()
|
|
1281
|
+
# Add link to stylesheet to the graph, so that edges light up on hover
|
|
1282
|
+
# https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
|
|
1283
|
+
styles = """/* the lines within the edges */
|
|
1284
|
+
.edge:active path,
|
|
1285
|
+
.edge:hover path {
|
|
1286
|
+
stroke: fuchsia;
|
|
1287
|
+
stroke-width: 3;
|
|
1288
|
+
stroke-opacity: 1;
|
|
1289
|
+
}
|
|
1290
|
+
/* arrows are typically drawn with a polygon */
|
|
1291
|
+
.edge:active polygon,
|
|
1292
|
+
.edge:hover polygon {
|
|
1293
|
+
stroke: fuchsia;
|
|
1294
|
+
stroke-width: 3;
|
|
1295
|
+
fill: fuchsia;
|
|
1296
|
+
stroke-opacity: 1;
|
|
1297
|
+
fill-opacity: 1;
|
|
1298
|
+
}
|
|
1299
|
+
/* If you happen to have text and want to color that as well... */
|
|
1300
|
+
.edge:active text,
|
|
1301
|
+
.edge:hover text {
|
|
1302
|
+
fill: fuchsia;
|
|
1303
|
+
}"""
|
|
1304
|
+
p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
|
|
1305
|
+
p.write_text(styles)
|
|
1306
|
+
with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
|
|
1307
|
+
return graphviz.Source(with_stylesheet)
|
|
1308
|
+
|
|
1309
|
+
def display(self, graphviz: bool = False, **kwargs: Unpack[GraphvizKwargs]) -> None:
|
|
1310
|
+
"""
|
|
1311
|
+
Displays the e-graph.
|
|
1312
|
+
|
|
1313
|
+
If in IPython it will display it inline, otherwise it will write it to a file and open it.
|
|
1314
|
+
"""
|
|
1315
|
+
from IPython.display import SVG, display
|
|
1280
1316
|
|
|
1281
|
-
|
|
1317
|
+
from .visualizer_widget import VisualizerWidget
|
|
1282
1318
|
|
|
1283
|
-
|
|
1319
|
+
if graphviz:
|
|
1320
|
+
if IN_IPYTHON:
|
|
1321
|
+
svg = self._graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")
|
|
1322
|
+
display(SVG(svg))
|
|
1323
|
+
else:
|
|
1324
|
+
self._graphviz(**kwargs).render(view=True, format="svg", quiet=True)
|
|
1325
|
+
else:
|
|
1326
|
+
serialized = self._serialize(**kwargs)
|
|
1327
|
+
VisualizerWidget(egraphs=[serialized.to_json()]).display_or_open()
|
|
1284
1328
|
|
|
1285
|
-
|
|
1286
|
-
|
|
1329
|
+
def saturate(
|
|
1330
|
+
self,
|
|
1331
|
+
schedule: Schedule | None = None,
|
|
1332
|
+
*,
|
|
1333
|
+
expr: Expr | None = None,
|
|
1334
|
+
max: int = 1000,
|
|
1335
|
+
**kwargs: Unpack[GraphvizKwargs],
|
|
1336
|
+
) -> None:
|
|
1337
|
+
"""
|
|
1338
|
+
Saturate the egraph, running the given schedule until the egraph is saturated.
|
|
1339
|
+
It serializes the egraph at each step and returns a widget to visualize the egraph.
|
|
1340
|
+
"""
|
|
1341
|
+
from .visualizer_widget import VisualizerWidget
|
|
1287
1342
|
|
|
1288
|
-
|
|
1343
|
+
def to_json() -> str:
|
|
1344
|
+
if expr:
|
|
1345
|
+
print(self.extract(expr))
|
|
1346
|
+
return self._serialize(**kwargs).to_json()
|
|
1289
1347
|
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1348
|
+
egraphs = [to_json()]
|
|
1349
|
+
i = 0
|
|
1350
|
+
while self.run(schedule or 1).updated and i < max:
|
|
1351
|
+
i += 1
|
|
1352
|
+
egraphs.append(to_json())
|
|
1353
|
+
VisualizerWidget(egraphs=egraphs).display_or_open()
|
|
1293
1354
|
|
|
1294
1355
|
@classmethod
|
|
1295
1356
|
def current(cls) -> EGraph:
|
|
1296
1357
|
"""
|
|
1297
1358
|
Returns the current egraph, which is the one in the context.
|
|
1298
1359
|
"""
|
|
1299
|
-
|
|
1360
|
+
try:
|
|
1361
|
+
return CURRENT_EGRAPH.get()
|
|
1362
|
+
except LookupError:
|
|
1363
|
+
return cls(save_egglog_string=True)
|
|
1300
1364
|
|
|
1301
1365
|
@property
|
|
1302
1366
|
def _egraph(self) -> bindings.EGraph:
|
|
@@ -1448,7 +1512,8 @@ class Ruleset(Schedule):
|
|
|
1448
1512
|
To return the egg decls, we go through our deferred rules and add any we haven't yet
|
|
1449
1513
|
"""
|
|
1450
1514
|
while self.deferred_rule_gens:
|
|
1451
|
-
|
|
1515
|
+
with set_current_ruleset(self):
|
|
1516
|
+
rules = self.deferred_rule_gens.pop()()
|
|
1452
1517
|
self._current_egg_decls.update(*rules)
|
|
1453
1518
|
self.__egg_ruleset__.rules.extend(r.decl for r in rules)
|
|
1454
1519
|
return self._current_egg_decls
|
|
@@ -1688,16 +1753,16 @@ def action_command(action: Action) -> Action:
|
|
|
1688
1753
|
return action
|
|
1689
1754
|
|
|
1690
1755
|
|
|
1691
|
-
def var(name: str, bound: type[
|
|
1756
|
+
def var(name: str, bound: type[T]) -> T:
|
|
1692
1757
|
"""Create a new variable with the given name and type."""
|
|
1693
|
-
return cast(
|
|
1758
|
+
return cast(T, _var(name, bound))
|
|
1694
1759
|
|
|
1695
1760
|
|
|
1696
1761
|
def _var(name: str, bound: object) -> RuntimeExpr:
|
|
1697
1762
|
"""Create a new variable with the given name and type."""
|
|
1698
1763
|
decls = Declarations()
|
|
1699
1764
|
type_ref = resolve_type_annotation(decls, bound)
|
|
1700
|
-
return RuntimeExpr.
|
|
1765
|
+
return RuntimeExpr.__from_values__(decls, TypedExprDecl(type_ref.to_just(), VarDecl(name, False)))
|
|
1701
1766
|
|
|
1702
1767
|
|
|
1703
1768
|
def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
|
|
@@ -1790,7 +1855,7 @@ class _NeBuilder(Generic[EXPR]):
|
|
|
1790
1855
|
lhs = to_runtime_expr(self.lhs)
|
|
1791
1856
|
rhs = convert_to_same_type(rhs, lhs)
|
|
1792
1857
|
assert isinstance(Unit, RuntimeClass)
|
|
1793
|
-
res = RuntimeExpr.
|
|
1858
|
+
res = RuntimeExpr.__from_values__(
|
|
1794
1859
|
Declarations.create(Unit, lhs, rhs),
|
|
1795
1860
|
TypedExprDecl(
|
|
1796
1861
|
JustTypeRef("Unit"), CallDecl(FunctionRef("!="), (lhs.__egg_typed_expr__, rhs.__egg_typed_expr__))
|
|
@@ -1944,7 +2009,7 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
|
|
|
1944
2009
|
if "Callable" not in globals:
|
|
1945
2010
|
globals["Callable"] = Callable
|
|
1946
2011
|
hints = get_type_hints(gen, globals, frame.f_locals)
|
|
1947
|
-
args = [_var(
|
|
2012
|
+
args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
|
|
1948
2013
|
return list(gen(*args)) # type: ignore[misc]
|
|
1949
2014
|
|
|
1950
2015
|
|
|
@@ -1959,3 +2024,19 @@ def _fact_like(fact_like: FactLike) -> Fact:
|
|
|
1959
2024
|
if isinstance(fact_like, Expr):
|
|
1960
2025
|
return expr_fact(fact_like)
|
|
1961
2026
|
return fact_like
|
|
2027
|
+
|
|
2028
|
+
|
|
2029
|
+
_CURRENT_RULESET = ContextVar[Ruleset | None]("CURRENT_RULESET", default=None)
|
|
2030
|
+
|
|
2031
|
+
|
|
2032
|
+
def get_current_ruleset() -> Ruleset | None:
|
|
2033
|
+
return _CURRENT_RULESET.get()
|
|
2034
|
+
|
|
2035
|
+
|
|
2036
|
+
@contextlib.contextmanager
|
|
2037
|
+
def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
|
|
2038
|
+
token = _CURRENT_RULESET.set(r)
|
|
2039
|
+
try:
|
|
2040
|
+
yield
|
|
2041
|
+
finally:
|
|
2042
|
+
_CURRENT_RULESET.reset(token)
|