egglog 11.2.0__cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

Files changed (46) hide show
  1. egglog/__init__.py +13 -0
  2. egglog/bindings.cpython-314-x86_64-linux-gnu.so +0 -0
  3. egglog/bindings.pyi +734 -0
  4. egglog/builtins.py +1133 -0
  5. egglog/config.py +8 -0
  6. egglog/conversion.py +286 -0
  7. egglog/declarations.py +912 -0
  8. egglog/deconstruct.py +173 -0
  9. egglog/egraph.py +1875 -0
  10. egglog/egraph_state.py +680 -0
  11. egglog/examples/README.rst +5 -0
  12. egglog/examples/__init__.py +3 -0
  13. egglog/examples/bignum.py +32 -0
  14. egglog/examples/bool.py +38 -0
  15. egglog/examples/eqsat_basic.py +44 -0
  16. egglog/examples/fib.py +28 -0
  17. egglog/examples/higher_order_functions.py +42 -0
  18. egglog/examples/jointree.py +67 -0
  19. egglog/examples/lambda_.py +287 -0
  20. egglog/examples/matrix.py +175 -0
  21. egglog/examples/multiset.py +60 -0
  22. egglog/examples/ndarrays.py +144 -0
  23. egglog/examples/resolution.py +84 -0
  24. egglog/examples/schedule_demo.py +34 -0
  25. egglog/exp/__init__.py +3 -0
  26. egglog/exp/array_api.py +2019 -0
  27. egglog/exp/array_api_jit.py +51 -0
  28. egglog/exp/array_api_loopnest.py +74 -0
  29. egglog/exp/array_api_numba.py +69 -0
  30. egglog/exp/array_api_program_gen.py +510 -0
  31. egglog/exp/program_gen.py +425 -0
  32. egglog/exp/siu_examples.py +32 -0
  33. egglog/ipython_magic.py +41 -0
  34. egglog/pretty.py +509 -0
  35. egglog/py.typed +0 -0
  36. egglog/runtime.py +712 -0
  37. egglog/thunk.py +97 -0
  38. egglog/type_constraint_solver.py +113 -0
  39. egglog/version_compat.py +87 -0
  40. egglog/visualizer.css +1 -0
  41. egglog/visualizer.js +35777 -0
  42. egglog/visualizer_widget.py +39 -0
  43. egglog-11.2.0.dist-info/METADATA +74 -0
  44. egglog-11.2.0.dist-info/RECORD +46 -0
  45. egglog-11.2.0.dist-info/WHEEL +4 -0
  46. egglog-11.2.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,2019 @@
