gradient-script 0.1.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +52 -9
- package/dist/cli.js +134 -19
- package/dist/dsl/AST.d.ts +8 -0
- package/dist/dsl/CodeGen.d.ts +8 -3
- package/dist/dsl/CodeGen.js +583 -132
- package/dist/dsl/Errors.d.ts +6 -1
- package/dist/dsl/Errors.js +70 -1
- package/dist/dsl/Expander.js +5 -2
- package/dist/dsl/ExpressionUtils.d.ts +14 -0
- package/dist/dsl/ExpressionUtils.js +56 -0
- package/dist/dsl/GradientChecker.d.ts +21 -0
- package/dist/dsl/GradientChecker.js +109 -23
- package/dist/dsl/Guards.d.ts +3 -1
- package/dist/dsl/Guards.js +86 -43
- package/dist/dsl/Inliner.d.ts +5 -0
- package/dist/dsl/Inliner.js +11 -2
- package/dist/dsl/Lexer.js +3 -1
- package/dist/dsl/Parser.js +11 -5
- package/dist/dsl/Simplify.d.ts +7 -0
- package/dist/dsl/Simplify.js +183 -0
- package/dist/dsl/egraph/Convert.d.ts +23 -0
- package/dist/dsl/egraph/Convert.js +84 -0
- package/dist/dsl/egraph/EGraph.d.ts +93 -0
- package/dist/dsl/egraph/EGraph.js +292 -0
- package/dist/dsl/egraph/ENode.d.ts +63 -0
- package/dist/dsl/egraph/ENode.js +94 -0
- package/dist/dsl/egraph/Extractor.d.ts +49 -0
- package/dist/dsl/egraph/Extractor.js +1068 -0
- package/dist/dsl/egraph/Optimizer.d.ts +50 -0
- package/dist/dsl/egraph/Optimizer.js +88 -0
- package/dist/dsl/egraph/Pattern.d.ts +80 -0
- package/dist/dsl/egraph/Pattern.js +325 -0
- package/dist/dsl/egraph/Rewriter.d.ts +44 -0
- package/dist/dsl/egraph/Rewriter.js +131 -0
- package/dist/dsl/egraph/Rules.d.ts +44 -0
- package/dist/dsl/egraph/Rules.js +187 -0
- package/dist/dsl/egraph/index.d.ts +15 -0
- package/dist/dsl/egraph/index.js +21 -0
- package/package.json +1 -1
- package/dist/dsl/CSE.d.ts +0 -21
- package/dist/dsl/CSE.js +0 -194
- package/dist/symbolic/AST.d.ts +0 -113
- package/dist/symbolic/AST.js +0 -128
- package/dist/symbolic/CodeGen.d.ts +0 -35
- package/dist/symbolic/CodeGen.js +0 -280
- package/dist/symbolic/Parser.d.ts +0 -64
- package/dist/symbolic/Parser.js +0 -329
- package/dist/symbolic/Simplify.d.ts +0 -10
- package/dist/symbolic/Simplify.js +0 -244
- package/dist/symbolic/SymbolicDiff.d.ts +0 -35
- package/dist/symbolic/SymbolicDiff.js +0 -339
package/dist/symbolic/Parser.js
DELETED
|
@@ -1,329 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Expression parser for symbolic gradient generation.
|
|
3
|
-
* Parses operator-overloaded mathematical expressions into AST.
|
|
4
|
-
* @internal
|
|
5
|
-
*/
|
|
6
|
-
import { NumberNode, VariableNode, BinaryOpNode, UnaryOpNode, FunctionCallNode, VectorAccessNode, VectorConstructorNode } from './AST';
|
|
7
|
-
/**
|
|
8
|
-
* Token types for lexical analysis
|
|
9
|
-
*/
|
|
10
|
-
var TokenType;
|
|
11
|
-
(function (TokenType) {
|
|
12
|
-
TokenType[TokenType["NUMBER"] = 0] = "NUMBER";
|
|
13
|
-
TokenType[TokenType["IDENTIFIER"] = 1] = "IDENTIFIER";
|
|
14
|
-
TokenType[TokenType["PLUS"] = 2] = "PLUS";
|
|
15
|
-
TokenType[TokenType["MINUS"] = 3] = "MINUS";
|
|
16
|
-
TokenType[TokenType["MULTIPLY"] = 4] = "MULTIPLY";
|
|
17
|
-
TokenType[TokenType["DIVIDE"] = 5] = "DIVIDE";
|
|
18
|
-
TokenType[TokenType["POWER"] = 6] = "POWER";
|
|
19
|
-
TokenType[TokenType["LPAREN"] = 7] = "LPAREN";
|
|
20
|
-
TokenType[TokenType["RPAREN"] = 8] = "RPAREN";
|
|
21
|
-
TokenType[TokenType["DOT"] = 9] = "DOT";
|
|
22
|
-
TokenType[TokenType["COMMA"] = 10] = "COMMA";
|
|
23
|
-
TokenType[TokenType["EQUALS"] = 11] = "EQUALS";
|
|
24
|
-
TokenType[TokenType["SEMICOLON"] = 12] = "SEMICOLON";
|
|
25
|
-
TokenType[TokenType["NEWLINE"] = 13] = "NEWLINE";
|
|
26
|
-
TokenType[TokenType["EOF"] = 14] = "EOF";
|
|
27
|
-
})(TokenType || (TokenType = {}));
|
|
28
|
-
/**
|
|
29
|
-
* Lexer: converts text into tokens
|
|
30
|
-
*/
|
|
31
|
-
class Lexer {
|
|
32
|
-
pos = 0;
|
|
33
|
-
text;
|
|
34
|
-
constructor(text) {
|
|
35
|
-
this.text = text;
|
|
36
|
-
}
|
|
37
|
-
peek(offset = 0) {
|
|
38
|
-
const pos = this.pos + offset;
|
|
39
|
-
return pos < this.text.length ? this.text[pos] : '\0';
|
|
40
|
-
}
|
|
41
|
-
advance() {
|
|
42
|
-
const ch = this.peek();
|
|
43
|
-
this.pos++;
|
|
44
|
-
return ch;
|
|
45
|
-
}
|
|
46
|
-
skipWhitespace() {
|
|
47
|
-
while (this.peek() === ' ' || this.peek() === '\t' || this.peek() === '\r') {
|
|
48
|
-
this.advance();
|
|
49
|
-
}
|
|
50
|
-
}
|
|
51
|
-
readNumber() {
|
|
52
|
-
const start = this.pos;
|
|
53
|
-
let numStr = '';
|
|
54
|
-
while (/[0-9.]/.test(this.peek())) {
|
|
55
|
-
numStr += this.advance();
|
|
56
|
-
}
|
|
57
|
-
return { type: TokenType.NUMBER, value: parseFloat(numStr), pos: start };
|
|
58
|
-
}
|
|
59
|
-
readIdentifier() {
|
|
60
|
-
const start = this.pos;
|
|
61
|
-
let id = '';
|
|
62
|
-
while (/[a-zA-Z0-9_]/.test(this.peek())) {
|
|
63
|
-
id += this.advance();
|
|
64
|
-
}
|
|
65
|
-
return { type: TokenType.IDENTIFIER, value: id, pos: start };
|
|
66
|
-
}
|
|
67
|
-
nextToken() {
|
|
68
|
-
this.skipWhitespace();
|
|
69
|
-
const ch = this.peek();
|
|
70
|
-
const pos = this.pos;
|
|
71
|
-
if (ch === '\0') {
|
|
72
|
-
return { type: TokenType.EOF, value: '', pos };
|
|
73
|
-
}
|
|
74
|
-
if (ch === '\n') {
|
|
75
|
-
this.advance();
|
|
76
|
-
return { type: TokenType.NEWLINE, value: '\n', pos };
|
|
77
|
-
}
|
|
78
|
-
if (/[0-9]/.test(ch)) {
|
|
79
|
-
return this.readNumber();
|
|
80
|
-
}
|
|
81
|
-
if (/[a-zA-Z_]/.test(ch)) {
|
|
82
|
-
return this.readIdentifier();
|
|
83
|
-
}
|
|
84
|
-
// Single-character tokens
|
|
85
|
-
this.advance();
|
|
86
|
-
switch (ch) {
|
|
87
|
-
case '+': return { type: TokenType.PLUS, value: '+', pos };
|
|
88
|
-
case '-': return { type: TokenType.MINUS, value: '-', pos };
|
|
89
|
-
case '*':
|
|
90
|
-
// Check for **
|
|
91
|
-
if (this.peek() === '*') {
|
|
92
|
-
this.advance();
|
|
93
|
-
return { type: TokenType.POWER, value: '**', pos };
|
|
94
|
-
}
|
|
95
|
-
return { type: TokenType.MULTIPLY, value: '*', pos };
|
|
96
|
-
case '/': return { type: TokenType.DIVIDE, value: '/', pos };
|
|
97
|
-
case '(': return { type: TokenType.LPAREN, value: '(', pos };
|
|
98
|
-
case ')': return { type: TokenType.RPAREN, value: ')', pos };
|
|
99
|
-
case '.': return { type: TokenType.DOT, value: '.', pos };
|
|
100
|
-
case ',': return { type: TokenType.COMMA, value: ',', pos };
|
|
101
|
-
case '=': return { type: TokenType.EQUALS, value: '=', pos };
|
|
102
|
-
case ';': return { type: TokenType.SEMICOLON, value: ';', pos };
|
|
103
|
-
default:
|
|
104
|
-
throw new Error(`Unexpected character '${ch}' at position ${pos}`);
|
|
105
|
-
}
|
|
106
|
-
}
|
|
107
|
-
tokenize() {
|
|
108
|
-
const tokens = [];
|
|
109
|
-
let token;
|
|
110
|
-
do {
|
|
111
|
-
token = this.nextToken();
|
|
112
|
-
// Skip newlines and semicolons (treat as statement separators)
|
|
113
|
-
if (token.type !== TokenType.NEWLINE && token.type !== TokenType.SEMICOLON) {
|
|
114
|
-
tokens.push(token);
|
|
115
|
-
}
|
|
116
|
-
} while (token.type !== TokenType.EOF);
|
|
117
|
-
return tokens;
|
|
118
|
-
}
|
|
119
|
-
}
|
|
120
|
-
/**
|
|
121
|
-
* Parser: converts tokens into AST
|
|
122
|
-
* Grammar (precedence from lowest to highest):
|
|
123
|
-
* assignment → IDENTIFIER '=' expression
|
|
124
|
-
* expression → term (('+' | '-') term)*
|
|
125
|
-
* term → factor (('*' | '/') factor)*
|
|
126
|
-
* factor → power
|
|
127
|
-
* power → postfix ('**' postfix)*
|
|
128
|
-
* postfix → primary ('.' IDENTIFIER)*
|
|
129
|
-
* primary → NUMBER | IDENTIFIER | function_call | vector_constructor | '(' expression ')' | ('+' | '-') primary
|
|
130
|
-
* function_call → IDENTIFIER '(' arg_list? ')'
|
|
131
|
-
* vector_constructor → ('Vec2' | 'Vec3') '(' arg_list ')'
|
|
132
|
-
* arg_list → expression (',' expression)*
|
|
133
|
-
*/
|
|
134
|
-
export class Parser {
|
|
135
|
-
tokens;
|
|
136
|
-
current = 0;
|
|
137
|
-
constructor(text) {
|
|
138
|
-
const lexer = new Lexer(text);
|
|
139
|
-
this.tokens = lexer.tokenize();
|
|
140
|
-
}
|
|
141
|
-
peek(offset = 0) {
|
|
142
|
-
const idx = this.current + offset;
|
|
143
|
-
return idx < this.tokens.length ? this.tokens[idx] : this.tokens[this.tokens.length - 1];
|
|
144
|
-
}
|
|
145
|
-
advance() {
|
|
146
|
-
const token = this.peek();
|
|
147
|
-
if (token.type !== TokenType.EOF) {
|
|
148
|
-
this.current++;
|
|
149
|
-
}
|
|
150
|
-
return token;
|
|
151
|
-
}
|
|
152
|
-
expect(type, message) {
|
|
153
|
-
const token = this.peek();
|
|
154
|
-
if (token.type !== type) {
|
|
155
|
-
throw new Error(message || `Expected token type ${TokenType[type]}, got ${TokenType[token.type]} at position ${token.pos}`);
|
|
156
|
-
}
|
|
157
|
-
return this.advance();
|
|
158
|
-
}
|
|
159
|
-
/**
|
|
160
|
-
* Parse a complete program
|
|
161
|
-
*/
|
|
162
|
-
parseProgram() {
|
|
163
|
-
const assignments = [];
|
|
164
|
-
let outputVar = '';
|
|
165
|
-
while (this.peek().type !== TokenType.EOF) {
|
|
166
|
-
// Check if this is an assignment
|
|
167
|
-
if (this.peek().type === TokenType.IDENTIFIER && this.peek(1).type === TokenType.EQUALS) {
|
|
168
|
-
const varName = this.peek().value;
|
|
169
|
-
this.advance(); // identifier
|
|
170
|
-
this.advance(); // equals
|
|
171
|
-
const expression = this.parseExpression();
|
|
172
|
-
assignments.push({ variable: varName, expression });
|
|
173
|
-
// Track the last assigned variable as potential output
|
|
174
|
-
outputVar = varName;
|
|
175
|
-
}
|
|
176
|
-
else {
|
|
177
|
-
// Standalone expression - treat as output
|
|
178
|
-
const expression = this.parseExpression();
|
|
179
|
-
// Generate a temporary output variable
|
|
180
|
-
const tempVar = `_output${assignments.length}`;
|
|
181
|
-
assignments.push({ variable: tempVar, expression });
|
|
182
|
-
outputVar = tempVar;
|
|
183
|
-
}
|
|
184
|
-
}
|
|
185
|
-
// Look for explicit "output" variable
|
|
186
|
-
const outputAssignment = assignments.find(a => a.variable === 'output');
|
|
187
|
-
if (outputAssignment) {
|
|
188
|
-
outputVar = 'output';
|
|
189
|
-
}
|
|
190
|
-
return { assignments, output: outputVar };
|
|
191
|
-
}
|
|
192
|
-
/**
|
|
193
|
-
* Parse an expression
|
|
194
|
-
*/
|
|
195
|
-
parseExpression() {
|
|
196
|
-
return this.parseTerm();
|
|
197
|
-
}
|
|
198
|
-
/**
|
|
199
|
-
* Parse term (addition/subtraction)
|
|
200
|
-
*/
|
|
201
|
-
parseTerm() {
|
|
202
|
-
let left = this.parseFactor();
|
|
203
|
-
while (this.peek().type === TokenType.PLUS || this.peek().type === TokenType.MINUS) {
|
|
204
|
-
const op = this.advance().value;
|
|
205
|
-
const right = this.parseFactor();
|
|
206
|
-
left = new BinaryOpNode(op, left, right);
|
|
207
|
-
}
|
|
208
|
-
return left;
|
|
209
|
-
}
|
|
210
|
-
/**
|
|
211
|
-
* Parse factor (multiplication/division)
|
|
212
|
-
*/
|
|
213
|
-
parseFactor() {
|
|
214
|
-
let left = this.parsePower();
|
|
215
|
-
while (this.peek().type === TokenType.MULTIPLY || this.peek().type === TokenType.DIVIDE) {
|
|
216
|
-
const op = this.advance().value;
|
|
217
|
-
const right = this.parsePower();
|
|
218
|
-
left = new BinaryOpNode(op, left, right);
|
|
219
|
-
}
|
|
220
|
-
return left;
|
|
221
|
-
}
|
|
222
|
-
/**
|
|
223
|
-
* Parse power (exponentiation)
|
|
224
|
-
*/
|
|
225
|
-
parsePower() {
|
|
226
|
-
let left = this.parsePostfix();
|
|
227
|
-
// Right-associative: 2**3**2 = 2**(3**2) = 512
|
|
228
|
-
if (this.peek().type === TokenType.POWER) {
|
|
229
|
-
this.advance();
|
|
230
|
-
const right = this.parsePower(); // Recursive for right-associativity
|
|
231
|
-
left = new BinaryOpNode('**', left, right);
|
|
232
|
-
}
|
|
233
|
-
return left;
|
|
234
|
-
}
|
|
235
|
-
/**
|
|
236
|
-
* Parse postfix (member access like v.x)
|
|
237
|
-
*/
|
|
238
|
-
parsePostfix() {
|
|
239
|
-
let node = this.parsePrimary();
|
|
240
|
-
while (this.peek().type === TokenType.DOT) {
|
|
241
|
-
this.advance(); // consume '.'
|
|
242
|
-
const member = this.expect(TokenType.IDENTIFIER, 'Expected component name after "."');
|
|
243
|
-
const memberName = member.value;
|
|
244
|
-
// Check if it's a vector component or method
|
|
245
|
-
if (memberName === 'x' || memberName === 'y' || memberName === 'z') {
|
|
246
|
-
node = new VectorAccessNode(node, memberName);
|
|
247
|
-
}
|
|
248
|
-
else if (memberName === 'magnitude' || memberName === 'sqrMagnitude' || memberName === 'normalized') {
|
|
249
|
-
// These are property accesses that look like functions
|
|
250
|
-
node = new FunctionCallNode(memberName, [node]);
|
|
251
|
-
}
|
|
252
|
-
else {
|
|
253
|
-
// Method call (e.g., v.dot(u))
|
|
254
|
-
this.expect(TokenType.LPAREN);
|
|
255
|
-
const args = [node]; // First arg is the object itself
|
|
256
|
-
args.push(...this.parseArgList());
|
|
257
|
-
this.expect(TokenType.RPAREN);
|
|
258
|
-
node = new FunctionCallNode(memberName, args);
|
|
259
|
-
}
|
|
260
|
-
}
|
|
261
|
-
return node;
|
|
262
|
-
}
|
|
263
|
-
/**
|
|
264
|
-
* Parse primary expression
|
|
265
|
-
*/
|
|
266
|
-
parsePrimary() {
|
|
267
|
-
const token = this.peek();
|
|
268
|
-
// Number literal
|
|
269
|
-
if (token.type === TokenType.NUMBER) {
|
|
270
|
-
this.advance();
|
|
271
|
-
return new NumberNode(token.value);
|
|
272
|
-
}
|
|
273
|
-
// Unary + or -
|
|
274
|
-
if (token.type === TokenType.PLUS || token.type === TokenType.MINUS) {
|
|
275
|
-
const op = this.advance().value;
|
|
276
|
-
const operand = this.parsePrimary();
|
|
277
|
-
return new UnaryOpNode(op, operand);
|
|
278
|
-
}
|
|
279
|
-
// Parenthesized expression
|
|
280
|
-
if (token.type === TokenType.LPAREN) {
|
|
281
|
-
this.advance();
|
|
282
|
-
const expr = this.parseExpression();
|
|
283
|
-
this.expect(TokenType.RPAREN, 'Expected closing parenthesis');
|
|
284
|
-
return expr;
|
|
285
|
-
}
|
|
286
|
-
// Identifier (variable, function call, or vector constructor)
|
|
287
|
-
if (token.type === TokenType.IDENTIFIER) {
|
|
288
|
-
const name = token.value;
|
|
289
|
-
this.advance();
|
|
290
|
-
// Check for function call or vector constructor
|
|
291
|
-
if (this.peek().type === TokenType.LPAREN) {
|
|
292
|
-
this.advance(); // consume '('
|
|
293
|
-
const args = this.parseArgList();
|
|
294
|
-
this.expect(TokenType.RPAREN, 'Expected closing parenthesis in function call');
|
|
295
|
-
// Vector constructor
|
|
296
|
-
if (name === 'Vec2' || name === 'Vec3') {
|
|
297
|
-
return new VectorConstructorNode(name, args);
|
|
298
|
-
}
|
|
299
|
-
// Function call
|
|
300
|
-
return new FunctionCallNode(name, args);
|
|
301
|
-
}
|
|
302
|
-
// Just a variable reference
|
|
303
|
-
return new VariableNode(name);
|
|
304
|
-
}
|
|
305
|
-
throw new Error(`Unexpected token ${TokenType[token.type]} at position ${token.pos}`);
|
|
306
|
-
}
|
|
307
|
-
/**
|
|
308
|
-
* Parse function argument list
|
|
309
|
-
*/
|
|
310
|
-
parseArgList() {
|
|
311
|
-
const args = [];
|
|
312
|
-
if (this.peek().type === TokenType.RPAREN) {
|
|
313
|
-
return args; // Empty arg list
|
|
314
|
-
}
|
|
315
|
-
args.push(this.parseExpression());
|
|
316
|
-
while (this.peek().type === TokenType.COMMA) {
|
|
317
|
-
this.advance(); // consume comma
|
|
318
|
-
args.push(this.parseExpression());
|
|
319
|
-
}
|
|
320
|
-
return args;
|
|
321
|
-
}
|
|
322
|
-
}
|
|
323
|
-
/**
|
|
324
|
-
* Parse a mathematical expression string into an AST
|
|
325
|
-
*/
|
|
326
|
-
export function parse(text) {
|
|
327
|
-
const parser = new Parser(text);
|
|
328
|
-
return parser.parseProgram();
|
|
329
|
-
}
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Expression simplification for symbolic gradients.
|
|
3
|
-
* Applies algebraic simplification rules to make formulas more readable.
|
|
4
|
-
* @internal
|
|
5
|
-
*/
|
|
6
|
-
import { ASTNode } from './AST';
|
|
7
|
-
/**
|
|
8
|
-
* Simplify an AST node
|
|
9
|
-
*/
|
|
10
|
-
export declare function simplify(node: ASTNode): ASTNode;
|
|
@@ -1,244 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Expression simplification for symbolic gradients.
|
|
3
|
-
* Applies algebraic simplification rules to make formulas more readable.
|
|
4
|
-
* @internal
|
|
5
|
-
*/
|
|
6
|
-
import { NumberNode, BinaryOpNode, UnaryOpNode, FunctionCallNode, VectorAccessNode, VectorConstructorNode } from './AST';
|
|
7
|
-
/**
|
|
8
|
-
* Simplification visitor
|
|
9
|
-
*/
|
|
10
|
-
class SimplificationVisitor {
|
|
11
|
-
visitNumber(node) {
|
|
12
|
-
return node;
|
|
13
|
-
}
|
|
14
|
-
visitVariable(node) {
|
|
15
|
-
return node;
|
|
16
|
-
}
|
|
17
|
-
visitUnaryOp(node) {
|
|
18
|
-
const operand = node.operand.accept(this);
|
|
19
|
-
// Simplify -(-x) = x
|
|
20
|
-
if (node.op === '-' && operand.type === 'UnaryOp') {
|
|
21
|
-
const unary = operand;
|
|
22
|
-
if (unary.op === '-') {
|
|
23
|
-
return unary.operand;
|
|
24
|
-
}
|
|
25
|
-
}
|
|
26
|
-
// Simplify -(0) = 0
|
|
27
|
-
if (node.op === '-' && operand.type === 'Number') {
|
|
28
|
-
return new NumberNode(-operand.value);
|
|
29
|
-
}
|
|
30
|
-
// Simplify +(x) = x
|
|
31
|
-
if (node.op === '+') {
|
|
32
|
-
return operand;
|
|
33
|
-
}
|
|
34
|
-
return new UnaryOpNode(node.op, operand);
|
|
35
|
-
}
|
|
36
|
-
visitBinaryOp(node) {
|
|
37
|
-
// First simplify children
|
|
38
|
-
const left = node.left.accept(this);
|
|
39
|
-
const right = node.right.accept(this);
|
|
40
|
-
// Get numeric values if both are numbers
|
|
41
|
-
const leftNum = left.type === 'Number' ? left.value : null;
|
|
42
|
-
const rightNum = right.type === 'Number' ? right.value : null;
|
|
43
|
-
// Constant folding
|
|
44
|
-
if (leftNum !== null && rightNum !== null) {
|
|
45
|
-
switch (node.op) {
|
|
46
|
-
case '+': return new NumberNode(leftNum + rightNum);
|
|
47
|
-
case '-': return new NumberNode(leftNum - rightNum);
|
|
48
|
-
case '*': return new NumberNode(leftNum * rightNum);
|
|
49
|
-
case '/': return new NumberNode(leftNum / rightNum);
|
|
50
|
-
case '**': return new NumberNode(Math.pow(leftNum, rightNum));
|
|
51
|
-
}
|
|
52
|
-
}
|
|
53
|
-
// Addition simplifications
|
|
54
|
-
if (node.op === '+') {
|
|
55
|
-
// 0 + x = x
|
|
56
|
-
if (leftNum === 0)
|
|
57
|
-
return right;
|
|
58
|
-
// x + 0 = x
|
|
59
|
-
if (rightNum === 0)
|
|
60
|
-
return left;
|
|
61
|
-
// x + x = 2*x
|
|
62
|
-
if (nodesEqual(left, right)) {
|
|
63
|
-
return new BinaryOpNode('*', new NumberNode(2), left);
|
|
64
|
-
}
|
|
65
|
-
}
|
|
66
|
-
// Subtraction simplifications
|
|
67
|
-
if (node.op === '-') {
|
|
68
|
-
// x - 0 = x
|
|
69
|
-
if (rightNum === 0)
|
|
70
|
-
return left;
|
|
71
|
-
// 0 - x = -x
|
|
72
|
-
if (leftNum === 0)
|
|
73
|
-
return new UnaryOpNode('-', right);
|
|
74
|
-
// x - x = 0
|
|
75
|
-
if (nodesEqual(left, right)) {
|
|
76
|
-
return new NumberNode(0);
|
|
77
|
-
}
|
|
78
|
-
}
|
|
79
|
-
// Multiplication simplifications
|
|
80
|
-
if (node.op === '*') {
|
|
81
|
-
// 0 * x = 0
|
|
82
|
-
if (leftNum === 0 || rightNum === 0)
|
|
83
|
-
return new NumberNode(0);
|
|
84
|
-
// 1 * x = x
|
|
85
|
-
if (leftNum === 1)
|
|
86
|
-
return right;
|
|
87
|
-
// x * 1 = x
|
|
88
|
-
if (rightNum === 1)
|
|
89
|
-
return left;
|
|
90
|
-
// -1 * x = -x
|
|
91
|
-
if (leftNum === -1)
|
|
92
|
-
return new UnaryOpNode('-', right);
|
|
93
|
-
// x * -1 = -x
|
|
94
|
-
if (rightNum === -1)
|
|
95
|
-
return new UnaryOpNode('-', left);
|
|
96
|
-
// x * x = x^2
|
|
97
|
-
if (nodesEqual(left, right)) {
|
|
98
|
-
return new BinaryOpNode('**', left, new NumberNode(2));
|
|
99
|
-
}
|
|
100
|
-
}
|
|
101
|
-
// Division simplifications
|
|
102
|
-
if (node.op === '/') {
|
|
103
|
-
// 0 / x = 0
|
|
104
|
-
if (leftNum === 0)
|
|
105
|
-
return new NumberNode(0);
|
|
106
|
-
// x / 1 = x
|
|
107
|
-
if (rightNum === 1)
|
|
108
|
-
return left;
|
|
109
|
-
// x / x = 1
|
|
110
|
-
if (nodesEqual(left, right))
|
|
111
|
-
return new NumberNode(1);
|
|
112
|
-
}
|
|
113
|
-
// Power simplifications
|
|
114
|
-
if (node.op === '**') {
|
|
115
|
-
// x^0 = 1
|
|
116
|
-
if (rightNum === 0)
|
|
117
|
-
return new NumberNode(1);
|
|
118
|
-
// x^1 = x
|
|
119
|
-
if (rightNum === 1)
|
|
120
|
-
return left;
|
|
121
|
-
// 0^x = 0 (for x > 0)
|
|
122
|
-
if (leftNum === 0)
|
|
123
|
-
return new NumberNode(0);
|
|
124
|
-
// 1^x = 1
|
|
125
|
-
if (leftNum === 1)
|
|
126
|
-
return new NumberNode(1);
|
|
127
|
-
}
|
|
128
|
-
return new BinaryOpNode(node.op, left, right);
|
|
129
|
-
}
|
|
130
|
-
visitFunctionCall(node) {
|
|
131
|
-
const args = node.args.map(arg => arg.accept(this));
|
|
132
|
-
// Check for constant arguments
|
|
133
|
-
if (args.length === 1 && args[0].type === 'Number') {
|
|
134
|
-
const value = args[0].value;
|
|
135
|
-
switch (node.name) {
|
|
136
|
-
case 'sin': return new NumberNode(Math.sin(value));
|
|
137
|
-
case 'cos': return new NumberNode(Math.cos(value));
|
|
138
|
-
case 'tan': return new NumberNode(Math.tan(value));
|
|
139
|
-
case 'exp': return new NumberNode(Math.exp(value));
|
|
140
|
-
case 'log':
|
|
141
|
-
case 'ln': return new NumberNode(Math.log(value));
|
|
142
|
-
case 'sqrt': return new NumberNode(Math.sqrt(value));
|
|
143
|
-
case 'abs': return new NumberNode(Math.abs(value));
|
|
144
|
-
case 'sign': return new NumberNode(Math.sign(value));
|
|
145
|
-
case 'floor': return new NumberNode(Math.floor(value));
|
|
146
|
-
case 'ceil': return new NumberNode(Math.ceil(value));
|
|
147
|
-
case 'round': return new NumberNode(Math.round(value));
|
|
148
|
-
case 'asin': return new NumberNode(Math.asin(value));
|
|
149
|
-
case 'acos': return new NumberNode(Math.acos(value));
|
|
150
|
-
case 'atan': return new NumberNode(Math.atan(value));
|
|
151
|
-
case 'sinh': return new NumberNode(Math.sinh(value));
|
|
152
|
-
case 'cosh': return new NumberNode(Math.cosh(value));
|
|
153
|
-
case 'tanh': return new NumberNode(Math.tanh(value));
|
|
154
|
-
}
|
|
155
|
-
}
|
|
156
|
-
// Two-argument functions
|
|
157
|
-
if (args.length === 2 && args[0].type === 'Number' && args[1].type === 'Number') {
|
|
158
|
-
const val1 = args[0].value;
|
|
159
|
-
const val2 = args[1].value;
|
|
160
|
-
switch (node.name) {
|
|
161
|
-
case 'pow': return new NumberNode(Math.pow(val1, val2));
|
|
162
|
-
case 'min': return new NumberNode(Math.min(val1, val2));
|
|
163
|
-
case 'max': return new NumberNode(Math.max(val1, val2));
|
|
164
|
-
}
|
|
165
|
-
}
|
|
166
|
-
// Special simplifications
|
|
167
|
-
// exp(0) = 1
|
|
168
|
-
if (node.name === 'exp' && args[0].type === 'Number' && args[0].value === 0) {
|
|
169
|
-
return new NumberNode(1);
|
|
170
|
-
}
|
|
171
|
-
// log(1) = 0
|
|
172
|
-
if ((node.name === 'log' || node.name === 'ln') && args[0].type === 'Number' && args[0].value === 1) {
|
|
173
|
-
return new NumberNode(0);
|
|
174
|
-
}
|
|
175
|
-
// sqrt(1) = 1
|
|
176
|
-
if (node.name === 'sqrt' && args[0].type === 'Number' && args[0].value === 1) {
|
|
177
|
-
return new NumberNode(1);
|
|
178
|
-
}
|
|
179
|
-
return new FunctionCallNode(node.name, args);
|
|
180
|
-
}
|
|
181
|
-
visitVectorAccess(node) {
|
|
182
|
-
const vector = node.vector.accept(this);
|
|
183
|
-
return new VectorAccessNode(vector, node.component);
|
|
184
|
-
}
|
|
185
|
-
visitVectorConstructor(node) {
|
|
186
|
-
const components = node.components.map(c => c.accept(this));
|
|
187
|
-
return new VectorConstructorNode(node.vectorType, components);
|
|
188
|
-
}
|
|
189
|
-
}
|
|
190
|
-
/**
|
|
191
|
-
* Check if two AST nodes are structurally equal
|
|
192
|
-
*/
|
|
193
|
-
function nodesEqual(a, b) {
|
|
194
|
-
if (a.type !== b.type)
|
|
195
|
-
return false;
|
|
196
|
-
if (a.type === 'Number' && b.type === 'Number') {
|
|
197
|
-
return a.value === b.value;
|
|
198
|
-
}
|
|
199
|
-
if (a.type === 'Variable' && b.type === 'Variable') {
|
|
200
|
-
return a.name === b.name;
|
|
201
|
-
}
|
|
202
|
-
if (a.type === 'UnaryOp' && b.type === 'UnaryOp') {
|
|
203
|
-
const aUnary = a;
|
|
204
|
-
const bUnary = b;
|
|
205
|
-
return aUnary.op === bUnary.op && nodesEqual(aUnary.operand, bUnary.operand);
|
|
206
|
-
}
|
|
207
|
-
if (a.type === 'BinaryOp' && b.type === 'BinaryOp') {
|
|
208
|
-
const aBinary = a;
|
|
209
|
-
const bBinary = b;
|
|
210
|
-
return aBinary.op === bBinary.op &&
|
|
211
|
-
nodesEqual(aBinary.left, bBinary.left) &&
|
|
212
|
-
nodesEqual(aBinary.right, bBinary.right);
|
|
213
|
-
}
|
|
214
|
-
if (a.type === 'FunctionCall' && b.type === 'FunctionCall') {
|
|
215
|
-
const aFunc = a;
|
|
216
|
-
const bFunc = b;
|
|
217
|
-
return aFunc.name === bFunc.name &&
|
|
218
|
-
aFunc.args.length === bFunc.args.length &&
|
|
219
|
-
aFunc.args.every((arg, i) => nodesEqual(arg, bFunc.args[i]));
|
|
220
|
-
}
|
|
221
|
-
if (a.type === 'VectorAccess' && b.type === 'VectorAccess') {
|
|
222
|
-
const aVec = a;
|
|
223
|
-
const bVec = b;
|
|
224
|
-
return aVec.component === bVec.component && nodesEqual(aVec.vector, bVec.vector);
|
|
225
|
-
}
|
|
226
|
-
return false;
|
|
227
|
-
}
|
|
228
|
-
/**
|
|
229
|
-
* Simplify an AST node
|
|
230
|
-
*/
|
|
231
|
-
export function simplify(node) {
|
|
232
|
-
const visitor = new SimplificationVisitor();
|
|
233
|
-
let current = node;
|
|
234
|
-
let previous;
|
|
235
|
-
// Keep simplifying until we reach a fixed point
|
|
236
|
-
let iterations = 0;
|
|
237
|
-
const maxIterations = 10;
|
|
238
|
-
do {
|
|
239
|
-
previous = current;
|
|
240
|
-
current = current.accept(visitor);
|
|
241
|
-
iterations++;
|
|
242
|
-
} while (iterations < maxIterations && !nodesEqual(current, previous));
|
|
243
|
-
return current;
|
|
244
|
-
}
|
|
@@ -1,35 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Symbolic differentiation engine.
|
|
3
|
-
* Applies differentiation rules to AST and generates symbolic gradient expressions.
|
|
4
|
-
* @internal
|
|
5
|
-
*/
|
|
6
|
-
import { ASTNode, ASTVisitor, NumberNode, VariableNode, BinaryOpNode, UnaryOpNode, FunctionCallNode, VectorAccessNode, VectorConstructorNode, Program } from './AST';
|
|
7
|
-
/**
|
|
8
|
-
* Differentiate an AST node with respect to a variable
|
|
9
|
-
*/
|
|
10
|
-
export declare class DifferentiationVisitor implements ASTVisitor<ASTNode> {
|
|
11
|
-
private wrt;
|
|
12
|
-
constructor(wrt: string);
|
|
13
|
-
visitNumber(node: NumberNode): ASTNode;
|
|
14
|
-
visitVariable(node: VariableNode): ASTNode;
|
|
15
|
-
visitUnaryOp(node: UnaryOpNode): ASTNode;
|
|
16
|
-
visitBinaryOp(node: BinaryOpNode): ASTNode;
|
|
17
|
-
visitFunctionCall(node: FunctionCallNode): ASTNode;
|
|
18
|
-
visitVectorAccess(node: VectorAccessNode): ASTNode;
|
|
19
|
-
visitVectorConstructor(node: VectorConstructorNode): ASTNode;
|
|
20
|
-
}
|
|
21
|
-
/**
|
|
22
|
-
* Differentiate an AST node with respect to a variable
|
|
23
|
-
*/
|
|
24
|
-
export declare function differentiate(node: ASTNode, wrt: string): ASTNode;
|
|
25
|
-
/**
|
|
26
|
-
* Compute all gradients for a program
|
|
27
|
-
*/
|
|
28
|
-
export interface GradientResult {
|
|
29
|
-
variable: string;
|
|
30
|
-
gradient: ASTNode;
|
|
31
|
-
}
|
|
32
|
-
/**
|
|
33
|
-
* Compute gradients of output w.r.t. all parameters
|
|
34
|
-
*/
|
|
35
|
-
export declare function computeGradients(program: Program, parameters: string[]): Map<string, ASTNode>;
|