egglog 12.0.0__cp313-cp313t-manylinux_2_17_ppc64.manylinux2014_ppc64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. egglog/__init__.py +13 -0
  2. egglog/bindings.cpython-313t-powerpc64-linux-gnu.so +0 -0
  3. egglog/bindings.pyi +887 -0
  4. egglog/builtins.py +1144 -0
  5. egglog/config.py +8 -0
  6. egglog/conversion.py +290 -0
  7. egglog/declarations.py +964 -0
  8. egglog/deconstruct.py +176 -0
  9. egglog/egraph.py +2247 -0
  10. egglog/egraph_state.py +978 -0
  11. egglog/examples/README.rst +5 -0
  12. egglog/examples/__init__.py +3 -0
  13. egglog/examples/bignum.py +32 -0
  14. egglog/examples/bool.py +38 -0
  15. egglog/examples/eqsat_basic.py +44 -0
  16. egglog/examples/fib.py +28 -0
  17. egglog/examples/higher_order_functions.py +42 -0
  18. egglog/examples/jointree.py +64 -0
  19. egglog/examples/lambda_.py +287 -0
  20. egglog/examples/matrix.py +175 -0
  21. egglog/examples/multiset.py +60 -0
  22. egglog/examples/ndarrays.py +144 -0
  23. egglog/examples/resolution.py +84 -0
  24. egglog/examples/schedule_demo.py +34 -0
  25. egglog/exp/MoA.ipynb +617 -0
  26. egglog/exp/__init__.py +3 -0
  27. egglog/exp/any_expr.py +947 -0
  28. egglog/exp/any_expr_example.ipynb +408 -0
  29. egglog/exp/array_api.py +2019 -0
  30. egglog/exp/array_api_jit.py +51 -0
  31. egglog/exp/array_api_loopnest.py +74 -0
  32. egglog/exp/array_api_numba.py +69 -0
  33. egglog/exp/array_api_program_gen.py +510 -0
  34. egglog/exp/program_gen.py +427 -0
  35. egglog/exp/siu_examples.py +32 -0
  36. egglog/ipython_magic.py +41 -0
  37. egglog/pretty.py +566 -0
  38. egglog/py.typed +0 -0
  39. egglog/runtime.py +888 -0
  40. egglog/thunk.py +97 -0
  41. egglog/type_constraint_solver.py +111 -0
  42. egglog/visualizer.css +1 -0
  43. egglog/visualizer.js +35798 -0
  44. egglog/visualizer_widget.py +39 -0
  45. egglog-12.0.0.dist-info/METADATA +93 -0
  46. egglog-12.0.0.dist-info/RECORD +48 -0
  47. egglog-12.0.0.dist-info/WHEEL +5 -0
  48. egglog-12.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,510 @@
