gradient-script 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. package/README.md +515 -0
  2. package/dist/cli.d.ts +2 -0
  3. package/dist/cli.js +136 -0
  4. package/dist/dsl/AST.d.ts +123 -0
  5. package/dist/dsl/AST.js +23 -0
  6. package/dist/dsl/BuiltIns.d.ts +58 -0
  7. package/dist/dsl/BuiltIns.js +181 -0
  8. package/dist/dsl/CSE.d.ts +21 -0
  9. package/dist/dsl/CSE.js +194 -0
  10. package/dist/dsl/CodeGen.d.ts +60 -0
  11. package/dist/dsl/CodeGen.js +474 -0
  12. package/dist/dsl/Differentiation.d.ts +45 -0
  13. package/dist/dsl/Differentiation.js +421 -0
  14. package/dist/dsl/DiscontinuityAnalyzer.d.ts +18 -0
  15. package/dist/dsl/DiscontinuityAnalyzer.js +75 -0
  16. package/dist/dsl/Errors.d.ts +22 -0
  17. package/dist/dsl/Errors.js +49 -0
  18. package/dist/dsl/Expander.d.ts +13 -0
  19. package/dist/dsl/Expander.js +220 -0
  20. package/dist/dsl/ExpressionTransformer.d.ts +54 -0
  21. package/dist/dsl/ExpressionTransformer.js +102 -0
  22. package/dist/dsl/ExpressionUtils.d.ts +55 -0
  23. package/dist/dsl/ExpressionUtils.js +175 -0
  24. package/dist/dsl/GradientChecker.d.ts +71 -0
  25. package/dist/dsl/GradientChecker.js +258 -0
  26. package/dist/dsl/Guards.d.ts +27 -0
  27. package/dist/dsl/Guards.js +206 -0
  28. package/dist/dsl/Inliner.d.ts +10 -0
  29. package/dist/dsl/Inliner.js +40 -0
  30. package/dist/dsl/Lexer.d.ts +63 -0
  31. package/dist/dsl/Lexer.js +243 -0
  32. package/dist/dsl/Parser.d.ts +92 -0
  33. package/dist/dsl/Parser.js +328 -0
  34. package/dist/dsl/Simplify.d.ts +17 -0
  35. package/dist/dsl/Simplify.js +276 -0
  36. package/dist/dsl/TypeInference.d.ts +39 -0
  37. package/dist/dsl/TypeInference.js +147 -0
  38. package/dist/dsl/Types.d.ts +58 -0
  39. package/dist/dsl/Types.js +114 -0
  40. package/dist/index.d.ts +13 -0
  41. package/dist/index.js +11 -0
  42. package/dist/symbolic/AST.d.ts +113 -0
  43. package/dist/symbolic/AST.js +128 -0
  44. package/dist/symbolic/CodeGen.d.ts +35 -0
  45. package/dist/symbolic/CodeGen.js +280 -0
  46. package/dist/symbolic/Parser.d.ts +64 -0
  47. package/dist/symbolic/Parser.js +329 -0
  48. package/dist/symbolic/Simplify.d.ts +10 -0
  49. package/dist/symbolic/Simplify.js +244 -0
  50. package/dist/symbolic/SymbolicDiff.d.ts +35 -0
  51. package/dist/symbolic/SymbolicDiff.js +339 -0
  52. package/package.json +56 -0
