gradient-script 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +515 -0
- package/dist/cli.d.ts +2 -0
- package/dist/cli.js +136 -0
- package/dist/dsl/AST.d.ts +123 -0
- package/dist/dsl/AST.js +23 -0
- package/dist/dsl/BuiltIns.d.ts +58 -0
- package/dist/dsl/BuiltIns.js +181 -0
- package/dist/dsl/CSE.d.ts +21 -0
- package/dist/dsl/CSE.js +194 -0
- package/dist/dsl/CodeGen.d.ts +60 -0
- package/dist/dsl/CodeGen.js +474 -0
- package/dist/dsl/Differentiation.d.ts +45 -0
- package/dist/dsl/Differentiation.js +421 -0
- package/dist/dsl/DiscontinuityAnalyzer.d.ts +18 -0
- package/dist/dsl/DiscontinuityAnalyzer.js +75 -0
- package/dist/dsl/Errors.d.ts +22 -0
- package/dist/dsl/Errors.js +49 -0
- package/dist/dsl/Expander.d.ts +13 -0
- package/dist/dsl/Expander.js +220 -0
- package/dist/dsl/ExpressionTransformer.d.ts +54 -0
- package/dist/dsl/ExpressionTransformer.js +102 -0
- package/dist/dsl/ExpressionUtils.d.ts +55 -0
- package/dist/dsl/ExpressionUtils.js +175 -0
- package/dist/dsl/GradientChecker.d.ts +71 -0
- package/dist/dsl/GradientChecker.js +258 -0
- package/dist/dsl/Guards.d.ts +27 -0
- package/dist/dsl/Guards.js +206 -0
- package/dist/dsl/Inliner.d.ts +10 -0
- package/dist/dsl/Inliner.js +40 -0
- package/dist/dsl/Lexer.d.ts +63 -0
- package/dist/dsl/Lexer.js +243 -0
- package/dist/dsl/Parser.d.ts +92 -0
- package/dist/dsl/Parser.js +328 -0
- package/dist/dsl/Simplify.d.ts +17 -0
- package/dist/dsl/Simplify.js +276 -0
- package/dist/dsl/TypeInference.d.ts +39 -0
- package/dist/dsl/TypeInference.js +147 -0
- package/dist/dsl/Types.d.ts +58 -0
- package/dist/dsl/Types.js +114 -0
- package/dist/index.d.ts +13 -0
- package/dist/index.js +11 -0
- package/dist/symbolic/AST.d.ts +113 -0
- package/dist/symbolic/AST.js +128 -0
- package/dist/symbolic/CodeGen.d.ts +35 -0
- package/dist/symbolic/CodeGen.js +280 -0
- package/dist/symbolic/Parser.d.ts +64 -0
- package/dist/symbolic/Parser.js +329 -0
- package/dist/symbolic/Simplify.d.ts +10 -0
- package/dist/symbolic/Simplify.js +244 -0
- package/dist/symbolic/SymbolicDiff.d.ts +35 -0
- package/dist/symbolic/SymbolicDiff.js +339 -0
- package/package.json +56 -0
|
@@ -0,0 +1,474 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Code generation for GradientScript DSL
|
|
3
|
+
* Generates TypeScript/JavaScript code with gradient functions
|
|
4
|
+
*/
|
|
5
|
+
import { simplifyGradients } from './Simplify.js';
|
|
6
|
+
import { eliminateCommonSubexpressionsStructured } from './CSE.js';
|
|
7
|
+
import { CodeGenError } from './Errors.js';
|
|
8
|
+
/**
|
|
9
|
+
* Code generator for expressions
|
|
10
|
+
*/
|
|
11
|
+
export class ExpressionCodeGen {
|
|
12
|
+
format;
|
|
13
|
+
constructor(format = 'typescript') {
|
|
14
|
+
this.format = format;
|
|
15
|
+
}
|
|
16
|
+
/**
|
|
17
|
+
* Generate code for an expression
|
|
18
|
+
*/
|
|
19
|
+
generate(expr) {
|
|
20
|
+
switch (expr.kind) {
|
|
21
|
+
case 'number':
|
|
22
|
+
return this.genNumber(expr);
|
|
23
|
+
case 'variable':
|
|
24
|
+
return this.genVariable(expr);
|
|
25
|
+
case 'binary':
|
|
26
|
+
return this.genBinary(expr);
|
|
27
|
+
case 'unary':
|
|
28
|
+
return this.genUnary(expr);
|
|
29
|
+
case 'call':
|
|
30
|
+
return this.genCall(expr);
|
|
31
|
+
case 'component':
|
|
32
|
+
return this.genComponent(expr);
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
genNumber(expr) {
|
|
36
|
+
return String(expr.value);
|
|
37
|
+
}
|
|
38
|
+
genVariable(expr) {
|
|
39
|
+
return expr.name;
|
|
40
|
+
}
|
|
41
|
+
genBinary(expr) {
|
|
42
|
+
// Generate left and right with precedence-aware parentheses
|
|
43
|
+
const left = this.genWithPrecedence(expr.left, expr, 'left');
|
|
44
|
+
const right = this.genWithPrecedence(expr.right, expr, 'right');
|
|
45
|
+
// Handle operator mapping for different formats
|
|
46
|
+
let op = expr.operator;
|
|
47
|
+
if (this.format === 'python' && (op === '^' || op === '**')) {
|
|
48
|
+
op = '**'; // Python uses **
|
|
49
|
+
}
|
|
50
|
+
else if ((this.format === 'typescript' || this.format === 'javascript') && (op === '^' || op === '**')) {
|
|
51
|
+
// Optimize: x^2 -> x*x, x^3 -> x*x*x (faster than Math.pow)
|
|
52
|
+
// Only for simple expressions (variables, component access)
|
|
53
|
+
const isSimple = expr.left.kind === 'variable' ||
|
|
54
|
+
expr.left.kind === 'component' ||
|
|
55
|
+
expr.left.kind === 'number';
|
|
56
|
+
if (isSimple && expr.right.kind === 'number') {
|
|
57
|
+
const exponent = expr.right.value;
|
|
58
|
+
if (Number.isInteger(exponent) && exponent >= 0 && exponent <= 3) {
|
|
59
|
+
if (exponent === 0) {
|
|
60
|
+
return '1';
|
|
61
|
+
}
|
|
62
|
+
else if (exponent === 1) {
|
|
63
|
+
return left;
|
|
64
|
+
}
|
|
65
|
+
else if (exponent === 2) {
|
|
66
|
+
return `${left} * ${left}`;
|
|
67
|
+
}
|
|
68
|
+
else if (exponent === 3) {
|
|
69
|
+
return `${left} * ${left} * ${left}`;
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
// Fall back to Math.pow for complex expressions or larger exponents
|
|
74
|
+
return `Math.pow(${left}, ${right})`;
|
|
75
|
+
}
|
|
76
|
+
return `${left} ${op} ${right}`;
|
|
77
|
+
}
|
|
78
|
+
/**
|
|
79
|
+
* Generate expression with parentheses if needed based on precedence
|
|
80
|
+
*/
|
|
81
|
+
genWithPrecedence(expr, parent, side) {
|
|
82
|
+
// Always parenthesize binary operations that are children of other binary ops
|
|
83
|
+
// unless they have higher precedence
|
|
84
|
+
if (expr.kind === 'binary') {
|
|
85
|
+
const needsParens = this.needsParentheses(expr, parent, side);
|
|
86
|
+
const code = this.generate(expr);
|
|
87
|
+
return needsParens ? `(${code})` : code;
|
|
88
|
+
}
|
|
89
|
+
// Unary expressions need parentheses when they're operands of binary operations
|
|
90
|
+
// with higher or equal precedence, to avoid ambiguity
|
|
91
|
+
if (expr.kind === 'unary') {
|
|
92
|
+
const code = this.generate(expr);
|
|
93
|
+
// Unary minus with binary operation inside needs parens when parent is * or /
|
|
94
|
+
if (parent.operator === '*' || parent.operator === '/' || parent.operator === '^' || parent.operator === '**') {
|
|
95
|
+
return `(${code})`;
|
|
96
|
+
}
|
|
97
|
+
return code;
|
|
98
|
+
}
|
|
99
|
+
return this.generate(expr);
|
|
100
|
+
}
|
|
101
|
+
/**
|
|
102
|
+
* Determine if child expression needs parentheses
|
|
103
|
+
*/
|
|
104
|
+
needsParentheses(child, parent, side) {
|
|
105
|
+
const childPrec = this.getPrecedence(child.operator);
|
|
106
|
+
const parentPrec = this.getPrecedence(parent.operator);
|
|
107
|
+
// Lower precedence always needs parentheses
|
|
108
|
+
if (childPrec < parentPrec) {
|
|
109
|
+
return true;
|
|
110
|
+
}
|
|
111
|
+
// Same precedence: check associativity
|
|
112
|
+
if (childPrec === parentPrec) {
|
|
113
|
+
// For non-associative or right-associative on left side, need parens
|
|
114
|
+
if (side === 'left' && (parent.operator === '/' || parent.operator === '-')) {
|
|
115
|
+
return true;
|
|
116
|
+
}
|
|
117
|
+
// For subtraction/division on right side, need parens
|
|
118
|
+
if (side === 'right' && (child.operator === '+' || child.operator === '-')) {
|
|
119
|
+
return parent.operator === '-';
|
|
120
|
+
}
|
|
121
|
+
if (side === 'right' && (child.operator === '*' || child.operator === '/')) {
|
|
122
|
+
return parent.operator === '/';
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
return false;
|
|
126
|
+
}
|
|
127
|
+
/**
|
|
128
|
+
* Get operator precedence (higher number = higher precedence)
|
|
129
|
+
*/
|
|
130
|
+
getPrecedence(op) {
|
|
131
|
+
switch (op) {
|
|
132
|
+
case '+':
|
|
133
|
+
case '-':
|
|
134
|
+
return 1;
|
|
135
|
+
case '*':
|
|
136
|
+
case '/':
|
|
137
|
+
return 2;
|
|
138
|
+
case '^':
|
|
139
|
+
case '**':
|
|
140
|
+
return 3;
|
|
141
|
+
default:
|
|
142
|
+
return 0;
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
genUnary(expr) {
|
|
146
|
+
const operand = this.generate(expr.operand);
|
|
147
|
+
return `${expr.operator}${operand}`;
|
|
148
|
+
}
|
|
149
|
+
genCall(expr) {
|
|
150
|
+
const args = expr.args.map(arg => this.generate(arg));
|
|
151
|
+
// Handle clamp specially (not in Math)
|
|
152
|
+
if (expr.name === 'clamp') {
|
|
153
|
+
if (args.length !== 3) {
|
|
154
|
+
throw new CodeGenError('clamp requires 3 arguments: clamp(x, min, max)', expr.name, this.format);
|
|
155
|
+
}
|
|
156
|
+
const [x, min, max] = args;
|
|
157
|
+
if (this.format === 'typescript' || this.format === 'javascript') {
|
|
158
|
+
return `Math.max(${min}, Math.min(${max}, ${x}))`;
|
|
159
|
+
}
|
|
160
|
+
else if (this.format === 'python') {
|
|
161
|
+
return `max(${min}, min(${max}, ${x}))`;
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
// Map function names for different formats
|
|
165
|
+
const funcName = this.mapFunctionName(expr.name);
|
|
166
|
+
return `${funcName}(${args.join(', ')})`;
|
|
167
|
+
}
|
|
168
|
+
genComponent(expr) {
|
|
169
|
+
const obj = this.generate(expr.object);
|
|
170
|
+
return `${obj}.${expr.component}`;
|
|
171
|
+
}
|
|
172
|
+
mapFunctionName(name) {
|
|
173
|
+
if (this.format === 'typescript' || this.format === 'javascript') {
|
|
174
|
+
const mathFuncs = {
|
|
175
|
+
'sin': 'Math.sin',
|
|
176
|
+
'cos': 'Math.cos',
|
|
177
|
+
'tan': 'Math.tan',
|
|
178
|
+
'asin': 'Math.asin',
|
|
179
|
+
'acos': 'Math.acos',
|
|
180
|
+
'atan': 'Math.atan',
|
|
181
|
+
'atan2': 'Math.atan2',
|
|
182
|
+
'exp': 'Math.exp',
|
|
183
|
+
'log': 'Math.log',
|
|
184
|
+
'sqrt': 'Math.sqrt',
|
|
185
|
+
'abs': 'Math.abs',
|
|
186
|
+
'pow': 'Math.pow',
|
|
187
|
+
'min': 'Math.min',
|
|
188
|
+
'max': 'Math.max'
|
|
189
|
+
};
|
|
190
|
+
return mathFuncs[name] || name;
|
|
191
|
+
}
|
|
192
|
+
else if (this.format === 'python') {
|
|
193
|
+
const mathFuncs = {
|
|
194
|
+
'atan2': 'math.atan2',
|
|
195
|
+
'sin': 'math.sin',
|
|
196
|
+
'cos': 'math.cos',
|
|
197
|
+
'tan': 'math.tan',
|
|
198
|
+
'asin': 'math.asin',
|
|
199
|
+
'acos': 'math.acos',
|
|
200
|
+
'atan': 'math.atan',
|
|
201
|
+
'exp': 'math.exp',
|
|
202
|
+
'log': 'math.log',
|
|
203
|
+
'sqrt': 'math.sqrt',
|
|
204
|
+
'abs': 'abs',
|
|
205
|
+
'pow': 'pow',
|
|
206
|
+
'min': 'min',
|
|
207
|
+
'max': 'max'
|
|
208
|
+
};
|
|
209
|
+
return mathFuncs[name] || name;
|
|
210
|
+
}
|
|
211
|
+
return name;
|
|
212
|
+
}
|
|
213
|
+
}
|
|
214
|
+
/**
|
|
215
|
+
* Generate complete gradient function code
|
|
216
|
+
*/
|
|
217
|
+
export function generateGradientFunction(func, gradients, env, options = {}) {
|
|
218
|
+
const format = options.format || 'typescript';
|
|
219
|
+
const includeComments = options.includeComments !== false;
|
|
220
|
+
const shouldSimplify = options.simplify !== false; // Default to true
|
|
221
|
+
// Simplify gradients if requested
|
|
222
|
+
const gradientsToUse = shouldSimplify
|
|
223
|
+
? { gradients: simplifyGradients(gradients.gradients) }
|
|
224
|
+
: gradients;
|
|
225
|
+
const codegen = new ExpressionCodeGen(format);
|
|
226
|
+
const lines = [];
|
|
227
|
+
// Function signature
|
|
228
|
+
const paramNames = func.parameters.map(p => p.name).join(', ');
|
|
229
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
230
|
+
lines.push(`function ${func.name}_grad(${paramNames}) {`);
|
|
231
|
+
}
|
|
232
|
+
else if (format === 'python') {
|
|
233
|
+
lines.push(`def ${func.name}_grad(${paramNames}):`);
|
|
234
|
+
}
|
|
235
|
+
// Forward pass - compute intermediate variables
|
|
236
|
+
// Track which expressions are already computed for CSE reuse
|
|
237
|
+
const forwardPassVars = new Map();
|
|
238
|
+
for (const stmt of func.body) {
|
|
239
|
+
if (stmt.kind === 'assignment') {
|
|
240
|
+
const varName = stmt.variable;
|
|
241
|
+
const expr = codegen.generate(stmt.expression);
|
|
242
|
+
// Track this for CSE reuse (store expression -> variable name mapping)
|
|
243
|
+
forwardPassVars.set(expr, varName);
|
|
244
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
245
|
+
lines.push(` const ${varName} = ${expr};`);
|
|
246
|
+
}
|
|
247
|
+
else {
|
|
248
|
+
lines.push(` ${varName} = ${expr}`);
|
|
249
|
+
}
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
// Compute output value - reuse forward pass variables if possible
|
|
253
|
+
let valueExpr = func.returnExpr;
|
|
254
|
+
const valueCode = codegen.generate(valueExpr);
|
|
255
|
+
const existingVar = forwardPassVars.get(valueCode);
|
|
256
|
+
if (existingVar) {
|
|
257
|
+
// Reuse existing variable
|
|
258
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
259
|
+
lines.push(` const value = ${existingVar};`);
|
|
260
|
+
}
|
|
261
|
+
else {
|
|
262
|
+
lines.push(` value = ${existingVar}`);
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
else {
|
|
266
|
+
// Compute new value
|
|
267
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
268
|
+
lines.push(` const value = ${valueCode};`);
|
|
269
|
+
}
|
|
270
|
+
else {
|
|
271
|
+
lines.push(` value = ${valueCode}`);
|
|
272
|
+
}
|
|
273
|
+
}
|
|
274
|
+
lines.push('');
|
|
275
|
+
// Generate gradients
|
|
276
|
+
const comment = format === 'python' ? '#' : '//';
|
|
277
|
+
if (includeComments) {
|
|
278
|
+
lines.push(` ${comment} Gradients`);
|
|
279
|
+
}
|
|
280
|
+
// Apply CSE if requested
|
|
281
|
+
const shouldApplyCSE = options.cse !== false; // Default to true
|
|
282
|
+
const cseIntermediates = new Map();
|
|
283
|
+
if (shouldApplyCSE) {
|
|
284
|
+
// Collect all gradient expressions for CSE analysis
|
|
285
|
+
for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
|
|
286
|
+
if (isStructuredGradient(gradient)) {
|
|
287
|
+
const cseResult = eliminateCommonSubexpressionsStructured(gradient.components);
|
|
288
|
+
// Merge intermediates
|
|
289
|
+
for (const [name, expr] of cseResult.intermediates.entries()) {
|
|
290
|
+
cseIntermediates.set(name, expr);
|
|
291
|
+
}
|
|
292
|
+
// Update gradient components with CSE-simplified versions
|
|
293
|
+
gradient.components = cseResult.components;
|
|
294
|
+
}
|
|
295
|
+
}
|
|
296
|
+
// Generate intermediate variables from CSE
|
|
297
|
+
if (cseIntermediates.size > 0) {
|
|
298
|
+
// Check if we should emit guards (opt-in)
|
|
299
|
+
const shouldEmitGuards = options.emitGuards === true;
|
|
300
|
+
const epsilon = options.epsilon || 1e-10;
|
|
301
|
+
// Identify potential denominators (sum of squares patterns)
|
|
302
|
+
const denominatorVars = new Set();
|
|
303
|
+
for (const [varName, expr] of cseIntermediates.entries()) {
|
|
304
|
+
const code = codegen.generate(expr);
|
|
305
|
+
// Check if this looks like a denominator (contains + and squared terms)
|
|
306
|
+
if (code.includes('+') && (code.includes('* ') || code.includes('Math.pow'))) {
|
|
307
|
+
denominatorVars.add(varName);
|
|
308
|
+
}
|
|
309
|
+
}
|
|
310
|
+
for (const [varName, expr] of cseIntermediates.entries()) {
|
|
311
|
+
const code = codegen.generate(expr);
|
|
312
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
313
|
+
lines.push(` const ${varName} = ${code};`);
|
|
314
|
+
}
|
|
315
|
+
else {
|
|
316
|
+
lines.push(` ${varName} = ${code}`);
|
|
317
|
+
}
|
|
318
|
+
}
|
|
319
|
+
// Emit epsilon guard if needed
|
|
320
|
+
if (shouldEmitGuards && denominatorVars.size > 0) {
|
|
321
|
+
lines.push('');
|
|
322
|
+
if (includeComments) {
|
|
323
|
+
lines.push(` ${comment} Guard against division by zero`);
|
|
324
|
+
}
|
|
325
|
+
for (const denom of denominatorVars) {
|
|
326
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
327
|
+
lines.push(` if (Math.abs(${denom}) < ${epsilon}) {`);
|
|
328
|
+
lines.push(` ${comment} Return zero gradients for degenerate case`);
|
|
329
|
+
// Emit zero gradient structure
|
|
330
|
+
const zeroGrads = [];
|
|
331
|
+
for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
|
|
332
|
+
if (isStructuredGradient(gradient)) {
|
|
333
|
+
const components = Array.from(gradient.components.keys());
|
|
334
|
+
const zeroStruct = components.map(c => `${c}: 0`).join(', ');
|
|
335
|
+
zeroGrads.push(`d${paramName}: { ${zeroStruct} }`);
|
|
336
|
+
}
|
|
337
|
+
else {
|
|
338
|
+
zeroGrads.push(`d${paramName}: 0`);
|
|
339
|
+
}
|
|
340
|
+
}
|
|
341
|
+
lines.push(` return { value, ${zeroGrads.join(', ')} };`);
|
|
342
|
+
lines.push(` }`);
|
|
343
|
+
}
|
|
344
|
+
}
|
|
345
|
+
}
|
|
346
|
+
lines.push('');
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
|
|
350
|
+
// Use shorter names: du, dv instead of grad_u, grad_v
|
|
351
|
+
const gradName = `d${paramName}`;
|
|
352
|
+
if (isStructuredGradient(gradient)) {
|
|
353
|
+
// Structured gradient
|
|
354
|
+
if (includeComments) {
|
|
355
|
+
lines.push(` ${comment} Gradient for ${paramName}`);
|
|
356
|
+
}
|
|
357
|
+
const components = [];
|
|
358
|
+
for (const [comp, expr] of gradient.components.entries()) {
|
|
359
|
+
const code = codegen.generate(expr);
|
|
360
|
+
components.push(`${comp}: ${code}`);
|
|
361
|
+
}
|
|
362
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
363
|
+
lines.push(` const ${gradName} = {`);
|
|
364
|
+
for (const comp of components) {
|
|
365
|
+
lines.push(` ${comp},`);
|
|
366
|
+
}
|
|
367
|
+
lines.push(` };`);
|
|
368
|
+
}
|
|
369
|
+
else {
|
|
370
|
+
lines.push(` ${gradName} = {`);
|
|
371
|
+
for (const comp of components) {
|
|
372
|
+
const [key, value] = comp.split(': ');
|
|
373
|
+
lines.push(` "${key}": ${value},`);
|
|
374
|
+
}
|
|
375
|
+
lines.push(` }`);
|
|
376
|
+
}
|
|
377
|
+
}
|
|
378
|
+
else {
|
|
379
|
+
// Scalar gradient
|
|
380
|
+
const code = codegen.generate(gradient);
|
|
381
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
382
|
+
lines.push(` const ${gradName} = ${code};`);
|
|
383
|
+
}
|
|
384
|
+
else {
|
|
385
|
+
lines.push(` ${gradName} = ${code}`);
|
|
386
|
+
}
|
|
387
|
+
}
|
|
388
|
+
}
|
|
389
|
+
lines.push('');
|
|
390
|
+
// Return result
|
|
391
|
+
const gradNames = Array.from(gradientsToUse.gradients.keys()).map(n => `d${n}`);
|
|
392
|
+
const returnObj = gradNames.map(n => `${n}: ${n}`).join(', ');
|
|
393
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
394
|
+
lines.push(` return {`);
|
|
395
|
+
lines.push(` value,`);
|
|
396
|
+
for (const gradName of gradNames) {
|
|
397
|
+
lines.push(` ${gradName},`);
|
|
398
|
+
}
|
|
399
|
+
lines.push(` };`);
|
|
400
|
+
lines.push('}');
|
|
401
|
+
}
|
|
402
|
+
else {
|
|
403
|
+
lines.push(` return {`);
|
|
404
|
+
lines.push(` "value": value,`);
|
|
405
|
+
for (const gradName of gradNames) {
|
|
406
|
+
lines.push(` "${gradName}": ${gradName},`);
|
|
407
|
+
}
|
|
408
|
+
lines.push(` }`);
|
|
409
|
+
}
|
|
410
|
+
return lines.join('\n');
|
|
411
|
+
}
|
|
412
|
+
/**
|
|
413
|
+
* Generate the original forward function
|
|
414
|
+
*/
|
|
415
|
+
export function generateForwardFunction(func, options = {}) {
|
|
416
|
+
const format = options.format || 'typescript';
|
|
417
|
+
const codegen = new ExpressionCodeGen(format);
|
|
418
|
+
const lines = [];
|
|
419
|
+
// Function signature
|
|
420
|
+
const paramNames = func.parameters.map(p => p.name).join(', ');
|
|
421
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
422
|
+
lines.push(`function ${func.name}(${paramNames}) {`);
|
|
423
|
+
}
|
|
424
|
+
else {
|
|
425
|
+
lines.push(`def ${func.name}(${paramNames}):`);
|
|
426
|
+
}
|
|
427
|
+
// Body
|
|
428
|
+
for (const stmt of func.body) {
|
|
429
|
+
if (stmt.kind === 'assignment') {
|
|
430
|
+
const varName = stmt.variable;
|
|
431
|
+
const expr = codegen.generate(stmt.expression);
|
|
432
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
433
|
+
lines.push(` const ${varName} = ${expr};`);
|
|
434
|
+
}
|
|
435
|
+
else {
|
|
436
|
+
lines.push(` ${varName} = ${expr}`);
|
|
437
|
+
}
|
|
438
|
+
}
|
|
439
|
+
}
|
|
440
|
+
// Return
|
|
441
|
+
const returnExpr = codegen.generate(func.returnExpr);
|
|
442
|
+
if (format === 'typescript' || format === 'javascript') {
|
|
443
|
+
lines.push(` return ${returnExpr};`);
|
|
444
|
+
lines.push('}');
|
|
445
|
+
}
|
|
446
|
+
else {
|
|
447
|
+
lines.push(` return ${returnExpr}`);
|
|
448
|
+
}
|
|
449
|
+
return lines.join('\n');
|
|
450
|
+
}
|
|
451
|
+
/**
|
|
452
|
+
* Generate complete output with both forward and gradient functions
|
|
453
|
+
*/
|
|
454
|
+
export function generateComplete(func, gradients, env, options = {}) {
|
|
455
|
+
const lines = [];
|
|
456
|
+
const format = options.format || 'typescript';
|
|
457
|
+
if (options.includeComments !== false) {
|
|
458
|
+
const comment = format === 'python' ? '#' : '//';
|
|
459
|
+
lines.push(`${comment} Generated by GradientScript`);
|
|
460
|
+
lines.push('');
|
|
461
|
+
}
|
|
462
|
+
// Forward function
|
|
463
|
+
lines.push(generateForwardFunction(func, options));
|
|
464
|
+
lines.push('');
|
|
465
|
+
// Gradient function
|
|
466
|
+
lines.push(generateGradientFunction(func, gradients, env, options));
|
|
467
|
+
return lines.join('\n');
|
|
468
|
+
}
|
|
469
|
+
/**
|
|
470
|
+
* Type guard for StructuredGradient
|
|
471
|
+
*/
|
|
472
|
+
function isStructuredGradient(grad) {
|
|
473
|
+
return 'components' in grad;
|
|
474
|
+
}
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Differentiation for GradientScript DSL
|
|
3
|
+
* Computes symbolic gradients for structured types
|
|
4
|
+
*/
|
|
5
|
+
import { Expression, FunctionDef } from './AST.js';
|
|
6
|
+
import { TypeEnv } from './Types.js';
|
|
7
|
+
/**
|
|
8
|
+
* Result of differentiation
|
|
9
|
+
*/
|
|
10
|
+
export interface GradientResult {
|
|
11
|
+
gradients: Map<string, Expression | StructuredGradient>;
|
|
12
|
+
}
|
|
13
|
+
/**
|
|
14
|
+
* Structured gradient (e.g., for Vec2 parameter)
|
|
15
|
+
*/
|
|
16
|
+
export interface StructuredGradient {
|
|
17
|
+
components: Map<string, Expression>;
|
|
18
|
+
}
|
|
19
|
+
/**
|
|
20
|
+
* Differentiation engine
|
|
21
|
+
*/
|
|
22
|
+
export declare class Differentiator {
|
|
23
|
+
private env;
|
|
24
|
+
constructor(env: TypeEnv);
|
|
25
|
+
/**
|
|
26
|
+
* Differentiate expression with respect to a variable (component-level)
|
|
27
|
+
*/
|
|
28
|
+
differentiate(expr: Expression, wrt: string): Expression;
|
|
29
|
+
private diffNumber;
|
|
30
|
+
private diffVariable;
|
|
31
|
+
private diffBinary;
|
|
32
|
+
private diffUnary;
|
|
33
|
+
private diffCall;
|
|
34
|
+
private diffMathFunction;
|
|
35
|
+
private diffComponent;
|
|
36
|
+
private expandComponentAccess;
|
|
37
|
+
/**
|
|
38
|
+
* Check if expression is constant with respect to wrt
|
|
39
|
+
*/
|
|
40
|
+
private isConstant;
|
|
41
|
+
}
|
|
42
|
+
/**
|
|
43
|
+
* Compute gradients for a function
|
|
44
|
+
*/
|
|
45
|
+
export declare function computeFunctionGradients(func: FunctionDef, env: TypeEnv): GradientResult;
|