gradient-script 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +515 -0
- package/dist/cli.d.ts +2 -0
- package/dist/cli.js +136 -0
- package/dist/dsl/AST.d.ts +123 -0
- package/dist/dsl/AST.js +23 -0
- package/dist/dsl/BuiltIns.d.ts +58 -0
- package/dist/dsl/BuiltIns.js +181 -0
- package/dist/dsl/CSE.d.ts +21 -0
- package/dist/dsl/CSE.js +194 -0
- package/dist/dsl/CodeGen.d.ts +60 -0
- package/dist/dsl/CodeGen.js +474 -0
- package/dist/dsl/Differentiation.d.ts +45 -0
- package/dist/dsl/Differentiation.js +421 -0
- package/dist/dsl/DiscontinuityAnalyzer.d.ts +18 -0
- package/dist/dsl/DiscontinuityAnalyzer.js +75 -0
- package/dist/dsl/Errors.d.ts +22 -0
- package/dist/dsl/Errors.js +49 -0
- package/dist/dsl/Expander.d.ts +13 -0
- package/dist/dsl/Expander.js +220 -0
- package/dist/dsl/ExpressionTransformer.d.ts +54 -0
- package/dist/dsl/ExpressionTransformer.js +102 -0
- package/dist/dsl/ExpressionUtils.d.ts +55 -0
- package/dist/dsl/ExpressionUtils.js +175 -0
- package/dist/dsl/GradientChecker.d.ts +71 -0
- package/dist/dsl/GradientChecker.js +258 -0
- package/dist/dsl/Guards.d.ts +27 -0
- package/dist/dsl/Guards.js +206 -0
- package/dist/dsl/Inliner.d.ts +10 -0
- package/dist/dsl/Inliner.js +40 -0
- package/dist/dsl/Lexer.d.ts +63 -0
- package/dist/dsl/Lexer.js +243 -0
- package/dist/dsl/Parser.d.ts +92 -0
- package/dist/dsl/Parser.js +328 -0
- package/dist/dsl/Simplify.d.ts +17 -0
- package/dist/dsl/Simplify.js +276 -0
- package/dist/dsl/TypeInference.d.ts +39 -0
- package/dist/dsl/TypeInference.js +147 -0
- package/dist/dsl/Types.d.ts +58 -0
- package/dist/dsl/Types.js +114 -0
- package/dist/index.d.ts +13 -0
- package/dist/index.js +11 -0
- package/dist/symbolic/AST.d.ts +113 -0
- package/dist/symbolic/AST.js +128 -0
- package/dist/symbolic/CodeGen.d.ts +35 -0
- package/dist/symbolic/CodeGen.js +280 -0
- package/dist/symbolic/Parser.d.ts +64 -0
- package/dist/symbolic/Parser.js +329 -0
- package/dist/symbolic/Simplify.d.ts +10 -0
- package/dist/symbolic/Simplify.js +244 -0
- package/dist/symbolic/SymbolicDiff.d.ts +35 -0
- package/dist/symbolic/SymbolicDiff.js +339 -0
- package/package.json +56 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* AST nodes for GradientScript DSL
|
|
3
|
+
* Supports function definitions with structured types
|
|
4
|
+
*/
|
|
5
|
+
import { Type } from './Types.js';
|
|
6
|
+
/**
|
|
7
|
+
* Base AST node
|
|
8
|
+
*/
|
|
9
|
+
export interface ASTNode {
|
|
10
|
+
type?: Type;
|
|
11
|
+
}
|
|
12
|
+
/**
|
|
13
|
+
* Program (top-level)
|
|
14
|
+
*/
|
|
15
|
+
export interface Program extends ASTNode {
|
|
16
|
+
kind: 'program';
|
|
17
|
+
functions: FunctionDef[];
|
|
18
|
+
}
|
|
19
|
+
/**
|
|
20
|
+
* Function definition
|
|
21
|
+
*/
|
|
22
|
+
export interface FunctionDef extends ASTNode {
|
|
23
|
+
kind: 'function';
|
|
24
|
+
name: string;
|
|
25
|
+
parameters: Parameter[];
|
|
26
|
+
body: Statement[];
|
|
27
|
+
returnExpr: Expression;
|
|
28
|
+
}
|
|
29
|
+
/**
|
|
30
|
+
* Function parameter
|
|
31
|
+
*/
|
|
32
|
+
export interface Parameter {
|
|
33
|
+
name: string;
|
|
34
|
+
requiresGrad: boolean;
|
|
35
|
+
paramType?: StructTypeAnnotation;
|
|
36
|
+
}
|
|
37
|
+
/**
|
|
38
|
+
* Type annotation for structured types
|
|
39
|
+
*/
|
|
40
|
+
export interface StructTypeAnnotation {
|
|
41
|
+
components: string[];
|
|
42
|
+
}
|
|
43
|
+
/**
|
|
44
|
+
* Statement types
|
|
45
|
+
*/
|
|
46
|
+
export type Statement = Assignment;
|
|
47
|
+
/**
|
|
48
|
+
* Assignment statement
|
|
49
|
+
*/
|
|
50
|
+
export interface Assignment extends ASTNode {
|
|
51
|
+
kind: 'assignment';
|
|
52
|
+
variable: string;
|
|
53
|
+
expression: Expression;
|
|
54
|
+
}
|
|
55
|
+
/**
|
|
56
|
+
* Expression types
|
|
57
|
+
*/
|
|
58
|
+
export type Expression = NumberLiteral | Variable | BinaryOp | UnaryOp | FunctionCall | ComponentAccess;
|
|
59
|
+
/**
|
|
60
|
+
* Number literal
|
|
61
|
+
*/
|
|
62
|
+
export interface NumberLiteral extends ASTNode {
|
|
63
|
+
kind: 'number';
|
|
64
|
+
value: number;
|
|
65
|
+
}
|
|
66
|
+
/**
|
|
67
|
+
* Variable reference
|
|
68
|
+
*/
|
|
69
|
+
export interface Variable extends ASTNode {
|
|
70
|
+
kind: 'variable';
|
|
71
|
+
name: string;
|
|
72
|
+
}
|
|
73
|
+
/**
|
|
74
|
+
* Binary operation
|
|
75
|
+
*/
|
|
76
|
+
export interface BinaryOp extends ASTNode {
|
|
77
|
+
kind: 'binary';
|
|
78
|
+
operator: '+' | '-' | '*' | '/' | '^' | '**';
|
|
79
|
+
left: Expression;
|
|
80
|
+
right: Expression;
|
|
81
|
+
}
|
|
82
|
+
/**
|
|
83
|
+
* Unary operation
|
|
84
|
+
*/
|
|
85
|
+
export interface UnaryOp extends ASTNode {
|
|
86
|
+
kind: 'unary';
|
|
87
|
+
operator: '-' | '+';
|
|
88
|
+
operand: Expression;
|
|
89
|
+
}
|
|
90
|
+
/**
|
|
91
|
+
* Function call
|
|
92
|
+
*/
|
|
93
|
+
export interface FunctionCall extends ASTNode {
|
|
94
|
+
kind: 'call';
|
|
95
|
+
name: string;
|
|
96
|
+
args: Expression[];
|
|
97
|
+
}
|
|
98
|
+
/**
|
|
99
|
+
* Component access (e.g., u.x, v.y)
|
|
100
|
+
*/
|
|
101
|
+
export interface ComponentAccess extends ASTNode {
|
|
102
|
+
kind: 'component';
|
|
103
|
+
object: Expression;
|
|
104
|
+
component: string;
|
|
105
|
+
}
|
|
106
|
+
/**
|
|
107
|
+
* Visitor pattern for AST traversal
|
|
108
|
+
*/
|
|
109
|
+
export interface ASTVisitor<T> {
|
|
110
|
+
visitProgram(node: Program): T;
|
|
111
|
+
visitFunction(node: FunctionDef): T;
|
|
112
|
+
visitAssignment(node: Assignment): T;
|
|
113
|
+
visitNumber(node: NumberLiteral): T;
|
|
114
|
+
visitVariable(node: Variable): T;
|
|
115
|
+
visitBinary(node: BinaryOp): T;
|
|
116
|
+
visitUnary(node: UnaryOp): T;
|
|
117
|
+
visitCall(node: FunctionCall): T;
|
|
118
|
+
visitComponent(node: ComponentAccess): T;
|
|
119
|
+
}
|
|
120
|
+
/**
|
|
121
|
+
* Helper to visit any expression node
|
|
122
|
+
*/
|
|
123
|
+
export declare function visitExpression<T>(visitor: ASTVisitor<T>, expr: Expression): T;
|
package/dist/dsl/AST.js
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* AST nodes for GradientScript DSL
|
|
3
|
+
* Supports function definitions with structured types
|
|
4
|
+
*/
|
|
5
|
+
/**
|
|
6
|
+
* Helper to visit any expression node
|
|
7
|
+
*/
|
|
8
|
+
export function visitExpression(visitor, expr) {
|
|
9
|
+
switch (expr.kind) {
|
|
10
|
+
case 'number':
|
|
11
|
+
return visitor.visitNumber(expr);
|
|
12
|
+
case 'variable':
|
|
13
|
+
return visitor.visitVariable(expr);
|
|
14
|
+
case 'binary':
|
|
15
|
+
return visitor.visitBinary(expr);
|
|
16
|
+
case 'unary':
|
|
17
|
+
return visitor.visitUnary(expr);
|
|
18
|
+
case 'call':
|
|
19
|
+
return visitor.visitCall(expr);
|
|
20
|
+
case 'component':
|
|
21
|
+
return visitor.visitComponent(expr);
|
|
22
|
+
}
|
|
23
|
+
}
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Built-in functions for GradientScript DSL
|
|
3
|
+
* Defines dot2d, cross2d, magnitude2d, etc.
|
|
4
|
+
*/
|
|
5
|
+
import { Type } from './Types.js';
|
|
6
|
+
/**
|
|
7
|
+
* Information about function discontinuities
|
|
8
|
+
*/
|
|
9
|
+
export interface DiscontinuityInfo {
|
|
10
|
+
description: string;
|
|
11
|
+
condition: string;
|
|
12
|
+
}
|
|
13
|
+
/**
|
|
14
|
+
* Signature of a built-in function
|
|
15
|
+
*/
|
|
16
|
+
export interface BuiltInSignature {
|
|
17
|
+
name: string;
|
|
18
|
+
params: Type[];
|
|
19
|
+
returnType: Type;
|
|
20
|
+
implementation?: (args: any[]) => any;
|
|
21
|
+
discontinuities?: DiscontinuityInfo[];
|
|
22
|
+
}
|
|
23
|
+
/**
|
|
24
|
+
* Registry of all built-in functions
|
|
25
|
+
*/
|
|
26
|
+
export declare class BuiltInRegistry {
|
|
27
|
+
private functions;
|
|
28
|
+
constructor();
|
|
29
|
+
/**
|
|
30
|
+
* Register all built-in functions
|
|
31
|
+
*/
|
|
32
|
+
private registerAll;
|
|
33
|
+
/**
|
|
34
|
+
* Register a built-in function
|
|
35
|
+
*/
|
|
36
|
+
register(sig: BuiltInSignature): void;
|
|
37
|
+
/**
|
|
38
|
+
* Look up a built-in function by name and argument types
|
|
39
|
+
*/
|
|
40
|
+
lookup(name: string, argTypes: Type[]): BuiltInSignature | undefined;
|
|
41
|
+
/**
|
|
42
|
+
* Check if argument types match parameter types
|
|
43
|
+
*/
|
|
44
|
+
private matchesSignature;
|
|
45
|
+
/**
|
|
46
|
+
* Check if a function is built-in
|
|
47
|
+
*/
|
|
48
|
+
isBuiltIn(name: string): boolean;
|
|
49
|
+
/**
|
|
50
|
+
* Get all overloads for a function name
|
|
51
|
+
*/
|
|
52
|
+
getOverloads(name: string): BuiltInSignature[];
|
|
53
|
+
/**
|
|
54
|
+
* Get discontinuity information for a function
|
|
55
|
+
*/
|
|
56
|
+
getDiscontinuities(name: string): DiscontinuityInfo[];
|
|
57
|
+
}
|
|
58
|
+
export declare const builtIns: BuiltInRegistry;
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Built-in functions for GradientScript DSL
|
|
3
|
+
* Defines dot2d, cross2d, magnitude2d, etc.
|
|
4
|
+
*/
|
|
5
|
+
import { Types } from './Types.js';
|
|
6
|
+
/**
|
|
7
|
+
* Registry of all built-in functions
|
|
8
|
+
*/
|
|
9
|
+
export class BuiltInRegistry {
|
|
10
|
+
functions = new Map();
|
|
11
|
+
constructor() {
|
|
12
|
+
this.registerAll();
|
|
13
|
+
}
|
|
14
|
+
/**
|
|
15
|
+
* Register all built-in functions
|
|
16
|
+
*/
|
|
17
|
+
registerAll() {
|
|
18
|
+
// 2D operations
|
|
19
|
+
this.register({
|
|
20
|
+
name: 'dot2d',
|
|
21
|
+
params: [Types.vec2(), Types.vec2()],
|
|
22
|
+
returnType: Types.scalar()
|
|
23
|
+
});
|
|
24
|
+
this.register({
|
|
25
|
+
name: 'cross2d',
|
|
26
|
+
params: [Types.vec2(), Types.vec2()],
|
|
27
|
+
returnType: Types.scalar()
|
|
28
|
+
});
|
|
29
|
+
this.register({
|
|
30
|
+
name: 'magnitude2d',
|
|
31
|
+
params: [Types.vec2()],
|
|
32
|
+
returnType: Types.scalar()
|
|
33
|
+
});
|
|
34
|
+
this.register({
|
|
35
|
+
name: 'normalize2d',
|
|
36
|
+
params: [Types.vec2()],
|
|
37
|
+
returnType: Types.vec2()
|
|
38
|
+
});
|
|
39
|
+
this.register({
|
|
40
|
+
name: 'distance2d',
|
|
41
|
+
params: [Types.vec2(), Types.vec2()],
|
|
42
|
+
returnType: Types.scalar()
|
|
43
|
+
});
|
|
44
|
+
// 3D operations
|
|
45
|
+
this.register({
|
|
46
|
+
name: 'dot3d',
|
|
47
|
+
params: [Types.vec3(), Types.vec3()],
|
|
48
|
+
returnType: Types.scalar()
|
|
49
|
+
});
|
|
50
|
+
this.register({
|
|
51
|
+
name: 'cross3d',
|
|
52
|
+
params: [Types.vec3(), Types.vec3()],
|
|
53
|
+
returnType: Types.vec3()
|
|
54
|
+
});
|
|
55
|
+
this.register({
|
|
56
|
+
name: 'magnitude3d',
|
|
57
|
+
params: [Types.vec3()],
|
|
58
|
+
returnType: Types.scalar()
|
|
59
|
+
});
|
|
60
|
+
this.register({
|
|
61
|
+
name: 'normalize3d',
|
|
62
|
+
params: [Types.vec3()],
|
|
63
|
+
returnType: Types.vec3()
|
|
64
|
+
});
|
|
65
|
+
// Math functions (scalar only)
|
|
66
|
+
const scalarMath = [
|
|
67
|
+
'sin', 'cos', 'tan',
|
|
68
|
+
'asin', 'acos', 'atan',
|
|
69
|
+
'exp', 'log', 'sqrt',
|
|
70
|
+
'abs'
|
|
71
|
+
];
|
|
72
|
+
for (const name of scalarMath) {
|
|
73
|
+
this.register({
|
|
74
|
+
name,
|
|
75
|
+
params: [Types.scalar()],
|
|
76
|
+
returnType: Types.scalar()
|
|
77
|
+
});
|
|
78
|
+
}
|
|
79
|
+
// Binary math functions
|
|
80
|
+
this.register({
|
|
81
|
+
name: 'atan2',
|
|
82
|
+
params: [Types.scalar(), Types.scalar()],
|
|
83
|
+
returnType: Types.scalar(),
|
|
84
|
+
discontinuities: [{
|
|
85
|
+
description: 'Branch cut discontinuity',
|
|
86
|
+
condition: 'x < 0 and y ≈ 0 (near ±180°)'
|
|
87
|
+
}]
|
|
88
|
+
});
|
|
89
|
+
this.register({
|
|
90
|
+
name: 'pow',
|
|
91
|
+
params: [Types.scalar(), Types.scalar()],
|
|
92
|
+
returnType: Types.scalar()
|
|
93
|
+
});
|
|
94
|
+
this.register({
|
|
95
|
+
name: 'min',
|
|
96
|
+
params: [Types.scalar(), Types.scalar()],
|
|
97
|
+
returnType: Types.scalar()
|
|
98
|
+
});
|
|
99
|
+
this.register({
|
|
100
|
+
name: 'max',
|
|
101
|
+
params: [Types.scalar(), Types.scalar()],
|
|
102
|
+
returnType: Types.scalar(),
|
|
103
|
+
discontinuities: [{
|
|
104
|
+
description: 'Non-smooth at equality',
|
|
105
|
+
condition: 'a = b'
|
|
106
|
+
}]
|
|
107
|
+
});
|
|
108
|
+
this.register({
|
|
109
|
+
name: 'clamp',
|
|
110
|
+
params: [Types.scalar(), Types.scalar(), Types.scalar()],
|
|
111
|
+
returnType: Types.scalar(),
|
|
112
|
+
discontinuities: [{
|
|
113
|
+
description: 'Non-smooth at boundaries',
|
|
114
|
+
condition: 'x = min or x = max'
|
|
115
|
+
}]
|
|
116
|
+
});
|
|
117
|
+
}
|
|
118
|
+
/**
|
|
119
|
+
* Register a built-in function
|
|
120
|
+
*/
|
|
121
|
+
register(sig) {
|
|
122
|
+
if (!this.functions.has(sig.name)) {
|
|
123
|
+
this.functions.set(sig.name, []);
|
|
124
|
+
}
|
|
125
|
+
this.functions.get(sig.name).push(sig);
|
|
126
|
+
}
|
|
127
|
+
/**
|
|
128
|
+
* Look up a built-in function by name and argument types
|
|
129
|
+
*/
|
|
130
|
+
lookup(name, argTypes) {
|
|
131
|
+
const overloads = this.functions.get(name);
|
|
132
|
+
if (!overloads)
|
|
133
|
+
return undefined;
|
|
134
|
+
// Find matching overload
|
|
135
|
+
for (const sig of overloads) {
|
|
136
|
+
if (this.matchesSignature(argTypes, sig.params)) {
|
|
137
|
+
return sig;
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
return undefined;
|
|
141
|
+
}
|
|
142
|
+
/**
|
|
143
|
+
* Check if argument types match parameter types
|
|
144
|
+
*/
|
|
145
|
+
matchesSignature(argTypes, paramTypes) {
|
|
146
|
+
if (argTypes.length !== paramTypes.length)
|
|
147
|
+
return false;
|
|
148
|
+
return argTypes.every((argType, i) => {
|
|
149
|
+
return Types.equals(argType, paramTypes[i]);
|
|
150
|
+
});
|
|
151
|
+
}
|
|
152
|
+
/**
|
|
153
|
+
* Check if a function is built-in
|
|
154
|
+
*/
|
|
155
|
+
isBuiltIn(name) {
|
|
156
|
+
return this.functions.has(name);
|
|
157
|
+
}
|
|
158
|
+
/**
|
|
159
|
+
* Get all overloads for a function name
|
|
160
|
+
*/
|
|
161
|
+
getOverloads(name) {
|
|
162
|
+
return this.functions.get(name) || [];
|
|
163
|
+
}
|
|
164
|
+
/**
|
|
165
|
+
* Get discontinuity information for a function
|
|
166
|
+
*/
|
|
167
|
+
getDiscontinuities(name) {
|
|
168
|
+
const overloads = this.functions.get(name);
|
|
169
|
+
if (!overloads)
|
|
170
|
+
return [];
|
|
171
|
+
const discontinuities = [];
|
|
172
|
+
for (const sig of overloads) {
|
|
173
|
+
if (sig.discontinuities) {
|
|
174
|
+
discontinuities.push(...sig.discontinuities);
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
return discontinuities;
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
// Global built-in registry instance
|
|
181
|
+
export const builtIns = new BuiltInRegistry();
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Common Subexpression Elimination (CSE)
|
|
3
|
+
* Identifies repeated expressions and factors them out
|
|
4
|
+
*/
|
|
5
|
+
import { Expression } from './AST.js';
|
|
6
|
+
export interface CSEResult {
|
|
7
|
+
intermediates: Map<string, Expression>;
|
|
8
|
+
simplified: Expression;
|
|
9
|
+
}
|
|
10
|
+
export interface StructuredCSEResult {
|
|
11
|
+
intermediates: Map<string, Expression>;
|
|
12
|
+
components: Map<string, Expression>;
|
|
13
|
+
}
|
|
14
|
+
/**
|
|
15
|
+
* Perform CSE on an expression
|
|
16
|
+
*/
|
|
17
|
+
export declare function eliminateCommonSubexpressions(expr: Expression, minCount?: number): CSEResult;
|
|
18
|
+
/**
|
|
19
|
+
* Perform CSE on structured gradients (for structured types like {x, y})
|
|
20
|
+
*/
|
|
21
|
+
export declare function eliminateCommonSubexpressionsStructured(components: Map<string, Expression>, minCount?: number): StructuredCSEResult;
|
package/dist/dsl/CSE.js
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Common Subexpression Elimination (CSE)
|
|
3
|
+
* Identifies repeated expressions and factors them out
|
|
4
|
+
*/
|
|
5
|
+
import { ExpressionTransformer } from './ExpressionTransformer.js';
|
|
6
|
+
/**
|
|
7
|
+
* Serializes expressions to canonical string form for comparison
|
|
8
|
+
* This is a dedicated serializer that doesn't abuse the type system
|
|
9
|
+
*/
|
|
10
|
+
class ExpressionSerializer {
|
|
11
|
+
serialize(expr) {
|
|
12
|
+
switch (expr.kind) {
|
|
13
|
+
case 'number':
|
|
14
|
+
return `num(${expr.value})`;
|
|
15
|
+
case 'variable':
|
|
16
|
+
return `var(${expr.name})`;
|
|
17
|
+
case 'binary':
|
|
18
|
+
const left = this.serialize(expr.left);
|
|
19
|
+
const right = this.serialize(expr.right);
|
|
20
|
+
return `bin(${expr.operator},${left},${right})`;
|
|
21
|
+
case 'unary':
|
|
22
|
+
const operand = this.serialize(expr.operand);
|
|
23
|
+
return `un(${expr.operator},${operand})`;
|
|
24
|
+
case 'call':
|
|
25
|
+
const args = expr.args.map(arg => this.serialize(arg)).join(',');
|
|
26
|
+
return `call(${expr.name},${args})`;
|
|
27
|
+
case 'component':
|
|
28
|
+
const object = this.serialize(expr.object);
|
|
29
|
+
return `comp(${object},${expr.component})`;
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
/**
|
|
34
|
+
* Perform CSE on an expression
|
|
35
|
+
*/
|
|
36
|
+
export function eliminateCommonSubexpressions(expr, minCount = 2) {
|
|
37
|
+
const counter = new ExpressionCounter();
|
|
38
|
+
counter.count(expr);
|
|
39
|
+
const intermediates = new Map();
|
|
40
|
+
let varCounter = 0;
|
|
41
|
+
const subexprMap = new Map();
|
|
42
|
+
for (const [exprStr, count] of counter.counts.entries()) {
|
|
43
|
+
if (count >= minCount) {
|
|
44
|
+
const parsed = counter.expressions.get(exprStr);
|
|
45
|
+
if (parsed && shouldExtract(parsed)) {
|
|
46
|
+
const varName = `_tmp${varCounter++}`;
|
|
47
|
+
intermediates.set(varName, parsed);
|
|
48
|
+
subexprMap.set(exprStr, varName);
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
const simplified = substituteExpressions(expr, subexprMap, counter);
|
|
53
|
+
return { intermediates, simplified };
|
|
54
|
+
}
|
|
55
|
+
/**
|
|
56
|
+
* Perform CSE on structured gradients (for structured types like {x, y})
|
|
57
|
+
*/
|
|
58
|
+
export function eliminateCommonSubexpressionsStructured(components, minCount = 2) {
|
|
59
|
+
const counter = new ExpressionCounter();
|
|
60
|
+
for (const expr of components.values()) {
|
|
61
|
+
counter.count(expr);
|
|
62
|
+
}
|
|
63
|
+
const intermediates = new Map();
|
|
64
|
+
let varCounter = 0;
|
|
65
|
+
const subexprMap = new Map();
|
|
66
|
+
for (const [exprStr, count] of counter.counts.entries()) {
|
|
67
|
+
if (count >= minCount) {
|
|
68
|
+
const parsed = counter.expressions.get(exprStr);
|
|
69
|
+
if (parsed && shouldExtract(parsed)) {
|
|
70
|
+
const varName = `_tmp${varCounter++}`;
|
|
71
|
+
intermediates.set(varName, parsed);
|
|
72
|
+
subexprMap.set(exprStr, varName);
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
const simplifiedComponents = new Map();
|
|
77
|
+
for (const [comp, expr] of components.entries()) {
|
|
78
|
+
simplifiedComponents.set(comp, substituteExpressions(expr, subexprMap, counter));
|
|
79
|
+
}
|
|
80
|
+
return { intermediates, components: simplifiedComponents };
|
|
81
|
+
}
|
|
82
|
+
/**
|
|
83
|
+
* Check if an expression should be extracted
|
|
84
|
+
*/
|
|
85
|
+
function shouldExtract(expr) {
|
|
86
|
+
switch (expr.kind) {
|
|
87
|
+
case 'number':
|
|
88
|
+
case 'variable':
|
|
89
|
+
return false;
|
|
90
|
+
case 'component':
|
|
91
|
+
return expr.object.kind !== 'variable';
|
|
92
|
+
case 'unary':
|
|
93
|
+
return shouldExtract(expr.operand);
|
|
94
|
+
case 'binary':
|
|
95
|
+
return true;
|
|
96
|
+
case 'call':
|
|
97
|
+
return true;
|
|
98
|
+
default:
|
|
99
|
+
return false;
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
/**
|
|
103
|
+
* Counts occurrences of subexpressions during traversal
|
|
104
|
+
*/
|
|
105
|
+
class ExpressionCounter extends ExpressionTransformer {
|
|
106
|
+
counts = new Map();
|
|
107
|
+
expressions = new Map();
|
|
108
|
+
serializer = new ExpressionSerializer();
|
|
109
|
+
count(expr) {
|
|
110
|
+
this.transform(expr);
|
|
111
|
+
}
|
|
112
|
+
serialize(expr) {
|
|
113
|
+
return this.serializer.serialize(expr);
|
|
114
|
+
}
|
|
115
|
+
recordExpression(expr) {
|
|
116
|
+
const key = this.serialize(expr);
|
|
117
|
+
const currentCount = this.counts.get(key) || 0;
|
|
118
|
+
this.counts.set(key, currentCount + 1);
|
|
119
|
+
if (!this.expressions.has(key)) {
|
|
120
|
+
this.expressions.set(key, expr);
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
visitNumber(node) {
|
|
124
|
+
this.recordExpression(node);
|
|
125
|
+
return node;
|
|
126
|
+
}
|
|
127
|
+
visitVariable(node) {
|
|
128
|
+
this.recordExpression(node);
|
|
129
|
+
return node;
|
|
130
|
+
}
|
|
131
|
+
visitBinaryOp(node) {
|
|
132
|
+
this.recordExpression(node);
|
|
133
|
+
return super.visitBinaryOp(node);
|
|
134
|
+
}
|
|
135
|
+
visitUnaryOp(node) {
|
|
136
|
+
this.recordExpression(node);
|
|
137
|
+
return super.visitUnaryOp(node);
|
|
138
|
+
}
|
|
139
|
+
visitFunctionCall(node) {
|
|
140
|
+
this.recordExpression(node);
|
|
141
|
+
return super.visitFunctionCall(node);
|
|
142
|
+
}
|
|
143
|
+
visitComponentAccess(node) {
|
|
144
|
+
this.recordExpression(node);
|
|
145
|
+
return super.visitComponentAccess(node);
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
/**
|
|
149
|
+
* Substitute common subexpressions with variables
|
|
150
|
+
*/
|
|
151
|
+
function substituteExpressions(expr, subexprMap, counter) {
|
|
152
|
+
const key = counter.serialize(expr);
|
|
153
|
+
if (subexprMap.has(key)) {
|
|
154
|
+
return {
|
|
155
|
+
kind: 'variable',
|
|
156
|
+
name: subexprMap.get(key)
|
|
157
|
+
};
|
|
158
|
+
}
|
|
159
|
+
let result = expr;
|
|
160
|
+
for (const [exprStr, varName] of subexprMap.entries()) {
|
|
161
|
+
const exprToReplace = counter.expressions.get(exprStr);
|
|
162
|
+
if (exprToReplace && counter.serialize(result) !== exprStr) {
|
|
163
|
+
result = substituteInExpression(result, exprToReplace, { kind: 'variable', name: varName }, counter);
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
return result;
|
|
167
|
+
}
|
|
168
|
+
/**
|
|
169
|
+
* Transformer that substitutes a pattern with a replacement expression
|
|
170
|
+
*/
|
|
171
|
+
class SubstitutionTransformer extends ExpressionTransformer {
|
|
172
|
+
pattern;
|
|
173
|
+
replacement;
|
|
174
|
+
counter;
|
|
175
|
+
constructor(pattern, replacement, counter) {
|
|
176
|
+
super();
|
|
177
|
+
this.pattern = pattern;
|
|
178
|
+
this.replacement = replacement;
|
|
179
|
+
this.counter = counter;
|
|
180
|
+
}
|
|
181
|
+
transform(expr) {
|
|
182
|
+
if (this.counter.serialize(expr) === this.counter.serialize(this.pattern)) {
|
|
183
|
+
return this.replacement;
|
|
184
|
+
}
|
|
185
|
+
return super.transform(expr);
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
/**
|
|
189
|
+
* Helper to substitute an expression pattern with a replacement
|
|
190
|
+
*/
|
|
191
|
+
function substituteInExpression(expr, pattern, replacement, counter) {
|
|
192
|
+
const transformer = new SubstitutionTransformer(pattern, replacement, counter);
|
|
193
|
+
return transformer.transform(expr);
|
|
194
|
+
}
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Code generation for GradientScript DSL
|
|
3
|
+
* Generates TypeScript/JavaScript code with gradient functions
|
|
4
|
+
*/
|
|
5
|
+
import { Expression, FunctionDef } from './AST.js';
|
|
6
|
+
import { TypeEnv } from './Types.js';
|
|
7
|
+
import { GradientResult } from './Differentiation.js';
|
|
8
|
+
/**
|
|
9
|
+
* Code generation options
|
|
10
|
+
*/
|
|
11
|
+
export interface CodeGenOptions {
|
|
12
|
+
format?: 'typescript' | 'javascript' | 'python';
|
|
13
|
+
includeComments?: boolean;
|
|
14
|
+
simplify?: boolean;
|
|
15
|
+
cse?: boolean;
|
|
16
|
+
epsilon?: number;
|
|
17
|
+
emitGuards?: boolean;
|
|
18
|
+
}
|
|
19
|
+
/**
|
|
20
|
+
* Code generator for expressions
|
|
21
|
+
*/
|
|
22
|
+
export declare class ExpressionCodeGen {
|
|
23
|
+
private format;
|
|
24
|
+
constructor(format?: 'typescript' | 'javascript' | 'python');
|
|
25
|
+
/**
|
|
26
|
+
* Generate code for an expression
|
|
27
|
+
*/
|
|
28
|
+
generate(expr: Expression): string;
|
|
29
|
+
private genNumber;
|
|
30
|
+
private genVariable;
|
|
31
|
+
private genBinary;
|
|
32
|
+
/**
|
|
33
|
+
* Generate expression with parentheses if needed based on precedence
|
|
34
|
+
*/
|
|
35
|
+
private genWithPrecedence;
|
|
36
|
+
/**
|
|
37
|
+
* Determine if child expression needs parentheses
|
|
38
|
+
*/
|
|
39
|
+
private needsParentheses;
|
|
40
|
+
/**
|
|
41
|
+
* Get operator precedence (higher number = higher precedence)
|
|
42
|
+
*/
|
|
43
|
+
private getPrecedence;
|
|
44
|
+
private genUnary;
|
|
45
|
+
private genCall;
|
|
46
|
+
private genComponent;
|
|
47
|
+
private mapFunctionName;
|
|
48
|
+
}
|
|
49
|
+
/**
|
|
50
|
+
* Generate complete gradient function code
|
|
51
|
+
*/
|
|
52
|
+
export declare function generateGradientFunction(func: FunctionDef, gradients: GradientResult, env: TypeEnv, options?: CodeGenOptions): string;
|
|
53
|
+
/**
|
|
54
|
+
* Generate the original forward function
|
|
55
|
+
*/
|
|
56
|
+
export declare function generateForwardFunction(func: FunctionDef, options?: CodeGenOptions): string;
|
|
57
|
+
/**
|
|
58
|
+
* Generate complete output with both forward and gradient functions
|
|
59
|
+
*/
|
|
60
|
+
export declare function generateComplete(func: FunctionDef, gradients: GradientResult, env: TypeEnv, options?: CodeGenOptions): string;
|