tensorgrad 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +121 -0
  3. package/SPEC.md +293 -0
  4. package/dist/adam.d.ts +31 -0
  5. package/dist/adam.d.ts.map +1 -0
  6. package/dist/adam.js +66 -0
  7. package/dist/adam.js.map +1 -0
  8. package/dist/buffers.d.ts +56 -0
  9. package/dist/buffers.d.ts.map +1 -0
  10. package/dist/buffers.js +114 -0
  11. package/dist/buffers.js.map +1 -0
  12. package/dist/codegen.d.ts +23 -0
  13. package/dist/codegen.d.ts.map +1 -0
  14. package/dist/codegen.js +709 -0
  15. package/dist/codegen.js.map +1 -0
  16. package/dist/compile.d.ts +53 -0
  17. package/dist/compile.d.ts.map +1 -0
  18. package/dist/compile.js +76 -0
  19. package/dist/compile.js.map +1 -0
  20. package/dist/grad.d.ts +8 -0
  21. package/dist/grad.d.ts.map +1 -0
  22. package/dist/grad.js +404 -0
  23. package/dist/grad.js.map +1 -0
  24. package/dist/index.d.ts +12 -0
  25. package/dist/index.d.ts.map +1 -0
  26. package/dist/index.js +37 -0
  27. package/dist/index.js.map +1 -0
  28. package/dist/ir.d.ts +204 -0
  29. package/dist/ir.d.ts.map +1 -0
  30. package/dist/ir.js +60 -0
  31. package/dist/ir.js.map +1 -0
  32. package/dist/module.d.ts +21 -0
  33. package/dist/module.d.ts.map +1 -0
  34. package/dist/module.js +113 -0
  35. package/dist/module.js.map +1 -0
  36. package/dist/ops.d.ts +35 -0
  37. package/dist/ops.d.ts.map +1 -0
  38. package/dist/ops.js +270 -0
  39. package/dist/ops.js.map +1 -0
  40. package/dist/runtime.d.ts +26 -0
  41. package/dist/runtime.d.ts.map +1 -0
  42. package/dist/runtime.js +190 -0
  43. package/dist/runtime.js.map +1 -0
  44. package/dist/shape.d.ts +24 -0
  45. package/dist/shape.d.ts.map +1 -0
  46. package/dist/shape.js +259 -0
  47. package/dist/shape.js.map +1 -0
  48. package/dist/trace.d.ts +8 -0
  49. package/dist/trace.d.ts.map +1 -0
  50. package/dist/trace.js +93 -0
  51. package/dist/trace.js.map +1 -0
  52. package/package.json +62 -0
  53. package/src/adam.ts +95 -0
  54. package/src/buffers.ts +173 -0
  55. package/src/codegen.ts +758 -0
  56. package/src/compile.ts +120 -0
  57. package/src/grad.ts +459 -0
  58. package/src/index.ts +40 -0
  59. package/src/ir.ts +197 -0
  60. package/src/module.ts +126 -0
  61. package/src/ops.ts +311 -0
  62. package/src/runtime.ts +232 -0
  63. package/src/shape.ts +263 -0
  64. package/src/trace.ts +101 -0
