gradient-script 0.2.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. package/README.md +3 -1
  2. package/dist/cli.js +219 -6
  3. package/dist/dsl/CodeGen.d.ts +1 -1
  4. package/dist/dsl/CodeGen.js +336 -74
  5. package/dist/dsl/ExpressionUtils.d.ts +8 -2
  6. package/dist/dsl/ExpressionUtils.js +34 -2
  7. package/dist/dsl/GradientChecker.d.ts +21 -0
  8. package/dist/dsl/GradientChecker.js +109 -23
  9. package/dist/dsl/Guards.d.ts +1 -1
  10. package/dist/dsl/Guards.js +14 -13
  11. package/dist/dsl/Inliner.d.ts +5 -0
  12. package/dist/dsl/Inliner.js +8 -0
  13. package/dist/dsl/Simplify.d.ts +7 -0
  14. package/dist/dsl/Simplify.js +136 -0
  15. package/dist/dsl/egraph/Convert.d.ts +23 -0
  16. package/dist/dsl/egraph/Convert.js +84 -0
  17. package/dist/dsl/egraph/EGraph.d.ts +93 -0
  18. package/dist/dsl/egraph/EGraph.js +292 -0
  19. package/dist/dsl/egraph/ENode.d.ts +63 -0
  20. package/dist/dsl/egraph/ENode.js +94 -0
  21. package/dist/dsl/egraph/Extractor.d.ts +49 -0
  22. package/dist/dsl/egraph/Extractor.js +1068 -0
  23. package/dist/dsl/egraph/Optimizer.d.ts +50 -0
  24. package/dist/dsl/egraph/Optimizer.js +88 -0
  25. package/dist/dsl/egraph/Pattern.d.ts +80 -0
  26. package/dist/dsl/egraph/Pattern.js +325 -0
  27. package/dist/dsl/egraph/Rewriter.d.ts +44 -0
  28. package/dist/dsl/egraph/Rewriter.js +131 -0
  29. package/dist/dsl/egraph/Rules.d.ts +44 -0
  30. package/dist/dsl/egraph/Rules.js +187 -0
  31. package/dist/dsl/egraph/index.d.ts +15 -0
  32. package/dist/dsl/egraph/index.js +21 -0
  33. package/package.json +1 -1
  34. package/dist/dsl/CSE.d.ts +0 -21
  35. package/dist/dsl/CSE.js +0 -168
  36. package/dist/symbolic/AST.d.ts +0 -113
  37. package/dist/symbolic/AST.js +0 -128
  38. package/dist/symbolic/CodeGen.d.ts +0 -35
  39. package/dist/symbolic/CodeGen.js +0 -280
  40. package/dist/symbolic/Parser.d.ts +0 -64
  41. package/dist/symbolic/Parser.js +0 -329
  42. package/dist/symbolic/Simplify.d.ts +0 -10
  43. package/dist/symbolic/Simplify.js +0 -244
  44. package/dist/symbolic/SymbolicDiff.d.ts +0 -35
  45. package/dist/symbolic/SymbolicDiff.js +0 -339
@@ -2,11 +2,12 @@
2
2
  * Code generation for GradientScript DSL
3
3
  * Generates TypeScript/JavaScript code with gradient functions
4
4
  */
5
- import { simplifyGradients } from './Simplify.js';
5
+ import { simplifyGradients, simplifyPostCSE } from './Simplify.js';
6
6
  import { ExpressionTransformer } from './ExpressionTransformer.js';
7
- import { eliminateCommonSubexpressionsStructured } from './CSE.js';
7
+ import { optimizeWithEGraph } from './egraph/index.js';
8
8
  import { CodeGenError } from './Errors.js';
9
9
  import { serializeExpression } from './ExpressionUtils.js';
10
+ import { inlineExpression } from './Inliner.js';
10
11
  function capitalize(str) {
11
12
  return str.charAt(0).toUpperCase() + str.slice(1);
12
13
  }
