tensorgrad 0.0.9 → 0.0.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +6 -6
- package/dist/compile.d.ts +77 -28
- package/dist/compile.d.ts.map +1 -1
- package/dist/compile.js +132 -81
- package/dist/compile.js.map +1 -1
- package/dist/index.d.ts +2 -2
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +2 -2
- package/dist/index.js.map +1 -1
- package/dist/nn.d.ts +14 -11
- package/dist/nn.d.ts.map +1 -1
- package/dist/nn.js +28 -33
- package/dist/nn.js.map +1 -1
- package/dist/runtime.d.ts +33 -32
- package/dist/runtime.d.ts.map +1 -1
- package/dist/runtime.js +46 -12
- package/dist/runtime.js.map +1 -1
- package/package.json +1 -1
- package/src/compile.ts +245 -114
- package/src/index.ts +7 -2
- package/src/nn.ts +34 -32
- package/src/runtime.ts +71 -53
package/src/runtime.ts
CHANGED
|
@@ -18,39 +18,65 @@ export interface UploadParamsOptions {
|
|
|
18
18
|
partial?: boolean
|
|
19
19
|
}
|
|
20
20
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
21
|
+
/**
|
|
22
|
+
* Activation readbacks for one `step()`/`run()` call. Keyed by the names
|
|
23
|
+
* passed to `capture(name, t)` during the trace. `get(name)` throws if the
|
|
24
|
+
* name isn't registered or wasn't read back this call (i.e., the call was
|
|
25
|
+
* made without `{ withCaptures: true }`); use `has(name)` if you need to
|
|
26
|
+
* branch. `shapeOf(name)` returns the static-after-compile shape and works
|
|
27
|
+
* regardless of whether captures were read back.
|
|
28
|
+
*/
|
|
29
|
+
export class Captures {
|
|
30
|
+
constructor(
|
|
31
|
+
private readonly shapes: Record<string, readonly number[]>,
|
|
32
|
+
private readonly data: Map<string, Float32Array>,
|
|
33
|
+
) {}
|
|
34
|
+
get(name: string): Float32Array {
|
|
35
|
+
const d = this.data.get(name)
|
|
36
|
+
if (!d) {
|
|
37
|
+
const known = [...this.data.keys()].sort().join(', ')
|
|
38
|
+
const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`
|
|
39
|
+
throw new Error(`Captures.get: '${name}' not present. ${detail}`)
|
|
40
|
+
}
|
|
41
|
+
return d
|
|
42
|
+
}
|
|
43
|
+
shapeOf(name: string): readonly number[] {
|
|
44
|
+
const s = this.shapes[name]
|
|
45
|
+
if (!s) {
|
|
46
|
+
const known = Object.keys(this.shapes).sort().join(', ') || '(none registered)'
|
|
47
|
+
throw new Error(`Captures.shapeOf: '${name}' not registered. Known: ${known}`)
|
|
48
|
+
}
|
|
49
|
+
return s
|
|
50
|
+
}
|
|
51
|
+
has(name: string): boolean { return this.data.has(name) }
|
|
52
|
+
names(): string[] { return [...this.data.keys()].sort() }
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
export interface RunResult {
|
|
56
|
+
output: Float32Array
|
|
57
|
+
captures: Captures
|
|
27
58
|
}
|
|
28
59
|
|
|
29
|
-
export interface
|
|
60
|
+
export interface StepResult {
|
|
30
61
|
loss: number
|
|
31
|
-
captures:
|
|
62
|
+
captures: Captures
|
|
32
63
|
}
|
|
33
64
|
|
|
34
65
|
export interface RunOptions {
|
|
35
66
|
/** Read back tensors registered via `capture(name, t)` during the trace.
|
|
36
|
-
* When false
|
|
37
|
-
*
|
|
67
|
+
* Default false. When false, the returned `captures` is empty (calling
|
|
68
|
+
* `.get` throws); when true, captures are read back and accessible. */
|
|
38
69
|
withCaptures?: boolean
|
|
39
70
|
}
|
|
40
71
|
|
|
41
|
-
export interface RunWithCaptures {
|
|
42
|
-
output: Float32Array
|
|
43
|
-
captures: Record<string, Float32Array>
|
|
44
|
-
}
|
|
45
|
-
|
|
46
72
|
/** Common surface for both training and forward-only compiled runtimes. */
|
|
47
73
|
export interface CompiledBase {
|
|
74
|
+
/** The GPUDevice this runtime is bound to. Pass to sibling compiles to
|
|
75
|
+
* share the device, or use directly for other GPU work. */
|
|
76
|
+
device: GPUDevice
|
|
48
77
|
/** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
|
|
49
78
|
* `sharedParams` to share without copies. */
|
|
50
79
|
params: Map<string, GPUBuffer>
|
|
51
|
-
/** Shape of each tensor registered via `capture(name, t)`. Static after
|
|
52
|
-
* compile — reshape readbacks without recomputing strides. */
|
|
53
|
-
captureShapes: Record<string, number[]>
|
|
54
80
|
/** Shape of the graph's output (loss scalar `[]` for training; the user's
|
|
55
81
|
* returned tensor for forward-only compiles). */
|
|
56
82
|
outputShape: number[]
|
|
@@ -64,15 +90,12 @@ export interface CompiledBase {
|
|
|
64
90
|
destroy(): void
|
|
65
91
|
}
|
|
66
92
|
|
|
67
|
-
/** Run a dispatch and read back the full output tensor
|
|
68
|
-
*
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
|
|
74
|
-
(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
|
|
75
|
-
}
|
|
93
|
+
/** Run a dispatch and read back the full output tensor. Captures are always
|
|
94
|
+
* returned; their data is empty unless `{ withCaptures: true }` is passed. */
|
|
95
|
+
export type RunFn = (
|
|
96
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
97
|
+
opts?: RunOptions,
|
|
98
|
+
) => Promise<RunResult>
|
|
76
99
|
|
|
77
100
|
export interface CompiledRuntime extends CompiledBase {
|
|
78
101
|
/** Read all parameter gradients back. Mostly for verification / debugging. */
|
|
@@ -83,12 +106,11 @@ export interface CompiledRuntime extends CompiledBase {
|
|
|
83
106
|
* 2. Dispatches every kernel in order.
|
|
84
107
|
* 3. Reads back the loss scalar (and any registered captures, if requested).
|
|
85
108
|
* Default returns the loss as a JS number; with `{ withCaptures: true }`
|
|
86
|
-
* returns `{ loss, captures }
|
|
87
|
-
* to `capture(...)` during the trace.
|
|
109
|
+
* returns `{ loss, captures }`.
|
|
88
110
|
*/
|
|
89
111
|
step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
|
|
90
|
-
step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<
|
|
91
|
-
step(inputs: Record<string, Int32Array | Float32Array>, opts:
|
|
112
|
+
step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
|
|
113
|
+
step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>
|
|
92
114
|
/** Same dispatch as step() but returns the full output Float32Array — for
|
|
93
115
|
* training graphs the output is a scalar loss, so step() is usually more
|
|
94
116
|
* convenient. Provided for parity with `compileForward`. */
|
|
@@ -261,7 +283,7 @@ export async function createRuntime(
|
|
|
261
283
|
async function dispatch(
|
|
262
284
|
inputs: Record<string, Int32Array | Float32Array>,
|
|
263
285
|
wantCaptures: boolean,
|
|
264
|
-
): Promise<{ output: Float32Array; captures:
|
|
286
|
+
): Promise<{ output: Float32Array; captures: Map<string, Float32Array> }> {
|
|
265
287
|
const turn = pending.catch(() => {}).then(() => dispatchUnsynchronized(inputs, wantCaptures))
|
|
266
288
|
pending = turn
|
|
267
289
|
return turn
|
|
@@ -269,7 +291,7 @@ export async function createRuntime(
|
|
|
269
291
|
async function dispatchUnsynchronized(
|
|
270
292
|
inputs: Record<string, Int32Array | Float32Array>,
|
|
271
293
|
wantCaptures: boolean,
|
|
272
|
-
): Promise<{ output: Float32Array; captures:
|
|
294
|
+
): Promise<{ output: Float32Array; captures: Map<string, Float32Array> }> {
|
|
273
295
|
if (wantCaptures && plan.capturesByName.size === 0) {
|
|
274
296
|
throw new Error(
|
|
275
297
|
`withCaptures=true but no capture(...) calls were registered during ` +
|
|
@@ -331,41 +353,37 @@ export async function createRuntime(
|
|
|
331
353
|
const output = new Float32Array(outputReadback.getMappedRange().slice(0))
|
|
332
354
|
outputReadback.unmap()
|
|
333
355
|
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
356
|
+
const captures = new Map<string, Float32Array>()
|
|
357
|
+
if (wantCaptures) {
|
|
358
|
+
for (const [name, staging] of stagings!) {
|
|
359
|
+
await staging.mapAsync(GPUMapMode.READ)
|
|
360
|
+
captures.set(name, new Float32Array(staging.getMappedRange().slice(0)))
|
|
361
|
+
staging.unmap()
|
|
362
|
+
}
|
|
341
363
|
}
|
|
342
364
|
return { output, captures }
|
|
343
365
|
}
|
|
344
366
|
|
|
345
367
|
// ---- step() — training-mode wrapper, returns scalar [0] of output ---------
|
|
346
368
|
function step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
|
|
347
|
-
function step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<
|
|
348
|
-
function step(inputs: Record<string, Int32Array | Float32Array>, opts:
|
|
369
|
+
function step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
|
|
370
|
+
function step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>
|
|
349
371
|
async function step(
|
|
350
372
|
inputs: Record<string, Int32Array | Float32Array>,
|
|
351
|
-
opts?:
|
|
352
|
-
): Promise<number |
|
|
373
|
+
opts?: RunOptions,
|
|
374
|
+
): Promise<number | StepResult> {
|
|
353
375
|
const r = await dispatch(inputs, opts?.withCaptures === true)
|
|
354
|
-
if (opts?.withCaptures) return { loss: r.output[0]!, captures: r.captures
|
|
376
|
+
if (opts?.withCaptures) return { loss: r.output[0]!, captures: new Captures(captureShapes, r.captures) }
|
|
355
377
|
return r.output[0]!
|
|
356
378
|
}
|
|
357
379
|
|
|
358
|
-
// ---- run() — forward-mode wrapper, returns
|
|
359
|
-
function run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
|
|
360
|
-
function run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
|
|
361
|
-
function run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
|
|
380
|
+
// ---- run() — forward-mode wrapper, returns { output, captures } ----------
|
|
362
381
|
async function run(
|
|
363
382
|
inputs: Record<string, Int32Array | Float32Array>,
|
|
364
383
|
opts?: RunOptions,
|
|
365
|
-
): Promise<
|
|
384
|
+
): Promise<RunResult> {
|
|
366
385
|
const r = await dispatch(inputs, opts?.withCaptures === true)
|
|
367
|
-
|
|
368
|
-
return r.output
|
|
386
|
+
return { output: r.output, captures: new Captures(captureShapes, r.captures) }
|
|
369
387
|
}
|
|
370
388
|
|
|
371
389
|
// ---- uploadParams ---------------------------------------------------------
|
|
@@ -461,8 +479,8 @@ export async function createRuntime(
|
|
|
461
479
|
}
|
|
462
480
|
|
|
463
481
|
return {
|
|
482
|
+
device,
|
|
464
483
|
params,
|
|
465
|
-
captureShapes,
|
|
466
484
|
outputShape,
|
|
467
485
|
uploadParams,
|
|
468
486
|
downloadParams: () => downloadFromMap(plan.paramsByName),
|