1
+ # mypy: disable-error-code="empty-body"
2
+ from __future__ import annotations
3
+
4
+ from egglog import *
5
+
6
+ from .array_api import *
7
+ from .program_gen import *
8
+
9
+ ##
10
+ # Functionality to compile expression to strings of NumPy code.
11
+ # Depends on `np` as a global variable.
12
+ ##
13
+
14
+ array_api_program_gen_ruleset = ruleset(name="array_api_program_gen_ruleset")
15
+ array_api_program_gen_eval_ruleset = ruleset(name="array_api_program_gen_eval_ruleset")
16
+
17
+ array_api_program_gen_combined_ruleset = (
18
+ array_api_program_gen_ruleset
19
+ | program_gen_ruleset
20
+ | array_api_program_gen_eval_ruleset
21
+ | array_api_vec_to_cons_ruleset
22
+ )
23
+ array_api_program_gen_schedule = (array_api_program_gen_combined_ruleset | eval_program_rulseset).saturate()
24
+
25
+
26
+ @function
27
+ def bool_program(x: Boolean) -> Program: ...
28
+
29
+
30
+ @array_api_program_gen_ruleset.register
31
+ def _bool_program():
32
+ yield rewrite(bool_program(TRUE)).to(Program("True"))
33
+ yield rewrite(bool_program(FALSE)).to(Program("False"))
34
+
35
+
36
+ @function
37
+ def int_program(x: Int) -> Program: ...
38
+
39
+
40
+ @array_api_program_gen_ruleset.register
41
+ def _int_program(i64_: i64, i: Int, j: Int, s: String):
42
+ yield rewrite(int_program(Int.var(s))).to(Program(s, True))
43
+ yield rewrite(int_program(Int(i64_))).to(Program(i64_.to_string()))
44
+ yield rewrite(int_program(~i)).to(Program("~") + int_program(i))
45
+ yield rewrite(bool_program(i < j)).to(Program("(") + int_program(i) + " < " + int_program(j) + ")")
46
+ yield rewrite(bool_program(i <= j)).to(Program("(") + int_program(i) + " <= " + int_program(j) + ")")
47
+ yield rewrite(bool_program(i > j)).to(Program("(") + int_program(i) + " > " + int_program(j) + ")")
48
+ yield rewrite(bool_program(i >= j)).to(Program("(") + int_program(i) + " >= " + int_program(j) + ")")
49
+ yield rewrite(bool_program(i == j)).to(Program("(") + int_program(i) + " == " + int_program(j) + ")")
50
+ yield rewrite(int_program(i + j)).to(Program("(") + int_program(i) + " + " + int_program(j) + ")")
51
+ yield rewrite(int_program(i - j)).to(Program("(") + int_program(i) + " - " + int_program(j) + ")")
52
+ yield rewrite(int_program(i * j)).to(Program("(") + int_program(i) + " * " + int_program(j) + ")")
53
+ yield rewrite(int_program(i / j)).to(Program("(") + int_program(i) + " / " + int_program(j) + ")")
54
+ yield rewrite(int_program(i % j)).to(Program("(") + int_program(i) + " % " + int_program(j) + ")")
55
+ yield rewrite(int_program(i**j)).to(Program("(") + int_program(i) + " ** " + int_program(j) + ")")
56
+ yield rewrite(int_program(i & j)).to(Program("(") + int_program(i) + " & " + int_program(j) + ")")
57
+ yield rewrite(int_program(i | j)).to(Program("(") + int_program(i) + " | " + int_program(j) + ")")
58
+ yield rewrite(int_program(i ^ j)).to(Program("(") + int_program(i) + " ^ " + int_program(j) + ")")
59
+ yield rewrite(int_program(i << j)).to(Program("(") + int_program(i) + " << " + int_program(j) + ")")
60
+ yield rewrite(int_program(i >> j)).to(Program("(") + int_program(i) + " >> " + int_program(j) + ")")
61
+ yield rewrite(int_program(i // j)).to(Program("(") + int_program(i) + " // " + int_program(j) + ")")
62
+
63
+
64
+ @function
65
+ def tuple_int_foldl_program(xs: TupleIntLike, f: Callable[[Program, Int], Program], init: ProgramLike) -> Program: ...
66
+
67
+
68
+ @function(ruleset=array_api_program_gen_ruleset)
69
+ def tuple_int_program(x: TupleIntLike) -> Program:
70
+ return tuple_int_foldl_program(x, lambda acc, i: acc + int_program(i) + ", ", "(") + ")"
71
+
72
+
73
+ @array_api_program_gen_ruleset.register
74
+ def _tuple_int_program(i: Int, ti: TupleInt, ti2: TupleInt, f: Callable[[Program, Int], Program], init: Program):
75
+ yield rewrite(int_program(ti[i])).to(tuple_int_program(ti) + "[" + int_program(i) + "]")
76
+ yield rewrite(int_program(ti.length())).to(Program("len(") + tuple_int_program(ti) + ")")
77
+
78
+ yield rewrite(tuple_int_foldl_program(TupleInt.EMPTY, f, init)).to(init)
79
+ yield rewrite(tuple_int_foldl_program(ti.append(i), f, init)).to(f(tuple_int_foldl_program(ti, f, init), i))
80
+
81
+ yield rewrite(tuple_int_program(ti + ti2)).to(
82
+ Program("(") + tuple_int_program(ti) + " + " + tuple_int_program(ti2) + ")"
83
+ )
84
+
85
+
86
+ @function
87
+ def ndarray_program(x: NDArray) -> Program: ...
88
+
89
+
90
+ @function(ruleset=array_api_program_gen_ruleset)
91
+ def ndarray_function_two_program(res: NDArray, l: NDArray, r: NDArray) -> Program:
92
+ return ndarray_program(res).function_two(ndarray_program(l), ndarray_program(r))
93
+
94
+
95
+ @function(ruleset=array_api_program_gen_eval_ruleset)
96
+ def ndarray_function_two(res: NDArray, l: NDArray, r: NDArray) -> EvalProgram:
97
+ return EvalProgram(ndarray_function_two_program(res, l, r), {"np": np})
98
+
99
+
100
+ @function
101
+ def dtype_program(x: DType) -> Program: ...
102
+
103
+
104
+ @array_api_program_gen_ruleset.register
105
+ def _dtype_program():
106
+ yield rewrite(dtype_program(DType.float64)).to(Program("np.dtype(np.float64)"))
107
+ yield rewrite(dtype_program(DType.float32)).to(Program("np.dtype(np.float32)"))
108
+ yield rewrite(dtype_program(DType.int64)).to(Program("np.dtype(np.int64)"))
109
+ yield rewrite(dtype_program(DType.int32)).to(Program("np.dtype(np.int32)"))
110
+ yield rewrite(dtype_program(DType.bool)).to(Program("np.dtype(np.bool)"))
111
+ yield rewrite(dtype_program(DType.object)).to(Program("np.dtype(np.object_)"))
112
+
113
+
114
+ @function
115
+ def float_program(x: Float) -> Program: ...
116
+
117
+
118
+ @array_api_program_gen_ruleset.register
119
+ def _float_program(f: Float, g: Float, f64_: f64, i: Int, r: BigRat):
120
+ yield rewrite(float_program(Float(f64_))).to(Program(f64_.to_string()))
121
+ yield rewrite(float_program(f.abs())).to(Program("np.abs(") + float_program(f) + ")")
122
+ yield rewrite(float_program(Float.from_int(i))).to(int_program(i))
123
+ yield rewrite(float_program(f + g)).to(Program("(") + float_program(f) + " + " + float_program(g) + ")")
124
+ yield rewrite(float_program(f - g)).to(Program("(") + float_program(f) + " - " + float_program(g) + ")")
125
+ yield rewrite(float_program(f * g)).to(Program("(") + float_program(f) + " * " + float_program(g) + ")")
126
+ yield rewrite(float_program(f / g)).to(Program("(") + float_program(f) + " / " + float_program(g) + ")")
127
+ yield rewrite(float_program(Float.rational(r))).to(
128
+ Program("float(") + Program(r.numer.to_string()) + " / " + Program(r.denom.to_string()) + ")",
129
+ ne(r.denom).to(BigInt(1)),
130
+ )
131
+ yield rewrite(float_program(Float.rational(r))).to(
132
+ Program("float(") + Program(r.numer.to_string()) + ")", eq(r.denom).to(BigInt(1))
133
+ )
134
+
135
+
136
+ @function
137
+ def value_program(x: Value) -> Program: ...
138
+
139
+
140
+ @array_api_program_gen_ruleset.register
141
+ def _value_program(i: Int, b: Boolean, f: Float, x: NDArray, v1: Value, v2: Value, xs: NDArray, ti: TupleInt):
142
+ yield rewrite(value_program(Value.int(i))).to(int_program(i))
143
+ yield rewrite(value_program(Value.bool(b))).to(bool_program(b))
144
+ yield rewrite(value_program(Value.float(f))).to(float_program(f))
145
+ # Could add .item() but we usually dont need it.
146
+ yield rewrite(value_program(x.to_value())).to(ndarray_program(x))
147
+ yield rewrite(value_program(v1 < v2)).to(Program("(") + value_program(v1) + " < " + value_program(v2) + ")")
148
+ yield rewrite(value_program(v1 / v2)).to(Program("(") + value_program(v1) + " / " + value_program(v2) + ")")
149
+ yield rewrite(value_program(v1 + v2)).to(Program("(") + value_program(v1) + " + " + value_program(v2) + ")")
150
+ yield rewrite(value_program(v1 * v2)).to(Program("(") + value_program(v1) + " * " + value_program(v2) + ")")
151
+ yield rewrite(bool_program(v1.to_bool)).to(value_program(v1))
152
+ yield rewrite(int_program(v1.to_int)).to(value_program(v1))
153
+ yield rewrite(value_program(xs.index(ti))).to((ndarray_program(xs) + "[" + tuple_int_program(ti) + "]").assign())
154
+ yield rewrite(value_program(v1.sqrt())).to(Program("np.sqrt(") + value_program(v1) + ")")
155
+ yield rewrite(value_program(v1.real())).to(Program("np.real(") + value_program(v1) + ")")
156
+ yield rewrite(value_program(v1.conj())).to(Program("np.conj(") + value_program(v1) + ")")
157
+
158
+
159
+ @function
160
+ def tuple_value_foldl_program(
161
+ xs: TupleValueLike, f: Callable[[Program, Value], Program], init: ProgramLike
162
+ ) -> Program: ...
163
+
164
+
165
+ @function(ruleset=array_api_program_gen_ruleset)
166
+ def tuple_value_program(x: TupleValueLike) -> Program:
167
+ return tuple_value_foldl_program(x, lambda acc, i: acc + value_program(i) + ", ", "(") + ")"
168
+
169
+
170
+ @array_api_program_gen_ruleset.register
171
+ def _tuple_value_program(i: Int, ti: TupleValue, f: Callable[[Program, Value], Program], v: Value, init: Program):
172
+ yield rewrite(value_program(ti[i])).to(tuple_value_program(ti) + "[" + int_program(i) + "]")
173
+ yield rewrite(int_program(ti.length())).to(Program("len(") + tuple_value_program(ti) + ")")
174
+
175
+ yield rewrite(tuple_value_foldl_program(TupleValue.EMPTY, f, init)).to(init)
176
+ yield rewrite(tuple_value_foldl_program(ti.append(v), f, init)).to(f(tuple_value_foldl_program(ti, f, init), v))
177
+
178
+
179
+ @function
180
+ def tuple_ndarray_foldl_program(
181
+ xs: TupleNDArrayLike, f: Callable[[Program, NDArray], Program], init: ProgramLike
182
+ ) -> Program: ...
183
+
184
+
185
+ @function(ruleset=array_api_program_gen_ruleset)
186
+ def tuple_ndarray_program(x: TupleNDArrayLike) -> Program:
187
+ return tuple_ndarray_foldl_program(x, lambda acc, i: acc + ndarray_program(i) + ", ", "(") + ")"
188
+
189
+
190
+ @array_api_program_gen_ruleset.register
191
+ def _tuple_ndarray_program(
192
+ i: Int, ti: TupleNDArray, f: Callable[[Program, NDArray], Program], v: NDArray, init: Program
193
+ ):
194
+ yield rewrite(ndarray_program(ti[i])).to(tuple_ndarray_program(ti) + "[" + int_program(i) + "]")
195
+ yield rewrite(int_program(ti.length())).to(Program("len(") + tuple_ndarray_program(ti) + ")")
196
+
197
+ yield rewrite(tuple_ndarray_foldl_program(TupleNDArray.EMPTY, f, init)).to(init)
198
+ yield rewrite(tuple_ndarray_foldl_program(ti.append(v), f, init)).to(f(tuple_ndarray_foldl_program(ti, f, init), v))
199
+
200
+
201
+ @function
202
+ def optional_dtype_program(x: OptionalDType) -> Program: ...
203
+
204
+
205
+ @array_api_program_gen_ruleset.register
206
+ def _optional_dtype_program(dtype: DType):
207
+ yield rewrite(optional_dtype_program(OptionalDType.none)).to(Program("None"))
208
+ yield rewrite(optional_dtype_program(OptionalDType.some(dtype))).to(dtype_program(dtype))
209
+
210
+
211
+ @function
212
+ def optional_int_program(x: OptionalInt) -> Program: ...
213
+
214
+
215
+ @array_api_program_gen_ruleset.register
216
+ def _optional_int_program(x: Int):
217
+ yield rewrite(optional_int_program(OptionalInt.none)).to(Program("None"))
218
+ yield rewrite(optional_int_program(OptionalInt.some(x))).to(int_program(x))
219
+
220
+
221
+ @function
222
+ def optional_int_slice_program(x: OptionalInt) -> Program:
223
+ """
224
+ Translates an optional int to a program, but translates None as "" instead of None
225
+ """
226
+
227
+
228
+ @array_api_program_gen_ruleset.register
229
+ def _optional_int_slice_program(x: Int):
230
+ yield rewrite(optional_int_slice_program(OptionalInt.none)).to(Program(""))
231
+ yield rewrite(optional_int_slice_program(OptionalInt.some(x))).to(int_program(x))
232
+
233
+
234
+ @function
235
+ def slice_program(x: Slice) -> Program: ...
236
+
237
+
238
+ @array_api_program_gen_ruleset.register
239
+ def _slice_program(start: OptionalInt, stop: OptionalInt, i: Int):
240
+ yield rewrite(slice_program(Slice(start, stop, OptionalInt.none))).to(
241
+ optional_int_slice_program(start) + ":" + optional_int_slice_program(stop)
242
+ )
243
+ yield rewrite(slice_program(Slice(start, stop, OptionalInt.some(i)))).to(
244
+ optional_int_slice_program(start) + ":" + optional_int_slice_program(stop) + ":" + int_program(i)
245
+ )
246
+
247
+
248
+ @function
249
+ def multi_axis_index_key_item_program(x: MultiAxisIndexKeyItem) -> Program: ...
250
+
251
+
252
+ @array_api_program_gen_ruleset.register
253
+ def _multi_axis_index_key_item_program(i: Int, s: Slice):
254
+ yield rewrite(multi_axis_index_key_item_program(MultiAxisIndexKeyItem.int(i))).to(int_program(i))
255
+ yield rewrite(multi_axis_index_key_item_program(MultiAxisIndexKeyItem.slice(s))).to(slice_program(s))
256
+ yield rewrite(multi_axis_index_key_item_program(MultiAxisIndexKeyItem.ELLIPSIS)).to(Program("..."))
257
+ yield rewrite(multi_axis_index_key_item_program(MultiAxisIndexKeyItem.NONE)).to(Program("None"))
258
+
259
+
260
+ @function
261
+ def multi_axis_index_key_program(x: MultiAxisIndexKey) -> Program: ...
262
+
263
+
264
+ @array_api_program_gen_ruleset.register
265
+ def _multi_axis_index_key_program(
266
+ idx_fn: Callable[[Int], MultiAxisIndexKeyItem], k: i64, vec: Vec[MultiAxisIndexKeyItem], i: MultiAxisIndexKeyItem
267
+ ):
268
+ yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey(0, idx_fn))).to(Program(""))
269
+
270
+ yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey(Int(k), idx_fn))).to(
271
+ multi_axis_index_key_item_program(idx_fn(Int(0)))
272
+ + ", "
273
+ + multi_axis_index_key_program(MultiAxisIndexKey(Int(k - 1), lambda i: idx_fn(i + 1))),
274
+ ne(k).to(i64(0)),
275
+ )
276
+
277
+ yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem]()))).to(
278
+ Program("")
279
+ )
280
+ yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec))).to(
281
+ multi_axis_index_key_item_program(vec[0]) + ",",
282
+ eq(vec.length()).to(i64(1)),
283
+ )
284
+ yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec))).to(
285
+ multi_axis_index_key_item_program(vec[0])
286
+ + ", "
287
+ + multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec.remove(0))),
288
+ vec.length() > 1,
289
+ )
290
+
291
+
292
+ @function
293
+ def index_key_program(x: IndexKey) -> Program: ...
294
+
295
+
296
+ @array_api_program_gen_ruleset.register
297
+ def _index_key_program(i: Int, s: Slice, key: MultiAxisIndexKey, a: NDArray):
298
+ yield rewrite(index_key_program(IndexKey.ELLIPSIS)).to(Program("..."))
299
+ yield rewrite(index_key_program(IndexKey.int(i))).to(int_program(i))
300
+ yield rewrite(index_key_program(IndexKey.slice(s))).to(slice_program(s))
301
+ yield rewrite(index_key_program(IndexKey.multi_axis(key))).to(multi_axis_index_key_program(key))
302
+ yield rewrite(index_key_program(IndexKey.ndarray(a))).to(ndarray_program(a))
303
+
304
+
305
+ @function
306
+ def int_or_tuple_program(x: IntOrTuple) -> Program: ...
307
+
308
+
309
+ @array_api_program_gen_ruleset.register
310
+ def _int_or_tuple_program(x: Int, t: TupleInt):
311
+ yield rewrite(int_or_tuple_program(IntOrTuple.int(x))).to(int_program(x))
312
+ yield rewrite(int_or_tuple_program(IntOrTuple.tuple(t))).to(tuple_int_program(t))
313
+
314
+
315
+ @function
316
+ def optional_int_or_tuple_program(x: OptionalIntOrTuple) -> Program: ...
317
+
318
+
319
+ @array_api_program_gen_ruleset.register
320
+ def _optional_int_or_tuple_program(it: IntOrTuple):
321
+ yield rewrite(optional_int_or_tuple_program(OptionalIntOrTuple.some(it))).to(int_or_tuple_program(it))
322
+ yield rewrite(optional_int_or_tuple_program(OptionalIntOrTuple.none)).to(Program("None"))
323
+
324
+
325
+ @array_api_program_gen_ruleset.register
326
+ def _ndarray_program(
327
+ x: NDArray,
328
+ y: NDArray,
329
+ z: NDArray,
330
+ s: String,
331
+ dtype: DType,
332
+ ti: TupleInt,
333
+ i: Int,
334
+ tv: TupleValue,
335
+ v: Value,
336
+ ob: OptionalBool,
337
+ tnd: TupleNDArray,
338
+ optional_device_: OptionalDevice,
339
+ int_or_tuple_: IntOrTuple,
340
+ idx: IndexKey,
341
+ odtype: OptionalDType,
342
+ ):
343
+ # Var
344
+ yield rewrite(ndarray_program(NDArray.var(s))).to(Program(s, True))
345
+
346
+ # Asssume dtype
347
+ z_assumed_dtype = copy(z)
348
+ assume_dtype(z_assumed_dtype, dtype)
349
+ z_program = ndarray_program(z)
350
+ yield rewrite(ndarray_program(z_assumed_dtype)).to(
351
+ z_program.statement(Program("assert ") + z_program + ".dtype == " + dtype_program(dtype))
352
+ )
353
+ # assume shape
354
+ z_assumed_shape = copy(z)
355
+ assume_shape(z_assumed_shape, ti)
356
+ yield rewrite(ndarray_program(z_assumed_shape)).to(
357
+ z_program.statement(Program("assert ") + z_program + ".shape == " + tuple_int_program(ti))
358
+ )
359
+
360
+ # assume isfinite
361
+ z_assumed_isfinite = copy(z)
362
+ assume_isfinite(z_assumed_isfinite)
363
+ yield rewrite(ndarray_program(z_assumed_isfinite)).to(
364
+ z_program.statement(Program("assert np.all(np.isfinite(") + z_program + "))")
365
+ )
366
+
367
+ # Assume value_one_of
368
+ z_assumed_value_one_of = copy(z)
369
+ assume_value_one_of(z_assumed_value_one_of, tv)
370
+ yield rewrite(ndarray_program(z_assumed_value_one_of)).to(
371
+ z_program.statement(Program("assert set(np.unique(") + z_program + ")) == set(" + tuple_value_program(tv) + ")")
372
+ )
373
+
374
+ # Value
375
+
376
+ # reshape (don't include copy, since not present in numpy)
377
+ yield rewrite(ndarray_program(reshape(y, ti, ob))).to(
378
+ (ndarray_program(y) + ".reshape(" + tuple_int_program(ti) + ")").assign()
379
+ )
380
+
381
+ # astype
382
+ yield rewrite(ndarray_program(astype(y, dtype))).to(
383
+ (ndarray_program(y) + ".astype(" + dtype_program(dtype) + ")").assign()
384
+ )
385
+
386
+ # unique_counts(x) => unique(x, return_counts=True)
387
+ yield rewrite(tuple_ndarray_program(unique_counts(x))).to(
388
+ (Program("np.unique(") + ndarray_program(x) + ", return_counts=True)").assign()
389
+ )
390
+ # unique_inverse(x) => unique(x, return_inverse=True)
391
+ yield rewrite(tuple_ndarray_program(unique_inverse(x))).to(
392
+ (Program("np.unique(") + ndarray_program(x) + ", return_inverse=True)").assign()
393
+ )
394
+
395
+ # Tuple ndarray indexing
396
+ yield rewrite(ndarray_program(tnd[i])).to(tuple_ndarray_program(tnd) + "[" + int_program(i) + "]")
397
+
398
+ # ndarray scalar
399
+ # TODO: Use dtype and shape and indexing instead?
400
+ # TODO: SPecify dtype?
401
+ yield rewrite(ndarray_program(NDArray.scalar(v))).to(Program("np.array(") + value_program(v) + ")")
402
+
403
+ # zeros
404
+ yield rewrite(ndarray_program(zeros(ti, OptionalDType.none, optional_device_))).to(
405
+ (Program("np.zeros(") + tuple_int_program(ti) + ")").assign()
406
+ )
407
+ yield rewrite(ndarray_program(zeros(ti, OptionalDType.some(dtype), optional_device_))).to(
408
+ (Program("np.zeros(") + tuple_int_program(ti) + ", dtype=" + dtype_program(dtype) + ")").assign(),
409
+ )
410
+
411
+ # unique_values
412
+ yield rewrite(ndarray_program(unique_values(x))).to((Program("np.unique(") + ndarray_program(x) + ")").assign())
413
+
414
+ # reshape
415
+
416
+ def bin_op(res: NDArray, op: str) -> Command:
417
+ return rewrite(ndarray_program(res)).to((ndarray_program(x) + f" {op} " + ndarray_program(y)).assign())
418
+
419
+ # NDARRAy ops
420
+ yield bin_op(x + y, "+")
421
+ yield bin_op(x - y, "-")
422
+ yield bin_op(x * y, "*")
423
+ yield bin_op(x / y, "/")
424
+ yield bin_op(x < y, "<")
425
+ yield bin_op(x <= y, "<=")
426
+ yield bin_op(x > y, ">")
427
+ yield bin_op(x >= y, ">=")
428
+ yield bin_op(x == y, "==")
429
+ yield bin_op(x @ y, "@")
430
+ yield bin_op(x % y, "%")
431
+ yield bin_op(x & y, "&")
432
+ yield bin_op(x | y, "|")
433
+ yield bin_op(x ^ y, "^")
434
+ yield bin_op(x << y, "<<")
435
+ yield bin_op(x >> y, ">>")
436
+ yield bin_op(x // y, "//")
437
+ yield bin_op(x**y, "**")
438
+
439
+ # setitem
440
+ mod_x = copy(x)
441
+ mod_x[idx] = y
442
+ assigned_x = ndarray_program(x).assign()
443
+ yield rewrite(ndarray_program(mod_x)).to(
444
+ assigned_x.statement(assigned_x + "[" + index_key_program(idx) + "] = " + ndarray_program(y))
445
+ )
446
+ # getitem
447
+ yield rewrite(ndarray_program(x[idx])).to(ndarray_program(x) + "[" + index_key_program(idx) + "]")
448
+
449
+ # square
450
+ yield rewrite(ndarray_program(square(x))).to((Program("np.square(") + ndarray_program(x) + ")").assign())
451
+
452
+ # expand_dims(x, axis)
453
+ yield rewrite(ndarray_program(expand_dims(x, i))).to(
454
+ (Program("np.expand_dims(") + ndarray_program(x) + ", " + int_program(i) + ")").assign()
455
+ )
456
+
457
+ # mean(x, axis)
458
+ yield rewrite(ndarray_program(mean(x))).to((Program("np.mean(") + ndarray_program(x) + ")").assign())
459
+ yield rewrite(
460
+ ndarray_program(mean(x, OptionalIntOrTuple.some(int_or_tuple_), FALSE)),
461
+ ).to(
462
+ (Program("np.mean(") + ndarray_program(x) + ", axis=" + int_or_tuple_program(int_or_tuple_) + ")").assign(),
463
+ )
464
+ yield rewrite(
465
+ ndarray_program(mean(x, OptionalIntOrTuple.some(int_or_tuple_), TRUE)),
466
+ ).to(
467
+ (
468
+ Program("np.mean(")
469
+ + ndarray_program(x)
470
+ + ", axis="
471
+ + int_or_tuple_program(int_or_tuple_)
472
+ + ", keepdims=True)"
473
+ ).assign(),
474
+ )
475
+
476
+ # Concat
477
+ yield rewrite(ndarray_program(concat(tnd, OptionalInt.none))).to(
478
+ (Program("np.concatenate(") + tuple_ndarray_program(tnd) + ")").assign()
479
+ )
480
+ yield rewrite(ndarray_program(concat(tnd, OptionalInt.some(i)))).to(
481
+ (Program("np.concatenate(") + tuple_ndarray_program(tnd) + ", axis=" + int_program(i) + ")").assign()
482
+ )
483
+ # Vector
484
+ yield rewrite(ndarray_program(NDArray.vector(tv))).to(Program("np.array(") + tuple_value_program(tv) + ")")
485
+ # std
486
+ yield rewrite(ndarray_program(std(x))).to((Program("np.std(") + ndarray_program(x) + ")").assign())
487
+ yield rewrite(ndarray_program(std(x, OptionalIntOrTuple.some(int_or_tuple_)))).to(
488
+ (Program("np.std(") + ndarray_program(x) + ", axis=" + int_or_tuple_program(int_or_tuple_) + ")").assign(),
489
+ )
490
+ # svd
491
+ yield rewrite(tuple_ndarray_program(svd(x))).to((Program("np.linalg.svd(") + ndarray_program(x) + ")").assign())
492
+ yield rewrite(tuple_ndarray_program(svd(x, FALSE))).to(
493
+ (Program("np.linalg.svd(") + ndarray_program(x) + ", full_matrices=False)").assign()
494
+ )
495
+ # sqrt
496
+ yield rewrite(ndarray_program(sqrt(x))).to((Program("np.sqrt(") + ndarray_program(x) + ")").assign())
497
+ # Transpose
498
+ yield rewrite(ndarray_program(x.T)).to(ndarray_program(x) + ".T")
499
+ # sum
500
+ yield rewrite(ndarray_program(sum(x))).to((Program("np.sum(") + ndarray_program(x) + ")").assign())
501
+ yield rewrite(ndarray_program(sum(x, OptionalIntOrTuple.some(int_or_tuple_)))).to(
502
+ (Program("np.sum(") + ndarray_program(x) + ", axis=" + int_or_tuple_program(int_or_tuple_) + ")").assign()
503
+ )
504
+ yield rewrite(tuple_int_program(x.shape)).to(ndarray_program(x) + ".shape")
505
+ yield rewrite(ndarray_program(abs(x))).to((Program("np.abs(") + ndarray_program(x) + ")").assign())
506
+
507
+ # asarray
508
+ yield rewrite(ndarray_program(asarray(x, odtype, OptionalBool.none, optional_device_))).to(
509
+ Program("np.asarray(") + ndarray_program(x) + ", " + optional_dtype_program(odtype) + ")"
510
+ )