tensorgrad 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +121 -0
  3. package/SPEC.md +293 -0
  4. package/dist/adam.d.ts +31 -0
  5. package/dist/adam.d.ts.map +1 -0
  6. package/dist/adam.js +66 -0
  7. package/dist/adam.js.map +1 -0
  8. package/dist/buffers.d.ts +56 -0
  9. package/dist/buffers.d.ts.map +1 -0
  10. package/dist/buffers.js +114 -0
  11. package/dist/buffers.js.map +1 -0
  12. package/dist/codegen.d.ts +23 -0
  13. package/dist/codegen.d.ts.map +1 -0
  14. package/dist/codegen.js +709 -0
  15. package/dist/codegen.js.map +1 -0
  16. package/dist/compile.d.ts +53 -0
  17. package/dist/compile.d.ts.map +1 -0
  18. package/dist/compile.js +76 -0
  19. package/dist/compile.js.map +1 -0
  20. package/dist/grad.d.ts +8 -0
  21. package/dist/grad.d.ts.map +1 -0
  22. package/dist/grad.js +404 -0
  23. package/dist/grad.js.map +1 -0
  24. package/dist/index.d.ts +12 -0
  25. package/dist/index.d.ts.map +1 -0
  26. package/dist/index.js +37 -0
  27. package/dist/index.js.map +1 -0
  28. package/dist/ir.d.ts +204 -0
  29. package/dist/ir.d.ts.map +1 -0
  30. package/dist/ir.js +60 -0
  31. package/dist/ir.js.map +1 -0
  32. package/dist/module.d.ts +21 -0
  33. package/dist/module.d.ts.map +1 -0
  34. package/dist/module.js +113 -0
  35. package/dist/module.js.map +1 -0
  36. package/dist/ops.d.ts +35 -0
  37. package/dist/ops.d.ts.map +1 -0
  38. package/dist/ops.js +270 -0
  39. package/dist/ops.js.map +1 -0
  40. package/dist/runtime.d.ts +26 -0
  41. package/dist/runtime.d.ts.map +1 -0
  42. package/dist/runtime.js +190 -0
  43. package/dist/runtime.js.map +1 -0
  44. package/dist/shape.d.ts +24 -0
  45. package/dist/shape.d.ts.map +1 -0
  46. package/dist/shape.js +259 -0
  47. package/dist/shape.js.map +1 -0
  48. package/dist/trace.d.ts +8 -0
  49. package/dist/trace.d.ts.map +1 -0
  50. package/dist/trace.js +93 -0
  51. package/dist/trace.js.map +1 -0
  52. package/package.json +62 -0
  53. package/src/adam.ts +95 -0
  54. package/src/buffers.ts +173 -0
  55. package/src/codegen.ts +758 -0
  56. package/src/compile.ts +120 -0
  57. package/src/grad.ts +459 -0
  58. package/src/index.ts +40 -0
  59. package/src/ir.ts +197 -0
  60. package/src/module.ts +126 -0
  61. package/src/ops.ts +311 -0
  62. package/src/runtime.ts +232 -0
  63. package/src/shape.ts +263 -0
  64. package/src/trace.ts +101 -0
