gradient-script 0.2.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. package/README.md +3 -1
  2. package/dist/cli.js +219 -6
  3. package/dist/dsl/CodeGen.d.ts +1 -1
  4. package/dist/dsl/CodeGen.js +336 -74
  5. package/dist/dsl/ExpressionUtils.d.ts +8 -2
  6. package/dist/dsl/ExpressionUtils.js +34 -2
  7. package/dist/dsl/GradientChecker.d.ts +21 -0
  8. package/dist/dsl/GradientChecker.js +109 -23
  9. package/dist/dsl/Guards.d.ts +1 -1
  10. package/dist/dsl/Guards.js +14 -13
  11. package/dist/dsl/Inliner.d.ts +5 -0
  12. package/dist/dsl/Inliner.js +8 -0
  13. package/dist/dsl/Simplify.d.ts +7 -0
  14. package/dist/dsl/Simplify.js +136 -0
  15. package/dist/dsl/egraph/Convert.d.ts +23 -0
  16. package/dist/dsl/egraph/Convert.js +84 -0
  17. package/dist/dsl/egraph/EGraph.d.ts +93 -0
  18. package/dist/dsl/egraph/EGraph.js +292 -0
  19. package/dist/dsl/egraph/ENode.d.ts +63 -0
  20. package/dist/dsl/egraph/ENode.js +94 -0
  21. package/dist/dsl/egraph/Extractor.d.ts +49 -0
  22. package/dist/dsl/egraph/Extractor.js +1068 -0
  23. package/dist/dsl/egraph/Optimizer.d.ts +50 -0
  24. package/dist/dsl/egraph/Optimizer.js +88 -0
  25. package/dist/dsl/egraph/Pattern.d.ts +80 -0
  26. package/dist/dsl/egraph/Pattern.js +325 -0
  27. package/dist/dsl/egraph/Rewriter.d.ts +44 -0
  28. package/dist/dsl/egraph/Rewriter.js +131 -0
  29. package/dist/dsl/egraph/Rules.d.ts +44 -0
  30. package/dist/dsl/egraph/Rules.js +187 -0
  31. package/dist/dsl/egraph/index.d.ts +15 -0
  32. package/dist/dsl/egraph/index.js +21 -0
  33. package/package.json +1 -1
  34. package/dist/dsl/CSE.d.ts +0 -21
  35. package/dist/dsl/CSE.js +0 -168
  36. package/dist/symbolic/AST.d.ts +0 -113
  37. package/dist/symbolic/AST.js +0 -128
  38. package/dist/symbolic/CodeGen.d.ts +0 -35
  39. package/dist/symbolic/CodeGen.js +0 -280
  40. package/dist/symbolic/Parser.d.ts +0 -64
  41. package/dist/symbolic/Parser.js +0 -329
  42. package/dist/symbolic/Simplify.d.ts +0 -10
  43. package/dist/symbolic/Simplify.js +0 -244
  44. package/dist/symbolic/SymbolicDiff.d.ts +0 -35
  45. package/dist/symbolic/SymbolicDiff.js +0 -339
