gradient-script 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +515 -0
- package/dist/cli.d.ts +2 -0
- package/dist/cli.js +136 -0
- package/dist/dsl/AST.d.ts +123 -0
- package/dist/dsl/AST.js +23 -0
- package/dist/dsl/BuiltIns.d.ts +58 -0
- package/dist/dsl/BuiltIns.js +181 -0
- package/dist/dsl/CSE.d.ts +21 -0
- package/dist/dsl/CSE.js +194 -0
- package/dist/dsl/CodeGen.d.ts +60 -0
- package/dist/dsl/CodeGen.js +474 -0
- package/dist/dsl/Differentiation.d.ts +45 -0
- package/dist/dsl/Differentiation.js +421 -0
- package/dist/dsl/DiscontinuityAnalyzer.d.ts +18 -0
- package/dist/dsl/DiscontinuityAnalyzer.js +75 -0
- package/dist/dsl/Errors.d.ts +22 -0
- package/dist/dsl/Errors.js +49 -0
- package/dist/dsl/Expander.d.ts +13 -0
- package/dist/dsl/Expander.js +220 -0
- package/dist/dsl/ExpressionTransformer.d.ts +54 -0
- package/dist/dsl/ExpressionTransformer.js +102 -0
- package/dist/dsl/ExpressionUtils.d.ts +55 -0
- package/dist/dsl/ExpressionUtils.js +175 -0
- package/dist/dsl/GradientChecker.d.ts +71 -0
- package/dist/dsl/GradientChecker.js +258 -0
- package/dist/dsl/Guards.d.ts +27 -0
- package/dist/dsl/Guards.js +206 -0
- package/dist/dsl/Inliner.d.ts +10 -0
- package/dist/dsl/Inliner.js +40 -0
- package/dist/dsl/Lexer.d.ts +63 -0
- package/dist/dsl/Lexer.js +243 -0
- package/dist/dsl/Parser.d.ts +92 -0
- package/dist/dsl/Parser.js +328 -0
- package/dist/dsl/Simplify.d.ts +17 -0
- package/dist/dsl/Simplify.js +276 -0
- package/dist/dsl/TypeInference.d.ts +39 -0
- package/dist/dsl/TypeInference.js +147 -0
- package/dist/dsl/Types.d.ts +58 -0
- package/dist/dsl/Types.js +114 -0
- package/dist/index.d.ts +13 -0
- package/dist/index.js +11 -0
- package/dist/symbolic/AST.d.ts +113 -0
- package/dist/symbolic/AST.js +128 -0
- package/dist/symbolic/CodeGen.d.ts +35 -0
- package/dist/symbolic/CodeGen.js +280 -0
- package/dist/symbolic/Parser.d.ts +64 -0
- package/dist/symbolic/Parser.js +329 -0
- package/dist/symbolic/Simplify.d.ts +10 -0
- package/dist/symbolic/Simplify.js +244 -0
- package/dist/symbolic/SymbolicDiff.d.ts +35 -0
- package/dist/symbolic/SymbolicDiff.js +339 -0
- package/package.json +56 -0
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Expander for GradientScript DSL
|
|
3
|
+
* Expands built-in functions and struct operations into scalar operations
|
|
4
|
+
*/
|
|
5
|
+
/**
|
|
6
|
+
* Expand built-in function calls to scalar expressions
|
|
7
|
+
*/
|
|
8
|
+
export function expandBuiltIn(call) {
|
|
9
|
+
const { name, args } = call;
|
|
10
|
+
switch (name) {
|
|
11
|
+
case 'dot2d':
|
|
12
|
+
return expandDot2d(args[0], args[1]);
|
|
13
|
+
case 'cross2d':
|
|
14
|
+
return expandCross2d(args[0], args[1]);
|
|
15
|
+
case 'magnitude2d':
|
|
16
|
+
return expandMagnitude2d(args[0]);
|
|
17
|
+
case 'normalize2d':
|
|
18
|
+
throw new Error('normalize2d not yet supported in differentiation');
|
|
19
|
+
case 'distance2d':
|
|
20
|
+
return expandDistance2d(args[0], args[1]);
|
|
21
|
+
case 'dot3d':
|
|
22
|
+
return expandDot3d(args[0], args[1]);
|
|
23
|
+
case 'cross3d':
|
|
24
|
+
throw new Error('cross3d returns vector - not yet supported');
|
|
25
|
+
case 'magnitude3d':
|
|
26
|
+
return expandMagnitude3d(args[0]);
|
|
27
|
+
default:
|
|
28
|
+
// Math functions (sin, cos, etc.) don't need expansion
|
|
29
|
+
return call;
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
/**
|
|
33
|
+
* Expand dot2d(u, v) → u.x * v.x + u.y * v.y
|
|
34
|
+
*/
|
|
35
|
+
function expandDot2d(u, v) {
|
|
36
|
+
return {
|
|
37
|
+
kind: 'binary',
|
|
38
|
+
operator: '+',
|
|
39
|
+
left: {
|
|
40
|
+
kind: 'binary',
|
|
41
|
+
operator: '*',
|
|
42
|
+
left: component(u, 'x'),
|
|
43
|
+
right: component(v, 'x')
|
|
44
|
+
},
|
|
45
|
+
right: {
|
|
46
|
+
kind: 'binary',
|
|
47
|
+
operator: '*',
|
|
48
|
+
left: component(u, 'y'),
|
|
49
|
+
right: component(v, 'y')
|
|
50
|
+
}
|
|
51
|
+
};
|
|
52
|
+
}
|
|
53
|
+
/**
|
|
54
|
+
* Expand cross2d(u, v) → u.x * v.y - u.y * v.x
|
|
55
|
+
*/
|
|
56
|
+
function expandCross2d(u, v) {
|
|
57
|
+
return {
|
|
58
|
+
kind: 'binary',
|
|
59
|
+
operator: '-',
|
|
60
|
+
left: {
|
|
61
|
+
kind: 'binary',
|
|
62
|
+
operator: '*',
|
|
63
|
+
left: component(u, 'x'),
|
|
64
|
+
right: component(v, 'y')
|
|
65
|
+
},
|
|
66
|
+
right: {
|
|
67
|
+
kind: 'binary',
|
|
68
|
+
operator: '*',
|
|
69
|
+
left: component(u, 'y'),
|
|
70
|
+
right: component(v, 'x')
|
|
71
|
+
}
|
|
72
|
+
};
|
|
73
|
+
}
|
|
74
|
+
/**
|
|
75
|
+
* Expand magnitude2d(v) → sqrt(v.x^2 + v.y^2)
|
|
76
|
+
*/
|
|
77
|
+
function expandMagnitude2d(v) {
|
|
78
|
+
return {
|
|
79
|
+
kind: 'call',
|
|
80
|
+
name: 'sqrt',
|
|
81
|
+
args: [{
|
|
82
|
+
kind: 'binary',
|
|
83
|
+
operator: '+',
|
|
84
|
+
left: {
|
|
85
|
+
kind: 'binary',
|
|
86
|
+
operator: '^',
|
|
87
|
+
left: component(v, 'x'),
|
|
88
|
+
right: { kind: 'number', value: 2 }
|
|
89
|
+
},
|
|
90
|
+
right: {
|
|
91
|
+
kind: 'binary',
|
|
92
|
+
operator: '^',
|
|
93
|
+
left: component(v, 'y'),
|
|
94
|
+
right: { kind: 'number', value: 2 }
|
|
95
|
+
}
|
|
96
|
+
}]
|
|
97
|
+
};
|
|
98
|
+
}
|
|
99
|
+
/**
|
|
100
|
+
* Expand distance2d(p1, p2) → magnitude2d(p2 - p1)
|
|
101
|
+
* But we can't subtract structs yet, so expand fully:
|
|
102
|
+
* sqrt((p2.x - p1.x)^2 + (p2.y - p1.y)^2)
|
|
103
|
+
*/
|
|
104
|
+
function expandDistance2d(p1, p2) {
|
|
105
|
+
const dx = {
|
|
106
|
+
kind: 'binary',
|
|
107
|
+
operator: '-',
|
|
108
|
+
left: component(p2, 'x'),
|
|
109
|
+
right: component(p1, 'x')
|
|
110
|
+
};
|
|
111
|
+
const dy = {
|
|
112
|
+
kind: 'binary',
|
|
113
|
+
operator: '-',
|
|
114
|
+
left: component(p2, 'y'),
|
|
115
|
+
right: component(p1, 'y')
|
|
116
|
+
};
|
|
117
|
+
return {
|
|
118
|
+
kind: 'call',
|
|
119
|
+
name: 'sqrt',
|
|
120
|
+
args: [{
|
|
121
|
+
kind: 'binary',
|
|
122
|
+
operator: '+',
|
|
123
|
+
left: {
|
|
124
|
+
kind: 'binary',
|
|
125
|
+
operator: '^',
|
|
126
|
+
left: dx,
|
|
127
|
+
right: { kind: 'number', value: 2 }
|
|
128
|
+
},
|
|
129
|
+
right: {
|
|
130
|
+
kind: 'binary',
|
|
131
|
+
operator: '^',
|
|
132
|
+
left: dy,
|
|
133
|
+
right: { kind: 'number', value: 2 }
|
|
134
|
+
}
|
|
135
|
+
}]
|
|
136
|
+
};
|
|
137
|
+
}
|
|
138
|
+
/**
|
|
139
|
+
* Expand dot3d(u, v) → u.x * v.x + u.y * v.y + u.z * v.z
|
|
140
|
+
*/
|
|
141
|
+
function expandDot3d(u, v) {
|
|
142
|
+
return {
|
|
143
|
+
kind: 'binary',
|
|
144
|
+
operator: '+',
|
|
145
|
+
left: {
|
|
146
|
+
kind: 'binary',
|
|
147
|
+
operator: '+',
|
|
148
|
+
left: {
|
|
149
|
+
kind: 'binary',
|
|
150
|
+
operator: '*',
|
|
151
|
+
left: component(u, 'x'),
|
|
152
|
+
right: component(v, 'x')
|
|
153
|
+
},
|
|
154
|
+
right: {
|
|
155
|
+
kind: 'binary',
|
|
156
|
+
operator: '*',
|
|
157
|
+
left: component(u, 'y'),
|
|
158
|
+
right: component(v, 'y')
|
|
159
|
+
}
|
|
160
|
+
},
|
|
161
|
+
right: {
|
|
162
|
+
kind: 'binary',
|
|
163
|
+
operator: '*',
|
|
164
|
+
left: component(u, 'z'),
|
|
165
|
+
right: component(v, 'z')
|
|
166
|
+
}
|
|
167
|
+
};
|
|
168
|
+
}
|
|
169
|
+
/**
|
|
170
|
+
* Expand magnitude3d(v) → sqrt(v.x^2 + v.y^2 + v.z^2)
|
|
171
|
+
*/
|
|
172
|
+
function expandMagnitude3d(v) {
|
|
173
|
+
return {
|
|
174
|
+
kind: 'call',
|
|
175
|
+
name: 'sqrt',
|
|
176
|
+
args: [{
|
|
177
|
+
kind: 'binary',
|
|
178
|
+
operator: '+',
|
|
179
|
+
left: {
|
|
180
|
+
kind: 'binary',
|
|
181
|
+
operator: '+',
|
|
182
|
+
left: {
|
|
183
|
+
kind: 'binary',
|
|
184
|
+
operator: '^',
|
|
185
|
+
left: component(v, 'x'),
|
|
186
|
+
right: { kind: 'number', value: 2 }
|
|
187
|
+
},
|
|
188
|
+
right: {
|
|
189
|
+
kind: 'binary',
|
|
190
|
+
operator: '^',
|
|
191
|
+
left: component(v, 'y'),
|
|
192
|
+
right: { kind: 'number', value: 2 }
|
|
193
|
+
}
|
|
194
|
+
},
|
|
195
|
+
right: {
|
|
196
|
+
kind: 'binary',
|
|
197
|
+
operator: '^',
|
|
198
|
+
left: component(v, 'z'),
|
|
199
|
+
right: { kind: 'number', value: 2 }
|
|
200
|
+
}
|
|
201
|
+
}]
|
|
202
|
+
};
|
|
203
|
+
}
|
|
204
|
+
/**
|
|
205
|
+
* Helper: Create component access expression
|
|
206
|
+
*/
|
|
207
|
+
function component(obj, comp) {
|
|
208
|
+
return {
|
|
209
|
+
kind: 'component',
|
|
210
|
+
object: obj,
|
|
211
|
+
component: comp
|
|
212
|
+
};
|
|
213
|
+
}
|
|
214
|
+
/**
|
|
215
|
+
* Check if a function call should be expanded
|
|
216
|
+
*/
|
|
217
|
+
export function shouldExpand(name) {
|
|
218
|
+
const expandable = ['dot2d', 'cross2d', 'magnitude2d', 'distance2d', 'dot3d', 'magnitude3d'];
|
|
219
|
+
return expandable.includes(name);
|
|
220
|
+
}
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* ExpressionTransformer - Abstract base class for AST transformations
|
|
3
|
+
*
|
|
4
|
+
* This class eliminates duplicated AST traversal logic by providing:
|
|
5
|
+
* - Default recursive descent for all expression types
|
|
6
|
+
* - Protected visit methods that subclasses can override
|
|
7
|
+
* - Type-safe transformation pipeline
|
|
8
|
+
*
|
|
9
|
+
* Usage:
|
|
10
|
+
* class MyTransformer extends ExpressionTransformer {
|
|
11
|
+
* protected visitBinaryOp(node: BinaryOp): Expression {
|
|
12
|
+
* // Custom logic here
|
|
13
|
+
* return super.visitBinaryOp(node); // Or custom result
|
|
14
|
+
* }
|
|
15
|
+
* }
|
|
16
|
+
*/
|
|
17
|
+
import { Expression, NumberLiteral, Variable, BinaryOp, UnaryOp, FunctionCall, ComponentAccess } from './AST.js';
|
|
18
|
+
export declare abstract class ExpressionTransformer {
|
|
19
|
+
/**
|
|
20
|
+
* Main entry point for transforming an expression
|
|
21
|
+
* Dispatches to appropriate visit method based on node kind
|
|
22
|
+
*/
|
|
23
|
+
transform(expr: Expression): Expression;
|
|
24
|
+
/**
|
|
25
|
+
* Visit a number literal
|
|
26
|
+
* Default: Return unchanged (identity transformation)
|
|
27
|
+
*/
|
|
28
|
+
protected visitNumber(node: NumberLiteral): Expression;
|
|
29
|
+
/**
|
|
30
|
+
* Visit a variable reference
|
|
31
|
+
* Default: Return unchanged (identity transformation)
|
|
32
|
+
*/
|
|
33
|
+
protected visitVariable(node: Variable): Expression;
|
|
34
|
+
/**
|
|
35
|
+
* Visit a binary operation
|
|
36
|
+
* Default: Transform left and right children, return new node
|
|
37
|
+
*/
|
|
38
|
+
protected visitBinaryOp(node: BinaryOp): Expression;
|
|
39
|
+
/**
|
|
40
|
+
* Visit a unary operation
|
|
41
|
+
* Default: Transform operand, return new node
|
|
42
|
+
*/
|
|
43
|
+
protected visitUnaryOp(node: UnaryOp): Expression;
|
|
44
|
+
/**
|
|
45
|
+
* Visit a function call
|
|
46
|
+
* Default: Transform all arguments, return new node
|
|
47
|
+
*/
|
|
48
|
+
protected visitFunctionCall(node: FunctionCall): Expression;
|
|
49
|
+
/**
|
|
50
|
+
* Visit a component access (e.g., v.x)
|
|
51
|
+
* Default: Transform object, return new node
|
|
52
|
+
*/
|
|
53
|
+
protected visitComponentAccess(node: ComponentAccess): Expression;
|
|
54
|
+
}
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* ExpressionTransformer - Abstract base class for AST transformations
|
|
3
|
+
*
|
|
4
|
+
* This class eliminates duplicated AST traversal logic by providing:
|
|
5
|
+
* - Default recursive descent for all expression types
|
|
6
|
+
* - Protected visit methods that subclasses can override
|
|
7
|
+
* - Type-safe transformation pipeline
|
|
8
|
+
*
|
|
9
|
+
* Usage:
|
|
10
|
+
* class MyTransformer extends ExpressionTransformer {
|
|
11
|
+
* protected visitBinaryOp(node: BinaryOp): Expression {
|
|
12
|
+
* // Custom logic here
|
|
13
|
+
* return super.visitBinaryOp(node); // Or custom result
|
|
14
|
+
* }
|
|
15
|
+
* }
|
|
16
|
+
*/
|
|
17
|
+
export class ExpressionTransformer {
|
|
18
|
+
/**
|
|
19
|
+
* Main entry point for transforming an expression
|
|
20
|
+
* Dispatches to appropriate visit method based on node kind
|
|
21
|
+
*/
|
|
22
|
+
transform(expr) {
|
|
23
|
+
switch (expr.kind) {
|
|
24
|
+
case 'number':
|
|
25
|
+
return this.visitNumber(expr);
|
|
26
|
+
case 'variable':
|
|
27
|
+
return this.visitVariable(expr);
|
|
28
|
+
case 'binary':
|
|
29
|
+
return this.visitBinaryOp(expr);
|
|
30
|
+
case 'unary':
|
|
31
|
+
return this.visitUnaryOp(expr);
|
|
32
|
+
case 'call':
|
|
33
|
+
return this.visitFunctionCall(expr);
|
|
34
|
+
case 'component':
|
|
35
|
+
return this.visitComponentAccess(expr);
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
/**
|
|
39
|
+
* Visit a number literal
|
|
40
|
+
* Default: Return unchanged (identity transformation)
|
|
41
|
+
*/
|
|
42
|
+
visitNumber(node) {
|
|
43
|
+
return node;
|
|
44
|
+
}
|
|
45
|
+
/**
|
|
46
|
+
* Visit a variable reference
|
|
47
|
+
* Default: Return unchanged (identity transformation)
|
|
48
|
+
*/
|
|
49
|
+
visitVariable(node) {
|
|
50
|
+
return node;
|
|
51
|
+
}
|
|
52
|
+
/**
|
|
53
|
+
* Visit a binary operation
|
|
54
|
+
* Default: Transform left and right children, return new node
|
|
55
|
+
*/
|
|
56
|
+
visitBinaryOp(node) {
|
|
57
|
+
const left = this.transform(node.left);
|
|
58
|
+
const right = this.transform(node.right);
|
|
59
|
+
return {
|
|
60
|
+
kind: 'binary',
|
|
61
|
+
operator: node.operator,
|
|
62
|
+
left,
|
|
63
|
+
right
|
|
64
|
+
};
|
|
65
|
+
}
|
|
66
|
+
/**
|
|
67
|
+
* Visit a unary operation
|
|
68
|
+
* Default: Transform operand, return new node
|
|
69
|
+
*/
|
|
70
|
+
visitUnaryOp(node) {
|
|
71
|
+
const operand = this.transform(node.operand);
|
|
72
|
+
return {
|
|
73
|
+
kind: 'unary',
|
|
74
|
+
operator: node.operator,
|
|
75
|
+
operand
|
|
76
|
+
};
|
|
77
|
+
}
|
|
78
|
+
/**
|
|
79
|
+
* Visit a function call
|
|
80
|
+
* Default: Transform all arguments, return new node
|
|
81
|
+
*/
|
|
82
|
+
visitFunctionCall(node) {
|
|
83
|
+
const args = node.args.map(arg => this.transform(arg));
|
|
84
|
+
return {
|
|
85
|
+
kind: 'call',
|
|
86
|
+
name: node.name,
|
|
87
|
+
args
|
|
88
|
+
};
|
|
89
|
+
}
|
|
90
|
+
/**
|
|
91
|
+
* Visit a component access (e.g., v.x)
|
|
92
|
+
* Default: Transform object, return new node
|
|
93
|
+
*/
|
|
94
|
+
visitComponentAccess(node) {
|
|
95
|
+
const object = this.transform(node.object);
|
|
96
|
+
return {
|
|
97
|
+
kind: 'component',
|
|
98
|
+
object,
|
|
99
|
+
component: node.component
|
|
100
|
+
};
|
|
101
|
+
}
|
|
102
|
+
}
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Shared utility functions for expression manipulation
|
|
3
|
+
* Eliminates code duplication across Differentiation, Inliner, and CSE modules
|
|
4
|
+
*/
|
|
5
|
+
import { Expression, NumberLiteral, Variable, BinaryOp } from './AST.js';
|
|
6
|
+
/**
|
|
7
|
+
* Substitute all occurrences of a variable with a replacement expression
|
|
8
|
+
* Used by: Differentiation, Inliner, CSE
|
|
9
|
+
*/
|
|
10
|
+
export declare function substituteVariable(expr: Expression, varName: string, replacement: Expression): Expression;
|
|
11
|
+
/**
|
|
12
|
+
* Check if expression is zero
|
|
13
|
+
*/
|
|
14
|
+
export declare function isZero(expr: Expression): boolean;
|
|
15
|
+
/**
|
|
16
|
+
* Check if expression is one
|
|
17
|
+
*/
|
|
18
|
+
export declare function isOne(expr: Expression): boolean;
|
|
19
|
+
/**
|
|
20
|
+
* Check if expression is a constant (number literal)
|
|
21
|
+
*/
|
|
22
|
+
export declare function isConstant(expr: Expression): boolean;
|
|
23
|
+
/**
|
|
24
|
+
* Check if expression is a variable
|
|
25
|
+
* If name is provided, checks if it matches that specific variable name
|
|
26
|
+
*/
|
|
27
|
+
export declare function isVariable(expr: Expression, name?: string): boolean;
|
|
28
|
+
/**
|
|
29
|
+
* Check if expression is a negative number
|
|
30
|
+
*/
|
|
31
|
+
export declare function isNegative(expr: Expression): boolean;
|
|
32
|
+
/**
|
|
33
|
+
* Create a number literal
|
|
34
|
+
*/
|
|
35
|
+
export declare function makeNumber(value: number): NumberLiteral;
|
|
36
|
+
/**
|
|
37
|
+
* Create a binary operation
|
|
38
|
+
*/
|
|
39
|
+
export declare function makeBinaryOp(op: '+' | '-' | '*' | '/' | '^' | '**', left: Expression, right: Expression): BinaryOp;
|
|
40
|
+
/**
|
|
41
|
+
* Create a variable reference
|
|
42
|
+
*/
|
|
43
|
+
export declare function makeVariable(name: string): Variable;
|
|
44
|
+
/**
|
|
45
|
+
* Get all variable names used in an expression
|
|
46
|
+
*/
|
|
47
|
+
export declare function getVariables(expr: Expression): Set<string>;
|
|
48
|
+
/**
|
|
49
|
+
* Check if an expression contains a specific variable
|
|
50
|
+
*/
|
|
51
|
+
export declare function containsVariable(expr: Expression, varName: string): boolean;
|
|
52
|
+
/**
|
|
53
|
+
* Calculate the maximum nesting depth of an expression
|
|
54
|
+
*/
|
|
55
|
+
export declare function expressionDepth(expr: Expression): number;
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Shared utility functions for expression manipulation
|
|
3
|
+
* Eliminates code duplication across Differentiation, Inliner, and CSE modules
|
|
4
|
+
*/
|
|
5
|
+
/**
|
|
6
|
+
* Substitute all occurrences of a variable with a replacement expression
|
|
7
|
+
* Used by: Differentiation, Inliner, CSE
|
|
8
|
+
*/
|
|
9
|
+
export function substituteVariable(expr, varName, replacement) {
|
|
10
|
+
switch (expr.kind) {
|
|
11
|
+
case 'number':
|
|
12
|
+
return expr;
|
|
13
|
+
case 'variable':
|
|
14
|
+
return expr.name === varName ? replacement : expr;
|
|
15
|
+
case 'binary':
|
|
16
|
+
return {
|
|
17
|
+
kind: 'binary',
|
|
18
|
+
operator: expr.operator,
|
|
19
|
+
left: substituteVariable(expr.left, varName, replacement),
|
|
20
|
+
right: substituteVariable(expr.right, varName, replacement)
|
|
21
|
+
};
|
|
22
|
+
case 'unary':
|
|
23
|
+
return {
|
|
24
|
+
kind: 'unary',
|
|
25
|
+
operator: expr.operator,
|
|
26
|
+
operand: substituteVariable(expr.operand, varName, replacement)
|
|
27
|
+
};
|
|
28
|
+
case 'call':
|
|
29
|
+
return {
|
|
30
|
+
kind: 'call',
|
|
31
|
+
name: expr.name,
|
|
32
|
+
args: expr.args.map(arg => substituteVariable(arg, varName, replacement))
|
|
33
|
+
};
|
|
34
|
+
case 'component':
|
|
35
|
+
return {
|
|
36
|
+
kind: 'component',
|
|
37
|
+
object: substituteVariable(expr.object, varName, replacement),
|
|
38
|
+
component: expr.component
|
|
39
|
+
};
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
/**
|
|
43
|
+
* Check if expression is zero
|
|
44
|
+
*/
|
|
45
|
+
export function isZero(expr) {
|
|
46
|
+
return expr.kind === 'number' && expr.value === 0;
|
|
47
|
+
}
|
|
48
|
+
/**
|
|
49
|
+
* Check if expression is one
|
|
50
|
+
*/
|
|
51
|
+
export function isOne(expr) {
|
|
52
|
+
return expr.kind === 'number' && expr.value === 1;
|
|
53
|
+
}
|
|
54
|
+
/**
|
|
55
|
+
* Check if expression is a constant (number literal)
|
|
56
|
+
*/
|
|
57
|
+
export function isConstant(expr) {
|
|
58
|
+
return expr.kind === 'number';
|
|
59
|
+
}
|
|
60
|
+
/**
|
|
61
|
+
* Check if expression is a variable
|
|
62
|
+
* If name is provided, checks if it matches that specific variable name
|
|
63
|
+
*/
|
|
64
|
+
export function isVariable(expr, name) {
|
|
65
|
+
if (expr.kind !== 'variable') {
|
|
66
|
+
return false;
|
|
67
|
+
}
|
|
68
|
+
return name === undefined || expr.name === name;
|
|
69
|
+
}
|
|
70
|
+
/**
|
|
71
|
+
* Check if expression is a negative number
|
|
72
|
+
*/
|
|
73
|
+
export function isNegative(expr) {
|
|
74
|
+
return expr.kind === 'number' && expr.value < 0;
|
|
75
|
+
}
|
|
76
|
+
/**
|
|
77
|
+
* Create a number literal
|
|
78
|
+
*/
|
|
79
|
+
export function makeNumber(value) {
|
|
80
|
+
return { kind: 'number', value };
|
|
81
|
+
}
|
|
82
|
+
/**
|
|
83
|
+
* Create a binary operation
|
|
84
|
+
*/
|
|
85
|
+
export function makeBinaryOp(op, left, right) {
|
|
86
|
+
return {
|
|
87
|
+
kind: 'binary',
|
|
88
|
+
operator: op,
|
|
89
|
+
left,
|
|
90
|
+
right
|
|
91
|
+
};
|
|
92
|
+
}
|
|
93
|
+
/**
|
|
94
|
+
* Create a variable reference
|
|
95
|
+
*/
|
|
96
|
+
export function makeVariable(name) {
|
|
97
|
+
return { kind: 'variable', name };
|
|
98
|
+
}
|
|
99
|
+
/**
|
|
100
|
+
* Get all variable names used in an expression
|
|
101
|
+
*/
|
|
102
|
+
export function getVariables(expr) {
|
|
103
|
+
const vars = new Set();
|
|
104
|
+
function collect(e) {
|
|
105
|
+
switch (e.kind) {
|
|
106
|
+
case 'variable':
|
|
107
|
+
vars.add(e.name);
|
|
108
|
+
break;
|
|
109
|
+
case 'binary':
|
|
110
|
+
collect(e.left);
|
|
111
|
+
collect(e.right);
|
|
112
|
+
break;
|
|
113
|
+
case 'unary':
|
|
114
|
+
collect(e.operand);
|
|
115
|
+
break;
|
|
116
|
+
case 'call':
|
|
117
|
+
for (const arg of e.args) {
|
|
118
|
+
collect(arg);
|
|
119
|
+
}
|
|
120
|
+
break;
|
|
121
|
+
case 'component':
|
|
122
|
+
collect(e.object);
|
|
123
|
+
if (e.object.kind === 'variable') {
|
|
124
|
+
vars.add(`${e.object.name}.${e.component}`);
|
|
125
|
+
}
|
|
126
|
+
break;
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
collect(expr);
|
|
130
|
+
return vars;
|
|
131
|
+
}
|
|
132
|
+
/**
|
|
133
|
+
* Check if an expression contains a specific variable
|
|
134
|
+
*/
|
|
135
|
+
export function containsVariable(expr, varName) {
|
|
136
|
+
switch (expr.kind) {
|
|
137
|
+
case 'number':
|
|
138
|
+
return false;
|
|
139
|
+
case 'variable':
|
|
140
|
+
return expr.name === varName;
|
|
141
|
+
case 'binary':
|
|
142
|
+
return containsVariable(expr.left, varName) || containsVariable(expr.right, varName);
|
|
143
|
+
case 'unary':
|
|
144
|
+
return containsVariable(expr.operand, varName);
|
|
145
|
+
case 'call':
|
|
146
|
+
return expr.args.some(arg => containsVariable(arg, varName));
|
|
147
|
+
case 'component':
|
|
148
|
+
if (expr.object.kind === 'variable') {
|
|
149
|
+
const fullName = `${expr.object.name}.${expr.component}`;
|
|
150
|
+
return fullName === varName || expr.object.name === varName;
|
|
151
|
+
}
|
|
152
|
+
return containsVariable(expr.object, varName);
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
/**
|
|
156
|
+
* Calculate the maximum nesting depth of an expression
|
|
157
|
+
*/
|
|
158
|
+
export function expressionDepth(expr) {
|
|
159
|
+
switch (expr.kind) {
|
|
160
|
+
case 'number':
|
|
161
|
+
case 'variable':
|
|
162
|
+
return 1;
|
|
163
|
+
case 'binary':
|
|
164
|
+
return 1 + Math.max(expressionDepth(expr.left), expressionDepth(expr.right));
|
|
165
|
+
case 'unary':
|
|
166
|
+
return 1 + expressionDepth(expr.operand);
|
|
167
|
+
case 'call':
|
|
168
|
+
if (expr.args.length === 0) {
|
|
169
|
+
return 1;
|
|
170
|
+
}
|
|
171
|
+
return 1 + Math.max(...expr.args.map(expressionDepth));
|
|
172
|
+
case 'component':
|
|
173
|
+
return 1 + expressionDepth(expr.object);
|
|
174
|
+
}
|
|
175
|
+
}
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Numerical gradient checking for GradientScript DSL
|
|
3
|
+
* Validates symbolic gradients against finite difference approximations
|
|
4
|
+
*/
|
|
5
|
+
import { FunctionDef } from './AST.js';
|
|
6
|
+
import { TypeEnv } from './Types.js';
|
|
7
|
+
import { GradientResult } from './Differentiation.js';
|
|
8
|
+
/**
|
|
9
|
+
* Numerical value (scalar or structured)
|
|
10
|
+
*/
|
|
11
|
+
type NumValue = number | {
|
|
12
|
+
[key: string]: number;
|
|
13
|
+
};
|
|
14
|
+
/**
|
|
15
|
+
* Gradient checking result
|
|
16
|
+
*/
|
|
17
|
+
export interface GradCheckResult {
|
|
18
|
+
passed: boolean;
|
|
19
|
+
errors: GradCheckError[];
|
|
20
|
+
maxError: number;
|
|
21
|
+
meanError: number;
|
|
22
|
+
}
|
|
23
|
+
export interface GradCheckError {
|
|
24
|
+
parameter: string;
|
|
25
|
+
component?: string;
|
|
26
|
+
analytical: number;
|
|
27
|
+
numerical: number;
|
|
28
|
+
error: number;
|
|
29
|
+
relativeError: number;
|
|
30
|
+
}
|
|
31
|
+
/**
|
|
32
|
+
* Gradient checker
|
|
33
|
+
*/
|
|
34
|
+
export declare class GradientChecker {
|
|
35
|
+
private epsilon;
|
|
36
|
+
private tolerance;
|
|
37
|
+
constructor(epsilon?: number, tolerance?: number);
|
|
38
|
+
/**
|
|
39
|
+
* Check gradients for a function
|
|
40
|
+
*/
|
|
41
|
+
check(func: FunctionDef, gradients: GradientResult, env: TypeEnv, testPoint: Map<string, NumValue>): GradCheckResult;
|
|
42
|
+
/**
|
|
43
|
+
* Compute numerical gradient for scalar parameter using finite differences
|
|
44
|
+
*/
|
|
45
|
+
private numericalGradientScalar;
|
|
46
|
+
/**
|
|
47
|
+
* Compute numerical gradient for structured parameter component
|
|
48
|
+
*/
|
|
49
|
+
private numericalGradientComponent;
|
|
50
|
+
/**
|
|
51
|
+
* Evaluate function at a test point
|
|
52
|
+
*/
|
|
53
|
+
private evaluateFunction;
|
|
54
|
+
/**
|
|
55
|
+
* Inline function body with test point values
|
|
56
|
+
*/
|
|
57
|
+
private inlineWithValues;
|
|
58
|
+
/**
|
|
59
|
+
* Substitute variables in expression
|
|
60
|
+
*/
|
|
61
|
+
private substituteExpression;
|
|
62
|
+
/**
|
|
63
|
+
* Evaluate an expression numerically
|
|
64
|
+
*/
|
|
65
|
+
private evaluateExpression;
|
|
66
|
+
/**
|
|
67
|
+
* Evaluate math function
|
|
68
|
+
*/
|
|
69
|
+
private evaluateMathFunction;
|
|
70
|
+
}
|
|
71
|
+
export {};
|