gradient-script 0.3.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/cli.js +164 -28
- package/dist/dsl/CodeGen.js +4 -0
- package/package.json +1 -1
package/dist/cli.js
CHANGED
|
@@ -6,7 +6,7 @@ import { computeFunctionGradients } from './dsl/Differentiation.js';
|
|
|
6
6
|
import { generateComplete } from './dsl/CodeGen.js';
|
|
7
7
|
import { analyzeGuards, formatGuardWarnings } from './dsl/Guards.js';
|
|
8
8
|
import { ParseError, formatParseError } from './dsl/Errors.js';
|
|
9
|
-
import
|
|
9
|
+
// GradientChecker import removed - verification now executes generated code directly
|
|
10
10
|
import { Types } from './dsl/Types.js';
|
|
11
11
|
/**
|
|
12
12
|
* Generate random test points for gradient verification.
|
|
@@ -38,33 +38,167 @@ function generateTestPoints(func, env) {
|
|
|
38
38
|
return testPoints;
|
|
39
39
|
}
|
|
40
40
|
/**
|
|
41
|
-
* Verify gradients
|
|
42
|
-
*
|
|
41
|
+
* Verify gradients by executing the GENERATED CODE against numerical differentiation.
|
|
42
|
+
* This catches code generation bugs (like missing parentheses) that AST-based checking misses.
|
|
43
43
|
*/
|
|
44
|
-
function
|
|
45
|
-
const checker = new GradientChecker(1e-5, 1e-4);
|
|
44
|
+
function verifyGeneratedCode(func, generatedCode, env) {
|
|
46
45
|
const testPoints = generateTestPoints(func, env);
|
|
46
|
+
const epsilon = 1e-6;
|
|
47
|
+
const tolerance = 1e-4;
|
|
48
|
+
// Build parameter list for function calls
|
|
49
|
+
const paramNames = func.parameters.map(p => p.name);
|
|
50
|
+
// Create executable functions from generated code
|
|
51
|
+
// The generated code defines both funcName and funcName_grad
|
|
52
|
+
let forwardFn;
|
|
53
|
+
let gradFn;
|
|
54
|
+
try {
|
|
55
|
+
// Create a function that returns both the forward and grad functions
|
|
56
|
+
const factory = new Function(`
|
|
57
|
+
${generatedCode}
|
|
58
|
+
return { forward: ${func.name}, grad: ${func.name}_grad };
|
|
59
|
+
`);
|
|
60
|
+
const fns = factory();
|
|
61
|
+
forwardFn = fns.forward;
|
|
62
|
+
gradFn = fns.grad;
|
|
63
|
+
}
|
|
64
|
+
catch (e) {
|
|
65
|
+
console.error(`// Gradient verification FAILED for "${func.name}": code execution error`);
|
|
66
|
+
console.error(`// ${e}`);
|
|
67
|
+
return false;
|
|
68
|
+
}
|
|
47
69
|
let allPassed = true;
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
const
|
|
57
|
-
|
|
58
|
-
.
|
|
59
|
-
|
|
60
|
-
|
|
70
|
+
let maxError = 0;
|
|
71
|
+
let totalChecks = 0;
|
|
72
|
+
const errors = [];
|
|
73
|
+
for (let pointIdx = 0; pointIdx < testPoints.length; pointIdx++) {
|
|
74
|
+
const testPoint = testPoints[pointIdx];
|
|
75
|
+
// Build args array (flattening structured types)
|
|
76
|
+
const args = [];
|
|
77
|
+
for (const param of func.parameters) {
|
|
78
|
+
const val = testPoint.get(param.name);
|
|
79
|
+
if (typeof val === 'number') {
|
|
80
|
+
args.push(val);
|
|
81
|
+
}
|
|
82
|
+
else {
|
|
83
|
+
// Structured: pass components in order
|
|
84
|
+
const paramType = env.getOrThrow(param.name);
|
|
85
|
+
if (!Types.isScalar(paramType)) {
|
|
86
|
+
for (const comp of paramType.components) {
|
|
87
|
+
args.push(val[comp]);
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
// Run forward function at test point
|
|
93
|
+
let f0;
|
|
94
|
+
try {
|
|
95
|
+
f0 = forwardFn(...args);
|
|
96
|
+
}
|
|
97
|
+
catch (e) {
|
|
98
|
+
errors.push(`Test point ${pointIdx + 1}: forward function threw: ${e}`);
|
|
61
99
|
allPassed = false;
|
|
100
|
+
continue;
|
|
101
|
+
}
|
|
102
|
+
// Run gradient function
|
|
103
|
+
let gradResult;
|
|
104
|
+
try {
|
|
105
|
+
// Pass original (non-flattened) args for grad function
|
|
106
|
+
const gradArgs = [];
|
|
107
|
+
for (const param of func.parameters) {
|
|
108
|
+
const val = testPoint.get(param.name);
|
|
109
|
+
if (typeof val === 'number') {
|
|
110
|
+
gradArgs.push(val);
|
|
111
|
+
}
|
|
112
|
+
else {
|
|
113
|
+
// For structured types, pass as object with named components
|
|
114
|
+
gradArgs.push(val);
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
gradResult = gradFn(...gradArgs);
|
|
118
|
+
}
|
|
119
|
+
catch (e) {
|
|
120
|
+
errors.push(`Test point ${pointIdx + 1}: gradient function threw: ${e}`);
|
|
121
|
+
allPassed = false;
|
|
122
|
+
continue;
|
|
123
|
+
}
|
|
124
|
+
// For each gradient parameter, compare analytical vs numerical
|
|
125
|
+
for (const param of func.parameters) {
|
|
126
|
+
if (!param.requiresGrad)
|
|
127
|
+
continue;
|
|
128
|
+
const paramType = env.getOrThrow(param.name);
|
|
129
|
+
const gradKey = `d${param.name}`;
|
|
130
|
+
if (Types.isScalar(paramType)) {
|
|
131
|
+
// Scalar gradient
|
|
132
|
+
totalChecks++;
|
|
133
|
+
const analytical = gradResult[gradKey];
|
|
134
|
+
// Compute numerical gradient
|
|
135
|
+
const paramIdx = func.parameters.findIndex(p => p.name === param.name);
|
|
136
|
+
let flatIdx = 0;
|
|
137
|
+
for (let i = 0; i < paramIdx; i++) {
|
|
138
|
+
const pt = env.getOrThrow(func.parameters[i].name);
|
|
139
|
+
flatIdx += Types.isScalar(pt) ? 1 : pt.components.length;
|
|
140
|
+
}
|
|
141
|
+
const argsPlus = [...args];
|
|
142
|
+
const argsMinus = [...args];
|
|
143
|
+
argsPlus[flatIdx] += epsilon;
|
|
144
|
+
argsMinus[flatIdx] -= epsilon;
|
|
145
|
+
const fPlus = forwardFn(...argsPlus);
|
|
146
|
+
const fMinus = forwardFn(...argsMinus);
|
|
147
|
+
const numerical = (fPlus - fMinus) / (2 * epsilon);
|
|
148
|
+
const error = Math.abs(analytical - numerical);
|
|
149
|
+
const relError = error / (Math.abs(numerical) + 1e-10);
|
|
150
|
+
maxError = Math.max(maxError, error);
|
|
151
|
+
if (error > tolerance && relError > tolerance) {
|
|
152
|
+
if (!isNaN(analytical) || !isNaN(numerical)) {
|
|
153
|
+
errors.push(`${param.name}: analytical=${analytical.toExponential(2)}, numerical=${numerical.toExponential(2)}, error=${error.toExponential(2)}`);
|
|
154
|
+
allPassed = false;
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
}
|
|
158
|
+
else {
|
|
159
|
+
// Structured gradient
|
|
160
|
+
const gradStruct = gradResult[gradKey];
|
|
161
|
+
const paramIdx = func.parameters.findIndex(p => p.name === param.name);
|
|
162
|
+
let flatIdx = 0;
|
|
163
|
+
for (let i = 0; i < paramIdx; i++) {
|
|
164
|
+
const pt = env.getOrThrow(func.parameters[i].name);
|
|
165
|
+
flatIdx += Types.isScalar(pt) ? 1 : pt.components.length;
|
|
166
|
+
}
|
|
167
|
+
for (let compIdx = 0; compIdx < paramType.components.length; compIdx++) {
|
|
168
|
+
const comp = paramType.components[compIdx];
|
|
169
|
+
totalChecks++;
|
|
170
|
+
const analytical = gradStruct[comp];
|
|
171
|
+
const argsPlus = [...args];
|
|
172
|
+
const argsMinus = [...args];
|
|
173
|
+
argsPlus[flatIdx + compIdx] += epsilon;
|
|
174
|
+
argsMinus[flatIdx + compIdx] -= epsilon;
|
|
175
|
+
const fPlus = forwardFn(...argsPlus);
|
|
176
|
+
const fMinus = forwardFn(...argsMinus);
|
|
177
|
+
const numerical = (fPlus - fMinus) / (2 * epsilon);
|
|
178
|
+
const error = Math.abs(analytical - numerical);
|
|
179
|
+
const relError = error / (Math.abs(numerical) + 1e-10);
|
|
180
|
+
maxError = Math.max(maxError, error);
|
|
181
|
+
if (error > tolerance && relError > tolerance) {
|
|
182
|
+
if (!isNaN(analytical) || !isNaN(numerical)) {
|
|
183
|
+
errors.push(`${param.name}.${comp}: analytical=${analytical.toExponential(2)}, numerical=${numerical.toExponential(2)}, error=${error.toExponential(2)}`);
|
|
184
|
+
allPassed = false;
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
}
|
|
62
189
|
}
|
|
63
190
|
}
|
|
64
191
|
if (allPassed) {
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
192
|
+
console.error(`// ✓ ${func.name}: ${totalChecks} gradients verified (max error: ${maxError.toExponential(2)})`);
|
|
193
|
+
}
|
|
194
|
+
else {
|
|
195
|
+
console.error(`// ✗ ${func.name}: gradient verification FAILED`);
|
|
196
|
+
for (const err of errors.slice(0, 5)) {
|
|
197
|
+
console.error(`// ${err}`);
|
|
198
|
+
}
|
|
199
|
+
if (errors.length > 5) {
|
|
200
|
+
console.error(`// ... and ${errors.length - 5} more errors`);
|
|
201
|
+
}
|
|
68
202
|
}
|
|
69
203
|
return allPassed;
|
|
70
204
|
}
|
|
@@ -208,8 +342,15 @@ function main() {
|
|
|
208
342
|
program.functions.forEach((func, index) => {
|
|
209
343
|
const env = inferFunction(func);
|
|
210
344
|
const gradients = computeFunctionGradients(func, env);
|
|
211
|
-
|
|
212
|
-
|
|
345
|
+
const perFunctionOptions = { ...options };
|
|
346
|
+
if (index > 0 && perFunctionOptions.includeComments !== false) {
|
|
347
|
+
perFunctionOptions.includeComments = false;
|
|
348
|
+
}
|
|
349
|
+
// Generate code FIRST
|
|
350
|
+
const code = generateComplete(func, gradients, env, perFunctionOptions);
|
|
351
|
+
// MANDATORY: Verify the GENERATED CODE against numerical differentiation
|
|
352
|
+
// This catches code generation bugs that AST-based checking would miss
|
|
353
|
+
const verified = verifyGeneratedCode(func, code, env);
|
|
213
354
|
if (!verified) {
|
|
214
355
|
hasVerificationFailure = true;
|
|
215
356
|
}
|
|
@@ -219,11 +360,6 @@ function main() {
|
|
|
219
360
|
console.error('// Function "' + func.name + '" may have edge cases:');
|
|
220
361
|
console.error(formatGuardWarnings(guardAnalysis, true));
|
|
221
362
|
}
|
|
222
|
-
const perFunctionOptions = { ...options };
|
|
223
|
-
if (index > 0 && perFunctionOptions.includeComments !== false) {
|
|
224
|
-
perFunctionOptions.includeComments = false;
|
|
225
|
-
}
|
|
226
|
-
const code = generateComplete(func, gradients, env, perFunctionOptions);
|
|
227
363
|
outputs.push(code);
|
|
228
364
|
});
|
|
229
365
|
if (hasVerificationFailure) {
|
package/dist/dsl/CodeGen.js
CHANGED
|
@@ -169,6 +169,10 @@ export class ExpressionCodeGen {
|
|
|
169
169
|
}
|
|
170
170
|
genUnary(expr) {
|
|
171
171
|
const operand = this.generate(expr.operand);
|
|
172
|
+
// Parenthesize binary operands to avoid precedence bugs: -(a + b) not -a + b
|
|
173
|
+
if (expr.operand.kind === 'binary') {
|
|
174
|
+
return `${expr.operator}(${operand})`;
|
|
175
|
+
}
|
|
172
176
|
return `${expr.operator}${operand}`;
|
|
173
177
|
}
|
|
174
178
|
genCall(expr) {
|