gradient-script 0.2.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -1
- package/dist/cli.js +219 -6
- package/dist/dsl/CodeGen.d.ts +1 -1
- package/dist/dsl/CodeGen.js +336 -74
- package/dist/dsl/ExpressionUtils.d.ts +8 -2
- package/dist/dsl/ExpressionUtils.js +34 -2
- package/dist/dsl/GradientChecker.d.ts +21 -0
- package/dist/dsl/GradientChecker.js +109 -23
- package/dist/dsl/Guards.d.ts +1 -1
- package/dist/dsl/Guards.js +14 -13
- package/dist/dsl/Inliner.d.ts +5 -0
- package/dist/dsl/Inliner.js +8 -0
- package/dist/dsl/Simplify.d.ts +7 -0
- package/dist/dsl/Simplify.js +136 -0
- package/dist/dsl/egraph/Convert.d.ts +23 -0
- package/dist/dsl/egraph/Convert.js +84 -0
- package/dist/dsl/egraph/EGraph.d.ts +93 -0
- package/dist/dsl/egraph/EGraph.js +292 -0
- package/dist/dsl/egraph/ENode.d.ts +63 -0
- package/dist/dsl/egraph/ENode.js +94 -0
- package/dist/dsl/egraph/Extractor.d.ts +49 -0
- package/dist/dsl/egraph/Extractor.js +1068 -0
- package/dist/dsl/egraph/Optimizer.d.ts +50 -0
- package/dist/dsl/egraph/Optimizer.js +88 -0
- package/dist/dsl/egraph/Pattern.d.ts +80 -0
- package/dist/dsl/egraph/Pattern.js +325 -0
- package/dist/dsl/egraph/Rewriter.d.ts +44 -0
- package/dist/dsl/egraph/Rewriter.js +131 -0
- package/dist/dsl/egraph/Rules.d.ts +44 -0
- package/dist/dsl/egraph/Rules.js +187 -0
- package/dist/dsl/egraph/index.d.ts +15 -0
- package/dist/dsl/egraph/index.js +21 -0
- package/package.json +1 -1
- package/dist/dsl/CSE.d.ts +0 -21
- package/dist/dsl/CSE.js +0 -168
- package/dist/symbolic/AST.d.ts +0 -113
- package/dist/symbolic/AST.js +0 -128
- package/dist/symbolic/CodeGen.d.ts +0 -35
- package/dist/symbolic/CodeGen.js +0 -280
- package/dist/symbolic/Parser.d.ts +0 -64
- package/dist/symbolic/Parser.js +0 -329
- package/dist/symbolic/Simplify.d.ts +0 -10
- package/dist/symbolic/Simplify.js +0 -244
- package/dist/symbolic/SymbolicDiff.d.ts +0 -35
- package/dist/symbolic/SymbolicDiff.js +0 -339
|
@@ -0,0 +1,1068 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Cost-Based Extraction from E-Graphs
|
|
3
|
+
*
|
|
4
|
+
* Extracts the lowest-cost expression from each e-class.
|
|
5
|
+
* Also detects common subexpressions (shared e-classes) for CSE.
|
|
6
|
+
*/
|
|
7
|
+
import { enodeChildren } from './ENode.js';
|
|
8
|
+
// =============================================================================
|
|
9
|
+
// Constant Folding Helpers
|
|
10
|
+
// =============================================================================
|
|
11
|
+
/**
|
|
12
|
+
* Create a binary expression, folding constants when possible
|
|
13
|
+
*/
|
|
14
|
+
function makeBinary(operator, left, right) {
|
|
15
|
+
// If both operands are numbers, fold the constant
|
|
16
|
+
if (left.kind === 'number' && right.kind === 'number') {
|
|
17
|
+
const l = left.value;
|
|
18
|
+
const r = right.value;
|
|
19
|
+
let result;
|
|
20
|
+
switch (operator) {
|
|
21
|
+
case '+':
|
|
22
|
+
result = l + r;
|
|
23
|
+
break;
|
|
24
|
+
case '-':
|
|
25
|
+
result = l - r;
|
|
26
|
+
break;
|
|
27
|
+
case '*':
|
|
28
|
+
result = l * r;
|
|
29
|
+
break;
|
|
30
|
+
case '/':
|
|
31
|
+
result = r !== 0 ? l / r : NaN;
|
|
32
|
+
break;
|
|
33
|
+
case '^':
|
|
34
|
+
result = Math.pow(l, r);
|
|
35
|
+
break;
|
|
36
|
+
}
|
|
37
|
+
// Only fold if result is a finite number
|
|
38
|
+
if (Number.isFinite(result)) {
|
|
39
|
+
return { kind: 'number', value: result };
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
return { kind: 'binary', operator, left, right };
|
|
43
|
+
}
|
|
44
|
+
/**
|
|
45
|
+
* Create a unary expression, folding constants when possible
|
|
46
|
+
*/
|
|
47
|
+
function makeUnary(operator, operand) {
|
|
48
|
+
// If operand is a number, fold the constant
|
|
49
|
+
if (operand.kind === 'number') {
|
|
50
|
+
return { kind: 'number', value: -operand.value };
|
|
51
|
+
}
|
|
52
|
+
return { kind: 'unary', operator, operand };
|
|
53
|
+
}
|
|
54
|
+
/**
|
|
55
|
+
* Check if an expression is trivial (should never be a temp)
|
|
56
|
+
* - Constants: 5, -2, etc.
|
|
57
|
+
* - Simple negations of constants: -(2)
|
|
58
|
+
* - Variables: a, b
|
|
59
|
+
*/
|
|
60
|
+
function isTrivialExpression(expr) {
|
|
61
|
+
if (expr.kind === 'number')
|
|
62
|
+
return true;
|
|
63
|
+
if (expr.kind === 'variable')
|
|
64
|
+
return true;
|
|
65
|
+
if (expr.kind === 'unary' && expr.operand.kind === 'number')
|
|
66
|
+
return true;
|
|
67
|
+
return false;
|
|
68
|
+
}
|
|
69
|
+
/**
|
|
70
|
+
* Default cost model - division is expensive
|
|
71
|
+
*/
|
|
72
|
+
export const defaultCostModel = {
|
|
73
|
+
num: 1,
|
|
74
|
+
var: 1,
|
|
75
|
+
add: 2,
|
|
76
|
+
sub: 2,
|
|
77
|
+
mul: 2,
|
|
78
|
+
div: 8, // Division is expensive - encourage factoring
|
|
79
|
+
pow: 4,
|
|
80
|
+
neg: 1,
|
|
81
|
+
inv: 5, // Inverse (1/x) - cheaper than div but significant
|
|
82
|
+
call: 3,
|
|
83
|
+
component: 1,
|
|
84
|
+
};
|
|
85
|
+
/**
|
|
86
|
+
* Extract the best expression from an e-class
|
|
87
|
+
*/
|
|
88
|
+
export function extractBest(egraph, rootId, costModel = defaultCostModel) {
|
|
89
|
+
const costs = computeCosts(egraph, costModel);
|
|
90
|
+
return extractFromClass(egraph, rootId, costs, costModel);
|
|
91
|
+
}
|
|
92
|
+
/**
|
|
93
|
+
* Extract multiple expressions with CSE (shared subexpressions become temps)
|
|
94
|
+
*/
|
|
95
|
+
export function extractWithCSE(egraph, roots, costModel = defaultCostModel, minSharedCost = 3 // Only extract temps if cost > this
|
|
96
|
+
) {
|
|
97
|
+
// Compute costs for all e-classes
|
|
98
|
+
const costs = computeCosts(egraph, costModel);
|
|
99
|
+
// Count references to each e-class from roots
|
|
100
|
+
const refCounts = countReferences(egraph, roots, costs, costModel);
|
|
101
|
+
// Decide which classes should become temps (count >= 2 means used multiple times)
|
|
102
|
+
const tempsToExtract = new Map();
|
|
103
|
+
let tempCounter = 0;
|
|
104
|
+
for (const [classId, count] of refCounts) {
|
|
105
|
+
if (count >= 2) {
|
|
106
|
+
const classCost = costs.get(egraph.find(classId)) ?? Infinity;
|
|
107
|
+
if (classCost > minSharedCost) {
|
|
108
|
+
tempsToExtract.set(egraph.find(classId), `_tmp${tempCounter++}`);
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
// Extract temp definitions (without referencing other temps initially)
|
|
113
|
+
const temps = new Map();
|
|
114
|
+
for (const [classId, tempName] of tempsToExtract) {
|
|
115
|
+
const expr = extractFromClass(egraph, classId, costs, costModel);
|
|
116
|
+
// Skip trivial expressions that shouldn't be temps
|
|
117
|
+
// (constants, or simple negations of constants)
|
|
118
|
+
if (isTrivialExpression(expr)) {
|
|
119
|
+
tempsToExtract.delete(classId);
|
|
120
|
+
continue;
|
|
121
|
+
}
|
|
122
|
+
temps.set(tempName, expr);
|
|
123
|
+
}
|
|
124
|
+
// Extract root expressions, using temps where available
|
|
125
|
+
const expressions = new Map();
|
|
126
|
+
for (const rootId of roots) {
|
|
127
|
+
const expr = extractWithTemps(egraph, rootId, costs, costModel, tempsToExtract);
|
|
128
|
+
expressions.set(rootId, expr);
|
|
129
|
+
}
|
|
130
|
+
// Post-process: substitute temps into other temp definitions where possible
|
|
131
|
+
// Build a map from expression serialization to temp name
|
|
132
|
+
const exprToTemp = new Map();
|
|
133
|
+
for (const [tempName, expr] of temps) {
|
|
134
|
+
exprToTemp.set(serializeExpr(expr), tempName);
|
|
135
|
+
}
|
|
136
|
+
// Substitute temps into temp definitions
|
|
137
|
+
for (const [tempName, expr] of temps) {
|
|
138
|
+
temps.set(tempName, substituteTempRefs(expr, exprToTemp, tempName));
|
|
139
|
+
}
|
|
140
|
+
// Topologically sort temps by dependency (deps first)
|
|
141
|
+
const sortedTemps = topologicalSortTemps(temps);
|
|
142
|
+
temps.clear();
|
|
143
|
+
for (const [name, expr] of sortedTemps) {
|
|
144
|
+
temps.set(name, expr);
|
|
145
|
+
}
|
|
146
|
+
// Count actual usage of each temp in the final output
|
|
147
|
+
const tempUsageCounts = new Map();
|
|
148
|
+
function countTempUsage(expr) {
|
|
149
|
+
if (expr.kind === 'variable' && expr.name.startsWith('_tmp')) {
|
|
150
|
+
tempUsageCounts.set(expr.name, (tempUsageCounts.get(expr.name) ?? 0) + 1);
|
|
151
|
+
}
|
|
152
|
+
else if (expr.kind === 'binary') {
|
|
153
|
+
countTempUsage(expr.left);
|
|
154
|
+
countTempUsage(expr.right);
|
|
155
|
+
}
|
|
156
|
+
else if (expr.kind === 'unary') {
|
|
157
|
+
countTempUsage(expr.operand);
|
|
158
|
+
}
|
|
159
|
+
else if (expr.kind === 'call') {
|
|
160
|
+
expr.args.forEach(countTempUsage);
|
|
161
|
+
}
|
|
162
|
+
else if (expr.kind === 'component') {
|
|
163
|
+
countTempUsage(expr.object);
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
// Count in root expressions
|
|
167
|
+
for (const expr of expressions.values()) {
|
|
168
|
+
countTempUsage(expr);
|
|
169
|
+
}
|
|
170
|
+
// Also count in temp definitions (temps can reference other temps)
|
|
171
|
+
for (const expr of temps.values()) {
|
|
172
|
+
countTempUsage(expr);
|
|
173
|
+
}
|
|
174
|
+
// Identify temps to inline (used 0 or 1 times)
|
|
175
|
+
const tempsToInline = new Set();
|
|
176
|
+
for (const [tempName] of temps) {
|
|
177
|
+
const count = tempUsageCounts.get(tempName) ?? 0;
|
|
178
|
+
if (count <= 1) {
|
|
179
|
+
tempsToInline.add(tempName);
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
// If there are temps to inline, substitute them back
|
|
183
|
+
if (tempsToInline.size > 0) {
|
|
184
|
+
function inlineTemps(expr) {
|
|
185
|
+
if (expr.kind === 'variable' && tempsToInline.has(expr.name)) {
|
|
186
|
+
const tempExpr = temps.get(expr.name);
|
|
187
|
+
return tempExpr ? inlineTemps(tempExpr) : expr;
|
|
188
|
+
}
|
|
189
|
+
else if (expr.kind === 'binary') {
|
|
190
|
+
return {
|
|
191
|
+
kind: 'binary',
|
|
192
|
+
operator: expr.operator,
|
|
193
|
+
left: inlineTemps(expr.left),
|
|
194
|
+
right: inlineTemps(expr.right)
|
|
195
|
+
};
|
|
196
|
+
}
|
|
197
|
+
else if (expr.kind === 'unary') {
|
|
198
|
+
return {
|
|
199
|
+
kind: 'unary',
|
|
200
|
+
operator: expr.operator,
|
|
201
|
+
operand: inlineTemps(expr.operand)
|
|
202
|
+
};
|
|
203
|
+
}
|
|
204
|
+
else if (expr.kind === 'call') {
|
|
205
|
+
return {
|
|
206
|
+
kind: 'call',
|
|
207
|
+
name: expr.name,
|
|
208
|
+
args: expr.args.map(inlineTemps)
|
|
209
|
+
};
|
|
210
|
+
}
|
|
211
|
+
else if (expr.kind === 'component') {
|
|
212
|
+
return {
|
|
213
|
+
kind: 'component',
|
|
214
|
+
object: inlineTemps(expr.object),
|
|
215
|
+
component: expr.component
|
|
216
|
+
};
|
|
217
|
+
}
|
|
218
|
+
return expr;
|
|
219
|
+
}
|
|
220
|
+
// Inline in remaining temps (temps not being inlined)
|
|
221
|
+
for (const [tempName, expr] of temps) {
|
|
222
|
+
if (!tempsToInline.has(tempName)) {
|
|
223
|
+
temps.set(tempName, inlineTemps(expr));
|
|
224
|
+
}
|
|
225
|
+
}
|
|
226
|
+
// Inline in root expressions
|
|
227
|
+
for (const [rootId, expr] of expressions) {
|
|
228
|
+
expressions.set(rootId, inlineTemps(expr));
|
|
229
|
+
}
|
|
230
|
+
// Remove inlined temps
|
|
231
|
+
for (const tempName of tempsToInline) {
|
|
232
|
+
temps.delete(tempName);
|
|
233
|
+
}
|
|
234
|
+
// Re-sort topologically after inlining (inlining may have changed dependencies)
|
|
235
|
+
const reSorted = topologicalSortTemps(temps);
|
|
236
|
+
temps.clear();
|
|
237
|
+
for (const [name, expr] of reSorted) {
|
|
238
|
+
temps.set(name, expr);
|
|
239
|
+
}
|
|
240
|
+
}
|
|
241
|
+
// Post-extraction CSE: find repeated patterns that emerge AFTER temp substitution
|
|
242
|
+
// e.g., "_tmp22 + _tmp23" appearing multiple times
|
|
243
|
+
postExtractionCSE(temps, expressions, minSharedCost, tempCounter, costModel);
|
|
244
|
+
// Detect and merge (a-b) / (b-a) patterns (these are negatives of each other)
|
|
245
|
+
mergeNegativePairs(temps, expressions);
|
|
246
|
+
// Normalize patterns like a + -1 * b to a - b (cleanup from e-graph extraction)
|
|
247
|
+
normalizeAddNegMul(temps, expressions);
|
|
248
|
+
// Calculate total cost
|
|
249
|
+
let totalCost = 0;
|
|
250
|
+
for (const [, expr] of temps) {
|
|
251
|
+
totalCost += expressionCost(expr, costModel);
|
|
252
|
+
}
|
|
253
|
+
for (const [, expr] of expressions) {
|
|
254
|
+
totalCost += expressionCost(expr, costModel);
|
|
255
|
+
}
|
|
256
|
+
return { temps, expressions, totalCost };
|
|
257
|
+
}
|
|
258
|
+
/**
|
|
259
|
+
* Compute the minimum cost for each e-class (bottom-up)
|
|
260
|
+
*/
|
|
261
|
+
function computeCosts(egraph, costModel) {
|
|
262
|
+
const costs = new Map();
|
|
263
|
+
const classIds = egraph.getClassIds();
|
|
264
|
+
// Initialize all costs to infinity
|
|
265
|
+
for (const id of classIds) {
|
|
266
|
+
costs.set(id, Infinity);
|
|
267
|
+
}
|
|
268
|
+
// Iterate until convergence
|
|
269
|
+
let changed = true;
|
|
270
|
+
let iterations = 0;
|
|
271
|
+
const maxIterations = 100;
|
|
272
|
+
while (changed && iterations < maxIterations) {
|
|
273
|
+
changed = false;
|
|
274
|
+
iterations++;
|
|
275
|
+
for (const classId of classIds) {
|
|
276
|
+
const canonId = egraph.find(classId);
|
|
277
|
+
const nodes = egraph.getNodes(canonId);
|
|
278
|
+
for (const node of nodes) {
|
|
279
|
+
const nodeCost = computeNodeCost(node, costs, costModel);
|
|
280
|
+
const currentCost = costs.get(canonId) ?? Infinity;
|
|
281
|
+
if (nodeCost < currentCost) {
|
|
282
|
+
costs.set(canonId, nodeCost);
|
|
283
|
+
changed = true;
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
return costs;
|
|
289
|
+
}
|
|
290
|
+
/**
|
|
291
|
+
* Compute cost of a single e-node
|
|
292
|
+
*/
|
|
293
|
+
function computeNodeCost(node, classCosts, costModel) {
|
|
294
|
+
const childCost = (id) => classCosts.get(id) ?? Infinity;
|
|
295
|
+
switch (node.tag) {
|
|
296
|
+
case 'num':
|
|
297
|
+
return costModel.num;
|
|
298
|
+
case 'var':
|
|
299
|
+
return costModel.var;
|
|
300
|
+
case 'add':
|
|
301
|
+
return costModel.add + childCost(node.children[0]) + childCost(node.children[1]);
|
|
302
|
+
case 'sub':
|
|
303
|
+
return costModel.sub + childCost(node.children[0]) + childCost(node.children[1]);
|
|
304
|
+
case 'mul':
|
|
305
|
+
return costModel.mul + childCost(node.children[0]) + childCost(node.children[1]);
|
|
306
|
+
case 'div':
|
|
307
|
+
return costModel.div + childCost(node.children[0]) + childCost(node.children[1]);
|
|
308
|
+
case 'pow':
|
|
309
|
+
return costModel.pow + childCost(node.children[0]) + childCost(node.children[1]);
|
|
310
|
+
case 'neg':
|
|
311
|
+
return costModel.neg + childCost(node.child);
|
|
312
|
+
case 'inv':
|
|
313
|
+
return costModel.inv + childCost(node.child);
|
|
314
|
+
case 'call':
|
|
315
|
+
return costModel.call + node.children.reduce((sum, id) => sum + childCost(id), 0);
|
|
316
|
+
case 'component':
|
|
317
|
+
return costModel.component + childCost(node.object);
|
|
318
|
+
}
|
|
319
|
+
}
|
|
320
|
+
/**
|
|
321
|
+
* Extract expression from an e-class using precomputed costs
|
|
322
|
+
*/
|
|
323
|
+
function extractFromClass(egraph, classId, costs, costModel) {
|
|
324
|
+
const canonId = egraph.find(classId);
|
|
325
|
+
const nodes = egraph.getNodes(canonId);
|
|
326
|
+
// Find lowest-cost node
|
|
327
|
+
let bestNode = null;
|
|
328
|
+
let bestCost = Infinity;
|
|
329
|
+
for (const node of nodes) {
|
|
330
|
+
const cost = computeNodeCost(node, costs, costModel);
|
|
331
|
+
if (cost < bestCost) {
|
|
332
|
+
bestCost = cost;
|
|
333
|
+
bestNode = node;
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
if (!bestNode) {
|
|
337
|
+
throw new Error(`No nodes in e-class ${canonId}`);
|
|
338
|
+
}
|
|
339
|
+
return nodeToExpression(bestNode, egraph, costs, costModel);
|
|
340
|
+
}
|
|
341
|
+
/**
|
|
342
|
+
* Extract expression, substituting temps where available
|
|
343
|
+
*/
|
|
344
|
+
function extractWithTemps(egraph, classId, costs, costModel, temps) {
|
|
345
|
+
const canonId = egraph.find(classId);
|
|
346
|
+
// Check if this class is a temp
|
|
347
|
+
const tempName = temps.get(canonId);
|
|
348
|
+
if (tempName) {
|
|
349
|
+
return { kind: 'variable', name: tempName };
|
|
350
|
+
}
|
|
351
|
+
const nodes = egraph.getNodes(canonId);
|
|
352
|
+
// Find lowest-cost node
|
|
353
|
+
let bestNode = null;
|
|
354
|
+
let bestCost = Infinity;
|
|
355
|
+
for (const node of nodes) {
|
|
356
|
+
const cost = computeNodeCost(node, costs, costModel);
|
|
357
|
+
if (cost < bestCost) {
|
|
358
|
+
bestCost = cost;
|
|
359
|
+
bestNode = node;
|
|
360
|
+
}
|
|
361
|
+
}
|
|
362
|
+
if (!bestNode) {
|
|
363
|
+
throw new Error(`No nodes in e-class ${canonId}`);
|
|
364
|
+
}
|
|
365
|
+
return nodeToExpressionWithTemps(bestNode, egraph, costs, costModel, temps);
|
|
366
|
+
}
|
|
367
|
+
/**
|
|
368
|
+
* Convert e-node to AST Expression (with constant folding)
|
|
369
|
+
*/
|
|
370
|
+
function nodeToExpression(node, egraph, costs, costModel) {
|
|
371
|
+
const extract = (id) => extractFromClass(egraph, id, costs, costModel);
|
|
372
|
+
switch (node.tag) {
|
|
373
|
+
case 'num':
|
|
374
|
+
return { kind: 'number', value: node.value };
|
|
375
|
+
case 'var':
|
|
376
|
+
return { kind: 'variable', name: node.name };
|
|
377
|
+
case 'add':
|
|
378
|
+
return makeBinary('+', extract(node.children[0]), extract(node.children[1]));
|
|
379
|
+
case 'sub':
|
|
380
|
+
return makeBinary('-', extract(node.children[0]), extract(node.children[1]));
|
|
381
|
+
case 'mul':
|
|
382
|
+
return makeBinary('*', extract(node.children[0]), extract(node.children[1]));
|
|
383
|
+
case 'div':
|
|
384
|
+
return makeBinary('/', extract(node.children[0]), extract(node.children[1]));
|
|
385
|
+
case 'pow':
|
|
386
|
+
return makeBinary('^', extract(node.children[0]), extract(node.children[1]));
|
|
387
|
+
case 'neg':
|
|
388
|
+
return makeUnary('-', extract(node.child));
|
|
389
|
+
case 'inv':
|
|
390
|
+
// inv(x) extracts as 1/x (with constant folding)
|
|
391
|
+
return makeBinary('/', { kind: 'number', value: 1 }, extract(node.child));
|
|
392
|
+
case 'call':
|
|
393
|
+
return {
|
|
394
|
+
kind: 'call',
|
|
395
|
+
name: node.name,
|
|
396
|
+
args: node.children.map(id => extract(id))
|
|
397
|
+
};
|
|
398
|
+
case 'component':
|
|
399
|
+
return {
|
|
400
|
+
kind: 'component',
|
|
401
|
+
object: extract(node.object),
|
|
402
|
+
component: node.field
|
|
403
|
+
};
|
|
404
|
+
}
|
|
405
|
+
}
|
|
406
|
+
/**
|
|
407
|
+
* Convert e-node to AST Expression, using temps
|
|
408
|
+
*/
|
|
409
|
+
function nodeToExpressionWithTemps(node, egraph, costs, costModel, temps) {
|
|
410
|
+
const extract = (id) => extractWithTemps(egraph, id, costs, costModel, temps);
|
|
411
|
+
switch (node.tag) {
|
|
412
|
+
case 'num':
|
|
413
|
+
return { kind: 'number', value: node.value };
|
|
414
|
+
case 'var':
|
|
415
|
+
return { kind: 'variable', name: node.name };
|
|
416
|
+
case 'add':
|
|
417
|
+
return makeBinary('+', extract(node.children[0]), extract(node.children[1]));
|
|
418
|
+
case 'sub':
|
|
419
|
+
return makeBinary('-', extract(node.children[0]), extract(node.children[1]));
|
|
420
|
+
case 'mul':
|
|
421
|
+
return makeBinary('*', extract(node.children[0]), extract(node.children[1]));
|
|
422
|
+
case 'div':
|
|
423
|
+
return makeBinary('/', extract(node.children[0]), extract(node.children[1]));
|
|
424
|
+
case 'pow':
|
|
425
|
+
return makeBinary('^', extract(node.children[0]), extract(node.children[1]));
|
|
426
|
+
case 'neg':
|
|
427
|
+
return makeUnary('-', extract(node.child));
|
|
428
|
+
case 'inv':
|
|
429
|
+
// inv(x) extracts as 1/x (with constant folding)
|
|
430
|
+
return makeBinary('/', { kind: 'number', value: 1 }, extract(node.child));
|
|
431
|
+
case 'call':
|
|
432
|
+
return {
|
|
433
|
+
kind: 'call',
|
|
434
|
+
name: node.name,
|
|
435
|
+
args: node.children.map(id => extract(id))
|
|
436
|
+
};
|
|
437
|
+
case 'component':
|
|
438
|
+
return {
|
|
439
|
+
kind: 'component',
|
|
440
|
+
object: extract(node.object),
|
|
441
|
+
component: node.field
|
|
442
|
+
};
|
|
443
|
+
}
|
|
444
|
+
}
|
|
445
|
+
/**
|
|
446
|
+
* Count references to each e-class from root expressions
|
|
447
|
+
*/
|
|
448
|
+
function countReferences(egraph, roots, costs, costModel) {
|
|
449
|
+
const counts = new Map();
|
|
450
|
+
function countInClass(classId, visited) {
|
|
451
|
+
const canonId = egraph.find(classId);
|
|
452
|
+
// Increment reference count
|
|
453
|
+
counts.set(canonId, (counts.get(canonId) ?? 0) + 1);
|
|
454
|
+
// Don't recurse if already visited in this path
|
|
455
|
+
if (visited.has(canonId)) {
|
|
456
|
+
return;
|
|
457
|
+
}
|
|
458
|
+
visited.add(canonId);
|
|
459
|
+
// Get best node and recurse into children
|
|
460
|
+
const nodes = egraph.getNodes(canonId);
|
|
461
|
+
let bestNode = null;
|
|
462
|
+
let bestCost = Infinity;
|
|
463
|
+
for (const node of nodes) {
|
|
464
|
+
const cost = computeNodeCost(node, costs, costModel);
|
|
465
|
+
if (cost < bestCost) {
|
|
466
|
+
bestCost = cost;
|
|
467
|
+
bestNode = node;
|
|
468
|
+
}
|
|
469
|
+
}
|
|
470
|
+
if (bestNode) {
|
|
471
|
+
for (const childId of enodeChildren(bestNode)) {
|
|
472
|
+
countInClass(childId, new Set(visited));
|
|
473
|
+
}
|
|
474
|
+
}
|
|
475
|
+
}
|
|
476
|
+
for (const rootId of roots) {
|
|
477
|
+
countInClass(rootId, new Set());
|
|
478
|
+
}
|
|
479
|
+
return counts;
|
|
480
|
+
}
|
|
481
|
+
/**
|
|
482
|
+
* Calculate cost of an AST expression
|
|
483
|
+
*/
|
|
484
|
+
function expressionCost(expr, costModel) {
|
|
485
|
+
switch (expr.kind) {
|
|
486
|
+
case 'number':
|
|
487
|
+
return costModel.num;
|
|
488
|
+
case 'variable':
|
|
489
|
+
return costModel.var;
|
|
490
|
+
case 'binary':
|
|
491
|
+
const opCost = expr.operator === '/' ? costModel.div :
|
|
492
|
+
expr.operator === '^' ? costModel.pow :
|
|
493
|
+
expr.operator === '*' ? costModel.mul :
|
|
494
|
+
costModel.add;
|
|
495
|
+
return opCost + expressionCost(expr.left, costModel) + expressionCost(expr.right, costModel);
|
|
496
|
+
case 'unary':
|
|
497
|
+
return costModel.neg + expressionCost(expr.operand, costModel);
|
|
498
|
+
case 'call':
|
|
499
|
+
return costModel.call + expr.args.reduce((sum, arg) => sum + expressionCost(arg, costModel), 0);
|
|
500
|
+
case 'component':
|
|
501
|
+
return costModel.component + expressionCost(expr.object, costModel);
|
|
502
|
+
}
|
|
503
|
+
}
|
|
504
|
+
/**
|
|
505
|
+
* Serialize an expression to a string for comparison
|
|
506
|
+
*/
|
|
507
|
+
function serializeExpr(expr) {
|
|
508
|
+
switch (expr.kind) {
|
|
509
|
+
case 'number':
|
|
510
|
+
return `N${expr.value}`;
|
|
511
|
+
case 'variable':
|
|
512
|
+
return `V${expr.name}`;
|
|
513
|
+
case 'binary':
|
|
514
|
+
return `(${serializeExpr(expr.left)}${expr.operator}${serializeExpr(expr.right)})`;
|
|
515
|
+
case 'unary':
|
|
516
|
+
return `U${expr.operator}${serializeExpr(expr.operand)}`;
|
|
517
|
+
case 'call':
|
|
518
|
+
return `C${expr.name}(${expr.args.map(serializeExpr).join(',')})`;
|
|
519
|
+
case 'component':
|
|
520
|
+
return `${serializeExpr(expr.object)}.${expr.component}`;
|
|
521
|
+
}
|
|
522
|
+
}
|
|
523
|
+
/**
|
|
524
|
+
* Substitute temp references into an expression (bottom-up)
|
|
525
|
+
* Looks for subexpressions that match other temps and replaces them
|
|
526
|
+
*/
|
|
527
|
+
function substituteTempRefs(expr, exprToTemp, currentTemp) {
|
|
528
|
+
// First, recursively substitute in children (bottom-up)
|
|
529
|
+
let result;
|
|
530
|
+
switch (expr.kind) {
|
|
531
|
+
case 'number':
|
|
532
|
+
case 'variable':
|
|
533
|
+
result = expr;
|
|
534
|
+
break;
|
|
535
|
+
case 'binary': {
|
|
536
|
+
const left = substituteTempRefs(expr.left, exprToTemp, currentTemp);
|
|
537
|
+
const right = substituteTempRefs(expr.right, exprToTemp, currentTemp);
|
|
538
|
+
result = (left === expr.left && right === expr.right)
|
|
539
|
+
? expr
|
|
540
|
+
: { kind: 'binary', operator: expr.operator, left, right };
|
|
541
|
+
break;
|
|
542
|
+
}
|
|
543
|
+
case 'unary': {
|
|
544
|
+
const operand = substituteTempRefs(expr.operand, exprToTemp, currentTemp);
|
|
545
|
+
result = (operand === expr.operand)
|
|
546
|
+
? expr
|
|
547
|
+
: { kind: 'unary', operator: expr.operator, operand };
|
|
548
|
+
break;
|
|
549
|
+
}
|
|
550
|
+
case 'call': {
|
|
551
|
+
const args = expr.args.map(arg => substituteTempRefs(arg, exprToTemp, currentTemp));
|
|
552
|
+
result = args.every((arg, i) => arg === expr.args[i])
|
|
553
|
+
? expr
|
|
554
|
+
: { kind: 'call', name: expr.name, args };
|
|
555
|
+
break;
|
|
556
|
+
}
|
|
557
|
+
case 'component': {
|
|
558
|
+
const object = substituteTempRefs(expr.object, exprToTemp, currentTemp);
|
|
559
|
+
result = (object === expr.object)
|
|
560
|
+
? expr
|
|
561
|
+
: { kind: 'component', object, component: expr.component };
|
|
562
|
+
break;
|
|
563
|
+
}
|
|
564
|
+
}
|
|
565
|
+
// Then check if the (possibly transformed) expression matches another temp
|
|
566
|
+
const serialized = serializeExpr(result);
|
|
567
|
+
const matchingTemp = exprToTemp.get(serialized);
|
|
568
|
+
if (matchingTemp && matchingTemp !== currentTemp) {
|
|
569
|
+
return { kind: 'variable', name: matchingTemp };
|
|
570
|
+
}
|
|
571
|
+
return result;
|
|
572
|
+
}
|
|
573
|
+
/**
|
|
574
|
+
* Topologically sort temps so dependencies come first
|
|
575
|
+
*/
|
|
576
|
+
function topologicalSortTemps(temps) {
|
|
577
|
+
// Find dependencies of each temp
|
|
578
|
+
const deps = new Map();
|
|
579
|
+
const tempNames = new Set(temps.keys());
|
|
580
|
+
function findDeps(expr, found) {
|
|
581
|
+
if (expr.kind === 'variable' && tempNames.has(expr.name)) {
|
|
582
|
+
found.add(expr.name);
|
|
583
|
+
}
|
|
584
|
+
else if (expr.kind === 'binary') {
|
|
585
|
+
findDeps(expr.left, found);
|
|
586
|
+
findDeps(expr.right, found);
|
|
587
|
+
}
|
|
588
|
+
else if (expr.kind === 'unary') {
|
|
589
|
+
findDeps(expr.operand, found);
|
|
590
|
+
}
|
|
591
|
+
else if (expr.kind === 'call') {
|
|
592
|
+
expr.args.forEach(arg => findDeps(arg, found));
|
|
593
|
+
}
|
|
594
|
+
else if (expr.kind === 'component') {
|
|
595
|
+
findDeps(expr.object, found);
|
|
596
|
+
}
|
|
597
|
+
}
|
|
598
|
+
for (const [name, expr] of temps) {
|
|
599
|
+
const d = new Set();
|
|
600
|
+
findDeps(expr, d);
|
|
601
|
+
deps.set(name, d);
|
|
602
|
+
}
|
|
603
|
+
// Topological sort using Kahn's algorithm
|
|
604
|
+
const result = [];
|
|
605
|
+
const remaining = new Set(temps.keys());
|
|
606
|
+
const processed = new Set();
|
|
607
|
+
while (remaining.size > 0) {
|
|
608
|
+
// Find a temp with no unprocessed dependencies
|
|
609
|
+
let found = false;
|
|
610
|
+
for (const name of remaining) {
|
|
611
|
+
const d = deps.get(name);
|
|
612
|
+
const hasUnprocessedDep = [...d].some(dep => !processed.has(dep));
|
|
613
|
+
if (!hasUnprocessedDep) {
|
|
614
|
+
result.push([name, temps.get(name)]);
|
|
615
|
+
remaining.delete(name);
|
|
616
|
+
processed.add(name);
|
|
617
|
+
found = true;
|
|
618
|
+
break;
|
|
619
|
+
}
|
|
620
|
+
}
|
|
621
|
+
if (!found) {
|
|
622
|
+
// Cycle detected - just add remaining in any order
|
|
623
|
+
for (const name of remaining) {
|
|
624
|
+
result.push([name, temps.get(name)]);
|
|
625
|
+
}
|
|
626
|
+
break;
|
|
627
|
+
}
|
|
628
|
+
}
|
|
629
|
+
return result;
|
|
630
|
+
}
|
|
631
|
+
/**
|
|
632
|
+
* Post-extraction CSE: find repeated patterns that emerge AFTER temp substitution
|
|
633
|
+
* e.g., "_tmp22 + _tmp23" appearing multiple times should become its own temp
|
|
634
|
+
*/
|
|
635
|
+
function postExtractionCSE(temps, expressions, minSharedCost, startingTempCounter, costModel) {
|
|
636
|
+
// Count occurrences of each subexpression
|
|
637
|
+
const exprCounts = new Map();
|
|
638
|
+
function countSubexprs(expr) {
|
|
639
|
+
// Don't count simple expressions
|
|
640
|
+
if (expr.kind === 'number' || expr.kind === 'variable')
|
|
641
|
+
return;
|
|
642
|
+
const serialized = serializeExpr(expr);
|
|
643
|
+
const cost = expressionCost(expr, costModel);
|
|
644
|
+
const existing = exprCounts.get(serialized);
|
|
645
|
+
if (existing) {
|
|
646
|
+
existing.count++;
|
|
647
|
+
}
|
|
648
|
+
else {
|
|
649
|
+
exprCounts.set(serialized, { count: 1, expr, cost });
|
|
650
|
+
}
|
|
651
|
+
// Recurse into children
|
|
652
|
+
if (expr.kind === 'binary') {
|
|
653
|
+
countSubexprs(expr.left);
|
|
654
|
+
countSubexprs(expr.right);
|
|
655
|
+
}
|
|
656
|
+
else if (expr.kind === 'unary') {
|
|
657
|
+
countSubexprs(expr.operand);
|
|
658
|
+
}
|
|
659
|
+
else if (expr.kind === 'call') {
|
|
660
|
+
expr.args.forEach(countSubexprs);
|
|
661
|
+
}
|
|
662
|
+
else if (expr.kind === 'component') {
|
|
663
|
+
countSubexprs(expr.object);
|
|
664
|
+
}
|
|
665
|
+
}
|
|
666
|
+
// Count in all temps and root expressions
|
|
667
|
+
for (const expr of temps.values()) {
|
|
668
|
+
countSubexprs(expr);
|
|
669
|
+
}
|
|
670
|
+
for (const expr of expressions.values()) {
|
|
671
|
+
countSubexprs(expr);
|
|
672
|
+
}
|
|
673
|
+
// Find subexpressions worth extracting (count >= 2 and cost > threshold)
|
|
674
|
+
const toExtract = [];
|
|
675
|
+
for (const [serialized, { count, expr, cost }] of exprCounts) {
|
|
676
|
+
if (count >= 2 && cost > minSharedCost) {
|
|
677
|
+
// Skip if it's just a temp reference
|
|
678
|
+
if (expr.kind === 'variable' && expr.name.startsWith('_tmp'))
|
|
679
|
+
continue;
|
|
680
|
+
// Skip trivial expressions (constants, negations of constants)
|
|
681
|
+
if (isTrivialExpression(expr))
|
|
682
|
+
continue;
|
|
683
|
+
toExtract.push({ serialized, expr, cost });
|
|
684
|
+
}
|
|
685
|
+
}
|
|
686
|
+
if (toExtract.length === 0)
|
|
687
|
+
return;
|
|
688
|
+
// Sort by cost ASCENDING (extract smaller/cheaper expressions first!)
|
|
689
|
+
// This is critical because larger patterns contain smaller ones.
|
|
690
|
+
// If we extract (a+b) first as _tmp100, then later patterns
|
|
691
|
+
// like (2 * (a+b)) will be serialized as (2 * _tmp100) and won't match.
|
|
692
|
+
toExtract.sort((a, b) => a.cost - b.cost);
|
|
693
|
+
// Build a map of existing temp RHS to prevent duplicates
|
|
694
|
+
const existingTempRHS = new Map();
|
|
695
|
+
for (const [tempName, expr] of temps) {
|
|
696
|
+
existingTempRHS.set(serializeExpr(expr), tempName);
|
|
697
|
+
}
|
|
698
|
+
// Create temps for repeated expressions
|
|
699
|
+
let tempCounter = startingTempCounter;
|
|
700
|
+
const serToTemp = new Map();
|
|
701
|
+
for (const { serialized, expr } of toExtract) {
|
|
702
|
+
// Skip if already defined as a temp
|
|
703
|
+
const existingTemp = existingTempRHS.get(serialized);
|
|
704
|
+
if (existingTemp) {
|
|
705
|
+
serToTemp.set(serialized, existingTemp);
|
|
706
|
+
continue;
|
|
707
|
+
}
|
|
708
|
+
// Find unique temp name
|
|
709
|
+
while (temps.has(`_tmp${tempCounter}`)) {
|
|
710
|
+
tempCounter++;
|
|
711
|
+
}
|
|
712
|
+
const tempName = `_tmp${tempCounter++}`;
|
|
713
|
+
serToTemp.set(serialized, tempName);
|
|
714
|
+
temps.set(tempName, expr);
|
|
715
|
+
existingTempRHS.set(serialized, tempName);
|
|
716
|
+
}
|
|
717
|
+
if (serToTemp.size === 0)
|
|
718
|
+
return;
|
|
719
|
+
// Substitute new temps into all expressions
|
|
720
|
+
function substitute(expr) {
|
|
721
|
+
if (expr.kind === 'number' || expr.kind === 'variable')
|
|
722
|
+
return expr;
|
|
723
|
+
const serialized = serializeExpr(expr);
|
|
724
|
+
const tempName = serToTemp.get(serialized);
|
|
725
|
+
if (tempName) {
|
|
726
|
+
return { kind: 'variable', name: tempName };
|
|
727
|
+
}
|
|
728
|
+
// Recurse
|
|
729
|
+
if (expr.kind === 'binary') {
|
|
730
|
+
const left = substitute(expr.left);
|
|
731
|
+
const right = substitute(expr.right);
|
|
732
|
+
return (left === expr.left && right === expr.right)
|
|
733
|
+
? expr
|
|
734
|
+
: { kind: 'binary', operator: expr.operator, left, right };
|
|
735
|
+
}
|
|
736
|
+
else if (expr.kind === 'unary') {
|
|
737
|
+
const operand = substitute(expr.operand);
|
|
738
|
+
return (operand === expr.operand)
|
|
739
|
+
? expr
|
|
740
|
+
: { kind: 'unary', operator: expr.operator, operand };
|
|
741
|
+
}
|
|
742
|
+
else if (expr.kind === 'call') {
|
|
743
|
+
const args = expr.args.map(substitute);
|
|
744
|
+
return args.every((arg, i) => arg === expr.args[i])
|
|
745
|
+
? expr
|
|
746
|
+
: { kind: 'call', name: expr.name, args };
|
|
747
|
+
}
|
|
748
|
+
else if (expr.kind === 'component') {
|
|
749
|
+
const object = substitute(expr.object);
|
|
750
|
+
return (object === expr.object)
|
|
751
|
+
? expr
|
|
752
|
+
: { kind: 'component', object, component: expr.component };
|
|
753
|
+
}
|
|
754
|
+
return expr;
|
|
755
|
+
}
|
|
756
|
+
// Substitute in ALL temps, including newly created ones
|
|
757
|
+
// But skip substituting a temp with itself (self-reference)
|
|
758
|
+
for (const [tempName, expr] of temps) {
|
|
759
|
+
// Create a substitute function that won't replace with the current temp
|
|
760
|
+
const subWithoutSelf = (e) => {
|
|
761
|
+
if (e.kind === 'number' || e.kind === 'variable')
|
|
762
|
+
return e;
|
|
763
|
+
const serialized = serializeExpr(e);
|
|
764
|
+
const targetTemp = serToTemp.get(serialized);
|
|
765
|
+
// Don't substitute if it would create self-reference
|
|
766
|
+
if (targetTemp && targetTemp !== tempName) {
|
|
767
|
+
return { kind: 'variable', name: targetTemp };
|
|
768
|
+
}
|
|
769
|
+
if (e.kind === 'binary') {
|
|
770
|
+
const left = subWithoutSelf(e.left);
|
|
771
|
+
const right = subWithoutSelf(e.right);
|
|
772
|
+
return (left === e.left && right === e.right)
|
|
773
|
+
? e
|
|
774
|
+
: { kind: 'binary', operator: e.operator, left, right };
|
|
775
|
+
}
|
|
776
|
+
else if (e.kind === 'unary') {
|
|
777
|
+
const operand = subWithoutSelf(e.operand);
|
|
778
|
+
return (operand === e.operand)
|
|
779
|
+
? e
|
|
780
|
+
: { kind: 'unary', operator: e.operator, operand };
|
|
781
|
+
}
|
|
782
|
+
else if (e.kind === 'call') {
|
|
783
|
+
const args = e.args.map(subWithoutSelf);
|
|
784
|
+
return args.every((arg, i) => arg === e.args[i])
|
|
785
|
+
? e
|
|
786
|
+
: { kind: 'call', name: e.name, args };
|
|
787
|
+
}
|
|
788
|
+
else if (e.kind === 'component') {
|
|
789
|
+
const object = subWithoutSelf(e.object);
|
|
790
|
+
return (object === e.object)
|
|
791
|
+
? e
|
|
792
|
+
: { kind: 'component', object, component: e.component };
|
|
793
|
+
}
|
|
794
|
+
return e;
|
|
795
|
+
};
|
|
796
|
+
temps.set(tempName, subWithoutSelf(expr));
|
|
797
|
+
}
|
|
798
|
+
// Substitute in root expressions
|
|
799
|
+
for (const [rootId, expr] of expressions) {
|
|
800
|
+
expressions.set(rootId, substitute(expr));
|
|
801
|
+
}
|
|
802
|
+
// Re-sort temps topologically
|
|
803
|
+
const sorted = topologicalSortTemps(temps);
|
|
804
|
+
temps.clear();
|
|
805
|
+
for (const [name, expr] of sorted) {
|
|
806
|
+
temps.set(name, expr);
|
|
807
|
+
}
|
|
808
|
+
// Inline temps that are now used only once (after all substitutions)
|
|
809
|
+
// This is critical because postExtractionCSE may have created temps
|
|
810
|
+
// that turned out to be used only once after substitution
|
|
811
|
+
const usageCounts = new Map();
|
|
812
|
+
function countUsage(expr) {
|
|
813
|
+
if (expr.kind === 'variable' && expr.name.startsWith('_tmp')) {
|
|
814
|
+
usageCounts.set(expr.name, (usageCounts.get(expr.name) ?? 0) + 1);
|
|
815
|
+
}
|
|
816
|
+
else if (expr.kind === 'binary') {
|
|
817
|
+
countUsage(expr.left);
|
|
818
|
+
countUsage(expr.right);
|
|
819
|
+
}
|
|
820
|
+
else if (expr.kind === 'unary') {
|
|
821
|
+
countUsage(expr.operand);
|
|
822
|
+
}
|
|
823
|
+
else if (expr.kind === 'call') {
|
|
824
|
+
expr.args.forEach(countUsage);
|
|
825
|
+
}
|
|
826
|
+
else if (expr.kind === 'component') {
|
|
827
|
+
countUsage(expr.object);
|
|
828
|
+
}
|
|
829
|
+
}
|
|
830
|
+
for (const expr of temps.values())
|
|
831
|
+
countUsage(expr);
|
|
832
|
+
for (const expr of expressions.values())
|
|
833
|
+
countUsage(expr);
|
|
834
|
+
// Find temps to inline (used 0 or 1 times)
|
|
835
|
+
const toInline = new Set();
|
|
836
|
+
for (const [name] of temps) {
|
|
837
|
+
const count = usageCounts.get(name) ?? 0;
|
|
838
|
+
if (count <= 1)
|
|
839
|
+
toInline.add(name);
|
|
840
|
+
}
|
|
841
|
+
if (toInline.size > 0) {
|
|
842
|
+
function inlineTemps(expr) {
|
|
843
|
+
if (expr.kind === 'variable' && toInline.has(expr.name)) {
|
|
844
|
+
const tempExpr = temps.get(expr.name);
|
|
845
|
+
return tempExpr ? inlineTemps(tempExpr) : expr;
|
|
846
|
+
}
|
|
847
|
+
else if (expr.kind === 'binary') {
|
|
848
|
+
const left = inlineTemps(expr.left);
|
|
849
|
+
const right = inlineTemps(expr.right);
|
|
850
|
+
return (left === expr.left && right === expr.right) ? expr
|
|
851
|
+
: { kind: 'binary', operator: expr.operator, left, right };
|
|
852
|
+
}
|
|
853
|
+
else if (expr.kind === 'unary') {
|
|
854
|
+
const operand = inlineTemps(expr.operand);
|
|
855
|
+
return (operand === expr.operand) ? expr
|
|
856
|
+
: { kind: 'unary', operator: expr.operator, operand };
|
|
857
|
+
}
|
|
858
|
+
else if (expr.kind === 'call') {
|
|
859
|
+
const args = expr.args.map(inlineTemps);
|
|
860
|
+
return args.every((a, i) => a === expr.args[i]) ? expr
|
|
861
|
+
: { kind: 'call', name: expr.name, args };
|
|
862
|
+
}
|
|
863
|
+
else if (expr.kind === 'component') {
|
|
864
|
+
const object = inlineTemps(expr.object);
|
|
865
|
+
return (object === expr.object) ? expr
|
|
866
|
+
: { kind: 'component', object, component: expr.component };
|
|
867
|
+
}
|
|
868
|
+
return expr;
|
|
869
|
+
}
|
|
870
|
+
// Inline in remaining temps
|
|
871
|
+
for (const [name, expr] of temps) {
|
|
872
|
+
if (!toInline.has(name)) {
|
|
873
|
+
temps.set(name, inlineTemps(expr));
|
|
874
|
+
}
|
|
875
|
+
}
|
|
876
|
+
// Inline in root expressions
|
|
877
|
+
for (const [rootId, expr] of expressions) {
|
|
878
|
+
expressions.set(rootId, inlineTemps(expr));
|
|
879
|
+
}
|
|
880
|
+
// Remove inlined temps
|
|
881
|
+
for (const name of toInline)
|
|
882
|
+
temps.delete(name);
|
|
883
|
+
// Re-sort topologically after inlining (inlining may have changed dependencies)
|
|
884
|
+
const finalSorted = topologicalSortTemps(temps);
|
|
885
|
+
temps.clear();
|
|
886
|
+
for (const [name, expr] of finalSorted) {
|
|
887
|
+
temps.set(name, expr);
|
|
888
|
+
}
|
|
889
|
+
}
|
|
890
|
+
}
|
|
891
|
+
/**
|
|
892
|
+
* Detect pairs of temps that are negatives of each other:
|
|
893
|
+
* e.g., _tmp1 = k * (a - b) and _tmp2 = k * (b - a)
|
|
894
|
+
* These can be merged: keep _tmp1, replace _tmp2 with -_tmp1
|
|
895
|
+
*/
|
|
896
|
+
function mergeNegativePairs(temps, expressions) {
|
|
897
|
+
// Build a map of "canonical subtraction form" -> temp name
|
|
898
|
+
// For k * (a - b), the canonical form is [k, a, b] sorted by serialization
|
|
899
|
+
const subPatterns = new Map();
|
|
900
|
+
for (const [tempName, expr] of temps) {
|
|
901
|
+
const pattern = extractSubtractionPattern(expr);
|
|
902
|
+
if (!pattern)
|
|
903
|
+
continue;
|
|
904
|
+
const { left, right, coefficient } = pattern;
|
|
905
|
+
// Canonical form: sort left and right alphabetically
|
|
906
|
+
const leftSer = serializeExpr(left);
|
|
907
|
+
const rightSer = serializeExpr(right);
|
|
908
|
+
const coeffSer = coefficient ? serializeExpr(coefficient) : '';
|
|
909
|
+
let canonKey;
|
|
910
|
+
let isNegated;
|
|
911
|
+
if (leftSer < rightSer) {
|
|
912
|
+
canonKey = `${coeffSer}:(${leftSer})-(${rightSer})`;
|
|
913
|
+
isNegated = false;
|
|
914
|
+
}
|
|
915
|
+
else {
|
|
916
|
+
canonKey = `${coeffSer}:(${rightSer})-(${leftSer})`;
|
|
917
|
+
isNegated = true;
|
|
918
|
+
}
|
|
919
|
+
const existing = subPatterns.get(canonKey);
|
|
920
|
+
if (existing) {
|
|
921
|
+
// Found a pair! One is the negative of the other
|
|
922
|
+
// Keep the non-negated one (or the first one if both are same)
|
|
923
|
+
if (existing.isNegated && !isNegated) {
|
|
924
|
+
// Current one is better, replace existing
|
|
925
|
+
replaceTempWithNegation(temps, expressions, existing.tempName, tempName);
|
|
926
|
+
subPatterns.set(canonKey, { tempName, isNegated, coefficient });
|
|
927
|
+
}
|
|
928
|
+
else if (!existing.isNegated && isNegated) {
|
|
929
|
+
// Existing is better, replace current
|
|
930
|
+
replaceTempWithNegation(temps, expressions, tempName, existing.tempName);
|
|
931
|
+
}
|
|
932
|
+
// If both have same negation status, do nothing
|
|
933
|
+
}
|
|
934
|
+
else {
|
|
935
|
+
subPatterns.set(canonKey, { tempName, isNegated, coefficient });
|
|
936
|
+
}
|
|
937
|
+
}
|
|
938
|
+
}
|
|
939
|
+
/**
|
|
940
|
+
* Extract subtraction pattern from expression:
|
|
941
|
+
* Returns { left, right, coefficient } for patterns like:
|
|
942
|
+
* - (a - b) -> { left: a, right: b, coefficient: null }
|
|
943
|
+
* - k * (a - b) -> { left: a, right: b, coefficient: k }
|
|
944
|
+
*/
|
|
945
|
+
function extractSubtractionPattern(expr) {
|
|
946
|
+
// Direct subtraction: (a - b)
|
|
947
|
+
if (expr.kind === 'binary' && expr.operator === '-') {
|
|
948
|
+
return { left: expr.left, right: expr.right, coefficient: null };
|
|
949
|
+
}
|
|
950
|
+
// Multiplication with subtraction: k * (a - b) or (a - b) * k
|
|
951
|
+
if (expr.kind === 'binary' && expr.operator === '*') {
|
|
952
|
+
if (expr.right.kind === 'binary' && expr.right.operator === '-') {
|
|
953
|
+
return { left: expr.right.left, right: expr.right.right, coefficient: expr.left };
|
|
954
|
+
}
|
|
955
|
+
if (expr.left.kind === 'binary' && expr.left.operator === '-') {
|
|
956
|
+
return { left: expr.left.left, right: expr.left.right, coefficient: expr.right };
|
|
957
|
+
}
|
|
958
|
+
}
|
|
959
|
+
return null;
|
|
960
|
+
}
|
|
961
|
+
/**
|
|
962
|
+
* Replace all uses of oldTemp with -newTemp, then delete oldTemp
|
|
963
|
+
*/
|
|
964
|
+
function replaceTempWithNegation(temps, expressions, oldTemp, newTemp) {
|
|
965
|
+
// Create negation expression
|
|
966
|
+
const negExpr = {
|
|
967
|
+
kind: 'unary',
|
|
968
|
+
operator: '-',
|
|
969
|
+
operand: { kind: 'variable', name: newTemp }
|
|
970
|
+
};
|
|
971
|
+
// Helper to replace references
|
|
972
|
+
function replaceRefs(expr) {
|
|
973
|
+
if (expr.kind === 'variable' && expr.name === oldTemp) {
|
|
974
|
+
return negExpr;
|
|
975
|
+
}
|
|
976
|
+
if (expr.kind === 'binary') {
|
|
977
|
+
const left = replaceRefs(expr.left);
|
|
978
|
+
const right = replaceRefs(expr.right);
|
|
979
|
+
return (left === expr.left && right === expr.right) ? expr
|
|
980
|
+
: { kind: 'binary', operator: expr.operator, left, right };
|
|
981
|
+
}
|
|
982
|
+
if (expr.kind === 'unary') {
|
|
983
|
+
const operand = replaceRefs(expr.operand);
|
|
984
|
+
return (operand === expr.operand) ? expr
|
|
985
|
+
: { kind: 'unary', operator: expr.operator, operand };
|
|
986
|
+
}
|
|
987
|
+
if (expr.kind === 'call') {
|
|
988
|
+
const args = expr.args.map(replaceRefs);
|
|
989
|
+
return args.every((a, i) => a === expr.args[i]) ? expr
|
|
990
|
+
: { kind: 'call', name: expr.name, args };
|
|
991
|
+
}
|
|
992
|
+
if (expr.kind === 'component') {
|
|
993
|
+
const object = replaceRefs(expr.object);
|
|
994
|
+
return (object === expr.object) ? expr
|
|
995
|
+
: { kind: 'component', object, component: expr.component };
|
|
996
|
+
}
|
|
997
|
+
return expr;
|
|
998
|
+
}
|
|
999
|
+
// Replace in all temps (except the one we're deleting)
|
|
1000
|
+
for (const [name, expr] of temps) {
|
|
1001
|
+
if (name !== oldTemp) {
|
|
1002
|
+
temps.set(name, replaceRefs(expr));
|
|
1003
|
+
}
|
|
1004
|
+
}
|
|
1005
|
+
// Replace in root expressions
|
|
1006
|
+
for (const [rootId, expr] of expressions) {
|
|
1007
|
+
expressions.set(rootId, replaceRefs(expr));
|
|
1008
|
+
}
|
|
1009
|
+
// Delete the old temp - its uses have been replaced with -newTemp
|
|
1010
|
+
temps.delete(oldTemp);
|
|
1011
|
+
}
|
|
1012
|
+
/**
|
|
1013
|
+
* Normalize patterns like a + -1 * b to a - b
|
|
1014
|
+
* This cleans up cases where e-graph extraction picked the wrong form
|
|
1015
|
+
*/
|
|
1016
|
+
function normalizeAddNegMul(temps, expressions) {
|
|
1017
|
+
function normalize(expr) {
|
|
1018
|
+
if (expr.kind === 'number' || expr.kind === 'variable')
|
|
1019
|
+
return expr;
|
|
1020
|
+
// First normalize children
|
|
1021
|
+
if (expr.kind === 'binary') {
|
|
1022
|
+
const left = normalize(expr.left);
|
|
1023
|
+
const right = normalize(expr.right);
|
|
1024
|
+
// Check for a + (-1 * b) pattern: convert to a - b
|
|
1025
|
+
if (expr.operator === '+' && right.kind === 'binary' && right.operator === '*') {
|
|
1026
|
+
if (right.left.kind === 'number' && right.left.value === -1) {
|
|
1027
|
+
return { kind: 'binary', operator: '-', left, right: normalize(right.right) };
|
|
1028
|
+
}
|
|
1029
|
+
if (right.right.kind === 'number' && right.right.value === -1) {
|
|
1030
|
+
return { kind: 'binary', operator: '-', left, right: normalize(right.left) };
|
|
1031
|
+
}
|
|
1032
|
+
}
|
|
1033
|
+
// Check for (-1 * b) + a pattern: convert to a - b
|
|
1034
|
+
if (expr.operator === '+' && left.kind === 'binary' && left.operator === '*') {
|
|
1035
|
+
if (left.left.kind === 'number' && left.left.value === -1) {
|
|
1036
|
+
return { kind: 'binary', operator: '-', left: right, right: normalize(left.right) };
|
|
1037
|
+
}
|
|
1038
|
+
if (left.right.kind === 'number' && left.right.value === -1) {
|
|
1039
|
+
return { kind: 'binary', operator: '-', left: right, right: normalize(left.left) };
|
|
1040
|
+
}
|
|
1041
|
+
}
|
|
1042
|
+
return (left === expr.left && right === expr.right)
|
|
1043
|
+
? expr
|
|
1044
|
+
: { kind: 'binary', operator: expr.operator, left, right };
|
|
1045
|
+
}
|
|
1046
|
+
if (expr.kind === 'unary') {
|
|
1047
|
+
const operand = normalize(expr.operand);
|
|
1048
|
+
return (operand === expr.operand) ? expr : { kind: 'unary', operator: expr.operator, operand };
|
|
1049
|
+
}
|
|
1050
|
+
if (expr.kind === 'call') {
|
|
1051
|
+
const args = expr.args.map(normalize);
|
|
1052
|
+
return args.every((a, i) => a === expr.args[i]) ? expr : { kind: 'call', name: expr.name, args };
|
|
1053
|
+
}
|
|
1054
|
+
if (expr.kind === 'component') {
|
|
1055
|
+
const object = normalize(expr.object);
|
|
1056
|
+
return (object === expr.object) ? expr : { kind: 'component', object, component: expr.component };
|
|
1057
|
+
}
|
|
1058
|
+
return expr;
|
|
1059
|
+
}
|
|
1060
|
+
// Normalize all temps
|
|
1061
|
+
for (const [name, expr] of temps) {
|
|
1062
|
+
temps.set(name, normalize(expr));
|
|
1063
|
+
}
|
|
1064
|
+
// Normalize root expressions
|
|
1065
|
+
for (const [rootId, expr] of expressions) {
|
|
1066
|
+
expressions.set(rootId, normalize(expr));
|
|
1067
|
+
}
|
|
1068
|
+
}
|