tensorgrad 0.0.15 → 0.0.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/module.ts CHANGED
@@ -32,30 +32,47 @@ import { paramInput } from './trace.js'
32
32
  // Init metadata
33
33
  // ============================================================================
34
34
 
35
- /** How a parameter's initial values are produced.
36
- * - `'randn'` Gaussian, with `scale` (default 0.02). The common case for
37
- * weight matrices and embeddings.
38
- * - `'zeros'` — fill with 0. Common for biases and LayerNorm beta.
39
- * - `'ones'` — fill with 1. Common for LayerNorm gain.
40
- * - Custom function receives total element count and shape, returns the
41
- * Float32Array. Use for fan-in scaling or any non-standard scheme.
35
+ /** How a parameter's initial values are produced. Serializable shape — no
36
+ * closures, since the initial values cross the worker boundary at compile
37
+ * time. Use the `init` helpers for ergonomic construction.
38
+ *
39
+ * String shorthands:
40
+ * - `'randn'`Gaussian with std 0.02 (the common weight-matrix init).
41
+ * - `'zeros'` fill with 0 (biases, LayerNorm beta).
42
+ * - `'ones'` — fill with 1 (LayerNorm gain).
43
+ *
44
+ * Object shapes:
45
+ * - `{ kind: 'randn', scale }` — randn with explicit std.
46
+ * - `{ kind: 'kaiming', gain? }` — `std = gain / sqrt(fan_in)`. Default
47
+ * gain `sqrt(2)` (good for ReLU). `fan_in = shape[0]`.
48
+ * - `{ kind: 'literal', data }` — explicit Float32Array; length must
49
+ * match the parameter's element count.
42
50
  */
43
51
  export type InitSpec =
44
52
  | 'randn'
45
53
  | 'zeros'
46
54
  | 'ones'
47
- | ((size: number, shape: readonly number[]) => Float32Array)
55
+ | { readonly kind: 'randn'; readonly scale: number }
56
+ | { readonly kind: 'kaiming'; readonly gain?: number }
57
+ | { readonly kind: 'literal'; readonly data: Float32Array }
58
+
59
+ /** Ergonomic constructors for InitSpec object shapes. */
60
+ export const init = {
61
+ randn: (opts: { scale?: number } = {}): InitSpec => ({ kind: 'randn', scale: opts.scale ?? 0.02 }),
62
+ kaiming: (opts: { gain?: number } = {}): InitSpec =>
63
+ opts.gain !== undefined ? { kind: 'kaiming', gain: opts.gain } : { kind: 'kaiming' },
64
+ literal: (data: Float32Array): InitSpec => ({ kind: 'literal', data }),
65
+ }
48
66
 
49
67
  export interface ParamOptions {
50
68
  dtype?: Dtype
51
- /** Init kind. Default: `'randn'`. */
69
+ /** Init shape. Default: `'randn'` (std 0.02). */
52
70
  init?: InitSpec
53
- /** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
54
- scale?: number
55
71
  /** Whether AdamW (when `weightDecay > 0`) should apply decoupled weight
56
- * decay to this param. Default: `true` for `'randn'` init (weight matrices,
57
- * embeddings), `false` for `'zeros'` / `'ones'` (biases, LN gains). Override
58
- * to force or skip. Replaces `adam.decayFilter` for the common case. */
72
+ * decay to this param. Default: `true` for randn/kaiming/literal init
73
+ * (weight matrices, embeddings); `false` for zeros/ones (biases, LN
74
+ * gains). Override to force or skip. Replaces `adam.decayFilter` for
75
+ * the common case. */
59
76
  decay?: boolean
60
77
  }
61
78
 
@@ -65,31 +82,52 @@ function boxMuller(): number {
65
82
  return Math.sqrt(-2 * Math.log(Math.max(1e-10, Math.random()))) * Math.cos(2 * Math.PI * Math.random())
66
83
  }
67
84
 
