gradient-script 0.1.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. package/README.md +52 -9
  2. package/dist/cli.js +134 -19
  3. package/dist/dsl/AST.d.ts +8 -0
  4. package/dist/dsl/CodeGen.d.ts +8 -3
  5. package/dist/dsl/CodeGen.js +583 -132
  6. package/dist/dsl/Errors.d.ts +6 -1
  7. package/dist/dsl/Errors.js +70 -1
  8. package/dist/dsl/Expander.js +5 -2
  9. package/dist/dsl/ExpressionUtils.d.ts +14 -0
  10. package/dist/dsl/ExpressionUtils.js +56 -0
  11. package/dist/dsl/GradientChecker.d.ts +21 -0
  12. package/dist/dsl/GradientChecker.js +109 -23
  13. package/dist/dsl/Guards.d.ts +3 -1
  14. package/dist/dsl/Guards.js +86 -43
  15. package/dist/dsl/Inliner.d.ts +5 -0
  16. package/dist/dsl/Inliner.js +11 -2
  17. package/dist/dsl/Lexer.js +3 -1
  18. package/dist/dsl/Parser.js +11 -5
  19. package/dist/dsl/Simplify.d.ts +7 -0
  20. package/dist/dsl/Simplify.js +183 -0
  21. package/dist/dsl/egraph/Convert.d.ts +23 -0
  22. package/dist/dsl/egraph/Convert.js +84 -0
  23. package/dist/dsl/egraph/EGraph.d.ts +93 -0
  24. package/dist/dsl/egraph/EGraph.js +292 -0
  25. package/dist/dsl/egraph/ENode.d.ts +63 -0
  26. package/dist/dsl/egraph/ENode.js +94 -0
  27. package/dist/dsl/egraph/Extractor.d.ts +49 -0
  28. package/dist/dsl/egraph/Extractor.js +1068 -0
  29. package/dist/dsl/egraph/Optimizer.d.ts +50 -0
  30. package/dist/dsl/egraph/Optimizer.js +88 -0
  31. package/dist/dsl/egraph/Pattern.d.ts +80 -0
  32. package/dist/dsl/egraph/Pattern.js +325 -0
  33. package/dist/dsl/egraph/Rewriter.d.ts +44 -0
  34. package/dist/dsl/egraph/Rewriter.js +131 -0
  35. package/dist/dsl/egraph/Rules.d.ts +44 -0
  36. package/dist/dsl/egraph/Rules.js +187 -0
  37. package/dist/dsl/egraph/index.d.ts +15 -0
  38. package/dist/dsl/egraph/index.js +21 -0
  39. package/package.json +1 -1
  40. package/dist/dsl/CSE.d.ts +0 -21
  41. package/dist/dsl/CSE.js +0 -194
  42. package/dist/symbolic/AST.d.ts +0 -113
  43. package/dist/symbolic/AST.js +0 -128
  44. package/dist/symbolic/CodeGen.d.ts +0 -35
  45. package/dist/symbolic/CodeGen.js +0 -280
  46. package/dist/symbolic/Parser.d.ts +0 -64
  47. package/dist/symbolic/Parser.js +0 -329
  48. package/dist/symbolic/Simplify.d.ts +0 -10
  49. package/dist/symbolic/Simplify.js +0 -244
  50. package/dist/symbolic/SymbolicDiff.d.ts +0 -35
  51. package/dist/symbolic/SymbolicDiff.js +0 -339
@@ -6,8 +6,9 @@ import { ExpressionTransformer } from './ExpressionTransformer.js';
6
6
  /**
7
7
  * Expression transformer that substitutes variables from a substitution map
8
8
  * Handles recursive inlining by reprocessing substituted expressions
9
+ * Used for inlining intermediate variables to eliminate assignments
9
10
  */
