egglog 6.1.0__cp311-none-win_amd64.whl → 7.1.0__cp311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/exp/array_api.py CHANGED
@@ -18,13 +18,13 @@ from egglog.runtime import RuntimeExpr
18
18
  from .program_gen import *
19
19
 
20
20
  if TYPE_CHECKING:
21
- from collections.abc import Iterator
21
+ from collections.abc import Callable, Iterator
22
22
  from types import ModuleType
23
23
 
24
24
  # Pretend that exprs are numbers b/c sklearn does isinstance checks
25
25
  numbers.Integral.register(RuntimeExpr)
26
26
 
27
- array_api_ruleset = ruleset()
27
+ array_api_ruleset = ruleset(name="array_api_ruleset")
28
28
  array_api_schedule = array_api_ruleset.saturate()
29
29
 
30
30
 
@@ -36,10 +36,14 @@ class Boolean(Expr):
36
36
  @property
37
37
  def bool(self) -> Bool: ...
38
38
 
39
- def __or__(self, other: Boolean) -> Boolean: ...
39
+ def __or__(self, other: BooleanLike) -> Boolean: ...
40
40
 
41
- def __and__(self, other: Boolean) -> Boolean: ...
41
+ def __and__(self, other: BooleanLike) -> Boolean: ...
42
42
 
43
+ def if_int(self, true_value: Int, false_value: Int) -> Int: ...
44
+
45
+
46
+ BooleanLike = Boolean | bool
43
47
 
44
48
  TRUE = constant("TRUE", Boolean)
45
49
  FALSE = constant("FALSE", Boolean)
@@ -47,7 +51,7 @@ converter(bool, Boolean, lambda x: TRUE if x else FALSE)
47
51
 
48
52
 
49
53
  @array_api_ruleset.register
