tensorgrad 0.0.14 → 0.0.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/module.ts CHANGED
@@ -32,30 +32,47 @@ import { paramInput } from './trace.js'
32
32
  // Init metadata
33
33
  // ============================================================================
34
34
 
35
- /** How a parameter's initial values are produced.
36
- * - `'randn'` Gaussian, with `scale` (default 0.02). The common case for
37
- * weight matrices and embeddings.
38
- * - `'zeros'` — fill with 0. Common for biases and LayerNorm beta.
39
- * - `'ones'` — fill with 1. Common for LayerNorm gain.
40
- * - Custom function receives total element count and shape, returns the
41
- * Float32Array. Use for fan-in scaling or any non-standard scheme.
35
+ /** How a parameter's initial values are produced. Serializable shape — no
36
+ * closures, since the initial values cross the worker boundary at compile
37
+ * time. Use the `init` helpers for ergonomic construction.
38
+ *
39
+ * String shorthands:
40
+ * - `'randn'`Gaussian with std 0.02 (the common weight-matrix init).
41
+ * - `'zeros'` fill with 0 (biases, LayerNorm beta).
42
+ * - `'ones'` — fill with 1 (LayerNorm gain).
43
+ *
44
+ * Object shapes:
45
+ * - `{ kind: 'randn', scale }` — randn with explicit std.
46
+ * - `{ kind: 'kaiming', gain? }` — `std = gain / sqrt(fan_in)`. Default
47
+ * gain `sqrt(2)` (good for ReLU). `fan_in = shape[0]`.
48
+ * - `{ kind: 'literal', data }` — explicit Float32Array; length must
49
+ * match the parameter's element count.
42
50
  */
43
51
  export type InitSpec =
44
52
  | 'randn'
45
53
  | 'zeros'
46
54
  | 'ones'
47
- | ((size: number, shape: readonly number[]) => Float32Array)
55
+ | { readonly kind: 'randn'; readonly scale: number }
56
+ | { readonly kind: 'kaiming'; readonly gain?: number }
57
+ | { readonly kind: 'literal'; readonly data: Float32Array }
58
+
59
+ /** Ergonomic constructors for InitSpec object shapes. */
60
+ export const init = {
61
+ randn: (opts: { scale?: number } = {}): InitSpec => ({ kind: 'randn', scale: opts.scale ?? 0.02 }),
62
+ kaiming: (opts: { gain?: number } = {}): InitSpec =>
63
+ opts.gain !== undefined ? { kind: 'kaiming', gain: opts.gain } : { kind: 'kaiming' },
64
+ literal: (data: Float32Array): InitSpec => ({ kind: 'literal', data }),
65
+ }
48
66
 
49
67
  export interface ParamOptions {
50
68
  dtype?: Dtype
51
- /** Init kind. Default: `'randn'`. */
69
+ /** Init shape. Default: `'randn'` (std 0.02). */
52
70
  init?: InitSpec
53
- /** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
54
- scale?: number
55
71
  /** Whether AdamW (when `weightDecay > 0`) should apply decoupled weight
56
- * decay to this param. Default: `true` for `'randn'` init (weight matrices,
57
- * embeddings), `false` for `'zeros'` / `'ones'` (biases, LN gains). Override
58
- * to force or skip. Replaces `adam.decayFilter` for the common case. */
72
+ * decay to this param. Default: `true` for randn/kaiming/literal init
73
+ * (weight matrices, embeddings); `false` for zeros/ones (biases, LN
74
+ * gains). Override to force or skip. Replaces `adam.decayFilter` for
75
+ * the common case. */
59
76
  decay?: boolean
60
77
  }
61
78
 
@@ -65,31 +82,52 @@ function boxMuller(): number {
65
82
  return Math.sqrt(-2 * Math.log(Math.max(1e-10, Math.random()))) * Math.cos(2 * Math.PI * Math.random())
66
83
  }
67
84
 
