egglog 7.2.0__cp310-none-win_amd64.whl → 8.0.1__cp310-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

@@ -0,0 +1,149 @@
1
+ """
2
+ In progress module
3
+
4
+ https://gist.github.com/sklam/5e5737137d48d6e5b816d14a90076f1d
5
+
6
+ """
7
+
8
+ # %%
9
+ # mypy: disable-error-code="empty-body"
10
+ from __future__ import annotations
11
+
12
+ from egglog import *
13
+ from egglog.exp.array_api import *
14
+
15
+
16
+ class ShapeAPI(Expr):
17
+ @method(unextractable=True)
18
+ def __init__(self, dims: TupleIntLike) -> None: ...
19
+
20
+ @method(unextractable=True)
21
+ def deselect(self, axis: TupleIntLike) -> ShapeAPI: ...
22
+
23
+ @method(unextractable=True)
24
+ def select(self, axis: TupleIntLike) -> ShapeAPI: ...
25
+
26
+ @method(unextractable=True)
27
+ def to_tuple(self) -> TupleInt: ...
28
+
29
+
30
+ @array_api_ruleset.register
31
+ def shape_api_ruleset(dims: TupleInt, axis: TupleInt):
32
+ s = ShapeAPI(dims)
33
+ yield rewrite(s.deselect(axis)).to(
34
+ ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: ~axis.contains(i)).map(lambda i: dims[i]))
35
+ )
36
+ yield rewrite(s.select(axis)).to(
37
+ ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: axis.contains(i)).map(lambda i: dims[i]))
38
+ )
39
+ yield rewrite(s.to_tuple()).to(dims)
40
+
41
+
42
+ class OptionalLoopNestAPI(Expr):
43
+ def __init__(self, value: LoopNestAPI) -> None: ...
44
+
45
+ NONE: ClassVar[OptionalLoopNestAPI]
46
+
47
+ def unwrap(self) -> LoopNestAPI: ...
48
+
49
+
50
+ class LoopNestAPI(Expr):
51
+ def __init__(self, dim: Int, inner: OptionalLoopNestAPI) -> None: ...
52
+
53
+ @classmethod
54
+ def from_tuple(cls, args: TupleInt) -> OptionalLoopNestAPI: ...
55
+
56
+ @method(preserve=True)
57
+ def __iter__(self) -> Iterator[TupleInt]:
58
+ return iter(self.indices)
59
+
60
+ @property
61
+ def indices(self) -> TupleTupleInt: ...
62
+
63
+ def get_dims(self) -> TupleInt: ...
64
+
65
+ def fold(self, fn: Callable[[NDArray, TupleInt], NDArray], init: NDArrayLike) -> NDArray: ...
66
+
67
+
68
+ @function
69
+ def tuple_tuple_int_reduce_ndarray(
70
+ xs: TupleTupleInt, fn: Callable[[NDArray, TupleInt], NDArray], init: NDArray
71
+ ) -> NDArray: ...
72
+
73
+
74
+ @function
75
+ def tuple_int_map_tuple_int(xs: TupleInt, fn: Callable[[Int], TupleInt]) -> TupleTupleInt: ...
76
+
77
+
78
+ @function
79
+ def tuple_tuple_int_product(xs: TupleTupleInt) -> TupleTupleInt: ...
80
+
81
+
82
+ @array_api_ruleset.register
83
+ def _loopnest_api_ruleset(
84
+ head: Int,
85
+ tail: TupleInt,
86
+ lna: LoopNestAPI,
87
+ fn: Callable[[NDArray, TupleInt], NDArray],
88
+ init: NDArray,
89
+ dim: Int,
90
+ idx_fn: Callable[[Int], Int],
91
+ i: i64,
92
+ ):
93
+ # from_tuple
94
+ yield rewrite(LoopNestAPI.from_tuple(TupleInt(0, idx_fn))).to(OptionalLoopNestAPI.NONE)
95
+ yield rewrite(LoopNestAPI.from_tuple(TupleInt(Int(i), idx_fn))).to(
96
+ OptionalLoopNestAPI(
97
+ LoopNestAPI(idx_fn(Int(0)), LoopNestAPI.from_tuple(TupleInt(Int(i - 1), lambda i: idx_fn(i + 1))))
98
+ ),
99
+ ne(i).to(i64(0)),
100
+ )
101
+ # reduce
102
+ yield rewrite(lna.fold(fn, init)).to(tuple_tuple_int_reduce_ndarray(lna.indices, fn, init))
103
+ # get_dims
104
+ yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI.NONE).get_dims()).to(TupleInt.single(dim))
105
+ yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI(lna)).get_dims()).to(TupleInt.single(dim) + lna.get_dims())
106
+ # indices
107
+ yield rewrite(lna.indices).to(tuple_tuple_int_product(tuple_int_map_tuple_int(lna.get_dims(), TupleInt.range)))
108
+
109
+
110
+ @function(ruleset=array_api_ruleset, unextractable=True)
111
+ def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
112
+ # peel off the outer shape for result array
113
+ outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
114
+ # get only the inner shape for reduction
115
+ reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()
116
+
117
+ return NDArray(
118
+ outshape,
119
+ X.dtype,
120
+ lambda k: sqrt(
121
+ LoopNestAPI.from_tuple(reduce_axis)
122
+ .unwrap()
123
+ .fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0)
124
+ ).to_value(),
125
+ )
126
+
127
+
128
+ # %%
129
+ # egraph = EGraph(save_egglog_string=True)
130
+
131
+ # egraph.register(val.shape)
132
+ # egraph.run(array_api_ruleset.saturate())
133
+ # egraph.extract_multiple(val.shape, 10)
134
+
135
+ # %%
136
+
137
+ X = NDArray.var("X")
138
+ assume_shape(X, (3, 2, 3, 4))
139
+ val = linalg_norm(X, (0, 1))
140
+ egraph = EGraph()
141
+ x = egraph.let("x", val.shape[2])
142
+ # egraph.display(n_inline_leaves=0)
143
+ # egraph.extract(x)
144
+ # egraph.saturate(array_api_ruleset, expr=val.shape[2], split_functions=[Int, TRUE, FALSE], n_inline_leaves=2)
145
+ # egraph.run(array_api_ruleset.saturate())
146
+ # egraph.display()
147
+
148
+
149
+ # %%
@@ -67,6 +67,6 @@ def _unique_inverse(x: NDArray, i: Int):
67
67
  return [
68
68
  # Creating a mask array of when the unique inverse is a value is the same as a mask array for when the value is that index of the unique values
69
69
  rewrite(unique_inverse(x)[Int(1)] == NDArray.scalar(Value.int(i)), subsume=True).to(
70
- x == NDArray.scalar(unique_values(x).index(TupleInt(i)))
70
+ x == NDArray.scalar(unique_values(x).index((i,)))
71
71
  ),
72
72
  ]
