egglog 11.2.0__cp310-cp310-manylinux_2_17_ppc64.manylinux2014_ppc64.whl → 11.4.0__cp310-cp310-manylinux_2_17_ppc64.manylinux2014_ppc64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cpython-310-powerpc64-linux-gnu.so +0 -0
- egglog/bindings.pyi +62 -1
- egglog/builtins.py +10 -0
- egglog/declarations.py +39 -4
- egglog/deconstruct.py +9 -7
- egglog/egraph.py +360 -26
- egglog/egraph_state.py +283 -12
- egglog/examples/jointree.py +0 -3
- egglog/exp/array_api_jit.py +2 -2
- egglog/pretty.py +38 -8
- egglog/runtime.py +22 -7
- egglog/type_constraint_solver.py +2 -2
- {egglog-11.2.0.dist-info → egglog-11.4.0.dist-info}/METADATA +19 -1
- {egglog-11.2.0.dist-info → egglog-11.4.0.dist-info}/RECORD +16 -16
- {egglog-11.2.0.dist-info → egglog-11.4.0.dist-info}/WHEEL +0 -0
- {egglog-11.2.0.dist-info → egglog-11.4.0.dist-info}/licenses/LICENSE +0 -0
|
Binary file
|
egglog/bindings.pyi
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
1
2
|
from datetime import timedelta
|
|
3
|
+
from fractions import Fraction
|
|
2
4
|
from pathlib import Path
|
|
3
|
-
from typing import TypeAlias
|
|
5
|
+
from typing import Any, Generic, Protocol, TypeAlias, TypeVar
|
|
4
6
|
|
|
5
7
|
from typing_extensions import final
|
|
6
8
|
|
|
@@ -14,6 +16,7 @@ __all__ = [
|
|
|
14
16
|
"Change",
|
|
15
17
|
"Check",
|
|
16
18
|
"Constructor",
|
|
19
|
+
"CostModel",
|
|
17
20
|
"Datatype",
|
|
18
21
|
"Datatypes",
|
|
19
22
|
"DefaultPrintFunctionMode",
|
|
@@ -26,6 +29,7 @@ __all__ = [
|
|
|
26
29
|
"Extract",
|
|
27
30
|
"ExtractBest",
|
|
28
31
|
"ExtractVariants",
|
|
32
|
+
"Extractor",
|
|
29
33
|
"Fact",
|
|
30
34
|
"Fail",
|
|
31
35
|
"Float",
|
|
@@ -83,6 +87,7 @@ __all__ = [
|
|
|
83
87
|
"UserDefined",
|
|
84
88
|
"UserDefinedCommandOutput",
|
|
85
89
|
"UserDefinedOutput",
|
|
90
|
+
"Value",
|
|
86
91
|
"Var",
|
|
87
92
|
"Variant",
|
|
88
93
|
]
|
|
@@ -128,6 +133,31 @@ class EGraph:
|
|
|
128
133
|
max_calls_per_function: int | None = None,
|
|
129
134
|
include_temporary_functions: bool = False,
|
|
130
135
|
) -> SerializedEGraph: ...
|
|
136
|
+
def lookup_function(self, name: str, key: list[Value]) -> Value | None: ...
|
|
137
|
+
def eval_expr(self, expr: _Expr) -> tuple[str, Value]: ...
|
|
138
|
+
def value_to_i64(self, v: Value) -> int: ...
|
|
139
|
+
def value_to_f64(self, v: Value) -> float: ...
|
|
140
|
+
def value_to_string(self, v: Value) -> str: ...
|
|
141
|
+
def value_to_bool(self, v: Value) -> bool: ...
|
|
142
|
+
def value_to_rational(self, v: Value) -> Fraction: ...
|
|
143
|
+
def value_to_bigint(self, v: Value) -> int: ...
|
|
144
|
+
def value_to_bigrat(self, v: Value) -> Fraction: ...
|
|
145
|
+
def value_to_pyobject(self, py_object_sort: PyObjectSort, v: Value) -> object: ...
|
|
146
|
+
def value_to_map(self, v: Value) -> dict[Value, Value]: ...
|
|
147
|
+
def value_to_multiset(self, v: Value) -> list[Value]: ...
|
|
148
|
+
def value_to_vec(self, v: Value) -> list[Value]: ...
|
|
149
|
+
def value_to_function(self, v: Value) -> tuple[str, list[Value]]: ...
|
|
150
|
+
def value_to_set(self, v: Value) -> set[Value]: ...
|
|
151
|
+
# def dynamic_cost_model_enode_cost(self, func: str, args: list[Value]) -> int: ...
|
|
152
|
+
|
|
153
|
+
@final
|
|
154
|
+
class Value:
|
|
155
|
+
def __hash__(self) -> int: ...
|
|
156
|
+
def __eq__(self, value: object) -> bool: ...
|
|
157
|
+
def __lt__(self, other: object) -> bool: ...
|
|
158
|
+
def __le__(self, other: object) -> bool: ...
|
|
159
|
+
def __gt__(self, other: object) -> bool: ...
|
|
160
|
+
def __ge__(self, other: object) -> bool: ...
|
|
131
161
|
|
|
132
162
|
@final
|
|
133
163
|
class EggSmolError(Exception):
|
|
@@ -732,3 +762,34 @@ class TermDag:
|
|
|
732
762
|
def expr_to_term(self, expr: _Expr) -> _Term: ...
|
|
733
763
|
def term_to_expr(self, term: _Term, span: _Span) -> _Expr: ...
|
|
734
764
|
def to_string(self, term: _Term) -> str: ...
|
|
765
|
+
|
|
766
|
+
##
|
|
767
|
+
# Extraction
|
|
768
|
+
##
|
|
769
|
+
class _Cost(Protocol):
|
|
770
|
+
def __lt__(self, other: _Cost) -> bool: ...
|
|
771
|
+
def __le__(self, other: _Cost) -> bool: ...
|
|
772
|
+
def __gt__(self, other: _Cost) -> bool: ...
|
|
773
|
+
def __ge__(self, other: _Cost) -> bool: ...
|
|
774
|
+
|
|
775
|
+
_COST = TypeVar("_COST", bound=_Cost)
|
|
776
|
+
|
|
777
|
+
_ENODE_COST = TypeVar("_ENODE_COST")
|
|
778
|
+
|
|
779
|
+
@final
|
|
780
|
+
class CostModel(Generic[_COST, _ENODE_COST]):
|
|
781
|
+
def __init__(
|
|
782
|
+
self,
|
|
783
|
+
fold: Callable[[str, _ENODE_COST, list[_COST]], _COST],
|
|
784
|
+
enode_cost: Callable[[str, list[Value]], _ENODE_COST],
|
|
785
|
+
container_cost: Callable[[str, Value, list[_COST]], _COST],
|
|
786
|
+
base_value_cost: Callable[[str, Value], _COST],
|
|
787
|
+
) -> None: ...
|
|
788
|
+
|
|
789
|
+
@final
|
|
790
|
+
class Extractor(Generic[_COST]):
|
|
791
|
+
def __init__(self, rootsorts: list[str] | None, egraph: EGraph, cost_model: CostModel[_COST, Any]) -> None: ...
|
|
792
|
+
def extract_best(self, egraph: EGraph, termdag: TermDag, value: Value, sort: str) -> tuple[_COST, _Term]: ...
|
|
793
|
+
def extract_variants(
|
|
794
|
+
self, egraph: EGraph, termdag: TermDag, value: Value, nvariants: int, sort: str
|
|
795
|
+
) -> list[tuple[_COST, _Term]]: ...
|
egglog/builtins.py
CHANGED
|
@@ -33,10 +33,12 @@ __all__ = [
|
|
|
33
33
|
"BigRatLike",
|
|
34
34
|
"Bool",
|
|
35
35
|
"BoolLike",
|
|
36
|
+
"Container",
|
|
36
37
|
"ExprValueError",
|
|
37
38
|
"Map",
|
|
38
39
|
"MapLike",
|
|
39
40
|
"MultiSet",
|
|
41
|
+
"Primitive",
|
|
40
42
|
"PyObject",
|
|
41
43
|
"Rational",
|
|
42
44
|
"Set",
|
|
@@ -103,6 +105,10 @@ class String(BuiltinExpr):
|
|
|
103
105
|
@method(egg_fn="replace")
|
|
104
106
|
def replace(self, old: StringLike, new: StringLike) -> String: ...
|
|
105
107
|
|
|
108
|
+
@method(preserve=True)
|
|
109
|
+
def __add__(self, other: StringLike) -> String:
|
|
110
|
+
return join(self, other)
|
|
111
|
+
|
|
106
112
|
|
|
107
113
|
StringLike: TypeAlias = String | str
|
|
108
114
|
|
|
@@ -1131,3 +1137,7 @@ def _convert_function(fn: FunctionType) -> UnstableFn:
|
|
|
1131
1137
|
|
|
1132
1138
|
|
|
1133
1139
|
converter(FunctionType, UnstableFn, _convert_function)
|
|
1140
|
+
|
|
1141
|
+
|
|
1142
|
+
Container: TypeAlias = Map | Set | MultiSet | Vec | UnstableFn
|
|
1143
|
+
Primitive: TypeAlias = String | Bool | i64 | f64 | Rational | BigInt | BigRat | PyObject | Unit
|
egglog/declarations.py
CHANGED
|
@@ -9,10 +9,13 @@ from __future__ import annotations
|
|
|
9
9
|
from dataclasses import dataclass, field
|
|
10
10
|
from functools import cached_property
|
|
11
11
|
from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias, TypeVar, Union, cast, runtime_checkable
|
|
12
|
+
from uuid import UUID
|
|
12
13
|
from weakref import WeakValueDictionary
|
|
13
14
|
|
|
14
15
|
from typing_extensions import Self, assert_never
|
|
15
16
|
|
|
17
|
+
from .bindings import Value
|
|
18
|
+
|
|
16
19
|
if TYPE_CHECKING:
|
|
17
20
|
from collections.abc import Callable, Iterable, Mapping
|
|
18
21
|
|
|
@@ -20,6 +23,7 @@ if TYPE_CHECKING:
|
|
|
20
23
|
__all__ = [
|
|
21
24
|
"ActionCommandDecl",
|
|
22
25
|
"ActionDecl",
|
|
26
|
+
"BackOffDecl",
|
|
23
27
|
"BiRewriteDecl",
|
|
24
28
|
"CallDecl",
|
|
25
29
|
"CallableDecl",
|
|
@@ -47,11 +51,13 @@ __all__ = [
|
|
|
47
51
|
"FunctionDecl",
|
|
48
52
|
"FunctionRef",
|
|
49
53
|
"FunctionSignature",
|
|
54
|
+
"GetCostDecl",
|
|
50
55
|
"HasDeclerations",
|
|
51
56
|
"InitRef",
|
|
52
57
|
"JustTypeRef",
|
|
53
58
|
"LetDecl",
|
|
54
59
|
"LetRefDecl",
|
|
60
|
+
"LetSchedulerDecl",
|
|
55
61
|
"LitDecl",
|
|
56
62
|
"LitType",
|
|
57
63
|
"MethodRef",
|
|
@@ -79,6 +85,7 @@ __all__ = [
|
|
|
79
85
|
"UnboundVarDecl",
|
|
80
86
|
"UnionDecl",
|
|
81
87
|
"UnnamedFunctionRef",
|
|
88
|
+
"ValueDecl",
|
|
82
89
|
"collect_unbound_vars",
|
|
83
90
|
"replace_typed_expr",
|
|
84
91
|
"upcast_declerations",
|
|
@@ -636,7 +643,7 @@ class CallDecl:
|
|
|
636
643
|
args: tuple[TypedExprDecl, ...] = ()
|
|
637
644
|
# type parameters that were bound to the callable, if it is a classmethod
|
|
638
645
|
# Used for pretty printing classmethod calls with type parameters
|
|
639
|
-
bound_tp_params: tuple[JustTypeRef, ...]
|
|
646
|
+
bound_tp_params: tuple[JustTypeRef, ...] = ()
|
|
640
647
|
|
|
641
648
|
# pool objects for faster __eq__
|
|
642
649
|
_args_to_value: ClassVar[WeakValueDictionary[tuple[object, ...], CallDecl]] = WeakValueDictionary({})
|
|
@@ -651,7 +658,7 @@ class CallDecl:
|
|
|
651
658
|
# normalize the args/kwargs to a tuple so that they can be compared
|
|
652
659
|
callable = args[0] if args else kwargs["callable"]
|
|
653
660
|
args_ = args[1] if len(args) > 1 else kwargs.get("args", ())
|
|
654
|
-
bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params")
|
|
661
|
+
bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params", ())
|
|
655
662
|
|
|
656
663
|
normalized_args = (callable, args_, bound_tp_params)
|
|
657
664
|
try:
|
|
@@ -693,7 +700,20 @@ class PartialCallDecl:
|
|
|
693
700
|
call: CallDecl
|
|
694
701
|
|
|
695
702
|
|
|
696
|
-
|
|
703
|
+
@dataclass(frozen=True)
|
|
704
|
+
class GetCostDecl:
|
|
705
|
+
callable: CallableRef
|
|
706
|
+
args: tuple[TypedExprDecl, ...]
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
@dataclass(frozen=True)
|
|
710
|
+
class ValueDecl:
|
|
711
|
+
value: Value
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
ExprDecl: TypeAlias = (
|
|
715
|
+
UnboundVarDecl | LetRefDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl | ValueDecl | GetCostDecl
|
|
716
|
+
)
|
|
697
717
|
|
|
698
718
|
|
|
699
719
|
@dataclass(frozen=True)
|
|
@@ -790,9 +810,24 @@ class SequenceDecl:
|
|
|
790
810
|
class RunDecl:
|
|
791
811
|
ruleset: str
|
|
792
812
|
until: tuple[FactDecl, ...] | None
|
|
813
|
+
scheduler: BackOffDecl | None = None
|
|
814
|
+
|
|
815
|
+
|
|
816
|
+
@dataclass(frozen=True)
|
|
817
|
+
class LetSchedulerDecl:
|
|
818
|
+
scheduler: BackOffDecl
|
|
819
|
+
inner: ScheduleDecl
|
|
793
820
|
|
|
794
821
|
|
|
795
|
-
ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl
|
|
822
|
+
ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl | LetSchedulerDecl
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
@dataclass(frozen=True)
|
|
826
|
+
class BackOffDecl:
|
|
827
|
+
id: UUID
|
|
828
|
+
match_limit: int | None
|
|
829
|
+
ban_length: int | None
|
|
830
|
+
|
|
796
831
|
|
|
797
832
|
##
|
|
798
833
|
# Facts
|
egglog/deconstruct.py
CHANGED
|
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, TypeVar, overload
|
|
|
11
11
|
from typing_extensions import TypeVarTuple, Unpack
|
|
12
12
|
|
|
13
13
|
from .declarations import *
|
|
14
|
-
from .egraph import BaseExpr
|
|
14
|
+
from .egraph import BaseExpr, Expr
|
|
15
15
|
from .runtime import *
|
|
16
16
|
from .thunk import *
|
|
17
17
|
|
|
@@ -49,7 +49,11 @@ def get_literal_value(x: PyObject) -> object: ...
|
|
|
49
49
|
def get_literal_value(x: UnstableFn[T, Unpack[TS]]) -> Callable[[Unpack[TS]], T] | None: ...
|
|
50
50
|
|
|
51
51
|
|
|
52
|
-
|
|
52
|
+
@overload
|
|
53
|
+
def get_literal_value(x: Expr) -> None: ...
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def get_literal_value(x: object) -> object:
|
|
53
57
|
"""
|
|
54
58
|
Returns the literal value of an expression if it is a literal.
|
|
55
59
|
If it is not a literal, returns None.
|
|
@@ -95,12 +99,9 @@ def get_var_name(x: BaseExpr) -> str | None:
|
|
|
95
99
|
return None
|
|
96
100
|
|
|
97
101
|
|
|
98
|
-
def get_callable_fn(x: T) -> Callable[..., T] | None:
|
|
102
|
+
def get_callable_fn(x: T) -> Callable[..., T] | T | None:
|
|
99
103
|
"""
|
|
100
|
-
Gets the function of an expression if it
|
|
101
|
-
If it is not a call expression (a property, a primitive value, constants, classvars, a let value), return None.
|
|
102
|
-
For those values, you can check them by comparing them directly with equality or for primitives calling `.eval()`
|
|
103
|
-
to return the Python value.
|
|
104
|
+
Gets the function of an expression, or if it's a constant or classvar, return that.
|
|
104
105
|
"""
|
|
105
106
|
if not isinstance(x, RuntimeExpr):
|
|
106
107
|
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
@@ -159,6 +160,7 @@ def _deconstruct_call_decl(
|
|
|
159
160
|
"""
|
|
160
161
|
args = call.args
|
|
161
162
|
arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args)
|
|
163
|
+
# TODO: handle values? Like constants
|
|
162
164
|
if isinstance(call.callable, InitRef):
|
|
163
165
|
return RuntimeClass(
|
|
164
166
|
decls_thunk,
|