egglog 11.3.0__cp310-cp310-manylinux_2_17_ppc64.manylinux2014_ppc64.whl → 11.4.0__cp310-cp310-manylinux_2_17_ppc64.manylinux2014_ppc64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cpython-310-powerpc64-linux-gnu.so +0 -0
- egglog/bindings.pyi +62 -1
- egglog/builtins.py +6 -0
- egglog/declarations.py +20 -3
- egglog/deconstruct.py +9 -7
- egglog/egraph.py +318 -21
- egglog/egraph_state.py +154 -6
- egglog/exp/array_api_jit.py +2 -2
- egglog/pretty.py +8 -3
- egglog/runtime.py +11 -5
- egglog/type_constraint_solver.py +2 -2
- {egglog-11.3.0.dist-info → egglog-11.4.0.dist-info}/METADATA +1 -1
- {egglog-11.3.0.dist-info → egglog-11.4.0.dist-info}/RECORD +15 -15
- {egglog-11.3.0.dist-info → egglog-11.4.0.dist-info}/WHEEL +0 -0
- {egglog-11.3.0.dist-info → egglog-11.4.0.dist-info}/licenses/LICENSE +0 -0
|
Binary file
|
egglog/bindings.pyi
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
1
2
|
from datetime import timedelta
|
|
3
|
+
from fractions import Fraction
|
|
2
4
|
from pathlib import Path
|
|
3
|
-
from typing import TypeAlias
|
|
5
|
+
from typing import Any, Generic, Protocol, TypeAlias, TypeVar
|
|
4
6
|
|
|
5
7
|
from typing_extensions import final
|
|
6
8
|
|
|
@@ -14,6 +16,7 @@ __all__ = [
|
|
|
14
16
|
"Change",
|
|
15
17
|
"Check",
|
|
16
18
|
"Constructor",
|
|
19
|
+
"CostModel",
|
|
17
20
|
"Datatype",
|
|
18
21
|
"Datatypes",
|
|
19
22
|
"DefaultPrintFunctionMode",
|
|
@@ -26,6 +29,7 @@ __all__ = [
|
|
|
26
29
|
"Extract",
|
|
27
30
|
"ExtractBest",
|
|
28
31
|
"ExtractVariants",
|
|
32
|
+
"Extractor",
|
|
29
33
|
"Fact",
|
|
30
34
|
"Fail",
|
|
31
35
|
"Float",
|
|
@@ -83,6 +87,7 @@ __all__ = [
|
|
|
83
87
|
"UserDefined",
|
|
84
88
|
"UserDefinedCommandOutput",
|
|
85
89
|
"UserDefinedOutput",
|
|
90
|
+
"Value",
|
|
86
91
|
"Var",
|
|
87
92
|
"Variant",
|
|
88
93
|
]
|
|
@@ -128,6 +133,31 @@ class EGraph:
|
|
|
128
133
|
max_calls_per_function: int | None = None,
|
|
129
134
|
include_temporary_functions: bool = False,
|
|
130
135
|
) -> SerializedEGraph: ...
|
|
136
|
+
def lookup_function(self, name: str, key: list[Value]) -> Value | None: ...
|
|
137
|
+
def eval_expr(self, expr: _Expr) -> tuple[str, Value]: ...
|
|
138
|
+
def value_to_i64(self, v: Value) -> int: ...
|
|
139
|
+
def value_to_f64(self, v: Value) -> float: ...
|
|
140
|
+
def value_to_string(self, v: Value) -> str: ...
|
|
141
|
+
def value_to_bool(self, v: Value) -> bool: ...
|
|
142
|
+
def value_to_rational(self, v: Value) -> Fraction: ...
|
|
143
|
+
def value_to_bigint(self, v: Value) -> int: ...
|
|
144
|
+
def value_to_bigrat(self, v: Value) -> Fraction: ...
|
|
145
|
+
def value_to_pyobject(self, py_object_sort: PyObjectSort, v: Value) -> object: ...
|
|
146
|
+
def value_to_map(self, v: Value) -> dict[Value, Value]: ...
|
|
147
|
+
def value_to_multiset(self, v: Value) -> list[Value]: ...
|
|
148
|
+
def value_to_vec(self, v: Value) -> list[Value]: ...
|
|
149
|
+
def value_to_function(self, v: Value) -> tuple[str, list[Value]]: ...
|
|
150
|
+
def value_to_set(self, v: Value) -> set[Value]: ...
|
|
151
|
+
# def dynamic_cost_model_enode_cost(self, func: str, args: list[Value]) -> int: ...
|
|
152
|
+
|
|
153
|
+
@final
|
|
154
|
+
class Value:
|
|
155
|
+
def __hash__(self) -> int: ...
|
|
156
|
+
def __eq__(self, value: object) -> bool: ...
|
|
157
|
+
def __lt__(self, other: object) -> bool: ...
|
|
158
|
+
def __le__(self, other: object) -> bool: ...
|
|
159
|
+
def __gt__(self, other: object) -> bool: ...
|
|
160
|
+
def __ge__(self, other: object) -> bool: ...
|
|
131
161
|
|
|
132
162
|
@final
|
|
133
163
|
class EggSmolError(Exception):
|
|
@@ -732,3 +762,34 @@ class TermDag:
|
|
|
732
762
|
def expr_to_term(self, expr: _Expr) -> _Term: ...
|
|
733
763
|
def term_to_expr(self, term: _Term, span: _Span) -> _Expr: ...
|
|
734
764
|
def to_string(self, term: _Term) -> str: ...
|
|
765
|
+
|
|
766
|
+
##
|
|
767
|
+
# Extraction
|
|
768
|
+
##
|
|
769
|
+
class _Cost(Protocol):
|
|
770
|
+
def __lt__(self, other: _Cost) -> bool: ...
|
|
771
|
+
def __le__(self, other: _Cost) -> bool: ...
|
|
772
|
+
def __gt__(self, other: _Cost) -> bool: ...
|
|
773
|
+
def __ge__(self, other: _Cost) -> bool: ...
|
|
774
|
+
|
|
775
|
+
_COST = TypeVar("_COST", bound=_Cost)
|
|
776
|
+
|
|
777
|
+
_ENODE_COST = TypeVar("_ENODE_COST")
|
|
778
|
+
|
|
779
|
+
@final
|
|
780
|
+
class CostModel(Generic[_COST, _ENODE_COST]):
|
|
781
|
+
def __init__(
|
|
782
|
+
self,
|
|
783
|
+
fold: Callable[[str, _ENODE_COST, list[_COST]], _COST],
|
|
784
|
+
enode_cost: Callable[[str, list[Value]], _ENODE_COST],
|
|
785
|
+
container_cost: Callable[[str, Value, list[_COST]], _COST],
|
|
786
|
+
base_value_cost: Callable[[str, Value], _COST],
|
|
787
|
+
) -> None: ...
|
|
788
|
+
|
|
789
|
+
@final
|
|
790
|
+
class Extractor(Generic[_COST]):
|
|
791
|
+
def __init__(self, rootsorts: list[str] | None, egraph: EGraph, cost_model: CostModel[_COST, Any]) -> None: ...
|
|
792
|
+
def extract_best(self, egraph: EGraph, termdag: TermDag, value: Value, sort: str) -> tuple[_COST, _Term]: ...
|
|
793
|
+
def extract_variants(
|
|
794
|
+
self, egraph: EGraph, termdag: TermDag, value: Value, nvariants: int, sort: str
|
|
795
|
+
) -> list[tuple[_COST, _Term]]: ...
|
egglog/builtins.py
CHANGED
|
@@ -33,10 +33,12 @@ __all__ = [
|
|
|
33
33
|
"BigRatLike",
|
|
34
34
|
"Bool",
|
|
35
35
|
"BoolLike",
|
|
36
|
+
"Container",
|
|
36
37
|
"ExprValueError",
|
|
37
38
|
"Map",
|
|
38
39
|
"MapLike",
|
|
39
40
|
"MultiSet",
|
|
41
|
+
"Primitive",
|
|
40
42
|
"PyObject",
|
|
41
43
|
"Rational",
|
|
42
44
|
"Set",
|
|
@@ -1135,3 +1137,7 @@ def _convert_function(fn: FunctionType) -> UnstableFn:
|
|
|
1135
1137
|
|
|
1136
1138
|
|
|
1137
1139
|
converter(FunctionType, UnstableFn, _convert_function)
|
|
1140
|
+
|
|
1141
|
+
|
|
1142
|
+
Container: TypeAlias = Map | Set | MultiSet | Vec | UnstableFn
|
|
1143
|
+
Primitive: TypeAlias = String | Bool | i64 | f64 | Rational | BigInt | BigRat | PyObject | Unit
|
egglog/declarations.py
CHANGED
|
@@ -14,6 +14,8 @@ from weakref import WeakValueDictionary
|
|
|
14
14
|
|
|
15
15
|
from typing_extensions import Self, assert_never
|
|
16
16
|
|
|
17
|
+
from .bindings import Value
|
|
18
|
+
|
|
17
19
|
if TYPE_CHECKING:
|
|
18
20
|
from collections.abc import Callable, Iterable, Mapping
|
|
19
21
|
|
|
@@ -49,6 +51,7 @@ __all__ = [
|
|
|
49
51
|
"FunctionDecl",
|
|
50
52
|
"FunctionRef",
|
|
51
53
|
"FunctionSignature",
|
|
54
|
+
"GetCostDecl",
|
|
52
55
|
"HasDeclerations",
|
|
53
56
|
"InitRef",
|
|
54
57
|
"JustTypeRef",
|
|
@@ -82,6 +85,7 @@ __all__ = [
|
|
|
82
85
|
"UnboundVarDecl",
|
|
83
86
|
"UnionDecl",
|
|
84
87
|
"UnnamedFunctionRef",
|
|
88
|
+
"ValueDecl",
|
|
85
89
|
"collect_unbound_vars",
|
|
86
90
|
"replace_typed_expr",
|
|
87
91
|
"upcast_declerations",
|
|
@@ -639,7 +643,7 @@ class CallDecl:
|
|
|
639
643
|
args: tuple[TypedExprDecl, ...] = ()
|
|
640
644
|
# type parameters that were bound to the callable, if it is a classmethod
|
|
641
645
|
# Used for pretty printing classmethod calls with type parameters
|
|
642
|
-
bound_tp_params: tuple[JustTypeRef, ...]
|
|
646
|
+
bound_tp_params: tuple[JustTypeRef, ...] = ()
|
|
643
647
|
|
|
644
648
|
# pool objects for faster __eq__
|
|
645
649
|
_args_to_value: ClassVar[WeakValueDictionary[tuple[object, ...], CallDecl]] = WeakValueDictionary({})
|
|
@@ -654,7 +658,7 @@ class CallDecl:
|
|
|
654
658
|
# normalize the args/kwargs to a tuple so that they can be compared
|
|
655
659
|
callable = args[0] if args else kwargs["callable"]
|
|
656
660
|
args_ = args[1] if len(args) > 1 else kwargs.get("args", ())
|
|
657
|
-
bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params")
|
|
661
|
+
bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params", ())
|
|
658
662
|
|
|
659
663
|
normalized_args = (callable, args_, bound_tp_params)
|
|
660
664
|
try:
|
|
@@ -696,7 +700,20 @@ class PartialCallDecl:
|
|
|
696
700
|
call: CallDecl
|
|
697
701
|
|
|
698
702
|
|
|
699
|
-
|
|
703
|
+
@dataclass(frozen=True)
|
|
704
|
+
class GetCostDecl:
|
|
705
|
+
callable: CallableRef
|
|
706
|
+
args: tuple[TypedExprDecl, ...]
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
@dataclass(frozen=True)
|
|
710
|
+
class ValueDecl:
|
|
711
|
+
value: Value
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
ExprDecl: TypeAlias = (
|
|
715
|
+
UnboundVarDecl | LetRefDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl | ValueDecl | GetCostDecl
|
|
716
|
+
)
|
|
700
717
|
|
|
701
718
|
|
|
702
719
|
@dataclass(frozen=True)
|
egglog/deconstruct.py
CHANGED
|
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, TypeVar, overload
|
|
|
11
11
|
from typing_extensions import TypeVarTuple, Unpack
|
|
12
12
|
|
|
13
13
|
from .declarations import *
|
|
14
|
-
from .egraph import BaseExpr
|
|
14
|
+
from .egraph import BaseExpr, Expr
|
|
15
15
|
from .runtime import *
|
|
16
16
|
from .thunk import *
|
|
17
17
|
|
|
@@ -49,7 +49,11 @@ def get_literal_value(x: PyObject) -> object: ...
|
|
|
49
49
|
def get_literal_value(x: UnstableFn[T, Unpack[TS]]) -> Callable[[Unpack[TS]], T] | None: ...
|
|
50
50
|
|
|
51
51
|
|
|
52
|
-
|
|
52
|
+
@overload
|
|
53
|
+
def get_literal_value(x: Expr) -> None: ...
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def get_literal_value(x: object) -> object:
|
|
53
57
|
"""
|
|
54
58
|
Returns the literal value of an expression if it is a literal.
|
|
55
59
|
If it is not a literal, returns None.
|
|
@@ -95,12 +99,9 @@ def get_var_name(x: BaseExpr) -> str | None:
|
|
|
95
99
|
return None
|
|
96
100
|
|
|
97
101
|
|
|
98
|
-
def get_callable_fn(x: T) -> Callable[..., T] | None:
|
|
102
|
+
def get_callable_fn(x: T) -> Callable[..., T] | T | None:
|
|
99
103
|
"""
|
|
100
|
-
Gets the function of an expression if it
|
|
101
|
-
If it is not a call expression (a property, a primitive value, constants, classvars, a let value), return None.
|
|
102
|
-
For those values, you can check them by comparing them directly with equality or for primitives calling `.eval()`
|
|
103
|
-
to return the Python value.
|
|
104
|
+
Gets the function of an expression, or if it's a constant or classvar, return that.
|
|
104
105
|
"""
|
|
105
106
|
if not isinstance(x, RuntimeExpr):
|
|
106
107
|
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
@@ -159,6 +160,7 @@ def _deconstruct_call_decl(
|
|
|
159
160
|
"""
|
|
160
161
|
args = call.args
|
|
161
162
|
arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args)
|
|
163
|
+
# TODO: handle values? Like constants
|
|
162
164
|
if isinstance(call.callable, InitRef):
|
|
163
165
|
return RuntimeClass(
|
|
164
166
|
decls_thunk,
|
egglog/egraph.py
CHANGED
|
@@ -16,6 +16,7 @@ from typing import (
|
|
|
16
16
|
ClassVar,
|
|
17
17
|
Generic,
|
|
18
18
|
Literal,
|
|
19
|
+
Protocol,
|
|
19
20
|
TypeAlias,
|
|
20
21
|
TypedDict,
|
|
21
22
|
TypeVar,
|
|
@@ -41,7 +42,7 @@ from .thunk import *
|
|
|
41
42
|
from .version_compat import *
|
|
42
43
|
|
|
43
44
|
if TYPE_CHECKING:
|
|
44
|
-
from .builtins import String, Unit, i64Like
|
|
45
|
+
from .builtins import String, Unit, i64, i64Like
|
|
45
46
|
|
|
46
47
|
|
|
47
48
|
__all__ = [
|
|
@@ -51,11 +52,14 @@ __all__ = [
|
|
|
51
52
|
"BuiltinExpr",
|
|
52
53
|
"Command",
|
|
53
54
|
"Command",
|
|
55
|
+
"CostModel",
|
|
54
56
|
"EGraph",
|
|
55
57
|
"Expr",
|
|
58
|
+
"ExprCallable",
|
|
56
59
|
"Fact",
|
|
57
60
|
"Fact",
|
|
58
61
|
"GraphvizKwargs",
|
|
62
|
+
"GreedyDagCost",
|
|
59
63
|
"RewriteOrRule",
|
|
60
64
|
"Ruleset",
|
|
61
65
|
"Schedule",
|
|
@@ -70,12 +74,15 @@ __all__ = [
|
|
|
70
74
|
"check",
|
|
71
75
|
"check_eq",
|
|
72
76
|
"constant",
|
|
77
|
+
"default_cost_model",
|
|
73
78
|
"delete",
|
|
74
79
|
"eq",
|
|
75
80
|
"expr_action",
|
|
76
81
|
"expr_fact",
|
|
77
82
|
"expr_parts",
|
|
78
83
|
"function",
|
|
84
|
+
"get_cost",
|
|
85
|
+
"greedy_dag_cost_model",
|
|
79
86
|
"let",
|
|
80
87
|
"method",
|
|
81
88
|
"ne",
|
|
@@ -88,6 +95,7 @@ __all__ = [
|
|
|
88
95
|
"seq",
|
|
89
96
|
"set_",
|
|
90
97
|
"set_cost",
|
|
98
|
+
"set_current_ruleset",
|
|
91
99
|
"subsume",
|
|
92
100
|
"union",
|
|
93
101
|
"unstable_combine_rulesets",
|
|
@@ -452,7 +460,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
|
|
|
452
460
|
continue
|
|
453
461
|
locals = frame.f_locals
|
|
454
462
|
ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
|
|
455
|
-
# TODO: Store deprecated message so we can
|
|
463
|
+
# TODO: Store deprecated message so we can get at runtime
|
|
456
464
|
if (getattr(fn, "__deprecated__", None)) is not None:
|
|
457
465
|
fn = fn.__wrapped__ # type: ignore[attr-defined]
|
|
458
466
|
match fn:
|
|
@@ -953,22 +961,45 @@ class EGraph:
|
|
|
953
961
|
return bindings.Check(span(2), egg_facts)
|
|
954
962
|
|
|
955
963
|
@overload
|
|
956
|
-
def extract(
|
|
964
|
+
def extract(
|
|
965
|
+
self, expr: BASE_EXPR, /, include_cost: Literal[False] = False, cost_model: CostModel | None = None
|
|
966
|
+
) -> BASE_EXPR: ...
|
|
957
967
|
|
|
958
968
|
@overload
|
|
959
|
-
def extract(
|
|
969
|
+
def extract(
|
|
970
|
+
self, expr: BASE_EXPR, /, include_cost: Literal[True], cost_model: None = None
|
|
971
|
+
) -> tuple[BASE_EXPR, int]: ...
|
|
960
972
|
|
|
961
|
-
|
|
973
|
+
@overload
|
|
974
|
+
def extract(
|
|
975
|
+
self, expr: BASE_EXPR, /, include_cost: Literal[True], cost_model: CostModel[COST]
|
|
976
|
+
) -> tuple[BASE_EXPR, COST]: ...
|
|
977
|
+
|
|
978
|
+
def extract(
|
|
979
|
+
self, expr: BASE_EXPR, /, include_cost: bool = False, cost_model: CostModel[COST] | None = None
|
|
980
|
+
) -> BASE_EXPR | tuple[BASE_EXPR, COST]:
|
|
962
981
|
"""
|
|
963
982
|
Extract the lowest cost expression from the egraph.
|
|
964
983
|
"""
|
|
965
984
|
runtime_expr = to_runtime_expr(expr)
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
985
|
+
self._add_decls(runtime_expr)
|
|
986
|
+
tp = runtime_expr.__egg_typed_expr__.tp
|
|
987
|
+
if cost_model is None:
|
|
988
|
+
extract_report = self._run_extract(runtime_expr, 0)
|
|
989
|
+
assert isinstance(extract_report, bindings.ExtractBest)
|
|
990
|
+
res = self._from_termdag(extract_report.termdag, extract_report.term, tp)
|
|
991
|
+
cost = cast("COST", extract_report.cost)
|
|
992
|
+
else:
|
|
993
|
+
# TODO: For some reason we need this or else it wont be registered. Not sure why
|
|
994
|
+
self.register(expr)
|
|
995
|
+
egg_cost_model = _CostModel(cost_model, self).to_bindings_cost_model()
|
|
996
|
+
egg_sort = self._state.type_ref_to_egg(tp)
|
|
997
|
+
extractor = bindings.Extractor([egg_sort], self._state.egraph, egg_cost_model)
|
|
998
|
+
termdag = bindings.TermDag()
|
|
999
|
+
value = self._state.typed_expr_to_value(runtime_expr.__egg_typed_expr__)
|
|
1000
|
+
cost, term = extractor.extract_best(self._state.egraph, termdag, value, egg_sort)
|
|
1001
|
+
res = self._from_termdag(termdag, term, tp)
|
|
1002
|
+
return (res, cost) if include_cost else res
|
|
972
1003
|
|
|
973
1004
|
def _from_termdag(self, termdag: bindings.TermDag, term: bindings._Term, tp: JustTypeRef) -> Any:
|
|
974
1005
|
(new_typed_expr,) = self._state.exprs_from_egg(termdag, [term], tp)
|
|
@@ -979,6 +1010,7 @@ class EGraph:
|
|
|
979
1010
|
Extract multiple expressions from the egraph.
|
|
980
1011
|
"""
|
|
981
1012
|
runtime_expr = to_runtime_expr(expr)
|
|
1013
|
+
self._add_decls(runtime_expr)
|
|
982
1014
|
extract_report = self._run_extract(runtime_expr, n)
|
|
983
1015
|
assert isinstance(extract_report, bindings.ExtractVariants)
|
|
984
1016
|
new_exprs = self._state.exprs_from_egg(
|
|
@@ -987,7 +1019,6 @@ class EGraph:
|
|
|
987
1019
|
return [cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs]
|
|
988
1020
|
|
|
989
1021
|
def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._CommandOutput:
|
|
990
|
-
self._add_decls(expr)
|
|
991
1022
|
expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
|
|
992
1023
|
# If we have defined any cost tables use the custom extraction
|
|
993
1024
|
args = (expr, bindings.Lit(span(2), bindings.Int(n)))
|
|
@@ -1212,16 +1243,12 @@ class EGraph:
|
|
|
1212
1243
|
"""
|
|
1213
1244
|
(output,) = self._egraph.run_program(bindings.PrintSize(span(1), None))
|
|
1214
1245
|
assert isinstance(output, bindings.PrintAllFunctionsSize)
|
|
1246
|
+
return [(callables[0], size) for (name, size) in output.sizes if (callables := self._egg_fn_to_callables(name))]
|
|
1247
|
+
|
|
1248
|
+
def _egg_fn_to_callables(self, egg_fn: str) -> list[ExprCallable]:
|
|
1215
1249
|
return [
|
|
1216
|
-
(
|
|
1217
|
-
|
|
1218
|
-
"ExprCallable",
|
|
1219
|
-
create_callable(self._state.__egg_decls__, next(iter(refs))),
|
|
1220
|
-
),
|
|
1221
|
-
size,
|
|
1222
|
-
)
|
|
1223
|
-
for (name, size) in output.sizes
|
|
1224
|
-
if (refs := self._state.egg_fn_to_callable_refs[name])
|
|
1250
|
+
cast("ExprCallable", create_callable(self._state.__egg_decls__, ref))
|
|
1251
|
+
for ref in self._state.egg_fn_to_callable_refs[egg_fn]
|
|
1225
1252
|
]
|
|
1226
1253
|
|
|
1227
1254
|
def function_values(
|
|
@@ -1245,6 +1272,33 @@ class EGraph:
|
|
|
1245
1272
|
for (call, res) in output.terms
|
|
1246
1273
|
}
|
|
1247
1274
|
|
|
1275
|
+
def lookup_function_value(self, expr: BASE_EXPR) -> BASE_EXPR | None:
|
|
1276
|
+
"""
|
|
1277
|
+
Given an expression that is a function call, looks up the value of the function call if it exists.
|
|
1278
|
+
"""
|
|
1279
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1280
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1281
|
+
assert isinstance(typed_expr.expr, CallDecl | GetCostDecl)
|
|
1282
|
+
egg_fn, typed_args = self._state.translate_call(typed_expr.expr)
|
|
1283
|
+
values_args = [self._state.typed_expr_to_value(a) for a in typed_args]
|
|
1284
|
+
possible_value = self._egraph.lookup_function(egg_fn, values_args)
|
|
1285
|
+
if possible_value is None:
|
|
1286
|
+
return None
|
|
1287
|
+
return cast(
|
|
1288
|
+
"BASE_EXPR",
|
|
1289
|
+
RuntimeExpr.__from_values__(
|
|
1290
|
+
self.__egg_decls__,
|
|
1291
|
+
TypedExprDecl(typed_expr.tp, self._state.value_to_expr(typed_expr.tp, possible_value)),
|
|
1292
|
+
),
|
|
1293
|
+
)
|
|
1294
|
+
|
|
1295
|
+
def has_custom_cost(self, fn: ExprCallable) -> bool:
|
|
1296
|
+
"""
|
|
1297
|
+
Checks if the any custom costs have been set for this expression callable.
|
|
1298
|
+
"""
|
|
1299
|
+
resolved, _ = resolve_callable(fn)
|
|
1300
|
+
return resolved in self._state.cost_callables
|
|
1301
|
+
|
|
1248
1302
|
|
|
1249
1303
|
# Either a constant or a function.
|
|
1250
1304
|
ExprCallable: TypeAlias = Callable[..., BaseExpr] | BaseExpr
|
|
@@ -1910,3 +1964,246 @@ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
|
|
|
1910
1964
|
yield
|
|
1911
1965
|
finally:
|
|
1912
1966
|
_CURRENT_RULESET.reset(token)
|
|
1967
|
+
|
|
1968
|
+
|
|
1969
|
+
def get_cost(expr: BaseExpr) -> i64:
|
|
1970
|
+
"""
|
|
1971
|
+
Return a lookup of the cost of an expression. If not set, won't match.
|
|
1972
|
+
"""
|
|
1973
|
+
assert isinstance(expr, RuntimeExpr)
|
|
1974
|
+
expr_decl = expr.__egg_typed_expr__.expr
|
|
1975
|
+
if not isinstance(expr_decl, CallDecl):
|
|
1976
|
+
msg = "Can only get cost of function calls, not literals or variables"
|
|
1977
|
+
raise TypeError(msg)
|
|
1978
|
+
return RuntimeExpr.__from_values__(
|
|
1979
|
+
expr.__egg_decls__,
|
|
1980
|
+
TypedExprDecl(JustTypeRef("i64"), GetCostDecl(expr_decl.callable, expr_decl.args)),
|
|
1981
|
+
)
|
|
1982
|
+
|
|
1983
|
+
|
|
1984
|
+
class Comparable(Protocol):
|
|
1985
|
+
def __lt__(self, other: Self) -> bool: ...
|
|
1986
|
+
def __le__(self, other: Self) -> bool: ...
|
|
1987
|
+
def __gt__(self, other: Self) -> bool: ...
|
|
1988
|
+
def __ge__(self, other: Self) -> bool: ...
|
|
1989
|
+
|
|
1990
|
+
|
|
1991
|
+
COST = TypeVar("COST", bound=Comparable)
|
|
1992
|
+
|
|
1993
|
+
|
|
1994
|
+
class CostModel(Protocol, Generic[COST]):
|
|
1995
|
+
"""
|
|
1996
|
+
A cost model for an e-graph. Used to determine the cost of an expression based on its structure and the costs of its sub-expressions.
|
|
1997
|
+
|
|
1998
|
+
Called with an expression and the costs of its children, returns the total cost of the expression.
|
|
1999
|
+
|
|
2000
|
+
Additionally, the cost model should guarantee that a term has a no-smaller cost
|
|
2001
|
+
than its subterms to avoid cycles in the extracted terms for common case usages.
|
|
2002
|
+
For more niche usages, a term can have a cost less than its subterms.
|
|
2003
|
+
As long as there is no negative cost cycle, the default extractor is guaranteed to terminate in computing the costs.
|
|
2004
|
+
However, the user needs to be careful to guarantee acyclicity in the extracted terms.
|
|
2005
|
+
"""
|
|
2006
|
+
|
|
2007
|
+
def __call__(self, egraph: EGraph, expr: BaseExpr, children_costs: list[COST]) -> COST:
|
|
2008
|
+
"""
|
|
2009
|
+
The total cost of a term given the cost of the root e-node and its immediate children's total costs.
|
|
2010
|
+
"""
|
|
2011
|
+
raise NotImplementedError
|
|
2012
|
+
|
|
2013
|
+
|
|
2014
|
+
def default_cost_model(egraph: EGraph, expr: BaseExpr, children_costs: list[int]) -> int:
|
|
2015
|
+
"""
|
|
2016
|
+
A default cost model for an e-graph, which looks up costs set on function calls, or uses 1 as the default cost.
|
|
2017
|
+
"""
|
|
2018
|
+
from .builtins import Container # noqa: PLC0415
|
|
2019
|
+
from .deconstruct import get_callable_fn # noqa: PLC0415
|
|
2020
|
+
|
|
2021
|
+
# 1. First prefer if the expr has a custom cost set on it
|
|
2022
|
+
if (
|
|
2023
|
+
(callable_fn := get_callable_fn(expr)) is not None
|
|
2024
|
+
and egraph.has_custom_cost(callable_fn)
|
|
2025
|
+
and (i := egraph.lookup_function_value(get_cost(expr))) is not None
|
|
2026
|
+
):
|
|
2027
|
+
self_cost = int(i)
|
|
2028
|
+
# 2. Else, check if this is a callable and it has a cost set on its declaration
|
|
2029
|
+
elif callable_fn is not None and (callable_cost := get_callable_cost(callable_fn)) is not None:
|
|
2030
|
+
self_cost = callable_cost
|
|
2031
|
+
# 3. Else, if this is a container, it has no cost, otherwise it has a cost of 1
|
|
2032
|
+
else:
|
|
2033
|
+
# By default, all nodes have a cost of 1 except for containers which have a cost of 0
|
|
2034
|
+
self_cost = 0 if isinstance(expr, Container) else 1
|
|
2035
|
+
# Sum up the costs of the children and our own cost
|
|
2036
|
+
return sum(children_costs, start=self_cost)
|
|
2037
|
+
|
|
2038
|
+
|
|
2039
|
+
class ComparableAddSub(Comparable, Protocol):
|
|
2040
|
+
def __add__(self, other: Self) -> Self: ...
|
|
2041
|
+
def __sub__(self, other: Self) -> Self: ...
|
|
2042
|
+
|
|
2043
|
+
|
|
2044
|
+
DAG_COST = TypeVar("DAG_COST", bound=ComparableAddSub)
|
|
2045
|
+
|
|
2046
|
+
|
|
2047
|
+
@dataclass
|
|
2048
|
+
class GreedyDagCost(Generic[DAG_COST]):
|
|
2049
|
+
"""
|
|
2050
|
+
Cost of a DAG, which stores children costs. Use `.total` to get the underlying cost.
|
|
2051
|
+
"""
|
|
2052
|
+
|
|
2053
|
+
total: DAG_COST
|
|
2054
|
+
_costs: dict[TypedExprDecl, DAG_COST] = field(repr=False)
|
|
2055
|
+
|
|
2056
|
+
def __eq__(self, other: object) -> bool:
|
|
2057
|
+
if not isinstance(other, GreedyDagCost):
|
|
2058
|
+
return NotImplemented
|
|
2059
|
+
return self.total == other.total
|
|
2060
|
+
|
|
2061
|
+
def __lt__(self, other: Self) -> bool:
|
|
2062
|
+
return self.total < other.total
|
|
2063
|
+
|
|
2064
|
+
def __le__(self, other: Self) -> bool:
|
|
2065
|
+
return self.total <= other.total
|
|
2066
|
+
|
|
2067
|
+
def __gt__(self, other: Self) -> bool:
|
|
2068
|
+
return self.total > other.total
|
|
2069
|
+
|
|
2070
|
+
def __ge__(self, other: Self) -> bool:
|
|
2071
|
+
return self.total >= other.total
|
|
2072
|
+
|
|
2073
|
+
def __hash__(self) -> int:
|
|
2074
|
+
return hash(self.total)
|
|
2075
|
+
|
|
2076
|
+
|
|
2077
|
+
@dataclass
|
|
2078
|
+
class GreedyDagCostModel(CostModel[GreedyDagCost[DAG_COST]]):
|
|
2079
|
+
"""
|
|
2080
|
+
A cost model which will count duplicate nodes only once.
|
|
2081
|
+
|
|
2082
|
+
Should have similar behavior as https://github.com/egraphs-good/extraction-gym/blob/main/src/extract/greedy_dag.rs
|
|
2083
|
+
but implemented as a cost model that will be used with the default extractor.
|
|
2084
|
+
"""
|
|
2085
|
+
|
|
2086
|
+
base: CostModel[DAG_COST]
|
|
2087
|
+
|
|
2088
|
+
def __call__(
|
|
2089
|
+
self, egraph: EGraph, expr: BaseExpr, children_costs: list[GreedyDagCost[DAG_COST]]
|
|
2090
|
+
) -> GreedyDagCost[DAG_COST]:
|
|
2091
|
+
cost = self.base(egraph, expr, [c.total for c in children_costs])
|
|
2092
|
+
for c in children_costs:
|
|
2093
|
+
cost -= c.total
|
|
2094
|
+
costs = {}
|
|
2095
|
+
for c in children_costs:
|
|
2096
|
+
costs.update(c._costs)
|
|
2097
|
+
total = sum(costs.values(), start=cost)
|
|
2098
|
+
costs[to_runtime_expr(expr).__egg_typed_expr__] = cost
|
|
2099
|
+
return GreedyDagCost(total, costs)
|
|
2100
|
+
|
|
2101
|
+
|
|
2102
|
+
@overload
|
|
2103
|
+
def greedy_dag_cost_model() -> CostModel[GreedyDagCost[int]]: ...
|
|
2104
|
+
|
|
2105
|
+
|
|
2106
|
+
@overload
|
|
2107
|
+
def greedy_dag_cost_model(base: CostModel[DAG_COST]) -> CostModel[GreedyDagCost[DAG_COST]]: ...
|
|
2108
|
+
|
|
2109
|
+
|
|
2110
|
+
def greedy_dag_cost_model(base: CostModel[Any] = default_cost_model) -> CostModel[GreedyDagCost[Any]]:
|
|
2111
|
+
"""
|
|
2112
|
+
Creates a greedy dag cost model from a base cost model.
|
|
2113
|
+
"""
|
|
2114
|
+
return GreedyDagCostModel(base or default_cost_model)
|
|
2115
|
+
|
|
2116
|
+
|
|
2117
|
+
def get_callable_cost(fn: ExprCallable) -> int | None:
|
|
2118
|
+
"""
|
|
2119
|
+
Returns the cost of a callable, if it has one set. Otherwise returns None.
|
|
2120
|
+
"""
|
|
2121
|
+
callable_ref, decls = resolve_callable(fn)
|
|
2122
|
+
callable_decl = decls.get_callable_decl(callable_ref)
|
|
2123
|
+
return callable_decl.cost if isinstance(callable_decl, ConstructorDecl) else 1
|
|
2124
|
+
|
|
2125
|
+
|
|
2126
|
+
@dataclass
|
|
2127
|
+
class _CostModel(Generic[COST]):
|
|
2128
|
+
"""
|
|
2129
|
+
Implements the methods compatible with the bindings for the cost model.
|
|
2130
|
+
"""
|
|
2131
|
+
|
|
2132
|
+
model: CostModel[COST]
|
|
2133
|
+
egraph: EGraph
|
|
2134
|
+
enode_cost_results: dict[tuple[str, tuple[bindings.Value, ...]], int] = field(default_factory=dict)
|
|
2135
|
+
enode_cost_expressions: list[RuntimeExpr] = field(default_factory=list)
|
|
2136
|
+
fold_results: dict[tuple[int, tuple[COST, ...]], COST] = field(default_factory=dict)
|
|
2137
|
+
base_value_cost_results: dict[tuple[str, bindings.Value], COST] = field(default_factory=dict)
|
|
2138
|
+
container_cost_results: dict[tuple[str, bindings.Value, tuple[COST, ...]], COST] = field(default_factory=dict)
|
|
2139
|
+
|
|
2140
|
+
def call_model(self, expr: RuntimeExpr, children_costs: list[COST]) -> COST:
|
|
2141
|
+
return self.model(self.egraph, cast("BaseExpr", expr), children_costs)
|
|
2142
|
+
# if __debug__:
|
|
2143
|
+
# for c in children_costs:
|
|
2144
|
+
# if res <= c:
|
|
2145
|
+
# msg = f"Cost model {self.model} produced a cost {res} less than or equal to a child cost {c} for {expr}"
|
|
2146
|
+
# raise ValueError(msg)
|
|
2147
|
+
|
|
2148
|
+
def fold(self, _fn: str, index: int, children_costs: list[COST]) -> COST:
|
|
2149
|
+
try:
|
|
2150
|
+
return self.fold_results[(index, tuple(children_costs))]
|
|
2151
|
+
except KeyError:
|
|
2152
|
+
pass
|
|
2153
|
+
|
|
2154
|
+
expr = self.enode_cost_expressions[index]
|
|
2155
|
+
return self.call_model(expr, children_costs)
|
|
2156
|
+
|
|
2157
|
+
# enode cost is only ever called right before fold, for the head_cost
|
|
2158
|
+
def enode_cost(self, name: str, args: list[bindings.Value]) -> int:
|
|
2159
|
+
try:
|
|
2160
|
+
return self.enode_cost_results[(name, tuple(args))]
|
|
2161
|
+
except KeyError:
|
|
2162
|
+
pass
|
|
2163
|
+
(callable_ref,) = self.egraph._state.egg_fn_to_callable_refs[name]
|
|
2164
|
+
signature = self.egraph.__egg_decls__.get_callable_decl(callable_ref).signature
|
|
2165
|
+
assert isinstance(signature, FunctionSignature)
|
|
2166
|
+
arg_exprs = [
|
|
2167
|
+
TypedExprDecl(tp.to_just(), self.egraph._state.value_to_expr(tp.to_just(), arg))
|
|
2168
|
+
for (arg, tp) in zip(args, signature.arg_types, strict=True)
|
|
2169
|
+
]
|
|
2170
|
+
res_type = signature.semantic_return_type.to_just()
|
|
2171
|
+
res = RuntimeExpr.__from_values__(
|
|
2172
|
+
self.egraph.__egg_decls__,
|
|
2173
|
+
TypedExprDecl(res_type, CallDecl(callable_ref, tuple(arg_exprs))),
|
|
2174
|
+
)
|
|
2175
|
+
index = len(self.enode_cost_expressions)
|
|
2176
|
+
self.enode_cost_expressions.append(res)
|
|
2177
|
+
self.enode_cost_results[(name, tuple(args))] = index
|
|
2178
|
+
return index
|
|
2179
|
+
|
|
2180
|
+
def base_value_cost(self, tp: str, value: bindings.Value) -> COST:
|
|
2181
|
+
try:
|
|
2182
|
+
return self.base_value_cost_results[(tp, value)]
|
|
2183
|
+
except KeyError:
|
|
2184
|
+
pass
|
|
2185
|
+
type_ref = self.egraph._state.egg_sort_to_type_ref[tp]
|
|
2186
|
+
expr = RuntimeExpr.__from_values__(
|
|
2187
|
+
self.egraph.__egg_decls__,
|
|
2188
|
+
TypedExprDecl(type_ref, self.egraph._state.value_to_expr(type_ref, value)),
|
|
2189
|
+
)
|
|
2190
|
+
res = self.call_model(expr, [])
|
|
2191
|
+
self.base_value_cost_results[(tp, value)] = res
|
|
2192
|
+
return res
|
|
2193
|
+
|
|
2194
|
+
def container_cost(self, tp: str, value: bindings.Value, element_costs: list[COST]) -> COST:
|
|
2195
|
+
try:
|
|
2196
|
+
return self.container_cost_results[(tp, value, tuple(element_costs))]
|
|
2197
|
+
except KeyError:
|
|
2198
|
+
pass
|
|
2199
|
+
type_ref = self.egraph._state.egg_sort_to_type_ref[tp]
|
|
2200
|
+
expr = RuntimeExpr.__from_values__(
|
|
2201
|
+
self.egraph.__egg_decls__,
|
|
2202
|
+
TypedExprDecl(type_ref, self.egraph._state.value_to_expr(type_ref, value)),
|
|
2203
|
+
)
|
|
2204
|
+
res = self.call_model(expr, element_costs)
|
|
2205
|
+
self.container_cost_results[(tp, value, tuple(element_costs))] = res
|
|
2206
|
+
return res
|
|
2207
|
+
|
|
2208
|
+
def to_bindings_cost_model(self) -> bindings.CostModel[COST, int]:
|
|
2209
|
+
return bindings.CostModel(self.fold, self.enode_cost, self.container_cost, self.base_value_cost)
|
egglog/egraph_state.py
CHANGED
|
@@ -68,6 +68,7 @@ class EGraphState:
|
|
|
68
68
|
|
|
69
69
|
# Bidirectional mapping between egg sort names and python type references.
|
|
70
70
|
type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
|
|
71
|
+
egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)
|
|
71
72
|
|
|
72
73
|
# Cache of egg expressions for converting to egg
|
|
73
74
|
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)
|
|
@@ -86,6 +87,7 @@ class EGraphState:
|
|
|
86
87
|
egg_fn_to_callable_refs=defaultdict(set, {k: v.copy() for k, v in self.egg_fn_to_callable_refs.items()}),
|
|
87
88
|
callable_ref_to_egg_fn=self.callable_ref_to_egg_fn.copy(),
|
|
88
89
|
type_ref_to_egg_sort=self.type_ref_to_egg_sort.copy(),
|
|
90
|
+
egg_sort_to_type_ref=self.egg_sort_to_type_ref.copy(),
|
|
89
91
|
expr_to_egg_cache=self.expr_to_egg_cache.copy(),
|
|
90
92
|
cost_callables=self.cost_callables.copy(),
|
|
91
93
|
)
|
|
@@ -352,6 +354,7 @@ class EGraphState:
|
|
|
352
354
|
Creates the egg cost table if needed and gets the name of the table.
|
|
353
355
|
"""
|
|
354
356
|
name = self.cost_table_name(ref)
|
|
357
|
+
print(name, self.cost_callables)
|
|
355
358
|
if ref not in self.cost_callables:
|
|
356
359
|
self.cost_callables.add(ref)
|
|
357
360
|
signature = self.__egg_decls__.get_callable_decl(ref).signature
|
|
@@ -455,10 +458,14 @@ class EGraphState:
|
|
|
455
458
|
pass
|
|
456
459
|
decl = self.__egg_decls__._classes[ref.name]
|
|
457
460
|
self.type_ref_to_egg_sort[ref] = egg_name = decl.egg_name or _generate_type_egg_name(ref)
|
|
461
|
+
self.egg_sort_to_type_ref[egg_name] = ref
|
|
458
462
|
if not decl.builtin or ref.args:
|
|
459
463
|
if ref.args:
|
|
460
464
|
if ref.name == "UnstableFn":
|
|
461
465
|
# UnstableFn is a special case, where the rest of args are collected into a call
|
|
466
|
+
if len(ref.args) < 2:
|
|
467
|
+
msg = "Zero argument higher order functions not supported"
|
|
468
|
+
raise NotImplementedError(msg)
|
|
462
469
|
type_args: list[bindings._Expr] = [
|
|
463
470
|
bindings.Call(
|
|
464
471
|
span(),
|
|
@@ -589,11 +596,9 @@ class EGraphState:
|
|
|
589
596
|
case _:
|
|
590
597
|
assert_never(value)
|
|
591
598
|
res = bindings.Lit(span(), l)
|
|
592
|
-
case CallDecl(
|
|
593
|
-
egg_fn,
|
|
594
|
-
egg_args = [self.typed_expr_to_egg(a, False) for a in
|
|
595
|
-
if reverse_args:
|
|
596
|
-
egg_args.reverse()
|
|
599
|
+
case CallDecl() | GetCostDecl():
|
|
600
|
+
egg_fn, typed_args = self.translate_call(expr_decl)
|
|
601
|
+
egg_args = [self.typed_expr_to_egg(a, False) for a in typed_args]
|
|
597
602
|
res = bindings.Call(span(), egg_fn, egg_args)
|
|
598
603
|
case PyObjectDecl(value):
|
|
599
604
|
res = GLOBAL_PY_OBJECT_SORT.store(value)
|
|
@@ -604,11 +609,31 @@ class EGraphState:
|
|
|
604
609
|
"unstable-fn",
|
|
605
610
|
[bindings.Lit(span(), bindings.String(egg_fn_call.name)), *egg_fn_call.args],
|
|
606
611
|
)
|
|
612
|
+
case ValueDecl():
|
|
613
|
+
msg = "Cannot turn a Value into an expression"
|
|
614
|
+
raise ValueError(msg)
|
|
607
615
|
case _:
|
|
608
616
|
assert_never(expr_decl.expr)
|
|
609
617
|
self.expr_to_egg_cache[expr_decl] = res
|
|
610
618
|
return res
|
|
611
619
|
|
|
620
|
+
def translate_call(self, expr: CallDecl | GetCostDecl) -> tuple[str, list[TypedExprDecl]]:
|
|
621
|
+
"""
|
|
622
|
+
Handle get cost and call decl, turn into egg table name and typed expr decls.
|
|
623
|
+
"""
|
|
624
|
+
match expr:
|
|
625
|
+
case CallDecl(ref, args, _):
|
|
626
|
+
egg_fn, reverse_args = self.callable_ref_to_egg(ref)
|
|
627
|
+
args_list = list(args)
|
|
628
|
+
if reverse_args:
|
|
629
|
+
args_list.reverse()
|
|
630
|
+
return egg_fn, args_list
|
|
631
|
+
case GetCostDecl(ref, args):
|
|
632
|
+
cost_table = self.create_cost_table(ref)
|
|
633
|
+
return cost_table, list(args)
|
|
634
|
+
case _:
|
|
635
|
+
assert_never(expr)
|
|
636
|
+
|
|
612
637
|
def exprs_from_egg(
|
|
613
638
|
self, termdag: bindings.TermDag, terms: list[bindings._Term], tp: JustTypeRef
|
|
614
639
|
) -> Iterable[TypedExprDecl]:
|
|
@@ -652,6 +677,129 @@ class EGraphState:
|
|
|
652
677
|
case _:
|
|
653
678
|
assert_never(ref)
|
|
654
679
|
|
|
680
|
+
def typed_expr_to_value(self, typed_expr: TypedExprDecl) -> bindings.Value:
|
|
681
|
+
egg_expr = self.typed_expr_to_egg(typed_expr, False)
|
|
682
|
+
return self.egraph.eval_expr(egg_expr)[1]
|
|
683
|
+
|
|
684
|
+
def value_to_expr(self, tp: JustTypeRef, value: bindings.Value) -> ExprDecl: # noqa: C901, PLR0911, PLR0912
|
|
685
|
+
match tp.name:
|
|
686
|
+
# Should match list in egraph bindings
|
|
687
|
+
case "i64":
|
|
688
|
+
return LitDecl(self.egraph.value_to_i64(value))
|
|
689
|
+
case "f64":
|
|
690
|
+
return LitDecl(self.egraph.value_to_f64(value))
|
|
691
|
+
case "Bool":
|
|
692
|
+
return LitDecl(self.egraph.value_to_bool(value))
|
|
693
|
+
case "String":
|
|
694
|
+
return LitDecl(self.egraph.value_to_string(value))
|
|
695
|
+
case "Unit":
|
|
696
|
+
return LitDecl(None)
|
|
697
|
+
case "PyObject":
|
|
698
|
+
return PyObjectDecl(self.egraph.value_to_pyobject(GLOBAL_PY_OBJECT_SORT, value))
|
|
699
|
+
case "Rational":
|
|
700
|
+
fraction = self.egraph.value_to_rational(value)
|
|
701
|
+
return CallDecl(
|
|
702
|
+
InitRef("Rational"),
|
|
703
|
+
(
|
|
704
|
+
TypedExprDecl(JustTypeRef("i64"), LitDecl(fraction.numerator)),
|
|
705
|
+
TypedExprDecl(JustTypeRef("i64"), LitDecl(fraction.denominator)),
|
|
706
|
+
),
|
|
707
|
+
)
|
|
708
|
+
case "BigInt":
|
|
709
|
+
i = self.egraph.value_to_bigint(value)
|
|
710
|
+
return CallDecl(
|
|
711
|
+
ClassMethodRef("BigInt", "from_string"),
|
|
712
|
+
(TypedExprDecl(JustTypeRef("String"), LitDecl(str(i))),),
|
|
713
|
+
)
|
|
714
|
+
case "BigRat":
|
|
715
|
+
fraction = self.egraph.value_to_bigrat(value)
|
|
716
|
+
return CallDecl(
|
|
717
|
+
InitRef("BigRat"),
|
|
718
|
+
(
|
|
719
|
+
TypedExprDecl(
|
|
720
|
+
JustTypeRef("BigInt"),
|
|
721
|
+
CallDecl(
|
|
722
|
+
ClassMethodRef("BigInt", "from_string"),
|
|
723
|
+
(TypedExprDecl(JustTypeRef("String"), LitDecl(str(fraction.numerator))),),
|
|
724
|
+
),
|
|
725
|
+
),
|
|
726
|
+
TypedExprDecl(
|
|
727
|
+
JustTypeRef("BigInt"),
|
|
728
|
+
CallDecl(
|
|
729
|
+
ClassMethodRef("BigInt", "from_string"),
|
|
730
|
+
(TypedExprDecl(JustTypeRef("String"), LitDecl(str(fraction.denominator))),),
|
|
731
|
+
),
|
|
732
|
+
),
|
|
733
|
+
),
|
|
734
|
+
)
|
|
735
|
+
case "Map":
|
|
736
|
+
k_tp, v_tp = tp.args
|
|
737
|
+
expr = CallDecl(ClassMethodRef("Map", "empty"), (), (k_tp, v_tp))
|
|
738
|
+
for k, v in self.egraph.value_to_map(value).items():
|
|
739
|
+
expr = CallDecl(
|
|
740
|
+
MethodRef("Map", "insert"),
|
|
741
|
+
(
|
|
742
|
+
TypedExprDecl(tp, expr),
|
|
743
|
+
TypedExprDecl(k_tp, self.value_to_expr(k_tp, k)),
|
|
744
|
+
TypedExprDecl(v_tp, self.value_to_expr(v_tp, v)),
|
|
745
|
+
),
|
|
746
|
+
)
|
|
747
|
+
return expr
|
|
748
|
+
case "Set":
|
|
749
|
+
xs_ = self.egraph.value_to_set(value)
|
|
750
|
+
(v_tp,) = tp.args
|
|
751
|
+
return CallDecl(
|
|
752
|
+
InitRef("Set"), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs_), (v_tp,)
|
|
753
|
+
)
|
|
754
|
+
case "Vec":
|
|
755
|
+
xs = self.egraph.value_to_vec(value)
|
|
756
|
+
(v_tp,) = tp.args
|
|
757
|
+
return CallDecl(
|
|
758
|
+
InitRef("Vec"), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs), (v_tp,)
|
|
759
|
+
)
|
|
760
|
+
case "MultiSet":
|
|
761
|
+
xs = self.egraph.value_to_multiset(value)
|
|
762
|
+
(v_tp,) = tp.args
|
|
763
|
+
return CallDecl(
|
|
764
|
+
InitRef("MultiSet"), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs), (v_tp,)
|
|
765
|
+
)
|
|
766
|
+
case "UnstableFn":
|
|
767
|
+
_names, _args = self.egraph.value_to_function(value)
|
|
768
|
+
return_tp, *arg_types = tp.args
|
|
769
|
+
return self._unstable_fn_value_to_expr(_names, _args, return_tp, arg_types)
|
|
770
|
+
return ValueDecl(value)
|
|
771
|
+
|
|
772
|
+
def _unstable_fn_value_to_expr(
|
|
773
|
+
self, name: str, partial_args: list[bindings.Value], return_tp: JustTypeRef, _arg_types: list[JustTypeRef]
|
|
774
|
+
) -> PartialCallDecl:
|
|
775
|
+
# Similar to FromEggState::from_call but accepts partial list of args and returns in values
|
|
776
|
+
# Find first callable ref whose return type matches and fill in arg types.
|
|
777
|
+
for callable_ref in self.egg_fn_to_callable_refs[name]:
|
|
778
|
+
signature = self.__egg_decls__.get_callable_decl(callable_ref).signature
|
|
779
|
+
if not isinstance(signature, FunctionSignature):
|
|
780
|
+
continue
|
|
781
|
+
if signature.semantic_return_type.name != return_tp.name:
|
|
782
|
+
continue
|
|
783
|
+
tcs = TypeConstraintSolver(self.__egg_decls__)
|
|
784
|
+
|
|
785
|
+
arg_types, bound_tp_params = tcs.infer_arg_types(
|
|
786
|
+
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, return_tp, None
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
args = tuple(
|
|
790
|
+
TypedExprDecl(tp, self.value_to_expr(tp, v)) for tp, v in zip(arg_types, partial_args, strict=False)
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
call_decl = CallDecl(
|
|
794
|
+
callable_ref,
|
|
795
|
+
args,
|
|
796
|
+
# Don't include bound type params if this is just a method, we only needed them for type resolution
|
|
797
|
+
# but dont need to store them
|
|
798
|
+
bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else (),
|
|
799
|
+
)
|
|
800
|
+
return PartialCallDecl(call_decl)
|
|
801
|
+
raise ValueError(f"Function '{name}' not found")
|
|
802
|
+
|
|
655
803
|
|
|
656
804
|
# https://chatgpt.com/share/9ab899b4-4e17-4426-a3f2-79d67a5ec456
|
|
657
805
|
_EGGLOG_INVALID_IDENT = re.compile(r"[^\w\-+*/?!=<>&|^/%]")
|
|
@@ -789,7 +937,7 @@ class FromEggState:
|
|
|
789
937
|
args,
|
|
790
938
|
# Don't include bound type params if this is just a method, we only needed them for type resolution
|
|
791
939
|
# but dont need to store them
|
|
792
|
-
bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else
|
|
940
|
+
bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else (),
|
|
793
941
|
)
|
|
794
942
|
raise ValueError(
|
|
795
943
|
f"Could not find callable ref for call {term}. None of these refs matched the types: {self.state.egg_fn_to_callable_refs[term.name]}"
|
egglog/exp/array_api_jit.py
CHANGED
|
@@ -4,7 +4,7 @@ from typing import TypeVar, cast
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
|
|
7
|
-
from egglog import EGraph
|
|
7
|
+
from egglog import EGraph, greedy_dag_cost_model
|
|
8
8
|
from egglog.exp.array_api import NDArray, set_array_api_egraph, try_evaling
|
|
9
9
|
from egglog.exp.array_api_numba import array_api_numba_schedule
|
|
10
10
|
from egglog.exp.array_api_program_gen import EvalProgram, array_api_program_gen_schedule, ndarray_function_two_program
|
|
@@ -41,7 +41,7 @@ def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph,
|
|
|
41
41
|
res = fn(NDArray.var(arg1), NDArray.var(arg2))
|
|
42
42
|
egraph.register(res)
|
|
43
43
|
egraph.run(array_api_numba_schedule)
|
|
44
|
-
res_optimized = egraph.extract(res)
|
|
44
|
+
res_optimized = egraph.extract(res, cost_model=greedy_dag_cost_model())
|
|
45
45
|
|
|
46
46
|
return (
|
|
47
47
|
egraph,
|
egglog/pretty.py
CHANGED
|
@@ -183,7 +183,7 @@ class TraverseContext:
|
|
|
183
183
|
if isinstance(de, DefaultRewriteDecl):
|
|
184
184
|
continue
|
|
185
185
|
self(de)
|
|
186
|
-
case CallDecl(ref, exprs, _):
|
|
186
|
+
case CallDecl(ref, exprs, _) | GetCostDecl(ref, exprs):
|
|
187
187
|
match ref:
|
|
188
188
|
case FunctionRef(UnnamedFunctionRef(_, res)):
|
|
189
189
|
self(res.expr)
|
|
@@ -205,12 +205,13 @@ class TraverseContext:
|
|
|
205
205
|
case SetCostDecl(_, e, c):
|
|
206
206
|
self(e)
|
|
207
207
|
self(c)
|
|
208
|
-
case BackOffDecl():
|
|
208
|
+
case BackOffDecl() | ValueDecl():
|
|
209
209
|
pass
|
|
210
210
|
case LetSchedulerDecl(scheduler, schedule):
|
|
211
211
|
self(scheduler)
|
|
212
212
|
self(schedule)
|
|
213
|
-
|
|
213
|
+
case GetCostDecl(ref, args):
|
|
214
|
+
self(CallDecl(ref, args))
|
|
214
215
|
case _:
|
|
215
216
|
assert_never(decl)
|
|
216
217
|
|
|
@@ -353,6 +354,10 @@ class PrettyContext:
|
|
|
353
354
|
if ban_length is not None:
|
|
354
355
|
list_args.append(f"ban_length={ban_length}")
|
|
355
356
|
return f"back_off({', '.join(list_args)})", "scheduler"
|
|
357
|
+
case ValueDecl(value):
|
|
358
|
+
return str(value), "value"
|
|
359
|
+
case GetCostDecl(ref, args):
|
|
360
|
+
return f"get_cost({self(CallDecl(ref, args))})", "get_cost"
|
|
356
361
|
assert_never(decl)
|
|
357
362
|
|
|
358
363
|
def _call(
|
egglog/runtime.py
CHANGED
|
@@ -457,7 +457,7 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
457
457
|
arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
|
|
458
458
|
return_tp = tcs.substitute_typevars(signature.semantic_return_type, cls_name)
|
|
459
459
|
bound_params = (
|
|
460
|
-
cast("JustTypeRef", bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else
|
|
460
|
+
cast("JustTypeRef", bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else ()
|
|
461
461
|
)
|
|
462
462
|
# If we were using unstable-app to call a funciton, add that function back as the first arg.
|
|
463
463
|
if function_value:
|
|
@@ -584,11 +584,17 @@ class RuntimeExpr(DelayedDeclerations):
|
|
|
584
584
|
if (method := _get_expr_method(self, "__eq__")) is not None:
|
|
585
585
|
return method(other)
|
|
586
586
|
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
587
|
+
if not (isinstance(self, RuntimeExpr) and isinstance(other, RuntimeExpr)):
|
|
588
|
+
return NotImplemented
|
|
589
|
+
if self.__egg_typed_expr__.tp != other.__egg_typed_expr__.tp:
|
|
590
|
+
return NotImplemented
|
|
590
591
|
|
|
591
|
-
|
|
592
|
+
from .egraph import Fact # noqa: PLC0415
|
|
593
|
+
|
|
594
|
+
return Fact(
|
|
595
|
+
Declarations.create(self, other),
|
|
596
|
+
EqDecl(self.__egg_typed_expr__.tp, self.__egg_typed_expr__.expr, other.__egg_typed_expr__.expr),
|
|
597
|
+
)
|
|
592
598
|
|
|
593
599
|
def __ne__(self, other: object) -> object: # type: ignore[override]
|
|
594
600
|
if (method := _get_expr_method(self, "__ne__")) is not None:
|
egglog/type_constraint_solver.py
CHANGED
|
@@ -54,7 +54,7 @@ class TypeConstraintSolver:
|
|
|
54
54
|
fn_var_args: TypeOrVarRef | None,
|
|
55
55
|
return_: JustTypeRef,
|
|
56
56
|
cls_name: str | None,
|
|
57
|
-
) -> tuple[Iterable[JustTypeRef], tuple[JustTypeRef, ...]
|
|
57
|
+
) -> tuple[Iterable[JustTypeRef], tuple[JustTypeRef, ...]]:
|
|
58
58
|
"""
|
|
59
59
|
Given a return type, infer the argument types. If there is a variable arg, it returns an infinite iterable.
|
|
60
60
|
|
|
@@ -75,7 +75,7 @@ class TypeConstraintSolver:
|
|
|
75
75
|
)
|
|
76
76
|
)
|
|
77
77
|
if cls_name
|
|
78
|
-
else
|
|
78
|
+
else ()
|
|
79
79
|
)
|
|
80
80
|
return arg_types, bound_typevars
|
|
81
81
|
|
|
@@ -1,16 +1,16 @@
|
|
|
1
|
-
egglog-11.
|
|
2
|
-
egglog-11.
|
|
3
|
-
egglog-11.
|
|
1
|
+
egglog-11.4.0.dist-info/METADATA,sha256=QIRn-wS4KchOL4lLxMhdsJKD8Z-7ikr467ysgLww_jM,4554
|
|
2
|
+
egglog-11.4.0.dist-info/WHEEL,sha256=PWlsTwJu6g7mtdc3pNbFOLMPvwrDUqswnFv2cL23vxs,127
|
|
3
|
+
egglog-11.4.0.dist-info/licenses/LICENSE,sha256=w7VlVv5O_FPZRo8Z-4Zb_q7D5ac3YDs8JUkMZ4Gq9CE,1070
|
|
4
4
|
egglog/__init__.py,sha256=0r3MzQbU-9U0fSCeAoJ3deVhZ77tI-1tf8A_WFOhbJs,344
|
|
5
|
-
egglog/bindings.cpython-310-powerpc64-linux-gnu.so,sha256
|
|
6
|
-
egglog/bindings.pyi,sha256=
|
|
7
|
-
egglog/builtins.py,sha256=
|
|
5
|
+
egglog/bindings.cpython-310-powerpc64-linux-gnu.so,sha256=e6C3pARiYc6eBFFK0rekPQKjzESbRXUb3HLz6Dq2L9Q,174573736
|
|
6
|
+
egglog/bindings.pyi,sha256=ntbf2xtpeiabQzUUzx0EdKzu0COXh7sX0_qaUdtB0BY,17874
|
|
7
|
+
egglog/builtins.py,sha256=OSK-JUCKDhpacwPhezPUI9KEru_XIcnA9u9drEyuJi4,30512
|
|
8
8
|
egglog/config.py,sha256=yM3FIcVCKnhWZmHD0pxkzx7ah7t5PxZx3WUqKtA9tjU,168
|
|
9
9
|
egglog/conversion.py,sha256=DO76lxRbbTqHs6hRo_Lckvtwu0c6LaKoX7k5_B2AfuY,11238
|
|
10
|
-
egglog/declarations.py,sha256=
|
|
11
|
-
egglog/deconstruct.py,sha256=
|
|
12
|
-
egglog/egraph.py,sha256=
|
|
13
|
-
egglog/egraph_state.py,sha256=
|
|
10
|
+
egglog/declarations.py,sha256=9gFcg0mc86JuJ4DDE_8qm8d324JiTJZxyLMnx59P3qs,26521
|
|
11
|
+
egglog/deconstruct.py,sha256=b7D5uCaLII-JtlMWcOK8s6LWB8nJe2N879iOIio3-ak,5455
|
|
12
|
+
egglog/egraph.py,sha256=nqL6idjBkURnL3gJ7V-EwhyM5pmfo9-CG19TUO6b9f8,77397
|
|
13
|
+
egglog/egraph_state.py,sha256=UOjcY242GZ0E-4lDt_taKihgxdVgVFX4YCWaDW0_AQs,43394
|
|
14
14
|
egglog/examples/README.rst,sha256=ztTvpofR0eotSqGoCy_C1fPLDPCncjvcqDanXtLHNNU,232
|
|
15
15
|
egglog/examples/__init__.py,sha256=wm9evUbMPfbtylXIjbDdRTAVMLH4OjT4Z77PCBFyaPU,31
|
|
16
16
|
egglog/examples/bignum.py,sha256=jfL57XXpQqIqizQQ3sSUCCjTrkdjtB71BmjrQIQorQk,535
|
|
@@ -27,20 +27,20 @@ egglog/examples/resolution.py,sha256=BJd5JClA3DBVGfiVRa-H0gbbFvIqeP3uYbhCXHblSQc
|
|
|
27
27
|
egglog/examples/schedule_demo.py,sha256=JbXdPII7_adxtgyKVAiqCyV2sj88VZ-DhomYrdn8vuc,618
|
|
28
28
|
egglog/exp/__init__.py,sha256=nPtzrH1bz1LVZhZCuS0S9Qild8m5gEikjOVqWAFIa88,49
|
|
29
29
|
egglog/exp/array_api.py,sha256=dKgEufUIyoT7J_RvnyGtOkg_DK25ZnxIgt7olVygaH8,65547
|
|
30
|
-
egglog/exp/array_api_jit.py,sha256=
|
|
30
|
+
egglog/exp/array_api_jit.py,sha256=S5XGWT8a_bFNHUXbQZi6U2cmy__xjvDMNl1eUgvRDyw,1739
|
|
31
31
|
egglog/exp/array_api_loopnest.py,sha256=-kbyorlGxvlaNsLx1nmLfEZHQM7VMEBwSKtV0l-bs0g,2444
|
|
32
32
|
egglog/exp/array_api_numba.py,sha256=X3H1TnCjPL92uVm6OvcWMJ11IeorAE58zWiOX6huPv4,2696
|
|
33
33
|
egglog/exp/array_api_program_gen.py,sha256=qnve8iqklRQVyGChllG8ZAjAffRpezmdxc3IdaWytoQ,21779
|
|
34
34
|
egglog/exp/program_gen.py,sha256=CavsD70x0ERS87V4OU9zkgMvLXswGEpb1ZZFK0WyN_g,13033
|
|
35
35
|
egglog/exp/siu_examples.py,sha256=yZ-sgH2Y12iTdwBUumP7D2OtCGL83M6pPW7PMobVFXc,719
|
|
36
36
|
egglog/ipython_magic.py,sha256=2hs3g2cSiyDmbCvE2t1OINmu17Bb8MWV--2DpEWwO7I,1189
|
|
37
|
-
egglog/pretty.py,sha256=
|
|
37
|
+
egglog/pretty.py,sha256=gVq3zHkvyCeVZ_u1uqxICy_2srD1-oRzpP5WqFiGFmY,22378
|
|
38
38
|
egglog/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
39
|
-
egglog/runtime.py,sha256=
|
|
39
|
+
egglog/runtime.py,sha256=VatMZWUqMbo_tc6RkdsEhnSuEZbUbropiYZdppn7S4Y,29805
|
|
40
40
|
egglog/thunk.py,sha256=MrAlPoGK36VQrUrq8PWSaJFu42sPL0yupwiH18lNips,2271
|
|
41
|
-
egglog/type_constraint_solver.py,sha256=
|
|
41
|
+
egglog/type_constraint_solver.py,sha256=jivNkqjRTm38miaoxQoUzTntZbx1yYJO7FhuZWik3lg,4509
|
|
42
42
|
egglog/version_compat.py,sha256=EaKRMIOPcatrx9XjCofxZD6Nr5WOooiWNdoapkKleww,3512
|
|
43
43
|
egglog/visualizer.css,sha256=eL0POoThQRc0P4OYnDT-d808ln9O5Qy6DizH9Z5LgWc,259398
|
|
44
44
|
egglog/visualizer.js,sha256=2qZZ-9W_INJx4gZMYjnVXl27IjT_JNuQyEeI2dbjWoU,3753315
|
|
45
45
|
egglog/visualizer_widget.py,sha256=LtVfzOtv2WeKtNuILQQ_9SOHWvRr8YdBYQDKQSgry_s,1319
|
|
46
|
-
egglog-11.
|
|
46
|
+
egglog-11.4.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|