egglog 10.0.2__cp310-cp310-win_amd64.whl → 11.0.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +3 -1
- egglog/bindings.cp310-win_amd64.pyd +0 -0
- egglog/bindings.pyi +26 -34
- egglog/builtins.py +221 -173
- egglog/conversion.py +61 -43
- egglog/declarations.py +103 -17
- egglog/deconstruct.py +173 -0
- egglog/egraph.py +77 -129
- egglog/egraph_state.py +17 -14
- egglog/examples/bignum.py +1 -1
- egglog/examples/multiset.py +2 -2
- egglog/exp/array_api.py +37 -3
- egglog/exp/array_api_jit.py +11 -5
- egglog/exp/array_api_program_gen.py +1 -1
- egglog/pretty.py +11 -25
- egglog/runtime.py +197 -147
- egglog/version_compat.py +2 -2
- {egglog-10.0.2.dist-info → egglog-11.0.0.dist-info}/METADATA +1 -1
- {egglog-10.0.2.dist-info → egglog-11.0.0.dist-info}/RECORD +21 -21
- {egglog-10.0.2.dist-info → egglog-11.0.0.dist-info}/WHEEL +1 -1
- egglog/functionalize.py +0 -91
- {egglog-10.0.2.dist-info → egglog-11.0.0.dist-info}/licenses/LICENSE +0 -0
egglog/egraph.py
CHANGED
|
@@ -5,7 +5,7 @@ import inspect
|
|
|
5
5
|
import pathlib
|
|
6
6
|
import tempfile
|
|
7
7
|
from collections.abc import Callable, Generator, Iterable
|
|
8
|
-
from contextvars import ContextVar
|
|
8
|
+
from contextvars import ContextVar, Token
|
|
9
9
|
from dataclasses import InitVar, dataclass, field
|
|
10
10
|
from functools import partial
|
|
11
11
|
from inspect import Parameter, currentframe, signature
|
|
@@ -29,6 +29,7 @@ from typing_extensions import Never, ParamSpec, Self, Unpack, assert_never
|
|
|
29
29
|
|
|
30
30
|
from . import bindings
|
|
31
31
|
from .conversion import *
|
|
32
|
+
from .conversion import convert_to_same_type, resolve_literal
|
|
32
33
|
from .declarations import *
|
|
33
34
|
from .egraph_state import *
|
|
34
35
|
from .ipython_magic import IN_IPYTHON
|
|
@@ -82,7 +83,6 @@ __all__ = [
|
|
|
82
83
|
"run",
|
|
83
84
|
"seq",
|
|
84
85
|
"set_",
|
|
85
|
-
"simplify",
|
|
86
86
|
"subsume",
|
|
87
87
|
"union",
|
|
88
88
|
"unstable_combine_rulesets",
|
|
@@ -110,12 +110,12 @@ IGNORED_ATTRIBUTES = {
|
|
|
110
110
|
"__weakref__",
|
|
111
111
|
"__orig_bases__",
|
|
112
112
|
"__annotations__",
|
|
113
|
-
"__hash__",
|
|
114
113
|
"__qualname__",
|
|
115
114
|
"__firstlineno__",
|
|
116
115
|
"__static_attributes__",
|
|
116
|
+
"__match_args__",
|
|
117
117
|
# Ignore all reflected binary method
|
|
118
|
-
*
|
|
118
|
+
*(f"__r{m[2:]}" for m in NUMERIC_BINARY_METHODS),
|
|
119
119
|
}
|
|
120
120
|
|
|
121
121
|
|
|
@@ -139,15 +139,6 @@ ALWAYS_PRESERVED = {
|
|
|
139
139
|
}
|
|
140
140
|
|
|
141
141
|
|
|
142
|
-
def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
|
|
143
|
-
"""
|
|
144
|
-
Simplify an expression by running the schedule.
|
|
145
|
-
"""
|
|
146
|
-
if schedule:
|
|
147
|
-
return EGraph().simplify(x, schedule)
|
|
148
|
-
return EGraph().extract(x)
|
|
149
|
-
|
|
150
|
-
|
|
151
142
|
def check_eq(x: BASE_EXPR, y: BASE_EXPR, schedule: Schedule | None = None, *, add_second=True, display=False) -> EGraph:
|
|
152
143
|
"""
|
|
153
144
|
Verifies that two expressions are equal after running the schedule.
|
|
@@ -291,7 +282,6 @@ def function(
|
|
|
291
282
|
mutates_first_arg: bool = ...,
|
|
292
283
|
unextractable: bool = ...,
|
|
293
284
|
ruleset: Ruleset | None = ...,
|
|
294
|
-
use_body_as_name: bool = ...,
|
|
295
285
|
subsume: bool = ...,
|
|
296
286
|
) -> Callable[[CONSTRUCTOR_CALLABLE], CONSTRUCTOR_CALLABLE]: ...
|
|
297
287
|
|
|
@@ -371,6 +361,7 @@ class BaseExpr(metaclass=_ExprMetaclass):
|
|
|
371
361
|
|
|
372
362
|
def __ne__(self, other: Self) -> Unit: ... # type: ignore[override, empty-body]
|
|
373
363
|
|
|
364
|
+
# not currently dissalowing other types of equality https://github.com/python/typeshed/issues/8217#issuecomment-3140873292
|
|
374
365
|
def __eq__(self, other: Self) -> Fact: ... # type: ignore[override, empty-body]
|
|
375
366
|
|
|
376
367
|
|
|
@@ -404,7 +395,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
|
|
|
404
395
|
)
|
|
405
396
|
type_vars = tuple(ClassTypeVarRef.from_type_var(p) for p in parameters)
|
|
406
397
|
del parameters
|
|
407
|
-
cls_decl = ClassDecl(egg_sort, type_vars, builtin)
|
|
398
|
+
cls_decl = ClassDecl(egg_sort, type_vars, builtin, match_args=namespace.pop("__match_args__", ()))
|
|
408
399
|
decls = Declarations(_classes={cls_name: cls_decl})
|
|
409
400
|
# Update class think eagerly when resolving so that lookups work in methods
|
|
410
401
|
runtime_cls.__egg_decls_thunk__ = Thunk.value(decls)
|
|
@@ -456,6 +447,9 @@ def _generate_class_decls( # noqa: C901,PLR0912
|
|
|
456
447
|
continue
|
|
457
448
|
locals = frame.f_locals
|
|
458
449
|
ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
|
|
450
|
+
# TODO: Store deprecated message so we can print at runtime
|
|
451
|
+
if (getattr(fn, "__deprecated__", None)) is not None:
|
|
452
|
+
fn = fn.__wrapped__ # type: ignore[attr-defined]
|
|
459
453
|
match fn:
|
|
460
454
|
case classmethod():
|
|
461
455
|
ref = ClassMethodRef(cls_name, method_name)
|
|
@@ -477,7 +471,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
|
|
|
477
471
|
decls.set_function_decl(ref, decl)
|
|
478
472
|
continue
|
|
479
473
|
try:
|
|
480
|
-
|
|
474
|
+
add_rewrite = _fn_decl(
|
|
481
475
|
decls,
|
|
482
476
|
egg_fn,
|
|
483
477
|
ref,
|
|
@@ -515,7 +509,6 @@ class _FunctionConstructor:
|
|
|
515
509
|
merge: Callable[[object, object], object] | None = None
|
|
516
510
|
unextractable: bool = False
|
|
517
511
|
ruleset: Ruleset | None = None
|
|
518
|
-
use_body_as_name: bool = False
|
|
519
512
|
subsume: bool = False
|
|
520
513
|
|
|
521
514
|
def __call__(self, fn: Callable) -> RuntimeFunction:
|
|
@@ -523,11 +516,10 @@ class _FunctionConstructor:
|
|
|
523
516
|
|
|
524
517
|
def create_decls(self, fn: Callable) -> tuple[Declarations, CallableRef]:
|
|
525
518
|
decls = Declarations()
|
|
526
|
-
|
|
527
|
-
ref, add_rewrite = _fn_decl(
|
|
519
|
+
add_rewrite = _fn_decl(
|
|
528
520
|
decls,
|
|
529
521
|
self.egg_fn,
|
|
530
|
-
ref,
|
|
522
|
+
ref := FunctionRef(fn.__name__),
|
|
531
523
|
fn,
|
|
532
524
|
self.hint_locals,
|
|
533
525
|
self.cost,
|
|
@@ -545,8 +537,7 @@ class _FunctionConstructor:
|
|
|
545
537
|
def _fn_decl(
|
|
546
538
|
decls: Declarations,
|
|
547
539
|
egg_name: str | None,
|
|
548
|
-
|
|
549
|
-
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef | None,
|
|
540
|
+
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
|
|
550
541
|
fn: object,
|
|
551
542
|
# Pass in the locals, retrieved from the frame when wrapping,
|
|
552
543
|
# so that we support classes and function defined inside of other functions (which won't show up in the globals)
|
|
@@ -559,7 +550,7 @@ def _fn_decl(
|
|
|
559
550
|
ruleset: Ruleset | None = None,
|
|
560
551
|
unextractable: bool = False,
|
|
561
552
|
reverse_args: bool = False,
|
|
562
|
-
) ->
|
|
553
|
+
) -> Callable[[], None]:
|
|
563
554
|
"""
|
|
564
555
|
Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
|
|
565
556
|
"""
|
|
@@ -619,8 +610,8 @@ def _fn_decl(
|
|
|
619
610
|
else resolve_literal(
|
|
620
611
|
return_type,
|
|
621
612
|
merge(
|
|
622
|
-
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(),
|
|
623
|
-
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(),
|
|
613
|
+
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), UnboundVarDecl("old", "old"))),
|
|
614
|
+
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), UnboundVarDecl("new", "new"))),
|
|
624
615
|
),
|
|
625
616
|
lambda: decls,
|
|
626
617
|
)
|
|
@@ -628,51 +619,40 @@ def _fn_decl(
|
|
|
628
619
|
decls |= merged
|
|
629
620
|
|
|
630
621
|
# defer this in generator so it doesn't resolve for builtins eagerly
|
|
631
|
-
args = (TypedExprDecl(tp.to_just(),
|
|
632
|
-
res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
|
|
633
|
-
res_thunk: Callable[[], object]
|
|
634
|
-
# If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it
|
|
635
|
-
if not ref:
|
|
636
|
-
tuple_args = tuple(args)
|
|
637
|
-
res = _create_default_value(decls, ref, fn, tuple_args, ruleset)
|
|
638
|
-
assert isinstance(res, RuntimeExpr)
|
|
639
|
-
res_ref = UnnamedFunctionRef(tuple_args, res.__egg_typed_expr__)
|
|
640
|
-
decls._unnamed_functions.add(res_ref)
|
|
641
|
-
res_thunk = Thunk.value(res)
|
|
622
|
+
args = (TypedExprDecl(tp.to_just(), UnboundVarDecl(name)) for name, tp in zip(arg_names, arg_types, strict=True))
|
|
642
623
|
|
|
624
|
+
return_type_is_eqsort = (
|
|
625
|
+
not decls._classes[return_type.name].builtin if isinstance(return_type, TypeRefWithVars) else False
|
|
626
|
+
)
|
|
627
|
+
is_constructor = not is_builtin and return_type_is_eqsort and merged is None
|
|
628
|
+
signature_ = FunctionSignature(
|
|
629
|
+
return_type=None if mutates_first_arg else return_type,
|
|
630
|
+
var_arg_type=var_arg_type,
|
|
631
|
+
arg_types=arg_types,
|
|
632
|
+
arg_names=arg_names,
|
|
633
|
+
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
|
|
634
|
+
reverse_args=reverse_args,
|
|
635
|
+
)
|
|
636
|
+
decl: ConstructorDecl | FunctionDecl
|
|
637
|
+
if is_constructor:
|
|
638
|
+
decl = ConstructorDecl(signature_, egg_name, cost, unextractable)
|
|
643
639
|
else:
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
640
|
+
if cost is not None:
|
|
641
|
+
msg = "Cost can only be set for constructors"
|
|
642
|
+
raise ValueError(msg)
|
|
643
|
+
if unextractable:
|
|
644
|
+
msg = "Unextractable can only be set for constructors"
|
|
645
|
+
raise ValueError(msg)
|
|
646
|
+
decl = FunctionDecl(
|
|
647
|
+
signature=signature_,
|
|
648
|
+
egg_name=egg_name,
|
|
649
|
+
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
|
|
650
|
+
builtin=is_builtin,
|
|
655
651
|
)
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
if cost is not None:
|
|
661
|
-
msg = "Cost can only be set for constructors"
|
|
662
|
-
raise ValueError(msg)
|
|
663
|
-
if unextractable:
|
|
664
|
-
msg = "Unextractable can only be set for constructors"
|
|
665
|
-
raise ValueError(msg)
|
|
666
|
-
decl = FunctionDecl(
|
|
667
|
-
signature=signature_,
|
|
668
|
-
egg_name=egg_name,
|
|
669
|
-
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
|
|
670
|
-
builtin=is_builtin,
|
|
671
|
-
)
|
|
672
|
-
res_ref = ref
|
|
673
|
-
decls.set_function_decl(ref, decl)
|
|
674
|
-
res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset, context=f"creating {ref}")
|
|
675
|
-
return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk, subsume)
|
|
652
|
+
decls.set_function_decl(ref, decl)
|
|
653
|
+
return Thunk.fn(
|
|
654
|
+
_add_default_rewrite_function, decls, ref, fn, args, ruleset, subsume, return_type, context=f"creating {ref}"
|
|
655
|
+
)
|
|
676
656
|
|
|
677
657
|
|
|
678
658
|
# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
|
|
@@ -707,7 +687,7 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..
|
|
|
707
687
|
|
|
708
688
|
|
|
709
689
|
def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations:
|
|
710
|
-
from .builtins import Unit
|
|
690
|
+
from .builtins import Unit # noqa: PLC0415
|
|
711
691
|
|
|
712
692
|
decls = Declarations()
|
|
713
693
|
decls |= cast("RuntimeClass", Unit)
|
|
@@ -746,13 +726,15 @@ def _constant_thunk(
|
|
|
746
726
|
return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
|
|
747
727
|
|
|
748
728
|
|
|
749
|
-
def
|
|
729
|
+
def _add_default_rewrite_function(
|
|
750
730
|
decls: Declarations,
|
|
751
|
-
ref:
|
|
731
|
+
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
|
|
752
732
|
fn: Callable,
|
|
753
733
|
args: Iterable[TypedExprDecl],
|
|
754
734
|
ruleset: Ruleset | None,
|
|
755
|
-
|
|
735
|
+
subsume: bool,
|
|
736
|
+
res_type: TypeOrVarRef,
|
|
737
|
+
) -> None:
|
|
756
738
|
args: list[object] = [RuntimeExpr.__from_values__(decls, a) for a in args]
|
|
757
739
|
|
|
758
740
|
# If this is a classmethod, add the class as the first arg
|
|
@@ -760,21 +742,8 @@ def _create_default_value(
|
|
|
760
742
|
tp = decls.get_paramaterized_class(ref.class_name)
|
|
761
743
|
args.insert(0, RuntimeClass(Thunk.value(decls), tp))
|
|
762
744
|
with set_current_ruleset(ruleset):
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
def _add_default_rewrite_function(
|
|
767
|
-
decls: Declarations,
|
|
768
|
-
ref: CallableRef,
|
|
769
|
-
res_type: TypeOrVarRef,
|
|
770
|
-
ruleset: Ruleset | None,
|
|
771
|
-
value_thunk: Callable[[], object],
|
|
772
|
-
subsume: bool,
|
|
773
|
-
) -> None:
|
|
774
|
-
"""
|
|
775
|
-
Helper functions that resolves a value thunk to create the default value.
|
|
776
|
-
"""
|
|
777
|
-
_add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset, subsume)
|
|
745
|
+
res = fn(*args)
|
|
746
|
+
_add_default_rewrite(decls, ref, res_type, res, ruleset, subsume)
|
|
778
747
|
|
|
779
748
|
|
|
780
749
|
def _add_default_rewrite(
|
|
@@ -794,6 +763,13 @@ def _add_default_rewrite(
|
|
|
794
763
|
return
|
|
795
764
|
resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls))
|
|
796
765
|
rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr, subsume)
|
|
766
|
+
ruleset_decls = _add_default_rewrite_inner(decls, rewrite_decl, ruleset)
|
|
767
|
+
ruleset_decls |= resolved_value
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
def _add_default_rewrite_inner(
|
|
771
|
+
decls: Declarations, rewrite_decl: DefaultRewriteDecl, ruleset: Ruleset | None
|
|
772
|
+
) -> Declarations:
|
|
797
773
|
if ruleset:
|
|
798
774
|
ruleset_decls = ruleset._current_egg_decls
|
|
799
775
|
ruleset_decl = ruleset.__egg_ruleset__
|
|
@@ -801,7 +777,7 @@ def _add_default_rewrite(
|
|
|
801
777
|
ruleset_decls = decls
|
|
802
778
|
ruleset_decl = decls.default_ruleset
|
|
803
779
|
ruleset_decl.rules.append(rewrite_decl)
|
|
804
|
-
ruleset_decls
|
|
780
|
+
return ruleset_decls
|
|
805
781
|
|
|
806
782
|
|
|
807
783
|
def _last_param_variable(params: list[Parameter]) -> bool:
|
|
@@ -882,6 +858,7 @@ class EGraph:
|
|
|
882
858
|
self._add_decls(decls)
|
|
883
859
|
return self._state.callable_ref_to_egg(ref)[0]
|
|
884
860
|
|
|
861
|
+
# TODO: Change let to be action...
|
|
885
862
|
def let(self, name: str, expr: BASE_EXPR) -> BASE_EXPR:
|
|
886
863
|
"""
|
|
887
864
|
Define a new expression in the egraph and return a reference to it.
|
|
@@ -893,38 +870,10 @@ class EGraph:
|
|
|
893
870
|
return cast(
|
|
894
871
|
"BASE_EXPR",
|
|
895
872
|
RuntimeExpr.__from_values__(
|
|
896
|
-
self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp,
|
|
873
|
+
self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, LetRefDecl(name))
|
|
897
874
|
),
|
|
898
875
|
)
|
|
899
876
|
|
|
900
|
-
@overload
|
|
901
|
-
def simplify(self, expr: BASE_EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> BASE_EXPR: ...
|
|
902
|
-
|
|
903
|
-
@overload
|
|
904
|
-
def simplify(self, expr: BASE_EXPR, schedule: Schedule, /) -> BASE_EXPR: ...
|
|
905
|
-
|
|
906
|
-
def simplify(
|
|
907
|
-
self, expr: BASE_EXPR, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None
|
|
908
|
-
) -> BASE_EXPR:
|
|
909
|
-
"""
|
|
910
|
-
Simplifies the given expression.
|
|
911
|
-
"""
|
|
912
|
-
schedule = run(ruleset, *until) * limit_or_schedule if isinstance(limit_or_schedule, int) else limit_or_schedule
|
|
913
|
-
del limit_or_schedule, until, ruleset
|
|
914
|
-
runtime_expr = to_runtime_expr(expr)
|
|
915
|
-
self._add_decls(runtime_expr, schedule)
|
|
916
|
-
egg_schedule = self._state.schedule_to_egg(schedule.schedule)
|
|
917
|
-
typed_expr = runtime_expr.__egg_typed_expr__
|
|
918
|
-
# Must also register type
|
|
919
|
-
egg_expr = self._state.typed_expr_to_egg(typed_expr)
|
|
920
|
-
self._egraph.run_program(bindings.Simplify(span(1), egg_expr, egg_schedule))
|
|
921
|
-
extract_report = self._egraph.extract_report()
|
|
922
|
-
if not isinstance(extract_report, bindings.Best):
|
|
923
|
-
msg = "No extract report saved"
|
|
924
|
-
raise ValueError(msg) # noqa: TRY004
|
|
925
|
-
(new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
|
|
926
|
-
return cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
|
|
927
|
-
|
|
928
877
|
def include(self, path: str) -> None:
|
|
929
878
|
"""
|
|
930
879
|
Include a file of rules.
|
|
@@ -1036,9 +985,7 @@ class EGraph:
|
|
|
1036
985
|
self._add_decls(expr)
|
|
1037
986
|
expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
|
|
1038
987
|
try:
|
|
1039
|
-
self._egraph.run_program(
|
|
1040
|
-
bindings.ActionCommand(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))
|
|
1041
|
-
)
|
|
988
|
+
self._egraph.run_program(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))
|
|
1042
989
|
except BaseException as e:
|
|
1043
990
|
raise add_note("Extracting: " + str(expr), e) # noqa: B904
|
|
1044
991
|
extract_report = self._egraph.extract_report()
|
|
@@ -1138,9 +1085,9 @@ class EGraph:
|
|
|
1138
1085
|
|
|
1139
1086
|
If in IPython it will display it inline, otherwise it will write it to a file and open it.
|
|
1140
1087
|
"""
|
|
1141
|
-
from IPython.display import SVG, display
|
|
1088
|
+
from IPython.display import SVG, display # noqa: PLC0415
|
|
1142
1089
|
|
|
1143
|
-
from .visualizer_widget import VisualizerWidget
|
|
1090
|
+
from .visualizer_widget import VisualizerWidget # noqa: PLC0415
|
|
1144
1091
|
|
|
1145
1092
|
if graphviz:
|
|
1146
1093
|
if IN_IPYTHON:
|
|
@@ -1167,7 +1114,7 @@ class EGraph:
|
|
|
1167
1114
|
|
|
1168
1115
|
If an `expr` is passed, it's also extracted after each run and printed
|
|
1169
1116
|
"""
|
|
1170
|
-
from .visualizer_widget import VisualizerWidget
|
|
1117
|
+
from .visualizer_widget import VisualizerWidget # noqa: PLC0415
|
|
1171
1118
|
|
|
1172
1119
|
def to_json() -> str:
|
|
1173
1120
|
if expr is not None:
|
|
@@ -1559,16 +1506,17 @@ def rule(*facts: FactLike, ruleset: None = None, name: str | None = None) -> _Ru
|
|
|
1559
1506
|
return _RuleBuilder(facts=_fact_likes(facts), name=name, ruleset=ruleset)
|
|
1560
1507
|
|
|
1561
1508
|
|
|
1562
|
-
def var(name: str, bound: type[T]) -> T:
|
|
1509
|
+
def var(name: str, bound: type[T], egg_name: str | None = None) -> T:
|
|
1563
1510
|
"""Create a new variable with the given name and type."""
|
|
1564
|
-
return cast("T", _var(name, bound))
|
|
1511
|
+
return cast("T", _var(name, bound, egg_name=egg_name))
|
|
1565
1512
|
|
|
1566
1513
|
|
|
1567
|
-
def _var(name: str, bound: object) -> RuntimeExpr:
|
|
1514
|
+
def _var(name: str, bound: object, egg_name: str | None) -> RuntimeExpr:
|
|
1568
1515
|
"""Create a new variable with the given name and type."""
|
|
1569
1516
|
decls_like, type_ref = resolve_type_annotation(bound)
|
|
1570
1517
|
return RuntimeExpr(
|
|
1571
|
-
Thunk.fn(Declarations.create, decls_like),
|
|
1518
|
+
Thunk.fn(Declarations.create, decls_like),
|
|
1519
|
+
Thunk.value(TypedExprDecl(type_ref.to_just(), UnboundVarDecl(name, egg_name))),
|
|
1572
1520
|
)
|
|
1573
1521
|
|
|
1574
1522
|
|
|
@@ -1659,7 +1607,7 @@ class _NeBuilder(Generic[BASE_EXPR]):
|
|
|
1659
1607
|
lhs: BASE_EXPR
|
|
1660
1608
|
|
|
1661
1609
|
def to(self, rhs: BASE_EXPR) -> Unit:
|
|
1662
|
-
from .builtins import Unit
|
|
1610
|
+
from .builtins import Unit # noqa: PLC0415
|
|
1663
1611
|
|
|
1664
1612
|
lhs = to_runtime_expr(self.lhs)
|
|
1665
1613
|
rhs = convert_to_same_type(rhs, lhs)
|
|
@@ -1818,7 +1766,7 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
|
|
|
1818
1766
|
# python/tests/test_no_import_star.py::test_no_import_star_rulesset
|
|
1819
1767
|
combined = {**gen.__globals__, **frame.f_locals}
|
|
1820
1768
|
hints = get_type_hints(gen, combined, combined)
|
|
1821
|
-
args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
|
|
1769
|
+
args = [_var(p.name, hints[p.name], egg_name=None) for p in signature(gen).parameters.values()]
|
|
1822
1770
|
return list(gen(*args)) # type: ignore[misc]
|
|
1823
1771
|
|
|
1824
1772
|
|
|
@@ -1844,7 +1792,7 @@ def get_current_ruleset() -> Ruleset | None:
|
|
|
1844
1792
|
|
|
1845
1793
|
@contextlib.contextmanager
|
|
1846
1794
|
def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
|
|
1847
|
-
token = _CURRENT_RULESET.set(r)
|
|
1795
|
+
token: Token[Ruleset | None] = _CURRENT_RULESET.set(r)
|
|
1848
1796
|
try:
|
|
1849
1797
|
yield
|
|
1850
1798
|
finally:
|
egglog/egraph_state.py
CHANGED
|
@@ -108,7 +108,7 @@ class EGraphState:
|
|
|
108
108
|
case RulesetDecl(rules):
|
|
109
109
|
if name not in self.rulesets:
|
|
110
110
|
if name:
|
|
111
|
-
self.egraph.run_program(bindings.AddRuleset(name))
|
|
111
|
+
self.egraph.run_program(bindings.AddRuleset(span(), name))
|
|
112
112
|
added_rules = self.rulesets[name] = set()
|
|
113
113
|
else:
|
|
114
114
|
added_rules = self.rulesets[name]
|
|
@@ -125,7 +125,7 @@ class EGraphState:
|
|
|
125
125
|
self.rulesets[name] = set()
|
|
126
126
|
for ruleset in rulesets:
|
|
127
127
|
self.ruleset_to_egg(ruleset)
|
|
128
|
-
self.egraph.run_program(bindings.UnstableCombinedRuleset(name, list(rulesets)))
|
|
128
|
+
self.egraph.run_program(bindings.UnstableCombinedRuleset(span(), name, list(rulesets)))
|
|
129
129
|
|
|
130
130
|
def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command | None:
|
|
131
131
|
match cmd:
|
|
@@ -160,7 +160,7 @@ class EGraphState:
|
|
|
160
160
|
assert isinstance(sig, FunctionSignature)
|
|
161
161
|
# Replace args with rule_var_name mapping
|
|
162
162
|
arg_mapping = tuple(
|
|
163
|
-
TypedExprDecl(tp.to_just(),
|
|
163
|
+
TypedExprDecl(tp.to_just(), UnboundVarDecl(name))
|
|
164
164
|
for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
|
|
165
165
|
)
|
|
166
166
|
rewrite_decl = RewriteDecl(
|
|
@@ -179,7 +179,7 @@ class EGraphState:
|
|
|
179
179
|
def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindings._Action | None: # noqa: C901, PLR0911, PLR0912
|
|
180
180
|
match action:
|
|
181
181
|
case LetDecl(name, typed_expr):
|
|
182
|
-
var_decl =
|
|
182
|
+
var_decl = LetRefDecl(name)
|
|
183
183
|
var_egg = self._expr_to_egg(var_decl)
|
|
184
184
|
self.expr_to_egg_cache[var_decl] = var_egg
|
|
185
185
|
return bindings.Let(span(), var_egg.name, self.typed_expr_to_egg(typed_expr))
|
|
@@ -369,7 +369,8 @@ class EGraphState:
|
|
|
369
369
|
"""
|
|
370
370
|
Rewrites this expression as a let binding if it's not already a let binding.
|
|
371
371
|
"""
|
|
372
|
-
|
|
372
|
+
# TODO: Replace with counter so that it works with hash collisions and is more stable
|
|
373
|
+
var_decl = LetRefDecl(f"__expr_{hash(typed_expr)}")
|
|
373
374
|
if var_decl in self.expr_to_egg_cache:
|
|
374
375
|
return None
|
|
375
376
|
var_egg = self._expr_to_egg(var_decl)
|
|
@@ -387,7 +388,7 @@ class EGraphState:
|
|
|
387
388
|
def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
|
|
388
389
|
|
|
389
390
|
@overload
|
|
390
|
-
def _expr_to_egg(self, expr_decl:
|
|
391
|
+
def _expr_to_egg(self, expr_decl: UnboundVarDecl | LetRefDecl) -> bindings.Var: ...
|
|
391
392
|
|
|
392
393
|
@overload
|
|
393
394
|
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
|
|
@@ -402,11 +403,10 @@ class EGraphState:
|
|
|
402
403
|
pass
|
|
403
404
|
res: bindings._Expr
|
|
404
405
|
match expr_decl:
|
|
405
|
-
case
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
res = bindings.Var(span(), name)
|
|
406
|
+
case LetRefDecl(name):
|
|
407
|
+
res = bindings.Var(span(), f"{name}")
|
|
408
|
+
case UnboundVarDecl(name, egg_name):
|
|
409
|
+
res = bindings.Var(span(), egg_name or f"_{name}")
|
|
410
410
|
case LitDecl(value):
|
|
411
411
|
l: bindings._Literal
|
|
412
412
|
match value:
|
|
@@ -467,7 +467,8 @@ class EGraphState:
|
|
|
467
467
|
return name
|
|
468
468
|
|
|
469
469
|
case ConstantRef(name):
|
|
470
|
-
|
|
470
|
+
# Prefix to avoid name collisions with local vars
|
|
471
|
+
return f"%{name}"
|
|
471
472
|
case (
|
|
472
473
|
MethodRef(cls_name, name)
|
|
473
474
|
| ClassMethodRef(cls_name, name)
|
|
@@ -549,7 +550,7 @@ class FromEggState:
|
|
|
549
550
|
"""
|
|
550
551
|
expr_decl: ExprDecl
|
|
551
552
|
if isinstance(term, bindings.TermVar):
|
|
552
|
-
expr_decl =
|
|
553
|
+
expr_decl = LetRefDecl(term.name)
|
|
553
554
|
elif isinstance(term, bindings.TermLit):
|
|
554
555
|
value = term.value
|
|
555
556
|
expr_decl = LitDecl(None if isinstance(value, bindings.Unit) else value.value)
|
|
@@ -624,7 +625,9 @@ class FromEggState:
|
|
|
624
625
|
# but dont need to store them
|
|
625
626
|
bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else None,
|
|
626
627
|
)
|
|
627
|
-
raise ValueError(
|
|
628
|
+
raise ValueError(
|
|
629
|
+
f"Could not find callable ref for call {term}. None of these refs matched the types: {self.state.egg_fn_to_callable_refs[term.name]}"
|
|
630
|
+
)
|
|
628
631
|
|
|
629
632
|
def resolve_term(self, term_id: int, tp: JustTypeRef) -> TypedExprDecl:
|
|
630
633
|
try:
|
egglog/examples/bignum.py
CHANGED
egglog/examples/multiset.py
CHANGED
|
@@ -32,7 +32,7 @@ egraph.register(xs)
|
|
|
32
32
|
egraph.check(xs == MultiSet(Math(1), Math(3), Math(2)))
|
|
33
33
|
egraph.check_fail(xs == MultiSet(Math(1), Math(1), Math(2), Math(3)))
|
|
34
34
|
|
|
35
|
-
assert Counter(egraph.extract(xs).
|
|
35
|
+
assert Counter(egraph.extract(xs).value) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
|
|
@@ -45,7 +45,7 @@ assert Math(4) not in xs
|
|
|
45
45
|
|
|
46
46
|
egraph.check(xs.remove(Math(1)) == MultiSet(Math(2), Math(3)))
|
|
47
47
|
|
|
48
|
-
assert egraph.extract(xs.length()).
|
|
48
|
+
assert egraph.extract(xs.length()).value == 3
|
|
49
49
|
assert len(xs) == 3
|
|
50
50
|
|
|
51
51
|
egraph.check(MultiSet(Math(1), Math(1)).length() == i64(2))
|
egglog/exp/array_api.py
CHANGED
|
@@ -154,6 +154,18 @@ class Int(Expr, ruleset=array_api_ruleset):
|
|
|
154
154
|
def __eq__(self, other: IntLike) -> Boolean: # type: ignore[override]
|
|
155
155
|
...
|
|
156
156
|
|
|
157
|
+
# add a hash so that this test can pass
|
|
158
|
+
# https://github.com/scikit-learn/scikit-learn/blob/6fd23fca53845b32b249f2b36051c081b65e2fab/sklearn/utils/validation.py#L486-L487
|
|
159
|
+
@method(preserve=True)
|
|
160
|
+
def __hash__(self) -> int:
|
|
161
|
+
egraph = _get_current_egraph()
|
|
162
|
+
egraph.register(self)
|
|
163
|
+
egraph.run(array_api_schedule)
|
|
164
|
+
simplified = egraph.extract(self)
|
|
165
|
+
return hash(cast("RuntimeExpr", simplified).__egg_typed_expr__)
|
|
166
|
+
|
|
167
|
+
def __round__(self, ndigits: OptionalIntLike = None) -> Int: ...
|
|
168
|
+
|
|
157
169
|
# TODO: Fix this?
|
|
158
170
|
# Make != always return a Bool, so that numpy.unique works on a tuple of ints
|
|
159
171
|
# In _unique1d
|
|
@@ -280,6 +292,8 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
|
|
|
280
292
|
yield rewrite(Int.if_(TRUE, o, b), subsume=True).to(o)
|
|
281
293
|
yield rewrite(Int.if_(FALSE, o, b), subsume=True).to(b)
|
|
282
294
|
|
|
295
|
+
yield rewrite(o.__round__(OptionalInt.none)).to(o)
|
|
296
|
+
|
|
283
297
|
# Never cannot be equal to anything real
|
|
284
298
|
yield rule(eq(Int.NEVER).to(Int(i))).then(panic("Int.NEVER cannot be equal to any real int"))
|
|
285
299
|
|
|
@@ -354,8 +368,14 @@ class Float(Expr, ruleset=array_api_ruleset):
|
|
|
354
368
|
def __sub__(self, other: FloatLike) -> Float: ...
|
|
355
369
|
|
|
356
370
|
def __pow__(self, other: FloatLike) -> Float: ...
|
|
371
|
+
def __round__(self, ndigits: OptionalIntLike = None) -> Float: ...
|
|
357
372
|
|
|
358
373
|
def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
|
|
374
|
+
def __ne__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
|
|
375
|
+
def __lt__(self, other: FloatLike) -> Boolean: ...
|
|
376
|
+
def __le__(self, other: FloatLike) -> Boolean: ...
|
|
377
|
+
def __gt__(self, other: FloatLike) -> Boolean: ...
|
|
378
|
+
def __ge__(self, other: FloatLike) -> Boolean: ...
|
|
359
379
|
|
|
360
380
|
|
|
361
381
|
converter(float, Float, lambda x: Float(x))
|
|
@@ -366,9 +386,10 @@ FloatLike: TypeAlias = Float | float | IntLike
|
|
|
366
386
|
|
|
367
387
|
|
|
368
388
|
@array_api_ruleset.register
|
|
369
|
-
def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
|
|
389
|
+
def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat, i_: Int):
|
|
370
390
|
return [
|
|
371
391
|
rule(eq(fl).to(Float(f))).then(set_(fl.to_f64).to(f)),
|
|
392
|
+
rewrite(Float.from_int(Int(i))).to(Float(f64.from_i64(i))),
|
|
372
393
|
rewrite(Float(f).abs()).to(Float(f), f >= 0.0),
|
|
373
394
|
rewrite(Float(f).abs()).to(Float(-f), f < 0.0),
|
|
374
395
|
# Convert from float to rationl, if its a whole number i.e. can be converted to int
|
|
@@ -383,11 +404,22 @@ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
|
|
|
383
404
|
rewrite(Float.rational(r) - Float.rational(r1)).to(Float.rational(r - r1)),
|
|
384
405
|
rewrite(Float.rational(r) * Float.rational(r1)).to(Float.rational(r * r1)),
|
|
385
406
|
rewrite(Float(f) ** Float(f2)).to(Float(f**f2)),
|
|
386
|
-
#
|
|
407
|
+
# comparisons
|
|
387
408
|
rewrite(Float(f) == Float(f)).to(TRUE),
|
|
388
409
|
rewrite(Float(f) == Float(f2)).to(FALSE, ne(f).to(f2)),
|
|
410
|
+
rewrite(Float(f) != Float(f2)).to(TRUE, f != f2),
|
|
411
|
+
rewrite(Float(f) != Float(f)).to(FALSE),
|
|
412
|
+
rewrite(Float(f) >= Float(f2)).to(TRUE, f >= f2),
|
|
413
|
+
rewrite(Float(f) >= Float(f2)).to(FALSE, f < f2),
|
|
414
|
+
rewrite(Float(f) <= Float(f2)).to(TRUE, f <= f2),
|
|
415
|
+
rewrite(Float(f) <= Float(f2)).to(FALSE, f > f2),
|
|
416
|
+
rewrite(Float(f) > Float(f2)).to(TRUE, f > f2),
|
|
417
|
+
rewrite(Float(f) > Float(f2)).to(FALSE, f <= f2),
|
|
418
|
+
rewrite(Float(f) < Float(f2)).to(TRUE, f < f2),
|
|
389
419
|
rewrite(Float.rational(r) == Float.rational(r)).to(TRUE),
|
|
390
420
|
rewrite(Float.rational(r) == Float.rational(r1)).to(FALSE, ne(r).to(r1)),
|
|
421
|
+
# round
|
|
422
|
+
rewrite(Float.rational(r).__round__()).to(Float.rational(r.round())),
|
|
391
423
|
]
|
|
392
424
|
|
|
393
425
|
|
|
@@ -671,6 +703,8 @@ class OptionalInt(Expr, ruleset=array_api_ruleset):
|
|
|
671
703
|
def some(cls, value: Int) -> OptionalInt: ...
|
|
672
704
|
|
|
673
705
|
|
|
706
|
+
OptionalIntLike: TypeAlias = OptionalInt | IntLike | None
|
|
707
|
+
|
|
674
708
|
converter(type(None), OptionalInt, lambda _: OptionalInt.none)
|
|
675
709
|
converter(Int, OptionalInt, OptionalInt.some)
|
|
676
710
|
|
|
@@ -1982,4 +2016,4 @@ def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: Built
|
|
|
1982
2016
|
except BaseException as e:
|
|
1983
2017
|
# egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
|
|
1984
2018
|
raise add_note(f"Cannot evaluate {egraph.extract(expr)}", e) # noqa: B904
|
|
1985
|
-
return extracted.
|
|
2019
|
+
return extracted.value # type: ignore[attr-defined]
|
egglog/exp/array_api_jit.py
CHANGED
|
@@ -14,16 +14,22 @@ from .program_gen import Program
|
|
|
14
14
|
X = TypeVar("X", bound=Callable)
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def jit(
|
|
17
|
+
def jit(
|
|
18
|
+
fn: X,
|
|
19
|
+
*,
|
|
20
|
+
handle_expr: Callable[[NDArray], None] | None = None,
|
|
21
|
+
handle_optimized_expr: Callable[[NDArray], None] | None = None,
|
|
22
|
+
) -> X:
|
|
18
23
|
"""
|
|
19
24
|
Jit compiles a function
|
|
20
25
|
"""
|
|
21
26
|
egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
|
|
27
|
+
if handle_expr:
|
|
28
|
+
handle_expr(res)
|
|
29
|
+
if handle_optimized_expr:
|
|
30
|
+
handle_optimized_expr(res_optimized)
|
|
22
31
|
fn_program = EvalProgram(program, {"np": np})
|
|
23
|
-
|
|
24
|
-
fn.initial_expr = res # type: ignore[attr-defined]
|
|
25
|
-
fn.expr = res_optimized # type: ignore[attr-defined]
|
|
26
|
-
return fn
|
|
32
|
+
return cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
|
|
27
33
|
|
|
28
34
|
|
|
29
35
|
def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph, NDArray, NDArray, Program]:
|