tensorgrad 0.0.16 → 0.0.18

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (2) hide show
  1. package/README.md +253 -119
  2. package/package.json +1 -1
package/README.md CHANGED
@@ -1,119 +1,253 @@
1
- # tensorgrad
2
-
3
- A tiny TypeScript-native tensor library with autograd that compiles directly
4
- to WebGPU. Designed for training small models in the browser without
5
- hand-writing WGSL kernels and without dragging in a 5 MB ML framework.
6
-
7
- ```sh
8
- npm i tensorgrad
9
- ```
10
-
11
- Roughly **3000 lines of zero-dependency TypeScript**, ~10 KB gzipped after
12
- build. Targets WebGPU only. Static shapes only. Forward + reverse-mode
13
- autograd; Adam optimizer; the whole training pipeline runs as compiled WGSL.
14
-
15
- ## Quick example
16
-
17
- A 2-layer MLP fitting `y = sin(x)`:
18
-
19
- ```ts
20
- import {
21
- Module, compileModule,
22
- add, mul, sub, sumLast, reshape, matmul, relu,
23
- type Tensor,
24
- } from 'tensorgrad'
25
-
26
- class Linear extends Module {
27
- W: Tensor; b: Tensor
28
- constructor(public inDim: number, public outDim: number) {
29
- super()
30
- this.W = this.param([inDim, outDim]) // randn, scale 0.02
31
- this.b = this.param([outDim], { init: 'zeros' })
32
- }
33
- }
34
-
35
- class MLP extends Module {
36
- l1 = new Linear(1, 64)
37
- l2 = new Linear(64, 64)
38
- l3 = new Linear(64, 1)
39
- }
40
-
41
- const linear = (p: Linear, x: Tensor) => add(matmul(x, p.W), p.b)
42
-
43
- function forward(m: MLP, x: Tensor): Tensor {
44
- return linear(m.l3, relu(linear(m.l2, relu(linear(m.l1, x)))))
45
- }
46
-
47
- function loss(m: MLP, { x, y }: { x: Tensor; y: Tensor }): Tensor {
48
- const diff = sub(forward(m, x), y)
49
- return mul(sumLast(reshape(mul(diff, diff), [B])), 1 / B)
50
- }
51
-
52
- const B = 256
53
- const compiled = await compileModule(() => new MLP(), loss, {
54
- adam: { lr: 0.005 },
55
- inputs: {
56
- x: { shape: [B, 1], dtype: 'f32' },
57
- y: { shape: [B, 1], dtype: 'f32' },
58
- },
59
- })
60
-
61
- // Initial params are uploaded automatically — no manual step needed.
62
-
63
- for (let step = 0; step < 1000; step++) {
64
- const { x, y } = generateBatch()
65
- const lossVal = await compiled.step({ x, y })
66
- if (step % 100 === 0) console.log('step', step, 'loss', lossVal)
67
- }
68
- ```
69
-
70
- That's the whole user-facing surface for this model: `Module` for parameter
71
- storage, plain functions for the forward pass, `compileModule` to JIT-compile
72
- to WGSL with autograd + Adam wired in. No decorators, no `tf.GradientTape`,
73
- no `register_pytree_node`.
74
-
75
- For a more involved example a 3-layer transformer trained from scratch on
76
- 2-digit addition — see the [`samples/`](./samples) workspace
77
- (`pnpm --filter samples dev`).
78
-
79
- ## What this library is for
80
-
81
- Small browser-side ML where you want to *train* the model, not just run
82
- inference of a pretrained model. Educational artifacts, interactive
83
- demos, on-device personalization, "transformer from scratch in your browser"
84
- blog posts. Roughly the niche where the model is small enough to fit
85
- comfortably in a browser tab but where you still want autograd and a real
86
- optimizer.
87
-
88
- If you want to ship inference of a pretrained model, use
89
- [ONNX Runtime Web](https://github.com/microsoft/onnxruntime) or
90
- [transformers.js](https://github.com/xenova/transformers.js).
91
- If you need full JAX (vmap / pmap / dynamic shapes / multi-backend), use
92
- [jax-js](https://github.com/jax-js/jax).
93
-
94
- ## Scope (deliberately small)
95
-
96
- The library only does what it does because of what it doesn't do. The
97
- load-bearing "out of scope" decisions are:
98
-
99
- - **WebGPU only** no Wasm or WebGL fallback.
100
- - **Static shapes only** — every shape is fixed at compile time. This is
101
- what lets us bake constants into the WGSL instead of carrying shape
102
- uniforms.
103
- - **`grad` is the only transformation** — no `vmap`, `pmap`, `jvp`,
104
- `custom_vjp`. Batch your data explicitly.
105
- - **`f32` only** no dtype promotion, no mixed precision.
106
- - **Closed op set** — about 25 ops, listed in `SPEC.md`. Compositions of
107
- those handle most needs (GELU, RMS norm, etc. are a few lines on top).
108
- - **Adam lives in the IR** — bias correction included; no CPU↔GPU
109
- round-trip per step.
110
-
111
- ## Status
112
-
113
- Alpha. Two real working models (a transformer training to <0.1 loss on
114
- addition, an MLP fitting `sin`). API may change before 1.0. Filing issues
115
- welcome.
116
-
117
- ## License
118
-
119
- MIT
1
+ # tensorgrad
2
+
3
+ A tiny TypeScript-native tensor library with autograd that compiles to WebGPU.
4
+ For training small models in the browser without hand-writing WGSL kernels and
5
+ without dragging in a multi-megabyte ML framework.
6
+
7
+ ```sh
8
+ npm i tensorgrad
9
+ ```
10
+
11
+ Roughly 3000 lines of zero-dependency TypeScript. Static shapes, `f32`, Adam
12
+ optimizer, ~25 ops, forward + reverse-mode autograd. Browser-only (uses
13
+ WebGPU). All training/inference work runs in a library-internal Web Worker
14
+ every method on a compiled module returns a `Promise`.
15
+
16
+ ## Minimal example
17
+
18
+ A 2-layer MLP fitting `y = sin(x)`:
19
+
20
+ ```ts
21
+ import {
22
+ Module, compileModule, init,
23
+ add, mul, sub, sumLast, reshape, matmul, relu,
24
+ type Tensor,
25
+ } from 'tensorgrad'
26
+
27
+ const B = 256
28
+
29
+ class Linear extends Module {
30
+ W: Tensor; b: Tensor
31
+ constructor(public inDim: number, public outDim: number) {
32
+ super()
33
+ this.W = this.param([inDim, outDim], { init: init.kaiming() })
34
+ this.b = this.param([outDim], { init: 'zeros' })
35
+ }
36
+ }
37
+
38
+ class MLP extends Module {
39
+ l1 = new Linear(1, 64)
40
+ l2 = new Linear(64, 64)
41
+ l3 = new Linear(64, 1)
42
+ }
43
+
44
+ const linearFwd = (p: Linear, x: Tensor) => add(matmul(x, p.W), p.b)
45
+
46
+ function modelFwd(m: MLP, x: Tensor): Tensor {
47
+ return linearFwd(m.l3, relu(linearFwd(m.l2, relu(linearFwd(m.l1, x)))))
48
+ }
49
+
50
+ function lossFn(m: MLP, { x, y }: { x: Tensor; y: Tensor }): Tensor {
51
+ const diff = sub(modelFwd(m, x), y)
52
+ return mul(sumLast(reshape(mul(diff, diff), [B])), 1 / B)
53
+ }
54
+
55
+ const compiled = await compileModule(() => new MLP(), lossFn, {
56
+ adam: { lr: 0.005 },
57
+ inputs: {
58
+ x: { shape: [B, 1], dtype: 'f32' },
59
+ y: { shape: [B, 1], dtype: 'f32' },
60
+ },
61
+ })
62
+
63
+ for (let step = 0; step < 1000; step++) {
64
+ const { x, y } = generateBatch()
65
+ const lossVal = await compiled.step({ x, y })
66
+ if (step % 100 === 0) console.log('step', step, 'loss', lossVal)
67
+ }
68
+ ```
69
+
70
+ ## Mental model
71
+
72
+ - A `Module` subclass declares parameters via `this.param([shape], opts)` and
73
+ composes child modules as plain fields. The class is a tree of params.
74
+ - A *forward function* takes the materialized module + a record of named
75
+ input tensors and returns a tensor (the loss for `compileModule`, or any
76
+ output for `compileForward`).
77
+ - `compileModule(factory, forward, opts)` traces the forward, derives
78
+ gradients, wires Adam, plans buffers, generates WGSL, spawns a worker, and
79
+ returns a `CompiledModule`. The factory `() => new Model()` is invoked
80
+ once during compile; the model instance is consumed (its param sentinels
81
+ are mutated into `Tensor`s).
82
+ - Every method on the compiled module is async. `await compiled.step(...)`
83
+ resolves with the loss after the worker's GPU work finishes.
84
+
85
+ ## Public API
86
+
87
+ ### Compile entry points
88
+
89
+ ```ts
90
+ compileModule(factory, forward, { adam?, inputs? }): Promise<CompiledModule>
91
+ compileForward(factory, forward, { inputs? }): Promise<CompiledForwardModule>
92
+ ```
93
+
94
+ `compileForward` produces a forward-only graph in its own worker. To share
95
+ params with an existing training graph, use the sibling method:
96
+
97
+ ```ts
98
+ const train = await compileModule(() => new Model(), lossFn, { ... })
99
+ const infer = await train.compileForward(predictFn, {
100
+ inputs: { tokens: { shape: [1, T], dtype: 'i32' } },
101
+ })
102
+ // infer runs in train's worker — every step's param updates are visible.
103
+ ```
104
+
105
+ ### CompiledModule methods (all `Promise`-returning)
106
+
107
+ ```ts
108
+ compiled.step(inputs) // loss: number
109
+ compiled.step(inputs, { withCaptures: true }) // → { loss, captures }
110
+ compiled.run(inputs) // → Float32Array
111
+ compiled.run(inputs, { withCaptures: true }) // → { output, captures }
112
+ compiled.uploadParams(record, { partial? })
113
+ compiled.downloadParams() // Record<name, Float32Array>
114
+ compiled.downloadParamGrads() // Record<name, Float32Array>
115
+ compiled.reset() // re-init params + zero Adam state
116
+ compiled.resetOptimizerState()
117
+ compiled.compileForward(forward, { inputs? }) // sibling forward graph
118
+ compiled.destroy() // tear down worker + GPU
119
+ ```
120
+
121
+ `compiled.kernelCount`, `compiled.outputShape`, `compiled.paramNames`, and
122
+ `compiled.ir` are sync properties for inspection.
123
+
124
+ ### Operators
125
+
126
+ Imported from `'tensorgrad'`:
127
+
128
+ - Element-wise: `add`, `sub`, `mul`, `div`, `sqrt`, `rsqrt`, `log`, `exp`, `relu`
129
+ - Comparisons / select: `less`, `greater`, `where`
130
+ - Reductions (last axis): `meanLast`, `sumLast`, `sumAll`
131
+ - Shape: `reshape`, `transpose`, `swapAxes`
132
+ - Linear algebra: `matmul`, `matmulBatched`
133
+ - Indexing / casting: `oneHot`, `arange`, `embedding`
134
+ - Slicing: `sliceLastRange`
135
+ - Fused ML primitives: `softmaxCausalLast`, `logSoftmaxLast`, `whereCausal`
136
+
137
+ `add`, `sub`, `mul`, `div` accept `(Tensor, Tensor)` or `(Tensor, number)`.
138
+
139
+ ### `nn` namespace
140
+
141
+ ```ts
142
+ import { nn } from 'tensorgrad'
143
+
144
+ nn.Linear(inDim, outDim, { bias? }) // .fwd(x)
145
+ nn.LayerNorm(dim) // .fwd(x)
146
+ nn.splitHeads(x, nHeads) // [B, T, D] → [B, H, T, D/H]
147
+ nn.mergeHeads(x) // inverse of splitHeads
148
+ nn.unsplitHeads(captures, name) // pull per-head slices off a capture
149
+ nn.crossEntropyLast(logits, targets) // standard CE
150
+ ```
151
+
152
+ Convention: leaf modules (`Linear`, `LayerNorm`) expose `.fwd(x)` for ergonomic
153
+ chaining. Composite modules you write yourself are typically free functions
154
+ taking `(p: ModuleType, x: Tensor)`.
155
+
156
+ ### LR schedules (`lr` namespace)
157
+
158
+ ```ts
159
+ import { lr } from 'tensorgrad'
160
+
161
+ adam: { lr: 0.005 } // constant
162
+ adam: { lr: lr.linearDecay({ peak: 0.005, final: 0.0005, steps: 1500 }) }
163
+ adam: { lr: lr.cosineDecay({ peak: 0.005, final: 0.0001, steps: 5000 }) }
164
+ adam: { lr: lr.warmup({ peakLr: 0.001, warmupSteps: 200, after: lr.constant(0.001) }) }
165
+ ```
166
+
167
+ LR schedules are serializable shapes, not closures (they cross the worker
168
+ boundary). Use a `number` for constant LR, or one of the constructors above.
169
+
170
+ ### Param init (`init` namespace)
171
+
172
+ ```ts
173
+ import { init } from 'tensorgrad'
174
+
175
+ this.param([D, D], { init: init.kaiming() }) // gain=sqrt(2), fan_in=D
176
+ this.param([D, D], { init: init.kaiming({ gain: 1 }) })
177
+ this.param([D], { init: 'zeros' })
178
+ this.param([D], { init: 'ones' })
179
+ this.param([D, D], { init: init.randn({ scale: 0.02 }) })
180
+ this.param([D], { init: init.literal(myFloat32Array) })
181
+ ```
182
+
183
+ Defaults: `'randn'` (std 0.02). AdamW weight decay defaults to `true` for
184
+ randn/kaiming/literal init, `false` for zeros/ones — override per-param with
185
+ `{ decay: true | false }`.
186
+
187
+ ### Captures (debugging / mech-interp)
188
+
189
+ Wrap any tensor inside a forward to expose its activation post-run:
190
+
191
+ ```ts
192
+ import { capture } from 'tensorgrad'
193
+
194
+ const attn = capture(`attn.${i}`, softmaxCausalLast(scores))
195
+ ```
196
+
197
+ ```ts
198
+ const { output, captures } = await compiled.run(inputs, { withCaptures: true })
199
+ const attn0 = captures.get('attn.0') // Float32Array
200
+ captures.shapeOf('attn.0') // readonly number[]
201
+ ```
202
+
203
+ Captures are zero-overhead unless `{ withCaptures: true }` is passed; they
204
+ add a single batched mapAsync on the readback.
205
+
206
+ ## Constraints
207
+
208
+ The library is small because of what it doesn't do. Plan accordingly:
209
+
210
+ - **WebGPU only.** No Wasm, WebGL, or native fallback.
211
+ - **Static shapes.** Every shape is fixed at compile time. Changing a batch
212
+ size means recompiling.
213
+ - **`f32` only.** No mixed precision. Inputs may be `i32` for indices.
214
+ - **One transformation: `grad`.** No `vmap`, `pmap`, `jvp`, `custom_vjp`.
215
+ Batch your data explicitly.
216
+ - **Loss must be a scalar.** `compileModule`'s forward returns a rank-0 tensor.
217
+ - **Closures don't cross the worker boundary.** LR schedules and inits are
218
+ serializable shapes, not functions. Anything per-step you write into a
219
+ user-defined optimizer (see *Extending* below) follows the same rule.
220
+ - **One model per `compileModule` call.** Sibling forward graphs share params
221
+ via the method form; otherwise each compile spawns its own worker.
222
+
223
+ ## Extending
224
+
225
+ The IR is open. Adam is built in only because it's the most common starting
226
+ point — other optimizers, custom losses, or extra ops are user code following
227
+ the same pattern as `appendAdam`:
228
+
229
+ ```ts
230
+ import { appendAdam, appendGrad, compileToIR } from 'tensorgrad'
231
+ ```
232
+
233
+ A custom optimizer is a function that takes the autograd output (graph +
234
+ `paramGrads`) and the materialized param tensors, appends its update ops
235
+ to the graph, and returns writeback declarations the buffer planner uses
236
+ to wire each new value back into its persistent home. SGD, Lion, RMSProp
237
+ all fit this shape; see `src/adam.ts` for the canonical example.
238
+
239
+ The same applies to ops: anything missing from the built-in set can be
240
+ expressed as a composition of existing ops (GELU, RMSNorm, etc. are a few
241
+ lines), or — if you need a new primitive — added to the IR with a
242
+ forward + backward + WGSL emit.
243
+
244
+ ## When not to use this
245
+
246
+ - **Inference of pretrained models** → use ONNX Runtime Web or
247
+ transformers.js.
248
+ - **Full JAX surface** (vmap, dynamic shapes, multi-backend) → use jax-js.
249
+ - **Server-side training** → use PyTorch or JAX.
250
+
251
+ ## License
252
+
253
+ MIT
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "tensorgrad",
3
- "version": "0.0.16",
3
+ "version": "0.0.18",
4
4
  "description": "Tiny TypeScript-native tensor library with autograd, compiling to WebGPU. Train small models in the browser without hand-writing kernels.",
5
5
  "license": "MIT",
6
6
  "author": "Ben Albahari",