gradient-script 0.1.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. package/README.md +52 -9
  2. package/dist/cli.js +134 -19
  3. package/dist/dsl/AST.d.ts +8 -0
  4. package/dist/dsl/CodeGen.d.ts +8 -3
  5. package/dist/dsl/CodeGen.js +583 -132
  6. package/dist/dsl/Errors.d.ts +6 -1
  7. package/dist/dsl/Errors.js +70 -1
  8. package/dist/dsl/Expander.js +5 -2
  9. package/dist/dsl/ExpressionUtils.d.ts +14 -0
  10. package/dist/dsl/ExpressionUtils.js +56 -0
  11. package/dist/dsl/GradientChecker.d.ts +21 -0
  12. package/dist/dsl/GradientChecker.js +109 -23
  13. package/dist/dsl/Guards.d.ts +3 -1
  14. package/dist/dsl/Guards.js +86 -43
  15. package/dist/dsl/Inliner.d.ts +5 -0
  16. package/dist/dsl/Inliner.js +11 -2
  17. package/dist/dsl/Lexer.js +3 -1
  18. package/dist/dsl/Parser.js +11 -5
  19. package/dist/dsl/Simplify.d.ts +7 -0
  20. package/dist/dsl/Simplify.js +183 -0
  21. package/dist/dsl/egraph/Convert.d.ts +23 -0
  22. package/dist/dsl/egraph/Convert.js +84 -0
  23. package/dist/dsl/egraph/EGraph.d.ts +93 -0
  24. package/dist/dsl/egraph/EGraph.js +292 -0
  25. package/dist/dsl/egraph/ENode.d.ts +63 -0
  26. package/dist/dsl/egraph/ENode.js +94 -0
  27. package/dist/dsl/egraph/Extractor.d.ts +49 -0
  28. package/dist/dsl/egraph/Extractor.js +1068 -0
  29. package/dist/dsl/egraph/Optimizer.d.ts +50 -0
  30. package/dist/dsl/egraph/Optimizer.js +88 -0
  31. package/dist/dsl/egraph/Pattern.d.ts +80 -0
  32. package/dist/dsl/egraph/Pattern.js +325 -0
  33. package/dist/dsl/egraph/Rewriter.d.ts +44 -0
  34. package/dist/dsl/egraph/Rewriter.js +131 -0
  35. package/dist/dsl/egraph/Rules.d.ts +44 -0
  36. package/dist/dsl/egraph/Rules.js +187 -0
  37. package/dist/dsl/egraph/index.d.ts +15 -0
  38. package/dist/dsl/egraph/index.js +21 -0
  39. package/package.json +1 -1
  40. package/dist/dsl/CSE.d.ts +0 -21
  41. package/dist/dsl/CSE.js +0 -194
  42. package/dist/symbolic/AST.d.ts +0 -113
  43. package/dist/symbolic/AST.js +0 -128
  44. package/dist/symbolic/CodeGen.d.ts +0 -35
  45. package/dist/symbolic/CodeGen.js +0 -280
  46. package/dist/symbolic/Parser.d.ts +0 -64
  47. package/dist/symbolic/Parser.js +0 -329
  48. package/dist/symbolic/Simplify.d.ts +0 -10
  49. package/dist/symbolic/Simplify.js +0 -244
  50. package/dist/symbolic/SymbolicDiff.d.ts +0 -35
  51. package/dist/symbolic/SymbolicDiff.js +0 -339
@@ -2,16 +2,33 @@
2
2
  * Code generation for GradientScript DSL
3
3
  * Generates TypeScript/JavaScript code with gradient functions
4
4
  */
5
- import { simplifyGradients } from './Simplify.js';
6
- import { eliminateCommonSubexpressionsStructured } from './CSE.js';
5
+ import { simplifyGradients, simplifyPostCSE } from './Simplify.js';
6
+ import { ExpressionTransformer } from './ExpressionTransformer.js';
7
+ import { optimizeWithEGraph } from './egraph/index.js';
7
8
  import { CodeGenError } from './Errors.js';
9
+ import { serializeExpression } from './ExpressionUtils.js';
10
+ import { inlineExpression } from './Inliner.js';
11
+ function capitalize(str) {
12
+ return str.charAt(0).toUpperCase() + str.slice(1);
13
+ }
14
+ function shouldTrackForForwardReuse(expr) {
15
+ switch (expr.kind) {
16
+ case 'number':
17
+ case 'variable':
18
+ return false;
19
+ default:
20
+ return true;
21
+ }
22
+ }
8
23
  /**
9
24
  * Code generator for expressions
10
25
  */
