tensorgrad 0.0.11 → 0.0.12

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/runtime.ts CHANGED
@@ -1,515 +1,523 @@
1
- // WebGPU runtime. Reads a BufferPlan + KernelSpec[] (produced by codegen),
2
- // allocates real GPU buffers and pipelines, and provides a `step()` method
3
- // that uploads inputs, dispatches all kernels, and reads back outputs.
4
- //
5
- // Browser-only: this module needs `navigator.gpu` at runtime.
6
-
7
- import type { BufferPlan } from './buffers.js'
8
- import type { KernelSpec } from './codegen.js'
9
-
10
- // TS lib.dom defines WebGPU types but not the GPUMapMode runtime constant.
11
- // Provided by the browser per WebGPU spec; declare just what we use.
12
- declare const GPUMapMode: { readonly READ: number; readonly WRITE: number }
13
-
14
- export interface UploadParamsOptions {
15
- /** Skip the "missing param" check, allowing the caller to update only some
16
- * params and leave the rest at their current GPU values. Extra (unknown)
17
- * keys are still rejected — that's always a typo. Default: false. */
18
- partial?: boolean
19
- }
20
-
21
- /**
22
- * Activation readbacks for one `step()`/`run()` call. Keyed by the names
23
- * passed to `capture(name, t)` during the trace. `get(name)` throws if the
24
- * name isn't registered or wasn't read back this call (i.e., the call was
25
- * made without `{ withCaptures: true }`); use `has(name)` if you need to
26
- * branch. `shapeOf(name)` returns the static-after-compile shape and works
27
- * regardless of whether captures were read back.
28
- */
29
- export class Captures {
30
- constructor(
31
- private readonly shapes: Record<string, readonly number[]>,
32
- private readonly data: Map<string, Float32Array>,
33
- ) {}
34
- get(name: string): Float32Array {
35
- const d = this.data.get(name)
36
- if (!d) {
37
- const known = [...this.data.keys()].sort().join(', ')
38
- const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`
39
- throw new Error(`Captures.get: '${name}' not present. ${detail}`)
40
- }
41
- return d
42
- }
43
- shapeOf(name: string): readonly number[] {
44
- const s = this.shapes[name]
45
- if (!s) {
46
- const known = Object.keys(this.shapes).sort().join(', ') || '(none registered)'
47
- throw new Error(`Captures.shapeOf: '${name}' not registered. Known: ${known}`)
48
- }
49
- return s
50
- }
51
- has(name: string): boolean { return this.data.has(name) }
52
- names(): string[] { return [...this.data.keys()].sort() }
53
- }
54
-
55
- export interface RunResult {
56
- output: Float32Array
57
- captures: Captures
58
- }
59
-
60
- export interface StepResult {
61
- loss: number
62
- captures: Captures
63
- }
64
-
65
- export interface RunOptions {
66
- /** Read back tensors registered via `capture(name, t)` during the trace.
67
- * Default false. When false, the returned `captures` is empty (calling
68
- * `.get` throws); when true, captures are read back and accessible. */
69
- withCaptures?: boolean
70
- }
71
-
72
- /** Common surface for both training and forward-only compiled runtimes. */
73
- export interface CompiledBase {
74
- /** The GPUDevice this runtime is bound to. Pass to sibling compiles to
75
- * share the device, or use directly for other GPU work. */
76
- device: GPUDevice
77
- /** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
78
- * `sharedParams` to share without copies. */
79
- params: Map<string, GPUBuffer>
80
- /** Shape of the graph's output (loss scalar `[]` for training; the user's
81
- * returned tensor for forward-only compiles). */
82
- outputShape: number[]
83
- /** Upload parameter Float32Arrays to their GPU buffers. By default, requires
84
- * *all* params to be present; throws on any unknown or missing key. Pass
85
- * `{ partial: true }` to skip the missing-key check. */
86
- uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void
87
- /** Read all parameters back as Float32Arrays — used for UI panels. */
88
- downloadParams(): Promise<Record<string, Float32Array>>
89
- /** Free GPU resources. */
90
- destroy(): void
91
- }
92
-
93
- /** Run a dispatch and read back the full output tensor. Captures are always
94
- * returned; their data is empty unless `{ withCaptures: true }` is passed. */
95
- export type RunFn = (
96
- inputs: Record<string, Int32Array | Float32Array>,
97
- opts?: RunOptions,
98
- ) => Promise<RunResult>
99
-
100
- export interface CompiledRuntime extends CompiledBase {
101
- /** Read all parameter gradients back. Mostly for verification / debugging. */
102
- downloadParamGrads(): Promise<Record<string, Float32Array>>
103
- /**
104
- * One full forward+backward step.
105
- * 1. Uploads `inputs` (tokens, targets, masks) to input buffers.
106
- * 2. Dispatches every kernel in order.
107
- * 3. Reads back the loss scalar (and any registered captures, if requested).
108
- * Default returns the loss as a JS number; with `{ withCaptures: true }`
109
- * returns `{ loss, captures }`.
110
- */
111
- step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
112
- step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
113
- step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>
114
- /** Same dispatch as step() but returns the full output Float32Array — for
115
- * training graphs the output is a scalar loss, so step() is usually more
116
- * convenient. Provided for parity with `compileForward`. */
117
- run: RunFn
118
- /** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
119
- * `uploadInitialParams()` for a full training reset without recompile. */
120
- resetOptimizerState(): void
121
- }
122
-
123
- /** Forward-only compiled runtime — produced by `compileForward`. No optimizer,
124
- * no backward. Returns the output tensor (not just a scalar) per `run()` call. */
125
- export interface CompiledForward extends CompiledBase {
126
- run: RunFn
127
- }
128
-
129
- export interface RuntimeOpts {
130
- /** Pre-acquired GPUDevice. If omitted, runtime requests its own. */
131
- device?: GPUDevice
132
- /** External param buffers to bind in place of allocating fresh ones, keyed
133
- * by param name. Used to share params between a training compile and a
134
- * sibling forward-only compile (e.g., a B=1 inference graph). When a name
135
- * is in this map, the runtime reuses the provided GPUBuffer; otherwise it
136
- * allocates as usual. */
137
- sharedParams?: Map<string, GPUBuffer>
138
- }
139
-
140
- // Inlined numeric values (per WebGPU spec) so this module is importable in Node
141
- // for codegen-only usage. The browser provides GPUBufferUsage as a global, but
142
- // referencing it at module scope would crash before any browser code runs.
143
- const STORAGE_RW = 0x80 /*STORAGE*/ | 0x8 /*COPY_DST*/ | 0x4 /*COPY_SRC*/
144
- const READBACK = 0x1 /*MAP_READ*/ | 0x8 /*COPY_DST*/
145
-
146
- export async function createRuntime(
147
- plan: BufferPlan,
148
- kernels: KernelSpec[],
149
- lossBufferId: number,
150
- opts: RuntimeOpts = {},
151
- ): Promise<CompiledRuntime> {
152
- const device = opts.device ?? await acquireDevice()
153
- const queue = device.queue
154
-
155
- // ---- Allocate one GPUBuffer per BufferSpec --------------------------------
156
- // State buffers also get filled with their initValue at allocation time.
157
- // Param buffers may be supplied externally via opts.sharedParams; in that
158
- // case we reuse the provided GPUBuffer instead of allocating, and the
159
- // sibling compile that owns it is responsible for upload + lifetime.
160
- const buffers = new Map<number, GPUBuffer>()
161
- const sharedParams = opts.sharedParams
162
- for (const spec of plan.buffers) {
163
- if (spec.kind === 'param' && sharedParams?.has(spec.name!)) {
164
- const shared = sharedParams.get(spec.name!)!
165
- if (shared.size !== spec.byteSize) {
166
- throw new Error(
167
- `sharedParams: size mismatch for '${spec.name}' — supplied ${shared.size} bytes, ` +
168
- `compiled graph expects ${spec.byteSize}.`,
169
- )
170
- }
171
- buffers.set(spec.id, shared)
172
- continue
173
- }
174
- const buf = device.createBuffer({
175
- size: spec.byteSize,
176
- usage: STORAGE_RW,
177
- label: spec.name ?? `t${spec.id}-${spec.kind}`,
178
- })
179
- buffers.set(spec.id, buf)
180
- if (spec.kind === 'state') fillStateBuffer(spec, buf)
181
- }
182
- // Track which params are externally owned — those are skipped on destroy().
183
- const ownedBufferIds = new Set<number>()
184
- for (const spec of plan.buffers) {
185
- if (!(spec.kind === 'param' && sharedParams?.has(spec.name!))) {
186
- ownedBufferIds.add(spec.id)
187
- }
188
- }
189
-
190
- // ---- Compile pipelines per kernel; cache by WGSL source -------------------
191
- // Push an error scope around each shader+pipeline creation so we can surface
192
- // the actual compile error rather than the cryptic "previous error" that
193
- // comes from using an invalid pipeline at dispatch time.
194
- const moduleCache = new Map<string, GPUShaderModule>()
195
- const pipelines: (GPUComputePipeline | null)[] = []
196
- type ErrorProbe = Promise<{ k: KernelSpec; module: GPUShaderModule; err: GPUError } | null>
197
- const probes: ErrorProbe[] = []
198
- for (const k of kernels) {
199
- if (!k.wgsl) { pipelines.push(null); continue }
200
- let module = moduleCache.get(k.wgsl)
201
- if (!module) {
202
- module = device.createShaderModule({ code: k.wgsl, label: k.opKind })
203
- moduleCache.set(k.wgsl, module)
204
- }
205
- device.pushErrorScope('validation')
206
- const pipeline = device.createComputePipeline({
207
- layout: 'auto',
208
- compute: { module, entryPoint: 'main' },
209
- label: k.opKind,
210
- })
211
- pipelines.push(pipeline)
212
- probes.push(device.popErrorScope().then(err => err ? { k, module: module!, err } : null))
213
- }
214
- const probeResults = await Promise.all(probes)
215
- const failures = probeResults.filter((p): p is { k: KernelSpec; module: GPUShaderModule; err: GPUError } => p != null)
216
- if (failures.length > 0) {
217
- const reports: string[] = []
218
- for (const { k, module, err } of failures) {
219
- const info = await module.getCompilationInfo()
220
- const messages = info.messages
221
- .map(m => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`)
222
- .join('\n')
223
- reports.push(
224
- `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` +
225
- (messages || ' (no compilation messages)') +
226
- `\n--- WGSL ---\n${k.wgsl}\n-----------`,
227
- )
228
- }
229
- // eslint-disable-next-line no-console
230
- console.error(reports.join('\n\n'))
231
- throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`)
232
- }
233
-
234
- // ---- Pre-build bind groups (static — buffer ids don't change per step) ---
235
- const bindGroups: (GPUBindGroup | null)[] = kernels.map((k, i) => {
236
- const pipeline = pipelines[i]
237
- if (!pipeline) return null
238
- return device.createBindGroup({
239
- layout: pipeline.getBindGroupLayout(0),
240
- entries: k.bindings.map((bufId, idx) => ({
241
- binding: idx,
242
- resource: { buffer: buffers.get(bufId)! },
243
- })),
244
- })
245
- })
246
-
247
- // ---- Output readback staging buffer ---------------------------------------
248
- // `outputBufferId` is the graph's main output (loss for training, the user's
249
- // returned tensor for forward-only). step() reads back its first element;
250
- // run() reads back the full Float32Array.
251
- const outputSpec = plan.buffers[lossBufferId]!
252
- const outputReadback = device.createBuffer({ size: outputSpec.byteSize, usage: READBACK })
253
-
254
- // ---- Capture readback staging buffers (lazy) ------------------------------
255
- // Allocated on first `step({ withCaptures: true })` call and reused across
256
- // subsequent calls. When the graph has no captures registered or when the
257
- // caller never opts in, no extra GPU memory is allocated.
258
- let captureStagings: Map<string, GPUBuffer> | null = null
259
- function ensureCaptureStagings(): Map<string, GPUBuffer> {
260
- if (captureStagings) return captureStagings
261
- captureStagings = new Map()
262
- for (const [name, bufId] of plan.capturesByName) {
263
- const spec = plan.buffers[bufId]!
264
- const staging = device.createBuffer({ size: spec.byteSize, usage: READBACK, label: `cap-${name}` })
265
- captureStagings.set(name, staging)
266
- }
267
- return captureStagings
268
- }
269
-
270
- // ---- dispatch() — shared core for step() and run() -----------------------
271
- // Uploads inputs, dispatches all kernels (in order), queues writebacks, copies
272
- // the output buffer into its staging, optionally copies captures into theirs,
273
- // submits, and reads back. Returns the full output Float32Array; step() takes
274
- // [0] for scalar loss, run() returns it whole.
275
- //
276
- // **Concurrent calls auto-serialize.** Two `step()`/`run()` calls on the same
277
- // runtime would otherwise both try to `mapAsync` the shared output staging
278
- // buffer at the same time and trip "Buffer already has an outstanding map
279
- // pending." We chain each new dispatch onto the prior one's promise so they
280
- // run sequentially even when fired from independent async paths (e.g., a
281
- // training loop's auxiliary `refreshPrediction()` + `writeDiagnostic()`).
282
- let pending: Promise<unknown> = Promise.resolve()
283
- async function dispatch(
284
- inputs: Record<string, Int32Array | Float32Array>,
285
- wantCaptures: boolean,
286
- ): Promise<{ output: Float32Array; captures: Map<string, Float32Array> }> {
287
- const turn = pending.catch(() => {}).then(() => dispatchUnsynchronized(inputs, wantCaptures))
288
- pending = turn
289
- return turn
290
- }
291
- async function dispatchUnsynchronized(
292
- inputs: Record<string, Int32Array | Float32Array>,
293
- wantCaptures: boolean,
294
- ): Promise<{ output: Float32Array; captures: Map<string, Float32Array> }> {
295
- if (wantCaptures && plan.capturesByName.size === 0) {
296
- throw new Error(
297
- `withCaptures=true but no capture(...) calls were registered during ` +
298
- `the trace. Add capture('name', tensor) inside your forward pass for ` +
299
- `the intermediates you want read back.`,
300
- )
301
- }
302
- for (const [name, bufId] of plan.inputsByName) {
303
- const data = inputs[name]
304
- if (!data) throw new Error(`tensorgrad: missing input '${name}'`)
305
- const expectedBytes = plan.buffers[bufId]!.byteSize
306
- if (data.byteLength !== expectedBytes) {
307
- throw new Error(`tensorgrad: input '${name}' has ${data.byteLength} bytes, expected ${expectedBytes}`)
308
- }
309
- // Cast to BufferSource: typed arrays are accepted by writeBuffer at runtime
310
- // but TS may infer ArrayBufferLike (vs ArrayBuffer) under strict configs.
311
- queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
312
- }
313
-
314
- const encoder = device.createCommandEncoder({ label: 'tensorgrad-step' })
315
- for (let i = 0; i < kernels.length; i++) {
316
- const k = kernels[i]!
317
- if (!k.wgsl || k.threads === 0) continue
318
- const pipeline = pipelines[i]!
319
- const bindGroup = bindGroups[i]!
320
- const pass = encoder.beginComputePass({ label: k.opKind })
321
- pass.setPipeline(pipeline)
322
- pass.setBindGroup(0, bindGroup)
323
- // WebGPU caps each dispatch dimension at 65535 workgroups. Split into 2D
324
- // when a kernel needs more than that on the X axis. Kernels compute their
325
- // global index as `gid.x + gid.y * (65535 * workgroup_size)`, matching the
326
- // stride we set here. For dispatches that fit in one row, gid.y is 0.
327
- const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize))
328
- const MAX_X = 65535
329
- const wgX = Math.min(wgCount, MAX_X)
330
- const wgY = Math.ceil(wgCount / MAX_X)
331
- pass.dispatchWorkgroups(wgX, wgY, 1)
332
- pass.end()
333
- }
334
- // After all dispatches: writebacks (Adam state, updated params). Empty for
335
- // forward-only compiles.
336
- for (const wb of plan.writebacks) {
337
- encoder.copyBufferToBuffer(buffers.get(wb.source)!, 0, buffers.get(wb.dest)!, 0, wb.bytes)
338
- }
339
- encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0, outputReadback, 0, outputSpec.byteSize)
340
- // Capture readbacks (only when opted in). Queued before submit so they
341
- // observe the same kernel outputs as the main output.
342
- let stagings: Map<string, GPUBuffer> | null = null
343
- if (wantCaptures) {
344
- stagings = ensureCaptureStagings()
345
- for (const [name, bufId] of plan.capturesByName) {
346
- const spec = plan.buffers[bufId]!
347
- encoder.copyBufferToBuffer(buffers.get(bufId)!, 0, stagings.get(name)!, 0, spec.byteSize)
348
- }
349
- }
350
- queue.submit([encoder.finish()])
351
-
352
- await outputReadback.mapAsync(GPUMapMode.READ)
353
- const output = new Float32Array(outputReadback.getMappedRange().slice(0))
354
- outputReadback.unmap()
355
-
356
- const captures = new Map<string, Float32Array>()
357
- if (wantCaptures) {
358
- for (const [name, staging] of stagings!) {
359
- await staging.mapAsync(GPUMapMode.READ)
360
- captures.set(name, new Float32Array(staging.getMappedRange().slice(0)))
361
- staging.unmap()
362
- }
363
- }
364
- return { output, captures }
365
- }
366
-
367
- // ---- step() — training-mode wrapper, returns scalar [0] of output ---------
368
- function step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
369
- function step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
370
- function step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>
371
- async function step(
372
- inputs: Record<string, Int32Array | Float32Array>,
373
- opts?: RunOptions,
374
- ): Promise<number | StepResult> {
375
- const r = await dispatch(inputs, opts?.withCaptures === true)
376
- if (opts?.withCaptures) return { loss: r.output[0]!, captures: new Captures(captureShapes, r.captures) }
377
- return r.output[0]!
378
- }
379
-
380
- // ---- run() — forward-mode wrapper, returns { output, captures } ----------
381
- async function run(
382
- inputs: Record<string, Int32Array | Float32Array>,
383
- opts?: RunOptions,
384
- ): Promise<RunResult> {
385
- const r = await dispatch(inputs, opts?.withCaptures === true)
386
- return { output: r.output, captures: new Captures(captureShapes, r.captures) }
387
- }
388
-
389
- // ---- uploadParams ---------------------------------------------------------
390
- function uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions) {
391
- const partial = opts?.partial ?? false
392
- for (const name of Object.keys(params)) {
393
- if (!plan.paramsByName.has(name)) {
394
- throw new Error(
395
- `uploadParams: unknown param '${name}'. ` +
396
- `Known: ${[...plan.paramsByName.keys()].sort().join(', ')}`,
397
- )
398
- }
399
- }
400
- if (!partial) {
401
- for (const name of plan.paramsByName.keys()) {
402
- if (!(name in params)) {
403
- throw new Error(
404
- `uploadParams: missing param '${name}'. ` +
405
- `Pass { partial: true } if you mean to update only some params.`,
406
- )
407
- }
408
- }
409
- }
410
- for (const [name, bufId] of plan.paramsByName) {
411
- const data = params[name]
412
- if (!data) continue
413
- const expected = plan.buffers[bufId]!.byteSize / 4
414
- if (data.length !== expected) {
415
- throw new Error(`uploadParams: '${name}' has ${data.length} elements, expected ${expected}`)
416
- }
417
- queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
418
- }
419
- }
420
-
421
- // ---- download helpers -----------------------------------------------------
422
- async function downloadFromMap(map: Map<string, number>): Promise<Record<string, Float32Array>> {
423
- const stagings: { name: string; buf: GPUBuffer; bytes: number }[] = []
424
- const encoder = device.createCommandEncoder({ label: 'tensorgrad-download' })
425
- for (const [name, bufId] of map) {
426
- const spec = plan.buffers[bufId]!
427
- const staging = device.createBuffer({ size: spec.byteSize, usage: READBACK })
428
- encoder.copyBufferToBuffer(buffers.get(bufId)!, 0, staging, 0, spec.byteSize)
429
- stagings.push({ name, buf: staging, bytes: spec.byteSize })
430
- }
431
- queue.submit([encoder.finish()])
432
- const out: Record<string, Float32Array> = {}
433
- for (const s of stagings) {
434
- await s.buf.mapAsync(GPUMapMode.READ)
435
- out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0))
436
- s.buf.unmap()
437
- s.buf.destroy()
438
- }
439
- return out
440
- }
441
-
442
- // Fill a state buffer with its declared initValue (typically 0). Float and
443
- // int both serialize to 4 bytes per element. Used at allocation time and on
444
- // resetOptimizerState() — same logic, two callers.
445
- function fillStateBuffer(spec: { byteSize: number; dtype: 'f32' | 'i32' | 'bool'; initValue?: number }, target: GPUBuffer): void {
446
- const elements = spec.byteSize / 4
447
- const init = spec.dtype === 'f32'
448
- ? new Float32Array(elements).fill(spec.initValue ?? 0)
449
- : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
450
- queue.writeBuffer(target, 0, init as unknown as BufferSource)
451
- }
452
-
453
- function resetOptimizerState() {
454
- for (const spec of plan.buffers) {
455
- if (spec.kind === 'state') fillStateBuffer(spec, buffers.get(spec.id)!)
456
- }
457
- }
458
-
459
- // Build the params map AFTER buffer allocation so it points at the actual
460
- // GPUBuffers (shared or freshly allocated).
461
- const params = new Map<string, GPUBuffer>()
462
- for (const [name, bufId] of plan.paramsByName) {
463
- params.set(name, buffers.get(bufId)!)
464
- }
465
- // Static-after-compile shape metadata so users don't have to recompute
466
- // strides to interpret a flat capture readback.
467
- const captureShapes: Record<string, number[]> = {}
468
- for (const [name, bufId] of plan.capturesByName) {
469
- captureShapes[name] = [...plan.buffers[bufId]!.shape]
470
- }
471
- const outputShape = [...plan.buffers[lossBufferId]!.shape]
472
-
473
- const destroy = () => {
474
- for (const [id, b] of buffers) {
475
- if (ownedBufferIds.has(id)) b.destroy()
476
- }
477
- outputReadback.destroy()
478
- if (captureStagings) for (const b of captureStagings.values()) b.destroy()
479
- }
480
-
481
- return {
482
- device,
483
- params,
484
- outputShape,
485
- uploadParams,
486
- downloadParams: () => downloadFromMap(plan.paramsByName),
487
- downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),
488
- step,
489
- run,
490
- resetOptimizerState,
491
- destroy,
492
- }
493
- }
494
-
495
- /** Same machinery as `createRuntime`, narrower public type: a forward-only
496
- * graph exposes `run()` instead of `step()` (no optimizer state, no scalar-
497
- * loss readback). The full runtime object is built once and projected by
498
- * `compileForward` to the public shape. */
499
- export async function createForwardRuntime(
500
- plan: BufferPlan,
501
- kernels: KernelSpec[],
502
- outputBufferId: number,
503
- opts: RuntimeOpts = {},
504
- ): Promise<CompiledForward> {
505
- return await createRuntime(plan, kernels, outputBufferId, opts)
506
- }
507
-
508
- async function acquireDevice(): Promise<GPUDevice> {
509
- if (typeof navigator === 'undefined' || !navigator.gpu) {
510
- throw new Error('tensorgrad: WebGPU not available in this environment')
511
- }
512
- const adapter = await navigator.gpu.requestAdapter()
513
- if (!adapter) throw new Error('tensorgrad: no WebGPU adapter')
514
- return await adapter.requestDevice()
515
- }
1
+ // WebGPU runtime. Reads a BufferPlan + KernelSpec[] (produced by codegen),
2
+ // allocates real GPU buffers and pipelines, and provides a `step()` method
3
+ // that uploads inputs, dispatches all kernels, and reads back outputs.
4
+ //
5
+ // Browser-only: this module needs `navigator.gpu` at runtime.
6
+
7
+ import type { BufferPlan } from './buffers.js'
8
+ import type { KernelSpec } from './codegen.js'
9
+
10
+ // TS lib.dom defines WebGPU types but not the GPUMapMode runtime constant.
11
+ // Provided by the browser per WebGPU spec; declare just what we use.
12
+ declare const GPUMapMode: { readonly READ: number; readonly WRITE: number }
13
+
14
+ export interface UploadParamsOptions {
15
+ /** Skip the "missing param" check, allowing the caller to update only some
16
+ * params and leave the rest at their current GPU values. Extra (unknown)
17
+ * keys are still rejected — that's always a typo. Default: false. */
18
+ partial?: boolean
19
+ }
20
+
21
+ /**
22
+ * Activation readbacks for one `step()`/`run()` call. Keyed by the names
23
+ * passed to `capture(name, t)` during the trace. `get(name)` throws if the
24
+ * name isn't registered or wasn't read back this call (i.e., the call was
25
+ * made without `{ withCaptures: true }`); use `has(name)` if you need to
26
+ * branch. `shapeOf(name)` returns the static-after-compile shape and works
27
+ * regardless of whether captures were read back.
28
+ */
29
+ export class Captures {
30
+ constructor(
31
+ private readonly shapes: Record<string, readonly number[]>,
32
+ private readonly data: Map<string, Float32Array>,
33
+ ) {}
34
+ get(name: string): Float32Array {
35
+ const d = this.data.get(name)
36
+ if (!d) {
37
+ const known = [...this.data.keys()].sort().join(', ')
38
+ const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`
39
+ throw new Error(`Captures.get: '${name}' not present. ${detail}`)
40
+ }
41
+ return d
42
+ }
43
+ shapeOf(name: string): readonly number[] {
44
+ const s = this.shapes[name]
45
+ if (!s) {
46
+ const known = Object.keys(this.shapes).sort().join(', ') || '(none registered)'
47
+ throw new Error(`Captures.shapeOf: '${name}' not registered. Known: ${known}`)
48
+ }
49
+ return s
50
+ }
51
+ has(name: string): boolean { return this.data.has(name) }
52
+ names(): string[] { return [...this.data.keys()].sort() }
53
+ }
54
+
55
+ export interface RunResult {
56
+ output: Float32Array
57
+ captures: Captures
58
+ }
59
+
60
+ export interface StepResult {
61
+ loss: number
62
+ captures: Captures
63
+ }
64
+
65
+ export interface RunOptions {
66
+ /** Read back tensors registered via `capture(name, t)` during the trace.
67
+ * Default false. When false, the returned `captures` is empty (calling
68
+ * `.get` throws); when true, captures are read back and accessible. */
69
+ withCaptures?: boolean
70
+ }
71
+
72
+ /** Common surface for both training and forward-only compiled runtimes. */
73
+ export interface CompiledBase {
74
+ /** The GPUDevice this runtime is bound to. Pass to sibling compiles to
75
+ * share the device, or use directly for other GPU work. */
76
+ device: GPUDevice
77
+ /** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
78
+ * `sharedParams` to share without copies. */
79
+ params: Map<string, GPUBuffer>
80
+ /** Shape of the graph's output (loss scalar `[]` for training; the user's
81
+ * returned tensor for forward-only compiles). */
82
+ outputShape: number[]
83
+ /** Upload parameter Float32Arrays to their GPU buffers. By default, requires
84
+ * *all* params to be present; throws on any unknown or missing key. Pass
85
+ * `{ partial: true }` to skip the missing-key check. */
86
+ uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void
87
+ /** Read all parameters back as Float32Arrays — used for UI panels. */
88
+ downloadParams(): Promise<Record<string, Float32Array>>
89
+ /** Free GPU resources. */
90
+ destroy(): void
91
+ }
92
+
93
+ /** Run a dispatch and read back the full output tensor. Default returns the
94
+ * output as a `Float32Array`; with `{ withCaptures: true }` returns
95
+ * `{ output, captures }`. Same shape as `step()`'s overloads. */
96
+ export interface RunFn {
97
+ (inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
98
+ (inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunResult>
99
+ (inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunResult>
100
+ }
101
+
102
+ export interface CompiledRuntime extends CompiledBase {
103
+ /** Read all parameter gradients back. Mostly for verification / debugging. */
104
+ downloadParamGrads(): Promise<Record<string, Float32Array>>
105
+ /**
106
+ * One full forward+backward step.
107
+ * 1. Uploads `inputs` (tokens, targets, masks) to input buffers.
108
+ * 2. Dispatches every kernel in order.
109
+ * 3. Reads back the loss scalar (and any registered captures, if requested).
110
+ * Default returns the loss as a JS number; with `{ withCaptures: true }`
111
+ * returns `{ loss, captures }`.
112
+ */
113
+ step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
114
+ step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
115
+ step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>
116
+ /** Same dispatch as step() but returns the full output Float32Array — for
117
+ * training graphs the output is a scalar loss, so step() is usually more
118
+ * convenient. Provided for parity with `compileForward`. */
119
+ run: RunFn
120
+ /** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
121
+ * `uploadInitialParams()` for a full training reset without recompile. */
122
+ resetOptimizerState(): void
123
+ }
124
+
125
+ /** Forward-only compiled runtime produced by `compileForward`. No optimizer,
126
+ * no backward. Returns the output tensor (not just a scalar) per `run()` call. */
127
+ export interface CompiledForward extends CompiledBase {
128
+ run: RunFn
129
+ }
130
+
131
+ export interface RuntimeOpts {
132
+ /** Pre-acquired GPUDevice. If omitted, runtime requests its own. */
133
+ device?: GPUDevice
134
+ /** External param buffers to bind in place of allocating fresh ones, keyed
135
+ * by param name. Used to share params between a training compile and a
136
+ * sibling forward-only compile (e.g., a B=1 inference graph). When a name
137
+ * is in this map, the runtime reuses the provided GPUBuffer; otherwise it
138
+ * allocates as usual. */
139
+ sharedParams?: Map<string, GPUBuffer>
140
+ }
141
+
142
+ // Inlined numeric values (per WebGPU spec) so this module is importable in Node
143
+ // for codegen-only usage. The browser provides GPUBufferUsage as a global, but
144
+ // referencing it at module scope would crash before any browser code runs.
145
+ const STORAGE_RW = 0x80 /*STORAGE*/ | 0x8 /*COPY_DST*/ | 0x4 /*COPY_SRC*/
146
+ const READBACK = 0x1 /*MAP_READ*/ | 0x8 /*COPY_DST*/
147
+
148
+ export async function createRuntime(
149
+ plan: BufferPlan,
150
+ kernels: KernelSpec[],
151
+ lossBufferId: number,
152
+ opts: RuntimeOpts = {},
153
+ ): Promise<CompiledRuntime> {
154
+ const device = opts.device ?? await acquireDevice()
155
+ const queue = device.queue
156
+
157
+ // ---- Allocate one GPUBuffer per BufferSpec --------------------------------
158
+ // State buffers also get filled with their initValue at allocation time.
159
+ // Param buffers may be supplied externally via opts.sharedParams; in that
160
+ // case we reuse the provided GPUBuffer instead of allocating, and the
161
+ // sibling compile that owns it is responsible for upload + lifetime.
162
+ const buffers = new Map<number, GPUBuffer>()
163
+ const sharedParams = opts.sharedParams
164
+ for (const spec of plan.buffers) {
165
+ if (spec.kind === 'param' && sharedParams?.has(spec.name!)) {
166
+ const shared = sharedParams.get(spec.name!)!
167
+ if (shared.size !== spec.byteSize) {
168
+ throw new Error(
169
+ `sharedParams: size mismatch for '${spec.name}' — supplied ${shared.size} bytes, ` +
170
+ `compiled graph expects ${spec.byteSize}.`,
171
+ )
172
+ }
173
+ buffers.set(spec.id, shared)
174
+ continue
175
+ }
176
+ const buf = device.createBuffer({
177
+ size: spec.byteSize,
178
+ usage: STORAGE_RW,
179
+ label: spec.name ?? `t${spec.id}-${spec.kind}`,
180
+ })
181
+ buffers.set(spec.id, buf)
182
+ if (spec.kind === 'state') fillStateBuffer(spec, buf)
183
+ }
184
+ // Track which params are externally owned — those are skipped on destroy().
185
+ const ownedBufferIds = new Set<number>()
186
+ for (const spec of plan.buffers) {
187
+ if (!(spec.kind === 'param' && sharedParams?.has(spec.name!))) {
188
+ ownedBufferIds.add(spec.id)
189
+ }
190
+ }
191
+
192
+ // ---- Compile pipelines per kernel; cache by WGSL source -------------------
193
+ // Push an error scope around each shader+pipeline creation so we can surface
194
+ // the actual compile error rather than the cryptic "previous error" that
195
+ // comes from using an invalid pipeline at dispatch time.
196
+ const moduleCache = new Map<string, GPUShaderModule>()
197
+ const pipelines: (GPUComputePipeline | null)[] = []
198
+ type ErrorProbe = Promise<{ k: KernelSpec; module: GPUShaderModule; err: GPUError } | null>
199
+ const probes: ErrorProbe[] = []
200
+ for (const k of kernels) {
201
+ if (!k.wgsl) { pipelines.push(null); continue }
202
+ let module = moduleCache.get(k.wgsl)
203
+ if (!module) {
204
+ module = device.createShaderModule({ code: k.wgsl, label: k.opKind })
205
+ moduleCache.set(k.wgsl, module)
206
+ }
207
+ device.pushErrorScope('validation')
208
+ const pipeline = device.createComputePipeline({
209
+ layout: 'auto',
210
+ compute: { module, entryPoint: 'main' },
211
+ label: k.opKind,
212
+ })
213
+ pipelines.push(pipeline)
214
+ probes.push(device.popErrorScope().then(err => err ? { k, module: module!, err } : null))
215
+ }
216
+ const probeResults = await Promise.all(probes)
217
+ const failures = probeResults.filter((p): p is { k: KernelSpec; module: GPUShaderModule; err: GPUError } => p != null)
218
+ if (failures.length > 0) {
219
+ const reports: string[] = []
220
+ for (const { k, module, err } of failures) {
221
+ const info = await module.getCompilationInfo()
222
+ const messages = info.messages
223
+ .map(m => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`)
224
+ .join('\n')
225
+ reports.push(
226
+ `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` +
227
+ (messages || ' (no compilation messages)') +
228
+ `\n--- WGSL ---\n${k.wgsl}\n-----------`,
229
+ )
230
+ }
231
+ // eslint-disable-next-line no-console
232
+ console.error(reports.join('\n\n'))
233
+ throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`)
234
+ }
235
+
236
+ // ---- Pre-build bind groups (static — buffer ids don't change per step) ---
237
+ const bindGroups: (GPUBindGroup | null)[] = kernels.map((k, i) => {
238
+ const pipeline = pipelines[i]
239
+ if (!pipeline) return null
240
+ return device.createBindGroup({
241
+ layout: pipeline.getBindGroupLayout(0),
242
+ entries: k.bindings.map((bufId, idx) => ({
243
+ binding: idx,
244
+ resource: { buffer: buffers.get(bufId)! },
245
+ })),
246
+ })
247
+ })
248
+
249
+ // ---- Output readback staging buffer ---------------------------------------
250
+ // `outputBufferId` is the graph's main output (loss for training, the user's
251
+ // returned tensor for forward-only). step() reads back its first element;
252
+ // run() reads back the full Float32Array.
253
+ const outputSpec = plan.buffers[lossBufferId]!
254
+ const outputReadback = device.createBuffer({ size: outputSpec.byteSize, usage: READBACK })
255
+
256
+ // ---- Capture readback staging buffers (lazy) ------------------------------
257
+ // Allocated on first `step({ withCaptures: true })` call and reused across
258
+ // subsequent calls. When the graph has no captures registered or when the
259
+ // caller never opts in, no extra GPU memory is allocated.
260
+ let captureStagings: Map<string, GPUBuffer> | null = null
261
+ function ensureCaptureStagings(): Map<string, GPUBuffer> {
262
+ if (captureStagings) return captureStagings
263
+ captureStagings = new Map()
264
+ for (const [name, bufId] of plan.capturesByName) {
265
+ const spec = plan.buffers[bufId]!
266
+ const staging = device.createBuffer({ size: spec.byteSize, usage: READBACK, label: `cap-${name}` })
267
+ captureStagings.set(name, staging)
268
+ }
269
+ return captureStagings
270
+ }
271
+
272
+ // ---- dispatch() shared core for step() and run() -----------------------
273
+ // Uploads inputs, dispatches all kernels (in order), queues writebacks, copies
274
+ // the output buffer into its staging, optionally copies captures into theirs,
275
+ // submits, and reads back. Returns the full output Float32Array; step() takes
276
+ // [0] for scalar loss, run() returns it whole.
277
+ //
278
+ // **Concurrent calls auto-serialize.** Two `step()`/`run()` calls on the same
279
+ // runtime would otherwise both try to `mapAsync` the shared output staging
280
+ // buffer at the same time and trip "Buffer already has an outstanding map
281
+ // pending." We chain each new dispatch onto the prior one's promise so they
282
+ // run sequentially even when fired from independent async paths (e.g., a
283
+ // training loop's auxiliary `refreshPrediction()` + `writeDiagnostic()`).
284
+ let pending: Promise<unknown> = Promise.resolve()
285
+ async function dispatch(
286
+ inputs: Record<string, Int32Array | Float32Array>,
287
+ wantCaptures: boolean,
288
+ ): Promise<{ output: Float32Array; captures: Map<string, Float32Array> }> {
289
+ const turn = pending.catch(() => {}).then(() => dispatchUnsynchronized(inputs, wantCaptures))
290
+ pending = turn
291
+ return turn
292
+ }
293
+ async function dispatchUnsynchronized(
294
+ inputs: Record<string, Int32Array | Float32Array>,
295
+ wantCaptures: boolean,
296
+ ): Promise<{ output: Float32Array; captures: Map<string, Float32Array> }> {
297
+ if (wantCaptures && plan.capturesByName.size === 0) {
298
+ throw new Error(
299
+ `withCaptures=true but no capture(...) calls were registered during ` +
300
+ `the trace. Add capture('name', tensor) inside your forward pass for ` +
301
+ `the intermediates you want read back.`,
302
+ )
303
+ }
304
+ for (const [name, bufId] of plan.inputsByName) {
305
+ const data = inputs[name]
306
+ if (!data) throw new Error(`tensorgrad: missing input '${name}'`)
307
+ const expectedBytes = plan.buffers[bufId]!.byteSize
308
+ if (data.byteLength !== expectedBytes) {
309
+ throw new Error(`tensorgrad: input '${name}' has ${data.byteLength} bytes, expected ${expectedBytes}`)
310
+ }
311
+ // Cast to BufferSource: typed arrays are accepted by writeBuffer at runtime
312
+ // but TS may infer ArrayBufferLike (vs ArrayBuffer) under strict configs.
313
+ queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
314
+ }
315
+
316
+ const encoder = device.createCommandEncoder({ label: 'tensorgrad-step' })
317
+ for (let i = 0; i < kernels.length; i++) {
318
+ const k = kernels[i]!
319
+ if (!k.wgsl || k.threads === 0) continue
320
+ const pipeline = pipelines[i]!
321
+ const bindGroup = bindGroups[i]!
322
+ const pass = encoder.beginComputePass({ label: k.opKind })
323
+ pass.setPipeline(pipeline)
324
+ pass.setBindGroup(0, bindGroup)
325
+ // WebGPU caps each dispatch dimension at 65535 workgroups. Split into 2D
326
+ // when a kernel needs more than that on the X axis. Kernels compute their
327
+ // global index as `gid.x + gid.y * (65535 * workgroup_size)`, matching the
328
+ // stride we set here. For dispatches that fit in one row, gid.y is 0.
329
+ const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize))
330
+ const MAX_X = 65535
331
+ const wgX = Math.min(wgCount, MAX_X)
332
+ const wgY = Math.ceil(wgCount / MAX_X)
333
+ pass.dispatchWorkgroups(wgX, wgY, 1)
334
+ pass.end()
335
+ }
336
+ // After all dispatches: writebacks (Adam state, updated params). Empty for
337
+ // forward-only compiles.
338
+ for (const wb of plan.writebacks) {
339
+ encoder.copyBufferToBuffer(buffers.get(wb.source)!, 0, buffers.get(wb.dest)!, 0, wb.bytes)
340
+ }
341
+ encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0, outputReadback, 0, outputSpec.byteSize)
342
+ // Capture readbacks (only when opted in). Queued before submit so they
343
+ // observe the same kernel outputs as the main output.
344
+ let stagings: Map<string, GPUBuffer> | null = null
345
+ if (wantCaptures) {
346
+ stagings = ensureCaptureStagings()
347
+ for (const [name, bufId] of plan.capturesByName) {
348
+ const spec = plan.buffers[bufId]!
349
+ encoder.copyBufferToBuffer(buffers.get(bufId)!, 0, stagings.get(name)!, 0, spec.byteSize)
350
+ }
351
+ }
352
+ queue.submit([encoder.finish()])
353
+
354
+ await outputReadback.mapAsync(GPUMapMode.READ)
355
+ const output = new Float32Array(outputReadback.getMappedRange().slice(0))
356
+ outputReadback.unmap()
357
+
358
+ const captures = new Map<string, Float32Array>()
359
+ if (wantCaptures) {
360
+ for (const [name, staging] of stagings!) {
361
+ await staging.mapAsync(GPUMapMode.READ)
362
+ captures.set(name, new Float32Array(staging.getMappedRange().slice(0)))
363
+ staging.unmap()
364
+ }
365
+ }
366
+ return { output, captures }
367
+ }
368
+
369
+ // ---- step() training-mode wrapper, returns scalar [0] of output ---------
370
+ function step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
371
+ function step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
372
+ function step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>
373
+ async function step(
374
+ inputs: Record<string, Int32Array | Float32Array>,
375
+ opts?: RunOptions,
376
+ ): Promise<number | StepResult> {
377
+ const r = await dispatch(inputs, opts?.withCaptures === true)
378
+ if (opts?.withCaptures) return { loss: r.output[0]!, captures: new Captures(captureShapes, r.captures) }
379
+ return r.output[0]!
380
+ }
381
+
382
+ // ---- run() — forward-mode wrapper, returns Float32Array by default -------
383
+ // Same overloaded shape as step(): scalar-shaped result (here Float32Array,
384
+ // there a JS number) is the default; { ..., captures } is the opt-in form.
385
+ function run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
386
+ function run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunResult>
387
+ function run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunResult>
388
+ async function run(
389
+ inputs: Record<string, Int32Array | Float32Array>,
390
+ opts?: RunOptions,
391
+ ): Promise<Float32Array | RunResult> {
392
+ const r = await dispatch(inputs, opts?.withCaptures === true)
393
+ if (opts?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) }
394
+ return r.output
395
+ }
396
+
397
+ // ---- uploadParams ---------------------------------------------------------
398
+ function uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions) {
399
+ const partial = opts?.partial ?? false
400
+ for (const name of Object.keys(params)) {
401
+ if (!plan.paramsByName.has(name)) {
402
+ throw new Error(
403
+ `uploadParams: unknown param '${name}'. ` +
404
+ `Known: ${[...plan.paramsByName.keys()].sort().join(', ')}`,
405
+ )
406
+ }
407
+ }
408
+ if (!partial) {
409
+ for (const name of plan.paramsByName.keys()) {
410
+ if (!(name in params)) {
411
+ throw new Error(
412
+ `uploadParams: missing param '${name}'. ` +
413
+ `Pass { partial: true } if you mean to update only some params.`,
414
+ )
415
+ }
416
+ }
417
+ }
418
+ for (const [name, bufId] of plan.paramsByName) {
419
+ const data = params[name]
420
+ if (!data) continue
421
+ const expected = plan.buffers[bufId]!.byteSize / 4
422
+ if (data.length !== expected) {
423
+ throw new Error(`uploadParams: '${name}' has ${data.length} elements, expected ${expected}`)
424
+ }
425
+ queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
426
+ }
427
+ }
428
+
429
+ // ---- download helpers -----------------------------------------------------
430
+ async function downloadFromMap(map: Map<string, number>): Promise<Record<string, Float32Array>> {
431
+ const stagings: { name: string; buf: GPUBuffer; bytes: number }[] = []
432
+ const encoder = device.createCommandEncoder({ label: 'tensorgrad-download' })
433
+ for (const [name, bufId] of map) {
434
+ const spec = plan.buffers[bufId]!
435
+ const staging = device.createBuffer({ size: spec.byteSize, usage: READBACK })
436
+ encoder.copyBufferToBuffer(buffers.get(bufId)!, 0, staging, 0, spec.byteSize)
437
+ stagings.push({ name, buf: staging, bytes: spec.byteSize })
438
+ }
439
+ queue.submit([encoder.finish()])
440
+ const out: Record<string, Float32Array> = {}
441
+ for (const s of stagings) {
442
+ await s.buf.mapAsync(GPUMapMode.READ)
443
+ out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0))
444
+ s.buf.unmap()
445
+ s.buf.destroy()
446
+ }
447
+ return out
448
+ }
449
+
450
+ // Fill a state buffer with its declared initValue (typically 0). Float and
451
+ // int both serialize to 4 bytes per element. Used at allocation time and on
452
+ // resetOptimizerState() — same logic, two callers.
453
+ function fillStateBuffer(spec: { byteSize: number; dtype: 'f32' | 'i32' | 'bool'; initValue?: number }, target: GPUBuffer): void {
454
+ const elements = spec.byteSize / 4
455
+ const init = spec.dtype === 'f32'
456
+ ? new Float32Array(elements).fill(spec.initValue ?? 0)
457
+ : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
458
+ queue.writeBuffer(target, 0, init as unknown as BufferSource)
459
+ }
460
+
461
+ function resetOptimizerState() {
462
+ for (const spec of plan.buffers) {
463
+ if (spec.kind === 'state') fillStateBuffer(spec, buffers.get(spec.id)!)
464
+ }
465
+ }
466
+
467
+ // Build the params map AFTER buffer allocation so it points at the actual
468
+ // GPUBuffers (shared or freshly allocated).
469
+ const params = new Map<string, GPUBuffer>()
470
+ for (const [name, bufId] of plan.paramsByName) {
471
+ params.set(name, buffers.get(bufId)!)
472
+ }
473
+ // Static-after-compile shape metadata so users don't have to recompute
474
+ // strides to interpret a flat capture readback.
475
+ const captureShapes: Record<string, number[]> = {}
476
+ for (const [name, bufId] of plan.capturesByName) {
477
+ captureShapes[name] = [...plan.buffers[bufId]!.shape]
478
+ }
479
+ const outputShape = [...plan.buffers[lossBufferId]!.shape]
480
+
481
+ const destroy = () => {
482
+ for (const [id, b] of buffers) {
483
+ if (ownedBufferIds.has(id)) b.destroy()
484
+ }
485
+ outputReadback.destroy()
486
+ if (captureStagings) for (const b of captureStagings.values()) b.destroy()
487
+ }
488
+
489
+ return {
490
+ device,
491
+ params,
492
+ outputShape,
493
+ uploadParams,
494
+ downloadParams: () => downloadFromMap(plan.paramsByName),
495
+ downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),
496
+ step,
497
+ run,
498
+ resetOptimizerState,
499
+ destroy,
500
+ }
501
+ }
502
+
503
+ /** Same machinery as `createRuntime`, narrower public type: a forward-only
504
+ * graph exposes `run()` instead of `step()` (no optimizer state, no scalar-
505
+ * loss readback). The full runtime object is built once and projected by
506
+ * `compileForward` to the public shape. */
507
+ export async function createForwardRuntime(
508
+ plan: BufferPlan,
509
+ kernels: KernelSpec[],
510
+ outputBufferId: number,
511
+ opts: RuntimeOpts = {},
512
+ ): Promise<CompiledForward> {
513
+ return await createRuntime(plan, kernels, outputBufferId, opts)
514
+ }
515
+
516
+ async function acquireDevice(): Promise<GPUDevice> {
517
+ if (typeof navigator === 'undefined' || !navigator.gpu) {
518
+ throw new Error('tensorgrad: WebGPU not available in this environment')
519
+ }
520
+ const adapter = await navigator.gpu.requestAdapter()
521
+ if (!adapter) throw new Error('tensorgrad: no WebGPU adapter')
522
+ return await adapter.requestDevice()
523
+ }