gradient-script 0.2.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -1
- package/dist/cli.js +219 -6
- package/dist/dsl/CodeGen.d.ts +1 -1
- package/dist/dsl/CodeGen.js +336 -74
- package/dist/dsl/ExpressionUtils.d.ts +8 -2
- package/dist/dsl/ExpressionUtils.js +34 -2
- package/dist/dsl/GradientChecker.d.ts +21 -0
- package/dist/dsl/GradientChecker.js +109 -23
- package/dist/dsl/Guards.d.ts +1 -1
- package/dist/dsl/Guards.js +14 -13
- package/dist/dsl/Inliner.d.ts +5 -0
- package/dist/dsl/Inliner.js +8 -0
- package/dist/dsl/Simplify.d.ts +7 -0
- package/dist/dsl/Simplify.js +136 -0
- package/dist/dsl/egraph/Convert.d.ts +23 -0
- package/dist/dsl/egraph/Convert.js +84 -0
- package/dist/dsl/egraph/EGraph.d.ts +93 -0
- package/dist/dsl/egraph/EGraph.js +292 -0
- package/dist/dsl/egraph/ENode.d.ts +63 -0
- package/dist/dsl/egraph/ENode.js +94 -0
- package/dist/dsl/egraph/Extractor.d.ts +49 -0
- package/dist/dsl/egraph/Extractor.js +1068 -0
- package/dist/dsl/egraph/Optimizer.d.ts +50 -0
- package/dist/dsl/egraph/Optimizer.js +88 -0
- package/dist/dsl/egraph/Pattern.d.ts +80 -0
- package/dist/dsl/egraph/Pattern.js +325 -0
- package/dist/dsl/egraph/Rewriter.d.ts +44 -0
- package/dist/dsl/egraph/Rewriter.js +131 -0
- package/dist/dsl/egraph/Rules.d.ts +44 -0
- package/dist/dsl/egraph/Rules.js +187 -0
- package/dist/dsl/egraph/index.d.ts +15 -0
- package/dist/dsl/egraph/index.js +21 -0
- package/package.json +1 -1
- package/dist/dsl/CSE.d.ts +0 -21
- package/dist/dsl/CSE.js +0 -168
- package/dist/symbolic/AST.d.ts +0 -113
- package/dist/symbolic/AST.js +0 -128
- package/dist/symbolic/CodeGen.d.ts +0 -35
- package/dist/symbolic/CodeGen.js +0 -280
- package/dist/symbolic/Parser.d.ts +0 -64
- package/dist/symbolic/Parser.js +0 -329
- package/dist/symbolic/Simplify.d.ts +0 -10
- package/dist/symbolic/Simplify.js +0 -244
- package/dist/symbolic/SymbolicDiff.d.ts +0 -35
- package/dist/symbolic/SymbolicDiff.js +0 -339
package/dist/dsl/CodeGen.js
CHANGED
|
@@ -2,11 +2,12 @@
|
|
|
2
2
|
* Code generation for GradientScript DSL
|
|
3
3
|
* Generates TypeScript/JavaScript code with gradient functions
|
|
4
4
|
*/
|
|
5
|
-
import { simplifyGradients } from './Simplify.js';
|
|
5
|
+
import { simplifyGradients, simplifyPostCSE } from './Simplify.js';
|
|
6
6
|
import { ExpressionTransformer } from './ExpressionTransformer.js';
|
|
7
|
-
import {
|
|
7
|
+
import { optimizeWithEGraph } from './egraph/index.js';
|
|
8
8
|
import { CodeGenError } from './Errors.js';
|
|
9
9
|
import { serializeExpression } from './ExpressionUtils.js';
|
|
10
|
+
import { inlineExpression } from './Inliner.js';
|
|
10
11
|
function capitalize(str) {
|
|
11
12
|
return str.charAt(0).toUpperCase() + str.slice(1);
|
|
12
13
|
}
|
|
@@ -79,10 +80,12 @@ export class ExpressionCodeGen {
|
|
|
79
80
|
return left;
|
|
80
81
|
}
|
|
81
82
|
else if (exponent === 2) {
|
|
82
|
-
|
|
83
|
+
// Wrap in parens because we're changing ^ (precedence 3) to * (precedence 2)
|
|
84
|
+
// Without parens, a / b^2 would become a / b * b instead of a / (b * b)
|
|
85
|
+
return `(${left} * ${left})`;
|
|
83
86
|
}
|
|
84
87
|
else if (exponent === 3) {
|
|
85
|
-
return
|
|
88
|
+
return `(${left} * ${left} * ${left})`;
|
|
86
89
|
}
|
|
87
90
|
}
|
|
88
91
|
}
|
|
@@ -166,6 +169,10 @@ export class ExpressionCodeGen {
|
|
|
166
169
|
}
|
|
167
170
|
genUnary(expr) {
|
|
168
171
|
const operand = this.generate(expr.operand);
|
|
172
|
+
// Parenthesize binary operands to avoid precedence bugs: -(a + b) not -a + b
|
|
173
|
+
if (expr.operand.kind === 'binary') {
|
|
174
|
+
return `${expr.operator}(${operand})`;
|
|
175
|
+
}
|
|
169
176
|
return `${expr.operator}${operand}`;
|
|
170
177
|
}
|
|
171
178
|
genCall(expr) {
|
|
@@ -300,15 +307,162 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
300
307
|
// Forward pass - compute intermediate variables
|
|
301
308
|
// Track which expressions are already computed for CSE reuse
|
|
302
309
|
const forwardExpressionMap = new Map();
|
|
310
|
+
// Build substitution map for inlining (to match gradient expressions which are fully inlined)
|
|
311
|
+
const substitutionMap = new Map();
|
|
312
|
+
for (const stmt of func.body) {
|
|
313
|
+
if (stmt.kind === 'assignment') {
|
|
314
|
+
substitutionMap.set(stmt.variable, stmt.expression);
|
|
315
|
+
}
|
|
316
|
+
}
|
|
317
|
+
// Collect forward variable names and expressions for optimization
|
|
318
|
+
const forwardVars = new Set();
|
|
319
|
+
const forwardVarExprs = new Map();
|
|
320
|
+
for (const stmt of func.body) {
|
|
321
|
+
if (stmt.kind === 'assignment') {
|
|
322
|
+
forwardVars.add(stmt.variable);
|
|
323
|
+
forwardVarExprs.set(stmt.variable, stmt.expression);
|
|
324
|
+
}
|
|
325
|
+
}
|
|
326
|
+
// Optimize forward expressions with e-graph
|
|
327
|
+
let optimizedForwardExprs = forwardVarExprs;
|
|
328
|
+
let forwardCseTemps = new Map();
|
|
329
|
+
if (options.cse !== false && forwardVarExprs.size > 0) {
|
|
330
|
+
const forOptimizer = new Map();
|
|
331
|
+
forOptimizer.set('_forward', forwardVarExprs);
|
|
332
|
+
const forwardResult = optimizeWithEGraph(forOptimizer, { verbose: false });
|
|
333
|
+
// Rename forward temps to avoid conflicts with gradient temps (use _fwd prefix)
|
|
334
|
+
const rawTemps = forwardResult.intermediates;
|
|
335
|
+
const renameMap = new Map();
|
|
336
|
+
for (const oldName of rawTemps.keys()) {
|
|
337
|
+
const newName = oldName.replace('_tmp', '_fwd');
|
|
338
|
+
renameMap.set(oldName, newName);
|
|
339
|
+
}
|
|
340
|
+
// Apply renaming to temp definitions
|
|
341
|
+
for (const [oldName, expr] of rawTemps) {
|
|
342
|
+
const newName = renameMap.get(oldName);
|
|
343
|
+
forwardCseTemps.set(newName, renameTempRefs(expr, renameMap));
|
|
344
|
+
}
|
|
345
|
+
// Apply renaming to optimized expressions
|
|
346
|
+
optimizedForwardExprs = new Map();
|
|
347
|
+
for (const [varName, expr] of (forwardResult.gradients.get('_forward') || forwardVarExprs)) {
|
|
348
|
+
optimizedForwardExprs.set(varName, renameTempRefs(expr, renameMap));
|
|
349
|
+
}
|
|
350
|
+
}
|
|
351
|
+
// Helper to rename temp references in an expression
|
|
352
|
+
function renameTempRefs(expr, renameMap) {
|
|
353
|
+
if (expr.kind === 'variable') {
|
|
354
|
+
const newName = renameMap.get(expr.name);
|
|
355
|
+
return newName ? { kind: 'variable', name: newName } : expr;
|
|
356
|
+
}
|
|
357
|
+
else if (expr.kind === 'binary') {
|
|
358
|
+
return {
|
|
359
|
+
kind: 'binary',
|
|
360
|
+
operator: expr.operator,
|
|
361
|
+
left: renameTempRefs(expr.left, renameMap),
|
|
362
|
+
right: renameTempRefs(expr.right, renameMap)
|
|
363
|
+
};
|
|
364
|
+
}
|
|
365
|
+
else if (expr.kind === 'unary') {
|
|
366
|
+
return {
|
|
367
|
+
kind: 'unary',
|
|
368
|
+
operator: expr.operator,
|
|
369
|
+
operand: renameTempRefs(expr.operand, renameMap)
|
|
370
|
+
};
|
|
371
|
+
}
|
|
372
|
+
else if (expr.kind === 'call') {
|
|
373
|
+
return {
|
|
374
|
+
kind: 'call',
|
|
375
|
+
name: expr.name,
|
|
376
|
+
args: expr.args.map(a => renameTempRefs(a, renameMap))
|
|
377
|
+
};
|
|
378
|
+
}
|
|
379
|
+
else if (expr.kind === 'component') {
|
|
380
|
+
return {
|
|
381
|
+
kind: 'component',
|
|
382
|
+
object: renameTempRefs(expr.object, renameMap),
|
|
383
|
+
component: expr.component
|
|
384
|
+
};
|
|
385
|
+
}
|
|
386
|
+
return expr;
|
|
387
|
+
}
|
|
388
|
+
// Helper to find which forward vars an expression depends on
|
|
389
|
+
function findForwardVarDeps(expr) {
|
|
390
|
+
const deps = new Set();
|
|
391
|
+
function visit(e) {
|
|
392
|
+
if (e.kind === 'variable' && forwardVars.has(e.name)) {
|
|
393
|
+
deps.add(e.name);
|
|
394
|
+
}
|
|
395
|
+
else if (e.kind === 'binary') {
|
|
396
|
+
visit(e.left);
|
|
397
|
+
visit(e.right);
|
|
398
|
+
}
|
|
399
|
+
else if (e.kind === 'unary') {
|
|
400
|
+
visit(e.operand);
|
|
401
|
+
}
|
|
402
|
+
else if (e.kind === 'call') {
|
|
403
|
+
e.args.forEach(visit);
|
|
404
|
+
}
|
|
405
|
+
else if (e.kind === 'component') {
|
|
406
|
+
visit(e.object);
|
|
407
|
+
}
|
|
408
|
+
}
|
|
409
|
+
visit(expr);
|
|
410
|
+
return deps;
|
|
411
|
+
}
|
|
412
|
+
// Track which CSE temps need to be emitted after which forward var
|
|
413
|
+
const fwdTempAfterVar = new Map();
|
|
414
|
+
const fwdTempsBeforeAny = [];
|
|
415
|
+
for (const [tempName, tempExpr] of forwardCseTemps) {
|
|
416
|
+
const deps = findForwardVarDeps(tempExpr);
|
|
417
|
+
if (deps.size === 0) {
|
|
418
|
+
fwdTempsBeforeAny.push({ name: tempName, expr: tempExpr });
|
|
419
|
+
}
|
|
420
|
+
else {
|
|
421
|
+
let lastDep = '';
|
|
422
|
+
for (const stmt of func.body) {
|
|
423
|
+
if (stmt.kind === 'assignment' && deps.has(stmt.variable)) {
|
|
424
|
+
lastDep = stmt.variable;
|
|
425
|
+
}
|
|
426
|
+
}
|
|
427
|
+
if (lastDep) {
|
|
428
|
+
if (!fwdTempAfterVar.has(lastDep)) {
|
|
429
|
+
fwdTempAfterVar.set(lastDep, []);
|
|
430
|
+
}
|
|
431
|
+
fwdTempAfterVar.get(lastDep).push({ name: tempName, expr: tempExpr });
|
|
432
|
+
}
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
// Emit forward CSE temps that don't depend on forward vars
|
|
436
|
+
for (const { name: tempName, expr } of fwdTempsBeforeAny) {
|
|
437
|
+
const code = codegen.generate(expr);
|
|
438
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
439
|
+
lines.push(` const ${tempName} = ${code};`);
|
|
440
|
+
}
|
|
441
|
+
else if (format === 'python') {
|
|
442
|
+
lines.push(` ${tempName} = ${code}`);
|
|
443
|
+
}
|
|
444
|
+
else if (format === 'csharp') {
|
|
445
|
+
lines.push(` ${csharpFloatType} ${tempName} = ${code};`);
|
|
446
|
+
}
|
|
447
|
+
}
|
|
448
|
+
// Generate forward variable assignments with interleaved temps
|
|
303
449
|
for (const stmt of func.body) {
|
|
304
450
|
if (stmt.kind === 'assignment') {
|
|
305
451
|
const varName = stmt.variable;
|
|
306
|
-
const
|
|
452
|
+
const expr = optimizedForwardExprs.get(varName) || stmt.expression;
|
|
453
|
+
const generatedExpr = codegen.generate(expr);
|
|
307
454
|
if (shouldTrackForForwardReuse(stmt.expression)) {
|
|
455
|
+
// Register the original expression
|
|
308
456
|
const exprKey = serializeExpression(stmt.expression);
|
|
309
457
|
if (!forwardExpressionMap.has(exprKey)) {
|
|
310
458
|
forwardExpressionMap.set(exprKey, varName);
|
|
311
459
|
}
|
|
460
|
+
// Also register the fully inlined form (this is what gradient expressions will have)
|
|
461
|
+
const inlinedExpr = inlineExpression(stmt.expression, substitutionMap);
|
|
462
|
+
const inlinedKey = serializeExpression(inlinedExpr);
|
|
463
|
+
if (inlinedKey !== exprKey && !forwardExpressionMap.has(inlinedKey)) {
|
|
464
|
+
forwardExpressionMap.set(inlinedKey, varName);
|
|
465
|
+
}
|
|
312
466
|
}
|
|
313
467
|
if (format === 'typescript' || format === 'javascript') {
|
|
314
468
|
lines.push(` const ${varName} = ${generatedExpr};`);
|
|
@@ -319,6 +473,20 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
319
473
|
else if (format === 'csharp') {
|
|
320
474
|
lines.push(` ${csharpFloatType} ${varName} = ${generatedExpr};`);
|
|
321
475
|
}
|
|
476
|
+
// Emit any forward CSE temps that depend on this var
|
|
477
|
+
const tempsForVar = fwdTempAfterVar.get(varName) || [];
|
|
478
|
+
for (const { name: tempName, expr: tempExpr } of tempsForVar) {
|
|
479
|
+
const tempCode = codegen.generate(tempExpr);
|
|
480
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
481
|
+
lines.push(` const ${tempName} = ${tempCode};`);
|
|
482
|
+
}
|
|
483
|
+
else if (format === 'python') {
|
|
484
|
+
lines.push(` ${tempName} = ${tempCode}`);
|
|
485
|
+
}
|
|
486
|
+
else if (format === 'csharp') {
|
|
487
|
+
lines.push(` ${csharpFloatType} ${tempName} = ${tempCode};`);
|
|
488
|
+
}
|
|
489
|
+
}
|
|
322
490
|
}
|
|
323
491
|
}
|
|
324
492
|
// Compute output value - reuse forward pass variables if possible
|
|
@@ -372,77 +540,71 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
372
540
|
if (includeComments) {
|
|
373
541
|
lines.push(` ${comment} Gradients`);
|
|
374
542
|
}
|
|
375
|
-
// Apply CSE
|
|
543
|
+
// Apply e-graph optimization (CSE + algebraic simplification)
|
|
376
544
|
const shouldApplyCSE = options.cse !== false; // Default to true
|
|
377
|
-
|
|
545
|
+
let cseIntermediates = new Map();
|
|
378
546
|
if (shouldApplyCSE) {
|
|
379
|
-
// Collect all gradient
|
|
547
|
+
// Collect all gradient components into a single map for global optimization
|
|
548
|
+
// Include both structured and scalar gradients (scalar uses 'value' as component)
|
|
549
|
+
const allGradientComponents = new Map();
|
|
380
550
|
for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
|
|
381
551
|
if (isStructuredGradient(gradient)) {
|
|
382
|
-
|
|
383
|
-
// Merge intermediates
|
|
384
|
-
for (const [name, expr] of cseResult.intermediates.entries()) {
|
|
385
|
-
cseIntermediates.set(name, expr);
|
|
386
|
-
}
|
|
387
|
-
// Update gradient components with CSE-simplified versions
|
|
388
|
-
gradient.components = cseResult.components;
|
|
389
|
-
}
|
|
390
|
-
}
|
|
391
|
-
// Generate intermediate variables from CSE
|
|
392
|
-
if (cseIntermediates.size > 0) {
|
|
393
|
-
// Check if we should emit guards (opt-in)
|
|
394
|
-
const shouldEmitGuards = options.emitGuards === true;
|
|
395
|
-
const epsilon = options.epsilon || 1e-10;
|
|
396
|
-
// Identify potential denominators (sum of squares patterns)
|
|
397
|
-
const denominatorVars = new Set();
|
|
398
|
-
for (const [varName, expr] of cseIntermediates.entries()) {
|
|
399
|
-
const code = codegen.generate(expr);
|
|
400
|
-
// Check if this looks like a denominator (contains + and squared terms)
|
|
401
|
-
if (code.includes('+') && (code.includes('* ') || code.includes('Math.pow'))) {
|
|
402
|
-
denominatorVars.add(varName);
|
|
403
|
-
}
|
|
552
|
+
allGradientComponents.set(paramName, gradient.components);
|
|
404
553
|
}
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
554
|
+
else {
|
|
555
|
+
// Scalar gradient - wrap as a single 'value' component
|
|
556
|
+
allGradientComponents.set(paramName, new Map([['value', gradient]]));
|
|
557
|
+
}
|
|
558
|
+
}
|
|
559
|
+
// Run e-graph optimization globally across ALL gradient expressions
|
|
560
|
+
const globalCSE = optimizeWithEGraph(allGradientComponents, { verbose: false });
|
|
561
|
+
cseIntermediates = globalCSE.intermediates;
|
|
562
|
+
// Update gradient components with optimized versions
|
|
563
|
+
for (const [paramName, simplifiedComponents] of globalCSE.gradients.entries()) {
|
|
564
|
+
const gradient = gradientsToUse.gradients.get(paramName);
|
|
565
|
+
if (gradient && isStructuredGradient(gradient)) {
|
|
566
|
+
gradient.components = simplifiedComponents;
|
|
416
567
|
}
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
if (
|
|
421
|
-
|
|
568
|
+
else {
|
|
569
|
+
// Scalar gradient - unwrap from 'value' component
|
|
570
|
+
const optimizedExpr = simplifiedComponents.get('value');
|
|
571
|
+
if (optimizedExpr) {
|
|
572
|
+
gradientsToUse.gradients.set(paramName, optimizedExpr);
|
|
422
573
|
}
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
}
|
|
435
|
-
else {
|
|
436
|
-
zeroGrads.push(`d${paramName}: 0`);
|
|
437
|
-
}
|
|
438
|
-
}
|
|
439
|
-
lines.push(` return { value, ${zeroGrads.join(', ')} };`);
|
|
440
|
-
lines.push(` }`);
|
|
441
|
-
}
|
|
574
|
+
}
|
|
575
|
+
}
|
|
576
|
+
// Post-CSE simplification: apply rules that were skipped to avoid CSE interference
|
|
577
|
+
// Specifically: a + a → 2 * a (now safe because temps have been extracted)
|
|
578
|
+
for (const [varName, expr] of cseIntermediates) {
|
|
579
|
+
cseIntermediates.set(varName, simplifyPostCSE(expr));
|
|
580
|
+
}
|
|
581
|
+
for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
|
|
582
|
+
if (isStructuredGradient(gradient)) {
|
|
583
|
+
for (const [comp, expr] of gradient.components.entries()) {
|
|
584
|
+
gradient.components.set(comp, simplifyPostCSE(expr));
|
|
442
585
|
}
|
|
443
586
|
}
|
|
444
|
-
|
|
587
|
+
else {
|
|
588
|
+
// Scalar gradient - apply post-CSE simplification
|
|
589
|
+
gradientsToUse.gradients.set(paramName, simplifyPostCSE(gradient));
|
|
590
|
+
}
|
|
591
|
+
}
|
|
592
|
+
}
|
|
593
|
+
// Generate CSE intermediate variables
|
|
594
|
+
if (cseIntermediates.size > 0) {
|
|
595
|
+
for (const [varName, expr] of cseIntermediates.entries()) {
|
|
596
|
+
const code = codegen.generate(expr);
|
|
597
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
598
|
+
lines.push(` const ${varName} = ${code};`);
|
|
599
|
+
}
|
|
600
|
+
else if (format === 'python') {
|
|
601
|
+
lines.push(` ${varName} = ${code}`);
|
|
602
|
+
}
|
|
603
|
+
else if (format === 'csharp') {
|
|
604
|
+
lines.push(` ${csharpFloatType} ${varName} = ${code};`);
|
|
605
|
+
}
|
|
445
606
|
}
|
|
607
|
+
lines.push('');
|
|
446
608
|
}
|
|
447
609
|
for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
|
|
448
610
|
// Use shorter names: du, dv instead of grad_u, grad_v
|
|
@@ -530,11 +692,12 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
530
692
|
return lines.join('\n');
|
|
531
693
|
}
|
|
532
694
|
/**
|
|
533
|
-
* Generate the original forward function
|
|
695
|
+
* Generate the original forward function (with optional e-graph optimization)
|
|
534
696
|
*/
|
|
535
697
|
export function generateForwardFunction(func, options = {}) {
|
|
536
698
|
const format = options.format || 'typescript';
|
|
537
699
|
const csharpFloatType = options.csharpFloatType || 'float';
|
|
700
|
+
const shouldOptimize = options.cse !== false; // Optimize by default
|
|
538
701
|
const codegen = new ExpressionCodeGen(format, csharpFloatType);
|
|
539
702
|
const lines = [];
|
|
540
703
|
// Function signature
|
|
@@ -549,28 +712,127 @@ export function generateForwardFunction(func, options = {}) {
|
|
|
549
712
|
const floatType = csharpFloatType;
|
|
550
713
|
const params = func.parameters.map(p => {
|
|
551
714
|
if (p.paramType && p.paramType.components) {
|
|
552
|
-
// Structured parameter - create a struct type name
|
|
553
715
|
return `${capitalize(p.name)}Struct ${p.name}`;
|
|
554
716
|
}
|
|
555
717
|
return `${floatType} ${p.name}`;
|
|
556
718
|
}).join(', ');
|
|
557
|
-
// Generate struct definitions for structured parameters first (we'll prepend them later)
|
|
558
719
|
lines.push(`public static ${floatType} ${capitalize(func.name)}(${params})`);
|
|
559
720
|
lines.push('{');
|
|
560
721
|
}
|
|
561
|
-
//
|
|
722
|
+
// Collect all forward variable names (for dependency tracking)
|
|
723
|
+
const forwardVars = new Set();
|
|
724
|
+
for (const stmt of func.body) {
|
|
725
|
+
if (stmt.kind === 'assignment') {
|
|
726
|
+
forwardVars.add(stmt.variable);
|
|
727
|
+
}
|
|
728
|
+
}
|
|
729
|
+
// Collect expressions for optimization
|
|
730
|
+
const varExpressions = new Map();
|
|
731
|
+
for (const stmt of func.body) {
|
|
732
|
+
if (stmt.kind === 'assignment') {
|
|
733
|
+
varExpressions.set(stmt.variable, stmt.expression);
|
|
734
|
+
}
|
|
735
|
+
}
|
|
736
|
+
// Optimize with e-graph if enabled
|
|
737
|
+
let optimizedExprs = varExpressions;
|
|
738
|
+
let cseTemps = new Map();
|
|
739
|
+
if (shouldOptimize && varExpressions.size > 0) {
|
|
740
|
+
const forOptimizer = new Map();
|
|
741
|
+
forOptimizer.set('_forward', varExpressions);
|
|
742
|
+
const result = optimizeWithEGraph(forOptimizer, { verbose: false });
|
|
743
|
+
cseTemps = result.intermediates;
|
|
744
|
+
optimizedExprs = result.gradients.get('_forward') || varExpressions;
|
|
745
|
+
}
|
|
746
|
+
// Helper to find which forward vars an expression depends on
|
|
747
|
+
function findForwardVarDeps(expr) {
|
|
748
|
+
const deps = new Set();
|
|
749
|
+
function visit(e) {
|
|
750
|
+
if (e.kind === 'variable' && forwardVars.has(e.name)) {
|
|
751
|
+
deps.add(e.name);
|
|
752
|
+
}
|
|
753
|
+
else if (e.kind === 'binary') {
|
|
754
|
+
visit(e.left);
|
|
755
|
+
visit(e.right);
|
|
756
|
+
}
|
|
757
|
+
else if (e.kind === 'unary') {
|
|
758
|
+
visit(e.operand);
|
|
759
|
+
}
|
|
760
|
+
else if (e.kind === 'call') {
|
|
761
|
+
e.args.forEach(visit);
|
|
762
|
+
}
|
|
763
|
+
else if (e.kind === 'component') {
|
|
764
|
+
visit(e.object);
|
|
765
|
+
}
|
|
766
|
+
}
|
|
767
|
+
visit(expr);
|
|
768
|
+
return deps;
|
|
769
|
+
}
|
|
770
|
+
// Track which temps need to be emitted after which forward var
|
|
771
|
+
const tempAfterVar = new Map();
|
|
772
|
+
const tempsEmittedBeforeAny = [];
|
|
773
|
+
for (const [tempName, tempExpr] of cseTemps) {
|
|
774
|
+
const deps = findForwardVarDeps(tempExpr);
|
|
775
|
+
if (deps.size === 0) {
|
|
776
|
+
// Temp only depends on params - emit before any forward vars
|
|
777
|
+
tempsEmittedBeforeAny.push({ name: tempName, expr: tempExpr });
|
|
778
|
+
}
|
|
779
|
+
else {
|
|
780
|
+
// Find the last forward var this temp depends on
|
|
781
|
+
let lastDep = '';
|
|
782
|
+
for (const stmt of func.body) {
|
|
783
|
+
if (stmt.kind === 'assignment' && deps.has(stmt.variable)) {
|
|
784
|
+
lastDep = stmt.variable;
|
|
785
|
+
}
|
|
786
|
+
}
|
|
787
|
+
if (lastDep) {
|
|
788
|
+
if (!tempAfterVar.has(lastDep)) {
|
|
789
|
+
tempAfterVar.set(lastDep, []);
|
|
790
|
+
}
|
|
791
|
+
tempAfterVar.get(lastDep).push({ name: tempName, expr: tempExpr });
|
|
792
|
+
}
|
|
793
|
+
}
|
|
794
|
+
}
|
|
795
|
+
// Emit temps that don't depend on forward vars
|
|
796
|
+
for (const { name: tempName, expr } of tempsEmittedBeforeAny) {
|
|
797
|
+
const code = codegen.generate(expr);
|
|
798
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
799
|
+
lines.push(` const ${tempName} = ${code};`);
|
|
800
|
+
}
|
|
801
|
+
else if (format === 'python') {
|
|
802
|
+
lines.push(` ${tempName} = ${code}`);
|
|
803
|
+
}
|
|
804
|
+
else if (format === 'csharp') {
|
|
805
|
+
lines.push(` ${csharpFloatType} ${tempName} = ${code};`);
|
|
806
|
+
}
|
|
807
|
+
}
|
|
808
|
+
// Generate variable assignments with interleaved temps
|
|
562
809
|
for (const stmt of func.body) {
|
|
563
810
|
if (stmt.kind === 'assignment') {
|
|
564
811
|
const varName = stmt.variable;
|
|
565
|
-
const expr =
|
|
812
|
+
const expr = optimizedExprs.get(varName) || stmt.expression;
|
|
813
|
+
const code = codegen.generate(expr);
|
|
566
814
|
if (format === 'typescript' || format === 'javascript') {
|
|
567
|
-
lines.push(` const ${varName} = ${
|
|
815
|
+
lines.push(` const ${varName} = ${code};`);
|
|
568
816
|
}
|
|
569
817
|
else if (format === 'python') {
|
|
570
|
-
lines.push(` ${varName} = ${
|
|
818
|
+
lines.push(` ${varName} = ${code}`);
|
|
571
819
|
}
|
|
572
820
|
else if (format === 'csharp') {
|
|
573
|
-
lines.push(` ${csharpFloatType} ${varName} = ${
|
|
821
|
+
lines.push(` ${csharpFloatType} ${varName} = ${code};`);
|
|
822
|
+
}
|
|
823
|
+
// Emit any temps that depend on this var
|
|
824
|
+
const tempsForVar = tempAfterVar.get(varName) || [];
|
|
825
|
+
for (const { name: tempName, expr: tempExpr } of tempsForVar) {
|
|
826
|
+
const tempCode = codegen.generate(tempExpr);
|
|
827
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
828
|
+
lines.push(` const ${tempName} = ${tempCode};`);
|
|
829
|
+
}
|
|
830
|
+
else if (format === 'python') {
|
|
831
|
+
lines.push(` ${tempName} = ${tempCode}`);
|
|
832
|
+
}
|
|
833
|
+
else if (format === 'csharp') {
|
|
834
|
+
lines.push(` ${csharpFloatType} ${tempName} = ${tempCode};`);
|
|
835
|
+
}
|
|
574
836
|
}
|
|
575
837
|
}
|
|
576
838
|
}
|
|
@@ -54,10 +54,16 @@ export declare function containsVariable(expr: Expression, varName: string): boo
|
|
|
54
54
|
*/
|
|
55
55
|
export declare function expressionDepth(expr: Expression): number;
|
|
56
56
|
/**
|
|
57
|
-
* Serializes an expression to
|
|
58
|
-
* Used for expression comparison
|
|
57
|
+
* Serializes an expression to structural string representation.
|
|
58
|
+
* Used for exact expression comparison - operand order matters.
|
|
59
59
|
*
|
|
60
60
|
* This ensures consistent string representation of expressions across different
|
|
61
61
|
* parts of the codebase.
|
|
62
62
|
*/
|
|
63
63
|
export declare function serializeExpression(expr: Expression): string;
|
|
64
|
+
/**
|
|
65
|
+
* Serializes an expression to canonical form for CSE matching.
|
|
66
|
+
* Commutative operations (+ and *) have operands sorted lexicographically,
|
|
67
|
+
* so a*b and b*a produce the same canonical string.
|
|
68
|
+
*/
|
|
69
|
+
export declare function serializeCanonical(expr: Expression): string;
|
|
@@ -174,8 +174,8 @@ export function expressionDepth(expr) {
|
|
|
174
174
|
}
|
|
175
175
|
}
|
|
176
176
|
/**
|
|
177
|
-
* Serializes an expression to
|
|
178
|
-
* Used for expression comparison
|
|
177
|
+
* Serializes an expression to structural string representation.
|
|
178
|
+
* Used for exact expression comparison - operand order matters.
|
|
179
179
|
*
|
|
180
180
|
* This ensures consistent string representation of expressions across different
|
|
181
181
|
* parts of the codebase.
|
|
@@ -197,3 +197,35 @@ export function serializeExpression(expr) {
|
|
|
197
197
|
return `comp(${serializeExpression(expr.object)},${expr.component})`;
|
|
198
198
|
}
|
|
199
199
|
}
|
|
200
|
+
/**
|
|
201
|
+
* Serializes an expression to canonical form for CSE matching.
|
|
202
|
+
* Commutative operations (+ and *) have operands sorted lexicographically,
|
|
203
|
+
* so a*b and b*a produce the same canonical string.
|
|
204
|
+
*/
|
|
205
|
+
export function serializeCanonical(expr) {
|
|
206
|
+
switch (expr.kind) {
|
|
207
|
+
case 'number':
|
|
208
|
+
return `num(${expr.value})`;
|
|
209
|
+
case 'variable':
|
|
210
|
+
return `var(${expr.name})`;
|
|
211
|
+
case 'binary': {
|
|
212
|
+
const leftStr = serializeCanonical(expr.left);
|
|
213
|
+
const rightStr = serializeCanonical(expr.right);
|
|
214
|
+
// For commutative operations, sort operands lexicographically
|
|
215
|
+
if (expr.operator === '+' || expr.operator === '*') {
|
|
216
|
+
const [first, second] = leftStr <= rightStr ? [leftStr, rightStr] : [rightStr, leftStr];
|
|
217
|
+
return `bin(${expr.operator},${first},${second})`;
|
|
218
|
+
}
|
|
219
|
+
// Non-commutative: preserve order
|
|
220
|
+
return `bin(${expr.operator},${leftStr},${rightStr})`;
|
|
221
|
+
}
|
|
222
|
+
case 'unary':
|
|
223
|
+
return `un(${expr.operator},${serializeCanonical(expr.operand)})`;
|
|
224
|
+
case 'call': {
|
|
225
|
+
const args = expr.args.map(arg => serializeCanonical(arg)).join(',');
|
|
226
|
+
return `call(${expr.name},${args})`;
|
|
227
|
+
}
|
|
228
|
+
case 'component':
|
|
229
|
+
return `comp(${serializeCanonical(expr.object)},${expr.component})`;
|
|
230
|
+
}
|
|
231
|
+
}
|
|
@@ -17,9 +17,25 @@ type NumValue = number | {
|
|
|
17
17
|
export interface GradCheckResult {
|
|
18
18
|
passed: boolean;
|
|
19
19
|
errors: GradCheckError[];
|
|
20
|
+
singularities: GradCheckSingularity[];
|
|
20
21
|
maxError: number;
|
|
21
22
|
meanError: number;
|
|
23
|
+
totalChecks: number;
|
|
22
24
|
}
|
|
25
|
+
/**
|
|
26
|
+
* Singularity detected during gradient checking
|
|
27
|
+
* When both analytical and numerical produce NaN/Inf, it's a singularity, not a bug
|
|
28
|
+
*/
|
|
29
|
+
export interface GradCheckSingularity {
|
|
30
|
+
parameter: string;
|
|
31
|
+
component?: string;
|
|
32
|
+
analytical: number;
|
|
33
|
+
numerical: number;
|
|
34
|
+
}
|
|
35
|
+
/**
|
|
36
|
+
* Format gradient check results as a human-readable string
|
|
37
|
+
*/
|
|
38
|
+
export declare function formatGradCheckResult(result: GradCheckResult, funcName: string): string;
|
|
23
39
|
export interface GradCheckError {
|
|
24
40
|
parameter: string;
|
|
25
41
|
component?: string;
|
|
@@ -39,6 +55,11 @@ export declare class GradientChecker {
|
|
|
39
55
|
* Check gradients for a function
|
|
40
56
|
*/
|
|
41
57
|
check(func: FunctionDef, gradients: GradientResult, env: TypeEnv, testPoint: Map<string, NumValue>): GradCheckResult;
|
|
58
|
+
/**
|
|
59
|
+
* Compare analytical and numerical gradients
|
|
60
|
+
* Distinguishes between: pass, error (mismatch), and singularity (both NaN/Inf)
|
|
61
|
+
*/
|
|
62
|
+
private compareGradients;
|
|
42
63
|
/**
|
|
43
64
|
* Compute numerical gradient for scalar parameter using finite differences
|
|
44
65
|
*/
|