gradient-script 0.2.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -1
- package/dist/cli.js +219 -6
- package/dist/dsl/CodeGen.d.ts +1 -1
- package/dist/dsl/CodeGen.js +336 -74
- package/dist/dsl/ExpressionUtils.d.ts +8 -2
- package/dist/dsl/ExpressionUtils.js +34 -2
- package/dist/dsl/GradientChecker.d.ts +21 -0
- package/dist/dsl/GradientChecker.js +109 -23
- package/dist/dsl/Guards.d.ts +1 -1
- package/dist/dsl/Guards.js +14 -13
- package/dist/dsl/Inliner.d.ts +5 -0
- package/dist/dsl/Inliner.js +8 -0
- package/dist/dsl/Simplify.d.ts +7 -0
- package/dist/dsl/Simplify.js +136 -0
- package/dist/dsl/egraph/Convert.d.ts +23 -0
- package/dist/dsl/egraph/Convert.js +84 -0
- package/dist/dsl/egraph/EGraph.d.ts +93 -0
- package/dist/dsl/egraph/EGraph.js +292 -0
- package/dist/dsl/egraph/ENode.d.ts +63 -0
- package/dist/dsl/egraph/ENode.js +94 -0
- package/dist/dsl/egraph/Extractor.d.ts +49 -0
- package/dist/dsl/egraph/Extractor.js +1068 -0
- package/dist/dsl/egraph/Optimizer.d.ts +50 -0
- package/dist/dsl/egraph/Optimizer.js +88 -0
- package/dist/dsl/egraph/Pattern.d.ts +80 -0
- package/dist/dsl/egraph/Pattern.js +325 -0
- package/dist/dsl/egraph/Rewriter.d.ts +44 -0
- package/dist/dsl/egraph/Rewriter.js +131 -0
- package/dist/dsl/egraph/Rules.d.ts +44 -0
- package/dist/dsl/egraph/Rules.js +187 -0
- package/dist/dsl/egraph/index.d.ts +15 -0
- package/dist/dsl/egraph/index.js +21 -0
- package/package.json +1 -1
- package/dist/dsl/CSE.d.ts +0 -21
- package/dist/dsl/CSE.js +0 -168
- package/dist/symbolic/AST.d.ts +0 -113
- package/dist/symbolic/AST.js +0 -128
- package/dist/symbolic/CodeGen.d.ts +0 -35
- package/dist/symbolic/CodeGen.js +0 -280
- package/dist/symbolic/Parser.d.ts +0 -64
- package/dist/symbolic/Parser.js +0 -329
- package/dist/symbolic/Simplify.d.ts +0 -10
- package/dist/symbolic/Simplify.js +0 -244
- package/dist/symbolic/SymbolicDiff.d.ts +0 -35
- package/dist/symbolic/SymbolicDiff.js +0 -339
package/README.md
CHANGED
|
@@ -12,6 +12,8 @@
|
|
|
12
12
|
|
|
13
13
|
GradientScript is a source-to-source compiler that automatically generates gradient functions from your mathematical code. Unlike numerical AD frameworks (JAX, PyTorch), it produces clean, human-readable gradient formulas you can inspect, optimize, and integrate directly into your codebase.
|
|
14
14
|
|
|
15
|
+
It's perfect for LLM usage where the LLM can verify existing gradients or construct gradients you require with less risks of making errors.
|
|
16
|
+
|
|
15
17
|
## Why GradientScript?
|
|
16
18
|
|
|
17
19
|
- **From real code to gradients**: Write natural math code, get symbolic derivatives
|
|
@@ -40,7 +42,7 @@ function distance(u: Vec2, v: Vec2): number {
|
|
|
40
42
|
}
|
|
41
43
|
```
|
|
42
44
|
|
|
43
|
-
Convert it to GradientScript by marking what you need gradients for:
|
|
45
|
+
Convert it to GradientScript (realistically, let your LLM convert it giving it a reference here - and/or free usage of the CLI) by marking what you need gradients for:
|
|
44
46
|
|
|
45
47
|
```typescript
|
|
46
48
|
// distance.gs
|
package/dist/cli.js
CHANGED
|
@@ -6,6 +6,202 @@ import { computeFunctionGradients } from './dsl/Differentiation.js';
|
|
|
6
6
|
import { generateComplete } from './dsl/CodeGen.js';
|
|
7
7
|
import { analyzeGuards, formatGuardWarnings } from './dsl/Guards.js';
|
|
8
8
|
import { ParseError, formatParseError } from './dsl/Errors.js';
|
|
9
|
+
// GradientChecker import removed - verification now executes generated code directly
|
|
10
|
+
import { Types } from './dsl/Types.js';
|
|
11
|
+
/**
|
|
12
|
+
* Generate random test points for gradient verification.
|
|
13
|
+
* Uses multiple test points to catch errors at different values.
|
|
14
|
+
*/
|
|
15
|
+
function generateTestPoints(func, env) {
|
|
16
|
+
const testPoints = [];
|
|
17
|
+
// Generate 3 different test points with varying scales
|
|
18
|
+
const scales = [1.0, 0.1, 10.0];
|
|
19
|
+
for (const scale of scales) {
|
|
20
|
+
const point = new Map();
|
|
21
|
+
for (const param of func.parameters) {
|
|
22
|
+
const paramType = env.getOrThrow(param.name);
|
|
23
|
+
if (Types.isScalar(paramType)) {
|
|
24
|
+
// Random scalar in range [-scale, scale], avoid zero
|
|
25
|
+
point.set(param.name, (Math.random() * 2 - 1) * scale + 0.1 * scale);
|
|
26
|
+
}
|
|
27
|
+
else {
|
|
28
|
+
// Structured type - get components
|
|
29
|
+
const struct = {};
|
|
30
|
+
for (const comp of paramType.components) {
|
|
31
|
+
struct[comp] = (Math.random() * 2 - 1) * scale + 0.1 * scale;
|
|
32
|
+
}
|
|
33
|
+
point.set(param.name, struct);
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
testPoints.push(point);
|
|
37
|
+
}
|
|
38
|
+
return testPoints;
|
|
39
|
+
}
|
|
40
|
+
/**
|
|
41
|
+
* Verify gradients by executing the GENERATED CODE against numerical differentiation.
|
|
42
|
+
* This catches code generation bugs (like missing parentheses) that AST-based checking misses.
|
|
43
|
+
*/
|
|
44
|
+
function verifyGeneratedCode(func, generatedCode, env) {
|
|
45
|
+
const testPoints = generateTestPoints(func, env);
|
|
46
|
+
const epsilon = 1e-6;
|
|
47
|
+
const tolerance = 1e-4;
|
|
48
|
+
// Build parameter list for function calls
|
|
49
|
+
const paramNames = func.parameters.map(p => p.name);
|
|
50
|
+
// Create executable functions from generated code
|
|
51
|
+
// The generated code defines both funcName and funcName_grad
|
|
52
|
+
let forwardFn;
|
|
53
|
+
let gradFn;
|
|
54
|
+
try {
|
|
55
|
+
// Create a function that returns both the forward and grad functions
|
|
56
|
+
const factory = new Function(`
|
|
57
|
+
${generatedCode}
|
|
58
|
+
return { forward: ${func.name}, grad: ${func.name}_grad };
|
|
59
|
+
`);
|
|
60
|
+
const fns = factory();
|
|
61
|
+
forwardFn = fns.forward;
|
|
62
|
+
gradFn = fns.grad;
|
|
63
|
+
}
|
|
64
|
+
catch (e) {
|
|
65
|
+
console.error(`// Gradient verification FAILED for "${func.name}": code execution error`);
|
|
66
|
+
console.error(`// ${e}`);
|
|
67
|
+
return false;
|
|
68
|
+
}
|
|
69
|
+
let allPassed = true;
|
|
70
|
+
let maxError = 0;
|
|
71
|
+
let totalChecks = 0;
|
|
72
|
+
const errors = [];
|
|
73
|
+
for (let pointIdx = 0; pointIdx < testPoints.length; pointIdx++) {
|
|
74
|
+
const testPoint = testPoints[pointIdx];
|
|
75
|
+
// Build args array (flattening structured types)
|
|
76
|
+
const args = [];
|
|
77
|
+
for (const param of func.parameters) {
|
|
78
|
+
const val = testPoint.get(param.name);
|
|
79
|
+
if (typeof val === 'number') {
|
|
80
|
+
args.push(val);
|
|
81
|
+
}
|
|
82
|
+
else {
|
|
83
|
+
// Structured: pass components in order
|
|
84
|
+
const paramType = env.getOrThrow(param.name);
|
|
85
|
+
if (!Types.isScalar(paramType)) {
|
|
86
|
+
for (const comp of paramType.components) {
|
|
87
|
+
args.push(val[comp]);
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
// Run forward function at test point
|
|
93
|
+
let f0;
|
|
94
|
+
try {
|
|
95
|
+
f0 = forwardFn(...args);
|
|
96
|
+
}
|
|
97
|
+
catch (e) {
|
|
98
|
+
errors.push(`Test point ${pointIdx + 1}: forward function threw: ${e}`);
|
|
99
|
+
allPassed = false;
|
|
100
|
+
continue;
|
|
101
|
+
}
|
|
102
|
+
// Run gradient function
|
|
103
|
+
let gradResult;
|
|
104
|
+
try {
|
|
105
|
+
// Pass original (non-flattened) args for grad function
|
|
106
|
+
const gradArgs = [];
|
|
107
|
+
for (const param of func.parameters) {
|
|
108
|
+
const val = testPoint.get(param.name);
|
|
109
|
+
if (typeof val === 'number') {
|
|
110
|
+
gradArgs.push(val);
|
|
111
|
+
}
|
|
112
|
+
else {
|
|
113
|
+
// For structured types, pass as object with named components
|
|
114
|
+
gradArgs.push(val);
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
gradResult = gradFn(...gradArgs);
|
|
118
|
+
}
|
|
119
|
+
catch (e) {
|
|
120
|
+
errors.push(`Test point ${pointIdx + 1}: gradient function threw: ${e}`);
|
|
121
|
+
allPassed = false;
|
|
122
|
+
continue;
|
|
123
|
+
}
|
|
124
|
+
// For each gradient parameter, compare analytical vs numerical
|
|
125
|
+
for (const param of func.parameters) {
|
|
126
|
+
if (!param.requiresGrad)
|
|
127
|
+
continue;
|
|
128
|
+
const paramType = env.getOrThrow(param.name);
|
|
129
|
+
const gradKey = `d${param.name}`;
|
|
130
|
+
if (Types.isScalar(paramType)) {
|
|
131
|
+
// Scalar gradient
|
|
132
|
+
totalChecks++;
|
|
133
|
+
const analytical = gradResult[gradKey];
|
|
134
|
+
// Compute numerical gradient
|
|
135
|
+
const paramIdx = func.parameters.findIndex(p => p.name === param.name);
|
|
136
|
+
let flatIdx = 0;
|
|
137
|
+
for (let i = 0; i < paramIdx; i++) {
|
|
138
|
+
const pt = env.getOrThrow(func.parameters[i].name);
|
|
139
|
+
flatIdx += Types.isScalar(pt) ? 1 : pt.components.length;
|
|
140
|
+
}
|
|
141
|
+
const argsPlus = [...args];
|
|
142
|
+
const argsMinus = [...args];
|
|
143
|
+
argsPlus[flatIdx] += epsilon;
|
|
144
|
+
argsMinus[flatIdx] -= epsilon;
|
|
145
|
+
const fPlus = forwardFn(...argsPlus);
|
|
146
|
+
const fMinus = forwardFn(...argsMinus);
|
|
147
|
+
const numerical = (fPlus - fMinus) / (2 * epsilon);
|
|
148
|
+
const error = Math.abs(analytical - numerical);
|
|
149
|
+
const relError = error / (Math.abs(numerical) + 1e-10);
|
|
150
|
+
maxError = Math.max(maxError, error);
|
|
151
|
+
if (error > tolerance && relError > tolerance) {
|
|
152
|
+
if (!isNaN(analytical) || !isNaN(numerical)) {
|
|
153
|
+
errors.push(`${param.name}: analytical=${analytical.toExponential(2)}, numerical=${numerical.toExponential(2)}, error=${error.toExponential(2)}`);
|
|
154
|
+
allPassed = false;
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
}
|
|
158
|
+
else {
|
|
159
|
+
// Structured gradient
|
|
160
|
+
const gradStruct = gradResult[gradKey];
|
|
161
|
+
const paramIdx = func.parameters.findIndex(p => p.name === param.name);
|
|
162
|
+
let flatIdx = 0;
|
|
163
|
+
for (let i = 0; i < paramIdx; i++) {
|
|
164
|
+
const pt = env.getOrThrow(func.parameters[i].name);
|
|
165
|
+
flatIdx += Types.isScalar(pt) ? 1 : pt.components.length;
|
|
166
|
+
}
|
|
167
|
+
for (let compIdx = 0; compIdx < paramType.components.length; compIdx++) {
|
|
168
|
+
const comp = paramType.components[compIdx];
|
|
169
|
+
totalChecks++;
|
|
170
|
+
const analytical = gradStruct[comp];
|
|
171
|
+
const argsPlus = [...args];
|
|
172
|
+
const argsMinus = [...args];
|
|
173
|
+
argsPlus[flatIdx + compIdx] += epsilon;
|
|
174
|
+
argsMinus[flatIdx + compIdx] -= epsilon;
|
|
175
|
+
const fPlus = forwardFn(...argsPlus);
|
|
176
|
+
const fMinus = forwardFn(...argsMinus);
|
|
177
|
+
const numerical = (fPlus - fMinus) / (2 * epsilon);
|
|
178
|
+
const error = Math.abs(analytical - numerical);
|
|
179
|
+
const relError = error / (Math.abs(numerical) + 1e-10);
|
|
180
|
+
maxError = Math.max(maxError, error);
|
|
181
|
+
if (error > tolerance && relError > tolerance) {
|
|
182
|
+
if (!isNaN(analytical) || !isNaN(numerical)) {
|
|
183
|
+
errors.push(`${param.name}.${comp}: analytical=${analytical.toExponential(2)}, numerical=${numerical.toExponential(2)}, error=${error.toExponential(2)}`);
|
|
184
|
+
allPassed = false;
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
}
|
|
191
|
+
if (allPassed) {
|
|
192
|
+
console.error(`// ✓ ${func.name}: ${totalChecks} gradients verified (max error: ${maxError.toExponential(2)})`);
|
|
193
|
+
}
|
|
194
|
+
else {
|
|
195
|
+
console.error(`// ✗ ${func.name}: gradient verification FAILED`);
|
|
196
|
+
for (const err of errors.slice(0, 5)) {
|
|
197
|
+
console.error(`// ${err}`);
|
|
198
|
+
}
|
|
199
|
+
if (errors.length > 5) {
|
|
200
|
+
console.error(`// ... and ${errors.length - 5} more errors`);
|
|
201
|
+
}
|
|
202
|
+
}
|
|
203
|
+
return allPassed;
|
|
204
|
+
}
|
|
9
205
|
function printUsage() {
|
|
10
206
|
console.log(`
|
|
11
207
|
GradientScript - Symbolic Differentiation for Structured Types
|
|
@@ -16,7 +212,7 @@ Usage:
|
|
|
16
212
|
Options:
|
|
17
213
|
--format <format> Output format: typescript (default), javascript, python, csharp
|
|
18
214
|
--no-simplify Disable gradient simplification
|
|
19
|
-
--no-cse Disable
|
|
215
|
+
--no-cse Disable optimization (e-graph CSE)
|
|
20
216
|
--no-comments Omit comments in generated code
|
|
21
217
|
--guards Emit runtime guards for division by zero (experimental)
|
|
22
218
|
--epsilon <value> Epsilon value for guards (default: 1e-10)
|
|
@@ -45,6 +241,9 @@ For more information and examples:
|
|
|
45
241
|
|
|
46
242
|
README (raw, LLM-friendly):
|
|
47
243
|
https://raw.githubusercontent.com/mfagerlund/gradient-script/main/README.md
|
|
244
|
+
|
|
245
|
+
LLM Optimization Guide (for AI agents writing .gs files):
|
|
246
|
+
https://raw.githubusercontent.com/mfagerlund/gradient-script/main/docs/LLM-OPTIMIZATION-GUIDE.md
|
|
48
247
|
`.trim());
|
|
49
248
|
}
|
|
50
249
|
function main() {
|
|
@@ -64,6 +263,7 @@ function main() {
|
|
|
64
263
|
simplify: true,
|
|
65
264
|
cse: true
|
|
66
265
|
};
|
|
266
|
+
let skipVerify = false;
|
|
67
267
|
for (let i = 1; i < args.length; i++) {
|
|
68
268
|
const arg = args[i];
|
|
69
269
|
if (arg === '--format') {
|
|
@@ -138,21 +338,34 @@ function main() {
|
|
|
138
338
|
process.exit(1);
|
|
139
339
|
}
|
|
140
340
|
const outputs = [];
|
|
341
|
+
let hasVerificationFailure = false;
|
|
141
342
|
program.functions.forEach((func, index) => {
|
|
142
343
|
const env = inferFunction(func);
|
|
143
344
|
const gradients = computeFunctionGradients(func, env);
|
|
144
|
-
const guardAnalysis = analyzeGuards(func);
|
|
145
|
-
if (guardAnalysis.hasIssues) {
|
|
146
|
-
console.error('Function "' + func.name + '" may have edge cases:');
|
|
147
|
-
console.error(formatGuardWarnings(guardAnalysis));
|
|
148
|
-
}
|
|
149
345
|
const perFunctionOptions = { ...options };
|
|
150
346
|
if (index > 0 && perFunctionOptions.includeComments !== false) {
|
|
151
347
|
perFunctionOptions.includeComments = false;
|
|
152
348
|
}
|
|
349
|
+
// Generate code FIRST
|
|
153
350
|
const code = generateComplete(func, gradients, env, perFunctionOptions);
|
|
351
|
+
// MANDATORY: Verify the GENERATED CODE against numerical differentiation
|
|
352
|
+
// This catches code generation bugs that AST-based checking would miss
|
|
353
|
+
const verified = verifyGeneratedCode(func, code, env);
|
|
354
|
+
if (!verified) {
|
|
355
|
+
hasVerificationFailure = true;
|
|
356
|
+
}
|
|
357
|
+
const guardAnalysis = analyzeGuards(func);
|
|
358
|
+
if (guardAnalysis.hasIssues) {
|
|
359
|
+
// Format warnings as comments so output remains valid code even if stderr is captured
|
|
360
|
+
console.error('// Function "' + func.name + '" may have edge cases:');
|
|
361
|
+
console.error(formatGuardWarnings(guardAnalysis, true));
|
|
362
|
+
}
|
|
154
363
|
outputs.push(code);
|
|
155
364
|
});
|
|
365
|
+
if (hasVerificationFailure) {
|
|
366
|
+
console.error('// ERROR: Gradient verification failed. Output may contain incorrect gradients!');
|
|
367
|
+
process.exit(1);
|
|
368
|
+
}
|
|
156
369
|
console.log(outputs.join('\n\n'));
|
|
157
370
|
}
|
|
158
371
|
catch (err) {
|
package/dist/dsl/CodeGen.d.ts
CHANGED
|
@@ -56,7 +56,7 @@ export declare class ExpressionCodeGen {
|
|
|
56
56
|
*/
|
|
57
57
|
export declare function generateGradientFunction(func: FunctionDef, gradients: GradientResult, env: TypeEnv, options?: CodeGenOptions): string;
|
|
58
58
|
/**
|
|
59
|
-
* Generate the original forward function
|
|
59
|
+
* Generate the original forward function (with optional e-graph optimization)
|
|
60
60
|
*/
|
|
61
61
|
export declare function generateForwardFunction(func: FunctionDef, options?: CodeGenOptions): string;
|
|
62
62
|
/**
|