tensorgrad 0.0.16 → 0.0.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -1,119 +1,253 @@
1
- # tensorgrad
2
-
3
- A tiny TypeScript-native tensor library with autograd that compiles directly
4
- to WebGPU. Designed for training small models in the browser without
5
- hand-writing WGSL kernels and without dragging in a 5 MB ML framework.
6
-
7
- ```sh
8
- npm i tensorgrad
9
- ```
10
-
11
- Roughly **3000 lines of zero-dependency TypeScript**, ~10 KB gzipped after
12
- build. Targets WebGPU only. Static shapes only. Forward + reverse-mode
13
- autograd; Adam optimizer; the whole training pipeline runs as compiled WGSL.
14
-
15
- ## Quick example
16
-
17
- A 2-layer MLP fitting `y = sin(x)`:
18
-
19
- ```ts
20
- import {
21
- Module, compileModule,
22
- add, mul, sub, sumLast, reshape, matmul, relu,
23
- type Tensor,
24
- } from 'tensorgrad'
25
-
26
- class Linear extends Module {
27
- W: Tensor; b: Tensor
28
- constructor(public inDim: number, public outDim: number) {
29
- super()
30
- this.W = this.param([inDim, outDim]) // randn, scale 0.02
31
- this.b = this.param([outDim], { init: 'zeros' })
32
- }
33
- }
34
-
35
- class MLP extends Module {
36
- l1 = new Linear(1, 64)
37
- l2 = new Linear(64, 64)
38
- l3 = new Linear(64, 1)
39
- }
40
-
41
- const linear = (p: Linear, x: Tensor) => add(matmul(x, p.W), p.b)
42
-
43
- function forward(m: MLP, x: Tensor): Tensor {
44
- return linear(m.l3, relu(linear(m.l2, relu(linear(m.l1, x)))))
45
- }
46
-
47
- function loss(m: MLP, { x, y }: { x: Tensor; y: Tensor }): Tensor {
48
- const diff = sub(forward(m, x), y)
49
- return mul(sumLast(reshape(mul(diff, diff), [B])), 1 / B)
50
- }
51
-
52
- const B = 256
53
- const compiled = await compileModule(() => new MLP(), loss, {
54
- adam: { lr: 0.005 },
55
- inputs: {
56
- x: { shape: [B, 1], dtype: 'f32' },
57
- y: { shape: [B, 1], dtype: 'f32' },
58
- },
59
- })
60
-
61
- // Initial params are uploaded automatically — no manual step needed.
62
-
63
- for (let step = 0; step < 1000; step++) {
64
- const { x, y } = generateBatch()
65
- const lossVal = await compiled.step({ x, y })
66
- if (step % 100 === 0) console.log('step', step, 'loss', lossVal)
67
- }
68
- ```
69
-
70
- That's the whole user-facing surface for this model: `Module` for parameter
71
- storage, plain functions for the forward pass, `compileModule` to JIT-compile
72
- to WGSL with autograd + Adam wired in. No decorators, no `tf.GradientTape`,
73
- no `register_pytree_node`.
74
-
75
- For a more involved example a 3-layer transformer trained from scratch on
76
- 2-digit addition — see the [`samples/`](./samples) workspace
77
- (`pnpm --filter samples dev`).
78
-
79
- ## What this library is for
80
-
81
- Small browser-side ML where you want to *train* the model, not just run
82
- inference of a pretrained model. Educational artifacts, interactive
83
- demos, on-device personalization, "transformer from scratch in your browser"
84
- blog posts. Roughly the niche where the model is small enough to fit
85
- comfortably in a browser tab but where you still want autograd and a real
86
- optimizer.
87
-
88
- If you want to ship inference of a pretrained model, use
89
- [ONNX Runtime Web](https://github.com/microsoft/onnxruntime) or
90
- [transformers.js](https://github.com/xenova/transformers.js).
91
- If you need full JAX (vmap / pmap / dynamic shapes / multi-backend), use
92
- [jax-js](https://github.com/jax-js/jax).
93
-
94
- ## Scope (deliberately small)
95
-
96
- The library only does what it does because of what it doesn't do. The
97
- load-bearing "out of scope" decisions are:
98
-
99
- - **WebGPU only** no Wasm or WebGL fallback.
100
- - **Static shapes only** — every shape is fixed at compile time. This is
101
- what lets us bake constants into the WGSL instead of carrying shape
102
- uniforms.
103
- - **`grad` is the only transformation** — no `vmap`, `pmap`, `jvp`,
104
- `custom_vjp`. Batch your data explicitly.
105
- - **`f32` only** no dtype promotion, no mixed precision.
106
- - **Closed op set** — about 25 ops, listed in `SPEC.md`. Compositions of
107
- those handle most needs (GELU, RMS norm, etc. are a few lines on top).
108
- - **Adam lives in the IR** — bias correction included; no CPU↔GPU
109
- round-trip per step.
110
-
111
- ## Status
112
-
113
- Alpha. Two real working models (a transformer training to <0.1 loss on
114
- addition, an MLP fitting `sin`). API may change before 1.0. Filing issues
115
- welcome.
116
-
117
- ## License
118
-
119
- MIT
1
+ # tensorgrad
2
+
3
+ A tiny TypeScript-native tensor library with autograd that compiles to WebGPU.
4
+ For training small models in the browser without hand-writing WGSL kernels and
5
+ without dragging in a multi-megabyte ML framework.
6
+
7
+ ```sh
8
+ npm i tensorgrad
9
+ ```
10
+
11
+ Roughly 3000 lines of zero-dependency TypeScript. Static shapes, `f32`, Adam
12
+ optimizer, ~25 ops, forward + reverse-mode autograd. Browser-only (uses
13
+ WebGPU). All training/inference work runs in a library-internal Web Worker
14
+ every method on a compiled module returns a `Promise`.
15
+
16
+ ## Minimal example
17
+
18
+ A 2-layer MLP fitting `y = sin(x)`:
19
+
20
+ ```ts
21
+ import {
22
+ Module, compileModule, init,
23
+ add, mul, sub, sumLast, reshape, matmul, relu,
24
+ type Tensor,
25
+ } from 'tensorgrad'
26
+
27
+ const B = 256
28
+
29
+ class Linear extends Module {
30
+ W: Tensor; b: Tensor
31
+ constructor(public inDim: number, public outDim: number) {
32
+ super()
33
+ this.W = this.param([inDim, outDim], { init: init.kaiming() })
34
+ this.b = this.param([outDim], { init: 'zeros' })
35
+ }
36
+ }
37
+
38
+ class MLP extends Module {
39
+ l1 = new Linear(1, 64)
40
+ l2 = new Linear(64, 64)
41
+ l3 = new Linear(64, 1)
42
+ }
43
+
44
+ const linearFwd = (p: Linear, x: Tensor) => add(matmul(x, p.W), p.b)
45
+
46
+ function modelFwd(m: MLP, x: Tensor): Tensor {
47
+ return linearFwd(m.l3, relu(linearFwd(m.l2, relu(linearFwd(m.l1, x)))))
48
+ }
49
+
50
+ function lossFn(m: MLP, { x, y }: { x: Tensor; y: Tensor }): Tensor {
51
+ const diff = sub(modelFwd(m, x), y)
52
+ return mul(sumLast(reshape(mul(diff, diff), [B])), 1 / B)
53
+ }
54
+
55
+ const compiled = await compileModule(() => new MLP(), lossFn, {
56
+ adam: { lr: 0.005 },
57
+ inputs: {
58
+ x: { shape: [B, 1], dtype: 'f32' },
59
+ y: { shape: [B, 1], dtype: 'f32' },
60
+ },
61
+ })
62
+
63
+ for (let step = 0; step < 1000; step++) {
64
+ const { x, y } = generateBatch()
65
+ const lossVal = await compiled.step({ x, y })
66
+ if (step % 100 === 0) console.log('step', step, 'loss', lossVal)
67
+ }
68
+ ```
69
+
70
+ ## Mental model
71
+
72
+ - A `Module` subclass declares parameters via `this.param([shape], opts)` and
73
+ composes child modules as plain fields. The class is a tree of params.
74
+ - A *forward function* takes the materialized module + a record of named
75
+ input tensors and returns a tensor (the loss for `compileModule`, or any
76
+ output for `compileForward`).
77
+ - `compileModule(factory, forward, opts)` traces the forward, derives
78
+ gradients, wires Adam, plans buffers, generates WGSL, spawns a worker, and
79
+ returns a `CompiledModule`. The factory `() => new Model()` is invoked
80
+ once during compile; the model instance is consumed (its param sentinels
81
+ are mutated into `Tensor`s).
82
+ - Every method on the compiled module is async. `await compiled.step(...)`
83
+ resolves with the loss after the worker's GPU work finishes.
84
+
85
+ ## Public API
86
+
87
+ ### Compile entry points
88
+
89
+ ```ts
90
+ compileModule(factory, forward, { adam?, inputs? }): Promise<CompiledModule>
91
+ compileForward(factory, forward, { inputs? }): Promise<CompiledForwardModule>
92
+ ```
93
+
94
+ `compileForward` produces a forward-only graph in its own worker. To share
95
+ params with an existing training graph, use the sibling method:
96
+
97
+ ```ts
98
+ const train = await compileModule(() => new Model(), lossFn, { ... })
99
+ const infer = await train.compileForward(predictFn, {
100
+ inputs: { tokens: { shape: [1, T], dtype: 'i32' } },
101
+ })
102
+ // infer runs in train's worker — every step's param updates are visible.
103
+ ```
104
+
105
+ ### CompiledModule methods (all `Promise`-returning)
106
+
107
+ ```ts
108
+ compiled.step(inputs) // loss: number
109
+ compiled.step(inputs, { withCaptures: true }) // → { loss, captures }
110
+ compiled.run(inputs) // → Float32Array
111
+ compiled.run(inputs, { withCaptures: true }) // → { output, captures }
112
+ compiled.uploadParams(record, { partial? })
113
+ compiled.downloadParams() // Record<name, Float32Array>
114
+ compiled.downloadParamGrads() // Record<name, Float32Array>
115
+ compiled.reset() // re-init params + zero Adam state
116
+ compiled.resetOptimizerState()
117
+ compiled.compileForward(forward, { inputs? }) // sibling forward graph
118
+ compiled.destroy() // tear down worker + GPU
119
+ ```
120
+
121
+ `compiled.kernelCount`, `compiled.outputShape`, `compiled.paramNames`, and
122
+ `compiled.ir` are sync properties for inspection.
123
+
124
+ ### Operators
125
+
126
+ Imported from `'tensorgrad'`:
127
+
128
+ - Element-wise: `add`, `sub`, `mul`, `div`, `sqrt`, `rsqrt`, `log`, `exp`, `relu`
129
+ - Comparisons / select: `less`, `greater`, `where`
130
+ - Reductions (last axis): `meanLast`, `sumLast`, `sumAll`
131
+ - Shape: `reshape`, `transpose`, `swapAxes`
132
+ - Linear algebra: `matmul`, `matmulBatched`
133
+ - Indexing / casting: `oneHot`, `arange`, `embedding`
134
+ - Slicing: `sliceLastRange`
135
+ - Fused ML primitives: `softmaxCausalLast`, `logSoftmaxLast`, `whereCausal`
136
+
137
+ `add`, `sub`, `mul`, `div` accept `(Tensor, Tensor)` or `(Tensor, number)`.
138
+
139
+ ### `nn` namespace
140
+
141
+ ```ts
142
+ import { nn } from 'tensorgrad'
143
+
144
+ nn.Linear(inDim, outDim, { bias? }) // .fwd(x)
145
+ nn.LayerNorm(dim) // .fwd(x)
146
+ nn.splitHeads(x, nHeads) // [B, T, D] → [B, H, T, D/H]
147
+ nn.mergeHeads(x) // inverse of splitHeads
148
+ nn.unsplitHeads(captures, name) // pull per-head slices off a capture
149
+ nn.crossEntropyLast(logits, targets) // standard CE
150
+ ```
151
+
152
+ Convention: leaf modules (`Linear`, `LayerNorm`) expose `.fwd(x)` for ergonomic
153
+ chaining. Composite modules you write yourself are typically free functions
154
+ taking `(p: ModuleType, x: Tensor)`.
155
+
156
+ ### LR schedules (`lr` namespace)
157
+
158
+ ```ts
159
+ import { lr } from 'tensorgrad'
160
+
161
+ adam: { lr: 0.005 } // constant
162
+ adam: { lr: lr.linearDecay({ peak: 0.005, final: 0.0005, steps: 1500 }) }
163
+ adam: { lr: lr.cosineDecay({ peak: 0.005, final: 0.0001, steps: 5000 }) }
164
+ adam: { lr: lr.warmup({ peakLr: 0.001, warmupSteps: 200, after: lr.constant(0.001) }) }
165
+ ```
166
+
167
+ LR schedules are serializable shapes, not closures (they cross the worker
168
+ boundary). Use a `number` for constant LR, or one of the constructors above.
169
+
170
+ ### Param init (`init` namespace)
171
+
172
+ ```ts
173
+ import { init } from 'tensorgrad'
174
+
175
+ this.param([D, D], { init: init.kaiming() }) // gain=sqrt(2), fan_in=D
176
+ this.param([D, D], { init: init.kaiming({ gain: 1 }) })
177
+ this.param([D], { init: 'zeros' })
178
+ this.param([D], { init: 'ones' })
179
+ this.param([D, D], { init: init.randn({ scale: 0.02 }) })
180
+ this.param([D], { init: init.literal(myFloat32Array) })
181
+ ```
182
+
183
+ Defaults: `'randn'` (std 0.02). AdamW weight decay defaults to `true` for
184
+ randn/kaiming/literal init, `false` for zeros/ones — override per-param with
185
+ `{ decay: true | false }`.
186
+
187
+ ### Captures (debugging / mech-interp)
188
+
189
+ Wrap any tensor inside a forward to expose its activation post-run:
190
+
191
+ ```ts
192
+ import { capture } from 'tensorgrad'
193
+
194
+ const attn = capture(`attn.${i}`, softmaxCausalLast(scores))
195
+ ```
196
+
197
+ ```ts
198
+ const { output, captures } = await compiled.run(inputs, { withCaptures: true })
199
+ const attn0 = captures.get('attn.0') // Float32Array
200
+ captures.shapeOf('attn.0') // readonly number[]
201
+ ```
202
+
203
+ Captures are zero-overhead unless `{ withCaptures: true }` is passed; they
204
+ add a single batched mapAsync on the readback.
205
+
206
+ ## Constraints
207
+
208
+ The library is small because of what it doesn't do. Plan accordingly:
209
+
210
+ - **WebGPU only.** No Wasm, WebGL, or native fallback.
211
+ - **Static shapes.** Every shape is fixed at compile time. Changing a batch
212
+ size means recompiling.
213
+ - **`f32` only.** No mixed precision. Inputs may be `i32` for indices.
214
+ - **One transformation: `grad`.** No `vmap`, `pmap`, `jvp`, `custom_vjp`.
215
+ Batch your data explicitly.
216
+ - **Loss must be a scalar.** `compileModule`'s forward returns a rank-0 tensor.
217
+ - **Closures don't cross the worker boundary.** LR schedules and inits are
218
+ serializable shapes, not functions. Anything per-step you write into a
219
+ user-defined optimizer (see *Extending* below) follows the same rule.
220
+ - **One model per `compileModule` call.** Sibling forward graphs share params
221
+ via the method form; otherwise each compile spawns its own worker.
222
+
223
+ ## Extending
224
+
225
+ The IR is open. Adam is built in only because it's the most common starting
226
+ point — other optimizers, custom losses, or extra ops are user code following
227
+ the same pattern as `appendAdam`:
228
+
229
+ ```ts
230
+ import { appendAdam, appendGrad, compileToIR } from 'tensorgrad'
231
+ ```
232
+
233
+ A custom optimizer is a function that takes the autograd output (graph +
234
+ `paramGrads`) and the materialized param tensors, appends its update ops
235
+ to the graph, and returns writeback declarations the buffer planner uses
236
+ to wire each new value back into its persistent home. SGD, Lion, RMSProp
237
+ all fit this shape; see `src/adam.ts` for the canonical example.
238
+
239
+ The same applies to ops: anything missing from the built-in set can be
240
+ expressed as a composition of existing ops (GELU, RMSNorm, etc. are a few
241
+ lines), or — if you need a new primitive — added to the IR with a
242
+ forward + backward + WGSL emit.
243
+
244
+ ## When not to use this
245
+
246
+ - **Inference of pretrained models** → use ONNX Runtime Web or
247
+ transformers.js.
248
+ - **Full JAX surface** (vmap, dynamic shapes, multi-backend) → use jax-js.
249
+ - **Server-side training** → use PyTorch or JAX.
250
+
251
+ ## License
252
+
253
+ MIT
package/dist/index.js CHANGED
@@ -1805,7 +1805,7 @@ async function compileModule(modelFactory, forward, opts = {}) {
1805
1805
  const kernels = emitKernels(graph, plan);
1806
1806
  const ir = { graph, paramGrads, loss, plan, kernels };
1807
1807
  const initialParams = buildInitialParams(plan, materialized.initFns);
1808
- const proxy = new WorkerProxy('// src/runtime.ts\nvar Captures = class {\n constructor(shapes, data) {\n this.shapes = shapes;\n this.data = data;\n }\n shapes;\n data;\n get(name) {\n const d = this.data.get(name);\n if (!d) {\n const known = [...this.data.keys()].sort().join(", ");\n const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;\n throw new Error(`Captures.get: \'${name}\' not present. ${detail}`);\n }\n return d;\n }\n shapeOf(name) {\n const s = this.shapes[name];\n if (!s) {\n const known = Object.keys(this.shapes).sort().join(", ") || "(none registered)";\n throw new Error(`Captures.shapeOf: \'${name}\' not registered. Known: ${known}`);\n }\n return s;\n }\n has(name) {\n return this.data.has(name);\n }\n names() {\n return [...this.data.keys()].sort();\n }\n};\nvar STORAGE_RW = 128 | 8 | 4;\nvar READBACK = 1 | 8;\nasync function createRuntime(plan, kernels, lossBufferId, opts = {}) {\n const device2 = opts.device ?? await acquireDevice();\n const queue = device2.queue;\n const buffers = /* @__PURE__ */ new Map();\n const ownedBufferIds = /* @__PURE__ */ new Set();\n const sharedParams = opts.sharedParams;\n for (const spec of plan.buffers) {\n const shared = spec.kind === "param" ? sharedParams?.get(spec.name) : void 0;\n if (shared) {\n if (shared.size !== spec.byteSize) {\n throw new Error(\n `sharedParams: size mismatch for \'${spec.name}\' \\u2014 supplied ${shared.size} bytes, compiled graph expects ${spec.byteSize}.`\n );\n }\n buffers.set(spec.id, shared);\n continue;\n }\n const buf = device2.createBuffer({\n size: spec.byteSize,\n usage: STORAGE_RW,\n label: spec.name ?? `t${spec.id}-${spec.kind}`\n });\n buffers.set(spec.id, buf);\n ownedBufferIds.add(spec.id);\n if (spec.kind === "state") fillStateBuffer(spec, buf);\n }\n const moduleCache = /* @__PURE__ */ new Map();\n const pipelines = [];\n const probes = [];\n for (const k of kernels) {\n if (!k.wgsl) {\n pipelines.push(null);\n continue;\n }\n let module = moduleCache.get(k.wgsl);\n if (!module) {\n module = device2.createShaderModule({ code: k.wgsl, label: k.opKind });\n moduleCache.set(k.wgsl, module);\n }\n device2.pushErrorScope("validation");\n const pipeline = device2.createComputePipeline({\n layout: "auto",\n compute: { module, entryPoint: "main" },\n label: k.opKind\n });\n pipelines.push(pipeline);\n probes.push(device2.popErrorScope().then((err) => err ? { k, module, err } : null));\n }\n const probeResults = await Promise.all(probes);\n const failures = probeResults.filter((p) => p != null);\n if (failures.length > 0) {\n const reports = [];\n for (const { k, module, err } of failures) {\n const info = await module.getCompilationInfo();\n const messages = info.messages.map((m) => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`).join("\\n");\n reports.push(\n `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` + (messages || " (no compilation messages)") + `\n--- WGSL ---\n${k.wgsl}\n-----------`\n );\n }\n console.error(reports.join("\\n\\n"));\n throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`);\n }\n const bindGroups = kernels.map((k, i) => {\n const pipeline = pipelines[i];\n if (!pipeline) return null;\n return device2.createBindGroup({\n layout: pipeline.getBindGroupLayout(0),\n entries: k.bindings.map((bufId, idx) => ({\n binding: idx,\n resource: { buffer: buffers.get(bufId) }\n }))\n });\n });\n const outputSpec = plan.buffers[lossBufferId];\n const outputReadback = device2.createBuffer({ size: outputSpec.byteSize, usage: READBACK });\n let captureStaging = null;\n function ensureCaptureStaging() {\n if (captureStaging) return captureStaging;\n let totalBytes = 0;\n const slices = [];\n for (const [name, bufId] of plan.capturesByName) {\n const spec = plan.buffers[bufId];\n slices.push({ name, bufId, offset: totalBytes, byteSize: spec.byteSize });\n totalBytes += spec.byteSize;\n }\n const buffer = device2.createBuffer({ size: totalBytes, usage: READBACK, label: "captures-staging" });\n captureStaging = { buffer, slices };\n return captureStaging;\n }\n let pending = Promise.resolve();\n async function dispatch(inputs, opts2) {\n const turn = pending.catch(() => {\n }).then(() => dispatchUnsynchronized(inputs, opts2));\n pending = turn;\n return turn;\n }\n async function dispatchUnsynchronized(inputs, opts2) {\n const wantCaptures = opts2.wantCaptures;\n if (wantCaptures && plan.capturesByName.size === 0) {\n throw new Error(\n `withCaptures=true but no capture(...) calls were registered during the trace. Add capture(\'name\', tensor) inside your forward pass for the intermediates you want read back.`\n );\n }\n for (const [name, bufId] of plan.inputsByName) {\n const data = inputs[name];\n if (!data) throw new Error(`tensorgrad: missing input \'${name}\'`);\n const expectedBytes = plan.buffers[bufId].byteSize;\n if (data.byteLength !== expectedBytes) {\n throw new Error(`tensorgrad: input \'${name}\' has ${data.byteLength} bytes, expected ${expectedBytes}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-step" });\n for (let i = 0; i < kernels.length; i++) {\n const k = kernels[i];\n if (!k.wgsl || k.threads === 0) continue;\n const pipeline = pipelines[i];\n const bindGroup = bindGroups[i];\n const pass = encoder.beginComputePass({ label: k.opKind });\n pass.setPipeline(pipeline);\n pass.setBindGroup(0, bindGroup);\n const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));\n const MAX_X = 65535;\n const wgX = Math.min(wgCount, MAX_X);\n const wgY = Math.ceil(wgCount / MAX_X);\n pass.dispatchWorkgroups(wgX, wgY, 1);\n pass.end();\n }\n for (const wb of plan.writebacks) {\n encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);\n }\n encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);\n let layout = null;\n if (wantCaptures) {\n layout = ensureCaptureStaging();\n for (const s of layout.slices) {\n encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);\n }\n }\n queue.submit([encoder.finish()]);\n if (!opts2.readback) return null;\n await outputReadback.mapAsync(GPUMapMode.READ);\n const output = new Float32Array(outputReadback.getMappedRange().slice(0));\n outputReadback.unmap();\n const captures = /* @__PURE__ */ new Map();\n if (layout) {\n await layout.buffer.mapAsync(GPUMapMode.READ);\n const range = layout.buffer.getMappedRange();\n for (const s of layout.slices) {\n captures.set(s.name, new Float32Array(range, s.offset, s.byteSize / 4).slice());\n }\n layout.buffer.unmap();\n }\n return { output, captures };\n }\n async function step(inputs, opts2) {\n if (opts2?.readLoss === false) {\n await dispatch(inputs, { wantCaptures: false, readback: false });\n return;\n }\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { loss: r.output[0], captures: new Captures(captureShapes, r.captures) };\n return r.output[0];\n }\n async function readLoss() {\n const turn = pending.catch(() => {\n }).then(async () => {\n await outputReadback.mapAsync(GPUMapMode.READ);\n const v = new Float32Array(outputReadback.getMappedRange())[0];\n outputReadback.unmap();\n return v;\n });\n pending = turn;\n return turn;\n }\n async function run(inputs, opts2) {\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) };\n return r.output;\n }\n function uploadParams(params2, opts2) {\n const partial = opts2?.partial ?? false;\n for (const name of Object.keys(params2)) {\n if (!plan.paramsByName.has(name)) {\n throw new Error(\n `uploadParams: unknown param \'${name}\'. Known: ${[...plan.paramsByName.keys()].sort().join(", ")}`\n );\n }\n }\n if (!partial) {\n for (const name of plan.paramsByName.keys()) {\n if (!(name in params2)) {\n throw new Error(\n `uploadParams: missing param \'${name}\'. Pass { partial: true } if you mean to update only some params.`\n );\n }\n }\n }\n for (const [name, bufId] of plan.paramsByName) {\n const data = params2[name];\n if (!data) continue;\n const expected = plan.buffers[bufId].byteSize / 4;\n if (data.length !== expected) {\n throw new Error(`uploadParams: \'${name}\' has ${data.length} elements, expected ${expected}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n }\n async function downloadFromMap(map) {\n const stagings = [];\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-download" });\n for (const [name, bufId] of map) {\n const spec = plan.buffers[bufId];\n const staging = device2.createBuffer({ size: spec.byteSize, usage: READBACK });\n encoder.copyBufferToBuffer(buffers.get(bufId), 0, staging, 0, spec.byteSize);\n stagings.push({ name, buf: staging, bytes: spec.byteSize });\n }\n queue.submit([encoder.finish()]);\n const out = {};\n for (const s of stagings) {\n await s.buf.mapAsync(GPUMapMode.READ);\n out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0));\n s.buf.unmap();\n s.buf.destroy();\n }\n return out;\n }\n function fillStateBuffer(spec, target) {\n const elements = spec.byteSize / 4;\n const init = spec.dtype === "f32" ? new Float32Array(elements).fill(spec.initValue ?? 0) : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0));\n queue.writeBuffer(target, 0, init);\n }\n function resetOptimizerState() {\n for (const spec of plan.buffers) {\n if (spec.kind === "state") fillStateBuffer(spec, buffers.get(spec.id));\n }\n }\n const params = /* @__PURE__ */ new Map();\n for (const [name, bufId] of plan.paramsByName) {\n params.set(name, buffers.get(bufId));\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const outputShape = [...plan.buffers[lossBufferId].shape];\n const destroy = () => {\n for (const [id, b] of buffers) {\n if (ownedBufferIds.has(id)) b.destroy();\n }\n outputReadback.destroy();\n if (captureStaging) captureStaging.buffer.destroy();\n };\n return {\n device: device2,\n params,\n outputShape,\n uploadParams,\n downloadParams: () => downloadFromMap(plan.paramsByName),\n downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),\n step,\n run,\n readLoss,\n resetOptimizerState,\n destroy\n };\n}\nasync function acquireDevice() {\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad: no WebGPU adapter");\n return await adapter.requestDevice();\n}\n\n// src/adam.ts\nfunction resolveLR(schedule, step) {\n if (typeof schedule === "number") return schedule;\n switch (schedule.kind) {\n case "constant":\n return schedule.value;\n case "linearDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.peak + (schedule.final - schedule.peak) * f;\n }\n case "cosineDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f));\n }\n case "warmup": {\n if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps);\n return resolveLR(schedule.after, step - schedule.warmupSteps);\n }\n }\n}\n\n// src/worker-protocol.ts\nfunction wireError(e) {\n if (e instanceof Error) {\n return { name: e.name, message: e.message, stack: e.stack ?? "" };\n }\n return { name: "Error", message: String(e), stack: "" };\n}\n\n// src/worker.ts\nvar graphs = /* @__PURE__ */ new Map();\nvar device = null;\nasync function ensureDevice() {\n if (device) return device;\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad worker: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad worker: no WebGPU adapter");\n device = await adapter.requestDevice();\n return device;\n}\nasync function handleCreateRuntime(payload) {\n const dev = await ensureDevice();\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n if (Object.keys(payload.initialParams).length > 0) {\n runtime.uploadParams(payload.initialParams);\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: payload.adam ? createAdamState(payload.adam) : null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nasync function handleCompileForward(payload) {\n const dev = await ensureDevice();\n const parent = graphs.get(payload.parentGraphId);\n if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`);\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev, sharedParams: parent.runtime.params };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nfunction createAdamState(cfg) {\n return {\n config: cfg,\n t: 0,\n lrtBuf: new Float32Array(1),\n decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null\n };\n}\nfunction injectAdamScalars(slot, inputs) {\n const a = slot.adam;\n if (!a) return inputs;\n a.t++;\n const lrNow = resolveLR(a.config.lr, a.t);\n a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t));\n const merged = { ...inputs, [a.config.lrtInputName]: a.lrtBuf };\n if (a.decayShrinkBuf && a.config.decayShrinkInputName) {\n a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay;\n merged[a.config.decayShrinkInputName] = a.decayShrinkBuf;\n }\n return merged;\n}\nasync function handleStep(payload) {\n const slot = mustGet(payload.graphId);\n const merged = injectAdamScalars(slot, payload.inputs);\n if (payload.withCaptures) {\n const r = await slot.runtime.step(merged, { withCaptures: true });\n return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const loss = await slot.runtime.step(merged);\n return { loss, captures: null };\n}\nasync function handleRun(payload) {\n const slot = mustGet(payload.graphId);\n if (payload.withCaptures) {\n const r = await slot.runtime.run(payload.inputs, { withCaptures: true });\n return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const output = await slot.runtime.run(payload.inputs);\n return { output, captures: null };\n}\nfunction capturesToRecord(captures, shapes) {\n const out = {};\n for (const name of Object.keys(shapes)) {\n if (captures.has(name)) out[name] = captures.get(name);\n }\n return out;\n}\nfunction handleUploadParams(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.uploadParams(payload.params, { partial: payload.partial });\n}\nasync function handleDownloadParams(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParams() };\n}\nasync function handleDownloadParamGrads(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParamGrads() };\n}\nfunction handleResetOptimizer(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.resetOptimizerState();\n if (slot.adam) slot.adam.t = 0;\n}\nfunction handleDestroy(payload) {\n const slot = graphs.get(payload.graphId);\n if (!slot) return;\n slot.runtime.destroy();\n graphs.delete(payload.graphId);\n}\nfunction mustGet(graphId) {\n const slot = graphs.get(graphId);\n if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`);\n return slot;\n}\nself.onmessage = async (ev) => {\n const req = ev.data;\n try {\n let result;\n let transferList = [];\n switch (req.kind) {\n case "createRuntime":\n result = await handleCreateRuntime(req.payload);\n break;\n case "compileForward":\n result = await handleCompileForward(req.payload);\n break;\n case "step":\n result = await handleStep(req.payload);\n transferList = collectTransfers(result.captures);\n break;\n case "run": {\n const r = await handleRun(req.payload);\n result = r;\n transferList = [r.output.buffer, ...collectTransfers(r.captures)];\n break;\n }\n case "uploadParams":\n handleUploadParams(req.payload);\n result = null;\n break;\n case "downloadParams": {\n const r = await handleDownloadParams(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "downloadParamGrads": {\n const r = await handleDownloadParamGrads(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "resetOptimizer":\n handleResetOptimizer(req.payload);\n result = null;\n break;\n case "destroy":\n handleDestroy(req.payload);\n result = null;\n break;\n default:\n throw new Error(`unknown request kind: ${req.kind}`);\n }\n const reply = { id: req.id, ok: true, result };\n self.postMessage(reply, { transfer: transferList });\n } catch (e) {\n const error = wireError(e);\n const reply = { id: req.id, ok: false, error };\n self.postMessage(reply);\n }\n};\nfunction collectTransfers(rec) {\n if (!rec) return [];\n const out = [];\n for (const v of Object.values(rec)) out.push(v.buffer);\n return out;\n}\n');
1808
+ const proxy = new WorkerProxy('// src/runtime.ts\nvar Captures = class {\n constructor(shapes, data) {\n this.shapes = shapes;\n this.data = data;\n }\n shapes;\n data;\n get(name) {\n const d = this.data.get(name);\n if (!d) {\n const known = [...this.data.keys()].sort().join(", ");\n const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;\n throw new Error(`Captures.get: \'${name}\' not present. ${detail}`);\n }\n return d;\n }\n shapeOf(name) {\n const s = this.shapes[name];\n if (!s) {\n const known = Object.keys(this.shapes).sort().join(", ") || "(none registered)";\n throw new Error(`Captures.shapeOf: \'${name}\' not registered. Known: ${known}`);\n }\n return s;\n }\n has(name) {\n return this.data.has(name);\n }\n names() {\n return [...this.data.keys()].sort();\n }\n};\nvar STORAGE_RW = 128 | 8 | 4;\nvar READBACK = 1 | 8;\nasync function createRuntime(plan, kernels, lossBufferId, opts = {}) {\n const device2 = opts.device ?? await acquireDevice();\n const queue = device2.queue;\n const buffers = /* @__PURE__ */ new Map();\n const ownedBufferIds = /* @__PURE__ */ new Set();\n const sharedParams = opts.sharedParams;\n for (const spec of plan.buffers) {\n const shared = spec.kind === "param" ? sharedParams?.get(spec.name) : void 0;\n if (shared) {\n if (shared.size !== spec.byteSize) {\n throw new Error(\n `sharedParams: size mismatch for \'${spec.name}\' \\u2014 supplied ${shared.size} bytes, compiled graph expects ${spec.byteSize}.`\n );\n }\n buffers.set(spec.id, shared);\n continue;\n }\n const buf = device2.createBuffer({\n size: spec.byteSize,\n usage: STORAGE_RW,\n label: spec.name ?? `t${spec.id}-${spec.kind}`\n });\n buffers.set(spec.id, buf);\n ownedBufferIds.add(spec.id);\n if (spec.kind === "state") fillStateBuffer(spec, buf);\n }\n const moduleCache = /* @__PURE__ */ new Map();\n const pipelines = [];\n const probes = [];\n for (const k of kernels) {\n if (!k.wgsl) {\n pipelines.push(null);\n continue;\n }\n let module = moduleCache.get(k.wgsl);\n if (!module) {\n module = device2.createShaderModule({ code: k.wgsl, label: k.opKind });\n moduleCache.set(k.wgsl, module);\n }\n device2.pushErrorScope("validation");\n const pipeline = device2.createComputePipeline({\n layout: "auto",\n compute: { module, entryPoint: "main" },\n label: k.opKind\n });\n pipelines.push(pipeline);\n probes.push(device2.popErrorScope().then((err) => err ? { k, module, err } : null));\n }\n const probeResults = await Promise.all(probes);\n const failures = probeResults.filter((p) => p != null);\n if (failures.length > 0) {\n const reports = [];\n for (const { k, module, err } of failures) {\n const info = await module.getCompilationInfo();\n const messages = info.messages.map((m) => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`).join("\\n");\n reports.push(\n `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` + (messages || " (no compilation messages)") + `\n--- WGSL ---\n${k.wgsl}\n-----------`\n );\n }\n console.error(reports.join("\\n\\n"));\n throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`);\n }\n const bindGroups = kernels.map((k, i) => {\n const pipeline = pipelines[i];\n if (!pipeline) return null;\n return device2.createBindGroup({\n layout: pipeline.getBindGroupLayout(0),\n entries: k.bindings.map((bufId, idx) => ({\n binding: idx,\n resource: { buffer: buffers.get(bufId) }\n }))\n });\n });\n const outputSpec = plan.buffers[lossBufferId];\n const outputReadback = device2.createBuffer({ size: outputSpec.byteSize, usage: READBACK });\n let captureStaging = null;\n function ensureCaptureStaging() {\n if (captureStaging) return captureStaging;\n let totalBytes = 0;\n const slices = [];\n for (const [name, bufId] of plan.capturesByName) {\n const spec = plan.buffers[bufId];\n slices.push({ name, bufId, offset: totalBytes, byteSize: spec.byteSize });\n totalBytes += spec.byteSize;\n }\n const buffer = device2.createBuffer({ size: totalBytes, usage: READBACK, label: "captures-staging" });\n captureStaging = { buffer, slices };\n return captureStaging;\n }\n let pending = Promise.resolve();\n async function dispatch(inputs, opts2) {\n const turn = pending.catch(() => {\n }).then(() => dispatchUnsynchronized(inputs, opts2));\n pending = turn;\n return turn;\n }\n async function dispatchUnsynchronized(inputs, opts2) {\n const wantCaptures = opts2.wantCaptures;\n if (wantCaptures && plan.capturesByName.size === 0) {\n throw new Error(\n `withCaptures=true but no capture(...) calls were registered during the trace. Add capture(\'name\', tensor) inside your forward pass for the intermediates you want read back.`\n );\n }\n for (const [name, bufId] of plan.inputsByName) {\n const data = inputs[name];\n if (!data) throw new Error(`tensorgrad: missing input \'${name}\'`);\n const expectedBytes = plan.buffers[bufId].byteSize;\n if (data.byteLength !== expectedBytes) {\n throw new Error(`tensorgrad: input \'${name}\' has ${data.byteLength} bytes, expected ${expectedBytes}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n const CHUNK_SIZE = 32;\n let layout = null;\n if (wantCaptures) {\n layout = ensureCaptureStaging();\n }\n let kernelIdx = 0;\n while (kernelIdx < kernels.length) {\n const chunkEnd = Math.min(kernelIdx + CHUNK_SIZE, kernels.length);\n const isLast = chunkEnd === kernels.length;\n const encoder = device2.createCommandEncoder({\n label: kernels.length > CHUNK_SIZE ? `tensorgrad-chunk-${kernelIdx}` : "tensorgrad-step"\n });\n for (let i = kernelIdx; i < chunkEnd; i++) {\n const k = kernels[i];\n if (!k.wgsl || k.threads === 0) continue;\n const pipeline = pipelines[i];\n const bindGroup = bindGroups[i];\n const pass = encoder.beginComputePass({ label: k.opKind });\n pass.setPipeline(pipeline);\n pass.setBindGroup(0, bindGroup);\n const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));\n const MAX_X = 65535;\n const wgX = Math.min(wgCount, MAX_X);\n const wgY = Math.ceil(wgCount / MAX_X);\n pass.dispatchWorkgroups(wgX, wgY, 1);\n pass.end();\n }\n if (isLast) {\n for (const wb of plan.writebacks) {\n encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);\n }\n encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);\n if (layout) {\n for (const s of layout.slices) {\n encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);\n }\n }\n }\n queue.submit([encoder.finish()]);\n if (!isLast) {\n await queue.onSubmittedWorkDone();\n }\n kernelIdx = chunkEnd;\n }\n if (!opts2.readback) return null;\n await outputReadback.mapAsync(GPUMapMode.READ);\n const output = new Float32Array(outputReadback.getMappedRange().slice(0));\n outputReadback.unmap();\n const captures = /* @__PURE__ */ new Map();\n if (layout) {\n await layout.buffer.mapAsync(GPUMapMode.READ);\n const range = layout.buffer.getMappedRange();\n for (const s of layout.slices) {\n captures.set(s.name, new Float32Array(range, s.offset, s.byteSize / 4).slice());\n }\n layout.buffer.unmap();\n }\n return { output, captures };\n }\n async function step(inputs, opts2) {\n if (opts2?.readLoss === false) {\n await dispatch(inputs, { wantCaptures: false, readback: false });\n return;\n }\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { loss: r.output[0], captures: new Captures(captureShapes, r.captures) };\n return r.output[0];\n }\n async function readLoss() {\n const turn = pending.catch(() => {\n }).then(async () => {\n await outputReadback.mapAsync(GPUMapMode.READ);\n const v = new Float32Array(outputReadback.getMappedRange())[0];\n outputReadback.unmap();\n return v;\n });\n pending = turn;\n return turn;\n }\n async function run(inputs, opts2) {\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) };\n return r.output;\n }\n function uploadParams(params2, opts2) {\n const partial = opts2?.partial ?? false;\n for (const name of Object.keys(params2)) {\n if (!plan.paramsByName.has(name)) {\n throw new Error(\n `uploadParams: unknown param \'${name}\'. Known: ${[...plan.paramsByName.keys()].sort().join(", ")}`\n );\n }\n }\n if (!partial) {\n for (const name of plan.paramsByName.keys()) {\n if (!(name in params2)) {\n throw new Error(\n `uploadParams: missing param \'${name}\'. Pass { partial: true } if you mean to update only some params.`\n );\n }\n }\n }\n for (const [name, bufId] of plan.paramsByName) {\n const data = params2[name];\n if (!data) continue;\n const expected = plan.buffers[bufId].byteSize / 4;\n if (data.length !== expected) {\n throw new Error(`uploadParams: \'${name}\' has ${data.length} elements, expected ${expected}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n }\n async function downloadFromMap(map) {\n const stagings = [];\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-download" });\n for (const [name, bufId] of map) {\n const spec = plan.buffers[bufId];\n const staging = device2.createBuffer({ size: spec.byteSize, usage: READBACK });\n encoder.copyBufferToBuffer(buffers.get(bufId), 0, staging, 0, spec.byteSize);\n stagings.push({ name, buf: staging, bytes: spec.byteSize });\n }\n queue.submit([encoder.finish()]);\n const out = {};\n for (const s of stagings) {\n await s.buf.mapAsync(GPUMapMode.READ);\n out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0));\n s.buf.unmap();\n s.buf.destroy();\n }\n return out;\n }\n function fillStateBuffer(spec, target) {\n const elements = spec.byteSize / 4;\n const init = spec.dtype === "f32" ? new Float32Array(elements).fill(spec.initValue ?? 0) : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0));\n queue.writeBuffer(target, 0, init);\n }\n function resetOptimizerState() {\n for (const spec of plan.buffers) {\n if (spec.kind === "state") fillStateBuffer(spec, buffers.get(spec.id));\n }\n }\n const params = /* @__PURE__ */ new Map();\n for (const [name, bufId] of plan.paramsByName) {\n params.set(name, buffers.get(bufId));\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const outputShape = [...plan.buffers[lossBufferId].shape];\n const destroy = () => {\n for (const [id, b] of buffers) {\n if (ownedBufferIds.has(id)) b.destroy();\n }\n outputReadback.destroy();\n if (captureStaging) captureStaging.buffer.destroy();\n };\n return {\n device: device2,\n params,\n outputShape,\n uploadParams,\n downloadParams: () => downloadFromMap(plan.paramsByName),\n downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),\n step,\n run,\n readLoss,\n resetOptimizerState,\n destroy\n };\n}\nasync function acquireDevice() {\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad: no WebGPU adapter");\n return await adapter.requestDevice();\n}\n\n// src/adam.ts\nfunction resolveLR(schedule, step) {\n if (typeof schedule === "number") return schedule;\n switch (schedule.kind) {\n case "constant":\n return schedule.value;\n case "linearDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.peak + (schedule.final - schedule.peak) * f;\n }\n case "cosineDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f));\n }\n case "warmup": {\n if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps);\n return resolveLR(schedule.after, step - schedule.warmupSteps);\n }\n }\n}\n\n// src/worker-protocol.ts\nfunction wireError(e) {\n if (e instanceof Error) {\n return { name: e.name, message: e.message, stack: e.stack ?? "" };\n }\n return { name: "Error", message: String(e), stack: "" };\n}\n\n// src/worker.ts\nvar graphs = /* @__PURE__ */ new Map();\nvar device = null;\nasync function ensureDevice() {\n if (device) return device;\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad worker: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad worker: no WebGPU adapter");\n device = await adapter.requestDevice();\n return device;\n}\nasync function handleCreateRuntime(payload) {\n const dev = await ensureDevice();\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n if (Object.keys(payload.initialParams).length > 0) {\n runtime.uploadParams(payload.initialParams);\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: payload.adam ? createAdamState(payload.adam) : null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nasync function handleCompileForward(payload) {\n const dev = await ensureDevice();\n const parent = graphs.get(payload.parentGraphId);\n if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`);\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev, sharedParams: parent.runtime.params };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nfunction createAdamState(cfg) {\n return {\n config: cfg,\n t: 0,\n lrtBuf: new Float32Array(1),\n decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null\n };\n}\nfunction injectAdamScalars(slot, inputs) {\n const a = slot.adam;\n if (!a) return inputs;\n a.t++;\n const lrNow = resolveLR(a.config.lr, a.t);\n a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t));\n const merged = { ...inputs, [a.config.lrtInputName]: a.lrtBuf };\n if (a.decayShrinkBuf && a.config.decayShrinkInputName) {\n a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay;\n merged[a.config.decayShrinkInputName] = a.decayShrinkBuf;\n }\n return merged;\n}\nasync function handleStep(payload) {\n const slot = mustGet(payload.graphId);\n const merged = injectAdamScalars(slot, payload.inputs);\n if (payload.withCaptures) {\n const r = await slot.runtime.step(merged, { withCaptures: true });\n return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const loss = await slot.runtime.step(merged);\n return { loss, captures: null };\n}\nasync function handleRun(payload) {\n const slot = mustGet(payload.graphId);\n if (payload.withCaptures) {\n const r = await slot.runtime.run(payload.inputs, { withCaptures: true });\n return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const output = await slot.runtime.run(payload.inputs);\n return { output, captures: null };\n}\nfunction capturesToRecord(captures, shapes) {\n const out = {};\n for (const name of Object.keys(shapes)) {\n if (captures.has(name)) out[name] = captures.get(name);\n }\n return out;\n}\nfunction handleUploadParams(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.uploadParams(payload.params, { partial: payload.partial });\n}\nasync function handleDownloadParams(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParams() };\n}\nasync function handleDownloadParamGrads(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParamGrads() };\n}\nfunction handleResetOptimizer(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.resetOptimizerState();\n if (slot.adam) slot.adam.t = 0;\n}\nfunction handleDestroy(payload) {\n const slot = graphs.get(payload.graphId);\n if (!slot) return;\n slot.runtime.destroy();\n graphs.delete(payload.graphId);\n}\nfunction mustGet(graphId) {\n const slot = graphs.get(graphId);\n if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`);\n return slot;\n}\nself.onmessage = async (ev) => {\n const req = ev.data;\n try {\n let result;\n let transferList = [];\n switch (req.kind) {\n case "createRuntime":\n result = await handleCreateRuntime(req.payload);\n break;\n case "compileForward":\n result = await handleCompileForward(req.payload);\n break;\n case "step":\n result = await handleStep(req.payload);\n transferList = collectTransfers(result.captures);\n break;\n case "run": {\n const r = await handleRun(req.payload);\n result = r;\n transferList = [r.output.buffer, ...collectTransfers(r.captures)];\n break;\n }\n case "uploadParams":\n handleUploadParams(req.payload);\n result = null;\n break;\n case "downloadParams": {\n const r = await handleDownloadParams(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "downloadParamGrads": {\n const r = await handleDownloadParamGrads(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "resetOptimizer":\n handleResetOptimizer(req.payload);\n result = null;\n break;\n case "destroy":\n handleDestroy(req.payload);\n result = null;\n break;\n default:\n throw new Error(`unknown request kind: ${req.kind}`);\n }\n const reply = { id: req.id, ok: true, result };\n self.postMessage(reply, { transfer: transferList });\n } catch (e) {\n const error = wireError(e);\n const reply = { id: req.id, ok: false, error };\n self.postMessage(reply);\n }\n};\nfunction collectTransfers(rec) {\n if (!rec) return [];\n const out = [];\n for (const v of Object.values(rec)) out.push(v.buffer);\n return out;\n}\n');
1809
1809
  const wireIR = { graph, plan, kernels };
1810
1810
  const wireAdam = adamResult ? wireAdamConfig(adamResult) : null;
1811
1811
  const transfers = transferablesOfRecord(initialParams);
@@ -1843,7 +1843,7 @@ async function compileForward(modelFactory, forward, opts = {}) {
1843
1843
  const kernels = emitKernels(graph, plan);
1844
1844
  const ir = { graph, paramGrads: {}, loss: outputTensor, plan, kernels };
1845
1845
  const initialParams = buildInitialParams(plan, materialized.initFns);
1846
- const proxy = new WorkerProxy('// src/runtime.ts\nvar Captures = class {\n constructor(shapes, data) {\n this.shapes = shapes;\n this.data = data;\n }\n shapes;\n data;\n get(name) {\n const d = this.data.get(name);\n if (!d) {\n const known = [...this.data.keys()].sort().join(", ");\n const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;\n throw new Error(`Captures.get: \'${name}\' not present. ${detail}`);\n }\n return d;\n }\n shapeOf(name) {\n const s = this.shapes[name];\n if (!s) {\n const known = Object.keys(this.shapes).sort().join(", ") || "(none registered)";\n throw new Error(`Captures.shapeOf: \'${name}\' not registered. Known: ${known}`);\n }\n return s;\n }\n has(name) {\n return this.data.has(name);\n }\n names() {\n return [...this.data.keys()].sort();\n }\n};\nvar STORAGE_RW = 128 | 8 | 4;\nvar READBACK = 1 | 8;\nasync function createRuntime(plan, kernels, lossBufferId, opts = {}) {\n const device2 = opts.device ?? await acquireDevice();\n const queue = device2.queue;\n const buffers = /* @__PURE__ */ new Map();\n const ownedBufferIds = /* @__PURE__ */ new Set();\n const sharedParams = opts.sharedParams;\n for (const spec of plan.buffers) {\n const shared = spec.kind === "param" ? sharedParams?.get(spec.name) : void 0;\n if (shared) {\n if (shared.size !== spec.byteSize) {\n throw new Error(\n `sharedParams: size mismatch for \'${spec.name}\' \\u2014 supplied ${shared.size} bytes, compiled graph expects ${spec.byteSize}.`\n );\n }\n buffers.set(spec.id, shared);\n continue;\n }\n const buf = device2.createBuffer({\n size: spec.byteSize,\n usage: STORAGE_RW,\n label: spec.name ?? `t${spec.id}-${spec.kind}`\n });\n buffers.set(spec.id, buf);\n ownedBufferIds.add(spec.id);\n if (spec.kind === "state") fillStateBuffer(spec, buf);\n }\n const moduleCache = /* @__PURE__ */ new Map();\n const pipelines = [];\n const probes = [];\n for (const k of kernels) {\n if (!k.wgsl) {\n pipelines.push(null);\n continue;\n }\n let module = moduleCache.get(k.wgsl);\n if (!module) {\n module = device2.createShaderModule({ code: k.wgsl, label: k.opKind });\n moduleCache.set(k.wgsl, module);\n }\n device2.pushErrorScope("validation");\n const pipeline = device2.createComputePipeline({\n layout: "auto",\n compute: { module, entryPoint: "main" },\n label: k.opKind\n });\n pipelines.push(pipeline);\n probes.push(device2.popErrorScope().then((err) => err ? { k, module, err } : null));\n }\n const probeResults = await Promise.all(probes);\n const failures = probeResults.filter((p) => p != null);\n if (failures.length > 0) {\n const reports = [];\n for (const { k, module, err } of failures) {\n const info = await module.getCompilationInfo();\n const messages = info.messages.map((m) => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`).join("\\n");\n reports.push(\n `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` + (messages || " (no compilation messages)") + `\n--- WGSL ---\n${k.wgsl}\n-----------`\n );\n }\n console.error(reports.join("\\n\\n"));\n throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`);\n }\n const bindGroups = kernels.map((k, i) => {\n const pipeline = pipelines[i];\n if (!pipeline) return null;\n return device2.createBindGroup({\n layout: pipeline.getBindGroupLayout(0),\n entries: k.bindings.map((bufId, idx) => ({\n binding: idx,\n resource: { buffer: buffers.get(bufId) }\n }))\n });\n });\n const outputSpec = plan.buffers[lossBufferId];\n const outputReadback = device2.createBuffer({ size: outputSpec.byteSize, usage: READBACK });\n let captureStaging = null;\n function ensureCaptureStaging() {\n if (captureStaging) return captureStaging;\n let totalBytes = 0;\n const slices = [];\n for (const [name, bufId] of plan.capturesByName) {\n const spec = plan.buffers[bufId];\n slices.push({ name, bufId, offset: totalBytes, byteSize: spec.byteSize });\n totalBytes += spec.byteSize;\n }\n const buffer = device2.createBuffer({ size: totalBytes, usage: READBACK, label: "captures-staging" });\n captureStaging = { buffer, slices };\n return captureStaging;\n }\n let pending = Promise.resolve();\n async function dispatch(inputs, opts2) {\n const turn = pending.catch(() => {\n }).then(() => dispatchUnsynchronized(inputs, opts2));\n pending = turn;\n return turn;\n }\n async function dispatchUnsynchronized(inputs, opts2) {\n const wantCaptures = opts2.wantCaptures;\n if (wantCaptures && plan.capturesByName.size === 0) {\n throw new Error(\n `withCaptures=true but no capture(...) calls were registered during the trace. Add capture(\'name\', tensor) inside your forward pass for the intermediates you want read back.`\n );\n }\n for (const [name, bufId] of plan.inputsByName) {\n const data = inputs[name];\n if (!data) throw new Error(`tensorgrad: missing input \'${name}\'`);\n const expectedBytes = plan.buffers[bufId].byteSize;\n if (data.byteLength !== expectedBytes) {\n throw new Error(`tensorgrad: input \'${name}\' has ${data.byteLength} bytes, expected ${expectedBytes}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-step" });\n for (let i = 0; i < kernels.length; i++) {\n const k = kernels[i];\n if (!k.wgsl || k.threads === 0) continue;\n const pipeline = pipelines[i];\n const bindGroup = bindGroups[i];\n const pass = encoder.beginComputePass({ label: k.opKind });\n pass.setPipeline(pipeline);\n pass.setBindGroup(0, bindGroup);\n const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));\n const MAX_X = 65535;\n const wgX = Math.min(wgCount, MAX_X);\n const wgY = Math.ceil(wgCount / MAX_X);\n pass.dispatchWorkgroups(wgX, wgY, 1);\n pass.end();\n }\n for (const wb of plan.writebacks) {\n encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);\n }\n encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);\n let layout = null;\n if (wantCaptures) {\n layout = ensureCaptureStaging();\n for (const s of layout.slices) {\n encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);\n }\n }\n queue.submit([encoder.finish()]);\n if (!opts2.readback) return null;\n await outputReadback.mapAsync(GPUMapMode.READ);\n const output = new Float32Array(outputReadback.getMappedRange().slice(0));\n outputReadback.unmap();\n const captures = /* @__PURE__ */ new Map();\n if (layout) {\n await layout.buffer.mapAsync(GPUMapMode.READ);\n const range = layout.buffer.getMappedRange();\n for (const s of layout.slices) {\n captures.set(s.name, new Float32Array(range, s.offset, s.byteSize / 4).slice());\n }\n layout.buffer.unmap();\n }\n return { output, captures };\n }\n async function step(inputs, opts2) {\n if (opts2?.readLoss === false) {\n await dispatch(inputs, { wantCaptures: false, readback: false });\n return;\n }\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { loss: r.output[0], captures: new Captures(captureShapes, r.captures) };\n return r.output[0];\n }\n async function readLoss() {\n const turn = pending.catch(() => {\n }).then(async () => {\n await outputReadback.mapAsync(GPUMapMode.READ);\n const v = new Float32Array(outputReadback.getMappedRange())[0];\n outputReadback.unmap();\n return v;\n });\n pending = turn;\n return turn;\n }\n async function run(inputs, opts2) {\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) };\n return r.output;\n }\n function uploadParams(params2, opts2) {\n const partial = opts2?.partial ?? false;\n for (const name of Object.keys(params2)) {\n if (!plan.paramsByName.has(name)) {\n throw new Error(\n `uploadParams: unknown param \'${name}\'. Known: ${[...plan.paramsByName.keys()].sort().join(", ")}`\n );\n }\n }\n if (!partial) {\n for (const name of plan.paramsByName.keys()) {\n if (!(name in params2)) {\n throw new Error(\n `uploadParams: missing param \'${name}\'. Pass { partial: true } if you mean to update only some params.`\n );\n }\n }\n }\n for (const [name, bufId] of plan.paramsByName) {\n const data = params2[name];\n if (!data) continue;\n const expected = plan.buffers[bufId].byteSize / 4;\n if (data.length !== expected) {\n throw new Error(`uploadParams: \'${name}\' has ${data.length} elements, expected ${expected}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n }\n async function downloadFromMap(map) {\n const stagings = [];\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-download" });\n for (const [name, bufId] of map) {\n const spec = plan.buffers[bufId];\n const staging = device2.createBuffer({ size: spec.byteSize, usage: READBACK });\n encoder.copyBufferToBuffer(buffers.get(bufId), 0, staging, 0, spec.byteSize);\n stagings.push({ name, buf: staging, bytes: spec.byteSize });\n }\n queue.submit([encoder.finish()]);\n const out = {};\n for (const s of stagings) {\n await s.buf.mapAsync(GPUMapMode.READ);\n out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0));\n s.buf.unmap();\n s.buf.destroy();\n }\n return out;\n }\n function fillStateBuffer(spec, target) {\n const elements = spec.byteSize / 4;\n const init = spec.dtype === "f32" ? new Float32Array(elements).fill(spec.initValue ?? 0) : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0));\n queue.writeBuffer(target, 0, init);\n }\n function resetOptimizerState() {\n for (const spec of plan.buffers) {\n if (spec.kind === "state") fillStateBuffer(spec, buffers.get(spec.id));\n }\n }\n const params = /* @__PURE__ */ new Map();\n for (const [name, bufId] of plan.paramsByName) {\n params.set(name, buffers.get(bufId));\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const outputShape = [...plan.buffers[lossBufferId].shape];\n const destroy = () => {\n for (const [id, b] of buffers) {\n if (ownedBufferIds.has(id)) b.destroy();\n }\n outputReadback.destroy();\n if (captureStaging) captureStaging.buffer.destroy();\n };\n return {\n device: device2,\n params,\n outputShape,\n uploadParams,\n downloadParams: () => downloadFromMap(plan.paramsByName),\n downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),\n step,\n run,\n readLoss,\n resetOptimizerState,\n destroy\n };\n}\nasync function acquireDevice() {\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad: no WebGPU adapter");\n return await adapter.requestDevice();\n}\n\n// src/adam.ts\nfunction resolveLR(schedule, step) {\n if (typeof schedule === "number") return schedule;\n switch (schedule.kind) {\n case "constant":\n return schedule.value;\n case "linearDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.peak + (schedule.final - schedule.peak) * f;\n }\n case "cosineDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f));\n }\n case "warmup": {\n if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps);\n return resolveLR(schedule.after, step - schedule.warmupSteps);\n }\n }\n}\n\n// src/worker-protocol.ts\nfunction wireError(e) {\n if (e instanceof Error) {\n return { name: e.name, message: e.message, stack: e.stack ?? "" };\n }\n return { name: "Error", message: String(e), stack: "" };\n}\n\n// src/worker.ts\nvar graphs = /* @__PURE__ */ new Map();\nvar device = null;\nasync function ensureDevice() {\n if (device) return device;\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad worker: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad worker: no WebGPU adapter");\n device = await adapter.requestDevice();\n return device;\n}\nasync function handleCreateRuntime(payload) {\n const dev = await ensureDevice();\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n if (Object.keys(payload.initialParams).length > 0) {\n runtime.uploadParams(payload.initialParams);\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: payload.adam ? createAdamState(payload.adam) : null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nasync function handleCompileForward(payload) {\n const dev = await ensureDevice();\n const parent = graphs.get(payload.parentGraphId);\n if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`);\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev, sharedParams: parent.runtime.params };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nfunction createAdamState(cfg) {\n return {\n config: cfg,\n t: 0,\n lrtBuf: new Float32Array(1),\n decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null\n };\n}\nfunction injectAdamScalars(slot, inputs) {\n const a = slot.adam;\n if (!a) return inputs;\n a.t++;\n const lrNow = resolveLR(a.config.lr, a.t);\n a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t));\n const merged = { ...inputs, [a.config.lrtInputName]: a.lrtBuf };\n if (a.decayShrinkBuf && a.config.decayShrinkInputName) {\n a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay;\n merged[a.config.decayShrinkInputName] = a.decayShrinkBuf;\n }\n return merged;\n}\nasync function handleStep(payload) {\n const slot = mustGet(payload.graphId);\n const merged = injectAdamScalars(slot, payload.inputs);\n if (payload.withCaptures) {\n const r = await slot.runtime.step(merged, { withCaptures: true });\n return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const loss = await slot.runtime.step(merged);\n return { loss, captures: null };\n}\nasync function handleRun(payload) {\n const slot = mustGet(payload.graphId);\n if (payload.withCaptures) {\n const r = await slot.runtime.run(payload.inputs, { withCaptures: true });\n return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const output = await slot.runtime.run(payload.inputs);\n return { output, captures: null };\n}\nfunction capturesToRecord(captures, shapes) {\n const out = {};\n for (const name of Object.keys(shapes)) {\n if (captures.has(name)) out[name] = captures.get(name);\n }\n return out;\n}\nfunction handleUploadParams(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.uploadParams(payload.params, { partial: payload.partial });\n}\nasync function handleDownloadParams(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParams() };\n}\nasync function handleDownloadParamGrads(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParamGrads() };\n}\nfunction handleResetOptimizer(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.resetOptimizerState();\n if (slot.adam) slot.adam.t = 0;\n}\nfunction handleDestroy(payload) {\n const slot = graphs.get(payload.graphId);\n if (!slot) return;\n slot.runtime.destroy();\n graphs.delete(payload.graphId);\n}\nfunction mustGet(graphId) {\n const slot = graphs.get(graphId);\n if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`);\n return slot;\n}\nself.onmessage = async (ev) => {\n const req = ev.data;\n try {\n let result;\n let transferList = [];\n switch (req.kind) {\n case "createRuntime":\n result = await handleCreateRuntime(req.payload);\n break;\n case "compileForward":\n result = await handleCompileForward(req.payload);\n break;\n case "step":\n result = await handleStep(req.payload);\n transferList = collectTransfers(result.captures);\n break;\n case "run": {\n const r = await handleRun(req.payload);\n result = r;\n transferList = [r.output.buffer, ...collectTransfers(r.captures)];\n break;\n }\n case "uploadParams":\n handleUploadParams(req.payload);\n result = null;\n break;\n case "downloadParams": {\n const r = await handleDownloadParams(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "downloadParamGrads": {\n const r = await handleDownloadParamGrads(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "resetOptimizer":\n handleResetOptimizer(req.payload);\n result = null;\n break;\n case "destroy":\n handleDestroy(req.payload);\n result = null;\n break;\n default:\n throw new Error(`unknown request kind: ${req.kind}`);\n }\n const reply = { id: req.id, ok: true, result };\n self.postMessage(reply, { transfer: transferList });\n } catch (e) {\n const error = wireError(e);\n const reply = { id: req.id, ok: false, error };\n self.postMessage(reply);\n }\n};\nfunction collectTransfers(rec) {\n if (!rec) return [];\n const out = [];\n for (const v of Object.values(rec)) out.push(v.buffer);\n return out;\n}\n');
1846
+ const proxy = new WorkerProxy('// src/runtime.ts\nvar Captures = class {\n constructor(shapes, data) {\n this.shapes = shapes;\n this.data = data;\n }\n shapes;\n data;\n get(name) {\n const d = this.data.get(name);\n if (!d) {\n const known = [...this.data.keys()].sort().join(", ");\n const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;\n throw new Error(`Captures.get: \'${name}\' not present. ${detail}`);\n }\n return d;\n }\n shapeOf(name) {\n const s = this.shapes[name];\n if (!s) {\n const known = Object.keys(this.shapes).sort().join(", ") || "(none registered)";\n throw new Error(`Captures.shapeOf: \'${name}\' not registered. Known: ${known}`);\n }\n return s;\n }\n has(name) {\n return this.data.has(name);\n }\n names() {\n return [...this.data.keys()].sort();\n }\n};\nvar STORAGE_RW = 128 | 8 | 4;\nvar READBACK = 1 | 8;\nasync function createRuntime(plan, kernels, lossBufferId, opts = {}) {\n const device2 = opts.device ?? await acquireDevice();\n const queue = device2.queue;\n const buffers = /* @__PURE__ */ new Map();\n const ownedBufferIds = /* @__PURE__ */ new Set();\n const sharedParams = opts.sharedParams;\n for (const spec of plan.buffers) {\n const shared = spec.kind === "param" ? sharedParams?.get(spec.name) : void 0;\n if (shared) {\n if (shared.size !== spec.byteSize) {\n throw new Error(\n `sharedParams: size mismatch for \'${spec.name}\' \\u2014 supplied ${shared.size} bytes, compiled graph expects ${spec.byteSize}.`\n );\n }\n buffers.set(spec.id, shared);\n continue;\n }\n const buf = device2.createBuffer({\n size: spec.byteSize,\n usage: STORAGE_RW,\n label: spec.name ?? `t${spec.id}-${spec.kind}`\n });\n buffers.set(spec.id, buf);\n ownedBufferIds.add(spec.id);\n if (spec.kind === "state") fillStateBuffer(spec, buf);\n }\n const moduleCache = /* @__PURE__ */ new Map();\n const pipelines = [];\n const probes = [];\n for (const k of kernels) {\n if (!k.wgsl) {\n pipelines.push(null);\n continue;\n }\n let module = moduleCache.get(k.wgsl);\n if (!module) {\n module = device2.createShaderModule({ code: k.wgsl, label: k.opKind });\n moduleCache.set(k.wgsl, module);\n }\n device2.pushErrorScope("validation");\n const pipeline = device2.createComputePipeline({\n layout: "auto",\n compute: { module, entryPoint: "main" },\n label: k.opKind\n });\n pipelines.push(pipeline);\n probes.push(device2.popErrorScope().then((err) => err ? { k, module, err } : null));\n }\n const probeResults = await Promise.all(probes);\n const failures = probeResults.filter((p) => p != null);\n if (failures.length > 0) {\n const reports = [];\n for (const { k, module, err } of failures) {\n const info = await module.getCompilationInfo();\n const messages = info.messages.map((m) => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`).join("\\n");\n reports.push(\n `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` + (messages || " (no compilation messages)") + `\n--- WGSL ---\n${k.wgsl}\n-----------`\n );\n }\n console.error(reports.join("\\n\\n"));\n throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`);\n }\n const bindGroups = kernels.map((k, i) => {\n const pipeline = pipelines[i];\n if (!pipeline) return null;\n return device2.createBindGroup({\n layout: pipeline.getBindGroupLayout(0),\n entries: k.bindings.map((bufId, idx) => ({\n binding: idx,\n resource: { buffer: buffers.get(bufId) }\n }))\n });\n });\n const outputSpec = plan.buffers[lossBufferId];\n const outputReadback = device2.createBuffer({ size: outputSpec.byteSize, usage: READBACK });\n let captureStaging = null;\n function ensureCaptureStaging() {\n if (captureStaging) return captureStaging;\n let totalBytes = 0;\n const slices = [];\n for (const [name, bufId] of plan.capturesByName) {\n const spec = plan.buffers[bufId];\n slices.push({ name, bufId, offset: totalBytes, byteSize: spec.byteSize });\n totalBytes += spec.byteSize;\n }\n const buffer = device2.createBuffer({ size: totalBytes, usage: READBACK, label: "captures-staging" });\n captureStaging = { buffer, slices };\n return captureStaging;\n }\n let pending = Promise.resolve();\n async function dispatch(inputs, opts2) {\n const turn = pending.catch(() => {\n }).then(() => dispatchUnsynchronized(inputs, opts2));\n pending = turn;\n return turn;\n }\n async function dispatchUnsynchronized(inputs, opts2) {\n const wantCaptures = opts2.wantCaptures;\n if (wantCaptures && plan.capturesByName.size === 0) {\n throw new Error(\n `withCaptures=true but no capture(...) calls were registered during the trace. Add capture(\'name\', tensor) inside your forward pass for the intermediates you want read back.`\n );\n }\n for (const [name, bufId] of plan.inputsByName) {\n const data = inputs[name];\n if (!data) throw new Error(`tensorgrad: missing input \'${name}\'`);\n const expectedBytes = plan.buffers[bufId].byteSize;\n if (data.byteLength !== expectedBytes) {\n throw new Error(`tensorgrad: input \'${name}\' has ${data.byteLength} bytes, expected ${expectedBytes}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n const CHUNK_SIZE = 32;\n let layout = null;\n if (wantCaptures) {\n layout = ensureCaptureStaging();\n }\n let kernelIdx = 0;\n while (kernelIdx < kernels.length) {\n const chunkEnd = Math.min(kernelIdx + CHUNK_SIZE, kernels.length);\n const isLast = chunkEnd === kernels.length;\n const encoder = device2.createCommandEncoder({\n label: kernels.length > CHUNK_SIZE ? `tensorgrad-chunk-${kernelIdx}` : "tensorgrad-step"\n });\n for (let i = kernelIdx; i < chunkEnd; i++) {\n const k = kernels[i];\n if (!k.wgsl || k.threads === 0) continue;\n const pipeline = pipelines[i];\n const bindGroup = bindGroups[i];\n const pass = encoder.beginComputePass({ label: k.opKind });\n pass.setPipeline(pipeline);\n pass.setBindGroup(0, bindGroup);\n const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));\n const MAX_X = 65535;\n const wgX = Math.min(wgCount, MAX_X);\n const wgY = Math.ceil(wgCount / MAX_X);\n pass.dispatchWorkgroups(wgX, wgY, 1);\n pass.end();\n }\n if (isLast) {\n for (const wb of plan.writebacks) {\n encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);\n }\n encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);\n if (layout) {\n for (const s of layout.slices) {\n encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);\n }\n }\n }\n queue.submit([encoder.finish()]);\n if (!isLast) {\n await queue.onSubmittedWorkDone();\n }\n kernelIdx = chunkEnd;\n }\n if (!opts2.readback) return null;\n await outputReadback.mapAsync(GPUMapMode.READ);\n const output = new Float32Array(outputReadback.getMappedRange().slice(0));\n outputReadback.unmap();\n const captures = /* @__PURE__ */ new Map();\n if (layout) {\n await layout.buffer.mapAsync(GPUMapMode.READ);\n const range = layout.buffer.getMappedRange();\n for (const s of layout.slices) {\n captures.set(s.name, new Float32Array(range, s.offset, s.byteSize / 4).slice());\n }\n layout.buffer.unmap();\n }\n return { output, captures };\n }\n async function step(inputs, opts2) {\n if (opts2?.readLoss === false) {\n await dispatch(inputs, { wantCaptures: false, readback: false });\n return;\n }\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { loss: r.output[0], captures: new Captures(captureShapes, r.captures) };\n return r.output[0];\n }\n async function readLoss() {\n const turn = pending.catch(() => {\n }).then(async () => {\n await outputReadback.mapAsync(GPUMapMode.READ);\n const v = new Float32Array(outputReadback.getMappedRange())[0];\n outputReadback.unmap();\n return v;\n });\n pending = turn;\n return turn;\n }\n async function run(inputs, opts2) {\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) };\n return r.output;\n }\n function uploadParams(params2, opts2) {\n const partial = opts2?.partial ?? false;\n for (const name of Object.keys(params2)) {\n if (!plan.paramsByName.has(name)) {\n throw new Error(\n `uploadParams: unknown param \'${name}\'. Known: ${[...plan.paramsByName.keys()].sort().join(", ")}`\n );\n }\n }\n if (!partial) {\n for (const name of plan.paramsByName.keys()) {\n if (!(name in params2)) {\n throw new Error(\n `uploadParams: missing param \'${name}\'. Pass { partial: true } if you mean to update only some params.`\n );\n }\n }\n }\n for (const [name, bufId] of plan.paramsByName) {\n const data = params2[name];\n if (!data) continue;\n const expected = plan.buffers[bufId].byteSize / 4;\n if (data.length !== expected) {\n throw new Error(`uploadParams: \'${name}\' has ${data.length} elements, expected ${expected}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n }\n async function downloadFromMap(map) {\n const stagings = [];\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-download" });\n for (const [name, bufId] of map) {\n const spec = plan.buffers[bufId];\n const staging = device2.createBuffer({ size: spec.byteSize, usage: READBACK });\n encoder.copyBufferToBuffer(buffers.get(bufId), 0, staging, 0, spec.byteSize);\n stagings.push({ name, buf: staging, bytes: spec.byteSize });\n }\n queue.submit([encoder.finish()]);\n const out = {};\n for (const s of stagings) {\n await s.buf.mapAsync(GPUMapMode.READ);\n out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0));\n s.buf.unmap();\n s.buf.destroy();\n }\n return out;\n }\n function fillStateBuffer(spec, target) {\n const elements = spec.byteSize / 4;\n const init = spec.dtype === "f32" ? new Float32Array(elements).fill(spec.initValue ?? 0) : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0));\n queue.writeBuffer(target, 0, init);\n }\n function resetOptimizerState() {\n for (const spec of plan.buffers) {\n if (spec.kind === "state") fillStateBuffer(spec, buffers.get(spec.id));\n }\n }\n const params = /* @__PURE__ */ new Map();\n for (const [name, bufId] of plan.paramsByName) {\n params.set(name, buffers.get(bufId));\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const outputShape = [...plan.buffers[lossBufferId].shape];\n const destroy = () => {\n for (const [id, b] of buffers) {\n if (ownedBufferIds.has(id)) b.destroy();\n }\n outputReadback.destroy();\n if (captureStaging) captureStaging.buffer.destroy();\n };\n return {\n device: device2,\n params,\n outputShape,\n uploadParams,\n downloadParams: () => downloadFromMap(plan.paramsByName),\n downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),\n step,\n run,\n readLoss,\n resetOptimizerState,\n destroy\n };\n}\nasync function acquireDevice() {\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad: no WebGPU adapter");\n return await adapter.requestDevice();\n}\n\n// src/adam.ts\nfunction resolveLR(schedule, step) {\n if (typeof schedule === "number") return schedule;\n switch (schedule.kind) {\n case "constant":\n return schedule.value;\n case "linearDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.peak + (schedule.final - schedule.peak) * f;\n }\n case "cosineDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f));\n }\n case "warmup": {\n if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps);\n return resolveLR(schedule.after, step - schedule.warmupSteps);\n }\n }\n}\n\n// src/worker-protocol.ts\nfunction wireError(e) {\n if (e instanceof Error) {\n return { name: e.name, message: e.message, stack: e.stack ?? "" };\n }\n return { name: "Error", message: String(e), stack: "" };\n}\n\n// src/worker.ts\nvar graphs = /* @__PURE__ */ new Map();\nvar device = null;\nasync function ensureDevice() {\n if (device) return device;\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad worker: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad worker: no WebGPU adapter");\n device = await adapter.requestDevice();\n return device;\n}\nasync function handleCreateRuntime(payload) {\n const dev = await ensureDevice();\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n if (Object.keys(payload.initialParams).length > 0) {\n runtime.uploadParams(payload.initialParams);\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: payload.adam ? createAdamState(payload.adam) : null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nasync function handleCompileForward(payload) {\n const dev = await ensureDevice();\n const parent = graphs.get(payload.parentGraphId);\n if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`);\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev, sharedParams: parent.runtime.params };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nfunction createAdamState(cfg) {\n return {\n config: cfg,\n t: 0,\n lrtBuf: new Float32Array(1),\n decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null\n };\n}\nfunction injectAdamScalars(slot, inputs) {\n const a = slot.adam;\n if (!a) return inputs;\n a.t++;\n const lrNow = resolveLR(a.config.lr, a.t);\n a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t));\n const merged = { ...inputs, [a.config.lrtInputName]: a.lrtBuf };\n if (a.decayShrinkBuf && a.config.decayShrinkInputName) {\n a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay;\n merged[a.config.decayShrinkInputName] = a.decayShrinkBuf;\n }\n return merged;\n}\nasync function handleStep(payload) {\n const slot = mustGet(payload.graphId);\n const merged = injectAdamScalars(slot, payload.inputs);\n if (payload.withCaptures) {\n const r = await slot.runtime.step(merged, { withCaptures: true });\n return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const loss = await slot.runtime.step(merged);\n return { loss, captures: null };\n}\nasync function handleRun(payload) {\n const slot = mustGet(payload.graphId);\n if (payload.withCaptures) {\n const r = await slot.runtime.run(payload.inputs, { withCaptures: true });\n return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const output = await slot.runtime.run(payload.inputs);\n return { output, captures: null };\n}\nfunction capturesToRecord(captures, shapes) {\n const out = {};\n for (const name of Object.keys(shapes)) {\n if (captures.has(name)) out[name] = captures.get(name);\n }\n return out;\n}\nfunction handleUploadParams(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.uploadParams(payload.params, { partial: payload.partial });\n}\nasync function handleDownloadParams(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParams() };\n}\nasync function handleDownloadParamGrads(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParamGrads() };\n}\nfunction handleResetOptimizer(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.resetOptimizerState();\n if (slot.adam) slot.adam.t = 0;\n}\nfunction handleDestroy(payload) {\n const slot = graphs.get(payload.graphId);\n if (!slot) return;\n slot.runtime.destroy();\n graphs.delete(payload.graphId);\n}\nfunction mustGet(graphId) {\n const slot = graphs.get(graphId);\n if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`);\n return slot;\n}\nself.onmessage = async (ev) => {\n const req = ev.data;\n try {\n let result;\n let transferList = [];\n switch (req.kind) {\n case "createRuntime":\n result = await handleCreateRuntime(req.payload);\n break;\n case "compileForward":\n result = await handleCompileForward(req.payload);\n break;\n case "step":\n result = await handleStep(req.payload);\n transferList = collectTransfers(result.captures);\n break;\n case "run": {\n const r = await handleRun(req.payload);\n result = r;\n transferList = [r.output.buffer, ...collectTransfers(r.captures)];\n break;\n }\n case "uploadParams":\n handleUploadParams(req.payload);\n result = null;\n break;\n case "downloadParams": {\n const r = await handleDownloadParams(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "downloadParamGrads": {\n const r = await handleDownloadParamGrads(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "resetOptimizer":\n handleResetOptimizer(req.payload);\n result = null;\n break;\n case "destroy":\n handleDestroy(req.payload);\n result = null;\n break;\n default:\n throw new Error(`unknown request kind: ${req.kind}`);\n }\n const reply = { id: req.id, ok: true, result };\n self.postMessage(reply, { transfer: transferList });\n } catch (e) {\n const error = wireError(e);\n const reply = { id: req.id, ok: false, error };\n self.postMessage(reply);\n }\n};\nfunction collectTransfers(rec) {\n if (!rec) return [];\n const out = [];\n for (const v of Object.values(rec)) out.push(v.buffer);\n return out;\n}\n');
1847
1847
  const wireIR = { graph, plan, kernels };
1848
1848
  const transfers = transferablesOfRecord(initialParams);
1849
1849
  let meta;