tensorgrad 0.0.15 → 0.0.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/module.ts CHANGED
@@ -32,30 +32,47 @@ import { paramInput } from './trace.js'
32
32
  // Init metadata
33
33
  // ============================================================================
34
34
 
35
- /** How a parameter's initial values are produced.
36
- * - `'randn'` Gaussian, with `scale` (default 0.02). The common case for
37
- * weight matrices and embeddings.
38
- * - `'zeros'` — fill with 0. Common for biases and LayerNorm beta.
39
- * - `'ones'` — fill with 1. Common for LayerNorm gain.
40
- * - Custom function receives total element count and shape, returns the
41
- * Float32Array. Use for fan-in scaling or any non-standard scheme.
35
+ /** How a parameter's initial values are produced. Serializable shape — no
36
+ * closures, since the initial values cross the worker boundary at compile
37
+ * time. Use the `init` helpers for ergonomic construction.
38
+ *
39
+ * String shorthands:
40
+ * - `'randn'`Gaussian with std 0.02 (the common weight-matrix init).
41
+ * - `'zeros'` fill with 0 (biases, LayerNorm beta).
42
+ * - `'ones'` — fill with 1 (LayerNorm gain).
43
+ *
44
+ * Object shapes:
45
+ * - `{ kind: 'randn', scale }` — randn with explicit std.
46
+ * - `{ kind: 'kaiming', gain? }` — `std = gain / sqrt(fan_in)`. Default
47
+ * gain `sqrt(2)` (good for ReLU). `fan_in = shape[0]`.
48
+ * - `{ kind: 'literal', data }` — explicit Float32Array; length must
49
+ * match the parameter's element count.
42
50
  */
43
51
  export type InitSpec =
44
52
  | 'randn'
45
53
  | 'zeros'
46
54
  | 'ones'
47
- | ((size: number, shape: readonly number[]) => Float32Array)
55
+ | { readonly kind: 'randn'; readonly scale: number }
56
+ | { readonly kind: 'kaiming'; readonly gain?: number }
57
+ | { readonly kind: 'literal'; readonly data: Float32Array }
58
+
59
+ /** Ergonomic constructors for InitSpec object shapes. */
60
+ export const init = {
61
+ randn: (opts: { scale?: number } = {}): InitSpec => ({ kind: 'randn', scale: opts.scale ?? 0.02 }),
62
+ kaiming: (opts: { gain?: number } = {}): InitSpec =>
63
+ opts.gain !== undefined ? { kind: 'kaiming', gain: opts.gain } : { kind: 'kaiming' },
64
+ literal: (data: Float32Array): InitSpec => ({ kind: 'literal', data }),
65
+ }
48
66
 
49
67
  export interface ParamOptions {
50
68
  dtype?: Dtype
51
- /** Init kind. Default: `'randn'`. */
69
+ /** Init shape. Default: `'randn'` (std 0.02). */
52
70
  init?: InitSpec
53
- /** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
54
- scale?: number
55
71
  /** Whether AdamW (when `weightDecay > 0`) should apply decoupled weight
56
- * decay to this param. Default: `true` for `'randn'` init (weight matrices,
57
- * embeddings), `false` for `'zeros'` / `'ones'` (biases, LN gains). Override
58
- * to force or skip. Replaces `adam.decayFilter` for the common case. */
72
+ * decay to this param. Default: `true` for randn/kaiming/literal init
73
+ * (weight matrices, embeddings); `false` for zeros/ones (biases, LN
74
+ * gains). Override to force or skip. Replaces `adam.decayFilter` for
75
+ * the common case. */
59
76
  decay?: boolean
60
77
  }
61
78
 
@@ -65,31 +82,52 @@ function boxMuller(): number {
65
82
  return Math.sqrt(-2 * Math.log(Math.max(1e-10, Math.random()))) * Math.cos(2 * Math.PI * Math.random())
66
83
  }
67
84
 
