scalar-autograd 0.1.7 → 0.1.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +127 -2
- package/dist/CompiledFunctions.d.ts +111 -0
- package/dist/CompiledFunctions.js +268 -0
- package/dist/CompiledResiduals.d.ts +74 -0
- package/dist/CompiledResiduals.js +94 -0
- package/dist/EigenvalueHelpers.d.ts +14 -0
- package/dist/EigenvalueHelpers.js +93 -0
- package/dist/Geometry.d.ts +131 -0
- package/dist/Geometry.js +213 -0
- package/dist/GraphBuilder.d.ts +64 -0
- package/dist/GraphBuilder.js +237 -0
- package/dist/GraphCanonicalizerNoSort.d.ts +20 -0
- package/dist/GraphCanonicalizerNoSort.js +190 -0
- package/dist/GraphHashCanonicalizer.d.ts +46 -0
- package/dist/GraphHashCanonicalizer.js +220 -0
- package/dist/GraphSignature.d.ts +7 -0
- package/dist/GraphSignature.js +7 -0
- package/dist/KernelPool.d.ts +55 -0
- package/dist/KernelPool.js +124 -0
- package/dist/LBFGS.d.ts +84 -0
- package/dist/LBFGS.js +313 -0
- package/dist/LinearSolver.d.ts +69 -0
- package/dist/LinearSolver.js +213 -0
- package/dist/Losses.d.ts +9 -0
- package/dist/Losses.js +42 -37
- package/dist/Matrix3x3.d.ts +50 -0
- package/dist/Matrix3x3.js +146 -0
- package/dist/NonlinearLeastSquares.d.ts +33 -0
- package/dist/NonlinearLeastSquares.js +252 -0
- package/dist/Optimizers.d.ts +70 -14
- package/dist/Optimizers.js +42 -19
- package/dist/V.d.ts +0 -0
- package/dist/V.js +0 -0
- package/dist/Value.d.ts +84 -2
- package/dist/Value.js +296 -58
- package/dist/ValueActivation.js +10 -14
- package/dist/ValueArithmetic.d.ts +1 -0
- package/dist/ValueArithmetic.js +58 -50
- package/dist/ValueComparison.js +9 -13
- package/dist/ValueRegistry.d.ts +38 -0
- package/dist/ValueRegistry.js +88 -0
- package/dist/ValueTrig.js +14 -18
- package/dist/Vec2.d.ts +45 -0
- package/dist/Vec2.js +93 -0
- package/dist/Vec3.d.ts +78 -0
- package/dist/Vec3.js +169 -0
- package/dist/Vec4.d.ts +45 -0
- package/dist/Vec4.js +126 -0
- package/dist/__tests__/duplicate-inputs.test.js +33 -0
- package/dist/cli/gradient-gen.d.ts +19 -0
- package/dist/cli/gradient-gen.js +264 -0
- package/dist/compileIndirectKernel.d.ts +24 -0
- package/dist/compileIndirectKernel.js +148 -0
- package/dist/index.d.ts +20 -0
- package/dist/index.js +20 -0
- package/dist/scalar-autograd.d.ts +1157 -0
- package/dist/symbolic/AST.d.ts +113 -0
- package/dist/symbolic/AST.js +128 -0
- package/dist/symbolic/CodeGen.d.ts +35 -0
- package/dist/symbolic/CodeGen.js +280 -0
- package/dist/symbolic/Parser.d.ts +64 -0
- package/dist/symbolic/Parser.js +329 -0
- package/dist/symbolic/Simplify.d.ts +10 -0
- package/dist/symbolic/Simplify.js +244 -0
- package/dist/symbolic/SymbolicDiff.d.ts +35 -0
- package/dist/symbolic/SymbolicDiff.js +339 -0
- package/dist/tsdoc-metadata.json +11 -0
- package/package.json +29 -5
- package/dist/Losses.spec.js +0 -54
- package/dist/Optimizers.edge-cases.spec.d.ts +0 -1
- package/dist/Optimizers.edge-cases.spec.js +0 -29
- package/dist/Optimizers.spec.d.ts +0 -1
- package/dist/Optimizers.spec.js +0 -56
- package/dist/Value.edge-cases.spec.d.ts +0 -1
- package/dist/Value.edge-cases.spec.js +0 -54
- package/dist/Value.grad-flow.spec.d.ts +0 -1
- package/dist/Value.grad-flow.spec.js +0 -24
- package/dist/Value.losses-edge-cases.spec.d.ts +0 -1
- package/dist/Value.losses-edge-cases.spec.js +0 -30
- package/dist/Value.memory.spec.d.ts +0 -1
- package/dist/Value.memory.spec.js +0 -23
- package/dist/Value.nn.spec.d.ts +0 -1
- package/dist/Value.nn.spec.js +0 -111
- package/dist/Value.spec.d.ts +0 -1
- package/dist/Value.spec.js +0 -245
- /package/dist/{Losses.spec.d.ts → __tests__/duplicate-inputs.test.d.ts} +0 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
#!/usr/bin/env node
|
|
2
|
+
/**
|
|
3
|
+
* CLI tool for symbolic gradient generation.
|
|
4
|
+
* Accepts mathematical expressions and outputs gradient computation code.
|
|
5
|
+
*/
|
|
6
|
+
import * as fs from 'fs';
|
|
7
|
+
import { parse } from '../symbolic/Parser';
|
|
8
|
+
import { computeGradients } from '../symbolic/SymbolicDiff';
|
|
9
|
+
import { simplify } from '../symbolic/Simplify';
|
|
10
|
+
import { generateGradientCode, generateGradientFunction } from '../symbolic/CodeGen';
|
|
11
|
+
const VERSION = '1.0.0';
|
|
12
|
+
function printHelp() {
|
|
13
|
+
console.log(`
|
|
14
|
+
ScalarAutograd Symbolic Gradient Generator v${VERSION}
|
|
15
|
+
|
|
16
|
+
Generate symbolic gradient formulas from mathematical expressions.
|
|
17
|
+
|
|
18
|
+
USAGE:
|
|
19
|
+
npx scalar-grad [options] [expression]
|
|
20
|
+
|
|
21
|
+
OPTIONS:
|
|
22
|
+
-i, --input <file> Input file containing expressions
|
|
23
|
+
-o, --output <file> Output file (default: stdout)
|
|
24
|
+
--wrt <params> Comma-separated list of parameters to differentiate
|
|
25
|
+
(default: auto-detect from expressions)
|
|
26
|
+
--format <type> Output format: js, ts, function, inline (default: js)
|
|
27
|
+
--no-simplify Disable expression simplification
|
|
28
|
+
--function <name> Function name when using --format=function (default: computeGradient)
|
|
29
|
+
-h, --help Show this help message
|
|
30
|
+
-v, --version Show version
|
|
31
|
+
|
|
32
|
+
EXAMPLES:
|
|
33
|
+
# Simple expression
|
|
34
|
+
npx scalar-grad "x = 2; y = 3; output = x*x + y*y" --wrt x,y
|
|
35
|
+
|
|
36
|
+
# From file
|
|
37
|
+
npx scalar-grad --input forward.txt --wrt a,b,c --output gradients.js
|
|
38
|
+
|
|
39
|
+
# Generate function
|
|
40
|
+
npx scalar-grad "z = x*x + y*y; output = sqrt(z)" --format function --function distanceGradient
|
|
41
|
+
|
|
42
|
+
# Inline expression (no file)
|
|
43
|
+
echo "output = sin(x) * cos(y)" | npx scalar-grad --wrt x,y
|
|
44
|
+
|
|
45
|
+
INPUT FORMAT:
|
|
46
|
+
Expressions use operator overloading syntax:
|
|
47
|
+
c = a + b // Addition
|
|
48
|
+
d = c * 2 // Multiplication
|
|
49
|
+
e = sin(d) // Function call
|
|
50
|
+
output = e // Mark output variable
|
|
51
|
+
|
|
52
|
+
Supported operators: +, -, *, /, ** (power)
|
|
53
|
+
Supported functions: sin, cos, tan, exp, log, sqrt, abs, asin, acos, atan, etc.
|
|
54
|
+
|
|
55
|
+
Vec2/Vec3 support:
|
|
56
|
+
v = Vec2(x, y) // Vector constructor
|
|
57
|
+
mag = v.magnitude // Vector property
|
|
58
|
+
dot = u.dot(v) // Vector method
|
|
59
|
+
|
|
60
|
+
OUTPUT:
|
|
61
|
+
Generated code includes:
|
|
62
|
+
- Forward pass computation
|
|
63
|
+
- Gradient formulas with mathematical notation in comments
|
|
64
|
+
- Executable JavaScript/TypeScript code
|
|
65
|
+
`);
|
|
66
|
+
}
|
|
67
|
+
function parseArgs(args) {
|
|
68
|
+
const options = {
|
|
69
|
+
format: 'js',
|
|
70
|
+
simplify: true,
|
|
71
|
+
help: false,
|
|
72
|
+
version: false
|
|
73
|
+
};
|
|
74
|
+
let i = 0;
|
|
75
|
+
while (i < args.length) {
|
|
76
|
+
const arg = args[i];
|
|
77
|
+
if (arg === '-h' || arg === '--help') {
|
|
78
|
+
options.help = true;
|
|
79
|
+
i++;
|
|
80
|
+
}
|
|
81
|
+
else if (arg === '-v' || arg === '--version') {
|
|
82
|
+
options.version = true;
|
|
83
|
+
i++;
|
|
84
|
+
}
|
|
85
|
+
else if (arg === '-i' || arg === '--input') {
|
|
86
|
+
options.input = args[++i];
|
|
87
|
+
i++;
|
|
88
|
+
}
|
|
89
|
+
else if (arg === '-o' || arg === '--output') {
|
|
90
|
+
options.output = args[++i];
|
|
91
|
+
i++;
|
|
92
|
+
}
|
|
93
|
+
else if (arg === '--wrt') {
|
|
94
|
+
options.wrt = args[++i].split(',').map(s => s.trim());
|
|
95
|
+
i++;
|
|
96
|
+
}
|
|
97
|
+
else if (arg === '--format') {
|
|
98
|
+
const format = args[++i];
|
|
99
|
+
if (format !== 'js' && format !== 'ts' && format !== 'function' && format !== 'inline') {
|
|
100
|
+
throw new Error(`Invalid format: ${format}. Must be js, ts, function, or inline`);
|
|
101
|
+
}
|
|
102
|
+
options.format = format;
|
|
103
|
+
i++;
|
|
104
|
+
}
|
|
105
|
+
else if (arg === '--no-simplify') {
|
|
106
|
+
options.simplify = false;
|
|
107
|
+
i++;
|
|
108
|
+
}
|
|
109
|
+
else if (arg === '--function') {
|
|
110
|
+
options.functionName = args[++i];
|
|
111
|
+
i++;
|
|
112
|
+
}
|
|
113
|
+
else if (arg.startsWith('-')) {
|
|
114
|
+
throw new Error(`Unknown option: ${arg}`);
|
|
115
|
+
}
|
|
116
|
+
else {
|
|
117
|
+
// Positional argument - treat as expression
|
|
118
|
+
options.input = arg;
|
|
119
|
+
i++;
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
return options;
|
|
123
|
+
}
|
|
124
|
+
function readInput(options) {
|
|
125
|
+
if (options.input) {
|
|
126
|
+
// Check if it's a file
|
|
127
|
+
if (fs.existsSync(options.input)) {
|
|
128
|
+
return fs.readFileSync(options.input, 'utf-8');
|
|
129
|
+
}
|
|
130
|
+
else {
|
|
131
|
+
// Treat as direct expression
|
|
132
|
+
return options.input;
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
// Read from stdin
|
|
136
|
+
const stdin = fs.readFileSync(0, 'utf-8');
|
|
137
|
+
return stdin;
|
|
138
|
+
}
|
|
139
|
+
function writeOutput(content, options) {
|
|
140
|
+
if (options.output) {
|
|
141
|
+
fs.writeFileSync(options.output, content, 'utf-8');
|
|
142
|
+
console.error(`Output written to ${options.output}`);
|
|
143
|
+
}
|
|
144
|
+
else {
|
|
145
|
+
console.log(content);
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
function extractParameters(input) {
|
|
149
|
+
// Simple heuristic: find all single-letter variables that appear on the right-hand side
|
|
150
|
+
// but are never assigned
|
|
151
|
+
const assignedVars = new Set();
|
|
152
|
+
const usedVars = new Set();
|
|
153
|
+
// Match assignments: var = expr
|
|
154
|
+
const assignmentRegex = /(\w+)\s*=/g;
|
|
155
|
+
let match;
|
|
156
|
+
while ((match = assignmentRegex.exec(input)) !== null) {
|
|
157
|
+
assignedVars.add(match[1]);
|
|
158
|
+
}
|
|
159
|
+
// Match all identifiers
|
|
160
|
+
const identifierRegex = /\b([a-zA-Z_]\w*)\b/g;
|
|
161
|
+
while ((match = identifierRegex.exec(input)) !== null) {
|
|
162
|
+
const ident = match[1];
|
|
163
|
+
// Skip known functions and keywords
|
|
164
|
+
const keywords = ['Vec2', 'Vec3', 'sin', 'cos', 'tan', 'exp', 'log', 'sqrt', 'abs',
|
|
165
|
+
'min', 'max', 'output', 'Math', 'const', 'let', 'var'];
|
|
166
|
+
if (!keywords.includes(ident)) {
|
|
167
|
+
usedVars.add(ident);
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
// Parameters are variables used but never assigned
|
|
171
|
+
const params = [];
|
|
172
|
+
for (const v of usedVars) {
|
|
173
|
+
if (!assignedVars.has(v)) {
|
|
174
|
+
params.push(v);
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
return params.sort();
|
|
178
|
+
}
|
|
179
|
+
function main() {
|
|
180
|
+
try {
|
|
181
|
+
const args = process.argv.slice(2);
|
|
182
|
+
if (args.length === 0) {
|
|
183
|
+
printHelp();
|
|
184
|
+
process.exit(0);
|
|
185
|
+
}
|
|
186
|
+
const options = parseArgs(args);
|
|
187
|
+
if (options.help) {
|
|
188
|
+
printHelp();
|
|
189
|
+
process.exit(0);
|
|
190
|
+
}
|
|
191
|
+
if (options.version) {
|
|
192
|
+
console.log(`ScalarAutograd Symbolic Gradient Generator v${VERSION}`);
|
|
193
|
+
process.exit(0);
|
|
194
|
+
}
|
|
195
|
+
// Read input
|
|
196
|
+
const input = readInput(options);
|
|
197
|
+
if (!input.trim()) {
|
|
198
|
+
console.error('Error: No input provided');
|
|
199
|
+
process.exit(1);
|
|
200
|
+
}
|
|
201
|
+
// Parse expression
|
|
202
|
+
console.error('Parsing expressions...');
|
|
203
|
+
const program = parse(input);
|
|
204
|
+
// Auto-detect parameters if not specified
|
|
205
|
+
let parameters = options.wrt;
|
|
206
|
+
if (!parameters) {
|
|
207
|
+
parameters = extractParameters(input);
|
|
208
|
+
console.error(`Auto-detected parameters: ${parameters.join(', ')}`);
|
|
209
|
+
}
|
|
210
|
+
if (parameters.length === 0) {
|
|
211
|
+
console.error('Error: No parameters to differentiate with respect to');
|
|
212
|
+
console.error('Hint: Use --wrt to specify parameters');
|
|
213
|
+
process.exit(1);
|
|
214
|
+
}
|
|
215
|
+
// Compute gradients
|
|
216
|
+
console.error('Computing symbolic gradients...');
|
|
217
|
+
let gradients = computeGradients(program, parameters);
|
|
218
|
+
// Simplify if enabled
|
|
219
|
+
if (options.simplify) {
|
|
220
|
+
console.error('Simplifying expressions...');
|
|
221
|
+
const simplified = new Map();
|
|
222
|
+
for (const [param, gradExpr] of gradients.entries()) {
|
|
223
|
+
simplified.set(param, simplify(gradExpr));
|
|
224
|
+
}
|
|
225
|
+
gradients = simplified;
|
|
226
|
+
}
|
|
227
|
+
// Generate code
|
|
228
|
+
console.error('Generating code...');
|
|
229
|
+
let output;
|
|
230
|
+
if (options.format === 'function') {
|
|
231
|
+
const funcName = options.functionName || 'computeGradient';
|
|
232
|
+
output = generateGradientFunction(program, gradients, funcName, parameters, {
|
|
233
|
+
includeMath: true
|
|
234
|
+
});
|
|
235
|
+
}
|
|
236
|
+
else {
|
|
237
|
+
output = generateGradientCode(program, gradients, {
|
|
238
|
+
includeMath: true,
|
|
239
|
+
varStyle: 'const',
|
|
240
|
+
includeForward: true
|
|
241
|
+
});
|
|
242
|
+
}
|
|
243
|
+
// Add TypeScript annotations if requested
|
|
244
|
+
if (options.format === 'ts') {
|
|
245
|
+
output = `// TypeScript version\n${output}`;
|
|
246
|
+
// Could add type annotations here
|
|
247
|
+
}
|
|
248
|
+
// Write output
|
|
249
|
+
writeOutput(output, options);
|
|
250
|
+
console.error('✓ Success!');
|
|
251
|
+
}
|
|
252
|
+
catch (error) {
|
|
253
|
+
console.error('Error:', error.message);
|
|
254
|
+
if (process.env.DEBUG) {
|
|
255
|
+
console.error(error.stack);
|
|
256
|
+
}
|
|
257
|
+
process.exit(1);
|
|
258
|
+
}
|
|
259
|
+
}
|
|
260
|
+
// Run if called directly
|
|
261
|
+
if (require.main === module) {
|
|
262
|
+
main();
|
|
263
|
+
}
|
|
264
|
+
export { main, parseArgs, extractParameters };
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import { Value } from './Value';
|
|
2
|
+
import { ValueRegistry } from './ValueRegistry';
|
|
3
|
+
/**
|
|
4
|
+
* Compiles a residual function with indirect value indexing for kernel reuse.
|
|
5
|
+
*
|
|
6
|
+
* Instead of hardcoding parameter positions, this generates code that:
|
|
7
|
+
* 1. Accepts a global value array
|
|
8
|
+
* 2. Accepts an indices array specifying which values this graph uses
|
|
9
|
+
* 3. Looks up values via indices: allValues[indices[i]]
|
|
10
|
+
*
|
|
11
|
+
* This allows topologically identical graphs to share the same kernel.
|
|
12
|
+
*
|
|
13
|
+
* @param residual - Output Value of the residual computation
|
|
14
|
+
* @param params - Parameter Values (for Jacobian computation)
|
|
15
|
+
* @param registry - ValueRegistry tracking all unique values
|
|
16
|
+
* @returns Compiled function: (allValues, indices, gradientIndices, gradient) => number
|
|
17
|
+
* @internal
|
|
18
|
+
*/
|
|
19
|
+
export declare function compileIndirectKernel(residual: Value, params: Value[], registry: ValueRegistry): (allValues: number[], indices: number[], gradientIndices: number[], gradient: number[]) => number;
|
|
20
|
+
/**
|
|
21
|
+
* Extract input indices for a residual graph.
|
|
22
|
+
* Returns array of registry IDs for all leaf nodes in topological order.
|
|
23
|
+
*/
|
|
24
|
+
export declare function extractInputIndices(residual: Value, registry: ValueRegistry): number[];
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Compiles a residual function with indirect value indexing for kernel reuse.
|
|
3
|
+
*
|
|
4
|
+
* Instead of hardcoding parameter positions, this generates code that:
|
|
5
|
+
* 1. Accepts a global value array
|
|
6
|
+
* 2. Accepts an indices array specifying which values this graph uses
|
|
7
|
+
* 3. Looks up values via indices: allValues[indices[i]]
|
|
8
|
+
*
|
|
9
|
+
* This allows topologically identical graphs to share the same kernel.
|
|
10
|
+
*
|
|
11
|
+
* @param residual - Output Value of the residual computation
|
|
12
|
+
* @param params - Parameter Values (for Jacobian computation)
|
|
13
|
+
* @param registry - ValueRegistry tracking all unique values
|
|
14
|
+
* @returns Compiled function: (allValues, indices, gradientIndices, gradient) => number
|
|
15
|
+
* @internal
|
|
16
|
+
*/
|
|
17
|
+
export function compileIndirectKernel(residual, params, registry) {
|
|
18
|
+
const visited = new Set();
|
|
19
|
+
const topoOrder = [];
|
|
20
|
+
const forwardCode = [];
|
|
21
|
+
const backwardCode = [];
|
|
22
|
+
let varCounter = 0;
|
|
23
|
+
const nodeToVar = new Map();
|
|
24
|
+
const nodeToIndexVar = new Map(); // Maps Value -> index variable name
|
|
25
|
+
// Build topo order and register only leaf nodes (inputs)
|
|
26
|
+
function buildTopoOrder(node) {
|
|
27
|
+
if (visited.has(node))
|
|
28
|
+
return;
|
|
29
|
+
visited.add(node);
|
|
30
|
+
for (const child of node.prev) {
|
|
31
|
+
buildTopoOrder(child);
|
|
32
|
+
}
|
|
33
|
+
topoOrder.push(node);
|
|
34
|
+
// Only register leaf nodes (graph inputs: constants and parameters)
|
|
35
|
+
// Intermediates (prev.length > 0) are computed within kernel, not in registry
|
|
36
|
+
if (node.prev.length === 0) {
|
|
37
|
+
registry.register(node);
|
|
38
|
+
}
|
|
39
|
+
}
|
|
40
|
+
buildTopoOrder(residual);
|
|
41
|
+
// Build index mapping for graph inputs (leaf nodes in topoOrder)
|
|
42
|
+
const graphInputs = [];
|
|
43
|
+
for (const node of topoOrder) {
|
|
44
|
+
const prev = node.prev;
|
|
45
|
+
if (prev.length === 0 && !graphInputs.includes(node)) {
|
|
46
|
+
graphInputs.push(node);
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
// Assign index variables to inputs
|
|
50
|
+
graphInputs.forEach((input, i) => {
|
|
51
|
+
nodeToIndexVar.set(input, `idx_${i}`);
|
|
52
|
+
});
|
|
53
|
+
function getVarName(node) {
|
|
54
|
+
if (!nodeToVar.has(node)) {
|
|
55
|
+
nodeToVar.set(node, `_v${varCounter++}`);
|
|
56
|
+
}
|
|
57
|
+
return nodeToVar.get(node);
|
|
58
|
+
}
|
|
59
|
+
// Generate forward code for each node in topo order
|
|
60
|
+
for (const node of topoOrder) {
|
|
61
|
+
const prev = node.prev;
|
|
62
|
+
if (prev.length === 0) {
|
|
63
|
+
// Leaf node - load from allValues via index
|
|
64
|
+
const indexVar = nodeToIndexVar.get(node);
|
|
65
|
+
forwardCode.push(`const ${getVarName(node)} = allValues[${indexVar}];`);
|
|
66
|
+
}
|
|
67
|
+
else {
|
|
68
|
+
// Computed node
|
|
69
|
+
const childCodes = prev.map(c => getVarName(c));
|
|
70
|
+
const code = node.getForwardCode(childCodes);
|
|
71
|
+
forwardCode.push(`const ${getVarName(node)} = ${code};`);
|
|
72
|
+
}
|
|
73
|
+
}
|
|
74
|
+
// Generate gradient declarations
|
|
75
|
+
const gradDeclarations = Array.from(nodeToVar.entries())
|
|
76
|
+
.map(([node, varName]) => `let grad_${varName} = 0;`);
|
|
77
|
+
// Generate backward pass - iterate REVERSE topological order (same as Value.backward)
|
|
78
|
+
for (let i = topoOrder.length - 1; i >= 0; i--) {
|
|
79
|
+
const node = topoOrder[i];
|
|
80
|
+
const prev = node.prev;
|
|
81
|
+
if (prev.length === 0 || !node.requiresGrad)
|
|
82
|
+
continue;
|
|
83
|
+
const gradVar = `grad_${getVarName(node)}`;
|
|
84
|
+
const childGrads = prev.map((c) => `grad_${getVarName(c)}`);
|
|
85
|
+
const childVars = prev.map((c) => getVarName(c));
|
|
86
|
+
backwardCode.push(node.getBackwardCode(gradVar, childGrads, childVars));
|
|
87
|
+
}
|
|
88
|
+
const outputVar = getVarName(residual);
|
|
89
|
+
// Extract indices for graph inputs
|
|
90
|
+
const indexExtractions = graphInputs
|
|
91
|
+
.map((input, i) => `const ${nodeToIndexVar.get(input)} = indices[${i}];`)
|
|
92
|
+
.join('\n ');
|
|
93
|
+
// Build gradient updates - only for inputs that require gradients
|
|
94
|
+
// Skip constants and any values with requiresGrad=false
|
|
95
|
+
const gradientUpdates = graphInputs
|
|
96
|
+
.map((input, inputIdx) => {
|
|
97
|
+
if (!input.requiresGrad) {
|
|
98
|
+
return null; // Skip - this input doesn't need gradients
|
|
99
|
+
}
|
|
100
|
+
const gradVar = `grad_${getVarName(input)}`;
|
|
101
|
+
return `gradient[gradientIndices[${inputIdx}]] += ${gradVar};`;
|
|
102
|
+
})
|
|
103
|
+
.filter((line) => line !== null)
|
|
104
|
+
.join('\n ');
|
|
105
|
+
const functionBody = `
|
|
106
|
+
${indexExtractions}
|
|
107
|
+
${forwardCode.join('\n ')}
|
|
108
|
+
${gradDeclarations.join('\n ')}
|
|
109
|
+
grad_${outputVar} = 1;
|
|
110
|
+
${backwardCode.join('\n ')}
|
|
111
|
+
${gradientUpdates}
|
|
112
|
+
return ${outputVar};
|
|
113
|
+
`;
|
|
114
|
+
try {
|
|
115
|
+
return new Function('allValues', 'indices', 'gradientIndices', 'gradient', functionBody);
|
|
116
|
+
}
|
|
117
|
+
catch (error) {
|
|
118
|
+
console.error(`[compileIndirectKernel] COMPILATION ERROR:`, error);
|
|
119
|
+
console.error(`[compileIndirectKernel] Function body:\n${functionBody}`);
|
|
120
|
+
throw error;
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
/**
|
|
124
|
+
* Extract input indices for a residual graph.
|
|
125
|
+
* Returns array of registry IDs for all leaf nodes in topological order.
|
|
126
|
+
*/
|
|
127
|
+
export function extractInputIndices(residual, registry) {
|
|
128
|
+
const visited = new Set();
|
|
129
|
+
const inputs = [];
|
|
130
|
+
function collectInputs(node) {
|
|
131
|
+
if (visited.has(node))
|
|
132
|
+
return;
|
|
133
|
+
visited.add(node);
|
|
134
|
+
const prev = node.prev;
|
|
135
|
+
if (prev.length === 0) {
|
|
136
|
+
if (!inputs.includes(node)) {
|
|
137
|
+
inputs.push(node);
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
else {
|
|
141
|
+
for (const child of prev) {
|
|
142
|
+
collectInputs(child);
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
collectInputs(residual);
|
|
147
|
+
return inputs.map(v => registry.getId(v));
|
|
148
|
+
}
|
package/dist/index.d.ts
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import './EigenvalueHelpers';
|
|
2
|
+
export { CompiledFunctions } from './CompiledFunctions';
|
|
3
|
+
export { CompiledResiduals } from './CompiledResiduals';
|
|
4
|
+
export { Geometry } from './Geometry';
|
|
5
|
+
export { GraphBuilder, type GraphSignature } from './GraphBuilder';
|
|
6
|
+
export { lbfgs, type LBFGSOptions, type LBFGSResult } from './LBFGS';
|
|
7
|
+
export { Losses } from './Losses';
|
|
8
|
+
export { Matrix3x3 } from './Matrix3x3';
|
|
9
|
+
export { nonlinearLeastSquares, type NonlinearLeastSquaresOptions, type NonlinearLeastSquaresResult } from './NonlinearLeastSquares';
|
|
10
|
+
export { Adam, AdamW, Optimizer, SGD } from './Optimizers';
|
|
11
|
+
export { V } from './V';
|
|
12
|
+
export { Value } from './Value';
|
|
13
|
+
export { Vec2 } from './Vec2';
|
|
14
|
+
export { Vec3 } from './Vec3';
|
|
15
|
+
export { Vec4 } from './Vec4';
|
|
16
|
+
export { parse, Parser } from './symbolic/Parser';
|
|
17
|
+
export { differentiate, computeGradients } from './symbolic/SymbolicDiff';
|
|
18
|
+
export { simplify } from './symbolic/Simplify';
|
|
19
|
+
export { generateCode, generateMathNotation, generateGradientCode, generateGradientFunction } from './symbolic/CodeGen';
|
|
20
|
+
export type { ASTNode, NumberNode, VariableNode, BinaryOpNode, UnaryOpNode, FunctionCallNode, VectorAccessNode, VectorConstructorNode, Assignment, Program } from './symbolic/AST';
|
package/dist/index.js
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import './EigenvalueHelpers';
|
|
2
|
+
export { CompiledFunctions } from './CompiledFunctions';
|
|
3
|
+
export { CompiledResiduals } from './CompiledResiduals';
|
|
4
|
+
export { Geometry } from './Geometry';
|
|
5
|
+
export { GraphBuilder } from './GraphBuilder';
|
|
6
|
+
export { lbfgs } from './LBFGS';
|
|
7
|
+
export { Losses } from './Losses';
|
|
8
|
+
export { Matrix3x3 } from './Matrix3x3';
|
|
9
|
+
export { nonlinearLeastSquares } from './NonlinearLeastSquares';
|
|
10
|
+
export { Adam, AdamW, Optimizer, SGD } from './Optimizers';
|
|
11
|
+
export { V } from './V';
|
|
12
|
+
export { Value } from './Value';
|
|
13
|
+
export { Vec2 } from './Vec2';
|
|
14
|
+
export { Vec3 } from './Vec3';
|
|
15
|
+
export { Vec4 } from './Vec4';
|
|
16
|
+
// Symbolic gradient generation
|
|
17
|
+
export { parse, Parser } from './symbolic/Parser';
|
|
18
|
+
export { differentiate, computeGradients } from './symbolic/SymbolicDiff';
|
|
19
|
+
export { simplify } from './symbolic/Simplify';
|
|
20
|
+
export { generateCode, generateMathNotation, generateGradientCode, generateGradientFunction } from './symbolic/CodeGen';
|