egglog 0.4.0__cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

@@ -0,0 +1,5 @@
1
+ Examples Gallery
2
+ ================
3
+
4
+ This is a gallery of examples, most of which were translated from the original
5
+ `egglog rust examples <https://github.com/egraphs-good/egglog/tree/08a6e8fecdb77e6ba72a1b1d9ff4aff33229912c/tests>`_.
File without changes
@@ -0,0 +1,43 @@
1
+ """
2
+ Basic equality saturation example.
3
+ ==================================
4
+ """
5
+ from __future__ import annotations
6
+
7
+ from egglog import *
8
+
9
+ egraph = EGraph()
10
+
11
+
12
+ @egraph.class_
13
+ class Num(BaseExpr):
14
+ def __init__(self, value: i64Like) -> None:
15
+ ...
16
+
17
+ @classmethod
18
+ def var(cls, name: StringLike) -> Num: # type: ignore[empty-body]
19
+ ...
20
+
21
+ def __add__(self, other: Num) -> Num: # type: ignore[empty-body]
22
+ ...
23
+
24
+ def __mul__(self, other: Num) -> Num: # type: ignore[empty-body]
25
+ ...
26
+
27
+
28
+ # expr1 = 2 * (x + 3)
29
+ expr1 = egraph.define("expr1", Num(2) * (Num.var("x") + Num(3)))
30
+ # expr2 = 6 + 2 * x
31
+ expr2 = egraph.define("expr2", Num(6) + Num(2) * Num.var("x"))
32
+
33
+ a, b, c = vars_("a b c", Num)
34
+ i, j = vars_("i j", i64)
35
+ egraph.register(
36
+ rewrite(a + b).to(b + a),
37
+ rewrite(a * (b + c)).to((a * b) + (a * c)),
38
+ rewrite(Num(i) + Num(j)).to(Num(i + j)),
39
+ rewrite(Num(i) * Num(j)).to(Num(i * j)),
40
+ )
41
+ egraph.run(10)
42
+ egraph.check(eq(expr1).to(expr2))
43
+ egraph
egglog/examples/fib.py ADDED
@@ -0,0 +1,28 @@
1
+ """
2
+ Fibonacci numbers example
3
+ =========================
4
+ """
5
+ from __future__ import annotations
6
+
7
+ from egglog import *
8
+
9
+ egraph = EGraph()
10
+
11
+
12
+ @egraph.function
13
+ def fib(x: i64Like) -> i64: # type: ignore[empty-body]
14
+ ...
15
+
16
+
17
+ f0, f1, x = vars_("f0 f1 x", i64)
18
+ egraph.register(
19
+ set_(fib(0)).to(i64(1)),
20
+ set_(fib(1)).to(i64(1)),
21
+ rule(
22
+ eq(f0).to(fib(x)),
23
+ eq(f1).to(fib(x + 1)),
24
+ ).then(set_(fib(x + 2)).to(f0 + f1)),
25
+ )
26
+ egraph.run(7)
27
+ egraph.check(eq(fib(i64(7))).to(i64(21)))
28
+ egraph
@@ -0,0 +1,310 @@
1
+ """
2
+ Lambda Calculus
3
+ ===============
4
+ """
5
+ # mypy: disable-error-code=empty-body
6
+ from __future__ import annotations
7
+
8
+ from typing import Callable, ClassVar
9
+
10
+ from egglog import *
11
+
12
+ egraph = EGraph()
13
+
14
+ # TODO: Debug extracting constants
15
+
16
+
17
+ @egraph.class_
18
+ class Val(BaseExpr):
19
+ """
20
+ A value is a number or a boolean.
21
+ """
22
+
23
+ TRUE: ClassVar[Val]
24
+ FALSE: ClassVar[Val]
25
+
26
+ def __init__(self, v: i64Like) -> None:
27
+ ...
28
+
29
+
30
+ @egraph.class_
31
+ class Var(BaseExpr):
32
+ def __init__(self, v: StringLike) -> None:
33
+ ...
34
+
35
+
36
+ @egraph.class_
37
+ class Term(BaseExpr):
38
+ @classmethod
39
+ def val(cls, v: Val) -> Term:
40
+ ...
41
+
42
+ @classmethod
43
+ def var(cls, v: Var) -> Term:
44
+ ...
45
+
46
+ def __add__(self, other: Term) -> Term:
47
+ ...
48
+
49
+ def __eq__(self, other: Term) -> Term: # type: ignore[override]
50
+ ...
51
+
52
+ def __call__(self, other: Term) -> Term:
53
+ ...
54
+
55
+ def eval(self) -> Val:
56
+ ...
57
+
58
+ def v(self) -> Var:
59
+ ...
60
+
61
+
62
+ @egraph.function
63
+ def lam(x: Var, t: Term) -> Term:
64
+ ...
65
+
66
+
67
+ @egraph.function
68
+ def let_(x: Var, t: Term, b: Term) -> Term:
69
+ ...
70
+
71
+
72
+ @egraph.function
73
+ def fix(x: Var, t: Term) -> Term:
74
+ ...
75
+
76
+
77
+ @egraph.function
78
+ def if_(c: Term, t: Term, f: Term) -> Term:
79
+ ...
80
+
81
+
82
+ StringSet = Set[Var]
83
+
84
+
85
+ @egraph.function(merge=lambda old, new: old & new)
86
+ def freer(t: Term) -> StringSet:
87
+ ...
88
+
89
+
90
+ (v, v1, v2) = vars_("v v1 v2", Val)
91
+ (t, t1, t2, t3, t4) = vars_("t t1 t2 t3 t4", Term)
92
+ (x, y) = vars_("x y", Var)
93
+ fv, fv1, fv2, fv3 = vars_("fv fv1 fv2 fv3", StringSet)
94
+ i1, i2 = vars_("i1 i2", i64)
95
+ egraph.register(
96
+ # freer
97
+ rule(eq(t).to(Term.val(v))).then(set_(freer(t)).to(StringSet.empty())),
98
+ rule(eq(t).to(Term.var(x))).then(set_(freer(t)).to(StringSet.empty().insert(x))),
99
+ rule(eq(t).to(t1 + t2), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
100
+ rule(eq(t).to(t1 == t2), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
101
+ rule(eq(t).to(t1(t2)), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
102
+ rule(eq(t).to(lam(x, t1)), eq(freer(t1)).to(fv)).then(set_(freer(t)).to(fv.remove(x))),
103
+ rule(eq(t).to(let_(x, t1, t2)), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(
104
+ set_(freer(t)).to(fv1.remove(x) | fv2)
105
+ ),
106
+ rule(eq(t).to(fix(x, t1)), eq(freer(t1)).to(fv)).then(set_(freer(t)).to(fv.remove(x))),
107
+ rule(eq(t).to(if_(t1, t2, t3)), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2), eq(freer(t3)).to(fv3)).then(
108
+ set_(freer(t)).to(fv1 | fv2 | fv3)
109
+ ),
110
+ # eval
111
+ rule(eq(t).to(Term.val(v))).then(set_(t.eval()).to(v)),
112
+ rule(eq(t).to(t1 + t2), eq(Val(i1)).to(t1.eval()), eq(Val(i2)).to(t2.eval())).then(
113
+ union(t.eval()).with_(Val(i1 + i2))
114
+ ),
115
+ rule(eq(t).to(t1 == t2), eq(t1.eval()).to(t2.eval())).then(union(t.eval()).with_(Val.TRUE)),
116
+ rule(eq(t).to(t1 == t2), eq(t1.eval()).to(v1), eq(t2.eval()).to(v2), v1 != v2).then(
117
+ union(t.eval()).with_(Val.FALSE)
118
+ ),
119
+ rule(eq(v).to(t.eval())).then(union(t).with_(Term.val(v))),
120
+ # if
121
+ rewrite(if_(Term.val(Val.TRUE), t1, t2)).to(t1),
122
+ rewrite(if_(Term.val(Val.FALSE), t1, t2)).to(t2),
123
+ # if-elim
124
+ # Adds let rules so next one can match on them
125
+ rule(eq(t).to(if_(Term.var(x) == t1, t2, t3))).then(let_(x, t1, t2), let_(x, t1, t3)),
126
+ rewrite(if_(Term.var(x) == t1, t2, t3)).to(
127
+ t3,
128
+ eq(let_(x, t1, t2)).to(let_(x, t1, t3)),
129
+ ),
130
+ # add-comm
131
+ rewrite(t1 + t2).to(t2 + t1),
132
+ # add-assoc
133
+ rewrite((t1 + t2) + t3).to(t1 + (t2 + t3)),
134
+ # eq-comm
135
+ rewrite(t1 == t2).to(t2 == t1),
136
+ # Fix
137
+ rewrite(fix(x, t)).to(let_(x, fix(x, t), t)),
138
+ # beta reduction
139
+ rewrite(lam(x, t)(t1)).to(let_(x, t1, t)),
140
+ # let-app
141
+ rewrite(let_(x, t, t1(t2))).to(let_(x, t, t1)(let_(x, t, t2))),
142
+ # let-add
143
+ rewrite(let_(x, t, t1 + t2)).to(let_(x, t, t1) + let_(x, t, t2)),
144
+ # let-eq
145
+ rewrite(let_(x, t, t1 == t2)).to(let_(x, t, t1) == let_(x, t, t2)),
146
+ # let-const
147
+ rewrite(let_(x, t, Term.val(v))).to(Term.val(v)),
148
+ # let-if
149
+ rewrite(let_(x, t, if_(t1, t2, t3))).to(if_(let_(x, t, t1), let_(x, t, t2), let_(x, t, t3))),
150
+ # let-var-same
151
+ rewrite(let_(x, t, Term.var(x))).to(t),
152
+ # let-var-diff
153
+ rewrite(let_(x, t, Term.var(y))).to(Term.var(y), x != y),
154
+ # let-lam-same
155
+ rewrite(let_(x, t, lam(x, t1))).to(lam(x, t1)),
156
+ # let-lam-diff
157
+ rewrite(let_(x, t, lam(y, t1))).to(lam(y, let_(x, t, t1)), x != y, eq(fv).to(freer(t)), fv.not_contains(y)),
158
+ rule(eq(t).to(let_(x, t1, lam(y, t2))), x != y, eq(fv).to(freer(t1)), fv.contains(y)).then(
159
+ union(t).with_(lam(t.v(), let_(x, t1, let_(y, Term.var(t.v()), t2))))
160
+ ),
161
+ )
162
+
163
+ result = egraph.relation("result")
164
+
165
+
166
+ def l(fn: Callable[[Term], Term]) -> Term: # noqa
167
+ """
168
+ Create a lambda term from a function
169
+ """
170
+ # Use first var name from fn
171
+ x = fn.__code__.co_varnames[0]
172
+ return lam(Var(x), fn(Term.var(Var(x))))
173
+
174
+
175
+ def assert_simplifies(left: BaseExpr, right: BaseExpr) -> None:
176
+ """
177
+ Simplify and print
178
+ """
179
+ with egraph:
180
+ egraph.register(left)
181
+ egraph.run(30)
182
+ res = egraph.extract(left)
183
+ print(f"{left} ➡ {res}")
184
+ egraph.check(eq(left).to(right))
185
+
186
+
187
+ assert_simplifies((Term.val(Val(1))).eval(), Val(1))
188
+ assert_simplifies((Term.val(Val(1)) + Term.val(Val(2))).eval(), Val(3))
189
+
190
+
191
+ # lambda under
192
+ assert_simplifies(
193
+ l(lambda x: Term.val(Val(4)) + l(lambda y: y)(Term.val(Val(4)))),
194
+ l(lambda x: Term.val(Val(8))),
195
+ )
196
+ # lambda if elim
197
+ a = Term.var(Var("a"))
198
+ b = Term.var(Var("b"))
199
+ with egraph:
200
+ e1 = egraph.define("e1", if_(a == b, a + a, a + b))
201
+ egraph.run(10)
202
+ egraph.check(eq(e1).to(a + b))
203
+
204
+ # lambda let simple
205
+ x = Var("x")
206
+ y = Var("y")
207
+ assert_simplifies(
208
+ let_(x, Term.val(Val(0)), let_(y, Term.val(Val(1)), Term.var(x) + Term.var(y))),
209
+ Term.val(Val(1)),
210
+ )
211
+ # lambda capture
212
+ assert_simplifies(
213
+ let_(x, Term.val(Val(1)), l(lambda x: x)),
214
+ l(lambda x: x),
215
+ )
216
+ # lambda capture free
217
+ with egraph:
218
+ e5 = egraph.define("e5", let_(y, Term.var(x) + Term.var(x), l(lambda x: Term.var(y))))
219
+ egraph.run(10)
220
+ egraph.check(freer(l(lambda x: Term.var(y))).contains(y))
221
+ egraph.check_fail(eq(e5).to(l(lambda x: x + x)))
222
+
223
+ # lambda_closure_not_seven
224
+ with egraph:
225
+ e6 = egraph.define(
226
+ "e6",
227
+ let_(
228
+ Var("five"),
229
+ Term.val(Val(5)),
230
+ let_(
231
+ Var("add-five"),
232
+ l(lambda x: x + Term.var(Var("five"))),
233
+ let_(Var("five"), Term.val(Val(6)), Term.var(Var("add-five"))(Term.val(Val(1)))),
234
+ ),
235
+ ),
236
+ )
237
+ egraph.run(10)
238
+ egraph.check_fail(eq(e6).to(Term.val(Val(7))))
239
+ egraph.check(eq(e6).to(Term.val(Val(6))))
240
+
241
+
242
+ # lambda_compose
243
+ with egraph:
244
+ compose = Var("compose")
245
+ add1 = Var("add1")
246
+ e7 = egraph.define(
247
+ "e7",
248
+ let_(
249
+ compose,
250
+ l(
251
+ lambda f: l(
252
+ lambda g: l(
253
+ lambda x: f(g(x)),
254
+ ),
255
+ ),
256
+ ),
257
+ let_(
258
+ add1,
259
+ l(lambda y: y + Term.val(Val(1))),
260
+ Term.var(compose)(Term.var(add1))(Term.var(add1)),
261
+ ),
262
+ ),
263
+ )
264
+ egraph.run(20)
265
+ egraph.register(
266
+ rule(
267
+ eq(t1).to(l(lambda x: Term.val(Val(1)) + l(lambda y: Term.val(Val(1)) + y)(x))),
268
+ eq(t2).to(l(lambda x: x + Term.val(Val(2)))),
269
+ ).then(result())
270
+ )
271
+ egraph.run(1)
272
+ egraph.check(result())
273
+
274
+
275
+ # lambda_if_simple
276
+ assert_simplifies(if_(Term.val(Val(1)) == Term.val(Val(1)), Term.val(Val(7)), Term.val(Val(9))), Term.val(Val(7)))
277
+
278
+
279
+ # # lambda_compose_many
280
+ assert_simplifies(
281
+ let_(
282
+ compose,
283
+ l(lambda f: l(lambda g: l(lambda x: f(g(x))))),
284
+ let_(
285
+ add1,
286
+ l(lambda y: y + Term.val(Val(1))),
287
+ Term.var(compose)(Term.var(add1))(
288
+ Term.var(compose)(Term.var(add1))(
289
+ Term.var(compose)(Term.var(add1))(
290
+ Term.var(compose)(Term.var(add1))(
291
+ Term.var(compose)(Term.var(add1))(Term.var(compose)(Term.var(add1))(Term.var(add1)))
292
+ )
293
+ )
294
+ )
295
+ ),
296
+ ),
297
+ ),
298
+ l(lambda x: x + Term.val(Val(7))),
299
+ )
300
+
301
+ # lambda_if
302
+ zeroone = Var("zeroone")
303
+ assert_simplifies(
304
+ let_(
305
+ zeroone,
306
+ l(lambda x: if_(x == Term.val(Val(0)), Term.val(Val(0)), Term.val(Val(1)))),
307
+ Term.var(zeroone)(Term.val(Val(0))) + Term.var(zeroone)(Term.val(Val(10))),
308
+ ),
309
+ Term.val(Val(1)),
310
+ )
@@ -0,0 +1,184 @@
1
+ """
2
+ Matrix multiplication and Kronecker product.
3
+ ============================================
4
+ """
5
+ from __future__ import annotations
6
+
7
+ from egglog import *
8
+
9
+ egraph = EGraph()
10
+
11
+
12
+ @egraph.class_
13
+ class Dim(BaseExpr):
14
+ """
15
+ A dimension of a matix.
16
+
17
+ >>> Dim(3) * Dim.named("n")
18
+ Dim(3) * Dim.named("n")
19
+ """
20
+
21
+ @egraph.method(egg_fn="Lit")
22
+ def __init__(self, value: i64Like) -> None:
23
+ ...
24
+
25
+ @egraph.method(egg_fn="NamedDim")
26
+ @classmethod
27
+ def named(cls, name: StringLike) -> Dim: # type: ignore[empty-body]
28
+ ...
29
+
30
+ @egraph.method(egg_fn="Times")
31
+ def __mul__(self, other: Dim) -> Dim: # type: ignore[empty-body]
32
+ ...
33
+
34
+
35
+ a, b, c, n = vars_("a b c n", Dim)
36
+ i, j = vars_("i j", i64)
37
+ egraph.register(
38
+ rewrite(a * (b * c)).to((a * b) * c),
39
+ rewrite((a * b) * c).to(a * (b * c)),
40
+ rewrite(Dim(i) * Dim(j)).to(Dim(i * j)),
41
+ rewrite(a * b).to(b * a),
42
+ )
43
+
44
+
45
+ @egraph.class_(egg_sort="MExpr")
46
+ class Matrix(BaseExpr):
47
+ @egraph.method(egg_fn="Id")
48
+ @classmethod
49
+ def identity(cls, dim: Dim) -> Matrix: # type: ignore[empty-body]
50
+ """
51
+ Create an identity matrix of the given dimension.
52
+ """
53
+ ...
54
+
55
+ @egraph.method(egg_fn="NamedMat")
56
+ @classmethod
57
+ def named(cls, name: StringLike) -> Matrix: # type: ignore[empty-body]
58
+ """
59
+ Create a named matrix.
60
+ """
61
+ ...
62
+
63
+ @egraph.method(egg_fn="MMul")
64
+ def __matmul__(self, other: Matrix) -> Matrix: # type: ignore[empty-body]
65
+ """
66
+ Matrix multiplication.
67
+ """
68
+ ...
69
+
70
+ @egraph.method(egg_fn="nrows")
71
+ def nrows(self) -> Dim: # type: ignore[empty-body]
72
+ """
73
+ Number of rows in the matrix.
74
+ """
75
+ ...
76
+
77
+ @egraph.method(egg_fn="ncols")
78
+ def ncols(self) -> Dim: # type: ignore[empty-body]
79
+ """
80
+ Number of columns in the matrix.
81
+ """
82
+ ...
83
+
84
+
85
+ @egraph.function(egg_fn="Kron")
86
+ def kron(a: Matrix, b: Matrix) -> Matrix: # type: ignore[empty-body]
87
+ """
88
+ Kronecker product of two matrices.
89
+
90
+ https://en.wikipedia.org/wiki/Kronecker_product#Definition
91
+ """
92
+ ...
93
+
94
+
95
+ A, B, C, D = vars_("A B C D", Matrix)
96
+ egraph.register(
97
+ # The dimensions of a kronecker product are the product of the dimensions
98
+ rewrite(kron(A, B).nrows()).to(A.nrows() * B.nrows()),
99
+ rewrite(kron(A, B).ncols()).to(A.ncols() * B.ncols()),
100
+ # The dimensions of a matrix multiplication are the number of rows of the first
101
+ # matrix and the number of columns of the second matrix.
102
+ rewrite((A @ B).nrows()).to(A.nrows()),
103
+ rewrite((A @ B).ncols()).to(B.ncols()),
104
+ # The dimensions of an identity matrix are the input dimension
105
+ rewrite(Matrix.identity(n).nrows()).to(n),
106
+ rewrite(Matrix.identity(n).ncols()).to(n),
107
+ )
108
+ egraph.register(
109
+ # Multiplication by an identity matrix is the same as the other matrix
110
+ rewrite(Matrix.identity(n) @ A).to(A),
111
+ rewrite(A @ Matrix.identity(n)).to(A),
112
+ # Matrix multiplication is associative
113
+ rewrite(A @ (B @ C)).to((A @ B) @ C),
114
+ rewrite((A @ B) @ C).to(A @ (B @ C)),
115
+ # Kronecker product is associative
116
+ rewrite(kron(A, kron(B, C))).to(kron(kron(A, B), C)),
117
+ rewrite(kron(kron(A, B), C)).to(kron(A, kron(B, C))),
118
+ # Kronecker product distributes over matrix multiplication
119
+ rewrite(kron(A @ C, B @ D)).to(kron(A, B) @ kron(C, D)),
120
+ rewrite(kron(A, B) @ kron(C, D)).to(
121
+ kron(A @ C, B @ D),
122
+ # Only when the dimensions match
123
+ eq(A.ncols()).to(C.nrows()),
124
+ eq(B.ncols()).to(D.nrows()),
125
+ ),
126
+ )
127
+ egraph.register(
128
+ # demand rows and columns when we multiply matrices
129
+ rule(eq(C).to(A @ B)).then(
130
+ let("demand1", A.ncols()),
131
+ let("demand2", A.nrows()),
132
+ let("demand3", B.ncols()),
133
+ let("demand4", B.nrows()),
134
+ ),
135
+ # demand rows and columns when we take the kronecker product
136
+ rule(eq(C).to(kron(A, B))).then(
137
+ let("demand1", A.ncols()),
138
+ let("demand2", A.nrows()),
139
+ let("demand3", B.ncols()),
140
+ let("demand4", B.nrows()),
141
+ ),
142
+ )
143
+
144
+
145
+ # Define a number of dimensions
146
+ n = egraph.define("n", Dim.named("n"))
147
+ m = egraph.define("m", Dim.named("m"))
148
+ p = egraph.define("p", Dim.named("p"))
149
+
150
+ # Define a number of matrices
151
+ A = egraph.define("A", Matrix.named("A"))
152
+ B = egraph.define("B", Matrix.named("B"))
153
+ C = egraph.define("C", Matrix.named("C"))
154
+
155
+ # Set each to be a square matrix of the given dimension
156
+ egraph.register(
157
+ set_(A.nrows()).to(n),
158
+ set_(A.ncols()).to(n),
159
+ set_(B.nrows()).to(m),
160
+ set_(B.ncols()).to(m),
161
+ set_(C.nrows()).to(p),
162
+ set_(C.ncols()).to(p),
163
+ )
164
+ # Create an example which should equal the kronecker product of A and B
165
+ ex1 = egraph.define("ex1", kron(Matrix.identity(n), B) @ kron(A, Matrix.identity(m)))
166
+ rows = egraph.define("rows", ex1.nrows())
167
+ cols = egraph.define("cols", ex1.ncols())
168
+
169
+ egraph.run(20)
170
+
171
+ egraph.check(eq(B.nrows()).to(m))
172
+ egraph.check(eq(kron(Matrix.identity(n), B).nrows()).to(n * m))
173
+
174
+ # Verify it matches the expected result
175
+ # TODO
176
+ simple_ex1 = egraph.define("simple_ex1", kron(A, B))
177
+ egraph.check(eq(ex1).to(simple_ex1))
178
+
179
+ ex2 = egraph.define("ex2", kron(Matrix.identity(p), C) @ kron(A, Matrix.identity(m)))
180
+
181
+ egraph.run(10)
182
+ # Verify it is not simplified
183
+ egraph.check_fail(eq(ex2).to(kron(A, C)))
184
+ egraph