gradient-script 0.1.0 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,11 +8,11 @@
8
8
  export function analyzeGuards(func) {
9
9
  const guards = [];
10
10
  // Analyze return expression
11
- collectGuards(func.returnExpr, guards);
11
+ collectGuards(func.returnExpr, guards, undefined);
12
12
  // Analyze intermediate expressions
13
13
  for (const stmt of func.body) {
14
14
  if (stmt.kind === 'assignment') {
15
- collectGuards(stmt.expression, guards);
15
+ collectGuards(stmt.expression, guards, stmt.variable, stmt.loc?.line);
16
16
  }
17
17
  }
18
18
  return {
@@ -23,7 +23,7 @@ export function analyzeGuards(func) {
23
23
  /**
24
24
  * Collect potential guards from an expression
25
25
  */
26
- function collectGuards(expr, guards) {
26
+ function collectGuards(expr, guards, variableName, line) {
27
27
  switch (expr.kind) {
28
28
  case 'binary':
29
29
  if (expr.operator === '/') {
@@ -31,47 +31,60 @@ function collectGuards(expr, guards) {
31
31
  type: 'division_by_zero',
32
32
  expression: expr.right,
33
33
  description: `Division by zero if denominator becomes zero`,
34
- suggestion: `Add check: if (denominator === 0) return { value: 0, gradients: {...} };`
34
+ suggestion: `Add check: if (Math.abs(denominator) < epsilon) return {...};`,
35
+ variableName,
36
+ line: line || expr.loc?.line
35
37
  });
36
38
  }
37
- collectGuards(expr.left, guards);
38
- collectGuards(expr.right, guards);
39
+ collectGuards(expr.left, guards, variableName, line);
40
+ collectGuards(expr.right, guards, variableName, line);
39
41
  break;
40
42
  case 'unary':
41
- collectGuards(expr.operand, guards);
43
+ collectGuards(expr.operand, guards, variableName, line);
42
44
  break;
43
45
  case 'call':
44
- analyzeCallGuards(expr, guards);
46
+ analyzeCallGuards(expr, guards, variableName, line || expr.loc?.line);
45
47
  for (const arg of expr.args) {
46
- collectGuards(arg, guards);
48
+ collectGuards(arg, guards, variableName, line);
47
49
  }
48
50
  break;
49
51
  case 'component':
50
- collectGuards(expr.object, guards);
52
+ collectGuards(expr.object, guards, variableName, line);
51
53
  break;
52
54
  }
53
55
  }
54
56
  /**
55
57
  * Analyze function calls for specific edge cases
56
58
  */
57
- function analyzeCallGuards(expr, guards) {
59
+ function analyzeCallGuards(expr, guards, variableName, line) {
58
60
  switch (expr.name) {
59
61
  case 'sqrt':
62
+ // Check if it's sqrt of sum of squares (always safe)
63
+ const arg = expr.args[0];
64
+ const isSumOfSquares = arg.kind === 'binary' && arg.operator === '+' &&
65
+ isSqExpression(arg.left) && isSqExpression(arg.right);
60
66
  guards.push({
61
67
  type: 'sqrt_negative',
62
68
  expression: expr.args[0],
63
- description: `sqrt of negative number produces NaN`,
64
- suggestion: `Add check: Math.max(0, value) or abs(value)`
69
+ description: isSumOfSquares
70
+ ? `sqrt of sum of squares (safe, but can be zero)`
71
+ : `sqrt of negative number produces NaN`,
72
+ suggestion: isSumOfSquares
73
+ ? `Add epsilon for numerical stability: sqrt(max(dx*dx + dy*dy, epsilon))`
74
+ : `Guard negative values: sqrt(max(0, value))`,
75
+ variableName,
76
+ line
65
77
  });
66
78
  break;
67
79
  case 'magnitude2d':
68
80
  case 'magnitude3d':
69
- // magnitude uses sqrt internally
70
81
  guards.push({
71
82
  type: 'sqrt_negative',
72
83
  expression: expr,
73
- description: `magnitude of vector (uses sqrt internally)`,
74
- suggestion: `Ensure vector components are valid`
84
+ description: `magnitude uses sqrt internally (safe, but can be zero)`,
85
+ suggestion: `Gradients may have division by zero when magnitude is zero`,
86
+ variableName,
87
+ line
75
88
  });
76
89
  break;
77
90
  case 'normalize2d':
@@ -80,15 +93,19 @@ function analyzeCallGuards(expr, guards) {
80
93
  type: 'normalize_zero',
81
94
  expression: expr,
82
95
  description: `Normalizing zero vector causes division by zero`,
83
- suggestion: `Check if magnitude > epsilon before normalizing`
96
+ suggestion: `if (magnitude < epsilon) return zero vector or skip normalization`,
97
+ variableName,
98
+ line
84
99
  });
85
100
  break;
86
101
  case 'atan2':
87
102
  guards.push({
88
103
  type: 'atan2_zero',
89
104
  expression: expr,
90
- description: `atan2(0, 0) is undefined`,
91
- suggestion: `Check if both arguments are zero: if (y === 0 && x === 0) return 0;`
105
+ description: `atan2(0, 0) is undefined and gradients have division by zero`,
106
+ suggestion: `if (y === 0 && x === 0) return 0 with zero gradients`,
107
+ variableName,
108
+ line
92
109
  });
93
110
  break;
94
111
  case 'log':
@@ -96,7 +113,9 @@ function analyzeCallGuards(expr, guards) {
96
113
  type: 'division_by_zero',
97
114
  expression: expr.args[0],
98
115
  description: `log(0) is -Infinity, log(negative) is NaN`,
99
- suggestion: `Add check: Math.max(epsilon, value)`
116
+ suggestion: `Clamp to positive: log(max(epsilon, value))`,
117
+ variableName,
118
+ line
100
119
  });
101
120
  break;
102
121
  case 'asin':
@@ -105,11 +124,33 @@ function analyzeCallGuards(expr, guards) {
105
124
  type: 'division_by_zero',
106
125
  expression: expr.args[0],
107
126
  description: `${expr.name} requires argument in [-1, 1]`,
108
- suggestion: `Clamp value: Math.max(-1, Math.min(1, value))`
127
+ suggestion: `Clamp: ${expr.name}(max(-1, min(1, value)))`,
128
+ variableName,
129
+ line
109
130
  });
110
131
  break;
111
132
  }
112
133
  }
134
+ /**
135
+ * Check if expression is a squared term (x^2 or x*x)
136
+ */
137
+ function isSqExpression(expr) {
138
+ if (expr.kind === 'binary') {
139
+ if (expr.operator === '*') {
140
+ // Check if x * x
141
+ if (expr.left.kind === 'variable' && expr.right.kind === 'variable') {
142
+ return expr.left.name === expr.right.name;
143
+ }
144
+ }
145
+ else if (expr.operator === '^' || expr.operator === '**') {
146
+ // Check if x^2
147
+ if (expr.right.kind === 'number' && expr.right.value === 2) {
148
+ return true;
149
+ }
150
+ }
151
+ }
152
+ return false;
153
+ }
113
154
  /**
114
155
  * Format guard analysis for display
115
156
  */
@@ -121,25 +162,26 @@ export function formatGuardWarnings(result) {
121
162
  lines.push('⚠️ EDGE CASE WARNINGS:');
122
163
  lines.push('');
123
164
  lines.push('The generated code may encounter edge cases that produce');
124
- lines.push('NaN, Infinity, or incorrect results:');
165
+ lines.push('NaN, Infinity, or incorrect gradients:');
125
166
  lines.push('');
126
- // Group by type
127
- const byType = new Map();
167
+ // Show each guard individually with context
128
168
  for (const guard of result.guards) {
129
- const existing = byType.get(guard.type) || [];
130
- existing.push(guard);
131
- byType.set(guard.type, existing);
132
- }
133
- for (const [type, guards] of byType.entries()) {
134
- const typeLabel = formatGuardType(type);
135
- lines.push(` • ${typeLabel} (${guards.length} occurrence${guards.length > 1 ? 's' : ''})`);
136
- // Show first occurrence
137
- const first = guards[0];
138
- lines.push(` ${first.description}`);
139
- lines.push(` 💡 ${first.suggestion}`);
169
+ const typeLabel = formatGuardType(guard.type);
170
+ // Show location and variable if available
171
+ let location = ' •';
172
+ if (guard.line) {
173
+ location += ` Line ${guard.line}:`;
174
+ }
175
+ if (guard.variableName) {
176
+ location += ` ${guard.variableName} =`;
177
+ }
178
+ lines.push(location);
179
+ lines.push(` ${typeLabel}: ${guard.description}`);
180
+ lines.push(` 💡 Fix: ${guard.suggestion}`);
140
181
  lines.push('');
141
182
  }
142
- lines.push('Consider adding runtime checks or ensuring inputs are within valid ranges.');
183
+ lines.push('Add runtime checks or ensure inputs are within valid ranges.');
184
+ lines.push('Use --guards --epsilon 1e-10 to automatically emit epsilon guards.');
143
185
  lines.push('');
144
186
  return lines.join('\n');
145
187
  }
@@ -6,8 +6,9 @@ import { ExpressionTransformer } from './ExpressionTransformer.js';
6
6
  /**
7
7
  * Expression transformer that substitutes variables from a substitution map
8
8
  * Handles recursive inlining by reprocessing substituted expressions
9
+ * Used for inlining intermediate variables to eliminate assignments
9
10
  */
10
- class SubstitutionTransformer extends ExpressionTransformer {
11
+ class VariableSubstitutionTransformer extends ExpressionTransformer {
11
12
  substitutions;
12
13
  constructor(substitutions) {
13
14
  super();
@@ -35,6 +36,6 @@ export function inlineIntermediateVariables(func) {
35
36
  }
36
37
  }
37
38
  // Use transformer to inline all variables
38
- const transformer = new SubstitutionTransformer(substitutions);
39
+ const transformer = new VariableSubstitutionTransformer(substitutions);
39
40
  return transformer.transform(func.returnExpr);
40
41
  }
package/dist/dsl/Lexer.js CHANGED
@@ -2,6 +2,7 @@
2
2
  * Lexer for GradientScript DSL
3
3
  * Tokenizes input with support for ∇, ^, **, and structured types
4
4
  */
5
+ import { ParseError } from './Errors.js';
5
6
  export var TokenType;
6
7
  (function (TokenType) {
7
8
  // Literals
@@ -126,7 +127,8 @@ export class Lexer {
126
127
  this.advance();
127
128
  return { type: TokenType.NEWLINE, value: '\n', line, column };
128
129
  }
129
- throw new Error(`Unexpected character '${char}' at line ${line}, column ${column}`);
130
+ // Create a more helpful error for common mistakes
131
+ throw new ParseError(`Unexpected character '${char}'`, line, column, char);
130
132
  }
131
133
  number() {
132
134
  const line = this.line;
@@ -123,13 +123,15 @@ export class Parser {
123
123
  * variable = expression
124
124
  */
125
125
  assignment() {
126
- const variable = this.consume(TokenType.IDENTIFIER, 'Expected variable name').value;
126
+ const varToken = this.consume(TokenType.IDENTIFIER, 'Expected variable name');
127
+ const variable = varToken.value;
127
128
  this.consume(TokenType.EQUALS, "Expected '=' in assignment");
128
129
  const expression = this.expression();
129
130
  return {
130
131
  kind: 'assignment',
131
132
  variable,
132
- expression
133
+ expression,
134
+ loc: { line: varToken.line, column: varToken.column }
133
135
  };
134
136
  }
135
137
  /**
@@ -161,13 +163,15 @@ export class Parser {
161
163
  multiplicative() {
162
164
  let expr = this.power();
163
165
  while (this.match(TokenType.MULTIPLY, TokenType.DIVIDE)) {
164
- const operator = this.previous().value;
166
+ const opToken = this.previous();
167
+ const operator = opToken.value;
165
168
  const right = this.power();
166
169
  expr = {
167
170
  kind: 'binary',
168
171
  operator,
169
172
  left: expr,
170
- right
173
+ right,
174
+ loc: { line: opToken.line, column: opToken.column }
171
175
  };
172
176
  }
173
177
  return expr;
@@ -213,6 +217,7 @@ export class Parser {
213
217
  while (true) {
214
218
  if (this.match(TokenType.LPAREN)) {
215
219
  // Function call
220
+ const startLoc = expr.loc || { line: this.previous().line, column: this.previous().column };
216
221
  const args = this.argumentList();
217
222
  this.consume(TokenType.RPAREN, "Expected ')' after arguments");
218
223
  if (expr.kind !== 'variable') {
@@ -222,7 +227,8 @@ export class Parser {
222
227
  expr = {
223
228
  kind: 'call',
224
229
  name: expr.name,
225
- args
230
+ args,
231
+ loc: startLoc
226
232
  };
227
233
  }
228
234
  else if (this.match(TokenType.DOT)) {
@@ -116,6 +116,53 @@ class Simplifier extends ExpressionTransformer {
116
116
  if (expressionsEqual(left, right)) {
117
117
  return { kind: 'number', value: 1 };
118
118
  }
119
+ // (a + a) / 2 → a
120
+ if (rightNum === 2 && left.kind === 'binary' && left.operator === '+') {
121
+ if (expressionsEqual(left.left, left.right)) {
122
+ return left.left;
123
+ }
124
+ }
125
+ // (a + a) / (2 * b) → a / b
126
+ if (right.kind === 'binary' && right.operator === '*') {
127
+ const rightLeft = right.left;
128
+ const rightRight = right.right;
129
+ const rightLeftNum = isNumber(rightLeft) ? rightLeft.value : null;
130
+ if (rightLeftNum === 2 && left.kind === 'binary' && left.operator === '+') {
131
+ if (expressionsEqual(left.left, left.right)) {
132
+ return {
133
+ kind: 'binary',
134
+ operator: '/',
135
+ left: left.left,
136
+ right: rightRight
137
+ };
138
+ }
139
+ // (-1 * a + a * -1) / (2 * b) → -a / b
140
+ const leftLeft = left.left;
141
+ const leftRight = left.right;
142
+ if (leftLeft.kind === 'binary' && leftLeft.operator === '*' &&
143
+ leftRight.kind === 'binary' && leftRight.operator === '*') {
144
+ const ll_left = leftLeft.left;
145
+ const ll_right = leftLeft.right;
146
+ const lr_left = leftRight.left;
147
+ const lr_right = leftRight.right;
148
+ const ll_leftNum = isNumber(ll_left) ? ll_left.value : null;
149
+ const lr_rightNum = isNumber(lr_right) ? lr_right.value : null;
150
+ // (-1 * a) + (a * -1)
151
+ if (ll_leftNum === -1 && lr_rightNum === -1 && expressionsEqual(ll_right, lr_left)) {
152
+ return {
153
+ kind: 'unary',
154
+ operator: '-',
155
+ operand: {
156
+ kind: 'binary',
157
+ operator: '/',
158
+ left: ll_right,
159
+ right: rightRight
160
+ }
161
+ };
162
+ }
163
+ }
164
+ }
165
+ }
119
166
  }
120
167
  // Power rules
121
168
  if (expr.operator === '^' || expr.operator === '**') {
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "gradient-script",
3
- "version": "0.1.0",
3
+ "version": "0.2.0",
4
4
  "description": "Symbolic differentiation for structured types with a simple DSL",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",