gradient-script 0.1.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. package/README.md +52 -9
  2. package/dist/cli.js +134 -19
  3. package/dist/dsl/AST.d.ts +8 -0
  4. package/dist/dsl/CodeGen.d.ts +8 -3
  5. package/dist/dsl/CodeGen.js +583 -132
  6. package/dist/dsl/Errors.d.ts +6 -1
  7. package/dist/dsl/Errors.js +70 -1
  8. package/dist/dsl/Expander.js +5 -2
  9. package/dist/dsl/ExpressionUtils.d.ts +14 -0
  10. package/dist/dsl/ExpressionUtils.js +56 -0
  11. package/dist/dsl/GradientChecker.d.ts +21 -0
  12. package/dist/dsl/GradientChecker.js +109 -23
  13. package/dist/dsl/Guards.d.ts +3 -1
  14. package/dist/dsl/Guards.js +86 -43
  15. package/dist/dsl/Inliner.d.ts +5 -0
  16. package/dist/dsl/Inliner.js +11 -2
  17. package/dist/dsl/Lexer.js +3 -1
  18. package/dist/dsl/Parser.js +11 -5
  19. package/dist/dsl/Simplify.d.ts +7 -0
  20. package/dist/dsl/Simplify.js +183 -0
  21. package/dist/dsl/egraph/Convert.d.ts +23 -0
  22. package/dist/dsl/egraph/Convert.js +84 -0
  23. package/dist/dsl/egraph/EGraph.d.ts +93 -0
  24. package/dist/dsl/egraph/EGraph.js +292 -0
  25. package/dist/dsl/egraph/ENode.d.ts +63 -0
  26. package/dist/dsl/egraph/ENode.js +94 -0
  27. package/dist/dsl/egraph/Extractor.d.ts +49 -0
  28. package/dist/dsl/egraph/Extractor.js +1068 -0
  29. package/dist/dsl/egraph/Optimizer.d.ts +50 -0
  30. package/dist/dsl/egraph/Optimizer.js +88 -0
  31. package/dist/dsl/egraph/Pattern.d.ts +80 -0
  32. package/dist/dsl/egraph/Pattern.js +325 -0
  33. package/dist/dsl/egraph/Rewriter.d.ts +44 -0
  34. package/dist/dsl/egraph/Rewriter.js +131 -0
  35. package/dist/dsl/egraph/Rules.d.ts +44 -0
  36. package/dist/dsl/egraph/Rules.js +187 -0
  37. package/dist/dsl/egraph/index.d.ts +15 -0
  38. package/dist/dsl/egraph/index.js +21 -0
  39. package/package.json +1 -1
  40. package/dist/dsl/CSE.d.ts +0 -21
  41. package/dist/dsl/CSE.js +0 -194
  42. package/dist/symbolic/AST.d.ts +0 -113
  43. package/dist/symbolic/AST.js +0 -128
  44. package/dist/symbolic/CodeGen.d.ts +0 -35
  45. package/dist/symbolic/CodeGen.js +0 -280
  46. package/dist/symbolic/Parser.d.ts +0 -64
  47. package/dist/symbolic/Parser.js +0 -329
  48. package/dist/symbolic/Simplify.d.ts +0 -10
  49. package/dist/symbolic/Simplify.js +0 -244
  50. package/dist/symbolic/SymbolicDiff.d.ts +0 -35
  51. package/dist/symbolic/SymbolicDiff.js +0 -339
