gradient-script 0.2.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -1
- package/dist/cli.js +219 -6
- package/dist/dsl/CodeGen.d.ts +1 -1
- package/dist/dsl/CodeGen.js +336 -74
- package/dist/dsl/ExpressionUtils.d.ts +8 -2
- package/dist/dsl/ExpressionUtils.js +34 -2
- package/dist/dsl/GradientChecker.d.ts +21 -0
- package/dist/dsl/GradientChecker.js +109 -23
- package/dist/dsl/Guards.d.ts +1 -1
- package/dist/dsl/Guards.js +14 -13
- package/dist/dsl/Inliner.d.ts +5 -0
- package/dist/dsl/Inliner.js +8 -0
- package/dist/dsl/Simplify.d.ts +7 -0
- package/dist/dsl/Simplify.js +136 -0
- package/dist/dsl/egraph/Convert.d.ts +23 -0
- package/dist/dsl/egraph/Convert.js +84 -0
- package/dist/dsl/egraph/EGraph.d.ts +93 -0
- package/dist/dsl/egraph/EGraph.js +292 -0
- package/dist/dsl/egraph/ENode.d.ts +63 -0
- package/dist/dsl/egraph/ENode.js +94 -0
- package/dist/dsl/egraph/Extractor.d.ts +49 -0
- package/dist/dsl/egraph/Extractor.js +1068 -0
- package/dist/dsl/egraph/Optimizer.d.ts +50 -0
- package/dist/dsl/egraph/Optimizer.js +88 -0
- package/dist/dsl/egraph/Pattern.d.ts +80 -0
- package/dist/dsl/egraph/Pattern.js +325 -0
- package/dist/dsl/egraph/Rewriter.d.ts +44 -0
- package/dist/dsl/egraph/Rewriter.js +131 -0
- package/dist/dsl/egraph/Rules.d.ts +44 -0
- package/dist/dsl/egraph/Rules.js +187 -0
- package/dist/dsl/egraph/index.d.ts +15 -0
- package/dist/dsl/egraph/index.js +21 -0
- package/package.json +1 -1
- package/dist/dsl/CSE.d.ts +0 -21
- package/dist/dsl/CSE.js +0 -168
- package/dist/symbolic/AST.d.ts +0 -113
- package/dist/symbolic/AST.js +0 -128
- package/dist/symbolic/CodeGen.d.ts +0 -35
- package/dist/symbolic/CodeGen.js +0 -280
- package/dist/symbolic/Parser.d.ts +0 -64
- package/dist/symbolic/Parser.js +0 -329
- package/dist/symbolic/Simplify.d.ts +0 -10
- package/dist/symbolic/Simplify.js +0 -244
- package/dist/symbolic/SymbolicDiff.d.ts +0 -35
- package/dist/symbolic/SymbolicDiff.js +0 -339
|
@@ -4,6 +4,47 @@
|
|
|
4
4
|
*/
|
|
5
5
|
import { Types } from './Types.js';
|
|
6
6
|
import { expandBuiltIn, shouldExpand } from './Expander.js';
|
|
7
|
+
/**
|
|
8
|
+
* Format gradient check results as a human-readable string
|
|
9
|
+
*/
|
|
10
|
+
export function formatGradCheckResult(result, funcName) {
|
|
11
|
+
const singularityCount = result.singularities.length;
|
|
12
|
+
const verifiedCount = result.totalChecks - result.errors.length - singularityCount;
|
|
13
|
+
if (result.passed) {
|
|
14
|
+
let msg = `✓ ${funcName}: ${verifiedCount} gradients verified (max error: ${result.maxError.toExponential(2)})`;
|
|
15
|
+
if (singularityCount > 0) {
|
|
16
|
+
msg += `\n ⚠ ${singularityCount} singularities detected (both analytical and numerical produce NaN/Inf)`;
|
|
17
|
+
}
|
|
18
|
+
return msg;
|
|
19
|
+
}
|
|
20
|
+
const lines = [
|
|
21
|
+
`✗ ${funcName}: ${result.errors.length}/${result.totalChecks} gradients FAILED`
|
|
22
|
+
];
|
|
23
|
+
// Group errors by parameter
|
|
24
|
+
const byParam = new Map();
|
|
25
|
+
for (const err of result.errors) {
|
|
26
|
+
const key = err.parameter;
|
|
27
|
+
if (!byParam.has(key))
|
|
28
|
+
byParam.set(key, []);
|
|
29
|
+
byParam.get(key).push(err);
|
|
30
|
+
}
|
|
31
|
+
for (const [param, errs] of byParam) {
|
|
32
|
+
if (errs.length === 1 && !errs[0].component) {
|
|
33
|
+
// Scalar parameter
|
|
34
|
+
const e = errs[0];
|
|
35
|
+
lines.push(` ${param}: analytical=${e.analytical.toFixed(6)}, numerical=${e.numerical.toFixed(6)}, error=${e.error.toExponential(2)}`);
|
|
36
|
+
}
|
|
37
|
+
else {
|
|
38
|
+
// Structured parameter - show on one line if possible
|
|
39
|
+
const components = errs.map(e => `${e.component}:${e.error.toExponential(1)}`).join(', ');
|
|
40
|
+
lines.push(` ${param}: {${components}}`);
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
if (singularityCount > 0) {
|
|
44
|
+
lines.push(` ⚠ ${singularityCount} singularities also detected`);
|
|
45
|
+
}
|
|
46
|
+
return lines.join('\n');
|
|
47
|
+
}
|
|
7
48
|
/**
|
|
8
49
|
* Gradient checker
|
|
9
50
|
*/
|
|
@@ -19,6 +60,9 @@ export class GradientChecker {
|
|
|
19
60
|
*/
|
|
20
61
|
check(func, gradients, env, testPoint) {
|
|
21
62
|
const errors = [];
|
|
63
|
+
const singularities = [];
|
|
64
|
+
let totalChecks = 0;
|
|
65
|
+
let maxError = 0;
|
|
22
66
|
// For each parameter that has gradients
|
|
23
67
|
for (const [paramName, gradient] of gradients.gradients.entries()) {
|
|
24
68
|
const paramType = env.getOrThrow(paramName);
|
|
@@ -28,21 +72,21 @@ export class GradientChecker {
|
|
|
28
72
|
}
|
|
29
73
|
if (Types.isScalar(paramType)) {
|
|
30
74
|
// Scalar parameter
|
|
75
|
+
totalChecks++;
|
|
31
76
|
if (typeof paramValue !== 'number') {
|
|
32
77
|
throw new Error(`Expected scalar value for ${paramName}`);
|
|
33
78
|
}
|
|
34
79
|
const analytical = this.evaluateExpression(gradient, testPoint);
|
|
35
80
|
const numerical = this.numericalGradientScalar(func, testPoint, paramName);
|
|
36
|
-
const
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
});
|
|
81
|
+
const checkResult = this.compareGradients(analytical, numerical, paramName);
|
|
82
|
+
if (checkResult.type === 'singularity') {
|
|
83
|
+
singularities.push(checkResult.singularity);
|
|
84
|
+
}
|
|
85
|
+
else if (checkResult.type === 'error') {
|
|
86
|
+
errors.push(checkResult.error);
|
|
87
|
+
}
|
|
88
|
+
else if (checkResult.error) {
|
|
89
|
+
maxError = Math.max(maxError, checkResult.error.error);
|
|
46
90
|
}
|
|
47
91
|
}
|
|
48
92
|
else {
|
|
@@ -52,32 +96,74 @@ export class GradientChecker {
|
|
|
52
96
|
}
|
|
53
97
|
const structGrad = gradient;
|
|
54
98
|
for (const [comp, expr] of structGrad.components.entries()) {
|
|
99
|
+
totalChecks++;
|
|
55
100
|
const analytical = this.evaluateExpression(expr, testPoint);
|
|
56
101
|
const numerical = this.numericalGradientComponent(func, testPoint, paramName, comp);
|
|
57
|
-
const
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
relativeError
|
|
67
|
-
});
|
|
102
|
+
const checkResult = this.compareGradients(analytical, numerical, paramName, comp);
|
|
103
|
+
if (checkResult.type === 'singularity') {
|
|
104
|
+
singularities.push(checkResult.singularity);
|
|
105
|
+
}
|
|
106
|
+
else if (checkResult.type === 'error') {
|
|
107
|
+
errors.push(checkResult.error);
|
|
108
|
+
}
|
|
109
|
+
else if (checkResult.error) {
|
|
110
|
+
maxError = Math.max(maxError, checkResult.error.error);
|
|
68
111
|
}
|
|
69
112
|
}
|
|
70
113
|
}
|
|
71
114
|
}
|
|
72
|
-
const maxError = errors.length > 0 ? Math.max(...errors.map(e => e.error)) : 0;
|
|
73
115
|
const meanError = errors.length > 0
|
|
74
116
|
? errors.reduce((sum, e) => sum + e.error, 0) / errors.length
|
|
75
117
|
: 0;
|
|
76
118
|
return {
|
|
77
119
|
passed: errors.length === 0,
|
|
78
120
|
errors,
|
|
121
|
+
singularities,
|
|
79
122
|
maxError,
|
|
80
|
-
meanError
|
|
123
|
+
meanError,
|
|
124
|
+
totalChecks
|
|
125
|
+
};
|
|
126
|
+
}
|
|
127
|
+
/**
|
|
128
|
+
* Compare analytical and numerical gradients
|
|
129
|
+
* Distinguishes between: pass, error (mismatch), and singularity (both NaN/Inf)
|
|
130
|
+
*/
|
|
131
|
+
compareGradients(analytical, numerical, parameter, component) {
|
|
132
|
+
const analyticalBad = !isFinite(analytical);
|
|
133
|
+
const numericalBad = !isFinite(numerical);
|
|
134
|
+
// Both produce NaN/Inf: singularity (not a bug)
|
|
135
|
+
if (analyticalBad && numericalBad) {
|
|
136
|
+
return {
|
|
137
|
+
type: 'singularity',
|
|
138
|
+
singularity: { parameter, component, analytical, numerical }
|
|
139
|
+
};
|
|
140
|
+
}
|
|
141
|
+
// One produces NaN/Inf, the other doesn't: actual bug
|
|
142
|
+
if (analyticalBad || numericalBad) {
|
|
143
|
+
return {
|
|
144
|
+
type: 'error',
|
|
145
|
+
error: {
|
|
146
|
+
parameter,
|
|
147
|
+
component,
|
|
148
|
+
analytical,
|
|
149
|
+
numerical,
|
|
150
|
+
error: Infinity,
|
|
151
|
+
relativeError: Infinity
|
|
152
|
+
}
|
|
153
|
+
};
|
|
154
|
+
}
|
|
155
|
+
// Both finite: compare values
|
|
156
|
+
const error = Math.abs(analytical - numerical);
|
|
157
|
+
const relativeError = Math.abs(error / (Math.abs(numerical) + 1e-10));
|
|
158
|
+
if (error > this.tolerance && relativeError > this.tolerance) {
|
|
159
|
+
return {
|
|
160
|
+
type: 'error',
|
|
161
|
+
error: { parameter, component, analytical, numerical, error, relativeError }
|
|
162
|
+
};
|
|
163
|
+
}
|
|
164
|
+
return {
|
|
165
|
+
type: 'pass',
|
|
166
|
+
error: { parameter, component, analytical, numerical, error, relativeError }
|
|
81
167
|
};
|
|
82
168
|
}
|
|
83
169
|
/**
|
package/dist/dsl/Guards.d.ts
CHANGED
|
@@ -22,7 +22,7 @@ export declare function analyzeGuards(func: FunctionDef): GuardAnalysisResult;
|
|
|
22
22
|
/**
|
|
23
23
|
* Format guard analysis for display
|
|
24
24
|
*/
|
|
25
|
-
export declare function formatGuardWarnings(result: GuardAnalysisResult): string;
|
|
25
|
+
export declare function formatGuardWarnings(result: GuardAnalysisResult, asComments?: boolean): string;
|
|
26
26
|
/**
|
|
27
27
|
* Generate guard code snippets for common cases
|
|
28
28
|
*/
|
package/dist/dsl/Guards.js
CHANGED
|
@@ -154,21 +154,22 @@ function isSqExpression(expr) {
|
|
|
154
154
|
/**
|
|
155
155
|
* Format guard analysis for display
|
|
156
156
|
*/
|
|
157
|
-
export function formatGuardWarnings(result) {
|
|
157
|
+
export function formatGuardWarnings(result, asComments = false) {
|
|
158
158
|
if (!result.hasIssues) {
|
|
159
159
|
return '';
|
|
160
160
|
}
|
|
161
|
+
const prefix = asComments ? '// ' : '';
|
|
161
162
|
const lines = [];
|
|
162
|
-
lines.push(
|
|
163
|
-
lines.push(
|
|
164
|
-
lines.push(
|
|
165
|
-
lines.push(
|
|
166
|
-
lines.push(
|
|
163
|
+
lines.push(`${prefix}⚠️ EDGE CASE WARNINGS:`);
|
|
164
|
+
lines.push(prefix);
|
|
165
|
+
lines.push(`${prefix}The generated code may encounter edge cases that produce`);
|
|
166
|
+
lines.push(`${prefix}NaN, Infinity, or incorrect gradients:`);
|
|
167
|
+
lines.push(prefix);
|
|
167
168
|
// Show each guard individually with context
|
|
168
169
|
for (const guard of result.guards) {
|
|
169
170
|
const typeLabel = formatGuardType(guard.type);
|
|
170
171
|
// Show location and variable if available
|
|
171
|
-
let location =
|
|
172
|
+
let location = `${prefix} •`;
|
|
172
173
|
if (guard.line) {
|
|
173
174
|
location += ` Line ${guard.line}:`;
|
|
174
175
|
}
|
|
@@ -176,13 +177,13 @@ export function formatGuardWarnings(result) {
|
|
|
176
177
|
location += ` ${guard.variableName} =`;
|
|
177
178
|
}
|
|
178
179
|
lines.push(location);
|
|
179
|
-
lines.push(
|
|
180
|
-
lines.push(
|
|
181
|
-
lines.push(
|
|
180
|
+
lines.push(`${prefix} ${typeLabel}: ${guard.description}`);
|
|
181
|
+
lines.push(`${prefix} 💡 Fix: ${guard.suggestion}`);
|
|
182
|
+
lines.push(prefix);
|
|
182
183
|
}
|
|
183
|
-
lines.push(
|
|
184
|
-
lines.push(
|
|
185
|
-
lines.push(
|
|
184
|
+
lines.push(`${prefix}Add runtime checks or ensure inputs are within valid ranges.`);
|
|
185
|
+
lines.push(`${prefix}Use --guards --epsilon 1e-10 to automatically emit epsilon guards.`);
|
|
186
|
+
lines.push(prefix);
|
|
186
187
|
return lines.join('\n');
|
|
187
188
|
}
|
|
188
189
|
function formatGuardType(type) {
|
package/dist/dsl/Inliner.d.ts
CHANGED
|
@@ -8,3 +8,8 @@ import { Expression, FunctionDef } from './AST.js';
|
|
|
8
8
|
* Returns a new expression with all intermediate variables substituted
|
|
9
9
|
*/
|
|
10
10
|
export declare function inlineIntermediateVariables(func: FunctionDef): Expression;
|
|
11
|
+
/**
|
|
12
|
+
* Inline an expression using a substitution map
|
|
13
|
+
* Used to get the fully-expanded form of forward pass expressions
|
|
14
|
+
*/
|
|
15
|
+
export declare function inlineExpression(expr: Expression, substitutions: Map<string, Expression>): Expression;
|
package/dist/dsl/Inliner.js
CHANGED
|
@@ -39,3 +39,11 @@ export function inlineIntermediateVariables(func) {
|
|
|
39
39
|
const transformer = new VariableSubstitutionTransformer(substitutions);
|
|
40
40
|
return transformer.transform(func.returnExpr);
|
|
41
41
|
}
|
|
42
|
+
/**
|
|
43
|
+
* Inline an expression using a substitution map
|
|
44
|
+
* Used to get the fully-expanded form of forward pass expressions
|
|
45
|
+
*/
|
|
46
|
+
export function inlineExpression(expr, substitutions) {
|
|
47
|
+
const transformer = new VariableSubstitutionTransformer(substitutions);
|
|
48
|
+
return transformer.transform(expr);
|
|
49
|
+
}
|
package/dist/dsl/Simplify.d.ts
CHANGED
|
@@ -15,3 +15,10 @@ export declare function simplifyGradients(gradients: Map<string, Expression | {
|
|
|
15
15
|
}>): Map<string, Expression | {
|
|
16
16
|
components: Map<string, Expression>;
|
|
17
17
|
}>;
|
|
18
|
+
/**
|
|
19
|
+
* Post-CSE simplification: applies rules that were intentionally skipped during
|
|
20
|
+
* initial simplification to avoid interfering with CSE.
|
|
21
|
+
*
|
|
22
|
+
* Specifically: a + a → 2 * a (now safe because CSE has already extracted temps)
|
|
23
|
+
*/
|
|
24
|
+
export declare function simplifyPostCSE(expr: Expression): Expression;
|
package/dist/dsl/Simplify.js
CHANGED
|
@@ -46,6 +46,9 @@ class Simplifier extends ExpressionTransformer {
|
|
|
46
46
|
return right;
|
|
47
47
|
if (rightNum === 0)
|
|
48
48
|
return left;
|
|
49
|
+
// Note: a + a → 2 * a rules are intentionally NOT applied here
|
|
50
|
+
// because they flatten expression structure and interfere with CSE.
|
|
51
|
+
// The CSE pass will extract common subexpressions instead.
|
|
49
52
|
}
|
|
50
53
|
// Subtraction rules
|
|
51
54
|
if (expr.operator === '-') {
|
|
@@ -57,6 +60,25 @@ class Simplifier extends ExpressionTransformer {
|
|
|
57
60
|
if (expressionsEqual(left, right)) {
|
|
58
61
|
return { kind: 'number', value: 0 };
|
|
59
62
|
}
|
|
63
|
+
// a - (-b) → a + b
|
|
64
|
+
if (right.kind === 'unary' && right.operator === '-') {
|
|
65
|
+
return this.transform({
|
|
66
|
+
kind: 'binary',
|
|
67
|
+
operator: '+',
|
|
68
|
+
left,
|
|
69
|
+
right: right.operand
|
|
70
|
+
});
|
|
71
|
+
}
|
|
72
|
+
// (-a) - (-b) → b - a
|
|
73
|
+
if (left.kind === 'unary' && left.operator === '-' &&
|
|
74
|
+
right.kind === 'unary' && right.operator === '-') {
|
|
75
|
+
return this.transform({
|
|
76
|
+
kind: 'binary',
|
|
77
|
+
operator: '-',
|
|
78
|
+
left: right.operand,
|
|
79
|
+
right: left.operand
|
|
80
|
+
});
|
|
81
|
+
}
|
|
60
82
|
}
|
|
61
83
|
// Multiplication rules
|
|
62
84
|
if (expr.operator === '*') {
|
|
@@ -68,6 +90,50 @@ class Simplifier extends ExpressionTransformer {
|
|
|
68
90
|
return right;
|
|
69
91
|
if (rightNum === 1)
|
|
70
92
|
return left;
|
|
93
|
+
// -1 * x → -x
|
|
94
|
+
if (leftNum === -1) {
|
|
95
|
+
return { kind: 'unary', operator: '-', operand: right };
|
|
96
|
+
}
|
|
97
|
+
// x * -1 → -x
|
|
98
|
+
if (rightNum === -1) {
|
|
99
|
+
return { kind: 'unary', operator: '-', operand: left };
|
|
100
|
+
}
|
|
101
|
+
// (-a) * (-b) → a * b
|
|
102
|
+
if (left.kind === 'unary' && left.operator === '-' &&
|
|
103
|
+
right.kind === 'unary' && right.operator === '-') {
|
|
104
|
+
return this.transform({
|
|
105
|
+
kind: 'binary',
|
|
106
|
+
operator: '*',
|
|
107
|
+
left: left.operand,
|
|
108
|
+
right: right.operand
|
|
109
|
+
});
|
|
110
|
+
}
|
|
111
|
+
// (-a) * b → -(a * b)
|
|
112
|
+
if (left.kind === 'unary' && left.operator === '-') {
|
|
113
|
+
return {
|
|
114
|
+
kind: 'unary',
|
|
115
|
+
operator: '-',
|
|
116
|
+
operand: this.transform({
|
|
117
|
+
kind: 'binary',
|
|
118
|
+
operator: '*',
|
|
119
|
+
left: left.operand,
|
|
120
|
+
right
|
|
121
|
+
})
|
|
122
|
+
};
|
|
123
|
+
}
|
|
124
|
+
// a * (-b) → -(a * b)
|
|
125
|
+
if (right.kind === 'unary' && right.operator === '-') {
|
|
126
|
+
return {
|
|
127
|
+
kind: 'unary',
|
|
128
|
+
operator: '-',
|
|
129
|
+
operand: this.transform({
|
|
130
|
+
kind: 'binary',
|
|
131
|
+
operator: '*',
|
|
132
|
+
left,
|
|
133
|
+
right: right.operand
|
|
134
|
+
})
|
|
135
|
+
};
|
|
136
|
+
}
|
|
71
137
|
// (x / x) * y → y
|
|
72
138
|
if (left.kind === 'binary' && left.operator === '/') {
|
|
73
139
|
if (expressionsEqual(left.left, left.right)) {
|
|
@@ -116,6 +182,42 @@ class Simplifier extends ExpressionTransformer {
|
|
|
116
182
|
if (expressionsEqual(left, right)) {
|
|
117
183
|
return { kind: 'number', value: 1 };
|
|
118
184
|
}
|
|
185
|
+
// (-a) / (-b) → a / b
|
|
186
|
+
if (left.kind === 'unary' && left.operator === '-' &&
|
|
187
|
+
right.kind === 'unary' && right.operator === '-') {
|
|
188
|
+
return this.transform({
|
|
189
|
+
kind: 'binary',
|
|
190
|
+
operator: '/',
|
|
191
|
+
left: left.operand,
|
|
192
|
+
right: right.operand
|
|
193
|
+
});
|
|
194
|
+
}
|
|
195
|
+
// (-a) / b → -(a / b)
|
|
196
|
+
if (left.kind === 'unary' && left.operator === '-') {
|
|
197
|
+
return {
|
|
198
|
+
kind: 'unary',
|
|
199
|
+
operator: '-',
|
|
200
|
+
operand: this.transform({
|
|
201
|
+
kind: 'binary',
|
|
202
|
+
operator: '/',
|
|
203
|
+
left: left.operand,
|
|
204
|
+
right
|
|
205
|
+
})
|
|
206
|
+
};
|
|
207
|
+
}
|
|
208
|
+
// a / (-b) → -(a / b)
|
|
209
|
+
if (right.kind === 'unary' && right.operator === '-') {
|
|
210
|
+
return {
|
|
211
|
+
kind: 'unary',
|
|
212
|
+
operator: '-',
|
|
213
|
+
operand: this.transform({
|
|
214
|
+
kind: 'binary',
|
|
215
|
+
operator: '/',
|
|
216
|
+
left,
|
|
217
|
+
right: right.operand
|
|
218
|
+
})
|
|
219
|
+
};
|
|
220
|
+
}
|
|
119
221
|
// (a + a) / 2 → a
|
|
120
222
|
if (rightNum === 2 && left.kind === 'binary' && left.operator === '+') {
|
|
121
223
|
if (expressionsEqual(left.left, left.right)) {
|
|
@@ -321,3 +423,37 @@ export function simplifyGradients(gradients) {
|
|
|
321
423
|
}
|
|
322
424
|
return simplified;
|
|
323
425
|
}
|
|
426
|
+
/**
|
|
427
|
+
* Post-CSE simplification: applies rules that were intentionally skipped during
|
|
428
|
+
* initial simplification to avoid interfering with CSE.
|
|
429
|
+
*
|
|
430
|
+
* Specifically: a + a → 2 * a (now safe because CSE has already extracted temps)
|
|
431
|
+
*/
|
|
432
|
+
export function simplifyPostCSE(expr) {
|
|
433
|
+
return new PostCSESimplifier().transform(expr);
|
|
434
|
+
}
|
|
435
|
+
class PostCSESimplifier extends ExpressionTransformer {
|
|
436
|
+
visitBinaryOp(expr) {
|
|
437
|
+
const left = this.transform(expr.left);
|
|
438
|
+
const right = this.transform(expr.right);
|
|
439
|
+
// a + a → 2 * a
|
|
440
|
+
if (expr.operator === '+' && expressionsEqual(left, right)) {
|
|
441
|
+
return {
|
|
442
|
+
kind: 'binary',
|
|
443
|
+
operator: '*',
|
|
444
|
+
left: { kind: 'number', value: 2 },
|
|
445
|
+
right: left
|
|
446
|
+
};
|
|
447
|
+
}
|
|
448
|
+
// Return simplified if no changes
|
|
449
|
+
if (left === expr.left && right === expr.right) {
|
|
450
|
+
return expr;
|
|
451
|
+
}
|
|
452
|
+
return {
|
|
453
|
+
kind: 'binary',
|
|
454
|
+
operator: expr.operator,
|
|
455
|
+
left,
|
|
456
|
+
right
|
|
457
|
+
};
|
|
458
|
+
}
|
|
459
|
+
}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Conversion between AST Expressions and E-Graph
|
|
3
|
+
*/
|
|
4
|
+
import { EGraph } from './EGraph.js';
|
|
5
|
+
import { EClassId } from './ENode.js';
|
|
6
|
+
import { Expression } from '../AST.js';
|
|
7
|
+
/**
|
|
8
|
+
* Add an AST Expression to the e-graph, returning its e-class ID
|
|
9
|
+
*/
|
|
10
|
+
export declare function addExpression(egraph: EGraph, expr: Expression): EClassId;
|
|
11
|
+
/**
|
|
12
|
+
* Add multiple expressions, returning a map of original keys to e-class IDs
|
|
13
|
+
*/
|
|
14
|
+
export declare function addExpressions<K extends string>(egraph: EGraph, expressions: Map<K, Expression>): Map<K, EClassId>;
|
|
15
|
+
/**
|
|
16
|
+
* Add all gradients (Map<paramName, Map<component, Expression>>)
|
|
17
|
+
* Returns Map<paramName, Map<component, EClassId>>
|
|
18
|
+
*/
|
|
19
|
+
export declare function addGradients(egraph: EGraph, gradients: Map<string, Map<string, Expression>>): Map<string, Map<string, EClassId>>;
|
|
20
|
+
/**
|
|
21
|
+
* Get all root e-class IDs from gradient structure
|
|
22
|
+
*/
|
|
23
|
+
export declare function getRootIds(gradientIds: Map<string, Map<string, EClassId>>): EClassId[];
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Conversion between AST Expressions and E-Graph
|
|
3
|
+
*/
|
|
4
|
+
/**
|
|
5
|
+
* Add an AST Expression to the e-graph, returning its e-class ID
|
|
6
|
+
*/
|
|
7
|
+
export function addExpression(egraph, expr) {
|
|
8
|
+
switch (expr.kind) {
|
|
9
|
+
case 'number':
|
|
10
|
+
return egraph.add({ tag: 'num', value: expr.value });
|
|
11
|
+
case 'variable':
|
|
12
|
+
return egraph.add({ tag: 'var', name: expr.name });
|
|
13
|
+
case 'binary': {
|
|
14
|
+
const left = addExpression(egraph, expr.left);
|
|
15
|
+
const right = addExpression(egraph, expr.right);
|
|
16
|
+
switch (expr.operator) {
|
|
17
|
+
case '+':
|
|
18
|
+
return egraph.add({ tag: 'add', children: [left, right] });
|
|
19
|
+
case '-':
|
|
20
|
+
return egraph.add({ tag: 'sub', children: [left, right] });
|
|
21
|
+
case '*':
|
|
22
|
+
return egraph.add({ tag: 'mul', children: [left, right] });
|
|
23
|
+
case '/':
|
|
24
|
+
return egraph.add({ tag: 'div', children: [left, right] });
|
|
25
|
+
case '^':
|
|
26
|
+
case '**':
|
|
27
|
+
return egraph.add({ tag: 'pow', children: [left, right] });
|
|
28
|
+
}
|
|
29
|
+
}
|
|
30
|
+
case 'unary': {
|
|
31
|
+
const operand = addExpression(egraph, expr.operand);
|
|
32
|
+
if (expr.operator === '-') {
|
|
33
|
+
return egraph.add({ tag: 'neg', child: operand });
|
|
34
|
+
}
|
|
35
|
+
// Unary + is identity
|
|
36
|
+
return operand;
|
|
37
|
+
}
|
|
38
|
+
case 'call': {
|
|
39
|
+
const args = expr.args.map(arg => addExpression(egraph, arg));
|
|
40
|
+
return egraph.add({ tag: 'call', name: expr.name, children: args });
|
|
41
|
+
}
|
|
42
|
+
case 'component': {
|
|
43
|
+
const object = addExpression(egraph, expr.object);
|
|
44
|
+
return egraph.add({ tag: 'component', object, field: expr.component });
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
/**
|
|
49
|
+
* Add multiple expressions, returning a map of original keys to e-class IDs
|
|
50
|
+
*/
|
|
51
|
+
export function addExpressions(egraph, expressions) {
|
|
52
|
+
const result = new Map();
|
|
53
|
+
for (const [key, expr] of expressions) {
|
|
54
|
+
result.set(key, addExpression(egraph, expr));
|
|
55
|
+
}
|
|
56
|
+
return result;
|
|
57
|
+
}
|
|
58
|
+
/**
|
|
59
|
+
* Add all gradients (Map<paramName, Map<component, Expression>>)
|
|
60
|
+
* Returns Map<paramName, Map<component, EClassId>>
|
|
61
|
+
*/
|
|
62
|
+
export function addGradients(egraph, gradients) {
|
|
63
|
+
const result = new Map();
|
|
64
|
+
for (const [paramName, components] of gradients) {
|
|
65
|
+
const componentIds = new Map();
|
|
66
|
+
for (const [comp, expr] of components) {
|
|
67
|
+
componentIds.set(comp, addExpression(egraph, expr));
|
|
68
|
+
}
|
|
69
|
+
result.set(paramName, componentIds);
|
|
70
|
+
}
|
|
71
|
+
return result;
|
|
72
|
+
}
|
|
73
|
+
/**
|
|
74
|
+
* Get all root e-class IDs from gradient structure
|
|
75
|
+
*/
|
|
76
|
+
export function getRootIds(gradientIds) {
|
|
77
|
+
const roots = [];
|
|
78
|
+
for (const components of gradientIds.values()) {
|
|
79
|
+
for (const id of components.values()) {
|
|
80
|
+
roots.push(id);
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
return roots;
|
|
84
|
+
}
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* E-Graph: Equality Graph for expression optimization
|
|
3
|
+
*
|
|
4
|
+
* An e-graph efficiently represents equivalence classes of expressions.
|
|
5
|
+
* It supports:
|
|
6
|
+
* - Adding expressions (returns e-class ID)
|
|
7
|
+
* - Merging e-classes (union)
|
|
8
|
+
* - Finding canonical e-class (find)
|
|
9
|
+
* - Rebuilding after merges (maintains congruence)
|
|
10
|
+
*/
|
|
11
|
+
import { ENode, EClassId } from './ENode.js';
|
|
12
|
+
/**
|
|
13
|
+
* E-Class: An equivalence class of expressions
|
|
14
|
+
*/
|
|
15
|
+
export interface EClass {
|
|
16
|
+
id: EClassId;
|
|
17
|
+
nodes: Set<string>;
|
|
18
|
+
parents: Set<string>;
|
|
19
|
+
}
|
|
20
|
+
/**
|
|
21
|
+
* E-Graph: The main data structure
|
|
22
|
+
*/
|
|
23
|
+
export declare class EGraph {
|
|
24
|
+
private nextId;
|
|
25
|
+
private classes;
|
|
26
|
+
private parent;
|
|
27
|
+
private rank;
|
|
28
|
+
private hashcons;
|
|
29
|
+
private nodeStore;
|
|
30
|
+
private pending;
|
|
31
|
+
/**
|
|
32
|
+
* Find the canonical e-class ID (with path compression)
|
|
33
|
+
*/
|
|
34
|
+
find(id: EClassId): EClassId;
|
|
35
|
+
/**
|
|
36
|
+
* Add an e-node to the e-graph, returning its e-class ID
|
|
37
|
+
* If the node already exists, returns the existing class
|
|
38
|
+
*/
|
|
39
|
+
add(node: ENode): EClassId;
|
|
40
|
+
/**
|
|
41
|
+
* Merge two e-classes, returning the new canonical ID
|
|
42
|
+
*/
|
|
43
|
+
merge(id1: EClassId, id2: EClassId): EClassId;
|
|
44
|
+
/**
|
|
45
|
+
* Rebuild the e-graph to restore congruence invariants
|
|
46
|
+
* Must be called after a batch of merges
|
|
47
|
+
*/
|
|
48
|
+
rebuild(): void;
|
|
49
|
+
/**
|
|
50
|
+
* Repair an e-class after merges
|
|
51
|
+
*/
|
|
52
|
+
private repair;
|
|
53
|
+
/**
|
|
54
|
+
* Find which e-class contains a node (by key)
|
|
55
|
+
*/
|
|
56
|
+
private findClassForNode;
|
|
57
|
+
/**
|
|
58
|
+
* Canonicalize an e-node (update children to canonical IDs)
|
|
59
|
+
*/
|
|
60
|
+
private canonicalize;
|
|
61
|
+
/**
|
|
62
|
+
* Get all e-class IDs
|
|
63
|
+
*/
|
|
64
|
+
getClassIds(): EClassId[];
|
|
65
|
+
/**
|
|
66
|
+
* Get an e-class by ID
|
|
67
|
+
*/
|
|
68
|
+
getClass(id: EClassId): EClass | undefined;
|
|
69
|
+
/**
|
|
70
|
+
* Get all e-nodes in an e-class
|
|
71
|
+
*/
|
|
72
|
+
getNodes(classId: EClassId): ENode[];
|
|
73
|
+
/**
|
|
74
|
+
* Get the number of e-classes
|
|
75
|
+
*/
|
|
76
|
+
get size(): number;
|
|
77
|
+
/**
|
|
78
|
+
* Get a node by its key
|
|
79
|
+
*/
|
|
80
|
+
getNodeByKey(key: string): ENode | undefined;
|
|
81
|
+
/**
|
|
82
|
+
* Lookup e-class by node (if it exists)
|
|
83
|
+
*/
|
|
84
|
+
lookup(node: ENode): EClassId | undefined;
|
|
85
|
+
/**
|
|
86
|
+
* Debug: print e-graph state
|
|
87
|
+
*/
|
|
88
|
+
dump(): string;
|
|
89
|
+
/**
|
|
90
|
+
* Convert e-node to readable string
|
|
91
|
+
*/
|
|
92
|
+
private nodeToString;
|
|
93
|
+
}
|