gradient-script 0.2.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. package/README.md +3 -1
  2. package/dist/cli.js +80 -3
  3. package/dist/dsl/CodeGen.d.ts +1 -1
  4. package/dist/dsl/CodeGen.js +332 -74
  5. package/dist/dsl/ExpressionUtils.d.ts +8 -2
  6. package/dist/dsl/ExpressionUtils.js +34 -2
  7. package/dist/dsl/GradientChecker.d.ts +21 -0
  8. package/dist/dsl/GradientChecker.js +109 -23
  9. package/dist/dsl/Guards.d.ts +1 -1
  10. package/dist/dsl/Guards.js +14 -13
  11. package/dist/dsl/Inliner.d.ts +5 -0
  12. package/dist/dsl/Inliner.js +8 -0
  13. package/dist/dsl/Simplify.d.ts +7 -0
  14. package/dist/dsl/Simplify.js +136 -0
  15. package/dist/dsl/egraph/Convert.d.ts +23 -0
  16. package/dist/dsl/egraph/Convert.js +84 -0
  17. package/dist/dsl/egraph/EGraph.d.ts +93 -0
  18. package/dist/dsl/egraph/EGraph.js +292 -0
  19. package/dist/dsl/egraph/ENode.d.ts +63 -0
  20. package/dist/dsl/egraph/ENode.js +94 -0
  21. package/dist/dsl/egraph/Extractor.d.ts +49 -0
  22. package/dist/dsl/egraph/Extractor.js +1068 -0
  23. package/dist/dsl/egraph/Optimizer.d.ts +50 -0
  24. package/dist/dsl/egraph/Optimizer.js +88 -0
  25. package/dist/dsl/egraph/Pattern.d.ts +80 -0
  26. package/dist/dsl/egraph/Pattern.js +325 -0
  27. package/dist/dsl/egraph/Rewriter.d.ts +44 -0
  28. package/dist/dsl/egraph/Rewriter.js +131 -0
  29. package/dist/dsl/egraph/Rules.d.ts +44 -0
  30. package/dist/dsl/egraph/Rules.js +187 -0
  31. package/dist/dsl/egraph/index.d.ts +15 -0
  32. package/dist/dsl/egraph/index.js +21 -0
  33. package/package.json +1 -1
  34. package/dist/dsl/CSE.d.ts +0 -21
  35. package/dist/dsl/CSE.js +0 -168
  36. package/dist/symbolic/AST.d.ts +0 -113
  37. package/dist/symbolic/AST.js +0 -128
  38. package/dist/symbolic/CodeGen.d.ts +0 -35
  39. package/dist/symbolic/CodeGen.js +0 -280
  40. package/dist/symbolic/Parser.d.ts +0 -64
  41. package/dist/symbolic/Parser.js +0 -329
  42. package/dist/symbolic/Simplify.d.ts +0 -10
  43. package/dist/symbolic/Simplify.js +0 -244
  44. package/dist/symbolic/SymbolicDiff.d.ts +0 -35
  45. package/dist/symbolic/SymbolicDiff.js +0 -339
@@ -174,8 +174,8 @@ export function expressionDepth(expr) {
174
174
  }
175
175
  }