@@ -1,329 +0,0 @@
1
- /**
2
- * Expression parser for symbolic gradient generation.
3
- * Parses operator-overloaded mathematical expressions into AST.
4
- * @internal
5
- */
6
- import { NumberNode, VariableNode, BinaryOpNode, UnaryOpNode, FunctionCallNode, VectorAccessNode, VectorConstructorNode } from './AST';
7
- /**
8
- * Token types for lexical analysis
9
- */
10
- var TokenType;
11
- (function (TokenType) {
12
- TokenType[TokenType["NUMBER"] = 0] = "NUMBER";
13
- TokenType[TokenType["IDENTIFIER"] = 1] = "IDENTIFIER";
14
- TokenType[TokenType["PLUS"] = 2] = "PLUS";
15
- TokenType[TokenType["MINUS"] = 3] = "MINUS";
16
- TokenType[TokenType["MULTIPLY"] = 4] = "MULTIPLY";
17
- TokenType[TokenType["DIVIDE"] = 5] = "DIVIDE";
18
- TokenType[TokenType["POWER"] = 6] = "POWER";
19
- TokenType[TokenType["LPAREN"] = 7] = "LPAREN";
20
- TokenType[TokenType["RPAREN"] = 8] = "RPAREN";
21
- TokenType[TokenType["DOT"] = 9] = "DOT";
22
- TokenType[TokenType["COMMA"] = 10] = "COMMA";
23
- TokenType[TokenType["EQUALS"] = 11] = "EQUALS";
24
- TokenType[TokenType["SEMICOLON"] = 12] = "SEMICOLON";
25
- TokenType[TokenType["NEWLINE"] = 13] = "NEWLINE";
26
- TokenType[TokenType["EOF"] = 14] = "EOF";
27
- })(TokenType || (TokenType = {}));
28
- /**
29
- * Lexer: converts text into tokens
30
- */
31
- class Lexer {
32
- pos = 0;
33
- text;
34
- constructor(text) {
35
- this.text = text;
36
- }
37
- peek(offset = 0) {
38
- const pos = this.pos + offset;
39
- return pos < this.text.length ? this.text[pos] : '\0';
40
- }
41
- advance() {
42
- const ch = this.peek();
43
- this.pos++;
44
- return ch;
45
- }
46
- skipWhitespace() {
47
- while (this.peek() === ' ' || this.peek() === '\t' || this.peek() === '\r') {
48
- this.advance();
49
- }
50
- }
51
- readNumber() {
52
- const start = this.pos;
53
- let numStr = '';
54
- while (/[0-9.]/.test(this.peek())) {
55
- numStr += this.advance();
56
- }
57
- return { type: TokenType.NUMBER, value: parseFloat(numStr), pos: start };
58
- }
59
- readIdentifier() {
60
- const start = this.pos;
61
- let id = '';
62
- while (/[a-zA-Z0-9_]/.test(this.peek())) {
63
- id += this.advance();
64
- }
65
- return { type: TokenType.IDENTIFIER, value: id, pos: start };
66
- }
67
- nextToken() {
68
- this.skipWhitespace();
69
- const ch = this.peek();
70
- const pos = this.pos;
71
- if (ch === '\0') {
72
- return { type: TokenType.EOF, value: '', pos };
73
- }
74
- if (ch === '\n') {
75
- this.advance();
76
- return { type: TokenType.NEWLINE, value: '\n', pos };
77
- }
78
- if (/[0-9]/.test(ch)) {
79
- return this.readNumber();
80
- }
81
- if (/[a-zA-Z_]/.test(ch)) {
82
- return this.readIdentifier();
83
- }
84
- // Single-character tokens
85
- this.advance();
86
- switch (ch) {
87
- case '+': return { type: TokenType.PLUS, value: '+', pos };
88
- case '-': return { type: TokenType.MINUS, value: '-', pos };
89
- case '*':
90
- // Check for **
91
- if (this.peek() === '*') {
92
- this.advance();
93
- return { type: TokenType.POWER, value: '**', pos };
94
- }
95
- return { type: TokenType.MULTIPLY, value: '*', pos };
96
- case '/': return { type: TokenType.DIVIDE, value: '/', pos };
97
- case '(': return { type: TokenType.LPAREN, value: '(', pos };
98
- case ')': return { type: TokenType.RPAREN, value: ')', pos };
99
- case '.': return { type: TokenType.DOT, value: '.', pos };
100
- case ',': return { type: TokenType.COMMA, value: ',', pos };
101
- case '=': return { type: TokenType.EQUALS, value: '=', pos };
102
- case ';': return { type: TokenType.SEMICOLON, value: ';', pos };
103
- default:
104
- throw new Error(`Unexpected character '${ch}' at position ${pos}`);
105
- }
106
- }
107
- tokenize() {
108
- const tokens = [];
109
- let token;
110
- do {
111
- token = this.nextToken();
112
- // Skip newlines and semicolons (treat as statement separators)
113
- if (token.type !== TokenType.NEWLINE && token.type !== TokenType.SEMICOLON) {
114
- tokens.push(token);
115
- }
116
- } while (token.type !== TokenType.EOF);
117
- return tokens;
118
- }
119
- }
120
- /**
121
- * Parser: converts tokens into AST
122
- * Grammar (precedence from lowest to highest):
123
- * assignment → IDENTIFIER '=' expression
124
- * expression → term (('+' | '-') term)*
125
- * term → factor (('*' | '/') factor)*
126
- * factor → power
127
- * power → postfix ('**' postfix)*
128
- * postfix → primary ('.' IDENTIFIER)*
129
- * primary → NUMBER | IDENTIFIER | function_call | vector_constructor | '(' expression ')' | ('+' | '-') primary
130
- * function_call → IDENTIFIER '(' arg_list? ')'
131
- * vector_constructor → ('Vec2' | 'Vec3') '(' arg_list ')'
132
- * arg_list → expression (',' expression)*
133
- */
134
- export class Parser {
135
- tokens;
136
- current = 0;
137
- constructor(text) {
138
- const lexer = new Lexer(text);
139
- this.tokens = lexer.tokenize();
140
- }
141
- peek(offset = 0) {
142
- const idx = this.current + offset;
143
- return idx < this.tokens.length ? this.tokens[idx] : this.tokens[this.tokens.length - 1];
144
- }
145
- advance() {
146
- const token = this.peek();
147
- if (token.type !== TokenType.EOF) {
148
- this.current++;
149
- }
150
- return token;
151
- }
152
- expect(type, message) {
153
- const token = this.peek();
154
- if (token.type !== type) {
155
- throw new Error(message || `Expected token type ${TokenType[type]}, got ${TokenType[token.type]} at position ${token.pos}`);
156
- }
157
- return this.advance();
158
- }
159
- /**
160
- * Parse a complete program
161
- */
162
- parseProgram() {
163
- const assignments = [];
164
- let outputVar = '';
165
- while (this.peek().type !== TokenType.EOF) {
166
- // Check if this is an assignment
167
- if (this.peek().type === TokenType.IDENTIFIER && this.peek(1).type === TokenType.EQUALS) {
168
- const varName = this.peek().value;
169
- this.advance(); // identifier
170
- this.advance(); // equals
171
- const expression = this.parseExpression();
172
- assignments.push({ variable: varName, expression });
173
- // Track the last assigned variable as potential output
174
- outputVar = varName;
175
- }
176
- else {
177
- // Standalone expression - treat as output
178
- const expression = this.parseExpression();
179
- // Generate a temporary output variable
180
- const tempVar = `_output${assignments.length}`;
181
- assignments.push({ variable: tempVar, expression });
182
- outputVar = tempVar;
183
- }
184
- }
185
- // Look for explicit "output" variable
186
- const outputAssignment = assignments.find(a => a.variable === 'output');
187
- if (outputAssignment) {
188
- outputVar = 'output';
189
- }
190
- return { assignments, output: outputVar };
191
- }
192
- /**
193
- * Parse an expression
194
- */
195
- parseExpression() {
196
- return this.parseTerm();
197
- }
198
- /**
199
- * Parse term (addition/subtraction)
200
- */
201
- parseTerm() {
202
- let left = this.parseFactor();
203
- while (this.peek().type === TokenType.PLUS || this.peek().type === TokenType.MINUS) {
204
- const op = this.advance().value;
205
- const right = this.parseFactor();
206
- left = new BinaryOpNode(op, left, right);
207
- }
208
- return left;
209
- }
210
- /**
211
- * Parse factor (multiplication/division)
212
- */
213
- parseFactor() {
214
- let left = this.parsePower();
215
- while (this.peek().type === TokenType.MULTIPLY || this.peek().type === TokenType.DIVIDE) {
216
- const op = this.advance().value;
217
- const right = this.parsePower();
218
- left = new BinaryOpNode(op, left, right);
219
- }
220
- return left;
221
- }
222
- /**
223
- * Parse power (exponentiation)
224
- */
225
- parsePower() {
226
- let left = this.parsePostfix();
227
- // Right-associative: 2**3**2 = 2**(3**2) = 512
228
- if (this.peek().type === TokenType.POWER) {
229
- this.advance();
230
- const right = this.parsePower(); // Recursive for right-associativity
231
- left = new BinaryOpNode('**', left, right);
232
- }
233
- return left;
234
- }
235
- /**
236
- * Parse postfix (member access like v.x)
237
- */
238
- parsePostfix() {
239
- let node = this.parsePrimary();
240
- while (this.peek().type === TokenType.DOT) {
241
- this.advance(); // consume '.'
242
- const member = this.expect(TokenType.IDENTIFIER, 'Expected component name after "."');
243
- const memberName = member.value;
244
- // Check if it's a vector component or method
245
- if (memberName === 'x' || memberName === 'y' || memberName === 'z') {
246
- node = new VectorAccessNode(node, memberName);
247
- }
248
- else if (memberName === 'magnitude' || memberName === 'sqrMagnitude' || memberName === 'normalized') {
249
- // These are property accesses that look like functions
250
- node = new FunctionCallNode(memberName, [node]);
251
- }
252
- else {
253
- // Method call (e.g., v.dot(u))
254
- this.expect(TokenType.LPAREN);
255
- const args = [node]; // First arg is the object itself
256
- args.push(...this.parseArgList());
257
- this.expect(TokenType.RPAREN);
258
- node = new FunctionCallNode(memberName, args);
259
- }
260
- }
261
- return node;
262
- }
263
- /**
264
- * Parse primary expression
265
- */
266
- parsePrimary() {
267
- const token = this.peek();
268
- // Number literal
269
- if (token.type === TokenType.NUMBER) {
270
- this.advance();
271
- return new NumberNode(token.value);
272
- }
273
- // Unary + or -
274
- if (token.type === TokenType.PLUS || token.type === TokenType.MINUS) {
275
- const op = this.advance().value;
276
- const operand = this.parsePrimary();
277
- return new UnaryOpNode(op, operand);
278
- }
279
- // Parenthesized expression
280
- if (token.type === TokenType.LPAREN) {
281
- this.advance();
282
- const expr = this.parseExpression();
283
- this.expect(TokenType.RPAREN, 'Expected closing parenthesis');
284
- return expr;
285
- }
286
- // Identifier (variable, function call, or vector constructor)
287
- if (token.type === TokenType.IDENTIFIER) {
288
- const name = token.value;
289
- this.advance();
290
- // Check for function call or vector constructor
291
- if (this.peek().type === TokenType.LPAREN) {
292
- this.advance(); // consume '('
293
- const args = this.parseArgList();
294
- this.expect(TokenType.RPAREN, 'Expected closing parenthesis in function call');
295
- // Vector constructor
296
- if (name === 'Vec2' || name === 'Vec3') {
297
- return new VectorConstructorNode(name, args);
298
- }
299
- // Function call
300
- return new FunctionCallNode(name, args);
301
- }
302
- // Just a variable reference
303
- return new VariableNode(name);
304
- }
305
- throw new Error(`Unexpected token ${TokenType[token.type]} at position ${token.pos}`);
306
- }
307
- /**
308
- * Parse function argument list
309
- */
310
- parseArgList() {
311
- const args = [];
312
- if (this.peek().type === TokenType.RPAREN) {
313
- return args; // Empty arg list
314
- }
315
- args.push(this.parseExpression());
316
- while (this.peek().type === TokenType.COMMA) {
317
- this.advance(); // consume comma
318
- args.push(this.parseExpression());
319
- }
320
- return args;
321
- }
322
- }
323
- /**
324
- * Parse a mathematical expression string into an AST
325
- */
326
- export function parse(text) {
327
- const parser = new Parser(text);
328
- return parser.parseProgram();
329
- }
@@ -1,10 +0,0 @@
1
- /**
2
- * Expression simplification for symbolic gradients.
3
- * Applies algebraic simplification rules to make formulas more readable.
4
- * @internal
5
- */
6
- import { ASTNode } from './AST';
7
- /**
8
- * Simplify an AST node
9
- */
10
- export declare function simplify(node: ASTNode): ASTNode;
@@ -1,244 +0,0 @@
1
- /**
2
- * Expression simplification for symbolic gradients.
3
- * Applies algebraic simplification rules to make formulas more readable.
4
- * @internal
5
- */
6
- import { NumberNode, BinaryOpNode, UnaryOpNode, FunctionCallNode, VectorAccessNode, VectorConstructorNode } from './AST';
7
- /**
8
- * Simplification visitor
9
- */
10
- class SimplificationVisitor {
11
- visitNumber(node) {
12
- return node;
13
- }
14
- visitVariable(node) {
15
- return node;
16
- }
17
- visitUnaryOp(node) {
18
- const operand = node.operand.accept(this);
19
- // Simplify -(-x) = x
20
- if (node.op === '-' && operand.type === 'UnaryOp') {
21
- const unary = operand;
22
- if (unary.op === '-') {
23
- return unary.operand;
24
- }
25
- }
26
- // Simplify -(0) = 0
27
- if (node.op === '-' && operand.type === 'Number') {
28
- return new NumberNode(-operand.value);
29
- }
30
- // Simplify +(x) = x
31
- if (node.op === '+') {
32
- return operand;
33
- }
34
- return new UnaryOpNode(node.op, operand);
35
- }
36
- visitBinaryOp(node) {
37
- // First simplify children
38
- const left = node.left.accept(this);
39
- const right = node.right.accept(this);
40
- // Get numeric values if both are numbers
41
- const leftNum = left.type === 'Number' ? left.value : null;
42
- const rightNum = right.type === 'Number' ? right.value : null;
43
- // Constant folding
44
- if (leftNum !== null && rightNum !== null) {
45
- switch (node.op) {
46
- case '+': return new NumberNode(leftNum + rightNum);
47
- case '-': return new NumberNode(leftNum - rightNum);
48
- case '*': return new NumberNode(leftNum * rightNum);
49
- case '/': return new NumberNode(leftNum / rightNum);
50
- case '**': return new NumberNode(Math.pow(leftNum, rightNum));
51
- }
52
- }
53
- // Addition simplifications
54
- if (node.op === '+') {
55
- // 0 + x = x
56
- if (leftNum === 0)
57
- return right;
58
- // x + 0 = x
59
- if (rightNum === 0)
60
- return left;
61
- // x + x = 2*x
62
- if (nodesEqual(left, right)) {
63
- return new BinaryOpNode('*', new NumberNode(2), left);
64
- }
65
- }
66
- // Subtraction simplifications
67
- if (node.op === '-') {
68
- // x - 0 = x
69
- if (rightNum === 0)
70
- return left;
71
- // 0 - x = -x
72
- if (leftNum === 0)
73
- return new UnaryOpNode('-', right);
74
- // x - x = 0
75
- if (nodesEqual(left, right)) {
76
- return new NumberNode(0);
77
- }
78
- }
79
- // Multiplication simplifications
80
- if (node.op === '*') {
81
- // 0 * x = 0
82
- if (leftNum === 0 || rightNum === 0)
83
- return new NumberNode(0);
84
- // 1 * x = x
85
- if (leftNum === 1)
86
- return right;
87
- // x * 1 = x
88
- if (rightNum === 1)
89
- return left;
90
- // -1 * x = -x
91
- if (leftNum === -1)
92
- return new UnaryOpNode('-', right);
93
- // x * -1 = -x
94
- if (rightNum === -1)
95
- return new UnaryOpNode('-', left);
96
- // x * x = x^2
97
- if (nodesEqual(left, right)) {
98
- return new BinaryOpNode('**', left, new NumberNode(2));
99
- }
100
- }
101
- // Division simplifications
102
- if (node.op === '/') {
103
- // 0 / x = 0
104
- if (leftNum === 0)
105
- return new NumberNode(0);
106
- // x / 1 = x
107
- if (rightNum === 1)
108
- return left;
109
- // x / x = 1
110
- if (nodesEqual(left, right))
111
- return new NumberNode(1);
112
- }
113
- // Power simplifications
114
- if (node.op === '**') {
115
- // x^0 = 1
116
- if (rightNum === 0)
117
- return new NumberNode(1);
118
- // x^1 = x
119
- if (rightNum === 1)
120
- return left;
121
- // 0^x = 0 (for x > 0)
122
- if (leftNum === 0)
123
- return new NumberNode(0);
124
- // 1^x = 1
125
- if (leftNum === 1)
126
- return new NumberNode(1);
127
- }
128
- return new BinaryOpNode(node.op, left, right);
129
- }
130
- visitFunctionCall(node) {
131
- const args = node.args.map(arg => arg.accept(this));
132
- // Check for constant arguments
133
- if (args.length === 1 && args[0].type === 'Number') {
134
- const value = args[0].value;
135
- switch (node.name) {
136
- case 'sin': return new NumberNode(Math.sin(value));
137
- case 'cos': return new NumberNode(Math.cos(value));
138
- case 'tan': return new NumberNode(Math.tan(value));
139
- case 'exp': return new NumberNode(Math.exp(value));
140
- case 'log':
141
- case 'ln': return new NumberNode(Math.log(value));
142
- case 'sqrt': return new NumberNode(Math.sqrt(value));
143
- case 'abs': return new NumberNode(Math.abs(value));
144
- case 'sign': return new NumberNode(Math.sign(value));
145
- case 'floor': return new NumberNode(Math.floor(value));
146
- case 'ceil': return new NumberNode(Math.ceil(value));
147
- case 'round': return new NumberNode(Math.round(value));
148
- case 'asin': return new NumberNode(Math.asin(value));
149
- case 'acos': return new NumberNode(Math.acos(value));
150
- case 'atan': return new NumberNode(Math.atan(value));
151
- case 'sinh': return new NumberNode(Math.sinh(value));
152
- case 'cosh': return new NumberNode(Math.cosh(value));
153
- case 'tanh': return new NumberNode(Math.tanh(value));
154
- }
155
- }
156
- // Two-argument functions
157
- if (args.length === 2 && args[0].type === 'Number' && args[1].type === 'Number') {
158
- const val1 = args[0].value;
159
- const val2 = args[1].value;
160
- switch (node.name) {
161
- case 'pow': return new NumberNode(Math.pow(val1, val2));
162
- case 'min': return new NumberNode(Math.min(val1, val2));
163
- case 'max': return new NumberNode(Math.max(val1, val2));
164
- }
165
- }
166
- // Special simplifications
167
- // exp(0) = 1
168
- if (node.name === 'exp' && args[0].type === 'Number' && args[0].value === 0) {
169
- return new NumberNode(1);
170
- }
171
- // log(1) = 0
172
- if ((node.name === 'log' || node.name === 'ln') && args[0].type === 'Number' && args[0].value === 1) {
173
- return new NumberNode(0);
174
- }
175
- // sqrt(1) = 1
176
- if (node.name === 'sqrt' && args[0].type === 'Number' && args[0].value === 1) {
177
- return new NumberNode(1);
178
- }
179
- return new FunctionCallNode(node.name, args);
180
- }
181
- visitVectorAccess(node) {
182
- const vector = node.vector.accept(this);
183
- return new VectorAccessNode(vector, node.component);
184
- }
185
- visitVectorConstructor(node) {
186
- const components = node.components.map(c => c.accept(this));
187
- return new VectorConstructorNode(node.vectorType, components);
188
- }
189
- }
190
- /**
191
- * Check if two AST nodes are structurally equal
192
- */
193
- function nodesEqual(a, b) {
194
- if (a.type !== b.type)
195
- return false;
196
- if (a.type === 'Number' && b.type === 'Number') {
197
- return a.value === b.value;
198
- }
199
- if (a.type === 'Variable' && b.type === 'Variable') {
200
- return a.name === b.name;
201
- }
202
- if (a.type === 'UnaryOp' && b.type === 'UnaryOp') {
203
- const aUnary = a;
204
- const bUnary = b;
205
- return aUnary.op === bUnary.op && nodesEqual(aUnary.operand, bUnary.operand);
206
- }
207
- if (a.type === 'BinaryOp' && b.type === 'BinaryOp') {
208
- const aBinary = a;
209
- const bBinary = b;
210
- return aBinary.op === bBinary.op &&
211
- nodesEqual(aBinary.left, bBinary.left) &&
212
- nodesEqual(aBinary.right, bBinary.right);
213
- }
214
- if (a.type === 'FunctionCall' && b.type === 'FunctionCall') {
215
- const aFunc = a;
216
- const bFunc = b;
217
- return aFunc.name === bFunc.name &&
218
- aFunc.args.length === bFunc.args.length &&
219
- aFunc.args.every((arg, i) => nodesEqual(arg, bFunc.args[i]));
220
- }
221
- if (a.type === 'VectorAccess' && b.type === 'VectorAccess') {
222
- const aVec = a;
223
- const bVec = b;
224
- return aVec.component === bVec.component && nodesEqual(aVec.vector, bVec.vector);
225
- }
226
- return false;
227
- }
228
- /**
229
- * Simplify an AST node
230
- */
231
- export function simplify(node) {
232
- const visitor = new SimplificationVisitor();
233
- let current = node;
234
- let previous;
235
- // Keep simplifying until we reach a fixed point
236
- let iterations = 0;
237
- const maxIterations = 10;
238
- do {
239
- previous = current;
240
- current = current.accept(visitor);
241
- iterations++;
242
- } while (iterations < maxIterations && !nodesEqual(current, previous));
243
- return current;
244
- }
@@ -1,35 +0,0 @@
1
- /**
2
- * Symbolic differentiation engine.
3
- * Applies differentiation rules to AST and generates symbolic gradient expressions.
4
- * @internal
5
- */
6
- import { ASTNode, ASTVisitor, NumberNode, VariableNode, BinaryOpNode, UnaryOpNode, FunctionCallNode, VectorAccessNode, VectorConstructorNode, Program } from './AST';
7
- /**
8
- * Differentiate an AST node with respect to a variable
9
- */
10
- export declare class DifferentiationVisitor implements ASTVisitor<ASTNode> {
11
- private wrt;
12
- constructor(wrt: string);
13
- visitNumber(node: NumberNode): ASTNode;
14
- visitVariable(node: VariableNode): ASTNode;
15
- visitUnaryOp(node: UnaryOpNode): ASTNode;
16
- visitBinaryOp(node: BinaryOpNode): ASTNode;
17
- visitFunctionCall(node: FunctionCallNode): ASTNode;
18
- visitVectorAccess(node: VectorAccessNode): ASTNode;
19
- visitVectorConstructor(node: VectorConstructorNode): ASTNode;
20
- }
21
- /**
22
- * Differentiate an AST node with respect to a variable
23
- */
24
- export declare function differentiate(node: ASTNode, wrt: string): ASTNode;
25
- /**
26
- * Compute all gradients for a program
27
- */
28
- export interface GradientResult {
29
- variable: string;
30
- gradient: ASTNode;
31
- }
32
- /**
33
- * Compute gradients of output w.r.t. all parameters
34
- */
35
- export declare function computeGradients(program: Program, parameters: string[]): Map<string, ASTNode>;