package/src/codegen.ts ADDED
@@ -0,0 +1,758 @@
1
+ // WGSL codegen: one kernel per IR op.
2
+ //
3
+ // All shapes are baked into the WGSL as compile-time constants — no shape
4
+ // uniforms. This means each shape combination produces a distinct shader
5
+ // (so `add([B, T, D], [D])` and `add([B, T, D], [B, T, D])` get different
6
+ // kernels), which is fine for our static-shape model and gives the WGSL
7
+ // compiler full freedom to specialize.
8
+ //
9
+ // Most kernels are direct ports of `transformer-gpu.bulb.md`'s WGSL — those
10
+ // are already debugged and tuned. The autograd ops (broadcast_to, sum_to_shape,
11
+ // relu_grad, etc.) are new.
12
+
13
+ import type { Graph, OpNode, Tensor, Shape } from './ir.js'
14
+ import type { BufferPlan } from './buffers.js'
15
+
16
+ // Workgroup size of 256 means even our biggest kernel (~8M threads in
17
+ // matmul_bwd_dW) needs only ~32K workgroups, well under WebGPU's 65535-per-dim
18
+ // dispatch cap. Smaller WG_SIZE forced 2D dispatch with significant over-dispatch.
19
+ const WG_SIZE = 256
20
+
21
+ export interface KernelSpec {
22
+ /** Index into graph.ops. */
23
+ opIndex: number
24
+ /** Op kind (for debugging / pipeline cache key). */
25
+ opKind: OpNode['kind']
26
+ /** Generated WGSL source. Empty string for "logical" ops with no kernel. */
27
+ wgsl: string
28
+ /**
29
+ * Buffer ids in binding-index order. The runtime creates a bind group with
30
+ * these in @binding(0..N) on @group(0). Inputs come first (read), output last
31
+ * (read_write).
32
+ */
33
+ bindings: number[]
34
+ /** Number of threads to dispatch (1-D). 0 means "skip" (e.g. reshape no-op). */
35
+ threads: number
36
+ /** Workgroup size; usually WG_SIZE. */
37
+ workgroupSize: number
38
+ }
39
+
40
+ // ============================================================================
41
+ // Public entry point
42
+ // ============================================================================
43
+
44
+ /** Generate a KernelSpec per compute op in graph.ops (in dispatch order). */
45
+ export function emitKernels(graph: Graph, plan: BufferPlan): KernelSpec[] {
46
+ const out: KernelSpec[] = []
47
+ for (let i = 0; i < graph.ops.length; i++) {
48
+ const op = graph.ops[i]!
49
+ const spec = emitKernel(op, graph, plan, i)
50
+ out.push(spec)
51
+ }
52
+ return out
53
+ }
54
+
55
+ function shapeSize(shape: Shape): number {
56
+ let n = 1
57
+ for (const d of shape) n *= d
58
+ return n
59
+ }
60
+
61
+ function emitKernel(op: OpNode, graph: Graph, plan: BufferPlan, opIndex: number): KernelSpec {
62
+ const tof = (id: number) => graph.tensors[id]!
63
+ const buf = (tensorId: number) => plan.tensorToBuffer.get(tensorId)!
64
+ const empty = (): KernelSpec => ({ opIndex, opKind: op.kind, wgsl: '', bindings: [], threads: 0, workgroupSize: WG_SIZE })
65
+
66
+ switch (op.kind) {
67
+ // ---- Leaves: data is supplied externally; no kernel ---------------------
68
+ case 'param_input':
69
+ case 'tensor_input':
70
+ case 'state_input':
71
+ return empty()
72
+
73
+ // ---- arange / const_scalar: kernel that fills the buffer once -----------
74
+ case 'arange': {
75
+ const out = tof(op.out)
76
+ const wgsl = `
77
+ @group(0) @binding(0) var<storage, read_write> buf : array<${wgslDtype(out.dtype)}>;
78
+ @compute @workgroup_size(${WG_SIZE})
79
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
80
+ let i = gid.x + gid.y * 16776960u;
81
+ if (i >= ${op.n}u) { return; }
82
+ buf[i] = ${castFromI32('i32(i)', out.dtype)};
83
+ }`.trim()
84
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.out)], threads: op.n, workgroupSize: WG_SIZE }
85
+ }
86
+ case 'const_scalar': {
87
+ const wgsl = `
88
+ @group(0) @binding(0) var<storage, read_write> buf : array<${wgslDtype(op.dtype)}>;
89
+ @compute @workgroup_size(1)
90
+ fn main() {
91
+ buf[0] = ${wgslLiteral(op.value, op.dtype)};
92
+ }`.trim()
93
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.out)], threads: 1, workgroupSize: 1 }
94
+ }
95
+
96
+ // ---- Element-wise binops with broadcast --------------------------------
97
+ case 'add':
98
+ case 'sub':
99
+ case 'mul':
100
+ case 'div': {
101
+ const out = tof(op.out)
102
+ const a = tof(op.a)
103
+ const b = tof(op.b)
104
+ const opStr = { add: '+', sub: '-', mul: '*', div: '/' }[op.kind]
105
+ const total = shapeSize(out.shape)
106
+ const wgsl = `
107
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
108
+ @group(0) @binding(1) var<storage, read> b : array<${wgslDtype(b.dtype)}>;
109
+ @group(0) @binding(2) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
110
+ @compute @workgroup_size(${WG_SIZE})
111
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
112
+ let i = gid.x + gid.y * 16776960u;
113
+ if (i >= ${total}u) { return; }
114
+ ${broadcastIndexBlock('i', out.shape, a.shape, 'aIdx')}
115
+ ${broadcastIndexBlock('i', out.shape, b.shape, 'bIdx')}
116
+ out[i] = a[aIdx] ${opStr} b[bIdx];
117
+ }`.trim()
118
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
119
+ }
120
+
121
+ // ---- Element-wise scalar binops (scalar baked into WGSL) ---------------
122
+ case 'mul_scalar':
123
+ case 'add_scalar': {
124
+ const out = tof(op.out)
125
+ const a = tof(op.a)
126
+ const opStr = op.kind === 'mul_scalar' ? '*' : '+'
127
+ const total = shapeSize(out.shape)
128
+ const lit = wgslLiteral(op.scalar, out.dtype)
129
+ const wgsl = `
130
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
131
+ @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
132
+ @compute @workgroup_size(${WG_SIZE})
133
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
134
+ let i = gid.x + gid.y * 16776960u;
135
+ if (i >= ${total}u) { return; }
136
+ out[i] = a[i] ${opStr} ${lit};
137
+ }`.trim()
138
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
139
+ }
140
+
141
+ // ---- Unary -------------------------------------------------------------
142
+ case 'sqrt':
143
+ case 'rsqrt':
144
+ case 'log':
145
+ case 'exp':
146
+ case 'relu': {
147
+ const out = tof(op.out)
148
+ const a = tof(op.a)
149
+ const total = shapeSize(out.shape)
150
+ const expr =
151
+ op.kind === 'sqrt' ? 'sqrt(x)' :
152
+ op.kind === 'rsqrt' ? '1.0 / sqrt(x)' :
153
+ op.kind === 'log' ? 'log(x)' :
154
+ op.kind === 'exp' ? 'exp(x)' :
155
+ /* relu */ 'max(x, 0.0)'
156
+ const wgsl = `
157
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
158
+ @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
159
+ @compute @workgroup_size(${WG_SIZE})
160
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
161
+ let i = gid.x + gid.y * 16776960u;
162
+ if (i >= ${total}u) { return; }
163
+ let x = a[i];
164
+ out[i] = ${expr};
165
+ }`.trim()
166
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
167
+ }
168
+
169
+ // ---- Comparisons + select --------------------------------------------
170
+ case 'less':
171
+ case 'greater': {
172
+ const out = tof(op.out)
173
+ const a = tof(op.a)
174
+ const b = tof(op.b)
175
+ const opStr = op.kind === 'less' ? '<' : '>'
176
+ const total = shapeSize(out.shape)
177
+ // bool tensors lower to u32 in storage (1 if true, 0 if false).
178
+ const wgsl = `
179
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
180
+ @group(0) @binding(1) var<storage, read> b : array<${wgslDtype(b.dtype)}>;
181
+ @group(0) @binding(2) var<storage, read_write> out : array<u32>;
182
+ @compute @workgroup_size(${WG_SIZE})
183
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
184
+ let i = gid.x + gid.y * 16776960u;
185
+ if (i >= ${total}u) { return; }
186
+ ${broadcastIndexBlock('i', out.shape, a.shape, 'aIdx')}
187
+ ${broadcastIndexBlock('i', out.shape, b.shape, 'bIdx')}
188
+ out[i] = select(0u, 1u, a[aIdx] ${opStr} b[bIdx]);
189
+ }`.trim()
190
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
191
+ }
192
+ case 'where': {
193
+ const out = tof(op.out)
194
+ const cond = tof(op.cond)
195
+ const a = tof(op.a)
196
+ const b = tof(op.b)
197
+ const total = shapeSize(out.shape)
198
+ const wgsl = `
199
+ @group(0) @binding(0) var<storage, read> cond : array<u32>;
200
+ @group(0) @binding(1) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
201
+ @group(0) @binding(2) var<storage, read> b : array<${wgslDtype(b.dtype)}>;
202
+ @group(0) @binding(3) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
203
+ @compute @workgroup_size(${WG_SIZE})
204
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
205
+ let i = gid.x + gid.y * 16776960u;
206
+ if (i >= ${total}u) { return; }
207
+ ${broadcastIndexBlock('i', out.shape, cond.shape, 'cIdx')}
208
+ ${broadcastIndexBlock('i', out.shape, a.shape, 'aIdx')}
209
+ ${broadcastIndexBlock('i', out.shape, b.shape, 'bIdx')}
210
+ out[i] = select(b[bIdx], a[aIdx], cond[cIdx] != 0u);
211
+ }`.trim()
212
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.cond), buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
213
+ }
214
+
215
+ case 'relu_grad': {
216
+ const out = tof(op.out)
217
+ const total = shapeSize(out.shape)
218
+ const wgsl = `
219
+ @group(0) @binding(0) var<storage, read> x : array<f32>;
220
+ @group(0) @binding(1) var<storage, read> dy : array<f32>;
221
+ @group(0) @binding(2) var<storage, read_write> out : array<f32>;
222
+ @compute @workgroup_size(${WG_SIZE})
223
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
224
+ let i = gid.x + gid.y * 16776960u;
225
+ if (i >= ${total}u) { return; }
226
+ out[i] = select(0.0, dy[i], x[i] > 0.0);
227
+ }`.trim()
228
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.x), buf(op.dy), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
229
+ }
230
+
231
+ // ---- Reductions over last axis -----------------------------------------
232
+ case 'mean_last':
233
+ case 'sum_last': {
234
+ const a = tof(op.a)
235
+ const D = a.shape[a.shape.length - 1]!
236
+ const outerSize = shapeSize(a.shape) / D
237
+ const divisor = op.kind === 'mean_last' ? `f32(${D}u)` : '1.0'
238
+ const wgsl = `
239
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
240
+ @group(0) @binding(1) var<storage, read_write> out : array<f32>;
241
+ @compute @workgroup_size(${WG_SIZE})
242
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
243
+ let i = gid.x + gid.y * 16776960u;
244
+ if (i >= ${outerSize}u) { return; }
245
+ let base = i * ${D}u;
246
+ var s : f32 = 0.0;
247
+ for (var j : u32 = 0u; j < ${D}u; j = j + 1u) {
248
+ s = s + a[base + j];
249
+ }
250
+ out[i] = s / ${divisor};
251
+ }`.trim()
252
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: outerSize, workgroupSize: WG_SIZE }
253
+ }
254
+
255
+ // ---- Shape ---------------------------------------------------------------
256
+ // reshape: no kernel needed if buffers can alias (shape change only). For
257
+ // v1 simplicity we emit a memcpy-style kernel rather than aliasing buffers,
258
+ // because aliasing complicates the buffer plan and we have memory headroom.
259
+ case 'reshape': {
260
+ const out = tof(op.out)
261
+ const a = tof(op.a)
262
+ const total = shapeSize(out.shape)
263
+ const wgsl = `
264
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
265
+ @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
266
+ @compute @workgroup_size(${WG_SIZE})
267
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
268
+ let i = gid.x + gid.y * 16776960u;
269
+ if (i >= ${total}u) { return; }
270
+ out[i] = a[i];
271
+ }`.trim()
272
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
273
+ }
274
+
275
+ case 'transpose': {
276
+ const out = tof(op.out)
277
+ const a = tof(op.a)
278
+ const total = shapeSize(out.shape)
279
+ // Emit per-axis index computation. For each output flat index i, decompose
280
+ // into per-axis output indices, then use op.perm to find the source axis order.
281
+ // Source flat index = sum(outIdx[perm.invert()[k]] * a_stride[k] for k).
282
+ const aStrides = computeStrides(a.shape)
283
+ const outDimDecls = decomposeFlatIndexBlock('i', out.shape, 'oIdx')
284
+ const srcExpr: string[] = []
285
+ for (let k = 0; k < a.shape.length; k++) {
286
+ const srcAxis = op.perm.indexOf(k) // which output axis came from input axis k
287
+ srcExpr.push(`oIdx_${srcAxis} * ${aStrides[k]}u`)
288
+ }
289
+ const wgsl = `
290
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
291
+ @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
292
+ @compute @workgroup_size(${WG_SIZE})
293
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
294
+ let i = gid.x + gid.y * 16776960u;
295
+ if (i >= ${total}u) { return; }
296
+ ${outDimDecls}
297
+ let srcIdx = ${srcExpr.join(' + ')};
298
+ out[i] = a[srcIdx];
299
+ }`.trim()
300
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
301
+ }
302
+
303
+ // ---- Linear algebra ----------------------------------------------------
304
+ // matmul: a [..., M, K] · b [K, N] -> [..., M, N]. b is unbatched.
305
+ case 'matmul': {
306
+ const out = tof(op.out)
307
+ const a = tof(op.a)
308
+ const b = tof(op.b)
309
+ const M = a.shape[a.shape.length - 2]!
310
+ const K = a.shape[a.shape.length - 1]!
311
+ const N = b.shape[1]!
312
+ const batch = shapeSize(a.shape) / (M * K)
313
+ const total = batch * M * N
314
+ const wgsl = `
315
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
316
+ @group(0) @binding(1) var<storage, read> b : array<f32>;
317
+ @group(0) @binding(2) var<storage, read_write> c : array<f32>;
318
+ @compute @workgroup_size(${WG_SIZE})
319
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
320
+ let i = gid.x + gid.y * 16776960u;
321
+ if (i >= ${total}u) { return; }
322
+ let bi = i / ${M * N}u; // batch index
323
+ let mn = i % ${M * N}u;
324
+ let m = mn / ${N}u;
325
+ let n = mn % ${N}u;
326
+ let aBase = bi * ${M * K}u + m * ${K}u;
327
+ var s : f32 = 0.0;
328
+ for (var k : u32 = 0u; k < ${K}u; k = k + 1u) {
329
+ s = s + a[aBase + k] * b[k * ${N}u + n];
330
+ }
331
+ c[i] = s;
332
+ }`.trim()
333
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
334
+ }
335
+
336
+ case 'matmul_batched': {
337
+ const out = tof(op.out)
338
+ const a = tof(op.a)
339
+ const b = tof(op.b)
340
+ const M = a.shape[a.shape.length - 2]!
341
+ const K = a.shape[a.shape.length - 1]!
342
+ const N = b.shape[b.shape.length - 1]!
343
+ const batch = shapeSize(a.shape) / (M * K)
344
+ const total = batch * M * N
345
+ const wgsl = `
346
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
347
+ @group(0) @binding(1) var<storage, read> b : array<f32>;
348
+ @group(0) @binding(2) var<storage, read_write> c : array<f32>;
349
+ @compute @workgroup_size(${WG_SIZE})
350
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
351
+ let i = gid.x + gid.y * 16776960u;
352
+ if (i >= ${total}u) { return; }
353
+ let bi = i / ${M * N}u;
354
+ let mn = i % ${M * N}u;
355
+ let m = mn / ${N}u;
356
+ let n = mn % ${N}u;
357
+ let aBase = bi * ${M * K}u + m * ${K}u;
358
+ let bBase = bi * ${K * N}u;
359
+ var s : f32 = 0.0;
360
+ for (var k : u32 = 0u; k < ${K}u; k = k + 1u) {
361
+ s = s + a[aBase + k] * b[bBase + k * ${N}u + n];
362
+ }
363
+ c[i] = s;
364
+ }`.trim()
365
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
366
+ }
367
+
368
+ // ---- One-hot ------------------------------------------------------------
369
+ case 'one_hot': {
370
+ const out = tof(op.out)
371
+ const indices = tof(op.indices)
372
+ const total = shapeSize(out.shape)
373
+ const depth = op.depth
374
+ const zeroLit = wgslLiteral(0, out.dtype)
375
+ const oneLit = wgslLiteral(1, out.dtype)
376
+ const wgsl = `
377
+ @group(0) @binding(0) var<storage, read> indices : array<i32>;
378
+ @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
379
+ @compute @workgroup_size(${WG_SIZE})
380
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
381
+ let i = gid.x + gid.y * 16776960u;
382
+ if (i >= ${total}u) { return; }
383
+ let outerIdx = i / ${depth}u;
384
+ let depthIdx = i % ${depth}u;
385
+ let tgt = u32(indices[outerIdx]);
386
+ out[i] = select(${zeroLit}, ${oneLit}, tgt == depthIdx);
387
+ }`.trim()
388
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.indices), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
389
+ }
390
+
391
+ // ---- ML primitives -----------------------------------------------------
392
+ case 'log_softmax_last': {
393
+ const a = tof(op.a)
394
+ const D = a.shape[a.shape.length - 1]!
395
+ const outerSize = shapeSize(a.shape) / D
396
+ const wgsl = `
397
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
398
+ @group(0) @binding(1) var<storage, read_write> out : array<f32>;
399
+ @compute @workgroup_size(${WG_SIZE})
400
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
401
+ let i = gid.x + gid.y * 16776960u;
402
+ if (i >= ${outerSize}u) { return; }
403
+ let base = i * ${D}u;
404
+ var m : f32 = -1.0e30;
405
+ for (var j : u32 = 0u; j < ${D}u; j = j + 1u) {
406
+ let v = a[base + j];
407
+ if (v > m) { m = v; }
408
+ }
409
+ var s : f32 = 0.0;
410
+ for (var j : u32 = 0u; j < ${D}u; j = j + 1u) {
411
+ s = s + exp(a[base + j] - m);
412
+ }
413
+ let logZ = m + log(s);
414
+ for (var j : u32 = 0u; j < ${D}u; j = j + 1u) {
415
+ out[base + j] = a[base + j] - logZ;
416
+ }
417
+ }`.trim()
418
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: outerSize, workgroupSize: WG_SIZE }
419
+ }
420
+
421
+ case 'softmax_causal_last': {
422
+ const a = tof(op.a)
423
+ const T = a.shape[a.shape.length - 1]! // == second-to-last (square)
424
+ // Outer size = (everything except last 2 axes) * (second-to-last axis)
425
+ const outerSize = shapeSize(a.shape) / T
426
+ const wgsl = `
427
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
428
+ @group(0) @binding(1) var<storage, read_write> out : array<f32>;
429
+ @compute @workgroup_size(${WG_SIZE})
430
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
431
+ // Each thread handles one (..., qpos)-row, softmaxing over kpos∈[0..qpos].
432
+ let i = gid.x + gid.y * 16776960u;
433
+ if (i >= ${outerSize}u) { return; }
434
+ let qpos = i % ${T}u;
435
+ let base = i * ${T}u;
436
+ var m : f32 = -1.0e30;
437
+ for (var k : u32 = 0u; k <= qpos; k = k + 1u) {
438
+ let v = a[base + k];
439
+ if (v > m) { m = v; }
440
+ }
441
+ var s : f32 = 0.0;
442
+ for (var k : u32 = 0u; k <= qpos; k = k + 1u) {
443
+ let e = exp(a[base + k] - m);
444
+ out[base + k] = e;
445
+ s = s + e;
446
+ }
447
+ for (var k : u32 = 0u; k <= qpos; k = k + 1u) {
448
+ out[base + k] = out[base + k] / s;
449
+ }
450
+ for (var k : u32 = qpos + 1u; k < ${T}u; k = k + 1u) {
451
+ out[base + k] = 0.0;
452
+ }
453
+ }`.trim()
454
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: outerSize, workgroupSize: WG_SIZE }
455
+ }
456
+
457
+ case 'where_causal': {
458
+ const a = tof(op.a)
459
+ const T = a.shape[a.shape.length - 1]!
460
+ const total = shapeSize(a.shape)
461
+ const fillLit = wgslLiteral(op.fillValue, 'f32')
462
+ const wgsl = `
463
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
464
+ @group(0) @binding(1) var<storage, read_write> out : array<f32>;
465
+ @compute @workgroup_size(${WG_SIZE})
466
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
467
+ let i = gid.x + gid.y * 16776960u;
468
+ if (i >= ${total}u) { return; }
469
+ let kpos = i % ${T}u;
470
+ let qpos = (i / ${T}u) % ${T}u;
471
+ if (kpos > qpos) {
472
+ out[i] = ${fillLit};
473
+ } else {
474
+ out[i] = a[i];
475
+ }
476
+ }`.trim()
477
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
478
+ }
479
+
480
+ // ---- Slicing -----------------------------------------------------------
481
+ case 'slice_last_range': {
482
+ const out = tof(op.out)
483
+ const a = tof(op.a)
484
+ const D_in = a.shape[a.shape.length - 1]!
485
+ const D_out = op.end - op.start
486
+ const total = shapeSize(out.shape)
487
+ const wgsl = `
488
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
489
+ @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
490
+ @compute @workgroup_size(${WG_SIZE})
491
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
492
+ let i = gid.x + gid.y * 16776960u;
493
+ if (i >= ${total}u) { return; }
494
+ let outer = i / ${D_out}u;
495
+ let inner = i % ${D_out}u;
496
+ out[i] = a[outer * ${D_in}u + ${op.start}u + inner];
497
+ }`.trim()
498
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
499
+ }
500
+
501
+ // ---- Broadcast / un-broadcast (autograd infrastructure) ----------------
502
+ case 'broadcast_to': {
503
+ const out = tof(op.out)
504
+ const a = tof(op.a)
505
+ const total = shapeSize(out.shape)
506
+ const wgsl = `
507
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
508
+ @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
509
+ @compute @workgroup_size(${WG_SIZE})
510
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
511
+ let i = gid.x + gid.y * 16776960u;
512
+ if (i >= ${total}u) { return; }
513
+ ${broadcastIndexBlock('i', out.shape, a.shape, 'srcIdx')}
514
+ out[i] = a[srcIdx];
515
+ }`.trim()
516
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
517
+ }
518
+
519
+ // ---- Adam (fused per-element) -----------------------------------------
520
+ case 'adam_update_m': {
521
+ // m_new = b1 * m + (1 - b1) * g
522
+ const out = tof(op.out)
523
+ const total = shapeSize(out.shape)
524
+ const b1 = op.b1
525
+ const oneMinusB1 = 1 - b1
526
+ const wgsl = `
527
+ @group(0) @binding(0) var<storage, read> m : array<f32>;
528
+ @group(0) @binding(1) var<storage, read> g : array<f32>;
529
+ @group(0) @binding(2) var<storage, read_write> out : array<f32>;
530
+ @compute @workgroup_size(${WG_SIZE})
531
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
532
+ let i = gid.x + gid.y * 16776960u;
533
+ if (i >= ${total}u) { return; }
534
+ out[i] = ${wgslLiteral(b1, 'f32')} * m[i] + ${wgslLiteral(oneMinusB1, 'f32')} * g[i];
535
+ }`.trim()
536
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.m), buf(op.g), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
537
+ }
538
+ case 'adam_update_v': {
539
+ // v_new = b2 * v + (1 - b2) * g²
540
+ const out = tof(op.out)
541
+ const total = shapeSize(out.shape)
542
+ const b2 = op.b2
543
+ const oneMinusB2 = 1 - b2
544
+ const wgsl = `
545
+ @group(0) @binding(0) var<storage, read> v : array<f32>;
546
+ @group(0) @binding(1) var<storage, read> g : array<f32>;
547
+ @group(0) @binding(2) var<storage, read_write> out : array<f32>;
548
+ @compute @workgroup_size(${WG_SIZE})
549
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
550
+ let i = gid.x + gid.y * 16776960u;
551
+ if (i >= ${total}u) { return; }
552
+ let gv = g[i];
553
+ out[i] = ${wgslLiteral(b2, 'f32')} * v[i] + ${wgslLiteral(oneMinusB2, 'f32')} * gv * gv;
554
+ }`.trim()
555
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.v), buf(op.g), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
556
+ }
557
+ case 'adam_update_p': {
558
+ // p_new = p - lrt[0] * m_new / (sqrt(v_new) + eps).
559
+ // lrt is supplied per-step from CPU (already includes bias correction).
560
+ const out = tof(op.out)
561
+ const total = shapeSize(out.shape)
562
+ const wgsl = `
563
+ @group(0) @binding(0) var<storage, read> p : array<f32>;
564
+ @group(0) @binding(1) var<storage, read> mNew : array<f32>;
565
+ @group(0) @binding(2) var<storage, read> vNew : array<f32>;
566
+ @group(0) @binding(3) var<storage, read> lrt : array<f32>;
567
+ @group(0) @binding(4) var<storage, read_write> out : array<f32>;
568
+ @compute @workgroup_size(${WG_SIZE})
569
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
570
+ let i = gid.x + gid.y * 16776960u;
571
+ if (i >= ${total}u) { return; }
572
+ out[i] = p[i] - lrt[0] * mNew[i] / (sqrt(vNew[i]) + ${wgslLiteral(op.eps, 'f32')});
573
+ }`.trim()
574
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.p), buf(op.mNew), buf(op.vNew), buf(op.lrt), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
575
+ }
576
+
577
+ case 'sum_to_shape': {
578
+ // Sum-reduce src down to target by summing over each axis where target=1
579
+ // or where target is missing (offset-prefix axes that get fully summed).
580
+ const out = tof(op.out)
581
+ const a = tof(op.a)
582
+ const wgsl = emitSumToShape(a.shape, out.shape, a.dtype)
583
+ const total = shapeSize(out.shape)
584
+ return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
585
+ }
586
+ }
587
+ }
588
+
589
+ // ============================================================================
590
+ // WGSL helpers
591
+ // ============================================================================
592
+
593
+ function wgslDtype(d: 'f32' | 'i32' | 'bool'): string {
594
+ // bool can't be in storage buffers in WGSL; we lower bool-typed tensors to
595
+ // u32 (0/1). For Phase 3a there are no bool-typed storage buffers in the
596
+ // forward+backward graph (causal mask is built inline in softmax kernels),
597
+ // so this only matters if the user explicitly creates a bool tensor.
598
+ if (d === 'bool') return 'u32'
599
+ return d
600
+ }
601
+
602
+ function wgslLiteral(value: number, dtype: 'f32' | 'i32' | 'bool'): string {
603
+ if (dtype === 'f32') {
604
+ if (Number.isFinite(value)) {
605
+ // WGSL requires `.` in float literals; force decimal form.
606
+ return value.toString().includes('.') || value.toString().includes('e')
607
+ ? `${value}f`
608
+ : `${value}.0f`
609
+ }
610
+ return value > 0 ? '1.0e30f' : '-1.0e30f'
611
+ }
612
+ if (dtype === 'i32') return `${Math.trunc(value)}i`
613
+ return value ? '1u' : '0u'
614
+ }
615
+
616
+ function castFromI32(expr: string, dtype: 'f32' | 'i32' | 'bool'): string {
617
+ if (dtype === 'f32') return `f32(${expr})`
618
+ if (dtype === 'i32') return `i32(${expr})`
619
+ return `u32(${expr})`
620
+ }
621
+
622
+ function computeStrides(shape: Shape): number[] {
623
+ const strides: number[] = new Array(shape.length).fill(1)
624
+ for (let i = shape.length - 2; i >= 0; i--) {
625
+ strides[i] = strides[i + 1]! * shape[i + 1]!
626
+ }
627
+ return strides
628
+ }
629
+
630
+ /**
631
+ * Generate WGSL that decomposes a flat index `flatVar` into per-axis indices
632
+ * `outVar_0, outVar_1, ...` according to `shape`.
633
+ */
634
+ function decomposeFlatIndexBlock(flatVar: string, shape: Shape, outVar: string): string {
635
+ if (shape.length === 0) return ` let ${outVar}_0 : u32 = 0u;` // not used but parser-safe
636
+ const strides = computeStrides(shape)
637
+ const lines: string[] = []
638
+ let remaining = flatVar
639
+ for (let i = 0; i < shape.length; i++) {
640
+ if (i === shape.length - 1) {
641
+ lines.push(` let ${outVar}_${i} = ${remaining};`)
642
+ } else {
643
+ lines.push(` let ${outVar}_${i} = ${remaining} / ${strides[i]}u;`)
644
+ const newRem = `${outVar}_rem${i}`
645
+ lines.push(` let ${newRem} = ${remaining} % ${strides[i]}u;`)
646
+ remaining = newRem
647
+ }
648
+ }
649
+ return lines.join('\n')
650
+ }
651
+
652
+ /**
653
+ * Generate WGSL that computes the source flat index in `srcVar` for an output
654
+ * flat index `flatVar`, given output shape `outShape` and source shape `srcShape`
655
+ * under right-aligned NumPy-style broadcasting (size-1 axes broadcast).
656
+ *
657
+ * Strategy:
658
+ * 1. Decompose flat output index into per-axis output indices.
659
+ * 2. For each output axis that maps onto a source axis (right-aligned), use
660
+ * the output index there if src.dim != 1, else 0 (broadcast).
661
+ * 3. Drop output-only axes (those with no corresponding source axis).
662
+ * 4. Combine source indices with source strides.
663
+ */
664
+ function broadcastIndexBlock(flatVar: string, outShape: Shape, srcShape: Shape, srcVar: string): string {
665
+ // Name the per-axis decomposition vars after `srcVar` so multiple
666
+ // broadcastIndexBlock calls in the same WGSL function don't collide.
667
+ const prefix = `${srcVar}_ax`
668
+ const decompose = decomposeFlatIndexBlock(flatVar, outShape, prefix)
669
+ const offset = outShape.length - srcShape.length
670
+ if (srcShape.length === 0) {
671
+ return `${decompose}\n let ${srcVar} : u32 = 0u;`
672
+ }
673
+ const srcStrides = computeStrides(srcShape)
674
+ const terms: string[] = []
675
+ for (let i = 0; i < srcShape.length; i++) {
676
+ const outAxis = i + offset
677
+ const srcDim = srcShape[i]!
678
+ const term = srcDim === 1 ? '0u' : `${prefix}_${outAxis} * ${srcStrides[i]}u`
679
+ terms.push(term)
680
+ }
681
+ return `${decompose}\n let ${srcVar} = ${terms.join(' + ')};`
682
+ }
683
+
684
+ /**
685
+ * sum_to_shape: each output cell sums over the source axes that are reduced.
686
+ * For source shape S and target shape T (right-aligned):
687
+ * - Axes in S not in T (leading prefix): fully reduced (sum over whole axis).
688
+ * - Axes where T=1 but S>1: reduced (sum over that axis).
689
+ * - Axes where T=S: passed through.
690
+ *
691
+ * Implementation: each thread = one output cell. It iterates over the reduced
692
+ * axes via nested-loop unrolling (we generate explicit nested for-loops).
693
+ */
694
+ function emitSumToShape(srcShape: Shape, tgtShape: Shape, dtype: 'f32' | 'i32' | 'bool'): string {
695
+ const srcStrides = computeStrides(srcShape)
696
+ const tgtStrides = computeStrides(tgtShape)
697
+ const offset = srcShape.length - tgtShape.length
698
+
699
+ // Decompose flat output index into per-axis target indices.
700
+ const decompose = decomposeFlatIndexBlock('i', tgtShape, 'tgt')
701
+
702
+ // Identify reduced axes of the SOURCE: axis k in src is reduced if either
703
+ // it's in the leading prefix (k < offset) or its corresponding target axis
704
+ // has size 1. For non-reduced axes (k >= offset and tgt=src), the source
705
+ // index is the target index along that axis.
706
+ const reducedAxes: number[] = []
707
+ for (let k = 0; k < srcShape.length; k++) {
708
+ if (k < offset) { reducedAxes.push(k); continue }
709
+ const tDim = tgtShape[k - offset]!
710
+ const sDim = srcShape[k]!
711
+ if (tDim === 1 && sDim > 1) reducedAxes.push(k)
712
+ }
713
+
714
+ // Build the source flat index expression. Initialize from the non-reduced axes.
715
+ const baseTerms: string[] = []
716
+ for (let k = 0; k < srcShape.length; k++) {
717
+ if (reducedAxes.includes(k)) continue // contributed by loop var instead
718
+ const tAxis = k - offset
719
+ baseTerms.push(`tgt_${tAxis} * ${srcStrides[k]}u`)
720
+ }
721
+ const baseExpr = baseTerms.length > 0 ? baseTerms.join(' + ') : '0u'
722
+
723
+ // Emit nested for loops over the reduced axes.
724
+ const indent = (depth: number) => ' '.repeat(depth + 1)
725
+ const loops: string[] = []
726
+ for (let depth = 0; depth < reducedAxes.length; depth++) {
727
+ const k = reducedAxes[depth]!
728
+ const dim = srcShape[k]!
729
+ loops.push(`${indent(depth)}for (var r${k} : u32 = 0u; r${k} < ${dim}u; r${k} = r${k} + 1u) {`)
730
+ }
731
+ // Inside innermost loop, compute source index.
732
+ const reducedTerms = reducedAxes.map(k => `r${k} * ${srcStrides[k]}u`)
733
+ const fullExpr = reducedTerms.length > 0
734
+ ? `${baseExpr} + ${reducedTerms.join(' + ')}`
735
+ : baseExpr
736
+ loops.push(`${indent(reducedAxes.length)}s = s + a[${fullExpr}];`)
737
+ for (let depth = reducedAxes.length - 1; depth >= 0; depth--) {
738
+ loops.push(`${indent(depth)}}`)
739
+ }
740
+
741
+ const total = tgtShape.length === 0 ? 1 : (tgtStrides[0]! * tgtShape[0]!)
742
+ const loopBody = reducedAxes.length === 0
743
+ ? ` s = s + a[${baseExpr}];`
744
+ : loops.join('\n')
745
+
746
+ return `
747
+ @group(0) @binding(0) var<storage, read> a : array<${wgslDtype(dtype)}>;
748
+ @group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(dtype)}>;
749
+ @compute @workgroup_size(${WG_SIZE})
750
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
751
+ let i = gid.x + gid.y * 16776960u;
752
+ if (i >= ${total}u) { return; }
753
+ ${decompose}
754
+ var s : ${wgslDtype(dtype)} = ${dtype === 'f32' ? '0.0f' : (dtype === 'i32' ? '0i' : '0u')};
755
+ ${loopBody}
756
+ out[i] = s;
757
+ }`.trim()
758
+ }