gradient-script 0.1.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +52 -9
- package/dist/cli.js +134 -19
- package/dist/dsl/AST.d.ts +8 -0
- package/dist/dsl/CodeGen.d.ts +8 -3
- package/dist/dsl/CodeGen.js +583 -132
- package/dist/dsl/Errors.d.ts +6 -1
- package/dist/dsl/Errors.js +70 -1
- package/dist/dsl/Expander.js +5 -2
- package/dist/dsl/ExpressionUtils.d.ts +14 -0
- package/dist/dsl/ExpressionUtils.js +56 -0
- package/dist/dsl/GradientChecker.d.ts +21 -0
- package/dist/dsl/GradientChecker.js +109 -23
- package/dist/dsl/Guards.d.ts +3 -1
- package/dist/dsl/Guards.js +86 -43
- package/dist/dsl/Inliner.d.ts +5 -0
- package/dist/dsl/Inliner.js +11 -2
- package/dist/dsl/Lexer.js +3 -1
- package/dist/dsl/Parser.js +11 -5
- package/dist/dsl/Simplify.d.ts +7 -0
- package/dist/dsl/Simplify.js +183 -0
- package/dist/dsl/egraph/Convert.d.ts +23 -0
- package/dist/dsl/egraph/Convert.js +84 -0
- package/dist/dsl/egraph/EGraph.d.ts +93 -0
- package/dist/dsl/egraph/EGraph.js +292 -0
- package/dist/dsl/egraph/ENode.d.ts +63 -0
- package/dist/dsl/egraph/ENode.js +94 -0
- package/dist/dsl/egraph/Extractor.d.ts +49 -0
- package/dist/dsl/egraph/Extractor.js +1068 -0
- package/dist/dsl/egraph/Optimizer.d.ts +50 -0
- package/dist/dsl/egraph/Optimizer.js +88 -0
- package/dist/dsl/egraph/Pattern.d.ts +80 -0
- package/dist/dsl/egraph/Pattern.js +325 -0
- package/dist/dsl/egraph/Rewriter.d.ts +44 -0
- package/dist/dsl/egraph/Rewriter.js +131 -0
- package/dist/dsl/egraph/Rules.d.ts +44 -0
- package/dist/dsl/egraph/Rules.js +187 -0
- package/dist/dsl/egraph/index.d.ts +15 -0
- package/dist/dsl/egraph/index.js +21 -0
- package/package.json +1 -1
- package/dist/dsl/CSE.d.ts +0 -21
- package/dist/dsl/CSE.js +0 -194
- package/dist/symbolic/AST.d.ts +0 -113
- package/dist/symbolic/AST.js +0 -128
- package/dist/symbolic/CodeGen.d.ts +0 -35
- package/dist/symbolic/CodeGen.js +0 -280
- package/dist/symbolic/Parser.d.ts +0 -64
- package/dist/symbolic/Parser.js +0 -329
- package/dist/symbolic/Simplify.d.ts +0 -10
- package/dist/symbolic/Simplify.js +0 -244
- package/dist/symbolic/SymbolicDiff.d.ts +0 -35
- package/dist/symbolic/SymbolicDiff.js +0 -339
package/dist/dsl/Inliner.js
CHANGED
|
@@ -6,8 +6,9 @@ import { ExpressionTransformer } from './ExpressionTransformer.js';
|
|
|
6
6
|
/**
|
|
7
7
|
* Expression transformer that substitutes variables from a substitution map
|
|
8
8
|
* Handles recursive inlining by reprocessing substituted expressions
|
|
9
|
+
* Used for inlining intermediate variables to eliminate assignments
|
|
9
10
|
*/
|
|
10
|
-
class
|
|
11
|
+
class VariableSubstitutionTransformer extends ExpressionTransformer {
|
|
11
12
|
substitutions;
|
|
12
13
|
constructor(substitutions) {
|
|
13
14
|
super();
|
|
@@ -35,6 +36,14 @@ export function inlineIntermediateVariables(func) {
|
|
|
35
36
|
}
|
|
36
37
|
}
|
|
37
38
|
// Use transformer to inline all variables
|
|
38
|
-
const transformer = new
|
|
39
|
+
const transformer = new VariableSubstitutionTransformer(substitutions);
|
|
39
40
|
return transformer.transform(func.returnExpr);
|
|
40
41
|
}
|
|
42
|
+
/**
|
|
43
|
+
* Inline an expression using a substitution map
|
|
44
|
+
* Used to get the fully-expanded form of forward pass expressions
|
|
45
|
+
*/
|
|
46
|
+
export function inlineExpression(expr, substitutions) {
|
|
47
|
+
const transformer = new VariableSubstitutionTransformer(substitutions);
|
|
48
|
+
return transformer.transform(expr);
|
|
49
|
+
}
|
package/dist/dsl/Lexer.js
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
* Lexer for GradientScript DSL
|
|
3
3
|
* Tokenizes input with support for ∇, ^, **, and structured types
|
|
4
4
|
*/
|
|
5
|
+
import { ParseError } from './Errors.js';
|
|
5
6
|
export var TokenType;
|
|
6
7
|
(function (TokenType) {
|
|
7
8
|
// Literals
|
|
@@ -126,7 +127,8 @@ export class Lexer {
|
|
|
126
127
|
this.advance();
|
|
127
128
|
return { type: TokenType.NEWLINE, value: '\n', line, column };
|
|
128
129
|
}
|
|
129
|
-
|
|
130
|
+
// Create a more helpful error for common mistakes
|
|
131
|
+
throw new ParseError(`Unexpected character '${char}'`, line, column, char);
|
|
130
132
|
}
|
|
131
133
|
number() {
|
|
132
134
|
const line = this.line;
|
package/dist/dsl/Parser.js
CHANGED
|
@@ -123,13 +123,15 @@ export class Parser {
|
|
|
123
123
|
* variable = expression
|
|
124
124
|
*/
|
|
125
125
|
assignment() {
|
|
126
|
-
const
|
|
126
|
+
const varToken = this.consume(TokenType.IDENTIFIER, 'Expected variable name');
|
|
127
|
+
const variable = varToken.value;
|
|
127
128
|
this.consume(TokenType.EQUALS, "Expected '=' in assignment");
|
|
128
129
|
const expression = this.expression();
|
|
129
130
|
return {
|
|
130
131
|
kind: 'assignment',
|
|
131
132
|
variable,
|
|
132
|
-
expression
|
|
133
|
+
expression,
|
|
134
|
+
loc: { line: varToken.line, column: varToken.column }
|
|
133
135
|
};
|
|
134
136
|
}
|
|
135
137
|
/**
|
|
@@ -161,13 +163,15 @@ export class Parser {
|
|
|
161
163
|
multiplicative() {
|
|
162
164
|
let expr = this.power();
|
|
163
165
|
while (this.match(TokenType.MULTIPLY, TokenType.DIVIDE)) {
|
|
164
|
-
const
|
|
166
|
+
const opToken = this.previous();
|
|
167
|
+
const operator = opToken.value;
|
|
165
168
|
const right = this.power();
|
|
166
169
|
expr = {
|
|
167
170
|
kind: 'binary',
|
|
168
171
|
operator,
|
|
169
172
|
left: expr,
|
|
170
|
-
right
|
|
173
|
+
right,
|
|
174
|
+
loc: { line: opToken.line, column: opToken.column }
|
|
171
175
|
};
|
|
172
176
|
}
|
|
173
177
|
return expr;
|
|
@@ -213,6 +217,7 @@ export class Parser {
|
|
|
213
217
|
while (true) {
|
|
214
218
|
if (this.match(TokenType.LPAREN)) {
|
|
215
219
|
// Function call
|
|
220
|
+
const startLoc = expr.loc || { line: this.previous().line, column: this.previous().column };
|
|
216
221
|
const args = this.argumentList();
|
|
217
222
|
this.consume(TokenType.RPAREN, "Expected ')' after arguments");
|
|
218
223
|
if (expr.kind !== 'variable') {
|
|
@@ -222,7 +227,8 @@ export class Parser {
|
|
|
222
227
|
expr = {
|
|
223
228
|
kind: 'call',
|
|
224
229
|
name: expr.name,
|
|
225
|
-
args
|
|
230
|
+
args,
|
|
231
|
+
loc: startLoc
|
|
226
232
|
};
|
|
227
233
|
}
|
|
228
234
|
else if (this.match(TokenType.DOT)) {
|
package/dist/dsl/Simplify.d.ts
CHANGED
|
@@ -15,3 +15,10 @@ export declare function simplifyGradients(gradients: Map<string, Expression | {
|
|
|
15
15
|
}>): Map<string, Expression | {
|
|
16
16
|
components: Map<string, Expression>;
|
|
17
17
|
}>;
|
|
18
|
+
/**
|
|
19
|
+
* Post-CSE simplification: applies rules that were intentionally skipped during
|
|
20
|
+
* initial simplification to avoid interfering with CSE.
|
|
21
|
+
*
|
|
22
|
+
* Specifically: a + a → 2 * a (now safe because CSE has already extracted temps)
|
|
23
|
+
*/
|
|
24
|
+
export declare function simplifyPostCSE(expr: Expression): Expression;
|
package/dist/dsl/Simplify.js
CHANGED
|
@@ -46,6 +46,9 @@ class Simplifier extends ExpressionTransformer {
|
|
|
46
46
|
return right;
|
|
47
47
|
if (rightNum === 0)
|
|
48
48
|
return left;
|
|
49
|
+
// Note: a + a → 2 * a rules are intentionally NOT applied here
|
|
50
|
+
// because they flatten expression structure and interfere with CSE.
|
|
51
|
+
// The CSE pass will extract common subexpressions instead.
|
|
49
52
|
}
|
|
50
53
|
// Subtraction rules
|
|
51
54
|
if (expr.operator === '-') {
|
|
@@ -57,6 +60,25 @@ class Simplifier extends ExpressionTransformer {
|
|
|
57
60
|
if (expressionsEqual(left, right)) {
|
|
58
61
|
return { kind: 'number', value: 0 };
|
|
59
62
|
}
|
|
63
|
+
// a - (-b) → a + b
|
|
64
|
+
if (right.kind === 'unary' && right.operator === '-') {
|
|
65
|
+
return this.transform({
|
|
66
|
+
kind: 'binary',
|
|
67
|
+
operator: '+',
|
|
68
|
+
left,
|
|
69
|
+
right: right.operand
|
|
70
|
+
});
|
|
71
|
+
}
|
|
72
|
+
// (-a) - (-b) → b - a
|
|
73
|
+
if (left.kind === 'unary' && left.operator === '-' &&
|
|
74
|
+
right.kind === 'unary' && right.operator === '-') {
|
|
75
|
+
return this.transform({
|
|
76
|
+
kind: 'binary',
|
|
77
|
+
operator: '-',
|
|
78
|
+
left: right.operand,
|
|
79
|
+
right: left.operand
|
|
80
|
+
});
|
|
81
|
+
}
|
|
60
82
|
}
|
|
61
83
|
// Multiplication rules
|
|
62
84
|
if (expr.operator === '*') {
|
|
@@ -68,6 +90,50 @@ class Simplifier extends ExpressionTransformer {
|
|
|
68
90
|
return right;
|
|
69
91
|
if (rightNum === 1)
|
|
70
92
|
return left;
|
|
93
|
+
// -1 * x → -x
|
|
94
|
+
if (leftNum === -1) {
|
|
95
|
+
return { kind: 'unary', operator: '-', operand: right };
|
|
96
|
+
}
|
|
97
|
+
// x * -1 → -x
|
|
98
|
+
if (rightNum === -1) {
|
|
99
|
+
return { kind: 'unary', operator: '-', operand: left };
|
|
100
|
+
}
|
|
101
|
+
// (-a) * (-b) → a * b
|
|
102
|
+
if (left.kind === 'unary' && left.operator === '-' &&
|
|
103
|
+
right.kind === 'unary' && right.operator === '-') {
|
|
104
|
+
return this.transform({
|
|
105
|
+
kind: 'binary',
|
|
106
|
+
operator: '*',
|
|
107
|
+
left: left.operand,
|
|
108
|
+
right: right.operand
|
|
109
|
+
});
|
|
110
|
+
}
|
|
111
|
+
// (-a) * b → -(a * b)
|
|
112
|
+
if (left.kind === 'unary' && left.operator === '-') {
|
|
113
|
+
return {
|
|
114
|
+
kind: 'unary',
|
|
115
|
+
operator: '-',
|
|
116
|
+
operand: this.transform({
|
|
117
|
+
kind: 'binary',
|
|
118
|
+
operator: '*',
|
|
119
|
+
left: left.operand,
|
|
120
|
+
right
|
|
121
|
+
})
|
|
122
|
+
};
|
|
123
|
+
}
|
|
124
|
+
// a * (-b) → -(a * b)
|
|
125
|
+
if (right.kind === 'unary' && right.operator === '-') {
|
|
126
|
+
return {
|
|
127
|
+
kind: 'unary',
|
|
128
|
+
operator: '-',
|
|
129
|
+
operand: this.transform({
|
|
130
|
+
kind: 'binary',
|
|
131
|
+
operator: '*',
|
|
132
|
+
left,
|
|
133
|
+
right: right.operand
|
|
134
|
+
})
|
|
135
|
+
};
|
|
136
|
+
}
|
|
71
137
|
// (x / x) * y → y
|
|
72
138
|
if (left.kind === 'binary' && left.operator === '/') {
|
|
73
139
|
if (expressionsEqual(left.left, left.right)) {
|
|
@@ -116,6 +182,89 @@ class Simplifier extends ExpressionTransformer {
|
|
|
116
182
|
if (expressionsEqual(left, right)) {
|
|
117
183
|
return { kind: 'number', value: 1 };
|
|
118
184
|
}
|
|
185
|
+
// (-a) / (-b) → a / b
|
|
186
|
+
if (left.kind === 'unary' && left.operator === '-' &&
|
|
187
|
+
right.kind === 'unary' && right.operator === '-') {
|
|
188
|
+
return this.transform({
|
|
189
|
+
kind: 'binary',
|
|
190
|
+
operator: '/',
|
|
191
|
+
left: left.operand,
|
|
192
|
+
right: right.operand
|
|
193
|
+
});
|
|
194
|
+
}
|
|
195
|
+
// (-a) / b → -(a / b)
|
|
196
|
+
if (left.kind === 'unary' && left.operator === '-') {
|
|
197
|
+
return {
|
|
198
|
+
kind: 'unary',
|
|
199
|
+
operator: '-',
|
|
200
|
+
operand: this.transform({
|
|
201
|
+
kind: 'binary',
|
|
202
|
+
operator: '/',
|
|
203
|
+
left: left.operand,
|
|
204
|
+
right
|
|
205
|
+
})
|
|
206
|
+
};
|
|
207
|
+
}
|
|
208
|
+
// a / (-b) → -(a / b)
|
|
209
|
+
if (right.kind === 'unary' && right.operator === '-') {
|
|
210
|
+
return {
|
|
211
|
+
kind: 'unary',
|
|
212
|
+
operator: '-',
|
|
213
|
+
operand: this.transform({
|
|
214
|
+
kind: 'binary',
|
|
215
|
+
operator: '/',
|
|
216
|
+
left,
|
|
217
|
+
right: right.operand
|
|
218
|
+
})
|
|
219
|
+
};
|
|
220
|
+
}
|
|
221
|
+
// (a + a) / 2 → a
|
|
222
|
+
if (rightNum === 2 && left.kind === 'binary' && left.operator === '+') {
|
|
223
|
+
if (expressionsEqual(left.left, left.right)) {
|
|
224
|
+
return left.left;
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
// (a + a) / (2 * b) → a / b
|
|
228
|
+
if (right.kind === 'binary' && right.operator === '*') {
|
|
229
|
+
const rightLeft = right.left;
|
|
230
|
+
const rightRight = right.right;
|
|
231
|
+
const rightLeftNum = isNumber(rightLeft) ? rightLeft.value : null;
|
|
232
|
+
if (rightLeftNum === 2 && left.kind === 'binary' && left.operator === '+') {
|
|
233
|
+
if (expressionsEqual(left.left, left.right)) {
|
|
234
|
+
return {
|
|
235
|
+
kind: 'binary',
|
|
236
|
+
operator: '/',
|
|
237
|
+
left: left.left,
|
|
238
|
+
right: rightRight
|
|
239
|
+
};
|
|
240
|
+
}
|
|
241
|
+
// (-1 * a + a * -1) / (2 * b) → -a / b
|
|
242
|
+
const leftLeft = left.left;
|
|
243
|
+
const leftRight = left.right;
|
|
244
|
+
if (leftLeft.kind === 'binary' && leftLeft.operator === '*' &&
|
|
245
|
+
leftRight.kind === 'binary' && leftRight.operator === '*') {
|
|
246
|
+
const ll_left = leftLeft.left;
|
|
247
|
+
const ll_right = leftLeft.right;
|
|
248
|
+
const lr_left = leftRight.left;
|
|
249
|
+
const lr_right = leftRight.right;
|
|
250
|
+
const ll_leftNum = isNumber(ll_left) ? ll_left.value : null;
|
|
251
|
+
const lr_rightNum = isNumber(lr_right) ? lr_right.value : null;
|
|
252
|
+
// (-1 * a) + (a * -1)
|
|
253
|
+
if (ll_leftNum === -1 && lr_rightNum === -1 && expressionsEqual(ll_right, lr_left)) {
|
|
254
|
+
return {
|
|
255
|
+
kind: 'unary',
|
|
256
|
+
operator: '-',
|
|
257
|
+
operand: {
|
|
258
|
+
kind: 'binary',
|
|
259
|
+
operator: '/',
|
|
260
|
+
left: ll_right,
|
|
261
|
+
right: rightRight
|
|
262
|
+
}
|
|
263
|
+
};
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
}
|
|
267
|
+
}
|
|
119
268
|
}
|
|
120
269
|
// Power rules
|
|
121
270
|
if (expr.operator === '^' || expr.operator === '**') {
|
|
@@ -274,3 +423,37 @@ export function simplifyGradients(gradients) {
|
|
|
274
423
|
}
|
|
275
424
|
return simplified;
|
|
276
425
|
}
|
|
426
|
+
/**
|
|
427
|
+
* Post-CSE simplification: applies rules that were intentionally skipped during
|
|
428
|
+
* initial simplification to avoid interfering with CSE.
|
|
429
|
+
*
|
|
430
|
+
* Specifically: a + a → 2 * a (now safe because CSE has already extracted temps)
|
|
431
|
+
*/
|
|
432
|
+
export function simplifyPostCSE(expr) {
|
|
433
|
+
return new PostCSESimplifier().transform(expr);
|
|
434
|
+
}
|
|
435
|
+
class PostCSESimplifier extends ExpressionTransformer {
|
|
436
|
+
visitBinaryOp(expr) {
|
|
437
|
+
const left = this.transform(expr.left);
|
|
438
|
+
const right = this.transform(expr.right);
|
|
439
|
+
// a + a → 2 * a
|
|
440
|
+
if (expr.operator === '+' && expressionsEqual(left, right)) {
|
|
441
|
+
return {
|
|
442
|
+
kind: 'binary',
|
|
443
|
+
operator: '*',
|
|
444
|
+
left: { kind: 'number', value: 2 },
|
|
445
|
+
right: left
|
|
446
|
+
};
|
|
447
|
+
}
|
|
448
|
+
// Return simplified if no changes
|
|
449
|
+
if (left === expr.left && right === expr.right) {
|
|
450
|
+
return expr;
|
|
451
|
+
}
|
|
452
|
+
return {
|
|
453
|
+
kind: 'binary',
|
|
454
|
+
operator: expr.operator,
|
|
455
|
+
left,
|
|
456
|
+
right
|
|
457
|
+
};
|
|
458
|
+
}
|
|
459
|
+
}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Conversion between AST Expressions and E-Graph
|
|
3
|
+
*/
|
|
4
|
+
import { EGraph } from './EGraph.js';
|
|
5
|
+
import { EClassId } from './ENode.js';
|
|
6
|
+
import { Expression } from '../AST.js';
|
|
7
|
+
/**
|
|
8
|
+
* Add an AST Expression to the e-graph, returning its e-class ID
|
|
9
|
+
*/
|
|
10
|
+
export declare function addExpression(egraph: EGraph, expr: Expression): EClassId;
|
|
11
|
+
/**
|
|
12
|
+
* Add multiple expressions, returning a map of original keys to e-class IDs
|
|
13
|
+
*/
|
|
14
|
+
export declare function addExpressions<K extends string>(egraph: EGraph, expressions: Map<K, Expression>): Map<K, EClassId>;
|
|
15
|
+
/**
|
|
16
|
+
* Add all gradients (Map<paramName, Map<component, Expression>>)
|
|
17
|
+
* Returns Map<paramName, Map<component, EClassId>>
|
|
18
|
+
*/
|
|
19
|
+
export declare function addGradients(egraph: EGraph, gradients: Map<string, Map<string, Expression>>): Map<string, Map<string, EClassId>>;
|
|
20
|
+
/**
|
|
21
|
+
* Get all root e-class IDs from gradient structure
|
|
22
|
+
*/
|
|
23
|
+
export declare function getRootIds(gradientIds: Map<string, Map<string, EClassId>>): EClassId[];
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Conversion between AST Expressions and E-Graph
|
|
3
|
+
*/
|
|
4
|
+
/**
|
|
5
|
+
* Add an AST Expression to the e-graph, returning its e-class ID
|
|
6
|
+
*/
|
|
7
|
+
export function addExpression(egraph, expr) {
|
|
8
|
+
switch (expr.kind) {
|
|
9
|
+
case 'number':
|
|
10
|
+
return egraph.add({ tag: 'num', value: expr.value });
|
|
11
|
+
case 'variable':
|
|
12
|
+
return egraph.add({ tag: 'var', name: expr.name });
|
|
13
|
+
case 'binary': {
|
|
14
|
+
const left = addExpression(egraph, expr.left);
|
|
15
|
+
const right = addExpression(egraph, expr.right);
|
|
16
|
+
switch (expr.operator) {
|
|
17
|
+
case '+':
|
|
18
|
+
return egraph.add({ tag: 'add', children: [left, right] });
|
|
19
|
+
case '-':
|
|
20
|
+
return egraph.add({ tag: 'sub', children: [left, right] });
|
|
21
|
+
case '*':
|
|
22
|
+
return egraph.add({ tag: 'mul', children: [left, right] });
|
|
23
|
+
case '/':
|
|
24
|
+
return egraph.add({ tag: 'div', children: [left, right] });
|
|
25
|
+
case '^':
|
|
26
|
+
case '**':
|
|
27
|
+
return egraph.add({ tag: 'pow', children: [left, right] });
|
|
28
|
+
}
|
|
29
|
+
}
|
|
30
|
+
case 'unary': {
|
|
31
|
+
const operand = addExpression(egraph, expr.operand);
|
|
32
|
+
if (expr.operator === '-') {
|
|
33
|
+
return egraph.add({ tag: 'neg', child: operand });
|
|
34
|
+
}
|
|
35
|
+
// Unary + is identity
|
|
36
|
+
return operand;
|
|
37
|
+
}
|
|
38
|
+
case 'call': {
|
|
39
|
+
const args = expr.args.map(arg => addExpression(egraph, arg));
|
|
40
|
+
return egraph.add({ tag: 'call', name: expr.name, children: args });
|
|
41
|
+
}
|
|
42
|
+
case 'component': {
|
|
43
|
+
const object = addExpression(egraph, expr.object);
|
|
44
|
+
return egraph.add({ tag: 'component', object, field: expr.component });
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
/**
|
|
49
|
+
* Add multiple expressions, returning a map of original keys to e-class IDs
|
|
50
|
+
*/
|
|
51
|
+
export function addExpressions(egraph, expressions) {
|
|
52
|
+
const result = new Map();
|
|
53
|
+
for (const [key, expr] of expressions) {
|
|
54
|
+
result.set(key, addExpression(egraph, expr));
|
|
55
|
+
}
|
|
56
|
+
return result;
|
|
57
|
+
}
|
|
58
|
+
/**
|
|
59
|
+
* Add all gradients (Map<paramName, Map<component, Expression>>)
|
|
60
|
+
* Returns Map<paramName, Map<component, EClassId>>
|
|
61
|
+
*/
|
|
62
|
+
export function addGradients(egraph, gradients) {
|
|
63
|
+
const result = new Map();
|
|
64
|
+
for (const [paramName, components] of gradients) {
|
|
65
|
+
const componentIds = new Map();
|
|
66
|
+
for (const [comp, expr] of components) {
|
|
67
|
+
componentIds.set(comp, addExpression(egraph, expr));
|
|
68
|
+
}
|
|
69
|
+
result.set(paramName, componentIds);
|
|
70
|
+
}
|
|
71
|
+
return result;
|
|
72
|
+
}
|
|
73
|
+
/**
|
|
74
|
+
* Get all root e-class IDs from gradient structure
|
|
75
|
+
*/
|
|
76
|
+
export function getRootIds(gradientIds) {
|
|
77
|
+
const roots = [];
|
|
78
|
+
for (const components of gradientIds.values()) {
|
|
79
|
+
for (const id of components.values()) {
|
|
80
|
+
roots.push(id);
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
return roots;
|
|
84
|
+
}
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* E-Graph: Equality Graph for expression optimization
|
|
3
|
+
*
|
|
4
|
+
* An e-graph efficiently represents equivalence classes of expressions.
|
|
5
|
+
* It supports:
|
|
6
|
+
* - Adding expressions (returns e-class ID)
|
|
7
|
+
* - Merging e-classes (union)
|
|
8
|
+
* - Finding canonical e-class (find)
|
|
9
|
+
* - Rebuilding after merges (maintains congruence)
|
|
10
|
+
*/
|
|
11
|
+
import { ENode, EClassId } from './ENode.js';
|
|
12
|
+
/**
|
|
13
|
+
* E-Class: An equivalence class of expressions
|
|
14
|
+
*/
|
|
15
|
+
export interface EClass {
|
|
16
|
+
id: EClassId;
|
|
17
|
+
nodes: Set<string>;
|
|
18
|
+
parents: Set<string>;
|
|
19
|
+
}
|
|
20
|
+
/**
|
|
21
|
+
* E-Graph: The main data structure
|
|
22
|
+
*/
|
|
23
|
+
export declare class EGraph {
|
|
24
|
+
private nextId;
|
|
25
|
+
private classes;
|
|
26
|
+
private parent;
|
|
27
|
+
private rank;
|
|
28
|
+
private hashcons;
|
|
29
|
+
private nodeStore;
|
|
30
|
+
private pending;
|
|
31
|
+
/**
|
|
32
|
+
* Find the canonical e-class ID (with path compression)
|
|
33
|
+
*/
|
|
34
|
+
find(id: EClassId): EClassId;
|
|
35
|
+
/**
|
|
36
|
+
* Add an e-node to the e-graph, returning its e-class ID
|
|
37
|
+
* If the node already exists, returns the existing class
|
|
38
|
+
*/
|
|
39
|
+
add(node: ENode): EClassId;
|
|
40
|
+
/**
|
|
41
|
+
* Merge two e-classes, returning the new canonical ID
|
|
42
|
+
*/
|
|
43
|
+
merge(id1: EClassId, id2: EClassId): EClassId;
|
|
44
|
+
/**
|
|
45
|
+
* Rebuild the e-graph to restore congruence invariants
|
|
46
|
+
* Must be called after a batch of merges
|
|
47
|
+
*/
|
|
48
|
+
rebuild(): void;
|
|
49
|
+
/**
|
|
50
|
+
* Repair an e-class after merges
|
|
51
|
+
*/
|
|
52
|
+
private repair;
|
|
53
|
+
/**
|
|
54
|
+
* Find which e-class contains a node (by key)
|
|
55
|
+
*/
|
|
56
|
+
private findClassForNode;
|
|
57
|
+
/**
|
|
58
|
+
* Canonicalize an e-node (update children to canonical IDs)
|
|
59
|
+
*/
|
|
60
|
+
private canonicalize;
|
|
61
|
+
/**
|
|
62
|
+
* Get all e-class IDs
|
|
63
|
+
*/
|
|
64
|
+
getClassIds(): EClassId[];
|
|
65
|
+
/**
|
|
66
|
+
* Get an e-class by ID
|
|
67
|
+
*/
|
|
68
|
+
getClass(id: EClassId): EClass | undefined;
|
|
69
|
+
/**
|
|
70
|
+
* Get all e-nodes in an e-class
|
|
71
|
+
*/
|
|
72
|
+
getNodes(classId: EClassId): ENode[];
|
|
73
|
+
/**
|
|
74
|
+
* Get the number of e-classes
|
|
75
|
+
*/
|
|
76
|
+
get size(): number;
|
|
77
|
+
/**
|
|
78
|
+
* Get a node by its key
|
|
79
|
+
*/
|
|
80
|
+
getNodeByKey(key: string): ENode | undefined;
|
|
81
|
+
/**
|
|
82
|
+
* Lookup e-class by node (if it exists)
|
|
83
|
+
*/
|
|
84
|
+
lookup(node: ENode): EClassId | undefined;
|
|
85
|
+
/**
|
|
86
|
+
* Debug: print e-graph state
|
|
87
|
+
*/
|
|
88
|
+
dump(): string;
|
|
89
|
+
/**
|
|
90
|
+
* Convert e-node to readable string
|
|
91
|
+
*/
|
|
92
|
+
private nodeToString;
|
|
93
|
+
}
|