@@ -1,8 +1,6 @@
1
1
  # mypy: disable-error-code="empty-body"
2
2
  from __future__ import annotations
3
3
 
4
- import numpy as np
5
-
6
4
  from egglog import *
7
5
 
8
6
  from .array_api import *
@@ -13,9 +11,12 @@ from .program_gen import *
13
11
  # Depends on `np` as a global variable.
14
12
  ##
15
13
 
16
- array_api_program_gen_ruleset = ruleset()
14
+ array_api_program_gen_ruleset = ruleset(name="array_api_program_gen_ruleset")
15
+ array_api_program_gen_eval_ruleset = ruleset(name="array_api_program_gen_eval_ruleset")
17
16
 
18
- array_api_program_gen_schedule = array_api_program_gen_ruleset.saturate() + program_gen_ruleset.saturate()
17
+ array_api_program_gen_schedule = (
18
+ array_api_program_gen_ruleset | program_gen_ruleset | array_api_program_gen_eval_ruleset | eval_program_rulseset
19
+ ).saturate()
19
20
 
20
21
 
21
22
  @function
@@ -56,39 +57,56 @@ def _int_program(i64_: i64, i: Int, j: Int):
56
57
 
57
58
 
58
59
  @function
59
- def tuple_int_program(x: TupleInt) -> Program: ...
60
+ def tuple_int_program(x: TupleInt) -> Program:
61
+ ...
62
+ # Could be rewritten as a fold, but we don't support generic folds yet
63
+ # return x.fold(Program("("), lambda acc, i: acc + ", " + int_program(i)) + ")"
60
64
 
61
65
 
62
66
  @function
63
- def tuple_int_program_inner(x: TupleInt) -> Program: ...
67
+ def tuple_int_program_inner(x: TupleInt) -> Program:
68
+ """
69
+ Returns the tuple w/ out the parenthesis
70
+ """
64
71
 
65
72
 
66
73
  @array_api_program_gen_ruleset.register
