logic-loom 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
logic_loom/__init__.py ADDED
@@ -0,0 +1,67 @@
1
+ """Logic-Loom: a compiler that understands mathematics.
2
+
3
+ Instead of peephole-optimizing instructions, Logic-Loom reasons about
4
+ the *algebra* of an expression. It uses equality saturation over an
5
+ e-graph to discover the whole space of equivalent forms, then extracts
6
+ the cheapest one under a configurable, hardware-aware cost model.
7
+
8
+ >>> from logic_loom import optimize
9
+ >>> print(optimize("a*b + a*c"))
10
+ a * b + a * c => a * (b + c)
11
+ cost 5.4 -> 3.3 (1.64x)
12
+ """
13
+
14
+ from .analysis import Analysis, analyze, reachable_rules
15
+ from .codegen import free_vars, to_code, to_llvm
16
+ from .compiler import Result, build_egraph, optimize
17
+ from .cost import (
18
+ DEFAULT_MODEL,
19
+ PROFILES,
20
+ CostModel,
21
+ expr_cost,
22
+ extract,
23
+ get_profile,
24
+ )
25
+ from .effects import is_effect_safe, tainted_classes
26
+ from .egraph import EGraph
27
+ from .expr import Expr, evaluate
28
+ from .parser import parse
29
+ from .rules import ALL_RULES, DEFAULT_RULES, EXTENDED_RULES, Rule, rule
30
+ from .saturate import BackoffScheduler, SaturationReport, saturate
31
+ from .viz import to_dot
32
+
33
+ __version__ = "0.3.0"
34
+
35
+ __all__ = [
36
+ "optimize",
37
+ "build_egraph",
38
+ "Result",
39
+ "Expr",
40
+ "evaluate",
41
+ "parse",
42
+ "EGraph",
43
+ "Rule",
44
+ "rule",
45
+ "DEFAULT_RULES",
46
+ "EXTENDED_RULES",
47
+ "ALL_RULES",
48
+ "saturate",
49
+ "SaturationReport",
50
+ "BackoffScheduler",
51
+ "extract",
52
+ "expr_cost",
53
+ "CostModel",
54
+ "PROFILES",
55
+ "DEFAULT_MODEL",
56
+ "get_profile",
57
+ "to_code",
58
+ "to_llvm",
59
+ "free_vars",
60
+ "to_dot",
61
+ "analyze",
62
+ "Analysis",
63
+ "reachable_rules",
64
+ "tainted_classes",
65
+ "is_effect_safe",
66
+ "__version__",
67
+ ]
logic_loom/__main__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .cli import main
2
+
3
+ if __name__ == "__main__":
4
+ raise SystemExit(main())
logic_loom/analysis.py ADDED
@@ -0,0 +1,94 @@
1
+ """Static analysis to tame the search before saturation begins.
2
+
3
+ Equality saturation explores blindly; a little static reasoning about the
4
+ input lets us avoid work that provably cannot help.
5
+
6
+ Two analyses live here:
7
+
8
+ 1. **Reachable-rule pruning.** A rule can only ever fire if the operators
9
+ in its left-hand side are present -- and operators only appear if some
10
+ *other* fireable rule introduces them. Computing this set as a
11
+ fixed-point lets us drop rules that can never match. This is sound:
12
+ the pruned rules would have contributed nothing, so the result is
13
+ identical, only reached faster.
14
+
15
+ 2. **Complexity estimate.** Counting associative/commutative operators
16
+ predicts how badly the e-graph might blow up, which we use to size the
17
+ resource limits automatically.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ from dataclasses import dataclass
23
+ from typing import List, Set
24
+
25
+ from .expr import Expr
26
+
27
+
28
+ def operators(e: Expr, acc: Set[str] | None = None) -> Set[str]:
29
+ """The set of operator/function heads appearing in ``e``."""
30
+ if acc is None:
31
+ acc = set()
32
+ if e.kind == "app":
33
+ acc.add(e.name)
34
+ for a in e.args:
35
+ operators(a, acc)
36
+ return acc
37
+
38
+
39
+ def _size(e: Expr) -> int:
40
+ return 1 + sum(_size(a) for a in e.args)
41
+
42
+
43
+ def _count_heads(e: Expr, heads) -> int:
44
+ n = 1 if (e.kind == "app" and e.name in heads) else 0
45
+ return n + sum(_count_heads(a, heads) for a in e.args)
46
+
47
+
48
+ @dataclass
49
+ class Analysis:
50
+ size: int
51
+ variables: int
52
+ ac_ops: int # number of + and * nodes (the explosion driver)
53
+ node_limit: int
54
+ match_limit: int
55
+
56
+ def summary(self) -> str:
57
+ return (f"size={self.size} vars={self.variables} ac_ops={self.ac_ops} "
58
+ f"-> node_limit={self.node_limit} match_limit={self.match_limit}")
59
+
60
+
61
+ def reachable_rules(expr: Expr, rules: List) -> List:
62
+ """Keep only rules whose operators can become reachable from ``expr``."""
63
+ reachable = operators(expr)
64
+ rule_ops = [(r, operators(r.lhs), operators(r.rhs)) for r in rules]
65
+ changed = True
66
+ while changed:
67
+ changed = False
68
+ for _r, lhs_ops, rhs_ops in rule_ops:
69
+ if lhs_ops <= reachable and not rhs_ops <= reachable:
70
+ reachable |= rhs_ops
71
+ changed = True
72
+ return [r for r, lhs_ops, _ in rule_ops if lhs_ops <= reachable]
73
+
74
+
75
+ def analyze(expr: Expr) -> Analysis:
76
+ size = _size(expr)
77
+ variables = len({v for v in _vars(expr)})
78
+ ac = _count_heads(expr, {"+", "*"})
79
+ # Small inputs can saturate fully; large AC-heavy ones need tighter
80
+ # budgets so we stop early with the best form found instead of thrashing.
81
+ node_limit = min(20_000, max(2_000, 400 * (ac + 1)))
82
+ match_limit = 2_000 if ac <= 3 else max(300, 2_000 // ac)
83
+ return Analysis(size=size, variables=variables, ac_ops=ac,
84
+ node_limit=node_limit, match_limit=match_limit)
85
+
86
+
87
+ def _vars(e: Expr, acc: Set[str] | None = None) -> Set[str]:
88
+ if acc is None:
89
+ acc = set()
90
+ if e.kind == "var":
91
+ acc.add(e.name)
92
+ for a in e.args:
93
+ _vars(a, acc)
94
+ return acc
logic_loom/cli.py ADDED
@@ -0,0 +1,88 @@
1
+ """Command-line interface: python -m logic_loom "a*b + a*c"."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import sys
7
+
8
+ from . import __version__
9
+ from .codegen import to_code, to_llvm
10
+ from .compiler import build_egraph, optimize
11
+ from .cost import PROFILES
12
+ from .rules import ALL_RULES, DEFAULT_RULES
13
+ from .viz import to_dot
14
+
15
+
16
+ def main(argv=None) -> int:
17
+ p = argparse.ArgumentParser(
18
+ prog="logic_loom",
19
+ description="A compiler that understands mathematics. "
20
+ "Give it an expression; it returns the cheapest equivalent form.",
21
+ )
22
+ p.add_argument("expression", nargs="*", help="math expression, e.g. 'a*b + a*c'")
23
+ p.add_argument("-v", "--verbose", action="store_true",
24
+ help="show saturation statistics")
25
+ p.add_argument("--extended", action="store_true",
26
+ help="enable transcendental-function rules (exp/log/sqrt/trig)")
27
+ p.add_argument("--profile", choices=list(PROFILES), default="default",
28
+ help="cost profile to optimize for (default: default)")
29
+ p.add_argument("--target", choices=["c", "rust", "js", "llvm"],
30
+ help="also emit the optimized form as source code / IR")
31
+ p.add_argument("--impure", default="",
32
+ help="comma-separated names of side-effecting functions")
33
+ p.add_argument("--explain", action="store_true",
34
+ help="report the domain assumptions the result relies on")
35
+ p.add_argument("--dot", action="store_true",
36
+ help="print the saturated e-graph as Graphviz DOT")
37
+ p.add_argument("--max-iters", type=int, default=30)
38
+ p.add_argument("--node-limit", type=int, default=None)
39
+ p.add_argument("--version", action="version", version=f"logic-loom {__version__}")
40
+ args = p.parse_args(argv)
41
+
42
+ rules = ALL_RULES if args.extended else DEFAULT_RULES
43
+ impure = {s.strip() for s in args.impure.split(",") if s.strip()}
44
+ sources = [" ".join(args.expression)] if args.expression else _read_stdin()
45
+ if not sources:
46
+ p.print_help()
47
+ return 1
48
+
49
+ for src in sources:
50
+ src = src.strip()
51
+ if not src:
52
+ continue
53
+
54
+ if args.dot:
55
+ eg, root, _ = build_egraph(
56
+ src, rules=rules, max_iters=args.max_iters,
57
+ node_limit=args.node_limit, impure=impure)
58
+ print(to_dot(eg, root))
59
+ continue
60
+
61
+ r = optimize(src, rules=rules, profile=args.profile,
62
+ impure=impure, max_iters=args.max_iters,
63
+ node_limit=args.node_limit)
64
+ print(r)
65
+ if args.target == "llvm":
66
+ print(to_llvm(r.optimized))
67
+ elif args.target:
68
+ print(f" {args.target}: {to_code(r.optimized, args.target)}")
69
+ if args.explain and r.assumptions:
70
+ print(f" assumes (for soundness): {'; '.join(r.assumptions)}")
71
+ if args.verbose:
72
+ rep = r.report
73
+ print(f" [{rep.stop_reason}] profile={r.model.name} "
74
+ f"iterations={rep.iterations} e-nodes={rep.nodes} "
75
+ f"e-classes={rep.classes}")
76
+ if rep.banned:
77
+ print(f" throttled rules: {', '.join(rep.banned)}")
78
+ return 0
79
+
80
+
81
+ def _read_stdin():
82
+ if sys.stdin.isatty():
83
+ return []
84
+ return sys.stdin.read().splitlines()
85
+
86
+
87
+ if __name__ == "__main__":
88
+ raise SystemExit(main())
logic_loom/codegen.py ADDED
@@ -0,0 +1,189 @@
1
+ """Emit an optimized expression as source code in a real language.
2
+
3
+ The whole point of optimizing an expression is to *run* it. This module
4
+ turns an :class:`Expr` tree into a snippet of C, Rust or JavaScript, so
5
+ the optimized form can be pasted straight into a program.
6
+
7
+ >>> from logic_loom import optimize, to_code
8
+ >>> e = optimize("a*x*x + b*x + c").optimized
9
+ >>> print(to_code(e, "c"))
10
+ x * (a * x + b) + c
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from .expr import Expr
16
+
17
+ # Per-language rendering of operators and functions.
18
+ _LANGS = {
19
+ "c": {
20
+ "pow": lambda a, b: f"pow({a}, {b})",
21
+ "funcs": {"sin": "sin", "cos": "cos", "tan": "tan",
22
+ "exp": "exp", "log": "log", "sqrt": "sqrt"},
23
+ },
24
+ "rust": {
25
+ "pow": lambda a, b: f"({a}).powf({b})",
26
+ "funcs": {"sin": "{0}.sin()", "cos": "{0}.cos()", "tan": "{0}.tan()",
27
+ "exp": "{0}.exp()", "log": "{0}.ln()", "sqrt": "{0}.sqrt()"},
28
+ },
29
+ "js": {
30
+ "pow": lambda a, b: f"Math.pow({a}, {b})",
31
+ "funcs": {"sin": "Math.sin", "cos": "Math.cos", "tan": "Math.tan",
32
+ "exp": "Math.exp", "log": "Math.log", "sqrt": "Math.sqrt"},
33
+ },
34
+ }
35
+
36
+ _INFIX = {"+": "+", "-": "-", "*": "*", "/": "/"}
37
+ _PREC = {"+": 1, "-": 1, "*": 2, "/": 2, "neg": 3, "^": 4}
38
+
39
+
40
+ def to_code(e: Expr, lang: str = "c") -> str:
41
+ """Render ``e`` as an expression in ``lang`` (``"c"``, ``"rust"`` or ``"js"``)."""
42
+ lang = lang.lower()
43
+ if lang not in _LANGS:
44
+ raise ValueError(f"unsupported language {lang!r}; choose from {list(_LANGS)}")
45
+ return _emit(e, _LANGS[lang], 0)
46
+
47
+
48
+ def _emit(e: Expr, spec, parent_prec: int) -> str:
49
+ if e.kind == "num":
50
+ v = e.value
51
+ if isinstance(v, float) and v.is_integer():
52
+ v = int(v)
53
+ return str(v)
54
+ if e.kind == "var":
55
+ return e.name
56
+ if e.kind == "patvar":
57
+ raise ValueError("cannot generate code from a pattern variable")
58
+
59
+ op = e.name
60
+ if op == "neg":
61
+ return f"-{_emit(e.args[0], spec, _PREC['neg'])}"
62
+
63
+ if op == "^":
64
+ a = _emit(e.args[0], spec, 0)
65
+ b = _emit(e.args[1], spec, 0)
66
+ return spec["pow"](a, b)
67
+
68
+ if op in _INFIX:
69
+ prec = _PREC[op]
70
+ left = _emit(e.args[0], spec, prec)
71
+ right = _emit(e.args[1], spec, prec + 1)
72
+ s = f"{left} {_INFIX[op]} {right}"
73
+ return f"({s})" if prec < parent_prec else s
74
+
75
+ # function call
76
+ funcs = spec["funcs"]
77
+ if op in funcs:
78
+ args = [_emit(a, spec, 0) for a in e.args]
79
+ tmpl = funcs[op]
80
+ if "{0}" in tmpl: # method style (Rust)
81
+ return tmpl.format(*args)
82
+ return f"{tmpl}({', '.join(args)})" # call style (C / JS)
83
+
84
+ # unknown function: emit verbatim
85
+ args = ", ".join(_emit(a, spec, 0) for a in e.args)
86
+ return f"{op}({args})"
87
+
88
+
89
+ # --------------------------------------------------------------------- #
90
+ # LLVM IR transpiler
91
+ # --------------------------------------------------------------------- #
92
+ # This lets Logic-Loom plug into a real toolchain: emit the optimized
93
+ # expression as an LLVM IR function that clang/opt can compile, inline,
94
+ # and vectorize alongside the rest of a C/C++/Rust program.
95
+ _LLVM_BINOP = {"+": "fadd", "-": "fsub", "*": "fmul", "/": "fdiv"}
96
+ _LLVM_INTRINSIC = {
97
+ "sin": "@llvm.sin.f64", "cos": "@llvm.cos.f64",
98
+ "exp": "@llvm.exp.f64", "log": "@llvm.log.f64",
99
+ "sqrt": "@llvm.sqrt.f64", "^": "@llvm.pow.f64",
100
+ }
101
+
102
+
103
+ def _llvm_const(v) -> str:
104
+ return f"{float(v):e}"
105
+
106
+
107
+ def free_vars(e: Expr):
108
+ """Sorted list of variable names in ``e`` (the function parameters)."""
109
+ seen = set()
110
+
111
+ def walk(node):
112
+ if node.kind == "var":
113
+ seen.add(node.name)
114
+ for a in node.args:
115
+ walk(a)
116
+
117
+ walk(e)
118
+ return sorted(seen)
119
+
120
+
121
+ def to_llvm(e: Expr, name: str = "f") -> str:
122
+ """Render ``e`` as an LLVM IR function ``double @name(double, ...)``."""
123
+ params = free_vars(e)
124
+ body: list[str] = []
125
+ counter = [0]
126
+ used: set[str] = set()
127
+
128
+ def fresh() -> str:
129
+ counter[0] += 1
130
+ return f"%t{counter[0]}"
131
+
132
+ def emit(node: Expr) -> str:
133
+ if node.kind == "num":
134
+ return _llvm_const(node.value)
135
+ if node.kind == "var":
136
+ return f"%{node.name}"
137
+ if node.kind == "patvar":
138
+ raise ValueError("cannot generate IR from a pattern variable")
139
+
140
+ op = node.name
141
+ if op == "neg":
142
+ x = emit(node.args[0])
143
+ r = fresh()
144
+ body.append(f" {r} = fneg double {x}")
145
+ return r
146
+ if op in _LLVM_BINOP:
147
+ a = emit(node.args[0])
148
+ b = emit(node.args[1])
149
+ r = fresh()
150
+ body.append(f" {r} = {_LLVM_BINOP[op]} double {a}, {b}")
151
+ return r
152
+ if op in _LLVM_INTRINSIC:
153
+ args = [emit(a) for a in node.args]
154
+ fn = _LLVM_INTRINSIC[op]
155
+ used.add(op)
156
+ r = fresh()
157
+ joined = ", ".join(f"double {a}" for a in args)
158
+ body.append(f" {r} = call double {fn}({joined})")
159
+ return r
160
+ # external function fallback
161
+ args = [emit(a) for a in node.args]
162
+ used.add(op)
163
+ r = fresh()
164
+ joined = ", ".join(f"double {a}" for a in args)
165
+ body.append(f" {r} = call double @{op}({joined})")
166
+ return r
167
+
168
+ ret = emit(e)
169
+ sig = ", ".join(f"double %{p}" for p in params)
170
+
171
+ decls = []
172
+ for op in sorted(used):
173
+ if op in _LLVM_INTRINSIC:
174
+ fn = _LLVM_INTRINSIC[op]
175
+ arity = 2 if op == "^" else 1
176
+ decls.append(f"declare double {fn}({', '.join(['double'] * arity)})")
177
+ else:
178
+ decls.append(f"declare double @{op}(double)")
179
+
180
+ lines = []
181
+ lines.extend(decls)
182
+ if decls:
183
+ lines.append("")
184
+ lines.append(f"define double @{name}({sig}) {{")
185
+ lines.append("entry:")
186
+ lines.extend(body)
187
+ lines.append(f" ret double {ret}")
188
+ lines.append("}")
189
+ return "\n".join(lines)
logic_loom/compiler.py ADDED
@@ -0,0 +1,135 @@
1
+ """High-level API: turn a math string into its cheapest equivalent form."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Iterable, List, Optional
7
+
8
+ from .analysis import analyze, reachable_rules
9
+ from .cost import DEFAULT_MODEL, CostModel, expr_cost, extract, get_profile
10
+ from .egraph import EGraph
11
+ from .expr import Expr
12
+ from .parser import parse
13
+ from .rules import DEFAULT_RULES, Rule
14
+ from .saturate import BackoffScheduler, SaturationReport, saturate
15
+
16
+
17
+ @dataclass
18
+ class Result:
19
+ source: str
20
+ original: Expr
21
+ optimized: Expr
22
+ original_cost: float
23
+ optimized_cost: float
24
+ report: SaturationReport
25
+ model: CostModel = DEFAULT_MODEL
26
+ assumptions: List[str] = field(default_factory=list)
27
+
28
+ @property
29
+ def improved(self) -> bool:
30
+ return self.optimized_cost < self.original_cost - 1e-9
31
+
32
+ @property
33
+ def speedup(self) -> float:
34
+ if self.optimized_cost <= 0:
35
+ return float("inf")
36
+ return self.original_cost / self.optimized_cost
37
+
38
+ def __str__(self) -> str:
39
+ arrow = "=>" if self.improved else "=="
40
+ return (
41
+ f"{self.original} {arrow} {self.optimized}\n"
42
+ f" cost {self.original_cost:.1f} -> {self.optimized_cost:.1f}"
43
+ f" ({self.speedup:.2f}x)"
44
+ )
45
+
46
+
47
+ def _resolve_model(model, profile) -> CostModel:
48
+ if profile is not None:
49
+ return get_profile(profile)
50
+ if model is not None:
51
+ return model
52
+ return DEFAULT_MODEL
53
+
54
+
55
+ def _prepare(source, rules, auto, node_limit):
56
+ """Parse and, if auto, statically prune rules and size the limits."""
57
+ original = parse(source)
58
+ rules = list(rules) if rules is not None else list(DEFAULT_RULES)
59
+ if auto:
60
+ rules = reachable_rules(original, rules)
61
+ an = analyze(original)
62
+ nl = node_limit if node_limit is not None else an.node_limit
63
+ scheduler = BackoffScheduler(match_limit=an.match_limit)
64
+ else:
65
+ nl = node_limit if node_limit is not None else 5_000
66
+ scheduler = None
67
+ return original, rules, nl, scheduler
68
+
69
+
70
+ def build_egraph(
71
+ source: str,
72
+ *,
73
+ rules: Optional[List[Rule]] = None,
74
+ auto: bool = True,
75
+ max_iters: int = 30,
76
+ node_limit: Optional[int] = None,
77
+ impure: Optional[Iterable[str]] = None,
78
+ ):
79
+ """Parse and saturate ``source``; return ``(egraph, root_id, report)``."""
80
+ original, rules, nl, scheduler = _prepare(source, rules, auto, node_limit)
81
+ eg = EGraph()
82
+ root = eg.add_expr(original)
83
+ report = saturate(eg, rules, max_iters=max_iters, node_limit=nl,
84
+ scheduler=scheduler, impure=set(impure or ()))
85
+ return eg, root, report
86
+
87
+
88
+ def optimize(
89
+ source: str,
90
+ *,
91
+ rules: Optional[List[Rule]] = None,
92
+ model: Optional[CostModel] = None,
93
+ profile: Optional[str] = None,
94
+ impure: Optional[Iterable[str]] = None,
95
+ auto: bool = True,
96
+ max_iters: int = 30,
97
+ node_limit: Optional[int] = None,
98
+ ) -> Result:
99
+ """Parse, saturate and extract the cheapest equivalent of ``source``.
100
+
101
+ Parameters
102
+ ----------
103
+ rules : rewrite rules to use (defaults to ``DEFAULT_RULES``).
104
+ model : a :class:`~logic_loom.cost.CostModel` for extraction.
105
+ profile : name of a built-in cost profile ("x86", "arm", "gpu", ...);
106
+ overrides ``model`` when given.
107
+ impure : names of side-effecting functions; rewrites that would
108
+ duplicate, drop, or reorder their calls are forbidden.
109
+ auto : enable static rule pruning and automatic limit sizing.
110
+ """
111
+ cost_model = _resolve_model(model, profile)
112
+ original, used_rules, nl, scheduler = _prepare(source, rules, auto, node_limit)
113
+
114
+ eg = EGraph()
115
+ root = eg.add_expr(original)
116
+ report = saturate(eg, used_rules, max_iters=max_iters, node_limit=nl,
117
+ scheduler=scheduler, impure=set(impure or ()))
118
+ optimized, opt_cost = extract(eg, root, cost_model)
119
+
120
+ by_name = {r.name: r for r in used_rules}
121
+ assumptions = sorted({
122
+ a for name in report.fired
123
+ for a in by_name.get(name, Rule(name, original, original)).assumes
124
+ })
125
+
126
+ return Result(
127
+ source=source,
128
+ original=original,
129
+ optimized=optimized,
130
+ original_cost=expr_cost(original, cost_model),
131
+ optimized_cost=opt_cost,
132
+ report=report,
133
+ model=cost_model,
134
+ assumptions=assumptions,
135
+ )