gradient-script 0.1.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +52 -9
- package/dist/cli.js +134 -19
- package/dist/dsl/AST.d.ts +8 -0
- package/dist/dsl/CodeGen.d.ts +8 -3
- package/dist/dsl/CodeGen.js +583 -132
- package/dist/dsl/Errors.d.ts +6 -1
- package/dist/dsl/Errors.js +70 -1
- package/dist/dsl/Expander.js +5 -2
- package/dist/dsl/ExpressionUtils.d.ts +14 -0
- package/dist/dsl/ExpressionUtils.js +56 -0
- package/dist/dsl/GradientChecker.d.ts +21 -0
- package/dist/dsl/GradientChecker.js +109 -23
- package/dist/dsl/Guards.d.ts +3 -1
- package/dist/dsl/Guards.js +86 -43
- package/dist/dsl/Inliner.d.ts +5 -0
- package/dist/dsl/Inliner.js +11 -2
- package/dist/dsl/Lexer.js +3 -1
- package/dist/dsl/Parser.js +11 -5
- package/dist/dsl/Simplify.d.ts +7 -0
- package/dist/dsl/Simplify.js +183 -0
- package/dist/dsl/egraph/Convert.d.ts +23 -0
- package/dist/dsl/egraph/Convert.js +84 -0
- package/dist/dsl/egraph/EGraph.d.ts +93 -0
- package/dist/dsl/egraph/EGraph.js +292 -0
- package/dist/dsl/egraph/ENode.d.ts +63 -0
- package/dist/dsl/egraph/ENode.js +94 -0
- package/dist/dsl/egraph/Extractor.d.ts +49 -0
- package/dist/dsl/egraph/Extractor.js +1068 -0
- package/dist/dsl/egraph/Optimizer.d.ts +50 -0
- package/dist/dsl/egraph/Optimizer.js +88 -0
- package/dist/dsl/egraph/Pattern.d.ts +80 -0
- package/dist/dsl/egraph/Pattern.js +325 -0
- package/dist/dsl/egraph/Rewriter.d.ts +44 -0
- package/dist/dsl/egraph/Rewriter.js +131 -0
- package/dist/dsl/egraph/Rules.d.ts +44 -0
- package/dist/dsl/egraph/Rules.js +187 -0
- package/dist/dsl/egraph/index.d.ts +15 -0
- package/dist/dsl/egraph/index.js +21 -0
- package/package.json +1 -1
- package/dist/dsl/CSE.d.ts +0 -21
- package/dist/dsl/CSE.js +0 -194
- package/dist/symbolic/AST.d.ts +0 -113
- package/dist/symbolic/AST.js +0 -128
- package/dist/symbolic/CodeGen.d.ts +0 -35
- package/dist/symbolic/CodeGen.js +0 -280
- package/dist/symbolic/Parser.d.ts +0 -64
- package/dist/symbolic/Parser.js +0 -329
- package/dist/symbolic/Simplify.d.ts +0 -10
- package/dist/symbolic/Simplify.js +0 -244
- package/dist/symbolic/SymbolicDiff.d.ts +0 -35
- package/dist/symbolic/SymbolicDiff.js +0 -339
package/dist/dsl/CodeGen.js
CHANGED
|
@@ -2,16 +2,33 @@
|
|
|
2
2
|
* Code generation for GradientScript DSL
|
|
3
3
|
* Generates TypeScript/JavaScript code with gradient functions
|
|
4
4
|
*/
|
|
5
|
-
import { simplifyGradients } from './Simplify.js';
|
|
6
|
-
import {
|
|
5
|
+
import { simplifyGradients, simplifyPostCSE } from './Simplify.js';
|
|
6
|
+
import { ExpressionTransformer } from './ExpressionTransformer.js';
|
|
7
|
+
import { optimizeWithEGraph } from './egraph/index.js';
|
|
7
8
|
import { CodeGenError } from './Errors.js';
|
|
9
|
+
import { serializeExpression } from './ExpressionUtils.js';
|
|
10
|
+
import { inlineExpression } from './Inliner.js';
|
|
11
|
+
function capitalize(str) {
|
|
12
|
+
return str.charAt(0).toUpperCase() + str.slice(1);
|
|
13
|
+
}
|
|
14
|
+
function shouldTrackForForwardReuse(expr) {
|
|
15
|
+
switch (expr.kind) {
|
|
16
|
+
case 'number':
|
|
17
|
+
case 'variable':
|
|
18
|
+
return false;
|
|
19
|
+
default:
|
|
20
|
+
return true;
|
|
21
|
+
}
|
|
22
|
+
}
|
|
8
23
|
/**
|
|
9
24
|
* Code generator for expressions
|
|
10
25
|
*/
|
|
11
26
|
export class ExpressionCodeGen {
|
|
12
27
|
format;
|
|
13
|
-
|
|
28
|
+
csharpFloatType;
|
|
29
|
+
constructor(format = 'typescript', csharpFloatType = 'float') {
|
|
14
30
|
this.format = format;
|
|
31
|
+
this.csharpFloatType = csharpFloatType;
|
|
15
32
|
}
|
|
16
33
|
/**
|
|
17
34
|
* Generate code for an expression
|
|
@@ -47,7 +64,7 @@ export class ExpressionCodeGen {
|
|
|
47
64
|
if (this.format === 'python' && (op === '^' || op === '**')) {
|
|
48
65
|
op = '**'; // Python uses **
|
|
49
66
|
}
|
|
50
|
-
else if ((this.format === 'typescript' || this.format === 'javascript') && (op === '^' || op === '**')) {
|
|
67
|
+
else if ((this.format === 'typescript' || this.format === 'javascript' || this.format === 'csharp') && (op === '^' || op === '**')) {
|
|
51
68
|
// Optimize: x^2 -> x*x, x^3 -> x*x*x (faster than Math.pow)
|
|
52
69
|
// Only for simple expressions (variables, component access)
|
|
53
70
|
const isSimple = expr.left.kind === 'variable' ||
|
|
@@ -63,15 +80,23 @@ export class ExpressionCodeGen {
|
|
|
63
80
|
return left;
|
|
64
81
|
}
|
|
65
82
|
else if (exponent === 2) {
|
|
66
|
-
|
|
83
|
+
// Wrap in parens because we're changing ^ (precedence 3) to * (precedence 2)
|
|
84
|
+
// Without parens, a / b^2 would become a / b * b instead of a / (b * b)
|
|
85
|
+
return `(${left} * ${left})`;
|
|
67
86
|
}
|
|
68
87
|
else if (exponent === 3) {
|
|
69
|
-
return
|
|
88
|
+
return `(${left} * ${left} * ${left})`;
|
|
70
89
|
}
|
|
71
90
|
}
|
|
72
91
|
}
|
|
73
|
-
// Fall back to Math.pow for complex expressions or larger exponents
|
|
74
|
-
|
|
92
|
+
// Fall back to Math.pow / MathF.Pow for complex expressions or larger exponents
|
|
93
|
+
if (this.format === 'csharp') {
|
|
94
|
+
const mathClass = this.csharpFloatType === 'float' ? 'MathF' : 'Math';
|
|
95
|
+
return `${mathClass}.Pow(${left}, ${right})`;
|
|
96
|
+
}
|
|
97
|
+
else {
|
|
98
|
+
return `Math.pow(${left}, ${right})`;
|
|
99
|
+
}
|
|
75
100
|
}
|
|
76
101
|
return `${left} ${op} ${right}`;
|
|
77
102
|
}
|
|
@@ -160,6 +185,10 @@ export class ExpressionCodeGen {
|
|
|
160
185
|
else if (this.format === 'python') {
|
|
161
186
|
return `max(${min}, min(${max}, ${x}))`;
|
|
162
187
|
}
|
|
188
|
+
else if (this.format === 'csharp') {
|
|
189
|
+
const mathClass = this.csharpFloatType === 'float' ? 'MathF' : 'Math';
|
|
190
|
+
return `${mathClass}.Max(${min}, ${mathClass}.Min(${max}, ${x}))`;
|
|
191
|
+
}
|
|
163
192
|
}
|
|
164
193
|
// Map function names for different formats
|
|
165
194
|
const funcName = this.mapFunctionName(expr.name);
|
|
@@ -167,63 +196,91 @@ export class ExpressionCodeGen {
|
|
|
167
196
|
}
|
|
168
197
|
genComponent(expr) {
|
|
169
198
|
const obj = this.generate(expr.object);
|
|
199
|
+
if (this.format === 'csharp') {
|
|
200
|
+
// C# uses PascalCase for properties
|
|
201
|
+
return `${obj}.${capitalize(expr.component)}`;
|
|
202
|
+
}
|
|
170
203
|
return `${obj}.${expr.component}`;
|
|
171
204
|
}
|
|
205
|
+
// Math functions that should be mapped across all formats
|
|
206
|
+
static MATH_FUNCTIONS = [
|
|
207
|
+
'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'atan2',
|
|
208
|
+
'exp', 'log', 'sqrt', 'abs', 'pow', 'min', 'max'
|
|
209
|
+
];
|
|
210
|
+
// Python built-in functions that don't need the math. prefix
|
|
211
|
+
static PYTHON_BUILTINS = ['abs', 'pow', 'min', 'max'];
|
|
172
212
|
mapFunctionName(name) {
|
|
173
|
-
if
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
'
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
213
|
+
// Check if this is a known math function
|
|
214
|
+
if (!ExpressionCodeGen.MATH_FUNCTIONS.includes(name)) {
|
|
215
|
+
return name;
|
|
216
|
+
}
|
|
217
|
+
// Define format-specific mappers
|
|
218
|
+
const mappers = {
|
|
219
|
+
typescript: (fn) => `Math.${fn}`,
|
|
220
|
+
javascript: (fn) => `Math.${fn}`,
|
|
221
|
+
python: (fn) => ExpressionCodeGen.PYTHON_BUILTINS.includes(fn) ? fn : `math.${fn}`,
|
|
222
|
+
csharp: (fn) => {
|
|
223
|
+
const mathClass = this.csharpFloatType === 'float' ? 'MathF' : 'Math';
|
|
224
|
+
const capitalized = fn.charAt(0).toUpperCase() + fn.slice(1);
|
|
225
|
+
return `${mathClass}.${capitalized}`;
|
|
226
|
+
}
|
|
227
|
+
};
|
|
228
|
+
const mapper = mappers[this.format];
|
|
229
|
+
return mapper ? mapper(name) : name;
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
/**
|
|
233
|
+
* Generate C# struct for gradient return type
|
|
234
|
+
*/
|
|
235
|
+
function generateCSharpGradientStruct(func, gradients, floatType) {
|
|
236
|
+
const lines = [];
|
|
237
|
+
const structName = `${capitalize(func.name)}GradResult`;
|
|
238
|
+
lines.push(`public struct ${structName}`);
|
|
239
|
+
lines.push('{');
|
|
240
|
+
lines.push(` public ${floatType} Value;`);
|
|
241
|
+
for (const [paramName, gradient] of gradients.gradients.entries()) {
|
|
242
|
+
const propName = capitalize(`d${paramName}`);
|
|
243
|
+
if (isStructuredGradient(gradient)) {
|
|
244
|
+
// Generate a nested struct type for structured gradients
|
|
245
|
+
const components = Array.from(gradient.components.keys());
|
|
246
|
+
lines.push(` public ${capitalize(paramName)}Grad ${propName};`);
|
|
247
|
+
}
|
|
248
|
+
else {
|
|
249
|
+
lines.push(` public ${floatType} ${propName};`);
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
lines.push('}');
|
|
253
|
+
// Generate nested struct types for structured gradients
|
|
254
|
+
for (const [paramName, gradient] of gradients.gradients.entries()) {
|
|
255
|
+
if (isStructuredGradient(gradient)) {
|
|
256
|
+
lines.push('');
|
|
257
|
+
lines.push(`public struct ${capitalize(paramName)}Grad`);
|
|
258
|
+
lines.push('{');
|
|
259
|
+
for (const comp of gradient.components.keys()) {
|
|
260
|
+
lines.push(` public ${floatType} ${capitalize(comp)};`);
|
|
261
|
+
}
|
|
262
|
+
lines.push('}');
|
|
210
263
|
}
|
|
211
|
-
return name;
|
|
212
264
|
}
|
|
265
|
+
return lines;
|
|
213
266
|
}
|
|
214
267
|
/**
|
|
215
268
|
* Generate complete gradient function code
|
|
216
269
|
*/
|
|
217
270
|
export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
218
271
|
const format = options.format || 'typescript';
|
|
272
|
+
const csharpFloatType = options.csharpFloatType || 'float';
|
|
219
273
|
const includeComments = options.includeComments !== false;
|
|
220
274
|
const shouldSimplify = options.simplify !== false; // Default to true
|
|
221
|
-
//
|
|
222
|
-
const gradientsToUse =
|
|
223
|
-
|
|
224
|
-
: gradients;
|
|
225
|
-
const codegen = new ExpressionCodeGen(format);
|
|
275
|
+
// Note: We don't simplify here yet - we'll do it after forward expression substitution
|
|
276
|
+
const gradientsToUse = gradients;
|
|
277
|
+
const codegen = new ExpressionCodeGen(format, csharpFloatType);
|
|
226
278
|
const lines = [];
|
|
279
|
+
// For C#, we need to generate a struct for the return type first
|
|
280
|
+
if (format === 'csharp') {
|
|
281
|
+
lines.push(...generateCSharpGradientStruct(func, gradientsToUse, csharpFloatType));
|
|
282
|
+
lines.push('');
|
|
283
|
+
}
|
|
227
284
|
// Function signature
|
|
228
285
|
const paramNames = func.parameters.map(p => p.name).join(', ');
|
|
229
286
|
if (format === 'typescript' || format === 'javascript') {
|
|
@@ -232,44 +289,246 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
232
289
|
else if (format === 'python') {
|
|
233
290
|
lines.push(`def ${func.name}_grad(${paramNames}):`);
|
|
234
291
|
}
|
|
292
|
+
else if (format === 'csharp') {
|
|
293
|
+
const params = func.parameters.map(p => {
|
|
294
|
+
if (p.paramType && p.paramType.components) {
|
|
295
|
+
return `${capitalize(p.name)}Struct ${p.name}`;
|
|
296
|
+
}
|
|
297
|
+
return `${csharpFloatType} ${p.name}`;
|
|
298
|
+
}).join(', ');
|
|
299
|
+
const returnType = `${capitalize(func.name)}GradResult`;
|
|
300
|
+
lines.push(`public static ${returnType} ${capitalize(func.name)}_Grad(${params})`);
|
|
301
|
+
lines.push('{');
|
|
302
|
+
}
|
|
235
303
|
// Forward pass - compute intermediate variables
|
|
236
304
|
// Track which expressions are already computed for CSE reuse
|
|
237
|
-
const
|
|
305
|
+
const forwardExpressionMap = new Map();
|
|
306
|
+
// Build substitution map for inlining (to match gradient expressions which are fully inlined)
|
|
307
|
+
const substitutionMap = new Map();
|
|
308
|
+
for (const stmt of func.body) {
|
|
309
|
+
if (stmt.kind === 'assignment') {
|
|
310
|
+
substitutionMap.set(stmt.variable, stmt.expression);
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
// Collect forward variable names and expressions for optimization
|
|
314
|
+
const forwardVars = new Set();
|
|
315
|
+
const forwardVarExprs = new Map();
|
|
316
|
+
for (const stmt of func.body) {
|
|
317
|
+
if (stmt.kind === 'assignment') {
|
|
318
|
+
forwardVars.add(stmt.variable);
|
|
319
|
+
forwardVarExprs.set(stmt.variable, stmt.expression);
|
|
320
|
+
}
|
|
321
|
+
}
|
|
322
|
+
// Optimize forward expressions with e-graph
|
|
323
|
+
let optimizedForwardExprs = forwardVarExprs;
|
|
324
|
+
let forwardCseTemps = new Map();
|
|
325
|
+
if (options.cse !== false && forwardVarExprs.size > 0) {
|
|
326
|
+
const forOptimizer = new Map();
|
|
327
|
+
forOptimizer.set('_forward', forwardVarExprs);
|
|
328
|
+
const forwardResult = optimizeWithEGraph(forOptimizer, { verbose: false });
|
|
329
|
+
// Rename forward temps to avoid conflicts with gradient temps (use _fwd prefix)
|
|
330
|
+
const rawTemps = forwardResult.intermediates;
|
|
331
|
+
const renameMap = new Map();
|
|
332
|
+
for (const oldName of rawTemps.keys()) {
|
|
333
|
+
const newName = oldName.replace('_tmp', '_fwd');
|
|
334
|
+
renameMap.set(oldName, newName);
|
|
335
|
+
}
|
|
336
|
+
// Apply renaming to temp definitions
|
|
337
|
+
for (const [oldName, expr] of rawTemps) {
|
|
338
|
+
const newName = renameMap.get(oldName);
|
|
339
|
+
forwardCseTemps.set(newName, renameTempRefs(expr, renameMap));
|
|
340
|
+
}
|
|
341
|
+
// Apply renaming to optimized expressions
|
|
342
|
+
optimizedForwardExprs = new Map();
|
|
343
|
+
for (const [varName, expr] of (forwardResult.gradients.get('_forward') || forwardVarExprs)) {
|
|
344
|
+
optimizedForwardExprs.set(varName, renameTempRefs(expr, renameMap));
|
|
345
|
+
}
|
|
346
|
+
}
|
|
347
|
+
// Helper to rename temp references in an expression
|
|
348
|
+
function renameTempRefs(expr, renameMap) {
|
|
349
|
+
if (expr.kind === 'variable') {
|
|
350
|
+
const newName = renameMap.get(expr.name);
|
|
351
|
+
return newName ? { kind: 'variable', name: newName } : expr;
|
|
352
|
+
}
|
|
353
|
+
else if (expr.kind === 'binary') {
|
|
354
|
+
return {
|
|
355
|
+
kind: 'binary',
|
|
356
|
+
operator: expr.operator,
|
|
357
|
+
left: renameTempRefs(expr.left, renameMap),
|
|
358
|
+
right: renameTempRefs(expr.right, renameMap)
|
|
359
|
+
};
|
|
360
|
+
}
|
|
361
|
+
else if (expr.kind === 'unary') {
|
|
362
|
+
return {
|
|
363
|
+
kind: 'unary',
|
|
364
|
+
operator: expr.operator,
|
|
365
|
+
operand: renameTempRefs(expr.operand, renameMap)
|
|
366
|
+
};
|
|
367
|
+
}
|
|
368
|
+
else if (expr.kind === 'call') {
|
|
369
|
+
return {
|
|
370
|
+
kind: 'call',
|
|
371
|
+
name: expr.name,
|
|
372
|
+
args: expr.args.map(a => renameTempRefs(a, renameMap))
|
|
373
|
+
};
|
|
374
|
+
}
|
|
375
|
+
else if (expr.kind === 'component') {
|
|
376
|
+
return {
|
|
377
|
+
kind: 'component',
|
|
378
|
+
object: renameTempRefs(expr.object, renameMap),
|
|
379
|
+
component: expr.component
|
|
380
|
+
};
|
|
381
|
+
}
|
|
382
|
+
return expr;
|
|
383
|
+
}
|
|
384
|
+
// Helper to find which forward vars an expression depends on
|
|
385
|
+
function findForwardVarDeps(expr) {
|
|
386
|
+
const deps = new Set();
|
|
387
|
+
function visit(e) {
|
|
388
|
+
if (e.kind === 'variable' && forwardVars.has(e.name)) {
|
|
389
|
+
deps.add(e.name);
|
|
390
|
+
}
|
|
391
|
+
else if (e.kind === 'binary') {
|
|
392
|
+
visit(e.left);
|
|
393
|
+
visit(e.right);
|
|
394
|
+
}
|
|
395
|
+
else if (e.kind === 'unary') {
|
|
396
|
+
visit(e.operand);
|
|
397
|
+
}
|
|
398
|
+
else if (e.kind === 'call') {
|
|
399
|
+
e.args.forEach(visit);
|
|
400
|
+
}
|
|
401
|
+
else if (e.kind === 'component') {
|
|
402
|
+
visit(e.object);
|
|
403
|
+
}
|
|
404
|
+
}
|
|
405
|
+
visit(expr);
|
|
406
|
+
return deps;
|
|
407
|
+
}
|
|
408
|
+
// Track which CSE temps need to be emitted after which forward var
|
|
409
|
+
const fwdTempAfterVar = new Map();
|
|
410
|
+
const fwdTempsBeforeAny = [];
|
|
411
|
+
for (const [tempName, tempExpr] of forwardCseTemps) {
|
|
412
|
+
const deps = findForwardVarDeps(tempExpr);
|
|
413
|
+
if (deps.size === 0) {
|
|
414
|
+
fwdTempsBeforeAny.push({ name: tempName, expr: tempExpr });
|
|
415
|
+
}
|
|
416
|
+
else {
|
|
417
|
+
let lastDep = '';
|
|
418
|
+
for (const stmt of func.body) {
|
|
419
|
+
if (stmt.kind === 'assignment' && deps.has(stmt.variable)) {
|
|
420
|
+
lastDep = stmt.variable;
|
|
421
|
+
}
|
|
422
|
+
}
|
|
423
|
+
if (lastDep) {
|
|
424
|
+
if (!fwdTempAfterVar.has(lastDep)) {
|
|
425
|
+
fwdTempAfterVar.set(lastDep, []);
|
|
426
|
+
}
|
|
427
|
+
fwdTempAfterVar.get(lastDep).push({ name: tempName, expr: tempExpr });
|
|
428
|
+
}
|
|
429
|
+
}
|
|
430
|
+
}
|
|
431
|
+
// Emit forward CSE temps that don't depend on forward vars
|
|
432
|
+
for (const { name: tempName, expr } of fwdTempsBeforeAny) {
|
|
433
|
+
const code = codegen.generate(expr);
|
|
434
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
435
|
+
lines.push(` const ${tempName} = ${code};`);
|
|
436
|
+
}
|
|
437
|
+
else if (format === 'python') {
|
|
438
|
+
lines.push(` ${tempName} = ${code}`);
|
|
439
|
+
}
|
|
440
|
+
else if (format === 'csharp') {
|
|
441
|
+
lines.push(` ${csharpFloatType} ${tempName} = ${code};`);
|
|
442
|
+
}
|
|
443
|
+
}
|
|
444
|
+
// Generate forward variable assignments with interleaved temps
|
|
238
445
|
for (const stmt of func.body) {
|
|
239
446
|
if (stmt.kind === 'assignment') {
|
|
240
447
|
const varName = stmt.variable;
|
|
241
|
-
const expr =
|
|
242
|
-
|
|
243
|
-
|
|
448
|
+
const expr = optimizedForwardExprs.get(varName) || stmt.expression;
|
|
449
|
+
const generatedExpr = codegen.generate(expr);
|
|
450
|
+
if (shouldTrackForForwardReuse(stmt.expression)) {
|
|
451
|
+
// Register the original expression
|
|
452
|
+
const exprKey = serializeExpression(stmt.expression);
|
|
453
|
+
if (!forwardExpressionMap.has(exprKey)) {
|
|
454
|
+
forwardExpressionMap.set(exprKey, varName);
|
|
455
|
+
}
|
|
456
|
+
// Also register the fully inlined form (this is what gradient expressions will have)
|
|
457
|
+
const inlinedExpr = inlineExpression(stmt.expression, substitutionMap);
|
|
458
|
+
const inlinedKey = serializeExpression(inlinedExpr);
|
|
459
|
+
if (inlinedKey !== exprKey && !forwardExpressionMap.has(inlinedKey)) {
|
|
460
|
+
forwardExpressionMap.set(inlinedKey, varName);
|
|
461
|
+
}
|
|
462
|
+
}
|
|
244
463
|
if (format === 'typescript' || format === 'javascript') {
|
|
245
|
-
lines.push(` const ${varName} = ${
|
|
464
|
+
lines.push(` const ${varName} = ${generatedExpr};`);
|
|
246
465
|
}
|
|
247
|
-
else {
|
|
248
|
-
lines.push(` ${varName} = ${
|
|
466
|
+
else if (format === 'python') {
|
|
467
|
+
lines.push(` ${varName} = ${generatedExpr}`);
|
|
468
|
+
}
|
|
469
|
+
else if (format === 'csharp') {
|
|
470
|
+
lines.push(` ${csharpFloatType} ${varName} = ${generatedExpr};`);
|
|
471
|
+
}
|
|
472
|
+
// Emit any forward CSE temps that depend on this var
|
|
473
|
+
const tempsForVar = fwdTempAfterVar.get(varName) || [];
|
|
474
|
+
for (const { name: tempName, expr: tempExpr } of tempsForVar) {
|
|
475
|
+
const tempCode = codegen.generate(tempExpr);
|
|
476
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
477
|
+
lines.push(` const ${tempName} = ${tempCode};`);
|
|
478
|
+
}
|
|
479
|
+
else if (format === 'python') {
|
|
480
|
+
lines.push(` ${tempName} = ${tempCode}`);
|
|
481
|
+
}
|
|
482
|
+
else if (format === 'csharp') {
|
|
483
|
+
lines.push(` ${csharpFloatType} ${tempName} = ${tempCode};`);
|
|
484
|
+
}
|
|
249
485
|
}
|
|
250
486
|
}
|
|
251
487
|
}
|
|
252
488
|
// Compute output value - reuse forward pass variables if possible
|
|
253
|
-
|
|
489
|
+
const valueExpr = func.returnExpr;
|
|
490
|
+
const valueKey = serializeExpression(valueExpr);
|
|
491
|
+
const existingVar = forwardExpressionMap.get(valueKey);
|
|
254
492
|
const valueCode = codegen.generate(valueExpr);
|
|
255
|
-
const existingVar = forwardPassVars.get(valueCode);
|
|
256
493
|
if (existingVar) {
|
|
257
494
|
// Reuse existing variable
|
|
258
495
|
if (format === 'typescript' || format === 'javascript') {
|
|
259
496
|
lines.push(` const value = ${existingVar};`);
|
|
260
497
|
}
|
|
261
|
-
else {
|
|
498
|
+
else if (format === 'python') {
|
|
262
499
|
lines.push(` value = ${existingVar}`);
|
|
263
500
|
}
|
|
501
|
+
else if (format === 'csharp') {
|
|
502
|
+
lines.push(` ${csharpFloatType} value = ${existingVar};`);
|
|
503
|
+
}
|
|
264
504
|
}
|
|
265
505
|
else {
|
|
266
506
|
// Compute new value
|
|
267
507
|
if (format === 'typescript' || format === 'javascript') {
|
|
268
508
|
lines.push(` const value = ${valueCode};`);
|
|
269
509
|
}
|
|
270
|
-
else {
|
|
510
|
+
else if (format === 'python') {
|
|
271
511
|
lines.push(` value = ${valueCode}`);
|
|
272
512
|
}
|
|
513
|
+
else if (format === 'csharp') {
|
|
514
|
+
lines.push(` ${csharpFloatType} value = ${valueCode};`);
|
|
515
|
+
}
|
|
516
|
+
if (shouldTrackForForwardReuse(valueExpr) && !forwardExpressionMap.has(valueKey)) {
|
|
517
|
+
forwardExpressionMap.set(valueKey, 'value');
|
|
518
|
+
}
|
|
519
|
+
}
|
|
520
|
+
// Apply forward expression substitution multiple times until no more changes
|
|
521
|
+
// This handles nested expressions like sqrt(dx*dx + dy*dy) where dx = pix - pjx
|
|
522
|
+
reuseForwardExpressionsInGradients(gradientsToUse.gradients, forwardExpressionMap);
|
|
523
|
+
reuseForwardExpressionsInGradients(gradientsToUse.gradients, forwardExpressionMap);
|
|
524
|
+
reuseForwardExpressionsInGradients(gradientsToUse.gradients, forwardExpressionMap);
|
|
525
|
+
// Simplify gradients after forward expression substitution
|
|
526
|
+
if (shouldSimplify) {
|
|
527
|
+
const simplified = simplifyGradients(gradientsToUse.gradients);
|
|
528
|
+
gradientsToUse.gradients.clear();
|
|
529
|
+
for (const [key, value] of simplified.entries()) {
|
|
530
|
+
gradientsToUse.gradients.set(key, value);
|
|
531
|
+
}
|
|
273
532
|
}
|
|
274
533
|
lines.push('');
|
|
275
534
|
// Generate gradients
|
|
@@ -277,74 +536,71 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
277
536
|
if (includeComments) {
|
|
278
537
|
lines.push(` ${comment} Gradients`);
|
|
279
538
|
}
|
|
280
|
-
// Apply CSE
|
|
539
|
+
// Apply e-graph optimization (CSE + algebraic simplification)
|
|
281
540
|
const shouldApplyCSE = options.cse !== false; // Default to true
|
|
282
|
-
|
|
541
|
+
let cseIntermediates = new Map();
|
|
283
542
|
if (shouldApplyCSE) {
|
|
284
|
-
// Collect all gradient
|
|
543
|
+
// Collect all gradient components into a single map for global optimization
|
|
544
|
+
// Include both structured and scalar gradients (scalar uses 'value' as component)
|
|
545
|
+
const allGradientComponents = new Map();
|
|
285
546
|
for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
|
|
286
547
|
if (isStructuredGradient(gradient)) {
|
|
287
|
-
|
|
288
|
-
// Merge intermediates
|
|
289
|
-
for (const [name, expr] of cseResult.intermediates.entries()) {
|
|
290
|
-
cseIntermediates.set(name, expr);
|
|
291
|
-
}
|
|
292
|
-
// Update gradient components with CSE-simplified versions
|
|
293
|
-
gradient.components = cseResult.components;
|
|
294
|
-
}
|
|
295
|
-
}
|
|
296
|
-
// Generate intermediate variables from CSE
|
|
297
|
-
if (cseIntermediates.size > 0) {
|
|
298
|
-
// Check if we should emit guards (opt-in)
|
|
299
|
-
const shouldEmitGuards = options.emitGuards === true;
|
|
300
|
-
const epsilon = options.epsilon || 1e-10;
|
|
301
|
-
// Identify potential denominators (sum of squares patterns)
|
|
302
|
-
const denominatorVars = new Set();
|
|
303
|
-
for (const [varName, expr] of cseIntermediates.entries()) {
|
|
304
|
-
const code = codegen.generate(expr);
|
|
305
|
-
// Check if this looks like a denominator (contains + and squared terms)
|
|
306
|
-
if (code.includes('+') && (code.includes('* ') || code.includes('Math.pow'))) {
|
|
307
|
-
denominatorVars.add(varName);
|
|
308
|
-
}
|
|
548
|
+
allGradientComponents.set(paramName, gradient.components);
|
|
309
549
|
}
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
550
|
+
else {
|
|
551
|
+
// Scalar gradient - wrap as a single 'value' component
|
|
552
|
+
allGradientComponents.set(paramName, new Map([['value', gradient]]));
|
|
553
|
+
}
|
|
554
|
+
}
|
|
555
|
+
// Run e-graph optimization globally across ALL gradient expressions
|
|
556
|
+
const globalCSE = optimizeWithEGraph(allGradientComponents, { verbose: false });
|
|
557
|
+
cseIntermediates = globalCSE.intermediates;
|
|
558
|
+
// Update gradient components with optimized versions
|
|
559
|
+
for (const [paramName, simplifiedComponents] of globalCSE.gradients.entries()) {
|
|
560
|
+
const gradient = gradientsToUse.gradients.get(paramName);
|
|
561
|
+
if (gradient && isStructuredGradient(gradient)) {
|
|
562
|
+
gradient.components = simplifiedComponents;
|
|
318
563
|
}
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
if (
|
|
323
|
-
|
|
564
|
+
else {
|
|
565
|
+
// Scalar gradient - unwrap from 'value' component
|
|
566
|
+
const optimizedExpr = simplifiedComponents.get('value');
|
|
567
|
+
if (optimizedExpr) {
|
|
568
|
+
gradientsToUse.gradients.set(paramName, optimizedExpr);
|
|
324
569
|
}
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
}
|
|
337
|
-
else {
|
|
338
|
-
zeroGrads.push(`d${paramName}: 0`);
|
|
339
|
-
}
|
|
340
|
-
}
|
|
341
|
-
lines.push(` return { value, ${zeroGrads.join(', ')} };`);
|
|
342
|
-
lines.push(` }`);
|
|
343
|
-
}
|
|
570
|
+
}
|
|
571
|
+
}
|
|
572
|
+
// Post-CSE simplification: apply rules that were skipped to avoid CSE interference
|
|
573
|
+
// Specifically: a + a → 2 * a (now safe because temps have been extracted)
|
|
574
|
+
for (const [varName, expr] of cseIntermediates) {
|
|
575
|
+
cseIntermediates.set(varName, simplifyPostCSE(expr));
|
|
576
|
+
}
|
|
577
|
+
for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
|
|
578
|
+
if (isStructuredGradient(gradient)) {
|
|
579
|
+
for (const [comp, expr] of gradient.components.entries()) {
|
|
580
|
+
gradient.components.set(comp, simplifyPostCSE(expr));
|
|
344
581
|
}
|
|
345
582
|
}
|
|
346
|
-
|
|
583
|
+
else {
|
|
584
|
+
// Scalar gradient - apply post-CSE simplification
|
|
585
|
+
gradientsToUse.gradients.set(paramName, simplifyPostCSE(gradient));
|
|
586
|
+
}
|
|
587
|
+
}
|
|
588
|
+
}
|
|
589
|
+
// Generate CSE intermediate variables
|
|
590
|
+
if (cseIntermediates.size > 0) {
|
|
591
|
+
for (const [varName, expr] of cseIntermediates.entries()) {
|
|
592
|
+
const code = codegen.generate(expr);
|
|
593
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
594
|
+
lines.push(` const ${varName} = ${code};`);
|
|
595
|
+
}
|
|
596
|
+
else if (format === 'python') {
|
|
597
|
+
lines.push(` ${varName} = ${code}`);
|
|
598
|
+
}
|
|
599
|
+
else if (format === 'csharp') {
|
|
600
|
+
lines.push(` ${csharpFloatType} ${varName} = ${code};`);
|
|
601
|
+
}
|
|
347
602
|
}
|
|
603
|
+
lines.push('');
|
|
348
604
|
}
|
|
349
605
|
for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
|
|
350
606
|
// Use shorter names: du, dv instead of grad_u, grad_v
|
|
@@ -366,7 +622,7 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
366
622
|
}
|
|
367
623
|
lines.push(` };`);
|
|
368
624
|
}
|
|
369
|
-
else {
|
|
625
|
+
else if (format === 'python') {
|
|
370
626
|
lines.push(` ${gradName} = {`);
|
|
371
627
|
for (const comp of components) {
|
|
372
628
|
const [key, value] = comp.split(': ');
|
|
@@ -374,6 +630,15 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
374
630
|
}
|
|
375
631
|
lines.push(` }`);
|
|
376
632
|
}
|
|
633
|
+
else if (format === 'csharp') {
|
|
634
|
+
lines.push(` var ${gradName} = new ${capitalize(paramName)}Grad`);
|
|
635
|
+
lines.push(` {`);
|
|
636
|
+
for (const comp of components) {
|
|
637
|
+
const [key, value] = comp.split(': ');
|
|
638
|
+
lines.push(` ${capitalize(key)} = ${value},`);
|
|
639
|
+
}
|
|
640
|
+
lines.push(` };`);
|
|
641
|
+
}
|
|
377
642
|
}
|
|
378
643
|
else {
|
|
379
644
|
// Scalar gradient
|
|
@@ -381,9 +646,12 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
381
646
|
if (format === 'typescript' || format === 'javascript') {
|
|
382
647
|
lines.push(` const ${gradName} = ${code};`);
|
|
383
648
|
}
|
|
384
|
-
else {
|
|
649
|
+
else if (format === 'python') {
|
|
385
650
|
lines.push(` ${gradName} = ${code}`);
|
|
386
651
|
}
|
|
652
|
+
else if (format === 'csharp') {
|
|
653
|
+
lines.push(` ${csharpFloatType} ${gradName} = ${code};`);
|
|
654
|
+
}
|
|
387
655
|
}
|
|
388
656
|
}
|
|
389
657
|
lines.push('');
|
|
@@ -399,7 +667,7 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
399
667
|
lines.push(` };`);
|
|
400
668
|
lines.push('}');
|
|
401
669
|
}
|
|
402
|
-
else {
|
|
670
|
+
else if (format === 'python') {
|
|
403
671
|
lines.push(` return {`);
|
|
404
672
|
lines.push(` "value": value,`);
|
|
405
673
|
for (const gradName of gradNames) {
|
|
@@ -407,33 +675,160 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
407
675
|
}
|
|
408
676
|
lines.push(` }`);
|
|
409
677
|
}
|
|
678
|
+
else if (format === 'csharp') {
|
|
679
|
+
lines.push(` return new ${capitalize(func.name)}GradResult`);
|
|
680
|
+
lines.push(` {`);
|
|
681
|
+
lines.push(` Value = value,`);
|
|
682
|
+
for (const gradName of gradNames) {
|
|
683
|
+
lines.push(` ${capitalize(gradName)} = ${gradName},`);
|
|
684
|
+
}
|
|
685
|
+
lines.push(` };`);
|
|
686
|
+
lines.push('}');
|
|
687
|
+
}
|
|
410
688
|
return lines.join('\n');
|
|
411
689
|
}
|
|
412
690
|
/**
|
|
413
|
-
* Generate the original forward function
|
|
691
|
+
* Generate the original forward function (with optional e-graph optimization)
|
|
414
692
|
*/
|
|
415
693
|
export function generateForwardFunction(func, options = {}) {
|
|
416
694
|
const format = options.format || 'typescript';
|
|
417
|
-
const
|
|
695
|
+
const csharpFloatType = options.csharpFloatType || 'float';
|
|
696
|
+
const shouldOptimize = options.cse !== false; // Optimize by default
|
|
697
|
+
const codegen = new ExpressionCodeGen(format, csharpFloatType);
|
|
418
698
|
const lines = [];
|
|
419
699
|
// Function signature
|
|
420
700
|
const paramNames = func.parameters.map(p => p.name).join(', ');
|
|
421
701
|
if (format === 'typescript' || format === 'javascript') {
|
|
422
702
|
lines.push(`function ${func.name}(${paramNames}) {`);
|
|
423
703
|
}
|
|
424
|
-
else {
|
|
704
|
+
else if (format === 'python') {
|
|
425
705
|
lines.push(`def ${func.name}(${paramNames}):`);
|
|
426
706
|
}
|
|
427
|
-
|
|
707
|
+
else if (format === 'csharp') {
|
|
708
|
+
const floatType = csharpFloatType;
|
|
709
|
+
const params = func.parameters.map(p => {
|
|
710
|
+
if (p.paramType && p.paramType.components) {
|
|
711
|
+
return `${capitalize(p.name)}Struct ${p.name}`;
|
|
712
|
+
}
|
|
713
|
+
return `${floatType} ${p.name}`;
|
|
714
|
+
}).join(', ');
|
|
715
|
+
lines.push(`public static ${floatType} ${capitalize(func.name)}(${params})`);
|
|
716
|
+
lines.push('{');
|
|
717
|
+
}
|
|
718
|
+
// Collect all forward variable names (for dependency tracking)
|
|
719
|
+
const forwardVars = new Set();
|
|
720
|
+
for (const stmt of func.body) {
|
|
721
|
+
if (stmt.kind === 'assignment') {
|
|
722
|
+
forwardVars.add(stmt.variable);
|
|
723
|
+
}
|
|
724
|
+
}
|
|
725
|
+
// Collect expressions for optimization
|
|
726
|
+
const varExpressions = new Map();
|
|
727
|
+
for (const stmt of func.body) {
|
|
728
|
+
if (stmt.kind === 'assignment') {
|
|
729
|
+
varExpressions.set(stmt.variable, stmt.expression);
|
|
730
|
+
}
|
|
731
|
+
}
|
|
732
|
+
// Optimize with e-graph if enabled
|
|
733
|
+
let optimizedExprs = varExpressions;
|
|
734
|
+
let cseTemps = new Map();
|
|
735
|
+
if (shouldOptimize && varExpressions.size > 0) {
|
|
736
|
+
const forOptimizer = new Map();
|
|
737
|
+
forOptimizer.set('_forward', varExpressions);
|
|
738
|
+
const result = optimizeWithEGraph(forOptimizer, { verbose: false });
|
|
739
|
+
cseTemps = result.intermediates;
|
|
740
|
+
optimizedExprs = result.gradients.get('_forward') || varExpressions;
|
|
741
|
+
}
|
|
742
|
+
// Helper to find which forward vars an expression depends on
|
|
743
|
+
function findForwardVarDeps(expr) {
|
|
744
|
+
const deps = new Set();
|
|
745
|
+
function visit(e) {
|
|
746
|
+
if (e.kind === 'variable' && forwardVars.has(e.name)) {
|
|
747
|
+
deps.add(e.name);
|
|
748
|
+
}
|
|
749
|
+
else if (e.kind === 'binary') {
|
|
750
|
+
visit(e.left);
|
|
751
|
+
visit(e.right);
|
|
752
|
+
}
|
|
753
|
+
else if (e.kind === 'unary') {
|
|
754
|
+
visit(e.operand);
|
|
755
|
+
}
|
|
756
|
+
else if (e.kind === 'call') {
|
|
757
|
+
e.args.forEach(visit);
|
|
758
|
+
}
|
|
759
|
+
else if (e.kind === 'component') {
|
|
760
|
+
visit(e.object);
|
|
761
|
+
}
|
|
762
|
+
}
|
|
763
|
+
visit(expr);
|
|
764
|
+
return deps;
|
|
765
|
+
}
|
|
766
|
+
// Track which temps need to be emitted after which forward var
|
|
767
|
+
const tempAfterVar = new Map();
|
|
768
|
+
const tempsEmittedBeforeAny = [];
|
|
769
|
+
for (const [tempName, tempExpr] of cseTemps) {
|
|
770
|
+
const deps = findForwardVarDeps(tempExpr);
|
|
771
|
+
if (deps.size === 0) {
|
|
772
|
+
// Temp only depends on params - emit before any forward vars
|
|
773
|
+
tempsEmittedBeforeAny.push({ name: tempName, expr: tempExpr });
|
|
774
|
+
}
|
|
775
|
+
else {
|
|
776
|
+
// Find the last forward var this temp depends on
|
|
777
|
+
let lastDep = '';
|
|
778
|
+
for (const stmt of func.body) {
|
|
779
|
+
if (stmt.kind === 'assignment' && deps.has(stmt.variable)) {
|
|
780
|
+
lastDep = stmt.variable;
|
|
781
|
+
}
|
|
782
|
+
}
|
|
783
|
+
if (lastDep) {
|
|
784
|
+
if (!tempAfterVar.has(lastDep)) {
|
|
785
|
+
tempAfterVar.set(lastDep, []);
|
|
786
|
+
}
|
|
787
|
+
tempAfterVar.get(lastDep).push({ name: tempName, expr: tempExpr });
|
|
788
|
+
}
|
|
789
|
+
}
|
|
790
|
+
}
|
|
791
|
+
// Emit temps that don't depend on forward vars
|
|
792
|
+
for (const { name: tempName, expr } of tempsEmittedBeforeAny) {
|
|
793
|
+
const code = codegen.generate(expr);
|
|
794
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
795
|
+
lines.push(` const ${tempName} = ${code};`);
|
|
796
|
+
}
|
|
797
|
+
else if (format === 'python') {
|
|
798
|
+
lines.push(` ${tempName} = ${code}`);
|
|
799
|
+
}
|
|
800
|
+
else if (format === 'csharp') {
|
|
801
|
+
lines.push(` ${csharpFloatType} ${tempName} = ${code};`);
|
|
802
|
+
}
|
|
803
|
+
}
|
|
804
|
+
// Generate variable assignments with interleaved temps
|
|
428
805
|
for (const stmt of func.body) {
|
|
429
806
|
if (stmt.kind === 'assignment') {
|
|
430
807
|
const varName = stmt.variable;
|
|
431
|
-
const expr =
|
|
808
|
+
const expr = optimizedExprs.get(varName) || stmt.expression;
|
|
809
|
+
const code = codegen.generate(expr);
|
|
432
810
|
if (format === 'typescript' || format === 'javascript') {
|
|
433
|
-
lines.push(` const ${varName} = ${
|
|
811
|
+
lines.push(` const ${varName} = ${code};`);
|
|
434
812
|
}
|
|
435
|
-
else {
|
|
436
|
-
lines.push(` ${varName} = ${
|
|
813
|
+
else if (format === 'python') {
|
|
814
|
+
lines.push(` ${varName} = ${code}`);
|
|
815
|
+
}
|
|
816
|
+
else if (format === 'csharp') {
|
|
817
|
+
lines.push(` ${csharpFloatType} ${varName} = ${code};`);
|
|
818
|
+
}
|
|
819
|
+
// Emit any temps that depend on this var
|
|
820
|
+
const tempsForVar = tempAfterVar.get(varName) || [];
|
|
821
|
+
for (const { name: tempName, expr: tempExpr } of tempsForVar) {
|
|
822
|
+
const tempCode = codegen.generate(tempExpr);
|
|
823
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
824
|
+
lines.push(` const ${tempName} = ${tempCode};`);
|
|
825
|
+
}
|
|
826
|
+
else if (format === 'python') {
|
|
827
|
+
lines.push(` ${tempName} = ${tempCode}`);
|
|
828
|
+
}
|
|
829
|
+
else if (format === 'csharp') {
|
|
830
|
+
lines.push(` ${csharpFloatType} ${tempName} = ${tempCode};`);
|
|
831
|
+
}
|
|
437
832
|
}
|
|
438
833
|
}
|
|
439
834
|
}
|
|
@@ -443,9 +838,31 @@ export function generateForwardFunction(func, options = {}) {
|
|
|
443
838
|
lines.push(` return ${returnExpr};`);
|
|
444
839
|
lines.push('}');
|
|
445
840
|
}
|
|
446
|
-
else {
|
|
841
|
+
else if (format === 'python') {
|
|
447
842
|
lines.push(` return ${returnExpr}`);
|
|
448
843
|
}
|
|
844
|
+
else if (format === 'csharp') {
|
|
845
|
+
lines.push(` return ${returnExpr};`);
|
|
846
|
+
lines.push('}');
|
|
847
|
+
}
|
|
848
|
+
// For C#, prepend struct definitions
|
|
849
|
+
if (format === 'csharp') {
|
|
850
|
+
const structLines = [];
|
|
851
|
+
for (const param of func.parameters) {
|
|
852
|
+
if (param.paramType && param.paramType.components) {
|
|
853
|
+
structLines.push(`public struct ${capitalize(param.name)}Struct`);
|
|
854
|
+
structLines.push('{');
|
|
855
|
+
for (const comp of param.paramType.components) {
|
|
856
|
+
structLines.push(` public ${csharpFloatType} ${capitalize(comp)};`);
|
|
857
|
+
}
|
|
858
|
+
structLines.push('}');
|
|
859
|
+
structLines.push('');
|
|
860
|
+
}
|
|
861
|
+
}
|
|
862
|
+
if (structLines.length > 0) {
|
|
863
|
+
return structLines.join('\n') + lines.join('\n');
|
|
864
|
+
}
|
|
865
|
+
}
|
|
449
866
|
return lines.join('\n');
|
|
450
867
|
}
|
|
451
868
|
/**
|
|
@@ -466,6 +883,40 @@ export function generateComplete(func, gradients, env, options = {}) {
|
|
|
466
883
|
lines.push(generateGradientFunction(func, gradients, env, options));
|
|
467
884
|
return lines.join('\n');
|
|
468
885
|
}
|
|
886
|
+
class ForwardExpressionSubstituter extends ExpressionTransformer {
|
|
887
|
+
forwardExpressions;
|
|
888
|
+
constructor(forwardExpressions) {
|
|
889
|
+
super();
|
|
890
|
+
this.forwardExpressions = forwardExpressions;
|
|
891
|
+
}
|
|
892
|
+
transform(expr) {
|
|
893
|
+
const key = serializeExpression(expr);
|
|
894
|
+
const varName = this.forwardExpressions.get(key);
|
|
895
|
+
if (varName) {
|
|
896
|
+
return {
|
|
897
|
+
kind: 'variable',
|
|
898
|
+
name: varName
|
|
899
|
+
};
|
|
900
|
+
}
|
|
901
|
+
return super.transform(expr);
|
|
902
|
+
}
|
|
903
|
+
}
|
|
904
|
+
function reuseForwardExpressionsInGradients(gradients, forwardExpressions) {
|
|
905
|
+
if (forwardExpressions.size === 0) {
|
|
906
|
+
return;
|
|
907
|
+
}
|
|
908
|
+
const substituter = new ForwardExpressionSubstituter(forwardExpressions);
|
|
909
|
+
for (const [paramName, gradient] of gradients.entries()) {
|
|
910
|
+
if (isStructuredGradient(gradient)) {
|
|
911
|
+
for (const [component, expr] of gradient.components.entries()) {
|
|
912
|
+
gradient.components.set(component, substituter.transform(expr));
|
|
913
|
+
}
|
|
914
|
+
}
|
|
915
|
+
else {
|
|
916
|
+
gradients.set(paramName, substituter.transform(gradient));
|
|
917
|
+
}
|
|
918
|
+
}
|
|
919
|
+
}
|
|
469
920
|
/**
|
|
470
921
|
* Type guard for StructuredGradient
|
|
471
922
|
*/
|