gradient-script 0.2.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. package/README.md +3 -1
  2. package/dist/cli.js +219 -6
  3. package/dist/dsl/CodeGen.d.ts +1 -1
  4. package/dist/dsl/CodeGen.js +336 -74
  5. package/dist/dsl/ExpressionUtils.d.ts +8 -2
  6. package/dist/dsl/ExpressionUtils.js +34 -2
  7. package/dist/dsl/GradientChecker.d.ts +21 -0
  8. package/dist/dsl/GradientChecker.js +109 -23
  9. package/dist/dsl/Guards.d.ts +1 -1
  10. package/dist/dsl/Guards.js +14 -13
  11. package/dist/dsl/Inliner.d.ts +5 -0
  12. package/dist/dsl/Inliner.js +8 -0
  13. package/dist/dsl/Simplify.d.ts +7 -0
  14. package/dist/dsl/Simplify.js +136 -0
  15. package/dist/dsl/egraph/Convert.d.ts +23 -0
  16. package/dist/dsl/egraph/Convert.js +84 -0
  17. package/dist/dsl/egraph/EGraph.d.ts +93 -0
  18. package/dist/dsl/egraph/EGraph.js +292 -0
  19. package/dist/dsl/egraph/ENode.d.ts +63 -0
  20. package/dist/dsl/egraph/ENode.js +94 -0
  21. package/dist/dsl/egraph/Extractor.d.ts +49 -0
  22. package/dist/dsl/egraph/Extractor.js +1068 -0
  23. package/dist/dsl/egraph/Optimizer.d.ts +50 -0
  24. package/dist/dsl/egraph/Optimizer.js +88 -0
  25. package/dist/dsl/egraph/Pattern.d.ts +80 -0
  26. package/dist/dsl/egraph/Pattern.js +325 -0
  27. package/dist/dsl/egraph/Rewriter.d.ts +44 -0
  28. package/dist/dsl/egraph/Rewriter.js +131 -0
  29. package/dist/dsl/egraph/Rules.d.ts +44 -0
  30. package/dist/dsl/egraph/Rules.js +187 -0
  31. package/dist/dsl/egraph/index.d.ts +15 -0
  32. package/dist/dsl/egraph/index.js +21 -0
  33. package/package.json +1 -1
  34. package/dist/dsl/CSE.d.ts +0 -21
  35. package/dist/dsl/CSE.js +0 -168
  36. package/dist/symbolic/AST.d.ts +0 -113
  37. package/dist/symbolic/AST.js +0 -128
  38. package/dist/symbolic/CodeGen.d.ts +0 -35
  39. package/dist/symbolic/CodeGen.js +0 -280
  40. package/dist/symbolic/Parser.d.ts +0 -64
  41. package/dist/symbolic/Parser.js +0 -329
  42. package/dist/symbolic/Simplify.d.ts +0 -10
  43. package/dist/symbolic/Simplify.js +0 -244
  44. package/dist/symbolic/SymbolicDiff.d.ts +0 -35
  45. package/dist/symbolic/SymbolicDiff.js +0 -339
