gradient-script 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -1
- package/dist/cli.js +80 -3
- package/dist/dsl/CodeGen.d.ts +1 -1
- package/dist/dsl/CodeGen.js +332 -74
- package/dist/dsl/ExpressionUtils.d.ts +8 -2
- package/dist/dsl/ExpressionUtils.js +34 -2
- package/dist/dsl/GradientChecker.d.ts +21 -0
- package/dist/dsl/GradientChecker.js +109 -23
- package/dist/dsl/Guards.d.ts +1 -1
- package/dist/dsl/Guards.js +14 -13
- package/dist/dsl/Inliner.d.ts +5 -0
- package/dist/dsl/Inliner.js +8 -0
- package/dist/dsl/Simplify.d.ts +7 -0
- package/dist/dsl/Simplify.js +136 -0
- package/dist/dsl/egraph/Convert.d.ts +23 -0
- package/dist/dsl/egraph/Convert.js +84 -0
- package/dist/dsl/egraph/EGraph.d.ts +93 -0
- package/dist/dsl/egraph/EGraph.js +292 -0
- package/dist/dsl/egraph/ENode.d.ts +63 -0
- package/dist/dsl/egraph/ENode.js +94 -0
- package/dist/dsl/egraph/Extractor.d.ts +49 -0
- package/dist/dsl/egraph/Extractor.js +1068 -0
- package/dist/dsl/egraph/Optimizer.d.ts +50 -0
- package/dist/dsl/egraph/Optimizer.js +88 -0
- package/dist/dsl/egraph/Pattern.d.ts +80 -0
- package/dist/dsl/egraph/Pattern.js +325 -0
- package/dist/dsl/egraph/Rewriter.d.ts +44 -0
- package/dist/dsl/egraph/Rewriter.js +131 -0
- package/dist/dsl/egraph/Rules.d.ts +44 -0
- package/dist/dsl/egraph/Rules.js +187 -0
- package/dist/dsl/egraph/index.d.ts +15 -0
- package/dist/dsl/egraph/index.js +21 -0
- package/package.json +1 -1
- package/dist/dsl/CSE.d.ts +0 -21
- package/dist/dsl/CSE.js +0 -168
- package/dist/symbolic/AST.d.ts +0 -113
- package/dist/symbolic/AST.js +0 -128
- package/dist/symbolic/CodeGen.d.ts +0 -35
- package/dist/symbolic/CodeGen.js +0 -280
- package/dist/symbolic/Parser.d.ts +0 -64
- package/dist/symbolic/Parser.js +0 -329
- package/dist/symbolic/Simplify.d.ts +0 -10
- package/dist/symbolic/Simplify.js +0 -244
- package/dist/symbolic/SymbolicDiff.d.ts +0 -35
- package/dist/symbolic/SymbolicDiff.js +0 -339
|
@@ -1,339 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Symbolic differentiation engine.
|
|
3
|
-
* Applies differentiation rules to AST and generates symbolic gradient expressions.
|
|
4
|
-
* @internal
|
|
5
|
-
*/
|
|
6
|
-
import { NumberNode, BinaryOpNode, UnaryOpNode, FunctionCallNode, VectorAccessNode, VectorConstructorNode } from './AST';
|
|
7
|
-
/**
|
|
8
|
-
* Differentiate an AST node with respect to a variable
|
|
9
|
-
*/
|
|
10
|
-
export class DifferentiationVisitor {
|
|
11
|
-
wrt;
|
|
12
|
-
constructor(wrt) {
|
|
13
|
-
this.wrt = wrt;
|
|
14
|
-
}
|
|
15
|
-
visitNumber(node) {
|
|
16
|
-
// d/dx(c) = 0
|
|
17
|
-
return new NumberNode(0);
|
|
18
|
-
}
|
|
19
|
-
visitVariable(node) {
|
|
20
|
-
// d/dx(x) = 1, d/dx(y) = 0
|
|
21
|
-
return new NumberNode(node.name === this.wrt ? 1 : 0);
|
|
22
|
-
}
|
|
23
|
-
visitUnaryOp(node) {
|
|
24
|
-
const inner = differentiate(node.operand, this.wrt);
|
|
25
|
-
if (node.op === '-') {
|
|
26
|
-
// d/dx(-f) = -df/dx
|
|
27
|
-
return new UnaryOpNode('-', inner);
|
|
28
|
-
}
|
|
29
|
-
else {
|
|
30
|
-
// d/dx(+f) = df/dx
|
|
31
|
-
return inner;
|
|
32
|
-
}
|
|
33
|
-
}
|
|
34
|
-
visitBinaryOp(node) {
|
|
35
|
-
const u = node.left;
|
|
36
|
-
const v = node.right;
|
|
37
|
-
const du = differentiate(u, this.wrt);
|
|
38
|
-
const dv = differentiate(v, this.wrt);
|
|
39
|
-
switch (node.op) {
|
|
40
|
-
case '+':
|
|
41
|
-
// d/dx(u + v) = du/dx + dv/dx
|
|
42
|
-
return new BinaryOpNode('+', du, dv);
|
|
43
|
-
case '-':
|
|
44
|
-
// d/dx(u - v) = du/dx - dv/dx
|
|
45
|
-
return new BinaryOpNode('-', du, dv);
|
|
46
|
-
case '*':
|
|
47
|
-
// d/dx(u * v) = u * dv/dx + v * du/dx (product rule)
|
|
48
|
-
return new BinaryOpNode('+', new BinaryOpNode('*', u, dv), new BinaryOpNode('*', v, du));
|
|
49
|
-
case '/':
|
|
50
|
-
// d/dx(u / v) = (v * du/dx - u * dv/dx) / v^2 (quotient rule)
|
|
51
|
-
return new BinaryOpNode('/', new BinaryOpNode('-', new BinaryOpNode('*', v, du), new BinaryOpNode('*', u, dv)), new BinaryOpNode('**', v, new NumberNode(2)));
|
|
52
|
-
case '**':
|
|
53
|
-
case 'pow': {
|
|
54
|
-
// d/dx(u^v) requires checking if v is constant or depends on x
|
|
55
|
-
const vIsConstant = !dependsOn(v, this.wrt);
|
|
56
|
-
const uIsConstant = !dependsOn(u, this.wrt);
|
|
57
|
-
if (vIsConstant && uIsConstant) {
|
|
58
|
-
// Both constant
|
|
59
|
-
return new NumberNode(0);
|
|
60
|
-
}
|
|
61
|
-
else if (vIsConstant) {
|
|
62
|
-
// d/dx(u^c) = c * u^(c-1) * du/dx (power rule)
|
|
63
|
-
return new BinaryOpNode('*', new BinaryOpNode('*', v, new BinaryOpNode('**', u, new BinaryOpNode('-', v, new NumberNode(1)))), du);
|
|
64
|
-
}
|
|
65
|
-
else if (uIsConstant) {
|
|
66
|
-
// d/dx(c^v) = c^v * ln(c) * dv/dx
|
|
67
|
-
return new BinaryOpNode('*', new BinaryOpNode('*', new BinaryOpNode('**', u, v), new FunctionCallNode('log', [u])), dv);
|
|
68
|
-
}
|
|
69
|
-
else {
|
|
70
|
-
// d/dx(u^v) = u^v * (v' * ln(u) + v * u'/u) (general power rule)
|
|
71
|
-
return new BinaryOpNode('*', new BinaryOpNode('**', u, v), new BinaryOpNode('+', new BinaryOpNode('*', dv, new FunctionCallNode('log', [u])), new BinaryOpNode('*', v, new BinaryOpNode('/', du, u))));
|
|
72
|
-
}
|
|
73
|
-
}
|
|
74
|
-
default:
|
|
75
|
-
throw new Error(`Unknown binary operator: ${node.op}`);
|
|
76
|
-
}
|
|
77
|
-
}
|
|
78
|
-
visitFunctionCall(node) {
|
|
79
|
-
// For single-argument functions, apply chain rule
|
|
80
|
-
if (node.args.length === 1) {
|
|
81
|
-
const arg = node.args[0];
|
|
82
|
-
const darg = differentiate(arg, this.wrt);
|
|
83
|
-
let derivative;
|
|
84
|
-
switch (node.name) {
|
|
85
|
-
case 'sin':
|
|
86
|
-
// d/dx(sin(u)) = cos(u) * du/dx
|
|
87
|
-
derivative = new FunctionCallNode('cos', [arg]);
|
|
88
|
-
break;
|
|
89
|
-
case 'cos':
|
|
90
|
-
// d/dx(cos(u)) = -sin(u) * du/dx
|
|
91
|
-
derivative = new UnaryOpNode('-', new FunctionCallNode('sin', [arg]));
|
|
92
|
-
break;
|
|
93
|
-
case 'tan':
|
|
94
|
-
// d/dx(tan(u)) = sec^2(u) * du/dx = 1/cos^2(u) * du/dx
|
|
95
|
-
derivative = new BinaryOpNode('/', new NumberNode(1), new BinaryOpNode('**', new FunctionCallNode('cos', [arg]), new NumberNode(2)));
|
|
96
|
-
break;
|
|
97
|
-
case 'exp':
|
|
98
|
-
// d/dx(exp(u)) = exp(u) * du/dx
|
|
99
|
-
derivative = new FunctionCallNode('exp', [arg]);
|
|
100
|
-
break;
|
|
101
|
-
case 'log':
|
|
102
|
-
case 'ln':
|
|
103
|
-
// d/dx(ln(u)) = 1/u * du/dx
|
|
104
|
-
derivative = new BinaryOpNode('/', new NumberNode(1), arg);
|
|
105
|
-
break;
|
|
106
|
-
case 'sqrt':
|
|
107
|
-
// d/dx(sqrt(u)) = 1/(2*sqrt(u)) * du/dx
|
|
108
|
-
derivative = new BinaryOpNode('/', new NumberNode(1), new BinaryOpNode('*', new NumberNode(2), new FunctionCallNode('sqrt', [arg])));
|
|
109
|
-
break;
|
|
110
|
-
case 'abs':
|
|
111
|
-
// d/dx(|u|) = u/|u| * du/dx = sign(u) * du/dx
|
|
112
|
-
derivative = new FunctionCallNode('sign', [arg]);
|
|
113
|
-
break;
|
|
114
|
-
case 'asin':
|
|
115
|
-
// d/dx(asin(u)) = 1/sqrt(1 - u^2) * du/dx
|
|
116
|
-
derivative = new BinaryOpNode('/', new NumberNode(1), new FunctionCallNode('sqrt', [
|
|
117
|
-
new BinaryOpNode('-', new NumberNode(1), new BinaryOpNode('**', arg, new NumberNode(2)))
|
|
118
|
-
]));
|
|
119
|
-
break;
|
|
120
|
-
case 'acos':
|
|
121
|
-
// d/dx(acos(u)) = -1/sqrt(1 - u^2) * du/dx
|
|
122
|
-
derivative = new UnaryOpNode('-', new BinaryOpNode('/', new NumberNode(1), new FunctionCallNode('sqrt', [
|
|
123
|
-
new BinaryOpNode('-', new NumberNode(1), new BinaryOpNode('**', arg, new NumberNode(2)))
|
|
124
|
-
])));
|
|
125
|
-
break;
|
|
126
|
-
case 'atan':
|
|
127
|
-
// d/dx(atan(u)) = 1/(1 + u^2) * du/dx
|
|
128
|
-
derivative = new BinaryOpNode('/', new NumberNode(1), new BinaryOpNode('+', new NumberNode(1), new BinaryOpNode('**', arg, new NumberNode(2))));
|
|
129
|
-
break;
|
|
130
|
-
case 'sinh':
|
|
131
|
-
// d/dx(sinh(u)) = cosh(u) * du/dx
|
|
132
|
-
derivative = new FunctionCallNode('cosh', [arg]);
|
|
133
|
-
break;
|
|
134
|
-
case 'cosh':
|
|
135
|
-
// d/dx(cosh(u)) = sinh(u) * du/dx
|
|
136
|
-
derivative = new FunctionCallNode('sinh', [arg]);
|
|
137
|
-
break;
|
|
138
|
-
case 'tanh':
|
|
139
|
-
// d/dx(tanh(u)) = 1 - tanh^2(u) * du/dx
|
|
140
|
-
derivative = new BinaryOpNode('-', new NumberNode(1), new BinaryOpNode('**', new FunctionCallNode('tanh', [arg]), new NumberNode(2)));
|
|
141
|
-
break;
|
|
142
|
-
case 'sigmoid':
|
|
143
|
-
// d/dx(sigmoid(u)) = sigmoid(u) * (1 - sigmoid(u)) * du/dx
|
|
144
|
-
derivative = new BinaryOpNode('*', new FunctionCallNode('sigmoid', [arg]), new BinaryOpNode('-', new NumberNode(1), new FunctionCallNode('sigmoid', [arg])));
|
|
145
|
-
break;
|
|
146
|
-
case 'relu':
|
|
147
|
-
// d/dx(relu(u)) = (u > 0 ? 1 : 0) * du/dx
|
|
148
|
-
// For symbolic, we'll use a heaviside-like representation
|
|
149
|
-
derivative = new FunctionCallNode('heaviside', [arg]);
|
|
150
|
-
break;
|
|
151
|
-
case 'sign':
|
|
152
|
-
// d/dx(sign(u)) = 0 (almost everywhere)
|
|
153
|
-
return new NumberNode(0);
|
|
154
|
-
case 'floor':
|
|
155
|
-
case 'ceil':
|
|
156
|
-
case 'round':
|
|
157
|
-
// d/dx(floor(u)) = 0 (almost everywhere)
|
|
158
|
-
return new NumberNode(0);
|
|
159
|
-
case 'magnitude': {
|
|
160
|
-
// For Vec2/Vec3 magnitude: d/dx(|v|) = v.x/|v| * dv.x/dx + v.y/|v| * dv.y/dx + ...
|
|
161
|
-
// arg is the vector
|
|
162
|
-
if (arg.type === 'Variable') {
|
|
163
|
-
const varName = arg.name;
|
|
164
|
-
// Assume 2D for now (can extend)
|
|
165
|
-
const vx = new VectorAccessNode(arg, 'x');
|
|
166
|
-
const vy = new VectorAccessNode(arg, 'y');
|
|
167
|
-
const mag = new FunctionCallNode('magnitude', [arg]);
|
|
168
|
-
// Check if differentiating w.r.t. x or y components
|
|
169
|
-
if (this.wrt === `${varName}.x`) {
|
|
170
|
-
return new BinaryOpNode('/', vx, mag);
|
|
171
|
-
}
|
|
172
|
-
else if (this.wrt === `${varName}.y`) {
|
|
173
|
-
return new BinaryOpNode('/', vy, mag);
|
|
174
|
-
}
|
|
175
|
-
}
|
|
176
|
-
throw new Error('magnitude differentiation requires vector variable');
|
|
177
|
-
}
|
|
178
|
-
case 'sqrMagnitude': {
|
|
179
|
-
// d/dx(|v|^2) = 2*v.x * dv.x/dx + 2*v.y * dv.y/dx
|
|
180
|
-
if (arg.type === 'Variable') {
|
|
181
|
-
const varName = arg.name;
|
|
182
|
-
const vx = new VectorAccessNode(arg, 'x');
|
|
183
|
-
if (this.wrt === `${varName}.x`) {
|
|
184
|
-
return new BinaryOpNode('*', new NumberNode(2), vx);
|
|
185
|
-
}
|
|
186
|
-
else if (this.wrt === `${varName}.y`) {
|
|
187
|
-
const vy = new VectorAccessNode(arg, 'y');
|
|
188
|
-
return new BinaryOpNode('*', new NumberNode(2), vy);
|
|
189
|
-
}
|
|
190
|
-
}
|
|
191
|
-
throw new Error('sqrMagnitude differentiation requires vector variable');
|
|
192
|
-
}
|
|
193
|
-
default:
|
|
194
|
-
throw new Error(`Unknown function: ${node.name}`);
|
|
195
|
-
}
|
|
196
|
-
// Apply chain rule: f'(u) * u'
|
|
197
|
-
return new BinaryOpNode('*', derivative, darg);
|
|
198
|
-
}
|
|
199
|
-
// Multi-argument functions
|
|
200
|
-
if (node.args.length === 2) {
|
|
201
|
-
const [arg1, arg2] = node.args;
|
|
202
|
-
const darg1 = differentiate(arg1, this.wrt);
|
|
203
|
-
const darg2 = differentiate(arg2, this.wrt);
|
|
204
|
-
switch (node.name) {
|
|
205
|
-
case 'pow':
|
|
206
|
-
// Same as ** operator
|
|
207
|
-
return differentiate(new BinaryOpNode('**', arg1, arg2), this.wrt);
|
|
208
|
-
case 'min':
|
|
209
|
-
case 'max':
|
|
210
|
-
// Derivative is discontinuous at boundary - for symbolic we'll note it
|
|
211
|
-
// d/dx(min(u,v)) = du/dx if u < v, dv/dx if v < u
|
|
212
|
-
// For now, return a placeholder
|
|
213
|
-
return new FunctionCallNode(`d_${node.name}`, [arg1, arg2, darg1, darg2]);
|
|
214
|
-
case 'atan2': {
|
|
215
|
-
// atan2(y, x): angle from positive x-axis to point (x, y)
|
|
216
|
-
// ∂/∂y(atan2(y, x)) = x/(x² + y²)
|
|
217
|
-
// ∂/∂x(atan2(y, x)) = -y/(x² + y²)
|
|
218
|
-
const y = arg1;
|
|
219
|
-
const x = arg2;
|
|
220
|
-
const dy = darg1;
|
|
221
|
-
const dx = darg2;
|
|
222
|
-
// denominator: x² + y²
|
|
223
|
-
const denom = new BinaryOpNode('+', new BinaryOpNode('**', x, new NumberNode(2)), new BinaryOpNode('**', y, new NumberNode(2)));
|
|
224
|
-
// Chain rule: (∂atan2/∂y) * dy + (∂atan2/∂x) * dx
|
|
225
|
-
const term1 = new BinaryOpNode('*', new BinaryOpNode('/', x, denom), dy);
|
|
226
|
-
const term2 = new BinaryOpNode('*', new UnaryOpNode('-', new BinaryOpNode('/', y, denom)), dx);
|
|
227
|
-
return new BinaryOpNode('+', term1, term2);
|
|
228
|
-
}
|
|
229
|
-
case 'dot': {
|
|
230
|
-
// dot product: u.x * v.x + u.y * v.y
|
|
231
|
-
// d/dx(u.v) = du/dx . v + u . dv/dx
|
|
232
|
-
return new BinaryOpNode('+', new FunctionCallNode('dot', [darg1, arg2]), new FunctionCallNode('dot', [arg1, darg2]));
|
|
233
|
-
}
|
|
234
|
-
default:
|
|
235
|
-
throw new Error(`Unknown 2-arg function: ${node.name}`);
|
|
236
|
-
}
|
|
237
|
-
}
|
|
238
|
-
throw new Error(`Unsupported function arity: ${node.name} with ${node.args.length} args`);
|
|
239
|
-
}
|
|
240
|
-
visitVectorAccess(node) {
|
|
241
|
-
// d/dx(v.y) = d(v.y)/dx
|
|
242
|
-
// This depends on whether we're differentiating w.r.t. the vector component
|
|
243
|
-
if (node.vector.type === 'Variable') {
|
|
244
|
-
const varName = node.vector.name;
|
|
245
|
-
const fullName = `${varName}.${node.component}`;
|
|
246
|
-
if (fullName === this.wrt) {
|
|
247
|
-
return new NumberNode(1);
|
|
248
|
-
}
|
|
249
|
-
else {
|
|
250
|
-
return new NumberNode(0);
|
|
251
|
-
}
|
|
252
|
-
}
|
|
253
|
-
// For computed vectors, we'd need to differentiate the computation
|
|
254
|
-
throw new Error('VectorAccess differentiation only supported for variable vectors');
|
|
255
|
-
}
|
|
256
|
-
visitVectorConstructor(node) {
|
|
257
|
-
// d/dx(Vec2(u, v)) = Vec2(du/dx, dv/dx)
|
|
258
|
-
const diffComponents = node.components.map(c => differentiate(c, this.wrt));
|
|
259
|
-
return new VectorConstructorNode(node.vectorType, diffComponents);
|
|
260
|
-
}
|
|
261
|
-
}
|
|
262
|
-
/**
|
|
263
|
-
* Check if an expression depends on a variable
|
|
264
|
-
*/
|
|
265
|
-
function dependsOn(node, varName) {
|
|
266
|
-
if (node.type === 'Number') {
|
|
267
|
-
return false;
|
|
268
|
-
}
|
|
269
|
-
if (node.type === 'Variable') {
|
|
270
|
-
return node.name === varName;
|
|
271
|
-
}
|
|
272
|
-
if (node.type === 'UnaryOp') {
|
|
273
|
-
return dependsOn(node.operand, varName);
|
|
274
|
-
}
|
|
275
|
-
if (node.type === 'BinaryOp') {
|
|
276
|
-
const binOp = node;
|
|
277
|
-
return dependsOn(binOp.left, varName) || dependsOn(binOp.right, varName);
|
|
278
|
-
}
|
|
279
|
-
if (node.type === 'FunctionCall') {
|
|
280
|
-
return node.args.some(arg => dependsOn(arg, varName));
|
|
281
|
-
}
|
|
282
|
-
if (node.type === 'VectorAccess') {
|
|
283
|
-
return dependsOn(node.vector, varName);
|
|
284
|
-
}
|
|
285
|
-
if (node.type === 'VectorConstructor') {
|
|
286
|
-
return node.components.some(c => dependsOn(c, varName));
|
|
287
|
-
}
|
|
288
|
-
return false;
|
|
289
|
-
}
|
|
290
|
-
/**
|
|
291
|
-
* Differentiate an AST node with respect to a variable
|
|
292
|
-
*/
|
|
293
|
-
export function differentiate(node, wrt) {
|
|
294
|
-
const visitor = new DifferentiationVisitor(wrt);
|
|
295
|
-
return node.accept(visitor);
|
|
296
|
-
}
|
|
297
|
-
/**
|
|
298
|
-
* Compute gradients of output w.r.t. all parameters
|
|
299
|
-
*/
|
|
300
|
-
export function computeGradients(program, parameters) {
|
|
301
|
-
// Build a map of variable definitions
|
|
302
|
-
const variableMap = new Map();
|
|
303
|
-
for (const assignment of program.assignments) {
|
|
304
|
-
variableMap.set(assignment.variable, assignment.expression);
|
|
305
|
-
}
|
|
306
|
-
// Get the output expression
|
|
307
|
-
const outputExpr = variableMap.get(program.output);
|
|
308
|
-
if (!outputExpr) {
|
|
309
|
-
throw new Error(`Output variable '${program.output}' not found`);
|
|
310
|
-
}
|
|
311
|
-
// Compute gradients using reverse-mode autodiff
|
|
312
|
-
// Start with d(output)/d(output) = 1
|
|
313
|
-
const gradients = new Map();
|
|
314
|
-
gradients.set(program.output, new NumberNode(1));
|
|
315
|
-
// Reverse topological order
|
|
316
|
-
const variables = program.assignments.map(a => a.variable).reverse();
|
|
317
|
-
for (const variable of variables) {
|
|
318
|
-
const expr = variableMap.get(variable);
|
|
319
|
-
const grad = gradients.get(variable);
|
|
320
|
-
if (!grad)
|
|
321
|
-
continue; // No gradient flows to this variable
|
|
322
|
-
// For each parameter that this expression depends on
|
|
323
|
-
for (const param of parameters) {
|
|
324
|
-
if (dependsOn(expr, param)) {
|
|
325
|
-
const localGrad = differentiate(expr, param);
|
|
326
|
-
// Chain rule: accumulate gradient
|
|
327
|
-
const chainedGrad = new BinaryOpNode('*', grad, localGrad);
|
|
328
|
-
const existing = gradients.get(param);
|
|
329
|
-
if (existing) {
|
|
330
|
-
gradients.set(param, new BinaryOpNode('+', existing, chainedGrad));
|
|
331
|
-
}
|
|
332
|
-
else {
|
|
333
|
-
gradients.set(param, chainedGrad);
|
|
334
|
-
}
|
|
335
|
-
}
|
|
336
|
-
}
|
|
337
|
-
}
|
|
338
|
-
return gradients;
|
|
339
|
-
}
|