68
- function resolveInit(opts: ParamOptions | undefined): InitFn {
69
- const init = opts?.init ?? 'randn'
70
- if (init === 'randn') {
71
- const scale = opts?.scale ?? 0.02
72
- return (size) => {
73
- const arr = new Float32Array(size)
74
- for (let i = 0; i < size; i++) arr[i] = boxMuller() * scale
75
- return arr
85
+ function randnFn(scale: number): InitFn {
86
+ return (size) => {
87
+ const arr = new Float32Array(size)
88
+ for (let i = 0; i < size; i++) arr[i] = boxMuller() * scale
89
+ return arr
90
+ }
91
+ }
92
+
93
+ /** Compile-time-only: resolve an InitSpec shape into the closure that
94
+ * generates the initial Float32Array for a given parameter shape. Runs
95
+ * on the main thread before initial values are transferred to the worker. */
96
+ function resolveInit(spec: InitSpec | undefined): InitFn {
97
+ if (!spec || spec === 'randn') return randnFn(0.02)
98
+ if (spec === 'zeros') return (size) => new Float32Array(size)
99
+ if (spec === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
100
+ switch (spec.kind) {
101
+ case 'randn': return randnFn(spec.scale)
102
+ case 'kaiming': {
103
+ const gain = spec.gain ?? Math.sqrt(2)
104
+ return (size, shape) => {
105
+ const fanIn = shape[0] ?? size
106
+ const std = gain / Math.sqrt(fanIn)
107
+ const arr = new Float32Array(size)
108
+ for (let i = 0; i < size; i++) arr[i] = boxMuller() * std
109
+ return arr
110
+ }
111
+ }
112
+ case 'literal': {
113
+ const data = spec.data
114
+ return (size) => {
115
+ if (data.length !== size) {
116
+ throw new Error(`init.literal: data length ${data.length} doesn't match param size ${size}`)
117
+ }
118
+ return new Float32Array(data)
119
+ }
76
120
  }
77
121
  }
78
- if (init === 'zeros') return (size) => new Float32Array(size)
79
- if (init === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
80
- if (typeof init === 'function') return init
81
- throw new Error(`Unknown init: ${String(init)}`)
82
122
  }
83
123
 
84
- /** Resolve the decay default for a param. Decay weight matrices and
85
- * embedding tables (randn-initialized); skip biases (zeros) and LN gains
86
- * (ones). Custom init functions default to "decay" most user-supplied
87
- * inits are weight-shaped (Kaiming etc.). Explicit `decay: false` overrides. */
124
+ /** Resolve the decay default for a param. Weight-shaped inits (randn,
125
+ * kaiming, literal) default to decay=true; ones/zeros default to false
126
+ * (biases, LN gains). Explicit `decay` opt overrides. */
88
127
  function resolveDecay(opts: ParamOptions | undefined): boolean {
89
128
  if (opts?.decay !== undefined) return opts.decay
90
- const init = opts?.init ?? 'randn'
91
- if (init === 'zeros' || init === 'ones') return false
92
- return true // 'randn' or function
129
+ const spec = opts?.init ?? 'randn'
130
+ return spec !== 'zeros' && spec !== 'ones'
93
131
  }
94
132
 
95
133
  // ============================================================================
@@ -127,7 +165,7 @@ export abstract class Module {
127
165
  protected param(shape: Shape, opts?: ParamOptions): Tensor {
128
166
  const dtype = opts?.dtype ?? 'f32'
129
167
  // Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
130
- return new ParamSentinel(shape, dtype, resolveInit(opts), resolveDecay(opts)) as unknown as Tensor
168
+ return new ParamSentinel(shape, dtype, resolveInit(opts?.init), resolveDecay(opts)) as unknown as Tensor
131
169
  }
132
170
  }
133
171
 
package/src/runtime.ts CHANGED
@@ -348,42 +348,67 @@ export async function createRuntime(
348
348
  queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
349
349
  }
350
350
 
351
- const encoder = device.createCommandEncoder({ label: 'tensorgrad-step' })
352
- for (let i = 0; i < kernels.length; i++) {
353
- const k = kernels[i]!
354
- if (!k.wgsl || k.threads === 0) continue
355
- const pipeline = pipelines[i]!
356
- const bindGroup = bindGroups[i]!
357
- const pass = encoder.beginComputePass({ label: k.opKind })
358
- pass.setPipeline(pipeline)
359
- pass.setBindGroup(0, bindGroup)
360
- // WebGPU caps each dispatch dimension at 65535 workgroups. Split into 2D
361
- // when a kernel needs more than that on the X axis. Kernels compute their
362
- // global index as `gid.x + gid.y * (65535 * workgroup_size)`, matching the
363
- // stride we set here. For dispatches that fit in one row, gid.y is 0.
364
- const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize))
365
- const MAX_X = 65535
366
- const wgX = Math.min(wgCount, MAX_X)
367
- const wgY = Math.ceil(wgCount / MAX_X)
368
- pass.dispatchWorkgroups(wgX, wgY, 1)
369
- pass.end()
370
- }
371
- // After all dispatches: writebacks (Adam state, updated params). Empty for
372
- // forward-only compiles.
373
- for (const wb of plan.writebacks) {
374
- encoder.copyBufferToBuffer(buffers.get(wb.source)!, 0, buffers.get(wb.dest)!, 0, wb.bytes)
375
- }
376
- encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0, outputReadback, 0, outputSpec.byteSize)
377
- // Capture readbacks (only when opted in). All captures concatenate into
378
- // a single staging buffer so we mapAsync once instead of N times.
351
+ // Chunked submit. One queue.submit() of all 240 kernels monopolizes the
352
+ // GPU for the full step duration, blocking compositor frames the entire
353
+ // time. Splitting into chunks with an explicit GPU-drain await between
354
+ // them gives the compositor a slot at each chunk boundary. On graphs
355
+ // smaller than CHUNK_SIZE this collapses to a single submit (no
356
+ // overhead). See specs/WorkerArchitecture.md / mobile-jank investigation.
357
+ const CHUNK_SIZE = 32
379
358
  let layout: CaptureLayout | null = null
380
359
  if (wantCaptures) {
360
+ // Compute layout up front so the last chunk can append capture copies.
381
361
  layout = ensureCaptureStaging()
382
- for (const s of layout.slices) {
383
- encoder.copyBufferToBuffer(buffers.get(s.bufId)!, 0, layout.buffer, s.offset, s.byteSize)
362
+ }
363
+
364
+ let kernelIdx = 0
365
+ while (kernelIdx < kernels.length) {
366
+ const chunkEnd = Math.min(kernelIdx + CHUNK_SIZE, kernels.length)
367
+ const isLast = chunkEnd === kernels.length
368
+ const encoder = device.createCommandEncoder({
369
+ label: kernels.length > CHUNK_SIZE ? `tensorgrad-chunk-${kernelIdx}` : 'tensorgrad-step',
370
+ })
371
+ for (let i = kernelIdx; i < chunkEnd; i++) {
372
+ const k = kernels[i]!
373
+ if (!k.wgsl || k.threads === 0) continue
374
+ const pipeline = pipelines[i]!
375
+ const bindGroup = bindGroups[i]!
376
+ const pass = encoder.beginComputePass({ label: k.opKind })
377
+ pass.setPipeline(pipeline)
378
+ pass.setBindGroup(0, bindGroup)
379
+ // WebGPU caps each dispatch dimension at 65535 workgroups. Split into 2D
380
+ // when a kernel needs more than that on the X axis. Kernels compute their
381
+ // global index as `gid.x + gid.y * (65535 * workgroup_size)`, matching the
382
+ // stride we set here. For dispatches that fit in one row, gid.y is 0.
383
+ const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize))
384
+ const MAX_X = 65535
385
+ const wgX = Math.min(wgCount, MAX_X)
386
+ const wgY = Math.ceil(wgCount / MAX_X)
387
+ pass.dispatchWorkgroups(wgX, wgY, 1)
388
+ pass.end()
384
389
  }
390
+ if (isLast) {
391
+ // Writebacks (Adam state, updated params; empty for forward-only) +
392
+ // output readback copy + capture readback copies all go into the
393
+ // final chunk so a single mapAsync below sees everything.
394
+ for (const wb of plan.writebacks) {
395
+ encoder.copyBufferToBuffer(buffers.get(wb.source)!, 0, buffers.get(wb.dest)!, 0, wb.bytes)
396
+ }
397
+ encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0, outputReadback, 0, outputSpec.byteSize)
398
+ if (layout) {
399
+ for (const s of layout.slices) {
400
+ encoder.copyBufferToBuffer(buffers.get(s.bufId)!, 0, layout.buffer, s.offset, s.byteSize)
401
+ }
402
+ }
403
+ }
404
+ queue.submit([encoder.finish()])
405
+ if (!isLast) {
406
+ // Drain the chunk before queuing the next one. This is the moment
407
+ // the compositor can interleave its own frame work onto the GPU.
408
+ await queue.onSubmittedWorkDone()
409
+ }
410
+ kernelIdx = chunkEnd
385
411
  }
