egglog 7.2.0__cp310-none-win_amd64.whl → 8.0.1__cp310-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/exp/array_api.py CHANGED
@@ -7,7 +7,9 @@ import math
7
7
  import numbers
8
8
  import sys
9
9
  from copy import copy
10
- from typing import TYPE_CHECKING, ClassVar, overload
10
+ from functools import partial
11
+ from types import EllipsisType
12
+ from typing import TYPE_CHECKING, ClassVar, TypeAlias, overload
11
13
 
12
14
  import numpy as np
13
15
 
@@ -42,6 +44,8 @@ class Boolean(Expr):
42
44
 
43
45
  def if_int(self, true_value: Int, false_value: Int) -> Int: ...
44
46
 
47
+ def __invert__(self) -> Boolean: ...
48
+
45
49
 
46
50
  BooleanLike = Boolean | bool
47
51
 
@@ -61,6 +65,8 @@ def _bool(x: Boolean, i: Int, j: Int):
61
65
  rewrite(FALSE & x).to(FALSE),
62
66
  rewrite(TRUE.if_int(i, j)).to(i),
63
67
  rewrite(FALSE.if_int(i, j)).to(j),
68
+ rewrite(~TRUE).to(FALSE),
69
+ rewrite(~FALSE).to(TRUE),
64
70
  ]
65
71
 
66
72
 
@@ -69,74 +75,75 @@ class Int(Expr):
69
75
 
70
76
  def __invert__(self) -> Int: ...
71
77
 
72
- def __lt__(self, other: Int) -> Boolean: ...
78
+ def __lt__(self, other: IntLike) -> Boolean: ...
73
79
 
74
- def __le__(self, other: Int) -> Boolean: ...
80
+ def __le__(self, other: IntLike) -> Boolean: ...
75
81
 
76
- def __eq__(self, other: Int) -> Boolean: # type: ignore[override]
82
+ def __eq__(self, other: IntLike) -> Boolean: # type: ignore[override]
77
83
  ...
78
84
 
85
+ # TODO: Fix this?
79
86
  # Make != always return a Bool, so that numpy.unique works on a tuple of ints
80
87
  # In _unique1d
81
88
  @method(preserve=True)
82
89
  def __ne__(self, other: Int) -> bool: # type: ignore[override]
83
90
  return not (self == other)
84
91
 
85
- def __gt__(self, other: Int) -> Boolean: ...
92
+ def __gt__(self, other: IntLike) -> Boolean: ...
86
93
 
87
- def __ge__(self, other: Int) -> Boolean: ...
94
+ def __ge__(self, other: IntLike) -> Boolean: ...
88
95
 
89
- def __add__(self, other: Int) -> Int: ...
96
+ def __add__(self, other: IntLike) -> Int: ...
90
97
 
91
- def __sub__(self, other: Int) -> Int: ...
98
+ def __sub__(self, other: IntLike) -> Int: ...
92
99
 
93
- def __mul__(self, other: Int) -> Int: ...
100
+ def __mul__(self, other: IntLike) -> Int: ...
94
101
 
95
- def __truediv__(self, other: Int) -> Int: ...
102
+ def __truediv__(self, other: IntLike) -> Int: ...
96
103
 
97
- def __floordiv__(self, other: Int) -> Int: ...
104
+ def __floordiv__(self, other: IntLike) -> Int: ...
98
105
 
99
- def __mod__(self, other: Int) -> Int: ...
106
+ def __mod__(self, other: IntLike) -> Int: ...
100
107
 
101
- def __divmod__(self, other: Int) -> Int: ...
108
+ def __divmod__(self, other: IntLike) -> Int: ...
102
109
 
103
- def __pow__(self, other: Int) -> Int: ...
110
+ def __pow__(self, other: IntLike) -> Int: ...
104
111
 
105
- def __lshift__(self, other: Int) -> Int: ...
112
+ def __lshift__(self, other: IntLike) -> Int: ...
106
113
 
107
- def __rshift__(self, other: Int) -> Int: ...
114
+ def __rshift__(self, other: IntLike) -> Int: ...
108
115
 
