gradient-script 0.2.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. package/README.md +3 -1
  2. package/dist/cli.js +219 -6
  3. package/dist/dsl/CodeGen.d.ts +1 -1
  4. package/dist/dsl/CodeGen.js +336 -74
  5. package/dist/dsl/ExpressionUtils.d.ts +8 -2
  6. package/dist/dsl/ExpressionUtils.js +34 -2
  7. package/dist/dsl/GradientChecker.d.ts +21 -0
  8. package/dist/dsl/GradientChecker.js +109 -23
  9. package/dist/dsl/Guards.d.ts +1 -1
  10. package/dist/dsl/Guards.js +14 -13
  11. package/dist/dsl/Inliner.d.ts +5 -0
  12. package/dist/dsl/Inliner.js +8 -0
  13. package/dist/dsl/Simplify.d.ts +7 -0
  14. package/dist/dsl/Simplify.js +136 -0
  15. package/dist/dsl/egraph/Convert.d.ts +23 -0
  16. package/dist/dsl/egraph/Convert.js +84 -0
  17. package/dist/dsl/egraph/EGraph.d.ts +93 -0
  18. package/dist/dsl/egraph/EGraph.js +292 -0
  19. package/dist/dsl/egraph/ENode.d.ts +63 -0
  20. package/dist/dsl/egraph/ENode.js +94 -0
  21. package/dist/dsl/egraph/Extractor.d.ts +49 -0
  22. package/dist/dsl/egraph/Extractor.js +1068 -0
  23. package/dist/dsl/egraph/Optimizer.d.ts +50 -0
  24. package/dist/dsl/egraph/Optimizer.js +88 -0
  25. package/dist/dsl/egraph/Pattern.d.ts +80 -0
  26. package/dist/dsl/egraph/Pattern.js +325 -0
  27. package/dist/dsl/egraph/Rewriter.d.ts +44 -0
  28. package/dist/dsl/egraph/Rewriter.js +131 -0
  29. package/dist/dsl/egraph/Rules.d.ts +44 -0
  30. package/dist/dsl/egraph/Rules.js +187 -0
  31. package/dist/dsl/egraph/index.d.ts +15 -0
  32. package/dist/dsl/egraph/index.js +21 -0
  33. package/package.json +1 -1
  34. package/dist/dsl/CSE.d.ts +0 -21
  35. package/dist/dsl/CSE.js +0 -168
  36. package/dist/symbolic/AST.d.ts +0 -113
  37. package/dist/symbolic/AST.js +0 -128
  38. package/dist/symbolic/CodeGen.d.ts +0 -35
  39. package/dist/symbolic/CodeGen.js +0 -280
  40. package/dist/symbolic/Parser.d.ts +0 -64
  41. package/dist/symbolic/Parser.js +0 -329
  42. package/dist/symbolic/Simplify.d.ts +0 -10
  43. package/dist/symbolic/Simplify.js +0 -244
  44. package/dist/symbolic/SymbolicDiff.d.ts +0 -35
  45. package/dist/symbolic/SymbolicDiff.js +0 -339
package/README.md CHANGED
@@ -12,6 +12,8 @@
12
12
 
13
13
  GradientScript is a source-to-source compiler that automatically generates gradient functions from your mathematical code. Unlike numerical AD frameworks (JAX, PyTorch), it produces clean, human-readable gradient formulas you can inspect, optimize, and integrate directly into your codebase.
14
14
 
15
+ It's perfect for LLM usage where the LLM can verify existing gradients or construct gradients you require with less risks of making errors.
16
+
15
17
  ## Why GradientScript?
16
18
 
17
19
  - **From real code to gradients**: Write natural math code, get symbolic derivatives
@@ -40,7 +42,7 @@ function distance(u: Vec2, v: Vec2): number {
40
42
  }
41
43
  ```
42
44
 
43
- Convert it to GradientScript by marking what you need gradients for:
45
+ Convert it to GradientScript (realistically, let your LLM convert it giving it a reference here - and/or free usage of the CLI) by marking what you need gradients for:
44
46
 
45
47
  ```typescript
46
48
  // distance.gs
package/dist/cli.js CHANGED
@@ -6,6 +6,202 @@ import { computeFunctionGradients } from './dsl/Differentiation.js';
6
6
  import { generateComplete } from './dsl/CodeGen.js';