50
- def _bool(x: Boolean):
54
+ def _bool(x: Boolean, i: Int, j: Int):
51
55
  return [
52
56
  rule(eq(x).to(TRUE)).then(set_(x.bool).to(Bool(True))),
53
57
  rule(eq(x).to(FALSE)).then(set_(x.bool).to(Bool(False))),
@@ -55,82 +59,8 @@ def _bool(x: Boolean):
55
59
  rewrite(FALSE | x).to(x),
56
60
  rewrite(TRUE & x).to(x),
57
61
  rewrite(FALSE & x).to(FALSE),
58
- ]
59
-
60
-
61
- class DType(Expr):
62
- float64: ClassVar[DType]
63
- float32: ClassVar[DType]
64
- int64: ClassVar[DType]
65
- int32: ClassVar[DType]
66
- object: ClassVar[DType]
67
- bool: ClassVar[DType]
68
-
69
- def __eq__(self, other: DType) -> Boolean: # type: ignore[override]
70
- ...
71
-
72
-
73
- float64 = DType.float64
74
- float32 = DType.float32
75
- int32 = DType.int32
76
- int64 = DType.int64
77
-
78
- _DTYPES = [float64, float32, int32, int64, DType.object]
79
-
80
- converter(type, DType, lambda x: convert(np.dtype(x), DType))
81
- converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type: ignore[call-overload]
82
- array_api_ruleset.register(
83
- *(rewrite(l == r).to(TRUE if l is r else FALSE) for l, r in itertools.product(_DTYPES, repeat=2))
84
- )
85
-
86
-
87
- class IsDtypeKind(Expr):
88
- NULL: ClassVar[IsDtypeKind]
89
-
90
- @classmethod
91
- def string(cls, s: StringLike) -> IsDtypeKind: ...
92
-
93
- @classmethod
94
- def dtype(cls, d: DType) -> IsDtypeKind: ...
95
-
96
- @method(cost=10)
97
- def __or__(self, other: IsDtypeKind) -> IsDtypeKind: ...
98
-
99
-
100
- # TODO: Make kind more generic to support tuples.
101
- @function
102
- def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: ...
103
-
104
-
105
- converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x))
106
- converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x))
107
- converter(
108
- tuple, IsDtypeKind, lambda x: convert(x[0], IsDtypeKind) | convert(x[1:], IsDtypeKind) if x else IsDtypeKind.NULL
109
- )
110
-
111
-
112
- @array_api_ruleset.register
113
- def _isdtype(d: DType, k1: IsDtypeKind, k2: IsDtypeKind):
114
- return [
115
- rewrite(isdtype(DType.float32, IsDtypeKind.string("integral"))).to(FALSE),
116
- rewrite(isdtype(DType.float64, IsDtypeKind.string("integral"))).to(FALSE),
117
- rewrite(isdtype(DType.object, IsDtypeKind.string("integral"))).to(FALSE),
118
- rewrite(isdtype(DType.int64, IsDtypeKind.string("integral"))).to(TRUE),
119
- rewrite(isdtype(DType.int32, IsDtypeKind.string("integral"))).to(TRUE),
120
- rewrite(isdtype(DType.float32, IsDtypeKind.string("real floating"))).to(TRUE),
121
- rewrite(isdtype(DType.float64, IsDtypeKind.string("real floating"))).to(TRUE),
122
- rewrite(isdtype(DType.object, IsDtypeKind.string("real floating"))).to(FALSE),
123
- rewrite(isdtype(DType.int64, IsDtypeKind.string("real floating"))).to(FALSE),
124
- rewrite(isdtype(DType.int32, IsDtypeKind.string("real floating"))).to(FALSE),
125
- rewrite(isdtype(DType.float32, IsDtypeKind.string("complex floating"))).to(FALSE),
126
- rewrite(isdtype(DType.float64, IsDtypeKind.string("complex floating"))).to(FALSE),
127
- rewrite(isdtype(DType.object, IsDtypeKind.string("complex floating"))).to(FALSE),
128
- rewrite(isdtype(DType.int64, IsDtypeKind.string("complex floating"))).to(FALSE),
129
- rewrite(isdtype(DType.int32, IsDtypeKind.string("complex floating"))).to(FALSE),
130
- rewrite(isdtype(d, IsDtypeKind.NULL)).to(FALSE),
131
- rewrite(isdtype(d, IsDtypeKind.dtype(d))).to(TRUE),
132
- rewrite(isdtype(d, k1 | k2)).to(isdtype(d, k1) | isdtype(d, k2)),
133
- rewrite(k1 | IsDtypeKind.NULL).to(k1),
62
+ rewrite(TRUE.if_int(i, j)).to(i),
63
+ rewrite(FALSE.if_int(i, j)).to(j),
134
64
  ]
135
65
 
136
66
 
@@ -264,10 +194,13 @@ converter(int, Int, lambda x: Int(x))
264
194
 
265
195
 
266
196
  class Float(Expr):
197
+ # Differentiate costs of three constructors so extraction is deterministic if all three are present
198
+ @method(cost=3)
267
199
  def __init__(self, value: f64Like) -> None: ...
268
200
 
269
201
  def abs(self) -> Float: ...
270
202
 
203
+ @method(cost=2)
271
204
  @classmethod
272
205
  def rational(cls, r: Rational) -> Float: ...
273
206
 
@@ -324,7 +257,7 @@ class TupleInt(Expr):
324
257
 
325
258
  def __getitem__(self, i: Int) -> Int: ...
326
259
 
327
- def product(self) -> Int: ...
260
+ def fold(self, init: Int, f: Callable[[Int, Int], Int]) -> Int: ...
328
261
 
329
262
 
330
263
  converter(
@@ -339,7 +272,7 @@ converter(
339
272
 
340
273
 
341
274
  @array_api_ruleset.register
342
- def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64):
275
+ def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64, f: Callable[[Int, Int], Int]):
343
276
  return [
344
277
  rewrite(ti + TupleInt.EMPTY).to(ti),
345
278
  rewrite(TupleInt(i).length()).to(Int(1)),
@@ -348,10 +281,10 @@ def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64):
348
281
  rewrite((TupleInt(i) + ti)[Int(0)]).to(i),
