gradient-script 0.1.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. package/README.md +52 -9
  2. package/dist/cli.js +134 -19
  3. package/dist/dsl/AST.d.ts +8 -0
  4. package/dist/dsl/CodeGen.d.ts +8 -3
  5. package/dist/dsl/CodeGen.js +583 -132
  6. package/dist/dsl/Errors.d.ts +6 -1
  7. package/dist/dsl/Errors.js +70 -1
  8. package/dist/dsl/Expander.js +5 -2
  9. package/dist/dsl/ExpressionUtils.d.ts +14 -0
  10. package/dist/dsl/ExpressionUtils.js +56 -0
  11. package/dist/dsl/GradientChecker.d.ts +21 -0
  12. package/dist/dsl/GradientChecker.js +109 -23
  13. package/dist/dsl/Guards.d.ts +3 -1
  14. package/dist/dsl/Guards.js +86 -43
  15. package/dist/dsl/Inliner.d.ts +5 -0
  16. package/dist/dsl/Inliner.js +11 -2
  17. package/dist/dsl/Lexer.js +3 -1
  18. package/dist/dsl/Parser.js +11 -5
  19. package/dist/dsl/Simplify.d.ts +7 -0
  20. package/dist/dsl/Simplify.js +183 -0
  21. package/dist/dsl/egraph/Convert.d.ts +23 -0
  22. package/dist/dsl/egraph/Convert.js +84 -0
  23. package/dist/dsl/egraph/EGraph.d.ts +93 -0
  24. package/dist/dsl/egraph/EGraph.js +292 -0
  25. package/dist/dsl/egraph/ENode.d.ts +63 -0
  26. package/dist/dsl/egraph/ENode.js +94 -0
  27. package/dist/dsl/egraph/Extractor.d.ts +49 -0
  28. package/dist/dsl/egraph/Extractor.js +1068 -0
  29. package/dist/dsl/egraph/Optimizer.d.ts +50 -0
  30. package/dist/dsl/egraph/Optimizer.js +88 -0
  31. package/dist/dsl/egraph/Pattern.d.ts +80 -0
  32. package/dist/dsl/egraph/Pattern.js +325 -0
  33. package/dist/dsl/egraph/Rewriter.d.ts +44 -0
  34. package/dist/dsl/egraph/Rewriter.js +131 -0
  35. package/dist/dsl/egraph/Rules.d.ts +44 -0
  36. package/dist/dsl/egraph/Rules.js +187 -0
  37. package/dist/dsl/egraph/index.d.ts +15 -0
  38. package/dist/dsl/egraph/index.js +21 -0
  39. package/package.json +1 -1
  40. package/dist/dsl/CSE.d.ts +0 -21
  41. package/dist/dsl/CSE.js +0 -194
  42. package/dist/symbolic/AST.d.ts +0 -113
  43. package/dist/symbolic/AST.js +0 -128
  44. package/dist/symbolic/CodeGen.d.ts +0 -35
  45. package/dist/symbolic/CodeGen.js +0 -280
  46. package/dist/symbolic/Parser.d.ts +0 -64
  47. package/dist/symbolic/Parser.js +0 -329
  48. package/dist/symbolic/Simplify.d.ts +0 -10
  49. package/dist/symbolic/Simplify.js +0 -244
  50. package/dist/symbolic/SymbolicDiff.d.ts +0 -35
  51. package/dist/symbolic/SymbolicDiff.js +0 -339
