gradient-script 0.2.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. package/README.md +3 -1
  2. package/dist/cli.js +219 -6
  3. package/dist/dsl/CodeGen.d.ts +1 -1
  4. package/dist/dsl/CodeGen.js +336 -74
  5. package/dist/dsl/ExpressionUtils.d.ts +8 -2
  6. package/dist/dsl/ExpressionUtils.js +34 -2
  7. package/dist/dsl/GradientChecker.d.ts +21 -0
  8. package/dist/dsl/GradientChecker.js +109 -23
  9. package/dist/dsl/Guards.d.ts +1 -1
  10. package/dist/dsl/Guards.js +14 -13
  11. package/dist/dsl/Inliner.d.ts +5 -0
  12. package/dist/dsl/Inliner.js +8 -0
  13. package/dist/dsl/Simplify.d.ts +7 -0
  14. package/dist/dsl/Simplify.js +136 -0
  15. package/dist/dsl/egraph/Convert.d.ts +23 -0
  16. package/dist/dsl/egraph/Convert.js +84 -0
  17. package/dist/dsl/egraph/EGraph.d.ts +93 -0
  18. package/dist/dsl/egraph/EGraph.js +292 -0
  19. package/dist/dsl/egraph/ENode.d.ts +63 -0
  20. package/dist/dsl/egraph/ENode.js +94 -0
  21. package/dist/dsl/egraph/Extractor.d.ts +49 -0
  22. package/dist/dsl/egraph/Extractor.js +1068 -0
  23. package/dist/dsl/egraph/Optimizer.d.ts +50 -0
  24. package/dist/dsl/egraph/Optimizer.js +88 -0
  25. package/dist/dsl/egraph/Pattern.d.ts +80 -0
  26. package/dist/dsl/egraph/Pattern.js +325 -0
  27. package/dist/dsl/egraph/Rewriter.d.ts +44 -0
  28. package/dist/dsl/egraph/Rewriter.js +131 -0
  29. package/dist/dsl/egraph/Rules.d.ts +44 -0
  30. package/dist/dsl/egraph/Rules.js +187 -0
  31. package/dist/dsl/egraph/index.d.ts +15 -0
  32. package/dist/dsl/egraph/index.js +21 -0
  33. package/package.json +1 -1
  34. package/dist/dsl/CSE.d.ts +0 -21
  35. package/dist/dsl/CSE.js +0 -168
  36. package/dist/symbolic/AST.d.ts +0 -113
  37. package/dist/symbolic/AST.js +0 -128
  38. package/dist/symbolic/CodeGen.d.ts +0 -35
  39. package/dist/symbolic/CodeGen.js +0 -280
  40. package/dist/symbolic/Parser.d.ts +0 -64
  41. package/dist/symbolic/Parser.js +0 -329
  42. package/dist/symbolic/Simplify.d.ts +0 -10
  43. package/dist/symbolic/Simplify.js +0 -244
  44. package/dist/symbolic/SymbolicDiff.d.ts +0 -35
  45. package/dist/symbolic/SymbolicDiff.js +0 -339
