egglog 7.2.0__cp311-none-win_amd64.whl → 8.0.0__cp311-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp311-win_amd64.pyd +0 -0
- egglog/bindings.pyi +63 -23
- egglog/builtins.py +49 -6
- egglog/conversion.py +31 -8
- egglog/declarations.py +83 -4
- egglog/egraph.py +241 -173
- egglog/egraph_state.py +137 -61
- egglog/examples/higher_order_functions.py +3 -8
- egglog/exp/array_api.py +274 -92
- egglog/exp/array_api_jit.py +1 -4
- egglog/exp/array_api_loopnest.py +145 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +51 -12
- egglog/functionalize.py +91 -0
- egglog/pretty.py +84 -40
- egglog/runtime.py +52 -39
- egglog/thunk.py +30 -18
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35753 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.2.0.dist-info → egglog-8.0.0.dist-info}/METADATA +33 -32
- egglog-8.0.0.dist-info/RECORD +42 -0
- {egglog-7.2.0.dist-info → egglog-8.0.0.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.2.0.dist-info/RECORD +0 -40
- {egglog-7.2.0.dist-info/license_files → egglog-8.0.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""
|
|
2
|
+
In progress module
|
|
3
|
+
|
|
4
|
+
https://gist.github.com/sklam/5e5737137d48d6e5b816d14a90076f1d
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# %%
|
|
9
|
+
# mypy: disable-error-code="empty-body"
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from egglog import *
|
|
13
|
+
from egglog.exp.array_api import *
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ShapeAPI(Expr):
|
|
17
|
+
@method(unextractable=True)
|
|
18
|
+
def __init__(self, dims: TupleIntLike) -> None: ...
|
|
19
|
+
|
|
20
|
+
@method(unextractable=True)
|
|
21
|
+
def deselect(self, axis: TupleIntLike) -> ShapeAPI: ...
|
|
22
|
+
|
|
23
|
+
@method(unextractable=True)
|
|
24
|
+
def select(self, axis: TupleIntLike) -> ShapeAPI: ...
|
|
25
|
+
|
|
26
|
+
@method(unextractable=True)
|
|
27
|
+
def to_tuple(self) -> TupleInt: ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@array_api_ruleset.register
|
|
31
|
+
def shape_api_ruleset(dims: TupleInt, axis: TupleInt): # noqa: ANN201
|
|
32
|
+
s = ShapeAPI(dims)
|
|
33
|
+
yield rewrite(s.deselect(axis)).to(
|
|
34
|
+
ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: ~axis.contains(i)).map(lambda i: dims[i]))
|
|
35
|
+
)
|
|
36
|
+
yield rewrite(s.select(axis)).to(
|
|
37
|
+
ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: axis.contains(i)).map(lambda i: dims[i]))
|
|
38
|
+
)
|
|
39
|
+
yield rewrite(s.to_tuple()).to(dims)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class OptionalLoopNestAPI(Expr):
|
|
43
|
+
def __init__(self, value: LoopNestAPI) -> None: ...
|
|
44
|
+
|
|
45
|
+
NONE: ClassVar[OptionalLoopNestAPI]
|
|
46
|
+
|
|
47
|
+
def unwrap(self) -> LoopNestAPI: ...
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class LoopNestAPI(Expr):
|
|
51
|
+
def __init__(self, dim: Int, inner: OptionalLoopNestAPI) -> None: ...
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def from_tuple(cls, args: TupleInt) -> OptionalLoopNestAPI: ...
|
|
55
|
+
|
|
56
|
+
@method(preserve=True)
|
|
57
|
+
def __iter__(self) -> Iterator[TupleInt]:
|
|
58
|
+
return iter(self.indices)
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def indices(self) -> TupleTupleInt: ...
|
|
62
|
+
|
|
63
|
+
def get_dims(self) -> TupleInt: ...
|
|
64
|
+
|
|
65
|
+
def fold(self, fn: Callable[[NDArray, TupleInt], NDArray], init: NDArrayLike) -> NDArray: ...
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@function
|
|
69
|
+
def tuple_tuple_int_reduce_ndarray(
|
|
70
|
+
xs: TupleTupleInt, fn: Callable[[NDArray, TupleInt], NDArray], init: NDArray
|
|
71
|
+
) -> NDArray: ...
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@function
|
|
75
|
+
def tuple_int_map_tuple_int(xs: TupleInt, fn: Callable[[Int], TupleInt]) -> TupleTupleInt: ...
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@function
|
|
79
|
+
def tuple_tuple_int_product(xs: TupleTupleInt) -> TupleTupleInt: ...
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@array_api_ruleset.register
|
|
83
|
+
def _loopnest_api_ruleset(
|
|
84
|
+
head: Int,
|
|
85
|
+
tail: TupleInt,
|
|
86
|
+
lna: LoopNestAPI,
|
|
87
|
+
fn: Callable[[NDArray, TupleInt], NDArray],
|
|
88
|
+
init: NDArray,
|
|
89
|
+
dim: Int,
|
|
90
|
+
idx_fn: Callable[[Int], Int],
|
|
91
|
+
i: i64,
|
|
92
|
+
):
|
|
93
|
+
# from_tuple
|
|
94
|
+
yield rewrite(LoopNestAPI.from_tuple(TupleInt(0, idx_fn))).to(OptionalLoopNestAPI.NONE)
|
|
95
|
+
yield rewrite(LoopNestAPI.from_tuple(TupleInt(Int(i), idx_fn))).to(
|
|
96
|
+
OptionalLoopNestAPI(
|
|
97
|
+
LoopNestAPI(idx_fn(Int(0)), LoopNestAPI.from_tuple(TupleInt(Int(i - 1), lambda i: idx_fn(i + 1))))
|
|
98
|
+
),
|
|
99
|
+
ne(i).to(i64(0)),
|
|
100
|
+
)
|
|
101
|
+
# reduce
|
|
102
|
+
yield rewrite(lna.fold(fn, init)).to(tuple_tuple_int_reduce_ndarray(lna.indices, fn, init))
|
|
103
|
+
# get_dims
|
|
104
|
+
yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI.NONE).get_dims()).to(TupleInt.single(dim))
|
|
105
|
+
yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI(lna)).get_dims()).to(TupleInt.single(dim) + lna.get_dims())
|
|
106
|
+
# indices
|
|
107
|
+
yield rewrite(lna.indices).to(tuple_tuple_int_product(tuple_int_map_tuple_int(lna.get_dims(), TupleInt.range)))
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@function(ruleset=array_api_ruleset, unextractable=True)
|
|
111
|
+
def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray: # noqa: N803
|
|
112
|
+
# peel off the outer shape for result array
|
|
113
|
+
outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
|
|
114
|
+
# get only the inner shape for reduction
|
|
115
|
+
reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()
|
|
116
|
+
|
|
117
|
+
return NDArray(
|
|
118
|
+
outshape,
|
|
119
|
+
X.dtype,
|
|
120
|
+
lambda k: sqrt(
|
|
121
|
+
LoopNestAPI.from_tuple(reduce_axis)
|
|
122
|
+
.unwrap()
|
|
123
|
+
.fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0)
|
|
124
|
+
).to_value(),
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
# %%
|
|
129
|
+
egraph = EGraph(save_egglog_string=True)
|
|
130
|
+
|
|
131
|
+
X = NDArray.var("X")
|
|
132
|
+
assume_shape(X, (3, 2, 3, 4))
|
|
133
|
+
val = linalg_norm(X, (0, 1))
|
|
134
|
+
egraph.register(val.shape)
|
|
135
|
+
egraph.run(array_api_ruleset.saturate())
|
|
136
|
+
egraph.extract_multiple(val.shape, 10)
|
|
137
|
+
|
|
138
|
+
# %%
|
|
139
|
+
egraph = EGraph()
|
|
140
|
+
egraph.register(val.shape[2])
|
|
141
|
+
egraph.run(array_api_ruleset.saturate())
|
|
142
|
+
egraph.display(split_functions=[Int, TRUE, FALSE], n_inline_leaves=2)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
# %%
|
egglog/exp/array_api_numba.py
CHANGED
|
@@ -67,6 +67,6 @@ def _unique_inverse(x: NDArray, i: Int):
|
|
|
67
67
|
return [
|
|
68
68
|
# Creating a mask array of when the unique inverse is a value is the same as a mask array for when the value is that index of the unique values
|
|
69
69
|
rewrite(unique_inverse(x)[Int(1)] == NDArray.scalar(Value.int(i)), subsume=True).to(
|
|
70
|
-
x == NDArray.scalar(unique_values(x).index(
|
|
70
|
+
x == NDArray.scalar(unique_values(x).index((i,)))
|
|
71
71
|
),
|
|
72
72
|
]
|
|
@@ -56,22 +56,42 @@ def _int_program(i64_: i64, i: Int, j: Int):
|
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
@function
|
|
59
|
-
def tuple_int_program(x: TupleInt) -> Program:
|
|
59
|
+
def tuple_int_program(x: TupleInt) -> Program:
|
|
60
|
+
...
|
|
61
|
+
# Could be rewritten as a fold, but we don't support generic folds yet
|
|
62
|
+
# return x.fold(Program("("), lambda acc, i: acc + ", " + int_program(i)) + ")"
|
|
60
63
|
|
|
61
64
|
|
|
62
65
|
@function
|
|
63
|
-
def tuple_int_program_inner(x: TupleInt) -> Program:
|
|
66
|
+
def tuple_int_program_inner(x: TupleInt) -> Program:
|
|
67
|
+
"""
|
|
68
|
+
Returns the tuple w/ out the parenthesis
|
|
69
|
+
"""
|
|
64
70
|
|
|
65
71
|
|
|
66
72
|
@array_api_program_gen_ruleset.register
|
|
67
|
-
def _tuple_int_program(i: Int,
|
|
73
|
+
def _tuple_int_program(i: Int, ti: TupleInt, k: i64, idx_fn: Callable[[Int], Int], vec_int: Vec[Int]):
|
|
68
74
|
yield rewrite(int_program(ti[i])).to(tuple_int_program(ti) + "[" + int_program(i) + "]")
|
|
75
|
+
yield rewrite(int_program(ti.length())).to(Program("len(") + tuple_int_program(ti) + ")")
|
|
69
76
|
|
|
70
77
|
yield rewrite(tuple_int_program(ti)).to(Program("(") + tuple_int_program_inner(ti) + ")")
|
|
71
|
-
|
|
72
|
-
|
|
78
|
+
|
|
79
|
+
yield rewrite(tuple_int_program_inner(TupleInt(0, idx_fn))).to(Program(""))
|
|
80
|
+
|
|
81
|
+
yield rewrite(tuple_int_program_inner(TupleInt(Int(k), idx_fn))).to(
|
|
82
|
+
int_program(idx_fn(Int(0))) + ", " + tuple_int_program_inner(TupleInt(Int(k - 1), lambda i: idx_fn(i + 1))),
|
|
83
|
+
ne(k).to(i64(0)),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
yield rewrite(tuple_int_program_inner(TupleInt.from_vec(Vec[Int]()))).to(Program(""))
|
|
87
|
+
yield rewrite(tuple_int_program_inner(TupleInt.from_vec(vec_int))).to(
|
|
88
|
+
int_program(vec_int[0]) + ", " + tuple_int_program_inner(TupleInt.from_vec(vec_int.remove(0))),
|
|
89
|
+
vec_int.length() > 1,
|
|
90
|
+
)
|
|
91
|
+
yield rewrite(tuple_int_program_inner(TupleInt.from_vec(vec_int))).to(
|
|
92
|
+
int_program(vec_int[0]) + ",",
|
|
93
|
+
eq(vec_int.length()).to(i64(1)),
|
|
73
94
|
)
|
|
74
|
-
yield rewrite(tuple_int_program_inner(TupleInt(i))).to(int_program(i) + ",")
|
|
75
95
|
|
|
76
96
|
|
|
77
97
|
@function
|
|
@@ -248,12 +268,31 @@ def multi_axis_index_key_program(x: MultiAxisIndexKey) -> Program: ...
|
|
|
248
268
|
|
|
249
269
|
|
|
250
270
|
@array_api_program_gen_ruleset.register
|
|
251
|
-
def _multi_axis_index_key_program(
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
271
|
+
def _multi_axis_index_key_program(
|
|
272
|
+
idx_fn: Callable[[Int], MultiAxisIndexKeyItem], k: i64, vec: Vec[MultiAxisIndexKeyItem], i: MultiAxisIndexKeyItem
|
|
273
|
+
):
|
|
274
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey(0, idx_fn))).to(Program(""))
|
|
275
|
+
|
|
276
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey(Int(k), idx_fn))).to(
|
|
277
|
+
multi_axis_index_key_item_program(idx_fn(Int(0)))
|
|
278
|
+
+ ", "
|
|
279
|
+
+ multi_axis_index_key_program(MultiAxisIndexKey(Int(k - 1), lambda i: idx_fn(i + 1))),
|
|
280
|
+
ne(k).to(i64(0)),
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem]()))).to(
|
|
284
|
+
Program("")
|
|
285
|
+
)
|
|
286
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec))).to(
|
|
287
|
+
multi_axis_index_key_item_program(vec[0]) + ",",
|
|
288
|
+
eq(vec.length()).to(i64(1)),
|
|
289
|
+
)
|
|
290
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec))).to(
|
|
291
|
+
multi_axis_index_key_item_program(vec[0])
|
|
292
|
+
+ ", "
|
|
293
|
+
+ multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec.remove(0))),
|
|
294
|
+
vec.length() > 1,
|
|
255
295
|
)
|
|
256
|
-
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.EMPTY)).to(Program("()"))
|
|
257
296
|
|
|
258
297
|
|
|
259
298
|
@function
|
|
@@ -266,7 +305,7 @@ def _index_key_program(i: Int, s: Slice, key: MultiAxisIndexKey, a: NDArray):
|
|
|
266
305
|
yield rewrite(index_key_program(IndexKey.int(i))).to(int_program(i))
|
|
267
306
|
yield rewrite(index_key_program(IndexKey.slice(s))).to(slice_program(s))
|
|
268
307
|
yield rewrite(index_key_program(IndexKey.multi_axis(key))).to(multi_axis_index_key_program(key))
|
|
269
|
-
yield rewrite(index_key_program(
|
|
308
|
+
yield rewrite(index_key_program(IndexKey.ndarray(a))).to(ndarray_program(a))
|
|
270
309
|
|
|
271
310
|
|
|
272
311
|
@function
|
egglog/functionalize.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from functools import partial
|
|
5
|
+
from inspect import Parameter, signature
|
|
6
|
+
from typing import Any, TypeVar, cast
|
|
7
|
+
|
|
8
|
+
__all__ = ["functionalize"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
T = TypeVar("T", bound=Callable)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# TODO: Add `to_lift` param so that we only transform those with vars in them to args
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def functionalize(f: T, get_annotation: Callable[[object], type | None]) -> T:
|
|
18
|
+
"""
|
|
19
|
+
Takes a function and returns a new function with all names (co_names) and free variables (co_freevars) added as arguments
|
|
20
|
+
and then partially applied with their values. The second arg, get_annotation, will be applied to all values
|
|
21
|
+
to get their type annotation. If it is None, that arg will not be added as a parameter.
|
|
22
|
+
|
|
23
|
+
For example if you have:
|
|
24
|
+
|
|
25
|
+
def get_annotation(x): return int if x <= 10 else None
|
|
26
|
+
|
|
27
|
+
g = 10
|
|
28
|
+
def f(a, a2):
|
|
29
|
+
def h(b: Z):
|
|
30
|
+
return a + a2 + b + g
|
|
31
|
+
|
|
32
|
+
return functionalize(h, get_annotation)
|
|
33
|
+
res = f(9, 11)
|
|
34
|
+
|
|
35
|
+
It should be equivalent to (according to body, signature, and annotations) (Note that the new arguments will be positional only):
|
|
36
|
+
|
|
37
|
+
def h(a: get_annotation(a), g: get_annotation(g), b: Z):
|
|
38
|
+
return a + b + g
|
|
39
|
+
res = partial(h, a, g)
|
|
40
|
+
"""
|
|
41
|
+
code = f.__code__
|
|
42
|
+
names = tuple(n for n in code.co_names if n in f.__globals__)
|
|
43
|
+
free_vars = code.co_freevars
|
|
44
|
+
|
|
45
|
+
global_values: list[Any] = [f.__globals__[name] for name in names]
|
|
46
|
+
free_var_values = [cell.cell_contents for cell in f.__closure__] if f.__closure__ else []
|
|
47
|
+
assert len(free_var_values) == len(free_vars), "Free vars and their values do not match"
|
|
48
|
+
global_values_filtered = [
|
|
49
|
+
(i, name, value, annotation)
|
|
50
|
+
for i, (name, value) in enumerate(zip(names, global_values, strict=True))
|
|
51
|
+
if (annotation := get_annotation(value)) is not None
|
|
52
|
+
]
|
|
53
|
+
free_var_values_filtered = [
|
|
54
|
+
(i, name, value, annotation)
|
|
55
|
+
for i, (name, value) in enumerate(zip(free_vars, free_var_values, strict=True))
|
|
56
|
+
if (annotation := get_annotation(value)) is not None
|
|
57
|
+
]
|
|
58
|
+
additional_arg_filtered = global_values_filtered + free_var_values_filtered
|
|
59
|
+
|
|
60
|
+
# Create a wrapper function
|
|
61
|
+
def wrapper(*args):
|
|
62
|
+
# Split args into names, free vars and other args
|
|
63
|
+
name_args, free_var_args, rest_args = (
|
|
64
|
+
args[: (n_names := len(global_values_filtered))],
|
|
65
|
+
args[n_names : (n_args := len(additional_arg_filtered))],
|
|
66
|
+
args[n_args:],
|
|
67
|
+
)
|
|
68
|
+
# Update globals with names
|
|
69
|
+
f.__globals__.update({
|
|
70
|
+
name: arg for (_, name, _, _), arg in zip(global_values_filtered, name_args, strict=False)
|
|
71
|
+
})
|
|
72
|
+
# update function free vars with free var args
|
|
73
|
+
for (i, _, _, _), value in zip(free_var_values_filtered, free_var_args, strict=True):
|
|
74
|
+
assert f.__closure__, "Function does not have closure"
|
|
75
|
+
f.__closure__[i].cell_contents = value
|
|
76
|
+
return f(*rest_args)
|
|
77
|
+
|
|
78
|
+
# Set the signature of the new function to a signature with the free vars and names added as arguments
|
|
79
|
+
orig_signature = signature(f)
|
|
80
|
+
wrapper.__signature__ = orig_signature.replace( # type: ignore[attr-defined]
|
|
81
|
+
parameters=[
|
|
82
|
+
*[Parameter(n, Parameter.POSITIONAL_OR_KEYWORD) for _, n, _, _ in additional_arg_filtered],
|
|
83
|
+
*orig_signature.parameters.values(),
|
|
84
|
+
]
|
|
85
|
+
)
|
|
86
|
+
# Set the annotations of the new function to the annotations of the original function + annotations of passed in values
|
|
87
|
+
wrapper.__annotations__ = f.__annotations__ | {n: a for _, n, _, a in additional_arg_filtered}
|
|
88
|
+
wrapper.__name__ = f.__name__
|
|
89
|
+
|
|
90
|
+
# Partially apply the wrapper function with the current values of the free vars
|
|
91
|
+
return cast(T, partial(wrapper, *(v for _, _, v, _ in additional_arg_filtered)))
|
egglog/pretty.py
CHANGED
|
@@ -16,6 +16,7 @@ from .declarations import *
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
17
17
|
from collections.abc import Mapping
|
|
18
18
|
|
|
19
|
+
|
|
19
20
|
__all__ = [
|
|
20
21
|
"pretty_decl",
|
|
21
22
|
"pretty_callable_ref",
|
|
@@ -77,9 +78,9 @@ def pretty_decl(
|
|
|
77
78
|
|
|
78
79
|
This will use re-format the result and put the expression on the last line, preceeded by the statements.
|
|
79
80
|
"""
|
|
80
|
-
traverse = TraverseContext()
|
|
81
|
+
traverse = TraverseContext(decls)
|
|
81
82
|
traverse(decl, toplevel=True)
|
|
82
|
-
pretty = traverse.pretty(
|
|
83
|
+
pretty = traverse.pretty()
|
|
83
84
|
expr = pretty(decl, ruleset_name=ruleset_name)
|
|
84
85
|
if wrapping_fn:
|
|
85
86
|
expr = f"{wrapping_fn}({expr})"
|
|
@@ -106,15 +107,20 @@ def pretty_callable_ref(
|
|
|
106
107
|
"""
|
|
107
108
|
# Pass in three dummy args, which are the max used for any operation that
|
|
108
109
|
# is not a generic function call
|
|
109
|
-
args: list[ExprDecl] = [VarDecl(ARG_STR)] * 3
|
|
110
|
+
args: list[ExprDecl] = [VarDecl(ARG_STR, False)] * 3
|
|
110
111
|
if first_arg:
|
|
111
112
|
args.insert(0, first_arg)
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
)
|
|
113
|
+
context = PrettyContext(decls, defaultdict(lambda: 0))
|
|
114
|
+
res = context._call_inner(ref, args, bound_tp_params=bound_tp_params, parens=False)
|
|
115
115
|
# Either returns a function or a function with args. If args are provided, they would just be called,
|
|
116
116
|
# on the function, so return them, because they are dummies
|
|
117
|
-
|
|
117
|
+
if isinstance(res, tuple):
|
|
118
|
+
name = res[0]
|
|
119
|
+
# if this is an unnamed function, return it but don't partially apply any args
|
|
120
|
+
if isinstance(name, UnnamedFunctionRef):
|
|
121
|
+
return context._pretty_function_body(name, [])
|
|
122
|
+
return name
|
|
123
|
+
return res
|
|
118
124
|
|
|
119
125
|
|
|
120
126
|
# TODO: Add a different pretty callable ref that doesnt fill in wholes but instead returns the function
|
|
@@ -128,16 +134,18 @@ class TraverseContext:
|
|
|
128
134
|
expression has.
|
|
129
135
|
"""
|
|
130
136
|
|
|
137
|
+
decls: Declarations
|
|
138
|
+
|
|
131
139
|
# All expressions we have seen (incremented the parent counts of all children)
|
|
132
140
|
_seen: set[AllDecls] = field(default_factory=set)
|
|
133
141
|
# The number of parents for each expressions
|
|
134
142
|
parents: Counter[AllDecls] = field(default_factory=Counter)
|
|
135
143
|
|
|
136
|
-
def pretty(self
|
|
144
|
+
def pretty(self) -> PrettyContext:
|
|
137
145
|
"""
|
|
138
146
|
Create a pretty context from the state of this traverse context.
|
|
139
147
|
"""
|
|
140
|
-
return PrettyContext(decls, self.parents)
|
|
148
|
+
return PrettyContext(self.decls, self.parents)
|
|
141
149
|
|
|
142
150
|
def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C901
|
|
143
151
|
if not toplevel:
|
|
@@ -169,9 +177,13 @@ class TraverseContext:
|
|
|
169
177
|
if isinstance(de, DefaultRewriteDecl):
|
|
170
178
|
continue
|
|
171
179
|
self(de)
|
|
172
|
-
case CallDecl(
|
|
173
|
-
|
|
174
|
-
|
|
180
|
+
case CallDecl(ref, exprs, _):
|
|
181
|
+
match ref:
|
|
182
|
+
case FunctionRef(UnnamedFunctionRef(_, res)):
|
|
183
|
+
self(res.expr)
|
|
184
|
+
case _:
|
|
185
|
+
for e in exprs:
|
|
186
|
+
self(e.expr)
|
|
175
187
|
case RunDecl(_, until):
|
|
176
188
|
if until:
|
|
177
189
|
for f in until:
|
|
@@ -244,8 +256,7 @@ class PrettyContext:
|
|
|
244
256
|
case CallDecl(_, _, _):
|
|
245
257
|
return self._call(decl, parens)
|
|
246
258
|
case PartialCallDecl(CallDecl(ref, typed_args, _)):
|
|
247
|
-
|
|
248
|
-
return f"UnstableFn({', '.join(arg_strs)})", "fn"
|
|
259
|
+
return self._pretty_partial(ref, [a.expr for a in typed_args]), "fn"
|
|
249
260
|
case PyObjectDecl(value):
|
|
250
261
|
return repr(value) if unwrap_lit else f"PyObject({value!r})", "PyObject"
|
|
251
262
|
case ActionCommandDecl(action):
|
|
@@ -352,12 +363,16 @@ class PrettyContext:
|
|
|
352
363
|
has_multiple_parents = self.parents[first_arg] > 1
|
|
353
364
|
self.names[decl] = expr_name = self._name_expr(tp_name, expr_str, copy_identifier=has_multiple_parents)
|
|
354
365
|
# Set the first arg to be the name of the mutated arg and return the name
|
|
355
|
-
args[0] = VarDecl(expr_name)
|
|
366
|
+
args[0] = VarDecl(expr_name, True)
|
|
356
367
|
else:
|
|
357
368
|
expr_name = None
|
|
358
369
|
res = self._call_inner(ref, args, decl.bound_tp_params, parens)
|
|
359
370
|
expr = (
|
|
360
|
-
|
|
371
|
+
(
|
|
372
|
+
f"{name}({', '.join(self(a, parens=False, unwrap_lit=True) for a in res[1])})"
|
|
373
|
+
if isinstance((name := res[0]), str)
|
|
374
|
+
else ((called := self._pretty_function_body(name, res[1])) if not parens else f"({called})")
|
|
375
|
+
)
|
|
361
376
|
if isinstance(res, tuple)
|
|
362
377
|
else res
|
|
363
378
|
)
|
|
@@ -368,8 +383,12 @@ class PrettyContext:
|
|
|
368
383
|
return expr, tp_name
|
|
369
384
|
|
|
370
385
|
def _call_inner( # noqa: PLR0911
|
|
371
|
-
self,
|
|
372
|
-
|
|
386
|
+
self,
|
|
387
|
+
ref: CallableRef,
|
|
388
|
+
args: list[ExprDecl],
|
|
389
|
+
bound_tp_params: tuple[JustTypeRef, ...] | None,
|
|
390
|
+
parens: bool,
|
|
391
|
+
) -> tuple[str | UnnamedFunctionRef, list[ExprDecl]] | str:
|
|
373
392
|
"""
|
|
374
393
|
Pretty print the call, returning either the full function call or a tuple of the function and the args.
|
|
375
394
|
"""
|
|
@@ -408,6 +427,8 @@ class PrettyContext:
|
|
|
408
427
|
case InitRef(class_name):
|
|
409
428
|
tp_ref = JustTypeRef(class_name, bound_tp_params or ())
|
|
410
429
|
return str(tp_ref), args
|
|
430
|
+
case UnnamedFunctionRef():
|
|
431
|
+
return ref, args
|
|
411
432
|
assert_never(ref)
|
|
412
433
|
|
|
413
434
|
def _generate_name(self, typ: str) -> str:
|
|
@@ -428,29 +449,52 @@ class PrettyContext:
|
|
|
428
449
|
self.statements.append(f"{name} = {expr_str}")
|
|
429
450
|
return name
|
|
430
451
|
|
|
452
|
+
def _pretty_partial(self, ref: CallableRef, args: list[ExprDecl]) -> str:
|
|
453
|
+
"""
|
|
454
|
+
Returns a partial function call as a string.
|
|
455
|
+
"""
|
|
456
|
+
match ref:
|
|
457
|
+
case FunctionRef(name):
|
|
458
|
+
fn = name
|
|
459
|
+
case UnnamedFunctionRef():
|
|
460
|
+
return self._pretty_function_body(ref, args)
|
|
461
|
+
case (
|
|
462
|
+
ClassMethodRef(class_name, method_name)
|
|
463
|
+
| MethodRef(class_name, method_name)
|
|
464
|
+
| PropertyRef(class_name, method_name)
|
|
465
|
+
):
|
|
466
|
+
fn = f"{class_name}.{method_name}"
|
|
467
|
+
case InitRef(class_name):
|
|
468
|
+
fn = class_name
|
|
469
|
+
case ConstantRef(_):
|
|
470
|
+
msg = "Constants should not be callable"
|
|
471
|
+
raise NotImplementedError(msg)
|
|
472
|
+
case ClassVariableRef(_, _):
|
|
473
|
+
msg = "Class variables should not be callable"
|
|
474
|
+
raise NotADirectoryError(msg)
|
|
475
|
+
case _:
|
|
476
|
+
assert_never(ref)
|
|
477
|
+
if not args:
|
|
478
|
+
return fn
|
|
479
|
+
arg_strs = (
|
|
480
|
+
fn,
|
|
481
|
+
*(self(a, parens=False, unwrap_lit=True) for a in args),
|
|
482
|
+
)
|
|
483
|
+
return f"partial({', '.join(arg_strs)})"
|
|
431
484
|
|
|
432
|
-
def
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
case InitRef(class_name):
|
|
446
|
-
return class_name
|
|
447
|
-
case ConstantRef(_):
|
|
448
|
-
msg = "Constants should not be callable"
|
|
449
|
-
raise NotImplementedError(msg)
|
|
450
|
-
case ClassVariableRef(_, _):
|
|
451
|
-
msg = "Class variables should not be callable"
|
|
452
|
-
raise NotADirectoryError(msg)
|
|
453
|
-
assert_never(ref)
|
|
485
|
+
def _pretty_function_body(self, fn: UnnamedFunctionRef, args: list[ExprDecl]) -> str:
|
|
486
|
+
"""
|
|
487
|
+
Pretty print the body of a function, partially applying some arguments.
|
|
488
|
+
"""
|
|
489
|
+
var_args = fn.args
|
|
490
|
+
replacements = {var_arg: TypedExprDecl(var_arg.tp, arg) for var_arg, arg in zip(var_args, args, strict=False)}
|
|
491
|
+
var_args = var_args[len(args) :]
|
|
492
|
+
res = replace_typed_expr(fn.res, replacements)
|
|
493
|
+
arg_names = fn.args[len(args) :]
|
|
494
|
+
prefix = "lambda"
|
|
495
|
+
if arg_names:
|
|
496
|
+
prefix += f" {', '.join(self(a.expr) for a in arg_names)}"
|
|
497
|
+
return f"{prefix}: {self(res.expr)}"
|
|
454
498
|
|
|
455
499
|
|
|
456
500
|
def _plot_line_length(expr: object): # pragma: no cover
|