egglog 12.0.0__cp313-cp313t-manylinux_2_17_ppc64.manylinux2014_ppc64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- egglog/__init__.py +13 -0
- egglog/bindings.cpython-313t-powerpc64-linux-gnu.so +0 -0
- egglog/bindings.pyi +887 -0
- egglog/builtins.py +1144 -0
- egglog/config.py +8 -0
- egglog/conversion.py +290 -0
- egglog/declarations.py +964 -0
- egglog/deconstruct.py +176 -0
- egglog/egraph.py +2247 -0
- egglog/egraph_state.py +978 -0
- egglog/examples/README.rst +5 -0
- egglog/examples/__init__.py +3 -0
- egglog/examples/bignum.py +32 -0
- egglog/examples/bool.py +38 -0
- egglog/examples/eqsat_basic.py +44 -0
- egglog/examples/fib.py +28 -0
- egglog/examples/higher_order_functions.py +42 -0
- egglog/examples/jointree.py +64 -0
- egglog/examples/lambda_.py +287 -0
- egglog/examples/matrix.py +175 -0
- egglog/examples/multiset.py +60 -0
- egglog/examples/ndarrays.py +144 -0
- egglog/examples/resolution.py +84 -0
- egglog/examples/schedule_demo.py +34 -0
- egglog/exp/MoA.ipynb +617 -0
- egglog/exp/__init__.py +3 -0
- egglog/exp/any_expr.py +947 -0
- egglog/exp/any_expr_example.ipynb +408 -0
- egglog/exp/array_api.py +2019 -0
- egglog/exp/array_api_jit.py +51 -0
- egglog/exp/array_api_loopnest.py +74 -0
- egglog/exp/array_api_numba.py +69 -0
- egglog/exp/array_api_program_gen.py +510 -0
- egglog/exp/program_gen.py +427 -0
- egglog/exp/siu_examples.py +32 -0
- egglog/ipython_magic.py +41 -0
- egglog/pretty.py +566 -0
- egglog/py.typed +0 -0
- egglog/runtime.py +888 -0
- egglog/thunk.py +97 -0
- egglog/type_constraint_solver.py +111 -0
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35798 -0
- egglog/visualizer_widget.py +39 -0
- egglog-12.0.0.dist-info/METADATA +93 -0
- egglog-12.0.0.dist-info/RECORD +48 -0
- egglog-12.0.0.dist-info/WHEEL +5 -0
- egglog-12.0.0.dist-info/licenses/LICENSE +21 -0
egglog/exp/array_api.py
ADDED
|
@@ -0,0 +1,2019 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
## Lists
|
|
5
|
+
|
|
6
|
+
Lists have two main constructors:
|
|
7
|
+
|
|
8
|
+
- `List(length, idx_fn)`
|
|
9
|
+
- `List.EMPTY` / `initial.append(last)`
|
|
10
|
+
|
|
11
|
+
This is so that they can be defined either with a known fixed integer length (the cons list type) or a symbolic
|
|
12
|
+
length that could not be resolved to an integer.
|
|
13
|
+
|
|
14
|
+
There are rewrites to convert between these constructors in both directions. The only limitation however is that
|
|
15
|
+
`length` has to a real i64 in order to be converted to a cons list.
|
|
16
|
+
|
|
17
|
+
When you are writing a function that uses ints, feel free to the `__getitem__` or `length()` methods or match
|
|
18
|
+
directly on `List()` constructor. If you can write your function using that interface please do. But for some other
|
|
19
|
+
methods whether the resulting length/index function is dependent on the rest of it, you can only define it with a known
|
|
20
|
+
length, so you can then use the const list constructors.
|
|
21
|
+
|
|
22
|
+
We also support creating lists from vectors. These can be converted one to one to the snoc list representation.
|
|
23
|
+
|
|
24
|
+
It is troublesome to have to redefine lists for every type. It would be nice to have generic types, but they are not implemented yet.
|
|
25
|
+
|
|
26
|
+
We are gauranteed that all lists with known lengths will be represented as cons/empty. To safely use lists, use
|
|
27
|
+
the `.length` and `.__getitem__` methods, unles you want to to depend on it having known length, in which
|
|
28
|
+
case you can match directly on the cons list.
|
|
29
|
+
|
|
30
|
+
To be a list, you must implement two methods:
|
|
31
|
+
|
|
32
|
+
* `l.length() -> Int`
|
|
33
|
+
* `l.__getitem__(i: Int) -> T`
|
|
34
|
+
|
|
35
|
+
There are three main types of constructors for lists which all implement these methods:
|
|
36
|
+
|
|
37
|
+
* Functional `List(length, idx_fn)`
|
|
38
|
+
* cons (well reversed cons) lists `List.EMPTY` and `l.append(x)`
|
|
39
|
+
* Vectors `List.from_vec(vec)`
|
|
40
|
+
|
|
41
|
+
Also all lists constructors must be converted to the functional representation, so that we can match on it
|
|
42
|
+
and convert lists with known lengths into cons lists and into vectors.
|
|
43
|
+
|
|
44
|
+
This is neccessary so that known length lists are properly materialized during extraction.
|
|
45
|
+
|
|
46
|
+
Q: Why are they implemented as SNOC lists instead of CONS lists?
|
|
47
|
+
A: So that when converting from functional to lists we can use the same index function by starting at the end and folding
|
|
48
|
+
that way recursively.
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
# mypy: disable-error-code="empty-body"
|
|
54
|
+
|
|
55
|
+
from __future__ import annotations
|
|
56
|
+
|
|
57
|
+
import contextlib
|
|
58
|
+
import itertools
|
|
59
|
+
import math
|
|
60
|
+
import numbers
|
|
61
|
+
import os
|
|
62
|
+
import sys
|
|
63
|
+
from collections.abc import Callable
|
|
64
|
+
from copy import copy
|
|
65
|
+
from types import EllipsisType
|
|
66
|
+
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, cast
|
|
67
|
+
|
|
68
|
+
import numpy as np
|
|
69
|
+
|
|
70
|
+
from egglog import *
|
|
71
|
+
from egglog.runtime import RuntimeExpr
|
|
72
|
+
|
|
73
|
+
from .program_gen import *
|
|
74
|
+
|
|
75
|
+
if TYPE_CHECKING:
|
|
76
|
+
from collections.abc import Iterator
|
|
77
|
+
from types import ModuleType
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# Pretend that exprs are numbers b/c sklearn does isinstance checks
|
|
81
|
+
numbers.Integral.register(RuntimeExpr)
|
|
82
|
+
|
|
83
|
+
# Set this to 1 before scipy is ever imported
|
|
84
|
+
# https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html#using-array-api-standard-support
|
|
85
|
+
os.environ["SCIPY_ARRAY_API"] = "1"
|
|
86
|
+
|
|
87
|
+
array_api_ruleset = ruleset(name="array_api_ruleset")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class Boolean(Expr, ruleset=array_api_ruleset):
|
|
91
|
+
def __init__(self, value: BoolLike) -> None: ...
|
|
92
|
+
|
|
93
|
+
@method(preserve=True)
|
|
94
|
+
def __bool__(self) -> bool:
|
|
95
|
+
return self.eval()
|
|
96
|
+
|
|
97
|
+
@method(preserve=True)
|
|
98
|
+
def eval(self) -> bool:
|
|
99
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_bool)
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def to_bool(self) -> Bool: ...
|
|
103
|
+
|
|
104
|
+
def __or__(self, other: BooleanLike) -> Boolean: ...
|
|
105
|
+
|
|
106
|
+
def __and__(self, other: BooleanLike) -> Boolean: ...
|
|
107
|
+
|
|
108
|
+
def __invert__(self) -> Boolean: ...
|
|
109
|
+
|
|
110
|
+
def __eq__(self, other: BooleanLike) -> Boolean: ... # type: ignore[override]
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
BooleanLike = Boolean | BoolLike
|
|
114
|
+
|
|
115
|
+
TRUE = Boolean(True)
|
|
116
|
+
FALSE = Boolean(False)
|
|
117
|
+
converter(Bool, Boolean, Boolean)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@array_api_ruleset.register
|
|
121
|
+
def _bool(x: Boolean, i: Int, j: Int, b: Bool):
|
|
122
|
+
return [
|
|
123
|
+
rule(eq(x).to(Boolean(b))).then(set_(x.to_bool).to(b)),
|
|
124
|
+
rewrite(TRUE | x).to(TRUE),
|
|
125
|
+
rewrite(FALSE | x).to(x),
|
|
126
|
+
rewrite(TRUE & x).to(x),
|
|
127
|
+
rewrite(FALSE & x).to(FALSE),
|
|
128
|
+
rewrite(~TRUE).to(FALSE),
|
|
129
|
+
rewrite(~FALSE).to(TRUE),
|
|
130
|
+
rule(eq(FALSE).to(TRUE)).then(panic("False cannot equal True")),
|
|
131
|
+
rewrite(x == x).to(TRUE), # noqa: PLR0124
|
|
132
|
+
rewrite(FALSE == TRUE).to(FALSE),
|
|
133
|
+
rewrite(TRUE == FALSE).to(FALSE),
|
|
134
|
+
]
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class Int(Expr, ruleset=array_api_ruleset):
|
|
138
|
+
# a never int is that should not exist. It could represent for example indexing into an array a value that is out of bounds
|
|
139
|
+
# https://en.wikipedia.org/wiki/Bottom_type
|
|
140
|
+
NEVER: ClassVar[Int]
|
|
141
|
+
|
|
142
|
+
@classmethod
|
|
143
|
+
def var(cls, name: StringLike) -> Int: ...
|
|
144
|
+
|
|
145
|
+
def __init__(self, value: i64Like) -> None: ...
|
|
146
|
+
|
|
147
|
+
def __invert__(self) -> Int: ...
|
|
148
|
+
|
|
149
|
+
def __lt__(self, other: IntLike) -> Boolean: ...
|
|
150
|
+
|
|
151
|
+
def __le__(self, other: IntLike) -> Boolean: ...
|
|
152
|
+
|
|
153
|
+
def __eq__(self, other: IntLike) -> Boolean: # type: ignore[override]
|
|
154
|
+
...
|
|
155
|
+
|
|
156
|
+
# add a hash so that this test can pass
|
|
157
|
+
# https://github.com/scikit-learn/scikit-learn/blob/6fd23fca53845b32b249f2b36051c081b65e2fab/sklearn/utils/validation.py#L486-L487
|
|
158
|
+
@method(preserve=True)
|
|
159
|
+
def __hash__(self) -> int:
|
|
160
|
+
egraph = _get_current_egraph()
|
|
161
|
+
egraph.register(self)
|
|
162
|
+
egraph.run(array_api_schedule)
|
|
163
|
+
simplified = egraph.extract(self)
|
|
164
|
+
return hash(cast("RuntimeExpr", simplified).__egg_typed_expr__)
|
|
165
|
+
|
|
166
|
+
def __round__(self, ndigits: OptionalIntLike = None) -> Int: ...
|
|
167
|
+
|
|
168
|
+
# TODO: Fix this?
|
|
169
|
+
# Make != always return a Bool, so that numpy.unique works on a tuple of ints
|
|
170
|
+
# In _unique1d
|
|
171
|
+
@method(preserve=True)
|
|
172
|
+
def __ne__(self, other: Int) -> bool: # type: ignore[override]
|
|
173
|
+
return not (self == other)
|
|
174
|
+
|
|
175
|
+
def __gt__(self, other: IntLike) -> Boolean: ...
|
|
176
|
+
|
|
177
|
+
def __ge__(self, other: IntLike) -> Boolean: ...
|
|
178
|
+
|
|
179
|
+
def __add__(self, other: IntLike) -> Int: ...
|
|
180
|
+
|
|
181
|
+
def __sub__(self, other: IntLike) -> Int: ...
|
|
182
|
+
|
|
183
|
+
def __mul__(self, other: IntLike) -> Int: ...
|
|
184
|
+
|
|
185
|
+
def __truediv__(self, other: IntLike) -> Int: ...
|
|
186
|
+
|
|
187
|
+
def __floordiv__(self, other: IntLike) -> Int: ...
|
|
188
|
+
|
|
189
|
+
def __mod__(self, other: IntLike) -> Int: ...
|
|
190
|
+
|
|
191
|
+
def __divmod__(self, other: IntLike) -> Int: ...
|
|
192
|
+
|
|
193
|
+
def __pow__(self, other: IntLike) -> Int: ...
|
|
194
|
+
|
|
195
|
+
def __lshift__(self, other: IntLike) -> Int: ...
|
|
196
|
+
|
|
197
|
+
def __rshift__(self, other: IntLike) -> Int: ...
|
|
198
|
+
|
|
199
|
+
def __and__(self, other: IntLike) -> Int: ...
|
|
200
|
+
|
|
201
|
+
def __xor__(self, other: IntLike) -> Int: ...
|
|
202
|
+
|
|
203
|
+
def __or__(self, other: IntLike) -> Int: ...
|
|
204
|
+
|
|
205
|
+
def __radd__(self, other: IntLike) -> Int: ...
|
|
206
|
+
|
|
207
|
+
def __rsub__(self, other: IntLike) -> Int: ...
|
|
208
|
+
|
|
209
|
+
def __rmul__(self, other: IntLike) -> Int: ...
|
|
210
|
+
|
|
211
|
+
def __rmatmul__(self, other: IntLike) -> Int: ...
|
|
212
|
+
|
|
213
|
+
def __rtruediv__(self, other: IntLike) -> Int: ...
|
|
214
|
+
|
|
215
|
+
def __rfloordiv__(self, other: IntLike) -> Int: ...
|
|
216
|
+
|
|
217
|
+
def __rmod__(self, other: IntLike) -> Int: ...
|
|
218
|
+
|
|
219
|
+
def __rpow__(self, other: IntLike) -> Int: ...
|
|
220
|
+
|
|
221
|
+
def __rlshift__(self, other: IntLike) -> Int: ...
|
|
222
|
+
|
|
223
|
+
def __rrshift__(self, other: IntLike) -> Int: ...
|
|
224
|
+
|
|
225
|
+
def __rand__(self, other: IntLike) -> Int: ...
|
|
226
|
+
|
|
227
|
+
def __rxor__(self, other: IntLike) -> Int: ...
|
|
228
|
+
|
|
229
|
+
def __ror__(self, other: IntLike) -> Int: ...
|
|
230
|
+
|
|
231
|
+
@property
|
|
232
|
+
def to_i64(self) -> i64: ...
|
|
233
|
+
|
|
234
|
+
@method(preserve=True)
|
|
235
|
+
def eval(self) -> int:
|
|
236
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_i64)
|
|
237
|
+
|
|
238
|
+
@method(preserve=True)
|
|
239
|
+
def __index__(self) -> int:
|
|
240
|
+
return self.eval()
|
|
241
|
+
|
|
242
|
+
@method(preserve=True)
|
|
243
|
+
def __int__(self) -> int:
|
|
244
|
+
return self.eval()
|
|
245
|
+
|
|
246
|
+
@method(preserve=True)
|
|
247
|
+
def __float__(self) -> float:
|
|
248
|
+
return float(self.eval())
|
|
249
|
+
|
|
250
|
+
@method(preserve=True)
|
|
251
|
+
def __bool__(self) -> bool:
|
|
252
|
+
return bool(self.eval())
|
|
253
|
+
|
|
254
|
+
@classmethod
|
|
255
|
+
def if_(cls, b: BooleanLike, i: IntLike, j: IntLike) -> Int: ...
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
@array_api_ruleset.register
|
|
259
|
+
def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
|
|
260
|
+
yield rewrite(Int(i) == Int(i)).to(TRUE)
|
|
261
|
+
yield rule(eq(r).to(Int(i) == Int(j)), ne(i).to(j)).then(union(r).with_(FALSE))
|
|
262
|
+
|
|
263
|
+
yield rewrite(Int(i) >= Int(i)).to(TRUE)
|
|
264
|
+
yield rule(eq(r).to(Int(i) >= Int(j)), i > j).then(union(r).with_(TRUE))
|
|
265
|
+
yield rule(eq(r).to(Int(i) >= Int(j)), i < j).then(union(r).with_(FALSE))
|
|
266
|
+
|
|
267
|
+
yield rewrite(Int(i) < Int(i)).to(FALSE)
|
|
268
|
+
yield rule(eq(r).to(Int(i) < Int(j)), i < j).then(union(r).with_(TRUE))
|
|
269
|
+
yield rule(eq(r).to(Int(i) < Int(j)), i > j).then(union(r).with_(FALSE))
|
|
270
|
+
|
|
271
|
+
yield rewrite(Int(i) > Int(i)).to(FALSE)
|
|
272
|
+
yield rule(eq(r).to(Int(i) > Int(j)), i > j).then(union(r).with_(TRUE))
|
|
273
|
+
yield rule(eq(r).to(Int(i) > Int(j)), i < j).then(union(r).with_(FALSE))
|
|
274
|
+
|
|
275
|
+
yield rule(eq(o).to(Int(j))).then(set_(o.to_i64).to(j))
|
|
276
|
+
|
|
277
|
+
yield rule(eq(Int(i)).to(Int(j)), ne(i).to(j)).then(panic("Real ints cannot be equal to different ints"))
|
|
278
|
+
|
|
279
|
+
yield rewrite(Int(i) + Int(j)).to(Int(i + j))
|
|
280
|
+
yield rewrite(Int(i) - Int(j)).to(Int(i - j))
|
|
281
|
+
yield rewrite(Int(i) * Int(j)).to(Int(i * j))
|
|
282
|
+
yield rewrite(Int(i) // Int(j)).to(Int(i / j))
|
|
283
|
+
yield rewrite(Int(i) % Int(j)).to(Int(i % j))
|
|
284
|
+
yield rewrite(Int(i) & Int(j)).to(Int(i & j))
|
|
285
|
+
yield rewrite(Int(i) | Int(j)).to(Int(i | j))
|
|
286
|
+
yield rewrite(Int(i) ^ Int(j)).to(Int(i ^ j))
|
|
287
|
+
yield rewrite(Int(i) << Int(j)).to(Int(i << j))
|
|
288
|
+
yield rewrite(Int(i) >> Int(j)).to(Int(i >> j))
|
|
289
|
+
yield rewrite(~Int(i)).to(Int(~i))
|
|
290
|
+
|
|
291
|
+
yield rewrite(Int.if_(TRUE, o, b), subsume=True).to(o)
|
|
292
|
+
yield rewrite(Int.if_(FALSE, o, b), subsume=True).to(b)
|
|
293
|
+
|
|
294
|
+
yield rewrite(o.__round__(OptionalInt.none)).to(o)
|
|
295
|
+
|
|
296
|
+
# Never cannot be equal to anything real
|
|
297
|
+
yield rule(eq(Int.NEVER).to(Int(i))).then(panic("Int.NEVER cannot be equal to any real int"))
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
converter(i64, Int, lambda x: Int(x))
|
|
301
|
+
|
|
302
|
+
IntLike: TypeAlias = Int | i64Like
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
@function(ruleset=array_api_ruleset)
|
|
306
|
+
def check_index(length: IntLike, idx: IntLike) -> Int:
|
|
307
|
+
"""
|
|
308
|
+
Returns the index if 0 <= idx < length, otherwise returns Int.NEVER
|
|
309
|
+
"""
|
|
310
|
+
length = cast("Int", length)
|
|
311
|
+
idx = cast("Int", idx)
|
|
312
|
+
return Int.if_(((idx >= 0) & (idx < length)), idx, Int.NEVER)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
# @array_api_ruleset.register
|
|
316
|
+
# def _check_index(i: i64, j: i64, x: Int):
|
|
317
|
+
# yield rewrite(
|
|
318
|
+
# check_index(Int(i), Int(j)),
|
|
319
|
+
# ).to(
|
|
320
|
+
# Int(j),
|
|
321
|
+
# i >= 0,
|
|
322
|
+
# i < j,
|
|
323
|
+
# )
|
|
324
|
+
|
|
325
|
+
# yield rewrite(
|
|
326
|
+
# check_index(x, Int(i)),
|
|
327
|
+
# ).to(
|
|
328
|
+
# Int.NEVER,
|
|
329
|
+
# i < 0,
|
|
330
|
+
# )
|
|
331
|
+
|
|
332
|
+
# yield rewrite(
|
|
333
|
+
# check_index(Int(i), Int(j)),
|
|
334
|
+
# ).to(
|
|
335
|
+
# Int.NEVER,
|
|
336
|
+
# i >= j,
|
|
337
|
+
# )
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class Float(Expr, ruleset=array_api_ruleset):
|
|
341
|
+
# Differentiate costs of three constructors so extraction is deterministic if all three are present
|
|
342
|
+
@method(cost=3)
|
|
343
|
+
def __init__(self, value: f64Like) -> None: ...
|
|
344
|
+
|
|
345
|
+
@property
|
|
346
|
+
def to_f64(self) -> f64: ...
|
|
347
|
+
|
|
348
|
+
@method(preserve=True)
|
|
349
|
+
def eval(self) -> float:
|
|
350
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_f64)
|
|
351
|
+
|
|
352
|
+
def abs(self) -> Float: ...
|
|
353
|
+
|
|
354
|
+
@method(cost=2)
|
|
355
|
+
@classmethod
|
|
356
|
+
def rational(cls, r: BigRat) -> Float: ...
|
|
357
|
+
|
|
358
|
+
@classmethod
|
|
359
|
+
def from_int(cls, i: IntLike) -> Float: ...
|
|
360
|
+
|
|
361
|
+
def __truediv__(self, other: FloatLike) -> Float: ...
|
|
362
|
+
|
|
363
|
+
def __mul__(self, other: FloatLike) -> Float: ...
|
|
364
|
+
|
|
365
|
+
def __add__(self, other: FloatLike) -> Float: ...
|
|
366
|
+
|
|
367
|
+
def __sub__(self, other: FloatLike) -> Float: ...
|
|
368
|
+
|
|
369
|
+
def __pow__(self, other: FloatLike) -> Float: ...
|
|
370
|
+
def __round__(self, ndigits: OptionalIntLike = None) -> Float: ...
|
|
371
|
+
|
|
372
|
+
def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
|
|
373
|
+
def __ne__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
|
|
374
|
+
def __lt__(self, other: FloatLike) -> Boolean: ...
|
|
375
|
+
def __le__(self, other: FloatLike) -> Boolean: ...
|
|
376
|
+
def __gt__(self, other: FloatLike) -> Boolean: ...
|
|
377
|
+
def __ge__(self, other: FloatLike) -> Boolean: ...
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
converter(float, Float, lambda x: Float(x))
|
|
381
|
+
converter(Int, Float, lambda x: Float.from_int(x))
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
FloatLike: TypeAlias = Float | float | IntLike
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
@array_api_ruleset.register
|
|
388
|
+
def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat, i_: Int):
|
|
389
|
+
return [
|
|
390
|
+
rule(eq(fl).to(Float(f))).then(set_(fl.to_f64).to(f)),
|
|
391
|
+
rewrite(Float.from_int(Int(i))).to(Float(f64.from_i64(i))),
|
|
392
|
+
rewrite(Float(f).abs()).to(Float(f), f >= 0.0),
|
|
393
|
+
rewrite(Float(f).abs()).to(Float(-f), f < 0.0),
|
|
394
|
+
# Convert from float to rationl, if its a whole number i.e. can be converted to int
|
|
395
|
+
rewrite(Float(f)).to(Float.rational(BigRat(f.to_i64(), 1)), eq(f64.from_i64(f.to_i64())).to(f)),
|
|
396
|
+
# always convert from int to rational
|
|
397
|
+
rewrite(Float.from_int(Int(i))).to(Float.rational(BigRat(i, 1))),
|
|
398
|
+
rewrite(Float(f) + Float(f2)).to(Float(f + f2)),
|
|
399
|
+
rewrite(Float(f) - Float(f2)).to(Float(f - f2)),
|
|
400
|
+
rewrite(Float(f) * Float(f2)).to(Float(f * f2)),
|
|
401
|
+
rewrite(Float.rational(r) / Float.rational(r1)).to(Float.rational(r / r1)),
|
|
402
|
+
rewrite(Float.rational(r) + Float.rational(r1)).to(Float.rational(r + r1)),
|
|
403
|
+
rewrite(Float.rational(r) - Float.rational(r1)).to(Float.rational(r - r1)),
|
|
404
|
+
rewrite(Float.rational(r) * Float.rational(r1)).to(Float.rational(r * r1)),
|
|
405
|
+
rewrite(Float(f) ** Float(f2)).to(Float(f**f2)),
|
|
406
|
+
# comparisons
|
|
407
|
+
rewrite(Float(f) == Float(f)).to(TRUE),
|
|
408
|
+
rewrite(Float(f) == Float(f2)).to(FALSE, ne(f).to(f2)),
|
|
409
|
+
rewrite(Float(f) != Float(f2)).to(TRUE, f != f2),
|
|
410
|
+
rewrite(Float(f) != Float(f)).to(FALSE),
|
|
411
|
+
rewrite(Float(f) >= Float(f2)).to(TRUE, f >= f2),
|
|
412
|
+
rewrite(Float(f) >= Float(f2)).to(FALSE, f < f2),
|
|
413
|
+
rewrite(Float(f) <= Float(f2)).to(TRUE, f <= f2),
|
|
414
|
+
rewrite(Float(f) <= Float(f2)).to(FALSE, f > f2),
|
|
415
|
+
rewrite(Float(f) > Float(f2)).to(TRUE, f > f2),
|
|
416
|
+
rewrite(Float(f) > Float(f2)).to(FALSE, f <= f2),
|
|
417
|
+
rewrite(Float(f) < Float(f2)).to(TRUE, f < f2),
|
|
418
|
+
rewrite(Float.rational(r) == Float.rational(r)).to(TRUE),
|
|
419
|
+
rewrite(Float.rational(r) == Float.rational(r1)).to(FALSE, ne(r).to(r1)),
|
|
420
|
+
# round
|
|
421
|
+
rewrite(Float.rational(r).__round__()).to(Float.rational(r.round())),
|
|
422
|
+
]
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
class TupleInt(Expr, ruleset=array_api_ruleset):
|
|
426
|
+
"""
|
|
427
|
+
Should act like a tuple[int, ...]
|
|
428
|
+
|
|
429
|
+
All constructors should be rewritten to the functional semantics in the __init__ method.
|
|
430
|
+
"""
|
|
431
|
+
|
|
432
|
+
@classmethod
|
|
433
|
+
def var(cls, name: StringLike) -> TupleInt: ...
|
|
434
|
+
|
|
435
|
+
def __init__(self, length: IntLike, idx_fn: Callable[[Int], Int]) -> None: ...
|
|
436
|
+
|
|
437
|
+
EMPTY: ClassVar[TupleInt]
|
|
438
|
+
NEVER: ClassVar[TupleInt]
|
|
439
|
+
|
|
440
|
+
def append(self, i: IntLike) -> TupleInt: ...
|
|
441
|
+
|
|
442
|
+
@classmethod
|
|
443
|
+
def single(cls, i: Int) -> TupleInt:
|
|
444
|
+
return TupleInt(Int(1), lambda _: i)
|
|
445
|
+
|
|
446
|
+
@method(subsume=True)
|
|
447
|
+
@classmethod
|
|
448
|
+
def range(cls, stop: IntLike) -> TupleInt:
|
|
449
|
+
return TupleInt(stop, lambda i: i)
|
|
450
|
+
|
|
451
|
+
@classmethod
|
|
452
|
+
def from_vec(cls, vec: VecLike[Int, IntLike]) -> TupleInt: ...
|
|
453
|
+
|
|
454
|
+
def __add__(self, other: TupleIntLike) -> TupleInt:
|
|
455
|
+
other = cast("TupleInt", other)
|
|
456
|
+
return TupleInt(
|
|
457
|
+
self.length() + other.length(), lambda i: Int.if_(i < self.length(), self[i], other[i - self.length()])
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
def length(self) -> Int: ...
|
|
461
|
+
def __getitem__(self, i: IntLike) -> Int: ...
|
|
462
|
+
|
|
463
|
+
@method(preserve=True)
|
|
464
|
+
def __len__(self) -> int:
|
|
465
|
+
return self.length().eval()
|
|
466
|
+
|
|
467
|
+
@method(preserve=True)
|
|
468
|
+
def __iter__(self) -> Iterator[Int]:
|
|
469
|
+
return iter(self.eval())
|
|
470
|
+
|
|
471
|
+
@property
|
|
472
|
+
def to_vec(self) -> Vec[Int]: ...
|
|
473
|
+
|
|
474
|
+
@method(preserve=True)
|
|
475
|
+
def eval(self) -> tuple[Int, ...]:
|
|
476
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
|
|
477
|
+
|
|
478
|
+
def foldl(self, f: Callable[[Int, Int], Int], init: Int) -> Int: ...
|
|
479
|
+
def foldl_boolean(self, f: Callable[[Boolean, Int], Boolean], init: Boolean) -> Boolean: ...
|
|
480
|
+
def foldl_tuple_int(self, f: Callable[[TupleInt, Int], TupleInt], init: TupleIntLike) -> TupleInt: ...
|
|
481
|
+
|
|
482
|
+
@method(subsume=True)
|
|
483
|
+
def contains(self, i: Int) -> Boolean:
|
|
484
|
+
return self.foldl_boolean(lambda acc, j: acc | (i == j), FALSE)
|
|
485
|
+
|
|
486
|
+
@method(subsume=True)
|
|
487
|
+
def filter(self, f: Callable[[Int], Boolean]) -> TupleInt:
|
|
488
|
+
return self.foldl_tuple_int(
|
|
489
|
+
lambda acc, v: TupleInt.if_(f(v), acc.append(v), acc),
|
|
490
|
+
TupleInt.EMPTY,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
@method(subsume=True)
|
|
494
|
+
def map(self, f: Callable[[Int], Int]) -> TupleInt:
|
|
495
|
+
return TupleInt(self.length(), lambda i: f(self[i]))
|
|
496
|
+
|
|
497
|
+
@classmethod
|
|
498
|
+
def if_(cls, b: BooleanLike, i: TupleIntLike, j: TupleIntLike) -> TupleInt: ...
|
|
499
|
+
|
|
500
|
+
def drop(self, n: Int) -> TupleInt:
|
|
501
|
+
return TupleInt(self.length() - n, lambda i: self[i + n])
|
|
502
|
+
|
|
503
|
+
def product(self) -> Int:
|
|
504
|
+
return self.foldl(lambda acc, i: acc * i, Int(1))
|
|
505
|
+
|
|
506
|
+
def map_tuple_int(self, f: Callable[[Int], TupleInt]) -> TupleTupleInt:
|
|
507
|
+
return TupleTupleInt(self.length(), lambda i: f(self[i]))
|
|
508
|
+
|
|
509
|
+
def select(self, indices: TupleIntLike) -> TupleInt:
|
|
510
|
+
"""
|
|
511
|
+
Return a new tuple with the elements at the given indices
|
|
512
|
+
"""
|
|
513
|
+
indices = cast("TupleInt", indices)
|
|
514
|
+
return indices.map(lambda i: self[i])
|
|
515
|
+
|
|
516
|
+
def deselect(self, indices: TupleIntLike) -> TupleInt:
|
|
517
|
+
"""
|
|
518
|
+
Return a new tuple with the elements not at the given indices
|
|
519
|
+
"""
|
|
520
|
+
indices = cast("TupleInt", indices)
|
|
521
|
+
return TupleInt.range(self.length()).filter(lambda i: ~indices.contains(i)).map(lambda i: self[i])
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
converter(Vec[Int], TupleInt, lambda x: TupleInt.from_vec(x))
|
|
525
|
+
|
|
526
|
+
TupleIntLike: TypeAlias = TupleInt | VecLike[Int, IntLike]
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
@array_api_ruleset.register
|
|
530
|
+
def _tuple_int(
|
|
531
|
+
i: Int,
|
|
532
|
+
i2: Int,
|
|
533
|
+
f: Callable[[Int, Int], Int],
|
|
534
|
+
bool_f: Callable[[Boolean, Int], Boolean],
|
|
535
|
+
idx_fn: Callable[[Int], Int],
|
|
536
|
+
tuple_int_f: Callable[[TupleInt, Int], TupleInt],
|
|
537
|
+
vs: Vec[Int],
|
|
538
|
+
b: Boolean,
|
|
539
|
+
ti: TupleInt,
|
|
540
|
+
ti2: TupleInt,
|
|
541
|
+
k: i64,
|
|
542
|
+
):
|
|
543
|
+
return [
|
|
544
|
+
rule(eq(ti).to(TupleInt.from_vec(vs))).then(set_(ti.to_vec).to(vs)),
|
|
545
|
+
# Functional access
|
|
546
|
+
rewrite(TupleInt(i, idx_fn).length()).to(i),
|
|
547
|
+
rewrite(TupleInt(i, idx_fn)[i2]).to(idx_fn(check_index(i, i2))),
|
|
548
|
+
# cons access
|
|
549
|
+
rewrite(TupleInt.EMPTY.length()).to(Int(0)),
|
|
550
|
+
rewrite(TupleInt.EMPTY[i]).to(Int.NEVER),
|
|
551
|
+
rewrite(ti.append(i).length()).to(ti.length() + 1),
|
|
552
|
+
rewrite(ti.append(i)[i2]).to(Int.if_(i2 == ti.length(), i, ti[i2])),
|
|
553
|
+
# cons to functional (removed this so that there is not infinite replacements between the,)
|
|
554
|
+
# rewrite(TupleInt.EMPTY).to(TupleInt(0, lambda _: Int.NEVER)),
|
|
555
|
+
# rewrite(TupleInt(i, idx_fn).append(i2)).to(TupleInt(i + 1, lambda j: Int.if_(j == i, i2, idx_fn(j)))),
|
|
556
|
+
# functional to cons
|
|
557
|
+
rewrite(TupleInt(0, idx_fn), subsume=True).to(TupleInt.EMPTY),
|
|
558
|
+
rewrite(TupleInt(Int(k), idx_fn), subsume=True).to(TupleInt(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0),
|
|
559
|
+
# cons to vec
|
|
560
|
+
rewrite(TupleInt.EMPTY).to(TupleInt.from_vec(Vec[Int]())),
|
|
561
|
+
rewrite(TupleInt.from_vec(vs).append(i)).to(TupleInt.from_vec(vs.append(Vec(i)))),
|
|
562
|
+
# fold
|
|
563
|
+
rewrite(TupleInt.EMPTY.foldl(f, i), subsume=True).to(i),
|
|
564
|
+
rewrite(ti.append(i2).foldl(f, i), subsume=True).to(f(ti.foldl(f, i), i2)),
|
|
565
|
+
# fold boolean
|
|
566
|
+
rewrite(TupleInt.EMPTY.foldl_boolean(bool_f, b), subsume=True).to(b),
|
|
567
|
+
rewrite(ti.append(i2).foldl_boolean(bool_f, b), subsume=True).to(bool_f(ti.foldl_boolean(bool_f, b), i2)),
|
|
568
|
+
# fold tuple_int
|
|
569
|
+
rewrite(TupleInt.EMPTY.foldl_tuple_int(tuple_int_f, ti), subsume=True).to(ti),
|
|
570
|
+
rewrite(ti.append(i2).foldl_tuple_int(tuple_int_f, ti2), subsume=True).to(
|
|
571
|
+
tuple_int_f(ti.foldl_tuple_int(tuple_int_f, ti2), i2)
|
|
572
|
+
),
|
|
573
|
+
# if_
|
|
574
|
+
rewrite(TupleInt.if_(TRUE, ti, ti2), subsume=True).to(ti),
|
|
575
|
+
rewrite(TupleInt.if_(FALSE, ti, ti2), subsume=True).to(ti2),
|
|
576
|
+
# unify append
|
|
577
|
+
rule(eq(ti.append(i)).to(ti2.append(i2))).then(union(ti).with_(ti2), union(i).with_(i2)),
|
|
578
|
+
]
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
class TupleTupleInt(Expr, ruleset=array_api_ruleset):
|
|
582
|
+
@classmethod
|
|
583
|
+
def var(cls, name: StringLike) -> TupleTupleInt: ...
|
|
584
|
+
|
|
585
|
+
EMPTY: ClassVar[TupleTupleInt]
|
|
586
|
+
|
|
587
|
+
def __init__(self, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> None: ...
|
|
588
|
+
|
|
589
|
+
@method(subsume=True)
|
|
590
|
+
@classmethod
|
|
591
|
+
def single(cls, i: TupleIntLike) -> TupleTupleInt:
|
|
592
|
+
i = cast("TupleInt", i)
|
|
593
|
+
return TupleTupleInt(1, lambda _: i)
|
|
594
|
+
|
|
595
|
+
@method(subsume=True)
|
|
596
|
+
@classmethod
|
|
597
|
+
def from_vec(cls, vec: Vec[TupleInt]) -> TupleTupleInt: ...
|
|
598
|
+
|
|
599
|
+
def append(self, i: TupleIntLike) -> TupleTupleInt: ...
|
|
600
|
+
|
|
601
|
+
def __add__(self, other: TupleTupleIntLike) -> TupleTupleInt:
|
|
602
|
+
other = cast("TupleTupleInt", other)
|
|
603
|
+
return TupleTupleInt(
|
|
604
|
+
self.length() + other.length(),
|
|
605
|
+
lambda i: TupleInt.if_(i < self.length(), self[i], other[i - self.length()]),
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
def length(self) -> Int: ...
|
|
609
|
+
def __getitem__(self, i: IntLike) -> TupleInt: ...
|
|
610
|
+
|
|
611
|
+
@method(preserve=True)
|
|
612
|
+
def __len__(self) -> int:
|
|
613
|
+
return self.length().eval()
|
|
614
|
+
|
|
615
|
+
@method(preserve=True)
|
|
616
|
+
def __iter__(self) -> Iterator[TupleInt]:
|
|
617
|
+
return iter(self.eval())
|
|
618
|
+
|
|
619
|
+
@property
|
|
620
|
+
def to_vec(self) -> Vec[TupleInt]: ...
|
|
621
|
+
|
|
622
|
+
@method(preserve=True)
|
|
623
|
+
def eval(self) -> tuple[TupleInt, ...]:
|
|
624
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
|
|
625
|
+
|
|
626
|
+
def drop(self, n: Int) -> TupleTupleInt:
|
|
627
|
+
return TupleTupleInt(self.length() - n, lambda i: self[i + n])
|
|
628
|
+
|
|
629
|
+
def map_int(self, f: Callable[[TupleInt], Int]) -> TupleInt:
|
|
630
|
+
return TupleInt(self.length(), lambda i: f(self[i]))
|
|
631
|
+
|
|
632
|
+
def foldl_value(self, f: Callable[[Value, TupleInt], Value], init: ValueLike) -> Value: ...
|
|
633
|
+
|
|
634
|
+
@method(subsume=True)
|
|
635
|
+
def product(self) -> TupleTupleInt:
|
|
636
|
+
"""
|
|
637
|
+
Cartesian product of inputs
|
|
638
|
+
|
|
639
|
+
https://docs.python.org/3/library/itertools.html#itertools.product
|
|
640
|
+
|
|
641
|
+
https://github.com/saulshanabrook/saulshanabrook/discussions/39
|
|
642
|
+
"""
|
|
643
|
+
return TupleTupleInt(
|
|
644
|
+
self.map_int(lambda x: x.length()).product(),
|
|
645
|
+
lambda i: TupleInt(
|
|
646
|
+
self.length(),
|
|
647
|
+
lambda j: self[j][(i // self.drop(j + 1).map_int(lambda x: x.length()).product()) % self[j].length()],
|
|
648
|
+
),
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
converter(Vec[TupleInt], TupleTupleInt, lambda x: TupleTupleInt.from_vec(x))
|
|
653
|
+
|
|
654
|
+
TupleTupleIntLike: TypeAlias = TupleTupleInt | VecLike[TupleInt, TupleIntLike]
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
@array_api_ruleset.register
|
|
658
|
+
def _tuple_tuple_int(
|
|
659
|
+
length: Int,
|
|
660
|
+
fn: Callable[[TupleInt], Int],
|
|
661
|
+
idx_fn: Callable[[Int], TupleInt],
|
|
662
|
+
f: Callable[[Value, TupleInt], Value],
|
|
663
|
+
i: Value,
|
|
664
|
+
k: i64,
|
|
665
|
+
idx: Int,
|
|
666
|
+
vs: Vec[TupleInt],
|
|
667
|
+
ti: TupleInt,
|
|
668
|
+
ti1: TupleInt,
|
|
669
|
+
tti: TupleTupleInt,
|
|
670
|
+
tti1: TupleTupleInt,
|
|
671
|
+
):
|
|
672
|
+
yield rule(eq(tti).to(TupleTupleInt.from_vec(vs))).then(set_(tti.to_vec).to(vs))
|
|
673
|
+
yield rewrite(TupleTupleInt(length, idx_fn).length()).to(length)
|
|
674
|
+
yield rewrite(TupleTupleInt(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length)))
|
|
675
|
+
|
|
676
|
+
# cons access
|
|
677
|
+
yield rewrite(TupleTupleInt.EMPTY.length()).to(Int(0))
|
|
678
|
+
yield rewrite(TupleTupleInt.EMPTY[idx]).to(TupleInt.NEVER)
|
|
679
|
+
yield rewrite(tti.append(ti).length()).to(tti.length() + 1)
|
|
680
|
+
yield rewrite(tti.append(ti)[idx]).to(TupleInt.if_(idx == tti.length(), ti, tti[idx]))
|
|
681
|
+
|
|
682
|
+
# functional to cons
|
|
683
|
+
yield rewrite(TupleTupleInt(0, idx_fn), subsume=True).to(TupleTupleInt.EMPTY)
|
|
684
|
+
yield rewrite(TupleTupleInt(Int(k), idx_fn), subsume=True).to(
|
|
685
|
+
TupleTupleInt(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0
|
|
686
|
+
)
|
|
687
|
+
# cons to vec
|
|
688
|
+
yield rewrite(TupleTupleInt.EMPTY).to(TupleTupleInt.from_vec(Vec[TupleInt]()))
|
|
689
|
+
yield rewrite(TupleTupleInt.from_vec(vs).append(ti)).to(TupleTupleInt.from_vec(vs.append(Vec(ti))))
|
|
690
|
+
# fold value
|
|
691
|
+
yield rewrite(TupleTupleInt.EMPTY.foldl_value(f, i), subsume=True).to(i)
|
|
692
|
+
yield rewrite(tti.append(ti).foldl_value(f, i), subsume=True).to(f(tti.foldl_value(f, i), ti))
|
|
693
|
+
|
|
694
|
+
# unify append
|
|
695
|
+
yield rule(eq(tti.append(ti)).to(tti1.append(ti1))).then(union(tti).with_(tti1), union(ti).with_(ti1))
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
class OptionalInt(Expr, ruleset=array_api_ruleset):
|
|
699
|
+
none: ClassVar[OptionalInt]
|
|
700
|
+
|
|
701
|
+
@classmethod
|
|
702
|
+
def some(cls, value: Int) -> OptionalInt: ...
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
OptionalIntLike: TypeAlias = OptionalInt | IntLike | None
|
|
706
|
+
|
|
707
|
+
converter(type(None), OptionalInt, lambda _: OptionalInt.none)
|
|
708
|
+
converter(Int, OptionalInt, OptionalInt.some)
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
class DType(Expr, ruleset=array_api_ruleset):
|
|
712
|
+
float64: ClassVar[DType]
|
|
713
|
+
float32: ClassVar[DType]
|
|
714
|
+
int64: ClassVar[DType]
|
|
715
|
+
int32: ClassVar[DType]
|
|
716
|
+
object: ClassVar[DType]
|
|
717
|
+
bool: ClassVar[DType]
|
|
718
|
+
|
|
719
|
+
def __eq__(self, other: DType) -> Boolean: # type: ignore[override]
|
|
720
|
+
...
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
float64 = DType.float64
|
|
724
|
+
float32 = DType.float32
|
|
725
|
+
int32 = DType.int32
|
|
726
|
+
int64 = DType.int64
|
|
727
|
+
|
|
728
|
+
_DTYPES = [float64, float32, int32, int64, DType.object]
|
|
729
|
+
|
|
730
|
+
converter(type, DType, lambda x: convert(np.dtype(x), DType))
|
|
731
|
+
converter(np.dtype, DType, lambda x: getattr(DType, x.name))
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
@array_api_ruleset.register
|
|
735
|
+
def _():
|
|
736
|
+
for l, r in itertools.product(_DTYPES, repeat=2):
|
|
737
|
+
yield rewrite(l == r).to(TRUE if l is r else FALSE)
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
class IsDtypeKind(Expr, ruleset=array_api_ruleset):
|
|
741
|
+
NULL: ClassVar[IsDtypeKind]
|
|
742
|
+
|
|
743
|
+
@classmethod
|
|
744
|
+
def string(cls, s: StringLike) -> IsDtypeKind: ...
|
|
745
|
+
|
|
746
|
+
@classmethod
|
|
747
|
+
def dtype(cls, d: DType) -> IsDtypeKind: ...
|
|
748
|
+
|
|
749
|
+
@method(cost=10)
|
|
750
|
+
def __or__(self, other: IsDtypeKind) -> IsDtypeKind: ...
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
# TODO: Make kind more generic to support tuples.
|
|
754
|
+
@function
|
|
755
|
+
def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: ...
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x))
|
|
759
|
+
converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x))
|
|
760
|
+
converter(
|
|
761
|
+
tuple, IsDtypeKind, lambda x: convert(x[0], IsDtypeKind) | convert(x[1:], IsDtypeKind) if x else IsDtypeKind.NULL
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
@array_api_ruleset.register
|
|
766
|
+
def _isdtype(d: DType, k1: IsDtypeKind, k2: IsDtypeKind):
|
|
767
|
+
return [
|
|
768
|
+
rewrite(isdtype(DType.float32, IsDtypeKind.string("integral"))).to(FALSE),
|
|
769
|
+
rewrite(isdtype(DType.float64, IsDtypeKind.string("integral"))).to(FALSE),
|
|
770
|
+
rewrite(isdtype(DType.object, IsDtypeKind.string("integral"))).to(FALSE),
|
|
771
|
+
rewrite(isdtype(DType.int64, IsDtypeKind.string("integral"))).to(TRUE),
|
|
772
|
+
rewrite(isdtype(DType.int32, IsDtypeKind.string("integral"))).to(TRUE),
|
|
773
|
+
rewrite(isdtype(DType.float32, IsDtypeKind.string("real floating"))).to(TRUE),
|
|
774
|
+
rewrite(isdtype(DType.float64, IsDtypeKind.string("real floating"))).to(TRUE),
|
|
775
|
+
rewrite(isdtype(DType.object, IsDtypeKind.string("real floating"))).to(FALSE),
|
|
776
|
+
rewrite(isdtype(DType.int64, IsDtypeKind.string("real floating"))).to(FALSE),
|
|
777
|
+
rewrite(isdtype(DType.int32, IsDtypeKind.string("real floating"))).to(FALSE),
|
|
778
|
+
rewrite(isdtype(DType.float32, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
779
|
+
rewrite(isdtype(DType.float64, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
780
|
+
rewrite(isdtype(DType.object, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
781
|
+
rewrite(isdtype(DType.int64, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
782
|
+
rewrite(isdtype(DType.int32, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
783
|
+
rewrite(isdtype(d, IsDtypeKind.NULL)).to(FALSE),
|
|
784
|
+
rewrite(isdtype(d, IsDtypeKind.dtype(d))).to(TRUE),
|
|
785
|
+
rewrite(isdtype(d, k1 | k2)).to(isdtype(d, k1) | isdtype(d, k2)),
|
|
786
|
+
rewrite(k1 | IsDtypeKind.NULL).to(k1),
|
|
787
|
+
]
|
|
788
|
+
|
|
789
|
+
|
|
790
|
+
# TODO: Add pushdown for math on scalars to values
|
|
791
|
+
# and add replacements
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
class Value(Expr, ruleset=array_api_ruleset):
|
|
795
|
+
NEVER: ClassVar[Value]
|
|
796
|
+
|
|
797
|
+
@classmethod
|
|
798
|
+
def int(cls, i: IntLike) -> Value: ...
|
|
799
|
+
|
|
800
|
+
@classmethod
|
|
801
|
+
def float(cls, f: FloatLike) -> Value: ...
|
|
802
|
+
|
|
803
|
+
@classmethod
|
|
804
|
+
def bool(cls, b: BooleanLike) -> Value: ...
|
|
805
|
+
|
|
806
|
+
def isfinite(self) -> Boolean: ...
|
|
807
|
+
|
|
808
|
+
def __lt__(self, other: ValueLike) -> Value: ...
|
|
809
|
+
|
|
810
|
+
def __truediv__(self, other: ValueLike) -> Value: ...
|
|
811
|
+
|
|
812
|
+
def __mul__(self, other: ValueLike) -> Value: ...
|
|
813
|
+
|
|
814
|
+
def __add__(self, other: ValueLike) -> Value: ...
|
|
815
|
+
|
|
816
|
+
def astype(self, dtype: DType) -> Value: ...
|
|
817
|
+
|
|
818
|
+
# TODO: Add all operations
|
|
819
|
+
|
|
820
|
+
@property
|
|
821
|
+
def dtype(self) -> DType:
|
|
822
|
+
"""
|
|
823
|
+
Default dtype for this scalar value
|
|
824
|
+
"""
|
|
825
|
+
|
|
826
|
+
@property
|
|
827
|
+
def to_bool(self) -> Boolean: ...
|
|
828
|
+
|
|
829
|
+
@property
|
|
830
|
+
def to_int(self) -> Int: ...
|
|
831
|
+
|
|
832
|
+
@property
|
|
833
|
+
def to_truthy_value(self) -> Value:
|
|
834
|
+
"""
|
|
835
|
+
Converts the value to a bool, based on if its truthy.
|
|
836
|
+
|
|
837
|
+
https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.any.html
|
|
838
|
+
"""
|
|
839
|
+
|
|
840
|
+
def conj(self) -> Value: ...
|
|
841
|
+
def real(self) -> Value: ...
|
|
842
|
+
def sqrt(self) -> Value: ...
|
|
843
|
+
|
|
844
|
+
@classmethod
|
|
845
|
+
def if_(cls, b: BooleanLike, i: ValueLike, j: ValueLike) -> Value: ...
|
|
846
|
+
|
|
847
|
+
def __eq__(self, other: ValueLike) -> Boolean: ... # type: ignore[override]
|
|
848
|
+
|
|
849
|
+
|
|
850
|
+
ValueLike: TypeAlias = Value | IntLike | FloatLike | BooleanLike
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
converter(Int, Value, Value.int)
|
|
854
|
+
converter(Float, Value, Value.float)
|
|
855
|
+
converter(Boolean, Value, Value.bool)
|
|
856
|
+
converter(Value, Int, lambda x: x.to_int, 10)
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
@array_api_ruleset.register
|
|
860
|
+
def _value(i: Int, f: Float, b: Boolean, v: Value, v1: Value, i1: Int, f1: Float, b1: Boolean):
|
|
861
|
+
# Default dtypes
|
|
862
|
+
# https://data-apis.org/array-api/latest/API_specification/data_types.html?highlight=dtype#default-data-types
|
|
863
|
+
yield rewrite(Value.int(i).dtype).to(DType.int64)
|
|
864
|
+
yield rewrite(Value.float(f).dtype).to(DType.float64)
|
|
865
|
+
yield rewrite(Value.bool(b).dtype).to(DType.bool)
|
|
866
|
+
|
|
867
|
+
yield rewrite(Value.bool(b).to_bool).to(b)
|
|
868
|
+
yield rewrite(Value.int(i).to_int).to(i)
|
|
869
|
+
|
|
870
|
+
yield rewrite(Value.bool(b).to_truthy_value).to(Value.bool(b))
|
|
871
|
+
# TODO: Add more rules for to_bool_value
|
|
872
|
+
|
|
873
|
+
yield rewrite(Value.float(f).conj()).to(Value.float(f))
|
|
874
|
+
yield rewrite(Value.float(f).real()).to(Value.float(f))
|
|
875
|
+
yield rewrite(Value.int(i).real()).to(Value.int(i))
|
|
876
|
+
yield rewrite(Value.int(i).conj()).to(Value.int(i))
|
|
877
|
+
|
|
878
|
+
yield rewrite(Value.float(f).sqrt()).to(Value.float(f ** (0.5)))
|
|
879
|
+
|
|
880
|
+
yield rewrite(Value.float(Float.rational(BigRat(0, 1))) + v).to(v)
|
|
881
|
+
|
|
882
|
+
yield rewrite(Value.if_(TRUE, v, v1)).to(v)
|
|
883
|
+
yield rewrite(Value.if_(FALSE, v, v1)).to(v1)
|
|
884
|
+
|
|
885
|
+
# ==
|
|
886
|
+
yield rewrite(Value.int(i) == Value.int(i1)).to(i == i1)
|
|
887
|
+
yield rewrite(Value.float(f) == Value.float(f1)).to(f == f1)
|
|
888
|
+
yield rewrite(Value.bool(b) == Value.bool(b1)).to(b == b1)
|
|
889
|
+
|
|
890
|
+
|
|
891
|
+
class TupleValue(Expr, ruleset=array_api_ruleset):
|
|
892
|
+
EMPTY: ClassVar[TupleValue]
|
|
893
|
+
|
|
894
|
+
def __init__(self, length: IntLike, idx_fn: Callable[[Int], Value]) -> None: ...
|
|
895
|
+
|
|
896
|
+
def append(self, i: ValueLike) -> TupleValue: ...
|
|
897
|
+
|
|
898
|
+
@classmethod
|
|
899
|
+
def from_vec(cls, vec: Vec[Value]) -> TupleValue: ...
|
|
900
|
+
|
|
901
|
+
def __add__(self, other: TupleValueLike) -> TupleValue:
|
|
902
|
+
other = cast("TupleValue", other)
|
|
903
|
+
return TupleValue(
|
|
904
|
+
self.length() + other.length(),
|
|
905
|
+
lambda i: Value.if_(i < self.length(), self[i], other[i - self.length()]),
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
def length(self) -> Int: ...
|
|
909
|
+
|
|
910
|
+
def __getitem__(self, i: Int) -> Value: ...
|
|
911
|
+
|
|
912
|
+
def foldl_boolean(self, f: Callable[[Boolean, Value], Boolean], init: BooleanLike) -> Boolean: ...
|
|
913
|
+
|
|
914
|
+
def contains(self, value: ValueLike) -> Boolean:
|
|
915
|
+
value = cast("Value", value)
|
|
916
|
+
return self.foldl_boolean(lambda acc, j: acc | (value == j), FALSE)
|
|
917
|
+
|
|
918
|
+
@method(subsume=True)
|
|
919
|
+
@classmethod
|
|
920
|
+
def from_tuple_int(cls, ti: TupleIntLike) -> TupleValue:
|
|
921
|
+
ti = cast("TupleInt", ti)
|
|
922
|
+
return TupleValue(ti.length(), lambda i: Value.int(ti[i]))
|
|
923
|
+
|
|
924
|
+
|
|
925
|
+
converter(Vec[Value], TupleValue, lambda x: TupleValue.from_vec(x))
|
|
926
|
+
converter(TupleInt, TupleValue, lambda x: TupleValue.from_tuple_int(x))
|
|
927
|
+
|
|
928
|
+
TupleValueLike: TypeAlias = TupleValue | VecLike[Value, ValueLike] | TupleIntLike
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
@array_api_ruleset.register
|
|
932
|
+
def _tuple_value(
|
|
933
|
+
length: Int,
|
|
934
|
+
idx_fn: Callable[[Int], Value],
|
|
935
|
+
k: i64,
|
|
936
|
+
idx: Int,
|
|
937
|
+
vs: Vec[Value],
|
|
938
|
+
v: Value,
|
|
939
|
+
v1: Value,
|
|
940
|
+
tv: TupleValue,
|
|
941
|
+
tv1: TupleValue,
|
|
942
|
+
bool_f: Callable[[Boolean, Value], Boolean],
|
|
943
|
+
b: Boolean,
|
|
944
|
+
):
|
|
945
|
+
yield rewrite(TupleValue(length, idx_fn).length()).to(length)
|
|
946
|
+
yield rewrite(TupleValue(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length)))
|
|
947
|
+
|
|
948
|
+
# cons access
|
|
949
|
+
yield rewrite(TupleValue.EMPTY.length()).to(Int(0))
|
|
950
|
+
yield rewrite(TupleValue.EMPTY[idx]).to(Value.NEVER)
|
|
951
|
+
yield rewrite(tv.append(v).length()).to(tv.length() + 1)
|
|
952
|
+
yield rewrite(tv.append(v)[idx]).to(Value.if_(idx == tv.length(), v, tv[idx]))
|
|
953
|
+
|
|
954
|
+
# functional to cons
|
|
955
|
+
yield rewrite(TupleValue(0, idx_fn), subsume=True).to(TupleValue.EMPTY)
|
|
956
|
+
yield rewrite(TupleValue(Int(k), idx_fn), subsume=True).to(
|
|
957
|
+
TupleValue(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0
|
|
958
|
+
)
|
|
959
|
+
|
|
960
|
+
# cons to vec
|
|
961
|
+
yield rewrite(TupleValue.EMPTY).to(TupleValue.from_vec(Vec[Value]()))
|
|
962
|
+
yield rewrite(TupleValue.from_vec(vs).append(v)).to(TupleValue.from_vec(vs.append(Vec(v))))
|
|
963
|
+
|
|
964
|
+
# fold boolean
|
|
965
|
+
yield rewrite(TupleValue.EMPTY.foldl_boolean(bool_f, b), subsume=True).to(b)
|
|
966
|
+
yield rewrite(tv.append(v).foldl_boolean(bool_f, b), subsume=True).to(bool_f(tv.foldl_boolean(bool_f, b), v))
|
|
967
|
+
|
|
968
|
+
# unify append
|
|
969
|
+
yield rule(eq(tv.append(v)).to(tv1.append(v1))).then(union(tv).with_(tv1), union(v).with_(v1))
|
|
970
|
+
|
|
971
|
+
|
|
972
|
+
@function
|
|
973
|
+
def possible_values(values: Value) -> TupleValue:
|
|
974
|
+
"""
|
|
975
|
+
All possible values in the input value.
|
|
976
|
+
"""
|
|
977
|
+
|
|
978
|
+
|
|
979
|
+
class Slice(Expr, ruleset=array_api_ruleset):
|
|
980
|
+
def __init__(
|
|
981
|
+
self,
|
|
982
|
+
start: OptionalInt = OptionalInt.none,
|
|
983
|
+
stop: OptionalInt = OptionalInt.none,
|
|
984
|
+
step: OptionalInt = OptionalInt.none,
|
|
985
|
+
) -> None: ...
|
|
986
|
+
|
|
987
|
+
|
|
988
|
+
converter(
|
|
989
|
+
slice,
|
|
990
|
+
Slice,
|
|
991
|
+
lambda x: Slice(convert(x.start, OptionalInt), convert(x.stop, OptionalInt), convert(x.step, OptionalInt)),
|
|
992
|
+
)
|
|
993
|
+
|
|
994
|
+
SliceLike: TypeAlias = Slice | slice
|
|
995
|
+
|
|
996
|
+
|
|
997
|
+
class MultiAxisIndexKeyItem(Expr, ruleset=array_api_ruleset):
|
|
998
|
+
ELLIPSIS: ClassVar[MultiAxisIndexKeyItem]
|
|
999
|
+
NONE: ClassVar[MultiAxisIndexKeyItem]
|
|
1000
|
+
|
|
1001
|
+
@classmethod
|
|
1002
|
+
def int(cls, i: Int) -> MultiAxisIndexKeyItem: ...
|
|
1003
|
+
|
|
1004
|
+
@classmethod
|
|
1005
|
+
def slice(cls, slice: Slice) -> MultiAxisIndexKeyItem: ...
|
|
1006
|
+
|
|
1007
|
+
|
|
1008
|
+
converter(type(...), MultiAxisIndexKeyItem, lambda _: MultiAxisIndexKeyItem.ELLIPSIS)
|
|
1009
|
+
converter(type(None), MultiAxisIndexKeyItem, lambda _: MultiAxisIndexKeyItem.NONE)
|
|
1010
|
+
converter(Int, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.int)
|
|
1011
|
+
converter(Slice, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.slice)
|
|
1012
|
+
|
|
1013
|
+
MultiAxisIndexKeyItemLike: TypeAlias = MultiAxisIndexKeyItem | EllipsisType | None | IntLike | SliceLike
|
|
1014
|
+
|
|
1015
|
+
|
|
1016
|
+
class MultiAxisIndexKey(Expr, ruleset=array_api_ruleset):
|
|
1017
|
+
def __init__(self, length: IntLike, idx_fn: Callable[[Int], MultiAxisIndexKeyItem]) -> None: ...
|
|
1018
|
+
|
|
1019
|
+
def __add__(self, other: MultiAxisIndexKey) -> MultiAxisIndexKey: ...
|
|
1020
|
+
|
|
1021
|
+
@classmethod
|
|
1022
|
+
def from_vec(cls, vec: Vec[MultiAxisIndexKeyItem]) -> MultiAxisIndexKey: ...
|
|
1023
|
+
|
|
1024
|
+
|
|
1025
|
+
MultiAxisIndexKeyLike: TypeAlias = "MultiAxisIndexKey | tuple[MultiAxisIndexKeyItemLike, ...] | TupleIntLike"
|
|
1026
|
+
|
|
1027
|
+
|
|
1028
|
+
converter(
|
|
1029
|
+
tuple,
|
|
1030
|
+
MultiAxisIndexKey,
|
|
1031
|
+
lambda x: MultiAxisIndexKey.from_vec(Vec(*(convert(i, MultiAxisIndexKeyItem) for i in x))),
|
|
1032
|
+
)
|
|
1033
|
+
converter(
|
|
1034
|
+
TupleInt, MultiAxisIndexKey, lambda ti: MultiAxisIndexKey(ti.length(), lambda i: MultiAxisIndexKeyItem.int(ti[i]))
|
|
1035
|
+
)
|
|
1036
|
+
|
|
1037
|
+
|
|
1038
|
+
class IndexKey(Expr, ruleset=array_api_ruleset):
|
|
1039
|
+
"""
|
|
1040
|
+
A key for indexing into an array
|
|
1041
|
+
|
|
1042
|
+
https://data-apis.org/array-api/2022.12/API_specification/indexing.html
|
|
1043
|
+
|
|
1044
|
+
It is equivalent to the following type signature:
|
|
1045
|
+
|
|
1046
|
+
Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis, None], ...], array]
|
|
1047
|
+
"""
|
|
1048
|
+
|
|
1049
|
+
ELLIPSIS: ClassVar[IndexKey]
|
|
1050
|
+
|
|
1051
|
+
@classmethod
|
|
1052
|
+
def int(cls, i: Int) -> IndexKey: ...
|
|
1053
|
+
|
|
1054
|
+
@classmethod
|
|
1055
|
+
def slice(cls, slice: Slice) -> IndexKey: ...
|
|
1056
|
+
|
|
1057
|
+
# Disabled until we support late binding
|
|
1058
|
+
# @classmethod
|
|
1059
|
+
# def boolean_array(cls, b: NDArray) -> IndexKey:
|
|
1060
|
+
# ...
|
|
1061
|
+
|
|
1062
|
+
@classmethod
|
|
1063
|
+
def multi_axis(cls, key: MultiAxisIndexKey) -> IndexKey: ...
|
|
1064
|
+
|
|
1065
|
+
@classmethod
|
|
1066
|
+
def ndarray(cls, key: NDArray) -> IndexKey:
|
|
1067
|
+
"""
|
|
1068
|
+
Indexes by a masked array
|
|
1069
|
+
"""
|
|
1070
|
+
|
|
1071
|
+
|
|
1072
|
+
IndexKeyLike: TypeAlias = "IndexKey | IntLike | SliceLike | MultiAxisIndexKeyLike | NDArrayLike"
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
converter(type(...), IndexKey, lambda _: IndexKey.ELLIPSIS)
|
|
1076
|
+
converter(Int, IndexKey, lambda i: IndexKey.int(i))
|
|
1077
|
+
converter(Slice, IndexKey, lambda s: IndexKey.slice(s))
|
|
1078
|
+
converter(MultiAxisIndexKey, IndexKey, lambda m: IndexKey.multi_axis(m))
|
|
1079
|
+
|
|
1080
|
+
|
|
1081
|
+
class Device(Expr, ruleset=array_api_ruleset): ...
|
|
1082
|
+
|
|
1083
|
+
|
|
1084
|
+
ALL_INDICES: TupleInt = constant("ALL_INDICES", TupleInt)
|
|
1085
|
+
|
|
1086
|
+
|
|
1087
|
+
class NDArray(Expr, ruleset=array_api_ruleset):
|
|
1088
|
+
def __init__(self, shape: TupleIntLike, dtype: DType, idx_fn: Callable[[TupleInt], Value]) -> None: ...
|
|
1089
|
+
|
|
1090
|
+
NEVER: ClassVar[NDArray]
|
|
1091
|
+
|
|
1092
|
+
@method(cost=200)
|
|
1093
|
+
@classmethod
|
|
1094
|
+
def var(cls, name: StringLike) -> NDArray: ...
|
|
1095
|
+
|
|
1096
|
+
@method(preserve=True)
|
|
1097
|
+
def __array_namespace__(self, api_version: object = None) -> ModuleType:
|
|
1098
|
+
return sys.modules[__name__]
|
|
1099
|
+
|
|
1100
|
+
@property
|
|
1101
|
+
def ndim(self) -> Int: ...
|
|
1102
|
+
|
|
1103
|
+
@property
|
|
1104
|
+
def dtype(self) -> DType: ...
|
|
1105
|
+
|
|
1106
|
+
@property
|
|
1107
|
+
def device(self) -> Device: ...
|
|
1108
|
+
|
|
1109
|
+
@property
|
|
1110
|
+
def shape(self) -> TupleInt: ...
|
|
1111
|
+
|
|
1112
|
+
@method(preserve=True)
|
|
1113
|
+
def __bool__(self) -> bool:
|
|
1114
|
+
return self.to_value().to_bool.eval()
|
|
1115
|
+
|
|
1116
|
+
@property
|
|
1117
|
+
def size(self) -> Int: ...
|
|
1118
|
+
|
|
1119
|
+
@method(preserve=True)
|
|
1120
|
+
def __len__(self) -> int:
|
|
1121
|
+
return self.size.eval()
|
|
1122
|
+
|
|
1123
|
+
@method(preserve=True)
|
|
1124
|
+
def __iter__(self) -> Iterator[NDArray]:
|
|
1125
|
+
for i in range(len(self)):
|
|
1126
|
+
yield self[IndexKey.int(Int(i))]
|
|
1127
|
+
|
|
1128
|
+
def __getitem__(self, key: IndexKeyLike) -> NDArray: ...
|
|
1129
|
+
|
|
1130
|
+
def __setitem__(self, key: IndexKeyLike, value: NDArray) -> None: ...
|
|
1131
|
+
|
|
1132
|
+
def __lt__(self, other: NDArrayLike) -> NDArray: ...
|
|
1133
|
+
|
|
1134
|
+
def __le__(self, other: NDArrayLike) -> NDArray: ...
|
|
1135
|
+
|
|
1136
|
+
def __eq__(self, other: NDArrayLike) -> NDArray: # type: ignore[override]
|
|
1137
|
+
...
|
|
1138
|
+
|
|
1139
|
+
# TODO: Add support for overloaded __ne__
|
|
1140
|
+
# def __ne__(self, other: NDArrayLike) -> NDArray: # type: ignore[override]
|
|
1141
|
+
# ...
|
|
1142
|
+
|
|
1143
|
+
def __gt__(self, other: NDArrayLike) -> NDArray: ...
|
|
1144
|
+
|
|
1145
|
+
def __ge__(self, other: NDArrayLike) -> NDArray: ...
|
|
1146
|
+
|
|
1147
|
+
def __add__(self, other: NDArrayLike) -> NDArray: ...
|
|
1148
|
+
|
|
1149
|
+
def __sub__(self, other: NDArrayLike) -> NDArray: ...
|
|
1150
|
+
|
|
1151
|
+
def __mul__(self, other: NDArrayLike) -> NDArray: ...
|
|
1152
|
+
|
|
1153
|
+
def __matmul__(self, other: NDArrayLike) -> NDArray: ...
|
|
1154
|
+
|
|
1155
|
+
def __truediv__(self, other: NDArrayLike) -> NDArray: ...
|
|
1156
|
+
|
|
1157
|
+
def __floordiv__(self, other: NDArrayLike) -> NDArray: ...
|
|
1158
|
+
|
|
1159
|
+
def __mod__(self, other: NDArrayLike) -> NDArray: ...
|
|
1160
|
+
|
|
1161
|
+
def __divmod__(self, other: NDArrayLike) -> NDArray: ...
|
|
1162
|
+
|
|
1163
|
+
def __pow__(self, other: NDArrayLike) -> NDArray: ...
|
|
1164
|
+
|
|
1165
|
+
def __lshift__(self, other: NDArrayLike) -> NDArray: ...
|
|
1166
|
+
|
|
1167
|
+
def __rshift__(self, other: NDArrayLike) -> NDArray: ...
|
|
1168
|
+
|
|
1169
|
+
def __and__(self, other: NDArrayLike) -> NDArray: ...
|
|
1170
|
+
|
|
1171
|
+
def __xor__(self, other: NDArrayLike) -> NDArray: ...
|
|
1172
|
+
|
|
1173
|
+
def __or__(self, other: NDArrayLike) -> NDArray: ...
|
|
1174
|
+
|
|
1175
|
+
def __radd__(self, other: NDArray) -> NDArray: ...
|
|
1176
|
+
|
|
1177
|
+
def __rsub__(self, other: NDArray) -> NDArray: ...
|
|
1178
|
+
|
|
1179
|
+
def __rmul__(self, other: NDArray) -> NDArray: ...
|
|
1180
|
+
|
|
1181
|
+
def __rmatmul__(self, other: NDArray) -> NDArray: ...
|
|
1182
|
+
|
|
1183
|
+
def __rtruediv__(self, other: NDArray) -> NDArray: ...
|
|
1184
|
+
|
|
1185
|
+
def __rfloordiv__(self, other: NDArray) -> NDArray: ...
|
|
1186
|
+
|
|
1187
|
+
def __rmod__(self, other: NDArray) -> NDArray: ...
|
|
1188
|
+
|
|
1189
|
+
def __rpow__(self, other: NDArray) -> NDArray: ...
|
|
1190
|
+
|
|
1191
|
+
def __rlshift__(self, other: NDArray) -> NDArray: ...
|
|
1192
|
+
|
|
1193
|
+
def __rrshift__(self, other: NDArray) -> NDArray: ...
|
|
1194
|
+
|
|
1195
|
+
def __rand__(self, other: NDArray) -> NDArray: ...
|
|
1196
|
+
|
|
1197
|
+
def __rxor__(self, other: NDArray) -> NDArray: ...
|
|
1198
|
+
|
|
1199
|
+
def __ror__(self, other: NDArray) -> NDArray: ...
|
|
1200
|
+
|
|
1201
|
+
@classmethod
|
|
1202
|
+
def scalar(cls, value: Value) -> NDArray:
|
|
1203
|
+
return NDArray(TupleInt.EMPTY, value.dtype, lambda _: value)
|
|
1204
|
+
|
|
1205
|
+
def to_value(self) -> Value:
|
|
1206
|
+
"""
|
|
1207
|
+
Returns the value if this is a scalar.
|
|
1208
|
+
"""
|
|
1209
|
+
|
|
1210
|
+
def to_values(self) -> TupleValue:
|
|
1211
|
+
"""
|
|
1212
|
+
Returns the value if this is a vector.
|
|
1213
|
+
"""
|
|
1214
|
+
|
|
1215
|
+
@property
|
|
1216
|
+
def T(self) -> NDArray:
|
|
1217
|
+
"""
|
|
1218
|
+
https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.array.T.html#array_api.array.T
|
|
1219
|
+
"""
|
|
1220
|
+
|
|
1221
|
+
@classmethod
|
|
1222
|
+
def vector(cls, values: TupleValueLike) -> NDArray: ...
|
|
1223
|
+
|
|
1224
|
+
def index(self, indices: TupleIntLike) -> Value:
|
|
1225
|
+
"""
|
|
1226
|
+
Return the value at the given indices.
|
|
1227
|
+
"""
|
|
1228
|
+
|
|
1229
|
+
@classmethod
|
|
1230
|
+
def if_(cls, b: BooleanLike, i: NDArrayLike, j: NDArrayLike) -> NDArray: ...
|
|
1231
|
+
|
|
1232
|
+
|
|
1233
|
+
NDArrayLike: TypeAlias = NDArray | ValueLike | TupleValueLike
|
|
1234
|
+
|
|
1235
|
+
converter(NDArray, IndexKey, lambda v: IndexKey.ndarray(v))
|
|
1236
|
+
converter(Value, NDArray, lambda v: NDArray.scalar(v))
|
|
1237
|
+
# Need this if we want to use ints in slices of arrays coming from 1d arrays, but make it more expensive
|
|
1238
|
+
# to prefer upcasting in the other direction when we can, which is safer at runtime
|
|
1239
|
+
converter(NDArray, Value, lambda n: n.to_value(), 100)
|
|
1240
|
+
converter(TupleValue, NDArray, lambda v: NDArray.vector(v))
|
|
1241
|
+
converter(TupleInt, TupleValue, lambda v: TupleValue.from_tuple_int(v))
|
|
1242
|
+
|
|
1243
|
+
|
|
1244
|
+
@array_api_ruleset.register
|
|
1245
|
+
def _ndarray(
|
|
1246
|
+
x: NDArray,
|
|
1247
|
+
x1: NDArray,
|
|
1248
|
+
b: Boolean,
|
|
1249
|
+
f: Float,
|
|
1250
|
+
fi1: f64,
|
|
1251
|
+
fi2: f64,
|
|
1252
|
+
shape: TupleInt,
|
|
1253
|
+
dtype: DType,
|
|
1254
|
+
idx_fn: Callable[[TupleInt], Value],
|
|
1255
|
+
idx: TupleInt,
|
|
1256
|
+
tv: TupleValue,
|
|
1257
|
+
):
|
|
1258
|
+
return [
|
|
1259
|
+
rewrite(NDArray(shape, dtype, idx_fn).shape).to(shape),
|
|
1260
|
+
rewrite(NDArray(shape, dtype, idx_fn).dtype).to(dtype),
|
|
1261
|
+
rewrite(NDArray(shape, dtype, idx_fn).index(idx), subsume=True).to(idx_fn(idx)),
|
|
1262
|
+
rewrite(x.ndim).to(x.shape.length()),
|
|
1263
|
+
# rewrite(NDArray.scalar(Value.bool(b)).to_bool()).to(b),
|
|
1264
|
+
# Converting to a value requires a scalar bool value
|
|
1265
|
+
rewrite(x.to_value()).to(x.index(TupleInt.EMPTY)),
|
|
1266
|
+
rewrite(NDArray.vector(tv).to_values()).to(tv),
|
|
1267
|
+
# TODO: Push these down to float
|
|
1268
|
+
rewrite(NDArray.scalar(Value.float(f)) / NDArray.scalar(Value.float(f))).to(
|
|
1269
|
+
NDArray.scalar(Value.float(Float(1.0)))
|
|
1270
|
+
),
|
|
1271
|
+
rewrite(NDArray.scalar(Value.float(f)) - NDArray.scalar(Value.float(f))).to(
|
|
1272
|
+
NDArray.scalar(Value.float(Float(0.0)))
|
|
1273
|
+
),
|
|
1274
|
+
rewrite(NDArray.scalar(Value.float(Float(fi1))) > NDArray.scalar(Value.float(Float(fi2)))).to(
|
|
1275
|
+
NDArray.scalar(Value.bool(TRUE)), fi1 > fi2
|
|
1276
|
+
),
|
|
1277
|
+
rewrite(NDArray.scalar(Value.float(Float(fi1))) > NDArray.scalar(Value.float(Float(fi2)))).to(
|
|
1278
|
+
NDArray.scalar(Value.bool(FALSE)), fi1 <= fi2
|
|
1279
|
+
),
|
|
1280
|
+
# Transpose of tranpose is the original array
|
|
1281
|
+
rewrite(x.T.T).to(x),
|
|
1282
|
+
# if_
|
|
1283
|
+
rewrite(NDArray.if_(TRUE, x, x1)).to(x),
|
|
1284
|
+
rewrite(NDArray.if_(FALSE, x, x1)).to(x1),
|
|
1285
|
+
]
|
|
1286
|
+
|
|
1287
|
+
|
|
1288
|
+
class TupleNDArray(Expr, ruleset=array_api_ruleset):
|
|
1289
|
+
EMPTY: ClassVar[TupleNDArray]
|
|
1290
|
+
|
|
1291
|
+
def __init__(self, length: IntLike, idx_fn: Callable[[Int], NDArray]) -> None: ...
|
|
1292
|
+
|
|
1293
|
+
def append(self, i: NDArrayLike) -> TupleNDArray: ...
|
|
1294
|
+
|
|
1295
|
+
@classmethod
|
|
1296
|
+
def from_vec(cls, vec: Vec[NDArray]) -> TupleNDArray: ...
|
|
1297
|
+
|
|
1298
|
+
def __add__(self, other: TupleNDArrayLike) -> TupleNDArray:
|
|
1299
|
+
other = cast("TupleNDArray", other)
|
|
1300
|
+
return TupleNDArray(
|
|
1301
|
+
self.length() + other.length(),
|
|
1302
|
+
lambda i: NDArray.if_(i < self.length(), self[i], other[i - self.length()]),
|
|
1303
|
+
)
|
|
1304
|
+
|
|
1305
|
+
def length(self) -> Int: ...
|
|
1306
|
+
|
|
1307
|
+
def __getitem__(self, i: IntLike) -> NDArray: ...
|
|
1308
|
+
|
|
1309
|
+
@method(preserve=True)
|
|
1310
|
+
def __len__(self) -> int:
|
|
1311
|
+
return self.length().eval()
|
|
1312
|
+
|
|
1313
|
+
@method(preserve=True)
|
|
1314
|
+
def __iter__(self) -> Iterator[NDArray]:
|
|
1315
|
+
return iter(self.eval())
|
|
1316
|
+
|
|
1317
|
+
@property
|
|
1318
|
+
def to_vec(self) -> Vec[NDArray]: ...
|
|
1319
|
+
|
|
1320
|
+
@method(preserve=True)
|
|
1321
|
+
def eval(self) -> tuple[NDArray, ...]:
|
|
1322
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
|
|
1323
|
+
|
|
1324
|
+
|
|
1325
|
+
converter(Vec[NDArray], TupleNDArray, lambda x: TupleNDArray.from_vec(x))
|
|
1326
|
+
|
|
1327
|
+
TupleNDArrayLike: TypeAlias = TupleNDArray | VecLike[NDArray, NDArrayLike]
|
|
1328
|
+
|
|
1329
|
+
|
|
1330
|
+
@array_api_ruleset.register
|
|
1331
|
+
def _tuple_ndarray(
|
|
1332
|
+
length: Int,
|
|
1333
|
+
idx_fn: Callable[[Int], NDArray],
|
|
1334
|
+
k: i64,
|
|
1335
|
+
idx: Int,
|
|
1336
|
+
vs: Vec[NDArray],
|
|
1337
|
+
v: NDArray,
|
|
1338
|
+
v1: NDArray,
|
|
1339
|
+
tv: TupleNDArray,
|
|
1340
|
+
tv1: TupleNDArray,
|
|
1341
|
+
b: Boolean,
|
|
1342
|
+
):
|
|
1343
|
+
yield rule(eq(tv).to(TupleNDArray.from_vec(vs))).then(set_(tv.to_vec).to(vs))
|
|
1344
|
+
yield rewrite(TupleNDArray(length, idx_fn).length()).to(length)
|
|
1345
|
+
yield rewrite(TupleNDArray(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length)))
|
|
1346
|
+
|
|
1347
|
+
# cons access
|
|
1348
|
+
yield rewrite(TupleNDArray.EMPTY.length()).to(Int(0))
|
|
1349
|
+
yield rewrite(TupleNDArray.EMPTY[idx]).to(NDArray.NEVER)
|
|
1350
|
+
yield rewrite(tv.append(v).length()).to(tv.length() + 1)
|
|
1351
|
+
yield rewrite(tv.append(v)[idx]).to(NDArray.if_(idx == tv.length(), v, tv[idx]))
|
|
1352
|
+
# functional to cons
|
|
1353
|
+
yield rewrite(TupleNDArray(0, idx_fn), subsume=True).to(TupleNDArray.EMPTY)
|
|
1354
|
+
yield rewrite(TupleNDArray(Int(k), idx_fn), subsume=True).to(
|
|
1355
|
+
TupleNDArray(k - 1, idx_fn).append(idx_fn(Int(k - 1))), k > 0
|
|
1356
|
+
)
|
|
1357
|
+
|
|
1358
|
+
# cons to vec
|
|
1359
|
+
yield rewrite(TupleNDArray.EMPTY).to(TupleNDArray.from_vec(Vec[NDArray]()))
|
|
1360
|
+
yield rewrite(TupleNDArray.from_vec(vs).append(v)).to(TupleNDArray.from_vec(vs.append(Vec(v))))
|
|
1361
|
+
|
|
1362
|
+
# unify append
|
|
1363
|
+
yield rule(eq(tv.append(v)).to(tv1.append(v1))).then(union(tv).with_(tv1), union(v).with_(v1))
|
|
1364
|
+
|
|
1365
|
+
|
|
1366
|
+
class OptionalBool(Expr, ruleset=array_api_ruleset):
|
|
1367
|
+
none: ClassVar[OptionalBool]
|
|
1368
|
+
|
|
1369
|
+
@classmethod
|
|
1370
|
+
def some(cls, value: Boolean) -> OptionalBool: ...
|
|
1371
|
+
|
|
1372
|
+
|
|
1373
|
+
converter(type(None), OptionalBool, lambda _: OptionalBool.none)
|
|
1374
|
+
converter(Boolean, OptionalBool, lambda x: OptionalBool.some(x))
|
|
1375
|
+
|
|
1376
|
+
|
|
1377
|
+
class OptionalDType(Expr, ruleset=array_api_ruleset):
|
|
1378
|
+
none: ClassVar[OptionalDType]
|
|
1379
|
+
|
|
1380
|
+
@classmethod
|
|
1381
|
+
def some(cls, value: DType) -> OptionalDType: ...
|
|
1382
|
+
|
|
1383
|
+
|
|
1384
|
+
converter(type(None), OptionalDType, lambda _: OptionalDType.none)
|
|
1385
|
+
converter(DType, OptionalDType, lambda x: OptionalDType.some(x))
|
|
1386
|
+
|
|
1387
|
+
|
|
1388
|
+
class OptionalDevice(Expr, ruleset=array_api_ruleset):
|
|
1389
|
+
none: ClassVar[OptionalDevice]
|
|
1390
|
+
|
|
1391
|
+
@classmethod
|
|
1392
|
+
def some(cls, value: Device) -> OptionalDevice: ...
|
|
1393
|
+
|
|
1394
|
+
|
|
1395
|
+
converter(type(None), OptionalDevice, lambda _: OptionalDevice.none)
|
|
1396
|
+
converter(Device, OptionalDevice, lambda x: OptionalDevice.some(x))
|
|
1397
|
+
|
|
1398
|
+
|
|
1399
|
+
class OptionalTupleInt(Expr, ruleset=array_api_ruleset):
|
|
1400
|
+
none: ClassVar[OptionalTupleInt]
|
|
1401
|
+
|
|
1402
|
+
@classmethod
|
|
1403
|
+
def some(cls, value: TupleIntLike) -> OptionalTupleInt: ...
|
|
1404
|
+
|
|
1405
|
+
|
|
1406
|
+
converter(type(None), OptionalTupleInt, lambda _: OptionalTupleInt.none)
|
|
1407
|
+
converter(TupleInt, OptionalTupleInt, lambda x: OptionalTupleInt.some(x))
|
|
1408
|
+
|
|
1409
|
+
|
|
1410
|
+
class IntOrTuple(Expr, ruleset=array_api_ruleset):
|
|
1411
|
+
none: ClassVar[IntOrTuple]
|
|
1412
|
+
|
|
1413
|
+
@classmethod
|
|
1414
|
+
def int(cls, value: Int) -> IntOrTuple: ...
|
|
1415
|
+
|
|
1416
|
+
@classmethod
|
|
1417
|
+
def tuple(cls, value: TupleIntLike) -> IntOrTuple: ...
|
|
1418
|
+
|
|
1419
|
+
|
|
1420
|
+
converter(Int, IntOrTuple, lambda v: IntOrTuple.int(v))
|
|
1421
|
+
converter(TupleInt, IntOrTuple, lambda v: IntOrTuple.tuple(v))
|
|
1422
|
+
|
|
1423
|
+
|
|
1424
|
+
class OptionalIntOrTuple(Expr, ruleset=array_api_ruleset):
|
|
1425
|
+
none: ClassVar[OptionalIntOrTuple]
|
|
1426
|
+
|
|
1427
|
+
@classmethod
|
|
1428
|
+
def some(cls, value: IntOrTuple) -> OptionalIntOrTuple: ...
|
|
1429
|
+
|
|
1430
|
+
|
|
1431
|
+
converter(type(None), OptionalIntOrTuple, lambda _: OptionalIntOrTuple.none)
|
|
1432
|
+
converter(IntOrTuple, OptionalIntOrTuple, lambda v: OptionalIntOrTuple.some(v))
|
|
1433
|
+
|
|
1434
|
+
|
|
1435
|
+
@function
|
|
1436
|
+
def asarray(
|
|
1437
|
+
a: NDArray,
|
|
1438
|
+
dtype: OptionalDType = OptionalDType.none,
|
|
1439
|
+
copy: OptionalBool = OptionalBool.none,
|
|
1440
|
+
device: OptionalDevice = OptionalDevice.none,
|
|
1441
|
+
) -> NDArray: ...
|
|
1442
|
+
|
|
1443
|
+
|
|
1444
|
+
@array_api_ruleset.register
|
|
1445
|
+
def _assarray(a: NDArray, d: OptionalDType, ob: OptionalBool):
|
|
1446
|
+
yield rewrite(asarray(a, d, ob).ndim).to(a.ndim) # asarray doesn't change ndim
|
|
1447
|
+
yield rewrite(asarray(a)).to(a)
|
|
1448
|
+
|
|
1449
|
+
|
|
1450
|
+
@function
|
|
1451
|
+
def isfinite(x: NDArray) -> NDArray: ...
|
|
1452
|
+
|
|
1453
|
+
|
|
1454
|
+
@function
|
|
1455
|
+
def sum(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray:
|
|
1456
|
+
"""
|
|
1457
|
+
https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.sum.html?highlight=sum
|
|
1458
|
+
"""
|
|
1459
|
+
|
|
1460
|
+
|
|
1461
|
+
@array_api_ruleset.register
|
|
1462
|
+
def _sum(x: NDArray, y: NDArray, v: Value, dtype: DType):
|
|
1463
|
+
return [
|
|
1464
|
+
rewrite(sum(x / NDArray.scalar(v))).to(sum(x) / NDArray.scalar(v)),
|
|
1465
|
+
# Sum of 0D array is
|
|
1466
|
+
]
|
|
1467
|
+
|
|
1468
|
+
|
|
1469
|
+
@function
|
|
1470
|
+
def reshape(x: NDArray, shape: TupleIntLike, copy: OptionalBool = OptionalBool.none) -> NDArray: ...
|
|
1471
|
+
|
|
1472
|
+
|
|
1473
|
+
# @function
|
|
1474
|
+
# def reshape_transform_index(original_shape: TupleInt, shape: TupleInt, index: TupleInt) -> TupleInt:
|
|
1475
|
+
# """
|
|
1476
|
+
# Transforms an indexing operation on a reshaped array to an indexing operation on the original array.
|
|
1477
|
+
# """
|
|
1478
|
+
# ...
|
|
1479
|
+
|
|
1480
|
+
|
|
1481
|
+
# @function
|
|
1482
|
+
# def reshape_transform_shape(original_shape: TupleInt, shape: TupleInt) -> TupleInt:
|
|
1483
|
+
# """
|
|
1484
|
+
# Transforms the shape of an array to one that is reshaped, by replacing -1 with the correct value.
|
|
1485
|
+
# """
|
|
1486
|
+
# ...
|
|
1487
|
+
|
|
1488
|
+
|
|
1489
|
+
# @array_api_ruleset.register
|
|
1490
|
+
# def _reshape(
|
|
1491
|
+
# x: NDArray,
|
|
1492
|
+
# y: NDArray,
|
|
1493
|
+
# shape: TupleInt,
|
|
1494
|
+
# copy: OptionalBool,
|
|
1495
|
+
# i: Int,
|
|
1496
|
+
# s: String,
|
|
1497
|
+
# ix: TupleInt,
|
|
1498
|
+
# ):
|
|
1499
|
+
# return [
|
|
1500
|
+
# # dtype of result is same as input
|
|
1501
|
+
# rewrite(reshape(x, shape, copy).dtype).to(x.dtype),
|
|
1502
|
+
# # Indexing into a reshaped array is the same as indexing into the original array with a transformed index
|
|
1503
|
+
# rewrite(reshape(x, shape, copy).index(ix)).to(x.index(reshape_transform_index(x.shape, shape, ix))),
|
|
1504
|
+
# rewrite(reshape(x, shape, copy).shape).to(reshape_transform_shape(x.shape, shape)),
|
|
1505
|
+
# # reshape_transform_shape recursively
|
|
1506
|
+
# # TODO: handle all cases
|
|
1507
|
+
# rewrite(reshape_transform_shape(TupleInt(i), TupleInt(Int(-1)))).to(TupleInt(i)),
|
|
1508
|
+
# ]
|
|
1509
|
+
|
|
1510
|
+
|
|
1511
|
+
@function
|
|
1512
|
+
def unique_values(x: NDArrayLike) -> NDArray:
|
|
1513
|
+
"""
|
|
1514
|
+
Returns the unique elements of an input array x flattened with arbitrary ordering.
|
|
1515
|
+
|
|
1516
|
+
https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_values.html
|
|
1517
|
+
"""
|
|
1518
|
+
|
|
1519
|
+
|
|
1520
|
+
@array_api_ruleset.register
|
|
1521
|
+
def _unique_values(x: NDArray):
|
|
1522
|
+
return [
|
|
1523
|
+
rewrite(unique_values(unique_values(x))).to(unique_values(x)),
|
|
1524
|
+
]
|
|
1525
|
+
|
|
1526
|
+
|
|
1527
|
+
@function
|
|
1528
|
+
def concat(arrays: TupleNDArrayLike, axis: OptionalInt = OptionalInt.none) -> NDArray: ...
|
|
1529
|
+
|
|
1530
|
+
|
|
1531
|
+
@array_api_ruleset.register
|
|
1532
|
+
def _concat(x: NDArray):
|
|
1533
|
+
return [
|
|
1534
|
+
# only support no-op concat for now
|
|
1535
|
+
rewrite(concat(TupleNDArray.EMPTY.append(x))).to(x),
|
|
1536
|
+
]
|
|
1537
|
+
|
|
1538
|
+
|
|
1539
|
+
@function
|
|
1540
|
+
def astype(x: NDArray, dtype: DType) -> NDArray: ...
|
|
1541
|
+
|
|
1542
|
+
|
|
1543
|
+
@array_api_ruleset.register
|
|
1544
|
+
def _astype(x: NDArray, dtype: DType, i: i64):
|
|
1545
|
+
return [
|
|
1546
|
+
rewrite(astype(x, dtype).dtype).to(dtype),
|
|
1547
|
+
rewrite(astype(NDArray.scalar(Value.int(Int(i))), float64)).to(
|
|
1548
|
+
NDArray.scalar(Value.float(Float(f64.from_i64(i))))
|
|
1549
|
+
),
|
|
1550
|
+
]
|
|
1551
|
+
|
|
1552
|
+
|
|
1553
|
+
@function
|
|
1554
|
+
def unique_counts(x: NDArray) -> TupleNDArray:
|
|
1555
|
+
"""
|
|
1556
|
+
Returns the unique elements of an input array x and the corresponding counts for each unique element in x.
|
|
1557
|
+
|
|
1558
|
+
|
|
1559
|
+
https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_counts.html
|
|
1560
|
+
"""
|
|
1561
|
+
|
|
1562
|
+
|
|
1563
|
+
@array_api_ruleset.register
|
|
1564
|
+
def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value, dtype: DType):
|
|
1565
|
+
return [
|
|
1566
|
+
# rewrite(unique_counts(x).length()).to(Int(2)),
|
|
1567
|
+
rewrite(unique_counts(x)).to(TupleNDArray(2, unique_counts(x).__getitem__)),
|
|
1568
|
+
# Sum of all unique counts is the size of the array
|
|
1569
|
+
rewrite(sum(unique_counts(x)[Int(1)])).to(NDArray.scalar(Value.int(x.size))),
|
|
1570
|
+
# Same but with astype in the middle
|
|
1571
|
+
# TODO: Replace
|
|
1572
|
+
rewrite(sum(astype(unique_counts(x)[Int(1)], dtype))).to(astype(NDArray.scalar(Value.int(x.size)), dtype)),
|
|
1573
|
+
]
|
|
1574
|
+
|
|
1575
|
+
|
|
1576
|
+
@function
|
|
1577
|
+
def square(x: NDArray) -> NDArray: ...
|
|
1578
|
+
|
|
1579
|
+
|
|
1580
|
+
@function
|
|
1581
|
+
def any(x: NDArray) -> NDArray: ...
|
|
1582
|
+
|
|
1583
|
+
|
|
1584
|
+
@function(egg_fn="ndarray-abs")
|
|
1585
|
+
def abs(x: NDArray) -> NDArray: ...
|
|
1586
|
+
|
|
1587
|
+
|
|
1588
|
+
@function(egg_fn="ndarray-log")
|
|
1589
|
+
def log(x: NDArray) -> NDArray: ...
|
|
1590
|
+
|
|
1591
|
+
|
|
1592
|
+
@array_api_ruleset.register
|
|
1593
|
+
def _abs(f: Float):
|
|
1594
|
+
return [
|
|
1595
|
+
rewrite(abs(NDArray.scalar(Value.float(f)))).to(NDArray.scalar(Value.float(f.abs()))),
|
|
1596
|
+
]
|
|
1597
|
+
|
|
1598
|
+
|
|
1599
|
+
@function
|
|
1600
|
+
def unique_inverse(x: NDArray) -> TupleNDArray:
|
|
1601
|
+
"""
|
|
1602
|
+
Returns the unique elements of an input array x and the indices from the set of unique elements that reconstruct x.
|
|
1603
|
+
|
|
1604
|
+
https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_inverse.html
|
|
1605
|
+
"""
|
|
1606
|
+
|
|
1607
|
+
|
|
1608
|
+
@array_api_ruleset.register
|
|
1609
|
+
def _unique_inverse(x: NDArray, i: Int):
|
|
1610
|
+
return [
|
|
1611
|
+
# rewrite(unique_inverse(x).length()).to(Int(2)),
|
|
1612
|
+
rewrite(unique_inverse(x)).to(TupleNDArray(2, unique_inverse(x).__getitem__)),
|
|
1613
|
+
# Shape of unique_inverse first element is same as shape of unique_values
|
|
1614
|
+
rewrite(unique_inverse(x)[Int(0)]).to(unique_values(x)),
|
|
1615
|
+
]
|
|
1616
|
+
|
|
1617
|
+
|
|
1618
|
+
@function
|
|
1619
|
+
def zeros(
|
|
1620
|
+
shape: TupleIntLike, dtype: OptionalDType = OptionalDType.none, device: OptionalDevice = OptionalDevice.none
|
|
1621
|
+
) -> NDArray: ...
|
|
1622
|
+
|
|
1623
|
+
|
|
1624
|
+
@function
|
|
1625
|
+
def expand_dims(x: NDArray, axis: Int = Int(0)) -> NDArray: ...
|
|
1626
|
+
|
|
1627
|
+
|
|
1628
|
+
@function
|
|
1629
|
+
def mean(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none, keepdims: Boolean = FALSE) -> NDArray: ...
|
|
1630
|
+
|
|
1631
|
+
|
|
1632
|
+
# TODO: Possibly change names to include modules.
|
|
1633
|
+
@function(egg_fn="ndarray-sqrt")
|
|
1634
|
+
def sqrt(x: NDArray) -> NDArray: ...
|
|
1635
|
+
|
|
1636
|
+
|
|
1637
|
+
@function
|
|
1638
|
+
def std(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray: ...
|
|
1639
|
+
|
|
1640
|
+
|
|
1641
|
+
@function
|
|
1642
|
+
def real(x: NDArray) -> NDArray: ...
|
|
1643
|
+
|
|
1644
|
+
|
|
1645
|
+
@function
|
|
1646
|
+
def conj(x: NDArray) -> NDArray: ...
|
|
1647
|
+
|
|
1648
|
+
|
|
1649
|
+
linalg = sys.modules[__name__]
|
|
1650
|
+
|
|
1651
|
+
|
|
1652
|
+
@function
|
|
1653
|
+
def svd(x: NDArray, full_matrices: Boolean = TRUE) -> TupleNDArray:
|
|
1654
|
+
"""
|
|
1655
|
+
https://data-apis.org/array-api/2022.12/extensions/generated/array_api.linalg.svd.html
|
|
1656
|
+
"""
|
|
1657
|
+
|
|
1658
|
+
|
|
1659
|
+
@array_api_ruleset.register
|
|
1660
|
+
def _linalg(x: NDArray, full_matrices: Boolean):
|
|
1661
|
+
return [
|
|
1662
|
+
# rewrite(svd(x, full_matrices).length()).to(Int(3)),
|
|
1663
|
+
rewrite(svd(x, full_matrices)).to(TupleNDArray(3, svd(x, full_matrices).__getitem__)),
|
|
1664
|
+
]
|
|
1665
|
+
|
|
1666
|
+
|
|
1667
|
+
@function(ruleset=array_api_ruleset)
|
|
1668
|
+
def ndindex(shape: TupleIntLike) -> TupleTupleInt:
|
|
1669
|
+
"""
|
|
1670
|
+
https://numpy.org/doc/stable/reference/generated/numpy.ndindex.html
|
|
1671
|
+
"""
|
|
1672
|
+
shape = cast("TupleInt", shape)
|
|
1673
|
+
return shape.map_tuple_int(TupleInt.range).product()
|
|
1674
|
+
|
|
1675
|
+
|
|
1676
|
+
##
|
|
1677
|
+
# Interval analysis
|
|
1678
|
+
#
|
|
1679
|
+
# to analyze `any(((astype(unique_counts(NDArray.var("y"))[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(150.0))) < NDArray.scalar(Value.int(Int(0)))).bool()``
|
|
1680
|
+
##
|
|
1681
|
+
|
|
1682
|
+
greater_zero = relation("greater_zero", Value)
|
|
1683
|
+
|
|
1684
|
+
|
|
1685
|
+
# @function
|
|
1686
|
+
# def ndarray_all_greater_0(x: NDArray) -> Unit:
|
|
1687
|
+
# ...
|
|
1688
|
+
|
|
1689
|
+
|
|
1690
|
+
# @function
|
|
1691
|
+
# def ndarray_all_false(x: NDArray) -> Unit:
|
|
1692
|
+
# ...
|
|
1693
|
+
|
|
1694
|
+
|
|
1695
|
+
# @function
|
|
1696
|
+
# def ndarray_all_true(x: NDArray) -> Unit:
|
|
1697
|
+
# ...
|
|
1698
|
+
|
|
1699
|
+
|
|
1700
|
+
# any((astype(unique_counts(_NDArray_1)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(150.0)))) < NDArray.scalar(Value.int(Int(0)))).to_bool()
|
|
1701
|
+
|
|
1702
|
+
# sum(astype(unique_counts(_NDArray_1)[Int(1)], DType.float64) / NDArray.scalar(Value.int(Int(150))))
|
|
1703
|
+
# And also
|
|
1704
|
+
|
|
1705
|
+
# def
|
|
1706
|
+
|
|
1707
|
+
|
|
1708
|
+
@function
|
|
1709
|
+
def broadcast_index(from_shape: TupleIntLike, to_shape: TupleIntLike, index: TupleIntLike) -> TupleInt:
|
|
1710
|
+
"""
|
|
1711
|
+
Returns the index in the original array of the given index in the broadcasted array.
|
|
1712
|
+
"""
|
|
1713
|
+
|
|
1714
|
+
|
|
1715
|
+
@function
|
|
1716
|
+
def broadcast_shapes(shape1: TupleIntLike, shape2: TupleIntLike) -> TupleInt:
|
|
1717
|
+
"""
|
|
1718
|
+
Returns the shape of the broadcasted array.
|
|
1719
|
+
"""
|
|
1720
|
+
|
|
1721
|
+
|
|
1722
|
+
@array_api_ruleset.register
|
|
1723
|
+
def _interval_analaysis(
|
|
1724
|
+
x: NDArray,
|
|
1725
|
+
y: NDArray,
|
|
1726
|
+
z: NDArray,
|
|
1727
|
+
dtype: DType,
|
|
1728
|
+
f: f64,
|
|
1729
|
+
i: i64,
|
|
1730
|
+
b: Boolean,
|
|
1731
|
+
idx: TupleInt,
|
|
1732
|
+
v: Value,
|
|
1733
|
+
v1: Value,
|
|
1734
|
+
v2: Value,
|
|
1735
|
+
float_: Float,
|
|
1736
|
+
int_: Int,
|
|
1737
|
+
):
|
|
1738
|
+
res_shape = broadcast_shapes(x.shape, y.shape)
|
|
1739
|
+
x_value = x.index(broadcast_index(x.shape, res_shape, idx))
|
|
1740
|
+
y_value = y.index(broadcast_index(y.shape, res_shape, idx))
|
|
1741
|
+
return [
|
|
1742
|
+
# Calling any on an array gives back a sclar, which is true if any of the values are truthy
|
|
1743
|
+
rewrite(any(x)).to(
|
|
1744
|
+
NDArray.scalar(Value.bool(possible_values(x.index(ALL_INDICES).to_truthy_value).contains(Value.bool(TRUE))))
|
|
1745
|
+
),
|
|
1746
|
+
# Indexing x < y is the same as broadcasting the index and then indexing both and then comparing
|
|
1747
|
+
rewrite((x < y).index(idx)).to(x_value < y_value),
|
|
1748
|
+
# Same for x / y
|
|
1749
|
+
rewrite((x / y).index(idx)).to(x_value / y_value),
|
|
1750
|
+
# Indexing a scalar is the same as the scalar
|
|
1751
|
+
rewrite(NDArray.scalar(v).index(idx)).to(v),
|
|
1752
|
+
# Indexing of astype is same as astype of indexing
|
|
1753
|
+
rewrite(astype(x, dtype).index(idx)).to(x.index(idx).astype(dtype)),
|
|
1754
|
+
# rule(eq(y).to(x < NDArray.scalar(Value.int(Int(0)))), ndarray_all_greater_0(x)).then(ndarray_all_false(y)),
|
|
1755
|
+
# rule(eq(y).to(any(x)), ndarray_all_false(x)).then(union(y).with_(NDArray.scalar(Value.bool(FALSE)))),
|
|
1756
|
+
# Indexing into unique counts counts are all positive
|
|
1757
|
+
rule(
|
|
1758
|
+
eq(v).to(unique_counts(x)[Int(1)].index(idx)),
|
|
1759
|
+
).then(greater_zero(v)),
|
|
1760
|
+
# Min value preserved over astype
|
|
1761
|
+
rule(
|
|
1762
|
+
greater_zero(v),
|
|
1763
|
+
eq(v1).to(v.astype(dtype)),
|
|
1764
|
+
).then(
|
|
1765
|
+
greater_zero(v1),
|
|
1766
|
+
),
|
|
1767
|
+
# Min value of scalar is scalar itself
|
|
1768
|
+
rule(eq(v).to(Value.float(Float(f))), f > 0.0).then(greater_zero(v)),
|
|
1769
|
+
rule(eq(v).to(Value.int(Int(i))), i > 0).then(greater_zero(v)),
|
|
1770
|
+
# If we have divison of v and v1, and both greater than zero, then the result is greater than zero
|
|
1771
|
+
rule(
|
|
1772
|
+
greater_zero(v),
|
|
1773
|
+
greater_zero(v1),
|
|
1774
|
+
eq(v2).to(v / v1),
|
|
1775
|
+
).then(
|
|
1776
|
+
greater_zero(v2),
|
|
1777
|
+
),
|
|
1778
|
+
# Define v < 0 to be false, if greater_zero(v)
|
|
1779
|
+
rule(
|
|
1780
|
+
greater_zero(v),
|
|
1781
|
+
eq(v1).to(v < Value.int(Int(0))),
|
|
1782
|
+
).then(
|
|
1783
|
+
union(v1).with_(Value.bool(FALSE)),
|
|
1784
|
+
),
|
|
1785
|
+
# possible values of bool is bool
|
|
1786
|
+
rewrite(possible_values(Value.bool(b))).to(TupleValue.EMPTY.append(Value.bool(b))),
|
|
1787
|
+
# casting to a type preserves if > 0
|
|
1788
|
+
rule(
|
|
1789
|
+
eq(v1).to(v.astype(dtype)),
|
|
1790
|
+
greater_zero(v),
|
|
1791
|
+
).then(
|
|
1792
|
+
greater_zero(v1),
|
|
1793
|
+
),
|
|
1794
|
+
]
|
|
1795
|
+
|
|
1796
|
+
|
|
1797
|
+
##
|
|
1798
|
+
# Mathematical descriptions of arrays as:
|
|
1799
|
+
# 1. A shape `.shape`
|
|
1800
|
+
# 2. A dtype `.dtype`
|
|
1801
|
+
# 3. A mapping from indices to values `x.index(idx)`
|
|
1802
|
+
#
|
|
1803
|
+
# For all operations that are supported mathematically, define each of the above.
|
|
1804
|
+
##
|
|
1805
|
+
|
|
1806
|
+
|
|
1807
|
+
def _demand_shape(compound: NDArray, inner: NDArray) -> Command:
|
|
1808
|
+
__a = var("__a", NDArray)
|
|
1809
|
+
return rule(eq(__a).to(compound)).then(inner.shape, inner.shape.length())
|
|
1810
|
+
|
|
1811
|
+
|
|
1812
|
+
@array_api_ruleset.register
|
|
1813
|
+
def _scalar_math(v: Value, vs: TupleValue, i: Int):
|
|
1814
|
+
yield rewrite(NDArray.scalar(v).shape).to(TupleInt.EMPTY)
|
|
1815
|
+
yield rewrite(NDArray.scalar(v).dtype).to(v.dtype)
|
|
1816
|
+
yield rewrite(NDArray.scalar(v).index(TupleInt.EMPTY)).to(v)
|
|
1817
|
+
|
|
1818
|
+
|
|
1819
|
+
@array_api_ruleset.register
|
|
1820
|
+
def _vector_math(v: Value, vs: TupleValue, ti: TupleInt):
|
|
1821
|
+
yield rewrite(NDArray.vector(vs).shape).to(TupleInt.single(vs.length()))
|
|
1822
|
+
yield rewrite(NDArray.vector(vs).dtype).to(vs[Int(0)].dtype)
|
|
1823
|
+
yield rewrite(NDArray.vector(vs).index(ti)).to(vs[ti[0]])
|
|
1824
|
+
|
|
1825
|
+
|
|
1826
|
+
@array_api_ruleset.register
|
|
1827
|
+
def _reshape_math(x: NDArray, shape: TupleInt, copy: OptionalBool):
|
|
1828
|
+
res = reshape(x, shape, copy)
|
|
1829
|
+
|
|
1830
|
+
yield _demand_shape(res, x)
|
|
1831
|
+
# Demand shape length and index
|
|
1832
|
+
yield rule(res).then(shape.length(), shape[0])
|
|
1833
|
+
|
|
1834
|
+
# Reshaping a vec to a vec is the same as the vec
|
|
1835
|
+
yield rewrite(res).to(
|
|
1836
|
+
x,
|
|
1837
|
+
eq(x.shape.length()).to(Int(1)),
|
|
1838
|
+
eq(shape.length()).to(Int(1)),
|
|
1839
|
+
eq(shape[0]).to(Int(-1)),
|
|
1840
|
+
)
|
|
1841
|
+
|
|
1842
|
+
|
|
1843
|
+
@array_api_ruleset.register
|
|
1844
|
+
def _indexing_pushdown(x: NDArray, shape: TupleInt, copy: OptionalBool, i: Int):
|
|
1845
|
+
# rewrite full getitem to indexec
|
|
1846
|
+
yield rewrite(x[IndexKey.int(i)]).to(NDArray.scalar(x.index(TupleInt.single(i))))
|
|
1847
|
+
# TODO: Multi index rewrite as well if all are ints
|
|
1848
|
+
|
|
1849
|
+
|
|
1850
|
+
##
|
|
1851
|
+
# Assumptions
|
|
1852
|
+
##
|
|
1853
|
+
|
|
1854
|
+
|
|
1855
|
+
@function(mutates_first_arg=True)
|
|
1856
|
+
def assume_dtype(x: NDArray, dtype: DType) -> None:
|
|
1857
|
+
"""
|
|
1858
|
+
Asserts that the dtype of x is dtype.
|
|
1859
|
+
"""
|
|
1860
|
+
|
|
1861
|
+
|
|
1862
|
+
@array_api_ruleset.register
|
|
1863
|
+
def _assume_dtype(x: NDArray, dtype: DType, idx: TupleInt):
|
|
1864
|
+
orig_x = copy(x)
|
|
1865
|
+
assume_dtype(x, dtype)
|
|
1866
|
+
yield rewrite(x.dtype).to(dtype)
|
|
1867
|
+
yield rewrite(x.shape).to(orig_x.shape)
|
|
1868
|
+
yield rewrite(x.index(idx)).to(orig_x.index(idx))
|
|
1869
|
+
|
|
1870
|
+
|
|
1871
|
+
@function(mutates_first_arg=True)
|
|
1872
|
+
def assume_shape(x: NDArray, shape: TupleIntLike) -> None:
|
|
1873
|
+
"""
|
|
1874
|
+
Asserts that the shape of x is shape.
|
|
1875
|
+
"""
|
|
1876
|
+
|
|
1877
|
+
|
|
1878
|
+
@array_api_ruleset.register
|
|
1879
|
+
def _assume_shape(x: NDArray, shape: TupleInt, idx: TupleInt):
|
|
1880
|
+
orig_x = copy(x)
|
|
1881
|
+
assume_shape(x, shape)
|
|
1882
|
+
yield rewrite(x.shape).to(shape)
|
|
1883
|
+
yield rewrite(x.dtype).to(orig_x.dtype)
|
|
1884
|
+
yield rewrite(x.index(idx)).to(orig_x.index(idx))
|
|
1885
|
+
|
|
1886
|
+
|
|
1887
|
+
@function(mutates_first_arg=True)
|
|
1888
|
+
def assume_isfinite(x: NDArray) -> None:
|
|
1889
|
+
"""
|
|
1890
|
+
Asserts that the scalar ndarray is non null and not infinite.
|
|
1891
|
+
"""
|
|
1892
|
+
|
|
1893
|
+
|
|
1894
|
+
@array_api_ruleset.register
|
|
1895
|
+
def _isfinite(x: NDArray, ti: TupleInt):
|
|
1896
|
+
orig_x = copy(x)
|
|
1897
|
+
assume_isfinite(x)
|
|
1898
|
+
|
|
1899
|
+
# pass through getitem, shape, index
|
|
1900
|
+
yield rewrite(x.shape).to(orig_x.shape)
|
|
1901
|
+
yield rewrite(x.dtype).to(orig_x.dtype)
|
|
1902
|
+
yield rewrite(x.index(ti)).to(orig_x.index(ti))
|
|
1903
|
+
# But say that any indixed value is finite
|
|
1904
|
+
yield rewrite(x.index(ti).isfinite()).to(TRUE)
|
|
1905
|
+
|
|
1906
|
+
|
|
1907
|
+
@function(mutates_first_arg=True)
|
|
1908
|
+
def assume_value_one_of(x: NDArray, values: TupleValueLike) -> None:
|
|
1909
|
+
"""
|
|
1910
|
+
A value that is one of the values in the tuple.
|
|
1911
|
+
"""
|
|
1912
|
+
|
|
1913
|
+
|
|
1914
|
+
@array_api_ruleset.register
|
|
1915
|
+
def _assume_value_one_of(x: NDArray, v: Value, vs: TupleValue, idx: TupleInt):
|
|
1916
|
+
x_orig = copy(x)
|
|
1917
|
+
assume_value_one_of(x, vs)
|
|
1918
|
+
# Pass through dtype and shape
|
|
1919
|
+
yield rewrite(x.shape).to(x_orig.shape)
|
|
1920
|
+
yield rewrite(x.dtype).to(x_orig.dtype)
|
|
1921
|
+
# The array vales passes through, but say that the possible_values are one of the values
|
|
1922
|
+
yield rule(eq(v).to(x.index(idx))).then(
|
|
1923
|
+
union(v).with_(x_orig.index(idx)),
|
|
1924
|
+
union(possible_values(v)).with_(vs),
|
|
1925
|
+
)
|
|
1926
|
+
|
|
1927
|
+
|
|
1928
|
+
@array_api_ruleset.register
|
|
1929
|
+
def _ndarray_value_isfinite(arr: NDArray, x: Value, xs: TupleValue, i: Int, f: f64, b: Boolean):
|
|
1930
|
+
yield rewrite(Value.int(i).isfinite()).to(TRUE)
|
|
1931
|
+
yield rewrite(Value.bool(b).isfinite()).to(TRUE)
|
|
1932
|
+
yield rewrite(Value.float(Float(f)).isfinite()).to(TRUE, ne(f).to(f64(math.nan)))
|
|
1933
|
+
|
|
1934
|
+
# a sum of an array is finite if all the values are finite
|
|
1935
|
+
yield rewrite(isfinite(sum(arr))).to(NDArray.scalar(Value.bool(arr.index(ALL_INDICES).isfinite())))
|
|
1936
|
+
|
|
1937
|
+
|
|
1938
|
+
@array_api_ruleset.register
|
|
1939
|
+
def _unique(xs: TupleValue, a: NDArray, shape: TupleInt, copy: OptionalBool):
|
|
1940
|
+
yield rewrite(unique_values(x=a)).to(NDArray.vector(possible_values(a.index(ALL_INDICES))))
|
|
1941
|
+
# yield rewrite(
|
|
1942
|
+
# possible_values(reshape(a.index(shape, copy), ALL_INDICES)),
|
|
1943
|
+
# ).to(possible_values(a.index(ALL_INDICES)))
|
|
1944
|
+
|
|
1945
|
+
|
|
1946
|
+
@array_api_ruleset.register
|
|
1947
|
+
def _size(x: NDArray):
|
|
1948
|
+
yield rewrite(x.size).to(x.shape.foldl(Int.__mul__, Int(1)))
|
|
1949
|
+
|
|
1950
|
+
|
|
1951
|
+
# Seperate rulseset so we can use it in program gen
|
|
1952
|
+
@ruleset
|
|
1953
|
+
def array_api_vec_to_cons_ruleset(
|
|
1954
|
+
vs: Vec[Int],
|
|
1955
|
+
vv: Vec[Value],
|
|
1956
|
+
vn: Vec[NDArray],
|
|
1957
|
+
vt: Vec[TupleInt],
|
|
1958
|
+
):
|
|
1959
|
+
yield rewrite(TupleInt.from_vec(vs)).to(TupleInt.EMPTY, eq(vs.length()).to(i64(0)))
|
|
1960
|
+
yield rewrite(TupleInt.from_vec(vs)).to(
|
|
1961
|
+
TupleInt.from_vec(vs.remove(vs.length() - 1)).append(vs[vs.length() - 1]), ne(vs.length()).to(i64(0))
|
|
1962
|
+
)
|
|
1963
|
+
|
|
1964
|
+
yield rewrite(TupleValue.from_vec(vv)).to(TupleValue.EMPTY, eq(vv.length()).to(i64(0)))
|
|
1965
|
+
yield rewrite(TupleValue.from_vec(vv)).to(
|
|
1966
|
+
TupleValue.from_vec(vv.remove(vv.length() - 1)).append(vv[vv.length() - 1]), ne(vv.length()).to(i64(0))
|
|
1967
|
+
)
|
|
1968
|
+
|
|
1969
|
+
yield rewrite(TupleTupleInt.from_vec(vt)).to(TupleTupleInt.EMPTY, eq(vt.length()).to(i64(0)))
|
|
1970
|
+
yield rewrite(TupleTupleInt.from_vec(vt)).to(
|
|
1971
|
+
TupleTupleInt.from_vec(vt.remove(vt.length() - 1)).append(vt[vt.length() - 1]), ne(vt.length()).to(i64(0))
|
|
1972
|
+
)
|
|
1973
|
+
yield rewrite(TupleNDArray.from_vec(vn)).to(TupleNDArray.EMPTY, eq(vn.length()).to(i64(0)))
|
|
1974
|
+
yield rewrite(TupleNDArray.from_vec(vn)).to(
|
|
1975
|
+
TupleNDArray.from_vec(vn.remove(vn.length() - 1)).append(vn[vn.length() - 1]), ne(vn.length()).to(i64(0))
|
|
1976
|
+
)
|
|
1977
|
+
|
|
1978
|
+
|
|
1979
|
+
array_api_combined_ruleset = array_api_ruleset | array_api_vec_to_cons_ruleset
|
|
1980
|
+
array_api_schedule = array_api_combined_ruleset.saturate()
|
|
1981
|
+
|
|
1982
|
+
_CURRENT_EGRAPH: None | EGraph = None
|
|
1983
|
+
|
|
1984
|
+
|
|
1985
|
+
@contextlib.contextmanager
|
|
1986
|
+
def set_array_api_egraph(egraph: EGraph) -> Iterator[None]:
|
|
1987
|
+
"""
|
|
1988
|
+
Context manager that will set the current egraph. It will be set back after.
|
|
1989
|
+
"""
|
|
1990
|
+
global _CURRENT_EGRAPH
|
|
1991
|
+
assert _CURRENT_EGRAPH is None
|
|
1992
|
+
_CURRENT_EGRAPH = egraph
|
|
1993
|
+
yield
|
|
1994
|
+
_CURRENT_EGRAPH = None
|
|
1995
|
+
|
|
1996
|
+
|
|
1997
|
+
def _get_current_egraph() -> EGraph:
|
|
1998
|
+
return _CURRENT_EGRAPH or EGraph()
|
|
1999
|
+
|
|
2000
|
+
|
|
2001
|
+
def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: BuiltinExpr) -> Any:
|
|
2002
|
+
"""
|
|
2003
|
+
Try evaling the expression that will result in a primitive expression being fill.
|
|
2004
|
+
if it fails, display the egraph and raise an error.
|
|
2005
|
+
"""
|
|
2006
|
+
try:
|
|
2007
|
+
extracted = egraph.extract(prim_expr)
|
|
2008
|
+
except EggSmolError:
|
|
2009
|
+
# If this primitive doesn't exist in the egraph, we need to try to create it by
|
|
2010
|
+
# registering the expression and running the schedule
|
|
2011
|
+
egraph.register(expr)
|
|
2012
|
+
egraph.run(schedule)
|
|
2013
|
+
try:
|
|
2014
|
+
extracted = egraph.extract(prim_expr)
|
|
2015
|
+
except BaseException as e:
|
|
2016
|
+
# egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
|
|
2017
|
+
e.add_note(f"Cannot evaluate {egraph.extract(expr)}")
|
|
2018
|
+
raise
|
|
2019
|
+
return extracted.value # type: ignore[attr-defined]
|