egglog 11.2.0__cp311-cp311-win_amd64.whl → 11.4.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph.py CHANGED
@@ -16,6 +16,7 @@ from typing import (
16
16
  ClassVar,
17
17
  Generic,
18
18
  Literal,
19
+ Protocol,
19
20
  TypeAlias,
20
21
  TypedDict,
21
22
  TypeVar,
@@ -23,6 +24,7 @@ from typing import (
23
24
  get_type_hints,
24
25
  overload,
25
26
  )
27
+ from uuid import uuid4
26
28
  from warnings import warn
27
29
 
28
30
  import graphviz
@@ -40,20 +42,24 @@ from .thunk import *
40
42
  from .version_compat import *
41
43
 
42
44
  if TYPE_CHECKING:
43
- from .builtins import String, Unit, i64Like
45
+ from .builtins import String, Unit, i64, i64Like
44
46
 
45
47
 
46
48
  __all__ = [
47
49
  "Action",
50
+ "BackOff",
48
51
  "BaseExpr",
49
52
  "BuiltinExpr",
50
53
  "Command",
51
54
  "Command",
55
+ "CostModel",
52
56
  "EGraph",
53
57
  "Expr",
58
+ "ExprCallable",
54
59
  "Fact",
55
60
  "Fact",
56
61
  "GraphvizKwargs",
62
+ "GreedyDagCost",
57
63
  "RewriteOrRule",
58
64
  "Ruleset",
59
65
  "Schedule",
@@ -63,16 +69,20 @@ __all__ = [
63
69
  "_RewriteBuilder",
64
70
  "_SetBuilder",
65
71
  "_UnionBuilder",
72
+ "back_off",
66
73
  "birewrite",
67
74
  "check",
68
75
  "check_eq",
69
76
  "constant",
77
+ "default_cost_model",
70
78
  "delete",
71
79
  "eq",
72
80
  "expr_action",
73
81
  "expr_fact",
74
82
  "expr_parts",
75
83
  "function",
84
+ "get_cost",
85
+ "greedy_dag_cost_model",
76
86
  "let",
77
87
  "method",
78
88
  "ne",
@@ -85,6 +95,7 @@ __all__ = [
85
95
  "seq",
86
96
  "set_",
87
97
  "set_cost",
98
+ "set_current_ruleset",
88
99
  "subsume",
89
100
  "union",
90
101
  "unstable_combine_rulesets",
@@ -449,7 +460,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
449
460
  continue
450
461
  locals = frame.f_locals
451
462
  ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
452
- # TODO: Store deprecated message so we can print at runtime
463
+ # TODO: Store deprecated message so we can get at runtime
453
464
  if (getattr(fn, "__deprecated__", None)) is not None:
454
465
  fn = fn.__wrapped__ # type: ignore[attr-defined]
455
466
  match fn:
@@ -905,8 +916,8 @@ class EGraph:
905
916
 
906
917
  def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
907
918
  self._add_decls(schedule)
908
- egg_schedule = self._state.schedule_to_egg(schedule.schedule)
909
- (command_output,) = self._egraph.run_program(bindings.RunSchedule(egg_schedule))
919
+ cmd = self._state.run_schedule_to_egg(schedule.schedule)
920
+ (command_output,) = self._egraph.run_program(cmd)
910
921
  assert isinstance(command_output, bindings.RunScheduleOutput)
911
922
  return command_output.report
912
923
 
@@ -950,22 +961,45 @@ class EGraph:
950
961
  return bindings.Check(span(2), egg_facts)
951
962
 
952
963
  @overload
953
- def extract(self, expr: BASE_EXPR, /, include_cost: Literal[False] = False) -> BASE_EXPR: ...
964
+ def extract(
965
+ self, expr: BASE_EXPR, /, include_cost: Literal[False] = False, cost_model: CostModel | None = None
966
+ ) -> BASE_EXPR: ...
954
967
 
955
968
  @overload
956
- def extract(self, expr: BASE_EXPR, /, include_cost: Literal[True]) -> tuple[BASE_EXPR, int]: ...
969
+ def extract(
970
+ self, expr: BASE_EXPR, /, include_cost: Literal[True], cost_model: None = None
971
+ ) -> tuple[BASE_EXPR, int]: ...
957
972
 
958
- def extract(self, expr: BASE_EXPR, include_cost: bool = False) -> BASE_EXPR | tuple[BASE_EXPR, int]:
973
+ @overload
974
+ def extract(
975
+ self, expr: BASE_EXPR, /, include_cost: Literal[True], cost_model: CostModel[COST]
976
+ ) -> tuple[BASE_EXPR, COST]: ...
977
+
978
+ def extract(
979
+ self, expr: BASE_EXPR, /, include_cost: bool = False, cost_model: CostModel[COST] | None = None
980
+ ) -> BASE_EXPR | tuple[BASE_EXPR, COST]:
959
981
  """
960
982
  Extract the lowest cost expression from the egraph.
961
983
  """
962
984
  runtime_expr = to_runtime_expr(expr)
963
- extract_report = self._run_extract(runtime_expr, 0)
964
- assert isinstance(extract_report, bindings.ExtractBest)
965
- res = self._from_termdag(extract_report.termdag, extract_report.term, runtime_expr.__egg_typed_expr__.tp)
966
- if include_cost:
967
- return res, extract_report.cost
968
- return res
985
+ self._add_decls(runtime_expr)
986
+ tp = runtime_expr.__egg_typed_expr__.tp
987
+ if cost_model is None:
988
+ extract_report = self._run_extract(runtime_expr, 0)
989
+ assert isinstance(extract_report, bindings.ExtractBest)
990
+ res = self._from_termdag(extract_report.termdag, extract_report.term, tp)
991
+ cost = cast("COST", extract_report.cost)
992
+ else:
993
+ # TODO: For some reason we need this or else it wont be registered. Not sure why
994
+ self.register(expr)
995
+ egg_cost_model = _CostModel(cost_model, self).to_bindings_cost_model()
996
+ egg_sort = self._state.type_ref_to_egg(tp)
997
+ extractor = bindings.Extractor([egg_sort], self._state.egraph, egg_cost_model)
998
+ termdag = bindings.TermDag()
999
+ value = self._state.typed_expr_to_value(runtime_expr.__egg_typed_expr__)
1000
+ cost, term = extractor.extract_best(self._state.egraph, termdag, value, egg_sort)
1001
+ res = self._from_termdag(termdag, term, tp)
1002
+ return (res, cost) if include_cost else res
969
1003
 
970
1004
  def _from_termdag(self, termdag: bindings.TermDag, term: bindings._Term, tp: JustTypeRef) -> Any:
971
1005
  (new_typed_expr,) = self._state.exprs_from_egg(termdag, [term], tp)
@@ -976,6 +1010,7 @@ class EGraph:
976
1010
  Extract multiple expressions from the egraph.
977
1011
  """
978
1012
  runtime_expr = to_runtime_expr(expr)
1013
+ self._add_decls(runtime_expr)
979
1014
  extract_report = self._run_extract(runtime_expr, n)
980
1015
  assert isinstance(extract_report, bindings.ExtractVariants)
981
1016
  new_exprs = self._state.exprs_from_egg(
@@ -984,7 +1019,6 @@ class EGraph:
984
1019
  return [cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs]
985
1020
 
986
1021
  def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._CommandOutput:
987
- self._add_decls(expr)
988
1022
  expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
989
1023
  # If we have defined any cost tables use the custom extraction
990
1024
  args = (expr, bindings.Lit(span(2), bindings.Int(n)))
@@ -1033,7 +1067,7 @@ class EGraph:
1033
1067
  split_primitive_outputs = kwargs.pop("split_primitive_outputs", True)
1034
1068
  split_functions = kwargs.pop("split_functions", [])
1035
1069
  include_temporary_functions = kwargs.pop("include_temporary_functions", False)
1036
- n_inline_leaves = kwargs.pop("n_inline_leaves", 1)
1070
+ n_inline_leaves = kwargs.pop("n_inline_leaves", 0)
1037
1071
  serialized = self._egraph.serialize(
1038
1072
  [],
1039
1073
  max_functions=max_functions,
@@ -1209,16 +1243,12 @@ class EGraph:
1209
1243
  """
1210
1244
  (output,) = self._egraph.run_program(bindings.PrintSize(span(1), None))
1211
1245
  assert isinstance(output, bindings.PrintAllFunctionsSize)
1246
+ return [(callables[0], size) for (name, size) in output.sizes if (callables := self._egg_fn_to_callables(name))]
1247
+
1248
+ def _egg_fn_to_callables(self, egg_fn: str) -> list[ExprCallable]:
1212
1249
  return [
1213
- (
1214
- cast(
1215
- "ExprCallable",
1216
- create_callable(self._state.__egg_decls__, next(iter(refs))),
1217
- ),
1218
- size,
1219
- )
1220
- for (name, size) in output.sizes
1221
- if (refs := self._state.egg_fn_to_callable_refs[name])
1250
+ cast("ExprCallable", create_callable(self._state.__egg_decls__, ref))
1251
+ for ref in self._state.egg_fn_to_callable_refs[egg_fn]
1222
1252
  ]
1223
1253
 
1224
1254
  def function_values(
@@ -1242,6 +1272,33 @@ class EGraph:
1242
1272
  for (call, res) in output.terms
1243
1273
  }
1244
1274
 
1275
+ def lookup_function_value(self, expr: BASE_EXPR) -> BASE_EXPR | None:
1276
+ """
1277
+ Given an expression that is a function call, looks up the value of the function call if it exists.
1278
+ """
1279
+ runtime_expr = to_runtime_expr(expr)
1280
+ typed_expr = runtime_expr.__egg_typed_expr__
1281
+ assert isinstance(typed_expr.expr, CallDecl | GetCostDecl)
1282
+ egg_fn, typed_args = self._state.translate_call(typed_expr.expr)
1283
+ values_args = [self._state.typed_expr_to_value(a) for a in typed_args]
1284
+ possible_value = self._egraph.lookup_function(egg_fn, values_args)
1285
+ if possible_value is None:
1286
+ return None
1287
+ return cast(
1288
+ "BASE_EXPR",
1289
+ RuntimeExpr.__from_values__(
1290
+ self.__egg_decls__,
1291
+ TypedExprDecl(typed_expr.tp, self._state.value_to_expr(typed_expr.tp, possible_value)),
1292
+ ),
1293
+ )
1294
+
1295
+ def has_custom_cost(self, fn: ExprCallable) -> bool:
1296
+ """
1297
+ Checks if the any custom costs have been set for this expression callable.
1298
+ """
1299
+ resolved, _ = resolve_callable(fn)
1300
+ return resolved in self._state.cost_callables
1301
+
1245
1302
 
1246
1303
  # Either a constant or a function.
1247
1304
  ExprCallable: TypeAlias = Callable[..., BaseExpr] | BaseExpr
@@ -1786,17 +1843,51 @@ def to_runtime_expr(expr: BaseExpr) -> RuntimeExpr:
1786
1843
  return expr
1787
1844
 
1788
1845
 
1789
- def run(ruleset: Ruleset | None = None, *until: FactLike) -> Schedule:
1846
+ def run(ruleset: Ruleset | None = None, *until: FactLike, scheduler: BackOff | None = None) -> Schedule:
1790
1847
  """
1791
1848
  Create a run configuration.
1792
1849
  """
1793
1850
  facts = _fact_likes(until)
1794
1851
  return Schedule(
1795
1852
  Thunk.fn(Declarations.create, ruleset, *facts),
1796
- RunDecl(ruleset.__egg_name__ if ruleset else "", tuple(f.fact for f in facts) or None),
1853
+ RunDecl(
1854
+ ruleset.__egg_name__ if ruleset else "",
1855
+ tuple(f.fact for f in facts) or None,
1856
+ scheduler.scheduler if scheduler else None,
1857
+ ),
1797
1858
  )
1798
1859
 
1799
1860
 
1861
+ def back_off(match_limit: None | int = None, ban_length: None | int = None) -> BackOff:
1862
+ """
1863
+ Create a backoff scheduler configuration.
1864
+
1865
+ ```python
1866
+ schedule = run(analysis_ruleset).saturate() + run(ruleset, scheduler=back_off(match_limit=1000, ban_length=5)) * 10
1867
+ ```
1868
+ This will run the `analysis_ruleset` until saturation, then run `ruleset` 10 times, using a backoff scheduler.
1869
+ """
1870
+ return BackOff(BackOffDecl(id=uuid4(), match_limit=match_limit, ban_length=ban_length))
1871
+
1872
+
1873
+ @dataclass(frozen=True)
1874
+ class BackOff:
1875
+ scheduler: BackOffDecl
1876
+
1877
+ def scope(self, schedule: Schedule) -> Schedule:
1878
+ """
1879
+ Defines the scheduler to be created directly before the inner schedule, instead of the default which is at the
1880
+ most outer scope.
1881
+ """
1882
+ return Schedule(schedule.__egg_decls_thunk__, LetSchedulerDecl(self.scheduler, schedule.schedule))
1883
+
1884
+ def __str__(self) -> str:
1885
+ return pretty_decl(Declarations(), self.scheduler)
1886
+
1887
+ def __repr__(self) -> str:
1888
+ return str(self)
1889
+
1890
+
1800
1891
  def seq(*schedules: Schedule) -> Schedule:
1801
1892
  """
1802
1893
  Run a sequence of schedules.
@@ -1873,3 +1964,246 @@ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
1873
1964
  yield
1874
1965
  finally:
1875
1966
  _CURRENT_RULESET.reset(token)
1967
+
1968
+
1969
+ def get_cost(expr: BaseExpr) -> i64:
1970
+ """
1971
+ Return a lookup of the cost of an expression. If not set, won't match.
1972
+ """
1973
+ assert isinstance(expr, RuntimeExpr)
1974
+ expr_decl = expr.__egg_typed_expr__.expr
1975
+ if not isinstance(expr_decl, CallDecl):
1976
+ msg = "Can only get cost of function calls, not literals or variables"
1977
+ raise TypeError(msg)
1978
+ return RuntimeExpr.__from_values__(
1979
+ expr.__egg_decls__,
1980
+ TypedExprDecl(JustTypeRef("i64"), GetCostDecl(expr_decl.callable, expr_decl.args)),
1981
+ )
1982
+
1983
+
1984
+ class Comparable(Protocol):
1985
+ def __lt__(self, other: Self) -> bool: ...
1986
+ def __le__(self, other: Self) -> bool: ...
1987
+ def __gt__(self, other: Self) -> bool: ...
1988
+ def __ge__(self, other: Self) -> bool: ...
1989
+
1990
+
1991
+ COST = TypeVar("COST", bound=Comparable)
1992
+
1993
+
1994
+ class CostModel(Protocol, Generic[COST]):
1995
+ """
1996
+ A cost model for an e-graph. Used to determine the cost of an expression based on its structure and the costs of its sub-expressions.
1997
+
1998
+ Called with an expression and the costs of its children, returns the total cost of the expression.
1999
+
2000
+ Additionally, the cost model should guarantee that a term has a no-smaller cost
2001
+ than its subterms to avoid cycles in the extracted terms for common case usages.
2002
+ For more niche usages, a term can have a cost less than its subterms.
2003
+ As long as there is no negative cost cycle, the default extractor is guaranteed to terminate in computing the costs.
2004
+ However, the user needs to be careful to guarantee acyclicity in the extracted terms.
2005
+ """
2006
+
2007
+ def __call__(self, egraph: EGraph, expr: BaseExpr, children_costs: list[COST]) -> COST:
2008
+ """
2009
+ The total cost of a term given the cost of the root e-node and its immediate children's total costs.
2010
+ """
2011
+ raise NotImplementedError
2012
+
2013
+
2014
+ def default_cost_model(egraph: EGraph, expr: BaseExpr, children_costs: list[int]) -> int:
2015
+ """
2016
+ A default cost model for an e-graph, which looks up costs set on function calls, or uses 1 as the default cost.
2017
+ """
2018
+ from .builtins import Container # noqa: PLC0415
2019
+ from .deconstruct import get_callable_fn # noqa: PLC0415
2020
+
2021
+ # 1. First prefer if the expr has a custom cost set on it
2022
+ if (
2023
+ (callable_fn := get_callable_fn(expr)) is not None
2024
+ and egraph.has_custom_cost(callable_fn)
2025
+ and (i := egraph.lookup_function_value(get_cost(expr))) is not None
2026
+ ):
2027
+ self_cost = int(i)
2028
+ # 2. Else, check if this is a callable and it has a cost set on its declaration
2029
+ elif callable_fn is not None and (callable_cost := get_callable_cost(callable_fn)) is not None:
2030
+ self_cost = callable_cost
2031
+ # 3. Else, if this is a container, it has no cost, otherwise it has a cost of 1
2032
+ else:
2033
+ # By default, all nodes have a cost of 1 except for containers which have a cost of 0
2034
+ self_cost = 0 if isinstance(expr, Container) else 1
2035
+ # Sum up the costs of the children and our own cost
2036
+ return sum(children_costs, start=self_cost)
2037
+
2038
+
2039
+ class ComparableAddSub(Comparable, Protocol):
2040
+ def __add__(self, other: Self) -> Self: ...
2041
+ def __sub__(self, other: Self) -> Self: ...
2042
+
2043
+
2044
+ DAG_COST = TypeVar("DAG_COST", bound=ComparableAddSub)
2045
+
2046
+
2047
+ @dataclass
2048
+ class GreedyDagCost(Generic[DAG_COST]):
2049
+ """
2050
+ Cost of a DAG, which stores children costs. Use `.total` to get the underlying cost.
2051
+ """
2052
+
2053
+ total: DAG_COST
2054
+ _costs: dict[TypedExprDecl, DAG_COST] = field(repr=False)
2055
+
2056
+ def __eq__(self, other: object) -> bool:
2057
+ if not isinstance(other, GreedyDagCost):
2058
+ return NotImplemented
2059
+ return self.total == other.total
2060
+
2061
+ def __lt__(self, other: Self) -> bool:
2062
+ return self.total < other.total
2063
+
2064
+ def __le__(self, other: Self) -> bool:
2065
+ return self.total <= other.total
2066
+
2067
+ def __gt__(self, other: Self) -> bool:
2068
+ return self.total > other.total
2069
+
2070
+ def __ge__(self, other: Self) -> bool:
2071
+ return self.total >= other.total
2072
+
2073
+ def __hash__(self) -> int:
2074
+ return hash(self.total)
2075
+
2076
+
2077
+ @dataclass
2078
+ class GreedyDagCostModel(CostModel[GreedyDagCost[DAG_COST]]):
2079
+ """
2080
+ A cost model which will count duplicate nodes only once.
2081
+
2082
+ Should have similar behavior as https://github.com/egraphs-good/extraction-gym/blob/main/src/extract/greedy_dag.rs
2083
+ but implemented as a cost model that will be used with the default extractor.
2084
+ """
2085
+
2086
+ base: CostModel[DAG_COST]
2087
+
2088
+ def __call__(
2089
+ self, egraph: EGraph, expr: BaseExpr, children_costs: list[GreedyDagCost[DAG_COST]]
2090
+ ) -> GreedyDagCost[DAG_COST]:
2091
+ cost = self.base(egraph, expr, [c.total for c in children_costs])
2092
+ for c in children_costs:
2093
+ cost -= c.total
2094
+ costs = {}
2095
+ for c in children_costs:
2096
+ costs.update(c._costs)
2097
+ total = sum(costs.values(), start=cost)
2098
+ costs[to_runtime_expr(expr).__egg_typed_expr__] = cost
2099
+ return GreedyDagCost(total, costs)
2100
+
2101
+
2102
+ @overload
2103
+ def greedy_dag_cost_model() -> CostModel[GreedyDagCost[int]]: ...
2104
+
2105
+
2106
+ @overload
2107
+ def greedy_dag_cost_model(base: CostModel[DAG_COST]) -> CostModel[GreedyDagCost[DAG_COST]]: ...
2108
+
2109
+
2110
+ def greedy_dag_cost_model(base: CostModel[Any] = default_cost_model) -> CostModel[GreedyDagCost[Any]]:
2111
+ """
2112
+ Creates a greedy dag cost model from a base cost model.
2113
+ """
2114
+ return GreedyDagCostModel(base or default_cost_model)
2115
+
2116
+
2117
+ def get_callable_cost(fn: ExprCallable) -> int | None:
2118
+ """
2119
+ Returns the cost of a callable, if it has one set. Otherwise returns None.
2120
+ """
2121
+ callable_ref, decls = resolve_callable(fn)
2122
+ callable_decl = decls.get_callable_decl(callable_ref)
2123
+ return callable_decl.cost if isinstance(callable_decl, ConstructorDecl) else 1
2124
+
2125
+
2126
+ @dataclass
2127
+ class _CostModel(Generic[COST]):
2128
+ """
2129
+ Implements the methods compatible with the bindings for the cost model.
2130
+ """
2131
+
2132
+ model: CostModel[COST]
2133
+ egraph: EGraph
2134
+ enode_cost_results: dict[tuple[str, tuple[bindings.Value, ...]], int] = field(default_factory=dict)
2135
+ enode_cost_expressions: list[RuntimeExpr] = field(default_factory=list)
2136
+ fold_results: dict[tuple[int, tuple[COST, ...]], COST] = field(default_factory=dict)
2137
+ base_value_cost_results: dict[tuple[str, bindings.Value], COST] = field(default_factory=dict)
2138
+ container_cost_results: dict[tuple[str, bindings.Value, tuple[COST, ...]], COST] = field(default_factory=dict)
2139
+
2140
+ def call_model(self, expr: RuntimeExpr, children_costs: list[COST]) -> COST:
2141
+ return self.model(self.egraph, cast("BaseExpr", expr), children_costs)
2142
+ # if __debug__:
2143
+ # for c in children_costs:
2144
+ # if res <= c:
2145
+ # msg = f"Cost model {self.model} produced a cost {res} less than or equal to a child cost {c} for {expr}"
2146
+ # raise ValueError(msg)
2147
+
2148
+ def fold(self, _fn: str, index: int, children_costs: list[COST]) -> COST:
2149
+ try:
2150
+ return self.fold_results[(index, tuple(children_costs))]
2151
+ except KeyError:
2152
+ pass
2153
+
2154
+ expr = self.enode_cost_expressions[index]
2155
+ return self.call_model(expr, children_costs)
2156
+
2157
+ # enode cost is only ever called right before fold, for the head_cost
2158
+ def enode_cost(self, name: str, args: list[bindings.Value]) -> int:
2159
+ try:
2160
+ return self.enode_cost_results[(name, tuple(args))]
2161
+ except KeyError:
2162
+ pass
2163
+ (callable_ref,) = self.egraph._state.egg_fn_to_callable_refs[name]
2164
+ signature = self.egraph.__egg_decls__.get_callable_decl(callable_ref).signature
2165
+ assert isinstance(signature, FunctionSignature)
2166
+ arg_exprs = [
2167
+ TypedExprDecl(tp.to_just(), self.egraph._state.value_to_expr(tp.to_just(), arg))
2168
+ for (arg, tp) in zip(args, signature.arg_types, strict=True)
2169
+ ]
2170
+ res_type = signature.semantic_return_type.to_just()
2171
+ res = RuntimeExpr.__from_values__(
2172
+ self.egraph.__egg_decls__,
2173
+ TypedExprDecl(res_type, CallDecl(callable_ref, tuple(arg_exprs))),
2174
+ )
2175
+ index = len(self.enode_cost_expressions)
2176
+ self.enode_cost_expressions.append(res)
2177
+ self.enode_cost_results[(name, tuple(args))] = index
2178
+ return index
2179
+
2180
+ def base_value_cost(self, tp: str, value: bindings.Value) -> COST:
2181
+ try:
2182
+ return self.base_value_cost_results[(tp, value)]
2183
+ except KeyError:
2184
+ pass
2185
+ type_ref = self.egraph._state.egg_sort_to_type_ref[tp]
2186
+ expr = RuntimeExpr.__from_values__(
2187
+ self.egraph.__egg_decls__,
2188
+ TypedExprDecl(type_ref, self.egraph._state.value_to_expr(type_ref, value)),
2189
+ )
2190
+ res = self.call_model(expr, [])
2191
+ self.base_value_cost_results[(tp, value)] = res
2192
+ return res
2193
+
2194
+ def container_cost(self, tp: str, value: bindings.Value, element_costs: list[COST]) -> COST:
2195
+ try:
2196
+ return self.container_cost_results[(tp, value, tuple(element_costs))]
2197
+ except KeyError:
2198
+ pass
2199
+ type_ref = self.egraph._state.egg_sort_to_type_ref[tp]
2200
+ expr = RuntimeExpr.__from_values__(
2201
+ self.egraph.__egg_decls__,
2202
+ TypedExprDecl(type_ref, self.egraph._state.value_to_expr(type_ref, value)),
2203
+ )
2204
+ res = self.call_model(expr, element_costs)
2205
+ self.container_cost_results[(tp, value, tuple(element_costs))] = res
2206
+ return res
2207
+
2208
+ def to_bindings_cost_model(self) -> bindings.CostModel[COST, int]:
2209
+ return bindings.CostModel(self.fold, self.enode_cost, self.container_cost, self.base_value_cost)