@@ -0,0 +1,131 @@
1
+ /**
2
+ * Rewrite Engine for E-Graph Equality Saturation
3
+ *
4
+ * Applies rewrite rules until saturation (no new merges) or iteration limit.
5
+ */
6
+ import { matchPattern, instantiatePattern } from './Pattern.js';
7
+ /**
8
+ * Apply equality saturation to an e-graph
9
+ *
10
+ * Repeatedly applies rewrite rules until:
11
+ * - No new equivalences are discovered (saturated)
12
+ * - Max iterations reached
13
+ * - E-graph size limit exceeded
14
+ */
15
+ export function saturate(egraph, rules, options = {}) {
16
+ const { maxIterations = 10, maxClassSize = 5000, maxMatchesPerIter = 5000, verbose = false } = options;
17
+ const stats = {
18
+ iterations: 0,
19
+ totalMatches: 0,
20
+ merges: 0,
21
+ saturated: false,
22
+ classCount: egraph.size
23
+ };
24
+ for (let iter = 0; iter < maxIterations; iter++) {
25
+ stats.iterations = iter + 1;
26
+ // Check size limit
27
+ if (egraph.size > maxClassSize) {
28
+ if (verbose) {
29
+ console.log(`[saturate] Size limit exceeded: ${egraph.size} > ${maxClassSize}`);
30
+ }
31
+ break;
32
+ }
33
+ // Collect all rule matches (with limit)
34
+ const matches = collectMatches(egraph, rules, maxMatchesPerIter);
35
+ stats.totalMatches += matches.length;
36
+ if (matches.length === 0) {
37
+ stats.saturated = true;
38
+ if (verbose) {
39
+ console.log(`[saturate] Saturated after ${iter + 1} iterations`);
40
+ }
41
+ break;
42
+ }
43
+ // Apply all matches
44
+ let mergesThisIter = 0;
45
+ for (const { rule, classId, subst } of matches) {
46
+ const rhsId = instantiatePattern(egraph, rule.rhs, subst);
47
+ const lhsCanon = egraph.find(classId);
48
+ const rhsCanon = egraph.find(rhsId);
49
+ if (lhsCanon !== rhsCanon) {
50
+ egraph.merge(lhsCanon, rhsCanon);
51
+ mergesThisIter++;
52
+ }
53
+ }
54
+ stats.merges += mergesThisIter;
55
+ // Rebuild to restore invariants
56
+ egraph.rebuild();
57
+ if (verbose) {
58
+ console.log(`[saturate] Iter ${iter + 1}: ${matches.length} matches, ${mergesThisIter} merges, ${egraph.size} classes`);
59
+ }
60
+ // If no merges happened, we're saturated
61
+ if (mergesThisIter === 0) {
62
+ stats.saturated = true;
63
+ break;
64
+ }
65
+ }
66
+ stats.classCount = egraph.size;
67
+ return stats;
68
+ }
69
+ /**
70
+ * Collect all rule matches across the e-graph
71
+ */
72
+ function collectMatches(egraph, rules, maxMatches) {
73
+ const matches = [];
74
+ for (const classId of egraph.getClassIds()) {
75
+ for (const rule of rules) {
76
+ const substs = matchPattern(egraph, rule.lhs, classId);
77
+ for (const subst of substs) {
78
+ matches.push({ rule, classId, subst });
79
+ if (matches.length >= maxMatches) {
80
+ // Hit limit - return what we have (prevents explosion)
81
+ return matches;
82
+ }
83
+ }
84
+ }
85
+ }
86
+ return matches;
87
+ }
88
+ /**
89
+ * Apply a single rule once, returning number of merges
90
+ */
91
+ export function applyRuleOnce(egraph, rule) {
92
+ const matches = collectMatches(egraph, [rule], 5000);
93
+ let merges = 0;
94
+ for (const { classId, subst } of matches) {
95
+ const rhsId = instantiatePattern(egraph, rule.rhs, subst);
96
+ const lhsCanon = egraph.find(classId);
97
+ const rhsCanon = egraph.find(rhsId);
98
+ if (lhsCanon !== rhsCanon) {
99
+ egraph.merge(lhsCanon, rhsCanon);
100
+ merges++;
101
+ }
102
+ }
103
+ egraph.rebuild();
104
+ return merges;
105
+ }
106
+ /**
107
+ * Apply rules in phases for better control
108
+ * E.g., apply core rules first, then algebra rules
109
+ */
110
+ export function saturatePhased(egraph, phases, options = {}) {
111
+ const combinedStats = {
112
+ iterations: 0,
113
+ totalMatches: 0,
114
+ merges: 0,
115
+ saturated: true,
116
+ classCount: egraph.size
117
+ };
118
+ for (let phaseIdx = 0; phaseIdx < phases.length; phaseIdx++) {
119
+ const rules = phases[phaseIdx];
120
+ if (options.verbose) {
121
+ console.log(`[saturatePhased] Phase ${phaseIdx + 1}/${phases.length}: ${rules.length} rules`);
122
+ }
123
+ const phaseStats = saturate(egraph, rules, options);
124
+ combinedStats.iterations += phaseStats.iterations;
125
+ combinedStats.totalMatches += phaseStats.totalMatches;
126
+ combinedStats.merges += phaseStats.merges;
127
+ combinedStats.saturated = combinedStats.saturated && phaseStats.saturated;
128
+ }
129
+ combinedStats.classCount = egraph.size;
130
+ return combinedStats;
131
+ }
@@ -0,0 +1,44 @@
1
+ /**
2
+ * Algebraic Rewrite Rules for E-Graph Optimization
3
+ *
4
+ * Rules are organized by category:
5
+ * - Core: Always safe, essential (commutativity, identity, etc.)
6
+ * - Algebra: More powerful, can cause expansion (distribution, factoring)
7
+ * - Functions: Domain-specific (sqrt, trig, exp/log)
8
+ *
9
+ * Based on rules from Herbie (https://github.com/herbie-fp/herbie)
10
+ * and egg (https://egraphs-good.github.io/)
11
+ */
12
+ import { Pattern } from './Pattern.js';
13
+ /**
14
+ * A rewrite rule: if LHS matches, RHS is equivalent
15
+ */
16
+ export interface Rule {
17
+ name: string;
18
+ lhs: Pattern;
19
+ rhs: Pattern;
20
+ }
21
+ /**
22
+ * Create a rule from pattern strings
23
+ */
24
+ export declare function rule(name: string, lhs: string, rhs: string): Rule;
25
+ /**
26
+ * Create bidirectional rules (both directions)
27
+ */
28
+ export declare function biRule(name: string, a: string, b: string): Rule[];
29
+ export declare const coreRules: Rule[];
30
+ export declare const algebraRules: Rule[];
31
+ export declare const functionRules: Rule[];
32
+ /**
33
+ * All rules combined
34
+ */
35
+ export declare const allRules: Rule[];
36
+ /**
37
+ * Minimal rules for canonicalization only
38
+ * (no expansion, just normalization)
39
+ */
40
+ export declare const canonRules: Rule[];
41
+ /**
42
+ * Get rules by category
43
+ */
44
+ export declare function getRuleSet(categories: ('core' | 'algebra' | 'function')[]): Rule[];
@@ -0,0 +1,187 @@
1
+ /**
2
+ * Algebraic Rewrite Rules for E-Graph Optimization
3
+ *
4
+ * Rules are organized by category:
5
+ * - Core: Always safe, essential (commutativity, identity, etc.)
6
+ * - Algebra: More powerful, can cause expansion (distribution, factoring)
7
+ * - Functions: Domain-specific (sqrt, trig, exp/log)
8
+ *
9
+ * Based on rules from Herbie (https://github.com/herbie-fp/herbie)
10
+ * and egg (https://egraphs-good.github.io/)
11
+ */
12
+ import { parsePattern } from './Pattern.js';
13
+ /**
14
+ * Create a rule from pattern strings
15
+ */
16
+ export function rule(name, lhs, rhs) {
17
+ return {
18
+ name,
19
+ lhs: parsePattern(lhs),
20
+ rhs: parsePattern(rhs)
21
+ };
22
+ }
23
+ /**
24
+ * Create bidirectional rules (both directions)
25
+ */
26
+ export function biRule(name, a, b) {
27
+ return [
28
+ rule(`${name}-l`, a, b),
29
+ rule(`${name}-r`, b, a)
30
+ ];
31
+ }
32
+ // =============================================================================
33
+ // CORE RULES - Always safe, essential
34
+ // =============================================================================
35
+ export const coreRules = [
36
+ // === Commutativity ===
37
+ rule('comm-add', '(+ ?a ?b)', '(+ ?b ?a)'),
38
+ rule('comm-mul', '(* ?a ?b)', '(* ?b ?a)'),
39
+ // === Associativity ===
40
+ ...biRule('assoc-add', '(+ (+ ?a ?b) ?c)', '(+ ?a (+ ?b ?c))'),
41
+ ...biRule('assoc-mul', '(* (* ?a ?b) ?c)', '(* ?a (* ?b ?c))'),
42
+ // === Identity: addition ===
43
+ rule('add-0-l', '(+ 0 ?a)', '?a'),
44
+ rule('add-0-r', '(+ ?a 0)', '?a'),
45
+ // === Identity: subtraction ===
46
+ rule('sub-0', '(- ?a 0)', '?a'),
47
+ rule('0-sub', '(- 0 ?a)', '(neg ?a)'),
48
+ // === Identity: multiplication ===
49
+ rule('mul-1-l', '(* 1 ?a)', '?a'),
50
+ rule('mul-1-r', '(* ?a 1)', '?a'),
51
+ // === Identity: division ===
52
+ rule('div-1', '(/ ?a 1)', '?a'),
53
+ // === Identity: power ===
54
+ rule('pow-0', '(^ ?a 0)', '1'),
55
+ rule('pow-1', '(^ ?a 1)', '?a'),
56
+ // === Zero: multiplication ===
57
+ rule('mul-0-l', '(* 0 ?a)', '0'),
58
+ rule('mul-0-r', '(* ?a 0)', '0'),
59
+ // === Zero: division ===
60
+ rule('0-div', '(/ 0 ?a)', '0'),
61
+ // === Inverse: subtraction ===
62
+ rule('sub-self', '(- ?a ?a)', '0'),
63
+ // === Inverse: division ===
64
+ rule('div-self', '(/ ?a ?a)', '1'),
65
+ // === Double negation ===
66
+ rule('neg-neg', '(neg (neg ?a))', '?a'),
67
+ // === Negation with multiplication ===
68
+ rule('neg-mul-neg', '(* (neg ?a) (neg ?b))', '(* ?a ?b)'),
69
+ rule('mul-neg-1', '(* -1 ?a)', '(neg ?a)'),
70
+ rule('neg-to-mul', '(neg ?a)', '(* -1 ?a)'),
71
+ ];
72
+ // =============================================================================
73
+ // ALGEBRA RULES - More powerful, can cause expansion
74
+ // =============================================================================
75
+ export const algebraRules = [
76
+ // === Distribution (both directions) ===
77
+ ...biRule('dist-mul-add', '(* ?a (+ ?b ?c))', '(+ (* ?a ?b) (* ?a ?c))'),
78
+ ...biRule('dist-mul-sub', '(* ?a (- ?b ?c))', '(- (* ?a ?b) (* ?a ?c))'),
79
+ // === Negation propagation ===
80
+ rule('neg-add', '(neg (+ ?a ?b))', '(+ (neg ?a) (neg ?b))'),
81
+ rule('neg-sub', '(neg (- ?a ?b))', '(- ?b ?a)'),
82
+ rule('neg-mul-l', '(* (neg ?a) ?b)', '(neg (* ?a ?b))'),
83
+ rule('neg-mul-r', '(* ?a (neg ?b))', '(neg (* ?a ?b))'),
84
+ rule('neg-div-l', '(/ (neg ?a) ?b)', '(neg (/ ?a ?b))'),
85
+ rule('neg-div-r', '(/ ?a (neg ?b))', '(neg (/ ?a ?b))'),
86
+ // === Subtraction to addition ===
87
+ rule('sub-to-add', '(- ?a ?b)', '(+ ?a (neg ?b))'),
88
+ rule('add-neg-to-sub', '(+ ?a (neg ?b))', '(- ?a ?b)'),
89
+ // === Division to multiplication by inverse ===
90
+ rule('div-to-mul-inv', '(/ ?a ?b)', '(* ?a (inv ?b))'),
91
+ // === Inverse rules ===
92
+ rule('inv-1', '(inv 1)', '1'),
93
+ rule('inv-inv', '(inv (inv ?a))', '?a'),
94
+ rule('div-1-to-inv', '(/ 1 ?a)', '(inv ?a)'),
95
+ rule('mul-inv-cancel', '(* ?a (inv ?a))', '1'), // a * (1/a) = 1
96
+ rule('inv-mul', '(inv (* ?a ?b))', '(* (inv ?a) (inv ?b))'), // 1/(a*b) = (1/a)*(1/b)
97
+ rule('div-self-sq', '(/ ?a (* ?a ?a))', '(inv ?a)'), // a / (a*a) = 1/a
98
+ rule('div-sq-self', '(/ (* ?a ?a) ?a)', '?a'), // (a*a) / a = a
99
+ rule('div-self-pow2', '(/ ?a (^ ?a 2))', '(inv ?a)'), // a / a^2 = 1/a
100
+ rule('div-pow2-self', '(/ (^ ?a 2) ?a)', '?a'), // a^2 / a = a
101
+ // === Power rules ===
102
+ rule('pow-2', '(^ ?a 2)', '(* ?a ?a)'),
103
+ rule('sq-to-pow', '(* ?a ?a)', '(^ ?a 2)'),
104
+ // === Combining like terms ===
105
+ rule('add-same', '(+ ?a ?a)', '(* 2 ?a)'),
106
+ rule('sub-neg-same', '(- ?a (neg ?a))', '(* 2 ?a)'),
107
+ ];
108
+ // =============================================================================
109
+ // FUNCTION RULES - Domain-specific
110
+ // =============================================================================
111
+ export const functionRules = [
112
+ // === Sqrt ===
113
+ rule('sqrt-sq', '(* (sqrt ?a) (sqrt ?a))', '?a'),
114
+ rule('sqrt-mul', '(sqrt (* ?a ?b))', '(* (sqrt ?a) (sqrt ?b))'),
115
+ rule('sqrt-div', '(sqrt (/ ?a ?b))', '(/ (sqrt ?a) (sqrt ?b))'),
116
+ rule('sqrt-1', '(sqrt 1)', '1'),
117
+ rule('sqrt-0', '(sqrt 0)', '0'),
118
+ // === Abs ===
119
+ rule('abs-neg', '(abs (neg ?a))', '(abs ?a)'),
120
+ rule('abs-abs', '(abs (abs ?a))', '(abs ?a)'),
121
+ rule('abs-sq', '(abs (* ?a ?a))', '(* ?a ?a)'),
122
+ // === Trig ===
123
+ rule('sin-0', '(sin 0)', '0'),
124
+ rule('cos-0', '(cos 0)', '1'),
125
+ rule('sin-neg', '(sin (neg ?a))', '(neg (sin ?a))'),
126
+ rule('cos-neg', '(cos (neg ?a))', '(cos ?a)'),
127
+ // === Exp/Log ===
128
+ rule('exp-0', '(exp 0)', '1'),
129
+ rule('log-1', '(log 1)', '0'),
130
+ rule('exp-log', '(exp (log ?a))', '?a'),
131
+ rule('log-exp', '(log (exp ?a))', '?a'),
132
+ rule('log-mul', '(log (* ?a ?b))', '(+ (log ?a) (log ?b))'),
133
+ rule('log-div', '(log (/ ?a ?b))', '(- (log ?a) (log ?b))'),
134
+ rule('log-pow', '(log (^ ?a ?b))', '(* ?b (log ?a))'),
135
+ ];
136
+ // =============================================================================
137
+ // RULE SETS
138
+ // =============================================================================
139
+ /**
140
+ * All rules combined
141
+ */
142
+ export const allRules = [
143
+ ...coreRules,
144
+ ...algebraRules,
145
+ ...functionRules,
146
+ ];
147
+ /**
148
+ * Minimal rules for canonicalization only
149
+ * (no expansion, just normalization)
150
+ */
151
+ export const canonRules = [
152
+ // Commutativity (for canonical ordering)
153
+ rule('comm-add', '(+ ?a ?b)', '(+ ?b ?a)'),
154
+ rule('comm-mul', '(* ?a ?b)', '(* ?b ?a)'),
155
+ // Identity removal
156
+ rule('add-0-l', '(+ 0 ?a)', '?a'),
157
+ rule('add-0-r', '(+ ?a 0)', '?a'),
158
+ rule('mul-1-l', '(* 1 ?a)', '?a'),
159
+ rule('mul-1-r', '(* ?a 1)', '?a'),
160
+ rule('sub-0', '(- ?a 0)', '?a'),
161
+ rule('div-1', '(/ ?a 1)', '?a'),
162
+ rule('pow-1', '(^ ?a 1)', '?a'),
163
+ // Zero
164
+ rule('mul-0-l', '(* 0 ?a)', '0'),
165
+ rule('mul-0-r', '(* ?a 0)', '0'),
166
+ rule('0-div', '(/ 0 ?a)', '0'),
167
+ rule('pow-0', '(^ ?a 0)', '1'),
168
+ // Inverse
169
+ rule('sub-self', '(- ?a ?a)', '0'),
170
+ rule('div-self', '(/ ?a ?a)', '1'),
171
+ // Double negation
172
+ rule('neg-neg', '(neg (neg ?a))', '?a'),
173
+ rule('neg-mul-neg', '(* (neg ?a) (neg ?b))', '(* ?a ?b)'),
174
+ ];
175
+ /**
176
+ * Get rules by category
177
+ */
178
+ export function getRuleSet(categories) {
179
+ const rules = [];
180
+ if (categories.includes('core'))
181
+ rules.push(...coreRules);
182
+ if (categories.includes('algebra'))
183
+ rules.push(...algebraRules);
184
+ if (categories.includes('function'))
185
+ rules.push(...functionRules);
186
+ return rules;
187
+ }
@@ -0,0 +1,15 @@
1
+ /**
2
+ * E-Graph Module for GradientScript
3
+ *
4
+ * Provides equality saturation-based optimization for gradient expressions.
5
+ * Can be used as an alternative to (or in addition to) the CSE module.
6
+ */
7
+ export { EGraph } from './EGraph.js';
8
+ export type { EClass } from './EGraph.js';
9
+ export { EClassId, ENode, enodeKey, enodeChildren, enodeWithChildren } from './ENode.js';
10
+ export { Pattern, Substitution, parsePattern, matchPattern, instantiatePattern, patternToString } from './Pattern.js';
11
+ export { Rule, rule, biRule, coreRules, algebraRules, functionRules, allRules, canonRules, getRuleSet } from './Rules.js';
12
+ export { saturate, saturatePhased, applyRuleOnce, SaturationStats, SaturationOptions } from './Rewriter.js';
13
+ export { extractBest, extractWithCSE, ExtractionResult, CostModel, defaultCostModel } from './Extractor.js';
14
+ export { addExpression, addExpressions, addGradients, getRootIds } from './Convert.js';
15
+ export { optimizeWithEGraph, EGraphOptimizeResult } from './Optimizer.js';
@@ -0,0 +1,21 @@
1
+ /**
2
+ * E-Graph Module for GradientScript
3
+ *
4
+ * Provides equality saturation-based optimization for gradient expressions.
5
+ * Can be used as an alternative to (or in addition to) the CSE module.
6
+ */
7
+ // Core e-graph
8
+ export { EGraph } from './EGraph.js';
9
+ export { enodeKey, enodeChildren, enodeWithChildren } from './ENode.js';
10
+ // Pattern matching
11
+ export { parsePattern, matchPattern, instantiatePattern, patternToString } from './Pattern.js';
12
+ // Rewrite rules
13
+ export { rule, biRule, coreRules, algebraRules, functionRules, allRules, canonRules, getRuleSet } from './Rules.js';
14
+ // Saturation
15
+ export { saturate, saturatePhased, applyRuleOnce } from './Rewriter.js';
16
+ // Extraction
17
+ export { extractBest, extractWithCSE, defaultCostModel } from './Extractor.js';
18
+ // AST conversion
19
+ export { addExpression, addExpressions, addGradients, getRootIds } from './Convert.js';
20
+ // Main optimizer function
21
+ export { optimizeWithEGraph } from './Optimizer.js';
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "gradient-script",
3
- "version": "0.1.0",
3
+ "version": "0.3.0",
4
4
  "description": "Symbolic differentiation for structured types with a simple DSL",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",
