scalar-autograd 0.1.7 → 0.1.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/README.md +127 -2
  2. package/dist/CompiledFunctions.d.ts +111 -0
  3. package/dist/CompiledFunctions.js +268 -0
  4. package/dist/CompiledResiduals.d.ts +74 -0
  5. package/dist/CompiledResiduals.js +94 -0
  6. package/dist/EigenvalueHelpers.d.ts +14 -0
  7. package/dist/EigenvalueHelpers.js +93 -0
  8. package/dist/Geometry.d.ts +131 -0
  9. package/dist/Geometry.js +213 -0
  10. package/dist/GraphBuilder.d.ts +64 -0
  11. package/dist/GraphBuilder.js +237 -0
  12. package/dist/GraphCanonicalizerNoSort.d.ts +20 -0
  13. package/dist/GraphCanonicalizerNoSort.js +190 -0
  14. package/dist/GraphHashCanonicalizer.d.ts +46 -0
  15. package/dist/GraphHashCanonicalizer.js +220 -0
  16. package/dist/GraphSignature.d.ts +7 -0
  17. package/dist/GraphSignature.js +7 -0
  18. package/dist/KernelPool.d.ts +55 -0
  19. package/dist/KernelPool.js +124 -0
  20. package/dist/LBFGS.d.ts +84 -0
  21. package/dist/LBFGS.js +313 -0
  22. package/dist/LinearSolver.d.ts +69 -0
  23. package/dist/LinearSolver.js +213 -0
  24. package/dist/Losses.d.ts +9 -0
  25. package/dist/Losses.js +42 -37
  26. package/dist/Matrix3x3.d.ts +50 -0
  27. package/dist/Matrix3x3.js +146 -0
  28. package/dist/NonlinearLeastSquares.d.ts +33 -0
  29. package/dist/NonlinearLeastSquares.js +252 -0
  30. package/dist/Optimizers.d.ts +70 -14
  31. package/dist/Optimizers.js +42 -19
  32. package/dist/V.d.ts +0 -0
  33. package/dist/V.js +0 -0
  34. package/dist/Value.d.ts +84 -2
  35. package/dist/Value.js +296 -58
  36. package/dist/ValueActivation.js +10 -14
  37. package/dist/ValueArithmetic.d.ts +1 -0
  38. package/dist/ValueArithmetic.js +58 -50
  39. package/dist/ValueComparison.js +9 -13
  40. package/dist/ValueRegistry.d.ts +38 -0
  41. package/dist/ValueRegistry.js +88 -0
  42. package/dist/ValueTrig.js +14 -18
  43. package/dist/Vec2.d.ts +45 -0
  44. package/dist/Vec2.js +93 -0
  45. package/dist/Vec3.d.ts +78 -0
  46. package/dist/Vec3.js +169 -0
  47. package/dist/Vec4.d.ts +45 -0
  48. package/dist/Vec4.js +126 -0
  49. package/dist/__tests__/duplicate-inputs.test.js +33 -0
  50. package/dist/cli/gradient-gen.d.ts +19 -0
  51. package/dist/cli/gradient-gen.js +264 -0
  52. package/dist/compileIndirectKernel.d.ts +24 -0
  53. package/dist/compileIndirectKernel.js +148 -0
  54. package/dist/index.d.ts +20 -0
  55. package/dist/index.js +20 -0
  56. package/dist/scalar-autograd.d.ts +1157 -0
  57. package/dist/symbolic/AST.d.ts +113 -0
  58. package/dist/symbolic/AST.js +128 -0
  59. package/dist/symbolic/CodeGen.d.ts +35 -0
  60. package/dist/symbolic/CodeGen.js +280 -0
  61. package/dist/symbolic/Parser.d.ts +64 -0
  62. package/dist/symbolic/Parser.js +329 -0
  63. package/dist/symbolic/Simplify.d.ts +10 -0
  64. package/dist/symbolic/Simplify.js +244 -0
  65. package/dist/symbolic/SymbolicDiff.d.ts +35 -0
  66. package/dist/symbolic/SymbolicDiff.js +339 -0
  67. package/dist/tsdoc-metadata.json +11 -0
  68. package/package.json +29 -5
  69. package/dist/Losses.spec.js +0 -54
  70. package/dist/Optimizers.edge-cases.spec.d.ts +0 -1
  71. package/dist/Optimizers.edge-cases.spec.js +0 -29
  72. package/dist/Optimizers.spec.d.ts +0 -1
  73. package/dist/Optimizers.spec.js +0 -56
  74. package/dist/Value.edge-cases.spec.d.ts +0 -1
  75. package/dist/Value.edge-cases.spec.js +0 -54
  76. package/dist/Value.grad-flow.spec.d.ts +0 -1
  77. package/dist/Value.grad-flow.spec.js +0 -24
  78. package/dist/Value.losses-edge-cases.spec.d.ts +0 -1
  79. package/dist/Value.losses-edge-cases.spec.js +0 -30
  80. package/dist/Value.memory.spec.d.ts +0 -1
  81. package/dist/Value.memory.spec.js +0 -23
  82. package/dist/Value.nn.spec.d.ts +0 -1
  83. package/dist/Value.nn.spec.js +0 -111
  84. package/dist/Value.spec.d.ts +0 -1
  85. package/dist/Value.spec.js +0 -245
  86. /package/dist/{Losses.spec.d.ts → __tests__/duplicate-inputs.test.d.ts} +0 -0
