egglog 12.0.0__cp313-cp313t-manylinux_2_17_ppc64.manylinux2014_ppc64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. egglog/__init__.py +13 -0
  2. egglog/bindings.cpython-313t-powerpc64-linux-gnu.so +0 -0
  3. egglog/bindings.pyi +887 -0
  4. egglog/builtins.py +1144 -0
  5. egglog/config.py +8 -0
  6. egglog/conversion.py +290 -0
  7. egglog/declarations.py +964 -0
  8. egglog/deconstruct.py +176 -0
  9. egglog/egraph.py +2247 -0
  10. egglog/egraph_state.py +978 -0
  11. egglog/examples/README.rst +5 -0
  12. egglog/examples/__init__.py +3 -0
  13. egglog/examples/bignum.py +32 -0
  14. egglog/examples/bool.py +38 -0
  15. egglog/examples/eqsat_basic.py +44 -0
  16. egglog/examples/fib.py +28 -0
  17. egglog/examples/higher_order_functions.py +42 -0
  18. egglog/examples/jointree.py +64 -0
  19. egglog/examples/lambda_.py +287 -0
  20. egglog/examples/matrix.py +175 -0
  21. egglog/examples/multiset.py +60 -0
  22. egglog/examples/ndarrays.py +144 -0
  23. egglog/examples/resolution.py +84 -0
  24. egglog/examples/schedule_demo.py +34 -0
  25. egglog/exp/MoA.ipynb +617 -0
  26. egglog/exp/__init__.py +3 -0
  27. egglog/exp/any_expr.py +947 -0
  28. egglog/exp/any_expr_example.ipynb +408 -0
  29. egglog/exp/array_api.py +2019 -0
  30. egglog/exp/array_api_jit.py +51 -0
  31. egglog/exp/array_api_loopnest.py +74 -0
  32. egglog/exp/array_api_numba.py +69 -0
  33. egglog/exp/array_api_program_gen.py +510 -0
  34. egglog/exp/program_gen.py +427 -0
  35. egglog/exp/siu_examples.py +32 -0
  36. egglog/ipython_magic.py +41 -0
  37. egglog/pretty.py +566 -0
  38. egglog/py.typed +0 -0
  39. egglog/runtime.py +888 -0
  40. egglog/thunk.py +97 -0
  41. egglog/type_constraint_solver.py +111 -0
  42. egglog/visualizer.css +1 -0
  43. egglog/visualizer.js +35798 -0
  44. egglog/visualizer_widget.py +39 -0
  45. egglog-12.0.0.dist-info/METADATA +93 -0
  46. egglog-12.0.0.dist-info/RECORD +48 -0
  47. egglog-12.0.0.dist-info/WHEEL +5 -0
  48. egglog-12.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,51 @@