@@ -1,339 +0,0 @@
1
- /**
2
- * Symbolic differentiation engine.
3
- * Applies differentiation rules to AST and generates symbolic gradient expressions.
4
- * @internal
5
- */
6
- import { NumberNode, BinaryOpNode, UnaryOpNode, FunctionCallNode, VectorAccessNode, VectorConstructorNode } from './AST';
7
- /**
8
- * Differentiate an AST node with respect to a variable
9
- */
10
- export class DifferentiationVisitor {
11
- wrt;
12
- constructor(wrt) {
13
- this.wrt = wrt;
14
- }
15
- visitNumber(node) {
16
- // d/dx(c) = 0
17
- return new NumberNode(0);
18
- }
19
- visitVariable(node) {
20
- // d/dx(x) = 1, d/dx(y) = 0
21
- return new NumberNode(node.name === this.wrt ? 1 : 0);
22
- }
23
- visitUnaryOp(node) {
24
- const inner = differentiate(node.operand, this.wrt);
25
- if (node.op === '-') {
26
- // d/dx(-f) = -df/dx
27
- return new UnaryOpNode('-', inner);
28
- }
29
- else {
30
- // d/dx(+f) = df/dx
31
- return inner;
32
- }
33
- }
34
- visitBinaryOp(node) {
35
- const u = node.left;
36
- const v = node.right;
37
- const du = differentiate(u, this.wrt);
38
- const dv = differentiate(v, this.wrt);
39
- switch (node.op) {
40
- case '+':
41
- // d/dx(u + v) = du/dx + dv/dx
42
- return new BinaryOpNode('+', du, dv);
43
- case '-':
44
- // d/dx(u - v) = du/dx - dv/dx
45
- return new BinaryOpNode('-', du, dv);
46
- case '*':
47
- // d/dx(u * v) = u * dv/dx + v * du/dx (product rule)
48
- return new BinaryOpNode('+', new BinaryOpNode('*', u, dv), new BinaryOpNode('*', v, du));
49
- case '/':
50
- // d/dx(u / v) = (v * du/dx - u * dv/dx) / v^2 (quotient rule)
51
- return new BinaryOpNode('/', new BinaryOpNode('-', new BinaryOpNode('*', v, du), new BinaryOpNode('*', u, dv)), new BinaryOpNode('**', v, new NumberNode(2)));
52
- case '**':
53
- case 'pow': {
54
- // d/dx(u^v) requires checking if v is constant or depends on x
55
- const vIsConstant = !dependsOn(v, this.wrt);
56
- const uIsConstant = !dependsOn(u, this.wrt);
57
- if (vIsConstant && uIsConstant) {
58
- // Both constant
59
- return new NumberNode(0);
60
- }
61
- else if (vIsConstant) {
62
- // d/dx(u^c) = c * u^(c-1) * du/dx (power rule)
63
- return new BinaryOpNode('*', new BinaryOpNode('*', v, new BinaryOpNode('**', u, new BinaryOpNode('-', v, new NumberNode(1)))), du);
64
- }
65
- else if (uIsConstant) {
66
- // d/dx(c^v) = c^v * ln(c) * dv/dx
67
- return new BinaryOpNode('*', new BinaryOpNode('*', new BinaryOpNode('**', u, v), new FunctionCallNode('log', [u])), dv);
68
- }
69
- else {
70
- // d/dx(u^v) = u^v * (v' * ln(u) + v * u'/u) (general power rule)
71
- return new BinaryOpNode('*', new BinaryOpNode('**', u, v), new BinaryOpNode('+', new BinaryOpNode('*', dv, new FunctionCallNode('log', [u])), new BinaryOpNode('*', v, new BinaryOpNode('/', du, u))));
72
- }
73
- }
74
- default:
75
- throw new Error(`Unknown binary operator: ${node.op}`);
76
- }
77
- }
78
- visitFunctionCall(node) {
79
- // For single-argument functions, apply chain rule
80
- if (node.args.length === 1) {
81
- const arg = node.args[0];
82
- const darg = differentiate(arg, this.wrt);
83
- let derivative;
84
- switch (node.name) {
85
- case 'sin':
86
- // d/dx(sin(u)) = cos(u) * du/dx
87
- derivative = new FunctionCallNode('cos', [arg]);
88
- break;
89
- case 'cos':
90
- // d/dx(cos(u)) = -sin(u) * du/dx
91
- derivative = new UnaryOpNode('-', new FunctionCallNode('sin', [arg]));
92
- break;
93
- case 'tan':
94
- // d/dx(tan(u)) = sec^2(u) * du/dx = 1/cos^2(u) * du/dx
95
- derivative = new BinaryOpNode('/', new NumberNode(1), new BinaryOpNode('**', new FunctionCallNode('cos', [arg]), new NumberNode(2)));
96
- break;
97
- case 'exp':
98
- // d/dx(exp(u)) = exp(u) * du/dx
99
- derivative = new FunctionCallNode('exp', [arg]);
100
- break;
101
- case 'log':
102
- case 'ln':
103
- // d/dx(ln(u)) = 1/u * du/dx
104
- derivative = new BinaryOpNode('/', new NumberNode(1), arg);
105
- break;
106
- case 'sqrt':
107
- // d/dx(sqrt(u)) = 1/(2*sqrt(u)) * du/dx
108
- derivative = new BinaryOpNode('/', new NumberNode(1), new BinaryOpNode('*', new NumberNode(2), new FunctionCallNode('sqrt', [arg])));
109
- break;
110
- case 'abs':
111
- // d/dx(|u|) = u/|u| * du/dx = sign(u) * du/dx
112
- derivative = new FunctionCallNode('sign', [arg]);
113
- break;
114
- case 'asin':
115
- // d/dx(asin(u)) = 1/sqrt(1 - u^2) * du/dx
116
- derivative = new BinaryOpNode('/', new NumberNode(1), new FunctionCallNode('sqrt', [
117
- new BinaryOpNode('-', new NumberNode(1), new BinaryOpNode('**', arg, new NumberNode(2)))
118
- ]));
119
- break;
120
- case 'acos':
121
- // d/dx(acos(u)) = -1/sqrt(1 - u^2) * du/dx
122
- derivative = new UnaryOpNode('-', new BinaryOpNode('/', new NumberNode(1), new FunctionCallNode('sqrt', [
123
- new BinaryOpNode('-', new NumberNode(1), new BinaryOpNode('**', arg, new NumberNode(2)))
124
- ])));
125
- break;
126
- case 'atan':
127
- // d/dx(atan(u)) = 1/(1 + u^2) * du/dx
128
- derivative = new BinaryOpNode('/', new NumberNode(1), new BinaryOpNode('+', new NumberNode(1), new BinaryOpNode('**', arg, new NumberNode(2))));
129
- break;
130
- case 'sinh':
131
- // d/dx(sinh(u)) = cosh(u) * du/dx
132
- derivative = new FunctionCallNode('cosh', [arg]);
133
- break;
134
- case 'cosh':
135
- // d/dx(cosh(u)) = sinh(u) * du/dx
136
- derivative = new FunctionCallNode('sinh', [arg]);
137
- break;
138
- case 'tanh':
139
- // d/dx(tanh(u)) = 1 - tanh^2(u) * du/dx
140
- derivative = new BinaryOpNode('-', new NumberNode(1), new BinaryOpNode('**', new FunctionCallNode('tanh', [arg]), new NumberNode(2)));
141
- break;
142
- case 'sigmoid':
143
- // d/dx(sigmoid(u)) = sigmoid(u) * (1 - sigmoid(u)) * du/dx
144
- derivative = new BinaryOpNode('*', new FunctionCallNode('sigmoid', [arg]), new BinaryOpNode('-', new NumberNode(1), new FunctionCallNode('sigmoid', [arg])));
145
- break;
146
- case 'relu':
147
- // d/dx(relu(u)) = (u > 0 ? 1 : 0) * du/dx
148
- // For symbolic, we'll use a heaviside-like representation
149
- derivative = new FunctionCallNode('heaviside', [arg]);
150
- break;
151
- case 'sign':
152
- // d/dx(sign(u)) = 0 (almost everywhere)
153
- return new NumberNode(0);
154
- case 'floor':
155
- case 'ceil':
156
- case 'round':
157
- // d/dx(floor(u)) = 0 (almost everywhere)
158
- return new NumberNode(0);
159
- case 'magnitude': {
160
- // For Vec2/Vec3 magnitude: d/dx(|v|) = v.x/|v| * dv.x/dx + v.y/|v| * dv.y/dx + ...
161
- // arg is the vector
162
- if (arg.type === 'Variable') {
163
- const varName = arg.name;
164
- // Assume 2D for now (can extend)
165
- const vx = new VectorAccessNode(arg, 'x');
166
- const vy = new VectorAccessNode(arg, 'y');
167
- const mag = new FunctionCallNode('magnitude', [arg]);
168
- // Check if differentiating w.r.t. x or y components
169
- if (this.wrt === `${varName}.x`) {
170
- return new BinaryOpNode('/', vx, mag);
171
- }
172
- else if (this.wrt === `${varName}.y`) {
173
- return new BinaryOpNode('/', vy, mag);
174
- }
175
- }
176
- throw new Error('magnitude differentiation requires vector variable');
177
- }
178
- case 'sqrMagnitude': {
179
- // d/dx(|v|^2) = 2*v.x * dv.x/dx + 2*v.y * dv.y/dx
180
- if (arg.type === 'Variable') {
181
- const varName = arg.name;
182
- const vx = new VectorAccessNode(arg, 'x');
183
- if (this.wrt === `${varName}.x`) {
184
- return new BinaryOpNode('*', new NumberNode(2), vx);
185
- }
186
- else if (this.wrt === `${varName}.y`) {
187
- const vy = new VectorAccessNode(arg, 'y');
188
- return new BinaryOpNode('*', new NumberNode(2), vy);
189
- }
190
- }
191
- throw new Error('sqrMagnitude differentiation requires vector variable');
192
- }
193
- default:
194
- throw new Error(`Unknown function: ${node.name}`);
195
- }
196
- // Apply chain rule: f'(u) * u'
197
- return new BinaryOpNode('*', derivative, darg);
198
- }
199
- // Multi-argument functions
200
- if (node.args.length === 2) {
201
- const [arg1, arg2] = node.args;
202
- const darg1 = differentiate(arg1, this.wrt);
203
- const darg2 = differentiate(arg2, this.wrt);
204
- switch (node.name) {
205
- case 'pow':
206
- // Same as ** operator
207
- return differentiate(new BinaryOpNode('**', arg1, arg2), this.wrt);
208
- case 'min':
209
- case 'max':
210
- // Derivative is discontinuous at boundary - for symbolic we'll note it
211
- // d/dx(min(u,v)) = du/dx if u < v, dv/dx if v < u
212
- // For now, return a placeholder
213
- return new FunctionCallNode(`d_${node.name}`, [arg1, arg2, darg1, darg2]);
214
- case 'atan2': {
215
- // atan2(y, x): angle from positive x-axis to point (x, y)
216
- // ∂/∂y(atan2(y, x)) = x/(x² + y²)
217
- // ∂/∂x(atan2(y, x)) = -y/(x² + y²)
218
- const y = arg1;
219
- const x = arg2;
220
- const dy = darg1;
221
- const dx = darg2;
222
- // denominator: x² + y²
223
- const denom = new BinaryOpNode('+', new BinaryOpNode('**', x, new NumberNode(2)), new BinaryOpNode('**', y, new NumberNode(2)));
224
- // Chain rule: (∂atan2/∂y) * dy + (∂atan2/∂x) * dx
225
- const term1 = new BinaryOpNode('*', new BinaryOpNode('/', x, denom), dy);
226
- const term2 = new BinaryOpNode('*', new UnaryOpNode('-', new BinaryOpNode('/', y, denom)), dx);
227
- return new BinaryOpNode('+', term1, term2);
228
- }
229
- case 'dot': {
230
- // dot product: u.x * v.x + u.y * v.y
231
- // d/dx(u.v) = du/dx . v + u . dv/dx
232
- return new BinaryOpNode('+', new FunctionCallNode('dot', [darg1, arg2]), new FunctionCallNode('dot', [arg1, darg2]));
233
- }
234
- default:
235
- throw new Error(`Unknown 2-arg function: ${node.name}`);
236
- }
237
- }
238
- throw new Error(`Unsupported function arity: ${node.name} with ${node.args.length} args`);
239
- }
240
- visitVectorAccess(node) {
241
- // d/dx(v.y) = d(v.y)/dx
242
- // This depends on whether we're differentiating w.r.t. the vector component
243
- if (node.vector.type === 'Variable') {
244
- const varName = node.vector.name;
245
- const fullName = `${varName}.${node.component}`;
246
- if (fullName === this.wrt) {
247
- return new NumberNode(1);
248
- }
249
- else {
250
- return new NumberNode(0);
251
- }
252
- }
253
- // For computed vectors, we'd need to differentiate the computation
254
- throw new Error('VectorAccess differentiation only supported for variable vectors');
255
- }
256
- visitVectorConstructor(node) {
257
- // d/dx(Vec2(u, v)) = Vec2(du/dx, dv/dx)
258
- const diffComponents = node.components.map(c => differentiate(c, this.wrt));
259
- return new VectorConstructorNode(node.vectorType, diffComponents);
260
- }
261
- }
262
- /**
263
- * Check if an expression depends on a variable
264
- */
265
- function dependsOn(node, varName) {
266
- if (node.type === 'Number') {
267
- return false;
268
- }
269
- if (node.type === 'Variable') {
270
- return node.name === varName;
271
- }
272
- if (node.type === 'UnaryOp') {
273
- return dependsOn(node.operand, varName);
274
- }
275
- if (node.type === 'BinaryOp') {
276
- const binOp = node;
277
- return dependsOn(binOp.left, varName) || dependsOn(binOp.right, varName);
278
- }
279
- if (node.type === 'FunctionCall') {
280
- return node.args.some(arg => dependsOn(arg, varName));
281
- }
282
- if (node.type === 'VectorAccess') {
283
- return dependsOn(node.vector, varName);
284
- }
285
- if (node.type === 'VectorConstructor') {
286
- return node.components.some(c => dependsOn(c, varName));
287
- }
288
- return false;
289
- }
290
- /**
291
- * Differentiate an AST node with respect to a variable
292
- */
293
- export function differentiate(node, wrt) {
294
- const visitor = new DifferentiationVisitor(wrt);
295
- return node.accept(visitor);
296
- }
297
- /**
298
- * Compute gradients of output w.r.t. all parameters
299
- */
300
- export function computeGradients(program, parameters) {
301
- // Build a map of variable definitions
302
- const variableMap = new Map();
303
- for (const assignment of program.assignments) {
304
- variableMap.set(assignment.variable, assignment.expression);
305
- }
306
- // Get the output expression
307
- const outputExpr = variableMap.get(program.output);
308
- if (!outputExpr) {
309
- throw new Error(`Output variable '${program.output}' not found`);
310
- }
311
- // Compute gradients using reverse-mode autodiff
312
- // Start with d(output)/d(output) = 1
313
- const gradients = new Map();
314
- gradients.set(program.output, new NumberNode(1));
315
- // Reverse topological order
316
- const variables = program.assignments.map(a => a.variable).reverse();
317
- for (const variable of variables) {
318
- const expr = variableMap.get(variable);
319
- const grad = gradients.get(variable);
320
- if (!grad)
321
- continue; // No gradient flows to this variable
322
- // For each parameter that this expression depends on
323
- for (const param of parameters) {
324
- if (dependsOn(expr, param)) {
325
- const localGrad = differentiate(expr, param);
326
- // Chain rule: accumulate gradient
327
- const chainedGrad = new BinaryOpNode('*', grad, localGrad);
328
- const existing = gradients.get(param);
329
- if (existing) {
330
- gradients.set(param, new BinaryOpNode('+', existing, chainedGrad));
331
- }
332
- else {
333
- gradients.set(param, chainedGrad);
334
- }
335
- }
336
- }
337
- }
338
- return gradients;
339
- }