68
- function resolveInit(opts: ParamOptions | undefined): InitFn {
69
- const init = opts?.init ?? 'randn'
70
- if (init === 'randn') {
71
- const scale = opts?.scale ?? 0.02
72
- return (size) => {
73
- const arr = new Float32Array(size)
74
- for (let i = 0; i < size; i++) arr[i] = boxMuller() * scale
75
- return arr
85
+ function randnFn(scale: number): InitFn {
86
+ return (size) => {
87
+ const arr = new Float32Array(size)
88
+ for (let i = 0; i < size; i++) arr[i] = boxMuller() * scale
89
+ return arr
90
+ }
91
+ }
92
+
93
+ /** Compile-time-only: resolve an InitSpec shape into the closure that
94
+ * generates the initial Float32Array for a given parameter shape. Runs
95
+ * on the main thread before initial values are transferred to the worker. */
96
+ function resolveInit(spec: InitSpec | undefined): InitFn {
97
+ if (!spec || spec === 'randn') return randnFn(0.02)
98
+ if (spec === 'zeros') return (size) => new Float32Array(size)
99
+ if (spec === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
100
+ switch (spec.kind) {
101
+ case 'randn': return randnFn(spec.scale)
102
+ case 'kaiming': {
103
+ const gain = spec.gain ?? Math.sqrt(2)
104
+ return (size, shape) => {
105
+ const fanIn = shape[0] ?? size
106
+ const std = gain / Math.sqrt(fanIn)
107
+ const arr = new Float32Array(size)
108
+ for (let i = 0; i < size; i++) arr[i] = boxMuller() * std
109
+ return arr
110
+ }
111
+ }
112
+ case 'literal': {
113
+ const data = spec.data
114
+ return (size) => {
115
+ if (data.length !== size) {
116
+ throw new Error(`init.literal: data length ${data.length} doesn't match param size ${size}`)
117
+ }
118
+ return new Float32Array(data)
119
+ }
76
120
  }
77
121
  }
78
- if (init === 'zeros') return (size) => new Float32Array(size)
79
- if (init === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
80
- if (typeof init === 'function') return init
81
- throw new Error(`Unknown init: ${String(init)}`)
82
122
  }
83
123
 
84
- /** Resolve the decay default for a param. Decay weight matrices and
85
- * embedding tables (randn-initialized); skip biases (zeros) and LN gains
86
- * (ones). Custom init functions default to "decay" most user-supplied
87
- * inits are weight-shaped (Kaiming etc.). Explicit `decay: false` overrides. */
124
+ /** Resolve the decay default for a param. Weight-shaped inits (randn,
125
+ * kaiming, literal) default to decay=true; ones/zeros default to false
126
+ * (biases, LN gains). Explicit `decay` opt overrides. */
88
127
  function resolveDecay(opts: ParamOptions | undefined): boolean {
89
128
  if (opts?.decay !== undefined) return opts.decay
90
- const init = opts?.init ?? 'randn'
91
- if (init === 'zeros' || init === 'ones') return false
92
- return true // 'randn' or function
129
+ const spec = opts?.init ?? 'randn'
130
+ return spec !== 'zeros' && spec !== 'ones'
93
131
  }
94
132
 
95
133
  // ============================================================================
@@ -127,7 +165,7 @@ export abstract class Module {
127
165
  protected param(shape: Shape, opts?: ParamOptions): Tensor {
128
166
  const dtype = opts?.dtype ?? 'f32'
129
167
  // Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
130
- return new ParamSentinel(shape, dtype, resolveInit(opts), resolveDecay(opts)) as unknown as Tensor
168
+ return new ParamSentinel(shape, dtype, resolveInit(opts?.init), resolveDecay(opts)) as unknown as Tensor
131
169
  }
132
170
  }
133
171
 
package/src/runtime.ts CHANGED
@@ -69,6 +69,23 @@ export interface RunOptions {
69
69
  withCaptures?: boolean
70
70
  }
71
71
 
72
+ export interface StepOptions extends RunOptions {
73
+ /** If false, the training submit is queued but the JS thread does not
74
+ * await `mapAsync` of the loss buffer. Returns `void` immediately.
75
+ * Use `runtime.readLoss()` to read the latest loss explicitly when
76
+ * you want it (e.g., every Nth step for UI display).
77
+ *
78
+ * Why: each `mapAsync` round-trip is ~1 ms on desktop but 10–30 ms on
79
+ * Android Chrome. A training loop that awaits per step pays N × that
80
+ * on the main thread, which on mobile starves the OS compositor and
81
+ * causes visible UI sluggishness. With `readLoss: false` plus a
82
+ * `requestAnimationFrame` yield between steps, the main thread stays
83
+ * responsive while training runs at GPU speed.
84
+ *
85
+ * Implies `withCaptures: false`. Default: true. */
86
+ readLoss?: boolean
87
+ }
88
+
72
89
  /** Common surface for both training and forward-only compiled runtimes. */
73
90
  export interface CompiledBase {
74
91
  /** The GPUDevice this runtime is bound to. Pass to sibling compiles to
@@ -112,11 +129,16 @@ export interface CompiledRuntime extends CompiledBase {
112
129
  */
113
130
  step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
114
131
  step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
115
- step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>
132
+ step(inputs: Record<string, Int32Array | Float32Array>, opts: { readLoss: false }): Promise<void>
133
+ step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepResult | void>
116
134
  /** Same dispatch as step() but returns the full output Float32Array — for
117
135
  * training graphs the output is a scalar loss, so step() is usually more
118
136
  * convenient. Provided for parity with `compileForward`. */
119
137
  run: RunFn
138
+ /** Read the latest loss value from the GPU. Pair with `step({ readLoss: false })`
139
+ * fire-and-forget training: every Nth iteration, call `readLoss()` for the
140
+ * UI, but most iterations don't pay the `mapAsync` cost. */
141
+ readLoss(): Promise<number>
120
142
  /** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
121
143
  * `uploadInitialParams()` for a full training reset without recompile. */
122
144
  resetOptimizerState(): void
@@ -292,18 +314,21 @@ export async function createRuntime(
292
314
  // run sequentially even when fired from independent async paths (e.g., a
293
315
  // training loop's auxiliary `refreshPrediction()` + `writeDiagnostic()`).
294
316
  let pending: Promise<unknown> = Promise.resolve()
317
+ type DispatchOpts = { wantCaptures: boolean; readback: boolean }
318
+ type DispatchResult = { output: Float32Array; captures: Map<string, Float32Array> } | null
295
319
  async function dispatch(
296
320
  inputs: Record<string, Int32Array | Float32Array>,
297
- wantCaptures: boolean,
298
- ): Promise<{ output: Float32Array; captures: Map<string, Float32Array> }> {
299
- const turn = pending.catch(() => {}).then(() => dispatchUnsynchronized(inputs, wantCaptures))
321
+ opts: DispatchOpts,
322
+ ): Promise<DispatchResult> {
323
+ const turn = pending.catch(() => {}).then(() => dispatchUnsynchronized(inputs, opts))
300
324
  pending = turn
301
325
  return turn
302
326
  }
303
327
  async function dispatchUnsynchronized(
304
328
  inputs: Record<string, Int32Array | Float32Array>,
305
- wantCaptures: boolean,
306
- ): Promise<{ output: Float32Array; captures: Map<string, Float32Array> }> {
329
+ opts: DispatchOpts,
330
+ ): Promise<DispatchResult> {
331
+ const wantCaptures = opts.wantCaptures
307
332
  if (wantCaptures && plan.capturesByName.size === 0) {
308
333
  throw new Error(
309
334
  `withCaptures=true but no capture(...) calls were registered during ` +
@@ -360,6 +385,12 @@ export async function createRuntime(
360
385
  }
361
386
  queue.submit([encoder.finish()])
362
387
 
388
+ // readback=false: training fire-and-forget. The encoder still copied
389
+ // loss → outputReadback (and captures → staging), but we don't await
390
+ // mapAsync. The caller can read the latest loss later via readLoss()
391
+ // when it actually wants to display it.
392
+ if (!opts.readback) return null
393
+
363
394
  await outputReadback.mapAsync(GPUMapMode.READ)
364
395
  const output = new Float32Array(outputReadback.getMappedRange().slice(0))
365
396
  outputReadback.unmap()
@@ -381,16 +412,37 @@ export async function createRuntime(
381
412
  // ---- step() — training-mode wrapper, returns scalar [0] of output ---------
382
413
  function step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
383
414
  function step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
384
- function step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>
415
+ function step(inputs: Record<string, Int32Array | Float32Array>, opts: { readLoss: false }): Promise<void>
416
+ function step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepResult | void>
385
417
  async function step(
386
418
  inputs: Record<string, Int32Array | Float32Array>,
387
- opts?: RunOptions,
388
- ): Promise<number | StepResult> {
389
- const r = await dispatch(inputs, opts?.withCaptures === true)
419
+ opts?: StepOptions,
420
+ ): Promise<number | StepResult | void> {
421
+ if (opts?.readLoss === false) {
422
+ await dispatch(inputs, { wantCaptures: false, readback: false })
423
+ return
424
+ }
425
+ const r = (await dispatch(inputs, { wantCaptures: opts?.withCaptures === true, readback: true }))!
390
426
  if (opts?.withCaptures) return { loss: r.output[0]!, captures: new Captures(captureShapes, r.captures) }
391
427
  return r.output[0]!
392
428
  }
393
429
 
430
+ // ---- readLoss() — explicit late readback for fire-and-forget training -----
431
+ // Maps the output buffer (which step() always copies the latest loss into,
432
+ // even when readLoss:false) and returns the value. Goes through the same
433
+ // serialization chain as step()/run() so two readLoss() calls don't both
434
+ // try to mapAsync the same buffer.
435
+ async function readLoss(): Promise<number> {
436
+ const turn = pending.catch(() => {}).then(async () => {
437
+ await outputReadback.mapAsync(GPUMapMode.READ)
438
+ const v = new Float32Array(outputReadback.getMappedRange())[0]!
439
+ outputReadback.unmap()
440
+ return v
441
+ })
442
+ pending = turn
443
+ return turn
444
+ }
445
+
394
446
  // ---- run() — forward-mode wrapper, returns Float32Array by default -------
395
447
  // Same overloaded shape as step(): scalar-shaped result (here Float32Array,
396
448
  // there a JS number) is the default; { ..., captures } is the opt-in form.
@@ -401,7 +453,7 @@ export async function createRuntime(
401
453
  inputs: Record<string, Int32Array | Float32Array>,
402
454
  opts?: RunOptions,
403
455
  ): Promise<Float32Array | RunResult> {
404
- const r = await dispatch(inputs, opts?.withCaptures === true)
456
+ const r = (await dispatch(inputs, { wantCaptures: opts?.withCaptures === true, readback: true }))!
405
457
  if (opts?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) }
406
458
  return r.output
407
459
  }
@@ -507,6 +559,7 @@ export async function createRuntime(
507
559
  downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),
508
560
  step,
509
561
  run,
562
+ readLoss,
510
563
  resetOptimizerState,
511
564
  destroy,
512
565
  }
@@ -0,0 +1,183 @@
1
+ // Wire format for the main-thread ↔ worker postMessage channel.
2
+ //
3
+ // All requests carry a numeric `id` assigned by the main thread; responses
4
+ // echo it back so the proxy can match concurrent in-flight calls. Every
5
+ // response is either `{ ok: true, result }` or `{ ok: false, error }`.
6
+ // Errors carry serialized name/message/stack so the proxy can reconstitute
7
+ // an Error with a working `instanceof` check on the receiving side.
8
+ //
9
+ // Inputs (typed arrays) and outputs (typed arrays, captures) are transferred
10
+ // rather than copied — see the per-request notes for which fields go on the
11
+ // transfer list. A single worker may host multiple compiled graphs (a train
12
+ // graph plus sibling forward graphs); each has a `graphId` issued by the
13
+ // main thread at compile time.
14
+
15
+ import type { Graph } from './ir.js'
16
+ import type { BufferPlan } from './buffers.js'
17
+ import type { KernelSpec } from './codegen.js'
18
+ import type { LRSchedule } from './adam.js'
19
+
20
+ // ============================================================================
21
+ // Serializable config (subset of AdamResolvedConfig that crosses the wire).
22
+ // `decayFilter` (a function, used only at compile time) is NOT part of this —
23
+ // the per-param decay decision is already baked into the IR by appendAdam
24
+ // before the IR ships to the worker.
25
+ // ============================================================================
26
+
27
+ export interface WireAdamConfig {
28
+ lr: LRSchedule
29
+ b1: number
30
+ b2: number
31
+ eps: number
32
+ weightDecay: number
33
+ lrIsScheduled: boolean
34
+ /** Names of the per-step scalar inputs the worker must populate before
35
+ * every step (`_adam_lrt`, optionally `_adam_decay_shrink`). Mirrors
36
+ * AdamResult so the worker can update them without re-deriving. */
37
+ lrtInputName: string
38
+ decayShrinkInputName: string | null
39
+ }
40
+
41
+ /** Compile output that crosses to the worker. Same fields as CompiledIR
42
+ * minus the `loss` tensor (carried by graph.outputs[0]). */
43
+ export interface WireIR {
44
+ graph: Graph
45
+ plan: BufferPlan
46
+ kernels: KernelSpec[]
47
+ }
48
+
49
+ // ============================================================================
50
+ // Requests (main → worker)
51
+ // ============================================================================
52
+
53
+ export type Req =
54
+ | { id: number; kind: 'createRuntime'; payload: CreateRuntimePayload }
55
+ | { id: number; kind: 'compileForward'; payload: CompileForwardPayload }
56
+ | { id: number; kind: 'step'; payload: StepPayload }
57
+ | { id: number; kind: 'run'; payload: RunPayload }
58
+ | { id: number; kind: 'uploadParams'; payload: UploadParamsPayload }
59
+ | { id: number; kind: 'downloadParams'; payload: { graphId: number } }
60
+ | { id: number; kind: 'downloadParamGrads'; payload: { graphId: number } }
61
+ | { id: number; kind: 'resetOptimizer'; payload: { graphId: number } }
62
+ | { id: number; kind: 'destroy'; payload: { graphId: number } }
63
+
64
+ /** Build the training runtime. Always graphId=0 for a fresh worker. */
65
+ export interface CreateRuntimePayload {
66
+ graphId: number
67
+ ir: WireIR
68
+ /** Initial param values per name. Transferred (zero-copy) — the main
69
+ * thread loses access after postMessage. */
70
+ initialParams: Record<string, Float32Array>
71
+ /** Adam config when training; absent for forward-only compiles. */
72
+ adam: WireAdamConfig | null
73
+ }
74
+
75
+ /** Build a sibling forward-only graph that shares param buffers with an
76
+ * existing graph (typically the training graph at graphId=0). */
77
+ export interface CompileForwardPayload {
78
+ graphId: number
79
+ parentGraphId: number
80
+ ir: WireIR
81
+ }
82
+
83
+ /** One training step. Inputs are transferred; the caller's typed arrays
84
+ * become detached after postMessage. */
85
+ export interface StepPayload {
86
+ graphId: number
87
+ inputs: Record<string, Int32Array | Float32Array>
88
+ withCaptures: boolean
89
+ }
90
+
91
+ /** Forward-only run. Same transfer semantics as `step`. */
92
+ export interface RunPayload {
93
+ graphId: number
94
+ inputs: Record<string, Int32Array | Float32Array>
95
+ withCaptures: boolean
96
+ }
97
+
98
+ export interface UploadParamsPayload {
99
+ graphId: number
100
+ params: Record<string, Float32Array> // transferred
101
+ partial: boolean
102
+ }
103
+
104
+ // ============================================================================
105
+ // Responses (worker → main)
106
+ // ============================================================================
107
+
108
+ export type Res<R = unknown> =
109
+ | { id: number; ok: true; result: R }
110
+ | { id: number; ok: false; error: WireError }
111
+
112
+ export interface WireError {
113
+ name: string
114
+ message: string
115
+ stack: string
116
+ }
117
+
118
+ // Per-request result shapes:
119
+
120
+ export interface CreateRuntimeResult {
121
+ paramNames: string[]
122
+ outputShape: number[]
123
+ kernelCount: number
124
+ captureShapes: Record<string, number[]>
125
+ }
126
+
127
+ export interface CompileForwardResult {
128
+ paramNames: string[]
129
+ outputShape: number[]
130
+ kernelCount: number
131
+ captureShapes: Record<string, number[]>
132
+ }
133
+
134
+ /** Step without `withCaptures` returns just `loss`. With captures, also
135
+ * populates `captures` (per-name Float32Array, all transferred back). */
136
+ export interface StepResultWire {
137
+ loss: number
138
+ captures: Record<string, Float32Array> | null
139
+ }
140
+
141
+ /** Run without `withCaptures` returns `{ output, captures: null }`.
142
+ * With captures, also populates `captures`. */
143
+ export interface RunResultWire {
144
+ output: Float32Array
145
+ captures: Record<string, Float32Array> | null
146
+ }
147
+
148
+ export interface DownloadParamsResult {
149
+ params: Record<string, Float32Array> // transferred
150
+ }
151
+
152
+ // ============================================================================
153
+ // Transfer-list helpers
154
+ // ============================================================================
155
+
156
+ /** Collect the underlying ArrayBuffers from a Record of typed arrays so we
157
+ * can pass them on `postMessage`'s transfer list. The values themselves
158
+ * stay in the Record; only their backing buffers move. */
159
+ export function transferablesOfRecord(
160
+ rec: Record<string, Int32Array | Float32Array>,
161
+ ): ArrayBuffer[] {
162
+ const out: ArrayBuffer[] = []
163
+ for (const v of Object.values(rec)) out.push(v.buffer as ArrayBuffer)
164
+ return out
165
+ }
166
+
167
+ /** Serialize an Error to a wire-friendly shape, preserving stack + name so
168
+ * the receiving side can reconstitute an Error that an `instanceof`-aware
169
+ * caller (e.g., for `ShapeError`) can still pattern-match by name. */
170
+ export function wireError(e: unknown): WireError {
171
+ if (e instanceof Error) {
172
+ return { name: e.name, message: e.message, stack: e.stack ?? '' }
173
+ }
174
+ return { name: 'Error', message: String(e), stack: '' }
175
+ }
176
+
177
+ /** Reconstitute an Error from the wire shape on the receiving (main) side. */
178
+ export function reconstituteError(w: WireError): Error {
179
+ const err = new Error(w.message)
180
+ err.name = w.name
181
+ err.stack = w.stack
182
+ return err
183
+ }
@@ -0,0 +1,76 @@
1
+ // Main-thread half of the worker channel: request/response correlation,
2
+ // promise wiring, error reconstitution. Knows nothing about Adam, captures,
3
+ // IR, etc. — just shuttles typed messages.
4
+
5
+ import type { Req, Res, WireError } from './worker-protocol.js'
6
+ import { reconstituteError } from './worker-protocol.js'
7
+
8
+ interface PendingHandlers {
9
+ resolve: (v: unknown) => void
10
+ reject: (e: Error) => void
11
+ }
12
+
13
+ /** Spawn a worker from an inlined source string and provide a typed
14
+ * request/response channel. One WorkerProxy = one Worker = one GPUDevice
15
+ * on the worker side. Sibling graphs share the same WorkerProxy. */
16
+ export class WorkerProxy {
17
+ private worker: Worker
18
+ private nextId = 1
19
+ private pending = new Map<number, PendingHandlers>()
20
+ private terminated = false
21
+
22
+ constructor(workerSource: string) {
23
+ const blob = new Blob([workerSource], { type: 'application/javascript' })
24
+ const url = URL.createObjectURL(blob)
25
+ this.worker = new Worker(url, { type: 'module' })
26
+ // The Blob URL keeps memory alive as long as it's referenced; revoke
27
+ // once the worker has loaded its source. Browsers tolerate revoke
28
+ // immediately after construction in practice.
29
+ URL.revokeObjectURL(url)
30
+
31
+ this.worker.onmessage = (ev: MessageEvent<Res>) => {
32
+ const reply = ev.data
33
+ const handlers = this.pending.get(reply.id)
34
+ if (!handlers) return // stale reply; ignore
35
+ this.pending.delete(reply.id)
36
+ if (reply.ok) handlers.resolve(reply.result)
37
+ else handlers.reject(reconstituteError(reply.error))
38
+ }
39
+
40
+ this.worker.onerror = (ev: ErrorEvent) => {
41
+ const err = new Error(`tensorgrad worker error: ${ev.message || 'unknown'}`)
42
+ const wire: WireError = { name: 'WorkerError', message: err.message, stack: err.stack ?? '' }
43
+ // Reject everything in flight; subsequent calls will fail too.
44
+ for (const handlers of this.pending.values()) handlers.reject(reconstituteError(wire))
45
+ this.pending.clear()
46
+ }
47
+ }
48
+
49
+ /** Send a request and await its matching response. `transfer` lists the
50
+ * ArrayBuffers to move (zero-copy) into the worker. */
51
+ request<R>(req: Omit<Req, 'id'>, transfer: ArrayBuffer[] = []): Promise<R> {
52
+ if (this.terminated) return Promise.reject(new Error('tensorgrad: worker has been terminated'))
53
+ const id = this.nextId++
54
+ return new Promise<R>((resolve, reject) => {
55
+ this.pending.set(id, { resolve: resolve as (v: unknown) => void, reject })
56
+ this.worker.postMessage({ ...req, id } as Req, transfer)
57
+ })
58
+ }
59
+
60
+ /** Fire-and-forget variant for cases where the caller doesn't need a reply
61
+ * (currently unused; keep for symmetry / future use). */
62
+ send(req: Omit<Req, 'id'>, transfer: ArrayBuffer[] = []): void {
63
+ if (this.terminated) return
64
+ const id = this.nextId++
65
+ this.worker.postMessage({ ...req, id } as Req, transfer)
66
+ }
67
+
68
+ terminate(): void {
69
+ if (this.terminated) return
70
+ this.terminated = true
71
+ this.worker.terminate()
72
+ const err = new Error('tensorgrad: worker terminated')
73
+ for (const handlers of this.pending.values()) handlers.reject(err)
74
+ this.pending.clear()
75
+ }
76
+ }