@@ -0,0 +1,123 @@
1
+ /**
2
+ * AST nodes for GradientScript DSL
3
+ * Supports function definitions with structured types
4
+ */
5
+ import { Type } from './Types.js';
6
+ /**
7
+ * Base AST node
8
+ */
9
+ export interface ASTNode {
10
+ type?: Type;
11
+ }
12
+ /**
13
+ * Program (top-level)
14
+ */
15
+ export interface Program extends ASTNode {
16
+ kind: 'program';
17
+ functions: FunctionDef[];
18
+ }
19
+ /**
20
+ * Function definition
21
+ */
22
+ export interface FunctionDef extends ASTNode {
23
+ kind: 'function';
24
+ name: string;
25
+ parameters: Parameter[];
26
+ body: Statement[];
27
+ returnExpr: Expression;
28
+ }
29
+ /**
30
+ * Function parameter
31
+ */
32
+ export interface Parameter {
33
+ name: string;
34
+ requiresGrad: boolean;
35
+ paramType?: StructTypeAnnotation;
36
+ }
37
+ /**
38
+ * Type annotation for structured types
39
+ */
40
+ export interface StructTypeAnnotation {
41
+ components: string[];
42
+ }
43
+ /**
44
+ * Statement types
45
+ */
46
+ export type Statement = Assignment;
47
+ /**
48
+ * Assignment statement
49
+ */
50
+ export interface Assignment extends ASTNode {
51
+ kind: 'assignment';
52
+ variable: string;
53
+ expression: Expression;
54
+ }
55
+ /**
56
+ * Expression types
57
+ */
58
+ export type Expression = NumberLiteral | Variable | BinaryOp | UnaryOp | FunctionCall | ComponentAccess;
59
+ /**
60
+ * Number literal
61
+ */
62
+ export interface NumberLiteral extends ASTNode {
63
+ kind: 'number';
64
+ value: number;
65
+ }
66
+ /**
67
+ * Variable reference
68
+ */
69
+ export interface Variable extends ASTNode {
70
+ kind: 'variable';
71
+ name: string;
72
+ }
73
+ /**
74
+ * Binary operation
75
+ */
76
+ export interface BinaryOp extends ASTNode {
77
+ kind: 'binary';
78
+ operator: '+' | '-' | '*' | '/' | '^' | '**';
79
+ left: Expression;
80
+ right: Expression;
81
+ }
82
+ /**
83
+ * Unary operation
84
+ */
85
+ export interface UnaryOp extends ASTNode {
86
+ kind: 'unary';
87
+ operator: '-' | '+';
88
+ operand: Expression;
89
+ }
90
+ /**
91
+ * Function call
92
+ */
93
+ export interface FunctionCall extends ASTNode {
94
+ kind: 'call';
95
+ name: string;
96
+ args: Expression[];
97
+ }
98
+ /**
99
+ * Component access (e.g., u.x, v.y)
100
+ */
101
+ export interface ComponentAccess extends ASTNode {
102
+ kind: 'component';
103
+ object: Expression;
104
+ component: string;
105
+ }
106
+ /**
107
+ * Visitor pattern for AST traversal
108
+ */
109
+ export interface ASTVisitor<T> {
110
+ visitProgram(node: Program): T;
111
+ visitFunction(node: FunctionDef): T;
112
+ visitAssignment(node: Assignment): T;
113
+ visitNumber(node: NumberLiteral): T;
114
+ visitVariable(node: Variable): T;
115
+ visitBinary(node: BinaryOp): T;
116
+ visitUnary(node: UnaryOp): T;
117
+ visitCall(node: FunctionCall): T;
118
+ visitComponent(node: ComponentAccess): T;
119
+ }
120
+ /**
121
+ * Helper to visit any expression node
122
+ */
123
+ export declare function visitExpression<T>(visitor: ASTVisitor<T>, expr: Expression): T;
@@ -0,0 +1,23 @@
1
+ /**
2
+ * AST nodes for GradientScript DSL
3
+ * Supports function definitions with structured types
4
+ */
5
+ /**
6
+ * Helper to visit any expression node
7
+ */
8
+ export function visitExpression(visitor, expr) {
9
+ switch (expr.kind) {
10
+ case 'number':
11
+ return visitor.visitNumber(expr);
12
+ case 'variable':
13
+ return visitor.visitVariable(expr);
14
+ case 'binary':
15
+ return visitor.visitBinary(expr);
16
+ case 'unary':
17
+ return visitor.visitUnary(expr);
18
+ case 'call':
19
+ return visitor.visitCall(expr);
20
+ case 'component':
21
+ return visitor.visitComponent(expr);
22
+ }
23
+ }
@@ -0,0 +1,58 @@
1
+ /**
2
+ * Built-in functions for GradientScript DSL
3
+ * Defines dot2d, cross2d, magnitude2d, etc.
4
+ */
5
+ import { Type } from './Types.js';
6
+ /**
7
+ * Information about function discontinuities
8
+ */
9
+ export interface DiscontinuityInfo {
10
+ description: string;
11
+ condition: string;
12
+ }
13
+ /**
14
+ * Signature of a built-in function
15
+ */
16
+ export interface BuiltInSignature {
17
+ name: string;
18
+ params: Type[];
19
+ returnType: Type;
20
+ implementation?: (args: any[]) => any;
21
+ discontinuities?: DiscontinuityInfo[];
22
+ }
23
+ /**
24
+ * Registry of all built-in functions
25
+ */
26
+ export declare class BuiltInRegistry {
27
+ private functions;
28
+ constructor();
29
+ /**
30
+ * Register all built-in functions
31
+ */
32
+ private registerAll;
33
+ /**
34
+ * Register a built-in function
35
+ */
36
+ register(sig: BuiltInSignature): void;
37
+ /**
38
+ * Look up a built-in function by name and argument types
39
+ */
40
+ lookup(name: string, argTypes: Type[]): BuiltInSignature | undefined;
41
+ /**
42
+ * Check if argument types match parameter types
43
+ */
44
+ private matchesSignature;
45
+ /**
46
+ * Check if a function is built-in
47
+ */
48
+ isBuiltIn(name: string): boolean;
49
+ /**
50
+ * Get all overloads for a function name
51
+ */
52
+ getOverloads(name: string): BuiltInSignature[];
53
+ /**
54
+ * Get discontinuity information for a function
55
+ */
56
+ getDiscontinuities(name: string): DiscontinuityInfo[];
57
+ }
58
+ export declare const builtIns: BuiltInRegistry;
@@ -0,0 +1,181 @@
1
+ /**
2
+ * Built-in functions for GradientScript DSL
3
+ * Defines dot2d, cross2d, magnitude2d, etc.
4
+ */
5
+ import { Types } from './Types.js';
6
+ /**
7
+ * Registry of all built-in functions
8
+ */
9
+ export class BuiltInRegistry {
10
+ functions = new Map();
11
+ constructor() {
12
+ this.registerAll();
13
+ }
14
+ /**
15
+ * Register all built-in functions
16
+ */
17
+ registerAll() {
18
+ // 2D operations
19
+ this.register({
20
+ name: 'dot2d',
21
+ params: [Types.vec2(), Types.vec2()],
22
+ returnType: Types.scalar()
23
+ });
24
+ this.register({
25
+ name: 'cross2d',
26
+ params: [Types.vec2(), Types.vec2()],
27
+ returnType: Types.scalar()
28
+ });
29
+ this.register({
30
+ name: 'magnitude2d',
31
+ params: [Types.vec2()],
32
+ returnType: Types.scalar()
33
+ });
34
+ this.register({
35
+ name: 'normalize2d',
36
+ params: [Types.vec2()],
37
+ returnType: Types.vec2()
38
+ });
39
+ this.register({
40
+ name: 'distance2d',
41
+ params: [Types.vec2(), Types.vec2()],
42
+ returnType: Types.scalar()
43
+ });
44
+ // 3D operations
45
+ this.register({
46
+ name: 'dot3d',
47
+ params: [Types.vec3(), Types.vec3()],
48
+ returnType: Types.scalar()
49
+ });
50
+ this.register({
51
+ name: 'cross3d',
52
+ params: [Types.vec3(), Types.vec3()],
53
+ returnType: Types.vec3()
54
+ });
55
+ this.register({
56
+ name: 'magnitude3d',
57
+ params: [Types.vec3()],
58
+ returnType: Types.scalar()
59
+ });
60
+ this.register({
61
+ name: 'normalize3d',
62
+ params: [Types.vec3()],
63
+ returnType: Types.vec3()
64
+ });
65
+ // Math functions (scalar only)
66
+ const scalarMath = [
67
+ 'sin', 'cos', 'tan',
68
+ 'asin', 'acos', 'atan',
69
+ 'exp', 'log', 'sqrt',
70
+ 'abs'
71
+ ];
72
+ for (const name of scalarMath) {
73
+ this.register({
74
+ name,
75
+ params: [Types.scalar()],
76
+ returnType: Types.scalar()
77
+ });
78
+ }
79
+ // Binary math functions
80
+ this.register({
81
+ name: 'atan2',
82
+ params: [Types.scalar(), Types.scalar()],
83
+ returnType: Types.scalar(),
84
+ discontinuities: [{
85
+ description: 'Branch cut discontinuity',
86
+ condition: 'x < 0 and y ≈ 0 (near ±180°)'
87
+ }]
88
+ });
89
+ this.register({
90
+ name: 'pow',
91
+ params: [Types.scalar(), Types.scalar()],
92
+ returnType: Types.scalar()
93
+ });
94
+ this.register({
95
+ name: 'min',
96
+ params: [Types.scalar(), Types.scalar()],
97
+ returnType: Types.scalar()
98
+ });
99
+ this.register({
100
+ name: 'max',
101
+ params: [Types.scalar(), Types.scalar()],
102
+ returnType: Types.scalar(),
103
+ discontinuities: [{
104
+ description: 'Non-smooth at equality',
105
+ condition: 'a = b'
106
+ }]
107
+ });
108
+ this.register({
109
+ name: 'clamp',
110
+ params: [Types.scalar(), Types.scalar(), Types.scalar()],
111
+ returnType: Types.scalar(),
112
+ discontinuities: [{
113
+ description: 'Non-smooth at boundaries',
114
+ condition: 'x = min or x = max'
115
+ }]
116
+ });
117
+ }
118
+ /**
119
+ * Register a built-in function
120
+ */
121
+ register(sig) {
122
+ if (!this.functions.has(sig.name)) {
123
+ this.functions.set(sig.name, []);
124
+ }
125
+ this.functions.get(sig.name).push(sig);
126
+ }
127
+ /**
128
+ * Look up a built-in function by name and argument types
129
+ */
130
+ lookup(name, argTypes) {
131
+ const overloads = this.functions.get(name);
132
+ if (!overloads)
133
+ return undefined;
134
+ // Find matching overload
135
+ for (const sig of overloads) {
136
+ if (this.matchesSignature(argTypes, sig.params)) {
137
+ return sig;
138
+ }
139
+ }
140
+ return undefined;
141
+ }
142
+ /**
143
+ * Check if argument types match parameter types
144
+ */
145
+ matchesSignature(argTypes, paramTypes) {
146
+ if (argTypes.length !== paramTypes.length)
147
+ return false;
148
+ return argTypes.every((argType, i) => {
149
+ return Types.equals(argType, paramTypes[i]);
150
+ });
151
+ }
152
+ /**
153
+ * Check if a function is built-in
154
+ */
155
+ isBuiltIn(name) {
156
+ return this.functions.has(name);
157
+ }
158
+ /**
159
+ * Get all overloads for a function name
160
+ */
161
+ getOverloads(name) {
162
+ return this.functions.get(name) || [];
163
+ }
164
+ /**
165
+ * Get discontinuity information for a function
166
+ */
167
+ getDiscontinuities(name) {
168
+ const overloads = this.functions.get(name);
169
+ if (!overloads)
170
+ return [];
171
+ const discontinuities = [];
172
+ for (const sig of overloads) {
173
+ if (sig.discontinuities) {
174
+ discontinuities.push(...sig.discontinuities);
175
+ }
176
+ }
177
+ return discontinuities;
178
+ }
179
+ }
180
+ // Global built-in registry instance
181
+ export const builtIns = new BuiltInRegistry();
@@ -0,0 +1,21 @@
1
+ /**
2
+ * Common Subexpression Elimination (CSE)
3
+ * Identifies repeated expressions and factors them out
4
+ */
5
+ import { Expression } from './AST.js';
6
+ export interface CSEResult {
7
+ intermediates: Map<string, Expression>;
8
+ simplified: Expression;
9
+ }
10
+ export interface StructuredCSEResult {
11
+ intermediates: Map<string, Expression>;
12
+ components: Map<string, Expression>;
13
+ }
14
+ /**
15
+ * Perform CSE on an expression
16
+ */
17
+ export declare function eliminateCommonSubexpressions(expr: Expression, minCount?: number): CSEResult;
18
+ /**
19
+ * Perform CSE on structured gradients (for structured types like {x, y})
20
+ */
21
+ export declare function eliminateCommonSubexpressionsStructured(components: Map<string, Expression>, minCount?: number): StructuredCSEResult;
@@ -0,0 +1,194 @@
1
+ /**
2
+ * Common Subexpression Elimination (CSE)
3
+ * Identifies repeated expressions and factors them out
4
+ */
5
+ import { ExpressionTransformer } from './ExpressionTransformer.js';
6
+ /**
7
+ * Serializes expressions to canonical string form for comparison
8
+ * This is a dedicated serializer that doesn't abuse the type system
9
+ */
10
+ class ExpressionSerializer {
11
+ serialize(expr) {
12
+ switch (expr.kind) {
13
+ case 'number':
14
+ return `num(${expr.value})`;
15
+ case 'variable':
16
+ return `var(${expr.name})`;
17
+ case 'binary':
18
+ const left = this.serialize(expr.left);
19
+ const right = this.serialize(expr.right);
20
+ return `bin(${expr.operator},${left},${right})`;
21
+ case 'unary':
22
+ const operand = this.serialize(expr.operand);
23
+ return `un(${expr.operator},${operand})`;
24
+ case 'call':
25
+ const args = expr.args.map(arg => this.serialize(arg)).join(',');
26
+ return `call(${expr.name},${args})`;
27
+ case 'component':
28
+ const object = this.serialize(expr.object);
29
+ return `comp(${object},${expr.component})`;
30
+ }
31
+ }
32
+ }
33
+ /**
34
+ * Perform CSE on an expression
35
+ */
36
+ export function eliminateCommonSubexpressions(expr, minCount = 2) {
37
+ const counter = new ExpressionCounter();
38
+ counter.count(expr);
39
+ const intermediates = new Map();
40
+ let varCounter = 0;
41
+ const subexprMap = new Map();
42
+ for (const [exprStr, count] of counter.counts.entries()) {
43
+ if (count >= minCount) {
44
+ const parsed = counter.expressions.get(exprStr);
45
+ if (parsed && shouldExtract(parsed)) {
46
+ const varName = `_tmp${varCounter++}`;
47
+ intermediates.set(varName, parsed);
48
+ subexprMap.set(exprStr, varName);
49
+ }
50
+ }
51
+ }
52
+ const simplified = substituteExpressions(expr, subexprMap, counter);
53
+ return { intermediates, simplified };
54
+ }
55
+ /**
56
+ * Perform CSE on structured gradients (for structured types like {x, y})
57
+ */
58
+ export function eliminateCommonSubexpressionsStructured(components, minCount = 2) {
59
+ const counter = new ExpressionCounter();
60
+ for (const expr of components.values()) {
61
+ counter.count(expr);
62
+ }
63
+ const intermediates = new Map();
64
+ let varCounter = 0;
65
+ const subexprMap = new Map();
66
+ for (const [exprStr, count] of counter.counts.entries()) {
67
+ if (count >= minCount) {
68
+ const parsed = counter.expressions.get(exprStr);
69
+ if (parsed && shouldExtract(parsed)) {
70
+ const varName = `_tmp${varCounter++}`;
71
+ intermediates.set(varName, parsed);
72
+ subexprMap.set(exprStr, varName);
73
+ }
74
+ }
75
+ }
76
+ const simplifiedComponents = new Map();
77
+ for (const [comp, expr] of components.entries()) {
78
+ simplifiedComponents.set(comp, substituteExpressions(expr, subexprMap, counter));
79
+ }
80
+ return { intermediates, components: simplifiedComponents };
81
+ }
82
+ /**
83
+ * Check if an expression should be extracted
84
+ */
85
+ function shouldExtract(expr) {
86
+ switch (expr.kind) {
87
+ case 'number':
88
+ case 'variable':
89
+ return false;
90
+ case 'component':
91
+ return expr.object.kind !== 'variable';
92
+ case 'unary':
93
+ return shouldExtract(expr.operand);
94
+ case 'binary':
95
+ return true;
96
+ case 'call':
97
+ return true;
98
+ default:
99
+ return false;
100
+ }
101
+ }
102
+ /**
103
+ * Counts occurrences of subexpressions during traversal
104
+ */
105
+ class ExpressionCounter extends ExpressionTransformer {
106
+ counts = new Map();
107
+ expressions = new Map();
108
+ serializer = new ExpressionSerializer();
109
+ count(expr) {
110
+ this.transform(expr);
111
+ }
112
+ serialize(expr) {
113
+ return this.serializer.serialize(expr);
114
+ }
115
+ recordExpression(expr) {
116
+ const key = this.serialize(expr);
117
+ const currentCount = this.counts.get(key) || 0;
118
+ this.counts.set(key, currentCount + 1);
119
+ if (!this.expressions.has(key)) {
120
+ this.expressions.set(key, expr);
121
+ }
122
+ }
123
+ visitNumber(node) {
124
+ this.recordExpression(node);
125
+ return node;
126
+ }
127
+ visitVariable(node) {
128
+ this.recordExpression(node);
129
+ return node;
130
+ }
131
+ visitBinaryOp(node) {
132
+ this.recordExpression(node);
133
+ return super.visitBinaryOp(node);
134
+ }
135
+ visitUnaryOp(node) {
136
+ this.recordExpression(node);
137
+ return super.visitUnaryOp(node);
138
+ }
139
+ visitFunctionCall(node) {
140
+ this.recordExpression(node);
141
+ return super.visitFunctionCall(node);
142
+ }
143
+ visitComponentAccess(node) {
144
+ this.recordExpression(node);
145
+ return super.visitComponentAccess(node);
146
+ }
147
+ }
148
+ /**
149
+ * Substitute common subexpressions with variables
150
+ */
151
+ function substituteExpressions(expr, subexprMap, counter) {
152
+ const key = counter.serialize(expr);
153
+ if (subexprMap.has(key)) {
154
+ return {
155
+ kind: 'variable',
156
+ name: subexprMap.get(key)
157
+ };
158
+ }
159
+ let result = expr;
160
+ for (const [exprStr, varName] of subexprMap.entries()) {
161
+ const exprToReplace = counter.expressions.get(exprStr);
162
+ if (exprToReplace && counter.serialize(result) !== exprStr) {
163
+ result = substituteInExpression(result, exprToReplace, { kind: 'variable', name: varName }, counter);
164
+ }
165
+ }
166
+ return result;
167
+ }
168
+ /**
169
+ * Transformer that substitutes a pattern with a replacement expression
170
+ */
171
+ class SubstitutionTransformer extends ExpressionTransformer {
172
+ pattern;
173
+ replacement;
174
+ counter;
175
+ constructor(pattern, replacement, counter) {
176
+ super();
177
+ this.pattern = pattern;
178
+ this.replacement = replacement;
179
+ this.counter = counter;
180
+ }
181
+ transform(expr) {
182
+ if (this.counter.serialize(expr) === this.counter.serialize(this.pattern)) {
183
+ return this.replacement;
184
+ }
185
+ return super.transform(expr);
186
+ }
187
+ }
188
+ /**
189
+ * Helper to substitute an expression pattern with a replacement
190
+ */
191
+ function substituteInExpression(expr, pattern, replacement, counter) {
192
+ const transformer = new SubstitutionTransformer(pattern, replacement, counter);
193
+ return transformer.transform(expr);
194
+ }
@@ -0,0 +1,60 @@
1
+ /**
2
+ * Code generation for GradientScript DSL
3
+ * Generates TypeScript/JavaScript code with gradient functions
4
+ */
5
+ import { Expression, FunctionDef } from './AST.js';
6
+ import { TypeEnv } from './Types.js';
7
+ import { GradientResult } from './Differentiation.js';
8
+ /**
9
+ * Code generation options
10
+ */
11
+ export interface CodeGenOptions {
12
+ format?: 'typescript' | 'javascript' | 'python';
13
+ includeComments?: boolean;
14
+ simplify?: boolean;
15
+ cse?: boolean;
16
+ epsilon?: number;
17
+ emitGuards?: boolean;
18
+ }
19
+ /**
20
+ * Code generator for expressions
21
+ */
22
+ export declare class ExpressionCodeGen {
23
+ private format;
24
+ constructor(format?: 'typescript' | 'javascript' | 'python');
25
+ /**
26
+ * Generate code for an expression
27
+ */
28
+ generate(expr: Expression): string;
29
+ private genNumber;
30
+ private genVariable;
31
+ private genBinary;
32
+ /**
33
+ * Generate expression with parentheses if needed based on precedence
34
+ */
35
+ private genWithPrecedence;
36
+ /**
37
+ * Determine if child expression needs parentheses
38
+ */
39
+ private needsParentheses;
40
+ /**
41
+ * Get operator precedence (higher number = higher precedence)
42
+ */
43
+ private getPrecedence;
44
+ private genUnary;
45
+ private genCall;
46
+ private genComponent;
47
+ private mapFunctionName;
48
+ }
49
+ /**
50
+ * Generate complete gradient function code
51
+ */
52
+ export declare function generateGradientFunction(func: FunctionDef, gradients: GradientResult, env: TypeEnv, options?: CodeGenOptions): string;
53
+ /**
54
+ * Generate the original forward function
55
+ */
56
+ export declare function generateForwardFunction(func: FunctionDef, options?: CodeGenOptions): string;
57
+ /**
58
+ * Generate complete output with both forward and gradient functions
59
+ */
60
+ export declare function generateComplete(func: FunctionDef, gradients: GradientResult, env: TypeEnv, options?: CodeGenOptions): string;