egglog 12.0.0__cp313-cp313t-manylinux_2_17_ppc64.manylinux2014_ppc64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- egglog/__init__.py +13 -0
- egglog/bindings.cpython-313t-powerpc64-linux-gnu.so +0 -0
- egglog/bindings.pyi +887 -0
- egglog/builtins.py +1144 -0
- egglog/config.py +8 -0
- egglog/conversion.py +290 -0
- egglog/declarations.py +964 -0
- egglog/deconstruct.py +176 -0
- egglog/egraph.py +2247 -0
- egglog/egraph_state.py +978 -0
- egglog/examples/README.rst +5 -0
- egglog/examples/__init__.py +3 -0
- egglog/examples/bignum.py +32 -0
- egglog/examples/bool.py +38 -0
- egglog/examples/eqsat_basic.py +44 -0
- egglog/examples/fib.py +28 -0
- egglog/examples/higher_order_functions.py +42 -0
- egglog/examples/jointree.py +64 -0
- egglog/examples/lambda_.py +287 -0
- egglog/examples/matrix.py +175 -0
- egglog/examples/multiset.py +60 -0
- egglog/examples/ndarrays.py +144 -0
- egglog/examples/resolution.py +84 -0
- egglog/examples/schedule_demo.py +34 -0
- egglog/exp/MoA.ipynb +617 -0
- egglog/exp/__init__.py +3 -0
- egglog/exp/any_expr.py +947 -0
- egglog/exp/any_expr_example.ipynb +408 -0
- egglog/exp/array_api.py +2019 -0
- egglog/exp/array_api_jit.py +51 -0
- egglog/exp/array_api_loopnest.py +74 -0
- egglog/exp/array_api_numba.py +69 -0
- egglog/exp/array_api_program_gen.py +510 -0
- egglog/exp/program_gen.py +427 -0
- egglog/exp/siu_examples.py +32 -0
- egglog/ipython_magic.py +41 -0
- egglog/pretty.py +566 -0
- egglog/py.typed +0 -0
- egglog/runtime.py +888 -0
- egglog/thunk.py +97 -0
- egglog/type_constraint_solver.py +111 -0
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35798 -0
- egglog/visualizer_widget.py +39 -0
- egglog-12.0.0.dist-info/METADATA +93 -0
- egglog-12.0.0.dist-info/RECORD +48 -0
- egglog-12.0.0.dist-info/WHEEL +5 -0
- egglog-12.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import TypeVar, cast
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from egglog import EGraph, greedy_dag_cost_model
|
|
8
|
+
from egglog.exp.array_api import NDArray, set_array_api_egraph, try_evaling
|
|
9
|
+
from egglog.exp.array_api_numba import array_api_numba_schedule
|
|
10
|
+
from egglog.exp.array_api_program_gen import EvalProgram, array_api_program_gen_schedule, ndarray_function_two_program
|
|
11
|
+
|
|
12
|
+
from .program_gen import Program
|
|
13
|
+
|
|
14
|
+
X = TypeVar("X", bound=Callable)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def jit(
|
|
18
|
+
fn: X,
|
|
19
|
+
*,
|
|
20
|
+
handle_expr: Callable[[NDArray], None] | None = None,
|
|
21
|
+
handle_optimized_expr: Callable[[NDArray], None] | None = None,
|
|
22
|
+
) -> X:
|
|
23
|
+
"""
|
|
24
|
+
Jit compiles a function
|
|
25
|
+
"""
|
|
26
|
+
egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
|
|
27
|
+
if handle_expr:
|
|
28
|
+
handle_expr(res)
|
|
29
|
+
if handle_optimized_expr:
|
|
30
|
+
handle_optimized_expr(res_optimized)
|
|
31
|
+
fn_program = EvalProgram(program, {"np": np})
|
|
32
|
+
return cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph, NDArray, NDArray, Program]:
|
|
36
|
+
sig = inspect.signature(fn)
|
|
37
|
+
arg1, arg2 = sig.parameters.keys()
|
|
38
|
+
egraph = EGraph(save_egglog_string=save_egglog_string)
|
|
39
|
+
with egraph:
|
|
40
|
+
with set_array_api_egraph(egraph):
|
|
41
|
+
res = fn(NDArray.var(arg1), NDArray.var(arg2))
|
|
42
|
+
egraph.register(res)
|
|
43
|
+
egraph.run(array_api_numba_schedule)
|
|
44
|
+
res_optimized = egraph.extract(res, cost_model=greedy_dag_cost_model())
|
|
45
|
+
|
|
46
|
+
return (
|
|
47
|
+
egraph,
|
|
48
|
+
res,
|
|
49
|
+
res_optimized,
|
|
50
|
+
ndarray_function_two_program(res_optimized, NDArray.var(arg1), NDArray.var(arg2)),
|
|
51
|
+
)
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Example module to replicate behavior expressed in
|
|
3
|
+
|
|
4
|
+
https://gist.github.com/sklam/5e5737137d48d6e5b816d14a90076f1d
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
# %%
|
|
8
|
+
# mypy: disable-error-code="empty-body"
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from egglog import *
|
|
12
|
+
from egglog.exp.array_api import *
|
|
13
|
+
|
|
14
|
+
__all__ = ["LoopNestAPI", "OptionalLoopNestAPI", "ShapeAPI"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ShapeAPI(Expr):
|
|
18
|
+
def __init__(self, dims: TupleIntLike) -> None: ...
|
|
19
|
+
|
|
20
|
+
def deselect(self, axis: TupleIntLike) -> ShapeAPI: ...
|
|
21
|
+
|
|
22
|
+
def select(self, axis: TupleIntLike) -> ShapeAPI: ...
|
|
23
|
+
|
|
24
|
+
def to_tuple(self) -> TupleInt: ...
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@array_api_ruleset.register
|
|
28
|
+
def shape_api_ruleset(dims: TupleInt, axis: TupleInt):
|
|
29
|
+
s = ShapeAPI(dims)
|
|
30
|
+
yield rewrite(s.deselect(axis), subsume=True).to(
|
|
31
|
+
ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: ~axis.contains(i)).map(lambda i: dims[i]))
|
|
32
|
+
)
|
|
33
|
+
yield rewrite(s.select(axis), subsume=True).to(
|
|
34
|
+
ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: axis.contains(i)).map(lambda i: dims[i]))
|
|
35
|
+
)
|
|
36
|
+
yield rewrite(s.to_tuple(), subsume=True).to(dims)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class OptionalLoopNestAPI(Expr):
|
|
40
|
+
def __init__(self, value: LoopNestAPI) -> None: ...
|
|
41
|
+
|
|
42
|
+
NONE: ClassVar[OptionalLoopNestAPI]
|
|
43
|
+
|
|
44
|
+
def unwrap(self) -> LoopNestAPI: ...
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class LoopNestAPI(Expr, ruleset=array_api_ruleset):
|
|
48
|
+
def __init__(self, last: Int, before: OptionalLoopNestAPI) -> None: ...
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def from_tuple(cls, args: TupleIntLike) -> OptionalLoopNestAPI: ...
|
|
52
|
+
|
|
53
|
+
@method(preserve=True)
|
|
54
|
+
def __iter__(self) -> Iterator[TupleInt]:
|
|
55
|
+
return iter(self.indices())
|
|
56
|
+
|
|
57
|
+
def indices(self) -> TupleTupleInt:
|
|
58
|
+
return self.get_dims().map_tuple_int(TupleInt.range).product()
|
|
59
|
+
|
|
60
|
+
def get_dims(self) -> TupleInt: ...
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@array_api_ruleset.register
|
|
64
|
+
def _loopnest_api_ruleset(lna: LoopNestAPI, dim: Int, ti: TupleInt, idx_fn: Callable[[Int], Int], i: i64):
|
|
65
|
+
# from_tuple
|
|
66
|
+
yield rewrite(LoopNestAPI.from_tuple(TupleInt.EMPTY), subsume=True).to(OptionalLoopNestAPI.NONE)
|
|
67
|
+
yield rewrite(LoopNestAPI.from_tuple(ti.append(dim)), subsume=True).to(
|
|
68
|
+
OptionalLoopNestAPI(LoopNestAPI(dim, LoopNestAPI.from_tuple(ti)))
|
|
69
|
+
)
|
|
70
|
+
# get_dims
|
|
71
|
+
yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI.NONE).get_dims(), subsume=True).to(TupleInt.single(dim))
|
|
72
|
+
yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI(lna)).get_dims(), subsume=True).to(lna.get_dims().append(dim))
|
|
73
|
+
# unwrap
|
|
74
|
+
yield rewrite(OptionalLoopNestAPI(lna).unwrap()).to(lna)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
"""
|
|
3
|
+
Module for generating array api code that works with Numba.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from egglog import *
|
|
9
|
+
from egglog.exp.array_api import *
|
|
10
|
+
|
|
11
|
+
array_api_numba_ruleset = ruleset()
|
|
12
|
+
array_api_numba_schedule = (array_api_combined_ruleset | array_api_numba_ruleset).saturate()
|
|
13
|
+
# For these rules, we not only wanna rewrite, we also want to subsume the original expression,
|
|
14
|
+
# so that the rewritten one is used, even if the original one is simpler.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# Rewrite mean(x, <int>, <expand dims>) to use sum b/c numba cant do mean with axis
|
|
18
|
+
# https://github.com/numba/numba/issues/1269
|
|
19
|
+
@array_api_numba_ruleset.register
|
|
20
|
+
def _mean(y: NDArray, x: NDArray, i: Int):
|
|
21
|
+
axis = OptionalIntOrTuple.some(IntOrTuple.int(i))
|
|
22
|
+
res = sum(x, axis) / NDArray.scalar(Value.int(x.shape[i]))
|
|
23
|
+
|
|
24
|
+
yield rewrite(mean(x, axis, FALSE), subsume=True).to(res)
|
|
25
|
+
yield rewrite(mean(x, axis, TRUE), subsume=True).to(expand_dims(res, i))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# Rewrite std(x, <int>) to use mean and sum b/c numba cant do std with axis
|
|
29
|
+
@array_api_numba_ruleset.register
|
|
30
|
+
def _std(y: NDArray, x: NDArray, i: Int):
|
|
31
|
+
axis = OptionalIntOrTuple.some(IntOrTuple.int(i))
|
|
32
|
+
# https://numpy.org/doc/stable/reference/generated/numpy.std.html
|
|
33
|
+
# "std = sqrt(mean(x)), where x = abs(a - a.mean())**2."
|
|
34
|
+
yield rewrite(
|
|
35
|
+
std(x, axis),
|
|
36
|
+
subsume=True,
|
|
37
|
+
).to(
|
|
38
|
+
sqrt(mean(square(x - mean(x, axis, keepdims=TRUE)), axis)),
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# rewrite unique_counts to count each value one by one, since numba doesn't support np.unique(..., return_counts=True)
|
|
43
|
+
@function(ruleset=array_api_numba_ruleset, subsume=True)
|
|
44
|
+
def count_values(x: NDArrayLike, values: TupleValueLike) -> TupleValue:
|
|
45
|
+
"""
|
|
46
|
+
Returns a tuple of the count of each of the values in the array as values
|
|
47
|
+
"""
|
|
48
|
+
x = cast(NDArray, x)
|
|
49
|
+
values = cast(TupleValue, values)
|
|
50
|
+
return TupleValue(values.length(), lambda i: sum(x == values[i]).to_value())
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@array_api_numba_ruleset.register
|
|
54
|
+
def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value):
|
|
55
|
+
return [
|
|
56
|
+
# The unique counts are the count of all the unique values
|
|
57
|
+
rewrite(unique_counts(x)[1], subsume=True).to(NDArray.vector(count_values(x, unique_values(x).to_values()))),
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# do the same for unique_inverse
|
|
62
|
+
@array_api_numba_ruleset.register
|
|
63
|
+
def _unique_inverse(x: NDArray, i: Int):
|
|
64
|
+
return [
|
|
65
|
+
# Creating a mask array of when the unique inverse is a value is the same as a mask array for when the value is that index of the unique values
|
|
66
|
+
rewrite(unique_inverse(x)[Int(1)] == NDArray.scalar(Value.int(i)), subsume=True).to(
|
|
67
|
+
x == NDArray.scalar(unique_values(x).index((i,)))
|
|
68
|
+
),
|
|
69
|
+
]
|