gradient-script 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. package/README.md +515 -0
  2. package/dist/cli.d.ts +2 -0
  3. package/dist/cli.js +136 -0
  4. package/dist/dsl/AST.d.ts +123 -0
  5. package/dist/dsl/AST.js +23 -0
  6. package/dist/dsl/BuiltIns.d.ts +58 -0
  7. package/dist/dsl/BuiltIns.js +181 -0
  8. package/dist/dsl/CSE.d.ts +21 -0
  9. package/dist/dsl/CSE.js +194 -0
  10. package/dist/dsl/CodeGen.d.ts +60 -0
  11. package/dist/dsl/CodeGen.js +474 -0
  12. package/dist/dsl/Differentiation.d.ts +45 -0
  13. package/dist/dsl/Differentiation.js +421 -0
  14. package/dist/dsl/DiscontinuityAnalyzer.d.ts +18 -0
  15. package/dist/dsl/DiscontinuityAnalyzer.js +75 -0
  16. package/dist/dsl/Errors.d.ts +22 -0
  17. package/dist/dsl/Errors.js +49 -0
  18. package/dist/dsl/Expander.d.ts +13 -0
  19. package/dist/dsl/Expander.js +220 -0
  20. package/dist/dsl/ExpressionTransformer.d.ts +54 -0
  21. package/dist/dsl/ExpressionTransformer.js +102 -0
  22. package/dist/dsl/ExpressionUtils.d.ts +55 -0
  23. package/dist/dsl/ExpressionUtils.js +175 -0
  24. package/dist/dsl/GradientChecker.d.ts +71 -0
  25. package/dist/dsl/GradientChecker.js +258 -0
  26. package/dist/dsl/Guards.d.ts +27 -0
  27. package/dist/dsl/Guards.js +206 -0
  28. package/dist/dsl/Inliner.d.ts +10 -0
  29. package/dist/dsl/Inliner.js +40 -0
  30. package/dist/dsl/Lexer.d.ts +63 -0
  31. package/dist/dsl/Lexer.js +243 -0
  32. package/dist/dsl/Parser.d.ts +92 -0
  33. package/dist/dsl/Parser.js +328 -0
  34. package/dist/dsl/Simplify.d.ts +17 -0
  35. package/dist/dsl/Simplify.js +276 -0
  36. package/dist/dsl/TypeInference.d.ts +39 -0
  37. package/dist/dsl/TypeInference.js +147 -0
  38. package/dist/dsl/Types.d.ts +58 -0
  39. package/dist/dsl/Types.js +114 -0
  40. package/dist/index.d.ts +13 -0
  41. package/dist/index.js +11 -0
  42. package/dist/symbolic/AST.d.ts +113 -0
  43. package/dist/symbolic/AST.js +128 -0
  44. package/dist/symbolic/CodeGen.d.ts +35 -0
  45. package/dist/symbolic/CodeGen.js +280 -0
  46. package/dist/symbolic/Parser.d.ts +64 -0
  47. package/dist/symbolic/Parser.js +329 -0
  48. package/dist/symbolic/Simplify.d.ts +10 -0
  49. package/dist/symbolic/Simplify.js +244 -0
  50. package/dist/symbolic/SymbolicDiff.d.ts +35 -0
  51. package/dist/symbolic/SymbolicDiff.js +339 -0
  52. package/package.json +56 -0