@@ -0,0 +1,131 @@
1
+ /**
2
+ * Rewrite Engine for E-Graph Equality Saturation
3
+ *
4
+ * Applies rewrite rules until saturation (no new merges) or iteration limit.
5
+ */
6
+ import { matchPattern, instantiatePattern } from './Pattern.js';
7
+ /**
8
+ * Apply equality saturation to an e-graph
9
+ *
10
+ * Repeatedly applies rewrite rules until:
11
+ * - No new equivalences are discovered (saturated)
12
+ * - Max iterations reached
13
+ * - E-graph size limit exceeded
14
+ */
15
+ export function saturate(egraph, rules, options = {}) {
16
+ const { maxIterations = 10, maxClassSize = 5000, maxMatchesPerIter = 5000, verbose = false } = options;
17
+ const stats = {
18
+ iterations: 0,
19
+ totalMatches: 0,
20
+ merges: 0,
21
+ saturated: false,
22
+ classCount: egraph.size
23
+ };
24
+ for (let iter = 0; iter < maxIterations; iter++) {
25
+ stats.iterations = iter + 1;
26
+ // Check size limit
27
+ if (egraph.size > maxClassSize) {
28
+ if (verbose) {
29
+ console.log(`[saturate] Size limit exceeded: ${egraph.size} > ${maxClassSize}`);
30
+ }
31
+ break;
32
+ }
33
+ // Collect all rule matches (with limit)
34
+ const matches = collectMatches(egraph, rules, maxMatchesPerIter);
35
+ stats.totalMatches += matches.length;
36
+ if (matches.length === 0) {
37
+ stats.saturated = true;
38
+ if (verbose) {
39
+ console.log(`[saturate] Saturated after ${iter + 1} iterations`);
40
+ }
41
+ break;
42
+ }
43
+ // Apply all matches
44
+ let mergesThisIter = 0;
45
+ for (const { rule, classId, subst } of matches) {
46
+ const rhsId = instantiatePattern(egraph, rule.rhs, subst);
47
+ const lhsCanon = egraph.find(classId);
48
+ const rhsCanon = egraph.find(rhsId);
49
+ if (lhsCanon !== rhsCanon) {
50
+ egraph.merge(lhsCanon, rhsCanon);
51
+ mergesThisIter++;
52
+ }
53
+ }
54
+ stats.merges += mergesThisIter;
55
+ // Rebuild to restore invariants
56
+ egraph.rebuild();
57
+ if (verbose) {
58
+ console.log(`[saturate] Iter ${iter + 1}: ${matches.length} matches, ${mergesThisIter} merges, ${egraph.size} classes`);
59
+ }
60
+ // If no merges happened, we're saturated
61
+ if (mergesThisIter === 0) {
62
+ stats.saturated = true;
63
+ break;
64
+ }
65
+ }
66
+ stats.classCount = egraph.size;
67
+ return stats;
68
+ }
69
+ /**
70
+ * Collect all rule matches across the e-graph
71
+ */
72
+ function collectMatches(egraph, rules, maxMatches) {
73
+ const matches = [];
74
+ for (const classId of egraph.getClassIds()) {
75
+ for (const rule of rules) {
76
+ const substs = matchPattern(egraph, rule.lhs, classId);
77
+ for (const subst of substs) {
78
+ matches.push({ rule, classId, subst });
79
+ if (matches.length >= maxMatches) {
80
+ // Hit limit - return what we have (prevents explosion)
81
+ return matches;
82
+ }
83
+ }
84
+ }
85
+ }
86
+ return matches;
87
+ }
88
+ /**
89
+ * Apply a single rule once, returning number of merges
90
+ */
91
+ export function applyRuleOnce(egraph, rule) {
92
+ const matches = collectMatches(egraph, [rule], 5000);
93
+ let merges = 0;
94
+ for (const { classId, subst } of matches) {
95
+ const rhsId = instantiatePattern(egraph, rule.rhs, subst);
96
+ const lhsCanon = egraph.find(classId);
97
+ const rhsCanon = egraph.find(rhsId);
98
+ if (lhsCanon !== rhsCanon) {
99
+ egraph.merge(lhsCanon, rhsCanon);
100
+ merges++;
101
+ }
102
+ }
103
+ egraph.rebuild();
104
+ return merges;
105
+ }
106
+ /**
107
+ * Apply rules in phases for better control
108
+ * E.g., apply core rules first, then algebra rules
109
+ */
110
+ export function saturatePhased(egraph, phases, options = {}) {
111
+ const combinedStats = {
112
+ iterations: 0,
113
+ totalMatches: 0,
114
+ merges: 0,
115
+ saturated: true,
116
+ classCount: egraph.size
117
+ };
118
+ for (let phaseIdx = 0; phaseIdx < phases.length; phaseIdx++) {
119
+ const rules = phases[phaseIdx];
120
+ if (options.verbose) {
121
+ console.log(`[saturatePhased] Phase ${phaseIdx + 1}/${phases.length}: ${rules.length} rules`);
122
+ }
123
+ const phaseStats = saturate(egraph, rules, options);
124
+ combinedStats.iterations += phaseStats.iterations;
125
+ combinedStats.totalMatches += phaseStats.totalMatches;
126
+ combinedStats.merges += phaseStats.merges;
127
+ combinedStats.saturated = combinedStats.saturated && phaseStats.saturated;
128
+ }
129
+ combinedStats.classCount = egraph.size;
130
+ return combinedStats;
131
+ }
@@ -0,0 +1,44 @@
1
+ /**
2
+ * Algebraic Rewrite Rules for E-Graph Optimization
3
+ *
4
+ * Rules are organized by category:
5
+ * - Core: Always safe, essential (commutativity, identity, etc.)
6
+ * - Algebra: More powerful, can cause expansion (distribution, factoring)
7
+ * - Functions: Domain-specific (sqrt, trig, exp/log)
8
+ *
9
+ * Based on rules from Herbie (https://github.com/herbie-fp/herbie)
10
+ * and egg (https://egraphs-good.github.io/)
11
+ */
12
+ import { Pattern } from './Pattern.js';
13
+ /**
14
+ * A rewrite rule: if LHS matches, RHS is equivalent
15
+ */
16
+ export interface Rule {
17
+ name: string;
18
+ lhs: Pattern;
19
+ rhs: Pattern;
20
+ }
21
+ /**
22
+ * Create a rule from pattern strings
23
+ */
24
+ export declare function rule(name: string, lhs: string, rhs: string): Rule;
25
+ /**
26
+ * Create bidirectional rules (both directions)
27
+ */
28
+ export declare function biRule(name: string, a: string, b: string): Rule[];
29
+ export declare const coreRules: Rule[];
30
+ export declare const algebraRules: Rule[];
31
+ export declare const functionRules: Rule[];
32
+ /**
33
+ * All rules combined
34
+ */
35
+ export declare const allRules: Rule[];
36
+ /**
37
+ * Minimal rules for canonicalization only
38
+ * (no expansion, just normalization)
39
+ */
40
+ export declare const canonRules: Rule[];
41
+ /**
42
+ * Get rules by category
43
+ */
44
+ export declare function getRuleSet(categories: ('core' | 'algebra' | 'function')[]): Rule[];
@@ -0,0 +1,187 @@
1
+ /**
2
+ * Algebraic Rewrite Rules for E-Graph Optimization
3
+ *
4
+ * Rules are organized by category:
5
+ * - Core: Always safe, essential (commutativity, identity, etc.)
6
+ * - Algebra: More powerful, can cause expansion (distribution, factoring)
7
+ * - Functions: Domain-specific (sqrt, trig, exp/log)
8
+ *
9
+ * Based on rules from Herbie (https://github.com/herbie-fp/herbie)
10
+ * and egg (https://egraphs-good.github.io/)
11
+ */
12
+ import { parsePattern } from './Pattern.js';
13
+ /**
14
+ * Create a rule from pattern strings
15
+ */
16
+ export function rule(name, lhs, rhs) {
17
+ return {
18
+ name,
19
+ lhs: parsePattern(lhs),
20
+ rhs: parsePattern(rhs)
21
+ };
22
+ }
23
+ /**
24
+ * Create bidirectional rules (both directions)
25
+ */
26
+ export function biRule(name, a, b) {
27
+ return [
28
+ rule(`${name}-l`, a, b),
29
+ rule(`${name}-r`, b, a)
30
+ ];
31
+ }
32
+ // =============================================================================
33
+ // CORE RULES - Always safe, essential
34
+ // =============================================================================
35
+ export const coreRules = [
36
+ // === Commutativity ===
37
+ rule('comm-add', '(+ ?a ?b)', '(+ ?b ?a)'),
38
+ rule('comm-mul', '(* ?a ?b)', '(* ?b ?a)'),
39
+ // === Associativity ===
40
+ ...biRule('assoc-add', '(+ (+ ?a ?b) ?c)', '(+ ?a (+ ?b ?c))'),
41
+ ...biRule('assoc-mul', '(* (* ?a ?b) ?c)', '(* ?a (* ?b ?c))'),
42
+ // === Identity: addition ===
43
+ rule('add-0-l', '(+ 0 ?a)', '?a'),
44
+ rule('add-0-r', '(+ ?a 0)', '?a'),
45
+ // === Identity: subtraction ===
46
+ rule('sub-0', '(- ?a 0)', '?a'),
47
+ rule('0-sub', '(- 0 ?a)', '(neg ?a)'),
48
+ // === Identity: multiplication ===
49
+ rule('mul-1-l', '(* 1 ?a)', '?a'),
50
+ rule('mul-1-r', '(* ?a 1)', '?a'),
51
+ // === Identity: division ===
52
+ rule('div-1', '(/ ?a 1)', '?a'),
53
+ // === Identity: power ===
54
+ rule('pow-0', '(^ ?a 0)', '1'),
55
+ rule('pow-1', '(^ ?a 1)', '?a'),
56
+ // === Zero: multiplication ===
57
+ rule('mul-0-l', '(* 0 ?a)', '0'),
58
+ rule('mul-0-r', '(* ?a 0)', '0'),
59
+ // === Zero: division ===
60
+ rule('0-div', '(/ 0 ?a)', '0'),
61
+ // === Inverse: subtraction ===
62
+ rule('sub-self', '(- ?a ?a)', '0'),
63
+ // === Inverse: division ===
64
+ rule('div-self', '(/ ?a ?a)', '1'),
65
+ // === Double negation ===
66
+ rule('neg-neg', '(neg (neg ?a))', '?a'),
67
+ // === Negation with multiplication ===
68
+ rule('neg-mul-neg', '(* (neg ?a) (neg ?b))', '(* ?a ?b)'),
69
+ rule('mul-neg-1', '(* -1 ?a)', '(neg ?a)'),
70
+ rule('neg-to-mul', '(neg ?a)', '(* -1 ?a)'),
71
+ ];
72
+ // =============================================================================
73
+ // ALGEBRA RULES - More powerful, can cause expansion
74
+ // =============================================================================
75
+ export const algebraRules = [
76
+ // === Distribution (both directions) ===
77
+ ...biRule('dist-mul-add', '(* ?a (+ ?b ?c))', '(+ (* ?a ?b) (* ?a ?c))'),
78
+ ...biRule('dist-mul-sub', '(* ?a (- ?b ?c))', '(- (* ?a ?b) (* ?a ?c))'),
79
+ // === Negation propagation ===
80
+ rule('neg-add', '(neg (+ ?a ?b))', '(+ (neg ?a) (neg ?b))'),
81
+ rule('neg-sub', '(neg (- ?a ?b))', '(- ?b ?a)'),
82
+ rule('neg-mul-l', '(* (neg ?a) ?b)', '(neg (* ?a ?b))'),
83
+ rule('neg-mul-r', '(* ?a (neg ?b))', '(neg (* ?a ?b))'),
84
+ rule('neg-div-l', '(/ (neg ?a) ?b)', '(neg (/ ?a ?b))'),
85
+ rule('neg-div-r', '(/ ?a (neg ?b))', '(neg (/ ?a ?b))'),
86
+ // === Subtraction to addition ===
87
+ rule('sub-to-add', '(- ?a ?b)', '(+ ?a (neg ?b))'),
88
+ rule('add-neg-to-sub', '(+ ?a (neg ?b))', '(- ?a ?b)'),
89
+ // === Division to multiplication by inverse ===
90
+ rule('div-to-mul-inv', '(/ ?a ?b)', '(* ?a (inv ?b))'),
91
+ // === Inverse rules ===
92
+ rule('inv-1', '(inv 1)', '1'),
93
+ rule('inv-inv', '(inv (inv ?a))', '?a'),
94
+ rule('div-1-to-inv', '(/ 1 ?a)', '(inv ?a)'),
95
+ rule('mul-inv-cancel', '(* ?a (inv ?a))', '1'), // a * (1/a) = 1
96
+ rule('inv-mul', '(inv (* ?a ?b))', '(* (inv ?a) (inv ?b))'), // 1/(a*b) = (1/a)*(1/b)
97
+ rule('div-self-sq', '(/ ?a (* ?a ?a))', '(inv ?a)'), // a / (a*a) = 1/a
98
+ rule('div-sq-self', '(/ (* ?a ?a) ?a)', '?a'), // (a*a) / a = a
99
+ rule('div-self-pow2', '(/ ?a (^ ?a 2))', '(inv ?a)'), // a / a^2 = 1/a
100
+ rule('div-pow2-self', '(/ (^ ?a 2) ?a)', '?a'), // a^2 / a = a
101
+ // === Power rules ===
102
+ rule('pow-2', '(^ ?a 2)', '(* ?a ?a)'),
103
+ rule('sq-to-pow', '(* ?a ?a)', '(^ ?a 2)'),
104
+ // === Combining like terms ===
105
+ rule('add-same', '(+ ?a ?a)', '(* 2 ?a)'),
106
+ rule('sub-neg-same', '(- ?a (neg ?a))', '(* 2 ?a)'),
107
+ ];
108
+ // =============================================================================
109
+ // FUNCTION RULES - Domain-specific
110
+ // =============================================================================
111
+ export const functionRules = [
112
+ // === Sqrt ===
113
+ rule('sqrt-sq', '(* (sqrt ?a) (sqrt ?a))', '?a'),
114
+ rule('sqrt-mul', '(sqrt (* ?a ?b))', '(* (sqrt ?a) (sqrt ?b))'),
115
+ rule('sqrt-div', '(sqrt (/ ?a ?b))', '(/ (sqrt ?a) (sqrt ?b))'),
116
+ rule('sqrt-1', '(sqrt 1)', '1'),
117
+ rule('sqrt-0', '(sqrt 0)', '0'),
118
+ // === Abs ===
119
+ rule('abs-neg', '(abs (neg ?a))', '(abs ?a)'),
120
+ rule('abs-abs', '(abs (abs ?a))', '(abs ?a)'),
121
+ rule('abs-sq', '(abs (* ?a ?a))', '(* ?a ?a)'),
122
+ // === Trig ===
123
+ rule('sin-0', '(sin 0)', '0'),
124
+ rule('cos-0', '(cos 0)', '1'),
125
+ rule('sin-neg', '(sin (neg ?a))', '(neg (sin ?a))'),
126
+ rule('cos-neg', '(cos (neg ?a))', '(cos ?a)'),
127
+ // === Exp/Log ===
128
+ rule('exp-0', '(exp 0)', '1'),
129
+ rule('log-1', '(log 1)', '0'),
130
+ rule('exp-log', '(exp (log ?a))', '?a'),
131
+ rule('log-exp', '(log (exp ?a))', '?a'),
132
+ rule('log-mul', '(log (* ?a ?b))', '(+ (log ?a) (log ?b))'),
133
+ rule('log-div', '(log (/ ?a ?b))', '(- (log ?a) (log ?b))'),
134
+ rule('log-pow', '(log (^ ?a ?b))', '(* ?b (log ?a))'),
135
+ ];
136
+ // =============================================================================
137
+ // RULE SETS
138
+ // =============================================================================
139
+ /**
140
+ * All rules combined
141
+ */
142
+ export const allRules = [
143
+ ...coreRules,
144
+ ...algebraRules,
145
+ ...functionRules,
146
+ ];
147
+ /**
148
+ * Minimal rules for canonicalization only
149
+ * (no expansion, just normalization)
150
+ */
151
+ export const canonRules = [
152
+ // Commutativity (for canonical ordering)
153
+ rule('comm-add', '(+ ?a ?b)', '(+ ?b ?a)'),
154
+ rule('comm-mul', '(* ?a ?b)', '(* ?b ?a)'),
155
+ // Identity removal
156
+ rule('add-0-l', '(+ 0 ?a)', '?a'),
157
+ rule('add-0-r', '(+ ?a 0)', '?a'),
158
+ rule('mul-1-l', '(* 1 ?a)', '?a'),
159
+ rule('mul-1-r', '(* ?a 1)', '?a'),
160
+ rule('sub-0', '(- ?a 0)', '?a'),
161
+ rule('div-1', '(/ ?a 1)', '?a'),
162
+ rule('pow-1', '(^ ?a 1)', '?a'),
163
+ // Zero
164
+ rule('mul-0-l', '(* 0 ?a)', '0'),
165
+ rule('mul-0-r', '(* ?a 0)', '0'),
166
+ rule('0-div', '(/ 0 ?a)', '0'),
167
+ rule('pow-0', '(^ ?a 0)', '1'),
168
+ // Inverse
169
+ rule('sub-self', '(- ?a ?a)', '0'),
170
+ rule('div-self', '(/ ?a ?a)', '1'),
171
+ // Double negation
172
+ rule('neg-neg', '(neg (neg ?a))', '?a'),
173
+ rule('neg-mul-neg', '(* (neg ?a) (neg ?b))', '(* ?a ?b)'),
174
+ ];
175
+ /**
176
+ * Get rules by category
177
+ */
178
+ export function getRuleSet(categories) {
179
+ const rules = [];
180
+ if (categories.includes('core'))
181
+ rules.push(...coreRules);
182
+ if (categories.includes('algebra'))
183
+ rules.push(...algebraRules);
184
+ if (categories.includes('function'))
185
+ rules.push(...functionRules);
186
+ return rules;
187
+ }
@@ -0,0 +1,15 @@
1
+ /**
2
+ * E-Graph Module for GradientScript
3
+ *
4
+ * Provides equality saturation-based optimization for gradient expressions.
5
+ * Can be used as an alternative to (or in addition to) the CSE module.
6
+ */
7
+ export { EGraph } from './EGraph.js';
8
+ export type { EClass } from './EGraph.js';
9
+ export { EClassId, ENode, enodeKey, enodeChildren, enodeWithChildren } from './ENode.js';
10
+ export { Pattern, Substitution, parsePattern, matchPattern, instantiatePattern, patternToString } from './Pattern.js';
11
+ export { Rule, rule, biRule, coreRules, algebraRules, functionRules, allRules, canonRules, getRuleSet } from './Rules.js';
12
+ export { saturate, saturatePhased, applyRuleOnce, SaturationStats, SaturationOptions } from './Rewriter.js';
13
+ export { extractBest, extractWithCSE, ExtractionResult, CostModel, defaultCostModel } from './Extractor.js';
14
+ export { addExpression, addExpressions, addGradients, getRootIds } from './Convert.js';
15
+ export { optimizeWithEGraph, EGraphOptimizeResult } from './Optimizer.js';
@@ -0,0 +1,21 @@
1
+ /**
2
+ * E-Graph Module for GradientScript
3
+ *
4
+ * Provides equality saturation-based optimization for gradient expressions.
5
+ * Can be used as an alternative to (or in addition to) the CSE module.
6
+ */
7
+ // Core e-graph
8
+ export { EGraph } from './EGraph.js';
9
+ export { enodeKey, enodeChildren, enodeWithChildren } from './ENode.js';
10
+ // Pattern matching
11
+ export { parsePattern, matchPattern, instantiatePattern, patternToString } from './Pattern.js';
12
+ // Rewrite rules
13
+ export { rule, biRule, coreRules, algebraRules, functionRules, allRules, canonRules, getRuleSet } from './Rules.js';
14
+ // Saturation
15
+ export { saturate, saturatePhased, applyRuleOnce } from './Rewriter.js';
16
+ // Extraction
17
+ export { extractBest, extractWithCSE, defaultCostModel } from './Extractor.js';
18
+ // AST conversion
19
+ export { addExpression, addExpressions, addGradients, getRootIds } from './Convert.js';
20
+ // Main optimizer function
21
+ export { optimizeWithEGraph } from './Optimizer.js';
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "gradient-script",
3
- "version": "0.2.0",
3
+ "version": "0.3.1",
4
4
  "description": "Symbolic differentiation for structured types with a simple DSL",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",
