egglog 12.0.0__cp313-cp313t-manylinux_2_17_ppc64.manylinux2014_ppc64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- egglog/__init__.py +13 -0
- egglog/bindings.cpython-313t-powerpc64-linux-gnu.so +0 -0
- egglog/bindings.pyi +887 -0
- egglog/builtins.py +1144 -0
- egglog/config.py +8 -0
- egglog/conversion.py +290 -0
- egglog/declarations.py +964 -0
- egglog/deconstruct.py +176 -0
- egglog/egraph.py +2247 -0
- egglog/egraph_state.py +978 -0
- egglog/examples/README.rst +5 -0
- egglog/examples/__init__.py +3 -0
- egglog/examples/bignum.py +32 -0
- egglog/examples/bool.py +38 -0
- egglog/examples/eqsat_basic.py +44 -0
- egglog/examples/fib.py +28 -0
- egglog/examples/higher_order_functions.py +42 -0
- egglog/examples/jointree.py +64 -0
- egglog/examples/lambda_.py +287 -0
- egglog/examples/matrix.py +175 -0
- egglog/examples/multiset.py +60 -0
- egglog/examples/ndarrays.py +144 -0
- egglog/examples/resolution.py +84 -0
- egglog/examples/schedule_demo.py +34 -0
- egglog/exp/MoA.ipynb +617 -0
- egglog/exp/__init__.py +3 -0
- egglog/exp/any_expr.py +947 -0
- egglog/exp/any_expr_example.ipynb +408 -0
- egglog/exp/array_api.py +2019 -0
- egglog/exp/array_api_jit.py +51 -0
- egglog/exp/array_api_loopnest.py +74 -0
- egglog/exp/array_api_numba.py +69 -0
- egglog/exp/array_api_program_gen.py +510 -0
- egglog/exp/program_gen.py +427 -0
- egglog/exp/siu_examples.py +32 -0
- egglog/ipython_magic.py +41 -0
- egglog/pretty.py +566 -0
- egglog/py.typed +0 -0
- egglog/runtime.py +888 -0
- egglog/thunk.py +97 -0
- egglog/type_constraint_solver.py +111 -0
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35798 -0
- egglog/visualizer_widget.py +39 -0
- egglog-12.0.0.dist-info/METADATA +93 -0
- egglog-12.0.0.dist-info/RECORD +48 -0
- egglog-12.0.0.dist-info/WHEEL +5 -0
- egglog-12.0.0.dist-info/licenses/LICENSE +21 -0
egglog/deconstruct.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions to deconstruct expressions in Python.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from functools import partial
|
|
9
|
+
from typing import TYPE_CHECKING, TypeVar, overload
|
|
10
|
+
|
|
11
|
+
import cloudpickle
|
|
12
|
+
from typing_extensions import TypeVarTuple, Unpack
|
|
13
|
+
|
|
14
|
+
from .declarations import *
|
|
15
|
+
from .egraph import BaseExpr, Expr
|
|
16
|
+
from .runtime import *
|
|
17
|
+
from .thunk import *
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from .builtins import Bool, PyObject, String, UnstableFn, f64, i64
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
T = TypeVar("T", bound=BaseExpr)
|
|
24
|
+
TS = TypeVarTuple("TS", default=Unpack[tuple[BaseExpr, ...]])
|
|
25
|
+
|
|
26
|
+
__all__ = ["get_callable_args", "get_callable_fn", "get_let_name", "get_literal_value", "get_var_name"]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@overload
|
|
30
|
+
def get_literal_value(x: String) -> str | None: ...
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@overload
|
|
34
|
+
def get_literal_value(x: Bool) -> bool | None: ...
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@overload
|
|
38
|
+
def get_literal_value(x: i64) -> int | None: ...
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@overload
|
|
42
|
+
def get_literal_value(x: f64) -> float | None: ...
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@overload
|
|
46
|
+
def get_literal_value(x: PyObject) -> object: ...
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@overload
|
|
50
|
+
def get_literal_value(x: UnstableFn[T, *TS]) -> Callable[[Unpack[TS]], T] | None: ...
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@overload
|
|
54
|
+
def get_literal_value(x: Expr) -> None: ...
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_literal_value(x: object) -> object:
|
|
58
|
+
"""
|
|
59
|
+
Returns the literal value of an expression if it is a literal.
|
|
60
|
+
If it is not a literal, returns None.
|
|
61
|
+
"""
|
|
62
|
+
if not isinstance(x, RuntimeExpr):
|
|
63
|
+
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
64
|
+
match x.__egg_typed_expr__.expr:
|
|
65
|
+
case LitDecl(v):
|
|
66
|
+
return v
|
|
67
|
+
case PyObjectDecl(obj):
|
|
68
|
+
return cloudpickle.loads(obj)
|
|
69
|
+
case PartialCallDecl(call):
|
|
70
|
+
fn, args = _deconstruct_call_decl(x.__egg_decls_thunk__, call)
|
|
71
|
+
if not args:
|
|
72
|
+
return fn
|
|
73
|
+
return partial(fn, *args)
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def get_let_name(x: BaseExpr) -> str | None:
|
|
78
|
+
"""
|
|
79
|
+
Check if the expression is a `let` expression and return the name of the variable.
|
|
80
|
+
If it is not a `let` expression, return None.
|
|
81
|
+
"""
|
|
82
|
+
if not isinstance(x, RuntimeExpr):
|
|
83
|
+
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
84
|
+
match x.__egg_typed_expr__.expr:
|
|
85
|
+
case LetRefDecl(name):
|
|
86
|
+
return name
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_var_name(x: BaseExpr) -> str | None:
|
|
91
|
+
"""
|
|
92
|
+
Check if the expression is a variable and return its name.
|
|
93
|
+
If it is not a variable, return None.
|
|
94
|
+
"""
|
|
95
|
+
if not isinstance(x, RuntimeExpr):
|
|
96
|
+
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
97
|
+
match x.__egg_typed_expr__.expr:
|
|
98
|
+
case UnboundVarDecl(name, _egg_name):
|
|
99
|
+
return name
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_callable_fn(x: T) -> Callable[..., T] | T | None:
|
|
104
|
+
"""
|
|
105
|
+
Gets the function of an expression, or if it's a constant or classvar, return that.
|
|
106
|
+
"""
|
|
107
|
+
if not isinstance(x, RuntimeExpr):
|
|
108
|
+
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
109
|
+
match x.__egg_typed_expr__.expr:
|
|
110
|
+
case CallDecl() as call:
|
|
111
|
+
fn, _ = _deconstruct_call_decl(x.__egg_decls_thunk__, call)
|
|
112
|
+
return fn
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@overload
|
|
117
|
+
def get_callable_args(x: T, fn: None = ...) -> tuple[BaseExpr, ...]: ...
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@overload
|
|
121
|
+
def get_callable_args(x: T, fn: Callable[[Unpack[TS]], T]) -> tuple[*TS] | None: ...
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def get_callable_args(x: T, fn: Callable[[Unpack[TS]], T] | None = None) -> tuple[*TS] | None:
|
|
125
|
+
"""
|
|
126
|
+
Gets all the arguments of an expression.
|
|
127
|
+
If a function is provided, it will only return the arguments if the expression is a call
|
|
128
|
+
to that function.
|
|
129
|
+
|
|
130
|
+
Note that recursively calling the arguments is the safe way to walk the expression tree.
|
|
131
|
+
"""
|
|
132
|
+
if not isinstance(x, RuntimeExpr):
|
|
133
|
+
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
134
|
+
match x.__egg_typed_expr__.expr:
|
|
135
|
+
case CallDecl() as call:
|
|
136
|
+
actual_fn, args = _deconstruct_call_decl(x.__egg_decls_thunk__, call)
|
|
137
|
+
if fn is None:
|
|
138
|
+
return args
|
|
139
|
+
# Compare functions and classes without considering bound type parameters, so that you can pass
|
|
140
|
+
# in a binding like Vec[i64] and match Vec[i64](...) or Vec(...) calls.
|
|
141
|
+
if (
|
|
142
|
+
isinstance(actual_fn, RuntimeFunction)
|
|
143
|
+
and isinstance(fn, RuntimeFunction)
|
|
144
|
+
and actual_fn.__egg_ref__ == fn.__egg_ref__
|
|
145
|
+
):
|
|
146
|
+
return args
|
|
147
|
+
if (
|
|
148
|
+
isinstance(actual_fn, RuntimeClass)
|
|
149
|
+
and isinstance(fn, RuntimeClass)
|
|
150
|
+
and actual_fn.__egg_tp__.ident == fn.__egg_tp__.ident
|
|
151
|
+
):
|
|
152
|
+
return args
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _deconstruct_call_decl(
|
|
157
|
+
decls_thunk: Callable[[], Declarations], call: CallDecl
|
|
158
|
+
) -> tuple[Callable, tuple[object, ...]]:
|
|
159
|
+
"""
|
|
160
|
+
Deconstructs a CallDecl into a runtime callable and its arguments.
|
|
161
|
+
"""
|
|
162
|
+
args = call.args
|
|
163
|
+
arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args)
|
|
164
|
+
# TODO: handle values? Like constants
|
|
165
|
+
if isinstance(call.callable, InitRef):
|
|
166
|
+
return RuntimeClass(
|
|
167
|
+
decls_thunk,
|
|
168
|
+
TypeRefWithVars(call.callable.ident, tuple(tp.to_var() for tp in (call.bound_tp_params or []))),
|
|
169
|
+
), arg_exprs
|
|
170
|
+
egg_bound = (
|
|
171
|
+
JustTypeRef(call.callable.ident, call.bound_tp_params or ())
|
|
172
|
+
if isinstance(call.callable, ClassMethodRef)
|
|
173
|
+
else None
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
return RuntimeFunction(decls_thunk, Thunk.value(call.callable), egg_bound), arg_exprs
|