@@ -79,10 +80,12 @@ export class ExpressionCodeGen {
79
80
  return left;
80
81
  }
81
82
  else if (exponent === 2) {
82
- return `${left} * ${left}`;
83
+ // Wrap in parens because we're changing ^ (precedence 3) to * (precedence 2)
84
+ // Without parens, a / b^2 would become a / b * b instead of a / (b * b)
85
+ return `(${left} * ${left})`;
83
86
  }
84
87
  else if (exponent === 3) {
85
- return `${left} * ${left} * ${left}`;
88
+ return `(${left} * ${left} * ${left})`;
86
89
  }
87
90
  }
88
91
  }
@@ -166,6 +169,10 @@ export class ExpressionCodeGen {
166
169
  }
167
170
  genUnary(expr) {
168
171
  const operand = this.generate(expr.operand);
172
+ // Parenthesize binary operands to avoid precedence bugs: -(a + b) not -a + b
173
+ if (expr.operand.kind === 'binary') {
174
+ return `${expr.operator}(${operand})`;
175
+ }
169
176
  return `${expr.operator}${operand}`;
170
177
  }
171
178
  genCall(expr) {
@@ -300,15 +307,162 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
300
307
  // Forward pass - compute intermediate variables
301
308
  // Track which expressions are already computed for CSE reuse
302
309
  const forwardExpressionMap = new Map();
310
+ // Build substitution map for inlining (to match gradient expressions which are fully inlined)
311
+ const substitutionMap = new Map();
312
+ for (const stmt of func.body) {
313
+ if (stmt.kind === 'assignment') {
314
+ substitutionMap.set(stmt.variable, stmt.expression);
315
+ }
316
+ }
317
+ // Collect forward variable names and expressions for optimization
318
+ const forwardVars = new Set();
319
+ const forwardVarExprs = new Map();
320
+ for (const stmt of func.body) {
321
+ if (stmt.kind === 'assignment') {
322
+ forwardVars.add(stmt.variable);
323
+ forwardVarExprs.set(stmt.variable, stmt.expression);
324
+ }
325
+ }
326
+ // Optimize forward expressions with e-graph
327
+ let optimizedForwardExprs = forwardVarExprs;
328
+ let forwardCseTemps = new Map();
329
+ if (options.cse !== false && forwardVarExprs.size > 0) {
330
+ const forOptimizer = new Map();
331
+ forOptimizer.set('_forward', forwardVarExprs);
332
+ const forwardResult = optimizeWithEGraph(forOptimizer, { verbose: false });
333
+ // Rename forward temps to avoid conflicts with gradient temps (use _fwd prefix)
334
+ const rawTemps = forwardResult.intermediates;
335
+ const renameMap = new Map();
336
+ for (const oldName of rawTemps.keys()) {
337
+ const newName = oldName.replace('_tmp', '_fwd');
338
+ renameMap.set(oldName, newName);
339
+ }
340
+ // Apply renaming to temp definitions
341
+ for (const [oldName, expr] of rawTemps) {
342
+ const newName = renameMap.get(oldName);
343
+ forwardCseTemps.set(newName, renameTempRefs(expr, renameMap));
344
+ }
345
+ // Apply renaming to optimized expressions
346
+ optimizedForwardExprs = new Map();
347
+ for (const [varName, expr] of (forwardResult.gradients.get('_forward') || forwardVarExprs)) {
348
+ optimizedForwardExprs.set(varName, renameTempRefs(expr, renameMap));
349
+ }
350
+ }
351
+ // Helper to rename temp references in an expression
352
+ function renameTempRefs(expr, renameMap) {
353
+ if (expr.kind === 'variable') {
354
+ const newName = renameMap.get(expr.name);
355
+ return newName ? { kind: 'variable', name: newName } : expr;
356
+ }
357
+ else if (expr.kind === 'binary') {
358
+ return {
359
+ kind: 'binary',
360
+ operator: expr.operator,
361
+ left: renameTempRefs(expr.left, renameMap),
362
+ right: renameTempRefs(expr.right, renameMap)
363
+ };
364
+ }
365
+ else if (expr.kind === 'unary') {
366
+ return {
367
+ kind: 'unary',
368
+ operator: expr.operator,
369
+ operand: renameTempRefs(expr.operand, renameMap)
370
+ };
371
+ }
372
+ else if (expr.kind === 'call') {
373
+ return {
374
+ kind: 'call',
375
+ name: expr.name,
376
+ args: expr.args.map(a => renameTempRefs(a, renameMap))
377
+ };
378
+ }
379
+ else if (expr.kind === 'component') {
380
+ return {
381
+ kind: 'component',
382
+ object: renameTempRefs(expr.object, renameMap),
383
+ component: expr.component
384
+ };
385
+ }
386
+ return expr;
387
+ }
388
+ // Helper to find which forward vars an expression depends on
389
+ function findForwardVarDeps(expr) {
390
+ const deps = new Set();
391
+ function visit(e) {
392
+ if (e.kind === 'variable' && forwardVars.has(e.name)) {
393
+ deps.add(e.name);
394
+ }
395
+ else if (e.kind === 'binary') {
396
+ visit(e.left);
397
+ visit(e.right);
398
+ }
399
+ else if (e.kind === 'unary') {
400
+ visit(e.operand);
401
+ }
402
+ else if (e.kind === 'call') {
403
+ e.args.forEach(visit);
404
+ }
405
+ else if (e.kind === 'component') {
406
+ visit(e.object);
407
+ }
408
+ }
409
+ visit(expr);
410
+ return deps;
411
+ }
412
+ // Track which CSE temps need to be emitted after which forward var
413
+ const fwdTempAfterVar = new Map();
414
+ const fwdTempsBeforeAny = [];
415
+ for (const [tempName, tempExpr] of forwardCseTemps) {
416
+ const deps = findForwardVarDeps(tempExpr);
417
+ if (deps.size === 0) {
418
+ fwdTempsBeforeAny.push({ name: tempName, expr: tempExpr });
419
+ }
420
+ else {
421
+ let lastDep = '';
422
+ for (const stmt of func.body) {
423
+ if (stmt.kind === 'assignment' && deps.has(stmt.variable)) {
424
+ lastDep = stmt.variable;
425
+ }
426
+ }
427
+ if (lastDep) {
428
+ if (!fwdTempAfterVar.has(lastDep)) {
429
+ fwdTempAfterVar.set(lastDep, []);
430
+ }
431
+ fwdTempAfterVar.get(lastDep).push({ name: tempName, expr: tempExpr });
432
+ }
433
+ }
434
+ }
435
+ // Emit forward CSE temps that don't depend on forward vars
436
+ for (const { name: tempName, expr } of fwdTempsBeforeAny) {
437
+ const code = codegen.generate(expr);
438
+ if (format === 'typescript' || format === 'javascript') {
439
+ lines.push(` const ${tempName} = ${code};`);
440
+ }
441
+ else if (format === 'python') {
442
+ lines.push(` ${tempName} = ${code}`);
443
+ }
444
+ else if (format === 'csharp') {
445
+ lines.push(` ${csharpFloatType} ${tempName} = ${code};`);
446
+ }
447
+ }
448
+ // Generate forward variable assignments with interleaved temps
303
449
  for (const stmt of func.body) {
304
450
  if (stmt.kind === 'assignment') {
305
451
  const varName = stmt.variable;
306
- const generatedExpr = codegen.generate(stmt.expression);
452
+ const expr = optimizedForwardExprs.get(varName) || stmt.expression;
453
+ const generatedExpr = codegen.generate(expr);
307
454
  if (shouldTrackForForwardReuse(stmt.expression)) {
455
+ // Register the original expression
308
456
  const exprKey = serializeExpression(stmt.expression);
309
457
  if (!forwardExpressionMap.has(exprKey)) {
310
458
  forwardExpressionMap.set(exprKey, varName);
311
459
  }
460
+ // Also register the fully inlined form (this is what gradient expressions will have)
461
+ const inlinedExpr = inlineExpression(stmt.expression, substitutionMap);
462
+ const inlinedKey = serializeExpression(inlinedExpr);
463
+ if (inlinedKey !== exprKey && !forwardExpressionMap.has(inlinedKey)) {
464
+ forwardExpressionMap.set(inlinedKey, varName);
465
+ }
312
466
  }
313
467
  if (format === 'typescript' || format === 'javascript') {
314
468
  lines.push(` const ${varName} = ${generatedExpr};`);
@@ -319,6 +473,20 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
319
473
  else if (format === 'csharp') {
320
474
  lines.push(` ${csharpFloatType} ${varName} = ${generatedExpr};`);
321
475
  }
476
+ // Emit any forward CSE temps that depend on this var
477
+ const tempsForVar = fwdTempAfterVar.get(varName) || [];
478
+ for (const { name: tempName, expr: tempExpr } of tempsForVar) {
479
+ const tempCode = codegen.generate(tempExpr);
480
+ if (format === 'typescript' || format === 'javascript') {
481
+ lines.push(` const ${tempName} = ${tempCode};`);
482
+ }
483
+ else if (format === 'python') {
484
+ lines.push(` ${tempName} = ${tempCode}`);
485
+ }
486
+ else if (format === 'csharp') {
487
+ lines.push(` ${csharpFloatType} ${tempName} = ${tempCode};`);
488
+ }
489
+ }
322
490
  }
323
491
  }
324
492
  // Compute output value - reuse forward pass variables if possible
@@ -372,77 +540,71 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
372
540
  if (includeComments) {
373
541
  lines.push(` ${comment} Gradients`);
374
542
  }
375
- // Apply CSE if requested
543
+ // Apply e-graph optimization (CSE + algebraic simplification)
376
544
  const shouldApplyCSE = options.cse !== false; // Default to true
377
- const cseIntermediates = new Map();
545
+ let cseIntermediates = new Map();
378
546
  if (shouldApplyCSE) {
379
- // Collect all gradient expressions for CSE analysis
547
+ // Collect all gradient components into a single map for global optimization
548
+ // Include both structured and scalar gradients (scalar uses 'value' as component)
549
+ const allGradientComponents = new Map();
380
550
  for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
381
551
  if (isStructuredGradient(gradient)) {
382
- const cseResult = eliminateCommonSubexpressionsStructured(gradient.components);
383
- // Merge intermediates
384
- for (const [name, expr] of cseResult.intermediates.entries()) {
385
- cseIntermediates.set(name, expr);
386
- }
387
- // Update gradient components with CSE-simplified versions
388
- gradient.components = cseResult.components;
389
- }
390
- }
391
- // Generate intermediate variables from CSE
392
- if (cseIntermediates.size > 0) {
393
- // Check if we should emit guards (opt-in)
394
- const shouldEmitGuards = options.emitGuards === true;
395
- const epsilon = options.epsilon || 1e-10;
396
- // Identify potential denominators (sum of squares patterns)
397
- const denominatorVars = new Set();
398
- for (const [varName, expr] of cseIntermediates.entries()) {
399
- const code = codegen.generate(expr);
400
- // Check if this looks like a denominator (contains + and squared terms)
401
- if (code.includes('+') && (code.includes('* ') || code.includes('Math.pow'))) {
402
- denominatorVars.add(varName);
403
- }
552
+ allGradientComponents.set(paramName, gradient.components);
404
553
  }
405
- for (const [varName, expr] of cseIntermediates.entries()) {
406
- const code = codegen.generate(expr);
407
- if (format === 'typescript' || format === 'javascript') {
408
- lines.push(` const ${varName} = ${code};`);
409
- }
410
- else if (format === 'python') {
411
- lines.push(` ${varName} = ${code}`);
412
- }
413
- else if (format === 'csharp') {
414
- lines.push(` ${csharpFloatType} ${varName} = ${code};`);
415
- }
554
+ else {
555
+ // Scalar gradient - wrap as a single 'value' component
556
+ allGradientComponents.set(paramName, new Map([['value', gradient]]));
557
+ }
558
+ }
559
+ // Run e-graph optimization globally across ALL gradient expressions
560
+ const globalCSE = optimizeWithEGraph(allGradientComponents, { verbose: false });
561
+ cseIntermediates = globalCSE.intermediates;
562
+ // Update gradient components with optimized versions
563
+ for (const [paramName, simplifiedComponents] of globalCSE.gradients.entries()) {
564
+ const gradient = gradientsToUse.gradients.get(paramName);
565
+ if (gradient && isStructuredGradient(gradient)) {
566
+ gradient.components = simplifiedComponents;
416
567
  }
417
- // Emit epsilon guard if needed
418
- if (shouldEmitGuards && denominatorVars.size > 0) {
419
- lines.push('');
420
- if (includeComments) {
421
- lines.push(` ${comment} Guard against division by zero`);
568
+ else {
569
+ // Scalar gradient - unwrap from 'value' component
570
+ const optimizedExpr = simplifiedComponents.get('value');
571
+ if (optimizedExpr) {
572
+ gradientsToUse.gradients.set(paramName, optimizedExpr);
422
573
  }
423
- for (const denom of denominatorVars) {
424
- if (format === 'typescript' || format === 'javascript') {
425
- lines.push(` if (Math.abs(${denom}) < ${epsilon}) {`);
426
- lines.push(` ${comment} Return zero gradients for degenerate case`);
427
- // Emit zero gradient structure
428
- const zeroGrads = [];
429
- for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
430
- if (isStructuredGradient(gradient)) {
431
- const components = Array.from(gradient.components.keys());
432
- const zeroStruct = components.map(c => `${c}: 0`).join(', ');
433
- zeroGrads.push(`d${paramName}: { ${zeroStruct} }`);
434
- }
435
- else {
436
- zeroGrads.push(`d${paramName}: 0`);
437
- }
438
- }
439
- lines.push(` return { value, ${zeroGrads.join(', ')} };`);
440
- lines.push(` }`);
441
- }
574
+ }
575
+ }
576
+ // Post-CSE simplification: apply rules that were skipped to avoid CSE interference
577
+ // Specifically: a + a → 2 * a (now safe because temps have been extracted)
578
+ for (const [varName, expr] of cseIntermediates) {
579
+ cseIntermediates.set(varName, simplifyPostCSE(expr));
580
+ }
581
+ for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
582
+ if (isStructuredGradient(gradient)) {
583
+ for (const [comp, expr] of gradient.components.entries()) {
584
+ gradient.components.set(comp, simplifyPostCSE(expr));
442
585
  }
443
586
  }
444
- lines.push('');
587
+ else {
588
+ // Scalar gradient - apply post-CSE simplification
589
+ gradientsToUse.gradients.set(paramName, simplifyPostCSE(gradient));
590
+ }
591
+ }
592
+ }
593
+ // Generate CSE intermediate variables
594
+ if (cseIntermediates.size > 0) {
595
+ for (const [varName, expr] of cseIntermediates.entries()) {
596
+ const code = codegen.generate(expr);
597
+ if (format === 'typescript' || format === 'javascript') {
598
+ lines.push(` const ${varName} = ${code};`);
599
+ }
600
+ else if (format === 'python') {
601
+ lines.push(` ${varName} = ${code}`);
602
+ }
603
+ else if (format === 'csharp') {
604
+ lines.push(` ${csharpFloatType} ${varName} = ${code};`);
605
+ }
445
606
  }
607
+ lines.push('');
446
608
  }
447
609
  for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
448
610
  // Use shorter names: du, dv instead of grad_u, grad_v
@@ -530,11 +692,12 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
530
692
  return lines.join('\n');
531
693
  }
532
694
  /**
533
- * Generate the original forward function
695
+ * Generate the original forward function (with optional e-graph optimization)
534
696
  */
535
697
  export function generateForwardFunction(func, options = {}) {
536
698
  const format = options.format || 'typescript';
537
699
  const csharpFloatType = options.csharpFloatType || 'float';
700
+ const shouldOptimize = options.cse !== false; // Optimize by default
538
701
  const codegen = new ExpressionCodeGen(format, csharpFloatType);
539
702
  const lines = [];
540
703
  // Function signature
@@ -549,28 +712,127 @@ export function generateForwardFunction(func, options = {}) {
549
712
  const floatType = csharpFloatType;
550
713
  const params = func.parameters.map(p => {
551
714
  if (p.paramType && p.paramType.components) {
552
- // Structured parameter - create a struct type name
553
715
  return `${capitalize(p.name)}Struct ${p.name}`;
554
716
  }
555
717
  return `${floatType} ${p.name}`;
556
718
  }).join(', ');
557
- // Generate struct definitions for structured parameters first (we'll prepend them later)
558
719
  lines.push(`public static ${floatType} ${capitalize(func.name)}(${params})`);
559
720
  lines.push('{');
560
721
  }
561
- // Body
722
+ // Collect all forward variable names (for dependency tracking)
723
+ const forwardVars = new Set();
724
+ for (const stmt of func.body) {
725
+ if (stmt.kind === 'assignment') {
726
+ forwardVars.add(stmt.variable);
727
+ }
728
+ }
729
+ // Collect expressions for optimization
730
+ const varExpressions = new Map();
731
+ for (const stmt of func.body) {
732
+ if (stmt.kind === 'assignment') {
733
+ varExpressions.set(stmt.variable, stmt.expression);
734
+ }
735
+ }
736
+ // Optimize with e-graph if enabled
737
+ let optimizedExprs = varExpressions;
738
+ let cseTemps = new Map();
739
+ if (shouldOptimize && varExpressions.size > 0) {
740
+ const forOptimizer = new Map();
741
+ forOptimizer.set('_forward', varExpressions);
742
+ const result = optimizeWithEGraph(forOptimizer, { verbose: false });
743
+ cseTemps = result.intermediates;
744
+ optimizedExprs = result.gradients.get('_forward') || varExpressions;
745
+ }
746
+ // Helper to find which forward vars an expression depends on
747
+ function findForwardVarDeps(expr) {
748
+ const deps = new Set();
749
+ function visit(e) {
750
+ if (e.kind === 'variable' && forwardVars.has(e.name)) {
751
+ deps.add(e.name);
752
+ }
753
+ else if (e.kind === 'binary') {
754
+ visit(e.left);
755
+ visit(e.right);
756
+ }
757
+ else if (e.kind === 'unary') {
758
+ visit(e.operand);
759
+ }
760
+ else if (e.kind === 'call') {
761
+ e.args.forEach(visit);
762
+ }
763
+ else if (e.kind === 'component') {
764
+ visit(e.object);
765
+ }
766
+ }
767
+ visit(expr);
768
+ return deps;
769
+ }
770
+ // Track which temps need to be emitted after which forward var
771
+ const tempAfterVar = new Map();
772
+ const tempsEmittedBeforeAny = [];
773
+ for (const [tempName, tempExpr] of cseTemps) {
774
+ const deps = findForwardVarDeps(tempExpr);
775
+ if (deps.size === 0) {
776
+ // Temp only depends on params - emit before any forward vars
777
+ tempsEmittedBeforeAny.push({ name: tempName, expr: tempExpr });
778
+ }
779
+ else {
780
+ // Find the last forward var this temp depends on
781
+ let lastDep = '';
782
+ for (const stmt of func.body) {
783
+ if (stmt.kind === 'assignment' && deps.has(stmt.variable)) {
784
+ lastDep = stmt.variable;
785
+ }
786
+ }
787
+ if (lastDep) {
788
+ if (!tempAfterVar.has(lastDep)) {
789
+ tempAfterVar.set(lastDep, []);
790
+ }
791
+ tempAfterVar.get(lastDep).push({ name: tempName, expr: tempExpr });
792
+ }
793
+ }
794
+ }
795
+ // Emit temps that don't depend on forward vars
796
+ for (const { name: tempName, expr } of tempsEmittedBeforeAny) {
797
+ const code = codegen.generate(expr);
798
+ if (format === 'typescript' || format === 'javascript') {
799
+ lines.push(` const ${tempName} = ${code};`);
800
+ }
801
+ else if (format === 'python') {
802
+ lines.push(` ${tempName} = ${code}`);
803
+ }
804
+ else if (format === 'csharp') {
805
+ lines.push(` ${csharpFloatType} ${tempName} = ${code};`);
806
+ }
807
+ }
808
+ // Generate variable assignments with interleaved temps
562
809
  for (const stmt of func.body) {
563
810
  if (stmt.kind === 'assignment') {
564
811
  const varName = stmt.variable;
565
- const expr = codegen.generate(stmt.expression);
812
+ const expr = optimizedExprs.get(varName) || stmt.expression;
813
+ const code = codegen.generate(expr);
566
814
  if (format === 'typescript' || format === 'javascript') {
567
- lines.push(` const ${varName} = ${expr};`);
815
+ lines.push(` const ${varName} = ${code};`);
568
816
  }
569
817
  else if (format === 'python') {
570
- lines.push(` ${varName} = ${expr}`);
818
+ lines.push(` ${varName} = ${code}`);
571
819
  }
572
820
  else if (format === 'csharp') {
573
- lines.push(` ${csharpFloatType} ${varName} = ${expr};`);
821
+ lines.push(` ${csharpFloatType} ${varName} = ${code};`);
822
+ }
823
+ // Emit any temps that depend on this var
824
+ const tempsForVar = tempAfterVar.get(varName) || [];
825
+ for (const { name: tempName, expr: tempExpr } of tempsForVar) {
826
+ const tempCode = codegen.generate(tempExpr);
827
+ if (format === 'typescript' || format === 'javascript') {
828
+ lines.push(` const ${tempName} = ${tempCode};`);
829
+ }
830
+ else if (format === 'python') {
831
+ lines.push(` ${tempName} = ${tempCode}`);
832
+ }
833
+ else if (format === 'csharp') {
834
+ lines.push(` ${csharpFloatType} ${tempName} = ${tempCode};`);
835
+ }
574
836
  }
575
837
  }
576
838
  }
@@ -54,10 +54,16 @@ export declare function containsVariable(expr: Expression, varName: string): boo
54
54
  */
55
55
  export declare function expressionDepth(expr: Expression): number;
56
56
  /**
57
- * Serializes an expression to canonical string representation.
58
- * Used for expression comparison and hashing (CSE, CodeGen forward reuse).
57
+ * Serializes an expression to structural string representation.
58
+ * Used for exact expression comparison - operand order matters.
59
59
  *
60
60
  * This ensures consistent string representation of expressions across different
61
61
  * parts of the codebase.
62
62
  */
63
63
  export declare function serializeExpression(expr: Expression): string;
64
+ /**
65
+ * Serializes an expression to canonical form for CSE matching.
66
+ * Commutative operations (+ and *) have operands sorted lexicographically,
67
+ * so a*b and b*a produce the same canonical string.
68
+ */
69
+ export declare function serializeCanonical(expr: Expression): string;
@@ -174,8 +174,8 @@ export function expressionDepth(expr) {
174
174
  }
175
175
  }
176
176
  /**
177
- * Serializes an expression to canonical string representation.
178
- * Used for expression comparison and hashing (CSE, CodeGen forward reuse).
177
+ * Serializes an expression to structural string representation.
178
+ * Used for exact expression comparison - operand order matters.
179
179
  *
180
180
  * This ensures consistent string representation of expressions across different
181
181
  * parts of the codebase.
@@ -197,3 +197,35 @@ export function serializeExpression(expr) {
197
197
  return `comp(${serializeExpression(expr.object)},${expr.component})`;
198
198
  }
199
199
  }
200
+ /**
201
+ * Serializes an expression to canonical form for CSE matching.
202
+ * Commutative operations (+ and *) have operands sorted lexicographically,
203
+ * so a*b and b*a produce the same canonical string.
204
+ */
205
+ export function serializeCanonical(expr) {
206
+ switch (expr.kind) {
207
+ case 'number':
208
+ return `num(${expr.value})`;
209
+ case 'variable':
210
+ return `var(${expr.name})`;
211
+ case 'binary': {
212
+ const leftStr = serializeCanonical(expr.left);
213
+ const rightStr = serializeCanonical(expr.right);
214
+ // For commutative operations, sort operands lexicographically
215
+ if (expr.operator === '+' || expr.operator === '*') {
216
+ const [first, second] = leftStr <= rightStr ? [leftStr, rightStr] : [rightStr, leftStr];
217
+ return `bin(${expr.operator},${first},${second})`;
218
+ }
219
+ // Non-commutative: preserve order
220
+ return `bin(${expr.operator},${leftStr},${rightStr})`;
221
+ }
222
+ case 'unary':
223
+ return `un(${expr.operator},${serializeCanonical(expr.operand)})`;
224
+ case 'call': {
225
+ const args = expr.args.map(arg => serializeCanonical(arg)).join(',');
226
+ return `call(${expr.name},${args})`;
227
+ }
228
+ case 'component':
229
+ return `comp(${serializeCanonical(expr.object)},${expr.component})`;
230
+ }
231
+ }
@@ -17,9 +17,25 @@ type NumValue = number | {
17
17
  export interface GradCheckResult {
18
18
  passed: boolean;
19
19
  errors: GradCheckError[];
20
+ singularities: GradCheckSingularity[];
20
21
  maxError: number;
21
22
  meanError: number;
23
+ totalChecks: number;
22
24
  }
25
+ /**
26
+ * Singularity detected during gradient checking
27
+ * When both analytical and numerical produce NaN/Inf, it's a singularity, not a bug
28
+ */
29
+ export interface GradCheckSingularity {
30
+ parameter: string;
31
+ component?: string;
32
+ analytical: number;
33
+ numerical: number;
34
+ }
35
+ /**
36
+ * Format gradient check results as a human-readable string
37
+ */
38
+ export declare function formatGradCheckResult(result: GradCheckResult, funcName: string): string;
23
39
  export interface GradCheckError {
24
40
  parameter: string;
25
41
  component?: string;
@@ -39,6 +55,11 @@ export declare class GradientChecker {
39
55
  * Check gradients for a function
40
56
  */
41
57
  check(func: FunctionDef, gradients: GradientResult, env: TypeEnv, testPoint: Map<string, NumValue>): GradCheckResult;
58
+ /**
59
+ * Compare analytical and numerical gradients
60
+ * Distinguishes between: pass, error (mismatch), and singularity (both NaN/Inf)
61
+ */
62
+ private compareGradients;
42
63
  /**
43
64
  * Compute numerical gradient for scalar parameter using finite differences
44
65
  */