gradient-script 0.1.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. package/README.md +52 -9
  2. package/dist/cli.js +134 -19
  3. package/dist/dsl/AST.d.ts +8 -0
  4. package/dist/dsl/CodeGen.d.ts +8 -3
  5. package/dist/dsl/CodeGen.js +583 -132
  6. package/dist/dsl/Errors.d.ts +6 -1
  7. package/dist/dsl/Errors.js +70 -1
  8. package/dist/dsl/Expander.js +5 -2
  9. package/dist/dsl/ExpressionUtils.d.ts +14 -0
  10. package/dist/dsl/ExpressionUtils.js +56 -0
  11. package/dist/dsl/GradientChecker.d.ts +21 -0
  12. package/dist/dsl/GradientChecker.js +109 -23
  13. package/dist/dsl/Guards.d.ts +3 -1
  14. package/dist/dsl/Guards.js +86 -43
  15. package/dist/dsl/Inliner.d.ts +5 -0
  16. package/dist/dsl/Inliner.js +11 -2
  17. package/dist/dsl/Lexer.js +3 -1
  18. package/dist/dsl/Parser.js +11 -5
  19. package/dist/dsl/Simplify.d.ts +7 -0
  20. package/dist/dsl/Simplify.js +183 -0
  21. package/dist/dsl/egraph/Convert.d.ts +23 -0
  22. package/dist/dsl/egraph/Convert.js +84 -0
  23. package/dist/dsl/egraph/EGraph.d.ts +93 -0
  24. package/dist/dsl/egraph/EGraph.js +292 -0
  25. package/dist/dsl/egraph/ENode.d.ts +63 -0
  26. package/dist/dsl/egraph/ENode.js +94 -0
  27. package/dist/dsl/egraph/Extractor.d.ts +49 -0
  28. package/dist/dsl/egraph/Extractor.js +1068 -0
  29. package/dist/dsl/egraph/Optimizer.d.ts +50 -0
  30. package/dist/dsl/egraph/Optimizer.js +88 -0
  31. package/dist/dsl/egraph/Pattern.d.ts +80 -0
  32. package/dist/dsl/egraph/Pattern.js +325 -0
  33. package/dist/dsl/egraph/Rewriter.d.ts +44 -0
  34. package/dist/dsl/egraph/Rewriter.js +131 -0
  35. package/dist/dsl/egraph/Rules.d.ts +44 -0
  36. package/dist/dsl/egraph/Rules.js +187 -0
  37. package/dist/dsl/egraph/index.d.ts +15 -0
  38. package/dist/dsl/egraph/index.js +21 -0
  39. package/package.json +1 -1
  40. package/dist/dsl/CSE.d.ts +0 -21
  41. package/dist/dsl/CSE.js +0 -194
  42. package/dist/symbolic/AST.d.ts +0 -113
  43. package/dist/symbolic/AST.js +0 -128
  44. package/dist/symbolic/CodeGen.d.ts +0 -35
  45. package/dist/symbolic/CodeGen.js +0 -280
  46. package/dist/symbolic/Parser.d.ts +0 -64
  47. package/dist/symbolic/Parser.js +0 -329
  48. package/dist/symbolic/Simplify.d.ts +0 -10
  49. package/dist/symbolic/Simplify.js +0 -244
  50. package/dist/symbolic/SymbolicDiff.d.ts +0 -35
  51. package/dist/symbolic/SymbolicDiff.js +0 -339
@@ -2,8 +2,13 @@ export declare class ParseError extends Error {
2
2
  line: number;
3
3
  column: number;
4
4
  token?: string | undefined;
5
- constructor(message: string, line: number, column: number, token?: string | undefined);
5
+ sourceContext?: string | undefined;
6
+ constructor(message: string, line: number, column: number, token?: string | undefined, sourceContext?: string | undefined);
6
7
  }
