egglog 9.0.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +10 -0
- egglog/bindings.cp311-win_amd64.pyd +0 -0
- egglog/bindings.pyi +667 -0
- egglog/builtins.py +1045 -0
- egglog/config.py +8 -0
- egglog/conversion.py +262 -0
- egglog/declarations.py +818 -0
- egglog/egraph.py +1909 -0
- egglog/egraph_state.py +634 -0
- egglog/examples/README.rst +5 -0
- egglog/examples/__init__.py +3 -0
- egglog/examples/bignum.py +31 -0
- egglog/examples/bool.py +38 -0
- egglog/examples/eqsat_basic.py +45 -0
- egglog/examples/fib.py +28 -0
- egglog/examples/higher_order_functions.py +45 -0
- egglog/examples/lambda_.py +288 -0
- egglog/examples/matrix.py +175 -0
- egglog/examples/multiset.py +61 -0
- egglog/examples/ndarrays.py +144 -0
- egglog/examples/resolution.py +84 -0
- egglog/examples/schedule_demo.py +34 -0
- egglog/exp/__init__.py +3 -0
- egglog/exp/array_api.py +1943 -0
- egglog/exp/array_api_jit.py +44 -0
- egglog/exp/array_api_loopnest.py +74 -0
- egglog/exp/array_api_numba.py +69 -0
- egglog/exp/array_api_program_gen.py +510 -0
- egglog/exp/program_gen.py +424 -0
- egglog/exp/siu_examples.py +32 -0
- egglog/functionalize.py +91 -0
- egglog/ipython_magic.py +41 -0
- egglog/pretty.py +510 -0
- egglog/py.typed +0 -0
- egglog/runtime.py +633 -0
- egglog/thunk.py +95 -0
- egglog/type_constraint_solver.py +113 -0
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35777 -0
- egglog/visualizer_widget.py +39 -0
- egglog-9.0.0.dist-info/METADATA +74 -0
- egglog-9.0.0.dist-info/RECORD +44 -0
- egglog-9.0.0.dist-info/WHEEL +4 -0
- egglog-9.0.0.dist-info/licenses/LICENSE +21 -0
egglog/examples/bool.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
"""
|
|
3
|
+
Boolean data type example and test
|
|
4
|
+
==================================
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from egglog import *
|
|
10
|
+
|
|
11
|
+
T = Bool(True)
|
|
12
|
+
F = Bool(False)
|
|
13
|
+
check(eq(T & T).to(T))
|
|
14
|
+
check(eq(T & F).to(F))
|
|
15
|
+
check(eq(T | F).to(T))
|
|
16
|
+
check(ne(T | F).to(F))
|
|
17
|
+
|
|
18
|
+
check(eq(i64(1).bool_lt(2)).to(T))
|
|
19
|
+
check(eq(i64(2).bool_lt(1)).to(F))
|
|
20
|
+
check(eq(i64(1).bool_lt(1)).to(F))
|
|
21
|
+
|
|
22
|
+
check(eq(i64(1).bool_le(2)).to(T))
|
|
23
|
+
check(eq(i64(2).bool_le(1)).to(F))
|
|
24
|
+
check(eq(i64(1).bool_le(1)).to(T))
|
|
25
|
+
|
|
26
|
+
R = relation("R", i64)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@function
|
|
30
|
+
def f(i: i64Like) -> Bool: ...
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
i = var("i", i64)
|
|
34
|
+
check(
|
|
35
|
+
eq(f(0)).to(T),
|
|
36
|
+
ruleset(rule(R(i)).then(set_(f(i)).to(T))) * 3,
|
|
37
|
+
R(i64(0)),
|
|
38
|
+
)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
"""
|
|
3
|
+
Basic equality saturation example.
|
|
4
|
+
==================================
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from egglog import *
|
|
10
|
+
|
|
11
|
+
egraph = EGraph()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Num(Expr):
|
|
15
|
+
def __init__(self, value: i64Like) -> None: ...
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def var(cls, name: StringLike) -> Num: ...
|
|
19
|
+
|
|
20
|
+
def __add__(self, other: Num) -> Num: ...
|
|
21
|
+
|
|
22
|
+
def __mul__(self, other: Num) -> Num: ...
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
expr1 = Num(2) * (Num.var("x") + Num(3))
|
|
26
|
+
expr2 = Num(6) + Num(2) * Num.var("x")
|
|
27
|
+
|
|
28
|
+
a, b, c = vars_("a b c", Num)
|
|
29
|
+
i, j = vars_("i j", i64)
|
|
30
|
+
|
|
31
|
+
check(
|
|
32
|
+
# Check that these expressions are equal
|
|
33
|
+
eq(expr1).to(expr2),
|
|
34
|
+
# After running these rules, up to ten times
|
|
35
|
+
ruleset(
|
|
36
|
+
rewrite(a + b).to(b + a),
|
|
37
|
+
rewrite(a * (b + c)).to((a * b) + (a * c)),
|
|
38
|
+
rewrite(Num(i) + Num(j)).to(Num(i + j)),
|
|
39
|
+
rewrite(Num(i) * Num(j)).to(Num(i * j)),
|
|
40
|
+
)
|
|
41
|
+
* 10,
|
|
42
|
+
# On these two initial expressions
|
|
43
|
+
expr1,
|
|
44
|
+
expr2,
|
|
45
|
+
)
|
egglog/examples/fib.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
"""
|
|
3
|
+
Fibonacci numbers example
|
|
4
|
+
=========================
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from egglog import *
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@function
|
|
13
|
+
def fib(x: i64Like) -> i64: ...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
f0, f1, x = vars_("f0 f1 x", i64)
|
|
17
|
+
check(
|
|
18
|
+
eq(fib(i64(7))).to(i64(21)),
|
|
19
|
+
ruleset(
|
|
20
|
+
rule(
|
|
21
|
+
eq(f0).to(fib(x)),
|
|
22
|
+
eq(f1).to(fib(x + 1)),
|
|
23
|
+
).then(set_(fib(x + 2)).to(f0 + f1)),
|
|
24
|
+
)
|
|
25
|
+
* 7,
|
|
26
|
+
set_(fib(0)).to(i64(1)),
|
|
27
|
+
set_(fib(1)).to(i64(1)),
|
|
28
|
+
)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
"""
|
|
3
|
+
Higher Order Functions
|
|
4
|
+
======================
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
from egglog import *
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from collections.abc import Callable
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Math(Expr):
|
|
18
|
+
def __init__(self, i: i64Like) -> None: ...
|
|
19
|
+
def __add__(self, other: Math) -> Math: ...
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MathList(Expr):
|
|
23
|
+
def __init__(self) -> None: ...
|
|
24
|
+
def append(self, i: Math) -> MathList: ...
|
|
25
|
+
def map(self, f: Callable[[Math], Math]) -> MathList: ...
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@ruleset
|
|
29
|
+
def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]):
|
|
30
|
+
yield rewrite(Math(i) + Math(j)).to(Math(i + j))
|
|
31
|
+
yield rewrite(xs.append(x).map(f)).to(xs.map(f).append(f(x)))
|
|
32
|
+
yield rewrite(MathList().map(f)).to(MathList())
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@function(ruleset=math_ruleset)
|
|
36
|
+
def incr_list(xs: MathList) -> MathList:
|
|
37
|
+
return xs.map(lambda x: x + Math(1))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
egraph = EGraph()
|
|
41
|
+
y = egraph.let("y", incr_list(MathList().append(Math(1)).append(Math(2))))
|
|
42
|
+
egraph.run(math_ruleset.saturate())
|
|
43
|
+
egraph.check(eq(y).to(MathList().append(Math(2)).append(Math(3))))
|
|
44
|
+
|
|
45
|
+
egraph
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Lambda Calculus
|
|
5
|
+
===============
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
11
|
+
|
|
12
|
+
from egglog import *
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from collections.abc import Callable
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Val(Expr):
|
|
19
|
+
"""
|
|
20
|
+
A value is a number or a boolean.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
TRUE: ClassVar[Val]
|
|
24
|
+
FALSE: ClassVar[Val]
|
|
25
|
+
|
|
26
|
+
def __init__(self, v: i64Like) -> None: ...
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Var(Expr):
|
|
30
|
+
def __init__(self, v: StringLike) -> None: ...
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Term(Expr):
|
|
34
|
+
@classmethod
|
|
35
|
+
def val(cls, v: Val) -> Term: ...
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def var(cls, v: Var) -> Term: ...
|
|
39
|
+
|
|
40
|
+
def __add__(self, other: Term) -> Term: ...
|
|
41
|
+
|
|
42
|
+
def __eq__(self, other: Term) -> Term: # type: ignore[override]
|
|
43
|
+
...
|
|
44
|
+
|
|
45
|
+
def __call__(self, other: Term) -> Term: ...
|
|
46
|
+
|
|
47
|
+
def eval(self) -> Val: ...
|
|
48
|
+
|
|
49
|
+
def v(self) -> Var: ...
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@function
|
|
53
|
+
def lam(x: Var, t: Term) -> Term: ...
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@function
|
|
57
|
+
def let_(x: Var, t: Term, b: Term) -> Term: ...
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@function
|
|
61
|
+
def fix(x: Var, t: Term) -> Term: ...
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@function
|
|
65
|
+
def if_(c: Term, t: Term, f: Term) -> Term: ...
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
StringSet = Set[Var]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@function(merge=lambda old, new: old & new)
|
|
72
|
+
def freer(t: Term) -> StringSet: ...
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
(v, v1, v2) = vars_("v v1 v2", Val)
|
|
76
|
+
(t, t1, t2, t3, t4) = vars_("t t1 t2 t3 t4", Term)
|
|
77
|
+
(x, y) = vars_("x y", Var)
|
|
78
|
+
fv, fv1, fv2, fv3 = vars_("fv fv1 fv2 fv3", StringSet)
|
|
79
|
+
i1, i2 = vars_("i1 i2", i64)
|
|
80
|
+
lamdba_ruleset = ruleset(
|
|
81
|
+
# freer
|
|
82
|
+
rule(eq(t).to(Term.val(v))).then(set_(freer(t)).to(StringSet.empty())),
|
|
83
|
+
rule(eq(t).to(Term.var(x))).then(set_(freer(t)).to(StringSet.empty().insert(x))),
|
|
84
|
+
rule(eq(t).to(t1 + t2), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
|
|
85
|
+
rule(eq(t).to(t1 == t2), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
|
|
86
|
+
rule(eq(t).to(t1(t2)), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
|
|
87
|
+
rule(eq(t).to(lam(x, t1)), eq(freer(t1)).to(fv)).then(set_(freer(t)).to(fv.remove(x))),
|
|
88
|
+
rule(eq(t).to(let_(x, t1, t2)), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(
|
|
89
|
+
set_(freer(t)).to(fv1.remove(x) | fv2)
|
|
90
|
+
),
|
|
91
|
+
rule(eq(t).to(fix(x, t1)), eq(freer(t1)).to(fv)).then(set_(freer(t)).to(fv.remove(x))),
|
|
92
|
+
rule(eq(t).to(if_(t1, t2, t3)), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2), eq(freer(t3)).to(fv3)).then(
|
|
93
|
+
set_(freer(t)).to(fv1 | fv2 | fv3)
|
|
94
|
+
),
|
|
95
|
+
# eval
|
|
96
|
+
rule(eq(t).to(Term.val(v))).then(union(t.eval()).with_(v)),
|
|
97
|
+
rule(eq(t).to(t1 + t2), eq(Val(i1)).to(t1.eval()), eq(Val(i2)).to(t2.eval())).then(
|
|
98
|
+
union(t.eval()).with_(Val(i1 + i2))
|
|
99
|
+
),
|
|
100
|
+
rule(eq(t).to(t1 == t2), eq(t1.eval()).to(t2.eval())).then(union(t.eval()).with_(Val.TRUE)),
|
|
101
|
+
rule(eq(t).to(t1 == t2), eq(t1.eval()).to(v1), eq(t2.eval()).to(v2), ne(v1).to(v2)).then(
|
|
102
|
+
union(t.eval()).with_(Val.FALSE)
|
|
103
|
+
),
|
|
104
|
+
rule(eq(v).to(t.eval())).then(union(t).with_(Term.val(v))),
|
|
105
|
+
# if
|
|
106
|
+
rewrite(if_(Term.val(Val.TRUE), t1, t2)).to(t1),
|
|
107
|
+
rewrite(if_(Term.val(Val.FALSE), t1, t2)).to(t2),
|
|
108
|
+
# if-elim
|
|
109
|
+
# Adds let rules so next one can match on them
|
|
110
|
+
rule(eq(t).to(if_(Term.var(x) == t1, t2, t3))).then(let_(x, t1, t2), let_(x, t1, t3)),
|
|
111
|
+
rewrite(if_(Term.var(x) == t1, t2, t3)).to(
|
|
112
|
+
t3,
|
|
113
|
+
eq(let_(x, t1, t2)).to(let_(x, t1, t3)),
|
|
114
|
+
),
|
|
115
|
+
# add-comm
|
|
116
|
+
rewrite(t1 + t2).to(t2 + t1),
|
|
117
|
+
# add-assoc
|
|
118
|
+
rewrite((t1 + t2) + t3).to(t1 + (t2 + t3)),
|
|
119
|
+
# eq-comm
|
|
120
|
+
rewrite(t1 == t2).to(t2 == t1),
|
|
121
|
+
# Fix
|
|
122
|
+
rewrite(fix(x, t)).to(let_(x, fix(x, t), t)),
|
|
123
|
+
# beta reduction
|
|
124
|
+
rewrite(lam(x, t)(t1)).to(let_(x, t1, t)),
|
|
125
|
+
# let-app
|
|
126
|
+
rewrite(let_(x, t, t1(t2))).to(let_(x, t, t1)(let_(x, t, t2))),
|
|
127
|
+
# let-add
|
|
128
|
+
rewrite(let_(x, t, t1 + t2)).to(let_(x, t, t1) + let_(x, t, t2)),
|
|
129
|
+
# let-eq
|
|
130
|
+
rewrite(let_(x, t, t1 == t2)).to(let_(x, t, t1) == let_(x, t, t2)),
|
|
131
|
+
# let-const
|
|
132
|
+
rewrite(let_(x, t, Term.val(v))).to(Term.val(v)),
|
|
133
|
+
# let-if
|
|
134
|
+
rewrite(let_(x, t, if_(t1, t2, t3))).to(if_(let_(x, t, t1), let_(x, t, t2), let_(x, t, t3))),
|
|
135
|
+
# let-var-same
|
|
136
|
+
rewrite(let_(x, t, Term.var(x))).to(t),
|
|
137
|
+
# let-var-diff
|
|
138
|
+
rewrite(let_(x, t, Term.var(y))).to(Term.var(y), ne(x).to(y)),
|
|
139
|
+
# let-lam-same
|
|
140
|
+
rewrite(let_(x, t, lam(x, t1))).to(lam(x, t1)),
|
|
141
|
+
# let-lam-diff
|
|
142
|
+
rewrite(let_(x, t, lam(y, t1))).to(lam(y, let_(x, t, t1)), ne(x).to(y), eq(fv).to(freer(t)), fv.not_contains(y)),
|
|
143
|
+
rule(eq(t).to(let_(x, t1, lam(y, t2))), ne(x).to(y), eq(fv).to(freer(t1)), fv.contains(y)).then(
|
|
144
|
+
union(t).with_(lam(t.v(), let_(x, t1, let_(y, Term.var(t.v()), t2))))
|
|
145
|
+
),
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
result = relation("result")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def l(fn: Callable[[Term], Term]) -> Term: # noqa: E743
|
|
152
|
+
"""
|
|
153
|
+
Create a lambda term from a function
|
|
154
|
+
"""
|
|
155
|
+
# Use first var name from fn
|
|
156
|
+
x = fn.__code__.co_varnames[0]
|
|
157
|
+
return lam(Var(x), fn(Term.var(Var(x))))
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def assert_simplifies(left: Expr, right: Expr) -> None:
|
|
161
|
+
"""
|
|
162
|
+
Simplify and print
|
|
163
|
+
"""
|
|
164
|
+
print(f"{left} ➡ {right}")
|
|
165
|
+
check(eq(left).to(right), lamdba_ruleset * 30, left)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
assert_simplifies((Term.val(Val(1))).eval(), Val(1))
|
|
169
|
+
assert_simplifies((Term.val(Val(1)) + Term.val(Val(2))).eval(), Val(3))
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
# lambda under
|
|
173
|
+
assert_simplifies(
|
|
174
|
+
l(lambda x: Term.val(Val(4)) + l(lambda y: y)(Term.val(Val(4)))),
|
|
175
|
+
l(lambda x: Term.val(Val(8))),
|
|
176
|
+
)
|
|
177
|
+
# lambda if elim
|
|
178
|
+
a = Term.var(Var("a"))
|
|
179
|
+
b = Term.var(Var("b"))
|
|
180
|
+
assert_simplifies(if_(a == b, a + a, a + b), a + b)
|
|
181
|
+
|
|
182
|
+
# lambda let simple
|
|
183
|
+
x = Var("x")
|
|
184
|
+
y = Var("y")
|
|
185
|
+
assert_simplifies(
|
|
186
|
+
let_(x, Term.val(Val(0)), let_(y, Term.val(Val(1)), Term.var(x) + Term.var(y))),
|
|
187
|
+
Term.val(Val(1)),
|
|
188
|
+
)
|
|
189
|
+
# lambda capture
|
|
190
|
+
assert_simplifies(
|
|
191
|
+
let_(x, Term.val(Val(1)), l(lambda x: x)),
|
|
192
|
+
l(lambda x: x),
|
|
193
|
+
)
|
|
194
|
+
# lambda capture free
|
|
195
|
+
egraph = EGraph()
|
|
196
|
+
e5 = egraph.let("e5", let_(y, Term.var(x) + Term.var(x), l(lambda x: Term.var(y))))
|
|
197
|
+
egraph.run(lamdba_ruleset * 10)
|
|
198
|
+
egraph.check(freer(l(lambda x: Term.var(y))).contains(y))
|
|
199
|
+
egraph.check_fail(eq(e5).to(l(lambda x: x + x)))
|
|
200
|
+
|
|
201
|
+
# lambda_closure_not_seven
|
|
202
|
+
egraph = EGraph()
|
|
203
|
+
e6 = egraph.let(
|
|
204
|
+
"e6",
|
|
205
|
+
let_(
|
|
206
|
+
Var("five"),
|
|
207
|
+
Term.val(Val(5)),
|
|
208
|
+
let_(
|
|
209
|
+
Var("add-five"),
|
|
210
|
+
l(lambda x: x + Term.var(Var("five"))),
|
|
211
|
+
let_(Var("five"), Term.val(Val(6)), Term.var(Var("add-five"))(Term.val(Val(1)))),
|
|
212
|
+
),
|
|
213
|
+
),
|
|
214
|
+
)
|
|
215
|
+
egraph.run(lamdba_ruleset * 10)
|
|
216
|
+
egraph.check_fail(eq(e6).to(Term.val(Val(7))))
|
|
217
|
+
egraph.check(eq(e6).to(Term.val(Val(6))))
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
# lambda_compose
|
|
221
|
+
egraph = EGraph()
|
|
222
|
+
compose = Var("compose")
|
|
223
|
+
add1 = Var("add1")
|
|
224
|
+
e7 = egraph.let(
|
|
225
|
+
"e7",
|
|
226
|
+
let_(
|
|
227
|
+
compose,
|
|
228
|
+
l(
|
|
229
|
+
lambda f: l(
|
|
230
|
+
lambda g: l(
|
|
231
|
+
lambda x: f(g(x)),
|
|
232
|
+
),
|
|
233
|
+
),
|
|
234
|
+
),
|
|
235
|
+
let_(
|
|
236
|
+
add1,
|
|
237
|
+
l(lambda y: y + Term.val(Val(1))),
|
|
238
|
+
Term.var(compose)(Term.var(add1))(Term.var(add1)),
|
|
239
|
+
),
|
|
240
|
+
),
|
|
241
|
+
)
|
|
242
|
+
egraph.run(lamdba_ruleset * 20)
|
|
243
|
+
egraph.register(
|
|
244
|
+
rule(
|
|
245
|
+
eq(t1).to(l(lambda x: Term.val(Val(1)) + l(lambda y: Term.val(Val(1)) + y)(x))),
|
|
246
|
+
eq(t2).to(l(lambda x: x + Term.val(Val(2)))),
|
|
247
|
+
).then(result())
|
|
248
|
+
)
|
|
249
|
+
egraph.run(1)
|
|
250
|
+
egraph.check(result())
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
# lambda_if_simple
|
|
254
|
+
assert_simplifies(if_(Term.val(Val(1)) == Term.val(Val(1)), Term.val(Val(7)), Term.val(Val(9))), Term.val(Val(7)))
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
# # lambda_compose_many
|
|
258
|
+
assert_simplifies(
|
|
259
|
+
let_(
|
|
260
|
+
compose,
|
|
261
|
+
l(lambda f: l(lambda g: l(lambda x: f(g(x))))),
|
|
262
|
+
let_(
|
|
263
|
+
add1,
|
|
264
|
+
l(lambda y: y + Term.val(Val(1))),
|
|
265
|
+
Term.var(compose)(Term.var(add1))(
|
|
266
|
+
Term.var(compose)(Term.var(add1))(
|
|
267
|
+
Term.var(compose)(Term.var(add1))(
|
|
268
|
+
Term.var(compose)(Term.var(add1))(
|
|
269
|
+
Term.var(compose)(Term.var(add1))(Term.var(compose)(Term.var(add1))(Term.var(add1)))
|
|
270
|
+
)
|
|
271
|
+
)
|
|
272
|
+
)
|
|
273
|
+
),
|
|
274
|
+
),
|
|
275
|
+
),
|
|
276
|
+
l(lambda x: x + Term.val(Val(7))),
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# lambda_if
|
|
280
|
+
zeroone = Var("zeroone")
|
|
281
|
+
assert_simplifies(
|
|
282
|
+
let_(
|
|
283
|
+
zeroone,
|
|
284
|
+
l(lambda x: if_(x == Term.val(Val(0)), Term.val(Val(0)), Term.val(Val(1)))),
|
|
285
|
+
Term.var(zeroone)(Term.val(Val(0))) + Term.var(zeroone)(Term.val(Val(10))),
|
|
286
|
+
),
|
|
287
|
+
Term.val(Val(1)),
|
|
288
|
+
)
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Matrix multiplication and Kronecker product.
|
|
3
|
+
============================================
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from egglog import *
|
|
9
|
+
|
|
10
|
+
egraph = EGraph()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Dim(Expr):
|
|
14
|
+
"""
|
|
15
|
+
A dimension of a matix.
|
|
16
|
+
|
|
17
|
+
>>> Dim(3) * Dim.named("n")
|
|
18
|
+
Dim(3) * Dim.named("n")
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
@method(egg_fn="Lit")
|
|
22
|
+
def __init__(self, value: i64Like) -> None: ...
|
|
23
|
+
|
|
24
|
+
@method(egg_fn="NamedDim")
|
|
25
|
+
@classmethod
|
|
26
|
+
def named(cls, name: StringLike) -> Dim: # type: ignore[empty-body]
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
@method(egg_fn="Times")
|
|
30
|
+
def __mul__(self, other: Dim) -> Dim: # type: ignore[empty-body]
|
|
31
|
+
...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
a, b, c, n = vars_("a b c n", Dim)
|
|
35
|
+
i, j = vars_("i j", i64)
|
|
36
|
+
egraph.register(
|
|
37
|
+
rewrite(a * (b * c)).to((a * b) * c),
|
|
38
|
+
rewrite((a * b) * c).to(a * (b * c)),
|
|
39
|
+
rewrite(Dim(i) * Dim(j)).to(Dim(i * j)),
|
|
40
|
+
rewrite(a * b).to(b * a),
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Matrix(Expr, egg_sort="MExpr"):
|
|
45
|
+
@method(egg_fn="Id")
|
|
46
|
+
@classmethod
|
|
47
|
+
def identity(cls, dim: Dim) -> Matrix: # type: ignore[empty-body]
|
|
48
|
+
"""
|
|
49
|
+
Create an identity matrix of the given dimension.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
@method(egg_fn="NamedMat")
|
|
53
|
+
@classmethod
|
|
54
|
+
def named(cls, name: StringLike) -> Matrix: # type: ignore[empty-body]
|
|
55
|
+
"""
|
|
56
|
+
Create a named matrix.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
@method(egg_fn="MMul")
|
|
60
|
+
def __matmul__(self, other: Matrix) -> Matrix: # type: ignore[empty-body]
|
|
61
|
+
"""
|
|
62
|
+
Matrix multiplication.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
@method(egg_fn="nrows")
|
|
66
|
+
def nrows(self) -> Dim: # type: ignore[empty-body]
|
|
67
|
+
"""
|
|
68
|
+
Number of rows in the matrix.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
@method(egg_fn="ncols")
|
|
72
|
+
def ncols(self) -> Dim: # type: ignore[empty-body]
|
|
73
|
+
"""
|
|
74
|
+
Number of columns in the matrix.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@function(egg_fn="Kron")
|
|
79
|
+
def kron(a: Matrix, b: Matrix) -> Matrix: # type: ignore[empty-body]
|
|
80
|
+
"""
|
|
81
|
+
Kronecker product of two matrices.
|
|
82
|
+
|
|
83
|
+
https://en.wikipedia.org/wiki/Kronecker_product#Definition
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
A, B, C, D = vars_("A B C D", Matrix)
|
|
88
|
+
egraph.register(
|
|
89
|
+
# The dimensions of a kronecker product are the product of the dimensions
|
|
90
|
+
rewrite(kron(A, B).nrows()).to(A.nrows() * B.nrows()),
|
|
91
|
+
rewrite(kron(A, B).ncols()).to(A.ncols() * B.ncols()),
|
|
92
|
+
# The dimensions of a matrix multiplication are the number of rows of the first
|
|
93
|
+
# matrix and the number of columns of the second matrix.
|
|
94
|
+
rewrite((A @ B).nrows()).to(A.nrows()),
|
|
95
|
+
rewrite((A @ B).ncols()).to(B.ncols()),
|
|
96
|
+
# The dimensions of an identity matrix are the input dimension
|
|
97
|
+
rewrite(Matrix.identity(n).nrows()).to(n),
|
|
98
|
+
rewrite(Matrix.identity(n).ncols()).to(n),
|
|
99
|
+
)
|
|
100
|
+
egraph.register(
|
|
101
|
+
# Multiplication by an identity matrix is the same as the other matrix
|
|
102
|
+
rewrite(Matrix.identity(n) @ A).to(A),
|
|
103
|
+
rewrite(A @ Matrix.identity(n)).to(A),
|
|
104
|
+
# Matrix multiplication is associative
|
|
105
|
+
rewrite(A @ (B @ C)).to((A @ B) @ C),
|
|
106
|
+
rewrite((A @ B) @ C).to(A @ (B @ C)),
|
|
107
|
+
# Kronecker product is associative
|
|
108
|
+
rewrite(kron(A, kron(B, C))).to(kron(kron(A, B), C)),
|
|
109
|
+
rewrite(kron(kron(A, B), C)).to(kron(A, kron(B, C))),
|
|
110
|
+
# Kronecker product distributes over matrix multiplication
|
|
111
|
+
rewrite(kron(A @ C, B @ D)).to(kron(A, B) @ kron(C, D)),
|
|
112
|
+
rewrite(kron(A, B) @ kron(C, D)).to(
|
|
113
|
+
kron(A @ C, B @ D),
|
|
114
|
+
# Only when the dimensions match
|
|
115
|
+
eq(A.ncols()).to(C.nrows()),
|
|
116
|
+
eq(B.ncols()).to(D.nrows()),
|
|
117
|
+
),
|
|
118
|
+
)
|
|
119
|
+
egraph.register(
|
|
120
|
+
# demand rows and columns when we multiply matrices
|
|
121
|
+
rule(eq(C).to(A @ B)).then(
|
|
122
|
+
A.ncols(),
|
|
123
|
+
A.nrows(),
|
|
124
|
+
B.ncols(),
|
|
125
|
+
B.nrows(),
|
|
126
|
+
),
|
|
127
|
+
# demand rows and columns when we take the kronecker product
|
|
128
|
+
rule(eq(C).to(kron(A, B))).then(
|
|
129
|
+
A.ncols(),
|
|
130
|
+
A.nrows(),
|
|
131
|
+
B.ncols(),
|
|
132
|
+
B.nrows(),
|
|
133
|
+
),
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
# Define a number of dimensions
|
|
138
|
+
n = egraph.let("n", Dim.named("n"))
|
|
139
|
+
m = egraph.let("m", Dim.named("m"))
|
|
140
|
+
p = egraph.let("p", Dim.named("p"))
|
|
141
|
+
|
|
142
|
+
# Define a number of matrices
|
|
143
|
+
A = egraph.let("A", Matrix.named("A"))
|
|
144
|
+
B = egraph.let("B", Matrix.named("B"))
|
|
145
|
+
C = egraph.let("C", Matrix.named("C"))
|
|
146
|
+
|
|
147
|
+
# Set each to be a square matrix of the given dimension
|
|
148
|
+
egraph.register(
|
|
149
|
+
union(A.nrows()).with_(n),
|
|
150
|
+
union(A.ncols()).with_(n),
|
|
151
|
+
union(B.nrows()).with_(m),
|
|
152
|
+
union(B.ncols()).with_(m),
|
|
153
|
+
union(C.nrows()).with_(p),
|
|
154
|
+
union(C.ncols()).with_(p),
|
|
155
|
+
)
|
|
156
|
+
# Create an example which should equal the kronecker product of A and B
|
|
157
|
+
ex1 = egraph.let("ex1", kron(Matrix.identity(n), B) @ kron(A, Matrix.identity(m)))
|
|
158
|
+
rows = egraph.let("rows", ex1.nrows())
|
|
159
|
+
cols = egraph.let("cols", ex1.ncols())
|
|
160
|
+
|
|
161
|
+
egraph.run(20)
|
|
162
|
+
|
|
163
|
+
egraph.check(eq(B.nrows()).to(m))
|
|
164
|
+
egraph.check(eq(kron(Matrix.identity(n), B).nrows()).to(n * m))
|
|
165
|
+
|
|
166
|
+
# Verify it matches the expected result
|
|
167
|
+
simple_ex1 = egraph.let("simple_ex1", kron(A, B))
|
|
168
|
+
egraph.check(eq(ex1).to(simple_ex1))
|
|
169
|
+
|
|
170
|
+
ex2 = egraph.let("ex2", kron(Matrix.identity(p), C) @ kron(A, Matrix.identity(m)))
|
|
171
|
+
|
|
172
|
+
egraph.run(10)
|
|
173
|
+
# Verify it is not simplified
|
|
174
|
+
egraph.check_fail(eq(ex2).to(kron(A, C)))
|
|
175
|
+
egraph
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
"""
|
|
3
|
+
Multiset example based off of egglog version
|
|
4
|
+
============================================
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from collections import Counter
|
|
10
|
+
|
|
11
|
+
from egglog import *
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Math(Expr):
|
|
15
|
+
def __init__(self, x: i64Like) -> None: ...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@function
|
|
19
|
+
def square(x: Math) -> Math: ...
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@ruleset
|
|
23
|
+
def math_ruleset(i: i64):
|
|
24
|
+
yield rewrite(square(Math(i))).to(Math(i * i))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
egraph = EGraph()
|
|
28
|
+
|
|
29
|
+
xs = MultiSet(Math(1), Math(2), Math(3))
|
|
30
|
+
egraph.register(xs)
|
|
31
|
+
|
|
32
|
+
with egraph.set_current():
|
|
33
|
+
assert xs == MultiSet(Math(1), Math(3), Math(2))
|
|
34
|
+
assert xs != MultiSet(Math(1), Math(1), Math(2), Math(3))
|
|
35
|
+
|
|
36
|
+
assert Counter(xs) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
|
|
37
|
+
|
|
38
|
+
inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
|
|
39
|
+
egraph.register(inserted)
|
|
40
|
+
assert xs.insert(Math(4)) == inserted
|
|
41
|
+
|
|
42
|
+
assert xs.contains(Math(1))
|
|
43
|
+
assert xs.not_contains(Math(4))
|
|
44
|
+
assert Math(1) in xs
|
|
45
|
+
assert Math(4) not in xs
|
|
46
|
+
|
|
47
|
+
assert xs.remove(Math(1)) == MultiSet(Math(2), Math(3))
|
|
48
|
+
|
|
49
|
+
assert xs.length() == i64(3)
|
|
50
|
+
assert len(xs) == 3
|
|
51
|
+
|
|
52
|
+
assert MultiSet(Math(1), Math(1)).length() == i64(2)
|
|
53
|
+
|
|
54
|
+
assert MultiSet(Math(1)).pick() == Math(1)
|
|
55
|
+
|
|
56
|
+
mapped = xs.map(square)
|
|
57
|
+
egraph.register(mapped)
|
|
58
|
+
egraph.run(math_ruleset)
|
|
59
|
+
assert mapped == MultiSet(Math(1), Math(4), Math(9))
|
|
60
|
+
|
|
61
|
+
assert xs + xs == MultiSet(Math(1), Math(2), Math(3), Math(1), Math(2), Math(3))
|