tensorgrad 0.0.1 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. package/README.md +7 -9
  2. package/dist/adam.d.ts +14 -2
  3. package/dist/adam.d.ts.map +1 -1
  4. package/dist/adam.js +19 -8
  5. package/dist/adam.js.map +1 -1
  6. package/dist/buffers.d.ts +1 -0
  7. package/dist/buffers.d.ts.map +1 -1
  8. package/dist/buffers.js +12 -1
  9. package/dist/buffers.js.map +1 -1
  10. package/dist/capture.d.ts +3 -0
  11. package/dist/capture.d.ts.map +1 -0
  12. package/dist/capture.js +33 -0
  13. package/dist/capture.js.map +1 -0
  14. package/dist/codegen.js +4 -2
  15. package/dist/codegen.js.map +1 -1
  16. package/dist/compile.d.ts +33 -5
  17. package/dist/compile.d.ts.map +1 -1
  18. package/dist/compile.js +96 -11
  19. package/dist/compile.js.map +1 -1
  20. package/dist/index.d.ts +5 -3
  21. package/dist/index.d.ts.map +1 -1
  22. package/dist/index.js +4 -2
  23. package/dist/index.js.map +1 -1
  24. package/dist/ir.d.ts +2 -0
  25. package/dist/ir.d.ts.map +1 -1
  26. package/dist/ir.js +1 -1
  27. package/dist/ir.js.map +1 -1
  28. package/dist/module.d.ts +30 -4
  29. package/dist/module.d.ts.map +1 -1
  30. package/dist/module.js +39 -13
  31. package/dist/module.js.map +1 -1
  32. package/dist/nn.d.ts +19 -0
  33. package/dist/nn.d.ts.map +1 -0
  34. package/dist/nn.js +60 -0
  35. package/dist/nn.js.map +1 -0
  36. package/dist/ops.d.ts +1 -1
  37. package/dist/ops.d.ts.map +1 -1
  38. package/dist/ops.js +2 -2
  39. package/dist/ops.js.map +1 -1
  40. package/dist/runtime.d.ts +79 -4
  41. package/dist/runtime.d.ts.map +1 -1
  42. package/dist/runtime.js +153 -19
  43. package/dist/runtime.js.map +1 -1
  44. package/dist/trace.d.ts +1 -0
  45. package/dist/trace.d.ts.map +1 -1
  46. package/dist/trace.js +12 -0
  47. package/dist/trace.js.map +1 -1
  48. package/package.json +1 -2
  49. package/src/adam.ts +31 -10
  50. package/src/buffers.ts +14 -1
  51. package/src/capture.ts +36 -0
  52. package/src/codegen.ts +4 -2
  53. package/src/compile.ts +112 -13
  54. package/src/index.ts +5 -3
  55. package/src/ir.ts +10 -4
  56. package/src/module.ts +75 -11
  57. package/src/nn.ts +59 -0
  58. package/src/ops.ts +2 -2
  59. package/src/runtime.ts +260 -22
  60. package/src/trace.ts +13 -0
  61. package/SPEC.md +0 -293
package/src/runtime.ts CHANGED
@@ -11,9 +11,47 @@ import type { KernelSpec } from './codegen.js'
11
11
  // Provided by the browser per WebGPU spec; declare just what we use.
12
12
  declare const GPUMapMode: { readonly READ: number; readonly WRITE: number }
13
13
 