1
+ """
2
+
3
+
4
+ ## Lists
5
+
6
+ Lists have two main constructors:
7
+
8
+ - `List(length, idx_fn)`
9
+ - `List.EMPTY` / `initial.append(last)`
10
+
11
+ This is so that they can be defined either with a known fixed integer length (the cons list type) or a symbolic
12
+ length that could not be resolved to an integer.
13
+
14
+ There are rewrites to convert between these constructors in both directions. The only limitation however is that
15
+ `length` has to a real i64 in order to be converted to a cons list.
16
+
17
+ When you are writing a function that uses ints, feel free to the `__getitem__` or `length()` methods or match
18
+ directly on `List()` constructor. If you can write your function using that interface please do. But for some other
19
+ methods whether the resulting length/index function is dependent on the rest of it, you can only define it with a known
20
+ length, so you can then use the const list constructors.
21
+
22
+ We also support creating lists from vectors. These can be converted one to one to the snoc list representation.
23
+
24
+ It is troublesome to have to redefine lists for every type. It would be nice to have generic types, but they are not implemented yet.
25
+
26
+ We are gauranteed that all lists with known lengths will be represented as cons/empty. To safely use lists, use
27
+ the `.length` and `.__getitem__` methods, unles you want to to depend on it having known length, in which
28
+ case you can match directly on the cons list.
29
+
30
+ To be a list, you must implement two methods:
31
+
32
+ * `l.length() -> Int`
33
+ * `l.__getitem__(i: Int) -> T`
34
+
35
+ There are three main types of constructors for lists which all implement these methods:
36
+
37
+ * Functional `List(length, idx_fn)`
38
+ * cons (well reversed cons) lists `List.EMPTY` and `l.append(x)`
39
+ * Vectors `List.from_vec(vec)`
40
+
41
+ Also all lists constructors must be converted to the functional representation, so that we can match on it
42
+ and convert lists with known lengths into cons lists and into vectors.
43
+
44
+ This is neccessary so that known length lists are properly materialized during extraction.
45
+
46
+ Q: Why are they implemented as SNOC lists instead of CONS lists?
47
+ A: So that when converting from functional to lists we can use the same index function by starting at the end and folding
48
+ that way recursively.
49
+
50
+
51
+ """
52
+
53
+ # mypy: disable-error-code="empty-body"
54
+
55
+ from __future__ import annotations
56
+
57
+ import contextlib
58
+ import itertools
59
+ import math
60
+ import numbers
61
+ import os
62
+ import sys
63
+ from collections.abc import Callable
64
+ from copy import copy
65
+ from types import EllipsisType
66
+ from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, cast
67
+
68
+ import numpy as np
69
+
70
+ from egglog import *
71
+ from egglog.runtime import RuntimeExpr
72
+ from egglog.version_compat import add_note
73
+
74
+ from .program_gen import *
75
+
76
+ if TYPE_CHECKING:
77
+ from collections.abc import Iterator
78
+ from types import ModuleType
79
+
80
+
81
+ # Pretend that exprs are numbers b/c sklearn does isinstance checks
82
+ numbers.Integral.register(RuntimeExpr)
83
+
84
+ # Set this to 1 before scipy is ever imported
85
+ # https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html#using-array-api-standard-support
86
+ os.environ["SCIPY_ARRAY_API"] = "1"
87
+
88
+ array_api_ruleset = ruleset(name="array_api_ruleset")
89
+
90
+
91
+ class Boolean(Expr, ruleset=array_api_ruleset):
92
+ def __init__(self, value: BoolLike) -> None: ...
93
+
94
+ @method(preserve=True)
95
+ def __bool__(self) -> bool:
96
+ return self.eval()
97
+
98
+ @method(preserve=True)
99
+ def eval(self) -> bool:
100
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_bool)
101
+
102
+ @property
103
+ def to_bool(self) -> Bool: ...
104
+
105
+ def __or__(self, other: BooleanLike) -> Boolean: ...
106
+
107
+ def __and__(self, other: BooleanLike) -> Boolean: ...
108
+
109
+ def __invert__(self) -> Boolean: ...
110
+
111
+ def __eq__(self, other: BooleanLike) -> Boolean: ... # type: ignore[override]
112
+
113
+
114
+ BooleanLike = Boolean | BoolLike
115
+
116
+ TRUE = Boolean(True)
117
+ FALSE = Boolean(False)
118
+ converter(Bool, Boolean, Boolean)
119
+
120
+
121
+ @array_api_ruleset.register
122
+ def _bool(x: Boolean, i: Int, j: Int, b: Bool):
123
+ return [
124
+ rule(eq(x).to(Boolean(b))).then(set_(x.to_bool).to(b)),
125
+ rewrite(TRUE | x).to(TRUE),
126
+ rewrite(FALSE | x).to(x),
127
+ rewrite(TRUE & x).to(x),
128
+ rewrite(FALSE & x).to(FALSE),
129
+ rewrite(~TRUE).to(FALSE),
130
+ rewrite(~FALSE).to(TRUE),
131
+ rule(eq(FALSE).to(TRUE)).then(panic("False cannot equal True")),
132
+ rewrite(x == x).to(TRUE), # noqa: PLR0124
133
+ rewrite(FALSE == TRUE).to(FALSE),
134
+ rewrite(TRUE == FALSE).to(FALSE),
135
+ ]
136
+
137
+
138
+ class Int(Expr, ruleset=array_api_ruleset):
139
+ # a never int is that should not exist. It could represent for example indexing into an array a value that is out of bounds
140
+ # https://en.wikipedia.org/wiki/Bottom_type
141
+ NEVER: ClassVar[Int]
142
+
143
+ @classmethod
144
+ def var(cls, name: StringLike) -> Int: ...
145
+
146
+ def __init__(self, value: i64Like) -> None: ...
147
+
148
+ def __invert__(self) -> Int: ...
149
+
150
+ def __lt__(self, other: IntLike) -> Boolean: ...
151
+
152
+ def __le__(self, other: IntLike) -> Boolean: ...
153
+
154
+ def __eq__(self, other: IntLike) -> Boolean: # type: ignore[override]
155
+ ...
156
+
157
+ # add a hash so that this test can pass
158
+ # https://github.com/scikit-learn/scikit-learn/blob/6fd23fca53845b32b249f2b36051c081b65e2fab/sklearn/utils/validation.py#L486-L487
159
+ @method(preserve=True)
160
+ def __hash__(self) -> int:
161
+ egraph = _get_current_egraph()
162
+ egraph.register(self)
163
+ egraph.run(array_api_schedule)
164
+ simplified = egraph.extract(self)
165
+ return hash(cast("RuntimeExpr", simplified).__egg_typed_expr__)
166
+
167
+ def __round__(self, ndigits: OptionalIntLike = None) -> Int: ...
168
+
169
+ # TODO: Fix this?
170
+ # Make != always return a Bool, so that numpy.unique works on a tuple of ints
171
+ # In _unique1d
172
+ @method(preserve=True)
173
+ def __ne__(self, other: Int) -> bool: # type: ignore[override]
174
+ return not (self == other)
175
+
176
+ def __gt__(self, other: IntLike) -> Boolean: ...
177
+
178
+ def __ge__(self, other: IntLike) -> Boolean: ...
179
+
180
+ def __add__(self, other: IntLike) -> Int: ...
181
+
182
+ def __sub__(self, other: IntLike) -> Int: ...
183
+
184
+ def __mul__(self, other: IntLike) -> Int: ...
185
+
186
+ def __truediv__(self, other: IntLike) -> Int: ...
187
+
188
+ def __floordiv__(self, other: IntLike) -> Int: ...
189
+
190
+ def __mod__(self, other: IntLike) -> Int: ...
191
+
192
+ def __divmod__(self, other: IntLike) -> Int: ...
193
+
194
+ def __pow__(self, other: IntLike) -> Int: ...
195
+
196
+ def __lshift__(self, other: IntLike) -> Int: ...
197
+
198
+ def __rshift__(self, other: IntLike) -> Int: ...
199
+
200
+ def __and__(self, other: IntLike) -> Int: ...
201
+
202
+ def __xor__(self, other: IntLike) -> Int: ...
203
+
204
+ def __or__(self, other: IntLike) -> Int: ...
205
+
206
+ def __radd__(self, other: IntLike) -> Int: ...
207
+
208
+ def __rsub__(self, other: IntLike) -> Int: ...
209
+
210
+ def __rmul__(self, other: IntLike) -> Int: ...
211
+
212
+ def __rmatmul__(self, other: IntLike) -> Int: ...
213
+
214
+ def __rtruediv__(self, other: IntLike) -> Int: ...
215
+
216
+ def __rfloordiv__(self, other: IntLike) -> Int: ...
217
+
218
+ def __rmod__(self, other: IntLike) -> Int: ...
219
+
220
+ def __rpow__(self, other: IntLike) -> Int: ...
221
+
222
+ def __rlshift__(self, other: IntLike) -> Int: ...
223
+
224
+ def __rrshift__(self, other: IntLike) -> Int: ...
225
+
226
+ def __rand__(self, other: IntLike) -> Int: ...
227
+
228
+ def __rxor__(self, other: IntLike) -> Int: ...
229
+
230
+ def __ror__(self, other: IntLike) -> Int: ...
231
+
232
+ @property
233
+ def to_i64(self) -> i64: ...
234
+
235
+ @method(preserve=True)
236
+ def eval(self) -> int:
237
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_i64)
238
+
239
+ @method(preserve=True)
240
+ def __index__(self) -> int:
241
+ return self.eval()
242
+
243
+ @method(preserve=True)
244
+ def __int__(self) -> int:
245
+ return self.eval()
246
+
247
+ @method(preserve=True)
248
+ def __float__(self) -> float:
249
+ return float(self.eval())
250
+
251
+ @method(preserve=True)
252
+ def __bool__(self) -> bool:
253
+ return bool(self.eval())
254
+
255
+ @classmethod
256
+ def if_(cls, b: BooleanLike, i: IntLike, j: IntLike) -> Int: ...
257
+
258
+
259
+ @array_api_ruleset.register
260
+ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
261
+ yield rewrite(Int(i) == Int(i)).to(TRUE)
262
+ yield rule(eq(r).to(Int(i) == Int(j)), ne(i).to(j)).then(union(r).with_(FALSE))
263
+
264
+ yield rewrite(Int(i) >= Int(i)).to(TRUE)
265
+ yield rule(eq(r).to(Int(i) >= Int(j)), i > j).then(union(r).with_(TRUE))
266
+ yield rule(eq(r).to(Int(i) >= Int(j)), i < j).then(union(r).with_(FALSE))
267
+
268
+ yield rewrite(Int(i) < Int(i)).to(FALSE)
269
+ yield rule(eq(r).to(Int(i) < Int(j)), i < j).then(union(r).with_(TRUE))
270
+ yield rule(eq(r).to(Int(i) < Int(j)), i > j).then(union(r).with_(FALSE))
271
+
272
+ yield rewrite(Int(i) > Int(i)).to(FALSE)
273
+ yield rule(eq(r).to(Int(i) > Int(j)), i > j).then(union(r).with_(TRUE))
274
+ yield rule(eq(r).to(Int(i) > Int(j)), i < j).then(union(r).with_(FALSE))
275
+
276
+ yield rule(eq(o).to(Int(j))).then(set_(o.to_i64).to(j))
277
+
278
+ yield rule(eq(Int(i)).to(Int(j)), ne(i).to(j)).then(panic("Real ints cannot be equal to different ints"))
279
+
280
+ yield rewrite(Int(i) + Int(j)).to(Int(i + j))
281
+ yield rewrite(Int(i) - Int(j)).to(Int(i - j))
282
+ yield rewrite(Int(i) * Int(j)).to(Int(i * j))
283
+ yield rewrite(Int(i) // Int(j)).to(Int(i / j))
284
+ yield rewrite(Int(i) % Int(j)).to(Int(i % j))
285
+ yield rewrite(Int(i) & Int(j)).to(Int(i & j))
286
+ yield rewrite(Int(i) | Int(j)).to(Int(i | j))
287
+ yield rewrite(Int(i) ^ Int(j)).to(Int(i ^ j))
288
+ yield rewrite(Int(i) << Int(j)).to(Int(i << j))
289
+ yield rewrite(Int(i) >> Int(j)).to(Int(i >> j))
290
+ yield rewrite(~Int(i)).to(Int(~i))
291
+
292
+ yield rewrite(Int.if_(TRUE, o, b), subsume=True).to(o)
293
+ yield rewrite(Int.if_(FALSE, o, b), subsume=True).to(b)
294
+
295
+ yield rewrite(o.__round__(OptionalInt.none)).to(o)
296
+
297
+ # Never cannot be equal to anything real
298
+ yield rule(eq(Int.NEVER).to(Int(i))).then(panic("Int.NEVER cannot be equal to any real int"))
299
+
300
+
301
+ converter(i64, Int, lambda x: Int(x))
302
+
303
+ IntLike: TypeAlias = Int | i64Like
304
+
305
+
306
+ @function(ruleset=array_api_ruleset)
307
+ def check_index(length: IntLike, idx: IntLike) -> Int:
308
+ """
309
+ Returns the index if 0 <= idx < length, otherwise returns Int.NEVER
310
+ """
311
+ length = cast("Int", length)
312
+ idx = cast("Int", idx)
313
+ return Int.if_(((idx >= 0) & (idx < length)), idx, Int.NEVER)
314
+
315
+
316
+ # @array_api_ruleset.register
317
+ # def _check_index(i: i64, j: i64, x: Int):
318
+ # yield rewrite(
319
+ # check_index(Int(i), Int(j)),
320
+ # ).to(
321
+ # Int(j),
322
+ # i >= 0,
323
+ # i < j,
324
+ # )
325
+
326
+ # yield rewrite(
327
+ # check_index(x, Int(i)),
328
+ # ).to(
329
+ # Int.NEVER,
330
+ # i < 0,
331
+ # )
332
+
333
+ # yield rewrite(
334
+ # check_index(Int(i), Int(j)),
335
+ # ).to(
336
+ # Int.NEVER,
337
+ # i >= j,
338
+ # )
339
+
340
+
341
+ class Float(Expr, ruleset=array_api_ruleset):
342
+ # Differentiate costs of three constructors so extraction is deterministic if all three are present
343
+ @method(cost=3)
344
+ def __init__(self, value: f64Like) -> None: ...
345
+
346
+ @property
347
+ def to_f64(self) -> f64: ...
348
+
349
+ @method(preserve=True)
350
+ def eval(self) -> float:
351
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_f64)
352
+
353
+ def abs(self) -> Float: ...
354
+
355
+ @method(cost=2)
356
+ @classmethod
357
+ def rational(cls, r: BigRat) -> Float: ...
358
+
359
+ @classmethod
360
+ def from_int(cls, i: IntLike) -> Float: ...
361
+
362
+ def __truediv__(self, other: FloatLike) -> Float: ...
363
+
364
+ def __mul__(self, other: FloatLike) -> Float: ...
365
+
366
+ def __add__(self, other: FloatLike) -> Float: ...
367
+
368
+ def __sub__(self, other: FloatLike) -> Float: ...
369
+
370
+ def __pow__(self, other: FloatLike) -> Float: ...
371
+ def __round__(self, ndigits: OptionalIntLike = None) -> Float: ...
372
+
373
+ def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
374
+ def __ne__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
375
+ def __lt__(self, other: FloatLike) -> Boolean: ...
376
+ def __le__(self, other: FloatLike) -> Boolean: ...
377
+ def __gt__(self, other: FloatLike) -> Boolean: ...
378
+ def __ge__(self, other: FloatLike) -> Boolean: ...
379
+
380
+
381
+ converter(float, Float, lambda x: Float(x))
382
+ converter(Int, Float, lambda x: Float.from_int(x))
383
+
384
+
385
+ FloatLike: TypeAlias = Float | float | IntLike
386
+
387
+
388
+ @array_api_ruleset.register
389
+ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat, i_: Int):
390
+ return [
391
+ rule(eq(fl).to(Float(f))).then(set_(fl.to_f64).to(f)),
392
+ rewrite(Float.from_int(Int(i))).to(Float(f64.from_i64(i))),
393
+ rewrite(Float(f).abs()).to(Float(f), f >= 0.0),
394
+ rewrite(Float(f).abs()).to(Float(-f), f < 0.0),
395
+ # Convert from float to rationl, if its a whole number i.e. can be converted to int
396
+ rewrite(Float(f)).to(Float.rational(BigRat(f.to_i64(), 1)), eq(f64.from_i64(f.to_i64())).to(f)),
397
+ # always convert from int to rational
398
+ rewrite(Float.from_int(Int(i))).to(Float.rational(BigRat(i, 1))),
399
+ rewrite(Float(f) + Float(f2)).to(Float(f + f2)),
400
+ rewrite(Float(f) - Float(f2)).to(Float(f - f2)),
401
+ rewrite(Float(f) * Float(f2)).to(Float(f * f2)),
402
+ rewrite(Float.rational(r) / Float.rational(r1)).to(Float.rational(r / r1)),
403
+ rewrite(Float.rational(r) + Float.rational(r1)).to(Float.rational(r + r1)),
404
+ rewrite(Float.rational(r) - Float.rational(r1)).to(Float.rational(r - r1)),
405
+ rewrite(Float.rational(r) * Float.rational(r1)).to(Float.rational(r * r1)),
406
+ rewrite(Float(f) ** Float(f2)).to(Float(f**f2)),
407
+ # comparisons
408
+ rewrite(Float(f) == Float(f)).to(TRUE),
409
+ rewrite(Float(f) == Float(f2)).to(FALSE, ne(f).to(f2)),
410
+ rewrite(Float(f) != Float(f2)).to(TRUE, f != f2),
411
+ rewrite(Float(f) != Float(f)).to(FALSE),
412
+ rewrite(Float(f) >= Float(f2)).to(TRUE, f >= f2),
413
+ rewrite(Float(f) >= Float(f2)).to(FALSE, f < f2),
414
+ rewrite(Float(f) <= Float(f2)).to(TRUE, f <= f2),
415
+ rewrite(Float(f) <= Float(f2)).to(FALSE, f > f2),
416
+ rewrite(Float(f) > Float(f2)).to(TRUE, f > f2),
417
+ rewrite(Float(f) > Float(f2)).to(FALSE, f <= f2),
418
+ rewrite(Float(f) < Float(f2)).to(TRUE, f < f2),
419
+ rewrite(Float.rational(r) == Float.rational(r)).to(TRUE),
420
+ rewrite(Float.rational(r) == Float.rational(r1)).to(FALSE, ne(r).to(r1)),
421
+ # round
422
+ rewrite(Float.rational(r).__round__()).to(Float.rational(r.round())),
423
+ ]
424
+
425
+
426
+ class TupleInt(Expr, ruleset=array_api_ruleset):
427
+ """
428
+ Should act like a tuple[int, ...]
429
+
430
+ All constructors should be rewritten to the functional semantics in the __init__ method.
431
+ """
432
+
433
+ @classmethod
434
+ def var(cls, name: StringLike) -> TupleInt: ...
435
+
436
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], Int]) -> None: ...
437
+
438
+ EMPTY: ClassVar[TupleInt]
439
+ NEVER: ClassVar[TupleInt]
440
+
441
+ def append(self, i: IntLike) -> TupleInt: ...
442
+
443
+ @classmethod
444
+ def single(cls, i: Int) -> TupleInt:
445
+ return TupleInt(Int(1), lambda _: i)
446
+
447
+ @method(subsume=True)
448
+ @classmethod
449
+ def range(cls, stop: IntLike) -> TupleInt:
450
+ return TupleInt(stop, lambda i: i)
451
+
452
+ @classmethod
453
+ def from_vec(cls, vec: VecLike[Int, IntLike]) -> TupleInt: ...
454
+
455
+ def __add__(self, other: TupleIntLike) -> TupleInt:
456
+ other = cast("TupleInt", other)
457
+ return TupleInt(
458
+ self.length() + other.length(), lambda i: Int.if_(i < self.length(), self[i], other[i - self.length()])
459
+ )
460
+
461
+ def length(self) -> Int: ...
462
+ def __getitem__(self, i: IntLike) -> Int: ...
463
+
464
+ @method(preserve=True)
465
+ def __len__(self) -> int:
466
+ return self.length().eval()
467
+
468
+ @method(preserve=True)
469
+ def __iter__(self) -> Iterator[Int]:
470
+ return iter(self.eval())
471
+
472
+ @property
473
+ def to_vec(self) -> Vec[Int]: ...
474
+
475
+ @method(preserve=True)
476
+ def eval(self) -> tuple[Int, ...]:
477
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
478
+
479
+ def foldl(self, f: Callable[[Int, Int], Int], init: Int) -> Int: ...
480
+ def foldl_boolean(self, f: Callable[[Boolean, Int], Boolean], init: Boolean) -> Boolean: ...
481
+ def foldl_tuple_int(self, f: Callable[[TupleInt, Int], TupleInt], init: TupleIntLike) -> TupleInt: ...
482
+
483
+ @method(subsume=True)
484
+ def contains(self, i: Int) -> Boolean:
485
+ return self.foldl_boolean(lambda acc, j: acc | (i == j), FALSE)
486
+
487
+ @method(subsume=True)
488
+ def filter(self, f: Callable[[Int], Boolean]) -> TupleInt:
489
+ return self.foldl_tuple_int(
490
+ lambda acc, v: TupleInt.if_(f(v), acc.append(v), acc),
491
+ TupleInt.EMPTY,
492
+ )
493
+
494
+ @method(subsume=True)
495
+ def map(self, f: Callable[[Int], Int]) -> TupleInt:
496
+ return TupleInt(self.length(), lambda i: f(self[i]))
497
+
498
+ @classmethod
499
+ def if_(cls, b: BooleanLike, i: TupleIntLike, j: TupleIntLike) -> TupleInt: ...
500
+
501
+ def drop(self, n: Int) -> TupleInt:
502
+ return TupleInt(self.length() - n, lambda i: self[i + n])
503
+
504
+ def product(self) -> Int:
505
+ return self.foldl(lambda acc, i: acc * i, Int(1))
506
+
507
+ def map_tuple_int(self, f: Callable[[Int], TupleInt]) -> TupleTupleInt:
508
+ return TupleTupleInt(self.length(), lambda i: f(self[i]))
509
+
510
+ def select(self, indices: TupleIntLike) -> TupleInt:
511
+ """
512
+ Return a new tuple with the elements at the given indices
513
+ """
514
+ indices = cast("TupleInt", indices)
515
+ return indices.map(lambda i: self[i])
516
+
517
+ def deselect(self, indices: TupleIntLike) -> TupleInt:
518
+ """
519
+ Return a new tuple with the elements not at the given indices
520
+ """
521
+ indices = cast("TupleInt", indices)
522
+ return TupleInt.range(self.length()).filter(lambda i: ~indices.contains(i)).map(lambda i: self[i])
523
+
524
+
525
+ converter(Vec[Int], TupleInt, lambda x: TupleInt.from_vec(x))
526
+
527
+ TupleIntLike: TypeAlias = TupleInt | VecLike[Int, IntLike]
528
+
529
+
530
+ @array_api_ruleset.register
531
+ def _tuple_int(
532
+ i: Int,
533
+ i2: Int,
534
+ f: Callable[[Int, Int], Int],
535
+ bool_f: Callable[[Boolean, Int], Boolean],
536
+ idx_fn: Callable[[Int], Int],
537
+ tuple_int_f: Callable[[TupleInt, Int], TupleInt],
538
+ vs: Vec[Int],
539
+ b: Boolean,
540
+ ti: TupleInt,
541
+ ti2: TupleInt,
542
+ k: i64,
543
+ ):
544
+ return [
545
+ rule(eq(ti).to(TupleInt.from_vec(vs))).then(set_(ti.to_vec).to(vs)),
546
+ # Functional access
547
+ rewrite(TupleInt(i, idx_fn).length()).to(i),
548
+ rewrite(TupleInt(i, idx_fn)[i2]).to(idx_fn(check_index(i, i2))),
549
+ # cons access
550
+ rewrite(TupleInt.EMPTY.length()).to(Int(0)),
551
+ rewrite(TupleInt.EMPTY[i]).to(Int.NEVER),
552
+ rewrite(ti.append(i).length()).to(ti.length() + 1),
553
+ rewrite(ti.append(i)[i2]).to(Int.if_(i2 == ti.length(), i, ti[i2])),
554
+ # cons to functional (removed this so that there is not infinite replacements between the,)
555
+ # rewrite(TupleInt.EMPTY).to(TupleInt(0, lambda _: Int.NEVER)),
556
+ # rewrite(TupleInt(i, idx_fn).append(i2)).to(TupleInt(i + 1, lambda j: Int.if_(j == i, i2, idx_fn(j)))),
557
+ # functional to cons
558
+ rewrite(TupleInt(0, idx_fn), subsume=True).to(TupleInt.EMPTY),
559
+ rewrite(TupleInt(Int(k), idx_fn), subsume=True).to(TupleInt(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0),
560
+ # cons to vec
561
+ rewrite(TupleInt.EMPTY).to(TupleInt.from_vec(Vec[Int]())),
562
+ rewrite(TupleInt.from_vec(vs).append(i)).to(TupleInt.from_vec(vs.append(Vec(i)))),
563
+ # fold
564
+ rewrite(TupleInt.EMPTY.foldl(f, i), subsume=True).to(i),
565
+ rewrite(ti.append(i2).foldl(f, i), subsume=True).to(f(ti.foldl(f, i), i2)),
566
+ # fold boolean
567
+ rewrite(TupleInt.EMPTY.foldl_boolean(bool_f, b), subsume=True).to(b),
568
+ rewrite(ti.append(i2).foldl_boolean(bool_f, b), subsume=True).to(bool_f(ti.foldl_boolean(bool_f, b), i2)),
569
+ # fold tuple_int
570
+ rewrite(TupleInt.EMPTY.foldl_tuple_int(tuple_int_f, ti), subsume=True).to(ti),
571
+ rewrite(ti.append(i2).foldl_tuple_int(tuple_int_f, ti2), subsume=True).to(
572
+ tuple_int_f(ti.foldl_tuple_int(tuple_int_f, ti2), i2)
573
+ ),
574
+ # if_
575
+ rewrite(TupleInt.if_(TRUE, ti, ti2), subsume=True).to(ti),
576
+ rewrite(TupleInt.if_(FALSE, ti, ti2), subsume=True).to(ti2),
577
+ # unify append
578
+ rule(eq(ti.append(i)).to(ti2.append(i2))).then(union(ti).with_(ti2), union(i).with_(i2)),
579
+ ]
580
+
581
+
582
+ class TupleTupleInt(Expr, ruleset=array_api_ruleset):
583
+ @classmethod
584
+ def var(cls, name: StringLike) -> TupleTupleInt: ...
585
+
586
+ EMPTY: ClassVar[TupleTupleInt]
587
+
588
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> None: ...
589
+
590
+ @method(subsume=True)
591
+ @classmethod
592
+ def single(cls, i: TupleIntLike) -> TupleTupleInt:
593
+ i = cast("TupleInt", i)
594
+ return TupleTupleInt(1, lambda _: i)
595
+
596
+ @method(subsume=True)
597
+ @classmethod
598
+ def from_vec(cls, vec: Vec[TupleInt]) -> TupleTupleInt: ...
599
+
600
+ def append(self, i: TupleIntLike) -> TupleTupleInt: ...
601
+
602
+ def __add__(self, other: TupleTupleIntLike) -> TupleTupleInt:
603
+ other = cast("TupleTupleInt", other)
604
+ return TupleTupleInt(
605
+ self.length() + other.length(),
606
+ lambda i: TupleInt.if_(i < self.length(), self[i], other[i - self.length()]),
607
+ )
608
+
609
+ def length(self) -> Int: ...
610
+ def __getitem__(self, i: IntLike) -> TupleInt: ...
611
+
612
+ @method(preserve=True)
613
+ def __len__(self) -> int:
614
+ return self.length().eval()
615
+
616
+ @method(preserve=True)
617
+ def __iter__(self) -> Iterator[TupleInt]:
618
+ return iter(self.eval())
619
+
620
+ @property
621
+ def to_vec(self) -> Vec[TupleInt]: ...
622
+
623
+ @method(preserve=True)
624
+ def eval(self) -> tuple[TupleInt, ...]:
625
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
626
+
627
+ def drop(self, n: Int) -> TupleTupleInt:
628
+ return TupleTupleInt(self.length() - n, lambda i: self[i + n])
629
+
630
+ def map_int(self, f: Callable[[TupleInt], Int]) -> TupleInt:
631
+ return TupleInt(self.length(), lambda i: f(self[i]))
632
+
633
+ def foldl_value(self, f: Callable[[Value, TupleInt], Value], init: ValueLike) -> Value: ...
634
+
635
+ @method(subsume=True)
636
+ def product(self) -> TupleTupleInt:
637
+ """
638
+ Cartesian product of inputs
639
+
640
+ https://docs.python.org/3/library/itertools.html#itertools.product
641
+
642
+ https://github.com/saulshanabrook/saulshanabrook/discussions/39
643
+ """
644
+ return TupleTupleInt(
645
+ self.map_int(lambda x: x.length()).product(),
646
+ lambda i: TupleInt(
647
+ self.length(),
648
+ lambda j: self[j][(i // self.drop(j + 1).map_int(lambda x: x.length()).product()) % self[j].length()],
649
+ ),
650
+ )
651
+
652
+
653
+ converter(Vec[TupleInt], TupleTupleInt, lambda x: TupleTupleInt.from_vec(x))
654
+
655
+ TupleTupleIntLike: TypeAlias = TupleTupleInt | VecLike[TupleInt, TupleIntLike]
656
+
657
+
658
+ @array_api_ruleset.register
659
+ def _tuple_tuple_int(
660
+ length: Int,
661
+ fn: Callable[[TupleInt], Int],
662
+ idx_fn: Callable[[Int], TupleInt],
663
+ f: Callable[[Value, TupleInt], Value],
664
+ i: Value,
665
+ k: i64,
666
+ idx: Int,
667
+ vs: Vec[TupleInt],
668
+ ti: TupleInt,
669
+ ti1: TupleInt,
670
+ tti: TupleTupleInt,
671
+ tti1: TupleTupleInt,
672
+ ):
673
+ yield rule(eq(tti).to(TupleTupleInt.from_vec(vs))).then(set_(tti.to_vec).to(vs))
674
+ yield rewrite(TupleTupleInt(length, idx_fn).length()).to(length)
675
+ yield rewrite(TupleTupleInt(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length)))
676
+
677
+ # cons access
678
+ yield rewrite(TupleTupleInt.EMPTY.length()).to(Int(0))
679
+ yield rewrite(TupleTupleInt.EMPTY[idx]).to(TupleInt.NEVER)
680
+ yield rewrite(tti.append(ti).length()).to(tti.length() + 1)
681
+ yield rewrite(tti.append(ti)[idx]).to(TupleInt.if_(idx == tti.length(), ti, tti[idx]))
682
+
683
+ # functional to cons
684
+ yield rewrite(TupleTupleInt(0, idx_fn), subsume=True).to(TupleTupleInt.EMPTY)
685
+ yield rewrite(TupleTupleInt(Int(k), idx_fn), subsume=True).to(
686
+ TupleTupleInt(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0
687
+ )
688
+ # cons to vec
689
+ yield rewrite(TupleTupleInt.EMPTY).to(TupleTupleInt.from_vec(Vec[TupleInt]()))
690
+ yield rewrite(TupleTupleInt.from_vec(vs).append(ti)).to(TupleTupleInt.from_vec(vs.append(Vec(ti))))
691
+ # fold value
692
+ yield rewrite(TupleTupleInt.EMPTY.foldl_value(f, i), subsume=True).to(i)
693
+ yield rewrite(tti.append(ti).foldl_value(f, i), subsume=True).to(f(tti.foldl_value(f, i), ti))
694
+
695
+ # unify append
696
+ yield rule(eq(tti.append(ti)).to(tti1.append(ti1))).then(union(tti).with_(tti1), union(ti).with_(ti1))
697
+
698
+
699
+ class OptionalInt(Expr, ruleset=array_api_ruleset):
700
+ none: ClassVar[OptionalInt]
701
+
702
+ @classmethod
703
+ def some(cls, value: Int) -> OptionalInt: ...
704
+
705
+
706
+ OptionalIntLike: TypeAlias = OptionalInt | IntLike | None
707
+
708
+ converter(type(None), OptionalInt, lambda _: OptionalInt.none)
709
+ converter(Int, OptionalInt, OptionalInt.some)
710
+
711
+
712
+ class DType(Expr, ruleset=array_api_ruleset):
713
+ float64: ClassVar[DType]
714
+ float32: ClassVar[DType]
715
+ int64: ClassVar[DType]
716
+ int32: ClassVar[DType]
717
+ object: ClassVar[DType]
718
+ bool: ClassVar[DType]
719
+
720
+ def __eq__(self, other: DType) -> Boolean: # type: ignore[override]
721
+ ...
722
+
723
+
724
+ float64 = DType.float64
725
+ float32 = DType.float32
726
+ int32 = DType.int32
727
+ int64 = DType.int64
728
+
729
+ _DTYPES = [float64, float32, int32, int64, DType.object]
730
+
731
+ converter(type, DType, lambda x: convert(np.dtype(x), DType))
732
+ converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type:ignore[call-overload]
733
+
734
+
735
+ @array_api_ruleset.register
736
+ def _():
737
+ for l, r in itertools.product(_DTYPES, repeat=2):
738
+ yield rewrite(l == r).to(TRUE if l is r else FALSE)
739
+
740
+
741
+ class IsDtypeKind(Expr, ruleset=array_api_ruleset):
742
+ NULL: ClassVar[IsDtypeKind]
743
+
744
+ @classmethod
745
+ def string(cls, s: StringLike) -> IsDtypeKind: ...
746
+
747
+ @classmethod
748
+ def dtype(cls, d: DType) -> IsDtypeKind: ...
749
+
750
+ @method(cost=10)
751
+ def __or__(self, other: IsDtypeKind) -> IsDtypeKind: ...
752
+
753
+
754
+ # TODO: Make kind more generic to support tuples.
755
+ @function
756
+ def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: ...
757
+
758
+
759
+ converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x))
760
+ converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x))
761
+ converter(
762
+ tuple, IsDtypeKind, lambda x: convert(x[0], IsDtypeKind) | convert(x[1:], IsDtypeKind) if x else IsDtypeKind.NULL
763
+ )
764
+
765
+
766
+ @array_api_ruleset.register
767
+ def _isdtype(d: DType, k1: IsDtypeKind, k2: IsDtypeKind):
768
+ return [
769
+ rewrite(isdtype(DType.float32, IsDtypeKind.string("integral"))).to(FALSE),
770
+ rewrite(isdtype(DType.float64, IsDtypeKind.string("integral"))).to(FALSE),
771
+ rewrite(isdtype(DType.object, IsDtypeKind.string("integral"))).to(FALSE),
772
+ rewrite(isdtype(DType.int64, IsDtypeKind.string("integral"))).to(TRUE),
773
+ rewrite(isdtype(DType.int32, IsDtypeKind.string("integral"))).to(TRUE),
774
+ rewrite(isdtype(DType.float32, IsDtypeKind.string("real floating"))).to(TRUE),
775
+ rewrite(isdtype(DType.float64, IsDtypeKind.string("real floating"))).to(TRUE),
776
+ rewrite(isdtype(DType.object, IsDtypeKind.string("real floating"))).to(FALSE),
777
+ rewrite(isdtype(DType.int64, IsDtypeKind.string("real floating"))).to(FALSE),
778
+ rewrite(isdtype(DType.int32, IsDtypeKind.string("real floating"))).to(FALSE),
779
+ rewrite(isdtype(DType.float32, IsDtypeKind.string("complex floating"))).to(FALSE),
780
+ rewrite(isdtype(DType.float64, IsDtypeKind.string("complex floating"))).to(FALSE),
781
+ rewrite(isdtype(DType.object, IsDtypeKind.string("complex floating"))).to(FALSE),
782
+ rewrite(isdtype(DType.int64, IsDtypeKind.string("complex floating"))).to(FALSE),
783
+ rewrite(isdtype(DType.int32, IsDtypeKind.string("complex floating"))).to(FALSE),
784
+ rewrite(isdtype(d, IsDtypeKind.NULL)).to(FALSE),
785
+ rewrite(isdtype(d, IsDtypeKind.dtype(d))).to(TRUE),
786
+ rewrite(isdtype(d, k1 | k2)).to(isdtype(d, k1) | isdtype(d, k2)),
787
+ rewrite(k1 | IsDtypeKind.NULL).to(k1),
788
+ ]
789
+
790
+
791
+ # TODO: Add pushdown for math on scalars to values
792
+ # and add replacements
793
+
794
+
795
+ class Value(Expr, ruleset=array_api_ruleset):
796
+ NEVER: ClassVar[Value]
797
+
798
+ @classmethod
799
+ def int(cls, i: IntLike) -> Value: ...
800
+
801
+ @classmethod
802
+ def float(cls, f: FloatLike) -> Value: ...
803
+
804
+ @classmethod
805
+ def bool(cls, b: BooleanLike) -> Value: ...
806
+
807
+ def isfinite(self) -> Boolean: ...
808
+
809
+ def __lt__(self, other: ValueLike) -> Value: ...
810
+
811
+ def __truediv__(self, other: ValueLike) -> Value: ...
812
+
813
+ def __mul__(self, other: ValueLike) -> Value: ...
814
+
815
+ def __add__(self, other: ValueLike) -> Value: ...
816
+
817
+ def astype(self, dtype: DType) -> Value: ...
818
+
819
+ # TODO: Add all operations
820
+
821
+ @property
822
+ def dtype(self) -> DType:
823
+ """
824
+ Default dtype for this scalar value
825
+ """
826
+
827
+ @property
828
+ def to_bool(self) -> Boolean: ...
829
+
830
+ @property
831
+ def to_int(self) -> Int: ...
832
+
833
+ @property
834
+ def to_truthy_value(self) -> Value:
835
+ """
836
+ Converts the value to a bool, based on if its truthy.
837
+
838
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.any.html
839
+ """
840
+
841
+ def conj(self) -> Value: ...
842
+ def real(self) -> Value: ...
843
+ def sqrt(self) -> Value: ...
844
+
845
+ @classmethod
846
+ def if_(cls, b: BooleanLike, i: ValueLike, j: ValueLike) -> Value: ...
847
+
848
+ def __eq__(self, other: ValueLike) -> Boolean: ... # type: ignore[override]
849
+
850
+
851
+ ValueLike: TypeAlias = Value | IntLike | FloatLike | BooleanLike
852
+
853
+
854
+ converter(Int, Value, Value.int)
855
+ converter(Float, Value, Value.float)
856
+ converter(Boolean, Value, Value.bool)
857
+ converter(Value, Int, lambda x: x.to_int, 10)
858
+
859
+
860
+ @array_api_ruleset.register
861
+ def _value(i: Int, f: Float, b: Boolean, v: Value, v1: Value, i1: Int, f1: Float, b1: Boolean):
862
+ # Default dtypes
863
+ # https://data-apis.org/array-api/latest/API_specification/data_types.html?highlight=dtype#default-data-types
864
+ yield rewrite(Value.int(i).dtype).to(DType.int64)
865
+ yield rewrite(Value.float(f).dtype).to(DType.float64)
866
+ yield rewrite(Value.bool(b).dtype).to(DType.bool)
867
+
868
+ yield rewrite(Value.bool(b).to_bool).to(b)
869
+ yield rewrite(Value.int(i).to_int).to(i)
870
+
871
+ yield rewrite(Value.bool(b).to_truthy_value).to(Value.bool(b))
872
+ # TODO: Add more rules for to_bool_value
873
+
874
+ yield rewrite(Value.float(f).conj()).to(Value.float(f))
875
+ yield rewrite(Value.float(f).real()).to(Value.float(f))
876
+ yield rewrite(Value.int(i).real()).to(Value.int(i))
877
+ yield rewrite(Value.int(i).conj()).to(Value.int(i))
878
+
879
+ yield rewrite(Value.float(f).sqrt()).to(Value.float(f ** (0.5)))
880
+
881
+ yield rewrite(Value.float(Float.rational(BigRat(0, 1))) + v).to(v)
882
+
883
+ yield rewrite(Value.if_(TRUE, v, v1)).to(v)
884
+ yield rewrite(Value.if_(FALSE, v, v1)).to(v1)
885
+
886
+ # ==
887
+ yield rewrite(Value.int(i) == Value.int(i1)).to(i == i1)
888
+ yield rewrite(Value.float(f) == Value.float(f1)).to(f == f1)
889
+ yield rewrite(Value.bool(b) == Value.bool(b1)).to(b == b1)
890
+
891
+
892
+ class TupleValue(Expr, ruleset=array_api_ruleset):
893
+ EMPTY: ClassVar[TupleValue]
894
+
895
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], Value]) -> None: ...
896
+
897
+ def append(self, i: ValueLike) -> TupleValue: ...
898
+
899
+ @classmethod
900
+ def from_vec(cls, vec: Vec[Value]) -> TupleValue: ...
901
+
902
+ def __add__(self, other: TupleValueLike) -> TupleValue:
903
+ other = cast("TupleValue", other)
904
+ return TupleValue(
905
+ self.length() + other.length(),
906
+ lambda i: Value.if_(i < self.length(), self[i], other[i - self.length()]),
907
+ )
908
+
909
+ def length(self) -> Int: ...
910
+
911
+ def __getitem__(self, i: Int) -> Value: ...
912
+
913
+ def foldl_boolean(self, f: Callable[[Boolean, Value], Boolean], init: BooleanLike) -> Boolean: ...
914
+
915
+ def contains(self, value: ValueLike) -> Boolean:
916
+ value = cast("Value", value)
917
+ return self.foldl_boolean(lambda acc, j: acc | (value == j), FALSE)
918
+
919
+ @method(subsume=True)
920
+ @classmethod
921
+ def from_tuple_int(cls, ti: TupleIntLike) -> TupleValue:
922
+ ti = cast("TupleInt", ti)
923
+ return TupleValue(ti.length(), lambda i: Value.int(ti[i]))
924
+
925
+
926
+ converter(Vec[Value], TupleValue, lambda x: TupleValue.from_vec(x))
927
+ converter(TupleInt, TupleValue, lambda x: TupleValue.from_tuple_int(x))
928
+
929
+ TupleValueLike: TypeAlias = TupleValue | VecLike[Value, ValueLike] | TupleIntLike
930
+
931
+
932
+ @array_api_ruleset.register
933
+ def _tuple_value(
934
+ length: Int,
935
+ idx_fn: Callable[[Int], Value],
936
+ k: i64,
937
+ idx: Int,
938
+ vs: Vec[Value],
939
+ v: Value,
940
+ v1: Value,
941
+ tv: TupleValue,
942
+ tv1: TupleValue,
943
+ bool_f: Callable[[Boolean, Value], Boolean],
944
+ b: Boolean,
945
+ ):
946
+ yield rewrite(TupleValue(length, idx_fn).length()).to(length)
947
+ yield rewrite(TupleValue(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length)))
948
+
949
+ # cons access
950
+ yield rewrite(TupleValue.EMPTY.length()).to(Int(0))
951
+ yield rewrite(TupleValue.EMPTY[idx]).to(Value.NEVER)
952
+ yield rewrite(tv.append(v).length()).to(tv.length() + 1)
953
+ yield rewrite(tv.append(v)[idx]).to(Value.if_(idx == tv.length(), v, tv[idx]))
954
+
955
+ # functional to cons
956
+ yield rewrite(TupleValue(0, idx_fn), subsume=True).to(TupleValue.EMPTY)
957
+ yield rewrite(TupleValue(Int(k), idx_fn), subsume=True).to(
958
+ TupleValue(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0
959
+ )
960
+
961
+ # cons to vec
962
+ yield rewrite(TupleValue.EMPTY).to(TupleValue.from_vec(Vec[Value]()))
963
+ yield rewrite(TupleValue.from_vec(vs).append(v)).to(TupleValue.from_vec(vs.append(Vec(v))))
964
+
965
+ # fold boolean
966
+ yield rewrite(TupleValue.EMPTY.foldl_boolean(bool_f, b), subsume=True).to(b)
967
+ yield rewrite(tv.append(v).foldl_boolean(bool_f, b), subsume=True).to(bool_f(tv.foldl_boolean(bool_f, b), v))
968
+
969
+ # unify append
970
+ yield rule(eq(tv.append(v)).to(tv1.append(v1))).then(union(tv).with_(tv1), union(v).with_(v1))
971
+
972
+
973
+ @function
974
+ def possible_values(values: Value) -> TupleValue:
975
+ """
976
+ All possible values in the input value.
977
+ """
978
+
979
+
980
+ class Slice(Expr, ruleset=array_api_ruleset):
981
+ def __init__(
982
+ self,
983
+ start: OptionalInt = OptionalInt.none,
984
+ stop: OptionalInt = OptionalInt.none,
985
+ step: OptionalInt = OptionalInt.none,
986
+ ) -> None: ...
987
+
988
+
989
+ converter(
990
+ slice,
991
+ Slice,
992
+ lambda x: Slice(convert(x.start, OptionalInt), convert(x.stop, OptionalInt), convert(x.step, OptionalInt)),
993
+ )
994
+
995
+ SliceLike: TypeAlias = Slice | slice
996
+
997
+
998
+ class MultiAxisIndexKeyItem(Expr, ruleset=array_api_ruleset):
999
+ ELLIPSIS: ClassVar[MultiAxisIndexKeyItem]
1000
+ NONE: ClassVar[MultiAxisIndexKeyItem]
1001
+
1002
+ @classmethod
1003
+ def int(cls, i: Int) -> MultiAxisIndexKeyItem: ...
1004
+
1005
+ @classmethod
1006
+ def slice(cls, slice: Slice) -> MultiAxisIndexKeyItem: ...
1007
+
1008
+
1009
+ converter(type(...), MultiAxisIndexKeyItem, lambda _: MultiAxisIndexKeyItem.ELLIPSIS)
1010
+ converter(type(None), MultiAxisIndexKeyItem, lambda _: MultiAxisIndexKeyItem.NONE)
1011
+ converter(Int, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.int)
1012
+ converter(Slice, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.slice)
1013
+
1014
+ MultiAxisIndexKeyItemLike: TypeAlias = MultiAxisIndexKeyItem | EllipsisType | None | IntLike | SliceLike
1015
+
1016
+
1017
+ class MultiAxisIndexKey(Expr, ruleset=array_api_ruleset):
1018
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], MultiAxisIndexKeyItem]) -> None: ...
1019
+
1020
+ def __add__(self, other: MultiAxisIndexKey) -> MultiAxisIndexKey: ...
1021
+
1022
+ @classmethod
1023
+ def from_vec(cls, vec: Vec[MultiAxisIndexKeyItem]) -> MultiAxisIndexKey: ...
1024
+
1025
+
1026
+ MultiAxisIndexKeyLike: TypeAlias = "MultiAxisIndexKey | tuple[MultiAxisIndexKeyItemLike, ...] | TupleIntLike"
1027
+
1028
+
1029
+ converter(
1030
+ tuple,
1031
+ MultiAxisIndexKey,
1032
+ lambda x: MultiAxisIndexKey.from_vec(Vec(*(convert(i, MultiAxisIndexKeyItem) for i in x))),
1033
+ )
1034
+ converter(
1035
+ TupleInt, MultiAxisIndexKey, lambda ti: MultiAxisIndexKey(ti.length(), lambda i: MultiAxisIndexKeyItem.int(ti[i]))
1036
+ )
1037
+
1038
+
1039
+ class IndexKey(Expr, ruleset=array_api_ruleset):
1040
+ """
1041
+ A key for indexing into an array
1042
+
1043
+ https://data-apis.org/array-api/2022.12/API_specification/indexing.html
1044
+
1045
+ It is equivalent to the following type signature:
1046
+
1047
+ Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis, None], ...], array]
1048
+ """
1049
+
1050
+ ELLIPSIS: ClassVar[IndexKey]
1051
+
1052
+ @classmethod
1053
+ def int(cls, i: Int) -> IndexKey: ...
1054
+
1055
+ @classmethod
1056
+ def slice(cls, slice: Slice) -> IndexKey: ...
1057
+
1058
+ # Disabled until we support late binding
1059
+ # @classmethod
1060
+ # def boolean_array(cls, b: NDArray) -> IndexKey:
1061
+ # ...
1062
+
1063
+ @classmethod
1064
+ def multi_axis(cls, key: MultiAxisIndexKey) -> IndexKey: ...
1065
+
1066
+ @classmethod
1067
+ def ndarray(cls, key: NDArray) -> IndexKey:
1068
+ """
1069
+ Indexes by a masked array
1070
+ """
1071
+
1072
+
1073
+ IndexKeyLike: TypeAlias = "IndexKey | IntLike | SliceLike | MultiAxisIndexKeyLike | NDArrayLike"
1074
+
1075
+
1076
+ converter(type(...), IndexKey, lambda _: IndexKey.ELLIPSIS)
1077
+ converter(Int, IndexKey, lambda i: IndexKey.int(i))
1078
+ converter(Slice, IndexKey, lambda s: IndexKey.slice(s))
1079
+ converter(MultiAxisIndexKey, IndexKey, lambda m: IndexKey.multi_axis(m))
1080
+
1081
+
1082
+ class Device(Expr, ruleset=array_api_ruleset): ...
1083
+
1084
+
1085
+ ALL_INDICES: TupleInt = constant("ALL_INDICES", TupleInt)
1086
+
1087
+
1088
+ class NDArray(Expr, ruleset=array_api_ruleset):
1089
+ def __init__(self, shape: TupleIntLike, dtype: DType, idx_fn: Callable[[TupleInt], Value]) -> None: ...
1090
+
1091
+ NEVER: ClassVar[NDArray]
1092
+
1093
+ @method(cost=200)
1094
+ @classmethod
1095
+ def var(cls, name: StringLike) -> NDArray: ...
1096
+
1097
+ @method(preserve=True)
1098
+ def __array_namespace__(self, api_version: object = None) -> ModuleType:
1099
+ return sys.modules[__name__]
1100
+
1101
+ @property
1102
+ def ndim(self) -> Int: ...
1103
+
1104
+ @property
1105
+ def dtype(self) -> DType: ...
1106
+
1107
+ @property
1108
+ def device(self) -> Device: ...
1109
+
1110
+ @property
1111
+ def shape(self) -> TupleInt: ...
1112
+
1113
+ @method(preserve=True)
1114
+ def __bool__(self) -> bool:
1115
+ return self.to_value().to_bool.eval()
1116
+
1117
+ @property
1118
+ def size(self) -> Int: ...
1119
+
1120
+ @method(preserve=True)
1121
+ def __len__(self) -> int:
1122
+ return self.size.eval()
1123
+
1124
+ @method(preserve=True)
1125
+ def __iter__(self) -> Iterator[NDArray]:
1126
+ for i in range(len(self)):
1127
+ yield self[IndexKey.int(Int(i))]
1128
+
1129
+ def __getitem__(self, key: IndexKeyLike) -> NDArray: ...
1130
+
1131
+ def __setitem__(self, key: IndexKeyLike, value: NDArray) -> None: ...
1132
+
1133
+ def __lt__(self, other: NDArrayLike) -> NDArray: ...
1134
+
1135
+ def __le__(self, other: NDArrayLike) -> NDArray: ...
1136
+
1137
+ def __eq__(self, other: NDArrayLike) -> NDArray: # type: ignore[override]
1138
+ ...
1139
+
1140
+ # TODO: Add support for overloaded __ne__
1141
+ # def __ne__(self, other: NDArrayLike) -> NDArray: # type: ignore[override]
1142
+ # ...
1143
+
1144
+ def __gt__(self, other: NDArrayLike) -> NDArray: ...
1145
+
1146
+ def __ge__(self, other: NDArrayLike) -> NDArray: ...
1147
+
1148
+ def __add__(self, other: NDArrayLike) -> NDArray: ...
1149
+
1150
+ def __sub__(self, other: NDArrayLike) -> NDArray: ...
1151
+
1152
+ def __mul__(self, other: NDArrayLike) -> NDArray: ...
1153
+
1154
+ def __matmul__(self, other: NDArrayLike) -> NDArray: ...
1155
+
1156
+ def __truediv__(self, other: NDArrayLike) -> NDArray: ...
1157
+
1158
+ def __floordiv__(self, other: NDArrayLike) -> NDArray: ...
1159
+
1160
+ def __mod__(self, other: NDArrayLike) -> NDArray: ...
1161
+
1162
+ def __divmod__(self, other: NDArrayLike) -> NDArray: ...
1163
+
1164
+ def __pow__(self, other: NDArrayLike) -> NDArray: ...
1165
+
1166
+ def __lshift__(self, other: NDArrayLike) -> NDArray: ...
1167
+
1168
+ def __rshift__(self, other: NDArrayLike) -> NDArray: ...
1169
+
1170
+ def __and__(self, other: NDArrayLike) -> NDArray: ...
1171
+
1172
+ def __xor__(self, other: NDArrayLike) -> NDArray: ...
1173
+
1174
+ def __or__(self, other: NDArrayLike) -> NDArray: ...
1175
+
1176
+ def __radd__(self, other: NDArray) -> NDArray: ...
1177
+
1178
+ def __rsub__(self, other: NDArray) -> NDArray: ...
1179
+
1180
+ def __rmul__(self, other: NDArray) -> NDArray: ...
1181
+
1182
+ def __rmatmul__(self, other: NDArray) -> NDArray: ...
1183
+
1184
+ def __rtruediv__(self, other: NDArray) -> NDArray: ...
1185
+
1186
+ def __rfloordiv__(self, other: NDArray) -> NDArray: ...
1187
+
1188
+ def __rmod__(self, other: NDArray) -> NDArray: ...
1189
+
1190
+ def __rpow__(self, other: NDArray) -> NDArray: ...
1191
+
1192
+ def __rlshift__(self, other: NDArray) -> NDArray: ...
1193
+
1194
+ def __rrshift__(self, other: NDArray) -> NDArray: ...
1195
+
1196
+ def __rand__(self, other: NDArray) -> NDArray: ...
1197
+
1198
+ def __rxor__(self, other: NDArray) -> NDArray: ...
1199
+
1200
+ def __ror__(self, other: NDArray) -> NDArray: ...
1201
+
1202
+ @classmethod
1203
+ def scalar(cls, value: Value) -> NDArray:
1204
+ return NDArray(TupleInt.EMPTY, value.dtype, lambda _: value)
1205
+
1206
+ def to_value(self) -> Value:
1207
+ """
1208
+ Returns the value if this is a scalar.
1209
+ """
1210
+
1211
+ def to_values(self) -> TupleValue:
1212
+ """
1213
+ Returns the value if this is a vector.
1214
+ """
1215
+
1216
+ @property
1217
+ def T(self) -> NDArray:
1218
+ """
1219
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.array.T.html#array_api.array.T
1220
+ """
1221
+
1222
+ @classmethod
1223
+ def vector(cls, values: TupleValueLike) -> NDArray: ...
1224
+
1225
+ def index(self, indices: TupleIntLike) -> Value:
1226
+ """
1227
+ Return the value at the given indices.
1228
+ """
1229
+
1230
+ @classmethod
1231
+ def if_(cls, b: BooleanLike, i: NDArrayLike, j: NDArrayLike) -> NDArray: ...
1232
+
1233
+
1234
+ NDArrayLike: TypeAlias = NDArray | ValueLike | TupleValueLike
1235
+
1236
+ converter(NDArray, IndexKey, lambda v: IndexKey.ndarray(v))
1237
+ converter(Value, NDArray, lambda v: NDArray.scalar(v))
1238
+ # Need this if we want to use ints in slices of arrays coming from 1d arrays, but make it more expensive
1239
+ # to prefer upcasting in the other direction when we can, which is safer at runtime
1240
+ converter(NDArray, Value, lambda n: n.to_value(), 100)
1241
+ converter(TupleValue, NDArray, lambda v: NDArray.vector(v))
1242
+ converter(TupleInt, TupleValue, lambda v: TupleValue.from_tuple_int(v))
1243
+
1244
+
1245
+ @array_api_ruleset.register
1246
+ def _ndarray(
1247
+ x: NDArray,
1248
+ x1: NDArray,
1249
+ b: Boolean,
1250
+ f: Float,
1251
+ fi1: f64,
1252
+ fi2: f64,
1253
+ shape: TupleInt,
1254
+ dtype: DType,
1255
+ idx_fn: Callable[[TupleInt], Value],
1256
+ idx: TupleInt,
1257
+ tv: TupleValue,
1258
+ ):
1259
+ return [
1260
+ rewrite(NDArray(shape, dtype, idx_fn).shape).to(shape),
1261
+ rewrite(NDArray(shape, dtype, idx_fn).dtype).to(dtype),
1262
+ rewrite(NDArray(shape, dtype, idx_fn).index(idx), subsume=True).to(idx_fn(idx)),
1263
+ rewrite(x.ndim).to(x.shape.length()),
1264
+ # rewrite(NDArray.scalar(Value.bool(b)).to_bool()).to(b),
1265
+ # Converting to a value requires a scalar bool value
1266
+ rewrite(x.to_value()).to(x.index(TupleInt.EMPTY)),
1267
+ rewrite(NDArray.vector(tv).to_values()).to(tv),
1268
+ # TODO: Push these down to float
1269
+ rewrite(NDArray.scalar(Value.float(f)) / NDArray.scalar(Value.float(f))).to(
1270
+ NDArray.scalar(Value.float(Float(1.0)))
1271
+ ),
1272
+ rewrite(NDArray.scalar(Value.float(f)) - NDArray.scalar(Value.float(f))).to(
1273
+ NDArray.scalar(Value.float(Float(0.0)))
1274
+ ),
1275
+ rewrite(NDArray.scalar(Value.float(Float(fi1))) > NDArray.scalar(Value.float(Float(fi2)))).to(
1276
+ NDArray.scalar(Value.bool(TRUE)), fi1 > fi2
1277
+ ),
1278
+ rewrite(NDArray.scalar(Value.float(Float(fi1))) > NDArray.scalar(Value.float(Float(fi2)))).to(
1279
+ NDArray.scalar(Value.bool(FALSE)), fi1 <= fi2
1280
+ ),
1281
+ # Transpose of tranpose is the original array
1282
+ rewrite(x.T.T).to(x),
1283
+ # if_
1284
+ rewrite(NDArray.if_(TRUE, x, x1)).to(x),
1285
+ rewrite(NDArray.if_(FALSE, x, x1)).to(x1),
1286
+ ]
1287
+
1288
+
1289
+ class TupleNDArray(Expr, ruleset=array_api_ruleset):
1290
+ EMPTY: ClassVar[TupleNDArray]
1291
+
1292
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], NDArray]) -> None: ...
1293
+
1294
+ def append(self, i: NDArrayLike) -> TupleNDArray: ...
1295
+
1296
+ @classmethod
1297
+ def from_vec(cls, vec: Vec[NDArray]) -> TupleNDArray: ...
1298
+
1299
+ def __add__(self, other: TupleNDArrayLike) -> TupleNDArray:
1300
+ other = cast("TupleNDArray", other)
1301
+ return TupleNDArray(
1302
+ self.length() + other.length(),
1303
+ lambda i: NDArray.if_(i < self.length(), self[i], other[i - self.length()]),
1304
+ )
1305
+
1306
+ def length(self) -> Int: ...
1307
+
1308
+ def __getitem__(self, i: IntLike) -> NDArray: ...
1309
+
1310
+ @method(preserve=True)
1311
+ def __len__(self) -> int:
1312
+ return self.length().eval()
1313
+
1314
+ @method(preserve=True)
1315
+ def __iter__(self) -> Iterator[NDArray]:
1316
+ return iter(self.eval())
1317
+
1318
+ @property
1319
+ def to_vec(self) -> Vec[NDArray]: ...
1320
+
1321
+ @method(preserve=True)
1322
+ def eval(self) -> tuple[NDArray, ...]:
1323
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
1324
+
1325
+
1326
+ converter(Vec[NDArray], TupleNDArray, lambda x: TupleNDArray.from_vec(x))
1327
+
1328
+ TupleNDArrayLike: TypeAlias = TupleNDArray | VecLike[NDArray, NDArrayLike]
1329
+
1330
+
1331
+ @array_api_ruleset.register
1332
+ def _tuple_ndarray(
1333
+ length: Int,
1334
+ idx_fn: Callable[[Int], NDArray],
1335
+ k: i64,
1336
+ idx: Int,
1337
+ vs: Vec[NDArray],
1338
+ v: NDArray,
1339
+ v1: NDArray,
1340
+ tv: TupleNDArray,
1341
+ tv1: TupleNDArray,
1342
+ b: Boolean,
1343
+ ):
1344
+ yield rule(eq(tv).to(TupleNDArray.from_vec(vs))).then(set_(tv.to_vec).to(vs))
1345
+ yield rewrite(TupleNDArray(length, idx_fn).length()).to(length)
1346
+ yield rewrite(TupleNDArray(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length)))
1347
+
1348
+ # cons access
1349
+ yield rewrite(TupleNDArray.EMPTY.length()).to(Int(0))
1350
+ yield rewrite(TupleNDArray.EMPTY[idx]).to(NDArray.NEVER)
1351
+ yield rewrite(tv.append(v).length()).to(tv.length() + 1)
1352
+ yield rewrite(tv.append(v)[idx]).to(NDArray.if_(idx == tv.length(), v, tv[idx]))
1353
+ # functional to cons
1354
+ yield rewrite(TupleNDArray(0, idx_fn), subsume=True).to(TupleNDArray.EMPTY)
1355
+ yield rewrite(TupleNDArray(Int(k), idx_fn), subsume=True).to(
1356
+ TupleNDArray(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0
1357
+ )
1358
+
1359
+ # cons to vec
1360
+ yield rewrite(TupleNDArray.EMPTY).to(TupleNDArray.from_vec(Vec[NDArray]()))
1361
+ yield rewrite(TupleNDArray.from_vec(vs).append(v)).to(TupleNDArray.from_vec(vs.append(Vec(v))))
1362
+
1363
+ # unify append
1364
+ yield rule(eq(tv.append(v)).to(tv1.append(v1))).then(union(tv).with_(tv1), union(v).with_(v1))
1365
+
1366
+
1367
+ class OptionalBool(Expr, ruleset=array_api_ruleset):
1368
+ none: ClassVar[OptionalBool]
1369
+
1370
+ @classmethod
1371
+ def some(cls, value: Boolean) -> OptionalBool: ...
1372
+
1373
+
1374
+ converter(type(None), OptionalBool, lambda _: OptionalBool.none)
1375
+ converter(Boolean, OptionalBool, lambda x: OptionalBool.some(x))
1376
+
1377
+
1378
+ class OptionalDType(Expr, ruleset=array_api_ruleset):
1379
+ none: ClassVar[OptionalDType]
1380
+
1381
+ @classmethod
1382
+ def some(cls, value: DType) -> OptionalDType: ...
1383
+
1384
+
1385
+ converter(type(None), OptionalDType, lambda _: OptionalDType.none)
1386
+ converter(DType, OptionalDType, lambda x: OptionalDType.some(x))
1387
+
1388
+
1389
+ class OptionalDevice(Expr, ruleset=array_api_ruleset):
1390
+ none: ClassVar[OptionalDevice]
1391
+
1392
+ @classmethod
1393
+ def some(cls, value: Device) -> OptionalDevice: ...
1394
+
1395
+
1396
+ converter(type(None), OptionalDevice, lambda _: OptionalDevice.none)
1397
+ converter(Device, OptionalDevice, lambda x: OptionalDevice.some(x))
1398
+
1399
+
1400
+ class OptionalTupleInt(Expr, ruleset=array_api_ruleset):
1401
+ none: ClassVar[OptionalTupleInt]
1402
+
1403
+ @classmethod
1404
+ def some(cls, value: TupleIntLike) -> OptionalTupleInt: ...
1405
+
1406
+
1407
+ converter(type(None), OptionalTupleInt, lambda _: OptionalTupleInt.none)
1408
+ converter(TupleInt, OptionalTupleInt, lambda x: OptionalTupleInt.some(x))
1409
+
1410
+
1411
+ class IntOrTuple(Expr, ruleset=array_api_ruleset):
1412
+ none: ClassVar[IntOrTuple]
1413
+
1414
+ @classmethod
1415
+ def int(cls, value: Int) -> IntOrTuple: ...
1416
+
1417
+ @classmethod
1418
+ def tuple(cls, value: TupleIntLike) -> IntOrTuple: ...
1419
+
1420
+
1421
+ converter(Int, IntOrTuple, lambda v: IntOrTuple.int(v))
1422
+ converter(TupleInt, IntOrTuple, lambda v: IntOrTuple.tuple(v))
1423
+
1424
+
1425
+ class OptionalIntOrTuple(Expr, ruleset=array_api_ruleset):
1426
+ none: ClassVar[OptionalIntOrTuple]
1427
+
1428
+ @classmethod
1429
+ def some(cls, value: IntOrTuple) -> OptionalIntOrTuple: ...
1430
+
1431
+
1432
+ converter(type(None), OptionalIntOrTuple, lambda _: OptionalIntOrTuple.none)
1433
+ converter(IntOrTuple, OptionalIntOrTuple, lambda v: OptionalIntOrTuple.some(v))
1434
+
1435
+
1436
+ @function
1437
+ def asarray(
1438
+ a: NDArray,
1439
+ dtype: OptionalDType = OptionalDType.none,
1440
+ copy: OptionalBool = OptionalBool.none,
1441
+ device: OptionalDevice = OptionalDevice.none,
1442
+ ) -> NDArray: ...
1443
+
1444
+
1445
+ @array_api_ruleset.register
1446
+ def _assarray(a: NDArray, d: OptionalDType, ob: OptionalBool):
1447
+ yield rewrite(asarray(a, d, ob).ndim).to(a.ndim) # asarray doesn't change ndim
1448
+ yield rewrite(asarray(a)).to(a)
1449
+
1450
+
1451
+ @function
1452
+ def isfinite(x: NDArray) -> NDArray: ...
1453
+
1454
+
1455
+ @function
1456
+ def sum(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray:
1457
+ """
1458
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.sum.html?highlight=sum
1459
+ """
1460
+
1461
+
1462
+ @array_api_ruleset.register
1463
+ def _sum(x: NDArray, y: NDArray, v: Value, dtype: DType):
1464
+ return [
1465
+ rewrite(sum(x / NDArray.scalar(v))).to(sum(x) / NDArray.scalar(v)),
1466
+ # Sum of 0D array is
1467
+ ]
1468
+
1469
+
1470
+ @function
1471
+ def reshape(x: NDArray, shape: TupleIntLike, copy: OptionalBool = OptionalBool.none) -> NDArray: ...
1472
+
1473
+
1474
+ # @function
1475
+ # def reshape_transform_index(original_shape: TupleInt, shape: TupleInt, index: TupleInt) -> TupleInt:
1476
+ # """
1477
+ # Transforms an indexing operation on a reshaped array to an indexing operation on the original array.
1478
+ # """
1479
+ # ...
1480
+
1481
+
1482
+ # @function
1483
+ # def reshape_transform_shape(original_shape: TupleInt, shape: TupleInt) -> TupleInt:
1484
+ # """
1485
+ # Transforms the shape of an array to one that is reshaped, by replacing -1 with the correct value.
1486
+ # """
1487
+ # ...
1488
+
1489
+
1490
+ # @array_api_ruleset.register
1491
+ # def _reshape(
1492
+ # x: NDArray,
1493
+ # y: NDArray,
1494
+ # shape: TupleInt,
1495
+ # copy: OptionalBool,
1496
+ # i: Int,
1497
+ # s: String,
1498
+ # ix: TupleInt,
1499
+ # ):
1500
+ # return [
1501
+ # # dtype of result is same as input
1502
+ # rewrite(reshape(x, shape, copy).dtype).to(x.dtype),
1503
+ # # Indexing into a reshaped array is the same as indexing into the original array with a transformed index
1504
+ # rewrite(reshape(x, shape, copy).index(ix)).to(x.index(reshape_transform_index(x.shape, shape, ix))),
1505
+ # rewrite(reshape(x, shape, copy).shape).to(reshape_transform_shape(x.shape, shape)),
1506
+ # # reshape_transform_shape recursively
1507
+ # # TODO: handle all cases
1508
+ # rewrite(reshape_transform_shape(TupleInt(i), TupleInt(Int(-1)))).to(TupleInt(i)),
1509
+ # ]
1510
+
1511
+
1512
+ @function
1513
+ def unique_values(x: NDArrayLike) -> NDArray:
1514
+ """
1515
+ Returns the unique elements of an input array x flattened with arbitrary ordering.
1516
+
1517
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_values.html
1518
+ """
1519
+
1520
+
1521
+ @array_api_ruleset.register
1522
+ def _unique_values(x: NDArray):
1523
+ return [
1524
+ rewrite(unique_values(unique_values(x))).to(unique_values(x)),
1525
+ ]
1526
+
1527
+
1528
+ @function
1529
+ def concat(arrays: TupleNDArrayLike, axis: OptionalInt = OptionalInt.none) -> NDArray: ...
1530
+
1531
+
1532
+ @array_api_ruleset.register
1533
+ def _concat(x: NDArray):
1534
+ return [
1535
+ # only support no-op concat for now
1536
+ rewrite(concat(TupleNDArray.EMPTY.append(x))).to(x),
1537
+ ]
1538
+
1539
+
1540
+ @function
1541
+ def astype(x: NDArray, dtype: DType) -> NDArray: ...
1542
+
1543
+
1544
+ @array_api_ruleset.register
1545
+ def _astype(x: NDArray, dtype: DType, i: i64):
1546
+ return [
1547
+ rewrite(astype(x, dtype).dtype).to(dtype),
1548
+ rewrite(astype(NDArray.scalar(Value.int(Int(i))), float64)).to(
1549
+ NDArray.scalar(Value.float(Float(f64.from_i64(i))))
1550
+ ),
1551
+ ]
1552
+
1553
+
1554
+ @function
1555
+ def unique_counts(x: NDArray) -> TupleNDArray:
1556
+ """
1557
+ Returns the unique elements of an input array x and the corresponding counts for each unique element in x.
1558
+
1559
+
1560
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_counts.html
1561
+ """
1562
+
1563
+
1564
+ @array_api_ruleset.register
1565
+ def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value, dtype: DType):
1566
+ return [
1567
+ # rewrite(unique_counts(x).length()).to(Int(2)),
1568
+ rewrite(unique_counts(x)).to(TupleNDArray(2, unique_counts(x).__getitem__)),
1569
+ # Sum of all unique counts is the size of the array
1570
+ rewrite(sum(unique_counts(x)[Int(1)])).to(NDArray.scalar(Value.int(x.size))),
1571
+ # Same but with astype in the middle
1572
+ # TODO: Replace
1573
+ rewrite(sum(astype(unique_counts(x)[Int(1)], dtype))).to(astype(NDArray.scalar(Value.int(x.size)), dtype)),
1574
+ ]
1575
+
1576
+
1577
+ @function
1578
+ def square(x: NDArray) -> NDArray: ...
1579
+
1580
+
1581
+ @function
1582
+ def any(x: NDArray) -> NDArray: ...
1583
+
1584
+
1585
+ @function(egg_fn="ndarray-abs")
1586
+ def abs(x: NDArray) -> NDArray: ...
1587
+
1588
+
1589
+ @function(egg_fn="ndarray-log")
1590
+ def log(x: NDArray) -> NDArray: ...
1591
+
1592
+
1593
+ @array_api_ruleset.register
1594
+ def _abs(f: Float):
1595
+ return [
1596
+ rewrite(abs(NDArray.scalar(Value.float(f)))).to(NDArray.scalar(Value.float(f.abs()))),
1597
+ ]
1598
+
1599
+
1600
+ @function
1601
+ def unique_inverse(x: NDArray) -> TupleNDArray:
1602
+ """
1603
+ Returns the unique elements of an input array x and the indices from the set of unique elements that reconstruct x.
1604
+
1605
+ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_inverse.html
1606
+ """
1607
+
1608
+
1609
+ @array_api_ruleset.register
1610
+ def _unique_inverse(x: NDArray, i: Int):
1611
+ return [
1612
+ # rewrite(unique_inverse(x).length()).to(Int(2)),
1613
+ rewrite(unique_inverse(x)).to(TupleNDArray(2, unique_inverse(x).__getitem__)),
1614
+ # Shape of unique_inverse first element is same as shape of unique_values
1615
+ rewrite(unique_inverse(x)[Int(0)]).to(unique_values(x)),
1616
+ ]
1617
+
1618
+
1619
+ @function
1620
+ def zeros(
1621
+ shape: TupleIntLike, dtype: OptionalDType = OptionalDType.none, device: OptionalDevice = OptionalDevice.none
1622
+ ) -> NDArray: ...
1623
+
1624
+
1625
+ @function
1626
+ def expand_dims(x: NDArray, axis: Int = Int(0)) -> NDArray: ...
1627
+
1628
+
1629
+ @function
1630
+ def mean(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none, keepdims: Boolean = FALSE) -> NDArray: ...
1631
+
1632
+
1633
+ # TODO: Possibly change names to include modules.
1634
+ @function(egg_fn="ndarray-sqrt")
1635
+ def sqrt(x: NDArray) -> NDArray: ...
1636
+
1637
+
1638
+ @function
1639
+ def std(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray: ...
1640
+
1641
+
1642
+ @function
1643
+ def real(x: NDArray) -> NDArray: ...
1644
+
1645
+
1646
+ @function
1647
+ def conj(x: NDArray) -> NDArray: ...
1648
+
1649
+
1650
+ linalg = sys.modules[__name__]
1651
+
1652
+
1653
+ @function
1654
+ def svd(x: NDArray, full_matrices: Boolean = TRUE) -> TupleNDArray:
1655
+ """
1656
+ https://data-apis.org/array-api/2022.12/extensions/generated/array_api.linalg.svd.html
1657
+ """
1658
+
1659
+
1660
+ @array_api_ruleset.register
1661
+ def _linalg(x: NDArray, full_matrices: Boolean):
1662
+ return [
1663
+ # rewrite(svd(x, full_matrices).length()).to(Int(3)),
1664
+ rewrite(svd(x, full_matrices)).to(TupleNDArray(3, svd(x, full_matrices).__getitem__)),
1665
+ ]
1666
+
1667
+
1668
+ @function(ruleset=array_api_ruleset)
1669
+ def ndindex(shape: TupleIntLike) -> TupleTupleInt:
1670
+ """
1671
+ https://numpy.org/doc/stable/reference/generated/numpy.ndindex.html
1672
+ """
1673
+ shape = cast("TupleInt", shape)
1674
+ return shape.map_tuple_int(TupleInt.range).product()
1675
+
1676
+
1677
+ ##
1678
+ # Interval analysis
1679
+ #
1680
+ # to analyze `any(((astype(unique_counts(NDArray.var("y"))[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(150.0))) < NDArray.scalar(Value.int(Int(0)))).bool()``
1681
+ ##
1682
+
1683
+ greater_zero = relation("greater_zero", Value)
1684
+
1685
+
1686
+ # @function
1687
+ # def ndarray_all_greater_0(x: NDArray) -> Unit:
1688
+ # ...
1689
+
1690
+
1691
+ # @function
1692
+ # def ndarray_all_false(x: NDArray) -> Unit:
1693
+ # ...
1694
+
1695
+
1696
+ # @function
1697
+ # def ndarray_all_true(x: NDArray) -> Unit:
1698
+ # ...
1699
+
1700
+
1701
+ # any((astype(unique_counts(_NDArray_1)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(150.0)))) < NDArray.scalar(Value.int(Int(0)))).to_bool()
1702
+
1703
+ # sum(astype(unique_counts(_NDArray_1)[Int(1)], DType.float64) / NDArray.scalar(Value.int(Int(150))))
1704
+ # And also
1705
+
1706
+ # def
1707
+
1708
+
1709
+ @function
1710
+ def broadcast_index(from_shape: TupleIntLike, to_shape: TupleIntLike, index: TupleIntLike) -> TupleInt:
1711
+ """
1712
+ Returns the index in the original array of the given index in the broadcasted array.
1713
+ """
1714
+
1715
+
1716
+ @function
1717
+ def broadcast_shapes(shape1: TupleIntLike, shape2: TupleIntLike) -> TupleInt:
1718
+ """
1719
+ Returns the shape of the broadcasted array.
1720
+ """
1721
+
1722
+
1723
+ @array_api_ruleset.register
1724
+ def _interval_analaysis(
1725
+ x: NDArray,
1726
+ y: NDArray,
1727
+ z: NDArray,
1728
+ dtype: DType,
1729
+ f: f64,
1730
+ i: i64,
1731
+ b: Boolean,
1732
+ idx: TupleInt,
1733
+ v: Value,
1734
+ v1: Value,
1735
+ v2: Value,
1736
+ float_: Float,
1737
+ int_: Int,
1738
+ ):
1739
+ res_shape = broadcast_shapes(x.shape, y.shape)
1740
+ x_value = x.index(broadcast_index(x.shape, res_shape, idx))
1741
+ y_value = y.index(broadcast_index(y.shape, res_shape, idx))
1742
+ return [
1743
+ # Calling any on an array gives back a sclar, which is true if any of the values are truthy
1744
+ rewrite(any(x)).to(
1745
+ NDArray.scalar(Value.bool(possible_values(x.index(ALL_INDICES).to_truthy_value).contains(Value.bool(TRUE))))
1746
+ ),
1747
+ # Indexing x < y is the same as broadcasting the index and then indexing both and then comparing
1748
+ rewrite((x < y).index(idx)).to(x_value < y_value),
1749
+ # Same for x / y
1750
+ rewrite((x / y).index(idx)).to(x_value / y_value),
1751
+ # Indexing a scalar is the same as the scalar
1752
+ rewrite(NDArray.scalar(v).index(idx)).to(v),
1753
+ # Indexing of astype is same as astype of indexing
1754
+ rewrite(astype(x, dtype).index(idx)).to(x.index(idx).astype(dtype)),
1755
+ # rule(eq(y).to(x < NDArray.scalar(Value.int(Int(0)))), ndarray_all_greater_0(x)).then(ndarray_all_false(y)),
1756
+ # rule(eq(y).to(any(x)), ndarray_all_false(x)).then(union(y).with_(NDArray.scalar(Value.bool(FALSE)))),
1757
+ # Indexing into unique counts counts are all positive
1758
+ rule(
1759
+ eq(v).to(unique_counts(x)[Int(1)].index(idx)),
1760
+ ).then(greater_zero(v)),
1761
+ # Min value preserved over astype
1762
+ rule(
1763
+ greater_zero(v),
1764
+ eq(v1).to(v.astype(dtype)),
1765
+ ).then(
1766
+ greater_zero(v1),
1767
+ ),
1768
+ # Min value of scalar is scalar itself
1769
+ rule(eq(v).to(Value.float(Float(f))), f > 0.0).then(greater_zero(v)),
1770
+ rule(eq(v).to(Value.int(Int(i))), i > 0).then(greater_zero(v)),
1771
+ # If we have divison of v and v1, and both greater than zero, then the result is greater than zero
1772
+ rule(
1773
+ greater_zero(v),
1774
+ greater_zero(v1),
1775
+ eq(v2).to(v / v1),
1776
+ ).then(
1777
+ greater_zero(v2),
1778
+ ),
1779
+ # Define v < 0 to be false, if greater_zero(v)
1780
+ rule(
1781
+ greater_zero(v),
1782
+ eq(v1).to(v < Value.int(Int(0))),
1783
+ ).then(
1784
+ union(v1).with_(Value.bool(FALSE)),
1785
+ ),
1786
+ # possible values of bool is bool
1787
+ rewrite(possible_values(Value.bool(b))).to(TupleValue.EMPTY.append(Value.bool(b))),
1788
+ # casting to a type preserves if > 0
1789
+ rule(
1790
+ eq(v1).to(v.astype(dtype)),
1791
+ greater_zero(v),
1792
+ ).then(
1793
+ greater_zero(v1),
1794
+ ),
1795
+ ]
1796
+
1797
+
1798
+ ##
1799
+ # Mathematical descriptions of arrays as:
1800
+ # 1. A shape `.shape`
1801
+ # 2. A dtype `.dtype`
1802
+ # 3. A mapping from indices to values `x.index(idx)`
1803
+ #
1804
+ # For all operations that are supported mathematically, define each of the above.
1805
+ ##
1806
+
1807
+
1808
+ def _demand_shape(compound: NDArray, inner: NDArray) -> Command:
1809
+ __a = var("__a", NDArray)
1810
+ return rule(eq(__a).to(compound)).then(inner.shape, inner.shape.length())
1811
+
1812
+
1813
+ @array_api_ruleset.register
1814
+ def _scalar_math(v: Value, vs: TupleValue, i: Int):
1815
+ yield rewrite(NDArray.scalar(v).shape).to(TupleInt.EMPTY)
1816
+ yield rewrite(NDArray.scalar(v).dtype).to(v.dtype)
1817
+ yield rewrite(NDArray.scalar(v).index(TupleInt.EMPTY)).to(v)
1818
+
1819
+
1820
+ @array_api_ruleset.register
1821
+ def _vector_math(v: Value, vs: TupleValue, ti: TupleInt):
1822
+ yield rewrite(NDArray.vector(vs).shape).to(TupleInt.single(vs.length()))
1823
+ yield rewrite(NDArray.vector(vs).dtype).to(vs[Int(0)].dtype)
1824
+ yield rewrite(NDArray.vector(vs).index(ti)).to(vs[ti[0]])
1825
+
1826
+
1827
+ @array_api_ruleset.register
1828
+ def _reshape_math(x: NDArray, shape: TupleInt, copy: OptionalBool):
1829
+ res = reshape(x, shape, copy)
1830
+
1831
+ yield _demand_shape(res, x)
1832
+ # Demand shape length and index
1833
+ yield rule(res).then(shape.length(), shape[0])
1834
+
1835
+ # Reshaping a vec to a vec is the same as the vec
1836
+ yield rewrite(res).to(
1837
+ x,
1838
+ eq(x.shape.length()).to(Int(1)),
1839
+ eq(shape.length()).to(Int(1)),
1840
+ eq(shape[0]).to(Int(-1)),
1841
+ )
1842
+
1843
+
1844
+ @array_api_ruleset.register
1845
+ def _indexing_pushdown(x: NDArray, shape: TupleInt, copy: OptionalBool, i: Int):
1846
+ # rewrite full getitem to indexec
1847
+ yield rewrite(x[IndexKey.int(i)]).to(NDArray.scalar(x.index(TupleInt.single(i))))
1848
+ # TODO: Multi index rewrite as well if all are ints
1849
+
1850
+
1851
+ ##
1852
+ # Assumptions
1853
+ ##
1854
+
1855
+
1856
+ @function(mutates_first_arg=True)
1857
+ def assume_dtype(x: NDArray, dtype: DType) -> None:
1858
+ """
1859
+ Asserts that the dtype of x is dtype.
1860
+ """
1861
+
1862
+
1863
+ @array_api_ruleset.register
1864
+ def _assume_dtype(x: NDArray, dtype: DType, idx: TupleInt):
1865
+ orig_x = copy(x)
1866
+ assume_dtype(x, dtype)
1867
+ yield rewrite(x.dtype).to(dtype)
1868
+ yield rewrite(x.shape).to(orig_x.shape)
1869
+ yield rewrite(x.index(idx)).to(orig_x.index(idx))
1870
+
1871
+
1872
+ @function(mutates_first_arg=True)
1873
+ def assume_shape(x: NDArray, shape: TupleIntLike) -> None:
1874
+ """
1875
+ Asserts that the shape of x is shape.
1876
+ """
1877
+
1878
+
1879
+ @array_api_ruleset.register
1880
+ def _assume_shape(x: NDArray, shape: TupleInt, idx: TupleInt):
1881
+ orig_x = copy(x)
1882
+ assume_shape(x, shape)
1883
+ yield rewrite(x.shape).to(shape)
1884
+ yield rewrite(x.dtype).to(orig_x.dtype)
1885
+ yield rewrite(x.index(idx)).to(orig_x.index(idx))
1886
+
1887
+
1888
+ @function(mutates_first_arg=True)
1889
+ def assume_isfinite(x: NDArray) -> None:
1890
+ """
1891
+ Asserts that the scalar ndarray is non null and not infinite.
1892
+ """
1893
+
1894
+
1895
+ @array_api_ruleset.register
1896
+ def _isfinite(x: NDArray, ti: TupleInt):
1897
+ orig_x = copy(x)
1898
+ assume_isfinite(x)
1899
+
1900
+ # pass through getitem, shape, index
1901
+ yield rewrite(x.shape).to(orig_x.shape)
1902
+ yield rewrite(x.dtype).to(orig_x.dtype)
1903
+ yield rewrite(x.index(ti)).to(orig_x.index(ti))
1904
+ # But say that any indixed value is finite
1905
+ yield rewrite(x.index(ti).isfinite()).to(TRUE)
1906
+
1907
+
1908
+ @function(mutates_first_arg=True)
1909
+ def assume_value_one_of(x: NDArray, values: TupleValueLike) -> None:
1910
+ """
1911
+ A value that is one of the values in the tuple.
1912
+ """
1913
+
1914
+
1915
+ @array_api_ruleset.register
1916
+ def _assume_value_one_of(x: NDArray, v: Value, vs: TupleValue, idx: TupleInt):
1917
+ x_orig = copy(x)
1918
+ assume_value_one_of(x, vs)
1919
+ # Pass through dtype and shape
1920
+ yield rewrite(x.shape).to(x_orig.shape)
1921
+ yield rewrite(x.dtype).to(x_orig.dtype)
1922
+ # The array vales passes through, but say that the possible_values are one of the values
1923
+ yield rule(eq(v).to(x.index(idx))).then(
1924
+ union(v).with_(x_orig.index(idx)),
1925
+ union(possible_values(v)).with_(vs),
1926
+ )
1927
+
1928
+
1929
+ @array_api_ruleset.register
1930
+ def _ndarray_value_isfinite(arr: NDArray, x: Value, xs: TupleValue, i: Int, f: f64, b: Boolean):
1931
+ yield rewrite(Value.int(i).isfinite()).to(TRUE)
1932
+ yield rewrite(Value.bool(b).isfinite()).to(TRUE)
1933
+ yield rewrite(Value.float(Float(f)).isfinite()).to(TRUE, ne(f).to(f64(math.nan)))
1934
+
1935
+ # a sum of an array is finite if all the values are finite
1936
+ yield rewrite(isfinite(sum(arr))).to(NDArray.scalar(Value.bool(arr.index(ALL_INDICES).isfinite())))
1937
+
1938
+
1939
+ @array_api_ruleset.register
1940
+ def _unique(xs: TupleValue, a: NDArray, shape: TupleInt, copy: OptionalBool):
1941
+ yield rewrite(unique_values(x=a)).to(NDArray.vector(possible_values(a.index(ALL_INDICES))))
1942
+ # yield rewrite(
1943
+ # possible_values(reshape(a.index(shape, copy), ALL_INDICES)),
1944
+ # ).to(possible_values(a.index(ALL_INDICES)))
1945
+
1946
+
1947
+ @array_api_ruleset.register
1948
+ def _size(x: NDArray):
1949
+ yield rewrite(x.size).to(x.shape.foldl(Int.__mul__, Int(1)))
1950
+
1951
+
1952
+ # Seperate rulseset so we can use it in program gen
1953
+ @ruleset
1954
+ def array_api_vec_to_cons_ruleset(
1955
+ vs: Vec[Int],
1956
+ vv: Vec[Value],
1957
+ vn: Vec[NDArray],
1958
+ vt: Vec[TupleInt],
1959
+ ):
1960
+ yield rewrite(TupleInt.from_vec(vs)).to(TupleInt.EMPTY, eq(vs.length()).to(i64(0)))
1961
+ yield rewrite(TupleInt.from_vec(vs)).to(
1962
+ TupleInt.from_vec(vs.remove(vs.length() - 1)).append(vs[vs.length() - 1]), ne(vs.length()).to(i64(0))
1963
+ )
1964
+
1965
+ yield rewrite(TupleValue.from_vec(vv)).to(TupleValue.EMPTY, eq(vv.length()).to(i64(0)))
1966
+ yield rewrite(TupleValue.from_vec(vv)).to(
1967
+ TupleValue.from_vec(vv.remove(vv.length() - 1)).append(vv[vv.length() - 1]), ne(vv.length()).to(i64(0))
1968
+ )
1969
+
1970
+ yield rewrite(TupleTupleInt.from_vec(vt)).to(TupleTupleInt.EMPTY, eq(vt.length()).to(i64(0)))
1971
+ yield rewrite(TupleTupleInt.from_vec(vt)).to(
1972
+ TupleTupleInt.from_vec(vt.remove(vt.length() - 1)).append(vt[vt.length() - 1]), ne(vt.length()).to(i64(0))
1973
+ )
1974
+ yield rewrite(TupleNDArray.from_vec(vn)).to(TupleNDArray.EMPTY, eq(vn.length()).to(i64(0)))
1975
+ yield rewrite(TupleNDArray.from_vec(vn)).to(
1976
+ TupleNDArray.from_vec(vn.remove(vn.length() - 1)).append(vn[vn.length() - 1]), ne(vn.length()).to(i64(0))
1977
+ )
1978
+
1979
+
1980
+ array_api_combined_ruleset = array_api_ruleset | array_api_vec_to_cons_ruleset
1981
+ array_api_schedule = array_api_combined_ruleset.saturate()
1982
+
1983
+ _CURRENT_EGRAPH: None | EGraph = None
1984
+
1985
+
1986
+ @contextlib.contextmanager
1987
+ def set_array_api_egraph(egraph: EGraph) -> Iterator[None]:
1988
+ """
1989
+ Context manager that will set the current egraph. It will be set back after.
1990
+ """
1991
+ global _CURRENT_EGRAPH
1992
+ assert _CURRENT_EGRAPH is None
1993
+ _CURRENT_EGRAPH = egraph
1994
+ yield
1995
+ _CURRENT_EGRAPH = None
1996
+
1997
+
1998
+ def _get_current_egraph() -> EGraph:
1999
+ return _CURRENT_EGRAPH or EGraph()
2000
+
2001
+
2002
+ def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: BuiltinExpr) -> Any:
2003
+ """
2004
+ Try evaling the expression that will result in a primitive expression being fill.
2005
+ if it fails, display the egraph and raise an error.
2006
+ """
2007
+ try:
2008
+ extracted = egraph.extract(prim_expr)
2009
+ except EggSmolError:
2010
+ # If this primitive doesn't exist in the egraph, we need to try to create it by
2011
+ # registering the expression and running the schedule
2012
+ egraph.register(expr)
2013
+ egraph.run(schedule)
2014
+ try:
2015
+ extracted = egraph.extract(prim_expr)
2016
+ except BaseException as e:
2017
+ # egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
2018
+ raise add_note(f"Cannot evaluate {egraph.extract(expr)}", e) # noqa: B904
2019
+ return extracted.value # type: ignore[attr-defined]