gradient-script 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -1
- package/dist/cli.js +80 -3
- package/dist/dsl/CodeGen.d.ts +1 -1
- package/dist/dsl/CodeGen.js +332 -74
- package/dist/dsl/ExpressionUtils.d.ts +8 -2
- package/dist/dsl/ExpressionUtils.js +34 -2
- package/dist/dsl/GradientChecker.d.ts +21 -0
- package/dist/dsl/GradientChecker.js +109 -23
- package/dist/dsl/Guards.d.ts +1 -1
- package/dist/dsl/Guards.js +14 -13
- package/dist/dsl/Inliner.d.ts +5 -0
- package/dist/dsl/Inliner.js +8 -0
- package/dist/dsl/Simplify.d.ts +7 -0
- package/dist/dsl/Simplify.js +136 -0
- package/dist/dsl/egraph/Convert.d.ts +23 -0
- package/dist/dsl/egraph/Convert.js +84 -0
- package/dist/dsl/egraph/EGraph.d.ts +93 -0
- package/dist/dsl/egraph/EGraph.js +292 -0
- package/dist/dsl/egraph/ENode.d.ts +63 -0
- package/dist/dsl/egraph/ENode.js +94 -0
- package/dist/dsl/egraph/Extractor.d.ts +49 -0
- package/dist/dsl/egraph/Extractor.js +1068 -0
- package/dist/dsl/egraph/Optimizer.d.ts +50 -0
- package/dist/dsl/egraph/Optimizer.js +88 -0
- package/dist/dsl/egraph/Pattern.d.ts +80 -0
- package/dist/dsl/egraph/Pattern.js +325 -0
- package/dist/dsl/egraph/Rewriter.d.ts +44 -0
- package/dist/dsl/egraph/Rewriter.js +131 -0
- package/dist/dsl/egraph/Rules.d.ts +44 -0
- package/dist/dsl/egraph/Rules.js +187 -0
- package/dist/dsl/egraph/index.d.ts +15 -0
- package/dist/dsl/egraph/index.js +21 -0
- package/package.json +1 -1
- package/dist/dsl/CSE.d.ts +0 -21
- package/dist/dsl/CSE.js +0 -168
- package/dist/symbolic/AST.d.ts +0 -113
- package/dist/symbolic/AST.js +0 -128
- package/dist/symbolic/CodeGen.d.ts +0 -35
- package/dist/symbolic/CodeGen.js +0 -280
- package/dist/symbolic/Parser.d.ts +0 -64
- package/dist/symbolic/Parser.js +0 -329
- package/dist/symbolic/Simplify.d.ts +0 -10
- package/dist/symbolic/Simplify.js +0 -244
- package/dist/symbolic/SymbolicDiff.d.ts +0 -35
- package/dist/symbolic/SymbolicDiff.js +0 -339
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* E-Graph: Equality Graph for expression optimization
|
|
3
|
+
*
|
|
4
|
+
* An e-graph efficiently represents equivalence classes of expressions.
|
|
5
|
+
* It supports:
|
|
6
|
+
* - Adding expressions (returns e-class ID)
|
|
7
|
+
* - Merging e-classes (union)
|
|
8
|
+
* - Finding canonical e-class (find)
|
|
9
|
+
* - Rebuilding after merges (maintains congruence)
|
|
10
|
+
*/
|
|
11
|
+
import { ENode, EClassId } from './ENode.js';
|
|
12
|
+
/**
|
|
13
|
+
* E-Class: An equivalence class of expressions
|
|
14
|
+
*/
|
|
15
|
+
export interface EClass {
|
|
16
|
+
id: EClassId;
|
|
17
|
+
nodes: Set<string>;
|
|
18
|
+
parents: Set<string>;
|
|
19
|
+
}
|
|
20
|
+
/**
|
|
21
|
+
* E-Graph: The main data structure
|
|
22
|
+
*/
|
|
23
|
+
export declare class EGraph {
|
|
24
|
+
private nextId;
|
|
25
|
+
private classes;
|
|
26
|
+
private parent;
|
|
27
|
+
private rank;
|
|
28
|
+
private hashcons;
|
|
29
|
+
private nodeStore;
|
|
30
|
+
private pending;
|
|
31
|
+
/**
|
|
32
|
+
* Find the canonical e-class ID (with path compression)
|
|
33
|
+
*/
|
|
34
|
+
find(id: EClassId): EClassId;
|
|
35
|
+
/**
|
|
36
|
+
* Add an e-node to the e-graph, returning its e-class ID
|
|
37
|
+
* If the node already exists, returns the existing class
|
|
38
|
+
*/
|
|
39
|
+
add(node: ENode): EClassId;
|
|
40
|
+
/**
|
|
41
|
+
* Merge two e-classes, returning the new canonical ID
|
|
42
|
+
*/
|
|
43
|
+
merge(id1: EClassId, id2: EClassId): EClassId;
|
|
44
|
+
/**
|
|
45
|
+
* Rebuild the e-graph to restore congruence invariants
|
|
46
|
+
* Must be called after a batch of merges
|
|
47
|
+
*/
|
|
48
|
+
rebuild(): void;
|
|
49
|
+
/**
|
|
50
|
+
* Repair an e-class after merges
|
|
51
|
+
*/
|
|
52
|
+
private repair;
|
|
53
|
+
/**
|
|
54
|
+
* Find which e-class contains a node (by key)
|
|
55
|
+
*/
|
|
56
|
+
private findClassForNode;
|
|
57
|
+
/**
|
|
58
|
+
* Canonicalize an e-node (update children to canonical IDs)
|
|
59
|
+
*/
|
|
60
|
+
private canonicalize;
|
|
61
|
+
/**
|
|
62
|
+
* Get all e-class IDs
|
|
63
|
+
*/
|
|
64
|
+
getClassIds(): EClassId[];
|
|
65
|
+
/**
|
|
66
|
+
* Get an e-class by ID
|
|
67
|
+
*/
|
|
68
|
+
getClass(id: EClassId): EClass | undefined;
|
|
69
|
+
/**
|
|
70
|
+
* Get all e-nodes in an e-class
|
|
71
|
+
*/
|
|
72
|
+
getNodes(classId: EClassId): ENode[];
|
|
73
|
+
/**
|
|
74
|
+
* Get the number of e-classes
|
|
75
|
+
*/
|
|
76
|
+
get size(): number;
|
|
77
|
+
/**
|
|
78
|
+
* Get a node by its key
|
|
79
|
+
*/
|
|
80
|
+
getNodeByKey(key: string): ENode | undefined;
|
|
81
|
+
/**
|
|
82
|
+
* Lookup e-class by node (if it exists)
|
|
83
|
+
*/
|
|
84
|
+
lookup(node: ENode): EClassId | undefined;
|
|
85
|
+
/**
|
|
86
|
+
* Debug: print e-graph state
|
|
87
|
+
*/
|
|
88
|
+
dump(): string;
|
|
89
|
+
/**
|
|
90
|
+
* Convert e-node to readable string
|
|
91
|
+
*/
|
|
92
|
+
private nodeToString;
|
|
93
|
+
}
|
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* E-Graph: Equality Graph for expression optimization
|
|
3
|
+
*
|
|
4
|
+
* An e-graph efficiently represents equivalence classes of expressions.
|
|
5
|
+
* It supports:
|
|
6
|
+
* - Adding expressions (returns e-class ID)
|
|
7
|
+
* - Merging e-classes (union)
|
|
8
|
+
* - Finding canonical e-class (find)
|
|
9
|
+
* - Rebuilding after merges (maintains congruence)
|
|
10
|
+
*/
|
|
11
|
+
import { enodeKey, enodeChildren, enodeWithChildren } from './ENode.js';
|
|
12
|
+
/**
|
|
13
|
+
* E-Graph: The main data structure
|
|
14
|
+
*/
|
|
15
|
+
export class EGraph {
|
|
16
|
+
nextId = 0;
|
|
17
|
+
classes = new Map();
|
|
18
|
+
parent = new Map(); // Union-find parent
|
|
19
|
+
rank = new Map(); // Union-find rank
|
|
20
|
+
hashcons = new Map(); // E-node key -> e-class
|
|
21
|
+
nodeStore = new Map(); // Key -> actual node
|
|
22
|
+
pending = []; // Classes needing rebuild
|
|
23
|
+
/**
|
|
24
|
+
* Find the canonical e-class ID (with path compression)
|
|
25
|
+
*/
|
|
26
|
+
find(id) {
|
|
27
|
+
let root = id;
|
|
28
|
+
while (this.parent.get(root) !== root) {
|
|
29
|
+
root = this.parent.get(root);
|
|
30
|
+
}
|
|
31
|
+
// Path compression
|
|
32
|
+
let current = id;
|
|
33
|
+
while (this.parent.get(current) !== root) {
|
|
34
|
+
const next = this.parent.get(current);
|
|
35
|
+
this.parent.set(current, root);
|
|
36
|
+
current = next;
|
|
37
|
+
}
|
|
38
|
+
return root;
|
|
39
|
+
}
|
|
40
|
+
/**
|
|
41
|
+
* Add an e-node to the e-graph, returning its e-class ID
|
|
42
|
+
* If the node already exists, returns the existing class
|
|
43
|
+
*/
|
|
44
|
+
add(node) {
|
|
45
|
+
// Canonicalize children first
|
|
46
|
+
const canonNode = this.canonicalize(node);
|
|
47
|
+
const key = enodeKey(canonNode);
|
|
48
|
+
// Check if already exists
|
|
49
|
+
const existing = this.hashcons.get(key);
|
|
50
|
+
if (existing !== undefined) {
|
|
51
|
+
return this.find(existing);
|
|
52
|
+
}
|
|
53
|
+
// Create new e-class
|
|
54
|
+
const id = this.nextId++;
|
|
55
|
+
this.parent.set(id, id);
|
|
56
|
+
this.rank.set(id, 0);
|
|
57
|
+
const eclass = {
|
|
58
|
+
id,
|
|
59
|
+
nodes: new Set([key]),
|
|
60
|
+
parents: new Set()
|
|
61
|
+
};
|
|
62
|
+
this.classes.set(id, eclass);
|
|
63
|
+
this.hashcons.set(key, id);
|
|
64
|
+
this.nodeStore.set(key, canonNode);
|
|
65
|
+
// Register as parent of children
|
|
66
|
+
for (const childId of enodeChildren(canonNode)) {
|
|
67
|
+
const childClass = this.classes.get(this.find(childId));
|
|
68
|
+
if (childClass) {
|
|
69
|
+
childClass.parents.add(key);
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
return id;
|
|
73
|
+
}
|
|
74
|
+
/**
|
|
75
|
+
* Merge two e-classes, returning the new canonical ID
|
|
76
|
+
*/
|
|
77
|
+
merge(id1, id2) {
|
|
78
|
+
const root1 = this.find(id1);
|
|
79
|
+
const root2 = this.find(id2);
|
|
80
|
+
if (root1 === root2) {
|
|
81
|
+
return root1;
|
|
82
|
+
}
|
|
83
|
+
// Union by rank
|
|
84
|
+
const rank1 = this.rank.get(root1);
|
|
85
|
+
const rank2 = this.rank.get(root2);
|
|
86
|
+
let newRoot;
|
|
87
|
+
let oldRoot;
|
|
88
|
+
if (rank1 < rank2) {
|
|
89
|
+
newRoot = root2;
|
|
90
|
+
oldRoot = root1;
|
|
91
|
+
}
|
|
92
|
+
else if (rank1 > rank2) {
|
|
93
|
+
newRoot = root1;
|
|
94
|
+
oldRoot = root2;
|
|
95
|
+
}
|
|
96
|
+
else {
|
|
97
|
+
newRoot = root1;
|
|
98
|
+
oldRoot = root2;
|
|
99
|
+
this.rank.set(newRoot, rank1 + 1);
|
|
100
|
+
}
|
|
101
|
+
this.parent.set(oldRoot, newRoot);
|
|
102
|
+
// Merge e-class data
|
|
103
|
+
const newClass = this.classes.get(newRoot);
|
|
104
|
+
const oldClass = this.classes.get(oldRoot);
|
|
105
|
+
for (const nodeKey of oldClass.nodes) {
|
|
106
|
+
newClass.nodes.add(nodeKey);
|
|
107
|
+
}
|
|
108
|
+
for (const parentKey of oldClass.parents) {
|
|
109
|
+
newClass.parents.add(parentKey);
|
|
110
|
+
}
|
|
111
|
+
// Mark for rebuild
|
|
112
|
+
this.pending.push(newRoot);
|
|
113
|
+
return newRoot;
|
|
114
|
+
}
|
|
115
|
+
/**
|
|
116
|
+
* Rebuild the e-graph to restore congruence invariants
|
|
117
|
+
* Must be called after a batch of merges
|
|
118
|
+
*/
|
|
119
|
+
rebuild() {
|
|
120
|
+
while (this.pending.length > 0) {
|
|
121
|
+
const todo = [...this.pending];
|
|
122
|
+
this.pending = [];
|
|
123
|
+
for (const classId of todo) {
|
|
124
|
+
this.repair(this.find(classId));
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
/**
|
|
129
|
+
* Repair an e-class after merges
|
|
130
|
+
*/
|
|
131
|
+
repair(classId) {
|
|
132
|
+
const eclass = this.classes.get(classId);
|
|
133
|
+
if (!eclass)
|
|
134
|
+
return;
|
|
135
|
+
// Collect parent nodes that need re-canonicalization
|
|
136
|
+
const oldParents = new Set(eclass.parents);
|
|
137
|
+
eclass.parents.clear();
|
|
138
|
+
for (const parentKey of oldParents) {
|
|
139
|
+
const parentNode = this.nodeStore.get(parentKey);
|
|
140
|
+
if (!parentNode)
|
|
141
|
+
continue;
|
|
142
|
+
// Remove old hashcons entry
|
|
143
|
+
this.hashcons.delete(parentKey);
|
|
144
|
+
// Re-canonicalize and re-add
|
|
145
|
+
const canonNode = this.canonicalize(parentNode);
|
|
146
|
+
const newKey = enodeKey(canonNode);
|
|
147
|
+
const existingClass = this.hashcons.get(newKey);
|
|
148
|
+
if (existingClass !== undefined) {
|
|
149
|
+
// Node already exists in another class - merge
|
|
150
|
+
const parentClassId = this.findClassForNode(parentKey);
|
|
151
|
+
if (parentClassId !== undefined) {
|
|
152
|
+
this.merge(parentClassId, existingClass);
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
else {
|
|
156
|
+
// Update hashcons with new key
|
|
157
|
+
const parentClassId = this.findClassForNode(parentKey);
|
|
158
|
+
if (parentClassId !== undefined) {
|
|
159
|
+
this.hashcons.set(newKey, parentClassId);
|
|
160
|
+
this.nodeStore.set(newKey, canonNode);
|
|
161
|
+
const parentClass = this.classes.get(this.find(parentClassId));
|
|
162
|
+
if (parentClass) {
|
|
163
|
+
parentClass.nodes.delete(parentKey);
|
|
164
|
+
parentClass.nodes.add(newKey);
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
// Re-register parents for this class
|
|
170
|
+
for (const nodeKey of eclass.nodes) {
|
|
171
|
+
const node = this.nodeStore.get(nodeKey);
|
|
172
|
+
if (!node)
|
|
173
|
+
continue;
|
|
174
|
+
for (const childId of enodeChildren(node)) {
|
|
175
|
+
const childClass = this.classes.get(this.find(childId));
|
|
176
|
+
if (childClass) {
|
|
177
|
+
childClass.parents.add(nodeKey);
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
/**
|
|
183
|
+
* Find which e-class contains a node (by key)
|
|
184
|
+
*/
|
|
185
|
+
findClassForNode(nodeKey) {
|
|
186
|
+
for (const [id, eclass] of this.classes) {
|
|
187
|
+
if (eclass.nodes.has(nodeKey)) {
|
|
188
|
+
return this.find(id);
|
|
189
|
+
}
|
|
190
|
+
}
|
|
191
|
+
return undefined;
|
|
192
|
+
}
|
|
193
|
+
/**
|
|
194
|
+
* Canonicalize an e-node (update children to canonical IDs)
|
|
195
|
+
*/
|
|
196
|
+
canonicalize(node) {
|
|
197
|
+
const children = enodeChildren(node);
|
|
198
|
+
if (children.length === 0) {
|
|
199
|
+
return node;
|
|
200
|
+
}
|
|
201
|
+
const canonChildren = children.map(id => this.find(id));
|
|
202
|
+
return enodeWithChildren(node, canonChildren);
|
|
203
|
+
}
|
|
204
|
+
/**
|
|
205
|
+
* Get all e-class IDs
|
|
206
|
+
*/
|
|
207
|
+
getClassIds() {
|
|
208
|
+
const canonical = new Set();
|
|
209
|
+
for (const id of this.classes.keys()) {
|
|
210
|
+
canonical.add(this.find(id));
|
|
211
|
+
}
|
|
212
|
+
return [...canonical];
|
|
213
|
+
}
|
|
214
|
+
/**
|
|
215
|
+
* Get an e-class by ID
|
|
216
|
+
*/
|
|
217
|
+
getClass(id) {
|
|
218
|
+
return this.classes.get(this.find(id));
|
|
219
|
+
}
|
|
220
|
+
/**
|
|
221
|
+
* Get all e-nodes in an e-class
|
|
222
|
+
*/
|
|
223
|
+
getNodes(classId) {
|
|
224
|
+
const eclass = this.classes.get(this.find(classId));
|
|
225
|
+
if (!eclass)
|
|
226
|
+
return [];
|
|
227
|
+
const nodes = [];
|
|
228
|
+
for (const key of eclass.nodes) {
|
|
229
|
+
const node = this.nodeStore.get(key);
|
|
230
|
+
if (node) {
|
|
231
|
+
nodes.push(this.canonicalize(node));
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
return nodes;
|
|
235
|
+
}
|
|
236
|
+
/**
|
|
237
|
+
* Get the number of e-classes
|
|
238
|
+
*/
|
|
239
|
+
get size() {
|
|
240
|
+
return this.getClassIds().length;
|
|
241
|
+
}
|
|
242
|
+
/**
|
|
243
|
+
* Get a node by its key
|
|
244
|
+
*/
|
|
245
|
+
getNodeByKey(key) {
|
|
246
|
+
return this.nodeStore.get(key);
|
|
247
|
+
}
|
|
248
|
+
/**
|
|
249
|
+
* Lookup e-class by node (if it exists)
|
|
250
|
+
*/
|
|
251
|
+
lookup(node) {
|
|
252
|
+
const canonNode = this.canonicalize(node);
|
|
253
|
+
const key = enodeKey(canonNode);
|
|
254
|
+
const id = this.hashcons.get(key);
|
|
255
|
+
return id !== undefined ? this.find(id) : undefined;
|
|
256
|
+
}
|
|
257
|
+
/**
|
|
258
|
+
* Debug: print e-graph state
|
|
259
|
+
*/
|
|
260
|
+
dump() {
|
|
261
|
+
const lines = ['E-Graph:'];
|
|
262
|
+
for (const classId of this.getClassIds()) {
|
|
263
|
+
const eclass = this.classes.get(classId);
|
|
264
|
+
if (!eclass)
|
|
265
|
+
continue;
|
|
266
|
+
const nodeStrs = [...eclass.nodes].map(key => {
|
|
267
|
+
const node = this.nodeStore.get(key);
|
|
268
|
+
return node ? this.nodeToString(node) : key;
|
|
269
|
+
});
|
|
270
|
+
lines.push(` [${classId}]: ${nodeStrs.join(' = ')}`);
|
|
271
|
+
}
|
|
272
|
+
return lines.join('\n');
|
|
273
|
+
}
|
|
274
|
+
/**
|
|
275
|
+
* Convert e-node to readable string
|
|
276
|
+
*/
|
|
277
|
+
nodeToString(node) {
|
|
278
|
+
switch (node.tag) {
|
|
279
|
+
case 'num': return `${node.value}`;
|
|
280
|
+
case 'var': return node.name;
|
|
281
|
+
case 'add': return `(+ e${node.children[0]} e${node.children[1]})`;
|
|
282
|
+
case 'mul': return `(* e${node.children[0]} e${node.children[1]})`;
|
|
283
|
+
case 'sub': return `(- e${node.children[0]} e${node.children[1]})`;
|
|
284
|
+
case 'div': return `(/ e${node.children[0]} e${node.children[1]})`;
|
|
285
|
+
case 'pow': return `(^ e${node.children[0]} e${node.children[1]})`;
|
|
286
|
+
case 'neg': return `(neg e${node.child})`;
|
|
287
|
+
case 'inv': return `(inv e${node.child})`;
|
|
288
|
+
case 'call': return `(${node.name} ${node.children.map(c => `e${c}`).join(' ')})`;
|
|
289
|
+
case 'component': return `(. e${node.object} ${node.field})`;
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
}
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* E-Node: Expression nodes in an e-graph
|
|
3
|
+
*
|
|
4
|
+
* E-nodes are hash-consed (deduplicated) and reference e-classes by ID.
|
|
5
|
+
* This allows the e-graph to represent equivalence classes efficiently.
|
|
6
|
+
*/
|
|
7
|
+
export type EClassId = number;
|
|
8
|
+
/**
|
|
9
|
+
* E-node variants representing different expression types
|
|
10
|
+
*/
|
|
11
|
+
export type ENode = {
|
|
12
|
+
tag: 'num';
|
|
13
|
+
value: number;
|
|
14
|
+
} | {
|
|
15
|
+
tag: 'var';
|
|
16
|
+
name: string;
|
|
17
|
+
} | {
|
|
18
|
+
tag: 'add';
|
|
19
|
+
children: [EClassId, EClassId];
|
|
20
|
+
} | {
|
|
21
|
+
tag: 'mul';
|
|
22
|
+
children: [EClassId, EClassId];
|
|
23
|
+
} | {
|
|
24
|
+
tag: 'sub';
|
|
25
|
+
children: [EClassId, EClassId];
|
|
26
|
+
} | {
|
|
27
|
+
tag: 'div';
|
|
28
|
+
children: [EClassId, EClassId];
|
|
29
|
+
} | {
|
|
30
|
+
tag: 'pow';
|
|
31
|
+
children: [EClassId, EClassId];
|
|
32
|
+
} | {
|
|
33
|
+
tag: 'neg';
|
|
34
|
+
child: EClassId;
|
|
35
|
+
} | {
|
|
36
|
+
tag: 'inv';
|
|
37
|
+
child: EClassId;
|
|
38
|
+
} | {
|
|
39
|
+
tag: 'call';
|
|
40
|
+
name: string;
|
|
41
|
+
children: EClassId[];
|
|
42
|
+
} | {
|
|
43
|
+
tag: 'component';
|
|
44
|
+
object: EClassId;
|
|
45
|
+
field: string;
|
|
46
|
+
};
|
|
47
|
+
/**
|
|
48
|
+
* Create a canonical string key for an e-node (for hash-consing)
|
|
49
|
+
* This key is used to detect structurally identical nodes.
|
|
50
|
+
*/
|
|
51
|
+
export declare function enodeKey(node: ENode): string;
|
|
52
|
+
/**
|
|
53
|
+
* Get all e-class IDs that this node references (its children)
|
|
54
|
+
*/
|
|
55
|
+
export declare function enodeChildren(node: ENode): EClassId[];
|
|
56
|
+
/**
|
|
57
|
+
* Create a new e-node with updated children (after canonicalization)
|
|
58
|
+
*/
|
|
59
|
+
export declare function enodeWithChildren(node: ENode, newChildren: EClassId[]): ENode;
|
|
60
|
+
/**
|
|
61
|
+
* Check if two e-nodes are structurally equal
|
|
62
|
+
*/
|
|
63
|
+
export declare function enodesEqual(a: ENode, b: ENode): boolean;
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* E-Node: Expression nodes in an e-graph
|
|
3
|
+
*
|
|
4
|
+
* E-nodes are hash-consed (deduplicated) and reference e-classes by ID.
|
|
5
|
+
* This allows the e-graph to represent equivalence classes efficiently.
|
|
6
|
+
*/
|
|
7
|
+
/**
|
|
8
|
+
* Create a canonical string key for an e-node (for hash-consing)
|
|
9
|
+
* This key is used to detect structurally identical nodes.
|
|
10
|
+
*/
|
|
11
|
+
export function enodeKey(node) {
|
|
12
|
+
switch (node.tag) {
|
|
13
|
+
case 'num':
|
|
14
|
+
return `num:${node.value}`;
|
|
15
|
+
case 'var':
|
|
16
|
+
return `var:${node.name}`;
|
|
17
|
+
case 'add':
|
|
18
|
+
return `add:${node.children[0]},${node.children[1]}`;
|
|
19
|
+
case 'mul':
|
|
20
|
+
return `mul:${node.children[0]},${node.children[1]}`;
|
|
21
|
+
case 'sub':
|
|
22
|
+
return `sub:${node.children[0]},${node.children[1]}`;
|
|
23
|
+
case 'div':
|
|
24
|
+
return `div:${node.children[0]},${node.children[1]}`;
|
|
25
|
+
case 'pow':
|
|
26
|
+
return `pow:${node.children[0]},${node.children[1]}`;
|
|
27
|
+
case 'neg':
|
|
28
|
+
return `neg:${node.child}`;
|
|
29
|
+
case 'inv':
|
|
30
|
+
return `inv:${node.child}`;
|
|
31
|
+
case 'call':
|
|
32
|
+
return `call:${node.name}(${node.children.join(',')})`;
|
|
33
|
+
case 'component':
|
|
34
|
+
return `comp:${node.object}.${node.field}`;
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
/**
|
|
38
|
+
* Get all e-class IDs that this node references (its children)
|
|
39
|
+
*/
|
|
40
|
+
export function enodeChildren(node) {
|
|
41
|
+
switch (node.tag) {
|
|
42
|
+
case 'num':
|
|
43
|
+
case 'var':
|
|
44
|
+
return [];
|
|
45
|
+
case 'add':
|
|
46
|
+
case 'mul':
|
|
47
|
+
case 'sub':
|
|
48
|
+
case 'div':
|
|
49
|
+
case 'pow':
|
|
50
|
+
return [...node.children];
|
|
51
|
+
case 'neg':
|
|
52
|
+
return [node.child];
|
|
53
|
+
case 'inv':
|
|
54
|
+
return [node.child];
|
|
55
|
+
case 'call':
|
|
56
|
+
return [...node.children];
|
|
57
|
+
case 'component':
|
|
58
|
+
return [node.object];
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
/**
|
|
62
|
+
* Create a new e-node with updated children (after canonicalization)
|
|
63
|
+
*/
|
|
64
|
+
export function enodeWithChildren(node, newChildren) {
|
|
65
|
+
switch (node.tag) {
|
|
66
|
+
case 'num':
|
|
67
|
+
case 'var':
|
|
68
|
+
return node;
|
|
69
|
+
case 'add':
|
|
70
|
+
return { tag: 'add', children: [newChildren[0], newChildren[1]] };
|
|
71
|
+
case 'mul':
|
|
72
|
+
return { tag: 'mul', children: [newChildren[0], newChildren[1]] };
|
|
73
|
+
case 'sub':
|
|
74
|
+
return { tag: 'sub', children: [newChildren[0], newChildren[1]] };
|
|
75
|
+
case 'div':
|
|
76
|
+
return { tag: 'div', children: [newChildren[0], newChildren[1]] };
|
|
77
|
+
case 'pow':
|
|
78
|
+
return { tag: 'pow', children: [newChildren[0], newChildren[1]] };
|
|
79
|
+
case 'neg':
|
|
80
|
+
return { tag: 'neg', child: newChildren[0] };
|
|
81
|
+
case 'inv':
|
|
82
|
+
return { tag: 'inv', child: newChildren[0] };
|
|
83
|
+
case 'call':
|
|
84
|
+
return { tag: 'call', name: node.name, children: newChildren };
|
|
85
|
+
case 'component':
|
|
86
|
+
return { tag: 'component', object: newChildren[0], field: node.field };
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
/**
|
|
90
|
+
* Check if two e-nodes are structurally equal
|
|
91
|
+
*/
|
|
92
|
+
export function enodesEqual(a, b) {
|
|
93
|
+
return enodeKey(a) === enodeKey(b);
|
|
94
|
+
}
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Cost-Based Extraction from E-Graphs
|
|
3
|
+
*
|
|
4
|
+
* Extracts the lowest-cost expression from each e-class.
|
|
5
|
+
* Also detects common subexpressions (shared e-classes) for CSE.
|
|
6
|
+
*/
|
|
7
|
+
import { EGraph } from './EGraph.js';
|
|
8
|
+
import { EClassId } from './ENode.js';
|
|
9
|
+
import { Expression } from '../AST.js';
|
|
10
|
+
/**
|
|
11
|
+
* Cost of different operations
|
|
12
|
+
* Higher cost = less preferred
|
|
13
|
+
*/
|
|
14
|
+
export interface CostModel {
|
|
15
|
+
num: number;
|
|
16
|
+
var: number;
|
|
17
|
+
add: number;
|
|
18
|
+
sub: number;
|
|
19
|
+
mul: number;
|
|
20
|
+
div: number;
|
|
21
|
+
pow: number;
|
|
22
|
+
neg: number;
|
|
23
|
+
inv: number;
|
|
24
|
+
call: number;
|
|
25
|
+
component: number;
|
|
26
|
+
}
|
|
27
|
+
/**
|
|
28
|
+
* Default cost model - division is expensive
|
|
29
|
+
*/
|
|
30
|
+
export declare const defaultCostModel: CostModel;
|
|
31
|
+
/**
|
|
32
|
+
* Result of extraction with CSE
|
|
33
|
+
*/
|
|
34
|
+
export interface ExtractionResult {
|
|
35
|
+
/** Temporary variable definitions (name -> expression) */
|
|
36
|
+
temps: Map<string, Expression>;
|
|
37
|
+
/** The extracted expressions (class ID -> expression using temps) */
|
|
38
|
+
expressions: Map<EClassId, Expression>;
|
|
39
|
+
/** Total cost */
|
|
40
|
+
totalCost: number;
|
|
41
|
+
}
|
|
42
|
+
/**
|
|
43
|
+
* Extract the best expression from an e-class
|
|
44
|
+
*/
|
|
45
|
+
export declare function extractBest(egraph: EGraph, rootId: EClassId, costModel?: CostModel): Expression;
|
|
46
|
+
/**
|
|
47
|
+
* Extract multiple expressions with CSE (shared subexpressions become temps)
|
|
48
|
+
*/
|
|
49
|
+
export declare function extractWithCSE(egraph: EGraph, roots: EClassId[], costModel?: CostModel, minSharedCost?: number): ExtractionResult;
|