67
- def _tuple_int_program(i: Int, j: Int, ti: TupleInt, ti1: TupleInt, ti2: TupleInt):
74
+ def _tuple_int_program(i: Int, ti: TupleInt, k: i64, idx_fn: Callable[[Int], Int], vec_int: Vec[Int]):
68
75
  yield rewrite(int_program(ti[i])).to(tuple_int_program(ti) + "[" + int_program(i) + "]")
76
+ yield rewrite(int_program(ti.length())).to(Program("len(") + tuple_int_program(ti) + ")")
69
77
 
70
78
  yield rewrite(tuple_int_program(ti)).to(Program("(") + tuple_int_program_inner(ti) + ")")
71
- yield rewrite(tuple_int_program_inner(ti1 + ti2)).to(
72
- tuple_int_program_inner(ti1) + " " + tuple_int_program_inner(ti2)
79
+
80
+ yield rewrite(tuple_int_program_inner(TupleInt(0, idx_fn))).to(Program(""))
81
+
82
+ yield rewrite(tuple_int_program_inner(TupleInt(Int(k), idx_fn))).to(
83
+ int_program(idx_fn(Int(0))) + ", " + tuple_int_program_inner(TupleInt(Int(k - 1), lambda i: idx_fn(i + 1))),
84
+ ne(k).to(i64(0)),
85
+ )
86
+
87
+ yield rewrite(tuple_int_program_inner(TupleInt.from_vec(Vec[Int]()))).to(Program(""))
88
+ yield rewrite(tuple_int_program_inner(TupleInt.from_vec(vec_int))).to(
89
+ int_program(vec_int[0]) + ", " + tuple_int_program_inner(TupleInt.from_vec(vec_int.remove(0))),
90
+ vec_int.length() > 1,
91
+ )
92
+ yield rewrite(tuple_int_program_inner(TupleInt.from_vec(vec_int))).to(
93
+ int_program(vec_int[0]) + ",",
94
+ eq(vec_int.length()).to(i64(1)),
73
95
  )
74
- yield rewrite(tuple_int_program_inner(TupleInt(i))).to(int_program(i) + ",")
75
96
 
76
97
 
77
98
  @function
78
99
  def ndarray_program(x: NDArray) -> Program: ...
79
100
 
80
101
 
81
- @function
82
- def ndarray_function_two(res: NDArray, l: NDArray, r: NDArray) -> Program: ...
102
+ @function(ruleset=array_api_program_gen_ruleset)
103
+ def ndarray_function_two_program(res: NDArray, l: NDArray, r: NDArray) -> Program:
104
+ return ndarray_program(res).function_two(ndarray_program(l), ndarray_program(r))
83
105
 
84
106
 
85
- @array_api_program_gen_ruleset.register
86
- def _ndarray_function_two(f: Program, res: NDArray, l: NDArray, r: NDArray, o: PyObject):
87
- # When we have function, set the program and trigger it to be compiled
88
- yield rule(eq(f).to(ndarray_function_two(res, l, r))).then(
89
- union(f).with_(ndarray_program(res).function_two(ndarray_program(l), ndarray_program(r))),
90
- f.eval_py_object({"np": np}),
91
- )
107
+ @function(ruleset=array_api_program_gen_eval_ruleset)
108
+ def ndarray_function_two(res: NDArray, l: NDArray, r: NDArray) -> EvalProgram:
109
+ return EvalProgram(ndarray_function_two_program(res, l, r), {"np": np})
92
110
 
93
111
 
94
112
  @function
@@ -248,12 +266,31 @@ def multi_axis_index_key_program(x: MultiAxisIndexKey) -> Program: ...
248
266
 
249
267
 
250
268
  @array_api_program_gen_ruleset.register
251
- def _multi_axis_index_key_program(l: MultiAxisIndexKey, r: MultiAxisIndexKey, item: MultiAxisIndexKeyItem):
252
- yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey(item))).to(multi_axis_index_key_item_program(item))
253
- yield rewrite(multi_axis_index_key_program(l + r)).to(
254
- multi_axis_index_key_program(l) + ", " + multi_axis_index_key_program(r)
269
+ def _multi_axis_index_key_program(
270
+ idx_fn: Callable[[Int], MultiAxisIndexKeyItem], k: i64, vec: Vec[MultiAxisIndexKeyItem], i: MultiAxisIndexKeyItem
271
+ ):
272
+ yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey(0, idx_fn))).to(Program(""))
273
+
274
+ yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey(Int(k), idx_fn))).to(
275
+ multi_axis_index_key_item_program(idx_fn(Int(0)))
276
+ + ", "
277
+ + multi_axis_index_key_program(MultiAxisIndexKey(Int(k - 1), lambda i: idx_fn(i + 1))),
278
+ ne(k).to(i64(0)),
279
+ )
280
+
281
+ yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem]()))).to(
282
+ Program("")
283
+ )
284
+ yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec))).to(
285
+ multi_axis_index_key_item_program(vec[0]) + ",",
286
+ eq(vec.length()).to(i64(1)),
287
+ )
288
+ yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec))).to(
289
+ multi_axis_index_key_item_program(vec[0])
290
+ + ", "
291
+ + multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec.remove(0))),
292
+ vec.length() > 1,
255
293
  )
