gradient-script 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -1
- package/dist/cli.js +80 -3
- package/dist/dsl/CodeGen.d.ts +1 -1
- package/dist/dsl/CodeGen.js +332 -74
- package/dist/dsl/ExpressionUtils.d.ts +8 -2
- package/dist/dsl/ExpressionUtils.js +34 -2
- package/dist/dsl/GradientChecker.d.ts +21 -0
- package/dist/dsl/GradientChecker.js +109 -23
- package/dist/dsl/Guards.d.ts +1 -1
- package/dist/dsl/Guards.js +14 -13
- package/dist/dsl/Inliner.d.ts +5 -0
- package/dist/dsl/Inliner.js +8 -0
- package/dist/dsl/Simplify.d.ts +7 -0
- package/dist/dsl/Simplify.js +136 -0
- package/dist/dsl/egraph/Convert.d.ts +23 -0
- package/dist/dsl/egraph/Convert.js +84 -0
- package/dist/dsl/egraph/EGraph.d.ts +93 -0
- package/dist/dsl/egraph/EGraph.js +292 -0
- package/dist/dsl/egraph/ENode.d.ts +63 -0
- package/dist/dsl/egraph/ENode.js +94 -0
- package/dist/dsl/egraph/Extractor.d.ts +49 -0
- package/dist/dsl/egraph/Extractor.js +1068 -0
- package/dist/dsl/egraph/Optimizer.d.ts +50 -0
- package/dist/dsl/egraph/Optimizer.js +88 -0
- package/dist/dsl/egraph/Pattern.d.ts +80 -0
- package/dist/dsl/egraph/Pattern.js +325 -0
- package/dist/dsl/egraph/Rewriter.d.ts +44 -0
- package/dist/dsl/egraph/Rewriter.js +131 -0
- package/dist/dsl/egraph/Rules.d.ts +44 -0
- package/dist/dsl/egraph/Rules.js +187 -0
- package/dist/dsl/egraph/index.d.ts +15 -0
- package/dist/dsl/egraph/index.js +21 -0
- package/package.json +1 -1
- package/dist/dsl/CSE.d.ts +0 -21
- package/dist/dsl/CSE.js +0 -168
- package/dist/symbolic/AST.d.ts +0 -113
- package/dist/symbolic/AST.js +0 -128
- package/dist/symbolic/CodeGen.d.ts +0 -35
- package/dist/symbolic/CodeGen.js +0 -280
- package/dist/symbolic/Parser.d.ts +0 -64
- package/dist/symbolic/Parser.js +0 -329
- package/dist/symbolic/Simplify.d.ts +0 -10
- package/dist/symbolic/Simplify.js +0 -244
- package/dist/symbolic/SymbolicDiff.d.ts +0 -35
- package/dist/symbolic/SymbolicDiff.js +0 -339
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Rewrite Engine for E-Graph Equality Saturation
|
|
3
|
+
*
|
|
4
|
+
* Applies rewrite rules until saturation (no new merges) or iteration limit.
|
|
5
|
+
*/
|
|
6
|
+
import { matchPattern, instantiatePattern } from './Pattern.js';
|
|
7
|
+
/**
|
|
8
|
+
* Apply equality saturation to an e-graph
|
|
9
|
+
*
|
|
10
|
+
* Repeatedly applies rewrite rules until:
|
|
11
|
+
* - No new equivalences are discovered (saturated)
|
|
12
|
+
* - Max iterations reached
|
|
13
|
+
* - E-graph size limit exceeded
|
|
14
|
+
*/
|
|
15
|
+
export function saturate(egraph, rules, options = {}) {
|
|
16
|
+
const { maxIterations = 10, maxClassSize = 5000, maxMatchesPerIter = 5000, verbose = false } = options;
|
|
17
|
+
const stats = {
|
|
18
|
+
iterations: 0,
|
|
19
|
+
totalMatches: 0,
|
|
20
|
+
merges: 0,
|
|
21
|
+
saturated: false,
|
|
22
|
+
classCount: egraph.size
|
|
23
|
+
};
|
|
24
|
+
for (let iter = 0; iter < maxIterations; iter++) {
|
|
25
|
+
stats.iterations = iter + 1;
|
|
26
|
+
// Check size limit
|
|
27
|
+
if (egraph.size > maxClassSize) {
|
|
28
|
+
if (verbose) {
|
|
29
|
+
console.log(`[saturate] Size limit exceeded: ${egraph.size} > ${maxClassSize}`);
|
|
30
|
+
}
|
|
31
|
+
break;
|
|
32
|
+
}
|
|
33
|
+
// Collect all rule matches (with limit)
|
|
34
|
+
const matches = collectMatches(egraph, rules, maxMatchesPerIter);
|
|
35
|
+
stats.totalMatches += matches.length;
|
|
36
|
+
if (matches.length === 0) {
|
|
37
|
+
stats.saturated = true;
|
|
38
|
+
if (verbose) {
|
|
39
|
+
console.log(`[saturate] Saturated after ${iter + 1} iterations`);
|
|
40
|
+
}
|
|
41
|
+
break;
|
|
42
|
+
}
|
|
43
|
+
// Apply all matches
|
|
44
|
+
let mergesThisIter = 0;
|
|
45
|
+
for (const { rule, classId, subst } of matches) {
|
|
46
|
+
const rhsId = instantiatePattern(egraph, rule.rhs, subst);
|
|
47
|
+
const lhsCanon = egraph.find(classId);
|
|
48
|
+
const rhsCanon = egraph.find(rhsId);
|
|
49
|
+
if (lhsCanon !== rhsCanon) {
|
|
50
|
+
egraph.merge(lhsCanon, rhsCanon);
|
|
51
|
+
mergesThisIter++;
|
|
52
|
+
}
|
|
53
|
+
}
|
|
54
|
+
stats.merges += mergesThisIter;
|
|
55
|
+
// Rebuild to restore invariants
|
|
56
|
+
egraph.rebuild();
|
|
57
|
+
if (verbose) {
|
|
58
|
+
console.log(`[saturate] Iter ${iter + 1}: ${matches.length} matches, ${mergesThisIter} merges, ${egraph.size} classes`);
|
|
59
|
+
}
|
|
60
|
+
// If no merges happened, we're saturated
|
|
61
|
+
if (mergesThisIter === 0) {
|
|
62
|
+
stats.saturated = true;
|
|
63
|
+
break;
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
stats.classCount = egraph.size;
|
|
67
|
+
return stats;
|
|
68
|
+
}
|
|
69
|
+
/**
|
|
70
|
+
* Collect all rule matches across the e-graph
|
|
71
|
+
*/
|
|
72
|
+
function collectMatches(egraph, rules, maxMatches) {
|
|
73
|
+
const matches = [];
|
|
74
|
+
for (const classId of egraph.getClassIds()) {
|
|
75
|
+
for (const rule of rules) {
|
|
76
|
+
const substs = matchPattern(egraph, rule.lhs, classId);
|
|
77
|
+
for (const subst of substs) {
|
|
78
|
+
matches.push({ rule, classId, subst });
|
|
79
|
+
if (matches.length >= maxMatches) {
|
|
80
|
+
// Hit limit - return what we have (prevents explosion)
|
|
81
|
+
return matches;
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
return matches;
|
|
87
|
+
}
|
|
88
|
+
/**
|
|
89
|
+
* Apply a single rule once, returning number of merges
|
|
90
|
+
*/
|
|
91
|
+
export function applyRuleOnce(egraph, rule) {
|
|
92
|
+
const matches = collectMatches(egraph, [rule], 5000);
|
|
93
|
+
let merges = 0;
|
|
94
|
+
for (const { classId, subst } of matches) {
|
|
95
|
+
const rhsId = instantiatePattern(egraph, rule.rhs, subst);
|
|
96
|
+
const lhsCanon = egraph.find(classId);
|
|
97
|
+
const rhsCanon = egraph.find(rhsId);
|
|
98
|
+
if (lhsCanon !== rhsCanon) {
|
|
99
|
+
egraph.merge(lhsCanon, rhsCanon);
|
|
100
|
+
merges++;
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
egraph.rebuild();
|
|
104
|
+
return merges;
|
|
105
|
+
}
|
|
106
|
+
/**
|
|
107
|
+
* Apply rules in phases for better control
|
|
108
|
+
* E.g., apply core rules first, then algebra rules
|
|
109
|
+
*/
|
|
110
|
+
export function saturatePhased(egraph, phases, options = {}) {
|
|
111
|
+
const combinedStats = {
|
|
112
|
+
iterations: 0,
|
|
113
|
+
totalMatches: 0,
|
|
114
|
+
merges: 0,
|
|
115
|
+
saturated: true,
|
|
116
|
+
classCount: egraph.size
|
|
117
|
+
};
|
|
118
|
+
for (let phaseIdx = 0; phaseIdx < phases.length; phaseIdx++) {
|
|
119
|
+
const rules = phases[phaseIdx];
|
|
120
|
+
if (options.verbose) {
|
|
121
|
+
console.log(`[saturatePhased] Phase ${phaseIdx + 1}/${phases.length}: ${rules.length} rules`);
|
|
122
|
+
}
|
|
123
|
+
const phaseStats = saturate(egraph, rules, options);
|
|
124
|
+
combinedStats.iterations += phaseStats.iterations;
|
|
125
|
+
combinedStats.totalMatches += phaseStats.totalMatches;
|
|
126
|
+
combinedStats.merges += phaseStats.merges;
|
|
127
|
+
combinedStats.saturated = combinedStats.saturated && phaseStats.saturated;
|
|
128
|
+
}
|
|
129
|
+
combinedStats.classCount = egraph.size;
|
|
130
|
+
return combinedStats;
|
|
131
|
+
}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Algebraic Rewrite Rules for E-Graph Optimization
|
|
3
|
+
*
|
|
4
|
+
* Rules are organized by category:
|
|
5
|
+
* - Core: Always safe, essential (commutativity, identity, etc.)
|
|
6
|
+
* - Algebra: More powerful, can cause expansion (distribution, factoring)
|
|
7
|
+
* - Functions: Domain-specific (sqrt, trig, exp/log)
|
|
8
|
+
*
|
|
9
|
+
* Based on rules from Herbie (https://github.com/herbie-fp/herbie)
|
|
10
|
+
* and egg (https://egraphs-good.github.io/)
|
|
11
|
+
*/
|
|
12
|
+
import { Pattern } from './Pattern.js';
|
|
13
|
+
/**
|
|
14
|
+
* A rewrite rule: if LHS matches, RHS is equivalent
|
|
15
|
+
*/
|
|
16
|
+
export interface Rule {
|
|
17
|
+
name: string;
|
|
18
|
+
lhs: Pattern;
|
|
19
|
+
rhs: Pattern;
|
|
20
|
+
}
|
|
21
|
+
/**
|
|
22
|
+
* Create a rule from pattern strings
|
|
23
|
+
*/
|
|
24
|
+
export declare function rule(name: string, lhs: string, rhs: string): Rule;
|
|
25
|
+
/**
|
|
26
|
+
* Create bidirectional rules (both directions)
|
|
27
|
+
*/
|
|
28
|
+
export declare function biRule(name: string, a: string, b: string): Rule[];
|
|
29
|
+
export declare const coreRules: Rule[];
|
|
30
|
+
export declare const algebraRules: Rule[];
|
|
31
|
+
export declare const functionRules: Rule[];
|
|
32
|
+
/**
|
|
33
|
+
* All rules combined
|
|
34
|
+
*/
|
|
35
|
+
export declare const allRules: Rule[];
|
|
36
|
+
/**
|
|
37
|
+
* Minimal rules for canonicalization only
|
|
38
|
+
* (no expansion, just normalization)
|
|
39
|
+
*/
|
|
40
|
+
export declare const canonRules: Rule[];
|
|
41
|
+
/**
|
|
42
|
+
* Get rules by category
|
|
43
|
+
*/
|
|
44
|
+
export declare function getRuleSet(categories: ('core' | 'algebra' | 'function')[]): Rule[];
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Algebraic Rewrite Rules for E-Graph Optimization
|
|
3
|
+
*
|
|
4
|
+
* Rules are organized by category:
|
|
5
|
+
* - Core: Always safe, essential (commutativity, identity, etc.)
|
|
6
|
+
* - Algebra: More powerful, can cause expansion (distribution, factoring)
|
|
7
|
+
* - Functions: Domain-specific (sqrt, trig, exp/log)
|
|
8
|
+
*
|
|
9
|
+
* Based on rules from Herbie (https://github.com/herbie-fp/herbie)
|
|
10
|
+
* and egg (https://egraphs-good.github.io/)
|
|
11
|
+
*/
|
|
12
|
+
import { parsePattern } from './Pattern.js';
|
|
13
|
+
/**
|
|
14
|
+
* Create a rule from pattern strings
|
|
15
|
+
*/
|
|
16
|
+
export function rule(name, lhs, rhs) {
|
|
17
|
+
return {
|
|
18
|
+
name,
|
|
19
|
+
lhs: parsePattern(lhs),
|
|
20
|
+
rhs: parsePattern(rhs)
|
|
21
|
+
};
|
|
22
|
+
}
|
|
23
|
+
/**
|
|
24
|
+
* Create bidirectional rules (both directions)
|
|
25
|
+
*/
|
|
26
|
+
export function biRule(name, a, b) {
|
|
27
|
+
return [
|
|
28
|
+
rule(`${name}-l`, a, b),
|
|
29
|
+
rule(`${name}-r`, b, a)
|
|
30
|
+
];
|
|
31
|
+
}
|
|
32
|
+
// =============================================================================
|
|
33
|
+
// CORE RULES - Always safe, essential
|
|
34
|
+
// =============================================================================
|
|
35
|
+
export const coreRules = [
|
|
36
|
+
// === Commutativity ===
|
|
37
|
+
rule('comm-add', '(+ ?a ?b)', '(+ ?b ?a)'),
|
|
38
|
+
rule('comm-mul', '(* ?a ?b)', '(* ?b ?a)'),
|
|
39
|
+
// === Associativity ===
|
|
40
|
+
...biRule('assoc-add', '(+ (+ ?a ?b) ?c)', '(+ ?a (+ ?b ?c))'),
|
|
41
|
+
...biRule('assoc-mul', '(* (* ?a ?b) ?c)', '(* ?a (* ?b ?c))'),
|
|
42
|
+
// === Identity: addition ===
|
|
43
|
+
rule('add-0-l', '(+ 0 ?a)', '?a'),
|
|
44
|
+
rule('add-0-r', '(+ ?a 0)', '?a'),
|
|
45
|
+
// === Identity: subtraction ===
|
|
46
|
+
rule('sub-0', '(- ?a 0)', '?a'),
|
|
47
|
+
rule('0-sub', '(- 0 ?a)', '(neg ?a)'),
|
|
48
|
+
// === Identity: multiplication ===
|
|
49
|
+
rule('mul-1-l', '(* 1 ?a)', '?a'),
|
|
50
|
+
rule('mul-1-r', '(* ?a 1)', '?a'),
|
|
51
|
+
// === Identity: division ===
|
|
52
|
+
rule('div-1', '(/ ?a 1)', '?a'),
|
|
53
|
+
// === Identity: power ===
|
|
54
|
+
rule('pow-0', '(^ ?a 0)', '1'),
|
|
55
|
+
rule('pow-1', '(^ ?a 1)', '?a'),
|
|
56
|
+
// === Zero: multiplication ===
|
|
57
|
+
rule('mul-0-l', '(* 0 ?a)', '0'),
|
|
58
|
+
rule('mul-0-r', '(* ?a 0)', '0'),
|
|
59
|
+
// === Zero: division ===
|
|
60
|
+
rule('0-div', '(/ 0 ?a)', '0'),
|
|
61
|
+
// === Inverse: subtraction ===
|
|
62
|
+
rule('sub-self', '(- ?a ?a)', '0'),
|
|
63
|
+
// === Inverse: division ===
|
|
64
|
+
rule('div-self', '(/ ?a ?a)', '1'),
|
|
65
|
+
// === Double negation ===
|
|
66
|
+
rule('neg-neg', '(neg (neg ?a))', '?a'),
|
|
67
|
+
// === Negation with multiplication ===
|
|
68
|
+
rule('neg-mul-neg', '(* (neg ?a) (neg ?b))', '(* ?a ?b)'),
|
|
69
|
+
rule('mul-neg-1', '(* -1 ?a)', '(neg ?a)'),
|
|
70
|
+
rule('neg-to-mul', '(neg ?a)', '(* -1 ?a)'),
|
|
71
|
+
];
|
|
72
|
+
// =============================================================================
|
|
73
|
+
// ALGEBRA RULES - More powerful, can cause expansion
|
|
74
|
+
// =============================================================================
|
|
75
|
+
export const algebraRules = [
|
|
76
|
+
// === Distribution (both directions) ===
|
|
77
|
+
...biRule('dist-mul-add', '(* ?a (+ ?b ?c))', '(+ (* ?a ?b) (* ?a ?c))'),
|
|
78
|
+
...biRule('dist-mul-sub', '(* ?a (- ?b ?c))', '(- (* ?a ?b) (* ?a ?c))'),
|
|
79
|
+
// === Negation propagation ===
|
|
80
|
+
rule('neg-add', '(neg (+ ?a ?b))', '(+ (neg ?a) (neg ?b))'),
|
|
81
|
+
rule('neg-sub', '(neg (- ?a ?b))', '(- ?b ?a)'),
|
|
82
|
+
rule('neg-mul-l', '(* (neg ?a) ?b)', '(neg (* ?a ?b))'),
|
|
83
|
+
rule('neg-mul-r', '(* ?a (neg ?b))', '(neg (* ?a ?b))'),
|
|
84
|
+
rule('neg-div-l', '(/ (neg ?a) ?b)', '(neg (/ ?a ?b))'),
|
|
85
|
+
rule('neg-div-r', '(/ ?a (neg ?b))', '(neg (/ ?a ?b))'),
|
|
86
|
+
// === Subtraction to addition ===
|
|
87
|
+
rule('sub-to-add', '(- ?a ?b)', '(+ ?a (neg ?b))'),
|
|
88
|
+
rule('add-neg-to-sub', '(+ ?a (neg ?b))', '(- ?a ?b)'),
|
|
89
|
+
// === Division to multiplication by inverse ===
|
|
90
|
+
rule('div-to-mul-inv', '(/ ?a ?b)', '(* ?a (inv ?b))'),
|
|
91
|
+
// === Inverse rules ===
|
|
92
|
+
rule('inv-1', '(inv 1)', '1'),
|
|
93
|
+
rule('inv-inv', '(inv (inv ?a))', '?a'),
|
|
94
|
+
rule('div-1-to-inv', '(/ 1 ?a)', '(inv ?a)'),
|
|
95
|
+
rule('mul-inv-cancel', '(* ?a (inv ?a))', '1'), // a * (1/a) = 1
|
|
96
|
+
rule('inv-mul', '(inv (* ?a ?b))', '(* (inv ?a) (inv ?b))'), // 1/(a*b) = (1/a)*(1/b)
|
|
97
|
+
rule('div-self-sq', '(/ ?a (* ?a ?a))', '(inv ?a)'), // a / (a*a) = 1/a
|
|
98
|
+
rule('div-sq-self', '(/ (* ?a ?a) ?a)', '?a'), // (a*a) / a = a
|
|
99
|
+
rule('div-self-pow2', '(/ ?a (^ ?a 2))', '(inv ?a)'), // a / a^2 = 1/a
|
|
100
|
+
rule('div-pow2-self', '(/ (^ ?a 2) ?a)', '?a'), // a^2 / a = a
|
|
101
|
+
// === Power rules ===
|
|
102
|
+
rule('pow-2', '(^ ?a 2)', '(* ?a ?a)'),
|
|
103
|
+
rule('sq-to-pow', '(* ?a ?a)', '(^ ?a 2)'),
|
|
104
|
+
// === Combining like terms ===
|
|
105
|
+
rule('add-same', '(+ ?a ?a)', '(* 2 ?a)'),
|
|
106
|
+
rule('sub-neg-same', '(- ?a (neg ?a))', '(* 2 ?a)'),
|
|
107
|
+
];
|
|
108
|
+
// =============================================================================
|
|
109
|
+
// FUNCTION RULES - Domain-specific
|
|
110
|
+
// =============================================================================
|
|
111
|
+
export const functionRules = [
|
|
112
|
+
// === Sqrt ===
|
|
113
|
+
rule('sqrt-sq', '(* (sqrt ?a) (sqrt ?a))', '?a'),
|
|
114
|
+
rule('sqrt-mul', '(sqrt (* ?a ?b))', '(* (sqrt ?a) (sqrt ?b))'),
|
|
115
|
+
rule('sqrt-div', '(sqrt (/ ?a ?b))', '(/ (sqrt ?a) (sqrt ?b))'),
|
|
116
|
+
rule('sqrt-1', '(sqrt 1)', '1'),
|
|
117
|
+
rule('sqrt-0', '(sqrt 0)', '0'),
|
|
118
|
+
// === Abs ===
|
|
119
|
+
rule('abs-neg', '(abs (neg ?a))', '(abs ?a)'),
|
|
120
|
+
rule('abs-abs', '(abs (abs ?a))', '(abs ?a)'),
|
|
121
|
+
rule('abs-sq', '(abs (* ?a ?a))', '(* ?a ?a)'),
|
|
122
|
+
// === Trig ===
|
|
123
|
+
rule('sin-0', '(sin 0)', '0'),
|
|
124
|
+
rule('cos-0', '(cos 0)', '1'),
|
|
125
|
+
rule('sin-neg', '(sin (neg ?a))', '(neg (sin ?a))'),
|
|
126
|
+
rule('cos-neg', '(cos (neg ?a))', '(cos ?a)'),
|
|
127
|
+
// === Exp/Log ===
|
|
128
|
+
rule('exp-0', '(exp 0)', '1'),
|
|
129
|
+
rule('log-1', '(log 1)', '0'),
|
|
130
|
+
rule('exp-log', '(exp (log ?a))', '?a'),
|
|
131
|
+
rule('log-exp', '(log (exp ?a))', '?a'),
|
|
132
|
+
rule('log-mul', '(log (* ?a ?b))', '(+ (log ?a) (log ?b))'),
|
|
133
|
+
rule('log-div', '(log (/ ?a ?b))', '(- (log ?a) (log ?b))'),
|
|
134
|
+
rule('log-pow', '(log (^ ?a ?b))', '(* ?b (log ?a))'),
|
|
135
|
+
];
|
|
136
|
+
// =============================================================================
|
|
137
|
+
// RULE SETS
|
|
138
|
+
// =============================================================================
|
|
139
|
+
/**
|
|
140
|
+
* All rules combined
|
|
141
|
+
*/
|
|
142
|
+
export const allRules = [
|
|
143
|
+
...coreRules,
|
|
144
|
+
...algebraRules,
|
|
145
|
+
...functionRules,
|
|
146
|
+
];
|
|
147
|
+
/**
|
|
148
|
+
* Minimal rules for canonicalization only
|
|
149
|
+
* (no expansion, just normalization)
|
|
150
|
+
*/
|
|
151
|
+
export const canonRules = [
|
|
152
|
+
// Commutativity (for canonical ordering)
|
|
153
|
+
rule('comm-add', '(+ ?a ?b)', '(+ ?b ?a)'),
|
|
154
|
+
rule('comm-mul', '(* ?a ?b)', '(* ?b ?a)'),
|
|
155
|
+
// Identity removal
|
|
156
|
+
rule('add-0-l', '(+ 0 ?a)', '?a'),
|
|
157
|
+
rule('add-0-r', '(+ ?a 0)', '?a'),
|
|
158
|
+
rule('mul-1-l', '(* 1 ?a)', '?a'),
|
|
159
|
+
rule('mul-1-r', '(* ?a 1)', '?a'),
|
|
160
|
+
rule('sub-0', '(- ?a 0)', '?a'),
|
|
161
|
+
rule('div-1', '(/ ?a 1)', '?a'),
|
|
162
|
+
rule('pow-1', '(^ ?a 1)', '?a'),
|
|
163
|
+
// Zero
|
|
164
|
+
rule('mul-0-l', '(* 0 ?a)', '0'),
|
|
165
|
+
rule('mul-0-r', '(* ?a 0)', '0'),
|
|
166
|
+
rule('0-div', '(/ 0 ?a)', '0'),
|
|
167
|
+
rule('pow-0', '(^ ?a 0)', '1'),
|
|
168
|
+
// Inverse
|
|
169
|
+
rule('sub-self', '(- ?a ?a)', '0'),
|
|
170
|
+
rule('div-self', '(/ ?a ?a)', '1'),
|
|
171
|
+
// Double negation
|
|
172
|
+
rule('neg-neg', '(neg (neg ?a))', '?a'),
|
|
173
|
+
rule('neg-mul-neg', '(* (neg ?a) (neg ?b))', '(* ?a ?b)'),
|
|
174
|
+
];
|
|
175
|
+
/**
|
|
176
|
+
* Get rules by category
|
|
177
|
+
*/
|
|
178
|
+
export function getRuleSet(categories) {
|
|
179
|
+
const rules = [];
|
|
180
|
+
if (categories.includes('core'))
|
|
181
|
+
rules.push(...coreRules);
|
|
182
|
+
if (categories.includes('algebra'))
|
|
183
|
+
rules.push(...algebraRules);
|
|
184
|
+
if (categories.includes('function'))
|
|
185
|
+
rules.push(...functionRules);
|
|
186
|
+
return rules;
|
|
187
|
+
}
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* E-Graph Module for GradientScript
|
|
3
|
+
*
|
|
4
|
+
* Provides equality saturation-based optimization for gradient expressions.
|
|
5
|
+
* Can be used as an alternative to (or in addition to) the CSE module.
|
|
6
|
+
*/
|
|
7
|
+
export { EGraph } from './EGraph.js';
|
|
8
|
+
export type { EClass } from './EGraph.js';
|
|
9
|
+
export { EClassId, ENode, enodeKey, enodeChildren, enodeWithChildren } from './ENode.js';
|
|
10
|
+
export { Pattern, Substitution, parsePattern, matchPattern, instantiatePattern, patternToString } from './Pattern.js';
|
|
11
|
+
export { Rule, rule, biRule, coreRules, algebraRules, functionRules, allRules, canonRules, getRuleSet } from './Rules.js';
|
|
12
|
+
export { saturate, saturatePhased, applyRuleOnce, SaturationStats, SaturationOptions } from './Rewriter.js';
|
|
13
|
+
export { extractBest, extractWithCSE, ExtractionResult, CostModel, defaultCostModel } from './Extractor.js';
|
|
14
|
+
export { addExpression, addExpressions, addGradients, getRootIds } from './Convert.js';
|
|
15
|
+
export { optimizeWithEGraph, EGraphOptimizeResult } from './Optimizer.js';
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* E-Graph Module for GradientScript
|
|
3
|
+
*
|
|
4
|
+
* Provides equality saturation-based optimization for gradient expressions.
|
|
5
|
+
* Can be used as an alternative to (or in addition to) the CSE module.
|
|
6
|
+
*/
|
|
7
|
+
// Core e-graph
|
|
8
|
+
export { EGraph } from './EGraph.js';
|
|
9
|
+
export { enodeKey, enodeChildren, enodeWithChildren } from './ENode.js';
|
|
10
|
+
// Pattern matching
|
|
11
|
+
export { parsePattern, matchPattern, instantiatePattern, patternToString } from './Pattern.js';
|
|
12
|
+
// Rewrite rules
|
|
13
|
+
export { rule, biRule, coreRules, algebraRules, functionRules, allRules, canonRules, getRuleSet } from './Rules.js';
|
|
14
|
+
// Saturation
|
|
15
|
+
export { saturate, saturatePhased, applyRuleOnce } from './Rewriter.js';
|
|
16
|
+
// Extraction
|
|
17
|
+
export { extractBest, extractWithCSE, defaultCostModel } from './Extractor.js';
|
|
18
|
+
// AST conversion
|
|
19
|
+
export { addExpression, addExpressions, addGradients, getRootIds } from './Convert.js';
|
|
20
|
+
// Main optimizer function
|
|
21
|
+
export { optimizeWithEGraph } from './Optimizer.js';
|
package/package.json
CHANGED
package/dist/dsl/CSE.d.ts
DELETED
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Common Subexpression Elimination (CSE)
|
|
3
|
-
* Identifies repeated expressions and factors them out
|
|
4
|
-
*/
|
|
5
|
-
import { Expression } from './AST.js';
|
|
6
|
-
export interface CSEResult {
|
|
7
|
-
intermediates: Map<string, Expression>;
|
|
8
|
-
simplified: Expression;
|
|
9
|
-
}
|
|
10
|
-
export interface StructuredCSEResult {
|
|
11
|
-
intermediates: Map<string, Expression>;
|
|
12
|
-
components: Map<string, Expression>;
|
|
13
|
-
}
|
|
14
|
-
/**
|
|
15
|
-
* Perform CSE on an expression
|
|
16
|
-
*/
|
|
17
|
-
export declare function eliminateCommonSubexpressions(expr: Expression, minCount?: number): CSEResult;
|
|
18
|
-
/**
|
|
19
|
-
* Perform CSE on structured gradients (for structured types like {x, y})
|
|
20
|
-
*/
|
|
21
|
-
export declare function eliminateCommonSubexpressionsStructured(components: Map<string, Expression>, minCount?: number): StructuredCSEResult;
|
package/dist/dsl/CSE.js
DELETED
|
@@ -1,168 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Common Subexpression Elimination (CSE)
|
|
3
|
-
* Identifies repeated expressions and factors them out
|
|
4
|
-
*/
|
|
5
|
-
import { ExpressionTransformer } from './ExpressionTransformer.js';
|
|
6
|
-
import { serializeExpression } from './ExpressionUtils.js';
|
|
7
|
-
/**
|
|
8
|
-
* Perform CSE on an expression
|
|
9
|
-
*/
|
|
10
|
-
export function eliminateCommonSubexpressions(expr, minCount = 2) {
|
|
11
|
-
const counter = new ExpressionCounter();
|
|
12
|
-
counter.count(expr);
|
|
13
|
-
const intermediates = new Map();
|
|
14
|
-
let varCounter = 0;
|
|
15
|
-
const subexprMap = new Map();
|
|
16
|
-
for (const [exprStr, count] of counter.counts.entries()) {
|
|
17
|
-
if (count >= minCount) {
|
|
18
|
-
const parsed = counter.expressions.get(exprStr);
|
|
19
|
-
if (parsed && shouldExtract(parsed)) {
|
|
20
|
-
const varName = `_tmp${varCounter++}`;
|
|
21
|
-
intermediates.set(varName, parsed);
|
|
22
|
-
subexprMap.set(exprStr, varName);
|
|
23
|
-
}
|
|
24
|
-
}
|
|
25
|
-
}
|
|
26
|
-
const simplified = substituteExpressions(expr, subexprMap, counter);
|
|
27
|
-
return { intermediates, simplified };
|
|
28
|
-
}
|
|
29
|
-
/**
|
|
30
|
-
* Perform CSE on structured gradients (for structured types like {x, y})
|
|
31
|
-
*/
|
|
32
|
-
export function eliminateCommonSubexpressionsStructured(components, minCount = 2) {
|
|
33
|
-
const counter = new ExpressionCounter();
|
|
34
|
-
for (const expr of components.values()) {
|
|
35
|
-
counter.count(expr);
|
|
36
|
-
}
|
|
37
|
-
const intermediates = new Map();
|
|
38
|
-
let varCounter = 0;
|
|
39
|
-
const subexprMap = new Map();
|
|
40
|
-
for (const [exprStr, count] of counter.counts.entries()) {
|
|
41
|
-
if (count >= minCount) {
|
|
42
|
-
const parsed = counter.expressions.get(exprStr);
|
|
43
|
-
if (parsed && shouldExtract(parsed)) {
|
|
44
|
-
const varName = `_tmp${varCounter++}`;
|
|
45
|
-
intermediates.set(varName, parsed);
|
|
46
|
-
subexprMap.set(exprStr, varName);
|
|
47
|
-
}
|
|
48
|
-
}
|
|
49
|
-
}
|
|
50
|
-
const simplifiedComponents = new Map();
|
|
51
|
-
for (const [comp, expr] of components.entries()) {
|
|
52
|
-
simplifiedComponents.set(comp, substituteExpressions(expr, subexprMap, counter));
|
|
53
|
-
}
|
|
54
|
-
return { intermediates, components: simplifiedComponents };
|
|
55
|
-
}
|
|
56
|
-
/**
|
|
57
|
-
* Check if an expression should be extracted
|
|
58
|
-
*/
|
|
59
|
-
function shouldExtract(expr) {
|
|
60
|
-
switch (expr.kind) {
|
|
61
|
-
case 'number':
|
|
62
|
-
case 'variable':
|
|
63
|
-
return false;
|
|
64
|
-
case 'component':
|
|
65
|
-
return expr.object.kind !== 'variable';
|
|
66
|
-
case 'unary':
|
|
67
|
-
return shouldExtract(expr.operand);
|
|
68
|
-
case 'binary':
|
|
69
|
-
return true;
|
|
70
|
-
case 'call':
|
|
71
|
-
return true;
|
|
72
|
-
default:
|
|
73
|
-
return false;
|
|
74
|
-
}
|
|
75
|
-
}
|
|
76
|
-
/**
|
|
77
|
-
* Counts occurrences of subexpressions during traversal
|
|
78
|
-
*/
|
|
79
|
-
class ExpressionCounter extends ExpressionTransformer {
|
|
80
|
-
counts = new Map();
|
|
81
|
-
expressions = new Map();
|
|
82
|
-
count(expr) {
|
|
83
|
-
this.transform(expr);
|
|
84
|
-
}
|
|
85
|
-
serialize(expr) {
|
|
86
|
-
return serializeExpression(expr);
|
|
87
|
-
}
|
|
88
|
-
recordExpression(expr) {
|
|
89
|
-
const key = this.serialize(expr);
|
|
90
|
-
const currentCount = this.counts.get(key) || 0;
|
|
91
|
-
this.counts.set(key, currentCount + 1);
|
|
92
|
-
if (!this.expressions.has(key)) {
|
|
93
|
-
this.expressions.set(key, expr);
|
|
94
|
-
}
|
|
95
|
-
}
|
|
96
|
-
visitNumber(node) {
|
|
97
|
-
this.recordExpression(node);
|
|
98
|
-
return node;
|
|
99
|
-
}
|
|
100
|
-
visitVariable(node) {
|
|
101
|
-
this.recordExpression(node);
|
|
102
|
-
return node;
|
|
103
|
-
}
|
|
104
|
-
visitBinaryOp(node) {
|
|
105
|
-
this.recordExpression(node);
|
|
106
|
-
return super.visitBinaryOp(node);
|
|
107
|
-
}
|
|
108
|
-
visitUnaryOp(node) {
|
|
109
|
-
this.recordExpression(node);
|
|
110
|
-
return super.visitUnaryOp(node);
|
|
111
|
-
}
|
|
112
|
-
visitFunctionCall(node) {
|
|
113
|
-
this.recordExpression(node);
|
|
114
|
-
return super.visitFunctionCall(node);
|
|
115
|
-
}
|
|
116
|
-
visitComponentAccess(node) {
|
|
117
|
-
this.recordExpression(node);
|
|
118
|
-
return super.visitComponentAccess(node);
|
|
119
|
-
}
|
|
120
|
-
}
|
|
121
|
-
/**
|
|
122
|
-
* Substitute common subexpressions with variables
|
|
123
|
-
*/
|
|
124
|
-
function substituteExpressions(expr, subexprMap, counter) {
|
|
125
|
-
const key = counter.serialize(expr);
|
|
126
|
-
if (subexprMap.has(key)) {
|
|
127
|
-
return {
|
|
128
|
-
kind: 'variable',
|
|
129
|
-
name: subexprMap.get(key)
|
|
130
|
-
};
|
|
131
|
-
}
|
|
132
|
-
let result = expr;
|
|
133
|
-
for (const [exprStr, varName] of subexprMap.entries()) {
|
|
134
|
-
const exprToReplace = counter.expressions.get(exprStr);
|
|
135
|
-
if (exprToReplace && counter.serialize(result) !== exprStr) {
|
|
136
|
-
result = substituteInExpression(result, exprToReplace, { kind: 'variable', name: varName }, counter);
|
|
137
|
-
}
|
|
138
|
-
}
|
|
139
|
-
return result;
|
|
140
|
-
}
|
|
141
|
-
/**
|
|
142
|
-
* Transformer that substitutes a pattern with a replacement expression
|
|
143
|
-
* Used for CSE optimization to replace repeated subexpressions with intermediate variables
|
|
144
|
-
*/
|
|
145
|
-
class PatternSubstitutionTransformer extends ExpressionTransformer {
|
|
146
|
-
pattern;
|
|
147
|
-
replacement;
|
|
148
|
-
counter;
|
|
149
|
-
constructor(pattern, replacement, counter) {
|
|
150
|
-
super();
|
|
151
|
-
this.pattern = pattern;
|
|
152
|
-
this.replacement = replacement;
|
|
153
|
-
this.counter = counter;
|
|
154
|
-
}
|
|
155
|
-
transform(expr) {
|
|
156
|
-
if (this.counter.serialize(expr) === this.counter.serialize(this.pattern)) {
|
|
157
|
-
return this.replacement;
|
|
158
|
-
}
|
|
159
|
-
return super.transform(expr);
|
|
160
|
-
}
|
|
161
|
-
}
|
|
162
|
-
/**
|
|
163
|
-
* Helper to substitute an expression pattern with a replacement
|
|
164
|
-
*/
|
|
165
|
-
function substituteInExpression(expr, pattern, replacement, counter) {
|
|
166
|
-
const transformer = new PatternSubstitutionTransformer(pattern, replacement, counter);
|
|
167
|
-
return transformer.transform(expr);
|
|
168
|
-
}
|