7
7
  import { analyzeGuards, formatGuardWarnings } from './dsl/Guards.js';
8
8
  import { ParseError, formatParseError } from './dsl/Errors.js';
9
+ // GradientChecker import removed - verification now executes generated code directly
10
+ import { Types } from './dsl/Types.js';
11
+ /**
12
+ * Generate random test points for gradient verification.
13
+ * Uses multiple test points to catch errors at different values.
14
+ */
15
+ function generateTestPoints(func, env) {
16
+ const testPoints = [];
17
+ // Generate 3 different test points with varying scales
18
+ const scales = [1.0, 0.1, 10.0];
19
+ for (const scale of scales) {
20
+ const point = new Map();
21
+ for (const param of func.parameters) {
22
+ const paramType = env.getOrThrow(param.name);
23
+ if (Types.isScalar(paramType)) {
24
+ // Random scalar in range [-scale, scale], avoid zero
25
+ point.set(param.name, (Math.random() * 2 - 1) * scale + 0.1 * scale);
26
+ }
27
+ else {
28
+ // Structured type - get components
29
+ const struct = {};
30
+ for (const comp of paramType.components) {
31
+ struct[comp] = (Math.random() * 2 - 1) * scale + 0.1 * scale;
32
+ }
33
+ point.set(param.name, struct);
34
+ }
35
+ }
36
+ testPoints.push(point);
37
+ }
38
+ return testPoints;
39
+ }
40
+ /**
41
+ * Verify gradients by executing the GENERATED CODE against numerical differentiation.
42
+ * This catches code generation bugs (like missing parentheses) that AST-based checking misses.
43
+ */
44
+ function verifyGeneratedCode(func, generatedCode, env) {
45
+ const testPoints = generateTestPoints(func, env);
46
+ const epsilon = 1e-6;
47
+ const tolerance = 1e-4;
48
+ // Build parameter list for function calls
49
+ const paramNames = func.parameters.map(p => p.name);
50
+ // Create executable functions from generated code
51
+ // The generated code defines both funcName and funcName_grad
52
+ let forwardFn;
53
+ let gradFn;
54
+ try {
55
+ // Create a function that returns both the forward and grad functions
56
+ const factory = new Function(`
57
+ ${generatedCode}
58
+ return { forward: ${func.name}, grad: ${func.name}_grad };
59
+ `);
60
+ const fns = factory();
61
+ forwardFn = fns.forward;
62
+ gradFn = fns.grad;
63
+ }
64
+ catch (e) {
65
+ console.error(`// Gradient verification FAILED for "${func.name}": code execution error`);
66
+ console.error(`// ${e}`);
67
+ return false;
68
+ }
69
+ let allPassed = true;
70
+ let maxError = 0;
71
+ let totalChecks = 0;
72
+ const errors = [];
73
+ for (let pointIdx = 0; pointIdx < testPoints.length; pointIdx++) {
74
+ const testPoint = testPoints[pointIdx];
75
+ // Build args array (flattening structured types)
76
+ const args = [];
77
+ for (const param of func.parameters) {
78
+ const val = testPoint.get(param.name);
79
+ if (typeof val === 'number') {
80
+ args.push(val);
81
+ }
82
+ else {
83
+ // Structured: pass components in order
84
+ const paramType = env.getOrThrow(param.name);
85
+ if (!Types.isScalar(paramType)) {
86
+ for (const comp of paramType.components) {
87
+ args.push(val[comp]);
88
+ }
89
+ }
90
+ }
91
+ }
92
+ // Run forward function at test point
93
+ let f0;
94
+ try {
95
+ f0 = forwardFn(...args);
96
+ }
97
+ catch (e) {
98
+ errors.push(`Test point ${pointIdx + 1}: forward function threw: ${e}`);
99
+ allPassed = false;
100
+ continue;
101
+ }
102
+ // Run gradient function
103
+ let gradResult;
104
+ try {
105
+ // Pass original (non-flattened) args for grad function
106
+ const gradArgs = [];
107
+ for (const param of func.parameters) {
108
+ const val = testPoint.get(param.name);
109
+ if (typeof val === 'number') {
110
+ gradArgs.push(val);
111
+ }
112
+ else {
113
+ // For structured types, pass as object with named components
114
+ gradArgs.push(val);
115
+ }
116
+ }
117
+ gradResult = gradFn(...gradArgs);
118
+ }
119
+ catch (e) {
120
+ errors.push(`Test point ${pointIdx + 1}: gradient function threw: ${e}`);
121
+ allPassed = false;
122
+ continue;
123
+ }
124
+ // For each gradient parameter, compare analytical vs numerical
125
+ for (const param of func.parameters) {
126
+ if (!param.requiresGrad)
127
+ continue;
128
+ const paramType = env.getOrThrow(param.name);
129
+ const gradKey = `d${param.name}`;
130
+ if (Types.isScalar(paramType)) {
131
+ // Scalar gradient
132
+ totalChecks++;
133
+ const analytical = gradResult[gradKey];
134
+ // Compute numerical gradient
135
+ const paramIdx = func.parameters.findIndex(p => p.name === param.name);
136
+ let flatIdx = 0;
137
+ for (let i = 0; i < paramIdx; i++) {
138
+ const pt = env.getOrThrow(func.parameters[i].name);
139
+ flatIdx += Types.isScalar(pt) ? 1 : pt.components.length;
140
+ }
141
+ const argsPlus = [...args];
142
+ const argsMinus = [...args];
143
+ argsPlus[flatIdx] += epsilon;
144
+ argsMinus[flatIdx] -= epsilon;
145
+ const fPlus = forwardFn(...argsPlus);
146
+ const fMinus = forwardFn(...argsMinus);
147
+ const numerical = (fPlus - fMinus) / (2 * epsilon);
148
+ const error = Math.abs(analytical - numerical);
149
+ const relError = error / (Math.abs(numerical) + 1e-10);
150
+ maxError = Math.max(maxError, error);
151
+ if (error > tolerance && relError > tolerance) {
152
+ if (!isNaN(analytical) || !isNaN(numerical)) {
153
+ errors.push(`${param.name}: analytical=${analytical.toExponential(2)}, numerical=${numerical.toExponential(2)}, error=${error.toExponential(2)}`);
154
+ allPassed = false;
155
+ }
156
+ }
157
+ }
158
+ else {
159
+ // Structured gradient
160
+ const gradStruct = gradResult[gradKey];
161
+ const paramIdx = func.parameters.findIndex(p => p.name === param.name);
162
+ let flatIdx = 0;
163
+ for (let i = 0; i < paramIdx; i++) {
164
+ const pt = env.getOrThrow(func.parameters[i].name);
165
+ flatIdx += Types.isScalar(pt) ? 1 : pt.components.length;
166
+ }
167
+ for (let compIdx = 0; compIdx < paramType.components.length; compIdx++) {
168
+ const comp = paramType.components[compIdx];
169
+ totalChecks++;
170
+ const analytical = gradStruct[comp];
171
+ const argsPlus = [...args];
172
+ const argsMinus = [...args];
173
+ argsPlus[flatIdx + compIdx] += epsilon;
174
+ argsMinus[flatIdx + compIdx] -= epsilon;
175
+ const fPlus = forwardFn(...argsPlus);
176
+ const fMinus = forwardFn(...argsMinus);
177
+ const numerical = (fPlus - fMinus) / (2 * epsilon);
178
+ const error = Math.abs(analytical - numerical);
179
+ const relError = error / (Math.abs(numerical) + 1e-10);
180
+ maxError = Math.max(maxError, error);
181
+ if (error > tolerance && relError > tolerance) {
182
+ if (!isNaN(analytical) || !isNaN(numerical)) {
183
+ errors.push(`${param.name}.${comp}: analytical=${analytical.toExponential(2)}, numerical=${numerical.toExponential(2)}, error=${error.toExponential(2)}`);
184
+ allPassed = false;
185
+ }
186
+ }
187
+ }
188
+ }
189
+ }
190
+ }
191
+ if (allPassed) {
192
+ console.error(`// ✓ ${func.name}: ${totalChecks} gradients verified (max error: ${maxError.toExponential(2)})`);
193
+ }
194
+ else {
195
+ console.error(`// ✗ ${func.name}: gradient verification FAILED`);
196
+ for (const err of errors.slice(0, 5)) {
197
+ console.error(`// ${err}`);
198
+ }
199
+ if (errors.length > 5) {
200
+ console.error(`// ... and ${errors.length - 5} more errors`);
201
+ }
202
+ }
203
+ return allPassed;
204
+ }
9
205
  function printUsage() {
10
206
  console.log(`
11
207
  GradientScript - Symbolic Differentiation for Structured Types
@@ -16,7 +212,7 @@ Usage:
16
212
  Options:
17
213
  --format <format> Output format: typescript (default), javascript, python, csharp
18
214
  --no-simplify Disable gradient simplification
19
- --no-cse Disable common subexpression elimination
215
+ --no-cse Disable optimization (e-graph CSE)
20
216
  --no-comments Omit comments in generated code
21
217
  --guards Emit runtime guards for division by zero (experimental)
22
218
  --epsilon <value> Epsilon value for guards (default: 1e-10)
@@ -45,6 +241,9 @@ For more information and examples:
45
241
 
46
242
  README (raw, LLM-friendly):
47
243
  https://raw.githubusercontent.com/mfagerlund/gradient-script/main/README.md
244
+
245
+ LLM Optimization Guide (for AI agents writing .gs files):
246
+ https://raw.githubusercontent.com/mfagerlund/gradient-script/main/docs/LLM-OPTIMIZATION-GUIDE.md
48
247
  `.trim());
49
248
  }