package/dist/dsl/CSE.d.ts DELETED
@@ -1,21 +0,0 @@
1
- /**
2
- * Common Subexpression Elimination (CSE)
3
- * Identifies repeated expressions and factors them out
4
- */
5
- import { Expression } from './AST.js';
6
- export interface CSEResult {
7
- intermediates: Map<string, Expression>;
8
- simplified: Expression;
9
- }
10
- export interface StructuredCSEResult {
11
- intermediates: Map<string, Expression>;
12
- components: Map<string, Expression>;
13
- }
14
- /**
15
- * Perform CSE on an expression
16
- */
17
- export declare function eliminateCommonSubexpressions(expr: Expression, minCount?: number): CSEResult;
18
- /**
19
- * Perform CSE on structured gradients (for structured types like {x, y})
20
- */
21
- export declare function eliminateCommonSubexpressionsStructured(components: Map<string, Expression>, minCount?: number): StructuredCSEResult;
package/dist/dsl/CSE.js DELETED
@@ -1,194 +0,0 @@
1
- /**
2
- * Common Subexpression Elimination (CSE)
3
- * Identifies repeated expressions and factors them out
4
- */
5
- import { ExpressionTransformer } from './ExpressionTransformer.js';
6
- /**
7
- * Serializes expressions to canonical string form for comparison
8
- * This is a dedicated serializer that doesn't abuse the type system
9
- */
10
- class ExpressionSerializer {
11
- serialize(expr) {
12
- switch (expr.kind) {
13
- case 'number':
14
- return `num(${expr.value})`;
15
- case 'variable':
16
- return `var(${expr.name})`;
17
- case 'binary':
18
- const left = this.serialize(expr.left);
19
- const right = this.serialize(expr.right);
20
- return `bin(${expr.operator},${left},${right})`;
21
- case 'unary':
22
- const operand = this.serialize(expr.operand);
23
- return `un(${expr.operator},${operand})`;
24
- case 'call':
25
- const args = expr.args.map(arg => this.serialize(arg)).join(',');
26
- return `call(${expr.name},${args})`;
27
- case 'component':
28
- const object = this.serialize(expr.object);
29
- return `comp(${object},${expr.component})`;
30
- }
31
- }
32
- }
33
- /**
34
- * Perform CSE on an expression
35
- */
36
- export function eliminateCommonSubexpressions(expr, minCount = 2) {
37
- const counter = new ExpressionCounter();
38
- counter.count(expr);
39
- const intermediates = new Map();
40
- let varCounter = 0;
41
- const subexprMap = new Map();
42
- for (const [exprStr, count] of counter.counts.entries()) {
43
- if (count >= minCount) {
44
- const parsed = counter.expressions.get(exprStr);
45
- if (parsed && shouldExtract(parsed)) {
46
- const varName = `_tmp${varCounter++}`;
47
- intermediates.set(varName, parsed);
48
- subexprMap.set(exprStr, varName);
49
- }
50
- }
51
- }
52
- const simplified = substituteExpressions(expr, subexprMap, counter);
53
- return { intermediates, simplified };
54
- }
55
- /**
56
- * Perform CSE on structured gradients (for structured types like {x, y})
57
- */
58
- export function eliminateCommonSubexpressionsStructured(components, minCount = 2) {
59
- const counter = new ExpressionCounter();
60
- for (const expr of components.values()) {
61
- counter.count(expr);
62
- }
63
- const intermediates = new Map();
64
- let varCounter = 0;
65
- const subexprMap = new Map();
66
- for (const [exprStr, count] of counter.counts.entries()) {
67
- if (count >= minCount) {
68
- const parsed = counter.expressions.get(exprStr);
69
- if (parsed && shouldExtract(parsed)) {
70
- const varName = `_tmp${varCounter++}`;
71
- intermediates.set(varName, parsed);
72
- subexprMap.set(exprStr, varName);
73
- }
74
- }
75
- }
76
- const simplifiedComponents = new Map();
77
- for (const [comp, expr] of components.entries()) {
78
- simplifiedComponents.set(comp, substituteExpressions(expr, subexprMap, counter));
79
- }
80
- return { intermediates, components: simplifiedComponents };
81
- }
82
- /**
83
- * Check if an expression should be extracted
84
- */
85
- function shouldExtract(expr) {
86
- switch (expr.kind) {
87
- case 'number':
88
- case 'variable':
89
- return false;
90
- case 'component':
91
- return expr.object.kind !== 'variable';
92
- case 'unary':
93
- return shouldExtract(expr.operand);
94
- case 'binary':
95
- return true;
96
- case 'call':
97
- return true;
98
- default:
99
- return false;
100
- }
101
- }
102
- /**
103
- * Counts occurrences of subexpressions during traversal
104
- */
105
- class ExpressionCounter extends ExpressionTransformer {
106
- counts = new Map();
107
- expressions = new Map();
108
- serializer = new ExpressionSerializer();
109
- count(expr) {
110
- this.transform(expr);
111
- }
112
- serialize(expr) {
113
- return this.serializer.serialize(expr);
114
- }
115
- recordExpression(expr) {
116
- const key = this.serialize(expr);
117
- const currentCount = this.counts.get(key) || 0;
118
- this.counts.set(key, currentCount + 1);
119
- if (!this.expressions.has(key)) {
120
- this.expressions.set(key, expr);
121
- }
122
- }
123
- visitNumber(node) {
124
- this.recordExpression(node);
125
- return node;
126
- }
127
- visitVariable(node) {
128
- this.recordExpression(node);
129
- return node;
130
- }
131
- visitBinaryOp(node) {
132
- this.recordExpression(node);
133
- return super.visitBinaryOp(node);
134
- }
135
- visitUnaryOp(node) {
136
- this.recordExpression(node);
137
- return super.visitUnaryOp(node);
138
- }
139
- visitFunctionCall(node) {
140
- this.recordExpression(node);
141
- return super.visitFunctionCall(node);
142
- }
143
- visitComponentAccess(node) {
144
- this.recordExpression(node);
145
- return super.visitComponentAccess(node);
146
- }
147
- }
148
- /**
149
- * Substitute common subexpressions with variables
150
- */
151
- function substituteExpressions(expr, subexprMap, counter) {
152
- const key = counter.serialize(expr);
153
- if (subexprMap.has(key)) {
154
- return {
155
- kind: 'variable',
156
- name: subexprMap.get(key)
157
- };
158
- }
159
- let result = expr;
160
- for (const [exprStr, varName] of subexprMap.entries()) {
161
- const exprToReplace = counter.expressions.get(exprStr);
162
- if (exprToReplace && counter.serialize(result) !== exprStr) {
163
- result = substituteInExpression(result, exprToReplace, { kind: 'variable', name: varName }, counter);
164
- }
165
- }
166
- return result;
167
- }
168
- /**
169
- * Transformer that substitutes a pattern with a replacement expression
170
- */
171
- class SubstitutionTransformer extends ExpressionTransformer {
172
- pattern;
173
- replacement;
174
- counter;
175
- constructor(pattern, replacement, counter) {
176
- super();
177
- this.pattern = pattern;
178
- this.replacement = replacement;
179
- this.counter = counter;
180
- }
181
- transform(expr) {
182
- if (this.counter.serialize(expr) === this.counter.serialize(this.pattern)) {
183
- return this.replacement;
184
- }
185
- return super.transform(expr);
186
- }
187
- }
188
- /**
189
- * Helper to substitute an expression pattern with a replacement
190
- */
191
- function substituteInExpression(expr, pattern, replacement, counter) {
192
- const transformer = new SubstitutionTransformer(pattern, replacement, counter);
193
- return transformer.transform(expr);
194
- }