gradient-script 0.2.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. package/README.md +3 -1
  2. package/dist/cli.js +80 -3
  3. package/dist/dsl/CodeGen.d.ts +1 -1
  4. package/dist/dsl/CodeGen.js +332 -74
  5. package/dist/dsl/ExpressionUtils.d.ts +8 -2
  6. package/dist/dsl/ExpressionUtils.js +34 -2
  7. package/dist/dsl/GradientChecker.d.ts +21 -0
  8. package/dist/dsl/GradientChecker.js +109 -23
  9. package/dist/dsl/Guards.d.ts +1 -1
  10. package/dist/dsl/Guards.js +14 -13
  11. package/dist/dsl/Inliner.d.ts +5 -0
  12. package/dist/dsl/Inliner.js +8 -0
  13. package/dist/dsl/Simplify.d.ts +7 -0
  14. package/dist/dsl/Simplify.js +136 -0
  15. package/dist/dsl/egraph/Convert.d.ts +23 -0
  16. package/dist/dsl/egraph/Convert.js +84 -0
  17. package/dist/dsl/egraph/EGraph.d.ts +93 -0
  18. package/dist/dsl/egraph/EGraph.js +292 -0
  19. package/dist/dsl/egraph/ENode.d.ts +63 -0
  20. package/dist/dsl/egraph/ENode.js +94 -0
  21. package/dist/dsl/egraph/Extractor.d.ts +49 -0
  22. package/dist/dsl/egraph/Extractor.js +1068 -0
  23. package/dist/dsl/egraph/Optimizer.d.ts +50 -0
  24. package/dist/dsl/egraph/Optimizer.js +88 -0
  25. package/dist/dsl/egraph/Pattern.d.ts +80 -0
  26. package/dist/dsl/egraph/Pattern.js +325 -0
  27. package/dist/dsl/egraph/Rewriter.d.ts +44 -0
  28. package/dist/dsl/egraph/Rewriter.js +131 -0
  29. package/dist/dsl/egraph/Rules.d.ts +44 -0
  30. package/dist/dsl/egraph/Rules.js +187 -0
  31. package/dist/dsl/egraph/index.d.ts +15 -0
  32. package/dist/dsl/egraph/index.js +21 -0
  33. package/package.json +1 -1
  34. package/dist/dsl/CSE.d.ts +0 -21
  35. package/dist/dsl/CSE.js +0 -168
  36. package/dist/symbolic/AST.d.ts +0 -113
  37. package/dist/symbolic/AST.js +0 -128
  38. package/dist/symbolic/CodeGen.d.ts +0 -35
  39. package/dist/symbolic/CodeGen.js +0 -280
  40. package/dist/symbolic/Parser.d.ts +0 -64
  41. package/dist/symbolic/Parser.js +0 -329
  42. package/dist/symbolic/Simplify.d.ts +0 -10
  43. package/dist/symbolic/Simplify.js +0 -244
  44. package/dist/symbolic/SymbolicDiff.d.ts +0 -35
  45. package/dist/symbolic/SymbolicDiff.js +0 -339
package/README.md CHANGED
@@ -12,6 +12,8 @@
12
12
 
13
13
  GradientScript is a source-to-source compiler that automatically generates gradient functions from your mathematical code. Unlike numerical AD frameworks (JAX, PyTorch), it produces clean, human-readable gradient formulas you can inspect, optimize, and integrate directly into your codebase.
14
14
 
15
+ It's perfect for LLM usage where the LLM can verify existing gradients or construct gradients you require with less risks of making errors.
16
+
15
17
  ## Why GradientScript?
16
18
 
17
19
  - **From real code to gradients**: Write natural math code, get symbolic derivatives
@@ -40,7 +42,7 @@ function distance(u: Vec2, v: Vec2): number {
40
42
  }
41
43
  ```
42
44
 
43
- Convert it to GradientScript by marking what you need gradients for:
45
+ Convert it to GradientScript (realistically, let your LLM convert it giving it a reference here - and/or free usage of the CLI) by marking what you need gradients for:
44
46
 
45
47
  ```typescript
46
48
  // distance.gs
package/dist/cli.js CHANGED
@@ -6,6 +6,68 @@ import { computeFunctionGradients } from './dsl/Differentiation.js';
6
6
  import { generateComplete } from './dsl/CodeGen.js';
