egglog 7.0.0__cp310-none-win_amd64.whl → 7.2.0__cp310-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp310-win_amd64.pyd +0 -0
- egglog/bindings.pyi +7 -0
- egglog/builtins.py +41 -1
- egglog/conversion.py +22 -17
- egglog/declarations.py +122 -37
- egglog/egraph.py +219 -78
- egglog/egraph_state.py +124 -54
- egglog/examples/higher_order_functions.py +50 -0
- egglog/exp/array_api.py +12 -9
- egglog/pretty.py +71 -15
- egglog/runtime.py +118 -33
- egglog/thunk.py +17 -6
- egglog/type_constraint_solver.py +5 -4
- {egglog-7.0.0.dist-info → egglog-7.2.0.dist-info}/METADATA +10 -10
- {egglog-7.0.0.dist-info → egglog-7.2.0.dist-info}/RECORD +17 -16
- {egglog-7.0.0.dist-info → egglog-7.2.0.dist-info}/WHEEL +0 -0
- {egglog-7.0.0.dist-info → egglog-7.2.0.dist-info}/license_files/LICENSE +0 -0
egglog/egraph.py
CHANGED
|
@@ -75,6 +75,7 @@ __all__ = [
|
|
|
75
75
|
"seq",
|
|
76
76
|
"Command",
|
|
77
77
|
"simplify",
|
|
78
|
+
"unstable_combine_rulesets",
|
|
78
79
|
"check",
|
|
79
80
|
"GraphvizKwargs",
|
|
80
81
|
"Ruleset",
|
|
@@ -88,6 +89,7 @@ __all__ = [
|
|
|
88
89
|
"Fact",
|
|
89
90
|
"Action",
|
|
90
91
|
"Command",
|
|
92
|
+
"check_eq",
|
|
91
93
|
]
|
|
92
94
|
|
|
93
95
|
T = TypeVar("T")
|
|
@@ -145,6 +147,23 @@ def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
|
|
|
145
147
|
return EGraph().extract(x)
|
|
146
148
|
|
|
147
149
|
|
|
150
|
+
def check_eq(x: EXPR, y: EXPR, schedule: Schedule | None = None) -> EGraph:
|
|
151
|
+
"""
|
|
152
|
+
Verifies that two expressions are equal after running the schedule.
|
|
153
|
+
"""
|
|
154
|
+
egraph = EGraph()
|
|
155
|
+
x_var = egraph.let("__check_eq_x", x)
|
|
156
|
+
y_var = egraph.let("__check_eq_y", y)
|
|
157
|
+
if schedule:
|
|
158
|
+
egraph.run(schedule)
|
|
159
|
+
fact = eq(x_var).to(y_var)
|
|
160
|
+
try:
|
|
161
|
+
egraph.check(fact)
|
|
162
|
+
except bindings.EggSmolError as err:
|
|
163
|
+
raise AssertionError(f"Failed {eq(x).to(y)}\n -> {ne(egraph.extract(x)).to(egraph.extract(y))})") from err
|
|
164
|
+
return egraph
|
|
165
|
+
|
|
166
|
+
|
|
148
167
|
def check(x: FactLike, schedule: Schedule | None = None, *given: ActionLike) -> None:
|
|
149
168
|
"""
|
|
150
169
|
Verifies that the fact is true given some assumptions and after running the schedule.
|
|
@@ -333,7 +352,7 @@ class _BaseModule:
|
|
|
333
352
|
This is the same as defining a nullary function with a high cost.
|
|
334
353
|
# TODO: Rename as declare to match eggglog?
|
|
335
354
|
"""
|
|
336
|
-
return constant(name, tp, egg_name)
|
|
355
|
+
return constant(name, tp, egg_name=egg_name)
|
|
337
356
|
|
|
338
357
|
def register(
|
|
339
358
|
self,
|
|
@@ -433,6 +452,7 @@ class _ExprMetaclass(type):
|
|
|
433
452
|
namespace: dict[str, Any],
|
|
434
453
|
egg_sort: str | None = None,
|
|
435
454
|
builtin: bool = False,
|
|
455
|
+
ruleset: Ruleset | None = None,
|
|
436
456
|
) -> RuntimeClass | type:
|
|
437
457
|
# If this is the Expr subclass, just return the class
|
|
438
458
|
if not bases:
|
|
@@ -448,7 +468,14 @@ class _ExprMetaclass(type):
|
|
|
448
468
|
# Otherwise, f_locals returns a copy
|
|
449
469
|
# https://peps.python.org/pep-0667/
|
|
450
470
|
decls_thunk = Thunk.fn(
|
|
451
|
-
_generate_class_decls,
|
|
471
|
+
_generate_class_decls,
|
|
472
|
+
namespace,
|
|
473
|
+
prev_frame,
|
|
474
|
+
builtin,
|
|
475
|
+
egg_sort,
|
|
476
|
+
name,
|
|
477
|
+
ruleset,
|
|
478
|
+
fallback=Declarations,
|
|
452
479
|
)
|
|
453
480
|
return RuntimeClass(decls_thunk, TypeRefWithVars(name))
|
|
454
481
|
|
|
@@ -457,7 +484,12 @@ class _ExprMetaclass(type):
|
|
|
457
484
|
|
|
458
485
|
|
|
459
486
|
def _generate_class_decls(
|
|
460
|
-
namespace: dict[str, Any],
|
|
487
|
+
namespace: dict[str, Any],
|
|
488
|
+
frame: FrameType,
|
|
489
|
+
builtin: bool,
|
|
490
|
+
egg_sort: str | None,
|
|
491
|
+
cls_name: str,
|
|
492
|
+
ruleset: Ruleset | None,
|
|
461
493
|
) -> Declarations:
|
|
462
494
|
"""
|
|
463
495
|
Lazy constructor for class declerations to support classes with methods whose types are not yet defined.
|
|
@@ -479,9 +511,9 @@ def _generate_class_decls(
|
|
|
479
511
|
for k, v in get_type_hints(_Dummytype, globalns=frame.f_globals, localns=frame.f_locals).items():
|
|
480
512
|
if getattr(v, "__origin__", None) == ClassVar:
|
|
481
513
|
(inner_tp,) = v.__args__
|
|
482
|
-
type_ref = resolve_type_annotation(decls, inner_tp)
|
|
483
|
-
cls_decl.class_variables[k] = ConstantDecl(type_ref)
|
|
484
|
-
|
|
514
|
+
type_ref = resolve_type_annotation(decls, inner_tp)
|
|
515
|
+
cls_decl.class_variables[k] = ConstantDecl(type_ref.to_just())
|
|
516
|
+
_add_default_rewrite(decls, ClassVariableRef(cls_name, k), type_ref, namespace.pop(k, None), ruleset)
|
|
485
517
|
else:
|
|
486
518
|
msg = f"On class {cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
|
|
487
519
|
raise NotImplementedError(msg)
|
|
@@ -491,13 +523,15 @@ def _generate_class_decls(
|
|
|
491
523
|
##
|
|
492
524
|
|
|
493
525
|
# The type ref of self is paramterized by the type vars
|
|
494
|
-
|
|
526
|
+
TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
|
|
495
527
|
|
|
496
528
|
# Get all the methods from the class
|
|
497
529
|
filtered_namespace: list[tuple[str, Any]] = [
|
|
498
530
|
(k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
|
|
499
531
|
]
|
|
500
532
|
|
|
533
|
+
# all methods we should try adding default functions for
|
|
534
|
+
default_function_refs: dict[ClassMethodRef | MethodRef | PropertyRef, Callable] = {}
|
|
501
535
|
# Then register each of its methods
|
|
502
536
|
for method_name, method in filtered_namespace:
|
|
503
537
|
is_init = method_name == "__init__"
|
|
@@ -517,33 +551,26 @@ def _generate_class_decls(
|
|
|
517
551
|
continue
|
|
518
552
|
locals = frame.f_locals
|
|
519
553
|
|
|
520
|
-
|
|
521
|
-
return _fn_decl(
|
|
522
|
-
decls,
|
|
523
|
-
egg_fn, # noqa: B023
|
|
524
|
-
fn,
|
|
525
|
-
locals, # noqa: B023
|
|
526
|
-
default, # noqa: B023
|
|
527
|
-
cost, # noqa: B023
|
|
528
|
-
merge, # noqa: B023
|
|
529
|
-
on_merge, # noqa: B023
|
|
530
|
-
mutates, # noqa: B023
|
|
531
|
-
builtin,
|
|
532
|
-
first,
|
|
533
|
-
is_init, # noqa: B023
|
|
534
|
-
unextractable, # noqa: B023
|
|
535
|
-
)
|
|
536
|
-
|
|
554
|
+
ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
|
|
537
555
|
match fn:
|
|
538
556
|
case classmethod():
|
|
539
|
-
|
|
557
|
+
ref = ClassMethodRef(cls_name, method_name)
|
|
558
|
+
fn = fn.__func__
|
|
540
559
|
case property():
|
|
541
|
-
|
|
560
|
+
ref = PropertyRef(cls_name, method_name)
|
|
561
|
+
fn = fn.fget
|
|
542
562
|
case _:
|
|
543
|
-
if is_init
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
563
|
+
ref = InitRef(cls_name) if is_init else MethodRef(cls_name, method_name)
|
|
564
|
+
|
|
565
|
+
_fn_decl(decls, egg_fn, ref, fn, locals, default, cost, merge, on_merge, mutates, builtin, unextractable)
|
|
566
|
+
|
|
567
|
+
if not builtin and not isinstance(ref, InitRef) and not mutates:
|
|
568
|
+
default_function_refs[ref] = fn
|
|
569
|
+
|
|
570
|
+
# Add all rewrite methods at the end so that all methods are registered first and can be accessed
|
|
571
|
+
# in the bodies
|
|
572
|
+
for ref, fn in default_function_refs.items():
|
|
573
|
+
_add_default_rewrite_function(decls, ref, fn, ruleset)
|
|
547
574
|
|
|
548
575
|
return decls
|
|
549
576
|
|
|
@@ -562,6 +589,7 @@ def function(
|
|
|
562
589
|
mutates_first_arg: bool = False,
|
|
563
590
|
unextractable: bool = False,
|
|
564
591
|
builtin: bool = False,
|
|
592
|
+
ruleset: Ruleset | None = None,
|
|
565
593
|
) -> Callable[[CALLABLE], CALLABLE]: ...
|
|
566
594
|
|
|
567
595
|
|
|
@@ -575,6 +603,7 @@ def function(
|
|
|
575
603
|
on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None,
|
|
576
604
|
mutates_first_arg: bool = False,
|
|
577
605
|
unextractable: bool = False,
|
|
606
|
+
ruleset: Ruleset | None = None,
|
|
578
607
|
) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...
|
|
579
608
|
|
|
580
609
|
|
|
@@ -605,15 +634,17 @@ class _FunctionConstructor:
|
|
|
605
634
|
merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None
|
|
606
635
|
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
|
|
607
636
|
unextractable: bool = False
|
|
637
|
+
ruleset: Ruleset | None = None
|
|
608
638
|
|
|
609
639
|
def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
|
|
610
640
|
return RuntimeFunction(Thunk.fn(self.create_decls, fn), FunctionRef(fn.__name__))
|
|
611
641
|
|
|
612
642
|
def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations:
|
|
613
643
|
decls = Declarations()
|
|
614
|
-
|
|
644
|
+
_fn_decl(
|
|
615
645
|
decls,
|
|
616
646
|
self.egg_fn,
|
|
647
|
+
(ref := FunctionRef(fn.__name__)),
|
|
617
648
|
fn,
|
|
618
649
|
self.hint_locals,
|
|
619
650
|
self.default,
|
|
@@ -624,12 +655,14 @@ class _FunctionConstructor:
|
|
|
624
655
|
self.builtin,
|
|
625
656
|
unextractable=self.unextractable,
|
|
626
657
|
)
|
|
658
|
+
_add_default_rewrite_function(decls, ref, fn, self.ruleset)
|
|
627
659
|
return decls
|
|
628
660
|
|
|
629
661
|
|
|
630
662
|
def _fn_decl(
|
|
631
663
|
decls: Declarations,
|
|
632
664
|
egg_name: str | None,
|
|
665
|
+
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
|
|
633
666
|
fn: object,
|
|
634
667
|
# Pass in the locals, retrieved from the frame when wrapping,
|
|
635
668
|
# so that we support classes and function defined inside of other functions (which won't show up in the globals)
|
|
@@ -640,22 +673,35 @@ def _fn_decl(
|
|
|
640
673
|
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None,
|
|
641
674
|
mutates_first_arg: bool,
|
|
642
675
|
is_builtin: bool,
|
|
643
|
-
# The first arg is either cls, for a classmethod, a self type, or none for a function
|
|
644
|
-
first_arg: Literal["cls"] | TypeOrVarRef | None = None,
|
|
645
|
-
is_init: bool = False,
|
|
646
676
|
unextractable: bool = False,
|
|
647
|
-
) ->
|
|
677
|
+
) -> None:
|
|
678
|
+
"""
|
|
679
|
+
Sets the function decl for the function object.
|
|
680
|
+
"""
|
|
648
681
|
if not isinstance(fn, FunctionType):
|
|
649
682
|
raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
|
|
650
683
|
|
|
684
|
+
# partial function creation and calling are handled with a special case in the type checker, so don't
|
|
685
|
+
# use the normal logic
|
|
686
|
+
special_function_name: SpecialFunctions | None = (
|
|
687
|
+
"fn-partial" if egg_name == "unstable-fn" else "fn-app" if egg_name == "unstable-app" else None
|
|
688
|
+
)
|
|
689
|
+
if special_function_name:
|
|
690
|
+
decls.set_function_decl(ref, FunctionDecl(special_function_name, builtin=True, egg_name=egg_name))
|
|
691
|
+
return
|
|
692
|
+
|
|
651
693
|
hint_globals = fn.__globals__.copy()
|
|
694
|
+
# Copy Callable into global if not present bc sometimes it gets automatically removed by ruff to type only block
|
|
695
|
+
# https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/
|
|
696
|
+
if "Callable" not in hint_globals:
|
|
697
|
+
hint_globals["Callable"] = Callable
|
|
652
698
|
|
|
653
699
|
hints = get_type_hints(fn, hint_globals, hint_locals)
|
|
654
700
|
|
|
655
701
|
params = list(signature(fn).parameters.values())
|
|
656
702
|
|
|
657
|
-
# If this is an init function, or a classmethod,
|
|
658
|
-
if
|
|
703
|
+
# If this is an init function, or a classmethod, the first arg is not used
|
|
704
|
+
if isinstance(ref, ClassMethodRef | InitRef):
|
|
659
705
|
params = params[1:]
|
|
660
706
|
|
|
661
707
|
if _last_param_variable(params):
|
|
@@ -665,9 +711,8 @@ def _fn_decl(
|
|
|
665
711
|
else:
|
|
666
712
|
var_arg_type = None
|
|
667
713
|
arg_types = tuple(
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
if i == 0 and isinstance(first_arg, ClassTypeVarRef | TypeRefWithVars) and not is_init
|
|
714
|
+
decls.get_paramaterized_class(ref.class_name)
|
|
715
|
+
if i == 0 and isinstance(ref, MethodRef | PropertyRef)
|
|
671
716
|
else resolve_type_annotation(decls, hints[t.name])
|
|
672
717
|
for i, t in enumerate(params)
|
|
673
718
|
)
|
|
@@ -680,17 +725,15 @@ def _fn_decl(
|
|
|
680
725
|
|
|
681
726
|
decls.update(*arg_defaults)
|
|
682
727
|
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
else:
|
|
693
|
-
return_type = resolve_type_annotation(decls, hints["return"])
|
|
728
|
+
return_type = (
|
|
729
|
+
decls.get_paramaterized_class(ref.class_name)
|
|
730
|
+
if isinstance(ref, InitRef)
|
|
731
|
+
else arg_types[0]
|
|
732
|
+
if mutates_first_arg
|
|
733
|
+
else resolve_type_annotation(decls, hints["return"])
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
arg_names = tuple(t.name for t in params)
|
|
694
737
|
|
|
695
738
|
decls |= default
|
|
696
739
|
merged = (
|
|
@@ -714,12 +757,14 @@ def _fn_decl(
|
|
|
714
757
|
)
|
|
715
758
|
)
|
|
716
759
|
decls.update(*merge_action)
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
760
|
+
decl = FunctionDecl(
|
|
761
|
+
FunctionSignature(
|
|
762
|
+
return_type=None if mutates_first_arg else return_type,
|
|
763
|
+
var_arg_type=var_arg_type,
|
|
764
|
+
arg_types=arg_types,
|
|
765
|
+
arg_names=arg_names,
|
|
766
|
+
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
|
|
767
|
+
),
|
|
723
768
|
cost=cost,
|
|
724
769
|
egg_name=egg_name,
|
|
725
770
|
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
|
|
@@ -728,6 +773,7 @@ def _fn_decl(
|
|
|
728
773
|
default=None if default is None else default.__egg_typed_expr__.expr,
|
|
729
774
|
on_merge=tuple(a.action for a in merge_action),
|
|
730
775
|
)
|
|
776
|
+
decls.set_function_decl(ref, decl)
|
|
731
777
|
|
|
732
778
|
|
|
733
779
|
# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
|
|
@@ -769,19 +815,73 @@ def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Dec
|
|
|
769
815
|
return decls
|
|
770
816
|
|
|
771
817
|
|
|
772
|
-
def constant(
|
|
818
|
+
def constant(
|
|
819
|
+
name: str,
|
|
820
|
+
tp: type[EXPR],
|
|
821
|
+
default_replacement: EXPR | None = None,
|
|
822
|
+
/,
|
|
823
|
+
*,
|
|
824
|
+
egg_name: str | None = None,
|
|
825
|
+
ruleset: Ruleset | None = None,
|
|
826
|
+
) -> EXPR:
|
|
773
827
|
"""
|
|
774
828
|
A "constant" is implemented as the instantiation of a value that takes no args.
|
|
775
829
|
This creates a function with `name` and return type `tp` and returns a value of it being called.
|
|
776
830
|
"""
|
|
777
|
-
return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name)))
|
|
831
|
+
return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name, default_replacement, ruleset)))
|
|
778
832
|
|
|
779
833
|
|
|
780
|
-
def _constant_thunk(
|
|
834
|
+
def _constant_thunk(
|
|
835
|
+
name: str, tp: type, egg_name: str | None, default_replacement: object, ruleset: Ruleset | None
|
|
836
|
+
) -> tuple[Declarations, TypedExprDecl]:
|
|
781
837
|
decls = Declarations()
|
|
782
|
-
type_ref = resolve_type_annotation(decls, tp)
|
|
783
|
-
|
|
784
|
-
|
|
838
|
+
type_ref = resolve_type_annotation(decls, tp)
|
|
839
|
+
callable_ref = ConstantRef(name)
|
|
840
|
+
decls._constants[name] = ConstantDecl(type_ref.to_just(), egg_name)
|
|
841
|
+
_add_default_rewrite(decls, callable_ref, type_ref, default_replacement, ruleset)
|
|
842
|
+
return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
def _add_default_rewrite_function(decls: Declarations, ref: CallableRef, fn: Callable, ruleset: Ruleset | None) -> None:
|
|
846
|
+
"""
|
|
847
|
+
Adds a default rewrite for a function, by calling the functions with vars and adding it if it is not None.
|
|
848
|
+
"""
|
|
849
|
+
callable_decl = decls.get_callable_decl(ref)
|
|
850
|
+
assert isinstance(callable_decl, FunctionDecl)
|
|
851
|
+
signature = callable_decl.signature
|
|
852
|
+
assert isinstance(signature, FunctionSignature)
|
|
853
|
+
|
|
854
|
+
var_args: list[object] = [
|
|
855
|
+
RuntimeExpr.__from_value__(decls, TypedExprDecl(tp.to_just(), VarDecl(_rule_var_name(name))))
|
|
856
|
+
for name, tp in zip(signature.arg_names, signature.arg_types, strict=False)
|
|
857
|
+
]
|
|
858
|
+
# If this is a classmethod, add the class as the first arg
|
|
859
|
+
if isinstance(ref, ClassMethodRef):
|
|
860
|
+
tp = decls.get_paramaterized_class(ref.class_name)
|
|
861
|
+
var_args.insert(0, RuntimeClass(Thunk.value(decls), tp))
|
|
862
|
+
_add_default_rewrite(decls, ref, signature.semantic_return_type, fn(*var_args), ruleset)
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
def _add_default_rewrite(
|
|
866
|
+
decls: Declarations, ref: CallableRef, type_ref: TypeOrVarRef, default_rewrite: object, ruleset: Ruleset | None
|
|
867
|
+
) -> None:
|
|
868
|
+
"""
|
|
869
|
+
Adds a default rewrite for the callable, if the default rewrite is not None
|
|
870
|
+
|
|
871
|
+
Will add it to the ruleset if it is passed in, or add it to the default ruleset on the passed in decls if not.
|
|
872
|
+
"""
|
|
873
|
+
if default_rewrite is None:
|
|
874
|
+
return
|
|
875
|
+
resolved_value = resolve_literal(type_ref, default_rewrite)
|
|
876
|
+
rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr)
|
|
877
|
+
if ruleset:
|
|
878
|
+
ruleset_decls = ruleset._current_egg_decls
|
|
879
|
+
ruleset_decl = ruleset.__egg_ruleset__
|
|
880
|
+
else:
|
|
881
|
+
ruleset_decls = decls
|
|
882
|
+
ruleset_decl = decls.default_ruleset
|
|
883
|
+
ruleset_decl.rules.append(rewrite_decl)
|
|
884
|
+
ruleset_decls |= resolved_value
|
|
785
885
|
|
|
786
886
|
|
|
787
887
|
def _last_param_variable(params: list[Parameter]) -> bool:
|
|
@@ -933,13 +1033,12 @@ class EGraph(_BaseModule):
|
|
|
933
1033
|
"""
|
|
934
1034
|
Displays the e-graph in the notebook.
|
|
935
1035
|
"""
|
|
936
|
-
graphviz = self.graphviz(**kwargs)
|
|
937
1036
|
if IN_IPYTHON:
|
|
938
1037
|
from IPython.display import SVG, display
|
|
939
1038
|
|
|
940
1039
|
display(SVG(self.graphviz_svg(**kwargs)))
|
|
941
1040
|
else:
|
|
942
|
-
graphviz.render(view=True, format="svg", quiet=True)
|
|
1041
|
+
self.graphviz(**kwargs).render(view=True, format="svg", quiet=True)
|
|
943
1042
|
|
|
944
1043
|
def input(self, fn: Callable[..., String], path: str) -> None:
|
|
945
1044
|
"""
|
|
@@ -957,6 +1056,7 @@ class EGraph(_BaseModule):
|
|
|
957
1056
|
action = let(name, expr)
|
|
958
1057
|
self.register(action)
|
|
959
1058
|
runtime_expr = to_runtime_expr(expr)
|
|
1059
|
+
self._add_decls(runtime_expr)
|
|
960
1060
|
return cast(
|
|
961
1061
|
EXPR,
|
|
962
1062
|
RuntimeExpr.__from_value__(
|
|
@@ -982,7 +1082,8 @@ class EGraph(_BaseModule):
|
|
|
982
1082
|
self._add_decls(runtime_expr, schedule)
|
|
983
1083
|
egg_schedule = self._state.schedule_to_egg(schedule.schedule)
|
|
984
1084
|
typed_expr = runtime_expr.__egg_typed_expr__
|
|
985
|
-
|
|
1085
|
+
# Must also register type
|
|
1086
|
+
egg_expr = self._state.typed_expr_to_egg(typed_expr)
|
|
986
1087
|
self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule))
|
|
987
1088
|
extract_report = self._egraph.extract_report()
|
|
988
1089
|
if not isinstance(extract_report, bindings.Best):
|
|
@@ -1059,7 +1160,7 @@ class EGraph(_BaseModule):
|
|
|
1059
1160
|
runtime_expr = to_runtime_expr(expr)
|
|
1060
1161
|
self._add_decls(runtime_expr)
|
|
1061
1162
|
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1062
|
-
extract_report = self._run_extract(typed_expr
|
|
1163
|
+
extract_report = self._run_extract(typed_expr, 0)
|
|
1063
1164
|
|
|
1064
1165
|
if not isinstance(extract_report, bindings.Best):
|
|
1065
1166
|
msg = "No extract report saved"
|
|
@@ -1079,15 +1180,15 @@ class EGraph(_BaseModule):
|
|
|
1079
1180
|
self._add_decls(runtime_expr)
|
|
1080
1181
|
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1081
1182
|
|
|
1082
|
-
extract_report = self._run_extract(typed_expr
|
|
1183
|
+
extract_report = self._run_extract(typed_expr, n)
|
|
1083
1184
|
if not isinstance(extract_report, bindings.Variants):
|
|
1084
1185
|
msg = "Wrong extract report type"
|
|
1085
1186
|
raise ValueError(msg) # noqa: TRY004
|
|
1086
1187
|
new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp)
|
|
1087
1188
|
return [cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, expr)) for expr in new_exprs]
|
|
1088
1189
|
|
|
1089
|
-
def _run_extract(self,
|
|
1090
|
-
expr = self._state.
|
|
1190
|
+
def _run_extract(self, typed_expr: TypedExprDecl, n: int) -> bindings._ExtractReport:
|
|
1191
|
+
expr = self._state.typed_expr_to_egg(typed_expr)
|
|
1091
1192
|
self._egraph.run_program(bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n)))))
|
|
1092
1193
|
extract_report = self._egraph.extract_report()
|
|
1093
1194
|
if not extract_report:
|
|
@@ -1146,7 +1247,7 @@ class EGraph(_BaseModule):
|
|
|
1146
1247
|
runtime_expr = to_runtime_expr(expr)
|
|
1147
1248
|
self._add_decls(runtime_expr)
|
|
1148
1249
|
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1149
|
-
egg_expr = self._state.
|
|
1250
|
+
egg_expr = self._state.typed_expr_to_egg(typed_expr)
|
|
1150
1251
|
match typed_expr.tp:
|
|
1151
1252
|
case JustTypeRef("i64"):
|
|
1152
1253
|
return self._egraph.eval_i64(egg_expr)
|
|
@@ -1276,8 +1377,10 @@ def ruleset(
|
|
|
1276
1377
|
"""
|
|
1277
1378
|
Creates a ruleset with the following rules.
|
|
1278
1379
|
|
|
1279
|
-
If no name is provided,
|
|
1380
|
+
If no name is provided, try using the name of the funciton.
|
|
1280
1381
|
"""
|
|
1382
|
+
if isinstance(rule_or_generator, FunctionType):
|
|
1383
|
+
name = name or rule_or_generator.__name__
|
|
1281
1384
|
r = Ruleset(name)
|
|
1282
1385
|
if rule_or_generator is not None:
|
|
1283
1386
|
r.register(rule_or_generator, *rules, _increase_frame=True)
|
|
@@ -1388,12 +1491,48 @@ class Ruleset(Schedule):
|
|
|
1388
1491
|
def __repr__(self) -> str:
|
|
1389
1492
|
return str(self)
|
|
1390
1493
|
|
|
1494
|
+
def __or__(self, other: Ruleset | UnstableCombinedRuleset) -> UnstableCombinedRuleset:
|
|
1495
|
+
return unstable_combine_rulesets(self, other)
|
|
1496
|
+
|
|
1391
1497
|
# Create a unique name if we didn't pass one from the user
|
|
1392
1498
|
@property
|
|
1393
1499
|
def __egg_name__(self) -> str:
|
|
1394
1500
|
return self.name or f"ruleset_{id(self)}"
|
|
1395
1501
|
|
|
1396
1502
|
|
|
1503
|
+
@dataclass
|
|
1504
|
+
class UnstableCombinedRuleset(Schedule):
|
|
1505
|
+
__egg_decls_thunk__: Callable[[], Declarations] = field(init=False)
|
|
1506
|
+
schedule: RunDecl = field(init=False)
|
|
1507
|
+
name: str | None
|
|
1508
|
+
rulesets: InitVar[list[Ruleset | UnstableCombinedRuleset]]
|
|
1509
|
+
|
|
1510
|
+
def __post_init__(self, rulesets: list[Ruleset | UnstableCombinedRuleset]) -> None:
|
|
1511
|
+
self.schedule = RunDecl(self.__egg_name__, ())
|
|
1512
|
+
self.__egg_decls_thunk__ = Thunk.fn(self._create_egg_decls, *rulesets)
|
|
1513
|
+
|
|
1514
|
+
@property
|
|
1515
|
+
def __egg_name__(self) -> str:
|
|
1516
|
+
return self.name or f"combined_ruleset_{id(self)}"
|
|
1517
|
+
|
|
1518
|
+
def _create_egg_decls(self, *rulesets: Ruleset | UnstableCombinedRuleset) -> Declarations:
|
|
1519
|
+
decls = Declarations.create(*rulesets)
|
|
1520
|
+
decls._rulesets[self.__egg_name__] = CombinedRulesetDecl(tuple(r.__egg_name__ for r in rulesets))
|
|
1521
|
+
return decls
|
|
1522
|
+
|
|
1523
|
+
def __or__(self, other: Ruleset | UnstableCombinedRuleset) -> UnstableCombinedRuleset:
|
|
1524
|
+
return unstable_combine_rulesets(self, other)
|
|
1525
|
+
|
|
1526
|
+
|
|
1527
|
+
def unstable_combine_rulesets(
|
|
1528
|
+
*rulesets: Ruleset | UnstableCombinedRuleset, name: str | None = None
|
|
1529
|
+
) -> UnstableCombinedRuleset:
|
|
1530
|
+
"""
|
|
1531
|
+
Combine multiple rulesets into a single ruleset.
|
|
1532
|
+
"""
|
|
1533
|
+
return UnstableCombinedRuleset(name, list(rulesets))
|
|
1534
|
+
|
|
1535
|
+
|
|
1397
1536
|
@dataclass
|
|
1398
1537
|
class RewriteOrRule:
|
|
1399
1538
|
__egg_decls__: Declarations
|
|
@@ -1556,9 +1695,9 @@ def var(name: str, bound: type[EXPR]) -> EXPR:
|
|
|
1556
1695
|
|
|
1557
1696
|
def _var(name: str, bound: object) -> RuntimeExpr:
|
|
1558
1697
|
"""Create a new variable with the given name and type."""
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
return RuntimeExpr.__from_value__(
|
|
1698
|
+
decls = Declarations()
|
|
1699
|
+
type_ref = resolve_type_annotation(decls, bound)
|
|
1700
|
+
return RuntimeExpr.__from_value__(decls, TypedExprDecl(type_ref.to_just(), VarDecl(name)))
|
|
1562
1701
|
|
|
1563
1702
|
|
|
1564
1703
|
def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
|
|
@@ -1801,9 +1940,11 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
|
|
|
1801
1940
|
"""
|
|
1802
1941
|
# Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
|
|
1803
1942
|
# but not in the globals
|
|
1804
|
-
|
|
1805
|
-
|
|
1806
|
-
|
|
1943
|
+
globals = gen.__globals__.copy()
|
|
1944
|
+
if "Callable" not in globals:
|
|
1945
|
+
globals["Callable"] = Callable
|
|
1946
|
+
hints = get_type_hints(gen, globals, frame.f_locals)
|
|
1947
|
+
args = [_var(_rule_var_name(p.name), hints[p.name]) for p in signature(gen).parameters.values()]
|
|
1807
1948
|
return list(gen(*args)) # type: ignore[misc]
|
|
1808
1949
|
|
|
1809
1950
|
|