package/dist/dsl/CSE.d.ts DELETED
@@ -1,21 +0,0 @@
1
- /**
2
- * Common Subexpression Elimination (CSE)
3
- * Identifies repeated expressions and factors them out
4
- */
5
- import { Expression } from './AST.js';
6
- export interface CSEResult {
7
- intermediates: Map<string, Expression>;
8
- simplified: Expression;
9
- }
10
- export interface StructuredCSEResult {
11
- intermediates: Map<string, Expression>;
12
- components: Map<string, Expression>;
13
- }
14
- /**
15
- * Perform CSE on an expression
16
- */
17
- export declare function eliminateCommonSubexpressions(expr: Expression, minCount?: number): CSEResult;
18
- /**
19
- * Perform CSE on structured gradients (for structured types like {x, y})
20
- */
21
- export declare function eliminateCommonSubexpressionsStructured(components: Map<string, Expression>, minCount?: number): StructuredCSEResult;
package/dist/dsl/CSE.js DELETED
@@ -1,168 +0,0 @@
1
- /**
2
- * Common Subexpression Elimination (CSE)
3
- * Identifies repeated expressions and factors them out
4
- */
5
- import { ExpressionTransformer } from './ExpressionTransformer.js';
6
- import { serializeExpression } from './ExpressionUtils.js';
7
- /**
8
- * Perform CSE on an expression
9
- */
10
- export function eliminateCommonSubexpressions(expr, minCount = 2) {
11
- const counter = new ExpressionCounter();
12
- counter.count(expr);
13
- const intermediates = new Map();
14
- let varCounter = 0;
15
- const subexprMap = new Map();
16
- for (const [exprStr, count] of counter.counts.entries()) {
17
- if (count >= minCount) {
18
- const parsed = counter.expressions.get(exprStr);
19
- if (parsed && shouldExtract(parsed)) {
20
- const varName = `_tmp${varCounter++}`;
21
- intermediates.set(varName, parsed);
22
- subexprMap.set(exprStr, varName);
23
- }
24
- }
25
- }
26
- const simplified = substituteExpressions(expr, subexprMap, counter);
27
- return { intermediates, simplified };
28
- }
29
- /**
30
- * Perform CSE on structured gradients (for structured types like {x, y})
31
- */
32
- export function eliminateCommonSubexpressionsStructured(components, minCount = 2) {
33
- const counter = new ExpressionCounter();
34
- for (const expr of components.values()) {
35
- counter.count(expr);
36
- }
37
- const intermediates = new Map();
38
- let varCounter = 0;
39
- const subexprMap = new Map();
40
- for (const [exprStr, count] of counter.counts.entries()) {
41
- if (count >= minCount) {
42
- const parsed = counter.expressions.get(exprStr);
43
- if (parsed && shouldExtract(parsed)) {
44
- const varName = `_tmp${varCounter++}`;
45
- intermediates.set(varName, parsed);
46
- subexprMap.set(exprStr, varName);
47
- }
48
- }
49
- }
50
- const simplifiedComponents = new Map();
51
- for (const [comp, expr] of components.entries()) {
52
- simplifiedComponents.set(comp, substituteExpressions(expr, subexprMap, counter));
53
- }
54
- return { intermediates, components: simplifiedComponents };
55
- }
56
- /**
57
- * Check if an expression should be extracted
58
- */
59
- function shouldExtract(expr) {
60
- switch (expr.kind) {
61
- case 'number':
62
- case 'variable':
63
- return false;
64
- case 'component':
65
- return expr.object.kind !== 'variable';
66
- case 'unary':
67
- return shouldExtract(expr.operand);
68
- case 'binary':
69
- return true;
70
- case 'call':
71
- return true;
72
- default:
73
- return false;
74
- }
75
- }
76
- /**
77
- * Counts occurrences of subexpressions during traversal
78
- */
79
- class ExpressionCounter extends ExpressionTransformer {
80
- counts = new Map();
81
- expressions = new Map();
82
- count(expr) {
83
- this.transform(expr);
84
- }
85
- serialize(expr) {
86
- return serializeExpression(expr);
87
- }
88
- recordExpression(expr) {
89
- const key = this.serialize(expr);
90
- const currentCount = this.counts.get(key) || 0;
91
- this.counts.set(key, currentCount + 1);
92
- if (!this.expressions.has(key)) {
93
- this.expressions.set(key, expr);
94
- }
95
- }
96
- visitNumber(node) {
97
- this.recordExpression(node);
98
- return node;
99
- }
100
- visitVariable(node) {
101
- this.recordExpression(node);
102
- return node;
103
- }
104
- visitBinaryOp(node) {
105
- this.recordExpression(node);
106
- return super.visitBinaryOp(node);
107
- }
108
- visitUnaryOp(node) {
109
- this.recordExpression(node);
110
- return super.visitUnaryOp(node);
111
- }
112
- visitFunctionCall(node) {
113
- this.recordExpression(node);
114
- return super.visitFunctionCall(node);
115
- }
116
- visitComponentAccess(node) {
117
- this.recordExpression(node);
118
- return super.visitComponentAccess(node);
119
- }
120
- }
121
- /**
122
- * Substitute common subexpressions with variables
123
- */
124
- function substituteExpressions(expr, subexprMap, counter) {
125
- const key = counter.serialize(expr);
126
- if (subexprMap.has(key)) {
127
- return {
128
- kind: 'variable',
129
- name: subexprMap.get(key)
130
- };
131
- }
132
- let result = expr;
133
- for (const [exprStr, varName] of subexprMap.entries()) {
134
- const exprToReplace = counter.expressions.get(exprStr);
135
- if (exprToReplace && counter.serialize(result) !== exprStr) {
136
- result = substituteInExpression(result, exprToReplace, { kind: 'variable', name: varName }, counter);
137
- }
138
- }
139
- return result;
140
- }
141
- /**
142
- * Transformer that substitutes a pattern with a replacement expression
143
- * Used for CSE optimization to replace repeated subexpressions with intermediate variables
144
- */
145
- class PatternSubstitutionTransformer extends ExpressionTransformer {
146
- pattern;
147
- replacement;
148
- counter;
149
- constructor(pattern, replacement, counter) {
150
- super();
151
- this.pattern = pattern;
152
- this.replacement = replacement;
153
- this.counter = counter;
154
- }
155
- transform(expr) {
156
- if (this.counter.serialize(expr) === this.counter.serialize(this.pattern)) {
157
- return this.replacement;
158
- }
159
- return super.transform(expr);
160
- }
161
- }
162
- /**
163
- * Helper to substitute an expression pattern with a replacement
164
- */
165
- function substituteInExpression(expr, pattern, replacement, counter) {
166
- const transformer = new PatternSubstitutionTransformer(pattern, replacement, counter);
167
- return transformer.transform(expr);
168
- }