68
- function resolveInit(opts: ParamOptions | undefined): InitFn {
69
- const init = opts?.init ?? 'randn'
70
- if (init === 'randn') {
71
- const scale = opts?.scale ?? 0.02
72
- return (size) => {
73
- const arr = new Float32Array(size)
74
- for (let i = 0; i < size; i++) arr[i] = boxMuller() * scale
75
- return arr
85
+ function randnFn(scale: number): InitFn {
86
+ return (size) => {
87
+ const arr = new Float32Array(size)
88
+ for (let i = 0; i < size; i++) arr[i] = boxMuller() * scale
89
+ return arr
90
+ }
91
+ }
92
+
93
+ /** Compile-time-only: resolve an InitSpec shape into the closure that
94
+ * generates the initial Float32Array for a given parameter shape. Runs
95
+ * on the main thread before initial values are transferred to the worker. */
96
+ function resolveInit(spec: InitSpec | undefined): InitFn {
97
+ if (!spec || spec === 'randn') return randnFn(0.02)
98
+ if (spec === 'zeros') return (size) => new Float32Array(size)
99
+ if (spec === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
100
+ switch (spec.kind) {
101
+ case 'randn': return randnFn(spec.scale)
102
+ case 'kaiming': {
103
+ const gain = spec.gain ?? Math.sqrt(2)
104
+ return (size, shape) => {
105
+ const fanIn = shape[0] ?? size
106
+ const std = gain / Math.sqrt(fanIn)
107
+ const arr = new Float32Array(size)
108
+ for (let i = 0; i < size; i++) arr[i] = boxMuller() * std
109
+ return arr
110
+ }
111
+ }
112
+ case 'literal': {
113
+ const data = spec.data
114
+ return (size) => {
115
+ if (data.length !== size) {
116
+ throw new Error(`init.literal: data length ${data.length} doesn't match param size ${size}`)
117
+ }
118
+ return new Float32Array(data)
119
+ }
76
120
  }
77
121
  }
78
- if (init === 'zeros') return (size) => new Float32Array(size)
79
- if (init === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
80
- if (typeof init === 'function') return init
81
- throw new Error(`Unknown init: ${String(init)}`)
82
122
  }
83
123
 
84
- /** Resolve the decay default for a param. Decay weight matrices and
85
- * embedding tables (randn-initialized); skip biases (zeros) and LN gains
86
- * (ones). Custom init functions default to "decay" most user-supplied
87
- * inits are weight-shaped (Kaiming etc.). Explicit `decay: false` overrides. */
124
+ /** Resolve the decay default for a param. Weight-shaped inits (randn,
125
+ * kaiming, literal) default to decay=true; ones/zeros default to false
126
+ * (biases, LN gains). Explicit `decay` opt overrides. */
88
127
  function resolveDecay(opts: ParamOptions | undefined): boolean {
89
128
  if (opts?.decay !== undefined) return opts.decay
90
- const init = opts?.init ?? 'randn'
91
- if (init === 'zeros' || init === 'ones') return false
92
- return true // 'randn' or function
129
+ const spec = opts?.init ?? 'randn'
130
+ return spec !== 'zeros' && spec !== 'ones'
93
131
  }
94
132
 
95
133
  // ============================================================================
@@ -127,7 +165,7 @@ export abstract class Module {
127
165
  protected param(shape: Shape, opts?: ParamOptions): Tensor {
128
166
  const dtype = opts?.dtype ?? 'f32'
129
167
  // Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
130
- return new ParamSentinel(shape, dtype, resolveInit(opts), resolveDecay(opts)) as unknown as Tensor
168
+ return new ParamSentinel(shape, dtype, resolveInit(opts?.init), resolveDecay(opts)) as unknown as Tensor
131
169
  }
132
170
  }
133
171
 
@@ -0,0 +1,183 @@
1
+ // Wire format for the main-thread ↔ worker postMessage channel.
2
+ //
3
+ // All requests carry a numeric `id` assigned by the main thread; responses
4
+ // echo it back so the proxy can match concurrent in-flight calls. Every
5
+ // response is either `{ ok: true, result }` or `{ ok: false, error }`.
6
+ // Errors carry serialized name/message/stack so the proxy can reconstitute
7
+ // an Error with a working `instanceof` check on the receiving side.
8
+ //
9
+ // Inputs (typed arrays) and outputs (typed arrays, captures) are transferred
10
+ // rather than copied — see the per-request notes for which fields go on the
11
+ // transfer list. A single worker may host multiple compiled graphs (a train
12
+ // graph plus sibling forward graphs); each has a `graphId` issued by the
13
+ // main thread at compile time.
14
+
15
+ import type { Graph } from './ir.js'
16
+ import type { BufferPlan } from './buffers.js'
17
+ import type { KernelSpec } from './codegen.js'
18
+ import type { LRSchedule } from './adam.js'
19
+
20
+ // ============================================================================
21
+ // Serializable config (subset of AdamResolvedConfig that crosses the wire).
22
+ // `decayFilter` (a function, used only at compile time) is NOT part of this —
23
+ // the per-param decay decision is already baked into the IR by appendAdam
24
+ // before the IR ships to the worker.
25
+ // ============================================================================
26
+
27
+ export interface WireAdamConfig {
28
+ lr: LRSchedule
29
+ b1: number
30
+ b2: number
31
+ eps: number
32
+ weightDecay: number
33
+ lrIsScheduled: boolean
34
+ /** Names of the per-step scalar inputs the worker must populate before
35
+ * every step (`_adam_lrt`, optionally `_adam_decay_shrink`). Mirrors
36
+ * AdamResult so the worker can update them without re-deriving. */
37
+ lrtInputName: string
38
+ decayShrinkInputName: string | null
39
+ }
40
+
41
+ /** Compile output that crosses to the worker. Same fields as CompiledIR
42
+ * minus the `loss` tensor (carried by graph.outputs[0]). */
43
+ export interface WireIR {
44
+ graph: Graph
45
+ plan: BufferPlan
46
+ kernels: KernelSpec[]
47
+ }
48
+
49
+ // ============================================================================
50
+ // Requests (main → worker)
51
+ // ============================================================================
52
+
53
+ export type Req =
54
+ | { id: number; kind: 'createRuntime'; payload: CreateRuntimePayload }
55
+ | { id: number; kind: 'compileForward'; payload: CompileForwardPayload }
56
+ | { id: number; kind: 'step'; payload: StepPayload }
57
+ | { id: number; kind: 'run'; payload: RunPayload }
58
+ | { id: number; kind: 'uploadParams'; payload: UploadParamsPayload }
59
+ | { id: number; kind: 'downloadParams'; payload: { graphId: number } }
60
+ | { id: number; kind: 'downloadParamGrads'; payload: { graphId: number } }
61
+ | { id: number; kind: 'resetOptimizer'; payload: { graphId: number } }
62
+ | { id: number; kind: 'destroy'; payload: { graphId: number } }
63
+
64
+ /** Build the training runtime. Always graphId=0 for a fresh worker. */
65
+ export interface CreateRuntimePayload {
66
+ graphId: number
67
+ ir: WireIR
68
+ /** Initial param values per name. Transferred (zero-copy) — the main
69
+ * thread loses access after postMessage. */
70
+ initialParams: Record<string, Float32Array>
71
+ /** Adam config when training; absent for forward-only compiles. */
72
+ adam: WireAdamConfig | null
73
+ }
74
+
75
+ /** Build a sibling forward-only graph that shares param buffers with an
76
+ * existing graph (typically the training graph at graphId=0). */
77
+ export interface CompileForwardPayload {
78
+ graphId: number
79
+ parentGraphId: number
80
+ ir: WireIR
81
+ }
82
+
83
+ /** One training step. Inputs are transferred; the caller's typed arrays
84
+ * become detached after postMessage. */
85
+ export interface StepPayload {
86
+ graphId: number
87
+ inputs: Record<string, Int32Array | Float32Array>
88
+ withCaptures: boolean
89
+ }
90
+
91
+ /** Forward-only run. Same transfer semantics as `step`. */
92
+ export interface RunPayload {
93
+ graphId: number
94
+ inputs: Record<string, Int32Array | Float32Array>
95
+ withCaptures: boolean
96
+ }
97
+
98
+ export interface UploadParamsPayload {
99
+ graphId: number
100
+ params: Record<string, Float32Array> // transferred
101
+ partial: boolean
102
+ }
103
+
104
+ // ============================================================================
105
+ // Responses (worker → main)
106
+ // ============================================================================
107
+
108
+ export type Res<R = unknown> =
109
+ | { id: number; ok: true; result: R }
110
+ | { id: number; ok: false; error: WireError }
111
+
112
+ export interface WireError {
113
+ name: string
114
+ message: string
115
+ stack: string
116
+ }
117
+
118
+ // Per-request result shapes:
119
+
120
+ export interface CreateRuntimeResult {
121
+ paramNames: string[]
122
+ outputShape: number[]
123
+ kernelCount: number
124
+ captureShapes: Record<string, number[]>
125
+ }
126
+
127
+ export interface CompileForwardResult {
128
+ paramNames: string[]
129
+ outputShape: number[]
130
+ kernelCount: number
131
+ captureShapes: Record<string, number[]>
132
+ }
133
+
134
+ /** Step without `withCaptures` returns just `loss`. With captures, also
135
+ * populates `captures` (per-name Float32Array, all transferred back). */
136
+ export interface StepResultWire {
137
+ loss: number
138
+ captures: Record<string, Float32Array> | null
139
+ }
140
+
141
+ /** Run without `withCaptures` returns `{ output, captures: null }`.
142
+ * With captures, also populates `captures`. */
143
+ export interface RunResultWire {
144
+ output: Float32Array
145
+ captures: Record<string, Float32Array> | null
146
+ }
147
+
148
+ export interface DownloadParamsResult {
149
+ params: Record<string, Float32Array> // transferred
150
+ }
151
+
152
+ // ============================================================================
153
+ // Transfer-list helpers
154
+ // ============================================================================
155
+
156
+ /** Collect the underlying ArrayBuffers from a Record of typed arrays so we
157
+ * can pass them on `postMessage`'s transfer list. The values themselves
158
+ * stay in the Record; only their backing buffers move. */
159
+ export function transferablesOfRecord(
160
+ rec: Record<string, Int32Array | Float32Array>,
161
+ ): ArrayBuffer[] {
162
+ const out: ArrayBuffer[] = []
163
+ for (const v of Object.values(rec)) out.push(v.buffer as ArrayBuffer)
164
+ return out
165
+ }
166
+
167
+ /** Serialize an Error to a wire-friendly shape, preserving stack + name so
168
+ * the receiving side can reconstitute an Error that an `instanceof`-aware
169
+ * caller (e.g., for `ShapeError`) can still pattern-match by name. */
170
+ export function wireError(e: unknown): WireError {
171
+ if (e instanceof Error) {
172
+ return { name: e.name, message: e.message, stack: e.stack ?? '' }
173
+ }
174
+ return { name: 'Error', message: String(e), stack: '' }
175
+ }
176
+
177
+ /** Reconstitute an Error from the wire shape on the receiving (main) side. */
178
+ export function reconstituteError(w: WireError): Error {
179
+ const err = new Error(w.message)
180
+ err.name = w.name
181
+ err.stack = w.stack
182
+ return err
183
+ }
@@ -0,0 +1,76 @@
1
+ // Main-thread half of the worker channel: request/response correlation,
2
+ // promise wiring, error reconstitution. Knows nothing about Adam, captures,
3
+ // IR, etc. — just shuttles typed messages.
4
+
5
+ import type { Req, Res, WireError } from './worker-protocol.js'
6
+ import { reconstituteError } from './worker-protocol.js'
7
+
8
+ interface PendingHandlers {
9
+ resolve: (v: unknown) => void
10
+ reject: (e: Error) => void
11
+ }
12
+
13
+ /** Spawn a worker from an inlined source string and provide a typed
14
+ * request/response channel. One WorkerProxy = one Worker = one GPUDevice
15
+ * on the worker side. Sibling graphs share the same WorkerProxy. */
16
+ export class WorkerProxy {
17
+ private worker: Worker
18
+ private nextId = 1
19
+ private pending = new Map<number, PendingHandlers>()
20
+ private terminated = false
21
+
22
+ constructor(workerSource: string) {
23
+ const blob = new Blob([workerSource], { type: 'application/javascript' })
24
+ const url = URL.createObjectURL(blob)
25
+ this.worker = new Worker(url, { type: 'module' })
26
+ // The Blob URL keeps memory alive as long as it's referenced; revoke
27
+ // once the worker has loaded its source. Browsers tolerate revoke
28
+ // immediately after construction in practice.
29
+ URL.revokeObjectURL(url)
30
+
31
+ this.worker.onmessage = (ev: MessageEvent<Res>) => {
32
+ const reply = ev.data
33
+ const handlers = this.pending.get(reply.id)
34
+ if (!handlers) return // stale reply; ignore
35
+ this.pending.delete(reply.id)
36
+ if (reply.ok) handlers.resolve(reply.result)
37
+ else handlers.reject(reconstituteError(reply.error))
38
+ }
39
+
40
+ this.worker.onerror = (ev: ErrorEvent) => {
41
+ const err = new Error(`tensorgrad worker error: ${ev.message || 'unknown'}`)
42
+ const wire: WireError = { name: 'WorkerError', message: err.message, stack: err.stack ?? '' }
43
+ // Reject everything in flight; subsequent calls will fail too.
44
+ for (const handlers of this.pending.values()) handlers.reject(reconstituteError(wire))
45
+ this.pending.clear()
46
+ }
47
+ }
48
+
49
+ /** Send a request and await its matching response. `transfer` lists the
50
+ * ArrayBuffers to move (zero-copy) into the worker. */
51
+ request<R>(req: Omit<Req, 'id'>, transfer: ArrayBuffer[] = []): Promise<R> {
52
+ if (this.terminated) return Promise.reject(new Error('tensorgrad: worker has been terminated'))
53
+ const id = this.nextId++
54
+ return new Promise<R>((resolve, reject) => {
55
+ this.pending.set(id, { resolve: resolve as (v: unknown) => void, reject })
56
+ this.worker.postMessage({ ...req, id } as Req, transfer)
57
+ })
58
+ }
59
+
60
+ /** Fire-and-forget variant for cases where the caller doesn't need a reply
61
+ * (currently unused; keep for symmetry / future use). */
62
+ send(req: Omit<Req, 'id'>, transfer: ArrayBuffer[] = []): void {
63
+ if (this.terminated) return
64
+ const id = this.nextId++
65
+ this.worker.postMessage({ ...req, id } as Req, transfer)
66
+ }
67
+
68
+ terminate(): void {
69
+ if (this.terminated) return
70
+ this.terminated = true
71
+ this.worker.terminate()
72
+ const err = new Error('tensorgrad: worker terminated')
73
+ for (const handlers of this.pending.values()) handlers.reject(err)
74
+ this.pending.clear()
75
+ }
76
+ }
package/src/worker.ts ADDED
@@ -0,0 +1,281 @@
1
+ // Worker entry point. Holds the GPUDevice + CompiledRuntime for one or more
2
+ // graphs and proxies main-thread requests via postMessage. See
3
+ // specs/WorkerArchitecture.md for the rationale.
4
+ //
5
+ // Keep this file dependency-free of anything DOM-y: it bundles into a Blob
6
+ // URL and runs in a Web Worker context where `window`/`document` don't
7
+ // exist. WebGPU IS available in workers (Chrome 113+, Safari 17.4+).
8
+
9
+ import { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js'
10
+ import { resolveLR, type LRSchedule } from './adam.js'
11
+ import type { Req, Res, WireIR, WireAdamConfig, WireError } from './worker-protocol.js'
12
+ import { wireError } from './worker-protocol.js'
13
+
14
+ // ----------------------------------------------------------------------------
15
+ // Per-graph state
16
+ // ----------------------------------------------------------------------------
17
+
18
+ interface GraphSlot {
19
+ runtime: CompiledRuntime
20
+ paramNames: readonly string[]
21
+ outputShape: number[]
22
+ kernelCount: number
23
+ captureShapes: Record<string, number[]>
24
+ /** Adam state for this graph, if it's a training graph. The wrapped step
25
+ * uses these to populate the per-step lrt and decayShrink scalars. */
26
+ adam: AdamState | null
27
+ }
28
+
29
+ interface AdamState {
30
+ config: WireAdamConfig
31
+ t: number
32
+ lrtBuf: Float32Array
33
+ decayShrinkBuf: Float32Array | null
34
+ }
35
+
36
+ const graphs = new Map<number, GraphSlot>()
37
+
38
+ // Worker holds one device shared across all graphs (sibling forward graphs
39
+ // must share param GPUBuffers, which means sharing a device).
40
+ let device: GPUDevice | null = null
41
+
42
+ async function ensureDevice(): Promise<GPUDevice> {
43
+ if (device) return device
44
+ if (typeof navigator === 'undefined' || !navigator.gpu) {
45
+ throw new Error('tensorgrad worker: WebGPU not available in this environment')
46
+ }
47
+ const adapter = await navigator.gpu.requestAdapter()
48
+ if (!adapter) throw new Error('tensorgrad worker: no WebGPU adapter')
49
+ device = await adapter.requestDevice()
50
+ return device
51
+ }
52
+
53
+ // ----------------------------------------------------------------------------
54
+ // Request handlers
55
+ // ----------------------------------------------------------------------------
56
+
57
+ async function handleCreateRuntime(payload: {
58
+ graphId: number
59
+ ir: WireIR
60
+ initialParams: Record<string, Float32Array>
61
+ adam: WireAdamConfig | null
62
+ }): Promise<{ paramNames: string[]; outputShape: number[]; kernelCount: number; captureShapes: Record<string, number[]> }> {
63
+ const dev = await ensureDevice()
64
+ const { graph, plan, kernels } = payload.ir
65
+ const outputTensorId = graph.outputs[0]!
66
+ const outputBufferId = plan.tensorToBuffer.get(outputTensorId)!
67
+ const opts: RuntimeOpts = { device: dev }
68
+ const runtime = await createRuntime(plan, kernels, outputBufferId, opts)
69
+
70
+ // Upload initial params.
71
+ if (Object.keys(payload.initialParams).length > 0) {
72
+ runtime.uploadParams(payload.initialParams)
73
+ }
74
+
75
+ // Capture shape metadata for return.
76
+ const captureShapes: Record<string, number[]> = {}
77
+ for (const [name, bufId] of plan.capturesByName) {
78
+ captureShapes[name] = [...plan.buffers[bufId]!.shape]
79
+ }
80
+
81
+ const slot: GraphSlot = {
82
+ runtime,
83
+ paramNames: [...plan.paramsByName.keys()],
84
+ outputShape: [...runtime.outputShape],
85
+ kernelCount: kernels.filter(k => k.wgsl).length,
86
+ captureShapes,
87
+ adam: payload.adam ? createAdamState(payload.adam) : null,
88
+ }
89
+ graphs.set(payload.graphId, slot)
90
+
91
+ return {
92
+ paramNames: [...slot.paramNames],
93
+ outputShape: slot.outputShape,
94
+ kernelCount: slot.kernelCount,
95
+ captureShapes: slot.captureShapes,
96
+ }
97
+ }
98
+
99
+ async function handleCompileForward(payload: {
100
+ graphId: number
101
+ parentGraphId: number
102
+ ir: WireIR
103
+ }): Promise<{ paramNames: string[]; outputShape: number[]; kernelCount: number; captureShapes: Record<string, number[]> }> {
104
+ const dev = await ensureDevice()
105
+ const parent = graphs.get(payload.parentGraphId)
106
+ if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`)
107
+
108
+ const { graph, plan, kernels } = payload.ir
109
+ const outputTensorId = graph.outputs[0]!
110
+ const outputBufferId = plan.tensorToBuffer.get(outputTensorId)!
111
+ const opts: RuntimeOpts = { device: dev, sharedParams: parent.runtime.params }
112
+ const runtime = await createRuntime(plan, kernels, outputBufferId, opts)
113
+ // No initial-param upload — sharedParams covers everything.
114
+
115
+ const captureShapes: Record<string, number[]> = {}
116
+ for (const [name, bufId] of plan.capturesByName) {
117
+ captureShapes[name] = [...plan.buffers[bufId]!.shape]
118
+ }
119
+
120
+ const slot: GraphSlot = {
121
+ runtime,
122
+ paramNames: [...plan.paramsByName.keys()],
123
+ outputShape: [...runtime.outputShape],
124
+ kernelCount: kernels.filter(k => k.wgsl).length,
125
+ captureShapes,
126
+ adam: null,
127
+ }
128
+ graphs.set(payload.graphId, slot)
129
+
130
+ return {
131
+ paramNames: [...slot.paramNames],
132
+ outputShape: slot.outputShape,
133
+ kernelCount: slot.kernelCount,
134
+ captureShapes: slot.captureShapes,
135
+ }
136
+ }
137
+
138
+ function createAdamState(cfg: WireAdamConfig): AdamState {
139
+ return {
140
+ config: cfg,
141
+ t: 0,
142
+ lrtBuf: new Float32Array(1),
143
+ decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null,
144
+ }
145
+ }
146
+
147
+ /** Inject Adam's per-step lrt + decayShrink scalars into the inputs map.
148
+ * Called before every step on a training graph. The buffers are reused
149
+ * across steps to avoid allocation. */
150
+ function injectAdamScalars(slot: GraphSlot, inputs: Record<string, Int32Array | Float32Array>): Record<string, Int32Array | Float32Array> {
151
+ const a = slot.adam
152
+ if (!a) return inputs
153
+ a.t++
154
+ const lrNow = resolveLR(a.config.lr as LRSchedule, a.t)
155
+ a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t))
156
+ const merged: Record<string, Int32Array | Float32Array> = { ...inputs, [a.config.lrtInputName]: a.lrtBuf }
157
+ if (a.decayShrinkBuf && a.config.decayShrinkInputName) {
158
+ a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay
159
+ merged[a.config.decayShrinkInputName] = a.decayShrinkBuf
160
+ }
161
+ return merged
162
+ }
163
+
164
+ async function handleStep(payload: {
165
+ graphId: number
166
+ inputs: Record<string, Int32Array | Float32Array>
167
+ withCaptures: boolean
168
+ }): Promise<{ loss: number; captures: Record<string, Float32Array> | null }> {
169
+ const slot = mustGet(payload.graphId)
170
+ const merged = injectAdamScalars(slot, payload.inputs)
171
+ if (payload.withCaptures) {
172
+ const r = await slot.runtime.step(merged, { withCaptures: true })
173
+ return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) }
174
+ }
175
+ const loss = await slot.runtime.step(merged)
176
+ return { loss, captures: null }
177
+ }
178
+
179
+ async function handleRun(payload: {
180
+ graphId: number
181
+ inputs: Record<string, Int32Array | Float32Array>
182
+ withCaptures: boolean
183
+ }): Promise<{ output: Float32Array; captures: Record<string, Float32Array> | null }> {
184
+ const slot = mustGet(payload.graphId)
185
+ if (payload.withCaptures) {
186
+ const r = await slot.runtime.run(payload.inputs, { withCaptures: true })
187
+ return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) }
188
+ }
189
+ const output = await slot.runtime.run(payload.inputs)
190
+ return { output, captures: null }
191
+ }
192
+
193
+ /** Captures (a class instance with a private Map) → a plain Record so the
194
+ * worker can transfer Float32Arrays back without serializing the class. */
195
+ function capturesToRecord(
196
+ captures: { get(name: string): Float32Array; has(name: string): boolean; names(): string[] },
197
+ // captureShapes available but not used directly — capture names from
198
+ // shapes in case captures.names() is filtered (it isn't, but be safe).
199
+ shapes: Record<string, number[]>,
200
+ ): Record<string, Float32Array> {
201
+ const out: Record<string, Float32Array> = {}
202
+ for (const name of Object.keys(shapes)) {
203
+ if (captures.has(name)) out[name] = captures.get(name)
204
+ }
205
+ return out
206
+ }
207
+
208
+ function handleUploadParams(payload: {
209
+ graphId: number
210
+ params: Record<string, Float32Array>
211
+ partial: boolean
212
+ }): void {
213
+ const slot = mustGet(payload.graphId)
214
+ slot.runtime.uploadParams(payload.params, { partial: payload.partial })
215
+ }
216
+
217
+ async function handleDownloadParams(payload: { graphId: number }): Promise<{ params: Record<string, Float32Array> }> {
218
+ const slot = mustGet(payload.graphId)
219
+ return { params: await slot.runtime.downloadParams() }
220
+ }
221
+
222
+ async function handleDownloadParamGrads(payload: { graphId: number }): Promise<{ params: Record<string, Float32Array> }> {
223
+ const slot = mustGet(payload.graphId)
224
+ return { params: await slot.runtime.downloadParamGrads() }
225
+ }
226
+
227
+ function handleResetOptimizer(payload: { graphId: number }): void {
228
+ const slot = mustGet(payload.graphId)
229
+ slot.runtime.resetOptimizerState()
230
+ if (slot.adam) slot.adam.t = 0
231
+ }
232
+
233
+ function handleDestroy(payload: { graphId: number }): void {
234
+ const slot = graphs.get(payload.graphId)
235
+ if (!slot) return
236
+ slot.runtime.destroy()
237
+ graphs.delete(payload.graphId)
238
+ }
239
+
240
+ function mustGet(graphId: number): GraphSlot {
241
+ const slot = graphs.get(graphId)
242
+ if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`)
243
+ return slot
244
+ }
245
+
246
+ // ----------------------------------------------------------------------------
247
+ // Message dispatch
248
+ // ----------------------------------------------------------------------------
249
+
250
+ self.onmessage = async (ev: MessageEvent<Req>) => {
251
+ const req = ev.data
252
+ try {
253
+ let result: unknown
254
+ let transferList: ArrayBuffer[] = []
255
+ switch (req.kind) {
256
+ case 'createRuntime': result = await handleCreateRuntime(req.payload); break
257
+ case 'compileForward': result = await handleCompileForward(req.payload); break
258
+ case 'step': result = await handleStep(req.payload); transferList = collectTransfers((result as any).captures); break
259
+ case 'run': { const r = await handleRun(req.payload); result = r; transferList = [r.output.buffer as ArrayBuffer, ...collectTransfers(r.captures)]; break }
260
+ case 'uploadParams': handleUploadParams(req.payload); result = null; break
261
+ case 'downloadParams': { const r = await handleDownloadParams(req.payload); result = r; transferList = collectTransfers(r.params); break }
262
+ case 'downloadParamGrads':{ const r = await handleDownloadParamGrads(req.payload); result = r; transferList = collectTransfers(r.params); break }
263
+ case 'resetOptimizer': handleResetOptimizer(req.payload); result = null; break
264
+ case 'destroy': handleDestroy(req.payload); result = null; break
265
+ default: throw new Error(`unknown request kind: ${(req as { kind: string }).kind}`)
266
+ }
267
+ const reply: Res = { id: req.id, ok: true, result }
268
+ self.postMessage(reply, { transfer: transferList })
269
+ } catch (e) {
270
+ const error: WireError = wireError(e)
271
+ const reply: Res = { id: req.id, ok: false, error }
272
+ self.postMessage(reply)
273
+ }
274
+ }
275
+
276
+ function collectTransfers(rec: Record<string, Float32Array> | null | undefined): ArrayBuffer[] {
277
+ if (!rec) return []
278
+ const out: ArrayBuffer[] = []
279
+ for (const v of Object.values(rec)) out.push(v.buffer as ArrayBuffer)
280
+ return out
281
+ }