gradient-script 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. package/README.md +515 -0
  2. package/dist/cli.d.ts +2 -0
  3. package/dist/cli.js +136 -0
  4. package/dist/dsl/AST.d.ts +123 -0
  5. package/dist/dsl/AST.js +23 -0
  6. package/dist/dsl/BuiltIns.d.ts +58 -0
  7. package/dist/dsl/BuiltIns.js +181 -0
  8. package/dist/dsl/CSE.d.ts +21 -0
  9. package/dist/dsl/CSE.js +194 -0
  10. package/dist/dsl/CodeGen.d.ts +60 -0
  11. package/dist/dsl/CodeGen.js +474 -0
  12. package/dist/dsl/Differentiation.d.ts +45 -0
  13. package/dist/dsl/Differentiation.js +421 -0
  14. package/dist/dsl/DiscontinuityAnalyzer.d.ts +18 -0
  15. package/dist/dsl/DiscontinuityAnalyzer.js +75 -0
  16. package/dist/dsl/Errors.d.ts +22 -0
  17. package/dist/dsl/Errors.js +49 -0
  18. package/dist/dsl/Expander.d.ts +13 -0
  19. package/dist/dsl/Expander.js +220 -0
  20. package/dist/dsl/ExpressionTransformer.d.ts +54 -0
  21. package/dist/dsl/ExpressionTransformer.js +102 -0
  22. package/dist/dsl/ExpressionUtils.d.ts +55 -0
  23. package/dist/dsl/ExpressionUtils.js +175 -0
  24. package/dist/dsl/GradientChecker.d.ts +71 -0
  25. package/dist/dsl/GradientChecker.js +258 -0
  26. package/dist/dsl/Guards.d.ts +27 -0
  27. package/dist/dsl/Guards.js +206 -0
  28. package/dist/dsl/Inliner.d.ts +10 -0
  29. package/dist/dsl/Inliner.js +40 -0
  30. package/dist/dsl/Lexer.d.ts +63 -0
  31. package/dist/dsl/Lexer.js +243 -0
  32. package/dist/dsl/Parser.d.ts +92 -0
  33. package/dist/dsl/Parser.js +328 -0
  34. package/dist/dsl/Simplify.d.ts +17 -0
  35. package/dist/dsl/Simplify.js +276 -0
  36. package/dist/dsl/TypeInference.d.ts +39 -0
  37. package/dist/dsl/TypeInference.js +147 -0
  38. package/dist/dsl/Types.d.ts +58 -0
  39. package/dist/dsl/Types.js +114 -0
  40. package/dist/index.d.ts +13 -0
  41. package/dist/index.js +11 -0
  42. package/dist/symbolic/AST.d.ts +113 -0
  43. package/dist/symbolic/AST.js +128 -0
  44. package/dist/symbolic/CodeGen.d.ts +35 -0
  45. package/dist/symbolic/CodeGen.js +280 -0
  46. package/dist/symbolic/Parser.d.ts +64 -0
  47. package/dist/symbolic/Parser.js +329 -0
  48. package/dist/symbolic/Simplify.d.ts +10 -0
  49. package/dist/symbolic/Simplify.js +244 -0
  50. package/dist/symbolic/SymbolicDiff.d.ts +35 -0
  51. package/dist/symbolic/SymbolicDiff.js +339 -0
  52. package/package.json +56 -0