109
- def __and__(self, other: Int) -> Int: ...
116
+ def __and__(self, other: IntLike) -> Int: ...
110
117
 
111
- def __xor__(self, other: Int) -> Int: ...
118
+ def __xor__(self, other: IntLike) -> Int: ...
112
119
 
113
- def __or__(self, other: Int) -> Int: ...
120
+ def __or__(self, other: IntLike) -> Int: ...
114
121
 
115
- def __radd__(self, other: Int) -> Int: ...
122
+ def __radd__(self, other: IntLike) -> Int: ...
116
123
 
117
- def __rsub__(self, other: Int) -> Int: ...
124
+ def __rsub__(self, other: IntLike) -> Int: ...
118
125
 
119
- def __rmul__(self, other: Int) -> Int: ...
126
+ def __rmul__(self, other: IntLike) -> Int: ...
120
127
 
121
- def __rmatmul__(self, other: Int) -> Int: ...
128
+ def __rmatmul__(self, other: IntLike) -> Int: ...
122
129
 
123
- def __rtruediv__(self, other: Int) -> Int: ...
130
+ def __rtruediv__(self, other: IntLike) -> Int: ...
124
131
 
125
- def __rfloordiv__(self, other: Int) -> Int: ...
132
+ def __rfloordiv__(self, other: IntLike) -> Int: ...
126
133
 
127
- def __rmod__(self, other: Int) -> Int: ...
134
+ def __rmod__(self, other: IntLike) -> Int: ...
128
135
 
129
- def __rpow__(self, other: Int) -> Int: ...
136
+ def __rpow__(self, other: IntLike) -> Int: ...
130
137
 
131
- def __rlshift__(self, other: Int) -> Int: ...
138
+ def __rlshift__(self, other: IntLike) -> Int: ...
132
139
 
133
- def __rrshift__(self, other: Int) -> Int: ...
140
+ def __rrshift__(self, other: IntLike) -> Int: ...
134
141
 
135
- def __rand__(self, other: Int) -> Int: ...
142
+ def __rand__(self, other: IntLike) -> Int: ...
136
143
 
137
- def __rxor__(self, other: Int) -> Int: ...
144
+ def __rxor__(self, other: IntLike) -> Int: ...
138
145
 
139
- def __ror__(self, other: Int) -> Int: ...
146
+ def __ror__(self, other: IntLike) -> Int: ...
140
147
 
141
148
  @property
142
149
  def i64(self) -> i64: ...
@@ -157,9 +164,12 @@ class Int(Expr):
157
164
  def __bool__(self) -> bool:
158
165
  return bool(int(self))
159
166
 
167
+ @classmethod
168
+ def if_(cls, b: Boolean, i: Int, j: Int) -> Int: ...
169
+
160
170
 
161
171
  @array_api_ruleset.register
162
- def _int(i: i64, j: i64, r: Boolean, o: Int):
172
+ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
163
173
  yield rewrite(Int(i) == Int(i)).to(TRUE)
164
174
  yield rule(eq(r).to(Int(i) == Int(j)), ne(i).to(j)).then(union(r).with_(FALSE))
165
175
 
@@ -189,8 +199,13 @@ def _int(i: i64, j: i64, r: Boolean, o: Int):
189
199
  yield rewrite(Int(i) >> Int(j)).to(Int(i >> j))
190
200
  yield rewrite(~Int(i)).to(Int(~i))
191
201
 
202
+ yield rewrite(Int.if_(TRUE, o, b)).to(o)
203
+ yield rewrite(Int.if_(FALSE, o, b)).to(b)
204
+
192
205
 
193
- converter(int, Int, lambda x: Int(x))
206
+ converter(i64, Int, lambda x: Int(x))
207
+
208
+ IntLike: TypeAlias = Int | i64Like
194
209
 
195
210
 
196
211
  class Float(Expr):
@@ -220,6 +235,9 @@ converter(float, Float, lambda x: Float(x))
220
235
  converter(Int, Float, lambda x: Float.from_int(x))
