egglog 10.0.1__cp310-cp310-win_amd64.whl → 11.0.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +3 -1
- egglog/bindings.cp310-win_amd64.pyd +0 -0
- egglog/bindings.pyi +26 -34
- egglog/builtins.py +231 -182
- egglog/conversion.py +61 -43
- egglog/declarations.py +104 -18
- egglog/deconstruct.py +173 -0
- egglog/egraph.py +86 -144
- egglog/egraph_state.py +17 -14
- egglog/examples/bignum.py +1 -1
- egglog/examples/multiset.py +2 -2
- egglog/exp/array_api.py +46 -12
- egglog/exp/array_api_jit.py +11 -5
- egglog/exp/array_api_program_gen.py +1 -1
- egglog/exp/program_gen.py +4 -3
- egglog/pretty.py +11 -25
- egglog/runtime.py +203 -151
- egglog/thunk.py +6 -4
- egglog/type_constraint_solver.py +1 -1
- egglog/version_compat.py +87 -0
- {egglog-10.0.1.dist-info → egglog-11.0.0.dist-info}/METADATA +1 -1
- egglog-11.0.0.dist-info/RECORD +45 -0
- {egglog-10.0.1.dist-info → egglog-11.0.0.dist-info}/WHEEL +1 -1
- egglog/functionalize.py +0 -91
- egglog-10.0.1.dist-info/RECORD +0 -44
- {egglog-10.0.1.dist-info → egglog-11.0.0.dist-info}/licenses/LICENSE +0 -0
egglog/egraph.py
CHANGED
|
@@ -5,7 +5,7 @@ import inspect
|
|
|
5
5
|
import pathlib
|
|
6
6
|
import tempfile
|
|
7
7
|
from collections.abc import Callable, Generator, Iterable
|
|
8
|
-
from contextvars import ContextVar
|
|
8
|
+
from contextvars import ContextVar, Token
|
|
9
9
|
from dataclasses import InitVar, dataclass, field
|
|
10
10
|
from functools import partial
|
|
11
11
|
from inspect import Parameter, currentframe, signature
|
|
@@ -16,7 +16,6 @@ from typing import (
|
|
|
16
16
|
ClassVar,
|
|
17
17
|
Generic,
|
|
18
18
|
Literal,
|
|
19
|
-
Never,
|
|
20
19
|
TypeAlias,
|
|
21
20
|
TypedDict,
|
|
22
21
|
TypeVar,
|
|
@@ -26,16 +25,18 @@ from typing import (
|
|
|
26
25
|
)
|
|
27
26
|
|
|
28
27
|
import graphviz
|
|
29
|
-
from typing_extensions import ParamSpec, Self, Unpack, assert_never
|
|
28
|
+
from typing_extensions import Never, ParamSpec, Self, Unpack, assert_never
|
|
30
29
|
|
|
31
30
|
from . import bindings
|
|
32
31
|
from .conversion import *
|
|
32
|
+
from .conversion import convert_to_same_type, resolve_literal
|
|
33
33
|
from .declarations import *
|
|
34
34
|
from .egraph_state import *
|
|
35
35
|
from .ipython_magic import IN_IPYTHON
|
|
36
36
|
from .pretty import pretty_decl
|
|
37
37
|
from .runtime import *
|
|
38
38
|
from .thunk import *
|
|
39
|
+
from .version_compat import *
|
|
39
40
|
|
|
40
41
|
if TYPE_CHECKING:
|
|
41
42
|
from .builtins import String, Unit
|
|
@@ -82,7 +83,6 @@ __all__ = [
|
|
|
82
83
|
"run",
|
|
83
84
|
"seq",
|
|
84
85
|
"set_",
|
|
85
|
-
"simplify",
|
|
86
86
|
"subsume",
|
|
87
87
|
"union",
|
|
88
88
|
"unstable_combine_rulesets",
|
|
@@ -110,12 +110,12 @@ IGNORED_ATTRIBUTES = {
|
|
|
110
110
|
"__weakref__",
|
|
111
111
|
"__orig_bases__",
|
|
112
112
|
"__annotations__",
|
|
113
|
-
"__hash__",
|
|
114
113
|
"__qualname__",
|
|
115
114
|
"__firstlineno__",
|
|
116
115
|
"__static_attributes__",
|
|
116
|
+
"__match_args__",
|
|
117
117
|
# Ignore all reflected binary method
|
|
118
|
-
*
|
|
118
|
+
*(f"__r{m[2:]}" for m in NUMERIC_BINARY_METHODS),
|
|
119
119
|
}
|
|
120
120
|
|
|
121
121
|
|
|
@@ -139,15 +139,6 @@ ALWAYS_PRESERVED = {
|
|
|
139
139
|
}
|
|
140
140
|
|
|
141
141
|
|
|
142
|
-
def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
|
|
143
|
-
"""
|
|
144
|
-
Simplify an expression by running the schedule.
|
|
145
|
-
"""
|
|
146
|
-
if schedule:
|
|
147
|
-
return EGraph().simplify(x, schedule)
|
|
148
|
-
return EGraph().extract(x)
|
|
149
|
-
|
|
150
|
-
|
|
151
142
|
def check_eq(x: BASE_EXPR, y: BASE_EXPR, schedule: Schedule | None = None, *, add_second=True, display=False) -> EGraph:
|
|
152
143
|
"""
|
|
153
144
|
Verifies that two expressions are equal after running the schedule.
|
|
@@ -169,8 +160,9 @@ def check_eq(x: BASE_EXPR, y: BASE_EXPR, schedule: Schedule | None = None, *, ad
|
|
|
169
160
|
except bindings.EggSmolError as err:
|
|
170
161
|
if display:
|
|
171
162
|
egraph.display()
|
|
172
|
-
|
|
173
|
-
|
|
163
|
+
raise add_note(
|
|
164
|
+
f"Failed:\n{eq(x).to(y)}\n\nExtracted:\n {eq(egraph.extract(x)).to(egraph.extract(y))})", err
|
|
165
|
+
) from None
|
|
174
166
|
return egraph
|
|
175
167
|
|
|
176
168
|
|
|
@@ -290,7 +282,6 @@ def function(
|
|
|
290
282
|
mutates_first_arg: bool = ...,
|
|
291
283
|
unextractable: bool = ...,
|
|
292
284
|
ruleset: Ruleset | None = ...,
|
|
293
|
-
use_body_as_name: bool = ...,
|
|
294
285
|
subsume: bool = ...,
|
|
295
286
|
) -> Callable[[CONSTRUCTOR_CALLABLE], CONSTRUCTOR_CALLABLE]: ...
|
|
296
287
|
|
|
@@ -370,6 +361,7 @@ class BaseExpr(metaclass=_ExprMetaclass):
|
|
|
370
361
|
|
|
371
362
|
def __ne__(self, other: Self) -> Unit: ... # type: ignore[override, empty-body]
|
|
372
363
|
|
|
364
|
+
# not currently dissalowing other types of equality https://github.com/python/typeshed/issues/8217#issuecomment-3140873292
|
|
373
365
|
def __eq__(self, other: Self) -> Fact: ... # type: ignore[override, empty-body]
|
|
374
366
|
|
|
375
367
|
|
|
@@ -403,7 +395,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
|
|
|
403
395
|
)
|
|
404
396
|
type_vars = tuple(ClassTypeVarRef.from_type_var(p) for p in parameters)
|
|
405
397
|
del parameters
|
|
406
|
-
cls_decl = ClassDecl(egg_sort, type_vars, builtin)
|
|
398
|
+
cls_decl = ClassDecl(egg_sort, type_vars, builtin, match_args=namespace.pop("__match_args__", ()))
|
|
407
399
|
decls = Declarations(_classes={cls_name: cls_decl})
|
|
408
400
|
# Update class think eagerly when resolving so that lookups work in methods
|
|
409
401
|
runtime_cls.__egg_decls_thunk__ = Thunk.value(decls)
|
|
@@ -455,6 +447,9 @@ def _generate_class_decls( # noqa: C901,PLR0912
|
|
|
455
447
|
continue
|
|
456
448
|
locals = frame.f_locals
|
|
457
449
|
ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
|
|
450
|
+
# TODO: Store deprecated message so we can print at runtime
|
|
451
|
+
if (getattr(fn, "__deprecated__", None)) is not None:
|
|
452
|
+
fn = fn.__wrapped__ # type: ignore[attr-defined]
|
|
458
453
|
match fn:
|
|
459
454
|
case classmethod():
|
|
460
455
|
ref = ClassMethodRef(cls_name, method_name)
|
|
@@ -476,7 +471,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
|
|
|
476
471
|
decls.set_function_decl(ref, decl)
|
|
477
472
|
continue
|
|
478
473
|
try:
|
|
479
|
-
|
|
474
|
+
add_rewrite = _fn_decl(
|
|
480
475
|
decls,
|
|
481
476
|
egg_fn,
|
|
482
477
|
ref,
|
|
@@ -492,8 +487,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
|
|
|
492
487
|
reverse_args=reverse_args,
|
|
493
488
|
)
|
|
494
489
|
except Exception as e:
|
|
495
|
-
|
|
496
|
-
raise
|
|
490
|
+
raise add_note(f"Error processing {cls_name}.{method_name}", e) from None
|
|
497
491
|
|
|
498
492
|
if not builtin and not isinstance(ref, InitRef) and not mutates:
|
|
499
493
|
add_default_funcs.append(add_rewrite)
|
|
@@ -515,7 +509,6 @@ class _FunctionConstructor:
|
|
|
515
509
|
merge: Callable[[object, object], object] | None = None
|
|
516
510
|
unextractable: bool = False
|
|
517
511
|
ruleset: Ruleset | None = None
|
|
518
|
-
use_body_as_name: bool = False
|
|
519
512
|
subsume: bool = False
|
|
520
513
|
|
|
521
514
|
def __call__(self, fn: Callable) -> RuntimeFunction:
|
|
@@ -523,11 +516,10 @@ class _FunctionConstructor:
|
|
|
523
516
|
|
|
524
517
|
def create_decls(self, fn: Callable) -> tuple[Declarations, CallableRef]:
|
|
525
518
|
decls = Declarations()
|
|
526
|
-
|
|
527
|
-
ref, add_rewrite = _fn_decl(
|
|
519
|
+
add_rewrite = _fn_decl(
|
|
528
520
|
decls,
|
|
529
521
|
self.egg_fn,
|
|
530
|
-
ref,
|
|
522
|
+
ref := FunctionRef(fn.__name__),
|
|
531
523
|
fn,
|
|
532
524
|
self.hint_locals,
|
|
533
525
|
self.cost,
|
|
@@ -545,8 +537,7 @@ class _FunctionConstructor:
|
|
|
545
537
|
def _fn_decl(
|
|
546
538
|
decls: Declarations,
|
|
547
539
|
egg_name: str | None,
|
|
548
|
-
|
|
549
|
-
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef | None,
|
|
540
|
+
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
|
|
550
541
|
fn: object,
|
|
551
542
|
# Pass in the locals, retrieved from the frame when wrapping,
|
|
552
543
|
# so that we support classes and function defined inside of other functions (which won't show up in the globals)
|
|
@@ -559,7 +550,7 @@ def _fn_decl(
|
|
|
559
550
|
ruleset: Ruleset | None = None,
|
|
560
551
|
unextractable: bool = False,
|
|
561
552
|
reverse_args: bool = False,
|
|
562
|
-
) ->
|
|
553
|
+
) -> Callable[[], None]:
|
|
563
554
|
"""
|
|
564
555
|
Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
|
|
565
556
|
"""
|
|
@@ -569,16 +560,11 @@ def _fn_decl(
|
|
|
569
560
|
if not isinstance(fn, FunctionType):
|
|
570
561
|
raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
|
|
571
562
|
|
|
572
|
-
hint_globals = fn.__globals__.copy()
|
|
573
|
-
# Copy Callable into global if not present bc sometimes it gets automatically removed by ruff to type only block
|
|
574
|
-
# https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/
|
|
575
|
-
if "Callable" not in hint_globals:
|
|
576
|
-
hint_globals["Callable"] = Callable
|
|
577
563
|
# Instead of passing both globals and locals, just pass the globals. Otherwise, for some reason forward references
|
|
578
564
|
# won't be resolved correctly
|
|
579
565
|
# We need this to be false so it returns "__forward_value__" https://github.com/python/cpython/blob/440ed18e08887b958ad50db1b823e692a747b671/Lib/typing.py#L919
|
|
580
566
|
# https://github.com/egraphs-good/egglog-python/issues/210
|
|
581
|
-
hint_globals.
|
|
567
|
+
hint_globals = {**fn.__globals__, **hint_locals}
|
|
582
568
|
hints = get_type_hints(fn, hint_globals)
|
|
583
569
|
|
|
584
570
|
params = list(signature(fn).parameters.values())
|
|
@@ -624,60 +610,49 @@ def _fn_decl(
|
|
|
624
610
|
else resolve_literal(
|
|
625
611
|
return_type,
|
|
626
612
|
merge(
|
|
627
|
-
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(),
|
|
628
|
-
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(),
|
|
613
|
+
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), UnboundVarDecl("old", "old"))),
|
|
614
|
+
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), UnboundVarDecl("new", "new"))),
|
|
629
615
|
),
|
|
630
616
|
lambda: decls,
|
|
631
617
|
)
|
|
632
618
|
)
|
|
633
619
|
decls |= merged
|
|
634
620
|
|
|
635
|
-
# defer this in generator so it
|
|
636
|
-
args = (TypedExprDecl(tp.to_just(),
|
|
637
|
-
res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
|
|
638
|
-
res_thunk: Callable[[], object]
|
|
639
|
-
# If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it
|
|
640
|
-
if not ref:
|
|
641
|
-
tuple_args = tuple(args)
|
|
642
|
-
res = _create_default_value(decls, ref, fn, tuple_args, ruleset)
|
|
643
|
-
assert isinstance(res, RuntimeExpr)
|
|
644
|
-
res_ref = UnnamedFunctionRef(tuple_args, res.__egg_typed_expr__)
|
|
645
|
-
decls._unnamed_functions.add(res_ref)
|
|
646
|
-
res_thunk = Thunk.value(res)
|
|
621
|
+
# defer this in generator so it doesn't resolve for builtins eagerly
|
|
622
|
+
args = (TypedExprDecl(tp.to_just(), UnboundVarDecl(name)) for name, tp in zip(arg_names, arg_types, strict=True))
|
|
647
623
|
|
|
624
|
+
return_type_is_eqsort = (
|
|
625
|
+
not decls._classes[return_type.name].builtin if isinstance(return_type, TypeRefWithVars) else False
|
|
626
|
+
)
|
|
627
|
+
is_constructor = not is_builtin and return_type_is_eqsort and merged is None
|
|
628
|
+
signature_ = FunctionSignature(
|
|
629
|
+
return_type=None if mutates_first_arg else return_type,
|
|
630
|
+
var_arg_type=var_arg_type,
|
|
631
|
+
arg_types=arg_types,
|
|
632
|
+
arg_names=arg_names,
|
|
633
|
+
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
|
|
634
|
+
reverse_args=reverse_args,
|
|
635
|
+
)
|
|
636
|
+
decl: ConstructorDecl | FunctionDecl
|
|
637
|
+
if is_constructor:
|
|
638
|
+
decl = ConstructorDecl(signature_, egg_name, cost, unextractable)
|
|
648
639
|
else:
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
640
|
+
if cost is not None:
|
|
641
|
+
msg = "Cost can only be set for constructors"
|
|
642
|
+
raise ValueError(msg)
|
|
643
|
+
if unextractable:
|
|
644
|
+
msg = "Unextractable can only be set for constructors"
|
|
645
|
+
raise ValueError(msg)
|
|
646
|
+
decl = FunctionDecl(
|
|
647
|
+
signature=signature_,
|
|
648
|
+
egg_name=egg_name,
|
|
649
|
+
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
|
|
650
|
+
builtin=is_builtin,
|
|
660
651
|
)
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
if cost is not None:
|
|
666
|
-
msg = "Cost can only be set for constructors"
|
|
667
|
-
raise ValueError(msg)
|
|
668
|
-
if unextractable:
|
|
669
|
-
msg = "Unextractable can only be set for constructors"
|
|
670
|
-
raise ValueError(msg)
|
|
671
|
-
decl = FunctionDecl(
|
|
672
|
-
signature=signature_,
|
|
673
|
-
egg_name=egg_name,
|
|
674
|
-
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
|
|
675
|
-
builtin=is_builtin,
|
|
676
|
-
)
|
|
677
|
-
res_ref = ref
|
|
678
|
-
decls.set_function_decl(ref, decl)
|
|
679
|
-
res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset)
|
|
680
|
-
return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk, subsume)
|
|
652
|
+
decls.set_function_decl(ref, decl)
|
|
653
|
+
return Thunk.fn(
|
|
654
|
+
_add_default_rewrite_function, decls, ref, fn, args, ruleset, subsume, return_type, context=f"creating {ref}"
|
|
655
|
+
)
|
|
681
656
|
|
|
682
657
|
|
|
683
658
|
# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
|
|
@@ -712,7 +687,7 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..
|
|
|
712
687
|
|
|
713
688
|
|
|
714
689
|
def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations:
|
|
715
|
-
from .builtins import Unit
|
|
690
|
+
from .builtins import Unit # noqa: PLC0415
|
|
716
691
|
|
|
717
692
|
decls = Declarations()
|
|
718
693
|
decls |= cast("RuntimeClass", Unit)
|
|
@@ -751,13 +726,15 @@ def _constant_thunk(
|
|
|
751
726
|
return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
|
|
752
727
|
|
|
753
728
|
|
|
754
|
-
def
|
|
729
|
+
def _add_default_rewrite_function(
|
|
755
730
|
decls: Declarations,
|
|
756
|
-
ref:
|
|
731
|
+
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
|
|
757
732
|
fn: Callable,
|
|
758
733
|
args: Iterable[TypedExprDecl],
|
|
759
734
|
ruleset: Ruleset | None,
|
|
760
|
-
|
|
735
|
+
subsume: bool,
|
|
736
|
+
res_type: TypeOrVarRef,
|
|
737
|
+
) -> None:
|
|
761
738
|
args: list[object] = [RuntimeExpr.__from_values__(decls, a) for a in args]
|
|
762
739
|
|
|
763
740
|
# If this is a classmethod, add the class as the first arg
|
|
@@ -765,21 +742,8 @@ def _create_default_value(
|
|
|
765
742
|
tp = decls.get_paramaterized_class(ref.class_name)
|
|
766
743
|
args.insert(0, RuntimeClass(Thunk.value(decls), tp))
|
|
767
744
|
with set_current_ruleset(ruleset):
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
def _add_default_rewrite_function(
|
|
772
|
-
decls: Declarations,
|
|
773
|
-
ref: CallableRef,
|
|
774
|
-
res_type: TypeOrVarRef,
|
|
775
|
-
ruleset: Ruleset | None,
|
|
776
|
-
value_thunk: Callable[[], object],
|
|
777
|
-
subsume: bool,
|
|
778
|
-
) -> None:
|
|
779
|
-
"""
|
|
780
|
-
Helper functions that resolves a value thunk to create the default value.
|
|
781
|
-
"""
|
|
782
|
-
_add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset, subsume)
|
|
745
|
+
res = fn(*args)
|
|
746
|
+
_add_default_rewrite(decls, ref, res_type, res, ruleset, subsume)
|
|
783
747
|
|
|
784
748
|
|
|
785
749
|
def _add_default_rewrite(
|
|
@@ -799,6 +763,13 @@ def _add_default_rewrite(
|
|
|
799
763
|
return
|
|
800
764
|
resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls))
|
|
801
765
|
rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr, subsume)
|
|
766
|
+
ruleset_decls = _add_default_rewrite_inner(decls, rewrite_decl, ruleset)
|
|
767
|
+
ruleset_decls |= resolved_value
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
def _add_default_rewrite_inner(
|
|
771
|
+
decls: Declarations, rewrite_decl: DefaultRewriteDecl, ruleset: Ruleset | None
|
|
772
|
+
) -> Declarations:
|
|
802
773
|
if ruleset:
|
|
803
774
|
ruleset_decls = ruleset._current_egg_decls
|
|
804
775
|
ruleset_decl = ruleset.__egg_ruleset__
|
|
@@ -806,7 +777,7 @@ def _add_default_rewrite(
|
|
|
806
777
|
ruleset_decls = decls
|
|
807
778
|
ruleset_decl = decls.default_ruleset
|
|
808
779
|
ruleset_decl.rules.append(rewrite_decl)
|
|
809
|
-
ruleset_decls
|
|
780
|
+
return ruleset_decls
|
|
810
781
|
|
|
811
782
|
|
|
812
783
|
def _last_param_variable(params: list[Parameter]) -> bool:
|
|
@@ -887,6 +858,7 @@ class EGraph:
|
|
|
887
858
|
self._add_decls(decls)
|
|
888
859
|
return self._state.callable_ref_to_egg(ref)[0]
|
|
889
860
|
|
|
861
|
+
# TODO: Change let to be action...
|
|
890
862
|
def let(self, name: str, expr: BASE_EXPR) -> BASE_EXPR:
|
|
891
863
|
"""
|
|
892
864
|
Define a new expression in the egraph and return a reference to it.
|
|
@@ -898,38 +870,10 @@ class EGraph:
|
|
|
898
870
|
return cast(
|
|
899
871
|
"BASE_EXPR",
|
|
900
872
|
RuntimeExpr.__from_values__(
|
|
901
|
-
self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp,
|
|
873
|
+
self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, LetRefDecl(name))
|
|
902
874
|
),
|
|
903
875
|
)
|
|
904
876
|
|
|
905
|
-
@overload
|
|
906
|
-
def simplify(self, expr: BASE_EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> BASE_EXPR: ...
|
|
907
|
-
|
|
908
|
-
@overload
|
|
909
|
-
def simplify(self, expr: BASE_EXPR, schedule: Schedule, /) -> BASE_EXPR: ...
|
|
910
|
-
|
|
911
|
-
def simplify(
|
|
912
|
-
self, expr: BASE_EXPR, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None
|
|
913
|
-
) -> BASE_EXPR:
|
|
914
|
-
"""
|
|
915
|
-
Simplifies the given expression.
|
|
916
|
-
"""
|
|
917
|
-
schedule = run(ruleset, *until) * limit_or_schedule if isinstance(limit_or_schedule, int) else limit_or_schedule
|
|
918
|
-
del limit_or_schedule, until, ruleset
|
|
919
|
-
runtime_expr = to_runtime_expr(expr)
|
|
920
|
-
self._add_decls(runtime_expr, schedule)
|
|
921
|
-
egg_schedule = self._state.schedule_to_egg(schedule.schedule)
|
|
922
|
-
typed_expr = runtime_expr.__egg_typed_expr__
|
|
923
|
-
# Must also register type
|
|
924
|
-
egg_expr = self._state.typed_expr_to_egg(typed_expr)
|
|
925
|
-
self._egraph.run_program(bindings.Simplify(span(1), egg_expr, egg_schedule))
|
|
926
|
-
extract_report = self._egraph.extract_report()
|
|
927
|
-
if not isinstance(extract_report, bindings.Best):
|
|
928
|
-
msg = "No extract report saved"
|
|
929
|
-
raise ValueError(msg) # noqa: TRY004
|
|
930
|
-
(new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
|
|
931
|
-
return cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
|
|
932
|
-
|
|
933
877
|
def include(self, path: str) -> None:
|
|
934
878
|
"""
|
|
935
879
|
Include a file of rules.
|
|
@@ -1041,12 +985,9 @@ class EGraph:
|
|
|
1041
985
|
self._add_decls(expr)
|
|
1042
986
|
expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
|
|
1043
987
|
try:
|
|
1044
|
-
self._egraph.run_program(
|
|
1045
|
-
bindings.ActionCommand(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))
|
|
1046
|
-
)
|
|
988
|
+
self._egraph.run_program(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))
|
|
1047
989
|
except BaseException as e:
|
|
1048
|
-
|
|
1049
|
-
raise
|
|
990
|
+
raise add_note("Extracting: " + str(expr), e) # noqa: B904
|
|
1050
991
|
extract_report = self._egraph.extract_report()
|
|
1051
992
|
if not extract_report:
|
|
1052
993
|
msg = "No extract report saved"
|
|
@@ -1144,9 +1085,9 @@ class EGraph:
|
|
|
1144
1085
|
|
|
1145
1086
|
If in IPython it will display it inline, otherwise it will write it to a file and open it.
|
|
1146
1087
|
"""
|
|
1147
|
-
from IPython.display import SVG, display
|
|
1088
|
+
from IPython.display import SVG, display # noqa: PLC0415
|
|
1148
1089
|
|
|
1149
|
-
from .visualizer_widget import VisualizerWidget
|
|
1090
|
+
from .visualizer_widget import VisualizerWidget # noqa: PLC0415
|
|
1150
1091
|
|
|
1151
1092
|
if graphviz:
|
|
1152
1093
|
if IN_IPYTHON:
|
|
@@ -1173,7 +1114,7 @@ class EGraph:
|
|
|
1173
1114
|
|
|
1174
1115
|
If an `expr` is passed, it's also extracted after each run and printed
|
|
1175
1116
|
"""
|
|
1176
|
-
from .visualizer_widget import VisualizerWidget
|
|
1117
|
+
from .visualizer_widget import VisualizerWidget # noqa: PLC0415
|
|
1177
1118
|
|
|
1178
1119
|
def to_json() -> str:
|
|
1179
1120
|
if expr is not None:
|
|
@@ -1565,16 +1506,17 @@ def rule(*facts: FactLike, ruleset: None = None, name: str | None = None) -> _Ru
|
|
|
1565
1506
|
return _RuleBuilder(facts=_fact_likes(facts), name=name, ruleset=ruleset)
|
|
1566
1507
|
|
|
1567
1508
|
|
|
1568
|
-
def var(name: str, bound: type[T]) -> T:
|
|
1509
|
+
def var(name: str, bound: type[T], egg_name: str | None = None) -> T:
|
|
1569
1510
|
"""Create a new variable with the given name and type."""
|
|
1570
|
-
return cast("T", _var(name, bound))
|
|
1511
|
+
return cast("T", _var(name, bound, egg_name=egg_name))
|
|
1571
1512
|
|
|
1572
1513
|
|
|
1573
|
-
def _var(name: str, bound: object) -> RuntimeExpr:
|
|
1514
|
+
def _var(name: str, bound: object, egg_name: str | None) -> RuntimeExpr:
|
|
1574
1515
|
"""Create a new variable with the given name and type."""
|
|
1575
1516
|
decls_like, type_ref = resolve_type_annotation(bound)
|
|
1576
1517
|
return RuntimeExpr(
|
|
1577
|
-
Thunk.fn(Declarations.create, decls_like),
|
|
1518
|
+
Thunk.fn(Declarations.create, decls_like),
|
|
1519
|
+
Thunk.value(TypedExprDecl(type_ref.to_just(), UnboundVarDecl(name, egg_name))),
|
|
1578
1520
|
)
|
|
1579
1521
|
|
|
1580
1522
|
|
|
@@ -1665,7 +1607,7 @@ class _NeBuilder(Generic[BASE_EXPR]):
|
|
|
1665
1607
|
lhs: BASE_EXPR
|
|
1666
1608
|
|
|
1667
1609
|
def to(self, rhs: BASE_EXPR) -> Unit:
|
|
1668
|
-
from .builtins import Unit
|
|
1610
|
+
from .builtins import Unit # noqa: PLC0415
|
|
1669
1611
|
|
|
1670
1612
|
lhs = to_runtime_expr(self.lhs)
|
|
1671
1613
|
rhs = convert_to_same_type(rhs, lhs)
|
|
@@ -1824,7 +1766,7 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
|
|
|
1824
1766
|
# python/tests/test_no_import_star.py::test_no_import_star_rulesset
|
|
1825
1767
|
combined = {**gen.__globals__, **frame.f_locals}
|
|
1826
1768
|
hints = get_type_hints(gen, combined, combined)
|
|
1827
|
-
args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
|
|
1769
|
+
args = [_var(p.name, hints[p.name], egg_name=None) for p in signature(gen).parameters.values()]
|
|
1828
1770
|
return list(gen(*args)) # type: ignore[misc]
|
|
1829
1771
|
|
|
1830
1772
|
|
|
@@ -1850,7 +1792,7 @@ def get_current_ruleset() -> Ruleset | None:
|
|
|
1850
1792
|
|
|
1851
1793
|
@contextlib.contextmanager
|
|
1852
1794
|
def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
|
|
1853
|
-
token = _CURRENT_RULESET.set(r)
|
|
1795
|
+
token: Token[Ruleset | None] = _CURRENT_RULESET.set(r)
|
|
1854
1796
|
try:
|
|
1855
1797
|
yield
|
|
1856
1798
|
finally:
|
egglog/egraph_state.py
CHANGED
|
@@ -108,7 +108,7 @@ class EGraphState:
|
|
|
108
108
|
case RulesetDecl(rules):
|
|
109
109
|
if name not in self.rulesets:
|
|
110
110
|
if name:
|
|
111
|
-
self.egraph.run_program(bindings.AddRuleset(name))
|
|
111
|
+
self.egraph.run_program(bindings.AddRuleset(span(), name))
|
|
112
112
|
added_rules = self.rulesets[name] = set()
|
|
113
113
|
else:
|
|
114
114
|
added_rules = self.rulesets[name]
|
|
@@ -125,7 +125,7 @@ class EGraphState:
|
|
|
125
125
|
self.rulesets[name] = set()
|
|
126
126
|
for ruleset in rulesets:
|
|
127
127
|
self.ruleset_to_egg(ruleset)
|
|
128
|
-
self.egraph.run_program(bindings.UnstableCombinedRuleset(name, list(rulesets)))
|
|
128
|
+
self.egraph.run_program(bindings.UnstableCombinedRuleset(span(), name, list(rulesets)))
|
|
129
129
|
|
|
130
130
|
def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command | None:
|
|
131
131
|
match cmd:
|
|
@@ -160,7 +160,7 @@ class EGraphState:
|
|
|
160
160
|
assert isinstance(sig, FunctionSignature)
|
|
161
161
|
# Replace args with rule_var_name mapping
|
|
162
162
|
arg_mapping = tuple(
|
|
163
|
-
TypedExprDecl(tp.to_just(),
|
|
163
|
+
TypedExprDecl(tp.to_just(), UnboundVarDecl(name))
|
|
164
164
|
for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
|
|
165
165
|
)
|
|
166
166
|
rewrite_decl = RewriteDecl(
|
|
@@ -179,7 +179,7 @@ class EGraphState:
|
|
|
179
179
|
def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindings._Action | None: # noqa: C901, PLR0911, PLR0912
|
|
180
180
|
match action:
|
|
181
181
|
case LetDecl(name, typed_expr):
|
|
182
|
-
var_decl =
|
|
182
|
+
var_decl = LetRefDecl(name)
|
|
183
183
|
var_egg = self._expr_to_egg(var_decl)
|
|
184
184
|
self.expr_to_egg_cache[var_decl] = var_egg
|
|
185
185
|
return bindings.Let(span(), var_egg.name, self.typed_expr_to_egg(typed_expr))
|
|
@@ -369,7 +369,8 @@ class EGraphState:
|
|
|
369
369
|
"""
|
|
370
370
|
Rewrites this expression as a let binding if it's not already a let binding.
|
|
371
371
|
"""
|
|
372
|
-
|
|
372
|
+
# TODO: Replace with counter so that it works with hash collisions and is more stable
|
|
373
|
+
var_decl = LetRefDecl(f"__expr_{hash(typed_expr)}")
|
|
373
374
|
if var_decl in self.expr_to_egg_cache:
|
|
374
375
|
return None
|
|
375
376
|
var_egg = self._expr_to_egg(var_decl)
|
|
@@ -387,7 +388,7 @@ class EGraphState:
|
|
|
387
388
|
def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
|
|
388
389
|
|
|
389
390
|
@overload
|
|
390
|
-
def _expr_to_egg(self, expr_decl:
|
|
391
|
+
def _expr_to_egg(self, expr_decl: UnboundVarDecl | LetRefDecl) -> bindings.Var: ...
|
|
391
392
|
|
|
392
393
|
@overload
|
|
393
394
|
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
|
|
@@ -402,11 +403,10 @@ class EGraphState:
|
|
|
402
403
|
pass
|
|
403
404
|
res: bindings._Expr
|
|
404
405
|
match expr_decl:
|
|
405
|
-
case
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
res = bindings.Var(span(), name)
|
|
406
|
+
case LetRefDecl(name):
|
|
407
|
+
res = bindings.Var(span(), f"{name}")
|
|
408
|
+
case UnboundVarDecl(name, egg_name):
|
|
409
|
+
res = bindings.Var(span(), egg_name or f"_{name}")
|
|
410
410
|
case LitDecl(value):
|
|
411
411
|
l: bindings._Literal
|
|
412
412
|
match value:
|
|
@@ -467,7 +467,8 @@ class EGraphState:
|
|
|
467
467
|
return name
|
|
468
468
|
|
|
469
469
|
case ConstantRef(name):
|
|
470
|
-
|
|
470
|
+
# Prefix to avoid name collisions with local vars
|
|
471
|
+
return f"%{name}"
|
|
471
472
|
case (
|
|
472
473
|
MethodRef(cls_name, name)
|
|
473
474
|
| ClassMethodRef(cls_name, name)
|
|
@@ -549,7 +550,7 @@ class FromEggState:
|
|
|
549
550
|
"""
|
|
550
551
|
expr_decl: ExprDecl
|
|
551
552
|
if isinstance(term, bindings.TermVar):
|
|
552
|
-
expr_decl =
|
|
553
|
+
expr_decl = LetRefDecl(term.name)
|
|
553
554
|
elif isinstance(term, bindings.TermLit):
|
|
554
555
|
value = term.value
|
|
555
556
|
expr_decl = LitDecl(None if isinstance(value, bindings.Unit) else value.value)
|
|
@@ -624,7 +625,9 @@ class FromEggState:
|
|
|
624
625
|
# but dont need to store them
|
|
625
626
|
bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else None,
|
|
626
627
|
)
|
|
627
|
-
raise ValueError(
|
|
628
|
+
raise ValueError(
|
|
629
|
+
f"Could not find callable ref for call {term}. None of these refs matched the types: {self.state.egg_fn_to_callable_refs[term.name]}"
|
|
630
|
+
)
|
|
628
631
|
|
|
629
632
|
def resolve_term(self, term_id: int, tp: JustTypeRef) -> TypedExprDecl:
|
|
630
633
|
try:
|
egglog/examples/bignum.py
CHANGED
egglog/examples/multiset.py
CHANGED
|
@@ -32,7 +32,7 @@ egraph.register(xs)
|
|
|
32
32
|
egraph.check(xs == MultiSet(Math(1), Math(3), Math(2)))
|
|
33
33
|
egraph.check_fail(xs == MultiSet(Math(1), Math(1), Math(2), Math(3)))
|
|
34
34
|
|
|
35
|
-
assert Counter(egraph.extract(xs).
|
|
35
|
+
assert Counter(egraph.extract(xs).value) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
|
|
@@ -45,7 +45,7 @@ assert Math(4) not in xs
|
|
|
45
45
|
|
|
46
46
|
egraph.check(xs.remove(Math(1)) == MultiSet(Math(2), Math(3)))
|
|
47
47
|
|
|
48
|
-
assert egraph.extract(xs.length()).
|
|
48
|
+
assert egraph.extract(xs.length()).value == 3
|
|
49
49
|
assert len(xs) == 3
|
|
50
50
|
|
|
51
51
|
egraph.check(MultiSet(Math(1), Math(1)).length() == i64(2))
|