egglog 12.0.0__cp313-cp313t-manylinux_2_17_ppc64.manylinux2014_ppc64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- egglog/__init__.py +13 -0
- egglog/bindings.cpython-313t-powerpc64-linux-gnu.so +0 -0
- egglog/bindings.pyi +887 -0
- egglog/builtins.py +1144 -0
- egglog/config.py +8 -0
- egglog/conversion.py +290 -0
- egglog/declarations.py +964 -0
- egglog/deconstruct.py +176 -0
- egglog/egraph.py +2247 -0
- egglog/egraph_state.py +978 -0
- egglog/examples/README.rst +5 -0
- egglog/examples/__init__.py +3 -0
- egglog/examples/bignum.py +32 -0
- egglog/examples/bool.py +38 -0
- egglog/examples/eqsat_basic.py +44 -0
- egglog/examples/fib.py +28 -0
- egglog/examples/higher_order_functions.py +42 -0
- egglog/examples/jointree.py +64 -0
- egglog/examples/lambda_.py +287 -0
- egglog/examples/matrix.py +175 -0
- egglog/examples/multiset.py +60 -0
- egglog/examples/ndarrays.py +144 -0
- egglog/examples/resolution.py +84 -0
- egglog/examples/schedule_demo.py +34 -0
- egglog/exp/MoA.ipynb +617 -0
- egglog/exp/__init__.py +3 -0
- egglog/exp/any_expr.py +947 -0
- egglog/exp/any_expr_example.ipynb +408 -0
- egglog/exp/array_api.py +2019 -0
- egglog/exp/array_api_jit.py +51 -0
- egglog/exp/array_api_loopnest.py +74 -0
- egglog/exp/array_api_numba.py +69 -0
- egglog/exp/array_api_program_gen.py +510 -0
- egglog/exp/program_gen.py +427 -0
- egglog/exp/siu_examples.py +32 -0
- egglog/ipython_magic.py +41 -0
- egglog/pretty.py +566 -0
- egglog/py.typed +0 -0
- egglog/runtime.py +888 -0
- egglog/thunk.py +97 -0
- egglog/type_constraint_solver.py +111 -0
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35798 -0
- egglog/visualizer_widget.py +39 -0
- egglog-12.0.0.dist-info/METADATA +93 -0
- egglog-12.0.0.dist-info/RECORD +48 -0
- egglog-12.0.0.dist-info/WHEEL +5 -0
- egglog-12.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,510 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from egglog import *
|
|
5
|
+
|
|
6
|
+
from .array_api import *
|
|
7
|
+
from .program_gen import *
|
|
8
|
+
|
|
9
|
+
##
|
|
10
|
+
# Functionality to compile expression to strings of NumPy code.
|
|
11
|
+
# Depends on `np` as a global variable.
|
|
12
|
+
##
|
|
13
|
+
|
|
14
|
+
array_api_program_gen_ruleset = ruleset(name="array_api_program_gen_ruleset")
|
|
15
|
+
array_api_program_gen_eval_ruleset = ruleset(name="array_api_program_gen_eval_ruleset")
|
|
16
|
+
|
|
17
|
+
array_api_program_gen_combined_ruleset = (
|
|
18
|
+
array_api_program_gen_ruleset
|
|
19
|
+
| program_gen_ruleset
|
|
20
|
+
| array_api_program_gen_eval_ruleset
|
|
21
|
+
| array_api_vec_to_cons_ruleset
|
|
22
|
+
)
|
|
23
|
+
array_api_program_gen_schedule = (array_api_program_gen_combined_ruleset | eval_program_rulseset).saturate()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@function
|
|
27
|
+
def bool_program(x: Boolean) -> Program: ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@array_api_program_gen_ruleset.register
|
|
31
|
+
def _bool_program():
|
|
32
|
+
yield rewrite(bool_program(TRUE)).to(Program("True"))
|
|
33
|
+
yield rewrite(bool_program(FALSE)).to(Program("False"))
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@function
|
|
37
|
+
def int_program(x: Int) -> Program: ...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@array_api_program_gen_ruleset.register
|
|
41
|
+
def _int_program(i64_: i64, i: Int, j: Int, s: String):
|
|
42
|
+
yield rewrite(int_program(Int.var(s))).to(Program(s, True))
|
|
43
|
+
yield rewrite(int_program(Int(i64_))).to(Program(i64_.to_string()))
|
|
44
|
+
yield rewrite(int_program(~i)).to(Program("~") + int_program(i))
|
|
45
|
+
yield rewrite(bool_program(i < j)).to(Program("(") + int_program(i) + " < " + int_program(j) + ")")
|
|
46
|
+
yield rewrite(bool_program(i <= j)).to(Program("(") + int_program(i) + " <= " + int_program(j) + ")")
|
|
47
|
+
yield rewrite(bool_program(i > j)).to(Program("(") + int_program(i) + " > " + int_program(j) + ")")
|
|
48
|
+
yield rewrite(bool_program(i >= j)).to(Program("(") + int_program(i) + " >= " + int_program(j) + ")")
|
|
49
|
+
yield rewrite(bool_program(i == j)).to(Program("(") + int_program(i) + " == " + int_program(j) + ")")
|
|
50
|
+
yield rewrite(int_program(i + j)).to(Program("(") + int_program(i) + " + " + int_program(j) + ")")
|
|
51
|
+
yield rewrite(int_program(i - j)).to(Program("(") + int_program(i) + " - " + int_program(j) + ")")
|
|
52
|
+
yield rewrite(int_program(i * j)).to(Program("(") + int_program(i) + " * " + int_program(j) + ")")
|
|
53
|
+
yield rewrite(int_program(i / j)).to(Program("(") + int_program(i) + " / " + int_program(j) + ")")
|
|
54
|
+
yield rewrite(int_program(i % j)).to(Program("(") + int_program(i) + " % " + int_program(j) + ")")
|
|
55
|
+
yield rewrite(int_program(i**j)).to(Program("(") + int_program(i) + " ** " + int_program(j) + ")")
|
|
56
|
+
yield rewrite(int_program(i & j)).to(Program("(") + int_program(i) + " & " + int_program(j) + ")")
|
|
57
|
+
yield rewrite(int_program(i | j)).to(Program("(") + int_program(i) + " | " + int_program(j) + ")")
|
|
58
|
+
yield rewrite(int_program(i ^ j)).to(Program("(") + int_program(i) + " ^ " + int_program(j) + ")")
|
|
59
|
+
yield rewrite(int_program(i << j)).to(Program("(") + int_program(i) + " << " + int_program(j) + ")")
|
|
60
|
+
yield rewrite(int_program(i >> j)).to(Program("(") + int_program(i) + " >> " + int_program(j) + ")")
|
|
61
|
+
yield rewrite(int_program(i // j)).to(Program("(") + int_program(i) + " // " + int_program(j) + ")")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@function
|
|
65
|
+
def tuple_int_foldl_program(xs: TupleIntLike, f: Callable[[Program, Int], Program], init: ProgramLike) -> Program: ...
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@function(ruleset=array_api_program_gen_ruleset)
|
|
69
|
+
def tuple_int_program(x: TupleIntLike) -> Program:
|
|
70
|
+
return tuple_int_foldl_program(x, lambda acc, i: acc + int_program(i) + ", ", "(") + ")"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@array_api_program_gen_ruleset.register
|
|
74
|
+
def _tuple_int_program(i: Int, ti: TupleInt, ti2: TupleInt, f: Callable[[Program, Int], Program], init: Program):
|
|
75
|
+
yield rewrite(int_program(ti[i])).to(tuple_int_program(ti) + "[" + int_program(i) + "]")
|
|
76
|
+
yield rewrite(int_program(ti.length())).to(Program("len(") + tuple_int_program(ti) + ")")
|
|
77
|
+
|
|
78
|
+
yield rewrite(tuple_int_foldl_program(TupleInt.EMPTY, f, init)).to(init)
|
|
79
|
+
yield rewrite(tuple_int_foldl_program(ti.append(i), f, init)).to(f(tuple_int_foldl_program(ti, f, init), i))
|
|
80
|
+
|
|
81
|
+
yield rewrite(tuple_int_program(ti + ti2)).to(
|
|
82
|
+
Program("(") + tuple_int_program(ti) + " + " + tuple_int_program(ti2) + ")"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@function
|
|
87
|
+
def ndarray_program(x: NDArray) -> Program: ...
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@function(ruleset=array_api_program_gen_ruleset)
|
|
91
|
+
def ndarray_function_two_program(res: NDArray, l: NDArray, r: NDArray) -> Program:
|
|
92
|
+
return ndarray_program(res).function_two(ndarray_program(l), ndarray_program(r))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@function(ruleset=array_api_program_gen_eval_ruleset)
|
|
96
|
+
def ndarray_function_two(res: NDArray, l: NDArray, r: NDArray) -> EvalProgram:
|
|
97
|
+
return EvalProgram(ndarray_function_two_program(res, l, r), {"np": np})
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@function
|
|
101
|
+
def dtype_program(x: DType) -> Program: ...
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@array_api_program_gen_ruleset.register
|
|
105
|
+
def _dtype_program():
|
|
106
|
+
yield rewrite(dtype_program(DType.float64)).to(Program("np.dtype(np.float64)"))
|
|
107
|
+
yield rewrite(dtype_program(DType.float32)).to(Program("np.dtype(np.float32)"))
|
|
108
|
+
yield rewrite(dtype_program(DType.int64)).to(Program("np.dtype(np.int64)"))
|
|
109
|
+
yield rewrite(dtype_program(DType.int32)).to(Program("np.dtype(np.int32)"))
|
|
110
|
+
yield rewrite(dtype_program(DType.bool)).to(Program("np.dtype(np.bool)"))
|
|
111
|
+
yield rewrite(dtype_program(DType.object)).to(Program("np.dtype(np.object_)"))
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@function
|
|
115
|
+
def float_program(x: Float) -> Program: ...
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@array_api_program_gen_ruleset.register
|
|
119
|
+
def _float_program(f: Float, g: Float, f64_: f64, i: Int, r: BigRat):
|
|
120
|
+
yield rewrite(float_program(Float(f64_))).to(Program(f64_.to_string()))
|
|
121
|
+
yield rewrite(float_program(f.abs())).to(Program("np.abs(") + float_program(f) + ")")
|
|
122
|
+
yield rewrite(float_program(Float.from_int(i))).to(int_program(i))
|
|
123
|
+
yield rewrite(float_program(f + g)).to(Program("(") + float_program(f) + " + " + float_program(g) + ")")
|
|
124
|
+
yield rewrite(float_program(f - g)).to(Program("(") + float_program(f) + " - " + float_program(g) + ")")
|
|
125
|
+
yield rewrite(float_program(f * g)).to(Program("(") + float_program(f) + " * " + float_program(g) + ")")
|
|
126
|
+
yield rewrite(float_program(f / g)).to(Program("(") + float_program(f) + " / " + float_program(g) + ")")
|
|
127
|
+
yield rewrite(float_program(Float.rational(r))).to(
|
|
128
|
+
Program("float(") + Program(r.numer.to_string()) + " / " + Program(r.denom.to_string()) + ")",
|
|
129
|
+
ne(r.denom).to(BigInt(1)),
|
|
130
|
+
)
|
|
131
|
+
yield rewrite(float_program(Float.rational(r))).to(
|
|
132
|
+
Program("float(") + Program(r.numer.to_string()) + ")", eq(r.denom).to(BigInt(1))
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@function
|
|
137
|
+
def value_program(x: Value) -> Program: ...
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@array_api_program_gen_ruleset.register
|
|
141
|
+
def _value_program(i: Int, b: Boolean, f: Float, x: NDArray, v1: Value, v2: Value, xs: NDArray, ti: TupleInt):
|
|
142
|
+
yield rewrite(value_program(Value.int(i))).to(int_program(i))
|
|
143
|
+
yield rewrite(value_program(Value.bool(b))).to(bool_program(b))
|
|
144
|
+
yield rewrite(value_program(Value.float(f))).to(float_program(f))
|
|
145
|
+
# Could add .item() but we usually dont need it.
|
|
146
|
+
yield rewrite(value_program(x.to_value())).to(ndarray_program(x))
|
|
147
|
+
yield rewrite(value_program(v1 < v2)).to(Program("(") + value_program(v1) + " < " + value_program(v2) + ")")
|
|
148
|
+
yield rewrite(value_program(v1 / v2)).to(Program("(") + value_program(v1) + " / " + value_program(v2) + ")")
|
|
149
|
+
yield rewrite(value_program(v1 + v2)).to(Program("(") + value_program(v1) + " + " + value_program(v2) + ")")
|
|
150
|
+
yield rewrite(value_program(v1 * v2)).to(Program("(") + value_program(v1) + " * " + value_program(v2) + ")")
|
|
151
|
+
yield rewrite(bool_program(v1.to_bool)).to(value_program(v1))
|
|
152
|
+
yield rewrite(int_program(v1.to_int)).to(value_program(v1))
|
|
153
|
+
yield rewrite(value_program(xs.index(ti))).to((ndarray_program(xs) + "[" + tuple_int_program(ti) + "]").assign())
|
|
154
|
+
yield rewrite(value_program(v1.sqrt())).to(Program("np.sqrt(") + value_program(v1) + ")")
|
|
155
|
+
yield rewrite(value_program(v1.real())).to(Program("np.real(") + value_program(v1) + ")")
|
|
156
|
+
yield rewrite(value_program(v1.conj())).to(Program("np.conj(") + value_program(v1) + ")")
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@function
|
|
160
|
+
def tuple_value_foldl_program(
|
|
161
|
+
xs: TupleValueLike, f: Callable[[Program, Value], Program], init: ProgramLike
|
|
162
|
+
) -> Program: ...
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@function(ruleset=array_api_program_gen_ruleset)
|
|
166
|
+
def tuple_value_program(x: TupleValueLike) -> Program:
|
|
167
|
+
return tuple_value_foldl_program(x, lambda acc, i: acc + value_program(i) + ", ", "(") + ")"
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@array_api_program_gen_ruleset.register
|
|
171
|
+
def _tuple_value_program(i: Int, ti: TupleValue, f: Callable[[Program, Value], Program], v: Value, init: Program):
|
|
172
|
+
yield rewrite(value_program(ti[i])).to(tuple_value_program(ti) + "[" + int_program(i) + "]")
|
|
173
|
+
yield rewrite(int_program(ti.length())).to(Program("len(") + tuple_value_program(ti) + ")")
|
|
174
|
+
|
|
175
|
+
yield rewrite(tuple_value_foldl_program(TupleValue.EMPTY, f, init)).to(init)
|
|
176
|
+
yield rewrite(tuple_value_foldl_program(ti.append(v), f, init)).to(f(tuple_value_foldl_program(ti, f, init), v))
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@function
|
|
180
|
+
def tuple_ndarray_foldl_program(
|
|
181
|
+
xs: TupleNDArrayLike, f: Callable[[Program, NDArray], Program], init: ProgramLike
|
|
182
|
+
) -> Program: ...
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@function(ruleset=array_api_program_gen_ruleset)
|
|
186
|
+
def tuple_ndarray_program(x: TupleNDArrayLike) -> Program:
|
|
187
|
+
return tuple_ndarray_foldl_program(x, lambda acc, i: acc + ndarray_program(i) + ", ", "(") + ")"
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@array_api_program_gen_ruleset.register
|
|
191
|
+
def _tuple_ndarray_program(
|
|
192
|
+
i: Int, ti: TupleNDArray, f: Callable[[Program, NDArray], Program], v: NDArray, init: Program
|
|
193
|
+
):
|
|
194
|
+
yield rewrite(ndarray_program(ti[i])).to(tuple_ndarray_program(ti) + "[" + int_program(i) + "]")
|
|
195
|
+
yield rewrite(int_program(ti.length())).to(Program("len(") + tuple_ndarray_program(ti) + ")")
|
|
196
|
+
|
|
197
|
+
yield rewrite(tuple_ndarray_foldl_program(TupleNDArray.EMPTY, f, init)).to(init)
|
|
198
|
+
yield rewrite(tuple_ndarray_foldl_program(ti.append(v), f, init)).to(f(tuple_ndarray_foldl_program(ti, f, init), v))
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@function
|
|
202
|
+
def optional_dtype_program(x: OptionalDType) -> Program: ...
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
@array_api_program_gen_ruleset.register
|
|
206
|
+
def _optional_dtype_program(dtype: DType):
|
|
207
|
+
yield rewrite(optional_dtype_program(OptionalDType.none)).to(Program("None"))
|
|
208
|
+
yield rewrite(optional_dtype_program(OptionalDType.some(dtype))).to(dtype_program(dtype))
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@function
|
|
212
|
+
def optional_int_program(x: OptionalInt) -> Program: ...
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@array_api_program_gen_ruleset.register
|
|
216
|
+
def _optional_int_program(x: Int):
|
|
217
|
+
yield rewrite(optional_int_program(OptionalInt.none)).to(Program("None"))
|
|
218
|
+
yield rewrite(optional_int_program(OptionalInt.some(x))).to(int_program(x))
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@function
|
|
222
|
+
def optional_int_slice_program(x: OptionalInt) -> Program:
|
|
223
|
+
"""
|
|
224
|
+
Translates an optional int to a program, but translates None as "" instead of None
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@array_api_program_gen_ruleset.register
|
|
229
|
+
def _optional_int_slice_program(x: Int):
|
|
230
|
+
yield rewrite(optional_int_slice_program(OptionalInt.none)).to(Program(""))
|
|
231
|
+
yield rewrite(optional_int_slice_program(OptionalInt.some(x))).to(int_program(x))
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@function
|
|
235
|
+
def slice_program(x: Slice) -> Program: ...
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@array_api_program_gen_ruleset.register
|
|
239
|
+
def _slice_program(start: OptionalInt, stop: OptionalInt, i: Int):
|
|
240
|
+
yield rewrite(slice_program(Slice(start, stop, OptionalInt.none))).to(
|
|
241
|
+
optional_int_slice_program(start) + ":" + optional_int_slice_program(stop)
|
|
242
|
+
)
|
|
243
|
+
yield rewrite(slice_program(Slice(start, stop, OptionalInt.some(i)))).to(
|
|
244
|
+
optional_int_slice_program(start) + ":" + optional_int_slice_program(stop) + ":" + int_program(i)
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@function
|
|
249
|
+
def multi_axis_index_key_item_program(x: MultiAxisIndexKeyItem) -> Program: ...
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
@array_api_program_gen_ruleset.register
|
|
253
|
+
def _multi_axis_index_key_item_program(i: Int, s: Slice):
|
|
254
|
+
yield rewrite(multi_axis_index_key_item_program(MultiAxisIndexKeyItem.int(i))).to(int_program(i))
|
|
255
|
+
yield rewrite(multi_axis_index_key_item_program(MultiAxisIndexKeyItem.slice(s))).to(slice_program(s))
|
|
256
|
+
yield rewrite(multi_axis_index_key_item_program(MultiAxisIndexKeyItem.ELLIPSIS)).to(Program("..."))
|
|
257
|
+
yield rewrite(multi_axis_index_key_item_program(MultiAxisIndexKeyItem.NONE)).to(Program("None"))
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@function
|
|
261
|
+
def multi_axis_index_key_program(x: MultiAxisIndexKey) -> Program: ...
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@array_api_program_gen_ruleset.register
|
|
265
|
+
def _multi_axis_index_key_program(
|
|
266
|
+
idx_fn: Callable[[Int], MultiAxisIndexKeyItem], k: i64, vec: Vec[MultiAxisIndexKeyItem], i: MultiAxisIndexKeyItem
|
|
267
|
+
):
|
|
268
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey(0, idx_fn))).to(Program(""))
|
|
269
|
+
|
|
270
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey(Int(k), idx_fn))).to(
|
|
271
|
+
multi_axis_index_key_item_program(idx_fn(Int(0)))
|
|
272
|
+
+ ", "
|
|
273
|
+
+ multi_axis_index_key_program(MultiAxisIndexKey(Int(k - 1), lambda i: idx_fn(i + 1))),
|
|
274
|
+
ne(k).to(i64(0)),
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem]()))).to(
|
|
278
|
+
Program("")
|
|
279
|
+
)
|
|
280
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec))).to(
|
|
281
|
+
multi_axis_index_key_item_program(vec[0]) + ",",
|
|
282
|
+
eq(vec.length()).to(i64(1)),
|
|
283
|
+
)
|
|
284
|
+
yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec))).to(
|
|
285
|
+
multi_axis_index_key_item_program(vec[0])
|
|
286
|
+
+ ", "
|
|
287
|
+
+ multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec.remove(0))),
|
|
288
|
+
vec.length() > 1,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
@function
|
|
293
|
+
def index_key_program(x: IndexKey) -> Program: ...
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
@array_api_program_gen_ruleset.register
|
|
297
|
+
def _index_key_program(i: Int, s: Slice, key: MultiAxisIndexKey, a: NDArray):
|
|
298
|
+
yield rewrite(index_key_program(IndexKey.ELLIPSIS)).to(Program("..."))
|
|
299
|
+
yield rewrite(index_key_program(IndexKey.int(i))).to(int_program(i))
|
|
300
|
+
yield rewrite(index_key_program(IndexKey.slice(s))).to(slice_program(s))
|
|
301
|
+
yield rewrite(index_key_program(IndexKey.multi_axis(key))).to(multi_axis_index_key_program(key))
|
|
302
|
+
yield rewrite(index_key_program(IndexKey.ndarray(a))).to(ndarray_program(a))
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
@function
|
|
306
|
+
def int_or_tuple_program(x: IntOrTuple) -> Program: ...
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@array_api_program_gen_ruleset.register
|
|
310
|
+
def _int_or_tuple_program(x: Int, t: TupleInt):
|
|
311
|
+
yield rewrite(int_or_tuple_program(IntOrTuple.int(x))).to(int_program(x))
|
|
312
|
+
yield rewrite(int_or_tuple_program(IntOrTuple.tuple(t))).to(tuple_int_program(t))
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
@function
|
|
316
|
+
def optional_int_or_tuple_program(x: OptionalIntOrTuple) -> Program: ...
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@array_api_program_gen_ruleset.register
|
|
320
|
+
def _optional_int_or_tuple_program(it: IntOrTuple):
|
|
321
|
+
yield rewrite(optional_int_or_tuple_program(OptionalIntOrTuple.some(it))).to(int_or_tuple_program(it))
|
|
322
|
+
yield rewrite(optional_int_or_tuple_program(OptionalIntOrTuple.none)).to(Program("None"))
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
@array_api_program_gen_ruleset.register
|
|
326
|
+
def _ndarray_program(
|
|
327
|
+
x: NDArray,
|
|
328
|
+
y: NDArray,
|
|
329
|
+
z: NDArray,
|
|
330
|
+
s: String,
|
|
331
|
+
dtype: DType,
|
|
332
|
+
ti: TupleInt,
|
|
333
|
+
i: Int,
|
|
334
|
+
tv: TupleValue,
|
|
335
|
+
v: Value,
|
|
336
|
+
ob: OptionalBool,
|
|
337
|
+
tnd: TupleNDArray,
|
|
338
|
+
optional_device_: OptionalDevice,
|
|
339
|
+
int_or_tuple_: IntOrTuple,
|
|
340
|
+
idx: IndexKey,
|
|
341
|
+
odtype: OptionalDType,
|
|
342
|
+
):
|
|
343
|
+
# Var
|
|
344
|
+
yield rewrite(ndarray_program(NDArray.var(s))).to(Program(s, True))
|
|
345
|
+
|
|
346
|
+
# Asssume dtype
|
|
347
|
+
z_assumed_dtype = copy(z)
|
|
348
|
+
assume_dtype(z_assumed_dtype, dtype)
|
|
349
|
+
z_program = ndarray_program(z)
|
|
350
|
+
yield rewrite(ndarray_program(z_assumed_dtype)).to(
|
|
351
|
+
z_program.statement(Program("assert ") + z_program + ".dtype == " + dtype_program(dtype))
|
|
352
|
+
)
|
|
353
|
+
# assume shape
|
|
354
|
+
z_assumed_shape = copy(z)
|
|
355
|
+
assume_shape(z_assumed_shape, ti)
|
|
356
|
+
yield rewrite(ndarray_program(z_assumed_shape)).to(
|
|
357
|
+
z_program.statement(Program("assert ") + z_program + ".shape == " + tuple_int_program(ti))
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
# assume isfinite
|
|
361
|
+
z_assumed_isfinite = copy(z)
|
|
362
|
+
assume_isfinite(z_assumed_isfinite)
|
|
363
|
+
yield rewrite(ndarray_program(z_assumed_isfinite)).to(
|
|
364
|
+
z_program.statement(Program("assert np.all(np.isfinite(") + z_program + "))")
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
# Assume value_one_of
|
|
368
|
+
z_assumed_value_one_of = copy(z)
|
|
369
|
+
assume_value_one_of(z_assumed_value_one_of, tv)
|
|
370
|
+
yield rewrite(ndarray_program(z_assumed_value_one_of)).to(
|
|
371
|
+
z_program.statement(Program("assert set(np.unique(") + z_program + ")) == set(" + tuple_value_program(tv) + ")")
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# Value
|
|
375
|
+
|
|
376
|
+
# reshape (don't include copy, since not present in numpy)
|
|
377
|
+
yield rewrite(ndarray_program(reshape(y, ti, ob))).to(
|
|
378
|
+
(ndarray_program(y) + ".reshape(" + tuple_int_program(ti) + ")").assign()
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# astype
|
|
382
|
+
yield rewrite(ndarray_program(astype(y, dtype))).to(
|
|
383
|
+
(ndarray_program(y) + ".astype(" + dtype_program(dtype) + ")").assign()
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# unique_counts(x) => unique(x, return_counts=True)
|
|
387
|
+
yield rewrite(tuple_ndarray_program(unique_counts(x))).to(
|
|
388
|
+
(Program("np.unique(") + ndarray_program(x) + ", return_counts=True)").assign()
|
|
389
|
+
)
|
|
390
|
+
# unique_inverse(x) => unique(x, return_inverse=True)
|
|
391
|
+
yield rewrite(tuple_ndarray_program(unique_inverse(x))).to(
|
|
392
|
+
(Program("np.unique(") + ndarray_program(x) + ", return_inverse=True)").assign()
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
# Tuple ndarray indexing
|
|
396
|
+
yield rewrite(ndarray_program(tnd[i])).to(tuple_ndarray_program(tnd) + "[" + int_program(i) + "]")
|
|
397
|
+
|
|
398
|
+
# ndarray scalar
|
|
399
|
+
# TODO: Use dtype and shape and indexing instead?
|
|
400
|
+
# TODO: SPecify dtype?
|
|
401
|
+
yield rewrite(ndarray_program(NDArray.scalar(v))).to(Program("np.array(") + value_program(v) + ")")
|
|
402
|
+
|
|
403
|
+
# zeros
|
|
404
|
+
yield rewrite(ndarray_program(zeros(ti, OptionalDType.none, optional_device_))).to(
|
|
405
|
+
(Program("np.zeros(") + tuple_int_program(ti) + ")").assign()
|
|
406
|
+
)
|
|
407
|
+
yield rewrite(ndarray_program(zeros(ti, OptionalDType.some(dtype), optional_device_))).to(
|
|
408
|
+
(Program("np.zeros(") + tuple_int_program(ti) + ", dtype=" + dtype_program(dtype) + ")").assign(),
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
# unique_values
|
|
412
|
+
yield rewrite(ndarray_program(unique_values(x))).to((Program("np.unique(") + ndarray_program(x) + ")").assign())
|
|
413
|
+
|
|
414
|
+
# reshape
|
|
415
|
+
|
|
416
|
+
def bin_op(res: NDArray, op: str) -> Command:
|
|
417
|
+
return rewrite(ndarray_program(res)).to((ndarray_program(x) + f" {op} " + ndarray_program(y)).assign())
|
|
418
|
+
|
|
419
|
+
# NDARRAy ops
|
|
420
|
+
yield bin_op(x + y, "+")
|
|
421
|
+
yield bin_op(x - y, "-")
|
|
422
|
+
yield bin_op(x * y, "*")
|
|
423
|
+
yield bin_op(x / y, "/")
|
|
424
|
+
yield bin_op(x < y, "<")
|
|
425
|
+
yield bin_op(x <= y, "<=")
|
|
426
|
+
yield bin_op(x > y, ">")
|
|
427
|
+
yield bin_op(x >= y, ">=")
|
|
428
|
+
yield bin_op(x == y, "==")
|
|
429
|
+
yield bin_op(x @ y, "@")
|
|
430
|
+
yield bin_op(x % y, "%")
|
|
431
|
+
yield bin_op(x & y, "&")
|
|
432
|
+
yield bin_op(x | y, "|")
|
|
433
|
+
yield bin_op(x ^ y, "^")
|
|
434
|
+
yield bin_op(x << y, "<<")
|
|
435
|
+
yield bin_op(x >> y, ">>")
|
|
436
|
+
yield bin_op(x // y, "//")
|
|
437
|
+
yield bin_op(x**y, "**")
|
|
438
|
+
|
|
439
|
+
# setitem
|
|
440
|
+
mod_x = copy(x)
|
|
441
|
+
mod_x[idx] = y
|
|
442
|
+
assigned_x = ndarray_program(x).assign()
|
|
443
|
+
yield rewrite(ndarray_program(mod_x)).to(
|
|
444
|
+
assigned_x.statement(assigned_x + "[" + index_key_program(idx) + "] = " + ndarray_program(y))
|
|
445
|
+
)
|
|
446
|
+
# getitem
|
|
447
|
+
yield rewrite(ndarray_program(x[idx])).to(ndarray_program(x) + "[" + index_key_program(idx) + "]")
|
|
448
|
+
|
|
449
|
+
# square
|
|
450
|
+
yield rewrite(ndarray_program(square(x))).to((Program("np.square(") + ndarray_program(x) + ")").assign())
|
|
451
|
+
|
|
452
|
+
# expand_dims(x, axis)
|
|
453
|
+
yield rewrite(ndarray_program(expand_dims(x, i))).to(
|
|
454
|
+
(Program("np.expand_dims(") + ndarray_program(x) + ", " + int_program(i) + ")").assign()
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
# mean(x, axis)
|
|
458
|
+
yield rewrite(ndarray_program(mean(x))).to((Program("np.mean(") + ndarray_program(x) + ")").assign())
|
|
459
|
+
yield rewrite(
|
|
460
|
+
ndarray_program(mean(x, OptionalIntOrTuple.some(int_or_tuple_), FALSE)),
|
|
461
|
+
).to(
|
|
462
|
+
(Program("np.mean(") + ndarray_program(x) + ", axis=" + int_or_tuple_program(int_or_tuple_) + ")").assign(),
|
|
463
|
+
)
|
|
464
|
+
yield rewrite(
|
|
465
|
+
ndarray_program(mean(x, OptionalIntOrTuple.some(int_or_tuple_), TRUE)),
|
|
466
|
+
).to(
|
|
467
|
+
(
|
|
468
|
+
Program("np.mean(")
|
|
469
|
+
+ ndarray_program(x)
|
|
470
|
+
+ ", axis="
|
|
471
|
+
+ int_or_tuple_program(int_or_tuple_)
|
|
472
|
+
+ ", keepdims=True)"
|
|
473
|
+
).assign(),
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
# Concat
|
|
477
|
+
yield rewrite(ndarray_program(concat(tnd, OptionalInt.none))).to(
|
|
478
|
+
(Program("np.concatenate(") + tuple_ndarray_program(tnd) + ")").assign()
|
|
479
|
+
)
|
|
480
|
+
yield rewrite(ndarray_program(concat(tnd, OptionalInt.some(i)))).to(
|
|
481
|
+
(Program("np.concatenate(") + tuple_ndarray_program(tnd) + ", axis=" + int_program(i) + ")").assign()
|
|
482
|
+
)
|
|
483
|
+
# Vector
|
|
484
|
+
yield rewrite(ndarray_program(NDArray.vector(tv))).to(Program("np.array(") + tuple_value_program(tv) + ")")
|
|
485
|
+
# std
|
|
486
|
+
yield rewrite(ndarray_program(std(x))).to((Program("np.std(") + ndarray_program(x) + ")").assign())
|
|
487
|
+
yield rewrite(ndarray_program(std(x, OptionalIntOrTuple.some(int_or_tuple_)))).to(
|
|
488
|
+
(Program("np.std(") + ndarray_program(x) + ", axis=" + int_or_tuple_program(int_or_tuple_) + ")").assign(),
|
|
489
|
+
)
|
|
490
|
+
# svd
|
|
491
|
+
yield rewrite(tuple_ndarray_program(svd(x))).to((Program("np.linalg.svd(") + ndarray_program(x) + ")").assign())
|
|
492
|
+
yield rewrite(tuple_ndarray_program(svd(x, FALSE))).to(
|
|
493
|
+
(Program("np.linalg.svd(") + ndarray_program(x) + ", full_matrices=False)").assign()
|
|
494
|
+
)
|
|
495
|
+
# sqrt
|
|
496
|
+
yield rewrite(ndarray_program(sqrt(x))).to((Program("np.sqrt(") + ndarray_program(x) + ")").assign())
|
|
497
|
+
# Transpose
|
|
498
|
+
yield rewrite(ndarray_program(x.T)).to(ndarray_program(x) + ".T")
|
|
499
|
+
# sum
|
|
500
|
+
yield rewrite(ndarray_program(sum(x))).to((Program("np.sum(") + ndarray_program(x) + ")").assign())
|
|
501
|
+
yield rewrite(ndarray_program(sum(x, OptionalIntOrTuple.some(int_or_tuple_)))).to(
|
|
502
|
+
(Program("np.sum(") + ndarray_program(x) + ", axis=" + int_or_tuple_program(int_or_tuple_) + ")").assign()
|
|
503
|
+
)
|
|
504
|
+
yield rewrite(tuple_int_program(x.shape)).to(ndarray_program(x) + ".shape")
|
|
505
|
+
yield rewrite(ndarray_program(abs(x))).to((Program("np.abs(") + ndarray_program(x) + ")").assign())
|
|
506
|
+
|
|
507
|
+
# asarray
|
|
508
|
+
yield rewrite(ndarray_program(asarray(x, odtype, OptionalBool.none, optional_device_))).to(
|
|
509
|
+
Program("np.asarray(") + ndarray_program(x) + ", " + optional_dtype_program(odtype) + ")"
|
|
510
|
+
)
|