221
236
 
222
237
 
238
+ FloatLike: TypeAlias = Float | float | IntLike
239
+
240
+
223
241
  @array_api_ruleset.register
224
242
  def _float(f: f64, f2: f64, i: i64, r: Rational, r1: Rational):
225
243
  return [
@@ -238,14 +256,45 @@ def _float(f: f64, f2: f64, i: i64, r: Rational, r1: Rational):
238
256
  ]
239
257
 
240
258
 
241
- class TupleInt(Expr):
259
+ @function
260
+ def index_vec_int(xs: Vec[Int], i: Int) -> Int: ...
261
+
262
+
263
+ class TupleInt(Expr, ruleset=array_api_ruleset):
264
+ """
265
+ Should act like a tuple[int, ...]
266
+
267
+ All constructors should be rewritten to the functional semantics in the __init__ method.
268
+ """
269
+
270
+ @classmethod
271
+ def var(cls, name: StringLike) -> TupleInt: ...
272
+
242
273
  EMPTY: ClassVar[TupleInt]
243
274
 
244
- def __init__(self, head: Int) -> None: ...
275
+ @method(unextractable=True)
276
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], Int]) -> None: ...
277
+
278
+ @classmethod
279
+ def single(cls, i: Int) -> TupleInt:
280
+ return TupleInt(Int(1), lambda _: i)
281
+
282
+ @classmethod
283
+ def range(cls, stop: Int) -> TupleInt:
284
+ return TupleInt(stop, lambda i: i)
285
+
286
+ @classmethod
287
+ def from_vec(cls, vec: Vec[Int]) -> TupleInt:
288
+ return TupleInt(vec.length(), partial(index_vec_int, vec))
245
289
 
246
- def __add__(self, other: TupleInt) -> TupleInt: ...
290
+ def __add__(self, other: TupleInt) -> TupleInt:
291
+ return TupleInt(
292
+ self.length() + other.length(),
293
+ lambda i: Int.if_(i < self.length(), self[i], other[i - self.length()]),
294
+ )
247
295
 
248
296
  def length(self) -> Int: ...
297
+ def __getitem__(self, i: IntLike) -> Int: ...
249
298
 
250
299
  @method(preserve=True)
251
300
  def __len__(self) -> int:
@@ -253,41 +302,121 @@ class TupleInt(Expr):
253
302
 
254
303
  @method(preserve=True)
255
304
  def __iter__(self) -> Iterator[Int]:
256
- return iter(self[Int(i)] for i in range(len(self)))
257
-
258
- def __getitem__(self, i: Int) -> Int: ...
305
+ return iter(self[i] for i in range(len(self)))
259
306
 
260
307
  def fold(self, init: Int, f: Callable[[Int, Int], Int]) -> Int: ...
261
308
 
309
+ def fold_boolean(self, init: Boolean, f: Callable[[Boolean, Int], Boolean]) -> Boolean: ...
262
310
 
263
- converter(
264
- tuple,
265
- TupleInt,
266
- lambda x: TupleInt(convert(x[0], Int)) + convert(x[1:], TupleInt)
267
- if len(x) > 1
268
- else TupleInt(convert(x[0], Int))
269
- if x
270
- else TupleInt.EMPTY,
271
- )
311
+ def contains(self, i: Int) -> Boolean:
312
+ return self.fold_boolean(FALSE, lambda acc, j: acc | (i == j))
313
+
314
+ @method(cost=100)
315
+ def filter(self, f: Callable[[Int], Boolean]) -> TupleInt: ...
316
+
317
+ @method(cost=100)
318
+ def map(self, f: Callable[[Int], Int]) -> TupleInt:
319
+ return TupleInt(self.length(), lambda i: f(self[i]))
320
+
321
+ @classmethod
322
+ def if_(cls, b: Boolean, i: TupleInt, j: TupleInt) -> TupleInt: ...
323
+
324
+ @method(preserve=True)
325
+ def to_py(self) -> tuple[int, ...]:
326
+ return tuple(int(i) for i in self)
327
+
328
+
329
+ # TODO: Upcast args for Vec[Int] constructor
330
+ converter(tuple, TupleInt, lambda x: TupleInt.from_vec(Vec(*(convert(i, Int) for i in x))))
331
+
332
+ TupleIntLike: TypeAlias = TupleInt | tuple[IntLike, ...]
272
333
 
