egglog 9.0.0__pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

Files changed (44) hide show
  1. egglog/__init__.py +10 -0
  2. egglog/bindings.pyi +667 -0
  3. egglog/bindings.pypy311-pp73-aarch64-linux-gnu.so +0 -0
  4. egglog/builtins.py +1045 -0
  5. egglog/config.py +8 -0
  6. egglog/conversion.py +262 -0
  7. egglog/declarations.py +818 -0
  8. egglog/egraph.py +1909 -0
  9. egglog/egraph_state.py +634 -0
  10. egglog/examples/README.rst +5 -0
  11. egglog/examples/__init__.py +3 -0
  12. egglog/examples/bignum.py +31 -0
  13. egglog/examples/bool.py +38 -0
  14. egglog/examples/eqsat_basic.py +45 -0
  15. egglog/examples/fib.py +28 -0
  16. egglog/examples/higher_order_functions.py +45 -0
  17. egglog/examples/lambda_.py +288 -0
  18. egglog/examples/matrix.py +175 -0
  19. egglog/examples/multiset.py +61 -0
  20. egglog/examples/ndarrays.py +144 -0
  21. egglog/examples/resolution.py +84 -0
  22. egglog/examples/schedule_demo.py +34 -0
  23. egglog/exp/__init__.py +3 -0
  24. egglog/exp/array_api.py +1943 -0
  25. egglog/exp/array_api_jit.py +44 -0
  26. egglog/exp/array_api_loopnest.py +74 -0
  27. egglog/exp/array_api_numba.py +69 -0
  28. egglog/exp/array_api_program_gen.py +510 -0
  29. egglog/exp/program_gen.py +424 -0
  30. egglog/exp/siu_examples.py +32 -0
  31. egglog/functionalize.py +91 -0
  32. egglog/ipython_magic.py +41 -0
  33. egglog/pretty.py +510 -0
  34. egglog/py.typed +0 -0
  35. egglog/runtime.py +633 -0
  36. egglog/thunk.py +95 -0
  37. egglog/type_constraint_solver.py +113 -0
  38. egglog/visualizer.css +1 -0
  39. egglog/visualizer.js +35777 -0
  40. egglog/visualizer_widget.py +39 -0
  41. egglog-9.0.0.dist-info/METADATA +74 -0
  42. egglog-9.0.0.dist-info/RECORD +44 -0
  43. egglog-9.0.0.dist-info/WHEEL +4 -0
  44. egglog-9.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1943 @@
