@framecast/rt 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (6) hide show
  1. package/LICENSE +18 -0
  2. package/README.md +78 -0
  3. package/index.d.ts +57 -0
  4. package/package.json +43 -0
  5. package/rt.js +1254 -0
  6. package/sr.js +205 -0
package/rt.js ADDED
@@ -0,0 +1,1254 @@
1
+ // Framecast custom WebGPU runtime - hand-rolled forward of the 1-block student
2
+ // (block0 of IFNet_m, scale=4, timestep=0.5). The whole frame is ONE command buffer
3
+ // of ~13 dispatches; no per-op JS, no runtime glue. Weights: assets/rt_1blk.{bin,json}
4
+ // (tools/export_rt_weights.py).
5
+ //
6
+ // Graph replicated from model/IFNet_m.py:
7
+ // x = cat(img0,img1,t) BGR /255 [7,H,W]
8
+ // xq = bilinear(x, 1/4, align_corners=false) [7,H/4,W/4]
9
+ // f8 = prelu(conv3x3s2(xq)) [120,H/8,W/8]
10
+ // f16 = prelu(conv3x3s2(f8)) [240,H/16,W/16]
11
+ // f16 = convblock(f16) + f16 (8x conv3x3s1+prelu, residual)
12
+ // tmp8 = deconv4x4s2(f16) [5,H/8,W/8]
13
+ // tmpF = bilinear(tmp8, x8, align_corners=false); flow = tmpF[0:4]*8; mask = tmpF[4]
14
+ // mid = sigmoid(mask)*warp(img0,flow01) + (1-s)*warp(img1,flow23)
15
+ // Requires H,W divisible by 16.
16
+
17
+ const WG = 8;
18
+
19
+ // identity stamps for texture-keyed caches. MODULE scope, not per-createRT:
20
+ // runtimes get rebuilt (res/model/tune switches) while the caller's textures
21
+ // survive with their stamps - a per-instance counter restarting at 1 would
22
+ // re-issue ids already stamped on live textures, and the next pool realloc
23
+ // would hand destroyed-texture bind groups out of the cache again.
24
+ let texBgSeq = 0;
25
+ const texBgId = (t) => t.__rtBgId || (t.__rtBgId = ++texBgSeq);
26
+
27
+ // NOTE: with layout:'auto' WebGPU strips unused bindings from the layout, so each
28
+ // entry point gets its own shader with exactly the bindings it touches.
29
+ function wgslPrepFull(W, H) {
30
+ return /* wgsl */`
31
+ @group(0) @binding(0) var<storage, read> rgba0: array<u32>;
32
+ @group(0) @binding(1) var<storage, read> rgba1: array<u32>;
33
+ @group(0) @binding(2) var<storage, read_write> imgs: array<f32>; // [6,${H},${W}] b0 g0 r0 b1 g1 r1
34
+
35
+ @compute @workgroup_size(${WG}, ${WG})
36
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
37
+ let x = i32(gid.x); let y = i32(gid.y);
38
+ if (x >= ${W} || y >= ${H}) { return; }
39
+ let o = y * ${W} + x;
40
+ let c0 = unpack4x8unorm(rgba0[o]).xyz;
41
+ let c1 = unpack4x8unorm(rgba1[o]).xyz;
42
+ let P = ${H * W};
43
+ imgs[o] = c0.z; imgs[P + o] = c0.y; imgs[2 * P + o] = c0.x;
44
+ imgs[3 * P + o] = c1.z; imgs[4 * P + o] = c1.y; imgs[5 * P + o] = c1.x;
45
+ }`;
46
+ }
47
+
48
+ function wgslPrepQuarter(W, H, f16) {
49
+ const QW = W / 4, QH = H / 4;
50
+ const T = f16 ? 'f16' : 'f32';
51
+ return /* wgsl */`
52
+ ${f16 ? 'enable f16;' : ''}
53
+ @group(0) @binding(0) var<storage, read> rgba0: array<u32>;
54
+ @group(0) @binding(1) var<storage, read> rgba1: array<u32>;
55
+ @group(0) @binding(2) var<storage, read_write> xq: array<${T}>; // [7,${QH},${QW}]
56
+ @group(0) @binding(3) var<storage, read> tstep: array<f32>; // [1] timestep
57
+
58
+ fn px(buf: i32, x: i32, y: i32) -> vec3<f32> {
59
+ let v = select(rgba1[y * ${W} + x], rgba0[y * ${W} + x], buf == 0);
60
+ return unpack4x8unorm(v).xyz; // r,g,b in 0..1
61
+ }
62
+
63
+ fn sampleQ(buf: i32, sx: f32, sy: f32) -> vec3<f32> {
64
+ let x0 = i32(floor(sx)); let y0 = i32(floor(sy));
65
+ let fx = sx - f32(x0); let fy = sy - f32(y0);
66
+ let xa = clamp(x0, 0, ${W - 1}); let xb = clamp(x0 + 1, 0, ${W - 1});
67
+ let ya = clamp(y0, 0, ${H - 1}); let yb = clamp(y0 + 1, 0, ${H - 1});
68
+ let v00 = px(buf, xa, ya); let v10 = px(buf, xb, ya);
69
+ let v01 = px(buf, xa, yb); let v11 = px(buf, xb, yb);
70
+ return mix(mix(v00, v10, fx), mix(v01, v11, fx), fy);
71
+ }
72
+
73
+ @compute @workgroup_size(${WG}, ${WG})
74
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
75
+ let x = i32(gid.x); let y = i32(gid.y);
76
+ if (x >= ${QW} || y >= ${QH}) { return; }
77
+ // align_corners=false: src = (dst+0.5)*4 - 0.5
78
+ let sx = (f32(x) + 0.5) * 4.0 - 0.5;
79
+ let sy = (f32(y) + 0.5) * 4.0 - 0.5;
80
+ let c0 = sampleQ(0, sx, sy);
81
+ let c1 = sampleQ(1, sx, sy);
82
+ let o = y * ${QW} + x;
83
+ let P = ${QH * QW};
84
+ xq[o] = ${T}(c0.z); xq[P + o] = ${T}(c0.y); xq[2 * P + o] = ${T}(c0.x);
85
+ xq[3 * P + o] = ${T}(c1.z); xq[4 * P + o] = ${T}(c1.y); xq[5 * P + o] = ${T}(c1.x);
86
+ xq[6 * P + o] = ${T}(tstep[0]); // timestep
87
+ }`;
88
+ }
89
+
90
+ // conv3x3 (stride 1/2) + bias + PReLU, optional residual add (post-activation).
91
+ // v3: COC output channels per thread (src reads amortized), weight slab staged through
92
+ // workgroup memory, and optional f16 storage for activations+weights (accumulation
93
+ // stays f32) - halves the traffic on the bandwidth-bound conv stack.
94
+ function wgslConv(CI, CO, IW, IH, OW, OH, stride, residual, f16) {
95
+ const COC = CO % 4 === 0 ? 4 : 1; // channels per thread
96
+ const SLAB = Math.min(CI, 30); // ci per staging round (fits 16KB wg memory)
97
+ const slabFloats = COC * SLAB * 9;
98
+ const T = f16 ? 'f16' : 'f32';
99
+ return /* wgsl */`
100
+ ${f16 ? 'enable f16;' : ''}
101
+ @group(0) @binding(0) var<storage, read> src: array<${T}>; // [${CI},${IH},${IW}]
102
+ @group(0) @binding(1) var<storage, read> wgt: array<${T}>; // [${CO},${CI},3,3]
103
+ @group(0) @binding(2) var<storage, read> bias: array<f32>; // [${CO}]
104
+ @group(0) @binding(3) var<storage, read> alpha: array<f32>; // [${CO}] prelu
105
+ @group(0) @binding(4) var<storage, read_write> dst: array<${T}>; // [${CO},${OH},${OW}]
106
+ ${residual ? `@group(0) @binding(5) var<storage, read> res: array<${T}>;` : ``}
107
+
108
+ var<workgroup> wsh: array<${T}, ${slabFloats}>; // [COC, SLAB, 9] slab of weights
109
+ ${stride === 1 ? `var<workgroup> tile: array<${T}, ${SLAB * 100}>; // [SLAB, 10, 10] input tiles (8x8 out + halo)` : ''}
110
+
111
+ @compute @workgroup_size(${WG}, ${WG}, 1)
112
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>,
113
+ @builtin(local_invocation_id) lid: vec3<u32>,
114
+ @builtin(workgroup_id) wid: vec3<u32>,
115
+ @builtin(local_invocation_index) li: u32) {
116
+ let x = i32(gid.x); let y = i32(gid.y);
117
+ let lx = i32(lid.x); let ly = i32(lid.y);
118
+ let wx0 = i32(wid.x) * ${WG}; let wy0 = i32(wid.y) * ${WG};
119
+ let cb = i32(gid.z) * ${COC}; // first output channel of this thread's block
120
+ let inb = x < ${OW} && y < ${OH};
121
+ var acc: array<f32, ${COC}>;
122
+ for (var c = 0; c < ${COC}; c++) { acc[c] = bias[cb + c]; }
123
+
124
+ for (var s = 0; s < ${CI}; s += ${SLAB}) {
125
+ let sl = min(${SLAB}, ${CI} - s);
126
+ let n = ${COC} * sl * 9;
127
+ workgroupBarrier();
128
+ // cooperative load: weights for [cb..cb+COC) x [s..s+sl) x 9
129
+ var idx = i32(li);
130
+ while (idx < n) {
131
+ let c = idx / (sl * 9);
132
+ let r = idx % (sl * 9);
133
+ wsh[idx] = wgt[(cb + c) * ${CI * 9} + (s + r / 9) * 9 + r % 9];
134
+ idx += ${WG * WG};
135
+ }
136
+ workgroupBarrier();
137
+ ${stride === 1 ? `
138
+ // stride-1 path: stage 10x10 input tiles for the WHOLE ci-slab at once - barriers
139
+ // per slab (16/conv) instead of per channel (480/conv), values reused 9x from shared
140
+ var ti = i32(li);
141
+ let tn = sl * 100;
142
+ while (ti < tn) {
143
+ let ci = ti / 100;
144
+ let r = ti % 100;
145
+ let ty = wy0 + r / 10 - 1;
146
+ let tx = wx0 + r % 10 - 1;
147
+ var v = ${T}(0.0);
148
+ if (ty >= 0 && ty < ${IH} && tx >= 0 && tx < ${IW}) {
149
+ v = src[(s + ci) * ${IH * IW} + ty * ${IW} + tx];
150
+ }
151
+ tile[ti] = v;
152
+ ti += ${WG * WG};
153
+ }
154
+ workgroupBarrier();
155
+ if (inb) {
156
+ for (var ci = 0; ci < sl; ci++) {
157
+ let tb = ci * 100;
158
+ for (var ky = 0; ky < 3; ky++) {
159
+ for (var kx = 0; kx < 3; kx++) {
160
+ let sv = f32(tile[tb + (ly + ky) * 10 + lx + kx]);
161
+ let wb = ci * 9 + ky * 3 + kx;
162
+ for (var c = 0; c < ${COC}; c++) {
163
+ acc[c] += sv * f32(wsh[c * (sl * 9) + wb]);
164
+ }
165
+ }
166
+ }
167
+ }
168
+ }
169
+ }` : `
170
+ if (inb) {
171
+ for (var ci = 0; ci < sl; ci++) {
172
+ let sbase = (s + ci) * ${IH * IW};
173
+ for (var ky = 0; ky < 3; ky++) {
174
+ let iy = y * ${stride} + ky - 1;
175
+ if (iy < 0 || iy >= ${IH}) { continue; }
176
+ for (var kx = 0; kx < 3; kx++) {
177
+ let ix = x * ${stride} + kx - 1;
178
+ if (ix < 0 || ix >= ${IW}) { continue; }
179
+ let sv = f32(src[sbase + iy * ${IW} + ix]);
180
+ let wb = ci * 9 + ky * 3 + kx;
181
+ for (var c = 0; c < ${COC}; c++) {
182
+ acc[c] += sv * f32(wsh[c * (sl * 9) + wb]);
183
+ }
184
+ }
185
+ }
186
+ }
187
+ }
188
+ }`}
189
+ if (!inb) { return; }
190
+ for (var c = 0; c < ${COC}; c++) {
191
+ let co = cb + c;
192
+ let v = select(alpha[co] * acc[c], acc[c], acc[c] >= 0.0);
193
+ let o = co * ${OH * OW} + y * ${OW} + x;
194
+ dst[o] = ${T}(${residual ? `v + f32(res[o])` : `v`});
195
+ }
196
+ }`;
197
+ }
198
+
199
+ // register-blocked conv3x3 s1 (f16 storage): each thread computes a 2x2 pixel patch x
200
+ // 4 output channels (16 accumulators) - every shared read now feeds 4 FMAs instead of ~1.
201
+ // Workgroup = 8x8 threads = 16x16 output tile; input tiles 18x18 per ci staged per slab.
202
+ export function wgslConvRB(CI, CO, IW, IH, OW, OH, residual, tune) {
203
+ // tune: {coc, slab} - shared memory must fit slab*324*2 + coc*slab*9*2 <= 16384
204
+ const COC = (tune && tune.coc) || 4, SLAB = (tune && tune.slab) || 20;
205
+ const slabW = COC * SLAB * 9; // f16 weights in shared
206
+ const slabT = SLAB * 324; // 18x18 tiles
207
+ return /* wgsl */`
208
+ enable f16;
209
+ @group(0) @binding(0) var<storage, read> src: array<f16>;
210
+ @group(0) @binding(1) var<storage, read> wgt: array<f16>;
211
+ @group(0) @binding(2) var<storage, read> bias: array<f32>;
212
+ @group(0) @binding(3) var<storage, read> alpha: array<f32>;
213
+ @group(0) @binding(4) var<storage, read_write> dst: array<f16>;
214
+ ${residual ? `@group(0) @binding(5) var<storage, read> res: array<f16>;` : ``}
215
+
216
+ var<workgroup> wsh: array<f16, ${slabW}>;
217
+ var<workgroup> tile: array<f16, ${slabT}>;
218
+
219
+ @compute @workgroup_size(8, 8, 1)
220
+ fn main(@builtin(local_invocation_id) lid: vec3<u32>,
221
+ @builtin(workgroup_id) wid: vec3<u32>,
222
+ @builtin(local_invocation_index) li: u32) {
223
+ let lx = i32(lid.x); let ly = i32(lid.y);
224
+ let ox0 = i32(wid.x) * 16; let oy0 = i32(wid.y) * 16; // wg output origin
225
+ let x0 = ox0 + lx * 2; let y0 = oy0 + ly * 2; // this thread's 2x2 patch
226
+ let cb = i32(wid.z) * ${COC};
227
+ // 16 scalar accumulators (unrolled - arrays may spill out of registers in WGSL)
228
+ ${Array.from({ length: COC }, (_, c) =>
229
+ ` var a${c}0 = bias[cb + ${c}]; var a${c}1 = a${c}0; var a${c}2 = a${c}0; var a${c}3 = a${c}0;`).join('\n')}
230
+
231
+ for (var s = 0; s < ${CI}; s += ${SLAB}) {
232
+ let sl = min(${SLAB}, ${CI} - s);
233
+ workgroupBarrier();
234
+ var idx = i32(li);
235
+ let wn = ${COC} * sl * 9;
236
+ while (idx < wn) {
237
+ let c = idx / (sl * 9);
238
+ let r = idx % (sl * 9);
239
+ wsh[idx] = wgt[(cb + c) * ${CI * 9} + (s + r / 9) * 9 + r % 9];
240
+ idx += 64;
241
+ }
242
+ var ti = i32(li);
243
+ let tn = sl * 324;
244
+ while (ti < tn) {
245
+ let ci = ti / 324;
246
+ let r = ti % 324;
247
+ let ty = oy0 + r / 18 - 1;
248
+ let tx = ox0 + r % 18 - 1;
249
+ var v = f16(0.0);
250
+ if (ty >= 0 && ty < ${IH} && tx >= 0 && tx < ${IW}) {
251
+ v = src[(s + ci) * ${IH * IW} + ty * ${IW} + tx];
252
+ }
253
+ tile[ti] = v;
254
+ ti += 64;
255
+ }
256
+ workgroupBarrier();
257
+ for (var ci = 0; ci < sl; ci++) {
258
+ let tb = ci * 324 + (ly * 2) * 18 + lx * 2; // top-left of this thread's 4x4 window
259
+ for (var ky = 0; ky < 3; ky++) {
260
+ let rb = tb + ky * 18;
261
+ for (var kx = 0; kx < 3; kx++) {
262
+ let t00 = f32(tile[rb + kx]);
263
+ let t01 = f32(tile[rb + kx + 1]);
264
+ let t10 = f32(tile[rb + kx + 18]);
265
+ let t11 = f32(tile[rb + kx + 19]);
266
+ let wb = ci * 9 + ky * 3 + kx;
267
+ ${Array.from({ length: COC }, (_, c) => ` {
268
+ let wv = f32(wsh[${c} * (sl * 9) + wb]);
269
+ a${c}0 += t00 * wv; a${c}1 += t01 * wv; a${c}2 += t10 * wv; a${c}3 += t11 * wv;
270
+ }`).join('\n')}
271
+ }
272
+ }
273
+ }
274
+ }
275
+ ${Array.from({ length: COC }, (_, c) => ` {
276
+ let co = cb + ${c};
277
+ let al = alpha[co];
278
+ ${[0, 1, 2, 3].map(p => ` {
279
+ let x = x0 + ${p & 1};
280
+ let y = y0 + ${p >> 1};
281
+ if (x < ${OW} && y < ${OH}) {
282
+ let a = a${c}${p};
283
+ let v = select(al * a, a, a >= 0.0);
284
+ let o = co * ${OH * OW} + y * ${OW} + x;
285
+ dst[o] = f16(${residual ? `v + f32(res[o])` : `v`});
286
+ }
287
+ }`).join('\n')}
288
+ }`).join('\n')}
289
+ }`;
290
+ }
291
+
292
+ // exp/subgroups: like wgslConvRB, but weights are NOT staged through shared
293
+ // memory - every lane computes the same weight index, so the first lane loads it
294
+ // from global (L2) and subgroupBroadcastFirst hands it to the wave. Saves the
295
+ // cooperative staging loop and shrinks the barrier to the input tile only.
296
+ export function wgslConvRBSg(CI, CO, IW, IH, OW, OH, residual, tune) {
297
+ const COC = (tune && tune.coc) || 4, SLAB = (tune && tune.slab) || 20;
298
+ const slabT = SLAB * 324;
299
+ return /* wgsl */`
300
+ enable f16;
301
+ enable subgroups;
302
+ @group(0) @binding(0) var<storage, read> src: array<f16>;
303
+ @group(0) @binding(1) var<storage, read> wgt: array<f16>;
304
+ @group(0) @binding(2) var<storage, read> bias: array<f32>;
305
+ @group(0) @binding(3) var<storage, read> alpha: array<f32>;
306
+ @group(0) @binding(4) var<storage, read_write> dst: array<f16>;
307
+ ${residual ? `@group(0) @binding(5) var<storage, read> res: array<f16>;` : ``}
308
+
309
+ var<workgroup> tile: array<f16, ${slabT}>;
310
+
311
+ @compute @workgroup_size(8, 8, 1)
312
+ fn main(@builtin(local_invocation_id) lid: vec3<u32>,
313
+ @builtin(workgroup_id) wid: vec3<u32>,
314
+ @builtin(local_invocation_index) li: u32) {
315
+ let lx = i32(lid.x); let ly = i32(lid.y);
316
+ let ox0 = i32(wid.x) * 16; let oy0 = i32(wid.y) * 16;
317
+ let x0 = ox0 + lx * 2; let y0 = oy0 + ly * 2;
318
+ let cb = i32(wid.z) * ${COC};
319
+ ${Array.from({ length: COC }, (_, c) =>
320
+ ` var a${c}0 = bias[cb + ${c}]; var a${c}1 = a${c}0; var a${c}2 = a${c}0; var a${c}3 = a${c}0;`).join('\n')}
321
+
322
+ for (var s = 0; s < ${CI}; s += ${SLAB}) {
323
+ let sl = min(${SLAB}, ${CI} - s);
324
+ workgroupBarrier();
325
+ var ti = i32(li);
326
+ let tn = sl * 324;
327
+ while (ti < tn) {
328
+ let ci = ti / 324;
329
+ let r = ti % 324;
330
+ let ty = oy0 + r / 18 - 1;
331
+ let tx = ox0 + r % 18 - 1;
332
+ var v = f16(0.0);
333
+ if (ty >= 0 && ty < ${IH} && tx >= 0 && tx < ${IW}) {
334
+ v = src[(s + ci) * ${IH * IW} + ty * ${IW} + tx];
335
+ }
336
+ tile[ti] = v;
337
+ ti += 64;
338
+ }
339
+ workgroupBarrier();
340
+ for (var ci = 0; ci < sl; ci++) {
341
+ let tb = ci * 324 + (ly * 2) * 18 + lx * 2;
342
+ let wrow = (s + ci) * 9;
343
+ for (var ky = 0; ky < 3; ky++) {
344
+ let rb = tb + ky * 18;
345
+ for (var kx = 0; kx < 3; kx++) {
346
+ let t00 = f32(tile[rb + kx]);
347
+ let t01 = f32(tile[rb + kx + 1]);
348
+ let t10 = f32(tile[rb + kx + 18]);
349
+ let t11 = f32(tile[rb + kx + 19]);
350
+ let wk = wrow + ky * 3 + kx;
351
+ ${Array.from({ length: COC }, (_, c) => ` {
352
+ let wv = subgroupBroadcastFirst(f32(wgt[(cb + ${c}) * ${CI * 9} + wk]));
353
+ a${c}0 += t00 * wv; a${c}1 += t01 * wv; a${c}2 += t10 * wv; a${c}3 += t11 * wv;
354
+ }`).join('\n')}
355
+ }
356
+ }
357
+ }
358
+ }
359
+ ${Array.from({ length: COC }, (_, c) => ` {
360
+ let co = cb + ${c};
361
+ let al = alpha[co];
362
+ ${[0, 1, 2, 3].map(p => ` {
363
+ let x = x0 + ${p & 1};
364
+ let y = y0 + ${p >> 1};
365
+ if (x < ${OW} && y < ${OH}) {
366
+ let a = a${c}${p};
367
+ let v = select(al * a, a, a >= 0.0);
368
+ let o = co * ${OH * OW} + y * ${OW} + x;
369
+ dst[o] = f16(${residual ? `v + f32(res[o])` : `v`});
370
+ }
371
+ }`).join('\n')}
372
+ }`).join('\n')}
373
+ }`;
374
+ }
375
+
376
+ // texture-input prep variants: the video frame lives in a GPU texture (uploaded via
377
+ // copyExternalImageToTexture) and never touches the CPU; the sampler also does the
378
+ // display->model resize for free. Sampling at texel centers == exact texel values.
379
+ function wgslPrepFullTex(W, H) {
380
+ return /* wgsl */`
381
+ @group(0) @binding(0) var tex0: texture_2d<f32>;
382
+ @group(0) @binding(1) var tex1: texture_2d<f32>;
383
+ @group(0) @binding(2) var samp: sampler;
384
+ @group(0) @binding(3) var<storage, read_write> imgs: array<f32>; // [6,${H},${W}] BGR
385
+
386
+ @compute @workgroup_size(${WG}, ${WG})
387
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
388
+ let x = i32(gid.x); let y = i32(gid.y);
389
+ if (x >= ${W} || y >= ${H}) { return; }
390
+ let uv = (vec2<f32>(f32(x), f32(y)) + 0.5) / vec2<f32>(${W}.0, ${H}.0);
391
+ let c0 = textureSampleLevel(tex0, samp, uv, 0.0).rgb;
392
+ let c1 = textureSampleLevel(tex1, samp, uv, 0.0).rgb;
393
+ let o = y * ${W} + x;
394
+ let P = ${H * W};
395
+ imgs[o] = c0.b; imgs[P + o] = c0.g; imgs[2 * P + o] = c0.r;
396
+ imgs[3 * P + o] = c1.b; imgs[4 * P + o] = c1.g; imgs[5 * P + o] = c1.r;
397
+ }`;
398
+ }
399
+
400
+ function wgslPrepQuarterTex(W, H, f16) {
401
+ const QW = W / 4, QH = H / 4;
402
+ const T = f16 ? 'f16' : 'f32';
403
+ return /* wgsl */`
404
+ ${f16 ? 'enable f16;' : ''}
405
+ @group(0) @binding(0) var tex0: texture_2d<f32>;
406
+ @group(0) @binding(1) var tex1: texture_2d<f32>;
407
+ @group(0) @binding(2) var samp: sampler;
408
+ @group(0) @binding(3) var<storage, read_write> xq: array<${T}>; // [7,${QH},${QW}]
409
+ @group(0) @binding(4) var<storage, read> tstep: array<f32>;
410
+
411
+ @compute @workgroup_size(${WG}, ${WG})
412
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
413
+ let x = i32(gid.x); let y = i32(gid.y);
414
+ if (x >= ${QW} || y >= ${QH}) { return; }
415
+ // quarter of the MODEL grid; the sampler maps through whatever the texture size is
416
+ let uv = (vec2<f32>(f32(x), f32(y)) + 0.5) / vec2<f32>(${QW}.0, ${QH}.0);
417
+ let c0 = textureSampleLevel(tex0, samp, uv, 0.0).rgb;
418
+ let c1 = textureSampleLevel(tex1, samp, uv, 0.0).rgb;
419
+ let o = y * ${QW} + x;
420
+ let P = ${QH * QW};
421
+ xq[o] = ${T}(c0.b); xq[P + o] = ${T}(c0.g); xq[2 * P + o] = ${T}(c0.r);
422
+ xq[3 * P + o] = ${T}(c1.b); xq[4 * P + o] = ${T}(c1.g); xq[5 * P + o] = ${T}(c1.r);
423
+ xq[6 * P + o] = ${T}(tstep[0]);
424
+ }`;
425
+ }
426
+
427
+ // t-factored prep: 6 channels, NO timestep - the trunk is t-free, t enters via FiLM
428
+ function wgslPrepQuarterTex6(W, H, f16) {
429
+ const QW = W / 4, QH = H / 4;
430
+ const T = f16 ? 'f16' : 'f32';
431
+ return /* wgsl */`
432
+ ${f16 ? 'enable f16;' : ''}
433
+ @group(0) @binding(0) var tex0: texture_2d<f32>;
434
+ @group(0) @binding(1) var tex1: texture_2d<f32>;
435
+ @group(0) @binding(2) var samp: sampler;
436
+ @group(0) @binding(3) var<storage, read_write> xq: array<${T}>; // [6,${QH},${QW}]
437
+
438
+ @compute @workgroup_size(${WG}, ${WG})
439
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
440
+ let x = i32(gid.x); let y = i32(gid.y);
441
+ if (x >= ${QW} || y >= ${QH}) { return; }
442
+ let uv = (vec2<f32>(f32(x), f32(y)) + 0.5) / vec2<f32>(${QW}.0, ${QH}.0);
443
+ let c0 = textureSampleLevel(tex0, samp, uv, 0.0).rgb;
444
+ let c1 = textureSampleLevel(tex1, samp, uv, 0.0).rgb;
445
+ let o = y * ${QW} + x;
446
+ let P = ${QH * QW};
447
+ xq[o] = ${T}(c0.b); xq[P + o] = ${T}(c0.g); xq[2 * P + o] = ${T}(c0.r);
448
+ xq[3 * P + o] = ${T}(c1.b); xq[4 * P + o] = ${T}(c1.g); xq[5 * P + o] = ${T}(c1.r);
449
+ }`;
450
+ }
451
+
452
+ // FiLM conditioning: x' = x * (1 + scale[c]) + shift[c]; params are the tiny
453
+ // t-MLP's output, computed on the CPU per timestep (2*C floats)
454
+ function wgslFilm(C, N, f16) {
455
+ const T = f16 ? 'f16' : 'f32';
456
+ return /* wgsl */`
457
+ ${f16 ? 'enable f16;' : ''}
458
+ @group(0) @binding(0) var<storage, read> src: array<${T}>; // [C,N] trunk features
459
+ @group(0) @binding(1) var<storage, read> prm: array<f32>; // [2C] scale, shift
460
+ @group(0) @binding(2) var<storage, read_write> dst: array<${T}>;
461
+
462
+ @compute @workgroup_size(256)
463
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
464
+ let i = i32(gid.x);
465
+ if (i >= ${C * N}) { return; }
466
+ let c = i / ${N};
467
+ dst[i] = ${T}(f32(src[i]) * (1.0 + prm[c]) + prm[${C} + c]);
468
+ }`;
469
+ }
470
+
471
+ // flowout variant writing straight into a storage texture (GPU-resident presentation:
472
+ // the mid never leaves the GPU). rgba8unorm store rounds instead of truncating - ±1 LSB
473
+ // vs the buffer path, invisible; the correctness harness keeps using the buffer path.
474
+ function wgslFlowOutTex(W, H, staticGuard = false, withRes = false) {
475
+ const TW = W / 8, TH = H / 8, RW = W / 4, RH = H / 4;
476
+ // tfact2: the refine residual (quarter res) is folded straight into this pass -
477
+ // bilinear x4 upsample (align_corners=False grid), added BEFORE the clamp
478
+ const RES_DECL = withRes ? /* wgsl */`
479
+ @group(0) @binding(3) var<storage, read> res: array<f32>; // [3,${RH},${RW}]
480
+ fn rtap(c: i32, x: i32, y: i32) -> f32 {
481
+ return res[c * ${RH * RW} + clamp(y, 0, ${RH - 1}) * ${RW} + clamp(x, 0, ${RW - 1})];
482
+ }
483
+ fn rup(c: i32, sx: f32, sy: f32) -> f32 {
484
+ let x0 = i32(floor(sx)); let y0 = i32(floor(sy));
485
+ let fx = sx - f32(x0); let fy = sy - f32(y0);
486
+ return mix(mix(rtap(c, x0, y0), rtap(c, x0 + 1, y0), fx),
487
+ mix(rtap(c, x0, y0 + 1), rtap(c, x0 + 1, y0 + 1), fx), fy);
488
+ }` : '';
489
+ const RES_ADD = withRes ? /* wgsl */`
490
+ let rx = (f32(x) + 0.5) / 4.0 - 0.5;
491
+ let ry = (f32(y) + 0.5) / 4.0 - 0.5;
492
+ bgr = bgr + vec3<f32>(rup(0, rx, ry), rup(1, rx, ry), rup(2, rx, ry));` : '';
493
+ // static-region protection (SVP-style): where A and B are locally identical
494
+ // (subtitles, logos, UI, frozen shots-in-motion) the warp can still DRAG other
495
+ // content there - blend back to the untouched source instead. Soft ramp so
496
+ // moving-edge pixels transition smoothly.
497
+ const GUARD = staticGuard ? /* wgsl */`
498
+ var d = 0.0;
499
+ for (var dy = -1; dy <= 1; dy++) {
500
+ for (var dx = -1; dx <= 1; dx++) {
501
+ let xx = x + dx; let yy = y + dy;
502
+ d += max(abs(img(0, xx, yy) - img(3, xx, yy)),
503
+ max(abs(img(1, xx, yy) - img(4, xx, yy)),
504
+ abs(img(2, xx, yy) - img(5, xx, yy))));
505
+ }
506
+ }
507
+ d *= (1.0 / 9.0);
508
+ let wStatic = 1.0 - smoothstep(0.03, 0.09, d);
509
+ if (wStatic > 0.001) {
510
+ let stat = vec3<f32>(
511
+ (img(0, x, y) + img(3, x, y)) * 0.5,
512
+ (img(1, x, y) + img(4, x, y)) * 0.5,
513
+ (img(2, x, y) + img(5, x, y)) * 0.5);
514
+ bgr = mix(bgr, stat, wStatic);
515
+ }` : '';
516
+ return /* wgsl */`
517
+ @group(0) @binding(0) var<storage, read> tmp8: array<f32>; // [5,${TH},${TW}]
518
+ @group(0) @binding(1) var<storage, read> imgs: array<f32>; // [6,${H},${W}]
519
+ @group(0) @binding(2) var outTex: texture_storage_2d<rgba8unorm, write>;
520
+ ${RES_DECL}
521
+ fn tap(c: i32, x: i32, y: i32) -> f32 {
522
+ return tmp8[c * ${TH * TW} + clamp(y, 0, ${TH - 1}) * ${TW} + clamp(x, 0, ${TW - 1})];
523
+ }
524
+ fn up(c: i32, sx: f32, sy: f32) -> f32 {
525
+ let x0 = i32(floor(sx)); let y0 = i32(floor(sy));
526
+ let fx = sx - f32(x0); let fy = sy - f32(y0);
527
+ return mix(mix(tap(c, x0, y0), tap(c, x0 + 1, y0), fx),
528
+ mix(tap(c, x0, y0 + 1), tap(c, x0 + 1, y0 + 1), fx), fy);
529
+ }
530
+ fn img(plane: i32, x: i32, y: i32) -> f32 {
531
+ return imgs[plane * ${H * W} + clamp(y, 0, ${H - 1}) * ${W} + clamp(x, 0, ${W - 1})];
532
+ }
533
+ fn warp3(base: i32, sx: f32, sy: f32) -> vec3<f32> {
534
+ let x0 = i32(floor(sx)); let y0 = i32(floor(sy));
535
+ let fx = sx - f32(x0); let fy = sy - f32(y0);
536
+ var r: vec3<f32>;
537
+ for (var c = 0; c < 3; c++) {
538
+ let p = base + c;
539
+ r[c] = mix(mix(img(p, x0, y0), img(p, x0 + 1, y0), fx),
540
+ mix(img(p, x0, y0 + 1), img(p, x0 + 1, y0 + 1), fx), fy);
541
+ }
542
+ return r; // b,g,r
543
+ }
544
+
545
+ @compute @workgroup_size(${WG}, ${WG})
546
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
547
+ let x = i32(gid.x); let y = i32(gid.y);
548
+ if (x >= ${W} || y >= ${H}) { return; }
549
+ let sx = (f32(x) + 0.5) / 8.0 - 0.5;
550
+ let sy = (f32(y) + 0.5) / 8.0 - 0.5;
551
+ let fx0 = up(0, sx, sy) * 8.0;
552
+ let fy0 = up(1, sx, sy) * 8.0;
553
+ let fx1 = up(2, sx, sy) * 8.0;
554
+ let fy1 = up(3, sx, sy) * 8.0;
555
+ let m = 1.0 / (1.0 + exp(-up(4, sx, sy)));
556
+ let w0 = warp3(0, f32(x) + fx0, f32(y) + fy0);
557
+ let w1 = warp3(3, f32(x) + fx1, f32(y) + fy1);
558
+ var bgr = w0 * m + w1 * (1.0 - m);
559
+ ${RES_ADD}
560
+ bgr = clamp(bgr, vec3<f32>(0.0), vec3<f32>(1.0));
561
+ ${GUARD}
562
+ textureStore(outTex, vec2<i32>(x, y), vec4<f32>(bgr.z, bgr.y, bgr.x, 1.0));
563
+ }`;
564
+ }
565
+
566
+ // ---- refine head (tfact2): occlusion repair at QUARTER resolution ----
567
+ // Gathers [warped0(3), warped1(3), mask(1), flow*(0.25/20)(4)] = 11ch at H/4.
568
+ // F.interpolate(x, 0.25, bilinear, align_corners=False) samples the source at
569
+ // 4x+1.5 - i.e. the mean of the CENTER 2x2 of each 4x4 block; we warp those
570
+ // four full-res positions and average, matching the trainer exactly.
571
+ function wgslRefinePrep(W, H, f16) {
572
+ const TW = W / 8, TH = H / 8, HW = W / 4, HH = H / 4;
573
+ const T = f16 ? 'f16' : 'f32';
574
+ return /* wgsl */`
575
+ ${f16 ? 'enable f16;' : ''}
576
+ @group(0) @binding(0) var<storage, read> tmp8: array<f32>; // [5,${TH},${TW}]
577
+ @group(0) @binding(1) var<storage, read> imgs: array<f32>; // [6,${H},${W}]
578
+ @group(0) @binding(2) var<storage, read_write> rin: array<${T}>; // [11,${HH},${HW}]
579
+
580
+ fn tap(c: i32, x: i32, y: i32) -> f32 {
581
+ return tmp8[c * ${TH * TW} + clamp(y, 0, ${TH - 1}) * ${TW} + clamp(x, 0, ${TW - 1})];
582
+ }
583
+ fn up(c: i32, sx: f32, sy: f32) -> f32 {
584
+ let x0 = i32(floor(sx)); let y0 = i32(floor(sy));
585
+ let fx = sx - f32(x0); let fy = sy - f32(y0);
586
+ return mix(mix(tap(c, x0, y0), tap(c, x0 + 1, y0), fx),
587
+ mix(tap(c, x0, y0 + 1), tap(c, x0 + 1, y0 + 1), fx), fy);
588
+ }
589
+ fn img(plane: i32, x: i32, y: i32) -> f32 {
590
+ return imgs[plane * ${H * W} + clamp(y, 0, ${H - 1}) * ${W} + clamp(x, 0, ${W - 1})];
591
+ }
592
+ fn warp3(base: i32, sx: f32, sy: f32) -> vec3<f32> {
593
+ let x0 = i32(floor(sx)); let y0 = i32(floor(sy));
594
+ let fx = sx - f32(x0); let fy = sy - f32(y0);
595
+ var r: vec3<f32>;
596
+ for (var c = 0; c < 3; c++) {
597
+ let p = base + c;
598
+ r[c] = mix(mix(img(p, x0, y0), img(p, x0 + 1, y0), fx),
599
+ mix(img(p, x0, y0 + 1), img(p, x0 + 1, y0 + 1), fx), fy);
600
+ }
601
+ return r;
602
+ }
603
+
604
+ @compute @workgroup_size(${WG}, ${WG})
605
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
606
+ let hx = i32(gid.x); let hy = i32(gid.y);
607
+ if (hx >= ${HW} || hy >= ${HH}) { return; }
608
+ var w0 = vec3<f32>(0.0); var w1 = vec3<f32>(0.0);
609
+ var mk = 0.0; var fl = vec4<f32>(0.0);
610
+ for (var sy = 1; sy <= 2; sy++) {
611
+ for (var sxp = 1; sxp <= 2; sxp++) {
612
+ let X = hx * 4 + sxp; let Y = hy * 4 + sy; // center 2x2 of the 4x4 block
613
+ let gx8 = (f32(X) + 0.5) / 8.0 - 0.5;
614
+ let gy8 = (f32(Y) + 0.5) / 8.0 - 0.5;
615
+ let fx0 = up(0, gx8, gy8) * 8.0; let fy0 = up(1, gx8, gy8) * 8.0;
616
+ let fx1 = up(2, gx8, gy8) * 8.0; let fy1 = up(3, gx8, gy8) * 8.0;
617
+ mk += 1.0 / (1.0 + exp(-up(4, gx8, gy8)));
618
+ w0 += warp3(0, f32(X) + fx0, f32(Y) + fy0);
619
+ w1 += warp3(3, f32(X) + fx1, f32(Y) + fy1);
620
+ fl += vec4<f32>(fx0, fy0, fx1, fy1);
621
+ }
622
+ }
623
+ let P = ${HH * HW};
624
+ let o = hy * ${HW} + hx;
625
+ let q = 0.25;
626
+ let fn_ = 0.25 * (0.25 / 20.0); // mean * (0.25/FLOW_NORM)
627
+ rin[o] = ${T}(w0.x * q); rin[P + o] = ${T}(w0.y * q); rin[2 * P + o] = ${T}(w0.z * q);
628
+ rin[3 * P + o] = ${T}(w1.x * q); rin[4 * P + o] = ${T}(w1.y * q); rin[5 * P + o] = ${T}(w1.z * q);
629
+ rin[6 * P + o] = ${T}(mk * q);
630
+ rin[7 * P + o] = ${T}(fl.x * fn_); rin[8 * P + o] = ${T}(fl.y * fn_);
631
+ rin[9 * P + o] = ${T}(fl.z * fn_); rin[10 * P + o] = ${T}(fl.w * fn_);
632
+ }`;
633
+ }
634
+
635
+ // final refine conv (C->3) + sigmoid*2-1 residual, half res, f32 out
636
+ function wgslRefineOut(C, HW, HH, f16) {
637
+ const T = f16 ? 'f16' : 'f32';
638
+ return /* wgsl */`
639
+ ${f16 ? 'enable f16;' : ''}
640
+ @group(0) @binding(0) var<storage, read> src: array<${T}>; // [${C},${HH},${HW}]
641
+ @group(0) @binding(1) var<storage, read> wgt: array<f32>; // [3,${C},3,3]
642
+ @group(0) @binding(2) var<storage, read> bias: array<f32>; // [3]
643
+ @group(0) @binding(3) var<storage, read_write> dst: array<f32>; // [3,${HH},${HW}]
644
+
645
+ @compute @workgroup_size(${WG}, ${WG})
646
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
647
+ let x = i32(gid.x); let y = i32(gid.y);
648
+ if (x >= ${HW} || y >= ${HH}) { return; }
649
+ for (var co = 0; co < 3; co++) {
650
+ var acc = bias[co];
651
+ for (var ci = 0; ci < ${C}; ci++) {
652
+ let sb = ci * ${HH * HW};
653
+ let wb = (co * ${C} + ci) * 9;
654
+ for (var ky = 0; ky < 3; ky++) {
655
+ let sy = clamp(y + ky - 1, 0, ${HH - 1});
656
+ for (var kx = 0; kx < 3; kx++) {
657
+ let sx = clamp(x + kx - 1, 0, ${HW - 1});
658
+ acc += f32(src[sb + sy * ${HW} + sx]) * wgt[wb + ky * 3 + kx];
659
+ }
660
+ }
661
+ }
662
+ dst[co * ${HH * HW} + y * ${HW} + x] = 2.0 / (1.0 + exp(-acc)) - 1.0;
663
+ }
664
+ }`;
665
+ }
666
+
667
+ // one-shot f32 -> f16 conversion (weights at init)
668
+ export const WGSL_TO_F16 = /* wgsl */`
669
+ enable f16;
670
+ @group(0) @binding(0) var<storage, read> src: array<f32>;
671
+ @group(0) @binding(1) var<storage, read_write> dst: array<f16>;
672
+ @compute @workgroup_size(256)
673
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
674
+ let i = gid.x;
675
+ if (i < arrayLength(&src)) { dst[i] = f16(src[i]); }
676
+ }`;
677
+
678
+ // ConvTranspose2d 4x4 stride2 pad1, no activation. Weight layout [CI, CO, 4, 4].
679
+ function wgslDeconv(CI, CO, IW, IH, OW, OH, f16src) {
680
+ const T = f16src ? 'f16' : 'f32';
681
+ return /* wgsl */`
682
+ ${f16src ? 'enable f16;' : ''}
683
+ @group(0) @binding(0) var<storage, read> src: array<${T}>;
684
+ @group(0) @binding(1) var<storage, read> wgt: array<f32>;
685
+ @group(0) @binding(2) var<storage, read> bias: array<f32>;
686
+ @group(0) @binding(3) var<storage, read_write> dst: array<f32>;
687
+
688
+ @compute @workgroup_size(${WG}, ${WG}, 1)
689
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
690
+ let x = i32(gid.x); let y = i32(gid.y); let co = i32(gid.z);
691
+ if (x >= ${OW} || y >= ${OH}) { return; }
692
+ var acc = bias[co];
693
+ for (var ky = 0; ky < 4; ky++) {
694
+ let ty = y + 1 - ky;
695
+ if (ty < 0 || (ty & 1) != 0) { continue; }
696
+ let iy = ty / 2;
697
+ if (iy >= ${IH}) { continue; }
698
+ for (var kx = 0; kx < 4; kx++) {
699
+ let tx = x + 1 - kx;
700
+ if (tx < 0 || (tx & 1) != 0) { continue; }
701
+ let ix = tx / 2;
702
+ if (ix >= ${IW}) { continue; }
703
+ for (var ci = 0; ci < ${CI}; ci++) {
704
+ acc += f32(src[ci * ${IH * IW} + iy * ${IW} + ix])
705
+ * wgt[ci * ${CO * 16} + co * 16 + ky * 4 + kx];
706
+ }
707
+ }
708
+ }
709
+ dst[co * ${OH * OW} + y * ${OW} + x] = acc;
710
+ }`;
711
+ }
712
+
713
+ // upsample tmp8 x8 (align_corners=false), flow*=8, warp both images, sigmoid blend, pack rgba
714
+ function wgslFlowOut(W, H) {
715
+ const TW = W / 8, TH = H / 8;
716
+ return /* wgsl */`
717
+ @group(0) @binding(0) var<storage, read> tmp8: array<f32>; // [5,${TH},${TW}]
718
+ @group(0) @binding(1) var<storage, read> imgs: array<f32>; // [6,${H},${W}]
719
+ @group(0) @binding(2) var<storage, read_write> outp: array<u32>; // rgba
720
+
721
+ fn tap(c: i32, x: i32, y: i32) -> f32 {
722
+ return tmp8[c * ${TH * TW} + clamp(y, 0, ${TH - 1}) * ${TW} + clamp(x, 0, ${TW - 1})];
723
+ }
724
+ fn up(c: i32, sx: f32, sy: f32) -> f32 {
725
+ let x0 = i32(floor(sx)); let y0 = i32(floor(sy));
726
+ let fx = sx - f32(x0); let fy = sy - f32(y0);
727
+ return mix(mix(tap(c, x0, y0), tap(c, x0 + 1, y0), fx),
728
+ mix(tap(c, x0, y0 + 1), tap(c, x0 + 1, y0 + 1), fx), fy);
729
+ }
730
+ fn img(plane: i32, x: i32, y: i32) -> f32 {
731
+ return imgs[plane * ${H * W} + clamp(y, 0, ${H - 1}) * ${W} + clamp(x, 0, ${W - 1})];
732
+ }
733
+ // grid_sample bilinear, border, align_corners=true == pixel-space bilinear with clamped taps
734
+ fn warp3(base: i32, sx: f32, sy: f32) -> vec3<f32> {
735
+ let x0 = i32(floor(sx)); let y0 = i32(floor(sy));
736
+ let fx = sx - f32(x0); let fy = sy - f32(y0);
737
+ var r: vec3<f32>;
738
+ for (var c = 0; c < 3; c++) {
739
+ let p = base + c;
740
+ r[c] = mix(mix(img(p, x0, y0), img(p, x0 + 1, y0), fx),
741
+ mix(img(p, x0, y0 + 1), img(p, x0 + 1, y0 + 1), fx), fy);
742
+ }
743
+ return r; // b,g,r
744
+ }
745
+
746
+ @compute @workgroup_size(${WG}, ${WG})
747
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
748
+ let x = i32(gid.x); let y = i32(gid.y);
749
+ if (x >= ${W} || y >= ${H}) { return; }
750
+ // align_corners=false x8 upsample: src = (dst+0.5)/8 - 0.5; flow scaled by 8 (=scale*2*... baked)
751
+ let sx = (f32(x) + 0.5) / 8.0 - 0.5;
752
+ let sy = (f32(y) + 0.5) / 8.0 - 0.5;
753
+ let fx0 = up(0, sx, sy) * 8.0;
754
+ let fy0 = up(1, sx, sy) * 8.0;
755
+ let fx1 = up(2, sx, sy) * 8.0;
756
+ let fy1 = up(3, sx, sy) * 8.0;
757
+ let m = 1.0 / (1.0 + exp(-up(4, sx, sy))); // sigmoid(mask)
758
+ let w0 = warp3(0, f32(x) + fx0, f32(y) + fy0);
759
+ let w1 = warp3(3, f32(x) + fx1, f32(y) + fy1);
760
+ let bgr = w0 * m + w1 * (1.0 - m);
761
+ // BGR -> RGB, *255 truncate (matches rife-core prepost), alpha 255
762
+ let r = u32(clamp(bgr.z, 0.0, 1.0) * 255.0);
763
+ let g = u32(clamp(bgr.y, 0.0, 1.0) * 255.0);
764
+ let b = u32(clamp(bgr.x, 0.0, 1.0) * 255.0);
765
+ outp[y * ${W} + x] = r | (g << 8u) | (b << 16u) | (255u << 24u);
766
+ }`;
767
+ }
768
+
769
+ export async function createRT(device, { w, h, weightsBin, weightsManifest, convTune,
770
+ textureInput = false, textureOutput = false,
771
+ staticGuard = false }) {
772
+ if (w % 16 || h % 16) throw new Error(`rt: dims must be /16 (got ${w}x${h})`);
773
+ const QW = w / 4, QH = h / 4, W8 = w / 8, H8 = h / 8, W16 = w / 16, H16 = h / 16;
774
+ const useF16 = device.features.has('shader-f16');
775
+ // channel widths come from the weights themselves (supports slim students)
776
+ const C1 = weightsManifest['block0.conv0.0.0.weight'].shape[0]; // conv0a out (120 full / 60 slim)
777
+ const C2 = weightsManifest['block0.conv0.1.0.weight'].shape[0]; // main width (240 full / 120 slim)
778
+ if (C2 % 4) throw new Error('rt: main width must be /4');
779
+ // t-factored graph: trunk (conv0 + 6 convblocks) is timestep-free and runs once
780
+ // per pair; FiLM(t) + convblocks 6,7 + lastconv run per mid. Detected by the
781
+ // film MLP in the manifest; input prep is 6ch (no t channel).
782
+ const tfact = 'film.2.weight' in weightsManifest;
783
+ const CI0 = weightsManifest['block0.conv0.0.0.weight'].shape[1]; // 7 classic, 6 tfact
784
+ if (tfact && (!textureInput || !textureOutput)) {
785
+ throw new Error('rt: tfact weights need texture input/output mode');
786
+ }
787
+ const refi = tfact && ('refine.c0.weight' in weightsManifest); // tfact2
788
+ const Z0A = C1 % 4 === 0 ? C1 / 4 : C1; // conv0a kernel packs 4 channels only when C1 is /4
789
+
790
+ const bufBytes = (bytes, usage = GPUBufferUsage.STORAGE) => device.createBuffer({
791
+ size: Math.ceil(bytes / 4) * 4,
792
+ usage: usage | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC });
793
+ const buf = (n) => bufBytes(n * 4);
794
+ const abuf = (n) => bufBytes(n * (useF16 ? 2 : 4)); // activation dtype
795
+
796
+ const mod = (code) => device.createShaderModule({ code });
797
+ const pipe = (code, entry = 'main') => device.createComputePipeline({
798
+ layout: 'auto', compute: { module: mod(code), entryPoint: entry } });
799
+ const bg = (p, entries) => device.createBindGroup({
800
+ layout: p.getBindGroupLayout(0),
801
+ entries: entries.map((b, i) => ({ binding: i, resource: { buffer: b } })) });
802
+
803
+ // weights (f32 upload; conv weights get f16 copies when supported)
804
+ const man = weightsManifest;
805
+ const wbuf = {};
806
+ for (const [name, m] of Object.entries(man)) {
807
+ const n = m.shape.reduce((a, b) => a * b, 1);
808
+ wbuf[name] = buf(n);
809
+ device.queue.writeBuffer(wbuf[name], 0, weightsBin, m.offset * 4, n * 4);
810
+ }
811
+ // one pipeline shared by all weight conversions
812
+ const pToF16 = useF16 ? pipe(WGSL_TO_F16) : null;
813
+ const convW = (name) => {
814
+ if (!useF16) return wbuf[name];
815
+ const n = man[name].shape.reduce((a, b) => a * b, 1);
816
+ const half = bufBytes(n * 2);
817
+ const p = pToF16;
818
+ const enc = device.createCommandEncoder();
819
+ const pass = enc.beginComputePass();
820
+ pass.setPipeline(p);
821
+ pass.setBindGroup(0, bg(p, [wbuf[name], half]));
822
+ pass.dispatchWorkgroups(Math.ceil(n / 256));
823
+ pass.end();
824
+ device.queue.submit([enc.finish()]);
825
+ return half;
826
+ };
827
+
828
+ // activations (f16 when supported), fixed f32 elsewhere
829
+ const tbuf = buf(1);
830
+ const rgba0 = textureInput ? null : buf(w * h);
831
+ const rgba1 = textureInput ? null : buf(w * h);
832
+ const imgs = buf(6 * w * h);
833
+ const xq = abuf(CI0 * QH * QW);
834
+ const f8 = abuf(C1 * H8 * W8);
835
+ const actBytes = C2 * H16 * W16 * (useF16 ? 2 : 4);
836
+ const f16a = bufBytes(actBytes), f16b = bufBytes(actBytes), f16r = bufBytes(actBytes);
837
+ const tmp8 = buf(5 * H8 * W8);
838
+ // readback plumbing exists only for the buffer-output path (rt_test harness);
839
+ // texture-output callers never read back - ~7*w*h*4 bytes of MAP_READ saved
840
+ const outp = textureOutput ? null : buf(w * h);
841
+ const staging = textureOutput ? null
842
+ : device.createBuffer({ size: w * h * 4, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ });
843
+ // slots for batched multi-t runs (factor N: upload once, N-1 mids in ONE submit)
844
+ const MAXT = 5;
845
+ const tbufs = [], stagings = [];
846
+ for (let i = 0; i < MAXT; i++) {
847
+ tbufs.push(buf(1));
848
+ if (!textureOutput) {
849
+ stagings.push(device.createBuffer({ size: w * h * 4, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }));
850
+ }
851
+ }
852
+
853
+ const sampler = textureInput
854
+ ? device.createSampler({ magFilter: 'linear', minFilter: 'linear',
855
+ addressModeU: 'clamp-to-edge', addressModeV: 'clamp-to-edge' })
856
+ : null;
857
+ // the heavy shaders compile in PARALLEL on Dawn's worker pool (and without
858
+ // blocking the main thread) - createRT is async anyway, so batch-await them
859
+ const pipeAsync = (code, entry = 'main') => device.createComputePipelineAsync({
860
+ layout: 'auto', compute: { module: mod(code), entryPoint: entry } });
861
+ const sgTuned = convTune && convTune.sg && device.features.has('subgroups');
862
+ const [pPrepFull, pPrepQ, pConv0a, pConv0b, pConvB, pConvBR, pDeconv, pFlow] = await Promise.all([
863
+ pipeAsync(textureInput ? wgslPrepFullTex(w, h) : wgslPrepFull(w, h)),
864
+ pipeAsync(tfact ? wgslPrepQuarterTex6(w, h, useF16)
865
+ : (textureInput ? wgslPrepQuarterTex(w, h, useF16) : wgslPrepQuarter(w, h, useF16))),
866
+ pipeAsync(wgslConv(CI0, C1, QW, QH, W8, H8, 2, false, useF16)),
867
+ pipeAsync(wgslConv(C1, C2, W8, H8, W16, H16, 2, false, useF16)),
868
+ pipeAsync(useF16
869
+ ? (sgTuned ? wgslConvRBSg : wgslConvRB)(C2, C2, W16, H16, W16, H16, false, convTune)
870
+ : wgslConv(C2, C2, W16, H16, W16, H16, 1, false, false)),
871
+ pipeAsync(useF16
872
+ ? (sgTuned ? wgslConvRBSg : wgslConvRB)(C2, C2, W16, H16, W16, H16, true, convTune)
873
+ : wgslConv(C2, C2, W16, H16, W16, H16, 1, true, false)),
874
+ pipeAsync(wgslDeconv(C2, 5, W16, H16, W8, H8, useF16)),
875
+ pipeAsync(textureOutput ? wgslFlowOutTex(w, h, staticGuard, refi) : wgslFlowOut(w, h)),
876
+ ]);
877
+ // texture-output mode: flow bind groups are per output texture (small ring - cache them)
878
+ const flowBgCache = new Map();
879
+ function flowBgFor(tex) {
880
+ if (!flowBgCache.has(tex)) {
881
+ // evict BEFORE inserting - clearing after would wipe the fresh entry and
882
+ // hand setBindGroup an undefined (latent until a caller rings >24 textures)
883
+ if (flowBgCache.size > 24) flowBgCache.clear();
884
+ const entries = [
885
+ { binding: 0, resource: { buffer: tmp8 } },
886
+ { binding: 1, resource: { buffer: imgs } },
887
+ { binding: 2, resource: tex.createView() }];
888
+ if (refi) entries.push({ binding: 3, resource: { buffer: rRes } }); // tfact2 residual
889
+ flowBgCache.set(tex, device.createBindGroup({ layout: pFlow.getBindGroupLayout(0), entries }));
890
+ }
891
+ return flowBgCache.get(tex);
892
+ }
893
+
894
+ // buffer-input prep bind groups (unused in texture mode)
895
+ const bgPrepFull = textureInput ? null : bg(pPrepFull, [rgba0, rgba1, imgs]);
896
+ const bgPrepQ = textureInput ? null : bg(pPrepQ, [rgba0, rgba1, xq, tbuf]);
897
+ const bgPrepQt = textureInput ? null : tbufs.map(tb => bg(pPrepQ, [rgba0, rgba1, xq, tb]));
898
+ // texture-mode prep bind groups are built per texture pair and cached (ping-pong -> few combos).
899
+ // Keyed by texture IDENTITY, not label: callers recreate pools reusing the same
900
+ // labels (resolution change), and a label-keyed cache would keep serving bind
901
+ // groups of destroyed textures - every submit fails async validation and the
902
+ // mids silently replay stale content.
903
+ const texBgCache = new Map();
904
+ function texPrepBgs(texA, texB) {
905
+ const key = texBgId(texA) + '|' + texBgId(texB);
906
+ if (!texBgCache.has(key)) {
907
+ // evict BEFORE inserting - clearing after would wipe the fresh entry too
908
+ if (texBgCache.size > 12) texBgCache.clear(); // texture set changed wholesale
909
+ const va = texA.createView(), vb = texB.createView();
910
+ texBgCache.set(key, {
911
+ full: device.createBindGroup({ layout: pPrepFull.getBindGroupLayout(0), entries: [
912
+ { binding: 0, resource: va }, { binding: 1, resource: vb },
913
+ { binding: 2, resource: sampler }, { binding: 3, resource: { buffer: imgs } }] }),
914
+ q: tfact
915
+ ? [device.createBindGroup({ layout: pPrepQ.getBindGroupLayout(0), entries: [
916
+ { binding: 0, resource: va }, { binding: 1, resource: vb },
917
+ { binding: 2, resource: sampler }, { binding: 3, resource: { buffer: xq } }] })]
918
+ : tbufs.map(tb => device.createBindGroup({ layout: pPrepQ.getBindGroupLayout(0), entries: [
919
+ { binding: 0, resource: va }, { binding: 1, resource: vb },
920
+ { binding: 2, resource: sampler }, { binding: 3, resource: { buffer: xq } },
921
+ { binding: 4, resource: { buffer: tb } }] })),
922
+ });
923
+ }
924
+ return texBgCache.get(key);
925
+ }
926
+
927
+ // tfact-only state: trunk feature buffer + FiLM params (tiny t-MLP runs in JS)
928
+ const hbuf = tfact ? bufBytes(C2 * H16 * W16 * (useF16 ? 2 : 4)) : null;
929
+ const filmBuf = tfact ? buf(2 * C2) : null;
930
+ const pFilm = tfact ? pipe(wgslFilm(C2, H16 * W16, useF16)) : null;
931
+ let bgFilm = null, filmW = null;
932
+ if (tfact) {
933
+ bgFilm = bg(pFilm, [hbuf, filmBuf, f16a]);
934
+ const f32 = (name) => {
935
+ const m = weightsManifest[name];
936
+ return new Float32Array(weightsBin, m.offset * 4, m.shape.reduce((a, b) => a * b, 1));
937
+ };
938
+ filmW = { w0: f32('film.0.weight'), b0: f32('film.0.bias'),
939
+ w2: f32('film.2.weight'), b2: f32('film.2.bias') };
940
+ }
941
+ // tfact2 refine head: occlusion repair at QUARTER res; the residual is folded
942
+ // into the flowout pass (withRes)
943
+ const RW4 = w / 4, RH4 = h / 4;
944
+ let pRPrep = null, pRC0 = null, pRC1 = null, pROut = null;
945
+ let bgRPrep = null, bgRC0 = null, bgRC1 = null, bgRC2 = null, bgROut = null;
946
+ let RC = 0, rRes = null;
947
+ if (refi) {
948
+ RC = man['refine.c0.weight'].shape[0];
949
+ // the refine convs are dispatched with z = RC/4 below, and wgslConv only
950
+ // packs 4-wide when CO % 4 == 0 - a non-/4 head would silently mis-dispatch
951
+ if (RC % 4 !== 0) throw new Error('rt: refine channels must be a multiple of 4, got ' + RC);
952
+ const rIn = abuf(11 * RH4 * RW4);
953
+ const rA2 = abuf(RC * RH4 * RW4);
954
+ const rB2 = abuf(RC * RH4 * RW4);
955
+ rRes = buf(3 * RH4 * RW4);
956
+ [pRPrep, pRC0, pRC1, pROut] = await Promise.all([
957
+ pipeAsync(wgslRefinePrep(w, h, useF16)),
958
+ pipeAsync(wgslConv(11, RC, RW4, RH4, RW4, RH4, 1, false, useF16)),
959
+ pipeAsync(wgslConv(RC, RC, RW4, RH4, RW4, RH4, 1, false, useF16)),
960
+ pipeAsync(wgslRefineOut(RC, RW4, RH4, useF16)),
961
+ ]);
962
+ bgRPrep = bg(pRPrep, [tmp8, imgs, rIn]);
963
+ bgRC0 = bg(pRC0, [rIn, convW('refine.c0.weight'), wbuf['refine.c0.bias'], wbuf['refine.a0.weight'], rA2]);
964
+ bgRC1 = bg(pRC1, [rA2, convW('refine.c1.weight'), wbuf['refine.c1.bias'], wbuf['refine.a1.weight'], rB2]);
965
+ bgRC2 = bg(pRC1, [rB2, convW('refine.c2.weight'), wbuf['refine.c2.bias'], wbuf['refine.a2.weight'], rA2]);
966
+ bgROut = bg(pROut, [rA2, wbuf['refine.c3.weight'], wbuf['refine.c3.bias'], rRes]);
967
+ }
968
+
969
+ // scratch reused across calls: these run once per mid (writeBuffer reads the
970
+ // array synchronously, so reuse is safe) - allocating per call is pure GC churn
971
+ const tScratch = new Float32Array(1);
972
+ const filmScratch = tfact ? { out: new Float32Array(2 * C2), h: new Float32Array(filmW.b0.length) } : null;
973
+ function filmParams(t) {
974
+ const HN = filmW.b0.length, { out, h } = filmScratch;
975
+ for (let j = 0; j < HN; j++) h[j] = Math.max(0, filmW.w0[j] * t + filmW.b0[j]);
976
+ for (let k = 0; k < 2 * C2; k++) {
977
+ let s = filmW.b2[k];
978
+ const row = k * HN;
979
+ for (let j = 0; j < HN; j++) s += filmW.w2[row + j] * h[j];
980
+ out[k] = s;
981
+ }
982
+ return out;
983
+ }
984
+ const bgConv0a = bg(pConv0a, [xq, convW('block0.conv0.0.0.weight'), wbuf['block0.conv0.0.0.bias'], wbuf['block0.conv0.0.1.weight'], f8]);
985
+ const bgConv0b = bg(pConv0b, [f8, convW('block0.conv0.1.0.weight'), wbuf['block0.conv0.1.0.bias'], wbuf['block0.conv0.1.1.weight'], f16a]);
986
+ // convblock ping-pong: a->b, b->a, ... 8th conv adds the residual (f16r = copy of f16a)
987
+ const bgB = [];
988
+ let src = f16a, dst = f16b;
989
+ for (let i = 0; i < 8; i++) {
990
+ const wn = `block0.convblock.${i}.0.weight`, bn = `block0.convblock.${i}.0.bias`, an = `block0.convblock.${i}.1.weight`;
991
+ if (i < 7) {
992
+ bgB.push({ p: pConvB, g: bg(pConvB, [src, convW(wn), wbuf[bn], wbuf[an], dst]) });
993
+ } else {
994
+ bgB.push({ p: pConvBR, g: bg(pConvBR, [src, convW(wn), wbuf[bn], wbuf[an], dst, f16r]) });
995
+ }
996
+ [src, dst] = [dst, src];
997
+ }
998
+ const f16out = src; // after 8 convs
999
+ const bgDeconv = bg(pDeconv, [f16out, wbuf['block0.lastconv.weight'], wbuf['block0.lastconv.bias'], tmp8]);
1000
+ const bgFlow = textureOutput ? null : bg(pFlow, [tmp8, imgs, outp]);
1001
+
1002
+ const gx = (n) => Math.ceil(n / WG);
1003
+ // register-blocked convblock kernel covers 16x16 output per workgroup
1004
+ const cbX = useF16 ? Math.ceil(W16 / 16) : gx(W16);
1005
+ const cbY = useF16 ? Math.ceil(H16 / 16) : gx(H16);
1006
+ const cbZ = useF16 ? C2 / ((convTune && convTune.coc) || 4) : C2 / 4;
1007
+
1008
+ // per-stage GPU times via timestamp queries (needs 'timestamp-query' on the device)
1009
+ async function profile(rgbaA, rgbaB) {
1010
+ if (!device.features.has('timestamp-query')) return 'no timestamp-query feature';
1011
+ const stages = [
1012
+ ['prepFull', pPrepFull, bgPrepFull, [gx(w), gx(h), 1]],
1013
+ ['prepQ', pPrepQ, bgPrepQ, [gx(QW), gx(QH), 1]],
1014
+ ['conv0a', pConv0a, bgConv0a, [gx(W8), gx(H8), Z0A]],
1015
+ ['conv0b', pConv0b, bgConv0b, [gx(W16), gx(H16), C2 / 4]],
1016
+ ...bgB.map(({ p, g }, i) => [`convB${i}`, p, g, [cbX, cbY, cbZ]]),
1017
+ ['deconv', pDeconv, bgDeconv, [gx(W8), gx(H8), 5]],
1018
+ ['flow', pFlow, bgFlow, [gx(w), gx(h), 1]],
1019
+ ];
1020
+ const qs = device.createQuerySet({ type: 'timestamp', count: stages.length * 2 });
1021
+ const qbuf = device.createBuffer({ size: stages.length * 16, usage: GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC });
1022
+ const qread = device.createBuffer({ size: stages.length * 16, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ });
1023
+ device.queue.writeBuffer(tbuf, 0, new Float32Array([0.5]));
1024
+ device.queue.writeBuffer(rgba0, 0, rgbaA.buffer, rgbaA.byteOffset, w * h * 4);
1025
+ device.queue.writeBuffer(rgba1, 0, rgbaB.buffer, rgbaB.byteOffset, w * h * 4);
1026
+ const enc = device.createCommandEncoder();
1027
+ stages.forEach(([name, p, g, d], i) => {
1028
+ if (name === 'convB0') enc.copyBufferToBuffer(f16a, 0, f16r, 0, actBytes);
1029
+ const pass = enc.beginComputePass({ timestampWrites: {
1030
+ querySet: qs, beginningOfPassWriteIndex: i * 2, endOfPassWriteIndex: i * 2 + 1 } });
1031
+ pass.setPipeline(p); pass.setBindGroup(0, g);
1032
+ pass.dispatchWorkgroups(d[0], d[1], d[2]);
1033
+ pass.end();
1034
+ });
1035
+ enc.resolveQuerySet(qs, 0, stages.length * 2, qbuf, 0);
1036
+ enc.copyBufferToBuffer(qbuf, 0, qread, 0, stages.length * 16);
1037
+ device.queue.submit([enc.finish()]);
1038
+ await qread.mapAsync(GPUMapMode.READ);
1039
+ const ts = new BigUint64Array(qread.getMappedRange().slice(0));
1040
+ qread.unmap();
1041
+ return stages.map(([name], i) =>
1042
+ `${name}: ${(Number(ts[i * 2 + 1] - ts[i * 2]) / 1e6).toFixed(2)}ms`).join(' · ');
1043
+ }
1044
+
1045
+ async function run(rgbaA, rgbaB, t = 0.5) {
1046
+ if (tfact) throw new Error('rt: tfact weights have no buffer-mode run()');
1047
+ device.queue.writeBuffer(tbuf, 0, new Float32Array([t]));
1048
+ device.queue.writeBuffer(rgba0, 0, rgbaA.buffer, rgbaA.byteOffset, w * h * 4);
1049
+ device.queue.writeBuffer(rgba1, 0, rgbaB.buffer, rgbaB.byteOffset, w * h * 4);
1050
+ const enc = device.createCommandEncoder();
1051
+ const pass = enc.beginComputePass();
1052
+ pass.setPipeline(pPrepFull); pass.setBindGroup(0, bgPrepFull); pass.dispatchWorkgroups(gx(w), gx(h));
1053
+ pass.setPipeline(pPrepQ); pass.setBindGroup(0, bgPrepQ); pass.dispatchWorkgroups(gx(QW), gx(QH));
1054
+ pass.setPipeline(pConv0a); pass.setBindGroup(0, bgConv0a); pass.dispatchWorkgroups(gx(W8), gx(H8), Z0A);
1055
+ pass.setPipeline(pConv0b); pass.setBindGroup(0, bgConv0b); pass.dispatchWorkgroups(gx(W16), gx(H16), C2 / 4);
1056
+ pass.end();
1057
+ // residual copy AFTER conv0b (f16r = f16a snapshot)
1058
+ enc.copyBufferToBuffer(f16a, 0, f16r, 0, actBytes);
1059
+ const pass2 = enc.beginComputePass();
1060
+ for (const { p, g } of bgB) {
1061
+ pass2.setPipeline(p); pass2.setBindGroup(0, g); pass2.dispatchWorkgroups(cbX, cbY, cbZ);
1062
+ }
1063
+ pass2.setPipeline(pDeconv); pass2.setBindGroup(0, bgDeconv); pass2.dispatchWorkgroups(gx(W8), gx(H8), 5);
1064
+ pass2.setPipeline(pFlow); pass2.setBindGroup(0, bgFlow); pass2.dispatchWorkgroups(gx(w), gx(h));
1065
+ pass2.end();
1066
+ enc.copyBufferToBuffer(outp, 0, staging, 0, w * h * 4);
1067
+ device.queue.submit([enc.finish()]);
1068
+ await staging.mapAsync(GPUMapMode.READ);
1069
+ const out = new Uint8Array(staging.getMappedRange().slice(0));
1070
+ staging.unmap();
1071
+ return out;
1072
+ }
1073
+
1074
+ // batched: upload/bind the pair once, produce mids for every t in ONE submit.
1075
+ // Buffer mode: a/b are RGBA arrays. Texture mode: a/b are GPUTextures (zero CPU pixels).
1076
+ // With textureOutput, outTexs[i] receives mid i and NOTHING is read back - returns null.
1077
+ async function runMulti(a, b, ts, outTexs) {
1078
+ if (tfact) { // factored graph: trunk once, head per t
1079
+ prepPair(a, b);
1080
+ for (let i = 0; i < ts.length; i++) runT(ts[i], outTexs[i]);
1081
+ return null;
1082
+ }
1083
+ if (ts.length > MAXT) throw new Error('too many timesteps');
1084
+ for (let i = 0; i < ts.length; i++) {
1085
+ device.queue.writeBuffer(tbufs[i], 0, new Float32Array([ts[i]]));
1086
+ }
1087
+ let tbg = null;
1088
+ if (textureInput) {
1089
+ tbg = texPrepBgs(a, b);
1090
+ } else {
1091
+ device.queue.writeBuffer(rgba0, 0, a.buffer, a.byteOffset, w * h * 4);
1092
+ device.queue.writeBuffer(rgba1, 0, b.buffer, b.byteOffset, w * h * 4);
1093
+ }
1094
+ const enc = device.createCommandEncoder();
1095
+ {
1096
+ const pass = enc.beginComputePass();
1097
+ pass.setPipeline(pPrepFull); pass.setBindGroup(0, tbg ? tbg.full : bgPrepFull); pass.dispatchWorkgroups(gx(w), gx(h));
1098
+ pass.end();
1099
+ }
1100
+ for (let i = 0; i < ts.length; i++) {
1101
+ const pass = enc.beginComputePass();
1102
+ pass.setPipeline(pPrepQ); pass.setBindGroup(0, tbg ? tbg.q[i] : bgPrepQt[i]); pass.dispatchWorkgroups(gx(QW), gx(QH));
1103
+ pass.setPipeline(pConv0a); pass.setBindGroup(0, bgConv0a); pass.dispatchWorkgroups(gx(W8), gx(H8), Z0A);
1104
+ pass.setPipeline(pConv0b); pass.setBindGroup(0, bgConv0b); pass.dispatchWorkgroups(gx(W16), gx(H16), C2 / 4);
1105
+ pass.end();
1106
+ enc.copyBufferToBuffer(f16a, 0, f16r, 0, actBytes);
1107
+ const pass2 = enc.beginComputePass();
1108
+ for (const { p, g } of bgB) {
1109
+ pass2.setPipeline(p); pass2.setBindGroup(0, g); pass2.dispatchWorkgroups(cbX, cbY, cbZ);
1110
+ }
1111
+ pass2.setPipeline(pDeconv); pass2.setBindGroup(0, bgDeconv); pass2.dispatchWorkgroups(gx(W8), gx(H8), 5);
1112
+ pass2.setPipeline(pFlow);
1113
+ pass2.setBindGroup(0, textureOutput ? flowBgFor(outTexs[i]) : bgFlow);
1114
+ pass2.dispatchWorkgroups(gx(w), gx(h));
1115
+ pass2.end();
1116
+ if (!textureOutput) enc.copyBufferToBuffer(outp, 0, stagings[i], 0, w * h * 4);
1117
+ }
1118
+ device.queue.submit([enc.finish()]);
1119
+ if (textureOutput) return null; // mids live in outTexs, nothing crosses the bus
1120
+ // map all stagings concurrently - sequential awaits cost ~1ms each
1121
+ await Promise.all(stagings.slice(0, ts.length).map(s => s.mapAsync(GPUMapMode.READ)));
1122
+ const outs = [];
1123
+ for (let i = 0; i < ts.length; i++) {
1124
+ outs.push(new Uint8Array(stagings[i].getMappedRange().slice(0)));
1125
+ stagings[i].unmap();
1126
+ }
1127
+ return outs;
1128
+ }
1129
+
1130
+ // ---- lazy per-mid API (texture in/out mode) ----
1131
+ // The queue is FIFO: a mid's present blit executes after EVERYTHING submitted
1132
+ // before it. Batching all mids upfront therefore makes the FIRST mid wait for
1133
+ // the WHOLE batch on the GPU. prepPair + runT let the caller submit each mid
1134
+ // just-in-time so present blits interleave with computes - the required
1135
+ // presentation delay shrinks from ~2x batch time to ~one mid time.
1136
+ let curPrep = null;
1137
+ function prepPair(a, b) {
1138
+ if (!textureInput) throw new Error('prepPair: texture-input mode only');
1139
+ curPrep = texPrepBgs(a, b);
1140
+ const enc = device.createCommandEncoder();
1141
+ if (tfact) {
1142
+ // the WHOLE t-free trunk runs here, once per pair
1143
+ const pass = enc.beginComputePass();
1144
+ pass.setPipeline(pPrepFull); pass.setBindGroup(0, curPrep.full); pass.dispatchWorkgroups(gx(w), gx(h));
1145
+ pass.setPipeline(pPrepQ); pass.setBindGroup(0, curPrep.q[0]); pass.dispatchWorkgroups(gx(QW), gx(QH));
1146
+ pass.setPipeline(pConv0a); pass.setBindGroup(0, bgConv0a); pass.dispatchWorkgroups(gx(W8), gx(H8), Z0A);
1147
+ pass.setPipeline(pConv0b); pass.setBindGroup(0, bgConv0b); pass.dispatchWorkgroups(gx(W16), gx(H16), C2 / 4);
1148
+ pass.end();
1149
+ enc.copyBufferToBuffer(f16a, 0, f16r, 0, actBytes); // feat0 residual for the head
1150
+ const pass2 = enc.beginComputePass();
1151
+ for (let i = 0; i < 6; i++) {
1152
+ pass2.setPipeline(bgB[i].p); pass2.setBindGroup(0, bgB[i].g); pass2.dispatchWorkgroups(cbX, cbY, cbZ);
1153
+ }
1154
+ pass2.end();
1155
+ enc.copyBufferToBuffer(f16a, 0, hbuf, 0, actBytes); // trunk features, reused per t
1156
+ } else {
1157
+ const pass = enc.beginComputePass();
1158
+ pass.setPipeline(pPrepFull); pass.setBindGroup(0, curPrep.full); pass.dispatchWorkgroups(gx(w), gx(h));
1159
+ pass.end();
1160
+ }
1161
+ device.queue.submit([enc.finish()]);
1162
+ }
1163
+ function runT(t, outTex) {
1164
+ if (!curPrep) throw new Error('runT before prepPair');
1165
+ if (!textureOutput) throw new Error('runT: texture-output mode only');
1166
+ const enc = device.createCommandEncoder();
1167
+ if (tfact) {
1168
+ // per-mid: FiLM(t) + convblocks 6,7 (+feat0 residual) + deconv + flow
1169
+ device.queue.writeBuffer(filmBuf, 0, filmParams(t));
1170
+ const pass = enc.beginComputePass();
1171
+ pass.setPipeline(pFilm); pass.setBindGroup(0, bgFilm);
1172
+ pass.dispatchWorkgroups(Math.ceil((C2 * H16 * W16) / 256));
1173
+ pass.setPipeline(bgB[6].p); pass.setBindGroup(0, bgB[6].g); pass.dispatchWorkgroups(cbX, cbY, cbZ);
1174
+ pass.setPipeline(bgB[7].p); pass.setBindGroup(0, bgB[7].g); pass.dispatchWorkgroups(cbX, cbY, cbZ);
1175
+ pass.setPipeline(pDeconv); pass.setBindGroup(0, bgDeconv); pass.dispatchWorkgroups(gx(W8), gx(H8), 5);
1176
+ if (refi) { // quarter-res refine chain; the flowout below folds the residual in
1177
+ pass.setPipeline(pRPrep); pass.setBindGroup(0, bgRPrep); pass.dispatchWorkgroups(gx(RW4), gx(RH4));
1178
+ pass.setPipeline(pRC0); pass.setBindGroup(0, bgRC0); pass.dispatchWorkgroups(gx(RW4), gx(RH4), RC / 4);
1179
+ pass.setPipeline(pRC1); pass.setBindGroup(0, bgRC1); pass.dispatchWorkgroups(gx(RW4), gx(RH4), RC / 4);
1180
+ pass.setPipeline(pRC1); pass.setBindGroup(0, bgRC2); pass.dispatchWorkgroups(gx(RW4), gx(RH4), RC / 4);
1181
+ pass.setPipeline(pROut); pass.setBindGroup(0, bgROut); pass.dispatchWorkgroups(gx(RW4), gx(RH4));
1182
+ }
1183
+ pass.setPipeline(pFlow); pass.setBindGroup(0, flowBgFor(outTex)); pass.dispatchWorkgroups(gx(w), gx(h));
1184
+ pass.end();
1185
+ device.queue.submit([enc.finish()]);
1186
+ return;
1187
+ }
1188
+ // single tbuf is safe: writeBuffer and submits are queue-ordered
1189
+ tScratch[0] = t;
1190
+ device.queue.writeBuffer(tbufs[0], 0, tScratch);
1191
+ const pass = enc.beginComputePass();
1192
+ pass.setPipeline(pPrepQ); pass.setBindGroup(0, curPrep.q[0]); pass.dispatchWorkgroups(gx(QW), gx(QH));
1193
+ pass.setPipeline(pConv0a); pass.setBindGroup(0, bgConv0a); pass.dispatchWorkgroups(gx(W8), gx(H8), Z0A);
1194
+ pass.setPipeline(pConv0b); pass.setBindGroup(0, bgConv0b); pass.dispatchWorkgroups(gx(W16), gx(H16), C2 / 4);
1195
+ pass.end();
1196
+ enc.copyBufferToBuffer(f16a, 0, f16r, 0, actBytes);
1197
+ const pass2 = enc.beginComputePass();
1198
+ for (const { p, g } of bgB) {
1199
+ pass2.setPipeline(p); pass2.setBindGroup(0, g); pass2.dispatchWorkgroups(cbX, cbY, cbZ);
1200
+ }
1201
+ pass2.setPipeline(pDeconv); pass2.setBindGroup(0, bgDeconv); pass2.dispatchWorkgroups(gx(W8), gx(H8), 5);
1202
+ pass2.setPipeline(pFlow); pass2.setBindGroup(0, flowBgFor(outTex)); pass2.dispatchWorkgroups(gx(w), gx(h));
1203
+ pass2.end();
1204
+ device.queue.submit([enc.finish()]);
1205
+ }
1206
+
1207
+ return { run, runMulti, prepPair, runT, profile, w, h };
1208
+ }
1209
+
1210
+
1211
+ // ---- one-shot conv autotune: bench wgslConvRB variants on this device ----
1212
+ // Returns the fastest {coc, slab, ms}. ~200-400ms of GPU time; call it once per
1213
+ // (adapter, model width, resolution) and persist the answer - relative ranking
1214
+ // is what matters, so light background load is tolerable.
1215
+ export async function tuneConvRB(device, { ci, co, w16, h16 }) {
1216
+ const base = [{ coc: 4, slab: 20 }, { coc: 8, slab: 20 }, { coc: 8, slab: 12 }, { coc: 4, slab: 12 }]
1217
+ .filter(v => co % v.coc === 0 && (v.slab * 324 + v.coc * v.slab * 9) * 2 <= 16384);
1218
+ // subgroup variants (weights via subgroupBroadcastFirst, no shared staging):
1219
+ // measured +20% at the 360p grid and -19% at 720p/coc8 on a 4060 Ti - exactly
1220
+ // why they go through the tuner instead of being hardcoded
1221
+ const variants = device.features.has('subgroups')
1222
+ ? [...base, ...base.map(v => ({ ...v, sg: true }))]
1223
+ : base;
1224
+ const buf = (bytes) => device.createBuffer({ size: Math.ceil(bytes / 4) * 4,
1225
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST });
1226
+ const src = buf(ci * w16 * h16 * 2), dst = buf(co * w16 * h16 * 2);
1227
+ const wgt = buf(co * ci * 9 * 2), bias = buf(co * 4), alpha = buf(co * 4);
1228
+ let best = null;
1229
+ for (const v of variants) {
1230
+ const gen = v.sg ? wgslConvRBSg : wgslConvRB;
1231
+ const p = device.createComputePipeline({ layout: 'auto', compute: {
1232
+ module: device.createShaderModule({ code: gen(ci, co, w16, h16, w16, h16, false, v) }),
1233
+ entryPoint: 'main' } });
1234
+ const bg = device.createBindGroup({ layout: p.getBindGroupLayout(0), entries: [
1235
+ { binding: 0, resource: { buffer: src } }, { binding: 1, resource: { buffer: wgt } },
1236
+ { binding: 2, resource: { buffer: bias } }, { binding: 3, resource: { buffer: alpha } },
1237
+ { binding: 4, resource: { buffer: dst } }] });
1238
+ const run = (k) => {
1239
+ const enc = device.createCommandEncoder();
1240
+ const pass = enc.beginComputePass();
1241
+ pass.setPipeline(p); pass.setBindGroup(0, bg);
1242
+ for (let i = 0; i < k; i++) pass.dispatchWorkgroups(Math.ceil(w16 / 16), Math.ceil(h16 / 16), co / v.coc);
1243
+ pass.end();
1244
+ device.queue.submit([enc.finish()]);
1245
+ };
1246
+ run(3); await device.queue.onSubmittedWorkDone(); // warm (incl pipeline compile)
1247
+ const t0 = performance.now();
1248
+ run(30); await device.queue.onSubmittedWorkDone();
1249
+ const ms = (performance.now() - t0) / 30;
1250
+ if (!best || ms < best.ms) best = { ...v, ms };
1251
+ }
1252
+ [src, dst, wgt, bias, alpha].forEach(b => b.destroy());
1253
+ return best;
1254
+ }