tensorgrad 0.0.4 → 0.0.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/grad.ts CHANGED
@@ -18,7 +18,7 @@
18
18
  import type { Graph, OpNode, Tensor, Shape } from './ir.js'
19
19
  import {
20
20
  add, sub, mul, div, mulScalar,
21
- matmul, matmulBatched, transpose, reshape,
21
+ matmul, matmulBatched, transpose, swapAxes, reshape,
22
22
  exp,
23
23
  broadcastTo, sumToShape,
24
24
  constScalar, reluGrad,
@@ -280,14 +280,10 @@ function runTransposeRule(
280
280
  // leading batch dims to get [K, N].
281
281
  const a = tensorOf(op.a), b = tensorOf(op.b)
282
282
  // dA = dC @ B^T
283
- const bT = transpose(b, [1, 0])
284
- accumulate(cotangents, op.a, matmul(outCotan, bT))
283
+ accumulate(cotangents, op.a, matmul(outCotan, swapAxes(b, -1, -2)))
285
284
  // dB: per-batch A^T @ dC, then sum over batch dims.
286
285
  // A is [..., M, K]; transpose last two axes.
287
- const aTPerm = identityPerm(a.shape.length)
288
- ;[aTPerm[a.shape.length - 1], aTPerm[a.shape.length - 2]] =
289
- [aTPerm[a.shape.length - 2]!, aTPerm[a.shape.length - 1]!]
290
- const aT = transpose(a, aTPerm) // [..., K, M]
286
+ const aT = swapAxes(a, -1, -2) // [..., K, M]
291
287
  // matmul_batched needs same rank on both sides. dC has rank `a.rank`;
292
288
  // aT has rank `a.rank`; use matmul_batched if rank > 2, else matmul.
293
289
  let perBatchDb: Tensor
@@ -305,15 +301,8 @@ function runTransposeRule(
305
301
  // dA = dC @ B^T (per-batch, all batch dims preserved)
306
302
  // dB = A^T @ dC (per-batch)
307
303
  const a = tensorOf(op.a), b = tensorOf(op.b)
308
- const lastTwoSwap = (rank: number) => {
309
- const p = identityPerm(rank)
310
- ;[p[rank - 1], p[rank - 2]] = [p[rank - 2]!, p[rank - 1]!]
311
- return p
312
- }
313
- const bT = transpose(b, lastTwoSwap(b.shape.length))
314
- const aT = transpose(a, lastTwoSwap(a.shape.length))
315
- accumulate(cotangents, op.a, matmulBatched(outCotan, bT))
316
- accumulate(cotangents, op.b, matmulBatched(aT, outCotan))
304
+ accumulate(cotangents, op.a, matmulBatched(outCotan, swapAxes(b, -1, -2)))
305
+ accumulate(cotangents, op.b, matmulBatched(swapAxes(a, -1, -2), outCotan))
317
306
  return
318
307
  }
319
308
 
package/src/index.ts CHANGED
@@ -15,13 +15,13 @@ export {
15
15
  // Comparisons + select
16
16
  less, greater, where,
17
17
  // Reductions over the last axis (other axes via reshape/transpose first)
18
- meanLast, sumLast,
18
+ meanLast, sumLast, sumAll,
19
19
  // Shape ops
20
- reshape, transpose,
20
+ reshape, transpose, swapAxes,
21
21
  // Linear algebra
22
22
  matmul, matmulBatched,
23
23
  // Indexing / casting
24
- oneHot, arange,
24
+ oneHot, arange, embedding,
25
25
  // ML primitives — fused for the transformer
26
26
  softmaxCausalLast, logSoftmaxLast, whereCausal,
27
27
  // Slicing
package/src/ir.ts CHANGED
@@ -113,9 +113,21 @@ export type OpNode =
113
113
  // `lrt` is a scalar tensor (provided as a tensor_input updated per step) that
114
114
  // already includes Adam's bias-correction factor: lrt = lr * sqrt(1-b2^t) / (1-b1^t).
115
115
  // `decayShrink` is the decoupled-weight-decay factor (Loshchilov & Hutter,
116
- // "AdamW") baked at compile time: 1 - lr * weightDecay when the param is being
117
- // decayed, 1 otherwise. eps and decayShrink are both baked into the kernel.
118
- | { kind: 'adam_update_p'; out: number; p: number; mNew: number; vNew: number; lrt: number; eps: number; decayShrink: number }
116
+ // "AdamW"): 1 - lr * weightDecay when the param is being decayed, 1 otherwise.
117
+ // It can be either a compile-time literal (number) for fixed-lr training, or a
118
+ // tensor id pointing at a scalar input that the runtime updates per step (used
119
+ // when the user supplies an lr schedule via `adam: { lr: (step) => ... }`).
120
+ | {
121
+ kind: 'adam_update_p'
122
+ out: number
123
+ p: number
124
+ mNew: number
125
+ vNew: number
126
+ lrt: number
127
+ eps: number
128
+ decayShrink: number // literal (used when decayShrinkTensor is null)
129
+ decayShrinkTensor: number | null // tensor id of a scalar input; takes precedence when set
130
+ }
119
131
 
120
132
  // ---- Slicing / broadcasting / autograd infrastructure -------------------
121
133
  // Slice [start, end) along the last axis. Output shape: input shape with
package/src/module.ts CHANGED
@@ -52,6 +52,11 @@ export interface ParamOptions {
52
52
  init?: InitSpec
53
53
  /** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
54
54
  scale?: number
55
+ /** Whether AdamW (when `weightDecay > 0`) should apply decoupled weight
56
+ * decay to this param. Default: `true` for `'randn'` init (weight matrices,
57
+ * embeddings), `false` for `'zeros'` / `'ones'` (biases, LN gains). Override
58
+ * to force or skip. Replaces `adam.decayFilter` for the common case. */
59
+ decay?: boolean
55
60
  }
56
61
 
57
62
  type InitFn = (size: number, shape: readonly number[]) => Float32Array
@@ -76,6 +81,17 @@ function resolveInit(opts: ParamOptions | undefined): InitFn {
76
81
  throw new Error(`Unknown init: ${String(init)}`)
77
82
  }
78
83
 
84
+ /** Resolve the decay default for a param. Decay weight matrices and
85
+ * embedding tables (randn-initialized); skip biases (zeros) and LN gains
86
+ * (ones). Custom init functions default to "decay" — most user-supplied
87
+ * inits are weight-shaped (Kaiming etc.). Explicit `decay: false` overrides. */
88
+ function resolveDecay(opts: ParamOptions | undefined): boolean {
89
+ if (opts?.decay !== undefined) return opts.decay
90
+ const init = opts?.init ?? 'randn'
91
+ if (init === 'zeros' || init === 'ones') return false
92
+ return true // 'randn' or function
93
+ }
94
+
79
95
  // ============================================================================
80
96
  // Internals: param sentinel
81
97
  // ============================================================================
@@ -90,6 +106,7 @@ class ParamSentinel {
90
106
  public readonly shape: Shape,
91
107
  public readonly dtype: Dtype,
92
108
  public readonly initFn: InitFn,
109
+ public readonly decay: boolean,
93
110
  ) {}
94
111
  }
95
112
 
@@ -110,7 +127,7 @@ export abstract class Module {
110
127
  protected param(shape: Shape, opts?: ParamOptions): Tensor {
111
128
  const dtype = opts?.dtype ?? 'f32'
112
129
  // Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
113
- return new ParamSentinel(shape, dtype, resolveInit(opts)) as unknown as Tensor
130
+ return new ParamSentinel(shape, dtype, resolveInit(opts), resolveDecay(opts)) as unknown as Tensor
114
131
  }
115
132
  }
116
133
 
@@ -123,6 +140,9 @@ export interface MaterializedParams {
123
140
  tensors: Record<string, Tensor>
124
141
  /** Init function per param path. Used by `uploadInitialParams`. */
125
142
  initFns: Record<string, InitFn>
143
+ /** Whether this param should receive AdamW weight decay. Resolved at
144
+ * `param()` time from `ParamOptions.decay` (with init-based default). */
145
+ decayFlags: Record<string, boolean>
126
146
  }
127
147
 
128
148
  /**
@@ -136,15 +156,17 @@ export interface MaterializedParams {
136
156
  export function materializeParams(root: Module): MaterializedParams {
137
157
  const tensors: Record<string, Tensor> = {}
138
158
  const initFns: Record<string, InitFn> = {}
159
+ const decayFlags: Record<string, boolean> = {}
139
160
  visit(root, '', (path, val, owner, key) => {
140
161
  if (val instanceof ParamSentinel) {
141
162
  const t = paramInput(path, val.shape, val.dtype)
142
163
  ;(owner as any)[key] = t
143
164
  tensors[path] = t
144
165
  initFns[path] = val.initFn
166
+ decayFlags[path] = val.decay
145
167
  }
146
168
  })
147
- return { tensors, initFns }
169
+ return { tensors, initFns, decayFlags }
148
170
  }
149
171
 
150
172
  // ----------------------------------------------------------------------------
package/src/nn.ts CHANGED
@@ -15,7 +15,9 @@
15
15
 
16
16
  import { Module } from './module.js'
17
17
  import type { Tensor } from './ir.js'
18
- import { add, matmul, sub, mul, div, sqrt, meanLast } from './ops.js'
18
+ import { add, matmul, sub, mul, div, sqrt, meanLast, sumLast, reshape, swapAxes, oneHot, logSoftmaxLast } from './ops.js'
19
+ import { ShapeError } from './shape.js'
20
+ import { captureSite } from './ir.js'
19
21
 
20
22
  // ----------------------------------------------------------------------------
21
23
  // Linear: y = x @ W (+ b)
@@ -57,3 +59,60 @@ export function layerNormFwd(p: LayerNorm, x: Tensor): Tensor {
57
59
  const stdev = sqrt(add(v, p.eps))
58
60
  return add(mul(div(c, stdev), p.g), p.b)
59
61
  }
62
+
63
+ // ----------------------------------------------------------------------------
64
+ // Multi-head attention shape helpers — split the last (model) axis into
65
+ // [nHeads, headDim] and bring heads ahead of the sequence axis.
66
+ // ----------------------------------------------------------------------------
67
+
68
+ /** [..., T, D] → [..., H, T, D/H]. Folds the standard
69
+ * `transpose(reshape(x, [..., T, H, d]), [..., H, T, d])` pattern into one
70
+ * call. Last dim of `x` must divide evenly by `nHeads`. */
71
+ export function splitHeads(x: Tensor, nHeads: number): Tensor {
72
+ const site = captureSite('splitHeads')
73
+ const r = x.shape.length
74
+ if (r < 2) throw new ShapeError(`splitHeads: requires rank >= 2, got ${r}`, site)
75
+ const T = x.shape[r - 2]!
76
+ const D = x.shape[r - 1]!
77
+ if (D % nHeads !== 0) {
78
+ throw new ShapeError(`splitHeads: last dim ${D} not divisible by nHeads ${nHeads}`, site)
79
+ }
80
+ const lead = x.shape.slice(0, r - 2)
81
+ const reshaped = reshape(x, [...lead, T, nHeads, D / nHeads])
82
+ // Swap T (axis lead.length) with H (axis lead.length + 1).
83
+ return swapAxes(reshaped, lead.length, lead.length + 1)
84
+ }
85
+
86
+ /** Inverse of `splitHeads`: [..., H, T, d] → [..., T, H*d]. */
87
+ export function mergeHeads(x: Tensor): Tensor {
88
+ const site = captureSite('mergeHeads')
89
+ const r = x.shape.length
90
+ if (r < 3) throw new ShapeError(`mergeHeads: requires rank >= 3, got ${r}`, site)
91
+ const H = x.shape[r - 3]!
92
+ const T = x.shape[r - 2]!
93
+ const d = x.shape[r - 1]!
94
+ const lead = x.shape.slice(0, r - 3)
95
+ // Swap H (axis r-3) and T (axis r-2): [..., H, T, d] → [..., T, H, d]
96
+ const swapped = swapAxes(x, r - 3, r - 2)
97
+ return reshape(swapped, [...lead, T, H * d])
98
+ }
99
+
100
+ // ----------------------------------------------------------------------------
101
+ // Loss helpers
102
+ // ----------------------------------------------------------------------------
103
+
104
+ /** Per-position cross-entropy along the last (vocab) axis: returns
105
+ * `-log p(target)` at each position. `logits` is `[..., V]`; `targets` is
106
+ * `[...]` of i32; result is `[...]` (one rank less than logits). The user
107
+ * applies their own masking + reduction downstream — useful when only some
108
+ * positions contribute (e.g. result-digit masking) or for label smoothing. */
109
+ export function crossEntropyLast(logits: Tensor, targets: Tensor): Tensor {
110
+ const site = captureSite('crossEntropyLast')
111
+ if (targets.dtype !== 'i32') {
112
+ throw new ShapeError(`crossEntropyLast: targets must be i32, got ${targets.dtype}`, site)
113
+ }
114
+ const vocab = logits.shape[logits.shape.length - 1]!
115
+ const lp = logSoftmaxLast(logits) // [..., V]
116
+ const targetLp = sumLast(mul(lp, oneHot(targets, vocab, 'f32'))) // [...]
117
+ return mul(targetLp, -1)
118
+ }
package/src/ops.ts CHANGED
@@ -17,7 +17,7 @@ import {
17
17
  inferReshape, inferTranspose, inferMatmul, inferMatmulBatched,
18
18
  inferOneHot, inferWhereCausal, inferSliceLastRange,
19
19
  inferBroadcastTo, inferSumToShape, inferReluGrad, inferWhere,
20
- ShapeError,
20
+ ShapeError, showShape,
21
21
  } from './shape.js'
22
22
 
23
23
  // ----------------------------------------------------------------------------
@@ -112,6 +112,11 @@ export function sumLast(a: Tensor): Tensor {
112
112
  return addOp(currentGraph(), 'sum_last', outShape, a.dtype, site, { a: a.id })
113
113
  }
114
114
 
115
+ /** Reduce all elements to a 0-d scalar. Composes `reshape` + `sumLast`. */
116
+ export function sumAll(a: Tensor): Tensor {
117
+ return sumLast(reshape(a, [-1]))
118
+ }
119
+
115
120
  // ----------------------------------------------------------------------------
116
121
  // Shape ops.
117
122
  // ----------------------------------------------------------------------------
@@ -128,6 +133,26 @@ export function transpose(a: Tensor, perm: readonly number[]): Tensor {
128
133
  return addOp(currentGraph(), 'transpose', outShape, a.dtype, site, { a: a.id, perm })
129
134
  }
130
135
 
136
+ /** Swap two axes of a tensor. Negative indices count from the end (so
137
+ * `swapAxes(x, -1, -2)` swaps the last two — the common attention pattern).
138
+ * All other axes keep their position. Implemented as `transpose` with the
139
+ * permutation `[0, 1, ..., axis2, ..., axis1, ..., n-1]`. */
140
+ export function swapAxes(a: Tensor, axis1: number, axis2: number): Tensor {
141
+ const r = a.shape.length
142
+ const norm = (axis: number): number => axis < 0 ? r + axis : axis
143
+ const i1 = norm(axis1)
144
+ const i2 = norm(axis2)
145
+ const site = captureSite('swapAxes')
146
+ if (i1 < 0 || i1 >= r || i2 < 0 || i2 >= r) {
147
+ throw new ShapeError(`swapAxes: axis out of range — got (${axis1}, ${axis2}) for rank-${r} tensor`, site)
148
+ }
149
+ if (i1 === i2) return a
150
+ const perm = Array.from({ length: r }, (_, k) => k)
151
+ perm[i1] = i2
152
+ perm[i2] = i1
153
+ return transpose(a, perm)
154
+ }
155
+
131
156
  // ----------------------------------------------------------------------------
132
157
  // Linear algebra.
133
158
  // ----------------------------------------------------------------------------
@@ -163,6 +188,22 @@ export function oneHot(indices: Tensor, depth: number, dtype: Dtype = 'f32'): Te
163
188
  return addOp(currentGraph(), 'one_hot', outShape, dtype, site, { indices: indices.id, depth, dtype })
164
189
  }
165
190
 
191
+ /** Embedding lookup: pull rows from `table` indexed by `indices`. Decomposes
192
+ * to `oneHot(indices, vocab) @ table` so autograd works without a dedicated
193
+ * scatter-with-atomic-add backward — the matmul transpose rule handles it.
194
+ * `table` is `[vocab, dim]`; `indices` is any shape `[...]` of i32; result
195
+ * is `[..., dim]`. The vocab size is taken from `table.shape[0]`. */
196
+ export function embedding(table: Tensor, indices: Tensor): Tensor {
197
+ const site = captureSite('embedding')
198
+ if (table.shape.length !== 2) {
199
+ throw new ShapeError(`embedding: table must be 2-d [vocab, dim], got ${showShape(table.shape)}`, site)
200
+ }
201
+ if (indices.dtype !== 'i32') {
202
+ throw new ShapeError(`embedding: indices must be i32, got ${indices.dtype}`, site)
203
+ }
204
+ return matmul(oneHot(indices, table.shape[0]!, 'f32'), table)
205
+ }
206
+
166
207
  // arange(n) → [n] of values [0, 1, ..., n-1]. Used for position embeddings.
167
208
  export function arange(n: number, dtype: Dtype = 'i32'): Tensor {
168
209
  const site = captureSite('arange')
@@ -297,7 +338,14 @@ export function adamUpdateV(v: Tensor, g: Tensor, b2: number): Tensor {
297
338
  return addOp(currentGraph(), 'adam_update_v', v.shape, 'f32', site, { v: v.id, g: g.id, b2 })
298
339
  }
299
340
 
300
- export function adamUpdateP(p: Tensor, mNew: Tensor, vNew: Tensor, lrt: Tensor, eps: number, decayShrink: number = 1): Tensor {
341
+ export function adamUpdateP(
342
+ p: Tensor,
343
+ mNew: Tensor,
344
+ vNew: Tensor,
345
+ lrt: Tensor,
346
+ eps: number,
347
+ decayShrink: number | Tensor = 1,
348
+ ): Tensor {
301
349
  const site = captureSite('adamUpdateP')
302
350
  if (p.dtype !== 'f32') throw new ShapeError(`adamUpdateP: requires f32`, site)
303
351
  if (lrt.dtype !== 'f32' || lrt.shape.length !== 0) {
@@ -306,6 +354,22 @@ export function adamUpdateP(p: Tensor, mNew: Tensor, vNew: Tensor, lrt: Tensor,
306
354
  if (p.shape.length !== mNew.shape.length || p.shape.some((d, i) => d !== mNew.shape[i])) {
307
355
  throw new ShapeError(`adamUpdateP: p/mNew shape mismatch`, site)
308
356
  }
309
- return addOp(currentGraph(), 'adam_update_p', p.shape, 'f32', site,
310
- { p: p.id, mNew: mNew.id, vNew: vNew.id, lrt: lrt.id, eps, decayShrink })
357
+ // decayShrink is either a literal (baked into the kernel) or a 0-d scalar
358
+ // tensor input the runtime updates per step. The kernel binds at most one,
359
+ // chosen by whichever the caller provided.
360
+ const isTensor = typeof decayShrink === 'object'
361
+ if (isTensor) {
362
+ if (decayShrink.dtype !== 'f32' || decayShrink.shape.length !== 0) {
363
+ throw new ShapeError(`adamUpdateP: decayShrink tensor must be a 0-d f32 scalar`, site)
364
+ }
365
+ }
366
+ return addOp(currentGraph(), 'adam_update_p', p.shape, 'f32', site, {
367
+ p: p.id,
368
+ mNew: mNew.id,
369
+ vNew: vNew.id,
370
+ lrt: lrt.id,
371
+ eps,
372
+ decayShrink: isTensor ? 1 : decayShrink,
373
+ decayShrinkTensor: isTensor ? decayShrink.id : null,
374
+ })
311
375
  }
package/src/runtime.ts CHANGED
@@ -43,17 +43,38 @@ export interface RunWithCaptures {
43
43
  captures: Record<string, Float32Array>
44
44
  }
45
45
 
46
- export interface CompiledRuntime {
47
- /** Map of param name -> the underlying GPUBuffer. Pass to a sibling compile
48
- * via `sharedParams` to share without copies every step on this runtime
49
- * is immediately visible to anyone reading these buffers. */
46
+ /** Common surface for both training and forward-only compiled runtimes. */
47
+ export interface CompiledBase {
48
+ /** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
49
+ * `sharedParams` to share without copies. */
50
50
  params: Map<string, GPUBuffer>
51
+ /** Shape of each tensor registered via `capture(name, t)`. Static after
52
+ * compile — reshape readbacks without recomputing strides. */
53
+ captureShapes: Record<string, number[]>
54
+ /** Shape of the graph's output (loss scalar `[]` for training; the user's
55
+ * returned tensor for forward-only compiles). */
56
+ outputShape: number[]
51
57
  /** Upload parameter Float32Arrays to their GPU buffers. By default, requires
52
58
  * *all* params to be present; throws on any unknown or missing key. Pass
53
59
  * `{ partial: true }` to skip the missing-key check. */
54
60
  uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void
55
61
  /** Read all parameters back as Float32Arrays — used for UI panels. */
56
62
  downloadParams(): Promise<Record<string, Float32Array>>
63
+ /** Free GPU resources. */
64
+ destroy(): void
65
+ }
66
+
67
+ /** Run a dispatch and read back the full output tensor (and any registered
68
+ * captures if requested). Forward-only compiles use this as their primary
69
+ * surface; training compiles also expose it but `step()` is more convenient
70
+ * there because the output is a scalar loss. */
71
+ export interface RunFn {
72
+ (inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
73
+ (inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
74
+ (inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
75
+ }
76
+
77
+ export interface CompiledRuntime extends CompiledBase {
57
78
  /** Read all parameter gradients back. Mostly for verification / debugging. */
58
79
  downloadParamGrads(): Promise<Record<string, Float32Array>>
59
80
  /**
@@ -68,32 +89,19 @@ export interface CompiledRuntime {
68
89
  step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
69
90
  step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepWithCaptures>
70
91
  step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepWithCaptures>
71
- /** Like `step()` but returns the full output Float32Array instead of just
72
- * its first element. For training graphs this is rarely useful (the output
73
- * *is* a scalar loss); it's the primary API for forward-only compiles. */
74
- run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
75
- run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
76
- run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
92
+ /** Same dispatch as step() but returns the full output Float32Array for
93
+ * training graphs the output is a scalar loss, so step() is usually more
94
+ * convenient. Provided for parity with `compileForward`. */
95
+ run: RunFn
77
96
  /** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
78
97
  * `uploadInitialParams()` for a full training reset without recompile. */
79
98
  resetOptimizerState(): void
80
- /** Free GPU resources. */
81
- destroy(): void
82
99
  }
83
100
 
84
101
  /** Forward-only compiled runtime — produced by `compileForward`. No optimizer,
85
102
  * no backward. Returns the output tensor (not just a scalar) per `run()` call. */
86
- export interface CompiledForward {
87
- params: Map<string, GPUBuffer>
88
- uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void
89
- downloadParams(): Promise<Record<string, Float32Array>>
90
- /** Forward-only dispatch. Returns the graph's output tensor as a Float32Array
91
- * (the user's returned tensor from the forward function, in row-major order).
92
- * With `{ withCaptures: true }`, returns `{ output, captures }`. */
93
- run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
94
- run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
95
- run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
96
- destroy(): void
103
+ export interface CompiledForward extends CompiledBase {
104
+ run: RunFn
97
105
  }
98
106
 
99
107
  export interface RuntimeOpts {
@@ -147,14 +155,7 @@ export async function createRuntime(
147
155
  label: spec.name ?? `t${spec.id}-${spec.kind}`,
148
156
  })
149
157
  buffers.set(spec.id, buf)
150
- if (spec.kind === 'state') {
151
- // Fill with initValue (typically 0). Float and int both 4 bytes per element.
152
- const elements = spec.byteSize / 4
153
- const init = spec.dtype === 'f32'
154
- ? new Float32Array(elements).fill(spec.initValue ?? 0)
155
- : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
156
- queue.writeBuffer(buf, 0, init as unknown as BufferSource)
157
- }
158
+ if (spec.kind === 'state') fillStateBuffer(spec, buf)
158
159
  }
159
160
  // Track which params are externally owned — those are skipped on destroy().
160
161
  const ownedBufferIds = new Set<number>()
@@ -404,14 +405,20 @@ export async function createRuntime(
404
405
  return out
405
406
  }
406
407
 
408
+ // Fill a state buffer with its declared initValue (typically 0). Float and
409
+ // int both serialize to 4 bytes per element. Used at allocation time and on
410
+ // resetOptimizerState() — same logic, two callers.
411
+ function fillStateBuffer(spec: { byteSize: number; dtype: 'f32' | 'i32' | 'bool'; initValue?: number }, target: GPUBuffer): void {
412
+ const elements = spec.byteSize / 4
413
+ const init = spec.dtype === 'f32'
414
+ ? new Float32Array(elements).fill(spec.initValue ?? 0)
415
+ : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
416
+ queue.writeBuffer(target, 0, init as unknown as BufferSource)
417
+ }
418
+
407
419
  function resetOptimizerState() {
408
420
  for (const spec of plan.buffers) {
409
- if (spec.kind !== 'state') continue
410
- const elements = spec.byteSize / 4
411
- const init = spec.dtype === 'f32'
412
- ? new Float32Array(elements).fill(spec.initValue ?? 0)
413
- : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
414
- queue.writeBuffer(buffers.get(spec.id)!, 0, init as unknown as BufferSource)
421
+ if (spec.kind === 'state') fillStateBuffer(spec, buffers.get(spec.id)!)
415
422
  }
416
423
  }
417
424
 
@@ -421,6 +428,13 @@ export async function createRuntime(
421
428
  for (const [name, bufId] of plan.paramsByName) {
422
429
  params.set(name, buffers.get(bufId)!)
423
430
  }
431
+ // Static-after-compile shape metadata so users don't have to recompute
432
+ // strides to interpret a flat capture readback.
433
+ const captureShapes: Record<string, number[]> = {}
434
+ for (const [name, bufId] of plan.capturesByName) {
435
+ captureShapes[name] = [...plan.buffers[bufId]!.shape]
436
+ }
437
+ const outputShape = [...plan.buffers[lossBufferId]!.shape]
424
438
 
425
439
  const destroy = () => {
426
440
  for (const [id, b] of buffers) {
@@ -432,6 +446,8 @@ export async function createRuntime(
432
446
 
433
447
  return {
434
448
  params,
449
+ captureShapes,
450
+ outputShape,
435
451
  uploadParams,
436
452
  downloadParams: () => downloadFromMap(plan.paramsByName),
437
453
  downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),
@@ -442,22 +458,17 @@ export async function createRuntime(
442
458
  }
443
459
  }
444
460
 
445
- /** Same machinery as `createRuntime`, narrower public API: no step,
446
- * no resetOptimizerState, no downloadParamGrads. Used by `compileForward`. */
461
+ /** Same machinery as `createRuntime`, narrower public type: a forward-only
462
+ * graph exposes `run()` instead of `step()` (no optimizer state, no scalar-
463
+ * loss readback). The full runtime object is built once and projected by
464
+ * `compileForward` to the public shape. */
447
465
  export async function createForwardRuntime(
448
466
  plan: BufferPlan,
449
467
  kernels: KernelSpec[],
450
468
  outputBufferId: number,
451
469
  opts: RuntimeOpts = {},
452
470
  ): Promise<CompiledForward> {
453
- const full = await createRuntime(plan, kernels, outputBufferId, opts)
454
- return {
455
- params: full.params,
456
- uploadParams: full.uploadParams,
457
- downloadParams: full.downloadParams,
458
- run: full.run,
459
- destroy: full.destroy,
460
- }
471
+ return await createRuntime(plan, kernels, outputBufferId, opts)
461
472
  }
462
473
 
463
474
  async function acquireDevice(): Promise<GPUDevice> {