gradient-script 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +515 -0
- package/dist/cli.d.ts +2 -0
- package/dist/cli.js +136 -0
- package/dist/dsl/AST.d.ts +123 -0
- package/dist/dsl/AST.js +23 -0
- package/dist/dsl/BuiltIns.d.ts +58 -0
- package/dist/dsl/BuiltIns.js +181 -0
- package/dist/dsl/CSE.d.ts +21 -0
- package/dist/dsl/CSE.js +194 -0
- package/dist/dsl/CodeGen.d.ts +60 -0
- package/dist/dsl/CodeGen.js +474 -0
- package/dist/dsl/Differentiation.d.ts +45 -0
- package/dist/dsl/Differentiation.js +421 -0
- package/dist/dsl/DiscontinuityAnalyzer.d.ts +18 -0
- package/dist/dsl/DiscontinuityAnalyzer.js +75 -0
- package/dist/dsl/Errors.d.ts +22 -0
- package/dist/dsl/Errors.js +49 -0
- package/dist/dsl/Expander.d.ts +13 -0
- package/dist/dsl/Expander.js +220 -0
- package/dist/dsl/ExpressionTransformer.d.ts +54 -0
- package/dist/dsl/ExpressionTransformer.js +102 -0
- package/dist/dsl/ExpressionUtils.d.ts +55 -0
- package/dist/dsl/ExpressionUtils.js +175 -0
- package/dist/dsl/GradientChecker.d.ts +71 -0
- package/dist/dsl/GradientChecker.js +258 -0
- package/dist/dsl/Guards.d.ts +27 -0
- package/dist/dsl/Guards.js +206 -0
- package/dist/dsl/Inliner.d.ts +10 -0
- package/dist/dsl/Inliner.js +40 -0
- package/dist/dsl/Lexer.d.ts +63 -0
- package/dist/dsl/Lexer.js +243 -0
- package/dist/dsl/Parser.d.ts +92 -0
- package/dist/dsl/Parser.js +328 -0
- package/dist/dsl/Simplify.d.ts +17 -0
- package/dist/dsl/Simplify.js +276 -0
- package/dist/dsl/TypeInference.d.ts +39 -0
- package/dist/dsl/TypeInference.js +147 -0
- package/dist/dsl/Types.d.ts +58 -0
- package/dist/dsl/Types.js +114 -0
- package/dist/index.d.ts +13 -0
- package/dist/index.js +11 -0
- package/dist/symbolic/AST.d.ts +113 -0
- package/dist/symbolic/AST.js +128 -0
- package/dist/symbolic/CodeGen.d.ts +35 -0
- package/dist/symbolic/CodeGen.js +280 -0
- package/dist/symbolic/Parser.d.ts +64 -0
- package/dist/symbolic/Parser.js +329 -0
- package/dist/symbolic/Simplify.d.ts +10 -0
- package/dist/symbolic/Simplify.js +244 -0
- package/dist/symbolic/SymbolicDiff.d.ts +35 -0
- package/dist/symbolic/SymbolicDiff.js +339 -0
- package/package.json +56 -0
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Expression simplification for GradientScript DSL
|
|
3
|
+
* Applies algebraic simplification rules
|
|
4
|
+
*/
|
|
5
|
+
import { ExpressionTransformer } from './ExpressionTransformer.js';
|
|
6
|
+
/**
|
|
7
|
+
* Simplifier - applies algebraic simplification rules recursively
|
|
8
|
+
*/
|
|
9
|
+
class Simplifier extends ExpressionTransformer {
|
|
10
|
+
fixedPoint = false;
|
|
11
|
+
constructor(fixedPoint = false) {
|
|
12
|
+
super();
|
|
13
|
+
this.fixedPoint = fixedPoint;
|
|
14
|
+
}
|
|
15
|
+
visitBinaryOp(expr) {
|
|
16
|
+
const left = this.transform(expr.left);
|
|
17
|
+
const right = this.transform(expr.right);
|
|
18
|
+
const leftNum = isNumber(left) ? left.value : null;
|
|
19
|
+
const rightNum = isNumber(right) ? right.value : null;
|
|
20
|
+
// Constant folding
|
|
21
|
+
if (leftNum !== null && rightNum !== null) {
|
|
22
|
+
let result;
|
|
23
|
+
switch (expr.operator) {
|
|
24
|
+
case '+':
|
|
25
|
+
result = leftNum + rightNum;
|
|
26
|
+
break;
|
|
27
|
+
case '-':
|
|
28
|
+
result = leftNum - rightNum;
|
|
29
|
+
break;
|
|
30
|
+
case '*':
|
|
31
|
+
result = leftNum * rightNum;
|
|
32
|
+
break;
|
|
33
|
+
case '/':
|
|
34
|
+
result = leftNum / rightNum;
|
|
35
|
+
break;
|
|
36
|
+
case '^':
|
|
37
|
+
case '**':
|
|
38
|
+
result = Math.pow(leftNum, rightNum);
|
|
39
|
+
break;
|
|
40
|
+
}
|
|
41
|
+
return { kind: 'number', value: result };
|
|
42
|
+
}
|
|
43
|
+
// Addition rules
|
|
44
|
+
if (expr.operator === '+') {
|
|
45
|
+
if (leftNum === 0)
|
|
46
|
+
return right;
|
|
47
|
+
if (rightNum === 0)
|
|
48
|
+
return left;
|
|
49
|
+
}
|
|
50
|
+
// Subtraction rules
|
|
51
|
+
if (expr.operator === '-') {
|
|
52
|
+
if (rightNum === 0)
|
|
53
|
+
return left;
|
|
54
|
+
if (leftNum === 0) {
|
|
55
|
+
return { kind: 'unary', operator: '-', operand: right };
|
|
56
|
+
}
|
|
57
|
+
if (expressionsEqual(left, right)) {
|
|
58
|
+
return { kind: 'number', value: 0 };
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
// Multiplication rules
|
|
62
|
+
if (expr.operator === '*') {
|
|
63
|
+
if (leftNum === 0)
|
|
64
|
+
return { kind: 'number', value: 0 };
|
|
65
|
+
if (rightNum === 0)
|
|
66
|
+
return { kind: 'number', value: 0 };
|
|
67
|
+
if (leftNum === 1)
|
|
68
|
+
return right;
|
|
69
|
+
if (rightNum === 1)
|
|
70
|
+
return left;
|
|
71
|
+
// (x / x) * y → y
|
|
72
|
+
if (left.kind === 'binary' && left.operator === '/') {
|
|
73
|
+
if (expressionsEqual(left.left, left.right)) {
|
|
74
|
+
return right;
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
// 0.5 * (a + a) → a
|
|
78
|
+
if (leftNum === 0.5 && right.kind === 'binary' && right.operator === '+') {
|
|
79
|
+
const { left: l1, right: r1 } = right;
|
|
80
|
+
if (expressionsEqual(l1, r1)) {
|
|
81
|
+
return l1;
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
// 0.5 * (a*b + b*a) → a*b
|
|
85
|
+
if (leftNum === 0.5 && right.kind === 'binary' && right.operator === '+') {
|
|
86
|
+
const { left: l1, right: r1 } = right;
|
|
87
|
+
if (l1.kind === 'binary' && l1.operator === '*' &&
|
|
88
|
+
r1.kind === 'binary' && r1.operator === '*') {
|
|
89
|
+
if (expressionsEqual(l1.left, r1.right) && expressionsEqual(l1.right, r1.left)) {
|
|
90
|
+
return l1;
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
// c * (a*b + b*a) → 2*c*a*b
|
|
95
|
+
if (leftNum !== null && right.kind === 'binary' && right.operator === '+') {
|
|
96
|
+
const { left: l1, right: r1 } = right;
|
|
97
|
+
if (l1.kind === 'binary' && l1.operator === '*' &&
|
|
98
|
+
r1.kind === 'binary' && r1.operator === '*') {
|
|
99
|
+
if (expressionsEqual(l1.left, r1.right) && expressionsEqual(l1.right, r1.left)) {
|
|
100
|
+
return {
|
|
101
|
+
kind: 'binary',
|
|
102
|
+
operator: '*',
|
|
103
|
+
left: { kind: 'number', value: 2 * leftNum },
|
|
104
|
+
right: l1
|
|
105
|
+
};
|
|
106
|
+
}
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
// Division rules
|
|
111
|
+
if (expr.operator === '/') {
|
|
112
|
+
if (leftNum === 0)
|
|
113
|
+
return { kind: 'number', value: 0 };
|
|
114
|
+
if (rightNum === 1)
|
|
115
|
+
return left;
|
|
116
|
+
if (expressionsEqual(left, right)) {
|
|
117
|
+
return { kind: 'number', value: 1 };
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
// Power rules
|
|
121
|
+
if (expr.operator === '^' || expr.operator === '**') {
|
|
122
|
+
if (rightNum === 0)
|
|
123
|
+
return { kind: 'number', value: 1 };
|
|
124
|
+
if (rightNum === 1)
|
|
125
|
+
return left;
|
|
126
|
+
if (leftNum === 0)
|
|
127
|
+
return { kind: 'number', value: 0 };
|
|
128
|
+
if (leftNum === 1)
|
|
129
|
+
return { kind: 'number', value: 1 };
|
|
130
|
+
}
|
|
131
|
+
return {
|
|
132
|
+
kind: 'binary',
|
|
133
|
+
operator: expr.operator,
|
|
134
|
+
left,
|
|
135
|
+
right
|
|
136
|
+
};
|
|
137
|
+
}
|
|
138
|
+
visitUnaryOp(expr) {
|
|
139
|
+
const operand = this.transform(expr.operand);
|
|
140
|
+
if (expr.operator === '-') {
|
|
141
|
+
// Double negation: --x = x
|
|
142
|
+
if (operand.kind === 'unary' && operand.operator === '-') {
|
|
143
|
+
return operand.operand;
|
|
144
|
+
}
|
|
145
|
+
// Negate number: -5 = -5
|
|
146
|
+
if (isNumber(operand)) {
|
|
147
|
+
return { kind: 'number', value: -operand.value };
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
if (expr.operator === '+') {
|
|
151
|
+
return operand;
|
|
152
|
+
}
|
|
153
|
+
return {
|
|
154
|
+
kind: 'unary',
|
|
155
|
+
operator: expr.operator,
|
|
156
|
+
operand
|
|
157
|
+
};
|
|
158
|
+
}
|
|
159
|
+
visitFunctionCall(expr) {
|
|
160
|
+
const args = expr.args.map(arg => this.transform(arg));
|
|
161
|
+
if (expr.name === 'sqrt' && args.length === 1) {
|
|
162
|
+
const arg = args[0];
|
|
163
|
+
if (isNumber(arg) && arg.value >= 0) {
|
|
164
|
+
return { kind: 'number', value: Math.sqrt(arg.value) };
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
if (expr.name === 'abs' && args.length === 1) {
|
|
168
|
+
const arg = args[0];
|
|
169
|
+
if (isNumber(arg)) {
|
|
170
|
+
return { kind: 'number', value: Math.abs(arg.value) };
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
return {
|
|
174
|
+
kind: 'call',
|
|
175
|
+
name: expr.name,
|
|
176
|
+
args
|
|
177
|
+
};
|
|
178
|
+
}
|
|
179
|
+
visitComponentAccess(expr) {
|
|
180
|
+
const object = this.transform(expr.object);
|
|
181
|
+
// (u + v).x -> u.x + v.x
|
|
182
|
+
if (object.kind === 'binary') {
|
|
183
|
+
return this.transform({
|
|
184
|
+
kind: 'binary',
|
|
185
|
+
operator: object.operator,
|
|
186
|
+
left: {
|
|
187
|
+
kind: 'component',
|
|
188
|
+
object: object.left,
|
|
189
|
+
component: expr.component
|
|
190
|
+
},
|
|
191
|
+
right: {
|
|
192
|
+
kind: 'component',
|
|
193
|
+
object: object.right,
|
|
194
|
+
component: expr.component
|
|
195
|
+
}
|
|
196
|
+
});
|
|
197
|
+
}
|
|
198
|
+
return {
|
|
199
|
+
kind: 'component',
|
|
200
|
+
object,
|
|
201
|
+
component: expr.component
|
|
202
|
+
};
|
|
203
|
+
}
|
|
204
|
+
}
|
|
205
|
+
/**
|
|
206
|
+
* Simplify an expression using algebraic rules
|
|
207
|
+
*/
|
|
208
|
+
export function simplify(expr) {
|
|
209
|
+
let current = expr;
|
|
210
|
+
let simplified;
|
|
211
|
+
do {
|
|
212
|
+
simplified = current;
|
|
213
|
+
current = new Simplifier(false).transform(simplified);
|
|
214
|
+
} while (!expressionsEqual(current, simplified));
|
|
215
|
+
return current;
|
|
216
|
+
}
|
|
217
|
+
/**
|
|
218
|
+
* Check if expression is a number literal
|
|
219
|
+
*/
|
|
220
|
+
function isNumber(expr) {
|
|
221
|
+
return expr.kind === 'number';
|
|
222
|
+
}
|
|
223
|
+
/**
|
|
224
|
+
* Check if two expressions are structurally equal
|
|
225
|
+
*/
|
|
226
|
+
function expressionsEqual(a, b) {
|
|
227
|
+
if (a.kind !== b.kind)
|
|
228
|
+
return false;
|
|
229
|
+
switch (a.kind) {
|
|
230
|
+
case 'number':
|
|
231
|
+
return b.kind === 'number' && a.value === b.value;
|
|
232
|
+
case 'variable':
|
|
233
|
+
return b.kind === 'variable' && a.name === b.name;
|
|
234
|
+
case 'binary':
|
|
235
|
+
if (b.kind !== 'binary')
|
|
236
|
+
return false;
|
|
237
|
+
return a.operator === b.operator &&
|
|
238
|
+
expressionsEqual(a.left, b.left) &&
|
|
239
|
+
expressionsEqual(a.right, b.right);
|
|
240
|
+
case 'unary':
|
|
241
|
+
if (b.kind !== 'unary')
|
|
242
|
+
return false;
|
|
243
|
+
return a.operator === b.operator &&
|
|
244
|
+
expressionsEqual(a.operand, b.operand);
|
|
245
|
+
case 'call':
|
|
246
|
+
if (b.kind !== 'call')
|
|
247
|
+
return false;
|
|
248
|
+
return a.name === b.name &&
|
|
249
|
+
a.args.length === b.args.length &&
|
|
250
|
+
a.args.every((arg, i) => expressionsEqual(arg, b.args[i]));
|
|
251
|
+
case 'component':
|
|
252
|
+
if (b.kind !== 'component')
|
|
253
|
+
return false;
|
|
254
|
+
return a.component === b.component &&
|
|
255
|
+
expressionsEqual(a.object, b.object);
|
|
256
|
+
}
|
|
257
|
+
}
|
|
258
|
+
/**
|
|
259
|
+
* Simplify all gradients in a map
|
|
260
|
+
*/
|
|
261
|
+
export function simplifyGradients(gradients) {
|
|
262
|
+
const simplified = new Map();
|
|
263
|
+
for (const [key, value] of gradients.entries()) {
|
|
264
|
+
if ('components' in value) {
|
|
265
|
+
const simplifiedComps = new Map();
|
|
266
|
+
for (const [comp, expr] of value.components.entries()) {
|
|
267
|
+
simplifiedComps.set(comp, simplify(expr));
|
|
268
|
+
}
|
|
269
|
+
simplified.set(key, { components: simplifiedComps });
|
|
270
|
+
}
|
|
271
|
+
else {
|
|
272
|
+
simplified.set(key, simplify(value));
|
|
273
|
+
}
|
|
274
|
+
}
|
|
275
|
+
return simplified;
|
|
276
|
+
}
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Type inference for GradientScript DSL
|
|
3
|
+
* Infers types for all expressions and validates type correctness
|
|
4
|
+
*/
|
|
5
|
+
import { Program, FunctionDef, Expression, Statement } from './AST.js';
|
|
6
|
+
import { Type, TypeEnv } from './Types.js';
|
|
7
|
+
/**
|
|
8
|
+
* Type inference visitor
|
|
9
|
+
*/
|
|
10
|
+
export declare class TypeInferenceVisitor {
|
|
11
|
+
private env;
|
|
12
|
+
constructor(env: TypeEnv);
|
|
13
|
+
/**
|
|
14
|
+
* Infer type for an expression
|
|
15
|
+
*/
|
|
16
|
+
inferExpression(expr: Expression): Type;
|
|
17
|
+
private inferNumber;
|
|
18
|
+
private inferVariable;
|
|
19
|
+
private inferBinary;
|
|
20
|
+
private inferUnary;
|
|
21
|
+
private inferCall;
|
|
22
|
+
private inferComponent;
|
|
23
|
+
}
|
|
24
|
+
/**
|
|
25
|
+
* Infer types for a statement
|
|
26
|
+
*/
|
|
27
|
+
export declare function inferStatement(stmt: Statement, env: TypeEnv): void;
|
|
28
|
+
/**
|
|
29
|
+
* Infer types for a function
|
|
30
|
+
*/
|
|
31
|
+
export declare function inferFunction(func: FunctionDef): TypeEnv;
|
|
32
|
+
/**
|
|
33
|
+
* Infer types for entire program
|
|
34
|
+
*/
|
|
35
|
+
export declare function inferProgram(program: Program): void;
|
|
36
|
+
/**
|
|
37
|
+
* Convenience function to infer types
|
|
38
|
+
*/
|
|
39
|
+
export declare function inferTypes(program: Program): Program;
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Type inference for GradientScript DSL
|
|
3
|
+
* Infers types for all expressions and validates type correctness
|
|
4
|
+
*/
|
|
5
|
+
import { Types, TypeEnv } from './Types.js';
|
|
6
|
+
import { builtIns } from './BuiltIns.js';
|
|
7
|
+
import { TypeError } from './Errors.js';
|
|
8
|
+
/**
|
|
9
|
+
* Type inference visitor
|
|
10
|
+
*/
|
|
11
|
+
export class TypeInferenceVisitor {
|
|
12
|
+
env;
|
|
13
|
+
constructor(env) {
|
|
14
|
+
this.env = env;
|
|
15
|
+
}
|
|
16
|
+
/**
|
|
17
|
+
* Infer type for an expression
|
|
18
|
+
*/
|
|
19
|
+
inferExpression(expr) {
|
|
20
|
+
switch (expr.kind) {
|
|
21
|
+
case 'number':
|
|
22
|
+
return this.inferNumber(expr);
|
|
23
|
+
case 'variable':
|
|
24
|
+
return this.inferVariable(expr);
|
|
25
|
+
case 'binary':
|
|
26
|
+
return this.inferBinary(expr);
|
|
27
|
+
case 'unary':
|
|
28
|
+
return this.inferUnary(expr);
|
|
29
|
+
case 'call':
|
|
30
|
+
return this.inferCall(expr);
|
|
31
|
+
case 'component':
|
|
32
|
+
return this.inferComponent(expr);
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
inferNumber(expr) {
|
|
36
|
+
const type = Types.scalar();
|
|
37
|
+
expr.type = type;
|
|
38
|
+
return type;
|
|
39
|
+
}
|
|
40
|
+
inferVariable(expr) {
|
|
41
|
+
const type = this.env.getOrThrow(expr.name);
|
|
42
|
+
expr.type = type;
|
|
43
|
+
return type;
|
|
44
|
+
}
|
|
45
|
+
inferBinary(expr) {
|
|
46
|
+
const leftType = this.inferExpression(expr.left);
|
|
47
|
+
const rightType = this.inferExpression(expr.right);
|
|
48
|
+
if (!Types.compatible(leftType, rightType)) {
|
|
49
|
+
throw new TypeError(`Type mismatch in binary operation`, expr.operator, Types.toString(leftType), Types.toString(rightType));
|
|
50
|
+
}
|
|
51
|
+
const resultType = Types.binaryResultType(leftType, rightType, expr.operator);
|
|
52
|
+
expr.type = resultType;
|
|
53
|
+
return resultType;
|
|
54
|
+
}
|
|
55
|
+
inferUnary(expr) {
|
|
56
|
+
const operandType = this.inferExpression(expr.operand);
|
|
57
|
+
const resultType = Types.unaryResultType(operandType, expr.operator);
|
|
58
|
+
expr.type = resultType;
|
|
59
|
+
return resultType;
|
|
60
|
+
}
|
|
61
|
+
inferCall(expr) {
|
|
62
|
+
// Infer argument types
|
|
63
|
+
const argTypes = expr.args.map(arg => this.inferExpression(arg));
|
|
64
|
+
// Look up built-in function
|
|
65
|
+
const signature = builtIns.lookup(expr.name, argTypes);
|
|
66
|
+
if (!signature) {
|
|
67
|
+
if (builtIns.isBuiltIn(expr.name)) {
|
|
68
|
+
const overloads = builtIns.getOverloads(expr.name);
|
|
69
|
+
const expectedSigs = overloads.map(sig => `${sig.name}(${sig.params.map(p => Types.toString(p)).join(', ')})`).join(' or ');
|
|
70
|
+
const actualSig = `${expr.name}(${argTypes.map(t => Types.toString(t)).join(', ')})`;
|
|
71
|
+
throw new TypeError(`No matching overload. Expected: ${expectedSigs}`, actualSig);
|
|
72
|
+
}
|
|
73
|
+
else {
|
|
74
|
+
throw new TypeError(`Unknown function`, expr.name);
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
expr.type = signature.returnType;
|
|
78
|
+
return signature.returnType;
|
|
79
|
+
}
|
|
80
|
+
inferComponent(expr) {
|
|
81
|
+
const objectType = this.inferExpression(expr.object);
|
|
82
|
+
if (!Types.isStruct(objectType)) {
|
|
83
|
+
throw new TypeError(`Cannot access component of scalar type`, expr.component, 'struct', 'scalar');
|
|
84
|
+
}
|
|
85
|
+
if (!objectType.components.includes(expr.component)) {
|
|
86
|
+
throw new TypeError(`Component does not exist. Available: ${objectType.components.join(', ')}`, expr.component, objectType.components.join('|'), expr.component);
|
|
87
|
+
}
|
|
88
|
+
const resultType = Types.scalar();
|
|
89
|
+
expr.type = resultType;
|
|
90
|
+
return resultType;
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
/**
|
|
94
|
+
* Infer types for a statement
|
|
95
|
+
*/
|
|
96
|
+
export function inferStatement(stmt, env) {
|
|
97
|
+
if (stmt.kind === 'assignment') {
|
|
98
|
+
const visitor = new TypeInferenceVisitor(env);
|
|
99
|
+
const exprType = visitor.inferExpression(stmt.expression);
|
|
100
|
+
// Add variable to environment
|
|
101
|
+
env.set(stmt.variable, exprType);
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
/**
|
|
105
|
+
* Infer types for a function
|
|
106
|
+
*/
|
|
107
|
+
export function inferFunction(func) {
|
|
108
|
+
const env = new TypeEnv();
|
|
109
|
+
// Add parameters to environment
|
|
110
|
+
for (const param of func.parameters) {
|
|
111
|
+
let paramType;
|
|
112
|
+
if (param.paramType) {
|
|
113
|
+
// Explicit type annotation
|
|
114
|
+
paramType = Types.struct(param.paramType.components);
|
|
115
|
+
}
|
|
116
|
+
else {
|
|
117
|
+
// No annotation - assume scalar for now
|
|
118
|
+
// (could be inferred from usage later)
|
|
119
|
+
paramType = Types.scalar();
|
|
120
|
+
}
|
|
121
|
+
env.set(param.name, paramType);
|
|
122
|
+
}
|
|
123
|
+
// Infer types for statements
|
|
124
|
+
for (const stmt of func.body) {
|
|
125
|
+
inferStatement(stmt, env);
|
|
126
|
+
}
|
|
127
|
+
// Infer return type
|
|
128
|
+
const visitor = new TypeInferenceVisitor(env);
|
|
129
|
+
const returnType = visitor.inferExpression(func.returnExpr);
|
|
130
|
+
func.type = returnType;
|
|
131
|
+
return env;
|
|
132
|
+
}
|
|
133
|
+
/**
|
|
134
|
+
* Infer types for entire program
|
|
135
|
+
*/
|
|
136
|
+
export function inferProgram(program) {
|
|
137
|
+
for (const func of program.functions) {
|
|
138
|
+
inferFunction(func);
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
/**
|
|
142
|
+
* Convenience function to infer types
|
|
143
|
+
*/
|
|
144
|
+
export function inferTypes(program) {
|
|
145
|
+
inferProgram(program);
|
|
146
|
+
return program;
|
|
147
|
+
}
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Type system for GradientScript DSL
|
|
3
|
+
* Handles scalar vs structured types and type inference
|
|
4
|
+
*/
|
|
5
|
+
/**
|
|
6
|
+
* Represents a type in the DSL
|
|
7
|
+
*/
|
|
8
|
+
export type Type = ScalarType | StructType;
|
|
9
|
+
/**
|
|
10
|
+
* Scalar type (numbers)
|
|
11
|
+
*/
|
|
12
|
+
export interface ScalarType {
|
|
13
|
+
kind: 'scalar';
|
|
14
|
+
}
|
|
15
|
+
/**
|
|
16
|
+
* Structured type with named components
|
|
17
|
+
* e.g., {x, y} or {x, y, z}
|
|
18
|
+
*/
|
|
19
|
+
export interface StructType {
|
|
20
|
+
kind: 'struct';
|
|
21
|
+
components: string[];
|
|
22
|
+
}
|
|
23
|
+
/**
|
|
24
|
+
* Type utilities
|
|
25
|
+
*/
|
|
26
|
+
export declare const Types: {
|
|
27
|
+
scalar(): ScalarType;
|
|
28
|
+
struct(components: string[]): StructType;
|
|
29
|
+
vec2(): StructType;
|
|
30
|
+
vec3(): StructType;
|
|
31
|
+
isScalar(type: Type): type is ScalarType;
|
|
32
|
+
isStruct(type: Type): type is StructType;
|
|
33
|
+
equals(a: Type, b: Type): boolean;
|
|
34
|
+
toString(type: Type): string;
|
|
35
|
+
/**
|
|
36
|
+
* Check if two types are compatible for binary operations
|
|
37
|
+
*/
|
|
38
|
+
compatible(a: Type, b: Type): boolean;
|
|
39
|
+
/**
|
|
40
|
+
* Result type of binary operation
|
|
41
|
+
*/
|
|
42
|
+
binaryResultType(a: Type, b: Type, op: string): Type;
|
|
43
|
+
/**
|
|
44
|
+
* Result type of unary operation
|
|
45
|
+
*/
|
|
46
|
+
unaryResultType(type: Type, op: string): Type;
|
|
47
|
+
};
|
|
48
|
+
/**
|
|
49
|
+
* Type environment for tracking variable types
|
|
50
|
+
*/
|
|
51
|
+
export declare class TypeEnv {
|
|
52
|
+
private types;
|
|
53
|
+
set(name: string, type: Type): void;
|
|
54
|
+
get(name: string): Type | undefined;
|
|
55
|
+
has(name: string): boolean;
|
|
56
|
+
clone(): TypeEnv;
|
|
57
|
+
getOrThrow(name: string): Type;
|
|
58
|
+
}
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Type system for GradientScript DSL
|
|
3
|
+
* Handles scalar vs structured types and type inference
|
|
4
|
+
*/
|
|
5
|
+
/**
|
|
6
|
+
* Type utilities
|
|
7
|
+
*/
|
|
8
|
+
export const Types = {
|
|
9
|
+
scalar() {
|
|
10
|
+
return { kind: 'scalar' };
|
|
11
|
+
},
|
|
12
|
+
struct(components) {
|
|
13
|
+
return { kind: 'struct', components };
|
|
14
|
+
},
|
|
15
|
+
vec2() {
|
|
16
|
+
return { kind: 'struct', components: ['x', 'y'] };
|
|
17
|
+
},
|
|
18
|
+
vec3() {
|
|
19
|
+
return { kind: 'struct', components: ['x', 'y', 'z'] };
|
|
20
|
+
},
|
|
21
|
+
isScalar(type) {
|
|
22
|
+
return type.kind === 'scalar';
|
|
23
|
+
},
|
|
24
|
+
isStruct(type) {
|
|
25
|
+
return type.kind === 'struct';
|
|
26
|
+
},
|
|
27
|
+
equals(a, b) {
|
|
28
|
+
if (a.kind !== b.kind)
|
|
29
|
+
return false;
|
|
30
|
+
if (a.kind === 'scalar')
|
|
31
|
+
return true;
|
|
32
|
+
const aStruct = a;
|
|
33
|
+
const bStruct = b;
|
|
34
|
+
if (aStruct.components.length !== bStruct.components.length)
|
|
35
|
+
return false;
|
|
36
|
+
return aStruct.components.every((comp, i) => comp === bStruct.components[i]);
|
|
37
|
+
},
|
|
38
|
+
toString(type) {
|
|
39
|
+
if (type.kind === 'scalar')
|
|
40
|
+
return 'scalar';
|
|
41
|
+
return `{${type.components.join(', ')}}`;
|
|
42
|
+
},
|
|
43
|
+
/**
|
|
44
|
+
* Check if two types are compatible for binary operations
|
|
45
|
+
*/
|
|
46
|
+
compatible(a, b) {
|
|
47
|
+
// scalar + scalar = ok
|
|
48
|
+
if (a.kind === 'scalar' && b.kind === 'scalar')
|
|
49
|
+
return true;
|
|
50
|
+
// struct + struct = ok if same structure
|
|
51
|
+
if (a.kind === 'struct' && b.kind === 'struct') {
|
|
52
|
+
return Types.equals(a, b);
|
|
53
|
+
}
|
|
54
|
+
// scalar + struct = ok (broadcasting)
|
|
55
|
+
if (a.kind === 'scalar' || b.kind === 'scalar')
|
|
56
|
+
return true;
|
|
57
|
+
return false;
|
|
58
|
+
},
|
|
59
|
+
/**
|
|
60
|
+
* Result type of binary operation
|
|
61
|
+
*/
|
|
62
|
+
binaryResultType(a, b, op) {
|
|
63
|
+
// scalar op scalar = scalar
|
|
64
|
+
if (a.kind === 'scalar' && b.kind === 'scalar') {
|
|
65
|
+
return Types.scalar();
|
|
66
|
+
}
|
|
67
|
+
// struct op struct = struct (element-wise)
|
|
68
|
+
if (a.kind === 'struct' && b.kind === 'struct') {
|
|
69
|
+
if (!Types.equals(a, b)) {
|
|
70
|
+
throw new Error(`Type mismatch: cannot perform ${op} on ${Types.toString(a)} and ${Types.toString(b)}`);
|
|
71
|
+
}
|
|
72
|
+
return a;
|
|
73
|
+
}
|
|
74
|
+
// scalar op struct = struct (broadcasting)
|
|
75
|
+
if (a.kind === 'scalar' && b.kind === 'struct')
|
|
76
|
+
return b;
|
|
77
|
+
if (a.kind === 'struct' && b.kind === 'scalar')
|
|
78
|
+
return a;
|
|
79
|
+
throw new Error(`Invalid types for ${op}: ${Types.toString(a)} and ${Types.toString(b)}`);
|
|
80
|
+
},
|
|
81
|
+
/**
|
|
82
|
+
* Result type of unary operation
|
|
83
|
+
*/
|
|
84
|
+
unaryResultType(type, op) {
|
|
85
|
+
return type; // Unary ops preserve type
|
|
86
|
+
}
|
|
87
|
+
};
|
|
88
|
+
/**
|
|
89
|
+
* Type environment for tracking variable types
|
|
90
|
+
*/
|
|
91
|
+
export class TypeEnv {
|
|
92
|
+
types = new Map();
|
|
93
|
+
set(name, type) {
|
|
94
|
+
this.types.set(name, type);
|
|
95
|
+
}
|
|
96
|
+
get(name) {
|
|
97
|
+
return this.types.get(name);
|
|
98
|
+
}
|
|
99
|
+
has(name) {
|
|
100
|
+
return this.types.has(name);
|
|
101
|
+
}
|
|
102
|
+
clone() {
|
|
103
|
+
const env = new TypeEnv();
|
|
104
|
+
env.types = new Map(this.types);
|
|
105
|
+
return env;
|
|
106
|
+
}
|
|
107
|
+
getOrThrow(name) {
|
|
108
|
+
const type = this.get(name);
|
|
109
|
+
if (!type) {
|
|
110
|
+
throw new Error(`Variable '${name}' is not defined`);
|
|
111
|
+
}
|
|
112
|
+
return type;
|
|
113
|
+
}
|
|
114
|
+
}
|
package/dist/index.d.ts
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* GradientScript - Symbolic differentiation for structured types
|
|
3
|
+
*
|
|
4
|
+
* This library provides automatic differentiation for functions with
|
|
5
|
+
* structured types (vectors, custom structures).
|
|
6
|
+
*/
|
|
7
|
+
export { parse } from './dsl/Parser.js';
|
|
8
|
+
export { inferFunction } from './dsl/TypeInference.js';
|
|
9
|
+
export { computeFunctionGradients } from './dsl/Differentiation.js';
|
|
10
|
+
export { generateComplete, generateGradientFunction, type CodeGenOptions } from './dsl/CodeGen.js';
|
|
11
|
+
export type { Expression, FunctionDef, Program, Parameter, Assignment } from './dsl/AST.js';
|
|
12
|
+
export type { Type, ScalarType, StructType, TypeEnv } from './dsl/Types.js';
|
|
13
|
+
export type { GradCheckResult, GradCheckError } from './dsl/GradientChecker.js';
|
package/dist/index.js
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* GradientScript - Symbolic differentiation for structured types
|
|
3
|
+
*
|
|
4
|
+
* This library provides automatic differentiation for functions with
|
|
5
|
+
* structured types (vectors, custom structures).
|
|
6
|
+
*/
|
|
7
|
+
// Core API
|
|
8
|
+
export { parse } from './dsl/Parser.js';
|
|
9
|
+
export { inferFunction } from './dsl/TypeInference.js';
|
|
10
|
+
export { computeFunctionGradients } from './dsl/Differentiation.js';
|
|
11
|
+
export { generateComplete, generateGradientFunction } from './dsl/CodeGen.js';
|