@@ -0,0 +1,113 @@
1
+ /**
2
+ * Abstract Syntax Tree (AST) node definitions for symbolic differentiation.
3
+ * @internal
4
+ */
5
+ /**
6
+ * Base interface for all AST nodes
7
+ */
8
+ export interface ASTNode {
9
+ type: string;
10
+ accept<T>(visitor: ASTVisitor<T>): T;
11
+ }
12
+ /**
13
+ * Visitor pattern interface for traversing AST
14
+ */
15
+ export interface ASTVisitor<T> {
16
+ visitNumber(node: NumberNode): T;
17
+ visitVariable(node: VariableNode): T;
18
+ visitBinaryOp(node: BinaryOpNode): T;
19
+ visitUnaryOp(node: UnaryOpNode): T;
20
+ visitFunctionCall(node: FunctionCallNode): T;
21
+ visitVectorAccess(node: VectorAccessNode): T;
22
+ visitVectorConstructor(node: VectorConstructorNode): T;
23
+ }
24
+ /**
25
+ * Numeric constant node
26
+ */
27
+ export declare class NumberNode implements ASTNode {
28
+ value: number;
29
+ type: "Number";
30
+ constructor(value: number);
31
+ accept<T>(visitor: ASTVisitor<T>): T;
32
+ toString(): string;
33
+ }
34
+ /**
35
+ * Variable reference node (e.g., 'x', 'y', 'a')
36
+ */
37
+ export declare class VariableNode implements ASTNode {
38
+ name: string;
39
+ type: "Variable";
40
+ constructor(name: string);
41
+ accept<T>(visitor: ASTVisitor<T>): T;
42
+ toString(): string;
43
+ }
44
+ /**
45
+ * Binary operation node (e.g., a + b, x * y)
46
+ */
47
+ export declare class BinaryOpNode implements ASTNode {
48
+ op: '+' | '-' | '*' | '/' | '**' | 'pow';
49
+ left: ASTNode;
50
+ right: ASTNode;
51
+ type: "BinaryOp";
52
+ constructor(op: '+' | '-' | '*' | '/' | '**' | 'pow', left: ASTNode, right: ASTNode);
53
+ accept<T>(visitor: ASTVisitor<T>): T;
54
+ toString(): string;
55
+ }
56
+ /**
57
+ * Unary operation node (e.g., -x, +y)
58
+ */
59
+ export declare class UnaryOpNode implements ASTNode {
60
+ op: '+' | '-';
61
+ operand: ASTNode;
62
+ type: "UnaryOp";
63
+ constructor(op: '+' | '-', operand: ASTNode);
64
+ accept<T>(visitor: ASTVisitor<T>): T;
65
+ toString(): string;
66
+ }
67
+ /**
68
+ * Function call node (e.g., sin(x), sqrt(y), max(a, b))
69
+ */
70
+ export declare class FunctionCallNode implements ASTNode {
71
+ name: string;
72
+ args: ASTNode[];
73
+ type: "FunctionCall";
74
+ constructor(name: string, args: ASTNode[]);
75
+ accept<T>(visitor: ASTVisitor<T>): T;
76
+ toString(): string;
77
+ }
78
+ /**
79
+ * Vector component access node (e.g., v.x, point.y)
80
+ */
81
+ export declare class VectorAccessNode implements ASTNode {
82
+ vector: ASTNode;
83
+ component: 'x' | 'y' | 'z';
84
+ type: "VectorAccess";
85
+ constructor(vector: ASTNode, component: 'x' | 'y' | 'z');
86
+ accept<T>(visitor: ASTVisitor<T>): T;
87
+ toString(): string;
88
+ }
89
+ /**
90
+ * Vector constructor node (e.g., Vec2(x, y), Vec3(1, 2, 3))
91
+ */
92
+ export declare class VectorConstructorNode implements ASTNode {
93
+ vectorType: 'Vec2' | 'Vec3';
94
+ components: ASTNode[];
95
+ type: "VectorConstructor";
96
+ constructor(vectorType: 'Vec2' | 'Vec3', components: ASTNode[]);
97
+ accept<T>(visitor: ASTVisitor<T>): T;
98
+ toString(): string;
99
+ }
100
+ /**
101
+ * Assignment statement (e.g., x = 5, y = x + 2)
102
+ */
103
+ export interface Assignment {
104
+ variable: string;
105
+ expression: ASTNode;
106
+ }
107
+ /**
108
+ * Complete program with assignments and output
109
+ */
110
+ export interface Program {
111
+ assignments: Assignment[];
112
+ output: string;
113
+ }
@@ -0,0 +1,128 @@
1
+ /**
2
+ * Abstract Syntax Tree (AST) node definitions for symbolic differentiation.
3
+ * @internal
4
+ */
5
+ /**
6
+ * Numeric constant node
7
+ */
8
+ export class NumberNode {
9
+ value;
10
+ type = 'Number';
11
+ constructor(value) {
12
+ this.value = value;
13
+ }
14
+ accept(visitor) {
15
+ return visitor.visitNumber(this);
16
+ }
17
+ toString() {
18
+ return String(this.value);
19
+ }
20
+ }
21
+ /**
22
+ * Variable reference node (e.g., 'x', 'y', 'a')
23
+ */
24
+ export class VariableNode {
25
+ name;
26
+ type = 'Variable';
27
+ constructor(name) {
28
+ this.name = name;
29
+ }
30
+ accept(visitor) {
31
+ return visitor.visitVariable(this);
32
+ }
33
+ toString() {
34
+ return this.name;
35
+ }
36
+ }
37
+ /**
38
+ * Binary operation node (e.g., a + b, x * y)
39
+ */
40
+ export class BinaryOpNode {
41
+ op;
42
+ left;
43
+ right;
44
+ type = 'BinaryOp';
45
+ constructor(op, left, right) {
46
+ this.op = op;
47
+ this.left = left;
48
+ this.right = right;
49
+ }
50
+ accept(visitor) {
51
+ return visitor.visitBinaryOp(this);
52
+ }
53
+ toString() {
54
+ return `(${this.left.toString()} ${this.op} ${this.right.toString()})`;
55
+ }
56
+ }
57
+ /**
58
+ * Unary operation node (e.g., -x, +y)
59
+ */
60
+ export class UnaryOpNode {
61
+ op;
62
+ operand;
63
+ type = 'UnaryOp';
64
+ constructor(op, operand) {
65
+ this.op = op;
66
+ this.operand = operand;
67
+ }
68
+ accept(visitor) {
69
+ return visitor.visitUnaryOp(this);
70
+ }
71
+ toString() {
72
+ return `${this.op}${this.operand.toString()}`;
73
+ }
74
+ }
75
+ /**
76
+ * Function call node (e.g., sin(x), sqrt(y), max(a, b))
77
+ */
78
+ export class FunctionCallNode {
79
+ name;
80
+ args;
81
+ type = 'FunctionCall';
82
+ constructor(name, args) {
83
+ this.name = name;
84
+ this.args = args;
85
+ }
86
+ accept(visitor) {
87
+ return visitor.visitFunctionCall(this);
88
+ }
89
+ toString() {
90
+ return `${this.name}(${this.args.map(a => a.toString()).join(', ')})`;
91
+ }
92
+ }
93
+ /**
94
+ * Vector component access node (e.g., v.x, point.y)
95
+ */
96
+ export class VectorAccessNode {
97
+ vector;
98
+ component;
99
+ type = 'VectorAccess';
100
+ constructor(vector, component) {
101
+ this.vector = vector;
102
+ this.component = component;
103
+ }
104
+ accept(visitor) {
105
+ return visitor.visitVectorAccess(this);
106
+ }
107
+ toString() {
108
+ return `${this.vector.toString()}.${this.component}`;
109
+ }
110
+ }
111
+ /**
112
+ * Vector constructor node (e.g., Vec2(x, y), Vec3(1, 2, 3))
113
+ */
114
+ export class VectorConstructorNode {
115
+ vectorType;
116
+ components;
117
+ type = 'VectorConstructor';
118
+ constructor(vectorType, components) {
119
+ this.vectorType = vectorType;
120
+ this.components = components;
121
+ }
122
+ accept(visitor) {
123
+ return visitor.visitVectorConstructor(this);
124
+ }
125
+ toString() {
126
+ return `${this.vectorType}(${this.components.map(c => c.toString()).join(', ')})`;
127
+ }
128
+ }
@@ -0,0 +1,35 @@
1
+ /**
2
+ * Code generator for symbolic gradients.
3
+ * Generates executable JavaScript/TypeScript code with mathematical annotations.
4
+ * @internal
5
+ */
6
+ import { ASTNode, Program } from './AST';
7
+ /**
8
+ * Generate JavaScript code from AST
9
+ */
10
+ export declare function generateCode(node: ASTNode): string;
11
+ /**
12
+ * Generate mathematical notation from AST
13
+ */
14
+ export declare function generateMathNotation(node: ASTNode): string;
15
+ /**
16
+ * Generate complete gradient computation code
17
+ */
18
+ export interface CodeGenOptions {
19
+ /** Include mathematical notation as comments */
20
+ includeMath?: boolean;
21
+ /** Variable declaration style */
22
+ varStyle?: 'const' | 'let' | 'var';
23
+ /** Include forward pass computation */
24
+ includeForward?: boolean;
25
+ /** Format code with indentation */
26
+ indent?: string;
27
+ }
28
+ /**
29
+ * Generate complete gradient code for a program
30
+ */
31
+ export declare function generateGradientCode(program: Program, gradients: Map<string, ASTNode>, options?: CodeGenOptions): string;
32
+ /**
33
+ * Generate gradient code as a function
34
+ */
35
+ export declare function generateGradientFunction(program: Program, gradients: Map<string, ASTNode>, functionName: string, parameters: string[], options?: CodeGenOptions): string;
@@ -0,0 +1,280 @@
1
+ /**
2
+ * Code generator for symbolic gradients.
3
+ * Generates executable JavaScript/TypeScript code with mathematical annotations.
4
+ * @internal
5
+ */
6
+ /**
7
+ * Code generation visitor - generates JavaScript expressions
8
+ */
9
+ class CodeGenVisitor {
10
+ visitNumber(node) {
11
+ return String(node.value);
12
+ }
13
+ visitVariable(node) {
14
+ return node.name;
15
+ }
16
+ visitUnaryOp(node) {
17
+ const operand = node.operand.accept(this);
18
+ // Add parentheses if operand is complex
19
+ if (needsParens(node.operand)) {
20
+ return `${node.op}(${operand})`;
21
+ }
22
+ return `${node.op}${operand}`;
23
+ }
24
+ visitBinaryOp(node) {
25
+ const left = node.left.accept(this);
26
+ const right = node.right.accept(this);
27
+ // Handle precedence with parentheses
28
+ const leftStr = needsParens(node.left, node.op) ? `(${left})` : left;
29
+ const rightStr = needsParens(node.right, node.op, 'right') ? `(${right})` : right;
30
+ if (node.op === '**') {
31
+ return `Math.pow(${leftStr}, ${rightStr})`;
32
+ }
33
+ return `${leftStr} ${node.op} ${rightStr}`;
34
+ }
35
+ visitFunctionCall(node) {
36
+ const args = node.args.map(arg => arg.accept(this)).join(', ');
37
+ // Map function names to Math functions
38
+ const funcMap = {
39
+ 'sin': 'Math.sin',
40
+ 'cos': 'Math.cos',
41
+ 'tan': 'Math.tan',
42
+ 'asin': 'Math.asin',
43
+ 'acos': 'Math.acos',
44
+ 'atan': 'Math.atan',
45
+ 'sinh': 'Math.sinh',
46
+ 'cosh': 'Math.cosh',
47
+ 'tanh': 'Math.tanh',
48
+ 'exp': 'Math.exp',
49
+ 'log': 'Math.log',
50
+ 'ln': 'Math.log',
51
+ 'sqrt': 'Math.sqrt',
52
+ 'abs': 'Math.abs',
53
+ 'sign': 'Math.sign',
54
+ 'floor': 'Math.floor',
55
+ 'ceil': 'Math.ceil',
56
+ 'round': 'Math.round',
57
+ 'pow': 'Math.pow',
58
+ 'min': 'Math.min',
59
+ 'max': 'Math.max',
60
+ 'atan2': 'Math.atan2',
61
+ 'heaviside': '(x => x > 0 ? 1 : 0)',
62
+ 'sigmoid': '(x => 1 / (1 + Math.exp(-x)))'
63
+ };
64
+ const funcName = funcMap[node.name] || node.name;
65
+ return `${funcName}(${args})`;
66
+ }
67
+ visitVectorAccess(node) {
68
+ const vector = node.vector.accept(this);
69
+ return `${vector}.${node.component}`;
70
+ }
71
+ visitVectorConstructor(node) {
72
+ const components = node.components.map(c => c.accept(this)).join(', ');
73
+ return `${node.vectorType}(${components})`;
74
+ }
75
+ }
76
+ /**
77
+ * Check if node needs parentheses based on operator precedence
78
+ */
79
+ function needsParens(node, parentOp, position = 'left') {
80
+ if (node.type === 'Number' || node.type === 'Variable' ||
81
+ node.type === 'FunctionCall' || node.type === 'VectorAccess') {
82
+ return false;
83
+ }
84
+ if (node.type === 'UnaryOp') {
85
+ return false; // Unary operators have high precedence
86
+ }
87
+ if (node.type === 'BinaryOp' && parentOp) {
88
+ const binNode = node;
89
+ const nodePrecedence = getPrecedence(binNode.op);
90
+ const parentPrecedence = getPrecedence(parentOp);
91
+ // Need parens if lower precedence
92
+ if (nodePrecedence < parentPrecedence) {
93
+ return true;
94
+ }
95
+ // For same precedence, check associativity
96
+ if (nodePrecedence === parentPrecedence) {
97
+ // Right-associative operators: ** (power)
98
+ if (parentOp === '**') {
99
+ return position === 'left'; // Left operand needs parens
100
+ }
101
+ // Left-associative: need parens on right for -, /
102
+ if ((parentOp === '-' || parentOp === '/') && position === 'right') {
103
+ return true;
104
+ }
105
+ }
106
+ }
107
+ return false;
108
+ }
109
+ /**
110
+ * Get operator precedence (higher = tighter binding)
111
+ */
112
+ function getPrecedence(op) {
113
+ switch (op) {
114
+ case '+':
115
+ case '-':
116
+ return 1;
117
+ case '*':
118
+ case '/':
119
+ return 2;
120
+ case '**':
121
+ return 3;
122
+ default:
123
+ return 0;
124
+ }
125
+ }
126
+ /**
127
+ * Generate mathematical notation for expressions (LaTeX-style comments)
128
+ */
129
+ class MathNotationVisitor {
130
+ visitNumber(node) {
131
+ return String(node.value);
132
+ }
133
+ visitVariable(node) {
134
+ return node.name;
135
+ }
136
+ visitUnaryOp(node) {
137
+ const operand = node.operand.accept(this);
138
+ if (needsParens(node.operand)) {
139
+ return `${node.op}(${operand})`;
140
+ }
141
+ return `${node.op}${operand}`;
142
+ }
143
+ visitBinaryOp(node) {
144
+ const left = node.left.accept(this);
145
+ const right = node.right.accept(this);
146
+ const leftStr = needsParens(node.left, node.op) ? `(${left})` : left;
147
+ const rightStr = needsParens(node.right, node.op, 'right') ? `(${right})` : right;
148
+ if (node.op === '**') {
149
+ return `${leftStr}^${rightStr}`;
150
+ }
151
+ return `${leftStr} ${node.op} ${rightStr}`;
152
+ }
153
+ visitFunctionCall(node) {
154
+ const args = node.args.map(arg => arg.accept(this)).join(', ');
155
+ return `${node.name}(${args})`;
156
+ }
157
+ visitVectorAccess(node) {
158
+ const vector = node.vector.accept(this);
159
+ return `${vector}.${node.component}`;
160
+ }
161
+ visitVectorConstructor(node) {
162
+ const components = node.components.map(c => c.accept(this)).join(', ');
163
+ return `${node.vectorType}(${components})`;
164
+ }
165
+ }
166
+ /**
167
+ * Generate JavaScript code from AST
168
+ */
169
+ export function generateCode(node) {
170
+ const visitor = new CodeGenVisitor();
171
+ return node.accept(visitor);
172
+ }
173
+ /**
174
+ * Generate mathematical notation from AST
175
+ */
176
+ export function generateMathNotation(node) {
177
+ const visitor = new MathNotationVisitor();
178
+ return node.accept(visitor);
179
+ }
180
+ /**
181
+ * Generate complete gradient code for a program
182
+ */
183
+ export function generateGradientCode(program, gradients, options = {}) {
184
+ const { includeMath = true, varStyle = 'const', includeForward = true, indent = ' ' } = options;
185
+ const lines = [];
186
+ // Header comment
187
+ lines.push('// Auto-generated gradient computation');
188
+ lines.push('// Generated by ScalarAutograd symbolic differentiation');
189
+ lines.push('');
190
+ if (includeForward) {
191
+ lines.push('// Forward pass');
192
+ for (const assignment of program.assignments) {
193
+ const code = generateCode(assignment.expression);
194
+ if (includeMath) {
195
+ const math = generateMathNotation(assignment.expression);
196
+ lines.push(`// ${assignment.variable} = ${math}`);
197
+ }
198
+ lines.push(`${varStyle} ${assignment.variable} = ${code};`);
199
+ lines.push('');
200
+ }
201
+ lines.push('// Gradient computation (reverse-mode autodiff)');
202
+ lines.push('');
203
+ }
204
+ // Generate gradients in topological order
205
+ const paramNames = Array.from(gradients.keys()).filter(name => name !== program.output);
206
+ for (const param of paramNames) {
207
+ const gradNode = gradients.get(param);
208
+ if (!gradNode)
209
+ continue;
210
+ const code = generateCode(gradNode);
211
+ if (includeMath) {
212
+ const math = generateMathNotation(gradNode);
213
+ lines.push(`// ∂${program.output}/∂${param} = ${math}`);
214
+ }
215
+ lines.push(`${varStyle} grad_${param} = ${code};`);
216
+ lines.push('');
217
+ }
218
+ // Export result
219
+ lines.push('// Result');
220
+ lines.push(`${varStyle} result = {`);
221
+ lines.push(`${indent}value: ${program.output},`);
222
+ lines.push(`${indent}gradients: {`);
223
+ for (const param of paramNames) {
224
+ lines.push(`${indent}${indent}${param}: grad_${param},`);
225
+ }
226
+ lines.push(`${indent}}`);
227
+ lines.push('};');
228
+ return lines.join('\n');
229
+ }
230
+ /**
231
+ * Generate gradient code as a function
232
+ */
233
+ export function generateGradientFunction(program, gradients, functionName, parameters, options = {}) {
234
+ const { includeMath = true, indent = ' ' } = options;
235
+ const lines = [];
236
+ // Function signature
237
+ lines.push(`/**`);
238
+ lines.push(` * Compute ${program.output} and its gradients`);
239
+ for (const param of parameters) {
240
+ lines.push(` * @param ${param} - Input parameter`);
241
+ }
242
+ lines.push(` * @returns Object with value and gradients`);
243
+ lines.push(` */`);
244
+ lines.push(`function ${functionName}(${parameters.join(', ')}) {`);
245
+ // Forward pass
246
+ if (includeMath) {
247
+ lines.push(`${indent}// Forward pass`);
248
+ }
249
+ for (const assignment of program.assignments) {
250
+ const code = generateCode(assignment.expression);
251
+ if (includeMath) {
252
+ const math = generateMathNotation(assignment.expression);
253
+ lines.push(`${indent}// ${assignment.variable} = ${math}`);
254
+ }
255
+ lines.push(`${indent}const ${assignment.variable} = ${code};`);
256
+ }
257
+ lines.push('');
258
+ if (includeMath) {
259
+ lines.push(`${indent}// Gradient computation`);
260
+ }
261
+ // Gradients
262
+ for (const param of parameters) {
263
+ const gradNode = gradients.get(param);
264
+ if (!gradNode)
265
+ continue;
266
+ const code = generateCode(gradNode);
267
+ if (includeMath) {
268
+ const math = generateMathNotation(gradNode);
269
+ lines.push(`${indent}// ∂${program.output}/∂${param} = ${math}`);
270
+ }
271
+ lines.push(`${indent}const grad_${param} = ${code};`);
272
+ }
273
+ lines.push('');
274
+ lines.push(`${indent}return {`);
275
+ lines.push(`${indent}${indent}value: ${program.output},`);
276
+ lines.push(`${indent}${indent}gradients: { ${parameters.map(p => `${p}: grad_${p}`).join(', ')} }`);
277
+ lines.push(`${indent}};`);
278
+ lines.push('}');
279
+ return lines.join('\n');
280
+ }
@@ -0,0 +1,64 @@
1
+ /**
2
+ * Expression parser for symbolic gradient generation.
3
+ * Parses operator-overloaded mathematical expressions into AST.
4
+ * @internal
5
+ */
6
+ import { Program } from './AST';
7
+ /**
8
+ * Parser: converts tokens into AST
9
+ * Grammar (precedence from lowest to highest):
10
+ * assignment → IDENTIFIER '=' expression
11
+ * expression → term (('+' | '-') term)*
12
+ * term → factor (('*' | '/') factor)*
13
+ * factor → power
14
+ * power → postfix ('**' postfix)*
15
+ * postfix → primary ('.' IDENTIFIER)*
16
+ * primary → NUMBER | IDENTIFIER | function_call | vector_constructor | '(' expression ')' | ('+' | '-') primary
17
+ * function_call → IDENTIFIER '(' arg_list? ')'
18
+ * vector_constructor → ('Vec2' | 'Vec3') '(' arg_list ')'
19
+ * arg_list → expression (',' expression)*
20
+ */
21
+ export declare class Parser {
22
+ private tokens;
23
+ private current;
24
+ constructor(text: string);
25
+ private peek;
26
+ private advance;
27
+ private expect;
28
+ /**
29
+ * Parse a complete program
30
+ */
31
+ parseProgram(): Program;
32
+ /**
33
+ * Parse an expression
34
+ */
35
+ private parseExpression;
36
+ /**
37
+ * Parse term (addition/subtraction)
38
+ */
39
+ private parseTerm;
40
+ /**
41
+ * Parse factor (multiplication/division)
42
+ */
43
+ private parseFactor;
44
+ /**
45
+ * Parse power (exponentiation)
46
+ */
47
+ private parsePower;
48
+ /**
49
+ * Parse postfix (member access like v.x)
50
+ */
51
+ private parsePostfix;
52
+ /**
53
+ * Parse primary expression
54
+ */
55
+ private parsePrimary;
56
+ /**
57
+ * Parse function argument list
58
+ */
59
+ private parseArgList;
60
+ }
61
+ /**
62
+ * Parse a mathematical expression string into an AST
63
+ */
64
+ export declare function parse(text: string): Program;