gradient-script 0.1.0 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +49 -8
- package/dist/cli.js +57 -19
- package/dist/dsl/AST.d.ts +8 -0
- package/dist/dsl/CSE.js +5 -31
- package/dist/dsl/CodeGen.d.ts +7 -2
- package/dist/dsl/CodeGen.js +259 -66
- package/dist/dsl/Errors.d.ts +6 -1
- package/dist/dsl/Errors.js +70 -1
- package/dist/dsl/Expander.js +5 -2
- package/dist/dsl/ExpressionUtils.d.ts +8 -0
- package/dist/dsl/ExpressionUtils.js +24 -0
- package/dist/dsl/Guards.d.ts +2 -0
- package/dist/dsl/Guards.js +78 -36
- package/dist/dsl/Inliner.js +3 -2
- package/dist/dsl/Lexer.js +3 -1
- package/dist/dsl/Parser.js +11 -5
- package/dist/dsl/Simplify.js +47 -0
- package/package.json +1 -1
package/dist/dsl/Guards.js
CHANGED
|
@@ -8,11 +8,11 @@
|
|
|
8
8
|
export function analyzeGuards(func) {
|
|
9
9
|
const guards = [];
|
|
10
10
|
// Analyze return expression
|
|
11
|
-
collectGuards(func.returnExpr, guards);
|
|
11
|
+
collectGuards(func.returnExpr, guards, undefined);
|
|
12
12
|
// Analyze intermediate expressions
|
|
13
13
|
for (const stmt of func.body) {
|
|
14
14
|
if (stmt.kind === 'assignment') {
|
|
15
|
-
collectGuards(stmt.expression, guards);
|
|
15
|
+
collectGuards(stmt.expression, guards, stmt.variable, stmt.loc?.line);
|
|
16
16
|
}
|
|
17
17
|
}
|
|
18
18
|
return {
|
|
@@ -23,7 +23,7 @@ export function analyzeGuards(func) {
|
|
|
23
23
|
/**
|
|
24
24
|
* Collect potential guards from an expression
|
|
25
25
|
*/
|
|
26
|
-
function collectGuards(expr, guards) {
|
|
26
|
+
function collectGuards(expr, guards, variableName, line) {
|
|
27
27
|
switch (expr.kind) {
|
|
28
28
|
case 'binary':
|
|
29
29
|
if (expr.operator === '/') {
|
|
@@ -31,47 +31,60 @@ function collectGuards(expr, guards) {
|
|
|
31
31
|
type: 'division_by_zero',
|
|
32
32
|
expression: expr.right,
|
|
33
33
|
description: `Division by zero if denominator becomes zero`,
|
|
34
|
-
suggestion: `Add check: if (denominator
|
|
34
|
+
suggestion: `Add check: if (Math.abs(denominator) < epsilon) return {...};`,
|
|
35
|
+
variableName,
|
|
36
|
+
line: line || expr.loc?.line
|
|
35
37
|
});
|
|
36
38
|
}
|
|
37
|
-
collectGuards(expr.left, guards);
|
|
38
|
-
collectGuards(expr.right, guards);
|
|
39
|
+
collectGuards(expr.left, guards, variableName, line);
|
|
40
|
+
collectGuards(expr.right, guards, variableName, line);
|
|
39
41
|
break;
|
|
40
42
|
case 'unary':
|
|
41
|
-
collectGuards(expr.operand, guards);
|
|
43
|
+
collectGuards(expr.operand, guards, variableName, line);
|
|
42
44
|
break;
|
|
43
45
|
case 'call':
|
|
44
|
-
analyzeCallGuards(expr, guards);
|
|
46
|
+
analyzeCallGuards(expr, guards, variableName, line || expr.loc?.line);
|
|
45
47
|
for (const arg of expr.args) {
|
|
46
|
-
collectGuards(arg, guards);
|
|
48
|
+
collectGuards(arg, guards, variableName, line);
|
|
47
49
|
}
|
|
48
50
|
break;
|
|
49
51
|
case 'component':
|
|
50
|
-
collectGuards(expr.object, guards);
|
|
52
|
+
collectGuards(expr.object, guards, variableName, line);
|
|
51
53
|
break;
|
|
52
54
|
}
|
|
53
55
|
}
|
|
54
56
|
/**
|
|
55
57
|
* Analyze function calls for specific edge cases
|
|
56
58
|
*/
|
|
57
|
-
function analyzeCallGuards(expr, guards) {
|
|
59
|
+
function analyzeCallGuards(expr, guards, variableName, line) {
|
|
58
60
|
switch (expr.name) {
|
|
59
61
|
case 'sqrt':
|
|
62
|
+
// Check if it's sqrt of sum of squares (always safe)
|
|
63
|
+
const arg = expr.args[0];
|
|
64
|
+
const isSumOfSquares = arg.kind === 'binary' && arg.operator === '+' &&
|
|
65
|
+
isSqExpression(arg.left) && isSqExpression(arg.right);
|
|
60
66
|
guards.push({
|
|
61
67
|
type: 'sqrt_negative',
|
|
62
68
|
expression: expr.args[0],
|
|
63
|
-
description:
|
|
64
|
-
|
|
69
|
+
description: isSumOfSquares
|
|
70
|
+
? `sqrt of sum of squares (safe, but can be zero)`
|
|
71
|
+
: `sqrt of negative number produces NaN`,
|
|
72
|
+
suggestion: isSumOfSquares
|
|
73
|
+
? `Add epsilon for numerical stability: sqrt(max(dx*dx + dy*dy, epsilon))`
|
|
74
|
+
: `Guard negative values: sqrt(max(0, value))`,
|
|
75
|
+
variableName,
|
|
76
|
+
line
|
|
65
77
|
});
|
|
66
78
|
break;
|
|
67
79
|
case 'magnitude2d':
|
|
68
80
|
case 'magnitude3d':
|
|
69
|
-
// magnitude uses sqrt internally
|
|
70
81
|
guards.push({
|
|
71
82
|
type: 'sqrt_negative',
|
|
72
83
|
expression: expr,
|
|
73
|
-
description: `magnitude
|
|
74
|
-
suggestion: `
|
|
84
|
+
description: `magnitude uses sqrt internally (safe, but can be zero)`,
|
|
85
|
+
suggestion: `Gradients may have division by zero when magnitude is zero`,
|
|
86
|
+
variableName,
|
|
87
|
+
line
|
|
75
88
|
});
|
|
76
89
|
break;
|
|
77
90
|
case 'normalize2d':
|
|
@@ -80,15 +93,19 @@ function analyzeCallGuards(expr, guards) {
|
|
|
80
93
|
type: 'normalize_zero',
|
|
81
94
|
expression: expr,
|
|
82
95
|
description: `Normalizing zero vector causes division by zero`,
|
|
83
|
-
suggestion: `
|
|
96
|
+
suggestion: `if (magnitude < epsilon) return zero vector or skip normalization`,
|
|
97
|
+
variableName,
|
|
98
|
+
line
|
|
84
99
|
});
|
|
85
100
|
break;
|
|
86
101
|
case 'atan2':
|
|
87
102
|
guards.push({
|
|
88
103
|
type: 'atan2_zero',
|
|
89
104
|
expression: expr,
|
|
90
|
-
description: `atan2(0, 0) is undefined`,
|
|
91
|
-
suggestion: `
|
|
105
|
+
description: `atan2(0, 0) is undefined and gradients have division by zero`,
|
|
106
|
+
suggestion: `if (y === 0 && x === 0) return 0 with zero gradients`,
|
|
107
|
+
variableName,
|
|
108
|
+
line
|
|
92
109
|
});
|
|
93
110
|
break;
|
|
94
111
|
case 'log':
|
|
@@ -96,7 +113,9 @@ function analyzeCallGuards(expr, guards) {
|
|
|
96
113
|
type: 'division_by_zero',
|
|
97
114
|
expression: expr.args[0],
|
|
98
115
|
description: `log(0) is -Infinity, log(negative) is NaN`,
|
|
99
|
-
suggestion: `
|
|
116
|
+
suggestion: `Clamp to positive: log(max(epsilon, value))`,
|
|
117
|
+
variableName,
|
|
118
|
+
line
|
|
100
119
|
});
|
|
101
120
|
break;
|
|
102
121
|
case 'asin':
|
|
@@ -105,11 +124,33 @@ function analyzeCallGuards(expr, guards) {
|
|
|
105
124
|
type: 'division_by_zero',
|
|
106
125
|
expression: expr.args[0],
|
|
107
126
|
description: `${expr.name} requires argument in [-1, 1]`,
|
|
108
|
-
suggestion: `Clamp
|
|
127
|
+
suggestion: `Clamp: ${expr.name}(max(-1, min(1, value)))`,
|
|
128
|
+
variableName,
|
|
129
|
+
line
|
|
109
130
|
});
|
|
110
131
|
break;
|
|
111
132
|
}
|
|
112
133
|
}
|
|
134
|
+
/**
|
|
135
|
+
* Check if expression is a squared term (x^2 or x*x)
|
|
136
|
+
*/
|
|
137
|
+
function isSqExpression(expr) {
|
|
138
|
+
if (expr.kind === 'binary') {
|
|
139
|
+
if (expr.operator === '*') {
|
|
140
|
+
// Check if x * x
|
|
141
|
+
if (expr.left.kind === 'variable' && expr.right.kind === 'variable') {
|
|
142
|
+
return expr.left.name === expr.right.name;
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
else if (expr.operator === '^' || expr.operator === '**') {
|
|
146
|
+
// Check if x^2
|
|
147
|
+
if (expr.right.kind === 'number' && expr.right.value === 2) {
|
|
148
|
+
return true;
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
return false;
|
|
153
|
+
}
|
|
113
154
|
/**
|
|
114
155
|
* Format guard analysis for display
|
|
115
156
|
*/
|
|
@@ -121,25 +162,26 @@ export function formatGuardWarnings(result) {
|
|
|
121
162
|
lines.push('⚠️ EDGE CASE WARNINGS:');
|
|
122
163
|
lines.push('');
|
|
123
164
|
lines.push('The generated code may encounter edge cases that produce');
|
|
124
|
-
lines.push('NaN, Infinity, or incorrect
|
|
165
|
+
lines.push('NaN, Infinity, or incorrect gradients:');
|
|
125
166
|
lines.push('');
|
|
126
|
-
//
|
|
127
|
-
const byType = new Map();
|
|
167
|
+
// Show each guard individually with context
|
|
128
168
|
for (const guard of result.guards) {
|
|
129
|
-
const
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
lines.push(
|
|
139
|
-
lines.push(`
|
|
169
|
+
const typeLabel = formatGuardType(guard.type);
|
|
170
|
+
// Show location and variable if available
|
|
171
|
+
let location = ' •';
|
|
172
|
+
if (guard.line) {
|
|
173
|
+
location += ` Line ${guard.line}:`;
|
|
174
|
+
}
|
|
175
|
+
if (guard.variableName) {
|
|
176
|
+
location += ` ${guard.variableName} =`;
|
|
177
|
+
}
|
|
178
|
+
lines.push(location);
|
|
179
|
+
lines.push(` ${typeLabel}: ${guard.description}`);
|
|
180
|
+
lines.push(` 💡 Fix: ${guard.suggestion}`);
|
|
140
181
|
lines.push('');
|
|
141
182
|
}
|
|
142
|
-
lines.push('
|
|
183
|
+
lines.push('Add runtime checks or ensure inputs are within valid ranges.');
|
|
184
|
+
lines.push('Use --guards --epsilon 1e-10 to automatically emit epsilon guards.');
|
|
143
185
|
lines.push('');
|
|
144
186
|
return lines.join('\n');
|
|
145
187
|
}
|
package/dist/dsl/Inliner.js
CHANGED
|
@@ -6,8 +6,9 @@ import { ExpressionTransformer } from './ExpressionTransformer.js';
|
|
|
6
6
|
/**
|
|
7
7
|
* Expression transformer that substitutes variables from a substitution map
|
|
8
8
|
* Handles recursive inlining by reprocessing substituted expressions
|
|
9
|
+
* Used for inlining intermediate variables to eliminate assignments
|
|
9
10
|
*/
|
|
10
|
-
class
|
|
11
|
+
class VariableSubstitutionTransformer extends ExpressionTransformer {
|
|
11
12
|
substitutions;
|
|
12
13
|
constructor(substitutions) {
|
|
13
14
|
super();
|
|
@@ -35,6 +36,6 @@ export function inlineIntermediateVariables(func) {
|
|
|
35
36
|
}
|
|
36
37
|
}
|
|
37
38
|
// Use transformer to inline all variables
|
|
38
|
-
const transformer = new
|
|
39
|
+
const transformer = new VariableSubstitutionTransformer(substitutions);
|
|
39
40
|
return transformer.transform(func.returnExpr);
|
|
40
41
|
}
|
package/dist/dsl/Lexer.js
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
* Lexer for GradientScript DSL
|
|
3
3
|
* Tokenizes input with support for ∇, ^, **, and structured types
|
|
4
4
|
*/
|
|
5
|
+
import { ParseError } from './Errors.js';
|
|
5
6
|
export var TokenType;
|
|
6
7
|
(function (TokenType) {
|
|
7
8
|
// Literals
|
|
@@ -126,7 +127,8 @@ export class Lexer {
|
|
|
126
127
|
this.advance();
|
|
127
128
|
return { type: TokenType.NEWLINE, value: '\n', line, column };
|
|
128
129
|
}
|
|
129
|
-
|
|
130
|
+
// Create a more helpful error for common mistakes
|
|
131
|
+
throw new ParseError(`Unexpected character '${char}'`, line, column, char);
|
|
130
132
|
}
|
|
131
133
|
number() {
|
|
132
134
|
const line = this.line;
|
package/dist/dsl/Parser.js
CHANGED
|
@@ -123,13 +123,15 @@ export class Parser {
|
|
|
123
123
|
* variable = expression
|
|
124
124
|
*/
|
|
125
125
|
assignment() {
|
|
126
|
-
const
|
|
126
|
+
const varToken = this.consume(TokenType.IDENTIFIER, 'Expected variable name');
|
|
127
|
+
const variable = varToken.value;
|
|
127
128
|
this.consume(TokenType.EQUALS, "Expected '=' in assignment");
|
|
128
129
|
const expression = this.expression();
|
|
129
130
|
return {
|
|
130
131
|
kind: 'assignment',
|
|
131
132
|
variable,
|
|
132
|
-
expression
|
|
133
|
+
expression,
|
|
134
|
+
loc: { line: varToken.line, column: varToken.column }
|
|
133
135
|
};
|
|
134
136
|
}
|
|
135
137
|
/**
|
|
@@ -161,13 +163,15 @@ export class Parser {
|
|
|
161
163
|
multiplicative() {
|
|
162
164
|
let expr = this.power();
|
|
163
165
|
while (this.match(TokenType.MULTIPLY, TokenType.DIVIDE)) {
|
|
164
|
-
const
|
|
166
|
+
const opToken = this.previous();
|
|
167
|
+
const operator = opToken.value;
|
|
165
168
|
const right = this.power();
|
|
166
169
|
expr = {
|
|
167
170
|
kind: 'binary',
|
|
168
171
|
operator,
|
|
169
172
|
left: expr,
|
|
170
|
-
right
|
|
173
|
+
right,
|
|
174
|
+
loc: { line: opToken.line, column: opToken.column }
|
|
171
175
|
};
|
|
172
176
|
}
|
|
173
177
|
return expr;
|
|
@@ -213,6 +217,7 @@ export class Parser {
|
|
|
213
217
|
while (true) {
|
|
214
218
|
if (this.match(TokenType.LPAREN)) {
|
|
215
219
|
// Function call
|
|
220
|
+
const startLoc = expr.loc || { line: this.previous().line, column: this.previous().column };
|
|
216
221
|
const args = this.argumentList();
|
|
217
222
|
this.consume(TokenType.RPAREN, "Expected ')' after arguments");
|
|
218
223
|
if (expr.kind !== 'variable') {
|
|
@@ -222,7 +227,8 @@ export class Parser {
|
|
|
222
227
|
expr = {
|
|
223
228
|
kind: 'call',
|
|
224
229
|
name: expr.name,
|
|
225
|
-
args
|
|
230
|
+
args,
|
|
231
|
+
loc: startLoc
|
|
226
232
|
};
|
|
227
233
|
}
|
|
228
234
|
else if (this.match(TokenType.DOT)) {
|
package/dist/dsl/Simplify.js
CHANGED
|
@@ -116,6 +116,53 @@ class Simplifier extends ExpressionTransformer {
|
|
|
116
116
|
if (expressionsEqual(left, right)) {
|
|
117
117
|
return { kind: 'number', value: 1 };
|
|
118
118
|
}
|
|
119
|
+
// (a + a) / 2 → a
|
|
120
|
+
if (rightNum === 2 && left.kind === 'binary' && left.operator === '+') {
|
|
121
|
+
if (expressionsEqual(left.left, left.right)) {
|
|
122
|
+
return left.left;
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
// (a + a) / (2 * b) → a / b
|
|
126
|
+
if (right.kind === 'binary' && right.operator === '*') {
|
|
127
|
+
const rightLeft = right.left;
|
|
128
|
+
const rightRight = right.right;
|
|
129
|
+
const rightLeftNum = isNumber(rightLeft) ? rightLeft.value : null;
|
|
130
|
+
if (rightLeftNum === 2 && left.kind === 'binary' && left.operator === '+') {
|
|
131
|
+
if (expressionsEqual(left.left, left.right)) {
|
|
132
|
+
return {
|
|
133
|
+
kind: 'binary',
|
|
134
|
+
operator: '/',
|
|
135
|
+
left: left.left,
|
|
136
|
+
right: rightRight
|
|
137
|
+
};
|
|
138
|
+
}
|
|
139
|
+
// (-1 * a + a * -1) / (2 * b) → -a / b
|
|
140
|
+
const leftLeft = left.left;
|
|
141
|
+
const leftRight = left.right;
|
|
142
|
+
if (leftLeft.kind === 'binary' && leftLeft.operator === '*' &&
|
|
143
|
+
leftRight.kind === 'binary' && leftRight.operator === '*') {
|
|
144
|
+
const ll_left = leftLeft.left;
|
|
145
|
+
const ll_right = leftLeft.right;
|
|
146
|
+
const lr_left = leftRight.left;
|
|
147
|
+
const lr_right = leftRight.right;
|
|
148
|
+
const ll_leftNum = isNumber(ll_left) ? ll_left.value : null;
|
|
149
|
+
const lr_rightNum = isNumber(lr_right) ? lr_right.value : null;
|
|
150
|
+
// (-1 * a) + (a * -1)
|
|
151
|
+
if (ll_leftNum === -1 && lr_rightNum === -1 && expressionsEqual(ll_right, lr_left)) {
|
|
152
|
+
return {
|
|
153
|
+
kind: 'unary',
|
|
154
|
+
operator: '-',
|
|
155
|
+
operand: {
|
|
156
|
+
kind: 'binary',
|
|
157
|
+
operator: '/',
|
|
158
|
+
left: ll_right,
|
|
159
|
+
right: rightRight
|
|
160
|
+
}
|
|
161
|
+
};
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
}
|
|
119
166
|
}
|
|
120
167
|
// Power rules
|
|
121
168
|
if (expr.operator === '^' || expr.operator === '**') {
|