14
+ export interface UploadParamsOptions {
15
+ /** Skip the "missing param" check, allowing the caller to update only some
16
+ * params and leave the rest at their current GPU values. Extra (unknown)
17
+ * keys are still rejected — that's always a typo. Default: false. */
18
+ partial?: boolean
19
+ }
20
+
21
+ export interface StepOptions {
22
+ /** Read back tensors registered via `capture(name, t)` during the trace.
23
+ * When false/unset, `step()` returns just the loss number; staging buffers
24
+ * for captures are not allocated. When true, returns `{ loss, captures }`
25
+ * and lazily allocates one staging buffer per capture on first use. */
26
+ withCaptures?: boolean
27
+ }
28
+
29
+ export interface StepWithCaptures {
30
+ loss: number
31
+ captures: Record<string, Float32Array>
32
+ }
33
+
34
+ export interface RunOptions {
35
+ /** Read back tensors registered via `capture(name, t)` during the trace.
36
+ * When false/unset, `run()` returns just the output Float32Array.
37
+ * When true, returns `{ output, captures }`. */
38
+ withCaptures?: boolean
39
+ }
40
+
41
+ export interface RunWithCaptures {
42
+ output: Float32Array
43
+ captures: Record<string, Float32Array>
44
+ }
45
+
14
46
  export interface CompiledRuntime {
15
- /** Upload one or more parameter Float32Arrays to their GPU buffers. */
16
- uploadParams(params: Record<string, Float32Array>): void
47
+ /** Map of param name -> the underlying GPUBuffer. Pass to a sibling compile
48
+ * via `sharedParams` to share without copies — every step on this runtime
49
+ * is immediately visible to anyone reading these buffers. */
50
+ params: Map<string, GPUBuffer>
51
+ /** Upload parameter Float32Arrays to their GPU buffers. By default, requires
52
+ * *all* params to be present; throws on any unknown or missing key. Pass
53
+ * `{ partial: true }` to skip the missing-key check. */
54
+ uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void
17
55
  /** Read all parameters back as Float32Arrays — used for UI panels. */
18
56
  downloadParams(): Promise<Record<string, Float32Array>>
19
57
  /** Read all parameter gradients back. Mostly for verification / debugging. */
@@ -22,17 +60,51 @@ export interface CompiledRuntime {
22
60
  * One full forward+backward step.
23
61
  * 1. Uploads `inputs` (tokens, targets, masks) to input buffers.
24
62
  * 2. Dispatches every kernel in order.
25
- * 3. Reads back the loss scalar.
26
- * Returns the loss as a JS number.
63
+ * 3. Reads back the loss scalar (and any registered captures, if requested).
64
+ * Default returns the loss as a JS number; with `{ withCaptures: true }`
65
+ * returns `{ loss, captures }` where `captures` is keyed by the names passed
66
+ * to `capture(...)` during the trace.
27
67
  */
28
68
  step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
69
+ step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepWithCaptures>
70
+ step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepWithCaptures>
71
+ /** Like `step()` but returns the full output Float32Array instead of just
72
+ * its first element. For training graphs this is rarely useful (the output
73
+ * *is* a scalar loss); it's the primary API for forward-only compiles. */
74
+ run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
75
+ run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
76
+ run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
77
+ /** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
78
+ * `uploadInitialParams()` for a full training reset without recompile. */
79
+ resetOptimizerState(): void
29
80
  /** Free GPU resources. */
30
81
  destroy(): void
31
82
  }
32
83
 
84
+ /** Forward-only compiled runtime — produced by `compileForward`. No optimizer,
85
+ * no backward. Returns the output tensor (not just a scalar) per `run()` call. */
86
+ export interface CompiledForward {
87
+ params: Map<string, GPUBuffer>
88
+ uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void
89
+ downloadParams(): Promise<Record<string, Float32Array>>
90
+ /** Forward-only dispatch. Returns the graph's output tensor as a Float32Array
91
+ * (the user's returned tensor from the forward function, in row-major order).
92
+ * With `{ withCaptures: true }`, returns `{ output, captures }`. */
93
+ run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
94
+ run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
95
+ run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
96
+ destroy(): void
97
+ }
98
+
33
99
  export interface RuntimeOpts {
34
100
  /** Pre-acquired GPUDevice. If omitted, runtime requests its own. */
35
101
  device?: GPUDevice
102
+ /** External param buffers to bind in place of allocating fresh ones, keyed
103
+ * by param name. Used to share params between a training compile and a
104
+ * sibling forward-only compile (e.g., a B=1 inference graph). When a name
105
+ * is in this map, the runtime reuses the provided GPUBuffer; otherwise it
106
+ * allocates as usual. */
107
+ sharedParams?: Map<string, GPUBuffer>
36
108
  }
37
109
 
38
110
  // Inlined numeric values (per WebGPU spec) so this module is importable in Node
@@ -52,8 +124,23 @@ export async function createRuntime(
52
124
 
53
125
  // ---- Allocate one GPUBuffer per BufferSpec --------------------------------
54
126
  // State buffers also get filled with their initValue at allocation time.
127
+ // Param buffers may be supplied externally via opts.sharedParams; in that
128
+ // case we reuse the provided GPUBuffer instead of allocating, and the
129
+ // sibling compile that owns it is responsible for upload + lifetime.
55
130
  const buffers = new Map<number, GPUBuffer>()
131
+ const sharedParams = opts.sharedParams
56
132
  for (const spec of plan.buffers) {
133
+ if (spec.kind === 'param' && sharedParams?.has(spec.name!)) {
134
+ const shared = sharedParams.get(spec.name!)!
135
+ if (shared.size !== spec.byteSize) {
136
+ throw new Error(
137
+ `sharedParams: size mismatch for '${spec.name}' — supplied ${shared.size} bytes, ` +
138
+ `compiled graph expects ${spec.byteSize}.`,
139
+ )
140
+ }
141
+ buffers.set(spec.id, shared)
142
+ continue
143
+ }
57
144
  const buf = device.createBuffer({
58
145
  size: spec.byteSize,
59
146
  usage: STORAGE_RW,
@@ -69,6 +156,13 @@ export async function createRuntime(
69
156
  queue.writeBuffer(buf, 0, init as unknown as BufferSource)
70
157
  }
71
158
  }
159
+ // Track which params are externally owned — those are skipped on destroy().
160
+ const ownedBufferIds = new Set<number>()
161
+ for (const spec of plan.buffers) {
162
+ if (!(spec.kind === 'param' && sharedParams?.has(spec.name!))) {
163
+ ownedBufferIds.add(spec.id)
164
+ }
165
+ }
72
166
 
73
167
  // ---- Compile pipelines per kernel; cache by WGSL source -------------------
74
168
  // Push an error scope around each shader+pipeline creation so we can surface
@@ -127,12 +221,45 @@ export async function createRuntime(
127
221
  })
128
222
  })
129
223
 
130
- // ---- Loss readback staging buffer -----------------------------------------
131
- const lossSpec = plan.buffers[lossBufferId]!
132
- const lossReadback = device.createBuffer({ size: lossSpec.byteSize, usage: READBACK })
224
+ // ---- Output readback staging buffer ---------------------------------------
225
+ // `outputBufferId` is the graph's main output (loss for training, the user's
226
+ // returned tensor for forward-only). step() reads back its first element;
227
+ // run() reads back the full Float32Array.
228
+ const outputSpec = plan.buffers[lossBufferId]!
229
+ const outputReadback = device.createBuffer({ size: outputSpec.byteSize, usage: READBACK })
230
+
231
+ // ---- Capture readback staging buffers (lazy) ------------------------------
232
+ // Allocated on first `step({ withCaptures: true })` call and reused across
233
+ // subsequent calls. When the graph has no captures registered or when the
234
+ // caller never opts in, no extra GPU memory is allocated.
235
+ let captureStagings: Map<string, GPUBuffer> | null = null
236
+ function ensureCaptureStagings(): Map<string, GPUBuffer> {
237
+ if (captureStagings) return captureStagings
238
+ captureStagings = new Map()
239
+ for (const [name, bufId] of plan.capturesByName) {
240
+ const spec = plan.buffers[bufId]!
241
+ const staging = device.createBuffer({ size: spec.byteSize, usage: READBACK, label: `cap-${name}` })
242
+ captureStagings.set(name, staging)
243
+ }
244
+ return captureStagings
245
+ }
133
246
 
134
- // ---- step() ---------------------------------------------------------------
135
- async function step(inputs: Record<string, Int32Array | Float32Array>): Promise<number> {
247
+ // ---- dispatch() — shared core for step() and run() -----------------------
248
+ // Uploads inputs, dispatches all kernels (in order), queues writebacks, copies
249
+ // the output buffer into its staging, optionally copies captures into theirs,
250
+ // submits, and reads back. Returns the full output Float32Array; step() takes
251
+ // [0] for scalar loss, run() returns it whole.
252
+ async function dispatch(
253
+ inputs: Record<string, Int32Array | Float32Array>,
254
+ wantCaptures: boolean,
255
+ ): Promise<{ output: Float32Array; captures: Record<string, Float32Array> | null }> {
256
+ if (wantCaptures && plan.capturesByName.size === 0) {
257
+ throw new Error(
258
+ `withCaptures=true but no capture(...) calls were registered during ` +
259
+ `the trace. Add capture('name', tensor) inside your forward pass for ` +
260
+ `the intermediates you want read back.`,
261
+ )
262
+ }
136
263
  for (const [name, bufId] of plan.inputsByName) {
137
264
  const data = inputs[name]
138
265
  if (!data) throw new Error(`tensorgrad: missing input '${name}'`)
@@ -165,26 +292,93 @@ export async function createRuntime(
165
292
  pass.dispatchWorkgroups(wgX, wgY, 1)
166
293
  pass.end()
167
294
  }
168
- // After all dispatches: writebacks (Adam state, updated params).
169
- // copyBufferToBuffer is queued onto the same encoder so it's ordered after
170
- // all kernel dispatches.
295
+ // After all dispatches: writebacks (Adam state, updated params). Empty for
296
+ // forward-only compiles.
171
297
  for (const wb of plan.writebacks) {
172
298
  encoder.copyBufferToBuffer(buffers.get(wb.source)!, 0, buffers.get(wb.dest)!, 0, wb.bytes)
173
299
  }
174
- encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0, lossReadback, 0, lossSpec.byteSize)
300
+ encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0, outputReadback, 0, outputSpec.byteSize)
301
+ // Capture readbacks (only when opted in). Queued before submit so they
302
+ // observe the same kernel outputs as the main output.
303
+ let stagings: Map<string, GPUBuffer> | null = null
304
+ if (wantCaptures) {
305
+ stagings = ensureCaptureStagings()
306
+ for (const [name, bufId] of plan.capturesByName) {
307
+ const spec = plan.buffers[bufId]!
308
+ encoder.copyBufferToBuffer(buffers.get(bufId)!, 0, stagings.get(name)!, 0, spec.byteSize)
309
+ }
310
+ }
175
311
  queue.submit([encoder.finish()])
176
312
 
177
- await lossReadback.mapAsync(GPUMapMode.READ)
178
- const view = new Float32Array(lossReadback.getMappedRange().slice(0))
179
- lossReadback.unmap()
180
- return view[0]!
313
+ await outputReadback.mapAsync(GPUMapMode.READ)
314
+ const output = new Float32Array(outputReadback.getMappedRange().slice(0))
315
+ outputReadback.unmap()
316
+
317
+ if (!wantCaptures) return { output, captures: null }
318
+
319
+ const captures: Record<string, Float32Array> = {}
320
+ for (const [name, staging] of stagings!) {
321
+ await staging.mapAsync(GPUMapMode.READ)
322
+ captures[name] = new Float32Array(staging.getMappedRange().slice(0))
323
+ staging.unmap()
324
+ }
325
+ return { output, captures }
326
+ }
327
+
328
+ // ---- step() — training-mode wrapper, returns scalar [0] of output ---------
329
+ function step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
330
+ function step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepWithCaptures>
331
+ function step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepWithCaptures>
332
+ async function step(
333
+ inputs: Record<string, Int32Array | Float32Array>,
334
+ opts?: StepOptions,
335
+ ): Promise<number | StepWithCaptures> {
336
+ const r = await dispatch(inputs, opts?.withCaptures === true)
337
+ if (opts?.withCaptures) return { loss: r.output[0]!, captures: r.captures! }
338
+ return r.output[0]!
339
+ }
340
+
341
+ // ---- run() — forward-mode wrapper, returns full output Float32Array -------
342
+ function run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
343
+ function run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
344
+ function run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
345
+ async function run(
346
+ inputs: Record<string, Int32Array | Float32Array>,
347
+ opts?: RunOptions,
348
+ ): Promise<Float32Array | RunWithCaptures> {
349
+ const r = await dispatch(inputs, opts?.withCaptures === true)
350
+ if (opts?.withCaptures) return { output: r.output, captures: r.captures! }
351
+ return r.output
181
352
  }
182
353
 
183
354
  // ---- uploadParams ---------------------------------------------------------
184
- function uploadParams(params: Record<string, Float32Array>) {
355
+ function uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions) {
356
+ const partial = opts?.partial ?? false
357
+ for (const name of Object.keys(params)) {
358
+ if (!plan.paramsByName.has(name)) {
359
+ throw new Error(
360
+ `uploadParams: unknown param '${name}'. ` +
361
+ `Known: ${[...plan.paramsByName.keys()].sort().join(', ')}`,
362
+ )
363
+ }
364
+ }
365
+ if (!partial) {
366
+ for (const name of plan.paramsByName.keys()) {
367
+ if (!(name in params)) {
368
+ throw new Error(
369
+ `uploadParams: missing param '${name}'. ` +
370
+ `Pass { partial: true } if you mean to update only some params.`,
371
+ )
372
+ }
373
+ }
374
+ }
185
375
  for (const [name, bufId] of plan.paramsByName) {
186
376
  const data = params[name]
187
377
  if (!data) continue
378
+ const expected = plan.buffers[bufId]!.byteSize / 4
379
+ if (data.length !== expected) {
380
+ throw new Error(`uploadParams: '${name}' has ${data.length} elements, expected ${expected}`)
381
+ }
188
382
  queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
189
383
  }
190
384
  }
@@ -210,15 +404,59 @@ export async function createRuntime(
210
404
  return out
211
405
  }
212
406
 
407
+ function resetOptimizerState() {
408
+ for (const spec of plan.buffers) {
409
+ if (spec.kind !== 'state') continue
410
+ const elements = spec.byteSize / 4
411
+ const init = spec.dtype === 'f32'
412
+ ? new Float32Array(elements).fill(spec.initValue ?? 0)
413
+ : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
414
+ queue.writeBuffer(buffers.get(spec.id)!, 0, init as unknown as BufferSource)
415
+ }
416
+ }
417
+
418
+ // Build the params map AFTER buffer allocation so it points at the actual
419
+ // GPUBuffers (shared or freshly allocated).
420
+ const params = new Map<string, GPUBuffer>()
421
+ for (const [name, bufId] of plan.paramsByName) {
422
+ params.set(name, buffers.get(bufId)!)
423
+ }
424
+
425
+ const destroy = () => {
426
+ for (const [id, b] of buffers) {
427
+ if (ownedBufferIds.has(id)) b.destroy()
428
+ }
429
+ outputReadback.destroy()
430
+ if (captureStagings) for (const b of captureStagings.values()) b.destroy()
431
+ }
432
+
213
433
  return {
434
+ params,
214
435
  uploadParams,
215
436
  downloadParams: () => downloadFromMap(plan.paramsByName),
216
437
  downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),
217
438
  step,
218
- destroy: () => {
219
- for (const b of buffers.values()) b.destroy()
220
- lossReadback.destroy()
221
- },
439
+ run,
440
+ resetOptimizerState,
441
+ destroy,
442
+ }
443
+ }
444
+
445
+ /** Same machinery as `createRuntime`, narrower public API: no step,
446
+ * no resetOptimizerState, no downloadParamGrads. Used by `compileForward`. */
447
+ export async function createForwardRuntime(
448
+ plan: BufferPlan,
449
+ kernels: KernelSpec[],
450
+ outputBufferId: number,
451
+ opts: RuntimeOpts = {},
452
+ ): Promise<CompiledForward> {
453
+ const full = await createRuntime(plan, kernels, outputBufferId, opts)
454
+ return {
455
+ params: full.params,
456
+ uploadParams: full.uploadParams,
457
+ downloadParams: full.downloadParams,
458
+ run: full.run,
459
+ destroy: full.destroy,
222
460
  }
223
461
  }
