gradient-script 0.2.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. package/README.md +3 -1
  2. package/dist/cli.js +80 -3
  3. package/dist/dsl/CodeGen.d.ts +1 -1
  4. package/dist/dsl/CodeGen.js +332 -74
  5. package/dist/dsl/ExpressionUtils.d.ts +8 -2
  6. package/dist/dsl/ExpressionUtils.js +34 -2
  7. package/dist/dsl/GradientChecker.d.ts +21 -0
  8. package/dist/dsl/GradientChecker.js +109 -23
  9. package/dist/dsl/Guards.d.ts +1 -1
  10. package/dist/dsl/Guards.js +14 -13
  11. package/dist/dsl/Inliner.d.ts +5 -0
  12. package/dist/dsl/Inliner.js +8 -0
  13. package/dist/dsl/Simplify.d.ts +7 -0
  14. package/dist/dsl/Simplify.js +136 -0
  15. package/dist/dsl/egraph/Convert.d.ts +23 -0
  16. package/dist/dsl/egraph/Convert.js +84 -0
  17. package/dist/dsl/egraph/EGraph.d.ts +93 -0
  18. package/dist/dsl/egraph/EGraph.js +292 -0
  19. package/dist/dsl/egraph/ENode.d.ts +63 -0
  20. package/dist/dsl/egraph/ENode.js +94 -0
  21. package/dist/dsl/egraph/Extractor.d.ts +49 -0
  22. package/dist/dsl/egraph/Extractor.js +1068 -0
  23. package/dist/dsl/egraph/Optimizer.d.ts +50 -0
  24. package/dist/dsl/egraph/Optimizer.js +88 -0
  25. package/dist/dsl/egraph/Pattern.d.ts +80 -0
  26. package/dist/dsl/egraph/Pattern.js +325 -0
  27. package/dist/dsl/egraph/Rewriter.d.ts +44 -0
  28. package/dist/dsl/egraph/Rewriter.js +131 -0
  29. package/dist/dsl/egraph/Rules.d.ts +44 -0
  30. package/dist/dsl/egraph/Rules.js +187 -0
  31. package/dist/dsl/egraph/index.d.ts +15 -0
  32. package/dist/dsl/egraph/index.js +21 -0
  33. package/package.json +1 -1
  34. package/dist/dsl/CSE.d.ts +0 -21
  35. package/dist/dsl/CSE.js +0 -168
  36. package/dist/symbolic/AST.d.ts +0 -113
  37. package/dist/symbolic/AST.js +0 -128
  38. package/dist/symbolic/CodeGen.d.ts +0 -35
  39. package/dist/symbolic/CodeGen.js +0 -280
  40. package/dist/symbolic/Parser.d.ts +0 -64
  41. package/dist/symbolic/Parser.js +0 -329
  42. package/dist/symbolic/Simplify.d.ts +0 -10
  43. package/dist/symbolic/Simplify.js +0 -244
  44. package/dist/symbolic/SymbolicDiff.d.ts +0 -35
  45. package/dist/symbolic/SymbolicDiff.js +0 -339
