egglog 7.1.0__cp312-none-win_amd64.whl → 8.0.0__cp312-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp312-win_amd64.pyd +0 -0
- egglog/bindings.pyi +63 -23
- egglog/builtins.py +49 -6
- egglog/conversion.py +31 -8
- egglog/declarations.py +146 -8
- egglog/egraph.py +337 -203
- egglog/egraph_state.py +171 -64
- egglog/examples/higher_order_functions.py +45 -0
- egglog/exp/array_api.py +278 -93
- egglog/exp/array_api_jit.py +1 -4
- egglog/exp/array_api_loopnest.py +145 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +51 -12
- egglog/functionalize.py +91 -0
- egglog/pretty.py +97 -43
- egglog/runtime.py +60 -44
- egglog/thunk.py +44 -20
- egglog/type_constraint_solver.py +5 -4
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35753 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.1.0.dist-info → egglog-8.0.0.dist-info}/METADATA +31 -30
- egglog-8.0.0.dist-info/RECORD +42 -0
- {egglog-7.1.0.dist-info → egglog-8.0.0.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.1.0.dist-info/RECORD +0 -39
- {egglog-7.1.0.dist-info/license_files → egglog-8.0.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""
|
|
2
|
+
In progress module
|
|
3
|
+
|
|
4
|
+
https://gist.github.com/sklam/5e5737137d48d6e5b816d14a90076f1d
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# %%
|
|
9
|
+
# mypy: disable-error-code="empty-body"
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from egglog import *
|
|
13
|
+
from egglog.exp.array_api import *
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ShapeAPI(Expr):
|
|
17
|
+
@method(unextractable=True)
|
|
18
|
+
def __init__(self, dims: TupleIntLike) -> None: ...
|
|
19
|
+
|
|
20
|
+
@method(unextractable=True)
|
|
21
|
+
def deselect(self, axis: TupleIntLike) -> ShapeAPI: ...
|
|
22
|
+
|
|
23
|
+
@method(unextractable=True)
|
|
24
|
+
def select(self, axis: TupleIntLike) -> ShapeAPI: ...
|
|
25
|
+
|
|
26
|
+
@method(unextractable=True)
|
|
27
|
+
def to_tuple(self) -> TupleInt: ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@array_api_ruleset.register
|
|
31
|
+
def shape_api_ruleset(dims: TupleInt, axis: TupleInt): # noqa: ANN201
|
|
32
|
+
s = ShapeAPI(dims)
|
|
33
|
+
yield rewrite(s.deselect(axis)).to(
|
|
34
|
+
ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: ~axis.contains(i)).map(lambda i: dims[i]))
|
|
35
|
+
)
|
|
36
|
+
yield rewrite(s.select(axis)).to(
|
|
37
|
+
ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: axis.contains(i)).map(lambda i: dims[i]))
|
|
38
|
+
)
|
|
39
|
+
yield rewrite(s.to_tuple()).to(dims)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class OptionalLoopNestAPI(Expr):
|
|
43
|
+
def __init__(self, value: LoopNestAPI) -> None: ...
|
|
44
|
+
|
|
45
|
+
NONE: ClassVar[OptionalLoopNestAPI]
|
|
46
|
+
|
|
47
|
+
def unwrap(self) -> LoopNestAPI: ...
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class LoopNestAPI(Expr):
|
|
51
|
+
def __init__(self, dim: Int, inner: OptionalLoopNestAPI) -> None: ...
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def from_tuple(cls, args: TupleInt) -> OptionalLoopNestAPI: ...
|
|
55
|
+
|
|
56
|
+
@method(preserve=True)
|
|
57
|
+
def __iter__(self) -> Iterator[TupleInt]:
|
|
58
|
+
return iter(self.indices)
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def indices(self) -> TupleTupleInt: ...
|
|
62
|
+
|
|
63
|
+
def get_dims(self) -> TupleInt: ...
|
|
64
|
+
|
|
65
|
+
def fold(self, fn: Callable[[NDArray, TupleInt], NDArray], init: NDArrayLike) -> NDArray: ...
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@function
|
|
69
|
+
def tuple_tuple_int_reduce_ndarray(
|
|
70
|
+
xs: TupleTupleInt, fn: Callable[[NDArray, TupleInt], NDArray], init: NDArray
|
|
71
|
+
) -> NDArray: ...
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@function
|
|
75
|
+
def tuple_int_map_tuple_int(xs: TupleInt, fn: Callable[[Int], TupleInt]) -> TupleTupleInt: ...
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@function
|
|
79
|
+
def tuple_tuple_int_product(xs: TupleTupleInt) -> TupleTupleInt: ...
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@array_api_ruleset.register
|
|
83
|
+
def _loopnest_api_ruleset(
|
|
84
|
+
head: Int,
|
|
85
|
+
tail: TupleInt,
|
|
86
|
+
lna: LoopNestAPI,
|
|
87
|
+
fn: Callable[[NDArray, TupleInt], NDArray],
|
|
88
|
+
init: NDArray,
|
|
89
|
+
dim: Int,
|
|
90
|
+
idx_fn: Callable[[Int], Int],
|
|
91
|
+
i: i64,
|
|
92
|
+
):
|
|
93
|
+
# from_tuple
|
|
94
|
+
yield rewrite(LoopNestAPI.from_tuple(TupleInt(0, idx_fn))).to(OptionalLoopNestAPI.NONE)
|
|
95
|
+
yield rewrite(LoopNestAPI.from_tuple(TupleInt(Int(i), idx_fn))).to(
|
|
96
|
+
OptionalLoopNestAPI(
|
|
97
|
+
LoopNestAPI(idx_fn(Int(0)), LoopNestAPI.from_tuple(TupleInt(Int(i - 1), lambda i: idx_fn(i + 1))))
|
|
98
|
+
),
|
|
99
|
+
ne(i).to(i64(0)),
|
|
100
|
+
)
|
|
101
|
+
# reduce
|
|
102
|
+
yield rewrite(lna.fold(fn, init)).to(tuple_tuple_int_reduce_ndarray(lna.indices, fn, init))
|
|
103
|
+
# get_dims
|
|
104
|
+
yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI.NONE).get_dims()).to(TupleInt.single(dim))
|
|
105
|
+
yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI(lna)).get_dims()).to(TupleInt.single(dim) + lna.get_dims())
|
|
106
|
+
# indices
|
|
107
|
+
yield rewrite(lna.indices).to(tuple_tuple_int_product(tuple_int_map_tuple_int(lna.get_dims(), TupleInt.range)))
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@function(ruleset=array_api_ruleset, unextractable=True)
|
|
111
|
+
def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray: # noqa: N803
|
|
112
|
+
# peel off the outer shape for result array
|
|
113
|
+
outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
|
|
114
|
+
# get only the inner shape for reduction
|
|
115
|
+
reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()
|
|
116
|
+
|
|
117
|
+
return NDArray(
|
|
118
|
+
outshape,
|
|
119
|
+
X.dtype,
|
|
120
|
+
lambda k: sqrt(
|
|
121
|
+
LoopNestAPI.from_tuple(reduce_axis)
|
|
122
|
+
.unwrap()
|
|
123
|
+
.fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0)
|
|
124
|
+
).to_value(),
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
# %%
|
|
129
|
+
egraph = EGraph(save_egglog_string=True)
|
|
130
|
+
|
|
131
|
+
X = NDArray.var("X")
|
|
132
|
+
assume_shape(X, (3, 2, 3, 4))
|
|
133
|
+
val = linalg_norm(X, (0, 1))
|
|
134
|
+
egraph.register(val.shape)
|
|
135
|
+
egraph.run(array_api_ruleset.saturate())
|
|
136
|
+
egraph.extract_multiple(val.shape, 10)
|
|
137
|
+
|
|
138
|
+
# %%
|
|
139
|
+
egraph = EGraph()
|
|
140
|
+
egraph.register(val.shape[2])
|
|
141
|
+
egraph.run(array_api_ruleset.saturate())
|
|
142
|
+
egraph.display(split_functions=[Int, TRUE, FALSE], n_inline_leaves=2)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
# %%
|
egglog/exp/array_api_numba.py
CHANGED
|
@@ -67,6 +67,6 @@ def _unique_inverse(x: NDArray, i: Int):
|
|
|
67
67
|
return [
|
|
68
68
|
# Creating a mask array of when the unique inverse is a value is the same as a mask array for when the value is that index of the unique values
|
|
69
69
|
rewrite(unique_inverse(x)[Int(1)] == NDArray.scalar(Value.int(i)), subsume=True).to(
|
|
70
|
-
x == NDArray.scalar(unique_values(x).index(
|
|
70
|
+
x == NDArray.scalar(unique_values(x).index((i,)))
|
|
71
71
|
),
|
|
72
72
|
]
|
|
@@ -56,22 +56,42 @@ def _int_program(i64_: i64, i: Int, j: Int):
|
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
@function
|
|
59
|
-
def tuple_int_program(x: TupleInt) -> Program:
|
|
59
|
+
def tuple_int_program(x: TupleInt) -> Program:
|
|
60
|
+
...
|
|
61
|
+
# Could be rewritten as a fold, but we don't support generic folds yet
|
|
62
|
+
# return x.fold(Program("("), lambda acc, i: acc + ", " + int_program(i)) + ")"
|
|
60
63
|
|
|
61
64
|
|
|
62
65
|
@function
|
|
63
|
-
def tuple_int_program_inner(x: TupleInt) -> Program:
|
|
66
|
+
def tuple_int_program_inner(x: TupleInt) -> Program:
|
|
67
|
+
"""
|
|
68
|
+
Returns the tuple w/ out the parenthesis
|
|
69
|
+
"""
|
|
64
70
|
|
|
65
71
|
|
|
66
72
|
@array_api_program_gen_ruleset.register
|
|
67
|
-
def _tuple_int_program(i: Int,
|
|
73
|
+
def _tuple_int_program(i: Int, ti: TupleInt, k: i64, idx_fn: Callable[[Int], Int], vec_int: Vec[Int]):
|
|
68
74
|
yield rewrite(int_program(ti[i])).to(tuple_int_program(ti) + "[" + int_program(i) + "]")
|
|
75
|
+
yield rewrite(int_program(ti.length())).to(Program("len(") + tuple_int_program(ti) + ")")
|
|
69
76
|
|
|
70
77
|
yield rewrite(tuple_int_program(ti)).to(Program("(") + tuple_int_program_inner(ti) + ")")
|
|
71
|
-
|
|
72
|
-
|
|
78
|
+
|
|
79
|
+
yield rewrite(tuple_int_program_inner(TupleInt(0, idx_fn))).to(Program(""))
|
|
80
|
+
|
|
81
|
+
yield rewrite(tuple_int_program_inner(TupleInt(Int(k), idx_fn))).to(
|
|
82
|
+
int_program(idx_fn(Int(0))) + ", " + tuple_int_program_inner(TupleInt(Int(k - 1), lambda i: idx_fn(i + 1))),
|
|
83
|
+
ne(k).to(i64(0)),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
yield rewrite(tuple_int_program_inner(TupleInt.from_vec(Vec[Int]()))).to(Program(""))
|
|
87
|
+
yield rewrite(tuple_int_program_inner(TupleInt.from_vec(vec_int))).to(
|
|
88
|
+
int_program(vec_int[0]) + ", " + tuple_int_program_inner(TupleInt.from_vec(vec_int.remove(0))),
|
|
89
|
+
vec_int.length() > 1,
|
|
90
|
+
)
|
|
91
|
+
yield rewrite(tuple_int_program_inner(TupleInt.from_vec(vec_int))).to(
|
|
92
|
+
int_program(vec_int[0]) + ",",
|
|
93
|
+
eq(vec_int.length()).to(i64(1)),
|
|
73
94
|
)
|
|
74
|
-
yield rewrite(tuple_int_program_inner(TupleInt(i))).to(int_program(i) + ",")
|
|
75
95
|
|
|
76
96
|
|
|
77
97
|
@function
|
|
@@ -248,12 +268,31 @@ def multi_axis_index_key_program(x: MultiAxisIndexKey) -> Program: ...
|
|
|
248
268
|
|
|
249
269
|
|
|
250
270
|
@array_api_program_gen_ruleset.register
|
|
251
|
-
def _multi_axis_index_key_program(
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
271
|
+
def _multi_axis_index_key_program(
|
|
272
|
+
idx_fn: Callable[[Int], MultiAxisIndexKeyItem], k: i64, vec: Vec[MultiAxisIndexKeyItem], i: MultiAxisIndexKeyItem
|
|
273
|
+
):
|
|
274
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey(0, idx_fn))).to(Program(""))
|
|
275
|
+
|
|
276
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey(Int(k), idx_fn))).to(
|
|
277
|
+
multi_axis_index_key_item_program(idx_fn(Int(0)))
|
|
278
|
+
+ ", "
|
|
279
|
+
+ multi_axis_index_key_program(MultiAxisIndexKey(Int(k - 1), lambda i: idx_fn(i + 1))),
|
|
280
|
+
ne(k).to(i64(0)),
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem]()))).to(
|
|
284
|
+
Program("")
|
|
285
|
+
)
|
|
286
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec))).to(
|
|
287
|
+
multi_axis_index_key_item_program(vec[0]) + ",",
|
|
288
|
+
eq(vec.length()).to(i64(1)),
|
|
289
|
+
)
|
|
290
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec))).to(
|
|
291
|
+
multi_axis_index_key_item_program(vec[0])
|
|
292
|
+
+ ", "
|
|
293
|
+
+ multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec.remove(0))),
|
|
294
|
+
vec.length() > 1,
|
|
255
295
|
)
|
|
256
|
-
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.EMPTY)).to(Program("()"))
|
|
257
296
|
|
|
258
297
|
|
|
259
298
|
@function
|
|
@@ -266,7 +305,7 @@ def _index_key_program(i: Int, s: Slice, key: MultiAxisIndexKey, a: NDArray):
|
|
|
266
305
|
yield rewrite(index_key_program(IndexKey.int(i))).to(int_program(i))
|
|
267
306
|
yield rewrite(index_key_program(IndexKey.slice(s))).to(slice_program(s))
|
|
268
307
|
yield rewrite(index_key_program(IndexKey.multi_axis(key))).to(multi_axis_index_key_program(key))
|
|
269
|
-
yield rewrite(index_key_program(
|
|
308
|
+
yield rewrite(index_key_program(IndexKey.ndarray(a))).to(ndarray_program(a))
|
|
270
309
|
|
|
271
310
|
|
|
272
311
|
@function
|
egglog/functionalize.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from functools import partial
|
|
5
|
+
from inspect import Parameter, signature
|
|
6
|
+
from typing import Any, TypeVar, cast
|
|
7
|
+
|
|
8
|
+
__all__ = ["functionalize"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
T = TypeVar("T", bound=Callable)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# TODO: Add `to_lift` param so that we only transform those with vars in them to args
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def functionalize(f: T, get_annotation: Callable[[object], type | None]) -> T:
|
|
18
|
+
"""
|
|
19
|
+
Takes a function and returns a new function with all names (co_names) and free variables (co_freevars) added as arguments
|
|
20
|
+
and then partially applied with their values. The second arg, get_annotation, will be applied to all values
|
|
21
|
+
to get their type annotation. If it is None, that arg will not be added as a parameter.
|
|
22
|
+
|
|
23
|
+
For example if you have:
|
|
24
|
+
|
|
25
|
+
def get_annotation(x): return int if x <= 10 else None
|
|
26
|
+
|
|
27
|
+
g = 10
|
|
28
|
+
def f(a, a2):
|
|
29
|
+
def h(b: Z):
|
|
30
|
+
return a + a2 + b + g
|
|
31
|
+
|
|
32
|
+
return functionalize(h, get_annotation)
|
|
33
|
+
res = f(9, 11)
|
|
34
|
+
|
|
35
|
+
It should be equivalent to (according to body, signature, and annotations) (Note that the new arguments will be positional only):
|
|
36
|
+
|
|
37
|
+
def h(a: get_annotation(a), g: get_annotation(g), b: Z):
|
|
38
|
+
return a + b + g
|
|
39
|
+
res = partial(h, a, g)
|
|
40
|
+
"""
|
|
41
|
+
code = f.__code__
|
|
42
|
+
names = tuple(n for n in code.co_names if n in f.__globals__)
|
|
43
|
+
free_vars = code.co_freevars
|
|
44
|
+
|
|
45
|
+
global_values: list[Any] = [f.__globals__[name] for name in names]
|
|
46
|
+
free_var_values = [cell.cell_contents for cell in f.__closure__] if f.__closure__ else []
|
|
47
|
+
assert len(free_var_values) == len(free_vars), "Free vars and their values do not match"
|
|
48
|
+
global_values_filtered = [
|
|
49
|
+
(i, name, value, annotation)
|
|
50
|
+
for i, (name, value) in enumerate(zip(names, global_values, strict=True))
|
|
51
|
+
if (annotation := get_annotation(value)) is not None
|
|
52
|
+
]
|
|
53
|
+
free_var_values_filtered = [
|
|
54
|
+
(i, name, value, annotation)
|
|
55
|
+
for i, (name, value) in enumerate(zip(free_vars, free_var_values, strict=True))
|
|
56
|
+
if (annotation := get_annotation(value)) is not None
|
|
57
|
+
]
|
|
58
|
+
additional_arg_filtered = global_values_filtered + free_var_values_filtered
|
|
59
|
+
|
|
60
|
+
# Create a wrapper function
|
|
61
|
+
def wrapper(*args):
|
|
62
|
+
# Split args into names, free vars and other args
|
|
63
|
+
name_args, free_var_args, rest_args = (
|
|
64
|
+
args[: (n_names := len(global_values_filtered))],
|
|
65
|
+
args[n_names : (n_args := len(additional_arg_filtered))],
|
|
66
|
+
args[n_args:],
|
|
67
|
+
)
|
|
68
|
+
# Update globals with names
|
|
69
|
+
f.__globals__.update({
|
|
70
|
+
name: arg for (_, name, _, _), arg in zip(global_values_filtered, name_args, strict=False)
|
|
71
|
+
})
|
|
72
|
+
# update function free vars with free var args
|
|
73
|
+
for (i, _, _, _), value in zip(free_var_values_filtered, free_var_args, strict=True):
|
|
74
|
+
assert f.__closure__, "Function does not have closure"
|
|
75
|
+
f.__closure__[i].cell_contents = value
|
|
76
|
+
return f(*rest_args)
|
|
77
|
+
|
|
78
|
+
# Set the signature of the new function to a signature with the free vars and names added as arguments
|
|
79
|
+
orig_signature = signature(f)
|
|
80
|
+
wrapper.__signature__ = orig_signature.replace( # type: ignore[attr-defined]
|
|
81
|
+
parameters=[
|
|
82
|
+
*[Parameter(n, Parameter.POSITIONAL_OR_KEYWORD) for _, n, _, _ in additional_arg_filtered],
|
|
83
|
+
*orig_signature.parameters.values(),
|
|
84
|
+
]
|
|
85
|
+
)
|
|
86
|
+
# Set the annotations of the new function to the annotations of the original function + annotations of passed in values
|
|
87
|
+
wrapper.__annotations__ = f.__annotations__ | {n: a for _, n, _, a in additional_arg_filtered}
|
|
88
|
+
wrapper.__name__ = f.__name__
|
|
89
|
+
|
|
90
|
+
# Partially apply the wrapper function with the current values of the free vars
|
|
91
|
+
return cast(T, partial(wrapper, *(v for _, _, v, _ in additional_arg_filtered)))
|
egglog/pretty.py
CHANGED
|
@@ -16,6 +16,7 @@ from .declarations import *
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
17
17
|
from collections.abc import Mapping
|
|
18
18
|
|
|
19
|
+
|
|
19
20
|
__all__ = [
|
|
20
21
|
"pretty_decl",
|
|
21
22
|
"pretty_callable_ref",
|
|
@@ -77,9 +78,9 @@ def pretty_decl(
|
|
|
77
78
|
|
|
78
79
|
This will use re-format the result and put the expression on the last line, preceeded by the statements.
|
|
79
80
|
"""
|
|
80
|
-
traverse = TraverseContext()
|
|
81
|
+
traverse = TraverseContext(decls)
|
|
81
82
|
traverse(decl, toplevel=True)
|
|
82
|
-
pretty = traverse.pretty(
|
|
83
|
+
pretty = traverse.pretty()
|
|
83
84
|
expr = pretty(decl, ruleset_name=ruleset_name)
|
|
84
85
|
if wrapping_fn:
|
|
85
86
|
expr = f"{wrapping_fn}({expr})"
|
|
@@ -106,15 +107,20 @@ def pretty_callable_ref(
|
|
|
106
107
|
"""
|
|
107
108
|
# Pass in three dummy args, which are the max used for any operation that
|
|
108
109
|
# is not a generic function call
|
|
109
|
-
args: list[ExprDecl] = [VarDecl(ARG_STR)] * 3
|
|
110
|
+
args: list[ExprDecl] = [VarDecl(ARG_STR, False)] * 3
|
|
110
111
|
if first_arg:
|
|
111
112
|
args.insert(0, first_arg)
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
)
|
|
113
|
+
context = PrettyContext(decls, defaultdict(lambda: 0))
|
|
114
|
+
res = context._call_inner(ref, args, bound_tp_params=bound_tp_params, parens=False)
|
|
115
115
|
# Either returns a function or a function with args. If args are provided, they would just be called,
|
|
116
116
|
# on the function, so return them, because they are dummies
|
|
117
|
-
|
|
117
|
+
if isinstance(res, tuple):
|
|
118
|
+
name = res[0]
|
|
119
|
+
# if this is an unnamed function, return it but don't partially apply any args
|
|
120
|
+
if isinstance(name, UnnamedFunctionRef):
|
|
121
|
+
return context._pretty_function_body(name, [])
|
|
122
|
+
return name
|
|
123
|
+
return res
|
|
118
124
|
|
|
119
125
|
|
|
120
126
|
# TODO: Add a different pretty callable ref that doesnt fill in wholes but instead returns the function
|
|
@@ -128,16 +134,18 @@ class TraverseContext:
|
|
|
128
134
|
expression has.
|
|
129
135
|
"""
|
|
130
136
|
|
|
137
|
+
decls: Declarations
|
|
138
|
+
|
|
131
139
|
# All expressions we have seen (incremented the parent counts of all children)
|
|
132
140
|
_seen: set[AllDecls] = field(default_factory=set)
|
|
133
141
|
# The number of parents for each expressions
|
|
134
142
|
parents: Counter[AllDecls] = field(default_factory=Counter)
|
|
135
143
|
|
|
136
|
-
def pretty(self
|
|
144
|
+
def pretty(self) -> PrettyContext:
|
|
137
145
|
"""
|
|
138
146
|
Create a pretty context from the state of this traverse context.
|
|
139
147
|
"""
|
|
140
|
-
return PrettyContext(decls, self.parents)
|
|
148
|
+
return PrettyContext(self.decls, self.parents)
|
|
141
149
|
|
|
142
150
|
def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C901
|
|
143
151
|
if not toplevel:
|
|
@@ -166,10 +174,16 @@ class TraverseContext:
|
|
|
166
174
|
pass
|
|
167
175
|
case EqDecl(_, decls) | SequenceDecl(decls) | RulesetDecl(decls):
|
|
168
176
|
for de in decls:
|
|
177
|
+
if isinstance(de, DefaultRewriteDecl):
|
|
178
|
+
continue
|
|
169
179
|
self(de)
|
|
170
|
-
case CallDecl(
|
|
171
|
-
|
|
172
|
-
|
|
180
|
+
case CallDecl(ref, exprs, _):
|
|
181
|
+
match ref:
|
|
182
|
+
case FunctionRef(UnnamedFunctionRef(_, res)):
|
|
183
|
+
self(res.expr)
|
|
184
|
+
case _:
|
|
185
|
+
for e in exprs:
|
|
186
|
+
self(e.expr)
|
|
173
187
|
case RunDecl(_, until):
|
|
174
188
|
if until:
|
|
175
189
|
for f in until:
|
|
@@ -178,6 +192,8 @@ class TraverseContext:
|
|
|
178
192
|
self(c)
|
|
179
193
|
case CombinedRulesetDecl(_):
|
|
180
194
|
pass
|
|
195
|
+
case DefaultRewriteDecl():
|
|
196
|
+
pass
|
|
181
197
|
case _:
|
|
182
198
|
assert_never(decl)
|
|
183
199
|
|
|
@@ -240,8 +256,7 @@ class PrettyContext:
|
|
|
240
256
|
case CallDecl(_, _, _):
|
|
241
257
|
return self._call(decl, parens)
|
|
242
258
|
case PartialCallDecl(CallDecl(ref, typed_args, _)):
|
|
243
|
-
|
|
244
|
-
return f"UnstableFn({', '.join(arg_strs)})", "fn"
|
|
259
|
+
return self._pretty_partial(ref, [a.expr for a in typed_args]), "fn"
|
|
245
260
|
case PyObjectDecl(value):
|
|
246
261
|
return repr(value) if unwrap_lit else f"PyObject({value!r})", "PyObject"
|
|
247
262
|
case ActionCommandDecl(action):
|
|
@@ -276,7 +291,7 @@ class PrettyContext:
|
|
|
276
291
|
case RulesetDecl(rules):
|
|
277
292
|
if ruleset_name:
|
|
278
293
|
return f"ruleset(name={ruleset_name!r})", f"ruleset_{ruleset_name}"
|
|
279
|
-
args = ", ".join(
|
|
294
|
+
args = ", ".join(self(r) for r in rules if not isinstance(r, DefaultRewriteDecl))
|
|
280
295
|
return f"ruleset({args})", "ruleset"
|
|
281
296
|
case CombinedRulesetDecl(rulesets):
|
|
282
297
|
if ruleset_name:
|
|
@@ -298,6 +313,9 @@ class PrettyContext:
|
|
|
298
313
|
return ruleset_str, "schedule"
|
|
299
314
|
args = ", ".join(map(self, until))
|
|
300
315
|
return f"run({ruleset_str}, {args})", "schedule"
|
|
316
|
+
case DefaultRewriteDecl():
|
|
317
|
+
msg = "default rewrites should not be pretty printed"
|
|
318
|
+
raise TypeError(msg)
|
|
301
319
|
assert_never(decl)
|
|
302
320
|
|
|
303
321
|
def _call(
|
|
@@ -345,12 +363,16 @@ class PrettyContext:
|
|
|
345
363
|
has_multiple_parents = self.parents[first_arg] > 1
|
|
346
364
|
self.names[decl] = expr_name = self._name_expr(tp_name, expr_str, copy_identifier=has_multiple_parents)
|
|
347
365
|
# Set the first arg to be the name of the mutated arg and return the name
|
|
348
|
-
args[0] = VarDecl(expr_name)
|
|
366
|
+
args[0] = VarDecl(expr_name, True)
|
|
349
367
|
else:
|
|
350
368
|
expr_name = None
|
|
351
369
|
res = self._call_inner(ref, args, decl.bound_tp_params, parens)
|
|
352
370
|
expr = (
|
|
353
|
-
|
|
371
|
+
(
|
|
372
|
+
f"{name}({', '.join(self(a, parens=False, unwrap_lit=True) for a in res[1])})"
|
|
373
|
+
if isinstance((name := res[0]), str)
|
|
374
|
+
else ((called := self._pretty_function_body(name, res[1])) if not parens else f"({called})")
|
|
375
|
+
)
|
|
354
376
|
if isinstance(res, tuple)
|
|
355
377
|
else res
|
|
356
378
|
)
|
|
@@ -361,8 +383,12 @@ class PrettyContext:
|
|
|
361
383
|
return expr, tp_name
|
|
362
384
|
|
|
363
385
|
def _call_inner( # noqa: PLR0911
|
|
364
|
-
self,
|
|
365
|
-
|
|
386
|
+
self,
|
|
387
|
+
ref: CallableRef,
|
|
388
|
+
args: list[ExprDecl],
|
|
389
|
+
bound_tp_params: tuple[JustTypeRef, ...] | None,
|
|
390
|
+
parens: bool,
|
|
391
|
+
) -> tuple[str | UnnamedFunctionRef, list[ExprDecl]] | str:
|
|
366
392
|
"""
|
|
367
393
|
Pretty print the call, returning either the full function call or a tuple of the function and the args.
|
|
368
394
|
"""
|
|
@@ -370,10 +396,8 @@ class PrettyContext:
|
|
|
370
396
|
case FunctionRef(name):
|
|
371
397
|
return name, args
|
|
372
398
|
case ClassMethodRef(class_name, method_name):
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
fn_str += f".{method_name}"
|
|
376
|
-
return fn_str, args
|
|
399
|
+
tp_ref = JustTypeRef(class_name, bound_tp_params or ())
|
|
400
|
+
return f"{tp_ref}.{method_name}", args
|
|
377
401
|
case MethodRef(_class_name, method_name):
|
|
378
402
|
slf, *args = args
|
|
379
403
|
slf = self(slf, parens=True)
|
|
@@ -400,6 +424,11 @@ class PrettyContext:
|
|
|
400
424
|
return f"{class_name}.{variable_name}"
|
|
401
425
|
case PropertyRef(_class_name, property_name):
|
|
402
426
|
return f"{self(args[0], parens=True)}.{property_name}"
|
|
427
|
+
case InitRef(class_name):
|
|
428
|
+
tp_ref = JustTypeRef(class_name, bound_tp_params or ())
|
|
429
|
+
return str(tp_ref), args
|
|
430
|
+
case UnnamedFunctionRef():
|
|
431
|
+
return ref, args
|
|
403
432
|
assert_never(ref)
|
|
404
433
|
|
|
405
434
|
def _generate_name(self, typ: str) -> str:
|
|
@@ -420,27 +449,52 @@ class PrettyContext:
|
|
|
420
449
|
self.statements.append(f"{name} = {expr_str}")
|
|
421
450
|
return name
|
|
422
451
|
|
|
452
|
+
def _pretty_partial(self, ref: CallableRef, args: list[ExprDecl]) -> str:
|
|
453
|
+
"""
|
|
454
|
+
Returns a partial function call as a string.
|
|
455
|
+
"""
|
|
456
|
+
match ref:
|
|
457
|
+
case FunctionRef(name):
|
|
458
|
+
fn = name
|
|
459
|
+
case UnnamedFunctionRef():
|
|
460
|
+
return self._pretty_function_body(ref, args)
|
|
461
|
+
case (
|
|
462
|
+
ClassMethodRef(class_name, method_name)
|
|
463
|
+
| MethodRef(class_name, method_name)
|
|
464
|
+
| PropertyRef(class_name, method_name)
|
|
465
|
+
):
|
|
466
|
+
fn = f"{class_name}.{method_name}"
|
|
467
|
+
case InitRef(class_name):
|
|
468
|
+
fn = class_name
|
|
469
|
+
case ConstantRef(_):
|
|
470
|
+
msg = "Constants should not be callable"
|
|
471
|
+
raise NotImplementedError(msg)
|
|
472
|
+
case ClassVariableRef(_, _):
|
|
473
|
+
msg = "Class variables should not be callable"
|
|
474
|
+
raise NotADirectoryError(msg)
|
|
475
|
+
case _:
|
|
476
|
+
assert_never(ref)
|
|
477
|
+
if not args:
|
|
478
|
+
return fn
|
|
479
|
+
arg_strs = (
|
|
480
|
+
fn,
|
|
481
|
+
*(self(a, parens=False, unwrap_lit=True) for a in args),
|
|
482
|
+
)
|
|
483
|
+
return f"partial({', '.join(arg_strs)})"
|
|
423
484
|
|
|
424
|
-
def
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
case ConstantRef(_):
|
|
438
|
-
msg = "Constants should not be callable"
|
|
439
|
-
raise NotImplementedError(msg)
|
|
440
|
-
case ClassVariableRef(_, _):
|
|
441
|
-
msg = "Class variables should not be callable"
|
|
442
|
-
raise NotADirectoryError(msg)
|
|
443
|
-
assert_never(ref)
|
|
485
|
+
def _pretty_function_body(self, fn: UnnamedFunctionRef, args: list[ExprDecl]) -> str:
|
|
486
|
+
"""
|
|
487
|
+
Pretty print the body of a function, partially applying some arguments.
|
|
488
|
+
"""
|
|
489
|
+
var_args = fn.args
|
|
490
|
+
replacements = {var_arg: TypedExprDecl(var_arg.tp, arg) for var_arg, arg in zip(var_args, args, strict=False)}
|
|
491
|
+
var_args = var_args[len(args) :]
|
|
492
|
+
res = replace_typed_expr(fn.res, replacements)
|
|
493
|
+
arg_names = fn.args[len(args) :]
|
|
494
|
+
prefix = "lambda"
|
|
495
|
+
if arg_names:
|
|
496
|
+
prefix += f" {', '.join(self(a.expr) for a in arg_names)}"
|
|
497
|
+
return f"{prefix}: {self(res.expr)}"
|
|
444
498
|
|
|
445
499
|
|
|
446
500
|
def _plot_line_length(expr: object): # pragma: no cover
|