egglog 9.0.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

@@ -0,0 +1,38 @@
1
+ # mypy: disable-error-code="empty-body"
2
+ """
3
+ Boolean data type example and test
4
+ ==================================
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from egglog import *
10
+
11
+ T = Bool(True)
12
+ F = Bool(False)
13
+ check(eq(T & T).to(T))
14
+ check(eq(T & F).to(F))
15
+ check(eq(T | F).to(T))
16
+ check(ne(T | F).to(F))
17
+
18
+ check(eq(i64(1).bool_lt(2)).to(T))
19
+ check(eq(i64(2).bool_lt(1)).to(F))
20
+ check(eq(i64(1).bool_lt(1)).to(F))
21
+
22
+ check(eq(i64(1).bool_le(2)).to(T))
23
+ check(eq(i64(2).bool_le(1)).to(F))
24
+ check(eq(i64(1).bool_le(1)).to(T))
25
+
26
+ R = relation("R", i64)
27
+
28
+
29
+ @function
30
+ def f(i: i64Like) -> Bool: ...
31
+
32
+
33
+ i = var("i", i64)
34
+ check(
35
+ eq(f(0)).to(T),
36
+ ruleset(rule(R(i)).then(set_(f(i)).to(T))) * 3,
37
+ R(i64(0)),
38
+ )
@@ -0,0 +1,45 @@
1
+ # mypy: disable-error-code="empty-body"
2
+ """
3
+ Basic equality saturation example.
4
+ ==================================
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from egglog import *
10
+
11
+ egraph = EGraph()
12
+
13
+
14
+ class Num(Expr):
15
+ def __init__(self, value: i64Like) -> None: ...
16
+
17
+ @classmethod
18
+ def var(cls, name: StringLike) -> Num: ...
19
+
20
+ def __add__(self, other: Num) -> Num: ...
21
+
22
+ def __mul__(self, other: Num) -> Num: ...
23
+
24
+
25
+ expr1 = Num(2) * (Num.var("x") + Num(3))
26
+ expr2 = Num(6) + Num(2) * Num.var("x")
27
+
28
+ a, b, c = vars_("a b c", Num)
29
+ i, j = vars_("i j", i64)
30
+
31
+ check(
32
+ # Check that these expressions are equal
33
+ eq(expr1).to(expr2),
34
+ # After running these rules, up to ten times
35
+ ruleset(
36
+ rewrite(a + b).to(b + a),
37
+ rewrite(a * (b + c)).to((a * b) + (a * c)),
38
+ rewrite(Num(i) + Num(j)).to(Num(i + j)),
39
+ rewrite(Num(i) * Num(j)).to(Num(i * j)),
40
+ )
41
+ * 10,
42
+ # On these two initial expressions
43
+ expr1,
44
+ expr2,
45
+ )
egglog/examples/fib.py ADDED
@@ -0,0 +1,28 @@
1
+ # mypy: disable-error-code="empty-body"
2
+ """
3
+ Fibonacci numbers example
4
+ =========================
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from egglog import *
10
+
11
+
12
+ @function
13
+ def fib(x: i64Like) -> i64: ...
14
+
15
+
16
+ f0, f1, x = vars_("f0 f1 x", i64)
17
+ check(
18
+ eq(fib(i64(7))).to(i64(21)),
19
+ ruleset(
20
+ rule(
21
+ eq(f0).to(fib(x)),
22
+ eq(f1).to(fib(x + 1)),
23
+ ).then(set_(fib(x + 2)).to(f0 + f1)),
24
+ )
25
+ * 7,
26
+ set_(fib(0)).to(i64(1)),
27
+ set_(fib(1)).to(i64(1)),
28
+ )
@@ -0,0 +1,45 @@
1
+ # mypy: disable-error-code="empty-body"
2
+ """
3
+ Higher Order Functions
4
+ ======================
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING
10
+
11
+ from egglog import *
12
+
13
+ if TYPE_CHECKING:
14
+ from collections.abc import Callable
15
+
16
+
17
+ class Math(Expr):
18
+ def __init__(self, i: i64Like) -> None: ...
19
+ def __add__(self, other: Math) -> Math: ...
20
+
21
+
22
+ class MathList(Expr):
23
+ def __init__(self) -> None: ...
24
+ def append(self, i: Math) -> MathList: ...
25
+ def map(self, f: Callable[[Math], Math]) -> MathList: ...
26
+
27
+
28
+ @ruleset
29
+ def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]):
30
+ yield rewrite(Math(i) + Math(j)).to(Math(i + j))
31
+ yield rewrite(xs.append(x).map(f)).to(xs.map(f).append(f(x)))
32
+ yield rewrite(MathList().map(f)).to(MathList())
33
+
34
+
35
+ @function(ruleset=math_ruleset)
36
+ def incr_list(xs: MathList) -> MathList:
37
+ return xs.map(lambda x: x + Math(1))
38
+
39
+
40
+ egraph = EGraph()
41
+ y = egraph.let("y", incr_list(MathList().append(Math(1)).append(Math(2))))
42
+ egraph.run(math_ruleset.saturate())
43
+ egraph.check(eq(y).to(MathList().append(Math(2)).append(Math(3))))
44
+
45
+ egraph
@@ -0,0 +1,288 @@
1
+ # mypy: disable-error-code="empty-body"
2
+
3
+ """
4
+ Lambda Calculus
5
+ ===============
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import TYPE_CHECKING, ClassVar
11
+
12
+ from egglog import *
13
+
14
+ if TYPE_CHECKING:
15
+ from collections.abc import Callable
16
+
17
+
18
+ class Val(Expr):
19
+ """
20
+ A value is a number or a boolean.
21
+ """
22
+
23
+ TRUE: ClassVar[Val]
24
+ FALSE: ClassVar[Val]
25
+
26
+ def __init__(self, v: i64Like) -> None: ...
27
+
28
+
29
+ class Var(Expr):
30
+ def __init__(self, v: StringLike) -> None: ...
31
+
32
+
33
+ class Term(Expr):
34
+ @classmethod
35
+ def val(cls, v: Val) -> Term: ...
36
+
37
+ @classmethod
38
+ def var(cls, v: Var) -> Term: ...
39
+
40
+ def __add__(self, other: Term) -> Term: ...
41
+
42
+ def __eq__(self, other: Term) -> Term: # type: ignore[override]
43
+ ...
44
+
45
+ def __call__(self, other: Term) -> Term: ...
46
+
47
+ def eval(self) -> Val: ...
48
+
49
+ def v(self) -> Var: ...
50
+
51
+
52
+ @function
53
+ def lam(x: Var, t: Term) -> Term: ...
54
+
55
+
56
+ @function
57
+ def let_(x: Var, t: Term, b: Term) -> Term: ...
58
+
59
+
60
+ @function
61
+ def fix(x: Var, t: Term) -> Term: ...
62
+
63
+
64
+ @function
65
+ def if_(c: Term, t: Term, f: Term) -> Term: ...
66
+
67
+
68
+ StringSet = Set[Var]
69
+
70
+
71
+ @function(merge=lambda old, new: old & new)
72
+ def freer(t: Term) -> StringSet: ...
73
+
74
+
75
+ (v, v1, v2) = vars_("v v1 v2", Val)
76
+ (t, t1, t2, t3, t4) = vars_("t t1 t2 t3 t4", Term)
77
+ (x, y) = vars_("x y", Var)
78
+ fv, fv1, fv2, fv3 = vars_("fv fv1 fv2 fv3", StringSet)
79
+ i1, i2 = vars_("i1 i2", i64)
80
+ lamdba_ruleset = ruleset(
81
+ # freer
82
+ rule(eq(t).to(Term.val(v))).then(set_(freer(t)).to(StringSet.empty())),
83
+ rule(eq(t).to(Term.var(x))).then(set_(freer(t)).to(StringSet.empty().insert(x))),
84
+ rule(eq(t).to(t1 + t2), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
85
+ rule(eq(t).to(t1 == t2), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
86
+ rule(eq(t).to(t1(t2)), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
87
+ rule(eq(t).to(lam(x, t1)), eq(freer(t1)).to(fv)).then(set_(freer(t)).to(fv.remove(x))),
88
+ rule(eq(t).to(let_(x, t1, t2)), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(
89
+ set_(freer(t)).to(fv1.remove(x) | fv2)
90
+ ),
91
+ rule(eq(t).to(fix(x, t1)), eq(freer(t1)).to(fv)).then(set_(freer(t)).to(fv.remove(x))),
92
+ rule(eq(t).to(if_(t1, t2, t3)), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2), eq(freer(t3)).to(fv3)).then(
93
+ set_(freer(t)).to(fv1 | fv2 | fv3)
94
+ ),
95
+ # eval
96
+ rule(eq(t).to(Term.val(v))).then(union(t.eval()).with_(v)),
97
+ rule(eq(t).to(t1 + t2), eq(Val(i1)).to(t1.eval()), eq(Val(i2)).to(t2.eval())).then(
98
+ union(t.eval()).with_(Val(i1 + i2))
99
+ ),
100
+ rule(eq(t).to(t1 == t2), eq(t1.eval()).to(t2.eval())).then(union(t.eval()).with_(Val.TRUE)),
101
+ rule(eq(t).to(t1 == t2), eq(t1.eval()).to(v1), eq(t2.eval()).to(v2), ne(v1).to(v2)).then(
102
+ union(t.eval()).with_(Val.FALSE)
103
+ ),
104
+ rule(eq(v).to(t.eval())).then(union(t).with_(Term.val(v))),
105
+ # if
106
+ rewrite(if_(Term.val(Val.TRUE), t1, t2)).to(t1),
107
+ rewrite(if_(Term.val(Val.FALSE), t1, t2)).to(t2),
108
+ # if-elim
109
+ # Adds let rules so next one can match on them
110
+ rule(eq(t).to(if_(Term.var(x) == t1, t2, t3))).then(let_(x, t1, t2), let_(x, t1, t3)),
111
+ rewrite(if_(Term.var(x) == t1, t2, t3)).to(
112
+ t3,
113
+ eq(let_(x, t1, t2)).to(let_(x, t1, t3)),
114
+ ),
115
+ # add-comm
116
+ rewrite(t1 + t2).to(t2 + t1),
117
+ # add-assoc
118
+ rewrite((t1 + t2) + t3).to(t1 + (t2 + t3)),
119
+ # eq-comm
120
+ rewrite(t1 == t2).to(t2 == t1),
121
+ # Fix
122
+ rewrite(fix(x, t)).to(let_(x, fix(x, t), t)),
123
+ # beta reduction
124
+ rewrite(lam(x, t)(t1)).to(let_(x, t1, t)),
125
+ # let-app
126
+ rewrite(let_(x, t, t1(t2))).to(let_(x, t, t1)(let_(x, t, t2))),
127
+ # let-add
128
+ rewrite(let_(x, t, t1 + t2)).to(let_(x, t, t1) + let_(x, t, t2)),
129
+ # let-eq
130
+ rewrite(let_(x, t, t1 == t2)).to(let_(x, t, t1) == let_(x, t, t2)),
131
+ # let-const
132
+ rewrite(let_(x, t, Term.val(v))).to(Term.val(v)),
133
+ # let-if
134
+ rewrite(let_(x, t, if_(t1, t2, t3))).to(if_(let_(x, t, t1), let_(x, t, t2), let_(x, t, t3))),
135
+ # let-var-same
136
+ rewrite(let_(x, t, Term.var(x))).to(t),
137
+ # let-var-diff
138
+ rewrite(let_(x, t, Term.var(y))).to(Term.var(y), ne(x).to(y)),
139
+ # let-lam-same
140
+ rewrite(let_(x, t, lam(x, t1))).to(lam(x, t1)),
141
+ # let-lam-diff
142
+ rewrite(let_(x, t, lam(y, t1))).to(lam(y, let_(x, t, t1)), ne(x).to(y), eq(fv).to(freer(t)), fv.not_contains(y)),
143
+ rule(eq(t).to(let_(x, t1, lam(y, t2))), ne(x).to(y), eq(fv).to(freer(t1)), fv.contains(y)).then(
144
+ union(t).with_(lam(t.v(), let_(x, t1, let_(y, Term.var(t.v()), t2))))
145
+ ),
146
+ )
147
+
148
+ result = relation("result")
149
+
150
+
151
+ def l(fn: Callable[[Term], Term]) -> Term: # noqa: E743
152
+ """
153
+ Create a lambda term from a function
154
+ """
155
+ # Use first var name from fn
156
+ x = fn.__code__.co_varnames[0]
157
+ return lam(Var(x), fn(Term.var(Var(x))))
158
+
159
+
160
+ def assert_simplifies(left: Expr, right: Expr) -> None:
161
+ """
162
+ Simplify and print
163
+ """
164
+ print(f"{left} ➡ {right}")
165
+ check(eq(left).to(right), lamdba_ruleset * 30, left)
166
+
167
+
168
+ assert_simplifies((Term.val(Val(1))).eval(), Val(1))
169
+ assert_simplifies((Term.val(Val(1)) + Term.val(Val(2))).eval(), Val(3))
170
+
171
+
172
+ # lambda under
173
+ assert_simplifies(
174
+ l(lambda x: Term.val(Val(4)) + l(lambda y: y)(Term.val(Val(4)))),
175
+ l(lambda x: Term.val(Val(8))),
176
+ )
177
+ # lambda if elim
178
+ a = Term.var(Var("a"))
179
+ b = Term.var(Var("b"))
180
+ assert_simplifies(if_(a == b, a + a, a + b), a + b)
181
+
182
+ # lambda let simple
183
+ x = Var("x")
184
+ y = Var("y")
185
+ assert_simplifies(
186
+ let_(x, Term.val(Val(0)), let_(y, Term.val(Val(1)), Term.var(x) + Term.var(y))),
187
+ Term.val(Val(1)),
188
+ )
189
+ # lambda capture
190
+ assert_simplifies(
191
+ let_(x, Term.val(Val(1)), l(lambda x: x)),
192
+ l(lambda x: x),
193
+ )
194
+ # lambda capture free
195
+ egraph = EGraph()
196
+ e5 = egraph.let("e5", let_(y, Term.var(x) + Term.var(x), l(lambda x: Term.var(y))))
197
+ egraph.run(lamdba_ruleset * 10)
198
+ egraph.check(freer(l(lambda x: Term.var(y))).contains(y))
199
+ egraph.check_fail(eq(e5).to(l(lambda x: x + x)))
200
+
201
+ # lambda_closure_not_seven
202
+ egraph = EGraph()
203
+ e6 = egraph.let(
204
+ "e6",
205
+ let_(
206
+ Var("five"),
207
+ Term.val(Val(5)),
208
+ let_(
209
+ Var("add-five"),
210
+ l(lambda x: x + Term.var(Var("five"))),
211
+ let_(Var("five"), Term.val(Val(6)), Term.var(Var("add-five"))(Term.val(Val(1)))),
212
+ ),
213
+ ),
214
+ )
215
+ egraph.run(lamdba_ruleset * 10)
216
+ egraph.check_fail(eq(e6).to(Term.val(Val(7))))
217
+ egraph.check(eq(e6).to(Term.val(Val(6))))
218
+
219
+
220
+ # lambda_compose
221
+ egraph = EGraph()
222
+ compose = Var("compose")
223
+ add1 = Var("add1")
224
+ e7 = egraph.let(
225
+ "e7",
226
+ let_(
227
+ compose,
228
+ l(
229
+ lambda f: l(
230
+ lambda g: l(
231
+ lambda x: f(g(x)),
232
+ ),
233
+ ),
234
+ ),
235
+ let_(
236
+ add1,
237
+ l(lambda y: y + Term.val(Val(1))),
238
+ Term.var(compose)(Term.var(add1))(Term.var(add1)),
239
+ ),
240
+ ),
241
+ )
242
+ egraph.run(lamdba_ruleset * 20)
243
+ egraph.register(
244
+ rule(
245
+ eq(t1).to(l(lambda x: Term.val(Val(1)) + l(lambda y: Term.val(Val(1)) + y)(x))),
246
+ eq(t2).to(l(lambda x: x + Term.val(Val(2)))),
247
+ ).then(result())
248
+ )
249
+ egraph.run(1)
250
+ egraph.check(result())
251
+
252
+
253
+ # lambda_if_simple
254
+ assert_simplifies(if_(Term.val(Val(1)) == Term.val(Val(1)), Term.val(Val(7)), Term.val(Val(9))), Term.val(Val(7)))
255
+
256
+
257
+ # # lambda_compose_many
258
+ assert_simplifies(
259
+ let_(
260
+ compose,
261
+ l(lambda f: l(lambda g: l(lambda x: f(g(x))))),
262
+ let_(
263
+ add1,
264
+ l(lambda y: y + Term.val(Val(1))),
265
+ Term.var(compose)(Term.var(add1))(
266
+ Term.var(compose)(Term.var(add1))(
267
+ Term.var(compose)(Term.var(add1))(
268
+ Term.var(compose)(Term.var(add1))(
269
+ Term.var(compose)(Term.var(add1))(Term.var(compose)(Term.var(add1))(Term.var(add1)))
270
+ )
271
+ )
272
+ )
273
+ ),
274
+ ),
275
+ ),
276
+ l(lambda x: x + Term.val(Val(7))),
277
+ )
278
+
279
+ # lambda_if
280
+ zeroone = Var("zeroone")
281
+ assert_simplifies(
282
+ let_(
283
+ zeroone,
284
+ l(lambda x: if_(x == Term.val(Val(0)), Term.val(Val(0)), Term.val(Val(1)))),
285
+ Term.var(zeroone)(Term.val(Val(0))) + Term.var(zeroone)(Term.val(Val(10))),
286
+ ),
287
+ Term.val(Val(1)),
288
+ )
@@ -0,0 +1,175 @@
1
+ """
2
+ Matrix multiplication and Kronecker product.
3
+ ============================================
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from egglog import *
9
+
10
+ egraph = EGraph()
11
+
12
+
13
+ class Dim(Expr):
14
+ """
15
+ A dimension of a matix.
16
+
17
+ >>> Dim(3) * Dim.named("n")
18
+ Dim(3) * Dim.named("n")
19
+ """
20
+
21
+ @method(egg_fn="Lit")
22
+ def __init__(self, value: i64Like) -> None: ...
23
+
24
+ @method(egg_fn="NamedDim")
25
+ @classmethod
26
+ def named(cls, name: StringLike) -> Dim: # type: ignore[empty-body]
27
+ ...
28
+
29
+ @method(egg_fn="Times")
30
+ def __mul__(self, other: Dim) -> Dim: # type: ignore[empty-body]
31
+ ...
32
+
33
+
34
+ a, b, c, n = vars_("a b c n", Dim)
35
+ i, j = vars_("i j", i64)
36
+ egraph.register(
37
+ rewrite(a * (b * c)).to((a * b) * c),
38
+ rewrite((a * b) * c).to(a * (b * c)),
39
+ rewrite(Dim(i) * Dim(j)).to(Dim(i * j)),
40
+ rewrite(a * b).to(b * a),
41
+ )
42
+
43
+
44
+ class Matrix(Expr, egg_sort="MExpr"):
45
+ @method(egg_fn="Id")
46
+ @classmethod
47
+ def identity(cls, dim: Dim) -> Matrix: # type: ignore[empty-body]
48
+ """
49
+ Create an identity matrix of the given dimension.
50
+ """
51
+
52
+ @method(egg_fn="NamedMat")
53
+ @classmethod
54
+ def named(cls, name: StringLike) -> Matrix: # type: ignore[empty-body]
55
+ """
56
+ Create a named matrix.
57
+ """
58
+
59
+ @method(egg_fn="MMul")
60
+ def __matmul__(self, other: Matrix) -> Matrix: # type: ignore[empty-body]
61
+ """
62
+ Matrix multiplication.
63
+ """
64
+
65
+ @method(egg_fn="nrows")
66
+ def nrows(self) -> Dim: # type: ignore[empty-body]
67
+ """
68
+ Number of rows in the matrix.
69
+ """
70
+
71
+ @method(egg_fn="ncols")
72
+ def ncols(self) -> Dim: # type: ignore[empty-body]
73
+ """
74
+ Number of columns in the matrix.
75
+ """
76
+
77
+
78
+ @function(egg_fn="Kron")
79
+ def kron(a: Matrix, b: Matrix) -> Matrix: # type: ignore[empty-body]
80
+ """
81
+ Kronecker product of two matrices.
82
+
83
+ https://en.wikipedia.org/wiki/Kronecker_product#Definition
84
+ """
85
+
86
+
87
+ A, B, C, D = vars_("A B C D", Matrix)
88
+ egraph.register(
89
+ # The dimensions of a kronecker product are the product of the dimensions
90
+ rewrite(kron(A, B).nrows()).to(A.nrows() * B.nrows()),
91
+ rewrite(kron(A, B).ncols()).to(A.ncols() * B.ncols()),
92
+ # The dimensions of a matrix multiplication are the number of rows of the first
93
+ # matrix and the number of columns of the second matrix.
94
+ rewrite((A @ B).nrows()).to(A.nrows()),
95
+ rewrite((A @ B).ncols()).to(B.ncols()),
96
+ # The dimensions of an identity matrix are the input dimension
97
+ rewrite(Matrix.identity(n).nrows()).to(n),
98
+ rewrite(Matrix.identity(n).ncols()).to(n),
99
+ )
100
+ egraph.register(
101
+ # Multiplication by an identity matrix is the same as the other matrix
102
+ rewrite(Matrix.identity(n) @ A).to(A),
103
+ rewrite(A @ Matrix.identity(n)).to(A),
104
+ # Matrix multiplication is associative
105
+ rewrite(A @ (B @ C)).to((A @ B) @ C),
106
+ rewrite((A @ B) @ C).to(A @ (B @ C)),
107
+ # Kronecker product is associative
108
+ rewrite(kron(A, kron(B, C))).to(kron(kron(A, B), C)),
109
+ rewrite(kron(kron(A, B), C)).to(kron(A, kron(B, C))),
110
+ # Kronecker product distributes over matrix multiplication
111
+ rewrite(kron(A @ C, B @ D)).to(kron(A, B) @ kron(C, D)),
112
+ rewrite(kron(A, B) @ kron(C, D)).to(
113
+ kron(A @ C, B @ D),
114
+ # Only when the dimensions match
115
+ eq(A.ncols()).to(C.nrows()),
116
+ eq(B.ncols()).to(D.nrows()),
117
+ ),
118
+ )
119
+ egraph.register(
120
+ # demand rows and columns when we multiply matrices
121
+ rule(eq(C).to(A @ B)).then(
122
+ A.ncols(),
123
+ A.nrows(),
124
+ B.ncols(),
125
+ B.nrows(),
126
+ ),
127
+ # demand rows and columns when we take the kronecker product
128
+ rule(eq(C).to(kron(A, B))).then(
129
+ A.ncols(),
130
+ A.nrows(),
131
+ B.ncols(),
132
+ B.nrows(),
133
+ ),
134
+ )
135
+
136
+
137
+ # Define a number of dimensions
138
+ n = egraph.let("n", Dim.named("n"))
139
+ m = egraph.let("m", Dim.named("m"))
140
+ p = egraph.let("p", Dim.named("p"))
141
+
142
+ # Define a number of matrices
143
+ A = egraph.let("A", Matrix.named("A"))
144
+ B = egraph.let("B", Matrix.named("B"))
145
+ C = egraph.let("C", Matrix.named("C"))
146
+
147
+ # Set each to be a square matrix of the given dimension
148
+ egraph.register(
149
+ union(A.nrows()).with_(n),
150
+ union(A.ncols()).with_(n),
151
+ union(B.nrows()).with_(m),
152
+ union(B.ncols()).with_(m),
153
+ union(C.nrows()).with_(p),
154
+ union(C.ncols()).with_(p),
155
+ )
156
+ # Create an example which should equal the kronecker product of A and B
157
+ ex1 = egraph.let("ex1", kron(Matrix.identity(n), B) @ kron(A, Matrix.identity(m)))
158
+ rows = egraph.let("rows", ex1.nrows())
159
+ cols = egraph.let("cols", ex1.ncols())
160
+
161
+ egraph.run(20)
162
+
163
+ egraph.check(eq(B.nrows()).to(m))
164
+ egraph.check(eq(kron(Matrix.identity(n), B).nrows()).to(n * m))
165
+
166
+ # Verify it matches the expected result
167
+ simple_ex1 = egraph.let("simple_ex1", kron(A, B))
168
+ egraph.check(eq(ex1).to(simple_ex1))
169
+
170
+ ex2 = egraph.let("ex2", kron(Matrix.identity(p), C) @ kron(A, Matrix.identity(m)))
171
+
172
+ egraph.run(10)
173
+ # Verify it is not simplified
174
+ egraph.check_fail(eq(ex2).to(kron(A, C)))
175
+ egraph
@@ -0,0 +1,61 @@
1
+ # mypy: disable-error-code="empty-body"
2
+ """
3
+ Multiset example based off of egglog version
4
+ ============================================
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from collections import Counter
10
+
11
+ from egglog import *
12
+
13
+
14
+ class Math(Expr):
15
+ def __init__(self, x: i64Like) -> None: ...
16
+
17
+
18
+ @function
19
+ def square(x: Math) -> Math: ...
20
+
21
+
22
+ @ruleset
23
+ def math_ruleset(i: i64):
24
+ yield rewrite(square(Math(i))).to(Math(i * i))
25
+
26
+
27
+ egraph = EGraph()
28
+
29
+ xs = MultiSet(Math(1), Math(2), Math(3))
30
+ egraph.register(xs)
31
+
32
+ with egraph.set_current():
33
+ assert xs == MultiSet(Math(1), Math(3), Math(2))
34
+ assert xs != MultiSet(Math(1), Math(1), Math(2), Math(3))
35
+
36
+ assert Counter(xs) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
37
+
38
+ inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
39
+ egraph.register(inserted)
40
+ assert xs.insert(Math(4)) == inserted
41
+
42
+ assert xs.contains(Math(1))
43
+ assert xs.not_contains(Math(4))
44
+ assert Math(1) in xs
45
+ assert Math(4) not in xs
46
+
47
+ assert xs.remove(Math(1)) == MultiSet(Math(2), Math(3))
48
+
49
+ assert xs.length() == i64(3)
50
+ assert len(xs) == 3
51
+
52
+ assert MultiSet(Math(1), Math(1)).length() == i64(2)
53
+
54
+ assert MultiSet(Math(1)).pick() == Math(1)
55
+
56
+ mapped = xs.map(square)
57
+ egraph.register(mapped)
58
+ egraph.run(math_ruleset)
59
+ assert mapped == MultiSet(Math(1), Math(4), Math(9))
60
+
61
+ assert xs + xs == MultiSet(Math(1), Math(2), Math(3), Math(1), Math(2), Math(3))