tensorgrad 0.0.17 → 0.0.18

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1805,7 +1805,7 @@ async function compileModule(modelFactory, forward, opts = {}) {
1805
1805
  const kernels = emitKernels(graph, plan);
1806
1806
  const ir = { graph, paramGrads, loss, plan, kernels };
1807
1807
  const initialParams = buildInitialParams(plan, materialized.initFns);
1808
- const proxy = new WorkerProxy('// src/runtime.ts\nvar Captures = class {\n constructor(shapes, data) {\n this.shapes = shapes;\n this.data = data;\n }\n shapes;\n data;\n get(name) {\n const d = this.data.get(name);\n if (!d) {\n const known = [...this.data.keys()].sort().join(", ");\n const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;\n throw new Error(`Captures.get: \'${name}\' not present. ${detail}`);\n }\n return d;\n }\n shapeOf(name) {\n const s = this.shapes[name];\n if (!s) {\n const known = Object.keys(this.shapes).sort().join(", ") || "(none registered)";\n throw new Error(`Captures.shapeOf: \'${name}\' not registered. Known: ${known}`);\n }\n return s;\n }\n has(name) {\n return this.data.has(name);\n }\n names() {\n return [...this.data.keys()].sort();\n }\n};\nvar STORAGE_RW = 128 | 8 | 4;\nvar READBACK = 1 | 8;\nasync function createRuntime(plan, kernels, lossBufferId, opts = {}) {\n const device2 = opts.device ?? await acquireDevice();\n const queue = device2.queue;\n const buffers = /* @__PURE__ */ new Map();\n const ownedBufferIds = /* @__PURE__ */ new Set();\n const sharedParams = opts.sharedParams;\n for (const spec of plan.buffers) {\n const shared = spec.kind === "param" ? sharedParams?.get(spec.name) : void 0;\n if (shared) {\n if (shared.size !== spec.byteSize) {\n throw new Error(\n `sharedParams: size mismatch for \'${spec.name}\' \\u2014 supplied ${shared.size} bytes, compiled graph expects ${spec.byteSize}.`\n );\n }\n buffers.set(spec.id, shared);\n continue;\n }\n const buf = device2.createBuffer({\n size: spec.byteSize,\n usage: STORAGE_RW,\n label: spec.name ?? `t${spec.id}-${spec.kind}`\n });\n buffers.set(spec.id, buf);\n ownedBufferIds.add(spec.id);\n if (spec.kind === "state") fillStateBuffer(spec, buf);\n }\n const moduleCache = /* @__PURE__ */ new Map();\n const pipelines = [];\n const probes = [];\n for (const k of kernels) {\n if (!k.wgsl) {\n pipelines.push(null);\n continue;\n }\n let module = moduleCache.get(k.wgsl);\n if (!module) {\n module = device2.createShaderModule({ code: k.wgsl, label: k.opKind });\n moduleCache.set(k.wgsl, module);\n }\n device2.pushErrorScope("validation");\n const pipeline = device2.createComputePipeline({\n layout: "auto",\n compute: { module, entryPoint: "main" },\n label: k.opKind\n });\n pipelines.push(pipeline);\n probes.push(device2.popErrorScope().then((err) => err ? { k, module, err } : null));\n }\n const probeResults = await Promise.all(probes);\n const failures = probeResults.filter((p) => p != null);\n if (failures.length > 0) {\n const reports = [];\n for (const { k, module, err } of failures) {\n const info = await module.getCompilationInfo();\n const messages = info.messages.map((m) => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`).join("\\n");\n reports.push(\n `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` + (messages || " (no compilation messages)") + `\n--- WGSL ---\n${k.wgsl}\n-----------`\n );\n }\n console.error(reports.join("\\n\\n"));\n throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`);\n }\n const bindGroups = kernels.map((k, i) => {\n const pipeline = pipelines[i];\n if (!pipeline) return null;\n return device2.createBindGroup({\n layout: pipeline.getBindGroupLayout(0),\n entries: k.bindings.map((bufId, idx) => ({\n binding: idx,\n resource: { buffer: buffers.get(bufId) }\n }))\n });\n });\n const outputSpec = plan.buffers[lossBufferId];\n const outputReadback = device2.createBuffer({ size: outputSpec.byteSize, usage: READBACK });\n let captureStaging = null;\n function ensureCaptureStaging() {\n if (captureStaging) return captureStaging;\n let totalBytes = 0;\n const slices = [];\n for (const [name, bufId] of plan.capturesByName) {\n const spec = plan.buffers[bufId];\n slices.push({ name, bufId, offset: totalBytes, byteSize: spec.byteSize });\n totalBytes += spec.byteSize;\n }\n const buffer = device2.createBuffer({ size: totalBytes, usage: READBACK, label: "captures-staging" });\n captureStaging = { buffer, slices };\n return captureStaging;\n }\n let pending = Promise.resolve();\n async function dispatch(inputs, opts2) {\n const turn = pending.catch(() => {\n }).then(() => dispatchUnsynchronized(inputs, opts2));\n pending = turn;\n return turn;\n }\n async function dispatchUnsynchronized(inputs, opts2) {\n const wantCaptures = opts2.wantCaptures;\n if (wantCaptures && plan.capturesByName.size === 0) {\n throw new Error(\n `withCaptures=true but no capture(...) calls were registered during the trace. Add capture(\'name\', tensor) inside your forward pass for the intermediates you want read back.`\n );\n }\n for (const [name, bufId] of plan.inputsByName) {\n const data = inputs[name];\n if (!data) throw new Error(`tensorgrad: missing input \'${name}\'`);\n const expectedBytes = plan.buffers[bufId].byteSize;\n if (data.byteLength !== expectedBytes) {\n throw new Error(`tensorgrad: input \'${name}\' has ${data.byteLength} bytes, expected ${expectedBytes}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n const CHUNK_SIZE = 32;\n let layout = null;\n if (wantCaptures) {\n layout = ensureCaptureStaging();\n }\n let kernelIdx = 0;\n while (kernelIdx < kernels.length) {\n const chunkEnd = Math.min(kernelIdx + CHUNK_SIZE, kernels.length);\n const isLast = chunkEnd === kernels.length;\n const encoder = device2.createCommandEncoder({\n label: kernels.length > CHUNK_SIZE ? `tensorgrad-chunk-${kernelIdx}` : "tensorgrad-step"\n });\n for (let i = kernelIdx; i < chunkEnd; i++) {\n const k = kernels[i];\n if (!k.wgsl || k.threads === 0) continue;\n const pipeline = pipelines[i];\n const bindGroup = bindGroups[i];\n const pass = encoder.beginComputePass({ label: k.opKind });\n pass.setPipeline(pipeline);\n pass.setBindGroup(0, bindGroup);\n const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));\n const MAX_X = 65535;\n const wgX = Math.min(wgCount, MAX_X);\n const wgY = Math.ceil(wgCount / MAX_X);\n pass.dispatchWorkgroups(wgX, wgY, 1);\n pass.end();\n }\n if (isLast) {\n for (const wb of plan.writebacks) {\n encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);\n }\n encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);\n if (layout) {\n for (const s of layout.slices) {\n encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);\n }\n }\n }\n queue.submit([encoder.finish()]);\n if (!isLast) {\n await queue.onSubmittedWorkDone();\n }\n kernelIdx = chunkEnd;\n }\n if (!opts2.readback) return null;\n await outputReadback.mapAsync(GPUMapMode.READ);\n const output = new Float32Array(outputReadback.getMappedRange().slice(0));\n outputReadback.unmap();\n const captures = /* @__PURE__ */ new Map();\n if (layout) {\n await layout.buffer.mapAsync(GPUMapMode.READ);\n const range = layout.buffer.getMappedRange();\n for (const s of layout.slices) {\n captures.set(s.name, new Float32Array(range, s.offset, s.byteSize / 4).slice());\n }\n layout.buffer.unmap();\n }\n return { output, captures };\n }\n async function step(inputs, opts2) {\n if (opts2?.readLoss === false) {\n await dispatch(inputs, { wantCaptures: false, readback: false });\n return;\n }\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { loss: r.output[0], captures: new Captures(captureShapes, r.captures) };\n return r.output[0];\n }\n async function readLoss() {\n const turn = pending.catch(() => {\n }).then(async () => {\n await outputReadback.mapAsync(GPUMapMode.READ);\n const v = new Float32Array(outputReadback.getMappedRange())[0];\n outputReadback.unmap();\n return v;\n });\n pending = turn;\n return turn;\n }\n async function run(inputs, opts2) {\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) };\n return r.output;\n }\n function uploadParams(params2, opts2) {\n const partial = opts2?.partial ?? false;\n for (const name of Object.keys(params2)) {\n if (!plan.paramsByName.has(name)) {\n throw new Error(\n `uploadParams: unknown param \'${name}\'. Known: ${[...plan.paramsByName.keys()].sort().join(", ")}`\n );\n }\n }\n if (!partial) {\n for (const name of plan.paramsByName.keys()) {\n if (!(name in params2)) {\n throw new Error(\n `uploadParams: missing param \'${name}\'. Pass { partial: true } if you mean to update only some params.`\n );\n }\n }\n }\n for (const [name, bufId] of plan.paramsByName) {\n const data = params2[name];\n if (!data) continue;\n const expected = plan.buffers[bufId].byteSize / 4;\n if (data.length !== expected) {\n throw new Error(`uploadParams: \'${name}\' has ${data.length} elements, expected ${expected}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n }\n async function downloadFromMap(map) {\n const stagings = [];\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-download" });\n for (const [name, bufId] of map) {\n const spec = plan.buffers[bufId];\n const staging = device2.createBuffer({ size: spec.byteSize, usage: READBACK });\n encoder.copyBufferToBuffer(buffers.get(bufId), 0, staging, 0, spec.byteSize);\n stagings.push({ name, buf: staging, bytes: spec.byteSize });\n }\n queue.submit([encoder.finish()]);\n const out = {};\n for (const s of stagings) {\n await s.buf.mapAsync(GPUMapMode.READ);\n out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0));\n s.buf.unmap();\n s.buf.destroy();\n }\n return out;\n }\n function fillStateBuffer(spec, target) {\n const elements = spec.byteSize / 4;\n const init = spec.dtype === "f32" ? new Float32Array(elements).fill(spec.initValue ?? 0) : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0));\n queue.writeBuffer(target, 0, init);\n }\n function resetOptimizerState() {\n for (const spec of plan.buffers) {\n if (spec.kind === "state") fillStateBuffer(spec, buffers.get(spec.id));\n }\n }\n const params = /* @__PURE__ */ new Map();\n for (const [name, bufId] of plan.paramsByName) {\n params.set(name, buffers.get(bufId));\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const outputShape = [...plan.buffers[lossBufferId].shape];\n const destroy = () => {\n for (const [id, b] of buffers) {\n if (ownedBufferIds.has(id)) b.destroy();\n }\n outputReadback.destroy();\n if (captureStaging) captureStaging.buffer.destroy();\n };\n return {\n device: device2,\n params,\n outputShape,\n uploadParams,\n downloadParams: () => downloadFromMap(plan.paramsByName),\n downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),\n step,\n run,\n readLoss,\n resetOptimizerState,\n destroy\n };\n}\nasync function acquireDevice() {\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad: no WebGPU adapter");\n return await adapter.requestDevice();\n}\n\n// src/adam.ts\nfunction resolveLR(schedule, step) {\n if (typeof schedule === "number") return schedule;\n switch (schedule.kind) {\n case "constant":\n return schedule.value;\n case "linearDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.peak + (schedule.final - schedule.peak) * f;\n }\n case "cosineDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f));\n }\n case "warmup": {\n if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps);\n return resolveLR(schedule.after, step - schedule.warmupSteps);\n }\n }\n}\n\n// src/worker-protocol.ts\nfunction wireError(e) {\n if (e instanceof Error) {\n return { name: e.name, message: e.message, stack: e.stack ?? "" };\n }\n return { name: "Error", message: String(e), stack: "" };\n}\n\n// src/worker.ts\nvar graphs = /* @__PURE__ */ new Map();\nvar device = null;\nasync function ensureDevice() {\n if (device) return device;\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad worker: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad worker: no WebGPU adapter");\n device = await adapter.requestDevice();\n return device;\n}\nasync function handleCreateRuntime(payload) {\n const dev = await ensureDevice();\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n if (Object.keys(payload.initialParams).length > 0) {\n runtime.uploadParams(payload.initialParams);\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: payload.adam ? createAdamState(payload.adam) : null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nasync function handleCompileForward(payload) {\n const dev = await ensureDevice();\n const parent = graphs.get(payload.parentGraphId);\n if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`);\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev, sharedParams: parent.runtime.params };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nfunction createAdamState(cfg) {\n return {\n config: cfg,\n t: 0,\n lrtBuf: new Float32Array(1),\n decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null\n };\n}\nfunction injectAdamScalars(slot, inputs) {\n const a = slot.adam;\n if (!a) return inputs;\n a.t++;\n const lrNow = resolveLR(a.config.lr, a.t);\n a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t));\n const merged = { ...inputs, [a.config.lrtInputName]: a.lrtBuf };\n if (a.decayShrinkBuf && a.config.decayShrinkInputName) {\n a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay;\n merged[a.config.decayShrinkInputName] = a.decayShrinkBuf;\n }\n return merged;\n}\nasync function handleStep(payload) {\n const slot = mustGet(payload.graphId);\n const merged = injectAdamScalars(slot, payload.inputs);\n if (payload.withCaptures) {\n const r = await slot.runtime.step(merged, { withCaptures: true });\n return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const loss = await slot.runtime.step(merged);\n return { loss, captures: null };\n}\nasync function handleRun(payload) {\n const slot = mustGet(payload.graphId);\n if (payload.withCaptures) {\n const r = await slot.runtime.run(payload.inputs, { withCaptures: true });\n return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const output = await slot.runtime.run(payload.inputs);\n return { output, captures: null };\n}\nfunction capturesToRecord(captures, shapes) {\n const out = {};\n for (const name of Object.keys(shapes)) {\n if (captures.has(name)) out[name] = captures.get(name);\n }\n return out;\n}\nfunction handleUploadParams(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.uploadParams(payload.params, { partial: payload.partial });\n}\nasync function handleDownloadParams(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParams() };\n}\nasync function handleDownloadParamGrads(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParamGrads() };\n}\nfunction handleResetOptimizer(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.resetOptimizerState();\n if (slot.adam) slot.adam.t = 0;\n}\nfunction handleDestroy(payload) {\n const slot = graphs.get(payload.graphId);\n if (!slot) return;\n slot.runtime.destroy();\n graphs.delete(payload.graphId);\n}\nfunction mustGet(graphId) {\n const slot = graphs.get(graphId);\n if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`);\n return slot;\n}\nself.onmessage = async (ev) => {\n const req = ev.data;\n try {\n let result;\n let transferList = [];\n switch (req.kind) {\n case "createRuntime":\n result = await handleCreateRuntime(req.payload);\n break;\n case "compileForward":\n result = await handleCompileForward(req.payload);\n break;\n case "step":\n result = await handleStep(req.payload);\n transferList = collectTransfers(result.captures);\n break;\n case "run": {\n const r = await handleRun(req.payload);\n result = r;\n transferList = [r.output.buffer, ...collectTransfers(r.captures)];\n break;\n }\n case "uploadParams":\n handleUploadParams(req.payload);\n result = null;\n break;\n case "downloadParams": {\n const r = await handleDownloadParams(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "downloadParamGrads": {\n const r = await handleDownloadParamGrads(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "resetOptimizer":\n handleResetOptimizer(req.payload);\n result = null;\n break;\n case "destroy":\n handleDestroy(req.payload);\n result = null;\n break;\n default:\n throw new Error(`unknown request kind: ${req.kind}`);\n }\n const reply = { id: req.id, ok: true, result };\n self.postMessage(reply, { transfer: transferList });\n } catch (e) {\n const error = wireError(e);\n const reply = { id: req.id, ok: false, error };\n self.postMessage(reply);\n }\n};\nfunction collectTransfers(rec) {\n if (!rec) return [];\n const out = [];\n for (const v of Object.values(rec)) out.push(v.buffer);\n return out;\n}\n');
1808
+ const proxy = new WorkerProxy('// src/runtime.ts\nvar Captures = class {\n constructor(shapes, data) {\n this.shapes = shapes;\n this.data = data;\n }\n shapes;\n data;\n get(name) {\n const d = this.data.get(name);\n if (!d) {\n const known = [...this.data.keys()].sort().join(", ");\n const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;\n throw new Error(`Captures.get: \'${name}\' not present. ${detail}`);\n }\n return d;\n }\n shapeOf(name) {\n const s = this.shapes[name];\n if (!s) {\n const known = Object.keys(this.shapes).sort().join(", ") || "(none registered)";\n throw new Error(`Captures.shapeOf: \'${name}\' not registered. Known: ${known}`);\n }\n return s;\n }\n has(name) {\n return this.data.has(name);\n }\n names() {\n return [...this.data.keys()].sort();\n }\n};\nvar STORAGE_RW = 128 | 8 | 4;\nvar READBACK = 1 | 8;\nasync function createRuntime(plan, kernels, lossBufferId, opts = {}) {\n const device2 = opts.device ?? await acquireDevice();\n const queue = device2.queue;\n const buffers = /* @__PURE__ */ new Map();\n const ownedBufferIds = /* @__PURE__ */ new Set();\n const sharedParams = opts.sharedParams;\n for (const spec of plan.buffers) {\n const shared = spec.kind === "param" ? sharedParams?.get(spec.name) : void 0;\n if (shared) {\n if (shared.size !== spec.byteSize) {\n throw new Error(\n `sharedParams: size mismatch for \'${spec.name}\' \\u2014 supplied ${shared.size} bytes, compiled graph expects ${spec.byteSize}.`\n );\n }\n buffers.set(spec.id, shared);\n continue;\n }\n const buf = device2.createBuffer({\n size: spec.byteSize,\n usage: STORAGE_RW,\n label: spec.name ?? `t${spec.id}-${spec.kind}`\n });\n buffers.set(spec.id, buf);\n ownedBufferIds.add(spec.id);\n if (spec.kind === "state") fillStateBuffer(spec, buf);\n }\n const moduleCache = /* @__PURE__ */ new Map();\n const pipelines = [];\n const probes = [];\n for (const k of kernels) {\n if (!k.wgsl) {\n pipelines.push(null);\n continue;\n }\n let module = moduleCache.get(k.wgsl);\n if (!module) {\n module = device2.createShaderModule({ code: k.wgsl, label: k.opKind });\n moduleCache.set(k.wgsl, module);\n }\n device2.pushErrorScope("validation");\n const pipeline = device2.createComputePipeline({\n layout: "auto",\n compute: { module, entryPoint: "main" },\n label: k.opKind\n });\n pipelines.push(pipeline);\n probes.push(device2.popErrorScope().then((err) => err ? { k, module, err } : null));\n }\n const probeResults = await Promise.all(probes);\n const failures = probeResults.filter((p) => p != null);\n if (failures.length > 0) {\n const reports = [];\n for (const { k, module, err } of failures) {\n const info = await module.getCompilationInfo();\n const messages = info.messages.map((m) => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`).join("\\n");\n reports.push(\n `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` + (messages || " (no compilation messages)") + `\n--- WGSL ---\n${k.wgsl}\n-----------`\n );\n }\n console.error(reports.join("\\n\\n"));\n throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`);\n }\n const bindGroups = kernels.map((k, i) => {\n const pipeline = pipelines[i];\n if (!pipeline) return null;\n return device2.createBindGroup({\n layout: pipeline.getBindGroupLayout(0),\n entries: k.bindings.map((bufId, idx) => ({\n binding: idx,\n resource: { buffer: buffers.get(bufId) }\n }))\n });\n });\n const outputSpec = plan.buffers[lossBufferId];\n const outputReadback = device2.createBuffer({ size: outputSpec.byteSize, usage: READBACK });\n let captureStaging = null;\n function ensureCaptureStaging() {\n if (captureStaging) return captureStaging;\n let totalBytes = 0;\n const slices = [];\n for (const [name, bufId] of plan.capturesByName) {\n const spec = plan.buffers[bufId];\n slices.push({ name, bufId, offset: totalBytes, byteSize: spec.byteSize });\n totalBytes += spec.byteSize;\n }\n const buffer = device2.createBuffer({ size: totalBytes, usage: READBACK, label: "captures-staging" });\n captureStaging = { buffer, slices };\n return captureStaging;\n }\n let pending = Promise.resolve();\n async function dispatch(inputs, opts2) {\n const turn = pending.catch(() => {\n }).then(() => dispatchUnsynchronized(inputs, opts2));\n pending = turn;\n return turn;\n }\n async function dispatchUnsynchronized(inputs, opts2) {\n const wantCaptures = opts2.wantCaptures;\n if (wantCaptures && plan.capturesByName.size === 0) {\n throw new Error(\n `withCaptures=true but no capture(...) calls were registered during the trace. Add capture(\'name\', tensor) inside your forward pass for the intermediates you want read back.`\n );\n }\n for (const [name, bufId] of plan.inputsByName) {\n const data = inputs[name];\n if (!data) throw new Error(`tensorgrad: missing input \'${name}\'`);\n const expectedBytes = plan.buffers[bufId].byteSize;\n if (data.byteLength !== expectedBytes) {\n throw new Error(`tensorgrad: input \'${name}\' has ${data.byteLength} bytes, expected ${expectedBytes}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-step" });\n for (let i = 0; i < kernels.length; i++) {\n const k = kernels[i];\n if (!k.wgsl || k.threads === 0) continue;\n const pipeline = pipelines[i];\n const bindGroup = bindGroups[i];\n const pass = encoder.beginComputePass({ label: k.opKind });\n pass.setPipeline(pipeline);\n pass.setBindGroup(0, bindGroup);\n const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));\n const MAX_X = 65535;\n const wgX = Math.min(wgCount, MAX_X);\n const wgY = Math.ceil(wgCount / MAX_X);\n pass.dispatchWorkgroups(wgX, wgY, 1);\n pass.end();\n }\n for (const wb of plan.writebacks) {\n encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);\n }\n encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);\n let layout = null;\n if (wantCaptures) {\n layout = ensureCaptureStaging();\n for (const s of layout.slices) {\n encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);\n }\n }\n queue.submit([encoder.finish()]);\n if (!opts2.readback) return null;\n await outputReadback.mapAsync(GPUMapMode.READ);\n const output = new Float32Array(outputReadback.getMappedRange().slice(0));\n outputReadback.unmap();\n const captures = /* @__PURE__ */ new Map();\n if (layout) {\n await layout.buffer.mapAsync(GPUMapMode.READ);\n const range = layout.buffer.getMappedRange();\n for (const s of layout.slices) {\n captures.set(s.name, new Float32Array(range, s.offset, s.byteSize / 4).slice());\n }\n layout.buffer.unmap();\n }\n return { output, captures };\n }\n async function step(inputs, opts2) {\n if (opts2?.readLoss === false) {\n await dispatch(inputs, { wantCaptures: false, readback: false });\n return;\n }\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { loss: r.output[0], captures: new Captures(captureShapes, r.captures) };\n return r.output[0];\n }\n async function readLoss() {\n const turn = pending.catch(() => {\n }).then(async () => {\n await outputReadback.mapAsync(GPUMapMode.READ);\n const v = new Float32Array(outputReadback.getMappedRange())[0];\n outputReadback.unmap();\n return v;\n });\n pending = turn;\n return turn;\n }\n async function run(inputs, opts2) {\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) };\n return r.output;\n }\n function uploadParams(params2, opts2) {\n const partial = opts2?.partial ?? false;\n for (const name of Object.keys(params2)) {\n if (!plan.paramsByName.has(name)) {\n throw new Error(\n `uploadParams: unknown param \'${name}\'. Known: ${[...plan.paramsByName.keys()].sort().join(", ")}`\n );\n }\n }\n if (!partial) {\n for (const name of plan.paramsByName.keys()) {\n if (!(name in params2)) {\n throw new Error(\n `uploadParams: missing param \'${name}\'. Pass { partial: true } if you mean to update only some params.`\n );\n }\n }\n }\n for (const [name, bufId] of plan.paramsByName) {\n const data = params2[name];\n if (!data) continue;\n const expected = plan.buffers[bufId].byteSize / 4;\n if (data.length !== expected) {\n throw new Error(`uploadParams: \'${name}\' has ${data.length} elements, expected ${expected}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n }\n async function downloadFromMap(map) {\n const stagings = [];\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-download" });\n for (const [name, bufId] of map) {\n const spec = plan.buffers[bufId];\n const staging = device2.createBuffer({ size: spec.byteSize, usage: READBACK });\n encoder.copyBufferToBuffer(buffers.get(bufId), 0, staging, 0, spec.byteSize);\n stagings.push({ name, buf: staging, bytes: spec.byteSize });\n }\n queue.submit([encoder.finish()]);\n const out = {};\n for (const s of stagings) {\n await s.buf.mapAsync(GPUMapMode.READ);\n out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0));\n s.buf.unmap();\n s.buf.destroy();\n }\n return out;\n }\n function fillStateBuffer(spec, target) {\n const elements = spec.byteSize / 4;\n const init = spec.dtype === "f32" ? new Float32Array(elements).fill(spec.initValue ?? 0) : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0));\n queue.writeBuffer(target, 0, init);\n }\n function resetOptimizerState() {\n for (const spec of plan.buffers) {\n if (spec.kind === "state") fillStateBuffer(spec, buffers.get(spec.id));\n }\n }\n const params = /* @__PURE__ */ new Map();\n for (const [name, bufId] of plan.paramsByName) {\n params.set(name, buffers.get(bufId));\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const outputShape = [...plan.buffers[lossBufferId].shape];\n const destroy = () => {\n for (const [id, b] of buffers) {\n if (ownedBufferIds.has(id)) b.destroy();\n }\n outputReadback.destroy();\n if (captureStaging) captureStaging.buffer.destroy();\n };\n return {\n device: device2,\n params,\n outputShape,\n uploadParams,\n downloadParams: () => downloadFromMap(plan.paramsByName),\n downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),\n step,\n run,\n readLoss,\n resetOptimizerState,\n destroy\n };\n}\nasync function acquireDevice() {\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad: no WebGPU adapter");\n return await adapter.requestDevice();\n}\n\n// src/adam.ts\nfunction resolveLR(schedule, step) {\n if (typeof schedule === "number") return schedule;\n switch (schedule.kind) {\n case "constant":\n return schedule.value;\n case "linearDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.peak + (schedule.final - schedule.peak) * f;\n }\n case "cosineDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f));\n }\n case "warmup": {\n if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps);\n return resolveLR(schedule.after, step - schedule.warmupSteps);\n }\n }\n}\n\n// src/worker-protocol.ts\nfunction wireError(e) {\n if (e instanceof Error) {\n return { name: e.name, message: e.message, stack: e.stack ?? "" };\n }\n return { name: "Error", message: String(e), stack: "" };\n}\n\n// src/worker.ts\nvar graphs = /* @__PURE__ */ new Map();\nvar device = null;\nasync function ensureDevice() {\n if (device) return device;\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad worker: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad worker: no WebGPU adapter");\n device = await adapter.requestDevice();\n return device;\n}\nasync function handleCreateRuntime(payload) {\n const dev = await ensureDevice();\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n if (Object.keys(payload.initialParams).length > 0) {\n runtime.uploadParams(payload.initialParams);\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: payload.adam ? createAdamState(payload.adam) : null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nasync function handleCompileForward(payload) {\n const dev = await ensureDevice();\n const parent = graphs.get(payload.parentGraphId);\n if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`);\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev, sharedParams: parent.runtime.params };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nfunction createAdamState(cfg) {\n return {\n config: cfg,\n t: 0,\n lrtBuf: new Float32Array(1),\n decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null\n };\n}\nfunction injectAdamScalars(slot, inputs) {\n const a = slot.adam;\n if (!a) return inputs;\n a.t++;\n const lrNow = resolveLR(a.config.lr, a.t);\n a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t));\n const merged = { ...inputs, [a.config.lrtInputName]: a.lrtBuf };\n if (a.decayShrinkBuf && a.config.decayShrinkInputName) {\n a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay;\n merged[a.config.decayShrinkInputName] = a.decayShrinkBuf;\n }\n return merged;\n}\nasync function handleStep(payload) {\n const slot = mustGet(payload.graphId);\n const merged = injectAdamScalars(slot, payload.inputs);\n if (payload.withCaptures) {\n const r = await slot.runtime.step(merged, { withCaptures: true });\n return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const loss = await slot.runtime.step(merged);\n return { loss, captures: null };\n}\nasync function handleRun(payload) {\n const slot = mustGet(payload.graphId);\n if (payload.withCaptures) {\n const r = await slot.runtime.run(payload.inputs, { withCaptures: true });\n return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const output = await slot.runtime.run(payload.inputs);\n return { output, captures: null };\n}\nfunction capturesToRecord(captures, shapes) {\n const out = {};\n for (const name of Object.keys(shapes)) {\n if (captures.has(name)) out[name] = captures.get(name);\n }\n return out;\n}\nfunction handleUploadParams(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.uploadParams(payload.params, { partial: payload.partial });\n}\nasync function handleDownloadParams(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParams() };\n}\nasync function handleDownloadParamGrads(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParamGrads() };\n}\nfunction handleResetOptimizer(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.resetOptimizerState();\n if (slot.adam) slot.adam.t = 0;\n}\nfunction handleDestroy(payload) {\n const slot = graphs.get(payload.graphId);\n if (!slot) return;\n slot.runtime.destroy();\n graphs.delete(payload.graphId);\n}\nfunction mustGet(graphId) {\n const slot = graphs.get(graphId);\n if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`);\n return slot;\n}\nself.onmessage = async (ev) => {\n const req = ev.data;\n try {\n let result;\n let transferList = [];\n switch (req.kind) {\n case "createRuntime":\n result = await handleCreateRuntime(req.payload);\n break;\n case "compileForward":\n result = await handleCompileForward(req.payload);\n break;\n case "step":\n result = await handleStep(req.payload);\n transferList = collectTransfers(result.captures);\n break;\n case "run": {\n const r = await handleRun(req.payload);\n result = r;\n transferList = [r.output.buffer, ...collectTransfers(r.captures)];\n break;\n }\n case "uploadParams":\n handleUploadParams(req.payload);\n result = null;\n break;\n case "downloadParams": {\n const r = await handleDownloadParams(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "downloadParamGrads": {\n const r = await handleDownloadParamGrads(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "resetOptimizer":\n handleResetOptimizer(req.payload);\n result = null;\n break;\n case "destroy":\n handleDestroy(req.payload);\n result = null;\n break;\n default:\n throw new Error(`unknown request kind: ${req.kind}`);\n }\n const reply = { id: req.id, ok: true, result };\n self.postMessage(reply, { transfer: transferList });\n } catch (e) {\n const error = wireError(e);\n const reply = { id: req.id, ok: false, error };\n self.postMessage(reply);\n }\n};\nfunction collectTransfers(rec) {\n if (!rec) return [];\n const out = [];\n for (const v of Object.values(rec)) out.push(v.buffer);\n return out;\n}\n');
1809
1809
  const wireIR = { graph, plan, kernels };
1810
1810
  const wireAdam = adamResult ? wireAdamConfig(adamResult) : null;
1811
1811
  const transfers = transferablesOfRecord(initialParams);
@@ -1843,7 +1843,7 @@ async function compileForward(modelFactory, forward, opts = {}) {
1843
1843
  const kernels = emitKernels(graph, plan);
1844
1844
  const ir = { graph, paramGrads: {}, loss: outputTensor, plan, kernels };
1845
1845
  const initialParams = buildInitialParams(plan, materialized.initFns);
1846
- const proxy = new WorkerProxy('// src/runtime.ts\nvar Captures = class {\n constructor(shapes, data) {\n this.shapes = shapes;\n this.data = data;\n }\n shapes;\n data;\n get(name) {\n const d = this.data.get(name);\n if (!d) {\n const known = [...this.data.keys()].sort().join(", ");\n const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;\n throw new Error(`Captures.get: \'${name}\' not present. ${detail}`);\n }\n return d;\n }\n shapeOf(name) {\n const s = this.shapes[name];\n if (!s) {\n const known = Object.keys(this.shapes).sort().join(", ") || "(none registered)";\n throw new Error(`Captures.shapeOf: \'${name}\' not registered. Known: ${known}`);\n }\n return s;\n }\n has(name) {\n return this.data.has(name);\n }\n names() {\n return [...this.data.keys()].sort();\n }\n};\nvar STORAGE_RW = 128 | 8 | 4;\nvar READBACK = 1 | 8;\nasync function createRuntime(plan, kernels, lossBufferId, opts = {}) {\n const device2 = opts.device ?? await acquireDevice();\n const queue = device2.queue;\n const buffers = /* @__PURE__ */ new Map();\n const ownedBufferIds = /* @__PURE__ */ new Set();\n const sharedParams = opts.sharedParams;\n for (const spec of plan.buffers) {\n const shared = spec.kind === "param" ? sharedParams?.get(spec.name) : void 0;\n if (shared) {\n if (shared.size !== spec.byteSize) {\n throw new Error(\n `sharedParams: size mismatch for \'${spec.name}\' \\u2014 supplied ${shared.size} bytes, compiled graph expects ${spec.byteSize}.`\n );\n }\n buffers.set(spec.id, shared);\n continue;\n }\n const buf = device2.createBuffer({\n size: spec.byteSize,\n usage: STORAGE_RW,\n label: spec.name ?? `t${spec.id}-${spec.kind}`\n });\n buffers.set(spec.id, buf);\n ownedBufferIds.add(spec.id);\n if (spec.kind === "state") fillStateBuffer(spec, buf);\n }\n const moduleCache = /* @__PURE__ */ new Map();\n const pipelines = [];\n const probes = [];\n for (const k of kernels) {\n if (!k.wgsl) {\n pipelines.push(null);\n continue;\n }\n let module = moduleCache.get(k.wgsl);\n if (!module) {\n module = device2.createShaderModule({ code: k.wgsl, label: k.opKind });\n moduleCache.set(k.wgsl, module);\n }\n device2.pushErrorScope("validation");\n const pipeline = device2.createComputePipeline({\n layout: "auto",\n compute: { module, entryPoint: "main" },\n label: k.opKind\n });\n pipelines.push(pipeline);\n probes.push(device2.popErrorScope().then((err) => err ? { k, module, err } : null));\n }\n const probeResults = await Promise.all(probes);\n const failures = probeResults.filter((p) => p != null);\n if (failures.length > 0) {\n const reports = [];\n for (const { k, module, err } of failures) {\n const info = await module.getCompilationInfo();\n const messages = info.messages.map((m) => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`).join("\\n");\n reports.push(\n `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` + (messages || " (no compilation messages)") + `\n--- WGSL ---\n${k.wgsl}\n-----------`\n );\n }\n console.error(reports.join("\\n\\n"));\n throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`);\n }\n const bindGroups = kernels.map((k, i) => {\n const pipeline = pipelines[i];\n if (!pipeline) return null;\n return device2.createBindGroup({\n layout: pipeline.getBindGroupLayout(0),\n entries: k.bindings.map((bufId, idx) => ({\n binding: idx,\n resource: { buffer: buffers.get(bufId) }\n }))\n });\n });\n const outputSpec = plan.buffers[lossBufferId];\n const outputReadback = device2.createBuffer({ size: outputSpec.byteSize, usage: READBACK });\n let captureStaging = null;\n function ensureCaptureStaging() {\n if (captureStaging) return captureStaging;\n let totalBytes = 0;\n const slices = [];\n for (const [name, bufId] of plan.capturesByName) {\n const spec = plan.buffers[bufId];\n slices.push({ name, bufId, offset: totalBytes, byteSize: spec.byteSize });\n totalBytes += spec.byteSize;\n }\n const buffer = device2.createBuffer({ size: totalBytes, usage: READBACK, label: "captures-staging" });\n captureStaging = { buffer, slices };\n return captureStaging;\n }\n let pending = Promise.resolve();\n async function dispatch(inputs, opts2) {\n const turn = pending.catch(() => {\n }).then(() => dispatchUnsynchronized(inputs, opts2));\n pending = turn;\n return turn;\n }\n async function dispatchUnsynchronized(inputs, opts2) {\n const wantCaptures = opts2.wantCaptures;\n if (wantCaptures && plan.capturesByName.size === 0) {\n throw new Error(\n `withCaptures=true but no capture(...) calls were registered during the trace. Add capture(\'name\', tensor) inside your forward pass for the intermediates you want read back.`\n );\n }\n for (const [name, bufId] of plan.inputsByName) {\n const data = inputs[name];\n if (!data) throw new Error(`tensorgrad: missing input \'${name}\'`);\n const expectedBytes = plan.buffers[bufId].byteSize;\n if (data.byteLength !== expectedBytes) {\n throw new Error(`tensorgrad: input \'${name}\' has ${data.byteLength} bytes, expected ${expectedBytes}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n const CHUNK_SIZE = 32;\n let layout = null;\n if (wantCaptures) {\n layout = ensureCaptureStaging();\n }\n let kernelIdx = 0;\n while (kernelIdx < kernels.length) {\n const chunkEnd = Math.min(kernelIdx + CHUNK_SIZE, kernels.length);\n const isLast = chunkEnd === kernels.length;\n const encoder = device2.createCommandEncoder({\n label: kernels.length > CHUNK_SIZE ? `tensorgrad-chunk-${kernelIdx}` : "tensorgrad-step"\n });\n for (let i = kernelIdx; i < chunkEnd; i++) {\n const k = kernels[i];\n if (!k.wgsl || k.threads === 0) continue;\n const pipeline = pipelines[i];\n const bindGroup = bindGroups[i];\n const pass = encoder.beginComputePass({ label: k.opKind });\n pass.setPipeline(pipeline);\n pass.setBindGroup(0, bindGroup);\n const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));\n const MAX_X = 65535;\n const wgX = Math.min(wgCount, MAX_X);\n const wgY = Math.ceil(wgCount / MAX_X);\n pass.dispatchWorkgroups(wgX, wgY, 1);\n pass.end();\n }\n if (isLast) {\n for (const wb of plan.writebacks) {\n encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);\n }\n encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);\n if (layout) {\n for (const s of layout.slices) {\n encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);\n }\n }\n }\n queue.submit([encoder.finish()]);\n if (!isLast) {\n await queue.onSubmittedWorkDone();\n }\n kernelIdx = chunkEnd;\n }\n if (!opts2.readback) return null;\n await outputReadback.mapAsync(GPUMapMode.READ);\n const output = new Float32Array(outputReadback.getMappedRange().slice(0));\n outputReadback.unmap();\n const captures = /* @__PURE__ */ new Map();\n if (layout) {\n await layout.buffer.mapAsync(GPUMapMode.READ);\n const range = layout.buffer.getMappedRange();\n for (const s of layout.slices) {\n captures.set(s.name, new Float32Array(range, s.offset, s.byteSize / 4).slice());\n }\n layout.buffer.unmap();\n }\n return { output, captures };\n }\n async function step(inputs, opts2) {\n if (opts2?.readLoss === false) {\n await dispatch(inputs, { wantCaptures: false, readback: false });\n return;\n }\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { loss: r.output[0], captures: new Captures(captureShapes, r.captures) };\n return r.output[0];\n }\n async function readLoss() {\n const turn = pending.catch(() => {\n }).then(async () => {\n await outputReadback.mapAsync(GPUMapMode.READ);\n const v = new Float32Array(outputReadback.getMappedRange())[0];\n outputReadback.unmap();\n return v;\n });\n pending = turn;\n return turn;\n }\n async function run(inputs, opts2) {\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) };\n return r.output;\n }\n function uploadParams(params2, opts2) {\n const partial = opts2?.partial ?? false;\n for (const name of Object.keys(params2)) {\n if (!plan.paramsByName.has(name)) {\n throw new Error(\n `uploadParams: unknown param \'${name}\'. Known: ${[...plan.paramsByName.keys()].sort().join(", ")}`\n );\n }\n }\n if (!partial) {\n for (const name of plan.paramsByName.keys()) {\n if (!(name in params2)) {\n throw new Error(\n `uploadParams: missing param \'${name}\'. Pass { partial: true } if you mean to update only some params.`\n );\n }\n }\n }\n for (const [name, bufId] of plan.paramsByName) {\n const data = params2[name];\n if (!data) continue;\n const expected = plan.buffers[bufId].byteSize / 4;\n if (data.length !== expected) {\n throw new Error(`uploadParams: \'${name}\' has ${data.length} elements, expected ${expected}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n }\n async function downloadFromMap(map) {\n const stagings = [];\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-download" });\n for (const [name, bufId] of map) {\n const spec = plan.buffers[bufId];\n const staging = device2.createBuffer({ size: spec.byteSize, usage: READBACK });\n encoder.copyBufferToBuffer(buffers.get(bufId), 0, staging, 0, spec.byteSize);\n stagings.push({ name, buf: staging, bytes: spec.byteSize });\n }\n queue.submit([encoder.finish()]);\n const out = {};\n for (const s of stagings) {\n await s.buf.mapAsync(GPUMapMode.READ);\n out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0));\n s.buf.unmap();\n s.buf.destroy();\n }\n return out;\n }\n function fillStateBuffer(spec, target) {\n const elements = spec.byteSize / 4;\n const init = spec.dtype === "f32" ? new Float32Array(elements).fill(spec.initValue ?? 0) : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0));\n queue.writeBuffer(target, 0, init);\n }\n function resetOptimizerState() {\n for (const spec of plan.buffers) {\n if (spec.kind === "state") fillStateBuffer(spec, buffers.get(spec.id));\n }\n }\n const params = /* @__PURE__ */ new Map();\n for (const [name, bufId] of plan.paramsByName) {\n params.set(name, buffers.get(bufId));\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const outputShape = [...plan.buffers[lossBufferId].shape];\n const destroy = () => {\n for (const [id, b] of buffers) {\n if (ownedBufferIds.has(id)) b.destroy();\n }\n outputReadback.destroy();\n if (captureStaging) captureStaging.buffer.destroy();\n };\n return {\n device: device2,\n params,\n outputShape,\n uploadParams,\n downloadParams: () => downloadFromMap(plan.paramsByName),\n downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),\n step,\n run,\n readLoss,\n resetOptimizerState,\n destroy\n };\n}\nasync function acquireDevice() {\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad: no WebGPU adapter");\n return await adapter.requestDevice();\n}\n\n// src/adam.ts\nfunction resolveLR(schedule, step) {\n if (typeof schedule === "number") return schedule;\n switch (schedule.kind) {\n case "constant":\n return schedule.value;\n case "linearDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.peak + (schedule.final - schedule.peak) * f;\n }\n case "cosineDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f));\n }\n case "warmup": {\n if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps);\n return resolveLR(schedule.after, step - schedule.warmupSteps);\n }\n }\n}\n\n// src/worker-protocol.ts\nfunction wireError(e) {\n if (e instanceof Error) {\n return { name: e.name, message: e.message, stack: e.stack ?? "" };\n }\n return { name: "Error", message: String(e), stack: "" };\n}\n\n// src/worker.ts\nvar graphs = /* @__PURE__ */ new Map();\nvar device = null;\nasync function ensureDevice() {\n if (device) return device;\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad worker: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad worker: no WebGPU adapter");\n device = await adapter.requestDevice();\n return device;\n}\nasync function handleCreateRuntime(payload) {\n const dev = await ensureDevice();\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n if (Object.keys(payload.initialParams).length > 0) {\n runtime.uploadParams(payload.initialParams);\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: payload.adam ? createAdamState(payload.adam) : null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nasync function handleCompileForward(payload) {\n const dev = await ensureDevice();\n const parent = graphs.get(payload.parentGraphId);\n if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`);\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev, sharedParams: parent.runtime.params };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nfunction createAdamState(cfg) {\n return {\n config: cfg,\n t: 0,\n lrtBuf: new Float32Array(1),\n decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null\n };\n}\nfunction injectAdamScalars(slot, inputs) {\n const a = slot.adam;\n if (!a) return inputs;\n a.t++;\n const lrNow = resolveLR(a.config.lr, a.t);\n a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t));\n const merged = { ...inputs, [a.config.lrtInputName]: a.lrtBuf };\n if (a.decayShrinkBuf && a.config.decayShrinkInputName) {\n a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay;\n merged[a.config.decayShrinkInputName] = a.decayShrinkBuf;\n }\n return merged;\n}\nasync function handleStep(payload) {\n const slot = mustGet(payload.graphId);\n const merged = injectAdamScalars(slot, payload.inputs);\n if (payload.withCaptures) {\n const r = await slot.runtime.step(merged, { withCaptures: true });\n return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const loss = await slot.runtime.step(merged);\n return { loss, captures: null };\n}\nasync function handleRun(payload) {\n const slot = mustGet(payload.graphId);\n if (payload.withCaptures) {\n const r = await slot.runtime.run(payload.inputs, { withCaptures: true });\n return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const output = await slot.runtime.run(payload.inputs);\n return { output, captures: null };\n}\nfunction capturesToRecord(captures, shapes) {\n const out = {};\n for (const name of Object.keys(shapes)) {\n if (captures.has(name)) out[name] = captures.get(name);\n }\n return out;\n}\nfunction handleUploadParams(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.uploadParams(payload.params, { partial: payload.partial });\n}\nasync function handleDownloadParams(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParams() };\n}\nasync function handleDownloadParamGrads(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParamGrads() };\n}\nfunction handleResetOptimizer(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.resetOptimizerState();\n if (slot.adam) slot.adam.t = 0;\n}\nfunction handleDestroy(payload) {\n const slot = graphs.get(payload.graphId);\n if (!slot) return;\n slot.runtime.destroy();\n graphs.delete(payload.graphId);\n}\nfunction mustGet(graphId) {\n const slot = graphs.get(graphId);\n if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`);\n return slot;\n}\nself.onmessage = async (ev) => {\n const req = ev.data;\n try {\n let result;\n let transferList = [];\n switch (req.kind) {\n case "createRuntime":\n result = await handleCreateRuntime(req.payload);\n break;\n case "compileForward":\n result = await handleCompileForward(req.payload);\n break;\n case "step":\n result = await handleStep(req.payload);\n transferList = collectTransfers(result.captures);\n break;\n case "run": {\n const r = await handleRun(req.payload);\n result = r;\n transferList = [r.output.buffer, ...collectTransfers(r.captures)];\n break;\n }\n case "uploadParams":\n handleUploadParams(req.payload);\n result = null;\n break;\n case "downloadParams": {\n const r = await handleDownloadParams(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "downloadParamGrads": {\n const r = await handleDownloadParamGrads(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "resetOptimizer":\n handleResetOptimizer(req.payload);\n result = null;\n break;\n case "destroy":\n handleDestroy(req.payload);\n result = null;\n break;\n default:\n throw new Error(`unknown request kind: ${req.kind}`);\n }\n const reply = { id: req.id, ok: true, result };\n self.postMessage(reply, { transfer: transferList });\n } catch (e) {\n const error = wireError(e);\n const reply = { id: req.id, ok: false, error };\n self.postMessage(reply);\n }\n};\nfunction collectTransfers(rec) {\n if (!rec) return [];\n const out = [];\n for (const v of Object.values(rec)) out.push(v.buffer);\n return out;\n}\n');
1846
+ const proxy = new WorkerProxy('// src/runtime.ts\nvar Captures = class {\n constructor(shapes, data) {\n this.shapes = shapes;\n this.data = data;\n }\n shapes;\n data;\n get(name) {\n const d = this.data.get(name);\n if (!d) {\n const known = [...this.data.keys()].sort().join(", ");\n const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;\n throw new Error(`Captures.get: \'${name}\' not present. ${detail}`);\n }\n return d;\n }\n shapeOf(name) {\n const s = this.shapes[name];\n if (!s) {\n const known = Object.keys(this.shapes).sort().join(", ") || "(none registered)";\n throw new Error(`Captures.shapeOf: \'${name}\' not registered. Known: ${known}`);\n }\n return s;\n }\n has(name) {\n return this.data.has(name);\n }\n names() {\n return [...this.data.keys()].sort();\n }\n};\nvar STORAGE_RW = 128 | 8 | 4;\nvar READBACK = 1 | 8;\nasync function createRuntime(plan, kernels, lossBufferId, opts = {}) {\n const device2 = opts.device ?? await acquireDevice();\n const queue = device2.queue;\n const buffers = /* @__PURE__ */ new Map();\n const ownedBufferIds = /* @__PURE__ */ new Set();\n const sharedParams = opts.sharedParams;\n for (const spec of plan.buffers) {\n const shared = spec.kind === "param" ? sharedParams?.get(spec.name) : void 0;\n if (shared) {\n if (shared.size !== spec.byteSize) {\n throw new Error(\n `sharedParams: size mismatch for \'${spec.name}\' \\u2014 supplied ${shared.size} bytes, compiled graph expects ${spec.byteSize}.`\n );\n }\n buffers.set(spec.id, shared);\n continue;\n }\n const buf = device2.createBuffer({\n size: spec.byteSize,\n usage: STORAGE_RW,\n label: spec.name ?? `t${spec.id}-${spec.kind}`\n });\n buffers.set(spec.id, buf);\n ownedBufferIds.add(spec.id);\n if (spec.kind === "state") fillStateBuffer(spec, buf);\n }\n const moduleCache = /* @__PURE__ */ new Map();\n const pipelines = [];\n const probes = [];\n for (const k of kernels) {\n if (!k.wgsl) {\n pipelines.push(null);\n continue;\n }\n let module = moduleCache.get(k.wgsl);\n if (!module) {\n module = device2.createShaderModule({ code: k.wgsl, label: k.opKind });\n moduleCache.set(k.wgsl, module);\n }\n device2.pushErrorScope("validation");\n const pipeline = device2.createComputePipeline({\n layout: "auto",\n compute: { module, entryPoint: "main" },\n label: k.opKind\n });\n pipelines.push(pipeline);\n probes.push(device2.popErrorScope().then((err) => err ? { k, module, err } : null));\n }\n const probeResults = await Promise.all(probes);\n const failures = probeResults.filter((p) => p != null);\n if (failures.length > 0) {\n const reports = [];\n for (const { k, module, err } of failures) {\n const info = await module.getCompilationInfo();\n const messages = info.messages.map((m) => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`).join("\\n");\n reports.push(\n `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` + (messages || " (no compilation messages)") + `\n--- WGSL ---\n${k.wgsl}\n-----------`\n );\n }\n console.error(reports.join("\\n\\n"));\n throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`);\n }\n const bindGroups = kernels.map((k, i) => {\n const pipeline = pipelines[i];\n if (!pipeline) return null;\n return device2.createBindGroup({\n layout: pipeline.getBindGroupLayout(0),\n entries: k.bindings.map((bufId, idx) => ({\n binding: idx,\n resource: { buffer: buffers.get(bufId) }\n }))\n });\n });\n const outputSpec = plan.buffers[lossBufferId];\n const outputReadback = device2.createBuffer({ size: outputSpec.byteSize, usage: READBACK });\n let captureStaging = null;\n function ensureCaptureStaging() {\n if (captureStaging) return captureStaging;\n let totalBytes = 0;\n const slices = [];\n for (const [name, bufId] of plan.capturesByName) {\n const spec = plan.buffers[bufId];\n slices.push({ name, bufId, offset: totalBytes, byteSize: spec.byteSize });\n totalBytes += spec.byteSize;\n }\n const buffer = device2.createBuffer({ size: totalBytes, usage: READBACK, label: "captures-staging" });\n captureStaging = { buffer, slices };\n return captureStaging;\n }\n let pending = Promise.resolve();\n async function dispatch(inputs, opts2) {\n const turn = pending.catch(() => {\n }).then(() => dispatchUnsynchronized(inputs, opts2));\n pending = turn;\n return turn;\n }\n async function dispatchUnsynchronized(inputs, opts2) {\n const wantCaptures = opts2.wantCaptures;\n if (wantCaptures && plan.capturesByName.size === 0) {\n throw new Error(\n `withCaptures=true but no capture(...) calls were registered during the trace. Add capture(\'name\', tensor) inside your forward pass for the intermediates you want read back.`\n );\n }\n for (const [name, bufId] of plan.inputsByName) {\n const data = inputs[name];\n if (!data) throw new Error(`tensorgrad: missing input \'${name}\'`);\n const expectedBytes = plan.buffers[bufId].byteSize;\n if (data.byteLength !== expectedBytes) {\n throw new Error(`tensorgrad: input \'${name}\' has ${data.byteLength} bytes, expected ${expectedBytes}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-step" });\n for (let i = 0; i < kernels.length; i++) {\n const k = kernels[i];\n if (!k.wgsl || k.threads === 0) continue;\n const pipeline = pipelines[i];\n const bindGroup = bindGroups[i];\n const pass = encoder.beginComputePass({ label: k.opKind });\n pass.setPipeline(pipeline);\n pass.setBindGroup(0, bindGroup);\n const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));\n const MAX_X = 65535;\n const wgX = Math.min(wgCount, MAX_X);\n const wgY = Math.ceil(wgCount / MAX_X);\n pass.dispatchWorkgroups(wgX, wgY, 1);\n pass.end();\n }\n for (const wb of plan.writebacks) {\n encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);\n }\n encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);\n let layout = null;\n if (wantCaptures) {\n layout = ensureCaptureStaging();\n for (const s of layout.slices) {\n encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);\n }\n }\n queue.submit([encoder.finish()]);\n if (!opts2.readback) return null;\n await outputReadback.mapAsync(GPUMapMode.READ);\n const output = new Float32Array(outputReadback.getMappedRange().slice(0));\n outputReadback.unmap();\n const captures = /* @__PURE__ */ new Map();\n if (layout) {\n await layout.buffer.mapAsync(GPUMapMode.READ);\n const range = layout.buffer.getMappedRange();\n for (const s of layout.slices) {\n captures.set(s.name, new Float32Array(range, s.offset, s.byteSize / 4).slice());\n }\n layout.buffer.unmap();\n }\n return { output, captures };\n }\n async function step(inputs, opts2) {\n if (opts2?.readLoss === false) {\n await dispatch(inputs, { wantCaptures: false, readback: false });\n return;\n }\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { loss: r.output[0], captures: new Captures(captureShapes, r.captures) };\n return r.output[0];\n }\n async function readLoss() {\n const turn = pending.catch(() => {\n }).then(async () => {\n await outputReadback.mapAsync(GPUMapMode.READ);\n const v = new Float32Array(outputReadback.getMappedRange())[0];\n outputReadback.unmap();\n return v;\n });\n pending = turn;\n return turn;\n }\n async function run(inputs, opts2) {\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) };\n return r.output;\n }\n function uploadParams(params2, opts2) {\n const partial = opts2?.partial ?? false;\n for (const name of Object.keys(params2)) {\n if (!plan.paramsByName.has(name)) {\n throw new Error(\n `uploadParams: unknown param \'${name}\'. Known: ${[...plan.paramsByName.keys()].sort().join(", ")}`\n );\n }\n }\n if (!partial) {\n for (const name of plan.paramsByName.keys()) {\n if (!(name in params2)) {\n throw new Error(\n `uploadParams: missing param \'${name}\'. Pass { partial: true } if you mean to update only some params.`\n );\n }\n }\n }\n for (const [name, bufId] of plan.paramsByName) {\n const data = params2[name];\n if (!data) continue;\n const expected = plan.buffers[bufId].byteSize / 4;\n if (data.length !== expected) {\n throw new Error(`uploadParams: \'${name}\' has ${data.length} elements, expected ${expected}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n }\n async function downloadFromMap(map) {\n const stagings = [];\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-download" });\n for (const [name, bufId] of map) {\n const spec = plan.buffers[bufId];\n const staging = device2.createBuffer({ size: spec.byteSize, usage: READBACK });\n encoder.copyBufferToBuffer(buffers.get(bufId), 0, staging, 0, spec.byteSize);\n stagings.push({ name, buf: staging, bytes: spec.byteSize });\n }\n queue.submit([encoder.finish()]);\n const out = {};\n for (const s of stagings) {\n await s.buf.mapAsync(GPUMapMode.READ);\n out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0));\n s.buf.unmap();\n s.buf.destroy();\n }\n return out;\n }\n function fillStateBuffer(spec, target) {\n const elements = spec.byteSize / 4;\n const init = spec.dtype === "f32" ? new Float32Array(elements).fill(spec.initValue ?? 0) : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0));\n queue.writeBuffer(target, 0, init);\n }\n function resetOptimizerState() {\n for (const spec of plan.buffers) {\n if (spec.kind === "state") fillStateBuffer(spec, buffers.get(spec.id));\n }\n }\n const params = /* @__PURE__ */ new Map();\n for (const [name, bufId] of plan.paramsByName) {\n params.set(name, buffers.get(bufId));\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const outputShape = [...plan.buffers[lossBufferId].shape];\n const destroy = () => {\n for (const [id, b] of buffers) {\n if (ownedBufferIds.has(id)) b.destroy();\n }\n outputReadback.destroy();\n if (captureStaging) captureStaging.buffer.destroy();\n };\n return {\n device: device2,\n params,\n outputShape,\n uploadParams,\n downloadParams: () => downloadFromMap(plan.paramsByName),\n downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),\n step,\n run,\n readLoss,\n resetOptimizerState,\n destroy\n };\n}\nasync function acquireDevice() {\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad: no WebGPU adapter");\n return await adapter.requestDevice();\n}\n\n// src/adam.ts\nfunction resolveLR(schedule, step) {\n if (typeof schedule === "number") return schedule;\n switch (schedule.kind) {\n case "constant":\n return schedule.value;\n case "linearDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.peak + (schedule.final - schedule.peak) * f;\n }\n case "cosineDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f));\n }\n case "warmup": {\n if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps);\n return resolveLR(schedule.after, step - schedule.warmupSteps);\n }\n }\n}\n\n// src/worker-protocol.ts\nfunction wireError(e) {\n if (e instanceof Error) {\n return { name: e.name, message: e.message, stack: e.stack ?? "" };\n }\n return { name: "Error", message: String(e), stack: "" };\n}\n\n// src/worker.ts\nvar graphs = /* @__PURE__ */ new Map();\nvar device = null;\nasync function ensureDevice() {\n if (device) return device;\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad worker: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad worker: no WebGPU adapter");\n device = await adapter.requestDevice();\n return device;\n}\nasync function handleCreateRuntime(payload) {\n const dev = await ensureDevice();\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n if (Object.keys(payload.initialParams).length > 0) {\n runtime.uploadParams(payload.initialParams);\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: payload.adam ? createAdamState(payload.adam) : null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nasync function handleCompileForward(payload) {\n const dev = await ensureDevice();\n const parent = graphs.get(payload.parentGraphId);\n if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`);\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev, sharedParams: parent.runtime.params };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nfunction createAdamState(cfg) {\n return {\n config: cfg,\n t: 0,\n lrtBuf: new Float32Array(1),\n decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null\n };\n}\nfunction injectAdamScalars(slot, inputs) {\n const a = slot.adam;\n if (!a) return inputs;\n a.t++;\n const lrNow = resolveLR(a.config.lr, a.t);\n a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t));\n const merged = { ...inputs, [a.config.lrtInputName]: a.lrtBuf };\n if (a.decayShrinkBuf && a.config.decayShrinkInputName) {\n a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay;\n merged[a.config.decayShrinkInputName] = a.decayShrinkBuf;\n }\n return merged;\n}\nasync function handleStep(payload) {\n const slot = mustGet(payload.graphId);\n const merged = injectAdamScalars(slot, payload.inputs);\n if (payload.withCaptures) {\n const r = await slot.runtime.step(merged, { withCaptures: true });\n return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const loss = await slot.runtime.step(merged);\n return { loss, captures: null };\n}\nasync function handleRun(payload) {\n const slot = mustGet(payload.graphId);\n if (payload.withCaptures) {\n const r = await slot.runtime.run(payload.inputs, { withCaptures: true });\n return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const output = await slot.runtime.run(payload.inputs);\n return { output, captures: null };\n}\nfunction capturesToRecord(captures, shapes) {\n const out = {};\n for (const name of Object.keys(shapes)) {\n if (captures.has(name)) out[name] = captures.get(name);\n }\n return out;\n}\nfunction handleUploadParams(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.uploadParams(payload.params, { partial: payload.partial });\n}\nasync function handleDownloadParams(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParams() };\n}\nasync function handleDownloadParamGrads(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParamGrads() };\n}\nfunction handleResetOptimizer(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.resetOptimizerState();\n if (slot.adam) slot.adam.t = 0;\n}\nfunction handleDestroy(payload) {\n const slot = graphs.get(payload.graphId);\n if (!slot) return;\n slot.runtime.destroy();\n graphs.delete(payload.graphId);\n}\nfunction mustGet(graphId) {\n const slot = graphs.get(graphId);\n if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`);\n return slot;\n}\nself.onmessage = async (ev) => {\n const req = ev.data;\n try {\n let result;\n let transferList = [];\n switch (req.kind) {\n case "createRuntime":\n result = await handleCreateRuntime(req.payload);\n break;\n case "compileForward":\n result = await handleCompileForward(req.payload);\n break;\n case "step":\n result = await handleStep(req.payload);\n transferList = collectTransfers(result.captures);\n break;\n case "run": {\n const r = await handleRun(req.payload);\n result = r;\n transferList = [r.output.buffer, ...collectTransfers(r.captures)];\n break;\n }\n case "uploadParams":\n handleUploadParams(req.payload);\n result = null;\n break;\n case "downloadParams": {\n const r = await handleDownloadParams(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "downloadParamGrads": {\n const r = await handleDownloadParamGrads(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "resetOptimizer":\n handleResetOptimizer(req.payload);\n result = null;\n break;\n case "destroy":\n handleDestroy(req.payload);\n result = null;\n break;\n default:\n throw new Error(`unknown request kind: ${req.kind}`);\n }\n const reply = { id: req.id, ok: true, result };\n self.postMessage(reply, { transfer: transferList });\n } catch (e) {\n const error = wireError(e);\n const reply = { id: req.id, ok: false, error };\n self.postMessage(reply);\n }\n};\nfunction collectTransfers(rec) {\n if (!rec) return [];\n const out = [];\n for (const v of Object.values(rec)) out.push(v.buffer);\n return out;\n}\n');
1847
1847
  const wireIR = { graph, plan, kernels };
1848
1848
  const transfers = transferablesOfRecord(initialParams);
1849
1849
  let meta;