egglog 12.0.0__cp313-cp313t-manylinux_2_17_ppc64.manylinux2014_ppc64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. egglog/__init__.py +13 -0
  2. egglog/bindings.cpython-313t-powerpc64-linux-gnu.so +0 -0
  3. egglog/bindings.pyi +887 -0
  4. egglog/builtins.py +1144 -0
  5. egglog/config.py +8 -0
  6. egglog/conversion.py +290 -0
  7. egglog/declarations.py +964 -0
  8. egglog/deconstruct.py +176 -0
  9. egglog/egraph.py +2247 -0
  10. egglog/egraph_state.py +978 -0
  11. egglog/examples/README.rst +5 -0
  12. egglog/examples/__init__.py +3 -0
  13. egglog/examples/bignum.py +32 -0
  14. egglog/examples/bool.py +38 -0
  15. egglog/examples/eqsat_basic.py +44 -0
  16. egglog/examples/fib.py +28 -0
  17. egglog/examples/higher_order_functions.py +42 -0
  18. egglog/examples/jointree.py +64 -0
  19. egglog/examples/lambda_.py +287 -0
  20. egglog/examples/matrix.py +175 -0
  21. egglog/examples/multiset.py +60 -0
  22. egglog/examples/ndarrays.py +144 -0
  23. egglog/examples/resolution.py +84 -0
  24. egglog/examples/schedule_demo.py +34 -0
  25. egglog/exp/MoA.ipynb +617 -0
  26. egglog/exp/__init__.py +3 -0
  27. egglog/exp/any_expr.py +947 -0
  28. egglog/exp/any_expr_example.ipynb +408 -0
  29. egglog/exp/array_api.py +2019 -0
  30. egglog/exp/array_api_jit.py +51 -0
  31. egglog/exp/array_api_loopnest.py +74 -0
  32. egglog/exp/array_api_numba.py +69 -0
  33. egglog/exp/array_api_program_gen.py +510 -0
  34. egglog/exp/program_gen.py +427 -0
  35. egglog/exp/siu_examples.py +32 -0
  36. egglog/ipython_magic.py +41 -0
  37. egglog/pretty.py +566 -0
  38. egglog/py.typed +0 -0
  39. egglog/runtime.py +888 -0
  40. egglog/thunk.py +97 -0
  41. egglog/type_constraint_solver.py +111 -0
  42. egglog/visualizer.css +1 -0
  43. egglog/visualizer.js +35798 -0
  44. egglog/visualizer_widget.py +39 -0
  45. egglog-12.0.0.dist-info/METADATA +93 -0
  46. egglog-12.0.0.dist-info/RECORD +48 -0
  47. egglog-12.0.0.dist-info/WHEEL +5 -0
  48. egglog-12.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,2019 @@