273
334
 
274
335
  @array_api_ruleset.register
275
- def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64, f: Callable[[Int, Int], Int]):
336
+ def _tuple_int(
337
+ i: Int,
338
+ i2: Int,
339
+ k: i64,
340
+ f: Callable[[Int, Int], Int],
341
+ bool_f: Callable[[Boolean, Int], Boolean],
342
+ idx_fn: Callable[[Int], Int],
343
+ map_fn: Callable[[Int], Int],
344
+ filter_f: Callable[[Int], Boolean],
345
+ vs: Vec[Int],
346
+ b: Boolean,
347
+ ti: TupleInt,
348
+ ti2: TupleInt,
349
+ ):
350
+ remaining = TupleInt(k - 1, lambda i: idx_fn(i + 1)).filter(filter_f)
276
351
  return [
277
- rewrite(ti + TupleInt.EMPTY).to(ti),
278
- rewrite(TupleInt(i).length()).to(Int(1)),
279
- rewrite((ti + ti2).length()).to(ti.length() + ti2.length()),
280
- rewrite(TupleInt(i)[Int(0)]).to(i),
281
- rewrite((TupleInt(i) + ti)[Int(0)]).to(i),
282
- # Rule for indexing > 0
283
- rule(eq(i).to((TupleInt(i2) + ti)[Int(k)]), k > 0).then(union(i).with_(ti[Int(k - 1)])),
352
+ rewrite(TupleInt(i, idx_fn).length()).to(i),
353
+ rewrite(TupleInt(i, idx_fn)[i2]).to(idx_fn(i2)),
354
+ # index_vec_int
355
+ rewrite(index_vec_int(vs, Int(k))).to(vs[k], vs.length() > k),
284
356
  # fold
285
- rewrite(TupleInt.EMPTY.fold(i, f)).to(i),
286
- rewrite(TupleInt(i2).fold(i, f)).to(f(i, i2)),
287
- rewrite((TupleInt(i2) + ti).fold(i, f)).to(ti.fold(f(i, i2), f)),
357
+ rewrite(TupleInt(0, idx_fn).fold(i, f)).to(i),
358
+ rewrite(TupleInt(Int(k), idx_fn).fold(i, f)).to(
359
+ f(TupleInt(k - 1, lambda i: idx_fn(i + 1)).fold(i, f), idx_fn(Int(0))),
360
+ ne(k).to(i64(0)),
361
+ ),
362
+ # fold boolean
363
+ rewrite(TupleInt(0, idx_fn).fold_boolean(b, bool_f)).to(b),
364
+ rewrite(TupleInt(Int(k), idx_fn).fold_boolean(b, bool_f)).to(
365
+ bool_f(TupleInt(k - 1, lambda i: idx_fn(i + 1)).fold_boolean(b, bool_f), idx_fn(Int(0))),
366
+ ne(k).to(i64(0)),
367
+ ),
368
+ # filter TODO: could be written as fold w/ generic types
369
+ rewrite(TupleInt(0, idx_fn).filter(filter_f)).to(TupleInt(0, idx_fn)),
370
+ rewrite(TupleInt(Int(k), idx_fn).filter(filter_f)).to(
371
+ TupleInt.if_(filter_f(value := idx_fn(Int(k))), TupleInt.single(value) + remaining, remaining),
372
+ ne(k).to(i64(0)),
373
+ ),
374
+ # Empty
375
+ rewrite(TupleInt.EMPTY).to(TupleInt(0, bottom_indexing)),
376
+ # if_
377
+ rewrite(TupleInt.if_(TRUE, ti, ti2)).to(ti),
378
+ rewrite(TupleInt.if_(FALSE, ti, ti2)).to(ti2),
288
379
  ]