349
282
  # Rule for indexing > 0
350
283
  rule(eq(i).to((TupleInt(i2) + ti)[Int(k)]), k > 0).then(union(i).with_(ti[Int(k - 1)])),
351
- # Product
352
- rewrite(TupleInt(i).product()).to(i),
353
- rewrite((TupleInt(i) + ti).product()).to(i * ti.product()),
354
- rewrite(TupleInt.EMPTY.product()).to(Int(1)),
284
+ # fold
285
+ rewrite(TupleInt.EMPTY.fold(i, f)).to(i),
286
+ rewrite(TupleInt(i2).fold(i, f)).to(f(i, i2)),
287
+ rewrite((TupleInt(i2) + ti).fold(i, f)).to(ti.fold(f(i, i2), f)),
355
288
  ]
356
289
 
357
290
 
@@ -366,6 +299,85 @@ converter(type(None), OptionalInt, lambda _: OptionalInt.none)
366
299
  converter(Int, OptionalInt, OptionalInt.some)
367
300
 
368
301
 
302
+ class DType(Expr):
303
+ float64: ClassVar[DType]
304
+ float32: ClassVar[DType]
305
+ int64: ClassVar[DType]
306
+ int32: ClassVar[DType]
307
+ object: ClassVar[DType]
308
+ bool: ClassVar[DType]
309
+
310
+ def __eq__(self, other: DType) -> Boolean: # type: ignore[override]
311
+ ...
312
+
313
+
314
+ float64 = DType.float64
315
+ float32 = DType.float32
316
+ int32 = DType.int32
317
+ int64 = DType.int64
318
+
319
+ _DTYPES = [float64, float32, int32, int64, DType.object]
320
+
321
+ converter(type, DType, lambda x: convert(np.dtype(x), DType))
322
+ converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type: ignore[call-overload]
323
+
324
+
325
+ @array_api_ruleset.register
326
+ def _():
327
+ for l, r in itertools.product(_DTYPES, repeat=2):
328
+ yield rewrite(l == r).to(TRUE if l is r else FALSE)
329
+
330
+
331
+ class IsDtypeKind(Expr):
332
+ NULL: ClassVar[IsDtypeKind]
333
+
334
+ @classmethod
335
+ def string(cls, s: StringLike) -> IsDtypeKind: ...
336
+
337
+ @classmethod
338
+ def dtype(cls, d: DType) -> IsDtypeKind: ...
339
+
340
+ @method(cost=10)
341
+ def __or__(self, other: IsDtypeKind) -> IsDtypeKind: ...
342
+
343
+
344
+ # TODO: Make kind more generic to support tuples.
345
+ @function
346
+ def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: ...
347
+
348
+
349
+ converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x))
350
+ converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x))
351
+ converter(
352
+ tuple, IsDtypeKind, lambda x: convert(x[0], IsDtypeKind) | convert(x[1:], IsDtypeKind) if x else IsDtypeKind.NULL
353
+ )
354
+
355
+
356
+ @array_api_ruleset.register
357
+ def _isdtype(d: DType, k1: IsDtypeKind, k2: IsDtypeKind):
358
+ return [
359
+ rewrite(isdtype(DType.float32, IsDtypeKind.string("integral"))).to(FALSE),
360
+ rewrite(isdtype(DType.float64, IsDtypeKind.string("integral"))).to(FALSE),
361
+ rewrite(isdtype(DType.object, IsDtypeKind.string("integral"))).to(FALSE),
362
+ rewrite(isdtype(DType.int64, IsDtypeKind.string("integral"))).to(TRUE),
363
+ rewrite(isdtype(DType.int32, IsDtypeKind.string("integral"))).to(TRUE),
364
+ rewrite(isdtype(DType.float32, IsDtypeKind.string("real floating"))).to(TRUE),
365
+ rewrite(isdtype(DType.float64, IsDtypeKind.string("real floating"))).to(TRUE),
366
+ rewrite(isdtype(DType.object, IsDtypeKind.string("real floating"))).to(FALSE),
367
+ rewrite(isdtype(DType.int64, IsDtypeKind.string("real floating"))).to(FALSE),
368
+ rewrite(isdtype(DType.int32, IsDtypeKind.string("real floating"))).to(FALSE),
369
+ rewrite(isdtype(DType.float32, IsDtypeKind.string("complex floating"))).to(FALSE),
370
+ rewrite(isdtype(DType.float64, IsDtypeKind.string("complex floating"))).to(FALSE),
371
+ rewrite(isdtype(DType.object, IsDtypeKind.string("complex floating"))).to(FALSE),
372
+ rewrite(isdtype(DType.int64, IsDtypeKind.string("complex floating"))).to(FALSE),
373
+ rewrite(isdtype(DType.int32, IsDtypeKind.string("complex floating"))).to(FALSE),
374
+ rewrite(isdtype(d, IsDtypeKind.NULL)).to(FALSE),
375
+ rewrite(isdtype(d, IsDtypeKind.dtype(d))).to(TRUE),
376
+ rewrite(isdtype(d, k1 | k2)).to(isdtype(d, k1) | isdtype(d, k2)),
377
+ rewrite(k1 | IsDtypeKind.NULL).to(k1),
378
+ ]
379
+
380
+
369
381
  class Slice(Expr):
