egglog 7.2.0__cp311-none-win_amd64.whl → 8.0.1__cp311-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp311-win_amd64.pyd +0 -0
- egglog/bindings.pyi +107 -53
- egglog/builtins.py +49 -6
- egglog/conversion.py +32 -9
- egglog/declarations.py +82 -4
- egglog/egraph.py +260 -179
- egglog/egraph_state.py +149 -66
- egglog/examples/higher_order_functions.py +4 -9
- egglog/exp/array_api.py +278 -93
- egglog/exp/array_api_jit.py +4 -8
- egglog/exp/array_api_loopnest.py +149 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +62 -25
- egglog/exp/program_gen.py +23 -17
- egglog/functionalize.py +91 -0
- egglog/ipython_magic.py +1 -1
- egglog/pretty.py +88 -44
- egglog/runtime.py +53 -40
- egglog/thunk.py +30 -18
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35774 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.2.0.dist-info → egglog-8.0.1.dist-info}/METADATA +33 -32
- egglog-8.0.1.dist-info/RECORD +42 -0
- {egglog-7.2.0.dist-info → egglog-8.0.1.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.2.0.dist-info/RECORD +0 -40
- {egglog-7.2.0.dist-info/license_files → egglog-8.0.1.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""
|
|
2
|
+
In progress module
|
|
3
|
+
|
|
4
|
+
https://gist.github.com/sklam/5e5737137d48d6e5b816d14a90076f1d
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# %%
|
|
9
|
+
# mypy: disable-error-code="empty-body"
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from egglog import *
|
|
13
|
+
from egglog.exp.array_api import *
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ShapeAPI(Expr):
|
|
17
|
+
@method(unextractable=True)
|
|
18
|
+
def __init__(self, dims: TupleIntLike) -> None: ...
|
|
19
|
+
|
|
20
|
+
@method(unextractable=True)
|
|
21
|
+
def deselect(self, axis: TupleIntLike) -> ShapeAPI: ...
|
|
22
|
+
|
|
23
|
+
@method(unextractable=True)
|
|
24
|
+
def select(self, axis: TupleIntLike) -> ShapeAPI: ...
|
|
25
|
+
|
|
26
|
+
@method(unextractable=True)
|
|
27
|
+
def to_tuple(self) -> TupleInt: ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@array_api_ruleset.register
|
|
31
|
+
def shape_api_ruleset(dims: TupleInt, axis: TupleInt):
|
|
32
|
+
s = ShapeAPI(dims)
|
|
33
|
+
yield rewrite(s.deselect(axis)).to(
|
|
34
|
+
ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: ~axis.contains(i)).map(lambda i: dims[i]))
|
|
35
|
+
)
|
|
36
|
+
yield rewrite(s.select(axis)).to(
|
|
37
|
+
ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: axis.contains(i)).map(lambda i: dims[i]))
|
|
38
|
+
)
|
|
39
|
+
yield rewrite(s.to_tuple()).to(dims)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class OptionalLoopNestAPI(Expr):
|
|
43
|
+
def __init__(self, value: LoopNestAPI) -> None: ...
|
|
44
|
+
|
|
45
|
+
NONE: ClassVar[OptionalLoopNestAPI]
|
|
46
|
+
|
|
47
|
+
def unwrap(self) -> LoopNestAPI: ...
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class LoopNestAPI(Expr):
|
|
51
|
+
def __init__(self, dim: Int, inner: OptionalLoopNestAPI) -> None: ...
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def from_tuple(cls, args: TupleInt) -> OptionalLoopNestAPI: ...
|
|
55
|
+
|
|
56
|
+
@method(preserve=True)
|
|
57
|
+
def __iter__(self) -> Iterator[TupleInt]:
|
|
58
|
+
return iter(self.indices)
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def indices(self) -> TupleTupleInt: ...
|
|
62
|
+
|
|
63
|
+
def get_dims(self) -> TupleInt: ...
|
|
64
|
+
|
|
65
|
+
def fold(self, fn: Callable[[NDArray, TupleInt], NDArray], init: NDArrayLike) -> NDArray: ...
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@function
|
|
69
|
+
def tuple_tuple_int_reduce_ndarray(
|
|
70
|
+
xs: TupleTupleInt, fn: Callable[[NDArray, TupleInt], NDArray], init: NDArray
|
|
71
|
+
) -> NDArray: ...
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@function
|
|
75
|
+
def tuple_int_map_tuple_int(xs: TupleInt, fn: Callable[[Int], TupleInt]) -> TupleTupleInt: ...
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@function
|
|
79
|
+
def tuple_tuple_int_product(xs: TupleTupleInt) -> TupleTupleInt: ...
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@array_api_ruleset.register
|
|
83
|
+
def _loopnest_api_ruleset(
|
|
84
|
+
head: Int,
|
|
85
|
+
tail: TupleInt,
|
|
86
|
+
lna: LoopNestAPI,
|
|
87
|
+
fn: Callable[[NDArray, TupleInt], NDArray],
|
|
88
|
+
init: NDArray,
|
|
89
|
+
dim: Int,
|
|
90
|
+
idx_fn: Callable[[Int], Int],
|
|
91
|
+
i: i64,
|
|
92
|
+
):
|
|
93
|
+
# from_tuple
|
|
94
|
+
yield rewrite(LoopNestAPI.from_tuple(TupleInt(0, idx_fn))).to(OptionalLoopNestAPI.NONE)
|
|
95
|
+
yield rewrite(LoopNestAPI.from_tuple(TupleInt(Int(i), idx_fn))).to(
|
|
96
|
+
OptionalLoopNestAPI(
|
|
97
|
+
LoopNestAPI(idx_fn(Int(0)), LoopNestAPI.from_tuple(TupleInt(Int(i - 1), lambda i: idx_fn(i + 1))))
|
|
98
|
+
),
|
|
99
|
+
ne(i).to(i64(0)),
|
|
100
|
+
)
|
|
101
|
+
# reduce
|
|
102
|
+
yield rewrite(lna.fold(fn, init)).to(tuple_tuple_int_reduce_ndarray(lna.indices, fn, init))
|
|
103
|
+
# get_dims
|
|
104
|
+
yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI.NONE).get_dims()).to(TupleInt.single(dim))
|
|
105
|
+
yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI(lna)).get_dims()).to(TupleInt.single(dim) + lna.get_dims())
|
|
106
|
+
# indices
|
|
107
|
+
yield rewrite(lna.indices).to(tuple_tuple_int_product(tuple_int_map_tuple_int(lna.get_dims(), TupleInt.range)))
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@function(ruleset=array_api_ruleset, unextractable=True)
|
|
111
|
+
def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
|
|
112
|
+
# peel off the outer shape for result array
|
|
113
|
+
outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
|
|
114
|
+
# get only the inner shape for reduction
|
|
115
|
+
reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()
|
|
116
|
+
|
|
117
|
+
return NDArray(
|
|
118
|
+
outshape,
|
|
119
|
+
X.dtype,
|
|
120
|
+
lambda k: sqrt(
|
|
121
|
+
LoopNestAPI.from_tuple(reduce_axis)
|
|
122
|
+
.unwrap()
|
|
123
|
+
.fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0)
|
|
124
|
+
).to_value(),
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
# %%
|
|
129
|
+
# egraph = EGraph(save_egglog_string=True)
|
|
130
|
+
|
|
131
|
+
# egraph.register(val.shape)
|
|
132
|
+
# egraph.run(array_api_ruleset.saturate())
|
|
133
|
+
# egraph.extract_multiple(val.shape, 10)
|
|
134
|
+
|
|
135
|
+
# %%
|
|
136
|
+
|
|
137
|
+
X = NDArray.var("X")
|
|
138
|
+
assume_shape(X, (3, 2, 3, 4))
|
|
139
|
+
val = linalg_norm(X, (0, 1))
|
|
140
|
+
egraph = EGraph()
|
|
141
|
+
x = egraph.let("x", val.shape[2])
|
|
142
|
+
# egraph.display(n_inline_leaves=0)
|
|
143
|
+
# egraph.extract(x)
|
|
144
|
+
# egraph.saturate(array_api_ruleset, expr=val.shape[2], split_functions=[Int, TRUE, FALSE], n_inline_leaves=2)
|
|
145
|
+
# egraph.run(array_api_ruleset.saturate())
|
|
146
|
+
# egraph.display()
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
# %%
|
egglog/exp/array_api_numba.py
CHANGED
|
@@ -67,6 +67,6 @@ def _unique_inverse(x: NDArray, i: Int):
|
|
|
67
67
|
return [
|
|
68
68
|
# Creating a mask array of when the unique inverse is a value is the same as a mask array for when the value is that index of the unique values
|
|
69
69
|
rewrite(unique_inverse(x)[Int(1)] == NDArray.scalar(Value.int(i)), subsume=True).to(
|
|
70
|
-
x == NDArray.scalar(unique_values(x).index(
|
|
70
|
+
x == NDArray.scalar(unique_values(x).index((i,)))
|
|
71
71
|
),
|
|
72
72
|
]
|
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
# mypy: disable-error-code="empty-body"
|
|
2
2
|
from __future__ import annotations
|
|
3
3
|
|
|
4
|
-
import numpy as np
|
|
5
|
-
|
|
6
4
|
from egglog import *
|
|
7
5
|
|
|
8
6
|
from .array_api import *
|
|
@@ -13,9 +11,12 @@ from .program_gen import *
|
|
|
13
11
|
# Depends on `np` as a global variable.
|
|
14
12
|
##
|
|
15
13
|
|
|
16
|
-
array_api_program_gen_ruleset = ruleset()
|
|
14
|
+
array_api_program_gen_ruleset = ruleset(name="array_api_program_gen_ruleset")
|
|
15
|
+
array_api_program_gen_eval_ruleset = ruleset(name="array_api_program_gen_eval_ruleset")
|
|
17
16
|
|
|
18
|
-
array_api_program_gen_schedule =
|
|
17
|
+
array_api_program_gen_schedule = (
|
|
18
|
+
array_api_program_gen_ruleset | program_gen_ruleset | array_api_program_gen_eval_ruleset | eval_program_rulseset
|
|
19
|
+
).saturate()
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
@function
|
|
@@ -56,39 +57,56 @@ def _int_program(i64_: i64, i: Int, j: Int):
|
|
|
56
57
|
|
|
57
58
|
|
|
58
59
|
@function
|
|
59
|
-
def tuple_int_program(x: TupleInt) -> Program:
|
|
60
|
+
def tuple_int_program(x: TupleInt) -> Program:
|
|
61
|
+
...
|
|
62
|
+
# Could be rewritten as a fold, but we don't support generic folds yet
|
|
63
|
+
# return x.fold(Program("("), lambda acc, i: acc + ", " + int_program(i)) + ")"
|
|
60
64
|
|
|
61
65
|
|
|
62
66
|
@function
|
|
63
|
-
def tuple_int_program_inner(x: TupleInt) -> Program:
|
|
67
|
+
def tuple_int_program_inner(x: TupleInt) -> Program:
|
|
68
|
+
"""
|
|
69
|
+
Returns the tuple w/ out the parenthesis
|
|
70
|
+
"""
|
|
64
71
|
|
|
65
72
|
|
|
66
73
|
@array_api_program_gen_ruleset.register
|
|
67
|
-
def _tuple_int_program(i: Int,
|
|
74
|
+
def _tuple_int_program(i: Int, ti: TupleInt, k: i64, idx_fn: Callable[[Int], Int], vec_int: Vec[Int]):
|
|
68
75
|
yield rewrite(int_program(ti[i])).to(tuple_int_program(ti) + "[" + int_program(i) + "]")
|
|
76
|
+
yield rewrite(int_program(ti.length())).to(Program("len(") + tuple_int_program(ti) + ")")
|
|
69
77
|
|
|
70
78
|
yield rewrite(tuple_int_program(ti)).to(Program("(") + tuple_int_program_inner(ti) + ")")
|
|
71
|
-
|
|
72
|
-
|
|
79
|
+
|
|
80
|
+
yield rewrite(tuple_int_program_inner(TupleInt(0, idx_fn))).to(Program(""))
|
|
81
|
+
|
|
82
|
+
yield rewrite(tuple_int_program_inner(TupleInt(Int(k), idx_fn))).to(
|
|
83
|
+
int_program(idx_fn(Int(0))) + ", " + tuple_int_program_inner(TupleInt(Int(k - 1), lambda i: idx_fn(i + 1))),
|
|
84
|
+
ne(k).to(i64(0)),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
yield rewrite(tuple_int_program_inner(TupleInt.from_vec(Vec[Int]()))).to(Program(""))
|
|
88
|
+
yield rewrite(tuple_int_program_inner(TupleInt.from_vec(vec_int))).to(
|
|
89
|
+
int_program(vec_int[0]) + ", " + tuple_int_program_inner(TupleInt.from_vec(vec_int.remove(0))),
|
|
90
|
+
vec_int.length() > 1,
|
|
91
|
+
)
|
|
92
|
+
yield rewrite(tuple_int_program_inner(TupleInt.from_vec(vec_int))).to(
|
|
93
|
+
int_program(vec_int[0]) + ",",
|
|
94
|
+
eq(vec_int.length()).to(i64(1)),
|
|
73
95
|
)
|
|
74
|
-
yield rewrite(tuple_int_program_inner(TupleInt(i))).to(int_program(i) + ",")
|
|
75
96
|
|
|
76
97
|
|
|
77
98
|
@function
|
|
78
99
|
def ndarray_program(x: NDArray) -> Program: ...
|
|
79
100
|
|
|
80
101
|
|
|
81
|
-
@function
|
|
82
|
-
def
|
|
102
|
+
@function(ruleset=array_api_program_gen_ruleset)
|
|
103
|
+
def ndarray_function_two_program(res: NDArray, l: NDArray, r: NDArray) -> Program:
|
|
104
|
+
return ndarray_program(res).function_two(ndarray_program(l), ndarray_program(r))
|
|
83
105
|
|
|
84
106
|
|
|
85
|
-
@
|
|
86
|
-
def
|
|
87
|
-
|
|
88
|
-
yield rule(eq(f).to(ndarray_function_two(res, l, r))).then(
|
|
89
|
-
union(f).with_(ndarray_program(res).function_two(ndarray_program(l), ndarray_program(r))),
|
|
90
|
-
f.eval_py_object({"np": np}),
|
|
91
|
-
)
|
|
107
|
+
@function(ruleset=array_api_program_gen_eval_ruleset)
|
|
108
|
+
def ndarray_function_two(res: NDArray, l: NDArray, r: NDArray) -> EvalProgram:
|
|
109
|
+
return EvalProgram(ndarray_function_two_program(res, l, r), {"np": np})
|
|
92
110
|
|
|
93
111
|
|
|
94
112
|
@function
|
|
@@ -248,12 +266,31 @@ def multi_axis_index_key_program(x: MultiAxisIndexKey) -> Program: ...
|
|
|
248
266
|
|
|
249
267
|
|
|
250
268
|
@array_api_program_gen_ruleset.register
|
|
251
|
-
def _multi_axis_index_key_program(
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
269
|
+
def _multi_axis_index_key_program(
|
|
270
|
+
idx_fn: Callable[[Int], MultiAxisIndexKeyItem], k: i64, vec: Vec[MultiAxisIndexKeyItem], i: MultiAxisIndexKeyItem
|
|
271
|
+
):
|
|
272
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey(0, idx_fn))).to(Program(""))
|
|
273
|
+
|
|
274
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey(Int(k), idx_fn))).to(
|
|
275
|
+
multi_axis_index_key_item_program(idx_fn(Int(0)))
|
|
276
|
+
+ ", "
|
|
277
|
+
+ multi_axis_index_key_program(MultiAxisIndexKey(Int(k - 1), lambda i: idx_fn(i + 1))),
|
|
278
|
+
ne(k).to(i64(0)),
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem]()))).to(
|
|
282
|
+
Program("")
|
|
283
|
+
)
|
|
284
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec))).to(
|
|
285
|
+
multi_axis_index_key_item_program(vec[0]) + ",",
|
|
286
|
+
eq(vec.length()).to(i64(1)),
|
|
287
|
+
)
|
|
288
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec))).to(
|
|
289
|
+
multi_axis_index_key_item_program(vec[0])
|
|
290
|
+
+ ", "
|
|
291
|
+
+ multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec.remove(0))),
|
|
292
|
+
vec.length() > 1,
|
|
255
293
|
)
|
|
256
|
-
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.EMPTY)).to(Program("()"))
|
|
257
294
|
|
|
258
295
|
|
|
259
296
|
@function
|
|
@@ -266,7 +303,7 @@ def _index_key_program(i: Int, s: Slice, key: MultiAxisIndexKey, a: NDArray):
|
|
|
266
303
|
yield rewrite(index_key_program(IndexKey.int(i))).to(int_program(i))
|
|
267
304
|
yield rewrite(index_key_program(IndexKey.slice(s))).to(slice_program(s))
|
|
268
305
|
yield rewrite(index_key_program(IndexKey.multi_axis(key))).to(multi_axis_index_key_program(key))
|
|
269
|
-
yield rewrite(index_key_program(
|
|
306
|
+
yield rewrite(index_key_program(IndexKey.ndarray(a))).to(ndarray_program(a))
|
|
270
307
|
|
|
271
308
|
|
|
272
309
|
@function
|
egglog/exp/program_gen.py
CHANGED
|
@@ -83,8 +83,18 @@ class Program(Expr):
|
|
|
83
83
|
Only keeps the original parent, not any additional ones, so that each set of statements is only added once.
|
|
84
84
|
"""
|
|
85
85
|
|
|
86
|
-
@
|
|
87
|
-
def
|
|
86
|
+
@property
|
|
87
|
+
def is_identifer(self) -> Bool:
|
|
88
|
+
"""
|
|
89
|
+
Returns whether the expression is an identifier. Used so that we don't re-assign any identifiers.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
converter(String, Program, Program)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class EvalProgram(Expr):
|
|
97
|
+
def __init__(self, program: Program, globals: object) -> None:
|
|
88
98
|
"""
|
|
89
99
|
Evaluates the program and saves as the py_object
|
|
90
100
|
"""
|
|
@@ -98,38 +108,34 @@ class Program(Expr):
|
|
|
98
108
|
"""
|
|
99
109
|
|
|
100
110
|
@property
|
|
101
|
-
def
|
|
111
|
+
def statements(self) -> String:
|
|
102
112
|
"""
|
|
103
|
-
Returns
|
|
113
|
+
Returns the statements of the program, if it's been compiled
|
|
104
114
|
"""
|
|
105
115
|
|
|
106
116
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
program_gen_ruleset = ruleset()
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
@program_gen_ruleset.register
|
|
113
|
-
def _py_object(p: Program, expr: String, statements: String, g: PyObject):
|
|
117
|
+
@ruleset
|
|
118
|
+
def eval_program_rulseset(ep: EvalProgram, p: Program, expr: String, statements: String, g: PyObject):
|
|
114
119
|
# When we evaluate a program, we first want to compile to a string
|
|
115
|
-
yield rule(p
|
|
120
|
+
yield rule(EvalProgram(p, g)).then(p.compile())
|
|
116
121
|
# Then we want to evaluate the statements/expr
|
|
117
122
|
yield rule(
|
|
118
|
-
|
|
123
|
+
eq(ep).to(EvalProgram(p, g)),
|
|
119
124
|
eq(p.statements).to(statements),
|
|
120
125
|
eq(p.expr).to(expr),
|
|
121
126
|
).then(
|
|
122
|
-
set_(
|
|
127
|
+
set_(ep.py_object).to(
|
|
123
128
|
py_eval(
|
|
124
129
|
"l['___res']",
|
|
125
130
|
PyObject.dict(PyObject.from_string("l"), py_exec(join(statements, "\n", "___res = ", expr), g)),
|
|
126
131
|
)
|
|
127
|
-
)
|
|
132
|
+
),
|
|
133
|
+
set_(ep.statements).to(statements),
|
|
128
134
|
)
|
|
129
135
|
|
|
130
136
|
|
|
131
|
-
@
|
|
132
|
-
def
|
|
137
|
+
@ruleset
|
|
138
|
+
def program_gen_ruleset(
|
|
133
139
|
s: String,
|
|
134
140
|
s1: String,
|
|
135
141
|
s2: String,
|
egglog/functionalize.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from functools import partial
|
|
5
|
+
from inspect import Parameter, signature
|
|
6
|
+
from typing import Any, TypeVar, cast
|
|
7
|
+
|
|
8
|
+
__all__ = ["functionalize"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
T = TypeVar("T", bound=Callable)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# TODO: Add `to_lift` param so that we only transform those with vars in them to args
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def functionalize(f: T, get_annotation: Callable[[object], type | None]) -> T:
|
|
18
|
+
"""
|
|
19
|
+
Takes a function and returns a new function with all names (co_names) and free variables (co_freevars) added as arguments
|
|
20
|
+
and then partially applied with their values. The second arg, get_annotation, will be applied to all values
|
|
21
|
+
to get their type annotation. If it is None, that arg will not be added as a parameter.
|
|
22
|
+
|
|
23
|
+
For example if you have:
|
|
24
|
+
|
|
25
|
+
def get_annotation(x): return int if x <= 10 else None
|
|
26
|
+
|
|
27
|
+
g = 10
|
|
28
|
+
def f(a, a2):
|
|
29
|
+
def h(b: Z):
|
|
30
|
+
return a + a2 + b + g
|
|
31
|
+
|
|
32
|
+
return functionalize(h, get_annotation)
|
|
33
|
+
res = f(9, 11)
|
|
34
|
+
|
|
35
|
+
It should be equivalent to (according to body, signature, and annotations) (Note that the new arguments will be positional only):
|
|
36
|
+
|
|
37
|
+
def h(a: get_annotation(a), g: get_annotation(g), b: Z):
|
|
38
|
+
return a + b + g
|
|
39
|
+
res = partial(h, a, g)
|
|
40
|
+
"""
|
|
41
|
+
code = f.__code__
|
|
42
|
+
names = tuple(n for n in code.co_names if n in f.__globals__)
|
|
43
|
+
free_vars = code.co_freevars
|
|
44
|
+
|
|
45
|
+
global_values: list[Any] = [f.__globals__[name] for name in names]
|
|
46
|
+
free_var_values = [cell.cell_contents for cell in f.__closure__] if f.__closure__ else []
|
|
47
|
+
assert len(free_var_values) == len(free_vars), "Free vars and their values do not match"
|
|
48
|
+
global_values_filtered = [
|
|
49
|
+
(i, name, value, annotation)
|
|
50
|
+
for i, (name, value) in enumerate(zip(names, global_values, strict=True))
|
|
51
|
+
if (annotation := get_annotation(value)) is not None
|
|
52
|
+
]
|
|
53
|
+
free_var_values_filtered = [
|
|
54
|
+
(i, name, value, annotation)
|
|
55
|
+
for i, (name, value) in enumerate(zip(free_vars, free_var_values, strict=True))
|
|
56
|
+
if (annotation := get_annotation(value)) is not None
|
|
57
|
+
]
|
|
58
|
+
additional_arg_filtered = global_values_filtered + free_var_values_filtered
|
|
59
|
+
|
|
60
|
+
# Create a wrapper function
|
|
61
|
+
def wrapper(*args):
|
|
62
|
+
# Split args into names, free vars and other args
|
|
63
|
+
name_args, free_var_args, rest_args = (
|
|
64
|
+
args[: (n_names := len(global_values_filtered))],
|
|
65
|
+
args[n_names : (n_args := len(additional_arg_filtered))],
|
|
66
|
+
args[n_args:],
|
|
67
|
+
)
|
|
68
|
+
# Update globals with names
|
|
69
|
+
f.__globals__.update({
|
|
70
|
+
name: arg for (_, name, _, _), arg in zip(global_values_filtered, name_args, strict=False)
|
|
71
|
+
})
|
|
72
|
+
# update function free vars with free var args
|
|
73
|
+
for (i, _, _, _), value in zip(free_var_values_filtered, free_var_args, strict=True):
|
|
74
|
+
assert f.__closure__, "Function does not have closure"
|
|
75
|
+
f.__closure__[i].cell_contents = value
|
|
76
|
+
return f(*rest_args)
|
|
77
|
+
|
|
78
|
+
# Set the signature of the new function to a signature with the free vars and names added as arguments
|
|
79
|
+
orig_signature = signature(f)
|
|
80
|
+
wrapper.__signature__ = orig_signature.replace( # type: ignore[attr-defined]
|
|
81
|
+
parameters=[
|
|
82
|
+
*[Parameter(n, Parameter.POSITIONAL_OR_KEYWORD) for _, n, _, _ in additional_arg_filtered],
|
|
83
|
+
*orig_signature.parameters.values(),
|
|
84
|
+
]
|
|
85
|
+
)
|
|
86
|
+
# Set the annotations of the new function to the annotations of the original function + annotations of passed in values
|
|
87
|
+
wrapper.__annotations__ = f.__annotations__ | {n: a for _, n, _, a in additional_arg_filtered}
|
|
88
|
+
wrapper.__name__ = f.__name__
|
|
89
|
+
|
|
90
|
+
# Partially apply the wrapper function with the current values of the free vars
|
|
91
|
+
return cast(T, partial(wrapper, *(v for _, _, v, _ in additional_arg_filtered)))
|