gradient-script 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/cli.js CHANGED
@@ -6,7 +6,7 @@ import { computeFunctionGradients } from './dsl/Differentiation.js';
6
6
  import { generateComplete } from './dsl/CodeGen.js';
7
7
  import { analyzeGuards, formatGuardWarnings } from './dsl/Guards.js';
8
8
  import { ParseError, formatParseError } from './dsl/Errors.js';
9
- import { GradientChecker, formatGradCheckResult } from './dsl/GradientChecker.js';
9
+ // GradientChecker import removed - verification now executes generated code directly
10
10
  import { Types } from './dsl/Types.js';
11
11
  /**
12
12
  * Generate random test points for gradient verification.
@@ -38,33 +38,167 @@ function generateTestPoints(func, env) {
38
38
  return testPoints;
39
39
  }
40
40
  /**
41
- * Verify gradients for a function using numerical differentiation.
42
- * Returns true if all gradients pass, false otherwise.
41
+ * Verify gradients by executing the GENERATED CODE against numerical differentiation.
42
+ * This catches code generation bugs (like missing parentheses) that AST-based checking misses.
43
43
  */
44
- function verifyGradients(func, gradients, env) {
45
- const checker = new GradientChecker(1e-5, 1e-4);
44
+ function verifyGeneratedCode(func, generatedCode, env) {
46
45
  const testPoints = generateTestPoints(func, env);
46
+ const epsilon = 1e-6;
47
+ const tolerance = 1e-4;
48
+ // Build parameter list for function calls
49
+ const paramNames = func.parameters.map(p => p.name);
50
+ // Create executable functions from generated code
51
+ // The generated code defines both funcName and funcName_grad
52
+ let forwardFn;
53
+ let gradFn;
54
+ try {
55
+ // Create a function that returns both the forward and grad functions
56
+ const factory = new Function(`
57
+ ${generatedCode}
58
+ return { forward: ${func.name}, grad: ${func.name}_grad };
59
+ `);
60
+ const fns = factory();
61
+ forwardFn = fns.forward;
62
+ gradFn = fns.grad;
63
+ }
64
+ catch (e) {
65
+ console.error(`// Gradient verification FAILED for "${func.name}": code execution error`);
66
+ console.error(`// ${e}`);
67
+ return false;
68
+ }
47
69
  let allPassed = true;
48
- for (let i = 0; i < testPoints.length; i++) {
49
- const result = checker.check(func, gradients, env, testPoints[i]);
50
- if (!result.passed) {
51
- if (allPassed) {
52
- // First failure - print header (as comment for valid output)
53
- console.error(`// Gradient verification FAILED for "${func.name}":`);
54
- }
55
- // Prefix each line with // so output remains valid code
56
- const formattedResult = formatGradCheckResult(result, func.name)
57
- .split('\n')
58
- .map(line => '// ' + line)
59
- .join('\n');
60
- console.error(`// Test point ${i + 1}: ${formattedResult}`);
70
+ let maxError = 0;
71
+ let totalChecks = 0;
72
+ const errors = [];
73
+ for (let pointIdx = 0; pointIdx < testPoints.length; pointIdx++) {
74
+ const testPoint = testPoints[pointIdx];
75
+ // Build args array (flattening structured types)
76
+ const args = [];
77
+ for (const param of func.parameters) {
78
+ const val = testPoint.get(param.name);
79
+ if (typeof val === 'number') {
80
+ args.push(val);
81
+ }
82
+ else {
83
+ // Structured: pass components in order
84
+ const paramType = env.getOrThrow(param.name);
85
+ if (!Types.isScalar(paramType)) {
86
+ for (const comp of paramType.components) {
87
+ args.push(val[comp]);
88
+ }
89
+ }
90
+ }
91
+ }
92
+ // Run forward function at test point
93
+ let f0;
94
+ try {
95
+ f0 = forwardFn(...args);
96
+ }
97
+ catch (e) {
98
+ errors.push(`Test point ${pointIdx + 1}: forward function threw: ${e}`);
61
99
  allPassed = false;
100
+ continue;
101
+ }
102
+ // Run gradient function
103
+ let gradResult;
104
+ try {
105
+ // Pass original (non-flattened) args for grad function
106
+ const gradArgs = [];
107
+ for (const param of func.parameters) {
108
+ const val = testPoint.get(param.name);
109
+ if (typeof val === 'number') {
110
+ gradArgs.push(val);
111
+ }
112
+ else {
113
+ // For structured types, pass as object with named components
114
+ gradArgs.push(val);
115
+ }
116
+ }
117
+ gradResult = gradFn(...gradArgs);
118
+ }
119
+ catch (e) {
120
+ errors.push(`Test point ${pointIdx + 1}: gradient function threw: ${e}`);
121
+ allPassed = false;
122
+ continue;
123
+ }
124
+ // For each gradient parameter, compare analytical vs numerical
125
+ for (const param of func.parameters) {
126
+ if (!param.requiresGrad)
127
+ continue;
128
+ const paramType = env.getOrThrow(param.name);
129
+ const gradKey = `d${param.name}`;
130
+ if (Types.isScalar(paramType)) {
131
+ // Scalar gradient
132
+ totalChecks++;
133
+ const analytical = gradResult[gradKey];
134
+ // Compute numerical gradient
135
+ const paramIdx = func.parameters.findIndex(p => p.name === param.name);
136
+ let flatIdx = 0;
137
+ for (let i = 0; i < paramIdx; i++) {
138
+ const pt = env.getOrThrow(func.parameters[i].name);
139
+ flatIdx += Types.isScalar(pt) ? 1 : pt.components.length;
140
+ }
141
+ const argsPlus = [...args];
142
+ const argsMinus = [...args];
143
+ argsPlus[flatIdx] += epsilon;
144
+ argsMinus[flatIdx] -= epsilon;
145
+ const fPlus = forwardFn(...argsPlus);
146
+ const fMinus = forwardFn(...argsMinus);
147
+ const numerical = (fPlus - fMinus) / (2 * epsilon);
148
+ const error = Math.abs(analytical - numerical);
149
+ const relError = error / (Math.abs(numerical) + 1e-10);
150
+ maxError = Math.max(maxError, error);
151
+ if (error > tolerance && relError > tolerance) {
152
+ if (!isNaN(analytical) || !isNaN(numerical)) {
153
+ errors.push(`${param.name}: analytical=${analytical.toExponential(2)}, numerical=${numerical.toExponential(2)}, error=${error.toExponential(2)}`);
154
+ allPassed = false;
155
+ }
156
+ }
157
+ }
158
+ else {
159
+ // Structured gradient
160
+ const gradStruct = gradResult[gradKey];
161
+ const paramIdx = func.parameters.findIndex(p => p.name === param.name);
162
+ let flatIdx = 0;
163
+ for (let i = 0; i < paramIdx; i++) {
164
+ const pt = env.getOrThrow(func.parameters[i].name);
165
+ flatIdx += Types.isScalar(pt) ? 1 : pt.components.length;
166
+ }
167
+ for (let compIdx = 0; compIdx < paramType.components.length; compIdx++) {
168
+ const comp = paramType.components[compIdx];
169
+ totalChecks++;
170
+ const analytical = gradStruct[comp];
171
+ const argsPlus = [...args];
172
+ const argsMinus = [...args];
173
+ argsPlus[flatIdx + compIdx] += epsilon;
174
+ argsMinus[flatIdx + compIdx] -= epsilon;
175
+ const fPlus = forwardFn(...argsPlus);
176
+ const fMinus = forwardFn(...argsMinus);
177
+ const numerical = (fPlus - fMinus) / (2 * epsilon);
178
+ const error = Math.abs(analytical - numerical);
179
+ const relError = error / (Math.abs(numerical) + 1e-10);
180
+ maxError = Math.max(maxError, error);
181
+ if (error > tolerance && relError > tolerance) {
182
+ if (!isNaN(analytical) || !isNaN(numerical)) {
183
+ errors.push(`${param.name}.${comp}: analytical=${analytical.toExponential(2)}, numerical=${numerical.toExponential(2)}, error=${error.toExponential(2)}`);
184
+ allPassed = false;
185
+ }
186
+ }
187
+ }
188
+ }
62
189
  }
63
190
  }
64
191
  if (allPassed) {
65
- const result = checker.check(func, gradients, env, testPoints[0]);
66
- // Prefix with // so output is valid code
67
- console.error('// ' + formatGradCheckResult(result, func.name));
192
+ console.error(`// ${func.name}: ${totalChecks} gradients verified (max error: ${maxError.toExponential(2)})`);
193
+ }
194
+ else {
195
+ console.error(`// ✗ ${func.name}: gradient verification FAILED`);
196
+ for (const err of errors.slice(0, 5)) {
197
+ console.error(`// ${err}`);
198
+ }
199
+ if (errors.length > 5) {
200
+ console.error(`// ... and ${errors.length - 5} more errors`);
201
+ }
68
202
  }
69
203
  return allPassed;
70
204
  }
@@ -208,8 +342,15 @@ function main() {
208
342
  program.functions.forEach((func, index) => {
209
343
  const env = inferFunction(func);
210
344
  const gradients = computeFunctionGradients(func, env);
211
- // MANDATORY gradient verification
212
- const verified = verifyGradients(func, gradients, env);
345
+ const perFunctionOptions = { ...options };
346
+ if (index > 0 && perFunctionOptions.includeComments !== false) {
347
+ perFunctionOptions.includeComments = false;
348
+ }
349
+ // Generate code FIRST
350
+ const code = generateComplete(func, gradients, env, perFunctionOptions);
351
+ // MANDATORY: Verify the GENERATED CODE against numerical differentiation
352
+ // This catches code generation bugs that AST-based checking would miss
353
+ const verified = verifyGeneratedCode(func, code, env);
213
354
  if (!verified) {
214
355
  hasVerificationFailure = true;
215
356
  }
@@ -219,11 +360,6 @@ function main() {
219
360
  console.error('// Function "' + func.name + '" may have edge cases:');
220
361
  console.error(formatGuardWarnings(guardAnalysis, true));
221
362
  }
222
- const perFunctionOptions = { ...options };
223
- if (index > 0 && perFunctionOptions.includeComments !== false) {
224
- perFunctionOptions.includeComments = false;
225
- }
226
- const code = generateComplete(func, gradients, env, perFunctionOptions);
227
363
  outputs.push(code);
228
364
  });
229
365
  if (hasVerificationFailure) {
@@ -169,6 +169,10 @@ export class ExpressionCodeGen {
169
169
  }
170
170
  genUnary(expr) {
171
171
  const operand = this.generate(expr.operand);
172
+ // Parenthesize binary operands to avoid precedence bugs: -(a + b) not -a + b
173
+ if (expr.operand.kind === 'binary') {
174
+ return `${expr.operator}(${operand})`;
175
+ }
172
176
  return `${expr.operator}${operand}`;
173
177
  }
174
178
  genCall(expr) {
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "gradient-script",
3
- "version": "0.3.0",
3
+ "version": "0.3.1",
4
4
  "description": "Symbolic differentiation for structured types with a simple DSL",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",