@@ -0,0 +1,276 @@
1
+ /**
2
+ * Expression simplification for GradientScript DSL
3
+ * Applies algebraic simplification rules
4
+ */
5
+ import { ExpressionTransformer } from './ExpressionTransformer.js';
6
+ /**
7
+ * Simplifier - applies algebraic simplification rules recursively
8
+ */
9
+ class Simplifier extends ExpressionTransformer {
10
+ fixedPoint = false;
11
+ constructor(fixedPoint = false) {
12
+ super();
13
+ this.fixedPoint = fixedPoint;
14
+ }
15
+ visitBinaryOp(expr) {
16
+ const left = this.transform(expr.left);
17
+ const right = this.transform(expr.right);
18
+ const leftNum = isNumber(left) ? left.value : null;
19
+ const rightNum = isNumber(right) ? right.value : null;
20
+ // Constant folding
21
+ if (leftNum !== null && rightNum !== null) {
22
+ let result;
23
+ switch (expr.operator) {
24
+ case '+':
25
+ result = leftNum + rightNum;
26
+ break;
27
+ case '-':
28
+ result = leftNum - rightNum;
29
+ break;
30
+ case '*':
31
+ result = leftNum * rightNum;
32
+ break;
33
+ case '/':
34
+ result = leftNum / rightNum;
35
+ break;
36
+ case '^':
37
+ case '**':
38
+ result = Math.pow(leftNum, rightNum);
39
+ break;
40
+ }
41
+ return { kind: 'number', value: result };
42
+ }
43
+ // Addition rules
44
+ if (expr.operator === '+') {
45
+ if (leftNum === 0)
46
+ return right;
47
+ if (rightNum === 0)
48
+ return left;
49
+ }
50
+ // Subtraction rules
51
+ if (expr.operator === '-') {
52
+ if (rightNum === 0)
53
+ return left;
54
+ if (leftNum === 0) {
55
+ return { kind: 'unary', operator: '-', operand: right };
56
+ }
57
+ if (expressionsEqual(left, right)) {
58
+ return { kind: 'number', value: 0 };
59
+ }
60
+ }
61
+ // Multiplication rules
62
+ if (expr.operator === '*') {
63
+ if (leftNum === 0)
64
+ return { kind: 'number', value: 0 };
65
+ if (rightNum === 0)
66
+ return { kind: 'number', value: 0 };
67
+ if (leftNum === 1)
68
+ return right;
69
+ if (rightNum === 1)
70
+ return left;
71
+ // (x / x) * y → y
72
+ if (left.kind === 'binary' && left.operator === '/') {
73
+ if (expressionsEqual(left.left, left.right)) {
74
+ return right;
75
+ }
76
+ }
77
+ // 0.5 * (a + a) → a
78
+ if (leftNum === 0.5 && right.kind === 'binary' && right.operator === '+') {
79
+ const { left: l1, right: r1 } = right;
80
+ if (expressionsEqual(l1, r1)) {
81
+ return l1;
82
+ }
83
+ }
84
+ // 0.5 * (a*b + b*a) → a*b
85
+ if (leftNum === 0.5 && right.kind === 'binary' && right.operator === '+') {
86
+ const { left: l1, right: r1 } = right;
87
+ if (l1.kind === 'binary' && l1.operator === '*' &&
88
+ r1.kind === 'binary' && r1.operator === '*') {
89
+ if (expressionsEqual(l1.left, r1.right) && expressionsEqual(l1.right, r1.left)) {
90
+ return l1;
91
+ }
92
+ }
93
+ }
94
+ // c * (a*b + b*a) → 2*c*a*b
95
+ if (leftNum !== null && right.kind === 'binary' && right.operator === '+') {
96
+ const { left: l1, right: r1 } = right;
97
+ if (l1.kind === 'binary' && l1.operator === '*' &&
98
+ r1.kind === 'binary' && r1.operator === '*') {
99
+ if (expressionsEqual(l1.left, r1.right) && expressionsEqual(l1.right, r1.left)) {
100
+ return {
101
+ kind: 'binary',
102
+ operator: '*',
103
+ left: { kind: 'number', value: 2 * leftNum },
104
+ right: l1
105
+ };
106
+ }
107
+ }
108
+ }
109
+ }
110
+ // Division rules
111
+ if (expr.operator === '/') {
112
+ if (leftNum === 0)
113
+ return { kind: 'number', value: 0 };
114
+ if (rightNum === 1)
115
+ return left;
116
+ if (expressionsEqual(left, right)) {
117
+ return { kind: 'number', value: 1 };
118
+ }
119
+ }
120
+ // Power rules
121
+ if (expr.operator === '^' || expr.operator === '**') {
122
+ if (rightNum === 0)
123
+ return { kind: 'number', value: 1 };
124
+ if (rightNum === 1)
125
+ return left;
126
+ if (leftNum === 0)
127
+ return { kind: 'number', value: 0 };
128
+ if (leftNum === 1)
129
+ return { kind: 'number', value: 1 };
130
+ }
131
+ return {
132
+ kind: 'binary',
133
+ operator: expr.operator,
134
+ left,
135
+ right
136
+ };
137
+ }
138
+ visitUnaryOp(expr) {
139
+ const operand = this.transform(expr.operand);
140
+ if (expr.operator === '-') {
141
+ // Double negation: --x = x
142
+ if (operand.kind === 'unary' && operand.operator === '-') {
143
+ return operand.operand;
144
+ }
145
+ // Negate number: -5 = -5
146
+ if (isNumber(operand)) {
147
+ return { kind: 'number', value: -operand.value };
148
+ }
149
+ }
150
+ if (expr.operator === '+') {
151
+ return operand;
152
+ }
153
+ return {
154
+ kind: 'unary',
155
+ operator: expr.operator,
156
+ operand
157
+ };
158
+ }
159
+ visitFunctionCall(expr) {
160
+ const args = expr.args.map(arg => this.transform(arg));
161
+ if (expr.name === 'sqrt' && args.length === 1) {
162
+ const arg = args[0];
163
+ if (isNumber(arg) && arg.value >= 0) {
164
+ return { kind: 'number', value: Math.sqrt(arg.value) };
165
+ }
166
+ }
167
+ if (expr.name === 'abs' && args.length === 1) {
168
+ const arg = args[0];
169
+ if (isNumber(arg)) {
170
+ return { kind: 'number', value: Math.abs(arg.value) };
171
+ }
172
+ }
173
+ return {
174
+ kind: 'call',
175
+ name: expr.name,
176
+ args
177
+ };
178
+ }
179
+ visitComponentAccess(expr) {
180
+ const object = this.transform(expr.object);
181
+ // (u + v).x -> u.x + v.x
182
+ if (object.kind === 'binary') {
183
+ return this.transform({
184
+ kind: 'binary',
185
+ operator: object.operator,
186
+ left: {
187
+ kind: 'component',
188
+ object: object.left,
189
+ component: expr.component
190
+ },
191
+ right: {
192
+ kind: 'component',
193
+ object: object.right,
194
+ component: expr.component
195
+ }
196
+ });
197
+ }
198
+ return {
199
+ kind: 'component',
200
+ object,
201
+ component: expr.component
202
+ };
203
+ }
204
+ }
205
+ /**
206
+ * Simplify an expression using algebraic rules
207
+ */
208
+ export function simplify(expr) {
209
+ let current = expr;
210
+ let simplified;
211
+ do {
212
+ simplified = current;
213
+ current = new Simplifier(false).transform(simplified);
214
+ } while (!expressionsEqual(current, simplified));
215
+ return current;
216
+ }
217
+ /**
218
+ * Check if expression is a number literal
219
+ */
220
+ function isNumber(expr) {
221
+ return expr.kind === 'number';
222
+ }
223
+ /**
224
+ * Check if two expressions are structurally equal
225
+ */
226
+ function expressionsEqual(a, b) {
227
+ if (a.kind !== b.kind)
228
+ return false;
229
+ switch (a.kind) {
230
+ case 'number':
231
+ return b.kind === 'number' && a.value === b.value;
232
+ case 'variable':
233
+ return b.kind === 'variable' && a.name === b.name;
234
+ case 'binary':
235
+ if (b.kind !== 'binary')
236
+ return false;
237
+ return a.operator === b.operator &&
238
+ expressionsEqual(a.left, b.left) &&
239
+ expressionsEqual(a.right, b.right);
240
+ case 'unary':
241
+ if (b.kind !== 'unary')
242
+ return false;
243
+ return a.operator === b.operator &&
244
+ expressionsEqual(a.operand, b.operand);
245
+ case 'call':
246
+ if (b.kind !== 'call')
247
+ return false;
248
+ return a.name === b.name &&
249
+ a.args.length === b.args.length &&
250
+ a.args.every((arg, i) => expressionsEqual(arg, b.args[i]));
251
+ case 'component':
252
+ if (b.kind !== 'component')
253
+ return false;
254
+ return a.component === b.component &&
255
+ expressionsEqual(a.object, b.object);
256
+ }
257
+ }
258
+ /**
259
+ * Simplify all gradients in a map
260
+ */
261
+ export function simplifyGradients(gradients) {
262
+ const simplified = new Map();
263
+ for (const [key, value] of gradients.entries()) {
264
+ if ('components' in value) {
265
+ const simplifiedComps = new Map();
266
+ for (const [comp, expr] of value.components.entries()) {
267
+ simplifiedComps.set(comp, simplify(expr));
268
+ }
269
+ simplified.set(key, { components: simplifiedComps });
270
+ }
271
+ else {
272
+ simplified.set(key, simplify(value));
273
+ }
274
+ }
275
+ return simplified;
276
+ }
@@ -0,0 +1,39 @@
1
+ /**
2
+ * Type inference for GradientScript DSL
3
+ * Infers types for all expressions and validates type correctness
4
+ */
5
+ import { Program, FunctionDef, Expression, Statement } from './AST.js';
6
+ import { Type, TypeEnv } from './Types.js';
7
+ /**
8
+ * Type inference visitor
9
+ */
10
+ export declare class TypeInferenceVisitor {
11
+ private env;
12
+ constructor(env: TypeEnv);
13
+ /**
14
+ * Infer type for an expression
15
+ */
16
+ inferExpression(expr: Expression): Type;
17
+ private inferNumber;
18
+ private inferVariable;
19
+ private inferBinary;
20
+ private inferUnary;
21
+ private inferCall;
22
+ private inferComponent;
23
+ }
24
+ /**
25
+ * Infer types for a statement
26
+ */
27
+ export declare function inferStatement(stmt: Statement, env: TypeEnv): void;
28
+ /**
29
+ * Infer types for a function
30
+ */
31
+ export declare function inferFunction(func: FunctionDef): TypeEnv;
32
+ /**
33
+ * Infer types for entire program
34
+ */
35
+ export declare function inferProgram(program: Program): void;
36
+ /**
37
+ * Convenience function to infer types
38
+ */
39
+ export declare function inferTypes(program: Program): Program;
@@ -0,0 +1,147 @@
1
+ /**
2
+ * Type inference for GradientScript DSL
3
+ * Infers types for all expressions and validates type correctness
4
+ */
5
+ import { Types, TypeEnv } from './Types.js';
6
+ import { builtIns } from './BuiltIns.js';
7
+ import { TypeError } from './Errors.js';
8
+ /**
9
+ * Type inference visitor
10
+ */
11
+ export class TypeInferenceVisitor {
12
+ env;
13
+ constructor(env) {
14
+ this.env = env;
15
+ }
16
+ /**
17
+ * Infer type for an expression
18
+ */
19
+ inferExpression(expr) {
20
+ switch (expr.kind) {
21
+ case 'number':
22
+ return this.inferNumber(expr);
23
+ case 'variable':
24
+ return this.inferVariable(expr);
25
+ case 'binary':
26
+ return this.inferBinary(expr);
27
+ case 'unary':
28
+ return this.inferUnary(expr);
29
+ case 'call':
30
+ return this.inferCall(expr);
31
+ case 'component':
32
+ return this.inferComponent(expr);
33
+ }
34
+ }
35
+ inferNumber(expr) {
36
+ const type = Types.scalar();
37
+ expr.type = type;
38
+ return type;
39
+ }
40
+ inferVariable(expr) {
41
+ const type = this.env.getOrThrow(expr.name);
42
+ expr.type = type;
43
+ return type;
44
+ }
45
+ inferBinary(expr) {
46
+ const leftType = this.inferExpression(expr.left);
47
+ const rightType = this.inferExpression(expr.right);
48
+ if (!Types.compatible(leftType, rightType)) {
49
+ throw new TypeError(`Type mismatch in binary operation`, expr.operator, Types.toString(leftType), Types.toString(rightType));
50
+ }
51
+ const resultType = Types.binaryResultType(leftType, rightType, expr.operator);
52
+ expr.type = resultType;
53
+ return resultType;
54
+ }
55
+ inferUnary(expr) {
56
+ const operandType = this.inferExpression(expr.operand);
57
+ const resultType = Types.unaryResultType(operandType, expr.operator);
58
+ expr.type = resultType;
59
+ return resultType;
60
+ }
61
+ inferCall(expr) {
62
+ // Infer argument types
63
+ const argTypes = expr.args.map(arg => this.inferExpression(arg));
64
+ // Look up built-in function
65
+ const signature = builtIns.lookup(expr.name, argTypes);
66
+ if (!signature) {
67
+ if (builtIns.isBuiltIn(expr.name)) {
68
+ const overloads = builtIns.getOverloads(expr.name);
69
+ const expectedSigs = overloads.map(sig => `${sig.name}(${sig.params.map(p => Types.toString(p)).join(', ')})`).join(' or ');
70
+ const actualSig = `${expr.name}(${argTypes.map(t => Types.toString(t)).join(', ')})`;
71
+ throw new TypeError(`No matching overload. Expected: ${expectedSigs}`, actualSig);
72
+ }
73
+ else {
74
+ throw new TypeError(`Unknown function`, expr.name);
75
+ }
76
+ }
77
+ expr.type = signature.returnType;
78
+ return signature.returnType;
79
+ }
80
+ inferComponent(expr) {
81
+ const objectType = this.inferExpression(expr.object);
82
+ if (!Types.isStruct(objectType)) {
83
+ throw new TypeError(`Cannot access component of scalar type`, expr.component, 'struct', 'scalar');
84
+ }
85
+ if (!objectType.components.includes(expr.component)) {
86
+ throw new TypeError(`Component does not exist. Available: ${objectType.components.join(', ')}`, expr.component, objectType.components.join('|'), expr.component);
87
+ }
88
+ const resultType = Types.scalar();
89
+ expr.type = resultType;
90
+ return resultType;
91
+ }
92
+ }
93
+ /**
94
+ * Infer types for a statement
95
+ */
96
+ export function inferStatement(stmt, env) {
97
+ if (stmt.kind === 'assignment') {
98
+ const visitor = new TypeInferenceVisitor(env);
99
+ const exprType = visitor.inferExpression(stmt.expression);
100
+ // Add variable to environment
101
+ env.set(stmt.variable, exprType);
102
+ }
103
+ }
104
+ /**
105
+ * Infer types for a function
106
+ */
107
+ export function inferFunction(func) {
108
+ const env = new TypeEnv();
109
+ // Add parameters to environment
110
+ for (const param of func.parameters) {
111
+ let paramType;
112
+ if (param.paramType) {
113
+ // Explicit type annotation
114
+ paramType = Types.struct(param.paramType.components);
115
+ }
116
+ else {
117
+ // No annotation - assume scalar for now
118
+ // (could be inferred from usage later)
119
+ paramType = Types.scalar();
120
+ }
121
+ env.set(param.name, paramType);
122
+ }
123
+ // Infer types for statements
124
+ for (const stmt of func.body) {
125
+ inferStatement(stmt, env);
126
+ }
127
+ // Infer return type
128
+ const visitor = new TypeInferenceVisitor(env);
129
+ const returnType = visitor.inferExpression(func.returnExpr);
130
+ func.type = returnType;
131
+ return env;
132
+ }
133
+ /**
134
+ * Infer types for entire program
135
+ */
136
+ export function inferProgram(program) {
137
+ for (const func of program.functions) {
138
+ inferFunction(func);
139
+ }
140
+ }
141
+ /**
142
+ * Convenience function to infer types
143
+ */
144
+ export function inferTypes(program) {
145
+ inferProgram(program);
146
+ return program;
147
+ }
@@ -0,0 +1,58 @@
1
+ /**
2
+ * Type system for GradientScript DSL
3
+ * Handles scalar vs structured types and type inference
4
+ */
5
+ /**
6
+ * Represents a type in the DSL
7
+ */
8
+ export type Type = ScalarType | StructType;
9
+ /**
10
+ * Scalar type (numbers)
11
+ */
12
+ export interface ScalarType {
13
+ kind: 'scalar';
14
+ }
15
+ /**
16
+ * Structured type with named components
17
+ * e.g., {x, y} or {x, y, z}
18
+ */
19
+ export interface StructType {
20
+ kind: 'struct';
21
+ components: string[];
22
+ }
23
+ /**
24
+ * Type utilities
25
+ */
26
+ export declare const Types: {
27
+ scalar(): ScalarType;
28
+ struct(components: string[]): StructType;
29
+ vec2(): StructType;
30
+ vec3(): StructType;
31
+ isScalar(type: Type): type is ScalarType;
32
+ isStruct(type: Type): type is StructType;
33
+ equals(a: Type, b: Type): boolean;
34
+ toString(type: Type): string;
35
+ /**
36
+ * Check if two types are compatible for binary operations
37
+ */
38
+ compatible(a: Type, b: Type): boolean;
39
+ /**
40
+ * Result type of binary operation
41
+ */
42
+ binaryResultType(a: Type, b: Type, op: string): Type;
43
+ /**
44
+ * Result type of unary operation
45
+ */
46
+ unaryResultType(type: Type, op: string): Type;
47
+ };
48
+ /**
49
+ * Type environment for tracking variable types
50
+ */
51
+ export declare class TypeEnv {
52
+ private types;
53
+ set(name: string, type: Type): void;
54
+ get(name: string): Type | undefined;
55
+ has(name: string): boolean;
56
+ clone(): TypeEnv;
57
+ getOrThrow(name: string): Type;
58
+ }
@@ -0,0 +1,114 @@
1
+ /**
2
+ * Type system for GradientScript DSL
3
+ * Handles scalar vs structured types and type inference
4
+ */
5
+ /**
6
+ * Type utilities
7
+ */
8
+ export const Types = {
9
+ scalar() {
10
+ return { kind: 'scalar' };
11
+ },
12
+ struct(components) {
13
+ return { kind: 'struct', components };
14
+ },
15
+ vec2() {
16
+ return { kind: 'struct', components: ['x', 'y'] };
17
+ },
18
+ vec3() {
19
+ return { kind: 'struct', components: ['x', 'y', 'z'] };
20
+ },
21
+ isScalar(type) {
22
+ return type.kind === 'scalar';
23
+ },
24
+ isStruct(type) {
25
+ return type.kind === 'struct';
26
+ },
27
+ equals(a, b) {
28
+ if (a.kind !== b.kind)
29
+ return false;
30
+ if (a.kind === 'scalar')
31
+ return true;
32
+ const aStruct = a;
33
+ const bStruct = b;
34
+ if (aStruct.components.length !== bStruct.components.length)
35
+ return false;
36
+ return aStruct.components.every((comp, i) => comp === bStruct.components[i]);
37
+ },
38
+ toString(type) {
39
+ if (type.kind === 'scalar')
40
+ return 'scalar';
41
+ return `{${type.components.join(', ')}}`;
42
+ },
43
+ /**
44
+ * Check if two types are compatible for binary operations
45
+ */
46
+ compatible(a, b) {
47
+ // scalar + scalar = ok
48
+ if (a.kind === 'scalar' && b.kind === 'scalar')
49
+ return true;
50
+ // struct + struct = ok if same structure
51
+ if (a.kind === 'struct' && b.kind === 'struct') {
52
+ return Types.equals(a, b);
53
+ }
54
+ // scalar + struct = ok (broadcasting)
55
+ if (a.kind === 'scalar' || b.kind === 'scalar')
56
+ return true;
57
+ return false;
58
+ },
59
+ /**
60
+ * Result type of binary operation
61
+ */
62
+ binaryResultType(a, b, op) {
63
+ // scalar op scalar = scalar
64
+ if (a.kind === 'scalar' && b.kind === 'scalar') {
65
+ return Types.scalar();
66
+ }
67
+ // struct op struct = struct (element-wise)
68
+ if (a.kind === 'struct' && b.kind === 'struct') {
69
+ if (!Types.equals(a, b)) {
70
+ throw new Error(`Type mismatch: cannot perform ${op} on ${Types.toString(a)} and ${Types.toString(b)}`);
71
+ }
72
+ return a;
73
+ }
74
+ // scalar op struct = struct (broadcasting)
75
+ if (a.kind === 'scalar' && b.kind === 'struct')
76
+ return b;
77
+ if (a.kind === 'struct' && b.kind === 'scalar')
78
+ return a;
79
+ throw new Error(`Invalid types for ${op}: ${Types.toString(a)} and ${Types.toString(b)}`);
80
+ },
81
+ /**
82
+ * Result type of unary operation
83
+ */
84
+ unaryResultType(type, op) {
85
+ return type; // Unary ops preserve type
86
+ }
87
+ };
88
+ /**
89
+ * Type environment for tracking variable types
90
+ */
91
+ export class TypeEnv {
92
+ types = new Map();
93
+ set(name, type) {
94
+ this.types.set(name, type);
95
+ }
96
+ get(name) {
97
+ return this.types.get(name);
98
+ }
99
+ has(name) {
100
+ return this.types.has(name);
101
+ }
102
+ clone() {
103
+ const env = new TypeEnv();
104
+ env.types = new Map(this.types);
105
+ return env;
106
+ }
107
+ getOrThrow(name) {
108
+ const type = this.get(name);
109
+ if (!type) {
110
+ throw new Error(`Variable '${name}' is not defined`);
111
+ }
112
+ return type;
113
+ }
114
+ }
@@ -0,0 +1,13 @@
1
+ /**
2
+ * GradientScript - Symbolic differentiation for structured types
3
+ *
4
+ * This library provides automatic differentiation for functions with
5
+ * structured types (vectors, custom structures).
6
+ */
7
+ export { parse } from './dsl/Parser.js';
8
+ export { inferFunction } from './dsl/TypeInference.js';
9
+ export { computeFunctionGradients } from './dsl/Differentiation.js';
10
+ export { generateComplete, generateGradientFunction, type CodeGenOptions } from './dsl/CodeGen.js';
11
+ export type { Expression, FunctionDef, Program, Parameter, Assignment } from './dsl/AST.js';
12
+ export type { Type, ScalarType, StructType, TypeEnv } from './dsl/Types.js';
13
+ export type { GradCheckResult, GradCheckError } from './dsl/GradientChecker.js';
package/dist/index.js ADDED
@@ -0,0 +1,11 @@
1
+ /**
2
+ * GradientScript - Symbolic differentiation for structured types
3
+ *
4
+ * This library provides automatic differentiation for functions with
5
+ * structured types (vectors, custom structures).
6
+ */
7
+ // Core API
8
+ export { parse } from './dsl/Parser.js';
9
+ export { inferFunction } from './dsl/TypeInference.js';
10
+ export { computeFunctionGradients } from './dsl/Differentiation.js';
11
+ export { generateComplete, generateGradientFunction } from './dsl/CodeGen.js';