50
249
  function main() {
@@ -64,6 +263,7 @@ function main() {
64
263
  simplify: true,
65
264
  cse: true
66
265
  };
266
+ let skipVerify = false;
67
267
  for (let i = 1; i < args.length; i++) {
68
268
  const arg = args[i];
69
269
  if (arg === '--format') {
@@ -138,21 +338,34 @@ function main() {
138
338
  process.exit(1);
139
339
  }
140
340
  const outputs = [];
341
+ let hasVerificationFailure = false;
141
342
  program.functions.forEach((func, index) => {
142
343
  const env = inferFunction(func);
143
344
  const gradients = computeFunctionGradients(func, env);
144
- const guardAnalysis = analyzeGuards(func);
145
- if (guardAnalysis.hasIssues) {
146
- console.error('Function "' + func.name + '" may have edge cases:');
147
- console.error(formatGuardWarnings(guardAnalysis));
148
- }
149
345
  const perFunctionOptions = { ...options };
150
346
  if (index > 0 && perFunctionOptions.includeComments !== false) {
151
347
  perFunctionOptions.includeComments = false;
152
348
  }
349
+ // Generate code FIRST
153
350
  const code = generateComplete(func, gradients, env, perFunctionOptions);
351
+ // MANDATORY: Verify the GENERATED CODE against numerical differentiation
352
+ // This catches code generation bugs that AST-based checking would miss
353
+ const verified = verifyGeneratedCode(func, code, env);
354
+ if (!verified) {
355
+ hasVerificationFailure = true;
356
+ }
357
+ const guardAnalysis = analyzeGuards(func);
358
+ if (guardAnalysis.hasIssues) {
359
+ // Format warnings as comments so output remains valid code even if stderr is captured
360
+ console.error('// Function "' + func.name + '" may have edge cases:');
361
+ console.error(formatGuardWarnings(guardAnalysis, true));
362
+ }
154
363
  outputs.push(code);
155
364
  });
365
+ if (hasVerificationFailure) {
366
+ console.error('// ERROR: Gradient verification failed. Output may contain incorrect gradients!');
367
+ process.exit(1);
368
+ }
156
369
  console.log(outputs.join('\n\n'));
157
370
  }
158
371
  catch (err) {
@@ -56,7 +56,7 @@ export declare class ExpressionCodeGen {
56
56
  */
57
57
  export declare function generateGradientFunction(func: FunctionDef, gradients: GradientResult, env: TypeEnv, options?: CodeGenOptions): string;
58
58
  /**
59
- * Generate the original forward function
59
+ * Generate the original forward function (with optional e-graph optimization)
60
60
  */
61
61
  export declare function generateForwardFunction(func: FunctionDef, options?: CodeGenOptions): string;
62
62
  /**