tensorgrad 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +121 -0
- package/SPEC.md +293 -0
- package/dist/adam.d.ts +31 -0
- package/dist/adam.d.ts.map +1 -0
- package/dist/adam.js +66 -0
- package/dist/adam.js.map +1 -0
- package/dist/buffers.d.ts +56 -0
- package/dist/buffers.d.ts.map +1 -0
- package/dist/buffers.js +114 -0
- package/dist/buffers.js.map +1 -0
- package/dist/codegen.d.ts +23 -0
- package/dist/codegen.d.ts.map +1 -0
- package/dist/codegen.js +709 -0
- package/dist/codegen.js.map +1 -0
- package/dist/compile.d.ts +53 -0
- package/dist/compile.d.ts.map +1 -0
- package/dist/compile.js +76 -0
- package/dist/compile.js.map +1 -0
- package/dist/grad.d.ts +8 -0
- package/dist/grad.d.ts.map +1 -0
- package/dist/grad.js +404 -0
- package/dist/grad.js.map +1 -0
- package/dist/index.d.ts +12 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +37 -0
- package/dist/index.js.map +1 -0
- package/dist/ir.d.ts +204 -0
- package/dist/ir.d.ts.map +1 -0
- package/dist/ir.js +60 -0
- package/dist/ir.js.map +1 -0
- package/dist/module.d.ts +21 -0
- package/dist/module.d.ts.map +1 -0
- package/dist/module.js +113 -0
- package/dist/module.js.map +1 -0
- package/dist/ops.d.ts +35 -0
- package/dist/ops.d.ts.map +1 -0
- package/dist/ops.js +270 -0
- package/dist/ops.js.map +1 -0
- package/dist/runtime.d.ts +26 -0
- package/dist/runtime.d.ts.map +1 -0
- package/dist/runtime.js +190 -0
- package/dist/runtime.js.map +1 -0
- package/dist/shape.d.ts +24 -0
- package/dist/shape.d.ts.map +1 -0
- package/dist/shape.js +259 -0
- package/dist/shape.js.map +1 -0
- package/dist/trace.d.ts +8 -0
- package/dist/trace.d.ts.map +1 -0
- package/dist/trace.js +93 -0
- package/dist/trace.js.map +1 -0
- package/package.json +62 -0
- package/src/adam.ts +95 -0
- package/src/buffers.ts +173 -0
- package/src/codegen.ts +758 -0
- package/src/compile.ts +120 -0
- package/src/grad.ts +459 -0
- package/src/index.ts +40 -0
- package/src/ir.ts +197 -0
- package/src/module.ts +126 -0
- package/src/ops.ts +311 -0
- package/src/runtime.ts +232 -0
- package/src/shape.ts +263 -0
- package/src/trace.ts +101 -0
package/src/codegen.ts
ADDED
|
@@ -0,0 +1,758 @@
|
|
|
1
|
+
// WGSL codegen: one kernel per IR op.
|
|
2
|
+
//
|
|
3
|
+
// All shapes are baked into the WGSL as compile-time constants — no shape
|
|
4
|
+
// uniforms. This means each shape combination produces a distinct shader
|
|
5
|
+
// (so `add([B, T, D], [D])` and `add([B, T, D], [B, T, D])` get different
|
|
6
|
+
// kernels), which is fine for our static-shape model and gives the WGSL
|
|
7
|
+
// compiler full freedom to specialize.
|
|
8
|
+
//
|
|
9
|
+
// Most kernels are direct ports of `transformer-gpu.bulb.md`'s WGSL — those
|
|
10
|
+
// are already debugged and tuned. The autograd ops (broadcast_to, sum_to_shape,
|
|
11
|
+
// relu_grad, etc.) are new.
|
|
12
|
+
|
|
13
|
+
import type { Graph, OpNode, Tensor, Shape } from './ir.js'
|
|
14
|
+
import type { BufferPlan } from './buffers.js'
|
|
15
|
+
|
|
16
|
+
// Workgroup size of 256 means even our biggest kernel (~8M threads in
|
|
17
|
+
// matmul_bwd_dW) needs only ~32K workgroups, well under WebGPU's 65535-per-dim
|
|
18
|
+
// dispatch cap. Smaller WG_SIZE forced 2D dispatch with significant over-dispatch.
|
|
19
|
+
const WG_SIZE = 256
|
|
20
|
+
|
|
21
|
+
export interface KernelSpec {
|
|
22
|
+
/** Index into graph.ops. */
|
|
23
|
+
opIndex: number
|
|
24
|
+
/** Op kind (for debugging / pipeline cache key). */
|
|
25
|
+
opKind: OpNode['kind']
|
|
26
|
+
/** Generated WGSL source. Empty string for "logical" ops with no kernel. */
|
|
27
|
+
wgsl: string
|
|
28
|
+
/**
|
|
29
|
+
* Buffer ids in binding-index order. The runtime creates a bind group with
|
|
30
|
+
* these in @binding(0..N) on @group(0). Inputs come first (read), output last
|
|
31
|
+
* (read_write).
|
|
32
|
+
*/
|
|
33
|
+
bindings: number[]
|
|
34
|
+
/** Number of threads to dispatch (1-D). 0 means "skip" (e.g. reshape no-op). */
|
|
35
|
+
threads: number
|
|
36
|
+
/** Workgroup size; usually WG_SIZE. */
|
|
37
|
+
workgroupSize: number
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
// ============================================================================
|
|
41
|
+
// Public entry point
|
|
42
|
+
// ============================================================================
|
|
43
|
+
|
|
44
|
+
/** Generate a KernelSpec per compute op in graph.ops (in dispatch order). */
|
|
45
|
+
export function emitKernels(graph: Graph, plan: BufferPlan): KernelSpec[] {
|
|
46
|
+
const out: KernelSpec[] = []
|
|
47
|
+
for (let i = 0; i < graph.ops.length; i++) {
|
|
48
|
+
const op = graph.ops[i]!
|
|
49
|
+
const spec = emitKernel(op, graph, plan, i)
|
|
50
|
+
out.push(spec)
|
|
51
|
+
}
|
|
52
|
+
return out
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
function shapeSize(shape: Shape): number {
|
|
56
|
+
let n = 1
|
|
57
|
+
for (const d of shape) n *= d
|
|
58
|
+
return n
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
function emitKernel(op: OpNode, graph: Graph, plan: BufferPlan, opIndex: number): KernelSpec {
|
|
62
|
+
const tof = (id: number) => graph.tensors[id]!
|
|
63
|
+
const buf = (tensorId: number) => plan.tensorToBuffer.get(tensorId)!
|
|
64
|
+
const empty = (): KernelSpec => ({ opIndex, opKind: op.kind, wgsl: '', bindings: [], threads: 0, workgroupSize: WG_SIZE })
|
|
65
|
+
|
|
66
|
+
switch (op.kind) {
|
|
67
|
+
// ---- Leaves: data is supplied externally; no kernel ---------------------
|
|
68
|
+
case 'param_input':
|
|
69
|
+
case 'tensor_input':
|
|
70
|
+
case 'state_input':
|
|
71
|
+
return empty()
|
|
72
|
+
|
|
73
|
+
// ---- arange / const_scalar: kernel that fills the buffer once -----------
|
|
74
|
+
case 'arange': {
|
|
75
|
+
const out = tof(op.out)
|
|
76
|
+
const wgsl = `
|
|
77
|
+
@group(0) @binding(0) var<storage, read_write> buf : array<${wgslDtype(out.dtype)}>;
|
|
78
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
79
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
80
|
+
let i = gid.x + gid.y * 16776960u;
|
|
81
|
+
if (i >= ${op.n}u) { return; }
|
|
82
|
+
buf[i] = ${castFromI32('i32(i)', out.dtype)};
|
|
83
|
+
}`.trim()
|
|
84
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.out)], threads: op.n, workgroupSize: WG_SIZE }
|
|
85
|
+
}
|
|
86
|
+
case 'const_scalar': {
|
|
87
|
+
const wgsl = `
|
|
88
|
+
@group(0) @binding(0) var<storage, read_write> buf : array<${wgslDtype(op.dtype)}>;
|
|
89
|
+
@compute @workgroup_size(1)
|
|
90
|
+
fn main() {
|
|
91
|
+
buf[0] = ${wgslLiteral(op.value, op.dtype)};
|
|
92
|
+
}`.trim()
|
|
93
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.out)], threads: 1, workgroupSize: 1 }
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
// ---- Element-wise binops with broadcast --------------------------------
|
|
97
|
+
case 'add':
|
|
98
|
+
case 'sub':
|
|
99
|
+
case 'mul':
|
|
100
|
+
case 'div': {
|
|
101
|
+
const out = tof(op.out)
|
|
102
|
+
const a = tof(op.a)
|
|
103
|
+
const b = tof(op.b)
|
|
104
|
+
const opStr = { add: '+', sub: '-', mul: '*', div: '/' }[op.kind]
|
|
105
|
+
const total = shapeSize(out.shape)
|
|
106
|
+
const wgsl = `
|
|
107
|
+
@group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
|
|
108
|
+
@group(0) @binding(1) var<storage, read> b : array<${wgslDtype(b.dtype)}>;
|
|
109
|
+
@group(0) @binding(2) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
|
|
110
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
111
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
112
|
+
let i = gid.x + gid.y * 16776960u;
|
|
113
|
+
if (i >= ${total}u) { return; }
|
|
114
|
+
${broadcastIndexBlock('i', out.shape, a.shape, 'aIdx')}
|
|
115
|
+
${broadcastIndexBlock('i', out.shape, b.shape, 'bIdx')}
|
|
116
|
+
out[i] = a[aIdx] ${opStr} b[bIdx];
|
|
117
|
+
}`.trim()
|
|
118
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
// ---- Element-wise scalar binops (scalar baked into WGSL) ---------------
|
|
122
|
+
case 'mul_scalar':
|
|
123
|
+
case 'add_scalar': {
|
|
124
|
+
const out = tof(op.out)
|
|
125
|
+
const a = tof(op.a)
|
|
126
|
+
const opStr = op.kind === 'mul_scalar' ? '*' : '+'
|
|
127
|
+
const total = shapeSize(out.shape)
|
|
128
|
+
const lit = wgslLiteral(op.scalar, out.dtype)
|
|
129
|
+
const wgsl = `
|
|
130
|
+
@group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
|
|
131
|
+
@group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
|
|
132
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
133
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
134
|
+
let i = gid.x + gid.y * 16776960u;
|
|
135
|
+
if (i >= ${total}u) { return; }
|
|
136
|
+
out[i] = a[i] ${opStr} ${lit};
|
|
137
|
+
}`.trim()
|
|
138
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
// ---- Unary -------------------------------------------------------------
|
|
142
|
+
case 'sqrt':
|
|
143
|
+
case 'rsqrt':
|
|
144
|
+
case 'log':
|
|
145
|
+
case 'exp':
|
|
146
|
+
case 'relu': {
|
|
147
|
+
const out = tof(op.out)
|
|
148
|
+
const a = tof(op.a)
|
|
149
|
+
const total = shapeSize(out.shape)
|
|
150
|
+
const expr =
|
|
151
|
+
op.kind === 'sqrt' ? 'sqrt(x)' :
|
|
152
|
+
op.kind === 'rsqrt' ? '1.0 / sqrt(x)' :
|
|
153
|
+
op.kind === 'log' ? 'log(x)' :
|
|
154
|
+
op.kind === 'exp' ? 'exp(x)' :
|
|
155
|
+
/* relu */ 'max(x, 0.0)'
|
|
156
|
+
const wgsl = `
|
|
157
|
+
@group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
|
|
158
|
+
@group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
|
|
159
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
160
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
161
|
+
let i = gid.x + gid.y * 16776960u;
|
|
162
|
+
if (i >= ${total}u) { return; }
|
|
163
|
+
let x = a[i];
|
|
164
|
+
out[i] = ${expr};
|
|
165
|
+
}`.trim()
|
|
166
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
// ---- Comparisons + select --------------------------------------------
|
|
170
|
+
case 'less':
|
|
171
|
+
case 'greater': {
|
|
172
|
+
const out = tof(op.out)
|
|
173
|
+
const a = tof(op.a)
|
|
174
|
+
const b = tof(op.b)
|
|
175
|
+
const opStr = op.kind === 'less' ? '<' : '>'
|
|
176
|
+
const total = shapeSize(out.shape)
|
|
177
|
+
// bool tensors lower to u32 in storage (1 if true, 0 if false).
|
|
178
|
+
const wgsl = `
|
|
179
|
+
@group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
|
|
180
|
+
@group(0) @binding(1) var<storage, read> b : array<${wgslDtype(b.dtype)}>;
|
|
181
|
+
@group(0) @binding(2) var<storage, read_write> out : array<u32>;
|
|
182
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
183
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
184
|
+
let i = gid.x + gid.y * 16776960u;
|
|
185
|
+
if (i >= ${total}u) { return; }
|
|
186
|
+
${broadcastIndexBlock('i', out.shape, a.shape, 'aIdx')}
|
|
187
|
+
${broadcastIndexBlock('i', out.shape, b.shape, 'bIdx')}
|
|
188
|
+
out[i] = select(0u, 1u, a[aIdx] ${opStr} b[bIdx]);
|
|
189
|
+
}`.trim()
|
|
190
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
191
|
+
}
|
|
192
|
+
case 'where': {
|
|
193
|
+
const out = tof(op.out)
|
|
194
|
+
const cond = tof(op.cond)
|
|
195
|
+
const a = tof(op.a)
|
|
196
|
+
const b = tof(op.b)
|
|
197
|
+
const total = shapeSize(out.shape)
|
|
198
|
+
const wgsl = `
|
|
199
|
+
@group(0) @binding(0) var<storage, read> cond : array<u32>;
|
|
200
|
+
@group(0) @binding(1) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
|
|
201
|
+
@group(0) @binding(2) var<storage, read> b : array<${wgslDtype(b.dtype)}>;
|
|
202
|
+
@group(0) @binding(3) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
|
|
203
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
204
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
205
|
+
let i = gid.x + gid.y * 16776960u;
|
|
206
|
+
if (i >= ${total}u) { return; }
|
|
207
|
+
${broadcastIndexBlock('i', out.shape, cond.shape, 'cIdx')}
|
|
208
|
+
${broadcastIndexBlock('i', out.shape, a.shape, 'aIdx')}
|
|
209
|
+
${broadcastIndexBlock('i', out.shape, b.shape, 'bIdx')}
|
|
210
|
+
out[i] = select(b[bIdx], a[aIdx], cond[cIdx] != 0u);
|
|
211
|
+
}`.trim()
|
|
212
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.cond), buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
case 'relu_grad': {
|
|
216
|
+
const out = tof(op.out)
|
|
217
|
+
const total = shapeSize(out.shape)
|
|
218
|
+
const wgsl = `
|
|
219
|
+
@group(0) @binding(0) var<storage, read> x : array<f32>;
|
|
220
|
+
@group(0) @binding(1) var<storage, read> dy : array<f32>;
|
|
221
|
+
@group(0) @binding(2) var<storage, read_write> out : array<f32>;
|
|
222
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
223
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
224
|
+
let i = gid.x + gid.y * 16776960u;
|
|
225
|
+
if (i >= ${total}u) { return; }
|
|
226
|
+
out[i] = select(0.0, dy[i], x[i] > 0.0);
|
|
227
|
+
}`.trim()
|
|
228
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.x), buf(op.dy), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
// ---- Reductions over last axis -----------------------------------------
|
|
232
|
+
case 'mean_last':
|
|
233
|
+
case 'sum_last': {
|
|
234
|
+
const a = tof(op.a)
|
|
235
|
+
const D = a.shape[a.shape.length - 1]!
|
|
236
|
+
const outerSize = shapeSize(a.shape) / D
|
|
237
|
+
const divisor = op.kind === 'mean_last' ? `f32(${D}u)` : '1.0'
|
|
238
|
+
const wgsl = `
|
|
239
|
+
@group(0) @binding(0) var<storage, read> a : array<f32>;
|
|
240
|
+
@group(0) @binding(1) var<storage, read_write> out : array<f32>;
|
|
241
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
242
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
243
|
+
let i = gid.x + gid.y * 16776960u;
|
|
244
|
+
if (i >= ${outerSize}u) { return; }
|
|
245
|
+
let base = i * ${D}u;
|
|
246
|
+
var s : f32 = 0.0;
|
|
247
|
+
for (var j : u32 = 0u; j < ${D}u; j = j + 1u) {
|
|
248
|
+
s = s + a[base + j];
|
|
249
|
+
}
|
|
250
|
+
out[i] = s / ${divisor};
|
|
251
|
+
}`.trim()
|
|
252
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: outerSize, workgroupSize: WG_SIZE }
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
// ---- Shape ---------------------------------------------------------------
|
|
256
|
+
// reshape: no kernel needed if buffers can alias (shape change only). For
|
|
257
|
+
// v1 simplicity we emit a memcpy-style kernel rather than aliasing buffers,
|
|
258
|
+
// because aliasing complicates the buffer plan and we have memory headroom.
|
|
259
|
+
case 'reshape': {
|
|
260
|
+
const out = tof(op.out)
|
|
261
|
+
const a = tof(op.a)
|
|
262
|
+
const total = shapeSize(out.shape)
|
|
263
|
+
const wgsl = `
|
|
264
|
+
@group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
|
|
265
|
+
@group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
|
|
266
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
267
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
268
|
+
let i = gid.x + gid.y * 16776960u;
|
|
269
|
+
if (i >= ${total}u) { return; }
|
|
270
|
+
out[i] = a[i];
|
|
271
|
+
}`.trim()
|
|
272
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
case 'transpose': {
|
|
276
|
+
const out = tof(op.out)
|
|
277
|
+
const a = tof(op.a)
|
|
278
|
+
const total = shapeSize(out.shape)
|
|
279
|
+
// Emit per-axis index computation. For each output flat index i, decompose
|
|
280
|
+
// into per-axis output indices, then use op.perm to find the source axis order.
|
|
281
|
+
// Source flat index = sum(outIdx[perm.invert()[k]] * a_stride[k] for k).
|
|
282
|
+
const aStrides = computeStrides(a.shape)
|
|
283
|
+
const outDimDecls = decomposeFlatIndexBlock('i', out.shape, 'oIdx')
|
|
284
|
+
const srcExpr: string[] = []
|
|
285
|
+
for (let k = 0; k < a.shape.length; k++) {
|
|
286
|
+
const srcAxis = op.perm.indexOf(k) // which output axis came from input axis k
|
|
287
|
+
srcExpr.push(`oIdx_${srcAxis} * ${aStrides[k]}u`)
|
|
288
|
+
}
|
|
289
|
+
const wgsl = `
|
|
290
|
+
@group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
|
|
291
|
+
@group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
|
|
292
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
293
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
294
|
+
let i = gid.x + gid.y * 16776960u;
|
|
295
|
+
if (i >= ${total}u) { return; }
|
|
296
|
+
${outDimDecls}
|
|
297
|
+
let srcIdx = ${srcExpr.join(' + ')};
|
|
298
|
+
out[i] = a[srcIdx];
|
|
299
|
+
}`.trim()
|
|
300
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
// ---- Linear algebra ----------------------------------------------------
|
|
304
|
+
// matmul: a [..., M, K] · b [K, N] -> [..., M, N]. b is unbatched.
|
|
305
|
+
case 'matmul': {
|
|
306
|
+
const out = tof(op.out)
|
|
307
|
+
const a = tof(op.a)
|
|
308
|
+
const b = tof(op.b)
|
|
309
|
+
const M = a.shape[a.shape.length - 2]!
|
|
310
|
+
const K = a.shape[a.shape.length - 1]!
|
|
311
|
+
const N = b.shape[1]!
|
|
312
|
+
const batch = shapeSize(a.shape) / (M * K)
|
|
313
|
+
const total = batch * M * N
|
|
314
|
+
const wgsl = `
|
|
315
|
+
@group(0) @binding(0) var<storage, read> a : array<f32>;
|
|
316
|
+
@group(0) @binding(1) var<storage, read> b : array<f32>;
|
|
317
|
+
@group(0) @binding(2) var<storage, read_write> c : array<f32>;
|
|
318
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
319
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
320
|
+
let i = gid.x + gid.y * 16776960u;
|
|
321
|
+
if (i >= ${total}u) { return; }
|
|
322
|
+
let bi = i / ${M * N}u; // batch index
|
|
323
|
+
let mn = i % ${M * N}u;
|
|
324
|
+
let m = mn / ${N}u;
|
|
325
|
+
let n = mn % ${N}u;
|
|
326
|
+
let aBase = bi * ${M * K}u + m * ${K}u;
|
|
327
|
+
var s : f32 = 0.0;
|
|
328
|
+
for (var k : u32 = 0u; k < ${K}u; k = k + 1u) {
|
|
329
|
+
s = s + a[aBase + k] * b[k * ${N}u + n];
|
|
330
|
+
}
|
|
331
|
+
c[i] = s;
|
|
332
|
+
}`.trim()
|
|
333
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
case 'matmul_batched': {
|
|
337
|
+
const out = tof(op.out)
|
|
338
|
+
const a = tof(op.a)
|
|
339
|
+
const b = tof(op.b)
|
|
340
|
+
const M = a.shape[a.shape.length - 2]!
|
|
341
|
+
const K = a.shape[a.shape.length - 1]!
|
|
342
|
+
const N = b.shape[b.shape.length - 1]!
|
|
343
|
+
const batch = shapeSize(a.shape) / (M * K)
|
|
344
|
+
const total = batch * M * N
|
|
345
|
+
const wgsl = `
|
|
346
|
+
@group(0) @binding(0) var<storage, read> a : array<f32>;
|
|
347
|
+
@group(0) @binding(1) var<storage, read> b : array<f32>;
|
|
348
|
+
@group(0) @binding(2) var<storage, read_write> c : array<f32>;
|
|
349
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
350
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
351
|
+
let i = gid.x + gid.y * 16776960u;
|
|
352
|
+
if (i >= ${total}u) { return; }
|
|
353
|
+
let bi = i / ${M * N}u;
|
|
354
|
+
let mn = i % ${M * N}u;
|
|
355
|
+
let m = mn / ${N}u;
|
|
356
|
+
let n = mn % ${N}u;
|
|
357
|
+
let aBase = bi * ${M * K}u + m * ${K}u;
|
|
358
|
+
let bBase = bi * ${K * N}u;
|
|
359
|
+
var s : f32 = 0.0;
|
|
360
|
+
for (var k : u32 = 0u; k < ${K}u; k = k + 1u) {
|
|
361
|
+
s = s + a[aBase + k] * b[bBase + k * ${N}u + n];
|
|
362
|
+
}
|
|
363
|
+
c[i] = s;
|
|
364
|
+
}`.trim()
|
|
365
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.b), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
// ---- One-hot ------------------------------------------------------------
|
|
369
|
+
case 'one_hot': {
|
|
370
|
+
const out = tof(op.out)
|
|
371
|
+
const indices = tof(op.indices)
|
|
372
|
+
const total = shapeSize(out.shape)
|
|
373
|
+
const depth = op.depth
|
|
374
|
+
const zeroLit = wgslLiteral(0, out.dtype)
|
|
375
|
+
const oneLit = wgslLiteral(1, out.dtype)
|
|
376
|
+
const wgsl = `
|
|
377
|
+
@group(0) @binding(0) var<storage, read> indices : array<i32>;
|
|
378
|
+
@group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
|
|
379
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
380
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
381
|
+
let i = gid.x + gid.y * 16776960u;
|
|
382
|
+
if (i >= ${total}u) { return; }
|
|
383
|
+
let outerIdx = i / ${depth}u;
|
|
384
|
+
let depthIdx = i % ${depth}u;
|
|
385
|
+
let tgt = u32(indices[outerIdx]);
|
|
386
|
+
out[i] = select(${zeroLit}, ${oneLit}, tgt == depthIdx);
|
|
387
|
+
}`.trim()
|
|
388
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.indices), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
// ---- ML primitives -----------------------------------------------------
|
|
392
|
+
case 'log_softmax_last': {
|
|
393
|
+
const a = tof(op.a)
|
|
394
|
+
const D = a.shape[a.shape.length - 1]!
|
|
395
|
+
const outerSize = shapeSize(a.shape) / D
|
|
396
|
+
const wgsl = `
|
|
397
|
+
@group(0) @binding(0) var<storage, read> a : array<f32>;
|
|
398
|
+
@group(0) @binding(1) var<storage, read_write> out : array<f32>;
|
|
399
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
400
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
401
|
+
let i = gid.x + gid.y * 16776960u;
|
|
402
|
+
if (i >= ${outerSize}u) { return; }
|
|
403
|
+
let base = i * ${D}u;
|
|
404
|
+
var m : f32 = -1.0e30;
|
|
405
|
+
for (var j : u32 = 0u; j < ${D}u; j = j + 1u) {
|
|
406
|
+
let v = a[base + j];
|
|
407
|
+
if (v > m) { m = v; }
|
|
408
|
+
}
|
|
409
|
+
var s : f32 = 0.0;
|
|
410
|
+
for (var j : u32 = 0u; j < ${D}u; j = j + 1u) {
|
|
411
|
+
s = s + exp(a[base + j] - m);
|
|
412
|
+
}
|
|
413
|
+
let logZ = m + log(s);
|
|
414
|
+
for (var j : u32 = 0u; j < ${D}u; j = j + 1u) {
|
|
415
|
+
out[base + j] = a[base + j] - logZ;
|
|
416
|
+
}
|
|
417
|
+
}`.trim()
|
|
418
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: outerSize, workgroupSize: WG_SIZE }
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
case 'softmax_causal_last': {
|
|
422
|
+
const a = tof(op.a)
|
|
423
|
+
const T = a.shape[a.shape.length - 1]! // == second-to-last (square)
|
|
424
|
+
// Outer size = (everything except last 2 axes) * (second-to-last axis)
|
|
425
|
+
const outerSize = shapeSize(a.shape) / T
|
|
426
|
+
const wgsl = `
|
|
427
|
+
@group(0) @binding(0) var<storage, read> a : array<f32>;
|
|
428
|
+
@group(0) @binding(1) var<storage, read_write> out : array<f32>;
|
|
429
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
430
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
431
|
+
// Each thread handles one (..., qpos)-row, softmaxing over kpos∈[0..qpos].
|
|
432
|
+
let i = gid.x + gid.y * 16776960u;
|
|
433
|
+
if (i >= ${outerSize}u) { return; }
|
|
434
|
+
let qpos = i % ${T}u;
|
|
435
|
+
let base = i * ${T}u;
|
|
436
|
+
var m : f32 = -1.0e30;
|
|
437
|
+
for (var k : u32 = 0u; k <= qpos; k = k + 1u) {
|
|
438
|
+
let v = a[base + k];
|
|
439
|
+
if (v > m) { m = v; }
|
|
440
|
+
}
|
|
441
|
+
var s : f32 = 0.0;
|
|
442
|
+
for (var k : u32 = 0u; k <= qpos; k = k + 1u) {
|
|
443
|
+
let e = exp(a[base + k] - m);
|
|
444
|
+
out[base + k] = e;
|
|
445
|
+
s = s + e;
|
|
446
|
+
}
|
|
447
|
+
for (var k : u32 = 0u; k <= qpos; k = k + 1u) {
|
|
448
|
+
out[base + k] = out[base + k] / s;
|
|
449
|
+
}
|
|
450
|
+
for (var k : u32 = qpos + 1u; k < ${T}u; k = k + 1u) {
|
|
451
|
+
out[base + k] = 0.0;
|
|
452
|
+
}
|
|
453
|
+
}`.trim()
|
|
454
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: outerSize, workgroupSize: WG_SIZE }
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
case 'where_causal': {
|
|
458
|
+
const a = tof(op.a)
|
|
459
|
+
const T = a.shape[a.shape.length - 1]!
|
|
460
|
+
const total = shapeSize(a.shape)
|
|
461
|
+
const fillLit = wgslLiteral(op.fillValue, 'f32')
|
|
462
|
+
const wgsl = `
|
|
463
|
+
@group(0) @binding(0) var<storage, read> a : array<f32>;
|
|
464
|
+
@group(0) @binding(1) var<storage, read_write> out : array<f32>;
|
|
465
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
466
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
467
|
+
let i = gid.x + gid.y * 16776960u;
|
|
468
|
+
if (i >= ${total}u) { return; }
|
|
469
|
+
let kpos = i % ${T}u;
|
|
470
|
+
let qpos = (i / ${T}u) % ${T}u;
|
|
471
|
+
if (kpos > qpos) {
|
|
472
|
+
out[i] = ${fillLit};
|
|
473
|
+
} else {
|
|
474
|
+
out[i] = a[i];
|
|
475
|
+
}
|
|
476
|
+
}`.trim()
|
|
477
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
// ---- Slicing -----------------------------------------------------------
|
|
481
|
+
case 'slice_last_range': {
|
|
482
|
+
const out = tof(op.out)
|
|
483
|
+
const a = tof(op.a)
|
|
484
|
+
const D_in = a.shape[a.shape.length - 1]!
|
|
485
|
+
const D_out = op.end - op.start
|
|
486
|
+
const total = shapeSize(out.shape)
|
|
487
|
+
const wgsl = `
|
|
488
|
+
@group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
|
|
489
|
+
@group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
|
|
490
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
491
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
492
|
+
let i = gid.x + gid.y * 16776960u;
|
|
493
|
+
if (i >= ${total}u) { return; }
|
|
494
|
+
let outer = i / ${D_out}u;
|
|
495
|
+
let inner = i % ${D_out}u;
|
|
496
|
+
out[i] = a[outer * ${D_in}u + ${op.start}u + inner];
|
|
497
|
+
}`.trim()
|
|
498
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
// ---- Broadcast / un-broadcast (autograd infrastructure) ----------------
|
|
502
|
+
case 'broadcast_to': {
|
|
503
|
+
const out = tof(op.out)
|
|
504
|
+
const a = tof(op.a)
|
|
505
|
+
const total = shapeSize(out.shape)
|
|
506
|
+
const wgsl = `
|
|
507
|
+
@group(0) @binding(0) var<storage, read> a : array<${wgslDtype(a.dtype)}>;
|
|
508
|
+
@group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(out.dtype)}>;
|
|
509
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
510
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
511
|
+
let i = gid.x + gid.y * 16776960u;
|
|
512
|
+
if (i >= ${total}u) { return; }
|
|
513
|
+
${broadcastIndexBlock('i', out.shape, a.shape, 'srcIdx')}
|
|
514
|
+
out[i] = a[srcIdx];
|
|
515
|
+
}`.trim()
|
|
516
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
// ---- Adam (fused per-element) -----------------------------------------
|
|
520
|
+
case 'adam_update_m': {
|
|
521
|
+
// m_new = b1 * m + (1 - b1) * g
|
|
522
|
+
const out = tof(op.out)
|
|
523
|
+
const total = shapeSize(out.shape)
|
|
524
|
+
const b1 = op.b1
|
|
525
|
+
const oneMinusB1 = 1 - b1
|
|
526
|
+
const wgsl = `
|
|
527
|
+
@group(0) @binding(0) var<storage, read> m : array<f32>;
|
|
528
|
+
@group(0) @binding(1) var<storage, read> g : array<f32>;
|
|
529
|
+
@group(0) @binding(2) var<storage, read_write> out : array<f32>;
|
|
530
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
531
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
532
|
+
let i = gid.x + gid.y * 16776960u;
|
|
533
|
+
if (i >= ${total}u) { return; }
|
|
534
|
+
out[i] = ${wgslLiteral(b1, 'f32')} * m[i] + ${wgslLiteral(oneMinusB1, 'f32')} * g[i];
|
|
535
|
+
}`.trim()
|
|
536
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.m), buf(op.g), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
537
|
+
}
|
|
538
|
+
case 'adam_update_v': {
|
|
539
|
+
// v_new = b2 * v + (1 - b2) * g²
|
|
540
|
+
const out = tof(op.out)
|
|
541
|
+
const total = shapeSize(out.shape)
|
|
542
|
+
const b2 = op.b2
|
|
543
|
+
const oneMinusB2 = 1 - b2
|
|
544
|
+
const wgsl = `
|
|
545
|
+
@group(0) @binding(0) var<storage, read> v : array<f32>;
|
|
546
|
+
@group(0) @binding(1) var<storage, read> g : array<f32>;
|
|
547
|
+
@group(0) @binding(2) var<storage, read_write> out : array<f32>;
|
|
548
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
549
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
550
|
+
let i = gid.x + gid.y * 16776960u;
|
|
551
|
+
if (i >= ${total}u) { return; }
|
|
552
|
+
let gv = g[i];
|
|
553
|
+
out[i] = ${wgslLiteral(b2, 'f32')} * v[i] + ${wgslLiteral(oneMinusB2, 'f32')} * gv * gv;
|
|
554
|
+
}`.trim()
|
|
555
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.v), buf(op.g), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
556
|
+
}
|
|
557
|
+
case 'adam_update_p': {
|
|
558
|
+
// p_new = p - lrt[0] * m_new / (sqrt(v_new) + eps).
|
|
559
|
+
// lrt is supplied per-step from CPU (already includes bias correction).
|
|
560
|
+
const out = tof(op.out)
|
|
561
|
+
const total = shapeSize(out.shape)
|
|
562
|
+
const wgsl = `
|
|
563
|
+
@group(0) @binding(0) var<storage, read> p : array<f32>;
|
|
564
|
+
@group(0) @binding(1) var<storage, read> mNew : array<f32>;
|
|
565
|
+
@group(0) @binding(2) var<storage, read> vNew : array<f32>;
|
|
566
|
+
@group(0) @binding(3) var<storage, read> lrt : array<f32>;
|
|
567
|
+
@group(0) @binding(4) var<storage, read_write> out : array<f32>;
|
|
568
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
569
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
570
|
+
let i = gid.x + gid.y * 16776960u;
|
|
571
|
+
if (i >= ${total}u) { return; }
|
|
572
|
+
out[i] = p[i] - lrt[0] * mNew[i] / (sqrt(vNew[i]) + ${wgslLiteral(op.eps, 'f32')});
|
|
573
|
+
}`.trim()
|
|
574
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.p), buf(op.mNew), buf(op.vNew), buf(op.lrt), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
575
|
+
}
|
|
576
|
+
|
|
577
|
+
case 'sum_to_shape': {
|
|
578
|
+
// Sum-reduce src down to target by summing over each axis where target=1
|
|
579
|
+
// or where target is missing (offset-prefix axes that get fully summed).
|
|
580
|
+
const out = tof(op.out)
|
|
581
|
+
const a = tof(op.a)
|
|
582
|
+
const wgsl = emitSumToShape(a.shape, out.shape, a.dtype)
|
|
583
|
+
const total = shapeSize(out.shape)
|
|
584
|
+
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.a), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
585
|
+
}
|
|
586
|
+
}
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
// ============================================================================
|
|
590
|
+
// WGSL helpers
|
|
591
|
+
// ============================================================================
|
|
592
|
+
|
|
593
|
+
function wgslDtype(d: 'f32' | 'i32' | 'bool'): string {
|
|
594
|
+
// bool can't be in storage buffers in WGSL; we lower bool-typed tensors to
|
|
595
|
+
// u32 (0/1). For Phase 3a there are no bool-typed storage buffers in the
|
|
596
|
+
// forward+backward graph (causal mask is built inline in softmax kernels),
|
|
597
|
+
// so this only matters if the user explicitly creates a bool tensor.
|
|
598
|
+
if (d === 'bool') return 'u32'
|
|
599
|
+
return d
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
function wgslLiteral(value: number, dtype: 'f32' | 'i32' | 'bool'): string {
|
|
603
|
+
if (dtype === 'f32') {
|
|
604
|
+
if (Number.isFinite(value)) {
|
|
605
|
+
// WGSL requires `.` in float literals; force decimal form.
|
|
606
|
+
return value.toString().includes('.') || value.toString().includes('e')
|
|
607
|
+
? `${value}f`
|
|
608
|
+
: `${value}.0f`
|
|
609
|
+
}
|
|
610
|
+
return value > 0 ? '1.0e30f' : '-1.0e30f'
|
|
611
|
+
}
|
|
612
|
+
if (dtype === 'i32') return `${Math.trunc(value)}i`
|
|
613
|
+
return value ? '1u' : '0u'
|
|
614
|
+
}
|
|
615
|
+
|
|
616
|
+
function castFromI32(expr: string, dtype: 'f32' | 'i32' | 'bool'): string {
|
|
617
|
+
if (dtype === 'f32') return `f32(${expr})`
|
|
618
|
+
if (dtype === 'i32') return `i32(${expr})`
|
|
619
|
+
return `u32(${expr})`
|
|
620
|
+
}
|
|
621
|
+
|
|
622
|
+
function computeStrides(shape: Shape): number[] {
|
|
623
|
+
const strides: number[] = new Array(shape.length).fill(1)
|
|
624
|
+
for (let i = shape.length - 2; i >= 0; i--) {
|
|
625
|
+
strides[i] = strides[i + 1]! * shape[i + 1]!
|
|
626
|
+
}
|
|
627
|
+
return strides
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
/**
|
|
631
|
+
* Generate WGSL that decomposes a flat index `flatVar` into per-axis indices
|
|
632
|
+
* `outVar_0, outVar_1, ...` according to `shape`.
|
|
633
|
+
*/
|
|
634
|
+
function decomposeFlatIndexBlock(flatVar: string, shape: Shape, outVar: string): string {
|
|
635
|
+
if (shape.length === 0) return ` let ${outVar}_0 : u32 = 0u;` // not used but parser-safe
|
|
636
|
+
const strides = computeStrides(shape)
|
|
637
|
+
const lines: string[] = []
|
|
638
|
+
let remaining = flatVar
|
|
639
|
+
for (let i = 0; i < shape.length; i++) {
|
|
640
|
+
if (i === shape.length - 1) {
|
|
641
|
+
lines.push(` let ${outVar}_${i} = ${remaining};`)
|
|
642
|
+
} else {
|
|
643
|
+
lines.push(` let ${outVar}_${i} = ${remaining} / ${strides[i]}u;`)
|
|
644
|
+
const newRem = `${outVar}_rem${i}`
|
|
645
|
+
lines.push(` let ${newRem} = ${remaining} % ${strides[i]}u;`)
|
|
646
|
+
remaining = newRem
|
|
647
|
+
}
|
|
648
|
+
}
|
|
649
|
+
return lines.join('\n')
|
|
650
|
+
}
|
|
651
|
+
|
|
652
|
+
/**
|
|
653
|
+
* Generate WGSL that computes the source flat index in `srcVar` for an output
|
|
654
|
+
* flat index `flatVar`, given output shape `outShape` and source shape `srcShape`
|
|
655
|
+
* under right-aligned NumPy-style broadcasting (size-1 axes broadcast).
|
|
656
|
+
*
|
|
657
|
+
* Strategy:
|
|
658
|
+
* 1. Decompose flat output index into per-axis output indices.
|
|
659
|
+
* 2. For each output axis that maps onto a source axis (right-aligned), use
|
|
660
|
+
* the output index there if src.dim != 1, else 0 (broadcast).
|
|
661
|
+
* 3. Drop output-only axes (those with no corresponding source axis).
|
|
662
|
+
* 4. Combine source indices with source strides.
|
|
663
|
+
*/
|
|
664
|
+
function broadcastIndexBlock(flatVar: string, outShape: Shape, srcShape: Shape, srcVar: string): string {
|
|
665
|
+
// Name the per-axis decomposition vars after `srcVar` so multiple
|
|
666
|
+
// broadcastIndexBlock calls in the same WGSL function don't collide.
|
|
667
|
+
const prefix = `${srcVar}_ax`
|
|
668
|
+
const decompose = decomposeFlatIndexBlock(flatVar, outShape, prefix)
|
|
669
|
+
const offset = outShape.length - srcShape.length
|
|
670
|
+
if (srcShape.length === 0) {
|
|
671
|
+
return `${decompose}\n let ${srcVar} : u32 = 0u;`
|
|
672
|
+
}
|
|
673
|
+
const srcStrides = computeStrides(srcShape)
|
|
674
|
+
const terms: string[] = []
|
|
675
|
+
for (let i = 0; i < srcShape.length; i++) {
|
|
676
|
+
const outAxis = i + offset
|
|
677
|
+
const srcDim = srcShape[i]!
|
|
678
|
+
const term = srcDim === 1 ? '0u' : `${prefix}_${outAxis} * ${srcStrides[i]}u`
|
|
679
|
+
terms.push(term)
|
|
680
|
+
}
|
|
681
|
+
return `${decompose}\n let ${srcVar} = ${terms.join(' + ')};`
|
|
682
|
+
}
|
|
683
|
+
|
|
684
|
+
/**
|
|
685
|
+
* sum_to_shape: each output cell sums over the source axes that are reduced.
|
|
686
|
+
* For source shape S and target shape T (right-aligned):
|
|
687
|
+
* - Axes in S not in T (leading prefix): fully reduced (sum over whole axis).
|
|
688
|
+
* - Axes where T=1 but S>1: reduced (sum over that axis).
|
|
689
|
+
* - Axes where T=S: passed through.
|
|
690
|
+
*
|
|
691
|
+
* Implementation: each thread = one output cell. It iterates over the reduced
|
|
692
|
+
* axes via nested-loop unrolling (we generate explicit nested for-loops).
|
|
693
|
+
*/
|
|
694
|
+
function emitSumToShape(srcShape: Shape, tgtShape: Shape, dtype: 'f32' | 'i32' | 'bool'): string {
|
|
695
|
+
const srcStrides = computeStrides(srcShape)
|
|
696
|
+
const tgtStrides = computeStrides(tgtShape)
|
|
697
|
+
const offset = srcShape.length - tgtShape.length
|
|
698
|
+
|
|
699
|
+
// Decompose flat output index into per-axis target indices.
|
|
700
|
+
const decompose = decomposeFlatIndexBlock('i', tgtShape, 'tgt')
|
|
701
|
+
|
|
702
|
+
// Identify reduced axes of the SOURCE: axis k in src is reduced if either
|
|
703
|
+
// it's in the leading prefix (k < offset) or its corresponding target axis
|
|
704
|
+
// has size 1. For non-reduced axes (k >= offset and tgt=src), the source
|
|
705
|
+
// index is the target index along that axis.
|
|
706
|
+
const reducedAxes: number[] = []
|
|
707
|
+
for (let k = 0; k < srcShape.length; k++) {
|
|
708
|
+
if (k < offset) { reducedAxes.push(k); continue }
|
|
709
|
+
const tDim = tgtShape[k - offset]!
|
|
710
|
+
const sDim = srcShape[k]!
|
|
711
|
+
if (tDim === 1 && sDim > 1) reducedAxes.push(k)
|
|
712
|
+
}
|
|
713
|
+
|
|
714
|
+
// Build the source flat index expression. Initialize from the non-reduced axes.
|
|
715
|
+
const baseTerms: string[] = []
|
|
716
|
+
for (let k = 0; k < srcShape.length; k++) {
|
|
717
|
+
if (reducedAxes.includes(k)) continue // contributed by loop var instead
|
|
718
|
+
const tAxis = k - offset
|
|
719
|
+
baseTerms.push(`tgt_${tAxis} * ${srcStrides[k]}u`)
|
|
720
|
+
}
|
|
721
|
+
const baseExpr = baseTerms.length > 0 ? baseTerms.join(' + ') : '0u'
|
|
722
|
+
|
|
723
|
+
// Emit nested for loops over the reduced axes.
|
|
724
|
+
const indent = (depth: number) => ' '.repeat(depth + 1)
|
|
725
|
+
const loops: string[] = []
|
|
726
|
+
for (let depth = 0; depth < reducedAxes.length; depth++) {
|
|
727
|
+
const k = reducedAxes[depth]!
|
|
728
|
+
const dim = srcShape[k]!
|
|
729
|
+
loops.push(`${indent(depth)}for (var r${k} : u32 = 0u; r${k} < ${dim}u; r${k} = r${k} + 1u) {`)
|
|
730
|
+
}
|
|
731
|
+
// Inside innermost loop, compute source index.
|
|
732
|
+
const reducedTerms = reducedAxes.map(k => `r${k} * ${srcStrides[k]}u`)
|
|
733
|
+
const fullExpr = reducedTerms.length > 0
|
|
734
|
+
? `${baseExpr} + ${reducedTerms.join(' + ')}`
|
|
735
|
+
: baseExpr
|
|
736
|
+
loops.push(`${indent(reducedAxes.length)}s = s + a[${fullExpr}];`)
|
|
737
|
+
for (let depth = reducedAxes.length - 1; depth >= 0; depth--) {
|
|
738
|
+
loops.push(`${indent(depth)}}`)
|
|
739
|
+
}
|
|
740
|
+
|
|
741
|
+
const total = tgtShape.length === 0 ? 1 : (tgtStrides[0]! * tgtShape[0]!)
|
|
742
|
+
const loopBody = reducedAxes.length === 0
|
|
743
|
+
? ` s = s + a[${baseExpr}];`
|
|
744
|
+
: loops.join('\n')
|
|
745
|
+
|
|
746
|
+
return `
|
|
747
|
+
@group(0) @binding(0) var<storage, read> a : array<${wgslDtype(dtype)}>;
|
|
748
|
+
@group(0) @binding(1) var<storage, read_write> out : array<${wgslDtype(dtype)}>;
|
|
749
|
+
@compute @workgroup_size(${WG_SIZE})
|
|
750
|
+
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
751
|
+
let i = gid.x + gid.y * 16776960u;
|
|
752
|
+
if (i >= ${total}u) { return; }
|
|
753
|
+
${decompose}
|
|
754
|
+
var s : ${wgslDtype(dtype)} = ${dtype === 'f32' ? '0.0f' : (dtype === 'i32' ? '0i' : '0u')};
|
|
755
|
+
${loopBody}
|
|
756
|
+
out[i] = s;
|
|
757
|
+
}`.trim()
|
|
758
|
+
}
|