tensorgrad 0.0.9 → 0.0.12

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/runtime.ts CHANGED
@@ -1,497 +1,523 @@
1
- // WebGPU runtime. Reads a BufferPlan + KernelSpec[] (produced by codegen),
2
- // allocates real GPU buffers and pipelines, and provides a `step()` method
3
- // that uploads inputs, dispatches all kernels, and reads back outputs.
4
- //
5
- // Browser-only: this module needs `navigator.gpu` at runtime.
6
-
7
- import type { BufferPlan } from './buffers.js'
8
- import type { KernelSpec } from './codegen.js'
9
-
10
- // TS lib.dom defines WebGPU types but not the GPUMapMode runtime constant.
11
- // Provided by the browser per WebGPU spec; declare just what we use.
12
- declare const GPUMapMode: { readonly READ: number; readonly WRITE: number }
13
-
14
- export interface UploadParamsOptions {
15
- /** Skip the "missing param" check, allowing the caller to update only some
16
- * params and leave the rest at their current GPU values. Extra (unknown)
17
- * keys are still rejected — that's always a typo. Default: false. */
18
- partial?: boolean
19
- }
20
-
21
- export interface StepOptions {
22
- /** Read back tensors registered via `capture(name, t)` during the trace.
23
- * When false/unset, `step()` returns just the loss number; staging buffers
24
- * for captures are not allocated. When true, returns `{ loss, captures }`
25
- * and lazily allocates one staging buffer per capture on first use. */
26
- withCaptures?: boolean
27
- }
28
-
29
- export interface StepWithCaptures {
30
- loss: number
31
- captures: Record<string, Float32Array>
32
- }
33
-
34
- export interface RunOptions {
35
- /** Read back tensors registered via `capture(name, t)` during the trace.
36
- * When false/unset, `run()` returns just the output Float32Array.
37
- * When true, returns `{ output, captures }`. */
38
- withCaptures?: boolean
39
- }
40
-
41
- export interface RunWithCaptures {
42
- output: Float32Array
43
- captures: Record<string, Float32Array>
44
- }
45
-
46
- /** Common surface for both training and forward-only compiled runtimes. */
47
- export interface CompiledBase {
48
- /** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
49
- * `sharedParams` to share without copies. */
50
- params: Map<string, GPUBuffer>
51
- /** Shape of each tensor registered via `capture(name, t)`. Static after
52
- * compile reshape readbacks without recomputing strides. */
53
- captureShapes: Record<string, number[]>
54
- /** Shape of the graph's output (loss scalar `[]` for training; the user's
55
- * returned tensor for forward-only compiles). */
56
- outputShape: number[]
57
- /** Upload parameter Float32Arrays to their GPU buffers. By default, requires
58
- * *all* params to be present; throws on any unknown or missing key. Pass
59
- * `{ partial: true }` to skip the missing-key check. */
60
- uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void
61
- /** Read all parameters back as Float32Arrays — used for UI panels. */
62
- downloadParams(): Promise<Record<string, Float32Array>>
63
- /** Free GPU resources. */
64
- destroy(): void
65
- }
66
-
67
- /** Run a dispatch and read back the full output tensor (and any registered
68
- * captures if requested). Forward-only compiles use this as their primary
69
- * surface; training compiles also expose it but `step()` is more convenient
70
- * there because the output is a scalar loss. */
71
- export interface RunFn {
72
- (inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
73
- (inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
74
- (inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
75
- }
76
-
77
- export interface CompiledRuntime extends CompiledBase {
78
- /** Read all parameter gradients back. Mostly for verification / debugging. */
79
- downloadParamGrads(): Promise<Record<string, Float32Array>>
80
- /**
81
- * One full forward+backward step.
82
- * 1. Uploads `inputs` (tokens, targets, masks) to input buffers.
83
- * 2. Dispatches every kernel in order.
84
- * 3. Reads back the loss scalar (and any registered captures, if requested).
85
- * Default returns the loss as a JS number; with `{ withCaptures: true }`
86
- * returns `{ loss, captures }` where `captures` is keyed by the names passed
87
- * to `capture(...)` during the trace.
88
- */
89
- step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
90
- step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepWithCaptures>
91
- step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepWithCaptures>
92
- /** Same dispatch as step() but returns the full output Float32Array — for
93
- * training graphs the output is a scalar loss, so step() is usually more
94
- * convenient. Provided for parity with `compileForward`. */
95
- run: RunFn
96
- /** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
97
- * `uploadInitialParams()` for a full training reset without recompile. */
98
- resetOptimizerState(): void
99
- }
100
-
101
- /** Forward-only compiled runtime — produced by `compileForward`. No optimizer,
102
- * no backward. Returns the output tensor (not just a scalar) per `run()` call. */
103
- export interface CompiledForward extends CompiledBase {
104
- run: RunFn
105
- }
106
-
107
- export interface RuntimeOpts {
108
- /** Pre-acquired GPUDevice. If omitted, runtime requests its own. */
109
- device?: GPUDevice
110
- /** External param buffers to bind in place of allocating fresh ones, keyed
111
- * by param name. Used to share params between a training compile and a
112
- * sibling forward-only compile (e.g., a B=1 inference graph). When a name
113
- * is in this map, the runtime reuses the provided GPUBuffer; otherwise it
114
- * allocates as usual. */
115
- sharedParams?: Map<string, GPUBuffer>
116
- }
117
-
118
- // Inlined numeric values (per WebGPU spec) so this module is importable in Node
119
- // for codegen-only usage. The browser provides GPUBufferUsage as a global, but
120
- // referencing it at module scope would crash before any browser code runs.
121
- const STORAGE_RW = 0x80 /*STORAGE*/ | 0x8 /*COPY_DST*/ | 0x4 /*COPY_SRC*/
122
- const READBACK = 0x1 /*MAP_READ*/ | 0x8 /*COPY_DST*/
123
-
124
- export async function createRuntime(
125
- plan: BufferPlan,
126
- kernels: KernelSpec[],
127
- lossBufferId: number,
128
- opts: RuntimeOpts = {},
129
- ): Promise<CompiledRuntime> {
130
- const device = opts.device ?? await acquireDevice()
131
- const queue = device.queue
132
-
133
- // ---- Allocate one GPUBuffer per BufferSpec --------------------------------
134
- // State buffers also get filled with their initValue at allocation time.
135
- // Param buffers may be supplied externally via opts.sharedParams; in that
136
- // case we reuse the provided GPUBuffer instead of allocating, and the
137
- // sibling compile that owns it is responsible for upload + lifetime.
138
- const buffers = new Map<number, GPUBuffer>()
139
- const sharedParams = opts.sharedParams
140
- for (const spec of plan.buffers) {
141
- if (spec.kind === 'param' && sharedParams?.has(spec.name!)) {
142
- const shared = sharedParams.get(spec.name!)!
143
- if (shared.size !== spec.byteSize) {
144
- throw new Error(
145
- `sharedParams: size mismatch for '${spec.name}' supplied ${shared.size} bytes, ` +
146
- `compiled graph expects ${spec.byteSize}.`,
147
- )
148
- }
149
- buffers.set(spec.id, shared)
150
- continue
151
- }
152
- const buf = device.createBuffer({
153
- size: spec.byteSize,
154
- usage: STORAGE_RW,
155
- label: spec.name ?? `t${spec.id}-${spec.kind}`,
156
- })
157
- buffers.set(spec.id, buf)
158
- if (spec.kind === 'state') fillStateBuffer(spec, buf)
159
- }
160
- // Track which params are externally owned those are skipped on destroy().
161
- const ownedBufferIds = new Set<number>()
162
- for (const spec of plan.buffers) {
163
- if (!(spec.kind === 'param' && sharedParams?.has(spec.name!))) {
164
- ownedBufferIds.add(spec.id)
165
- }
166
- }
167
-
168
- // ---- Compile pipelines per kernel; cache by WGSL source -------------------
169
- // Push an error scope around each shader+pipeline creation so we can surface
170
- // the actual compile error rather than the cryptic "previous error" that
171
- // comes from using an invalid pipeline at dispatch time.
172
- const moduleCache = new Map<string, GPUShaderModule>()
173
- const pipelines: (GPUComputePipeline | null)[] = []
174
- type ErrorProbe = Promise<{ k: KernelSpec; module: GPUShaderModule; err: GPUError } | null>
175
- const probes: ErrorProbe[] = []
176
- for (const k of kernels) {
177
- if (!k.wgsl) { pipelines.push(null); continue }
178
- let module = moduleCache.get(k.wgsl)
179
- if (!module) {
180
- module = device.createShaderModule({ code: k.wgsl, label: k.opKind })
181
- moduleCache.set(k.wgsl, module)
182
- }
183
- device.pushErrorScope('validation')
184
- const pipeline = device.createComputePipeline({
185
- layout: 'auto',
186
- compute: { module, entryPoint: 'main' },
187
- label: k.opKind,
188
- })
189
- pipelines.push(pipeline)
190
- probes.push(device.popErrorScope().then(err => err ? { k, module: module!, err } : null))
191
- }
192
- const probeResults = await Promise.all(probes)
193
- const failures = probeResults.filter((p): p is { k: KernelSpec; module: GPUShaderModule; err: GPUError } => p != null)
194
- if (failures.length > 0) {
195
- const reports: string[] = []
196
- for (const { k, module, err } of failures) {
197
- const info = await module.getCompilationInfo()
198
- const messages = info.messages
199
- .map(m => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`)
200
- .join('\n')
201
- reports.push(
202
- `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` +
203
- (messages || ' (no compilation messages)') +
204
- `\n--- WGSL ---\n${k.wgsl}\n-----------`,
205
- )
206
- }
207
- // eslint-disable-next-line no-console
208
- console.error(reports.join('\n\n'))
209
- throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`)
210
- }
211
-
212
- // ---- Pre-build bind groups (static — buffer ids don't change per step) ---
213
- const bindGroups: (GPUBindGroup | null)[] = kernels.map((k, i) => {
214
- const pipeline = pipelines[i]
215
- if (!pipeline) return null
216
- return device.createBindGroup({
217
- layout: pipeline.getBindGroupLayout(0),
218
- entries: k.bindings.map((bufId, idx) => ({
219
- binding: idx,
220
- resource: { buffer: buffers.get(bufId)! },
221
- })),
222
- })
223
- })
224
-
225
- // ---- Output readback staging buffer ---------------------------------------
226
- // `outputBufferId` is the graph's main output (loss for training, the user's
227
- // returned tensor for forward-only). step() reads back its first element;
228
- // run() reads back the full Float32Array.
229
- const outputSpec = plan.buffers[lossBufferId]!
230
- const outputReadback = device.createBuffer({ size: outputSpec.byteSize, usage: READBACK })
231
-
232
- // ---- Capture readback staging buffers (lazy) ------------------------------
233
- // Allocated on first `step({ withCaptures: true })` call and reused across
234
- // subsequent calls. When the graph has no captures registered or when the
235
- // caller never opts in, no extra GPU memory is allocated.
236
- let captureStagings: Map<string, GPUBuffer> | null = null
237
- function ensureCaptureStagings(): Map<string, GPUBuffer> {
238
- if (captureStagings) return captureStagings
239
- captureStagings = new Map()
240
- for (const [name, bufId] of plan.capturesByName) {
241
- const spec = plan.buffers[bufId]!
242
- const staging = device.createBuffer({ size: spec.byteSize, usage: READBACK, label: `cap-${name}` })
243
- captureStagings.set(name, staging)
244
- }
245
- return captureStagings
246
- }
247
-
248
- // ---- dispatch() — shared core for step() and run() -----------------------
249
- // Uploads inputs, dispatches all kernels (in order), queues writebacks, copies
250
- // the output buffer into its staging, optionally copies captures into theirs,
251
- // submits, and reads back. Returns the full output Float32Array; step() takes
252
- // [0] for scalar loss, run() returns it whole.
253
- //
254
- // **Concurrent calls auto-serialize.** Two `step()`/`run()` calls on the same
255
- // runtime would otherwise both try to `mapAsync` the shared output staging
256
- // buffer at the same time and trip "Buffer already has an outstanding map
257
- // pending." We chain each new dispatch onto the prior one's promise so they
258
- // run sequentially even when fired from independent async paths (e.g., a
259
- // training loop's auxiliary `refreshPrediction()` + `writeDiagnostic()`).
260
- let pending: Promise<unknown> = Promise.resolve()
261
- async function dispatch(
262
- inputs: Record<string, Int32Array | Float32Array>,
263
- wantCaptures: boolean,
264
- ): Promise<{ output: Float32Array; captures: Record<string, Float32Array> | null }> {
265
- const turn = pending.catch(() => {}).then(() => dispatchUnsynchronized(inputs, wantCaptures))
266
- pending = turn
267
- return turn
268
- }
269
- async function dispatchUnsynchronized(
270
- inputs: Record<string, Int32Array | Float32Array>,
271
- wantCaptures: boolean,
272
- ): Promise<{ output: Float32Array; captures: Record<string, Float32Array> | null }> {
273
- if (wantCaptures && plan.capturesByName.size === 0) {
274
- throw new Error(
275
- `withCaptures=true but no capture(...) calls were registered during ` +
276
- `the trace. Add capture('name', tensor) inside your forward pass for ` +
277
- `the intermediates you want read back.`,
278
- )
279
- }
280
- for (const [name, bufId] of plan.inputsByName) {
281
- const data = inputs[name]
282
- if (!data) throw new Error(`tensorgrad: missing input '${name}'`)
283
- const expectedBytes = plan.buffers[bufId]!.byteSize
284
- if (data.byteLength !== expectedBytes) {
285
- throw new Error(`tensorgrad: input '${name}' has ${data.byteLength} bytes, expected ${expectedBytes}`)
286
- }
287
- // Cast to BufferSource: typed arrays are accepted by writeBuffer at runtime
288
- // but TS may infer ArrayBufferLike (vs ArrayBuffer) under strict configs.
289
- queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
290
- }
291
-
292
- const encoder = device.createCommandEncoder({ label: 'tensorgrad-step' })
293
- for (let i = 0; i < kernels.length; i++) {
294
- const k = kernels[i]!
295
- if (!k.wgsl || k.threads === 0) continue
296
- const pipeline = pipelines[i]!
297
- const bindGroup = bindGroups[i]!
298
- const pass = encoder.beginComputePass({ label: k.opKind })
299
- pass.setPipeline(pipeline)
300
- pass.setBindGroup(0, bindGroup)
301
- // WebGPU caps each dispatch dimension at 65535 workgroups. Split into 2D
302
- // when a kernel needs more than that on the X axis. Kernels compute their
303
- // global index as `gid.x + gid.y * (65535 * workgroup_size)`, matching the
304
- // stride we set here. For dispatches that fit in one row, gid.y is 0.
305
- const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize))
306
- const MAX_X = 65535
307
- const wgX = Math.min(wgCount, MAX_X)
308
- const wgY = Math.ceil(wgCount / MAX_X)
309
- pass.dispatchWorkgroups(wgX, wgY, 1)
310
- pass.end()
311
- }
312
- // After all dispatches: writebacks (Adam state, updated params). Empty for
313
- // forward-only compiles.
314
- for (const wb of plan.writebacks) {
315
- encoder.copyBufferToBuffer(buffers.get(wb.source)!, 0, buffers.get(wb.dest)!, 0, wb.bytes)
316
- }
317
- encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0, outputReadback, 0, outputSpec.byteSize)
318
- // Capture readbacks (only when opted in). Queued before submit so they
319
- // observe the same kernel outputs as the main output.
320
- let stagings: Map<string, GPUBuffer> | null = null
321
- if (wantCaptures) {
322
- stagings = ensureCaptureStagings()
323
- for (const [name, bufId] of plan.capturesByName) {
324
- const spec = plan.buffers[bufId]!
325
- encoder.copyBufferToBuffer(buffers.get(bufId)!, 0, stagings.get(name)!, 0, spec.byteSize)
326
- }
327
- }
328
- queue.submit([encoder.finish()])
329
-
330
- await outputReadback.mapAsync(GPUMapMode.READ)
331
- const output = new Float32Array(outputReadback.getMappedRange().slice(0))
332
- outputReadback.unmap()
333
-
334
- if (!wantCaptures) return { output, captures: null }
335
-
336
- const captures: Record<string, Float32Array> = {}
337
- for (const [name, staging] of stagings!) {
338
- await staging.mapAsync(GPUMapMode.READ)
339
- captures[name] = new Float32Array(staging.getMappedRange().slice(0))
340
- staging.unmap()
341
- }
342
- return { output, captures }
343
- }
344
-
345
- // ---- step() — training-mode wrapper, returns scalar [0] of output ---------
346
- function step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
347
- function step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepWithCaptures>
348
- function step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepWithCaptures>
349
- async function step(
350
- inputs: Record<string, Int32Array | Float32Array>,
351
- opts?: StepOptions,
352
- ): Promise<number | StepWithCaptures> {
353
- const r = await dispatch(inputs, opts?.withCaptures === true)
354
- if (opts?.withCaptures) return { loss: r.output[0]!, captures: r.captures! }
355
- return r.output[0]!
356
- }
357
-
358
- // ---- run() forward-mode wrapper, returns full output Float32Array -------
359
- function run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
360
- function run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
361
- function run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
362
- async function run(
363
- inputs: Record<string, Int32Array | Float32Array>,
364
- opts?: RunOptions,
365
- ): Promise<Float32Array | RunWithCaptures> {
366
- const r = await dispatch(inputs, opts?.withCaptures === true)
367
- if (opts?.withCaptures) return { output: r.output, captures: r.captures! }
368
- return r.output
369
- }
370
-
371
- // ---- uploadParams ---------------------------------------------------------
372
- function uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions) {
373
- const partial = opts?.partial ?? false
374
- for (const name of Object.keys(params)) {
375
- if (!plan.paramsByName.has(name)) {
376
- throw new Error(
377
- `uploadParams: unknown param '${name}'. ` +
378
- `Known: ${[...plan.paramsByName.keys()].sort().join(', ')}`,
379
- )
380
- }
381
- }
382
- if (!partial) {
383
- for (const name of plan.paramsByName.keys()) {
384
- if (!(name in params)) {
385
- throw new Error(
386
- `uploadParams: missing param '${name}'. ` +
387
- `Pass { partial: true } if you mean to update only some params.`,
388
- )
389
- }
390
- }
391
- }
392
- for (const [name, bufId] of plan.paramsByName) {
393
- const data = params[name]
394
- if (!data) continue
395
- const expected = plan.buffers[bufId]!.byteSize / 4
396
- if (data.length !== expected) {
397
- throw new Error(`uploadParams: '${name}' has ${data.length} elements, expected ${expected}`)
398
- }
399
- queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
400
- }
401
- }
402
-
403
- // ---- download helpers -----------------------------------------------------
404
- async function downloadFromMap(map: Map<string, number>): Promise<Record<string, Float32Array>> {
405
- const stagings: { name: string; buf: GPUBuffer; bytes: number }[] = []
406
- const encoder = device.createCommandEncoder({ label: 'tensorgrad-download' })
407
- for (const [name, bufId] of map) {
408
- const spec = plan.buffers[bufId]!
409
- const staging = device.createBuffer({ size: spec.byteSize, usage: READBACK })
410
- encoder.copyBufferToBuffer(buffers.get(bufId)!, 0, staging, 0, spec.byteSize)
411
- stagings.push({ name, buf: staging, bytes: spec.byteSize })
412
- }
413
- queue.submit([encoder.finish()])
414
- const out: Record<string, Float32Array> = {}
415
- for (const s of stagings) {
416
- await s.buf.mapAsync(GPUMapMode.READ)
417
- out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0))
418
- s.buf.unmap()
419
- s.buf.destroy()
420
- }
421
- return out
422
- }
423
-
424
- // Fill a state buffer with its declared initValue (typically 0). Float and
425
- // int both serialize to 4 bytes per element. Used at allocation time and on
426
- // resetOptimizerState() — same logic, two callers.
427
- function fillStateBuffer(spec: { byteSize: number; dtype: 'f32' | 'i32' | 'bool'; initValue?: number }, target: GPUBuffer): void {
428
- const elements = spec.byteSize / 4
429
- const init = spec.dtype === 'f32'
430
- ? new Float32Array(elements).fill(spec.initValue ?? 0)
431
- : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
432
- queue.writeBuffer(target, 0, init as unknown as BufferSource)
433
- }
434
-
435
- function resetOptimizerState() {
436
- for (const spec of plan.buffers) {
437
- if (spec.kind === 'state') fillStateBuffer(spec, buffers.get(spec.id)!)
438
- }
439
- }
440
-
441
- // Build the params map AFTER buffer allocation so it points at the actual
442
- // GPUBuffers (shared or freshly allocated).
443
- const params = new Map<string, GPUBuffer>()
444
- for (const [name, bufId] of plan.paramsByName) {
445
- params.set(name, buffers.get(bufId)!)
446
- }
447
- // Static-after-compile shape metadata so users don't have to recompute
448
- // strides to interpret a flat capture readback.
449
- const captureShapes: Record<string, number[]> = {}
450
- for (const [name, bufId] of plan.capturesByName) {
451
- captureShapes[name] = [...plan.buffers[bufId]!.shape]
452
- }
453
- const outputShape = [...plan.buffers[lossBufferId]!.shape]
454
-
455
- const destroy = () => {
456
- for (const [id, b] of buffers) {
457
- if (ownedBufferIds.has(id)) b.destroy()
458
- }
459
- outputReadback.destroy()
460
- if (captureStagings) for (const b of captureStagings.values()) b.destroy()
461
- }
462
-
463
- return {
464
- params,
465
- captureShapes,
466
- outputShape,
467
- uploadParams,
468
- downloadParams: () => downloadFromMap(plan.paramsByName),
469
- downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),
470
- step,
471
- run,
472
- resetOptimizerState,
473
- destroy,
474
- }
475
- }
476
-
477
- /** Same machinery as `createRuntime`, narrower public type: a forward-only
478
- * graph exposes `run()` instead of `step()` (no optimizer state, no scalar-
479
- * loss readback). The full runtime object is built once and projected by
480
- * `compileForward` to the public shape. */
481
- export async function createForwardRuntime(
482
- plan: BufferPlan,
483
- kernels: KernelSpec[],
484
- outputBufferId: number,
485
- opts: RuntimeOpts = {},
486
- ): Promise<CompiledForward> {
487
- return await createRuntime(plan, kernels, outputBufferId, opts)
488
- }
489
-
490
- async function acquireDevice(): Promise<GPUDevice> {
491
- if (typeof navigator === 'undefined' || !navigator.gpu) {
492
- throw new Error('tensorgrad: WebGPU not available in this environment')
493
- }
494
- const adapter = await navigator.gpu.requestAdapter()
495
- if (!adapter) throw new Error('tensorgrad: no WebGPU adapter')
496
- return await adapter.requestDevice()
497
- }
1
+ // WebGPU runtime. Reads a BufferPlan + KernelSpec[] (produced by codegen),
2
+ // allocates real GPU buffers and pipelines, and provides a `step()` method
3
+ // that uploads inputs, dispatches all kernels, and reads back outputs.
4
+ //
5
+ // Browser-only: this module needs `navigator.gpu` at runtime.
6
+
7
+ import type { BufferPlan } from './buffers.js'
8
+ import type { KernelSpec } from './codegen.js'
9
+
10
+ // TS lib.dom defines WebGPU types but not the GPUMapMode runtime constant.
11
+ // Provided by the browser per WebGPU spec; declare just what we use.
12
+ declare const GPUMapMode: { readonly READ: number; readonly WRITE: number }
13
+
14
+ export interface UploadParamsOptions {
15
+ /** Skip the "missing param" check, allowing the caller to update only some
16
+ * params and leave the rest at their current GPU values. Extra (unknown)
17
+ * keys are still rejected — that's always a typo. Default: false. */
18
+ partial?: boolean
19
+ }
20
+
21
+ /**
22
+ * Activation readbacks for one `step()`/`run()` call. Keyed by the names
23
+ * passed to `capture(name, t)` during the trace. `get(name)` throws if the
24
+ * name isn't registered or wasn't read back this call (i.e., the call was
25
+ * made without `{ withCaptures: true }`); use `has(name)` if you need to
26
+ * branch. `shapeOf(name)` returns the static-after-compile shape and works
27
+ * regardless of whether captures were read back.
28
+ */
29
+ export class Captures {
30
+ constructor(
31
+ private readonly shapes: Record<string, readonly number[]>,
32
+ private readonly data: Map<string, Float32Array>,
33
+ ) {}
34
+ get(name: string): Float32Array {
35
+ const d = this.data.get(name)
36
+ if (!d) {
37
+ const known = [...this.data.keys()].sort().join(', ')
38
+ const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`
39
+ throw new Error(`Captures.get: '${name}' not present. ${detail}`)
40
+ }
41
+ return d
42
+ }
43
+ shapeOf(name: string): readonly number[] {
44
+ const s = this.shapes[name]
45
+ if (!s) {
46
+ const known = Object.keys(this.shapes).sort().join(', ') || '(none registered)'
47
+ throw new Error(`Captures.shapeOf: '${name}' not registered. Known: ${known}`)
48
+ }
49
+ return s
50
+ }
51
+ has(name: string): boolean { return this.data.has(name) }
52
+ names(): string[] { return [...this.data.keys()].sort() }
53
+ }
54
+
55
+ export interface RunResult {
56
+ output: Float32Array
57
+ captures: Captures
58
+ }
59
+
60
+ export interface StepResult {
61
+ loss: number
62
+ captures: Captures
63
+ }
64
+
65
+ export interface RunOptions {
66
+ /** Read back tensors registered via `capture(name, t)` during the trace.
67
+ * Default false. When false, the returned `captures` is empty (calling
68
+ * `.get` throws); when true, captures are read back and accessible. */
69
+ withCaptures?: boolean
70
+ }
71
+
72
+ /** Common surface for both training and forward-only compiled runtimes. */
73
+ export interface CompiledBase {
74
+ /** The GPUDevice this runtime is bound to. Pass to sibling compiles to
75
+ * share the device, or use directly for other GPU work. */
76
+ device: GPUDevice
77
+ /** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
78
+ * `sharedParams` to share without copies. */
79
+ params: Map<string, GPUBuffer>
80
+ /** Shape of the graph's output (loss scalar `[]` for training; the user's
81
+ * returned tensor for forward-only compiles). */
82
+ outputShape: number[]
83
+ /** Upload parameter Float32Arrays to their GPU buffers. By default, requires
84
+ * *all* params to be present; throws on any unknown or missing key. Pass
85
+ * `{ partial: true }` to skip the missing-key check. */
86
+ uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void
87
+ /** Read all parameters back as Float32Arrays — used for UI panels. */
88
+ downloadParams(): Promise<Record<string, Float32Array>>
89
+ /** Free GPU resources. */
90
+ destroy(): void
91
+ }
92
+
93
+ /** Run a dispatch and read back the full output tensor. Default returns the
94
+ * output as a `Float32Array`; with `{ withCaptures: true }` returns
95
+ * `{ output, captures }`. Same shape as `step()`'s overloads. */
96
+ export interface RunFn {
97
+ (inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
98
+ (inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunResult>
99
+ (inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunResult>
100
+ }
101
+
102
+ export interface CompiledRuntime extends CompiledBase {
103
+ /** Read all parameter gradients back. Mostly for verification / debugging. */
104
+ downloadParamGrads(): Promise<Record<string, Float32Array>>
105
+ /**
106
+ * One full forward+backward step.
107
+ * 1. Uploads `inputs` (tokens, targets, masks) to input buffers.
108
+ * 2. Dispatches every kernel in order.
109
+ * 3. Reads back the loss scalar (and any registered captures, if requested).
110
+ * Default returns the loss as a JS number; with `{ withCaptures: true }`
111
+ * returns `{ loss, captures }`.
112
+ */
113
+ step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
114
+ step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
115
+ step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>
116
+ /** Same dispatch as step() but returns the full output Float32Array — for
117
+ * training graphs the output is a scalar loss, so step() is usually more
118
+ * convenient. Provided for parity with `compileForward`. */
119
+ run: RunFn
120
+ /** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
121
+ * `uploadInitialParams()` for a full training reset without recompile. */
122
+ resetOptimizerState(): void
123
+ }
124
+
125
+ /** Forward-only compiled runtime — produced by `compileForward`. No optimizer,
126
+ * no backward. Returns the output tensor (not just a scalar) per `run()` call. */
127
+ export interface CompiledForward extends CompiledBase {
128
+ run: RunFn
129
+ }
130
+
131
+ export interface RuntimeOpts {
132
+ /** Pre-acquired GPUDevice. If omitted, runtime requests its own. */
133
+ device?: GPUDevice
134
+ /** External param buffers to bind in place of allocating fresh ones, keyed
135
+ * by param name. Used to share params between a training compile and a
136
+ * sibling forward-only compile (e.g., a B=1 inference graph). When a name
137
+ * is in this map, the runtime reuses the provided GPUBuffer; otherwise it
138
+ * allocates as usual. */
139
+ sharedParams?: Map<string, GPUBuffer>
140
+ }
141
+
142
+ // Inlined numeric values (per WebGPU spec) so this module is importable in Node
143
+ // for codegen-only usage. The browser provides GPUBufferUsage as a global, but
144
+ // referencing it at module scope would crash before any browser code runs.
145
+ const STORAGE_RW = 0x80 /*STORAGE*/ | 0x8 /*COPY_DST*/ | 0x4 /*COPY_SRC*/
146
+ const READBACK = 0x1 /*MAP_READ*/ | 0x8 /*COPY_DST*/
147
+
148
+ export async function createRuntime(
149
+ plan: BufferPlan,
150
+ kernels: KernelSpec[],
151
+ lossBufferId: number,
152
+ opts: RuntimeOpts = {},
153
+ ): Promise<CompiledRuntime> {
154
+ const device = opts.device ?? await acquireDevice()
155
+ const queue = device.queue
156
+
157
+ // ---- Allocate one GPUBuffer per BufferSpec --------------------------------
158
+ // State buffers also get filled with their initValue at allocation time.
159
+ // Param buffers may be supplied externally via opts.sharedParams; in that
160
+ // case we reuse the provided GPUBuffer instead of allocating, and the
161
+ // sibling compile that owns it is responsible for upload + lifetime.
162
+ const buffers = new Map<number, GPUBuffer>()
163
+ const sharedParams = opts.sharedParams
164
+ for (const spec of plan.buffers) {
165
+ if (spec.kind === 'param' && sharedParams?.has(spec.name!)) {
166
+ const shared = sharedParams.get(spec.name!)!
167
+ if (shared.size !== spec.byteSize) {
168
+ throw new Error(
169
+ `sharedParams: size mismatch for '${spec.name}' supplied ${shared.size} bytes, ` +
170
+ `compiled graph expects ${spec.byteSize}.`,
171
+ )
172
+ }
173
+ buffers.set(spec.id, shared)
174
+ continue
175
+ }
176
+ const buf = device.createBuffer({
177
+ size: spec.byteSize,
178
+ usage: STORAGE_RW,
179
+ label: spec.name ?? `t${spec.id}-${spec.kind}`,
180
+ })
181
+ buffers.set(spec.id, buf)
182
+ if (spec.kind === 'state') fillStateBuffer(spec, buf)
183
+ }
184
+ // Track which params are externally owned — those are skipped on destroy().
185
+ const ownedBufferIds = new Set<number>()
186
+ for (const spec of plan.buffers) {
187
+ if (!(spec.kind === 'param' && sharedParams?.has(spec.name!))) {
188
+ ownedBufferIds.add(spec.id)
189
+ }
190
+ }
191
+
192
+ // ---- Compile pipelines per kernel; cache by WGSL source -------------------
193
+ // Push an error scope around each shader+pipeline creation so we can surface
194
+ // the actual compile error rather than the cryptic "previous error" that
195
+ // comes from using an invalid pipeline at dispatch time.
196
+ const moduleCache = new Map<string, GPUShaderModule>()
197
+ const pipelines: (GPUComputePipeline | null)[] = []
198
+ type ErrorProbe = Promise<{ k: KernelSpec; module: GPUShaderModule; err: GPUError } | null>
199
+ const probes: ErrorProbe[] = []
200
+ for (const k of kernels) {
201
+ if (!k.wgsl) { pipelines.push(null); continue }
202
+ let module = moduleCache.get(k.wgsl)
203
+ if (!module) {
204
+ module = device.createShaderModule({ code: k.wgsl, label: k.opKind })
205
+ moduleCache.set(k.wgsl, module)
206
+ }
207
+ device.pushErrorScope('validation')
208
+ const pipeline = device.createComputePipeline({
209
+ layout: 'auto',
210
+ compute: { module, entryPoint: 'main' },
211
+ label: k.opKind,
212
+ })
213
+ pipelines.push(pipeline)
214
+ probes.push(device.popErrorScope().then(err => err ? { k, module: module!, err } : null))
215
+ }
216
+ const probeResults = await Promise.all(probes)
217
+ const failures = probeResults.filter((p): p is { k: KernelSpec; module: GPUShaderModule; err: GPUError } => p != null)
218
+ if (failures.length > 0) {
219
+ const reports: string[] = []
220
+ for (const { k, module, err } of failures) {
221
+ const info = await module.getCompilationInfo()
222
+ const messages = info.messages
223
+ .map(m => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`)
224
+ .join('\n')
225
+ reports.push(
226
+ `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` +
227
+ (messages || ' (no compilation messages)') +
228
+ `\n--- WGSL ---\n${k.wgsl}\n-----------`,
229
+ )
230
+ }
231
+ // eslint-disable-next-line no-console
232
+ console.error(reports.join('\n\n'))
233
+ throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`)
234
+ }
235
+
236
+ // ---- Pre-build bind groups (static buffer ids don't change per step) ---
237
+ const bindGroups: (GPUBindGroup | null)[] = kernels.map((k, i) => {
238
+ const pipeline = pipelines[i]
239
+ if (!pipeline) return null
240
+ return device.createBindGroup({
241
+ layout: pipeline.getBindGroupLayout(0),
242
+ entries: k.bindings.map((bufId, idx) => ({
243
+ binding: idx,
244
+ resource: { buffer: buffers.get(bufId)! },
245
+ })),
246
+ })
247
+ })
248
+
249
+ // ---- Output readback staging buffer ---------------------------------------
250
+ // `outputBufferId` is the graph's main output (loss for training, the user's
251
+ // returned tensor for forward-only). step() reads back its first element;
252
+ // run() reads back the full Float32Array.
253
+ const outputSpec = plan.buffers[lossBufferId]!
254
+ const outputReadback = device.createBuffer({ size: outputSpec.byteSize, usage: READBACK })
255
+
256
+ // ---- Capture readback staging buffers (lazy) ------------------------------
257
+ // Allocated on first `step({ withCaptures: true })` call and reused across
258
+ // subsequent calls. When the graph has no captures registered or when the
259
+ // caller never opts in, no extra GPU memory is allocated.
260
+ let captureStagings: Map<string, GPUBuffer> | null = null
261
+ function ensureCaptureStagings(): Map<string, GPUBuffer> {
262
+ if (captureStagings) return captureStagings
263
+ captureStagings = new Map()
264
+ for (const [name, bufId] of plan.capturesByName) {
265
+ const spec = plan.buffers[bufId]!
266
+ const staging = device.createBuffer({ size: spec.byteSize, usage: READBACK, label: `cap-${name}` })
267
+ captureStagings.set(name, staging)
268
+ }
269
+ return captureStagings
270
+ }
271
+
272
+ // ---- dispatch() shared core for step() and run() -----------------------
273
+ // Uploads inputs, dispatches all kernels (in order), queues writebacks, copies
274
+ // the output buffer into its staging, optionally copies captures into theirs,
275
+ // submits, and reads back. Returns the full output Float32Array; step() takes
276
+ // [0] for scalar loss, run() returns it whole.
277
+ //
278
+ // **Concurrent calls auto-serialize.** Two `step()`/`run()` calls on the same
279
+ // runtime would otherwise both try to `mapAsync` the shared output staging
280
+ // buffer at the same time and trip "Buffer already has an outstanding map
281
+ // pending." We chain each new dispatch onto the prior one's promise so they
282
+ // run sequentially even when fired from independent async paths (e.g., a
283
+ // training loop's auxiliary `refreshPrediction()` + `writeDiagnostic()`).
284
+ let pending: Promise<unknown> = Promise.resolve()
285
+ async function dispatch(
286
+ inputs: Record<string, Int32Array | Float32Array>,
287
+ wantCaptures: boolean,
288
+ ): Promise<{ output: Float32Array; captures: Map<string, Float32Array> }> {
289
+ const turn = pending.catch(() => {}).then(() => dispatchUnsynchronized(inputs, wantCaptures))
290
+ pending = turn
291
+ return turn
292
+ }
293
+ async function dispatchUnsynchronized(
294
+ inputs: Record<string, Int32Array | Float32Array>,
295
+ wantCaptures: boolean,
296
+ ): Promise<{ output: Float32Array; captures: Map<string, Float32Array> }> {
297
+ if (wantCaptures && plan.capturesByName.size === 0) {
298
+ throw new Error(
299
+ `withCaptures=true but no capture(...) calls were registered during ` +
300
+ `the trace. Add capture('name', tensor) inside your forward pass for ` +
301
+ `the intermediates you want read back.`,
302
+ )
303
+ }
304
+ for (const [name, bufId] of plan.inputsByName) {
305
+ const data = inputs[name]
306
+ if (!data) throw new Error(`tensorgrad: missing input '${name}'`)
307
+ const expectedBytes = plan.buffers[bufId]!.byteSize
308
+ if (data.byteLength !== expectedBytes) {
309
+ throw new Error(`tensorgrad: input '${name}' has ${data.byteLength} bytes, expected ${expectedBytes}`)
310
+ }
311
+ // Cast to BufferSource: typed arrays are accepted by writeBuffer at runtime
312
+ // but TS may infer ArrayBufferLike (vs ArrayBuffer) under strict configs.
313
+ queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
314
+ }
315
+
316
+ const encoder = device.createCommandEncoder({ label: 'tensorgrad-step' })
317
+ for (let i = 0; i < kernels.length; i++) {
318
+ const k = kernels[i]!
319
+ if (!k.wgsl || k.threads === 0) continue
320
+ const pipeline = pipelines[i]!
321
+ const bindGroup = bindGroups[i]!
322
+ const pass = encoder.beginComputePass({ label: k.opKind })
323
+ pass.setPipeline(pipeline)
324
+ pass.setBindGroup(0, bindGroup)
325
+ // WebGPU caps each dispatch dimension at 65535 workgroups. Split into 2D
326
+ // when a kernel needs more than that on the X axis. Kernels compute their
327
+ // global index as `gid.x + gid.y * (65535 * workgroup_size)`, matching the
328
+ // stride we set here. For dispatches that fit in one row, gid.y is 0.
329
+ const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize))
330
+ const MAX_X = 65535
331
+ const wgX = Math.min(wgCount, MAX_X)
332
+ const wgY = Math.ceil(wgCount / MAX_X)
333
+ pass.dispatchWorkgroups(wgX, wgY, 1)
334
+ pass.end()
335
+ }
336
+ // After all dispatches: writebacks (Adam state, updated params). Empty for
337
+ // forward-only compiles.
338
+ for (const wb of plan.writebacks) {
339
+ encoder.copyBufferToBuffer(buffers.get(wb.source)!, 0, buffers.get(wb.dest)!, 0, wb.bytes)
340
+ }
341
+ encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0, outputReadback, 0, outputSpec.byteSize)
342
+ // Capture readbacks (only when opted in). Queued before submit so they
343
+ // observe the same kernel outputs as the main output.
344
+ let stagings: Map<string, GPUBuffer> | null = null
345
+ if (wantCaptures) {
346
+ stagings = ensureCaptureStagings()
347
+ for (const [name, bufId] of plan.capturesByName) {
348
+ const spec = plan.buffers[bufId]!
349
+ encoder.copyBufferToBuffer(buffers.get(bufId)!, 0, stagings.get(name)!, 0, spec.byteSize)
350
+ }
351
+ }
352
+ queue.submit([encoder.finish()])
353
+
354
+ await outputReadback.mapAsync(GPUMapMode.READ)
355
+ const output = new Float32Array(outputReadback.getMappedRange().slice(0))
356
+ outputReadback.unmap()
357
+
358
+ const captures = new Map<string, Float32Array>()
359
+ if (wantCaptures) {
360
+ for (const [name, staging] of stagings!) {
361
+ await staging.mapAsync(GPUMapMode.READ)
362
+ captures.set(name, new Float32Array(staging.getMappedRange().slice(0)))
363
+ staging.unmap()
364
+ }
365
+ }
366
+ return { output, captures }
367
+ }
368
+
369
+ // ---- step() — training-mode wrapper, returns scalar [0] of output ---------
370
+ function step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
371
+ function step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
372
+ function step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>
373
+ async function step(
374
+ inputs: Record<string, Int32Array | Float32Array>,
375
+ opts?: RunOptions,
376
+ ): Promise<number | StepResult> {
377
+ const r = await dispatch(inputs, opts?.withCaptures === true)
378
+ if (opts?.withCaptures) return { loss: r.output[0]!, captures: new Captures(captureShapes, r.captures) }
379
+ return r.output[0]!
380
+ }
381
+
382
+ // ---- run() — forward-mode wrapper, returns Float32Array by default -------
383
+ // Same overloaded shape as step(): scalar-shaped result (here Float32Array,
384
+ // there a JS number) is the default; { ..., captures } is the opt-in form.
385
+ function run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
386
+ function run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunResult>
387
+ function run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunResult>
388
+ async function run(
389
+ inputs: Record<string, Int32Array | Float32Array>,
390
+ opts?: RunOptions,
391
+ ): Promise<Float32Array | RunResult> {
392
+ const r = await dispatch(inputs, opts?.withCaptures === true)
393
+ if (opts?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) }
394
+ return r.output
395
+ }
396
+
397
+ // ---- uploadParams ---------------------------------------------------------
398
+ function uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions) {
399
+ const partial = opts?.partial ?? false
400
+ for (const name of Object.keys(params)) {
401
+ if (!plan.paramsByName.has(name)) {
402
+ throw new Error(
403
+ `uploadParams: unknown param '${name}'. ` +
404
+ `Known: ${[...plan.paramsByName.keys()].sort().join(', ')}`,
405
+ )
406
+ }
407
+ }
408
+ if (!partial) {
409
+ for (const name of plan.paramsByName.keys()) {
410
+ if (!(name in params)) {
411
+ throw new Error(
412
+ `uploadParams: missing param '${name}'. ` +
413
+ `Pass { partial: true } if you mean to update only some params.`,
414
+ )
415
+ }
416
+ }
417
+ }
418
+ for (const [name, bufId] of plan.paramsByName) {
419
+ const data = params[name]
420
+ if (!data) continue
421
+ const expected = plan.buffers[bufId]!.byteSize / 4
422
+ if (data.length !== expected) {
423
+ throw new Error(`uploadParams: '${name}' has ${data.length} elements, expected ${expected}`)
424
+ }
425
+ queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
426
+ }
427
+ }
428
+
429
+ // ---- download helpers -----------------------------------------------------
430
+ async function downloadFromMap(map: Map<string, number>): Promise<Record<string, Float32Array>> {
431
+ const stagings: { name: string; buf: GPUBuffer; bytes: number }[] = []
432
+ const encoder = device.createCommandEncoder({ label: 'tensorgrad-download' })
433
+ for (const [name, bufId] of map) {
434
+ const spec = plan.buffers[bufId]!
435
+ const staging = device.createBuffer({ size: spec.byteSize, usage: READBACK })
436
+ encoder.copyBufferToBuffer(buffers.get(bufId)!, 0, staging, 0, spec.byteSize)
437
+ stagings.push({ name, buf: staging, bytes: spec.byteSize })
438
+ }
439
+ queue.submit([encoder.finish()])
440
+ const out: Record<string, Float32Array> = {}
441
+ for (const s of stagings) {
442
+ await s.buf.mapAsync(GPUMapMode.READ)
443
+ out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0))
444
+ s.buf.unmap()
445
+ s.buf.destroy()
446
+ }
447
+ return out
448
+ }
449
+
450
+ // Fill a state buffer with its declared initValue (typically 0). Float and
451
+ // int both serialize to 4 bytes per element. Used at allocation time and on
452
+ // resetOptimizerState() — same logic, two callers.
453
+ function fillStateBuffer(spec: { byteSize: number; dtype: 'f32' | 'i32' | 'bool'; initValue?: number }, target: GPUBuffer): void {
454
+ const elements = spec.byteSize / 4
455
+ const init = spec.dtype === 'f32'
456
+ ? new Float32Array(elements).fill(spec.initValue ?? 0)
457
+ : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
458
+ queue.writeBuffer(target, 0, init as unknown as BufferSource)
459
+ }
460
+
461
+ function resetOptimizerState() {
462
+ for (const spec of plan.buffers) {
463
+ if (spec.kind === 'state') fillStateBuffer(spec, buffers.get(spec.id)!)
464
+ }
465
+ }
466
+
467
+ // Build the params map AFTER buffer allocation so it points at the actual
468
+ // GPUBuffers (shared or freshly allocated).
469
+ const params = new Map<string, GPUBuffer>()
470
+ for (const [name, bufId] of plan.paramsByName) {
471
+ params.set(name, buffers.get(bufId)!)
472
+ }
473
+ // Static-after-compile shape metadata so users don't have to recompute
474
+ // strides to interpret a flat capture readback.
475
+ const captureShapes: Record<string, number[]> = {}
476
+ for (const [name, bufId] of plan.capturesByName) {
477
+ captureShapes[name] = [...plan.buffers[bufId]!.shape]
478
+ }
479
+ const outputShape = [...plan.buffers[lossBufferId]!.shape]
480
+
481
+ const destroy = () => {
482
+ for (const [id, b] of buffers) {
483
+ if (ownedBufferIds.has(id)) b.destroy()
484
+ }
485
+ outputReadback.destroy()
486
+ if (captureStagings) for (const b of captureStagings.values()) b.destroy()
487
+ }
488
+
489
+ return {
490
+ device,
491
+ params,
492
+ outputShape,
493
+ uploadParams,
494
+ downloadParams: () => downloadFromMap(plan.paramsByName),
495
+ downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),
496
+ step,
497
+ run,
498
+ resetOptimizerState,
499
+ destroy,
500
+ }
501
+ }
502
+
503
+ /** Same machinery as `createRuntime`, narrower public type: a forward-only
504
+ * graph exposes `run()` instead of `step()` (no optimizer state, no scalar-
505
+ * loss readback). The full runtime object is built once and projected by
506
+ * `compileForward` to the public shape. */
507
+ export async function createForwardRuntime(
508
+ plan: BufferPlan,
509
+ kernels: KernelSpec[],
510
+ outputBufferId: number,
511
+ opts: RuntimeOpts = {},
512
+ ): Promise<CompiledForward> {
513
+ return await createRuntime(plan, kernels, outputBufferId, opts)
514
+ }
515
+
516
+ async function acquireDevice(): Promise<GPUDevice> {
517
+ if (typeof navigator === 'undefined' || !navigator.gpu) {
518
+ throw new Error('tensorgrad: WebGPU not available in this environment')
519
+ }
520
+ const adapter = await navigator.gpu.requestAdapter()
521
+ if (!adapter) throw new Error('tensorgrad: no WebGPU adapter')
522
+ return await adapter.requestDevice()
523
+ }