1
+ """
2
+
3
+
4
+ ## Lists
5
+
6
+ Lists have two main constructors:
7
+
8
+ - `List(length, idx_fn)`
9
+ - `List.EMPTY` / `initial.append(last)`
10
+
11
+ This is so that they can be defined either with a known fixed integer length (the cons list type) or a symbolic
12
+ length that could not be resolved to an integer.
13
+
14
+ There are rewrites to convert between these constructors in both directions. The only limitation however is that
15
+ `length` has to a real i64 in order to be converted to a cons list.
16
+
17
+ When you are writing a function that uses ints, feel free to the `__getitem__` or `length()` methods or match
18
+ directly on `List()` constructor. If you can write your function using that interface please do. But for some other
19
+ methods whether the resulting length/index function is dependent on the rest of it, you can only define it with a known
20
+ length, so you can then use the const list constructors.
21
+
22
+ We also support creating lists from vectors. These can be converted one to one to the snoc list representation.
23
+
24
+ It is troublesome to have to redefine lists for every type. It would be nice to have generic types, but they are not implemented yet.
25
+
26
+ We are gauranteed that all lists with known lengths will be represented as cons/empty. To safely use lists, use
27
+ the `.length` and `.__getitem__` methods, unles you want to to depend on it having known length, in which
28
+ case you can match directly on the cons list.
29
+
30
+ To be a list, you must implement two methods:
31
+
32
+ * `l.length() -> Int`
33
+ * `l.__getitem__(i: Int) -> T`
34
+
35
+ There are three main types of constructors for lists which all implement these methods:
36
+
37
+ * Functional `List(length, idx_fn)`
38
+ * cons (well reversed cons) lists `List.EMPTY` and `l.append(x)`
39
+ * Vectors `List.from_vec(vec)`
40
+
41
+ Also all lists constructors must be converted to the functional representation, so that we can match on it
42
+ and convert lists with known lengths into cons lists and into vectors.
43
+
44
+ This is neccessary so that known length lists are properly materialized during extraction.
45
+
46
+ Q: Why are they implemented as SNOC lists instead of CONS lists?
47
+ A: So that when converting from functional to lists we can use the same index function by starting at the end and folding
48
+ that way recursively.
49
+
50
+
51
+ """
52
+
53
+ # mypy: disable-error-code="empty-body"
54
+
55
+ from __future__ import annotations
56
+
57
+ import itertools
58
+ import math
59
+ import numbers
60
+ import os
61
+ import sys
62
+ from copy import copy
63
+ from types import EllipsisType
64
+ from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast
65
+
66
+ import numpy as np
67
+
68
+ from egglog import *
69
+ from egglog.runtime import RuntimeExpr
70
+
71
+ from .program_gen import *
72
+
73
+ if TYPE_CHECKING:
74
+ from collections.abc import Callable, Iterator
75
+ from types import ModuleType
76
+
77
+ # Pretend that exprs are numbers b/c sklearn does isinstance checks
78
+ numbers.Integral.register(RuntimeExpr)
79
+
80
+ # Set this to 1 before scipy is ever imported
81
+ # https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html#using-array-api-standard-support
82
+ os.environ["SCIPY_ARRAY_API"] = "1"
83
+
84
+ array_api_ruleset = ruleset(name="array_api_ruleset")
85
+
86
+
87
+ class Boolean(Expr, ruleset=array_api_ruleset):
88
+ def __init__(self, value: BoolLike) -> None: ...
89
+
90
+ @method(preserve=True)
91
+ def __bool__(self) -> bool:
92
+ return self.eval()
93
+
94
+ @method(preserve=True)
95
+ def eval(self) -> bool:
96
+ return try_evaling(array_api_schedule, self, self.to_bool)
97
+
98
+ @property
99
+ def to_bool(self) -> Bool: ...
100
+
101
+ def __or__(self, other: BooleanLike) -> Boolean: ...
102
+
103
+ def __and__(self, other: BooleanLike) -> Boolean: ...
104
+
105
+ def __invert__(self) -> Boolean: ...
106
+
107
+ def __eq__(self, other: BooleanLike) -> Boolean: ... # type: ignore[override]
108
+
109
+
110
+ BooleanLike = Boolean | BoolLike
111
+
112
+ TRUE = Boolean(True)
113
+ FALSE = Boolean(False)
114
+ converter(Bool, Boolean, Boolean)
115
+
116
+
117
+ @array_api_ruleset.register
118
+ def _bool(x: Boolean, i: Int, j: Int, b: Bool):
119
+ return [
120
+ rule(eq(x).to(Boolean(b))).then(set_(x.to_bool).to(b)),
121
+ rewrite(TRUE | x).to(TRUE),
122
+ rewrite(FALSE | x).to(x),
123
+ rewrite(TRUE & x).to(x),
124
+ rewrite(FALSE & x).to(FALSE),
125
+ rewrite(~TRUE).to(FALSE),
126
+ rewrite(~FALSE).to(TRUE),
127
+ rule(eq(FALSE).to(TRUE)).then(panic("False cannot equal True")),
128
+ rewrite(x == x).to(TRUE), # noqa: PLR0124
129
+ rewrite(FALSE == TRUE).to(FALSE),
130
+ rewrite(TRUE == FALSE).to(FALSE),
131
+ ]
132
+
133
+
134
+ class Int(Expr, ruleset=array_api_ruleset):
135
+ # a never int is that should not exist. It could represent for example indexing into an array a value that is out of bounds
136
+ # https://en.wikipedia.org/wiki/Bottom_type
137
+ NEVER: ClassVar[Int]
138
+
139
+ @classmethod
140
+ def var(cls, name: StringLike) -> Int: ...
141
+
142
+ def __init__(self, value: i64Like) -> None: ...
143
+
144
+ def __invert__(self) -> Int: ...
145
+
146
+ def __lt__(self, other: IntLike) -> Boolean: ...
147
+
148
+ def __le__(self, other: IntLike) -> Boolean: ...
149
+
150
+ def __eq__(self, other: IntLike) -> Boolean: # type: ignore[override]
151
+ ...
152
+
153
+ # TODO: Fix this?
154
+ # Make != always return a Bool, so that numpy.unique works on a tuple of ints
155
+ # In _unique1d
156
+ @method(preserve=True)
157
+ def __ne__(self, other: Int) -> bool: # type: ignore[override]
158
+ return not (self == other)
159
+
160
+ def __gt__(self, other: IntLike) -> Boolean: ...
161
+
162
+ def __ge__(self, other: IntLike) -> Boolean: ...
163
+
164
+ def __add__(self, other: IntLike) -> Int: ...
165
+
166
+ def __sub__(self, other: IntLike) -> Int: ...
167
+
168
+ def __mul__(self, other: IntLike) -> Int: ...
169
+
170
+ def __truediv__(self, other: IntLike) -> Int: ...
171
+
172
+ def __floordiv__(self, other: IntLike) -> Int: ...
173
+
174
+ def __mod__(self, other: IntLike) -> Int: ...
175
+
176
+ def __divmod__(self, other: IntLike) -> Int: ...
177
+
178
+ def __pow__(self, other: IntLike) -> Int: ...
179
+
180
+ def __lshift__(self, other: IntLike) -> Int: ...
181
+
182
+ def __rshift__(self, other: IntLike) -> Int: ...
183
+
184
+ def __and__(self, other: IntLike) -> Int: ...
185
+
186
+ def __xor__(self, other: IntLike) -> Int: ...
187
+
188
+ def __or__(self, other: IntLike) -> Int: ...
189
+
190
+ def __radd__(self, other: IntLike) -> Int: ...
191
+
192
+ def __rsub__(self, other: IntLike) -> Int: ...
193
+
194
+ def __rmul__(self, other: IntLike) -> Int: ...
195
+
196
+ def __rmatmul__(self, other: IntLike) -> Int: ...
197
+
198
+ def __rtruediv__(self, other: IntLike) -> Int: ...
199
+
200
+ def __rfloordiv__(self, other: IntLike) -> Int: ...
201
+
202
+ def __rmod__(self, other: IntLike) -> Int: ...
203
+
204
+ def __rpow__(self, other: IntLike) -> Int: ...
205
+
206
+ def __rlshift__(self, other: IntLike) -> Int: ...
207
+
208
+ def __rrshift__(self, other: IntLike) -> Int: ...
209
+
210
+ def __rand__(self, other: IntLike) -> Int: ...
211
+
212
+ def __rxor__(self, other: IntLike) -> Int: ...
213
+
214
+ def __ror__(self, other: IntLike) -> Int: ...
215
+
216
+ @property
217
+ def to_i64(self) -> i64: ...
218
+
219
+ @method(preserve=True)
220
+ def eval(self) -> int:
221
+ return try_evaling(array_api_schedule, self, self.to_i64)
222
+
223
+ @method(preserve=True)
224
+ def __index__(self) -> int:
225
+ return self.eval()
226
+
227
+ @method(preserve=True)
228
+ def __int__(self) -> int:
229
+ return self.eval()
230
+
231
+ @method(preserve=True)
232
+ def __float__(self) -> float:
233
+ return float(self.eval())
234
+
235
+ @method(preserve=True)
236
+ def __bool__(self) -> bool:
237
+ return bool(self.eval())
238
+
239
+ @classmethod
240
+ def if_(cls, b: BooleanLike, i: IntLike, j: IntLike) -> Int: ...
241
+
242
+
243
+ @array_api_ruleset.register
244
+ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
245
+ yield rewrite(Int(i) == Int(i)).to(TRUE)
246
+ yield rule(eq(r).to(Int(i) == Int(j)), ne(i).to(j)).then(union(r).with_(FALSE))
247
+
248
+ yield rewrite(Int(i) >= Int(i)).to(TRUE)
249
+ yield rule(eq(r).to(Int(i) >= Int(j)), i > j).then(union(r).with_(TRUE))
250
+ yield rule(eq(r).to(Int(i) >= Int(j)), i < j).then(union(r).with_(FALSE))
251
+
252
+ yield rewrite(Int(i) < Int(i)).to(FALSE)
253
+ yield rule(eq(r).to(Int(i) < Int(j)), i < j).then(union(r).with_(TRUE))
254
+ yield rule(eq(r).to(Int(i) < Int(j)), i > j).then(union(r).with_(FALSE))
255
+
256
+ yield rewrite(Int(i) > Int(i)).to(FALSE)
257
+ yield rule(eq(r).to(Int(i) > Int(j)), i > j).then(union(r).with_(TRUE))
258
+ yield rule(eq(r).to(Int(i) > Int(j)), i < j).then(union(r).with_(FALSE))
259
+
260
+ yield rule(eq(o).to(Int(j))).then(set_(o.to_i64).to(j))
261
+
262
+ yield rule(eq(Int(i)).to(Int(j)), ne(i).to(j)).then(panic("Real ints cannot be equal to different ints"))
263
+
264
+ yield rewrite(Int(i) + Int(j)).to(Int(i + j))
265
+ yield rewrite(Int(i) - Int(j)).to(Int(i - j))
266
+ yield rewrite(Int(i) * Int(j)).to(Int(i * j))
267
+ yield rewrite(Int(i) // Int(j)).to(Int(i / j))
268
+ yield rewrite(Int(i) % Int(j)).to(Int(i % j))
269
+ yield rewrite(Int(i) & Int(j)).to(Int(i & j))
270
+ yield rewrite(Int(i) | Int(j)).to(Int(i | j))
271
+ yield rewrite(Int(i) ^ Int(j)).to(Int(i ^ j))
272
+ yield rewrite(Int(i) << Int(j)).to(Int(i << j))
273
+ yield rewrite(Int(i) >> Int(j)).to(Int(i >> j))
274
+ yield rewrite(~Int(i)).to(Int(~i))
275
+
276
+ yield rewrite(Int.if_(TRUE, o, b), subsume=True).to(o)
277
+ yield rewrite(Int.if_(FALSE, o, b), subsume=True).to(b)
278
+
279
+ # Never cannot be equal to anything real
280
+ yield rule(eq(Int.NEVER).to(Int(i))).then(panic("Int.NEVER cannot be equal to any real int"))
281
+
282
+
283
+ converter(i64, Int, lambda x: Int(x))
284
+
285
+ IntLike: TypeAlias = Int | i64Like
286
+
287
+
288
+ @function(ruleset=array_api_ruleset)
289
+ def check_index(length: IntLike, idx: IntLike) -> Int:
290
+ """
291
+ Returns the index if 0 <= idx < length, otherwise returns Int.NEVER
292
+ """
293
+ length = cast("Int", length)
294
+ idx = cast("Int", idx)
295
+ return Int.if_(((idx >= 0) & (idx < length)), idx, Int.NEVER)
296
+
297
+
298
+ # @array_api_ruleset.register
299
+ # def _check_index(i: i64, j: i64, x: Int):
300
+ # yield rewrite(
301
+ # check_index(Int(i), Int(j)),
302
+ # ).to(
303
+ # Int(j),
304
+ # i >= 0,
305
+ # i < j,
306
+ # )
307
+
308
+ # yield rewrite(
309
+ # check_index(x, Int(i)),
310
+ # ).to(
311
+ # Int.NEVER,
312
+ # i < 0,
313
+ # )
314
+
315
+ # yield rewrite(
316
+ # check_index(Int(i), Int(j)),
317
+ # ).to(
318
+ # Int.NEVER,
319
+ # i >= j,
320
+ # )
321
+
322
+
323
+ class Float(Expr, ruleset=array_api_ruleset):
324
+ # Differentiate costs of three constructors so extraction is deterministic if all three are present
325
+ @method(cost=3)
326
+ def __init__(self, value: f64Like) -> None: ...
327
+
328
+ @property
329
+ def to_f64(self) -> f64: ...
330
+
331
+ @method(preserve=True)
332
+ def eval(self) -> float:
333
+ return try_evaling(array_api_schedule, self, self.to_f64)
334
+
335
+ def abs(self) -> Float: ...
336
+
337
+ @method(cost=2)
338
+ @classmethod
339
+ def rational(cls, r: BigRat) -> Float: ...
340
+
341
+ @classmethod
342
+ def from_int(cls, i: IntLike) -> Float: ...
343
+
344
+ def __truediv__(self, other: FloatLike) -> Float: ...
345
+
346
+ def __mul__(self, other: FloatLike) -> Float: ...
347
+
348
+ def __add__(self, other: FloatLike) -> Float: ...
349
+
350
+ def __sub__(self, other: FloatLike) -> Float: ...
351
+
352
+ def __pow__(self, other: FloatLike) -> Float: ...
353
+
354
+ def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
355
+
356
+
357
+ converter(float, Float, lambda x: Float(x))
358
+ converter(Int, Float, lambda x: Float.from_int(x))
359
+
360
+
361
+ FloatLike: TypeAlias = Float | float | IntLike
362
+
363
+
364
+ @array_api_ruleset.register
365
+ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
366
+ return [
367
+ rule(eq(fl).to(Float(f))).then(set_(fl.to_f64).to(f)),
368
+ rewrite(Float(f).abs()).to(Float(f), f >= 0.0),
369
+ rewrite(Float(f).abs()).to(Float(-f), f < 0.0),
370
+ # Convert from float to rationl, if its a whole number i.e. can be converted to int
371
+ rewrite(Float(f)).to(Float.rational(BigRat(f.to_i64(), 1)), eq(f64.from_i64(f.to_i64())).to(f)),
372
+ # always convert from int to rational
373
+ rewrite(Float.from_int(Int(i))).to(Float.rational(BigRat(i, 1))),
374
+ rewrite(Float(f) + Float(f2)).to(Float(f + f2)),
375
+ rewrite(Float(f) - Float(f2)).to(Float(f - f2)),
376
+ rewrite(Float(f) * Float(f2)).to(Float(f * f2)),
377
+ rewrite(Float.rational(r) / Float.rational(r1)).to(Float.rational(r / r1)),
378
+ rewrite(Float.rational(r) + Float.rational(r1)).to(Float.rational(r + r1)),
379
+ rewrite(Float.rational(r) - Float.rational(r1)).to(Float.rational(r - r1)),
380
+ rewrite(Float.rational(r) * Float.rational(r1)).to(Float.rational(r * r1)),
381
+ rewrite(Float(f) ** Float(f2)).to(Float(f**f2)),
382
+ # ==
383
+ rewrite(Float(f) == Float(f)).to(TRUE),
384
+ rewrite(Float(f) == Float(f2)).to(FALSE, ne(f).to(f2)),
385
+ rewrite(Float.rational(r) == Float.rational(r)).to(TRUE),
386
+ rewrite(Float.rational(r) == Float.rational(r1)).to(FALSE, ne(r).to(r1)),
387
+ ]
388
+
389
+
390
+ class TupleInt(Expr, ruleset=array_api_ruleset):
391
+ """
392
+ Should act like a tuple[int, ...]
393
+
394
+ All constructors should be rewritten to the functional semantics in the __init__ method.
395
+ """
396
+
397
+ @classmethod
398
+ def var(cls, name: StringLike) -> TupleInt: ...
399
+
400
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], Int]) -> None: ...
401
+
402
+ EMPTY: ClassVar[TupleInt]
403
+ NEVER: ClassVar[TupleInt]
404
+
405
+ def append(self, i: IntLike) -> TupleInt: ...
406
+
407
+ @classmethod
408
+ def single(cls, i: Int) -> TupleInt:
409
+ return TupleInt(Int(1), lambda _: i)
410
+
411
+ @method(subsume=True)
412
+ @classmethod
413
+ def range(cls, stop: IntLike) -> TupleInt:
414
+ return TupleInt(stop, lambda i: i)
415
+
416
+ @classmethod
417
+ def from_vec(cls, vec: VecLike[Int, IntLike]) -> TupleInt: ...
418
+
419
+ def __add__(self, other: TupleIntLike) -> TupleInt:
420
+ other = cast("TupleInt", other)
421
+ return TupleInt(
422
+ self.length() + other.length(), lambda i: Int.if_(i < self.length(), self[i], other[i - self.length()])
423
+ )
424
+
425
+ def length(self) -> Int: ...
426
+ def __getitem__(self, i: IntLike) -> Int: ...
427
+
428
+ @method(preserve=True)
429
+ def __len__(self) -> int:
430
+ return self.length().eval()
431
+
432
+ @method(preserve=True)
433
+ def __iter__(self) -> Iterator[Int]:
434
+ return iter(self.eval())
435
+
436
+ @property
437
+ def to_vec(self) -> Vec[Int]: ...
438
+
439
+ @method(preserve=True)
440
+ def eval(self) -> tuple[Int, ...]:
441
+ return try_evaling(array_api_schedule, self, self.to_vec)
442
+
443
+ def foldl(self, f: Callable[[Int, Int], Int], init: Int) -> Int: ...
444
+ def foldl_boolean(self, f: Callable[[Boolean, Int], Boolean], init: Boolean) -> Boolean: ...
445
+ def foldl_tuple_int(self, f: Callable[[TupleInt, Int], TupleInt], init: TupleIntLike) -> TupleInt: ...
446
+
447
+ @method(subsume=True)
448
+ def contains(self, i: Int) -> Boolean:
449
+ return self.foldl_boolean(lambda acc, j: acc | (i == j), FALSE)
450
+
451
+ @method(subsume=True)
452
+ def filter(self, f: Callable[[Int], Boolean]) -> TupleInt:
453
+ return self.foldl_tuple_int(
454
+ lambda acc, v: TupleInt.if_(f(v), acc.append(v), acc),
455
+ TupleInt.EMPTY,
456
+ )
457
+
458
+ @method(subsume=True)
459
+ def map(self, f: Callable[[Int], Int]) -> TupleInt:
460
+ return TupleInt(self.length(), lambda i: f(self[i]))
461
+
462
+ @classmethod
463
+ def if_(cls, b: BooleanLike, i: TupleIntLike, j: TupleIntLike) -> TupleInt: ...
464
+
465
+ def drop(self, n: Int) -> TupleInt:
466
+ return TupleInt(self.length() - n, lambda i: self[i + n])
467
+
468
+ def product(self) -> Int:
469
+ return self.foldl(lambda acc, i: acc * i, Int(1))
470
+
471
+ def map_tuple_int(self, f: Callable[[Int], TupleInt]) -> TupleTupleInt:
472
+ return TupleTupleInt(self.length(), lambda i: f(self[i]))
473
+
474
+ def select(self, indices: TupleIntLike) -> TupleInt:
475
+ """
476
+ Return a new tuple with the elements at the given indices
477
+ """
478
+ indices = cast("TupleInt", indices)
479
+ return indices.map(lambda i: self[i])
480
+
481
+ def deselect(self, indices: TupleIntLike) -> TupleInt:
482
+ """
483
+ Return a new tuple with the elements not at the given indices
484
+ """
485
+ indices = cast("TupleInt", indices)
486
+ return TupleInt.range(self.length()).filter(lambda i: ~indices.contains(i)).map(lambda i: self[i])
487
+
488
+
489
+ converter(Vec[Int], TupleInt, lambda x: TupleInt.from_vec(x))
490
+
491
+ TupleIntLike: TypeAlias = TupleInt | VecLike[Int, IntLike]
492
+
493
+
494
+ @array_api_ruleset.register
495
+ def _tuple_int(
496
+ i: Int,
497
+ i2: Int,
498
+ f: Callable[[Int, Int], Int],
499
+ bool_f: Callable[[Boolean, Int], Boolean],
500
+ idx_fn: Callable[[Int], Int],
501
+ tuple_int_f: Callable[[TupleInt, Int], TupleInt],
502
+ vs: Vec[Int],
503
+ b: Boolean,
504
+ ti: TupleInt,
505
+ ti2: TupleInt,
506
+ k: i64,
507
+ ):
508
+ return [
509
+ rule(eq(ti).to(TupleInt.from_vec(vs))).then(set_(ti.to_vec).to(vs)),
510
+ # Functional access
511
+ rewrite(TupleInt(i, idx_fn).length()).to(i),
512
+ rewrite(TupleInt(i, idx_fn)[i2]).to(idx_fn(check_index(i, i2))),
513
+ # cons access
514
+ rewrite(TupleInt.EMPTY.length()).to(Int(0)),
515
+ rewrite(TupleInt.EMPTY[i]).to(Int.NEVER),
516
+ rewrite(ti.append(i).length()).to(ti.length() + 1),
517
+ rewrite(ti.append(i)[i2]).to(Int.if_(i2 == ti.length(), i, ti[i2])),
518
+ # cons to functional (removed this so that there is not infinite replacements between the,)
519
+ # rewrite(TupleInt.EMPTY).to(TupleInt(0, lambda _: Int.NEVER)),
520
+ # rewrite(TupleInt(i, idx_fn).append(i2)).to(TupleInt(i + 1, lambda j: Int.if_(j == i, i2, idx_fn(j)))),
521
+ # functional to cons
522
+ rewrite(TupleInt(0, idx_fn), subsume=True).to(TupleInt.EMPTY),
523
+ rewrite(TupleInt(Int(k), idx_fn), subsume=True).to(TupleInt(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0),
524
+ # cons to vec
525
+ rewrite(TupleInt.EMPTY).to(TupleInt.from_vec(Vec[Int]())),
526
+ rewrite(TupleInt.from_vec(vs).append(i)).to(TupleInt.from_vec(vs.append(Vec(i)))),
527
+ # fold
528
+ rewrite(TupleInt.EMPTY.foldl(f, i), subsume=True).to(i),
529
+ rewrite(ti.append(i2).foldl(f, i), subsume=True).to(f(ti.foldl(f, i), i2)),
530
+ # fold boolean
531
+ rewrite(TupleInt.EMPTY.foldl_boolean(bool_f, b), subsume=True).to(b),
532
+ rewrite(ti.append(i2).foldl_boolean(bool_f, b), subsume=True).to(bool_f(ti.foldl_boolean(bool_f, b), i2)),
533
+ # fold tuple_int
534
+ rewrite(TupleInt.EMPTY.foldl_tuple_int(tuple_int_f, ti), subsume=True).to(ti),
535
+ rewrite(ti.append(i2).foldl_tuple_int(tuple_int_f, ti2), subsume=True).to(
536
+ tuple_int_f(ti.foldl_tuple_int(tuple_int_f, ti2), i2)
537
+ ),
538
+ # if_
539
+ rewrite(TupleInt.if_(TRUE, ti, ti2), subsume=True).to(ti),
540
+ rewrite(TupleInt.if_(FALSE, ti, ti2), subsume=True).to(ti2),
541
+ # unify append
542
+ rule(eq(ti.append(i)).to(ti2.append(i2))).then(union(ti).with_(ti2), union(i).with_(i2)),
543
+ ]
544
+
545
+
546
+ class TupleTupleInt(Expr, ruleset=array_api_ruleset):
547
+ @classmethod
548
+ def var(cls, name: StringLike) -> TupleTupleInt: ...
549
+
550
+ EMPTY: ClassVar[TupleTupleInt]
551
+
552
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> None: ...
553
+
554
+ @method(subsume=True)
555
+ @classmethod
556
+ def single(cls, i: TupleIntLike) -> TupleTupleInt:
557
+ i = cast("TupleInt", i)
558
+ return TupleTupleInt(1, lambda _: i)
559
+
560
+ @method(subsume=True)
561
+ @classmethod
562
+ def from_vec(cls, vec: Vec[TupleInt]) -> TupleTupleInt: ...
563
+
564
+ def append(self, i: TupleIntLike) -> TupleTupleInt: ...
565
+
566
+ def __add__(self, other: TupleTupleIntLike) -> TupleTupleInt:
567
+ other = cast("TupleTupleInt", other)
568
+ return TupleTupleInt(
569
+ self.length() + other.length(),
570
+ lambda i: TupleInt.if_(i < self.length(), self[i], other[i - self.length()]),
571
+ )
572
+
573
+ def length(self) -> Int: ...
574
+ def __getitem__(self, i: IntLike) -> TupleInt: ...
575
+
576
+ @method(preserve=True)
577
+ def __len__(self) -> int:
578
+ return self.length().eval()
579
+
580
+ @method(preserve=True)
581
+ def __iter__(self) -> Iterator[TupleInt]:
582
+ return iter(self.eval())
583
+
584
+ @property
585
+ def to_vec(self) -> Vec[TupleInt]: ...
586
+
587
+ @method(preserve=True)
588
+ def eval(self) -> tuple[TupleInt, ...]:
589
+ return try_evaling(array_api_schedule, self, self.to_vec)
590
+
591
+ def drop(self, n: Int) -> TupleTupleInt:
592
+ return TupleTupleInt(self.length() - n, lambda i: self[i + n])
593
+
594
+ def map_int(self, f: Callable[[TupleInt], Int]) -> TupleInt:
595
+ return TupleInt(self.length(), lambda i: f(self[i]))
596
+
597
+ def foldl_value(self, f: Callable[[Value, TupleInt], Value], init: ValueLike) -> Value: ...
598
+
599
+ @method(subsume=True)
600
+ def product(self) -> TupleTupleInt:
601
+ """
602
+ Cartesian product of inputs
603
+
604
+ https://docs.python.org/3/library/itertools.html#itertools.product
605
+
606
+ https://github.com/saulshanabrook/saulshanabrook/discussions/39
607
+ """
608
+ return TupleTupleInt(
609
+ self.map_int(lambda x: x.length()).product(),
610
+ lambda i: TupleInt(
611
+ self.length(),
612
+ lambda j: self[j][(i // self.drop(j + 1).map_int(lambda x: x.length()).product()) % self[j].length()],
613
+ ),
614
+ )
615
+
616
+
617
+ converter(Vec[TupleInt], TupleTupleInt, lambda x: TupleTupleInt.from_vec(x))
618
+
619
+ TupleTupleIntLike: TypeAlias = TupleTupleInt | VecLike[TupleInt, TupleIntLike]
620
+
621
+
622
+ @array_api_ruleset.register
623
+ def _tuple_tuple_int(
624
+ length: Int,
625
+ fn: Callable[[TupleInt], Int],
626
+ idx_fn: Callable[[Int], TupleInt],
627
+ f: Callable[[Value, TupleInt], Value],
628
+ i: Value,
629
+ k: i64,
630
+ idx: Int,
631
+ vs: Vec[TupleInt],
632
+ ti: TupleInt,
633
+ ti1: TupleInt,
634
+ tti: TupleTupleInt,
635
+ tti1: TupleTupleInt,
636
+ ):
637
+ yield rule(eq(tti).to(TupleTupleInt.from_vec(vs))).then(set_(tti.to_vec).to(vs))
638
+ yield rewrite(TupleTupleInt(length, idx_fn).length()).to(length)
639
+ yield rewrite(TupleTupleInt(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length)))
640
+
641
+ # cons access
642
+ yield rewrite(TupleTupleInt.EMPTY.length()).to(Int(0))
643
+ yield rewrite(TupleTupleInt.EMPTY[idx]).to(TupleInt.NEVER)
644
+ yield rewrite(tti.append(ti).length()).to(tti.length() + 1)
645
+ yield rewrite(tti.append(ti)[idx]).to(TupleInt.if_(idx == tti.length(), ti, tti[idx]))
646
+
647
+ # functional to cons
648
+ yield rewrite(TupleTupleInt(0, idx_fn), subsume=True).to(TupleTupleInt.EMPTY)
649
+ yield rewrite(TupleTupleInt(Int(k), idx_fn), subsume=True).to(
650
+ TupleTupleInt(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0
651
+ )
652
+ # cons to vec
653
+ yield rewrite(TupleTupleInt.EMPTY).to(TupleTupleInt.from_vec(Vec[TupleInt]()))
654
+ yield rewrite(TupleTupleInt.from_vec(vs).append(ti)).to(TupleTupleInt.from_vec(vs.append(Vec(ti))))
655
+ # fold value
656
+ yield rewrite(TupleTupleInt.EMPTY.foldl_value(f, i), subsume=True).to(i)
657
+ yield rewrite(tti.append(ti).foldl_value(f, i), subsume=True).to(f(tti.foldl_value(f, i), ti))
658
+
659
+ # unify append
660
+ yield rule(eq(tti.append(ti)).to(tti1.append(ti1))).then(union(tti).with_(tti1), union(ti).with_(ti1))
661
+
662
+
663
+ class OptionalInt(Expr, ruleset=array_api_ruleset):
664
+ none: ClassVar[OptionalInt]
665
+
666
+ @classmethod
667
+ def some(cls, value: Int) -> OptionalInt: ...
668
+
669
+
670
+ converter(type(None), OptionalInt, lambda _: OptionalInt.none)
671
+ converter(Int, OptionalInt, OptionalInt.some)
672
+
673
+
674
+ class DType(Expr, ruleset=array_api_ruleset):
675
+ float64: ClassVar[DType]
676
+ float32: ClassVar[DType]
677
+ int64: ClassVar[DType]
678
+ int32: ClassVar[DType]
679
+ object: ClassVar[DType]
680
+ bool: ClassVar[DType]
681
+
682
+ def __eq__(self, other: DType) -> Boolean: # type: ignore[override]
683
+ ...
684
+
685
+
686
+ float64 = DType.float64
687
+ float32 = DType.float32
688
+ int32 = DType.int32
689
+ int64 = DType.int64
690
+
691
+ _DTYPES = [float64, float32, int32, int64, DType.object]
692
+
693
+ converter(type, DType, lambda x: convert(np.dtype(x), DType))
694
+ converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type: ignore[call-overload]
695
+
696
+
697
+ @array_api_ruleset.register
698
+ def _():
699
+ for l, r in itertools.product(_DTYPES, repeat=2):
700
+ yield rewrite(l == r).to(TRUE if l is r else FALSE)
701
+
702
+
703
+ class IsDtypeKind(Expr, ruleset=array_api_ruleset):
704
+ NULL: ClassVar[IsDtypeKind]
705
+
706
+ @classmethod
707
+ def string(cls, s: StringLike) -> IsDtypeKind: ...
708
+
709
+ @classmethod
710
+ def dtype(cls, d: DType) -> IsDtypeKind: ...
711
+
712
+ @method(cost=10)
713
+ def __or__(self, other: IsDtypeKind) -> IsDtypeKind: ...
714
+
715
+
716
+ # TODO: Make kind more generic to support tuples.
717
+ @function
718
+ def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: ...
719
+
720
+
721
+ converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x))
722
+ converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x))
723
+ converter(
724
+ tuple, IsDtypeKind, lambda x: convert(x[0], IsDtypeKind) | convert(x[1:], IsDtypeKind) if x else IsDtypeKind.NULL
725
+ )
726
+
727
+
728
+ @array_api_ruleset.register
729
+ def _isdtype(d: DType, k1: IsDtypeKind, k2: IsDtypeKind):
730
+ return [
731
+ rewrite(isdtype(DType.float32, IsDtypeKind.string("integral"))).to(FALSE),
732
+ rewrite(isdtype(DType.float64, IsDtypeKind.string("integral"))).to(FALSE),
733
+ rewrite(isdtype(DType.object, IsDtypeKind.string("integral"))).to(FALSE),
734
+ rewrite(isdtype(DType.int64, IsDtypeKind.string("integral"))).to(TRUE),
735
+ rewrite(isdtype(DType.int32, IsDtypeKind.string("integral"))).to(TRUE),
736
+ rewrite(isdtype(DType.float32, IsDtypeKind.string("real floating"))).to(TRUE),
737
+ rewrite(isdtype(DType.float64, IsDtypeKind.string("real floating"))).to(TRUE),
738
+ rewrite(isdtype(DType.object, IsDtypeKind.string("real floating"))).to(FALSE),
739
+ rewrite(isdtype(DType.int64, IsDtypeKind.string("real floating"))).to(FALSE),
740
+ rewrite(isdtype(DType.int32, IsDtypeKind.string("real floating"))).to(FALSE),
741
+ rewrite(isdtype(DType.float32, IsDtypeKind.string("complex floating"))).to(FALSE),
742
+ rewrite(isdtype(DType.float64, IsDtypeKind.string("complex floating"))).to(FALSE),
743
+ rewrite(isdtype(DType.object, IsDtypeKind.string("complex floating"))).to(FALSE),
744
+ rewrite(isdtype(DType.int64, IsDtypeKind.string("complex floating"))).to(FALSE),
745
+ rewrite(isdtype(DType.int32, IsDtypeKind.string("complex floating"))).to(FALSE),
746
+ rewrite(isdtype(d, IsDtypeKind.NULL)).to(FALSE),
747
+ rewrite(isdtype(d, IsDtypeKind.dtype(d))).to(TRUE),
748
+ rewrite(isdtype(d, k1 | k2)).to(isdtype(d, k1) | isdtype(d, k2)),
749
+ rewrite(k1 | IsDtypeKind.NULL).to(k1),
750
+ ]
751
+
752
+
753
+ # TODO: Add pushdown for math on scalars to values
754
+ # and add replacements
755
+
756
+
757
+ class Value(Expr, ruleset=array_api_ruleset):
758
+ NEVER: ClassVar[Value]
759
+
760
+ @classmethod
761
+ def int(cls, i: IntLike) -> Value: ...
762
+
763
+ @classmethod
764
+ def float(cls, f: FloatLike) -> Value: ...
765
+
766
+ @classmethod
767
+ def bool(cls, b: BooleanLike) -> Value: ...
768
+
769
+ def isfinite(self) -> Boolean: ...
770
+
771
+ def __lt__(self, other: ValueLike) -> Value: ...
772
+
773
+ def __truediv__(self, other: ValueLike) -> Value: ...
774
+
775
+ def __mul__(self, other: ValueLike) -> Value: ...
776
+
777
+ def __add__(self, other: ValueLike) -> Value: ...
778
+
779
+ def astype(self, dtype: DType) -> Value: ...
780
+
781
+ # TODO: Add all operations
782
+
783
+ @property
784
+ def dtype(self) -> DType:
785
+ """
786
+ Default dtype for this scalar value
787
+ """
788
+
789
+ @property
790
+ def to_bool(self) -> Boolean: ...
791
+
792
+ @property
793
+ def to_int(self) -> Int: ...
794
+
795
+ @property
796
+ def to_truthy_value(self) -> Value:
797
+ """
798
+ Converts the value to a bool, based on if its truthy.
799
+
800
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.any.html
801
+ """
802
+
803
+ def conj(self) -> Value: ...
804
+ def real(self) -> Value: ...
805
+ def sqrt(self) -> Value: ...
806
+
807
+ @classmethod
808
+ def if_(cls, b: BooleanLike, i: ValueLike, j: ValueLike) -> Value: ...
809
+
810
+ def __eq__(self, other: ValueLike) -> Boolean: ... # type: ignore[override]
811
+
812
+
813
+ ValueLike: TypeAlias = Value | IntLike | FloatLike | BooleanLike
814
+
815
+
816
+ converter(Int, Value, Value.int)
817
+ converter(Float, Value, Value.float)
818
+ converter(Boolean, Value, Value.bool)
819
+ converter(Value, Int, lambda x: x.to_int, 10)
820
+
821
+
822
+ @array_api_ruleset.register
823
+ def _value(i: Int, f: Float, b: Boolean, v: Value, v1: Value, i1: Int, f1: Float, b1: Boolean):
824
+ # Default dtypes
825
+ # https://data-apis.org/array-api/latest/API_specification/data_types.html?highlight=dtype#default-data-types
826
+ yield rewrite(Value.int(i).dtype).to(DType.int64)
827
+ yield rewrite(Value.float(f).dtype).to(DType.float64)
828
+ yield rewrite(Value.bool(b).dtype).to(DType.bool)
829
+
830
+ yield rewrite(Value.bool(b).to_bool).to(b)
831
+ yield rewrite(Value.int(i).to_int).to(i)
832
+
833
+ yield rewrite(Value.bool(b).to_truthy_value).to(Value.bool(b))
834
+ # TODO: Add more rules for to_bool_value
835
+
836
+ yield rewrite(Value.float(f).conj()).to(Value.float(f))
837
+ yield rewrite(Value.float(f).real()).to(Value.float(f))
838
+ yield rewrite(Value.int(i).real()).to(Value.int(i))
839
+ yield rewrite(Value.int(i).conj()).to(Value.int(i))
840
+
841
+ yield rewrite(Value.float(f).sqrt()).to(Value.float(f ** (0.5)))
842
+
843
+ yield rewrite(Value.float(Float.rational(BigRat(0, 1))) + v).to(v)
844
+
845
+ yield rewrite(Value.if_(TRUE, v, v1)).to(v)
846
+ yield rewrite(Value.if_(FALSE, v, v1)).to(v1)
847
+
848
+ # ==
849
+ yield rewrite(Value.int(i) == Value.int(i1)).to(i == i1)
850
+ yield rewrite(Value.float(f) == Value.float(f1)).to(f == f1)
851
+ yield rewrite(Value.bool(b) == Value.bool(b1)).to(b == b1)
852
+
853
+
854
+ class TupleValue(Expr, ruleset=array_api_ruleset):
855
+ EMPTY: ClassVar[TupleValue]
856
+
857
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], Value]) -> None: ...
858
+
859
+ def append(self, i: ValueLike) -> TupleValue: ...
860
+
861
+ @classmethod
862
+ def from_vec(cls, vec: Vec[Value]) -> TupleValue: ...
863
+
864
+ def __add__(self, other: TupleValueLike) -> TupleValue:
865
+ other = cast("TupleValue", other)
866
+ return TupleValue(
867
+ self.length() + other.length(),
868
+ lambda i: Value.if_(i < self.length(), self[i], other[i - self.length()]),
869
+ )
870
+
871
+ def length(self) -> Int: ...
872
+
873
+ def __getitem__(self, i: Int) -> Value: ...
874
+
875
+ def foldl_boolean(self, f: Callable[[Boolean, Value], Boolean], init: BooleanLike) -> Boolean: ...
876
+
877
+ def contains(self, value: ValueLike) -> Boolean:
878
+ value = cast("Value", value)
879
+ return self.foldl_boolean(lambda acc, j: acc | (value == j), FALSE)
880
+
881
+ @method(subsume=True)
882
+ @classmethod
883
+ def from_tuple_int(cls, ti: TupleIntLike) -> TupleValue:
884
+ ti = cast("TupleInt", ti)
885
+ return TupleValue(ti.length(), lambda i: Value.int(ti[i]))
886
+
887
+
888
+ converter(Vec[Value], TupleValue, lambda x: TupleValue.from_vec(x))
889
+ converter(TupleInt, TupleValue, lambda x: TupleValue.from_tuple_int(x))
890
+
891
+ TupleValueLike: TypeAlias = TupleValue | VecLike[Value, ValueLike] | TupleIntLike
892
+
893
+
894
+ @array_api_ruleset.register
895
+ def _tuple_value(
896
+ length: Int,
897
+ idx_fn: Callable[[Int], Value],
898
+ k: i64,
899
+ idx: Int,
900
+ vs: Vec[Value],
901
+ v: Value,
902
+ v1: Value,
903
+ tv: TupleValue,
904
+ tv1: TupleValue,
905
+ bool_f: Callable[[Boolean, Value], Boolean],
906
+ b: Boolean,
907
+ ):
908
+ yield rewrite(TupleValue(length, idx_fn).length()).to(length)
909
+ yield rewrite(TupleValue(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length)))
910
+
911
+ # cons access
912
+ yield rewrite(TupleValue.EMPTY.length()).to(Int(0))
913
+ yield rewrite(TupleValue.EMPTY[idx]).to(Value.NEVER)
914
+ yield rewrite(tv.append(v).length()).to(tv.length() + 1)
915
+ yield rewrite(tv.append(v)[idx]).to(Value.if_(idx == tv.length(), v, tv[idx]))
916
+
917
+ # functional to cons
918
+ yield rewrite(TupleValue(0, idx_fn), subsume=True).to(TupleValue.EMPTY)
919
+ yield rewrite(TupleValue(Int(k), idx_fn), subsume=True).to(
920
+ TupleValue(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0
921
+ )
922
+
923
+ # cons to vec
924
+ yield rewrite(TupleValue.EMPTY).to(TupleValue.from_vec(Vec[Value]()))
925
+ yield rewrite(TupleValue.from_vec(vs).append(v)).to(TupleValue.from_vec(vs.append(Vec(v))))
926
+
927
+ # fold boolean
928
+ yield rewrite(TupleValue.EMPTY.foldl_boolean(bool_f, b), subsume=True).to(b)
929
+ yield rewrite(tv.append(v).foldl_boolean(bool_f, b), subsume=True).to(bool_f(tv.foldl_boolean(bool_f, b), v))
930
+
931
+ # unify append
932
+ yield rule(eq(tv.append(v)).to(tv1.append(v1))).then(union(tv).with_(tv1), union(v).with_(v1))
933
+
934
+
935
+ @function
936
+ def possible_values(values: Value) -> TupleValue:
937
+ """
938
+ All possible values in the input value.
939
+ """
940
+
941
+
942
+ class Slice(Expr, ruleset=array_api_ruleset):
943
+ def __init__(
944
+ self,
945
+ start: OptionalInt = OptionalInt.none,
946
+ stop: OptionalInt = OptionalInt.none,
947
+ step: OptionalInt = OptionalInt.none,
948
+ ) -> None: ...
949
+
950
+
951
+ converter(
952
+ slice,
953
+ Slice,
954
+ lambda x: Slice(convert(x.start, OptionalInt), convert(x.stop, OptionalInt), convert(x.step, OptionalInt)),
955
+ )
956
+
957
+ SliceLike: TypeAlias = Slice | slice
958
+
959
+
960
+ class MultiAxisIndexKeyItem(Expr, ruleset=array_api_ruleset):
961
+ ELLIPSIS: ClassVar[MultiAxisIndexKeyItem]
962
+ NONE: ClassVar[MultiAxisIndexKeyItem]
963
+
964
+ @classmethod
965
+ def int(cls, i: Int) -> MultiAxisIndexKeyItem: ...
966
+
967
+ @classmethod
968
+ def slice(cls, slice: Slice) -> MultiAxisIndexKeyItem: ...
969
+
970
+
971
+ converter(type(...), MultiAxisIndexKeyItem, lambda _: MultiAxisIndexKeyItem.ELLIPSIS)
972
+ converter(type(None), MultiAxisIndexKeyItem, lambda _: MultiAxisIndexKeyItem.NONE)
973
+ converter(Int, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.int)
974
+ converter(Slice, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.slice)
975
+
976
+ MultiAxisIndexKeyItemLike: TypeAlias = MultiAxisIndexKeyItem | EllipsisType | None | IntLike | SliceLike
977
+
978
+
979
+ class MultiAxisIndexKey(Expr, ruleset=array_api_ruleset):
980
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], MultiAxisIndexKeyItem]) -> None: ...
981
+
982
+ def __add__(self, other: MultiAxisIndexKey) -> MultiAxisIndexKey: ...
983
+
984
+ @classmethod
985
+ def from_vec(cls, vec: Vec[MultiAxisIndexKeyItem]) -> MultiAxisIndexKey: ...
986
+
987
+
988
+ MultiAxisIndexKeyLike: TypeAlias = "MultiAxisIndexKey | tuple[MultiAxisIndexKeyItemLike, ...] | TupleIntLike"
989
+
990
+
991
+ converter(
992
+ tuple,
993
+ MultiAxisIndexKey,
994
+ lambda x: MultiAxisIndexKey.from_vec(Vec(*(convert(i, MultiAxisIndexKeyItem) for i in x))),
995
+ )
996
+ converter(
997
+ TupleInt, MultiAxisIndexKey, lambda ti: MultiAxisIndexKey(ti.length(), lambda i: MultiAxisIndexKeyItem.int(ti[i]))
998
+ )
999
+
1000
+
1001
+ class IndexKey(Expr, ruleset=array_api_ruleset):
1002
+ """
1003
+ A key for indexing into an array
1004
+
1005
+ https://data-apis.org/array-api/2022.12/API_specification/indexing.html
1006
+
1007
+ It is equivalent to the following type signature:
1008
+
1009
+ Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis, None], ...], array]
1010
+ """
1011
+
1012
+ ELLIPSIS: ClassVar[IndexKey]
1013
+
1014
+ @classmethod
1015
+ def int(cls, i: Int) -> IndexKey: ...
1016
+
1017
+ @classmethod
1018
+ def slice(cls, slice: Slice) -> IndexKey: ...
1019
+
1020
+ # Disabled until we support late binding
1021
+ # @classmethod
1022
+ # def boolean_array(cls, b: NDArray) -> IndexKey:
1023
+ # ...
1024
+
1025
+ @classmethod
1026
+ def multi_axis(cls, key: MultiAxisIndexKey) -> IndexKey: ...
1027
+
1028
+ @classmethod
1029
+ def ndarray(cls, key: NDArray) -> IndexKey:
1030
+ """
1031
+ Indexes by a masked array
1032
+ """
1033
+
1034
+
1035
+ IndexKeyLike: TypeAlias = "IndexKey | IntLike | SliceLike | MultiAxisIndexKeyLike | NDArrayLike"
1036
+
1037
+
1038
+ converter(type(...), IndexKey, lambda _: IndexKey.ELLIPSIS)
1039
+ converter(Int, IndexKey, lambda i: IndexKey.int(i))
1040
+ converter(Slice, IndexKey, lambda s: IndexKey.slice(s))
1041
+ converter(MultiAxisIndexKey, IndexKey, lambda m: IndexKey.multi_axis(m))
1042
+
1043
+
1044
+ class Device(Expr, ruleset=array_api_ruleset): ...
1045
+
1046
+
1047
+ ALL_INDICES: TupleInt = constant("ALL_INDICES", TupleInt)
1048
+
1049
+
1050
+ class NDArray(Expr, ruleset=array_api_ruleset):
1051
+ def __init__(self, shape: TupleIntLike, dtype: DType, idx_fn: Callable[[TupleInt], Value]) -> None: ...
1052
+
1053
+ NEVER: ClassVar[NDArray]
1054
+
1055
+ @method(cost=200)
1056
+ @classmethod
1057
+ def var(cls, name: StringLike) -> NDArray: ...
1058
+
1059
+ @method(preserve=True)
1060
+ def __array_namespace__(self, api_version: object = None) -> ModuleType:
1061
+ return sys.modules[__name__]
1062
+
1063
+ @property
1064
+ def ndim(self) -> Int: ...
1065
+
1066
+ @property
1067
+ def dtype(self) -> DType: ...
1068
+
1069
+ @property
1070
+ def device(self) -> Device: ...
1071
+
1072
+ @property
1073
+ def shape(self) -> TupleInt: ...
1074
+
1075
+ @method(preserve=True)
1076
+ def __bool__(self) -> bool:
1077
+ return self.to_value().to_bool.eval()
1078
+
1079
+ @property
1080
+ def size(self) -> Int: ...
1081
+
1082
+ @method(preserve=True)
1083
+ def __len__(self) -> int:
1084
+ return self.size.eval()
1085
+
1086
+ @method(preserve=True)
1087
+ def __iter__(self) -> Iterator[NDArray]:
1088
+ for i in range(len(self)):
1089
+ yield self[IndexKey.int(Int(i))]
1090
+
1091
+ def __getitem__(self, key: IndexKeyLike) -> NDArray: ...
1092
+
1093
+ def __setitem__(self, key: IndexKeyLike, value: NDArray) -> None: ...
1094
+
1095
+ def __lt__(self, other: NDArrayLike) -> NDArray: ...
1096
+
1097
+ def __le__(self, other: NDArrayLike) -> NDArray: ...
1098
+
1099
+ def __eq__(self, other: NDArrayLike) -> NDArray: # type: ignore[override]
1100
+ ...
1101
+
1102
+ # TODO: Add support for overloaded __ne__
1103
+ # def __ne__(self, other: NDArrayLike) -> NDArray: # type: ignore[override]
1104
+ # ...
1105
+
1106
+ def __gt__(self, other: NDArrayLike) -> NDArray: ...
1107
+
1108
+ def __ge__(self, other: NDArrayLike) -> NDArray: ...
1109
+
1110
+ def __add__(self, other: NDArrayLike) -> NDArray: ...
1111
+
1112
+ def __sub__(self, other: NDArrayLike) -> NDArray: ...
1113
+
1114
+ def __mul__(self, other: NDArrayLike) -> NDArray: ...
1115
+
1116
+ def __matmul__(self, other: NDArrayLike) -> NDArray: ...
1117
+
1118
+ def __truediv__(self, other: NDArrayLike) -> NDArray: ...
1119
+
1120
+ def __floordiv__(self, other: NDArrayLike) -> NDArray: ...
1121
+
1122
+ def __mod__(self, other: NDArrayLike) -> NDArray: ...
1123
+
1124
+ def __divmod__(self, other: NDArrayLike) -> NDArray: ...
1125
+
1126
+ def __pow__(self, other: NDArrayLike) -> NDArray: ...
1127
+
1128
+ def __lshift__(self, other: NDArrayLike) -> NDArray: ...
1129
+
1130
+ def __rshift__(self, other: NDArrayLike) -> NDArray: ...
1131
+
1132
+ def __and__(self, other: NDArrayLike) -> NDArray: ...
1133
+
1134
+ def __xor__(self, other: NDArrayLike) -> NDArray: ...
1135
+
1136
+ def __or__(self, other: NDArrayLike) -> NDArray: ...
1137
+
1138
+ def __radd__(self, other: NDArray) -> NDArray: ...
1139
+
1140
+ def __rsub__(self, other: NDArray) -> NDArray: ...
1141
+
1142
+ def __rmul__(self, other: NDArray) -> NDArray: ...
1143
+
1144
+ def __rmatmul__(self, other: NDArray) -> NDArray: ...
1145
+
1146
+ def __rtruediv__(self, other: NDArray) -> NDArray: ...
1147
+
1148
+ def __rfloordiv__(self, other: NDArray) -> NDArray: ...
1149
+
1150
+ def __rmod__(self, other: NDArray) -> NDArray: ...
1151
+
1152
+ def __rpow__(self, other: NDArray) -> NDArray: ...
1153
+
1154
+ def __rlshift__(self, other: NDArray) -> NDArray: ...
1155
+
1156
+ def __rrshift__(self, other: NDArray) -> NDArray: ...
1157
+
1158
+ def __rand__(self, other: NDArray) -> NDArray: ...
1159
+
1160
+ def __rxor__(self, other: NDArray) -> NDArray: ...
1161
+
1162
+ def __ror__(self, other: NDArray) -> NDArray: ...
1163
+
1164
+ @classmethod
1165
+ def scalar(cls, value: Value) -> NDArray:
1166
+ return NDArray(TupleInt.EMPTY, value.dtype, lambda _: value)
1167
+
1168
+ def to_value(self) -> Value:
1169
+ """
1170
+ Returns the value if this is a scalar.
1171
+ """
1172
+
1173
+ def to_values(self) -> TupleValue:
1174
+ """
1175
+ Returns the value if this is a vector.
1176
+ """
1177
+
1178
+ @property
1179
+ def T(self) -> NDArray:
1180
+ """
1181
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.array.T.html#array_api.array.T
1182
+ """
1183
+
1184
+ @classmethod
1185
+ def vector(cls, values: TupleValueLike) -> NDArray: ...
1186
+
1187
+ def index(self, indices: TupleIntLike) -> Value:
1188
+ """
1189
+ Return the value at the given indices.
1190
+ """
1191
+
1192
+ @classmethod
1193
+ def if_(cls, b: BooleanLike, i: NDArrayLike, j: NDArrayLike) -> NDArray: ...
1194
+
1195
+
1196
+ NDArrayLike: TypeAlias = NDArray | ValueLike | TupleValueLike
1197
+
1198
+ converter(NDArray, IndexKey, IndexKey.ndarray)
1199
+ converter(Value, NDArray, NDArray.scalar)
1200
+ # Need this if we want to use ints in slices of arrays coming from 1d arrays, but make it more expensive
1201
+ # to prefer upcasting in the other direction when we can, which is safer at runtime
1202
+ converter(NDArray, Value, lambda n: n.to_value(), 100)
1203
+ converter(TupleValue, NDArray, NDArray.vector)
1204
+ converter(TupleInt, TupleValue, TupleValue.from_tuple_int)
1205
+
1206
+
1207
+ @array_api_ruleset.register
1208
+ def _ndarray(
1209
+ x: NDArray,
1210
+ x1: NDArray,
1211
+ b: Boolean,
1212
+ f: Float,
1213
+ fi1: f64,
1214
+ fi2: f64,
1215
+ shape: TupleInt,
1216
+ dtype: DType,
1217
+ idx_fn: Callable[[TupleInt], Value],
1218
+ idx: TupleInt,
1219
+ tv: TupleValue,
1220
+ ):
1221
+ return [
1222
+ rewrite(NDArray(shape, dtype, idx_fn).shape).to(shape),
1223
+ rewrite(NDArray(shape, dtype, idx_fn).dtype).to(dtype),
1224
+ rewrite(NDArray(shape, dtype, idx_fn).index(idx), subsume=True).to(idx_fn(idx)),
1225
+ rewrite(x.ndim).to(x.shape.length()),
1226
+ # rewrite(NDArray.scalar(Value.bool(b)).to_bool()).to(b),
1227
+ # Converting to a value requires a scalar bool value
1228
+ rewrite(x.to_value()).to(x.index(TupleInt.EMPTY)),
1229
+ rewrite(NDArray.vector(tv).to_values()).to(tv),
1230
+ # TODO: Push these down to float
1231
+ rewrite(NDArray.scalar(Value.float(f)) / NDArray.scalar(Value.float(f))).to(
1232
+ NDArray.scalar(Value.float(Float(1.0)))
1233
+ ),
1234
+ rewrite(NDArray.scalar(Value.float(f)) - NDArray.scalar(Value.float(f))).to(
1235
+ NDArray.scalar(Value.float(Float(0.0)))
1236
+ ),
1237
+ rewrite(NDArray.scalar(Value.float(Float(fi1))) > NDArray.scalar(Value.float(Float(fi2)))).to(
1238
+ NDArray.scalar(Value.bool(TRUE)), fi1 > fi2
1239
+ ),
1240
+ rewrite(NDArray.scalar(Value.float(Float(fi1))) > NDArray.scalar(Value.float(Float(fi2)))).to(
1241
+ NDArray.scalar(Value.bool(FALSE)), fi1 <= fi2
1242
+ ),
1243
+ # Transpose of tranpose is the original array
1244
+ rewrite(x.T.T).to(x),
1245
+ # if_
1246
+ rewrite(NDArray.if_(TRUE, x, x1)).to(x),
1247
+ rewrite(NDArray.if_(FALSE, x, x1)).to(x1),
1248
+ ]
1249
+
1250
+
1251
+ class TupleNDArray(Expr, ruleset=array_api_ruleset):
1252
+ EMPTY: ClassVar[TupleNDArray]
1253
+
1254
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], NDArray]) -> None: ...
1255
+
1256
+ def append(self, i: NDArrayLike) -> TupleNDArray: ...
1257
+
1258
+ @classmethod
1259
+ def from_vec(cls, vec: Vec[NDArray]) -> TupleNDArray: ...
1260
+
1261
+ def __add__(self, other: TupleNDArrayLike) -> TupleNDArray:
1262
+ other = cast("TupleNDArray", other)
1263
+ return TupleNDArray(
1264
+ self.length() + other.length(),
1265
+ lambda i: NDArray.if_(i < self.length(), self[i], other[i - self.length()]),
1266
+ )
1267
+
1268
+ def length(self) -> Int: ...
1269
+
1270
+ def __getitem__(self, i: IntLike) -> NDArray: ...
1271
+
1272
+ @method(preserve=True)
1273
+ def __len__(self) -> int:
1274
+ return self.length().eval()
1275
+
1276
+ @method(preserve=True)
1277
+ def __iter__(self) -> Iterator[NDArray]:
1278
+ return iter(self.eval())
1279
+
1280
+ @property
1281
+ def to_vec(self) -> Vec[NDArray]: ...
1282
+
1283
+ @method(preserve=True)
1284
+ def eval(self) -> tuple[NDArray, ...]:
1285
+ return try_evaling(array_api_schedule, self, self.to_vec)
1286
+
1287
+
1288
+ converter(Vec[NDArray], TupleNDArray, lambda x: TupleNDArray.from_vec(x))
1289
+
1290
+ TupleNDArrayLike: TypeAlias = TupleNDArray | VecLike[NDArray, NDArrayLike]
1291
+
1292
+
1293
+ @array_api_ruleset.register
1294
+ def _tuple_ndarray(
1295
+ length: Int,
1296
+ idx_fn: Callable[[Int], NDArray],
1297
+ k: i64,
1298
+ idx: Int,
1299
+ vs: Vec[NDArray],
1300
+ v: NDArray,
1301
+ v1: NDArray,
1302
+ tv: TupleNDArray,
1303
+ tv1: TupleNDArray,
1304
+ b: Boolean,
1305
+ ):
1306
+ yield rule(eq(tv).to(TupleNDArray.from_vec(vs))).then(set_(tv.to_vec).to(vs))
1307
+ yield rewrite(TupleNDArray(length, idx_fn).length()).to(length)
1308
+ yield rewrite(TupleNDArray(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length)))
1309
+
1310
+ # cons access
1311
+ yield rewrite(TupleNDArray.EMPTY.length()).to(Int(0))
1312
+ yield rewrite(TupleNDArray.EMPTY[idx]).to(NDArray.NEVER)
1313
+ yield rewrite(tv.append(v).length()).to(tv.length() + 1)
1314
+ yield rewrite(tv.append(v)[idx]).to(NDArray.if_(idx == tv.length(), v, tv[idx]))
1315
+ # functional to cons
1316
+ yield rewrite(TupleNDArray(0, idx_fn), subsume=True).to(TupleNDArray.EMPTY)
1317
+ yield rewrite(TupleNDArray(Int(k), idx_fn), subsume=True).to(
1318
+ TupleNDArray(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0
1319
+ )
1320
+
1321
+ # cons to vec
1322
+ yield rewrite(TupleNDArray.EMPTY).to(TupleNDArray.from_vec(Vec[NDArray]()))
1323
+ yield rewrite(TupleNDArray.from_vec(vs).append(v)).to(TupleNDArray.from_vec(vs.append(Vec(v))))
1324
+
1325
+ # unify append
1326
+ yield rule(eq(tv.append(v)).to(tv1.append(v1))).then(union(tv).with_(tv1), union(v).with_(v1))
1327
+
1328
+
1329
+ class OptionalBool(Expr, ruleset=array_api_ruleset):
1330
+ none: ClassVar[OptionalBool]
1331
+
1332
+ @classmethod
1333
+ def some(cls, value: Boolean) -> OptionalBool: ...
1334
+
1335
+
1336
+ converter(type(None), OptionalBool, lambda _: OptionalBool.none)
1337
+ converter(Boolean, OptionalBool, lambda x: OptionalBool.some(x))
1338
+
1339
+
1340
+ class OptionalDType(Expr, ruleset=array_api_ruleset):
1341
+ none: ClassVar[OptionalDType]
1342
+
1343
+ @classmethod
1344
+ def some(cls, value: DType) -> OptionalDType: ...
1345
+
1346
+
1347
+ converter(type(None), OptionalDType, lambda _: OptionalDType.none)
1348
+ converter(DType, OptionalDType, lambda x: OptionalDType.some(x))
1349
+
1350
+
1351
+ class OptionalDevice(Expr, ruleset=array_api_ruleset):
1352
+ none: ClassVar[OptionalDevice]
1353
+
1354
+ @classmethod
1355
+ def some(cls, value: Device) -> OptionalDevice: ...
1356
+
1357
+
1358
+ converter(type(None), OptionalDevice, lambda _: OptionalDevice.none)
1359
+ converter(Device, OptionalDevice, lambda x: OptionalDevice.some(x))
1360
+
1361
+
1362
+ class OptionalTupleInt(Expr, ruleset=array_api_ruleset):
1363
+ none: ClassVar[OptionalTupleInt]
1364
+
1365
+ @classmethod
1366
+ def some(cls, value: TupleIntLike) -> OptionalTupleInt: ...
1367
+
1368
+
1369
+ converter(type(None), OptionalTupleInt, lambda _: OptionalTupleInt.none)
1370
+ converter(TupleInt, OptionalTupleInt, lambda x: OptionalTupleInt.some(x))
1371
+
1372
+
1373
+ class IntOrTuple(Expr, ruleset=array_api_ruleset):
1374
+ none: ClassVar[IntOrTuple]
1375
+
1376
+ @classmethod
1377
+ def int(cls, value: Int) -> IntOrTuple: ...
1378
+
1379
+ @classmethod
1380
+ def tuple(cls, value: TupleIntLike) -> IntOrTuple: ...
1381
+
1382
+
1383
+ converter(Int, IntOrTuple, IntOrTuple.int)
1384
+ converter(TupleInt, IntOrTuple, IntOrTuple.tuple)
1385
+
1386
+
1387
+ class OptionalIntOrTuple(Expr, ruleset=array_api_ruleset):
1388
+ none: ClassVar[OptionalIntOrTuple]
1389
+
1390
+ @classmethod
1391
+ def some(cls, value: IntOrTuple) -> OptionalIntOrTuple: ...
1392
+
1393
+
1394
+ converter(type(None), OptionalIntOrTuple, lambda _: OptionalIntOrTuple.none)
1395
+ converter(IntOrTuple, OptionalIntOrTuple, OptionalIntOrTuple.some)
1396
+
1397
+
1398
+ @function
1399
+ def asarray(
1400
+ a: NDArray,
1401
+ dtype: OptionalDType = OptionalDType.none,
1402
+ copy: OptionalBool = OptionalBool.none,
1403
+ device: OptionalDevice = OptionalDevice.none,
1404
+ ) -> NDArray: ...
1405
+
1406
+
1407
+ @array_api_ruleset.register
1408
+ def _assarray(a: NDArray, d: OptionalDType, ob: OptionalBool):
1409
+ yield rewrite(asarray(a, d, ob).ndim).to(a.ndim) # asarray doesn't change ndim
1410
+ yield rewrite(asarray(a)).to(a)
1411
+
1412
+
1413
+ @function
1414
+ def isfinite(x: NDArray) -> NDArray: ...
1415
+
1416
+
1417
+ @function
1418
+ def sum(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray:
1419
+ """
1420
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.sum.html?highlight=sum
1421
+ """
1422
+
1423
+
1424
+ @array_api_ruleset.register
1425
+ def _sum(x: NDArray, y: NDArray, v: Value, dtype: DType):
1426
+ return [
1427
+ rewrite(sum(x / NDArray.scalar(v))).to(sum(x) / NDArray.scalar(v)),
1428
+ # Sum of 0D array is
1429
+ ]
1430
+
1431
+
1432
+ @function
1433
+ def reshape(x: NDArray, shape: TupleIntLike, copy: OptionalBool = OptionalBool.none) -> NDArray: ...
1434
+
1435
+
1436
+ # @function
1437
+ # def reshape_transform_index(original_shape: TupleInt, shape: TupleInt, index: TupleInt) -> TupleInt:
1438
+ # """
1439
+ # Transforms an indexing operation on a reshaped array to an indexing operation on the original array.
1440
+ # """
1441
+ # ...
1442
+
1443
+
1444
+ # @function
1445
+ # def reshape_transform_shape(original_shape: TupleInt, shape: TupleInt) -> TupleInt:
1446
+ # """
1447
+ # Transforms the shape of an array to one that is reshaped, by replacing -1 with the correct value.
1448
+ # """
1449
+ # ...
1450
+
1451
+
1452
+ # @array_api_ruleset.register
1453
+ # def _reshape(
1454
+ # x: NDArray,
1455
+ # y: NDArray,
1456
+ # shape: TupleInt,
1457
+ # copy: OptionalBool,
1458
+ # i: Int,
1459
+ # s: String,
1460
+ # ix: TupleInt,
1461
+ # ):
1462
+ # return [
1463
+ # # dtype of result is same as input
1464
+ # rewrite(reshape(x, shape, copy).dtype).to(x.dtype),
1465
+ # # Indexing into a reshaped array is the same as indexing into the original array with a transformed index
1466
+ # rewrite(reshape(x, shape, copy).index(ix)).to(x.index(reshape_transform_index(x.shape, shape, ix))),
1467
+ # rewrite(reshape(x, shape, copy).shape).to(reshape_transform_shape(x.shape, shape)),
1468
+ # # reshape_transform_shape recursively
1469
+ # # TODO: handle all cases
1470
+ # rewrite(reshape_transform_shape(TupleInt(i), TupleInt(Int(-1)))).to(TupleInt(i)),
1471
+ # ]
1472
+
1473
+
1474
+ @function
1475
+ def unique_values(x: NDArrayLike) -> NDArray:
1476
+ """
1477
+ Returns the unique elements of an input array x flattened with arbitrary ordering.
1478
+
1479
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_values.html
1480
+ """
1481
+
1482
+
1483
+ @array_api_ruleset.register
1484
+ def _unique_values(x: NDArray):
1485
+ return [
1486
+ rewrite(unique_values(unique_values(x))).to(unique_values(x)),
1487
+ ]
1488
+
1489
+
1490
+ @function
1491
+ def concat(arrays: TupleNDArrayLike, axis: OptionalInt = OptionalInt.none) -> NDArray: ...
1492
+
1493
+
1494
+ @array_api_ruleset.register
1495
+ def _concat(x: NDArray):
1496
+ return [
1497
+ # only support no-op concat for now
1498
+ rewrite(concat(TupleNDArray.EMPTY.append(x))).to(x),
1499
+ ]
1500
+
1501
+
1502
+ @function
1503
+ def astype(x: NDArray, dtype: DType) -> NDArray: ...
1504
+
1505
+
1506
+ @array_api_ruleset.register
1507
+ def _astype(x: NDArray, dtype: DType, i: i64):
1508
+ return [
1509
+ rewrite(astype(x, dtype).dtype).to(dtype),
1510
+ rewrite(astype(NDArray.scalar(Value.int(Int(i))), float64)).to(
1511
+ NDArray.scalar(Value.float(Float(f64.from_i64(i))))
1512
+ ),
1513
+ ]
1514
+
1515
+
1516
+ @function
1517
+ def unique_counts(x: NDArray) -> TupleNDArray:
1518
+ """
1519
+ Returns the unique elements of an input array x and the corresponding counts for each unique element in x.
1520
+
1521
+
1522
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_counts.html
1523
+ """
1524
+
1525
+
1526
+ @array_api_ruleset.register
1527
+ def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value, dtype: DType):
1528
+ return [
1529
+ # rewrite(unique_counts(x).length()).to(Int(2)),
1530
+ rewrite(unique_counts(x)).to(TupleNDArray(2, unique_counts(x).__getitem__)),
1531
+ # Sum of all unique counts is the size of the array
1532
+ rewrite(sum(unique_counts(x)[Int(1)])).to(NDArray.scalar(Value.int(x.size))),
1533
+ # Same but with astype in the middle
1534
+ # TODO: Replace
1535
+ rewrite(sum(astype(unique_counts(x)[Int(1)], dtype))).to(astype(NDArray.scalar(Value.int(x.size)), dtype)),
1536
+ ]
1537
+
1538
+
1539
+ @function
1540
+ def square(x: NDArray) -> NDArray: ...
1541
+
1542
+
1543
+ @function
1544
+ def any(x: NDArray) -> NDArray: ...
1545
+
1546
+
1547
+ @function(egg_fn="ndarray-abs")
1548
+ def abs(x: NDArray) -> NDArray: ...
1549
+
1550
+
1551
+ @function(egg_fn="ndarray-log")
1552
+ def log(x: NDArray) -> NDArray: ...
1553
+
1554
+
1555
+ @array_api_ruleset.register
1556
+ def _abs(f: Float):
1557
+ return [
1558
+ rewrite(abs(NDArray.scalar(Value.float(f)))).to(NDArray.scalar(Value.float(f.abs()))),
1559
+ ]
1560
+
1561
+
1562
+ @function
1563
+ def unique_inverse(x: NDArray) -> TupleNDArray:
1564
+ """
1565
+ Returns the unique elements of an input array x and the indices from the set of unique elements that reconstruct x.
1566
+
1567
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_inverse.html
1568
+ """
1569
+
1570
+
1571
+ @array_api_ruleset.register
1572
+ def _unique_inverse(x: NDArray, i: Int):
1573
+ return [
1574
+ # rewrite(unique_inverse(x).length()).to(Int(2)),
1575
+ rewrite(unique_inverse(x)).to(TupleNDArray(2, unique_inverse(x).__getitem__)),
1576
+ # Shape of unique_inverse first element is same as shape of unique_values
1577
+ rewrite(unique_inverse(x)[Int(0)]).to(unique_values(x)),
1578
+ ]
1579
+
1580
+
1581
+ @function
1582
+ def zeros(
1583
+ shape: TupleIntLike, dtype: OptionalDType = OptionalDType.none, device: OptionalDevice = OptionalDevice.none
1584
+ ) -> NDArray: ...
1585
+
1586
+
1587
+ @function
1588
+ def expand_dims(x: NDArray, axis: Int = Int(0)) -> NDArray: ...
1589
+
1590
+
1591
+ @function
1592
+ def mean(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none, keepdims: Boolean = FALSE) -> NDArray: ...
1593
+
1594
+
1595
+ # TODO: Possibly change names to include modules.
1596
+ @function(egg_fn="ndarray-sqrt")
1597
+ def sqrt(x: NDArray) -> NDArray: ...
1598
+
1599
+
1600
+ @function
1601
+ def std(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray: ...
1602
+
1603
+
1604
+ @function
1605
+ def real(x: NDArray) -> NDArray: ...
1606
+
1607
+
1608
+ @function
1609
+ def conj(x: NDArray) -> NDArray: ...
1610
+
1611
+
1612
+ linalg = sys.modules[__name__]
1613
+
1614
+
1615
+ @function
1616
+ def svd(x: NDArray, full_matrices: Boolean = TRUE) -> TupleNDArray:
1617
+ """
1618
+ https://data-apis.org/array-api/2022.12/extensions/generated/array_api.linalg.svd.html
1619
+ """
1620
+
1621
+
1622
+ @array_api_ruleset.register
1623
+ def _linalg(x: NDArray, full_matrices: Boolean):
1624
+ return [
1625
+ # rewrite(svd(x, full_matrices).length()).to(Int(3)),
1626
+ rewrite(svd(x, full_matrices)).to(TupleNDArray(3, svd(x, full_matrices).__getitem__)),
1627
+ ]
1628
+
1629
+
1630
+ @function(ruleset=array_api_ruleset)
1631
+ def ndindex(shape: TupleIntLike) -> TupleTupleInt:
1632
+ """
1633
+ https://numpy.org/doc/stable/reference/generated/numpy.ndindex.html
1634
+ """
1635
+ shape = cast("TupleInt", shape)
1636
+ return shape.map_tuple_int(TupleInt.range).product()
1637
+
1638
+
1639
+ ##
1640
+ # Interval analysis
1641
+ #
1642
+ # to analyze `any(((astype(unique_counts(NDArray.var("y"))[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(150.0))) < NDArray.scalar(Value.int(Int(0)))).bool()``
1643
+ ##
1644
+
1645
+ greater_zero = relation("greater_zero", Value)
1646
+
1647
+
1648
+ # @function
1649
+ # def ndarray_all_greater_0(x: NDArray) -> Unit:
1650
+ # ...
1651
+
1652
+
1653
+ # @function
1654
+ # def ndarray_all_false(x: NDArray) -> Unit:
1655
+ # ...
1656
+
1657
+
1658
+ # @function
1659
+ # def ndarray_all_true(x: NDArray) -> Unit:
1660
+ # ...
1661
+
1662
+
1663
+ # any((astype(unique_counts(_NDArray_1)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(150.0)))) < NDArray.scalar(Value.int(Int(0)))).to_bool()
1664
+
1665
+ # sum(astype(unique_counts(_NDArray_1)[Int(1)], DType.float64) / NDArray.scalar(Value.int(Int(150))))
1666
+ # And also
1667
+
1668
+ # def
1669
+
1670
+
1671
+ @function
1672
+ def broadcast_index(from_shape: TupleIntLike, to_shape: TupleIntLike, index: TupleIntLike) -> TupleInt:
1673
+ """
1674
+ Returns the index in the original array of the given index in the broadcasted array.
1675
+ """
1676
+
1677
+
1678
+ @function
1679
+ def broadcast_shapes(shape1: TupleIntLike, shape2: TupleIntLike) -> TupleInt:
1680
+ """
1681
+ Returns the shape of the broadcasted array.
1682
+ """
1683
+
1684
+
1685
+ @array_api_ruleset.register
1686
+ def _interval_analaysis(
1687
+ x: NDArray,
1688
+ y: NDArray,
1689
+ z: NDArray,
1690
+ dtype: DType,
1691
+ f: f64,
1692
+ i: i64,
1693
+ b: Boolean,
1694
+ idx: TupleInt,
1695
+ v: Value,
1696
+ v1: Value,
1697
+ v2: Value,
1698
+ float_: Float,
1699
+ int_: Int,
1700
+ ):
1701
+ res_shape = broadcast_shapes(x.shape, y.shape)
1702
+ x_value = x.index(broadcast_index(x.shape, res_shape, idx))
1703
+ y_value = y.index(broadcast_index(y.shape, res_shape, idx))
1704
+ return [
1705
+ # Calling any on an array gives back a sclar, which is true if any of the values are truthy
1706
+ rewrite(any(x)).to(
1707
+ NDArray.scalar(Value.bool(possible_values(x.index(ALL_INDICES).to_truthy_value).contains(Value.bool(TRUE))))
1708
+ ),
1709
+ # Indexing x < y is the same as broadcasting the index and then indexing both and then comparing
1710
+ rewrite((x < y).index(idx)).to(x_value < y_value),
1711
+ # Same for x / y
1712
+ rewrite((x / y).index(idx)).to(x_value / y_value),
1713
+ # Indexing a scalar is the same as the scalar
1714
+ rewrite(NDArray.scalar(v).index(idx)).to(v),
1715
+ # Indexing of astype is same as astype of indexing
1716
+ rewrite(astype(x, dtype).index(idx)).to(x.index(idx).astype(dtype)),
1717
+ # rule(eq(y).to(x < NDArray.scalar(Value.int(Int(0)))), ndarray_all_greater_0(x)).then(ndarray_all_false(y)),
1718
+ # rule(eq(y).to(any(x)), ndarray_all_false(x)).then(union(y).with_(NDArray.scalar(Value.bool(FALSE)))),
1719
+ # Indexing into unique counts counts are all positive
1720
+ rule(
1721
+ eq(v).to(unique_counts(x)[Int(1)].index(idx)),
1722
+ ).then(greater_zero(v)),
1723
+ # Min value preserved over astype
1724
+ rule(
1725
+ greater_zero(v),
1726
+ eq(v1).to(v.astype(dtype)),
1727
+ ).then(
1728
+ greater_zero(v1),
1729
+ ),
1730
+ # Min value of scalar is scalar itself
1731
+ rule(eq(v).to(Value.float(Float(f))), f > 0.0).then(greater_zero(v)),
1732
+ rule(eq(v).to(Value.int(Int(i))), i > 0).then(greater_zero(v)),
1733
+ # If we have divison of v and v1, and both greater than zero, then the result is greater than zero
1734
+ rule(
1735
+ greater_zero(v),
1736
+ greater_zero(v1),
1737
+ eq(v2).to(v / v1),
1738
+ ).then(
1739
+ greater_zero(v2),
1740
+ ),
1741
+ # Define v < 0 to be false, if greater_zero(v)
1742
+ rule(
1743
+ greater_zero(v),
1744
+ eq(v1).to(v < Value.int(Int(0))),
1745
+ ).then(
1746
+ union(v1).with_(Value.bool(FALSE)),
1747
+ ),
1748
+ # possible values of bool is bool
1749
+ rewrite(possible_values(Value.bool(b))).to(TupleValue.EMPTY.append(Value.bool(b))),
1750
+ # casting to a type preserves if > 0
1751
+ rule(
1752
+ eq(v1).to(v.astype(dtype)),
1753
+ greater_zero(v),
1754
+ ).then(
1755
+ greater_zero(v1),
1756
+ ),
1757
+ ]
1758
+
1759
+
1760
+ ##
1761
+ # Mathematical descriptions of arrays as:
1762
+ # 1. A shape `.shape`
1763
+ # 2. A dtype `.dtype`
1764
+ # 3. A mapping from indices to values `x.index(idx)`
1765
+ #
1766
+ # For all operations that are supported mathematically, define each of the above.
1767
+ ##
1768
+
1769
+
1770
+ def _demand_shape(compound: NDArray, inner: NDArray) -> Command:
1771
+ __a = var("__a", NDArray)
1772
+ return rule(eq(__a).to(compound)).then(inner.shape, inner.shape.length())
1773
+
1774
+
1775
+ @array_api_ruleset.register
1776
+ def _scalar_math(v: Value, vs: TupleValue, i: Int):
1777
+ yield rewrite(NDArray.scalar(v).shape).to(TupleInt.EMPTY)
1778
+ yield rewrite(NDArray.scalar(v).dtype).to(v.dtype)
1779
+ yield rewrite(NDArray.scalar(v).index(TupleInt.EMPTY)).to(v)
1780
+
1781
+
1782
+ @array_api_ruleset.register
1783
+ def _vector_math(v: Value, vs: TupleValue, ti: TupleInt):
1784
+ yield rewrite(NDArray.vector(vs).shape).to(TupleInt.single(vs.length()))
1785
+ yield rewrite(NDArray.vector(vs).dtype).to(vs[Int(0)].dtype)
1786
+ yield rewrite(NDArray.vector(vs).index(ti)).to(vs[ti[0]])
1787
+
1788
+
1789
+ @array_api_ruleset.register
1790
+ def _reshape_math(x: NDArray, shape: TupleInt, copy: OptionalBool):
1791
+ res = reshape(x, shape, copy)
1792
+
1793
+ yield _demand_shape(res, x)
1794
+ # Demand shape length and index
1795
+ yield rule(res).then(shape.length(), shape[0])
1796
+
1797
+ # Reshaping a vec to a vec is the same as the vec
1798
+ yield rewrite(res).to(
1799
+ x,
1800
+ eq(x.shape.length()).to(Int(1)),
1801
+ eq(shape.length()).to(Int(1)),
1802
+ eq(shape[0]).to(Int(-1)),
1803
+ )
1804
+
1805
+
1806
+ @array_api_ruleset.register
1807
+ def _indexing_pushdown(x: NDArray, shape: TupleInt, copy: OptionalBool, i: Int):
1808
+ # rewrite full getitem to indexec
1809
+ yield rewrite(x[IndexKey.int(i)]).to(NDArray.scalar(x.index(TupleInt.single(i))))
1810
+ # TODO: Multi index rewrite as well if all are ints
1811
+
1812
+
1813
+ ##
1814
+ # Assumptions
1815
+ ##
1816
+
1817
+
1818
+ @function(mutates_first_arg=True)
1819
+ def assume_dtype(x: NDArray, dtype: DType) -> None:
1820
+ """
1821
+ Asserts that the dtype of x is dtype.
1822
+ """
1823
+
1824
+
1825
+ @array_api_ruleset.register
1826
+ def _assume_dtype(x: NDArray, dtype: DType, idx: TupleInt):
1827
+ orig_x = copy(x)
1828
+ assume_dtype(x, dtype)
1829
+ yield rewrite(x.dtype).to(dtype)
1830
+ yield rewrite(x.shape).to(orig_x.shape)
1831
+ yield rewrite(x.index(idx)).to(orig_x.index(idx))
1832
+
1833
+
1834
+ @function(mutates_first_arg=True)
1835
+ def assume_shape(x: NDArray, shape: TupleIntLike) -> None:
1836
+ """
1837
+ Asserts that the shape of x is shape.
1838
+ """
1839
+
1840
+
1841
+ @array_api_ruleset.register
1842
+ def _assume_shape(x: NDArray, shape: TupleInt, idx: TupleInt):
1843
+ orig_x = copy(x)
1844
+ assume_shape(x, shape)
1845
+ yield rewrite(x.shape).to(shape)
1846
+ yield rewrite(x.dtype).to(orig_x.dtype)
1847
+ yield rewrite(x.index(idx)).to(orig_x.index(idx))
1848
+
1849
+
1850
+ @function(mutates_first_arg=True)
1851
+ def assume_isfinite(x: NDArray) -> None:
1852
+ """
1853
+ Asserts that the scalar ndarray is non null and not infinite.
1854
+ """
1855
+
1856
+
1857
+ @array_api_ruleset.register
1858
+ def _isfinite(x: NDArray, ti: TupleInt):
1859
+ orig_x = copy(x)
1860
+ assume_isfinite(x)
1861
+
1862
+ # pass through getitem, shape, index
1863
+ yield rewrite(x.shape).to(orig_x.shape)
1864
+ yield rewrite(x.dtype).to(orig_x.dtype)
1865
+ yield rewrite(x.index(ti)).to(orig_x.index(ti))
1866
+ # But say that any indixed value is finite
1867
+ yield rewrite(x.index(ti).isfinite()).to(TRUE)
1868
+
1869
+
1870
+ @function(mutates_first_arg=True)
1871
+ def assume_value_one_of(x: NDArray, values: TupleValueLike) -> None:
1872
+ """
1873
+ A value that is one of the values in the tuple.
1874
+ """
1875
+
1876
+
1877
+ @array_api_ruleset.register
1878
+ def _assume_value_one_of(x: NDArray, v: Value, vs: TupleValue, idx: TupleInt):
1879
+ x_orig = copy(x)
1880
+ assume_value_one_of(x, vs)
1881
+ # Pass through dtype and shape
1882
+ yield rewrite(x.shape).to(x_orig.shape)
1883
+ yield rewrite(x.dtype).to(x_orig.dtype)
1884
+ # The array vales passes through, but say that the possible_values are one of the values
1885
+ yield rule(eq(v).to(x.index(idx))).then(
1886
+ union(v).with_(x_orig.index(idx)),
1887
+ union(possible_values(v)).with_(vs),
1888
+ )
1889
+
1890
+
1891
+ @array_api_ruleset.register
1892
+ def _ndarray_value_isfinite(arr: NDArray, x: Value, xs: TupleValue, i: Int, f: f64, b: Boolean):
1893
+ yield rewrite(Value.int(i).isfinite()).to(TRUE)
1894
+ yield rewrite(Value.bool(b).isfinite()).to(TRUE)
1895
+ yield rewrite(Value.float(Float(f)).isfinite()).to(TRUE, ne(f).to(f64(math.nan)))
1896
+
1897
+ # a sum of an array is finite if all the values are finite
1898
+ yield rewrite(isfinite(sum(arr))).to(NDArray.scalar(Value.bool(arr.index(ALL_INDICES).isfinite())))
1899
+
1900
+
1901
+ @array_api_ruleset.register
1902
+ def _unique(xs: TupleValue, a: NDArray, shape: TupleInt, copy: OptionalBool):
1903
+ yield rewrite(unique_values(x=a)).to(NDArray.vector(possible_values(a.index(ALL_INDICES))))
1904
+ # yield rewrite(
1905
+ # possible_values(reshape(a.index(shape, copy), ALL_INDICES)),
1906
+ # ).to(possible_values(a.index(ALL_INDICES)))
1907
+
1908
+
1909
+ @array_api_ruleset.register
1910
+ def _size(x: NDArray):
1911
+ yield rewrite(x.size).to(x.shape.foldl(Int.__mul__, Int(1)))
1912
+
1913
+
1914
+ # Seperate rulseset so we can use it in program gen
1915
+ @ruleset
1916
+ def array_api_vec_to_cons_ruleset(
1917
+ vs: Vec[Int],
1918
+ vv: Vec[Value],
1919
+ vn: Vec[NDArray],
1920
+ vt: Vec[TupleInt],
1921
+ ):
1922
+ yield rewrite(TupleInt.from_vec(vs)).to(TupleInt.EMPTY, eq(vs.length()).to(i64(0)))
1923
+ yield rewrite(TupleInt.from_vec(vs)).to(
1924
+ TupleInt.from_vec(vs.remove(vs.length() - 1)).append(vs[vs.length() - 1]), ne(vs.length()).to(i64(0))
1925
+ )
1926
+
1927
+ yield rewrite(TupleValue.from_vec(vv)).to(TupleValue.EMPTY, eq(vv.length()).to(i64(0)))
1928
+ yield rewrite(TupleValue.from_vec(vv)).to(
1929
+ TupleValue.from_vec(vv.remove(vv.length() - 1)).append(vv[vv.length() - 1]), ne(vv.length()).to(i64(0))
1930
+ )
1931
+
1932
+ yield rewrite(TupleTupleInt.from_vec(vt)).to(TupleTupleInt.EMPTY, eq(vt.length()).to(i64(0)))
1933
+ yield rewrite(TupleTupleInt.from_vec(vt)).to(
1934
+ TupleTupleInt.from_vec(vt.remove(vt.length() - 1)).append(vt[vt.length() - 1]), ne(vt.length()).to(i64(0))
1935
+ )
1936
+ yield rewrite(TupleNDArray.from_vec(vn)).to(TupleNDArray.EMPTY, eq(vn.length()).to(i64(0)))
1937
+ yield rewrite(TupleNDArray.from_vec(vn)).to(
1938
+ TupleNDArray.from_vec(vn.remove(vn.length() - 1)).append(vn[vn.length() - 1]), ne(vn.length()).to(i64(0))
1939
+ )
1940
+
1941
+
1942
+ array_api_combined_ruleset = array_api_ruleset | array_api_vec_to_cons_ruleset
1943
+ array_api_schedule = array_api_combined_ruleset.saturate()