egglog 11.3.0__cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl → 11.4.0__cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/bindings.pyi CHANGED
@@ -1,6 +1,8 @@
1
+ from collections.abc import Callable
1
2
  from datetime import timedelta
3
+ from fractions import Fraction
2
4
  from pathlib import Path
3
- from typing import TypeAlias
5
+ from typing import Any, Generic, Protocol, TypeAlias, TypeVar
4
6
 
5
7
  from typing_extensions import final
6
8
 
@@ -14,6 +16,7 @@ __all__ = [
14
16
  "Change",
15
17
  "Check",
16
18
  "Constructor",
19
+ "CostModel",
17
20
  "Datatype",
18
21
  "Datatypes",
19
22
  "DefaultPrintFunctionMode",
@@ -26,6 +29,7 @@ __all__ = [
26
29
  "Extract",
27
30
  "ExtractBest",
28
31
  "ExtractVariants",
32
+ "Extractor",
29
33
  "Fact",
30
34
  "Fail",
31
35
  "Float",
@@ -83,6 +87,7 @@ __all__ = [
83
87
  "UserDefined",
84
88
  "UserDefinedCommandOutput",
85
89
  "UserDefinedOutput",
90
+ "Value",
86
91
  "Var",
87
92
  "Variant",
88
93
  ]
@@ -128,6 +133,31 @@ class EGraph:
128
133
  max_calls_per_function: int | None = None,
129
134
  include_temporary_functions: bool = False,
130
135
  ) -> SerializedEGraph: ...
136
+ def lookup_function(self, name: str, key: list[Value]) -> Value | None: ...
137
+ def eval_expr(self, expr: _Expr) -> tuple[str, Value]: ...
138
+ def value_to_i64(self, v: Value) -> int: ...
139
+ def value_to_f64(self, v: Value) -> float: ...
140
+ def value_to_string(self, v: Value) -> str: ...
141
+ def value_to_bool(self, v: Value) -> bool: ...
142
+ def value_to_rational(self, v: Value) -> Fraction: ...
143
+ def value_to_bigint(self, v: Value) -> int: ...
144
+ def value_to_bigrat(self, v: Value) -> Fraction: ...
145
+ def value_to_pyobject(self, py_object_sort: PyObjectSort, v: Value) -> object: ...
146
+ def value_to_map(self, v: Value) -> dict[Value, Value]: ...
147
+ def value_to_multiset(self, v: Value) -> list[Value]: ...
148
+ def value_to_vec(self, v: Value) -> list[Value]: ...
149
+ def value_to_function(self, v: Value) -> tuple[str, list[Value]]: ...
150
+ def value_to_set(self, v: Value) -> set[Value]: ...
151
+ # def dynamic_cost_model_enode_cost(self, func: str, args: list[Value]) -> int: ...
152
+
153
+ @final
154
+ class Value:
155
+ def __hash__(self) -> int: ...
156
+ def __eq__(self, value: object) -> bool: ...
157
+ def __lt__(self, other: object) -> bool: ...
158
+ def __le__(self, other: object) -> bool: ...
159
+ def __gt__(self, other: object) -> bool: ...
160
+ def __ge__(self, other: object) -> bool: ...
131
161
 
132
162
  @final
133
163
  class EggSmolError(Exception):
@@ -732,3 +762,34 @@ class TermDag:
732
762
  def expr_to_term(self, expr: _Expr) -> _Term: ...
733
763
  def term_to_expr(self, term: _Term, span: _Span) -> _Expr: ...
734
764
  def to_string(self, term: _Term) -> str: ...
765
+
766
+ ##
767
+ # Extraction
768
+ ##
769
+ class _Cost(Protocol):
770
+ def __lt__(self, other: _Cost) -> bool: ...
771
+ def __le__(self, other: _Cost) -> bool: ...
772
+ def __gt__(self, other: _Cost) -> bool: ...
773
+ def __ge__(self, other: _Cost) -> bool: ...
774
+
775
+ _COST = TypeVar("_COST", bound=_Cost)
776
+
777
+ _ENODE_COST = TypeVar("_ENODE_COST")
778
+
779
+ @final
780
+ class CostModel(Generic[_COST, _ENODE_COST]):
781
+ def __init__(
782
+ self,
783
+ fold: Callable[[str, _ENODE_COST, list[_COST]], _COST],
784
+ enode_cost: Callable[[str, list[Value]], _ENODE_COST],
785
+ container_cost: Callable[[str, Value, list[_COST]], _COST],
786
+ base_value_cost: Callable[[str, Value], _COST],
787
+ ) -> None: ...
788
+
789
+ @final
790
+ class Extractor(Generic[_COST]):
791
+ def __init__(self, rootsorts: list[str] | None, egraph: EGraph, cost_model: CostModel[_COST, Any]) -> None: ...
792
+ def extract_best(self, egraph: EGraph, termdag: TermDag, value: Value, sort: str) -> tuple[_COST, _Term]: ...
793
+ def extract_variants(
794
+ self, egraph: EGraph, termdag: TermDag, value: Value, nvariants: int, sort: str
795
+ ) -> list[tuple[_COST, _Term]]: ...
egglog/builtins.py CHANGED
@@ -33,10 +33,12 @@ __all__ = [
33
33
  "BigRatLike",
34
34
  "Bool",
35
35
  "BoolLike",
36
+ "Container",
36
37
  "ExprValueError",
37
38
  "Map",
38
39
  "MapLike",
39
40
  "MultiSet",
41
+ "Primitive",
40
42
  "PyObject",
41
43
  "Rational",
42
44
  "Set",
@@ -1135,3 +1137,7 @@ def _convert_function(fn: FunctionType) -> UnstableFn:
1135
1137
 
1136
1138
 
1137
1139
  converter(FunctionType, UnstableFn, _convert_function)
1140
+
1141
+
1142
+ Container: TypeAlias = Map | Set | MultiSet | Vec | UnstableFn
1143
+ Primitive: TypeAlias = String | Bool | i64 | f64 | Rational | BigInt | BigRat | PyObject | Unit
egglog/declarations.py CHANGED
@@ -14,6 +14,8 @@ from weakref import WeakValueDictionary
14
14
 
15
15
  from typing_extensions import Self, assert_never
16
16
 
17
+ from .bindings import Value
18
+
17
19
  if TYPE_CHECKING:
18
20
  from collections.abc import Callable, Iterable, Mapping
19
21
 
@@ -49,6 +51,7 @@ __all__ = [
49
51
  "FunctionDecl",
50
52
  "FunctionRef",
51
53
  "FunctionSignature",
54
+ "GetCostDecl",
52
55
  "HasDeclerations",
53
56
  "InitRef",
54
57
  "JustTypeRef",
@@ -82,6 +85,7 @@ __all__ = [
82
85
  "UnboundVarDecl",
83
86
  "UnionDecl",
84
87
  "UnnamedFunctionRef",
88
+ "ValueDecl",
85
89
  "collect_unbound_vars",
86
90
  "replace_typed_expr",
87
91
  "upcast_declerations",
@@ -639,7 +643,7 @@ class CallDecl:
639
643
  args: tuple[TypedExprDecl, ...] = ()
640
644
  # type parameters that were bound to the callable, if it is a classmethod
641
645
  # Used for pretty printing classmethod calls with type parameters
642
- bound_tp_params: tuple[JustTypeRef, ...] | None = None
646
+ bound_tp_params: tuple[JustTypeRef, ...] = ()
643
647
 
644
648
  # pool objects for faster __eq__
645
649
  _args_to_value: ClassVar[WeakValueDictionary[tuple[object, ...], CallDecl]] = WeakValueDictionary({})
@@ -654,7 +658,7 @@ class CallDecl:
654
658
  # normalize the args/kwargs to a tuple so that they can be compared
655
659
  callable = args[0] if args else kwargs["callable"]
656
660
  args_ = args[1] if len(args) > 1 else kwargs.get("args", ())
657
- bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params")
661
+ bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params", ())
658
662
 
659
663
  normalized_args = (callable, args_, bound_tp_params)
660
664
  try:
@@ -696,7 +700,20 @@ class PartialCallDecl:
696
700
  call: CallDecl
697
701
 
698
702
 
699
- ExprDecl: TypeAlias = UnboundVarDecl | LetRefDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl
703
+ @dataclass(frozen=True)
704
+ class GetCostDecl:
705
+ callable: CallableRef
706
+ args: tuple[TypedExprDecl, ...]
707
+
708
+
709
+ @dataclass(frozen=True)
710
+ class ValueDecl:
711
+ value: Value
712
+
713
+
714
+ ExprDecl: TypeAlias = (
715
+ UnboundVarDecl | LetRefDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl | ValueDecl | GetCostDecl
716
+ )
700
717
 
701
718
 
702
719
  @dataclass(frozen=True)
egglog/deconstruct.py CHANGED
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, TypeVar, overload
11
11
  from typing_extensions import TypeVarTuple, Unpack
12
12
 
13
13
  from .declarations import *
14
- from .egraph import BaseExpr
14
+ from .egraph import BaseExpr, Expr
15
15
  from .runtime import *
16
16
  from .thunk import *
17
17
 
@@ -49,7 +49,11 @@ def get_literal_value(x: PyObject) -> object: ...
49
49
  def get_literal_value(x: UnstableFn[T, Unpack[TS]]) -> Callable[[Unpack[TS]], T] | None: ...
50
50
 
51
51
 
52
- def get_literal_value(x: String | Bool | i64 | f64 | PyObject | UnstableFn) -> object:
52
+ @overload
53
+ def get_literal_value(x: Expr) -> None: ...
54
+
55
+
56
+ def get_literal_value(x: object) -> object:
53
57
  """
54
58
  Returns the literal value of an expression if it is a literal.
55
59
  If it is not a literal, returns None.
@@ -95,12 +99,9 @@ def get_var_name(x: BaseExpr) -> str | None:
95
99
  return None
96
100
 
97
101
 
98
- def get_callable_fn(x: T) -> Callable[..., T] | None:
102
+ def get_callable_fn(x: T) -> Callable[..., T] | T | None:
99
103
  """
100
- Gets the function of an expression if it is a call expression.
101
- If it is not a call expression (a property, a primitive value, constants, classvars, a let value), return None.
102
- For those values, you can check them by comparing them directly with equality or for primitives calling `.eval()`
103
- to return the Python value.
104
+ Gets the function of an expression, or if it's a constant or classvar, return that.
104
105
  """
105
106
  if not isinstance(x, RuntimeExpr):
106
107
  raise TypeError(f"Expected Expression, got {type(x).__name__}")
@@ -159,6 +160,7 @@ def _deconstruct_call_decl(
159
160
  """
160
161
  args = call.args
161
162
  arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args)
163
+ # TODO: handle values? Like constants
162
164
  if isinstance(call.callable, InitRef):
163
165
  return RuntimeClass(
164
166
  decls_thunk,
egglog/egraph.py CHANGED
@@ -16,6 +16,7 @@ from typing import (
16
16
  ClassVar,
17
17
  Generic,
18
18
  Literal,
19
+ Protocol,
19
20
  TypeAlias,
20
21
  TypedDict,
21
22
  TypeVar,
@@ -41,7 +42,7 @@ from .thunk import *
41
42
  from .version_compat import *
42
43
 
43
44
  if TYPE_CHECKING:
44
- from .builtins import String, Unit, i64Like
45
+ from .builtins import String, Unit, i64, i64Like
45
46
 
46
47
 
47
48
  __all__ = [
@@ -51,11 +52,14 @@ __all__ = [
51
52
  "BuiltinExpr",
52
53
  "Command",
53
54
  "Command",
55
+ "CostModel",
54
56
  "EGraph",
55
57
  "Expr",
58
+ "ExprCallable",
56
59
  "Fact",
57
60
  "Fact",
58
61
  "GraphvizKwargs",
62
+ "GreedyDagCost",
59
63
  "RewriteOrRule",
60
64
  "Ruleset",
61
65
  "Schedule",
@@ -70,12 +74,15 @@ __all__ = [
70
74
  "check",
71
75
  "check_eq",
72
76
  "constant",
77
+ "default_cost_model",
73
78
  "delete",
74
79
  "eq",
75
80
  "expr_action",
76
81
  "expr_fact",
77
82
  "expr_parts",
78
83
  "function",
84
+ "get_cost",
85
+ "greedy_dag_cost_model",
79
86
  "let",
80
87
  "method",
81
88
  "ne",
@@ -88,6 +95,7 @@ __all__ = [
88
95
  "seq",
89
96
  "set_",
90
97
  "set_cost",
98
+ "set_current_ruleset",
91
99
  "subsume",
92
100
  "union",
93
101
  "unstable_combine_rulesets",
@@ -452,7 +460,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
452
460
  continue
453
461
  locals = frame.f_locals
454
462
  ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
455
- # TODO: Store deprecated message so we can print at runtime
463
+ # TODO: Store deprecated message so we can get at runtime
456
464
  if (getattr(fn, "__deprecated__", None)) is not None:
457
465
  fn = fn.__wrapped__ # type: ignore[attr-defined]
458
466
  match fn:
@@ -953,22 +961,45 @@ class EGraph:
953
961
  return bindings.Check(span(2), egg_facts)
954
962
 
955
963
  @overload
956
- def extract(self, expr: BASE_EXPR, /, include_cost: Literal[False] = False) -> BASE_EXPR: ...
964
+ def extract(
965
+ self, expr: BASE_EXPR, /, include_cost: Literal[False] = False, cost_model: CostModel | None = None
966
+ ) -> BASE_EXPR: ...
957
967
 
958
968
  @overload
959
- def extract(self, expr: BASE_EXPR, /, include_cost: Literal[True]) -> tuple[BASE_EXPR, int]: ...
969
+ def extract(
970
+ self, expr: BASE_EXPR, /, include_cost: Literal[True], cost_model: None = None
971
+ ) -> tuple[BASE_EXPR, int]: ...
960
972
 
961
- def extract(self, expr: BASE_EXPR, include_cost: bool = False) -> BASE_EXPR | tuple[BASE_EXPR, int]:
973
+ @overload
974
+ def extract(
975
+ self, expr: BASE_EXPR, /, include_cost: Literal[True], cost_model: CostModel[COST]
976
+ ) -> tuple[BASE_EXPR, COST]: ...
977
+
978
+ def extract(
979
+ self, expr: BASE_EXPR, /, include_cost: bool = False, cost_model: CostModel[COST] | None = None
980
+ ) -> BASE_EXPR | tuple[BASE_EXPR, COST]:
962
981
  """
963
982
  Extract the lowest cost expression from the egraph.
964
983
  """
965
984
  runtime_expr = to_runtime_expr(expr)
966
- extract_report = self._run_extract(runtime_expr, 0)
967
- assert isinstance(extract_report, bindings.ExtractBest)
968
- res = self._from_termdag(extract_report.termdag, extract_report.term, runtime_expr.__egg_typed_expr__.tp)
969
- if include_cost:
970
- return res, extract_report.cost
971
- return res
985
+ self._add_decls(runtime_expr)
986
+ tp = runtime_expr.__egg_typed_expr__.tp
987
+ if cost_model is None:
988
+ extract_report = self._run_extract(runtime_expr, 0)
989
+ assert isinstance(extract_report, bindings.ExtractBest)
990
+ res = self._from_termdag(extract_report.termdag, extract_report.term, tp)
991
+ cost = cast("COST", extract_report.cost)
992
+ else:
993
+ # TODO: For some reason we need this or else it wont be registered. Not sure why
994
+ self.register(expr)
995
+ egg_cost_model = _CostModel(cost_model, self).to_bindings_cost_model()
996
+ egg_sort = self._state.type_ref_to_egg(tp)
997
+ extractor = bindings.Extractor([egg_sort], self._state.egraph, egg_cost_model)
998
+ termdag = bindings.TermDag()
999
+ value = self._state.typed_expr_to_value(runtime_expr.__egg_typed_expr__)
1000
+ cost, term = extractor.extract_best(self._state.egraph, termdag, value, egg_sort)
1001
+ res = self._from_termdag(termdag, term, tp)
1002
+ return (res, cost) if include_cost else res
972
1003
 
973
1004
  def _from_termdag(self, termdag: bindings.TermDag, term: bindings._Term, tp: JustTypeRef) -> Any:
974
1005
  (new_typed_expr,) = self._state.exprs_from_egg(termdag, [term], tp)
@@ -979,6 +1010,7 @@ class EGraph:
979
1010
  Extract multiple expressions from the egraph.
980
1011
  """
981
1012
  runtime_expr = to_runtime_expr(expr)
1013
+ self._add_decls(runtime_expr)
982
1014
  extract_report = self._run_extract(runtime_expr, n)
983
1015
  assert isinstance(extract_report, bindings.ExtractVariants)
984
1016
  new_exprs = self._state.exprs_from_egg(
@@ -987,7 +1019,6 @@ class EGraph:
987
1019
  return [cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs]
988
1020
 
989
1021
  def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._CommandOutput:
990
- self._add_decls(expr)
991
1022
  expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
992
1023
  # If we have defined any cost tables use the custom extraction
993
1024
  args = (expr, bindings.Lit(span(2), bindings.Int(n)))
@@ -1212,16 +1243,12 @@ class EGraph:
1212
1243
  """
1213
1244
  (output,) = self._egraph.run_program(bindings.PrintSize(span(1), None))
1214
1245
  assert isinstance(output, bindings.PrintAllFunctionsSize)
1246
+ return [(callables[0], size) for (name, size) in output.sizes if (callables := self._egg_fn_to_callables(name))]
1247
+
1248
+ def _egg_fn_to_callables(self, egg_fn: str) -> list[ExprCallable]:
1215
1249
  return [
1216
- (
1217
- cast(
1218
- "ExprCallable",
1219
- create_callable(self._state.__egg_decls__, next(iter(refs))),
1220
- ),
1221
- size,
1222
- )
1223
- for (name, size) in output.sizes
1224
- if (refs := self._state.egg_fn_to_callable_refs[name])
1250
+ cast("ExprCallable", create_callable(self._state.__egg_decls__, ref))
1251
+ for ref in self._state.egg_fn_to_callable_refs[egg_fn]
1225
1252
  ]
1226
1253
 
1227
1254
  def function_values(
@@ -1245,6 +1272,33 @@ class EGraph:
1245
1272
  for (call, res) in output.terms
1246
1273
  }
1247
1274
 
1275
+ def lookup_function_value(self, expr: BASE_EXPR) -> BASE_EXPR | None:
1276
+ """
1277
+ Given an expression that is a function call, looks up the value of the function call if it exists.
1278
+ """
1279
+ runtime_expr = to_runtime_expr(expr)
1280
+ typed_expr = runtime_expr.__egg_typed_expr__
1281
+ assert isinstance(typed_expr.expr, CallDecl | GetCostDecl)
1282
+ egg_fn, typed_args = self._state.translate_call(typed_expr.expr)
1283
+ values_args = [self._state.typed_expr_to_value(a) for a in typed_args]
1284
+ possible_value = self._egraph.lookup_function(egg_fn, values_args)
1285
+ if possible_value is None:
1286
+ return None
1287
+ return cast(
1288
+ "BASE_EXPR",
1289
+ RuntimeExpr.__from_values__(
1290
+ self.__egg_decls__,
1291
+ TypedExprDecl(typed_expr.tp, self._state.value_to_expr(typed_expr.tp, possible_value)),
1292
+ ),
1293
+ )
1294
+
1295
+ def has_custom_cost(self, fn: ExprCallable) -> bool:
1296
+ """
1297
+ Checks if the any custom costs have been set for this expression callable.
1298
+ """
1299
+ resolved, _ = resolve_callable(fn)
1300
+ return resolved in self._state.cost_callables
1301
+
1248
1302
 
1249
1303
  # Either a constant or a function.
1250
1304
  ExprCallable: TypeAlias = Callable[..., BaseExpr] | BaseExpr
@@ -1910,3 +1964,246 @@ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
1910
1964
  yield
1911
1965
  finally:
1912
1966
  _CURRENT_RULESET.reset(token)
1967
+
1968
+
1969
+ def get_cost(expr: BaseExpr) -> i64:
1970
+ """
1971
+ Return a lookup of the cost of an expression. If not set, won't match.
1972
+ """
1973
+ assert isinstance(expr, RuntimeExpr)
1974
+ expr_decl = expr.__egg_typed_expr__.expr
1975
+ if not isinstance(expr_decl, CallDecl):
1976
+ msg = "Can only get cost of function calls, not literals or variables"
1977
+ raise TypeError(msg)
1978
+ return RuntimeExpr.__from_values__(
1979
+ expr.__egg_decls__,
1980
+ TypedExprDecl(JustTypeRef("i64"), GetCostDecl(expr_decl.callable, expr_decl.args)),
1981
+ )
1982
+
1983
+
1984
+ class Comparable(Protocol):
1985
+ def __lt__(self, other: Self) -> bool: ...
1986
+ def __le__(self, other: Self) -> bool: ...
1987
+ def __gt__(self, other: Self) -> bool: ...
1988
+ def __ge__(self, other: Self) -> bool: ...
1989
+
1990
+
1991
+ COST = TypeVar("COST", bound=Comparable)
1992
+
1993
+
1994
+ class CostModel(Protocol, Generic[COST]):
1995
+ """
1996
+ A cost model for an e-graph. Used to determine the cost of an expression based on its structure and the costs of its sub-expressions.
1997
+
1998
+ Called with an expression and the costs of its children, returns the total cost of the expression.
1999
+
2000
+ Additionally, the cost model should guarantee that a term has a no-smaller cost
2001
+ than its subterms to avoid cycles in the extracted terms for common case usages.
2002
+ For more niche usages, a term can have a cost less than its subterms.
2003
+ As long as there is no negative cost cycle, the default extractor is guaranteed to terminate in computing the costs.
2004
+ However, the user needs to be careful to guarantee acyclicity in the extracted terms.
2005
+ """
2006
+
2007
+ def __call__(self, egraph: EGraph, expr: BaseExpr, children_costs: list[COST]) -> COST:
2008
+ """
2009
+ The total cost of a term given the cost of the root e-node and its immediate children's total costs.
2010
+ """
2011
+ raise NotImplementedError
2012
+
2013
+
2014
+ def default_cost_model(egraph: EGraph, expr: BaseExpr, children_costs: list[int]) -> int:
2015
+ """
2016
+ A default cost model for an e-graph, which looks up costs set on function calls, or uses 1 as the default cost.
2017
+ """
2018
+ from .builtins import Container # noqa: PLC0415
2019
+ from .deconstruct import get_callable_fn # noqa: PLC0415
2020
+
2021
+ # 1. First prefer if the expr has a custom cost set on it
2022
+ if (
2023
+ (callable_fn := get_callable_fn(expr)) is not None
2024
+ and egraph.has_custom_cost(callable_fn)
2025
+ and (i := egraph.lookup_function_value(get_cost(expr))) is not None
2026
+ ):
2027
+ self_cost = int(i)
2028
+ # 2. Else, check if this is a callable and it has a cost set on its declaration
2029
+ elif callable_fn is not None and (callable_cost := get_callable_cost(callable_fn)) is not None:
2030
+ self_cost = callable_cost
2031
+ # 3. Else, if this is a container, it has no cost, otherwise it has a cost of 1
2032
+ else:
2033
+ # By default, all nodes have a cost of 1 except for containers which have a cost of 0
2034
+ self_cost = 0 if isinstance(expr, Container) else 1
2035
+ # Sum up the costs of the children and our own cost
2036
+ return sum(children_costs, start=self_cost)
2037
+
2038
+
2039
+ class ComparableAddSub(Comparable, Protocol):
2040
+ def __add__(self, other: Self) -> Self: ...
2041
+ def __sub__(self, other: Self) -> Self: ...
2042
+
2043
+
2044
+ DAG_COST = TypeVar("DAG_COST", bound=ComparableAddSub)
2045
+
2046
+
2047
+ @dataclass
2048
+ class GreedyDagCost(Generic[DAG_COST]):
2049
+ """
2050
+ Cost of a DAG, which stores children costs. Use `.total` to get the underlying cost.
2051
+ """
2052
+
2053
+ total: DAG_COST
2054
+ _costs: dict[TypedExprDecl, DAG_COST] = field(repr=False)
2055
+
2056
+ def __eq__(self, other: object) -> bool:
2057
+ if not isinstance(other, GreedyDagCost):
2058
+ return NotImplemented
2059
+ return self.total == other.total
2060
+
2061
+ def __lt__(self, other: Self) -> bool:
2062
+ return self.total < other.total
2063
+
2064
+ def __le__(self, other: Self) -> bool:
2065
+ return self.total <= other.total
2066
+
2067
+ def __gt__(self, other: Self) -> bool:
2068
+ return self.total > other.total
2069
+
2070
+ def __ge__(self, other: Self) -> bool:
2071
+ return self.total >= other.total
2072
+
2073
+ def __hash__(self) -> int:
2074
+ return hash(self.total)
2075
+
2076
+
2077
+ @dataclass
2078
+ class GreedyDagCostModel(CostModel[GreedyDagCost[DAG_COST]]):
2079
+ """
2080
+ A cost model which will count duplicate nodes only once.
2081
+
2082
+ Should have similar behavior as https://github.com/egraphs-good/extraction-gym/blob/main/src/extract/greedy_dag.rs
2083
+ but implemented as a cost model that will be used with the default extractor.
2084
+ """
2085
+
2086
+ base: CostModel[DAG_COST]
2087
+
2088
+ def __call__(
2089
+ self, egraph: EGraph, expr: BaseExpr, children_costs: list[GreedyDagCost[DAG_COST]]
2090
+ ) -> GreedyDagCost[DAG_COST]:
2091
+ cost = self.base(egraph, expr, [c.total for c in children_costs])
2092
+ for c in children_costs:
2093
+ cost -= c.total
2094
+ costs = {}
2095
+ for c in children_costs:
2096
+ costs.update(c._costs)
2097
+ total = sum(costs.values(), start=cost)
2098
+ costs[to_runtime_expr(expr).__egg_typed_expr__] = cost
2099
+ return GreedyDagCost(total, costs)
2100
+
2101
+
2102
+ @overload
2103
+ def greedy_dag_cost_model() -> CostModel[GreedyDagCost[int]]: ...
2104
+
2105
+
2106
+ @overload
2107
+ def greedy_dag_cost_model(base: CostModel[DAG_COST]) -> CostModel[GreedyDagCost[DAG_COST]]: ...
2108
+
2109
+
2110
+ def greedy_dag_cost_model(base: CostModel[Any] = default_cost_model) -> CostModel[GreedyDagCost[Any]]:
2111
+ """
2112
+ Creates a greedy dag cost model from a base cost model.
2113
+ """
2114
+ return GreedyDagCostModel(base or default_cost_model)
2115
+
2116
+
2117
+ def get_callable_cost(fn: ExprCallable) -> int | None:
2118
+ """
2119
+ Returns the cost of a callable, if it has one set. Otherwise returns None.
2120
+ """
2121
+ callable_ref, decls = resolve_callable(fn)
2122
+ callable_decl = decls.get_callable_decl(callable_ref)
2123
+ return callable_decl.cost if isinstance(callable_decl, ConstructorDecl) else 1
2124
+
2125
+
2126
+ @dataclass
2127
+ class _CostModel(Generic[COST]):
2128
+ """
2129
+ Implements the methods compatible with the bindings for the cost model.
2130
+ """
2131
+
2132
+ model: CostModel[COST]
2133
+ egraph: EGraph
2134
+ enode_cost_results: dict[tuple[str, tuple[bindings.Value, ...]], int] = field(default_factory=dict)
2135
+ enode_cost_expressions: list[RuntimeExpr] = field(default_factory=list)
2136
+ fold_results: dict[tuple[int, tuple[COST, ...]], COST] = field(default_factory=dict)
2137
+ base_value_cost_results: dict[tuple[str, bindings.Value], COST] = field(default_factory=dict)
2138
+ container_cost_results: dict[tuple[str, bindings.Value, tuple[COST, ...]], COST] = field(default_factory=dict)
2139
+
2140
+ def call_model(self, expr: RuntimeExpr, children_costs: list[COST]) -> COST:
2141
+ return self.model(self.egraph, cast("BaseExpr", expr), children_costs)
2142
+ # if __debug__:
2143
+ # for c in children_costs:
2144
+ # if res <= c:
2145
+ # msg = f"Cost model {self.model} produced a cost {res} less than or equal to a child cost {c} for {expr}"
2146
+ # raise ValueError(msg)
2147
+
2148
+ def fold(self, _fn: str, index: int, children_costs: list[COST]) -> COST:
2149
+ try:
2150
+ return self.fold_results[(index, tuple(children_costs))]
2151
+ except KeyError:
2152
+ pass
2153
+
2154
+ expr = self.enode_cost_expressions[index]
2155
+ return self.call_model(expr, children_costs)
2156
+
2157
+ # enode cost is only ever called right before fold, for the head_cost
2158
+ def enode_cost(self, name: str, args: list[bindings.Value]) -> int:
2159
+ try:
2160
+ return self.enode_cost_results[(name, tuple(args))]
2161
+ except KeyError:
2162
+ pass
2163
+ (callable_ref,) = self.egraph._state.egg_fn_to_callable_refs[name]
2164
+ signature = self.egraph.__egg_decls__.get_callable_decl(callable_ref).signature
2165
+ assert isinstance(signature, FunctionSignature)
2166
+ arg_exprs = [
2167
+ TypedExprDecl(tp.to_just(), self.egraph._state.value_to_expr(tp.to_just(), arg))
2168
+ for (arg, tp) in zip(args, signature.arg_types, strict=True)
2169
+ ]
2170
+ res_type = signature.semantic_return_type.to_just()
2171
+ res = RuntimeExpr.__from_values__(
2172
+ self.egraph.__egg_decls__,
2173
+ TypedExprDecl(res_type, CallDecl(callable_ref, tuple(arg_exprs))),
2174
+ )
2175
+ index = len(self.enode_cost_expressions)
2176
+ self.enode_cost_expressions.append(res)
2177
+ self.enode_cost_results[(name, tuple(args))] = index
2178
+ return index
2179
+
2180
+ def base_value_cost(self, tp: str, value: bindings.Value) -> COST:
2181
+ try:
2182
+ return self.base_value_cost_results[(tp, value)]
2183
+ except KeyError:
2184
+ pass
2185
+ type_ref = self.egraph._state.egg_sort_to_type_ref[tp]
2186
+ expr = RuntimeExpr.__from_values__(
2187
+ self.egraph.__egg_decls__,
2188
+ TypedExprDecl(type_ref, self.egraph._state.value_to_expr(type_ref, value)),
2189
+ )
2190
+ res = self.call_model(expr, [])
2191
+ self.base_value_cost_results[(tp, value)] = res
2192
+ return res
2193
+
2194
+ def container_cost(self, tp: str, value: bindings.Value, element_costs: list[COST]) -> COST:
2195
+ try:
2196
+ return self.container_cost_results[(tp, value, tuple(element_costs))]
2197
+ except KeyError:
2198
+ pass
2199
+ type_ref = self.egraph._state.egg_sort_to_type_ref[tp]
2200
+ expr = RuntimeExpr.__from_values__(
2201
+ self.egraph.__egg_decls__,
2202
+ TypedExprDecl(type_ref, self.egraph._state.value_to_expr(type_ref, value)),
2203
+ )
2204
+ res = self.call_model(expr, element_costs)
2205
+ self.container_cost_results[(tp, value, tuple(element_costs))] = res
2206
+ return res
2207
+
2208
+ def to_bindings_cost_model(self) -> bindings.CostModel[COST, int]:
2209
+ return bindings.CostModel(self.fold, self.enode_cost, self.container_cost, self.base_value_cost)
egglog/egraph_state.py CHANGED
@@ -68,6 +68,7 @@ class EGraphState:
68
68
 
69
69
  # Bidirectional mapping between egg sort names and python type references.
70
70
  type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
71
+ egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)
71
72
 
72
73
  # Cache of egg expressions for converting to egg
73
74
  expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)
@@ -86,6 +87,7 @@ class EGraphState:
86
87
  egg_fn_to_callable_refs=defaultdict(set, {k: v.copy() for k, v in self.egg_fn_to_callable_refs.items()}),
87
88
  callable_ref_to_egg_fn=self.callable_ref_to_egg_fn.copy(),
88
89
  type_ref_to_egg_sort=self.type_ref_to_egg_sort.copy(),
90
+ egg_sort_to_type_ref=self.egg_sort_to_type_ref.copy(),
89
91
  expr_to_egg_cache=self.expr_to_egg_cache.copy(),
90
92
  cost_callables=self.cost_callables.copy(),
91
93
  )
@@ -352,6 +354,7 @@ class EGraphState:
352
354
  Creates the egg cost table if needed and gets the name of the table.
353
355
  """
354
356
  name = self.cost_table_name(ref)
357
+ print(name, self.cost_callables)
355
358
  if ref not in self.cost_callables:
356
359
  self.cost_callables.add(ref)
357
360
  signature = self.__egg_decls__.get_callable_decl(ref).signature
@@ -455,10 +458,14 @@ class EGraphState:
455
458
  pass
456
459
  decl = self.__egg_decls__._classes[ref.name]
457
460
  self.type_ref_to_egg_sort[ref] = egg_name = decl.egg_name or _generate_type_egg_name(ref)
461
+ self.egg_sort_to_type_ref[egg_name] = ref
458
462
  if not decl.builtin or ref.args:
459
463
  if ref.args:
460
464
  if ref.name == "UnstableFn":
461
465
  # UnstableFn is a special case, where the rest of args are collected into a call
466
+ if len(ref.args) < 2:
467
+ msg = "Zero argument higher order functions not supported"
468
+ raise NotImplementedError(msg)
462
469
  type_args: list[bindings._Expr] = [
463
470
  bindings.Call(
464
471
  span(),
@@ -589,11 +596,9 @@ class EGraphState:
589
596
  case _:
590
597
  assert_never(value)
591
598
  res = bindings.Lit(span(), l)
592
- case CallDecl(ref, args, _):
593
- egg_fn, reverse_args = self.callable_ref_to_egg(ref)
594
- egg_args = [self.typed_expr_to_egg(a, False) for a in args]
595
- if reverse_args:
596
- egg_args.reverse()
599
+ case CallDecl() | GetCostDecl():
600
+ egg_fn, typed_args = self.translate_call(expr_decl)
601
+ egg_args = [self.typed_expr_to_egg(a, False) for a in typed_args]
597
602
  res = bindings.Call(span(), egg_fn, egg_args)
598
603
  case PyObjectDecl(value):
599
604
  res = GLOBAL_PY_OBJECT_SORT.store(value)
@@ -604,11 +609,31 @@ class EGraphState:
604
609
  "unstable-fn",
605
610
  [bindings.Lit(span(), bindings.String(egg_fn_call.name)), *egg_fn_call.args],
606
611
  )
612
+ case ValueDecl():
613
+ msg = "Cannot turn a Value into an expression"
614
+ raise ValueError(msg)
607
615
  case _:
608
616
  assert_never(expr_decl.expr)
609
617
  self.expr_to_egg_cache[expr_decl] = res
610
618
  return res
611
619
 
620
+ def translate_call(self, expr: CallDecl | GetCostDecl) -> tuple[str, list[TypedExprDecl]]:
621
+ """
622
+ Handle get cost and call decl, turn into egg table name and typed expr decls.
623
+ """
624
+ match expr:
625
+ case CallDecl(ref, args, _):
626
+ egg_fn, reverse_args = self.callable_ref_to_egg(ref)
627
+ args_list = list(args)
628
+ if reverse_args:
629
+ args_list.reverse()
630
+ return egg_fn, args_list
631
+ case GetCostDecl(ref, args):
632
+ cost_table = self.create_cost_table(ref)
633
+ return cost_table, list(args)
634
+ case _:
635
+ assert_never(expr)
636
+
612
637
  def exprs_from_egg(
613
638
  self, termdag: bindings.TermDag, terms: list[bindings._Term], tp: JustTypeRef
614
639
  ) -> Iterable[TypedExprDecl]:
@@ -652,6 +677,129 @@ class EGraphState:
652
677
  case _:
653
678
  assert_never(ref)
654
679
 
680
+ def typed_expr_to_value(self, typed_expr: TypedExprDecl) -> bindings.Value:
681
+ egg_expr = self.typed_expr_to_egg(typed_expr, False)
682
+ return self.egraph.eval_expr(egg_expr)[1]
683
+
684
+ def value_to_expr(self, tp: JustTypeRef, value: bindings.Value) -> ExprDecl: # noqa: C901, PLR0911, PLR0912
685
+ match tp.name:
686
+ # Should match list in egraph bindings
687
+ case "i64":
688
+ return LitDecl(self.egraph.value_to_i64(value))
689
+ case "f64":
690
+ return LitDecl(self.egraph.value_to_f64(value))
691
+ case "Bool":
692
+ return LitDecl(self.egraph.value_to_bool(value))
693
+ case "String":
694
+ return LitDecl(self.egraph.value_to_string(value))
695
+ case "Unit":
696
+ return LitDecl(None)
697
+ case "PyObject":
698
+ return PyObjectDecl(self.egraph.value_to_pyobject(GLOBAL_PY_OBJECT_SORT, value))
699
+ case "Rational":
700
+ fraction = self.egraph.value_to_rational(value)
701
+ return CallDecl(
702
+ InitRef("Rational"),
703
+ (
704
+ TypedExprDecl(JustTypeRef("i64"), LitDecl(fraction.numerator)),
705
+ TypedExprDecl(JustTypeRef("i64"), LitDecl(fraction.denominator)),
706
+ ),
707
+ )
708
+ case "BigInt":
709
+ i = self.egraph.value_to_bigint(value)
710
+ return CallDecl(
711
+ ClassMethodRef("BigInt", "from_string"),
712
+ (TypedExprDecl(JustTypeRef("String"), LitDecl(str(i))),),
713
+ )
714
+ case "BigRat":
715
+ fraction = self.egraph.value_to_bigrat(value)
716
+ return CallDecl(
717
+ InitRef("BigRat"),
718
+ (
719
+ TypedExprDecl(
720
+ JustTypeRef("BigInt"),
721
+ CallDecl(
722
+ ClassMethodRef("BigInt", "from_string"),
723
+ (TypedExprDecl(JustTypeRef("String"), LitDecl(str(fraction.numerator))),),
724
+ ),
725
+ ),
726
+ TypedExprDecl(
727
+ JustTypeRef("BigInt"),
728
+ CallDecl(
729
+ ClassMethodRef("BigInt", "from_string"),
730
+ (TypedExprDecl(JustTypeRef("String"), LitDecl(str(fraction.denominator))),),
731
+ ),
732
+ ),
733
+ ),
734
+ )
735
+ case "Map":
736
+ k_tp, v_tp = tp.args
737
+ expr = CallDecl(ClassMethodRef("Map", "empty"), (), (k_tp, v_tp))
738
+ for k, v in self.egraph.value_to_map(value).items():
739
+ expr = CallDecl(
740
+ MethodRef("Map", "insert"),
741
+ (
742
+ TypedExprDecl(tp, expr),
743
+ TypedExprDecl(k_tp, self.value_to_expr(k_tp, k)),
744
+ TypedExprDecl(v_tp, self.value_to_expr(v_tp, v)),
745
+ ),
746
+ )
747
+ return expr
748
+ case "Set":
749
+ xs_ = self.egraph.value_to_set(value)
750
+ (v_tp,) = tp.args
751
+ return CallDecl(
752
+ InitRef("Set"), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs_), (v_tp,)
753
+ )
754
+ case "Vec":
755
+ xs = self.egraph.value_to_vec(value)
756
+ (v_tp,) = tp.args
757
+ return CallDecl(
758
+ InitRef("Vec"), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs), (v_tp,)
759
+ )
760
+ case "MultiSet":
761
+ xs = self.egraph.value_to_multiset(value)
762
+ (v_tp,) = tp.args
763
+ return CallDecl(
764
+ InitRef("MultiSet"), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs), (v_tp,)
765
+ )
766
+ case "UnstableFn":
767
+ _names, _args = self.egraph.value_to_function(value)
768
+ return_tp, *arg_types = tp.args
769
+ return self._unstable_fn_value_to_expr(_names, _args, return_tp, arg_types)
770
+ return ValueDecl(value)
771
+
772
+ def _unstable_fn_value_to_expr(
773
+ self, name: str, partial_args: list[bindings.Value], return_tp: JustTypeRef, _arg_types: list[JustTypeRef]
774
+ ) -> PartialCallDecl:
775
+ # Similar to FromEggState::from_call but accepts partial list of args and returns in values
776
+ # Find first callable ref whose return type matches and fill in arg types.
777
+ for callable_ref in self.egg_fn_to_callable_refs[name]:
778
+ signature = self.__egg_decls__.get_callable_decl(callable_ref).signature
779
+ if not isinstance(signature, FunctionSignature):
780
+ continue
781
+ if signature.semantic_return_type.name != return_tp.name:
782
+ continue
783
+ tcs = TypeConstraintSolver(self.__egg_decls__)
784
+
785
+ arg_types, bound_tp_params = tcs.infer_arg_types(
786
+ signature.arg_types, signature.semantic_return_type, signature.var_arg_type, return_tp, None
787
+ )
788
+
789
+ args = tuple(
790
+ TypedExprDecl(tp, self.value_to_expr(tp, v)) for tp, v in zip(arg_types, partial_args, strict=False)
791
+ )
792
+
793
+ call_decl = CallDecl(
794
+ callable_ref,
795
+ args,
796
+ # Don't include bound type params if this is just a method, we only needed them for type resolution
797
+ # but dont need to store them
798
+ bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else (),
799
+ )
800
+ return PartialCallDecl(call_decl)
801
+ raise ValueError(f"Function '{name}' not found")
802
+
655
803
 
656
804
  # https://chatgpt.com/share/9ab899b4-4e17-4426-a3f2-79d67a5ec456
657
805
  _EGGLOG_INVALID_IDENT = re.compile(r"[^\w\-+*/?!=<>&|^/%]")
@@ -789,7 +937,7 @@ class FromEggState:
789
937
  args,
790
938
  # Don't include bound type params if this is just a method, we only needed them for type resolution
791
939
  # but dont need to store them
792
- bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else None,
940
+ bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else (),
793
941
  )
794
942
  raise ValueError(
795
943
  f"Could not find callable ref for call {term}. None of these refs matched the types: {self.state.egg_fn_to_callable_refs[term.name]}"
@@ -4,7 +4,7 @@ from typing import TypeVar, cast
4
4
 
5
5
  import numpy as np
6
6
 
7
- from egglog import EGraph
7
+ from egglog import EGraph, greedy_dag_cost_model
8
8
  from egglog.exp.array_api import NDArray, set_array_api_egraph, try_evaling
9
9
  from egglog.exp.array_api_numba import array_api_numba_schedule
10
10
  from egglog.exp.array_api_program_gen import EvalProgram, array_api_program_gen_schedule, ndarray_function_two_program
@@ -41,7 +41,7 @@ def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph,
41
41
  res = fn(NDArray.var(arg1), NDArray.var(arg2))
42
42
  egraph.register(res)
43
43
  egraph.run(array_api_numba_schedule)
44
- res_optimized = egraph.extract(res)
44
+ res_optimized = egraph.extract(res, cost_model=greedy_dag_cost_model())
45
45
 
46
46
  return (
47
47
  egraph,
egglog/pretty.py CHANGED
@@ -183,7 +183,7 @@ class TraverseContext:
183
183
  if isinstance(de, DefaultRewriteDecl):
184
184
  continue
185
185
  self(de)
186
- case CallDecl(ref, exprs, _):
186
+ case CallDecl(ref, exprs, _) | GetCostDecl(ref, exprs):
187
187
  match ref:
188
188
  case FunctionRef(UnnamedFunctionRef(_, res)):
189
189
  self(res.expr)
@@ -205,12 +205,13 @@ class TraverseContext:
205
205
  case SetCostDecl(_, e, c):
206
206
  self(e)
207
207
  self(c)
208
- case BackOffDecl():
208
+ case BackOffDecl() | ValueDecl():
209
209
  pass
210
210
  case LetSchedulerDecl(scheduler, schedule):
211
211
  self(scheduler)
212
212
  self(schedule)
213
-
213
+ case GetCostDecl(ref, args):
214
+ self(CallDecl(ref, args))
214
215
  case _:
215
216
  assert_never(decl)
216
217
 
@@ -353,6 +354,10 @@ class PrettyContext:
353
354
  if ban_length is not None:
354
355
  list_args.append(f"ban_length={ban_length}")
355
356
  return f"back_off({', '.join(list_args)})", "scheduler"
357
+ case ValueDecl(value):
358
+ return str(value), "value"
359
+ case GetCostDecl(ref, args):
360
+ return f"get_cost({self(CallDecl(ref, args))})", "get_cost"
356
361
  assert_never(decl)
357
362
 
358
363
  def _call(
egglog/runtime.py CHANGED
@@ -457,7 +457,7 @@ class RuntimeFunction(DelayedDeclerations):
457
457
  arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
458
458
  return_tp = tcs.substitute_typevars(signature.semantic_return_type, cls_name)
459
459
  bound_params = (
460
- cast("JustTypeRef", bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else None
460
+ cast("JustTypeRef", bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else ()
461
461
  )
462
462
  # If we were using unstable-app to call a funciton, add that function back as the first arg.
463
463
  if function_value:
@@ -584,11 +584,17 @@ class RuntimeExpr(DelayedDeclerations):
584
584
  if (method := _get_expr_method(self, "__eq__")) is not None:
585
585
  return method(other)
586
586
 
587
- # TODO: Check if two objects can be upcasted to be the same. If not, then return NotImplemented so other
588
- # expr gets a chance to resolve __eq__ which could be a preserved method.
589
- from .egraph import BaseExpr, eq # noqa: PLC0415
587
+ if not (isinstance(self, RuntimeExpr) and isinstance(other, RuntimeExpr)):
588
+ return NotImplemented
589
+ if self.__egg_typed_expr__.tp != other.__egg_typed_expr__.tp:
590
+ return NotImplemented
590
591
 
591
- return eq(cast("BaseExpr", self)).to(cast("BaseExpr", other))
592
+ from .egraph import Fact # noqa: PLC0415
593
+
594
+ return Fact(
595
+ Declarations.create(self, other),
596
+ EqDecl(self.__egg_typed_expr__.tp, self.__egg_typed_expr__.expr, other.__egg_typed_expr__.expr),
597
+ )
592
598
 
593
599
  def __ne__(self, other: object) -> object: # type: ignore[override]
594
600
  if (method := _get_expr_method(self, "__ne__")) is not None:
@@ -54,7 +54,7 @@ class TypeConstraintSolver:
54
54
  fn_var_args: TypeOrVarRef | None,
55
55
  return_: JustTypeRef,
56
56
  cls_name: str | None,
57
- ) -> tuple[Iterable[JustTypeRef], tuple[JustTypeRef, ...] | None]:
57
+ ) -> tuple[Iterable[JustTypeRef], tuple[JustTypeRef, ...]]:
58
58
  """
59
59
  Given a return type, infer the argument types. If there is a variable arg, it returns an infinite iterable.
60
60
 
@@ -75,7 +75,7 @@ class TypeConstraintSolver:
75
75
  )
76
76
  )
77
77
  if cls_name
78
- else None
78
+ else ()
79
79
  )
80
80
  return arg_types, bound_typevars
81
81
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: egglog
3
- Version: 11.3.0
3
+ Version: 11.4.0
4
4
  Classifier: Environment :: MacOS X
5
5
  Classifier: Environment :: Win32 (MS Windows)
6
6
  Classifier: Intended Audience :: Developers
@@ -1,16 +1,16 @@
1
- egglog-11.3.0.dist-info/METADATA,sha256=_7z3OUAk7vWdO7bwhHHUwhirFqvbHsguTnkGY9qvidQ,4554
2
- egglog-11.3.0.dist-info/WHEEL,sha256=Xci0wQUn185O40gd7BpQOd6FhkCRTTECoTx1iWoeZos,131
3
- egglog-11.3.0.dist-info/licenses/LICENSE,sha256=w7VlVv5O_FPZRo8Z-4Zb_q7D5ac3YDs8JUkMZ4Gq9CE,1070
1
+ egglog-11.4.0.dist-info/METADATA,sha256=QIRn-wS4KchOL4lLxMhdsJKD8Z-7ikr467ysgLww_jM,4554
2
+ egglog-11.4.0.dist-info/WHEEL,sha256=Xci0wQUn185O40gd7BpQOd6FhkCRTTECoTx1iWoeZos,131
3
+ egglog-11.4.0.dist-info/licenses/LICENSE,sha256=w7VlVv5O_FPZRo8Z-4Zb_q7D5ac3YDs8JUkMZ4Gq9CE,1070
4
4
  egglog/__init__.py,sha256=0r3MzQbU-9U0fSCeAoJ3deVhZ77tI-1tf8A_WFOhbJs,344
5
- egglog/bindings.cpython-310-aarch64-linux-gnu.so,sha256=7bRwR1m-c25hqH9s5kumX6FlJw5a73yGfhZ17Y2-AuA,165375536
6
- egglog/bindings.pyi,sha256=Y_YpdAKmVHZ0nIHTTPeg0sigBEPiS8z_U-Z161zKSK4,15330
7
- egglog/builtins.py,sha256=qXbBOtT1qwgR9uQu9yb1gUp4dm2L6BgvJIWYU4zCzuw,30317
5
+ egglog/bindings.cpython-310-aarch64-linux-gnu.so,sha256=gYDZxkrNtHtSH3cfbVoPmFupRnYCoVhr4F-ormRcwgo,169197376
6
+ egglog/bindings.pyi,sha256=ntbf2xtpeiabQzUUzx0EdKzu0COXh7sX0_qaUdtB0BY,17874
7
+ egglog/builtins.py,sha256=OSK-JUCKDhpacwPhezPUI9KEru_XIcnA9u9drEyuJi4,30512
8
8
  egglog/config.py,sha256=yM3FIcVCKnhWZmHD0pxkzx7ah7t5PxZx3WUqKtA9tjU,168
9
9
  egglog/conversion.py,sha256=DO76lxRbbTqHs6hRo_Lckvtwu0c6LaKoX7k5_B2AfuY,11238
10
- egglog/declarations.py,sha256=pc2KEYwyKNQXuKndbBCC6iuVROgHkaSKJJf_s9liZi8,26260
11
- egglog/deconstruct.py,sha256=CovORrpROMIwOLgERPUw8doqRUDUehj6LJEB5FMbpmI,5635
12
- egglog/egraph.py,sha256=zJpAoC6JXXqnRsp24CvQN5M5EZ0PrOj93R9U4w6bqlw,65417
13
- egglog/egraph_state.py,sha256=3VLwkAsR3oCydHLx_BXmFw4UHXgdZ9jooQdWUcQeUD0,36375
10
+ egglog/declarations.py,sha256=9gFcg0mc86JuJ4DDE_8qm8d324JiTJZxyLMnx59P3qs,26521
11
+ egglog/deconstruct.py,sha256=b7D5uCaLII-JtlMWcOK8s6LWB8nJe2N879iOIio3-ak,5455
12
+ egglog/egraph.py,sha256=nqL6idjBkURnL3gJ7V-EwhyM5pmfo9-CG19TUO6b9f8,77397
13
+ egglog/egraph_state.py,sha256=UOjcY242GZ0E-4lDt_taKihgxdVgVFX4YCWaDW0_AQs,43394
14
14
  egglog/examples/README.rst,sha256=ztTvpofR0eotSqGoCy_C1fPLDPCncjvcqDanXtLHNNU,232
15
15
  egglog/examples/__init__.py,sha256=wm9evUbMPfbtylXIjbDdRTAVMLH4OjT4Z77PCBFyaPU,31
16
16
  egglog/examples/bignum.py,sha256=jfL57XXpQqIqizQQ3sSUCCjTrkdjtB71BmjrQIQorQk,535
@@ -27,20 +27,20 @@ egglog/examples/resolution.py,sha256=BJd5JClA3DBVGfiVRa-H0gbbFvIqeP3uYbhCXHblSQc
27
27
  egglog/examples/schedule_demo.py,sha256=JbXdPII7_adxtgyKVAiqCyV2sj88VZ-DhomYrdn8vuc,618
28
28
  egglog/exp/__init__.py,sha256=nPtzrH1bz1LVZhZCuS0S9Qild8m5gEikjOVqWAFIa88,49
29
29
  egglog/exp/array_api.py,sha256=dKgEufUIyoT7J_RvnyGtOkg_DK25ZnxIgt7olVygaH8,65547
30
- egglog/exp/array_api_jit.py,sha256=Ak4QhmfYLKimjPf8ffUvPv62OhxOneJ9NEWQJuMxKJc,1680
30
+ egglog/exp/array_api_jit.py,sha256=S5XGWT8a_bFNHUXbQZi6U2cmy__xjvDMNl1eUgvRDyw,1739
31
31
  egglog/exp/array_api_loopnest.py,sha256=-kbyorlGxvlaNsLx1nmLfEZHQM7VMEBwSKtV0l-bs0g,2444
32
32
  egglog/exp/array_api_numba.py,sha256=X3H1TnCjPL92uVm6OvcWMJ11IeorAE58zWiOX6huPv4,2696
33
33
  egglog/exp/array_api_program_gen.py,sha256=qnve8iqklRQVyGChllG8ZAjAffRpezmdxc3IdaWytoQ,21779
34
34
  egglog/exp/program_gen.py,sha256=CavsD70x0ERS87V4OU9zkgMvLXswGEpb1ZZFK0WyN_g,13033
35
35
  egglog/exp/siu_examples.py,sha256=yZ-sgH2Y12iTdwBUumP7D2OtCGL83M6pPW7PMobVFXc,719
36
36
  egglog/ipython_magic.py,sha256=2hs3g2cSiyDmbCvE2t1OINmu17Bb8MWV--2DpEWwO7I,1189
37
- egglog/pretty.py,sha256=Sv3H9e0CJcZv3-ylijP58ApCQ5w1BOdXl2VDw6Hst4Y,22061
37
+ egglog/pretty.py,sha256=gVq3zHkvyCeVZ_u1uqxICy_2srD1-oRzpP5WqFiGFmY,22378
38
38
  egglog/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
- egglog/runtime.py,sha256=NUA0O-_urneP54RqXRcPLQIlFzNwPacPKIMxGpwAkus,29672
39
+ egglog/runtime.py,sha256=VatMZWUqMbo_tc6RkdsEhnSuEZbUbropiYZdppn7S4Y,29805
40
40
  egglog/thunk.py,sha256=MrAlPoGK36VQrUrq8PWSaJFu42sPL0yupwiH18lNips,2271
41
- egglog/type_constraint_solver.py,sha256=U2GjLgbebTLv5QY8_TU0As5wMKL5_NxkHLen9rpfMwI,4518
41
+ egglog/type_constraint_solver.py,sha256=jivNkqjRTm38miaoxQoUzTntZbx1yYJO7FhuZWik3lg,4509
42
42
  egglog/version_compat.py,sha256=EaKRMIOPcatrx9XjCofxZD6Nr5WOooiWNdoapkKleww,3512
43
43
  egglog/visualizer.css,sha256=eL0POoThQRc0P4OYnDT-d808ln9O5Qy6DizH9Z5LgWc,259398
44
44
  egglog/visualizer.js,sha256=2qZZ-9W_INJx4gZMYjnVXl27IjT_JNuQyEeI2dbjWoU,3753315
45
45
  egglog/visualizer_widget.py,sha256=LtVfzOtv2WeKtNuILQQ_9SOHWvRr8YdBYQDKQSgry_s,1319
46
- egglog-11.3.0.dist-info/RECORD,,
46
+ egglog-11.4.0.dist-info/RECORD,,