tensorgrad 0.0.15 → 0.0.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,39 +1,2208 @@
1
- // Public surface. Bulb code imports from here.
2
- //
3
- // Phase 1 exports: IR types, op surface, trace driver. Autograd (Phase 2) and
4
- // codegen / compile() (Phase 3+) come later.
5
- export { ShapeError } from './shape.js';
6
- export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js';
7
- export { capture } from './capture.js';
8
- export {
9
- // Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
10
- add, sub, mul, div,
11
- // Element-wise unary
12
- sqrt, rsqrt, log, exp, relu,
13
- // Comparisons + select
14
- less, greater, where,
15
- // Reductions over the last axis (other axes via reshape/transpose first)
16
- meanLast, sumLast, sumAll,
17
- // Shape ops
18
- reshape, transpose, swapAxes,
19
- // Linear algebra
20
- matmul, matmulBatched,
21
- // Indexing / casting
22
- oneHot, arange, embedding,
23
- // ML primitives — fused for the transformer
24
- softmaxCausalLast, logSoftmaxLast, whereCausal,
25
- // Slicing
26
- sliceLastRange, } from './ops.js';
27
- // Note: addScalar/mulScalar/broadcastTo/sumToShape/constScalar/reluGrad/adam_update_*
28
- // are autograd/optimizer building blocks. They live in ops.ts (so grad.ts and
29
- // adam.ts can import them) but aren't part of the public API — `add`/`mul`
30
- // overload on JS numbers, `where` subsumes the rest.
31
- export { appendGrad } from './grad.js';
32
- export { appendAdam } from './adam.js';
33
- export { planBuffers } from './buffers.js';
34
- export { emitKernels } from './codegen.js';
35
- export { createRuntime, createForwardRuntime, Captures } from './runtime.js';
36
- export { compile, compileToIR, compileModule, compileForward, } from './compile.js';
37
- export { Module, materializeParams } from './module.js';
38
- export * as nn from './nn.js';
39
- //# sourceMappingURL=index.js.map
1
+ var __defProp = Object.defineProperty;
2
+ var __export = (target, all) => {
3
+ for (var name in all)
4
+ __defProp(target, name, { get: all[name], enumerable: true });
5
+ };
6
+
7
+ // src/ir.ts
8
+ function makeGraph() {
9
+ return { ops: [], tensors: [], outputs: [], captures: /* @__PURE__ */ new Map() };
10
+ }
11
+ function addTensor(g, shape, dtype, source, site) {
12
+ const id = g.tensors.length;
13
+ const t = { id, shape, dtype, source, site };
14
+ g.tensors.push(t);
15
+ return t;
16
+ }
17
+ function addOp(g, kind, shape, dtype, site, fields) {
18
+ const opIndex = g.ops.length;
19
+ const out = addTensor(g, shape, dtype, opIndex, site);
20
+ const node = { kind, out: out.id, ...fields };
21
+ g.ops.push(node);
22
+ return out;
23
+ }
24
+ function captureSite(opName) {
25
+ const stack = new Error().stack ?? "";
26
+ return { opName, stack };
27
+ }
28
+ function formatSite(site) {
29
+ const lines = site.stack.split("\n");
30
+ const userFrames = [];
31
+ for (const line of lines.slice(1)) {
32
+ if (line.includes("/tensorgrad/src/") || line.includes("\\tensorgrad\\src\\")) continue;
33
+ userFrames.push(line.trim());
34
+ if (userFrames.length >= 3) break;
35
+ }
36
+ if (userFrames.length === 0) return `[${site.opName}] (no user frame found)`;
37
+ return `[${site.opName}]
38
+ ${userFrames.join("\n ")}`;
39
+ }
40
+
41
+ // src/shape.ts
42
+ var ShapeError = class extends Error {
43
+ constructor(message, site) {
44
+ const formatted = site ? `${message}
45
+ at ${formatSite(site)}` : message;
46
+ super(formatted);
47
+ this.name = "ShapeError";
48
+ }
49
+ };
50
+ function fail(message, site) {
51
+ throw new ShapeError(message, site);
52
+ }
53
+ function shapesEqual(a, b) {
54
+ if (a.length !== b.length) return false;
55
+ for (let i = 0; i < a.length; i++) if (a[i] !== b[i]) return false;
56
+ return true;
57
+ }
58
+ function shapeSize(shape) {
59
+ let n = 1;
60
+ for (const d of shape) n *= d;
61
+ return n;
62
+ }
63
+ function showShape(shape) {
64
+ return `[${shape.join(", ")}]`;
65
+ }
66
+ function broadcastTrailing(a, b) {
67
+ const rank = Math.max(a.length, b.length);
68
+ const out = new Array(rank);
69
+ for (let i = 0; i < rank; i++) {
70
+ const ai = i - (rank - a.length);
71
+ const bi = i - (rank - b.length);
72
+ const av = ai < 0 ? 1 : a[ai];
73
+ const bv = bi < 0 ? 1 : b[bi];
74
+ if (av === bv) out[i] = av;
75
+ else if (av === 1) out[i] = bv;
76
+ else if (bv === 1) out[i] = av;
77
+ else return null;
78
+ }
79
+ return out;
80
+ }
81
+ function inferElementwiseBinop(opName, aShape, bShape, site) {
82
+ const result = broadcastTrailing(aShape, bShape);
83
+ if (!result) {
84
+ fail(
85
+ `${opName}: incompatible shapes ${showShape(aShape)} and ${showShape(bShape)}. Trailing-suffix broadcasting only \u2014 the smaller shape must be a suffix of the larger, with size-1 axes broadcasting to any size.`,
86
+ site
87
+ );
88
+ }
89
+ return result;
90
+ }
91
+ function inferUnary(_opName, aShape, _site) {
92
+ return aShape;
93
+ }
94
+ function inferMeanLast(opName, aShape, site) {
95
+ if (aShape.length === 0) fail(`${opName}: cannot reduce a 0-d tensor`, site);
96
+ return [...aShape.slice(0, -1), 1];
97
+ }
98
+ function inferSumLast(opName, aShape, site) {
99
+ if (aShape.length === 0) fail(`${opName}: cannot reduce a 0-d tensor`, site);
100
+ return aShape.slice(0, -1);
101
+ }
102
+ function inferReshape(opName, aShape, newShape, site) {
103
+ let inferIdx = -1;
104
+ let knownSize = 1;
105
+ for (let i = 0; i < newShape.length; i++) {
106
+ const d = newShape[i];
107
+ if (d === -1) {
108
+ if (inferIdx !== -1) fail(`${opName}: at most one -1 dim allowed in newShape ${showShape(newShape)}`, site);
109
+ inferIdx = i;
110
+ } else if (d <= 0) {
111
+ fail(`${opName}: invalid dim ${d} in newShape ${showShape(newShape)}`, site);
112
+ } else {
113
+ knownSize *= d;
114
+ }
115
+ }
116
+ const totalIn = shapeSize(aShape);
117
+ const out = [...newShape];
118
+ if (inferIdx !== -1) {
119
+ if (totalIn % knownSize !== 0) {
120
+ fail(`${opName}: cannot reshape ${showShape(aShape)} (size ${totalIn}) to ${showShape(newShape)} \u2014 known dims multiply to ${knownSize}`, site);
121
+ }
122
+ out[inferIdx] = totalIn / knownSize;
123
+ } else if (knownSize !== totalIn) {
124
+ fail(`${opName}: size mismatch \u2014 input ${showShape(aShape)} has ${totalIn} elements but newShape ${showShape(newShape)} has ${knownSize}`, site);
125
+ }
126
+ return out;
127
+ }
128
+ function inferTranspose(opName, aShape, perm, site) {
129
+ if (perm.length !== aShape.length) {
130
+ fail(`${opName}: perm length ${perm.length} must equal input rank ${aShape.length}`, site);
131
+ }
132
+ const seen = /* @__PURE__ */ new Set();
133
+ for (const p of perm) {
134
+ if (p < 0 || p >= aShape.length) fail(`${opName}: perm index ${p} out of range for rank ${aShape.length}`, site);
135
+ if (seen.has(p)) fail(`${opName}: perm has duplicate index ${p}`, site);
136
+ seen.add(p);
137
+ }
138
+ return perm.map((p) => aShape[p]);
139
+ }
140
+ function inferMatmul(opName, aShape, bShape, site) {
141
+ if (aShape.length < 2) fail(`${opName}: lhs must have rank >= 2, got ${showShape(aShape)}`, site);
142
+ if (bShape.length !== 2) fail(`${opName}: rhs must have rank 2, got ${showShape(bShape)} \u2014 use matmulBatched for batched rhs`, site);
143
+ const M = aShape[aShape.length - 2];
144
+ const Ka = aShape[aShape.length - 1];
145
+ const Kb = bShape[0];
146
+ const N = bShape[1];
147
+ if (Ka !== Kb) fail(`${opName}: inner dims don't match \u2014 ${showShape(aShape)} \xB7 ${showShape(bShape)} (last axis of lhs = ${Ka}, first axis of rhs = ${Kb})`, site);
148
+ return [...aShape.slice(0, -2), M, N];
149
+ }
150
+ function inferMatmulBatched(opName, aShape, bShape, site) {
151
+ if (aShape.length < 2 || bShape.length < 2) {
152
+ fail(`${opName}: both inputs must have rank >= 2, got ${showShape(aShape)} and ${showShape(bShape)}`, site);
153
+ }
154
+ if (aShape.length !== bShape.length) {
155
+ fail(`${opName}: ranks must match (got ${aShape.length} vs ${bShape.length}). Reshape if you need different batch dims.`, site);
156
+ }
157
+ const aBatch = aShape.slice(0, -2);
158
+ const bBatch = bShape.slice(0, -2);
159
+ for (let i = 0; i < aBatch.length; i++) {
160
+ if (aBatch[i] !== bBatch[i]) {
161
+ fail(`${opName}: batch dims must match \u2014 ${showShape(aShape)} vs ${showShape(bShape)}`, site);
162
+ }
163
+ }
164
+ const M = aShape[aShape.length - 2];
165
+ const Ka = aShape[aShape.length - 1];
166
+ const Kb = bShape[bShape.length - 2];
167
+ const N = bShape[bShape.length - 1];
168
+ if (Ka !== Kb) fail(`${opName}: inner dims don't match \u2014 last axis of lhs = ${Ka}, second-to-last of rhs = ${Kb}`, site);
169
+ return [...aBatch, M, N];
170
+ }
171
+ function inferOneHot(opName, indicesShape, depth, site) {
172
+ if (depth <= 0) fail(`${opName}: depth must be positive, got ${depth}`, site);
173
+ return [...indicesShape, depth];
174
+ }
175
+ function inferWhereCausal(opName, aShape, site) {
176
+ if (aShape.length < 2) fail(`${opName}: requires rank >= 2, got ${showShape(aShape)}`, site);
177
+ const m = aShape[aShape.length - 2];
178
+ const n = aShape[aShape.length - 1];
179
+ if (m !== n) fail(`${opName}: last two axes must be equal (square mask), got ${showShape(aShape)}`, site);
180
+ return aShape;
181
+ }
182
+ function inferSliceLastRange(opName, aShape, start, end, site) {
183
+ if (aShape.length === 0) fail(`${opName}: cannot slice 0-d tensor`, site);
184
+ const last = aShape[aShape.length - 1];
185
+ if (start < 0 || end > last || start >= end) {
186
+ fail(`${opName}: invalid range [${start}, ${end}) for last axis of size ${last}`, site);
187
+ }
188
+ return [...aShape.slice(0, -1), end - start];
189
+ }
190
+ function inferBroadcastTo(opName, aShape, targetShape, site) {
191
+ if (aShape.length > targetShape.length) {
192
+ fail(`${opName}: source rank ${aShape.length} > target rank ${targetShape.length}`, site);
193
+ }
194
+ const offset = targetShape.length - aShape.length;
195
+ for (let i = 0; i < aShape.length; i++) {
196
+ const av = aShape[i];
197
+ const tv = targetShape[offset + i];
198
+ if (av !== tv && av !== 1) {
199
+ fail(`${opName}: cannot broadcast ${showShape(aShape)} to ${showShape(targetShape)} \u2014 axis ${i} (size ${av}) doesn't match target axis ${offset + i} (size ${tv}) and isn't 1`, site);
200
+ }
201
+ }
202
+ return targetShape;
203
+ }
204
+ function inferSumToShape(opName, aShape, targetShape, site) {
205
+ if (targetShape.length > aShape.length) {
206
+ fail(`${opName}: target rank ${targetShape.length} > source rank ${aShape.length}`, site);
207
+ }
208
+ const offset = aShape.length - targetShape.length;
209
+ for (let i = 0; i < targetShape.length; i++) {
210
+ const av = aShape[offset + i];
211
+ const tv = targetShape[i];
212
+ if (av !== tv && tv !== 1) {
213
+ fail(`${opName}: cannot sum-reduce ${showShape(aShape)} to ${showShape(targetShape)} \u2014 target axis ${i} (size ${tv}) must be 1 or match source`, site);
214
+ }
215
+ }
216
+ return targetShape;
217
+ }
218
+ function inferWhere(opName, condShape, aShape, bShape, site) {
219
+ const ab = broadcastTrailing(aShape, bShape);
220
+ if (!ab) fail(`${opName}: a/b incompatible: ${showShape(aShape)} vs ${showShape(bShape)}`, site);
221
+ const result = broadcastTrailing(condShape, ab);
222
+ if (!result) fail(`${opName}: cond ${showShape(condShape)} incompatible with broadcast(a, b) ${showShape(ab)}`, site);
223
+ return result;
224
+ }
225
+ function inferReluGrad(opName, xShape, dyShape, site) {
226
+ if (!shapesEqual(xShape, dyShape)) {
227
+ fail(`${opName}: x and dy must have matching shapes, got ${showShape(xShape)} and ${showShape(dyShape)}`, site);
228
+ }
229
+ return xShape;
230
+ }
231
+
232
+ // src/trace.ts
233
+ var _current = null;
234
+ var _captureEnabled = false;
235
+ function currentGraph() {
236
+ if (!_current) {
237
+ throw new Error(
238
+ "tensorgrad: ops can only be called inside trace(). Did you forget to wrap your forward pass?"
239
+ );
240
+ }
241
+ return _current;
242
+ }
243
+ function isCaptureEnabled() {
244
+ return _captureEnabled;
245
+ }
246
+ function trace(fn) {
247
+ if (_current) {
248
+ throw new Error("tensorgrad: nested trace() is not supported");
249
+ }
250
+ const g = makeGraph();
251
+ _current = g;
252
+ _captureEnabled = true;
253
+ try {
254
+ const result = fn();
255
+ const outputs = Array.isArray(result) ? result : [result];
256
+ for (const t of outputs) {
257
+ ;
258
+ g.outputs.push(t.id);
259
+ }
260
+ } finally {
261
+ _current = null;
262
+ _captureEnabled = false;
263
+ }
264
+ return g;
265
+ }
266
+ function traceInto(g, fn) {
267
+ if (_current) {
268
+ throw new Error("tensorgrad: traceInto() called while another trace is active");
269
+ }
270
+ _current = g;
271
+ try {
272
+ return fn();
273
+ } finally {
274
+ _current = null;
275
+ }
276
+ }
277
+ function assertNameUnused(g, name, kinds, label) {
278
+ if (g.ops.some((op) => kinds.includes(op.kind) && op.name === name)) {
279
+ throw new Error(`tensorgrad: ${label} name '${name}' already used in this trace`);
280
+ }
281
+ }
282
+ function paramInput(name, shape, dtype = "f32") {
283
+ const g = currentGraph();
284
+ assertNameUnused(g, name, ["param_input", "tensor_input"], "input");
285
+ const site = captureSite("paramInput");
286
+ return addOp(g, "param_input", shape, dtype, site, { name });
287
+ }
288
+ function tensorInput(name, shape, dtype = "f32") {
289
+ const g = currentGraph();
290
+ assertNameUnused(g, name, ["param_input", "tensor_input"], "input");
291
+ const site = captureSite("tensorInput");
292
+ return addOp(g, "tensor_input", shape, dtype, site, { name });
293
+ }
294
+ function stateInput(name, shape, dtype = "f32", initValue = 0) {
295
+ const g = currentGraph();
296
+ assertNameUnused(g, name, ["state_input"], "state");
297
+ const site = captureSite("stateInput");
298
+ return addOp(g, "state_input", shape, dtype, site, { name, initValue });
299
+ }
300
+
301
+ // src/capture.ts
302
+ function capture(name, t) {
303
+ if (!isCaptureEnabled()) return t;
304
+ const g = currentGraph();
305
+ if (g.captures.has(name)) {
306
+ throw new Error(
307
+ `capture: name '${name}' already registered. Use unique names (e.g. \`attn.\${layerIdx}\`) when capturing across a loop.`
308
+ );
309
+ }
310
+ g.captures.set(name, t.id);
311
+ return t;
312
+ }
313
+
314
+ // src/ops.ts
315
+ function binopOp(name, kind, a, b, outDtype = a.dtype) {
316
+ const site = captureSite(name);
317
+ if (a.dtype !== b.dtype) throw new ShapeError(`${name}: dtype mismatch (${a.dtype} vs ${b.dtype})`, site);
318
+ const outShape = inferElementwiseBinop(name, a.shape, b.shape, site);
319
+ return addOp(currentGraph(), kind, outShape, outDtype, site, { a: a.id, b: b.id });
320
+ }
321
+ function add(a, b) {
322
+ return typeof b === "number" ? addScalar(a, b) : binopOp("add", "add", a, b);
323
+ }
324
+ function sub(a, b) {
325
+ return typeof b === "number" ? addScalar(a, -b) : binopOp("sub", "sub", a, b);
326
+ }
327
+ function mul(a, b) {
328
+ return typeof b === "number" ? mulScalar(a, b) : binopOp("mul", "mul", a, b);
329
+ }
330
+ function div(a, b) {
331
+ if (typeof b === "number") {
332
+ if (b === 0) throw new ShapeError(`div: scalar divisor cannot be zero`, captureSite("div"));
333
+ return mulScalar(a, 1 / b);
334
+ }
335
+ return binopOp("div", "div", a, b);
336
+ }
337
+ function mulScalar(a, scalar) {
338
+ const site = captureSite("mulScalar");
339
+ return addOp(currentGraph(), "mul_scalar", a.shape, a.dtype, site, { a: a.id, scalar });
340
+ }
341
+ function addScalar(a, scalar) {
342
+ const site = captureSite("addScalar");
343
+ return addOp(currentGraph(), "add_scalar", a.shape, a.dtype, site, { a: a.id, scalar });
344
+ }
345
+ function unary(name, a) {
346
+ const site = captureSite(name);
347
+ if (a.dtype !== "f32") throw new ShapeError(`${name}: requires f32, got ${a.dtype}`, site);
348
+ return addOp(currentGraph(), name, inferUnary(name, a.shape, site), "f32", site, { a: a.id });
349
+ }
350
+ var sqrt = (a) => unary("sqrt", a);
351
+ var rsqrt = (a) => unary("rsqrt", a);
352
+ var log = (a) => unary("log", a);
353
+ var exp = (a) => unary("exp", a);
354
+ var relu = (a) => unary("relu", a);
355
+ function meanLast(a) {
356
+ const site = captureSite("meanLast");
357
+ if (a.dtype !== "f32") throw new ShapeError(`meanLast: requires f32, got ${a.dtype}`, site);
358
+ const outShape = inferMeanLast("meanLast", a.shape, site);
359
+ return addOp(currentGraph(), "mean_last", outShape, a.dtype, site, { a: a.id });
360
+ }
361
+ function sumLast(a) {
362
+ const site = captureSite("sumLast");
363
+ if (a.dtype !== "f32") throw new ShapeError(`sumLast: requires f32, got ${a.dtype}`, site);
364
+ const outShape = inferSumLast("sumLast", a.shape, site);
365
+ return addOp(currentGraph(), "sum_last", outShape, a.dtype, site, { a: a.id });
366
+ }
367
+ function sumAll(a) {
368
+ return sumLast(reshape(a, [-1]));
369
+ }
370
+ function reshape(a, newShape) {
371
+ const site = captureSite("reshape");
372
+ const outShape = inferReshape("reshape", a.shape, newShape, site);
373
+ return addOp(currentGraph(), "reshape", outShape, a.dtype, site, { a: a.id, newShape: outShape });
374
+ }
375
+ function transpose(a, perm) {
376
+ const site = captureSite("transpose");
377
+ const outShape = inferTranspose("transpose", a.shape, perm, site);
378
+ return addOp(currentGraph(), "transpose", outShape, a.dtype, site, { a: a.id, perm });
379
+ }
380
+ function swapAxes(a, axis1, axis2) {
381
+ const r = a.shape.length;
382
+ const norm = (axis) => axis < 0 ? r + axis : axis;
383
+ const i1 = norm(axis1);
384
+ const i2 = norm(axis2);
385
+ const site = captureSite("swapAxes");
386
+ if (i1 < 0 || i1 >= r || i2 < 0 || i2 >= r) {
387
+ throw new ShapeError(`swapAxes: axis out of range \u2014 got (${axis1}, ${axis2}) for rank-${r} tensor`, site);
388
+ }
389
+ if (i1 === i2) return a;
390
+ const perm = Array.from({ length: r }, (_, k) => k);
391
+ perm[i1] = i2;
392
+ perm[i2] = i1;
393
+ return transpose(a, perm);
394
+ }
395
+ function matmul(a, b) {
396
+ const site = captureSite("matmul");
397
+ if (a.dtype !== "f32" || b.dtype !== "f32") {
398
+ throw new ShapeError(`matmul: requires f32, got ${a.dtype} and ${b.dtype}`, site);
399
+ }
400
+ const outShape = inferMatmul("matmul", a.shape, b.shape, site);
401
+ return addOp(currentGraph(), "matmul", outShape, "f32", site, { a: a.id, b: b.id });
402
+ }
403
+ function matmulBatched(a, b) {
404
+ const site = captureSite("matmulBatched");
405
+ if (a.dtype !== "f32" || b.dtype !== "f32") {
406
+ throw new ShapeError(`matmulBatched: requires f32, got ${a.dtype} and ${b.dtype}`, site);
407
+ }
408
+ const outShape = inferMatmulBatched("matmulBatched", a.shape, b.shape, site);
409
+ return addOp(currentGraph(), "matmul_batched", outShape, "f32", site, { a: a.id, b: b.id });
410
+ }
411
+ function oneHot(indices, depth, dtype = "f32") {
412
+ const site = captureSite("oneHot");
413
+ if (indices.dtype !== "i32") {
414
+ throw new ShapeError(`oneHot: indices must be i32, got ${indices.dtype}`, site);
415
+ }
416
+ const outShape = inferOneHot("oneHot", indices.shape, depth, site);
417
+ return addOp(currentGraph(), "one_hot", outShape, dtype, site, { indices: indices.id, depth, dtype });
418
+ }
419
+ function embedding(table, indices) {
420
+ const site = captureSite("embedding");
421
+ if (table.shape.length !== 2) {
422
+ throw new ShapeError(`embedding: table must be 2-d [vocab, dim], got ${showShape(table.shape)}`, site);
423
+ }
424
+ if (indices.dtype !== "i32") {
425
+ throw new ShapeError(`embedding: indices must be i32, got ${indices.dtype}`, site);
426
+ }
427
+ return matmul(oneHot(indices, table.shape[0], "f32"), table);
428
+ }
429
+ function arange(n, dtype = "i32") {
430
+ const site = captureSite("arange");
431
+ if (n <= 0 || !Number.isInteger(n)) {
432
+ throw new ShapeError(`arange: n must be a positive integer, got ${n}`, site);
433
+ }
434
+ return addOp(currentGraph(), "arange", [n], dtype, site, { n, dtype });
435
+ }
436
+ function softmaxCausalLast(a) {
437
+ const site = captureSite("softmaxCausalLast");
438
+ if (a.dtype !== "f32") throw new ShapeError(`softmaxCausalLast: requires f32, got ${a.dtype}`, site);
439
+ inferWhereCausal("softmaxCausalLast", a.shape, site);
440
+ return addOp(currentGraph(), "softmax_causal_last", a.shape, "f32", site, { a: a.id });
441
+ }
442
+ function logSoftmaxLast(a) {
443
+ const site = captureSite("logSoftmaxLast");
444
+ if (a.dtype !== "f32") throw new ShapeError(`logSoftmaxLast: requires f32, got ${a.dtype}`, site);
445
+ return addOp(currentGraph(), "log_softmax_last", a.shape, "f32", site, { a: a.id });
446
+ }
447
+ function whereCausal(a, fillValue) {
448
+ const site = captureSite("whereCausal");
449
+ if (a.dtype !== "f32") throw new ShapeError(`whereCausal: requires f32, got ${a.dtype}`, site);
450
+ inferWhereCausal("whereCausal", a.shape, site);
451
+ return addOp(currentGraph(), "where_causal", a.shape, "f32", site, { a: a.id, fillValue });
452
+ }
453
+ function sliceLastRange(a, start, end) {
454
+ const site = captureSite("sliceLastRange");
455
+ const outShape = inferSliceLastRange("sliceLastRange", a.shape, start, end, site);
456
+ return addOp(currentGraph(), "slice_last_range", outShape, a.dtype, site, { a: a.id, start, end });
457
+ }
458
+ function broadcastTo(a, targetShape) {
459
+ const site = captureSite("broadcastTo");
460
+ inferBroadcastTo("broadcastTo", a.shape, targetShape, site);
461
+ return addOp(currentGraph(), "broadcast_to", targetShape, a.dtype, site, { a: a.id, targetShape });
462
+ }
463
+ function sumToShape(a, targetShape) {
464
+ const site = captureSite("sumToShape");
465
+ inferSumToShape("sumToShape", a.shape, targetShape, site);
466
+ return addOp(currentGraph(), "sum_to_shape", targetShape, a.dtype, site, { a: a.id, targetShape });
467
+ }
468
+ function constScalar(value, dtype = "f32") {
469
+ const site = captureSite("constScalar");
470
+ return addOp(currentGraph(), "const_scalar", [], dtype, site, { value, dtype });
471
+ }
472
+ var less = (a, b) => binopOp("less", "less", a, b, "bool");
473
+ var greater = (a, b) => binopOp("greater", "greater", a, b, "bool");
474
+ function where(cond, a, b) {
475
+ const site = captureSite("where");
476
+ if (cond.dtype !== "bool") throw new ShapeError(`where: cond must be bool, got ${cond.dtype}`, site);
477
+ if (a.dtype !== b.dtype) throw new ShapeError(`where: a/b dtype mismatch (${a.dtype} vs ${b.dtype})`, site);
478
+ const outShape = inferWhere("where", cond.shape, a.shape, b.shape, site);
479
+ return addOp(currentGraph(), "where", outShape, a.dtype, site, { cond: cond.id, a: a.id, b: b.id });
480
+ }
481
+ function reluGrad(x, dy) {
482
+ const site = captureSite("reluGrad");
483
+ if (x.dtype !== "f32" || dy.dtype !== "f32") {
484
+ throw new ShapeError(`reluGrad: requires f32, got ${x.dtype} and ${dy.dtype}`, site);
485
+ }
486
+ const outShape = inferReluGrad("reluGrad", x.shape, dy.shape, site);
487
+ return addOp(currentGraph(), "relu_grad", outShape, "f32", site, { x: x.id, dy: dy.id });
488
+ }
489
+ function adamUpdateM(m, g, b1) {
490
+ const site = captureSite("adamUpdateM");
491
+ if (m.dtype !== "f32" || g.dtype !== "f32") throw new ShapeError(`adamUpdateM: requires f32`, site);
492
+ if (m.shape.length !== g.shape.length || m.shape.some((d, i) => d !== g.shape[i])) {
493
+ throw new ShapeError(`adamUpdateM: shape mismatch`, site);
494
+ }
495
+ return addOp(currentGraph(), "adam_update_m", m.shape, "f32", site, { m: m.id, g: g.id, b1 });
496
+ }
497
+ function adamUpdateV(v, g, b2) {
498
+ const site = captureSite("adamUpdateV");
499
+ if (v.dtype !== "f32" || g.dtype !== "f32") throw new ShapeError(`adamUpdateV: requires f32`, site);
500
+ if (v.shape.length !== g.shape.length || v.shape.some((d, i) => d !== g.shape[i])) {
501
+ throw new ShapeError(`adamUpdateV: shape mismatch`, site);
502
+ }
503
+ return addOp(currentGraph(), "adam_update_v", v.shape, "f32", site, { v: v.id, g: g.id, b2 });
504
+ }
505
+ function adamUpdateP(p, mNew, vNew, lrt, eps, decayShrink = 1) {
506
+ const site = captureSite("adamUpdateP");
507
+ if (p.dtype !== "f32") throw new ShapeError(`adamUpdateP: requires f32`, site);
508
+ if (lrt.dtype !== "f32" || lrt.shape.length !== 0) {
509
+ throw new ShapeError(`adamUpdateP: lrt must be a 0-d f32 scalar`, site);
510
+ }
511
+ if (p.shape.length !== mNew.shape.length || p.shape.some((d, i) => d !== mNew.shape[i])) {
512
+ throw new ShapeError(`adamUpdateP: p/mNew shape mismatch`, site);
513
+ }
514
+ const isTensor = typeof decayShrink === "object";
515
+ if (isTensor) {
516
+ if (decayShrink.dtype !== "f32" || decayShrink.shape.length !== 0) {
517
+ throw new ShapeError(`adamUpdateP: decayShrink tensor must be a 0-d f32 scalar`, site);
518
+ }
519
+ }
520
+ return addOp(currentGraph(), "adam_update_p", p.shape, "f32", site, {
521
+ p: p.id,
522
+ mNew: mNew.id,
523
+ vNew: vNew.id,
524
+ lrt: lrt.id,
525
+ eps,
526
+ decayShrink: isTensor ? 1 : decayShrink,
527
+ decayShrinkTensor: isTensor ? decayShrink.id : null
528
+ });
529
+ }
530
+
531
+ // src/grad.ts
532
+ function appendGrad(graph) {
533
+ if (graph.outputs.length !== 1) {
534
+ throw new Error(`autograd: expected graph with exactly 1 output (the loss); got ${graph.outputs.length}`);
535
+ }
536
+ const lossId = graph.outputs[0];
537
+ const lossTensor = graph.tensors[lossId];
538
+ if (lossTensor.shape.length !== 0) {
539
+ throw new Error(
540
+ `autograd: loss must be a rank-0 scalar; got shape [${lossTensor.shape.join(", ")}]. Reduce with sumLast / mulScalar to a scalar before calling appendGrad.`
541
+ );
542
+ }
543
+ const forwardOpCount = graph.ops.length;
544
+ const forwardOps = graph.ops.slice(0, forwardOpCount);
545
+ const cotangents = /* @__PURE__ */ new Map();
546
+ return traceInto(graph, () => {
547
+ cotangents.set(lossId, constScalar(1, "f32"));
548
+ for (let i = forwardOpCount - 1; i >= 0; i--) {
549
+ const op = forwardOps[i];
550
+ const outCotan = cotangents.get(op.out);
551
+ if (!outCotan) continue;
552
+ runTransposeRule(op, outCotan, graph, cotangents);
553
+ }
554
+ const paramGrads = {};
555
+ for (const op of forwardOps) {
556
+ if (op.kind !== "param_input") continue;
557
+ const cotan = cotangents.get(op.out);
558
+ if (!cotan) {
559
+ const t = graph.tensors[op.out];
560
+ paramGrads[op.name] = broadcastTo(constScalar(0, t.dtype), t.shape);
561
+ } else {
562
+ paramGrads[op.name] = cotan;
563
+ }
564
+ }
565
+ return { graph, paramGrads, loss: lossTensor };
566
+ });
567
+ }
568
+ function accumulate(cotangents, inputId, contribution) {
569
+ const existing = cotangents.get(inputId);
570
+ if (existing) {
571
+ cotangents.set(inputId, add(existing, contribution));
572
+ } else {
573
+ cotangents.set(inputId, contribution);
574
+ }
575
+ }
576
+ function unbroadcast(cotan, toShape) {
577
+ if (shapesEqual(cotan.shape, toShape)) return cotan;
578
+ return sumToShape(cotan, toShape);
579
+ }
580
+ function runTransposeRule(op, outCotan, graph, cotangents) {
581
+ const tensorOf = (id) => graph.tensors[id];
582
+ switch (op.kind) {
583
+ // ---- Leaves: no inputs to accumulate into. -----------------------------
584
+ case "param_input":
585
+ case "tensor_input":
586
+ case "state_input":
587
+ case "arange":
588
+ case "const_scalar":
589
+ return;
590
+ // ---- Element-wise binops (with broadcast) ------------------------------
591
+ // c = a op b; reduce cotan back to each operand's shape.
592
+ case "add": {
593
+ const a = tensorOf(op.a), b = tensorOf(op.b);
594
+ accumulate(cotangents, op.a, unbroadcast(outCotan, a.shape));
595
+ accumulate(cotangents, op.b, unbroadcast(outCotan, b.shape));
596
+ return;
597
+ }
598
+ case "sub": {
599
+ const a = tensorOf(op.a), b = tensorOf(op.b);
600
+ accumulate(cotangents, op.a, unbroadcast(outCotan, a.shape));
601
+ accumulate(cotangents, op.b, unbroadcast(mulScalar(outCotan, -1), b.shape));
602
+ return;
603
+ }
604
+ case "mul": {
605
+ const a = tensorOf(op.a), b = tensorOf(op.b);
606
+ accumulate(cotangents, op.a, unbroadcast(mul(outCotan, b), a.shape));
607
+ accumulate(cotangents, op.b, unbroadcast(mul(outCotan, a), b.shape));
608
+ return;
609
+ }
610
+ case "div": {
611
+ const a = tensorOf(op.a), b = tensorOf(op.b);
612
+ accumulate(cotangents, op.a, unbroadcast(div(outCotan, b), a.shape));
613
+ const numer = mul(outCotan, a);
614
+ const bSq = mul(b, b);
615
+ accumulate(cotangents, op.b, unbroadcast(mulScalar(div(numer, bSq), -1), b.shape));
616
+ return;
617
+ }
618
+ // ---- Element-wise scalar binops (scalar is a JS number, not a tensor) -
619
+ case "mul_scalar": {
620
+ accumulate(cotangents, op.a, mulScalar(outCotan, op.scalar));
621
+ return;
622
+ }
623
+ case "add_scalar": {
624
+ accumulate(cotangents, op.a, outCotan);
625
+ return;
626
+ }
627
+ // ---- Unary -------------------------------------------------------------
628
+ case "sqrt": {
629
+ const c = tensorOf(op.out);
630
+ accumulate(cotangents, op.a, mulScalar(div(outCotan, c), 0.5));
631
+ return;
632
+ }
633
+ case "rsqrt": {
634
+ const c = tensorOf(op.out);
635
+ const c3 = mul(mul(c, c), c);
636
+ accumulate(cotangents, op.a, mulScalar(mul(outCotan, c3), -0.5));
637
+ return;
638
+ }
639
+ case "log": {
640
+ const a = tensorOf(op.a);
641
+ accumulate(cotangents, op.a, div(outCotan, a));
642
+ return;
643
+ }
644
+ case "exp": {
645
+ const c = tensorOf(op.out);
646
+ accumulate(cotangents, op.a, mul(outCotan, c));
647
+ return;
648
+ }
649
+ case "relu": {
650
+ const a = tensorOf(op.a);
651
+ accumulate(cotangents, op.a, reluGrad(a, outCotan));
652
+ return;
653
+ }
654
+ // ---- Reductions over last axis ---------------------------------------
655
+ case "mean_last": {
656
+ const a = tensorOf(op.a);
657
+ const D = a.shape[a.shape.length - 1];
658
+ const expanded = broadcastTo(outCotan, a.shape);
659
+ accumulate(cotangents, op.a, mulScalar(expanded, 1 / D));
660
+ return;
661
+ }
662
+ case "sum_last": {
663
+ const a = tensorOf(op.a);
664
+ const withKeep = reshape(outCotan, [...outCotan.shape, 1]);
665
+ accumulate(cotangents, op.a, broadcastTo(withKeep, a.shape));
666
+ return;
667
+ }
668
+ // ---- Shape ------------------------------------------------------------
669
+ case "reshape": {
670
+ const a = tensorOf(op.a);
671
+ accumulate(cotangents, op.a, reshape(outCotan, a.shape));
672
+ return;
673
+ }
674
+ case "transpose": {
675
+ const inv = invertPerm(op.perm);
676
+ accumulate(cotangents, op.a, transpose(outCotan, inv));
677
+ return;
678
+ }
679
+ // ---- Linear algebra ---------------------------------------------------
680
+ case "matmul": {
681
+ const a = tensorOf(op.a), b = tensorOf(op.b);
682
+ accumulate(cotangents, op.a, matmul(outCotan, swapAxes(b, -1, -2)));
683
+ const aT = swapAxes(a, -1, -2);
684
+ let perBatchDb;
685
+ if (a.shape.length > 2) {
686
+ perBatchDb = matmulBatched(aT, outCotan);
687
+ } else {
688
+ perBatchDb = matmul(aT, outCotan);
689
+ }
690
+ accumulate(cotangents, op.b, sumToShape(perBatchDb, b.shape));
691
+ return;
692
+ }
693
+ case "matmul_batched": {
694
+ const a = tensorOf(op.a), b = tensorOf(op.b);
695
+ accumulate(cotangents, op.a, matmulBatched(outCotan, swapAxes(b, -1, -2)));
696
+ accumulate(cotangents, op.b, matmulBatched(swapAxes(a, -1, -2), outCotan));
697
+ return;
698
+ }
699
+ // ---- Indexing / casting (no gradient through integer indices) --------
700
+ case "one_hot":
701
+ return;
702
+ // ---- Slicing ---------------------------------------------------------
703
+ case "slice_last_range": {
704
+ const a = tensorOf(op.a);
705
+ throw new Error(
706
+ `autograd: slice_last_range backward not implemented yet (would need a scatter-style op or a Concat op). Workaround for now: avoid taking gradients through slices by using separate matmuls for Q/K/V instead of a fused W_qkv. Tensor: ${a.shape} -> ${tensorOf(op.out).shape}`
707
+ );
708
+ }
709
+ // ---- Broadcast / un-broadcast (autograd infrastructure) ---------------
710
+ case "broadcast_to": {
711
+ const a = tensorOf(op.a);
712
+ accumulate(cotangents, op.a, sumToShape(outCotan, a.shape));
713
+ return;
714
+ }
715
+ case "sum_to_shape": {
716
+ const a = tensorOf(op.a);
717
+ accumulate(cotangents, op.a, broadcastTo(outCotan, a.shape));
718
+ return;
719
+ }
720
+ // ---- ML primitives ---------------------------------------------------
721
+ case "log_softmax_last": {
722
+ const c = tensorOf(op.out);
723
+ const sm = exp(c);
724
+ const sumDc = sumLast(outCotan);
725
+ const sumDcKeep = reshape(sumDc, [...sumDc.shape, 1]);
726
+ const term = mul(sm, broadcastTo(sumDcKeep, c.shape));
727
+ accumulate(cotangents, op.a, sub(outCotan, term));
728
+ return;
729
+ }
730
+ case "softmax_causal_last": {
731
+ const c = tensorOf(op.out);
732
+ const dcXc = mul(outCotan, c);
733
+ const s = sumLast(dcXc);
734
+ const sKeep = reshape(s, [...s.shape, 1]);
735
+ const inner = sub(outCotan, broadcastTo(sKeep, c.shape));
736
+ accumulate(cotangents, op.a, mul(inner, c));
737
+ return;
738
+ }
739
+ // ---- Comparisons + select ---------------------------------------------
740
+ case "less":
741
+ case "greater":
742
+ return;
743
+ case "where": {
744
+ const cond = tensorOf(op.cond);
745
+ const a = tensorOf(op.a);
746
+ const b = tensorOf(op.b);
747
+ const zeroA = broadcastTo(constScalar(0, a.dtype), outCotan.shape);
748
+ const zeroB = broadcastTo(constScalar(0, b.dtype), outCotan.shape);
749
+ accumulate(cotangents, op.a, unbroadcast(where(cond, outCotan, zeroA), a.shape));
750
+ accumulate(cotangents, op.b, unbroadcast(where(cond, zeroB, outCotan), b.shape));
751
+ return;
752
+ }
753
+ case "where_causal": {
754
+ throw new Error(
755
+ `autograd: where_causal backward not yet implemented. Use softmax_causal_last (which fuses the mask + softmax) instead.`
756
+ );
757
+ }
758
+ // ---- Adam ops are post-autograd; no backward through them. ----------
759
+ case "adam_update_m":
760
+ case "adam_update_v":
761
+ case "adam_update_p":
762
+ throw new Error(`autograd: cannot differentiate through ${op.kind}`);
763
+ // ---- relu_grad has no further backward (autograd-internal) ----------
764
+ case "relu_grad": {
765
+ throw new Error(
766
+ `autograd: cannot take second-order gradient through relu_grad. Phase 2 does not support higher-order autodiff.`
767
+ );
768
+ }
769
+ default: {
770
+ const _exhaustive = op;
771
+ void _exhaustive;
772
+ throw new Error(`autograd: unhandled op kind ${op.kind}`);
773
+ }
774
+ }
775
+ }
776
+ function invertPerm(perm) {
777
+ const inv = new Array(perm.length);
778
+ for (let i = 0; i < perm.length; i++) inv[perm[i]] = i;
779
+ return inv;
780
+ }
781
+
782
+ // src/adam.ts
783
+ var lr = {
784
+ constant: (value) => ({ kind: "constant", value }),
785
+ /** Linearly interpolate from `peak` at step 1 to `final` at step `steps`,
786
+ * then hold at `final`. Matches `peak + (final - peak) * min(step/steps, 1)`. */
787
+ linearDecay: (opts) => ({ kind: "linearDecay", ...opts }),
788
+ /** Half-cosine from `peak` at step 1 down to `final` at step `steps`,
789
+ * then hold at `final`. */
790
+ cosineDecay: (opts) => ({ kind: "cosineDecay", ...opts }),
791
+ /** Linear ramp from 0 to `peakLr` over `warmupSteps` steps, then hand off
792
+ * to `after` (offset so step 1 of `after` = first post-warmup step). */
793
+ warmup: (opts) => ({ kind: "warmup", ...opts })
794
+ };
795
+ function resolveLR(schedule, step) {
796
+ if (typeof schedule === "number") return schedule;
797
+ switch (schedule.kind) {
798
+ case "constant":
799
+ return schedule.value;
800
+ case "linearDecay": {
801
+ const f = Math.min(step / schedule.steps, 1);
802
+ return schedule.peak + (schedule.final - schedule.peak) * f;
803
+ }
804
+ case "cosineDecay": {
805
+ const f = Math.min(step / schedule.steps, 1);
806
+ return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f));
807
+ }
808
+ case "warmup": {
809
+ if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps);
810
+ return resolveLR(schedule.after, step - schedule.warmupSteps);
811
+ }
812
+ }
813
+ }
814
+ function isLRDynamic(schedule) {
815
+ if (typeof schedule === "number") return false;
816
+ return schedule.kind !== "constant";
817
+ }
818
+ function appendAdam(graph, paramGrads, paramTensors, config, decayFlags) {
819
+ const lrIsScheduled = isLRDynamic(config.lr);
820
+ const initialLr = resolveLR(config.lr, 1);
821
+ const fullConfig = {
822
+ lr: config.lr,
823
+ b1: config.b1 ?? 0.9,
824
+ b2: config.b2 ?? 0.999,
825
+ eps: config.eps ?? 1e-8,
826
+ weightDecay: config.weightDecay ?? 0,
827
+ decayFilter: config.decayFilter ?? (() => true),
828
+ lrIsScheduled
829
+ };
830
+ const writebacks = [];
831
+ const lrtInputName = "_adam_lrt";
832
+ let decayShrinkInputName = null;
833
+ return traceInto(graph, () => {
834
+ const lrt = tensorInput(lrtInputName, [], "f32");
835
+ const decayedNames = new Set(
836
+ fullConfig.weightDecay > 0 ? Object.keys(paramGrads).filter((name) => decayFlags && name in decayFlags ? decayFlags[name] : fullConfig.decayFilter(name)) : []
837
+ );
838
+ let decayShrinkScalar = null;
839
+ if (lrIsScheduled && decayedNames.size > 0) {
840
+ decayShrinkInputName = "_adam_decay_shrink";
841
+ decayShrinkScalar = tensorInput(decayShrinkInputName, [], "f32");
842
+ }
843
+ for (const name of Object.keys(paramGrads)) {
844
+ const p = paramTensors[name];
845
+ const g = paramGrads[name];
846
+ if (!p) throw new Error(`appendAdam: missing param tensor for '${name}'`);
847
+ if (!g) throw new Error(`appendAdam: missing gradient for '${name}'`);
848
+ const mState = stateInput(`adam_m_${name}`, p.shape, "f32", 0);
849
+ const vState = stateInput(`adam_v_${name}`, p.shape, "f32", 0);
850
+ const decayShrink = !decayedNames.has(name) ? 1 : decayShrinkScalar !== null ? decayShrinkScalar : 1 - initialLr * fullConfig.weightDecay;
851
+ const newM = adamUpdateM(mState, g, fullConfig.b1);
852
+ const newV = adamUpdateV(vState, g, fullConfig.b2);
853
+ const newP = adamUpdateP(p, newM, newV, lrt, fullConfig.eps, decayShrink);
854
+ writebacks.push({ source: newM, destName: `adam_m_${name}`, destKind: "state" });
855
+ writebacks.push({ source: newV, destName: `adam_v_${name}`, destKind: "state" });
856
+ writebacks.push({ source: newP, destName: name, destKind: "param" });
857
+ }
858
+ return { writebacks, lrtInputName, decayShrinkInputName, config: fullConfig };
859
+ });
860
+ }
861
+
862
+ // src/buffers.ts
863
+ var dtypeBytes = { f32: 4, i32: 4, bool: 4 };
864
+ function planBuffers(graph, paramGrads, writebackDecls = []) {
865
+ const buffers = [];
866
+ const tensorToBuffer = /* @__PURE__ */ new Map();
867
+ const paramsByName = /* @__PURE__ */ new Map();
868
+ const inputsByName = /* @__PURE__ */ new Map();
869
+ const paramGradsByName = /* @__PURE__ */ new Map();
870
+ const statesByName = /* @__PURE__ */ new Map();
871
+ const gradTensorIdToName = /* @__PURE__ */ new Map();
872
+ for (const [name, tensor] of Object.entries(paramGrads)) {
873
+ gradTensorIdToName.set(tensor.id, name);
874
+ }
875
+ const opByOutId = /* @__PURE__ */ new Map();
876
+ for (const op of graph.ops) opByOutId.set(op.out, op);
877
+ const outputSet = new Set(graph.outputs);
878
+ for (const t of graph.tensors) {
879
+ const op = opByOutId.get(t.id);
880
+ let kind = "intermediate";
881
+ let name = null;
882
+ let initValue;
883
+ if (op?.kind === "param_input") {
884
+ kind = "param";
885
+ name = op.name;
886
+ } else if (op?.kind === "tensor_input") {
887
+ kind = "tensor_input";
888
+ name = op.name;
889
+ } else if (op?.kind === "state_input") {
890
+ kind = "state";
891
+ name = op.name;
892
+ initValue = op.initValue;
893
+ } else if (gradTensorIdToName.has(t.id)) {
894
+ kind = "param_grad";
895
+ name = gradTensorIdToName.get(t.id);
896
+ } else if (outputSet.has(t.id)) {
897
+ kind = "output";
898
+ }
899
+ const spec = {
900
+ id: t.id,
901
+ byteSize: Math.max(4, shapeSize(t.shape) * dtypeBytes[t.dtype]),
902
+ dtype: t.dtype,
903
+ shape: t.shape,
904
+ kind,
905
+ name,
906
+ ...initValue !== void 0 ? { initValue } : {}
907
+ };
908
+ buffers.push(spec);
909
+ tensorToBuffer.set(t.id, t.id);
910
+ if (kind === "param") paramsByName.set(name, t.id);
911
+ if (kind === "tensor_input") inputsByName.set(name, t.id);
912
+ if (kind === "param_grad") paramGradsByName.set(name, t.id);
913
+ if (kind === "state") statesByName.set(name, t.id);
914
+ }
915
+ const outputBufferIds = graph.outputs.map((id) => tensorToBuffer.get(id));
916
+ const writebacks = writebackDecls.map((decl) => {
917
+ const sourceBufId = tensorToBuffer.get(decl.source.id);
918
+ if (sourceBufId === void 0) {
919
+ throw new Error(`planBuffers: writeback source tensor #${decl.source.id} not in graph`);
920
+ }
921
+ const destBufId = decl.destKind === "param" ? paramsByName.get(decl.destName) : statesByName.get(decl.destName);
922
+ if (destBufId === void 0) {
923
+ throw new Error(`planBuffers: writeback dest ${decl.destKind}:'${decl.destName}' not found`);
924
+ }
925
+ const sourceSpec = buffers[sourceBufId];
926
+ const destSpec = buffers[destBufId];
927
+ if (sourceSpec.byteSize !== destSpec.byteSize) {
928
+ throw new Error(
929
+ `planBuffers: writeback size mismatch for ${decl.destKind}:'${decl.destName}' (source ${sourceSpec.byteSize} bytes vs dest ${destSpec.byteSize})`
930
+ );
931
+ }
932
+ return { source: sourceBufId, dest: destBufId, bytes: sourceSpec.byteSize };
933
+ });
934
+ const capturesByName = /* @__PURE__ */ new Map();
935
+ for (const [name, tensorId] of graph.captures) {
936
+ const bufId = tensorToBuffer.get(tensorId);
937
+ if (bufId === void 0) {
938
+ throw new Error(`planBuffers: capture '${name}' references unknown tensor #${tensorId}`);
939
+ }
940
+ capturesByName.set(name, bufId);
941
+ }
942
+ return { buffers, tensorToBuffer, paramsByName, inputsByName, paramGradsByName, statesByName, capturesByName, outputBufferIds, writebacks };
943
+ }
944
+
945
+ // src/codegen.ts
946
+ var WG_SIZE = 256;
947
+ var GID_LINE = "let i = gid.x + gid.y * 16776960u;";
948
+ function emitKernels(graph, plan) {
949
+ const out = [];
950
+ for (let i = 0; i < graph.ops.length; i++) {
951
+ const op = graph.ops[i];
952
+ const spec = emitKernel(op, graph, plan, i);
953
+ out.push(spec);
954
+ }
955
+ return out;
956
+ }
957
+ function emitKernel(op, graph, plan, opIndex) {
958
+ const tof = (id) => graph.tensors[id];
959
+ const buf = (tensorId) => plan.tensorToBuffer.get(tensorId);
960
+ const empty = () => ({ opIndex, opKind: op.kind, wgsl: "", bindings: [], threads: 0, workgroupSize: WG_SIZE });
961
+ switch (op.kind) {
962
+ // ---- Leaves: data is supplied externally; no kernel ---------------------
963
+ case "param_input":
964
+ case "tensor_input":
965
+ case "state_input":
966
+ return empty();
967
+ // ---- arange / const_scalar: kernel that fills the buffer once -----------
968
+ case "arange": {
969
+ const out = tof(op.out);
970
+ const wgsl = `
971
+ @group(0) @binding(0) var<storage, read_write> buf : array<${wgslDtype(out.dtype)}>;
972
+ @compute @workgroup_size(${WG_SIZE})
973
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
974
+ ${GID_LINE}
975
+ if (i >= ${op.n}u) { return; }
976
+ buf[i] = ${castFromI32("i32(i)", out.dtype)};
977
+ }`.trim();
978
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.out)], threads: op.n, workgroupSize: WG_SIZE };
979
+ }
980
+ case "const_scalar": {
981
+ const wgsl = `
982
+ @group(0) @binding(0) var<storage, read_write> buf : array<${wgslDtype(op.dtype)}>;
983
+ @compute @workgroup_size(1)
984
+ fn main() {
985
+ buf[0] = ${wgslLiteral(op.value, op.dtype)};
986
+ }`.trim();
987
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.out)], threads: 1, workgroupSize: 1 };
988
+ }
989
+ // ---- Element-wise binops with broadcast --------------------------------
990
+ case "add":
991
+ case "sub":
992
+ case "mul":
993
+ case "div": {
994
+ const out = tof(op.out);
995
+ const a = tof(op.a);
996
+ const b = tof(op.b);
997
+ const opStr = { add: "+", sub: "-", mul: "*", div: "/" }[op.kind];
998
+ const total = shapeSize(out.shape);
999
+ const wgsl = `
1000
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
1001
+ @group(0) @binding(1) var<storage, read> b : array<${wgslDtype(b.dtype)}>;
1002
+ @group(0) @binding(2) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
1003
+ @compute @workgroup_size(${WG_SIZE})
1004
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1005
+ ${GID_LINE}
1006
+ if (i >= ${total}u) { return; }
1007
+ ${broadcastIndexBlock("i", out.shape, a.shape, "aIdx")}
1008
+ ${broadcastIndexBlock("i", out.shape, b.shape, "bIdx")}
1009
+ out[i] = a[aIdx] ${opStr} b[bIdx];
1010
+ }`.trim();
1011
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1012
+ }
1013
+ // ---- Element-wise scalar binops (scalar baked into WGSL) ---------------
1014
+ case "mul_scalar":
1015
+ case "add_scalar": {
1016
+ const out = tof(op.out);
1017
+ const a = tof(op.a);
1018
+ const opStr = op.kind === "mul_scalar" ? "*" : "+";
1019
+ const total = shapeSize(out.shape);
1020
+ const lit = wgslLiteral(op.scalar, out.dtype);
1021
+ const wgsl = `
1022
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
1023
+ @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
1024
+ @compute @workgroup_size(${WG_SIZE})
1025
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1026
+ ${GID_LINE}
1027
+ if (i >= ${total}u) { return; }
1028
+ out[i] = a[i] ${opStr} ${lit};
1029
+ }`.trim();
1030
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1031
+ }
1032
+ // ---- Unary -------------------------------------------------------------
1033
+ case "sqrt":
1034
+ case "rsqrt":
1035
+ case "log":
1036
+ case "exp":
1037
+ case "relu": {
1038
+ const out = tof(op.out);
1039
+ const a = tof(op.a);
1040
+ const total = shapeSize(out.shape);
1041
+ const expr = op.kind === "sqrt" ? "sqrt(x)" : op.kind === "rsqrt" ? "1.0 / sqrt(x)" : op.kind === "log" ? "log(x)" : op.kind === "exp" ? "exp(x)" : (
1042
+ /* relu */
1043
+ "max(x, 0.0)"
1044
+ );
1045
+ const wgsl = `
1046
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
1047
+ @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
1048
+ @compute @workgroup_size(${WG_SIZE})
1049
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1050
+ ${GID_LINE}
1051
+ if (i >= ${total}u) { return; }
1052
+ let x = a[i];
1053
+ out[i] = ${expr};
1054
+ }`.trim();
1055
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1056
+ }
1057
+ // ---- Comparisons + select --------------------------------------------
1058
+ case "less":
1059
+ case "greater": {
1060
+ const out = tof(op.out);
1061
+ const a = tof(op.a);
1062
+ const b = tof(op.b);
1063
+ const opStr = op.kind === "less" ? "<" : ">";
1064
+ const total = shapeSize(out.shape);
1065
+ const wgsl = `
1066
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
1067
+ @group(0) @binding(1) var<storage, read> b : array<${wgslDtype(b.dtype)}>;
1068
+ @group(0) @binding(2) var<storage, read_write> out : array<u32>;
1069
+ @compute @workgroup_size(${WG_SIZE})
1070
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1071
+ ${GID_LINE}
1072
+ if (i >= ${total}u) { return; }
1073
+ ${broadcastIndexBlock("i", out.shape, a.shape, "aIdx")}
1074
+ ${broadcastIndexBlock("i", out.shape, b.shape, "bIdx")}
1075
+ out[i] = select(0u, 1u, a[aIdx] ${opStr} b[bIdx]);
1076
+ }`.trim();
1077
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1078
+ }
1079
+ case "where": {
1080
+ const out = tof(op.out);
1081
+ const cond = tof(op.cond);
1082
+ const a = tof(op.a);
1083
+ const b = tof(op.b);
1084
+ const total = shapeSize(out.shape);
1085
+ const wgsl = `
1086
+ @group(0) @binding(0) var<storage, read> cond : array<u32>;
1087
+ @group(0) @binding(1) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
1088
+ @group(0) @binding(2) var<storage, read> b : array<${wgslDtype(b.dtype)}>;
1089
+ @group(0) @binding(3) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
1090
+ @compute @workgroup_size(${WG_SIZE})
1091
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1092
+ ${GID_LINE}
1093
+ if (i >= ${total}u) { return; }
1094
+ ${broadcastIndexBlock("i", out.shape, cond.shape, "cIdx")}
1095
+ ${broadcastIndexBlock("i", out.shape, a.shape, "aIdx")}
1096
+ ${broadcastIndexBlock("i", out.shape, b.shape, "bIdx")}
1097
+ out[i] = select(b[bIdx], a[aIdx], cond[cIdx] != 0u);
1098
+ }`.trim();
1099
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.cond), buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1100
+ }
1101
+ case "relu_grad": {
1102
+ const out = tof(op.out);
1103
+ const total = shapeSize(out.shape);
1104
+ const wgsl = `
1105
+ @group(0) @binding(0) var<storage, read> x : array<f32>;
1106
+ @group(0) @binding(1) var<storage, read> dy : array<f32>;
1107
+ @group(0) @binding(2) var<storage, read_write> out : array<f32>;
1108
+ @compute @workgroup_size(${WG_SIZE})
1109
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1110
+ ${GID_LINE}
1111
+ if (i >= ${total}u) { return; }
1112
+ out[i] = select(0.0, dy[i], x[i] > 0.0);
1113
+ }`.trim();
1114
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.x), buf(op.dy), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1115
+ }
1116
+ // ---- Reductions over last axis -----------------------------------------
1117
+ case "mean_last":
1118
+ case "sum_last": {
1119
+ const a = tof(op.a);
1120
+ const D = a.shape[a.shape.length - 1];
1121
+ const outerSize = shapeSize(a.shape) / D;
1122
+ const divisor = op.kind === "mean_last" ? `f32(${D}u)` : "1.0";
1123
+ const wgsl = `
1124
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
1125
+ @group(0) @binding(1) var<storage, read_write> out : array<f32>;
1126
+ @compute @workgroup_size(${WG_SIZE})
1127
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1128
+ ${GID_LINE}
1129
+ if (i >= ${outerSize}u) { return; }
1130
+ let base = i * ${D}u;
1131
+ var s : f32 = 0.0;
1132
+ for (var j : u32 = 0u; j < ${D}u; j = j + 1u) {
1133
+ s = s + a[base + j];
1134
+ }
1135
+ out[i] = s / ${divisor};
1136
+ }`.trim();
1137
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: outerSize, workgroupSize: WG_SIZE };
1138
+ }
1139
+ // ---- Shape ---------------------------------------------------------------
1140
+ // reshape: no kernel needed if buffers can alias (shape change only). For
1141
+ // v1 simplicity we emit a memcpy-style kernel rather than aliasing buffers,
1142
+ // because aliasing complicates the buffer plan and we have memory headroom.
1143
+ case "reshape": {
1144
+ const out = tof(op.out);
1145
+ const a = tof(op.a);
1146
+ const total = shapeSize(out.shape);
1147
+ const wgsl = `
1148
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
1149
+ @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
1150
+ @compute @workgroup_size(${WG_SIZE})
1151
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1152
+ ${GID_LINE}
1153
+ if (i >= ${total}u) { return; }
1154
+ out[i] = a[i];
1155
+ }`.trim();
1156
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1157
+ }
1158
+ case "transpose": {
1159
+ const out = tof(op.out);
1160
+ const a = tof(op.a);
1161
+ const total = shapeSize(out.shape);
1162
+ const aStrides = computeStrides(a.shape);
1163
+ const outDimDecls = decomposeFlatIndexBlock("i", out.shape, "oIdx");
1164
+ const srcExpr = [];
1165
+ for (let k = 0; k < a.shape.length; k++) {
1166
+ const srcAxis = op.perm.indexOf(k);
1167
+ srcExpr.push(`oIdx_${srcAxis} * ${aStrides[k]}u`);
1168
+ }
1169
+ const wgsl = `
1170
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
1171
+ @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
1172
+ @compute @workgroup_size(${WG_SIZE})
1173
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1174
+ ${GID_LINE}
1175
+ if (i >= ${total}u) { return; }
1176
+ ${outDimDecls}
1177
+ let srcIdx = ${srcExpr.join(" + ")};
1178
+ out[i] = a[srcIdx];
1179
+ }`.trim();
1180
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1181
+ }
1182
+ // ---- Linear algebra ----------------------------------------------------
1183
+ // matmul: a [..., M, K] · b [K, N] -> [..., M, N]. b is unbatched.
1184
+ case "matmul": {
1185
+ const out = tof(op.out);
1186
+ const a = tof(op.a);
1187
+ const b = tof(op.b);
1188
+ const M = a.shape[a.shape.length - 2];
1189
+ const K = a.shape[a.shape.length - 1];
1190
+ const N = b.shape[1];
1191
+ const batch = shapeSize(a.shape) / (M * K);
1192
+ const total = batch * M * N;
1193
+ const wgsl = `
1194
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
1195
+ @group(0) @binding(1) var<storage, read> b : array<f32>;
1196
+ @group(0) @binding(2) var<storage, read_write> c : array<f32>;
1197
+ @compute @workgroup_size(${WG_SIZE})
1198
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1199
+ ${GID_LINE}
1200
+ if (i >= ${total}u) { return; }
1201
+ let bi = i / ${M * N}u; // batch index
1202
+ let mn = i % ${M * N}u;
1203
+ let m = mn / ${N}u;
1204
+ let n = mn % ${N}u;
1205
+ let aBase = bi * ${M * K}u + m * ${K}u;
1206
+ var s : f32 = 0.0;
1207
+ for (var k : u32 = 0u; k < ${K}u; k = k + 1u) {
1208
+ s = s + a[aBase + k] * b[k * ${N}u + n];
1209
+ }
1210
+ c[i] = s;
1211
+ }`.trim();
1212
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1213
+ }
1214
+ case "matmul_batched": {
1215
+ const out = tof(op.out);
1216
+ const a = tof(op.a);
1217
+ const b = tof(op.b);
1218
+ const M = a.shape[a.shape.length - 2];
1219
+ const K = a.shape[a.shape.length - 1];
1220
+ const N = b.shape[b.shape.length - 1];
1221
+ const batch = shapeSize(a.shape) / (M * K);
1222
+ const total = batch * M * N;
1223
+ const wgsl = `
1224
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
1225
+ @group(0) @binding(1) var<storage, read> b : array<f32>;
1226
+ @group(0) @binding(2) var<storage, read_write> c : array<f32>;
1227
+ @compute @workgroup_size(${WG_SIZE})
1228
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1229
+ ${GID_LINE}
1230
+ if (i >= ${total}u) { return; }
1231
+ let bi = i / ${M * N}u;
1232
+ let mn = i % ${M * N}u;
1233
+ let m = mn / ${N}u;
1234
+ let n = mn % ${N}u;
1235
+ let aBase = bi * ${M * K}u + m * ${K}u;
1236
+ let bBase = bi * ${K * N}u;
1237
+ var s : f32 = 0.0;
1238
+ for (var k : u32 = 0u; k < ${K}u; k = k + 1u) {
1239
+ s = s + a[aBase + k] * b[bBase + k * ${N}u + n];
1240
+ }
1241
+ c[i] = s;
1242
+ }`.trim();
1243
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1244
+ }
1245
+ // ---- One-hot ------------------------------------------------------------
1246
+ case "one_hot": {
1247
+ const out = tof(op.out);
1248
+ const indices = tof(op.indices);
1249
+ const total = shapeSize(out.shape);
1250
+ const depth = op.depth;
1251
+ const zeroLit = wgslLiteral(0, out.dtype);
1252
+ const oneLit = wgslLiteral(1, out.dtype);
1253
+ const wgsl = `
1254
+ @group(0) @binding(0) var<storage, read> indices : array<i32>;
1255
+ @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
1256
+ @compute @workgroup_size(${WG_SIZE})
1257
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1258
+ ${GID_LINE}
1259
+ if (i >= ${total}u) { return; }
1260
+ let outerIdx = i / ${depth}u;
1261
+ let depthIdx = i % ${depth}u;
1262
+ let tgt = u32(indices[outerIdx]);
1263
+ out[i] = select(${zeroLit}, ${oneLit}, tgt == depthIdx);
1264
+ }`.trim();
1265
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.indices), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1266
+ }
1267
+ // ---- ML primitives -----------------------------------------------------
1268
+ case "log_softmax_last": {
1269
+ const a = tof(op.a);
1270
+ const D = a.shape[a.shape.length - 1];
1271
+ const outerSize = shapeSize(a.shape) / D;
1272
+ const wgsl = `
1273
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
1274
+ @group(0) @binding(1) var<storage, read_write> out : array<f32>;
1275
+ @compute @workgroup_size(${WG_SIZE})
1276
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1277
+ ${GID_LINE}
1278
+ if (i >= ${outerSize}u) { return; }
1279
+ let base = i * ${D}u;
1280
+ var m : f32 = -1.0e30;
1281
+ for (var j : u32 = 0u; j < ${D}u; j = j + 1u) {
1282
+ let v = a[base + j];
1283
+ if (v > m) { m = v; }
1284
+ }
1285
+ var s : f32 = 0.0;
1286
+ for (var j : u32 = 0u; j < ${D}u; j = j + 1u) {
1287
+ s = s + exp(a[base + j] - m);
1288
+ }
1289
+ let logZ = m + log(s);
1290
+ for (var j : u32 = 0u; j < ${D}u; j = j + 1u) {
1291
+ out[base + j] = a[base + j] - logZ;
1292
+ }
1293
+ }`.trim();
1294
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: outerSize, workgroupSize: WG_SIZE };
1295
+ }
1296
+ case "softmax_causal_last": {
1297
+ const a = tof(op.a);
1298
+ const T = a.shape[a.shape.length - 1];
1299
+ const outerSize = shapeSize(a.shape) / T;
1300
+ const wgsl = `
1301
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
1302
+ @group(0) @binding(1) var<storage, read_write> out : array<f32>;
1303
+ @compute @workgroup_size(${WG_SIZE})
1304
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1305
+ // Each thread handles one (..., qpos)-row, softmaxing over kpos\u2208[0..qpos].
1306
+ ${GID_LINE}
1307
+ if (i >= ${outerSize}u) { return; }
1308
+ let qpos = i % ${T}u;
1309
+ let base = i * ${T}u;
1310
+ var m : f32 = -1.0e30;
1311
+ for (var k : u32 = 0u; k <= qpos; k = k + 1u) {
1312
+ let v = a[base + k];
1313
+ if (v > m) { m = v; }
1314
+ }
1315
+ var s : f32 = 0.0;
1316
+ for (var k : u32 = 0u; k <= qpos; k = k + 1u) {
1317
+ let e = exp(a[base + k] - m);
1318
+ out[base + k] = e;
1319
+ s = s + e;
1320
+ }
1321
+ for (var k : u32 = 0u; k <= qpos; k = k + 1u) {
1322
+ out[base + k] = out[base + k] / s;
1323
+ }
1324
+ for (var k : u32 = qpos + 1u; k < ${T}u; k = k + 1u) {
1325
+ out[base + k] = 0.0;
1326
+ }
1327
+ }`.trim();
1328
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: outerSize, workgroupSize: WG_SIZE };
1329
+ }
1330
+ case "where_causal": {
1331
+ const a = tof(op.a);
1332
+ const T = a.shape[a.shape.length - 1];
1333
+ const total = shapeSize(a.shape);
1334
+ const fillLit = wgslLiteral(op.fillValue, "f32");
1335
+ const wgsl = `
1336
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
1337
+ @group(0) @binding(1) var<storage, read_write> out : array<f32>;
1338
+ @compute @workgroup_size(${WG_SIZE})
1339
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1340
+ ${GID_LINE}
1341
+ if (i >= ${total}u) { return; }
1342
+ let kpos = i % ${T}u;
1343
+ let qpos = (i / ${T}u) % ${T}u;
1344
+ if (kpos > qpos) {
1345
+ out[i] = ${fillLit};
1346
+ } else {
1347
+ out[i] = a[i];
1348
+ }
1349
+ }`.trim();
1350
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1351
+ }
1352
+ // ---- Slicing -----------------------------------------------------------
1353
+ case "slice_last_range": {
1354
+ const out = tof(op.out);
1355
+ const a = tof(op.a);
1356
+ const D_in = a.shape[a.shape.length - 1];
1357
+ const D_out = op.end - op.start;
1358
+ const total = shapeSize(out.shape);
1359
+ const wgsl = `
1360
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
1361
+ @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
1362
+ @compute @workgroup_size(${WG_SIZE})
1363
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1364
+ ${GID_LINE}
1365
+ if (i >= ${total}u) { return; }
1366
+ let outer = i / ${D_out}u;
1367
+ let inner = i % ${D_out}u;
1368
+ out[i] = a[outer * ${D_in}u + ${op.start}u + inner];
1369
+ }`.trim();
1370
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1371
+ }
1372
+ // ---- Broadcast / un-broadcast (autograd infrastructure) ----------------
1373
+ case "broadcast_to": {
1374
+ const out = tof(op.out);
1375
+ const a = tof(op.a);
1376
+ const total = shapeSize(out.shape);
1377
+ const wgsl = `
1378
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
1379
+ @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
1380
+ @compute @workgroup_size(${WG_SIZE})
1381
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1382
+ ${GID_LINE}
1383
+ if (i >= ${total}u) { return; }
1384
+ ${broadcastIndexBlock("i", out.shape, a.shape, "srcIdx")}
1385
+ out[i] = a[srcIdx];
1386
+ }`.trim();
1387
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1388
+ }
1389
+ // ---- Adam (fused per-element) -----------------------------------------
1390
+ case "adam_update_m": {
1391
+ const out = tof(op.out);
1392
+ const total = shapeSize(out.shape);
1393
+ const b1 = op.b1;
1394
+ const oneMinusB1 = 1 - b1;
1395
+ const wgsl = `
1396
+ @group(0) @binding(0) var<storage, read> m : array<f32>;
1397
+ @group(0) @binding(1) var<storage, read> g : array<f32>;
1398
+ @group(0) @binding(2) var<storage, read_write> out : array<f32>;
1399
+ @compute @workgroup_size(${WG_SIZE})
1400
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1401
+ ${GID_LINE}
1402
+ if (i >= ${total}u) { return; }
1403
+ out[i] = ${wgslLiteral(b1, "f32")} * m[i] + ${wgslLiteral(oneMinusB1, "f32")} * g[i];
1404
+ }`.trim();
1405
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.m), buf(op.g), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1406
+ }
1407
+ case "adam_update_v": {
1408
+ const out = tof(op.out);
1409
+ const total = shapeSize(out.shape);
1410
+ const b2 = op.b2;
1411
+ const oneMinusB2 = 1 - b2;
1412
+ const wgsl = `
1413
+ @group(0) @binding(0) var<storage, read> v : array<f32>;
1414
+ @group(0) @binding(1) var<storage, read> g : array<f32>;
1415
+ @group(0) @binding(2) var<storage, read_write> out : array<f32>;
1416
+ @compute @workgroup_size(${WG_SIZE})
1417
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1418
+ ${GID_LINE}
1419
+ if (i >= ${total}u) { return; }
1420
+ let gv = g[i];
1421
+ out[i] = ${wgslLiteral(b2, "f32")} * v[i] + ${wgslLiteral(oneMinusB2, "f32")} * gv * gv;
1422
+ }`.trim();
1423
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.v), buf(op.g), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1424
+ }
1425
+ case "adam_update_p": {
1426
+ const out = tof(op.out);
1427
+ const total = shapeSize(out.shape);
1428
+ const dynamicShrink = op.decayShrinkTensor !== null;
1429
+ const shrinkExpr = dynamicShrink ? "decayShrink[0]" : wgslLiteral(op.decayShrink, "f32");
1430
+ const shrinkBinding = dynamicShrink ? `@group(0) @binding(4) var<storage, read> decayShrink : array<f32>;
1431
+ @group(0) @binding(5) var<storage, read_write> out : array<f32>;` : `@group(0) @binding(4) var<storage, read_write> out : array<f32>;`;
1432
+ const wgsl = `
1433
+ @group(0) @binding(0) var<storage, read> p : array<f32>;
1434
+ @group(0) @binding(1) var<storage, read> mNew : array<f32>;
1435
+ @group(0) @binding(2) var<storage, read> vNew : array<f32>;
1436
+ @group(0) @binding(3) var<storage, read> lrt : array<f32>;
1437
+ ${shrinkBinding}
1438
+ @compute @workgroup_size(${WG_SIZE})
1439
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1440
+ ${GID_LINE}
1441
+ if (i >= ${total}u) { return; }
1442
+ out[i] = ${shrinkExpr} * p[i] - lrt[0] * mNew[i] / (sqrt(vNew[i]) + ${wgslLiteral(op.eps, "f32")});
1443
+ }`.trim();
1444
+ const bindings = dynamicShrink ? [buf(op.p), buf(op.mNew), buf(op.vNew), buf(op.lrt), buf(op.decayShrinkTensor), buf(op.out)] : [buf(op.p), buf(op.mNew), buf(op.vNew), buf(op.lrt), buf(op.out)];
1445
+ return { opIndex, opKind: op.kind, wgsl, bindings, threads: total, workgroupSize: WG_SIZE };
1446
+ }
1447
+ case "sum_to_shape": {
1448
+ const out = tof(op.out);
1449
+ const a = tof(op.a);
1450
+ const wgsl = emitSumToShape(a.shape, out.shape, a.dtype);
1451
+ const total = shapeSize(out.shape);
1452
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
1453
+ }
1454
+ }
1455
+ }
1456
+ function wgslDtype(d) {
1457
+ if (d === "bool") return "u32";
1458
+ return d;
1459
+ }
1460
+ function wgslLiteral(value, dtype) {
1461
+ if (dtype === "f32") {
1462
+ if (Number.isFinite(value)) {
1463
+ return value.toString().includes(".") || value.toString().includes("e") ? `${value}f` : `${value}.0f`;
1464
+ }
1465
+ return value > 0 ? "1.0e30f" : "-1.0e30f";
1466
+ }
1467
+ if (dtype === "i32") return `${Math.trunc(value)}i`;
1468
+ return value ? "1u" : "0u";
1469
+ }
1470
+ function castFromI32(expr, dtype) {
1471
+ if (dtype === "f32") return `f32(${expr})`;
1472
+ if (dtype === "i32") return `i32(${expr})`;
1473
+ return `u32(${expr})`;
1474
+ }
1475
+ function computeStrides(shape) {
1476
+ const strides = new Array(shape.length).fill(1);
1477
+ for (let i = shape.length - 2; i >= 0; i--) {
1478
+ strides[i] = strides[i + 1] * shape[i + 1];
1479
+ }
1480
+ return strides;
1481
+ }
1482
+ function decomposeFlatIndexBlock(flatVar, shape, outVar) {
1483
+ if (shape.length === 0) return ` let ${outVar}_0 : u32 = 0u;`;
1484
+ const strides = computeStrides(shape);
1485
+ const lines = [];
1486
+ let remaining = flatVar;
1487
+ for (let i = 0; i < shape.length; i++) {
1488
+ if (i === shape.length - 1) {
1489
+ lines.push(` let ${outVar}_${i} = ${remaining};`);
1490
+ } else {
1491
+ lines.push(` let ${outVar}_${i} = ${remaining} / ${strides[i]}u;`);
1492
+ const newRem = `${outVar}_rem${i}`;
1493
+ lines.push(` let ${newRem} = ${remaining} % ${strides[i]}u;`);
1494
+ remaining = newRem;
1495
+ }
1496
+ }
1497
+ return lines.join("\n");
1498
+ }
1499
+ function broadcastIndexBlock(flatVar, outShape, srcShape, srcVar) {
1500
+ const prefix = `${srcVar}_ax`;
1501
+ const decompose = decomposeFlatIndexBlock(flatVar, outShape, prefix);
1502
+ const offset = outShape.length - srcShape.length;
1503
+ if (srcShape.length === 0) {
1504
+ return `${decompose}
1505
+ let ${srcVar} : u32 = 0u;`;
1506
+ }
1507
+ const srcStrides = computeStrides(srcShape);
1508
+ const terms = [];
1509
+ for (let i = 0; i < srcShape.length; i++) {
1510
+ const outAxis = i + offset;
1511
+ const srcDim = srcShape[i];
1512
+ const term = srcDim === 1 ? "0u" : `${prefix}_${outAxis} * ${srcStrides[i]}u`;
1513
+ terms.push(term);
1514
+ }
1515
+ return `${decompose}
1516
+ let ${srcVar} = ${terms.join(" + ")};`;
1517
+ }
1518
+ function emitSumToShape(srcShape, tgtShape, dtype) {
1519
+ const srcStrides = computeStrides(srcShape);
1520
+ const tgtStrides = computeStrides(tgtShape);
1521
+ const offset = srcShape.length - tgtShape.length;
1522
+ const decompose = decomposeFlatIndexBlock("i", tgtShape, "tgt");
1523
+ const reducedAxes = [];
1524
+ for (let k = 0; k < srcShape.length; k++) {
1525
+ if (k < offset) {
1526
+ reducedAxes.push(k);
1527
+ continue;
1528
+ }
1529
+ const tDim = tgtShape[k - offset];
1530
+ const sDim = srcShape[k];
1531
+ if (tDim === 1 && sDim > 1) reducedAxes.push(k);
1532
+ }
1533
+ const baseTerms = [];
1534
+ for (let k = 0; k < srcShape.length; k++) {
1535
+ if (reducedAxes.includes(k)) continue;
1536
+ const tAxis = k - offset;
1537
+ baseTerms.push(`tgt_${tAxis} * ${srcStrides[k]}u`);
1538
+ }
1539
+ const baseExpr = baseTerms.length > 0 ? baseTerms.join(" + ") : "0u";
1540
+ const indent = (depth) => " ".repeat(depth + 1);
1541
+ const loops = [];
1542
+ for (let depth = 0; depth < reducedAxes.length; depth++) {
1543
+ const k = reducedAxes[depth];
1544
+ const dim = srcShape[k];
1545
+ loops.push(`${indent(depth)}for (var r${k} : u32 = 0u; r${k} < ${dim}u; r${k} = r${k} + 1u) {`);
1546
+ }
1547
+ const reducedTerms = reducedAxes.map((k) => `r${k} * ${srcStrides[k]}u`);
1548
+ const fullExpr = reducedTerms.length > 0 ? `${baseExpr} + ${reducedTerms.join(" + ")}` : baseExpr;
1549
+ loops.push(`${indent(reducedAxes.length)}s = s + a[${fullExpr}];`);
1550
+ for (let depth = reducedAxes.length - 1; depth >= 0; depth--) {
1551
+ loops.push(`${indent(depth)}}`);
1552
+ }
1553
+ const total = tgtShape.length === 0 ? 1 : tgtStrides[0] * tgtShape[0];
1554
+ const loopBody = reducedAxes.length === 0 ? ` s = s + a[${baseExpr}];` : loops.join("\n");
1555
+ return `
1556
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(dtype)}>;
1557
+ @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(dtype)}>;
1558
+ @compute @workgroup_size(${WG_SIZE})
1559
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
1560
+ ${GID_LINE}
1561
+ if (i >= ${total}u) { return; }
1562
+ ${decompose}
1563
+ var s : ${wgslDtype(dtype)} = ${dtype === "f32" ? "0.0f" : dtype === "i32" ? "0i" : "0u"};
1564
+ ${loopBody}
1565
+ out[i] = s;
1566
+ }`.trim();
1567
+ }
1568
+
1569
+ // src/runtime.ts
1570
+ var Captures = class {
1571
+ constructor(shapes, data) {
1572
+ this.shapes = shapes;
1573
+ this.data = data;
1574
+ }
1575
+ shapes;
1576
+ data;
1577
+ get(name) {
1578
+ const d = this.data.get(name);
1579
+ if (!d) {
1580
+ const known = [...this.data.keys()].sort().join(", ");
1581
+ const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;
1582
+ throw new Error(`Captures.get: '${name}' not present. ${detail}`);
1583
+ }
1584
+ return d;
1585
+ }
1586
+ shapeOf(name) {
1587
+ const s = this.shapes[name];
1588
+ if (!s) {
1589
+ const known = Object.keys(this.shapes).sort().join(", ") || "(none registered)";
1590
+ throw new Error(`Captures.shapeOf: '${name}' not registered. Known: ${known}`);
1591
+ }
1592
+ return s;
1593
+ }
1594
+ has(name) {
1595
+ return this.data.has(name);
1596
+ }
1597
+ names() {
1598
+ return [...this.data.keys()].sort();
1599
+ }
1600
+ };
1601
+ var STORAGE_RW = 128 | 8 | 4;
1602
+ var READBACK = 1 | 8;
1603
+
1604
+ // src/module.ts
1605
+ var init = {
1606
+ randn: (opts = {}) => ({ kind: "randn", scale: opts.scale ?? 0.02 }),
1607
+ kaiming: (opts = {}) => opts.gain !== void 0 ? { kind: "kaiming", gain: opts.gain } : { kind: "kaiming" },
1608
+ literal: (data) => ({ kind: "literal", data })
1609
+ };
1610
+ function boxMuller() {
1611
+ return Math.sqrt(-2 * Math.log(Math.max(1e-10, Math.random()))) * Math.cos(2 * Math.PI * Math.random());
1612
+ }
1613
+ function randnFn(scale) {
1614
+ return (size) => {
1615
+ const arr = new Float32Array(size);
1616
+ for (let i = 0; i < size; i++) arr[i] = boxMuller() * scale;
1617
+ return arr;
1618
+ };
1619
+ }
1620
+ function resolveInit(spec) {
1621
+ if (!spec || spec === "randn") return randnFn(0.02);
1622
+ if (spec === "zeros") return (size) => new Float32Array(size);
1623
+ if (spec === "ones") return (size) => {
1624
+ const a = new Float32Array(size);
1625
+ a.fill(1);
1626
+ return a;
1627
+ };
1628
+ switch (spec.kind) {
1629
+ case "randn":
1630
+ return randnFn(spec.scale);
1631
+ case "kaiming": {
1632
+ const gain = spec.gain ?? Math.sqrt(2);
1633
+ return (size, shape) => {
1634
+ const fanIn = shape[0] ?? size;
1635
+ const std = gain / Math.sqrt(fanIn);
1636
+ const arr = new Float32Array(size);
1637
+ for (let i = 0; i < size; i++) arr[i] = boxMuller() * std;
1638
+ return arr;
1639
+ };
1640
+ }
1641
+ case "literal": {
1642
+ const data = spec.data;
1643
+ return (size) => {
1644
+ if (data.length !== size) {
1645
+ throw new Error(`init.literal: data length ${data.length} doesn't match param size ${size}`);
1646
+ }
1647
+ return new Float32Array(data);
1648
+ };
1649
+ }
1650
+ }
1651
+ }
1652
+ function resolveDecay(opts) {
1653
+ if (opts?.decay !== void 0) return opts.decay;
1654
+ const spec = opts?.init ?? "randn";
1655
+ return spec !== "zeros" && spec !== "ones";
1656
+ }
1657
+ var ParamSentinel = class {
1658
+ constructor(shape, dtype, initFn, decay) {
1659
+ this.shape = shape;
1660
+ this.dtype = dtype;
1661
+ this.initFn = initFn;
1662
+ this.decay = decay;
1663
+ }
1664
+ shape;
1665
+ dtype;
1666
+ initFn;
1667
+ decay;
1668
+ };
1669
+ var Module = class {
1670
+ /**
1671
+ * Declare a learnable parameter at this module. Must be called from inside
1672
+ * the constructor (typically as a field assignment). Returns a placeholder
1673
+ * that gets replaced with a real Tensor at compile time.
1674
+ *
1675
+ * The parameter's name is auto-derived from its property path in the model
1676
+ * tree (e.g. `layers.0.attn.W_q`). Init metadata travels with the param;
1677
+ * call `compiled.uploadInitialParams()` to apply it after compile.
1678
+ */
1679
+ param(shape, opts) {
1680
+ const dtype = opts?.dtype ?? "f32";
1681
+ return new ParamSentinel(shape, dtype, resolveInit(opts?.init), resolveDecay(opts));
1682
+ }
1683
+ };
1684
+ function materializeParams(root) {
1685
+ const tensors = {};
1686
+ const initFns = {};
1687
+ const decayFlags = {};
1688
+ visit(root, "", (path, val, owner, key) => {
1689
+ if (val instanceof ParamSentinel) {
1690
+ const t = paramInput(path, val.shape, val.dtype);
1691
+ owner[key] = t;
1692
+ tensors[path] = t;
1693
+ initFns[path] = val.initFn;
1694
+ decayFlags[path] = val.decay;
1695
+ }
1696
+ });
1697
+ return { tensors, initFns, decayFlags };
1698
+ }
1699
+ function visit(node, path, visitor) {
1700
+ if (node === null || node === void 0) return;
1701
+ if (typeof node !== "object") return;
1702
+ if (node instanceof Module) {
1703
+ for (const key of Object.keys(node)) {
1704
+ const child = node[key];
1705
+ const childPath = path ? `${path}.${key}` : key;
1706
+ visitChild(child, childPath, node, key, visitor);
1707
+ }
1708
+ return;
1709
+ }
1710
+ if (Array.isArray(node)) {
1711
+ node.forEach((item, i) => {
1712
+ const childPath = path ? `${path}.${i}` : String(i);
1713
+ visitChild(item, childPath, node, i, visitor);
1714
+ });
1715
+ return;
1716
+ }
1717
+ }
1718
+ function visitChild(child, path, owner, key, visitor) {
1719
+ if (child instanceof Module || Array.isArray(child)) {
1720
+ visit(child, path, visitor);
1721
+ } else {
1722
+ visitor(path, child, owner, key);
1723
+ }
1724
+ }
1725
+
1726
+ // src/worker-protocol.ts
1727
+ function transferablesOfRecord(rec) {
1728
+ const out = [];
1729
+ for (const v of Object.values(rec)) out.push(v.buffer);
1730
+ return out;
1731
+ }
1732
+ function reconstituteError(w) {
1733
+ const err = new Error(w.message);
1734
+ err.name = w.name;
1735
+ err.stack = w.stack;
1736
+ return err;
1737
+ }
1738
+
1739
+ // src/worker-proxy.ts
1740
+ var WorkerProxy = class {
1741
+ worker;
1742
+ nextId = 1;
1743
+ pending = /* @__PURE__ */ new Map();
1744
+ terminated = false;
1745
+ constructor(workerSource) {
1746
+ const blob = new Blob([workerSource], { type: "application/javascript" });
1747
+ const url = URL.createObjectURL(blob);
1748
+ this.worker = new Worker(url, { type: "module" });
1749
+ URL.revokeObjectURL(url);
1750
+ this.worker.onmessage = (ev) => {
1751
+ const reply = ev.data;
1752
+ const handlers = this.pending.get(reply.id);
1753
+ if (!handlers) return;
1754
+ this.pending.delete(reply.id);
1755
+ if (reply.ok) handlers.resolve(reply.result);
1756
+ else handlers.reject(reconstituteError(reply.error));
1757
+ };
1758
+ this.worker.onerror = (ev) => {
1759
+ const err = new Error(`tensorgrad worker error: ${ev.message || "unknown"}`);
1760
+ const wire = { name: "WorkerError", message: err.message, stack: err.stack ?? "" };
1761
+ for (const handlers of this.pending.values()) handlers.reject(reconstituteError(wire));
1762
+ this.pending.clear();
1763
+ };
1764
+ }
1765
+ /** Send a request and await its matching response. `transfer` lists the
1766
+ * ArrayBuffers to move (zero-copy) into the worker. */
1767
+ request(req, transfer = []) {
1768
+ if (this.terminated) return Promise.reject(new Error("tensorgrad: worker has been terminated"));
1769
+ const id = this.nextId++;
1770
+ return new Promise((resolve, reject) => {
1771
+ this.pending.set(id, { resolve, reject });
1772
+ this.worker.postMessage({ ...req, id }, transfer);
1773
+ });
1774
+ }
1775
+ /** Fire-and-forget variant for cases where the caller doesn't need a reply
1776
+ * (currently unused; keep for symmetry / future use). */
1777
+ send(req, transfer = []) {
1778
+ if (this.terminated) return;
1779
+ const id = this.nextId++;
1780
+ this.worker.postMessage({ ...req, id }, transfer);
1781
+ }
1782
+ terminate() {
1783
+ if (this.terminated) return;
1784
+ this.terminated = true;
1785
+ this.worker.terminate();
1786
+ const err = new Error("tensorgrad: worker terminated");
1787
+ for (const handlers of this.pending.values()) handlers.reject(err);
1788
+ this.pending.clear();
1789
+ }
1790
+ };
1791
+
1792
+ // src/compile.ts
1793
+ function compileToIR(traceFn) {
1794
+ const graph = trace(traceFn);
1795
+ const { paramGrads, loss } = appendGrad(graph);
1796
+ const plan = planBuffers(graph, paramGrads);
1797
+ const kernels = emitKernels(graph, plan);
1798
+ return { graph, paramGrads, loss, plan, kernels };
1799
+ }
1800
+ async function compileModule(modelFactory, forward, opts = {}) {
1801
+ const { graph, materialized } = traceModule(modelFactory, forward, opts.inputs ?? {});
1802
+ const { paramGrads, loss } = appendGrad(graph);
1803
+ const adamResult = opts.adam ? appendAdam(graph, paramGrads, materialized.tensors, opts.adam, materialized.decayFlags) : void 0;
1804
+ const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? []);
1805
+ const kernels = emitKernels(graph, plan);
1806
+ const ir = { graph, paramGrads, loss, plan, kernels };
1807
+ const initialParams = buildInitialParams(plan, materialized.initFns);
1808
+ const proxy = new WorkerProxy('// src/runtime.ts\nvar Captures = class {\n constructor(shapes, data) {\n this.shapes = shapes;\n this.data = data;\n }\n shapes;\n data;\n get(name) {\n const d = this.data.get(name);\n if (!d) {\n const known = [...this.data.keys()].sort().join(", ");\n const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;\n throw new Error(`Captures.get: \'${name}\' not present. ${detail}`);\n }\n return d;\n }\n shapeOf(name) {\n const s = this.shapes[name];\n if (!s) {\n const known = Object.keys(this.shapes).sort().join(", ") || "(none registered)";\n throw new Error(`Captures.shapeOf: \'${name}\' not registered. Known: ${known}`);\n }\n return s;\n }\n has(name) {\n return this.data.has(name);\n }\n names() {\n return [...this.data.keys()].sort();\n }\n};\nvar STORAGE_RW = 128 | 8 | 4;\nvar READBACK = 1 | 8;\nasync function createRuntime(plan, kernels, lossBufferId, opts = {}) {\n const device2 = opts.device ?? await acquireDevice();\n const queue = device2.queue;\n const buffers = /* @__PURE__ */ new Map();\n const ownedBufferIds = /* @__PURE__ */ new Set();\n const sharedParams = opts.sharedParams;\n for (const spec of plan.buffers) {\n const shared = spec.kind === "param" ? sharedParams?.get(spec.name) : void 0;\n if (shared) {\n if (shared.size !== spec.byteSize) {\n throw new Error(\n `sharedParams: size mismatch for \'${spec.name}\' \\u2014 supplied ${shared.size} bytes, compiled graph expects ${spec.byteSize}.`\n );\n }\n buffers.set(spec.id, shared);\n continue;\n }\n const buf = device2.createBuffer({\n size: spec.byteSize,\n usage: STORAGE_RW,\n label: spec.name ?? `t${spec.id}-${spec.kind}`\n });\n buffers.set(spec.id, buf);\n ownedBufferIds.add(spec.id);\n if (spec.kind === "state") fillStateBuffer(spec, buf);\n }\n const moduleCache = /* @__PURE__ */ new Map();\n const pipelines = [];\n const probes = [];\n for (const k of kernels) {\n if (!k.wgsl) {\n pipelines.push(null);\n continue;\n }\n let module = moduleCache.get(k.wgsl);\n if (!module) {\n module = device2.createShaderModule({ code: k.wgsl, label: k.opKind });\n moduleCache.set(k.wgsl, module);\n }\n device2.pushErrorScope("validation");\n const pipeline = device2.createComputePipeline({\n layout: "auto",\n compute: { module, entryPoint: "main" },\n label: k.opKind\n });\n pipelines.push(pipeline);\n probes.push(device2.popErrorScope().then((err) => err ? { k, module, err } : null));\n }\n const probeResults = await Promise.all(probes);\n const failures = probeResults.filter((p) => p != null);\n if (failures.length > 0) {\n const reports = [];\n for (const { k, module, err } of failures) {\n const info = await module.getCompilationInfo();\n const messages = info.messages.map((m) => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`).join("\\n");\n reports.push(\n `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` + (messages || " (no compilation messages)") + `\n--- WGSL ---\n${k.wgsl}\n-----------`\n );\n }\n console.error(reports.join("\\n\\n"));\n throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`);\n }\n const bindGroups = kernels.map((k, i) => {\n const pipeline = pipelines[i];\n if (!pipeline) return null;\n return device2.createBindGroup({\n layout: pipeline.getBindGroupLayout(0),\n entries: k.bindings.map((bufId, idx) => ({\n binding: idx,\n resource: { buffer: buffers.get(bufId) }\n }))\n });\n });\n const outputSpec = plan.buffers[lossBufferId];\n const outputReadback = device2.createBuffer({ size: outputSpec.byteSize, usage: READBACK });\n let captureStaging = null;\n function ensureCaptureStaging() {\n if (captureStaging) return captureStaging;\n let totalBytes = 0;\n const slices = [];\n for (const [name, bufId] of plan.capturesByName) {\n const spec = plan.buffers[bufId];\n slices.push({ name, bufId, offset: totalBytes, byteSize: spec.byteSize });\n totalBytes += spec.byteSize;\n }\n const buffer = device2.createBuffer({ size: totalBytes, usage: READBACK, label: "captures-staging" });\n captureStaging = { buffer, slices };\n return captureStaging;\n }\n let pending = Promise.resolve();\n async function dispatch(inputs, opts2) {\n const turn = pending.catch(() => {\n }).then(() => dispatchUnsynchronized(inputs, opts2));\n pending = turn;\n return turn;\n }\n async function dispatchUnsynchronized(inputs, opts2) {\n const wantCaptures = opts2.wantCaptures;\n if (wantCaptures && plan.capturesByName.size === 0) {\n throw new Error(\n `withCaptures=true but no capture(...) calls were registered during the trace. Add capture(\'name\', tensor) inside your forward pass for the intermediates you want read back.`\n );\n }\n for (const [name, bufId] of plan.inputsByName) {\n const data = inputs[name];\n if (!data) throw new Error(`tensorgrad: missing input \'${name}\'`);\n const expectedBytes = plan.buffers[bufId].byteSize;\n if (data.byteLength !== expectedBytes) {\n throw new Error(`tensorgrad: input \'${name}\' has ${data.byteLength} bytes, expected ${expectedBytes}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-step" });\n for (let i = 0; i < kernels.length; i++) {\n const k = kernels[i];\n if (!k.wgsl || k.threads === 0) continue;\n const pipeline = pipelines[i];\n const bindGroup = bindGroups[i];\n const pass = encoder.beginComputePass({ label: k.opKind });\n pass.setPipeline(pipeline);\n pass.setBindGroup(0, bindGroup);\n const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));\n const MAX_X = 65535;\n const wgX = Math.min(wgCount, MAX_X);\n const wgY = Math.ceil(wgCount / MAX_X);\n pass.dispatchWorkgroups(wgX, wgY, 1);\n pass.end();\n }\n for (const wb of plan.writebacks) {\n encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);\n }\n encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);\n let layout = null;\n if (wantCaptures) {\n layout = ensureCaptureStaging();\n for (const s of layout.slices) {\n encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);\n }\n }\n queue.submit([encoder.finish()]);\n if (!opts2.readback) return null;\n await outputReadback.mapAsync(GPUMapMode.READ);\n const output = new Float32Array(outputReadback.getMappedRange().slice(0));\n outputReadback.unmap();\n const captures = /* @__PURE__ */ new Map();\n if (layout) {\n await layout.buffer.mapAsync(GPUMapMode.READ);\n const range = layout.buffer.getMappedRange();\n for (const s of layout.slices) {\n captures.set(s.name, new Float32Array(range, s.offset, s.byteSize / 4).slice());\n }\n layout.buffer.unmap();\n }\n return { output, captures };\n }\n async function step(inputs, opts2) {\n if (opts2?.readLoss === false) {\n await dispatch(inputs, { wantCaptures: false, readback: false });\n return;\n }\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { loss: r.output[0], captures: new Captures(captureShapes, r.captures) };\n return r.output[0];\n }\n async function readLoss() {\n const turn = pending.catch(() => {\n }).then(async () => {\n await outputReadback.mapAsync(GPUMapMode.READ);\n const v = new Float32Array(outputReadback.getMappedRange())[0];\n outputReadback.unmap();\n return v;\n });\n pending = turn;\n return turn;\n }\n async function run(inputs, opts2) {\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) };\n return r.output;\n }\n function uploadParams(params2, opts2) {\n const partial = opts2?.partial ?? false;\n for (const name of Object.keys(params2)) {\n if (!plan.paramsByName.has(name)) {\n throw new Error(\n `uploadParams: unknown param \'${name}\'. Known: ${[...plan.paramsByName.keys()].sort().join(", ")}`\n );\n }\n }\n if (!partial) {\n for (const name of plan.paramsByName.keys()) {\n if (!(name in params2)) {\n throw new Error(\n `uploadParams: missing param \'${name}\'. Pass { partial: true } if you mean to update only some params.`\n );\n }\n }\n }\n for (const [name, bufId] of plan.paramsByName) {\n const data = params2[name];\n if (!data) continue;\n const expected = plan.buffers[bufId].byteSize / 4;\n if (data.length !== expected) {\n throw new Error(`uploadParams: \'${name}\' has ${data.length} elements, expected ${expected}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n }\n async function downloadFromMap(map) {\n const stagings = [];\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-download" });\n for (const [name, bufId] of map) {\n const spec = plan.buffers[bufId];\n const staging = device2.createBuffer({ size: spec.byteSize, usage: READBACK });\n encoder.copyBufferToBuffer(buffers.get(bufId), 0, staging, 0, spec.byteSize);\n stagings.push({ name, buf: staging, bytes: spec.byteSize });\n }\n queue.submit([encoder.finish()]);\n const out = {};\n for (const s of stagings) {\n await s.buf.mapAsync(GPUMapMode.READ);\n out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0));\n s.buf.unmap();\n s.buf.destroy();\n }\n return out;\n }\n function fillStateBuffer(spec, target) {\n const elements = spec.byteSize / 4;\n const init = spec.dtype === "f32" ? new Float32Array(elements).fill(spec.initValue ?? 0) : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0));\n queue.writeBuffer(target, 0, init);\n }\n function resetOptimizerState() {\n for (const spec of plan.buffers) {\n if (spec.kind === "state") fillStateBuffer(spec, buffers.get(spec.id));\n }\n }\n const params = /* @__PURE__ */ new Map();\n for (const [name, bufId] of plan.paramsByName) {\n params.set(name, buffers.get(bufId));\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const outputShape = [...plan.buffers[lossBufferId].shape];\n const destroy = () => {\n for (const [id, b] of buffers) {\n if (ownedBufferIds.has(id)) b.destroy();\n }\n outputReadback.destroy();\n if (captureStaging) captureStaging.buffer.destroy();\n };\n return {\n device: device2,\n params,\n outputShape,\n uploadParams,\n downloadParams: () => downloadFromMap(plan.paramsByName),\n downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),\n step,\n run,\n readLoss,\n resetOptimizerState,\n destroy\n };\n}\nasync function acquireDevice() {\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad: no WebGPU adapter");\n return await adapter.requestDevice();\n}\n\n// src/adam.ts\nfunction resolveLR(schedule, step) {\n if (typeof schedule === "number") return schedule;\n switch (schedule.kind) {\n case "constant":\n return schedule.value;\n case "linearDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.peak + (schedule.final - schedule.peak) * f;\n }\n case "cosineDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f));\n }\n case "warmup": {\n if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps);\n return resolveLR(schedule.after, step - schedule.warmupSteps);\n }\n }\n}\n\n// src/worker-protocol.ts\nfunction wireError(e) {\n if (e instanceof Error) {\n return { name: e.name, message: e.message, stack: e.stack ?? "" };\n }\n return { name: "Error", message: String(e), stack: "" };\n}\n\n// src/worker.ts\nvar graphs = /* @__PURE__ */ new Map();\nvar device = null;\nasync function ensureDevice() {\n if (device) return device;\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad worker: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad worker: no WebGPU adapter");\n device = await adapter.requestDevice();\n return device;\n}\nasync function handleCreateRuntime(payload) {\n const dev = await ensureDevice();\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n if (Object.keys(payload.initialParams).length > 0) {\n runtime.uploadParams(payload.initialParams);\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: payload.adam ? createAdamState(payload.adam) : null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nasync function handleCompileForward(payload) {\n const dev = await ensureDevice();\n const parent = graphs.get(payload.parentGraphId);\n if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`);\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev, sharedParams: parent.runtime.params };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nfunction createAdamState(cfg) {\n return {\n config: cfg,\n t: 0,\n lrtBuf: new Float32Array(1),\n decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null\n };\n}\nfunction injectAdamScalars(slot, inputs) {\n const a = slot.adam;\n if (!a) return inputs;\n a.t++;\n const lrNow = resolveLR(a.config.lr, a.t);\n a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t));\n const merged = { ...inputs, [a.config.lrtInputName]: a.lrtBuf };\n if (a.decayShrinkBuf && a.config.decayShrinkInputName) {\n a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay;\n merged[a.config.decayShrinkInputName] = a.decayShrinkBuf;\n }\n return merged;\n}\nasync function handleStep(payload) {\n const slot = mustGet(payload.graphId);\n const merged = injectAdamScalars(slot, payload.inputs);\n if (payload.withCaptures) {\n const r = await slot.runtime.step(merged, { withCaptures: true });\n return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const loss = await slot.runtime.step(merged);\n return { loss, captures: null };\n}\nasync function handleRun(payload) {\n const slot = mustGet(payload.graphId);\n if (payload.withCaptures) {\n const r = await slot.runtime.run(payload.inputs, { withCaptures: true });\n return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const output = await slot.runtime.run(payload.inputs);\n return { output, captures: null };\n}\nfunction capturesToRecord(captures, shapes) {\n const out = {};\n for (const name of Object.keys(shapes)) {\n if (captures.has(name)) out[name] = captures.get(name);\n }\n return out;\n}\nfunction handleUploadParams(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.uploadParams(payload.params, { partial: payload.partial });\n}\nasync function handleDownloadParams(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParams() };\n}\nasync function handleDownloadParamGrads(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParamGrads() };\n}\nfunction handleResetOptimizer(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.resetOptimizerState();\n if (slot.adam) slot.adam.t = 0;\n}\nfunction handleDestroy(payload) {\n const slot = graphs.get(payload.graphId);\n if (!slot) return;\n slot.runtime.destroy();\n graphs.delete(payload.graphId);\n}\nfunction mustGet(graphId) {\n const slot = graphs.get(graphId);\n if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`);\n return slot;\n}\nself.onmessage = async (ev) => {\n const req = ev.data;\n try {\n let result;\n let transferList = [];\n switch (req.kind) {\n case "createRuntime":\n result = await handleCreateRuntime(req.payload);\n break;\n case "compileForward":\n result = await handleCompileForward(req.payload);\n break;\n case "step":\n result = await handleStep(req.payload);\n transferList = collectTransfers(result.captures);\n break;\n case "run": {\n const r = await handleRun(req.payload);\n result = r;\n transferList = [r.output.buffer, ...collectTransfers(r.captures)];\n break;\n }\n case "uploadParams":\n handleUploadParams(req.payload);\n result = null;\n break;\n case "downloadParams": {\n const r = await handleDownloadParams(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "downloadParamGrads": {\n const r = await handleDownloadParamGrads(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "resetOptimizer":\n handleResetOptimizer(req.payload);\n result = null;\n break;\n case "destroy":\n handleDestroy(req.payload);\n result = null;\n break;\n default:\n throw new Error(`unknown request kind: ${req.kind}`);\n }\n const reply = { id: req.id, ok: true, result };\n self.postMessage(reply, { transfer: transferList });\n } catch (e) {\n const error = wireError(e);\n const reply = { id: req.id, ok: false, error };\n self.postMessage(reply);\n }\n};\nfunction collectTransfers(rec) {\n if (!rec) return [];\n const out = [];\n for (const v of Object.values(rec)) out.push(v.buffer);\n return out;\n}\n');
1809
+ const wireIR = { graph, plan, kernels };
1810
+ const wireAdam = adamResult ? wireAdamConfig(adamResult) : null;
1811
+ const transfers = transferablesOfRecord(initialParams);
1812
+ let meta;
1813
+ try {
1814
+ meta = await proxy.request(
1815
+ { kind: "createRuntime", payload: { graphId: 0, ir: wireIR, initialParams, adam: wireAdam } },
1816
+ transfers
1817
+ );
1818
+ } catch (e) {
1819
+ proxy.terminate();
1820
+ throw e;
1821
+ }
1822
+ return new CompiledModuleProxy(
1823
+ proxy,
1824
+ /* graphId */
1825
+ 0,
1826
+ ir,
1827
+ meta,
1828
+ modelFactory,
1829
+ /* initFns */
1830
+ materialized.initFns,
1831
+ /* nextGraphId */
1832
+ { v: 1 }
1833
+ );
1834
+ }
1835
+ async function compileForward(modelFactory, forward, opts = {}) {
1836
+ const { graph, materialized } = traceModule(modelFactory, forward, opts.inputs ?? {});
1837
+ const outputTensor = graph.tensors[graph.outputs[0]];
1838
+ const plan = planBuffers(
1839
+ graph,
1840
+ /* paramGrads */
1841
+ {}
1842
+ );
1843
+ const kernels = emitKernels(graph, plan);
1844
+ const ir = { graph, paramGrads: {}, loss: outputTensor, plan, kernels };
1845
+ const initialParams = buildInitialParams(plan, materialized.initFns);
1846
+ const proxy = new WorkerProxy('// src/runtime.ts\nvar Captures = class {\n constructor(shapes, data) {\n this.shapes = shapes;\n this.data = data;\n }\n shapes;\n data;\n get(name) {\n const d = this.data.get(name);\n if (!d) {\n const known = [...this.data.keys()].sort().join(", ");\n const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;\n throw new Error(`Captures.get: \'${name}\' not present. ${detail}`);\n }\n return d;\n }\n shapeOf(name) {\n const s = this.shapes[name];\n if (!s) {\n const known = Object.keys(this.shapes).sort().join(", ") || "(none registered)";\n throw new Error(`Captures.shapeOf: \'${name}\' not registered. Known: ${known}`);\n }\n return s;\n }\n has(name) {\n return this.data.has(name);\n }\n names() {\n return [...this.data.keys()].sort();\n }\n};\nvar STORAGE_RW = 128 | 8 | 4;\nvar READBACK = 1 | 8;\nasync function createRuntime(plan, kernels, lossBufferId, opts = {}) {\n const device2 = opts.device ?? await acquireDevice();\n const queue = device2.queue;\n const buffers = /* @__PURE__ */ new Map();\n const ownedBufferIds = /* @__PURE__ */ new Set();\n const sharedParams = opts.sharedParams;\n for (const spec of plan.buffers) {\n const shared = spec.kind === "param" ? sharedParams?.get(spec.name) : void 0;\n if (shared) {\n if (shared.size !== spec.byteSize) {\n throw new Error(\n `sharedParams: size mismatch for \'${spec.name}\' \\u2014 supplied ${shared.size} bytes, compiled graph expects ${spec.byteSize}.`\n );\n }\n buffers.set(spec.id, shared);\n continue;\n }\n const buf = device2.createBuffer({\n size: spec.byteSize,\n usage: STORAGE_RW,\n label: spec.name ?? `t${spec.id}-${spec.kind}`\n });\n buffers.set(spec.id, buf);\n ownedBufferIds.add(spec.id);\n if (spec.kind === "state") fillStateBuffer(spec, buf);\n }\n const moduleCache = /* @__PURE__ */ new Map();\n const pipelines = [];\n const probes = [];\n for (const k of kernels) {\n if (!k.wgsl) {\n pipelines.push(null);\n continue;\n }\n let module = moduleCache.get(k.wgsl);\n if (!module) {\n module = device2.createShaderModule({ code: k.wgsl, label: k.opKind });\n moduleCache.set(k.wgsl, module);\n }\n device2.pushErrorScope("validation");\n const pipeline = device2.createComputePipeline({\n layout: "auto",\n compute: { module, entryPoint: "main" },\n label: k.opKind\n });\n pipelines.push(pipeline);\n probes.push(device2.popErrorScope().then((err) => err ? { k, module, err } : null));\n }\n const probeResults = await Promise.all(probes);\n const failures = probeResults.filter((p) => p != null);\n if (failures.length > 0) {\n const reports = [];\n for (const { k, module, err } of failures) {\n const info = await module.getCompilationInfo();\n const messages = info.messages.map((m) => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`).join("\\n");\n reports.push(\n `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` + (messages || " (no compilation messages)") + `\n--- WGSL ---\n${k.wgsl}\n-----------`\n );\n }\n console.error(reports.join("\\n\\n"));\n throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`);\n }\n const bindGroups = kernels.map((k, i) => {\n const pipeline = pipelines[i];\n if (!pipeline) return null;\n return device2.createBindGroup({\n layout: pipeline.getBindGroupLayout(0),\n entries: k.bindings.map((bufId, idx) => ({\n binding: idx,\n resource: { buffer: buffers.get(bufId) }\n }))\n });\n });\n const outputSpec = plan.buffers[lossBufferId];\n const outputReadback = device2.createBuffer({ size: outputSpec.byteSize, usage: READBACK });\n let captureStaging = null;\n function ensureCaptureStaging() {\n if (captureStaging) return captureStaging;\n let totalBytes = 0;\n const slices = [];\n for (const [name, bufId] of plan.capturesByName) {\n const spec = plan.buffers[bufId];\n slices.push({ name, bufId, offset: totalBytes, byteSize: spec.byteSize });\n totalBytes += spec.byteSize;\n }\n const buffer = device2.createBuffer({ size: totalBytes, usage: READBACK, label: "captures-staging" });\n captureStaging = { buffer, slices };\n return captureStaging;\n }\n let pending = Promise.resolve();\n async function dispatch(inputs, opts2) {\n const turn = pending.catch(() => {\n }).then(() => dispatchUnsynchronized(inputs, opts2));\n pending = turn;\n return turn;\n }\n async function dispatchUnsynchronized(inputs, opts2) {\n const wantCaptures = opts2.wantCaptures;\n if (wantCaptures && plan.capturesByName.size === 0) {\n throw new Error(\n `withCaptures=true but no capture(...) calls were registered during the trace. Add capture(\'name\', tensor) inside your forward pass for the intermediates you want read back.`\n );\n }\n for (const [name, bufId] of plan.inputsByName) {\n const data = inputs[name];\n if (!data) throw new Error(`tensorgrad: missing input \'${name}\'`);\n const expectedBytes = plan.buffers[bufId].byteSize;\n if (data.byteLength !== expectedBytes) {\n throw new Error(`tensorgrad: input \'${name}\' has ${data.byteLength} bytes, expected ${expectedBytes}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-step" });\n for (let i = 0; i < kernels.length; i++) {\n const k = kernels[i];\n if (!k.wgsl || k.threads === 0) continue;\n const pipeline = pipelines[i];\n const bindGroup = bindGroups[i];\n const pass = encoder.beginComputePass({ label: k.opKind });\n pass.setPipeline(pipeline);\n pass.setBindGroup(0, bindGroup);\n const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));\n const MAX_X = 65535;\n const wgX = Math.min(wgCount, MAX_X);\n const wgY = Math.ceil(wgCount / MAX_X);\n pass.dispatchWorkgroups(wgX, wgY, 1);\n pass.end();\n }\n for (const wb of plan.writebacks) {\n encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);\n }\n encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);\n let layout = null;\n if (wantCaptures) {\n layout = ensureCaptureStaging();\n for (const s of layout.slices) {\n encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);\n }\n }\n queue.submit([encoder.finish()]);\n if (!opts2.readback) return null;\n await outputReadback.mapAsync(GPUMapMode.READ);\n const output = new Float32Array(outputReadback.getMappedRange().slice(0));\n outputReadback.unmap();\n const captures = /* @__PURE__ */ new Map();\n if (layout) {\n await layout.buffer.mapAsync(GPUMapMode.READ);\n const range = layout.buffer.getMappedRange();\n for (const s of layout.slices) {\n captures.set(s.name, new Float32Array(range, s.offset, s.byteSize / 4).slice());\n }\n layout.buffer.unmap();\n }\n return { output, captures };\n }\n async function step(inputs, opts2) {\n if (opts2?.readLoss === false) {\n await dispatch(inputs, { wantCaptures: false, readback: false });\n return;\n }\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { loss: r.output[0], captures: new Captures(captureShapes, r.captures) };\n return r.output[0];\n }\n async function readLoss() {\n const turn = pending.catch(() => {\n }).then(async () => {\n await outputReadback.mapAsync(GPUMapMode.READ);\n const v = new Float32Array(outputReadback.getMappedRange())[0];\n outputReadback.unmap();\n return v;\n });\n pending = turn;\n return turn;\n }\n async function run(inputs, opts2) {\n const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });\n if (opts2?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) };\n return r.output;\n }\n function uploadParams(params2, opts2) {\n const partial = opts2?.partial ?? false;\n for (const name of Object.keys(params2)) {\n if (!plan.paramsByName.has(name)) {\n throw new Error(\n `uploadParams: unknown param \'${name}\'. Known: ${[...plan.paramsByName.keys()].sort().join(", ")}`\n );\n }\n }\n if (!partial) {\n for (const name of plan.paramsByName.keys()) {\n if (!(name in params2)) {\n throw new Error(\n `uploadParams: missing param \'${name}\'. Pass { partial: true } if you mean to update only some params.`\n );\n }\n }\n }\n for (const [name, bufId] of plan.paramsByName) {\n const data = params2[name];\n if (!data) continue;\n const expected = plan.buffers[bufId].byteSize / 4;\n if (data.length !== expected) {\n throw new Error(`uploadParams: \'${name}\' has ${data.length} elements, expected ${expected}`);\n }\n queue.writeBuffer(buffers.get(bufId), 0, data);\n }\n }\n async function downloadFromMap(map) {\n const stagings = [];\n const encoder = device2.createCommandEncoder({ label: "tensorgrad-download" });\n for (const [name, bufId] of map) {\n const spec = plan.buffers[bufId];\n const staging = device2.createBuffer({ size: spec.byteSize, usage: READBACK });\n encoder.copyBufferToBuffer(buffers.get(bufId), 0, staging, 0, spec.byteSize);\n stagings.push({ name, buf: staging, bytes: spec.byteSize });\n }\n queue.submit([encoder.finish()]);\n const out = {};\n for (const s of stagings) {\n await s.buf.mapAsync(GPUMapMode.READ);\n out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0));\n s.buf.unmap();\n s.buf.destroy();\n }\n return out;\n }\n function fillStateBuffer(spec, target) {\n const elements = spec.byteSize / 4;\n const init = spec.dtype === "f32" ? new Float32Array(elements).fill(spec.initValue ?? 0) : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0));\n queue.writeBuffer(target, 0, init);\n }\n function resetOptimizerState() {\n for (const spec of plan.buffers) {\n if (spec.kind === "state") fillStateBuffer(spec, buffers.get(spec.id));\n }\n }\n const params = /* @__PURE__ */ new Map();\n for (const [name, bufId] of plan.paramsByName) {\n params.set(name, buffers.get(bufId));\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const outputShape = [...plan.buffers[lossBufferId].shape];\n const destroy = () => {\n for (const [id, b] of buffers) {\n if (ownedBufferIds.has(id)) b.destroy();\n }\n outputReadback.destroy();\n if (captureStaging) captureStaging.buffer.destroy();\n };\n return {\n device: device2,\n params,\n outputShape,\n uploadParams,\n downloadParams: () => downloadFromMap(plan.paramsByName),\n downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),\n step,\n run,\n readLoss,\n resetOptimizerState,\n destroy\n };\n}\nasync function acquireDevice() {\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad: no WebGPU adapter");\n return await adapter.requestDevice();\n}\n\n// src/adam.ts\nfunction resolveLR(schedule, step) {\n if (typeof schedule === "number") return schedule;\n switch (schedule.kind) {\n case "constant":\n return schedule.value;\n case "linearDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.peak + (schedule.final - schedule.peak) * f;\n }\n case "cosineDecay": {\n const f = Math.min(step / schedule.steps, 1);\n return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f));\n }\n case "warmup": {\n if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps);\n return resolveLR(schedule.after, step - schedule.warmupSteps);\n }\n }\n}\n\n// src/worker-protocol.ts\nfunction wireError(e) {\n if (e instanceof Error) {\n return { name: e.name, message: e.message, stack: e.stack ?? "" };\n }\n return { name: "Error", message: String(e), stack: "" };\n}\n\n// src/worker.ts\nvar graphs = /* @__PURE__ */ new Map();\nvar device = null;\nasync function ensureDevice() {\n if (device) return device;\n if (typeof navigator === "undefined" || !navigator.gpu) {\n throw new Error("tensorgrad worker: WebGPU not available in this environment");\n }\n const adapter = await navigator.gpu.requestAdapter();\n if (!adapter) throw new Error("tensorgrad worker: no WebGPU adapter");\n device = await adapter.requestDevice();\n return device;\n}\nasync function handleCreateRuntime(payload) {\n const dev = await ensureDevice();\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n if (Object.keys(payload.initialParams).length > 0) {\n runtime.uploadParams(payload.initialParams);\n }\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: payload.adam ? createAdamState(payload.adam) : null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nasync function handleCompileForward(payload) {\n const dev = await ensureDevice();\n const parent = graphs.get(payload.parentGraphId);\n if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`);\n const { graph, plan, kernels } = payload.ir;\n const outputTensorId = graph.outputs[0];\n const outputBufferId = plan.tensorToBuffer.get(outputTensorId);\n const opts = { device: dev, sharedParams: parent.runtime.params };\n const runtime = await createRuntime(plan, kernels, outputBufferId, opts);\n const captureShapes = {};\n for (const [name, bufId] of plan.capturesByName) {\n captureShapes[name] = [...plan.buffers[bufId].shape];\n }\n const slot = {\n runtime,\n paramNames: [...plan.paramsByName.keys()],\n outputShape: [...runtime.outputShape],\n kernelCount: kernels.filter((k) => k.wgsl).length,\n captureShapes,\n adam: null\n };\n graphs.set(payload.graphId, slot);\n return {\n paramNames: [...slot.paramNames],\n outputShape: slot.outputShape,\n kernelCount: slot.kernelCount,\n captureShapes: slot.captureShapes\n };\n}\nfunction createAdamState(cfg) {\n return {\n config: cfg,\n t: 0,\n lrtBuf: new Float32Array(1),\n decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null\n };\n}\nfunction injectAdamScalars(slot, inputs) {\n const a = slot.adam;\n if (!a) return inputs;\n a.t++;\n const lrNow = resolveLR(a.config.lr, a.t);\n a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t));\n const merged = { ...inputs, [a.config.lrtInputName]: a.lrtBuf };\n if (a.decayShrinkBuf && a.config.decayShrinkInputName) {\n a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay;\n merged[a.config.decayShrinkInputName] = a.decayShrinkBuf;\n }\n return merged;\n}\nasync function handleStep(payload) {\n const slot = mustGet(payload.graphId);\n const merged = injectAdamScalars(slot, payload.inputs);\n if (payload.withCaptures) {\n const r = await slot.runtime.step(merged, { withCaptures: true });\n return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const loss = await slot.runtime.step(merged);\n return { loss, captures: null };\n}\nasync function handleRun(payload) {\n const slot = mustGet(payload.graphId);\n if (payload.withCaptures) {\n const r = await slot.runtime.run(payload.inputs, { withCaptures: true });\n return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) };\n }\n const output = await slot.runtime.run(payload.inputs);\n return { output, captures: null };\n}\nfunction capturesToRecord(captures, shapes) {\n const out = {};\n for (const name of Object.keys(shapes)) {\n if (captures.has(name)) out[name] = captures.get(name);\n }\n return out;\n}\nfunction handleUploadParams(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.uploadParams(payload.params, { partial: payload.partial });\n}\nasync function handleDownloadParams(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParams() };\n}\nasync function handleDownloadParamGrads(payload) {\n const slot = mustGet(payload.graphId);\n return { params: await slot.runtime.downloadParamGrads() };\n}\nfunction handleResetOptimizer(payload) {\n const slot = mustGet(payload.graphId);\n slot.runtime.resetOptimizerState();\n if (slot.adam) slot.adam.t = 0;\n}\nfunction handleDestroy(payload) {\n const slot = graphs.get(payload.graphId);\n if (!slot) return;\n slot.runtime.destroy();\n graphs.delete(payload.graphId);\n}\nfunction mustGet(graphId) {\n const slot = graphs.get(graphId);\n if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`);\n return slot;\n}\nself.onmessage = async (ev) => {\n const req = ev.data;\n try {\n let result;\n let transferList = [];\n switch (req.kind) {\n case "createRuntime":\n result = await handleCreateRuntime(req.payload);\n break;\n case "compileForward":\n result = await handleCompileForward(req.payload);\n break;\n case "step":\n result = await handleStep(req.payload);\n transferList = collectTransfers(result.captures);\n break;\n case "run": {\n const r = await handleRun(req.payload);\n result = r;\n transferList = [r.output.buffer, ...collectTransfers(r.captures)];\n break;\n }\n case "uploadParams":\n handleUploadParams(req.payload);\n result = null;\n break;\n case "downloadParams": {\n const r = await handleDownloadParams(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "downloadParamGrads": {\n const r = await handleDownloadParamGrads(req.payload);\n result = r;\n transferList = collectTransfers(r.params);\n break;\n }\n case "resetOptimizer":\n handleResetOptimizer(req.payload);\n result = null;\n break;\n case "destroy":\n handleDestroy(req.payload);\n result = null;\n break;\n default:\n throw new Error(`unknown request kind: ${req.kind}`);\n }\n const reply = { id: req.id, ok: true, result };\n self.postMessage(reply, { transfer: transferList });\n } catch (e) {\n const error = wireError(e);\n const reply = { id: req.id, ok: false, error };\n self.postMessage(reply);\n }\n};\nfunction collectTransfers(rec) {\n if (!rec) return [];\n const out = [];\n for (const v of Object.values(rec)) out.push(v.buffer);\n return out;\n}\n');
1847
+ const wireIR = { graph, plan, kernels };
1848
+ const transfers = transferablesOfRecord(initialParams);
1849
+ let meta;
1850
+ try {
1851
+ meta = await proxy.request(
1852
+ { kind: "createRuntime", payload: { graphId: 0, ir: wireIR, initialParams, adam: null } },
1853
+ transfers
1854
+ );
1855
+ } catch (e) {
1856
+ proxy.terminate();
1857
+ throw e;
1858
+ }
1859
+ return new CompiledForwardModuleProxy(
1860
+ proxy,
1861
+ /* graphId */
1862
+ 0,
1863
+ ir,
1864
+ meta,
1865
+ /* ownsWorker */
1866
+ true
1867
+ );
1868
+ }
1869
+ var CompiledModuleProxy = class {
1870
+ constructor(proxy, graphId, ir, meta, modelFactory, initFns, nextGraphId) {
1871
+ this.proxy = proxy;
1872
+ this.graphId = graphId;
1873
+ this.ir = ir;
1874
+ this.meta = meta;
1875
+ this.modelFactory = modelFactory;
1876
+ this.initFns = initFns;
1877
+ this.nextGraphId = nextGraphId;
1878
+ }
1879
+ proxy;
1880
+ graphId;
1881
+ ir;
1882
+ meta;
1883
+ modelFactory;
1884
+ initFns;
1885
+ nextGraphId;
1886
+ get kernelCount() {
1887
+ return this.meta.kernelCount;
1888
+ }
1889
+ get outputShape() {
1890
+ return this.meta.outputShape;
1891
+ }
1892
+ get paramNames() {
1893
+ return this.meta.paramNames;
1894
+ }
1895
+ async step(inputs, opts) {
1896
+ const r = await this.proxy.request(
1897
+ { kind: "step", payload: { graphId: this.graphId, inputs, withCaptures: opts?.withCaptures === true } }
1898
+ );
1899
+ if (opts?.withCaptures) {
1900
+ return { loss: r.loss, captures: makeCaptures(r.captures, this.meta.captureShapes) };
1901
+ }
1902
+ return r.loss;
1903
+ }
1904
+ async run(inputs, opts) {
1905
+ const r = await this.proxy.request(
1906
+ { kind: "run", payload: { graphId: this.graphId, inputs, withCaptures: opts?.withCaptures === true } }
1907
+ );
1908
+ if (opts?.withCaptures) {
1909
+ return { output: r.output, captures: makeCaptures(r.captures, this.meta.captureShapes) };
1910
+ }
1911
+ return r.output;
1912
+ }
1913
+ uploadParams(params, opts) {
1914
+ return this.proxy.request(
1915
+ { kind: "uploadParams", payload: { graphId: this.graphId, params, partial: !!opts?.partial } }
1916
+ ).then(() => void 0);
1917
+ }
1918
+ async downloadParams() {
1919
+ const r = await this.proxy.request(
1920
+ { kind: "downloadParams", payload: { graphId: this.graphId } }
1921
+ );
1922
+ return r.params;
1923
+ }
1924
+ async downloadParamGrads() {
1925
+ const r = await this.proxy.request(
1926
+ { kind: "downloadParamGrads", payload: { graphId: this.graphId } }
1927
+ );
1928
+ return r.params;
1929
+ }
1930
+ async reset() {
1931
+ const initialParams = buildInitialParams(this.ir.plan, this.initFns);
1932
+ await this.uploadParams(initialParams);
1933
+ await this.resetOptimizerState();
1934
+ }
1935
+ resetOptimizerState() {
1936
+ return this.proxy.request(
1937
+ { kind: "resetOptimizer", payload: { graphId: this.graphId } }
1938
+ ).then(() => void 0);
1939
+ }
1940
+ async compileForward(forward, opts = {}) {
1941
+ const { graph, materialized: _materialized } = traceModule(this.modelFactory, forward, opts.inputs ?? {});
1942
+ const outputTensor = graph.tensors[graph.outputs[0]];
1943
+ const plan = planBuffers(
1944
+ graph,
1945
+ /* paramGrads */
1946
+ {}
1947
+ );
1948
+ const kernels = emitKernels(graph, plan);
1949
+ const ir = { graph, paramGrads: {}, loss: outputTensor, plan, kernels };
1950
+ const childGraphId = this.nextGraphId.v++;
1951
+ const wireIR = { graph, plan, kernels };
1952
+ const meta = await this.proxy.request(
1953
+ { kind: "compileForward", payload: { graphId: childGraphId, parentGraphId: this.graphId, ir: wireIR } }
1954
+ );
1955
+ return new CompiledForwardModuleProxy(
1956
+ this.proxy,
1957
+ childGraphId,
1958
+ ir,
1959
+ meta,
1960
+ /* ownsWorker */
1961
+ false
1962
+ );
1963
+ }
1964
+ destroy() {
1965
+ this.proxy.send({ kind: "destroy", payload: { graphId: this.graphId } });
1966
+ this.proxy.terminate();
1967
+ }
1968
+ };
1969
+ var CompiledForwardModuleProxy = class {
1970
+ constructor(proxy, graphId, ir, meta, ownsWorker) {
1971
+ this.proxy = proxy;
1972
+ this.graphId = graphId;
1973
+ this.ir = ir;
1974
+ this.meta = meta;
1975
+ this.ownsWorker = ownsWorker;
1976
+ }
1977
+ proxy;
1978
+ graphId;
1979
+ ir;
1980
+ meta;
1981
+ ownsWorker;
1982
+ get kernelCount() {
1983
+ return this.meta.kernelCount;
1984
+ }
1985
+ get outputShape() {
1986
+ return this.meta.outputShape;
1987
+ }
1988
+ get paramNames() {
1989
+ return this.meta.paramNames;
1990
+ }
1991
+ async run(inputs, opts) {
1992
+ const r = await this.proxy.request(
1993
+ { kind: "run", payload: { graphId: this.graphId, inputs, withCaptures: opts?.withCaptures === true } }
1994
+ );
1995
+ if (opts?.withCaptures) {
1996
+ return { output: r.output, captures: makeCaptures(r.captures, this.meta.captureShapes) };
1997
+ }
1998
+ return r.output;
1999
+ }
2000
+ uploadParams(params, opts) {
2001
+ return this.proxy.request(
2002
+ { kind: "uploadParams", payload: { graphId: this.graphId, params, partial: !!opts?.partial } }
2003
+ ).then(() => void 0);
2004
+ }
2005
+ async downloadParams() {
2006
+ const r = await this.proxy.request(
2007
+ { kind: "downloadParams", payload: { graphId: this.graphId } }
2008
+ );
2009
+ return r.params;
2010
+ }
2011
+ destroy() {
2012
+ this.proxy.send({ kind: "destroy", payload: { graphId: this.graphId } });
2013
+ if (this.ownsWorker) this.proxy.terminate();
2014
+ }
2015
+ };
2016
+ function traceModule(modelFactory, forward, inputDecls) {
2017
+ const model = modelFactory();
2018
+ let materialized = { tensors: {}, initFns: {}, decayFlags: {} };
2019
+ const graph = trace(() => {
2020
+ materialized = materializeParams(model);
2021
+ const inputTensors = {};
2022
+ for (const [name, decl] of Object.entries(inputDecls)) {
2023
+ inputTensors[name] = tensorInput(name, decl.shape, decl.dtype ?? "f32");
2024
+ }
2025
+ return forward(model, inputTensors);
2026
+ });
2027
+ return { graph, materialized };
2028
+ }
2029
+ function buildInitialParams(plan, initFns) {
2030
+ const out = {};
2031
+ for (const [name, bufId] of plan.paramsByName) {
2032
+ const shape = plan.buffers[bufId].shape;
2033
+ const size = shape.reduce((a, b) => a * b, 1);
2034
+ const initFn = initFns[name];
2035
+ if (!initFn) throw new Error(`compile: no init for param '${name}'`);
2036
+ out[name] = initFn(size, shape);
2037
+ }
2038
+ return out;
2039
+ }
2040
+ function wireAdamConfig(r) {
2041
+ const c = r.config;
2042
+ return {
2043
+ lr: c.lr,
2044
+ b1: c.b1,
2045
+ b2: c.b2,
2046
+ eps: c.eps,
2047
+ weightDecay: c.weightDecay,
2048
+ lrIsScheduled: c.lrIsScheduled,
2049
+ lrtInputName: r.lrtInputName,
2050
+ decayShrinkInputName: r.decayShrinkInputName
2051
+ };
2052
+ }
2053
+ function makeCaptures(captures, captureShapes) {
2054
+ const data = /* @__PURE__ */ new Map();
2055
+ if (captures) {
2056
+ for (const [name, arr] of Object.entries(captures)) data.set(name, arr);
2057
+ }
2058
+ return new Captures(captureShapes, data);
2059
+ }
2060
+
2061
+ // src/nn.ts
2062
+ var nn_exports = {};
2063
+ __export(nn_exports, {
2064
+ LayerNorm: () => LayerNorm,
2065
+ Linear: () => Linear,
2066
+ crossEntropyLast: () => crossEntropyLast,
2067
+ mergeHeads: () => mergeHeads,
2068
+ splitHeads: () => splitHeads,
2069
+ unsplitHeads: () => unsplitHeads
2070
+ });
2071
+ var Linear = class extends Module {
2072
+ constructor(inDim, outDim, opts = {}) {
2073
+ super();
2074
+ this.inDim = inDim;
2075
+ this.outDim = outDim;
2076
+ this.W = this.param([inDim, outDim]);
2077
+ this.b = opts.bias === false ? null : this.param([outDim], { init: "zeros" });
2078
+ }
2079
+ inDim;
2080
+ outDim;
2081
+ W;
2082
+ b;
2083
+ fwd(x) {
2084
+ const out = matmul(x, this.W);
2085
+ return this.b ? add(out, this.b) : out;
2086
+ }
2087
+ };
2088
+ var LayerNorm = class extends Module {
2089
+ constructor(d, eps = 1e-5) {
2090
+ super();
2091
+ this.d = d;
2092
+ this.eps = eps;
2093
+ this.g = this.param([d], { init: "ones" });
2094
+ this.b = this.param([d], { init: "zeros" });
2095
+ }
2096
+ d;
2097
+ eps;
2098
+ g;
2099
+ b;
2100
+ fwd(x) {
2101
+ const m = meanLast(x);
2102
+ const c = sub(x, m);
2103
+ const v = meanLast(mul(c, c));
2104
+ const stdev = sqrt(add(v, this.eps));
2105
+ return add(mul(div(c, stdev), this.g), this.b);
2106
+ }
2107
+ };
2108
+ function splitHeads(x, nHeads) {
2109
+ const site = captureSite("splitHeads");
2110
+ const r = x.shape.length;
2111
+ if (r < 2) throw new ShapeError(`splitHeads: requires rank >= 2, got ${r}`, site);
2112
+ const T = x.shape[r - 2];
2113
+ const D = x.shape[r - 1];
2114
+ if (D % nHeads !== 0) {
2115
+ throw new ShapeError(`splitHeads: last dim ${D} not divisible by nHeads ${nHeads}`, site);
2116
+ }
2117
+ const lead = x.shape.slice(0, r - 2);
2118
+ const reshaped = reshape(x, [...lead, T, nHeads, D / nHeads]);
2119
+ return swapAxes(reshaped, lead.length, lead.length + 1);
2120
+ }
2121
+ function mergeHeads(x) {
2122
+ const site = captureSite("mergeHeads");
2123
+ const r = x.shape.length;
2124
+ if (r < 3) throw new ShapeError(`mergeHeads: requires rank >= 3, got ${r}`, site);
2125
+ const H = x.shape[r - 3];
2126
+ const T = x.shape[r - 2];
2127
+ const d = x.shape[r - 1];
2128
+ const lead = x.shape.slice(0, r - 3);
2129
+ const swapped = swapAxes(x, r - 3, r - 2);
2130
+ return reshape(swapped, [...lead, T, H * d]);
2131
+ }
2132
+ function unsplitHeads(captures, name) {
2133
+ const flat = captures.get(name);
2134
+ const shape = captures.shapeOf(name);
2135
+ if (shape.length < 2) {
2136
+ throw new Error(`unsplitHeads: '${name}' shape needs >= 2 dims, got [${shape.join(", ")}]`);
2137
+ }
2138
+ const s = shape[0] === 1 ? shape.slice(1) : shape;
2139
+ const H = s[0];
2140
+ let stride = 1;
2141
+ for (let i = 1; i < s.length; i++) stride *= s[i];
2142
+ const expected = H * stride;
2143
+ if (flat.length !== expected) {
2144
+ throw new Error(`unsplitHeads: '${name}' length ${flat.length} doesn't match shape product ${expected}`);
2145
+ }
2146
+ return Array.from({ length: H }, (_, h) => flat.slice(h * stride, (h + 1) * stride));
2147
+ }
2148
+ function crossEntropyLast(logits, targets) {
2149
+ const site = captureSite("crossEntropyLast");
2150
+ if (targets.dtype !== "i32") {
2151
+ throw new ShapeError(`crossEntropyLast: targets must be i32, got ${targets.dtype}`, site);
2152
+ }
2153
+ const vocab = logits.shape[logits.shape.length - 1];
2154
+ const lp = logSoftmaxLast(logits);
2155
+ const targetLp = sumLast(mul(lp, oneHot(targets, vocab, "f32")));
2156
+ return mul(targetLp, -1);
2157
+ }
2158
+ export {
2159
+ Captures,
2160
+ Module,
2161
+ ShapeError,
2162
+ add,
2163
+ appendAdam,
2164
+ appendGrad,
2165
+ arange,
2166
+ capture,
2167
+ compileForward,
2168
+ compileModule,
2169
+ compileToIR,
2170
+ div,
2171
+ embedding,
2172
+ emitKernels,
2173
+ exp,
2174
+ greater,
2175
+ init,
2176
+ less,
2177
+ log,
2178
+ logSoftmaxLast,
2179
+ lr,
2180
+ materializeParams,
2181
+ matmul,
2182
+ matmulBatched,
2183
+ meanLast,
2184
+ mul,
2185
+ nn_exports as nn,
2186
+ oneHot,
2187
+ paramInput,
2188
+ planBuffers,
2189
+ relu,
2190
+ reshape,
2191
+ resolveLR,
2192
+ rsqrt,
2193
+ sliceLastRange,
2194
+ softmaxCausalLast,
2195
+ sqrt,
2196
+ stateInput,
2197
+ sub,
2198
+ sumAll,
2199
+ sumLast,
2200
+ swapAxes,
2201
+ tensorInput,
2202
+ trace,
2203
+ traceInto,
2204
+ transpose,
2205
+ where,
2206
+ whereCausal
2207
+ };
2208
+ //# sourceMappingURL=index.js.map