10
- class SubstitutionTransformer extends ExpressionTransformer {
11
+ class VariableSubstitutionTransformer extends ExpressionTransformer {
11
12
  substitutions;
12
13
  constructor(substitutions) {
13
14
  super();
@@ -35,6 +36,14 @@ export function inlineIntermediateVariables(func) {
35
36
  }
36
37
  }
37
38
  // Use transformer to inline all variables
38
- const transformer = new SubstitutionTransformer(substitutions);
39
+ const transformer = new VariableSubstitutionTransformer(substitutions);
39
40
  return transformer.transform(func.returnExpr);
40
41
  }
42
+ /**
43
+ * Inline an expression using a substitution map
44
+ * Used to get the fully-expanded form of forward pass expressions
45
+ */
46
+ export function inlineExpression(expr, substitutions) {
47
+ const transformer = new VariableSubstitutionTransformer(substitutions);
48
+ return transformer.transform(expr);
49
+ }
package/dist/dsl/Lexer.js CHANGED
@@ -2,6 +2,7 @@
2
2
  * Lexer for GradientScript DSL
3
3
  * Tokenizes input with support for ∇, ^, **, and structured types
4
4
  */
5
+ import { ParseError } from './Errors.js';
5
6
  export var TokenType;
6
7
  (function (TokenType) {
7
8
  // Literals
@@ -126,7 +127,8 @@ export class Lexer {
126
127
  this.advance();
127
128
  return { type: TokenType.NEWLINE, value: '\n', line, column };
128
129
  }
129
- throw new Error(`Unexpected character '${char}' at line ${line}, column ${column}`);
130
+ // Create a more helpful error for common mistakes
131
+ throw new ParseError(`Unexpected character '${char}'`, line, column, char);
130
132
  }
131
133
  number() {
132
134
  const line = this.line;
@@ -123,13 +123,15 @@ export class Parser {
123
123
  * variable = expression
124
124
  */
125
125
  assignment() {
126
- const variable = this.consume(TokenType.IDENTIFIER, 'Expected variable name').value;
126
+ const varToken = this.consume(TokenType.IDENTIFIER, 'Expected variable name');
127
+ const variable = varToken.value;
127
128
  this.consume(TokenType.EQUALS, "Expected '=' in assignment");
128
129
  const expression = this.expression();
129
130
  return {
130
131
  kind: 'assignment',
131
132
  variable,
132
- expression
133
+ expression,
134
+ loc: { line: varToken.line, column: varToken.column }
133
135
  };
134
136
  }
135
137
  /**
@@ -161,13 +163,15 @@ export class Parser {
161
163
  multiplicative() {
162
164
  let expr = this.power();
163
165
  while (this.match(TokenType.MULTIPLY, TokenType.DIVIDE)) {
164
- const operator = this.previous().value;
166
+ const opToken = this.previous();
167
+ const operator = opToken.value;
165
168
  const right = this.power();
166
169
  expr = {
167
170
  kind: 'binary',
168
171
  operator,
169
172
  left: expr,
170
- right
173
+ right,
174
+ loc: { line: opToken.line, column: opToken.column }
171
175
  };
172
176
  }
173
177
  return expr;
@@ -213,6 +217,7 @@ export class Parser {
213
217
  while (true) {
214
218
  if (this.match(TokenType.LPAREN)) {
215
219
  // Function call
220
+ const startLoc = expr.loc || { line: this.previous().line, column: this.previous().column };
216
221
  const args = this.argumentList();
217
222
  this.consume(TokenType.RPAREN, "Expected ')' after arguments");
218
223
  if (expr.kind !== 'variable') {
@@ -222,7 +227,8 @@ export class Parser {
222
227
  expr = {
223
228
  kind: 'call',
224
229
  name: expr.name,
225
- args
230
+ args,
231
+ loc: startLoc
226
232
  };
227
233
  }
228
234
  else if (this.match(TokenType.DOT)) {
@@ -15,3 +15,10 @@ export declare function simplifyGradients(gradients: Map<string, Expression | {
15
15
  }>): Map<string, Expression | {
16
16
  components: Map<string, Expression>;
17
17
  }>;
18
+ /**
19
+ * Post-CSE simplification: applies rules that were intentionally skipped during
20
+ * initial simplification to avoid interfering with CSE.
21
+ *
22
+ * Specifically: a + a → 2 * a (now safe because CSE has already extracted temps)
23
+ */
24
+ export declare function simplifyPostCSE(expr: Expression): Expression;
@@ -46,6 +46,9 @@ class Simplifier extends ExpressionTransformer {
46
46
  return right;
47
47
  if (rightNum === 0)
48
48
  return left;
49
+ // Note: a + a → 2 * a rules are intentionally NOT applied here
50
+ // because they flatten expression structure and interfere with CSE.
51
+ // The CSE pass will extract common subexpressions instead.
49
52
  }
50
53
  // Subtraction rules
51
54
  if (expr.operator === '-') {
@@ -57,6 +60,25 @@ class Simplifier extends ExpressionTransformer {
57
60
  if (expressionsEqual(left, right)) {
58
61
  return { kind: 'number', value: 0 };
59
62
  }
63
+ // a - (-b) → a + b
64
+ if (right.kind === 'unary' && right.operator === '-') {
65
+ return this.transform({
66
+ kind: 'binary',
67
+ operator: '+',
68
+ left,
69
+ right: right.operand
70
+ });
71
+ }
72
+ // (-a) - (-b) → b - a
73
+ if (left.kind === 'unary' && left.operator === '-' &&
74
+ right.kind === 'unary' && right.operator === '-') {
75
+ return this.transform({
76
+ kind: 'binary',
77
+ operator: '-',
78
+ left: right.operand,
79
+ right: left.operand
80
+ });
81
+ }
60
82
  }
61
83
  // Multiplication rules
62
84
  if (expr.operator === '*') {
@@ -68,6 +90,50 @@ class Simplifier extends ExpressionTransformer {
68
90
  return right;
69
91
  if (rightNum === 1)
70
92
  return left;
93
+ // -1 * x → -x
94
+ if (leftNum === -1) {
95
+ return { kind: 'unary', operator: '-', operand: right };
96
+ }
97
+ // x * -1 → -x
98
+ if (rightNum === -1) {
99
+ return { kind: 'unary', operator: '-', operand: left };
100
+ }
101
+ // (-a) * (-b) → a * b
102
+ if (left.kind === 'unary' && left.operator === '-' &&
103
+ right.kind === 'unary' && right.operator === '-') {
104
+ return this.transform({
105
+ kind: 'binary',
106
+ operator: '*',
107
+ left: left.operand,
108
+ right: right.operand
109
+ });
110
+ }
111
+ // (-a) * b → -(a * b)
112
+ if (left.kind === 'unary' && left.operator === '-') {
113
+ return {
114
+ kind: 'unary',
115
+ operator: '-',
116
+ operand: this.transform({
117
+ kind: 'binary',
118
+ operator: '*',
119
+ left: left.operand,
120
+ right
121
+ })
122
+ };
123
+ }
124
+ // a * (-b) → -(a * b)
125
+ if (right.kind === 'unary' && right.operator === '-') {
126
+ return {
127
+ kind: 'unary',
128
+ operator: '-',
129
+ operand: this.transform({
130
+ kind: 'binary',
131
+ operator: '*',
132
+ left,
133
+ right: right.operand
134
+ })
135
+ };
136
+ }
71
137
  // (x / x) * y → y
72
138
  if (left.kind === 'binary' && left.operator === '/') {
73
139
  if (expressionsEqual(left.left, left.right)) {
@@ -116,6 +182,89 @@ class Simplifier extends ExpressionTransformer {
116
182
  if (expressionsEqual(left, right)) {
117
183
  return { kind: 'number', value: 1 };
118
184
  }
185
+ // (-a) / (-b) → a / b
186
+ if (left.kind === 'unary' && left.operator === '-' &&
187
+ right.kind === 'unary' && right.operator === '-') {
188
+ return this.transform({
189
+ kind: 'binary',
190
+ operator: '/',
191
+ left: left.operand,
192
+ right: right.operand
193
+ });
194
+ }
195
+ // (-a) / b → -(a / b)
196
+ if (left.kind === 'unary' && left.operator === '-') {
197
+ return {
198
+ kind: 'unary',
199
+ operator: '-',
200
+ operand: this.transform({
201
+ kind: 'binary',
202
+ operator: '/',
203
+ left: left.operand,
204
+ right
205
+ })
206
+ };
207
+ }
208
+ // a / (-b) → -(a / b)
209
+ if (right.kind === 'unary' && right.operator === '-') {
210
+ return {
211
+ kind: 'unary',
212
+ operator: '-',
213
+ operand: this.transform({
214
+ kind: 'binary',
215
+ operator: '/',
216
+ left,
217
+ right: right.operand
218
+ })
219
+ };
220
+ }
221
+ // (a + a) / 2 → a
222
+ if (rightNum === 2 && left.kind === 'binary' && left.operator === '+') {
223
+ if (expressionsEqual(left.left, left.right)) {
224
+ return left.left;
225
+ }
226
+ }
227
+ // (a + a) / (2 * b) → a / b
228
+ if (right.kind === 'binary' && right.operator === '*') {
229
+ const rightLeft = right.left;
230
+ const rightRight = right.right;
231
+ const rightLeftNum = isNumber(rightLeft) ? rightLeft.value : null;
232
+ if (rightLeftNum === 2 && left.kind === 'binary' && left.operator === '+') {
233
+ if (expressionsEqual(left.left, left.right)) {
234
+ return {
235
+ kind: 'binary',
236
+ operator: '/',
237
+ left: left.left,
238
+ right: rightRight
239
+ };
240
+ }
241
+ // (-1 * a + a * -1) / (2 * b) → -a / b
242
+ const leftLeft = left.left;
243
+ const leftRight = left.right;
244
+ if (leftLeft.kind === 'binary' && leftLeft.operator === '*' &&
245
+ leftRight.kind === 'binary' && leftRight.operator === '*') {
246
+ const ll_left = leftLeft.left;
247
+ const ll_right = leftLeft.right;
248
+ const lr_left = leftRight.left;
249
+ const lr_right = leftRight.right;
250
+ const ll_leftNum = isNumber(ll_left) ? ll_left.value : null;
251
+ const lr_rightNum = isNumber(lr_right) ? lr_right.value : null;
252
+ // (-1 * a) + (a * -1)
253
+ if (ll_leftNum === -1 && lr_rightNum === -1 && expressionsEqual(ll_right, lr_left)) {
254
+ return {
255
+ kind: 'unary',
256
+ operator: '-',
257
+ operand: {
258
+ kind: 'binary',
259
+ operator: '/',
260
+ left: ll_right,
261
+ right: rightRight
262
+ }
263
+ };
264
+ }
265
+ }
266
+ }
267
+ }
119
268
  }
120
269
  // Power rules
121
270
  if (expr.operator === '^' || expr.operator === '**') {
@@ -274,3 +423,37 @@ export function simplifyGradients(gradients) {
274
423
  }
275
424
  return simplified;
276
425
  }
426
+ /**
427
+ * Post-CSE simplification: applies rules that were intentionally skipped during
428
+ * initial simplification to avoid interfering with CSE.
429
+ *
430
+ * Specifically: a + a → 2 * a (now safe because CSE has already extracted temps)
431
+ */
432
+ export function simplifyPostCSE(expr) {
433
+ return new PostCSESimplifier().transform(expr);
434
+ }
435
+ class PostCSESimplifier extends ExpressionTransformer {
436
+ visitBinaryOp(expr) {
437
+ const left = this.transform(expr.left);
438
+ const right = this.transform(expr.right);
439
+ // a + a → 2 * a
440
+ if (expr.operator === '+' && expressionsEqual(left, right)) {
441
+ return {
442
+ kind: 'binary',
443
+ operator: '*',
444
+ left: { kind: 'number', value: 2 },
445
+ right: left
446
+ };
447
+ }
448
+ // Return simplified if no changes
449
+ if (left === expr.left && right === expr.right) {
450
+ return expr;
451
+ }
452
+ return {
453
+ kind: 'binary',
454
+ operator: expr.operator,
455
+ left,
456
+ right
457
+ };
458
+ }
459
+ }
@@ -0,0 +1,23 @@
1
+ /**
2
+ * Conversion between AST Expressions and E-Graph
3
+ */
4
+ import { EGraph } from './EGraph.js';
5
+ import { EClassId } from './ENode.js';
6
+ import { Expression } from '../AST.js';
7
+ /**
8
+ * Add an AST Expression to the e-graph, returning its e-class ID
9
+ */
10
+ export declare function addExpression(egraph: EGraph, expr: Expression): EClassId;
11
+ /**
12
+ * Add multiple expressions, returning a map of original keys to e-class IDs
13
+ */
14
+ export declare function addExpressions<K extends string>(egraph: EGraph, expressions: Map<K, Expression>): Map<K, EClassId>;
15
+ /**
16
+ * Add all gradients (Map<paramName, Map<component, Expression>>)
17
+ * Returns Map<paramName, Map<component, EClassId>>
18
+ */
19
+ export declare function addGradients(egraph: EGraph, gradients: Map<string, Map<string, Expression>>): Map<string, Map<string, EClassId>>;
20
+ /**
21
+ * Get all root e-class IDs from gradient structure
22
+ */
23
+ export declare function getRootIds(gradientIds: Map<string, Map<string, EClassId>>): EClassId[];
@@ -0,0 +1,84 @@
1
+ /**
2
+ * Conversion between AST Expressions and E-Graph
3
+ */
4
+ /**
5
+ * Add an AST Expression to the e-graph, returning its e-class ID
6
+ */
7
+ export function addExpression(egraph, expr) {
8
+ switch (expr.kind) {
9
+ case 'number':
10
+ return egraph.add({ tag: 'num', value: expr.value });
11
+ case 'variable':
12
+ return egraph.add({ tag: 'var', name: expr.name });
13
+ case 'binary': {
14
+ const left = addExpression(egraph, expr.left);
15
+ const right = addExpression(egraph, expr.right);
16
+ switch (expr.operator) {
17
+ case '+':
18
+ return egraph.add({ tag: 'add', children: [left, right] });
19
+ case '-':
20
+ return egraph.add({ tag: 'sub', children: [left, right] });
21
+ case '*':
22
+ return egraph.add({ tag: 'mul', children: [left, right] });
23
+ case '/':
24
+ return egraph.add({ tag: 'div', children: [left, right] });
25
+ case '^':
26
+ case '**':
27
+ return egraph.add({ tag: 'pow', children: [left, right] });
28
+ }
29
+ }
30
+ case 'unary': {
31
+ const operand = addExpression(egraph, expr.operand);
32
+ if (expr.operator === '-') {
33
+ return egraph.add({ tag: 'neg', child: operand });
34
+ }
35
+ // Unary + is identity
36
+ return operand;
37
+ }
38
+ case 'call': {
39
+ const args = expr.args.map(arg => addExpression(egraph, arg));
40
+ return egraph.add({ tag: 'call', name: expr.name, children: args });
41
+ }
42
+ case 'component': {
43
+ const object = addExpression(egraph, expr.object);
44
+ return egraph.add({ tag: 'component', object, field: expr.component });
45
+ }
46
+ }
47
+ }
48
+ /**
49
+ * Add multiple expressions, returning a map of original keys to e-class IDs
50
+ */
51
+ export function addExpressions(egraph, expressions) {
52
+ const result = new Map();
53
+ for (const [key, expr] of expressions) {
54
+ result.set(key, addExpression(egraph, expr));
55
+ }
56
+ return result;
57
+ }
58
+ /**
59
+ * Add all gradients (Map<paramName, Map<component, Expression>>)
60
+ * Returns Map<paramName, Map<component, EClassId>>
61
+ */
62
+ export function addGradients(egraph, gradients) {
63
+ const result = new Map();
64
+ for (const [paramName, components] of gradients) {
65
+ const componentIds = new Map();
66
+ for (const [comp, expr] of components) {
67
+ componentIds.set(comp, addExpression(egraph, expr));
68
+ }
69
+ result.set(paramName, componentIds);
70
+ }
71
+ return result;
72
+ }
73
+ /**
74
+ * Get all root e-class IDs from gradient structure
75
+ */
76
+ export function getRootIds(gradientIds) {
77
+ const roots = [];
78
+ for (const components of gradientIds.values()) {
79
+ for (const id of components.values()) {
80
+ roots.push(id);
81
+ }
82
+ }
83
+ return roots;
84
+ }
@@ -0,0 +1,93 @@
1
+ /**
2
+ * E-Graph: Equality Graph for expression optimization
3
+ *
4
+ * An e-graph efficiently represents equivalence classes of expressions.
5
+ * It supports:
6
+ * - Adding expressions (returns e-class ID)
7
+ * - Merging e-classes (union)
8
+ * - Finding canonical e-class (find)
9
+ * - Rebuilding after merges (maintains congruence)
10
+ */
11
+ import { ENode, EClassId } from './ENode.js';
12
+ /**
13
+ * E-Class: An equivalence class of expressions
14
+ */
15
+ export interface EClass {
16
+ id: EClassId;
17
+ nodes: Set<string>;
18
+ parents: Set<string>;
19
+ }
20
+ /**
21
+ * E-Graph: The main data structure
22
+ */
23
+ export declare class EGraph {
24
+ private nextId;
25
+ private classes;
26
+ private parent;
27
+ private rank;
28
+ private hashcons;
29
+ private nodeStore;
30
+ private pending;
31
+ /**
32
+ * Find the canonical e-class ID (with path compression)
33
+ */
34
+ find(id: EClassId): EClassId;
35
+ /**
36
+ * Add an e-node to the e-graph, returning its e-class ID
37
+ * If the node already exists, returns the existing class
38
+ */
39
+ add(node: ENode): EClassId;
40
+ /**
41
+ * Merge two e-classes, returning the new canonical ID
42
+ */
43
+ merge(id1: EClassId, id2: EClassId): EClassId;
44
+ /**
45
+ * Rebuild the e-graph to restore congruence invariants
46
+ * Must be called after a batch of merges
47
+ */
48
+ rebuild(): void;
49
+ /**
50
+ * Repair an e-class after merges
51
+ */
52
+ private repair;
53
+ /**
54
+ * Find which e-class contains a node (by key)
55
+ */
56
+ private findClassForNode;
57
+ /**
58
+ * Canonicalize an e-node (update children to canonical IDs)
59
+ */
60
+ private canonicalize;
61
+ /**
62
+ * Get all e-class IDs
63
+ */
64
+ getClassIds(): EClassId[];
65
+ /**
66
+ * Get an e-class by ID
67
+ */
68
+ getClass(id: EClassId): EClass | undefined;
69
+ /**
70
+ * Get all e-nodes in an e-class
71
+ */
72
+ getNodes(classId: EClassId): ENode[];
73
+ /**
74
+ * Get the number of e-classes
75
+ */
76
+ get size(): number;
77
+ /**
78
+ * Get a node by its key
79
+ */
80
+ getNodeByKey(key: string): ENode | undefined;
81
+ /**
82
+ * Lookup e-class by node (if it exists)
83
+ */
84
+ lookup(node: ENode): EClassId | undefined;
85
+ /**
86
+ * Debug: print e-graph state
87
+ */
88
+ dump(): string;
89
+ /**
90
+ * Convert e-node to readable string
91
+ */
92
+ private nodeToString;
93
+ }