gradient-script 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -1
- package/dist/cli.js +80 -3
- package/dist/dsl/CodeGen.d.ts +1 -1
- package/dist/dsl/CodeGen.js +332 -74
- package/dist/dsl/ExpressionUtils.d.ts +8 -2
- package/dist/dsl/ExpressionUtils.js +34 -2
- package/dist/dsl/GradientChecker.d.ts +21 -0
- package/dist/dsl/GradientChecker.js +109 -23
- package/dist/dsl/Guards.d.ts +1 -1
- package/dist/dsl/Guards.js +14 -13
- package/dist/dsl/Inliner.d.ts +5 -0
- package/dist/dsl/Inliner.js +8 -0
- package/dist/dsl/Simplify.d.ts +7 -0
- package/dist/dsl/Simplify.js +136 -0
- package/dist/dsl/egraph/Convert.d.ts +23 -0
- package/dist/dsl/egraph/Convert.js +84 -0
- package/dist/dsl/egraph/EGraph.d.ts +93 -0
- package/dist/dsl/egraph/EGraph.js +292 -0
- package/dist/dsl/egraph/ENode.d.ts +63 -0
- package/dist/dsl/egraph/ENode.js +94 -0
- package/dist/dsl/egraph/Extractor.d.ts +49 -0
- package/dist/dsl/egraph/Extractor.js +1068 -0
- package/dist/dsl/egraph/Optimizer.d.ts +50 -0
- package/dist/dsl/egraph/Optimizer.js +88 -0
- package/dist/dsl/egraph/Pattern.d.ts +80 -0
- package/dist/dsl/egraph/Pattern.js +325 -0
- package/dist/dsl/egraph/Rewriter.d.ts +44 -0
- package/dist/dsl/egraph/Rewriter.js +131 -0
- package/dist/dsl/egraph/Rules.d.ts +44 -0
- package/dist/dsl/egraph/Rules.js +187 -0
- package/dist/dsl/egraph/index.d.ts +15 -0
- package/dist/dsl/egraph/index.js +21 -0
- package/package.json +1 -1
- package/dist/dsl/CSE.d.ts +0 -21
- package/dist/dsl/CSE.js +0 -168
- package/dist/symbolic/AST.d.ts +0 -113
- package/dist/symbolic/AST.js +0 -128
- package/dist/symbolic/CodeGen.d.ts +0 -35
- package/dist/symbolic/CodeGen.js +0 -280
- package/dist/symbolic/Parser.d.ts +0 -64
- package/dist/symbolic/Parser.js +0 -329
- package/dist/symbolic/Simplify.d.ts +0 -10
- package/dist/symbolic/Simplify.js +0 -244
- package/dist/symbolic/SymbolicDiff.d.ts +0 -35
- package/dist/symbolic/SymbolicDiff.js +0 -339
package/README.md
CHANGED
|
@@ -12,6 +12,8 @@
|
|
|
12
12
|
|
|
13
13
|
GradientScript is a source-to-source compiler that automatically generates gradient functions from your mathematical code. Unlike numerical AD frameworks (JAX, PyTorch), it produces clean, human-readable gradient formulas you can inspect, optimize, and integrate directly into your codebase.
|
|
14
14
|
|
|
15
|
+
It's perfect for LLM usage where the LLM can verify existing gradients or construct gradients you require with less risks of making errors.
|
|
16
|
+
|
|
15
17
|
## Why GradientScript?
|
|
16
18
|
|
|
17
19
|
- **From real code to gradients**: Write natural math code, get symbolic derivatives
|
|
@@ -40,7 +42,7 @@ function distance(u: Vec2, v: Vec2): number {
|
|
|
40
42
|
}
|
|
41
43
|
```
|
|
42
44
|
|
|
43
|
-
Convert it to GradientScript by marking what you need gradients for:
|
|
45
|
+
Convert it to GradientScript (realistically, let your LLM convert it giving it a reference here - and/or free usage of the CLI) by marking what you need gradients for:
|
|
44
46
|
|
|
45
47
|
```typescript
|
|
46
48
|
// distance.gs
|
package/dist/cli.js
CHANGED
|
@@ -6,6 +6,68 @@ import { computeFunctionGradients } from './dsl/Differentiation.js';
|
|
|
6
6
|
import { generateComplete } from './dsl/CodeGen.js';
|
|
7
7
|
import { analyzeGuards, formatGuardWarnings } from './dsl/Guards.js';
|
|
8
8
|
import { ParseError, formatParseError } from './dsl/Errors.js';
|
|
9
|
+
import { GradientChecker, formatGradCheckResult } from './dsl/GradientChecker.js';
|
|
10
|
+
import { Types } from './dsl/Types.js';
|
|
11
|
+
/**
|
|
12
|
+
* Generate random test points for gradient verification.
|
|
13
|
+
* Uses multiple test points to catch errors at different values.
|
|
14
|
+
*/
|
|
15
|
+
function generateTestPoints(func, env) {
|
|
16
|
+
const testPoints = [];
|
|
17
|
+
// Generate 3 different test points with varying scales
|
|
18
|
+
const scales = [1.0, 0.1, 10.0];
|
|
19
|
+
for (const scale of scales) {
|
|
20
|
+
const point = new Map();
|
|
21
|
+
for (const param of func.parameters) {
|
|
22
|
+
const paramType = env.getOrThrow(param.name);
|
|
23
|
+
if (Types.isScalar(paramType)) {
|
|
24
|
+
// Random scalar in range [-scale, scale], avoid zero
|
|
25
|
+
point.set(param.name, (Math.random() * 2 - 1) * scale + 0.1 * scale);
|
|
26
|
+
}
|
|
27
|
+
else {
|
|
28
|
+
// Structured type - get components
|
|
29
|
+
const struct = {};
|
|
30
|
+
for (const comp of paramType.components) {
|
|
31
|
+
struct[comp] = (Math.random() * 2 - 1) * scale + 0.1 * scale;
|
|
32
|
+
}
|
|
33
|
+
point.set(param.name, struct);
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
testPoints.push(point);
|
|
37
|
+
}
|
|
38
|
+
return testPoints;
|
|
39
|
+
}
|
|
40
|
+
/**
|
|
41
|
+
* Verify gradients for a function using numerical differentiation.
|
|
42
|
+
* Returns true if all gradients pass, false otherwise.
|
|
43
|
+
*/
|
|
44
|
+
function verifyGradients(func, gradients, env) {
|
|
45
|
+
const checker = new GradientChecker(1e-5, 1e-4);
|
|
46
|
+
const testPoints = generateTestPoints(func, env);
|
|
47
|
+
let allPassed = true;
|
|
48
|
+
for (let i = 0; i < testPoints.length; i++) {
|
|
49
|
+
const result = checker.check(func, gradients, env, testPoints[i]);
|
|
50
|
+
if (!result.passed) {
|
|
51
|
+
if (allPassed) {
|
|
52
|
+
// First failure - print header (as comment for valid output)
|
|
53
|
+
console.error(`// Gradient verification FAILED for "${func.name}":`);
|
|
54
|
+
}
|
|
55
|
+
// Prefix each line with // so output remains valid code
|
|
56
|
+
const formattedResult = formatGradCheckResult(result, func.name)
|
|
57
|
+
.split('\n')
|
|
58
|
+
.map(line => '// ' + line)
|
|
59
|
+
.join('\n');
|
|
60
|
+
console.error(`// Test point ${i + 1}: ${formattedResult}`);
|
|
61
|
+
allPassed = false;
|
|
62
|
+
}
|
|
63
|
+
}
|
|
64
|
+
if (allPassed) {
|
|
65
|
+
const result = checker.check(func, gradients, env, testPoints[0]);
|
|
66
|
+
// Prefix with // so output is valid code
|
|
67
|
+
console.error('// ' + formatGradCheckResult(result, func.name));
|
|
68
|
+
}
|
|
69
|
+
return allPassed;
|
|
70
|
+
}
|
|
9
71
|
function printUsage() {
|
|
10
72
|
console.log(`
|
|
11
73
|
GradientScript - Symbolic Differentiation for Structured Types
|
|
@@ -16,7 +78,7 @@ Usage:
|
|
|
16
78
|
Options:
|
|
17
79
|
--format <format> Output format: typescript (default), javascript, python, csharp
|
|
18
80
|
--no-simplify Disable gradient simplification
|
|
19
|
-
--no-cse Disable
|
|
81
|
+
--no-cse Disable optimization (e-graph CSE)
|
|
20
82
|
--no-comments Omit comments in generated code
|
|
21
83
|
--guards Emit runtime guards for division by zero (experimental)
|
|
22
84
|
--epsilon <value> Epsilon value for guards (default: 1e-10)
|
|
@@ -45,6 +107,9 @@ For more information and examples:
|
|
|
45
107
|
|
|
46
108
|
README (raw, LLM-friendly):
|
|
47
109
|
https://raw.githubusercontent.com/mfagerlund/gradient-script/main/README.md
|
|
110
|
+
|
|
111
|
+
LLM Optimization Guide (for AI agents writing .gs files):
|
|
112
|
+
https://raw.githubusercontent.com/mfagerlund/gradient-script/main/docs/LLM-OPTIMIZATION-GUIDE.md
|
|
48
113
|
`.trim());
|
|
49
114
|
}
|
|
50
115
|
function main() {
|
|
@@ -64,6 +129,7 @@ function main() {
|
|
|
64
129
|
simplify: true,
|
|
65
130
|
cse: true
|
|
66
131
|
};
|
|
132
|
+
let skipVerify = false;
|
|
67
133
|
for (let i = 1; i < args.length; i++) {
|
|
68
134
|
const arg = args[i];
|
|
69
135
|
if (arg === '--format') {
|
|
@@ -138,13 +204,20 @@ function main() {
|
|
|
138
204
|
process.exit(1);
|
|
139
205
|
}
|
|
140
206
|
const outputs = [];
|
|
207
|
+
let hasVerificationFailure = false;
|
|
141
208
|
program.functions.forEach((func, index) => {
|
|
142
209
|
const env = inferFunction(func);
|
|
143
210
|
const gradients = computeFunctionGradients(func, env);
|
|
211
|
+
// MANDATORY gradient verification
|
|
212
|
+
const verified = verifyGradients(func, gradients, env);
|
|
213
|
+
if (!verified) {
|
|
214
|
+
hasVerificationFailure = true;
|
|
215
|
+
}
|
|
144
216
|
const guardAnalysis = analyzeGuards(func);
|
|
145
217
|
if (guardAnalysis.hasIssues) {
|
|
146
|
-
|
|
147
|
-
console.error(
|
|
218
|
+
// Format warnings as comments so output remains valid code even if stderr is captured
|
|
219
|
+
console.error('// Function "' + func.name + '" may have edge cases:');
|
|
220
|
+
console.error(formatGuardWarnings(guardAnalysis, true));
|
|
148
221
|
}
|
|
149
222
|
const perFunctionOptions = { ...options };
|
|
150
223
|
if (index > 0 && perFunctionOptions.includeComments !== false) {
|
|
@@ -153,6 +226,10 @@ function main() {
|
|
|
153
226
|
const code = generateComplete(func, gradients, env, perFunctionOptions);
|
|
154
227
|
outputs.push(code);
|
|
155
228
|
});
|
|
229
|
+
if (hasVerificationFailure) {
|
|
230
|
+
console.error('// ERROR: Gradient verification failed. Output may contain incorrect gradients!');
|
|
231
|
+
process.exit(1);
|
|
232
|
+
}
|
|
156
233
|
console.log(outputs.join('\n\n'));
|
|
157
234
|
}
|
|
158
235
|
catch (err) {
|
package/dist/dsl/CodeGen.d.ts
CHANGED
|
@@ -56,7 +56,7 @@ export declare class ExpressionCodeGen {
|
|
|
56
56
|
*/
|
|
57
57
|
export declare function generateGradientFunction(func: FunctionDef, gradients: GradientResult, env: TypeEnv, options?: CodeGenOptions): string;
|
|
58
58
|
/**
|
|
59
|
-
* Generate the original forward function
|
|
59
|
+
* Generate the original forward function (with optional e-graph optimization)
|
|
60
60
|
*/
|
|
61
61
|
export declare function generateForwardFunction(func: FunctionDef, options?: CodeGenOptions): string;
|
|
62
62
|
/**
|
package/dist/dsl/CodeGen.js
CHANGED
|
@@ -2,11 +2,12 @@
|
|
|
2
2
|
* Code generation for GradientScript DSL
|
|
3
3
|
* Generates TypeScript/JavaScript code with gradient functions
|
|
4
4
|
*/
|
|
5
|
-
import { simplifyGradients } from './Simplify.js';
|
|
5
|
+
import { simplifyGradients, simplifyPostCSE } from './Simplify.js';
|
|
6
6
|
import { ExpressionTransformer } from './ExpressionTransformer.js';
|
|
7
|
-
import {
|
|
7
|
+
import { optimizeWithEGraph } from './egraph/index.js';
|
|
8
8
|
import { CodeGenError } from './Errors.js';
|
|
9
9
|
import { serializeExpression } from './ExpressionUtils.js';
|
|
10
|
+
import { inlineExpression } from './Inliner.js';
|
|
10
11
|
function capitalize(str) {
|
|
11
12
|
return str.charAt(0).toUpperCase() + str.slice(1);
|
|
12
13
|
}
|
|
@@ -79,10 +80,12 @@ export class ExpressionCodeGen {
|
|
|
79
80
|
return left;
|
|
80
81
|
}
|
|
81
82
|
else if (exponent === 2) {
|
|
82
|
-
|
|
83
|
+
// Wrap in parens because we're changing ^ (precedence 3) to * (precedence 2)
|
|
84
|
+
// Without parens, a / b^2 would become a / b * b instead of a / (b * b)
|
|
85
|
+
return `(${left} * ${left})`;
|
|
83
86
|
}
|
|
84
87
|
else if (exponent === 3) {
|
|
85
|
-
return
|
|
88
|
+
return `(${left} * ${left} * ${left})`;
|
|
86
89
|
}
|
|
87
90
|
}
|
|
88
91
|
}
|
|
@@ -300,15 +303,162 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
300
303
|
// Forward pass - compute intermediate variables
|
|
301
304
|
// Track which expressions are already computed for CSE reuse
|
|
302
305
|
const forwardExpressionMap = new Map();
|
|
306
|
+
// Build substitution map for inlining (to match gradient expressions which are fully inlined)
|
|
307
|
+
const substitutionMap = new Map();
|
|
308
|
+
for (const stmt of func.body) {
|
|
309
|
+
if (stmt.kind === 'assignment') {
|
|
310
|
+
substitutionMap.set(stmt.variable, stmt.expression);
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
// Collect forward variable names and expressions for optimization
|
|
314
|
+
const forwardVars = new Set();
|
|
315
|
+
const forwardVarExprs = new Map();
|
|
316
|
+
for (const stmt of func.body) {
|
|
317
|
+
if (stmt.kind === 'assignment') {
|
|
318
|
+
forwardVars.add(stmt.variable);
|
|
319
|
+
forwardVarExprs.set(stmt.variable, stmt.expression);
|
|
320
|
+
}
|
|
321
|
+
}
|
|
322
|
+
// Optimize forward expressions with e-graph
|
|
323
|
+
let optimizedForwardExprs = forwardVarExprs;
|
|
324
|
+
let forwardCseTemps = new Map();
|
|
325
|
+
if (options.cse !== false && forwardVarExprs.size > 0) {
|
|
326
|
+
const forOptimizer = new Map();
|
|
327
|
+
forOptimizer.set('_forward', forwardVarExprs);
|
|
328
|
+
const forwardResult = optimizeWithEGraph(forOptimizer, { verbose: false });
|
|
329
|
+
// Rename forward temps to avoid conflicts with gradient temps (use _fwd prefix)
|
|
330
|
+
const rawTemps = forwardResult.intermediates;
|
|
331
|
+
const renameMap = new Map();
|
|
332
|
+
for (const oldName of rawTemps.keys()) {
|
|
333
|
+
const newName = oldName.replace('_tmp', '_fwd');
|
|
334
|
+
renameMap.set(oldName, newName);
|
|
335
|
+
}
|
|
336
|
+
// Apply renaming to temp definitions
|
|
337
|
+
for (const [oldName, expr] of rawTemps) {
|
|
338
|
+
const newName = renameMap.get(oldName);
|
|
339
|
+
forwardCseTemps.set(newName, renameTempRefs(expr, renameMap));
|
|
340
|
+
}
|
|
341
|
+
// Apply renaming to optimized expressions
|
|
342
|
+
optimizedForwardExprs = new Map();
|
|
343
|
+
for (const [varName, expr] of (forwardResult.gradients.get('_forward') || forwardVarExprs)) {
|
|
344
|
+
optimizedForwardExprs.set(varName, renameTempRefs(expr, renameMap));
|
|
345
|
+
}
|
|
346
|
+
}
|
|
347
|
+
// Helper to rename temp references in an expression
|
|
348
|
+
function renameTempRefs(expr, renameMap) {
|
|
349
|
+
if (expr.kind === 'variable') {
|
|
350
|
+
const newName = renameMap.get(expr.name);
|
|
351
|
+
return newName ? { kind: 'variable', name: newName } : expr;
|
|
352
|
+
}
|
|
353
|
+
else if (expr.kind === 'binary') {
|
|
354
|
+
return {
|
|
355
|
+
kind: 'binary',
|
|
356
|
+
operator: expr.operator,
|
|
357
|
+
left: renameTempRefs(expr.left, renameMap),
|
|
358
|
+
right: renameTempRefs(expr.right, renameMap)
|
|
359
|
+
};
|
|
360
|
+
}
|
|
361
|
+
else if (expr.kind === 'unary') {
|
|
362
|
+
return {
|
|
363
|
+
kind: 'unary',
|
|
364
|
+
operator: expr.operator,
|
|
365
|
+
operand: renameTempRefs(expr.operand, renameMap)
|
|
366
|
+
};
|
|
367
|
+
}
|
|
368
|
+
else if (expr.kind === 'call') {
|
|
369
|
+
return {
|
|
370
|
+
kind: 'call',
|
|
371
|
+
name: expr.name,
|
|
372
|
+
args: expr.args.map(a => renameTempRefs(a, renameMap))
|
|
373
|
+
};
|
|
374
|
+
}
|
|
375
|
+
else if (expr.kind === 'component') {
|
|
376
|
+
return {
|
|
377
|
+
kind: 'component',
|
|
378
|
+
object: renameTempRefs(expr.object, renameMap),
|
|
379
|
+
component: expr.component
|
|
380
|
+
};
|
|
381
|
+
}
|
|
382
|
+
return expr;
|
|
383
|
+
}
|
|
384
|
+
// Helper to find which forward vars an expression depends on
|
|
385
|
+
function findForwardVarDeps(expr) {
|
|
386
|
+
const deps = new Set();
|
|
387
|
+
function visit(e) {
|
|
388
|
+
if (e.kind === 'variable' && forwardVars.has(e.name)) {
|
|
389
|
+
deps.add(e.name);
|
|
390
|
+
}
|
|
391
|
+
else if (e.kind === 'binary') {
|
|
392
|
+
visit(e.left);
|
|
393
|
+
visit(e.right);
|
|
394
|
+
}
|
|
395
|
+
else if (e.kind === 'unary') {
|
|
396
|
+
visit(e.operand);
|
|
397
|
+
}
|
|
398
|
+
else if (e.kind === 'call') {
|
|
399
|
+
e.args.forEach(visit);
|
|
400
|
+
}
|
|
401
|
+
else if (e.kind === 'component') {
|
|
402
|
+
visit(e.object);
|
|
403
|
+
}
|
|
404
|
+
}
|
|
405
|
+
visit(expr);
|
|
406
|
+
return deps;
|
|
407
|
+
}
|
|
408
|
+
// Track which CSE temps need to be emitted after which forward var
|
|
409
|
+
const fwdTempAfterVar = new Map();
|
|
410
|
+
const fwdTempsBeforeAny = [];
|
|
411
|
+
for (const [tempName, tempExpr] of forwardCseTemps) {
|
|
412
|
+
const deps = findForwardVarDeps(tempExpr);
|
|
413
|
+
if (deps.size === 0) {
|
|
414
|
+
fwdTempsBeforeAny.push({ name: tempName, expr: tempExpr });
|
|
415
|
+
}
|
|
416
|
+
else {
|
|
417
|
+
let lastDep = '';
|
|
418
|
+
for (const stmt of func.body) {
|
|
419
|
+
if (stmt.kind === 'assignment' && deps.has(stmt.variable)) {
|
|
420
|
+
lastDep = stmt.variable;
|
|
421
|
+
}
|
|
422
|
+
}
|
|
423
|
+
if (lastDep) {
|
|
424
|
+
if (!fwdTempAfterVar.has(lastDep)) {
|
|
425
|
+
fwdTempAfterVar.set(lastDep, []);
|
|
426
|
+
}
|
|
427
|
+
fwdTempAfterVar.get(lastDep).push({ name: tempName, expr: tempExpr });
|
|
428
|
+
}
|
|
429
|
+
}
|
|
430
|
+
}
|
|
431
|
+
// Emit forward CSE temps that don't depend on forward vars
|
|
432
|
+
for (const { name: tempName, expr } of fwdTempsBeforeAny) {
|
|
433
|
+
const code = codegen.generate(expr);
|
|
434
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
435
|
+
lines.push(` const ${tempName} = ${code};`);
|
|
436
|
+
}
|
|
437
|
+
else if (format === 'python') {
|
|
438
|
+
lines.push(` ${tempName} = ${code}`);
|
|
439
|
+
}
|
|
440
|
+
else if (format === 'csharp') {
|
|
441
|
+
lines.push(` ${csharpFloatType} ${tempName} = ${code};`);
|
|
442
|
+
}
|
|
443
|
+
}
|
|
444
|
+
// Generate forward variable assignments with interleaved temps
|
|
303
445
|
for (const stmt of func.body) {
|
|
304
446
|
if (stmt.kind === 'assignment') {
|
|
305
447
|
const varName = stmt.variable;
|
|
306
|
-
const
|
|
448
|
+
const expr = optimizedForwardExprs.get(varName) || stmt.expression;
|
|
449
|
+
const generatedExpr = codegen.generate(expr);
|
|
307
450
|
if (shouldTrackForForwardReuse(stmt.expression)) {
|
|
451
|
+
// Register the original expression
|
|
308
452
|
const exprKey = serializeExpression(stmt.expression);
|
|
309
453
|
if (!forwardExpressionMap.has(exprKey)) {
|
|
310
454
|
forwardExpressionMap.set(exprKey, varName);
|
|
311
455
|
}
|
|
456
|
+
// Also register the fully inlined form (this is what gradient expressions will have)
|
|
457
|
+
const inlinedExpr = inlineExpression(stmt.expression, substitutionMap);
|
|
458
|
+
const inlinedKey = serializeExpression(inlinedExpr);
|
|
459
|
+
if (inlinedKey !== exprKey && !forwardExpressionMap.has(inlinedKey)) {
|
|
460
|
+
forwardExpressionMap.set(inlinedKey, varName);
|
|
461
|
+
}
|
|
312
462
|
}
|
|
313
463
|
if (format === 'typescript' || format === 'javascript') {
|
|
314
464
|
lines.push(` const ${varName} = ${generatedExpr};`);
|
|
@@ -319,6 +469,20 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
319
469
|
else if (format === 'csharp') {
|
|
320
470
|
lines.push(` ${csharpFloatType} ${varName} = ${generatedExpr};`);
|
|
321
471
|
}
|
|
472
|
+
// Emit any forward CSE temps that depend on this var
|
|
473
|
+
const tempsForVar = fwdTempAfterVar.get(varName) || [];
|
|
474
|
+
for (const { name: tempName, expr: tempExpr } of tempsForVar) {
|
|
475
|
+
const tempCode = codegen.generate(tempExpr);
|
|
476
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
477
|
+
lines.push(` const ${tempName} = ${tempCode};`);
|
|
478
|
+
}
|
|
479
|
+
else if (format === 'python') {
|
|
480
|
+
lines.push(` ${tempName} = ${tempCode}`);
|
|
481
|
+
}
|
|
482
|
+
else if (format === 'csharp') {
|
|
483
|
+
lines.push(` ${csharpFloatType} ${tempName} = ${tempCode};`);
|
|
484
|
+
}
|
|
485
|
+
}
|
|
322
486
|
}
|
|
323
487
|
}
|
|
324
488
|
// Compute output value - reuse forward pass variables if possible
|
|
@@ -372,77 +536,71 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
372
536
|
if (includeComments) {
|
|
373
537
|
lines.push(` ${comment} Gradients`);
|
|
374
538
|
}
|
|
375
|
-
// Apply CSE
|
|
539
|
+
// Apply e-graph optimization (CSE + algebraic simplification)
|
|
376
540
|
const shouldApplyCSE = options.cse !== false; // Default to true
|
|
377
|
-
|
|
541
|
+
let cseIntermediates = new Map();
|
|
378
542
|
if (shouldApplyCSE) {
|
|
379
|
-
// Collect all gradient
|
|
543
|
+
// Collect all gradient components into a single map for global optimization
|
|
544
|
+
// Include both structured and scalar gradients (scalar uses 'value' as component)
|
|
545
|
+
const allGradientComponents = new Map();
|
|
380
546
|
for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
|
|
381
547
|
if (isStructuredGradient(gradient)) {
|
|
382
|
-
|
|
383
|
-
// Merge intermediates
|
|
384
|
-
for (const [name, expr] of cseResult.intermediates.entries()) {
|
|
385
|
-
cseIntermediates.set(name, expr);
|
|
386
|
-
}
|
|
387
|
-
// Update gradient components with CSE-simplified versions
|
|
388
|
-
gradient.components = cseResult.components;
|
|
389
|
-
}
|
|
390
|
-
}
|
|
391
|
-
// Generate intermediate variables from CSE
|
|
392
|
-
if (cseIntermediates.size > 0) {
|
|
393
|
-
// Check if we should emit guards (opt-in)
|
|
394
|
-
const shouldEmitGuards = options.emitGuards === true;
|
|
395
|
-
const epsilon = options.epsilon || 1e-10;
|
|
396
|
-
// Identify potential denominators (sum of squares patterns)
|
|
397
|
-
const denominatorVars = new Set();
|
|
398
|
-
for (const [varName, expr] of cseIntermediates.entries()) {
|
|
399
|
-
const code = codegen.generate(expr);
|
|
400
|
-
// Check if this looks like a denominator (contains + and squared terms)
|
|
401
|
-
if (code.includes('+') && (code.includes('* ') || code.includes('Math.pow'))) {
|
|
402
|
-
denominatorVars.add(varName);
|
|
403
|
-
}
|
|
548
|
+
allGradientComponents.set(paramName, gradient.components);
|
|
404
549
|
}
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
lines.push(` const ${varName} = ${code};`);
|
|
409
|
-
}
|
|
410
|
-
else if (format === 'python') {
|
|
411
|
-
lines.push(` ${varName} = ${code}`);
|
|
412
|
-
}
|
|
413
|
-
else if (format === 'csharp') {
|
|
414
|
-
lines.push(` ${csharpFloatType} ${varName} = ${code};`);
|
|
415
|
-
}
|
|
550
|
+
else {
|
|
551
|
+
// Scalar gradient - wrap as a single 'value' component
|
|
552
|
+
allGradientComponents.set(paramName, new Map([['value', gradient]]));
|
|
416
553
|
}
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
554
|
+
}
|
|
555
|
+
// Run e-graph optimization globally across ALL gradient expressions
|
|
556
|
+
const globalCSE = optimizeWithEGraph(allGradientComponents, { verbose: false });
|
|
557
|
+
cseIntermediates = globalCSE.intermediates;
|
|
558
|
+
// Update gradient components with optimized versions
|
|
559
|
+
for (const [paramName, simplifiedComponents] of globalCSE.gradients.entries()) {
|
|
560
|
+
const gradient = gradientsToUse.gradients.get(paramName);
|
|
561
|
+
if (gradient && isStructuredGradient(gradient)) {
|
|
562
|
+
gradient.components = simplifiedComponents;
|
|
563
|
+
}
|
|
564
|
+
else {
|
|
565
|
+
// Scalar gradient - unwrap from 'value' component
|
|
566
|
+
const optimizedExpr = simplifiedComponents.get('value');
|
|
567
|
+
if (optimizedExpr) {
|
|
568
|
+
gradientsToUse.gradients.set(paramName, optimizedExpr);
|
|
422
569
|
}
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
}
|
|
435
|
-
else {
|
|
436
|
-
zeroGrads.push(`d${paramName}: 0`);
|
|
437
|
-
}
|
|
438
|
-
}
|
|
439
|
-
lines.push(` return { value, ${zeroGrads.join(', ')} };`);
|
|
440
|
-
lines.push(` }`);
|
|
441
|
-
}
|
|
570
|
+
}
|
|
571
|
+
}
|
|
572
|
+
// Post-CSE simplification: apply rules that were skipped to avoid CSE interference
|
|
573
|
+
// Specifically: a + a → 2 * a (now safe because temps have been extracted)
|
|
574
|
+
for (const [varName, expr] of cseIntermediates) {
|
|
575
|
+
cseIntermediates.set(varName, simplifyPostCSE(expr));
|
|
576
|
+
}
|
|
577
|
+
for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
|
|
578
|
+
if (isStructuredGradient(gradient)) {
|
|
579
|
+
for (const [comp, expr] of gradient.components.entries()) {
|
|
580
|
+
gradient.components.set(comp, simplifyPostCSE(expr));
|
|
442
581
|
}
|
|
443
582
|
}
|
|
444
|
-
|
|
583
|
+
else {
|
|
584
|
+
// Scalar gradient - apply post-CSE simplification
|
|
585
|
+
gradientsToUse.gradients.set(paramName, simplifyPostCSE(gradient));
|
|
586
|
+
}
|
|
587
|
+
}
|
|
588
|
+
}
|
|
589
|
+
// Generate CSE intermediate variables
|
|
590
|
+
if (cseIntermediates.size > 0) {
|
|
591
|
+
for (const [varName, expr] of cseIntermediates.entries()) {
|
|
592
|
+
const code = codegen.generate(expr);
|
|
593
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
594
|
+
lines.push(` const ${varName} = ${code};`);
|
|
595
|
+
}
|
|
596
|
+
else if (format === 'python') {
|
|
597
|
+
lines.push(` ${varName} = ${code}`);
|
|
598
|
+
}
|
|
599
|
+
else if (format === 'csharp') {
|
|
600
|
+
lines.push(` ${csharpFloatType} ${varName} = ${code};`);
|
|
601
|
+
}
|
|
445
602
|
}
|
|
603
|
+
lines.push('');
|
|
446
604
|
}
|
|
447
605
|
for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
|
|
448
606
|
// Use shorter names: du, dv instead of grad_u, grad_v
|
|
@@ -530,11 +688,12 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
530
688
|
return lines.join('\n');
|
|
531
689
|
}
|
|
532
690
|
/**
|
|
533
|
-
* Generate the original forward function
|
|
691
|
+
* Generate the original forward function (with optional e-graph optimization)
|
|
534
692
|
*/
|
|
535
693
|
export function generateForwardFunction(func, options = {}) {
|
|
536
694
|
const format = options.format || 'typescript';
|
|
537
695
|
const csharpFloatType = options.csharpFloatType || 'float';
|
|
696
|
+
const shouldOptimize = options.cse !== false; // Optimize by default
|
|
538
697
|
const codegen = new ExpressionCodeGen(format, csharpFloatType);
|
|
539
698
|
const lines = [];
|
|
540
699
|
// Function signature
|
|
@@ -549,28 +708,127 @@ export function generateForwardFunction(func, options = {}) {
|
|
|
549
708
|
const floatType = csharpFloatType;
|
|
550
709
|
const params = func.parameters.map(p => {
|
|
551
710
|
if (p.paramType && p.paramType.components) {
|
|
552
|
-
// Structured parameter - create a struct type name
|
|
553
711
|
return `${capitalize(p.name)}Struct ${p.name}`;
|
|
554
712
|
}
|
|
555
713
|
return `${floatType} ${p.name}`;
|
|
556
714
|
}).join(', ');
|
|
557
|
-
// Generate struct definitions for structured parameters first (we'll prepend them later)
|
|
558
715
|
lines.push(`public static ${floatType} ${capitalize(func.name)}(${params})`);
|
|
559
716
|
lines.push('{');
|
|
560
717
|
}
|
|
561
|
-
//
|
|
718
|
+
// Collect all forward variable names (for dependency tracking)
|
|
719
|
+
const forwardVars = new Set();
|
|
720
|
+
for (const stmt of func.body) {
|
|
721
|
+
if (stmt.kind === 'assignment') {
|
|
722
|
+
forwardVars.add(stmt.variable);
|
|
723
|
+
}
|
|
724
|
+
}
|
|
725
|
+
// Collect expressions for optimization
|
|
726
|
+
const varExpressions = new Map();
|
|
727
|
+
for (const stmt of func.body) {
|
|
728
|
+
if (stmt.kind === 'assignment') {
|
|
729
|
+
varExpressions.set(stmt.variable, stmt.expression);
|
|
730
|
+
}
|
|
731
|
+
}
|
|
732
|
+
// Optimize with e-graph if enabled
|
|
733
|
+
let optimizedExprs = varExpressions;
|
|
734
|
+
let cseTemps = new Map();
|
|
735
|
+
if (shouldOptimize && varExpressions.size > 0) {
|
|
736
|
+
const forOptimizer = new Map();
|
|
737
|
+
forOptimizer.set('_forward', varExpressions);
|
|
738
|
+
const result = optimizeWithEGraph(forOptimizer, { verbose: false });
|
|
739
|
+
cseTemps = result.intermediates;
|
|
740
|
+
optimizedExprs = result.gradients.get('_forward') || varExpressions;
|
|
741
|
+
}
|
|
742
|
+
// Helper to find which forward vars an expression depends on
|
|
743
|
+
function findForwardVarDeps(expr) {
|
|
744
|
+
const deps = new Set();
|
|
745
|
+
function visit(e) {
|
|
746
|
+
if (e.kind === 'variable' && forwardVars.has(e.name)) {
|
|
747
|
+
deps.add(e.name);
|
|
748
|
+
}
|
|
749
|
+
else if (e.kind === 'binary') {
|
|
750
|
+
visit(e.left);
|
|
751
|
+
visit(e.right);
|
|
752
|
+
}
|
|
753
|
+
else if (e.kind === 'unary') {
|
|
754
|
+
visit(e.operand);
|
|
755
|
+
}
|
|
756
|
+
else if (e.kind === 'call') {
|
|
757
|
+
e.args.forEach(visit);
|
|
758
|
+
}
|
|
759
|
+
else if (e.kind === 'component') {
|
|
760
|
+
visit(e.object);
|
|
761
|
+
}
|
|
762
|
+
}
|
|
763
|
+
visit(expr);
|
|
764
|
+
return deps;
|
|
765
|
+
}
|
|
766
|
+
// Track which temps need to be emitted after which forward var
|
|
767
|
+
const tempAfterVar = new Map();
|
|
768
|
+
const tempsEmittedBeforeAny = [];
|
|
769
|
+
for (const [tempName, tempExpr] of cseTemps) {
|
|
770
|
+
const deps = findForwardVarDeps(tempExpr);
|
|
771
|
+
if (deps.size === 0) {
|
|
772
|
+
// Temp only depends on params - emit before any forward vars
|
|
773
|
+
tempsEmittedBeforeAny.push({ name: tempName, expr: tempExpr });
|
|
774
|
+
}
|
|
775
|
+
else {
|
|
776
|
+
// Find the last forward var this temp depends on
|
|
777
|
+
let lastDep = '';
|
|
778
|
+
for (const stmt of func.body) {
|
|
779
|
+
if (stmt.kind === 'assignment' && deps.has(stmt.variable)) {
|
|
780
|
+
lastDep = stmt.variable;
|
|
781
|
+
}
|
|
782
|
+
}
|
|
783
|
+
if (lastDep) {
|
|
784
|
+
if (!tempAfterVar.has(lastDep)) {
|
|
785
|
+
tempAfterVar.set(lastDep, []);
|
|
786
|
+
}
|
|
787
|
+
tempAfterVar.get(lastDep).push({ name: tempName, expr: tempExpr });
|
|
788
|
+
}
|
|
789
|
+
}
|
|
790
|
+
}
|
|
791
|
+
// Emit temps that don't depend on forward vars
|
|
792
|
+
for (const { name: tempName, expr } of tempsEmittedBeforeAny) {
|
|
793
|
+
const code = codegen.generate(expr);
|
|
794
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
795
|
+
lines.push(` const ${tempName} = ${code};`);
|
|
796
|
+
}
|
|
797
|
+
else if (format === 'python') {
|
|
798
|
+
lines.push(` ${tempName} = ${code}`);
|
|
799
|
+
}
|
|
800
|
+
else if (format === 'csharp') {
|
|
801
|
+
lines.push(` ${csharpFloatType} ${tempName} = ${code};`);
|
|
802
|
+
}
|
|
803
|
+
}
|
|
804
|
+
// Generate variable assignments with interleaved temps
|
|
562
805
|
for (const stmt of func.body) {
|
|
563
806
|
if (stmt.kind === 'assignment') {
|
|
564
807
|
const varName = stmt.variable;
|
|
565
|
-
const expr =
|
|
808
|
+
const expr = optimizedExprs.get(varName) || stmt.expression;
|
|
809
|
+
const code = codegen.generate(expr);
|
|
566
810
|
if (format === 'typescript' || format === 'javascript') {
|
|
567
|
-
lines.push(` const ${varName} = ${
|
|
811
|
+
lines.push(` const ${varName} = ${code};`);
|
|
568
812
|
}
|
|
569
813
|
else if (format === 'python') {
|
|
570
|
-
lines.push(` ${varName} = ${
|
|
814
|
+
lines.push(` ${varName} = ${code}`);
|
|
571
815
|
}
|
|
572
816
|
else if (format === 'csharp') {
|
|
573
|
-
lines.push(` ${csharpFloatType} ${varName} = ${
|
|
817
|
+
lines.push(` ${csharpFloatType} ${varName} = ${code};`);
|
|
818
|
+
}
|
|
819
|
+
// Emit any temps that depend on this var
|
|
820
|
+
const tempsForVar = tempAfterVar.get(varName) || [];
|
|
821
|
+
for (const { name: tempName, expr: tempExpr } of tempsForVar) {
|
|
822
|
+
const tempCode = codegen.generate(tempExpr);
|
|
823
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
824
|
+
lines.push(` const ${tempName} = ${tempCode};`);
|
|
825
|
+
}
|
|
826
|
+
else if (format === 'python') {
|
|
827
|
+
lines.push(` ${tempName} = ${tempCode}`);
|
|
828
|
+
}
|
|
829
|
+
else if (format === 'csharp') {
|
|
830
|
+
lines.push(` ${csharpFloatType} ${tempName} = ${tempCode};`);
|
|
831
|
+
}
|
|
574
832
|
}
|
|
575
833
|
}
|
|
576
834
|
}
|
|
@@ -54,10 +54,16 @@ export declare function containsVariable(expr: Expression, varName: string): boo
|
|
|
54
54
|
*/
|
|
55
55
|
export declare function expressionDepth(expr: Expression): number;
|
|
56
56
|
/**
|
|
57
|
-
* Serializes an expression to
|
|
58
|
-
* Used for expression comparison
|
|
57
|
+
* Serializes an expression to structural string representation.
|
|
58
|
+
* Used for exact expression comparison - operand order matters.
|
|
59
59
|
*
|
|
60
60
|
* This ensures consistent string representation of expressions across different
|
|
61
61
|
* parts of the codebase.
|
|
62
62
|
*/
|
|
63
63
|
export declare function serializeExpression(expr: Expression): string;
|
|
64
|
+
/**
|
|
65
|
+
* Serializes an expression to canonical form for CSE matching.
|
|
66
|
+
* Commutative operations (+ and *) have operands sorted lexicographically,
|
|
67
|
+
* so a*b and b*a produce the same canonical string.
|
|
68
|
+
*/
|
|
69
|
+
export declare function serializeCanonical(expr: Expression): string;
|