tensorgrad 0.0.8 → 0.0.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +6 -6
- package/dist/compile.d.ts +77 -28
- package/dist/compile.d.ts.map +1 -1
- package/dist/compile.js +132 -81
- package/dist/compile.js.map +1 -1
- package/dist/index.d.ts +2 -2
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +2 -2
- package/dist/index.js.map +1 -1
- package/dist/nn.d.ts +14 -11
- package/dist/nn.d.ts.map +1 -1
- package/dist/nn.js +28 -33
- package/dist/nn.js.map +1 -1
- package/dist/runtime.d.ts +33 -32
- package/dist/runtime.d.ts.map +1 -1
- package/dist/runtime.js +59 -12
- package/dist/runtime.js.map +1 -1
- package/package.json +1 -1
- package/src/compile.ts +245 -114
- package/src/index.ts +7 -2
- package/src/nn.ts +34 -32
- package/src/runtime.ts +86 -52
package/src/runtime.ts
CHANGED
|
@@ -18,39 +18,65 @@ export interface UploadParamsOptions {
|
|
|
18
18
|
partial?: boolean
|
|
19
19
|
}
|
|
20
20
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
21
|
+
/**
|
|
22
|
+
* Activation readbacks for one `step()`/`run()` call. Keyed by the names
|
|
23
|
+
* passed to `capture(name, t)` during the trace. `get(name)` throws if the
|
|
24
|
+
* name isn't registered or wasn't read back this call (i.e., the call was
|
|
25
|
+
* made without `{ withCaptures: true }`); use `has(name)` if you need to
|
|
26
|
+
* branch. `shapeOf(name)` returns the static-after-compile shape and works
|
|
27
|
+
* regardless of whether captures were read back.
|
|
28
|
+
*/
|
|
29
|
+
export class Captures {
|
|
30
|
+
constructor(
|
|
31
|
+
private readonly shapes: Record<string, readonly number[]>,
|
|
32
|
+
private readonly data: Map<string, Float32Array>,
|
|
33
|
+
) {}
|
|
34
|
+
get(name: string): Float32Array {
|
|
35
|
+
const d = this.data.get(name)
|
|
36
|
+
if (!d) {
|
|
37
|
+
const known = [...this.data.keys()].sort().join(', ')
|
|
38
|
+
const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`
|
|
39
|
+
throw new Error(`Captures.get: '${name}' not present. ${detail}`)
|
|
40
|
+
}
|
|
41
|
+
return d
|
|
42
|
+
}
|
|
43
|
+
shapeOf(name: string): readonly number[] {
|
|
44
|
+
const s = this.shapes[name]
|
|
45
|
+
if (!s) {
|
|
46
|
+
const known = Object.keys(this.shapes).sort().join(', ') || '(none registered)'
|
|
47
|
+
throw new Error(`Captures.shapeOf: '${name}' not registered. Known: ${known}`)
|
|
48
|
+
}
|
|
49
|
+
return s
|
|
50
|
+
}
|
|
51
|
+
has(name: string): boolean { return this.data.has(name) }
|
|
52
|
+
names(): string[] { return [...this.data.keys()].sort() }
|
|
27
53
|
}
|
|
28
54
|
|
|
29
|
-
export interface
|
|
55
|
+
export interface RunResult {
|
|
56
|
+
output: Float32Array
|
|
57
|
+
captures: Captures
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
export interface StepResult {
|
|
30
61
|
loss: number
|
|
31
|
-
captures:
|
|
62
|
+
captures: Captures
|
|
32
63
|
}
|
|
33
64
|
|
|
34
65
|
export interface RunOptions {
|
|
35
66
|
/** Read back tensors registered via `capture(name, t)` during the trace.
|
|
36
|
-
* When false
|
|
37
|
-
*
|
|
67
|
+
* Default false. When false, the returned `captures` is empty (calling
|
|
68
|
+
* `.get` throws); when true, captures are read back and accessible. */
|
|
38
69
|
withCaptures?: boolean
|
|
39
70
|
}
|
|
40
71
|
|
|
41
|
-
export interface RunWithCaptures {
|
|
42
|
-
output: Float32Array
|
|
43
|
-
captures: Record<string, Float32Array>
|
|
44
|
-
}
|
|
45
|
-
|
|
46
72
|
/** Common surface for both training and forward-only compiled runtimes. */
|
|
47
73
|
export interface CompiledBase {
|
|
74
|
+
/** The GPUDevice this runtime is bound to. Pass to sibling compiles to
|
|
75
|
+
* share the device, or use directly for other GPU work. */
|
|
76
|
+
device: GPUDevice
|
|
48
77
|
/** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
|
|
49
78
|
* `sharedParams` to share without copies. */
|
|
50
79
|
params: Map<string, GPUBuffer>
|
|
51
|
-
/** Shape of each tensor registered via `capture(name, t)`. Static after
|
|
52
|
-
* compile — reshape readbacks without recomputing strides. */
|
|
53
|
-
captureShapes: Record<string, number[]>
|
|
54
80
|
/** Shape of the graph's output (loss scalar `[]` for training; the user's
|
|
55
81
|
* returned tensor for forward-only compiles). */
|
|
56
82
|
outputShape: number[]
|
|
@@ -64,15 +90,12 @@ export interface CompiledBase {
|
|
|
64
90
|
destroy(): void
|
|
65
91
|
}
|
|
66
92
|
|
|
67
|
-
/** Run a dispatch and read back the full output tensor
|
|
68
|
-
*
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
|
|
74
|
-
(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
|
|
75
|
-
}
|
|
93
|
+
/** Run a dispatch and read back the full output tensor. Captures are always
|
|
94
|
+
* returned; their data is empty unless `{ withCaptures: true }` is passed. */
|
|
95
|
+
export type RunFn = (
|
|
96
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
97
|
+
opts?: RunOptions,
|
|
98
|
+
) => Promise<RunResult>
|
|
76
99
|
|
|
77
100
|
export interface CompiledRuntime extends CompiledBase {
|
|
78
101
|
/** Read all parameter gradients back. Mostly for verification / debugging. */
|
|
@@ -83,12 +106,11 @@ export interface CompiledRuntime extends CompiledBase {
|
|
|
83
106
|
* 2. Dispatches every kernel in order.
|
|
84
107
|
* 3. Reads back the loss scalar (and any registered captures, if requested).
|
|
85
108
|
* Default returns the loss as a JS number; with `{ withCaptures: true }`
|
|
86
|
-
* returns `{ loss, captures }
|
|
87
|
-
* to `capture(...)` during the trace.
|
|
109
|
+
* returns `{ loss, captures }`.
|
|
88
110
|
*/
|
|
89
111
|
step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
|
|
90
|
-
step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<
|
|
91
|
-
step(inputs: Record<string, Int32Array | Float32Array>, opts:
|
|
112
|
+
step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
|
|
113
|
+
step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>
|
|
92
114
|
/** Same dispatch as step() but returns the full output Float32Array — for
|
|
93
115
|
* training graphs the output is a scalar loss, so step() is usually more
|
|
94
116
|
* convenient. Provided for parity with `compileForward`. */
|
|
@@ -250,10 +272,26 @@ export async function createRuntime(
|
|
|
250
272
|
// the output buffer into its staging, optionally copies captures into theirs,
|
|
251
273
|
// submits, and reads back. Returns the full output Float32Array; step() takes
|
|
252
274
|
// [0] for scalar loss, run() returns it whole.
|
|
275
|
+
//
|
|
276
|
+
// **Concurrent calls auto-serialize.** Two `step()`/`run()` calls on the same
|
|
277
|
+
// runtime would otherwise both try to `mapAsync` the shared output staging
|
|
278
|
+
// buffer at the same time and trip "Buffer already has an outstanding map
|
|
279
|
+
// pending." We chain each new dispatch onto the prior one's promise so they
|
|
280
|
+
// run sequentially even when fired from independent async paths (e.g., a
|
|
281
|
+
// training loop's auxiliary `refreshPrediction()` + `writeDiagnostic()`).
|
|
282
|
+
let pending: Promise<unknown> = Promise.resolve()
|
|
253
283
|
async function dispatch(
|
|
254
284
|
inputs: Record<string, Int32Array | Float32Array>,
|
|
255
285
|
wantCaptures: boolean,
|
|
256
|
-
): Promise<{ output: Float32Array; captures:
|
|
286
|
+
): Promise<{ output: Float32Array; captures: Map<string, Float32Array> }> {
|
|
287
|
+
const turn = pending.catch(() => {}).then(() => dispatchUnsynchronized(inputs, wantCaptures))
|
|
288
|
+
pending = turn
|
|
289
|
+
return turn
|
|
290
|
+
}
|
|
291
|
+
async function dispatchUnsynchronized(
|
|
292
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
293
|
+
wantCaptures: boolean,
|
|
294
|
+
): Promise<{ output: Float32Array; captures: Map<string, Float32Array> }> {
|
|
257
295
|
if (wantCaptures && plan.capturesByName.size === 0) {
|
|
258
296
|
throw new Error(
|
|
259
297
|
`withCaptures=true but no capture(...) calls were registered during ` +
|
|
@@ -315,41 +353,37 @@ export async function createRuntime(
|
|
|
315
353
|
const output = new Float32Array(outputReadback.getMappedRange().slice(0))
|
|
316
354
|
outputReadback.unmap()
|
|
317
355
|
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
356
|
+
const captures = new Map<string, Float32Array>()
|
|
357
|
+
if (wantCaptures) {
|
|
358
|
+
for (const [name, staging] of stagings!) {
|
|
359
|
+
await staging.mapAsync(GPUMapMode.READ)
|
|
360
|
+
captures.set(name, new Float32Array(staging.getMappedRange().slice(0)))
|
|
361
|
+
staging.unmap()
|
|
362
|
+
}
|
|
325
363
|
}
|
|
326
364
|
return { output, captures }
|
|
327
365
|
}
|
|
328
366
|
|
|
329
367
|
// ---- step() — training-mode wrapper, returns scalar [0] of output ---------
|
|
330
368
|
function step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
|
|
331
|
-
function step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<
|
|
332
|
-
function step(inputs: Record<string, Int32Array | Float32Array>, opts:
|
|
369
|
+
function step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
|
|
370
|
+
function step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>
|
|
333
371
|
async function step(
|
|
334
372
|
inputs: Record<string, Int32Array | Float32Array>,
|
|
335
|
-
opts?:
|
|
336
|
-
): Promise<number |
|
|
373
|
+
opts?: RunOptions,
|
|
374
|
+
): Promise<number | StepResult> {
|
|
337
375
|
const r = await dispatch(inputs, opts?.withCaptures === true)
|
|
338
|
-
if (opts?.withCaptures) return { loss: r.output[0]!, captures: r.captures
|
|
376
|
+
if (opts?.withCaptures) return { loss: r.output[0]!, captures: new Captures(captureShapes, r.captures) }
|
|
339
377
|
return r.output[0]!
|
|
340
378
|
}
|
|
341
379
|
|
|
342
|
-
// ---- run() — forward-mode wrapper, returns
|
|
343
|
-
function run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
|
|
344
|
-
function run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
|
|
345
|
-
function run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
|
|
380
|
+
// ---- run() — forward-mode wrapper, returns { output, captures } ----------
|
|
346
381
|
async function run(
|
|
347
382
|
inputs: Record<string, Int32Array | Float32Array>,
|
|
348
383
|
opts?: RunOptions,
|
|
349
|
-
): Promise<
|
|
384
|
+
): Promise<RunResult> {
|
|
350
385
|
const r = await dispatch(inputs, opts?.withCaptures === true)
|
|
351
|
-
|
|
352
|
-
return r.output
|
|
386
|
+
return { output: r.output, captures: new Captures(captureShapes, r.captures) }
|
|
353
387
|
}
|
|
354
388
|
|
|
355
389
|
// ---- uploadParams ---------------------------------------------------------
|
|
@@ -445,8 +479,8 @@ export async function createRuntime(
|
|
|
445
479
|
}
|
|
446
480
|
|
|
447
481
|
return {
|
|
482
|
+
device,
|
|
448
483
|
params,
|
|
449
|
-
captureShapes,
|
|
450
484
|
outputShape,
|
|
451
485
|
uploadParams,
|
|
452
486
|
downloadParams: () => downloadFromMap(plan.paramsByName),
|