289
380
 
290
381
 
382
+ class TupleTupleInt(Expr, ruleset=array_api_ruleset):
383
+ @classmethod
384
+ def var(cls, name: StringLike) -> TupleTupleInt: ...
385
+
386
+ EMPTY: ClassVar[TupleTupleInt]
387
+
388
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> None: ...
389
+
390
+ @classmethod
391
+ def single(cls, i: TupleInt) -> TupleTupleInt:
392
+ return TupleTupleInt(Int(1), lambda _: i)
393
+
394
+ @classmethod
395
+ def from_vec(cls, vec: Vec[Int]) -> TupleInt:
396
+ return TupleInt(vec.length(), partial(index_vec_int, vec))
397
+
398
+ def __add__(self, other: TupleTupleInt) -> TupleTupleInt:
399
+ return TupleTupleInt(
400
+ self.length() + other.length(),
401
+ lambda i: TupleInt.if_(i < self.length(), self[i], other[i - self.length()]),
402
+ )
403
+
404
+ def length(self) -> Int: ...
405
+ def __getitem__(self, i: IntLike) -> TupleInt: ...
406
+
407
+ @method(preserve=True)
408
+ def __len__(self) -> int:
409
+ return int(self.length())
410
+
411
+ @method(preserve=True)
412
+ def __iter__(self) -> Iterator[TupleInt]:
413
+ return iter(self[i] for i in range(len(self)))
414
+
415
+
416
+ @function
417
+ def bottom_indexing(i: Int) -> Int: ...
418
+
419
+
291
420
  class OptionalInt(Expr):
292
421
  none: ClassVar[OptionalInt]
293
422
 
@@ -393,6 +522,8 @@ converter(
393
522
  lambda x: Slice(convert(x.start, OptionalInt), convert(x.stop, OptionalInt), convert(x.step, OptionalInt)),
394
523
  )
395
524
 
525
+ SliceLike: TypeAlias = Slice | slice
526
+
396
527
 
397
528
  class MultiAxisIndexKeyItem(Expr):
398
529
  ELLIPSIS: ClassVar[MultiAxisIndexKeyItem]
