gradient-script 0.1.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +52 -9
- package/dist/cli.js +134 -19
- package/dist/dsl/AST.d.ts +8 -0
- package/dist/dsl/CodeGen.d.ts +8 -3
- package/dist/dsl/CodeGen.js +583 -132
- package/dist/dsl/Errors.d.ts +6 -1
- package/dist/dsl/Errors.js +70 -1
- package/dist/dsl/Expander.js +5 -2
- package/dist/dsl/ExpressionUtils.d.ts +14 -0
- package/dist/dsl/ExpressionUtils.js +56 -0
- package/dist/dsl/GradientChecker.d.ts +21 -0
- package/dist/dsl/GradientChecker.js +109 -23
- package/dist/dsl/Guards.d.ts +3 -1
- package/dist/dsl/Guards.js +86 -43
- package/dist/dsl/Inliner.d.ts +5 -0
- package/dist/dsl/Inliner.js +11 -2
- package/dist/dsl/Lexer.js +3 -1
- package/dist/dsl/Parser.js +11 -5
- package/dist/dsl/Simplify.d.ts +7 -0
- package/dist/dsl/Simplify.js +183 -0
- package/dist/dsl/egraph/Convert.d.ts +23 -0
- package/dist/dsl/egraph/Convert.js +84 -0
- package/dist/dsl/egraph/EGraph.d.ts +93 -0
- package/dist/dsl/egraph/EGraph.js +292 -0
- package/dist/dsl/egraph/ENode.d.ts +63 -0
- package/dist/dsl/egraph/ENode.js +94 -0
- package/dist/dsl/egraph/Extractor.d.ts +49 -0
- package/dist/dsl/egraph/Extractor.js +1068 -0
- package/dist/dsl/egraph/Optimizer.d.ts +50 -0
- package/dist/dsl/egraph/Optimizer.js +88 -0
- package/dist/dsl/egraph/Pattern.d.ts +80 -0
- package/dist/dsl/egraph/Pattern.js +325 -0
- package/dist/dsl/egraph/Rewriter.d.ts +44 -0
- package/dist/dsl/egraph/Rewriter.js +131 -0
- package/dist/dsl/egraph/Rules.d.ts +44 -0
- package/dist/dsl/egraph/Rules.js +187 -0
- package/dist/dsl/egraph/index.d.ts +15 -0
- package/dist/dsl/egraph/index.js +21 -0
- package/package.json +1 -1
- package/dist/dsl/CSE.d.ts +0 -21
- package/dist/dsl/CSE.js +0 -194
- package/dist/symbolic/AST.d.ts +0 -113
- package/dist/symbolic/AST.js +0 -128
- package/dist/symbolic/CodeGen.d.ts +0 -35
- package/dist/symbolic/CodeGen.js +0 -280
- package/dist/symbolic/Parser.d.ts +0 -64
- package/dist/symbolic/Parser.js +0 -329
- package/dist/symbolic/Simplify.d.ts +0 -10
- package/dist/symbolic/Simplify.js +0 -244
- package/dist/symbolic/SymbolicDiff.d.ts +0 -35
- package/dist/symbolic/SymbolicDiff.js +0 -339
package/dist/dsl/Errors.d.ts
CHANGED
|
@@ -2,8 +2,13 @@ export declare class ParseError extends Error {
|
|
|
2
2
|
line: number;
|
|
3
3
|
column: number;
|
|
4
4
|
token?: string | undefined;
|
|
5
|
-
|
|
5
|
+
sourceContext?: string | undefined;
|
|
6
|
+
constructor(message: string, line: number, column: number, token?: string | undefined, sourceContext?: string | undefined);
|
|
6
7
|
}
|
|
8
|
+
/**
|
|
9
|
+
* Format a user-friendly error message with source context
|
|
10
|
+
*/
|
|
11
|
+
export declare function formatParseError(error: ParseError, sourceCode: string, verbose?: boolean): string;
|
|
7
12
|
export declare class TypeError extends Error {
|
|
8
13
|
expression: string;
|
|
9
14
|
expectedType?: string | undefined;
|
package/dist/dsl/Errors.js
CHANGED
|
@@ -2,14 +2,83 @@ export class ParseError extends Error {
|
|
|
2
2
|
line;
|
|
3
3
|
column;
|
|
4
4
|
token;
|
|
5
|
-
|
|
5
|
+
sourceContext;
|
|
6
|
+
constructor(message, line, column, token, sourceContext) {
|
|
6
7
|
super(`Parse error at ${line}:${column}: ${message}`);
|
|
7
8
|
this.line = line;
|
|
8
9
|
this.column = column;
|
|
9
10
|
this.token = token;
|
|
11
|
+
this.sourceContext = sourceContext;
|
|
10
12
|
this.name = 'ParseError';
|
|
11
13
|
}
|
|
12
14
|
}
|
|
15
|
+
/**
|
|
16
|
+
* Format a user-friendly error message with source context
|
|
17
|
+
*/
|
|
18
|
+
export function formatParseError(error, sourceCode, verbose = false) {
|
|
19
|
+
const lines = sourceCode.split('\n');
|
|
20
|
+
const errorLine = lines[error.line - 1];
|
|
21
|
+
let output = `Error: ${error.message.replace(/^Parse error at \d+:\d+: /, '')}\n`;
|
|
22
|
+
// Show the source line with the error
|
|
23
|
+
if (errorLine) {
|
|
24
|
+
output += `\n ${errorLine}\n`;
|
|
25
|
+
// Add caret pointing to error position
|
|
26
|
+
const caretPos = Math.max(0, error.column - 1);
|
|
27
|
+
output += ` ${' '.repeat(caretPos)}^\n`;
|
|
28
|
+
}
|
|
29
|
+
// Add helpful tips based on the error
|
|
30
|
+
output += formatErrorGuidance(error);
|
|
31
|
+
// Only show stack trace in verbose mode
|
|
32
|
+
if (verbose && error.stack) {
|
|
33
|
+
output += '\n\nStack trace:\n' + error.stack;
|
|
34
|
+
}
|
|
35
|
+
return output;
|
|
36
|
+
}
|
|
37
|
+
/**
|
|
38
|
+
* Provide contextual guidance based on error patterns
|
|
39
|
+
*/
|
|
40
|
+
function formatErrorGuidance(error) {
|
|
41
|
+
const msg = error.message.toLowerCase();
|
|
42
|
+
const token = error.token;
|
|
43
|
+
// Semicolon error
|
|
44
|
+
if (token === ';') {
|
|
45
|
+
return `
|
|
46
|
+
Semicolons are not part of gradient-script syntax.
|
|
47
|
+
Each statement should be on its own line.
|
|
48
|
+
|
|
49
|
+
Correct syntax:
|
|
50
|
+
function example(x∇, y∇) {
|
|
51
|
+
result = x + y
|
|
52
|
+
return result
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
💡 Tip: gradient-script uses newline-delimited statements (like Python),
|
|
56
|
+
not semicolons (like JavaScript/C#).
|
|
57
|
+
`;
|
|
58
|
+
}
|
|
59
|
+
// Missing colon in type annotation
|
|
60
|
+
if (msg.includes("expected ':'")) {
|
|
61
|
+
return `
|
|
62
|
+
Type annotations require a colon before the type.
|
|
63
|
+
|
|
64
|
+
Correct syntax:
|
|
65
|
+
function distance(point∇: {x, y}) {
|
|
66
|
+
^
|
|
67
|
+
|
|
68
|
+
💡 Tip: Parameters marked with ∇ need type annotations to specify structure.
|
|
69
|
+
`;
|
|
70
|
+
}
|
|
71
|
+
// Missing gradient marker suggestion
|
|
72
|
+
if (msg.includes('expected parameter name') || msg.includes('unexpected')) {
|
|
73
|
+
return `
|
|
74
|
+
💡 Tip: Make sure all parameters are properly formatted.
|
|
75
|
+
Variables that need gradients must be marked with ∇.
|
|
76
|
+
|
|
77
|
+
Example: function f(a∇: {x, y}, b) { ... }
|
|
78
|
+
`;
|
|
79
|
+
}
|
|
80
|
+
return '';
|
|
81
|
+
}
|
|
13
82
|
export class TypeError extends Error {
|
|
14
83
|
expression;
|
|
15
84
|
expectedType;
|
package/dist/dsl/Expander.js
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
* Expander for GradientScript DSL
|
|
3
3
|
* Expands built-in functions and struct operations into scalar operations
|
|
4
4
|
*/
|
|
5
|
+
import { DifferentiationError } from './Errors.js';
|
|
5
6
|
/**
|
|
6
7
|
* Expand built-in function calls to scalar expressions
|
|
7
8
|
*/
|
|
@@ -15,13 +16,15 @@ export function expandBuiltIn(call) {
|
|
|
15
16
|
case 'magnitude2d':
|
|
16
17
|
return expandMagnitude2d(args[0]);
|
|
17
18
|
case 'normalize2d':
|
|
18
|
-
throw new
|
|
19
|
+
throw new DifferentiationError('normalize2d not yet supported', 'normalize2d', 'Vector normalization requires special handling for zero-length vectors. ' +
|
|
20
|
+
'Use magnitude2d() and division for now.');
|
|
19
21
|
case 'distance2d':
|
|
20
22
|
return expandDistance2d(args[0], args[1]);
|
|
21
23
|
case 'dot3d':
|
|
22
24
|
return expandDot3d(args[0], args[1]);
|
|
23
25
|
case 'cross3d':
|
|
24
|
-
throw new
|
|
26
|
+
throw new DifferentiationError('cross3d returns vector - not yet supported', 'cross3d', 'Cross product returns a 3D vector, which requires structured gradient support. ' +
|
|
27
|
+
'This feature is not yet implemented.');
|
|
25
28
|
case 'magnitude3d':
|
|
26
29
|
return expandMagnitude3d(args[0]);
|
|
27
30
|
default:
|
|
@@ -53,3 +53,17 @@ export declare function containsVariable(expr: Expression, varName: string): boo
|
|
|
53
53
|
* Calculate the maximum nesting depth of an expression
|
|
54
54
|
*/
|
|
55
55
|
export declare function expressionDepth(expr: Expression): number;
|
|
56
|
+
/**
|
|
57
|
+
* Serializes an expression to structural string representation.
|
|
58
|
+
* Used for exact expression comparison - operand order matters.
|
|
59
|
+
*
|
|
60
|
+
* This ensures consistent string representation of expressions across different
|
|
61
|
+
* parts of the codebase.
|
|
62
|
+
*/
|
|
63
|
+
export declare function serializeExpression(expr: Expression): string;
|
|
64
|
+
/**
|
|
65
|
+
* Serializes an expression to canonical form for CSE matching.
|
|
66
|
+
* Commutative operations (+ and *) have operands sorted lexicographically,
|
|
67
|
+
* so a*b and b*a produce the same canonical string.
|
|
68
|
+
*/
|
|
69
|
+
export declare function serializeCanonical(expr: Expression): string;
|
|
@@ -173,3 +173,59 @@ export function expressionDepth(expr) {
|
|
|
173
173
|
return 1 + expressionDepth(expr.object);
|
|
174
174
|
}
|
|
175
175
|
}
|
|
176
|
+
/**
|
|
177
|
+
* Serializes an expression to structural string representation.
|
|
178
|
+
* Used for exact expression comparison - operand order matters.
|
|
179
|
+
*
|
|
180
|
+
* This ensures consistent string representation of expressions across different
|
|
181
|
+
* parts of the codebase.
|
|
182
|
+
*/
|
|
183
|
+
export function serializeExpression(expr) {
|
|
184
|
+
switch (expr.kind) {
|
|
185
|
+
case 'number':
|
|
186
|
+
return `num(${expr.value})`;
|
|
187
|
+
case 'variable':
|
|
188
|
+
return `var(${expr.name})`;
|
|
189
|
+
case 'binary':
|
|
190
|
+
return `bin(${expr.operator},${serializeExpression(expr.left)},${serializeExpression(expr.right)})`;
|
|
191
|
+
case 'unary':
|
|
192
|
+
return `un(${expr.operator},${serializeExpression(expr.operand)})`;
|
|
193
|
+
case 'call':
|
|
194
|
+
const args = expr.args.map(arg => serializeExpression(arg)).join(',');
|
|
195
|
+
return `call(${expr.name},${args})`;
|
|
196
|
+
case 'component':
|
|
197
|
+
return `comp(${serializeExpression(expr.object)},${expr.component})`;
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
/**
|
|
201
|
+
* Serializes an expression to canonical form for CSE matching.
|
|
202
|
+
* Commutative operations (+ and *) have operands sorted lexicographically,
|
|
203
|
+
* so a*b and b*a produce the same canonical string.
|
|
204
|
+
*/
|
|
205
|
+
export function serializeCanonical(expr) {
|
|
206
|
+
switch (expr.kind) {
|
|
207
|
+
case 'number':
|
|
208
|
+
return `num(${expr.value})`;
|
|
209
|
+
case 'variable':
|
|
210
|
+
return `var(${expr.name})`;
|
|
211
|
+
case 'binary': {
|
|
212
|
+
const leftStr = serializeCanonical(expr.left);
|
|
213
|
+
const rightStr = serializeCanonical(expr.right);
|
|
214
|
+
// For commutative operations, sort operands lexicographically
|
|
215
|
+
if (expr.operator === '+' || expr.operator === '*') {
|
|
216
|
+
const [first, second] = leftStr <= rightStr ? [leftStr, rightStr] : [rightStr, leftStr];
|
|
217
|
+
return `bin(${expr.operator},${first},${second})`;
|
|
218
|
+
}
|
|
219
|
+
// Non-commutative: preserve order
|
|
220
|
+
return `bin(${expr.operator},${leftStr},${rightStr})`;
|
|
221
|
+
}
|
|
222
|
+
case 'unary':
|
|
223
|
+
return `un(${expr.operator},${serializeCanonical(expr.operand)})`;
|
|
224
|
+
case 'call': {
|
|
225
|
+
const args = expr.args.map(arg => serializeCanonical(arg)).join(',');
|
|
226
|
+
return `call(${expr.name},${args})`;
|
|
227
|
+
}
|
|
228
|
+
case 'component':
|
|
229
|
+
return `comp(${serializeCanonical(expr.object)},${expr.component})`;
|
|
230
|
+
}
|
|
231
|
+
}
|
|
@@ -17,9 +17,25 @@ type NumValue = number | {
|
|
|
17
17
|
export interface GradCheckResult {
|
|
18
18
|
passed: boolean;
|
|
19
19
|
errors: GradCheckError[];
|
|
20
|
+
singularities: GradCheckSingularity[];
|
|
20
21
|
maxError: number;
|
|
21
22
|
meanError: number;
|
|
23
|
+
totalChecks: number;
|
|
22
24
|
}
|
|
25
|
+
/**
|
|
26
|
+
* Singularity detected during gradient checking
|
|
27
|
+
* When both analytical and numerical produce NaN/Inf, it's a singularity, not a bug
|
|
28
|
+
*/
|
|
29
|
+
export interface GradCheckSingularity {
|
|
30
|
+
parameter: string;
|
|
31
|
+
component?: string;
|
|
32
|
+
analytical: number;
|
|
33
|
+
numerical: number;
|
|
34
|
+
}
|
|
35
|
+
/**
|
|
36
|
+
* Format gradient check results as a human-readable string
|
|
37
|
+
*/
|
|
38
|
+
export declare function formatGradCheckResult(result: GradCheckResult, funcName: string): string;
|
|
23
39
|
export interface GradCheckError {
|
|
24
40
|
parameter: string;
|
|
25
41
|
component?: string;
|
|
@@ -39,6 +55,11 @@ export declare class GradientChecker {
|
|
|
39
55
|
* Check gradients for a function
|
|
40
56
|
*/
|
|
41
57
|
check(func: FunctionDef, gradients: GradientResult, env: TypeEnv, testPoint: Map<string, NumValue>): GradCheckResult;
|
|
58
|
+
/**
|
|
59
|
+
* Compare analytical and numerical gradients
|
|
60
|
+
* Distinguishes between: pass, error (mismatch), and singularity (both NaN/Inf)
|
|
61
|
+
*/
|
|
62
|
+
private compareGradients;
|
|
42
63
|
/**
|
|
43
64
|
* Compute numerical gradient for scalar parameter using finite differences
|
|
44
65
|
*/
|
|
@@ -4,6 +4,47 @@
|
|
|
4
4
|
*/
|
|
5
5
|
import { Types } from './Types.js';
|
|
6
6
|
import { expandBuiltIn, shouldExpand } from './Expander.js';
|
|
7
|
+
/**
|
|
8
|
+
* Format gradient check results as a human-readable string
|
|
9
|
+
*/
|
|
10
|
+
export function formatGradCheckResult(result, funcName) {
|
|
11
|
+
const singularityCount = result.singularities.length;
|
|
12
|
+
const verifiedCount = result.totalChecks - result.errors.length - singularityCount;
|
|
13
|
+
if (result.passed) {
|
|
14
|
+
let msg = `✓ ${funcName}: ${verifiedCount} gradients verified (max error: ${result.maxError.toExponential(2)})`;
|
|
15
|
+
if (singularityCount > 0) {
|
|
16
|
+
msg += `\n ⚠ ${singularityCount} singularities detected (both analytical and numerical produce NaN/Inf)`;
|
|
17
|
+
}
|
|
18
|
+
return msg;
|
|
19
|
+
}
|
|
20
|
+
const lines = [
|
|
21
|
+
`✗ ${funcName}: ${result.errors.length}/${result.totalChecks} gradients FAILED`
|
|
22
|
+
];
|
|
23
|
+
// Group errors by parameter
|
|
24
|
+
const byParam = new Map();
|
|
25
|
+
for (const err of result.errors) {
|
|
26
|
+
const key = err.parameter;
|
|
27
|
+
if (!byParam.has(key))
|
|
28
|
+
byParam.set(key, []);
|
|
29
|
+
byParam.get(key).push(err);
|
|
30
|
+
}
|
|
31
|
+
for (const [param, errs] of byParam) {
|
|
32
|
+
if (errs.length === 1 && !errs[0].component) {
|
|
33
|
+
// Scalar parameter
|
|
34
|
+
const e = errs[0];
|
|
35
|
+
lines.push(` ${param}: analytical=${e.analytical.toFixed(6)}, numerical=${e.numerical.toFixed(6)}, error=${e.error.toExponential(2)}`);
|
|
36
|
+
}
|
|
37
|
+
else {
|
|
38
|
+
// Structured parameter - show on one line if possible
|
|
39
|
+
const components = errs.map(e => `${e.component}:${e.error.toExponential(1)}`).join(', ');
|
|
40
|
+
lines.push(` ${param}: {${components}}`);
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
if (singularityCount > 0) {
|
|
44
|
+
lines.push(` ⚠ ${singularityCount} singularities also detected`);
|
|
45
|
+
}
|
|
46
|
+
return lines.join('\n');
|
|
47
|
+
}
|
|
7
48
|
/**
|
|
8
49
|
* Gradient checker
|
|
9
50
|
*/
|
|
@@ -19,6 +60,9 @@ export class GradientChecker {
|
|
|
19
60
|
*/
|
|
20
61
|
check(func, gradients, env, testPoint) {
|
|
21
62
|
const errors = [];
|
|
63
|
+
const singularities = [];
|
|
64
|
+
let totalChecks = 0;
|
|
65
|
+
let maxError = 0;
|
|
22
66
|
// For each parameter that has gradients
|
|
23
67
|
for (const [paramName, gradient] of gradients.gradients.entries()) {
|
|
24
68
|
const paramType = env.getOrThrow(paramName);
|
|
@@ -28,21 +72,21 @@ export class GradientChecker {
|
|
|
28
72
|
}
|
|
29
73
|
if (Types.isScalar(paramType)) {
|
|
30
74
|
// Scalar parameter
|
|
75
|
+
totalChecks++;
|
|
31
76
|
if (typeof paramValue !== 'number') {
|
|
32
77
|
throw new Error(`Expected scalar value for ${paramName}`);
|
|
33
78
|
}
|
|
34
79
|
const analytical = this.evaluateExpression(gradient, testPoint);
|
|
35
80
|
const numerical = this.numericalGradientScalar(func, testPoint, paramName);
|
|
36
|
-
const
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
});
|
|
81
|
+
const checkResult = this.compareGradients(analytical, numerical, paramName);
|
|
82
|
+
if (checkResult.type === 'singularity') {
|
|
83
|
+
singularities.push(checkResult.singularity);
|
|
84
|
+
}
|
|
85
|
+
else if (checkResult.type === 'error') {
|
|
86
|
+
errors.push(checkResult.error);
|
|
87
|
+
}
|
|
88
|
+
else if (checkResult.error) {
|
|
89
|
+
maxError = Math.max(maxError, checkResult.error.error);
|
|
46
90
|
}
|
|
47
91
|
}
|
|
48
92
|
else {
|
|
@@ -52,32 +96,74 @@ export class GradientChecker {
|
|
|
52
96
|
}
|
|
53
97
|
const structGrad = gradient;
|
|
54
98
|
for (const [comp, expr] of structGrad.components.entries()) {
|
|
99
|
+
totalChecks++;
|
|
55
100
|
const analytical = this.evaluateExpression(expr, testPoint);
|
|
56
101
|
const numerical = this.numericalGradientComponent(func, testPoint, paramName, comp);
|
|
57
|
-
const
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
relativeError
|
|
67
|
-
});
|
|
102
|
+
const checkResult = this.compareGradients(analytical, numerical, paramName, comp);
|
|
103
|
+
if (checkResult.type === 'singularity') {
|
|
104
|
+
singularities.push(checkResult.singularity);
|
|
105
|
+
}
|
|
106
|
+
else if (checkResult.type === 'error') {
|
|
107
|
+
errors.push(checkResult.error);
|
|
108
|
+
}
|
|
109
|
+
else if (checkResult.error) {
|
|
110
|
+
maxError = Math.max(maxError, checkResult.error.error);
|
|
68
111
|
}
|
|
69
112
|
}
|
|
70
113
|
}
|
|
71
114
|
}
|
|
72
|
-
const maxError = errors.length > 0 ? Math.max(...errors.map(e => e.error)) : 0;
|
|
73
115
|
const meanError = errors.length > 0
|
|
74
116
|
? errors.reduce((sum, e) => sum + e.error, 0) / errors.length
|
|
75
117
|
: 0;
|
|
76
118
|
return {
|
|
77
119
|
passed: errors.length === 0,
|
|
78
120
|
errors,
|
|
121
|
+
singularities,
|
|
79
122
|
maxError,
|
|
80
|
-
meanError
|
|
123
|
+
meanError,
|
|
124
|
+
totalChecks
|
|
125
|
+
};
|
|
126
|
+
}
|
|
127
|
+
/**
|
|
128
|
+
* Compare analytical and numerical gradients
|
|
129
|
+
* Distinguishes between: pass, error (mismatch), and singularity (both NaN/Inf)
|
|
130
|
+
*/
|
|
131
|
+
compareGradients(analytical, numerical, parameter, component) {
|
|
132
|
+
const analyticalBad = !isFinite(analytical);
|
|
133
|
+
const numericalBad = !isFinite(numerical);
|
|
134
|
+
// Both produce NaN/Inf: singularity (not a bug)
|
|
135
|
+
if (analyticalBad && numericalBad) {
|
|
136
|
+
return {
|
|
137
|
+
type: 'singularity',
|
|
138
|
+
singularity: { parameter, component, analytical, numerical }
|
|
139
|
+
};
|
|
140
|
+
}
|
|
141
|
+
// One produces NaN/Inf, the other doesn't: actual bug
|
|
142
|
+
if (analyticalBad || numericalBad) {
|
|
143
|
+
return {
|
|
144
|
+
type: 'error',
|
|
145
|
+
error: {
|
|
146
|
+
parameter,
|
|
147
|
+
component,
|
|
148
|
+
analytical,
|
|
149
|
+
numerical,
|
|
150
|
+
error: Infinity,
|
|
151
|
+
relativeError: Infinity
|
|
152
|
+
}
|
|
153
|
+
};
|
|
154
|
+
}
|
|
155
|
+
// Both finite: compare values
|
|
156
|
+
const error = Math.abs(analytical - numerical);
|
|
157
|
+
const relativeError = Math.abs(error / (Math.abs(numerical) + 1e-10));
|
|
158
|
+
if (error > this.tolerance && relativeError > this.tolerance) {
|
|
159
|
+
return {
|
|
160
|
+
type: 'error',
|
|
161
|
+
error: { parameter, component, analytical, numerical, error, relativeError }
|
|
162
|
+
};
|
|
163
|
+
}
|
|
164
|
+
return {
|
|
165
|
+
type: 'pass',
|
|
166
|
+
error: { parameter, component, analytical, numerical, error, relativeError }
|
|
81
167
|
};
|
|
82
168
|
}
|
|
83
169
|
/**
|
package/dist/dsl/Guards.d.ts
CHANGED
|
@@ -8,6 +8,8 @@ export interface Guard {
|
|
|
8
8
|
expression: Expression;
|
|
9
9
|
description: string;
|
|
10
10
|
suggestion: string;
|
|
11
|
+
variableName?: string;
|
|
12
|
+
line?: number;
|
|
11
13
|
}
|
|
12
14
|
export interface GuardAnalysisResult {
|
|
13
15
|
guards: Guard[];
|
|
@@ -20,7 +22,7 @@ export declare function analyzeGuards(func: FunctionDef): GuardAnalysisResult;
|
|
|
20
22
|
/**
|
|
21
23
|
* Format guard analysis for display
|
|
22
24
|
*/
|
|
23
|
-
export declare function formatGuardWarnings(result: GuardAnalysisResult): string;
|
|
25
|
+
export declare function formatGuardWarnings(result: GuardAnalysisResult, asComments?: boolean): string;
|
|
24
26
|
/**
|
|
25
27
|
* Generate guard code snippets for common cases
|
|
26
28
|
*/
|
package/dist/dsl/Guards.js
CHANGED
|
@@ -8,11 +8,11 @@
|
|
|
8
8
|
export function analyzeGuards(func) {
|
|
9
9
|
const guards = [];
|
|
10
10
|
// Analyze return expression
|
|
11
|
-
collectGuards(func.returnExpr, guards);
|
|
11
|
+
collectGuards(func.returnExpr, guards, undefined);
|
|
12
12
|
// Analyze intermediate expressions
|
|
13
13
|
for (const stmt of func.body) {
|
|
14
14
|
if (stmt.kind === 'assignment') {
|
|
15
|
-
collectGuards(stmt.expression, guards);
|
|
15
|
+
collectGuards(stmt.expression, guards, stmt.variable, stmt.loc?.line);
|
|
16
16
|
}
|
|
17
17
|
}
|
|
18
18
|
return {
|
|
@@ -23,7 +23,7 @@ export function analyzeGuards(func) {
|
|
|
23
23
|
/**
|
|
24
24
|
* Collect potential guards from an expression
|
|
25
25
|
*/
|
|
26
|
-
function collectGuards(expr, guards) {
|
|
26
|
+
function collectGuards(expr, guards, variableName, line) {
|
|
27
27
|
switch (expr.kind) {
|
|
28
28
|
case 'binary':
|
|
29
29
|
if (expr.operator === '/') {
|
|
@@ -31,47 +31,60 @@ function collectGuards(expr, guards) {
|
|
|
31
31
|
type: 'division_by_zero',
|
|
32
32
|
expression: expr.right,
|
|
33
33
|
description: `Division by zero if denominator becomes zero`,
|
|
34
|
-
suggestion: `Add check: if (denominator
|
|
34
|
+
suggestion: `Add check: if (Math.abs(denominator) < epsilon) return {...};`,
|
|
35
|
+
variableName,
|
|
36
|
+
line: line || expr.loc?.line
|
|
35
37
|
});
|
|
36
38
|
}
|
|
37
|
-
collectGuards(expr.left, guards);
|
|
38
|
-
collectGuards(expr.right, guards);
|
|
39
|
+
collectGuards(expr.left, guards, variableName, line);
|
|
40
|
+
collectGuards(expr.right, guards, variableName, line);
|
|
39
41
|
break;
|
|
40
42
|
case 'unary':
|
|
41
|
-
collectGuards(expr.operand, guards);
|
|
43
|
+
collectGuards(expr.operand, guards, variableName, line);
|
|
42
44
|
break;
|
|
43
45
|
case 'call':
|
|
44
|
-
analyzeCallGuards(expr, guards);
|
|
46
|
+
analyzeCallGuards(expr, guards, variableName, line || expr.loc?.line);
|
|
45
47
|
for (const arg of expr.args) {
|
|
46
|
-
collectGuards(arg, guards);
|
|
48
|
+
collectGuards(arg, guards, variableName, line);
|
|
47
49
|
}
|
|
48
50
|
break;
|
|
49
51
|
case 'component':
|
|
50
|
-
collectGuards(expr.object, guards);
|
|
52
|
+
collectGuards(expr.object, guards, variableName, line);
|
|
51
53
|
break;
|
|
52
54
|
}
|
|
53
55
|
}
|
|
54
56
|
/**
|
|
55
57
|
* Analyze function calls for specific edge cases
|
|
56
58
|
*/
|
|
57
|
-
function analyzeCallGuards(expr, guards) {
|
|
59
|
+
function analyzeCallGuards(expr, guards, variableName, line) {
|
|
58
60
|
switch (expr.name) {
|
|
59
61
|
case 'sqrt':
|
|
62
|
+
// Check if it's sqrt of sum of squares (always safe)
|
|
63
|
+
const arg = expr.args[0];
|
|
64
|
+
const isSumOfSquares = arg.kind === 'binary' && arg.operator === '+' &&
|
|
65
|
+
isSqExpression(arg.left) && isSqExpression(arg.right);
|
|
60
66
|
guards.push({
|
|
61
67
|
type: 'sqrt_negative',
|
|
62
68
|
expression: expr.args[0],
|
|
63
|
-
description:
|
|
64
|
-
|
|
69
|
+
description: isSumOfSquares
|
|
70
|
+
? `sqrt of sum of squares (safe, but can be zero)`
|
|
71
|
+
: `sqrt of negative number produces NaN`,
|
|
72
|
+
suggestion: isSumOfSquares
|
|
73
|
+
? `Add epsilon for numerical stability: sqrt(max(dx*dx + dy*dy, epsilon))`
|
|
74
|
+
: `Guard negative values: sqrt(max(0, value))`,
|
|
75
|
+
variableName,
|
|
76
|
+
line
|
|
65
77
|
});
|
|
66
78
|
break;
|
|
67
79
|
case 'magnitude2d':
|
|
68
80
|
case 'magnitude3d':
|
|
69
|
-
// magnitude uses sqrt internally
|
|
70
81
|
guards.push({
|
|
71
82
|
type: 'sqrt_negative',
|
|
72
83
|
expression: expr,
|
|
73
|
-
description: `magnitude
|
|
74
|
-
suggestion: `
|
|
84
|
+
description: `magnitude uses sqrt internally (safe, but can be zero)`,
|
|
85
|
+
suggestion: `Gradients may have division by zero when magnitude is zero`,
|
|
86
|
+
variableName,
|
|
87
|
+
line
|
|
75
88
|
});
|
|
76
89
|
break;
|
|
77
90
|
case 'normalize2d':
|
|
@@ -80,15 +93,19 @@ function analyzeCallGuards(expr, guards) {
|
|
|
80
93
|
type: 'normalize_zero',
|
|
81
94
|
expression: expr,
|
|
82
95
|
description: `Normalizing zero vector causes division by zero`,
|
|
83
|
-
suggestion: `
|
|
96
|
+
suggestion: `if (magnitude < epsilon) return zero vector or skip normalization`,
|
|
97
|
+
variableName,
|
|
98
|
+
line
|
|
84
99
|
});
|
|
85
100
|
break;
|
|
86
101
|
case 'atan2':
|
|
87
102
|
guards.push({
|
|
88
103
|
type: 'atan2_zero',
|
|
89
104
|
expression: expr,
|
|
90
|
-
description: `atan2(0, 0) is undefined`,
|
|
91
|
-
suggestion: `
|
|
105
|
+
description: `atan2(0, 0) is undefined and gradients have division by zero`,
|
|
106
|
+
suggestion: `if (y === 0 && x === 0) return 0 with zero gradients`,
|
|
107
|
+
variableName,
|
|
108
|
+
line
|
|
92
109
|
});
|
|
93
110
|
break;
|
|
94
111
|
case 'log':
|
|
@@ -96,7 +113,9 @@ function analyzeCallGuards(expr, guards) {
|
|
|
96
113
|
type: 'division_by_zero',
|
|
97
114
|
expression: expr.args[0],
|
|
98
115
|
description: `log(0) is -Infinity, log(negative) is NaN`,
|
|
99
|
-
suggestion: `
|
|
116
|
+
suggestion: `Clamp to positive: log(max(epsilon, value))`,
|
|
117
|
+
variableName,
|
|
118
|
+
line
|
|
100
119
|
});
|
|
101
120
|
break;
|
|
102
121
|
case 'asin':
|
|
@@ -105,42 +124,66 @@ function analyzeCallGuards(expr, guards) {
|
|
|
105
124
|
type: 'division_by_zero',
|
|
106
125
|
expression: expr.args[0],
|
|
107
126
|
description: `${expr.name} requires argument in [-1, 1]`,
|
|
108
|
-
suggestion: `Clamp
|
|
127
|
+
suggestion: `Clamp: ${expr.name}(max(-1, min(1, value)))`,
|
|
128
|
+
variableName,
|
|
129
|
+
line
|
|
109
130
|
});
|
|
110
131
|
break;
|
|
111
132
|
}
|
|
112
133
|
}
|
|
134
|
+
/**
|
|
135
|
+
* Check if expression is a squared term (x^2 or x*x)
|
|
136
|
+
*/
|
|
137
|
+
function isSqExpression(expr) {
|
|
138
|
+
if (expr.kind === 'binary') {
|
|
139
|
+
if (expr.operator === '*') {
|
|
140
|
+
// Check if x * x
|
|
141
|
+
if (expr.left.kind === 'variable' && expr.right.kind === 'variable') {
|
|
142
|
+
return expr.left.name === expr.right.name;
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
else if (expr.operator === '^' || expr.operator === '**') {
|
|
146
|
+
// Check if x^2
|
|
147
|
+
if (expr.right.kind === 'number' && expr.right.value === 2) {
|
|
148
|
+
return true;
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
return false;
|
|
153
|
+
}
|
|
113
154
|
/**
|
|
114
155
|
* Format guard analysis for display
|
|
115
156
|
*/
|
|
116
|
-
export function formatGuardWarnings(result) {
|
|
157
|
+
export function formatGuardWarnings(result, asComments = false) {
|
|
117
158
|
if (!result.hasIssues) {
|
|
118
159
|
return '';
|
|
119
160
|
}
|
|
161
|
+
const prefix = asComments ? '// ' : '';
|
|
120
162
|
const lines = [];
|
|
121
|
-
lines.push(
|
|
122
|
-
lines.push(
|
|
123
|
-
lines.push(
|
|
124
|
-
lines.push(
|
|
125
|
-
lines.push(
|
|
126
|
-
//
|
|
127
|
-
const byType = new Map();
|
|
163
|
+
lines.push(`${prefix}⚠️ EDGE CASE WARNINGS:`);
|
|
164
|
+
lines.push(prefix);
|
|
165
|
+
lines.push(`${prefix}The generated code may encounter edge cases that produce`);
|
|
166
|
+
lines.push(`${prefix}NaN, Infinity, or incorrect gradients:`);
|
|
167
|
+
lines.push(prefix);
|
|
168
|
+
// Show each guard individually with context
|
|
128
169
|
for (const guard of result.guards) {
|
|
129
|
-
const
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
lines.push(
|
|
139
|
-
lines.push(
|
|
140
|
-
lines.push(
|
|
170
|
+
const typeLabel = formatGuardType(guard.type);
|
|
171
|
+
// Show location and variable if available
|
|
172
|
+
let location = `${prefix} •`;
|
|
173
|
+
if (guard.line) {
|
|
174
|
+
location += ` Line ${guard.line}:`;
|
|
175
|
+
}
|
|
176
|
+
if (guard.variableName) {
|
|
177
|
+
location += ` ${guard.variableName} =`;
|
|
178
|
+
}
|
|
179
|
+
lines.push(location);
|
|
180
|
+
lines.push(`${prefix} ${typeLabel}: ${guard.description}`);
|
|
181
|
+
lines.push(`${prefix} 💡 Fix: ${guard.suggestion}`);
|
|
182
|
+
lines.push(prefix);
|
|
141
183
|
}
|
|
142
|
-
lines.push(
|
|
143
|
-
lines.push(
|
|
184
|
+
lines.push(`${prefix}Add runtime checks or ensure inputs are within valid ranges.`);
|
|
185
|
+
lines.push(`${prefix}Use --guards --epsilon 1e-10 to automatically emit epsilon guards.`);
|
|
186
|
+
lines.push(prefix);
|
|
144
187
|
return lines.join('\n');
|
|
145
188
|
}
|
|
146
189
|
function formatGuardType(type) {
|
package/dist/dsl/Inliner.d.ts
CHANGED
|
@@ -8,3 +8,8 @@ import { Expression, FunctionDef } from './AST.js';
|
|
|
8
8
|
* Returns a new expression with all intermediate variables substituted
|
|
9
9
|
*/
|
|
10
10
|
export declare function inlineIntermediateVariables(func: FunctionDef): Expression;
|
|
11
|
+
/**
|
|
12
|
+
* Inline an expression using a substitution map
|
|
13
|
+
* Used to get the fully-expanded form of forward pass expressions
|
|
14
|
+
*/
|
|
15
|
+
export declare function inlineExpression(expr: Expression, substitutions: Map<string, Expression>): Expression;
|