scalar-autograd 0.1.7 → 0.1.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +127 -2
- package/dist/CompiledFunctions.d.ts +111 -0
- package/dist/CompiledFunctions.js +268 -0
- package/dist/CompiledResiduals.d.ts +74 -0
- package/dist/CompiledResiduals.js +94 -0
- package/dist/EigenvalueHelpers.d.ts +14 -0
- package/dist/EigenvalueHelpers.js +93 -0
- package/dist/Geometry.d.ts +131 -0
- package/dist/Geometry.js +213 -0
- package/dist/GraphBuilder.d.ts +64 -0
- package/dist/GraphBuilder.js +237 -0
- package/dist/GraphCanonicalizerNoSort.d.ts +20 -0
- package/dist/GraphCanonicalizerNoSort.js +190 -0
- package/dist/GraphHashCanonicalizer.d.ts +46 -0
- package/dist/GraphHashCanonicalizer.js +220 -0
- package/dist/GraphSignature.d.ts +7 -0
- package/dist/GraphSignature.js +7 -0
- package/dist/KernelPool.d.ts +55 -0
- package/dist/KernelPool.js +124 -0
- package/dist/LBFGS.d.ts +84 -0
- package/dist/LBFGS.js +313 -0
- package/dist/LinearSolver.d.ts +69 -0
- package/dist/LinearSolver.js +213 -0
- package/dist/Losses.d.ts +9 -0
- package/dist/Losses.js +42 -37
- package/dist/Matrix3x3.d.ts +50 -0
- package/dist/Matrix3x3.js +146 -0
- package/dist/NonlinearLeastSquares.d.ts +33 -0
- package/dist/NonlinearLeastSquares.js +252 -0
- package/dist/Optimizers.d.ts +70 -14
- package/dist/Optimizers.js +42 -19
- package/dist/V.d.ts +0 -0
- package/dist/V.js +0 -0
- package/dist/Value.d.ts +84 -2
- package/dist/Value.js +296 -58
- package/dist/ValueActivation.js +10 -14
- package/dist/ValueArithmetic.d.ts +1 -0
- package/dist/ValueArithmetic.js +58 -50
- package/dist/ValueComparison.js +9 -13
- package/dist/ValueRegistry.d.ts +38 -0
- package/dist/ValueRegistry.js +88 -0
- package/dist/ValueTrig.js +14 -18
- package/dist/Vec2.d.ts +45 -0
- package/dist/Vec2.js +93 -0
- package/dist/Vec3.d.ts +78 -0
- package/dist/Vec3.js +169 -0
- package/dist/Vec4.d.ts +45 -0
- package/dist/Vec4.js +126 -0
- package/dist/__tests__/duplicate-inputs.test.js +33 -0
- package/dist/cli/gradient-gen.d.ts +19 -0
- package/dist/cli/gradient-gen.js +264 -0
- package/dist/compileIndirectKernel.d.ts +24 -0
- package/dist/compileIndirectKernel.js +148 -0
- package/dist/index.d.ts +20 -0
- package/dist/index.js +20 -0
- package/dist/scalar-autograd.d.ts +1157 -0
- package/dist/symbolic/AST.d.ts +113 -0
- package/dist/symbolic/AST.js +128 -0
- package/dist/symbolic/CodeGen.d.ts +35 -0
- package/dist/symbolic/CodeGen.js +280 -0
- package/dist/symbolic/Parser.d.ts +64 -0
- package/dist/symbolic/Parser.js +329 -0
- package/dist/symbolic/Simplify.d.ts +10 -0
- package/dist/symbolic/Simplify.js +244 -0
- package/dist/symbolic/SymbolicDiff.d.ts +35 -0
- package/dist/symbolic/SymbolicDiff.js +339 -0
- package/dist/tsdoc-metadata.json +11 -0
- package/package.json +29 -5
- package/dist/Losses.spec.js +0 -54
- package/dist/Optimizers.edge-cases.spec.d.ts +0 -1
- package/dist/Optimizers.edge-cases.spec.js +0 -29
- package/dist/Optimizers.spec.d.ts +0 -1
- package/dist/Optimizers.spec.js +0 -56
- package/dist/Value.edge-cases.spec.d.ts +0 -1
- package/dist/Value.edge-cases.spec.js +0 -54
- package/dist/Value.grad-flow.spec.d.ts +0 -1
- package/dist/Value.grad-flow.spec.js +0 -24
- package/dist/Value.losses-edge-cases.spec.d.ts +0 -1
- package/dist/Value.losses-edge-cases.spec.js +0 -30
- package/dist/Value.memory.spec.d.ts +0 -1
- package/dist/Value.memory.spec.js +0 -23
- package/dist/Value.nn.spec.d.ts +0 -1
- package/dist/Value.nn.spec.js +0 -111
- package/dist/Value.spec.d.ts +0 -1
- package/dist/Value.spec.js +0 -245
- /package/dist/{Losses.spec.d.ts → __tests__/duplicate-inputs.test.d.ts} +0 -0
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Abstract Syntax Tree (AST) node definitions for symbolic differentiation.
|
|
3
|
+
* @internal
|
|
4
|
+
*/
|
|
5
|
+
/**
|
|
6
|
+
* Base interface for all AST nodes
|
|
7
|
+
*/
|
|
8
|
+
export interface ASTNode {
|
|
9
|
+
type: string;
|
|
10
|
+
accept<T>(visitor: ASTVisitor<T>): T;
|
|
11
|
+
}
|
|
12
|
+
/**
|
|
13
|
+
* Visitor pattern interface for traversing AST
|
|
14
|
+
*/
|
|
15
|
+
export interface ASTVisitor<T> {
|
|
16
|
+
visitNumber(node: NumberNode): T;
|
|
17
|
+
visitVariable(node: VariableNode): T;
|
|
18
|
+
visitBinaryOp(node: BinaryOpNode): T;
|
|
19
|
+
visitUnaryOp(node: UnaryOpNode): T;
|
|
20
|
+
visitFunctionCall(node: FunctionCallNode): T;
|
|
21
|
+
visitVectorAccess(node: VectorAccessNode): T;
|
|
22
|
+
visitVectorConstructor(node: VectorConstructorNode): T;
|
|
23
|
+
}
|
|
24
|
+
/**
|
|
25
|
+
* Numeric constant node
|
|
26
|
+
*/
|
|
27
|
+
export declare class NumberNode implements ASTNode {
|
|
28
|
+
value: number;
|
|
29
|
+
type: "Number";
|
|
30
|
+
constructor(value: number);
|
|
31
|
+
accept<T>(visitor: ASTVisitor<T>): T;
|
|
32
|
+
toString(): string;
|
|
33
|
+
}
|
|
34
|
+
/**
|
|
35
|
+
* Variable reference node (e.g., 'x', 'y', 'a')
|
|
36
|
+
*/
|
|
37
|
+
export declare class VariableNode implements ASTNode {
|
|
38
|
+
name: string;
|
|
39
|
+
type: "Variable";
|
|
40
|
+
constructor(name: string);
|
|
41
|
+
accept<T>(visitor: ASTVisitor<T>): T;
|
|
42
|
+
toString(): string;
|
|
43
|
+
}
|
|
44
|
+
/**
|
|
45
|
+
* Binary operation node (e.g., a + b, x * y)
|
|
46
|
+
*/
|
|
47
|
+
export declare class BinaryOpNode implements ASTNode {
|
|
48
|
+
op: '+' | '-' | '*' | '/' | '**' | 'pow';
|
|
49
|
+
left: ASTNode;
|
|
50
|
+
right: ASTNode;
|
|
51
|
+
type: "BinaryOp";
|
|
52
|
+
constructor(op: '+' | '-' | '*' | '/' | '**' | 'pow', left: ASTNode, right: ASTNode);
|
|
53
|
+
accept<T>(visitor: ASTVisitor<T>): T;
|
|
54
|
+
toString(): string;
|
|
55
|
+
}
|
|
56
|
+
/**
|
|
57
|
+
* Unary operation node (e.g., -x, +y)
|
|
58
|
+
*/
|
|
59
|
+
export declare class UnaryOpNode implements ASTNode {
|
|
60
|
+
op: '+' | '-';
|
|
61
|
+
operand: ASTNode;
|
|
62
|
+
type: "UnaryOp";
|
|
63
|
+
constructor(op: '+' | '-', operand: ASTNode);
|
|
64
|
+
accept<T>(visitor: ASTVisitor<T>): T;
|
|
65
|
+
toString(): string;
|
|
66
|
+
}
|
|
67
|
+
/**
|
|
68
|
+
* Function call node (e.g., sin(x), sqrt(y), max(a, b))
|
|
69
|
+
*/
|
|
70
|
+
export declare class FunctionCallNode implements ASTNode {
|
|
71
|
+
name: string;
|
|
72
|
+
args: ASTNode[];
|
|
73
|
+
type: "FunctionCall";
|
|
74
|
+
constructor(name: string, args: ASTNode[]);
|
|
75
|
+
accept<T>(visitor: ASTVisitor<T>): T;
|
|
76
|
+
toString(): string;
|
|
77
|
+
}
|
|
78
|
+
/**
|
|
79
|
+
* Vector component access node (e.g., v.x, point.y)
|
|
80
|
+
*/
|
|
81
|
+
export declare class VectorAccessNode implements ASTNode {
|
|
82
|
+
vector: ASTNode;
|
|
83
|
+
component: 'x' | 'y' | 'z';
|
|
84
|
+
type: "VectorAccess";
|
|
85
|
+
constructor(vector: ASTNode, component: 'x' | 'y' | 'z');
|
|
86
|
+
accept<T>(visitor: ASTVisitor<T>): T;
|
|
87
|
+
toString(): string;
|
|
88
|
+
}
|
|
89
|
+
/**
|
|
90
|
+
* Vector constructor node (e.g., Vec2(x, y), Vec3(1, 2, 3))
|
|
91
|
+
*/
|
|
92
|
+
export declare class VectorConstructorNode implements ASTNode {
|
|
93
|
+
vectorType: 'Vec2' | 'Vec3';
|
|
94
|
+
components: ASTNode[];
|
|
95
|
+
type: "VectorConstructor";
|
|
96
|
+
constructor(vectorType: 'Vec2' | 'Vec3', components: ASTNode[]);
|
|
97
|
+
accept<T>(visitor: ASTVisitor<T>): T;
|
|
98
|
+
toString(): string;
|
|
99
|
+
}
|
|
100
|
+
/**
|
|
101
|
+
* Assignment statement (e.g., x = 5, y = x + 2)
|
|
102
|
+
*/
|
|
103
|
+
export interface Assignment {
|
|
104
|
+
variable: string;
|
|
105
|
+
expression: ASTNode;
|
|
106
|
+
}
|
|
107
|
+
/**
|
|
108
|
+
* Complete program with assignments and output
|
|
109
|
+
*/
|
|
110
|
+
export interface Program {
|
|
111
|
+
assignments: Assignment[];
|
|
112
|
+
output: string;
|
|
113
|
+
}
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Abstract Syntax Tree (AST) node definitions for symbolic differentiation.
|
|
3
|
+
* @internal
|
|
4
|
+
*/
|
|
5
|
+
/**
|
|
6
|
+
* Numeric constant node
|
|
7
|
+
*/
|
|
8
|
+
export class NumberNode {
|
|
9
|
+
value;
|
|
10
|
+
type = 'Number';
|
|
11
|
+
constructor(value) {
|
|
12
|
+
this.value = value;
|
|
13
|
+
}
|
|
14
|
+
accept(visitor) {
|
|
15
|
+
return visitor.visitNumber(this);
|
|
16
|
+
}
|
|
17
|
+
toString() {
|
|
18
|
+
return String(this.value);
|
|
19
|
+
}
|
|
20
|
+
}
|
|
21
|
+
/**
|
|
22
|
+
* Variable reference node (e.g., 'x', 'y', 'a')
|
|
23
|
+
*/
|
|
24
|
+
export class VariableNode {
|
|
25
|
+
name;
|
|
26
|
+
type = 'Variable';
|
|
27
|
+
constructor(name) {
|
|
28
|
+
this.name = name;
|
|
29
|
+
}
|
|
30
|
+
accept(visitor) {
|
|
31
|
+
return visitor.visitVariable(this);
|
|
32
|
+
}
|
|
33
|
+
toString() {
|
|
34
|
+
return this.name;
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
/**
|
|
38
|
+
* Binary operation node (e.g., a + b, x * y)
|
|
39
|
+
*/
|
|
40
|
+
export class BinaryOpNode {
|
|
41
|
+
op;
|
|
42
|
+
left;
|
|
43
|
+
right;
|
|
44
|
+
type = 'BinaryOp';
|
|
45
|
+
constructor(op, left, right) {
|
|
46
|
+
this.op = op;
|
|
47
|
+
this.left = left;
|
|
48
|
+
this.right = right;
|
|
49
|
+
}
|
|
50
|
+
accept(visitor) {
|
|
51
|
+
return visitor.visitBinaryOp(this);
|
|
52
|
+
}
|
|
53
|
+
toString() {
|
|
54
|
+
return `(${this.left.toString()} ${this.op} ${this.right.toString()})`;
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
/**
|
|
58
|
+
* Unary operation node (e.g., -x, +y)
|
|
59
|
+
*/
|
|
60
|
+
export class UnaryOpNode {
|
|
61
|
+
op;
|
|
62
|
+
operand;
|
|
63
|
+
type = 'UnaryOp';
|
|
64
|
+
constructor(op, operand) {
|
|
65
|
+
this.op = op;
|
|
66
|
+
this.operand = operand;
|
|
67
|
+
}
|
|
68
|
+
accept(visitor) {
|
|
69
|
+
return visitor.visitUnaryOp(this);
|
|
70
|
+
}
|
|
71
|
+
toString() {
|
|
72
|
+
return `${this.op}${this.operand.toString()}`;
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
/**
|
|
76
|
+
* Function call node (e.g., sin(x), sqrt(y), max(a, b))
|
|
77
|
+
*/
|
|
78
|
+
export class FunctionCallNode {
|
|
79
|
+
name;
|
|
80
|
+
args;
|
|
81
|
+
type = 'FunctionCall';
|
|
82
|
+
constructor(name, args) {
|
|
83
|
+
this.name = name;
|
|
84
|
+
this.args = args;
|
|
85
|
+
}
|
|
86
|
+
accept(visitor) {
|
|
87
|
+
return visitor.visitFunctionCall(this);
|
|
88
|
+
}
|
|
89
|
+
toString() {
|
|
90
|
+
return `${this.name}(${this.args.map(a => a.toString()).join(', ')})`;
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
/**
|
|
94
|
+
* Vector component access node (e.g., v.x, point.y)
|
|
95
|
+
*/
|
|
96
|
+
export class VectorAccessNode {
|
|
97
|
+
vector;
|
|
98
|
+
component;
|
|
99
|
+
type = 'VectorAccess';
|
|
100
|
+
constructor(vector, component) {
|
|
101
|
+
this.vector = vector;
|
|
102
|
+
this.component = component;
|
|
103
|
+
}
|
|
104
|
+
accept(visitor) {
|
|
105
|
+
return visitor.visitVectorAccess(this);
|
|
106
|
+
}
|
|
107
|
+
toString() {
|
|
108
|
+
return `${this.vector.toString()}.${this.component}`;
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
/**
|
|
112
|
+
* Vector constructor node (e.g., Vec2(x, y), Vec3(1, 2, 3))
|
|
113
|
+
*/
|
|
114
|
+
export class VectorConstructorNode {
|
|
115
|
+
vectorType;
|
|
116
|
+
components;
|
|
117
|
+
type = 'VectorConstructor';
|
|
118
|
+
constructor(vectorType, components) {
|
|
119
|
+
this.vectorType = vectorType;
|
|
120
|
+
this.components = components;
|
|
121
|
+
}
|
|
122
|
+
accept(visitor) {
|
|
123
|
+
return visitor.visitVectorConstructor(this);
|
|
124
|
+
}
|
|
125
|
+
toString() {
|
|
126
|
+
return `${this.vectorType}(${this.components.map(c => c.toString()).join(', ')})`;
|
|
127
|
+
}
|
|
128
|
+
}
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Code generator for symbolic gradients.
|
|
3
|
+
* Generates executable JavaScript/TypeScript code with mathematical annotations.
|
|
4
|
+
* @internal
|
|
5
|
+
*/
|
|
6
|
+
import { ASTNode, Program } from './AST';
|
|
7
|
+
/**
|
|
8
|
+
* Generate JavaScript code from AST
|
|
9
|
+
*/
|
|
10
|
+
export declare function generateCode(node: ASTNode): string;
|
|
11
|
+
/**
|
|
12
|
+
* Generate mathematical notation from AST
|
|
13
|
+
*/
|
|
14
|
+
export declare function generateMathNotation(node: ASTNode): string;
|
|
15
|
+
/**
|
|
16
|
+
* Generate complete gradient computation code
|
|
17
|
+
*/
|
|
18
|
+
export interface CodeGenOptions {
|
|
19
|
+
/** Include mathematical notation as comments */
|
|
20
|
+
includeMath?: boolean;
|
|
21
|
+
/** Variable declaration style */
|
|
22
|
+
varStyle?: 'const' | 'let' | 'var';
|
|
23
|
+
/** Include forward pass computation */
|
|
24
|
+
includeForward?: boolean;
|
|
25
|
+
/** Format code with indentation */
|
|
26
|
+
indent?: string;
|
|
27
|
+
}
|
|
28
|
+
/**
|
|
29
|
+
* Generate complete gradient code for a program
|
|
30
|
+
*/
|
|
31
|
+
export declare function generateGradientCode(program: Program, gradients: Map<string, ASTNode>, options?: CodeGenOptions): string;
|
|
32
|
+
/**
|
|
33
|
+
* Generate gradient code as a function
|
|
34
|
+
*/
|
|
35
|
+
export declare function generateGradientFunction(program: Program, gradients: Map<string, ASTNode>, functionName: string, parameters: string[], options?: CodeGenOptions): string;
|
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Code generator for symbolic gradients.
|
|
3
|
+
* Generates executable JavaScript/TypeScript code with mathematical annotations.
|
|
4
|
+
* @internal
|
|
5
|
+
*/
|
|
6
|
+
/**
|
|
7
|
+
* Code generation visitor - generates JavaScript expressions
|
|
8
|
+
*/
|
|
9
|
+
class CodeGenVisitor {
|
|
10
|
+
visitNumber(node) {
|
|
11
|
+
return String(node.value);
|
|
12
|
+
}
|
|
13
|
+
visitVariable(node) {
|
|
14
|
+
return node.name;
|
|
15
|
+
}
|
|
16
|
+
visitUnaryOp(node) {
|
|
17
|
+
const operand = node.operand.accept(this);
|
|
18
|
+
// Add parentheses if operand is complex
|
|
19
|
+
if (needsParens(node.operand)) {
|
|
20
|
+
return `${node.op}(${operand})`;
|
|
21
|
+
}
|
|
22
|
+
return `${node.op}${operand}`;
|
|
23
|
+
}
|
|
24
|
+
visitBinaryOp(node) {
|
|
25
|
+
const left = node.left.accept(this);
|
|
26
|
+
const right = node.right.accept(this);
|
|
27
|
+
// Handle precedence with parentheses
|
|
28
|
+
const leftStr = needsParens(node.left, node.op) ? `(${left})` : left;
|
|
29
|
+
const rightStr = needsParens(node.right, node.op, 'right') ? `(${right})` : right;
|
|
30
|
+
if (node.op === '**') {
|
|
31
|
+
return `Math.pow(${leftStr}, ${rightStr})`;
|
|
32
|
+
}
|
|
33
|
+
return `${leftStr} ${node.op} ${rightStr}`;
|
|
34
|
+
}
|
|
35
|
+
visitFunctionCall(node) {
|
|
36
|
+
const args = node.args.map(arg => arg.accept(this)).join(', ');
|
|
37
|
+
// Map function names to Math functions
|
|
38
|
+
const funcMap = {
|
|
39
|
+
'sin': 'Math.sin',
|
|
40
|
+
'cos': 'Math.cos',
|
|
41
|
+
'tan': 'Math.tan',
|
|
42
|
+
'asin': 'Math.asin',
|
|
43
|
+
'acos': 'Math.acos',
|
|
44
|
+
'atan': 'Math.atan',
|
|
45
|
+
'sinh': 'Math.sinh',
|
|
46
|
+
'cosh': 'Math.cosh',
|
|
47
|
+
'tanh': 'Math.tanh',
|
|
48
|
+
'exp': 'Math.exp',
|
|
49
|
+
'log': 'Math.log',
|
|
50
|
+
'ln': 'Math.log',
|
|
51
|
+
'sqrt': 'Math.sqrt',
|
|
52
|
+
'abs': 'Math.abs',
|
|
53
|
+
'sign': 'Math.sign',
|
|
54
|
+
'floor': 'Math.floor',
|
|
55
|
+
'ceil': 'Math.ceil',
|
|
56
|
+
'round': 'Math.round',
|
|
57
|
+
'pow': 'Math.pow',
|
|
58
|
+
'min': 'Math.min',
|
|
59
|
+
'max': 'Math.max',
|
|
60
|
+
'atan2': 'Math.atan2',
|
|
61
|
+
'heaviside': '(x => x > 0 ? 1 : 0)',
|
|
62
|
+
'sigmoid': '(x => 1 / (1 + Math.exp(-x)))'
|
|
63
|
+
};
|
|
64
|
+
const funcName = funcMap[node.name] || node.name;
|
|
65
|
+
return `${funcName}(${args})`;
|
|
66
|
+
}
|
|
67
|
+
visitVectorAccess(node) {
|
|
68
|
+
const vector = node.vector.accept(this);
|
|
69
|
+
return `${vector}.${node.component}`;
|
|
70
|
+
}
|
|
71
|
+
visitVectorConstructor(node) {
|
|
72
|
+
const components = node.components.map(c => c.accept(this)).join(', ');
|
|
73
|
+
return `${node.vectorType}(${components})`;
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
/**
|
|
77
|
+
* Check if node needs parentheses based on operator precedence
|
|
78
|
+
*/
|
|
79
|
+
function needsParens(node, parentOp, position = 'left') {
|
|
80
|
+
if (node.type === 'Number' || node.type === 'Variable' ||
|
|
81
|
+
node.type === 'FunctionCall' || node.type === 'VectorAccess') {
|
|
82
|
+
return false;
|
|
83
|
+
}
|
|
84
|
+
if (node.type === 'UnaryOp') {
|
|
85
|
+
return false; // Unary operators have high precedence
|
|
86
|
+
}
|
|
87
|
+
if (node.type === 'BinaryOp' && parentOp) {
|
|
88
|
+
const binNode = node;
|
|
89
|
+
const nodePrecedence = getPrecedence(binNode.op);
|
|
90
|
+
const parentPrecedence = getPrecedence(parentOp);
|
|
91
|
+
// Need parens if lower precedence
|
|
92
|
+
if (nodePrecedence < parentPrecedence) {
|
|
93
|
+
return true;
|
|
94
|
+
}
|
|
95
|
+
// For same precedence, check associativity
|
|
96
|
+
if (nodePrecedence === parentPrecedence) {
|
|
97
|
+
// Right-associative operators: ** (power)
|
|
98
|
+
if (parentOp === '**') {
|
|
99
|
+
return position === 'left'; // Left operand needs parens
|
|
100
|
+
}
|
|
101
|
+
// Left-associative: need parens on right for -, /
|
|
102
|
+
if ((parentOp === '-' || parentOp === '/') && position === 'right') {
|
|
103
|
+
return true;
|
|
104
|
+
}
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
return false;
|
|
108
|
+
}
|
|
109
|
+
/**
|
|
110
|
+
* Get operator precedence (higher = tighter binding)
|
|
111
|
+
*/
|
|
112
|
+
function getPrecedence(op) {
|
|
113
|
+
switch (op) {
|
|
114
|
+
case '+':
|
|
115
|
+
case '-':
|
|
116
|
+
return 1;
|
|
117
|
+
case '*':
|
|
118
|
+
case '/':
|
|
119
|
+
return 2;
|
|
120
|
+
case '**':
|
|
121
|
+
return 3;
|
|
122
|
+
default:
|
|
123
|
+
return 0;
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
/**
|
|
127
|
+
* Generate mathematical notation for expressions (LaTeX-style comments)
|
|
128
|
+
*/
|
|
129
|
+
class MathNotationVisitor {
|
|
130
|
+
visitNumber(node) {
|
|
131
|
+
return String(node.value);
|
|
132
|
+
}
|
|
133
|
+
visitVariable(node) {
|
|
134
|
+
return node.name;
|
|
135
|
+
}
|
|
136
|
+
visitUnaryOp(node) {
|
|
137
|
+
const operand = node.operand.accept(this);
|
|
138
|
+
if (needsParens(node.operand)) {
|
|
139
|
+
return `${node.op}(${operand})`;
|
|
140
|
+
}
|
|
141
|
+
return `${node.op}${operand}`;
|
|
142
|
+
}
|
|
143
|
+
visitBinaryOp(node) {
|
|
144
|
+
const left = node.left.accept(this);
|
|
145
|
+
const right = node.right.accept(this);
|
|
146
|
+
const leftStr = needsParens(node.left, node.op) ? `(${left})` : left;
|
|
147
|
+
const rightStr = needsParens(node.right, node.op, 'right') ? `(${right})` : right;
|
|
148
|
+
if (node.op === '**') {
|
|
149
|
+
return `${leftStr}^${rightStr}`;
|
|
150
|
+
}
|
|
151
|
+
return `${leftStr} ${node.op} ${rightStr}`;
|
|
152
|
+
}
|
|
153
|
+
visitFunctionCall(node) {
|
|
154
|
+
const args = node.args.map(arg => arg.accept(this)).join(', ');
|
|
155
|
+
return `${node.name}(${args})`;
|
|
156
|
+
}
|
|
157
|
+
visitVectorAccess(node) {
|
|
158
|
+
const vector = node.vector.accept(this);
|
|
159
|
+
return `${vector}.${node.component}`;
|
|
160
|
+
}
|
|
161
|
+
visitVectorConstructor(node) {
|
|
162
|
+
const components = node.components.map(c => c.accept(this)).join(', ');
|
|
163
|
+
return `${node.vectorType}(${components})`;
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
/**
|
|
167
|
+
* Generate JavaScript code from AST
|
|
168
|
+
*/
|
|
169
|
+
export function generateCode(node) {
|
|
170
|
+
const visitor = new CodeGenVisitor();
|
|
171
|
+
return node.accept(visitor);
|
|
172
|
+
}
|
|
173
|
+
/**
|
|
174
|
+
* Generate mathematical notation from AST
|
|
175
|
+
*/
|
|
176
|
+
export function generateMathNotation(node) {
|
|
177
|
+
const visitor = new MathNotationVisitor();
|
|
178
|
+
return node.accept(visitor);
|
|
179
|
+
}
|
|
180
|
+
/**
|
|
181
|
+
* Generate complete gradient code for a program
|
|
182
|
+
*/
|
|
183
|
+
export function generateGradientCode(program, gradients, options = {}) {
|
|
184
|
+
const { includeMath = true, varStyle = 'const', includeForward = true, indent = ' ' } = options;
|
|
185
|
+
const lines = [];
|
|
186
|
+
// Header comment
|
|
187
|
+
lines.push('// Auto-generated gradient computation');
|
|
188
|
+
lines.push('// Generated by ScalarAutograd symbolic differentiation');
|
|
189
|
+
lines.push('');
|
|
190
|
+
if (includeForward) {
|
|
191
|
+
lines.push('// Forward pass');
|
|
192
|
+
for (const assignment of program.assignments) {
|
|
193
|
+
const code = generateCode(assignment.expression);
|
|
194
|
+
if (includeMath) {
|
|
195
|
+
const math = generateMathNotation(assignment.expression);
|
|
196
|
+
lines.push(`// ${assignment.variable} = ${math}`);
|
|
197
|
+
}
|
|
198
|
+
lines.push(`${varStyle} ${assignment.variable} = ${code};`);
|
|
199
|
+
lines.push('');
|
|
200
|
+
}
|
|
201
|
+
lines.push('// Gradient computation (reverse-mode autodiff)');
|
|
202
|
+
lines.push('');
|
|
203
|
+
}
|
|
204
|
+
// Generate gradients in topological order
|
|
205
|
+
const paramNames = Array.from(gradients.keys()).filter(name => name !== program.output);
|
|
206
|
+
for (const param of paramNames) {
|
|
207
|
+
const gradNode = gradients.get(param);
|
|
208
|
+
if (!gradNode)
|
|
209
|
+
continue;
|
|
210
|
+
const code = generateCode(gradNode);
|
|
211
|
+
if (includeMath) {
|
|
212
|
+
const math = generateMathNotation(gradNode);
|
|
213
|
+
lines.push(`// ∂${program.output}/∂${param} = ${math}`);
|
|
214
|
+
}
|
|
215
|
+
lines.push(`${varStyle} grad_${param} = ${code};`);
|
|
216
|
+
lines.push('');
|
|
217
|
+
}
|
|
218
|
+
// Export result
|
|
219
|
+
lines.push('// Result');
|
|
220
|
+
lines.push(`${varStyle} result = {`);
|
|
221
|
+
lines.push(`${indent}value: ${program.output},`);
|
|
222
|
+
lines.push(`${indent}gradients: {`);
|
|
223
|
+
for (const param of paramNames) {
|
|
224
|
+
lines.push(`${indent}${indent}${param}: grad_${param},`);
|
|
225
|
+
}
|
|
226
|
+
lines.push(`${indent}}`);
|
|
227
|
+
lines.push('};');
|
|
228
|
+
return lines.join('\n');
|
|
229
|
+
}
|
|
230
|
+
/**
|
|
231
|
+
* Generate gradient code as a function
|
|
232
|
+
*/
|
|
233
|
+
export function generateGradientFunction(program, gradients, functionName, parameters, options = {}) {
|
|
234
|
+
const { includeMath = true, indent = ' ' } = options;
|
|
235
|
+
const lines = [];
|
|
236
|
+
// Function signature
|
|
237
|
+
lines.push(`/**`);
|
|
238
|
+
lines.push(` * Compute ${program.output} and its gradients`);
|
|
239
|
+
for (const param of parameters) {
|
|
240
|
+
lines.push(` * @param ${param} - Input parameter`);
|
|
241
|
+
}
|
|
242
|
+
lines.push(` * @returns Object with value and gradients`);
|
|
243
|
+
lines.push(` */`);
|
|
244
|
+
lines.push(`function ${functionName}(${parameters.join(', ')}) {`);
|
|
245
|
+
// Forward pass
|
|
246
|
+
if (includeMath) {
|
|
247
|
+
lines.push(`${indent}// Forward pass`);
|
|
248
|
+
}
|
|
249
|
+
for (const assignment of program.assignments) {
|
|
250
|
+
const code = generateCode(assignment.expression);
|
|
251
|
+
if (includeMath) {
|
|
252
|
+
const math = generateMathNotation(assignment.expression);
|
|
253
|
+
lines.push(`${indent}// ${assignment.variable} = ${math}`);
|
|
254
|
+
}
|
|
255
|
+
lines.push(`${indent}const ${assignment.variable} = ${code};`);
|
|
256
|
+
}
|
|
257
|
+
lines.push('');
|
|
258
|
+
if (includeMath) {
|
|
259
|
+
lines.push(`${indent}// Gradient computation`);
|
|
260
|
+
}
|
|
261
|
+
// Gradients
|
|
262
|
+
for (const param of parameters) {
|
|
263
|
+
const gradNode = gradients.get(param);
|
|
264
|
+
if (!gradNode)
|
|
265
|
+
continue;
|
|
266
|
+
const code = generateCode(gradNode);
|
|
267
|
+
if (includeMath) {
|
|
268
|
+
const math = generateMathNotation(gradNode);
|
|
269
|
+
lines.push(`${indent}// ∂${program.output}/∂${param} = ${math}`);
|
|
270
|
+
}
|
|
271
|
+
lines.push(`${indent}const grad_${param} = ${code};`);
|
|
272
|
+
}
|
|
273
|
+
lines.push('');
|
|
274
|
+
lines.push(`${indent}return {`);
|
|
275
|
+
lines.push(`${indent}${indent}value: ${program.output},`);
|
|
276
|
+
lines.push(`${indent}${indent}gradients: { ${parameters.map(p => `${p}: grad_${p}`).join(', ')} }`);
|
|
277
|
+
lines.push(`${indent}};`);
|
|
278
|
+
lines.push('}');
|
|
279
|
+
return lines.join('\n');
|
|
280
|
+
}
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Expression parser for symbolic gradient generation.
|
|
3
|
+
* Parses operator-overloaded mathematical expressions into AST.
|
|
4
|
+
* @internal
|
|
5
|
+
*/
|
|
6
|
+
import { Program } from './AST';
|
|
7
|
+
/**
|
|
8
|
+
* Parser: converts tokens into AST
|
|
9
|
+
* Grammar (precedence from lowest to highest):
|
|
10
|
+
* assignment → IDENTIFIER '=' expression
|
|
11
|
+
* expression → term (('+' | '-') term)*
|
|
12
|
+
* term → factor (('*' | '/') factor)*
|
|
13
|
+
* factor → power
|
|
14
|
+
* power → postfix ('**' postfix)*
|
|
15
|
+
* postfix → primary ('.' IDENTIFIER)*
|
|
16
|
+
* primary → NUMBER | IDENTIFIER | function_call | vector_constructor | '(' expression ')' | ('+' | '-') primary
|
|
17
|
+
* function_call → IDENTIFIER '(' arg_list? ')'
|
|
18
|
+
* vector_constructor → ('Vec2' | 'Vec3') '(' arg_list ')'
|
|
19
|
+
* arg_list → expression (',' expression)*
|
|
20
|
+
*/
|
|
21
|
+
export declare class Parser {
|
|
22
|
+
private tokens;
|
|
23
|
+
private current;
|
|
24
|
+
constructor(text: string);
|
|
25
|
+
private peek;
|
|
26
|
+
private advance;
|
|
27
|
+
private expect;
|
|
28
|
+
/**
|
|
29
|
+
* Parse a complete program
|
|
30
|
+
*/
|
|
31
|
+
parseProgram(): Program;
|
|
32
|
+
/**
|
|
33
|
+
* Parse an expression
|
|
34
|
+
*/
|
|
35
|
+
private parseExpression;
|
|
36
|
+
/**
|
|
37
|
+
* Parse term (addition/subtraction)
|
|
38
|
+
*/
|
|
39
|
+
private parseTerm;
|
|
40
|
+
/**
|
|
41
|
+
* Parse factor (multiplication/division)
|
|
42
|
+
*/
|
|
43
|
+
private parseFactor;
|
|
44
|
+
/**
|
|
45
|
+
* Parse power (exponentiation)
|
|
46
|
+
*/
|
|
47
|
+
private parsePower;
|
|
48
|
+
/**
|
|
49
|
+
* Parse postfix (member access like v.x)
|
|
50
|
+
*/
|
|
51
|
+
private parsePostfix;
|
|
52
|
+
/**
|
|
53
|
+
* Parse primary expression
|
|
54
|
+
*/
|
|
55
|
+
private parsePrimary;
|
|
56
|
+
/**
|
|
57
|
+
* Parse function argument list
|
|
58
|
+
*/
|
|
59
|
+
private parseArgList;
|
|
60
|
+
}
|
|
61
|
+
/**
|
|
62
|
+
* Parse a mathematical expression string into an AST
|
|
63
|
+
*/
|
|
64
|
+
export declare function parse(text: string): Program;
|