11
26
  export class ExpressionCodeGen {
12
27
  format;
13
- constructor(format = 'typescript') {
28
+ csharpFloatType;
29
+ constructor(format = 'typescript', csharpFloatType = 'float') {
14
30
  this.format = format;
31
+ this.csharpFloatType = csharpFloatType;
15
32
  }
16
33
  /**
17
34
  * Generate code for an expression
@@ -47,7 +64,7 @@ export class ExpressionCodeGen {
47
64
  if (this.format === 'python' && (op === '^' || op === '**')) {
48
65
  op = '**'; // Python uses **
49
66
  }
50
- else if ((this.format === 'typescript' || this.format === 'javascript') && (op === '^' || op === '**')) {
67
+ else if ((this.format === 'typescript' || this.format === 'javascript' || this.format === 'csharp') && (op === '^' || op === '**')) {
51
68
  // Optimize: x^2 -> x*x, x^3 -> x*x*x (faster than Math.pow)
52
69
  // Only for simple expressions (variables, component access)
53
70
  const isSimple = expr.left.kind === 'variable' ||
@@ -63,15 +80,23 @@ export class ExpressionCodeGen {
63
80
  return left;
64
81
  }
65
82
  else if (exponent === 2) {
66
- return `${left} * ${left}`;
83
+ // Wrap in parens because we're changing ^ (precedence 3) to * (precedence 2)
84
+ // Without parens, a / b^2 would become a / b * b instead of a / (b * b)
85
+ return `(${left} * ${left})`;
67
86
  }
68
87
  else if (exponent === 3) {
69
- return `${left} * ${left} * ${left}`;
88
+ return `(${left} * ${left} * ${left})`;
70
89
  }
71
90
  }
72
91
  }
73
- // Fall back to Math.pow for complex expressions or larger exponents
74
- return `Math.pow(${left}, ${right})`;
92
+ // Fall back to Math.pow / MathF.Pow for complex expressions or larger exponents
93
+ if (this.format === 'csharp') {
94
+ const mathClass = this.csharpFloatType === 'float' ? 'MathF' : 'Math';
95
+ return `${mathClass}.Pow(${left}, ${right})`;
96
+ }
97
+ else {
98
+ return `Math.pow(${left}, ${right})`;
99
+ }
75
100
  }
76
101
  return `${left} ${op} ${right}`;
77
102
  }
@@ -160,6 +185,10 @@ export class ExpressionCodeGen {
160
185
  else if (this.format === 'python') {
161
186
  return `max(${min}, min(${max}, ${x}))`;
162
187
  }
188
+ else if (this.format === 'csharp') {
189
+ const mathClass = this.csharpFloatType === 'float' ? 'MathF' : 'Math';
190
+ return `${mathClass}.Max(${min}, ${mathClass}.Min(${max}, ${x}))`;
191
+ }
163
192
  }
164
193
  // Map function names for different formats
165
194
  const funcName = this.mapFunctionName(expr.name);
@@ -167,63 +196,91 @@ export class ExpressionCodeGen {
167
196
  }
168
197
  genComponent(expr) {
169
198
  const obj = this.generate(expr.object);
199
+ if (this.format === 'csharp') {
200
+ // C# uses PascalCase for properties
201
+ return `${obj}.${capitalize(expr.component)}`;
202
+ }
170
203
  return `${obj}.${expr.component}`;
171
204
  }
205
+ // Math functions that should be mapped across all formats
206
+ static MATH_FUNCTIONS = [
207
+ 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'atan2',
208
+ 'exp', 'log', 'sqrt', 'abs', 'pow', 'min', 'max'
209
+ ];
210
+ // Python built-in functions that don't need the math. prefix
211
+ static PYTHON_BUILTINS = ['abs', 'pow', 'min', 'max'];
172
212
  mapFunctionName(name) {
173
- if (this.format === 'typescript' || this.format === 'javascript') {
174
- const mathFuncs = {
175
- 'sin': 'Math.sin',
176
- 'cos': 'Math.cos',
177
- 'tan': 'Math.tan',
178
- 'asin': 'Math.asin',
179
- 'acos': 'Math.acos',
180
- 'atan': 'Math.atan',
181
- 'atan2': 'Math.atan2',
182
- 'exp': 'Math.exp',
183
- 'log': 'Math.log',
184
- 'sqrt': 'Math.sqrt',
185
- 'abs': 'Math.abs',
186
- 'pow': 'Math.pow',
187
- 'min': 'Math.min',
188
- 'max': 'Math.max'
189
- };
190
- return mathFuncs[name] || name;
191
- }
192
- else if (this.format === 'python') {
193
- const mathFuncs = {
194
- 'atan2': 'math.atan2',
195
- 'sin': 'math.sin',
196
- 'cos': 'math.cos',
197
- 'tan': 'math.tan',
198
- 'asin': 'math.asin',
199
- 'acos': 'math.acos',
200
- 'atan': 'math.atan',
201
- 'exp': 'math.exp',
202
- 'log': 'math.log',
203
- 'sqrt': 'math.sqrt',
204
- 'abs': 'abs',
205
- 'pow': 'pow',
206
- 'min': 'min',
207
- 'max': 'max'
208
- };
209
- return mathFuncs[name] || name;
213
+ // Check if this is a known math function
214
+ if (!ExpressionCodeGen.MATH_FUNCTIONS.includes(name)) {
215
+ return name;
216
+ }
217
+ // Define format-specific mappers
218
+ const mappers = {
219
+ typescript: (fn) => `Math.${fn}`,
220
+ javascript: (fn) => `Math.${fn}`,
221
+ python: (fn) => ExpressionCodeGen.PYTHON_BUILTINS.includes(fn) ? fn : `math.${fn}`,
222
+ csharp: (fn) => {
223
+ const mathClass = this.csharpFloatType === 'float' ? 'MathF' : 'Math';
224
+ const capitalized = fn.charAt(0).toUpperCase() + fn.slice(1);
225
+ return `${mathClass}.${capitalized}`;
226
+ }
227
+ };
228
+ const mapper = mappers[this.format];
229
+ return mapper ? mapper(name) : name;
230
+ }
231
+ }
232
+ /**
233
+ * Generate C# struct for gradient return type
234
+ */
235
+ function generateCSharpGradientStruct(func, gradients, floatType) {
236
+ const lines = [];
237
+ const structName = `${capitalize(func.name)}GradResult`;
238
+ lines.push(`public struct ${structName}`);
239
+ lines.push('{');
240
+ lines.push(` public ${floatType} Value;`);
241
+ for (const [paramName, gradient] of gradients.gradients.entries()) {
242
+ const propName = capitalize(`d${paramName}`);
243
+ if (isStructuredGradient(gradient)) {
244
+ // Generate a nested struct type for structured gradients
245
+ const components = Array.from(gradient.components.keys());
246
+ lines.push(` public ${capitalize(paramName)}Grad ${propName};`);
247
+ }
248
+ else {
249
+ lines.push(` public ${floatType} ${propName};`);
250
+ }
251
+ }
252
+ lines.push('}');
253
+ // Generate nested struct types for structured gradients
254
+ for (const [paramName, gradient] of gradients.gradients.entries()) {
255
+ if (isStructuredGradient(gradient)) {
256
+ lines.push('');
257
+ lines.push(`public struct ${capitalize(paramName)}Grad`);
258
+ lines.push('{');
259
+ for (const comp of gradient.components.keys()) {
260
+ lines.push(` public ${floatType} ${capitalize(comp)};`);
261
+ }
262
+ lines.push('}');
210
263
  }
211
- return name;
212
264
  }
265
+ return lines;
213
266
  }
214
267
  /**
215
268
  * Generate complete gradient function code
216
269
  */
217
270
  export function generateGradientFunction(func, gradients, env, options = {}) {
218
271
  const format = options.format || 'typescript';
272
+ const csharpFloatType = options.csharpFloatType || 'float';
219
273
  const includeComments = options.includeComments !== false;
220
274
  const shouldSimplify = options.simplify !== false; // Default to true
221
- // Simplify gradients if requested
222
- const gradientsToUse = shouldSimplify
223
- ? { gradients: simplifyGradients(gradients.gradients) }
224
- : gradients;
225
- const codegen = new ExpressionCodeGen(format);
275
+ // Note: We don't simplify here yet - we'll do it after forward expression substitution
276
+ const gradientsToUse = gradients;
277
+ const codegen = new ExpressionCodeGen(format, csharpFloatType);
226
278
  const lines = [];
279
+ // For C#, we need to generate a struct for the return type first
280
+ if (format === 'csharp') {
281
+ lines.push(...generateCSharpGradientStruct(func, gradientsToUse, csharpFloatType));
282
+ lines.push('');
283
+ }
227
284
  // Function signature
228
285
  const paramNames = func.parameters.map(p => p.name).join(', ');
229
286
  if (format === 'typescript' || format === 'javascript') {
@@ -232,44 +289,246 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
232
289
  else if (format === 'python') {
233
290
  lines.push(`def ${func.name}_grad(${paramNames}):`);
234
291
  }
292
+ else if (format === 'csharp') {
293
+ const params = func.parameters.map(p => {
294
+ if (p.paramType && p.paramType.components) {
295
+ return `${capitalize(p.name)}Struct ${p.name}`;
296
+ }
297
+ return `${csharpFloatType} ${p.name}`;
298
+ }).join(', ');
299
+ const returnType = `${capitalize(func.name)}GradResult`;
300
+ lines.push(`public static ${returnType} ${capitalize(func.name)}_Grad(${params})`);
301
+ lines.push('{');
302
+ }
235
303
  // Forward pass - compute intermediate variables
236
304
  // Track which expressions are already computed for CSE reuse
237
- const forwardPassVars = new Map();
305
+ const forwardExpressionMap = new Map();
306
+ // Build substitution map for inlining (to match gradient expressions which are fully inlined)
307
+ const substitutionMap = new Map();
308
+ for (const stmt of func.body) {
309
+ if (stmt.kind === 'assignment') {
310
+ substitutionMap.set(stmt.variable, stmt.expression);
311
+ }
312
+ }
313
+ // Collect forward variable names and expressions for optimization
314
+ const forwardVars = new Set();
315
+ const forwardVarExprs = new Map();
316
+ for (const stmt of func.body) {
317
+ if (stmt.kind === 'assignment') {
318
+ forwardVars.add(stmt.variable);
319
+ forwardVarExprs.set(stmt.variable, stmt.expression);
320
+ }
321
+ }
322
+ // Optimize forward expressions with e-graph
323
+ let optimizedForwardExprs = forwardVarExprs;
324
+ let forwardCseTemps = new Map();
325
+ if (options.cse !== false && forwardVarExprs.size > 0) {
326
+ const forOptimizer = new Map();
327
+ forOptimizer.set('_forward', forwardVarExprs);
328
+ const forwardResult = optimizeWithEGraph(forOptimizer, { verbose: false });
329
+ // Rename forward temps to avoid conflicts with gradient temps (use _fwd prefix)
330
+ const rawTemps = forwardResult.intermediates;
331
+ const renameMap = new Map();
332
+ for (const oldName of rawTemps.keys()) {
333
+ const newName = oldName.replace('_tmp', '_fwd');
334
+ renameMap.set(oldName, newName);
335
+ }
336
+ // Apply renaming to temp definitions
337
+ for (const [oldName, expr] of rawTemps) {
338
+ const newName = renameMap.get(oldName);
339
+ forwardCseTemps.set(newName, renameTempRefs(expr, renameMap));
340
+ }
341
+ // Apply renaming to optimized expressions
342
+ optimizedForwardExprs = new Map();
343
+ for (const [varName, expr] of (forwardResult.gradients.get('_forward') || forwardVarExprs)) {
344
+ optimizedForwardExprs.set(varName, renameTempRefs(expr, renameMap));
345
+ }
346
+ }
347
+ // Helper to rename temp references in an expression
348
+ function renameTempRefs(expr, renameMap) {
349
+ if (expr.kind === 'variable') {
350
+ const newName = renameMap.get(expr.name);
351
+ return newName ? { kind: 'variable', name: newName } : expr;
352
+ }
353
+ else if (expr.kind === 'binary') {
354
+ return {
355
+ kind: 'binary',
356
+ operator: expr.operator,
357
+ left: renameTempRefs(expr.left, renameMap),
358
+ right: renameTempRefs(expr.right, renameMap)
359
+ };
360
+ }
361
+ else if (expr.kind === 'unary') {
362
+ return {
363
+ kind: 'unary',
364
+ operator: expr.operator,
365
+ operand: renameTempRefs(expr.operand, renameMap)
366
+ };
367
+ }
368
+ else if (expr.kind === 'call') {
369
+ return {
370
+ kind: 'call',
371
+ name: expr.name,
372
+ args: expr.args.map(a => renameTempRefs(a, renameMap))
373
+ };
374
+ }
375
+ else if (expr.kind === 'component') {
376
+ return {
377
+ kind: 'component',
378
+ object: renameTempRefs(expr.object, renameMap),
379
+ component: expr.component
380
+ };
381
+ }
382
+ return expr;
383
+ }
384
+ // Helper to find which forward vars an expression depends on
385
+ function findForwardVarDeps(expr) {
386
+ const deps = new Set();
387
+ function visit(e) {
388
+ if (e.kind === 'variable' && forwardVars.has(e.name)) {
389
+ deps.add(e.name);
390
+ }
391
+ else if (e.kind === 'binary') {
392
+ visit(e.left);
393
+ visit(e.right);
394
+ }
395
+ else if (e.kind === 'unary') {
396
+ visit(e.operand);
397
+ }
398
+ else if (e.kind === 'call') {
399
+ e.args.forEach(visit);
400
+ }
401
+ else if (e.kind === 'component') {
402
+ visit(e.object);
403
+ }
404
+ }
405
+ visit(expr);
406
+ return deps;
407
+ }
408
+ // Track which CSE temps need to be emitted after which forward var
409
+ const fwdTempAfterVar = new Map();
410
+ const fwdTempsBeforeAny = [];
411
+ for (const [tempName, tempExpr] of forwardCseTemps) {
412
+ const deps = findForwardVarDeps(tempExpr);
413
+ if (deps.size === 0) {
414
+ fwdTempsBeforeAny.push({ name: tempName, expr: tempExpr });
415
+ }
416
+ else {
417
+ let lastDep = '';
418
+ for (const stmt of func.body) {
419
+ if (stmt.kind === 'assignment' && deps.has(stmt.variable)) {
420
+ lastDep = stmt.variable;
421
+ }
422
+ }
423
+ if (lastDep) {
424
+ if (!fwdTempAfterVar.has(lastDep)) {
425
+ fwdTempAfterVar.set(lastDep, []);
426
+ }
427
+ fwdTempAfterVar.get(lastDep).push({ name: tempName, expr: tempExpr });
428
+ }
429
+ }
430
+ }
431
+ // Emit forward CSE temps that don't depend on forward vars
432
+ for (const { name: tempName, expr } of fwdTempsBeforeAny) {
433
+ const code = codegen.generate(expr);
434
+ if (format === 'typescript' || format === 'javascript') {
435
+ lines.push(` const ${tempName} = ${code};`);
436
+ }
437
+ else if (format === 'python') {
438
+ lines.push(` ${tempName} = ${code}`);
439
+ }
440
+ else if (format === 'csharp') {
441
+ lines.push(` ${csharpFloatType} ${tempName} = ${code};`);
442
+ }
443
+ }
444
+ // Generate forward variable assignments with interleaved temps
238
445
  for (const stmt of func.body) {
239
446
  if (stmt.kind === 'assignment') {
240
447
  const varName = stmt.variable;
241
- const expr = codegen.generate(stmt.expression);
242
- // Track this for CSE reuse (store expression -> variable name mapping)
243
- forwardPassVars.set(expr, varName);
448
+ const expr = optimizedForwardExprs.get(varName) || stmt.expression;
449
+ const generatedExpr = codegen.generate(expr);
450
+ if (shouldTrackForForwardReuse(stmt.expression)) {
451
+ // Register the original expression
452
+ const exprKey = serializeExpression(stmt.expression);
453
+ if (!forwardExpressionMap.has(exprKey)) {
454
+ forwardExpressionMap.set(exprKey, varName);
455
+ }
456
+ // Also register the fully inlined form (this is what gradient expressions will have)
457
+ const inlinedExpr = inlineExpression(stmt.expression, substitutionMap);
458
+ const inlinedKey = serializeExpression(inlinedExpr);
459
+ if (inlinedKey !== exprKey && !forwardExpressionMap.has(inlinedKey)) {
460
+ forwardExpressionMap.set(inlinedKey, varName);
461
+ }
462
+ }
244
463
  if (format === 'typescript' || format === 'javascript') {
245
- lines.push(` const ${varName} = ${expr};`);
464
+ lines.push(` const ${varName} = ${generatedExpr};`);
246
465
  }
247
- else {
248
- lines.push(` ${varName} = ${expr}`);
466
+ else if (format === 'python') {
467
+ lines.push(` ${varName} = ${generatedExpr}`);
468
+ }
469
+ else if (format === 'csharp') {
470
+ lines.push(` ${csharpFloatType} ${varName} = ${generatedExpr};`);
471
+ }
472
+ // Emit any forward CSE temps that depend on this var
473
+ const tempsForVar = fwdTempAfterVar.get(varName) || [];
474
+ for (const { name: tempName, expr: tempExpr } of tempsForVar) {
475
+ const tempCode = codegen.generate(tempExpr);
476
+ if (format === 'typescript' || format === 'javascript') {
477
+ lines.push(` const ${tempName} = ${tempCode};`);
478
+ }
479
+ else if (format === 'python') {
480
+ lines.push(` ${tempName} = ${tempCode}`);
481
+ }
482
+ else if (format === 'csharp') {
483
+ lines.push(` ${csharpFloatType} ${tempName} = ${tempCode};`);
484
+ }
249
485
  }
250
486
  }
251
487
  }
252
488
  // Compute output value - reuse forward pass variables if possible
253
- let valueExpr = func.returnExpr;
489
+ const valueExpr = func.returnExpr;
490
+ const valueKey = serializeExpression(valueExpr);
491
+ const existingVar = forwardExpressionMap.get(valueKey);
254
492
  const valueCode = codegen.generate(valueExpr);
255
- const existingVar = forwardPassVars.get(valueCode);
256
493
  if (existingVar) {
257
494
  // Reuse existing variable
258
495
  if (format === 'typescript' || format === 'javascript') {
259
496
  lines.push(` const value = ${existingVar};`);
260
497
  }
261
- else {
498
+ else if (format === 'python') {
262
499
  lines.push(` value = ${existingVar}`);
263
500
  }
501
+ else if (format === 'csharp') {
502
+ lines.push(` ${csharpFloatType} value = ${existingVar};`);
503
+ }
264
504
  }
265
505
  else {
266
506
  // Compute new value
267
507
  if (format === 'typescript' || format === 'javascript') {
268
508
  lines.push(` const value = ${valueCode};`);
269
509
  }
270
- else {
510
+ else if (format === 'python') {
271
511
  lines.push(` value = ${valueCode}`);
272
512
  }
513
+ else if (format === 'csharp') {
514
+ lines.push(` ${csharpFloatType} value = ${valueCode};`);
515
+ }
516
+ if (shouldTrackForForwardReuse(valueExpr) && !forwardExpressionMap.has(valueKey)) {
517
+ forwardExpressionMap.set(valueKey, 'value');
518
+ }
519
+ }
520
+ // Apply forward expression substitution multiple times until no more changes
521
+ // This handles nested expressions like sqrt(dx*dx + dy*dy) where dx = pix - pjx
522
+ reuseForwardExpressionsInGradients(gradientsToUse.gradients, forwardExpressionMap);
523
+ reuseForwardExpressionsInGradients(gradientsToUse.gradients, forwardExpressionMap);
524
+ reuseForwardExpressionsInGradients(gradientsToUse.gradients, forwardExpressionMap);
525
+ // Simplify gradients after forward expression substitution
526
+ if (shouldSimplify) {
527
+ const simplified = simplifyGradients(gradientsToUse.gradients);
528
+ gradientsToUse.gradients.clear();
529
+ for (const [key, value] of simplified.entries()) {
530
+ gradientsToUse.gradients.set(key, value);
531
+ }
273
532
  }
274
533
  lines.push('');
275
534
  // Generate gradients
@@ -277,74 +536,71 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
277
536
  if (includeComments) {
278
537
  lines.push(` ${comment} Gradients`);
279
538
  }
280
- // Apply CSE if requested
539
+ // Apply e-graph optimization (CSE + algebraic simplification)
281
540
  const shouldApplyCSE = options.cse !== false; // Default to true
282
- const cseIntermediates = new Map();
541
+ let cseIntermediates = new Map();
283
542
  if (shouldApplyCSE) {
284
- // Collect all gradient expressions for CSE analysis
543
+ // Collect all gradient components into a single map for global optimization
544
+ // Include both structured and scalar gradients (scalar uses 'value' as component)
545
+ const allGradientComponents = new Map();
285
546
  for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
286
547
  if (isStructuredGradient(gradient)) {
287
- const cseResult = eliminateCommonSubexpressionsStructured(gradient.components);
288
- // Merge intermediates
289
- for (const [name, expr] of cseResult.intermediates.entries()) {
290
- cseIntermediates.set(name, expr);
291
- }
292
- // Update gradient components with CSE-simplified versions
293
- gradient.components = cseResult.components;
294
- }
295
- }
296
- // Generate intermediate variables from CSE
297
- if (cseIntermediates.size > 0) {
298
- // Check if we should emit guards (opt-in)
299
- const shouldEmitGuards = options.emitGuards === true;
300
- const epsilon = options.epsilon || 1e-10;
301
- // Identify potential denominators (sum of squares patterns)
302
- const denominatorVars = new Set();
303
- for (const [varName, expr] of cseIntermediates.entries()) {
304
- const code = codegen.generate(expr);
305
- // Check if this looks like a denominator (contains + and squared terms)
306
- if (code.includes('+') && (code.includes('* ') || code.includes('Math.pow'))) {
307
- denominatorVars.add(varName);
308
- }
548
+ allGradientComponents.set(paramName, gradient.components);
309
549
  }
310
- for (const [varName, expr] of cseIntermediates.entries()) {
311
- const code = codegen.generate(expr);
312
- if (format === 'typescript' || format === 'javascript') {
313
- lines.push(` const ${varName} = ${code};`);
314
- }
315
- else {
316
- lines.push(` ${varName} = ${code}`);
317
- }
550
+ else {
551
+ // Scalar gradient - wrap as a single 'value' component
552
+ allGradientComponents.set(paramName, new Map([['value', gradient]]));
553
+ }
554
+ }
555
+ // Run e-graph optimization globally across ALL gradient expressions
556
+ const globalCSE = optimizeWithEGraph(allGradientComponents, { verbose: false });
557
+ cseIntermediates = globalCSE.intermediates;
558
+ // Update gradient components with optimized versions
559
+ for (const [paramName, simplifiedComponents] of globalCSE.gradients.entries()) {
560
+ const gradient = gradientsToUse.gradients.get(paramName);
561
+ if (gradient && isStructuredGradient(gradient)) {
562
+ gradient.components = simplifiedComponents;
318
563
  }
319
- // Emit epsilon guard if needed
320
- if (shouldEmitGuards && denominatorVars.size > 0) {
321
- lines.push('');
322
- if (includeComments) {
323
- lines.push(` ${comment} Guard against division by zero`);
564
+ else {
565
+ // Scalar gradient - unwrap from 'value' component
566
+ const optimizedExpr = simplifiedComponents.get('value');
567
+ if (optimizedExpr) {
568
+ gradientsToUse.gradients.set(paramName, optimizedExpr);
324
569
  }
325
- for (const denom of denominatorVars) {
326
- if (format === 'typescript' || format === 'javascript') {
327
- lines.push(` if (Math.abs(${denom}) < ${epsilon}) {`);
328
- lines.push(` ${comment} Return zero gradients for degenerate case`);
329
- // Emit zero gradient structure
330
- const zeroGrads = [];
331
- for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
332
- if (isStructuredGradient(gradient)) {
333
- const components = Array.from(gradient.components.keys());
334
- const zeroStruct = components.map(c => `${c}: 0`).join(', ');
335
- zeroGrads.push(`d${paramName}: { ${zeroStruct} }`);
336
- }
337
- else {
338
- zeroGrads.push(`d${paramName}: 0`);
339
- }
340
- }
341
- lines.push(` return { value, ${zeroGrads.join(', ')} };`);
342
- lines.push(` }`);
343
- }
570
+ }
571
+ }
572
+ // Post-CSE simplification: apply rules that were skipped to avoid CSE interference
573
+ // Specifically: a + a → 2 * a (now safe because temps have been extracted)
574
+ for (const [varName, expr] of cseIntermediates) {
575
+ cseIntermediates.set(varName, simplifyPostCSE(expr));
576
+ }
577
+ for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
578
+ if (isStructuredGradient(gradient)) {
579
+ for (const [comp, expr] of gradient.components.entries()) {
580
+ gradient.components.set(comp, simplifyPostCSE(expr));
344
581
  }
345
582
  }
346
- lines.push('');
583
+ else {
584
+ // Scalar gradient - apply post-CSE simplification
585
+ gradientsToUse.gradients.set(paramName, simplifyPostCSE(gradient));
586
+ }
587
+ }
588
+ }
589
+ // Generate CSE intermediate variables
590
+ if (cseIntermediates.size > 0) {
591
+ for (const [varName, expr] of cseIntermediates.entries()) {
592
+ const code = codegen.generate(expr);
593
+ if (format === 'typescript' || format === 'javascript') {
594
+ lines.push(` const ${varName} = ${code};`);
595
+ }
596
+ else if (format === 'python') {
597
+ lines.push(` ${varName} = ${code}`);
598
+ }
599
+ else if (format === 'csharp') {
600
+ lines.push(` ${csharpFloatType} ${varName} = ${code};`);
601
+ }
347
602
  }
603
+ lines.push('');
348
604
  }
349
605
  for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
350
606
  // Use shorter names: du, dv instead of grad_u, grad_v
@@ -366,7 +622,7 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
366
622
  }
367
623
  lines.push(` };`);
368
624
  }
369
- else {
625
+ else if (format === 'python') {
370
626
  lines.push(` ${gradName} = {`);
371
627
  for (const comp of components) {
372
628
  const [key, value] = comp.split(': ');
@@ -374,6 +630,15 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
374
630
  }
375
631
  lines.push(` }`);
376
632
  }
633
+ else if (format === 'csharp') {
634
+ lines.push(` var ${gradName} = new ${capitalize(paramName)}Grad`);
635
+ lines.push(` {`);
636
+ for (const comp of components) {
637
+ const [key, value] = comp.split(': ');
638
+ lines.push(` ${capitalize(key)} = ${value},`);
639
+ }
640
+ lines.push(` };`);
641
+ }
377
642
  }
378
643
  else {
379
644
  // Scalar gradient
@@ -381,9 +646,12 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
381
646
  if (format === 'typescript' || format === 'javascript') {
382
647
  lines.push(` const ${gradName} = ${code};`);
383
648
  }
384
- else {
649
+ else if (format === 'python') {
385
650
  lines.push(` ${gradName} = ${code}`);
386
651
  }
652
+ else if (format === 'csharp') {
653
+ lines.push(` ${csharpFloatType} ${gradName} = ${code};`);
654
+ }
387
655
  }
388
656
  }
389
657
  lines.push('');
@@ -399,7 +667,7 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
399
667
  lines.push(` };`);
400
668
  lines.push('}');
401
669
  }
402
- else {
670
+ else if (format === 'python') {
403
671
  lines.push(` return {`);
404
672
  lines.push(` "value": value,`);
405
673
  for (const gradName of gradNames) {
@@ -407,33 +675,160 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
407
675
  }
408
676
  lines.push(` }`);
409
677
  }
678
+ else if (format === 'csharp') {
679
+ lines.push(` return new ${capitalize(func.name)}GradResult`);
680
+ lines.push(` {`);
681
+ lines.push(` Value = value,`);
682
+ for (const gradName of gradNames) {
683
+ lines.push(` ${capitalize(gradName)} = ${gradName},`);
684
+ }
685
+ lines.push(` };`);
686
+ lines.push('}');
687
+ }
410
688
  return lines.join('\n');
411
689
  }
412
690
  /**
413
- * Generate the original forward function
691
+ * Generate the original forward function (with optional e-graph optimization)
414
692
  */
415
693
  export function generateForwardFunction(func, options = {}) {
416
694
  const format = options.format || 'typescript';
417
- const codegen = new ExpressionCodeGen(format);
695
+ const csharpFloatType = options.csharpFloatType || 'float';
696
+ const shouldOptimize = options.cse !== false; // Optimize by default
697
+ const codegen = new ExpressionCodeGen(format, csharpFloatType);
418
698
  const lines = [];
419
699
  // Function signature
420
700
  const paramNames = func.parameters.map(p => p.name).join(', ');
421
701
  if (format === 'typescript' || format === 'javascript') {
422
702
  lines.push(`function ${func.name}(${paramNames}) {`);
423
703
  }
424
- else {
704
+ else if (format === 'python') {
425
705
  lines.push(`def ${func.name}(${paramNames}):`);
426
706
  }
427
- // Body
707
+ else if (format === 'csharp') {
708
+ const floatType = csharpFloatType;
709
+ const params = func.parameters.map(p => {
710
+ if (p.paramType && p.paramType.components) {
711
+ return `${capitalize(p.name)}Struct ${p.name}`;
712
+ }
713
+ return `${floatType} ${p.name}`;
714
+ }).join(', ');
715
+ lines.push(`public static ${floatType} ${capitalize(func.name)}(${params})`);
716
+ lines.push('{');
717
+ }
718
+ // Collect all forward variable names (for dependency tracking)
719
+ const forwardVars = new Set();
720
+ for (const stmt of func.body) {
721
+ if (stmt.kind === 'assignment') {
722
+ forwardVars.add(stmt.variable);
723
+ }
724
+ }
725
+ // Collect expressions for optimization
726
+ const varExpressions = new Map();
727
+ for (const stmt of func.body) {
728
+ if (stmt.kind === 'assignment') {
729
+ varExpressions.set(stmt.variable, stmt.expression);
730
+ }
731
+ }
732
+ // Optimize with e-graph if enabled
733
+ let optimizedExprs = varExpressions;
734
+ let cseTemps = new Map();
735
+ if (shouldOptimize && varExpressions.size > 0) {
736
+ const forOptimizer = new Map();
737
+ forOptimizer.set('_forward', varExpressions);
738
+ const result = optimizeWithEGraph(forOptimizer, { verbose: false });
739
+ cseTemps = result.intermediates;
740
+ optimizedExprs = result.gradients.get('_forward') || varExpressions;
741
+ }
742
+ // Helper to find which forward vars an expression depends on
743
+ function findForwardVarDeps(expr) {
744
+ const deps = new Set();
745
+ function visit(e) {
746
+ if (e.kind === 'variable' && forwardVars.has(e.name)) {
747
+ deps.add(e.name);
748
+ }
749
+ else if (e.kind === 'binary') {
750
+ visit(e.left);
751
+ visit(e.right);
752
+ }
753
+ else if (e.kind === 'unary') {
754
+ visit(e.operand);
755
+ }
756
+ else if (e.kind === 'call') {
757
+ e.args.forEach(visit);
758
+ }
759
+ else if (e.kind === 'component') {
760
+ visit(e.object);
761
+ }
762
+ }
763
+ visit(expr);
764
+ return deps;
765
+ }
766
+ // Track which temps need to be emitted after which forward var
767
+ const tempAfterVar = new Map();
768
+ const tempsEmittedBeforeAny = [];
769
+ for (const [tempName, tempExpr] of cseTemps) {
770
+ const deps = findForwardVarDeps(tempExpr);
771
+ if (deps.size === 0) {
772
+ // Temp only depends on params - emit before any forward vars
773
+ tempsEmittedBeforeAny.push({ name: tempName, expr: tempExpr });
774
+ }
775
+ else {
776
+ // Find the last forward var this temp depends on
777
+ let lastDep = '';
778
+ for (const stmt of func.body) {
779
+ if (stmt.kind === 'assignment' && deps.has(stmt.variable)) {
780
+ lastDep = stmt.variable;
781
+ }
782
+ }
783
+ if (lastDep) {
784
+ if (!tempAfterVar.has(lastDep)) {
785
+ tempAfterVar.set(lastDep, []);
786
+ }
787
+ tempAfterVar.get(lastDep).push({ name: tempName, expr: tempExpr });
788
+ }
789
+ }
790
+ }
791
+ // Emit temps that don't depend on forward vars
792
+ for (const { name: tempName, expr } of tempsEmittedBeforeAny) {
793
+ const code = codegen.generate(expr);
794
+ if (format === 'typescript' || format === 'javascript') {
795
+ lines.push(` const ${tempName} = ${code};`);
796
+ }
797
+ else if (format === 'python') {
798
+ lines.push(` ${tempName} = ${code}`);
799
+ }
800
+ else if (format === 'csharp') {
801
+ lines.push(` ${csharpFloatType} ${tempName} = ${code};`);
802
+ }
803
+ }
804
+ // Generate variable assignments with interleaved temps
428
805
  for (const stmt of func.body) {
429
806
  if (stmt.kind === 'assignment') {
430
807
  const varName = stmt.variable;
431
- const expr = codegen.generate(stmt.expression);
808
+ const expr = optimizedExprs.get(varName) || stmt.expression;
809
+ const code = codegen.generate(expr);
432
810
  if (format === 'typescript' || format === 'javascript') {
433
- lines.push(` const ${varName} = ${expr};`);
811
+ lines.push(` const ${varName} = ${code};`);
434
812
  }
435
- else {
436
- lines.push(` ${varName} = ${expr}`);
813
+ else if (format === 'python') {
814
+ lines.push(` ${varName} = ${code}`);
815
+ }
816
+ else if (format === 'csharp') {
817
+ lines.push(` ${csharpFloatType} ${varName} = ${code};`);
818
+ }
819
+ // Emit any temps that depend on this var
820
+ const tempsForVar = tempAfterVar.get(varName) || [];
821
+ for (const { name: tempName, expr: tempExpr } of tempsForVar) {
822
+ const tempCode = codegen.generate(tempExpr);
823
+ if (format === 'typescript' || format === 'javascript') {
824
+ lines.push(` const ${tempName} = ${tempCode};`);
825
+ }
826
+ else if (format === 'python') {
827
+ lines.push(` ${tempName} = ${tempCode}`);
828
+ }
829
+ else if (format === 'csharp') {
830
+ lines.push(` ${csharpFloatType} ${tempName} = ${tempCode};`);
831
+ }
437
832
  }
438
833
  }
439
834
  }
@@ -443,9 +838,31 @@ export function generateForwardFunction(func, options = {}) {
443
838
  lines.push(` return ${returnExpr};`);
444
839
  lines.push('}');
445
840
  }
446
- else {
841
+ else if (format === 'python') {
447
842
  lines.push(` return ${returnExpr}`);
448
843
  }
844
+ else if (format === 'csharp') {
845
+ lines.push(` return ${returnExpr};`);
846
+ lines.push('}');
847
+ }
848
+ // For C#, prepend struct definitions
849
+ if (format === 'csharp') {
850
+ const structLines = [];
851
+ for (const param of func.parameters) {
852
+ if (param.paramType && param.paramType.components) {
853
+ structLines.push(`public struct ${capitalize(param.name)}Struct`);
854
+ structLines.push('{');
855
+ for (const comp of param.paramType.components) {
856
+ structLines.push(` public ${csharpFloatType} ${capitalize(comp)};`);
857
+ }
858
+ structLines.push('}');
859
+ structLines.push('');
860
+ }
861
+ }
862
+ if (structLines.length > 0) {
863
+ return structLines.join('\n') + lines.join('\n');
864
+ }
865
+ }
449
866
  return lines.join('\n');
450
867
  }
451
868
  /**
@@ -466,6 +883,40 @@ export function generateComplete(func, gradients, env, options = {}) {
466
883
  lines.push(generateGradientFunction(func, gradients, env, options));
467
884
  return lines.join('\n');
468
885
  }
886
+ class ForwardExpressionSubstituter extends ExpressionTransformer {
887
+ forwardExpressions;
888
+ constructor(forwardExpressions) {
889
+ super();
890
+ this.forwardExpressions = forwardExpressions;
891
+ }
892
+ transform(expr) {
893
+ const key = serializeExpression(expr);
894
+ const varName = this.forwardExpressions.get(key);
895
+ if (varName) {
896
+ return {
897
+ kind: 'variable',
898
+ name: varName
899
+ };
900
+ }
901
+ return super.transform(expr);
902
+ }
903
+ }
904
+ function reuseForwardExpressionsInGradients(gradients, forwardExpressions) {
905
+ if (forwardExpressions.size === 0) {
906
+ return;
907
+ }
908
+ const substituter = new ForwardExpressionSubstituter(forwardExpressions);
909
+ for (const [paramName, gradient] of gradients.entries()) {
910
+ if (isStructuredGradient(gradient)) {
911
+ for (const [component, expr] of gradient.components.entries()) {
912
+ gradient.components.set(component, substituter.transform(expr));
913
+ }
914
+ }
915
+ else {
916
+ gradients.set(paramName, substituter.transform(gradient));
917
+ }
918
+ }
919
+ }
469
920
  /**
470
921
  * Type guard for StructuredGradient
471
922
  */