tensorgrad 0.0.8 → 0.0.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/runtime.ts CHANGED
@@ -18,39 +18,65 @@ export interface UploadParamsOptions {
18
18
  partial?: boolean
19
19
  }
20
20
 
21
- export interface StepOptions {
22
- /** Read back tensors registered via `capture(name, t)` during the trace.
23
- * When false/unset, `step()` returns just the loss number; staging buffers
24
- * for captures are not allocated. When true, returns `{ loss, captures }`
25
- * and lazily allocates one staging buffer per capture on first use. */
26
- withCaptures?: boolean
21
+ /**
22
+ * Activation readbacks for one `step()`/`run()` call. Keyed by the names
23
+ * passed to `capture(name, t)` during the trace. `get(name)` throws if the
24
+ * name isn't registered or wasn't read back this call (i.e., the call was
25
+ * made without `{ withCaptures: true }`); use `has(name)` if you need to
26
+ * branch. `shapeOf(name)` returns the static-after-compile shape and works
27
+ * regardless of whether captures were read back.
28
+ */
29
+ export class Captures {
30
+ constructor(
31
+ private readonly shapes: Record<string, readonly number[]>,
32
+ private readonly data: Map<string, Float32Array>,
33
+ ) {}
34
+ get(name: string): Float32Array {
35
+ const d = this.data.get(name)
36
+ if (!d) {
37
+ const known = [...this.data.keys()].sort().join(', ')
38
+ const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`
39
+ throw new Error(`Captures.get: '${name}' not present. ${detail}`)
40
+ }
41
+ return d
42
+ }
43
+ shapeOf(name: string): readonly number[] {
44
+ const s = this.shapes[name]
45
+ if (!s) {
46
+ const known = Object.keys(this.shapes).sort().join(', ') || '(none registered)'
47
+ throw new Error(`Captures.shapeOf: '${name}' not registered. Known: ${known}`)
48
+ }
49
+ return s
50
+ }
51
+ has(name: string): boolean { return this.data.has(name) }
52
+ names(): string[] { return [...this.data.keys()].sort() }
27
53
  }
28
54
 
29
- export interface StepWithCaptures {
55
+ export interface RunResult {
56
+ output: Float32Array
57
+ captures: Captures
58
+ }
59
+
60
+ export interface StepResult {
30
61
  loss: number
31
- captures: Record<string, Float32Array>
62
+ captures: Captures
32
63
  }
33
64
 
34
65
  export interface RunOptions {
35
66
  /** Read back tensors registered via `capture(name, t)` during the trace.
36
- * When false/unset, `run()` returns just the output Float32Array.
37
- * When true, returns `{ output, captures }`. */
67
+ * Default false. When false, the returned `captures` is empty (calling
68
+ * `.get` throws); when true, captures are read back and accessible. */
38
69
  withCaptures?: boolean
39
70
  }
40
71
 
41
- export interface RunWithCaptures {
42
- output: Float32Array
43
- captures: Record<string, Float32Array>
44
- }
45
-
46
72
  /** Common surface for both training and forward-only compiled runtimes. */
47
73
  export interface CompiledBase {
74
+ /** The GPUDevice this runtime is bound to. Pass to sibling compiles to
75
+ * share the device, or use directly for other GPU work. */
76
+ device: GPUDevice
48
77
  /** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
49
78
  * `sharedParams` to share without copies. */
50
79
  params: Map<string, GPUBuffer>
51
- /** Shape of each tensor registered via `capture(name, t)`. Static after
52
- * compile — reshape readbacks without recomputing strides. */
53
- captureShapes: Record<string, number[]>
54
80
  /** Shape of the graph's output (loss scalar `[]` for training; the user's
55
81
  * returned tensor for forward-only compiles). */
56
82
  outputShape: number[]
@@ -64,15 +90,12 @@ export interface CompiledBase {
64
90
  destroy(): void
65
91
  }
66
92
 
67
- /** Run a dispatch and read back the full output tensor (and any registered
68
- * captures if requested). Forward-only compiles use this as their primary
69
- * surface; training compiles also expose it but `step()` is more convenient
70
- * there because the output is a scalar loss. */
71
- export interface RunFn {
72
- (inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
73
- (inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
74
- (inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
75
- }
93
+ /** Run a dispatch and read back the full output tensor. Captures are always
94
+ * returned; their data is empty unless `{ withCaptures: true }` is passed. */
95
+ export type RunFn = (
96
+ inputs: Record<string, Int32Array | Float32Array>,
97
+ opts?: RunOptions,
98
+ ) => Promise<RunResult>
76
99
 
77
100
  export interface CompiledRuntime extends CompiledBase {
78
101
  /** Read all parameter gradients back. Mostly for verification / debugging. */
@@ -83,12 +106,11 @@ export interface CompiledRuntime extends CompiledBase {
83
106
  * 2. Dispatches every kernel in order.
84
107
  * 3. Reads back the loss scalar (and any registered captures, if requested).
85
108
  * Default returns the loss as a JS number; with `{ withCaptures: true }`
86
- * returns `{ loss, captures }` where `captures` is keyed by the names passed
87
- * to `capture(...)` during the trace.
109
+ * returns `{ loss, captures }`.
88
110
  */
89
111
  step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
90
- step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepWithCaptures>
91
- step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepWithCaptures>
112
+ step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
113
+ step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>
92
114
  /** Same dispatch as step() but returns the full output Float32Array — for
93
115
  * training graphs the output is a scalar loss, so step() is usually more
94
116
  * convenient. Provided for parity with `compileForward`. */
@@ -250,10 +272,26 @@ export async function createRuntime(
250
272
  // the output buffer into its staging, optionally copies captures into theirs,
251
273
  // submits, and reads back. Returns the full output Float32Array; step() takes
252
274
  // [0] for scalar loss, run() returns it whole.
275
+ //
276
+ // **Concurrent calls auto-serialize.** Two `step()`/`run()` calls on the same
277
+ // runtime would otherwise both try to `mapAsync` the shared output staging
278
+ // buffer at the same time and trip "Buffer already has an outstanding map
279
+ // pending." We chain each new dispatch onto the prior one's promise so they
280
+ // run sequentially even when fired from independent async paths (e.g., a
281
+ // training loop's auxiliary `refreshPrediction()` + `writeDiagnostic()`).
282
+ let pending: Promise<unknown> = Promise.resolve()
253
283
  async function dispatch(
254
284
  inputs: Record<string, Int32Array | Float32Array>,
255
285
  wantCaptures: boolean,
256
- ): Promise<{ output: Float32Array; captures: Record<string, Float32Array> | null }> {
286
+ ): Promise<{ output: Float32Array; captures: Map<string, Float32Array> }> {
287
+ const turn = pending.catch(() => {}).then(() => dispatchUnsynchronized(inputs, wantCaptures))
288
+ pending = turn
289
+ return turn
290
+ }
291
+ async function dispatchUnsynchronized(
292
+ inputs: Record<string, Int32Array | Float32Array>,
293
+ wantCaptures: boolean,
294
+ ): Promise<{ output: Float32Array; captures: Map<string, Float32Array> }> {
257
295
  if (wantCaptures && plan.capturesByName.size === 0) {
258
296
  throw new Error(
259
297
  `withCaptures=true but no capture(...) calls were registered during ` +
@@ -315,41 +353,37 @@ export async function createRuntime(
315
353
  const output = new Float32Array(outputReadback.getMappedRange().slice(0))
316
354
  outputReadback.unmap()
317
355
 
318
- if (!wantCaptures) return { output, captures: null }
319
-
320
- const captures: Record<string, Float32Array> = {}
321
- for (const [name, staging] of stagings!) {
322
- await staging.mapAsync(GPUMapMode.READ)
323
- captures[name] = new Float32Array(staging.getMappedRange().slice(0))
324
- staging.unmap()
356
+ const captures = new Map<string, Float32Array>()
357
+ if (wantCaptures) {
358
+ for (const [name, staging] of stagings!) {
359
+ await staging.mapAsync(GPUMapMode.READ)
360
+ captures.set(name, new Float32Array(staging.getMappedRange().slice(0)))
361
+ staging.unmap()
362
+ }
325
363
  }
326
364
  return { output, captures }
327
365
  }
328
366
 
329
367
  // ---- step() — training-mode wrapper, returns scalar [0] of output ---------
330
368
  function step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
331
- function step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepWithCaptures>
332
- function step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepWithCaptures>
369
+ function step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
370
+ function step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>
333
371
  async function step(
334
372
  inputs: Record<string, Int32Array | Float32Array>,
335
- opts?: StepOptions,
336
- ): Promise<number | StepWithCaptures> {
373
+ opts?: RunOptions,
374
+ ): Promise<number | StepResult> {
337
375
  const r = await dispatch(inputs, opts?.withCaptures === true)
338
- if (opts?.withCaptures) return { loss: r.output[0]!, captures: r.captures! }
376
+ if (opts?.withCaptures) return { loss: r.output[0]!, captures: new Captures(captureShapes, r.captures) }
339
377
  return r.output[0]!
340
378
  }
341
379
 
342
- // ---- run() — forward-mode wrapper, returns full output Float32Array -------
343
- function run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
344
- function run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
345
- function run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
380
+ // ---- run() — forward-mode wrapper, returns { output, captures } ----------
346
381
  async function run(
347
382
  inputs: Record<string, Int32Array | Float32Array>,
348
383
  opts?: RunOptions,
349
- ): Promise<Float32Array | RunWithCaptures> {
384
+ ): Promise<RunResult> {
350
385
  const r = await dispatch(inputs, opts?.withCaptures === true)
351
- if (opts?.withCaptures) return { output: r.output, captures: r.captures! }
352
- return r.output
386
+ return { output: r.output, captures: new Captures(captureShapes, r.captures) }
353
387
  }
354
388
 
355
389
  // ---- uploadParams ---------------------------------------------------------
@@ -445,8 +479,8 @@ export async function createRuntime(
445
479
  }
446
480
 
447
481
  return {
482
+ device,
448
483
  params,
449
- captureShapes,
450
484
  outputShape,
451
485
  uploadParams,
452
486
  downloadParams: () => downloadFromMap(plan.paramsByName),