egglog 11.2.0__cp310-cp310-win_amd64.whl → 11.4.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

Binary file
egglog/bindings.pyi CHANGED
@@ -1,6 +1,8 @@
1
+ from collections.abc import Callable
1
2
  from datetime import timedelta
3
+ from fractions import Fraction
2
4
  from pathlib import Path
3
- from typing import TypeAlias
5
+ from typing import Any, Generic, Protocol, TypeAlias, TypeVar
4
6
 
5
7
  from typing_extensions import final
6
8
 
@@ -14,6 +16,7 @@ __all__ = [
14
16
  "Change",
15
17
  "Check",
16
18
  "Constructor",
19
+ "CostModel",
17
20
  "Datatype",
18
21
  "Datatypes",
19
22
  "DefaultPrintFunctionMode",
@@ -26,6 +29,7 @@ __all__ = [
26
29
  "Extract",
27
30
  "ExtractBest",
28
31
  "ExtractVariants",
32
+ "Extractor",
29
33
  "Fact",
30
34
  "Fail",
31
35
  "Float",
@@ -83,6 +87,7 @@ __all__ = [
83
87
  "UserDefined",
84
88
  "UserDefinedCommandOutput",
85
89
  "UserDefinedOutput",
90
+ "Value",
86
91
  "Var",
87
92
  "Variant",
88
93
  ]
@@ -128,6 +133,31 @@ class EGraph:
128
133
  max_calls_per_function: int | None = None,
129
134
  include_temporary_functions: bool = False,
130
135
  ) -> SerializedEGraph: ...
136
+ def lookup_function(self, name: str, key: list[Value]) -> Value | None: ...
137
+ def eval_expr(self, expr: _Expr) -> tuple[str, Value]: ...
138
+ def value_to_i64(self, v: Value) -> int: ...
139
+ def value_to_f64(self, v: Value) -> float: ...
140
+ def value_to_string(self, v: Value) -> str: ...
141
+ def value_to_bool(self, v: Value) -> bool: ...
142
+ def value_to_rational(self, v: Value) -> Fraction: ...
143
+ def value_to_bigint(self, v: Value) -> int: ...
144
+ def value_to_bigrat(self, v: Value) -> Fraction: ...
145
+ def value_to_pyobject(self, py_object_sort: PyObjectSort, v: Value) -> object: ...
146
+ def value_to_map(self, v: Value) -> dict[Value, Value]: ...
147
+ def value_to_multiset(self, v: Value) -> list[Value]: ...
148
+ def value_to_vec(self, v: Value) -> list[Value]: ...
149
+ def value_to_function(self, v: Value) -> tuple[str, list[Value]]: ...
150
+ def value_to_set(self, v: Value) -> set[Value]: ...
151
+ # def dynamic_cost_model_enode_cost(self, func: str, args: list[Value]) -> int: ...
152
+
153
+ @final
154
+ class Value:
155
+ def __hash__(self) -> int: ...
156
+ def __eq__(self, value: object) -> bool: ...
157
+ def __lt__(self, other: object) -> bool: ...
158
+ def __le__(self, other: object) -> bool: ...
159
+ def __gt__(self, other: object) -> bool: ...
160
+ def __ge__(self, other: object) -> bool: ...
131
161
 
132
162
  @final
133
163
  class EggSmolError(Exception):
@@ -732,3 +762,34 @@ class TermDag:
732
762
  def expr_to_term(self, expr: _Expr) -> _Term: ...
733
763
  def term_to_expr(self, term: _Term, span: _Span) -> _Expr: ...
734
764
  def to_string(self, term: _Term) -> str: ...
765
+
766
+ ##
767
+ # Extraction
768
+ ##
769
+ class _Cost(Protocol):
770
+ def __lt__(self, other: _Cost) -> bool: ...
771
+ def __le__(self, other: _Cost) -> bool: ...
772
+ def __gt__(self, other: _Cost) -> bool: ...
773
+ def __ge__(self, other: _Cost) -> bool: ...
774
+
775
+ _COST = TypeVar("_COST", bound=_Cost)
776
+
777
+ _ENODE_COST = TypeVar("_ENODE_COST")
778
+
779
+ @final
780
+ class CostModel(Generic[_COST, _ENODE_COST]):
781
+ def __init__(
782
+ self,
783
+ fold: Callable[[str, _ENODE_COST, list[_COST]], _COST],
784
+ enode_cost: Callable[[str, list[Value]], _ENODE_COST],
785
+ container_cost: Callable[[str, Value, list[_COST]], _COST],
786
+ base_value_cost: Callable[[str, Value], _COST],
787
+ ) -> None: ...
788
+
789
+ @final
790
+ class Extractor(Generic[_COST]):
791
+ def __init__(self, rootsorts: list[str] | None, egraph: EGraph, cost_model: CostModel[_COST, Any]) -> None: ...
792
+ def extract_best(self, egraph: EGraph, termdag: TermDag, value: Value, sort: str) -> tuple[_COST, _Term]: ...
793
+ def extract_variants(
794
+ self, egraph: EGraph, termdag: TermDag, value: Value, nvariants: int, sort: str
795
+ ) -> list[tuple[_COST, _Term]]: ...
egglog/builtins.py CHANGED
@@ -33,10 +33,12 @@ __all__ = [
33
33
  "BigRatLike",
34
34
  "Bool",
35
35
  "BoolLike",
36
+ "Container",
36
37
  "ExprValueError",
37
38
  "Map",
38
39
  "MapLike",
39
40
  "MultiSet",
41
+ "Primitive",
40
42
  "PyObject",
41
43
  "Rational",
42
44
  "Set",
@@ -103,6 +105,10 @@ class String(BuiltinExpr):
103
105
  @method(egg_fn="replace")
104
106
  def replace(self, old: StringLike, new: StringLike) -> String: ...
105
107
 
108
+ @method(preserve=True)
109
+ def __add__(self, other: StringLike) -> String:
110
+ return join(self, other)
111
+
106
112
 
107
113
  StringLike: TypeAlias = String | str
108
114
 
@@ -1131,3 +1137,7 @@ def _convert_function(fn: FunctionType) -> UnstableFn:
1131
1137
 
1132
1138
 
1133
1139
  converter(FunctionType, UnstableFn, _convert_function)
1140
+
1141
+
1142
+ Container: TypeAlias = Map | Set | MultiSet | Vec | UnstableFn
1143
+ Primitive: TypeAlias = String | Bool | i64 | f64 | Rational | BigInt | BigRat | PyObject | Unit
egglog/declarations.py CHANGED
@@ -9,10 +9,13 @@ from __future__ import annotations
9
9
  from dataclasses import dataclass, field
10
10
  from functools import cached_property
11
11
  from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias, TypeVar, Union, cast, runtime_checkable
12
+ from uuid import UUID
12
13
  from weakref import WeakValueDictionary
13
14
 
14
15
  from typing_extensions import Self, assert_never
15
16
 
17
+ from .bindings import Value
18
+
16
19
  if TYPE_CHECKING:
17
20
  from collections.abc import Callable, Iterable, Mapping
18
21
 
@@ -20,6 +23,7 @@ if TYPE_CHECKING:
20
23
  __all__ = [
21
24
  "ActionCommandDecl",
22
25
  "ActionDecl",
26
+ "BackOffDecl",
23
27
  "BiRewriteDecl",
24
28
  "CallDecl",
25
29
  "CallableDecl",
@@ -47,11 +51,13 @@ __all__ = [
47
51
  "FunctionDecl",
48
52
  "FunctionRef",
49
53
  "FunctionSignature",
54
+ "GetCostDecl",
50
55
  "HasDeclerations",
51
56
  "InitRef",
52
57
  "JustTypeRef",
53
58
  "LetDecl",
54
59
  "LetRefDecl",
60
+ "LetSchedulerDecl",
55
61
  "LitDecl",
56
62
  "LitType",
57
63
  "MethodRef",
@@ -79,6 +85,7 @@ __all__ = [
79
85
  "UnboundVarDecl",
80
86
  "UnionDecl",
81
87
  "UnnamedFunctionRef",
88
+ "ValueDecl",
82
89
  "collect_unbound_vars",
83
90
  "replace_typed_expr",
84
91
  "upcast_declerations",
@@ -636,7 +643,7 @@ class CallDecl:
636
643
  args: tuple[TypedExprDecl, ...] = ()
637
644
  # type parameters that were bound to the callable, if it is a classmethod
638
645
  # Used for pretty printing classmethod calls with type parameters
639
- bound_tp_params: tuple[JustTypeRef, ...] | None = None
646
+ bound_tp_params: tuple[JustTypeRef, ...] = ()
640
647
 
641
648
  # pool objects for faster __eq__
642
649
  _args_to_value: ClassVar[WeakValueDictionary[tuple[object, ...], CallDecl]] = WeakValueDictionary({})
@@ -651,7 +658,7 @@ class CallDecl:
651
658
  # normalize the args/kwargs to a tuple so that they can be compared
652
659
  callable = args[0] if args else kwargs["callable"]
653
660
  args_ = args[1] if len(args) > 1 else kwargs.get("args", ())
654
- bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params")
661
+ bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params", ())
655
662
 
656
663
  normalized_args = (callable, args_, bound_tp_params)
657
664
  try:
@@ -693,7 +700,20 @@ class PartialCallDecl:
693
700
  call: CallDecl
694
701
 
695
702
 
696
- ExprDecl: TypeAlias = UnboundVarDecl | LetRefDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl
703
+ @dataclass(frozen=True)
704
+ class GetCostDecl:
705
+ callable: CallableRef
706
+ args: tuple[TypedExprDecl, ...]
707
+
708
+
709
+ @dataclass(frozen=True)
710
+ class ValueDecl:
711
+ value: Value
712
+
713
+
714
+ ExprDecl: TypeAlias = (
715
+ UnboundVarDecl | LetRefDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl | ValueDecl | GetCostDecl
716
+ )
697
717
 
698
718
 
699
719
  @dataclass(frozen=True)
@@ -790,9 +810,24 @@ class SequenceDecl:
790
810
  class RunDecl:
791
811
  ruleset: str
792
812
  until: tuple[FactDecl, ...] | None
813
+ scheduler: BackOffDecl | None = None
814
+
815
+
816
+ @dataclass(frozen=True)
817
+ class LetSchedulerDecl:
818
+ scheduler: BackOffDecl
819
+ inner: ScheduleDecl
793
820
 
794
821
 
795
- ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl
822
+ ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl | LetSchedulerDecl
823
+
824
+
825
+ @dataclass(frozen=True)
826
+ class BackOffDecl:
827
+ id: UUID
828
+ match_limit: int | None
829
+ ban_length: int | None
830
+
796
831
 
797
832
  ##
798
833
  # Facts
egglog/deconstruct.py CHANGED
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, TypeVar, overload
11
11
  from typing_extensions import TypeVarTuple, Unpack
12
12
 
13
13
  from .declarations import *
14
- from .egraph import BaseExpr
14
+ from .egraph import BaseExpr, Expr
15
15
  from .runtime import *
16
16
  from .thunk import *
17
17
 
@@ -49,7 +49,11 @@ def get_literal_value(x: PyObject) -> object: ...
49
49
  def get_literal_value(x: UnstableFn[T, Unpack[TS]]) -> Callable[[Unpack[TS]], T] | None: ...
50
50
 
51
51
 
52
- def get_literal_value(x: String | Bool | i64 | f64 | PyObject | UnstableFn) -> object:
52
+ @overload
53
+ def get_literal_value(x: Expr) -> None: ...
54
+
55
+
56
+ def get_literal_value(x: object) -> object:
53
57
  """
54
58
  Returns the literal value of an expression if it is a literal.
55
59
  If it is not a literal, returns None.
@@ -95,12 +99,9 @@ def get_var_name(x: BaseExpr) -> str | None:
95
99
  return None
96
100
 
97
101
 
98
- def get_callable_fn(x: T) -> Callable[..., T] | None:
102
+ def get_callable_fn(x: T) -> Callable[..., T] | T | None:
99
103
  """
100
- Gets the function of an expression if it is a call expression.
101
- If it is not a call expression (a property, a primitive value, constants, classvars, a let value), return None.
102
- For those values, you can check them by comparing them directly with equality or for primitives calling `.eval()`
103
- to return the Python value.
104
+ Gets the function of an expression, or if it's a constant or classvar, return that.
104
105
  """
105
106
  if not isinstance(x, RuntimeExpr):
106
107
  raise TypeError(f"Expected Expression, got {type(x).__name__}")
@@ -159,6 +160,7 @@ def _deconstruct_call_decl(
159
160
  """
160
161
  args = call.args
161
162
  arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args)
163
+ # TODO: handle values? Like constants
162
164
  if isinstance(call.callable, InitRef):
163
165
  return RuntimeClass(
164
166
  decls_thunk,