7
7
  import { analyzeGuards, formatGuardWarnings } from './dsl/Guards.js';
8
8
  import { ParseError, formatParseError } from './dsl/Errors.js';
9
+ import { GradientChecker, formatGradCheckResult } from './dsl/GradientChecker.js';
10
+ import { Types } from './dsl/Types.js';
11
+ /**
12
+ * Generate random test points for gradient verification.
13
+ * Uses multiple test points to catch errors at different values.
14
+ */
15
+ function generateTestPoints(func, env) {
16
+ const testPoints = [];
17
+ // Generate 3 different test points with varying scales
18
+ const scales = [1.0, 0.1, 10.0];
19
+ for (const scale of scales) {
20
+ const point = new Map();
21
+ for (const param of func.parameters) {
22
+ const paramType = env.getOrThrow(param.name);
23
+ if (Types.isScalar(paramType)) {
24
+ // Random scalar in range [-scale, scale], avoid zero
25
+ point.set(param.name, (Math.random() * 2 - 1) * scale + 0.1 * scale);
26
+ }
27
+ else {
28
+ // Structured type - get components
29
+ const struct = {};
30
+ for (const comp of paramType.components) {
31
+ struct[comp] = (Math.random() * 2 - 1) * scale + 0.1 * scale;
32
+ }
33
+ point.set(param.name, struct);
34
+ }
35
+ }
36
+ testPoints.push(point);
37
+ }
38
+ return testPoints;
39
+ }
40
+ /**
41
+ * Verify gradients for a function using numerical differentiation.
42
+ * Returns true if all gradients pass, false otherwise.
43
+ */
44
+ function verifyGradients(func, gradients, env) {
45
+ const checker = new GradientChecker(1e-5, 1e-4);
46
+ const testPoints = generateTestPoints(func, env);
47
+ let allPassed = true;
48
+ for (let i = 0; i < testPoints.length; i++) {
49
+ const result = checker.check(func, gradients, env, testPoints[i]);
50
+ if (!result.passed) {
51
+ if (allPassed) {
52
+ // First failure - print header (as comment for valid output)
53
+ console.error(`// Gradient verification FAILED for "${func.name}":`);
54
+ }
55
+ // Prefix each line with // so output remains valid code
56
+ const formattedResult = formatGradCheckResult(result, func.name)
57
+ .split('\n')
58
+ .map(line => '// ' + line)
59
+ .join('\n');
60
+ console.error(`// Test point ${i + 1}: ${formattedResult}`);
61
+ allPassed = false;
62
+ }
63
+ }
64
+ if (allPassed) {
65
+ const result = checker.check(func, gradients, env, testPoints[0]);
66
+ // Prefix with // so output is valid code
67
+ console.error('// ' + formatGradCheckResult(result, func.name));
68
+ }
69
+ return allPassed;
70
+ }
9
71
  function printUsage() {
10
72
  console.log(`
11
73
  GradientScript - Symbolic Differentiation for Structured Types
@@ -16,7 +78,7 @@ Usage:
16
78
  Options:
17
79
  --format <format> Output format: typescript (default), javascript, python, csharp
18
80
  --no-simplify Disable gradient simplification
19
- --no-cse Disable common subexpression elimination
81
+ --no-cse Disable optimization (e-graph CSE)
20
82
  --no-comments Omit comments in generated code
21
83
  --guards Emit runtime guards for division by zero (experimental)
22
84
  --epsilon <value> Epsilon value for guards (default: 1e-10)
@@ -45,6 +107,9 @@ For more information and examples:
45
107
 
46
108
  README (raw, LLM-friendly):
47
109
  https://raw.githubusercontent.com/mfagerlund/gradient-script/main/README.md
110
+
111
+ LLM Optimization Guide (for AI agents writing .gs files):
112
+ https://raw.githubusercontent.com/mfagerlund/gradient-script/main/docs/LLM-OPTIMIZATION-GUIDE.md
48
113
  `.trim());
49
114
  }