386
- queue.submit([encoder.finish()])
387
412
 
388
413
  // readback=false: training fire-and-forget. The encoder still copied
389
414
  // loss → outputReadback (and captures → staging), but we don't await
@@ -0,0 +1,183 @@
1
+ // Wire format for the main-thread ↔ worker postMessage channel.
2
+ //
3
+ // All requests carry a numeric `id` assigned by the main thread; responses
4
+ // echo it back so the proxy can match concurrent in-flight calls. Every
5
+ // response is either `{ ok: true, result }` or `{ ok: false, error }`.
6
+ // Errors carry serialized name/message/stack so the proxy can reconstitute
7
+ // an Error with a working `instanceof` check on the receiving side.
8
+ //
9
+ // Inputs (typed arrays) and outputs (typed arrays, captures) are transferred
10
+ // rather than copied — see the per-request notes for which fields go on the
11
+ // transfer list. A single worker may host multiple compiled graphs (a train
12
+ // graph plus sibling forward graphs); each has a `graphId` issued by the
13
+ // main thread at compile time.
14
+
15
+ import type { Graph } from './ir.js'
16
+ import type { BufferPlan } from './buffers.js'
17
+ import type { KernelSpec } from './codegen.js'
18
+ import type { LRSchedule } from './adam.js'
19
+
20
+ // ============================================================================
21
+ // Serializable config (subset of AdamResolvedConfig that crosses the wire).
22
+ // `decayFilter` (a function, used only at compile time) is NOT part of this —
23
+ // the per-param decay decision is already baked into the IR by appendAdam
24
+ // before the IR ships to the worker.
25
+ // ============================================================================
26
+
27
+ export interface WireAdamConfig {
28
+ lr: LRSchedule
29
+ b1: number
30
+ b2: number
31
+ eps: number
32
+ weightDecay: number
33
+ lrIsScheduled: boolean
34
+ /** Names of the per-step scalar inputs the worker must populate before
35
+ * every step (`_adam_lrt`, optionally `_adam_decay_shrink`). Mirrors
36
+ * AdamResult so the worker can update them without re-deriving. */
37
+ lrtInputName: string
38
+ decayShrinkInputName: string | null
39
+ }
40
+
41
+ /** Compile output that crosses to the worker. Same fields as CompiledIR
42
+ * minus the `loss` tensor (carried by graph.outputs[0]). */
43
+ export interface WireIR {
44
+ graph: Graph
45
+ plan: BufferPlan
46
+ kernels: KernelSpec[]
47
+ }
48
+
49
+ // ============================================================================
50
+ // Requests (main → worker)
51
+ // ============================================================================
52
+
53
+ export type Req =
54
+ | { id: number; kind: 'createRuntime'; payload: CreateRuntimePayload }
55
+ | { id: number; kind: 'compileForward'; payload: CompileForwardPayload }
56
+ | { id: number; kind: 'step'; payload: StepPayload }
57
+ | { id: number; kind: 'run'; payload: RunPayload }
58
+ | { id: number; kind: 'uploadParams'; payload: UploadParamsPayload }
59
+ | { id: number; kind: 'downloadParams'; payload: { graphId: number } }
60
+ | { id: number; kind: 'downloadParamGrads'; payload: { graphId: number } }
61
+ | { id: number; kind: 'resetOptimizer'; payload: { graphId: number } }
62
+ | { id: number; kind: 'destroy'; payload: { graphId: number } }
63
+
64
+ /** Build the training runtime. Always graphId=0 for a fresh worker. */
65
+ export interface CreateRuntimePayload {
66
+ graphId: number
67
+ ir: WireIR
68
+ /** Initial param values per name. Transferred (zero-copy) — the main
69
+ * thread loses access after postMessage. */
70
+ initialParams: Record<string, Float32Array>
71
+ /** Adam config when training; absent for forward-only compiles. */
72
+ adam: WireAdamConfig | null
73
+ }
74
+
75
+ /** Build a sibling forward-only graph that shares param buffers with an
76
+ * existing graph (typically the training graph at graphId=0). */
77
+ export interface CompileForwardPayload {
78
+ graphId: number
79
+ parentGraphId: number
80
+ ir: WireIR
81
+ }
82
+
83
+ /** One training step. Inputs are transferred; the caller's typed arrays
84
+ * become detached after postMessage. */
85
+ export interface StepPayload {
86
+ graphId: number
87
+ inputs: Record<string, Int32Array | Float32Array>
88
+ withCaptures: boolean
89
+ }
90
+
91
+ /** Forward-only run. Same transfer semantics as `step`. */
92
+ export interface RunPayload {
93
+ graphId: number
94
+ inputs: Record<string, Int32Array | Float32Array>
95
+ withCaptures: boolean
96
+ }
97
+
98
+ export interface UploadParamsPayload {
99
+ graphId: number
100
+ params: Record<string, Float32Array> // transferred
101
+ partial: boolean
102
+ }
103
+
104
+ // ============================================================================
105
+ // Responses (worker → main)
106
+ // ============================================================================
107
+
108
+ export type Res<R = unknown> =
109
+ | { id: number; ok: true; result: R }
110
+ | { id: number; ok: false; error: WireError }
111
+
112
+ export interface WireError {
113
+ name: string
114
+ message: string
115
+ stack: string
116
+ }
117
+
118
+ // Per-request result shapes:
119
+
120
+ export interface CreateRuntimeResult {
121
+ paramNames: string[]
122
+ outputShape: number[]
123
+ kernelCount: number
124
+ captureShapes: Record<string, number[]>
125
+ }
126
+
127
+ export interface CompileForwardResult {
128
+ paramNames: string[]
129
+ outputShape: number[]
130
+ kernelCount: number
131
+ captureShapes: Record<string, number[]>
132
+ }
133
+
134
+ /** Step without `withCaptures` returns just `loss`. With captures, also
135
+ * populates `captures` (per-name Float32Array, all transferred back). */
136
+ export interface StepResultWire {
137
+ loss: number
138
+ captures: Record<string, Float32Array> | null
139
+ }
140
+
141
+ /** Run without `withCaptures` returns `{ output, captures: null }`.
142
+ * With captures, also populates `captures`. */
143
+ export interface RunResultWire {
144
+ output: Float32Array
145
+ captures: Record<string, Float32Array> | null
146
+ }
147
+
148
+ export interface DownloadParamsResult {
149
+ params: Record<string, Float32Array> // transferred
150
+ }
151
+
152
+ // ============================================================================
153
+ // Transfer-list helpers
154
+ // ============================================================================
155
+
156
+ /** Collect the underlying ArrayBuffers from a Record of typed arrays so we
157
+ * can pass them on `postMessage`'s transfer list. The values themselves
158
+ * stay in the Record; only their backing buffers move. */
159
+ export function transferablesOfRecord(
160
+ rec: Record<string, Int32Array | Float32Array>,
161
+ ): ArrayBuffer[] {
162
+ const out: ArrayBuffer[] = []
163
+ for (const v of Object.values(rec)) out.push(v.buffer as ArrayBuffer)
164
+ return out
165
+ }
166
+
167
+ /** Serialize an Error to a wire-friendly shape, preserving stack + name so
168
+ * the receiving side can reconstitute an Error that an `instanceof`-aware
169
+ * caller (e.g., for `ShapeError`) can still pattern-match by name. */
170
+ export function wireError(e: unknown): WireError {
171
+ if (e instanceof Error) {
172
+ return { name: e.name, message: e.message, stack: e.stack ?? '' }
173
+ }
174
+ return { name: 'Error', message: String(e), stack: '' }
175
+ }
176
+
177
+ /** Reconstitute an Error from the wire shape on the receiving (main) side. */
178
+ export function reconstituteError(w: WireError): Error {
179
+ const err = new Error(w.message)
180
+ err.name = w.name
181
+ err.stack = w.stack
182
+ return err
183
+ }
@@ -0,0 +1,76 @@
1
+ // Main-thread half of the worker channel: request/response correlation,
2
+ // promise wiring, error reconstitution. Knows nothing about Adam, captures,
3
+ // IR, etc. — just shuttles typed messages.
4
+
5
+ import type { Req, Res, WireError } from './worker-protocol.js'
6
+ import { reconstituteError } from './worker-protocol.js'
7
+
8
+ interface PendingHandlers {
9
+ resolve: (v: unknown) => void
10
+ reject: (e: Error) => void
11
+ }
12
+
13
+ /** Spawn a worker from an inlined source string and provide a typed
14
+ * request/response channel. One WorkerProxy = one Worker = one GPUDevice
15
+ * on the worker side. Sibling graphs share the same WorkerProxy. */
16
+ export class WorkerProxy {
17
+ private worker: Worker
18
+ private nextId = 1
19
+ private pending = new Map<number, PendingHandlers>()
20
+ private terminated = false
21
+
22
+ constructor(workerSource: string) {
23
+ const blob = new Blob([workerSource], { type: 'application/javascript' })
24
+ const url = URL.createObjectURL(blob)
25
+ this.worker = new Worker(url, { type: 'module' })
26
+ // The Blob URL keeps memory alive as long as it's referenced; revoke
27
+ // once the worker has loaded its source. Browsers tolerate revoke
28
+ // immediately after construction in practice.
29
+ URL.revokeObjectURL(url)
30
+
31
+ this.worker.onmessage = (ev: MessageEvent<Res>) => {
32
+ const reply = ev.data
33
+ const handlers = this.pending.get(reply.id)
34
+ if (!handlers) return // stale reply; ignore
35
+ this.pending.delete(reply.id)
36
+ if (reply.ok) handlers.resolve(reply.result)
37
+ else handlers.reject(reconstituteError(reply.error))
38
+ }
39
+
40
+ this.worker.onerror = (ev: ErrorEvent) => {
41
+ const err = new Error(`tensorgrad worker error: ${ev.message || 'unknown'}`)
42
+ const wire: WireError = { name: 'WorkerError', message: err.message, stack: err.stack ?? '' }
43
+ // Reject everything in flight; subsequent calls will fail too.
44
+ for (const handlers of this.pending.values()) handlers.reject(reconstituteError(wire))
45
+ this.pending.clear()
46
+ }
47
+ }
48
+
49
+ /** Send a request and await its matching response. `transfer` lists the
50
+ * ArrayBuffers to move (zero-copy) into the worker. */
51
+ request<R>(req: Omit<Req, 'id'>, transfer: ArrayBuffer[] = []): Promise<R> {
52
+ if (this.terminated) return Promise.reject(new Error('tensorgrad: worker has been terminated'))
53
+ const id = this.nextId++
54
+ return new Promise<R>((resolve, reject) => {
55
+ this.pending.set(id, { resolve: resolve as (v: unknown) => void, reject })
56
+ this.worker.postMessage({ ...req, id } as Req, transfer)
57
+ })
58
+ }
59
+
60
+ /** Fire-and-forget variant for cases where the caller doesn't need a reply
61
+ * (currently unused; keep for symmetry / future use). */
62
+ send(req: Omit<Req, 'id'>, transfer: ArrayBuffer[] = []): void {
63
+ if (this.terminated) return
64
+ const id = this.nextId++
65
+ this.worker.postMessage({ ...req, id } as Req, transfer)
66
+ }
67
+
68
+ terminate(): void {
69
+ if (this.terminated) return
70
+ this.terminated = true
71
+ this.worker.terminate()
72
+ const err = new Error('tensorgrad: worker terminated')
73
+ for (const handlers of this.pending.values()) handlers.reject(err)
74
+ this.pending.clear()
75
+ }
76
+ }