1
+ import inspect
2
+ from collections.abc import Callable
3
+ from typing import TypeVar, cast
4
+
5
+ import numpy as np
6
+
7
+ from egglog import EGraph, greedy_dag_cost_model
8
+ from egglog.exp.array_api import NDArray, set_array_api_egraph, try_evaling
9
+ from egglog.exp.array_api_numba import array_api_numba_schedule
10
+ from egglog.exp.array_api_program_gen import EvalProgram, array_api_program_gen_schedule, ndarray_function_two_program
11
+
12
+ from .program_gen import Program
13
+
14
+ X = TypeVar("X", bound=Callable)
15
+
16
+
17
+ def jit(
18
+ fn: X,
19
+ *,
20
+ handle_expr: Callable[[NDArray], None] | None = None,
21
+ handle_optimized_expr: Callable[[NDArray], None] | None = None,
22
+ ) -> X:
23
+ """
24
+ Jit compiles a function
25
+ """
26
+ egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
27
+ if handle_expr:
28
+ handle_expr(res)
29
+ if handle_optimized_expr:
30
+ handle_optimized_expr(res_optimized)
31
+ fn_program = EvalProgram(program, {"np": np})
32
+ return cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
33
+
34
+
35
+ def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph, NDArray, NDArray, Program]:
36
+ sig = inspect.signature(fn)
37
+ arg1, arg2 = sig.parameters.keys()
38
+ egraph = EGraph(save_egglog_string=save_egglog_string)
39
+ with egraph:
40
+ with set_array_api_egraph(egraph):
41
+ res = fn(NDArray.var(arg1), NDArray.var(arg2))
42
+ egraph.register(res)
43
+ egraph.run(array_api_numba_schedule)
44
+ res_optimized = egraph.extract(res, cost_model=greedy_dag_cost_model())
45
+
46
+ return (
47
+ egraph,
48
+ res,
49
+ res_optimized,
50
+ ndarray_function_two_program(res_optimized, NDArray.var(arg1), NDArray.var(arg2)),
51
+ )
@@ -0,0 +1,74 @@
1
+ """
2
+ Example module to replicate behavior expressed in
3
+
4
+ https://gist.github.com/sklam/5e5737137d48d6e5b816d14a90076f1d
5
+ """
6
+
7
+ # %%
8
+ # mypy: disable-error-code="empty-body"
9
+ from __future__ import annotations
10
+
11
+ from egglog import *
12
+ from egglog.exp.array_api import *
13
+
14
+ __all__ = ["LoopNestAPI", "OptionalLoopNestAPI", "ShapeAPI"]
15
+
16
+
17
+ class ShapeAPI(Expr):
18
+ def __init__(self, dims: TupleIntLike) -> None: ...
19
+
20
+ def deselect(self, axis: TupleIntLike) -> ShapeAPI: ...
21
+
22
+ def select(self, axis: TupleIntLike) -> ShapeAPI: ...
23
+
24
+ def to_tuple(self) -> TupleInt: ...
25
+
26
+
27
+ @array_api_ruleset.register
28
+ def shape_api_ruleset(dims: TupleInt, axis: TupleInt):
29
+ s = ShapeAPI(dims)
30
+ yield rewrite(s.deselect(axis), subsume=True).to(
31
+ ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: ~axis.contains(i)).map(lambda i: dims[i]))
32
+ )
33
+ yield rewrite(s.select(axis), subsume=True).to(
34
+ ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: axis.contains(i)).map(lambda i: dims[i]))
35
+ )
36
+ yield rewrite(s.to_tuple(), subsume=True).to(dims)
37
+
38
+
39
+ class OptionalLoopNestAPI(Expr):
40
+ def __init__(self, value: LoopNestAPI) -> None: ...
41
+
42
+ NONE: ClassVar[OptionalLoopNestAPI]
43
+
44
+ def unwrap(self) -> LoopNestAPI: ...
45
+
46
+
47
+ class LoopNestAPI(Expr, ruleset=array_api_ruleset):
48
+ def __init__(self, last: Int, before: OptionalLoopNestAPI) -> None: ...
49
+
50
+ @classmethod
51
+ def from_tuple(cls, args: TupleIntLike) -> OptionalLoopNestAPI: ...
52
+
53
+ @method(preserve=True)
54
+ def __iter__(self) -> Iterator[TupleInt]:
55
+ return iter(self.indices())
56
+
57
+ def indices(self) -> TupleTupleInt:
58
+ return self.get_dims().map_tuple_int(TupleInt.range).product()
59
+
60
+ def get_dims(self) -> TupleInt: ...
61
+
62
+
63
+ @array_api_ruleset.register
64
+ def _loopnest_api_ruleset(lna: LoopNestAPI, dim: Int, ti: TupleInt, idx_fn: Callable[[Int], Int], i: i64):
65
+ # from_tuple
66
+ yield rewrite(LoopNestAPI.from_tuple(TupleInt.EMPTY), subsume=True).to(OptionalLoopNestAPI.NONE)
67
+ yield rewrite(LoopNestAPI.from_tuple(ti.append(dim)), subsume=True).to(
68
+ OptionalLoopNestAPI(LoopNestAPI(dim, LoopNestAPI.from_tuple(ti)))
69
+ )
70
+ # get_dims
71
+ yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI.NONE).get_dims(), subsume=True).to(TupleInt.single(dim))
72
+ yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI(lna)).get_dims(), subsume=True).to(lna.get_dims().append(dim))
73
+ # unwrap
74
+ yield rewrite(OptionalLoopNestAPI(lna).unwrap()).to(lna)
@@ -0,0 +1,69 @@
1
+ # mypy: disable-error-code="empty-body"
2
+ """
3
+ Module for generating array api code that works with Numba.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from egglog import *
9
+ from egglog.exp.array_api import *
10
+
11
+ array_api_numba_ruleset = ruleset()
12
+ array_api_numba_schedule = (array_api_combined_ruleset | array_api_numba_ruleset).saturate()
13
+ # For these rules, we not only wanna rewrite, we also want to subsume the original expression,
14
+ # so that the rewritten one is used, even if the original one is simpler.
15
+
16
+
17
+ # Rewrite mean(x, <int>, <expand dims>) to use sum b/c numba cant do mean with axis
18
+ # https://github.com/numba/numba/issues/1269
19
+ @array_api_numba_ruleset.register
20
+ def _mean(y: NDArray, x: NDArray, i: Int):
21
+ axis = OptionalIntOrTuple.some(IntOrTuple.int(i))
22
+ res = sum(x, axis) / NDArray.scalar(Value.int(x.shape[i]))
23
+
24
+ yield rewrite(mean(x, axis, FALSE), subsume=True).to(res)
25
+ yield rewrite(mean(x, axis, TRUE), subsume=True).to(expand_dims(res, i))
26
+
27
+
28
+ # Rewrite std(x, <int>) to use mean and sum b/c numba cant do std with axis
29
+ @array_api_numba_ruleset.register
30
+ def _std(y: NDArray, x: NDArray, i: Int):
31
+ axis = OptionalIntOrTuple.some(IntOrTuple.int(i))
32
+ # https://numpy.org/doc/stable/reference/generated/numpy.std.html
33
+ # "std = sqrt(mean(x)), where x = abs(a - a.mean())**2."
34
+ yield rewrite(
35
+ std(x, axis),
36
+ subsume=True,
37
+ ).to(
38
+ sqrt(mean(square(x - mean(x, axis, keepdims=TRUE)), axis)),
39
+ )
40
+
41
+
42
+ # rewrite unique_counts to count each value one by one, since numba doesn't support np.unique(..., return_counts=True)
43
+ @function(ruleset=array_api_numba_ruleset, subsume=True)
44
+ def count_values(x: NDArrayLike, values: TupleValueLike) -> TupleValue:
45
+ """
46
+ Returns a tuple of the count of each of the values in the array as values
47
+ """
48
+ x = cast(NDArray, x)
49
+ values = cast(TupleValue, values)
50
+ return TupleValue(values.length(), lambda i: sum(x == values[i]).to_value())
51
+
52
+
53
+ @array_api_numba_ruleset.register
54
+ def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value):
55
+ return [
56
+ # The unique counts are the count of all the unique values
57
+ rewrite(unique_counts(x)[1], subsume=True).to(NDArray.vector(count_values(x, unique_values(x).to_values()))),
58
+ ]
59
+
60
+
61
+ # do the same for unique_inverse
62
+ @array_api_numba_ruleset.register
63
+ def _unique_inverse(x: NDArray, i: Int):
64
+ return [
65
+ # Creating a mask array of when the unique inverse is a value is the same as a mask array for when the value is that index of the unique values
66
+ rewrite(unique_inverse(x)[Int(1)] == NDArray.scalar(Value.int(i)), subsume=True).to(
67
+ x == NDArray.scalar(unique_values(x).index((i,)))
68
+ ),
69
+ ]