egglog 7.2.0__cp311-none-win_amd64.whl → 8.0.1__cp311-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp311-win_amd64.pyd +0 -0
- egglog/bindings.pyi +107 -53
- egglog/builtins.py +49 -6
- egglog/conversion.py +32 -9
- egglog/declarations.py +82 -4
- egglog/egraph.py +260 -179
- egglog/egraph_state.py +149 -66
- egglog/examples/higher_order_functions.py +4 -9
- egglog/exp/array_api.py +278 -93
- egglog/exp/array_api_jit.py +4 -8
- egglog/exp/array_api_loopnest.py +149 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +62 -25
- egglog/exp/program_gen.py +23 -17
- egglog/functionalize.py +91 -0
- egglog/ipython_magic.py +1 -1
- egglog/pretty.py +88 -44
- egglog/runtime.py +53 -40
- egglog/thunk.py +30 -18
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35774 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.2.0.dist-info → egglog-8.0.1.dist-info}/METADATA +33 -32
- egglog-8.0.1.dist-info/RECORD +42 -0
- {egglog-7.2.0.dist-info → egglog-8.0.1.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.2.0.dist-info/RECORD +0 -40
- {egglog-7.2.0.dist-info/license_files → egglog-8.0.1.dist-info/licenses}/LICENSE +0 -0
egglog/exp/array_api.py
CHANGED
|
@@ -7,7 +7,9 @@ import math
|
|
|
7
7
|
import numbers
|
|
8
8
|
import sys
|
|
9
9
|
from copy import copy
|
|
10
|
-
from
|
|
10
|
+
from functools import partial
|
|
11
|
+
from types import EllipsisType
|
|
12
|
+
from typing import TYPE_CHECKING, ClassVar, TypeAlias, overload
|
|
11
13
|
|
|
12
14
|
import numpy as np
|
|
13
15
|
|
|
@@ -42,6 +44,8 @@ class Boolean(Expr):
|
|
|
42
44
|
|
|
43
45
|
def if_int(self, true_value: Int, false_value: Int) -> Int: ...
|
|
44
46
|
|
|
47
|
+
def __invert__(self) -> Boolean: ...
|
|
48
|
+
|
|
45
49
|
|
|
46
50
|
BooleanLike = Boolean | bool
|
|
47
51
|
|
|
@@ -61,6 +65,8 @@ def _bool(x: Boolean, i: Int, j: Int):
|
|
|
61
65
|
rewrite(FALSE & x).to(FALSE),
|
|
62
66
|
rewrite(TRUE.if_int(i, j)).to(i),
|
|
63
67
|
rewrite(FALSE.if_int(i, j)).to(j),
|
|
68
|
+
rewrite(~TRUE).to(FALSE),
|
|
69
|
+
rewrite(~FALSE).to(TRUE),
|
|
64
70
|
]
|
|
65
71
|
|
|
66
72
|
|
|
@@ -69,74 +75,75 @@ class Int(Expr):
|
|
|
69
75
|
|
|
70
76
|
def __invert__(self) -> Int: ...
|
|
71
77
|
|
|
72
|
-
def __lt__(self, other:
|
|
78
|
+
def __lt__(self, other: IntLike) -> Boolean: ...
|
|
73
79
|
|
|
74
|
-
def __le__(self, other:
|
|
80
|
+
def __le__(self, other: IntLike) -> Boolean: ...
|
|
75
81
|
|
|
76
|
-
def __eq__(self, other:
|
|
82
|
+
def __eq__(self, other: IntLike) -> Boolean: # type: ignore[override]
|
|
77
83
|
...
|
|
78
84
|
|
|
85
|
+
# TODO: Fix this?
|
|
79
86
|
# Make != always return a Bool, so that numpy.unique works on a tuple of ints
|
|
80
87
|
# In _unique1d
|
|
81
88
|
@method(preserve=True)
|
|
82
89
|
def __ne__(self, other: Int) -> bool: # type: ignore[override]
|
|
83
90
|
return not (self == other)
|
|
84
91
|
|
|
85
|
-
def __gt__(self, other:
|
|
92
|
+
def __gt__(self, other: IntLike) -> Boolean: ...
|
|
86
93
|
|
|
87
|
-
def __ge__(self, other:
|
|
94
|
+
def __ge__(self, other: IntLike) -> Boolean: ...
|
|
88
95
|
|
|
89
|
-
def __add__(self, other:
|
|
96
|
+
def __add__(self, other: IntLike) -> Int: ...
|
|
90
97
|
|
|
91
|
-
def __sub__(self, other:
|
|
98
|
+
def __sub__(self, other: IntLike) -> Int: ...
|
|
92
99
|
|
|
93
|
-
def __mul__(self, other:
|
|
100
|
+
def __mul__(self, other: IntLike) -> Int: ...
|
|
94
101
|
|
|
95
|
-
def __truediv__(self, other:
|
|
102
|
+
def __truediv__(self, other: IntLike) -> Int: ...
|
|
96
103
|
|
|
97
|
-
def __floordiv__(self, other:
|
|
104
|
+
def __floordiv__(self, other: IntLike) -> Int: ...
|
|
98
105
|
|
|
99
|
-
def __mod__(self, other:
|
|
106
|
+
def __mod__(self, other: IntLike) -> Int: ...
|
|
100
107
|
|
|
101
|
-
def __divmod__(self, other:
|
|
108
|
+
def __divmod__(self, other: IntLike) -> Int: ...
|
|
102
109
|
|
|
103
|
-
def __pow__(self, other:
|
|
110
|
+
def __pow__(self, other: IntLike) -> Int: ...
|
|
104
111
|
|
|
105
|
-
def __lshift__(self, other:
|
|
112
|
+
def __lshift__(self, other: IntLike) -> Int: ...
|
|
106
113
|
|
|
107
|
-
def __rshift__(self, other:
|
|
114
|
+
def __rshift__(self, other: IntLike) -> Int: ...
|
|
108
115
|
|
|
109
|
-
def __and__(self, other:
|
|
116
|
+
def __and__(self, other: IntLike) -> Int: ...
|
|
110
117
|
|
|
111
|
-
def __xor__(self, other:
|
|
118
|
+
def __xor__(self, other: IntLike) -> Int: ...
|
|
112
119
|
|
|
113
|
-
def __or__(self, other:
|
|
120
|
+
def __or__(self, other: IntLike) -> Int: ...
|
|
114
121
|
|
|
115
|
-
def __radd__(self, other:
|
|
122
|
+
def __radd__(self, other: IntLike) -> Int: ...
|
|
116
123
|
|
|
117
|
-
def __rsub__(self, other:
|
|
124
|
+
def __rsub__(self, other: IntLike) -> Int: ...
|
|
118
125
|
|
|
119
|
-
def __rmul__(self, other:
|
|
126
|
+
def __rmul__(self, other: IntLike) -> Int: ...
|
|
120
127
|
|
|
121
|
-
def __rmatmul__(self, other:
|
|
128
|
+
def __rmatmul__(self, other: IntLike) -> Int: ...
|
|
122
129
|
|
|
123
|
-
def __rtruediv__(self, other:
|
|
130
|
+
def __rtruediv__(self, other: IntLike) -> Int: ...
|
|
124
131
|
|
|
125
|
-
def __rfloordiv__(self, other:
|
|
132
|
+
def __rfloordiv__(self, other: IntLike) -> Int: ...
|
|
126
133
|
|
|
127
|
-
def __rmod__(self, other:
|
|
134
|
+
def __rmod__(self, other: IntLike) -> Int: ...
|
|
128
135
|
|
|
129
|
-
def __rpow__(self, other:
|
|
136
|
+
def __rpow__(self, other: IntLike) -> Int: ...
|
|
130
137
|
|
|
131
|
-
def __rlshift__(self, other:
|
|
138
|
+
def __rlshift__(self, other: IntLike) -> Int: ...
|
|
132
139
|
|
|
133
|
-
def __rrshift__(self, other:
|
|
140
|
+
def __rrshift__(self, other: IntLike) -> Int: ...
|
|
134
141
|
|
|
135
|
-
def __rand__(self, other:
|
|
142
|
+
def __rand__(self, other: IntLike) -> Int: ...
|
|
136
143
|
|
|
137
|
-
def __rxor__(self, other:
|
|
144
|
+
def __rxor__(self, other: IntLike) -> Int: ...
|
|
138
145
|
|
|
139
|
-
def __ror__(self, other:
|
|
146
|
+
def __ror__(self, other: IntLike) -> Int: ...
|
|
140
147
|
|
|
141
148
|
@property
|
|
142
149
|
def i64(self) -> i64: ...
|
|
@@ -157,9 +164,12 @@ class Int(Expr):
|
|
|
157
164
|
def __bool__(self) -> bool:
|
|
158
165
|
return bool(int(self))
|
|
159
166
|
|
|
167
|
+
@classmethod
|
|
168
|
+
def if_(cls, b: Boolean, i: Int, j: Int) -> Int: ...
|
|
169
|
+
|
|
160
170
|
|
|
161
171
|
@array_api_ruleset.register
|
|
162
|
-
def _int(i: i64, j: i64, r: Boolean, o: Int):
|
|
172
|
+
def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
|
|
163
173
|
yield rewrite(Int(i) == Int(i)).to(TRUE)
|
|
164
174
|
yield rule(eq(r).to(Int(i) == Int(j)), ne(i).to(j)).then(union(r).with_(FALSE))
|
|
165
175
|
|
|
@@ -189,8 +199,13 @@ def _int(i: i64, j: i64, r: Boolean, o: Int):
|
|
|
189
199
|
yield rewrite(Int(i) >> Int(j)).to(Int(i >> j))
|
|
190
200
|
yield rewrite(~Int(i)).to(Int(~i))
|
|
191
201
|
|
|
202
|
+
yield rewrite(Int.if_(TRUE, o, b)).to(o)
|
|
203
|
+
yield rewrite(Int.if_(FALSE, o, b)).to(b)
|
|
204
|
+
|
|
192
205
|
|
|
193
|
-
converter(
|
|
206
|
+
converter(i64, Int, lambda x: Int(x))
|
|
207
|
+
|
|
208
|
+
IntLike: TypeAlias = Int | i64Like
|
|
194
209
|
|
|
195
210
|
|
|
196
211
|
class Float(Expr):
|
|
@@ -220,6 +235,9 @@ converter(float, Float, lambda x: Float(x))
|
|
|
220
235
|
converter(Int, Float, lambda x: Float.from_int(x))
|
|
221
236
|
|
|
222
237
|
|
|
238
|
+
FloatLike: TypeAlias = Float | float | IntLike
|
|
239
|
+
|
|
240
|
+
|
|
223
241
|
@array_api_ruleset.register
|
|
224
242
|
def _float(f: f64, f2: f64, i: i64, r: Rational, r1: Rational):
|
|
225
243
|
return [
|
|
@@ -238,14 +256,45 @@ def _float(f: f64, f2: f64, i: i64, r: Rational, r1: Rational):
|
|
|
238
256
|
]
|
|
239
257
|
|
|
240
258
|
|
|
241
|
-
|
|
259
|
+
@function
|
|
260
|
+
def index_vec_int(xs: Vec[Int], i: Int) -> Int: ...
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class TupleInt(Expr, ruleset=array_api_ruleset):
|
|
264
|
+
"""
|
|
265
|
+
Should act like a tuple[int, ...]
|
|
266
|
+
|
|
267
|
+
All constructors should be rewritten to the functional semantics in the __init__ method.
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
@classmethod
|
|
271
|
+
def var(cls, name: StringLike) -> TupleInt: ...
|
|
272
|
+
|
|
242
273
|
EMPTY: ClassVar[TupleInt]
|
|
243
274
|
|
|
244
|
-
|
|
275
|
+
@method(unextractable=True)
|
|
276
|
+
def __init__(self, length: IntLike, idx_fn: Callable[[Int], Int]) -> None: ...
|
|
277
|
+
|
|
278
|
+
@classmethod
|
|
279
|
+
def single(cls, i: Int) -> TupleInt:
|
|
280
|
+
return TupleInt(Int(1), lambda _: i)
|
|
281
|
+
|
|
282
|
+
@classmethod
|
|
283
|
+
def range(cls, stop: Int) -> TupleInt:
|
|
284
|
+
return TupleInt(stop, lambda i: i)
|
|
285
|
+
|
|
286
|
+
@classmethod
|
|
287
|
+
def from_vec(cls, vec: Vec[Int]) -> TupleInt:
|
|
288
|
+
return TupleInt(vec.length(), partial(index_vec_int, vec))
|
|
245
289
|
|
|
246
|
-
def __add__(self, other: TupleInt) -> TupleInt:
|
|
290
|
+
def __add__(self, other: TupleInt) -> TupleInt:
|
|
291
|
+
return TupleInt(
|
|
292
|
+
self.length() + other.length(),
|
|
293
|
+
lambda i: Int.if_(i < self.length(), self[i], other[i - self.length()]),
|
|
294
|
+
)
|
|
247
295
|
|
|
248
296
|
def length(self) -> Int: ...
|
|
297
|
+
def __getitem__(self, i: IntLike) -> Int: ...
|
|
249
298
|
|
|
250
299
|
@method(preserve=True)
|
|
251
300
|
def __len__(self) -> int:
|
|
@@ -253,41 +302,121 @@ class TupleInt(Expr):
|
|
|
253
302
|
|
|
254
303
|
@method(preserve=True)
|
|
255
304
|
def __iter__(self) -> Iterator[Int]:
|
|
256
|
-
return iter(self[
|
|
257
|
-
|
|
258
|
-
def __getitem__(self, i: Int) -> Int: ...
|
|
305
|
+
return iter(self[i] for i in range(len(self)))
|
|
259
306
|
|
|
260
307
|
def fold(self, init: Int, f: Callable[[Int, Int], Int]) -> Int: ...
|
|
261
308
|
|
|
309
|
+
def fold_boolean(self, init: Boolean, f: Callable[[Boolean, Int], Boolean]) -> Boolean: ...
|
|
262
310
|
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
)
|
|
311
|
+
def contains(self, i: Int) -> Boolean:
|
|
312
|
+
return self.fold_boolean(FALSE, lambda acc, j: acc | (i == j))
|
|
313
|
+
|
|
314
|
+
@method(cost=100)
|
|
315
|
+
def filter(self, f: Callable[[Int], Boolean]) -> TupleInt: ...
|
|
316
|
+
|
|
317
|
+
@method(cost=100)
|
|
318
|
+
def map(self, f: Callable[[Int], Int]) -> TupleInt:
|
|
319
|
+
return TupleInt(self.length(), lambda i: f(self[i]))
|
|
320
|
+
|
|
321
|
+
@classmethod
|
|
322
|
+
def if_(cls, b: Boolean, i: TupleInt, j: TupleInt) -> TupleInt: ...
|
|
323
|
+
|
|
324
|
+
@method(preserve=True)
|
|
325
|
+
def to_py(self) -> tuple[int, ...]:
|
|
326
|
+
return tuple(int(i) for i in self)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
# TODO: Upcast args for Vec[Int] constructor
|
|
330
|
+
converter(tuple, TupleInt, lambda x: TupleInt.from_vec(Vec(*(convert(i, Int) for i in x))))
|
|
331
|
+
|
|
332
|
+
TupleIntLike: TypeAlias = TupleInt | tuple[IntLike, ...]
|
|
272
333
|
|
|
273
334
|
|
|
274
335
|
@array_api_ruleset.register
|
|
275
|
-
def _tuple_int(
|
|
336
|
+
def _tuple_int(
|
|
337
|
+
i: Int,
|
|
338
|
+
i2: Int,
|
|
339
|
+
k: i64,
|
|
340
|
+
f: Callable[[Int, Int], Int],
|
|
341
|
+
bool_f: Callable[[Boolean, Int], Boolean],
|
|
342
|
+
idx_fn: Callable[[Int], Int],
|
|
343
|
+
map_fn: Callable[[Int], Int],
|
|
344
|
+
filter_f: Callable[[Int], Boolean],
|
|
345
|
+
vs: Vec[Int],
|
|
346
|
+
b: Boolean,
|
|
347
|
+
ti: TupleInt,
|
|
348
|
+
ti2: TupleInt,
|
|
349
|
+
):
|
|
350
|
+
remaining = TupleInt(k - 1, lambda i: idx_fn(i + 1)).filter(filter_f)
|
|
276
351
|
return [
|
|
277
|
-
rewrite(
|
|
278
|
-
rewrite(TupleInt(i)
|
|
279
|
-
|
|
280
|
-
rewrite(
|
|
281
|
-
rewrite((TupleInt(i) + ti)[Int(0)]).to(i),
|
|
282
|
-
# Rule for indexing > 0
|
|
283
|
-
rule(eq(i).to((TupleInt(i2) + ti)[Int(k)]), k > 0).then(union(i).with_(ti[Int(k - 1)])),
|
|
352
|
+
rewrite(TupleInt(i, idx_fn).length()).to(i),
|
|
353
|
+
rewrite(TupleInt(i, idx_fn)[i2]).to(idx_fn(i2)),
|
|
354
|
+
# index_vec_int
|
|
355
|
+
rewrite(index_vec_int(vs, Int(k))).to(vs[k], vs.length() > k),
|
|
284
356
|
# fold
|
|
285
|
-
rewrite(TupleInt.
|
|
286
|
-
rewrite(TupleInt(
|
|
287
|
-
|
|
357
|
+
rewrite(TupleInt(0, idx_fn).fold(i, f)).to(i),
|
|
358
|
+
rewrite(TupleInt(Int(k), idx_fn).fold(i, f)).to(
|
|
359
|
+
f(TupleInt(k - 1, lambda i: idx_fn(i + 1)).fold(i, f), idx_fn(Int(0))),
|
|
360
|
+
ne(k).to(i64(0)),
|
|
361
|
+
),
|
|
362
|
+
# fold boolean
|
|
363
|
+
rewrite(TupleInt(0, idx_fn).fold_boolean(b, bool_f)).to(b),
|
|
364
|
+
rewrite(TupleInt(Int(k), idx_fn).fold_boolean(b, bool_f)).to(
|
|
365
|
+
bool_f(TupleInt(k - 1, lambda i: idx_fn(i + 1)).fold_boolean(b, bool_f), idx_fn(Int(0))),
|
|
366
|
+
ne(k).to(i64(0)),
|
|
367
|
+
),
|
|
368
|
+
# filter TODO: could be written as fold w/ generic types
|
|
369
|
+
rewrite(TupleInt(0, idx_fn).filter(filter_f)).to(TupleInt(0, idx_fn)),
|
|
370
|
+
rewrite(TupleInt(Int(k), idx_fn).filter(filter_f)).to(
|
|
371
|
+
TupleInt.if_(filter_f(value := idx_fn(Int(k))), TupleInt.single(value) + remaining, remaining),
|
|
372
|
+
ne(k).to(i64(0)),
|
|
373
|
+
),
|
|
374
|
+
# Empty
|
|
375
|
+
rewrite(TupleInt.EMPTY).to(TupleInt(0, bottom_indexing)),
|
|
376
|
+
# if_
|
|
377
|
+
rewrite(TupleInt.if_(TRUE, ti, ti2)).to(ti),
|
|
378
|
+
rewrite(TupleInt.if_(FALSE, ti, ti2)).to(ti2),
|
|
288
379
|
]
|
|
289
380
|
|
|
290
381
|
|
|
382
|
+
class TupleTupleInt(Expr, ruleset=array_api_ruleset):
|
|
383
|
+
@classmethod
|
|
384
|
+
def var(cls, name: StringLike) -> TupleTupleInt: ...
|
|
385
|
+
|
|
386
|
+
EMPTY: ClassVar[TupleTupleInt]
|
|
387
|
+
|
|
388
|
+
def __init__(self, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> None: ...
|
|
389
|
+
|
|
390
|
+
@classmethod
|
|
391
|
+
def single(cls, i: TupleInt) -> TupleTupleInt:
|
|
392
|
+
return TupleTupleInt(Int(1), lambda _: i)
|
|
393
|
+
|
|
394
|
+
@classmethod
|
|
395
|
+
def from_vec(cls, vec: Vec[Int]) -> TupleInt:
|
|
396
|
+
return TupleInt(vec.length(), partial(index_vec_int, vec))
|
|
397
|
+
|
|
398
|
+
def __add__(self, other: TupleTupleInt) -> TupleTupleInt:
|
|
399
|
+
return TupleTupleInt(
|
|
400
|
+
self.length() + other.length(),
|
|
401
|
+
lambda i: TupleInt.if_(i < self.length(), self[i], other[i - self.length()]),
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
def length(self) -> Int: ...
|
|
405
|
+
def __getitem__(self, i: IntLike) -> TupleInt: ...
|
|
406
|
+
|
|
407
|
+
@method(preserve=True)
|
|
408
|
+
def __len__(self) -> int:
|
|
409
|
+
return int(self.length())
|
|
410
|
+
|
|
411
|
+
@method(preserve=True)
|
|
412
|
+
def __iter__(self) -> Iterator[TupleInt]:
|
|
413
|
+
return iter(self[i] for i in range(len(self)))
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
@function
|
|
417
|
+
def bottom_indexing(i: Int) -> Int: ...
|
|
418
|
+
|
|
419
|
+
|
|
291
420
|
class OptionalInt(Expr):
|
|
292
421
|
none: ClassVar[OptionalInt]
|
|
293
422
|
|
|
@@ -393,6 +522,8 @@ converter(
|
|
|
393
522
|
lambda x: Slice(convert(x.start, OptionalInt), convert(x.stop, OptionalInt), convert(x.step, OptionalInt)),
|
|
394
523
|
)
|
|
395
524
|
|
|
525
|
+
SliceLike: TypeAlias = Slice | slice
|
|
526
|
+
|
|
396
527
|
|
|
397
528
|
class MultiAxisIndexKeyItem(Expr):
|
|
398
529
|
ELLIPSIS: ClassVar[MultiAxisIndexKeyItem]
|
|
@@ -410,23 +541,28 @@ converter(type(None), MultiAxisIndexKeyItem, lambda _: MultiAxisIndexKeyItem.NON
|
|
|
410
541
|
converter(Int, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.int)
|
|
411
542
|
converter(Slice, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.slice)
|
|
412
543
|
|
|
544
|
+
MultiAxisIndexKeyItemLike: TypeAlias = MultiAxisIndexKeyItem | EllipsisType | None | IntLike | SliceLike
|
|
413
545
|
|
|
414
|
-
class MultiAxisIndexKey(Expr):
|
|
415
|
-
def __init__(self, item: MultiAxisIndexKeyItem) -> None: ...
|
|
416
546
|
|
|
417
|
-
|
|
547
|
+
class MultiAxisIndexKey(Expr, ruleset=array_api_ruleset):
|
|
548
|
+
def __init__(self, length: IntLike, idx_fn: Callable[[Int], MultiAxisIndexKeyItem]) -> None: ...
|
|
418
549
|
|
|
419
550
|
def __add__(self, other: MultiAxisIndexKey) -> MultiAxisIndexKey: ...
|
|
420
551
|
|
|
552
|
+
@classmethod
|
|
553
|
+
def from_vec(cls, vec: Vec[MultiAxisIndexKeyItem]) -> MultiAxisIndexKey: ...
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
MultiAxisIndexKeyLike: TypeAlias = "MultiAxisIndexKey | tuple[MultiAxisIndexKeyItemLike, ...] | TupleIntLike"
|
|
557
|
+
|
|
421
558
|
|
|
422
559
|
converter(
|
|
423
560
|
tuple,
|
|
424
561
|
MultiAxisIndexKey,
|
|
425
|
-
lambda x: MultiAxisIndexKey(convert(
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
else MultiAxisIndexKey.EMPTY,
|
|
562
|
+
lambda x: MultiAxisIndexKey.from_vec(Vec(*(convert(i, MultiAxisIndexKeyItem) for i in x))),
|
|
563
|
+
)
|
|
564
|
+
converter(
|
|
565
|
+
TupleInt, MultiAxisIndexKey, lambda ti: MultiAxisIndexKey(ti.length(), lambda i: MultiAxisIndexKeyItem.int(ti[i]))
|
|
430
566
|
)
|
|
431
567
|
|
|
432
568
|
|
|
@@ -457,11 +593,20 @@ class IndexKey(Expr):
|
|
|
457
593
|
@classmethod
|
|
458
594
|
def multi_axis(cls, key: MultiAxisIndexKey) -> IndexKey: ...
|
|
459
595
|
|
|
596
|
+
@classmethod
|
|
597
|
+
def ndarray(cls, key: NDArray) -> IndexKey:
|
|
598
|
+
"""
|
|
599
|
+
Indexes by a masked array
|
|
600
|
+
"""
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
IndexKeyLike: TypeAlias = "IndexKey | IntLike | SliceLike | MultiAxisIndexKeyLike | NDArrayLike"
|
|
604
|
+
|
|
460
605
|
|
|
461
606
|
converter(type(...), IndexKey, lambda _: IndexKey.ELLIPSIS)
|
|
462
|
-
converter(Int, IndexKey, IndexKey.int)
|
|
463
|
-
converter(Slice, IndexKey, IndexKey.slice)
|
|
464
|
-
converter(MultiAxisIndexKey, IndexKey, IndexKey.multi_axis)
|
|
607
|
+
converter(Int, IndexKey, lambda i: IndexKey.int(i))
|
|
608
|
+
converter(Slice, IndexKey, lambda s: IndexKey.slice(s))
|
|
609
|
+
converter(MultiAxisIndexKey, IndexKey, lambda m: IndexKey.multi_axis(m))
|
|
465
610
|
|
|
466
611
|
|
|
467
612
|
class Device(Expr): ...
|
|
@@ -520,6 +665,8 @@ converter(Float, Value, Value.float)
|
|
|
520
665
|
converter(Boolean, Value, Value.bool)
|
|
521
666
|
converter(Value, Int, lambda x: x.to_int, 10)
|
|
522
667
|
|
|
668
|
+
ValueLike: TypeAlias = Value | IntLike | FloatLike | BooleanLike
|
|
669
|
+
|
|
523
670
|
|
|
524
671
|
@array_api_ruleset.register
|
|
525
672
|
def _value(i: Int, f: Float, b: Boolean):
|
|
@@ -560,6 +707,8 @@ converter(
|
|
|
560
707
|
else TupleValue.EMPTY,
|
|
561
708
|
)
|
|
562
709
|
|
|
710
|
+
TupleValueLike: TypeAlias = TupleValue | tuple[ValueLike, ...]
|
|
711
|
+
|
|
563
712
|
|
|
564
713
|
@array_api_ruleset.register
|
|
565
714
|
def _tuple_value(
|
|
@@ -595,7 +744,7 @@ def possible_values(values: Value) -> TupleValue:
|
|
|
595
744
|
|
|
596
745
|
|
|
597
746
|
class NDArray(Expr):
|
|
598
|
-
def __init__(self,
|
|
747
|
+
def __init__(self, shape: TupleInt, dtype: DType, idx_fn: Callable[[TupleInt], Value]) -> None: ...
|
|
599
748
|
|
|
600
749
|
@method(cost=200)
|
|
601
750
|
@classmethod
|
|
@@ -633,9 +782,9 @@ class NDArray(Expr):
|
|
|
633
782
|
for i in range(len(self)):
|
|
634
783
|
yield self[IndexKey.int(Int(i))]
|
|
635
784
|
|
|
636
|
-
def __getitem__(self, key:
|
|
785
|
+
def __getitem__(self, key: IndexKeyLike) -> NDArray: ...
|
|
637
786
|
|
|
638
|
-
def __setitem__(self, key:
|
|
787
|
+
def __setitem__(self, key: IndexKeyLike, value: NDArray) -> None: ...
|
|
639
788
|
|
|
640
789
|
def __lt__(self, other: NDArray) -> NDArray: ...
|
|
641
790
|
|
|
@@ -720,20 +869,15 @@ class NDArray(Expr):
|
|
|
720
869
|
@classmethod
|
|
721
870
|
def vector(cls, values: TupleValue) -> NDArray: ...
|
|
722
871
|
|
|
723
|
-
def index(self, indices:
|
|
872
|
+
def index(self, indices: TupleIntLike) -> Value:
|
|
724
873
|
"""
|
|
725
874
|
Return the value at the given indices.
|
|
726
875
|
"""
|
|
727
876
|
|
|
728
877
|
|
|
729
|
-
|
|
730
|
-
def ndarray_index(x: NDArray) -> IndexKey:
|
|
731
|
-
"""
|
|
732
|
-
Indexes by a masked array
|
|
733
|
-
"""
|
|
878
|
+
NDArrayLike: TypeAlias = NDArray | ValueLike | TupleValueLike
|
|
734
879
|
|
|
735
|
-
|
|
736
|
-
converter(NDArray, IndexKey, ndarray_index)
|
|
880
|
+
converter(NDArray, IndexKey, IndexKey.ndarray)
|
|
737
881
|
converter(Value, NDArray, NDArray.scalar)
|
|
738
882
|
# Need this if we want to use ints in slices of arrays coming from 1d arrays, but make it more expensive
|
|
739
883
|
# to prefer upcasting in the other direction when we can, which is safter at runtime
|
|
@@ -742,8 +886,19 @@ converter(TupleValue, NDArray, NDArray.vector)
|
|
|
742
886
|
|
|
743
887
|
|
|
744
888
|
@array_api_ruleset.register
|
|
745
|
-
def _ndarray(
|
|
889
|
+
def _ndarray(
|
|
890
|
+
x: NDArray,
|
|
891
|
+
b: Boolean,
|
|
892
|
+
f: Float,
|
|
893
|
+
fi1: f64,
|
|
894
|
+
fi2: f64,
|
|
895
|
+
shape: TupleInt,
|
|
896
|
+
dtype: DType,
|
|
897
|
+
idx_fn: Callable[[TupleInt], Value],
|
|
898
|
+
):
|
|
746
899
|
return [
|
|
900
|
+
rewrite(NDArray(shape, dtype, idx_fn).shape).to(shape),
|
|
901
|
+
rewrite(NDArray(shape, dtype, idx_fn).dtype).to(dtype),
|
|
747
902
|
rewrite(x.ndim).to(x.shape.length()),
|
|
748
903
|
# rewrite(NDArray.scalar(Value.bool(b)).to_bool()).to(b),
|
|
749
904
|
# Converting to a value requires a scalar bool value
|
|
@@ -1067,6 +1222,14 @@ def sqrt(x: NDArray) -> NDArray: ...
|
|
|
1067
1222
|
def std(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray: ...
|
|
1068
1223
|
|
|
1069
1224
|
|
|
1225
|
+
@function
|
|
1226
|
+
def real(x: NDArray) -> NDArray: ...
|
|
1227
|
+
|
|
1228
|
+
|
|
1229
|
+
@function
|
|
1230
|
+
def conj(x: NDArray) -> NDArray: ...
|
|
1231
|
+
|
|
1232
|
+
|
|
1070
1233
|
linalg = sys.modules[__name__]
|
|
1071
1234
|
|
|
1072
1235
|
|
|
@@ -1228,10 +1391,10 @@ def _scalar_math(v: Value, vs: TupleValue, i: Int):
|
|
|
1228
1391
|
|
|
1229
1392
|
|
|
1230
1393
|
@array_api_ruleset.register
|
|
1231
|
-
def _vector_math(v: Value, vs: TupleValue,
|
|
1232
|
-
yield rewrite(NDArray.vector(vs).shape).to(TupleInt(vs.length()))
|
|
1394
|
+
def _vector_math(v: Value, vs: TupleValue, ti: TupleInt):
|
|
1395
|
+
yield rewrite(NDArray.vector(vs).shape).to(TupleInt.single(vs.length()))
|
|
1233
1396
|
yield rewrite(NDArray.vector(vs).dtype).to(vs[Int(0)].dtype)
|
|
1234
|
-
yield rewrite(NDArray.vector(vs).index(
|
|
1397
|
+
yield rewrite(NDArray.vector(vs).index(ti)).to(vs[ti[0]])
|
|
1235
1398
|
|
|
1236
1399
|
|
|
1237
1400
|
@array_api_ruleset.register
|
|
@@ -1239,15 +1402,22 @@ def _reshape_math(x: NDArray, shape: TupleInt, copy: OptionalBool):
|
|
|
1239
1402
|
res = reshape(x, shape, copy)
|
|
1240
1403
|
|
|
1241
1404
|
yield _demand_shape(res, x)
|
|
1405
|
+
# Demand shape length and index
|
|
1406
|
+
yield rule(res).then(shape.length(), shape[0])
|
|
1242
1407
|
|
|
1243
1408
|
# Reshaping a vec to a vec is the same as the vec
|
|
1244
|
-
yield rewrite(
|
|
1409
|
+
yield rewrite(res).to(
|
|
1410
|
+
x,
|
|
1411
|
+
eq(x.shape.length()).to(Int(1)),
|
|
1412
|
+
eq(shape.length()).to(Int(1)),
|
|
1413
|
+
eq(shape[0]).to(Int(-1)),
|
|
1414
|
+
)
|
|
1245
1415
|
|
|
1246
1416
|
|
|
1247
1417
|
@array_api_ruleset.register
|
|
1248
1418
|
def _indexing_pushdown(x: NDArray, shape: TupleInt, copy: OptionalBool, i: Int):
|
|
1249
1419
|
# rewrite full getitem to indexec
|
|
1250
|
-
yield rewrite(x[IndexKey.int(i)]).to(NDArray.scalar(x.index(TupleInt(i))))
|
|
1420
|
+
yield rewrite(x[IndexKey.int(i)]).to(NDArray.scalar(x.index(TupleInt.single(i))))
|
|
1251
1421
|
# TODO: Multi index rewrite as well if all are ints
|
|
1252
1422
|
|
|
1253
1423
|
|
|
@@ -1273,7 +1443,7 @@ def _assume_dtype(x: NDArray, dtype: DType, idx: TupleInt):
|
|
|
1273
1443
|
|
|
1274
1444
|
|
|
1275
1445
|
@function(mutates_first_arg=True)
|
|
1276
|
-
def assume_shape(x: NDArray, shape:
|
|
1446
|
+
def assume_shape(x: NDArray, shape: TupleIntLike) -> None:
|
|
1277
1447
|
"""
|
|
1278
1448
|
Asserts that the shape of x is shape.
|
|
1279
1449
|
"""
|
|
@@ -1368,8 +1538,23 @@ def try_evaling(expr: Expr, prim_expr: i64 | Bool) -> int | bool:
|
|
|
1368
1538
|
egraph.register(expr)
|
|
1369
1539
|
egraph.run(array_api_schedule)
|
|
1370
1540
|
try:
|
|
1371
|
-
|
|
1541
|
+
extracted = egraph.extract(prim_expr)
|
|
1372
1542
|
except EggSmolError as exc:
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1543
|
+
# Try giving some context, by showing the smallest version of the larger expression
|
|
1544
|
+
try:
|
|
1545
|
+
expr_extracted = egraph.extract(expr)
|
|
1546
|
+
except EggSmolError as inner_exc:
|
|
1547
|
+
raise ValueError(f"Cannot simplify {expr}") from inner_exc
|
|
1548
|
+
egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
|
|
1549
|
+
msg = f"Cannot simplify to primitive {expr_extracted}"
|
|
1550
|
+
raise ValueError(msg) from exc
|
|
1551
|
+
return egraph.eval(extracted)
|
|
1552
|
+
|
|
1553
|
+
# string = (
|
|
1554
|
+
# egraph.as_egglog_string
|
|
1555
|
+
# + "\n"
|
|
1556
|
+
# + str(egraph._state.typed_expr_to_egg(cast(RuntimeExpr, prim_expr).__egg_typed_expr__))
|
|
1557
|
+
# )
|
|
1558
|
+
# # save to "tmp.egg"
|
|
1559
|
+
# with open("tmp.egg", "w") as f:
|
|
1560
|
+
# f.write(string)
|
egglog/exp/array_api_jit.py
CHANGED
|
@@ -14,24 +14,20 @@ def jit(fn: X) -> X:
|
|
|
14
14
|
"""
|
|
15
15
|
Jit compiles a function
|
|
16
16
|
"""
|
|
17
|
-
from IPython.display import SVG
|
|
18
|
-
|
|
19
17
|
# 1. Create variables for each of the two args in the functions
|
|
20
18
|
sig = inspect.signature(fn)
|
|
21
19
|
arg1, arg2 = sig.parameters.keys()
|
|
22
|
-
|
|
23
|
-
with
|
|
20
|
+
egraph = EGraph()
|
|
21
|
+
with egraph:
|
|
24
22
|
res = fn(NDArray.var(arg1), NDArray.var(arg2))
|
|
25
23
|
egraph.register(res)
|
|
26
24
|
egraph.run(array_api_numba_schedule)
|
|
27
25
|
res_optimized = egraph.extract(res)
|
|
28
|
-
|
|
26
|
+
# egraph.display(split_primitive_outputs=True, n_inline_leaves=3)
|
|
29
27
|
|
|
30
|
-
egraph = EGraph()
|
|
31
28
|
fn_program = ndarray_function_two(res_optimized, NDArray.var(arg1), NDArray.var(arg2))
|
|
32
29
|
egraph.register(fn_program)
|
|
33
30
|
egraph.run(array_api_program_gen_schedule)
|
|
34
|
-
fn = cast(X, egraph.eval(fn_program.py_object))
|
|
35
|
-
fn.egraph = svg # type: ignore[attr-defined]
|
|
31
|
+
fn = cast(X, egraph.eval(egraph.extract(fn_program.py_object)))
|
|
36
32
|
fn.expr = res_optimized # type: ignore[attr-defined]
|
|
37
33
|
return fn
|