tensorgrad 0.0.16 → 0.0.17
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +253 -119
- package/dist/index.js +2 -2
- package/dist/index.js.map +2 -2
- package/dist/worker.debug.js +39 -23
- package/package.json +1 -1
- package/src/runtime.ts +56 -31
package/README.md
CHANGED
|
@@ -1,119 +1,253 @@
|
|
|
1
|
-
# tensorgrad
|
|
2
|
-
|
|
3
|
-
A tiny TypeScript-native tensor library with autograd that compiles
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
```sh
|
|
8
|
-
npm i tensorgrad
|
|
9
|
-
```
|
|
10
|
-
|
|
11
|
-
Roughly
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
}
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
}
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
}
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
for (let step = 0; step < 1000; step++) {
|
|
64
|
-
const { x, y } = generateBatch()
|
|
65
|
-
const lossVal = await compiled.step({ x, y })
|
|
66
|
-
if (step % 100 === 0) console.log('step', step, 'loss', lossVal)
|
|
67
|
-
}
|
|
68
|
-
```
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
(`
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
1
|
+
# tensorgrad
|
|
2
|
+
|
|
3
|
+
A tiny TypeScript-native tensor library with autograd that compiles to WebGPU.
|
|
4
|
+
For training small models in the browser without hand-writing WGSL kernels and
|
|
5
|
+
without dragging in a multi-megabyte ML framework.
|
|
6
|
+
|
|
7
|
+
```sh
|
|
8
|
+
npm i tensorgrad
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
Roughly 3000 lines of zero-dependency TypeScript. Static shapes, `f32`, Adam
|
|
12
|
+
optimizer, ~25 ops, forward + reverse-mode autograd. Browser-only (uses
|
|
13
|
+
WebGPU). All training/inference work runs in a library-internal Web Worker —
|
|
14
|
+
every method on a compiled module returns a `Promise`.
|
|
15
|
+
|
|
16
|
+
## Minimal example
|
|
17
|
+
|
|
18
|
+
A 2-layer MLP fitting `y = sin(x)`:
|
|
19
|
+
|
|
20
|
+
```ts
|
|
21
|
+
import {
|
|
22
|
+
Module, compileModule, init,
|
|
23
|
+
add, mul, sub, sumLast, reshape, matmul, relu,
|
|
24
|
+
type Tensor,
|
|
25
|
+
} from 'tensorgrad'
|
|
26
|
+
|
|
27
|
+
const B = 256
|
|
28
|
+
|
|
29
|
+
class Linear extends Module {
|
|
30
|
+
W: Tensor; b: Tensor
|
|
31
|
+
constructor(public inDim: number, public outDim: number) {
|
|
32
|
+
super()
|
|
33
|
+
this.W = this.param([inDim, outDim], { init: init.kaiming() })
|
|
34
|
+
this.b = this.param([outDim], { init: 'zeros' })
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
class MLP extends Module {
|
|
39
|
+
l1 = new Linear(1, 64)
|
|
40
|
+
l2 = new Linear(64, 64)
|
|
41
|
+
l3 = new Linear(64, 1)
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
const linearFwd = (p: Linear, x: Tensor) => add(matmul(x, p.W), p.b)
|
|
45
|
+
|
|
46
|
+
function modelFwd(m: MLP, x: Tensor): Tensor {
|
|
47
|
+
return linearFwd(m.l3, relu(linearFwd(m.l2, relu(linearFwd(m.l1, x)))))
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
function lossFn(m: MLP, { x, y }: { x: Tensor; y: Tensor }): Tensor {
|
|
51
|
+
const diff = sub(modelFwd(m, x), y)
|
|
52
|
+
return mul(sumLast(reshape(mul(diff, diff), [B])), 1 / B)
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
const compiled = await compileModule(() => new MLP(), lossFn, {
|
|
56
|
+
adam: { lr: 0.005 },
|
|
57
|
+
inputs: {
|
|
58
|
+
x: { shape: [B, 1], dtype: 'f32' },
|
|
59
|
+
y: { shape: [B, 1], dtype: 'f32' },
|
|
60
|
+
},
|
|
61
|
+
})
|
|
62
|
+
|
|
63
|
+
for (let step = 0; step < 1000; step++) {
|
|
64
|
+
const { x, y } = generateBatch()
|
|
65
|
+
const lossVal = await compiled.step({ x, y })
|
|
66
|
+
if (step % 100 === 0) console.log('step', step, 'loss', lossVal)
|
|
67
|
+
}
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
## Mental model
|
|
71
|
+
|
|
72
|
+
- A `Module` subclass declares parameters via `this.param([shape], opts)` and
|
|
73
|
+
composes child modules as plain fields. The class is a tree of params.
|
|
74
|
+
- A *forward function* takes the materialized module + a record of named
|
|
75
|
+
input tensors and returns a tensor (the loss for `compileModule`, or any
|
|
76
|
+
output for `compileForward`).
|
|
77
|
+
- `compileModule(factory, forward, opts)` traces the forward, derives
|
|
78
|
+
gradients, wires Adam, plans buffers, generates WGSL, spawns a worker, and
|
|
79
|
+
returns a `CompiledModule`. The factory `() => new Model()` is invoked
|
|
80
|
+
once during compile; the model instance is consumed (its param sentinels
|
|
81
|
+
are mutated into `Tensor`s).
|
|
82
|
+
- Every method on the compiled module is async. `await compiled.step(...)`
|
|
83
|
+
resolves with the loss after the worker's GPU work finishes.
|
|
84
|
+
|
|
85
|
+
## Public API
|
|
86
|
+
|
|
87
|
+
### Compile entry points
|
|
88
|
+
|
|
89
|
+
```ts
|
|
90
|
+
compileModule(factory, forward, { adam?, inputs? }): Promise<CompiledModule>
|
|
91
|
+
compileForward(factory, forward, { inputs? }): Promise<CompiledForwardModule>
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
`compileForward` produces a forward-only graph in its own worker. To share
|
|
95
|
+
params with an existing training graph, use the sibling method:
|
|
96
|
+
|
|
97
|
+
```ts
|
|
98
|
+
const train = await compileModule(() => new Model(), lossFn, { ... })
|
|
99
|
+
const infer = await train.compileForward(predictFn, {
|
|
100
|
+
inputs: { tokens: { shape: [1, T], dtype: 'i32' } },
|
|
101
|
+
})
|
|
102
|
+
// infer runs in train's worker — every step's param updates are visible.
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
### CompiledModule methods (all `Promise`-returning)
|
|
106
|
+
|
|
107
|
+
```ts
|
|
108
|
+
compiled.step(inputs) // → loss: number
|
|
109
|
+
compiled.step(inputs, { withCaptures: true }) // → { loss, captures }
|
|
110
|
+
compiled.run(inputs) // → Float32Array
|
|
111
|
+
compiled.run(inputs, { withCaptures: true }) // → { output, captures }
|
|
112
|
+
compiled.uploadParams(record, { partial? })
|
|
113
|
+
compiled.downloadParams() // → Record<name, Float32Array>
|
|
114
|
+
compiled.downloadParamGrads() // → Record<name, Float32Array>
|
|
115
|
+
compiled.reset() // re-init params + zero Adam state
|
|
116
|
+
compiled.resetOptimizerState()
|
|
117
|
+
compiled.compileForward(forward, { inputs? }) // sibling forward graph
|
|
118
|
+
compiled.destroy() // tear down worker + GPU
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
`compiled.kernelCount`, `compiled.outputShape`, `compiled.paramNames`, and
|
|
122
|
+
`compiled.ir` are sync properties for inspection.
|
|
123
|
+
|
|
124
|
+
### Operators
|
|
125
|
+
|
|
126
|
+
Imported from `'tensorgrad'`:
|
|
127
|
+
|
|
128
|
+
- Element-wise: `add`, `sub`, `mul`, `div`, `sqrt`, `rsqrt`, `log`, `exp`, `relu`
|
|
129
|
+
- Comparisons / select: `less`, `greater`, `where`
|
|
130
|
+
- Reductions (last axis): `meanLast`, `sumLast`, `sumAll`
|
|
131
|
+
- Shape: `reshape`, `transpose`, `swapAxes`
|
|
132
|
+
- Linear algebra: `matmul`, `matmulBatched`
|
|
133
|
+
- Indexing / casting: `oneHot`, `arange`, `embedding`
|
|
134
|
+
- Slicing: `sliceLastRange`
|
|
135
|
+
- Fused ML primitives: `softmaxCausalLast`, `logSoftmaxLast`, `whereCausal`
|
|
136
|
+
|
|
137
|
+
`add`, `sub`, `mul`, `div` accept `(Tensor, Tensor)` or `(Tensor, number)`.
|
|
138
|
+
|
|
139
|
+
### `nn` namespace
|
|
140
|
+
|
|
141
|
+
```ts
|
|
142
|
+
import { nn } from 'tensorgrad'
|
|
143
|
+
|
|
144
|
+
nn.Linear(inDim, outDim, { bias? }) // .fwd(x)
|
|
145
|
+
nn.LayerNorm(dim) // .fwd(x)
|
|
146
|
+
nn.splitHeads(x, nHeads) // [B, T, D] → [B, H, T, D/H]
|
|
147
|
+
nn.mergeHeads(x) // inverse of splitHeads
|
|
148
|
+
nn.unsplitHeads(captures, name) // pull per-head slices off a capture
|
|
149
|
+
nn.crossEntropyLast(logits, targets) // standard CE
|
|
150
|
+
```
|
|
151
|
+
|
|
152
|
+
Convention: leaf modules (`Linear`, `LayerNorm`) expose `.fwd(x)` for ergonomic
|
|
153
|
+
chaining. Composite modules you write yourself are typically free functions
|
|
154
|
+
taking `(p: ModuleType, x: Tensor)`.
|
|
155
|
+
|
|
156
|
+
### LR schedules (`lr` namespace)
|
|
157
|
+
|
|
158
|
+
```ts
|
|
159
|
+
import { lr } from 'tensorgrad'
|
|
160
|
+
|
|
161
|
+
adam: { lr: 0.005 } // constant
|
|
162
|
+
adam: { lr: lr.linearDecay({ peak: 0.005, final: 0.0005, steps: 1500 }) }
|
|
163
|
+
adam: { lr: lr.cosineDecay({ peak: 0.005, final: 0.0001, steps: 5000 }) }
|
|
164
|
+
adam: { lr: lr.warmup({ peakLr: 0.001, warmupSteps: 200, after: lr.constant(0.001) }) }
|
|
165
|
+
```
|
|
166
|
+
|
|
167
|
+
LR schedules are serializable shapes, not closures (they cross the worker
|
|
168
|
+
boundary). Use a `number` for constant LR, or one of the constructors above.
|
|
169
|
+
|
|
170
|
+
### Param init (`init` namespace)
|
|
171
|
+
|
|
172
|
+
```ts
|
|
173
|
+
import { init } from 'tensorgrad'
|
|
174
|
+
|
|
175
|
+
this.param([D, D], { init: init.kaiming() }) // gain=sqrt(2), fan_in=D
|
|
176
|
+
this.param([D, D], { init: init.kaiming({ gain: 1 }) })
|
|
177
|
+
this.param([D], { init: 'zeros' })
|
|
178
|
+
this.param([D], { init: 'ones' })
|
|
179
|
+
this.param([D, D], { init: init.randn({ scale: 0.02 }) })
|
|
180
|
+
this.param([D], { init: init.literal(myFloat32Array) })
|
|
181
|
+
```
|
|
182
|
+
|
|
183
|
+
Defaults: `'randn'` (std 0.02). AdamW weight decay defaults to `true` for
|
|
184
|
+
randn/kaiming/literal init, `false` for zeros/ones — override per-param with
|
|
185
|
+
`{ decay: true | false }`.
|
|
186
|
+
|
|
187
|
+
### Captures (debugging / mech-interp)
|
|
188
|
+
|
|
189
|
+
Wrap any tensor inside a forward to expose its activation post-run:
|
|
190
|
+
|
|
191
|
+
```ts
|
|
192
|
+
import { capture } from 'tensorgrad'
|
|
193
|
+
|
|
194
|
+
const attn = capture(`attn.${i}`, softmaxCausalLast(scores))
|
|
195
|
+
```
|
|
196
|
+
|
|
197
|
+
```ts
|
|
198
|
+
const { output, captures } = await compiled.run(inputs, { withCaptures: true })
|
|
199
|
+
const attn0 = captures.get('attn.0') // Float32Array
|
|
200
|
+
captures.shapeOf('attn.0') // readonly number[]
|
|
201
|
+
```
|
|
202
|
+
|
|
203
|
+
Captures are zero-overhead unless `{ withCaptures: true }` is passed; they
|
|
204
|
+
add a single batched mapAsync on the readback.
|
|
205
|
+
|
|
206
|
+
## Constraints
|
|
207
|
+
|
|
208
|
+
The library is small because of what it doesn't do. Plan accordingly:
|
|
209
|
+
|
|
210
|
+
- **WebGPU only.** No Wasm, WebGL, or native fallback.
|
|
211
|
+
- **Static shapes.** Every shape is fixed at compile time. Changing a batch
|
|
212
|
+
size means recompiling.
|
|
213
|
+
- **`f32` only.** No mixed precision. Inputs may be `i32` for indices.
|
|
214
|
+
- **One transformation: `grad`.** No `vmap`, `pmap`, `jvp`, `custom_vjp`.
|
|
215
|
+
Batch your data explicitly.
|
|
216
|
+
- **Loss must be a scalar.** `compileModule`'s forward returns a rank-0 tensor.
|
|
217
|
+
- **Closures don't cross the worker boundary.** LR schedules and inits are
|
|
218
|
+
serializable shapes, not functions. Anything per-step you write into a
|
|
219
|
+
user-defined optimizer (see *Extending* below) follows the same rule.
|
|
220
|
+
- **One model per `compileModule` call.** Sibling forward graphs share params
|
|
221
|
+
via the method form; otherwise each compile spawns its own worker.
|
|
222
|
+
|
|
223
|
+
## Extending
|
|
224
|
+
|
|
225
|
+
The IR is open. Adam is built in only because it's the most common starting
|
|
226
|
+
point — other optimizers, custom losses, or extra ops are user code following
|
|
227
|
+
the same pattern as `appendAdam`:
|
|
228
|
+
|
|
229
|
+
```ts
|
|
230
|
+
import { appendAdam, appendGrad, compileToIR } from 'tensorgrad'
|
|
231
|
+
```
|
|
232
|
+
|
|
233
|
+
A custom optimizer is a function that takes the autograd output (graph +
|
|
234
|
+
`paramGrads`) and the materialized param tensors, appends its update ops
|
|
235
|
+
to the graph, and returns writeback declarations the buffer planner uses
|
|
236
|
+
to wire each new value back into its persistent home. SGD, Lion, RMSProp
|
|
237
|
+
all fit this shape; see `src/adam.ts` for the canonical example.
|
|
238
|
+
|
|
239
|
+
The same applies to ops: anything missing from the built-in set can be
|
|
240
|
+
expressed as a composition of existing ops (GELU, RMSNorm, etc. are a few
|
|
241
|
+
lines), or — if you need a new primitive — added to the IR with a
|
|
242
|
+
forward + backward + WGSL emit.
|
|
243
|
+
|
|
244
|
+
## When not to use this
|
|
245
|
+
|
|
246
|
+
- **Inference of pretrained models** → use ONNX Runtime Web or
|
|
247
|
+
transformers.js.
|
|
248
|
+
- **Full JAX surface** (vmap, dynamic shapes, multi-backend) → use jax-js.
|
|
249
|
+
- **Server-side training** → use PyTorch or JAX.
|
|
250
|
+
|
|
251
|
+
## License
|
|
252
|
+
|
|
253
|
+
MIT
|
package/dist/index.js
CHANGED
|
@@ -1805,7 +1805,7 @@ async function compileModule(modelFactory, forward, opts = {}) {
|
|
|
1805
1805
|
const kernels = emitKernels(graph, plan);
|
|
1806
1806
|
const ir = { graph, paramGrads, loss, plan, kernels };
|
|
1807
1807
|
const initialParams = buildInitialParams(plan, materialized.initFns);
|
|
1808
|
-
const proxy = new WorkerProxy('// src/runtime.ts\nvar Captures = class {\n constructor(shapes, data) {\n this.shapes = shapes;\n this.data = data;\n }\n shapes;\n data;\n get(name) {\n const d = this.data.get(name);\n if (!d) {\n const known = [...this.data.keys()].sort().join(", ");\n const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;\n throw new Error(`Captures.get: \'${name}\' not present. ${detail}`);\n }\n return d;\n }\n shapeOf(name) {\n const s = this.shapes[name];\n if (!s) {\n const known = Object.keys(this.shapes).sort().join(", ") || "(none registered)";\n throw new Error(`Captures.shapeOf: \'${name}\' not registered. Known: ${known}`);\n }\n return s;\n }\n has(name) {\n return this.data.has(name);\n }\n names() {\n return [...this.data.keys()].sort();\n }\n};\nvar STORAGE_RW = 128 | 8 | 4;\nvar READBACK = 1 | 8;\nasync function createRuntime(plan, kernels, lossBufferId, opts = {}) {\n const device2 = opts.device ?? await acquireDevice();\n const queue = device2.queue;\n const buffers = /* @__PURE__ */ new Map();\n const ownedBufferIds = /* @__PURE__ */ new Set();\n const sharedParams = opts.sharedParams;\n for (const spec of plan.buffers) {\n const shared = spec.kind === "param" ? sharedParams?.get(spec.name) : void 0;\n if (shared) {\n if (shared.size !== spec.byteSize) {\n throw new Error(\n `sharedParams: size mismatch for \'${spec.name}\' \\u2014 supplied ${shared.size} bytes, compiled graph expects ${spec.byteSize}.`\n );\n }\n buffers.set(spec.id, shared);\n continue;\n }\n const buf = device2.createBuffer({\n size: spec.byteSize,\n usage: STORAGE_RW,\n label: spec.name ?? `t${spec.id}-${spec.kind}`\n });\n buffers.set(spec.id, buf);\n ownedBufferIds.add(spec.id);\n if (spec.kind === "state") fillStateBuffer(spec, buf);\n }\n const moduleCache = /* @__PURE__ */ new Map();\n const pipelines = [];\n const probes = [];\n for (const k of kernels) {\n if (!k.wgsl) {\n pipelines.push(null);\n continue;\n }\n let module = moduleCache.get(k.wgsl);\n if (!module) {\n module = device2.createShaderModule({ code: k.wgsl, label: k.opKind });\n moduleCache.set(k.wgsl, module);\n }\n device2.pushErrorScope("validation");\n const pipeline = device2.createComputePipeline({\n layout: "auto",\n compute: { module, entryPoint: "main" },\n label: k.opKind\n });\n pipelines.push(pipeline);\n probes.push(device2.popErrorScope().then((err) => err ? { k, module, err } : null));\n }\n const probeResults = await Promise.all(probes);\n const failures = probeResults.filter((p) => p != null);\n if (failures.length > 0) {\n const reports = [];\n for (const { k, module, err } of failures) {\n const info = await module.getCompilationInfo();\n const messages = info.messages.map((m) => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`).join("\\n");\n reports.push(\n `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` + (messages || " (no compilation messages)") + `\n--- WGSL ---\n${k.wgsl}\n-----------`\n );\n }\n console.error(reports.join("\\n\\n"));\n throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`);\n }\n const bindGroups = kernels.map((k, i) => {\n const pipeline = pipelines[i];\n if (!pipeline) return null;\n return device2.createBindGroup({\n layout: pipeline.getBindGroupLayout(0),\n entries: k.bindings.map((bufId, idx) => ({\n binding: idx,\n resource: { buffer: buffers.get(bufId) }\n }))\n });\n });\n const outputSpec = plan.buffers[lossBufferId];\n const outputReadback = device2.createBuffer({ size: outputSpec.byteSize, usage: READBACK });\n let captureStaging = null;\n function ensureCaptureStaging() {\n if (captureStaging) return captureStaging;\n let totalBytes = 0;\n const slices = [];\n for (const [name, bufId] of plan.capturesByName) {\n const spec = plan.buffers[bufId];\n slices.push({ name, bufId, offset: totalBytes, byteSize: spec.byteSize });\n totalBytes += spec.byteSize;\n }\n const buffer = device2.createBuffer({ size: totalBytes, usage: READBACK, label: "captures-staging" });\n captureStaging = { buffer, slices };\n return captureStaging;\n }\n let pending = Promise.resolve();\n async function dispatch(inputs, opts2) {\n const turn = pending.catch(() => {\n }).then(() => dispatchUnsynchronized(inputs, opts2));\n pending = turn;\n return turn;\n }\n async function dispatchUnsynchronized(inputs, opts2) {\n const wantCaptures = opts2.wantCaptures;\n if (wantCaptures && plan.capturesByName.size === 0) {\n throw new Error(\n `withCaptures=true but no capture(...) calls were registered during the trace. Add capture(\'name\', tensor) inside your forward pass for the intermediates you want read back.`\n );\n }\n for (const [name, bufId] of plan.inputsByName) {\n const data = inputs[name];\n if (!data) throw new Error(`tensorgrad: missing input \'${name}\'`);\n const expectedBytes = plan.buffers[bufId].byteSize;\n if (data.byteLength !== expectedBytes) {\n throw new Error(`tensorgrad: input \'${name}\' has ${data.byteLength} bytes, expected ${expectedBytes}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-step" });\n for (let i = 0; i < kernels.length; i++) {\n const k = kernels[i];\n if (!k.wgsl || k.threads === 0) continue;\n const pipeline = pipelines[i];\n const bindGroup = bindGroups[i];\n const pass = encoder.beginComputePass({ label: k.opKind });\n pass.setPipeline(pipeline);\n pass.setBindGroup(0, bindGroup);\n const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));\n const MAX_X = 65535;\n const wgX = Math.min(wgCount, MAX_X);\n const wgY = Math.ceil(wgCount / MAX_X);\n pass.dispatchWorkgroups(wgX, wgY, 1);\n pass.end();\n }\n for (const wb of plan.writebacks) {\n encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);\n }\n encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);\n let layout = null;\n if (wantCaptures) {\n layout = ensureCaptureStaging();\n for (const s of layout.slices) {\n encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);\n }\n }\n queue.submit([encoder.finish()]);\n if (!opts2.readback) return null;\n await outputReadback.mapAsync(GPUMapMode.READ);\n const output = new Float32Array(outputReadback.getMappedRange().slice(0));\n outputReadback.unmap();\n const captures = /* @__PURE__ */ new Map();\n if (layout) {\n await layout.buffer.mapAsync(GPUMapMode.READ);\n const range = layout.buffer.getMappedRange();\n for (const s of layout.slices) {\n captures.set(s.name, new Float32Array(range, s.offset, s.byteSize / 4).slice());\n }\n layout.buffer.unmap();\n }\n return { output, captures };\n }\n async function step(inputs, opts2) {\n if (opts2?.readLoss === false) {\n await dispatch(inputs, { wantCaptures: false, readback: false });\n return;\n }\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { loss: r.output[0], captures: new Captures(captureShapes, r.captures) };\n return r.output[0];\n }\n async function readLoss() {\n const turn = pending.catch(() => {\n }).then(async () => {\n await outputReadback.mapAsync(GPUMapMode.READ);\n const v = new Float32Array(outputReadback.getMappedRange())[0];\n outputReadback.unmap();\n return v;\n });\n pending = turn;\n return turn;\n }\n async function run(inputs, opts2) {\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) };\n return r.output;\n }\n function uploadParams(params2, opts2) {\n const partial = opts2?.partial ?? false;\n for (const name of Object.keys(params2)) {\n if (!plan.paramsByName.has(name)) {\n throw new Error(\n `uploadParams: unknown param \'${name}\'. Known: ${[...plan.paramsByName.keys()].sort().join(", ")}`\n );\n }\n }\n if (!partial) {\n for (const name of plan.paramsByName.keys()) {\n if (!(name in params2)) {\n throw new Error(\n `uploadParams: missing param \'${name}\'. Pass { partial: true } if you mean to update only some params.`\n );\n }\n }\n }\n for (const [name, bufId] of plan.paramsByName) {\n const data = params2[name];\n if (!data) continue;\n const expected = plan.buffers[bufId].byteSize / 4;\n if (data.length !== expected) {\n throw new Error(`uploadParams: \'${name}\' has ${data.length} elements, expected ${expected}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n }\n async function downloadFromMap(map) {\n const stagings = [];\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-download" });\n for (const [name, bufId] of map) {\n const spec = plan.buffers[bufId];\n const staging = device2.createBuffer({ size: spec.byteSize, usage: READBACK });\n encoder.copyBufferToBuffer(buffers.get(bufId), 0, staging, 0, spec.byteSize);\n stagings.push({ name, buf: staging, bytes: spec.byteSize });\n }\n queue.submit([encoder.finish()]);\n const out = {};\n for (const s of stagings) {\n await s.buf.mapAsync(GPUMapMode.READ);\n out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0));\n s.buf.unmap();\n s.buf.destroy();\n }\n return out;\n }\n function fillStateBuffer(spec, target) {\n const elements = spec.byteSize / 4;\n const init = spec.dtype === "f32" ? new Float32Array(elements).fill(spec.initValue ?? 0) : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0));\n queue.writeBuffer(target, 0, init);\n }\n function resetOptimizerState() {\n for (const spec of plan.buffers) {\n if (spec.kind === "state") fillStateBuffer(spec, buffers.get(spec.id));\n }\n }\n const params = /* @__PURE__ */ new Map();\n for (const [name, bufId] of plan.paramsByName) {\n params.set(name, buffers.get(bufId));\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const outputShape = [...plan.buffers[lossBufferId].shape];\n const destroy = () => {\n for (const [id, b] of buffers) {\n if (ownedBufferIds.has(id)) b.destroy();\n }\n outputReadback.destroy();\n if (captureStaging) captureStaging.buffer.destroy();\n };\n return {\n device: device2,\n params,\n outputShape,\n uploadParams,\n downloadParams: () => downloadFromMap(plan.paramsByName),\n downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),\n step,\n run,\n readLoss,\n resetOptimizerState,\n destroy\n };\n}\nasync function acquireDevice() {\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad: no WebGPU adapter");\n return await adapter.requestDevice();\n}\n\n// src/adam.ts\nfunction resolveLR(schedule, step) {\n if (typeof schedule === "number") return schedule;\n switch (schedule.kind) {\n case "constant":\n return schedule.value;\n case "linearDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.peak + (schedule.final - schedule.peak) * f;\n }\n case "cosineDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f));\n }\n case "warmup": {\n if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps);\n return resolveLR(schedule.after, step - schedule.warmupSteps);\n }\n }\n}\n\n// src/worker-protocol.ts\nfunction wireError(e) {\n if (e instanceof Error) {\n return { name: e.name, message: e.message, stack: e.stack ?? "" };\n }\n return { name: "Error", message: String(e), stack: "" };\n}\n\n// src/worker.ts\nvar graphs = /* @__PURE__ */ new Map();\nvar device = null;\nasync function ensureDevice() {\n if (device) return device;\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad worker: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad worker: no WebGPU adapter");\n device = await adapter.requestDevice();\n return device;\n}\nasync function handleCreateRuntime(payload) {\n const dev = await ensureDevice();\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n if (Object.keys(payload.initialParams).length > 0) {\n runtime.uploadParams(payload.initialParams);\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: payload.adam ? createAdamState(payload.adam) : null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nasync function handleCompileForward(payload) {\n const dev = await ensureDevice();\n const parent = graphs.get(payload.parentGraphId);\n if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`);\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev, sharedParams: parent.runtime.params };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nfunction createAdamState(cfg) {\n return {\n config: cfg,\n t: 0,\n lrtBuf: new Float32Array(1),\n decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null\n };\n}\nfunction injectAdamScalars(slot, inputs) {\n const a = slot.adam;\n if (!a) return inputs;\n a.t++;\n const lrNow = resolveLR(a.config.lr, a.t);\n a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t));\n const merged = { ...inputs, [a.config.lrtInputName]: a.lrtBuf };\n if (a.decayShrinkBuf && a.config.decayShrinkInputName) {\n a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay;\n merged[a.config.decayShrinkInputName] = a.decayShrinkBuf;\n }\n return merged;\n}\nasync function handleStep(payload) {\n const slot = mustGet(payload.graphId);\n const merged = injectAdamScalars(slot, payload.inputs);\n if (payload.withCaptures) {\n const r = await slot.runtime.step(merged, { withCaptures: true });\n return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const loss = await slot.runtime.step(merged);\n return { loss, captures: null };\n}\nasync function handleRun(payload) {\n const slot = mustGet(payload.graphId);\n if (payload.withCaptures) {\n const r = await slot.runtime.run(payload.inputs, { withCaptures: true });\n return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const output = await slot.runtime.run(payload.inputs);\n return { output, captures: null };\n}\nfunction capturesToRecord(captures, shapes) {\n const out = {};\n for (const name of Object.keys(shapes)) {\n if (captures.has(name)) out[name] = captures.get(name);\n }\n return out;\n}\nfunction handleUploadParams(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.uploadParams(payload.params, { partial: payload.partial });\n}\nasync function handleDownloadParams(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParams() };\n}\nasync function handleDownloadParamGrads(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParamGrads() };\n}\nfunction handleResetOptimizer(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.resetOptimizerState();\n if (slot.adam) slot.adam.t = 0;\n}\nfunction handleDestroy(payload) {\n const slot = graphs.get(payload.graphId);\n if (!slot) return;\n slot.runtime.destroy();\n graphs.delete(payload.graphId);\n}\nfunction mustGet(graphId) {\n const slot = graphs.get(graphId);\n if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`);\n return slot;\n}\nself.onmessage = async (ev) => {\n const req = ev.data;\n try {\n let result;\n let transferList = [];\n switch (req.kind) {\n case "createRuntime":\n result = await handleCreateRuntime(req.payload);\n break;\n case "compileForward":\n result = await handleCompileForward(req.payload);\n break;\n case "step":\n result = await handleStep(req.payload);\n transferList = collectTransfers(result.captures);\n break;\n case "run": {\n const r = await handleRun(req.payload);\n result = r;\n transferList = [r.output.buffer, ...collectTransfers(r.captures)];\n break;\n }\n case "uploadParams":\n handleUploadParams(req.payload);\n result = null;\n break;\n case "downloadParams": {\n const r = await handleDownloadParams(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "downloadParamGrads": {\n const r = await handleDownloadParamGrads(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "resetOptimizer":\n handleResetOptimizer(req.payload);\n result = null;\n break;\n case "destroy":\n handleDestroy(req.payload);\n result = null;\n break;\n default:\n throw new Error(`unknown request kind: ${req.kind}`);\n }\n const reply = { id: req.id, ok: true, result };\n self.postMessage(reply, { transfer: transferList });\n } catch (e) {\n const error = wireError(e);\n const reply = { id: req.id, ok: false, error };\n self.postMessage(reply);\n }\n};\nfunction collectTransfers(rec) {\n if (!rec) return [];\n const out = [];\n for (const v of Object.values(rec)) out.push(v.buffer);\n return out;\n}\n');
|
|
1808
|
+
const proxy = new WorkerProxy('// src/runtime.ts\nvar Captures = class {\n constructor(shapes, data) {\n this.shapes = shapes;\n this.data = data;\n }\n shapes;\n data;\n get(name) {\n const d = this.data.get(name);\n if (!d) {\n const known = [...this.data.keys()].sort().join(", ");\n const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;\n throw new Error(`Captures.get: \'${name}\' not present. ${detail}`);\n }\n return d;\n }\n shapeOf(name) {\n const s = this.shapes[name];\n if (!s) {\n const known = Object.keys(this.shapes).sort().join(", ") || "(none registered)";\n throw new Error(`Captures.shapeOf: \'${name}\' not registered. Known: ${known}`);\n }\n return s;\n }\n has(name) {\n return this.data.has(name);\n }\n names() {\n return [...this.data.keys()].sort();\n }\n};\nvar STORAGE_RW = 128 | 8 | 4;\nvar READBACK = 1 | 8;\nasync function createRuntime(plan, kernels, lossBufferId, opts = {}) {\n const device2 = opts.device ?? await acquireDevice();\n const queue = device2.queue;\n const buffers = /* @__PURE__ */ new Map();\n const ownedBufferIds = /* @__PURE__ */ new Set();\n const sharedParams = opts.sharedParams;\n for (const spec of plan.buffers) {\n const shared = spec.kind === "param" ? sharedParams?.get(spec.name) : void 0;\n if (shared) {\n if (shared.size !== spec.byteSize) {\n throw new Error(\n `sharedParams: size mismatch for \'${spec.name}\' \\u2014 supplied ${shared.size} bytes, compiled graph expects ${spec.byteSize}.`\n );\n }\n buffers.set(spec.id, shared);\n continue;\n }\n const buf = device2.createBuffer({\n size: spec.byteSize,\n usage: STORAGE_RW,\n label: spec.name ?? `t${spec.id}-${spec.kind}`\n });\n buffers.set(spec.id, buf);\n ownedBufferIds.add(spec.id);\n if (spec.kind === "state") fillStateBuffer(spec, buf);\n }\n const moduleCache = /* @__PURE__ */ new Map();\n const pipelines = [];\n const probes = [];\n for (const k of kernels) {\n if (!k.wgsl) {\n pipelines.push(null);\n continue;\n }\n let module = moduleCache.get(k.wgsl);\n if (!module) {\n module = device2.createShaderModule({ code: k.wgsl, label: k.opKind });\n moduleCache.set(k.wgsl, module);\n }\n device2.pushErrorScope("validation");\n const pipeline = device2.createComputePipeline({\n layout: "auto",\n compute: { module, entryPoint: "main" },\n label: k.opKind\n });\n pipelines.push(pipeline);\n probes.push(device2.popErrorScope().then((err) => err ? { k, module, err } : null));\n }\n const probeResults = await Promise.all(probes);\n const failures = probeResults.filter((p) => p != null);\n if (failures.length > 0) {\n const reports = [];\n for (const { k, module, err } of failures) {\n const info = await module.getCompilationInfo();\n const messages = info.messages.map((m) => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`).join("\\n");\n reports.push(\n `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` + (messages || " (no compilation messages)") + `\n--- WGSL ---\n${k.wgsl}\n-----------`\n );\n }\n console.error(reports.join("\\n\\n"));\n throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`);\n }\n const bindGroups = kernels.map((k, i) => {\n const pipeline = pipelines[i];\n if (!pipeline) return null;\n return device2.createBindGroup({\n layout: pipeline.getBindGroupLayout(0),\n entries: k.bindings.map((bufId, idx) => ({\n binding: idx,\n resource: { buffer: buffers.get(bufId) }\n }))\n });\n });\n const outputSpec = plan.buffers[lossBufferId];\n const outputReadback = device2.createBuffer({ size: outputSpec.byteSize, usage: READBACK });\n let captureStaging = null;\n function ensureCaptureStaging() {\n if (captureStaging) return captureStaging;\n let totalBytes = 0;\n const slices = [];\n for (const [name, bufId] of plan.capturesByName) {\n const spec = plan.buffers[bufId];\n slices.push({ name, bufId, offset: totalBytes, byteSize: spec.byteSize });\n totalBytes += spec.byteSize;\n }\n const buffer = device2.createBuffer({ size: totalBytes, usage: READBACK, label: "captures-staging" });\n captureStaging = { buffer, slices };\n return captureStaging;\n }\n let pending = Promise.resolve();\n async function dispatch(inputs, opts2) {\n const turn = pending.catch(() => {\n }).then(() => dispatchUnsynchronized(inputs, opts2));\n pending = turn;\n return turn;\n }\n async function dispatchUnsynchronized(inputs, opts2) {\n const wantCaptures = opts2.wantCaptures;\n if (wantCaptures && plan.capturesByName.size === 0) {\n throw new Error(\n `withCaptures=true but no capture(...) calls were registered during the trace. Add capture(\'name\', tensor) inside your forward pass for the intermediates you want read back.`\n );\n }\n for (const [name, bufId] of plan.inputsByName) {\n const data = inputs[name];\n if (!data) throw new Error(`tensorgrad: missing input \'${name}\'`);\n const expectedBytes = plan.buffers[bufId].byteSize;\n if (data.byteLength !== expectedBytes) {\n throw new Error(`tensorgrad: input \'${name}\' has ${data.byteLength} bytes, expected ${expectedBytes}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n const CHUNK_SIZE = 32;\n let layout = null;\n if (wantCaptures) {\n layout = ensureCaptureStaging();\n }\n let kernelIdx = 0;\n while (kernelIdx < kernels.length) {\n const chunkEnd = Math.min(kernelIdx + CHUNK_SIZE, kernels.length);\n const isLast = chunkEnd === kernels.length;\n const encoder = device2.createCommandEncoder({\n label: kernels.length > CHUNK_SIZE ? `tensorgrad-chunk-${kernelIdx}` : "tensorgrad-step"\n });\n for (let i = kernelIdx; i < chunkEnd; i++) {\n const k = kernels[i];\n if (!k.wgsl || k.threads === 0) continue;\n const pipeline = pipelines[i];\n const bindGroup = bindGroups[i];\n const pass = encoder.beginComputePass({ label: k.opKind });\n pass.setPipeline(pipeline);\n pass.setBindGroup(0, bindGroup);\n const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));\n const MAX_X = 65535;\n const wgX = Math.min(wgCount, MAX_X);\n const wgY = Math.ceil(wgCount / MAX_X);\n pass.dispatchWorkgroups(wgX, wgY, 1);\n pass.end();\n }\n if (isLast) {\n for (const wb of plan.writebacks) {\n encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);\n }\n encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);\n if (layout) {\n for (const s of layout.slices) {\n encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);\n }\n }\n }\n queue.submit([encoder.finish()]);\n if (!isLast) {\n await queue.onSubmittedWorkDone();\n }\n kernelIdx = chunkEnd;\n }\n if (!opts2.readback) return null;\n await outputReadback.mapAsync(GPUMapMode.READ);\n const output = new Float32Array(outputReadback.getMappedRange().slice(0));\n outputReadback.unmap();\n const captures = /* @__PURE__ */ new Map();\n if (layout) {\n await layout.buffer.mapAsync(GPUMapMode.READ);\n const range = layout.buffer.getMappedRange();\n for (const s of layout.slices) {\n captures.set(s.name, new Float32Array(range, s.offset, s.byteSize / 4).slice());\n }\n layout.buffer.unmap();\n }\n return { output, captures };\n }\n async function step(inputs, opts2) {\n if (opts2?.readLoss === false) {\n await dispatch(inputs, { wantCaptures: false, readback: false });\n return;\n }\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { loss: r.output[0], captures: new Captures(captureShapes, r.captures) };\n return r.output[0];\n }\n async function readLoss() {\n const turn = pending.catch(() => {\n }).then(async () => {\n await outputReadback.mapAsync(GPUMapMode.READ);\n const v = new Float32Array(outputReadback.getMappedRange())[0];\n outputReadback.unmap();\n return v;\n });\n pending = turn;\n return turn;\n }\n async function run(inputs, opts2) {\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) };\n return r.output;\n }\n function uploadParams(params2, opts2) {\n const partial = opts2?.partial ?? false;\n for (const name of Object.keys(params2)) {\n if (!plan.paramsByName.has(name)) {\n throw new Error(\n `uploadParams: unknown param \'${name}\'. Known: ${[...plan.paramsByName.keys()].sort().join(", ")}`\n );\n }\n }\n if (!partial) {\n for (const name of plan.paramsByName.keys()) {\n if (!(name in params2)) {\n throw new Error(\n `uploadParams: missing param \'${name}\'. Pass { partial: true } if you mean to update only some params.`\n );\n }\n }\n }\n for (const [name, bufId] of plan.paramsByName) {\n const data = params2[name];\n if (!data) continue;\n const expected = plan.buffers[bufId].byteSize / 4;\n if (data.length !== expected) {\n throw new Error(`uploadParams: \'${name}\' has ${data.length} elements, expected ${expected}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n }\n async function downloadFromMap(map) {\n const stagings = [];\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-download" });\n for (const [name, bufId] of map) {\n const spec = plan.buffers[bufId];\n const staging = device2.createBuffer({ size: spec.byteSize, usage: READBACK });\n encoder.copyBufferToBuffer(buffers.get(bufId), 0, staging, 0, spec.byteSize);\n stagings.push({ name, buf: staging, bytes: spec.byteSize });\n }\n queue.submit([encoder.finish()]);\n const out = {};\n for (const s of stagings) {\n await s.buf.mapAsync(GPUMapMode.READ);\n out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0));\n s.buf.unmap();\n s.buf.destroy();\n }\n return out;\n }\n function fillStateBuffer(spec, target) {\n const elements = spec.byteSize / 4;\n const init = spec.dtype === "f32" ? new Float32Array(elements).fill(spec.initValue ?? 0) : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0));\n queue.writeBuffer(target, 0, init);\n }\n function resetOptimizerState() {\n for (const spec of plan.buffers) {\n if (spec.kind === "state") fillStateBuffer(spec, buffers.get(spec.id));\n }\n }\n const params = /* @__PURE__ */ new Map();\n for (const [name, bufId] of plan.paramsByName) {\n params.set(name, buffers.get(bufId));\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const outputShape = [...plan.buffers[lossBufferId].shape];\n const destroy = () => {\n for (const [id, b] of buffers) {\n if (ownedBufferIds.has(id)) b.destroy();\n }\n outputReadback.destroy();\n if (captureStaging) captureStaging.buffer.destroy();\n };\n return {\n device: device2,\n params,\n outputShape,\n uploadParams,\n downloadParams: () => downloadFromMap(plan.paramsByName),\n downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),\n step,\n run,\n readLoss,\n resetOptimizerState,\n destroy\n };\n}\nasync function acquireDevice() {\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad: no WebGPU adapter");\n return await adapter.requestDevice();\n}\n\n// src/adam.ts\nfunction resolveLR(schedule, step) {\n if (typeof schedule === "number") return schedule;\n switch (schedule.kind) {\n case "constant":\n return schedule.value;\n case "linearDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.peak + (schedule.final - schedule.peak) * f;\n }\n case "cosineDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f));\n }\n case "warmup": {\n if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps);\n return resolveLR(schedule.after, step - schedule.warmupSteps);\n }\n }\n}\n\n// src/worker-protocol.ts\nfunction wireError(e) {\n if (e instanceof Error) {\n return { name: e.name, message: e.message, stack: e.stack ?? "" };\n }\n return { name: "Error", message: String(e), stack: "" };\n}\n\n// src/worker.ts\nvar graphs = /* @__PURE__ */ new Map();\nvar device = null;\nasync function ensureDevice() {\n if (device) return device;\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad worker: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad worker: no WebGPU adapter");\n device = await adapter.requestDevice();\n return device;\n}\nasync function handleCreateRuntime(payload) {\n const dev = await ensureDevice();\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n if (Object.keys(payload.initialParams).length > 0) {\n runtime.uploadParams(payload.initialParams);\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: payload.adam ? createAdamState(payload.adam) : null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nasync function handleCompileForward(payload) {\n const dev = await ensureDevice();\n const parent = graphs.get(payload.parentGraphId);\n if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`);\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev, sharedParams: parent.runtime.params };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nfunction createAdamState(cfg) {\n return {\n config: cfg,\n t: 0,\n lrtBuf: new Float32Array(1),\n decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null\n };\n}\nfunction injectAdamScalars(slot, inputs) {\n const a = slot.adam;\n if (!a) return inputs;\n a.t++;\n const lrNow = resolveLR(a.config.lr, a.t);\n a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t));\n const merged = { ...inputs, [a.config.lrtInputName]: a.lrtBuf };\n if (a.decayShrinkBuf && a.config.decayShrinkInputName) {\n a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay;\n merged[a.config.decayShrinkInputName] = a.decayShrinkBuf;\n }\n return merged;\n}\nasync function handleStep(payload) {\n const slot = mustGet(payload.graphId);\n const merged = injectAdamScalars(slot, payload.inputs);\n if (payload.withCaptures) {\n const r = await slot.runtime.step(merged, { withCaptures: true });\n return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const loss = await slot.runtime.step(merged);\n return { loss, captures: null };\n}\nasync function handleRun(payload) {\n const slot = mustGet(payload.graphId);\n if (payload.withCaptures) {\n const r = await slot.runtime.run(payload.inputs, { withCaptures: true });\n return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const output = await slot.runtime.run(payload.inputs);\n return { output, captures: null };\n}\nfunction capturesToRecord(captures, shapes) {\n const out = {};\n for (const name of Object.keys(shapes)) {\n if (captures.has(name)) out[name] = captures.get(name);\n }\n return out;\n}\nfunction handleUploadParams(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.uploadParams(payload.params, { partial: payload.partial });\n}\nasync function handleDownloadParams(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParams() };\n}\nasync function handleDownloadParamGrads(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParamGrads() };\n}\nfunction handleResetOptimizer(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.resetOptimizerState();\n if (slot.adam) slot.adam.t = 0;\n}\nfunction handleDestroy(payload) {\n const slot = graphs.get(payload.graphId);\n if (!slot) return;\n slot.runtime.destroy();\n graphs.delete(payload.graphId);\n}\nfunction mustGet(graphId) {\n const slot = graphs.get(graphId);\n if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`);\n return slot;\n}\nself.onmessage = async (ev) => {\n const req = ev.data;\n try {\n let result;\n let transferList = [];\n switch (req.kind) {\n case "createRuntime":\n result = await handleCreateRuntime(req.payload);\n break;\n case "compileForward":\n result = await handleCompileForward(req.payload);\n break;\n case "step":\n result = await handleStep(req.payload);\n transferList = collectTransfers(result.captures);\n break;\n case "run": {\n const r = await handleRun(req.payload);\n result = r;\n transferList = [r.output.buffer, ...collectTransfers(r.captures)];\n break;\n }\n case "uploadParams":\n handleUploadParams(req.payload);\n result = null;\n break;\n case "downloadParams": {\n const r = await handleDownloadParams(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "downloadParamGrads": {\n const r = await handleDownloadParamGrads(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "resetOptimizer":\n handleResetOptimizer(req.payload);\n result = null;\n break;\n case "destroy":\n handleDestroy(req.payload);\n result = null;\n break;\n default:\n throw new Error(`unknown request kind: ${req.kind}`);\n }\n const reply = { id: req.id, ok: true, result };\n self.postMessage(reply, { transfer: transferList });\n } catch (e) {\n const error = wireError(e);\n const reply = { id: req.id, ok: false, error };\n self.postMessage(reply);\n }\n};\nfunction collectTransfers(rec) {\n if (!rec) return [];\n const out = [];\n for (const v of Object.values(rec)) out.push(v.buffer);\n return out;\n}\n');
|
|
1809
1809
|
const wireIR = { graph, plan, kernels };
|
|
1810
1810
|
const wireAdam = adamResult ? wireAdamConfig(adamResult) : null;
|
|
1811
1811
|
const transfers = transferablesOfRecord(initialParams);
|
|
@@ -1843,7 +1843,7 @@ async function compileForward(modelFactory, forward, opts = {}) {
|
|
|
1843
1843
|
const kernels = emitKernels(graph, plan);
|
|
1844
1844
|
const ir = { graph, paramGrads: {}, loss: outputTensor, plan, kernels };
|
|
1845
1845
|
const initialParams = buildInitialParams(plan, materialized.initFns);
|
|
1846
|
-
const proxy = new WorkerProxy('// src/runtime.ts\nvar Captures = class {\n constructor(shapes, data) {\n this.shapes = shapes;\n this.data = data;\n }\n shapes;\n data;\n get(name) {\n const d = this.data.get(name);\n if (!d) {\n const known = [...this.data.keys()].sort().join(", ");\n const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;\n throw new Error(`Captures.get: \'${name}\' not present. ${detail}`);\n }\n return d;\n }\n shapeOf(name) {\n const s = this.shapes[name];\n if (!s) {\n const known = Object.keys(this.shapes).sort().join(", ") || "(none registered)";\n throw new Error(`Captures.shapeOf: \'${name}\' not registered. Known: ${known}`);\n }\n return s;\n }\n has(name) {\n return this.data.has(name);\n }\n names() {\n return [...this.data.keys()].sort();\n }\n};\nvar STORAGE_RW = 128 | 8 | 4;\nvar READBACK = 1 | 8;\nasync function createRuntime(plan, kernels, lossBufferId, opts = {}) {\n const device2 = opts.device ?? await acquireDevice();\n const queue = device2.queue;\n const buffers = /* @__PURE__ */ new Map();\n const ownedBufferIds = /* @__PURE__ */ new Set();\n const sharedParams = opts.sharedParams;\n for (const spec of plan.buffers) {\n const shared = spec.kind === "param" ? sharedParams?.get(spec.name) : void 0;\n if (shared) {\n if (shared.size !== spec.byteSize) {\n throw new Error(\n `sharedParams: size mismatch for \'${spec.name}\' \\u2014 supplied ${shared.size} bytes, compiled graph expects ${spec.byteSize}.`\n );\n }\n buffers.set(spec.id, shared);\n continue;\n }\n const buf = device2.createBuffer({\n size: spec.byteSize,\n usage: STORAGE_RW,\n label: spec.name ?? `t${spec.id}-${spec.kind}`\n });\n buffers.set(spec.id, buf);\n ownedBufferIds.add(spec.id);\n if (spec.kind === "state") fillStateBuffer(spec, buf);\n }\n const moduleCache = /* @__PURE__ */ new Map();\n const pipelines = [];\n const probes = [];\n for (const k of kernels) {\n if (!k.wgsl) {\n pipelines.push(null);\n continue;\n }\n let module = moduleCache.get(k.wgsl);\n if (!module) {\n module = device2.createShaderModule({ code: k.wgsl, label: k.opKind });\n moduleCache.set(k.wgsl, module);\n }\n device2.pushErrorScope("validation");\n const pipeline = device2.createComputePipeline({\n layout: "auto",\n compute: { module, entryPoint: "main" },\n label: k.opKind\n });\n pipelines.push(pipeline);\n probes.push(device2.popErrorScope().then((err) => err ? { k, module, err } : null));\n }\n const probeResults = await Promise.all(probes);\n const failures = probeResults.filter((p) => p != null);\n if (failures.length > 0) {\n const reports = [];\n for (const { k, module, err } of failures) {\n const info = await module.getCompilationInfo();\n const messages = info.messages.map((m) => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`).join("\\n");\n reports.push(\n `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` + (messages || " (no compilation messages)") + `\n--- WGSL ---\n${k.wgsl}\n-----------`\n );\n }\n console.error(reports.join("\\n\\n"));\n throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`);\n }\n const bindGroups = kernels.map((k, i) => {\n const pipeline = pipelines[i];\n if (!pipeline) return null;\n return device2.createBindGroup({\n layout: pipeline.getBindGroupLayout(0),\n entries: k.bindings.map((bufId, idx) => ({\n binding: idx,\n resource: { buffer: buffers.get(bufId) }\n }))\n });\n });\n const outputSpec = plan.buffers[lossBufferId];\n const outputReadback = device2.createBuffer({ size: outputSpec.byteSize, usage: READBACK });\n let captureStaging = null;\n function ensureCaptureStaging() {\n if (captureStaging) return captureStaging;\n let totalBytes = 0;\n const slices = [];\n for (const [name, bufId] of plan.capturesByName) {\n const spec = plan.buffers[bufId];\n slices.push({ name, bufId, offset: totalBytes, byteSize: spec.byteSize });\n totalBytes += spec.byteSize;\n }\n const buffer = device2.createBuffer({ size: totalBytes, usage: READBACK, label: "captures-staging" });\n captureStaging = { buffer, slices };\n return captureStaging;\n }\n let pending = Promise.resolve();\n async function dispatch(inputs, opts2) {\n const turn = pending.catch(() => {\n }).then(() => dispatchUnsynchronized(inputs, opts2));\n pending = turn;\n return turn;\n }\n async function dispatchUnsynchronized(inputs, opts2) {\n const wantCaptures = opts2.wantCaptures;\n if (wantCaptures && plan.capturesByName.size === 0) {\n throw new Error(\n `withCaptures=true but no capture(...) calls were registered during the trace. Add capture(\'name\', tensor) inside your forward pass for the intermediates you want read back.`\n );\n }\n for (const [name, bufId] of plan.inputsByName) {\n const data = inputs[name];\n if (!data) throw new Error(`tensorgrad: missing input \'${name}\'`);\n const expectedBytes = plan.buffers[bufId].byteSize;\n if (data.byteLength !== expectedBytes) {\n throw new Error(`tensorgrad: input \'${name}\' has ${data.byteLength} bytes, expected ${expectedBytes}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-step" });\n for (let i = 0; i < kernels.length; i++) {\n const k = kernels[i];\n if (!k.wgsl || k.threads === 0) continue;\n const pipeline = pipelines[i];\n const bindGroup = bindGroups[i];\n const pass = encoder.beginComputePass({ label: k.opKind });\n pass.setPipeline(pipeline);\n pass.setBindGroup(0, bindGroup);\n const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));\n const MAX_X = 65535;\n const wgX = Math.min(wgCount, MAX_X);\n const wgY = Math.ceil(wgCount / MAX_X);\n pass.dispatchWorkgroups(wgX, wgY, 1);\n pass.end();\n }\n for (const wb of plan.writebacks) {\n encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);\n }\n encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);\n let layout = null;\n if (wantCaptures) {\n layout = ensureCaptureStaging();\n for (const s of layout.slices) {\n encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);\n }\n }\n queue.submit([encoder.finish()]);\n if (!opts2.readback) return null;\n await outputReadback.mapAsync(GPUMapMode.READ);\n const output = new Float32Array(outputReadback.getMappedRange().slice(0));\n outputReadback.unmap();\n const captures = /* @__PURE__ */ new Map();\n if (layout) {\n await layout.buffer.mapAsync(GPUMapMode.READ);\n const range = layout.buffer.getMappedRange();\n for (const s of layout.slices) {\n captures.set(s.name, new Float32Array(range, s.offset, s.byteSize / 4).slice());\n }\n layout.buffer.unmap();\n }\n return { output, captures };\n }\n async function step(inputs, opts2) {\n if (opts2?.readLoss === false) {\n await dispatch(inputs, { wantCaptures: false, readback: false });\n return;\n }\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { loss: r.output[0], captures: new Captures(captureShapes, r.captures) };\n return r.output[0];\n }\n async function readLoss() {\n const turn = pending.catch(() => {\n }).then(async () => {\n await outputReadback.mapAsync(GPUMapMode.READ);\n const v = new Float32Array(outputReadback.getMappedRange())[0];\n outputReadback.unmap();\n return v;\n });\n pending = turn;\n return turn;\n }\n async function run(inputs, opts2) {\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) };\n return r.output;\n }\n function uploadParams(params2, opts2) {\n const partial = opts2?.partial ?? false;\n for (const name of Object.keys(params2)) {\n if (!plan.paramsByName.has(name)) {\n throw new Error(\n `uploadParams: unknown param \'${name}\'. Known: ${[...plan.paramsByName.keys()].sort().join(", ")}`\n );\n }\n }\n if (!partial) {\n for (const name of plan.paramsByName.keys()) {\n if (!(name in params2)) {\n throw new Error(\n `uploadParams: missing param \'${name}\'. Pass { partial: true } if you mean to update only some params.`\n );\n }\n }\n }\n for (const [name, bufId] of plan.paramsByName) {\n const data = params2[name];\n if (!data) continue;\n const expected = plan.buffers[bufId].byteSize / 4;\n if (data.length !== expected) {\n throw new Error(`uploadParams: \'${name}\' has ${data.length} elements, expected ${expected}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n }\n async function downloadFromMap(map) {\n const stagings = [];\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-download" });\n for (const [name, bufId] of map) {\n const spec = plan.buffers[bufId];\n const staging = device2.createBuffer({ size: spec.byteSize, usage: READBACK });\n encoder.copyBufferToBuffer(buffers.get(bufId), 0, staging, 0, spec.byteSize);\n stagings.push({ name, buf: staging, bytes: spec.byteSize });\n }\n queue.submit([encoder.finish()]);\n const out = {};\n for (const s of stagings) {\n await s.buf.mapAsync(GPUMapMode.READ);\n out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0));\n s.buf.unmap();\n s.buf.destroy();\n }\n return out;\n }\n function fillStateBuffer(spec, target) {\n const elements = spec.byteSize / 4;\n const init = spec.dtype === "f32" ? new Float32Array(elements).fill(spec.initValue ?? 0) : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0));\n queue.writeBuffer(target, 0, init);\n }\n function resetOptimizerState() {\n for (const spec of plan.buffers) {\n if (spec.kind === "state") fillStateBuffer(spec, buffers.get(spec.id));\n }\n }\n const params = /* @__PURE__ */ new Map();\n for (const [name, bufId] of plan.paramsByName) {\n params.set(name, buffers.get(bufId));\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const outputShape = [...plan.buffers[lossBufferId].shape];\n const destroy = () => {\n for (const [id, b] of buffers) {\n if (ownedBufferIds.has(id)) b.destroy();\n }\n outputReadback.destroy();\n if (captureStaging) captureStaging.buffer.destroy();\n };\n return {\n device: device2,\n params,\n outputShape,\n uploadParams,\n downloadParams: () => downloadFromMap(plan.paramsByName),\n downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),\n step,\n run,\n readLoss,\n resetOptimizerState,\n destroy\n };\n}\nasync function acquireDevice() {\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad: no WebGPU adapter");\n return await adapter.requestDevice();\n}\n\n// src/adam.ts\nfunction resolveLR(schedule, step) {\n if (typeof schedule === "number") return schedule;\n switch (schedule.kind) {\n case "constant":\n return schedule.value;\n case "linearDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.peak + (schedule.final - schedule.peak) * f;\n }\n case "cosineDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f));\n }\n case "warmup": {\n if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps);\n return resolveLR(schedule.after, step - schedule.warmupSteps);\n }\n }\n}\n\n// src/worker-protocol.ts\nfunction wireError(e) {\n if (e instanceof Error) {\n return { name: e.name, message: e.message, stack: e.stack ?? "" };\n }\n return { name: "Error", message: String(e), stack: "" };\n}\n\n// src/worker.ts\nvar graphs = /* @__PURE__ */ new Map();\nvar device = null;\nasync function ensureDevice() {\n if (device) return device;\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad worker: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad worker: no WebGPU adapter");\n device = await adapter.requestDevice();\n return device;\n}\nasync function handleCreateRuntime(payload) {\n const dev = await ensureDevice();\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n if (Object.keys(payload.initialParams).length > 0) {\n runtime.uploadParams(payload.initialParams);\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: payload.adam ? createAdamState(payload.adam) : null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nasync function handleCompileForward(payload) {\n const dev = await ensureDevice();\n const parent = graphs.get(payload.parentGraphId);\n if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`);\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev, sharedParams: parent.runtime.params };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nfunction createAdamState(cfg) {\n return {\n config: cfg,\n t: 0,\n lrtBuf: new Float32Array(1),\n decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null\n };\n}\nfunction injectAdamScalars(slot, inputs) {\n const a = slot.adam;\n if (!a) return inputs;\n a.t++;\n const lrNow = resolveLR(a.config.lr, a.t);\n a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t));\n const merged = { ...inputs, [a.config.lrtInputName]: a.lrtBuf };\n if (a.decayShrinkBuf && a.config.decayShrinkInputName) {\n a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay;\n merged[a.config.decayShrinkInputName] = a.decayShrinkBuf;\n }\n return merged;\n}\nasync function handleStep(payload) {\n const slot = mustGet(payload.graphId);\n const merged = injectAdamScalars(slot, payload.inputs);\n if (payload.withCaptures) {\n const r = await slot.runtime.step(merged, { withCaptures: true });\n return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const loss = await slot.runtime.step(merged);\n return { loss, captures: null };\n}\nasync function handleRun(payload) {\n const slot = mustGet(payload.graphId);\n if (payload.withCaptures) {\n const r = await slot.runtime.run(payload.inputs, { withCaptures: true });\n return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const output = await slot.runtime.run(payload.inputs);\n return { output, captures: null };\n}\nfunction capturesToRecord(captures, shapes) {\n const out = {};\n for (const name of Object.keys(shapes)) {\n if (captures.has(name)) out[name] = captures.get(name);\n }\n return out;\n}\nfunction handleUploadParams(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.uploadParams(payload.params, { partial: payload.partial });\n}\nasync function handleDownloadParams(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParams() };\n}\nasync function handleDownloadParamGrads(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParamGrads() };\n}\nfunction handleResetOptimizer(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.resetOptimizerState();\n if (slot.adam) slot.adam.t = 0;\n}\nfunction handleDestroy(payload) {\n const slot = graphs.get(payload.graphId);\n if (!slot) return;\n slot.runtime.destroy();\n graphs.delete(payload.graphId);\n}\nfunction mustGet(graphId) {\n const slot = graphs.get(graphId);\n if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`);\n return slot;\n}\nself.onmessage = async (ev) => {\n const req = ev.data;\n try {\n let result;\n let transferList = [];\n switch (req.kind) {\n case "createRuntime":\n result = await handleCreateRuntime(req.payload);\n break;\n case "compileForward":\n result = await handleCompileForward(req.payload);\n break;\n case "step":\n result = await handleStep(req.payload);\n transferList = collectTransfers(result.captures);\n break;\n case "run": {\n const r = await handleRun(req.payload);\n result = r;\n transferList = [r.output.buffer, ...collectTransfers(r.captures)];\n break;\n }\n case "uploadParams":\n handleUploadParams(req.payload);\n result = null;\n break;\n case "downloadParams": {\n const r = await handleDownloadParams(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "downloadParamGrads": {\n const r = await handleDownloadParamGrads(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "resetOptimizer":\n handleResetOptimizer(req.payload);\n result = null;\n break;\n case "destroy":\n handleDestroy(req.payload);\n result = null;\n break;\n default:\n throw new Error(`unknown request kind: ${req.kind}`);\n }\n const reply = { id: req.id, ok: true, result };\n self.postMessage(reply, { transfer: transferList });\n } catch (e) {\n const error = wireError(e);\n const reply = { id: req.id, ok: false, error };\n self.postMessage(reply);\n }\n};\nfunction collectTransfers(rec) {\n if (!rec) return [];\n const out = [];\n for (const v of Object.values(rec)) out.push(v.buffer);\n return out;\n}\n');
|
|
1846
|
+
const proxy = new WorkerProxy('// src/runtime.ts\nvar Captures = class {\n constructor(shapes, data) {\n this.shapes = shapes;\n this.data = data;\n }\n shapes;\n data;\n get(name) {\n const d = this.data.get(name);\n if (!d) {\n const known = [...this.data.keys()].sort().join(", ");\n const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;\n throw new Error(`Captures.get: \'${name}\' not present. ${detail}`);\n }\n return d;\n }\n shapeOf(name) {\n const s = this.shapes[name];\n if (!s) {\n const known = Object.keys(this.shapes).sort().join(", ") || "(none registered)";\n throw new Error(`Captures.shapeOf: \'${name}\' not registered. Known: ${known}`);\n }\n return s;\n }\n has(name) {\n return this.data.has(name);\n }\n names() {\n return [...this.data.keys()].sort();\n }\n};\nvar STORAGE_RW = 128 | 8 | 4;\nvar READBACK = 1 | 8;\nasync function createRuntime(plan, kernels, lossBufferId, opts = {}) {\n const device2 = opts.device ?? await acquireDevice();\n const queue = device2.queue;\n const buffers = /* @__PURE__ */ new Map();\n const ownedBufferIds = /* @__PURE__ */ new Set();\n const sharedParams = opts.sharedParams;\n for (const spec of plan.buffers) {\n const shared = spec.kind === "param" ? sharedParams?.get(spec.name) : void 0;\n if (shared) {\n if (shared.size !== spec.byteSize) {\n throw new Error(\n `sharedParams: size mismatch for \'${spec.name}\' \\u2014 supplied ${shared.size} bytes, compiled graph expects ${spec.byteSize}.`\n );\n }\n buffers.set(spec.id, shared);\n continue;\n }\n const buf = device2.createBuffer({\n size: spec.byteSize,\n usage: STORAGE_RW,\n label: spec.name ?? `t${spec.id}-${spec.kind}`\n });\n buffers.set(spec.id, buf);\n ownedBufferIds.add(spec.id);\n if (spec.kind === "state") fillStateBuffer(spec, buf);\n }\n const moduleCache = /* @__PURE__ */ new Map();\n const pipelines = [];\n const probes = [];\n for (const k of kernels) {\n if (!k.wgsl) {\n pipelines.push(null);\n continue;\n }\n let module = moduleCache.get(k.wgsl);\n if (!module) {\n module = device2.createShaderModule({ code: k.wgsl, label: k.opKind });\n moduleCache.set(k.wgsl, module);\n }\n device2.pushErrorScope("validation");\n const pipeline = device2.createComputePipeline({\n layout: "auto",\n compute: { module, entryPoint: "main" },\n label: k.opKind\n });\n pipelines.push(pipeline);\n probes.push(device2.popErrorScope().then((err) => err ? { k, module, err } : null));\n }\n const probeResults = await Promise.all(probes);\n const failures = probeResults.filter((p) => p != null);\n if (failures.length > 0) {\n const reports = [];\n for (const { k, module, err } of failures) {\n const info = await module.getCompilationInfo();\n const messages = info.messages.map((m) => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`).join("\\n");\n reports.push(\n `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` + (messages || " (no compilation messages)") + `\n--- WGSL ---\n${k.wgsl}\n-----------`\n );\n }\n console.error(reports.join("\\n\\n"));\n throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`);\n }\n const bindGroups = kernels.map((k, i) => {\n const pipeline = pipelines[i];\n if (!pipeline) return null;\n return device2.createBindGroup({\n layout: pipeline.getBindGroupLayout(0),\n entries: k.bindings.map((bufId, idx) => ({\n binding: idx,\n resource: { buffer: buffers.get(bufId) }\n }))\n });\n });\n const outputSpec = plan.buffers[lossBufferId];\n const outputReadback = device2.createBuffer({ size: outputSpec.byteSize, usage: READBACK });\n let captureStaging = null;\n function ensureCaptureStaging() {\n if (captureStaging) return captureStaging;\n let totalBytes = 0;\n const slices = [];\n for (const [name, bufId] of plan.capturesByName) {\n const spec = plan.buffers[bufId];\n slices.push({ name, bufId, offset: totalBytes, byteSize: spec.byteSize });\n totalBytes += spec.byteSize;\n }\n const buffer = device2.createBuffer({ size: totalBytes, usage: READBACK, label: "captures-staging" });\n captureStaging = { buffer, slices };\n return captureStaging;\n }\n let pending = Promise.resolve();\n async function dispatch(inputs, opts2) {\n const turn = pending.catch(() => {\n }).then(() => dispatchUnsynchronized(inputs, opts2));\n pending = turn;\n return turn;\n }\n async function dispatchUnsynchronized(inputs, opts2) {\n const wantCaptures = opts2.wantCaptures;\n if (wantCaptures && plan.capturesByName.size === 0) {\n throw new Error(\n `withCaptures=true but no capture(...) calls were registered during the trace. Add capture(\'name\', tensor) inside your forward pass for the intermediates you want read back.`\n );\n }\n for (const [name, bufId] of plan.inputsByName) {\n const data = inputs[name];\n if (!data) throw new Error(`tensorgrad: missing input \'${name}\'`);\n const expectedBytes = plan.buffers[bufId].byteSize;\n if (data.byteLength !== expectedBytes) {\n throw new Error(`tensorgrad: input \'${name}\' has ${data.byteLength} bytes, expected ${expectedBytes}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n const CHUNK_SIZE = 32;\n let layout = null;\n if (wantCaptures) {\n layout = ensureCaptureStaging();\n }\n let kernelIdx = 0;\n while (kernelIdx < kernels.length) {\n const chunkEnd = Math.min(kernelIdx + CHUNK_SIZE, kernels.length);\n const isLast = chunkEnd === kernels.length;\n const encoder = device2.createCommandEncoder({\n label: kernels.length > CHUNK_SIZE ? `tensorgrad-chunk-${kernelIdx}` : "tensorgrad-step"\n });\n for (let i = kernelIdx; i < chunkEnd; i++) {\n const k = kernels[i];\n if (!k.wgsl || k.threads === 0) continue;\n const pipeline = pipelines[i];\n const bindGroup = bindGroups[i];\n const pass = encoder.beginComputePass({ label: k.opKind });\n pass.setPipeline(pipeline);\n pass.setBindGroup(0, bindGroup);\n const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));\n const MAX_X = 65535;\n const wgX = Math.min(wgCount, MAX_X);\n const wgY = Math.ceil(wgCount / MAX_X);\n pass.dispatchWorkgroups(wgX, wgY, 1);\n pass.end();\n }\n if (isLast) {\n for (const wb of plan.writebacks) {\n encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);\n }\n encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);\n if (layout) {\n for (const s of layout.slices) {\n encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);\n }\n }\n }\n queue.submit([encoder.finish()]);\n if (!isLast) {\n await queue.onSubmittedWorkDone();\n }\n kernelIdx = chunkEnd;\n }\n if (!opts2.readback) return null;\n await outputReadback.mapAsync(GPUMapMode.READ);\n const output = new Float32Array(outputReadback.getMappedRange().slice(0));\n outputReadback.unmap();\n const captures = /* @__PURE__ */ new Map();\n if (layout) {\n await layout.buffer.mapAsync(GPUMapMode.READ);\n const range = layout.buffer.getMappedRange();\n for (const s of layout.slices) {\n captures.set(s.name, new Float32Array(range, s.offset, s.byteSize / 4).slice());\n }\n layout.buffer.unmap();\n }\n return { output, captures };\n }\n async function step(inputs, opts2) {\n if (opts2?.readLoss === false) {\n await dispatch(inputs, { wantCaptures: false, readback: false });\n return;\n }\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { loss: r.output[0], captures: new Captures(captureShapes, r.captures) };\n return r.output[0];\n }\n async function readLoss() {\n const turn = pending.catch(() => {\n }).then(async () => {\n await outputReadback.mapAsync(GPUMapMode.READ);\n const v = new Float32Array(outputReadback.getMappedRange())[0];\n outputReadback.unmap();\n return v;\n });\n pending = turn;\n return turn;\n }\n async function run(inputs, opts2) {\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) };\n return r.output;\n }\n function uploadParams(params2, opts2) {\n const partial = opts2?.partial ?? false;\n for (const name of Object.keys(params2)) {\n if (!plan.paramsByName.has(name)) {\n throw new Error(\n `uploadParams: unknown param \'${name}\'. Known: ${[...plan.paramsByName.keys()].sort().join(", ")}`\n );\n }\n }\n if (!partial) {\n for (const name of plan.paramsByName.keys()) {\n if (!(name in params2)) {\n throw new Error(\n `uploadParams: missing param \'${name}\'. Pass { partial: true } if you mean to update only some params.`\n );\n }\n }\n }\n for (const [name, bufId] of plan.paramsByName) {\n const data = params2[name];\n if (!data) continue;\n const expected = plan.buffers[bufId].byteSize / 4;\n if (data.length !== expected) {\n throw new Error(`uploadParams: \'${name}\' has ${data.length} elements, expected ${expected}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n }\n async function downloadFromMap(map) {\n const stagings = [];\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-download" });\n for (const [name, bufId] of map) {\n const spec = plan.buffers[bufId];\n const staging = device2.createBuffer({ size: spec.byteSize, usage: READBACK });\n encoder.copyBufferToBuffer(buffers.get(bufId), 0, staging, 0, spec.byteSize);\n stagings.push({ name, buf: staging, bytes: spec.byteSize });\n }\n queue.submit([encoder.finish()]);\n const out = {};\n for (const s of stagings) {\n await s.buf.mapAsync(GPUMapMode.READ);\n out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0));\n s.buf.unmap();\n s.buf.destroy();\n }\n return out;\n }\n function fillStateBuffer(spec, target) {\n const elements = spec.byteSize / 4;\n const init = spec.dtype === "f32" ? new Float32Array(elements).fill(spec.initValue ?? 0) : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0));\n queue.writeBuffer(target, 0, init);\n }\n function resetOptimizerState() {\n for (const spec of plan.buffers) {\n if (spec.kind === "state") fillStateBuffer(spec, buffers.get(spec.id));\n }\n }\n const params = /* @__PURE__ */ new Map();\n for (const [name, bufId] of plan.paramsByName) {\n params.set(name, buffers.get(bufId));\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const outputShape = [...plan.buffers[lossBufferId].shape];\n const destroy = () => {\n for (const [id, b] of buffers) {\n if (ownedBufferIds.has(id)) b.destroy();\n }\n outputReadback.destroy();\n if (captureStaging) captureStaging.buffer.destroy();\n };\n return {\n device: device2,\n params,\n outputShape,\n uploadParams,\n downloadParams: () => downloadFromMap(plan.paramsByName),\n downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),\n step,\n run,\n readLoss,\n resetOptimizerState,\n destroy\n };\n}\nasync function acquireDevice() {\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad: no WebGPU adapter");\n return await adapter.requestDevice();\n}\n\n// src/adam.ts\nfunction resolveLR(schedule, step) {\n if (typeof schedule === "number") return schedule;\n switch (schedule.kind) {\n case "constant":\n return schedule.value;\n case "linearDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.peak + (schedule.final - schedule.peak) * f;\n }\n case "cosineDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f));\n }\n case "warmup": {\n if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps);\n return resolveLR(schedule.after, step - schedule.warmupSteps);\n }\n }\n}\n\n// src/worker-protocol.ts\nfunction wireError(e) {\n if (e instanceof Error) {\n return { name: e.name, message: e.message, stack: e.stack ?? "" };\n }\n return { name: "Error", message: String(e), stack: "" };\n}\n\n// src/worker.ts\nvar graphs = /* @__PURE__ */ new Map();\nvar device = null;\nasync function ensureDevice() {\n if (device) return device;\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad worker: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad worker: no WebGPU adapter");\n device = await adapter.requestDevice();\n return device;\n}\nasync function handleCreateRuntime(payload) {\n const dev = await ensureDevice();\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n if (Object.keys(payload.initialParams).length > 0) {\n runtime.uploadParams(payload.initialParams);\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: payload.adam ? createAdamState(payload.adam) : null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nasync function handleCompileForward(payload) {\n const dev = await ensureDevice();\n const parent = graphs.get(payload.parentGraphId);\n if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`);\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev, sharedParams: parent.runtime.params };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nfunction createAdamState(cfg) {\n return {\n config: cfg,\n t: 0,\n lrtBuf: new Float32Array(1),\n decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null\n };\n}\nfunction injectAdamScalars(slot, inputs) {\n const a = slot.adam;\n if (!a) return inputs;\n a.t++;\n const lrNow = resolveLR(a.config.lr, a.t);\n a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t));\n const merged = { ...inputs, [a.config.lrtInputName]: a.lrtBuf };\n if (a.decayShrinkBuf && a.config.decayShrinkInputName) {\n a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay;\n merged[a.config.decayShrinkInputName] = a.decayShrinkBuf;\n }\n return merged;\n}\nasync function handleStep(payload) {\n const slot = mustGet(payload.graphId);\n const merged = injectAdamScalars(slot, payload.inputs);\n if (payload.withCaptures) {\n const r = await slot.runtime.step(merged, { withCaptures: true });\n return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const loss = await slot.runtime.step(merged);\n return { loss, captures: null };\n}\nasync function handleRun(payload) {\n const slot = mustGet(payload.graphId);\n if (payload.withCaptures) {\n const r = await slot.runtime.run(payload.inputs, { withCaptures: true });\n return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const output = await slot.runtime.run(payload.inputs);\n return { output, captures: null };\n}\nfunction capturesToRecord(captures, shapes) {\n const out = {};\n for (const name of Object.keys(shapes)) {\n if (captures.has(name)) out[name] = captures.get(name);\n }\n return out;\n}\nfunction handleUploadParams(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.uploadParams(payload.params, { partial: payload.partial });\n}\nasync function handleDownloadParams(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParams() };\n}\nasync function handleDownloadParamGrads(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParamGrads() };\n}\nfunction handleResetOptimizer(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.resetOptimizerState();\n if (slot.adam) slot.adam.t = 0;\n}\nfunction handleDestroy(payload) {\n const slot = graphs.get(payload.graphId);\n if (!slot) return;\n slot.runtime.destroy();\n graphs.delete(payload.graphId);\n}\nfunction mustGet(graphId) {\n const slot = graphs.get(graphId);\n if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`);\n return slot;\n}\nself.onmessage = async (ev) => {\n const req = ev.data;\n try {\n let result;\n let transferList = [];\n switch (req.kind) {\n case "createRuntime":\n result = await handleCreateRuntime(req.payload);\n break;\n case "compileForward":\n result = await handleCompileForward(req.payload);\n break;\n case "step":\n result = await handleStep(req.payload);\n transferList = collectTransfers(result.captures);\n break;\n case "run": {\n const r = await handleRun(req.payload);\n result = r;\n transferList = [r.output.buffer, ...collectTransfers(r.captures)];\n break;\n }\n case "uploadParams":\n handleUploadParams(req.payload);\n result = null;\n break;\n case "downloadParams": {\n const r = await handleDownloadParams(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "downloadParamGrads": {\n const r = await handleDownloadParamGrads(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "resetOptimizer":\n handleResetOptimizer(req.payload);\n result = null;\n break;\n case "destroy":\n handleDestroy(req.payload);\n result = null;\n break;\n default:\n throw new Error(`unknown request kind: ${req.kind}`);\n }\n const reply = { id: req.id, ok: true, result };\n self.postMessage(reply, { transfer: transferList });\n } catch (e) {\n const error = wireError(e);\n const reply = { id: req.id, ok: false, error };\n self.postMessage(reply);\n }\n};\nfunction collectTransfers(rec) {\n if (!rec) return [];\n const out = [];\n for (const v of Object.values(rec)) out.push(v.buffer);\n return out;\n}\n');
|
|
1847
1847
|
const wireIR = { graph, plan, kernels };
|
|
1848
1848
|
const transfers = transferablesOfRecord(initialParams);
|
|
1849
1849
|
let meta;
|