tensorgrad 0.0.11 → 0.0.13
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +119 -119
- package/dist/buffers.js +1 -6
- package/dist/buffers.js.map +1 -1
- package/dist/codegen.js +30 -28
- package/dist/codegen.js.map +1 -1
- package/dist/compile.js +39 -68
- package/dist/compile.js.map +1 -1
- package/dist/grad.js +1 -14
- package/dist/grad.js.map +1 -1
- package/dist/index.d.ts +740 -14
- package/dist/runtime.js +9 -11
- package/dist/runtime.js.map +1 -1
- package/dist/trace.js +8 -13
- package/dist/trace.js.map +1 -1
- package/package.json +67 -61
- package/src/buffers.ts +1 -6
- package/src/codegen.ts +31 -28
- package/src/compile.ts +45 -91
- package/src/grad.ts +1 -11
- package/src/index.ts +47 -47
- package/src/runtime.ts +520 -515
- package/src/trace.ts +12 -9
- package/dist/adam.d.ts +0 -65
- package/dist/adam.d.ts.map +0 -1
- package/dist/buffers.d.ts +0 -57
- package/dist/buffers.d.ts.map +0 -1
- package/dist/capture.d.ts +0 -3
- package/dist/capture.d.ts.map +0 -1
- package/dist/codegen.d.ts +0 -23
- package/dist/codegen.d.ts.map +0 -1
- package/dist/compile.d.ts +0 -130
- package/dist/compile.d.ts.map +0 -1
- package/dist/grad.d.ts +0 -8
- package/dist/grad.d.ts.map +0 -1
- package/dist/index.d.ts.map +0 -1
- package/dist/ir.d.ts +0 -207
- package/dist/ir.d.ts.map +0 -1
- package/dist/module.d.ts +0 -55
- package/dist/module.d.ts.map +0 -1
- package/dist/nn.d.ts +0 -42
- package/dist/nn.d.ts.map +0 -1
- package/dist/ops.d.ts +0 -48
- package/dist/ops.d.ts.map +0 -1
- package/dist/runtime.d.ts +0 -108
- package/dist/runtime.d.ts.map +0 -1
- package/dist/shape.d.ts +0 -24
- package/dist/shape.d.ts.map +0 -1
- package/dist/trace.d.ts +0 -9
- package/dist/trace.d.ts.map +0 -1
package/src/runtime.ts
CHANGED
|
@@ -1,515 +1,520 @@
|
|
|
1
|
-
// WebGPU runtime. Reads a BufferPlan + KernelSpec[] (produced by codegen),
|
|
2
|
-
// allocates real GPU buffers and pipelines, and provides a `step()` method
|
|
3
|
-
// that uploads inputs, dispatches all kernels, and reads back outputs.
|
|
4
|
-
//
|
|
5
|
-
// Browser-only: this module needs `navigator.gpu` at runtime.
|
|
6
|
-
|
|
7
|
-
import type { BufferPlan } from './buffers.js'
|
|
8
|
-
import type { KernelSpec } from './codegen.js'
|
|
9
|
-
|
|
10
|
-
// TS lib.dom defines WebGPU types but not the GPUMapMode runtime constant.
|
|
11
|
-
// Provided by the browser per WebGPU spec; declare just what we use.
|
|
12
|
-
declare const GPUMapMode: { readonly READ: number; readonly WRITE: number }
|
|
13
|
-
|
|
14
|
-
export interface UploadParamsOptions {
|
|
15
|
-
/** Skip the "missing param" check, allowing the caller to update only some
|
|
16
|
-
* params and leave the rest at their current GPU values. Extra (unknown)
|
|
17
|
-
* keys are still rejected — that's always a typo. Default: false. */
|
|
18
|
-
partial?: boolean
|
|
19
|
-
}
|
|
20
|
-
|
|
21
|
-
/**
|
|
22
|
-
* Activation readbacks for one `step()`/`run()` call. Keyed by the names
|
|
23
|
-
* passed to `capture(name, t)` during the trace. `get(name)` throws if the
|
|
24
|
-
* name isn't registered or wasn't read back this call (i.e., the call was
|
|
25
|
-
* made without `{ withCaptures: true }`); use `has(name)` if you need to
|
|
26
|
-
* branch. `shapeOf(name)` returns the static-after-compile shape and works
|
|
27
|
-
* regardless of whether captures were read back.
|
|
28
|
-
*/
|
|
29
|
-
export class Captures {
|
|
30
|
-
constructor(
|
|
31
|
-
private readonly shapes: Record<string, readonly number[]>,
|
|
32
|
-
private readonly data: Map<string, Float32Array>,
|
|
33
|
-
) {}
|
|
34
|
-
get(name: string): Float32Array {
|
|
35
|
-
const d = this.data.get(name)
|
|
36
|
-
if (!d) {
|
|
37
|
-
const known = [...this.data.keys()].sort().join(', ')
|
|
38
|
-
const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`
|
|
39
|
-
throw new Error(`Captures.get: '${name}' not present. ${detail}`)
|
|
40
|
-
}
|
|
41
|
-
return d
|
|
42
|
-
}
|
|
43
|
-
shapeOf(name: string): readonly number[] {
|
|
44
|
-
const s = this.shapes[name]
|
|
45
|
-
if (!s) {
|
|
46
|
-
const known = Object.keys(this.shapes).sort().join(', ') || '(none registered)'
|
|
47
|
-
throw new Error(`Captures.shapeOf: '${name}' not registered. Known: ${known}`)
|
|
48
|
-
}
|
|
49
|
-
return s
|
|
50
|
-
}
|
|
51
|
-
has(name: string): boolean { return this.data.has(name) }
|
|
52
|
-
names(): string[] { return [...this.data.keys()].sort() }
|
|
53
|
-
}
|
|
54
|
-
|
|
55
|
-
export interface RunResult {
|
|
56
|
-
output: Float32Array
|
|
57
|
-
captures: Captures
|
|
58
|
-
}
|
|
59
|
-
|
|
60
|
-
export interface StepResult {
|
|
61
|
-
loss: number
|
|
62
|
-
captures: Captures
|
|
63
|
-
}
|
|
64
|
-
|
|
65
|
-
export interface RunOptions {
|
|
66
|
-
/** Read back tensors registered via `capture(name, t)` during the trace.
|
|
67
|
-
* Default false. When false, the returned `captures` is empty (calling
|
|
68
|
-
* `.get` throws); when true, captures are read back and accessible. */
|
|
69
|
-
withCaptures?: boolean
|
|
70
|
-
}
|
|
71
|
-
|
|
72
|
-
/** Common surface for both training and forward-only compiled runtimes. */
|
|
73
|
-
export interface CompiledBase {
|
|
74
|
-
/** The GPUDevice this runtime is bound to. Pass to sibling compiles to
|
|
75
|
-
* share the device, or use directly for other GPU work. */
|
|
76
|
-
device: GPUDevice
|
|
77
|
-
/** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
|
|
78
|
-
* `sharedParams` to share without copies. */
|
|
79
|
-
params: Map<string, GPUBuffer>
|
|
80
|
-
/** Shape of the graph's output (loss scalar `[]` for training; the user's
|
|
81
|
-
* returned tensor for forward-only compiles). */
|
|
82
|
-
outputShape: number[]
|
|
83
|
-
/** Upload parameter Float32Arrays to their GPU buffers. By default, requires
|
|
84
|
-
* *all* params to be present; throws on any unknown or missing key. Pass
|
|
85
|
-
* `{ partial: true }` to skip the missing-key check. */
|
|
86
|
-
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void
|
|
87
|
-
/** Read all parameters back as Float32Arrays — used for UI panels. */
|
|
88
|
-
downloadParams(): Promise<Record<string, Float32Array>>
|
|
89
|
-
/** Free GPU resources. */
|
|
90
|
-
destroy(): void
|
|
91
|
-
}
|
|
92
|
-
|
|
93
|
-
/** Run a dispatch and read back the full output tensor.
|
|
94
|
-
*
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
/**
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
*
|
|
107
|
-
*
|
|
108
|
-
*
|
|
109
|
-
*
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
step(inputs: Record<string, Int32Array | Float32Array
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
run
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
/**
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
*
|
|
136
|
-
*
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
//
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
//
|
|
158
|
-
//
|
|
159
|
-
//
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
}
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
//
|
|
191
|
-
//
|
|
192
|
-
//
|
|
193
|
-
|
|
194
|
-
const
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
const
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
device.
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
const
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
const
|
|
218
|
-
|
|
219
|
-
const
|
|
220
|
-
|
|
221
|
-
.
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
//
|
|
248
|
-
//
|
|
249
|
-
//
|
|
250
|
-
|
|
251
|
-
const
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
//
|
|
255
|
-
//
|
|
256
|
-
//
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
const
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
//
|
|
271
|
-
//
|
|
272
|
-
//
|
|
273
|
-
//
|
|
274
|
-
//
|
|
275
|
-
//
|
|
276
|
-
//
|
|
277
|
-
//
|
|
278
|
-
//
|
|
279
|
-
//
|
|
280
|
-
//
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
`
|
|
298
|
-
`the
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
//
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
const
|
|
319
|
-
const
|
|
320
|
-
|
|
321
|
-
pass.
|
|
322
|
-
|
|
323
|
-
//
|
|
324
|
-
//
|
|
325
|
-
//
|
|
326
|
-
|
|
327
|
-
const
|
|
328
|
-
const
|
|
329
|
-
const
|
|
330
|
-
|
|
331
|
-
pass.
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
//
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
//
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
}
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
function step(inputs: Record<string, Int32Array | Float32Array
|
|
369
|
-
function step(inputs: Record<string, Int32Array | Float32Array>, opts:
|
|
370
|
-
function step(
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
//
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
): Promise<RunResult>
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
//
|
|
466
|
-
|
|
467
|
-
const
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
const
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
1
|
+
// WebGPU runtime. Reads a BufferPlan + KernelSpec[] (produced by codegen),
|
|
2
|
+
// allocates real GPU buffers and pipelines, and provides a `step()` method
|
|
3
|
+
// that uploads inputs, dispatches all kernels, and reads back outputs.
|
|
4
|
+
//
|
|
5
|
+
// Browser-only: this module needs `navigator.gpu` at runtime.
|
|
6
|
+
|
|
7
|
+
import type { BufferPlan } from './buffers.js'
|
|
8
|
+
import type { KernelSpec } from './codegen.js'
|
|
9
|
+
|
|
10
|
+
// TS lib.dom defines WebGPU types but not the GPUMapMode runtime constant.
|
|
11
|
+
// Provided by the browser per WebGPU spec; declare just what we use.
|
|
12
|
+
declare const GPUMapMode: { readonly READ: number; readonly WRITE: number }
|
|
13
|
+
|
|
14
|
+
export interface UploadParamsOptions {
|
|
15
|
+
/** Skip the "missing param" check, allowing the caller to update only some
|
|
16
|
+
* params and leave the rest at their current GPU values. Extra (unknown)
|
|
17
|
+
* keys are still rejected — that's always a typo. Default: false. */
|
|
18
|
+
partial?: boolean
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
/**
|
|
22
|
+
* Activation readbacks for one `step()`/`run()` call. Keyed by the names
|
|
23
|
+
* passed to `capture(name, t)` during the trace. `get(name)` throws if the
|
|
24
|
+
* name isn't registered or wasn't read back this call (i.e., the call was
|
|
25
|
+
* made without `{ withCaptures: true }`); use `has(name)` if you need to
|
|
26
|
+
* branch. `shapeOf(name)` returns the static-after-compile shape and works
|
|
27
|
+
* regardless of whether captures were read back.
|
|
28
|
+
*/
|
|
29
|
+
export class Captures {
|
|
30
|
+
constructor(
|
|
31
|
+
private readonly shapes: Record<string, readonly number[]>,
|
|
32
|
+
private readonly data: Map<string, Float32Array>,
|
|
33
|
+
) {}
|
|
34
|
+
get(name: string): Float32Array {
|
|
35
|
+
const d = this.data.get(name)
|
|
36
|
+
if (!d) {
|
|
37
|
+
const known = [...this.data.keys()].sort().join(', ')
|
|
38
|
+
const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`
|
|
39
|
+
throw new Error(`Captures.get: '${name}' not present. ${detail}`)
|
|
40
|
+
}
|
|
41
|
+
return d
|
|
42
|
+
}
|
|
43
|
+
shapeOf(name: string): readonly number[] {
|
|
44
|
+
const s = this.shapes[name]
|
|
45
|
+
if (!s) {
|
|
46
|
+
const known = Object.keys(this.shapes).sort().join(', ') || '(none registered)'
|
|
47
|
+
throw new Error(`Captures.shapeOf: '${name}' not registered. Known: ${known}`)
|
|
48
|
+
}
|
|
49
|
+
return s
|
|
50
|
+
}
|
|
51
|
+
has(name: string): boolean { return this.data.has(name) }
|
|
52
|
+
names(): string[] { return [...this.data.keys()].sort() }
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
export interface RunResult {
|
|
56
|
+
output: Float32Array
|
|
57
|
+
captures: Captures
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
export interface StepResult {
|
|
61
|
+
loss: number
|
|
62
|
+
captures: Captures
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
export interface RunOptions {
|
|
66
|
+
/** Read back tensors registered via `capture(name, t)` during the trace.
|
|
67
|
+
* Default false. When false, the returned `captures` is empty (calling
|
|
68
|
+
* `.get` throws); when true, captures are read back and accessible. */
|
|
69
|
+
withCaptures?: boolean
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
/** Common surface for both training and forward-only compiled runtimes. */
|
|
73
|
+
export interface CompiledBase {
|
|
74
|
+
/** The GPUDevice this runtime is bound to. Pass to sibling compiles to
|
|
75
|
+
* share the device, or use directly for other GPU work. */
|
|
76
|
+
device: GPUDevice
|
|
77
|
+
/** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
|
|
78
|
+
* `sharedParams` to share without copies. */
|
|
79
|
+
params: Map<string, GPUBuffer>
|
|
80
|
+
/** Shape of the graph's output (loss scalar `[]` for training; the user's
|
|
81
|
+
* returned tensor for forward-only compiles). */
|
|
82
|
+
outputShape: number[]
|
|
83
|
+
/** Upload parameter Float32Arrays to their GPU buffers. By default, requires
|
|
84
|
+
* *all* params to be present; throws on any unknown or missing key. Pass
|
|
85
|
+
* `{ partial: true }` to skip the missing-key check. */
|
|
86
|
+
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void
|
|
87
|
+
/** Read all parameters back as Float32Arrays — used for UI panels. */
|
|
88
|
+
downloadParams(): Promise<Record<string, Float32Array>>
|
|
89
|
+
/** Free GPU resources. */
|
|
90
|
+
destroy(): void
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
/** Run a dispatch and read back the full output tensor. Default returns the
|
|
94
|
+
* output as a `Float32Array`; with `{ withCaptures: true }` returns
|
|
95
|
+
* `{ output, captures }`. Same shape as `step()`'s overloads. */
|
|
96
|
+
export interface RunFn {
|
|
97
|
+
(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
|
|
98
|
+
(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunResult>
|
|
99
|
+
(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunResult>
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
export interface CompiledRuntime extends CompiledBase {
|
|
103
|
+
/** Read all parameter gradients back. Mostly for verification / debugging. */
|
|
104
|
+
downloadParamGrads(): Promise<Record<string, Float32Array>>
|
|
105
|
+
/**
|
|
106
|
+
* One full forward+backward step.
|
|
107
|
+
* 1. Uploads `inputs` (tokens, targets, masks) to input buffers.
|
|
108
|
+
* 2. Dispatches every kernel in order.
|
|
109
|
+
* 3. Reads back the loss scalar (and any registered captures, if requested).
|
|
110
|
+
* Default returns the loss as a JS number; with `{ withCaptures: true }`
|
|
111
|
+
* returns `{ loss, captures }`.
|
|
112
|
+
*/
|
|
113
|
+
step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
|
|
114
|
+
step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
|
|
115
|
+
step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>
|
|
116
|
+
/** Same dispatch as step() but returns the full output Float32Array — for
|
|
117
|
+
* training graphs the output is a scalar loss, so step() is usually more
|
|
118
|
+
* convenient. Provided for parity with `compileForward`. */
|
|
119
|
+
run: RunFn
|
|
120
|
+
/** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
|
|
121
|
+
* `uploadInitialParams()` for a full training reset without recompile. */
|
|
122
|
+
resetOptimizerState(): void
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
/** Forward-only compiled runtime — produced by `compileForward`. No optimizer,
|
|
126
|
+
* no backward. Returns the output tensor (not just a scalar) per `run()` call. */
|
|
127
|
+
export interface CompiledForward extends CompiledBase {
|
|
128
|
+
run: RunFn
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
export interface RuntimeOpts {
|
|
132
|
+
/** Pre-acquired GPUDevice. If omitted, runtime requests its own. */
|
|
133
|
+
device?: GPUDevice
|
|
134
|
+
/** External param buffers to bind in place of allocating fresh ones, keyed
|
|
135
|
+
* by param name. Used to share params between a training compile and a
|
|
136
|
+
* sibling forward-only compile (e.g., a B=1 inference graph). When a name
|
|
137
|
+
* is in this map, the runtime reuses the provided GPUBuffer; otherwise it
|
|
138
|
+
* allocates as usual. */
|
|
139
|
+
sharedParams?: Map<string, GPUBuffer>
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
// Inlined numeric values (per WebGPU spec) so this module is importable in Node
|
|
143
|
+
// for codegen-only usage. The browser provides GPUBufferUsage as a global, but
|
|
144
|
+
// referencing it at module scope would crash before any browser code runs.
|
|
145
|
+
const STORAGE_RW = 0x80 /*STORAGE*/ | 0x8 /*COPY_DST*/ | 0x4 /*COPY_SRC*/
|
|
146
|
+
const READBACK = 0x1 /*MAP_READ*/ | 0x8 /*COPY_DST*/
|
|
147
|
+
|
|
148
|
+
export async function createRuntime(
|
|
149
|
+
plan: BufferPlan,
|
|
150
|
+
kernels: KernelSpec[],
|
|
151
|
+
lossBufferId: number,
|
|
152
|
+
opts: RuntimeOpts = {},
|
|
153
|
+
): Promise<CompiledRuntime> {
|
|
154
|
+
const device = opts.device ?? await acquireDevice()
|
|
155
|
+
const queue = device.queue
|
|
156
|
+
|
|
157
|
+
// ---- Allocate one GPUBuffer per BufferSpec --------------------------------
|
|
158
|
+
// State buffers also get filled with their initValue at allocation time.
|
|
159
|
+
// Param buffers may be supplied externally via opts.sharedParams; in that
|
|
160
|
+
// case we reuse the provided GPUBuffer instead of allocating, and the
|
|
161
|
+
// sibling compile that owns it is responsible for upload + lifetime.
|
|
162
|
+
// ownedBufferIds tracks which buffers we allocated ourselves (and so must
|
|
163
|
+
// destroy on .destroy()) vs which were handed in by a sibling compile.
|
|
164
|
+
const buffers = new Map<number, GPUBuffer>()
|
|
165
|
+
const ownedBufferIds = new Set<number>()
|
|
166
|
+
const sharedParams = opts.sharedParams
|
|
167
|
+
for (const spec of plan.buffers) {
|
|
168
|
+
const shared = spec.kind === 'param' ? sharedParams?.get(spec.name!) : undefined
|
|
169
|
+
if (shared) {
|
|
170
|
+
if (shared.size !== spec.byteSize) {
|
|
171
|
+
throw new Error(
|
|
172
|
+
`sharedParams: size mismatch for '${spec.name}' — supplied ${shared.size} bytes, ` +
|
|
173
|
+
`compiled graph expects ${spec.byteSize}.`,
|
|
174
|
+
)
|
|
175
|
+
}
|
|
176
|
+
buffers.set(spec.id, shared)
|
|
177
|
+
continue
|
|
178
|
+
}
|
|
179
|
+
const buf = device.createBuffer({
|
|
180
|
+
size: spec.byteSize,
|
|
181
|
+
usage: STORAGE_RW,
|
|
182
|
+
label: spec.name ?? `t${spec.id}-${spec.kind}`,
|
|
183
|
+
})
|
|
184
|
+
buffers.set(spec.id, buf)
|
|
185
|
+
ownedBufferIds.add(spec.id)
|
|
186
|
+
if (spec.kind === 'state') fillStateBuffer(spec, buf)
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
// ---- Compile pipelines per kernel; cache by WGSL source -------------------
|
|
190
|
+
// Push an error scope around each shader+pipeline creation so we can surface
|
|
191
|
+
// the actual compile error rather than the cryptic "previous error" that
|
|
192
|
+
// comes from using an invalid pipeline at dispatch time.
|
|
193
|
+
const moduleCache = new Map<string, GPUShaderModule>()
|
|
194
|
+
const pipelines: (GPUComputePipeline | null)[] = []
|
|
195
|
+
type ErrorProbe = Promise<{ k: KernelSpec; module: GPUShaderModule; err: GPUError } | null>
|
|
196
|
+
const probes: ErrorProbe[] = []
|
|
197
|
+
for (const k of kernels) {
|
|
198
|
+
if (!k.wgsl) { pipelines.push(null); continue }
|
|
199
|
+
let module = moduleCache.get(k.wgsl)
|
|
200
|
+
if (!module) {
|
|
201
|
+
module = device.createShaderModule({ code: k.wgsl, label: k.opKind })
|
|
202
|
+
moduleCache.set(k.wgsl, module)
|
|
203
|
+
}
|
|
204
|
+
device.pushErrorScope('validation')
|
|
205
|
+
const pipeline = device.createComputePipeline({
|
|
206
|
+
layout: 'auto',
|
|
207
|
+
compute: { module, entryPoint: 'main' },
|
|
208
|
+
label: k.opKind,
|
|
209
|
+
})
|
|
210
|
+
pipelines.push(pipeline)
|
|
211
|
+
probes.push(device.popErrorScope().then(err => err ? { k, module: module!, err } : null))
|
|
212
|
+
}
|
|
213
|
+
const probeResults = await Promise.all(probes)
|
|
214
|
+
const failures = probeResults.filter((p): p is { k: KernelSpec; module: GPUShaderModule; err: GPUError } => p != null)
|
|
215
|
+
if (failures.length > 0) {
|
|
216
|
+
const reports: string[] = []
|
|
217
|
+
for (const { k, module, err } of failures) {
|
|
218
|
+
const info = await module.getCompilationInfo()
|
|
219
|
+
const messages = info.messages
|
|
220
|
+
.map(m => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`)
|
|
221
|
+
.join('\n')
|
|
222
|
+
reports.push(
|
|
223
|
+
`[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` +
|
|
224
|
+
(messages || ' (no compilation messages)') +
|
|
225
|
+
`\n--- WGSL ---\n${k.wgsl}\n-----------`,
|
|
226
|
+
)
|
|
227
|
+
}
|
|
228
|
+
// eslint-disable-next-line no-console
|
|
229
|
+
console.error(reports.join('\n\n'))
|
|
230
|
+
throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`)
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
// ---- Pre-build bind groups (static — buffer ids don't change per step) ---
|
|
234
|
+
const bindGroups: (GPUBindGroup | null)[] = kernels.map((k, i) => {
|
|
235
|
+
const pipeline = pipelines[i]
|
|
236
|
+
if (!pipeline) return null
|
|
237
|
+
return device.createBindGroup({
|
|
238
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
239
|
+
entries: k.bindings.map((bufId, idx) => ({
|
|
240
|
+
binding: idx,
|
|
241
|
+
resource: { buffer: buffers.get(bufId)! },
|
|
242
|
+
})),
|
|
243
|
+
})
|
|
244
|
+
})
|
|
245
|
+
|
|
246
|
+
// ---- Output readback staging buffer ---------------------------------------
|
|
247
|
+
// `outputBufferId` is the graph's main output (loss for training, the user's
|
|
248
|
+
// returned tensor for forward-only). step() reads back its first element;
|
|
249
|
+
// run() reads back the full Float32Array.
|
|
250
|
+
const outputSpec = plan.buffers[lossBufferId]!
|
|
251
|
+
const outputReadback = device.createBuffer({ size: outputSpec.byteSize, usage: READBACK })
|
|
252
|
+
|
|
253
|
+
// ---- Capture readback staging buffers (lazy) ------------------------------
|
|
254
|
+
// Allocated on first `step({ withCaptures: true })` call and reused across
|
|
255
|
+
// subsequent calls. When the graph has no captures registered or when the
|
|
256
|
+
// caller never opts in, no extra GPU memory is allocated.
|
|
257
|
+
let captureStagings: Map<string, GPUBuffer> | null = null
|
|
258
|
+
function ensureCaptureStagings(): Map<string, GPUBuffer> {
|
|
259
|
+
if (captureStagings) return captureStagings
|
|
260
|
+
captureStagings = new Map()
|
|
261
|
+
for (const [name, bufId] of plan.capturesByName) {
|
|
262
|
+
const spec = plan.buffers[bufId]!
|
|
263
|
+
const staging = device.createBuffer({ size: spec.byteSize, usage: READBACK, label: `cap-${name}` })
|
|
264
|
+
captureStagings.set(name, staging)
|
|
265
|
+
}
|
|
266
|
+
return captureStagings
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
// ---- dispatch() — shared core for step() and run() -----------------------
|
|
270
|
+
// Uploads inputs, dispatches all kernels (in order), queues writebacks, copies
|
|
271
|
+
// the output buffer into its staging, optionally copies captures into theirs,
|
|
272
|
+
// submits, and reads back. Returns the full output Float32Array; step() takes
|
|
273
|
+
// [0] for scalar loss, run() returns it whole.
|
|
274
|
+
//
|
|
275
|
+
// **Concurrent calls auto-serialize.** Two `step()`/`run()` calls on the same
|
|
276
|
+
// runtime would otherwise both try to `mapAsync` the shared output staging
|
|
277
|
+
// buffer at the same time and trip "Buffer already has an outstanding map
|
|
278
|
+
// pending." We chain each new dispatch onto the prior one's promise so they
|
|
279
|
+
// run sequentially even when fired from independent async paths (e.g., a
|
|
280
|
+
// training loop's auxiliary `refreshPrediction()` + `writeDiagnostic()`).
|
|
281
|
+
let pending: Promise<unknown> = Promise.resolve()
|
|
282
|
+
async function dispatch(
|
|
283
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
284
|
+
wantCaptures: boolean,
|
|
285
|
+
): Promise<{ output: Float32Array; captures: Map<string, Float32Array> }> {
|
|
286
|
+
const turn = pending.catch(() => {}).then(() => dispatchUnsynchronized(inputs, wantCaptures))
|
|
287
|
+
pending = turn
|
|
288
|
+
return turn
|
|
289
|
+
}
|
|
290
|
+
async function dispatchUnsynchronized(
|
|
291
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
292
|
+
wantCaptures: boolean,
|
|
293
|
+
): Promise<{ output: Float32Array; captures: Map<string, Float32Array> }> {
|
|
294
|
+
if (wantCaptures && plan.capturesByName.size === 0) {
|
|
295
|
+
throw new Error(
|
|
296
|
+
`withCaptures=true but no capture(...) calls were registered during ` +
|
|
297
|
+
`the trace. Add capture('name', tensor) inside your forward pass for ` +
|
|
298
|
+
`the intermediates you want read back.`,
|
|
299
|
+
)
|
|
300
|
+
}
|
|
301
|
+
for (const [name, bufId] of plan.inputsByName) {
|
|
302
|
+
const data = inputs[name]
|
|
303
|
+
if (!data) throw new Error(`tensorgrad: missing input '${name}'`)
|
|
304
|
+
const expectedBytes = plan.buffers[bufId]!.byteSize
|
|
305
|
+
if (data.byteLength !== expectedBytes) {
|
|
306
|
+
throw new Error(`tensorgrad: input '${name}' has ${data.byteLength} bytes, expected ${expectedBytes}`)
|
|
307
|
+
}
|
|
308
|
+
// Cast to BufferSource: typed arrays are accepted by writeBuffer at runtime
|
|
309
|
+
// but TS may infer ArrayBufferLike (vs ArrayBuffer) under strict configs.
|
|
310
|
+
queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
const encoder = device.createCommandEncoder({ label: 'tensorgrad-step' })
|
|
314
|
+
for (let i = 0; i < kernels.length; i++) {
|
|
315
|
+
const k = kernels[i]!
|
|
316
|
+
if (!k.wgsl || k.threads === 0) continue
|
|
317
|
+
const pipeline = pipelines[i]!
|
|
318
|
+
const bindGroup = bindGroups[i]!
|
|
319
|
+
const pass = encoder.beginComputePass({ label: k.opKind })
|
|
320
|
+
pass.setPipeline(pipeline)
|
|
321
|
+
pass.setBindGroup(0, bindGroup)
|
|
322
|
+
// WebGPU caps each dispatch dimension at 65535 workgroups. Split into 2D
|
|
323
|
+
// when a kernel needs more than that on the X axis. Kernels compute their
|
|
324
|
+
// global index as `gid.x + gid.y * (65535 * workgroup_size)`, matching the
|
|
325
|
+
// stride we set here. For dispatches that fit in one row, gid.y is 0.
|
|
326
|
+
const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize))
|
|
327
|
+
const MAX_X = 65535
|
|
328
|
+
const wgX = Math.min(wgCount, MAX_X)
|
|
329
|
+
const wgY = Math.ceil(wgCount / MAX_X)
|
|
330
|
+
pass.dispatchWorkgroups(wgX, wgY, 1)
|
|
331
|
+
pass.end()
|
|
332
|
+
}
|
|
333
|
+
// After all dispatches: writebacks (Adam state, updated params). Empty for
|
|
334
|
+
// forward-only compiles.
|
|
335
|
+
for (const wb of plan.writebacks) {
|
|
336
|
+
encoder.copyBufferToBuffer(buffers.get(wb.source)!, 0, buffers.get(wb.dest)!, 0, wb.bytes)
|
|
337
|
+
}
|
|
338
|
+
encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0, outputReadback, 0, outputSpec.byteSize)
|
|
339
|
+
// Capture readbacks (only when opted in). Queued before submit so they
|
|
340
|
+
// observe the same kernel outputs as the main output.
|
|
341
|
+
let stagings: Map<string, GPUBuffer> | null = null
|
|
342
|
+
if (wantCaptures) {
|
|
343
|
+
stagings = ensureCaptureStagings()
|
|
344
|
+
for (const [name, bufId] of plan.capturesByName) {
|
|
345
|
+
const spec = plan.buffers[bufId]!
|
|
346
|
+
encoder.copyBufferToBuffer(buffers.get(bufId)!, 0, stagings.get(name)!, 0, spec.byteSize)
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
queue.submit([encoder.finish()])
|
|
350
|
+
|
|
351
|
+
await outputReadback.mapAsync(GPUMapMode.READ)
|
|
352
|
+
const output = new Float32Array(outputReadback.getMappedRange().slice(0))
|
|
353
|
+
outputReadback.unmap()
|
|
354
|
+
|
|
355
|
+
const captures = new Map<string, Float32Array>()
|
|
356
|
+
if (wantCaptures) {
|
|
357
|
+
for (const [name, staging] of stagings!) {
|
|
358
|
+
await staging.mapAsync(GPUMapMode.READ)
|
|
359
|
+
captures.set(name, new Float32Array(staging.getMappedRange().slice(0)))
|
|
360
|
+
staging.unmap()
|
|
361
|
+
}
|
|
362
|
+
}
|
|
363
|
+
return { output, captures }
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
// ---- step() — training-mode wrapper, returns scalar [0] of output ---------
|
|
367
|
+
function step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
|
|
368
|
+
function step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
|
|
369
|
+
function step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>
|
|
370
|
+
async function step(
|
|
371
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
372
|
+
opts?: RunOptions,
|
|
373
|
+
): Promise<number | StepResult> {
|
|
374
|
+
const r = await dispatch(inputs, opts?.withCaptures === true)
|
|
375
|
+
if (opts?.withCaptures) return { loss: r.output[0]!, captures: new Captures(captureShapes, r.captures) }
|
|
376
|
+
return r.output[0]!
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
// ---- run() — forward-mode wrapper, returns Float32Array by default -------
|
|
380
|
+
// Same overloaded shape as step(): scalar-shaped result (here Float32Array,
|
|
381
|
+
// there a JS number) is the default; { ..., captures } is the opt-in form.
|
|
382
|
+
function run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
|
|
383
|
+
function run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunResult>
|
|
384
|
+
function run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunResult>
|
|
385
|
+
async function run(
|
|
386
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
387
|
+
opts?: RunOptions,
|
|
388
|
+
): Promise<Float32Array | RunResult> {
|
|
389
|
+
const r = await dispatch(inputs, opts?.withCaptures === true)
|
|
390
|
+
if (opts?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) }
|
|
391
|
+
return r.output
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
// ---- uploadParams ---------------------------------------------------------
|
|
395
|
+
function uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions) {
|
|
396
|
+
const partial = opts?.partial ?? false
|
|
397
|
+
for (const name of Object.keys(params)) {
|
|
398
|
+
if (!plan.paramsByName.has(name)) {
|
|
399
|
+
throw new Error(
|
|
400
|
+
`uploadParams: unknown param '${name}'. ` +
|
|
401
|
+
`Known: ${[...plan.paramsByName.keys()].sort().join(', ')}`,
|
|
402
|
+
)
|
|
403
|
+
}
|
|
404
|
+
}
|
|
405
|
+
if (!partial) {
|
|
406
|
+
for (const name of plan.paramsByName.keys()) {
|
|
407
|
+
if (!(name in params)) {
|
|
408
|
+
throw new Error(
|
|
409
|
+
`uploadParams: missing param '${name}'. ` +
|
|
410
|
+
`Pass { partial: true } if you mean to update only some params.`,
|
|
411
|
+
)
|
|
412
|
+
}
|
|
413
|
+
}
|
|
414
|
+
}
|
|
415
|
+
for (const [name, bufId] of plan.paramsByName) {
|
|
416
|
+
const data = params[name]
|
|
417
|
+
if (!data) continue
|
|
418
|
+
const expected = plan.buffers[bufId]!.byteSize / 4
|
|
419
|
+
if (data.length !== expected) {
|
|
420
|
+
throw new Error(`uploadParams: '${name}' has ${data.length} elements, expected ${expected}`)
|
|
421
|
+
}
|
|
422
|
+
queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
|
|
423
|
+
}
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
// ---- download helpers -----------------------------------------------------
|
|
427
|
+
async function downloadFromMap(map: Map<string, number>): Promise<Record<string, Float32Array>> {
|
|
428
|
+
const stagings: { name: string; buf: GPUBuffer; bytes: number }[] = []
|
|
429
|
+
const encoder = device.createCommandEncoder({ label: 'tensorgrad-download' })
|
|
430
|
+
for (const [name, bufId] of map) {
|
|
431
|
+
const spec = plan.buffers[bufId]!
|
|
432
|
+
const staging = device.createBuffer({ size: spec.byteSize, usage: READBACK })
|
|
433
|
+
encoder.copyBufferToBuffer(buffers.get(bufId)!, 0, staging, 0, spec.byteSize)
|
|
434
|
+
stagings.push({ name, buf: staging, bytes: spec.byteSize })
|
|
435
|
+
}
|
|
436
|
+
queue.submit([encoder.finish()])
|
|
437
|
+
const out: Record<string, Float32Array> = {}
|
|
438
|
+
for (const s of stagings) {
|
|
439
|
+
await s.buf.mapAsync(GPUMapMode.READ)
|
|
440
|
+
out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0))
|
|
441
|
+
s.buf.unmap()
|
|
442
|
+
s.buf.destroy()
|
|
443
|
+
}
|
|
444
|
+
return out
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
// Fill a state buffer with its declared initValue (typically 0). Float and
|
|
448
|
+
// int both serialize to 4 bytes per element. Used at allocation time and on
|
|
449
|
+
// resetOptimizerState() — same logic, two callers.
|
|
450
|
+
function fillStateBuffer(spec: { byteSize: number; dtype: 'f32' | 'i32' | 'bool'; initValue?: number }, target: GPUBuffer): void {
|
|
451
|
+
const elements = spec.byteSize / 4
|
|
452
|
+
const init = spec.dtype === 'f32'
|
|
453
|
+
? new Float32Array(elements).fill(spec.initValue ?? 0)
|
|
454
|
+
: new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
|
|
455
|
+
queue.writeBuffer(target, 0, init as unknown as BufferSource)
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
function resetOptimizerState() {
|
|
459
|
+
for (const spec of plan.buffers) {
|
|
460
|
+
if (spec.kind === 'state') fillStateBuffer(spec, buffers.get(spec.id)!)
|
|
461
|
+
}
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
// Build the params map AFTER buffer allocation so it points at the actual
|
|
465
|
+
// GPUBuffers (shared or freshly allocated).
|
|
466
|
+
const params = new Map<string, GPUBuffer>()
|
|
467
|
+
for (const [name, bufId] of plan.paramsByName) {
|
|
468
|
+
params.set(name, buffers.get(bufId)!)
|
|
469
|
+
}
|
|
470
|
+
// Static-after-compile shape metadata so users don't have to recompute
|
|
471
|
+
// strides to interpret a flat capture readback.
|
|
472
|
+
const captureShapes: Record<string, number[]> = {}
|
|
473
|
+
for (const [name, bufId] of plan.capturesByName) {
|
|
474
|
+
captureShapes[name] = [...plan.buffers[bufId]!.shape]
|
|
475
|
+
}
|
|
476
|
+
const outputShape = [...plan.buffers[lossBufferId]!.shape]
|
|
477
|
+
|
|
478
|
+
const destroy = () => {
|
|
479
|
+
for (const [id, b] of buffers) {
|
|
480
|
+
if (ownedBufferIds.has(id)) b.destroy()
|
|
481
|
+
}
|
|
482
|
+
outputReadback.destroy()
|
|
483
|
+
if (captureStagings) for (const b of captureStagings.values()) b.destroy()
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
return {
|
|
487
|
+
device,
|
|
488
|
+
params,
|
|
489
|
+
outputShape,
|
|
490
|
+
uploadParams,
|
|
491
|
+
downloadParams: () => downloadFromMap(plan.paramsByName),
|
|
492
|
+
downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),
|
|
493
|
+
step,
|
|
494
|
+
run,
|
|
495
|
+
resetOptimizerState,
|
|
496
|
+
destroy,
|
|
497
|
+
}
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
/** Same machinery as `createRuntime`, narrower public type: a forward-only
|
|
501
|
+
* graph exposes `run()` instead of `step()` (no optimizer state, no scalar-
|
|
502
|
+
* loss readback). The full runtime object is built once and projected by
|
|
503
|
+
* `compileForward` to the public shape. */
|
|
504
|
+
export async function createForwardRuntime(
|
|
505
|
+
plan: BufferPlan,
|
|
506
|
+
kernels: KernelSpec[],
|
|
507
|
+
outputBufferId: number,
|
|
508
|
+
opts: RuntimeOpts = {},
|
|
509
|
+
): Promise<CompiledForward> {
|
|
510
|
+
return await createRuntime(plan, kernels, outputBufferId, opts)
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
async function acquireDevice(): Promise<GPUDevice> {
|
|
514
|
+
if (typeof navigator === 'undefined' || !navigator.gpu) {
|
|
515
|
+
throw new Error('tensorgrad: WebGPU not available in this environment')
|
|
516
|
+
}
|
|
517
|
+
const adapter = await navigator.gpu.requestAdapter()
|
|
518
|
+
if (!adapter) throw new Error('tensorgrad: no WebGPU adapter')
|
|
519
|
+
return await adapter.requestDevice()
|
|
520
|
+
}
|