package/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Ben Albahari
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
package/README.md ADDED
@@ -0,0 +1,121 @@
1
+ # tensorgrad
2
+
3
+ A tiny TypeScript-native tensor library with autograd that compiles directly
4
+ to WebGPU. Designed for training small models in the browser — without
5
+ hand-writing WGSL kernels and without dragging in a 5 MB ML framework.
6
+
7
+ ```sh
8
+ npm i tensorgrad
9
+ ```
10
+
11
+ Roughly **3000 lines of zero-dependency TypeScript**, ~10 KB gzipped after
12
+ build. Targets WebGPU only. Static shapes only. Forward + reverse-mode
13
+ autograd; Adam optimizer; the whole training pipeline runs as compiled WGSL.
14
+
15
+ ## Quick example
16
+
17
+ A 2-layer MLP fitting `y = sin(x)`:
18
+
19
+ ```ts
20
+ import {
21
+ Module, compileModule,
22
+ add, mul, sub, sumLast, reshape, matmul, relu,
23
+ type Tensor,
24
+ } from 'tensorgrad'
25
+
26
+ class Linear extends Module {
27
+ W: Tensor; b: Tensor
28
+ constructor(public inDim: number, public outDim: number) {
29
+ super()
30
+ this.W = this.param([inDim, outDim])
31
+ this.b = this.param([outDim])
32
+ }
33
+ }
34
+
35
+ class MLP extends Module {
36
+ l1 = new Linear(1, 64)
37
+ l2 = new Linear(64, 64)
38
+ l3 = new Linear(64, 1)
39
+ }
40
+
41
+ const linear = (p: Linear, x: Tensor) => add(matmul(x, p.W), p.b)
42
+
43
+ function forward(m: MLP, x: Tensor): Tensor {
44
+ return linear(m.l3, relu(linear(m.l2, relu(linear(m.l1, x)))))
45
+ }
46
+
47
+ function loss(m: MLP, x: Tensor, y: Tensor): Tensor {
48
+ const diff = sub(forward(m, x), y)
49
+ return mul(sumLast(reshape(mul(diff, diff), [B])), 1 / B)
50
+ }
51
+
52
+ const B = 256
53
+ const model = new MLP()
54
+ const compiled = await compileModule(model, loss, {
55
+ adam: { lr: 0.005 },
56
+ inputs: [
57
+ { name: 'x', shape: [B, 1], dtype: 'f32' },
58
+ { name: 'y', shape: [B, 1], dtype: 'f32' },
59
+ ],
60
+ })
61
+
62
+ // Initialize params however you like (random, etc), then upload + train.
63
+ compiled.uploadParams(initialParams)
64
+ for (let step = 0; step < 1000; step++) {
65
+ const { x, y } = generateBatch()
66
+ const lossVal = await compiled.step({ x, y })
67
+ if (step % 100 === 0) console.log('step', step, 'loss', lossVal)
68
+ }
69
+ ```
70
+
71
+ That's the whole user-facing surface for this model: `Module` for parameter
72
+ storage, plain functions for the forward pass, `compileModule` to JIT-compile
73
+ to WGSL with autograd + Adam wired in. No decorators, no `tf.GradientTape`,
74
+ no `register_pytree_node`.
75
+
76
+ For a more involved example — a 3-layer transformer trained from scratch on
77
+ 2-digit addition — see the [`samples/`](./samples) workspace
78
+ (`pnpm --filter samples dev`).
79
+
80
+ ## What this library is for
81
+
82
+ Small browser-side ML where you want to *train* the model, not just run
83
+ inference of a pretrained model. Educational artifacts, interactive
84
+ demos, on-device personalization, "transformer from scratch in your browser"
85
+ blog posts. Roughly the niche where the model is small enough to fit
86
+ comfortably in a browser tab but where you still want autograd and a real
87
+ optimizer.
88
+
89
+ If you want to ship inference of a pretrained model, use
90
+ [ONNX Runtime Web](https://github.com/microsoft/onnxruntime) or
91
+ [transformers.js](https://github.com/xenova/transformers.js).
92
+ If you need full JAX (vmap / pmap / dynamic shapes / multi-backend), use
93
+ [jax-js](https://github.com/jax-js/jax).
94
+
95
+ ## Scope (deliberately small)
96
+
97
+ The library only does what it does because of what it doesn't do.
98
+ [`SPEC.md`](./SPEC.md) has the full design notes; the load-bearing
99
+ "out of scope" decisions are:
100
+
101
+ - **WebGPU only** — no Wasm or WebGL fallback.
102
+ - **Static shapes only** — every shape is fixed at compile time. This is
103
+ what lets us bake constants into the WGSL instead of carrying shape
104
+ uniforms.
105
+ - **`grad` is the only transformation** — no `vmap`, `pmap`, `jvp`,
106
+ `custom_vjp`. Batch your data explicitly.
107
+ - **`f32` only** — no dtype promotion, no mixed precision.
108
+ - **Closed op set** — about 25 ops, listed in `SPEC.md`. Compositions of
109
+ those handle most needs (GELU, RMS norm, etc. are a few lines on top).
110
+ - **Adam lives in the IR** — bias correction included; no CPU↔GPU
111
+ round-trip per step.
112
+
113
+ ## Status
114
+
115
+ Alpha. Two real working models (a transformer training to <0.1 loss on
116
+ addition, an MLP fitting `sin`). API may change before 1.0. Filing issues
117
+ welcome.
118
+
119
+ ## License
120
+
121
+ MIT
package/SPEC.md ADDED
@@ -0,0 +1,293 @@
1
+ # Tensorgrad — Architecture
2
+
3
+ This document covers the design decisions, IR, and internals of tensorgrad.
4
+ For installation and user-facing API, see `README.md`. For pre-1.0
5
+ implementation status and what to pick up next, see `HANDOFF.md`.
6
+
7
+ ## Scope and non-goals (load-bearing)
8
+
9
+ The library only does what it does because of what it doesn't do. Each
10
+ "out of scope" decision is the *precondition* that lets the rest stay small.
11
+
12
+ **In scope:**
13
+ - Static-shape models. Every shape is fixed at compile time.
14
+ - WebGPU only.
15
+ - f32 only.
16
+ - `grad` (reverse-mode autograd) as the only transformation.
17
+ - A closed set of ~25 ops covering transformers + MLPs.
18
+ - Adam optimizer in-IR.
19
+
20
+ **Out of scope (deliberately):**
21
+ - Wasm or WebGL fallback.
22
+ - Dynamic shapes, shape polymorphism.
23
+ - `vmap`, `pmap`, `jvp`, `custom_vjp`, higher-order gradients.
24
+ - Dtype promotion, mixed precision.
25
+ - General PyTree machinery (we use `Module` + property paths instead).
26
+ - Inference of pre-trained models (use ONNX Runtime Web or transformers.js).
27
+ - ONNX import / safetensors / model loaders.
28
+ - Distributed training, gradient accumulation across devices.
29
+
30
+ The non-goals are load-bearing. Trying to add any of them without rethinking
31
+ the IR forces complexity throughout.
32
+
33
+ ## Architecture overview
34
+
35
+ ```
36
+ ┌────────────────────────────────────────────────────────────┐
37
+ │ User code │
38
+ │ class Model extends Module { /* params */ } │
39
+ │ function forward(m: Model, x: Tensor): Tensor { /* ... */ }│
40
+ └────────────────────────────────────────────────────────────┘
41
+
42
+ ▼ trace()
43
+ ┌────────────────────────────────────────────────────────────┐
44
+ │ Forward IR build (src/trace.ts, src/ops.ts, src/shape.ts) │
45
+ │ Each op call appends a node to the Graph and returns a │
46
+ │ fresh Tensor handle. Shapes inferred + validated per op. │
47
+ └────────────────────────────────────────────────────────────┘
48
+
49
+ ▼ appendGrad()
50
+ ┌────────────────────────────────────────────────────────────┐
51
+ │ Reverse-mode autograd (src/grad.ts) │
52
+ │ Topological walk; each forward op contributes its │
53
+ │ transpose rule, building the backward graph in-place. │
54
+ └────────────────────────────────────────────────────────────┘
55
+
56
+ ▼ appendAdam() (optional)
57
+ ┌────────────────────────────────────────────────────────────┐
58
+ │ Optimizer (src/adam.ts) │
59
+ │ Per-param: m_state, v_state state_inputs + fused │
60
+ │ adam_update_{m,v,p} ops. Writebacks declared. │
61
+ └────────────────────────────────────────────────────────────┘
62
+
63
+ ▼ planBuffers()
64
+ ┌────────────────────────────────────────────────────────────┐
65
+ │ Buffer plan (src/buffers.ts) │
66
+ │ One GPU buffer per IR tensor, categorized: │
67
+ │ param / param_grad / state / tensor_input / intermediate. │
68
+ │ Writebacks resolved to (source_buf → dest_buf) pairs. │
69
+ └────────────────────────────────────────────────────────────┘
70
+
71
+ ▼ emitKernels()
72
+ ┌────────────────────────────────────────────────────────────┐
73
+ │ WGSL codegen (src/codegen.ts) │
74
+ │ Per op kind: a kernel template with shapes baked in. │
75
+ │ Returns dispatch-ready KernelSpec[]. │
76
+ └────────────────────────────────────────────────────────────┘
77
+
78
+ ▼ createRuntime()
79
+ ┌────────────────────────────────────────────────────────────┐
80
+ │ Runtime (src/runtime.ts) │
81
+ │ GPUDevice setup, pipeline cache, bind groups, │
82
+ │ step(): upload inputs → dispatch all kernels → writebacks │
83
+ │ → loss readback. Compile errors surface via │
84
+ │ pushErrorScope+getCompilationInfo. │
85
+ └────────────────────────────────────────────────────────────┘
86
+ ```
87
+
88
+ ## Key design decisions
89
+
90
+ **D1. Runtime tracing, not build-time.** Forward function is traced once on
91
+ first compile; the IR is built from those op calls. Build-time tracing via
92
+ a TypeScript transformer plugin would be cleaner but adds a build-step
93
+ requirement. v2 candidate.
94
+
95
+ **D2. Tensors are opaque handles, not Proxies.** Each op returns a fresh
96
+ `Tensor` (just `{ id, shape, dtype, source, site }`). Proxy-based tracing
97
+ gives nicer error UX but couples the IR to runtime introspection.
98
+
99
+ **D3. No reference counting.** Every IR tensor gets its own GPU buffer,
100
+ allocated once and never freed. Our scope (one model, fixed shapes,
101
+ training in a loop) means there's nothing to gain from refcount discipline.
102
+ Memory cost is bounded; buffer pooling is a v2 optimization, not v1
103
+ correctness.
104
+
105
+ **D4. Closed op set.** The IR knows about exactly the ops it supports.
106
+ Adding a new op means adding (a) shape rule, (b) WGSL kernel template,
107
+ (c) autograd transpose rule. This is intentional — a closed op set makes
108
+ each piece tractable to write and verify by hand.
109
+
110
+ **D5. Shapes checked at trace time, not at type level.** Type-level shape
111
+ encoding in TypeScript is real but hits recursion limits and adds
112
+ significant generic complexity to user code. Runtime shape errors at trace
113
+ time, with call-site capture, are good enough.
114
+
115
+ **D6. Adam state in the IR.** Optimizer state (m, v, plus a per-step `lrt`
116
+ scalar for bias correction) lives in dedicated `state_input` buffers that
117
+ persist across `step()` calls. Writebacks at the end of each step copy new
118
+ values into their persistent homes. No CPU↔GPU round-trip per step.
119
+
120
+ **D7. Module separates from forward.** Mutable parameter storage lives in
121
+ `Module` subclasses. Forward functions are pure, take the materialized
122
+ model as the first argument, and call ordinary op functions. State and
123
+ computation never mix — JAX's lesson, applied to TypeScript with
124
+ class-based ergonomics.
125
+
126
+ **D8. JS-number scalar overloads.** `add(x, 1e-5)` and `add(x, y)` both
127
+ work. The scalar variants dispatch to fused IR ops internally.
128
+
129
+ ## IR
130
+
131
+ ```ts
132
+ interface Tensor {
133
+ readonly id: number
134
+ readonly shape: Shape
135
+ readonly dtype: Dtype
136
+ readonly source: number | null // op index, or null for leaves
137
+ readonly site: CallSite | null // user's stack at op-call time
138
+ }
139
+
140
+ type OpNode =
141
+ | { kind: 'param_input'; ... } | { kind: 'tensor_input'; ... }
142
+ | { kind: 'state_input'; ... } | { kind: 'arange'; ... }
143
+ | { kind: 'const_scalar'; ... }
144
+ | { kind: 'add' | 'sub' | 'mul' | 'div'; ... }
145
+ | { kind: 'add_scalar' | 'mul_scalar'; ... }
146
+ | { kind: 'sqrt' | 'rsqrt' | 'log' | 'exp' | 'relu'; ... }
147
+ | { kind: 'less' | 'greater'; ... } | { kind: 'where'; ... }
148
+ | { kind: 'mean_last' | 'sum_last'; ... }
149
+ | { kind: 'reshape' | 'transpose' | 'slice_last_range'; ... }
150
+ | { kind: 'broadcast_to' | 'sum_to_shape'; ... }
151
+ | { kind: 'matmul' | 'matmul_batched'; ... }
152
+ | { kind: 'one_hot'; ... }
153
+ | { kind: 'softmax_causal_last' | 'log_softmax_last' | 'where_causal'; ... }
154
+ | { kind: 'relu_grad'; ... }
155
+ | { kind: 'adam_update_m' | 'adam_update_v' | 'adam_update_p'; ... }
156
+
157
+ interface Graph {
158
+ readonly ops: OpNode[]
159
+ readonly tensors: Tensor[]
160
+ readonly outputs: number[] // tensor ids — typically just the loss
161
+ }
162
+ ```
163
+
164
+ The op kinds are intentionally split fine-grained (`mean_last` not
165
+ `mean(axis)`) because each kind maps to a hand-written WGSL kernel. Adding
166
+ generality later is fine; pretending to be more general than we are isn't.
167
+
168
+ ## Op set (current)
169
+
170
+ **Leaves:** `param_input`, `tensor_input`, `state_input`, `arange`, `const_scalar`
171
+
172
+ **Element-wise binops** (NumPy broadcasting): `add`, `sub`, `mul`, `div`,
173
+ plus fused `add_scalar`, `mul_scalar`
174
+
175
+ **Element-wise unary:** `sqrt`, `rsqrt`, `log`, `exp`, `relu`
176
+
177
+ **Comparisons + select:** `less`, `greater`, `where`
178
+
179
+ **Reductions over last axis:** `mean_last`, `sum_last`
180
+
181
+ **Shape:** `reshape`, `transpose`, `slice_last_range`, `broadcast_to`, `sum_to_shape`
182
+
183
+ **Linear algebra:** `matmul` (2D rhs), `matmul_batched` (both batched)
184
+
185
+ **Indexing / casting:** `one_hot`
186
+
187
+ **ML primitives** (fused for clean autograd): `softmax_causal_last`,
188
+ `log_softmax_last`, `where_causal`
189
+
190
+ **Autograd-internal:** `relu_grad`
191
+
192
+ **Adam-internal:** `adam_update_m`, `adam_update_v`, `adam_update_p`
193
+
194
+ ## Module abstraction
195
+
196
+ The `Module` base class enables Domeleon-style auto-discovery of nested
197
+ modules and parameters via property reflection:
198
+
199
+ ```ts
200
+ class Linear extends Module {
201
+ W: Tensor; b: Tensor
202
+ constructor(public inDim: number, public outDim: number) {
203
+ super()
204
+ this.W = this.param([inDim, outDim]) // returns ParamSentinel cast to Tensor
205
+ this.b = this.param([outDim])
206
+ }
207
+ }
208
+ ```
209
+
210
+ `this.param(shape)` returns a `ParamSentinel` typed as `Tensor`. At compile
211
+ time, `materializeParams(root)` walks enumerable properties of the model
212
+ tree (recursing into nested `Module` instances and arrays of modules),
213
+ replaces every sentinel with a real `paramInput` tensor whose name is the
214
+ property path (`layers.0.attn.q.W`), and returns a flat `Record<path,
215
+ Tensor>` for autograd to use.
216
+
217
+ This is the JAX/Equinox separation: parameter storage is mutable
218
+ (state-bearing components), forward computation is pure (functions over
219
+ materialized parameters and inputs).
220
+
221
+ ## WGSL codegen
222
+
223
+ Each op kind has a kernel template in `codegen.ts`. Shapes are **baked into
224
+ the WGSL as compile-time constants** rather than passed as uniforms — this
225
+ gives the WGSL compiler full freedom to specialize and means each shape
226
+ combination produces a distinct shader. Fine for our static-shape model.
227
+
228
+ **Dispatch:** WebGPU caps each dimension at 65535 workgroups. The runtime
229
+ dispatches as `(min(N, 65535), ceil(N/65535), 1)`; kernels compute their
230
+ global index as `gid.x + gid.y * (65535 * workgroup_size)`. Workgroup size
231
+ is 256 — large enough that our biggest kernel (~8M threads in
232
+ `matmul_bwd_dW`) fits in 1D.
233
+
234
+ **Error reporting:** `runtime.ts` wraps each pipeline creation in
235
+ `pushErrorScope('validation')` and pulls `getCompilationInfo()` on
236
+ failure, so shader bugs surface with file/line/message rather than the
237
+ useless "previous error" you get when a broken pipeline is dispatched.
238
+
239
+ ## Buffer plan
240
+
241
+ `planBuffers(graph, paramGrads, writebacks)` walks every tensor and
242
+ categorizes it:
243
+
244
+ | Kind | Purpose | Lifetime |
245
+ |---|---|---|
246
+ | `param` | trainable parameter | persistent |
247
+ | `param_grad` | gradient w.r.t. a param | one step |
248
+ | `state` | optimizer state (Adam m, v) | persistent |
249
+ | `tensor_input` | data input (tokens, targets) | one step |
250
+ | `intermediate` | any other op output | one step |
251
+ | `output` | exposed graph output (loss) | one step |
252
+
253
+ State buffers are zero-initialized at runtime creation. Writebacks (declared
254
+ by `appendAdam`) describe end-of-step `copyBufferToBuffer` operations from
255
+ freshly-computed values into their persistent homes.
256
+
257
+ ## Autograd
258
+
259
+ `appendGrad(graph)` walks the forward ops in reverse and emits backward ops
260
+ into the same graph. Each op's transpose rule is hand-written in
261
+ `grad.ts`. The cotangents map (`tensorId → Tensor`) accumulates
262
+ contributions from multiple consumers via `add`.
263
+
264
+ Two notable workarounds:
265
+
266
+ - **Embedding lookup is implemented as `oneHot @ table`** rather than
267
+ `gather`. Gather has no transpose rule (it'd need scatter-with-atomic-add
268
+ or similar), but `oneHot @ table` decomposes into ops that *do* have
269
+ rules, so autograd works through it for free.
270
+ - **`slice_last_range` has no backward yet.** Forward works (used in any
271
+ axis-2 slicing pattern); backward is unimplemented because it'd need a
272
+ scatter-style "place into zeros" op. Workaround: use multiple separate
273
+ matmuls (e.g. `W_q`/`W_k`/`W_v`) instead of a fused `W_qkv`.
274
+
275
+ ## Verification approach
276
+
277
+ Two layers:
278
+
279
+ 1. **Smoke test** (`pnpm test` → `test/smoke.ts`) — runs in Node without
280
+ GPU. Builds the IR, attaches grad, plans buffers, emits all WGSL, and
281
+ verifies structure (kernel count, binding count, shape errors). Catches
282
+ codegen regressions without needing a browser.
283
+ 2. **Live samples** (`pnpm --filter samples dev`) — Vite dev server with
284
+ a `/__log` endpoint that streams browser logs to stdout, used during
285
+ development to bypass copy-paste-from-console debugging.
286
+
287
+ ## What this spec is not
288
+
289
+ A contract. The API will change before 1.0. The load-bearing decisions
290
+ are in **Scope and non-goals** and **Key design decisions** above —
291
+ those are the points where the design deliberately diverges from JAX or
292
+ PyTorch, and where reverting any of them effectively re-creates the
293
+ failure mode that motivated this library.
package/dist/adam.d.ts ADDED
@@ -0,0 +1,31 @@
1
+ import type { Tensor } from './ir.js';
2
+ import type { Graph } from './ir.js';
3
+ import type { WritebackDecl } from './buffers.js';
4
+ export interface AdamConfig {
5
+ lr: number;
6
+ b1?: number;
7
+ b2?: number;
8
+ eps?: number;
9
+ }
10
+ export interface AdamResult {
11
+ /** Writebacks the buffer planner should wire into the runtime. */
12
+ writebacks: WritebackDecl[];
13
+ /** Name of the per-step scalar tensor_input. The runtime fills this each call
14
+ * with `lr * sqrt(1-b2^t)/(1-b1^t)` (Adam's bias-corrected effective LR). */
15
+ lrtInputName: string;
16
+ /** Hyperparameters as captured (so the runtime can compute lrt). */
17
+ config: Required<AdamConfig>;
18
+ }
19
+ /**
20
+ * Append Adam update ops to `graph`. Must be called inside an active trace
21
+ * context (or after a trace, since traceInto re-enters the graph).
22
+ *
23
+ * @param graph the graph (already containing forward + backward)
24
+ * @param paramGrads param name -> gradient tensor (output of `appendGrad`)
25
+ * @param paramTensors param name -> the param's leaf Tensor (the param_input).
26
+ * Needed because the param_input lives in the graph but we
27
+ * don't have a direct map by name in `Graph` — caller passes it.
28
+ * @param config Adam hyperparameters
29
+ */
30
+ export declare function appendAdam(graph: Graph, paramGrads: Record<string, Tensor>, paramTensors: Record<string, Tensor>, config: AdamConfig): AdamResult;
31
+ //# sourceMappingURL=adam.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"adam.d.ts","sourceRoot":"","sources":["../src/adam.ts"],"names":[],"mappings":"AAoBA,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,SAAS,CAAA;AACrC,OAAO,KAAK,EAAE,KAAK,EAAE,MAAM,SAAS,CAAA;AACpC,OAAO,KAAK,EAAE,aAAa,EAAE,MAAM,cAAc,CAAA;AAIjD,MAAM,WAAW,UAAU;IACzB,EAAE,EAAE,MAAM,CAAA;IACV,EAAE,CAAC,EAAE,MAAM,CAAA;IACX,EAAE,CAAC,EAAE,MAAM,CAAA;IACX,GAAG,CAAC,EAAE,MAAM,CAAA;CACb;AAED,MAAM,WAAW,UAAU;IACzB,kEAAkE;IAClE,UAAU,EAAE,aAAa,EAAE,CAAA;IAC3B;iFAC6E;IAC7E,YAAY,EAAE,MAAM,CAAA;IACpB,oEAAoE;IACpE,MAAM,EAAE,QAAQ,CAAC,UAAU,CAAC,CAAA;CAC7B;AAED;;;;;;;;;;GAUG;AACH,wBAAgB,UAAU,CACxB,KAAK,EAAE,KAAK,EACZ,UAAU,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,EAClC,YAAY,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,EACpC,MAAM,EAAE,UAAU,GACjB,UAAU,CAmCZ"}
package/dist/adam.js ADDED
@@ -0,0 +1,66 @@
1
+ // Adam optimizer, in-graph.
2
+ //
3
+ // `appendAdam` extends a graph that already has a forward pass + autograd-emitted
4
+ // backward (i.e., has paramGrads from `appendGrad`) with the Adam update math.
5
+ //
6
+ // Per parameter P with gradient g:
7
+ // m_new = b1 * m + (1 - b1) * g
8
+ // v_new = b2 * v + (1 - b2) * g²
9
+ // p_new = p - lr * m_new / (sqrt(v_new) + eps)
10
+ //
11
+ // This is "Adam without bias correction" — the `1 / (1 - β^t)` factors are
12
+ // dropped because computing them in-graph requires per-step uniforms or
13
+ // awkward exp/log tricks. In practice the omission only affects the first
14
+ // ~100 steps; convergence is unaffected.
15
+ //
16
+ // Returns writeback declarations the buffer planner uses to wire up the
17
+ // "after step, copy the new value into the persistent home" path. m and v
18
+ // are state_inputs (zero-initialized, persistent across steps); the param
19
+ // updates are aliased back to the param buffers.
20
+ import { traceInto, stateInput, tensorInput } from './trace.js';
21
+ import { adamUpdateM, adamUpdateV, adamUpdateP } from './ops.js';
22
+ /**
23
+ * Append Adam update ops to `graph`. Must be called inside an active trace
24
+ * context (or after a trace, since traceInto re-enters the graph).
25
+ *
26
+ * @param graph the graph (already containing forward + backward)
27
+ * @param paramGrads param name -> gradient tensor (output of `appendGrad`)
28
+ * @param paramTensors param name -> the param's leaf Tensor (the param_input).
29
+ * Needed because the param_input lives in the graph but we
30
+ * don't have a direct map by name in `Graph` — caller passes it.
31
+ * @param config Adam hyperparameters
32
+ */
33
+ export function appendAdam(graph, paramGrads, paramTensors, config) {
34
+ const fullConfig = {
35
+ lr: config.lr,
36
+ b1: config.b1 ?? 0.9,
37
+ b2: config.b2 ?? 0.999,
38
+ eps: config.eps ?? 1e-8,
39
+ };
40
+ const writebacks = [];
41
+ const lrtInputName = '_adam_lrt';
42
+ return traceInto(graph, () => {
43
+ // One scalar lrt input shared by every adam_update_p call. Runtime supplies
44
+ // it per step as `lr * sqrt(1-b2^t) / (1-b1^t)`.
45
+ const lrt = tensorInput(lrtInputName, [], 'f32');
46
+ for (const name of Object.keys(paramGrads)) {
47
+ const p = paramTensors[name];
48
+ const g = paramGrads[name];
49
+ if (!p)
50
+ throw new Error(`appendAdam: missing param tensor for '${name}'`);
51
+ if (!g)
52
+ throw new Error(`appendAdam: missing gradient for '${name}'`);
53
+ const mState = stateInput(`adam_m_${name}`, p.shape, 'f32', 0);
54
+ const vState = stateInput(`adam_v_${name}`, p.shape, 'f32', 0);
55
+ // Three fused kernels per parameter — one for each of m_new / v_new / p_new.
56
+ const newM = adamUpdateM(mState, g, fullConfig.b1);
57
+ const newV = adamUpdateV(vState, g, fullConfig.b2);
58
+ const newP = adamUpdateP(p, newM, newV, lrt, fullConfig.eps);
59
+ writebacks.push({ source: newM, destName: `adam_m_${name}`, destKind: 'state' });
60
+ writebacks.push({ source: newV, destName: `adam_v_${name}`, destKind: 'state' });
61
+ writebacks.push({ source: newP, destName: name, destKind: 'param' });
62
+ }
63
+ return { writebacks, lrtInputName, config: fullConfig };
64
+ });
65
+ }
66
+ //# sourceMappingURL=adam.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"adam.js","sourceRoot":"","sources":["../src/adam.ts"],"names":[],"mappings":"AAAA,4BAA4B;AAC5B,EAAE;AACF,kFAAkF;AAClF,+EAA+E;AAC/E,EAAE;AACF,mCAAmC;AACnC,kCAAkC;AAClC,mCAAmC;AACnC,iDAAiD;AACjD,EAAE;AACF,2EAA2E;AAC3E,wEAAwE;AACxE,0EAA0E;AAC1E,yCAAyC;AACzC,EAAE;AACF,wEAAwE;AACxE,0EAA0E;AAC1E,0EAA0E;AAC1E,iDAAiD;AAKjD,OAAO,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,MAAM,YAAY,CAAA;AAC/D,OAAO,EAAE,WAAW,EAAE,WAAW,EAAE,WAAW,EAAE,MAAM,UAAU,CAAA;AAmBhE;;;;;;;;;;GAUG;AACH,MAAM,UAAU,UAAU,CACxB,KAAY,EACZ,UAAkC,EAClC,YAAoC,EACpC,MAAkB;IAElB,MAAM,UAAU,GAAyB;QACvC,EAAE,EAAE,MAAM,CAAC,EAAE;QACb,EAAE,EAAE,MAAM,CAAC,EAAE,IAAI,GAAG;QACpB,EAAE,EAAE,MAAM,CAAC,EAAE,IAAI,KAAK;QACtB,GAAG,EAAE,MAAM,CAAC,GAAG,IAAI,IAAI;KACxB,CAAA;IACD,MAAM,UAAU,GAAoB,EAAE,CAAA;IACtC,MAAM,YAAY,GAAG,WAAW,CAAA;IAEhC,OAAO,SAAS,CAAC,KAAK,EAAE,GAAG,EAAE;QAC3B,4EAA4E;QAC5E,iDAAiD;QACjD,MAAM,GAAG,GAAG,WAAW,CAAC,YAAY,EAAE,EAAE,EAAE,KAAK,CAAC,CAAA;QAEhD,KAAK,MAAM,IAAI,IAAI,MAAM,CAAC,IAAI,CAAC,UAAU,CAAC,EAAE,CAAC;YAC3C,MAAM,CAAC,GAAG,YAAY,CAAC,IAAI,CAAC,CAAA;YAC5B,MAAM,CAAC,GAAG,UAAU,CAAC,IAAI,CAAC,CAAA;YAC1B,IAAI,CAAC,CAAC;gBAAE,MAAM,IAAI,KAAK,CAAC,yCAAyC,IAAI,GAAG,CAAC,CAAA;YACzE,IAAI,CAAC,CAAC;gBAAE,MAAM,IAAI,KAAK,CAAC,qCAAqC,IAAI,GAAG,CAAC,CAAA;YAErE,MAAM,MAAM,GAAG,UAAU,CAAC,UAAU,IAAI,EAAE,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,CAAC,CAAC,CAAA;YAC9D,MAAM,MAAM,GAAG,UAAU,CAAC,UAAU,IAAI,EAAE,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,CAAC,CAAC,CAAA;YAE9D,6EAA6E;YAC7E,MAAM,IAAI,GAAG,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,UAAU,CAAC,EAAE,CAAC,CAAA;YAClD,MAAM,IAAI,GAAG,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,UAAU,CAAC,EAAE,CAAC,CAAA;YAClD,MAAM,IAAI,GAAG,WAAW,CAAC,CAAC,EAAE,IAAI,EAAE,IAAI,EAAE,GAAG,EAAE,UAAU,CAAC,GAAG,CAAC,CAAA;YAE5D,UAAU,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE,UAAU,IAAI,EAAE,EAAE,QAAQ,EAAE,OAAO,EAAE,CAAC,CAAA;YAChF,UAAU,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE,UAAU,IAAI,EAAE,EAAE,QAAQ,EAAE,OAAO,EAAE,CAAC,CAAA;YAChF,UAAU,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE,IAAI,EAAc,QAAQ,EAAE,OAAO,EAAE,CAAC,CAAA;QAClF,CAAC;QACD,OAAO,EAAE,UAAU,EAAE,YAAY,EAAE,MAAM,EAAE,UAAU,EAAE,CAAA;IACzD,CAAC,CAAC,CAAA;AACJ,CAAC"}
@@ -0,0 +1,56 @@
1
+ import type { Graph, Tensor, Dtype, Shape } from './ir.js';
2
+ export interface BufferSpec {
3
+ /** Matches tensor.id. */
4
+ id: number;
5
+ byteSize: number;
6
+ dtype: Dtype;
7
+ shape: Shape;
8
+ kind: 'param' | 'param_grad' | 'tensor_input' | 'state' | 'intermediate' | 'output';
9
+ /** External name for param/param_grad/tensor_input/state bindings. null otherwise. */
10
+ name: string | null;
11
+ /** For state buffers: the value to fill on initial allocation. 0 by default. */
12
+ initValue?: number;
13
+ }
14
+ /**
15
+ * After step(), copy `source`'s buffer into `dest`'s buffer.
16
+ * Used to write back updated optimizer state and updated parameters into
17
+ * their persistent home buffers.
18
+ */
19
+ export interface Writeback {
20
+ source: number;
21
+ dest: number;
22
+ bytes: number;
23
+ }
24
+ export interface BufferPlan {
25
+ buffers: BufferSpec[];
26
+ /** Tensor id -> buffer id (currently 1:1 but kept opaque for future pooling). */
27
+ tensorToBuffer: Map<number, number>;
28
+ /** Easy lookup tables for the runtime. */
29
+ paramsByName: Map<string, number>;
30
+ inputsByName: Map<string, number>;
31
+ paramGradsByName: Map<string, number>;
32
+ statesByName: Map<string, number>;
33
+ outputBufferIds: number[];
34
+ /** End-of-step writebacks (Adam updates for params, m, v, etc.) */
35
+ writebacks: Writeback[];
36
+ }
37
+ /**
38
+ * Caller-supplied writeback declarations: "after each step, copy this Tensor's
39
+ * buffer into the persistent home of this param/state."
40
+ */
41
+ export interface WritebackDecl {
42
+ /** The Tensor (output of some op) holding the new value to write back. */
43
+ source: Tensor;
44
+ /** Either a param name (writes to that param's home buffer) or a state name. */
45
+ destName: string;
46
+ destKind: 'param' | 'state';
47
+ }
48
+ /**
49
+ * Build a BufferPlan from a graph + the param-grad map produced by appendGrad.
50
+ * @param graph the full graph (forward + backward + any optimizer ops)
51
+ * @param paramGrads map from param name -> the Tensor that holds its gradient
52
+ * @param writebackDecls list of end-of-step writebacks (e.g. from appendAdam).
53
+ * Empty when there's no optimizer in the graph.
54
+ */
55
+ export declare function planBuffers(graph: Graph, paramGrads: Record<string, Tensor>, writebackDecls?: WritebackDecl[]): BufferPlan;
56
+ //# sourceMappingURL=buffers.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"buffers.d.ts","sourceRoot":"","sources":["../src/buffers.ts"],"names":[],"mappings":"AAcA,OAAO,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAU,MAAM,SAAS,CAAA;AAElE,MAAM,WAAW,UAAU;IACzB,yBAAyB;IACzB,EAAE,EAAE,MAAM,CAAA;IACV,QAAQ,EAAE,MAAM,CAAA;IAChB,KAAK,EAAE,KAAK,CAAA;IACZ,KAAK,EAAE,KAAK,CAAA;IACZ,IAAI,EAAE,OAAO,GAAG,YAAY,GAAG,cAAc,GAAG,OAAO,GAAG,cAAc,GAAG,QAAQ,CAAA;IACnF,sFAAsF;IACtF,IAAI,EAAE,MAAM,GAAG,IAAI,CAAA;IACnB,gFAAgF;IAChF,SAAS,CAAC,EAAE,MAAM,CAAA;CACnB;AAED;;;;GAIG;AACH,MAAM,WAAW,SAAS;IACxB,MAAM,EAAE,MAAM,CAAA;IACd,IAAI,EAAE,MAAM,CAAA;IACZ,KAAK,EAAE,MAAM,CAAA;CACd;AAED,MAAM,WAAW,UAAU;IACzB,OAAO,EAAE,UAAU,EAAE,CAAA;IACrB,iFAAiF;IACjF,cAAc,EAAE,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IACnC,0CAA0C;IAC1C,YAAY,EAAE,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IACjC,YAAY,EAAE,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IACjC,gBAAgB,EAAE,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IACrC,YAAY,EAAE,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IACjC,eAAe,EAAE,MAAM,EAAE,CAAA;IACzB,mEAAmE;IACnE,UAAU,EAAE,SAAS,EAAE,CAAA;CACxB;AAUD;;;GAGG;AACH,MAAM,WAAW,aAAa;IAC5B,0EAA0E;IAC1E,MAAM,EAAE,MAAM,CAAA;IACd,gFAAgF;IAChF,QAAQ,EAAE,MAAM,CAAA;IAChB,QAAQ,EAAE,OAAO,GAAG,OAAO,CAAA;CAC5B;AAED;;;;;;GAMG;AACH,wBAAgB,WAAW,CACzB,KAAK,EAAE,KAAK,EACZ,UAAU,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,EAClC,cAAc,GAAE,aAAa,EAAO,GACnC,UAAU,CAuFZ"}