egglog 6.1.0__cp312-none-win_amd64.whl → 7.1.0__cp312-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +1 -1
- egglog/bindings.cp312-win_amd64.pyd +0 -0
- egglog/bindings.pyi +9 -0
- egglog/builtins.py +42 -2
- egglog/conversion.py +177 -0
- egglog/declarations.py +354 -734
- egglog/egraph.py +602 -800
- egglog/egraph_state.py +456 -0
- egglog/exp/array_api.py +100 -88
- egglog/exp/array_api_numba.py +6 -1
- egglog/exp/siu_examples.py +35 -0
- egglog/pretty.py +464 -0
- egglog/runtime.py +279 -431
- egglog/thunk.py +71 -0
- egglog/type_constraint_solver.py +5 -2
- {egglog-6.1.0.dist-info → egglog-7.1.0.dist-info}/METADATA +7 -7
- {egglog-6.1.0.dist-info → egglog-7.1.0.dist-info}/RECORD +19 -14
- {egglog-6.1.0.dist-info → egglog-7.1.0.dist-info}/WHEEL +0 -0
- {egglog-6.1.0.dist-info → egglog-7.1.0.dist-info}/license_files/LICENSE +0 -0
egglog/exp/array_api.py
CHANGED
|
@@ -18,13 +18,13 @@ from egglog.runtime import RuntimeExpr
|
|
|
18
18
|
from .program_gen import *
|
|
19
19
|
|
|
20
20
|
if TYPE_CHECKING:
|
|
21
|
-
from collections.abc import Iterator
|
|
21
|
+
from collections.abc import Callable, Iterator
|
|
22
22
|
from types import ModuleType
|
|
23
23
|
|
|
24
24
|
# Pretend that exprs are numbers b/c sklearn does isinstance checks
|
|
25
25
|
numbers.Integral.register(RuntimeExpr)
|
|
26
26
|
|
|
27
|
-
array_api_ruleset = ruleset()
|
|
27
|
+
array_api_ruleset = ruleset(name="array_api_ruleset")
|
|
28
28
|
array_api_schedule = array_api_ruleset.saturate()
|
|
29
29
|
|
|
30
30
|
|
|
@@ -36,10 +36,14 @@ class Boolean(Expr):
|
|
|
36
36
|
@property
|
|
37
37
|
def bool(self) -> Bool: ...
|
|
38
38
|
|
|
39
|
-
def __or__(self, other:
|
|
39
|
+
def __or__(self, other: BooleanLike) -> Boolean: ...
|
|
40
40
|
|
|
41
|
-
def __and__(self, other:
|
|
41
|
+
def __and__(self, other: BooleanLike) -> Boolean: ...
|
|
42
42
|
|
|
43
|
+
def if_int(self, true_value: Int, false_value: Int) -> Int: ...
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
BooleanLike = Boolean | bool
|
|
43
47
|
|
|
44
48
|
TRUE = constant("TRUE", Boolean)
|
|
45
49
|
FALSE = constant("FALSE", Boolean)
|
|
@@ -47,7 +51,7 @@ converter(bool, Boolean, lambda x: TRUE if x else FALSE)
|
|
|
47
51
|
|
|
48
52
|
|
|
49
53
|
@array_api_ruleset.register
|
|
50
|
-
def _bool(x: Boolean):
|
|
54
|
+
def _bool(x: Boolean, i: Int, j: Int):
|
|
51
55
|
return [
|
|
52
56
|
rule(eq(x).to(TRUE)).then(set_(x.bool).to(Bool(True))),
|
|
53
57
|
rule(eq(x).to(FALSE)).then(set_(x.bool).to(Bool(False))),
|
|
@@ -55,82 +59,8 @@ def _bool(x: Boolean):
|
|
|
55
59
|
rewrite(FALSE | x).to(x),
|
|
56
60
|
rewrite(TRUE & x).to(x),
|
|
57
61
|
rewrite(FALSE & x).to(FALSE),
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
class DType(Expr):
|
|
62
|
-
float64: ClassVar[DType]
|
|
63
|
-
float32: ClassVar[DType]
|
|
64
|
-
int64: ClassVar[DType]
|
|
65
|
-
int32: ClassVar[DType]
|
|
66
|
-
object: ClassVar[DType]
|
|
67
|
-
bool: ClassVar[DType]
|
|
68
|
-
|
|
69
|
-
def __eq__(self, other: DType) -> Boolean: # type: ignore[override]
|
|
70
|
-
...
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
float64 = DType.float64
|
|
74
|
-
float32 = DType.float32
|
|
75
|
-
int32 = DType.int32
|
|
76
|
-
int64 = DType.int64
|
|
77
|
-
|
|
78
|
-
_DTYPES = [float64, float32, int32, int64, DType.object]
|
|
79
|
-
|
|
80
|
-
converter(type, DType, lambda x: convert(np.dtype(x), DType))
|
|
81
|
-
converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type: ignore[call-overload]
|
|
82
|
-
array_api_ruleset.register(
|
|
83
|
-
*(rewrite(l == r).to(TRUE if l is r else FALSE) for l, r in itertools.product(_DTYPES, repeat=2))
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
class IsDtypeKind(Expr):
|
|
88
|
-
NULL: ClassVar[IsDtypeKind]
|
|
89
|
-
|
|
90
|
-
@classmethod
|
|
91
|
-
def string(cls, s: StringLike) -> IsDtypeKind: ...
|
|
92
|
-
|
|
93
|
-
@classmethod
|
|
94
|
-
def dtype(cls, d: DType) -> IsDtypeKind: ...
|
|
95
|
-
|
|
96
|
-
@method(cost=10)
|
|
97
|
-
def __or__(self, other: IsDtypeKind) -> IsDtypeKind: ...
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
# TODO: Make kind more generic to support tuples.
|
|
101
|
-
@function
|
|
102
|
-
def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: ...
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x))
|
|
106
|
-
converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x))
|
|
107
|
-
converter(
|
|
108
|
-
tuple, IsDtypeKind, lambda x: convert(x[0], IsDtypeKind) | convert(x[1:], IsDtypeKind) if x else IsDtypeKind.NULL
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
@array_api_ruleset.register
|
|
113
|
-
def _isdtype(d: DType, k1: IsDtypeKind, k2: IsDtypeKind):
|
|
114
|
-
return [
|
|
115
|
-
rewrite(isdtype(DType.float32, IsDtypeKind.string("integral"))).to(FALSE),
|
|
116
|
-
rewrite(isdtype(DType.float64, IsDtypeKind.string("integral"))).to(FALSE),
|
|
117
|
-
rewrite(isdtype(DType.object, IsDtypeKind.string("integral"))).to(FALSE),
|
|
118
|
-
rewrite(isdtype(DType.int64, IsDtypeKind.string("integral"))).to(TRUE),
|
|
119
|
-
rewrite(isdtype(DType.int32, IsDtypeKind.string("integral"))).to(TRUE),
|
|
120
|
-
rewrite(isdtype(DType.float32, IsDtypeKind.string("real floating"))).to(TRUE),
|
|
121
|
-
rewrite(isdtype(DType.float64, IsDtypeKind.string("real floating"))).to(TRUE),
|
|
122
|
-
rewrite(isdtype(DType.object, IsDtypeKind.string("real floating"))).to(FALSE),
|
|
123
|
-
rewrite(isdtype(DType.int64, IsDtypeKind.string("real floating"))).to(FALSE),
|
|
124
|
-
rewrite(isdtype(DType.int32, IsDtypeKind.string("real floating"))).to(FALSE),
|
|
125
|
-
rewrite(isdtype(DType.float32, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
126
|
-
rewrite(isdtype(DType.float64, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
127
|
-
rewrite(isdtype(DType.object, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
128
|
-
rewrite(isdtype(DType.int64, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
129
|
-
rewrite(isdtype(DType.int32, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
130
|
-
rewrite(isdtype(d, IsDtypeKind.NULL)).to(FALSE),
|
|
131
|
-
rewrite(isdtype(d, IsDtypeKind.dtype(d))).to(TRUE),
|
|
132
|
-
rewrite(isdtype(d, k1 | k2)).to(isdtype(d, k1) | isdtype(d, k2)),
|
|
133
|
-
rewrite(k1 | IsDtypeKind.NULL).to(k1),
|
|
62
|
+
rewrite(TRUE.if_int(i, j)).to(i),
|
|
63
|
+
rewrite(FALSE.if_int(i, j)).to(j),
|
|
134
64
|
]
|
|
135
65
|
|
|
136
66
|
|
|
@@ -264,10 +194,13 @@ converter(int, Int, lambda x: Int(x))
|
|
|
264
194
|
|
|
265
195
|
|
|
266
196
|
class Float(Expr):
|
|
197
|
+
# Differentiate costs of three constructors so extraction is deterministic if all three are present
|
|
198
|
+
@method(cost=3)
|
|
267
199
|
def __init__(self, value: f64Like) -> None: ...
|
|
268
200
|
|
|
269
201
|
def abs(self) -> Float: ...
|
|
270
202
|
|
|
203
|
+
@method(cost=2)
|
|
271
204
|
@classmethod
|
|
272
205
|
def rational(cls, r: Rational) -> Float: ...
|
|
273
206
|
|
|
@@ -324,7 +257,7 @@ class TupleInt(Expr):
|
|
|
324
257
|
|
|
325
258
|
def __getitem__(self, i: Int) -> Int: ...
|
|
326
259
|
|
|
327
|
-
def
|
|
260
|
+
def fold(self, init: Int, f: Callable[[Int, Int], Int]) -> Int: ...
|
|
328
261
|
|
|
329
262
|
|
|
330
263
|
converter(
|
|
@@ -339,7 +272,7 @@ converter(
|
|
|
339
272
|
|
|
340
273
|
|
|
341
274
|
@array_api_ruleset.register
|
|
342
|
-
def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64):
|
|
275
|
+
def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64, f: Callable[[Int, Int], Int]):
|
|
343
276
|
return [
|
|
344
277
|
rewrite(ti + TupleInt.EMPTY).to(ti),
|
|
345
278
|
rewrite(TupleInt(i).length()).to(Int(1)),
|
|
@@ -348,10 +281,10 @@ def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64):
|
|
|
348
281
|
rewrite((TupleInt(i) + ti)[Int(0)]).to(i),
|
|
349
282
|
# Rule for indexing > 0
|
|
350
283
|
rule(eq(i).to((TupleInt(i2) + ti)[Int(k)]), k > 0).then(union(i).with_(ti[Int(k - 1)])),
|
|
351
|
-
#
|
|
352
|
-
rewrite(TupleInt(i)
|
|
353
|
-
rewrite(
|
|
354
|
-
rewrite(TupleInt.
|
|
284
|
+
# fold
|
|
285
|
+
rewrite(TupleInt.EMPTY.fold(i, f)).to(i),
|
|
286
|
+
rewrite(TupleInt(i2).fold(i, f)).to(f(i, i2)),
|
|
287
|
+
rewrite((TupleInt(i2) + ti).fold(i, f)).to(ti.fold(f(i, i2), f)),
|
|
355
288
|
]
|
|
356
289
|
|
|
357
290
|
|
|
@@ -366,6 +299,85 @@ converter(type(None), OptionalInt, lambda _: OptionalInt.none)
|
|
|
366
299
|
converter(Int, OptionalInt, OptionalInt.some)
|
|
367
300
|
|
|
368
301
|
|
|
302
|
+
class DType(Expr):
|
|
303
|
+
float64: ClassVar[DType]
|
|
304
|
+
float32: ClassVar[DType]
|
|
305
|
+
int64: ClassVar[DType]
|
|
306
|
+
int32: ClassVar[DType]
|
|
307
|
+
object: ClassVar[DType]
|
|
308
|
+
bool: ClassVar[DType]
|
|
309
|
+
|
|
310
|
+
def __eq__(self, other: DType) -> Boolean: # type: ignore[override]
|
|
311
|
+
...
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
float64 = DType.float64
|
|
315
|
+
float32 = DType.float32
|
|
316
|
+
int32 = DType.int32
|
|
317
|
+
int64 = DType.int64
|
|
318
|
+
|
|
319
|
+
_DTYPES = [float64, float32, int32, int64, DType.object]
|
|
320
|
+
|
|
321
|
+
converter(type, DType, lambda x: convert(np.dtype(x), DType))
|
|
322
|
+
converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type: ignore[call-overload]
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
@array_api_ruleset.register
|
|
326
|
+
def _():
|
|
327
|
+
for l, r in itertools.product(_DTYPES, repeat=2):
|
|
328
|
+
yield rewrite(l == r).to(TRUE if l is r else FALSE)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
class IsDtypeKind(Expr):
|
|
332
|
+
NULL: ClassVar[IsDtypeKind]
|
|
333
|
+
|
|
334
|
+
@classmethod
|
|
335
|
+
def string(cls, s: StringLike) -> IsDtypeKind: ...
|
|
336
|
+
|
|
337
|
+
@classmethod
|
|
338
|
+
def dtype(cls, d: DType) -> IsDtypeKind: ...
|
|
339
|
+
|
|
340
|
+
@method(cost=10)
|
|
341
|
+
def __or__(self, other: IsDtypeKind) -> IsDtypeKind: ...
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
# TODO: Make kind more generic to support tuples.
|
|
345
|
+
@function
|
|
346
|
+
def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: ...
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x))
|
|
350
|
+
converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x))
|
|
351
|
+
converter(
|
|
352
|
+
tuple, IsDtypeKind, lambda x: convert(x[0], IsDtypeKind) | convert(x[1:], IsDtypeKind) if x else IsDtypeKind.NULL
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
@array_api_ruleset.register
|
|
357
|
+
def _isdtype(d: DType, k1: IsDtypeKind, k2: IsDtypeKind):
|
|
358
|
+
return [
|
|
359
|
+
rewrite(isdtype(DType.float32, IsDtypeKind.string("integral"))).to(FALSE),
|
|
360
|
+
rewrite(isdtype(DType.float64, IsDtypeKind.string("integral"))).to(FALSE),
|
|
361
|
+
rewrite(isdtype(DType.object, IsDtypeKind.string("integral"))).to(FALSE),
|
|
362
|
+
rewrite(isdtype(DType.int64, IsDtypeKind.string("integral"))).to(TRUE),
|
|
363
|
+
rewrite(isdtype(DType.int32, IsDtypeKind.string("integral"))).to(TRUE),
|
|
364
|
+
rewrite(isdtype(DType.float32, IsDtypeKind.string("real floating"))).to(TRUE),
|
|
365
|
+
rewrite(isdtype(DType.float64, IsDtypeKind.string("real floating"))).to(TRUE),
|
|
366
|
+
rewrite(isdtype(DType.object, IsDtypeKind.string("real floating"))).to(FALSE),
|
|
367
|
+
rewrite(isdtype(DType.int64, IsDtypeKind.string("real floating"))).to(FALSE),
|
|
368
|
+
rewrite(isdtype(DType.int32, IsDtypeKind.string("real floating"))).to(FALSE),
|
|
369
|
+
rewrite(isdtype(DType.float32, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
370
|
+
rewrite(isdtype(DType.float64, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
371
|
+
rewrite(isdtype(DType.object, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
372
|
+
rewrite(isdtype(DType.int64, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
373
|
+
rewrite(isdtype(DType.int32, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
374
|
+
rewrite(isdtype(d, IsDtypeKind.NULL)).to(FALSE),
|
|
375
|
+
rewrite(isdtype(d, IsDtypeKind.dtype(d))).to(TRUE),
|
|
376
|
+
rewrite(isdtype(d, k1 | k2)).to(isdtype(d, k1) | isdtype(d, k2)),
|
|
377
|
+
rewrite(k1 | IsDtypeKind.NULL).to(k1),
|
|
378
|
+
]
|
|
379
|
+
|
|
380
|
+
|
|
369
381
|
class Slice(Expr):
|
|
370
382
|
def __init__(
|
|
371
383
|
self,
|
|
@@ -1334,7 +1346,7 @@ def _unique(xs: TupleValue, a: NDArray, shape: TupleInt, copy: OptionalBool):
|
|
|
1334
1346
|
|
|
1335
1347
|
@array_api_ruleset.register
|
|
1336
1348
|
def _size(x: NDArray):
|
|
1337
|
-
yield rewrite(x.size).to(x.shape.
|
|
1349
|
+
yield rewrite(x.size).to(x.shape.fold(Int(1), Int.__mul__))
|
|
1338
1350
|
|
|
1339
1351
|
|
|
1340
1352
|
@overload
|
egglog/exp/array_api_numba.py
CHANGED
|
@@ -31,7 +31,12 @@ def _std(y: NDArray, x: NDArray, i: Int):
|
|
|
31
31
|
axis = OptionalIntOrTuple.some(IntOrTuple.int(i))
|
|
32
32
|
# https://numpy.org/doc/stable/reference/generated/numpy.std.html
|
|
33
33
|
# "std = sqrt(mean(x)), where x = abs(a - a.mean())**2."
|
|
34
|
-
yield rewrite(
|
|
34
|
+
yield rewrite(
|
|
35
|
+
std(x, axis),
|
|
36
|
+
subsume=True,
|
|
37
|
+
).to(
|
|
38
|
+
sqrt(mean(square(x - mean(x, axis, keepdims=TRUE)), axis)),
|
|
39
|
+
)
|
|
35
40
|
|
|
36
41
|
|
|
37
42
|
# rewrite unique_counts to count each value one by one, since numba doesn't support np.unique(..., return_counts=True)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import egglog
|
|
4
|
+
|
|
5
|
+
from .array_api import Int
|
|
6
|
+
|
|
7
|
+
# https://github.com/sklam/pyasir/blob/c363ff4f8f91177700ad4108dd5042b9b97d8289/pyasir/tests/test_fib.py
|
|
8
|
+
|
|
9
|
+
# In progress - should be able to re-create this
|
|
10
|
+
# @df.func
|
|
11
|
+
# def fib_ir(n: pyasir.Int64) -> pyasir.Int64:
|
|
12
|
+
# @df.switch(n <= 1)
|
|
13
|
+
# def swt(n):
|
|
14
|
+
# @df.case(1)
|
|
15
|
+
# def case0(n):
|
|
16
|
+
# return 1
|
|
17
|
+
|
|
18
|
+
# @df.case(0)
|
|
19
|
+
# def case1(n):
|
|
20
|
+
# return fib_ir(n - 1) + fib_ir(n - 2)
|
|
21
|
+
|
|
22
|
+
# yield case0
|
|
23
|
+
# yield case1
|
|
24
|
+
|
|
25
|
+
# r = swt(n)
|
|
26
|
+
# return r
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# With something like this:
|
|
30
|
+
@egglog.function
|
|
31
|
+
def fib(n: Int) -> Int:
|
|
32
|
+
return (n <= Int(1)).if_int(
|
|
33
|
+
Int(1),
|
|
34
|
+
fib(n - Int(1)) + fib(n - Int(2)),
|
|
35
|
+
)
|