gradient-script 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -1
- package/dist/cli.js +80 -3
- package/dist/dsl/CodeGen.d.ts +1 -1
- package/dist/dsl/CodeGen.js +332 -74
- package/dist/dsl/ExpressionUtils.d.ts +8 -2
- package/dist/dsl/ExpressionUtils.js +34 -2
- package/dist/dsl/GradientChecker.d.ts +21 -0
- package/dist/dsl/GradientChecker.js +109 -23
- package/dist/dsl/Guards.d.ts +1 -1
- package/dist/dsl/Guards.js +14 -13
- package/dist/dsl/Inliner.d.ts +5 -0
- package/dist/dsl/Inliner.js +8 -0
- package/dist/dsl/Simplify.d.ts +7 -0
- package/dist/dsl/Simplify.js +136 -0
- package/dist/dsl/egraph/Convert.d.ts +23 -0
- package/dist/dsl/egraph/Convert.js +84 -0
- package/dist/dsl/egraph/EGraph.d.ts +93 -0
- package/dist/dsl/egraph/EGraph.js +292 -0
- package/dist/dsl/egraph/ENode.d.ts +63 -0
- package/dist/dsl/egraph/ENode.js +94 -0
- package/dist/dsl/egraph/Extractor.d.ts +49 -0
- package/dist/dsl/egraph/Extractor.js +1068 -0
- package/dist/dsl/egraph/Optimizer.d.ts +50 -0
- package/dist/dsl/egraph/Optimizer.js +88 -0
- package/dist/dsl/egraph/Pattern.d.ts +80 -0
- package/dist/dsl/egraph/Pattern.js +325 -0
- package/dist/dsl/egraph/Rewriter.d.ts +44 -0
- package/dist/dsl/egraph/Rewriter.js +131 -0
- package/dist/dsl/egraph/Rules.d.ts +44 -0
- package/dist/dsl/egraph/Rules.js +187 -0
- package/dist/dsl/egraph/index.d.ts +15 -0
- package/dist/dsl/egraph/index.js +21 -0
- package/package.json +1 -1
- package/dist/dsl/CSE.d.ts +0 -21
- package/dist/dsl/CSE.js +0 -168
- package/dist/symbolic/AST.d.ts +0 -113
- package/dist/symbolic/AST.js +0 -128
- package/dist/symbolic/CodeGen.d.ts +0 -35
- package/dist/symbolic/CodeGen.js +0 -280
- package/dist/symbolic/Parser.d.ts +0 -64
- package/dist/symbolic/Parser.js +0 -329
- package/dist/symbolic/Simplify.d.ts +0 -10
- package/dist/symbolic/Simplify.js +0 -244
- package/dist/symbolic/SymbolicDiff.d.ts +0 -35
- package/dist/symbolic/SymbolicDiff.js +0 -339
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* E-Graph Optimizer for GradientScript
|
|
3
|
+
*
|
|
4
|
+
* Main entry point for e-graph-based optimization.
|
|
5
|
+
* Uses equality saturation for CSE and algebraic simplification.
|
|
6
|
+
*/
|
|
7
|
+
import { Expression } from '../AST.js';
|
|
8
|
+
import { SaturationStats } from './Rewriter.js';
|
|
9
|
+
import { CostModel } from './Extractor.js';
|
|
10
|
+
/**
|
|
11
|
+
* Result from e-graph optimization (matches GlobalCSEResult interface)
|
|
12
|
+
*/
|
|
13
|
+
export interface EGraphOptimizeResult {
|
|
14
|
+
intermediates: Map<string, Expression>;
|
|
15
|
+
gradients: Map<string, Map<string, Expression>>;
|
|
16
|
+
stats?: OptimizationStats;
|
|
17
|
+
}
|
|
18
|
+
/**
|
|
19
|
+
* Statistics from optimization
|
|
20
|
+
*/
|
|
21
|
+
export interface OptimizationStats {
|
|
22
|
+
saturation: SaturationStats;
|
|
23
|
+
tempsCreated: number;
|
|
24
|
+
totalCost: number;
|
|
25
|
+
}
|
|
26
|
+
/**
|
|
27
|
+
* Options for e-graph optimization
|
|
28
|
+
*/
|
|
29
|
+
export interface EGraphOptimizeOptions {
|
|
30
|
+
/** Maximum saturation iterations (default: 30) */
|
|
31
|
+
maxIterations?: number;
|
|
32
|
+
/** Maximum e-graph size before stopping saturation (default: 15000) */
|
|
33
|
+
maxClassSize?: number;
|
|
34
|
+
/** Which rule sets to use (default: ['core', 'algebra']) */
|
|
35
|
+
ruleSets?: ('core' | 'algebra' | 'function')[];
|
|
36
|
+
/** Use phased saturation (core first, then others) */
|
|
37
|
+
phased?: boolean;
|
|
38
|
+
/** Minimum cost for a subexpression to become a temp (default: 3) */
|
|
39
|
+
minSharedCost?: number;
|
|
40
|
+
/** Custom cost model */
|
|
41
|
+
costModel?: CostModel;
|
|
42
|
+
/** Print verbose output */
|
|
43
|
+
verbose?: boolean;
|
|
44
|
+
}
|
|
45
|
+
/**
|
|
46
|
+
* Optimize gradients using e-graph equality saturation
|
|
47
|
+
*
|
|
48
|
+
* This is designed to be a drop-in replacement for eliminateCommonSubexpressionsGlobal
|
|
49
|
+
*/
|
|
50
|
+
export declare function optimizeWithEGraph(allGradients: Map<string, Map<string, Expression>>, options?: EGraphOptimizeOptions): EGraphOptimizeResult;
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* E-Graph Optimizer for GradientScript
|
|
3
|
+
*
|
|
4
|
+
* Main entry point for e-graph-based optimization.
|
|
5
|
+
* Uses equality saturation for CSE and algebraic simplification.
|
|
6
|
+
*/
|
|
7
|
+
import { EGraph } from './EGraph.js';
|
|
8
|
+
import { addGradients, getRootIds } from './Convert.js';
|
|
9
|
+
import { saturate, saturatePhased } from './Rewriter.js';
|
|
10
|
+
import { extractWithCSE, defaultCostModel } from './Extractor.js';
|
|
11
|
+
import { coreRules, algebraRules, functionRules } from './Rules.js';
|
|
12
|
+
/**
|
|
13
|
+
* Optimize gradients using e-graph equality saturation
|
|
14
|
+
*
|
|
15
|
+
* This is designed to be a drop-in replacement for eliminateCommonSubexpressionsGlobal
|
|
16
|
+
*/
|
|
17
|
+
export function optimizeWithEGraph(allGradients, options = {}) {
|
|
18
|
+
const { maxIterations = 30, maxClassSize = 15000, ruleSets = ['core', 'algebra'], phased = true, minSharedCost = 2, costModel = defaultCostModel, verbose = false } = options;
|
|
19
|
+
// Build e-graph from all gradients
|
|
20
|
+
const egraph = new EGraph();
|
|
21
|
+
const gradientIds = addGradients(egraph, allGradients);
|
|
22
|
+
const rootIds = getRootIds(gradientIds);
|
|
23
|
+
if (verbose) {
|
|
24
|
+
console.log(`[egraph] Added ${egraph.size} e-classes from ${rootIds.length} gradient expressions`);
|
|
25
|
+
}
|
|
26
|
+
// Select rules
|
|
27
|
+
const rules = selectRules(ruleSets);
|
|
28
|
+
// Saturate
|
|
29
|
+
let stats;
|
|
30
|
+
if (phased) {
|
|
31
|
+
const phases = [];
|
|
32
|
+
if (ruleSets.includes('core'))
|
|
33
|
+
phases.push(coreRules);
|
|
34
|
+
if (ruleSets.includes('algebra'))
|
|
35
|
+
phases.push(algebraRules);
|
|
36
|
+
if (ruleSets.includes('function'))
|
|
37
|
+
phases.push(functionRules);
|
|
38
|
+
stats = saturatePhased(egraph, phases, { maxIterations, maxClassSize, verbose });
|
|
39
|
+
}
|
|
40
|
+
else {
|
|
41
|
+
stats = saturate(egraph, rules, { maxIterations, maxClassSize, verbose });
|
|
42
|
+
}
|
|
43
|
+
if (verbose) {
|
|
44
|
+
console.log(`[egraph] Saturation: ${stats.iterations} iters, ${stats.merges} merges, ${egraph.size} classes`);
|
|
45
|
+
}
|
|
46
|
+
// Extract with CSE
|
|
47
|
+
const extraction = extractWithCSE(egraph, rootIds, costModel, minSharedCost);
|
|
48
|
+
if (verbose) {
|
|
49
|
+
console.log(`[egraph] Extracted ${extraction.temps.size} temps, total cost ${extraction.totalCost}`);
|
|
50
|
+
}
|
|
51
|
+
// Convert back to gradient structure
|
|
52
|
+
const optimizedGradients = new Map();
|
|
53
|
+
let rootIndex = 0;
|
|
54
|
+
for (const [paramName, componentIds] of gradientIds) {
|
|
55
|
+
const components = new Map();
|
|
56
|
+
for (const [comp] of componentIds) {
|
|
57
|
+
const rootId = rootIds[rootIndex++];
|
|
58
|
+
const expr = extraction.expressions.get(rootId);
|
|
59
|
+
if (!expr) {
|
|
60
|
+
throw new Error(`Missing extraction for root ${rootId}`);
|
|
61
|
+
}
|
|
62
|
+
components.set(comp, expr);
|
|
63
|
+
}
|
|
64
|
+
optimizedGradients.set(paramName, components);
|
|
65
|
+
}
|
|
66
|
+
return {
|
|
67
|
+
intermediates: extraction.temps,
|
|
68
|
+
gradients: optimizedGradients,
|
|
69
|
+
stats: {
|
|
70
|
+
saturation: stats,
|
|
71
|
+
tempsCreated: extraction.temps.size,
|
|
72
|
+
totalCost: extraction.totalCost
|
|
73
|
+
}
|
|
74
|
+
};
|
|
75
|
+
}
|
|
76
|
+
/**
|
|
77
|
+
* Select rules based on rule set names
|
|
78
|
+
*/
|
|
79
|
+
function selectRules(ruleSets) {
|
|
80
|
+
const rules = [];
|
|
81
|
+
if (ruleSets.includes('core'))
|
|
82
|
+
rules.push(...coreRules);
|
|
83
|
+
if (ruleSets.includes('algebra'))
|
|
84
|
+
rules.push(...algebraRules);
|
|
85
|
+
if (ruleSets.includes('function'))
|
|
86
|
+
rules.push(...functionRules);
|
|
87
|
+
return rules;
|
|
88
|
+
}
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Pattern Matching for E-Graph Rewrite Rules
|
|
3
|
+
*
|
|
4
|
+
* Patterns are expression templates with variables (?a, ?b, etc.)
|
|
5
|
+
* that can match against e-classes in the e-graph.
|
|
6
|
+
*/
|
|
7
|
+
import { EGraph } from './EGraph.js';
|
|
8
|
+
import { EClassId } from './ENode.js';
|
|
9
|
+
/**
|
|
10
|
+
* Pattern AST
|
|
11
|
+
*/
|
|
12
|
+
export type Pattern = {
|
|
13
|
+
tag: 'pvar';
|
|
14
|
+
name: string;
|
|
15
|
+
} | {
|
|
16
|
+
tag: 'pnum';
|
|
17
|
+
value: number;
|
|
18
|
+
} | {
|
|
19
|
+
tag: 'padd';
|
|
20
|
+
left: Pattern;
|
|
21
|
+
right: Pattern;
|
|
22
|
+
} | {
|
|
23
|
+
tag: 'pmul';
|
|
24
|
+
left: Pattern;
|
|
25
|
+
right: Pattern;
|
|
26
|
+
} | {
|
|
27
|
+
tag: 'psub';
|
|
28
|
+
left: Pattern;
|
|
29
|
+
right: Pattern;
|
|
30
|
+
} | {
|
|
31
|
+
tag: 'pdiv';
|
|
32
|
+
left: Pattern;
|
|
33
|
+
right: Pattern;
|
|
34
|
+
} | {
|
|
35
|
+
tag: 'ppow';
|
|
36
|
+
left: Pattern;
|
|
37
|
+
right: Pattern;
|
|
38
|
+
} | {
|
|
39
|
+
tag: 'pneg';
|
|
40
|
+
child: Pattern;
|
|
41
|
+
} | {
|
|
42
|
+
tag: 'pinv';
|
|
43
|
+
child: Pattern;
|
|
44
|
+
} | {
|
|
45
|
+
tag: 'pcall';
|
|
46
|
+
name: string;
|
|
47
|
+
args: Pattern[];
|
|
48
|
+
};
|
|
49
|
+
/**
|
|
50
|
+
* A substitution mapping pattern variables to e-class IDs
|
|
51
|
+
*/
|
|
52
|
+
export type Substitution = Map<string, EClassId>;
|
|
53
|
+
/**
|
|
54
|
+
* Parse a pattern string into a Pattern AST
|
|
55
|
+
*
|
|
56
|
+
* Syntax:
|
|
57
|
+
* ?a, ?b, ?x - pattern variables
|
|
58
|
+
* 0, 1, 2, -1 - number literals
|
|
59
|
+
* (+ ?a ?b) - addition
|
|
60
|
+
* (* ?a ?b) - multiplication
|
|
61
|
+
* (- ?a ?b) - subtraction
|
|
62
|
+
* (/ ?a ?b) - division
|
|
63
|
+
* (^ ?a ?b) - power
|
|
64
|
+
* (neg ?a) - negation
|
|
65
|
+
* (sqrt ?a) - function call
|
|
66
|
+
*/
|
|
67
|
+
export declare function parsePattern(input: string): Pattern;
|
|
68
|
+
/**
|
|
69
|
+
* Match a pattern against an e-class, returning all valid substitutions
|
|
70
|
+
*/
|
|
71
|
+
export declare function matchPattern(egraph: EGraph, pattern: Pattern, classId: EClassId): Substitution[];
|
|
72
|
+
/**
|
|
73
|
+
* Instantiate a pattern with a substitution, adding nodes to the e-graph
|
|
74
|
+
* Returns the e-class ID of the instantiated pattern
|
|
75
|
+
*/
|
|
76
|
+
export declare function instantiatePattern(egraph: EGraph, pattern: Pattern, subst: Substitution): EClassId;
|
|
77
|
+
/**
|
|
78
|
+
* Convert pattern to string for debugging
|
|
79
|
+
*/
|
|
80
|
+
export declare function patternToString(pattern: Pattern): string;
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Pattern Matching for E-Graph Rewrite Rules
|
|
3
|
+
*
|
|
4
|
+
* Patterns are expression templates with variables (?a, ?b, etc.)
|
|
5
|
+
* that can match against e-classes in the e-graph.
|
|
6
|
+
*/
|
|
7
|
+
/**
|
|
8
|
+
* Parse a pattern string into a Pattern AST
|
|
9
|
+
*
|
|
10
|
+
* Syntax:
|
|
11
|
+
* ?a, ?b, ?x - pattern variables
|
|
12
|
+
* 0, 1, 2, -1 - number literals
|
|
13
|
+
* (+ ?a ?b) - addition
|
|
14
|
+
* (* ?a ?b) - multiplication
|
|
15
|
+
* (- ?a ?b) - subtraction
|
|
16
|
+
* (/ ?a ?b) - division
|
|
17
|
+
* (^ ?a ?b) - power
|
|
18
|
+
* (neg ?a) - negation
|
|
19
|
+
* (sqrt ?a) - function call
|
|
20
|
+
*/
|
|
21
|
+
export function parsePattern(input) {
|
|
22
|
+
const tokens = tokenize(input);
|
|
23
|
+
let pos = 0;
|
|
24
|
+
function peek() {
|
|
25
|
+
return tokens[pos];
|
|
26
|
+
}
|
|
27
|
+
function consume() {
|
|
28
|
+
return tokens[pos++];
|
|
29
|
+
}
|
|
30
|
+
function parseExpr() {
|
|
31
|
+
const token = peek();
|
|
32
|
+
if (token === '(') {
|
|
33
|
+
consume(); // (
|
|
34
|
+
const op = consume();
|
|
35
|
+
let result;
|
|
36
|
+
if (op === 'neg') {
|
|
37
|
+
const child = parseExpr();
|
|
38
|
+
result = { tag: 'pneg', child };
|
|
39
|
+
}
|
|
40
|
+
else if (op === 'inv') {
|
|
41
|
+
const child = parseExpr();
|
|
42
|
+
result = { tag: 'pinv', child };
|
|
43
|
+
}
|
|
44
|
+
else if (['+', '-', '*', '/', '^'].includes(op)) {
|
|
45
|
+
const left = parseExpr();
|
|
46
|
+
const right = parseExpr();
|
|
47
|
+
switch (op) {
|
|
48
|
+
case '+':
|
|
49
|
+
result = { tag: 'padd', left, right };
|
|
50
|
+
break;
|
|
51
|
+
case '-':
|
|
52
|
+
result = { tag: 'psub', left, right };
|
|
53
|
+
break;
|
|
54
|
+
case '*':
|
|
55
|
+
result = { tag: 'pmul', left, right };
|
|
56
|
+
break;
|
|
57
|
+
case '/':
|
|
58
|
+
result = { tag: 'pdiv', left, right };
|
|
59
|
+
break;
|
|
60
|
+
case '^':
|
|
61
|
+
result = { tag: 'ppow', left, right };
|
|
62
|
+
break;
|
|
63
|
+
default: throw new Error(`Unknown operator: ${op}`);
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
else {
|
|
67
|
+
// Function call
|
|
68
|
+
const args = [];
|
|
69
|
+
while (peek() !== ')') {
|
|
70
|
+
args.push(parseExpr());
|
|
71
|
+
}
|
|
72
|
+
result = { tag: 'pcall', name: op, args };
|
|
73
|
+
}
|
|
74
|
+
if (consume() !== ')') {
|
|
75
|
+
throw new Error('Expected )');
|
|
76
|
+
}
|
|
77
|
+
return result;
|
|
78
|
+
}
|
|
79
|
+
if (token?.startsWith('?')) {
|
|
80
|
+
consume();
|
|
81
|
+
return { tag: 'pvar', name: token.slice(1) };
|
|
82
|
+
}
|
|
83
|
+
if (token && /^-?\d+(\.\d+)?$/.test(token)) {
|
|
84
|
+
consume();
|
|
85
|
+
return { tag: 'pnum', value: parseFloat(token) };
|
|
86
|
+
}
|
|
87
|
+
throw new Error(`Unexpected token: ${token}`);
|
|
88
|
+
}
|
|
89
|
+
const result = parseExpr();
|
|
90
|
+
if (pos < tokens.length) {
|
|
91
|
+
throw new Error(`Unexpected token after pattern: ${tokens[pos]}`);
|
|
92
|
+
}
|
|
93
|
+
return result;
|
|
94
|
+
}
|
|
95
|
+
/**
|
|
96
|
+
* Tokenize a pattern string
|
|
97
|
+
*/
|
|
98
|
+
function tokenize(input) {
|
|
99
|
+
const tokens = [];
|
|
100
|
+
let i = 0;
|
|
101
|
+
while (i < input.length) {
|
|
102
|
+
const ch = input[i];
|
|
103
|
+
if (/\s/.test(ch)) {
|
|
104
|
+
i++;
|
|
105
|
+
continue;
|
|
106
|
+
}
|
|
107
|
+
if (ch === '(' || ch === ')') {
|
|
108
|
+
tokens.push(ch);
|
|
109
|
+
i++;
|
|
110
|
+
continue;
|
|
111
|
+
}
|
|
112
|
+
// Variable or operator or number
|
|
113
|
+
let token = '';
|
|
114
|
+
while (i < input.length && !/[\s()]/.test(input[i])) {
|
|
115
|
+
token += input[i];
|
|
116
|
+
i++;
|
|
117
|
+
}
|
|
118
|
+
if (token) {
|
|
119
|
+
tokens.push(token);
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
return tokens;
|
|
123
|
+
}
|
|
124
|
+
/** Maximum substitutions to return from a single match (prevents explosion) */
|
|
125
|
+
const MAX_SUBSTITUTIONS = 100;
|
|
126
|
+
/**
|
|
127
|
+
* Match a pattern against an e-class, returning all valid substitutions
|
|
128
|
+
*/
|
|
129
|
+
export function matchPattern(egraph, pattern, classId) {
|
|
130
|
+
const results = matchPatternWithSubst(egraph, pattern, classId, new Map(), 0);
|
|
131
|
+
return results.slice(0, MAX_SUBSTITUTIONS);
|
|
132
|
+
}
|
|
133
|
+
const MAX_MATCH_DEPTH = 50;
|
|
134
|
+
function matchPatternWithSubst(egraph, pattern, classId, subst, depth) {
|
|
135
|
+
if (depth > MAX_MATCH_DEPTH) {
|
|
136
|
+
return []; // Prevent infinite recursion
|
|
137
|
+
}
|
|
138
|
+
const canonId = egraph.find(classId);
|
|
139
|
+
// Pattern variable - bind or check existing binding
|
|
140
|
+
if (pattern.tag === 'pvar') {
|
|
141
|
+
const existing = subst.get(pattern.name);
|
|
142
|
+
if (existing !== undefined) {
|
|
143
|
+
// Check if it matches the same e-class
|
|
144
|
+
if (egraph.find(existing) === canonId) {
|
|
145
|
+
return [new Map(subst)];
|
|
146
|
+
}
|
|
147
|
+
return [];
|
|
148
|
+
}
|
|
149
|
+
// Bind the variable
|
|
150
|
+
const newSubst = new Map(subst);
|
|
151
|
+
newSubst.set(pattern.name, canonId);
|
|
152
|
+
return [newSubst];
|
|
153
|
+
}
|
|
154
|
+
// Try to match against all nodes in the e-class
|
|
155
|
+
const nodes = egraph.getNodes(canonId);
|
|
156
|
+
const results = [];
|
|
157
|
+
for (const node of nodes) {
|
|
158
|
+
if (results.length >= MAX_SUBSTITUTIONS)
|
|
159
|
+
break;
|
|
160
|
+
const matches = matchNodeWithPattern(egraph, pattern, node, subst, depth);
|
|
161
|
+
for (const m of matches) {
|
|
162
|
+
results.push(m);
|
|
163
|
+
if (results.length >= MAX_SUBSTITUTIONS)
|
|
164
|
+
break;
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
return results;
|
|
168
|
+
}
|
|
169
|
+
function matchNodeWithPattern(egraph, pattern, node, subst, depth) {
|
|
170
|
+
switch (pattern.tag) {
|
|
171
|
+
case 'pvar':
|
|
172
|
+
// Already handled above
|
|
173
|
+
throw new Error('pvar should be handled in matchPatternWithSubst');
|
|
174
|
+
case 'pnum':
|
|
175
|
+
if (node.tag === 'num' && node.value === pattern.value) {
|
|
176
|
+
return [new Map(subst)];
|
|
177
|
+
}
|
|
178
|
+
return [];
|
|
179
|
+
case 'padd':
|
|
180
|
+
if (node.tag === 'add') {
|
|
181
|
+
return matchBinaryChildren(egraph, pattern.left, pattern.right, node.children, subst, depth);
|
|
182
|
+
}
|
|
183
|
+
return [];
|
|
184
|
+
case 'pmul':
|
|
185
|
+
if (node.tag === 'mul') {
|
|
186
|
+
return matchBinaryChildren(egraph, pattern.left, pattern.right, node.children, subst, depth);
|
|
187
|
+
}
|
|
188
|
+
return [];
|
|
189
|
+
case 'psub':
|
|
190
|
+
if (node.tag === 'sub') {
|
|
191
|
+
return matchBinaryChildren(egraph, pattern.left, pattern.right, node.children, subst, depth);
|
|
192
|
+
}
|
|
193
|
+
return [];
|
|
194
|
+
case 'pdiv':
|
|
195
|
+
if (node.tag === 'div') {
|
|
196
|
+
return matchBinaryChildren(egraph, pattern.left, pattern.right, node.children, subst, depth);
|
|
197
|
+
}
|
|
198
|
+
return [];
|
|
199
|
+
case 'ppow':
|
|
200
|
+
if (node.tag === 'pow') {
|
|
201
|
+
return matchBinaryChildren(egraph, pattern.left, pattern.right, node.children, subst, depth);
|
|
202
|
+
}
|
|
203
|
+
return [];
|
|
204
|
+
case 'pneg':
|
|
205
|
+
if (node.tag === 'neg') {
|
|
206
|
+
return matchPatternWithSubst(egraph, pattern.child, node.child, subst, depth + 1);
|
|
207
|
+
}
|
|
208
|
+
return [];
|
|
209
|
+
case 'pinv':
|
|
210
|
+
if (node.tag === 'inv') {
|
|
211
|
+
return matchPatternWithSubst(egraph, pattern.child, node.child, subst, depth + 1);
|
|
212
|
+
}
|
|
213
|
+
return [];
|
|
214
|
+
case 'pcall':
|
|
215
|
+
if (node.tag === 'call' && node.name === pattern.name && node.children.length === pattern.args.length) {
|
|
216
|
+
return matchCallChildren(egraph, pattern.args, node.children, subst, depth);
|
|
217
|
+
}
|
|
218
|
+
return [];
|
|
219
|
+
}
|
|
220
|
+
}
|
|
221
|
+
function matchBinaryChildren(egraph, leftPattern, rightPattern, children, subst, depth) {
|
|
222
|
+
const results = [];
|
|
223
|
+
// Match left, then right
|
|
224
|
+
const leftMatches = matchPatternWithSubst(egraph, leftPattern, children[0], subst, depth + 1);
|
|
225
|
+
for (const leftSubst of leftMatches) {
|
|
226
|
+
if (results.length >= MAX_SUBSTITUTIONS)
|
|
227
|
+
break;
|
|
228
|
+
const rightMatches = matchPatternWithSubst(egraph, rightPattern, children[1], leftSubst, depth + 1);
|
|
229
|
+
for (const m of rightMatches) {
|
|
230
|
+
results.push(m);
|
|
231
|
+
if (results.length >= MAX_SUBSTITUTIONS)
|
|
232
|
+
break;
|
|
233
|
+
}
|
|
234
|
+
}
|
|
235
|
+
return results;
|
|
236
|
+
}
|
|
237
|
+
function matchCallChildren(egraph, patterns, children, subst, depth) {
|
|
238
|
+
if (patterns.length === 0) {
|
|
239
|
+
return [new Map(subst)];
|
|
240
|
+
}
|
|
241
|
+
const results = [];
|
|
242
|
+
const firstMatches = matchPatternWithSubst(egraph, patterns[0], children[0], subst, depth + 1);
|
|
243
|
+
for (const firstSubst of firstMatches) {
|
|
244
|
+
if (results.length >= MAX_SUBSTITUTIONS)
|
|
245
|
+
break;
|
|
246
|
+
const restMatches = matchCallChildren(egraph, patterns.slice(1), children.slice(1), firstSubst, depth + 1);
|
|
247
|
+
for (const m of restMatches) {
|
|
248
|
+
results.push(m);
|
|
249
|
+
if (results.length >= MAX_SUBSTITUTIONS)
|
|
250
|
+
break;
|
|
251
|
+
}
|
|
252
|
+
}
|
|
253
|
+
return results;
|
|
254
|
+
}
|
|
255
|
+
/**
|
|
256
|
+
* Instantiate a pattern with a substitution, adding nodes to the e-graph
|
|
257
|
+
* Returns the e-class ID of the instantiated pattern
|
|
258
|
+
*/
|
|
259
|
+
export function instantiatePattern(egraph, pattern, subst) {
|
|
260
|
+
switch (pattern.tag) {
|
|
261
|
+
case 'pvar': {
|
|
262
|
+
const id = subst.get(pattern.name);
|
|
263
|
+
if (id === undefined) {
|
|
264
|
+
throw new Error(`Unbound pattern variable: ?${pattern.name}`);
|
|
265
|
+
}
|
|
266
|
+
return id;
|
|
267
|
+
}
|
|
268
|
+
case 'pnum':
|
|
269
|
+
return egraph.add({ tag: 'num', value: pattern.value });
|
|
270
|
+
case 'padd': {
|
|
271
|
+
const left = instantiatePattern(egraph, pattern.left, subst);
|
|
272
|
+
const right = instantiatePattern(egraph, pattern.right, subst);
|
|
273
|
+
return egraph.add({ tag: 'add', children: [left, right] });
|
|
274
|
+
}
|
|
275
|
+
case 'pmul': {
|
|
276
|
+
const left = instantiatePattern(egraph, pattern.left, subst);
|
|
277
|
+
const right = instantiatePattern(egraph, pattern.right, subst);
|
|
278
|
+
return egraph.add({ tag: 'mul', children: [left, right] });
|
|
279
|
+
}
|
|
280
|
+
case 'psub': {
|
|
281
|
+
const left = instantiatePattern(egraph, pattern.left, subst);
|
|
282
|
+
const right = instantiatePattern(egraph, pattern.right, subst);
|
|
283
|
+
return egraph.add({ tag: 'sub', children: [left, right] });
|
|
284
|
+
}
|
|
285
|
+
case 'pdiv': {
|
|
286
|
+
const left = instantiatePattern(egraph, pattern.left, subst);
|
|
287
|
+
const right = instantiatePattern(egraph, pattern.right, subst);
|
|
288
|
+
return egraph.add({ tag: 'div', children: [left, right] });
|
|
289
|
+
}
|
|
290
|
+
case 'ppow': {
|
|
291
|
+
const left = instantiatePattern(egraph, pattern.left, subst);
|
|
292
|
+
const right = instantiatePattern(egraph, pattern.right, subst);
|
|
293
|
+
return egraph.add({ tag: 'pow', children: [left, right] });
|
|
294
|
+
}
|
|
295
|
+
case 'pneg': {
|
|
296
|
+
const child = instantiatePattern(egraph, pattern.child, subst);
|
|
297
|
+
return egraph.add({ tag: 'neg', child });
|
|
298
|
+
}
|
|
299
|
+
case 'pinv': {
|
|
300
|
+
const child = instantiatePattern(egraph, pattern.child, subst);
|
|
301
|
+
return egraph.add({ tag: 'inv', child });
|
|
302
|
+
}
|
|
303
|
+
case 'pcall': {
|
|
304
|
+
const args = pattern.args.map(arg => instantiatePattern(egraph, arg, subst));
|
|
305
|
+
return egraph.add({ tag: 'call', name: pattern.name, children: args });
|
|
306
|
+
}
|
|
307
|
+
}
|
|
308
|
+
}
|
|
309
|
+
/**
|
|
310
|
+
* Convert pattern to string for debugging
|
|
311
|
+
*/
|
|
312
|
+
export function patternToString(pattern) {
|
|
313
|
+
switch (pattern.tag) {
|
|
314
|
+
case 'pvar': return `?${pattern.name}`;
|
|
315
|
+
case 'pnum': return `${pattern.value}`;
|
|
316
|
+
case 'padd': return `(+ ${patternToString(pattern.left)} ${patternToString(pattern.right)})`;
|
|
317
|
+
case 'pmul': return `(* ${patternToString(pattern.left)} ${patternToString(pattern.right)})`;
|
|
318
|
+
case 'psub': return `(- ${patternToString(pattern.left)} ${patternToString(pattern.right)})`;
|
|
319
|
+
case 'pdiv': return `(/ ${patternToString(pattern.left)} ${patternToString(pattern.right)})`;
|
|
320
|
+
case 'ppow': return `(^ ${patternToString(pattern.left)} ${patternToString(pattern.right)})`;
|
|
321
|
+
case 'pneg': return `(neg ${patternToString(pattern.child)})`;
|
|
322
|
+
case 'pinv': return `(inv ${patternToString(pattern.child)})`;
|
|
323
|
+
case 'pcall': return `(${pattern.name} ${pattern.args.map(patternToString).join(' ')})`;
|
|
324
|
+
}
|
|
325
|
+
}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Rewrite Engine for E-Graph Equality Saturation
|
|
3
|
+
*
|
|
4
|
+
* Applies rewrite rules until saturation (no new merges) or iteration limit.
|
|
5
|
+
*/
|
|
6
|
+
import { EGraph } from './EGraph.js';
|
|
7
|
+
import { Rule } from './Rules.js';
|
|
8
|
+
/**
|
|
9
|
+
* Statistics from a saturation run
|
|
10
|
+
*/
|
|
11
|
+
export interface SaturationStats {
|
|
12
|
+
iterations: number;
|
|
13
|
+
totalMatches: number;
|
|
14
|
+
merges: number;
|
|
15
|
+
saturated: boolean;
|
|
16
|
+
classCount: number;
|
|
17
|
+
}
|
|
18
|
+
/**
|
|
19
|
+
* Options for saturation
|
|
20
|
+
*/
|
|
21
|
+
export interface SaturationOptions {
|
|
22
|
+
maxIterations?: number;
|
|
23
|
+
maxClassSize?: number;
|
|
24
|
+
maxMatchesPerIter?: number;
|
|
25
|
+
verbose?: boolean;
|
|
26
|
+
}
|
|
27
|
+
/**
|
|
28
|
+
* Apply equality saturation to an e-graph
|
|
29
|
+
*
|
|
30
|
+
* Repeatedly applies rewrite rules until:
|
|
31
|
+
* - No new equivalences are discovered (saturated)
|
|
32
|
+
* - Max iterations reached
|
|
33
|
+
* - E-graph size limit exceeded
|
|
34
|
+
*/
|
|
35
|
+
export declare function saturate(egraph: EGraph, rules: Rule[], options?: SaturationOptions): SaturationStats;
|
|
36
|
+
/**
|
|
37
|
+
* Apply a single rule once, returning number of merges
|
|
38
|
+
*/
|
|
39
|
+
export declare function applyRuleOnce(egraph: EGraph, rule: Rule): number;
|
|
40
|
+
/**
|
|
41
|
+
* Apply rules in phases for better control
|
|
42
|
+
* E.g., apply core rules first, then algebra rules
|
|
43
|
+
*/
|
|
44
|
+
export declare function saturatePhased(egraph: EGraph, phases: Rule[][], options?: SaturationOptions): SaturationStats;
|