@@ -0,0 +1,1068 @@
1
+ /**
2
+ * Cost-Based Extraction from E-Graphs
3
+ *
4
+ * Extracts the lowest-cost expression from each e-class.
5
+ * Also detects common subexpressions (shared e-classes) for CSE.
6
+ */
7
+ import { enodeChildren } from './ENode.js';
8
+ // =============================================================================
9
+ // Constant Folding Helpers
10
+ // =============================================================================
11
+ /**
12
+ * Create a binary expression, folding constants when possible
13
+ */
14
+ function makeBinary(operator, left, right) {
15
+ // If both operands are numbers, fold the constant
16
+ if (left.kind === 'number' && right.kind === 'number') {
17
+ const l = left.value;
18
+ const r = right.value;
19
+ let result;
20
+ switch (operator) {
21
+ case '+':
22
+ result = l + r;
23
+ break;
24
+ case '-':
25
+ result = l - r;
26
+ break;
27
+ case '*':
28
+ result = l * r;
29
+ break;
30
+ case '/':
31
+ result = r !== 0 ? l / r : NaN;
32
+ break;
33
+ case '^':
34
+ result = Math.pow(l, r);
35
+ break;
36
+ }
37
+ // Only fold if result is a finite number
38
+ if (Number.isFinite(result)) {
39
+ return { kind: 'number', value: result };
40
+ }
41
+ }
42
+ return { kind: 'binary', operator, left, right };
43
+ }
44
+ /**
45
+ * Create a unary expression, folding constants when possible
46
+ */
47
+ function makeUnary(operator, operand) {
48
+ // If operand is a number, fold the constant
49
+ if (operand.kind === 'number') {
50
+ return { kind: 'number', value: -operand.value };
51
+ }
52
+ return { kind: 'unary', operator, operand };
53
+ }
54
+ /**
55
+ * Check if an expression is trivial (should never be a temp)
56
+ * - Constants: 5, -2, etc.
57
+ * - Simple negations of constants: -(2)
58
+ * - Variables: a, b
59
+ */
60
+ function isTrivialExpression(expr) {
61
+ if (expr.kind === 'number')
62
+ return true;
63
+ if (expr.kind === 'variable')
64
+ return true;
65
+ if (expr.kind === 'unary' && expr.operand.kind === 'number')
66
+ return true;
67
+ return false;
68
+ }
69
+ /**
70
+ * Default cost model - division is expensive
71
+ */
72
+ export const defaultCostModel = {
73
+ num: 1,
74
+ var: 1,
75
+ add: 2,
76
+ sub: 2,
77
+ mul: 2,
78
+ div: 8, // Division is expensive - encourage factoring
79
+ pow: 4,
80
+ neg: 1,
81
+ inv: 5, // Inverse (1/x) - cheaper than div but significant
82
+ call: 3,
83
+ component: 1,
84
+ };
85
+ /**
86
+ * Extract the best expression from an e-class
87
+ */
88
+ export function extractBest(egraph, rootId, costModel = defaultCostModel) {
89
+ const costs = computeCosts(egraph, costModel);
90
+ return extractFromClass(egraph, rootId, costs, costModel);
91
+ }
92
+ /**
93
+ * Extract multiple expressions with CSE (shared subexpressions become temps)
94
+ */
95
+ export function extractWithCSE(egraph, roots, costModel = defaultCostModel, minSharedCost = 3 // Only extract temps if cost > this
96
+ ) {
97
+ // Compute costs for all e-classes
98
+ const costs = computeCosts(egraph, costModel);
99
+ // Count references to each e-class from roots
100
+ const refCounts = countReferences(egraph, roots, costs, costModel);
101
+ // Decide which classes should become temps (count >= 2 means used multiple times)
102
+ const tempsToExtract = new Map();
103
+ let tempCounter = 0;
104
+ for (const [classId, count] of refCounts) {
105
+ if (count >= 2) {
106
+ const classCost = costs.get(egraph.find(classId)) ?? Infinity;
107
+ if (classCost > minSharedCost) {
108
+ tempsToExtract.set(egraph.find(classId), `_tmp${tempCounter++}`);
109
+ }
110
+ }
111
+ }
112
+ // Extract temp definitions (without referencing other temps initially)
113
+ const temps = new Map();
114
+ for (const [classId, tempName] of tempsToExtract) {
115
+ const expr = extractFromClass(egraph, classId, costs, costModel);
116
+ // Skip trivial expressions that shouldn't be temps
117
+ // (constants, or simple negations of constants)
118
+ if (isTrivialExpression(expr)) {
119
+ tempsToExtract.delete(classId);
120
+ continue;
121
+ }
122
+ temps.set(tempName, expr);
123
+ }
124
+ // Extract root expressions, using temps where available
125
+ const expressions = new Map();
126
+ for (const rootId of roots) {
127
+ const expr = extractWithTemps(egraph, rootId, costs, costModel, tempsToExtract);
128
+ expressions.set(rootId, expr);
129
+ }
130
+ // Post-process: substitute temps into other temp definitions where possible
131
+ // Build a map from expression serialization to temp name
132
+ const exprToTemp = new Map();
133
+ for (const [tempName, expr] of temps) {
134
+ exprToTemp.set(serializeExpr(expr), tempName);
135
+ }
136
+ // Substitute temps into temp definitions
137
+ for (const [tempName, expr] of temps) {
138
+ temps.set(tempName, substituteTempRefs(expr, exprToTemp, tempName));
139
+ }
140
+ // Topologically sort temps by dependency (deps first)
141
+ const sortedTemps = topologicalSortTemps(temps);
142
+ temps.clear();
143
+ for (const [name, expr] of sortedTemps) {
144
+ temps.set(name, expr);
145
+ }
146
+ // Count actual usage of each temp in the final output
147
+ const tempUsageCounts = new Map();
148
+ function countTempUsage(expr) {
149
+ if (expr.kind === 'variable' && expr.name.startsWith('_tmp')) {
150
+ tempUsageCounts.set(expr.name, (tempUsageCounts.get(expr.name) ?? 0) + 1);
151
+ }
152
+ else if (expr.kind === 'binary') {
153
+ countTempUsage(expr.left);
154
+ countTempUsage(expr.right);
155
+ }
156
+ else if (expr.kind === 'unary') {
157
+ countTempUsage(expr.operand);
158
+ }
159
+ else if (expr.kind === 'call') {
160
+ expr.args.forEach(countTempUsage);
161
+ }
162
+ else if (expr.kind === 'component') {
163
+ countTempUsage(expr.object);
164
+ }
165
+ }
166
+ // Count in root expressions
167
+ for (const expr of expressions.values()) {
168
+ countTempUsage(expr);
169
+ }
170
+ // Also count in temp definitions (temps can reference other temps)
171
+ for (const expr of temps.values()) {
172
+ countTempUsage(expr);
173
+ }
174
+ // Identify temps to inline (used 0 or 1 times)
175
+ const tempsToInline = new Set();
176
+ for (const [tempName] of temps) {
177
+ const count = tempUsageCounts.get(tempName) ?? 0;
178
+ if (count <= 1) {
179
+ tempsToInline.add(tempName);
180
+ }
181
+ }
182
+ // If there are temps to inline, substitute them back
183
+ if (tempsToInline.size > 0) {
184
+ function inlineTemps(expr) {
185
+ if (expr.kind === 'variable' && tempsToInline.has(expr.name)) {
186
+ const tempExpr = temps.get(expr.name);
187
+ return tempExpr ? inlineTemps(tempExpr) : expr;
188
+ }
189
+ else if (expr.kind === 'binary') {
190
+ return {
191
+ kind: 'binary',
192
+ operator: expr.operator,
193
+ left: inlineTemps(expr.left),
194
+ right: inlineTemps(expr.right)
195
+ };
196
+ }
197
+ else if (expr.kind === 'unary') {
198
+ return {
199
+ kind: 'unary',
200
+ operator: expr.operator,
201
+ operand: inlineTemps(expr.operand)
202
+ };
203
+ }
204
+ else if (expr.kind === 'call') {
205
+ return {
206
+ kind: 'call',
207
+ name: expr.name,
208
+ args: expr.args.map(inlineTemps)
209
+ };
210
+ }
211
+ else if (expr.kind === 'component') {
212
+ return {
213
+ kind: 'component',
214
+ object: inlineTemps(expr.object),
215
+ component: expr.component
216
+ };
217
+ }
218
+ return expr;
219
+ }
220
+ // Inline in remaining temps (temps not being inlined)
221
+ for (const [tempName, expr] of temps) {
222
+ if (!tempsToInline.has(tempName)) {
223
+ temps.set(tempName, inlineTemps(expr));
224
+ }
225
+ }
226
+ // Inline in root expressions
227
+ for (const [rootId, expr] of expressions) {
228
+ expressions.set(rootId, inlineTemps(expr));
229
+ }
230
+ // Remove inlined temps
231
+ for (const tempName of tempsToInline) {
232
+ temps.delete(tempName);
233
+ }
234
+ // Re-sort topologically after inlining (inlining may have changed dependencies)
235
+ const reSorted = topologicalSortTemps(temps);
236
+ temps.clear();
237
+ for (const [name, expr] of reSorted) {
238
+ temps.set(name, expr);
239
+ }
240
+ }
241
+ // Post-extraction CSE: find repeated patterns that emerge AFTER temp substitution
242
+ // e.g., "_tmp22 + _tmp23" appearing multiple times
243
+ postExtractionCSE(temps, expressions, minSharedCost, tempCounter, costModel);
244
+ // Detect and merge (a-b) / (b-a) patterns (these are negatives of each other)
245
+ mergeNegativePairs(temps, expressions);
246
+ // Normalize patterns like a + -1 * b to a - b (cleanup from e-graph extraction)
247
+ normalizeAddNegMul(temps, expressions);
248
+ // Calculate total cost
249
+ let totalCost = 0;
250
+ for (const [, expr] of temps) {
251
+ totalCost += expressionCost(expr, costModel);
252
+ }
253
+ for (const [, expr] of expressions) {
254
+ totalCost += expressionCost(expr, costModel);
255
+ }
256
+ return { temps, expressions, totalCost };
257
+ }
258
+ /**
259
+ * Compute the minimum cost for each e-class (bottom-up)
260
+ */
261
+ function computeCosts(egraph, costModel) {
262
+ const costs = new Map();
263
+ const classIds = egraph.getClassIds();
264
+ // Initialize all costs to infinity
265
+ for (const id of classIds) {
266
+ costs.set(id, Infinity);
267
+ }
268
+ // Iterate until convergence
269
+ let changed = true;
270
+ let iterations = 0;
271
+ const maxIterations = 100;
272
+ while (changed && iterations < maxIterations) {
273
+ changed = false;
274
+ iterations++;
275
+ for (const classId of classIds) {
276
+ const canonId = egraph.find(classId);
277
+ const nodes = egraph.getNodes(canonId);
278
+ for (const node of nodes) {
279
+ const nodeCost = computeNodeCost(node, costs, costModel);
280
+ const currentCost = costs.get(canonId) ?? Infinity;
281
+ if (nodeCost < currentCost) {
282
+ costs.set(canonId, nodeCost);
283
+ changed = true;
284
+ }
285
+ }
286
+ }
287
+ }
288
+ return costs;
289
+ }
290
+ /**
291
+ * Compute cost of a single e-node
292
+ */
293
+ function computeNodeCost(node, classCosts, costModel) {
294
+ const childCost = (id) => classCosts.get(id) ?? Infinity;
295
+ switch (node.tag) {
296
+ case 'num':
297
+ return costModel.num;
298
+ case 'var':
299
+ return costModel.var;
300
+ case 'add':
301
+ return costModel.add + childCost(node.children[0]) + childCost(node.children[1]);
302
+ case 'sub':
303
+ return costModel.sub + childCost(node.children[0]) + childCost(node.children[1]);
304
+ case 'mul':
305
+ return costModel.mul + childCost(node.children[0]) + childCost(node.children[1]);
306
+ case 'div':
307
+ return costModel.div + childCost(node.children[0]) + childCost(node.children[1]);
308
+ case 'pow':
309
+ return costModel.pow + childCost(node.children[0]) + childCost(node.children[1]);
310
+ case 'neg':
311
+ return costModel.neg + childCost(node.child);
312
+ case 'inv':
313
+ return costModel.inv + childCost(node.child);
314
+ case 'call':
315
+ return costModel.call + node.children.reduce((sum, id) => sum + childCost(id), 0);
316
+ case 'component':
317
+ return costModel.component + childCost(node.object);
318
+ }
319
+ }
320
+ /**
321
+ * Extract expression from an e-class using precomputed costs
322
+ */
323
+ function extractFromClass(egraph, classId, costs, costModel) {
324
+ const canonId = egraph.find(classId);
325
+ const nodes = egraph.getNodes(canonId);
326
+ // Find lowest-cost node
327
+ let bestNode = null;
328
+ let bestCost = Infinity;
329
+ for (const node of nodes) {
330
+ const cost = computeNodeCost(node, costs, costModel);
331
+ if (cost < bestCost) {
332
+ bestCost = cost;
333
+ bestNode = node;
334
+ }
335
+ }
336
+ if (!bestNode) {
337
+ throw new Error(`No nodes in e-class ${canonId}`);
338
+ }
339
+ return nodeToExpression(bestNode, egraph, costs, costModel);
340
+ }
341
+ /**
342
+ * Extract expression, substituting temps where available
343
+ */
344
+ function extractWithTemps(egraph, classId, costs, costModel, temps) {
345
+ const canonId = egraph.find(classId);
346
+ // Check if this class is a temp
347
+ const tempName = temps.get(canonId);
348
+ if (tempName) {
349
+ return { kind: 'variable', name: tempName };
350
+ }
351
+ const nodes = egraph.getNodes(canonId);
352
+ // Find lowest-cost node
353
+ let bestNode = null;
354
+ let bestCost = Infinity;
355
+ for (const node of nodes) {
356
+ const cost = computeNodeCost(node, costs, costModel);
357
+ if (cost < bestCost) {
358
+ bestCost = cost;
359
+ bestNode = node;
360
+ }
361
+ }
362
+ if (!bestNode) {
363
+ throw new Error(`No nodes in e-class ${canonId}`);
364
+ }
365
+ return nodeToExpressionWithTemps(bestNode, egraph, costs, costModel, temps);
366
+ }
367
+ /**
368
+ * Convert e-node to AST Expression (with constant folding)
369
+ */
370
+ function nodeToExpression(node, egraph, costs, costModel) {
371
+ const extract = (id) => extractFromClass(egraph, id, costs, costModel);
372
+ switch (node.tag) {
373
+ case 'num':
374
+ return { kind: 'number', value: node.value };
375
+ case 'var':
376
+ return { kind: 'variable', name: node.name };
377
+ case 'add':
378
+ return makeBinary('+', extract(node.children[0]), extract(node.children[1]));
379
+ case 'sub':
380
+ return makeBinary('-', extract(node.children[0]), extract(node.children[1]));
381
+ case 'mul':
382
+ return makeBinary('*', extract(node.children[0]), extract(node.children[1]));
383
+ case 'div':
384
+ return makeBinary('/', extract(node.children[0]), extract(node.children[1]));
385
+ case 'pow':
386
+ return makeBinary('^', extract(node.children[0]), extract(node.children[1]));
387
+ case 'neg':
388
+ return makeUnary('-', extract(node.child));
389
+ case 'inv':
390
+ // inv(x) extracts as 1/x (with constant folding)
391
+ return makeBinary('/', { kind: 'number', value: 1 }, extract(node.child));
392
+ case 'call':
393
+ return {
394
+ kind: 'call',
395
+ name: node.name,
396
+ args: node.children.map(id => extract(id))
397
+ };
398
+ case 'component':
399
+ return {
400
+ kind: 'component',
401
+ object: extract(node.object),
402
+ component: node.field
403
+ };
404
+ }
405
+ }
406
+ /**
407
+ * Convert e-node to AST Expression, using temps
408
+ */
409
+ function nodeToExpressionWithTemps(node, egraph, costs, costModel, temps) {
410
+ const extract = (id) => extractWithTemps(egraph, id, costs, costModel, temps);
411
+ switch (node.tag) {
412
+ case 'num':
413
+ return { kind: 'number', value: node.value };
414
+ case 'var':
415
+ return { kind: 'variable', name: node.name };
416
+ case 'add':
417
+ return makeBinary('+', extract(node.children[0]), extract(node.children[1]));
418
+ case 'sub':
419
+ return makeBinary('-', extract(node.children[0]), extract(node.children[1]));
420
+ case 'mul':
421
+ return makeBinary('*', extract(node.children[0]), extract(node.children[1]));
422
+ case 'div':
423
+ return makeBinary('/', extract(node.children[0]), extract(node.children[1]));
424
+ case 'pow':
425
+ return makeBinary('^', extract(node.children[0]), extract(node.children[1]));
426
+ case 'neg':
427
+ return makeUnary('-', extract(node.child));
428
+ case 'inv':
429
+ // inv(x) extracts as 1/x (with constant folding)
430
+ return makeBinary('/', { kind: 'number', value: 1 }, extract(node.child));
431
+ case 'call':
432
+ return {
433
+ kind: 'call',
434
+ name: node.name,
435
+ args: node.children.map(id => extract(id))
436
+ };
437
+ case 'component':
438
+ return {
439
+ kind: 'component',
440
+ object: extract(node.object),
441
+ component: node.field
442
+ };
443
+ }
444
+ }
445
+ /**
446
+ * Count references to each e-class from root expressions
447
+ */
448
+ function countReferences(egraph, roots, costs, costModel) {
449
+ const counts = new Map();
450
+ function countInClass(classId, visited) {
451
+ const canonId = egraph.find(classId);
452
+ // Increment reference count
453
+ counts.set(canonId, (counts.get(canonId) ?? 0) + 1);
454
+ // Don't recurse if already visited in this path
455
+ if (visited.has(canonId)) {
456
+ return;
457
+ }
458
+ visited.add(canonId);
459
+ // Get best node and recurse into children
460
+ const nodes = egraph.getNodes(canonId);
461
+ let bestNode = null;
462
+ let bestCost = Infinity;
463
+ for (const node of nodes) {
464
+ const cost = computeNodeCost(node, costs, costModel);
465
+ if (cost < bestCost) {
466
+ bestCost = cost;
467
+ bestNode = node;
468
+ }
469
+ }
470
+ if (bestNode) {
471
+ for (const childId of enodeChildren(bestNode)) {
472
+ countInClass(childId, new Set(visited));
473
+ }
474
+ }
475
+ }
476
+ for (const rootId of roots) {
477
+ countInClass(rootId, new Set());
478
+ }
479
+ return counts;
480
+ }
481
+ /**
482
+ * Calculate cost of an AST expression
483
+ */
484
+ function expressionCost(expr, costModel) {
485
+ switch (expr.kind) {
486
+ case 'number':
487
+ return costModel.num;
488
+ case 'variable':
489
+ return costModel.var;
490
+ case 'binary':
491
+ const opCost = expr.operator === '/' ? costModel.div :
492
+ expr.operator === '^' ? costModel.pow :
493
+ expr.operator === '*' ? costModel.mul :
494
+ costModel.add;
495
+ return opCost + expressionCost(expr.left, costModel) + expressionCost(expr.right, costModel);
496
+ case 'unary':
497
+ return costModel.neg + expressionCost(expr.operand, costModel);
498
+ case 'call':
499
+ return costModel.call + expr.args.reduce((sum, arg) => sum + expressionCost(arg, costModel), 0);
500
+ case 'component':
501
+ return costModel.component + expressionCost(expr.object, costModel);
502
+ }
503
+ }
504
+ /**
505
+ * Serialize an expression to a string for comparison
506
+ */
507
+ function serializeExpr(expr) {
508
+ switch (expr.kind) {
509
+ case 'number':
510
+ return `N${expr.value}`;
511
+ case 'variable':
512
+ return `V${expr.name}`;
513
+ case 'binary':
514
+ return `(${serializeExpr(expr.left)}${expr.operator}${serializeExpr(expr.right)})`;
515
+ case 'unary':
516
+ return `U${expr.operator}${serializeExpr(expr.operand)}`;
517
+ case 'call':
518
+ return `C${expr.name}(${expr.args.map(serializeExpr).join(',')})`;
519
+ case 'component':
520
+ return `${serializeExpr(expr.object)}.${expr.component}`;
521
+ }
522
+ }
523
+ /**
524
+ * Substitute temp references into an expression (bottom-up)
525
+ * Looks for subexpressions that match other temps and replaces them
526
+ */
527
+ function substituteTempRefs(expr, exprToTemp, currentTemp) {
528
+ // First, recursively substitute in children (bottom-up)
529
+ let result;
530
+ switch (expr.kind) {
531
+ case 'number':
532
+ case 'variable':
533
+ result = expr;
534
+ break;
535
+ case 'binary': {
536
+ const left = substituteTempRefs(expr.left, exprToTemp, currentTemp);
537
+ const right = substituteTempRefs(expr.right, exprToTemp, currentTemp);
538
+ result = (left === expr.left && right === expr.right)
539
+ ? expr
540
+ : { kind: 'binary', operator: expr.operator, left, right };
541
+ break;
542
+ }
543
+ case 'unary': {
544
+ const operand = substituteTempRefs(expr.operand, exprToTemp, currentTemp);
545
+ result = (operand === expr.operand)
546
+ ? expr
547
+ : { kind: 'unary', operator: expr.operator, operand };
548
+ break;
549
+ }
550
+ case 'call': {
551
+ const args = expr.args.map(arg => substituteTempRefs(arg, exprToTemp, currentTemp));
552
+ result = args.every((arg, i) => arg === expr.args[i])
553
+ ? expr
554
+ : { kind: 'call', name: expr.name, args };
555
+ break;
556
+ }
557
+ case 'component': {
558
+ const object = substituteTempRefs(expr.object, exprToTemp, currentTemp);
559
+ result = (object === expr.object)
560
+ ? expr
561
+ : { kind: 'component', object, component: expr.component };
562
+ break;
563
+ }
564
+ }
565
+ // Then check if the (possibly transformed) expression matches another temp
566
+ const serialized = serializeExpr(result);
567
+ const matchingTemp = exprToTemp.get(serialized);
568
+ if (matchingTemp && matchingTemp !== currentTemp) {
569
+ return { kind: 'variable', name: matchingTemp };
570
+ }
571
+ return result;
572
+ }
573
+ /**
574
+ * Topologically sort temps so dependencies come first
575
+ */
576
+ function topologicalSortTemps(temps) {
577
+ // Find dependencies of each temp
578
+ const deps = new Map();
579
+ const tempNames = new Set(temps.keys());
580
+ function findDeps(expr, found) {
581
+ if (expr.kind === 'variable' && tempNames.has(expr.name)) {
582
+ found.add(expr.name);
583
+ }
584
+ else if (expr.kind === 'binary') {
585
+ findDeps(expr.left, found);
586
+ findDeps(expr.right, found);
587
+ }
588
+ else if (expr.kind === 'unary') {
589
+ findDeps(expr.operand, found);
590
+ }
591
+ else if (expr.kind === 'call') {
592
+ expr.args.forEach(arg => findDeps(arg, found));
593
+ }
594
+ else if (expr.kind === 'component') {
595
+ findDeps(expr.object, found);
596
+ }
597
+ }
598
+ for (const [name, expr] of temps) {
599
+ const d = new Set();
600
+ findDeps(expr, d);
601
+ deps.set(name, d);
602
+ }
603
+ // Topological sort using Kahn's algorithm
604
+ const result = [];
605
+ const remaining = new Set(temps.keys());
606
+ const processed = new Set();
607
+ while (remaining.size > 0) {
608
+ // Find a temp with no unprocessed dependencies
609
+ let found = false;
610
+ for (const name of remaining) {
611
+ const d = deps.get(name);
612
+ const hasUnprocessedDep = [...d].some(dep => !processed.has(dep));
613
+ if (!hasUnprocessedDep) {
614
+ result.push([name, temps.get(name)]);
615
+ remaining.delete(name);
616
+ processed.add(name);
617
+ found = true;
618
+ break;
619
+ }
620
+ }
621
+ if (!found) {
622
+ // Cycle detected - just add remaining in any order
623
+ for (const name of remaining) {
624
+ result.push([name, temps.get(name)]);
625
+ }
626
+ break;
627
+ }
628
+ }
629
+ return result;
630
+ }
631
+ /**
632
+ * Post-extraction CSE: find repeated patterns that emerge AFTER temp substitution
633
+ * e.g., "_tmp22 + _tmp23" appearing multiple times should become its own temp
634
+ */
635
+ function postExtractionCSE(temps, expressions, minSharedCost, startingTempCounter, costModel) {
636
+ // Count occurrences of each subexpression
637
+ const exprCounts = new Map();
638
+ function countSubexprs(expr) {
639
+ // Don't count simple expressions
640
+ if (expr.kind === 'number' || expr.kind === 'variable')
641
+ return;
642
+ const serialized = serializeExpr(expr);
643
+ const cost = expressionCost(expr, costModel);
644
+ const existing = exprCounts.get(serialized);
645
+ if (existing) {
646
+ existing.count++;
647
+ }
648
+ else {
649
+ exprCounts.set(serialized, { count: 1, expr, cost });
650
+ }
651
+ // Recurse into children
652
+ if (expr.kind === 'binary') {
653
+ countSubexprs(expr.left);
654
+ countSubexprs(expr.right);
655
+ }
656
+ else if (expr.kind === 'unary') {
657
+ countSubexprs(expr.operand);
658
+ }
659
+ else if (expr.kind === 'call') {
660
+ expr.args.forEach(countSubexprs);
661
+ }
662
+ else if (expr.kind === 'component') {
663
+ countSubexprs(expr.object);
664
+ }
665
+ }
666
+ // Count in all temps and root expressions
667
+ for (const expr of temps.values()) {
668
+ countSubexprs(expr);
669
+ }
670
+ for (const expr of expressions.values()) {
671
+ countSubexprs(expr);
672
+ }
673
+ // Find subexpressions worth extracting (count >= 2 and cost > threshold)
674
+ const toExtract = [];
675
+ for (const [serialized, { count, expr, cost }] of exprCounts) {
676
+ if (count >= 2 && cost > minSharedCost) {
677
+ // Skip if it's just a temp reference
678
+ if (expr.kind === 'variable' && expr.name.startsWith('_tmp'))
679
+ continue;
680
+ // Skip trivial expressions (constants, negations of constants)
681
+ if (isTrivialExpression(expr))
682
+ continue;
683
+ toExtract.push({ serialized, expr, cost });
684
+ }
685
+ }
686
+ if (toExtract.length === 0)
687
+ return;
688
+ // Sort by cost ASCENDING (extract smaller/cheaper expressions first!)
689
+ // This is critical because larger patterns contain smaller ones.
690
+ // If we extract (a+b) first as _tmp100, then later patterns
691
+ // like (2 * (a+b)) will be serialized as (2 * _tmp100) and won't match.
692
+ toExtract.sort((a, b) => a.cost - b.cost);
693
+ // Build a map of existing temp RHS to prevent duplicates
694
+ const existingTempRHS = new Map();
695
+ for (const [tempName, expr] of temps) {
696
+ existingTempRHS.set(serializeExpr(expr), tempName);
697
+ }
698
+ // Create temps for repeated expressions
699
+ let tempCounter = startingTempCounter;
700
+ const serToTemp = new Map();
701
+ for (const { serialized, expr } of toExtract) {
702
+ // Skip if already defined as a temp
703
+ const existingTemp = existingTempRHS.get(serialized);
704
+ if (existingTemp) {
705
+ serToTemp.set(serialized, existingTemp);
706
+ continue;
707
+ }
708
+ // Find unique temp name
709
+ while (temps.has(`_tmp${tempCounter}`)) {
710
+ tempCounter++;
711
+ }
712
+ const tempName = `_tmp${tempCounter++}`;
713
+ serToTemp.set(serialized, tempName);
714
+ temps.set(tempName, expr);
715
+ existingTempRHS.set(serialized, tempName);
716
+ }
717
+ if (serToTemp.size === 0)
718
+ return;
719
+ // Substitute new temps into all expressions
720
+ function substitute(expr) {
721
+ if (expr.kind === 'number' || expr.kind === 'variable')
722
+ return expr;
723
+ const serialized = serializeExpr(expr);
724
+ const tempName = serToTemp.get(serialized);
725
+ if (tempName) {
726
+ return { kind: 'variable', name: tempName };
727
+ }
728
+ // Recurse
729
+ if (expr.kind === 'binary') {
730
+ const left = substitute(expr.left);
731
+ const right = substitute(expr.right);
732
+ return (left === expr.left && right === expr.right)
733
+ ? expr
734
+ : { kind: 'binary', operator: expr.operator, left, right };
735
+ }
736
+ else if (expr.kind === 'unary') {
737
+ const operand = substitute(expr.operand);
738
+ return (operand === expr.operand)
739
+ ? expr
740
+ : { kind: 'unary', operator: expr.operator, operand };
741
+ }
742
+ else if (expr.kind === 'call') {
743
+ const args = expr.args.map(substitute);
744
+ return args.every((arg, i) => arg === expr.args[i])
745
+ ? expr
746
+ : { kind: 'call', name: expr.name, args };
747
+ }
748
+ else if (expr.kind === 'component') {
749
+ const object = substitute(expr.object);
750
+ return (object === expr.object)
751
+ ? expr
752
+ : { kind: 'component', object, component: expr.component };
753
+ }
754
+ return expr;
755
+ }
756
+ // Substitute in ALL temps, including newly created ones
757
+ // But skip substituting a temp with itself (self-reference)
758
+ for (const [tempName, expr] of temps) {
759
+ // Create a substitute function that won't replace with the current temp
760
+ const subWithoutSelf = (e) => {
761
+ if (e.kind === 'number' || e.kind === 'variable')
762
+ return e;
763
+ const serialized = serializeExpr(e);
764
+ const targetTemp = serToTemp.get(serialized);
765
+ // Don't substitute if it would create self-reference
766
+ if (targetTemp && targetTemp !== tempName) {
767
+ return { kind: 'variable', name: targetTemp };
768
+ }
769
+ if (e.kind === 'binary') {
770
+ const left = subWithoutSelf(e.left);
771
+ const right = subWithoutSelf(e.right);
772
+ return (left === e.left && right === e.right)
773
+ ? e
774
+ : { kind: 'binary', operator: e.operator, left, right };
775
+ }
776
+ else if (e.kind === 'unary') {
777
+ const operand = subWithoutSelf(e.operand);
778
+ return (operand === e.operand)
779
+ ? e
780
+ : { kind: 'unary', operator: e.operator, operand };
781
+ }
782
+ else if (e.kind === 'call') {
783
+ const args = e.args.map(subWithoutSelf);
784
+ return args.every((arg, i) => arg === e.args[i])
785
+ ? e
786
+ : { kind: 'call', name: e.name, args };
787
+ }
788
+ else if (e.kind === 'component') {
789
+ const object = subWithoutSelf(e.object);
790
+ return (object === e.object)
791
+ ? e
792
+ : { kind: 'component', object, component: e.component };
793
+ }
794
+ return e;
795
+ };
796
+ temps.set(tempName, subWithoutSelf(expr));
797
+ }
798
+ // Substitute in root expressions
799
+ for (const [rootId, expr] of expressions) {
800
+ expressions.set(rootId, substitute(expr));
801
+ }
802
+ // Re-sort temps topologically
803
+ const sorted = topologicalSortTemps(temps);
804
+ temps.clear();
805
+ for (const [name, expr] of sorted) {
806
+ temps.set(name, expr);
807
+ }
808
+ // Inline temps that are now used only once (after all substitutions)
809
+ // This is critical because postExtractionCSE may have created temps
810
+ // that turned out to be used only once after substitution
811
+ const usageCounts = new Map();
812
+ function countUsage(expr) {
813
+ if (expr.kind === 'variable' && expr.name.startsWith('_tmp')) {
814
+ usageCounts.set(expr.name, (usageCounts.get(expr.name) ?? 0) + 1);
815
+ }
816
+ else if (expr.kind === 'binary') {
817
+ countUsage(expr.left);
818
+ countUsage(expr.right);
819
+ }
820
+ else if (expr.kind === 'unary') {
821
+ countUsage(expr.operand);
822
+ }
823
+ else if (expr.kind === 'call') {
824
+ expr.args.forEach(countUsage);
825
+ }
826
+ else if (expr.kind === 'component') {
827
+ countUsage(expr.object);
828
+ }
829
+ }
830
+ for (const expr of temps.values())
831
+ countUsage(expr);
832
+ for (const expr of expressions.values())
833
+ countUsage(expr);
834
+ // Find temps to inline (used 0 or 1 times)
835
+ const toInline = new Set();
836
+ for (const [name] of temps) {
837
+ const count = usageCounts.get(name) ?? 0;
838
+ if (count <= 1)
839
+ toInline.add(name);
840
+ }
841
+ if (toInline.size > 0) {
842
+ function inlineTemps(expr) {
843
+ if (expr.kind === 'variable' && toInline.has(expr.name)) {
844
+ const tempExpr = temps.get(expr.name);
845
+ return tempExpr ? inlineTemps(tempExpr) : expr;
846
+ }
847
+ else if (expr.kind === 'binary') {
848
+ const left = inlineTemps(expr.left);
849
+ const right = inlineTemps(expr.right);
850
+ return (left === expr.left && right === expr.right) ? expr
851
+ : { kind: 'binary', operator: expr.operator, left, right };
852
+ }
853
+ else if (expr.kind === 'unary') {
854
+ const operand = inlineTemps(expr.operand);
855
+ return (operand === expr.operand) ? expr
856
+ : { kind: 'unary', operator: expr.operator, operand };
857
+ }
858
+ else if (expr.kind === 'call') {
859
+ const args = expr.args.map(inlineTemps);
860
+ return args.every((a, i) => a === expr.args[i]) ? expr
861
+ : { kind: 'call', name: expr.name, args };
862
+ }
863
+ else if (expr.kind === 'component') {
864
+ const object = inlineTemps(expr.object);
865
+ return (object === expr.object) ? expr
866
+ : { kind: 'component', object, component: expr.component };
867
+ }
868
+ return expr;
869
+ }
870
+ // Inline in remaining temps
871
+ for (const [name, expr] of temps) {
872
+ if (!toInline.has(name)) {
873
+ temps.set(name, inlineTemps(expr));
874
+ }
875
+ }
876
+ // Inline in root expressions
877
+ for (const [rootId, expr] of expressions) {
878
+ expressions.set(rootId, inlineTemps(expr));
879
+ }
880
+ // Remove inlined temps
881
+ for (const name of toInline)
882
+ temps.delete(name);
883
+ // Re-sort topologically after inlining (inlining may have changed dependencies)
884
+ const finalSorted = topologicalSortTemps(temps);
885
+ temps.clear();
886
+ for (const [name, expr] of finalSorted) {
887
+ temps.set(name, expr);
888
+ }
889
+ }
890
+ }
891
+ /**
892
+ * Detect pairs of temps that are negatives of each other:
893
+ * e.g., _tmp1 = k * (a - b) and _tmp2 = k * (b - a)
894
+ * These can be merged: keep _tmp1, replace _tmp2 with -_tmp1
895
+ */
896
+ function mergeNegativePairs(temps, expressions) {
897
+ // Build a map of "canonical subtraction form" -> temp name
898
+ // For k * (a - b), the canonical form is [k, a, b] sorted by serialization
899
+ const subPatterns = new Map();
900
+ for (const [tempName, expr] of temps) {
901
+ const pattern = extractSubtractionPattern(expr);
902
+ if (!pattern)
903
+ continue;
904
+ const { left, right, coefficient } = pattern;
905
+ // Canonical form: sort left and right alphabetically
906
+ const leftSer = serializeExpr(left);
907
+ const rightSer = serializeExpr(right);
908
+ const coeffSer = coefficient ? serializeExpr(coefficient) : '';
909
+ let canonKey;
910
+ let isNegated;
911
+ if (leftSer < rightSer) {
912
+ canonKey = `${coeffSer}:(${leftSer})-(${rightSer})`;
913
+ isNegated = false;
914
+ }
915
+ else {
916
+ canonKey = `${coeffSer}:(${rightSer})-(${leftSer})`;
917
+ isNegated = true;
918
+ }
919
+ const existing = subPatterns.get(canonKey);
920
+ if (existing) {
921
+ // Found a pair! One is the negative of the other
922
+ // Keep the non-negated one (or the first one if both are same)
923
+ if (existing.isNegated && !isNegated) {
924
+ // Current one is better, replace existing
925
+ replaceTempWithNegation(temps, expressions, existing.tempName, tempName);
926
+ subPatterns.set(canonKey, { tempName, isNegated, coefficient });
927
+ }
928
+ else if (!existing.isNegated && isNegated) {
929
+ // Existing is better, replace current
930
+ replaceTempWithNegation(temps, expressions, tempName, existing.tempName);
931
+ }
932
+ // If both have same negation status, do nothing
933
+ }
934
+ else {
935
+ subPatterns.set(canonKey, { tempName, isNegated, coefficient });
936
+ }
937
+ }
938
+ }
939
+ /**
940
+ * Extract subtraction pattern from expression:
941
+ * Returns { left, right, coefficient } for patterns like:
942
+ * - (a - b) -> { left: a, right: b, coefficient: null }
943
+ * - k * (a - b) -> { left: a, right: b, coefficient: k }
944
+ */
945
+ function extractSubtractionPattern(expr) {
946
+ // Direct subtraction: (a - b)
947
+ if (expr.kind === 'binary' && expr.operator === '-') {
948
+ return { left: expr.left, right: expr.right, coefficient: null };
949
+ }
950
+ // Multiplication with subtraction: k * (a - b) or (a - b) * k
951
+ if (expr.kind === 'binary' && expr.operator === '*') {
952
+ if (expr.right.kind === 'binary' && expr.right.operator === '-') {
953
+ return { left: expr.right.left, right: expr.right.right, coefficient: expr.left };
954
+ }
955
+ if (expr.left.kind === 'binary' && expr.left.operator === '-') {
956
+ return { left: expr.left.left, right: expr.left.right, coefficient: expr.right };
957
+ }
958
+ }
959
+ return null;
960
+ }
961
+ /**
962
+ * Replace all uses of oldTemp with -newTemp, then delete oldTemp
963
+ */
964
+ function replaceTempWithNegation(temps, expressions, oldTemp, newTemp) {
965
+ // Create negation expression
966
+ const negExpr = {
967
+ kind: 'unary',
968
+ operator: '-',
969
+ operand: { kind: 'variable', name: newTemp }
970
+ };
971
+ // Helper to replace references
972
+ function replaceRefs(expr) {
973
+ if (expr.kind === 'variable' && expr.name === oldTemp) {
974
+ return negExpr;
975
+ }
976
+ if (expr.kind === 'binary') {
977
+ const left = replaceRefs(expr.left);
978
+ const right = replaceRefs(expr.right);
979
+ return (left === expr.left && right === expr.right) ? expr
980
+ : { kind: 'binary', operator: expr.operator, left, right };
981
+ }
982
+ if (expr.kind === 'unary') {
983
+ const operand = replaceRefs(expr.operand);
984
+ return (operand === expr.operand) ? expr
985
+ : { kind: 'unary', operator: expr.operator, operand };
986
+ }
987
+ if (expr.kind === 'call') {
988
+ const args = expr.args.map(replaceRefs);
989
+ return args.every((a, i) => a === expr.args[i]) ? expr
990
+ : { kind: 'call', name: expr.name, args };
991
+ }
992
+ if (expr.kind === 'component') {
993
+ const object = replaceRefs(expr.object);
994
+ return (object === expr.object) ? expr
995
+ : { kind: 'component', object, component: expr.component };
996
+ }
997
+ return expr;
998
+ }
999
+ // Replace in all temps (except the one we're deleting)
1000
+ for (const [name, expr] of temps) {
1001
+ if (name !== oldTemp) {
1002
+ temps.set(name, replaceRefs(expr));
1003
+ }
1004
+ }
1005
+ // Replace in root expressions
1006
+ for (const [rootId, expr] of expressions) {
1007
+ expressions.set(rootId, replaceRefs(expr));
1008
+ }
1009
+ // Delete the old temp - its uses have been replaced with -newTemp
1010
+ temps.delete(oldTemp);
1011
+ }
1012
+ /**
1013
+ * Normalize patterns like a + -1 * b to a - b
1014
+ * This cleans up cases where e-graph extraction picked the wrong form
1015
+ */
1016
+ function normalizeAddNegMul(temps, expressions) {
1017
+ function normalize(expr) {
1018
+ if (expr.kind === 'number' || expr.kind === 'variable')
1019
+ return expr;
1020
+ // First normalize children
1021
+ if (expr.kind === 'binary') {
1022
+ const left = normalize(expr.left);
1023
+ const right = normalize(expr.right);
1024
+ // Check for a + (-1 * b) pattern: convert to a - b
1025
+ if (expr.operator === '+' && right.kind === 'binary' && right.operator === '*') {
1026
+ if (right.left.kind === 'number' && right.left.value === -1) {
1027
+ return { kind: 'binary', operator: '-', left, right: normalize(right.right) };
1028
+ }
1029
+ if (right.right.kind === 'number' && right.right.value === -1) {
1030
+ return { kind: 'binary', operator: '-', left, right: normalize(right.left) };
1031
+ }
1032
+ }
1033
+ // Check for (-1 * b) + a pattern: convert to a - b
1034
+ if (expr.operator === '+' && left.kind === 'binary' && left.operator === '*') {
1035
+ if (left.left.kind === 'number' && left.left.value === -1) {
1036
+ return { kind: 'binary', operator: '-', left: right, right: normalize(left.right) };
1037
+ }
1038
+ if (left.right.kind === 'number' && left.right.value === -1) {
1039
+ return { kind: 'binary', operator: '-', left: right, right: normalize(left.left) };
1040
+ }
1041
+ }
1042
+ return (left === expr.left && right === expr.right)
1043
+ ? expr
1044
+ : { kind: 'binary', operator: expr.operator, left, right };
1045
+ }
1046
+ if (expr.kind === 'unary') {
1047
+ const operand = normalize(expr.operand);
1048
+ return (operand === expr.operand) ? expr : { kind: 'unary', operator: expr.operator, operand };
1049
+ }
1050
+ if (expr.kind === 'call') {
1051
+ const args = expr.args.map(normalize);
1052
+ return args.every((a, i) => a === expr.args[i]) ? expr : { kind: 'call', name: expr.name, args };
1053
+ }
1054
+ if (expr.kind === 'component') {
1055
+ const object = normalize(expr.object);
1056
+ return (object === expr.object) ? expr : { kind: 'component', object, component: expr.component };
1057
+ }
1058
+ return expr;
1059
+ }
1060
+ // Normalize all temps
1061
+ for (const [name, expr] of temps) {
1062
+ temps.set(name, normalize(expr));
1063
+ }
1064
+ // Normalize root expressions
1065
+ for (const [rootId, expr] of expressions) {
1066
+ expressions.set(rootId, normalize(expr));
1067
+ }
1068
+ }