@@ -0,0 +1,292 @@
1
+ /**
2
+ * E-Graph: Equality Graph for expression optimization
3
+ *
4
+ * An e-graph efficiently represents equivalence classes of expressions.
5
+ * It supports:
6
+ * - Adding expressions (returns e-class ID)
7
+ * - Merging e-classes (union)
8
+ * - Finding canonical e-class (find)
9
+ * - Rebuilding after merges (maintains congruence)
10
+ */
11
+ import { enodeKey, enodeChildren, enodeWithChildren } from './ENode.js';
12
+ /**
13
+ * E-Graph: The main data structure
14
+ */
15
+ export class EGraph {
16
+ nextId = 0;
17
+ classes = new Map();
18
+ parent = new Map(); // Union-find parent
19
+ rank = new Map(); // Union-find rank
20
+ hashcons = new Map(); // E-node key -> e-class
21
+ nodeStore = new Map(); // Key -> actual node
22
+ pending = []; // Classes needing rebuild
23
+ /**
24
+ * Find the canonical e-class ID (with path compression)
25
+ */
26
+ find(id) {
27
+ let root = id;
28
+ while (this.parent.get(root) !== root) {
29
+ root = this.parent.get(root);
30
+ }
31
+ // Path compression
32
+ let current = id;
33
+ while (this.parent.get(current) !== root) {
34
+ const next = this.parent.get(current);
35
+ this.parent.set(current, root);
36
+ current = next;
37
+ }
38
+ return root;
39
+ }
40
+ /**
41
+ * Add an e-node to the e-graph, returning its e-class ID
42
+ * If the node already exists, returns the existing class
43
+ */
44
+ add(node) {
45
+ // Canonicalize children first
46
+ const canonNode = this.canonicalize(node);
47
+ const key = enodeKey(canonNode);
48
+ // Check if already exists
49
+ const existing = this.hashcons.get(key);
50
+ if (existing !== undefined) {
51
+ return this.find(existing);
52
+ }
53
+ // Create new e-class
54
+ const id = this.nextId++;
55
+ this.parent.set(id, id);
56
+ this.rank.set(id, 0);
57
+ const eclass = {
58
+ id,
59
+ nodes: new Set([key]),
60
+ parents: new Set()
61
+ };
62
+ this.classes.set(id, eclass);
63
+ this.hashcons.set(key, id);
64
+ this.nodeStore.set(key, canonNode);
65
+ // Register as parent of children
66
+ for (const childId of enodeChildren(canonNode)) {
67
+ const childClass = this.classes.get(this.find(childId));
68
+ if (childClass) {
69
+ childClass.parents.add(key);
70
+ }
71
+ }
72
+ return id;
73
+ }
74
+ /**
75
+ * Merge two e-classes, returning the new canonical ID
76
+ */
77
+ merge(id1, id2) {
78
+ const root1 = this.find(id1);
79
+ const root2 = this.find(id2);
80
+ if (root1 === root2) {
81
+ return root1;
82
+ }
83
+ // Union by rank
84
+ const rank1 = this.rank.get(root1);
85
+ const rank2 = this.rank.get(root2);
86
+ let newRoot;
87
+ let oldRoot;
88
+ if (rank1 < rank2) {
89
+ newRoot = root2;
90
+ oldRoot = root1;
91
+ }
92
+ else if (rank1 > rank2) {
93
+ newRoot = root1;
94
+ oldRoot = root2;
95
+ }
96
+ else {
97
+ newRoot = root1;
98
+ oldRoot = root2;
99
+ this.rank.set(newRoot, rank1 + 1);
100
+ }
101
+ this.parent.set(oldRoot, newRoot);
102
+ // Merge e-class data
103
+ const newClass = this.classes.get(newRoot);
104
+ const oldClass = this.classes.get(oldRoot);
105
+ for (const nodeKey of oldClass.nodes) {
106
+ newClass.nodes.add(nodeKey);
107
+ }
108
+ for (const parentKey of oldClass.parents) {
109
+ newClass.parents.add(parentKey);
110
+ }
111
+ // Mark for rebuild
112
+ this.pending.push(newRoot);
113
+ return newRoot;
114
+ }
115
+ /**
116
+ * Rebuild the e-graph to restore congruence invariants
117
+ * Must be called after a batch of merges
118
+ */
119
+ rebuild() {
120
+ while (this.pending.length > 0) {
121
+ const todo = [...this.pending];
122
+ this.pending = [];
123
+ for (const classId of todo) {
124
+ this.repair(this.find(classId));
125
+ }
126
+ }
127
+ }
128
+ /**
129
+ * Repair an e-class after merges
130
+ */
131
+ repair(classId) {
132
+ const eclass = this.classes.get(classId);
133
+ if (!eclass)
134
+ return;
135
+ // Collect parent nodes that need re-canonicalization
136
+ const oldParents = new Set(eclass.parents);
137
+ eclass.parents.clear();
138
+ for (const parentKey of oldParents) {
139
+ const parentNode = this.nodeStore.get(parentKey);
140
+ if (!parentNode)
141
+ continue;
142
+ // Remove old hashcons entry
143
+ this.hashcons.delete(parentKey);
144
+ // Re-canonicalize and re-add
145
+ const canonNode = this.canonicalize(parentNode);
146
+ const newKey = enodeKey(canonNode);
147
+ const existingClass = this.hashcons.get(newKey);
148
+ if (existingClass !== undefined) {
149
+ // Node already exists in another class - merge
150
+ const parentClassId = this.findClassForNode(parentKey);
151
+ if (parentClassId !== undefined) {
152
+ this.merge(parentClassId, existingClass);
153
+ }
154
+ }
155
+ else {
156
+ // Update hashcons with new key
157
+ const parentClassId = this.findClassForNode(parentKey);
158
+ if (parentClassId !== undefined) {
159
+ this.hashcons.set(newKey, parentClassId);
160
+ this.nodeStore.set(newKey, canonNode);
161
+ const parentClass = this.classes.get(this.find(parentClassId));
162
+ if (parentClass) {
163
+ parentClass.nodes.delete(parentKey);
164
+ parentClass.nodes.add(newKey);
165
+ }
166
+ }
167
+ }
168
+ }
169
+ // Re-register parents for this class
170
+ for (const nodeKey of eclass.nodes) {
171
+ const node = this.nodeStore.get(nodeKey);
172
+ if (!node)
173
+ continue;
174
+ for (const childId of enodeChildren(node)) {
175
+ const childClass = this.classes.get(this.find(childId));
176
+ if (childClass) {
177
+ childClass.parents.add(nodeKey);
178
+ }
179
+ }
180
+ }
181
+ }
182
+ /**
183
+ * Find which e-class contains a node (by key)
184
+ */
185
+ findClassForNode(nodeKey) {
186
+ for (const [id, eclass] of this.classes) {
187
+ if (eclass.nodes.has(nodeKey)) {
188
+ return this.find(id);
189
+ }
190
+ }
191
+ return undefined;
192
+ }
193
+ /**
194
+ * Canonicalize an e-node (update children to canonical IDs)
195
+ */
196
+ canonicalize(node) {
197
+ const children = enodeChildren(node);
198
+ if (children.length === 0) {
199
+ return node;
200
+ }
201
+ const canonChildren = children.map(id => this.find(id));
202
+ return enodeWithChildren(node, canonChildren);
203
+ }
204
+ /**
205
+ * Get all e-class IDs
206
+ */
207
+ getClassIds() {
208
+ const canonical = new Set();
209
+ for (const id of this.classes.keys()) {
210
+ canonical.add(this.find(id));
211
+ }
212
+ return [...canonical];
213
+ }
214
+ /**
215
+ * Get an e-class by ID
216
+ */
217
+ getClass(id) {
218
+ return this.classes.get(this.find(id));
219
+ }
220
+ /**
221
+ * Get all e-nodes in an e-class
222
+ */
223
+ getNodes(classId) {
224
+ const eclass = this.classes.get(this.find(classId));
225
+ if (!eclass)
226
+ return [];
227
+ const nodes = [];
228
+ for (const key of eclass.nodes) {
229
+ const node = this.nodeStore.get(key);
230
+ if (node) {
231
+ nodes.push(this.canonicalize(node));
232
+ }
233
+ }
234
+ return nodes;
235
+ }
236
+ /**
237
+ * Get the number of e-classes
238
+ */
239
+ get size() {
240
+ return this.getClassIds().length;
241
+ }
242
+ /**
243
+ * Get a node by its key
244
+ */
245
+ getNodeByKey(key) {
246
+ return this.nodeStore.get(key);
247
+ }
248
+ /**
249
+ * Lookup e-class by node (if it exists)
250
+ */
251
+ lookup(node) {
252
+ const canonNode = this.canonicalize(node);
253
+ const key = enodeKey(canonNode);
254
+ const id = this.hashcons.get(key);
255
+ return id !== undefined ? this.find(id) : undefined;
256
+ }
257
+ /**
258
+ * Debug: print e-graph state
259
+ */
260
+ dump() {
261
+ const lines = ['E-Graph:'];
262
+ for (const classId of this.getClassIds()) {
263
+ const eclass = this.classes.get(classId);
264
+ if (!eclass)
265
+ continue;
266
+ const nodeStrs = [...eclass.nodes].map(key => {
267
+ const node = this.nodeStore.get(key);
268
+ return node ? this.nodeToString(node) : key;
269
+ });
270
+ lines.push(` [${classId}]: ${nodeStrs.join(' = ')}`);
271
+ }
272
+ return lines.join('\n');
273
+ }
274
+ /**
275
+ * Convert e-node to readable string
276
+ */
277
+ nodeToString(node) {
278
+ switch (node.tag) {
279
+ case 'num': return `${node.value}`;
280
+ case 'var': return node.name;
281
+ case 'add': return `(+ e${node.children[0]} e${node.children[1]})`;
282
+ case 'mul': return `(* e${node.children[0]} e${node.children[1]})`;
283
+ case 'sub': return `(- e${node.children[0]} e${node.children[1]})`;
284
+ case 'div': return `(/ e${node.children[0]} e${node.children[1]})`;
285
+ case 'pow': return `(^ e${node.children[0]} e${node.children[1]})`;
286
+ case 'neg': return `(neg e${node.child})`;
287
+ case 'inv': return `(inv e${node.child})`;
288
+ case 'call': return `(${node.name} ${node.children.map(c => `e${c}`).join(' ')})`;
289
+ case 'component': return `(. e${node.object} ${node.field})`;
290
+ }
291
+ }
292
+ }
@@ -0,0 +1,63 @@
1
+ /**
2
+ * E-Node: Expression nodes in an e-graph
3
+ *
4
+ * E-nodes are hash-consed (deduplicated) and reference e-classes by ID.
5
+ * This allows the e-graph to represent equivalence classes efficiently.
6
+ */
7
+ export type EClassId = number;
8
+ /**
9
+ * E-node variants representing different expression types
10
+ */
11
+ export type ENode = {
12
+ tag: 'num';
13
+ value: number;
14
+ } | {
15
+ tag: 'var';
16
+ name: string;
17
+ } | {
18
+ tag: 'add';
19
+ children: [EClassId, EClassId];
20
+ } | {
21
+ tag: 'mul';
22
+ children: [EClassId, EClassId];
23
+ } | {
24
+ tag: 'sub';
25
+ children: [EClassId, EClassId];
26
+ } | {
27
+ tag: 'div';
28
+ children: [EClassId, EClassId];
29
+ } | {
30
+ tag: 'pow';
31
+ children: [EClassId, EClassId];
32
+ } | {
33
+ tag: 'neg';
34
+ child: EClassId;
35
+ } | {
36
+ tag: 'inv';
37
+ child: EClassId;
38
+ } | {
39
+ tag: 'call';
40
+ name: string;
41
+ children: EClassId[];
42
+ } | {
43
+ tag: 'component';
44
+ object: EClassId;
45
+ field: string;
46
+ };
47
+ /**
48
+ * Create a canonical string key for an e-node (for hash-consing)
49
+ * This key is used to detect structurally identical nodes.
50
+ */
51
+ export declare function enodeKey(node: ENode): string;
52
+ /**
53
+ * Get all e-class IDs that this node references (its children)
54
+ */
55
+ export declare function enodeChildren(node: ENode): EClassId[];
56
+ /**
57
+ * Create a new e-node with updated children (after canonicalization)
58
+ */
59
+ export declare function enodeWithChildren(node: ENode, newChildren: EClassId[]): ENode;
60
+ /**
61
+ * Check if two e-nodes are structurally equal
62
+ */
63
+ export declare function enodesEqual(a: ENode, b: ENode): boolean;
@@ -0,0 +1,94 @@
1
+ /**
2
+ * E-Node: Expression nodes in an e-graph
3
+ *
4
+ * E-nodes are hash-consed (deduplicated) and reference e-classes by ID.
5
+ * This allows the e-graph to represent equivalence classes efficiently.
6
+ */
7
+ /**
8
+ * Create a canonical string key for an e-node (for hash-consing)
9
+ * This key is used to detect structurally identical nodes.
10
+ */
11
+ export function enodeKey(node) {
12
+ switch (node.tag) {
13
+ case 'num':
14
+ return `num:${node.value}`;
15
+ case 'var':
16
+ return `var:${node.name}`;
17
+ case 'add':
18
+ return `add:${node.children[0]},${node.children[1]}`;
19
+ case 'mul':
20
+ return `mul:${node.children[0]},${node.children[1]}`;
21
+ case 'sub':
22
+ return `sub:${node.children[0]},${node.children[1]}`;
23
+ case 'div':
24
+ return `div:${node.children[0]},${node.children[1]}`;
25
+ case 'pow':
26
+ return `pow:${node.children[0]},${node.children[1]}`;
27
+ case 'neg':
28
+ return `neg:${node.child}`;
29
+ case 'inv':
30
+ return `inv:${node.child}`;
31
+ case 'call':
32
+ return `call:${node.name}(${node.children.join(',')})`;
33
+ case 'component':
34
+ return `comp:${node.object}.${node.field}`;
35
+ }
36
+ }
37
+ /**
38
+ * Get all e-class IDs that this node references (its children)
39
+ */
40
+ export function enodeChildren(node) {
41
+ switch (node.tag) {
42
+ case 'num':
43
+ case 'var':
44
+ return [];
45
+ case 'add':
46
+ case 'mul':
47
+ case 'sub':
48
+ case 'div':
49
+ case 'pow':
50
+ return [...node.children];
51
+ case 'neg':
52
+ return [node.child];
53
+ case 'inv':
54
+ return [node.child];
55
+ case 'call':
56
+ return [...node.children];
57
+ case 'component':
58
+ return [node.object];
59
+ }
60
+ }
61
+ /**
62
+ * Create a new e-node with updated children (after canonicalization)
63
+ */
64
+ export function enodeWithChildren(node, newChildren) {
65
+ switch (node.tag) {
66
+ case 'num':
67
+ case 'var':
68
+ return node;
69
+ case 'add':
70
+ return { tag: 'add', children: [newChildren[0], newChildren[1]] };
71
+ case 'mul':
72
+ return { tag: 'mul', children: [newChildren[0], newChildren[1]] };
73
+ case 'sub':
74
+ return { tag: 'sub', children: [newChildren[0], newChildren[1]] };
75
+ case 'div':
76
+ return { tag: 'div', children: [newChildren[0], newChildren[1]] };
77
+ case 'pow':
78
+ return { tag: 'pow', children: [newChildren[0], newChildren[1]] };
79
+ case 'neg':
80
+ return { tag: 'neg', child: newChildren[0] };
81
+ case 'inv':
82
+ return { tag: 'inv', child: newChildren[0] };
83
+ case 'call':
84
+ return { tag: 'call', name: node.name, children: newChildren };
85
+ case 'component':
86
+ return { tag: 'component', object: newChildren[0], field: node.field };
87
+ }
88
+ }
89
+ /**
90
+ * Check if two e-nodes are structurally equal
91
+ */
92
+ export function enodesEqual(a, b) {
93
+ return enodeKey(a) === enodeKey(b);
94
+ }
@@ -0,0 +1,49 @@
1
+ /**
2
+ * Cost-Based Extraction from E-Graphs
3
+ *
4
+ * Extracts the lowest-cost expression from each e-class.
5
+ * Also detects common subexpressions (shared e-classes) for CSE.
6
+ */
7
+ import { EGraph } from './EGraph.js';
8
+ import { EClassId } from './ENode.js';
9
+ import { Expression } from '../AST.js';
10
+ /**
11
+ * Cost of different operations
12
+ * Higher cost = less preferred
13
+ */
14
+ export interface CostModel {
15
+ num: number;
16
+ var: number;
17
+ add: number;
18
+ sub: number;
19
+ mul: number;
20
+ div: number;
21
+ pow: number;
22
+ neg: number;
23
+ inv: number;
24
+ call: number;
25
+ component: number;
26
+ }
27
+ /**
28
+ * Default cost model - division is expensive
29
+ */
30
+ export declare const defaultCostModel: CostModel;
31
+ /**
32
+ * Result of extraction with CSE
33
+ */
34
+ export interface ExtractionResult {
35
+ /** Temporary variable definitions (name -> expression) */
36
+ temps: Map<string, Expression>;
37
+ /** The extracted expressions (class ID -> expression using temps) */
38
+ expressions: Map<EClassId, Expression>;
39
+ /** Total cost */
40
+ totalCost: number;
41
+ }
42
+ /**
43
+ * Extract the best expression from an e-class
44
+ */
45
+ export declare function extractBest(egraph: EGraph, rootId: EClassId, costModel?: CostModel): Expression;
46
+ /**
47
+ * Extract multiple expressions with CSE (shared subexpressions become temps)
48
+ */
49
+ export declare function extractWithCSE(egraph: EGraph, roots: EClassId[], costModel?: CostModel, minSharedCost?: number): ExtractionResult;