tensorgrad 0.0.5 → 0.0.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/adam.ts CHANGED
@@ -95,6 +95,11 @@ export function appendAdam(
95
95
  paramGrads: Record<string, Tensor>,
96
96
  paramTensors: Record<string, Tensor>,
97
97
  config: AdamConfig,
98
+ /** Per-param decay flags from `materializeParams`. When supplied, overrides
99
+ * `config.decayFilter` for any name in the map; falls back to `decayFilter`
100
+ * for names not present (e.g., for low-level callers using `compile()`
101
+ * directly without a Module). */
102
+ decayFlags?: Record<string, boolean>,
98
103
  ): AdamResult {
99
104
  const lrIsScheduled = typeof config.lr === 'function'
100
105
  const lrFn = lrIsScheduled
@@ -119,13 +124,22 @@ export function appendAdam(
119
124
  return traceInto(graph, () => {
120
125
  const lrt = tensorInput(lrtInputName, [], 'f32')
121
126
 
122
- // Decide up-front whether we need a runtime decayShrink scalar. Only does
123
- // something when both (a) lr varies per step and (b) some param is decayed.
124
- const needsDynamicShrink = lrIsScheduled
125
- && fullConfig.weightDecay > 0
126
- && Object.keys(paramGrads).some(name => fullConfig.decayFilter(name))
127
+ // Up-front: which params receive weight decay? Per-param decayFlags (set
128
+ // by Module.param's options) wins; falls back to decayFilter for names
129
+ // not in the map. Empty when weightDecay = 0 so the rest of the function
130
+ // can just ask "is this name in the set?".
131
+ const decayedNames = new Set<string>(
132
+ fullConfig.weightDecay > 0
133
+ ? Object.keys(paramGrads).filter(name =>
134
+ (decayFlags && name in decayFlags) ? decayFlags[name]! : fullConfig.decayFilter(name))
135
+ : [],
136
+ )
137
+
138
+ // We only need a runtime decayShrink scalar when lr varies per step AND
139
+ // at least one param is being decayed. Otherwise the value is constant
140
+ // and bakes into the kernel as a literal.
127
141
  let decayShrinkScalar: Tensor | null = null
128
- if (needsDynamicShrink) {
142
+ if (lrIsScheduled && decayedNames.size > 0) {
129
143
  decayShrinkInputName = '_adam_decay_shrink'
130
144
  decayShrinkScalar = tensorInput(decayShrinkInputName, [], 'f32')
131
145
  }
@@ -141,17 +155,12 @@ export function appendAdam(
141
155
 
142
156
  // Choose the decayShrink form per param:
143
157
  // - non-decayed params: literal 1 (kernel multiply folds out).
144
- // - decayed + static lr: literal `1 - lr * wd` baked at compile.
145
158
  // - decayed + scheduled lr: tensor input updated per step.
146
- const isDecayed = fullConfig.weightDecay > 0 && fullConfig.decayFilter(name)
147
- let decayShrink: number | Tensor
148
- if (!isDecayed) {
149
- decayShrink = 1
150
- } else if (decayShrinkScalar !== null) {
151
- decayShrink = decayShrinkScalar
152
- } else {
153
- decayShrink = 1 - initialLr * fullConfig.weightDecay
154
- }
159
+ // - decayed + static lr: literal `1 - lr * wd` baked at compile.
160
+ const decayShrink: number | Tensor =
161
+ !decayedNames.has(name) ? 1
162
+ : decayShrinkScalar !== null ? decayShrinkScalar
163
+ : 1 - initialLr * fullConfig.weightDecay
155
164
 
156
165
  // Three fused kernels per parameter — one for each of m_new / v_new / p_new.
157
166
  const newM = adamUpdateM(mState, g, fullConfig.b1)
package/src/compile.ts CHANGED
@@ -93,7 +93,7 @@ export async function compileModule<M extends Module>(
93
93
  ): Promise<CompiledRuntime & { ir: CompiledIR; uploadInitialParams: () => void }> {
94
94
  const inputDecls = opts.inputs ?? []
95
95
  const model = modelFactory()
96
- let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {} }
96
+ let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {}, decayFlags: {} }
97
97
  const graph = trace(() => {
98
98
  materialized = materializeParams(model)
99
99
  const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
@@ -104,7 +104,7 @@ export async function compileModule<M extends Module>(
104
104
 
105
105
  let adamResult: ReturnType<typeof appendAdam> | undefined
106
106
  if (opts.adam) {
107
- adamResult = appendAdam(graph, paramGrads, materialized.tensors, opts.adam)
107
+ adamResult = appendAdam(graph, paramGrads, materialized.tensors, opts.adam, materialized.decayFlags)
108
108
  }
109
109
 
110
110
  const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? [])
@@ -144,16 +144,8 @@ export async function compileModule<M extends Module>(
144
144
  }
145
145
  }
146
146
 
147
- const { initFns } = materialized
148
147
  const uploadInitialParams = () => {
149
- const out: Record<string, Float32Array> = {}
150
- for (const [name, bufId] of plan.paramsByName) {
151
- const shape = plan.buffers[bufId]!.shape
152
- const size = shape.reduce((a, b) => a * b, 1)
153
- const initFn = initFns[name]
154
- if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
155
- out[name] = initFn(size, shape)
156
- }
148
+ const out = buildInitialParamUploads(plan, materialized.initFns)
157
149
  runtime.uploadParams(out)
158
150
  }
159
151
 
@@ -161,6 +153,28 @@ export async function compileModule<M extends Module>(
161
153
  return Object.assign(runtime, { ir, uploadInitialParams })
162
154
  }
163
155
 
156
+ // Build a Record<paramName, Float32Array> by running each param's init
157
+ // function against its shape. Shared by compileModule and compileForward.
158
+ // `sharedParams`, when supplied, skips any name it covers (those are owned
159
+ // by the sibling compile and already initialized there).
160
+ type InitFn = (size: number, shape: readonly number[]) => Float32Array
161
+ function buildInitialParamUploads(
162
+ plan: BufferPlan,
163
+ initFns: Record<string, InitFn>,
164
+ sharedParams?: Map<string, GPUBuffer>,
165
+ ): Record<string, Float32Array> {
166
+ const out: Record<string, Float32Array> = {}
167
+ for (const [name, bufId] of plan.paramsByName) {
168
+ if (sharedParams?.has(name)) continue
169
+ const shape = plan.buffers[bufId]!.shape
170
+ const size = shape.reduce((a, b) => a * b, 1)
171
+ const initFn = initFns[name]
172
+ if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
173
+ out[name] = initFn(size, shape)
174
+ }
175
+ return out
176
+ }
177
+
164
178
  // ============================================================================
165
179
  // Forward-only compile
166
180
  // ============================================================================
@@ -187,7 +201,7 @@ export async function compileForward<M extends Module>(
187
201
  ): Promise<CompiledForward & { ir: CompiledIR; uploadInitialParams: () => void }> {
188
202
  const inputDecls = opts.inputs ?? []
189
203
  const model = modelFactory()
190
- let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {} }
204
+ let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {}, decayFlags: {} }
191
205
  const graph = trace(() => {
192
206
  materialized = materializeParams(model)
193
207
  const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
@@ -201,22 +215,9 @@ export async function compileForward<M extends Module>(
201
215
  const runtime = await createForwardRuntime(plan, kernels, outputBufferId, opts)
202
216
 
203
217
  const sharedParams = opts.sharedParams
204
- const { initFns } = materialized
205
218
  const uploadInitialParams = () => {
206
- const out: Record<string, Float32Array> = {}
207
- let needsUpload = false
208
- for (const [name, bufId] of plan.paramsByName) {
209
- // Skip params covered by sharedParams — those are owned by the providing
210
- // compile and already initialized there.
211
- if (sharedParams?.has(name)) continue
212
- const shape = plan.buffers[bufId]!.shape
213
- const size = shape.reduce((a, b) => a * b, 1)
214
- const initFn = initFns[name]
215
- if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
216
- out[name] = initFn(size, shape)
217
- needsUpload = true
218
- }
219
- if (needsUpload) runtime.uploadParams(out, { partial: !!sharedParams })
219
+ const out = buildInitialParamUploads(plan, materialized.initFns, sharedParams)
220
+ if (Object.keys(out).length > 0) runtime.uploadParams(out, { partial: !!sharedParams })
220
221
  }
221
222
 
222
223
  // CompiledIR.loss is the field name; for forward-only, it carries the user's
package/src/grad.ts CHANGED
@@ -18,7 +18,7 @@
18
18
  import type { Graph, OpNode, Tensor, Shape } from './ir.js'
19
19
  import {
20
20
  add, sub, mul, div, mulScalar,
21
- matmul, matmulBatched, transpose, reshape,
21
+ matmul, matmulBatched, transpose, swapAxes, reshape,
22
22
  exp,
23
23
  broadcastTo, sumToShape,
24
24
  constScalar, reluGrad,
@@ -280,14 +280,10 @@ function runTransposeRule(
280
280
  // leading batch dims to get [K, N].
281
281
  const a = tensorOf(op.a), b = tensorOf(op.b)
282
282
  // dA = dC @ B^T
283
- const bT = transpose(b, [1, 0])
284
- accumulate(cotangents, op.a, matmul(outCotan, bT))
283
+ accumulate(cotangents, op.a, matmul(outCotan, swapAxes(b, -1, -2)))
285
284
  // dB: per-batch A^T @ dC, then sum over batch dims.
286
285
  // A is [..., M, K]; transpose last two axes.
287
- const aTPerm = identityPerm(a.shape.length)
288
- ;[aTPerm[a.shape.length - 1], aTPerm[a.shape.length - 2]] =
289
- [aTPerm[a.shape.length - 2]!, aTPerm[a.shape.length - 1]!]
290
- const aT = transpose(a, aTPerm) // [..., K, M]
286
+ const aT = swapAxes(a, -1, -2) // [..., K, M]
291
287
  // matmul_batched needs same rank on both sides. dC has rank `a.rank`;
292
288
  // aT has rank `a.rank`; use matmul_batched if rank > 2, else matmul.
293
289
  let perBatchDb: Tensor
@@ -305,15 +301,8 @@ function runTransposeRule(
305
301
  // dA = dC @ B^T (per-batch, all batch dims preserved)
306
302
  // dB = A^T @ dC (per-batch)
307
303
  const a = tensorOf(op.a), b = tensorOf(op.b)
308
- const lastTwoSwap = (rank: number) => {
309
- const p = identityPerm(rank)
310
- ;[p[rank - 1], p[rank - 2]] = [p[rank - 2]!, p[rank - 1]!]
311
- return p
312
- }
313
- const bT = transpose(b, lastTwoSwap(b.shape.length))
314
- const aT = transpose(a, lastTwoSwap(a.shape.length))
315
- accumulate(cotangents, op.a, matmulBatched(outCotan, bT))
316
- accumulate(cotangents, op.b, matmulBatched(aT, outCotan))
304
+ accumulate(cotangents, op.a, matmulBatched(outCotan, swapAxes(b, -1, -2)))
305
+ accumulate(cotangents, op.b, matmulBatched(swapAxes(a, -1, -2), outCotan))
317
306
  return
318
307
  }
319
308
 
package/src/index.ts CHANGED
@@ -15,13 +15,13 @@ export {
15
15
  // Comparisons + select
16
16
  less, greater, where,
17
17
  // Reductions over the last axis (other axes via reshape/transpose first)
18
- meanLast, sumLast,
18
+ meanLast, sumLast, sumAll,
19
19
  // Shape ops
20
- reshape, transpose,
20
+ reshape, transpose, swapAxes,
21
21
  // Linear algebra
22
22
  matmul, matmulBatched,
23
23
  // Indexing / casting
24
- oneHot, arange,
24
+ oneHot, arange, embedding,
25
25
  // ML primitives — fused for the transformer
26
26
  softmaxCausalLast, logSoftmaxLast, whereCausal,
27
27
  // Slicing
package/src/module.ts CHANGED
@@ -52,6 +52,11 @@ export interface ParamOptions {
52
52
  init?: InitSpec
53
53
  /** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
54
54
  scale?: number
55
+ /** Whether AdamW (when `weightDecay > 0`) should apply decoupled weight
56
+ * decay to this param. Default: `true` for `'randn'` init (weight matrices,
57
+ * embeddings), `false` for `'zeros'` / `'ones'` (biases, LN gains). Override
58
+ * to force or skip. Replaces `adam.decayFilter` for the common case. */
59
+ decay?: boolean
55
60
  }
56
61
 
57
62
  type InitFn = (size: number, shape: readonly number[]) => Float32Array
@@ -76,6 +81,17 @@ function resolveInit(opts: ParamOptions | undefined): InitFn {
76
81
  throw new Error(`Unknown init: ${String(init)}`)
77
82
  }
78
83
 
84
+ /** Resolve the decay default for a param. Decay weight matrices and
85
+ * embedding tables (randn-initialized); skip biases (zeros) and LN gains
86
+ * (ones). Custom init functions default to "decay" — most user-supplied
87
+ * inits are weight-shaped (Kaiming etc.). Explicit `decay: false` overrides. */
88
+ function resolveDecay(opts: ParamOptions | undefined): boolean {
89
+ if (opts?.decay !== undefined) return opts.decay
90
+ const init = opts?.init ?? 'randn'
91
+ if (init === 'zeros' || init === 'ones') return false
92
+ return true // 'randn' or function
93
+ }
94
+
79
95
  // ============================================================================
80
96
  // Internals: param sentinel
81
97
  // ============================================================================
@@ -90,6 +106,7 @@ class ParamSentinel {
90
106
  public readonly shape: Shape,
91
107
  public readonly dtype: Dtype,
92
108
  public readonly initFn: InitFn,
109
+ public readonly decay: boolean,
93
110
  ) {}
94
111
  }
95
112
 
@@ -110,7 +127,7 @@ export abstract class Module {
110
127
  protected param(shape: Shape, opts?: ParamOptions): Tensor {
111
128
  const dtype = opts?.dtype ?? 'f32'
112
129
  // Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
113
- return new ParamSentinel(shape, dtype, resolveInit(opts)) as unknown as Tensor
130
+ return new ParamSentinel(shape, dtype, resolveInit(opts), resolveDecay(opts)) as unknown as Tensor
114
131
  }
115
132
  }
116
133
 
@@ -123,6 +140,9 @@ export interface MaterializedParams {
123
140
  tensors: Record<string, Tensor>
124
141
  /** Init function per param path. Used by `uploadInitialParams`. */
125
142
  initFns: Record<string, InitFn>
143
+ /** Whether this param should receive AdamW weight decay. Resolved at
144
+ * `param()` time from `ParamOptions.decay` (with init-based default). */
145
+ decayFlags: Record<string, boolean>
126
146
  }
127
147
 
128
148
  /**
@@ -136,15 +156,17 @@ export interface MaterializedParams {
136
156
  export function materializeParams(root: Module): MaterializedParams {
137
157
  const tensors: Record<string, Tensor> = {}
138
158
  const initFns: Record<string, InitFn> = {}
159
+ const decayFlags: Record<string, boolean> = {}
139
160
  visit(root, '', (path, val, owner, key) => {
140
161
  if (val instanceof ParamSentinel) {
141
162
  const t = paramInput(path, val.shape, val.dtype)
142
163
  ;(owner as any)[key] = t
143
164
  tensors[path] = t
144
165
  initFns[path] = val.initFn
166
+ decayFlags[path] = val.decay
145
167
  }
146
168
  })
147
- return { tensors, initFns }
169
+ return { tensors, initFns, decayFlags }
148
170
  }
149
171
 
150
172
  // ----------------------------------------------------------------------------
package/src/nn.ts CHANGED
@@ -15,7 +15,9 @@
15
15
 
16
16
  import { Module } from './module.js'
17
17
  import type { Tensor } from './ir.js'
18
- import { add, matmul, sub, mul, div, sqrt, meanLast } from './ops.js'
18
+ import { add, matmul, sub, mul, div, sqrt, meanLast, sumLast, reshape, swapAxes, oneHot, logSoftmaxLast } from './ops.js'
19
+ import { ShapeError } from './shape.js'
20
+ import { captureSite } from './ir.js'
19
21
 
20
22
  // ----------------------------------------------------------------------------
21
23
  // Linear: y = x @ W (+ b)
@@ -57,3 +59,84 @@ export function layerNormFwd(p: LayerNorm, x: Tensor): Tensor {
57
59
  const stdev = sqrt(add(v, p.eps))
58
60
  return add(mul(div(c, stdev), p.g), p.b)
59
61
  }
62
+
63
+ // ----------------------------------------------------------------------------
64
+ // Multi-head attention shape helpers — split the last (model) axis into
65
+ // [nHeads, headDim] and bring heads ahead of the sequence axis.
66
+ // ----------------------------------------------------------------------------
67
+
68
+ /** [..., T, D] → [..., H, T, D/H]. Folds the standard
69
+ * `transpose(reshape(x, [..., T, H, d]), [..., H, T, d])` pattern into one
70
+ * call. Last dim of `x` must divide evenly by `nHeads`. */
71
+ export function splitHeads(x: Tensor, nHeads: number): Tensor {
72
+ const site = captureSite('splitHeads')
73
+ const r = x.shape.length
74
+ if (r < 2) throw new ShapeError(`splitHeads: requires rank >= 2, got ${r}`, site)
75
+ const T = x.shape[r - 2]!
76
+ const D = x.shape[r - 1]!
77
+ if (D % nHeads !== 0) {
78
+ throw new ShapeError(`splitHeads: last dim ${D} not divisible by nHeads ${nHeads}`, site)
79
+ }
80
+ const lead = x.shape.slice(0, r - 2)
81
+ const reshaped = reshape(x, [...lead, T, nHeads, D / nHeads])
82
+ // Swap T (axis lead.length) with H (axis lead.length + 1).
83
+ return swapAxes(reshaped, lead.length, lead.length + 1)
84
+ }
85
+
86
+ /** Inverse of `splitHeads`: [..., H, T, d] → [..., T, H*d]. */
87
+ export function mergeHeads(x: Tensor): Tensor {
88
+ const site = captureSite('mergeHeads')
89
+ const r = x.shape.length
90
+ if (r < 3) throw new ShapeError(`mergeHeads: requires rank >= 3, got ${r}`, site)
91
+ const H = x.shape[r - 3]!
92
+ const T = x.shape[r - 2]!
93
+ const d = x.shape[r - 1]!
94
+ const lead = x.shape.slice(0, r - 3)
95
+ // Swap H (axis r-3) and T (axis r-2): [..., H, T, d] → [..., T, H, d]
96
+ const swapped = swapAxes(x, r - 3, r - 2)
97
+ return reshape(swapped, [...lead, T, H * d])
98
+ }
99
+
100
+ /** Slice a flat capture readback of shape `[H, ..., ...]` into one
101
+ * Float32Array per head. The leading axis is treated as the head axis;
102
+ * pass the shape from `compiled.captureShapes[name]`. Result: `H` arrays,
103
+ * each holding the row-major data for that head (size = product of trailing
104
+ * axes). For B>1 graphs, prefix the result by the batch — this helper
105
+ * assumes the leading axis is heads, which matches how `splitHeads` lays
106
+ * out captures at B=1 (the typical capture-readback shape). */
107
+ export function unsplitHeads(flat: Float32Array, shape: readonly number[]): Float32Array[] {
108
+ if (shape.length < 2) {
109
+ throw new Error(`unsplitHeads: shape needs >= 2 dims, got [${shape.join(', ')}]`)
110
+ }
111
+ // For inference graphs at B=1, captures have shape [1, H, ..., ...]. Strip
112
+ // the leading 1 if present so callers can pass captureShapes[name] directly.
113
+ const s = shape[0] === 1 ? shape.slice(1) : shape
114
+ const H = s[0]!
115
+ let stride = 1
116
+ for (let i = 1; i < s.length; i++) stride *= s[i]!
117
+ const expected = H * stride
118
+ if (flat.length !== expected) {
119
+ throw new Error(`unsplitHeads: flat length ${flat.length} doesn't match shape product ${expected}`)
120
+ }
121
+ return Array.from({ length: H }, (_, h) => flat.slice(h * stride, (h + 1) * stride))
122
+ }
123
+
124
+ // ----------------------------------------------------------------------------
125
+ // Loss helpers
126
+ // ----------------------------------------------------------------------------
127
+
128
+ /** Per-position cross-entropy along the last (vocab) axis: returns
129
+ * `-log p(target)` at each position. `logits` is `[..., V]`; `targets` is
130
+ * `[...]` of i32; result is `[...]` (one rank less than logits). The user
131
+ * applies their own masking + reduction downstream — useful when only some
132
+ * positions contribute (e.g. result-digit masking) or for label smoothing. */
133
+ export function crossEntropyLast(logits: Tensor, targets: Tensor): Tensor {
134
+ const site = captureSite('crossEntropyLast')
135
+ if (targets.dtype !== 'i32') {
136
+ throw new ShapeError(`crossEntropyLast: targets must be i32, got ${targets.dtype}`, site)
137
+ }
138
+ const vocab = logits.shape[logits.shape.length - 1]!
139
+ const lp = logSoftmaxLast(logits) // [..., V]
140
+ const targetLp = sumLast(mul(lp, oneHot(targets, vocab, 'f32'))) // [...]
141
+ return mul(targetLp, -1)
142
+ }
package/src/ops.ts CHANGED
@@ -17,7 +17,7 @@ import {
17
17
  inferReshape, inferTranspose, inferMatmul, inferMatmulBatched,
18
18
  inferOneHot, inferWhereCausal, inferSliceLastRange,
19
19
  inferBroadcastTo, inferSumToShape, inferReluGrad, inferWhere,
20
- ShapeError,
20
+ ShapeError, showShape,
21
21
  } from './shape.js'
22
22
 
23
23
  // ----------------------------------------------------------------------------
@@ -112,6 +112,11 @@ export function sumLast(a: Tensor): Tensor {
112
112
  return addOp(currentGraph(), 'sum_last', outShape, a.dtype, site, { a: a.id })
113
113
  }
114
114
 
115
+ /** Reduce all elements to a 0-d scalar. Composes `reshape` + `sumLast`. */
116
+ export function sumAll(a: Tensor): Tensor {
117
+ return sumLast(reshape(a, [-1]))
118
+ }
119
+
115
120
  // ----------------------------------------------------------------------------
116
121
  // Shape ops.
117
122
  // ----------------------------------------------------------------------------
@@ -128,6 +133,26 @@ export function transpose(a: Tensor, perm: readonly number[]): Tensor {
128
133
  return addOp(currentGraph(), 'transpose', outShape, a.dtype, site, { a: a.id, perm })
129
134
  }
130
135
 
136
+ /** Swap two axes of a tensor. Negative indices count from the end (so
137
+ * `swapAxes(x, -1, -2)` swaps the last two — the common attention pattern).
138
+ * All other axes keep their position. Implemented as `transpose` with the
139
+ * permutation `[0, 1, ..., axis2, ..., axis1, ..., n-1]`. */
140
+ export function swapAxes(a: Tensor, axis1: number, axis2: number): Tensor {
141
+ const r = a.shape.length
142
+ const norm = (axis: number): number => axis < 0 ? r + axis : axis
143
+ const i1 = norm(axis1)
144
+ const i2 = norm(axis2)
145
+ const site = captureSite('swapAxes')
146
+ if (i1 < 0 || i1 >= r || i2 < 0 || i2 >= r) {
147
+ throw new ShapeError(`swapAxes: axis out of range — got (${axis1}, ${axis2}) for rank-${r} tensor`, site)
148
+ }
149
+ if (i1 === i2) return a
150
+ const perm = Array.from({ length: r }, (_, k) => k)
151
+ perm[i1] = i2
152
+ perm[i2] = i1
153
+ return transpose(a, perm)
154
+ }
155
+
131
156
  // ----------------------------------------------------------------------------
132
157
  // Linear algebra.
133
158
  // ----------------------------------------------------------------------------
@@ -163,6 +188,22 @@ export function oneHot(indices: Tensor, depth: number, dtype: Dtype = 'f32'): Te
163
188
  return addOp(currentGraph(), 'one_hot', outShape, dtype, site, { indices: indices.id, depth, dtype })
164
189
  }
165
190
 
191
+ /** Embedding lookup: pull rows from `table` indexed by `indices`. Decomposes
192
+ * to `oneHot(indices, vocab) @ table` so autograd works without a dedicated
193
+ * scatter-with-atomic-add backward — the matmul transpose rule handles it.
194
+ * `table` is `[vocab, dim]`; `indices` is any shape `[...]` of i32; result
195
+ * is `[..., dim]`. The vocab size is taken from `table.shape[0]`. */
196
+ export function embedding(table: Tensor, indices: Tensor): Tensor {
197
+ const site = captureSite('embedding')
198
+ if (table.shape.length !== 2) {
199
+ throw new ShapeError(`embedding: table must be 2-d [vocab, dim], got ${showShape(table.shape)}`, site)
200
+ }
201
+ if (indices.dtype !== 'i32') {
202
+ throw new ShapeError(`embedding: indices must be i32, got ${indices.dtype}`, site)
203
+ }
204
+ return matmul(oneHot(indices, table.shape[0]!, 'f32'), table)
205
+ }
206
+
166
207
  // arange(n) → [n] of values [0, 1, ..., n-1]. Used for position embeddings.
167
208
  export function arange(n: number, dtype: Dtype = 'i32'): Tensor {
168
209
  const site = captureSite('arange')
package/src/runtime.ts CHANGED
@@ -43,17 +43,38 @@ export interface RunWithCaptures {
43
43
  captures: Record<string, Float32Array>
44
44
  }
45
45
 
46
- export interface CompiledRuntime {
47
- /** Map of param name -> the underlying GPUBuffer. Pass to a sibling compile
48
- * via `sharedParams` to share without copies every step on this runtime
49
- * is immediately visible to anyone reading these buffers. */
46
+ /** Common surface for both training and forward-only compiled runtimes. */
47
+ export interface CompiledBase {
48
+ /** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
49
+ * `sharedParams` to share without copies. */
50
50
  params: Map<string, GPUBuffer>
51
+ /** Shape of each tensor registered via `capture(name, t)`. Static after
52
+ * compile — reshape readbacks without recomputing strides. */
53
+ captureShapes: Record<string, number[]>
54
+ /** Shape of the graph's output (loss scalar `[]` for training; the user's
55
+ * returned tensor for forward-only compiles). */
56
+ outputShape: number[]
51
57
  /** Upload parameter Float32Arrays to their GPU buffers. By default, requires
52
58
  * *all* params to be present; throws on any unknown or missing key. Pass
53
59
  * `{ partial: true }` to skip the missing-key check. */
54
60
  uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void
55
61
  /** Read all parameters back as Float32Arrays — used for UI panels. */
56
62
  downloadParams(): Promise<Record<string, Float32Array>>
63
+ /** Free GPU resources. */
64
+ destroy(): void
65
+ }
66
+
67
+ /** Run a dispatch and read back the full output tensor (and any registered
68
+ * captures if requested). Forward-only compiles use this as their primary
69
+ * surface; training compiles also expose it but `step()` is more convenient
70
+ * there because the output is a scalar loss. */
71
+ export interface RunFn {
72
+ (inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
73
+ (inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
74
+ (inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
75
+ }
76
+
77
+ export interface CompiledRuntime extends CompiledBase {
57
78
  /** Read all parameter gradients back. Mostly for verification / debugging. */
58
79
  downloadParamGrads(): Promise<Record<string, Float32Array>>
59
80
  /**
@@ -68,32 +89,19 @@ export interface CompiledRuntime {
68
89
  step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
69
90
  step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepWithCaptures>
70
91
  step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepWithCaptures>
71
- /** Like `step()` but returns the full output Float32Array instead of just
72
- * its first element. For training graphs this is rarely useful (the output
73
- * *is* a scalar loss); it's the primary API for forward-only compiles. */
74
- run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
75
- run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
76
- run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
92
+ /** Same dispatch as step() but returns the full output Float32Array for
93
+ * training graphs the output is a scalar loss, so step() is usually more
94
+ * convenient. Provided for parity with `compileForward`. */
95
+ run: RunFn
77
96
  /** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
78
97
  * `uploadInitialParams()` for a full training reset without recompile. */
79
98
  resetOptimizerState(): void
80
- /** Free GPU resources. */
81
- destroy(): void
82
99
  }
83
100
 
84
101
  /** Forward-only compiled runtime — produced by `compileForward`. No optimizer,
85
102
  * no backward. Returns the output tensor (not just a scalar) per `run()` call. */
86
- export interface CompiledForward {
87
- params: Map<string, GPUBuffer>
88
- uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void
89
- downloadParams(): Promise<Record<string, Float32Array>>
90
- /** Forward-only dispatch. Returns the graph's output tensor as a Float32Array
91
- * (the user's returned tensor from the forward function, in row-major order).
92
- * With `{ withCaptures: true }`, returns `{ output, captures }`. */
93
- run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
94
- run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
95
- run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
96
- destroy(): void
103
+ export interface CompiledForward extends CompiledBase {
104
+ run: RunFn
97
105
  }
98
106
 
99
107
  export interface RuntimeOpts {
@@ -147,14 +155,7 @@ export async function createRuntime(
147
155
  label: spec.name ?? `t${spec.id}-${spec.kind}`,
148
156
  })
149
157
  buffers.set(spec.id, buf)
150
- if (spec.kind === 'state') {
151
- // Fill with initValue (typically 0). Float and int both 4 bytes per element.
152
- const elements = spec.byteSize / 4
153
- const init = spec.dtype === 'f32'
154
- ? new Float32Array(elements).fill(spec.initValue ?? 0)
155
- : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
156
- queue.writeBuffer(buf, 0, init as unknown as BufferSource)
157
- }
158
+ if (spec.kind === 'state') fillStateBuffer(spec, buf)
158
159
  }
159
160
  // Track which params are externally owned — those are skipped on destroy().
160
161
  const ownedBufferIds = new Set<number>()
@@ -404,14 +405,20 @@ export async function createRuntime(
404
405
  return out
405
406
  }
406
407
 
408
+ // Fill a state buffer with its declared initValue (typically 0). Float and
409
+ // int both serialize to 4 bytes per element. Used at allocation time and on
410
+ // resetOptimizerState() — same logic, two callers.
411
+ function fillStateBuffer(spec: { byteSize: number; dtype: 'f32' | 'i32' | 'bool'; initValue?: number }, target: GPUBuffer): void {
412
+ const elements = spec.byteSize / 4
413
+ const init = spec.dtype === 'f32'
414
+ ? new Float32Array(elements).fill(spec.initValue ?? 0)
415
+ : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
416
+ queue.writeBuffer(target, 0, init as unknown as BufferSource)
417
+ }
418
+
407
419
  function resetOptimizerState() {
408
420
  for (const spec of plan.buffers) {
409
- if (spec.kind !== 'state') continue
410
- const elements = spec.byteSize / 4
411
- const init = spec.dtype === 'f32'
412
- ? new Float32Array(elements).fill(spec.initValue ?? 0)
413
- : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
414
- queue.writeBuffer(buffers.get(spec.id)!, 0, init as unknown as BufferSource)
421
+ if (spec.kind === 'state') fillStateBuffer(spec, buffers.get(spec.id)!)
415
422
  }
416
423
  }
417
424
 
@@ -421,6 +428,13 @@ export async function createRuntime(
421
428
  for (const [name, bufId] of plan.paramsByName) {
422
429
  params.set(name, buffers.get(bufId)!)
423
430
  }
431
+ // Static-after-compile shape metadata so users don't have to recompute
432
+ // strides to interpret a flat capture readback.
433
+ const captureShapes: Record<string, number[]> = {}
434
+ for (const [name, bufId] of plan.capturesByName) {
435
+ captureShapes[name] = [...plan.buffers[bufId]!.shape]
436
+ }
437
+ const outputShape = [...plan.buffers[lossBufferId]!.shape]
424
438
 
425
439
  const destroy = () => {
426
440
  for (const [id, b] of buffers) {
@@ -432,6 +446,8 @@ export async function createRuntime(
432
446
 
433
447
  return {
434
448
  params,
449
+ captureShapes,
450
+ outputShape,
435
451
  uploadParams,
436
452
  downloadParams: () => downloadFromMap(plan.paramsByName),
437
453
  downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),
@@ -442,22 +458,17 @@ export async function createRuntime(
442
458
  }
443
459
  }
444
460
 
445
- /** Same machinery as `createRuntime`, narrower public API: no step,
446
- * no resetOptimizerState, no downloadParamGrads. Used by `compileForward`. */
461
+ /** Same machinery as `createRuntime`, narrower public type: a forward-only
462
+ * graph exposes `run()` instead of `step()` (no optimizer state, no scalar-
463
+ * loss readback). The full runtime object is built once and projected by
464
+ * `compileForward` to the public shape. */
447
465
  export async function createForwardRuntime(
448
466
  plan: BufferPlan,
449
467
  kernels: KernelSpec[],
450
468
  outputBufferId: number,
451
469
  opts: RuntimeOpts = {},
452
470
  ): Promise<CompiledForward> {
453
- const full = await createRuntime(plan, kernels, outputBufferId, opts)
454
- return {
455
- params: full.params,
456
- uploadParams: full.uploadParams,
457
- downloadParams: full.downloadParams,
458
- run: full.run,
459
- destroy: full.destroy,
460
- }
471
+ return await createRuntime(plan, kernels, outputBufferId, opts)
461
472
  }
462
473
 
463
474
  async function acquireDevice(): Promise<GPUDevice> {