224
462
 
package/src/trace.ts CHANGED
@@ -19,6 +19,10 @@ import { makeGraph, addOp, captureSite } from './ir.js'
19
19
 
20
20
  // Module-local: the graph being built right now, or null if no trace is active.
21
21
  let _current: Graph | null = null
22
+ // Module-local: whether `capture(name, t)` calls should register on the current
23
+ // graph. True only during the user's forward trace; false during `traceInto`
24
+ // (autograd / optimizer ops shouldn't accidentally publish gradient tensors).
25
+ let _captureEnabled = false
22
26
 
23
27
  export function currentGraph(): Graph {
24
28
  if (!_current) {
@@ -30,6 +34,10 @@ export function currentGraph(): Graph {
30
34
  return _current
31
35
  }
32
36
 
37
+ export function isCaptureEnabled(): boolean {
38
+ return _captureEnabled
39
+ }
40
+
33
41
  // Run `fn` with a fresh graph as the current one; capture and return the graph.
34
42
  // `fn` must return the tensor (or array of tensors) to mark as graph outputs.
35
43
  export function trace(fn: () => Tensor | Tensor[]): Graph {
@@ -38,6 +46,7 @@ export function trace(fn: () => Tensor | Tensor[]): Graph {
38
46
  }
39
47
  const g = makeGraph()
40
48
  _current = g
49
+ _captureEnabled = true
41
50
  try {
42
51
  const result = fn()
43
52
  const outputs = Array.isArray(result) ? result : [result]
@@ -46,6 +55,7 @@ export function trace(fn: () => Tensor | Tensor[]): Graph {
46
55
  }
47
56
  } finally {
48
57
  _current = null
58
+ _captureEnabled = false
49
59
  }
50
60
  return g
51
61
  }
@@ -53,12 +63,15 @@ export function trace(fn: () => Tensor | Tensor[]): Graph {
53
63
  // Re-enter an existing graph to append more ops. Used by autograd to add
54
64
  // backward ops to a graph that's already been traced. `fn` runs with the
55
65
  // supplied graph as the current one; any ops it calls append to that graph.
66
+ // Capture is intentionally disabled here — backward / optimizer rules
67
+ // shouldn't publish their internal tensors via `capture()`.
56
68
  // Returns whatever `fn` returns.
57
69
  export function traceInto<T>(g: Graph, fn: () => T): T {
58
70
  if (_current) {
59
71
  throw new Error('tensorgrad: traceInto() called while another trace is active')
60
72
  }
61
73
  _current = g
74
+ // _captureEnabled stays false (default) — explicit, but not toggled.
62
75
  try {
63
76
  return fn()
64
77
  } finally {