gradient-script 0.1.0 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +49 -8
- package/dist/cli.js +57 -19
- package/dist/dsl/AST.d.ts +8 -0
- package/dist/dsl/CSE.js +5 -31
- package/dist/dsl/CodeGen.d.ts +7 -2
- package/dist/dsl/CodeGen.js +259 -66
- package/dist/dsl/Errors.d.ts +6 -1
- package/dist/dsl/Errors.js +70 -1
- package/dist/dsl/Expander.js +5 -2
- package/dist/dsl/ExpressionUtils.d.ts +8 -0
- package/dist/dsl/ExpressionUtils.js +24 -0
- package/dist/dsl/Guards.d.ts +2 -0
- package/dist/dsl/Guards.js +78 -36
- package/dist/dsl/Inliner.js +3 -2
- package/dist/dsl/Lexer.js +3 -1
- package/dist/dsl/Parser.js +11 -5
- package/dist/dsl/Simplify.js +47 -0
- package/package.json +1 -1
package/dist/dsl/CodeGen.js
CHANGED
|
@@ -3,15 +3,31 @@
|
|
|
3
3
|
* Generates TypeScript/JavaScript code with gradient functions
|
|
4
4
|
*/
|
|
5
5
|
import { simplifyGradients } from './Simplify.js';
|
|
6
|
+
import { ExpressionTransformer } from './ExpressionTransformer.js';
|
|
6
7
|
import { eliminateCommonSubexpressionsStructured } from './CSE.js';
|
|
7
8
|
import { CodeGenError } from './Errors.js';
|
|
9
|
+
import { serializeExpression } from './ExpressionUtils.js';
|
|
10
|
+
function capitalize(str) {
|
|
11
|
+
return str.charAt(0).toUpperCase() + str.slice(1);
|
|
12
|
+
}
|
|
13
|
+
function shouldTrackForForwardReuse(expr) {
|
|
14
|
+
switch (expr.kind) {
|
|
15
|
+
case 'number':
|
|
16
|
+
case 'variable':
|
|
17
|
+
return false;
|
|
18
|
+
default:
|
|
19
|
+
return true;
|
|
20
|
+
}
|
|
21
|
+
}
|
|
8
22
|
/**
|
|
9
23
|
* Code generator for expressions
|
|
10
24
|
*/
|
|
11
25
|
export class ExpressionCodeGen {
|
|
12
26
|
format;
|
|
13
|
-
|
|
27
|
+
csharpFloatType;
|
|
28
|
+
constructor(format = 'typescript', csharpFloatType = 'float') {
|
|
14
29
|
this.format = format;
|
|
30
|
+
this.csharpFloatType = csharpFloatType;
|
|
15
31
|
}
|
|
16
32
|
/**
|
|
17
33
|
* Generate code for an expression
|
|
@@ -47,7 +63,7 @@ export class ExpressionCodeGen {
|
|
|
47
63
|
if (this.format === 'python' && (op === '^' || op === '**')) {
|
|
48
64
|
op = '**'; // Python uses **
|
|
49
65
|
}
|
|
50
|
-
else if ((this.format === 'typescript' || this.format === 'javascript') && (op === '^' || op === '**')) {
|
|
66
|
+
else if ((this.format === 'typescript' || this.format === 'javascript' || this.format === 'csharp') && (op === '^' || op === '**')) {
|
|
51
67
|
// Optimize: x^2 -> x*x, x^3 -> x*x*x (faster than Math.pow)
|
|
52
68
|
// Only for simple expressions (variables, component access)
|
|
53
69
|
const isSimple = expr.left.kind === 'variable' ||
|
|
@@ -70,8 +86,14 @@ export class ExpressionCodeGen {
|
|
|
70
86
|
}
|
|
71
87
|
}
|
|
72
88
|
}
|
|
73
|
-
// Fall back to Math.pow for complex expressions or larger exponents
|
|
74
|
-
|
|
89
|
+
// Fall back to Math.pow / MathF.Pow for complex expressions or larger exponents
|
|
90
|
+
if (this.format === 'csharp') {
|
|
91
|
+
const mathClass = this.csharpFloatType === 'float' ? 'MathF' : 'Math';
|
|
92
|
+
return `${mathClass}.Pow(${left}, ${right})`;
|
|
93
|
+
}
|
|
94
|
+
else {
|
|
95
|
+
return `Math.pow(${left}, ${right})`;
|
|
96
|
+
}
|
|
75
97
|
}
|
|
76
98
|
return `${left} ${op} ${right}`;
|
|
77
99
|
}
|
|
@@ -160,6 +182,10 @@ export class ExpressionCodeGen {
|
|
|
160
182
|
else if (this.format === 'python') {
|
|
161
183
|
return `max(${min}, min(${max}, ${x}))`;
|
|
162
184
|
}
|
|
185
|
+
else if (this.format === 'csharp') {
|
|
186
|
+
const mathClass = this.csharpFloatType === 'float' ? 'MathF' : 'Math';
|
|
187
|
+
return `${mathClass}.Max(${min}, ${mathClass}.Min(${max}, ${x}))`;
|
|
188
|
+
}
|
|
163
189
|
}
|
|
164
190
|
// Map function names for different formats
|
|
165
191
|
const funcName = this.mapFunctionName(expr.name);
|
|
@@ -167,63 +193,91 @@ export class ExpressionCodeGen {
|
|
|
167
193
|
}
|
|
168
194
|
genComponent(expr) {
|
|
169
195
|
const obj = this.generate(expr.object);
|
|
196
|
+
if (this.format === 'csharp') {
|
|
197
|
+
// C# uses PascalCase for properties
|
|
198
|
+
return `${obj}.${capitalize(expr.component)}`;
|
|
199
|
+
}
|
|
170
200
|
return `${obj}.${expr.component}`;
|
|
171
201
|
}
|
|
202
|
+
// Math functions that should be mapped across all formats
|
|
203
|
+
static MATH_FUNCTIONS = [
|
|
204
|
+
'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'atan2',
|
|
205
|
+
'exp', 'log', 'sqrt', 'abs', 'pow', 'min', 'max'
|
|
206
|
+
];
|
|
207
|
+
// Python built-in functions that don't need the math. prefix
|
|
208
|
+
static PYTHON_BUILTINS = ['abs', 'pow', 'min', 'max'];
|
|
172
209
|
mapFunctionName(name) {
|
|
173
|
-
if
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
'cos': 'Math.cos',
|
|
177
|
-
'tan': 'Math.tan',
|
|
178
|
-
'asin': 'Math.asin',
|
|
179
|
-
'acos': 'Math.acos',
|
|
180
|
-
'atan': 'Math.atan',
|
|
181
|
-
'atan2': 'Math.atan2',
|
|
182
|
-
'exp': 'Math.exp',
|
|
183
|
-
'log': 'Math.log',
|
|
184
|
-
'sqrt': 'Math.sqrt',
|
|
185
|
-
'abs': 'Math.abs',
|
|
186
|
-
'pow': 'Math.pow',
|
|
187
|
-
'min': 'Math.min',
|
|
188
|
-
'max': 'Math.max'
|
|
189
|
-
};
|
|
190
|
-
return mathFuncs[name] || name;
|
|
191
|
-
}
|
|
192
|
-
else if (this.format === 'python') {
|
|
193
|
-
const mathFuncs = {
|
|
194
|
-
'atan2': 'math.atan2',
|
|
195
|
-
'sin': 'math.sin',
|
|
196
|
-
'cos': 'math.cos',
|
|
197
|
-
'tan': 'math.tan',
|
|
198
|
-
'asin': 'math.asin',
|
|
199
|
-
'acos': 'math.acos',
|
|
200
|
-
'atan': 'math.atan',
|
|
201
|
-
'exp': 'math.exp',
|
|
202
|
-
'log': 'math.log',
|
|
203
|
-
'sqrt': 'math.sqrt',
|
|
204
|
-
'abs': 'abs',
|
|
205
|
-
'pow': 'pow',
|
|
206
|
-
'min': 'min',
|
|
207
|
-
'max': 'max'
|
|
208
|
-
};
|
|
209
|
-
return mathFuncs[name] || name;
|
|
210
|
+
// Check if this is a known math function
|
|
211
|
+
if (!ExpressionCodeGen.MATH_FUNCTIONS.includes(name)) {
|
|
212
|
+
return name;
|
|
210
213
|
}
|
|
211
|
-
|
|
214
|
+
// Define format-specific mappers
|
|
215
|
+
const mappers = {
|
|
216
|
+
typescript: (fn) => `Math.${fn}`,
|
|
217
|
+
javascript: (fn) => `Math.${fn}`,
|
|
218
|
+
python: (fn) => ExpressionCodeGen.PYTHON_BUILTINS.includes(fn) ? fn : `math.${fn}`,
|
|
219
|
+
csharp: (fn) => {
|
|
220
|
+
const mathClass = this.csharpFloatType === 'float' ? 'MathF' : 'Math';
|
|
221
|
+
const capitalized = fn.charAt(0).toUpperCase() + fn.slice(1);
|
|
222
|
+
return `${mathClass}.${capitalized}`;
|
|
223
|
+
}
|
|
224
|
+
};
|
|
225
|
+
const mapper = mappers[this.format];
|
|
226
|
+
return mapper ? mapper(name) : name;
|
|
212
227
|
}
|
|
213
228
|
}
|
|
229
|
+
/**
|
|
230
|
+
* Generate C# struct for gradient return type
|
|
231
|
+
*/
|
|
232
|
+
function generateCSharpGradientStruct(func, gradients, floatType) {
|
|
233
|
+
const lines = [];
|
|
234
|
+
const structName = `${capitalize(func.name)}GradResult`;
|
|
235
|
+
lines.push(`public struct ${structName}`);
|
|
236
|
+
lines.push('{');
|
|
237
|
+
lines.push(` public ${floatType} Value;`);
|
|
238
|
+
for (const [paramName, gradient] of gradients.gradients.entries()) {
|
|
239
|
+
const propName = capitalize(`d${paramName}`);
|
|
240
|
+
if (isStructuredGradient(gradient)) {
|
|
241
|
+
// Generate a nested struct type for structured gradients
|
|
242
|
+
const components = Array.from(gradient.components.keys());
|
|
243
|
+
lines.push(` public ${capitalize(paramName)}Grad ${propName};`);
|
|
244
|
+
}
|
|
245
|
+
else {
|
|
246
|
+
lines.push(` public ${floatType} ${propName};`);
|
|
247
|
+
}
|
|
248
|
+
}
|
|
249
|
+
lines.push('}');
|
|
250
|
+
// Generate nested struct types for structured gradients
|
|
251
|
+
for (const [paramName, gradient] of gradients.gradients.entries()) {
|
|
252
|
+
if (isStructuredGradient(gradient)) {
|
|
253
|
+
lines.push('');
|
|
254
|
+
lines.push(`public struct ${capitalize(paramName)}Grad`);
|
|
255
|
+
lines.push('{');
|
|
256
|
+
for (const comp of gradient.components.keys()) {
|
|
257
|
+
lines.push(` public ${floatType} ${capitalize(comp)};`);
|
|
258
|
+
}
|
|
259
|
+
lines.push('}');
|
|
260
|
+
}
|
|
261
|
+
}
|
|
262
|
+
return lines;
|
|
263
|
+
}
|
|
214
264
|
/**
|
|
215
265
|
* Generate complete gradient function code
|
|
216
266
|
*/
|
|
217
267
|
export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
218
268
|
const format = options.format || 'typescript';
|
|
269
|
+
const csharpFloatType = options.csharpFloatType || 'float';
|
|
219
270
|
const includeComments = options.includeComments !== false;
|
|
220
271
|
const shouldSimplify = options.simplify !== false; // Default to true
|
|
221
|
-
//
|
|
222
|
-
const gradientsToUse =
|
|
223
|
-
|
|
224
|
-
: gradients;
|
|
225
|
-
const codegen = new ExpressionCodeGen(format);
|
|
272
|
+
// Note: We don't simplify here yet - we'll do it after forward expression substitution
|
|
273
|
+
const gradientsToUse = gradients;
|
|
274
|
+
const codegen = new ExpressionCodeGen(format, csharpFloatType);
|
|
226
275
|
const lines = [];
|
|
276
|
+
// For C#, we need to generate a struct for the return type first
|
|
277
|
+
if (format === 'csharp') {
|
|
278
|
+
lines.push(...generateCSharpGradientStruct(func, gradientsToUse, csharpFloatType));
|
|
279
|
+
lines.push('');
|
|
280
|
+
}
|
|
227
281
|
// Function signature
|
|
228
282
|
const paramNames = func.parameters.map(p => p.name).join(', ');
|
|
229
283
|
if (format === 'typescript' || format === 'javascript') {
|
|
@@ -232,44 +286,85 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
232
286
|
else if (format === 'python') {
|
|
233
287
|
lines.push(`def ${func.name}_grad(${paramNames}):`);
|
|
234
288
|
}
|
|
289
|
+
else if (format === 'csharp') {
|
|
290
|
+
const params = func.parameters.map(p => {
|
|
291
|
+
if (p.paramType && p.paramType.components) {
|
|
292
|
+
return `${capitalize(p.name)}Struct ${p.name}`;
|
|
293
|
+
}
|
|
294
|
+
return `${csharpFloatType} ${p.name}`;
|
|
295
|
+
}).join(', ');
|
|
296
|
+
const returnType = `${capitalize(func.name)}GradResult`;
|
|
297
|
+
lines.push(`public static ${returnType} ${capitalize(func.name)}_Grad(${params})`);
|
|
298
|
+
lines.push('{');
|
|
299
|
+
}
|
|
235
300
|
// Forward pass - compute intermediate variables
|
|
236
301
|
// Track which expressions are already computed for CSE reuse
|
|
237
|
-
const
|
|
302
|
+
const forwardExpressionMap = new Map();
|
|
238
303
|
for (const stmt of func.body) {
|
|
239
304
|
if (stmt.kind === 'assignment') {
|
|
240
305
|
const varName = stmt.variable;
|
|
241
|
-
const
|
|
242
|
-
|
|
243
|
-
|
|
306
|
+
const generatedExpr = codegen.generate(stmt.expression);
|
|
307
|
+
if (shouldTrackForForwardReuse(stmt.expression)) {
|
|
308
|
+
const exprKey = serializeExpression(stmt.expression);
|
|
309
|
+
if (!forwardExpressionMap.has(exprKey)) {
|
|
310
|
+
forwardExpressionMap.set(exprKey, varName);
|
|
311
|
+
}
|
|
312
|
+
}
|
|
244
313
|
if (format === 'typescript' || format === 'javascript') {
|
|
245
|
-
lines.push(` const ${varName} = ${
|
|
314
|
+
lines.push(` const ${varName} = ${generatedExpr};`);
|
|
246
315
|
}
|
|
247
|
-
else {
|
|
248
|
-
lines.push(` ${varName} = ${
|
|
316
|
+
else if (format === 'python') {
|
|
317
|
+
lines.push(` ${varName} = ${generatedExpr}`);
|
|
318
|
+
}
|
|
319
|
+
else if (format === 'csharp') {
|
|
320
|
+
lines.push(` ${csharpFloatType} ${varName} = ${generatedExpr};`);
|
|
249
321
|
}
|
|
250
322
|
}
|
|
251
323
|
}
|
|
252
324
|
// Compute output value - reuse forward pass variables if possible
|
|
253
|
-
|
|
325
|
+
const valueExpr = func.returnExpr;
|
|
326
|
+
const valueKey = serializeExpression(valueExpr);
|
|
327
|
+
const existingVar = forwardExpressionMap.get(valueKey);
|
|
254
328
|
const valueCode = codegen.generate(valueExpr);
|
|
255
|
-
const existingVar = forwardPassVars.get(valueCode);
|
|
256
329
|
if (existingVar) {
|
|
257
330
|
// Reuse existing variable
|
|
258
331
|
if (format === 'typescript' || format === 'javascript') {
|
|
259
332
|
lines.push(` const value = ${existingVar};`);
|
|
260
333
|
}
|
|
261
|
-
else {
|
|
334
|
+
else if (format === 'python') {
|
|
262
335
|
lines.push(` value = ${existingVar}`);
|
|
263
336
|
}
|
|
337
|
+
else if (format === 'csharp') {
|
|
338
|
+
lines.push(` ${csharpFloatType} value = ${existingVar};`);
|
|
339
|
+
}
|
|
264
340
|
}
|
|
265
341
|
else {
|
|
266
342
|
// Compute new value
|
|
267
343
|
if (format === 'typescript' || format === 'javascript') {
|
|
268
344
|
lines.push(` const value = ${valueCode};`);
|
|
269
345
|
}
|
|
270
|
-
else {
|
|
346
|
+
else if (format === 'python') {
|
|
271
347
|
lines.push(` value = ${valueCode}`);
|
|
272
348
|
}
|
|
349
|
+
else if (format === 'csharp') {
|
|
350
|
+
lines.push(` ${csharpFloatType} value = ${valueCode};`);
|
|
351
|
+
}
|
|
352
|
+
if (shouldTrackForForwardReuse(valueExpr) && !forwardExpressionMap.has(valueKey)) {
|
|
353
|
+
forwardExpressionMap.set(valueKey, 'value');
|
|
354
|
+
}
|
|
355
|
+
}
|
|
356
|
+
// Apply forward expression substitution multiple times until no more changes
|
|
357
|
+
// This handles nested expressions like sqrt(dx*dx + dy*dy) where dx = pix - pjx
|
|
358
|
+
reuseForwardExpressionsInGradients(gradientsToUse.gradients, forwardExpressionMap);
|
|
359
|
+
reuseForwardExpressionsInGradients(gradientsToUse.gradients, forwardExpressionMap);
|
|
360
|
+
reuseForwardExpressionsInGradients(gradientsToUse.gradients, forwardExpressionMap);
|
|
361
|
+
// Simplify gradients after forward expression substitution
|
|
362
|
+
if (shouldSimplify) {
|
|
363
|
+
const simplified = simplifyGradients(gradientsToUse.gradients);
|
|
364
|
+
gradientsToUse.gradients.clear();
|
|
365
|
+
for (const [key, value] of simplified.entries()) {
|
|
366
|
+
gradientsToUse.gradients.set(key, value);
|
|
367
|
+
}
|
|
273
368
|
}
|
|
274
369
|
lines.push('');
|
|
275
370
|
// Generate gradients
|
|
@@ -312,9 +407,12 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
312
407
|
if (format === 'typescript' || format === 'javascript') {
|
|
313
408
|
lines.push(` const ${varName} = ${code};`);
|
|
314
409
|
}
|
|
315
|
-
else {
|
|
410
|
+
else if (format === 'python') {
|
|
316
411
|
lines.push(` ${varName} = ${code}`);
|
|
317
412
|
}
|
|
413
|
+
else if (format === 'csharp') {
|
|
414
|
+
lines.push(` ${csharpFloatType} ${varName} = ${code};`);
|
|
415
|
+
}
|
|
318
416
|
}
|
|
319
417
|
// Emit epsilon guard if needed
|
|
320
418
|
if (shouldEmitGuards && denominatorVars.size > 0) {
|
|
@@ -366,7 +464,7 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
366
464
|
}
|
|
367
465
|
lines.push(` };`);
|
|
368
466
|
}
|
|
369
|
-
else {
|
|
467
|
+
else if (format === 'python') {
|
|
370
468
|
lines.push(` ${gradName} = {`);
|
|
371
469
|
for (const comp of components) {
|
|
372
470
|
const [key, value] = comp.split(': ');
|
|
@@ -374,6 +472,15 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
374
472
|
}
|
|
375
473
|
lines.push(` }`);
|
|
376
474
|
}
|
|
475
|
+
else if (format === 'csharp') {
|
|
476
|
+
lines.push(` var ${gradName} = new ${capitalize(paramName)}Grad`);
|
|
477
|
+
lines.push(` {`);
|
|
478
|
+
for (const comp of components) {
|
|
479
|
+
const [key, value] = comp.split(': ');
|
|
480
|
+
lines.push(` ${capitalize(key)} = ${value},`);
|
|
481
|
+
}
|
|
482
|
+
lines.push(` };`);
|
|
483
|
+
}
|
|
377
484
|
}
|
|
378
485
|
else {
|
|
379
486
|
// Scalar gradient
|
|
@@ -381,9 +488,12 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
381
488
|
if (format === 'typescript' || format === 'javascript') {
|
|
382
489
|
lines.push(` const ${gradName} = ${code};`);
|
|
383
490
|
}
|
|
384
|
-
else {
|
|
491
|
+
else if (format === 'python') {
|
|
385
492
|
lines.push(` ${gradName} = ${code}`);
|
|
386
493
|
}
|
|
494
|
+
else if (format === 'csharp') {
|
|
495
|
+
lines.push(` ${csharpFloatType} ${gradName} = ${code};`);
|
|
496
|
+
}
|
|
387
497
|
}
|
|
388
498
|
}
|
|
389
499
|
lines.push('');
|
|
@@ -399,7 +509,7 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
399
509
|
lines.push(` };`);
|
|
400
510
|
lines.push('}');
|
|
401
511
|
}
|
|
402
|
-
else {
|
|
512
|
+
else if (format === 'python') {
|
|
403
513
|
lines.push(` return {`);
|
|
404
514
|
lines.push(` "value": value,`);
|
|
405
515
|
for (const gradName of gradNames) {
|
|
@@ -407,6 +517,16 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
407
517
|
}
|
|
408
518
|
lines.push(` }`);
|
|
409
519
|
}
|
|
520
|
+
else if (format === 'csharp') {
|
|
521
|
+
lines.push(` return new ${capitalize(func.name)}GradResult`);
|
|
522
|
+
lines.push(` {`);
|
|
523
|
+
lines.push(` Value = value,`);
|
|
524
|
+
for (const gradName of gradNames) {
|
|
525
|
+
lines.push(` ${capitalize(gradName)} = ${gradName},`);
|
|
526
|
+
}
|
|
527
|
+
lines.push(` };`);
|
|
528
|
+
lines.push('}');
|
|
529
|
+
}
|
|
410
530
|
return lines.join('\n');
|
|
411
531
|
}
|
|
412
532
|
/**
|
|
@@ -414,16 +534,30 @@ export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
|
414
534
|
*/
|
|
415
535
|
export function generateForwardFunction(func, options = {}) {
|
|
416
536
|
const format = options.format || 'typescript';
|
|
417
|
-
const
|
|
537
|
+
const csharpFloatType = options.csharpFloatType || 'float';
|
|
538
|
+
const codegen = new ExpressionCodeGen(format, csharpFloatType);
|
|
418
539
|
const lines = [];
|
|
419
540
|
// Function signature
|
|
420
541
|
const paramNames = func.parameters.map(p => p.name).join(', ');
|
|
421
542
|
if (format === 'typescript' || format === 'javascript') {
|
|
422
543
|
lines.push(`function ${func.name}(${paramNames}) {`);
|
|
423
544
|
}
|
|
424
|
-
else {
|
|
545
|
+
else if (format === 'python') {
|
|
425
546
|
lines.push(`def ${func.name}(${paramNames}):`);
|
|
426
547
|
}
|
|
548
|
+
else if (format === 'csharp') {
|
|
549
|
+
const floatType = csharpFloatType;
|
|
550
|
+
const params = func.parameters.map(p => {
|
|
551
|
+
if (p.paramType && p.paramType.components) {
|
|
552
|
+
// Structured parameter - create a struct type name
|
|
553
|
+
return `${capitalize(p.name)}Struct ${p.name}`;
|
|
554
|
+
}
|
|
555
|
+
return `${floatType} ${p.name}`;
|
|
556
|
+
}).join(', ');
|
|
557
|
+
// Generate struct definitions for structured parameters first (we'll prepend them later)
|
|
558
|
+
lines.push(`public static ${floatType} ${capitalize(func.name)}(${params})`);
|
|
559
|
+
lines.push('{');
|
|
560
|
+
}
|
|
427
561
|
// Body
|
|
428
562
|
for (const stmt of func.body) {
|
|
429
563
|
if (stmt.kind === 'assignment') {
|
|
@@ -432,9 +566,12 @@ export function generateForwardFunction(func, options = {}) {
|
|
|
432
566
|
if (format === 'typescript' || format === 'javascript') {
|
|
433
567
|
lines.push(` const ${varName} = ${expr};`);
|
|
434
568
|
}
|
|
435
|
-
else {
|
|
569
|
+
else if (format === 'python') {
|
|
436
570
|
lines.push(` ${varName} = ${expr}`);
|
|
437
571
|
}
|
|
572
|
+
else if (format === 'csharp') {
|
|
573
|
+
lines.push(` ${csharpFloatType} ${varName} = ${expr};`);
|
|
574
|
+
}
|
|
438
575
|
}
|
|
439
576
|
}
|
|
440
577
|
// Return
|
|
@@ -443,9 +580,31 @@ export function generateForwardFunction(func, options = {}) {
|
|
|
443
580
|
lines.push(` return ${returnExpr};`);
|
|
444
581
|
lines.push('}');
|
|
445
582
|
}
|
|
446
|
-
else {
|
|
583
|
+
else if (format === 'python') {
|
|
447
584
|
lines.push(` return ${returnExpr}`);
|
|
448
585
|
}
|
|
586
|
+
else if (format === 'csharp') {
|
|
587
|
+
lines.push(` return ${returnExpr};`);
|
|
588
|
+
lines.push('}');
|
|
589
|
+
}
|
|
590
|
+
// For C#, prepend struct definitions
|
|
591
|
+
if (format === 'csharp') {
|
|
592
|
+
const structLines = [];
|
|
593
|
+
for (const param of func.parameters) {
|
|
594
|
+
if (param.paramType && param.paramType.components) {
|
|
595
|
+
structLines.push(`public struct ${capitalize(param.name)}Struct`);
|
|
596
|
+
structLines.push('{');
|
|
597
|
+
for (const comp of param.paramType.components) {
|
|
598
|
+
structLines.push(` public ${csharpFloatType} ${capitalize(comp)};`);
|
|
599
|
+
}
|
|
600
|
+
structLines.push('}');
|
|
601
|
+
structLines.push('');
|
|
602
|
+
}
|
|
603
|
+
}
|
|
604
|
+
if (structLines.length > 0) {
|
|
605
|
+
return structLines.join('\n') + lines.join('\n');
|
|
606
|
+
}
|
|
607
|
+
}
|
|
449
608
|
return lines.join('\n');
|
|
450
609
|
}
|
|
451
610
|
/**
|
|
@@ -466,6 +625,40 @@ export function generateComplete(func, gradients, env, options = {}) {
|
|
|
466
625
|
lines.push(generateGradientFunction(func, gradients, env, options));
|
|
467
626
|
return lines.join('\n');
|
|
468
627
|
}
|
|
628
|
+
class ForwardExpressionSubstituter extends ExpressionTransformer {
|
|
629
|
+
forwardExpressions;
|
|
630
|
+
constructor(forwardExpressions) {
|
|
631
|
+
super();
|
|
632
|
+
this.forwardExpressions = forwardExpressions;
|
|
633
|
+
}
|
|
634
|
+
transform(expr) {
|
|
635
|
+
const key = serializeExpression(expr);
|
|
636
|
+
const varName = this.forwardExpressions.get(key);
|
|
637
|
+
if (varName) {
|
|
638
|
+
return {
|
|
639
|
+
kind: 'variable',
|
|
640
|
+
name: varName
|
|
641
|
+
};
|
|
642
|
+
}
|
|
643
|
+
return super.transform(expr);
|
|
644
|
+
}
|
|
645
|
+
}
|
|
646
|
+
function reuseForwardExpressionsInGradients(gradients, forwardExpressions) {
|
|
647
|
+
if (forwardExpressions.size === 0) {
|
|
648
|
+
return;
|
|
649
|
+
}
|
|
650
|
+
const substituter = new ForwardExpressionSubstituter(forwardExpressions);
|
|
651
|
+
for (const [paramName, gradient] of gradients.entries()) {
|
|
652
|
+
if (isStructuredGradient(gradient)) {
|
|
653
|
+
for (const [component, expr] of gradient.components.entries()) {
|
|
654
|
+
gradient.components.set(component, substituter.transform(expr));
|
|
655
|
+
}
|
|
656
|
+
}
|
|
657
|
+
else {
|
|
658
|
+
gradients.set(paramName, substituter.transform(gradient));
|
|
659
|
+
}
|
|
660
|
+
}
|
|
661
|
+
}
|
|
469
662
|
/**
|
|
470
663
|
* Type guard for StructuredGradient
|
|
471
664
|
*/
|
package/dist/dsl/Errors.d.ts
CHANGED
|
@@ -2,8 +2,13 @@ export declare class ParseError extends Error {
|
|
|
2
2
|
line: number;
|
|
3
3
|
column: number;
|
|
4
4
|
token?: string | undefined;
|
|
5
|
-
|
|
5
|
+
sourceContext?: string | undefined;
|
|
6
|
+
constructor(message: string, line: number, column: number, token?: string | undefined, sourceContext?: string | undefined);
|
|
6
7
|
}
|
|
8
|
+
/**
|
|
9
|
+
* Format a user-friendly error message with source context
|
|
10
|
+
*/
|
|
11
|
+
export declare function formatParseError(error: ParseError, sourceCode: string, verbose?: boolean): string;
|
|
7
12
|
export declare class TypeError extends Error {
|
|
8
13
|
expression: string;
|
|
9
14
|
expectedType?: string | undefined;
|
package/dist/dsl/Errors.js
CHANGED
|
@@ -2,14 +2,83 @@ export class ParseError extends Error {
|
|
|
2
2
|
line;
|
|
3
3
|
column;
|
|
4
4
|
token;
|
|
5
|
-
|
|
5
|
+
sourceContext;
|
|
6
|
+
constructor(message, line, column, token, sourceContext) {
|
|
6
7
|
super(`Parse error at ${line}:${column}: ${message}`);
|
|
7
8
|
this.line = line;
|
|
8
9
|
this.column = column;
|
|
9
10
|
this.token = token;
|
|
11
|
+
this.sourceContext = sourceContext;
|
|
10
12
|
this.name = 'ParseError';
|
|
11
13
|
}
|
|
12
14
|
}
|
|
15
|
+
/**
|
|
16
|
+
* Format a user-friendly error message with source context
|
|
17
|
+
*/
|
|
18
|
+
export function formatParseError(error, sourceCode, verbose = false) {
|
|
19
|
+
const lines = sourceCode.split('\n');
|
|
20
|
+
const errorLine = lines[error.line - 1];
|
|
21
|
+
let output = `Error: ${error.message.replace(/^Parse error at \d+:\d+: /, '')}\n`;
|
|
22
|
+
// Show the source line with the error
|
|
23
|
+
if (errorLine) {
|
|
24
|
+
output += `\n ${errorLine}\n`;
|
|
25
|
+
// Add caret pointing to error position
|
|
26
|
+
const caretPos = Math.max(0, error.column - 1);
|
|
27
|
+
output += ` ${' '.repeat(caretPos)}^\n`;
|
|
28
|
+
}
|
|
29
|
+
// Add helpful tips based on the error
|
|
30
|
+
output += formatErrorGuidance(error);
|
|
31
|
+
// Only show stack trace in verbose mode
|
|
32
|
+
if (verbose && error.stack) {
|
|
33
|
+
output += '\n\nStack trace:\n' + error.stack;
|
|
34
|
+
}
|
|
35
|
+
return output;
|
|
36
|
+
}
|
|
37
|
+
/**
|
|
38
|
+
* Provide contextual guidance based on error patterns
|
|
39
|
+
*/
|
|
40
|
+
function formatErrorGuidance(error) {
|
|
41
|
+
const msg = error.message.toLowerCase();
|
|
42
|
+
const token = error.token;
|
|
43
|
+
// Semicolon error
|
|
44
|
+
if (token === ';') {
|
|
45
|
+
return `
|
|
46
|
+
Semicolons are not part of gradient-script syntax.
|
|
47
|
+
Each statement should be on its own line.
|
|
48
|
+
|
|
49
|
+
Correct syntax:
|
|
50
|
+
function example(x∇, y∇) {
|
|
51
|
+
result = x + y
|
|
52
|
+
return result
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
💡 Tip: gradient-script uses newline-delimited statements (like Python),
|
|
56
|
+
not semicolons (like JavaScript/C#).
|
|
57
|
+
`;
|
|
58
|
+
}
|
|
59
|
+
// Missing colon in type annotation
|
|
60
|
+
if (msg.includes("expected ':'")) {
|
|
61
|
+
return `
|
|
62
|
+
Type annotations require a colon before the type.
|
|
63
|
+
|
|
64
|
+
Correct syntax:
|
|
65
|
+
function distance(point∇: {x, y}) {
|
|
66
|
+
^
|
|
67
|
+
|
|
68
|
+
💡 Tip: Parameters marked with ∇ need type annotations to specify structure.
|
|
69
|
+
`;
|
|
70
|
+
}
|
|
71
|
+
// Missing gradient marker suggestion
|
|
72
|
+
if (msg.includes('expected parameter name') || msg.includes('unexpected')) {
|
|
73
|
+
return `
|
|
74
|
+
💡 Tip: Make sure all parameters are properly formatted.
|
|
75
|
+
Variables that need gradients must be marked with ∇.
|
|
76
|
+
|
|
77
|
+
Example: function f(a∇: {x, y}, b) { ... }
|
|
78
|
+
`;
|
|
79
|
+
}
|
|
80
|
+
return '';
|
|
81
|
+
}
|
|
13
82
|
export class TypeError extends Error {
|
|
14
83
|
expression;
|
|
15
84
|
expectedType;
|
package/dist/dsl/Expander.js
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
* Expander for GradientScript DSL
|
|
3
3
|
* Expands built-in functions and struct operations into scalar operations
|
|
4
4
|
*/
|
|
5
|
+
import { DifferentiationError } from './Errors.js';
|
|
5
6
|
/**
|
|
6
7
|
* Expand built-in function calls to scalar expressions
|
|
7
8
|
*/
|
|
@@ -15,13 +16,15 @@ export function expandBuiltIn(call) {
|
|
|
15
16
|
case 'magnitude2d':
|
|
16
17
|
return expandMagnitude2d(args[0]);
|
|
17
18
|
case 'normalize2d':
|
|
18
|
-
throw new
|
|
19
|
+
throw new DifferentiationError('normalize2d not yet supported', 'normalize2d', 'Vector normalization requires special handling for zero-length vectors. ' +
|
|
20
|
+
'Use magnitude2d() and division for now.');
|
|
19
21
|
case 'distance2d':
|
|
20
22
|
return expandDistance2d(args[0], args[1]);
|
|
21
23
|
case 'dot3d':
|
|
22
24
|
return expandDot3d(args[0], args[1]);
|
|
23
25
|
case 'cross3d':
|
|
24
|
-
throw new
|
|
26
|
+
throw new DifferentiationError('cross3d returns vector - not yet supported', 'cross3d', 'Cross product returns a 3D vector, which requires structured gradient support. ' +
|
|
27
|
+
'This feature is not yet implemented.');
|
|
25
28
|
case 'magnitude3d':
|
|
26
29
|
return expandMagnitude3d(args[0]);
|
|
27
30
|
default:
|
|
@@ -53,3 +53,11 @@ export declare function containsVariable(expr: Expression, varName: string): boo
|
|
|
53
53
|
* Calculate the maximum nesting depth of an expression
|
|
54
54
|
*/
|
|
55
55
|
export declare function expressionDepth(expr: Expression): number;
|
|
56
|
+
/**
|
|
57
|
+
* Serializes an expression to canonical string representation.
|
|
58
|
+
* Used for expression comparison and hashing (CSE, CodeGen forward reuse).
|
|
59
|
+
*
|
|
60
|
+
* This ensures consistent string representation of expressions across different
|
|
61
|
+
* parts of the codebase.
|
|
62
|
+
*/
|
|
63
|
+
export declare function serializeExpression(expr: Expression): string;
|
|
@@ -173,3 +173,27 @@ export function expressionDepth(expr) {
|
|
|
173
173
|
return 1 + expressionDepth(expr.object);
|
|
174
174
|
}
|
|
175
175
|
}
|
|
176
|
+
/**
|
|
177
|
+
* Serializes an expression to canonical string representation.
|
|
178
|
+
* Used for expression comparison and hashing (CSE, CodeGen forward reuse).
|
|
179
|
+
*
|
|
180
|
+
* This ensures consistent string representation of expressions across different
|
|
181
|
+
* parts of the codebase.
|
|
182
|
+
*/
|
|
183
|
+
export function serializeExpression(expr) {
|
|
184
|
+
switch (expr.kind) {
|
|
185
|
+
case 'number':
|
|
186
|
+
return `num(${expr.value})`;
|
|
187
|
+
case 'variable':
|
|
188
|
+
return `var(${expr.name})`;
|
|
189
|
+
case 'binary':
|
|
190
|
+
return `bin(${expr.operator},${serializeExpression(expr.left)},${serializeExpression(expr.right)})`;
|
|
191
|
+
case 'unary':
|
|
192
|
+
return `un(${expr.operator},${serializeExpression(expr.operand)})`;
|
|
193
|
+
case 'call':
|
|
194
|
+
const args = expr.args.map(arg => serializeExpression(arg)).join(',');
|
|
195
|
+
return `call(${expr.name},${args})`;
|
|
196
|
+
case 'component':
|
|
197
|
+
return `comp(${serializeExpression(expr.object)},${expr.component})`;
|
|
198
|
+
}
|
|
199
|
+
}
|