tensorgrad 0.0.2 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. package/README.md +7 -9
  2. package/dist/adam.d.ts +22 -5
  3. package/dist/adam.d.ts.map +1 -1
  4. package/dist/adam.js +42 -10
  5. package/dist/adam.js.map +1 -1
  6. package/dist/buffers.d.ts +1 -0
  7. package/dist/buffers.d.ts.map +1 -1
  8. package/dist/buffers.js +12 -1
  9. package/dist/buffers.js.map +1 -1
  10. package/dist/capture.d.ts +3 -0
  11. package/dist/capture.d.ts.map +1 -0
  12. package/dist/capture.js +33 -0
  13. package/dist/capture.js.map +1 -0
  14. package/dist/codegen.js +16 -5
  15. package/dist/codegen.js.map +1 -1
  16. package/dist/compile.d.ts +33 -5
  17. package/dist/compile.d.ts.map +1 -1
  18. package/dist/compile.js +106 -14
  19. package/dist/compile.js.map +1 -1
  20. package/dist/index.d.ts +5 -3
  21. package/dist/index.d.ts.map +1 -1
  22. package/dist/index.js +4 -2
  23. package/dist/index.js.map +1 -1
  24. package/dist/ir.d.ts +2 -0
  25. package/dist/ir.d.ts.map +1 -1
  26. package/dist/ir.js +1 -1
  27. package/dist/ir.js.map +1 -1
  28. package/dist/module.d.ts +30 -4
  29. package/dist/module.d.ts.map +1 -1
  30. package/dist/module.js +39 -13
  31. package/dist/module.js.map +1 -1
  32. package/dist/nn.d.ts +19 -0
  33. package/dist/nn.d.ts.map +1 -0
  34. package/dist/nn.js +60 -0
  35. package/dist/nn.js.map +1 -0
  36. package/dist/ops.d.ts +1 -1
  37. package/dist/ops.d.ts.map +1 -1
  38. package/dist/ops.js +18 -1
  39. package/dist/ops.js.map +1 -1
  40. package/dist/runtime.d.ts +79 -4
  41. package/dist/runtime.d.ts.map +1 -1
  42. package/dist/runtime.js +153 -19
  43. package/dist/runtime.js.map +1 -1
  44. package/dist/trace.d.ts +1 -0
  45. package/dist/trace.d.ts.map +1 -1
  46. package/dist/trace.js +12 -0
  47. package/dist/trace.js.map +1 -1
  48. package/package.json +1 -2
  49. package/src/adam.ts +65 -14
  50. package/src/buffers.ts +14 -1
  51. package/src/capture.ts +36 -0
  52. package/src/codegen.ts +16 -5
  53. package/src/compile.ts +122 -16
  54. package/src/index.ts +5 -3
  55. package/src/ir.ts +20 -4
  56. package/src/module.ts +75 -11
  57. package/src/nn.ts +59 -0
  58. package/src/ops.ts +26 -3
  59. package/src/runtime.ts +260 -22
  60. package/src/trace.ts +13 -0
  61. package/SPEC.md +0 -293