256
- yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.EMPTY)).to(Program("()"))
257
294
 
258
295
 
259
296
  @function
@@ -266,7 +303,7 @@ def _index_key_program(i: Int, s: Slice, key: MultiAxisIndexKey, a: NDArray):
266
303
  yield rewrite(index_key_program(IndexKey.int(i))).to(int_program(i))
267
304
  yield rewrite(index_key_program(IndexKey.slice(s))).to(slice_program(s))
268
305
  yield rewrite(index_key_program(IndexKey.multi_axis(key))).to(multi_axis_index_key_program(key))
269
- yield rewrite(index_key_program(ndarray_index(a))).to(ndarray_program(a))
306
+ yield rewrite(index_key_program(IndexKey.ndarray(a))).to(ndarray_program(a))
270
307
 
271
308
 
272
309
  @function
egglog/exp/program_gen.py CHANGED
@@ -83,8 +83,18 @@ class Program(Expr):
83
83
  Only keeps the original parent, not any additional ones, so that each set of statements is only added once.
84
84
  """
85
85
 
86
- @method(default=Unit())
87
- def eval_py_object(self, globals: object) -> Unit:
86
+ @property
87
+ def is_identifer(self) -> Bool:
88
+ """
89
+ Returns whether the expression is an identifier. Used so that we don't re-assign any identifiers.
90
+ """
91
+
92
+
93
+ converter(String, Program, Program)
94
+
95
+
96
+ class EvalProgram(Expr):
97
+ def __init__(self, program: Program, globals: object) -> None:
88
98
  """
89
99
  Evaluates the program and saves as the py_object
90
100
  """
@@ -98,38 +108,34 @@ class Program(Expr):
98
108
  """
99
109
 
100
110
  @property
101
- def is_identifer(self) -> Bool:
111
+ def statements(self) -> String:
102
112
  """
103
- Returns whether the expression is an identifier. Used so that we don't re-assign any identifiers.
113
+ Returns the statements of the program, if it's been compiled
104
114
  """
105
115
 
106
116
 
107
- converter(String, Program, Program)
108
-
109
- program_gen_ruleset = ruleset()
110
-
111
-
112
- @program_gen_ruleset.register
113
- def _py_object(p: Program, expr: String, statements: String, g: PyObject):
117
+ @ruleset
118
+ def eval_program_rulseset(ep: EvalProgram, p: Program, expr: String, statements: String, g: PyObject):
114
119
  # When we evaluate a program, we first want to compile to a string
115
- yield rule(p.eval_py_object(g)).then(p.compile())
120
+ yield rule(EvalProgram(p, g)).then(p.compile())
116
121
  # Then we want to evaluate the statements/expr
117
122
  yield rule(
118
- p.eval_py_object(g),
123
+ eq(ep).to(EvalProgram(p, g)),
119
124
  eq(p.statements).to(statements),
120
125
  eq(p.expr).to(expr),
121
126
  ).then(
122
- set_(p.py_object).to(
127
+ set_(ep.py_object).to(
123
128
  py_eval(
124
129
  "l['___res']",
125
130
  PyObject.dict(PyObject.from_string("l"), py_exec(join(statements, "\n", "___res = ", expr), g)),
126
131
  )
127
- )
132
+ ),
133
+ set_(ep.statements).to(statements),
128
134
  )
129
135
 
130
136
 
