logic-loom 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- logic_loom/__init__.py +67 -0
- logic_loom/__main__.py +4 -0
- logic_loom/analysis.py +94 -0
- logic_loom/cli.py +88 -0
- logic_loom/codegen.py +189 -0
- logic_loom/compiler.py +135 -0
- logic_loom/cost.py +166 -0
- logic_loom/effects.py +88 -0
- logic_loom/egraph.py +155 -0
- logic_loom/expr.py +138 -0
- logic_loom/parser.py +167 -0
- logic_loom/rules.py +181 -0
- logic_loom/saturate.py +210 -0
- logic_loom/viz.py +71 -0
- logic_loom-0.3.0.dist-info/METADATA +504 -0
- logic_loom-0.3.0.dist-info/RECORD +20 -0
- logic_loom-0.3.0.dist-info/WHEEL +5 -0
- logic_loom-0.3.0.dist-info/entry_points.txt +2 -0
- logic_loom-0.3.0.dist-info/licenses/LICENSE +21 -0
- logic_loom-0.3.0.dist-info/top_level.txt +1 -0
logic_loom/__init__.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Logic-Loom: a compiler that understands mathematics.
|
|
2
|
+
|
|
3
|
+
Instead of peephole-optimizing instructions, Logic-Loom reasons about
|
|
4
|
+
the *algebra* of an expression. It uses equality saturation over an
|
|
5
|
+
e-graph to discover the whole space of equivalent forms, then extracts
|
|
6
|
+
the cheapest one under a configurable, hardware-aware cost model.
|
|
7
|
+
|
|
8
|
+
>>> from logic_loom import optimize
|
|
9
|
+
>>> print(optimize("a*b + a*c"))
|
|
10
|
+
a * b + a * c => a * (b + c)
|
|
11
|
+
cost 5.4 -> 3.3 (1.64x)
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from .analysis import Analysis, analyze, reachable_rules
|
|
15
|
+
from .codegen import free_vars, to_code, to_llvm
|
|
16
|
+
from .compiler import Result, build_egraph, optimize
|
|
17
|
+
from .cost import (
|
|
18
|
+
DEFAULT_MODEL,
|
|
19
|
+
PROFILES,
|
|
20
|
+
CostModel,
|
|
21
|
+
expr_cost,
|
|
22
|
+
extract,
|
|
23
|
+
get_profile,
|
|
24
|
+
)
|
|
25
|
+
from .effects import is_effect_safe, tainted_classes
|
|
26
|
+
from .egraph import EGraph
|
|
27
|
+
from .expr import Expr, evaluate
|
|
28
|
+
from .parser import parse
|
|
29
|
+
from .rules import ALL_RULES, DEFAULT_RULES, EXTENDED_RULES, Rule, rule
|
|
30
|
+
from .saturate import BackoffScheduler, SaturationReport, saturate
|
|
31
|
+
from .viz import to_dot
|
|
32
|
+
|
|
33
|
+
__version__ = "0.3.0"
|
|
34
|
+
|
|
35
|
+
__all__ = [
|
|
36
|
+
"optimize",
|
|
37
|
+
"build_egraph",
|
|
38
|
+
"Result",
|
|
39
|
+
"Expr",
|
|
40
|
+
"evaluate",
|
|
41
|
+
"parse",
|
|
42
|
+
"EGraph",
|
|
43
|
+
"Rule",
|
|
44
|
+
"rule",
|
|
45
|
+
"DEFAULT_RULES",
|
|
46
|
+
"EXTENDED_RULES",
|
|
47
|
+
"ALL_RULES",
|
|
48
|
+
"saturate",
|
|
49
|
+
"SaturationReport",
|
|
50
|
+
"BackoffScheduler",
|
|
51
|
+
"extract",
|
|
52
|
+
"expr_cost",
|
|
53
|
+
"CostModel",
|
|
54
|
+
"PROFILES",
|
|
55
|
+
"DEFAULT_MODEL",
|
|
56
|
+
"get_profile",
|
|
57
|
+
"to_code",
|
|
58
|
+
"to_llvm",
|
|
59
|
+
"free_vars",
|
|
60
|
+
"to_dot",
|
|
61
|
+
"analyze",
|
|
62
|
+
"Analysis",
|
|
63
|
+
"reachable_rules",
|
|
64
|
+
"tainted_classes",
|
|
65
|
+
"is_effect_safe",
|
|
66
|
+
"__version__",
|
|
67
|
+
]
|
logic_loom/__main__.py
ADDED
logic_loom/analysis.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""Static analysis to tame the search before saturation begins.
|
|
2
|
+
|
|
3
|
+
Equality saturation explores blindly; a little static reasoning about the
|
|
4
|
+
input lets us avoid work that provably cannot help.
|
|
5
|
+
|
|
6
|
+
Two analyses live here:
|
|
7
|
+
|
|
8
|
+
1. **Reachable-rule pruning.** A rule can only ever fire if the operators
|
|
9
|
+
in its left-hand side are present -- and operators only appear if some
|
|
10
|
+
*other* fireable rule introduces them. Computing this set as a
|
|
11
|
+
fixed-point lets us drop rules that can never match. This is sound:
|
|
12
|
+
the pruned rules would have contributed nothing, so the result is
|
|
13
|
+
identical, only reached faster.
|
|
14
|
+
|
|
15
|
+
2. **Complexity estimate.** Counting associative/commutative operators
|
|
16
|
+
predicts how badly the e-graph might blow up, which we use to size the
|
|
17
|
+
resource limits automatically.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
from dataclasses import dataclass
|
|
23
|
+
from typing import List, Set
|
|
24
|
+
|
|
25
|
+
from .expr import Expr
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def operators(e: Expr, acc: Set[str] | None = None) -> Set[str]:
|
|
29
|
+
"""The set of operator/function heads appearing in ``e``."""
|
|
30
|
+
if acc is None:
|
|
31
|
+
acc = set()
|
|
32
|
+
if e.kind == "app":
|
|
33
|
+
acc.add(e.name)
|
|
34
|
+
for a in e.args:
|
|
35
|
+
operators(a, acc)
|
|
36
|
+
return acc
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _size(e: Expr) -> int:
|
|
40
|
+
return 1 + sum(_size(a) for a in e.args)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _count_heads(e: Expr, heads) -> int:
|
|
44
|
+
n = 1 if (e.kind == "app" and e.name in heads) else 0
|
|
45
|
+
return n + sum(_count_heads(a, heads) for a in e.args)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class Analysis:
|
|
50
|
+
size: int
|
|
51
|
+
variables: int
|
|
52
|
+
ac_ops: int # number of + and * nodes (the explosion driver)
|
|
53
|
+
node_limit: int
|
|
54
|
+
match_limit: int
|
|
55
|
+
|
|
56
|
+
def summary(self) -> str:
|
|
57
|
+
return (f"size={self.size} vars={self.variables} ac_ops={self.ac_ops} "
|
|
58
|
+
f"-> node_limit={self.node_limit} match_limit={self.match_limit}")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def reachable_rules(expr: Expr, rules: List) -> List:
|
|
62
|
+
"""Keep only rules whose operators can become reachable from ``expr``."""
|
|
63
|
+
reachable = operators(expr)
|
|
64
|
+
rule_ops = [(r, operators(r.lhs), operators(r.rhs)) for r in rules]
|
|
65
|
+
changed = True
|
|
66
|
+
while changed:
|
|
67
|
+
changed = False
|
|
68
|
+
for _r, lhs_ops, rhs_ops in rule_ops:
|
|
69
|
+
if lhs_ops <= reachable and not rhs_ops <= reachable:
|
|
70
|
+
reachable |= rhs_ops
|
|
71
|
+
changed = True
|
|
72
|
+
return [r for r, lhs_ops, _ in rule_ops if lhs_ops <= reachable]
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def analyze(expr: Expr) -> Analysis:
|
|
76
|
+
size = _size(expr)
|
|
77
|
+
variables = len({v for v in _vars(expr)})
|
|
78
|
+
ac = _count_heads(expr, {"+", "*"})
|
|
79
|
+
# Small inputs can saturate fully; large AC-heavy ones need tighter
|
|
80
|
+
# budgets so we stop early with the best form found instead of thrashing.
|
|
81
|
+
node_limit = min(20_000, max(2_000, 400 * (ac + 1)))
|
|
82
|
+
match_limit = 2_000 if ac <= 3 else max(300, 2_000 // ac)
|
|
83
|
+
return Analysis(size=size, variables=variables, ac_ops=ac,
|
|
84
|
+
node_limit=node_limit, match_limit=match_limit)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _vars(e: Expr, acc: Set[str] | None = None) -> Set[str]:
|
|
88
|
+
if acc is None:
|
|
89
|
+
acc = set()
|
|
90
|
+
if e.kind == "var":
|
|
91
|
+
acc.add(e.name)
|
|
92
|
+
for a in e.args:
|
|
93
|
+
_vars(a, acc)
|
|
94
|
+
return acc
|
logic_loom/cli.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""Command-line interface: python -m logic_loom "a*b + a*c"."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
from . import __version__
|
|
9
|
+
from .codegen import to_code, to_llvm
|
|
10
|
+
from .compiler import build_egraph, optimize
|
|
11
|
+
from .cost import PROFILES
|
|
12
|
+
from .rules import ALL_RULES, DEFAULT_RULES
|
|
13
|
+
from .viz import to_dot
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def main(argv=None) -> int:
|
|
17
|
+
p = argparse.ArgumentParser(
|
|
18
|
+
prog="logic_loom",
|
|
19
|
+
description="A compiler that understands mathematics. "
|
|
20
|
+
"Give it an expression; it returns the cheapest equivalent form.",
|
|
21
|
+
)
|
|
22
|
+
p.add_argument("expression", nargs="*", help="math expression, e.g. 'a*b + a*c'")
|
|
23
|
+
p.add_argument("-v", "--verbose", action="store_true",
|
|
24
|
+
help="show saturation statistics")
|
|
25
|
+
p.add_argument("--extended", action="store_true",
|
|
26
|
+
help="enable transcendental-function rules (exp/log/sqrt/trig)")
|
|
27
|
+
p.add_argument("--profile", choices=list(PROFILES), default="default",
|
|
28
|
+
help="cost profile to optimize for (default: default)")
|
|
29
|
+
p.add_argument("--target", choices=["c", "rust", "js", "llvm"],
|
|
30
|
+
help="also emit the optimized form as source code / IR")
|
|
31
|
+
p.add_argument("--impure", default="",
|
|
32
|
+
help="comma-separated names of side-effecting functions")
|
|
33
|
+
p.add_argument("--explain", action="store_true",
|
|
34
|
+
help="report the domain assumptions the result relies on")
|
|
35
|
+
p.add_argument("--dot", action="store_true",
|
|
36
|
+
help="print the saturated e-graph as Graphviz DOT")
|
|
37
|
+
p.add_argument("--max-iters", type=int, default=30)
|
|
38
|
+
p.add_argument("--node-limit", type=int, default=None)
|
|
39
|
+
p.add_argument("--version", action="version", version=f"logic-loom {__version__}")
|
|
40
|
+
args = p.parse_args(argv)
|
|
41
|
+
|
|
42
|
+
rules = ALL_RULES if args.extended else DEFAULT_RULES
|
|
43
|
+
impure = {s.strip() for s in args.impure.split(",") if s.strip()}
|
|
44
|
+
sources = [" ".join(args.expression)] if args.expression else _read_stdin()
|
|
45
|
+
if not sources:
|
|
46
|
+
p.print_help()
|
|
47
|
+
return 1
|
|
48
|
+
|
|
49
|
+
for src in sources:
|
|
50
|
+
src = src.strip()
|
|
51
|
+
if not src:
|
|
52
|
+
continue
|
|
53
|
+
|
|
54
|
+
if args.dot:
|
|
55
|
+
eg, root, _ = build_egraph(
|
|
56
|
+
src, rules=rules, max_iters=args.max_iters,
|
|
57
|
+
node_limit=args.node_limit, impure=impure)
|
|
58
|
+
print(to_dot(eg, root))
|
|
59
|
+
continue
|
|
60
|
+
|
|
61
|
+
r = optimize(src, rules=rules, profile=args.profile,
|
|
62
|
+
impure=impure, max_iters=args.max_iters,
|
|
63
|
+
node_limit=args.node_limit)
|
|
64
|
+
print(r)
|
|
65
|
+
if args.target == "llvm":
|
|
66
|
+
print(to_llvm(r.optimized))
|
|
67
|
+
elif args.target:
|
|
68
|
+
print(f" {args.target}: {to_code(r.optimized, args.target)}")
|
|
69
|
+
if args.explain and r.assumptions:
|
|
70
|
+
print(f" assumes (for soundness): {'; '.join(r.assumptions)}")
|
|
71
|
+
if args.verbose:
|
|
72
|
+
rep = r.report
|
|
73
|
+
print(f" [{rep.stop_reason}] profile={r.model.name} "
|
|
74
|
+
f"iterations={rep.iterations} e-nodes={rep.nodes} "
|
|
75
|
+
f"e-classes={rep.classes}")
|
|
76
|
+
if rep.banned:
|
|
77
|
+
print(f" throttled rules: {', '.join(rep.banned)}")
|
|
78
|
+
return 0
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _read_stdin():
|
|
82
|
+
if sys.stdin.isatty():
|
|
83
|
+
return []
|
|
84
|
+
return sys.stdin.read().splitlines()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
if __name__ == "__main__":
|
|
88
|
+
raise SystemExit(main())
|
logic_loom/codegen.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
"""Emit an optimized expression as source code in a real language.
|
|
2
|
+
|
|
3
|
+
The whole point of optimizing an expression is to *run* it. This module
|
|
4
|
+
turns an :class:`Expr` tree into a snippet of C, Rust or JavaScript, so
|
|
5
|
+
the optimized form can be pasted straight into a program.
|
|
6
|
+
|
|
7
|
+
>>> from logic_loom import optimize, to_code
|
|
8
|
+
>>> e = optimize("a*x*x + b*x + c").optimized
|
|
9
|
+
>>> print(to_code(e, "c"))
|
|
10
|
+
x * (a * x + b) + c
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from .expr import Expr
|
|
16
|
+
|
|
17
|
+
# Per-language rendering of operators and functions.
|
|
18
|
+
_LANGS = {
|
|
19
|
+
"c": {
|
|
20
|
+
"pow": lambda a, b: f"pow({a}, {b})",
|
|
21
|
+
"funcs": {"sin": "sin", "cos": "cos", "tan": "tan",
|
|
22
|
+
"exp": "exp", "log": "log", "sqrt": "sqrt"},
|
|
23
|
+
},
|
|
24
|
+
"rust": {
|
|
25
|
+
"pow": lambda a, b: f"({a}).powf({b})",
|
|
26
|
+
"funcs": {"sin": "{0}.sin()", "cos": "{0}.cos()", "tan": "{0}.tan()",
|
|
27
|
+
"exp": "{0}.exp()", "log": "{0}.ln()", "sqrt": "{0}.sqrt()"},
|
|
28
|
+
},
|
|
29
|
+
"js": {
|
|
30
|
+
"pow": lambda a, b: f"Math.pow({a}, {b})",
|
|
31
|
+
"funcs": {"sin": "Math.sin", "cos": "Math.cos", "tan": "Math.tan",
|
|
32
|
+
"exp": "Math.exp", "log": "Math.log", "sqrt": "Math.sqrt"},
|
|
33
|
+
},
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
_INFIX = {"+": "+", "-": "-", "*": "*", "/": "/"}
|
|
37
|
+
_PREC = {"+": 1, "-": 1, "*": 2, "/": 2, "neg": 3, "^": 4}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def to_code(e: Expr, lang: str = "c") -> str:
|
|
41
|
+
"""Render ``e`` as an expression in ``lang`` (``"c"``, ``"rust"`` or ``"js"``)."""
|
|
42
|
+
lang = lang.lower()
|
|
43
|
+
if lang not in _LANGS:
|
|
44
|
+
raise ValueError(f"unsupported language {lang!r}; choose from {list(_LANGS)}")
|
|
45
|
+
return _emit(e, _LANGS[lang], 0)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _emit(e: Expr, spec, parent_prec: int) -> str:
|
|
49
|
+
if e.kind == "num":
|
|
50
|
+
v = e.value
|
|
51
|
+
if isinstance(v, float) and v.is_integer():
|
|
52
|
+
v = int(v)
|
|
53
|
+
return str(v)
|
|
54
|
+
if e.kind == "var":
|
|
55
|
+
return e.name
|
|
56
|
+
if e.kind == "patvar":
|
|
57
|
+
raise ValueError("cannot generate code from a pattern variable")
|
|
58
|
+
|
|
59
|
+
op = e.name
|
|
60
|
+
if op == "neg":
|
|
61
|
+
return f"-{_emit(e.args[0], spec, _PREC['neg'])}"
|
|
62
|
+
|
|
63
|
+
if op == "^":
|
|
64
|
+
a = _emit(e.args[0], spec, 0)
|
|
65
|
+
b = _emit(e.args[1], spec, 0)
|
|
66
|
+
return spec["pow"](a, b)
|
|
67
|
+
|
|
68
|
+
if op in _INFIX:
|
|
69
|
+
prec = _PREC[op]
|
|
70
|
+
left = _emit(e.args[0], spec, prec)
|
|
71
|
+
right = _emit(e.args[1], spec, prec + 1)
|
|
72
|
+
s = f"{left} {_INFIX[op]} {right}"
|
|
73
|
+
return f"({s})" if prec < parent_prec else s
|
|
74
|
+
|
|
75
|
+
# function call
|
|
76
|
+
funcs = spec["funcs"]
|
|
77
|
+
if op in funcs:
|
|
78
|
+
args = [_emit(a, spec, 0) for a in e.args]
|
|
79
|
+
tmpl = funcs[op]
|
|
80
|
+
if "{0}" in tmpl: # method style (Rust)
|
|
81
|
+
return tmpl.format(*args)
|
|
82
|
+
return f"{tmpl}({', '.join(args)})" # call style (C / JS)
|
|
83
|
+
|
|
84
|
+
# unknown function: emit verbatim
|
|
85
|
+
args = ", ".join(_emit(a, spec, 0) for a in e.args)
|
|
86
|
+
return f"{op}({args})"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# --------------------------------------------------------------------- #
|
|
90
|
+
# LLVM IR transpiler
|
|
91
|
+
# --------------------------------------------------------------------- #
|
|
92
|
+
# This lets Logic-Loom plug into a real toolchain: emit the optimized
|
|
93
|
+
# expression as an LLVM IR function that clang/opt can compile, inline,
|
|
94
|
+
# and vectorize alongside the rest of a C/C++/Rust program.
|
|
95
|
+
_LLVM_BINOP = {"+": "fadd", "-": "fsub", "*": "fmul", "/": "fdiv"}
|
|
96
|
+
_LLVM_INTRINSIC = {
|
|
97
|
+
"sin": "@llvm.sin.f64", "cos": "@llvm.cos.f64",
|
|
98
|
+
"exp": "@llvm.exp.f64", "log": "@llvm.log.f64",
|
|
99
|
+
"sqrt": "@llvm.sqrt.f64", "^": "@llvm.pow.f64",
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _llvm_const(v) -> str:
|
|
104
|
+
return f"{float(v):e}"
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def free_vars(e: Expr):
|
|
108
|
+
"""Sorted list of variable names in ``e`` (the function parameters)."""
|
|
109
|
+
seen = set()
|
|
110
|
+
|
|
111
|
+
def walk(node):
|
|
112
|
+
if node.kind == "var":
|
|
113
|
+
seen.add(node.name)
|
|
114
|
+
for a in node.args:
|
|
115
|
+
walk(a)
|
|
116
|
+
|
|
117
|
+
walk(e)
|
|
118
|
+
return sorted(seen)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def to_llvm(e: Expr, name: str = "f") -> str:
|
|
122
|
+
"""Render ``e`` as an LLVM IR function ``double @name(double, ...)``."""
|
|
123
|
+
params = free_vars(e)
|
|
124
|
+
body: list[str] = []
|
|
125
|
+
counter = [0]
|
|
126
|
+
used: set[str] = set()
|
|
127
|
+
|
|
128
|
+
def fresh() -> str:
|
|
129
|
+
counter[0] += 1
|
|
130
|
+
return f"%t{counter[0]}"
|
|
131
|
+
|
|
132
|
+
def emit(node: Expr) -> str:
|
|
133
|
+
if node.kind == "num":
|
|
134
|
+
return _llvm_const(node.value)
|
|
135
|
+
if node.kind == "var":
|
|
136
|
+
return f"%{node.name}"
|
|
137
|
+
if node.kind == "patvar":
|
|
138
|
+
raise ValueError("cannot generate IR from a pattern variable")
|
|
139
|
+
|
|
140
|
+
op = node.name
|
|
141
|
+
if op == "neg":
|
|
142
|
+
x = emit(node.args[0])
|
|
143
|
+
r = fresh()
|
|
144
|
+
body.append(f" {r} = fneg double {x}")
|
|
145
|
+
return r
|
|
146
|
+
if op in _LLVM_BINOP:
|
|
147
|
+
a = emit(node.args[0])
|
|
148
|
+
b = emit(node.args[1])
|
|
149
|
+
r = fresh()
|
|
150
|
+
body.append(f" {r} = {_LLVM_BINOP[op]} double {a}, {b}")
|
|
151
|
+
return r
|
|
152
|
+
if op in _LLVM_INTRINSIC:
|
|
153
|
+
args = [emit(a) for a in node.args]
|
|
154
|
+
fn = _LLVM_INTRINSIC[op]
|
|
155
|
+
used.add(op)
|
|
156
|
+
r = fresh()
|
|
157
|
+
joined = ", ".join(f"double {a}" for a in args)
|
|
158
|
+
body.append(f" {r} = call double {fn}({joined})")
|
|
159
|
+
return r
|
|
160
|
+
# external function fallback
|
|
161
|
+
args = [emit(a) for a in node.args]
|
|
162
|
+
used.add(op)
|
|
163
|
+
r = fresh()
|
|
164
|
+
joined = ", ".join(f"double {a}" for a in args)
|
|
165
|
+
body.append(f" {r} = call double @{op}({joined})")
|
|
166
|
+
return r
|
|
167
|
+
|
|
168
|
+
ret = emit(e)
|
|
169
|
+
sig = ", ".join(f"double %{p}" for p in params)
|
|
170
|
+
|
|
171
|
+
decls = []
|
|
172
|
+
for op in sorted(used):
|
|
173
|
+
if op in _LLVM_INTRINSIC:
|
|
174
|
+
fn = _LLVM_INTRINSIC[op]
|
|
175
|
+
arity = 2 if op == "^" else 1
|
|
176
|
+
decls.append(f"declare double {fn}({', '.join(['double'] * arity)})")
|
|
177
|
+
else:
|
|
178
|
+
decls.append(f"declare double @{op}(double)")
|
|
179
|
+
|
|
180
|
+
lines = []
|
|
181
|
+
lines.extend(decls)
|
|
182
|
+
if decls:
|
|
183
|
+
lines.append("")
|
|
184
|
+
lines.append(f"define double @{name}({sig}) {{")
|
|
185
|
+
lines.append("entry:")
|
|
186
|
+
lines.extend(body)
|
|
187
|
+
lines.append(f" ret double {ret}")
|
|
188
|
+
lines.append("}")
|
|
189
|
+
return "\n".join(lines)
|
logic_loom/compiler.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""High-level API: turn a math string into its cheapest equivalent form."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Iterable, List, Optional
|
|
7
|
+
|
|
8
|
+
from .analysis import analyze, reachable_rules
|
|
9
|
+
from .cost import DEFAULT_MODEL, CostModel, expr_cost, extract, get_profile
|
|
10
|
+
from .egraph import EGraph
|
|
11
|
+
from .expr import Expr
|
|
12
|
+
from .parser import parse
|
|
13
|
+
from .rules import DEFAULT_RULES, Rule
|
|
14
|
+
from .saturate import BackoffScheduler, SaturationReport, saturate
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class Result:
|
|
19
|
+
source: str
|
|
20
|
+
original: Expr
|
|
21
|
+
optimized: Expr
|
|
22
|
+
original_cost: float
|
|
23
|
+
optimized_cost: float
|
|
24
|
+
report: SaturationReport
|
|
25
|
+
model: CostModel = DEFAULT_MODEL
|
|
26
|
+
assumptions: List[str] = field(default_factory=list)
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def improved(self) -> bool:
|
|
30
|
+
return self.optimized_cost < self.original_cost - 1e-9
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def speedup(self) -> float:
|
|
34
|
+
if self.optimized_cost <= 0:
|
|
35
|
+
return float("inf")
|
|
36
|
+
return self.original_cost / self.optimized_cost
|
|
37
|
+
|
|
38
|
+
def __str__(self) -> str:
|
|
39
|
+
arrow = "=>" if self.improved else "=="
|
|
40
|
+
return (
|
|
41
|
+
f"{self.original} {arrow} {self.optimized}\n"
|
|
42
|
+
f" cost {self.original_cost:.1f} -> {self.optimized_cost:.1f}"
|
|
43
|
+
f" ({self.speedup:.2f}x)"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _resolve_model(model, profile) -> CostModel:
|
|
48
|
+
if profile is not None:
|
|
49
|
+
return get_profile(profile)
|
|
50
|
+
if model is not None:
|
|
51
|
+
return model
|
|
52
|
+
return DEFAULT_MODEL
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _prepare(source, rules, auto, node_limit):
|
|
56
|
+
"""Parse and, if auto, statically prune rules and size the limits."""
|
|
57
|
+
original = parse(source)
|
|
58
|
+
rules = list(rules) if rules is not None else list(DEFAULT_RULES)
|
|
59
|
+
if auto:
|
|
60
|
+
rules = reachable_rules(original, rules)
|
|
61
|
+
an = analyze(original)
|
|
62
|
+
nl = node_limit if node_limit is not None else an.node_limit
|
|
63
|
+
scheduler = BackoffScheduler(match_limit=an.match_limit)
|
|
64
|
+
else:
|
|
65
|
+
nl = node_limit if node_limit is not None else 5_000
|
|
66
|
+
scheduler = None
|
|
67
|
+
return original, rules, nl, scheduler
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def build_egraph(
|
|
71
|
+
source: str,
|
|
72
|
+
*,
|
|
73
|
+
rules: Optional[List[Rule]] = None,
|
|
74
|
+
auto: bool = True,
|
|
75
|
+
max_iters: int = 30,
|
|
76
|
+
node_limit: Optional[int] = None,
|
|
77
|
+
impure: Optional[Iterable[str]] = None,
|
|
78
|
+
):
|
|
79
|
+
"""Parse and saturate ``source``; return ``(egraph, root_id, report)``."""
|
|
80
|
+
original, rules, nl, scheduler = _prepare(source, rules, auto, node_limit)
|
|
81
|
+
eg = EGraph()
|
|
82
|
+
root = eg.add_expr(original)
|
|
83
|
+
report = saturate(eg, rules, max_iters=max_iters, node_limit=nl,
|
|
84
|
+
scheduler=scheduler, impure=set(impure or ()))
|
|
85
|
+
return eg, root, report
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def optimize(
|
|
89
|
+
source: str,
|
|
90
|
+
*,
|
|
91
|
+
rules: Optional[List[Rule]] = None,
|
|
92
|
+
model: Optional[CostModel] = None,
|
|
93
|
+
profile: Optional[str] = None,
|
|
94
|
+
impure: Optional[Iterable[str]] = None,
|
|
95
|
+
auto: bool = True,
|
|
96
|
+
max_iters: int = 30,
|
|
97
|
+
node_limit: Optional[int] = None,
|
|
98
|
+
) -> Result:
|
|
99
|
+
"""Parse, saturate and extract the cheapest equivalent of ``source``.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
rules : rewrite rules to use (defaults to ``DEFAULT_RULES``).
|
|
104
|
+
model : a :class:`~logic_loom.cost.CostModel` for extraction.
|
|
105
|
+
profile : name of a built-in cost profile ("x86", "arm", "gpu", ...);
|
|
106
|
+
overrides ``model`` when given.
|
|
107
|
+
impure : names of side-effecting functions; rewrites that would
|
|
108
|
+
duplicate, drop, or reorder their calls are forbidden.
|
|
109
|
+
auto : enable static rule pruning and automatic limit sizing.
|
|
110
|
+
"""
|
|
111
|
+
cost_model = _resolve_model(model, profile)
|
|
112
|
+
original, used_rules, nl, scheduler = _prepare(source, rules, auto, node_limit)
|
|
113
|
+
|
|
114
|
+
eg = EGraph()
|
|
115
|
+
root = eg.add_expr(original)
|
|
116
|
+
report = saturate(eg, used_rules, max_iters=max_iters, node_limit=nl,
|
|
117
|
+
scheduler=scheduler, impure=set(impure or ()))
|
|
118
|
+
optimized, opt_cost = extract(eg, root, cost_model)
|
|
119
|
+
|
|
120
|
+
by_name = {r.name: r for r in used_rules}
|
|
121
|
+
assumptions = sorted({
|
|
122
|
+
a for name in report.fired
|
|
123
|
+
for a in by_name.get(name, Rule(name, original, original)).assumes
|
|
124
|
+
})
|
|
125
|
+
|
|
126
|
+
return Result(
|
|
127
|
+
source=source,
|
|
128
|
+
original=original,
|
|
129
|
+
optimized=optimized,
|
|
130
|
+
original_cost=expr_cost(original, cost_model),
|
|
131
|
+
optimized_cost=opt_cost,
|
|
132
|
+
report=report,
|
|
133
|
+
model=cost_model,
|
|
134
|
+
assumptions=assumptions,
|
|
135
|
+
)
|