@@ -0,0 +1,474 @@
1
+ /**
2
+ * Code generation for GradientScript DSL
3
+ * Generates TypeScript/JavaScript code with gradient functions
4
+ */
5
+ import { simplifyGradients } from './Simplify.js';
6
+ import { eliminateCommonSubexpressionsStructured } from './CSE.js';
7
+ import { CodeGenError } from './Errors.js';
8
+ /**
9
+ * Code generator for expressions
10
+ */
11
+ export class ExpressionCodeGen {
12
+ format;
13
+ constructor(format = 'typescript') {
14
+ this.format = format;
15
+ }
16
+ /**
17
+ * Generate code for an expression
18
+ */
19
+ generate(expr) {
20
+ switch (expr.kind) {
21
+ case 'number':
22
+ return this.genNumber(expr);
23
+ case 'variable':
24
+ return this.genVariable(expr);
25
+ case 'binary':
26
+ return this.genBinary(expr);
27
+ case 'unary':
28
+ return this.genUnary(expr);
29
+ case 'call':
30
+ return this.genCall(expr);
31
+ case 'component':
32
+ return this.genComponent(expr);
33
+ }
34
+ }
35
+ genNumber(expr) {
36
+ return String(expr.value);
37
+ }
38
+ genVariable(expr) {
39
+ return expr.name;
40
+ }
41
+ genBinary(expr) {
42
+ // Generate left and right with precedence-aware parentheses
43
+ const left = this.genWithPrecedence(expr.left, expr, 'left');
44
+ const right = this.genWithPrecedence(expr.right, expr, 'right');
45
+ // Handle operator mapping for different formats
46
+ let op = expr.operator;
47
+ if (this.format === 'python' && (op === '^' || op === '**')) {
48
+ op = '**'; // Python uses **
49
+ }
50
+ else if ((this.format === 'typescript' || this.format === 'javascript') && (op === '^' || op === '**')) {
51
+ // Optimize: x^2 -> x*x, x^3 -> x*x*x (faster than Math.pow)
52
+ // Only for simple expressions (variables, component access)
53
+ const isSimple = expr.left.kind === 'variable' ||
54
+ expr.left.kind === 'component' ||
55
+ expr.left.kind === 'number';
56
+ if (isSimple && expr.right.kind === 'number') {
57
+ const exponent = expr.right.value;
58
+ if (Number.isInteger(exponent) && exponent >= 0 && exponent <= 3) {
59
+ if (exponent === 0) {
60
+ return '1';
61
+ }
62
+ else if (exponent === 1) {
63
+ return left;
64
+ }
65
+ else if (exponent === 2) {
66
+ return `${left} * ${left}`;
67
+ }
68
+ else if (exponent === 3) {
69
+ return `${left} * ${left} * ${left}`;
70
+ }
71
+ }
72
+ }
73
+ // Fall back to Math.pow for complex expressions or larger exponents
74
+ return `Math.pow(${left}, ${right})`;
75
+ }
76
+ return `${left} ${op} ${right}`;
77
+ }
78
+ /**
79
+ * Generate expression with parentheses if needed based on precedence
80
+ */
81
+ genWithPrecedence(expr, parent, side) {
82
+ // Always parenthesize binary operations that are children of other binary ops
83
+ // unless they have higher precedence
84
+ if (expr.kind === 'binary') {
85
+ const needsParens = this.needsParentheses(expr, parent, side);
86
+ const code = this.generate(expr);
87
+ return needsParens ? `(${code})` : code;
88
+ }
89
+ // Unary expressions need parentheses when they're operands of binary operations
90
+ // with higher or equal precedence, to avoid ambiguity
91
+ if (expr.kind === 'unary') {
92
+ const code = this.generate(expr);
93
+ // Unary minus with binary operation inside needs parens when parent is * or /
94
+ if (parent.operator === '*' || parent.operator === '/' || parent.operator === '^' || parent.operator === '**') {
95
+ return `(${code})`;
96
+ }
97
+ return code;
98
+ }
99
+ return this.generate(expr);
100
+ }
101
+ /**
102
+ * Determine if child expression needs parentheses
103
+ */
104
+ needsParentheses(child, parent, side) {
105
+ const childPrec = this.getPrecedence(child.operator);
106
+ const parentPrec = this.getPrecedence(parent.operator);
107
+ // Lower precedence always needs parentheses
108
+ if (childPrec < parentPrec) {
109
+ return true;
110
+ }
111
+ // Same precedence: check associativity
112
+ if (childPrec === parentPrec) {
113
+ // For non-associative or right-associative on left side, need parens
114
+ if (side === 'left' && (parent.operator === '/' || parent.operator === '-')) {
115
+ return true;
116
+ }
117
+ // For subtraction/division on right side, need parens
118
+ if (side === 'right' && (child.operator === '+' || child.operator === '-')) {
119
+ return parent.operator === '-';
120
+ }
121
+ if (side === 'right' && (child.operator === '*' || child.operator === '/')) {
122
+ return parent.operator === '/';
123
+ }
124
+ }
125
+ return false;
126
+ }
127
+ /**
128
+ * Get operator precedence (higher number = higher precedence)
129
+ */
130
+ getPrecedence(op) {
131
+ switch (op) {
132
+ case '+':
133
+ case '-':
134
+ return 1;
135
+ case '*':
136
+ case '/':
137
+ return 2;
138
+ case '^':
139
+ case '**':
140
+ return 3;
141
+ default:
142
+ return 0;
143
+ }
144
+ }
145
+ genUnary(expr) {
146
+ const operand = this.generate(expr.operand);
147
+ return `${expr.operator}${operand}`;
148
+ }
149
+ genCall(expr) {
150
+ const args = expr.args.map(arg => this.generate(arg));
151
+ // Handle clamp specially (not in Math)
152
+ if (expr.name === 'clamp') {
153
+ if (args.length !== 3) {
154
+ throw new CodeGenError('clamp requires 3 arguments: clamp(x, min, max)', expr.name, this.format);
155
+ }
156
+ const [x, min, max] = args;
157
+ if (this.format === 'typescript' || this.format === 'javascript') {
158
+ return `Math.max(${min}, Math.min(${max}, ${x}))`;
159
+ }
160
+ else if (this.format === 'python') {
161
+ return `max(${min}, min(${max}, ${x}))`;
162
+ }
163
+ }
164
+ // Map function names for different formats
165
+ const funcName = this.mapFunctionName(expr.name);
166
+ return `${funcName}(${args.join(', ')})`;
167
+ }
168
+ genComponent(expr) {
169
+ const obj = this.generate(expr.object);
170
+ return `${obj}.${expr.component}`;
171
+ }
172
+ mapFunctionName(name) {
173
+ if (this.format === 'typescript' || this.format === 'javascript') {
174
+ const mathFuncs = {
175
+ 'sin': 'Math.sin',
176
+ 'cos': 'Math.cos',
177
+ 'tan': 'Math.tan',
178
+ 'asin': 'Math.asin',
179
+ 'acos': 'Math.acos',
180
+ 'atan': 'Math.atan',
181
+ 'atan2': 'Math.atan2',
182
+ 'exp': 'Math.exp',
183
+ 'log': 'Math.log',
184
+ 'sqrt': 'Math.sqrt',
185
+ 'abs': 'Math.abs',
186
+ 'pow': 'Math.pow',
187
+ 'min': 'Math.min',
188
+ 'max': 'Math.max'
189
+ };
190
+ return mathFuncs[name] || name;
191
+ }
192
+ else if (this.format === 'python') {
193
+ const mathFuncs = {
194
+ 'atan2': 'math.atan2',
195
+ 'sin': 'math.sin',
196
+ 'cos': 'math.cos',
197
+ 'tan': 'math.tan',
198
+ 'asin': 'math.asin',
199
+ 'acos': 'math.acos',
200
+ 'atan': 'math.atan',
201
+ 'exp': 'math.exp',
202
+ 'log': 'math.log',
203
+ 'sqrt': 'math.sqrt',
204
+ 'abs': 'abs',
205
+ 'pow': 'pow',
206
+ 'min': 'min',
207
+ 'max': 'max'
208
+ };
209
+ return mathFuncs[name] || name;
210
+ }
211
+ return name;
212
+ }
213
+ }
214
+ /**
215
+ * Generate complete gradient function code
216
+ */
217
+ export function generateGradientFunction(func, gradients, env, options = {}) {
218
+ const format = options.format || 'typescript';
219
+ const includeComments = options.includeComments !== false;
220
+ const shouldSimplify = options.simplify !== false; // Default to true
221
+ // Simplify gradients if requested
222
+ const gradientsToUse = shouldSimplify
223
+ ? { gradients: simplifyGradients(gradients.gradients) }
224
+ : gradients;
225
+ const codegen = new ExpressionCodeGen(format);
226
+ const lines = [];
227
+ // Function signature
228
+ const paramNames = func.parameters.map(p => p.name).join(', ');
229
+ if (format === 'typescript' || format === 'javascript') {
230
+ lines.push(`function ${func.name}_grad(${paramNames}) {`);
231
+ }
232
+ else if (format === 'python') {
233
+ lines.push(`def ${func.name}_grad(${paramNames}):`);
234
+ }
235
+ // Forward pass - compute intermediate variables
236
+ // Track which expressions are already computed for CSE reuse
237
+ const forwardPassVars = new Map();
238
+ for (const stmt of func.body) {
239
+ if (stmt.kind === 'assignment') {
240
+ const varName = stmt.variable;
241
+ const expr = codegen.generate(stmt.expression);
242
+ // Track this for CSE reuse (store expression -> variable name mapping)
243
+ forwardPassVars.set(expr, varName);
244
+ if (format === 'typescript' || format === 'javascript') {
245
+ lines.push(` const ${varName} = ${expr};`);
246
+ }
247
+ else {
248
+ lines.push(` ${varName} = ${expr}`);
249
+ }
250
+ }
251
+ }
252
+ // Compute output value - reuse forward pass variables if possible
253
+ let valueExpr = func.returnExpr;
254
+ const valueCode = codegen.generate(valueExpr);
255
+ const existingVar = forwardPassVars.get(valueCode);
256
+ if (existingVar) {
257
+ // Reuse existing variable
258
+ if (format === 'typescript' || format === 'javascript') {
259
+ lines.push(` const value = ${existingVar};`);
260
+ }
261
+ else {
262
+ lines.push(` value = ${existingVar}`);
263
+ }
264
+ }
265
+ else {
266
+ // Compute new value
267
+ if (format === 'typescript' || format === 'javascript') {
268
+ lines.push(` const value = ${valueCode};`);
269
+ }
270
+ else {
271
+ lines.push(` value = ${valueCode}`);
272
+ }
273
+ }
274
+ lines.push('');
275
+ // Generate gradients
276
+ const comment = format === 'python' ? '#' : '//';
277
+ if (includeComments) {
278
+ lines.push(` ${comment} Gradients`);
279
+ }
280
+ // Apply CSE if requested
281
+ const shouldApplyCSE = options.cse !== false; // Default to true
282
+ const cseIntermediates = new Map();
283
+ if (shouldApplyCSE) {
284
+ // Collect all gradient expressions for CSE analysis
285
+ for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
286
+ if (isStructuredGradient(gradient)) {
287
+ const cseResult = eliminateCommonSubexpressionsStructured(gradient.components);
288
+ // Merge intermediates
289
+ for (const [name, expr] of cseResult.intermediates.entries()) {
290
+ cseIntermediates.set(name, expr);
291
+ }
292
+ // Update gradient components with CSE-simplified versions
293
+ gradient.components = cseResult.components;
294
+ }
295
+ }
296
+ // Generate intermediate variables from CSE
297
+ if (cseIntermediates.size > 0) {
298
+ // Check if we should emit guards (opt-in)
299
+ const shouldEmitGuards = options.emitGuards === true;
300
+ const epsilon = options.epsilon || 1e-10;
301
+ // Identify potential denominators (sum of squares patterns)
302
+ const denominatorVars = new Set();
303
+ for (const [varName, expr] of cseIntermediates.entries()) {
304
+ const code = codegen.generate(expr);
305
+ // Check if this looks like a denominator (contains + and squared terms)
306
+ if (code.includes('+') && (code.includes('* ') || code.includes('Math.pow'))) {
307
+ denominatorVars.add(varName);
308
+ }
309
+ }
310
+ for (const [varName, expr] of cseIntermediates.entries()) {
311
+ const code = codegen.generate(expr);
312
+ if (format === 'typescript' || format === 'javascript') {
313
+ lines.push(` const ${varName} = ${code};`);
314
+ }
315
+ else {
316
+ lines.push(` ${varName} = ${code}`);
317
+ }
318
+ }
319
+ // Emit epsilon guard if needed
320
+ if (shouldEmitGuards && denominatorVars.size > 0) {
321
+ lines.push('');
322
+ if (includeComments) {
323
+ lines.push(` ${comment} Guard against division by zero`);
324
+ }
325
+ for (const denom of denominatorVars) {
326
+ if (format === 'typescript' || format === 'javascript') {
327
+ lines.push(` if (Math.abs(${denom}) < ${epsilon}) {`);
328
+ lines.push(` ${comment} Return zero gradients for degenerate case`);
329
+ // Emit zero gradient structure
330
+ const zeroGrads = [];
331
+ for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
332
+ if (isStructuredGradient(gradient)) {
333
+ const components = Array.from(gradient.components.keys());
334
+ const zeroStruct = components.map(c => `${c}: 0`).join(', ');
335
+ zeroGrads.push(`d${paramName}: { ${zeroStruct} }`);
336
+ }
337
+ else {
338
+ zeroGrads.push(`d${paramName}: 0`);
339
+ }
340
+ }
341
+ lines.push(` return { value, ${zeroGrads.join(', ')} };`);
342
+ lines.push(` }`);
343
+ }
344
+ }
345
+ }
346
+ lines.push('');
347
+ }
348
+ }
349
+ for (const [paramName, gradient] of gradientsToUse.gradients.entries()) {
350
+ // Use shorter names: du, dv instead of grad_u, grad_v
351
+ const gradName = `d${paramName}`;
352
+ if (isStructuredGradient(gradient)) {
353
+ // Structured gradient
354
+ if (includeComments) {
355
+ lines.push(` ${comment} Gradient for ${paramName}`);
356
+ }
357
+ const components = [];
358
+ for (const [comp, expr] of gradient.components.entries()) {
359
+ const code = codegen.generate(expr);
360
+ components.push(`${comp}: ${code}`);
361
+ }
362
+ if (format === 'typescript' || format === 'javascript') {
363
+ lines.push(` const ${gradName} = {`);
364
+ for (const comp of components) {
365
+ lines.push(` ${comp},`);
366
+ }
367
+ lines.push(` };`);
368
+ }
369
+ else {
370
+ lines.push(` ${gradName} = {`);
371
+ for (const comp of components) {
372
+ const [key, value] = comp.split(': ');
373
+ lines.push(` "${key}": ${value},`);
374
+ }
375
+ lines.push(` }`);
376
+ }
377
+ }
378
+ else {
379
+ // Scalar gradient
380
+ const code = codegen.generate(gradient);
381
+ if (format === 'typescript' || format === 'javascript') {
382
+ lines.push(` const ${gradName} = ${code};`);
383
+ }
384
+ else {
385
+ lines.push(` ${gradName} = ${code}`);
386
+ }
387
+ }
388
+ }
389
+ lines.push('');
390
+ // Return result
391
+ const gradNames = Array.from(gradientsToUse.gradients.keys()).map(n => `d${n}`);
392
+ const returnObj = gradNames.map(n => `${n}: ${n}`).join(', ');
393
+ if (format === 'typescript' || format === 'javascript') {
394
+ lines.push(` return {`);
395
+ lines.push(` value,`);
396
+ for (const gradName of gradNames) {
397
+ lines.push(` ${gradName},`);
398
+ }
399
+ lines.push(` };`);
400
+ lines.push('}');
401
+ }
402
+ else {
403
+ lines.push(` return {`);
404
+ lines.push(` "value": value,`);
405
+ for (const gradName of gradNames) {
406
+ lines.push(` "${gradName}": ${gradName},`);
407
+ }
408
+ lines.push(` }`);
409
+ }
410
+ return lines.join('\n');
411
+ }
412
+ /**
413
+ * Generate the original forward function
414
+ */
415
+ export function generateForwardFunction(func, options = {}) {
416
+ const format = options.format || 'typescript';
417
+ const codegen = new ExpressionCodeGen(format);
418
+ const lines = [];
419
+ // Function signature
420
+ const paramNames = func.parameters.map(p => p.name).join(', ');
421
+ if (format === 'typescript' || format === 'javascript') {
422
+ lines.push(`function ${func.name}(${paramNames}) {`);
423
+ }
424
+ else {
425
+ lines.push(`def ${func.name}(${paramNames}):`);
426
+ }
427
+ // Body
428
+ for (const stmt of func.body) {
429
+ if (stmt.kind === 'assignment') {
430
+ const varName = stmt.variable;
431
+ const expr = codegen.generate(stmt.expression);
432
+ if (format === 'typescript' || format === 'javascript') {
433
+ lines.push(` const ${varName} = ${expr};`);
434
+ }
435
+ else {
436
+ lines.push(` ${varName} = ${expr}`);
437
+ }
438
+ }
439
+ }
440
+ // Return
441
+ const returnExpr = codegen.generate(func.returnExpr);
442
+ if (format === 'typescript' || format === 'javascript') {
443
+ lines.push(` return ${returnExpr};`);
444
+ lines.push('}');
445
+ }
446
+ else {
447
+ lines.push(` return ${returnExpr}`);
448
+ }
449
+ return lines.join('\n');
450
+ }
451
+ /**
452
+ * Generate complete output with both forward and gradient functions
453
+ */
454
+ export function generateComplete(func, gradients, env, options = {}) {
455
+ const lines = [];
456
+ const format = options.format || 'typescript';
457
+ if (options.includeComments !== false) {
458
+ const comment = format === 'python' ? '#' : '//';
459
+ lines.push(`${comment} Generated by GradientScript`);
460
+ lines.push('');
461
+ }
462
+ // Forward function
463
+ lines.push(generateForwardFunction(func, options));
464
+ lines.push('');
465
+ // Gradient function
466
+ lines.push(generateGradientFunction(func, gradients, env, options));
467
+ return lines.join('\n');
468
+ }
469
+ /**
470
+ * Type guard for StructuredGradient
471
+ */
472
+ function isStructuredGradient(grad) {
473
+ return 'components' in grad;
474
+ }
@@ -0,0 +1,45 @@
1
+ /**
2
+ * Differentiation for GradientScript DSL
3
+ * Computes symbolic gradients for structured types
4
+ */
5
+ import { Expression, FunctionDef } from './AST.js';
6
+ import { TypeEnv } from './Types.js';
7
+ /**
8
+ * Result of differentiation
9
+ */
10
+ export interface GradientResult {
11
+ gradients: Map<string, Expression | StructuredGradient>;
12
+ }
13
+ /**
14
+ * Structured gradient (e.g., for Vec2 parameter)
15
+ */
16
+ export interface StructuredGradient {
17
+ components: Map<string, Expression>;
18
+ }
19
+ /**
20
+ * Differentiation engine
21
+ */
22
+ export declare class Differentiator {
23
+ private env;
24
+ constructor(env: TypeEnv);
25
+ /**
26
+ * Differentiate expression with respect to a variable (component-level)
27
+ */
28
+ differentiate(expr: Expression, wrt: string): Expression;
29
+ private diffNumber;
30
+ private diffVariable;
31
+ private diffBinary;
32
+ private diffUnary;
33
+ private diffCall;
34
+ private diffMathFunction;
35
+ private diffComponent;
36
+ private expandComponentAccess;
37
+ /**
38
+ * Check if expression is constant with respect to wrt
39
+ */
40
+ private isConstant;
41
+ }
42
+ /**
43
+ * Compute gradients for a function
44
+ */
45
+ export declare function computeFunctionGradients(func: FunctionDef, env: TypeEnv): GradientResult;