tensorgrad 0.0.2 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +7 -9
- package/dist/adam.d.ts +22 -5
- package/dist/adam.d.ts.map +1 -1
- package/dist/adam.js +42 -10
- package/dist/adam.js.map +1 -1
- package/dist/buffers.d.ts +1 -0
- package/dist/buffers.d.ts.map +1 -1
- package/dist/buffers.js +12 -1
- package/dist/buffers.js.map +1 -1
- package/dist/capture.d.ts +3 -0
- package/dist/capture.d.ts.map +1 -0
- package/dist/capture.js +33 -0
- package/dist/capture.js.map +1 -0
- package/dist/codegen.js +16 -5
- package/dist/codegen.js.map +1 -1
- package/dist/compile.d.ts +33 -5
- package/dist/compile.d.ts.map +1 -1
- package/dist/compile.js +106 -14
- package/dist/compile.js.map +1 -1
- package/dist/index.d.ts +5 -3
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +4 -2
- package/dist/index.js.map +1 -1
- package/dist/ir.d.ts +2 -0
- package/dist/ir.d.ts.map +1 -1
- package/dist/ir.js +1 -1
- package/dist/ir.js.map +1 -1
- package/dist/module.d.ts +30 -4
- package/dist/module.d.ts.map +1 -1
- package/dist/module.js +39 -13
- package/dist/module.js.map +1 -1
- package/dist/nn.d.ts +19 -0
- package/dist/nn.d.ts.map +1 -0
- package/dist/nn.js +60 -0
- package/dist/nn.js.map +1 -0
- package/dist/ops.d.ts +1 -1
- package/dist/ops.d.ts.map +1 -1
- package/dist/ops.js +18 -1
- package/dist/ops.js.map +1 -1
- package/dist/runtime.d.ts +79 -4
- package/dist/runtime.d.ts.map +1 -1
- package/dist/runtime.js +153 -19
- package/dist/runtime.js.map +1 -1
- package/dist/trace.d.ts +1 -0
- package/dist/trace.d.ts.map +1 -1
- package/dist/trace.js +12 -0
- package/dist/trace.js.map +1 -1
- package/package.json +1 -2
- package/src/adam.ts +65 -14
- package/src/buffers.ts +14 -1
- package/src/capture.ts +36 -0
- package/src/codegen.ts +16 -5
- package/src/compile.ts +122 -16
- package/src/index.ts +5 -3
- package/src/ir.ts +20 -4
- package/src/module.ts +75 -11
- package/src/nn.ts +59 -0
- package/src/ops.ts +26 -3
- package/src/runtime.ts +260 -22
- package/src/trace.ts +13 -0
- package/SPEC.md +0 -293
package/SPEC.md
DELETED
|
@@ -1,293 +0,0 @@
|
|
|
1
|
-
# Tensorgrad — Architecture
|
|
2
|
-
|
|
3
|
-
This document covers the design decisions, IR, and internals of tensorgrad.
|
|
4
|
-
For installation and user-facing API, see `README.md`. For pre-1.0
|
|
5
|
-
implementation status and what to pick up next, see `HANDOFF.md`.
|
|
6
|
-
|
|
7
|
-
## Scope and non-goals (load-bearing)
|
|
8
|
-
|
|
9
|
-
The library only does what it does because of what it doesn't do. Each
|
|
10
|
-
"out of scope" decision is the *precondition* that lets the rest stay small.
|
|
11
|
-
|
|
12
|
-
**In scope:**
|
|
13
|
-
- Static-shape models. Every shape is fixed at compile time.
|
|
14
|
-
- WebGPU only.
|
|
15
|
-
- f32 only.
|
|
16
|
-
- `grad` (reverse-mode autograd) as the only transformation.
|
|
17
|
-
- A closed set of ~25 ops covering transformers + MLPs.
|
|
18
|
-
- Adam optimizer in-IR.
|
|
19
|
-
|
|
20
|
-
**Out of scope (deliberately):**
|
|
21
|
-
- Wasm or WebGL fallback.
|
|
22
|
-
- Dynamic shapes, shape polymorphism.
|
|
23
|
-
- `vmap`, `pmap`, `jvp`, `custom_vjp`, higher-order gradients.
|
|
24
|
-
- Dtype promotion, mixed precision.
|
|
25
|
-
- General PyTree machinery (we use `Module` + property paths instead).
|
|
26
|
-
- Inference of pre-trained models (use ONNX Runtime Web or transformers.js).
|
|
27
|
-
- ONNX import / safetensors / model loaders.
|
|
28
|
-
- Distributed training, gradient accumulation across devices.
|
|
29
|
-
|
|
30
|
-
The non-goals are load-bearing. Trying to add any of them without rethinking
|
|
31
|
-
the IR forces complexity throughout.
|
|
32
|
-
|
|
33
|
-
## Architecture overview
|
|
34
|
-
|
|
35
|
-
```
|
|
36
|
-
┌────────────────────────────────────────────────────────────┐
|
|
37
|
-
│ User code │
|
|
38
|
-
│ class Model extends Module { /* params */ } │
|
|
39
|
-
│ function forward(m: Model, x: Tensor): Tensor { /* ... */ }│
|
|
40
|
-
└────────────────────────────────────────────────────────────┘
|
|
41
|
-
│
|
|
42
|
-
▼ trace()
|
|
43
|
-
┌────────────────────────────────────────────────────────────┐
|
|
44
|
-
│ Forward IR build (src/trace.ts, src/ops.ts, src/shape.ts) │
|
|
45
|
-
│ Each op call appends a node to the Graph and returns a │
|
|
46
|
-
│ fresh Tensor handle. Shapes inferred + validated per op. │
|
|
47
|
-
└────────────────────────────────────────────────────────────┘
|
|
48
|
-
│
|
|
49
|
-
▼ appendGrad()
|
|
50
|
-
┌────────────────────────────────────────────────────────────┐
|
|
51
|
-
│ Reverse-mode autograd (src/grad.ts) │
|
|
52
|
-
│ Topological walk; each forward op contributes its │
|
|
53
|
-
│ transpose rule, building the backward graph in-place. │
|
|
54
|
-
└────────────────────────────────────────────────────────────┘
|
|
55
|
-
│
|
|
56
|
-
▼ appendAdam() (optional)
|
|
57
|
-
┌────────────────────────────────────────────────────────────┐
|
|
58
|
-
│ Optimizer (src/adam.ts) │
|
|
59
|
-
│ Per-param: m_state, v_state state_inputs + fused │
|
|
60
|
-
│ adam_update_{m,v,p} ops. Writebacks declared. │
|
|
61
|
-
└────────────────────────────────────────────────────────────┘
|
|
62
|
-
│
|
|
63
|
-
▼ planBuffers()
|
|
64
|
-
┌────────────────────────────────────────────────────────────┐
|
|
65
|
-
│ Buffer plan (src/buffers.ts) │
|
|
66
|
-
│ One GPU buffer per IR tensor, categorized: │
|
|
67
|
-
│ param / param_grad / state / tensor_input / intermediate. │
|
|
68
|
-
│ Writebacks resolved to (source_buf → dest_buf) pairs. │
|
|
69
|
-
└────────────────────────────────────────────────────────────┘
|
|
70
|
-
│
|
|
71
|
-
▼ emitKernels()
|
|
72
|
-
┌────────────────────────────────────────────────────────────┐
|
|
73
|
-
│ WGSL codegen (src/codegen.ts) │
|
|
74
|
-
│ Per op kind: a kernel template with shapes baked in. │
|
|
75
|
-
│ Returns dispatch-ready KernelSpec[]. │
|
|
76
|
-
└────────────────────────────────────────────────────────────┘
|
|
77
|
-
│
|
|
78
|
-
▼ createRuntime()
|
|
79
|
-
┌────────────────────────────────────────────────────────────┐
|
|
80
|
-
│ Runtime (src/runtime.ts) │
|
|
81
|
-
│ GPUDevice setup, pipeline cache, bind groups, │
|
|
82
|
-
│ step(): upload inputs → dispatch all kernels → writebacks │
|
|
83
|
-
│ → loss readback. Compile errors surface via │
|
|
84
|
-
│ pushErrorScope+getCompilationInfo. │
|
|
85
|
-
└────────────────────────────────────────────────────────────┘
|
|
86
|
-
```
|
|
87
|
-
|
|
88
|
-
## Key design decisions
|
|
89
|
-
|
|
90
|
-
**D1. Runtime tracing, not build-time.** Forward function is traced once on
|
|
91
|
-
first compile; the IR is built from those op calls. Build-time tracing via
|
|
92
|
-
a TypeScript transformer plugin would be cleaner but adds a build-step
|
|
93
|
-
requirement. v2 candidate.
|
|
94
|
-
|
|
95
|
-
**D2. Tensors are opaque handles, not Proxies.** Each op returns a fresh
|
|
96
|
-
`Tensor` (just `{ id, shape, dtype, source, site }`). Proxy-based tracing
|
|
97
|
-
gives nicer error UX but couples the IR to runtime introspection.
|
|
98
|
-
|
|
99
|
-
**D3. No reference counting.** Every IR tensor gets its own GPU buffer,
|
|
100
|
-
allocated once and never freed. Our scope (one model, fixed shapes,
|
|
101
|
-
training in a loop) means there's nothing to gain from refcount discipline.
|
|
102
|
-
Memory cost is bounded; buffer pooling is a v2 optimization, not v1
|
|
103
|
-
correctness.
|
|
104
|
-
|
|
105
|
-
**D4. Closed op set.** The IR knows about exactly the ops it supports.
|
|
106
|
-
Adding a new op means adding (a) shape rule, (b) WGSL kernel template,
|
|
107
|
-
(c) autograd transpose rule. This is intentional — a closed op set makes
|
|
108
|
-
each piece tractable to write and verify by hand.
|
|
109
|
-
|
|
110
|
-
**D5. Shapes checked at trace time, not at type level.** Type-level shape
|
|
111
|
-
encoding in TypeScript is real but hits recursion limits and adds
|
|
112
|
-
significant generic complexity to user code. Runtime shape errors at trace
|
|
113
|
-
time, with call-site capture, are good enough.
|
|
114
|
-
|
|
115
|
-
**D6. Adam state in the IR.** Optimizer state (m, v, plus a per-step `lrt`
|
|
116
|
-
scalar for bias correction) lives in dedicated `state_input` buffers that
|
|
117
|
-
persist across `step()` calls. Writebacks at the end of each step copy new
|
|
118
|
-
values into their persistent homes. No CPU↔GPU round-trip per step.
|
|
119
|
-
|
|
120
|
-
**D7. Module separates from forward.** Mutable parameter storage lives in
|
|
121
|
-
`Module` subclasses. Forward functions are pure, take the materialized
|
|
122
|
-
model as the first argument, and call ordinary op functions. State and
|
|
123
|
-
computation never mix — JAX's lesson, applied to TypeScript with
|
|
124
|
-
class-based ergonomics.
|
|
125
|
-
|
|
126
|
-
**D8. JS-number scalar overloads.** `add(x, 1e-5)` and `add(x, y)` both
|
|
127
|
-
work. The scalar variants dispatch to fused IR ops internally.
|
|
128
|
-
|
|
129
|
-
## IR
|
|
130
|
-
|
|
131
|
-
```ts
|
|
132
|
-
interface Tensor {
|
|
133
|
-
readonly id: number
|
|
134
|
-
readonly shape: Shape
|
|
135
|
-
readonly dtype: Dtype
|
|
136
|
-
readonly source: number | null // op index, or null for leaves
|
|
137
|
-
readonly site: CallSite | null // user's stack at op-call time
|
|
138
|
-
}
|
|
139
|
-
|
|
140
|
-
type OpNode =
|
|
141
|
-
| { kind: 'param_input'; ... } | { kind: 'tensor_input'; ... }
|
|
142
|
-
| { kind: 'state_input'; ... } | { kind: 'arange'; ... }
|
|
143
|
-
| { kind: 'const_scalar'; ... }
|
|
144
|
-
| { kind: 'add' | 'sub' | 'mul' | 'div'; ... }
|
|
145
|
-
| { kind: 'add_scalar' | 'mul_scalar'; ... }
|
|
146
|
-
| { kind: 'sqrt' | 'rsqrt' | 'log' | 'exp' | 'relu'; ... }
|
|
147
|
-
| { kind: 'less' | 'greater'; ... } | { kind: 'where'; ... }
|
|
148
|
-
| { kind: 'mean_last' | 'sum_last'; ... }
|
|
149
|
-
| { kind: 'reshape' | 'transpose' | 'slice_last_range'; ... }
|
|
150
|
-
| { kind: 'broadcast_to' | 'sum_to_shape'; ... }
|
|
151
|
-
| { kind: 'matmul' | 'matmul_batched'; ... }
|
|
152
|
-
| { kind: 'one_hot'; ... }
|
|
153
|
-
| { kind: 'softmax_causal_last' | 'log_softmax_last' | 'where_causal'; ... }
|
|
154
|
-
| { kind: 'relu_grad'; ... }
|
|
155
|
-
| { kind: 'adam_update_m' | 'adam_update_v' | 'adam_update_p'; ... }
|
|
156
|
-
|
|
157
|
-
interface Graph {
|
|
158
|
-
readonly ops: OpNode[]
|
|
159
|
-
readonly tensors: Tensor[]
|
|
160
|
-
readonly outputs: number[] // tensor ids — typically just the loss
|
|
161
|
-
}
|
|
162
|
-
```
|
|
163
|
-
|
|
164
|
-
The op kinds are intentionally split fine-grained (`mean_last` not
|
|
165
|
-
`mean(axis)`) because each kind maps to a hand-written WGSL kernel. Adding
|
|
166
|
-
generality later is fine; pretending to be more general than we are isn't.
|
|
167
|
-
|
|
168
|
-
## Op set (current)
|
|
169
|
-
|
|
170
|
-
**Leaves:** `param_input`, `tensor_input`, `state_input`, `arange`, `const_scalar`
|
|
171
|
-
|
|
172
|
-
**Element-wise binops** (NumPy broadcasting): `add`, `sub`, `mul`, `div`,
|
|
173
|
-
plus fused `add_scalar`, `mul_scalar`
|
|
174
|
-
|
|
175
|
-
**Element-wise unary:** `sqrt`, `rsqrt`, `log`, `exp`, `relu`
|
|
176
|
-
|
|
177
|
-
**Comparisons + select:** `less`, `greater`, `where`
|
|
178
|
-
|
|
179
|
-
**Reductions over last axis:** `mean_last`, `sum_last`
|
|
180
|
-
|
|
181
|
-
**Shape:** `reshape`, `transpose`, `slice_last_range`, `broadcast_to`, `sum_to_shape`
|
|
182
|
-
|
|
183
|
-
**Linear algebra:** `matmul` (2D rhs), `matmul_batched` (both batched)
|
|
184
|
-
|
|
185
|
-
**Indexing / casting:** `one_hot`
|
|
186
|
-
|
|
187
|
-
**ML primitives** (fused for clean autograd): `softmax_causal_last`,
|
|
188
|
-
`log_softmax_last`, `where_causal`
|
|
189
|
-
|
|
190
|
-
**Autograd-internal:** `relu_grad`
|
|
191
|
-
|
|
192
|
-
**Adam-internal:** `adam_update_m`, `adam_update_v`, `adam_update_p`
|
|
193
|
-
|
|
194
|
-
## Module abstraction
|
|
195
|
-
|
|
196
|
-
The `Module` base class enables Domeleon-style auto-discovery of nested
|
|
197
|
-
modules and parameters via property reflection:
|
|
198
|
-
|
|
199
|
-
```ts
|
|
200
|
-
class Linear extends Module {
|
|
201
|
-
W: Tensor; b: Tensor
|
|
202
|
-
constructor(public inDim: number, public outDim: number) {
|
|
203
|
-
super()
|
|
204
|
-
this.W = this.param([inDim, outDim]) // returns ParamSentinel cast to Tensor
|
|
205
|
-
this.b = this.param([outDim])
|
|
206
|
-
}
|
|
207
|
-
}
|
|
208
|
-
```
|
|
209
|
-
|
|
210
|
-
`this.param(shape)` returns a `ParamSentinel` typed as `Tensor`. At compile
|
|
211
|
-
time, `materializeParams(root)` walks enumerable properties of the model
|
|
212
|
-
tree (recursing into nested `Module` instances and arrays of modules),
|
|
213
|
-
replaces every sentinel with a real `paramInput` tensor whose name is the
|
|
214
|
-
property path (`layers.0.attn.q.W`), and returns a flat `Record<path,
|
|
215
|
-
Tensor>` for autograd to use.
|
|
216
|
-
|
|
217
|
-
This is the JAX/Equinox separation: parameter storage is mutable
|
|
218
|
-
(state-bearing components), forward computation is pure (functions over
|
|
219
|
-
materialized parameters and inputs).
|
|
220
|
-
|
|
221
|
-
## WGSL codegen
|
|
222
|
-
|
|
223
|
-
Each op kind has a kernel template in `codegen.ts`. Shapes are **baked into
|
|
224
|
-
the WGSL as compile-time constants** rather than passed as uniforms — this
|
|
225
|
-
gives the WGSL compiler full freedom to specialize and means each shape
|
|
226
|
-
combination produces a distinct shader. Fine for our static-shape model.
|
|
227
|
-
|
|
228
|
-
**Dispatch:** WebGPU caps each dimension at 65535 workgroups. The runtime
|
|
229
|
-
dispatches as `(min(N, 65535), ceil(N/65535), 1)`; kernels compute their
|
|
230
|
-
global index as `gid.x + gid.y * (65535 * workgroup_size)`. Workgroup size
|
|
231
|
-
is 256 — large enough that our biggest kernel (~8M threads in
|
|
232
|
-
`matmul_bwd_dW`) fits in 1D.
|
|
233
|
-
|
|
234
|
-
**Error reporting:** `runtime.ts` wraps each pipeline creation in
|
|
235
|
-
`pushErrorScope('validation')` and pulls `getCompilationInfo()` on
|
|
236
|
-
failure, so shader bugs surface with file/line/message rather than the
|
|
237
|
-
useless "previous error" you get when a broken pipeline is dispatched.
|
|
238
|
-
|
|
239
|
-
## Buffer plan
|
|
240
|
-
|
|
241
|
-
`planBuffers(graph, paramGrads, writebacks)` walks every tensor and
|
|
242
|
-
categorizes it:
|
|
243
|
-
|
|
244
|
-
| Kind | Purpose | Lifetime |
|
|
245
|
-
|---|---|---|
|
|
246
|
-
| `param` | trainable parameter | persistent |
|
|
247
|
-
| `param_grad` | gradient w.r.t. a param | one step |
|
|
248
|
-
| `state` | optimizer state (Adam m, v) | persistent |
|
|
249
|
-
| `tensor_input` | data input (tokens, targets) | one step |
|
|
250
|
-
| `intermediate` | any other op output | one step |
|
|
251
|
-
| `output` | exposed graph output (loss) | one step |
|
|
252
|
-
|
|
253
|
-
State buffers are zero-initialized at runtime creation. Writebacks (declared
|
|
254
|
-
by `appendAdam`) describe end-of-step `copyBufferToBuffer` operations from
|
|
255
|
-
freshly-computed values into their persistent homes.
|
|
256
|
-
|
|
257
|
-
## Autograd
|
|
258
|
-
|
|
259
|
-
`appendGrad(graph)` walks the forward ops in reverse and emits backward ops
|
|
260
|
-
into the same graph. Each op's transpose rule is hand-written in
|
|
261
|
-
`grad.ts`. The cotangents map (`tensorId → Tensor`) accumulates
|
|
262
|
-
contributions from multiple consumers via `add`.
|
|
263
|
-
|
|
264
|
-
Two notable workarounds:
|
|
265
|
-
|
|
266
|
-
- **Embedding lookup is implemented as `oneHot @ table`** rather than
|
|
267
|
-
`gather`. Gather has no transpose rule (it'd need scatter-with-atomic-add
|
|
268
|
-
or similar), but `oneHot @ table` decomposes into ops that *do* have
|
|
269
|
-
rules, so autograd works through it for free.
|
|
270
|
-
- **`slice_last_range` has no backward yet.** Forward works (used in any
|
|
271
|
-
axis-2 slicing pattern); backward is unimplemented because it'd need a
|
|
272
|
-
scatter-style "place into zeros" op. Workaround: use multiple separate
|
|
273
|
-
matmuls (e.g. `W_q`/`W_k`/`W_v`) instead of a fused `W_qkv`.
|
|
274
|
-
|
|
275
|
-
## Verification approach
|
|
276
|
-
|
|
277
|
-
Two layers:
|
|
278
|
-
|
|
279
|
-
1. **Smoke test** (`pnpm test` → `test/smoke.ts`) — runs in Node without
|
|
280
|
-
GPU. Builds the IR, attaches grad, plans buffers, emits all WGSL, and
|
|
281
|
-
verifies structure (kernel count, binding count, shape errors). Catches
|
|
282
|
-
codegen regressions without needing a browser.
|
|
283
|
-
2. **Live samples** (`pnpm --filter samples dev`) — Vite dev server with
|
|
284
|
-
a `/__log` endpoint that streams browser logs to stdout, used during
|
|
285
|
-
development to bypass copy-paste-from-console debugging.
|
|
286
|
-
|
|
287
|
-
## What this spec is not
|
|
288
|
-
|
|
289
|
-
A contract. The API will change before 1.0. The load-bearing decisions
|
|
290
|
-
are in **Scope and non-goals** and **Key design decisions** above —
|
|
291
|
-
those are the points where the design deliberately diverges from JAX or
|
|
292
|
-
PyTorch, and where reverting any of them effectively re-creates the
|
|
293
|
-
failure mode that motivated this library.
|