176
176
  /**
177
- * Serializes an expression to canonical string representation.
178
- * Used for expression comparison and hashing (CSE, CodeGen forward reuse).
177
+ * Serializes an expression to structural string representation.
178
+ * Used for exact expression comparison - operand order matters.
179
179
  *
180
180
  * This ensures consistent string representation of expressions across different
181
181
  * parts of the codebase.
@@ -197,3 +197,35 @@ export function serializeExpression(expr) {
197
197
  return `comp(${serializeExpression(expr.object)},${expr.component})`;
198
198
  }
199
199
  }
200
+ /**
201
+ * Serializes an expression to canonical form for CSE matching.
202
+ * Commutative operations (+ and *) have operands sorted lexicographically,
203
+ * so a*b and b*a produce the same canonical string.
204
+ */
205
+ export function serializeCanonical(expr) {
206
+ switch (expr.kind) {
207
+ case 'number':
208
+ return `num(${expr.value})`;
209
+ case 'variable':
210
+ return `var(${expr.name})`;
211
+ case 'binary': {
212
+ const leftStr = serializeCanonical(expr.left);
213
+ const rightStr = serializeCanonical(expr.right);
214
+ // For commutative operations, sort operands lexicographically
215
+ if (expr.operator === '+' || expr.operator === '*') {
216
+ const [first, second] = leftStr <= rightStr ? [leftStr, rightStr] : [rightStr, leftStr];
217
+ return `bin(${expr.operator},${first},${second})`;
218
+ }
219
+ // Non-commutative: preserve order
220
+ return `bin(${expr.operator},${leftStr},${rightStr})`;
221
+ }
222
+ case 'unary':
223
+ return `un(${expr.operator},${serializeCanonical(expr.operand)})`;
224
+ case 'call': {
225
+ const args = expr.args.map(arg => serializeCanonical(arg)).join(',');
226
+ return `call(${expr.name},${args})`;
227
+ }
228
+ case 'component':
229
+ return `comp(${serializeCanonical(expr.object)},${expr.component})`;
230
+ }
231
+ }
@@ -17,9 +17,25 @@ type NumValue = number | {
17
17
  export interface GradCheckResult {
18
18
  passed: boolean;
19
19
  errors: GradCheckError[];
20
+ singularities: GradCheckSingularity[];
20
21
  maxError: number;
21
22
  meanError: number;
23
+ totalChecks: number;
22
24
  }
25
+ /**
26
+ * Singularity detected during gradient checking
27
+ * When both analytical and numerical produce NaN/Inf, it's a singularity, not a bug
28
+ */
29
+ export interface GradCheckSingularity {
30
+ parameter: string;
31
+ component?: string;
32
+ analytical: number;
33
+ numerical: number;
34
+ }
35
+ /**
36
+ * Format gradient check results as a human-readable string
37
+ */
38
+ export declare function formatGradCheckResult(result: GradCheckResult, funcName: string): string;
23
39
  export interface GradCheckError {
24
40
  parameter: string;
25
41
  component?: string;
@@ -39,6 +55,11 @@ export declare class GradientChecker {
39
55
  * Check gradients for a function
40
56
  */
41
57
  check(func: FunctionDef, gradients: GradientResult, env: TypeEnv, testPoint: Map<string, NumValue>): GradCheckResult;
58
+ /**
59
+ * Compare analytical and numerical gradients
60
+ * Distinguishes between: pass, error (mismatch), and singularity (both NaN/Inf)
61
+ */
62
+ private compareGradients;
42
63
  /**
43
64
  * Compute numerical gradient for scalar parameter using finite differences
44
65
  */
@@ -4,6 +4,47 @@
4
4
  */
5
5
  import { Types } from './Types.js';
6
6
  import { expandBuiltIn, shouldExpand } from './Expander.js';
7
+ /**
8
+ * Format gradient check results as a human-readable string
9
+ */
10
+ export function formatGradCheckResult(result, funcName) {
11
+ const singularityCount = result.singularities.length;
12
+ const verifiedCount = result.totalChecks - result.errors.length - singularityCount;
13
+ if (result.passed) {
14
+ let msg = `✓ ${funcName}: ${verifiedCount} gradients verified (max error: ${result.maxError.toExponential(2)})`;
15
+ if (singularityCount > 0) {
16
+ msg += `\n ⚠ ${singularityCount} singularities detected (both analytical and numerical produce NaN/Inf)`;
17
+ }
18
+ return msg;
19
+ }
20
+ const lines = [
21
+ `✗ ${funcName}: ${result.errors.length}/${result.totalChecks} gradients FAILED`
22
+ ];
23
+ // Group errors by parameter
24
+ const byParam = new Map();
25
+ for (const err of result.errors) {
26
+ const key = err.parameter;
27
+ if (!byParam.has(key))
28
+ byParam.set(key, []);
29
+ byParam.get(key).push(err);
30
+ }
31
+ for (const [param, errs] of byParam) {
32
+ if (errs.length === 1 && !errs[0].component) {
33
+ // Scalar parameter
34
+ const e = errs[0];
35
+ lines.push(` ${param}: analytical=${e.analytical.toFixed(6)}, numerical=${e.numerical.toFixed(6)}, error=${e.error.toExponential(2)}`);
36
+ }
37
+ else {
38
+ // Structured parameter - show on one line if possible
39
+ const components = errs.map(e => `${e.component}:${e.error.toExponential(1)}`).join(', ');
40
+ lines.push(` ${param}: {${components}}`);
41
+ }
42
+ }
43
+ if (singularityCount > 0) {
44
+ lines.push(` ⚠ ${singularityCount} singularities also detected`);
45
+ }
46
+ return lines.join('\n');
47
+ }
7
48
  /**
8
49
  * Gradient checker
9
50
  */
@@ -19,6 +60,9 @@ export class GradientChecker {
19
60
  */
20
61
  check(func, gradients, env, testPoint) {
21
62
  const errors = [];
63
+ const singularities = [];
64
+ let totalChecks = 0;
65
+ let maxError = 0;
22
66
  // For each parameter that has gradients
23
67
  for (const [paramName, gradient] of gradients.gradients.entries()) {
24
68
  const paramType = env.getOrThrow(paramName);
@@ -28,21 +72,21 @@ export class GradientChecker {
28
72
  }
29
73
  if (Types.isScalar(paramType)) {
30
74
  // Scalar parameter
75
+ totalChecks++;
31
76
  if (typeof paramValue !== 'number') {
32
77
  throw new Error(`Expected scalar value for ${paramName}`);
33
78
  }
34
79
  const analytical = this.evaluateExpression(gradient, testPoint);
35
80
  const numerical = this.numericalGradientScalar(func, testPoint, paramName);
36
- const error = Math.abs(analytical - numerical);
37
- const relativeError = Math.abs(error / (numerical + 1e-10));
38
- if (error > this.tolerance && relativeError > this.tolerance) {
39
- errors.push({
40
- parameter: paramName,
41
- analytical,
42
- numerical,
43
- error,
44
- relativeError
45
- });
81
+ const checkResult = this.compareGradients(analytical, numerical, paramName);
82
+ if (checkResult.type === 'singularity') {
83
+ singularities.push(checkResult.singularity);
84
+ }
85
+ else if (checkResult.type === 'error') {
86
+ errors.push(checkResult.error);
87
+ }
88
+ else if (checkResult.error) {
89
+ maxError = Math.max(maxError, checkResult.error.error);
46
90
  }
47
91
  }
48
92
  else {
@@ -52,32 +96,74 @@ export class GradientChecker {
52
96
  }
53
97
  const structGrad = gradient;
54
98
  for (const [comp, expr] of structGrad.components.entries()) {
99
+ totalChecks++;
55
100
  const analytical = this.evaluateExpression(expr, testPoint);
56
101
  const numerical = this.numericalGradientComponent(func, testPoint, paramName, comp);
57
- const error = Math.abs(analytical - numerical);
58
- const relativeError = Math.abs(error / (numerical + 1e-10));
59
- if (error > this.tolerance && relativeError > this.tolerance) {
60
- errors.push({
61
- parameter: paramName,
62
- component: comp,
63
- analytical,
64
- numerical,
65
- error,
66
- relativeError
67
- });
102
+ const checkResult = this.compareGradients(analytical, numerical, paramName, comp);
103
+ if (checkResult.type === 'singularity') {
104
+ singularities.push(checkResult.singularity);
105
+ }
106
+ else if (checkResult.type === 'error') {
107
+ errors.push(checkResult.error);
108
+ }
109
+ else if (checkResult.error) {
110
+ maxError = Math.max(maxError, checkResult.error.error);
68
111
  }
69
112
  }
70
113
  }
71
114
  }
72
- const maxError = errors.length > 0 ? Math.max(...errors.map(e => e.error)) : 0;
73
115
  const meanError = errors.length > 0
74
116
  ? errors.reduce((sum, e) => sum + e.error, 0) / errors.length
75
117
  : 0;
76
118
  return {
77
119
  passed: errors.length === 0,
78
120
  errors,
121
+ singularities,
79
122
  maxError,
80
- meanError
123
+ meanError,
124
+ totalChecks
125
+ };
126
+ }
127
+ /**
128
+ * Compare analytical and numerical gradients
129
+ * Distinguishes between: pass, error (mismatch), and singularity (both NaN/Inf)
130
+ */
131
+ compareGradients(analytical, numerical, parameter, component) {
132
+ const analyticalBad = !isFinite(analytical);
133
+ const numericalBad = !isFinite(numerical);
134
+ // Both produce NaN/Inf: singularity (not a bug)
135
+ if (analyticalBad && numericalBad) {
136
+ return {
137
+ type: 'singularity',
138
+ singularity: { parameter, component, analytical, numerical }
139
+ };
140
+ }
141
+ // One produces NaN/Inf, the other doesn't: actual bug
142
+ if (analyticalBad || numericalBad) {
143
+ return {
144
+ type: 'error',
145
+ error: {
146
+ parameter,
147
+ component,
148
+ analytical,
149
+ numerical,
150
+ error: Infinity,
151
+ relativeError: Infinity
152
+ }
153
+ };
154
+ }
155
+ // Both finite: compare values
156
+ const error = Math.abs(analytical - numerical);
157
+ const relativeError = Math.abs(error / (Math.abs(numerical) + 1e-10));
158
+ if (error > this.tolerance && relativeError > this.tolerance) {
159
+ return {
160
+ type: 'error',
161
+ error: { parameter, component, analytical, numerical, error, relativeError }
162
+ };
163
+ }
164
+ return {
165
+ type: 'pass',
166
+ error: { parameter, component, analytical, numerical, error, relativeError }
81
167
  };
82
168
  }
83
169
  /**
@@ -22,7 +22,7 @@ export declare function analyzeGuards(func: FunctionDef): GuardAnalysisResult;
22
22
  /**
23
23
  * Format guard analysis for display
24
24
  */
25
- export declare function formatGuardWarnings(result: GuardAnalysisResult): string;
25
+ export declare function formatGuardWarnings(result: GuardAnalysisResult, asComments?: boolean): string;
26
26
  /**
27
27
  * Generate guard code snippets for common cases
28
28
  */
@@ -154,21 +154,22 @@ function isSqExpression(expr) {
154
154
  /**
155
155
  * Format guard analysis for display
156
156
  */
157
- export function formatGuardWarnings(result) {
157
+ export function formatGuardWarnings(result, asComments = false) {
158
158
  if (!result.hasIssues) {
159
159
  return '';
160
160
  }
161
+ const prefix = asComments ? '// ' : '';
161
162
  const lines = [];
162
- lines.push('⚠️ EDGE CASE WARNINGS:');
163
- lines.push('');
164
- lines.push('The generated code may encounter edge cases that produce');
165
- lines.push('NaN, Infinity, or incorrect gradients:');
166
- lines.push('');
163
+ lines.push(`${prefix}⚠️ EDGE CASE WARNINGS:`);
164
+ lines.push(prefix);
165
+ lines.push(`${prefix}The generated code may encounter edge cases that produce`);
166
+ lines.push(`${prefix}NaN, Infinity, or incorrect gradients:`);
167
+ lines.push(prefix);
167
168
  // Show each guard individually with context
168
169
  for (const guard of result.guards) {
169
170
  const typeLabel = formatGuardType(guard.type);
170
171
  // Show location and variable if available
171
- let location = ' •';
172
+ let location = `${prefix} •`;
172
173
  if (guard.line) {
173
174
  location += ` Line ${guard.line}:`;
174
175
  }
@@ -176,13 +177,13 @@ export function formatGuardWarnings(result) {
176
177
  location += ` ${guard.variableName} =`;
177
178
  }
178
179
  lines.push(location);
179
- lines.push(` ${typeLabel}: ${guard.description}`);
180
- lines.push(` 💡 Fix: ${guard.suggestion}`);
181
- lines.push('');
180
+ lines.push(`${prefix} ${typeLabel}: ${guard.description}`);
181
+ lines.push(`${prefix} 💡 Fix: ${guard.suggestion}`);
182
+ lines.push(prefix);
182
183
  }
183
- lines.push('Add runtime checks or ensure inputs are within valid ranges.');
184
- lines.push('Use --guards --epsilon 1e-10 to automatically emit epsilon guards.');
185
- lines.push('');
184
+ lines.push(`${prefix}Add runtime checks or ensure inputs are within valid ranges.`);
185
+ lines.push(`${prefix}Use --guards --epsilon 1e-10 to automatically emit epsilon guards.`);
186
+ lines.push(prefix);
186
187
  return lines.join('\n');
187
188
  }
188
189
  function formatGuardType(type) {
@@ -8,3 +8,8 @@ import { Expression, FunctionDef } from './AST.js';
8
8
  * Returns a new expression with all intermediate variables substituted
9
9
  */
10
10
  export declare function inlineIntermediateVariables(func: FunctionDef): Expression;
11
+ /**
12
+ * Inline an expression using a substitution map
13
+ * Used to get the fully-expanded form of forward pass expressions
14
+ */
15
+ export declare function inlineExpression(expr: Expression, substitutions: Map<string, Expression>): Expression;
@@ -39,3 +39,11 @@ export function inlineIntermediateVariables(func) {
39
39
  const transformer = new VariableSubstitutionTransformer(substitutions);
40
40
  return transformer.transform(func.returnExpr);
41
41
  }
42
+ /**
43
+ * Inline an expression using a substitution map
44
+ * Used to get the fully-expanded form of forward pass expressions
45
+ */
46
+ export function inlineExpression(expr, substitutions) {
47
+ const transformer = new VariableSubstitutionTransformer(substitutions);
48
+ return transformer.transform(expr);
49
+ }
@@ -15,3 +15,10 @@ export declare function simplifyGradients(gradients: Map<string, Expression | {
15
15
  }>): Map<string, Expression | {
16
16
  components: Map<string, Expression>;
17
17
  }>;
18
+ /**
19
+ * Post-CSE simplification: applies rules that were intentionally skipped during
20
+ * initial simplification to avoid interfering with CSE.
21
+ *
22
+ * Specifically: a + a → 2 * a (now safe because CSE has already extracted temps)
23
+ */
24
+ export declare function simplifyPostCSE(expr: Expression): Expression;
@@ -46,6 +46,9 @@ class Simplifier extends ExpressionTransformer {
46
46
  return right;
47
47
  if (rightNum === 0)
48
48
  return left;
49
+ // Note: a + a → 2 * a rules are intentionally NOT applied here
50
+ // because they flatten expression structure and interfere with CSE.
51
+ // The CSE pass will extract common subexpressions instead.
49
52
  }
50
53
  // Subtraction rules
51
54
  if (expr.operator === '-') {
@@ -57,6 +60,25 @@ class Simplifier extends ExpressionTransformer {
57
60
  if (expressionsEqual(left, right)) {
58
61
  return { kind: 'number', value: 0 };
59
62
  }
63
+ // a - (-b) → a + b
64
+ if (right.kind === 'unary' && right.operator === '-') {
65
+ return this.transform({
66
+ kind: 'binary',
67
+ operator: '+',
68
+ left,
69
+ right: right.operand
70
+ });
71
+ }
72
+ // (-a) - (-b) → b - a
73
+ if (left.kind === 'unary' && left.operator === '-' &&
74
+ right.kind === 'unary' && right.operator === '-') {
75
+ return this.transform({
76
+ kind: 'binary',
77
+ operator: '-',
78
+ left: right.operand,
79
+ right: left.operand
80
+ });
81
+ }
60
82
  }
61
83
  // Multiplication rules
62
84
  if (expr.operator === '*') {
@@ -68,6 +90,50 @@ class Simplifier extends ExpressionTransformer {
68
90
  return right;
69
91
  if (rightNum === 1)
70
92
  return left;
93
+ // -1 * x → -x
94
+ if (leftNum === -1) {
95
+ return { kind: 'unary', operator: '-', operand: right };
96
+ }
97
+ // x * -1 → -x
98
+ if (rightNum === -1) {
99
+ return { kind: 'unary', operator: '-', operand: left };
100
+ }
101
+ // (-a) * (-b) → a * b
102
+ if (left.kind === 'unary' && left.operator === '-' &&
103
+ right.kind === 'unary' && right.operator === '-') {
104
+ return this.transform({
105
+ kind: 'binary',
106
+ operator: '*',
107
+ left: left.operand,
108
+ right: right.operand
109
+ });
110
+ }
111
+ // (-a) * b → -(a * b)
112
+ if (left.kind === 'unary' && left.operator === '-') {
113
+ return {
114
+ kind: 'unary',
115
+ operator: '-',
116
+ operand: this.transform({
117
+ kind: 'binary',
118
+ operator: '*',
119
+ left: left.operand,
120
+ right
121
+ })
122
+ };
123
+ }
124
+ // a * (-b) → -(a * b)
125
+ if (right.kind === 'unary' && right.operator === '-') {
126
+ return {
127
+ kind: 'unary',
128
+ operator: '-',
129
+ operand: this.transform({
130
+ kind: 'binary',
131
+ operator: '*',
132
+ left,
133
+ right: right.operand
134
+ })
135
+ };
136
+ }
71
137
  // (x / x) * y → y
72
138
  if (left.kind === 'binary' && left.operator === '/') {
73
139
  if (expressionsEqual(left.left, left.right)) {
@@ -116,6 +182,42 @@ class Simplifier extends ExpressionTransformer {
116
182
  if (expressionsEqual(left, right)) {
117
183
  return { kind: 'number', value: 1 };
118
184
  }
185
+ // (-a) / (-b) → a / b
186
+ if (left.kind === 'unary' && left.operator === '-' &&
187
+ right.kind === 'unary' && right.operator === '-') {
188
+ return this.transform({
189
+ kind: 'binary',
190
+ operator: '/',
191
+ left: left.operand,
192
+ right: right.operand
193
+ });
194
+ }
195
+ // (-a) / b → -(a / b)
196
+ if (left.kind === 'unary' && left.operator === '-') {
197
+ return {
198
+ kind: 'unary',
199
+ operator: '-',
200
+ operand: this.transform({
201
+ kind: 'binary',
202
+ operator: '/',
203
+ left: left.operand,
204
+ right
205
+ })
206
+ };
207
+ }
208
+ // a / (-b) → -(a / b)
209
+ if (right.kind === 'unary' && right.operator === '-') {
210
+ return {
211
+ kind: 'unary',
212
+ operator: '-',
213
+ operand: this.transform({
214
+ kind: 'binary',
215
+ operator: '/',
216
+ left,
217
+ right: right.operand
218
+ })
219
+ };
220
+ }
119
221
  // (a + a) / 2 → a
120
222
  if (rightNum === 2 && left.kind === 'binary' && left.operator === '+') {
121
223
  if (expressionsEqual(left.left, left.right)) {
@@ -321,3 +423,37 @@ export function simplifyGradients(gradients) {
321
423
  }
322
424
  return simplified;
323
425
  }
426
+ /**
427
+ * Post-CSE simplification: applies rules that were intentionally skipped during
428
+ * initial simplification to avoid interfering with CSE.
429
+ *
430
+ * Specifically: a + a → 2 * a (now safe because CSE has already extracted temps)
431
+ */
432
+ export function simplifyPostCSE(expr) {
433
+ return new PostCSESimplifier().transform(expr);
434
+ }
435
+ class PostCSESimplifier extends ExpressionTransformer {
436
+ visitBinaryOp(expr) {
437
+ const left = this.transform(expr.left);
438
+ const right = this.transform(expr.right);
439
+ // a + a → 2 * a
440
+ if (expr.operator === '+' && expressionsEqual(left, right)) {
441
+ return {
442
+ kind: 'binary',
443
+ operator: '*',
444
+ left: { kind: 'number', value: 2 },
445
+ right: left
446
+ };
447
+ }
448
+ // Return simplified if no changes
449
+ if (left === expr.left && right === expr.right) {
450
+ return expr;
451
+ }
452
+ return {
453
+ kind: 'binary',
454
+ operator: expr.operator,
455
+ left,
456
+ right
457
+ };
458
+ }
459
+ }
@@ -0,0 +1,23 @@
1
+ /**
2
+ * Conversion between AST Expressions and E-Graph
3
+ */
4
+ import { EGraph } from './EGraph.js';
5
+ import { EClassId } from './ENode.js';
6
+ import { Expression } from '../AST.js';
7
+ /**
8
+ * Add an AST Expression to the e-graph, returning its e-class ID
9
+ */
10
+ export declare function addExpression(egraph: EGraph, expr: Expression): EClassId;
11
+ /**
12
+ * Add multiple expressions, returning a map of original keys to e-class IDs
13
+ */
14
+ export declare function addExpressions<K extends string>(egraph: EGraph, expressions: Map<K, Expression>): Map<K, EClassId>;
15
+ /**
16
+ * Add all gradients (Map<paramName, Map<component, Expression>>)
17
+ * Returns Map<paramName, Map<component, EClassId>>
18
+ */
19
+ export declare function addGradients(egraph: EGraph, gradients: Map<string, Map<string, Expression>>): Map<string, Map<string, EClassId>>;
20
+ /**
21
+ * Get all root e-class IDs from gradient structure
22
+ */
23
+ export declare function getRootIds(gradientIds: Map<string, Map<string, EClassId>>): EClassId[];
@@ -0,0 +1,84 @@
1
+ /**
2
+ * Conversion between AST Expressions and E-Graph
3
+ */
4
+ /**
5
+ * Add an AST Expression to the e-graph, returning its e-class ID
6
+ */
7
+ export function addExpression(egraph, expr) {
8
+ switch (expr.kind) {
9
+ case 'number':
10
+ return egraph.add({ tag: 'num', value: expr.value });
11
+ case 'variable':
12
+ return egraph.add({ tag: 'var', name: expr.name });
13
+ case 'binary': {
14
+ const left = addExpression(egraph, expr.left);
15
+ const right = addExpression(egraph, expr.right);
16
+ switch (expr.operator) {
17
+ case '+':
18
+ return egraph.add({ tag: 'add', children: [left, right] });
19
+ case '-':
20
+ return egraph.add({ tag: 'sub', children: [left, right] });
21
+ case '*':
22
+ return egraph.add({ tag: 'mul', children: [left, right] });
23
+ case '/':
24
+ return egraph.add({ tag: 'div', children: [left, right] });
25
+ case '^':
26
+ case '**':
27
+ return egraph.add({ tag: 'pow', children: [left, right] });
28
+ }
29
+ }
30
+ case 'unary': {
31
+ const operand = addExpression(egraph, expr.operand);
32
+ if (expr.operator === '-') {
33
+ return egraph.add({ tag: 'neg', child: operand });
34
+ }
35
+ // Unary + is identity
36
+ return operand;
37
+ }
38
+ case 'call': {
39
+ const args = expr.args.map(arg => addExpression(egraph, arg));
40
+ return egraph.add({ tag: 'call', name: expr.name, children: args });
41
+ }
42
+ case 'component': {
43
+ const object = addExpression(egraph, expr.object);
44
+ return egraph.add({ tag: 'component', object, field: expr.component });
45
+ }
46
+ }
47
+ }
48
+ /**
49
+ * Add multiple expressions, returning a map of original keys to e-class IDs
50
+ */
51
+ export function addExpressions(egraph, expressions) {
52
+ const result = new Map();
53
+ for (const [key, expr] of expressions) {
54
+ result.set(key, addExpression(egraph, expr));
55
+ }
56
+ return result;
57
+ }
58
+ /**
59
+ * Add all gradients (Map<paramName, Map<component, Expression>>)
60
+ * Returns Map<paramName, Map<component, EClassId>>
61
+ */
62
+ export function addGradients(egraph, gradients) {
63
+ const result = new Map();
64
+ for (const [paramName, components] of gradients) {
65
+ const componentIds = new Map();
66
+ for (const [comp, expr] of components) {
67
+ componentIds.set(comp, addExpression(egraph, expr));
68
+ }
69
+ result.set(paramName, componentIds);
70
+ }
71
+ return result;
72
+ }
73
+ /**
74
+ * Get all root e-class IDs from gradient structure
75
+ */
76
+ export function getRootIds(gradientIds) {
77
+ const roots = [];
78
+ for (const components of gradientIds.values()) {
79
+ for (const id of components.values()) {
80
+ roots.push(id);
81
+ }
82
+ }
83
+ return roots;
84
+ }