50
115
  function main() {
@@ -64,6 +129,7 @@ function main() {
64
129
  simplify: true,
65
130
  cse: true
66
131
  };
132
+ let skipVerify = false;
67
133
  for (let i = 1; i < args.length; i++) {
68
134
  const arg = args[i];
69
135
  if (arg === '--format') {
@@ -138,13 +204,20 @@ function main() {
138
204
  process.exit(1);
139
205
  }
140
206
  const outputs = [];
207
+ let hasVerificationFailure = false;
141
208
  program.functions.forEach((func, index) => {
142
209
  const env = inferFunction(func);
143
210
  const gradients = computeFunctionGradients(func, env);
211
+ // MANDATORY gradient verification
212
+ const verified = verifyGradients(func, gradients, env);
213
+ if (!verified) {
214
+ hasVerificationFailure = true;
215
+ }
144
216
  const guardAnalysis = analyzeGuards(func);
145
217
  if (guardAnalysis.hasIssues) {
146
- console.error('Function "' + func.name + '" may have edge cases:');
147
- console.error(formatGuardWarnings(guardAnalysis));
218
+ // Format warnings as comments so output remains valid code even if stderr is captured
219
+ console.error('// Function "' + func.name + '" may have edge cases:');
220
+ console.error(formatGuardWarnings(guardAnalysis, true));
148
221
  }
149
222
  const perFunctionOptions = { ...options };
150
223
  if (index > 0 && perFunctionOptions.includeComments !== false) {
@@ -153,6 +226,10 @@ function main() {
153
226
  const code = generateComplete(func, gradients, env, perFunctionOptions);
154
227
  outputs.push(code);
155
228
  });
229
+ if (hasVerificationFailure) {
230
+ console.error('// ERROR: Gradient verification failed. Output may contain incorrect gradients!');
231
+ process.exit(1);
232
+ }
156
233
  console.log(outputs.join('\n\n'));
157
234
  }
158
235
  catch (err) {
@@ -56,7 +56,7 @@ export declare class ExpressionCodeGen {
56
56
  */
57
57
  export declare function generateGradientFunction(func: FunctionDef, gradients: GradientResult, env: TypeEnv, options?: CodeGenOptions): string;
58
58
  /**
59
- * Generate the original forward function
59
+ * Generate the original forward function (with optional e-graph optimization)
60
60
  */
61
61
  export declare function generateForwardFunction(func: FunctionDef, options?: CodeGenOptions): string;
62
62
  /**
@@ -2,11 +2,12 @@
2
2
  * Code generation for GradientScript DSL
3
3
  * Generates TypeScript/JavaScript code with gradient functions
4
4
  */
5
- import { simplifyGradients } from './Simplify.js';
5
+ import { simplifyGradients, simplifyPostCSE } from './Simplify.js';
6
6
  import { ExpressionTransformer } from './ExpressionTransformer.js';
7
- import { eliminateCommonSubexpressionsStructured } from './CSE.js';
7
+ import { optimizeWithEGraph } from './egraph/index.js';
8
8
  import { CodeGenError } from './Errors.js';
9
9
  import { serializeExpression } from './ExpressionUtils.js';
10
+ import { inlineExpression } from './Inliner.js';
10
11
  function capitalize(str) {
11
12
  return str.charAt(0).toUpperCase() + str.slice(1);
12
13
  }
@@ -79,10 +80,12 @@ export class ExpressionCodeGen {
79
80
  return left;
80
81
  }
81
82
  else if (exponent === 2) {
82
- return `${left} * ${left}`;
83
+ // Wrap in parens because we're changing ^ (precedence 3) to * (precedence 2)
84
+ // Without parens, a / b^2 would become a / b * b instead of a / (b * b)
85
+ return `(${left} * ${left})`;
83
86
  }
84
87
  else if (exponent === 3) {
85
- return `${left} * ${left} * ${left}`;
88
+ return `(${left} * ${left} * ${left})`;
86
89
  }
87
90
  }
88
91
  }
@@ -300,15 +303,162 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
300
303
  // Forward pass - compute intermediate variables
301
304
  // Track which expressions are already computed for CSE reuse
302
305
  const forwardExpressionMap = new Map();
306
+ // Build substitution map for inlining (to match gradient expressions which are fully inlined)
307
+ const substitutionMap = new Map();
308
+ for (const stmt of func.body) {
309
+ if (stmt.kind === 'assignment') {
310
+ substitutionMap.set(stmt.variable, stmt.expression);
311
+ }
312
+ }
313
+ // Collect forward variable names and expressions for optimization
314
+ const forwardVars = new Set();
315
+ const forwardVarExprs = new Map();
316
+ for (const stmt of func.body) {
317
+ if (stmt.kind === 'assignment') {
318
+ forwardVars.add(stmt.variable);
319
+ forwardVarExprs.set(stmt.variable, stmt.expression);
320
+ }
321
+ }
322
+ // Optimize forward expressions with e-graph
323
+ let optimizedForwardExprs = forwardVarExprs;
324
+ let forwardCseTemps = new Map();
325
+ if (options.cse !== false && forwardVarExprs.size > 0) {
326
+ const forOptimizer = new Map();
327
+ forOptimizer.set('_forward', forwardVarExprs);
328
+ const forwardResult = optimizeWithEGraph(forOptimizer, { verbose: false });
329
+ // Rename forward temps to avoid conflicts with gradient temps (use _fwd prefix)
330
+ const rawTemps = forwardResult.intermediates;
331
+ const renameMap = new Map();
332
+ for (const oldName of rawTemps.keys()) {
333
+ const newName = oldName.replace('_tmp', '_fwd');
334
+ renameMap.set(oldName, newName);
335
+ }
336
+ // Apply renaming to temp definitions
337
+ for (const [oldName, expr] of rawTemps) {
338
+ const newName = renameMap.get(oldName);
339
+ forwardCseTemps.set(newName, renameTempRefs(expr, renameMap));
340
+ }
341
+ // Apply renaming to optimized expressions
342
+ optimizedForwardExprs = new Map();
343
+ for (const [varName, expr] of (forwardResult.gradients.get('_forward') || forwardVarExprs)) {
344
+ optimizedForwardExprs.set(varName, renameTempRefs(expr, renameMap));
345
+ }
346
+ }
347
+ // Helper to rename temp references in an expression
348
+ function renameTempRefs(expr, renameMap) {
349
+ if (expr.kind === 'variable') {
350
+ const newName = renameMap.get(expr.name);
351
+ return newName ? { kind: 'variable', name: newName } : expr;
352
+ }
353
+ else if (expr.kind === 'binary') {
354
+ return {
355
+ kind: 'binary',
356
+ operator: expr.operator,
357
+ left: renameTempRefs(expr.left, renameMap),
358
+ right: renameTempRefs(expr.right, renameMap)
359
+ };
360
+ }
361
+ else if (expr.kind === 'unary') {
362
+ return {
363
+ kind: 'unary',
364
+ operator: expr.operator,
365
+ operand: renameTempRefs(expr.operand, renameMap)
366
+ };
367
+ }
368
+ else if (expr.kind === 'call') {
369
+ return {
370
+ kind: 'call',
371
+ name: expr.name,
372
+ args: expr.args.map(a => renameTempRefs(a, renameMap))
373
+ };
374
+ }
375
+ else if (expr.kind === 'component') {
376
+ return {
377
+ kind: 'component',
378
+ object: renameTempRefs(expr.object, renameMap),
379
+ component: expr.component
380
+ };
381
+ }
382
+ return expr;
383
+ }
384
+ // Helper to find which forward vars an expression depends on
385
+ function findForwardVarDeps(expr) {
386
+ const deps = new Set();
387
+ function visit(e) {
388
+ if (e.kind === 'variable' && forwardVars.has(e.name)) {
389
+ deps.add(e.name);
390
+ }
391
+ else if (e.kind === 'binary') {
392
+ visit(e.left);
393
+ visit(e.right);
394
+ }
395
+ else if (e.kind === 'unary') {
396
+ visit(e.operand);
397
+ }
398
+ else if (e.kind === 'call') {
399
+ e.args.forEach(visit);
400
+ }
401
+ else if (e.kind === 'component') {
402
+ visit(e.object);
403
+ }
404
+ }
405
+ visit(expr);
406
+ return deps;
407
+ }
408
+ // Track which CSE temps need to be emitted after which forward var
409
+ const fwdTempAfterVar = new Map();
410
+ const fwdTempsBeforeAny = [];
411
+ for (const [tempName, tempExpr] of forwardCseTemps) {
412
+ const deps = findForwardVarDeps(tempExpr);
413
+ if (deps.size === 0) {
414
+ fwdTempsBeforeAny.push({ name: tempName, expr: tempExpr });
415
+ }
416
+ else {
417
+ let lastDep = '';
418
+ for (const stmt of func.body) {
419
+ if (stmt.kind === 'assignment' && deps.has(stmt.variable)) {
420
+ lastDep = stmt.variable;
421
+ }
422
+ }
423
+ if (lastDep) {
424
+ if (!fwdTempAfterVar.has(lastDep)) {
425
+ fwdTempAfterVar.set(lastDep, []);
426
+ }
427
+ fwdTempAfterVar.get(lastDep).push({ name: tempName, expr: tempExpr });
428
+ }
429
+ }
430
+ }
431
+ // Emit forward CSE temps that don't depend on forward vars
432
+ for (const { name: tempName, expr } of fwdTempsBeforeAny) {
433
+ const code = codegen.generate(expr);
434
+ if (format === 'typescript' || format === 'javascript') {
435
+ lines.push(` const ${tempName} = ${code};`);
436
+ }
437
+ else if (format === 'python') {
438
+ lines.push(` ${tempName} = ${code}`);
439
+ }
440
+ else if (format === 'csharp') {
441
+ lines.push(` ${csharpFloatType} ${tempName} = ${code};`);
442
+ }
443
+ }
444
+ // Generate forward variable assignments with interleaved temps
303
445
  for (const stmt of func.body) {
304
446
  if (stmt.kind === 'assignment') {
305
447
  const varName = stmt.variable;
306
- const generatedExpr = codegen.generate(stmt.expression);
448
+ const expr = optimizedForwardExprs.get(varName) || stmt.expression;
449
+ const generatedExpr = codegen.generate(expr);
307
450
  if (shouldTrackForForwardReuse(stmt.expression)) {
451
+ // Register the original expression
308
452
  const exprKey = serializeExpression(stmt.expression);
309
453
  if (!forwardExpressionMap.has(exprKey)) {
310
454
  forwardExpressionMap.set(exprKey, varName);
311
455
  }
456
+ // Also register the fully inlined form (this is what gradient expressions will have)
457
+ const inlinedExpr = inlineExpression(stmt.expression, substitutionMap);
458
+ const inlinedKey = serializeExpression(inlinedExpr);
459
+ if (inlinedKey !== exprKey && !forwardExpressionMap.has(inlinedKey)) {
460
+ forwardExpressionMap.set(inlinedKey, varName);
461
+ }
312
462
  }
313
463
  if (format === 'typescript' || format === 'javascript') {
314
464
  lines.push(` const ${varName} = ${generatedExpr};`);
@@ -319,6 +469,20 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
319
469
  else if (format === 'csharp') {
320
470
  lines.push(` ${csharpFloatType} ${varName} = ${generatedExpr};`);
321
471
  }
472
+ // Emit any forward CSE temps that depend on this var
473
+ const tempsForVar = fwdTempAfterVar.get(varName) || [];
474
+ for (const { name: tempName, expr: tempExpr } of tempsForVar) {
475
+ const tempCode = codegen.generate(tempExpr);
476
+ if (format === 'typescript' || format === 'javascript') {
477
+ lines.push(` const ${tempName} = ${tempCode};`);
478
+ }
479
+ else if (format === 'python') {
480
+ lines.push(` ${tempName} = ${tempCode}`);
481
+ }
482
+ else if (format === 'csharp') {
483
+ lines.push(` ${csharpFloatType} ${tempName} = ${tempCode};`);
484
+ }
485
+ }
322
486
  }
323
487
  }
324
488
  // Compute output value - reuse forward pass variables if possible
@@ -372,77 +536,71 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
372
536
  if (includeComments) {
373
537
  lines.push(` ${comment} Gradients`);
374
538
  }
375
- // Apply CSE if requested
539
+ // Apply e-graph optimization (CSE + algebraic simplification)
376
540
  const shouldApplyCSE = options.cse !== false; // Default to true
377
- const cseIntermediates = new Map();
541
+ let cseIntermediates = new Map();
378
542
  if (shouldApplyCSE) {
379
- // Collect all gradient expressions for CSE analysis
543
+ // Collect all gradient components into a single map for global optimization
544
+ // Include both structured and scalar gradients (scalar uses 'value' as component)
545
+ const allGradientComponents = new Map();
380
546
  for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
381
547
  if (isStructuredGradient(gradient)) {
382
- const cseResult = eliminateCommonSubexpressionsStructured(gradient.components);
383
- // Merge intermediates
384
- for (const [name, expr] of cseResult.intermediates.entries()) {
385
- cseIntermediates.set(name, expr);
386
- }
387
- // Update gradient components with CSE-simplified versions
388
- gradient.components = cseResult.components;
389
- }
390
- }
391
- // Generate intermediate variables from CSE
392
- if (cseIntermediates.size > 0) {
393
- // Check if we should emit guards (opt-in)
394
- const shouldEmitGuards = options.emitGuards === true;
395
- const epsilon = options.epsilon || 1e-10;
396
- // Identify potential denominators (sum of squares patterns)
397
- const denominatorVars = new Set();
398
- for (const [varName, expr] of cseIntermediates.entries()) {
399
- const code = codegen.generate(expr);
400
- // Check if this looks like a denominator (contains + and squared terms)
401
- if (code.includes('+') && (code.includes('* ') || code.includes('Math.pow'))) {
402
- denominatorVars.add(varName);
403
- }
548
+ allGradientComponents.set(paramName, gradient.components);
404
549
  }
405
- for (const [varName, expr] of cseIntermediates.entries()) {
406
- const code = codegen.generate(expr);
407
- if (format === 'typescript' || format === 'javascript') {
408
- lines.push(` const ${varName} = ${code};`);
409
- }
410
- else if (format === 'python') {
411
- lines.push(` ${varName} = ${code}`);
412
- }
413
- else if (format === 'csharp') {
414
- lines.push(` ${csharpFloatType} ${varName} = ${code};`);
415
- }
550
+ else {
551
+ // Scalar gradient - wrap as a single 'value' component
552
+ allGradientComponents.set(paramName, new Map([['value', gradient]]));
416
553
  }
417
- // Emit epsilon guard if needed
418
- if (shouldEmitGuards && denominatorVars.size > 0) {
419
- lines.push('');
420
- if (includeComments) {
421
- lines.push(` ${comment} Guard against division by zero`);
554
+ }
555
+ // Run e-graph optimization globally across ALL gradient expressions
556
+ const globalCSE = optimizeWithEGraph(allGradientComponents, { verbose: false });
557
+ cseIntermediates = globalCSE.intermediates;
558
+ // Update gradient components with optimized versions
559
+ for (const [paramName, simplifiedComponents] of globalCSE.gradients.entries()) {
560
+ const gradient = gradientsToUse.gradients.get(paramName);
561
+ if (gradient && isStructuredGradient(gradient)) {
562
+ gradient.components = simplifiedComponents;
563
+ }
564
+ else {
565
+ // Scalar gradient - unwrap from 'value' component
566
+ const optimizedExpr = simplifiedComponents.get('value');
567
+ if (optimizedExpr) {
568
+ gradientsToUse.gradients.set(paramName, optimizedExpr);
422
569
  }
423
- for (const denom of denominatorVars) {
424
- if (format === 'typescript' || format === 'javascript') {
425
- lines.push(` if (Math.abs(${denom}) < ${epsilon}) {`);
426
- lines.push(` ${comment} Return zero gradients for degenerate case`);
427
- // Emit zero gradient structure
428
- const zeroGrads = [];
429
- for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
430
- if (isStructuredGradient(gradient)) {
431
- const components = Array.from(gradient.components.keys());
432
- const zeroStruct = components.map(c => `${c}: 0`).join(', ');
433
- zeroGrads.push(`d${paramName}: { ${zeroStruct} }`);
434
- }
435
- else {
436
- zeroGrads.push(`d${paramName}: 0`);
437
- }
438
- }
439
- lines.push(` return { value, ${zeroGrads.join(', ')} };`);
440
- lines.push(` }`);
441
- }
570
+ }
571
+ }
572
+ // Post-CSE simplification: apply rules that were skipped to avoid CSE interference
573
+ // Specifically: a + a → 2 * a (now safe because temps have been extracted)
574
+ for (const [varName, expr] of cseIntermediates) {
575
+ cseIntermediates.set(varName, simplifyPostCSE(expr));
576
+ }
577
+ for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
578
+ if (isStructuredGradient(gradient)) {
579
+ for (const [comp, expr] of gradient.components.entries()) {
580
+ gradient.components.set(comp, simplifyPostCSE(expr));
442
581
  }
443
582
  }
444
- lines.push('');
583
+ else {
584
+ // Scalar gradient - apply post-CSE simplification
585
+ gradientsToUse.gradients.set(paramName, simplifyPostCSE(gradient));
586
+ }
587
+ }
588
+ }
589
+ // Generate CSE intermediate variables
590
+ if (cseIntermediates.size > 0) {
591
+ for (const [varName, expr] of cseIntermediates.entries()) {
592
+ const code = codegen.generate(expr);
593
+ if (format === 'typescript' || format === 'javascript') {
594
+ lines.push(` const ${varName} = ${code};`);
595
+ }
596
+ else if (format === 'python') {
597
+ lines.push(` ${varName} = ${code}`);
598
+ }
599
+ else if (format === 'csharp') {
600
+ lines.push(` ${csharpFloatType} ${varName} = ${code};`);
601
+ }
445
602
  }
603
+ lines.push('');
446
604
  }
447
605
  for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
448
606
  // Use shorter names: du, dv instead of grad_u, grad_v
@@ -530,11 +688,12 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
530
688
  return lines.join('\n');
531
689
  }
532
690
  /**
533
- * Generate the original forward function
691
+ * Generate the original forward function (with optional e-graph optimization)
534
692
  */
535
693
  export function generateForwardFunction(func, options = {}) {
536
694
  const format = options.format || 'typescript';
537
695
  const csharpFloatType = options.csharpFloatType || 'float';
696
+ const shouldOptimize = options.cse !== false; // Optimize by default
538
697
  const codegen = new ExpressionCodeGen(format, csharpFloatType);
539
698
  const lines = [];
540
699
  // Function signature
@@ -549,28 +708,127 @@ export function generateForwardFunction(func, options = {}) {
549
708
  const floatType = csharpFloatType;
550
709
  const params = func.parameters.map(p => {
551
710
  if (p.paramType && p.paramType.components) {
552
- // Structured parameter - create a struct type name
553
711
  return `${capitalize(p.name)}Struct ${p.name}`;
554
712
  }
555
713
  return `${floatType} ${p.name}`;
556
714
  }).join(', ');
557
- // Generate struct definitions for structured parameters first (we'll prepend them later)
558
715
  lines.push(`public static ${floatType} ${capitalize(func.name)}(${params})`);
559
716
  lines.push('{');
560
717
  }
561
- // Body
718
+ // Collect all forward variable names (for dependency tracking)
719
+ const forwardVars = new Set();
720
+ for (const stmt of func.body) {
721
+ if (stmt.kind === 'assignment') {
722
+ forwardVars.add(stmt.variable);
723
+ }
724
+ }
725
+ // Collect expressions for optimization
726
+ const varExpressions = new Map();
727
+ for (const stmt of func.body) {
728
+ if (stmt.kind === 'assignment') {
729
+ varExpressions.set(stmt.variable, stmt.expression);
730
+ }
731
+ }
732
+ // Optimize with e-graph if enabled
733
+ let optimizedExprs = varExpressions;
734
+ let cseTemps = new Map();
735
+ if (shouldOptimize && varExpressions.size > 0) {
736
+ const forOptimizer = new Map();
737
+ forOptimizer.set('_forward', varExpressions);
738
+ const result = optimizeWithEGraph(forOptimizer, { verbose: false });
739
+ cseTemps = result.intermediates;
740
+ optimizedExprs = result.gradients.get('_forward') || varExpressions;
741
+ }
742
+ // Helper to find which forward vars an expression depends on
743
+ function findForwardVarDeps(expr) {
744
+ const deps = new Set();
745
+ function visit(e) {
746
+ if (e.kind === 'variable' && forwardVars.has(e.name)) {
747
+ deps.add(e.name);
748
+ }
749
+ else if (e.kind === 'binary') {
750
+ visit(e.left);
751
+ visit(e.right);
752
+ }
753
+ else if (e.kind === 'unary') {
754
+ visit(e.operand);
755
+ }
756
+ else if (e.kind === 'call') {
757
+ e.args.forEach(visit);
758
+ }
759
+ else if (e.kind === 'component') {
760
+ visit(e.object);
761
+ }
762
+ }
763
+ visit(expr);
764
+ return deps;
765
+ }
766
+ // Track which temps need to be emitted after which forward var
767
+ const tempAfterVar = new Map();
768
+ const tempsEmittedBeforeAny = [];
769
+ for (const [tempName, tempExpr] of cseTemps) {
770
+ const deps = findForwardVarDeps(tempExpr);
771
+ if (deps.size === 0) {
772
+ // Temp only depends on params - emit before any forward vars
773
+ tempsEmittedBeforeAny.push({ name: tempName, expr: tempExpr });
774
+ }
775
+ else {
776
+ // Find the last forward var this temp depends on
777
+ let lastDep = '';
778
+ for (const stmt of func.body) {
779
+ if (stmt.kind === 'assignment' && deps.has(stmt.variable)) {
780
+ lastDep = stmt.variable;
781
+ }
782
+ }
783
+ if (lastDep) {
784
+ if (!tempAfterVar.has(lastDep)) {
785
+ tempAfterVar.set(lastDep, []);
786
+ }
787
+ tempAfterVar.get(lastDep).push({ name: tempName, expr: tempExpr });
788
+ }
789
+ }
790
+ }
791
+ // Emit temps that don't depend on forward vars
792
+ for (const { name: tempName, expr } of tempsEmittedBeforeAny) {
793
+ const code = codegen.generate(expr);
794
+ if (format === 'typescript' || format === 'javascript') {
795
+ lines.push(` const ${tempName} = ${code};`);
796
+ }
797
+ else if (format === 'python') {
798
+ lines.push(` ${tempName} = ${code}`);
799
+ }
800
+ else if (format === 'csharp') {
801
+ lines.push(` ${csharpFloatType} ${tempName} = ${code};`);
802
+ }
803
+ }
804
+ // Generate variable assignments with interleaved temps
562
805
  for (const stmt of func.body) {
563
806
  if (stmt.kind === 'assignment') {
564
807
  const varName = stmt.variable;
565
- const expr = codegen.generate(stmt.expression);
808
+ const expr = optimizedExprs.get(varName) || stmt.expression;
809
+ const code = codegen.generate(expr);
566
810
  if (format === 'typescript' || format === 'javascript') {
567
- lines.push(` const ${varName} = ${expr};`);
811
+ lines.push(` const ${varName} = ${code};`);
568
812
  }
569
813
  else if (format === 'python') {
570
- lines.push(` ${varName} = ${expr}`);
814
+ lines.push(` ${varName} = ${code}`);
571
815
  }
572
816
  else if (format === 'csharp') {
573
- lines.push(` ${csharpFloatType} ${varName} = ${expr};`);
817
+ lines.push(` ${csharpFloatType} ${varName} = ${code};`);
818
+ }
819
+ // Emit any temps that depend on this var
820
+ const tempsForVar = tempAfterVar.get(varName) || [];
821
+ for (const { name: tempName, expr: tempExpr } of tempsForVar) {
822
+ const tempCode = codegen.generate(tempExpr);
823
+ if (format === 'typescript' || format === 'javascript') {
824
+ lines.push(` const ${tempName} = ${tempCode};`);
825
+ }
826
+ else if (format === 'python') {
827
+ lines.push(` ${tempName} = ${tempCode}`);
828
+ }
829
+ else if (format === 'csharp') {
830
+ lines.push(` ${csharpFloatType} ${tempName} = ${tempCode};`);
831
+ }
574
832
  }
575
833
  }
576
834
  }
@@ -54,10 +54,16 @@ export declare function containsVariable(expr: Expression, varName: string): boo
54
54
  */
55
55
  export declare function expressionDepth(expr: Expression): number;
56
56
  /**
57
- * Serializes an expression to canonical string representation.
58
- * Used for expression comparison and hashing (CSE, CodeGen forward reuse).
57
+ * Serializes an expression to structural string representation.
58
+ * Used for exact expression comparison - operand order matters.
59
59
  *
60
60
  * This ensures consistent string representation of expressions across different
61
61
  * parts of the codebase.
62
62
  */
63
63
  export declare function serializeExpression(expr: Expression): string;
64
+ /**
65
+ * Serializes an expression to canonical form for CSE matching.
66
+ * Commutative operations (+ and *) have operands sorted lexicographically,
67
+ * so a*b and b*a produce the same canonical string.
68
+ */
69
+ export declare function serializeCanonical(expr: Expression): string;