package/SPEC.md DELETED
@@ -1,293 +0,0 @@
1
- # Tensorgrad — Architecture
2
-
3
- This document covers the design decisions, IR, and internals of tensorgrad.
4
- For installation and user-facing API, see `README.md`. For pre-1.0
5
- implementation status and what to pick up next, see `HANDOFF.md`.
6
-
7
- ## Scope and non-goals (load-bearing)
8
-
9
- The library only does what it does because of what it doesn't do. Each
10
- "out of scope" decision is the *precondition* that lets the rest stay small.
11
-
12
- **In scope:**
13
- - Static-shape models. Every shape is fixed at compile time.
14
- - WebGPU only.
15
- - f32 only.
16
- - `grad` (reverse-mode autograd) as the only transformation.
17
- - A closed set of ~25 ops covering transformers + MLPs.
18
- - Adam optimizer in-IR.
19
-
20
- **Out of scope (deliberately):**
21
- - Wasm or WebGL fallback.
22
- - Dynamic shapes, shape polymorphism.
23
- - `vmap`, `pmap`, `jvp`, `custom_vjp`, higher-order gradients.
24
- - Dtype promotion, mixed precision.
25
- - General PyTree machinery (we use `Module` + property paths instead).
26
- - Inference of pre-trained models (use ONNX Runtime Web or transformers.js).
27
- - ONNX import / safetensors / model loaders.
28
- - Distributed training, gradient accumulation across devices.
29
-
30
- The non-goals are load-bearing. Trying to add any of them without rethinking
31
- the IR forces complexity throughout.
32
-
33
- ## Architecture overview
34
-
35
- ```
36
- ┌────────────────────────────────────────────────────────────┐
37
- │ User code │
38
- │ class Model extends Module { /* params */ } │
39
- │ function forward(m: Model, x: Tensor): Tensor { /* ... */ }│
40
- └────────────────────────────────────────────────────────────┘
41
-
42
- ▼ trace()
43
- ┌────────────────────────────────────────────────────────────┐
44
- │ Forward IR build (src/trace.ts, src/ops.ts, src/shape.ts) │
45
- │ Each op call appends a node to the Graph and returns a │
46
- │ fresh Tensor handle. Shapes inferred + validated per op. │
47
- └────────────────────────────────────────────────────────────┘
48
-
49
- ▼ appendGrad()
50
- ┌────────────────────────────────────────────────────────────┐
51
- │ Reverse-mode autograd (src/grad.ts) │
52
- │ Topological walk; each forward op contributes its │
53
- │ transpose rule, building the backward graph in-place. │
54
- └────────────────────────────────────────────────────────────┘
55
-
56
- ▼ appendAdam() (optional)
57
- ┌────────────────────────────────────────────────────────────┐
58
- │ Optimizer (src/adam.ts) │
59
- │ Per-param: m_state, v_state state_inputs + fused │
60
- │ adam_update_{m,v,p} ops. Writebacks declared. │
61
- └────────────────────────────────────────────────────────────┘
62
-
63
- ▼ planBuffers()
64
- ┌────────────────────────────────────────────────────────────┐
65
- │ Buffer plan (src/buffers.ts) │
66
- │ One GPU buffer per IR tensor, categorized: │
67
- │ param / param_grad / state / tensor_input / intermediate. │
68
- │ Writebacks resolved to (source_buf → dest_buf) pairs. │
69
- └────────────────────────────────────────────────────────────┘
70
-
71
- ▼ emitKernels()
72
- ┌────────────────────────────────────────────────────────────┐
73
- │ WGSL codegen (src/codegen.ts) │
74
- │ Per op kind: a kernel template with shapes baked in. │
75
- │ Returns dispatch-ready KernelSpec[]. │
76
- └────────────────────────────────────────────────────────────┘
77
-
78
- ▼ createRuntime()
79
- ┌────────────────────────────────────────────────────────────┐
80
- │ Runtime (src/runtime.ts) │
81
- │ GPUDevice setup, pipeline cache, bind groups, │
82
- │ step(): upload inputs → dispatch all kernels → writebacks │
83
- │ → loss readback. Compile errors surface via │
84
- │ pushErrorScope+getCompilationInfo. │
85
- └────────────────────────────────────────────────────────────┘
86
- ```
87
-
88
- ## Key design decisions
89
-
90
- **D1. Runtime tracing, not build-time.** Forward function is traced once on
91
- first compile; the IR is built from those op calls. Build-time tracing via
92
- a TypeScript transformer plugin would be cleaner but adds a build-step
93
- requirement. v2 candidate.
94
-
95
- **D2. Tensors are opaque handles, not Proxies.** Each op returns a fresh
96
- `Tensor` (just `{ id, shape, dtype, source, site }`). Proxy-based tracing
97
- gives nicer error UX but couples the IR to runtime introspection.
98
-
99
- **D3. No reference counting.** Every IR tensor gets its own GPU buffer,
100
- allocated once and never freed. Our scope (one model, fixed shapes,
101
- training in a loop) means there's nothing to gain from refcount discipline.
102
- Memory cost is bounded; buffer pooling is a v2 optimization, not v1
103
- correctness.
104
-
105
- **D4. Closed op set.** The IR knows about exactly the ops it supports.
106
- Adding a new op means adding (a) shape rule, (b) WGSL kernel template,
107
- (c) autograd transpose rule. This is intentional — a closed op set makes
108
- each piece tractable to write and verify by hand.
109
-
110
- **D5. Shapes checked at trace time, not at type level.** Type-level shape
111
- encoding in TypeScript is real but hits recursion limits and adds
112
- significant generic complexity to user code. Runtime shape errors at trace
113
- time, with call-site capture, are good enough.
114
-
115
- **D6. Adam state in the IR.** Optimizer state (m, v, plus a per-step `lrt`
116
- scalar for bias correction) lives in dedicated `state_input` buffers that
117
- persist across `step()` calls. Writebacks at the end of each step copy new
118
- values into their persistent homes. No CPU↔GPU round-trip per step.
119
-
120
- **D7. Module separates from forward.** Mutable parameter storage lives in
121
- `Module` subclasses. Forward functions are pure, take the materialized
122
- model as the first argument, and call ordinary op functions. State and
123
- computation never mix — JAX's lesson, applied to TypeScript with
124
- class-based ergonomics.
125
-
126
- **D8. JS-number scalar overloads.** `add(x, 1e-5)` and `add(x, y)` both
127
- work. The scalar variants dispatch to fused IR ops internally.
128
-
129
- ## IR
130
-
131
- ```ts
132
- interface Tensor {
133
- readonly id: number
134
- readonly shape: Shape
135
- readonly dtype: Dtype
136
- readonly source: number | null // op index, or null for leaves
137
- readonly site: CallSite | null // user's stack at op-call time
138
- }
139
-
140
- type OpNode =
141
- | { kind: 'param_input'; ... } | { kind: 'tensor_input'; ... }
142
- | { kind: 'state_input'; ... } | { kind: 'arange'; ... }
143
- | { kind: 'const_scalar'; ... }
144
- | { kind: 'add' | 'sub' | 'mul' | 'div'; ... }
145
- | { kind: 'add_scalar' | 'mul_scalar'; ... }
146
- | { kind: 'sqrt' | 'rsqrt' | 'log' | 'exp' | 'relu'; ... }
147
- | { kind: 'less' | 'greater'; ... } | { kind: 'where'; ... }
148
- | { kind: 'mean_last' | 'sum_last'; ... }
149
- | { kind: 'reshape' | 'transpose' | 'slice_last_range'; ... }
150
- | { kind: 'broadcast_to' | 'sum_to_shape'; ... }
151
- | { kind: 'matmul' | 'matmul_batched'; ... }
152
- | { kind: 'one_hot'; ... }
153
- | { kind: 'softmax_causal_last' | 'log_softmax_last' | 'where_causal'; ... }
154
- | { kind: 'relu_grad'; ... }
155
- | { kind: 'adam_update_m' | 'adam_update_v' | 'adam_update_p'; ... }
156
-
157
- interface Graph {
158
- readonly ops: OpNode[]
159
- readonly tensors: Tensor[]
160
- readonly outputs: number[] // tensor ids — typically just the loss
161
- }
162
- ```
163
-
164
- The op kinds are intentionally split fine-grained (`mean_last` not
165
- `mean(axis)`) because each kind maps to a hand-written WGSL kernel. Adding
166
- generality later is fine; pretending to be more general than we are isn't.
167
-
168
- ## Op set (current)
169
-
170
- **Leaves:** `param_input`, `tensor_input`, `state_input`, `arange`, `const_scalar`
171
-
172
- **Element-wise binops** (NumPy broadcasting): `add`, `sub`, `mul`, `div`,
173
- plus fused `add_scalar`, `mul_scalar`
174
-
175
- **Element-wise unary:** `sqrt`, `rsqrt`, `log`, `exp`, `relu`
176
-
177
- **Comparisons + select:** `less`, `greater`, `where`
178
-
179
- **Reductions over last axis:** `mean_last`, `sum_last`
180
-
181
- **Shape:** `reshape`, `transpose`, `slice_last_range`, `broadcast_to`, `sum_to_shape`
182
-
183
- **Linear algebra:** `matmul` (2D rhs), `matmul_batched` (both batched)
184
-
185
- **Indexing / casting:** `one_hot`
186
-
187
- **ML primitives** (fused for clean autograd): `softmax_causal_last`,
188
- `log_softmax_last`, `where_causal`
189
-
190
- **Autograd-internal:** `relu_grad`
191
-
192
- **Adam-internal:** `adam_update_m`, `adam_update_v`, `adam_update_p`
193
-
194
- ## Module abstraction
195
-
196
- The `Module` base class enables Domeleon-style auto-discovery of nested
197
- modules and parameters via property reflection:
198
-
199
- ```ts
200
- class Linear extends Module {
201
- W: Tensor; b: Tensor
202
- constructor(public inDim: number, public outDim: number) {
203
- super()
204
- this.W = this.param([inDim, outDim]) // returns ParamSentinel cast to Tensor
205
- this.b = this.param([outDim])
206
- }
207
- }
208
- ```
209
-
210
- `this.param(shape)` returns a `ParamSentinel` typed as `Tensor`. At compile
211
- time, `materializeParams(root)` walks enumerable properties of the model
212
- tree (recursing into nested `Module` instances and arrays of modules),
213
- replaces every sentinel with a real `paramInput` tensor whose name is the
214
- property path (`layers.0.attn.q.W`), and returns a flat `Record<path,
215
- Tensor>` for autograd to use.
216
-
217
- This is the JAX/Equinox separation: parameter storage is mutable
218
- (state-bearing components), forward computation is pure (functions over
219
- materialized parameters and inputs).
220
-
221
- ## WGSL codegen
222
-
223
- Each op kind has a kernel template in `codegen.ts`. Shapes are **baked into
224
- the WGSL as compile-time constants** rather than passed as uniforms — this
225
- gives the WGSL compiler full freedom to specialize and means each shape
226
- combination produces a distinct shader. Fine for our static-shape model.
227
-
228
- **Dispatch:** WebGPU caps each dimension at 65535 workgroups. The runtime
229
- dispatches as `(min(N, 65535), ceil(N/65535), 1)`; kernels compute their
230
- global index as `gid.x + gid.y * (65535 * workgroup_size)`. Workgroup size
231
- is 256 — large enough that our biggest kernel (~8M threads in
232
- `matmul_bwd_dW`) fits in 1D.
233
-
234
- **Error reporting:** `runtime.ts` wraps each pipeline creation in
235
- `pushErrorScope('validation')` and pulls `getCompilationInfo()` on
236
- failure, so shader bugs surface with file/line/message rather than the
237
- useless "previous error" you get when a broken pipeline is dispatched.
238
-
239
- ## Buffer plan
240
-
241
- `planBuffers(graph, paramGrads, writebacks)` walks every tensor and
242
- categorizes it:
243
-
244
- | Kind | Purpose | Lifetime |
245
- |---|---|---|
246
- | `param` | trainable parameter | persistent |
247
- | `param_grad` | gradient w.r.t. a param | one step |
248
- | `state` | optimizer state (Adam m, v) | persistent |
249
- | `tensor_input` | data input (tokens, targets) | one step |
250
- | `intermediate` | any other op output | one step |
251
- | `output` | exposed graph output (loss) | one step |
252
-
253
- State buffers are zero-initialized at runtime creation. Writebacks (declared
254
- by `appendAdam`) describe end-of-step `copyBufferToBuffer` operations from
255
- freshly-computed values into their persistent homes.
256
-
257
- ## Autograd
258
-
259
- `appendGrad(graph)` walks the forward ops in reverse and emits backward ops
260
- into the same graph. Each op's transpose rule is hand-written in
261
- `grad.ts`. The cotangents map (`tensorId → Tensor`) accumulates
262
- contributions from multiple consumers via `add`.
263
-
264
- Two notable workarounds:
265
-
266
- - **Embedding lookup is implemented as `oneHot @ table`** rather than
267
- `gather`. Gather has no transpose rule (it'd need scatter-with-atomic-add
268
- or similar), but `oneHot @ table` decomposes into ops that *do* have
269
- rules, so autograd works through it for free.
270
- - **`slice_last_range` has no backward yet.** Forward works (used in any
271
- axis-2 slicing pattern); backward is unimplemented because it'd need a
272
- scatter-style "place into zeros" op. Workaround: use multiple separate
273
- matmuls (e.g. `W_q`/`W_k`/`W_v`) instead of a fused `W_qkv`.
274
-
275
- ## Verification approach
276
-
277
- Two layers:
278
-
279
- 1. **Smoke test** (`pnpm test` → `test/smoke.ts`) — runs in Node without
280
- GPU. Builds the IR, attaches grad, plans buffers, emits all WGSL, and
281
- verifies structure (kernel count, binding count, shape errors). Catches
282
- codegen regressions without needing a browser.
283
- 2. **Live samples** (`pnpm --filter samples dev`) — Vite dev server with
284
- a `/__log` endpoint that streams browser logs to stdout, used during
285
- development to bypass copy-paste-from-console debugging.
286
-
287
- ## What this spec is not
288
-
289
- A contract. The API will change before 1.0. The load-bearing decisions
290
- are in **Scope and non-goals** and **Key design decisions** above —
291
- those are the points where the design deliberately diverges from JAX or
292
- PyTorch, and where reverting any of them effectively re-creates the
293
- failure mode that motivated this library.