1
+ """
2
+
3
+
4
+ ## Lists
5
+
6
+ Lists have two main constructors:
7
+
8
+ - `List(length, idx_fn)`
9
+ - `List.EMPTY` / `initial.append(last)`
10
+
11
+ This is so that they can be defined either with a known fixed integer length (the cons list type) or a symbolic
12
+ length that could not be resolved to an integer.
13
+
14
+ There are rewrites to convert between these constructors in both directions. The only limitation however is that
15
+ `length` has to a real i64 in order to be converted to a cons list.
16
+
17
+ When you are writing a function that uses ints, feel free to the `__getitem__` or `length()` methods or match
18
+ directly on `List()` constructor. If you can write your function using that interface please do. But for some other
19
+ methods whether the resulting length/index function is dependent on the rest of it, you can only define it with a known
20
+ length, so you can then use the const list constructors.
21
+
22
+ We also support creating lists from vectors. These can be converted one to one to the snoc list representation.
23
+
24
+ It is troublesome to have to redefine lists for every type. It would be nice to have generic types, but they are not implemented yet.
25
+
26
+ We are gauranteed that all lists with known lengths will be represented as cons/empty. To safely use lists, use
27
+ the `.length` and `.__getitem__` methods, unles you want to to depend on it having known length, in which
28
+ case you can match directly on the cons list.
29
+
30
+ To be a list, you must implement two methods:
31
+
32
+ * `l.length() -> Int`
33
+ * `l.__getitem__(i: Int) -> T`
34
+
35
+ There are three main types of constructors for lists which all implement these methods:
36
+
37
+ * Functional `List(length, idx_fn)`
38
+ * cons (well reversed cons) lists `List.EMPTY` and `l.append(x)`
39
+ * Vectors `List.from_vec(vec)`
40
+
41
+ Also all lists constructors must be converted to the functional representation, so that we can match on it
42
+ and convert lists with known lengths into cons lists and into vectors.
43
+
44
+ This is neccessary so that known length lists are properly materialized during extraction.
45
+
46
+ Q: Why are they implemented as SNOC lists instead of CONS lists?
47
+ A: So that when converting from functional to lists we can use the same index function by starting at the end and folding
48
+ that way recursively.
49
+
50
+
51
+ """
52
+
53
+ # mypy: disable-error-code="empty-body"
54
+
55
+ from __future__ import annotations
56
+
57
+ import contextlib
58
+ import itertools
59
+ import math
60
+ import numbers
61
+ import os
62
+ import sys
63
+ from collections.abc import Callable
64
+ from copy import copy
65
+ from types import EllipsisType
66
+ from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, cast
67
+
68
+ import numpy as np
69
+
70
+ from egglog import *
71
+ from egglog.runtime import RuntimeExpr
72
+
73
+ from .program_gen import *
74
+
75
+ if TYPE_CHECKING:
76
+ from collections.abc import Iterator
77
+ from types import ModuleType
78
+
79
+
80
+ # Pretend that exprs are numbers b/c sklearn does isinstance checks
81
+ numbers.Integral.register(RuntimeExpr)
82
+
83
+ # Set this to 1 before scipy is ever imported
84
+ # https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html#using-array-api-standard-support
85
+ os.environ["SCIPY_ARRAY_API"] = "1"
86
+
87
+ array_api_ruleset = ruleset(name="array_api_ruleset")
88
+
89
+
90
+ class Boolean(Expr, ruleset=array_api_ruleset):
91
+ def __init__(self, value: BoolLike) -> None: ...
92
+
93
+ @method(preserve=True)
94
+ def __bool__(self) -> bool:
95
+ return self.eval()
96
+
97
+ @method(preserve=True)
98
+ def eval(self) -> bool:
99
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_bool)
100
+
101
+ @property
102
+ def to_bool(self) -> Bool: ...
103
+
104
+ def __or__(self, other: BooleanLike) -> Boolean: ...
105
+
106
+ def __and__(self, other: BooleanLike) -> Boolean: ...
107
+
108
+ def __invert__(self) -> Boolean: ...
109
+
110
+ def __eq__(self, other: BooleanLike) -> Boolean: ... # type: ignore[override]
111
+
112
+
113
+ BooleanLike = Boolean | BoolLike
114
+
115
+ TRUE = Boolean(True)
116
+ FALSE = Boolean(False)
117
+ converter(Bool, Boolean, Boolean)
118
+
119
+
120
+ @array_api_ruleset.register
121
+ def _bool(x: Boolean, i: Int, j: Int, b: Bool):
122
+ return [
123
+ rule(eq(x).to(Boolean(b))).then(set_(x.to_bool).to(b)),
124
+ rewrite(TRUE | x).to(TRUE),
125
+ rewrite(FALSE | x).to(x),
126
+ rewrite(TRUE & x).to(x),
127
+ rewrite(FALSE & x).to(FALSE),
128
+ rewrite(~TRUE).to(FALSE),
129
+ rewrite(~FALSE).to(TRUE),
130
+ rule(eq(FALSE).to(TRUE)).then(panic("False cannot equal True")),
131
+ rewrite(x == x).to(TRUE), # noqa: PLR0124
132
+ rewrite(FALSE == TRUE).to(FALSE),
133
+ rewrite(TRUE == FALSE).to(FALSE),
134
+ ]
135
+
136
+
137
+ class Int(Expr, ruleset=array_api_ruleset):
138
+ # a never int is that should not exist. It could represent for example indexing into an array a value that is out of bounds
139
+ # https://en.wikipedia.org/wiki/Bottom_type
140
+ NEVER: ClassVar[Int]
141
+
142
+ @classmethod
143
+ def var(cls, name: StringLike) -> Int: ...
144
+
145
+ def __init__(self, value: i64Like) -> None: ...
146
+
147
+ def __invert__(self) -> Int: ...
148
+
149
+ def __lt__(self, other: IntLike) -> Boolean: ...
150
+
151
+ def __le__(self, other: IntLike) -> Boolean: ...
152
+
153
+ def __eq__(self, other: IntLike) -> Boolean: # type: ignore[override]
154
+ ...
155
+
156
+ # add a hash so that this test can pass
157
+ # https://github.com/scikit-learn/scikit-learn/blob/6fd23fca53845b32b249f2b36051c081b65e2fab/sklearn/utils/validation.py#L486-L487
158
+ @method(preserve=True)
159
+ def __hash__(self) -> int:
160
+ egraph = _get_current_egraph()
161
+ egraph.register(self)
162
+ egraph.run(array_api_schedule)
163
+ simplified = egraph.extract(self)
164
+ return hash(cast("RuntimeExpr", simplified).__egg_typed_expr__)
165
+
166
+ def __round__(self, ndigits: OptionalIntLike = None) -> Int: ...
167
+
168
+ # TODO: Fix this?
169
+ # Make != always return a Bool, so that numpy.unique works on a tuple of ints
170
+ # In _unique1d
171
+ @method(preserve=True)
172
+ def __ne__(self, other: Int) -> bool: # type: ignore[override]
173
+ return not (self == other)
174
+
175
+ def __gt__(self, other: IntLike) -> Boolean: ...
176
+
177
+ def __ge__(self, other: IntLike) -> Boolean: ...
178
+
179
+ def __add__(self, other: IntLike) -> Int: ...
180
+
181
+ def __sub__(self, other: IntLike) -> Int: ...
182
+
183
+ def __mul__(self, other: IntLike) -> Int: ...
184
+
185
+ def __truediv__(self, other: IntLike) -> Int: ...
186
+
187
+ def __floordiv__(self, other: IntLike) -> Int: ...
188
+
189
+ def __mod__(self, other: IntLike) -> Int: ...
190
+
191
+ def __divmod__(self, other: IntLike) -> Int: ...
192
+
193
+ def __pow__(self, other: IntLike) -> Int: ...
194
+
195
+ def __lshift__(self, other: IntLike) -> Int: ...
196
+
197
+ def __rshift__(self, other: IntLike) -> Int: ...
198
+
199
+ def __and__(self, other: IntLike) -> Int: ...
200
+
201
+ def __xor__(self, other: IntLike) -> Int: ...
202
+
203
+ def __or__(self, other: IntLike) -> Int: ...
204
+
205
+ def __radd__(self, other: IntLike) -> Int: ...
206
+
207
+ def __rsub__(self, other: IntLike) -> Int: ...
208
+
209
+ def __rmul__(self, other: IntLike) -> Int: ...
210
+
211
+ def __rmatmul__(self, other: IntLike) -> Int: ...
212
+
213
+ def __rtruediv__(self, other: IntLike) -> Int: ...
214
+
215
+ def __rfloordiv__(self, other: IntLike) -> Int: ...
216
+
217
+ def __rmod__(self, other: IntLike) -> Int: ...
218
+
219
+ def __rpow__(self, other: IntLike) -> Int: ...
220
+
221
+ def __rlshift__(self, other: IntLike) -> Int: ...
222
+
223
+ def __rrshift__(self, other: IntLike) -> Int: ...
224
+
225
+ def __rand__(self, other: IntLike) -> Int: ...
226
+
227
+ def __rxor__(self, other: IntLike) -> Int: ...
228
+
229
+ def __ror__(self, other: IntLike) -> Int: ...
230
+
231
+ @property
232
+ def to_i64(self) -> i64: ...
233
+
234
+ @method(preserve=True)
235
+ def eval(self) -> int:
236
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_i64)
237
+
238
+ @method(preserve=True)
239
+ def __index__(self) -> int:
240
+ return self.eval()
241
+
242
+ @method(preserve=True)
243
+ def __int__(self) -> int:
244
+ return self.eval()
245
+
246
+ @method(preserve=True)
247
+ def __float__(self) -> float:
248
+ return float(self.eval())
249
+
250
+ @method(preserve=True)
251
+ def __bool__(self) -> bool:
252
+ return bool(self.eval())
253
+
254
+ @classmethod
255
+ def if_(cls, b: BooleanLike, i: IntLike, j: IntLike) -> Int: ...
256
+
257
+
258
+ @array_api_ruleset.register
259
+ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
260
+ yield rewrite(Int(i) == Int(i)).to(TRUE)
261
+ yield rule(eq(r).to(Int(i) == Int(j)), ne(i).to(j)).then(union(r).with_(FALSE))
262
+
263
+ yield rewrite(Int(i) >= Int(i)).to(TRUE)
264
+ yield rule(eq(r).to(Int(i) >= Int(j)), i > j).then(union(r).with_(TRUE))
265
+ yield rule(eq(r).to(Int(i) >= Int(j)), i < j).then(union(r).with_(FALSE))
266
+
267
+ yield rewrite(Int(i) < Int(i)).to(FALSE)
268
+ yield rule(eq(r).to(Int(i) < Int(j)), i < j).then(union(r).with_(TRUE))
269
+ yield rule(eq(r).to(Int(i) < Int(j)), i > j).then(union(r).with_(FALSE))
270
+
271
+ yield rewrite(Int(i) > Int(i)).to(FALSE)
272
+ yield rule(eq(r).to(Int(i) > Int(j)), i > j).then(union(r).with_(TRUE))
273
+ yield rule(eq(r).to(Int(i) > Int(j)), i < j).then(union(r).with_(FALSE))
274
+
275
+ yield rule(eq(o).to(Int(j))).then(set_(o.to_i64).to(j))
276
+
277
+ yield rule(eq(Int(i)).to(Int(j)), ne(i).to(j)).then(panic("Real ints cannot be equal to different ints"))
278
+
279
+ yield rewrite(Int(i) + Int(j)).to(Int(i + j))
280
+ yield rewrite(Int(i) - Int(j)).to(Int(i - j))
281
+ yield rewrite(Int(i) * Int(j)).to(Int(i * j))
282
+ yield rewrite(Int(i) // Int(j)).to(Int(i / j))
283
+ yield rewrite(Int(i) % Int(j)).to(Int(i % j))
284
+ yield rewrite(Int(i) & Int(j)).to(Int(i & j))
285
+ yield rewrite(Int(i) | Int(j)).to(Int(i | j))
286
+ yield rewrite(Int(i) ^ Int(j)).to(Int(i ^ j))
287
+ yield rewrite(Int(i) << Int(j)).to(Int(i << j))
288
+ yield rewrite(Int(i) >> Int(j)).to(Int(i >> j))
289
+ yield rewrite(~Int(i)).to(Int(~i))
290
+
291
+ yield rewrite(Int.if_(TRUE, o, b), subsume=True).to(o)
292
+ yield rewrite(Int.if_(FALSE, o, b), subsume=True).to(b)
293
+
294
+ yield rewrite(o.__round__(OptionalInt.none)).to(o)
295
+
296
+ # Never cannot be equal to anything real
297
+ yield rule(eq(Int.NEVER).to(Int(i))).then(panic("Int.NEVER cannot be equal to any real int"))
298
+
299
+
300
+ converter(i64, Int, lambda x: Int(x))
301
+
302
+ IntLike: TypeAlias = Int | i64Like
303
+
304
+
305
+ @function(ruleset=array_api_ruleset)
306
+ def check_index(length: IntLike, idx: IntLike) -> Int:
307
+ """
308
+ Returns the index if 0 <= idx < length, otherwise returns Int.NEVER
309
+ """
310
+ length = cast("Int", length)
311
+ idx = cast("Int", idx)
312
+ return Int.if_(((idx >= 0) & (idx < length)), idx, Int.NEVER)
313
+
314
+
315
+ # @array_api_ruleset.register
316
+ # def _check_index(i: i64, j: i64, x: Int):
317
+ # yield rewrite(
318
+ # check_index(Int(i), Int(j)),
319
+ # ).to(
320
+ # Int(j),
321
+ # i >= 0,
322
+ # i < j,
323
+ # )
324
+
325
+ # yield rewrite(
326
+ # check_index(x, Int(i)),
327
+ # ).to(
328
+ # Int.NEVER,
329
+ # i < 0,
330
+ # )
331
+
332
+ # yield rewrite(
333
+ # check_index(Int(i), Int(j)),
334
+ # ).to(
335
+ # Int.NEVER,
336
+ # i >= j,
337
+ # )
338
+
339
+
340
+ class Float(Expr, ruleset=array_api_ruleset):
341
+ # Differentiate costs of three constructors so extraction is deterministic if all three are present
342
+ @method(cost=3)
343
+ def __init__(self, value: f64Like) -> None: ...
344
+
345
+ @property
346
+ def to_f64(self) -> f64: ...
347
+
348
+ @method(preserve=True)
349
+ def eval(self) -> float:
350
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_f64)
351
+
352
+ def abs(self) -> Float: ...
353
+
354
+ @method(cost=2)
355
+ @classmethod
356
+ def rational(cls, r: BigRat) -> Float: ...
357
+
358
+ @classmethod
359
+ def from_int(cls, i: IntLike) -> Float: ...
360
+
361
+ def __truediv__(self, other: FloatLike) -> Float: ...
362
+
363
+ def __mul__(self, other: FloatLike) -> Float: ...
364
+
365
+ def __add__(self, other: FloatLike) -> Float: ...
366
+
367
+ def __sub__(self, other: FloatLike) -> Float: ...
368
+
369
+ def __pow__(self, other: FloatLike) -> Float: ...
370
+ def __round__(self, ndigits: OptionalIntLike = None) -> Float: ...
371
+
372
+ def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
373
+ def __ne__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
374
+ def __lt__(self, other: FloatLike) -> Boolean: ...
375
+ def __le__(self, other: FloatLike) -> Boolean: ...
376
+ def __gt__(self, other: FloatLike) -> Boolean: ...
377
+ def __ge__(self, other: FloatLike) -> Boolean: ...
378
+
379
+
380
+ converter(float, Float, lambda x: Float(x))
381
+ converter(Int, Float, lambda x: Float.from_int(x))
382
+
383
+
384
+ FloatLike: TypeAlias = Float | float | IntLike
385
+
386
+
387
+ @array_api_ruleset.register
388
+ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat, i_: Int):
389
+ return [
390
+ rule(eq(fl).to(Float(f))).then(set_(fl.to_f64).to(f)),
391
+ rewrite(Float.from_int(Int(i))).to(Float(f64.from_i64(i))),
392
+ rewrite(Float(f).abs()).to(Float(f), f >= 0.0),
393
+ rewrite(Float(f).abs()).to(Float(-f), f < 0.0),
394
+ # Convert from float to rationl, if its a whole number i.e. can be converted to int
395
+ rewrite(Float(f)).to(Float.rational(BigRat(f.to_i64(), 1)), eq(f64.from_i64(f.to_i64())).to(f)),
396
+ # always convert from int to rational
397
+ rewrite(Float.from_int(Int(i))).to(Float.rational(BigRat(i, 1))),
398
+ rewrite(Float(f) + Float(f2)).to(Float(f + f2)),
399
+ rewrite(Float(f) - Float(f2)).to(Float(f - f2)),
400
+ rewrite(Float(f) * Float(f2)).to(Float(f * f2)),
401
+ rewrite(Float.rational(r) / Float.rational(r1)).to(Float.rational(r / r1)),
402
+ rewrite(Float.rational(r) + Float.rational(r1)).to(Float.rational(r + r1)),
403
+ rewrite(Float.rational(r) - Float.rational(r1)).to(Float.rational(r - r1)),
404
+ rewrite(Float.rational(r) * Float.rational(r1)).to(Float.rational(r * r1)),
405
+ rewrite(Float(f) ** Float(f2)).to(Float(f**f2)),
406
+ # comparisons
407
+ rewrite(Float(f) == Float(f)).to(TRUE),
408
+ rewrite(Float(f) == Float(f2)).to(FALSE, ne(f).to(f2)),
409
+ rewrite(Float(f) != Float(f2)).to(TRUE, f != f2),
410
+ rewrite(Float(f) != Float(f)).to(FALSE),
411
+ rewrite(Float(f) >= Float(f2)).to(TRUE, f >= f2),
412
+ rewrite(Float(f) >= Float(f2)).to(FALSE, f < f2),
413
+ rewrite(Float(f) <= Float(f2)).to(TRUE, f <= f2),
414
+ rewrite(Float(f) <= Float(f2)).to(FALSE, f > f2),
415
+ rewrite(Float(f) > Float(f2)).to(TRUE, f > f2),
416
+ rewrite(Float(f) > Float(f2)).to(FALSE, f <= f2),
417
+ rewrite(Float(f) < Float(f2)).to(TRUE, f < f2),
418
+ rewrite(Float.rational(r) == Float.rational(r)).to(TRUE),
419
+ rewrite(Float.rational(r) == Float.rational(r1)).to(FALSE, ne(r).to(r1)),
420
+ # round
421
+ rewrite(Float.rational(r).__round__()).to(Float.rational(r.round())),
422
+ ]
423
+
424
+
425
+ class TupleInt(Expr, ruleset=array_api_ruleset):
426
+ """
427
+ Should act like a tuple[int, ...]
428
+
429
+ All constructors should be rewritten to the functional semantics in the __init__ method.
430
+ """
431
+
432
+ @classmethod
433
+ def var(cls, name: StringLike) -> TupleInt: ...
434
+
435
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], Int]) -> None: ...
436
+
437
+ EMPTY: ClassVar[TupleInt]
438
+ NEVER: ClassVar[TupleInt]
439
+
440
+ def append(self, i: IntLike) -> TupleInt: ...
441
+
442
+ @classmethod
443
+ def single(cls, i: Int) -> TupleInt:
444
+ return TupleInt(Int(1), lambda _: i)
445
+
446
+ @method(subsume=True)
447
+ @classmethod
448
+ def range(cls, stop: IntLike) -> TupleInt:
449
+ return TupleInt(stop, lambda i: i)
450
+
451
+ @classmethod
452
+ def from_vec(cls, vec: VecLike[Int, IntLike]) -> TupleInt: ...
453
+
454
+ def __add__(self, other: TupleIntLike) -> TupleInt:
455
+ other = cast("TupleInt", other)
456
+ return TupleInt(
457
+ self.length() + other.length(), lambda i: Int.if_(i < self.length(), self[i], other[i - self.length()])
458
+ )
459
+
460
+ def length(self) -> Int: ...
461
+ def __getitem__(self, i: IntLike) -> Int: ...
462
+
463
+ @method(preserve=True)
464
+ def __len__(self) -> int:
465
+ return self.length().eval()
466
+
467
+ @method(preserve=True)
468
+ def __iter__(self) -> Iterator[Int]:
469
+ return iter(self.eval())
470
+
471
+ @property
472
+ def to_vec(self) -> Vec[Int]: ...
473
+
474
+ @method(preserve=True)
475
+ def eval(self) -> tuple[Int, ...]:
476
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
477
+
478
+ def foldl(self, f: Callable[[Int, Int], Int], init: Int) -> Int: ...
479
+ def foldl_boolean(self, f: Callable[[Boolean, Int], Boolean], init: Boolean) -> Boolean: ...
480
+ def foldl_tuple_int(self, f: Callable[[TupleInt, Int], TupleInt], init: TupleIntLike) -> TupleInt: ...
481
+
482
+ @method(subsume=True)
483
+ def contains(self, i: Int) -> Boolean:
484
+ return self.foldl_boolean(lambda acc, j: acc | (i == j), FALSE)
485
+
486
+ @method(subsume=True)
487
+ def filter(self, f: Callable[[Int], Boolean]) -> TupleInt:
488
+ return self.foldl_tuple_int(
489
+ lambda acc, v: TupleInt.if_(f(v), acc.append(v), acc),
490
+ TupleInt.EMPTY,
491
+ )
492
+
493
+ @method(subsume=True)
494
+ def map(self, f: Callable[[Int], Int]) -> TupleInt:
495
+ return TupleInt(self.length(), lambda i: f(self[i]))
496
+
497
+ @classmethod
498
+ def if_(cls, b: BooleanLike, i: TupleIntLike, j: TupleIntLike) -> TupleInt: ...
499
+
500
+ def drop(self, n: Int) -> TupleInt:
501
+ return TupleInt(self.length() - n, lambda i: self[i + n])
502
+
503
+ def product(self) -> Int:
504
+ return self.foldl(lambda acc, i: acc * i, Int(1))
505
+
506
+ def map_tuple_int(self, f: Callable[[Int], TupleInt]) -> TupleTupleInt:
507
+ return TupleTupleInt(self.length(), lambda i: f(self[i]))
508
+
509
+ def select(self, indices: TupleIntLike) -> TupleInt:
510
+ """
511
+ Return a new tuple with the elements at the given indices
512
+ """
513
+ indices = cast("TupleInt", indices)
514
+ return indices.map(lambda i: self[i])
515
+
516
+ def deselect(self, indices: TupleIntLike) -> TupleInt:
517
+ """
518
+ Return a new tuple with the elements not at the given indices
519
+ """
520
+ indices = cast("TupleInt", indices)
521
+ return TupleInt.range(self.length()).filter(lambda i: ~indices.contains(i)).map(lambda i: self[i])
522
+
523
+
524
+ converter(Vec[Int], TupleInt, lambda x: TupleInt.from_vec(x))
525
+
526
+ TupleIntLike: TypeAlias = TupleInt | VecLike[Int, IntLike]
527
+
528
+
529
+ @array_api_ruleset.register
530
+ def _tuple_int(
531
+ i: Int,
532
+ i2: Int,
533
+ f: Callable[[Int, Int], Int],
534
+ bool_f: Callable[[Boolean, Int], Boolean],
535
+ idx_fn: Callable[[Int], Int],
536
+ tuple_int_f: Callable[[TupleInt, Int], TupleInt],
537
+ vs: Vec[Int],
538
+ b: Boolean,
539
+ ti: TupleInt,
540
+ ti2: TupleInt,
541
+ k: i64,
542
+ ):
543
+ return [
544
+ rule(eq(ti).to(TupleInt.from_vec(vs))).then(set_(ti.to_vec).to(vs)),
545
+ # Functional access
546
+ rewrite(TupleInt(i, idx_fn).length()).to(i),
547
+ rewrite(TupleInt(i, idx_fn)[i2]).to(idx_fn(check_index(i, i2))),
548
+ # cons access
549
+ rewrite(TupleInt.EMPTY.length()).to(Int(0)),
550
+ rewrite(TupleInt.EMPTY[i]).to(Int.NEVER),
551
+ rewrite(ti.append(i).length()).to(ti.length() + 1),
552
+ rewrite(ti.append(i)[i2]).to(Int.if_(i2 == ti.length(), i, ti[i2])),
553
+ # cons to functional (removed this so that there is not infinite replacements between the,)
554
+ # rewrite(TupleInt.EMPTY).to(TupleInt(0, lambda _: Int.NEVER)),
555
+ # rewrite(TupleInt(i, idx_fn).append(i2)).to(TupleInt(i + 1, lambda j: Int.if_(j == i, i2, idx_fn(j)))),
556
+ # functional to cons
557
+ rewrite(TupleInt(0, idx_fn), subsume=True).to(TupleInt.EMPTY),
558
+ rewrite(TupleInt(Int(k), idx_fn), subsume=True).to(TupleInt(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0),
559
+ # cons to vec
560
+ rewrite(TupleInt.EMPTY).to(TupleInt.from_vec(Vec[Int]())),
561
+ rewrite(TupleInt.from_vec(vs).append(i)).to(TupleInt.from_vec(vs.append(Vec(i)))),
562
+ # fold
563
+ rewrite(TupleInt.EMPTY.foldl(f, i), subsume=True).to(i),
564
+ rewrite(ti.append(i2).foldl(f, i), subsume=True).to(f(ti.foldl(f, i), i2)),
565
+ # fold boolean
566
+ rewrite(TupleInt.EMPTY.foldl_boolean(bool_f, b), subsume=True).to(b),
567
+ rewrite(ti.append(i2).foldl_boolean(bool_f, b), subsume=True).to(bool_f(ti.foldl_boolean(bool_f, b), i2)),
568
+ # fold tuple_int
569
+ rewrite(TupleInt.EMPTY.foldl_tuple_int(tuple_int_f, ti), subsume=True).to(ti),
570
+ rewrite(ti.append(i2).foldl_tuple_int(tuple_int_f, ti2), subsume=True).to(
571
+ tuple_int_f(ti.foldl_tuple_int(tuple_int_f, ti2), i2)
572
+ ),
573
+ # if_
574
+ rewrite(TupleInt.if_(TRUE, ti, ti2), subsume=True).to(ti),
575
+ rewrite(TupleInt.if_(FALSE, ti, ti2), subsume=True).to(ti2),
576
+ # unify append
577
+ rule(eq(ti.append(i)).to(ti2.append(i2))).then(union(ti).with_(ti2), union(i).with_(i2)),
578
+ ]
579
+
580
+
581
+ class TupleTupleInt(Expr, ruleset=array_api_ruleset):
582
+ @classmethod
583
+ def var(cls, name: StringLike) -> TupleTupleInt: ...
584
+
585
+ EMPTY: ClassVar[TupleTupleInt]
586
+
587
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> None: ...
588
+
589
+ @method(subsume=True)
590
+ @classmethod
591
+ def single(cls, i: TupleIntLike) -> TupleTupleInt:
592
+ i = cast("TupleInt", i)
593
+ return TupleTupleInt(1, lambda _: i)
594
+
595
+ @method(subsume=True)
596
+ @classmethod
597
+ def from_vec(cls, vec: Vec[TupleInt]) -> TupleTupleInt: ...
598
+
599
+ def append(self, i: TupleIntLike) -> TupleTupleInt: ...
600
+
601
+ def __add__(self, other: TupleTupleIntLike) -> TupleTupleInt:
602
+ other = cast("TupleTupleInt", other)
603
+ return TupleTupleInt(
604
+ self.length() + other.length(),
605
+ lambda i: TupleInt.if_(i < self.length(), self[i], other[i - self.length()]),
606
+ )
607
+
608
+ def length(self) -> Int: ...
609
+ def __getitem__(self, i: IntLike) -> TupleInt: ...
610
+
611
+ @method(preserve=True)
612
+ def __len__(self) -> int:
613
+ return self.length().eval()
614
+
615
+ @method(preserve=True)
616
+ def __iter__(self) -> Iterator[TupleInt]:
617
+ return iter(self.eval())
618
+
619
+ @property
620
+ def to_vec(self) -> Vec[TupleInt]: ...
621
+
622
+ @method(preserve=True)
623
+ def eval(self) -> tuple[TupleInt, ...]:
624
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
625
+
626
+ def drop(self, n: Int) -> TupleTupleInt:
627
+ return TupleTupleInt(self.length() - n, lambda i: self[i + n])
628
+
629
+ def map_int(self, f: Callable[[TupleInt], Int]) -> TupleInt:
630
+ return TupleInt(self.length(), lambda i: f(self[i]))
631
+
632
+ def foldl_value(self, f: Callable[[Value, TupleInt], Value], init: ValueLike) -> Value: ...
633
+
634
+ @method(subsume=True)
635
+ def product(self) -> TupleTupleInt:
636
+ """
637
+ Cartesian product of inputs
638
+
639
+ https://docs.python.org/3/library/itertools.html#itertools.product
640
+
641
+ https://github.com/saulshanabrook/saulshanabrook/discussions/39
642
+ """
643
+ return TupleTupleInt(
644
+ self.map_int(lambda x: x.length()).product(),
645
+ lambda i: TupleInt(
646
+ self.length(),
647
+ lambda j: self[j][(i // self.drop(j + 1).map_int(lambda x: x.length()).product()) % self[j].length()],
648
+ ),
649
+ )
650
+
651
+
652
+ converter(Vec[TupleInt], TupleTupleInt, lambda x: TupleTupleInt.from_vec(x))
653
+
654
+ TupleTupleIntLike: TypeAlias = TupleTupleInt | VecLike[TupleInt, TupleIntLike]
655
+
656
+
657
+ @array_api_ruleset.register
658
+ def _tuple_tuple_int(
659
+ length: Int,
660
+ fn: Callable[[TupleInt], Int],
661
+ idx_fn: Callable[[Int], TupleInt],
662
+ f: Callable[[Value, TupleInt], Value],
663
+ i: Value,
664
+ k: i64,
665
+ idx: Int,
666
+ vs: Vec[TupleInt],
667
+ ti: TupleInt,
668
+ ti1: TupleInt,
669
+ tti: TupleTupleInt,
670
+ tti1: TupleTupleInt,
671
+ ):
672
+ yield rule(eq(tti).to(TupleTupleInt.from_vec(vs))).then(set_(tti.to_vec).to(vs))
673
+ yield rewrite(TupleTupleInt(length, idx_fn).length()).to(length)
674
+ yield rewrite(TupleTupleInt(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length)))
675
+
676
+ # cons access
677
+ yield rewrite(TupleTupleInt.EMPTY.length()).to(Int(0))
678
+ yield rewrite(TupleTupleInt.EMPTY[idx]).to(TupleInt.NEVER)
679
+ yield rewrite(tti.append(ti).length()).to(tti.length() + 1)
680
+ yield rewrite(tti.append(ti)[idx]).to(TupleInt.if_(idx == tti.length(), ti, tti[idx]))
681
+
682
+ # functional to cons
683
+ yield rewrite(TupleTupleInt(0, idx_fn), subsume=True).to(TupleTupleInt.EMPTY)
684
+ yield rewrite(TupleTupleInt(Int(k), idx_fn), subsume=True).to(
685
+ TupleTupleInt(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0
686
+ )
687
+ # cons to vec
688
+ yield rewrite(TupleTupleInt.EMPTY).to(TupleTupleInt.from_vec(Vec[TupleInt]()))
689
+ yield rewrite(TupleTupleInt.from_vec(vs).append(ti)).to(TupleTupleInt.from_vec(vs.append(Vec(ti))))
690
+ # fold value
691
+ yield rewrite(TupleTupleInt.EMPTY.foldl_value(f, i), subsume=True).to(i)
692
+ yield rewrite(tti.append(ti).foldl_value(f, i), subsume=True).to(f(tti.foldl_value(f, i), ti))
693
+
694
+ # unify append
695
+ yield rule(eq(tti.append(ti)).to(tti1.append(ti1))).then(union(tti).with_(tti1), union(ti).with_(ti1))
696
+
697
+
698
+ class OptionalInt(Expr, ruleset=array_api_ruleset):
699
+ none: ClassVar[OptionalInt]
700
+
701
+ @classmethod
702
+ def some(cls, value: Int) -> OptionalInt: ...
703
+
704
+
705
+ OptionalIntLike: TypeAlias = OptionalInt | IntLike | None
706
+
707
+ converter(type(None), OptionalInt, lambda _: OptionalInt.none)
708
+ converter(Int, OptionalInt, OptionalInt.some)
709
+
710
+
711
+ class DType(Expr, ruleset=array_api_ruleset):
712
+ float64: ClassVar[DType]
713
+ float32: ClassVar[DType]
714
+ int64: ClassVar[DType]
715
+ int32: ClassVar[DType]
716
+ object: ClassVar[DType]
717
+ bool: ClassVar[DType]
718
+
719
+ def __eq__(self, other: DType) -> Boolean: # type: ignore[override]
720
+ ...
721
+
722
+
723
+ float64 = DType.float64
724
+ float32 = DType.float32
725
+ int32 = DType.int32
726
+ int64 = DType.int64
727
+
728
+ _DTYPES = [float64, float32, int32, int64, DType.object]
729
+
730
+ converter(type, DType, lambda x: convert(np.dtype(x), DType))
731
+ converter(np.dtype, DType, lambda x: getattr(DType, x.name))
732
+
733
+
734
+ @array_api_ruleset.register
735
+ def _():
736
+ for l, r in itertools.product(_DTYPES, repeat=2):
737
+ yield rewrite(l == r).to(TRUE if l is r else FALSE)
738
+
739
+
740
+ class IsDtypeKind(Expr, ruleset=array_api_ruleset):
741
+ NULL: ClassVar[IsDtypeKind]
742
+
743
+ @classmethod
744
+ def string(cls, s: StringLike) -> IsDtypeKind: ...
745
+
746
+ @classmethod
747
+ def dtype(cls, d: DType) -> IsDtypeKind: ...
748
+
749
+ @method(cost=10)
750
+ def __or__(self, other: IsDtypeKind) -> IsDtypeKind: ...
751
+
752
+
753
+ # TODO: Make kind more generic to support tuples.
754
+ @function
755
+ def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: ...
756
+
757
+
758
+ converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x))
759
+ converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x))
760
+ converter(
761
+ tuple, IsDtypeKind, lambda x: convert(x[0], IsDtypeKind) | convert(x[1:], IsDtypeKind) if x else IsDtypeKind.NULL
762
+ )
763
+
764
+
765
+ @array_api_ruleset.register
766
+ def _isdtype(d: DType, k1: IsDtypeKind, k2: IsDtypeKind):
767
+ return [
768
+ rewrite(isdtype(DType.float32, IsDtypeKind.string("integral"))).to(FALSE),
769
+ rewrite(isdtype(DType.float64, IsDtypeKind.string("integral"))).to(FALSE),
770
+ rewrite(isdtype(DType.object, IsDtypeKind.string("integral"))).to(FALSE),
771
+ rewrite(isdtype(DType.int64, IsDtypeKind.string("integral"))).to(TRUE),
772
+ rewrite(isdtype(DType.int32, IsDtypeKind.string("integral"))).to(TRUE),
773
+ rewrite(isdtype(DType.float32, IsDtypeKind.string("real floating"))).to(TRUE),
774
+ rewrite(isdtype(DType.float64, IsDtypeKind.string("real floating"))).to(TRUE),
775
+ rewrite(isdtype(DType.object, IsDtypeKind.string("real floating"))).to(FALSE),
776
+ rewrite(isdtype(DType.int64, IsDtypeKind.string("real floating"))).to(FALSE),
777
+ rewrite(isdtype(DType.int32, IsDtypeKind.string("real floating"))).to(FALSE),
778
+ rewrite(isdtype(DType.float32, IsDtypeKind.string("complex floating"))).to(FALSE),
779
+ rewrite(isdtype(DType.float64, IsDtypeKind.string("complex floating"))).to(FALSE),
780
+ rewrite(isdtype(DType.object, IsDtypeKind.string("complex floating"))).to(FALSE),
781
+ rewrite(isdtype(DType.int64, IsDtypeKind.string("complex floating"))).to(FALSE),
782
+ rewrite(isdtype(DType.int32, IsDtypeKind.string("complex floating"))).to(FALSE),
783
+ rewrite(isdtype(d, IsDtypeKind.NULL)).to(FALSE),
784
+ rewrite(isdtype(d, IsDtypeKind.dtype(d))).to(TRUE),
785
+ rewrite(isdtype(d, k1 | k2)).to(isdtype(d, k1) | isdtype(d, k2)),
786
+ rewrite(k1 | IsDtypeKind.NULL).to(k1),
787
+ ]
788
+
789
+
790
+ # TODO: Add pushdown for math on scalars to values
791
+ # and add replacements
792
+
793
+
794
+ class Value(Expr, ruleset=array_api_ruleset):
795
+ NEVER: ClassVar[Value]
796
+
797
+ @classmethod
798
+ def int(cls, i: IntLike) -> Value: ...
799
+
800
+ @classmethod
801
+ def float(cls, f: FloatLike) -> Value: ...
802
+
803
+ @classmethod
804
+ def bool(cls, b: BooleanLike) -> Value: ...
805
+
806
+ def isfinite(self) -> Boolean: ...
807
+
808
+ def __lt__(self, other: ValueLike) -> Value: ...
809
+
810
+ def __truediv__(self, other: ValueLike) -> Value: ...
811
+
812
+ def __mul__(self, other: ValueLike) -> Value: ...
813
+
814
+ def __add__(self, other: ValueLike) -> Value: ...
815
+
816
+ def astype(self, dtype: DType) -> Value: ...
817
+
818
+ # TODO: Add all operations
819
+
820
+ @property
821
+ def dtype(self) -> DType:
822
+ """
823
+ Default dtype for this scalar value
824
+ """
825
+
826
+ @property
827
+ def to_bool(self) -> Boolean: ...
828
+
829
+ @property
830
+ def to_int(self) -> Int: ...
831
+
832
+ @property
833
+ def to_truthy_value(self) -> Value:
834
+ """
835
+ Converts the value to a bool, based on if its truthy.
836
+
837
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.any.html
838
+ """
839
+
840
+ def conj(self) -> Value: ...
841
+ def real(self) -> Value: ...
842
+ def sqrt(self) -> Value: ...
843
+
844
+ @classmethod
845
+ def if_(cls, b: BooleanLike, i: ValueLike, j: ValueLike) -> Value: ...
846
+
847
+ def __eq__(self, other: ValueLike) -> Boolean: ... # type: ignore[override]
848
+
849
+
850
+ ValueLike: TypeAlias = Value | IntLike | FloatLike | BooleanLike
851
+
852
+
853
+ converter(Int, Value, Value.int)
854
+ converter(Float, Value, Value.float)
855
+ converter(Boolean, Value, Value.bool)
856
+ converter(Value, Int, lambda x: x.to_int, 10)
857
+
858
+
859
+ @array_api_ruleset.register
860
+ def _value(i: Int, f: Float, b: Boolean, v: Value, v1: Value, i1: Int, f1: Float, b1: Boolean):
861
+ # Default dtypes
862
+ # https://data-apis.org/array-api/latest/API_specification/data_types.html?highlight=dtype#default-data-types
863
+ yield rewrite(Value.int(i).dtype).to(DType.int64)
864
+ yield rewrite(Value.float(f).dtype).to(DType.float64)
865
+ yield rewrite(Value.bool(b).dtype).to(DType.bool)
866
+
867
+ yield rewrite(Value.bool(b).to_bool).to(b)
868
+ yield rewrite(Value.int(i).to_int).to(i)
869
+
870
+ yield rewrite(Value.bool(b).to_truthy_value).to(Value.bool(b))
871
+ # TODO: Add more rules for to_bool_value
872
+
873
+ yield rewrite(Value.float(f).conj()).to(Value.float(f))
874
+ yield rewrite(Value.float(f).real()).to(Value.float(f))
875
+ yield rewrite(Value.int(i).real()).to(Value.int(i))
876
+ yield rewrite(Value.int(i).conj()).to(Value.int(i))
877
+
878
+ yield rewrite(Value.float(f).sqrt()).to(Value.float(f ** (0.5)))
879
+
880
+ yield rewrite(Value.float(Float.rational(BigRat(0, 1))) + v).to(v)
881
+
882
+ yield rewrite(Value.if_(TRUE, v, v1)).to(v)
883
+ yield rewrite(Value.if_(FALSE, v, v1)).to(v1)
884
+
885
+ # ==
886
+ yield rewrite(Value.int(i) == Value.int(i1)).to(i == i1)
887
+ yield rewrite(Value.float(f) == Value.float(f1)).to(f == f1)
888
+ yield rewrite(Value.bool(b) == Value.bool(b1)).to(b == b1)
889
+
890
+
891
+ class TupleValue(Expr, ruleset=array_api_ruleset):
892
+ EMPTY: ClassVar[TupleValue]
893
+
894
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], Value]) -> None: ...
895
+
896
+ def append(self, i: ValueLike) -> TupleValue: ...
897
+
898
+ @classmethod
899
+ def from_vec(cls, vec: Vec[Value]) -> TupleValue: ...
900
+
901
+ def __add__(self, other: TupleValueLike) -> TupleValue:
902
+ other = cast("TupleValue", other)
903
+ return TupleValue(
904
+ self.length() + other.length(),
905
+ lambda i: Value.if_(i < self.length(), self[i], other[i - self.length()]),
906
+ )
907
+
908
+ def length(self) -> Int: ...
909
+
910
+ def __getitem__(self, i: Int) -> Value: ...
911
+
912
+ def foldl_boolean(self, f: Callable[[Boolean, Value], Boolean], init: BooleanLike) -> Boolean: ...
913
+
914
+ def contains(self, value: ValueLike) -> Boolean:
915
+ value = cast("Value", value)
916
+ return self.foldl_boolean(lambda acc, j: acc | (value == j), FALSE)
917
+
918
+ @method(subsume=True)
919
+ @classmethod
920
+ def from_tuple_int(cls, ti: TupleIntLike) -> TupleValue:
921
+ ti = cast("TupleInt", ti)
922
+ return TupleValue(ti.length(), lambda i: Value.int(ti[i]))
923
+
924
+
925
+ converter(Vec[Value], TupleValue, lambda x: TupleValue.from_vec(x))
926
+ converter(TupleInt, TupleValue, lambda x: TupleValue.from_tuple_int(x))
927
+
928
+ TupleValueLike: TypeAlias = TupleValue | VecLike[Value, ValueLike] | TupleIntLike
929
+
930
+
931
+ @array_api_ruleset.register
932
+ def _tuple_value(
933
+ length: Int,
934
+ idx_fn: Callable[[Int], Value],
935
+ k: i64,
936
+ idx: Int,
937
+ vs: Vec[Value],
938
+ v: Value,
939
+ v1: Value,
940
+ tv: TupleValue,
941
+ tv1: TupleValue,
942
+ bool_f: Callable[[Boolean, Value], Boolean],
943
+ b: Boolean,
944
+ ):
945
+ yield rewrite(TupleValue(length, idx_fn).length()).to(length)
946
+ yield rewrite(TupleValue(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length)))
947
+
948
+ # cons access
949
+ yield rewrite(TupleValue.EMPTY.length()).to(Int(0))
950
+ yield rewrite(TupleValue.EMPTY[idx]).to(Value.NEVER)
951
+ yield rewrite(tv.append(v).length()).to(tv.length() + 1)
952
+ yield rewrite(tv.append(v)[idx]).to(Value.if_(idx == tv.length(), v, tv[idx]))
953
+
954
+ # functional to cons
955
+ yield rewrite(TupleValue(0, idx_fn), subsume=True).to(TupleValue.EMPTY)
956
+ yield rewrite(TupleValue(Int(k), idx_fn), subsume=True).to(
957
+ TupleValue(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0
958
+ )
959
+
960
+ # cons to vec
961
+ yield rewrite(TupleValue.EMPTY).to(TupleValue.from_vec(Vec[Value]()))
962
+ yield rewrite(TupleValue.from_vec(vs).append(v)).to(TupleValue.from_vec(vs.append(Vec(v))))
963
+
964
+ # fold boolean
965
+ yield rewrite(TupleValue.EMPTY.foldl_boolean(bool_f, b), subsume=True).to(b)
966
+ yield rewrite(tv.append(v).foldl_boolean(bool_f, b), subsume=True).to(bool_f(tv.foldl_boolean(bool_f, b), v))
967
+
968
+ # unify append
969
+ yield rule(eq(tv.append(v)).to(tv1.append(v1))).then(union(tv).with_(tv1), union(v).with_(v1))
970
+
971
+
972
+ @function
973
+ def possible_values(values: Value) -> TupleValue:
974
+ """
975
+ All possible values in the input value.
976
+ """
977
+
978
+
979
+ class Slice(Expr, ruleset=array_api_ruleset):
980
+ def __init__(
981
+ self,
982
+ start: OptionalInt = OptionalInt.none,
983
+ stop: OptionalInt = OptionalInt.none,
984
+ step: OptionalInt = OptionalInt.none,
985
+ ) -> None: ...
986
+
987
+
988
+ converter(
989
+ slice,
990
+ Slice,
991
+ lambda x: Slice(convert(x.start, OptionalInt), convert(x.stop, OptionalInt), convert(x.step, OptionalInt)),
992
+ )
993
+
994
+ SliceLike: TypeAlias = Slice | slice
995
+
996
+
997
+ class MultiAxisIndexKeyItem(Expr, ruleset=array_api_ruleset):
998
+ ELLIPSIS: ClassVar[MultiAxisIndexKeyItem]
999
+ NONE: ClassVar[MultiAxisIndexKeyItem]
1000
+
1001
+ @classmethod
1002
+ def int(cls, i: Int) -> MultiAxisIndexKeyItem: ...
1003
+
1004
+ @classmethod
1005
+ def slice(cls, slice: Slice) -> MultiAxisIndexKeyItem: ...
1006
+
1007
+
1008
+ converter(type(...), MultiAxisIndexKeyItem, lambda _: MultiAxisIndexKeyItem.ELLIPSIS)
1009
+ converter(type(None), MultiAxisIndexKeyItem, lambda _: MultiAxisIndexKeyItem.NONE)
1010
+ converter(Int, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.int)
1011
+ converter(Slice, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.slice)
1012
+
1013
+ MultiAxisIndexKeyItemLike: TypeAlias = MultiAxisIndexKeyItem | EllipsisType | None | IntLike | SliceLike
1014
+
1015
+
1016
+ class MultiAxisIndexKey(Expr, ruleset=array_api_ruleset):
1017
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], MultiAxisIndexKeyItem]) -> None: ...
1018
+
1019
+ def __add__(self, other: MultiAxisIndexKey) -> MultiAxisIndexKey: ...
1020
+
1021
+ @classmethod
1022
+ def from_vec(cls, vec: Vec[MultiAxisIndexKeyItem]) -> MultiAxisIndexKey: ...
1023
+
1024
+
1025
+ MultiAxisIndexKeyLike: TypeAlias = "MultiAxisIndexKey | tuple[MultiAxisIndexKeyItemLike, ...] | TupleIntLike"
1026
+
1027
+
1028
+ converter(
1029
+ tuple,
1030
+ MultiAxisIndexKey,
1031
+ lambda x: MultiAxisIndexKey.from_vec(Vec(*(convert(i, MultiAxisIndexKeyItem) for i in x))),
1032
+ )
1033
+ converter(
1034
+ TupleInt, MultiAxisIndexKey, lambda ti: MultiAxisIndexKey(ti.length(), lambda i: MultiAxisIndexKeyItem.int(ti[i]))
1035
+ )
1036
+
1037
+
1038
+ class IndexKey(Expr, ruleset=array_api_ruleset):
1039
+ """
1040
+ A key for indexing into an array
1041
+
1042
+ https://data-apis.org/array-api/2022.12/API_specification/indexing.html
1043
+
1044
+ It is equivalent to the following type signature:
1045
+
1046
+ Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis, None], ...], array]
1047
+ """
1048
+
1049
+ ELLIPSIS: ClassVar[IndexKey]
1050
+
1051
+ @classmethod
1052
+ def int(cls, i: Int) -> IndexKey: ...
1053
+
1054
+ @classmethod
1055
+ def slice(cls, slice: Slice) -> IndexKey: ...
1056
+
1057
+ # Disabled until we support late binding
1058
+ # @classmethod
1059
+ # def boolean_array(cls, b: NDArray) -> IndexKey:
1060
+ # ...
1061
+
1062
+ @classmethod
1063
+ def multi_axis(cls, key: MultiAxisIndexKey) -> IndexKey: ...
1064
+
1065
+ @classmethod
1066
+ def ndarray(cls, key: NDArray) -> IndexKey:
1067
+ """
1068
+ Indexes by a masked array
1069
+ """
1070
+
1071
+
1072
+ IndexKeyLike: TypeAlias = "IndexKey | IntLike | SliceLike | MultiAxisIndexKeyLike | NDArrayLike"
1073
+
1074
+
1075
+ converter(type(...), IndexKey, lambda _: IndexKey.ELLIPSIS)
1076
+ converter(Int, IndexKey, lambda i: IndexKey.int(i))
1077
+ converter(Slice, IndexKey, lambda s: IndexKey.slice(s))
1078
+ converter(MultiAxisIndexKey, IndexKey, lambda m: IndexKey.multi_axis(m))
1079
+
1080
+
1081
+ class Device(Expr, ruleset=array_api_ruleset): ...
1082
+
1083
+
1084
+ ALL_INDICES: TupleInt = constant("ALL_INDICES", TupleInt)
1085
+
1086
+
1087
+ class NDArray(Expr, ruleset=array_api_ruleset):
1088
+ def __init__(self, shape: TupleIntLike, dtype: DType, idx_fn: Callable[[TupleInt], Value]) -> None: ...
1089
+
1090
+ NEVER: ClassVar[NDArray]
1091
+
1092
+ @method(cost=200)
1093
+ @classmethod
1094
+ def var(cls, name: StringLike) -> NDArray: ...
1095
+
1096
+ @method(preserve=True)
1097
+ def __array_namespace__(self, api_version: object = None) -> ModuleType:
1098
+ return sys.modules[__name__]
1099
+
1100
+ @property
1101
+ def ndim(self) -> Int: ...
1102
+
1103
+ @property
1104
+ def dtype(self) -> DType: ...
1105
+
1106
+ @property
1107
+ def device(self) -> Device: ...
1108
+
1109
+ @property
1110
+ def shape(self) -> TupleInt: ...
1111
+
1112
+ @method(preserve=True)
1113
+ def __bool__(self) -> bool:
1114
+ return self.to_value().to_bool.eval()
1115
+
1116
+ @property
1117
+ def size(self) -> Int: ...
1118
+
1119
+ @method(preserve=True)
1120
+ def __len__(self) -> int:
1121
+ return self.size.eval()
1122
+
1123
+ @method(preserve=True)
1124
+ def __iter__(self) -> Iterator[NDArray]:
1125
+ for i in range(len(self)):
1126
+ yield self[IndexKey.int(Int(i))]
1127
+
1128
+ def __getitem__(self, key: IndexKeyLike) -> NDArray: ...
1129
+
1130
+ def __setitem__(self, key: IndexKeyLike, value: NDArray) -> None: ...
1131
+
1132
+ def __lt__(self, other: NDArrayLike) -> NDArray: ...
1133
+
1134
+ def __le__(self, other: NDArrayLike) -> NDArray: ...
1135
+
1136
+ def __eq__(self, other: NDArrayLike) -> NDArray: # type: ignore[override]
1137
+ ...
1138
+
1139
+ # TODO: Add support for overloaded __ne__
1140
+ # def __ne__(self, other: NDArrayLike) -> NDArray: # type: ignore[override]
1141
+ # ...
1142
+
1143
+ def __gt__(self, other: NDArrayLike) -> NDArray: ...
1144
+
1145
+ def __ge__(self, other: NDArrayLike) -> NDArray: ...
1146
+
1147
+ def __add__(self, other: NDArrayLike) -> NDArray: ...
1148
+
1149
+ def __sub__(self, other: NDArrayLike) -> NDArray: ...
1150
+
1151
+ def __mul__(self, other: NDArrayLike) -> NDArray: ...
1152
+
1153
+ def __matmul__(self, other: NDArrayLike) -> NDArray: ...
1154
+
1155
+ def __truediv__(self, other: NDArrayLike) -> NDArray: ...
1156
+
1157
+ def __floordiv__(self, other: NDArrayLike) -> NDArray: ...
1158
+
1159
+ def __mod__(self, other: NDArrayLike) -> NDArray: ...
1160
+
1161
+ def __divmod__(self, other: NDArrayLike) -> NDArray: ...
1162
+
1163
+ def __pow__(self, other: NDArrayLike) -> NDArray: ...
1164
+
1165
+ def __lshift__(self, other: NDArrayLike) -> NDArray: ...
1166
+
1167
+ def __rshift__(self, other: NDArrayLike) -> NDArray: ...
1168
+
1169
+ def __and__(self, other: NDArrayLike) -> NDArray: ...
1170
+
1171
+ def __xor__(self, other: NDArrayLike) -> NDArray: ...
1172
+
1173
+ def __or__(self, other: NDArrayLike) -> NDArray: ...
1174
+
1175
+ def __radd__(self, other: NDArray) -> NDArray: ...
1176
+
1177
+ def __rsub__(self, other: NDArray) -> NDArray: ...
1178
+
1179
+ def __rmul__(self, other: NDArray) -> NDArray: ...
1180
+
1181
+ def __rmatmul__(self, other: NDArray) -> NDArray: ...
1182
+
1183
+ def __rtruediv__(self, other: NDArray) -> NDArray: ...
1184
+
1185
+ def __rfloordiv__(self, other: NDArray) -> NDArray: ...
1186
+
1187
+ def __rmod__(self, other: NDArray) -> NDArray: ...
1188
+
1189
+ def __rpow__(self, other: NDArray) -> NDArray: ...
1190
+
1191
+ def __rlshift__(self, other: NDArray) -> NDArray: ...
1192
+
1193
+ def __rrshift__(self, other: NDArray) -> NDArray: ...
1194
+
1195
+ def __rand__(self, other: NDArray) -> NDArray: ...
1196
+
1197
+ def __rxor__(self, other: NDArray) -> NDArray: ...
1198
+
1199
+ def __ror__(self, other: NDArray) -> NDArray: ...
1200
+
1201
+ @classmethod
1202
+ def scalar(cls, value: Value) -> NDArray:
1203
+ return NDArray(TupleInt.EMPTY, value.dtype, lambda _: value)
1204
+
1205
+ def to_value(self) -> Value:
1206
+ """
1207
+ Returns the value if this is a scalar.
1208
+ """
1209
+
1210
+ def to_values(self) -> TupleValue:
1211
+ """
1212
+ Returns the value if this is a vector.
1213
+ """
1214
+
1215
+ @property
1216
+ def T(self) -> NDArray:
1217
+ """
1218
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.array.T.html#array_api.array.T
1219
+ """
1220
+
1221
+ @classmethod
1222
+ def vector(cls, values: TupleValueLike) -> NDArray: ...
1223
+
1224
+ def index(self, indices: TupleIntLike) -> Value:
1225
+ """
1226
+ Return the value at the given indices.
1227
+ """
1228
+
1229
+ @classmethod
1230
+ def if_(cls, b: BooleanLike, i: NDArrayLike, j: NDArrayLike) -> NDArray: ...
1231
+
1232
+
1233
+ NDArrayLike: TypeAlias = NDArray | ValueLike | TupleValueLike
1234
+
1235
+ converter(NDArray, IndexKey, lambda v: IndexKey.ndarray(v))
1236
+ converter(Value, NDArray, lambda v: NDArray.scalar(v))
1237
+ # Need this if we want to use ints in slices of arrays coming from 1d arrays, but make it more expensive
1238
+ # to prefer upcasting in the other direction when we can, which is safer at runtime
1239
+ converter(NDArray, Value, lambda n: n.to_value(), 100)
1240
+ converter(TupleValue, NDArray, lambda v: NDArray.vector(v))
1241
+ converter(TupleInt, TupleValue, lambda v: TupleValue.from_tuple_int(v))
1242
+
1243
+
1244
+ @array_api_ruleset.register
1245
+ def _ndarray(
1246
+ x: NDArray,
1247
+ x1: NDArray,
1248
+ b: Boolean,
1249
+ f: Float,
1250
+ fi1: f64,
1251
+ fi2: f64,
1252
+ shape: TupleInt,
1253
+ dtype: DType,
1254
+ idx_fn: Callable[[TupleInt], Value],
1255
+ idx: TupleInt,
1256
+ tv: TupleValue,
1257
+ ):
1258
+ return [
1259
+ rewrite(NDArray(shape, dtype, idx_fn).shape).to(shape),
1260
+ rewrite(NDArray(shape, dtype, idx_fn).dtype).to(dtype),
1261
+ rewrite(NDArray(shape, dtype, idx_fn).index(idx), subsume=True).to(idx_fn(idx)),
1262
+ rewrite(x.ndim).to(x.shape.length()),
1263
+ # rewrite(NDArray.scalar(Value.bool(b)).to_bool()).to(b),
1264
+ # Converting to a value requires a scalar bool value
1265
+ rewrite(x.to_value()).to(x.index(TupleInt.EMPTY)),
1266
+ rewrite(NDArray.vector(tv).to_values()).to(tv),
1267
+ # TODO: Push these down to float
1268
+ rewrite(NDArray.scalar(Value.float(f)) / NDArray.scalar(Value.float(f))).to(
1269
+ NDArray.scalar(Value.float(Float(1.0)))
1270
+ ),
1271
+ rewrite(NDArray.scalar(Value.float(f)) - NDArray.scalar(Value.float(f))).to(
1272
+ NDArray.scalar(Value.float(Float(0.0)))
1273
+ ),
1274
+ rewrite(NDArray.scalar(Value.float(Float(fi1))) > NDArray.scalar(Value.float(Float(fi2)))).to(
1275
+ NDArray.scalar(Value.bool(TRUE)), fi1 > fi2
1276
+ ),
1277
+ rewrite(NDArray.scalar(Value.float(Float(fi1))) > NDArray.scalar(Value.float(Float(fi2)))).to(
1278
+ NDArray.scalar(Value.bool(FALSE)), fi1 <= fi2
1279
+ ),
1280
+ # Transpose of tranpose is the original array
1281
+ rewrite(x.T.T).to(x),
1282
+ # if_
1283
+ rewrite(NDArray.if_(TRUE, x, x1)).to(x),
1284
+ rewrite(NDArray.if_(FALSE, x, x1)).to(x1),
1285
+ ]
1286
+
1287
+
1288
+ class TupleNDArray(Expr, ruleset=array_api_ruleset):
1289
+ EMPTY: ClassVar[TupleNDArray]
1290
+
1291
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], NDArray]) -> None: ...
1292
+
1293
+ def append(self, i: NDArrayLike) -> TupleNDArray: ...
1294
+
1295
+ @classmethod
1296
+ def from_vec(cls, vec: Vec[NDArray]) -> TupleNDArray: ...
1297
+
1298
+ def __add__(self, other: TupleNDArrayLike) -> TupleNDArray:
1299
+ other = cast("TupleNDArray", other)
1300
+ return TupleNDArray(
1301
+ self.length() + other.length(),
1302
+ lambda i: NDArray.if_(i < self.length(), self[i], other[i - self.length()]),
1303
+ )
1304
+
1305
+ def length(self) -> Int: ...
1306
+
1307
+ def __getitem__(self, i: IntLike) -> NDArray: ...
1308
+
1309
+ @method(preserve=True)
1310
+ def __len__(self) -> int:
1311
+ return self.length().eval()
1312
+
1313
+ @method(preserve=True)
1314
+ def __iter__(self) -> Iterator[NDArray]:
1315
+ return iter(self.eval())
1316
+
1317
+ @property
1318
+ def to_vec(self) -> Vec[NDArray]: ...
1319
+
1320
+ @method(preserve=True)
1321
+ def eval(self) -> tuple[NDArray, ...]:
1322
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
1323
+
1324
+
1325
+ converter(Vec[NDArray], TupleNDArray, lambda x: TupleNDArray.from_vec(x))
1326
+
1327
+ TupleNDArrayLike: TypeAlias = TupleNDArray | VecLike[NDArray, NDArrayLike]
1328
+
1329
+
1330
+ @array_api_ruleset.register
1331
+ def _tuple_ndarray(
1332
+ length: Int,
1333
+ idx_fn: Callable[[Int], NDArray],
1334
+ k: i64,
1335
+ idx: Int,
1336
+ vs: Vec[NDArray],
1337
+ v: NDArray,
1338
+ v1: NDArray,
1339
+ tv: TupleNDArray,
1340
+ tv1: TupleNDArray,
1341
+ b: Boolean,
1342
+ ):
1343
+ yield rule(eq(tv).to(TupleNDArray.from_vec(vs))).then(set_(tv.to_vec).to(vs))
1344
+ yield rewrite(TupleNDArray(length, idx_fn).length()).to(length)
1345
+ yield rewrite(TupleNDArray(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length)))
1346
+
1347
+ # cons access
1348
+ yield rewrite(TupleNDArray.EMPTY.length()).to(Int(0))
1349
+ yield rewrite(TupleNDArray.EMPTY[idx]).to(NDArray.NEVER)
1350
+ yield rewrite(tv.append(v).length()).to(tv.length() + 1)
1351
+ yield rewrite(tv.append(v)[idx]).to(NDArray.if_(idx == tv.length(), v, tv[idx]))
1352
+ # functional to cons
1353
+ yield rewrite(TupleNDArray(0, idx_fn), subsume=True).to(TupleNDArray.EMPTY)
1354
+ yield rewrite(TupleNDArray(Int(k), idx_fn), subsume=True).to(
1355
+ TupleNDArray(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0
1356
+ )
1357
+
1358
+ # cons to vec
1359
+ yield rewrite(TupleNDArray.EMPTY).to(TupleNDArray.from_vec(Vec[NDArray]()))
1360
+ yield rewrite(TupleNDArray.from_vec(vs).append(v)).to(TupleNDArray.from_vec(vs.append(Vec(v))))
1361
+
1362
+ # unify append
1363
+ yield rule(eq(tv.append(v)).to(tv1.append(v1))).then(union(tv).with_(tv1), union(v).with_(v1))
1364
+
1365
+
1366
+ class OptionalBool(Expr, ruleset=array_api_ruleset):
1367
+ none: ClassVar[OptionalBool]
1368
+
1369
+ @classmethod
1370
+ def some(cls, value: Boolean) -> OptionalBool: ...
1371
+
1372
+
1373
+ converter(type(None), OptionalBool, lambda _: OptionalBool.none)
1374
+ converter(Boolean, OptionalBool, lambda x: OptionalBool.some(x))
1375
+
1376
+
1377
+ class OptionalDType(Expr, ruleset=array_api_ruleset):
1378
+ none: ClassVar[OptionalDType]
1379
+
1380
+ @classmethod
1381
+ def some(cls, value: DType) -> OptionalDType: ...
1382
+
1383
+
1384
+ converter(type(None), OptionalDType, lambda _: OptionalDType.none)
1385
+ converter(DType, OptionalDType, lambda x: OptionalDType.some(x))
1386
+
1387
+
1388
+ class OptionalDevice(Expr, ruleset=array_api_ruleset):
1389
+ none: ClassVar[OptionalDevice]
1390
+
1391
+ @classmethod
1392
+ def some(cls, value: Device) -> OptionalDevice: ...
1393
+
1394
+
1395
+ converter(type(None), OptionalDevice, lambda _: OptionalDevice.none)
1396
+ converter(Device, OptionalDevice, lambda x: OptionalDevice.some(x))
1397
+
1398
+
1399
+ class OptionalTupleInt(Expr, ruleset=array_api_ruleset):
1400
+ none: ClassVar[OptionalTupleInt]
1401
+
1402
+ @classmethod
1403
+ def some(cls, value: TupleIntLike) -> OptionalTupleInt: ...
1404
+
1405
+
1406
+ converter(type(None), OptionalTupleInt, lambda _: OptionalTupleInt.none)
1407
+ converter(TupleInt, OptionalTupleInt, lambda x: OptionalTupleInt.some(x))
1408
+
1409
+
1410
+ class IntOrTuple(Expr, ruleset=array_api_ruleset):
1411
+ none: ClassVar[IntOrTuple]
1412
+
1413
+ @classmethod
1414
+ def int(cls, value: Int) -> IntOrTuple: ...
1415
+
1416
+ @classmethod
1417
+ def tuple(cls, value: TupleIntLike) -> IntOrTuple: ...
1418
+
1419
+
1420
+ converter(Int, IntOrTuple, lambda v: IntOrTuple.int(v))
1421
+ converter(TupleInt, IntOrTuple, lambda v: IntOrTuple.tuple(v))
1422
+
1423
+
1424
+ class OptionalIntOrTuple(Expr, ruleset=array_api_ruleset):
1425
+ none: ClassVar[OptionalIntOrTuple]
1426
+
1427
+ @classmethod
1428
+ def some(cls, value: IntOrTuple) -> OptionalIntOrTuple: ...
1429
+
1430
+
1431
+ converter(type(None), OptionalIntOrTuple, lambda _: OptionalIntOrTuple.none)
1432
+ converter(IntOrTuple, OptionalIntOrTuple, lambda v: OptionalIntOrTuple.some(v))
1433
+
1434
+
1435
+ @function
1436
+ def asarray(
1437
+ a: NDArray,
1438
+ dtype: OptionalDType = OptionalDType.none,
1439
+ copy: OptionalBool = OptionalBool.none,
1440
+ device: OptionalDevice = OptionalDevice.none,
1441
+ ) -> NDArray: ...
1442
+
1443
+
1444
+ @array_api_ruleset.register
1445
+ def _assarray(a: NDArray, d: OptionalDType, ob: OptionalBool):
1446
+ yield rewrite(asarray(a, d, ob).ndim).to(a.ndim) # asarray doesn't change ndim
1447
+ yield rewrite(asarray(a)).to(a)
1448
+
1449
+
1450
+ @function
1451
+ def isfinite(x: NDArray) -> NDArray: ...
1452
+
1453
+
1454
+ @function
1455
+ def sum(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray:
1456
+ """
1457
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.sum.html?highlight=sum
1458
+ """
1459
+
1460
+
1461
+ @array_api_ruleset.register
1462
+ def _sum(x: NDArray, y: NDArray, v: Value, dtype: DType):
1463
+ return [
1464
+ rewrite(sum(x / NDArray.scalar(v))).to(sum(x) / NDArray.scalar(v)),
1465
+ # Sum of 0D array is
1466
+ ]
1467
+
1468
+
1469
+ @function
1470
+ def reshape(x: NDArray, shape: TupleIntLike, copy: OptionalBool = OptionalBool.none) -> NDArray: ...
1471
+
1472
+
1473
+ # @function
1474
+ # def reshape_transform_index(original_shape: TupleInt, shape: TupleInt, index: TupleInt) -> TupleInt:
1475
+ # """
1476
+ # Transforms an indexing operation on a reshaped array to an indexing operation on the original array.
1477
+ # """
1478
+ # ...
1479
+
1480
+
1481
+ # @function
1482
+ # def reshape_transform_shape(original_shape: TupleInt, shape: TupleInt) -> TupleInt:
1483
+ # """
1484
+ # Transforms the shape of an array to one that is reshaped, by replacing -1 with the correct value.
1485
+ # """
1486
+ # ...
1487
+
1488
+
1489
+ # @array_api_ruleset.register
1490
+ # def _reshape(
1491
+ # x: NDArray,
1492
+ # y: NDArray,
1493
+ # shape: TupleInt,
1494
+ # copy: OptionalBool,
1495
+ # i: Int,
1496
+ # s: String,
1497
+ # ix: TupleInt,
1498
+ # ):
1499
+ # return [
1500
+ # # dtype of result is same as input
1501
+ # rewrite(reshape(x, shape, copy).dtype).to(x.dtype),
1502
+ # # Indexing into a reshaped array is the same as indexing into the original array with a transformed index
1503
+ # rewrite(reshape(x, shape, copy).index(ix)).to(x.index(reshape_transform_index(x.shape, shape, ix))),
1504
+ # rewrite(reshape(x, shape, copy).shape).to(reshape_transform_shape(x.shape, shape)),
1505
+ # # reshape_transform_shape recursively
1506
+ # # TODO: handle all cases
1507
+ # rewrite(reshape_transform_shape(TupleInt(i), TupleInt(Int(-1)))).to(TupleInt(i)),
1508
+ # ]
1509
+
1510
+
1511
+ @function
1512
+ def unique_values(x: NDArrayLike) -> NDArray:
1513
+ """
1514
+ Returns the unique elements of an input array x flattened with arbitrary ordering.
1515
+
1516
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_values.html
1517
+ """
1518
+
1519
+
1520
+ @array_api_ruleset.register
1521
+ def _unique_values(x: NDArray):
1522
+ return [
1523
+ rewrite(unique_values(unique_values(x))).to(unique_values(x)),
1524
+ ]
1525
+
1526
+
1527
+ @function
1528
+ def concat(arrays: TupleNDArrayLike, axis: OptionalInt = OptionalInt.none) -> NDArray: ...
1529
+
1530
+
1531
+ @array_api_ruleset.register
1532
+ def _concat(x: NDArray):
1533
+ return [
1534
+ # only support no-op concat for now
1535
+ rewrite(concat(TupleNDArray.EMPTY.append(x))).to(x),
1536
+ ]
1537
+
1538
+
1539
+ @function
1540
+ def astype(x: NDArray, dtype: DType) -> NDArray: ...
1541
+
1542
+
1543
+ @array_api_ruleset.register
1544
+ def _astype(x: NDArray, dtype: DType, i: i64):
1545
+ return [
1546
+ rewrite(astype(x, dtype).dtype).to(dtype),
1547
+ rewrite(astype(NDArray.scalar(Value.int(Int(i))), float64)).to(
1548
+ NDArray.scalar(Value.float(Float(f64.from_i64(i))))
1549
+ ),
1550
+ ]
1551
+
1552
+
1553
+ @function
1554
+ def unique_counts(x: NDArray) -> TupleNDArray:
1555
+ """
1556
+ Returns the unique elements of an input array x and the corresponding counts for each unique element in x.
1557
+
1558
+
1559
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_counts.html
1560
+ """
1561
+
1562
+
1563
+ @array_api_ruleset.register
1564
+ def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value, dtype: DType):
1565
+ return [
1566
+ # rewrite(unique_counts(x).length()).to(Int(2)),
1567
+ rewrite(unique_counts(x)).to(TupleNDArray(2, unique_counts(x).__getitem__)),
1568
+ # Sum of all unique counts is the size of the array
1569
+ rewrite(sum(unique_counts(x)[Int(1)])).to(NDArray.scalar(Value.int(x.size))),
1570
+ # Same but with astype in the middle
1571
+ # TODO: Replace
1572
+ rewrite(sum(astype(unique_counts(x)[Int(1)], dtype))).to(astype(NDArray.scalar(Value.int(x.size)), dtype)),
1573
+ ]
1574
+
1575
+
1576
+ @function
1577
+ def square(x: NDArray) -> NDArray: ...
1578
+
1579
+
1580
+ @function
1581
+ def any(x: NDArray) -> NDArray: ...
1582
+
1583
+
1584
+ @function(egg_fn="ndarray-abs")
1585
+ def abs(x: NDArray) -> NDArray: ...
1586
+
1587
+
1588
+ @function(egg_fn="ndarray-log")
1589
+ def log(x: NDArray) -> NDArray: ...
1590
+
1591
+
1592
+ @array_api_ruleset.register
1593
+ def _abs(f: Float):
1594
+ return [
1595
+ rewrite(abs(NDArray.scalar(Value.float(f)))).to(NDArray.scalar(Value.float(f.abs()))),
1596
+ ]
1597
+
1598
+
1599
+ @function
1600
+ def unique_inverse(x: NDArray) -> TupleNDArray:
1601
+ """
1602
+ Returns the unique elements of an input array x and the indices from the set of unique elements that reconstruct x.
1603
+
1604
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_inverse.html
1605
+ """
1606
+
1607
+
1608
+ @array_api_ruleset.register
1609
+ def _unique_inverse(x: NDArray, i: Int):
1610
+ return [
1611
+ # rewrite(unique_inverse(x).length()).to(Int(2)),
1612
+ rewrite(unique_inverse(x)).to(TupleNDArray(2, unique_inverse(x).__getitem__)),
1613
+ # Shape of unique_inverse first element is same as shape of unique_values
1614
+ rewrite(unique_inverse(x)[Int(0)]).to(unique_values(x)),
1615
+ ]
1616
+
1617
+
1618
+ @function
1619
+ def zeros(
1620
+ shape: TupleIntLike, dtype: OptionalDType = OptionalDType.none, device: OptionalDevice = OptionalDevice.none
1621
+ ) -> NDArray: ...
1622
+
1623
+
1624
+ @function
1625
+ def expand_dims(x: NDArray, axis: Int = Int(0)) -> NDArray: ...
1626
+
1627
+
1628
+ @function
1629
+ def mean(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none, keepdims: Boolean = FALSE) -> NDArray: ...
1630
+
1631
+
1632
+ # TODO: Possibly change names to include modules.
1633
+ @function(egg_fn="ndarray-sqrt")
1634
+ def sqrt(x: NDArray) -> NDArray: ...
1635
+
1636
+
1637
+ @function
1638
+ def std(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray: ...
1639
+
1640
+
1641
+ @function
1642
+ def real(x: NDArray) -> NDArray: ...
1643
+
1644
+
1645
+ @function
1646
+ def conj(x: NDArray) -> NDArray: ...
1647
+
1648
+
1649
+ linalg = sys.modules[__name__]
1650
+
1651
+
1652
+ @function
1653
+ def svd(x: NDArray, full_matrices: Boolean = TRUE) -> TupleNDArray:
1654
+ """
1655
+ https://data-apis.org/array-api/2022.12/extensions/generated/array_api.linalg.svd.html
1656
+ """
1657
+
1658
+
1659
+ @array_api_ruleset.register
1660
+ def _linalg(x: NDArray, full_matrices: Boolean):
1661
+ return [
1662
+ # rewrite(svd(x, full_matrices).length()).to(Int(3)),
1663
+ rewrite(svd(x, full_matrices)).to(TupleNDArray(3, svd(x, full_matrices).__getitem__)),
1664
+ ]
1665
+
1666
+
1667
+ @function(ruleset=array_api_ruleset)
1668
+ def ndindex(shape: TupleIntLike) -> TupleTupleInt:
1669
+ """
1670
+ https://numpy.org/doc/stable/reference/generated/numpy.ndindex.html
1671
+ """
1672
+ shape = cast("TupleInt", shape)
1673
+ return shape.map_tuple_int(TupleInt.range).product()
1674
+
1675
+
1676
+ ##
1677
+ # Interval analysis
1678
+ #
1679
+ # to analyze `any(((astype(unique_counts(NDArray.var("y"))[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(150.0))) < NDArray.scalar(Value.int(Int(0)))).bool()``
1680
+ ##
1681
+
1682
+ greater_zero = relation("greater_zero", Value)
1683
+
1684
+
1685
+ # @function
1686
+ # def ndarray_all_greater_0(x: NDArray) -> Unit:
1687
+ # ...
1688
+
1689
+
1690
+ # @function
1691
+ # def ndarray_all_false(x: NDArray) -> Unit:
1692
+ # ...
1693
+
1694
+
1695
+ # @function
1696
+ # def ndarray_all_true(x: NDArray) -> Unit:
1697
+ # ...
1698
+
1699
+
1700
+ # any((astype(unique_counts(_NDArray_1)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(150.0)))) < NDArray.scalar(Value.int(Int(0)))).to_bool()
1701
+
1702
+ # sum(astype(unique_counts(_NDArray_1)[Int(1)], DType.float64) / NDArray.scalar(Value.int(Int(150))))
1703
+ # And also
1704
+
1705
+ # def
1706
+
1707
+
1708
+ @function
1709
+ def broadcast_index(from_shape: TupleIntLike, to_shape: TupleIntLike, index: TupleIntLike) -> TupleInt:
1710
+ """
1711
+ Returns the index in the original array of the given index in the broadcasted array.
1712
+ """
1713
+
1714
+
1715
+ @function
1716
+ def broadcast_shapes(shape1: TupleIntLike, shape2: TupleIntLike) -> TupleInt:
1717
+ """
1718
+ Returns the shape of the broadcasted array.
1719
+ """
1720
+
1721
+
1722
+ @array_api_ruleset.register
1723
+ def _interval_analaysis(
1724
+ x: NDArray,
1725
+ y: NDArray,
1726
+ z: NDArray,
1727
+ dtype: DType,
1728
+ f: f64,
1729
+ i: i64,
1730
+ b: Boolean,
1731
+ idx: TupleInt,
1732
+ v: Value,
1733
+ v1: Value,
1734
+ v2: Value,
1735
+ float_: Float,
1736
+ int_: Int,
1737
+ ):
1738
+ res_shape = broadcast_shapes(x.shape, y.shape)
1739
+ x_value = x.index(broadcast_index(x.shape, res_shape, idx))
1740
+ y_value = y.index(broadcast_index(y.shape, res_shape, idx))
1741
+ return [
1742
+ # Calling any on an array gives back a sclar, which is true if any of the values are truthy
1743
+ rewrite(any(x)).to(
1744
+ NDArray.scalar(Value.bool(possible_values(x.index(ALL_INDICES).to_truthy_value).contains(Value.bool(TRUE))))
1745
+ ),
1746
+ # Indexing x < y is the same as broadcasting the index and then indexing both and then comparing
1747
+ rewrite((x < y).index(idx)).to(x_value < y_value),
1748
+ # Same for x / y
1749
+ rewrite((x / y).index(idx)).to(x_value / y_value),
1750
+ # Indexing a scalar is the same as the scalar
1751
+ rewrite(NDArray.scalar(v).index(idx)).to(v),
1752
+ # Indexing of astype is same as astype of indexing
1753
+ rewrite(astype(x, dtype).index(idx)).to(x.index(idx).astype(dtype)),
1754
+ # rule(eq(y).to(x < NDArray.scalar(Value.int(Int(0)))), ndarray_all_greater_0(x)).then(ndarray_all_false(y)),
1755
+ # rule(eq(y).to(any(x)), ndarray_all_false(x)).then(union(y).with_(NDArray.scalar(Value.bool(FALSE)))),
1756
+ # Indexing into unique counts counts are all positive
1757
+ rule(
1758
+ eq(v).to(unique_counts(x)[Int(1)].index(idx)),
1759
+ ).then(greater_zero(v)),
1760
+ # Min value preserved over astype
1761
+ rule(
1762
+ greater_zero(v),
1763
+ eq(v1).to(v.astype(dtype)),
1764
+ ).then(
1765
+ greater_zero(v1),
1766
+ ),
1767
+ # Min value of scalar is scalar itself
1768
+ rule(eq(v).to(Value.float(Float(f))), f > 0.0).then(greater_zero(v)),
1769
+ rule(eq(v).to(Value.int(Int(i))), i > 0).then(greater_zero(v)),
1770
+ # If we have divison of v and v1, and both greater than zero, then the result is greater than zero
1771
+ rule(
1772
+ greater_zero(v),
1773
+ greater_zero(v1),
1774
+ eq(v2).to(v / v1),
1775
+ ).then(
1776
+ greater_zero(v2),
1777
+ ),
1778
+ # Define v < 0 to be false, if greater_zero(v)
1779
+ rule(
1780
+ greater_zero(v),
1781
+ eq(v1).to(v < Value.int(Int(0))),
1782
+ ).then(
1783
+ union(v1).with_(Value.bool(FALSE)),
1784
+ ),
1785
+ # possible values of bool is bool
1786
+ rewrite(possible_values(Value.bool(b))).to(TupleValue.EMPTY.append(Value.bool(b))),
1787
+ # casting to a type preserves if > 0
1788
+ rule(
1789
+ eq(v1).to(v.astype(dtype)),
1790
+ greater_zero(v),
1791
+ ).then(
1792
+ greater_zero(v1),
1793
+ ),
1794
+ ]
1795
+
1796
+
1797
+ ##
1798
+ # Mathematical descriptions of arrays as:
1799
+ # 1. A shape `.shape`
1800
+ # 2. A dtype `.dtype`
1801
+ # 3. A mapping from indices to values `x.index(idx)`
1802
+ #
1803
+ # For all operations that are supported mathematically, define each of the above.
1804
+ ##
1805
+
1806
+
1807
+ def _demand_shape(compound: NDArray, inner: NDArray) -> Command:
1808
+ __a = var("__a", NDArray)
1809
+ return rule(eq(__a).to(compound)).then(inner.shape, inner.shape.length())
1810
+
1811
+
1812
+ @array_api_ruleset.register
1813
+ def _scalar_math(v: Value, vs: TupleValue, i: Int):
1814
+ yield rewrite(NDArray.scalar(v).shape).to(TupleInt.EMPTY)
1815
+ yield rewrite(NDArray.scalar(v).dtype).to(v.dtype)
1816
+ yield rewrite(NDArray.scalar(v).index(TupleInt.EMPTY)).to(v)
1817
+
1818
+
1819
+ @array_api_ruleset.register
1820
+ def _vector_math(v: Value, vs: TupleValue, ti: TupleInt):
1821
+ yield rewrite(NDArray.vector(vs).shape).to(TupleInt.single(vs.length()))
1822
+ yield rewrite(NDArray.vector(vs).dtype).to(vs[Int(0)].dtype)
1823
+ yield rewrite(NDArray.vector(vs).index(ti)).to(vs[ti[0]])
1824
+
1825
+
1826
+ @array_api_ruleset.register
1827
+ def _reshape_math(x: NDArray, shape: TupleInt, copy: OptionalBool):
1828
+ res = reshape(x, shape, copy)
1829
+
1830
+ yield _demand_shape(res, x)
1831
+ # Demand shape length and index
1832
+ yield rule(res).then(shape.length(), shape[0])
1833
+
1834
+ # Reshaping a vec to a vec is the same as the vec
1835
+ yield rewrite(res).to(
1836
+ x,
1837
+ eq(x.shape.length()).to(Int(1)),
1838
+ eq(shape.length()).to(Int(1)),
1839
+ eq(shape[0]).to(Int(-1)),
1840
+ )
1841
+
1842
+
1843
+ @array_api_ruleset.register
1844
+ def _indexing_pushdown(x: NDArray, shape: TupleInt, copy: OptionalBool, i: Int):
1845
+ # rewrite full getitem to indexec
1846
+ yield rewrite(x[IndexKey.int(i)]).to(NDArray.scalar(x.index(TupleInt.single(i))))
1847
+ # TODO: Multi index rewrite as well if all are ints
1848
+
1849
+
1850
+ ##
1851
+ # Assumptions
1852
+ ##
1853
+
1854
+
1855
+ @function(mutates_first_arg=True)
1856
+ def assume_dtype(x: NDArray, dtype: DType) -> None:
1857
+ """
1858
+ Asserts that the dtype of x is dtype.
1859
+ """
1860
+
1861
+
1862
+ @array_api_ruleset.register
1863
+ def _assume_dtype(x: NDArray, dtype: DType, idx: TupleInt):
1864
+ orig_x = copy(x)
1865
+ assume_dtype(x, dtype)
1866
+ yield rewrite(x.dtype).to(dtype)
1867
+ yield rewrite(x.shape).to(orig_x.shape)
1868
+ yield rewrite(x.index(idx)).to(orig_x.index(idx))
1869
+
1870
+
1871
+ @function(mutates_first_arg=True)
1872
+ def assume_shape(x: NDArray, shape: TupleIntLike) -> None:
1873
+ """
1874
+ Asserts that the shape of x is shape.
1875
+ """
1876
+
1877
+
1878
+ @array_api_ruleset.register
1879
+ def _assume_shape(x: NDArray, shape: TupleInt, idx: TupleInt):
1880
+ orig_x = copy(x)
1881
+ assume_shape(x, shape)
1882
+ yield rewrite(x.shape).to(shape)
1883
+ yield rewrite(x.dtype).to(orig_x.dtype)
1884
+ yield rewrite(x.index(idx)).to(orig_x.index(idx))
1885
+
1886
+
1887
+ @function(mutates_first_arg=True)
1888
+ def assume_isfinite(x: NDArray) -> None:
1889
+ """
1890
+ Asserts that the scalar ndarray is non null and not infinite.
1891
+ """
1892
+
1893
+
1894
+ @array_api_ruleset.register
1895
+ def _isfinite(x: NDArray, ti: TupleInt):
1896
+ orig_x = copy(x)
1897
+ assume_isfinite(x)
1898
+
1899
+ # pass through getitem, shape, index
1900
+ yield rewrite(x.shape).to(orig_x.shape)
1901
+ yield rewrite(x.dtype).to(orig_x.dtype)
1902
+ yield rewrite(x.index(ti)).to(orig_x.index(ti))
1903
+ # But say that any indixed value is finite
1904
+ yield rewrite(x.index(ti).isfinite()).to(TRUE)
1905
+
1906
+
1907
+ @function(mutates_first_arg=True)
1908
+ def assume_value_one_of(x: NDArray, values: TupleValueLike) -> None:
1909
+ """
1910
+ A value that is one of the values in the tuple.
1911
+ """
1912
+
1913
+
1914
+ @array_api_ruleset.register
1915
+ def _assume_value_one_of(x: NDArray, v: Value, vs: TupleValue, idx: TupleInt):
1916
+ x_orig = copy(x)
1917
+ assume_value_one_of(x, vs)
1918
+ # Pass through dtype and shape
1919
+ yield rewrite(x.shape).to(x_orig.shape)
1920
+ yield rewrite(x.dtype).to(x_orig.dtype)
1921
+ # The array vales passes through, but say that the possible_values are one of the values
1922
+ yield rule(eq(v).to(x.index(idx))).then(
1923
+ union(v).with_(x_orig.index(idx)),
1924
+ union(possible_values(v)).with_(vs),
1925
+ )
1926
+
1927
+
1928
+ @array_api_ruleset.register
1929
+ def _ndarray_value_isfinite(arr: NDArray, x: Value, xs: TupleValue, i: Int, f: f64, b: Boolean):
1930
+ yield rewrite(Value.int(i).isfinite()).to(TRUE)
1931
+ yield rewrite(Value.bool(b).isfinite()).to(TRUE)
1932
+ yield rewrite(Value.float(Float(f)).isfinite()).to(TRUE, ne(f).to(f64(math.nan)))
1933
+
1934
+ # a sum of an array is finite if all the values are finite
1935
+ yield rewrite(isfinite(sum(arr))).to(NDArray.scalar(Value.bool(arr.index(ALL_INDICES).isfinite())))
1936
+
1937
+
1938
+ @array_api_ruleset.register
1939
+ def _unique(xs: TupleValue, a: NDArray, shape: TupleInt, copy: OptionalBool):
1940
+ yield rewrite(unique_values(x=a)).to(NDArray.vector(possible_values(a.index(ALL_INDICES))))
1941
+ # yield rewrite(
1942
+ # possible_values(reshape(a.index(shape, copy), ALL_INDICES)),
1943
+ # ).to(possible_values(a.index(ALL_INDICES)))
1944
+
1945
+
1946
+ @array_api_ruleset.register
1947
+ def _size(x: NDArray):
1948
+ yield rewrite(x.size).to(x.shape.foldl(Int.__mul__, Int(1)))
1949
+
1950
+
1951
+ # Seperate rulseset so we can use it in program gen
1952
+ @ruleset
1953
+ def array_api_vec_to_cons_ruleset(
1954
+ vs: Vec[Int],
1955
+ vv: Vec[Value],
1956
+ vn: Vec[NDArray],
1957
+ vt: Vec[TupleInt],
1958
+ ):
1959
+ yield rewrite(TupleInt.from_vec(vs)).to(TupleInt.EMPTY, eq(vs.length()).to(i64(0)))
1960
+ yield rewrite(TupleInt.from_vec(vs)).to(
1961
+ TupleInt.from_vec(vs.remove(vs.length() - 1)).append(vs[vs.length() - 1]), ne(vs.length()).to(i64(0))
1962
+ )
1963
+
1964
+ yield rewrite(TupleValue.from_vec(vv)).to(TupleValue.EMPTY, eq(vv.length()).to(i64(0)))
1965
+ yield rewrite(TupleValue.from_vec(vv)).to(
1966
+ TupleValue.from_vec(vv.remove(vv.length() - 1)).append(vv[vv.length() - 1]), ne(vv.length()).to(i64(0))
1967
+ )
1968
+
1969
+ yield rewrite(TupleTupleInt.from_vec(vt)).to(TupleTupleInt.EMPTY, eq(vt.length()).to(i64(0)))
1970
+ yield rewrite(TupleTupleInt.from_vec(vt)).to(
1971
+ TupleTupleInt.from_vec(vt.remove(vt.length() - 1)).append(vt[vt.length() - 1]), ne(vt.length()).to(i64(0))
1972
+ )
1973
+ yield rewrite(TupleNDArray.from_vec(vn)).to(TupleNDArray.EMPTY, eq(vn.length()).to(i64(0)))
1974
+ yield rewrite(TupleNDArray.from_vec(vn)).to(
1975
+ TupleNDArray.from_vec(vn.remove(vn.length() - 1)).append(vn[vn.length() - 1]), ne(vn.length()).to(i64(0))
1976
+ )
1977
+
1978
+
1979
+ array_api_combined_ruleset = array_api_ruleset | array_api_vec_to_cons_ruleset
1980
+ array_api_schedule = array_api_combined_ruleset.saturate()
1981
+
1982
+ _CURRENT_EGRAPH: None | EGraph = None
1983
+
1984
+
1985
+ @contextlib.contextmanager
1986
+ def set_array_api_egraph(egraph: EGraph) -> Iterator[None]:
1987
+ """
1988
+ Context manager that will set the current egraph. It will be set back after.
1989
+ """
1990
+ global _CURRENT_EGRAPH
1991
+ assert _CURRENT_EGRAPH is None
1992
+ _CURRENT_EGRAPH = egraph
1993
+ yield
1994
+ _CURRENT_EGRAPH = None
1995
+
1996
+
1997
+ def _get_current_egraph() -> EGraph:
1998
+ return _CURRENT_EGRAPH or EGraph()
1999
+
2000
+
2001
+ def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: BuiltinExpr) -> Any:
2002
+ """
2003
+ Try evaling the expression that will result in a primitive expression being fill.
2004
+ if it fails, display the egraph and raise an error.
2005
+ """
2006
+ try:
2007
+ extracted = egraph.extract(prim_expr)
2008
+ except EggSmolError:
2009
+ # If this primitive doesn't exist in the egraph, we need to try to create it by
2010
+ # registering the expression and running the schedule
2011
+ egraph.register(expr)
2012
+ egraph.run(schedule)
2013
+ try:
2014
+ extracted = egraph.extract(prim_expr)
2015
+ except BaseException as e:
2016
+ # egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
2017
+ e.add_note(f"Cannot evaluate {egraph.extract(expr)}")
2018
+ raise
2019
+ return extracted.value # type: ignore[attr-defined]