egglog 11.2.0__cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +13 -0
- egglog/bindings.cpython-314-x86_64-linux-gnu.so +0 -0
- egglog/bindings.pyi +734 -0
- egglog/builtins.py +1133 -0
- egglog/config.py +8 -0
- egglog/conversion.py +286 -0
- egglog/declarations.py +912 -0
- egglog/deconstruct.py +173 -0
- egglog/egraph.py +1875 -0
- egglog/egraph_state.py +680 -0
- egglog/examples/README.rst +5 -0
- egglog/examples/__init__.py +3 -0
- egglog/examples/bignum.py +32 -0
- egglog/examples/bool.py +38 -0
- egglog/examples/eqsat_basic.py +44 -0
- egglog/examples/fib.py +28 -0
- egglog/examples/higher_order_functions.py +42 -0
- egglog/examples/jointree.py +67 -0
- egglog/examples/lambda_.py +287 -0
- egglog/examples/matrix.py +175 -0
- egglog/examples/multiset.py +60 -0
- egglog/examples/ndarrays.py +144 -0
- egglog/examples/resolution.py +84 -0
- egglog/examples/schedule_demo.py +34 -0
- egglog/exp/__init__.py +3 -0
- egglog/exp/array_api.py +2019 -0
- egglog/exp/array_api_jit.py +51 -0
- egglog/exp/array_api_loopnest.py +74 -0
- egglog/exp/array_api_numba.py +69 -0
- egglog/exp/array_api_program_gen.py +510 -0
- egglog/exp/program_gen.py +425 -0
- egglog/exp/siu_examples.py +32 -0
- egglog/ipython_magic.py +41 -0
- egglog/pretty.py +509 -0
- egglog/py.typed +0 -0
- egglog/runtime.py +712 -0
- egglog/thunk.py +97 -0
- egglog/type_constraint_solver.py +113 -0
- egglog/version_compat.py +87 -0
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35777 -0
- egglog/visualizer_widget.py +39 -0
- egglog-11.2.0.dist-info/METADATA +74 -0
- egglog-11.2.0.dist-info/RECORD +46 -0
- egglog-11.2.0.dist-info/WHEEL +4 -0
- egglog-11.2.0.dist-info/licenses/LICENSE +21 -0
egglog/deconstruct.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions to deconstruct expressions in Python.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from functools import partial
|
|
9
|
+
from typing import TYPE_CHECKING, TypeVar, overload
|
|
10
|
+
|
|
11
|
+
from typing_extensions import TypeVarTuple, Unpack
|
|
12
|
+
|
|
13
|
+
from .declarations import *
|
|
14
|
+
from .egraph import BaseExpr
|
|
15
|
+
from .runtime import *
|
|
16
|
+
from .thunk import *
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from .builtins import Bool, PyObject, String, UnstableFn, f64, i64
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
T = TypeVar("T", bound=BaseExpr)
|
|
23
|
+
TS = TypeVarTuple("TS", default=Unpack[tuple[BaseExpr, ...]])
|
|
24
|
+
|
|
25
|
+
__all__ = ["get_callable_args", "get_callable_fn", "get_let_name", "get_literal_value", "get_var_name"]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@overload
|
|
29
|
+
def get_literal_value(x: String) -> str | None: ...
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@overload
|
|
33
|
+
def get_literal_value(x: Bool) -> bool | None: ...
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@overload
|
|
37
|
+
def get_literal_value(x: i64) -> int | None: ...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@overload
|
|
41
|
+
def get_literal_value(x: f64) -> float | None: ...
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@overload
|
|
45
|
+
def get_literal_value(x: PyObject) -> object: ...
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@overload
|
|
49
|
+
def get_literal_value(x: UnstableFn[T, Unpack[TS]]) -> Callable[[Unpack[TS]], T] | None: ...
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_literal_value(x: String | Bool | i64 | f64 | PyObject | UnstableFn) -> object:
|
|
53
|
+
"""
|
|
54
|
+
Returns the literal value of an expression if it is a literal.
|
|
55
|
+
If it is not a literal, returns None.
|
|
56
|
+
"""
|
|
57
|
+
if not isinstance(x, RuntimeExpr):
|
|
58
|
+
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
59
|
+
match x.__egg_typed_expr__.expr:
|
|
60
|
+
case LitDecl(v):
|
|
61
|
+
return v
|
|
62
|
+
case PyObjectDecl(obj):
|
|
63
|
+
return obj
|
|
64
|
+
case PartialCallDecl(call):
|
|
65
|
+
fn, args = _deconstruct_call_decl(x.__egg_decls_thunk__, call)
|
|
66
|
+
if not args:
|
|
67
|
+
return fn
|
|
68
|
+
return partial(fn, *args)
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_let_name(x: BaseExpr) -> str | None:
|
|
73
|
+
"""
|
|
74
|
+
Check if the expression is a `let` expression and return the name of the variable.
|
|
75
|
+
If it is not a `let` expression, return None.
|
|
76
|
+
"""
|
|
77
|
+
if not isinstance(x, RuntimeExpr):
|
|
78
|
+
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
79
|
+
match x.__egg_typed_expr__.expr:
|
|
80
|
+
case LetRefDecl(name):
|
|
81
|
+
return name
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def get_var_name(x: BaseExpr) -> str | None:
|
|
86
|
+
"""
|
|
87
|
+
Check if the expression is a variable and return its name.
|
|
88
|
+
If it is not a variable, return None.
|
|
89
|
+
"""
|
|
90
|
+
if not isinstance(x, RuntimeExpr):
|
|
91
|
+
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
92
|
+
match x.__egg_typed_expr__.expr:
|
|
93
|
+
case UnboundVarDecl(name, _egg_name):
|
|
94
|
+
return name
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_callable_fn(x: T) -> Callable[..., T] | None:
|
|
99
|
+
"""
|
|
100
|
+
Gets the function of an expression if it is a call expression.
|
|
101
|
+
If it is not a call expression (a property, a primitive value, constants, classvars, a let value), return None.
|
|
102
|
+
For those values, you can check them by comparing them directly with equality or for primitives calling `.eval()`
|
|
103
|
+
to return the Python value.
|
|
104
|
+
"""
|
|
105
|
+
if not isinstance(x, RuntimeExpr):
|
|
106
|
+
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
107
|
+
match x.__egg_typed_expr__.expr:
|
|
108
|
+
case CallDecl() as call:
|
|
109
|
+
fn, _ = _deconstruct_call_decl(x.__egg_decls_thunk__, call)
|
|
110
|
+
return fn
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@overload
|
|
115
|
+
def get_callable_args(x: T, fn: None = ...) -> tuple[BaseExpr, ...]: ...
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@overload
|
|
119
|
+
def get_callable_args(x: T, fn: Callable[[Unpack[TS]], T]) -> tuple[Unpack[TS]] | None: ...
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def get_callable_args(x: T, fn: Callable[[Unpack[TS]], T] | None = None) -> tuple[Unpack[TS]] | None:
|
|
123
|
+
"""
|
|
124
|
+
Gets all the arguments of an expression.
|
|
125
|
+
If a function is provided, it will only return the arguments if the expression is a call
|
|
126
|
+
to that function.
|
|
127
|
+
|
|
128
|
+
Note that recursively calling the arguments is the safe way to walk the expression tree.
|
|
129
|
+
"""
|
|
130
|
+
if not isinstance(x, RuntimeExpr):
|
|
131
|
+
raise TypeError(f"Expected Expression, got {type(x).__name__}")
|
|
132
|
+
match x.__egg_typed_expr__.expr:
|
|
133
|
+
case CallDecl() as call:
|
|
134
|
+
actual_fn, args = _deconstruct_call_decl(x.__egg_decls_thunk__, call)
|
|
135
|
+
if fn is None:
|
|
136
|
+
return args
|
|
137
|
+
# Compare functions and classes without considering bound type parameters, so that you can pass
|
|
138
|
+
# in a binding like Vec[i64] and match Vec[i64](...) or Vec(...) calls.
|
|
139
|
+
if (
|
|
140
|
+
isinstance(actual_fn, RuntimeFunction)
|
|
141
|
+
and isinstance(fn, RuntimeFunction)
|
|
142
|
+
and actual_fn.__egg_ref__ == fn.__egg_ref__
|
|
143
|
+
):
|
|
144
|
+
return args
|
|
145
|
+
if (
|
|
146
|
+
isinstance(actual_fn, RuntimeClass)
|
|
147
|
+
and isinstance(fn, RuntimeClass)
|
|
148
|
+
and actual_fn.__egg_tp__.name == fn.__egg_tp__.name
|
|
149
|
+
):
|
|
150
|
+
return args
|
|
151
|
+
return None
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _deconstruct_call_decl(
|
|
155
|
+
decls_thunk: Callable[[], Declarations], call: CallDecl
|
|
156
|
+
) -> tuple[Callable, tuple[object, ...]]:
|
|
157
|
+
"""
|
|
158
|
+
Deconstructs a CallDecl into a runtime callable and its arguments.
|
|
159
|
+
"""
|
|
160
|
+
args = call.args
|
|
161
|
+
arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args)
|
|
162
|
+
if isinstance(call.callable, InitRef):
|
|
163
|
+
return RuntimeClass(
|
|
164
|
+
decls_thunk,
|
|
165
|
+
TypeRefWithVars(call.callable.class_name, tuple(tp.to_var() for tp in (call.bound_tp_params or []))),
|
|
166
|
+
), arg_exprs
|
|
167
|
+
egg_bound = (
|
|
168
|
+
JustTypeRef(call.callable.class_name, call.bound_tp_params or ())
|
|
169
|
+
if isinstance(call.callable, ClassMethodRef)
|
|
170
|
+
else None
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return RuntimeFunction(decls_thunk, Thunk.value(call.callable), egg_bound), arg_exprs
|