gradient-script 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +515 -0
- package/dist/cli.d.ts +2 -0
- package/dist/cli.js +136 -0
- package/dist/dsl/AST.d.ts +123 -0
- package/dist/dsl/AST.js +23 -0
- package/dist/dsl/BuiltIns.d.ts +58 -0
- package/dist/dsl/BuiltIns.js +181 -0
- package/dist/dsl/CSE.d.ts +21 -0
- package/dist/dsl/CSE.js +194 -0
- package/dist/dsl/CodeGen.d.ts +60 -0
- package/dist/dsl/CodeGen.js +474 -0
- package/dist/dsl/Differentiation.d.ts +45 -0
- package/dist/dsl/Differentiation.js +421 -0
- package/dist/dsl/DiscontinuityAnalyzer.d.ts +18 -0
- package/dist/dsl/DiscontinuityAnalyzer.js +75 -0
- package/dist/dsl/Errors.d.ts +22 -0
- package/dist/dsl/Errors.js +49 -0
- package/dist/dsl/Expander.d.ts +13 -0
- package/dist/dsl/Expander.js +220 -0
- package/dist/dsl/ExpressionTransformer.d.ts +54 -0
- package/dist/dsl/ExpressionTransformer.js +102 -0
- package/dist/dsl/ExpressionUtils.d.ts +55 -0
- package/dist/dsl/ExpressionUtils.js +175 -0
- package/dist/dsl/GradientChecker.d.ts +71 -0
- package/dist/dsl/GradientChecker.js +258 -0
- package/dist/dsl/Guards.d.ts +27 -0
- package/dist/dsl/Guards.js +206 -0
- package/dist/dsl/Inliner.d.ts +10 -0
- package/dist/dsl/Inliner.js +40 -0
- package/dist/dsl/Lexer.d.ts +63 -0
- package/dist/dsl/Lexer.js +243 -0
- package/dist/dsl/Parser.d.ts +92 -0
- package/dist/dsl/Parser.js +328 -0
- package/dist/dsl/Simplify.d.ts +17 -0
- package/dist/dsl/Simplify.js +276 -0
- package/dist/dsl/TypeInference.d.ts +39 -0
- package/dist/dsl/TypeInference.js +147 -0
- package/dist/dsl/Types.d.ts +58 -0
- package/dist/dsl/Types.js +114 -0
- package/dist/index.d.ts +13 -0
- package/dist/index.js +11 -0
- package/dist/symbolic/AST.d.ts +113 -0
- package/dist/symbolic/AST.js +128 -0
- package/dist/symbolic/CodeGen.d.ts +35 -0
- package/dist/symbolic/CodeGen.js +280 -0
- package/dist/symbolic/Parser.d.ts +64 -0
- package/dist/symbolic/Parser.js +329 -0
- package/dist/symbolic/Simplify.d.ts +10 -0
- package/dist/symbolic/Simplify.js +244 -0
- package/dist/symbolic/SymbolicDiff.d.ts +35 -0
- package/dist/symbolic/SymbolicDiff.js +339 -0
- package/package.json +56 -0
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Differentiation for GradientScript DSL
|
|
3
|
+
* Computes symbolic gradients for structured types
|
|
4
|
+
*/
|
|
5
|
+
import { Types } from './Types.js';
|
|
6
|
+
import { expandBuiltIn, shouldExpand } from './Expander.js';
|
|
7
|
+
import { inlineIntermediateVariables } from './Inliner.js';
|
|
8
|
+
import { containsVariable } from './ExpressionUtils.js';
|
|
9
|
+
import { DifferentiationError } from './Errors.js';
|
|
10
|
+
/**
|
|
11
|
+
* Differentiation engine
|
|
12
|
+
*/
|
|
13
|
+
export class Differentiator {
|
|
14
|
+
env;
|
|
15
|
+
constructor(env) {
|
|
16
|
+
this.env = env;
|
|
17
|
+
}
|
|
18
|
+
/**
|
|
19
|
+
* Differentiate expression with respect to a variable (component-level)
|
|
20
|
+
*/
|
|
21
|
+
differentiate(expr, wrt) {
|
|
22
|
+
switch (expr.kind) {
|
|
23
|
+
case 'number':
|
|
24
|
+
return this.diffNumber(expr, wrt);
|
|
25
|
+
case 'variable':
|
|
26
|
+
return this.diffVariable(expr, wrt);
|
|
27
|
+
case 'binary':
|
|
28
|
+
return this.diffBinary(expr, wrt);
|
|
29
|
+
case 'unary':
|
|
30
|
+
return this.diffUnary(expr, wrt);
|
|
31
|
+
case 'call':
|
|
32
|
+
return this.diffCall(expr, wrt);
|
|
33
|
+
case 'component':
|
|
34
|
+
return this.diffComponent(expr, wrt);
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
diffNumber(expr, wrt) {
|
|
38
|
+
// d/dx(c) = 0
|
|
39
|
+
return { kind: 'number', value: 0 };
|
|
40
|
+
}
|
|
41
|
+
diffVariable(expr, wrt) {
|
|
42
|
+
// d/dx(x) = 1, d/dx(y) = 0
|
|
43
|
+
if (expr.name === wrt) {
|
|
44
|
+
return { kind: 'number', value: 1 };
|
|
45
|
+
}
|
|
46
|
+
return { kind: 'number', value: 0 };
|
|
47
|
+
}
|
|
48
|
+
diffBinary(expr, wrt) {
|
|
49
|
+
const { operator, left, right } = expr;
|
|
50
|
+
switch (operator) {
|
|
51
|
+
case '+':
|
|
52
|
+
case '-':
|
|
53
|
+
// (f + g)' = f' + g', (f - g)' = f' - g'
|
|
54
|
+
return {
|
|
55
|
+
kind: 'binary',
|
|
56
|
+
operator,
|
|
57
|
+
left: this.differentiate(left, wrt),
|
|
58
|
+
right: this.differentiate(right, wrt)
|
|
59
|
+
};
|
|
60
|
+
case '*':
|
|
61
|
+
// Product rule: (f * g)' = f' * g + f * g'
|
|
62
|
+
return {
|
|
63
|
+
kind: 'binary',
|
|
64
|
+
operator: '+',
|
|
65
|
+
left: {
|
|
66
|
+
kind: 'binary',
|
|
67
|
+
operator: '*',
|
|
68
|
+
left: this.differentiate(left, wrt),
|
|
69
|
+
right: right
|
|
70
|
+
},
|
|
71
|
+
right: {
|
|
72
|
+
kind: 'binary',
|
|
73
|
+
operator: '*',
|
|
74
|
+
left: left,
|
|
75
|
+
right: this.differentiate(right, wrt)
|
|
76
|
+
}
|
|
77
|
+
};
|
|
78
|
+
case '/':
|
|
79
|
+
// Quotient rule: (f / g)' = (f' * g - f * g') / g^2
|
|
80
|
+
return {
|
|
81
|
+
kind: 'binary',
|
|
82
|
+
operator: '/',
|
|
83
|
+
left: {
|
|
84
|
+
kind: 'binary',
|
|
85
|
+
operator: '-',
|
|
86
|
+
left: {
|
|
87
|
+
kind: 'binary',
|
|
88
|
+
operator: '*',
|
|
89
|
+
left: this.differentiate(left, wrt),
|
|
90
|
+
right: right
|
|
91
|
+
},
|
|
92
|
+
right: {
|
|
93
|
+
kind: 'binary',
|
|
94
|
+
operator: '*',
|
|
95
|
+
left: left,
|
|
96
|
+
right: this.differentiate(right, wrt)
|
|
97
|
+
}
|
|
98
|
+
},
|
|
99
|
+
right: {
|
|
100
|
+
kind: 'binary',
|
|
101
|
+
operator: '^',
|
|
102
|
+
left: right,
|
|
103
|
+
right: { kind: 'number', value: 2 }
|
|
104
|
+
}
|
|
105
|
+
};
|
|
106
|
+
case '^':
|
|
107
|
+
case '**':
|
|
108
|
+
// Power rule: (f^g)' = f^g * (g' * ln(f) + g * f' / f)
|
|
109
|
+
// Special case: if g is constant, use simple power rule: (f^n)' = n * f^(n-1) * f'
|
|
110
|
+
if (this.isConstant(right, wrt)) {
|
|
111
|
+
return {
|
|
112
|
+
kind: 'binary',
|
|
113
|
+
operator: '*',
|
|
114
|
+
left: {
|
|
115
|
+
kind: 'binary',
|
|
116
|
+
operator: '*',
|
|
117
|
+
left: right,
|
|
118
|
+
right: {
|
|
119
|
+
kind: 'binary',
|
|
120
|
+
operator: '^',
|
|
121
|
+
left: left,
|
|
122
|
+
right: {
|
|
123
|
+
kind: 'binary',
|
|
124
|
+
operator: '-',
|
|
125
|
+
left: right,
|
|
126
|
+
right: { kind: 'number', value: 1 }
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
},
|
|
130
|
+
right: this.differentiate(left, wrt)
|
|
131
|
+
};
|
|
132
|
+
}
|
|
133
|
+
else {
|
|
134
|
+
throw new DifferentiationError('Power with variable exponent not supported', 'f^g', 'Both base and exponent depend on the differentiation variable');
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
diffUnary(expr, wrt) {
|
|
139
|
+
const { operator, operand } = expr;
|
|
140
|
+
if (operator === '-') {
|
|
141
|
+
// (-f)' = -f'
|
|
142
|
+
return {
|
|
143
|
+
kind: 'unary',
|
|
144
|
+
operator: '-',
|
|
145
|
+
operand: this.differentiate(operand, wrt)
|
|
146
|
+
};
|
|
147
|
+
}
|
|
148
|
+
else {
|
|
149
|
+
// (+f)' = f'
|
|
150
|
+
return this.differentiate(operand, wrt);
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
diffCall(expr, wrt) {
|
|
154
|
+
const { name, args } = expr;
|
|
155
|
+
// Expand built-in functions before differentiating
|
|
156
|
+
if (shouldExpand(name)) {
|
|
157
|
+
const expanded = expandBuiltIn(expr);
|
|
158
|
+
return this.differentiate(expanded, wrt);
|
|
159
|
+
}
|
|
160
|
+
// Differentiate scalar math functions
|
|
161
|
+
return this.diffMathFunction(name, args, wrt);
|
|
162
|
+
}
|
|
163
|
+
diffMathFunction(name, args, wrt) {
|
|
164
|
+
if (args.length !== 1 && name !== 'atan2' && name !== 'pow' && name !== 'min' && name !== 'max' && name !== 'clamp') {
|
|
165
|
+
throw new DifferentiationError('Function differentiation not yet supported', name, `Expected 1 argument or specific multi-arg function, got ${args.length} arguments`);
|
|
166
|
+
}
|
|
167
|
+
const arg = args[0];
|
|
168
|
+
const argPrime = this.differentiate(arg, wrt);
|
|
169
|
+
switch (name) {
|
|
170
|
+
case 'sin':
|
|
171
|
+
// sin(f)' = cos(f) * f'
|
|
172
|
+
return {
|
|
173
|
+
kind: 'binary',
|
|
174
|
+
operator: '*',
|
|
175
|
+
left: { kind: 'call', name: 'cos', args: [arg] },
|
|
176
|
+
right: argPrime
|
|
177
|
+
};
|
|
178
|
+
case 'cos':
|
|
179
|
+
// cos(f)' = -sin(f) * f'
|
|
180
|
+
return {
|
|
181
|
+
kind: 'binary',
|
|
182
|
+
operator: '*',
|
|
183
|
+
left: {
|
|
184
|
+
kind: 'unary',
|
|
185
|
+
operator: '-',
|
|
186
|
+
operand: { kind: 'call', name: 'sin', args: [arg] }
|
|
187
|
+
},
|
|
188
|
+
right: argPrime
|
|
189
|
+
};
|
|
190
|
+
case 'tan':
|
|
191
|
+
// tan(f)' = sec^2(f) * f' = (1 / cos^2(f)) * f'
|
|
192
|
+
return {
|
|
193
|
+
kind: 'binary',
|
|
194
|
+
operator: '*',
|
|
195
|
+
left: {
|
|
196
|
+
kind: 'binary',
|
|
197
|
+
operator: '/',
|
|
198
|
+
left: { kind: 'number', value: 1 },
|
|
199
|
+
right: {
|
|
200
|
+
kind: 'binary',
|
|
201
|
+
operator: '^',
|
|
202
|
+
left: { kind: 'call', name: 'cos', args: [arg] },
|
|
203
|
+
right: { kind: 'number', value: 2 }
|
|
204
|
+
}
|
|
205
|
+
},
|
|
206
|
+
right: argPrime
|
|
207
|
+
};
|
|
208
|
+
case 'exp':
|
|
209
|
+
// exp(f)' = exp(f) * f'
|
|
210
|
+
return {
|
|
211
|
+
kind: 'binary',
|
|
212
|
+
operator: '*',
|
|
213
|
+
left: { kind: 'call', name: 'exp', args: [arg] },
|
|
214
|
+
right: argPrime
|
|
215
|
+
};
|
|
216
|
+
case 'log':
|
|
217
|
+
// log(f)' = f' / f
|
|
218
|
+
return {
|
|
219
|
+
kind: 'binary',
|
|
220
|
+
operator: '/',
|
|
221
|
+
left: argPrime,
|
|
222
|
+
right: arg
|
|
223
|
+
};
|
|
224
|
+
case 'sqrt':
|
|
225
|
+
// sqrt(f)' = f' / (2 * sqrt(f))
|
|
226
|
+
return {
|
|
227
|
+
kind: 'binary',
|
|
228
|
+
operator: '/',
|
|
229
|
+
left: argPrime,
|
|
230
|
+
right: {
|
|
231
|
+
kind: 'binary',
|
|
232
|
+
operator: '*',
|
|
233
|
+
left: { kind: 'number', value: 2 },
|
|
234
|
+
right: { kind: 'call', name: 'sqrt', args: [arg] }
|
|
235
|
+
}
|
|
236
|
+
};
|
|
237
|
+
case 'abs':
|
|
238
|
+
// abs(f)' = f' * sign(f) = f' * f / abs(f)
|
|
239
|
+
return {
|
|
240
|
+
kind: 'binary',
|
|
241
|
+
operator: '*',
|
|
242
|
+
left: argPrime,
|
|
243
|
+
right: {
|
|
244
|
+
kind: 'binary',
|
|
245
|
+
operator: '/',
|
|
246
|
+
left: arg,
|
|
247
|
+
right: { kind: 'call', name: 'abs', args: [arg] }
|
|
248
|
+
}
|
|
249
|
+
};
|
|
250
|
+
case 'atan2':
|
|
251
|
+
// atan2(y, x)' w.r.t. variable
|
|
252
|
+
// d/dx atan2(y, x) = -y / (x^2 + y^2)
|
|
253
|
+
// d/dy atan2(y, x) = x / (x^2 + y^2)
|
|
254
|
+
// General: atan2(f, g)' = (g * f' - f * g') / (f^2 + g^2)
|
|
255
|
+
const y = args[0];
|
|
256
|
+
const x = args[1];
|
|
257
|
+
const yPrime = this.differentiate(y, wrt);
|
|
258
|
+
const xPrime = this.differentiate(x, wrt);
|
|
259
|
+
return {
|
|
260
|
+
kind: 'binary',
|
|
261
|
+
operator: '/',
|
|
262
|
+
left: {
|
|
263
|
+
kind: 'binary',
|
|
264
|
+
operator: '-',
|
|
265
|
+
left: {
|
|
266
|
+
kind: 'binary',
|
|
267
|
+
operator: '*',
|
|
268
|
+
left: x,
|
|
269
|
+
right: yPrime
|
|
270
|
+
},
|
|
271
|
+
right: {
|
|
272
|
+
kind: 'binary',
|
|
273
|
+
operator: '*',
|
|
274
|
+
left: y,
|
|
275
|
+
right: xPrime
|
|
276
|
+
}
|
|
277
|
+
},
|
|
278
|
+
right: {
|
|
279
|
+
kind: 'binary',
|
|
280
|
+
operator: '+',
|
|
281
|
+
left: {
|
|
282
|
+
kind: 'binary',
|
|
283
|
+
operator: '^',
|
|
284
|
+
left: x,
|
|
285
|
+
right: { kind: 'number', value: 2 }
|
|
286
|
+
},
|
|
287
|
+
right: {
|
|
288
|
+
kind: 'binary',
|
|
289
|
+
operator: '^',
|
|
290
|
+
left: y,
|
|
291
|
+
right: { kind: 'number', value: 2 }
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
};
|
|
295
|
+
case 'min':
|
|
296
|
+
// min(a, b)' = a' if a < b, b' if b < a, subgradient if a = b
|
|
297
|
+
// We use: (a' + b') / 2 - (a' - b') * sign(a - b) / 2
|
|
298
|
+
// Simplified subgradient: average of gradients when equal
|
|
299
|
+
{
|
|
300
|
+
const a = args[0];
|
|
301
|
+
const b = args[1];
|
|
302
|
+
const aPrime = this.differentiate(a, wrt);
|
|
303
|
+
const bPrime = this.differentiate(b, wrt);
|
|
304
|
+
// For now, use simple approach: gradient is aPrime when a < b, bPrime when b < a
|
|
305
|
+
// In practice: min(a,b)' ≈ a' when a < b (dominant)
|
|
306
|
+
// Proper implementation: return conditional or use subgradient
|
|
307
|
+
// Here we use: a' if a≤b else b' (subgradient convention: use first argument at tie)
|
|
308
|
+
// Simple approximation for symbolic differentiation:
|
|
309
|
+
// grad_min(a, b) w.r.t x = da/dx * (a <= b) + db/dx * (b < a)
|
|
310
|
+
// Since we can't represent conditionals easily, we document this as a limitation
|
|
311
|
+
// and use the midpoint convention: (da/dx + db/dx)/2 - sign(a-b) * (da/dx - db/dx)/2
|
|
312
|
+
// Simplified: just use a' (assumes a is typically smaller)
|
|
313
|
+
// Better: generate both and let user handle non-smoothness
|
|
314
|
+
return aPrime; // Subgradient: choose first argument's gradient
|
|
315
|
+
}
|
|
316
|
+
case 'max':
|
|
317
|
+
// max(a, b)' = a' if a > b, b' if b > a
|
|
318
|
+
// Similar to min, we use subgradient convention
|
|
319
|
+
{
|
|
320
|
+
const a = args[0];
|
|
321
|
+
const b = args[1];
|
|
322
|
+
const aPrime = this.differentiate(a, wrt);
|
|
323
|
+
return aPrime; // Subgradient: choose first argument's gradient
|
|
324
|
+
}
|
|
325
|
+
case 'clamp':
|
|
326
|
+
// clamp(x, lo, hi)' = 0 if x < lo or x > hi, x' if lo ≤ x ≤ hi
|
|
327
|
+
// Subgradient at boundaries: 0
|
|
328
|
+
{
|
|
329
|
+
const x = args[0];
|
|
330
|
+
const xPrime = this.differentiate(x, wrt);
|
|
331
|
+
// Return x' (assumes x is in valid range; gradient is 0 outside, x' inside)
|
|
332
|
+
// User should handle boundaries in their optimization
|
|
333
|
+
return xPrime; // Subgradient: gradient of x when in range, 0 outside
|
|
334
|
+
}
|
|
335
|
+
default:
|
|
336
|
+
throw new DifferentiationError('Function differentiation not implemented', name, 'This mathematical function does not have a derivative rule defined');
|
|
337
|
+
}
|
|
338
|
+
}
|
|
339
|
+
diffComponent(expr, wrt) {
|
|
340
|
+
// Differentiate component access
|
|
341
|
+
// For example: d/d(u.x) of v.x
|
|
342
|
+
// This is tricky - we need to check if the component access matches wrt
|
|
343
|
+
// If wrt is "u.x" and expr is "u.x", derivative is 1
|
|
344
|
+
// Otherwise 0
|
|
345
|
+
// For now, we'll handle the simple case where object is a variable
|
|
346
|
+
if (expr.object.kind === 'variable') {
|
|
347
|
+
const fullName = `${expr.object.name}.${expr.component}`;
|
|
348
|
+
if (fullName === wrt) {
|
|
349
|
+
return { kind: 'number', value: 1 };
|
|
350
|
+
}
|
|
351
|
+
return { kind: 'number', value: 0 };
|
|
352
|
+
}
|
|
353
|
+
// If object is a binary operation (e.g., (u-v).x), expand it first
|
|
354
|
+
// (u-v).x -> u.x - v.x, then differentiate
|
|
355
|
+
if (expr.object.kind === 'binary') {
|
|
356
|
+
const expandedExpr = this.expandComponentAccess(expr);
|
|
357
|
+
return this.differentiate(expandedExpr, wrt);
|
|
358
|
+
}
|
|
359
|
+
return { kind: 'number', value: 0 };
|
|
360
|
+
}
|
|
361
|
+
expandComponentAccess(expr) {
|
|
362
|
+
if (expr.object.kind === 'binary') {
|
|
363
|
+
const { operator, left, right } = expr.object;
|
|
364
|
+
// (left op right).comp -> left.comp op right.comp
|
|
365
|
+
return {
|
|
366
|
+
kind: 'binary',
|
|
367
|
+
operator,
|
|
368
|
+
left: {
|
|
369
|
+
kind: 'component',
|
|
370
|
+
object: left,
|
|
371
|
+
component: expr.component
|
|
372
|
+
},
|
|
373
|
+
right: {
|
|
374
|
+
kind: 'component',
|
|
375
|
+
object: right,
|
|
376
|
+
component: expr.component
|
|
377
|
+
}
|
|
378
|
+
};
|
|
379
|
+
}
|
|
380
|
+
return expr;
|
|
381
|
+
}
|
|
382
|
+
/**
|
|
383
|
+
* Check if expression is constant with respect to wrt
|
|
384
|
+
*/
|
|
385
|
+
isConstant(expr, wrt) {
|
|
386
|
+
return !containsVariable(expr, wrt);
|
|
387
|
+
}
|
|
388
|
+
}
|
|
389
|
+
/**
|
|
390
|
+
* Compute gradients for a function
|
|
391
|
+
*/
|
|
392
|
+
export function computeFunctionGradients(func, env) {
|
|
393
|
+
const gradients = new Map();
|
|
394
|
+
const differ = new Differentiator(env);
|
|
395
|
+
// Inline all intermediate variables first
|
|
396
|
+
const inlinedExpr = inlineIntermediateVariables(func);
|
|
397
|
+
// For each parameter that requires gradients
|
|
398
|
+
for (const param of func.parameters) {
|
|
399
|
+
if (!param.requiresGrad)
|
|
400
|
+
continue;
|
|
401
|
+
const paramType = env.getOrThrow(param.name);
|
|
402
|
+
if (Types.isScalar(paramType)) {
|
|
403
|
+
// Scalar parameter - compute single gradient
|
|
404
|
+
const grad = differ.differentiate(inlinedExpr, param.name);
|
|
405
|
+
gradients.set(param.name, grad);
|
|
406
|
+
}
|
|
407
|
+
else {
|
|
408
|
+
// Structured parameter - compute gradient for each component
|
|
409
|
+
const structGrad = {
|
|
410
|
+
components: new Map()
|
|
411
|
+
};
|
|
412
|
+
for (const component of paramType.components) {
|
|
413
|
+
const wrt = `${param.name}.${component}`;
|
|
414
|
+
const grad = differ.differentiate(inlinedExpr, wrt);
|
|
415
|
+
structGrad.components.set(component, grad);
|
|
416
|
+
}
|
|
417
|
+
gradients.set(param.name, structGrad);
|
|
418
|
+
}
|
|
419
|
+
}
|
|
420
|
+
return { gradients };
|
|
421
|
+
}
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Analyzes functions for known discontinuities
|
|
3
|
+
*/
|
|
4
|
+
import { FunctionDef } from './AST.js';
|
|
5
|
+
import { DiscontinuityInfo } from './BuiltIns.js';
|
|
6
|
+
export interface DiscontinuityWarning {
|
|
7
|
+
functionName: string;
|
|
8
|
+
location: string;
|
|
9
|
+
discontinuities: DiscontinuityInfo[];
|
|
10
|
+
}
|
|
11
|
+
/**
|
|
12
|
+
* Analyze a function for discontinuities
|
|
13
|
+
*/
|
|
14
|
+
export declare function analyzeDiscontinuities(func: FunctionDef): DiscontinuityWarning[];
|
|
15
|
+
/**
|
|
16
|
+
* Format discontinuity warnings for display
|
|
17
|
+
*/
|
|
18
|
+
export declare function formatDiscontinuityWarnings(warnings: DiscontinuityWarning[]): string;
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Analyzes functions for known discontinuities
|
|
3
|
+
*/
|
|
4
|
+
import { builtIns } from './BuiltIns.js';
|
|
5
|
+
/**
|
|
6
|
+
* Analyze a function for discontinuities
|
|
7
|
+
*/
|
|
8
|
+
export function analyzeDiscontinuities(func) {
|
|
9
|
+
const warnings = [];
|
|
10
|
+
// Check return expression
|
|
11
|
+
collectDiscontinuities(func.returnExpr, 'return expression', warnings);
|
|
12
|
+
// Check intermediate variables
|
|
13
|
+
for (const stmt of func.body) {
|
|
14
|
+
if (stmt.kind === 'assignment') {
|
|
15
|
+
collectDiscontinuities(stmt.expression, `variable '${stmt.variable}'`, warnings);
|
|
16
|
+
}
|
|
17
|
+
}
|
|
18
|
+
return warnings;
|
|
19
|
+
}
|
|
20
|
+
/**
|
|
21
|
+
* Recursively collect discontinuities from an expression
|
|
22
|
+
*/
|
|
23
|
+
function collectDiscontinuities(expr, location, warnings) {
|
|
24
|
+
switch (expr.kind) {
|
|
25
|
+
case 'call':
|
|
26
|
+
const disconts = builtIns.getDiscontinuities(expr.name);
|
|
27
|
+
if (disconts.length > 0) {
|
|
28
|
+
warnings.push({
|
|
29
|
+
functionName: expr.name,
|
|
30
|
+
location,
|
|
31
|
+
discontinuities: disconts
|
|
32
|
+
});
|
|
33
|
+
}
|
|
34
|
+
// Recurse into arguments
|
|
35
|
+
for (const arg of expr.args) {
|
|
36
|
+
collectDiscontinuities(arg, location, warnings);
|
|
37
|
+
}
|
|
38
|
+
break;
|
|
39
|
+
case 'binary':
|
|
40
|
+
collectDiscontinuities(expr.left, location, warnings);
|
|
41
|
+
collectDiscontinuities(expr.right, location, warnings);
|
|
42
|
+
break;
|
|
43
|
+
case 'unary':
|
|
44
|
+
collectDiscontinuities(expr.operand, location, warnings);
|
|
45
|
+
break;
|
|
46
|
+
case 'component':
|
|
47
|
+
collectDiscontinuities(expr.object, location, warnings);
|
|
48
|
+
break;
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
/**
|
|
52
|
+
* Format discontinuity warnings for display
|
|
53
|
+
*/
|
|
54
|
+
export function formatDiscontinuityWarnings(warnings) {
|
|
55
|
+
if (warnings.length === 0) {
|
|
56
|
+
return '';
|
|
57
|
+
}
|
|
58
|
+
const lines = [];
|
|
59
|
+
lines.push('⚠️ DISCONTINUITY WARNINGS:');
|
|
60
|
+
lines.push('');
|
|
61
|
+
lines.push('The following functions have known discontinuities that may affect');
|
|
62
|
+
lines.push('numerical gradient checking:');
|
|
63
|
+
lines.push('');
|
|
64
|
+
for (const warning of warnings) {
|
|
65
|
+
lines.push(` • ${warning.functionName} (in ${warning.location})`);
|
|
66
|
+
for (const discont of warning.discontinuities) {
|
|
67
|
+
lines.push(` - ${discont.description}`);
|
|
68
|
+
lines.push(` Occurs when: ${discont.condition}`);
|
|
69
|
+
}
|
|
70
|
+
lines.push('');
|
|
71
|
+
}
|
|
72
|
+
lines.push('Note: Symbolic gradients remain correct at these points,');
|
|
73
|
+
lines.push('but numerical validation may show large errors due to discontinuities.');
|
|
74
|
+
return lines.join('\n');
|
|
75
|
+
}
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
export declare class ParseError extends Error {
|
|
2
|
+
line: number;
|
|
3
|
+
column: number;
|
|
4
|
+
token?: string | undefined;
|
|
5
|
+
constructor(message: string, line: number, column: number, token?: string | undefined);
|
|
6
|
+
}
|
|
7
|
+
export declare class TypeError extends Error {
|
|
8
|
+
expression: string;
|
|
9
|
+
expectedType?: string | undefined;
|
|
10
|
+
actualType?: string | undefined;
|
|
11
|
+
constructor(message: string, expression: string, expectedType?: string | undefined, actualType?: string | undefined);
|
|
12
|
+
}
|
|
13
|
+
export declare class DifferentiationError extends Error {
|
|
14
|
+
operation: string;
|
|
15
|
+
reason?: string | undefined;
|
|
16
|
+
constructor(message: string, operation: string, reason?: string | undefined);
|
|
17
|
+
}
|
|
18
|
+
export declare class CodeGenError extends Error {
|
|
19
|
+
node: string;
|
|
20
|
+
format?: string | undefined;
|
|
21
|
+
constructor(message: string, node: string, format?: string | undefined);
|
|
22
|
+
}
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
export class ParseError extends Error {
|
|
2
|
+
line;
|
|
3
|
+
column;
|
|
4
|
+
token;
|
|
5
|
+
constructor(message, line, column, token) {
|
|
6
|
+
super(`Parse error at ${line}:${column}: ${message}`);
|
|
7
|
+
this.line = line;
|
|
8
|
+
this.column = column;
|
|
9
|
+
this.token = token;
|
|
10
|
+
this.name = 'ParseError';
|
|
11
|
+
}
|
|
12
|
+
}
|
|
13
|
+
export class TypeError extends Error {
|
|
14
|
+
expression;
|
|
15
|
+
expectedType;
|
|
16
|
+
actualType;
|
|
17
|
+
constructor(message, expression, expectedType, actualType) {
|
|
18
|
+
const typeInfo = expectedType && actualType
|
|
19
|
+
? ` (expected ${expectedType}, got ${actualType})`
|
|
20
|
+
: '';
|
|
21
|
+
super(`Type error in '${expression}': ${message}${typeInfo}`);
|
|
22
|
+
this.expression = expression;
|
|
23
|
+
this.expectedType = expectedType;
|
|
24
|
+
this.actualType = actualType;
|
|
25
|
+
this.name = 'TypeError';
|
|
26
|
+
}
|
|
27
|
+
}
|
|
28
|
+
export class DifferentiationError extends Error {
|
|
29
|
+
operation;
|
|
30
|
+
reason;
|
|
31
|
+
constructor(message, operation, reason) {
|
|
32
|
+
const reasonInfo = reason ? ` - ${reason}` : '';
|
|
33
|
+
super(`Differentiation error for '${operation}': ${message}${reasonInfo}`);
|
|
34
|
+
this.operation = operation;
|
|
35
|
+
this.reason = reason;
|
|
36
|
+
this.name = 'DifferentiationError';
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
export class CodeGenError extends Error {
|
|
40
|
+
node;
|
|
41
|
+
format;
|
|
42
|
+
constructor(message, node, format) {
|
|
43
|
+
const formatInfo = format ? ` (format: ${format})` : '';
|
|
44
|
+
super(`Code generation error for '${node}': ${message}${formatInfo}`);
|
|
45
|
+
this.node = node;
|
|
46
|
+
this.format = format;
|
|
47
|
+
this.name = 'CodeGenError';
|
|
48
|
+
}
|
|
49
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Expander for GradientScript DSL
|
|
3
|
+
* Expands built-in functions and struct operations into scalar operations
|
|
4
|
+
*/
|
|
5
|
+
import { Expression, FunctionCall } from './AST.js';
|
|
6
|
+
/**
|
|
7
|
+
* Expand built-in function calls to scalar expressions
|
|
8
|
+
*/
|
|
9
|
+
export declare function expandBuiltIn(call: FunctionCall): Expression;
|
|
10
|
+
/**
|
|
11
|
+
* Check if a function call should be expanded
|
|
12
|
+
*/
|
|
13
|
+
export declare function shouldExpand(name: string): boolean;
|