egglog 11.2.0__cp310-cp310-win_amd64.whl → 11.4.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp310-win_amd64.pyd +0 -0
- egglog/bindings.pyi +62 -1
- egglog/builtins.py +10 -0
- egglog/declarations.py +39 -4
- egglog/deconstruct.py +9 -7
- egglog/egraph.py +360 -26
- egglog/egraph_state.py +283 -12
- egglog/examples/jointree.py +0 -3
- egglog/exp/array_api_jit.py +2 -2
- egglog/pretty.py +38 -8
- egglog/runtime.py +22 -7
- egglog/type_constraint_solver.py +2 -2
- {egglog-11.2.0.dist-info → egglog-11.4.0.dist-info}/METADATA +19 -1
- {egglog-11.2.0.dist-info → egglog-11.4.0.dist-info}/RECORD +16 -16
- {egglog-11.2.0.dist-info → egglog-11.4.0.dist-info}/WHEEL +0 -0
- {egglog-11.2.0.dist-info → egglog-11.4.0.dist-info}/licenses/LICENSE +0 -0
egglog/egraph.py
CHANGED
|
@@ -16,6 +16,7 @@ from typing import (
|
|
|
16
16
|
ClassVar,
|
|
17
17
|
Generic,
|
|
18
18
|
Literal,
|
|
19
|
+
Protocol,
|
|
19
20
|
TypeAlias,
|
|
20
21
|
TypedDict,
|
|
21
22
|
TypeVar,
|
|
@@ -23,6 +24,7 @@ from typing import (
|
|
|
23
24
|
get_type_hints,
|
|
24
25
|
overload,
|
|
25
26
|
)
|
|
27
|
+
from uuid import uuid4
|
|
26
28
|
from warnings import warn
|
|
27
29
|
|
|
28
30
|
import graphviz
|
|
@@ -40,20 +42,24 @@ from .thunk import *
|
|
|
40
42
|
from .version_compat import *
|
|
41
43
|
|
|
42
44
|
if TYPE_CHECKING:
|
|
43
|
-
from .builtins import String, Unit, i64Like
|
|
45
|
+
from .builtins import String, Unit, i64, i64Like
|
|
44
46
|
|
|
45
47
|
|
|
46
48
|
__all__ = [
|
|
47
49
|
"Action",
|
|
50
|
+
"BackOff",
|
|
48
51
|
"BaseExpr",
|
|
49
52
|
"BuiltinExpr",
|
|
50
53
|
"Command",
|
|
51
54
|
"Command",
|
|
55
|
+
"CostModel",
|
|
52
56
|
"EGraph",
|
|
53
57
|
"Expr",
|
|
58
|
+
"ExprCallable",
|
|
54
59
|
"Fact",
|
|
55
60
|
"Fact",
|
|
56
61
|
"GraphvizKwargs",
|
|
62
|
+
"GreedyDagCost",
|
|
57
63
|
"RewriteOrRule",
|
|
58
64
|
"Ruleset",
|
|
59
65
|
"Schedule",
|
|
@@ -63,16 +69,20 @@ __all__ = [
|
|
|
63
69
|
"_RewriteBuilder",
|
|
64
70
|
"_SetBuilder",
|
|
65
71
|
"_UnionBuilder",
|
|
72
|
+
"back_off",
|
|
66
73
|
"birewrite",
|
|
67
74
|
"check",
|
|
68
75
|
"check_eq",
|
|
69
76
|
"constant",
|
|
77
|
+
"default_cost_model",
|
|
70
78
|
"delete",
|
|
71
79
|
"eq",
|
|
72
80
|
"expr_action",
|
|
73
81
|
"expr_fact",
|
|
74
82
|
"expr_parts",
|
|
75
83
|
"function",
|
|
84
|
+
"get_cost",
|
|
85
|
+
"greedy_dag_cost_model",
|
|
76
86
|
"let",
|
|
77
87
|
"method",
|
|
78
88
|
"ne",
|
|
@@ -85,6 +95,7 @@ __all__ = [
|
|
|
85
95
|
"seq",
|
|
86
96
|
"set_",
|
|
87
97
|
"set_cost",
|
|
98
|
+
"set_current_ruleset",
|
|
88
99
|
"subsume",
|
|
89
100
|
"union",
|
|
90
101
|
"unstable_combine_rulesets",
|
|
@@ -449,7 +460,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
|
|
|
449
460
|
continue
|
|
450
461
|
locals = frame.f_locals
|
|
451
462
|
ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
|
|
452
|
-
# TODO: Store deprecated message so we can
|
|
463
|
+
# TODO: Store deprecated message so we can get at runtime
|
|
453
464
|
if (getattr(fn, "__deprecated__", None)) is not None:
|
|
454
465
|
fn = fn.__wrapped__ # type: ignore[attr-defined]
|
|
455
466
|
match fn:
|
|
@@ -905,8 +916,8 @@ class EGraph:
|
|
|
905
916
|
|
|
906
917
|
def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
|
|
907
918
|
self._add_decls(schedule)
|
|
908
|
-
|
|
909
|
-
(command_output,) = self._egraph.run_program(
|
|
919
|
+
cmd = self._state.run_schedule_to_egg(schedule.schedule)
|
|
920
|
+
(command_output,) = self._egraph.run_program(cmd)
|
|
910
921
|
assert isinstance(command_output, bindings.RunScheduleOutput)
|
|
911
922
|
return command_output.report
|
|
912
923
|
|
|
@@ -950,22 +961,45 @@ class EGraph:
|
|
|
950
961
|
return bindings.Check(span(2), egg_facts)
|
|
951
962
|
|
|
952
963
|
@overload
|
|
953
|
-
def extract(
|
|
964
|
+
def extract(
|
|
965
|
+
self, expr: BASE_EXPR, /, include_cost: Literal[False] = False, cost_model: CostModel | None = None
|
|
966
|
+
) -> BASE_EXPR: ...
|
|
954
967
|
|
|
955
968
|
@overload
|
|
956
|
-
def extract(
|
|
969
|
+
def extract(
|
|
970
|
+
self, expr: BASE_EXPR, /, include_cost: Literal[True], cost_model: None = None
|
|
971
|
+
) -> tuple[BASE_EXPR, int]: ...
|
|
957
972
|
|
|
958
|
-
|
|
973
|
+
@overload
|
|
974
|
+
def extract(
|
|
975
|
+
self, expr: BASE_EXPR, /, include_cost: Literal[True], cost_model: CostModel[COST]
|
|
976
|
+
) -> tuple[BASE_EXPR, COST]: ...
|
|
977
|
+
|
|
978
|
+
def extract(
|
|
979
|
+
self, expr: BASE_EXPR, /, include_cost: bool = False, cost_model: CostModel[COST] | None = None
|
|
980
|
+
) -> BASE_EXPR | tuple[BASE_EXPR, COST]:
|
|
959
981
|
"""
|
|
960
982
|
Extract the lowest cost expression from the egraph.
|
|
961
983
|
"""
|
|
962
984
|
runtime_expr = to_runtime_expr(expr)
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
985
|
+
self._add_decls(runtime_expr)
|
|
986
|
+
tp = runtime_expr.__egg_typed_expr__.tp
|
|
987
|
+
if cost_model is None:
|
|
988
|
+
extract_report = self._run_extract(runtime_expr, 0)
|
|
989
|
+
assert isinstance(extract_report, bindings.ExtractBest)
|
|
990
|
+
res = self._from_termdag(extract_report.termdag, extract_report.term, tp)
|
|
991
|
+
cost = cast("COST", extract_report.cost)
|
|
992
|
+
else:
|
|
993
|
+
# TODO: For some reason we need this or else it wont be registered. Not sure why
|
|
994
|
+
self.register(expr)
|
|
995
|
+
egg_cost_model = _CostModel(cost_model, self).to_bindings_cost_model()
|
|
996
|
+
egg_sort = self._state.type_ref_to_egg(tp)
|
|
997
|
+
extractor = bindings.Extractor([egg_sort], self._state.egraph, egg_cost_model)
|
|
998
|
+
termdag = bindings.TermDag()
|
|
999
|
+
value = self._state.typed_expr_to_value(runtime_expr.__egg_typed_expr__)
|
|
1000
|
+
cost, term = extractor.extract_best(self._state.egraph, termdag, value, egg_sort)
|
|
1001
|
+
res = self._from_termdag(termdag, term, tp)
|
|
1002
|
+
return (res, cost) if include_cost else res
|
|
969
1003
|
|
|
970
1004
|
def _from_termdag(self, termdag: bindings.TermDag, term: bindings._Term, tp: JustTypeRef) -> Any:
|
|
971
1005
|
(new_typed_expr,) = self._state.exprs_from_egg(termdag, [term], tp)
|
|
@@ -976,6 +1010,7 @@ class EGraph:
|
|
|
976
1010
|
Extract multiple expressions from the egraph.
|
|
977
1011
|
"""
|
|
978
1012
|
runtime_expr = to_runtime_expr(expr)
|
|
1013
|
+
self._add_decls(runtime_expr)
|
|
979
1014
|
extract_report = self._run_extract(runtime_expr, n)
|
|
980
1015
|
assert isinstance(extract_report, bindings.ExtractVariants)
|
|
981
1016
|
new_exprs = self._state.exprs_from_egg(
|
|
@@ -984,7 +1019,6 @@ class EGraph:
|
|
|
984
1019
|
return [cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs]
|
|
985
1020
|
|
|
986
1021
|
def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._CommandOutput:
|
|
987
|
-
self._add_decls(expr)
|
|
988
1022
|
expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
|
|
989
1023
|
# If we have defined any cost tables use the custom extraction
|
|
990
1024
|
args = (expr, bindings.Lit(span(2), bindings.Int(n)))
|
|
@@ -1033,7 +1067,7 @@ class EGraph:
|
|
|
1033
1067
|
split_primitive_outputs = kwargs.pop("split_primitive_outputs", True)
|
|
1034
1068
|
split_functions = kwargs.pop("split_functions", [])
|
|
1035
1069
|
include_temporary_functions = kwargs.pop("include_temporary_functions", False)
|
|
1036
|
-
n_inline_leaves = kwargs.pop("n_inline_leaves",
|
|
1070
|
+
n_inline_leaves = kwargs.pop("n_inline_leaves", 0)
|
|
1037
1071
|
serialized = self._egraph.serialize(
|
|
1038
1072
|
[],
|
|
1039
1073
|
max_functions=max_functions,
|
|
@@ -1209,16 +1243,12 @@ class EGraph:
|
|
|
1209
1243
|
"""
|
|
1210
1244
|
(output,) = self._egraph.run_program(bindings.PrintSize(span(1), None))
|
|
1211
1245
|
assert isinstance(output, bindings.PrintAllFunctionsSize)
|
|
1246
|
+
return [(callables[0], size) for (name, size) in output.sizes if (callables := self._egg_fn_to_callables(name))]
|
|
1247
|
+
|
|
1248
|
+
def _egg_fn_to_callables(self, egg_fn: str) -> list[ExprCallable]:
|
|
1212
1249
|
return [
|
|
1213
|
-
(
|
|
1214
|
-
|
|
1215
|
-
"ExprCallable",
|
|
1216
|
-
create_callable(self._state.__egg_decls__, next(iter(refs))),
|
|
1217
|
-
),
|
|
1218
|
-
size,
|
|
1219
|
-
)
|
|
1220
|
-
for (name, size) in output.sizes
|
|
1221
|
-
if (refs := self._state.egg_fn_to_callable_refs[name])
|
|
1250
|
+
cast("ExprCallable", create_callable(self._state.__egg_decls__, ref))
|
|
1251
|
+
for ref in self._state.egg_fn_to_callable_refs[egg_fn]
|
|
1222
1252
|
]
|
|
1223
1253
|
|
|
1224
1254
|
def function_values(
|
|
@@ -1242,6 +1272,33 @@ class EGraph:
|
|
|
1242
1272
|
for (call, res) in output.terms
|
|
1243
1273
|
}
|
|
1244
1274
|
|
|
1275
|
+
def lookup_function_value(self, expr: BASE_EXPR) -> BASE_EXPR | None:
|
|
1276
|
+
"""
|
|
1277
|
+
Given an expression that is a function call, looks up the value of the function call if it exists.
|
|
1278
|
+
"""
|
|
1279
|
+
runtime_expr = to_runtime_expr(expr)
|
|
1280
|
+
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1281
|
+
assert isinstance(typed_expr.expr, CallDecl | GetCostDecl)
|
|
1282
|
+
egg_fn, typed_args = self._state.translate_call(typed_expr.expr)
|
|
1283
|
+
values_args = [self._state.typed_expr_to_value(a) for a in typed_args]
|
|
1284
|
+
possible_value = self._egraph.lookup_function(egg_fn, values_args)
|
|
1285
|
+
if possible_value is None:
|
|
1286
|
+
return None
|
|
1287
|
+
return cast(
|
|
1288
|
+
"BASE_EXPR",
|
|
1289
|
+
RuntimeExpr.__from_values__(
|
|
1290
|
+
self.__egg_decls__,
|
|
1291
|
+
TypedExprDecl(typed_expr.tp, self._state.value_to_expr(typed_expr.tp, possible_value)),
|
|
1292
|
+
),
|
|
1293
|
+
)
|
|
1294
|
+
|
|
1295
|
+
def has_custom_cost(self, fn: ExprCallable) -> bool:
|
|
1296
|
+
"""
|
|
1297
|
+
Checks if the any custom costs have been set for this expression callable.
|
|
1298
|
+
"""
|
|
1299
|
+
resolved, _ = resolve_callable(fn)
|
|
1300
|
+
return resolved in self._state.cost_callables
|
|
1301
|
+
|
|
1245
1302
|
|
|
1246
1303
|
# Either a constant or a function.
|
|
1247
1304
|
ExprCallable: TypeAlias = Callable[..., BaseExpr] | BaseExpr
|
|
@@ -1786,17 +1843,51 @@ def to_runtime_expr(expr: BaseExpr) -> RuntimeExpr:
|
|
|
1786
1843
|
return expr
|
|
1787
1844
|
|
|
1788
1845
|
|
|
1789
|
-
def run(ruleset: Ruleset | None = None, *until: FactLike) -> Schedule:
|
|
1846
|
+
def run(ruleset: Ruleset | None = None, *until: FactLike, scheduler: BackOff | None = None) -> Schedule:
|
|
1790
1847
|
"""
|
|
1791
1848
|
Create a run configuration.
|
|
1792
1849
|
"""
|
|
1793
1850
|
facts = _fact_likes(until)
|
|
1794
1851
|
return Schedule(
|
|
1795
1852
|
Thunk.fn(Declarations.create, ruleset, *facts),
|
|
1796
|
-
RunDecl(
|
|
1853
|
+
RunDecl(
|
|
1854
|
+
ruleset.__egg_name__ if ruleset else "",
|
|
1855
|
+
tuple(f.fact for f in facts) or None,
|
|
1856
|
+
scheduler.scheduler if scheduler else None,
|
|
1857
|
+
),
|
|
1797
1858
|
)
|
|
1798
1859
|
|
|
1799
1860
|
|
|
1861
|
+
def back_off(match_limit: None | int = None, ban_length: None | int = None) -> BackOff:
|
|
1862
|
+
"""
|
|
1863
|
+
Create a backoff scheduler configuration.
|
|
1864
|
+
|
|
1865
|
+
```python
|
|
1866
|
+
schedule = run(analysis_ruleset).saturate() + run(ruleset, scheduler=back_off(match_limit=1000, ban_length=5)) * 10
|
|
1867
|
+
```
|
|
1868
|
+
This will run the `analysis_ruleset` until saturation, then run `ruleset` 10 times, using a backoff scheduler.
|
|
1869
|
+
"""
|
|
1870
|
+
return BackOff(BackOffDecl(id=uuid4(), match_limit=match_limit, ban_length=ban_length))
|
|
1871
|
+
|
|
1872
|
+
|
|
1873
|
+
@dataclass(frozen=True)
|
|
1874
|
+
class BackOff:
|
|
1875
|
+
scheduler: BackOffDecl
|
|
1876
|
+
|
|
1877
|
+
def scope(self, schedule: Schedule) -> Schedule:
|
|
1878
|
+
"""
|
|
1879
|
+
Defines the scheduler to be created directly before the inner schedule, instead of the default which is at the
|
|
1880
|
+
most outer scope.
|
|
1881
|
+
"""
|
|
1882
|
+
return Schedule(schedule.__egg_decls_thunk__, LetSchedulerDecl(self.scheduler, schedule.schedule))
|
|
1883
|
+
|
|
1884
|
+
def __str__(self) -> str:
|
|
1885
|
+
return pretty_decl(Declarations(), self.scheduler)
|
|
1886
|
+
|
|
1887
|
+
def __repr__(self) -> str:
|
|
1888
|
+
return str(self)
|
|
1889
|
+
|
|
1890
|
+
|
|
1800
1891
|
def seq(*schedules: Schedule) -> Schedule:
|
|
1801
1892
|
"""
|
|
1802
1893
|
Run a sequence of schedules.
|
|
@@ -1873,3 +1964,246 @@ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
|
|
|
1873
1964
|
yield
|
|
1874
1965
|
finally:
|
|
1875
1966
|
_CURRENT_RULESET.reset(token)
|
|
1967
|
+
|
|
1968
|
+
|
|
1969
|
+
def get_cost(expr: BaseExpr) -> i64:
|
|
1970
|
+
"""
|
|
1971
|
+
Return a lookup of the cost of an expression. If not set, won't match.
|
|
1972
|
+
"""
|
|
1973
|
+
assert isinstance(expr, RuntimeExpr)
|
|
1974
|
+
expr_decl = expr.__egg_typed_expr__.expr
|
|
1975
|
+
if not isinstance(expr_decl, CallDecl):
|
|
1976
|
+
msg = "Can only get cost of function calls, not literals or variables"
|
|
1977
|
+
raise TypeError(msg)
|
|
1978
|
+
return RuntimeExpr.__from_values__(
|
|
1979
|
+
expr.__egg_decls__,
|
|
1980
|
+
TypedExprDecl(JustTypeRef("i64"), GetCostDecl(expr_decl.callable, expr_decl.args)),
|
|
1981
|
+
)
|
|
1982
|
+
|
|
1983
|
+
|
|
1984
|
+
class Comparable(Protocol):
|
|
1985
|
+
def __lt__(self, other: Self) -> bool: ...
|
|
1986
|
+
def __le__(self, other: Self) -> bool: ...
|
|
1987
|
+
def __gt__(self, other: Self) -> bool: ...
|
|
1988
|
+
def __ge__(self, other: Self) -> bool: ...
|
|
1989
|
+
|
|
1990
|
+
|
|
1991
|
+
COST = TypeVar("COST", bound=Comparable)
|
|
1992
|
+
|
|
1993
|
+
|
|
1994
|
+
class CostModel(Protocol, Generic[COST]):
|
|
1995
|
+
"""
|
|
1996
|
+
A cost model for an e-graph. Used to determine the cost of an expression based on its structure and the costs of its sub-expressions.
|
|
1997
|
+
|
|
1998
|
+
Called with an expression and the costs of its children, returns the total cost of the expression.
|
|
1999
|
+
|
|
2000
|
+
Additionally, the cost model should guarantee that a term has a no-smaller cost
|
|
2001
|
+
than its subterms to avoid cycles in the extracted terms for common case usages.
|
|
2002
|
+
For more niche usages, a term can have a cost less than its subterms.
|
|
2003
|
+
As long as there is no negative cost cycle, the default extractor is guaranteed to terminate in computing the costs.
|
|
2004
|
+
However, the user needs to be careful to guarantee acyclicity in the extracted terms.
|
|
2005
|
+
"""
|
|
2006
|
+
|
|
2007
|
+
def __call__(self, egraph: EGraph, expr: BaseExpr, children_costs: list[COST]) -> COST:
|
|
2008
|
+
"""
|
|
2009
|
+
The total cost of a term given the cost of the root e-node and its immediate children's total costs.
|
|
2010
|
+
"""
|
|
2011
|
+
raise NotImplementedError
|
|
2012
|
+
|
|
2013
|
+
|
|
2014
|
+
def default_cost_model(egraph: EGraph, expr: BaseExpr, children_costs: list[int]) -> int:
|
|
2015
|
+
"""
|
|
2016
|
+
A default cost model for an e-graph, which looks up costs set on function calls, or uses 1 as the default cost.
|
|
2017
|
+
"""
|
|
2018
|
+
from .builtins import Container # noqa: PLC0415
|
|
2019
|
+
from .deconstruct import get_callable_fn # noqa: PLC0415
|
|
2020
|
+
|
|
2021
|
+
# 1. First prefer if the expr has a custom cost set on it
|
|
2022
|
+
if (
|
|
2023
|
+
(callable_fn := get_callable_fn(expr)) is not None
|
|
2024
|
+
and egraph.has_custom_cost(callable_fn)
|
|
2025
|
+
and (i := egraph.lookup_function_value(get_cost(expr))) is not None
|
|
2026
|
+
):
|
|
2027
|
+
self_cost = int(i)
|
|
2028
|
+
# 2. Else, check if this is a callable and it has a cost set on its declaration
|
|
2029
|
+
elif callable_fn is not None and (callable_cost := get_callable_cost(callable_fn)) is not None:
|
|
2030
|
+
self_cost = callable_cost
|
|
2031
|
+
# 3. Else, if this is a container, it has no cost, otherwise it has a cost of 1
|
|
2032
|
+
else:
|
|
2033
|
+
# By default, all nodes have a cost of 1 except for containers which have a cost of 0
|
|
2034
|
+
self_cost = 0 if isinstance(expr, Container) else 1
|
|
2035
|
+
# Sum up the costs of the children and our own cost
|
|
2036
|
+
return sum(children_costs, start=self_cost)
|
|
2037
|
+
|
|
2038
|
+
|
|
2039
|
+
class ComparableAddSub(Comparable, Protocol):
|
|
2040
|
+
def __add__(self, other: Self) -> Self: ...
|
|
2041
|
+
def __sub__(self, other: Self) -> Self: ...
|
|
2042
|
+
|
|
2043
|
+
|
|
2044
|
+
DAG_COST = TypeVar("DAG_COST", bound=ComparableAddSub)
|
|
2045
|
+
|
|
2046
|
+
|
|
2047
|
+
@dataclass
|
|
2048
|
+
class GreedyDagCost(Generic[DAG_COST]):
|
|
2049
|
+
"""
|
|
2050
|
+
Cost of a DAG, which stores children costs. Use `.total` to get the underlying cost.
|
|
2051
|
+
"""
|
|
2052
|
+
|
|
2053
|
+
total: DAG_COST
|
|
2054
|
+
_costs: dict[TypedExprDecl, DAG_COST] = field(repr=False)
|
|
2055
|
+
|
|
2056
|
+
def __eq__(self, other: object) -> bool:
|
|
2057
|
+
if not isinstance(other, GreedyDagCost):
|
|
2058
|
+
return NotImplemented
|
|
2059
|
+
return self.total == other.total
|
|
2060
|
+
|
|
2061
|
+
def __lt__(self, other: Self) -> bool:
|
|
2062
|
+
return self.total < other.total
|
|
2063
|
+
|
|
2064
|
+
def __le__(self, other: Self) -> bool:
|
|
2065
|
+
return self.total <= other.total
|
|
2066
|
+
|
|
2067
|
+
def __gt__(self, other: Self) -> bool:
|
|
2068
|
+
return self.total > other.total
|
|
2069
|
+
|
|
2070
|
+
def __ge__(self, other: Self) -> bool:
|
|
2071
|
+
return self.total >= other.total
|
|
2072
|
+
|
|
2073
|
+
def __hash__(self) -> int:
|
|
2074
|
+
return hash(self.total)
|
|
2075
|
+
|
|
2076
|
+
|
|
2077
|
+
@dataclass
|
|
2078
|
+
class GreedyDagCostModel(CostModel[GreedyDagCost[DAG_COST]]):
|
|
2079
|
+
"""
|
|
2080
|
+
A cost model which will count duplicate nodes only once.
|
|
2081
|
+
|
|
2082
|
+
Should have similar behavior as https://github.com/egraphs-good/extraction-gym/blob/main/src/extract/greedy_dag.rs
|
|
2083
|
+
but implemented as a cost model that will be used with the default extractor.
|
|
2084
|
+
"""
|
|
2085
|
+
|
|
2086
|
+
base: CostModel[DAG_COST]
|
|
2087
|
+
|
|
2088
|
+
def __call__(
|
|
2089
|
+
self, egraph: EGraph, expr: BaseExpr, children_costs: list[GreedyDagCost[DAG_COST]]
|
|
2090
|
+
) -> GreedyDagCost[DAG_COST]:
|
|
2091
|
+
cost = self.base(egraph, expr, [c.total for c in children_costs])
|
|
2092
|
+
for c in children_costs:
|
|
2093
|
+
cost -= c.total
|
|
2094
|
+
costs = {}
|
|
2095
|
+
for c in children_costs:
|
|
2096
|
+
costs.update(c._costs)
|
|
2097
|
+
total = sum(costs.values(), start=cost)
|
|
2098
|
+
costs[to_runtime_expr(expr).__egg_typed_expr__] = cost
|
|
2099
|
+
return GreedyDagCost(total, costs)
|
|
2100
|
+
|
|
2101
|
+
|
|
2102
|
+
@overload
|
|
2103
|
+
def greedy_dag_cost_model() -> CostModel[GreedyDagCost[int]]: ...
|
|
2104
|
+
|
|
2105
|
+
|
|
2106
|
+
@overload
|
|
2107
|
+
def greedy_dag_cost_model(base: CostModel[DAG_COST]) -> CostModel[GreedyDagCost[DAG_COST]]: ...
|
|
2108
|
+
|
|
2109
|
+
|
|
2110
|
+
def greedy_dag_cost_model(base: CostModel[Any] = default_cost_model) -> CostModel[GreedyDagCost[Any]]:
|
|
2111
|
+
"""
|
|
2112
|
+
Creates a greedy dag cost model from a base cost model.
|
|
2113
|
+
"""
|
|
2114
|
+
return GreedyDagCostModel(base or default_cost_model)
|
|
2115
|
+
|
|
2116
|
+
|
|
2117
|
+
def get_callable_cost(fn: ExprCallable) -> int | None:
|
|
2118
|
+
"""
|
|
2119
|
+
Returns the cost of a callable, if it has one set. Otherwise returns None.
|
|
2120
|
+
"""
|
|
2121
|
+
callable_ref, decls = resolve_callable(fn)
|
|
2122
|
+
callable_decl = decls.get_callable_decl(callable_ref)
|
|
2123
|
+
return callable_decl.cost if isinstance(callable_decl, ConstructorDecl) else 1
|
|
2124
|
+
|
|
2125
|
+
|
|
2126
|
+
@dataclass
|
|
2127
|
+
class _CostModel(Generic[COST]):
|
|
2128
|
+
"""
|
|
2129
|
+
Implements the methods compatible with the bindings for the cost model.
|
|
2130
|
+
"""
|
|
2131
|
+
|
|
2132
|
+
model: CostModel[COST]
|
|
2133
|
+
egraph: EGraph
|
|
2134
|
+
enode_cost_results: dict[tuple[str, tuple[bindings.Value, ...]], int] = field(default_factory=dict)
|
|
2135
|
+
enode_cost_expressions: list[RuntimeExpr] = field(default_factory=list)
|
|
2136
|
+
fold_results: dict[tuple[int, tuple[COST, ...]], COST] = field(default_factory=dict)
|
|
2137
|
+
base_value_cost_results: dict[tuple[str, bindings.Value], COST] = field(default_factory=dict)
|
|
2138
|
+
container_cost_results: dict[tuple[str, bindings.Value, tuple[COST, ...]], COST] = field(default_factory=dict)
|
|
2139
|
+
|
|
2140
|
+
def call_model(self, expr: RuntimeExpr, children_costs: list[COST]) -> COST:
|
|
2141
|
+
return self.model(self.egraph, cast("BaseExpr", expr), children_costs)
|
|
2142
|
+
# if __debug__:
|
|
2143
|
+
# for c in children_costs:
|
|
2144
|
+
# if res <= c:
|
|
2145
|
+
# msg = f"Cost model {self.model} produced a cost {res} less than or equal to a child cost {c} for {expr}"
|
|
2146
|
+
# raise ValueError(msg)
|
|
2147
|
+
|
|
2148
|
+
def fold(self, _fn: str, index: int, children_costs: list[COST]) -> COST:
|
|
2149
|
+
try:
|
|
2150
|
+
return self.fold_results[(index, tuple(children_costs))]
|
|
2151
|
+
except KeyError:
|
|
2152
|
+
pass
|
|
2153
|
+
|
|
2154
|
+
expr = self.enode_cost_expressions[index]
|
|
2155
|
+
return self.call_model(expr, children_costs)
|
|
2156
|
+
|
|
2157
|
+
# enode cost is only ever called right before fold, for the head_cost
|
|
2158
|
+
def enode_cost(self, name: str, args: list[bindings.Value]) -> int:
|
|
2159
|
+
try:
|
|
2160
|
+
return self.enode_cost_results[(name, tuple(args))]
|
|
2161
|
+
except KeyError:
|
|
2162
|
+
pass
|
|
2163
|
+
(callable_ref,) = self.egraph._state.egg_fn_to_callable_refs[name]
|
|
2164
|
+
signature = self.egraph.__egg_decls__.get_callable_decl(callable_ref).signature
|
|
2165
|
+
assert isinstance(signature, FunctionSignature)
|
|
2166
|
+
arg_exprs = [
|
|
2167
|
+
TypedExprDecl(tp.to_just(), self.egraph._state.value_to_expr(tp.to_just(), arg))
|
|
2168
|
+
for (arg, tp) in zip(args, signature.arg_types, strict=True)
|
|
2169
|
+
]
|
|
2170
|
+
res_type = signature.semantic_return_type.to_just()
|
|
2171
|
+
res = RuntimeExpr.__from_values__(
|
|
2172
|
+
self.egraph.__egg_decls__,
|
|
2173
|
+
TypedExprDecl(res_type, CallDecl(callable_ref, tuple(arg_exprs))),
|
|
2174
|
+
)
|
|
2175
|
+
index = len(self.enode_cost_expressions)
|
|
2176
|
+
self.enode_cost_expressions.append(res)
|
|
2177
|
+
self.enode_cost_results[(name, tuple(args))] = index
|
|
2178
|
+
return index
|
|
2179
|
+
|
|
2180
|
+
def base_value_cost(self, tp: str, value: bindings.Value) -> COST:
|
|
2181
|
+
try:
|
|
2182
|
+
return self.base_value_cost_results[(tp, value)]
|
|
2183
|
+
except KeyError:
|
|
2184
|
+
pass
|
|
2185
|
+
type_ref = self.egraph._state.egg_sort_to_type_ref[tp]
|
|
2186
|
+
expr = RuntimeExpr.__from_values__(
|
|
2187
|
+
self.egraph.__egg_decls__,
|
|
2188
|
+
TypedExprDecl(type_ref, self.egraph._state.value_to_expr(type_ref, value)),
|
|
2189
|
+
)
|
|
2190
|
+
res = self.call_model(expr, [])
|
|
2191
|
+
self.base_value_cost_results[(tp, value)] = res
|
|
2192
|
+
return res
|
|
2193
|
+
|
|
2194
|
+
def container_cost(self, tp: str, value: bindings.Value, element_costs: list[COST]) -> COST:
|
|
2195
|
+
try:
|
|
2196
|
+
return self.container_cost_results[(tp, value, tuple(element_costs))]
|
|
2197
|
+
except KeyError:
|
|
2198
|
+
pass
|
|
2199
|
+
type_ref = self.egraph._state.egg_sort_to_type_ref[tp]
|
|
2200
|
+
expr = RuntimeExpr.__from_values__(
|
|
2201
|
+
self.egraph.__egg_decls__,
|
|
2202
|
+
TypedExprDecl(type_ref, self.egraph._state.value_to_expr(type_ref, value)),
|
|
2203
|
+
)
|
|
2204
|
+
res = self.call_model(expr, element_costs)
|
|
2205
|
+
self.container_cost_results[(tp, value, tuple(element_costs))] = res
|
|
2206
|
+
return res
|
|
2207
|
+
|
|
2208
|
+
def to_bindings_cost_model(self) -> bindings.CostModel[COST, int]:
|
|
2209
|
+
return bindings.CostModel(self.fold, self.enode_cost, self.container_cost, self.base_value_cost)
|