gradient-script 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +515 -0
- package/dist/cli.d.ts +2 -0
- package/dist/cli.js +136 -0
- package/dist/dsl/AST.d.ts +123 -0
- package/dist/dsl/AST.js +23 -0
- package/dist/dsl/BuiltIns.d.ts +58 -0
- package/dist/dsl/BuiltIns.js +181 -0
- package/dist/dsl/CSE.d.ts +21 -0
- package/dist/dsl/CSE.js +194 -0
- package/dist/dsl/CodeGen.d.ts +60 -0
- package/dist/dsl/CodeGen.js +474 -0
- package/dist/dsl/Differentiation.d.ts +45 -0
- package/dist/dsl/Differentiation.js +421 -0
- package/dist/dsl/DiscontinuityAnalyzer.d.ts +18 -0
- package/dist/dsl/DiscontinuityAnalyzer.js +75 -0
- package/dist/dsl/Errors.d.ts +22 -0
- package/dist/dsl/Errors.js +49 -0
- package/dist/dsl/Expander.d.ts +13 -0
- package/dist/dsl/Expander.js +220 -0
- package/dist/dsl/ExpressionTransformer.d.ts +54 -0
- package/dist/dsl/ExpressionTransformer.js +102 -0
- package/dist/dsl/ExpressionUtils.d.ts +55 -0
- package/dist/dsl/ExpressionUtils.js +175 -0
- package/dist/dsl/GradientChecker.d.ts +71 -0
- package/dist/dsl/GradientChecker.js +258 -0
- package/dist/dsl/Guards.d.ts +27 -0
- package/dist/dsl/Guards.js +206 -0
- package/dist/dsl/Inliner.d.ts +10 -0
- package/dist/dsl/Inliner.js +40 -0
- package/dist/dsl/Lexer.d.ts +63 -0
- package/dist/dsl/Lexer.js +243 -0
- package/dist/dsl/Parser.d.ts +92 -0
- package/dist/dsl/Parser.js +328 -0
- package/dist/dsl/Simplify.d.ts +17 -0
- package/dist/dsl/Simplify.js +276 -0
- package/dist/dsl/TypeInference.d.ts +39 -0
- package/dist/dsl/TypeInference.js +147 -0
- package/dist/dsl/Types.d.ts +58 -0
- package/dist/dsl/Types.js +114 -0
- package/dist/index.d.ts +13 -0
- package/dist/index.js +11 -0
- package/dist/symbolic/AST.d.ts +113 -0
- package/dist/symbolic/AST.js +128 -0
- package/dist/symbolic/CodeGen.d.ts +35 -0
- package/dist/symbolic/CodeGen.js +280 -0
- package/dist/symbolic/Parser.d.ts +64 -0
- package/dist/symbolic/Parser.js +329 -0
- package/dist/symbolic/Simplify.d.ts +10 -0
- package/dist/symbolic/Simplify.js +244 -0
- package/dist/symbolic/SymbolicDiff.d.ts +35 -0
- package/dist/symbolic/SymbolicDiff.js +339 -0
- package/package.json +56 -0
package/README.md
ADDED
|
@@ -0,0 +1,515 @@
|
|
|
1
|
+
# GradientScript
|
|
2
|
+
|
|
3
|
+
**Symbolic automatic differentiation for structured types**
|
|
4
|
+
|
|
5
|
+
GradientScript is a source-to-source compiler that automatically generates gradient functions from your mathematical code. Unlike numerical AD frameworks (JAX, PyTorch), it produces clean, human-readable gradient formulas you can inspect, optimize, and integrate directly into your codebase.
|
|
6
|
+
|
|
7
|
+
## Why GradientScript?
|
|
8
|
+
|
|
9
|
+
- **From real code to gradients**: Write natural math code, get symbolic derivatives
|
|
10
|
+
- **Verified correctness**: Every gradient automatically checked against numerical differentiation
|
|
11
|
+
- **Structured types**: Work with vectors `{x, y}` and custom structures, not just scalars
|
|
12
|
+
- **Zero runtime overhead**: No tape, no graph - just pure gradient functions
|
|
13
|
+
- **Multiple output languages**: TypeScript, JavaScript, or Python
|
|
14
|
+
- **Readable output**: Human-reviewable formulas with automatic optimization
|
|
15
|
+
|
|
16
|
+
## Installation
|
|
17
|
+
|
|
18
|
+
```bash
|
|
19
|
+
npm install -g gradient-script
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
## Quick Example
|
|
23
|
+
|
|
24
|
+
You have TypeScript code computing 2D vector distance:
|
|
25
|
+
|
|
26
|
+
```typescript
|
|
27
|
+
// Your original TypeScript code
|
|
28
|
+
function distance(u: Vec2, v: Vec2): number {
|
|
29
|
+
const dx = u.x - v.x;
|
|
30
|
+
const dy = u.y - v.y;
|
|
31
|
+
return Math.sqrt(dx * dx + dy * dy);
|
|
32
|
+
}
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
Convert it to GradientScript by marking what you need gradients for:
|
|
36
|
+
|
|
37
|
+
```typescript
|
|
38
|
+
// distance.gs
|
|
39
|
+
function distance(u∇: {x, y}, v∇: {x, y}) {
|
|
40
|
+
dx = u.x - v.x
|
|
41
|
+
dy = u.y - v.y
|
|
42
|
+
return sqrt(dx * dx + dy * dy)
|
|
43
|
+
}
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
Generate gradients:
|
|
47
|
+
|
|
48
|
+
```bash
|
|
49
|
+
gradient-script distance.gs
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
Get complete forward and gradient functions:
|
|
53
|
+
|
|
54
|
+
```typescript
|
|
55
|
+
// Forward function
|
|
56
|
+
function distance(u, v) {
|
|
57
|
+
const dx = u.x - v.x;
|
|
58
|
+
const dy = u.y - v.y;
|
|
59
|
+
return Math.sqrt(dx * dx + dy * dy);
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
// Gradient function - returns { value, du, dv }
|
|
63
|
+
function distance_grad(u, v) {
|
|
64
|
+
const dx = u.x - v.x;
|
|
65
|
+
const dy = u.y - v.y;
|
|
66
|
+
const value = Math.sqrt(dx * dx + dy * dy);
|
|
67
|
+
|
|
68
|
+
const _tmp0 = 2 * Math.sqrt(dx * dx + dy * dy);
|
|
69
|
+
|
|
70
|
+
const du = {
|
|
71
|
+
x: (2 * dx) / _tmp0,
|
|
72
|
+
y: (2 * dy) / _tmp0,
|
|
73
|
+
};
|
|
74
|
+
const dv = {
|
|
75
|
+
x: (2 * (-dx)) / _tmp0,
|
|
76
|
+
y: (2 * (-dy)) / _tmp0,
|
|
77
|
+
};
|
|
78
|
+
|
|
79
|
+
return { value, du, dv };
|
|
80
|
+
}
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
Now use it in your optimizer, physics engine, or neural network!
|
|
84
|
+
|
|
85
|
+
## More Examples
|
|
86
|
+
|
|
87
|
+
### From C++ Physics Code
|
|
88
|
+
|
|
89
|
+
**Original C++ spring force calculation:**
|
|
90
|
+
```cpp
|
|
91
|
+
float spring_energy(Vec2 p1, Vec2 p2, float rest_length, float k) {
|
|
92
|
+
float dx = p2.x - p1.x;
|
|
93
|
+
float dy = p2.y - p1.y;
|
|
94
|
+
float dist = sqrt(dx*dx + dy*dy);
|
|
95
|
+
float stretch = dist - rest_length;
|
|
96
|
+
return 0.5f * k * stretch * stretch;
|
|
97
|
+
}
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
**GradientScript version:**
|
|
101
|
+
```typescript
|
|
102
|
+
function spring_energy(p1∇: {x, y}, p2∇: {x, y}, rest_length, k) {
|
|
103
|
+
dx = p2.x - p1.x
|
|
104
|
+
dy = p2.y - p1.y
|
|
105
|
+
dist = sqrt(dx * dx + dy * dy)
|
|
106
|
+
stretch = dist - rest_length
|
|
107
|
+
return 0.5 * k * stretch^2
|
|
108
|
+
}
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
**Generated gradient (for physics simulation):**
|
|
112
|
+
```typescript
|
|
113
|
+
function spring_energy_grad(p1, p2, rest_length, k) {
|
|
114
|
+
const dx = p2.x - p1.x;
|
|
115
|
+
const dy = p2.y - p1.y;
|
|
116
|
+
const dist = Math.sqrt(dx * dx + dy * dy);
|
|
117
|
+
const stretch = dist - rest_length;
|
|
118
|
+
const value = 0.5 * k * stretch * stretch;
|
|
119
|
+
|
|
120
|
+
const _tmp0 = 2 * Math.sqrt(dx * dx + dy * dy);
|
|
121
|
+
|
|
122
|
+
const dp1 = {
|
|
123
|
+
x: k * stretch * (-(2 * dx) / _tmp0),
|
|
124
|
+
y: k * stretch * (-(2 * dy) / _tmp0),
|
|
125
|
+
};
|
|
126
|
+
const dp2 = {
|
|
127
|
+
x: k * stretch * (2 * dx) / _tmp0,
|
|
128
|
+
y: k * stretch * (2 * dy) / _tmp0,
|
|
129
|
+
};
|
|
130
|
+
|
|
131
|
+
return { value, dp1, dp2 };
|
|
132
|
+
}
|
|
133
|
+
```
|
|
134
|
+
|
|
135
|
+
Use `dp1` and `dp2` as forces in your physics simulation!
|
|
136
|
+
|
|
137
|
+
### From C# Graphics Code
|
|
138
|
+
|
|
139
|
+
**Original C# normalized dot product:**
|
|
140
|
+
```csharp
|
|
141
|
+
float NormalizedDotProduct(Vector2 u, Vector2 v) {
|
|
142
|
+
float dot = u.X * v.X + u.Y * v.Y;
|
|
143
|
+
float u_mag = (float)Math.Sqrt(u.X * u.X + u.Y * u.Y);
|
|
144
|
+
float v_mag = (float)Math.Sqrt(v.X * v.X + v.Y * v.Y);
|
|
145
|
+
return dot / (u_mag * v_mag);
|
|
146
|
+
}
|
|
147
|
+
```
|
|
148
|
+
|
|
149
|
+
**GradientScript version:**
|
|
150
|
+
```typescript
|
|
151
|
+
function normalized_dot(u∇: {x, y}, v∇: {x, y}) {
|
|
152
|
+
dot = u.x * v.x + u.y * v.y
|
|
153
|
+
u_mag = sqrt(u.x * u.x + u.y * u.y)
|
|
154
|
+
v_mag = sqrt(v.x * v.x + v.y * v.y)
|
|
155
|
+
return dot / (u_mag * v_mag)
|
|
156
|
+
}
|
|
157
|
+
```
|
|
158
|
+
|
|
159
|
+
**Generated gradient:**
|
|
160
|
+
```typescript
|
|
161
|
+
function normalized_dot_grad(u, v) {
|
|
162
|
+
const dot = u.x * v.x + u.y * v.y;
|
|
163
|
+
const u_mag = Math.sqrt(u.x * u.x + u.y * u.y);
|
|
164
|
+
const v_mag = Math.sqrt(v.x * v.x + v.y * v.y);
|
|
165
|
+
const value = dot / (u_mag * v_mag);
|
|
166
|
+
|
|
167
|
+
const _tmp0 = u_mag * v_mag;
|
|
168
|
+
const _tmp1 = 2 * u_mag;
|
|
169
|
+
const _tmp2 = 2 * v_mag;
|
|
170
|
+
const _tmp3 = _tmp0 * _tmp0;
|
|
171
|
+
|
|
172
|
+
const du = {
|
|
173
|
+
x: (v.x * _tmp0 - dot * u.x / _tmp1 * v_mag) / _tmp3,
|
|
174
|
+
y: (v.y * _tmp0 - dot * u.y / _tmp1 * v_mag) / _tmp3,
|
|
175
|
+
};
|
|
176
|
+
const dv = {
|
|
177
|
+
x: (u.x * _tmp0 - dot * u_mag * v.x / _tmp2) / _tmp3,
|
|
178
|
+
y: (u.y * _tmp0 - dot * u_mag * v.y / _tmp2) / _tmp3,
|
|
179
|
+
};
|
|
180
|
+
|
|
181
|
+
return { value, du, dv };
|
|
182
|
+
}
|
|
183
|
+
```
|
|
184
|
+
|
|
185
|
+
### From JavaScript Robotics
|
|
186
|
+
|
|
187
|
+
**Original JavaScript angle between vectors:**
|
|
188
|
+
```javascript
|
|
189
|
+
function angleBetween(u, v) {
|
|
190
|
+
const cross = u.x * v.y - u.y * v.x;
|
|
191
|
+
const dot = u.x * v.x + u.y * v.y;
|
|
192
|
+
return Math.atan2(cross, dot);
|
|
193
|
+
}
|
|
194
|
+
```
|
|
195
|
+
|
|
196
|
+
**GradientScript version:**
|
|
197
|
+
```typescript
|
|
198
|
+
function angle_between(u∇: {x, y}, v∇: {x, y}) {
|
|
199
|
+
cross = u.x * v.y - u.y * v.x
|
|
200
|
+
dot = u.x * v.x + u.y * v.y
|
|
201
|
+
return atan2(cross, dot)
|
|
202
|
+
}
|
|
203
|
+
```
|
|
204
|
+
|
|
205
|
+
**Generated gradient:**
|
|
206
|
+
```typescript
|
|
207
|
+
function angle_between_grad(u, v) {
|
|
208
|
+
const cross = u.x * v.y - u.y * v.x;
|
|
209
|
+
const dot = u.x * v.x + u.y * v.y;
|
|
210
|
+
const value = Math.atan2(cross, dot);
|
|
211
|
+
|
|
212
|
+
const _tmp0 = dot * dot + cross * cross;
|
|
213
|
+
|
|
214
|
+
const du = {
|
|
215
|
+
x: (dot * v.y - cross * v.x) / _tmp0,
|
|
216
|
+
y: (dot * (-v.x) - cross * v.y) / _tmp0,
|
|
217
|
+
};
|
|
218
|
+
const dv = {
|
|
219
|
+
x: (dot * (-u.y) - cross * u.x) / _tmp0,
|
|
220
|
+
y: (dot * u.x - cross * u.y) / _tmp0,
|
|
221
|
+
};
|
|
222
|
+
|
|
223
|
+
return { value, du, dv };
|
|
224
|
+
}
|
|
225
|
+
```
|
|
226
|
+
|
|
227
|
+
## Command Line Options
|
|
228
|
+
|
|
229
|
+
```bash
|
|
230
|
+
gradient-script <file.gs> [options]
|
|
231
|
+
|
|
232
|
+
Options:
|
|
233
|
+
--format <format> typescript (default), javascript, python
|
|
234
|
+
--no-simplify Disable gradient simplification
|
|
235
|
+
--no-cse Disable common subexpression elimination
|
|
236
|
+
--no-comments Omit comments in generated code
|
|
237
|
+
--help, -h Show help message
|
|
238
|
+
```
|
|
239
|
+
|
|
240
|
+
**Examples:**
|
|
241
|
+
```bash
|
|
242
|
+
# Generate TypeScript (default)
|
|
243
|
+
gradient-script spring.gs
|
|
244
|
+
|
|
245
|
+
# Generate Python
|
|
246
|
+
gradient-script spring.gs --format python
|
|
247
|
+
|
|
248
|
+
# Generate JavaScript without CSE optimization
|
|
249
|
+
gradient-script spring.gs --format javascript --no-cse
|
|
250
|
+
```
|
|
251
|
+
|
|
252
|
+
## Language Syntax
|
|
253
|
+
|
|
254
|
+
### Function Declaration
|
|
255
|
+
|
|
256
|
+
```typescript
|
|
257
|
+
function name(param1∇: {x, y}, param2∇, param3) {
|
|
258
|
+
local1 = expression
|
|
259
|
+
local2 = expression
|
|
260
|
+
return expression
|
|
261
|
+
}
|
|
262
|
+
```
|
|
263
|
+
|
|
264
|
+
- The `∇` symbol marks parameters that need gradients
|
|
265
|
+
- Type annotations `{x, y}` specify structured types
|
|
266
|
+
- Parameters without `∇` are treated as constants
|
|
267
|
+
- Use `=` for assignments, not `const` or `let`
|
|
268
|
+
|
|
269
|
+
### Structured Types
|
|
270
|
+
|
|
271
|
+
```typescript
|
|
272
|
+
// 2D vectors
|
|
273
|
+
u∇: {x, y}
|
|
274
|
+
|
|
275
|
+
// 3D vectors
|
|
276
|
+
v∇: {x, y, z}
|
|
277
|
+
|
|
278
|
+
// Scalars (no annotation)
|
|
279
|
+
param∇
|
|
280
|
+
```
|
|
281
|
+
|
|
282
|
+
### Built-in Functions
|
|
283
|
+
|
|
284
|
+
**Vector operations:**
|
|
285
|
+
- `dot2d(u, v)` - dot product (expands to `u.x*v.x + u.y*v.y`)
|
|
286
|
+
- `cross2d(u, v)` - 2D cross product (expands to `u.x*v.y - u.y*v.x`)
|
|
287
|
+
- `magnitude2d(v)` - vector length (expands to `sqrt(v.x*v.x + v.y*v.y)`)
|
|
288
|
+
- `normalize2d(v)` - unit vector
|
|
289
|
+
|
|
290
|
+
**Math functions:**
|
|
291
|
+
- `sqrt(x)`, `sin(x)`, `cos(x)`, `tan(x)`
|
|
292
|
+
- `asin(x)`, `acos(x)`, `atan(x)`
|
|
293
|
+
- `atan2(y, x)` - two-argument arctangent
|
|
294
|
+
- `exp(x)`, `log(x)`, `abs(x)`
|
|
295
|
+
|
|
296
|
+
**Non-smooth functions (with subgradients):**
|
|
297
|
+
- `min(a, b)` - minimum of two values
|
|
298
|
+
- `max(a, b)` - maximum of two values
|
|
299
|
+
- `clamp(x, lo, hi)` - clamp x to range [lo, hi]
|
|
300
|
+
|
|
301
|
+
**Operators:**
|
|
302
|
+
- Arithmetic: `+`, `-`, `*`, `/`
|
|
303
|
+
- Power: `x^2` (converts to `x * x` for better performance)
|
|
304
|
+
- Negation: `-x`
|
|
305
|
+
|
|
306
|
+
### Output Formats
|
|
307
|
+
|
|
308
|
+
**TypeScript (default):**
|
|
309
|
+
```typescript
|
|
310
|
+
const du = { x: expr1, y: expr2 };
|
|
311
|
+
```
|
|
312
|
+
|
|
313
|
+
**JavaScript:**
|
|
314
|
+
```javascript
|
|
315
|
+
const du = { x: expr1, y: expr2 };
|
|
316
|
+
```
|
|
317
|
+
|
|
318
|
+
**Python:**
|
|
319
|
+
```python
|
|
320
|
+
du = { "x": expr1, "y": expr2 }
|
|
321
|
+
```
|
|
322
|
+
|
|
323
|
+
## How It Works
|
|
324
|
+
|
|
325
|
+
GradientScript uses **symbolic differentiation** with the chain rule:
|
|
326
|
+
|
|
327
|
+
1. **Parse** your function into an expression tree
|
|
328
|
+
2. **Type inference** determines scalar vs structured gradients
|
|
329
|
+
3. **Symbolic differentiation** applies calculus rules (product rule, chain rule, etc.)
|
|
330
|
+
4. **Simplification** reduces complex expressions
|
|
331
|
+
5. **CSE optimization** eliminates redundant subexpressions
|
|
332
|
+
6. **Code generation** emits clean TypeScript/JavaScript/Python
|
|
333
|
+
|
|
334
|
+
### Common Subexpression Elimination (CSE)
|
|
335
|
+
|
|
336
|
+
GradientScript automatically factors out repeated expressions:
|
|
337
|
+
|
|
338
|
+
**Before CSE:**
|
|
339
|
+
```typescript
|
|
340
|
+
const du_x = v.x / sqrt(u.x*u.x + u.y*u.y) - dot * u.x / (2 * sqrt(u.x*u.x + u.y*u.y));
|
|
341
|
+
const du_y = v.y / sqrt(u.x*u.x + u.y*u.y) - dot * u.y / (2 * sqrt(u.x*u.x + u.y*u.y));
|
|
342
|
+
```
|
|
343
|
+
|
|
344
|
+
**After CSE:**
|
|
345
|
+
```typescript
|
|
346
|
+
const _tmp0 = Math.sqrt(u.x * u.x + u.y * u.y);
|
|
347
|
+
const _tmp1 = 2 * _tmp0;
|
|
348
|
+
const du = {
|
|
349
|
+
x: v.x / _tmp0 - dot * u.x / _tmp1,
|
|
350
|
+
y: v.y / _tmp0 - dot * u.y / _tmp1,
|
|
351
|
+
};
|
|
352
|
+
```
|
|
353
|
+
|
|
354
|
+
This improves both performance and readability.
|
|
355
|
+
|
|
356
|
+
### Non-Smooth Functions & Subgradients
|
|
357
|
+
|
|
358
|
+
GradientScript supports **non-smooth functions** (`min`, `max`, `clamp`) using **subgradient** differentiation. These are essential for constrained optimization, robust losses, and geometric queries.
|
|
359
|
+
|
|
360
|
+
**Example: Point-to-Segment Distance**
|
|
361
|
+
```typescript
|
|
362
|
+
function distance_point_segment(p∇: {x, y}, a: {x, y}, b: {x, y}) {
|
|
363
|
+
vx = b.x - a.x
|
|
364
|
+
vy = b.y - a.y
|
|
365
|
+
wx = p.x - a.x
|
|
366
|
+
wy = p.y - a.y
|
|
367
|
+
t = (wx * vx + wy * vy) / (vx * vx + vy * vy)
|
|
368
|
+
t_clamped = clamp(t, 0, 1) // Project onto segment
|
|
369
|
+
qx = a.x + t_clamped * vx
|
|
370
|
+
qy = a.y + t_clamped * vy
|
|
371
|
+
dx = p.x - qx
|
|
372
|
+
dy = p.y - qy
|
|
373
|
+
return sqrt(dx * dx + dy * dy)
|
|
374
|
+
}
|
|
375
|
+
```
|
|
376
|
+
|
|
377
|
+
Generated code correctly handles the non-smooth boundaries at segment endpoints:
|
|
378
|
+
```typescript
|
|
379
|
+
const t_clamped = Math.max(0, Math.min(1, t)); // clamp expansion
|
|
380
|
+
```
|
|
381
|
+
|
|
382
|
+
**How subgradients work:**
|
|
383
|
+
- At smooth points: standard gradient
|
|
384
|
+
- At non-smooth points (e.g., `min(a,b)` when `a=b`): any valid subgradient
|
|
385
|
+
- Converges for convex functions in optimization
|
|
386
|
+
- Common in L1 regularization, SVM, robust losses
|
|
387
|
+
|
|
388
|
+
**Use cases:**
|
|
389
|
+
- Constrained optimization (clamp parameters to valid ranges)
|
|
390
|
+
- Robust losses (Huber-like functions with min/max)
|
|
391
|
+
- Geometric queries (distance to segments, boxes, polytopes)
|
|
392
|
+
- Activation functions (ReLU = `max(0, x)`)
|
|
393
|
+
|
|
394
|
+
## Use Cases
|
|
395
|
+
|
|
396
|
+
- **Physics simulations** - Get force gradients for constraint solvers
|
|
397
|
+
- **Robotics** - Compute Jacobians for inverse kinematics
|
|
398
|
+
- **Machine learning** - Custom loss functions with analytical gradients
|
|
399
|
+
- **Computer graphics** - Optimize shader parameters
|
|
400
|
+
- **Game engines** - Procedural animation with gradient-based optimization
|
|
401
|
+
- **Scientific computing** - Sensitivity analysis and optimization
|
|
402
|
+
|
|
403
|
+
## Edge Case Detection
|
|
404
|
+
|
|
405
|
+
GradientScript analyzes your functions and warns about potential issues:
|
|
406
|
+
|
|
407
|
+
```
|
|
408
|
+
⚠️ EDGE CASE WARNINGS:
|
|
409
|
+
|
|
410
|
+
• Division by zero (1 occurrence)
|
|
411
|
+
Division by zero if denominator becomes zero
|
|
412
|
+
💡 Add check: if (denominator === 0) return { value: 0, gradients: {...} };
|
|
413
|
+
|
|
414
|
+
• Square root of negative (2 occurrences)
|
|
415
|
+
magnitude of vector (uses sqrt internally)
|
|
416
|
+
💡 Ensure vector components are valid
|
|
417
|
+
```
|
|
418
|
+
|
|
419
|
+
You can then add appropriate guards in your code that uses the generated functions.
|
|
420
|
+
|
|
421
|
+
## Architecture
|
|
422
|
+
|
|
423
|
+
GradientScript uses a **source-to-source compilation** approach with the following pipeline:
|
|
424
|
+
|
|
425
|
+
```
|
|
426
|
+
Input (.gs file)
|
|
427
|
+
↓
|
|
428
|
+
Lexer & Parser → AST
|
|
429
|
+
↓
|
|
430
|
+
Type Inference → Scalar vs Structured types
|
|
431
|
+
↓
|
|
432
|
+
Built-in Expansion → dot2d(), magnitude(), etc.
|
|
433
|
+
↓
|
|
434
|
+
Symbolic Differentiation → Product rule, chain rule, quotient rule
|
|
435
|
+
↓
|
|
436
|
+
Algebraic Simplification → 0.5*(a+a) → a, etc.
|
|
437
|
+
↓
|
|
438
|
+
CSE Optimization → Extract common subexpressions
|
|
439
|
+
↓
|
|
440
|
+
Code Generation → TypeScript/JavaScript/Python
|
|
441
|
+
↓
|
|
442
|
+
Output (gradient functions)
|
|
443
|
+
```
|
|
444
|
+
|
|
445
|
+
All gradient computations are verified against numerical differentiation to ensure correctness.
|
|
446
|
+
|
|
447
|
+
## Testing & Correctness
|
|
448
|
+
|
|
449
|
+
**Every gradient is automatically verified against numerical differentiation.**
|
|
450
|
+
|
|
451
|
+
GradientScript includes a comprehensive test suite that validates all generated gradients using finite differences. This means you can trust that the symbolic derivatives are mathematically correct.
|
|
452
|
+
|
|
453
|
+
```bash
|
|
454
|
+
npm test
|
|
455
|
+
```
|
|
456
|
+
|
|
457
|
+
Current status: **78 tests passing**
|
|
458
|
+
|
|
459
|
+
Test suite includes:
|
|
460
|
+
|
|
461
|
+
### Gradient Verification Tests
|
|
462
|
+
- **Numerical gradient checking**: All symbolic gradients compared against finite differences
|
|
463
|
+
- Basic scalar differentiation (power, product, chain rules)
|
|
464
|
+
- Structured type gradients (2D/3D vectors)
|
|
465
|
+
- Built-in function derivatives (sin, cos, atan2, sqrt, etc.)
|
|
466
|
+
- Complex compositions and chain rule applications
|
|
467
|
+
|
|
468
|
+
### Property-Based Tests
|
|
469
|
+
- **Singularity handling**: Near-zero denominators, parallel vectors, origin points
|
|
470
|
+
- **Rotation invariance**: Rotating inputs rotates gradients consistently
|
|
471
|
+
- **Scale invariance**: Functions like cosine similarity maintain invariance properties
|
|
472
|
+
- **Symmetry**: Distance function has symmetric gradients
|
|
473
|
+
- **Translation invariance**: Relative functions have zero gradient sum
|
|
474
|
+
- **SE(2) transformations**: Zero gradients at exact match, proper gradient direction
|
|
475
|
+
- **Reprojection invariants**: Uniform scaling maintains structure
|
|
476
|
+
- **Bearing properties**: Rotation shifts angle, gradient perpendicular to input
|
|
477
|
+
|
|
478
|
+
### Code Generation Tests
|
|
479
|
+
- CSE optimization correctness
|
|
480
|
+
- Operator precedence preservation
|
|
481
|
+
- Power optimization (x*x vs Math.pow)
|
|
482
|
+
- Multiple output formats (TypeScript, JavaScript, Python)
|
|
483
|
+
- Algebraic simplification correctness
|
|
484
|
+
|
|
485
|
+
**Key guarantee**: If a test passes, the generated gradient is correct to within numerical precision (~10 decimal places).
|
|
486
|
+
|
|
487
|
+
## Comparison with Other Tools
|
|
488
|
+
|
|
489
|
+
| Feature | GradientScript | JAX/PyTorch | SymPy | Manual Math |
|
|
490
|
+
|---------|----------------|-------------|-------|-------------|
|
|
491
|
+
| **Output** | Clean source code | Tape/Graph | Symbolic expr | Pen & paper |
|
|
492
|
+
| **Runtime** | Zero overhead | Tape overhead | Symbolic eval | Zero |
|
|
493
|
+
| **Readability** | High | Low | Medium | High |
|
|
494
|
+
| **Structured types** | Native | Tensors only | Limited | Natural |
|
|
495
|
+
| **Integration** | Copy/paste code | Framework required | Eval strings | Type by hand |
|
|
496
|
+
| **Speed** | Native JS/TS/Py | JIT optimized | Slow | Native |
|
|
497
|
+
| **Debugging** | Standard debugger | Special tools | Hard | Standard |
|
|
498
|
+
|
|
499
|
+
## Contributing
|
|
500
|
+
|
|
501
|
+
GradientScript is under active development. Contributions welcome!
|
|
502
|
+
|
|
503
|
+
**Roadmap:**
|
|
504
|
+
- Property-based tests for mathematical invariants
|
|
505
|
+
- Additional output formats (C, Rust, GLSL)
|
|
506
|
+
- Web playground for live gradient generation
|
|
507
|
+
- Benchmarking suite
|
|
508
|
+
|
|
509
|
+
## License
|
|
510
|
+
|
|
511
|
+
MIT
|
|
512
|
+
|
|
513
|
+
## Credits
|
|
514
|
+
|
|
515
|
+
Inspired by symbolic differentiation in SymPy, the ergonomics of JAX, and the practicality of writing math code by hand.
|
package/dist/cli.d.ts
ADDED
package/dist/cli.js
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
#!/usr/bin/env node
|
|
2
|
+
import { readFileSync } from 'fs';
|
|
3
|
+
import { parse } from './dsl/Parser.js';
|
|
4
|
+
import { inferFunction } from './dsl/TypeInference.js';
|
|
5
|
+
import { computeFunctionGradients } from './dsl/Differentiation.js';
|
|
6
|
+
import { generateComplete } from './dsl/CodeGen.js';
|
|
7
|
+
import { analyzeGuards, formatGuardWarnings } from './dsl/Guards.js';
|
|
8
|
+
function printUsage() {
|
|
9
|
+
console.log(`
|
|
10
|
+
GradientScript - Symbolic Differentiation for Structured Types
|
|
11
|
+
|
|
12
|
+
Usage:
|
|
13
|
+
gradient-script <file.gs> [options]
|
|
14
|
+
|
|
15
|
+
Options:
|
|
16
|
+
--format <format> Output format: typescript (default), javascript, python
|
|
17
|
+
--no-simplify Disable gradient simplification
|
|
18
|
+
--no-cse Disable common subexpression elimination
|
|
19
|
+
--no-comments Omit comments in generated code
|
|
20
|
+
--guards Emit runtime guards for division by zero (experimental)
|
|
21
|
+
--epsilon <value> Epsilon value for guards (default: 1e-10)
|
|
22
|
+
--help, -h Show this help message
|
|
23
|
+
|
|
24
|
+
Examples:
|
|
25
|
+
gradient-script angle.gs
|
|
26
|
+
gradient-script angle.gs --format python
|
|
27
|
+
gradient-script angle.gs --format javascript --no-comments
|
|
28
|
+
|
|
29
|
+
Input File Format (.gs):
|
|
30
|
+
function name(param1∇: {x, y}, param2∇) {
|
|
31
|
+
// intermediate calculations
|
|
32
|
+
local = expression
|
|
33
|
+
return expression
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
The ∇ symbol marks parameters that need gradients computed.
|
|
37
|
+
Type annotations like {x, y} specify structured types.
|
|
38
|
+
`.trim());
|
|
39
|
+
}
|
|
40
|
+
function main() {
|
|
41
|
+
const args = process.argv.slice(2);
|
|
42
|
+
if (args.length === 0 || args.includes('--help') || args.includes('-h')) {
|
|
43
|
+
printUsage();
|
|
44
|
+
process.exit(0);
|
|
45
|
+
}
|
|
46
|
+
const inputFile = args[0];
|
|
47
|
+
if (!inputFile.endsWith('.gs')) {
|
|
48
|
+
console.error('Error: Input file must have .gs extension');
|
|
49
|
+
process.exit(1);
|
|
50
|
+
}
|
|
51
|
+
const options = {
|
|
52
|
+
format: 'typescript',
|
|
53
|
+
includeComments: true,
|
|
54
|
+
simplify: true,
|
|
55
|
+
cse: true
|
|
56
|
+
};
|
|
57
|
+
for (let i = 1; i < args.length; i++) {
|
|
58
|
+
const arg = args[i];
|
|
59
|
+
if (arg === '--format') {
|
|
60
|
+
const format = args[++i];
|
|
61
|
+
if (format !== 'typescript' && format !== 'javascript' && format !== 'python') {
|
|
62
|
+
console.error(`Error: Invalid format "${format}". Must be: typescript, javascript, or python`);
|
|
63
|
+
process.exit(1);
|
|
64
|
+
}
|
|
65
|
+
options.format = format;
|
|
66
|
+
}
|
|
67
|
+
else if (arg === '--no-simplify') {
|
|
68
|
+
options.simplify = false;
|
|
69
|
+
}
|
|
70
|
+
else if (arg === '--no-cse') {
|
|
71
|
+
options.cse = false;
|
|
72
|
+
}
|
|
73
|
+
else if (arg === '--no-comments') {
|
|
74
|
+
options.includeComments = false;
|
|
75
|
+
}
|
|
76
|
+
else if (arg === '--guards') {
|
|
77
|
+
options.emitGuards = true;
|
|
78
|
+
}
|
|
79
|
+
else if (arg === '--epsilon') {
|
|
80
|
+
const epsilonValue = parseFloat(args[++i]);
|
|
81
|
+
if (isNaN(epsilonValue) || epsilonValue <= 0) {
|
|
82
|
+
console.error(`Error: Invalid epsilon value. Must be a positive number.`);
|
|
83
|
+
process.exit(1);
|
|
84
|
+
}
|
|
85
|
+
options.epsilon = epsilonValue;
|
|
86
|
+
}
|
|
87
|
+
else {
|
|
88
|
+
console.error(`Error: Unknown option "${arg}"`);
|
|
89
|
+
printUsage();
|
|
90
|
+
process.exit(1);
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
let input;
|
|
94
|
+
try {
|
|
95
|
+
input = readFileSync(inputFile, 'utf-8');
|
|
96
|
+
}
|
|
97
|
+
catch (err) {
|
|
98
|
+
console.error(`Error: Could not read file "${inputFile}"`);
|
|
99
|
+
if (err instanceof Error) {
|
|
100
|
+
console.error(err.message);
|
|
101
|
+
}
|
|
102
|
+
process.exit(1);
|
|
103
|
+
}
|
|
104
|
+
try {
|
|
105
|
+
const program = parse(input);
|
|
106
|
+
if (program.functions.length === 0) {
|
|
107
|
+
console.error('Error: No functions found in input file');
|
|
108
|
+
process.exit(1);
|
|
109
|
+
}
|
|
110
|
+
const func = program.functions[0];
|
|
111
|
+
if (program.functions.length > 1) {
|
|
112
|
+
console.warn(`Warning: Multiple functions found, processing only "${func.name}"`);
|
|
113
|
+
}
|
|
114
|
+
const env = inferFunction(func);
|
|
115
|
+
const gradients = computeFunctionGradients(func, env);
|
|
116
|
+
// Analyze for edge cases
|
|
117
|
+
const guardAnalysis = analyzeGuards(func);
|
|
118
|
+
if (guardAnalysis.hasIssues) {
|
|
119
|
+
console.error(formatGuardWarnings(guardAnalysis));
|
|
120
|
+
}
|
|
121
|
+
const code = generateComplete(func, gradients, env, options);
|
|
122
|
+
console.log(code);
|
|
123
|
+
}
|
|
124
|
+
catch (err) {
|
|
125
|
+
console.error('Error: Failed to process input file');
|
|
126
|
+
if (err instanceof Error) {
|
|
127
|
+
console.error(err.message);
|
|
128
|
+
if (err.stack) {
|
|
129
|
+
console.error('\nStack trace:');
|
|
130
|
+
console.error(err.stack);
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
process.exit(1);
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
main();
|