131
- @program_gen_ruleset.register
132
- def _compile(
137
+ @ruleset
138
+ def program_gen_ruleset(
133
139
  s: String,
134
140
  s1: String,
135
141
  s2: String,
@@ -0,0 +1,91 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from functools import partial
5
+ from inspect import Parameter, signature
6
+ from typing import Any, TypeVar, cast
7
+
8
+ __all__ = ["functionalize"]
9
+
10
+
11
+ T = TypeVar("T", bound=Callable)
12
+
13
+
14
+ # TODO: Add `to_lift` param so that we only transform those with vars in them to args
15
+
16
+
17
+ def functionalize(f: T, get_annotation: Callable[[object], type | None]) -> T:
18
+ """
19
+ Takes a function and returns a new function with all names (co_names) and free variables (co_freevars) added as arguments
20
+ and then partially applied with their values. The second arg, get_annotation, will be applied to all values
21
+ to get their type annotation. If it is None, that arg will not be added as a parameter.
22
+
23
+ For example if you have:
24
+
25
+ def get_annotation(x): return int if x <= 10 else None
26
+
27
+ g = 10
28
+ def f(a, a2):
29
+ def h(b: Z):
30
+ return a + a2 + b + g
31
+
32
+ return functionalize(h, get_annotation)
33
+ res = f(9, 11)
34
+
35
+ It should be equivalent to (according to body, signature, and annotations) (Note that the new arguments will be positional only):
36
+
37
+ def h(a: get_annotation(a), g: get_annotation(g), b: Z):
38
+ return a + b + g
39
+ res = partial(h, a, g)
40
+ """
41
+ code = f.__code__
42
+ names = tuple(n for n in code.co_names if n in f.__globals__)
43
+ free_vars = code.co_freevars
44
+
45
+ global_values: list[Any] = [f.__globals__[name] for name in names]
46
+ free_var_values = [cell.cell_contents for cell in f.__closure__] if f.__closure__ else []
47
+ assert len(free_var_values) == len(free_vars), "Free vars and their values do not match"
48
+ global_values_filtered = [
49
+ (i, name, value, annotation)
50
+ for i, (name, value) in enumerate(zip(names, global_values, strict=True))
51
+ if (annotation := get_annotation(value)) is not None
52
+ ]
53
+ free_var_values_filtered = [
54
+ (i, name, value, annotation)
55
+ for i, (name, value) in enumerate(zip(free_vars, free_var_values, strict=True))
56
+ if (annotation := get_annotation(value)) is not None
57
+ ]
58
+ additional_arg_filtered = global_values_filtered + free_var_values_filtered
59
+
60
+ # Create a wrapper function
61
+ def wrapper(*args):
62
+ # Split args into names, free vars and other args
63
+ name_args, free_var_args, rest_args = (
64
+ args[: (n_names := len(global_values_filtered))],
65
+ args[n_names : (n_args := len(additional_arg_filtered))],
66
+ args[n_args:],
67
+ )
68
+ # Update globals with names
69
+ f.__globals__.update({
70
+ name: arg for (_, name, _, _), arg in zip(global_values_filtered, name_args, strict=False)
71
+ })
72
+ # update function free vars with free var args
73
+ for (i, _, _, _), value in zip(free_var_values_filtered, free_var_args, strict=True):
74
+ assert f.__closure__, "Function does not have closure"
75
+ f.__closure__[i].cell_contents = value
76
+ return f(*rest_args)
77
+
78
+ # Set the signature of the new function to a signature with the free vars and names added as arguments
79
+ orig_signature = signature(f)
80
+ wrapper.__signature__ = orig_signature.replace( # type: ignore[attr-defined]
81
+ parameters=[
82
+ *[Parameter(n, Parameter.POSITIONAL_OR_KEYWORD) for _, n, _, _ in additional_arg_filtered],
83
+ *orig_signature.parameters.values(),
84
+ ]
85
+ )
86
+ # Set the annotations of the new function to the annotations of the original function + annotations of passed in values
87
+ wrapper.__annotations__ = f.__annotations__ | {n: a for _, n, _, a in additional_arg_filtered}
88
+ wrapper.__name__ = f.__name__
89
+
90
+ # Partially apply the wrapper function with the current values of the free vars
91
+ return cast(T, partial(wrapper, *(v for _, _, v, _ in additional_arg_filtered)))
egglog/ipython_magic.py CHANGED
@@ -14,7 +14,7 @@ if IN_IPYTHON:
14
14
 
15
15
  @needs_local_scope
16
16
  @register_cell_magic
17
- def egglog(line, cell, local_ns): # noqa: ANN001, ANN201
17
+ def egglog(line, cell, local_ns):
18
18
  """
19
19
  Run an egglog program.
20
20