370
382
  def __init__(
371
383
  self,
@@ -1334,7 +1346,7 @@ def _unique(xs: TupleValue, a: NDArray, shape: TupleInt, copy: OptionalBool):
1334
1346
 
1335
1347
  @array_api_ruleset.register
1336
1348
  def _size(x: NDArray):
1337
- yield rewrite(x.size).to(x.shape.product())
1349
+ yield rewrite(x.size).to(x.shape.fold(Int(1), Int.__mul__))
1338
1350
 
1339
1351
 
1340
1352
  @overload
@@ -31,7 +31,12 @@ def _std(y: NDArray, x: NDArray, i: Int):
31
31
  axis = OptionalIntOrTuple.some(IntOrTuple.int(i))
32
32
  # https://numpy.org/doc/stable/reference/generated/numpy.std.html
33
33
  # "std = sqrt(mean(x)), where x = abs(a - a.mean())**2."
34
- yield rewrite(std(x, axis), subsume=True).to(sqrt(mean(square(x - mean(x, axis, keepdims=TRUE)), axis)))
34
+ yield rewrite(
35
+ std(x, axis),
36
+ subsume=True,
37
+ ).to(
38
+ sqrt(mean(square(x - mean(x, axis, keepdims=TRUE)), axis)),
39
+ )
35
40
 
36
41
 
37
42
  # rewrite unique_counts to count each value one by one, since numba doesn't support np.unique(..., return_counts=True)
@@ -0,0 +1,35 @@
1
+ from __future__ import annotations
2
+
3
+ import egglog
4
+
5
+ from .array_api import Int
6
+
7
+ # https://github.com/sklam/pyasir/blob/c363ff4f8f91177700ad4108dd5042b9b97d8289/pyasir/tests/test_fib.py
8
+
9
+ # In progress - should be able to re-create this
10
+ # @df.func
11
+ # def fib_ir(n: pyasir.Int64) -> pyasir.Int64:
12
+ # @df.switch(n <= 1)
13
+ # def swt(n):
14
+ # @df.case(1)
15
+ # def case0(n):
16
+ # return 1
17
+
18
+ # @df.case(0)
19
+ # def case1(n):
20
+ # return fib_ir(n - 1) + fib_ir(n - 2)
21
+
22
+ # yield case0
23
+ # yield case1
24
+
25
+ # r = swt(n)
26
+ # return r
27
+
28
+
29
+ # With something like this:
30
+ @egglog.function
31
+ def fib(n: Int) -> Int:
32
+ return (n <= Int(1)).if_int(
33
+ Int(1),
34
+ fib(n - Int(1)) + fib(n - Int(2)),
35
+ )