egglog 11.2.0__cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +13 -0
- egglog/bindings.cpython-314-x86_64-linux-gnu.so +0 -0
- egglog/bindings.pyi +734 -0
- egglog/builtins.py +1133 -0
- egglog/config.py +8 -0
- egglog/conversion.py +286 -0
- egglog/declarations.py +912 -0
- egglog/deconstruct.py +173 -0
- egglog/egraph.py +1875 -0
- egglog/egraph_state.py +680 -0
- egglog/examples/README.rst +5 -0
- egglog/examples/__init__.py +3 -0
- egglog/examples/bignum.py +32 -0
- egglog/examples/bool.py +38 -0
- egglog/examples/eqsat_basic.py +44 -0
- egglog/examples/fib.py +28 -0
- egglog/examples/higher_order_functions.py +42 -0
- egglog/examples/jointree.py +67 -0
- egglog/examples/lambda_.py +287 -0
- egglog/examples/matrix.py +175 -0
- egglog/examples/multiset.py +60 -0
- egglog/examples/ndarrays.py +144 -0
- egglog/examples/resolution.py +84 -0
- egglog/examples/schedule_demo.py +34 -0
- egglog/exp/__init__.py +3 -0
- egglog/exp/array_api.py +2019 -0
- egglog/exp/array_api_jit.py +51 -0
- egglog/exp/array_api_loopnest.py +74 -0
- egglog/exp/array_api_numba.py +69 -0
- egglog/exp/array_api_program_gen.py +510 -0
- egglog/exp/program_gen.py +425 -0
- egglog/exp/siu_examples.py +32 -0
- egglog/ipython_magic.py +41 -0
- egglog/pretty.py +509 -0
- egglog/py.typed +0 -0
- egglog/runtime.py +712 -0
- egglog/thunk.py +97 -0
- egglog/type_constraint_solver.py +113 -0
- egglog/version_compat.py +87 -0
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35777 -0
- egglog/visualizer_widget.py +39 -0
- egglog-11.2.0.dist-info/METADATA +74 -0
- egglog-11.2.0.dist-info/RECORD +46 -0
- egglog-11.2.0.dist-info/WHEEL +4 -0
- egglog-11.2.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
"""
|
|
3
|
+
BigNum/BigRat Example
|
|
4
|
+
=====================
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from egglog import *
|
|
10
|
+
|
|
11
|
+
x = BigInt(-1234)
|
|
12
|
+
y = BigInt.from_string("2")
|
|
13
|
+
z = BigRat(x, y)
|
|
14
|
+
|
|
15
|
+
egraph = EGraph()
|
|
16
|
+
|
|
17
|
+
assert egraph.extract(z.numer.to_string()).value == "-617"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@function
|
|
21
|
+
def bignums(x: BigInt, y: BigInt) -> BigRat: ...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
egraph.register(set_(bignums(x, y)).to(z))
|
|
25
|
+
|
|
26
|
+
c = var("c", BigRat)
|
|
27
|
+
a, b = vars_("a b", BigInt)
|
|
28
|
+
egraph.check(
|
|
29
|
+
bignums(a, b) == c,
|
|
30
|
+
c.numer == a >> 1,
|
|
31
|
+
c.denom == b >> 1,
|
|
32
|
+
)
|
egglog/examples/bool.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
"""
|
|
3
|
+
Boolean data type example and test
|
|
4
|
+
==================================
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from egglog import *
|
|
10
|
+
|
|
11
|
+
T = Bool(True)
|
|
12
|
+
F = Bool(False)
|
|
13
|
+
check(eq(T & T).to(T))
|
|
14
|
+
check(eq(T & F).to(F))
|
|
15
|
+
check(eq(T | F).to(T))
|
|
16
|
+
check(ne(T | F).to(F))
|
|
17
|
+
|
|
18
|
+
check(eq(i64(1).bool_lt(2)).to(T))
|
|
19
|
+
check(eq(i64(2).bool_lt(1)).to(F))
|
|
20
|
+
check(eq(i64(1).bool_lt(1)).to(F))
|
|
21
|
+
|
|
22
|
+
check(eq(i64(1).bool_le(2)).to(T))
|
|
23
|
+
check(eq(i64(2).bool_le(1)).to(F))
|
|
24
|
+
check(eq(i64(1).bool_le(1)).to(T))
|
|
25
|
+
|
|
26
|
+
R = relation("R", i64)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@function
|
|
30
|
+
def f(i: i64Like) -> Bool: ...
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
i = var("i", i64)
|
|
34
|
+
check(
|
|
35
|
+
eq(f(0)).to(T),
|
|
36
|
+
ruleset(rule(R(i)).then(set_(f(i)).to(T))) * 3,
|
|
37
|
+
R(i64(0)),
|
|
38
|
+
)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
"""
|
|
3
|
+
Basic equality saturation example.
|
|
4
|
+
==================================
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from egglog import *
|
|
10
|
+
|
|
11
|
+
egraph = EGraph()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Num(Expr):
|
|
15
|
+
def __init__(self, value: i64Like) -> None: ...
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def var(cls, name: StringLike) -> Num: ...
|
|
19
|
+
|
|
20
|
+
def __add__(self, other: Num) -> Num: ...
|
|
21
|
+
|
|
22
|
+
def __mul__(self, other: Num) -> Num: ...
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
expr1 = Num(2) * (Num.var("x") + Num(3))
|
|
26
|
+
expr2 = Num(6) + Num(2) * Num.var("x")
|
|
27
|
+
|
|
28
|
+
a, b, c = vars_("a b c", Num)
|
|
29
|
+
i, j = vars_("i j", i64)
|
|
30
|
+
|
|
31
|
+
egraph = EGraph()
|
|
32
|
+
egraph.register(expr1, expr2)
|
|
33
|
+
|
|
34
|
+
egraph.run(
|
|
35
|
+
ruleset(
|
|
36
|
+
rewrite(a + b).to(b + a),
|
|
37
|
+
rewrite(a * (b + c)).to((a * b) + (a * c)),
|
|
38
|
+
rewrite(Num(i) + Num(j)).to(Num(i + j)),
|
|
39
|
+
rewrite(Num(i) * Num(j)).to(Num(i * j)),
|
|
40
|
+
)
|
|
41
|
+
* 10
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
egraph.check(expr1 == expr2)
|
egglog/examples/fib.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
"""
|
|
3
|
+
Fibonacci numbers example
|
|
4
|
+
=========================
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from egglog import *
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@function
|
|
13
|
+
def fib(x: i64Like) -> i64: ...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
f0, f1, x = vars_("f0 f1 x", i64)
|
|
17
|
+
check(
|
|
18
|
+
eq(fib(i64(7))).to(i64(21)),
|
|
19
|
+
ruleset(
|
|
20
|
+
rule(
|
|
21
|
+
eq(f0).to(fib(x)),
|
|
22
|
+
eq(f1).to(fib(x + 1)),
|
|
23
|
+
).then(set_(fib(x + 2)).to(f0 + f1)),
|
|
24
|
+
)
|
|
25
|
+
* 7,
|
|
26
|
+
set_(fib(0)).to(i64(1)),
|
|
27
|
+
set_(fib(1)).to(i64(1)),
|
|
28
|
+
)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
"""
|
|
3
|
+
Higher Order Functions
|
|
4
|
+
======================
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
|
|
11
|
+
from egglog import *
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Math(Expr):
|
|
15
|
+
def __init__(self, i: i64Like) -> None: ...
|
|
16
|
+
def __add__(self, other: Math) -> Math: ...
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MathList(Expr):
|
|
20
|
+
def __init__(self) -> None: ...
|
|
21
|
+
def append(self, i: Math) -> MathList: ...
|
|
22
|
+
def map(self, f: Callable[[Math], Math]) -> MathList: ...
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@ruleset
|
|
26
|
+
def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]):
|
|
27
|
+
yield rewrite(Math(i) + Math(j)).to(Math(i + j))
|
|
28
|
+
yield rewrite(xs.append(x).map(f)).to(xs.map(f).append(f(x)))
|
|
29
|
+
yield rewrite(MathList().map(f)).to(MathList())
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@function(ruleset=math_ruleset)
|
|
33
|
+
def incr_list(xs: MathList) -> MathList:
|
|
34
|
+
return xs.map(lambda x: x + Math(1))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
egraph = EGraph()
|
|
38
|
+
y = egraph.let("y", incr_list(MathList().append(Math(1)).append(Math(2))))
|
|
39
|
+
egraph.run(math_ruleset.saturate())
|
|
40
|
+
egraph.check(eq(y).to(MathList().append(Math(2)).append(Math(3))))
|
|
41
|
+
|
|
42
|
+
egraph
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Join Tree (custom costs)
|
|
5
|
+
========================
|
|
6
|
+
|
|
7
|
+
Example of using custom cost functions for jointree.
|
|
8
|
+
|
|
9
|
+
From https://egraphs.zulipchat.com/#narrow/stream/328972-general/topic/How.20can.20I.20find.20the.20tree.20associated.20with.20an.20extraction.3F
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from egglog import *
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class JoinTree(Expr):
|
|
18
|
+
def __init__(self, name: StringLike) -> None: ...
|
|
19
|
+
|
|
20
|
+
def join(self, other: JoinTree) -> JoinTree: ...
|
|
21
|
+
|
|
22
|
+
@method(merge=lambda old, new: old.min(new)) # type:ignore[prop-decorator]
|
|
23
|
+
@property
|
|
24
|
+
def size(self) -> i64: ...
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
ra = JoinTree("a")
|
|
28
|
+
rb = JoinTree("b")
|
|
29
|
+
rc = JoinTree("c")
|
|
30
|
+
rd = JoinTree("d")
|
|
31
|
+
re = JoinTree("e")
|
|
32
|
+
rf = JoinTree("f")
|
|
33
|
+
|
|
34
|
+
query = ra.join(rb).join(rc).join(rd).join(re).join(rf)
|
|
35
|
+
|
|
36
|
+
egraph = EGraph()
|
|
37
|
+
egraph.register(
|
|
38
|
+
set_(ra.size).to(50),
|
|
39
|
+
set_(rb.size).to(200),
|
|
40
|
+
set_(rc.size).to(10),
|
|
41
|
+
set_(rd.size).to(123),
|
|
42
|
+
set_(re.size).to(10000),
|
|
43
|
+
set_(rf.size).to(1),
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@egraph.register
|
|
48
|
+
def _rules(s: String, a: JoinTree, b: JoinTree, c: JoinTree, asize: i64, bsize: i64):
|
|
49
|
+
# cost of relation is its size minus 1, since the string arg will have a cost of 1 as well
|
|
50
|
+
yield rule(JoinTree(s).size == asize).then(set_cost(JoinTree(s), asize - 1))
|
|
51
|
+
# cost/size of join is product of sizes
|
|
52
|
+
yield rule(a.join(b), a.size == asize, b.size == bsize).then(
|
|
53
|
+
set_(a.join(b).size).to(asize * bsize), set_cost(a.join(b), asize * bsize)
|
|
54
|
+
)
|
|
55
|
+
# associativity
|
|
56
|
+
yield rewrite(a.join(b)).to(b.join(a))
|
|
57
|
+
# commutativity
|
|
58
|
+
yield rewrite(a.join(b).join(c)).to(a.join(b.join(c)))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
egraph.register(query)
|
|
62
|
+
egraph.run(1000)
|
|
63
|
+
print(egraph.extract(query))
|
|
64
|
+
print(egraph.extract(query.size))
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
egraph
|
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Lambda Calculus
|
|
5
|
+
===============
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from typing import ClassVar
|
|
12
|
+
|
|
13
|
+
from egglog import *
|
|
14
|
+
from egglog import Expr
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Val(Expr):
|
|
18
|
+
"""
|
|
19
|
+
A value is a number or a boolean.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
TRUE: ClassVar[Val]
|
|
23
|
+
FALSE: ClassVar[Val]
|
|
24
|
+
|
|
25
|
+
def __init__(self, v: i64Like) -> None: ...
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Var(Expr):
|
|
29
|
+
def __init__(self, v: StringLike) -> None: ...
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Term(Expr):
|
|
33
|
+
@classmethod
|
|
34
|
+
def val(cls, v: Val) -> Term: ...
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def var(cls, v: Var) -> Term: ...
|
|
38
|
+
|
|
39
|
+
def __add__(self, other: Term) -> Term: ...
|
|
40
|
+
|
|
41
|
+
def __eq__(self, other: Term) -> Term: # type: ignore[override]
|
|
42
|
+
...
|
|
43
|
+
|
|
44
|
+
def __call__(self, other: Term) -> Term: ...
|
|
45
|
+
|
|
46
|
+
def eval(self) -> Val: ...
|
|
47
|
+
|
|
48
|
+
def v(self) -> Var: ...
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@function
|
|
52
|
+
def lam(x: Var, t: Term) -> Term: ...
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@function
|
|
56
|
+
def let_(x: Var, t: Term, b: Term) -> Term: ...
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@function
|
|
60
|
+
def fix(x: Var, t: Term) -> Term: ...
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@function
|
|
64
|
+
def if_(c: Term, t: Term, f: Term) -> Term: ...
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
StringSet = Set[Var]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@function(merge=lambda old, new: old & new)
|
|
71
|
+
def freer(t: Term) -> StringSet: ...
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
(v, v1, v2) = vars_("v v1 v2", Val)
|
|
75
|
+
(t, t1, t2, t3, t4) = vars_("t t1 t2 t3 t4", Term)
|
|
76
|
+
(x, y) = vars_("x y", Var)
|
|
77
|
+
fv, fv1, fv2, fv3 = vars_("fv fv1 fv2 fv3", StringSet)
|
|
78
|
+
i1, i2 = vars_("i1 i2", i64)
|
|
79
|
+
lamdba_ruleset = ruleset(
|
|
80
|
+
# freer
|
|
81
|
+
rule(eq(t).to(Term.val(v))).then(set_(freer(t)).to(StringSet.empty())),
|
|
82
|
+
rule(eq(t).to(Term.var(x))).then(set_(freer(t)).to(StringSet.empty().insert(x))),
|
|
83
|
+
rule(eq(t).to(t1 + t2), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
|
|
84
|
+
rule(eq(t).to(t1 == t2), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
|
|
85
|
+
rule(eq(t).to(t1(t2)), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
|
|
86
|
+
rule(eq(t).to(lam(x, t1)), eq(freer(t1)).to(fv)).then(set_(freer(t)).to(fv.remove(x))),
|
|
87
|
+
rule(eq(t).to(let_(x, t1, t2)), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(
|
|
88
|
+
set_(freer(t)).to(fv1.remove(x) | fv2)
|
|
89
|
+
),
|
|
90
|
+
rule(eq(t).to(fix(x, t1)), eq(freer(t1)).to(fv)).then(set_(freer(t)).to(fv.remove(x))),
|
|
91
|
+
rule(eq(t).to(if_(t1, t2, t3)), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2), eq(freer(t3)).to(fv3)).then(
|
|
92
|
+
set_(freer(t)).to(fv1 | fv2 | fv3)
|
|
93
|
+
),
|
|
94
|
+
# eval
|
|
95
|
+
rule(eq(t).to(Term.val(v))).then(union(t.eval()).with_(v)),
|
|
96
|
+
rule(eq(t).to(t1 + t2), eq(Val(i1)).to(t1.eval()), eq(Val(i2)).to(t2.eval())).then(
|
|
97
|
+
union(t.eval()).with_(Val(i1 + i2))
|
|
98
|
+
),
|
|
99
|
+
rule(eq(t).to(t1 == t2), eq(t1.eval()).to(t2.eval())).then(union(t.eval()).with_(Val.TRUE)),
|
|
100
|
+
rule(eq(t).to(t1 == t2), eq(t1.eval()).to(v1), eq(t2.eval()).to(v2), ne(v1).to(v2)).then(
|
|
101
|
+
union(t.eval()).with_(Val.FALSE)
|
|
102
|
+
),
|
|
103
|
+
rule(eq(v).to(t.eval())).then(union(t).with_(Term.val(v))),
|
|
104
|
+
# if
|
|
105
|
+
rewrite(if_(Term.val(Val.TRUE), t1, t2)).to(t1),
|
|
106
|
+
rewrite(if_(Term.val(Val.FALSE), t1, t2)).to(t2),
|
|
107
|
+
# if-elim
|
|
108
|
+
# Adds let rules so next one can match on them
|
|
109
|
+
rule(eq(t).to(if_(Term.var(x) == t1, t2, t3))).then(let_(x, t1, t2), let_(x, t1, t3)),
|
|
110
|
+
rewrite(if_(Term.var(x) == t1, t2, t3)).to(
|
|
111
|
+
t3,
|
|
112
|
+
eq(let_(x, t1, t2)).to(let_(x, t1, t3)),
|
|
113
|
+
),
|
|
114
|
+
# add-comm
|
|
115
|
+
rewrite(t1 + t2).to(t2 + t1),
|
|
116
|
+
# add-assoc
|
|
117
|
+
rewrite((t1 + t2) + t3).to(t1 + (t2 + t3)),
|
|
118
|
+
# eq-comm
|
|
119
|
+
rewrite(t1 == t2).to(t2 == t1),
|
|
120
|
+
# Fix
|
|
121
|
+
rewrite(fix(x, t)).to(let_(x, fix(x, t), t)),
|
|
122
|
+
# beta reduction
|
|
123
|
+
rewrite(lam(x, t)(t1)).to(let_(x, t1, t)),
|
|
124
|
+
# let-app
|
|
125
|
+
rewrite(let_(x, t, t1(t2))).to(let_(x, t, t1)(let_(x, t, t2))),
|
|
126
|
+
# let-add
|
|
127
|
+
rewrite(let_(x, t, t1 + t2)).to(let_(x, t, t1) + let_(x, t, t2)),
|
|
128
|
+
# let-eq
|
|
129
|
+
rewrite(let_(x, t, t1 == t2)).to(let_(x, t, t1) == let_(x, t, t2)),
|
|
130
|
+
# let-const
|
|
131
|
+
rewrite(let_(x, t, Term.val(v))).to(Term.val(v)),
|
|
132
|
+
# let-if
|
|
133
|
+
rewrite(let_(x, t, if_(t1, t2, t3))).to(if_(let_(x, t, t1), let_(x, t, t2), let_(x, t, t3))),
|
|
134
|
+
# let-var-same
|
|
135
|
+
rewrite(let_(x, t, Term.var(x))).to(t),
|
|
136
|
+
# let-var-diff
|
|
137
|
+
rewrite(let_(x, t, Term.var(y))).to(Term.var(y), ne(x).to(y)),
|
|
138
|
+
# let-lam-same
|
|
139
|
+
rewrite(let_(x, t, lam(x, t1))).to(lam(x, t1)),
|
|
140
|
+
# let-lam-diff
|
|
141
|
+
rewrite(let_(x, t, lam(y, t1))).to(lam(y, let_(x, t, t1)), ne(x).to(y), eq(fv).to(freer(t)), fv.not_contains(y)),
|
|
142
|
+
rule(eq(t).to(let_(x, t1, lam(y, t2))), ne(x).to(y), eq(fv).to(freer(t1)), fv.contains(y)).then(
|
|
143
|
+
union(t).with_(lam(t.v(), let_(x, t1, let_(y, Term.var(t.v()), t2))))
|
|
144
|
+
),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
result = relation("result")
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def l(fn: Callable[[Term], Term]) -> Term: # noqa: E743
|
|
151
|
+
"""
|
|
152
|
+
Create a lambda term from a function
|
|
153
|
+
"""
|
|
154
|
+
# Use first var name from fn
|
|
155
|
+
x = fn.__code__.co_varnames[0]
|
|
156
|
+
return lam(Var(x), fn(Term.var(Var(x))))
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def assert_simplifies(left: Expr, right: Expr) -> None:
|
|
160
|
+
"""
|
|
161
|
+
Simplify and print
|
|
162
|
+
"""
|
|
163
|
+
print(f"{left} ➡ {right}")
|
|
164
|
+
check(eq(left).to(right), lamdba_ruleset * 30, left)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
assert_simplifies((Term.val(Val(1))).eval(), Val(1))
|
|
168
|
+
assert_simplifies((Term.val(Val(1)) + Term.val(Val(2))).eval(), Val(3))
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# lambda under
|
|
172
|
+
assert_simplifies(
|
|
173
|
+
l(lambda x: Term.val(Val(4)) + l(lambda y: y)(Term.val(Val(4)))),
|
|
174
|
+
l(lambda x: Term.val(Val(8))),
|
|
175
|
+
)
|
|
176
|
+
# lambda if elim
|
|
177
|
+
a = Term.var(Var("a"))
|
|
178
|
+
b = Term.var(Var("b"))
|
|
179
|
+
assert_simplifies(if_(a == b, a + a, a + b), a + b)
|
|
180
|
+
|
|
181
|
+
# lambda let simple
|
|
182
|
+
x = Var("x")
|
|
183
|
+
y = Var("y")
|
|
184
|
+
assert_simplifies(
|
|
185
|
+
let_(x, Term.val(Val(0)), let_(y, Term.val(Val(1)), Term.var(x) + Term.var(y))),
|
|
186
|
+
Term.val(Val(1)),
|
|
187
|
+
)
|
|
188
|
+
# lambda capture
|
|
189
|
+
assert_simplifies(
|
|
190
|
+
let_(x, Term.val(Val(1)), l(lambda x: x)),
|
|
191
|
+
l(lambda x: x),
|
|
192
|
+
)
|
|
193
|
+
# lambda capture free
|
|
194
|
+
egraph = EGraph()
|
|
195
|
+
e5 = egraph.let("e5", let_(y, Term.var(x) + Term.var(x), l(lambda x: Term.var(y))))
|
|
196
|
+
egraph.run(lamdba_ruleset * 10)
|
|
197
|
+
egraph.check(freer(l(lambda x: Term.var(y))).contains(y))
|
|
198
|
+
egraph.check_fail(eq(e5).to(l(lambda x: x + x)))
|
|
199
|
+
|
|
200
|
+
# lambda_closure_not_seven
|
|
201
|
+
egraph = EGraph()
|
|
202
|
+
e6 = egraph.let(
|
|
203
|
+
"e6",
|
|
204
|
+
let_(
|
|
205
|
+
Var("five"),
|
|
206
|
+
Term.val(Val(5)),
|
|
207
|
+
let_(
|
|
208
|
+
Var("add-five"),
|
|
209
|
+
l(lambda x: x + Term.var(Var("five"))),
|
|
210
|
+
let_(Var("five"), Term.val(Val(6)), Term.var(Var("add-five"))(Term.val(Val(1)))),
|
|
211
|
+
),
|
|
212
|
+
),
|
|
213
|
+
)
|
|
214
|
+
egraph.run(lamdba_ruleset * 10)
|
|
215
|
+
egraph.check_fail(eq(e6).to(Term.val(Val(7))))
|
|
216
|
+
egraph.check(eq(e6).to(Term.val(Val(6))))
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# lambda_compose
|
|
220
|
+
egraph = EGraph()
|
|
221
|
+
compose = Var("compose")
|
|
222
|
+
add1 = Var("add1")
|
|
223
|
+
e7 = egraph.let(
|
|
224
|
+
"e7",
|
|
225
|
+
let_(
|
|
226
|
+
compose,
|
|
227
|
+
l(
|
|
228
|
+
lambda f: l(
|
|
229
|
+
lambda g: l(
|
|
230
|
+
lambda x: f(g(x)),
|
|
231
|
+
),
|
|
232
|
+
),
|
|
233
|
+
),
|
|
234
|
+
let_(
|
|
235
|
+
add1,
|
|
236
|
+
l(lambda y: y + Term.val(Val(1))),
|
|
237
|
+
Term.var(compose)(Term.var(add1))(Term.var(add1)),
|
|
238
|
+
),
|
|
239
|
+
),
|
|
240
|
+
)
|
|
241
|
+
egraph.run(lamdba_ruleset * 20)
|
|
242
|
+
egraph.register(
|
|
243
|
+
rule(
|
|
244
|
+
eq(t1).to(l(lambda x: Term.val(Val(1)) + l(lambda y: Term.val(Val(1)) + y)(x))),
|
|
245
|
+
eq(t2).to(l(lambda x: x + Term.val(Val(2)))),
|
|
246
|
+
).then(result())
|
|
247
|
+
)
|
|
248
|
+
egraph.run(1)
|
|
249
|
+
egraph.check(result())
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
# lambda_if_simple
|
|
253
|
+
assert_simplifies(if_(Term.val(Val(1)) == Term.val(Val(1)), Term.val(Val(7)), Term.val(Val(9))), Term.val(Val(7)))
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
# # lambda_compose_many
|
|
257
|
+
assert_simplifies(
|
|
258
|
+
let_(
|
|
259
|
+
compose,
|
|
260
|
+
l(lambda f: l(lambda g: l(lambda x: f(g(x))))),
|
|
261
|
+
let_(
|
|
262
|
+
add1,
|
|
263
|
+
l(lambda y: y + Term.val(Val(1))),
|
|
264
|
+
Term.var(compose)(Term.var(add1))(
|
|
265
|
+
Term.var(compose)(Term.var(add1))(
|
|
266
|
+
Term.var(compose)(Term.var(add1))(
|
|
267
|
+
Term.var(compose)(Term.var(add1))(
|
|
268
|
+
Term.var(compose)(Term.var(add1))(Term.var(compose)(Term.var(add1))(Term.var(add1)))
|
|
269
|
+
)
|
|
270
|
+
)
|
|
271
|
+
)
|
|
272
|
+
),
|
|
273
|
+
),
|
|
274
|
+
),
|
|
275
|
+
l(lambda x: x + Term.val(Val(7))),
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# lambda_if
|
|
279
|
+
zeroone = Var("zeroone")
|
|
280
|
+
assert_simplifies(
|
|
281
|
+
let_(
|
|
282
|
+
zeroone,
|
|
283
|
+
l(lambda x: if_(x == Term.val(Val(0)), Term.val(Val(0)), Term.val(Val(1)))),
|
|
284
|
+
Term.var(zeroone)(Term.val(Val(0))) + Term.var(zeroone)(Term.val(Val(10))),
|
|
285
|
+
),
|
|
286
|
+
Term.val(Val(1)),
|
|
287
|
+
)
|