@@ -410,23 +541,28 @@ converter(type(None), MultiAxisIndexKeyItem, lambda _: MultiAxisIndexKeyItem.NON
410
541
  converter(Int, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.int)
411
542
  converter(Slice, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.slice)
412
543
 
544
+ MultiAxisIndexKeyItemLike: TypeAlias = MultiAxisIndexKeyItem | EllipsisType | None | IntLike | SliceLike
413
545
 
414
- class MultiAxisIndexKey(Expr):
415
- def __init__(self, item: MultiAxisIndexKeyItem) -> None: ...
416
546
 
417
- EMPTY: ClassVar[MultiAxisIndexKey]
547
+ class MultiAxisIndexKey(Expr, ruleset=array_api_ruleset):
548
+ def __init__(self, length: IntLike, idx_fn: Callable[[Int], MultiAxisIndexKeyItem]) -> None: ...
418
549
 
419
550
  def __add__(self, other: MultiAxisIndexKey) -> MultiAxisIndexKey: ...
420
551
 
552
+ @classmethod
553
+ def from_vec(cls, vec: Vec[MultiAxisIndexKeyItem]) -> MultiAxisIndexKey: ...
554
+
555
+
556
+ MultiAxisIndexKeyLike: TypeAlias = "MultiAxisIndexKey | tuple[MultiAxisIndexKeyItemLike, ...] | TupleIntLike"
557
+
421
558
 
422
559
  converter(
423
560
  tuple,
424
561
  MultiAxisIndexKey,
425
- lambda x: MultiAxisIndexKey(convert(x[0], MultiAxisIndexKeyItem)) + convert(x[1:], MultiAxisIndexKey)
426
- if len(x) > 1
427
- else MultiAxisIndexKey(convert(x[0], MultiAxisIndexKeyItem))
428
- if x
429
- else MultiAxisIndexKey.EMPTY,
562
+ lambda x: MultiAxisIndexKey.from_vec(Vec(*(convert(i, MultiAxisIndexKeyItem) for i in x))),
563
+ )
564
+ converter(
565
+ TupleInt, MultiAxisIndexKey, lambda ti: MultiAxisIndexKey(ti.length(), lambda i: MultiAxisIndexKeyItem.int(ti[i]))
430
566
  )
431
567
 
432
568
 
@@ -457,11 +593,20 @@ class IndexKey(Expr):
457
593
  @classmethod
458
594
  def multi_axis(cls, key: MultiAxisIndexKey) -> IndexKey: ...
459
595
 
596
+ @classmethod
597
+ def ndarray(cls, key: NDArray) -> IndexKey:
598
+ """
599
+ Indexes by a masked array
600
+ """
601
+
602
+
603
+ IndexKeyLike: TypeAlias = "IndexKey | IntLike | SliceLike | MultiAxisIndexKeyLike | NDArrayLike"
604
+
460
605
 
461
606
  converter(type(...), IndexKey, lambda _: IndexKey.ELLIPSIS)
462
- converter(Int, IndexKey, IndexKey.int)
463
- converter(Slice, IndexKey, IndexKey.slice)
464
- converter(MultiAxisIndexKey, IndexKey, IndexKey.multi_axis)
607
+ converter(Int, IndexKey, lambda i: IndexKey.int(i))
608
+ converter(Slice, IndexKey, lambda s: IndexKey.slice(s))
609
+ converter(MultiAxisIndexKey, IndexKey, lambda m: IndexKey.multi_axis(m))
465
610
 
466
611
 
467
612
  class Device(Expr): ...
@@ -520,6 +665,8 @@ converter(Float, Value, Value.float)
520
665
  converter(Boolean, Value, Value.bool)
521
666
  converter(Value, Int, lambda x: x.to_int, 10)
522
667
 
668
+ ValueLike: TypeAlias = Value | IntLike | FloatLike | BooleanLike
669
+
523
670
 
524
671
  @array_api_ruleset.register
525
672
  def _value(i: Int, f: Float, b: Boolean):
@@ -560,6 +707,8 @@ converter(
560
707
  else TupleValue.EMPTY,
561
708
  )
562
709
 
710
+ TupleValueLike: TypeAlias = TupleValue | tuple[ValueLike, ...]
711
+
563
712
 
564
713
  @array_api_ruleset.register
565
714
  def _tuple_value(
@@ -595,7 +744,7 @@ def possible_values(values: Value) -> TupleValue:
595
744
 
596
745
 
597
746
  class NDArray(Expr):
598
- def __init__(self, py_array: PyObject) -> None: ...
747
+ def __init__(self, shape: TupleInt, dtype: DType, idx_fn: Callable[[TupleInt], Value]) -> None: ...
599
748
 
600
749
  @method(cost=200)
601
750
  @classmethod
@@ -633,9 +782,9 @@ class NDArray(Expr):
633
782
  for i in range(len(self)):
634
783
  yield self[IndexKey.int(Int(i))]
635
784
 
636
- def __getitem__(self, key: IndexKey) -> NDArray: ...
785
+ def __getitem__(self, key: IndexKeyLike) -> NDArray: ...
637
786
 
638
- def __setitem__(self, key: IndexKey, value: NDArray) -> None: ...
787
+ def __setitem__(self, key: IndexKeyLike, value: NDArray) -> None: ...
639
788
 
640
789
  def __lt__(self, other: NDArray) -> NDArray: ...
641
790
 
@@ -720,20 +869,15 @@ class NDArray(Expr):
720
869
  @classmethod
721
870
  def vector(cls, values: TupleValue) -> NDArray: ...
722
871
 
723
- def index(self, indices: TupleInt) -> Value:
872
+ def index(self, indices: TupleIntLike) -> Value:
724
873
  """
725
874
  Return the value at the given indices.
726
875
  """
727
876
 
728
877
 
729
- @function
730
- def ndarray_index(x: NDArray) -> IndexKey:
731
- """
732
- Indexes by a masked array
733
- """
878
+ NDArrayLike: TypeAlias = NDArray | ValueLike | TupleValueLike
734
879
 
735
-
736
- converter(NDArray, IndexKey, ndarray_index)
880
+ converter(NDArray, IndexKey, IndexKey.ndarray)
737
881
  converter(Value, NDArray, NDArray.scalar)
738
882
  # Need this if we want to use ints in slices of arrays coming from 1d arrays, but make it more expensive
739
883
  # to prefer upcasting in the other direction when we can, which is safter at runtime
@@ -742,8 +886,19 @@ converter(TupleValue, NDArray, NDArray.vector)
742
886
 
743
887
 
744
888
  @array_api_ruleset.register
745
- def _ndarray(x: NDArray, b: Boolean, f: Float, fi1: f64, fi2: f64):
889
+ def _ndarray(
890
+ x: NDArray,
891
+ b: Boolean,
892
+ f: Float,
893
+ fi1: f64,
894
+ fi2: f64,
895
+ shape: TupleInt,
896
+ dtype: DType,
897
+ idx_fn: Callable[[TupleInt], Value],
898
+ ):
746
899
  return [
900
+ rewrite(NDArray(shape, dtype, idx_fn).shape).to(shape),
901
+ rewrite(NDArray(shape, dtype, idx_fn).dtype).to(dtype),
747
902
  rewrite(x.ndim).to(x.shape.length()),
748
903
  # rewrite(NDArray.scalar(Value.bool(b)).to_bool()).to(b),
749
904
  # Converting to a value requires a scalar bool value
@@ -1067,6 +1222,14 @@ def sqrt(x: NDArray) -> NDArray: ...
1067
1222
  def std(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray: ...
1068
1223
 
1069
1224
 
1225
+ @function
1226
+ def real(x: NDArray) -> NDArray: ...
1227
+
1228
+
1229
+ @function
1230
+ def conj(x: NDArray) -> NDArray: ...
1231
+
1232
+
1070
1233
  linalg = sys.modules[__name__]
1071
1234
 
1072
1235
 
@@ -1228,10 +1391,10 @@ def _scalar_math(v: Value, vs: TupleValue, i: Int):
1228
1391
 
1229
1392
 
1230
1393
  @array_api_ruleset.register
1231
- def _vector_math(v: Value, vs: TupleValue, i: Int):
1232
- yield rewrite(NDArray.vector(vs).shape).to(TupleInt(vs.length()))
1394
+ def _vector_math(v: Value, vs: TupleValue, ti: TupleInt):
1395
+ yield rewrite(NDArray.vector(vs).shape).to(TupleInt.single(vs.length()))
1233
1396
  yield rewrite(NDArray.vector(vs).dtype).to(vs[Int(0)].dtype)
1234
- yield rewrite(NDArray.vector(vs).index(TupleInt(i))).to(vs[i])
1397
+ yield rewrite(NDArray.vector(vs).index(ti)).to(vs[ti[0]])
1235
1398
 
1236
1399
 
1237
1400
  @array_api_ruleset.register
@@ -1239,15 +1402,22 @@ def _reshape_math(x: NDArray, shape: TupleInt, copy: OptionalBool):
1239
1402
  res = reshape(x, shape, copy)
1240
1403
 
1241
1404
  yield _demand_shape(res, x)
1405
+ # Demand shape length and index
1406
+ yield rule(res).then(shape.length(), shape[0])
1242
1407
 
1243
1408
  # Reshaping a vec to a vec is the same as the vec
1244
- yield rewrite(reshape(x, TupleInt(Int(-1)), copy)).to(x, eq(x.shape.length()).to(Int(1)))
1409
+ yield rewrite(res).to(
1410
+ x,
1411
+ eq(x.shape.length()).to(Int(1)),
1412
+ eq(shape.length()).to(Int(1)),
1413
+ eq(shape[0]).to(Int(-1)),
1414
+ )
1245
1415
 
1246
1416
 
1247
1417
  @array_api_ruleset.register
1248
1418
  def _indexing_pushdown(x: NDArray, shape: TupleInt, copy: OptionalBool, i: Int):
1249
1419
  # rewrite full getitem to indexec
1250
- yield rewrite(x[IndexKey.int(i)]).to(NDArray.scalar(x.index(TupleInt(i))))
1420
+ yield rewrite(x[IndexKey.int(i)]).to(NDArray.scalar(x.index(TupleInt.single(i))))
1251
1421
  # TODO: Multi index rewrite as well if all are ints
1252
1422
 
1253
1423
 
@@ -1273,7 +1443,7 @@ def _assume_dtype(x: NDArray, dtype: DType, idx: TupleInt):
1273
1443
 
1274
1444
 
1275
1445
  @function(mutates_first_arg=True)
1276
- def assume_shape(x: NDArray, shape: TupleInt) -> None:
1446
+ def assume_shape(x: NDArray, shape: TupleIntLike) -> None:
1277
1447
  """
1278
1448
  Asserts that the shape of x is shape.
1279
1449
  """
@@ -1368,8 +1538,23 @@ def try_evaling(expr: Expr, prim_expr: i64 | Bool) -> int | bool:
1368
1538
  egraph.register(expr)
1369
1539
  egraph.run(array_api_schedule)
1370
1540
  try:
1371
- return egraph.eval(prim_expr)
1541
+ extracted = egraph.extract(prim_expr)
1372
1542
  except EggSmolError as exc:
1373
- egraph.display(n_inline_leaves=2, split_primitive_outputs=True)
1374
- msg = "Cannot simplify:"
1375
- raise ValueError(msg, egraph.extract(expr)) from exc
1543
+ # Try giving some context, by showing the smallest version of the larger expression
1544
+ try:
1545
+ expr_extracted = egraph.extract(expr)
1546
+ except EggSmolError as inner_exc:
1547
+ raise ValueError(f"Cannot simplify {expr}") from inner_exc
1548
+ egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1549
+ msg = f"Cannot simplify to primitive {expr_extracted}"
1550
+ raise ValueError(msg) from exc
1551
+ return egraph.eval(extracted)
1552
+
1553
+ # string = (
1554
+ # egraph.as_egglog_string
1555
+ # + "\n"
1556
+ # + str(egraph._state.typed_expr_to_egg(cast(RuntimeExpr, prim_expr).__egg_typed_expr__))
1557
+ # )
1558
+ # # save to "tmp.egg"
1559
+ # with open("tmp.egg", "w") as f:
1560
+ # f.write(string)
@@ -14,24 +14,20 @@ def jit(fn: X) -> X:
14
14
  """
15
15
  Jit compiles a function
16
16
  """
17
- from IPython.display import SVG
18
-
19
17
  # 1. Create variables for each of the two args in the functions
20
18
  sig = inspect.signature(fn)
21
19
  arg1, arg2 = sig.parameters.keys()
22
-
23
- with EGraph() as egraph:
20
+ egraph = EGraph()
21
+ with egraph:
24
22
  res = fn(NDArray.var(arg1), NDArray.var(arg2))
25
23
  egraph.register(res)
26
24
  egraph.run(array_api_numba_schedule)
27
25
  res_optimized = egraph.extract(res)
28
- svg = SVG(egraph.graphviz_svg(split_primitive_outputs=True, n_inline_leaves=3))
26
+ # egraph.display(split_primitive_outputs=True, n_inline_leaves=3)
29
27
 
30
- egraph = EGraph()
31
28
  fn_program = ndarray_function_two(res_optimized, NDArray.var(arg1), NDArray.var(arg2))
32
29
  egraph.register(fn_program)
33
30
  egraph.run(array_api_program_gen_schedule)
34
- fn = cast(X, egraph.eval(fn_program.py_object))
35
- fn.egraph = svg # type: ignore[attr-defined]
31
+ fn = cast(X, egraph.eval(egraph.extract(fn_program.py_object)))
36
32
  fn.expr = res_optimized # type: ignore[attr-defined]
37
33
  return fn