tensorgrad 0.0.14 → 0.0.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/codegen.js DELETED
@@ -1,724 +0,0 @@
1
- // WGSL codegen: one kernel per IR op.
2
- //
3
- // All shapes are baked into the WGSL as compile-time constants — no shape
4
- // uniforms. This means each shape combination produces a distinct shader
5
- // (so `add([B, T, D], [D])` and `add([B, T, D], [B, T, D])` get different
6
- // kernels), which is fine for our static-shape model and gives the WGSL
7
- // compiler full freedom to specialize.
8
- //
9
- // Most kernels are direct ports of `transformer-gpu.bulb.md`'s WGSL — those
10
- // are already debugged and tuned. The autograd ops (broadcast_to, sum_to_shape,
11
- // relu_grad, etc.) are new.
12
- import { shapeSize } from './shape.js';
13
- // Workgroup size of 256 means even our biggest kernel (~8M threads in
14
- // matmul_bwd_dW) needs only ~32K workgroups, well under WebGPU's 65535-per-dim
15
- // dispatch cap. Smaller WG_SIZE forced 2D dispatch with significant over-dispatch.
16
- const WG_SIZE = 256;
17
- // Global thread index, packed across the 2D dispatch grid that lets us route
18
- // past WebGPU's 65535-per-dim cap. Every kernel uses this exact line — keep
19
- // the formula consistent with the dispatch-stride math in runtime.ts (MAX_X
20
- // = 65535, so per-row stride = 65535 * WG_SIZE = 16776960). Inlined into
21
- // each WGSL string via interpolation rather than a function so the WGSL
22
- // compiler still sees a literal constant.
23
- const GID_LINE = 'let i = gid.x + gid.y * 16776960u;';
24
- // ============================================================================
25
- // Public entry point
26
- // ============================================================================
27
- /** Generate a KernelSpec per compute op in graph.ops (in dispatch order). */
28
- export function emitKernels(graph, plan) {
29
- const out = [];
30
- for (let i = 0; i < graph.ops.length; i++) {
31
- const op = graph.ops[i];
32
- const spec = emitKernel(op, graph, plan, i);
33
- out.push(spec);
34
- }
35
- return out;
36
- }
37
- function emitKernel(op, graph, plan, opIndex) {
38
- const tof = (id) => graph.tensors[id];
39
- const buf = (tensorId) => plan.tensorToBuffer.get(tensorId);
40
- const empty = () => ({ opIndex, opKind: op.kind, wgsl: '', bindings: [], threads: 0, workgroupSize: WG_SIZE });
41
- switch (op.kind) {
42
- // ---- Leaves: data is supplied externally; no kernel ---------------------
43
- case 'param_input':
44
- case 'tensor_input':
45
- case 'state_input':
46
- return empty();
47
- // ---- arange / const_scalar: kernel that fills the buffer once -----------
48
- case 'arange': {
49
- const out = tof(op.out);
50
- const wgsl = `
51
- @group(0) @binding(0) var<storage, read_write> buf : array<${wgslDtype(out.dtype)}>;
52
- @compute @workgroup_size(${WG_SIZE})
53
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
54
- ${GID_LINE}
55
- if (i >= ${op.n}u) { return; }
56
- buf[i] = ${castFromI32('i32(i)', out.dtype)};
57
- }`.trim();
58
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.out)], threads: op.n, workgroupSize: WG_SIZE };
59
- }
60
- case 'const_scalar': {
61
- const wgsl = `
62
- @group(0) @binding(0) var<storage, read_write> buf : array<${wgslDtype(op.dtype)}>;
63
- @compute @workgroup_size(1)
64
- fn main() {
65
- buf[0] = ${wgslLiteral(op.value, op.dtype)};
66
- }`.trim();
67
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.out)], threads: 1, workgroupSize: 1 };
68
- }
69
- // ---- Element-wise binops with broadcast --------------------------------
70
- case 'add':
71
- case 'sub':
72
- case 'mul':
73
- case 'div': {
74
- const out = tof(op.out);
75
- const a = tof(op.a);
76
- const b = tof(op.b);
77
- const opStr = { add: '+', sub: '-', mul: '*', div: '/' }[op.kind];
78
- const total = shapeSize(out.shape);
79
- const wgsl = `
80
- @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
81
- @group(0) @binding(1) var<storage, read> b : array<${wgslDtype(b.dtype)}>;
82
- @group(0) @binding(2) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
83
- @compute @workgroup_size(${WG_SIZE})
84
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
85
- ${GID_LINE}
86
- if (i >= ${total}u) { return; }
87
- ${broadcastIndexBlock('i', out.shape, a.shape, 'aIdx')}
88
- ${broadcastIndexBlock('i', out.shape, b.shape, 'bIdx')}
89
- out[i] = a[aIdx] ${opStr} b[bIdx];
90
- }`.trim();
91
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
92
- }
93
- // ---- Element-wise scalar binops (scalar baked into WGSL) ---------------
94
- case 'mul_scalar':
95
- case 'add_scalar': {
96
- const out = tof(op.out);
97
- const a = tof(op.a);
98
- const opStr = op.kind === 'mul_scalar' ? '*' : '+';
99
- const total = shapeSize(out.shape);
100
- const lit = wgslLiteral(op.scalar, out.dtype);
101
- const wgsl = `
102
- @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
103
- @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
104
- @compute @workgroup_size(${WG_SIZE})
105
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
106
- ${GID_LINE}
107
- if (i >= ${total}u) { return; }
108
- out[i] = a[i] ${opStr} ${lit};
109
- }`.trim();
110
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
111
- }
112
- // ---- Unary -------------------------------------------------------------
113
- case 'sqrt':
114
- case 'rsqrt':
115
- case 'log':
116
- case 'exp':
117
- case 'relu': {
118
- const out = tof(op.out);
119
- const a = tof(op.a);
120
- const total = shapeSize(out.shape);
121
- const expr = op.kind === 'sqrt' ? 'sqrt(x)' :
122
- op.kind === 'rsqrt' ? '1.0 / sqrt(x)' :
123
- op.kind === 'log' ? 'log(x)' :
124
- op.kind === 'exp' ? 'exp(x)' :
125
- /* relu */ 'max(x, 0.0)';
126
- const wgsl = `
127
- @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
128
- @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
129
- @compute @workgroup_size(${WG_SIZE})
130
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
131
- ${GID_LINE}
132
- if (i >= ${total}u) { return; }
133
- let x = a[i];
134
- out[i] = ${expr};
135
- }`.trim();
136
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
137
- }
138
- // ---- Comparisons + select --------------------------------------------
139
- case 'less':
140
- case 'greater': {
141
- const out = tof(op.out);
142
- const a = tof(op.a);
143
- const b = tof(op.b);
144
- const opStr = op.kind === 'less' ? '<' : '>';
145
- const total = shapeSize(out.shape);
146
- // bool tensors lower to u32 in storage (1 if true, 0 if false).
147
- const wgsl = `
148
- @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
149
- @group(0) @binding(1) var<storage, read> b : array<${wgslDtype(b.dtype)}>;
150
- @group(0) @binding(2) var<storage, read_write> out : array<u32>;
151
- @compute @workgroup_size(${WG_SIZE})
152
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
153
- ${GID_LINE}
154
- if (i >= ${total}u) { return; }
155
- ${broadcastIndexBlock('i', out.shape, a.shape, 'aIdx')}
156
- ${broadcastIndexBlock('i', out.shape, b.shape, 'bIdx')}
157
- out[i] = select(0u, 1u, a[aIdx] ${opStr} b[bIdx]);
158
- }`.trim();
159
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
160
- }
161
- case 'where': {
162
- const out = tof(op.out);
163
- const cond = tof(op.cond);
164
- const a = tof(op.a);
165
- const b = tof(op.b);
166
- const total = shapeSize(out.shape);
167
- const wgsl = `
168
- @group(0) @binding(0) var<storage, read> cond : array<u32>;
169
- @group(0) @binding(1) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
170
- @group(0) @binding(2) var<storage, read> b : array<${wgslDtype(b.dtype)}>;
171
- @group(0) @binding(3) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
172
- @compute @workgroup_size(${WG_SIZE})
173
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
174
- ${GID_LINE}
175
- if (i >= ${total}u) { return; }
176
- ${broadcastIndexBlock('i', out.shape, cond.shape, 'cIdx')}
177
- ${broadcastIndexBlock('i', out.shape, a.shape, 'aIdx')}
178
- ${broadcastIndexBlock('i', out.shape, b.shape, 'bIdx')}
179
- out[i] = select(b[bIdx], a[aIdx], cond[cIdx] != 0u);
180
- }`.trim();
181
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.cond), buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
182
- }
183
- case 'relu_grad': {
184
- const out = tof(op.out);
185
- const total = shapeSize(out.shape);
186
- const wgsl = `
187
- @group(0) @binding(0) var<storage, read> x : array<f32>;
188
- @group(0) @binding(1) var<storage, read> dy : array<f32>;
189
- @group(0) @binding(2) var<storage, read_write> out : array<f32>;
190
- @compute @workgroup_size(${WG_SIZE})
191
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
192
- ${GID_LINE}
193
- if (i >= ${total}u) { return; }
194
- out[i] = select(0.0, dy[i], x[i] > 0.0);
195
- }`.trim();
196
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.x), buf(op.dy), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
197
- }
198
- // ---- Reductions over last axis -----------------------------------------
199
- case 'mean_last':
200
- case 'sum_last': {
201
- const a = tof(op.a);
202
- const D = a.shape[a.shape.length - 1];
203
- const outerSize = shapeSize(a.shape) / D;
204
- const divisor = op.kind === 'mean_last' ? `f32(${D}u)` : '1.0';
205
- const wgsl = `
206
- @group(0) @binding(0) var<storage, read> a : array<f32>;
207
- @group(0) @binding(1) var<storage, read_write> out : array<f32>;
208
- @compute @workgroup_size(${WG_SIZE})
209
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
210
- ${GID_LINE}
211
- if (i >= ${outerSize}u) { return; }
212
- let base = i * ${D}u;
213
- var s : f32 = 0.0;
214
- for (var j : u32 = 0u; j < ${D}u; j = j + 1u) {
215
- s = s + a[base + j];
216
- }
217
- out[i] = s / ${divisor};
218
- }`.trim();
219
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: outerSize, workgroupSize: WG_SIZE };
220
- }
221
- // ---- Shape ---------------------------------------------------------------
222
- // reshape: no kernel needed if buffers can alias (shape change only). For
223
- // v1 simplicity we emit a memcpy-style kernel rather than aliasing buffers,
224
- // because aliasing complicates the buffer plan and we have memory headroom.
225
- case 'reshape': {
226
- const out = tof(op.out);
227
- const a = tof(op.a);
228
- const total = shapeSize(out.shape);
229
- const wgsl = `
230
- @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
231
- @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
232
- @compute @workgroup_size(${WG_SIZE})
233
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
234
- ${GID_LINE}
235
- if (i >= ${total}u) { return; }
236
- out[i] = a[i];
237
- }`.trim();
238
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
239
- }
240
- case 'transpose': {
241
- const out = tof(op.out);
242
- const a = tof(op.a);
243
- const total = shapeSize(out.shape);
244
- // Emit per-axis index computation. For each output flat index i, decompose
245
- // into per-axis output indices, then use op.perm to find the source axis order.
246
- // Source flat index = sum(outIdx[perm.invert()[k]] * a_stride[k] for k).
247
- const aStrides = computeStrides(a.shape);
248
- const outDimDecls = decomposeFlatIndexBlock('i', out.shape, 'oIdx');
249
- const srcExpr = [];
250
- for (let k = 0; k < a.shape.length; k++) {
251
- const srcAxis = op.perm.indexOf(k); // which output axis came from input axis k
252
- srcExpr.push(`oIdx_${srcAxis} * ${aStrides[k]}u`);
253
- }
254
- const wgsl = `
255
- @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
256
- @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
257
- @compute @workgroup_size(${WG_SIZE})
258
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
259
- ${GID_LINE}
260
- if (i >= ${total}u) { return; }
261
- ${outDimDecls}
262
- let srcIdx = ${srcExpr.join(' + ')};
263
- out[i] = a[srcIdx];
264
- }`.trim();
265
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
266
- }
267
- // ---- Linear algebra ----------------------------------------------------
268
- // matmul: a [..., M, K] · b [K, N] -> [..., M, N]. b is unbatched.
269
- case 'matmul': {
270
- const out = tof(op.out);
271
- const a = tof(op.a);
272
- const b = tof(op.b);
273
- const M = a.shape[a.shape.length - 2];
274
- const K = a.shape[a.shape.length - 1];
275
- const N = b.shape[1];
276
- const batch = shapeSize(a.shape) / (M * K);
277
- const total = batch * M * N;
278
- const wgsl = `
279
- @group(0) @binding(0) var<storage, read> a : array<f32>;
280
- @group(0) @binding(1) var<storage, read> b : array<f32>;
281
- @group(0) @binding(2) var<storage, read_write> c : array<f32>;
282
- @compute @workgroup_size(${WG_SIZE})
283
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
284
- ${GID_LINE}
285
- if (i >= ${total}u) { return; }
286
- let bi = i / ${M * N}u; // batch index
287
- let mn = i % ${M * N}u;
288
- let m = mn / ${N}u;
289
- let n = mn % ${N}u;
290
- let aBase = bi * ${M * K}u + m * ${K}u;
291
- var s : f32 = 0.0;
292
- for (var k : u32 = 0u; k < ${K}u; k = k + 1u) {
293
- s = s + a[aBase + k] * b[k * ${N}u + n];
294
- }
295
- c[i] = s;
296
- }`.trim();
297
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
298
- }
299
- case 'matmul_batched': {
300
- const out = tof(op.out);
301
- const a = tof(op.a);
302
- const b = tof(op.b);
303
- const M = a.shape[a.shape.length - 2];
304
- const K = a.shape[a.shape.length - 1];
305
- const N = b.shape[b.shape.length - 1];
306
- const batch = shapeSize(a.shape) / (M * K);
307
- const total = batch * M * N;
308
- const wgsl = `
309
- @group(0) @binding(0) var<storage, read> a : array<f32>;
310
- @group(0) @binding(1) var<storage, read> b : array<f32>;
311
- @group(0) @binding(2) var<storage, read_write> c : array<f32>;
312
- @compute @workgroup_size(${WG_SIZE})
313
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
314
- ${GID_LINE}
315
- if (i >= ${total}u) { return; }
316
- let bi = i / ${M * N}u;
317
- let mn = i % ${M * N}u;
318
- let m = mn / ${N}u;
319
- let n = mn % ${N}u;
320
- let aBase = bi * ${M * K}u + m * ${K}u;
321
- let bBase = bi * ${K * N}u;
322
- var s : f32 = 0.0;
323
- for (var k : u32 = 0u; k < ${K}u; k = k + 1u) {
324
- s = s + a[aBase + k] * b[bBase + k * ${N}u + n];
325
- }
326
- c[i] = s;
327
- }`.trim();
328
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
329
- }
330
- // ---- One-hot ------------------------------------------------------------
331
- case 'one_hot': {
332
- const out = tof(op.out);
333
- const indices = tof(op.indices);
334
- const total = shapeSize(out.shape);
335
- const depth = op.depth;
336
- const zeroLit = wgslLiteral(0, out.dtype);
337
- const oneLit = wgslLiteral(1, out.dtype);
338
- const wgsl = `
339
- @group(0) @binding(0) var<storage, read> indices : array<i32>;
340
- @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
341
- @compute @workgroup_size(${WG_SIZE})
342
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
343
- ${GID_LINE}
344
- if (i >= ${total}u) { return; }
345
- let outerIdx = i / ${depth}u;
346
- let depthIdx = i % ${depth}u;
347
- let tgt = u32(indices[outerIdx]);
348
- out[i] = select(${zeroLit}, ${oneLit}, tgt == depthIdx);
349
- }`.trim();
350
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.indices), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
351
- }
352
- // ---- ML primitives -----------------------------------------------------
353
- case 'log_softmax_last': {
354
- const a = tof(op.a);
355
- const D = a.shape[a.shape.length - 1];
356
- const outerSize = shapeSize(a.shape) / D;
357
- const wgsl = `
358
- @group(0) @binding(0) var<storage, read> a : array<f32>;
359
- @group(0) @binding(1) var<storage, read_write> out : array<f32>;
360
- @compute @workgroup_size(${WG_SIZE})
361
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
362
- ${GID_LINE}
363
- if (i >= ${outerSize}u) { return; }
364
- let base = i * ${D}u;
365
- var m : f32 = -1.0e30;
366
- for (var j : u32 = 0u; j < ${D}u; j = j + 1u) {
367
- let v = a[base + j];
368
- if (v > m) { m = v; }
369
- }
370
- var s : f32 = 0.0;
371
- for (var j : u32 = 0u; j < ${D}u; j = j + 1u) {
372
- s = s + exp(a[base + j] - m);
373
- }
374
- let logZ = m + log(s);
375
- for (var j : u32 = 0u; j < ${D}u; j = j + 1u) {
376
- out[base + j] = a[base + j] - logZ;
377
- }
378
- }`.trim();
379
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: outerSize, workgroupSize: WG_SIZE };
380
- }
381
- case 'softmax_causal_last': {
382
- const a = tof(op.a);
383
- const T = a.shape[a.shape.length - 1]; // == second-to-last (square)
384
- // Outer size = (everything except last 2 axes) * (second-to-last axis)
385
- const outerSize = shapeSize(a.shape) / T;
386
- const wgsl = `
387
- @group(0) @binding(0) var<storage, read> a : array<f32>;
388
- @group(0) @binding(1) var<storage, read_write> out : array<f32>;
389
- @compute @workgroup_size(${WG_SIZE})
390
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
391
- // Each thread handles one (..., qpos)-row, softmaxing over kpos∈[0..qpos].
392
- ${GID_LINE}
393
- if (i >= ${outerSize}u) { return; }
394
- let qpos = i % ${T}u;
395
- let base = i * ${T}u;
396
- var m : f32 = -1.0e30;
397
- for (var k : u32 = 0u; k <= qpos; k = k + 1u) {
398
- let v = a[base + k];
399
- if (v > m) { m = v; }
400
- }
401
- var s : f32 = 0.0;
402
- for (var k : u32 = 0u; k <= qpos; k = k + 1u) {
403
- let e = exp(a[base + k] - m);
404
- out[base + k] = e;
405
- s = s + e;
406
- }
407
- for (var k : u32 = 0u; k <= qpos; k = k + 1u) {
408
- out[base + k] = out[base + k] / s;
409
- }
410
- for (var k : u32 = qpos + 1u; k < ${T}u; k = k + 1u) {
411
- out[base + k] = 0.0;
412
- }
413
- }`.trim();
414
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: outerSize, workgroupSize: WG_SIZE };
415
- }
416
- case 'where_causal': {
417
- const a = tof(op.a);
418
- const T = a.shape[a.shape.length - 1];
419
- const total = shapeSize(a.shape);
420
- const fillLit = wgslLiteral(op.fillValue, 'f32');
421
- const wgsl = `
422
- @group(0) @binding(0) var<storage, read> a : array<f32>;
423
- @group(0) @binding(1) var<storage, read_write> out : array<f32>;
424
- @compute @workgroup_size(${WG_SIZE})
425
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
426
- ${GID_LINE}
427
- if (i >= ${total}u) { return; }
428
- let kpos = i % ${T}u;
429
- let qpos = (i / ${T}u) % ${T}u;
430
- if (kpos > qpos) {
431
- out[i] = ${fillLit};
432
- } else {
433
- out[i] = a[i];
434
- }
435
- }`.trim();
436
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
437
- }
438
- // ---- Slicing -----------------------------------------------------------
439
- case 'slice_last_range': {
440
- const out = tof(op.out);
441
- const a = tof(op.a);
442
- const D_in = a.shape[a.shape.length - 1];
443
- const D_out = op.end - op.start;
444
- const total = shapeSize(out.shape);
445
- const wgsl = `
446
- @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
447
- @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
448
- @compute @workgroup_size(${WG_SIZE})
449
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
450
- ${GID_LINE}
451
- if (i >= ${total}u) { return; }
452
- let outer = i / ${D_out}u;
453
- let inner = i % ${D_out}u;
454
- out[i] = a[outer * ${D_in}u + ${op.start}u + inner];
455
- }`.trim();
456
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
457
- }
458
- // ---- Broadcast / un-broadcast (autograd infrastructure) ----------------
459
- case 'broadcast_to': {
460
- const out = tof(op.out);
461
- const a = tof(op.a);
462
- const total = shapeSize(out.shape);
463
- const wgsl = `
464
- @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
465
- @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
466
- @compute @workgroup_size(${WG_SIZE})
467
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
468
- ${GID_LINE}
469
- if (i >= ${total}u) { return; }
470
- ${broadcastIndexBlock('i', out.shape, a.shape, 'srcIdx')}
471
- out[i] = a[srcIdx];
472
- }`.trim();
473
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
474
- }
475
- // ---- Adam (fused per-element) -----------------------------------------
476
- case 'adam_update_m': {
477
- // m_new = b1 * m + (1 - b1) * g
478
- const out = tof(op.out);
479
- const total = shapeSize(out.shape);
480
- const b1 = op.b1;
481
- const oneMinusB1 = 1 - b1;
482
- const wgsl = `
483
- @group(0) @binding(0) var<storage, read> m : array<f32>;
484
- @group(0) @binding(1) var<storage, read> g : array<f32>;
485
- @group(0) @binding(2) var<storage, read_write> out : array<f32>;
486
- @compute @workgroup_size(${WG_SIZE})
487
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
488
- ${GID_LINE}
489
- if (i >= ${total}u) { return; }
490
- out[i] = ${wgslLiteral(b1, 'f32')} * m[i] + ${wgslLiteral(oneMinusB1, 'f32')} * g[i];
491
- }`.trim();
492
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.m), buf(op.g), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
493
- }
494
- case 'adam_update_v': {
495
- // v_new = b2 * v + (1 - b2) * g²
496
- const out = tof(op.out);
497
- const total = shapeSize(out.shape);
498
- const b2 = op.b2;
499
- const oneMinusB2 = 1 - b2;
500
- const wgsl = `
501
- @group(0) @binding(0) var<storage, read> v : array<f32>;
502
- @group(0) @binding(1) var<storage, read> g : array<f32>;
503
- @group(0) @binding(2) var<storage, read_write> out : array<f32>;
504
- @compute @workgroup_size(${WG_SIZE})
505
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
506
- ${GID_LINE}
507
- if (i >= ${total}u) { return; }
508
- let gv = g[i];
509
- out[i] = ${wgslLiteral(b2, 'f32')} * v[i] + ${wgslLiteral(oneMinusB2, 'f32')} * gv * gv;
510
- }`.trim();
511
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.v), buf(op.g), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
512
- }
513
- case 'adam_update_p': {
514
- // p_new = decayShrink * p - lrt[0] * m_new / (sqrt(v_new) + eps).
515
- // lrt is supplied per-step from CPU (already includes bias correction).
516
- // decayShrink is either baked as a literal (no schedule, fixed lr) or
517
- // bound as a per-step scalar input (when the user supplies an lr
518
- // schedule via `adam: { lr: (step) => ... }`). When literal=1 the WGSL
519
- // compiler folds the multiply away.
520
- const out = tof(op.out);
521
- const total = shapeSize(out.shape);
522
- const dynamicShrink = op.decayShrinkTensor !== null;
523
- const shrinkExpr = dynamicShrink ? 'decayShrink[0]' : wgslLiteral(op.decayShrink, 'f32');
524
- const shrinkBinding = dynamicShrink
525
- ? `@group(0) @binding(4) var<storage, read> decayShrink : array<f32>;\n` +
526
- `@group(0) @binding(5) var<storage, read_write> out : array<f32>;`
527
- : `@group(0) @binding(4) var<storage, read_write> out : array<f32>;`;
528
- const wgsl = `
529
- @group(0) @binding(0) var<storage, read> p : array<f32>;
530
- @group(0) @binding(1) var<storage, read> mNew : array<f32>;
531
- @group(0) @binding(2) var<storage, read> vNew : array<f32>;
532
- @group(0) @binding(3) var<storage, read> lrt : array<f32>;
533
- ${shrinkBinding}
534
- @compute @workgroup_size(${WG_SIZE})
535
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
536
- ${GID_LINE}
537
- if (i >= ${total}u) { return; }
538
- out[i] = ${shrinkExpr} * p[i] - lrt[0] * mNew[i] / (sqrt(vNew[i]) + ${wgslLiteral(op.eps, 'f32')});
539
- }`.trim();
540
- const bindings = dynamicShrink
541
- ? [buf(op.p), buf(op.mNew), buf(op.vNew), buf(op.lrt), buf(op.decayShrinkTensor), buf(op.out)]
542
- : [buf(op.p), buf(op.mNew), buf(op.vNew), buf(op.lrt), buf(op.out)];
543
- return { opIndex, opKind: op.kind, wgsl, bindings, threads: total, workgroupSize: WG_SIZE };
544
- }
545
- case 'sum_to_shape': {
546
- // Sum-reduce src down to target by summing over each axis where target=1
547
- // or where target is missing (offset-prefix axes that get fully summed).
548
- const out = tof(op.out);
549
- const a = tof(op.a);
550
- const wgsl = emitSumToShape(a.shape, out.shape, a.dtype);
551
- const total = shapeSize(out.shape);
552
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE };
553
- }
554
- }
555
- }
556
- // ============================================================================
557
- // WGSL helpers
558
- // ============================================================================
559
- function wgslDtype(d) {
560
- // bool can't be in storage buffers in WGSL; we lower bool-typed tensors to
561
- // u32 (0/1). For Phase 3a there are no bool-typed storage buffers in the
562
- // forward+backward graph (causal mask is built inline in softmax kernels),
563
- // so this only matters if the user explicitly creates a bool tensor.
564
- if (d === 'bool')
565
- return 'u32';
566
- return d;
567
- }
568
- function wgslLiteral(value, dtype) {
569
- if (dtype === 'f32') {
570
- if (Number.isFinite(value)) {
571
- // WGSL requires `.` in float literals; force decimal form.
572
- return value.toString().includes('.') || value.toString().includes('e')
573
- ? `${value}f`
574
- : `${value}.0f`;
575
- }
576
- return value > 0 ? '1.0e30f' : '-1.0e30f';
577
- }
578
- if (dtype === 'i32')
579
- return `${Math.trunc(value)}i`;
580
- return value ? '1u' : '0u';
581
- }
582
- function castFromI32(expr, dtype) {
583
- if (dtype === 'f32')
584
- return `f32(${expr})`;
585
- if (dtype === 'i32')
586
- return `i32(${expr})`;
587
- return `u32(${expr})`;
588
- }
589
- function computeStrides(shape) {
590
- const strides = new Array(shape.length).fill(1);
591
- for (let i = shape.length - 2; i >= 0; i--) {
592
- strides[i] = strides[i + 1] * shape[i + 1];
593
- }
594
- return strides;
595
- }
596
- /**
597
- * Generate WGSL that decomposes a flat index `flatVar` into per-axis indices
598
- * `outVar_0, outVar_1, ...` according to `shape`.
599
- */
600
- function decomposeFlatIndexBlock(flatVar, shape, outVar) {
601
- if (shape.length === 0)
602
- return ` let ${outVar}_0 : u32 = 0u;`; // not used but parser-safe
603
- const strides = computeStrides(shape);
604
- const lines = [];
605
- let remaining = flatVar;
606
- for (let i = 0; i < shape.length; i++) {
607
- if (i === shape.length - 1) {
608
- lines.push(` let ${outVar}_${i} = ${remaining};`);
609
- }
610
- else {
611
- lines.push(` let ${outVar}_${i} = ${remaining} / ${strides[i]}u;`);
612
- const newRem = `${outVar}_rem${i}`;
613
- lines.push(` let ${newRem} = ${remaining} % ${strides[i]}u;`);
614
- remaining = newRem;
615
- }
616
- }
617
- return lines.join('\n');
618
- }
619
- /**
620
- * Generate WGSL that computes the source flat index in `srcVar` for an output
621
- * flat index `flatVar`, given output shape `outShape` and source shape `srcShape`
622
- * under right-aligned NumPy-style broadcasting (size-1 axes broadcast).
623
- *
624
- * Strategy:
625
- * 1. Decompose flat output index into per-axis output indices.
626
- * 2. For each output axis that maps onto a source axis (right-aligned), use
627
- * the output index there if src.dim != 1, else 0 (broadcast).
628
- * 3. Drop output-only axes (those with no corresponding source axis).
629
- * 4. Combine source indices with source strides.
630
- */
631
- function broadcastIndexBlock(flatVar, outShape, srcShape, srcVar) {
632
- // Name the per-axis decomposition vars after `srcVar` so multiple
633
- // broadcastIndexBlock calls in the same WGSL function don't collide.
634
- const prefix = `${srcVar}_ax`;
635
- const decompose = decomposeFlatIndexBlock(flatVar, outShape, prefix);
636
- const offset = outShape.length - srcShape.length;
637
- if (srcShape.length === 0) {
638
- return `${decompose}\n let ${srcVar} : u32 = 0u;`;
639
- }
640
- const srcStrides = computeStrides(srcShape);
641
- const terms = [];
642
- for (let i = 0; i < srcShape.length; i++) {
643
- const outAxis = i + offset;
644
- const srcDim = srcShape[i];
645
- const term = srcDim === 1 ? '0u' : `${prefix}_${outAxis} * ${srcStrides[i]}u`;
646
- terms.push(term);
647
- }
648
- return `${decompose}\n let ${srcVar} = ${terms.join(' + ')};`;
649
- }
650
- /**
651
- * sum_to_shape: each output cell sums over the source axes that are reduced.
652
- * For source shape S and target shape T (right-aligned):
653
- * - Axes in S not in T (leading prefix): fully reduced (sum over whole axis).
654
- * - Axes where T=1 but S>1: reduced (sum over that axis).
655
- * - Axes where T=S: passed through.
656
- *
657
- * Implementation: each thread = one output cell. It iterates over the reduced
658
- * axes via nested-loop unrolling (we generate explicit nested for-loops).
659
- */
660
- function emitSumToShape(srcShape, tgtShape, dtype) {
661
- const srcStrides = computeStrides(srcShape);
662
- const tgtStrides = computeStrides(tgtShape);
663
- const offset = srcShape.length - tgtShape.length;
664
- // Decompose flat output index into per-axis target indices.
665
- const decompose = decomposeFlatIndexBlock('i', tgtShape, 'tgt');
666
- // Identify reduced axes of the SOURCE: axis k in src is reduced if either
667
- // it's in the leading prefix (k < offset) or its corresponding target axis
668
- // has size 1. For non-reduced axes (k >= offset and tgt=src), the source
669
- // index is the target index along that axis.
670
- const reducedAxes = [];
671
- for (let k = 0; k < srcShape.length; k++) {
672
- if (k < offset) {
673
- reducedAxes.push(k);
674
- continue;
675
- }
676
- const tDim = tgtShape[k - offset];
677
- const sDim = srcShape[k];
678
- if (tDim === 1 && sDim > 1)
679
- reducedAxes.push(k);
680
- }
681
- // Build the source flat index expression. Initialize from the non-reduced axes.
682
- const baseTerms = [];
683
- for (let k = 0; k < srcShape.length; k++) {
684
- if (reducedAxes.includes(k))
685
- continue; // contributed by loop var instead
686
- const tAxis = k - offset;
687
- baseTerms.push(`tgt_${tAxis} * ${srcStrides[k]}u`);
688
- }
689
- const baseExpr = baseTerms.length > 0 ? baseTerms.join(' + ') : '0u';
690
- // Emit nested for loops over the reduced axes.
691
- const indent = (depth) => ' '.repeat(depth + 1);
692
- const loops = [];
693
- for (let depth = 0; depth < reducedAxes.length; depth++) {
694
- const k = reducedAxes[depth];
695
- const dim = srcShape[k];
696
- loops.push(`${indent(depth)}for (var r${k} : u32 = 0u; r${k} < ${dim}u; r${k} = r${k} + 1u) {`);
697
- }
698
- // Inside innermost loop, compute source index.
699
- const reducedTerms = reducedAxes.map(k => `r${k} * ${srcStrides[k]}u`);
700
- const fullExpr = reducedTerms.length > 0
701
- ? `${baseExpr} + ${reducedTerms.join(' + ')}`
702
- : baseExpr;
703
- loops.push(`${indent(reducedAxes.length)}s = s + a[${fullExpr}];`);
704
- for (let depth = reducedAxes.length - 1; depth >= 0; depth--) {
705
- loops.push(`${indent(depth)}}`);
706
- }
707
- const total = tgtShape.length === 0 ? 1 : (tgtStrides[0] * tgtShape[0]);
708
- const loopBody = reducedAxes.length === 0
709
- ? ` s = s + a[${baseExpr}];`
710
- : loops.join('\n');
711
- return `
712
- @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(dtype)}>;
713
- @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(dtype)}>;
714
- @compute @workgroup_size(${WG_SIZE})
715
- fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
716
- ${GID_LINE}
717
- if (i >= ${total}u) { return; }
718
- ${decompose}
719
- var s : ${wgslDtype(dtype)} = ${dtype === 'f32' ? '0.0f' : (dtype === 'i32' ? '0i' : '0u')};
720
- ${loopBody}
721
- out[i] = s;
722
- }`.trim();
723
- }
724
- //# sourceMappingURL=codegen.js.map