8
+ /**
9
+ * Format a user-friendly error message with source context
10
+ */
11
+ export declare function formatParseError(error: ParseError, sourceCode: string, verbose?: boolean): string;
7
12
  export declare class TypeError extends Error {
8
13
  expression: string;
9
14
  expectedType?: string | undefined;
@@ -2,14 +2,83 @@ export class ParseError extends Error {
2
2
  line;
3
3
  column;
4
4
  token;
5
- constructor(message, line, column, token) {
5
+ sourceContext;
6
+ constructor(message, line, column, token, sourceContext) {
6
7
  super(`Parse error at ${line}:${column}: ${message}`);
7
8
  this.line = line;
8
9
  this.column = column;
9
10
  this.token = token;
11
+ this.sourceContext = sourceContext;
10
12
  this.name = 'ParseError';
11
13
  }
12
14
  }
15
+ /**
16
+ * Format a user-friendly error message with source context
17
+ */
18
+ export function formatParseError(error, sourceCode, verbose = false) {
19
+ const lines = sourceCode.split('\n');
20
+ const errorLine = lines[error.line - 1];
21
+ let output = `Error: ${error.message.replace(/^Parse error at \d+:\d+: /, '')}\n`;
22
+ // Show the source line with the error
23
+ if (errorLine) {
24
+ output += `\n ${errorLine}\n`;
25
+ // Add caret pointing to error position
26
+ const caretPos = Math.max(0, error.column - 1);
27
+ output += ` ${' '.repeat(caretPos)}^\n`;
28
+ }
29
+ // Add helpful tips based on the error
30
+ output += formatErrorGuidance(error);
31
+ // Only show stack trace in verbose mode
32
+ if (verbose && error.stack) {
33
+ output += '\n\nStack trace:\n' + error.stack;
34
+ }
35
+ return output;
36
+ }
37
+ /**
38
+ * Provide contextual guidance based on error patterns
39
+ */
40
+ function formatErrorGuidance(error) {
41
+ const msg = error.message.toLowerCase();
42
+ const token = error.token;
43
+ // Semicolon error
44
+ if (token === ';') {
45
+ return `
46
+ Semicolons are not part of gradient-script syntax.
47
+ Each statement should be on its own line.
48
+
49
+ Correct syntax:
50
+ function example(x∇, y∇) {
51
+ result = x + y
52
+ return result
53
+ }
54
+
55
+ 💡 Tip: gradient-script uses newline-delimited statements (like Python),
56
+ not semicolons (like JavaScript/C#).
57
+ `;
58
+ }
59
+ // Missing colon in type annotation
60
+ if (msg.includes("expected ':'")) {
61
+ return `
62
+ Type annotations require a colon before the type.
63
+
64
+ Correct syntax:
65
+ function distance(point∇: {x, y}) {
66
+ ^
67
+
68
+ 💡 Tip: Parameters marked with ∇ need type annotations to specify structure.
69
+ `;
70
+ }
71
+ // Missing gradient marker suggestion
72
+ if (msg.includes('expected parameter name') || msg.includes('unexpected')) {
73
+ return `
74
+ 💡 Tip: Make sure all parameters are properly formatted.
75
+ Variables that need gradients must be marked with ∇.
76
+
77
+ Example: function f(a∇: {x, y}, b) { ... }
78
+ `;
79
+ }
80
+ return '';
81
+ }
13
82
  export class TypeError extends Error {
14
83
  expression;
15
84
  expectedType;
@@ -2,6 +2,7 @@
2
2
  * Expander for GradientScript DSL
3
3
  * Expands built-in functions and struct operations into scalar operations
4
4
  */
5
+ import { DifferentiationError } from './Errors.js';
5
6
  /**
6
7
  * Expand built-in function calls to scalar expressions
7
8
  */
@@ -15,13 +16,15 @@ export function expandBuiltIn(call) {
15
16
  case 'magnitude2d':
16
17
  return expandMagnitude2d(args[0]);
17
18
  case 'normalize2d':
18
- throw new Error('normalize2d not yet supported in differentiation');
19
+ throw new DifferentiationError('normalize2d not yet supported', 'normalize2d', 'Vector normalization requires special handling for zero-length vectors. ' +
20
+ 'Use magnitude2d() and division for now.');
19
21
  case 'distance2d':
20
22
  return expandDistance2d(args[0], args[1]);
21
23
  case 'dot3d':
22
24
  return expandDot3d(args[0], args[1]);
23
25
  case 'cross3d':
24
- throw new Error('cross3d returns vector - not yet supported');
26
+ throw new DifferentiationError('cross3d returns vector - not yet supported', 'cross3d', 'Cross product returns a 3D vector, which requires structured gradient support. ' +
27
+ 'This feature is not yet implemented.');
25
28
  case 'magnitude3d':
26
29
  return expandMagnitude3d(args[0]);
27
30
  default:
@@ -53,3 +53,17 @@ export declare function containsVariable(expr: Expression, varName: string): boo
53
53
  * Calculate the maximum nesting depth of an expression
54
54
  */
55
55
  export declare function expressionDepth(expr: Expression): number;
56
+ /**
57
+ * Serializes an expression to structural string representation.
58
+ * Used for exact expression comparison - operand order matters.
59
+ *
60
+ * This ensures consistent string representation of expressions across different
61
+ * parts of the codebase.
62
+ */
63
+ export declare function serializeExpression(expr: Expression): string;
64
+ /**
65
+ * Serializes an expression to canonical form for CSE matching.
66
+ * Commutative operations (+ and *) have operands sorted lexicographically,
67
+ * so a*b and b*a produce the same canonical string.
68
+ */
69
+ export declare function serializeCanonical(expr: Expression): string;
@@ -173,3 +173,59 @@ export function expressionDepth(expr) {
173
173
  return 1 + expressionDepth(expr.object);
174
174
  }
175
175
  }
176
+ /**
177
+ * Serializes an expression to structural string representation.
178
+ * Used for exact expression comparison - operand order matters.
179
+ *
180
+ * This ensures consistent string representation of expressions across different
181
+ * parts of the codebase.
182
+ */
183
+ export function serializeExpression(expr) {
184
+ switch (expr.kind) {
185
+ case 'number':
186
+ return `num(${expr.value})`;
187
+ case 'variable':
188
+ return `var(${expr.name})`;
189
+ case 'binary':
190
+ return `bin(${expr.operator},${serializeExpression(expr.left)},${serializeExpression(expr.right)})`;
191
+ case 'unary':
192
+ return `un(${expr.operator},${serializeExpression(expr.operand)})`;
193
+ case 'call':
194
+ const args = expr.args.map(arg => serializeExpression(arg)).join(',');
195
+ return `call(${expr.name},${args})`;
196
+ case 'component':
197
+ return `comp(${serializeExpression(expr.object)},${expr.component})`;
198
+ }
199
+ }
200
+ /**
201
+ * Serializes an expression to canonical form for CSE matching.
202
+ * Commutative operations (+ and *) have operands sorted lexicographically,
203
+ * so a*b and b*a produce the same canonical string.
204
+ */
205
+ export function serializeCanonical(expr) {
206
+ switch (expr.kind) {
207
+ case 'number':
208
+ return `num(${expr.value})`;
209
+ case 'variable':
210
+ return `var(${expr.name})`;
211
+ case 'binary': {
212
+ const leftStr = serializeCanonical(expr.left);
213
+ const rightStr = serializeCanonical(expr.right);
214
+ // For commutative operations, sort operands lexicographically
215
+ if (expr.operator === '+' || expr.operator === '*') {
216
+ const [first, second] = leftStr <= rightStr ? [leftStr, rightStr] : [rightStr, leftStr];
217
+ return `bin(${expr.operator},${first},${second})`;
218
+ }
219
+ // Non-commutative: preserve order
220
+ return `bin(${expr.operator},${leftStr},${rightStr})`;
221
+ }
222
+ case 'unary':
223
+ return `un(${expr.operator},${serializeCanonical(expr.operand)})`;
224
+ case 'call': {
225
+ const args = expr.args.map(arg => serializeCanonical(arg)).join(',');
226
+ return `call(${expr.name},${args})`;
227
+ }
228
+ case 'component':
229
+ return `comp(${serializeCanonical(expr.object)},${expr.component})`;
230
+ }
231
+ }
@@ -17,9 +17,25 @@ type NumValue = number | {
17
17
  export interface GradCheckResult {
18
18
  passed: boolean;
19
19
  errors: GradCheckError[];
20
+ singularities: GradCheckSingularity[];
20
21
  maxError: number;
21
22
  meanError: number;
23
+ totalChecks: number;
22
24
  }
25
+ /**
26
+ * Singularity detected during gradient checking
27
+ * When both analytical and numerical produce NaN/Inf, it's a singularity, not a bug
28
+ */
29
+ export interface GradCheckSingularity {
30
+ parameter: string;
31
+ component?: string;
32
+ analytical: number;
33
+ numerical: number;
34
+ }
35
+ /**
36
+ * Format gradient check results as a human-readable string
37
+ */
38
+ export declare function formatGradCheckResult(result: GradCheckResult, funcName: string): string;
23
39
  export interface GradCheckError {
24
40
  parameter: string;
25
41
  component?: string;
@@ -39,6 +55,11 @@ export declare class GradientChecker {
39
55
  * Check gradients for a function
40
56
  */
41
57
  check(func: FunctionDef, gradients: GradientResult, env: TypeEnv, testPoint: Map<string, NumValue>): GradCheckResult;
58
+ /**
59
+ * Compare analytical and numerical gradients
60
+ * Distinguishes between: pass, error (mismatch), and singularity (both NaN/Inf)
61
+ */
62
+ private compareGradients;
42
63
  /**
43
64
  * Compute numerical gradient for scalar parameter using finite differences
44
65
  */
@@ -4,6 +4,47 @@
4
4
  */
5
5
  import { Types } from './Types.js';
6
6
  import { expandBuiltIn, shouldExpand } from './Expander.js';
7
+ /**
8
+ * Format gradient check results as a human-readable string
9
+ */
10
+ export function formatGradCheckResult(result, funcName) {
11
+ const singularityCount = result.singularities.length;
12
+ const verifiedCount = result.totalChecks - result.errors.length - singularityCount;
13
+ if (result.passed) {
14
+ let msg = `✓ ${funcName}: ${verifiedCount} gradients verified (max error: ${result.maxError.toExponential(2)})`;
15
+ if (singularityCount > 0) {
16
+ msg += `\n ⚠ ${singularityCount} singularities detected (both analytical and numerical produce NaN/Inf)`;
17
+ }
18
+ return msg;
19
+ }
20
+ const lines = [
21
+ `✗ ${funcName}: ${result.errors.length}/${result.totalChecks} gradients FAILED`
22
+ ];
23
+ // Group errors by parameter
24
+ const byParam = new Map();
25
+ for (const err of result.errors) {
26
+ const key = err.parameter;
27
+ if (!byParam.has(key))
28
+ byParam.set(key, []);
29
+ byParam.get(key).push(err);
30
+ }
31
+ for (const [param, errs] of byParam) {
32
+ if (errs.length === 1 && !errs[0].component) {
33
+ // Scalar parameter
34
+ const e = errs[0];
35
+ lines.push(` ${param}: analytical=${e.analytical.toFixed(6)}, numerical=${e.numerical.toFixed(6)}, error=${e.error.toExponential(2)}`);
36
+ }
37
+ else {
38
+ // Structured parameter - show on one line if possible
39
+ const components = errs.map(e => `${e.component}:${e.error.toExponential(1)}`).join(', ');
40
+ lines.push(` ${param}: {${components}}`);
41
+ }
42
+ }
43
+ if (singularityCount > 0) {
44
+ lines.push(` ⚠ ${singularityCount} singularities also detected`);
45
+ }
46
+ return lines.join('\n');
47
+ }
7
48
  /**
8
49
  * Gradient checker
9
50
  */
@@ -19,6 +60,9 @@ export class GradientChecker {
19
60
  */
20
61
  check(func, gradients, env, testPoint) {
21
62
  const errors = [];
63
+ const singularities = [];
64
+ let totalChecks = 0;
65
+ let maxError = 0;
22
66
  // For each parameter that has gradients
23
67
  for (const [paramName, gradient] of gradients.gradients.entries()) {
24
68
  const paramType = env.getOrThrow(paramName);
@@ -28,21 +72,21 @@ export class GradientChecker {
28
72
  }
29
73
  if (Types.isScalar(paramType)) {
30
74
  // Scalar parameter
75
+ totalChecks++;
31
76
  if (typeof paramValue !== 'number') {
32
77
  throw new Error(`Expected scalar value for ${paramName}`);
33
78
  }
34
79
  const analytical = this.evaluateExpression(gradient, testPoint);
35
80
  const numerical = this.numericalGradientScalar(func, testPoint, paramName);
36
- const error = Math.abs(analytical - numerical);
37
- const relativeError = Math.abs(error / (numerical + 1e-10));
38
- if (error > this.tolerance && relativeError > this.tolerance) {
39
- errors.push({
40
- parameter: paramName,
41
- analytical,
42
- numerical,
43
- error,
44
- relativeError
45
- });
81
+ const checkResult = this.compareGradients(analytical, numerical, paramName);
82
+ if (checkResult.type === 'singularity') {
83
+ singularities.push(checkResult.singularity);
84
+ }
85
+ else if (checkResult.type === 'error') {
86
+ errors.push(checkResult.error);
87
+ }
88
+ else if (checkResult.error) {
89
+ maxError = Math.max(maxError, checkResult.error.error);
46
90
  }
47
91
  }
48
92
  else {
@@ -52,32 +96,74 @@ export class GradientChecker {
52
96
  }
53
97
  const structGrad = gradient;
54
98
  for (const [comp, expr] of structGrad.components.entries()) {
99
+ totalChecks++;
55
100
  const analytical = this.evaluateExpression(expr, testPoint);
56
101
  const numerical = this.numericalGradientComponent(func, testPoint, paramName, comp);
57
- const error = Math.abs(analytical - numerical);
58
- const relativeError = Math.abs(error / (numerical + 1e-10));
59
- if (error > this.tolerance && relativeError > this.tolerance) {
60
- errors.push({
61
- parameter: paramName,
62
- component: comp,
63
- analytical,
64
- numerical,
65
- error,
66
- relativeError
67
- });
102
+ const checkResult = this.compareGradients(analytical, numerical, paramName, comp);
103
+ if (checkResult.type === 'singularity') {
104
+ singularities.push(checkResult.singularity);
105
+ }
106
+ else if (checkResult.type === 'error') {
107
+ errors.push(checkResult.error);
108
+ }
109
+ else if (checkResult.error) {
110
+ maxError = Math.max(maxError, checkResult.error.error);
68
111
  }
69
112
  }
70
113
  }
71
114
  }
72
- const maxError = errors.length > 0 ? Math.max(...errors.map(e => e.error)) : 0;
73
115
  const meanError = errors.length > 0
74
116
  ? errors.reduce((sum, e) => sum + e.error, 0) / errors.length
75
117
  : 0;
76
118
  return {
77
119
  passed: errors.length === 0,
78
120
  errors,
121
+ singularities,
79
122
  maxError,
80
- meanError
123
+ meanError,
124
+ totalChecks
125
+ };
126
+ }
127
+ /**
128
+ * Compare analytical and numerical gradients
129
+ * Distinguishes between: pass, error (mismatch), and singularity (both NaN/Inf)
130
+ */
131
+ compareGradients(analytical, numerical, parameter, component) {
132
+ const analyticalBad = !isFinite(analytical);
133
+ const numericalBad = !isFinite(numerical);
134
+ // Both produce NaN/Inf: singularity (not a bug)
135
+ if (analyticalBad && numericalBad) {
136
+ return {
137
+ type: 'singularity',
138
+ singularity: { parameter, component, analytical, numerical }
139
+ };
140
+ }
141
+ // One produces NaN/Inf, the other doesn't: actual bug
142
+ if (analyticalBad || numericalBad) {
143
+ return {
144
+ type: 'error',
145
+ error: {
146
+ parameter,
147
+ component,
148
+ analytical,
149
+ numerical,
150
+ error: Infinity,
151
+ relativeError: Infinity
152
+ }
153
+ };
154
+ }
155
+ // Both finite: compare values
156
+ const error = Math.abs(analytical - numerical);
157
+ const relativeError = Math.abs(error / (Math.abs(numerical) + 1e-10));
158
+ if (error > this.tolerance && relativeError > this.tolerance) {
159
+ return {
160
+ type: 'error',
161
+ error: { parameter, component, analytical, numerical, error, relativeError }
162
+ };
163
+ }
164
+ return {
165
+ type: 'pass',
166
+ error: { parameter, component, analytical, numerical, error, relativeError }
81
167
  };
82
168
  }
83
169
  /**
@@ -8,6 +8,8 @@ export interface Guard {
8
8
  expression: Expression;
9
9
  description: string;
10
10
  suggestion: string;
11
+ variableName?: string;
12
+ line?: number;
11
13
  }
12
14
  export interface GuardAnalysisResult {
13
15
  guards: Guard[];
@@ -20,7 +22,7 @@ export declare function analyzeGuards(func: FunctionDef): GuardAnalysisResult;
20
22
  /**
21
23
  * Format guard analysis for display
22
24
  */
23
- export declare function formatGuardWarnings(result: GuardAnalysisResult): string;
25
+ export declare function formatGuardWarnings(result: GuardAnalysisResult, asComments?: boolean): string;
24
26
  /**
25
27
  * Generate guard code snippets for common cases
26
28
  */
@@ -8,11 +8,11 @@
8
8
  export function analyzeGuards(func) {
9
9
  const guards = [];
10
10
  // Analyze return expression
11
- collectGuards(func.returnExpr, guards);
11
+ collectGuards(func.returnExpr, guards, undefined);
12
12
  // Analyze intermediate expressions
13
13
  for (const stmt of func.body) {
14
14
  if (stmt.kind === 'assignment') {
15
- collectGuards(stmt.expression, guards);
15
+ collectGuards(stmt.expression, guards, stmt.variable, stmt.loc?.line);
16
16
  }
17
17
  }
18
18
  return {
@@ -23,7 +23,7 @@ export function analyzeGuards(func) {
23
23
  /**
24
24
  * Collect potential guards from an expression
25
25
  */
26
- function collectGuards(expr, guards) {
26
+ function collectGuards(expr, guards, variableName, line) {
27
27
  switch (expr.kind) {
28
28
  case 'binary':
29
29
  if (expr.operator === '/') {
@@ -31,47 +31,60 @@ function collectGuards(expr, guards) {
31
31
  type: 'division_by_zero',
32
32
  expression: expr.right,
33
33
  description: `Division by zero if denominator becomes zero`,
34
- suggestion: `Add check: if (denominator === 0) return { value: 0, gradients: {...} };`
34
+ suggestion: `Add check: if (Math.abs(denominator) < epsilon) return {...};`,
35
+ variableName,
36
+ line: line || expr.loc?.line
35
37
  });
36
38
  }
37
- collectGuards(expr.left, guards);
38
- collectGuards(expr.right, guards);
39
+ collectGuards(expr.left, guards, variableName, line);
40
+ collectGuards(expr.right, guards, variableName, line);
39
41
  break;
40
42
  case 'unary':
41
- collectGuards(expr.operand, guards);
43
+ collectGuards(expr.operand, guards, variableName, line);
42
44
  break;
43
45
  case 'call':
44
- analyzeCallGuards(expr, guards);
46
+ analyzeCallGuards(expr, guards, variableName, line || expr.loc?.line);
45
47
  for (const arg of expr.args) {
46
- collectGuards(arg, guards);
48
+ collectGuards(arg, guards, variableName, line);
47
49
  }
48
50
  break;
49
51
  case 'component':
50
- collectGuards(expr.object, guards);
52
+ collectGuards(expr.object, guards, variableName, line);
51
53
  break;
52
54
  }
53
55
  }
54
56
  /**
55
57
  * Analyze function calls for specific edge cases
56
58
  */
57
- function analyzeCallGuards(expr, guards) {
59
+ function analyzeCallGuards(expr, guards, variableName, line) {
58
60
  switch (expr.name) {
59
61
  case 'sqrt':
62
+ // Check if it's sqrt of sum of squares (always safe)
63
+ const arg = expr.args[0];
64
+ const isSumOfSquares = arg.kind === 'binary' && arg.operator === '+' &&
65
+ isSqExpression(arg.left) && isSqExpression(arg.right);
60
66
  guards.push({
61
67
  type: 'sqrt_negative',
62
68
  expression: expr.args[0],
63
- description: `sqrt of negative number produces NaN`,
64
- suggestion: `Add check: Math.max(0, value) or abs(value)`
69
+ description: isSumOfSquares
70
+ ? `sqrt of sum of squares (safe, but can be zero)`
71
+ : `sqrt of negative number produces NaN`,
72
+ suggestion: isSumOfSquares
73
+ ? `Add epsilon for numerical stability: sqrt(max(dx*dx + dy*dy, epsilon))`
74
+ : `Guard negative values: sqrt(max(0, value))`,
75
+ variableName,
76
+ line
65
77
  });
66
78
  break;
67
79
  case 'magnitude2d':
68
80
  case 'magnitude3d':
69
- // magnitude uses sqrt internally
70
81
  guards.push({
71
82
  type: 'sqrt_negative',
72
83
  expression: expr,
73
- description: `magnitude of vector (uses sqrt internally)`,
74
- suggestion: `Ensure vector components are valid`
84
+ description: `magnitude uses sqrt internally (safe, but can be zero)`,
85
+ suggestion: `Gradients may have division by zero when magnitude is zero`,
86
+ variableName,
87
+ line
75
88
  });
76
89
  break;
77
90
  case 'normalize2d':
@@ -80,15 +93,19 @@ function analyzeCallGuards(expr, guards) {
80
93
  type: 'normalize_zero',
81
94
  expression: expr,
82
95
  description: `Normalizing zero vector causes division by zero`,
83
- suggestion: `Check if magnitude > epsilon before normalizing`
96
+ suggestion: `if (magnitude < epsilon) return zero vector or skip normalization`,
97
+ variableName,
98
+ line
84
99
  });
85
100
  break;
86
101
  case 'atan2':
87
102
  guards.push({
88
103
  type: 'atan2_zero',
89
104
  expression: expr,
90
- description: `atan2(0, 0) is undefined`,
91
- suggestion: `Check if both arguments are zero: if (y === 0 && x === 0) return 0;`
105
+ description: `atan2(0, 0) is undefined and gradients have division by zero`,
106
+ suggestion: `if (y === 0 && x === 0) return 0 with zero gradients`,
107
+ variableName,
108
+ line
92
109
  });
93
110
  break;
94
111
  case 'log':
@@ -96,7 +113,9 @@ function analyzeCallGuards(expr, guards) {
96
113
  type: 'division_by_zero',
97
114
  expression: expr.args[0],
98
115
  description: `log(0) is -Infinity, log(negative) is NaN`,
99
- suggestion: `Add check: Math.max(epsilon, value)`
116
+ suggestion: `Clamp to positive: log(max(epsilon, value))`,
117
+ variableName,
118
+ line
100
119
  });
101
120
  break;
102
121
  case 'asin':
@@ -105,42 +124,66 @@ function analyzeCallGuards(expr, guards) {
105
124
  type: 'division_by_zero',
106
125
  expression: expr.args[0],
107
126
  description: `${expr.name} requires argument in [-1, 1]`,
108
- suggestion: `Clamp value: Math.max(-1, Math.min(1, value))`
127
+ suggestion: `Clamp: ${expr.name}(max(-1, min(1, value)))`,
128
+ variableName,
129
+ line
109
130
  });
110
131
  break;
111
132
  }
112
133
  }
134
+ /**
135
+ * Check if expression is a squared term (x^2 or x*x)
136
+ */
137
+ function isSqExpression(expr) {
138
+ if (expr.kind === 'binary') {
139
+ if (expr.operator === '*') {
140
+ // Check if x * x
141
+ if (expr.left.kind === 'variable' && expr.right.kind === 'variable') {
142
+ return expr.left.name === expr.right.name;
143
+ }
144
+ }
145
+ else if (expr.operator === '^' || expr.operator === '**') {
146
+ // Check if x^2
147
+ if (expr.right.kind === 'number' && expr.right.value === 2) {
148
+ return true;
149
+ }
150
+ }
151
+ }
152
+ return false;
153
+ }
113
154
  /**
114
155
  * Format guard analysis for display
115
156
  */
116
- export function formatGuardWarnings(result) {
157
+ export function formatGuardWarnings(result, asComments = false) {
117
158
  if (!result.hasIssues) {
118
159
  return '';
119
160
  }
161
+ const prefix = asComments ? '// ' : '';
120
162
  const lines = [];
121
- lines.push('⚠️ EDGE CASE WARNINGS:');
122
- lines.push('');
123
- lines.push('The generated code may encounter edge cases that produce');
124
- lines.push('NaN, Infinity, or incorrect results:');
125
- lines.push('');
126
- // Group by type
127
- const byType = new Map();
163
+ lines.push(`${prefix}⚠️ EDGE CASE WARNINGS:`);
164
+ lines.push(prefix);
165
+ lines.push(`${prefix}The generated code may encounter edge cases that produce`);
166
+ lines.push(`${prefix}NaN, Infinity, or incorrect gradients:`);
167
+ lines.push(prefix);
168
+ // Show each guard individually with context
128
169
  for (const guard of result.guards) {
129
- const existing = byType.get(guard.type) || [];
130
- existing.push(guard);
131
- byType.set(guard.type, existing);
132
- }
133
- for (const [type, guards] of byType.entries()) {
134
- const typeLabel = formatGuardType(type);
135
- lines.push(` • ${typeLabel} (${guards.length} occurrence${guards.length > 1 ? 's' : ''})`);
136
- // Show first occurrence
137
- const first = guards[0];
138
- lines.push(` ${first.description}`);
139
- lines.push(` 💡 ${first.suggestion}`);
140
- lines.push('');
170
+ const typeLabel = formatGuardType(guard.type);
171
+ // Show location and variable if available
172
+ let location = `${prefix} •`;
173
+ if (guard.line) {
174
+ location += ` Line ${guard.line}:`;
175
+ }
176
+ if (guard.variableName) {
177
+ location += ` ${guard.variableName} =`;
178
+ }
179
+ lines.push(location);
180
+ lines.push(`${prefix} ${typeLabel}: ${guard.description}`);
181
+ lines.push(`${prefix} 💡 Fix: ${guard.suggestion}`);
182
+ lines.push(prefix);
141
183
  }
142
- lines.push('Consider adding runtime checks or ensuring inputs are within valid ranges.');
143
- lines.push('');
184
+ lines.push(`${prefix}Add runtime checks or ensure inputs are within valid ranges.`);
185
+ lines.push(`${prefix}Use --guards --epsilon 1e-10 to automatically emit epsilon guards.`);
186
+ lines.push(prefix);
144
187
  return lines.join('\n');
145
188
  }
146
189
  function formatGuardType(type) {
@@ -8,3 +8,8 @@ import { Expression, FunctionDef } from './AST.js';
8
8
  * Returns a new expression with all intermediate variables substituted
9
9
  */
10
10
  export declare function inlineIntermediateVariables(func: FunctionDef): Expression;
11
+ /**
12
+ * Inline an expression using a substitution map
13
+ * Used to get the fully-expanded form of forward pass expressions
14
+ */
15
+ export declare function inlineExpression(expr: Expression, substitutions: Map<string, Expression>): Expression;