@framecast/rt 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +18 -0
- package/README.md +78 -0
- package/index.d.ts +57 -0
- package/package.json +43 -0
- package/rt.js +1254 -0
- package/sr.js +205 -0
package/rt.js
ADDED
|
@@ -0,0 +1,1254 @@
|
|
|
1
|
+
// Framecast custom WebGPU runtime - hand-rolled forward of the 1-block student
|
|
2
|
+
// (block0 of IFNet_m, scale=4, timestep=0.5). The whole frame is ONE command buffer
|
|
3
|
+
// of ~13 dispatches; no per-op JS, no runtime glue. Weights: assets/rt_1blk.{bin,json}
|
|
4
|
+
// (tools/export_rt_weights.py).
|
|
5
|
+
//
|
|
6
|
+
// Graph replicated from model/IFNet_m.py:
|
|
7
|
+
// x = cat(img0,img1,t) BGR /255 [7,H,W]
|
|
8
|
+
// xq = bilinear(x, 1/4, align_corners=false) [7,H/4,W/4]
|
|
9
|
+
// f8 = prelu(conv3x3s2(xq)) [120,H/8,W/8]
|
|
10
|
+
// f16 = prelu(conv3x3s2(f8)) [240,H/16,W/16]
|
|
11
|
+
// f16 = convblock(f16) + f16 (8x conv3x3s1+prelu, residual)
|
|
12
|
+
// tmp8 = deconv4x4s2(f16) [5,H/8,W/8]
|
|
13
|
+
// tmpF = bilinear(tmp8, x8, align_corners=false); flow = tmpF[0:4]*8; mask = tmpF[4]
|
|
14
|
+
// mid = sigmoid(mask)*warp(img0,flow01) + (1-s)*warp(img1,flow23)
|
|
15
|
+
// Requires H,W divisible by 16.
|
|
16
|
+
|
|
17
|
+
const WG = 8;
|
|
18
|
+
|
|
19
|
+
// identity stamps for texture-keyed caches. MODULE scope, not per-createRT:
|
|
20
|
+
// runtimes get rebuilt (res/model/tune switches) while the caller's textures
|
|
21
|
+
// survive with their stamps - a per-instance counter restarting at 1 would
|
|
22
|
+
// re-issue ids already stamped on live textures, and the next pool realloc
|
|
23
|
+
// would hand destroyed-texture bind groups out of the cache again.
|
|
24
|
+
let texBgSeq = 0;
|
|
25
|
+
const texBgId = (t) => t.__rtBgId || (t.__rtBgId = ++texBgSeq);
|
|
26
|
+
|
|
27
|
+
// NOTE: with layout:'auto' WebGPU strips unused bindings from the layout, so each
|
|
28
|
+
// entry point gets its own shader with exactly the bindings it touches.
|
|
29
|
+
function wgslPrepFull(W, H) {
|
|
30
|
+
return /* wgsl */`
|
|
31
|
+
@group(0) @binding(0) var<storage, read> rgba0: array<u32>;
|
|
32
|
+
@group(0) @binding(1) var<storage, read> rgba1: array<u32>;
|
|
33
|
+
@group(0) @binding(2) var<storage, read_write> imgs: array<f32>; // [6,${H},${W}] b0 g0 r0 b1 g1 r1
|
|
34
|
+
|
|
35
|
+
@compute @workgroup_size(${WG}, ${WG})
|
|
36
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
37
|
+
let x = i32(gid.x); let y = i32(gid.y);
|
|
38
|
+
if (x >= ${W} || y >= ${H}) { return; }
|
|
39
|
+
let o = y * ${W} + x;
|
|
40
|
+
let c0 = unpack4x8unorm(rgba0[o]).xyz;
|
|
41
|
+
let c1 = unpack4x8unorm(rgba1[o]).xyz;
|
|
42
|
+
let P = ${H * W};
|
|
43
|
+
imgs[o] = c0.z; imgs[P + o] = c0.y; imgs[2 * P + o] = c0.x;
|
|
44
|
+
imgs[3 * P + o] = c1.z; imgs[4 * P + o] = c1.y; imgs[5 * P + o] = c1.x;
|
|
45
|
+
}`;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
function wgslPrepQuarter(W, H, f16) {
|
|
49
|
+
const QW = W / 4, QH = H / 4;
|
|
50
|
+
const T = f16 ? 'f16' : 'f32';
|
|
51
|
+
return /* wgsl */`
|
|
52
|
+
${f16 ? 'enable f16;' : ''}
|
|
53
|
+
@group(0) @binding(0) var<storage, read> rgba0: array<u32>;
|
|
54
|
+
@group(0) @binding(1) var<storage, read> rgba1: array<u32>;
|
|
55
|
+
@group(0) @binding(2) var<storage, read_write> xq: array<${T}>; // [7,${QH},${QW}]
|
|
56
|
+
@group(0) @binding(3) var<storage, read> tstep: array<f32>; // [1] timestep
|
|
57
|
+
|
|
58
|
+
fn px(buf: i32, x: i32, y: i32) -> vec3<f32> {
|
|
59
|
+
let v = select(rgba1[y * ${W} + x], rgba0[y * ${W} + x], buf == 0);
|
|
60
|
+
return unpack4x8unorm(v).xyz; // r,g,b in 0..1
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
fn sampleQ(buf: i32, sx: f32, sy: f32) -> vec3<f32> {
|
|
64
|
+
let x0 = i32(floor(sx)); let y0 = i32(floor(sy));
|
|
65
|
+
let fx = sx - f32(x0); let fy = sy - f32(y0);
|
|
66
|
+
let xa = clamp(x0, 0, ${W - 1}); let xb = clamp(x0 + 1, 0, ${W - 1});
|
|
67
|
+
let ya = clamp(y0, 0, ${H - 1}); let yb = clamp(y0 + 1, 0, ${H - 1});
|
|
68
|
+
let v00 = px(buf, xa, ya); let v10 = px(buf, xb, ya);
|
|
69
|
+
let v01 = px(buf, xa, yb); let v11 = px(buf, xb, yb);
|
|
70
|
+
return mix(mix(v00, v10, fx), mix(v01, v11, fx), fy);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
@compute @workgroup_size(${WG}, ${WG})
|
|
74
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
75
|
+
let x = i32(gid.x); let y = i32(gid.y);
|
|
76
|
+
if (x >= ${QW} || y >= ${QH}) { return; }
|
|
77
|
+
// align_corners=false: src = (dst+0.5)*4 - 0.5
|
|
78
|
+
let sx = (f32(x) + 0.5) * 4.0 - 0.5;
|
|
79
|
+
let sy = (f32(y) + 0.5) * 4.0 - 0.5;
|
|
80
|
+
let c0 = sampleQ(0, sx, sy);
|
|
81
|
+
let c1 = sampleQ(1, sx, sy);
|
|
82
|
+
let o = y * ${QW} + x;
|
|
83
|
+
let P = ${QH * QW};
|
|
84
|
+
xq[o] = ${T}(c0.z); xq[P + o] = ${T}(c0.y); xq[2 * P + o] = ${T}(c0.x);
|
|
85
|
+
xq[3 * P + o] = ${T}(c1.z); xq[4 * P + o] = ${T}(c1.y); xq[5 * P + o] = ${T}(c1.x);
|
|
86
|
+
xq[6 * P + o] = ${T}(tstep[0]); // timestep
|
|
87
|
+
}`;
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
// conv3x3 (stride 1/2) + bias + PReLU, optional residual add (post-activation).
|
|
91
|
+
// v3: COC output channels per thread (src reads amortized), weight slab staged through
|
|
92
|
+
// workgroup memory, and optional f16 storage for activations+weights (accumulation
|
|
93
|
+
// stays f32) - halves the traffic on the bandwidth-bound conv stack.
|
|
94
|
+
function wgslConv(CI, CO, IW, IH, OW, OH, stride, residual, f16) {
|
|
95
|
+
const COC = CO % 4 === 0 ? 4 : 1; // channels per thread
|
|
96
|
+
const SLAB = Math.min(CI, 30); // ci per staging round (fits 16KB wg memory)
|
|
97
|
+
const slabFloats = COC * SLAB * 9;
|
|
98
|
+
const T = f16 ? 'f16' : 'f32';
|
|
99
|
+
return /* wgsl */`
|
|
100
|
+
${f16 ? 'enable f16;' : ''}
|
|
101
|
+
@group(0) @binding(0) var<storage, read> src: array<${T}>; // [${CI},${IH},${IW}]
|
|
102
|
+
@group(0) @binding(1) var<storage, read> wgt: array<${T}>; // [${CO},${CI},3,3]
|
|
103
|
+
@group(0) @binding(2) var<storage, read> bias: array<f32>; // [${CO}]
|
|
104
|
+
@group(0) @binding(3) var<storage, read> alpha: array<f32>; // [${CO}] prelu
|
|
105
|
+
@group(0) @binding(4) var<storage, read_write> dst: array<${T}>; // [${CO},${OH},${OW}]
|
|
106
|
+
${residual ? `@group(0) @binding(5) var<storage, read> res: array<${T}>;` : ``}
|
|
107
|
+
|
|
108
|
+
var<workgroup> wsh: array<${T}, ${slabFloats}>; // [COC, SLAB, 9] slab of weights
|
|
109
|
+
${stride === 1 ? `var<workgroup> tile: array<${T}, ${SLAB * 100}>; // [SLAB, 10, 10] input tiles (8x8 out + halo)` : ''}
|
|
110
|
+
|
|
111
|
+
@compute @workgroup_size(${WG}, ${WG}, 1)
|
|
112
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>,
|
|
113
|
+
@builtin(local_invocation_id) lid: vec3<u32>,
|
|
114
|
+
@builtin(workgroup_id) wid: vec3<u32>,
|
|
115
|
+
@builtin(local_invocation_index) li: u32) {
|
|
116
|
+
let x = i32(gid.x); let y = i32(gid.y);
|
|
117
|
+
let lx = i32(lid.x); let ly = i32(lid.y);
|
|
118
|
+
let wx0 = i32(wid.x) * ${WG}; let wy0 = i32(wid.y) * ${WG};
|
|
119
|
+
let cb = i32(gid.z) * ${COC}; // first output channel of this thread's block
|
|
120
|
+
let inb = x < ${OW} && y < ${OH};
|
|
121
|
+
var acc: array<f32, ${COC}>;
|
|
122
|
+
for (var c = 0; c < ${COC}; c++) { acc[c] = bias[cb + c]; }
|
|
123
|
+
|
|
124
|
+
for (var s = 0; s < ${CI}; s += ${SLAB}) {
|
|
125
|
+
let sl = min(${SLAB}, ${CI} - s);
|
|
126
|
+
let n = ${COC} * sl * 9;
|
|
127
|
+
workgroupBarrier();
|
|
128
|
+
// cooperative load: weights for [cb..cb+COC) x [s..s+sl) x 9
|
|
129
|
+
var idx = i32(li);
|
|
130
|
+
while (idx < n) {
|
|
131
|
+
let c = idx / (sl * 9);
|
|
132
|
+
let r = idx % (sl * 9);
|
|
133
|
+
wsh[idx] = wgt[(cb + c) * ${CI * 9} + (s + r / 9) * 9 + r % 9];
|
|
134
|
+
idx += ${WG * WG};
|
|
135
|
+
}
|
|
136
|
+
workgroupBarrier();
|
|
137
|
+
${stride === 1 ? `
|
|
138
|
+
// stride-1 path: stage 10x10 input tiles for the WHOLE ci-slab at once - barriers
|
|
139
|
+
// per slab (16/conv) instead of per channel (480/conv), values reused 9x from shared
|
|
140
|
+
var ti = i32(li);
|
|
141
|
+
let tn = sl * 100;
|
|
142
|
+
while (ti < tn) {
|
|
143
|
+
let ci = ti / 100;
|
|
144
|
+
let r = ti % 100;
|
|
145
|
+
let ty = wy0 + r / 10 - 1;
|
|
146
|
+
let tx = wx0 + r % 10 - 1;
|
|
147
|
+
var v = ${T}(0.0);
|
|
148
|
+
if (ty >= 0 && ty < ${IH} && tx >= 0 && tx < ${IW}) {
|
|
149
|
+
v = src[(s + ci) * ${IH * IW} + ty * ${IW} + tx];
|
|
150
|
+
}
|
|
151
|
+
tile[ti] = v;
|
|
152
|
+
ti += ${WG * WG};
|
|
153
|
+
}
|
|
154
|
+
workgroupBarrier();
|
|
155
|
+
if (inb) {
|
|
156
|
+
for (var ci = 0; ci < sl; ci++) {
|
|
157
|
+
let tb = ci * 100;
|
|
158
|
+
for (var ky = 0; ky < 3; ky++) {
|
|
159
|
+
for (var kx = 0; kx < 3; kx++) {
|
|
160
|
+
let sv = f32(tile[tb + (ly + ky) * 10 + lx + kx]);
|
|
161
|
+
let wb = ci * 9 + ky * 3 + kx;
|
|
162
|
+
for (var c = 0; c < ${COC}; c++) {
|
|
163
|
+
acc[c] += sv * f32(wsh[c * (sl * 9) + wb]);
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
}` : `
|
|
170
|
+
if (inb) {
|
|
171
|
+
for (var ci = 0; ci < sl; ci++) {
|
|
172
|
+
let sbase = (s + ci) * ${IH * IW};
|
|
173
|
+
for (var ky = 0; ky < 3; ky++) {
|
|
174
|
+
let iy = y * ${stride} + ky - 1;
|
|
175
|
+
if (iy < 0 || iy >= ${IH}) { continue; }
|
|
176
|
+
for (var kx = 0; kx < 3; kx++) {
|
|
177
|
+
let ix = x * ${stride} + kx - 1;
|
|
178
|
+
if (ix < 0 || ix >= ${IW}) { continue; }
|
|
179
|
+
let sv = f32(src[sbase + iy * ${IW} + ix]);
|
|
180
|
+
let wb = ci * 9 + ky * 3 + kx;
|
|
181
|
+
for (var c = 0; c < ${COC}; c++) {
|
|
182
|
+
acc[c] += sv * f32(wsh[c * (sl * 9) + wb]);
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
}`}
|
|
189
|
+
if (!inb) { return; }
|
|
190
|
+
for (var c = 0; c < ${COC}; c++) {
|
|
191
|
+
let co = cb + c;
|
|
192
|
+
let v = select(alpha[co] * acc[c], acc[c], acc[c] >= 0.0);
|
|
193
|
+
let o = co * ${OH * OW} + y * ${OW} + x;
|
|
194
|
+
dst[o] = ${T}(${residual ? `v + f32(res[o])` : `v`});
|
|
195
|
+
}
|
|
196
|
+
}`;
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
// register-blocked conv3x3 s1 (f16 storage): each thread computes a 2x2 pixel patch x
|
|
200
|
+
// 4 output channels (16 accumulators) - every shared read now feeds 4 FMAs instead of ~1.
|
|
201
|
+
// Workgroup = 8x8 threads = 16x16 output tile; input tiles 18x18 per ci staged per slab.
|
|
202
|
+
export function wgslConvRB(CI, CO, IW, IH, OW, OH, residual, tune) {
|
|
203
|
+
// tune: {coc, slab} - shared memory must fit slab*324*2 + coc*slab*9*2 <= 16384
|
|
204
|
+
const COC = (tune && tune.coc) || 4, SLAB = (tune && tune.slab) || 20;
|
|
205
|
+
const slabW = COC * SLAB * 9; // f16 weights in shared
|
|
206
|
+
const slabT = SLAB * 324; // 18x18 tiles
|
|
207
|
+
return /* wgsl */`
|
|
208
|
+
enable f16;
|
|
209
|
+
@group(0) @binding(0) var<storage, read> src: array<f16>;
|
|
210
|
+
@group(0) @binding(1) var<storage, read> wgt: array<f16>;
|
|
211
|
+
@group(0) @binding(2) var<storage, read> bias: array<f32>;
|
|
212
|
+
@group(0) @binding(3) var<storage, read> alpha: array<f32>;
|
|
213
|
+
@group(0) @binding(4) var<storage, read_write> dst: array<f16>;
|
|
214
|
+
${residual ? `@group(0) @binding(5) var<storage, read> res: array<f16>;` : ``}
|
|
215
|
+
|
|
216
|
+
var<workgroup> wsh: array<f16, ${slabW}>;
|
|
217
|
+
var<workgroup> tile: array<f16, ${slabT}>;
|
|
218
|
+
|
|
219
|
+
@compute @workgroup_size(8, 8, 1)
|
|
220
|
+
fn main(@builtin(local_invocation_id) lid: vec3<u32>,
|
|
221
|
+
@builtin(workgroup_id) wid: vec3<u32>,
|
|
222
|
+
@builtin(local_invocation_index) li: u32) {
|
|
223
|
+
let lx = i32(lid.x); let ly = i32(lid.y);
|
|
224
|
+
let ox0 = i32(wid.x) * 16; let oy0 = i32(wid.y) * 16; // wg output origin
|
|
225
|
+
let x0 = ox0 + lx * 2; let y0 = oy0 + ly * 2; // this thread's 2x2 patch
|
|
226
|
+
let cb = i32(wid.z) * ${COC};
|
|
227
|
+
// 16 scalar accumulators (unrolled - arrays may spill out of registers in WGSL)
|
|
228
|
+
${Array.from({ length: COC }, (_, c) =>
|
|
229
|
+
` var a${c}0 = bias[cb + ${c}]; var a${c}1 = a${c}0; var a${c}2 = a${c}0; var a${c}3 = a${c}0;`).join('\n')}
|
|
230
|
+
|
|
231
|
+
for (var s = 0; s < ${CI}; s += ${SLAB}) {
|
|
232
|
+
let sl = min(${SLAB}, ${CI} - s);
|
|
233
|
+
workgroupBarrier();
|
|
234
|
+
var idx = i32(li);
|
|
235
|
+
let wn = ${COC} * sl * 9;
|
|
236
|
+
while (idx < wn) {
|
|
237
|
+
let c = idx / (sl * 9);
|
|
238
|
+
let r = idx % (sl * 9);
|
|
239
|
+
wsh[idx] = wgt[(cb + c) * ${CI * 9} + (s + r / 9) * 9 + r % 9];
|
|
240
|
+
idx += 64;
|
|
241
|
+
}
|
|
242
|
+
var ti = i32(li);
|
|
243
|
+
let tn = sl * 324;
|
|
244
|
+
while (ti < tn) {
|
|
245
|
+
let ci = ti / 324;
|
|
246
|
+
let r = ti % 324;
|
|
247
|
+
let ty = oy0 + r / 18 - 1;
|
|
248
|
+
let tx = ox0 + r % 18 - 1;
|
|
249
|
+
var v = f16(0.0);
|
|
250
|
+
if (ty >= 0 && ty < ${IH} && tx >= 0 && tx < ${IW}) {
|
|
251
|
+
v = src[(s + ci) * ${IH * IW} + ty * ${IW} + tx];
|
|
252
|
+
}
|
|
253
|
+
tile[ti] = v;
|
|
254
|
+
ti += 64;
|
|
255
|
+
}
|
|
256
|
+
workgroupBarrier();
|
|
257
|
+
for (var ci = 0; ci < sl; ci++) {
|
|
258
|
+
let tb = ci * 324 + (ly * 2) * 18 + lx * 2; // top-left of this thread's 4x4 window
|
|
259
|
+
for (var ky = 0; ky < 3; ky++) {
|
|
260
|
+
let rb = tb + ky * 18;
|
|
261
|
+
for (var kx = 0; kx < 3; kx++) {
|
|
262
|
+
let t00 = f32(tile[rb + kx]);
|
|
263
|
+
let t01 = f32(tile[rb + kx + 1]);
|
|
264
|
+
let t10 = f32(tile[rb + kx + 18]);
|
|
265
|
+
let t11 = f32(tile[rb + kx + 19]);
|
|
266
|
+
let wb = ci * 9 + ky * 3 + kx;
|
|
267
|
+
${Array.from({ length: COC }, (_, c) => ` {
|
|
268
|
+
let wv = f32(wsh[${c} * (sl * 9) + wb]);
|
|
269
|
+
a${c}0 += t00 * wv; a${c}1 += t01 * wv; a${c}2 += t10 * wv; a${c}3 += t11 * wv;
|
|
270
|
+
}`).join('\n')}
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
}
|
|
274
|
+
}
|
|
275
|
+
${Array.from({ length: COC }, (_, c) => ` {
|
|
276
|
+
let co = cb + ${c};
|
|
277
|
+
let al = alpha[co];
|
|
278
|
+
${[0, 1, 2, 3].map(p => ` {
|
|
279
|
+
let x = x0 + ${p & 1};
|
|
280
|
+
let y = y0 + ${p >> 1};
|
|
281
|
+
if (x < ${OW} && y < ${OH}) {
|
|
282
|
+
let a = a${c}${p};
|
|
283
|
+
let v = select(al * a, a, a >= 0.0);
|
|
284
|
+
let o = co * ${OH * OW} + y * ${OW} + x;
|
|
285
|
+
dst[o] = f16(${residual ? `v + f32(res[o])` : `v`});
|
|
286
|
+
}
|
|
287
|
+
}`).join('\n')}
|
|
288
|
+
}`).join('\n')}
|
|
289
|
+
}`;
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
// exp/subgroups: like wgslConvRB, but weights are NOT staged through shared
|
|
293
|
+
// memory - every lane computes the same weight index, so the first lane loads it
|
|
294
|
+
// from global (L2) and subgroupBroadcastFirst hands it to the wave. Saves the
|
|
295
|
+
// cooperative staging loop and shrinks the barrier to the input tile only.
|
|
296
|
+
export function wgslConvRBSg(CI, CO, IW, IH, OW, OH, residual, tune) {
|
|
297
|
+
const COC = (tune && tune.coc) || 4, SLAB = (tune && tune.slab) || 20;
|
|
298
|
+
const slabT = SLAB * 324;
|
|
299
|
+
return /* wgsl */`
|
|
300
|
+
enable f16;
|
|
301
|
+
enable subgroups;
|
|
302
|
+
@group(0) @binding(0) var<storage, read> src: array<f16>;
|
|
303
|
+
@group(0) @binding(1) var<storage, read> wgt: array<f16>;
|
|
304
|
+
@group(0) @binding(2) var<storage, read> bias: array<f32>;
|
|
305
|
+
@group(0) @binding(3) var<storage, read> alpha: array<f32>;
|
|
306
|
+
@group(0) @binding(4) var<storage, read_write> dst: array<f16>;
|
|
307
|
+
${residual ? `@group(0) @binding(5) var<storage, read> res: array<f16>;` : ``}
|
|
308
|
+
|
|
309
|
+
var<workgroup> tile: array<f16, ${slabT}>;
|
|
310
|
+
|
|
311
|
+
@compute @workgroup_size(8, 8, 1)
|
|
312
|
+
fn main(@builtin(local_invocation_id) lid: vec3<u32>,
|
|
313
|
+
@builtin(workgroup_id) wid: vec3<u32>,
|
|
314
|
+
@builtin(local_invocation_index) li: u32) {
|
|
315
|
+
let lx = i32(lid.x); let ly = i32(lid.y);
|
|
316
|
+
let ox0 = i32(wid.x) * 16; let oy0 = i32(wid.y) * 16;
|
|
317
|
+
let x0 = ox0 + lx * 2; let y0 = oy0 + ly * 2;
|
|
318
|
+
let cb = i32(wid.z) * ${COC};
|
|
319
|
+
${Array.from({ length: COC }, (_, c) =>
|
|
320
|
+
` var a${c}0 = bias[cb + ${c}]; var a${c}1 = a${c}0; var a${c}2 = a${c}0; var a${c}3 = a${c}0;`).join('\n')}
|
|
321
|
+
|
|
322
|
+
for (var s = 0; s < ${CI}; s += ${SLAB}) {
|
|
323
|
+
let sl = min(${SLAB}, ${CI} - s);
|
|
324
|
+
workgroupBarrier();
|
|
325
|
+
var ti = i32(li);
|
|
326
|
+
let tn = sl * 324;
|
|
327
|
+
while (ti < tn) {
|
|
328
|
+
let ci = ti / 324;
|
|
329
|
+
let r = ti % 324;
|
|
330
|
+
let ty = oy0 + r / 18 - 1;
|
|
331
|
+
let tx = ox0 + r % 18 - 1;
|
|
332
|
+
var v = f16(0.0);
|
|
333
|
+
if (ty >= 0 && ty < ${IH} && tx >= 0 && tx < ${IW}) {
|
|
334
|
+
v = src[(s + ci) * ${IH * IW} + ty * ${IW} + tx];
|
|
335
|
+
}
|
|
336
|
+
tile[ti] = v;
|
|
337
|
+
ti += 64;
|
|
338
|
+
}
|
|
339
|
+
workgroupBarrier();
|
|
340
|
+
for (var ci = 0; ci < sl; ci++) {
|
|
341
|
+
let tb = ci * 324 + (ly * 2) * 18 + lx * 2;
|
|
342
|
+
let wrow = (s + ci) * 9;
|
|
343
|
+
for (var ky = 0; ky < 3; ky++) {
|
|
344
|
+
let rb = tb + ky * 18;
|
|
345
|
+
for (var kx = 0; kx < 3; kx++) {
|
|
346
|
+
let t00 = f32(tile[rb + kx]);
|
|
347
|
+
let t01 = f32(tile[rb + kx + 1]);
|
|
348
|
+
let t10 = f32(tile[rb + kx + 18]);
|
|
349
|
+
let t11 = f32(tile[rb + kx + 19]);
|
|
350
|
+
let wk = wrow + ky * 3 + kx;
|
|
351
|
+
${Array.from({ length: COC }, (_, c) => ` {
|
|
352
|
+
let wv = subgroupBroadcastFirst(f32(wgt[(cb + ${c}) * ${CI * 9} + wk]));
|
|
353
|
+
a${c}0 += t00 * wv; a${c}1 += t01 * wv; a${c}2 += t10 * wv; a${c}3 += t11 * wv;
|
|
354
|
+
}`).join('\n')}
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
}
|
|
358
|
+
}
|
|
359
|
+
${Array.from({ length: COC }, (_, c) => ` {
|
|
360
|
+
let co = cb + ${c};
|
|
361
|
+
let al = alpha[co];
|
|
362
|
+
${[0, 1, 2, 3].map(p => ` {
|
|
363
|
+
let x = x0 + ${p & 1};
|
|
364
|
+
let y = y0 + ${p >> 1};
|
|
365
|
+
if (x < ${OW} && y < ${OH}) {
|
|
366
|
+
let a = a${c}${p};
|
|
367
|
+
let v = select(al * a, a, a >= 0.0);
|
|
368
|
+
let o = co * ${OH * OW} + y * ${OW} + x;
|
|
369
|
+
dst[o] = f16(${residual ? `v + f32(res[o])` : `v`});
|
|
370
|
+
}
|
|
371
|
+
}`).join('\n')}
|
|
372
|
+
}`).join('\n')}
|
|
373
|
+
}`;
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
// texture-input prep variants: the video frame lives in a GPU texture (uploaded via
|
|
377
|
+
// copyExternalImageToTexture) and never touches the CPU; the sampler also does the
|
|
378
|
+
// display->model resize for free. Sampling at texel centers == exact texel values.
|
|
379
|
+
function wgslPrepFullTex(W, H) {
|
|
380
|
+
return /* wgsl */`
|
|
381
|
+
@group(0) @binding(0) var tex0: texture_2d<f32>;
|
|
382
|
+
@group(0) @binding(1) var tex1: texture_2d<f32>;
|
|
383
|
+
@group(0) @binding(2) var samp: sampler;
|
|
384
|
+
@group(0) @binding(3) var<storage, read_write> imgs: array<f32>; // [6,${H},${W}] BGR
|
|
385
|
+
|
|
386
|
+
@compute @workgroup_size(${WG}, ${WG})
|
|
387
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
388
|
+
let x = i32(gid.x); let y = i32(gid.y);
|
|
389
|
+
if (x >= ${W} || y >= ${H}) { return; }
|
|
390
|
+
let uv = (vec2<f32>(f32(x), f32(y)) + 0.5) / vec2<f32>(${W}.0, ${H}.0);
|
|
391
|
+
let c0 = textureSampleLevel(tex0, samp, uv, 0.0).rgb;
|
|
392
|
+
let c1 = textureSampleLevel(tex1, samp, uv, 0.0).rgb;
|
|
393
|
+
let o = y * ${W} + x;
|
|
394
|
+
let P = ${H * W};
|
|
395
|
+
imgs[o] = c0.b; imgs[P + o] = c0.g; imgs[2 * P + o] = c0.r;
|
|
396
|
+
imgs[3 * P + o] = c1.b; imgs[4 * P + o] = c1.g; imgs[5 * P + o] = c1.r;
|
|
397
|
+
}`;
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
function wgslPrepQuarterTex(W, H, f16) {
|
|
401
|
+
const QW = W / 4, QH = H / 4;
|
|
402
|
+
const T = f16 ? 'f16' : 'f32';
|
|
403
|
+
return /* wgsl */`
|
|
404
|
+
${f16 ? 'enable f16;' : ''}
|
|
405
|
+
@group(0) @binding(0) var tex0: texture_2d<f32>;
|
|
406
|
+
@group(0) @binding(1) var tex1: texture_2d<f32>;
|
|
407
|
+
@group(0) @binding(2) var samp: sampler;
|
|
408
|
+
@group(0) @binding(3) var<storage, read_write> xq: array<${T}>; // [7,${QH},${QW}]
|
|
409
|
+
@group(0) @binding(4) var<storage, read> tstep: array<f32>;
|
|
410
|
+
|
|
411
|
+
@compute @workgroup_size(${WG}, ${WG})
|
|
412
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
413
|
+
let x = i32(gid.x); let y = i32(gid.y);
|
|
414
|
+
if (x >= ${QW} || y >= ${QH}) { return; }
|
|
415
|
+
// quarter of the MODEL grid; the sampler maps through whatever the texture size is
|
|
416
|
+
let uv = (vec2<f32>(f32(x), f32(y)) + 0.5) / vec2<f32>(${QW}.0, ${QH}.0);
|
|
417
|
+
let c0 = textureSampleLevel(tex0, samp, uv, 0.0).rgb;
|
|
418
|
+
let c1 = textureSampleLevel(tex1, samp, uv, 0.0).rgb;
|
|
419
|
+
let o = y * ${QW} + x;
|
|
420
|
+
let P = ${QH * QW};
|
|
421
|
+
xq[o] = ${T}(c0.b); xq[P + o] = ${T}(c0.g); xq[2 * P + o] = ${T}(c0.r);
|
|
422
|
+
xq[3 * P + o] = ${T}(c1.b); xq[4 * P + o] = ${T}(c1.g); xq[5 * P + o] = ${T}(c1.r);
|
|
423
|
+
xq[6 * P + o] = ${T}(tstep[0]);
|
|
424
|
+
}`;
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
// t-factored prep: 6 channels, NO timestep - the trunk is t-free, t enters via FiLM
|
|
428
|
+
function wgslPrepQuarterTex6(W, H, f16) {
|
|
429
|
+
const QW = W / 4, QH = H / 4;
|
|
430
|
+
const T = f16 ? 'f16' : 'f32';
|
|
431
|
+
return /* wgsl */`
|
|
432
|
+
${f16 ? 'enable f16;' : ''}
|
|
433
|
+
@group(0) @binding(0) var tex0: texture_2d<f32>;
|
|
434
|
+
@group(0) @binding(1) var tex1: texture_2d<f32>;
|
|
435
|
+
@group(0) @binding(2) var samp: sampler;
|
|
436
|
+
@group(0) @binding(3) var<storage, read_write> xq: array<${T}>; // [6,${QH},${QW}]
|
|
437
|
+
|
|
438
|
+
@compute @workgroup_size(${WG}, ${WG})
|
|
439
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
440
|
+
let x = i32(gid.x); let y = i32(gid.y);
|
|
441
|
+
if (x >= ${QW} || y >= ${QH}) { return; }
|
|
442
|
+
let uv = (vec2<f32>(f32(x), f32(y)) + 0.5) / vec2<f32>(${QW}.0, ${QH}.0);
|
|
443
|
+
let c0 = textureSampleLevel(tex0, samp, uv, 0.0).rgb;
|
|
444
|
+
let c1 = textureSampleLevel(tex1, samp, uv, 0.0).rgb;
|
|
445
|
+
let o = y * ${QW} + x;
|
|
446
|
+
let P = ${QH * QW};
|
|
447
|
+
xq[o] = ${T}(c0.b); xq[P + o] = ${T}(c0.g); xq[2 * P + o] = ${T}(c0.r);
|
|
448
|
+
xq[3 * P + o] = ${T}(c1.b); xq[4 * P + o] = ${T}(c1.g); xq[5 * P + o] = ${T}(c1.r);
|
|
449
|
+
}`;
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
// FiLM conditioning: x' = x * (1 + scale[c]) + shift[c]; params are the tiny
|
|
453
|
+
// t-MLP's output, computed on the CPU per timestep (2*C floats)
|
|
454
|
+
function wgslFilm(C, N, f16) {
|
|
455
|
+
const T = f16 ? 'f16' : 'f32';
|
|
456
|
+
return /* wgsl */`
|
|
457
|
+
${f16 ? 'enable f16;' : ''}
|
|
458
|
+
@group(0) @binding(0) var<storage, read> src: array<${T}>; // [C,N] trunk features
|
|
459
|
+
@group(0) @binding(1) var<storage, read> prm: array<f32>; // [2C] scale, shift
|
|
460
|
+
@group(0) @binding(2) var<storage, read_write> dst: array<${T}>;
|
|
461
|
+
|
|
462
|
+
@compute @workgroup_size(256)
|
|
463
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
464
|
+
let i = i32(gid.x);
|
|
465
|
+
if (i >= ${C * N}) { return; }
|
|
466
|
+
let c = i / ${N};
|
|
467
|
+
dst[i] = ${T}(f32(src[i]) * (1.0 + prm[c]) + prm[${C} + c]);
|
|
468
|
+
}`;
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
// flowout variant writing straight into a storage texture (GPU-resident presentation:
|
|
472
|
+
// the mid never leaves the GPU). rgba8unorm store rounds instead of truncating - ±1 LSB
|
|
473
|
+
// vs the buffer path, invisible; the correctness harness keeps using the buffer path.
|
|
474
|
+
function wgslFlowOutTex(W, H, staticGuard = false, withRes = false) {
|
|
475
|
+
const TW = W / 8, TH = H / 8, RW = W / 4, RH = H / 4;
|
|
476
|
+
// tfact2: the refine residual (quarter res) is folded straight into this pass -
|
|
477
|
+
// bilinear x4 upsample (align_corners=False grid), added BEFORE the clamp
|
|
478
|
+
const RES_DECL = withRes ? /* wgsl */`
|
|
479
|
+
@group(0) @binding(3) var<storage, read> res: array<f32>; // [3,${RH},${RW}]
|
|
480
|
+
fn rtap(c: i32, x: i32, y: i32) -> f32 {
|
|
481
|
+
return res[c * ${RH * RW} + clamp(y, 0, ${RH - 1}) * ${RW} + clamp(x, 0, ${RW - 1})];
|
|
482
|
+
}
|
|
483
|
+
fn rup(c: i32, sx: f32, sy: f32) -> f32 {
|
|
484
|
+
let x0 = i32(floor(sx)); let y0 = i32(floor(sy));
|
|
485
|
+
let fx = sx - f32(x0); let fy = sy - f32(y0);
|
|
486
|
+
return mix(mix(rtap(c, x0, y0), rtap(c, x0 + 1, y0), fx),
|
|
487
|
+
mix(rtap(c, x0, y0 + 1), rtap(c, x0 + 1, y0 + 1), fx), fy);
|
|
488
|
+
}` : '';
|
|
489
|
+
const RES_ADD = withRes ? /* wgsl */`
|
|
490
|
+
let rx = (f32(x) + 0.5) / 4.0 - 0.5;
|
|
491
|
+
let ry = (f32(y) + 0.5) / 4.0 - 0.5;
|
|
492
|
+
bgr = bgr + vec3<f32>(rup(0, rx, ry), rup(1, rx, ry), rup(2, rx, ry));` : '';
|
|
493
|
+
// static-region protection (SVP-style): where A and B are locally identical
|
|
494
|
+
// (subtitles, logos, UI, frozen shots-in-motion) the warp can still DRAG other
|
|
495
|
+
// content there - blend back to the untouched source instead. Soft ramp so
|
|
496
|
+
// moving-edge pixels transition smoothly.
|
|
497
|
+
const GUARD = staticGuard ? /* wgsl */`
|
|
498
|
+
var d = 0.0;
|
|
499
|
+
for (var dy = -1; dy <= 1; dy++) {
|
|
500
|
+
for (var dx = -1; dx <= 1; dx++) {
|
|
501
|
+
let xx = x + dx; let yy = y + dy;
|
|
502
|
+
d += max(abs(img(0, xx, yy) - img(3, xx, yy)),
|
|
503
|
+
max(abs(img(1, xx, yy) - img(4, xx, yy)),
|
|
504
|
+
abs(img(2, xx, yy) - img(5, xx, yy))));
|
|
505
|
+
}
|
|
506
|
+
}
|
|
507
|
+
d *= (1.0 / 9.0);
|
|
508
|
+
let wStatic = 1.0 - smoothstep(0.03, 0.09, d);
|
|
509
|
+
if (wStatic > 0.001) {
|
|
510
|
+
let stat = vec3<f32>(
|
|
511
|
+
(img(0, x, y) + img(3, x, y)) * 0.5,
|
|
512
|
+
(img(1, x, y) + img(4, x, y)) * 0.5,
|
|
513
|
+
(img(2, x, y) + img(5, x, y)) * 0.5);
|
|
514
|
+
bgr = mix(bgr, stat, wStatic);
|
|
515
|
+
}` : '';
|
|
516
|
+
return /* wgsl */`
|
|
517
|
+
@group(0) @binding(0) var<storage, read> tmp8: array<f32>; // [5,${TH},${TW}]
|
|
518
|
+
@group(0) @binding(1) var<storage, read> imgs: array<f32>; // [6,${H},${W}]
|
|
519
|
+
@group(0) @binding(2) var outTex: texture_storage_2d<rgba8unorm, write>;
|
|
520
|
+
${RES_DECL}
|
|
521
|
+
fn tap(c: i32, x: i32, y: i32) -> f32 {
|
|
522
|
+
return tmp8[c * ${TH * TW} + clamp(y, 0, ${TH - 1}) * ${TW} + clamp(x, 0, ${TW - 1})];
|
|
523
|
+
}
|
|
524
|
+
fn up(c: i32, sx: f32, sy: f32) -> f32 {
|
|
525
|
+
let x0 = i32(floor(sx)); let y0 = i32(floor(sy));
|
|
526
|
+
let fx = sx - f32(x0); let fy = sy - f32(y0);
|
|
527
|
+
return mix(mix(tap(c, x0, y0), tap(c, x0 + 1, y0), fx),
|
|
528
|
+
mix(tap(c, x0, y0 + 1), tap(c, x0 + 1, y0 + 1), fx), fy);
|
|
529
|
+
}
|
|
530
|
+
fn img(plane: i32, x: i32, y: i32) -> f32 {
|
|
531
|
+
return imgs[plane * ${H * W} + clamp(y, 0, ${H - 1}) * ${W} + clamp(x, 0, ${W - 1})];
|
|
532
|
+
}
|
|
533
|
+
fn warp3(base: i32, sx: f32, sy: f32) -> vec3<f32> {
|
|
534
|
+
let x0 = i32(floor(sx)); let y0 = i32(floor(sy));
|
|
535
|
+
let fx = sx - f32(x0); let fy = sy - f32(y0);
|
|
536
|
+
var r: vec3<f32>;
|
|
537
|
+
for (var c = 0; c < 3; c++) {
|
|
538
|
+
let p = base + c;
|
|
539
|
+
r[c] = mix(mix(img(p, x0, y0), img(p, x0 + 1, y0), fx),
|
|
540
|
+
mix(img(p, x0, y0 + 1), img(p, x0 + 1, y0 + 1), fx), fy);
|
|
541
|
+
}
|
|
542
|
+
return r; // b,g,r
|
|
543
|
+
}
|
|
544
|
+
|
|
545
|
+
@compute @workgroup_size(${WG}, ${WG})
|
|
546
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
547
|
+
let x = i32(gid.x); let y = i32(gid.y);
|
|
548
|
+
if (x >= ${W} || y >= ${H}) { return; }
|
|
549
|
+
let sx = (f32(x) + 0.5) / 8.0 - 0.5;
|
|
550
|
+
let sy = (f32(y) + 0.5) / 8.0 - 0.5;
|
|
551
|
+
let fx0 = up(0, sx, sy) * 8.0;
|
|
552
|
+
let fy0 = up(1, sx, sy) * 8.0;
|
|
553
|
+
let fx1 = up(2, sx, sy) * 8.0;
|
|
554
|
+
let fy1 = up(3, sx, sy) * 8.0;
|
|
555
|
+
let m = 1.0 / (1.0 + exp(-up(4, sx, sy)));
|
|
556
|
+
let w0 = warp3(0, f32(x) + fx0, f32(y) + fy0);
|
|
557
|
+
let w1 = warp3(3, f32(x) + fx1, f32(y) + fy1);
|
|
558
|
+
var bgr = w0 * m + w1 * (1.0 - m);
|
|
559
|
+
${RES_ADD}
|
|
560
|
+
bgr = clamp(bgr, vec3<f32>(0.0), vec3<f32>(1.0));
|
|
561
|
+
${GUARD}
|
|
562
|
+
textureStore(outTex, vec2<i32>(x, y), vec4<f32>(bgr.z, bgr.y, bgr.x, 1.0));
|
|
563
|
+
}`;
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
// ---- refine head (tfact2): occlusion repair at QUARTER resolution ----
|
|
567
|
+
// Gathers [warped0(3), warped1(3), mask(1), flow*(0.25/20)(4)] = 11ch at H/4.
|
|
568
|
+
// F.interpolate(x, 0.25, bilinear, align_corners=False) samples the source at
|
|
569
|
+
// 4x+1.5 - i.e. the mean of the CENTER 2x2 of each 4x4 block; we warp those
|
|
570
|
+
// four full-res positions and average, matching the trainer exactly.
|
|
571
|
+
function wgslRefinePrep(W, H, f16) {
|
|
572
|
+
const TW = W / 8, TH = H / 8, HW = W / 4, HH = H / 4;
|
|
573
|
+
const T = f16 ? 'f16' : 'f32';
|
|
574
|
+
return /* wgsl */`
|
|
575
|
+
${f16 ? 'enable f16;' : ''}
|
|
576
|
+
@group(0) @binding(0) var<storage, read> tmp8: array<f32>; // [5,${TH},${TW}]
|
|
577
|
+
@group(0) @binding(1) var<storage, read> imgs: array<f32>; // [6,${H},${W}]
|
|
578
|
+
@group(0) @binding(2) var<storage, read_write> rin: array<${T}>; // [11,${HH},${HW}]
|
|
579
|
+
|
|
580
|
+
fn tap(c: i32, x: i32, y: i32) -> f32 {
|
|
581
|
+
return tmp8[c * ${TH * TW} + clamp(y, 0, ${TH - 1}) * ${TW} + clamp(x, 0, ${TW - 1})];
|
|
582
|
+
}
|
|
583
|
+
fn up(c: i32, sx: f32, sy: f32) -> f32 {
|
|
584
|
+
let x0 = i32(floor(sx)); let y0 = i32(floor(sy));
|
|
585
|
+
let fx = sx - f32(x0); let fy = sy - f32(y0);
|
|
586
|
+
return mix(mix(tap(c, x0, y0), tap(c, x0 + 1, y0), fx),
|
|
587
|
+
mix(tap(c, x0, y0 + 1), tap(c, x0 + 1, y0 + 1), fx), fy);
|
|
588
|
+
}
|
|
589
|
+
fn img(plane: i32, x: i32, y: i32) -> f32 {
|
|
590
|
+
return imgs[plane * ${H * W} + clamp(y, 0, ${H - 1}) * ${W} + clamp(x, 0, ${W - 1})];
|
|
591
|
+
}
|
|
592
|
+
fn warp3(base: i32, sx: f32, sy: f32) -> vec3<f32> {
|
|
593
|
+
let x0 = i32(floor(sx)); let y0 = i32(floor(sy));
|
|
594
|
+
let fx = sx - f32(x0); let fy = sy - f32(y0);
|
|
595
|
+
var r: vec3<f32>;
|
|
596
|
+
for (var c = 0; c < 3; c++) {
|
|
597
|
+
let p = base + c;
|
|
598
|
+
r[c] = mix(mix(img(p, x0, y0), img(p, x0 + 1, y0), fx),
|
|
599
|
+
mix(img(p, x0, y0 + 1), img(p, x0 + 1, y0 + 1), fx), fy);
|
|
600
|
+
}
|
|
601
|
+
return r;
|
|
602
|
+
}
|
|
603
|
+
|
|
604
|
+
@compute @workgroup_size(${WG}, ${WG})
|
|
605
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
606
|
+
let hx = i32(gid.x); let hy = i32(gid.y);
|
|
607
|
+
if (hx >= ${HW} || hy >= ${HH}) { return; }
|
|
608
|
+
var w0 = vec3<f32>(0.0); var w1 = vec3<f32>(0.0);
|
|
609
|
+
var mk = 0.0; var fl = vec4<f32>(0.0);
|
|
610
|
+
for (var sy = 1; sy <= 2; sy++) {
|
|
611
|
+
for (var sxp = 1; sxp <= 2; sxp++) {
|
|
612
|
+
let X = hx * 4 + sxp; let Y = hy * 4 + sy; // center 2x2 of the 4x4 block
|
|
613
|
+
let gx8 = (f32(X) + 0.5) / 8.0 - 0.5;
|
|
614
|
+
let gy8 = (f32(Y) + 0.5) / 8.0 - 0.5;
|
|
615
|
+
let fx0 = up(0, gx8, gy8) * 8.0; let fy0 = up(1, gx8, gy8) * 8.0;
|
|
616
|
+
let fx1 = up(2, gx8, gy8) * 8.0; let fy1 = up(3, gx8, gy8) * 8.0;
|
|
617
|
+
mk += 1.0 / (1.0 + exp(-up(4, gx8, gy8)));
|
|
618
|
+
w0 += warp3(0, f32(X) + fx0, f32(Y) + fy0);
|
|
619
|
+
w1 += warp3(3, f32(X) + fx1, f32(Y) + fy1);
|
|
620
|
+
fl += vec4<f32>(fx0, fy0, fx1, fy1);
|
|
621
|
+
}
|
|
622
|
+
}
|
|
623
|
+
let P = ${HH * HW};
|
|
624
|
+
let o = hy * ${HW} + hx;
|
|
625
|
+
let q = 0.25;
|
|
626
|
+
let fn_ = 0.25 * (0.25 / 20.0); // mean * (0.25/FLOW_NORM)
|
|
627
|
+
rin[o] = ${T}(w0.x * q); rin[P + o] = ${T}(w0.y * q); rin[2 * P + o] = ${T}(w0.z * q);
|
|
628
|
+
rin[3 * P + o] = ${T}(w1.x * q); rin[4 * P + o] = ${T}(w1.y * q); rin[5 * P + o] = ${T}(w1.z * q);
|
|
629
|
+
rin[6 * P + o] = ${T}(mk * q);
|
|
630
|
+
rin[7 * P + o] = ${T}(fl.x * fn_); rin[8 * P + o] = ${T}(fl.y * fn_);
|
|
631
|
+
rin[9 * P + o] = ${T}(fl.z * fn_); rin[10 * P + o] = ${T}(fl.w * fn_);
|
|
632
|
+
}`;
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
// final refine conv (C->3) + sigmoid*2-1 residual, half res, f32 out
|
|
636
|
+
function wgslRefineOut(C, HW, HH, f16) {
|
|
637
|
+
const T = f16 ? 'f16' : 'f32';
|
|
638
|
+
return /* wgsl */`
|
|
639
|
+
${f16 ? 'enable f16;' : ''}
|
|
640
|
+
@group(0) @binding(0) var<storage, read> src: array<${T}>; // [${C},${HH},${HW}]
|
|
641
|
+
@group(0) @binding(1) var<storage, read> wgt: array<f32>; // [3,${C},3,3]
|
|
642
|
+
@group(0) @binding(2) var<storage, read> bias: array<f32>; // [3]
|
|
643
|
+
@group(0) @binding(3) var<storage, read_write> dst: array<f32>; // [3,${HH},${HW}]
|
|
644
|
+
|
|
645
|
+
@compute @workgroup_size(${WG}, ${WG})
|
|
646
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
647
|
+
let x = i32(gid.x); let y = i32(gid.y);
|
|
648
|
+
if (x >= ${HW} || y >= ${HH}) { return; }
|
|
649
|
+
for (var co = 0; co < 3; co++) {
|
|
650
|
+
var acc = bias[co];
|
|
651
|
+
for (var ci = 0; ci < ${C}; ci++) {
|
|
652
|
+
let sb = ci * ${HH * HW};
|
|
653
|
+
let wb = (co * ${C} + ci) * 9;
|
|
654
|
+
for (var ky = 0; ky < 3; ky++) {
|
|
655
|
+
let sy = clamp(y + ky - 1, 0, ${HH - 1});
|
|
656
|
+
for (var kx = 0; kx < 3; kx++) {
|
|
657
|
+
let sx = clamp(x + kx - 1, 0, ${HW - 1});
|
|
658
|
+
acc += f32(src[sb + sy * ${HW} + sx]) * wgt[wb + ky * 3 + kx];
|
|
659
|
+
}
|
|
660
|
+
}
|
|
661
|
+
}
|
|
662
|
+
dst[co * ${HH * HW} + y * ${HW} + x] = 2.0 / (1.0 + exp(-acc)) - 1.0;
|
|
663
|
+
}
|
|
664
|
+
}`;
|
|
665
|
+
}
|
|
666
|
+
|
|
667
|
+
// one-shot f32 -> f16 conversion (weights at init)
|
|
668
|
+
export const WGSL_TO_F16 = /* wgsl */`
|
|
669
|
+
enable f16;
|
|
670
|
+
@group(0) @binding(0) var<storage, read> src: array<f32>;
|
|
671
|
+
@group(0) @binding(1) var<storage, read_write> dst: array<f16>;
|
|
672
|
+
@compute @workgroup_size(256)
|
|
673
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
674
|
+
let i = gid.x;
|
|
675
|
+
if (i < arrayLength(&src)) { dst[i] = f16(src[i]); }
|
|
676
|
+
}`;
|
|
677
|
+
|
|
678
|
+
// ConvTranspose2d 4x4 stride2 pad1, no activation. Weight layout [CI, CO, 4, 4].
|
|
679
|
+
function wgslDeconv(CI, CO, IW, IH, OW, OH, f16src) {
|
|
680
|
+
const T = f16src ? 'f16' : 'f32';
|
|
681
|
+
return /* wgsl */`
|
|
682
|
+
${f16src ? 'enable f16;' : ''}
|
|
683
|
+
@group(0) @binding(0) var<storage, read> src: array<${T}>;
|
|
684
|
+
@group(0) @binding(1) var<storage, read> wgt: array<f32>;
|
|
685
|
+
@group(0) @binding(2) var<storage, read> bias: array<f32>;
|
|
686
|
+
@group(0) @binding(3) var<storage, read_write> dst: array<f32>;
|
|
687
|
+
|
|
688
|
+
@compute @workgroup_size(${WG}, ${WG}, 1)
|
|
689
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
690
|
+
let x = i32(gid.x); let y = i32(gid.y); let co = i32(gid.z);
|
|
691
|
+
if (x >= ${OW} || y >= ${OH}) { return; }
|
|
692
|
+
var acc = bias[co];
|
|
693
|
+
for (var ky = 0; ky < 4; ky++) {
|
|
694
|
+
let ty = y + 1 - ky;
|
|
695
|
+
if (ty < 0 || (ty & 1) != 0) { continue; }
|
|
696
|
+
let iy = ty / 2;
|
|
697
|
+
if (iy >= ${IH}) { continue; }
|
|
698
|
+
for (var kx = 0; kx < 4; kx++) {
|
|
699
|
+
let tx = x + 1 - kx;
|
|
700
|
+
if (tx < 0 || (tx & 1) != 0) { continue; }
|
|
701
|
+
let ix = tx / 2;
|
|
702
|
+
if (ix >= ${IW}) { continue; }
|
|
703
|
+
for (var ci = 0; ci < ${CI}; ci++) {
|
|
704
|
+
acc += f32(src[ci * ${IH * IW} + iy * ${IW} + ix])
|
|
705
|
+
* wgt[ci * ${CO * 16} + co * 16 + ky * 4 + kx];
|
|
706
|
+
}
|
|
707
|
+
}
|
|
708
|
+
}
|
|
709
|
+
dst[co * ${OH * OW} + y * ${OW} + x] = acc;
|
|
710
|
+
}`;
|
|
711
|
+
}
|
|
712
|
+
|
|
713
|
+
// upsample tmp8 x8 (align_corners=false), flow*=8, warp both images, sigmoid blend, pack rgba
|
|
714
|
+
function wgslFlowOut(W, H) {
|
|
715
|
+
const TW = W / 8, TH = H / 8;
|
|
716
|
+
return /* wgsl */`
|
|
717
|
+
@group(0) @binding(0) var<storage, read> tmp8: array<f32>; // [5,${TH},${TW}]
|
|
718
|
+
@group(0) @binding(1) var<storage, read> imgs: array<f32>; // [6,${H},${W}]
|
|
719
|
+
@group(0) @binding(2) var<storage, read_write> outp: array<u32>; // rgba
|
|
720
|
+
|
|
721
|
+
fn tap(c: i32, x: i32, y: i32) -> f32 {
|
|
722
|
+
return tmp8[c * ${TH * TW} + clamp(y, 0, ${TH - 1}) * ${TW} + clamp(x, 0, ${TW - 1})];
|
|
723
|
+
}
|
|
724
|
+
fn up(c: i32, sx: f32, sy: f32) -> f32 {
|
|
725
|
+
let x0 = i32(floor(sx)); let y0 = i32(floor(sy));
|
|
726
|
+
let fx = sx - f32(x0); let fy = sy - f32(y0);
|
|
727
|
+
return mix(mix(tap(c, x0, y0), tap(c, x0 + 1, y0), fx),
|
|
728
|
+
mix(tap(c, x0, y0 + 1), tap(c, x0 + 1, y0 + 1), fx), fy);
|
|
729
|
+
}
|
|
730
|
+
fn img(plane: i32, x: i32, y: i32) -> f32 {
|
|
731
|
+
return imgs[plane * ${H * W} + clamp(y, 0, ${H - 1}) * ${W} + clamp(x, 0, ${W - 1})];
|
|
732
|
+
}
|
|
733
|
+
// grid_sample bilinear, border, align_corners=true == pixel-space bilinear with clamped taps
|
|
734
|
+
fn warp3(base: i32, sx: f32, sy: f32) -> vec3<f32> {
|
|
735
|
+
let x0 = i32(floor(sx)); let y0 = i32(floor(sy));
|
|
736
|
+
let fx = sx - f32(x0); let fy = sy - f32(y0);
|
|
737
|
+
var r: vec3<f32>;
|
|
738
|
+
for (var c = 0; c < 3; c++) {
|
|
739
|
+
let p = base + c;
|
|
740
|
+
r[c] = mix(mix(img(p, x0, y0), img(p, x0 + 1, y0), fx),
|
|
741
|
+
mix(img(p, x0, y0 + 1), img(p, x0 + 1, y0 + 1), fx), fy);
|
|
742
|
+
}
|
|
743
|
+
return r; // b,g,r
|
|
744
|
+
}
|
|
745
|
+
|
|
746
|
+
@compute @workgroup_size(${WG}, ${WG})
|
|
747
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
748
|
+
let x = i32(gid.x); let y = i32(gid.y);
|
|
749
|
+
if (x >= ${W} || y >= ${H}) { return; }
|
|
750
|
+
// align_corners=false x8 upsample: src = (dst+0.5)/8 - 0.5; flow scaled by 8 (=scale*2*... baked)
|
|
751
|
+
let sx = (f32(x) + 0.5) / 8.0 - 0.5;
|
|
752
|
+
let sy = (f32(y) + 0.5) / 8.0 - 0.5;
|
|
753
|
+
let fx0 = up(0, sx, sy) * 8.0;
|
|
754
|
+
let fy0 = up(1, sx, sy) * 8.0;
|
|
755
|
+
let fx1 = up(2, sx, sy) * 8.0;
|
|
756
|
+
let fy1 = up(3, sx, sy) * 8.0;
|
|
757
|
+
let m = 1.0 / (1.0 + exp(-up(4, sx, sy))); // sigmoid(mask)
|
|
758
|
+
let w0 = warp3(0, f32(x) + fx0, f32(y) + fy0);
|
|
759
|
+
let w1 = warp3(3, f32(x) + fx1, f32(y) + fy1);
|
|
760
|
+
let bgr = w0 * m + w1 * (1.0 - m);
|
|
761
|
+
// BGR -> RGB, *255 truncate (matches rife-core prepost), alpha 255
|
|
762
|
+
let r = u32(clamp(bgr.z, 0.0, 1.0) * 255.0);
|
|
763
|
+
let g = u32(clamp(bgr.y, 0.0, 1.0) * 255.0);
|
|
764
|
+
let b = u32(clamp(bgr.x, 0.0, 1.0) * 255.0);
|
|
765
|
+
outp[y * ${W} + x] = r | (g << 8u) | (b << 16u) | (255u << 24u);
|
|
766
|
+
}`;
|
|
767
|
+
}
|
|
768
|
+
|
|
769
|
+
export async function createRT(device, { w, h, weightsBin, weightsManifest, convTune,
|
|
770
|
+
textureInput = false, textureOutput = false,
|
|
771
|
+
staticGuard = false }) {
|
|
772
|
+
if (w % 16 || h % 16) throw new Error(`rt: dims must be /16 (got ${w}x${h})`);
|
|
773
|
+
const QW = w / 4, QH = h / 4, W8 = w / 8, H8 = h / 8, W16 = w / 16, H16 = h / 16;
|
|
774
|
+
const useF16 = device.features.has('shader-f16');
|
|
775
|
+
// channel widths come from the weights themselves (supports slim students)
|
|
776
|
+
const C1 = weightsManifest['block0.conv0.0.0.weight'].shape[0]; // conv0a out (120 full / 60 slim)
|
|
777
|
+
const C2 = weightsManifest['block0.conv0.1.0.weight'].shape[0]; // main width (240 full / 120 slim)
|
|
778
|
+
if (C2 % 4) throw new Error('rt: main width must be /4');
|
|
779
|
+
// t-factored graph: trunk (conv0 + 6 convblocks) is timestep-free and runs once
|
|
780
|
+
// per pair; FiLM(t) + convblocks 6,7 + lastconv run per mid. Detected by the
|
|
781
|
+
// film MLP in the manifest; input prep is 6ch (no t channel).
|
|
782
|
+
const tfact = 'film.2.weight' in weightsManifest;
|
|
783
|
+
const CI0 = weightsManifest['block0.conv0.0.0.weight'].shape[1]; // 7 classic, 6 tfact
|
|
784
|
+
if (tfact && (!textureInput || !textureOutput)) {
|
|
785
|
+
throw new Error('rt: tfact weights need texture input/output mode');
|
|
786
|
+
}
|
|
787
|
+
const refi = tfact && ('refine.c0.weight' in weightsManifest); // tfact2
|
|
788
|
+
const Z0A = C1 % 4 === 0 ? C1 / 4 : C1; // conv0a kernel packs 4 channels only when C1 is /4
|
|
789
|
+
|
|
790
|
+
const bufBytes = (bytes, usage = GPUBufferUsage.STORAGE) => device.createBuffer({
|
|
791
|
+
size: Math.ceil(bytes / 4) * 4,
|
|
792
|
+
usage: usage | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC });
|
|
793
|
+
const buf = (n) => bufBytes(n * 4);
|
|
794
|
+
const abuf = (n) => bufBytes(n * (useF16 ? 2 : 4)); // activation dtype
|
|
795
|
+
|
|
796
|
+
const mod = (code) => device.createShaderModule({ code });
|
|
797
|
+
const pipe = (code, entry = 'main') => device.createComputePipeline({
|
|
798
|
+
layout: 'auto', compute: { module: mod(code), entryPoint: entry } });
|
|
799
|
+
const bg = (p, entries) => device.createBindGroup({
|
|
800
|
+
layout: p.getBindGroupLayout(0),
|
|
801
|
+
entries: entries.map((b, i) => ({ binding: i, resource: { buffer: b } })) });
|
|
802
|
+
|
|
803
|
+
// weights (f32 upload; conv weights get f16 copies when supported)
|
|
804
|
+
const man = weightsManifest;
|
|
805
|
+
const wbuf = {};
|
|
806
|
+
for (const [name, m] of Object.entries(man)) {
|
|
807
|
+
const n = m.shape.reduce((a, b) => a * b, 1);
|
|
808
|
+
wbuf[name] = buf(n);
|
|
809
|
+
device.queue.writeBuffer(wbuf[name], 0, weightsBin, m.offset * 4, n * 4);
|
|
810
|
+
}
|
|
811
|
+
// one pipeline shared by all weight conversions
|
|
812
|
+
const pToF16 = useF16 ? pipe(WGSL_TO_F16) : null;
|
|
813
|
+
const convW = (name) => {
|
|
814
|
+
if (!useF16) return wbuf[name];
|
|
815
|
+
const n = man[name].shape.reduce((a, b) => a * b, 1);
|
|
816
|
+
const half = bufBytes(n * 2);
|
|
817
|
+
const p = pToF16;
|
|
818
|
+
const enc = device.createCommandEncoder();
|
|
819
|
+
const pass = enc.beginComputePass();
|
|
820
|
+
pass.setPipeline(p);
|
|
821
|
+
pass.setBindGroup(0, bg(p, [wbuf[name], half]));
|
|
822
|
+
pass.dispatchWorkgroups(Math.ceil(n / 256));
|
|
823
|
+
pass.end();
|
|
824
|
+
device.queue.submit([enc.finish()]);
|
|
825
|
+
return half;
|
|
826
|
+
};
|
|
827
|
+
|
|
828
|
+
// activations (f16 when supported), fixed f32 elsewhere
|
|
829
|
+
const tbuf = buf(1);
|
|
830
|
+
const rgba0 = textureInput ? null : buf(w * h);
|
|
831
|
+
const rgba1 = textureInput ? null : buf(w * h);
|
|
832
|
+
const imgs = buf(6 * w * h);
|
|
833
|
+
const xq = abuf(CI0 * QH * QW);
|
|
834
|
+
const f8 = abuf(C1 * H8 * W8);
|
|
835
|
+
const actBytes = C2 * H16 * W16 * (useF16 ? 2 : 4);
|
|
836
|
+
const f16a = bufBytes(actBytes), f16b = bufBytes(actBytes), f16r = bufBytes(actBytes);
|
|
837
|
+
const tmp8 = buf(5 * H8 * W8);
|
|
838
|
+
// readback plumbing exists only for the buffer-output path (rt_test harness);
|
|
839
|
+
// texture-output callers never read back - ~7*w*h*4 bytes of MAP_READ saved
|
|
840
|
+
const outp = textureOutput ? null : buf(w * h);
|
|
841
|
+
const staging = textureOutput ? null
|
|
842
|
+
: device.createBuffer({ size: w * h * 4, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ });
|
|
843
|
+
// slots for batched multi-t runs (factor N: upload once, N-1 mids in ONE submit)
|
|
844
|
+
const MAXT = 5;
|
|
845
|
+
const tbufs = [], stagings = [];
|
|
846
|
+
for (let i = 0; i < MAXT; i++) {
|
|
847
|
+
tbufs.push(buf(1));
|
|
848
|
+
if (!textureOutput) {
|
|
849
|
+
stagings.push(device.createBuffer({ size: w * h * 4, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }));
|
|
850
|
+
}
|
|
851
|
+
}
|
|
852
|
+
|
|
853
|
+
const sampler = textureInput
|
|
854
|
+
? device.createSampler({ magFilter: 'linear', minFilter: 'linear',
|
|
855
|
+
addressModeU: 'clamp-to-edge', addressModeV: 'clamp-to-edge' })
|
|
856
|
+
: null;
|
|
857
|
+
// the heavy shaders compile in PARALLEL on Dawn's worker pool (and without
|
|
858
|
+
// blocking the main thread) - createRT is async anyway, so batch-await them
|
|
859
|
+
const pipeAsync = (code, entry = 'main') => device.createComputePipelineAsync({
|
|
860
|
+
layout: 'auto', compute: { module: mod(code), entryPoint: entry } });
|
|
861
|
+
const sgTuned = convTune && convTune.sg && device.features.has('subgroups');
|
|
862
|
+
const [pPrepFull, pPrepQ, pConv0a, pConv0b, pConvB, pConvBR, pDeconv, pFlow] = await Promise.all([
|
|
863
|
+
pipeAsync(textureInput ? wgslPrepFullTex(w, h) : wgslPrepFull(w, h)),
|
|
864
|
+
pipeAsync(tfact ? wgslPrepQuarterTex6(w, h, useF16)
|
|
865
|
+
: (textureInput ? wgslPrepQuarterTex(w, h, useF16) : wgslPrepQuarter(w, h, useF16))),
|
|
866
|
+
pipeAsync(wgslConv(CI0, C1, QW, QH, W8, H8, 2, false, useF16)),
|
|
867
|
+
pipeAsync(wgslConv(C1, C2, W8, H8, W16, H16, 2, false, useF16)),
|
|
868
|
+
pipeAsync(useF16
|
|
869
|
+
? (sgTuned ? wgslConvRBSg : wgslConvRB)(C2, C2, W16, H16, W16, H16, false, convTune)
|
|
870
|
+
: wgslConv(C2, C2, W16, H16, W16, H16, 1, false, false)),
|
|
871
|
+
pipeAsync(useF16
|
|
872
|
+
? (sgTuned ? wgslConvRBSg : wgslConvRB)(C2, C2, W16, H16, W16, H16, true, convTune)
|
|
873
|
+
: wgslConv(C2, C2, W16, H16, W16, H16, 1, true, false)),
|
|
874
|
+
pipeAsync(wgslDeconv(C2, 5, W16, H16, W8, H8, useF16)),
|
|
875
|
+
pipeAsync(textureOutput ? wgslFlowOutTex(w, h, staticGuard, refi) : wgslFlowOut(w, h)),
|
|
876
|
+
]);
|
|
877
|
+
// texture-output mode: flow bind groups are per output texture (small ring - cache them)
|
|
878
|
+
const flowBgCache = new Map();
|
|
879
|
+
function flowBgFor(tex) {
|
|
880
|
+
if (!flowBgCache.has(tex)) {
|
|
881
|
+
// evict BEFORE inserting - clearing after would wipe the fresh entry and
|
|
882
|
+
// hand setBindGroup an undefined (latent until a caller rings >24 textures)
|
|
883
|
+
if (flowBgCache.size > 24) flowBgCache.clear();
|
|
884
|
+
const entries = [
|
|
885
|
+
{ binding: 0, resource: { buffer: tmp8 } },
|
|
886
|
+
{ binding: 1, resource: { buffer: imgs } },
|
|
887
|
+
{ binding: 2, resource: tex.createView() }];
|
|
888
|
+
if (refi) entries.push({ binding: 3, resource: { buffer: rRes } }); // tfact2 residual
|
|
889
|
+
flowBgCache.set(tex, device.createBindGroup({ layout: pFlow.getBindGroupLayout(0), entries }));
|
|
890
|
+
}
|
|
891
|
+
return flowBgCache.get(tex);
|
|
892
|
+
}
|
|
893
|
+
|
|
894
|
+
// buffer-input prep bind groups (unused in texture mode)
|
|
895
|
+
const bgPrepFull = textureInput ? null : bg(pPrepFull, [rgba0, rgba1, imgs]);
|
|
896
|
+
const bgPrepQ = textureInput ? null : bg(pPrepQ, [rgba0, rgba1, xq, tbuf]);
|
|
897
|
+
const bgPrepQt = textureInput ? null : tbufs.map(tb => bg(pPrepQ, [rgba0, rgba1, xq, tb]));
|
|
898
|
+
// texture-mode prep bind groups are built per texture pair and cached (ping-pong -> few combos).
|
|
899
|
+
// Keyed by texture IDENTITY, not label: callers recreate pools reusing the same
|
|
900
|
+
// labels (resolution change), and a label-keyed cache would keep serving bind
|
|
901
|
+
// groups of destroyed textures - every submit fails async validation and the
|
|
902
|
+
// mids silently replay stale content.
|
|
903
|
+
const texBgCache = new Map();
|
|
904
|
+
function texPrepBgs(texA, texB) {
|
|
905
|
+
const key = texBgId(texA) + '|' + texBgId(texB);
|
|
906
|
+
if (!texBgCache.has(key)) {
|
|
907
|
+
// evict BEFORE inserting - clearing after would wipe the fresh entry too
|
|
908
|
+
if (texBgCache.size > 12) texBgCache.clear(); // texture set changed wholesale
|
|
909
|
+
const va = texA.createView(), vb = texB.createView();
|
|
910
|
+
texBgCache.set(key, {
|
|
911
|
+
full: device.createBindGroup({ layout: pPrepFull.getBindGroupLayout(0), entries: [
|
|
912
|
+
{ binding: 0, resource: va }, { binding: 1, resource: vb },
|
|
913
|
+
{ binding: 2, resource: sampler }, { binding: 3, resource: { buffer: imgs } }] }),
|
|
914
|
+
q: tfact
|
|
915
|
+
? [device.createBindGroup({ layout: pPrepQ.getBindGroupLayout(0), entries: [
|
|
916
|
+
{ binding: 0, resource: va }, { binding: 1, resource: vb },
|
|
917
|
+
{ binding: 2, resource: sampler }, { binding: 3, resource: { buffer: xq } }] })]
|
|
918
|
+
: tbufs.map(tb => device.createBindGroup({ layout: pPrepQ.getBindGroupLayout(0), entries: [
|
|
919
|
+
{ binding: 0, resource: va }, { binding: 1, resource: vb },
|
|
920
|
+
{ binding: 2, resource: sampler }, { binding: 3, resource: { buffer: xq } },
|
|
921
|
+
{ binding: 4, resource: { buffer: tb } }] })),
|
|
922
|
+
});
|
|
923
|
+
}
|
|
924
|
+
return texBgCache.get(key);
|
|
925
|
+
}
|
|
926
|
+
|
|
927
|
+
// tfact-only state: trunk feature buffer + FiLM params (tiny t-MLP runs in JS)
|
|
928
|
+
const hbuf = tfact ? bufBytes(C2 * H16 * W16 * (useF16 ? 2 : 4)) : null;
|
|
929
|
+
const filmBuf = tfact ? buf(2 * C2) : null;
|
|
930
|
+
const pFilm = tfact ? pipe(wgslFilm(C2, H16 * W16, useF16)) : null;
|
|
931
|
+
let bgFilm = null, filmW = null;
|
|
932
|
+
if (tfact) {
|
|
933
|
+
bgFilm = bg(pFilm, [hbuf, filmBuf, f16a]);
|
|
934
|
+
const f32 = (name) => {
|
|
935
|
+
const m = weightsManifest[name];
|
|
936
|
+
return new Float32Array(weightsBin, m.offset * 4, m.shape.reduce((a, b) => a * b, 1));
|
|
937
|
+
};
|
|
938
|
+
filmW = { w0: f32('film.0.weight'), b0: f32('film.0.bias'),
|
|
939
|
+
w2: f32('film.2.weight'), b2: f32('film.2.bias') };
|
|
940
|
+
}
|
|
941
|
+
// tfact2 refine head: occlusion repair at QUARTER res; the residual is folded
|
|
942
|
+
// into the flowout pass (withRes)
|
|
943
|
+
const RW4 = w / 4, RH4 = h / 4;
|
|
944
|
+
let pRPrep = null, pRC0 = null, pRC1 = null, pROut = null;
|
|
945
|
+
let bgRPrep = null, bgRC0 = null, bgRC1 = null, bgRC2 = null, bgROut = null;
|
|
946
|
+
let RC = 0, rRes = null;
|
|
947
|
+
if (refi) {
|
|
948
|
+
RC = man['refine.c0.weight'].shape[0];
|
|
949
|
+
// the refine convs are dispatched with z = RC/4 below, and wgslConv only
|
|
950
|
+
// packs 4-wide when CO % 4 == 0 - a non-/4 head would silently mis-dispatch
|
|
951
|
+
if (RC % 4 !== 0) throw new Error('rt: refine channels must be a multiple of 4, got ' + RC);
|
|
952
|
+
const rIn = abuf(11 * RH4 * RW4);
|
|
953
|
+
const rA2 = abuf(RC * RH4 * RW4);
|
|
954
|
+
const rB2 = abuf(RC * RH4 * RW4);
|
|
955
|
+
rRes = buf(3 * RH4 * RW4);
|
|
956
|
+
[pRPrep, pRC0, pRC1, pROut] = await Promise.all([
|
|
957
|
+
pipeAsync(wgslRefinePrep(w, h, useF16)),
|
|
958
|
+
pipeAsync(wgslConv(11, RC, RW4, RH4, RW4, RH4, 1, false, useF16)),
|
|
959
|
+
pipeAsync(wgslConv(RC, RC, RW4, RH4, RW4, RH4, 1, false, useF16)),
|
|
960
|
+
pipeAsync(wgslRefineOut(RC, RW4, RH4, useF16)),
|
|
961
|
+
]);
|
|
962
|
+
bgRPrep = bg(pRPrep, [tmp8, imgs, rIn]);
|
|
963
|
+
bgRC0 = bg(pRC0, [rIn, convW('refine.c0.weight'), wbuf['refine.c0.bias'], wbuf['refine.a0.weight'], rA2]);
|
|
964
|
+
bgRC1 = bg(pRC1, [rA2, convW('refine.c1.weight'), wbuf['refine.c1.bias'], wbuf['refine.a1.weight'], rB2]);
|
|
965
|
+
bgRC2 = bg(pRC1, [rB2, convW('refine.c2.weight'), wbuf['refine.c2.bias'], wbuf['refine.a2.weight'], rA2]);
|
|
966
|
+
bgROut = bg(pROut, [rA2, wbuf['refine.c3.weight'], wbuf['refine.c3.bias'], rRes]);
|
|
967
|
+
}
|
|
968
|
+
|
|
969
|
+
// scratch reused across calls: these run once per mid (writeBuffer reads the
|
|
970
|
+
// array synchronously, so reuse is safe) - allocating per call is pure GC churn
|
|
971
|
+
const tScratch = new Float32Array(1);
|
|
972
|
+
const filmScratch = tfact ? { out: new Float32Array(2 * C2), h: new Float32Array(filmW.b0.length) } : null;
|
|
973
|
+
function filmParams(t) {
|
|
974
|
+
const HN = filmW.b0.length, { out, h } = filmScratch;
|
|
975
|
+
for (let j = 0; j < HN; j++) h[j] = Math.max(0, filmW.w0[j] * t + filmW.b0[j]);
|
|
976
|
+
for (let k = 0; k < 2 * C2; k++) {
|
|
977
|
+
let s = filmW.b2[k];
|
|
978
|
+
const row = k * HN;
|
|
979
|
+
for (let j = 0; j < HN; j++) s += filmW.w2[row + j] * h[j];
|
|
980
|
+
out[k] = s;
|
|
981
|
+
}
|
|
982
|
+
return out;
|
|
983
|
+
}
|
|
984
|
+
const bgConv0a = bg(pConv0a, [xq, convW('block0.conv0.0.0.weight'), wbuf['block0.conv0.0.0.bias'], wbuf['block0.conv0.0.1.weight'], f8]);
|
|
985
|
+
const bgConv0b = bg(pConv0b, [f8, convW('block0.conv0.1.0.weight'), wbuf['block0.conv0.1.0.bias'], wbuf['block0.conv0.1.1.weight'], f16a]);
|
|
986
|
+
// convblock ping-pong: a->b, b->a, ... 8th conv adds the residual (f16r = copy of f16a)
|
|
987
|
+
const bgB = [];
|
|
988
|
+
let src = f16a, dst = f16b;
|
|
989
|
+
for (let i = 0; i < 8; i++) {
|
|
990
|
+
const wn = `block0.convblock.${i}.0.weight`, bn = `block0.convblock.${i}.0.bias`, an = `block0.convblock.${i}.1.weight`;
|
|
991
|
+
if (i < 7) {
|
|
992
|
+
bgB.push({ p: pConvB, g: bg(pConvB, [src, convW(wn), wbuf[bn], wbuf[an], dst]) });
|
|
993
|
+
} else {
|
|
994
|
+
bgB.push({ p: pConvBR, g: bg(pConvBR, [src, convW(wn), wbuf[bn], wbuf[an], dst, f16r]) });
|
|
995
|
+
}
|
|
996
|
+
[src, dst] = [dst, src];
|
|
997
|
+
}
|
|
998
|
+
const f16out = src; // after 8 convs
|
|
999
|
+
const bgDeconv = bg(pDeconv, [f16out, wbuf['block0.lastconv.weight'], wbuf['block0.lastconv.bias'], tmp8]);
|
|
1000
|
+
const bgFlow = textureOutput ? null : bg(pFlow, [tmp8, imgs, outp]);
|
|
1001
|
+
|
|
1002
|
+
const gx = (n) => Math.ceil(n / WG);
|
|
1003
|
+
// register-blocked convblock kernel covers 16x16 output per workgroup
|
|
1004
|
+
const cbX = useF16 ? Math.ceil(W16 / 16) : gx(W16);
|
|
1005
|
+
const cbY = useF16 ? Math.ceil(H16 / 16) : gx(H16);
|
|
1006
|
+
const cbZ = useF16 ? C2 / ((convTune && convTune.coc) || 4) : C2 / 4;
|
|
1007
|
+
|
|
1008
|
+
// per-stage GPU times via timestamp queries (needs 'timestamp-query' on the device)
|
|
1009
|
+
async function profile(rgbaA, rgbaB) {
|
|
1010
|
+
if (!device.features.has('timestamp-query')) return 'no timestamp-query feature';
|
|
1011
|
+
const stages = [
|
|
1012
|
+
['prepFull', pPrepFull, bgPrepFull, [gx(w), gx(h), 1]],
|
|
1013
|
+
['prepQ', pPrepQ, bgPrepQ, [gx(QW), gx(QH), 1]],
|
|
1014
|
+
['conv0a', pConv0a, bgConv0a, [gx(W8), gx(H8), Z0A]],
|
|
1015
|
+
['conv0b', pConv0b, bgConv0b, [gx(W16), gx(H16), C2 / 4]],
|
|
1016
|
+
...bgB.map(({ p, g }, i) => [`convB${i}`, p, g, [cbX, cbY, cbZ]]),
|
|
1017
|
+
['deconv', pDeconv, bgDeconv, [gx(W8), gx(H8), 5]],
|
|
1018
|
+
['flow', pFlow, bgFlow, [gx(w), gx(h), 1]],
|
|
1019
|
+
];
|
|
1020
|
+
const qs = device.createQuerySet({ type: 'timestamp', count: stages.length * 2 });
|
|
1021
|
+
const qbuf = device.createBuffer({ size: stages.length * 16, usage: GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC });
|
|
1022
|
+
const qread = device.createBuffer({ size: stages.length * 16, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ });
|
|
1023
|
+
device.queue.writeBuffer(tbuf, 0, new Float32Array([0.5]));
|
|
1024
|
+
device.queue.writeBuffer(rgba0, 0, rgbaA.buffer, rgbaA.byteOffset, w * h * 4);
|
|
1025
|
+
device.queue.writeBuffer(rgba1, 0, rgbaB.buffer, rgbaB.byteOffset, w * h * 4);
|
|
1026
|
+
const enc = device.createCommandEncoder();
|
|
1027
|
+
stages.forEach(([name, p, g, d], i) => {
|
|
1028
|
+
if (name === 'convB0') enc.copyBufferToBuffer(f16a, 0, f16r, 0, actBytes);
|
|
1029
|
+
const pass = enc.beginComputePass({ timestampWrites: {
|
|
1030
|
+
querySet: qs, beginningOfPassWriteIndex: i * 2, endOfPassWriteIndex: i * 2 + 1 } });
|
|
1031
|
+
pass.setPipeline(p); pass.setBindGroup(0, g);
|
|
1032
|
+
pass.dispatchWorkgroups(d[0], d[1], d[2]);
|
|
1033
|
+
pass.end();
|
|
1034
|
+
});
|
|
1035
|
+
enc.resolveQuerySet(qs, 0, stages.length * 2, qbuf, 0);
|
|
1036
|
+
enc.copyBufferToBuffer(qbuf, 0, qread, 0, stages.length * 16);
|
|
1037
|
+
device.queue.submit([enc.finish()]);
|
|
1038
|
+
await qread.mapAsync(GPUMapMode.READ);
|
|
1039
|
+
const ts = new BigUint64Array(qread.getMappedRange().slice(0));
|
|
1040
|
+
qread.unmap();
|
|
1041
|
+
return stages.map(([name], i) =>
|
|
1042
|
+
`${name}: ${(Number(ts[i * 2 + 1] - ts[i * 2]) / 1e6).toFixed(2)}ms`).join(' · ');
|
|
1043
|
+
}
|
|
1044
|
+
|
|
1045
|
+
async function run(rgbaA, rgbaB, t = 0.5) {
|
|
1046
|
+
if (tfact) throw new Error('rt: tfact weights have no buffer-mode run()');
|
|
1047
|
+
device.queue.writeBuffer(tbuf, 0, new Float32Array([t]));
|
|
1048
|
+
device.queue.writeBuffer(rgba0, 0, rgbaA.buffer, rgbaA.byteOffset, w * h * 4);
|
|
1049
|
+
device.queue.writeBuffer(rgba1, 0, rgbaB.buffer, rgbaB.byteOffset, w * h * 4);
|
|
1050
|
+
const enc = device.createCommandEncoder();
|
|
1051
|
+
const pass = enc.beginComputePass();
|
|
1052
|
+
pass.setPipeline(pPrepFull); pass.setBindGroup(0, bgPrepFull); pass.dispatchWorkgroups(gx(w), gx(h));
|
|
1053
|
+
pass.setPipeline(pPrepQ); pass.setBindGroup(0, bgPrepQ); pass.dispatchWorkgroups(gx(QW), gx(QH));
|
|
1054
|
+
pass.setPipeline(pConv0a); pass.setBindGroup(0, bgConv0a); pass.dispatchWorkgroups(gx(W8), gx(H8), Z0A);
|
|
1055
|
+
pass.setPipeline(pConv0b); pass.setBindGroup(0, bgConv0b); pass.dispatchWorkgroups(gx(W16), gx(H16), C2 / 4);
|
|
1056
|
+
pass.end();
|
|
1057
|
+
// residual copy AFTER conv0b (f16r = f16a snapshot)
|
|
1058
|
+
enc.copyBufferToBuffer(f16a, 0, f16r, 0, actBytes);
|
|
1059
|
+
const pass2 = enc.beginComputePass();
|
|
1060
|
+
for (const { p, g } of bgB) {
|
|
1061
|
+
pass2.setPipeline(p); pass2.setBindGroup(0, g); pass2.dispatchWorkgroups(cbX, cbY, cbZ);
|
|
1062
|
+
}
|
|
1063
|
+
pass2.setPipeline(pDeconv); pass2.setBindGroup(0, bgDeconv); pass2.dispatchWorkgroups(gx(W8), gx(H8), 5);
|
|
1064
|
+
pass2.setPipeline(pFlow); pass2.setBindGroup(0, bgFlow); pass2.dispatchWorkgroups(gx(w), gx(h));
|
|
1065
|
+
pass2.end();
|
|
1066
|
+
enc.copyBufferToBuffer(outp, 0, staging, 0, w * h * 4);
|
|
1067
|
+
device.queue.submit([enc.finish()]);
|
|
1068
|
+
await staging.mapAsync(GPUMapMode.READ);
|
|
1069
|
+
const out = new Uint8Array(staging.getMappedRange().slice(0));
|
|
1070
|
+
staging.unmap();
|
|
1071
|
+
return out;
|
|
1072
|
+
}
|
|
1073
|
+
|
|
1074
|
+
// batched: upload/bind the pair once, produce mids for every t in ONE submit.
|
|
1075
|
+
// Buffer mode: a/b are RGBA arrays. Texture mode: a/b are GPUTextures (zero CPU pixels).
|
|
1076
|
+
// With textureOutput, outTexs[i] receives mid i and NOTHING is read back - returns null.
|
|
1077
|
+
async function runMulti(a, b, ts, outTexs) {
|
|
1078
|
+
if (tfact) { // factored graph: trunk once, head per t
|
|
1079
|
+
prepPair(a, b);
|
|
1080
|
+
for (let i = 0; i < ts.length; i++) runT(ts[i], outTexs[i]);
|
|
1081
|
+
return null;
|
|
1082
|
+
}
|
|
1083
|
+
if (ts.length > MAXT) throw new Error('too many timesteps');
|
|
1084
|
+
for (let i = 0; i < ts.length; i++) {
|
|
1085
|
+
device.queue.writeBuffer(tbufs[i], 0, new Float32Array([ts[i]]));
|
|
1086
|
+
}
|
|
1087
|
+
let tbg = null;
|
|
1088
|
+
if (textureInput) {
|
|
1089
|
+
tbg = texPrepBgs(a, b);
|
|
1090
|
+
} else {
|
|
1091
|
+
device.queue.writeBuffer(rgba0, 0, a.buffer, a.byteOffset, w * h * 4);
|
|
1092
|
+
device.queue.writeBuffer(rgba1, 0, b.buffer, b.byteOffset, w * h * 4);
|
|
1093
|
+
}
|
|
1094
|
+
const enc = device.createCommandEncoder();
|
|
1095
|
+
{
|
|
1096
|
+
const pass = enc.beginComputePass();
|
|
1097
|
+
pass.setPipeline(pPrepFull); pass.setBindGroup(0, tbg ? tbg.full : bgPrepFull); pass.dispatchWorkgroups(gx(w), gx(h));
|
|
1098
|
+
pass.end();
|
|
1099
|
+
}
|
|
1100
|
+
for (let i = 0; i < ts.length; i++) {
|
|
1101
|
+
const pass = enc.beginComputePass();
|
|
1102
|
+
pass.setPipeline(pPrepQ); pass.setBindGroup(0, tbg ? tbg.q[i] : bgPrepQt[i]); pass.dispatchWorkgroups(gx(QW), gx(QH));
|
|
1103
|
+
pass.setPipeline(pConv0a); pass.setBindGroup(0, bgConv0a); pass.dispatchWorkgroups(gx(W8), gx(H8), Z0A);
|
|
1104
|
+
pass.setPipeline(pConv0b); pass.setBindGroup(0, bgConv0b); pass.dispatchWorkgroups(gx(W16), gx(H16), C2 / 4);
|
|
1105
|
+
pass.end();
|
|
1106
|
+
enc.copyBufferToBuffer(f16a, 0, f16r, 0, actBytes);
|
|
1107
|
+
const pass2 = enc.beginComputePass();
|
|
1108
|
+
for (const { p, g } of bgB) {
|
|
1109
|
+
pass2.setPipeline(p); pass2.setBindGroup(0, g); pass2.dispatchWorkgroups(cbX, cbY, cbZ);
|
|
1110
|
+
}
|
|
1111
|
+
pass2.setPipeline(pDeconv); pass2.setBindGroup(0, bgDeconv); pass2.dispatchWorkgroups(gx(W8), gx(H8), 5);
|
|
1112
|
+
pass2.setPipeline(pFlow);
|
|
1113
|
+
pass2.setBindGroup(0, textureOutput ? flowBgFor(outTexs[i]) : bgFlow);
|
|
1114
|
+
pass2.dispatchWorkgroups(gx(w), gx(h));
|
|
1115
|
+
pass2.end();
|
|
1116
|
+
if (!textureOutput) enc.copyBufferToBuffer(outp, 0, stagings[i], 0, w * h * 4);
|
|
1117
|
+
}
|
|
1118
|
+
device.queue.submit([enc.finish()]);
|
|
1119
|
+
if (textureOutput) return null; // mids live in outTexs, nothing crosses the bus
|
|
1120
|
+
// map all stagings concurrently - sequential awaits cost ~1ms each
|
|
1121
|
+
await Promise.all(stagings.slice(0, ts.length).map(s => s.mapAsync(GPUMapMode.READ)));
|
|
1122
|
+
const outs = [];
|
|
1123
|
+
for (let i = 0; i < ts.length; i++) {
|
|
1124
|
+
outs.push(new Uint8Array(stagings[i].getMappedRange().slice(0)));
|
|
1125
|
+
stagings[i].unmap();
|
|
1126
|
+
}
|
|
1127
|
+
return outs;
|
|
1128
|
+
}
|
|
1129
|
+
|
|
1130
|
+
// ---- lazy per-mid API (texture in/out mode) ----
|
|
1131
|
+
// The queue is FIFO: a mid's present blit executes after EVERYTHING submitted
|
|
1132
|
+
// before it. Batching all mids upfront therefore makes the FIRST mid wait for
|
|
1133
|
+
// the WHOLE batch on the GPU. prepPair + runT let the caller submit each mid
|
|
1134
|
+
// just-in-time so present blits interleave with computes - the required
|
|
1135
|
+
// presentation delay shrinks from ~2x batch time to ~one mid time.
|
|
1136
|
+
let curPrep = null;
|
|
1137
|
+
function prepPair(a, b) {
|
|
1138
|
+
if (!textureInput) throw new Error('prepPair: texture-input mode only');
|
|
1139
|
+
curPrep = texPrepBgs(a, b);
|
|
1140
|
+
const enc = device.createCommandEncoder();
|
|
1141
|
+
if (tfact) {
|
|
1142
|
+
// the WHOLE t-free trunk runs here, once per pair
|
|
1143
|
+
const pass = enc.beginComputePass();
|
|
1144
|
+
pass.setPipeline(pPrepFull); pass.setBindGroup(0, curPrep.full); pass.dispatchWorkgroups(gx(w), gx(h));
|
|
1145
|
+
pass.setPipeline(pPrepQ); pass.setBindGroup(0, curPrep.q[0]); pass.dispatchWorkgroups(gx(QW), gx(QH));
|
|
1146
|
+
pass.setPipeline(pConv0a); pass.setBindGroup(0, bgConv0a); pass.dispatchWorkgroups(gx(W8), gx(H8), Z0A);
|
|
1147
|
+
pass.setPipeline(pConv0b); pass.setBindGroup(0, bgConv0b); pass.dispatchWorkgroups(gx(W16), gx(H16), C2 / 4);
|
|
1148
|
+
pass.end();
|
|
1149
|
+
enc.copyBufferToBuffer(f16a, 0, f16r, 0, actBytes); // feat0 residual for the head
|
|
1150
|
+
const pass2 = enc.beginComputePass();
|
|
1151
|
+
for (let i = 0; i < 6; i++) {
|
|
1152
|
+
pass2.setPipeline(bgB[i].p); pass2.setBindGroup(0, bgB[i].g); pass2.dispatchWorkgroups(cbX, cbY, cbZ);
|
|
1153
|
+
}
|
|
1154
|
+
pass2.end();
|
|
1155
|
+
enc.copyBufferToBuffer(f16a, 0, hbuf, 0, actBytes); // trunk features, reused per t
|
|
1156
|
+
} else {
|
|
1157
|
+
const pass = enc.beginComputePass();
|
|
1158
|
+
pass.setPipeline(pPrepFull); pass.setBindGroup(0, curPrep.full); pass.dispatchWorkgroups(gx(w), gx(h));
|
|
1159
|
+
pass.end();
|
|
1160
|
+
}
|
|
1161
|
+
device.queue.submit([enc.finish()]);
|
|
1162
|
+
}
|
|
1163
|
+
function runT(t, outTex) {
|
|
1164
|
+
if (!curPrep) throw new Error('runT before prepPair');
|
|
1165
|
+
if (!textureOutput) throw new Error('runT: texture-output mode only');
|
|
1166
|
+
const enc = device.createCommandEncoder();
|
|
1167
|
+
if (tfact) {
|
|
1168
|
+
// per-mid: FiLM(t) + convblocks 6,7 (+feat0 residual) + deconv + flow
|
|
1169
|
+
device.queue.writeBuffer(filmBuf, 0, filmParams(t));
|
|
1170
|
+
const pass = enc.beginComputePass();
|
|
1171
|
+
pass.setPipeline(pFilm); pass.setBindGroup(0, bgFilm);
|
|
1172
|
+
pass.dispatchWorkgroups(Math.ceil((C2 * H16 * W16) / 256));
|
|
1173
|
+
pass.setPipeline(bgB[6].p); pass.setBindGroup(0, bgB[6].g); pass.dispatchWorkgroups(cbX, cbY, cbZ);
|
|
1174
|
+
pass.setPipeline(bgB[7].p); pass.setBindGroup(0, bgB[7].g); pass.dispatchWorkgroups(cbX, cbY, cbZ);
|
|
1175
|
+
pass.setPipeline(pDeconv); pass.setBindGroup(0, bgDeconv); pass.dispatchWorkgroups(gx(W8), gx(H8), 5);
|
|
1176
|
+
if (refi) { // quarter-res refine chain; the flowout below folds the residual in
|
|
1177
|
+
pass.setPipeline(pRPrep); pass.setBindGroup(0, bgRPrep); pass.dispatchWorkgroups(gx(RW4), gx(RH4));
|
|
1178
|
+
pass.setPipeline(pRC0); pass.setBindGroup(0, bgRC0); pass.dispatchWorkgroups(gx(RW4), gx(RH4), RC / 4);
|
|
1179
|
+
pass.setPipeline(pRC1); pass.setBindGroup(0, bgRC1); pass.dispatchWorkgroups(gx(RW4), gx(RH4), RC / 4);
|
|
1180
|
+
pass.setPipeline(pRC1); pass.setBindGroup(0, bgRC2); pass.dispatchWorkgroups(gx(RW4), gx(RH4), RC / 4);
|
|
1181
|
+
pass.setPipeline(pROut); pass.setBindGroup(0, bgROut); pass.dispatchWorkgroups(gx(RW4), gx(RH4));
|
|
1182
|
+
}
|
|
1183
|
+
pass.setPipeline(pFlow); pass.setBindGroup(0, flowBgFor(outTex)); pass.dispatchWorkgroups(gx(w), gx(h));
|
|
1184
|
+
pass.end();
|
|
1185
|
+
device.queue.submit([enc.finish()]);
|
|
1186
|
+
return;
|
|
1187
|
+
}
|
|
1188
|
+
// single tbuf is safe: writeBuffer and submits are queue-ordered
|
|
1189
|
+
tScratch[0] = t;
|
|
1190
|
+
device.queue.writeBuffer(tbufs[0], 0, tScratch);
|
|
1191
|
+
const pass = enc.beginComputePass();
|
|
1192
|
+
pass.setPipeline(pPrepQ); pass.setBindGroup(0, curPrep.q[0]); pass.dispatchWorkgroups(gx(QW), gx(QH));
|
|
1193
|
+
pass.setPipeline(pConv0a); pass.setBindGroup(0, bgConv0a); pass.dispatchWorkgroups(gx(W8), gx(H8), Z0A);
|
|
1194
|
+
pass.setPipeline(pConv0b); pass.setBindGroup(0, bgConv0b); pass.dispatchWorkgroups(gx(W16), gx(H16), C2 / 4);
|
|
1195
|
+
pass.end();
|
|
1196
|
+
enc.copyBufferToBuffer(f16a, 0, f16r, 0, actBytes);
|
|
1197
|
+
const pass2 = enc.beginComputePass();
|
|
1198
|
+
for (const { p, g } of bgB) {
|
|
1199
|
+
pass2.setPipeline(p); pass2.setBindGroup(0, g); pass2.dispatchWorkgroups(cbX, cbY, cbZ);
|
|
1200
|
+
}
|
|
1201
|
+
pass2.setPipeline(pDeconv); pass2.setBindGroup(0, bgDeconv); pass2.dispatchWorkgroups(gx(W8), gx(H8), 5);
|
|
1202
|
+
pass2.setPipeline(pFlow); pass2.setBindGroup(0, flowBgFor(outTex)); pass2.dispatchWorkgroups(gx(w), gx(h));
|
|
1203
|
+
pass2.end();
|
|
1204
|
+
device.queue.submit([enc.finish()]);
|
|
1205
|
+
}
|
|
1206
|
+
|
|
1207
|
+
return { run, runMulti, prepPair, runT, profile, w, h };
|
|
1208
|
+
}
|
|
1209
|
+
|
|
1210
|
+
|
|
1211
|
+
// ---- one-shot conv autotune: bench wgslConvRB variants on this device ----
|
|
1212
|
+
// Returns the fastest {coc, slab, ms}. ~200-400ms of GPU time; call it once per
|
|
1213
|
+
// (adapter, model width, resolution) and persist the answer - relative ranking
|
|
1214
|
+
// is what matters, so light background load is tolerable.
|
|
1215
|
+
export async function tuneConvRB(device, { ci, co, w16, h16 }) {
|
|
1216
|
+
const base = [{ coc: 4, slab: 20 }, { coc: 8, slab: 20 }, { coc: 8, slab: 12 }, { coc: 4, slab: 12 }]
|
|
1217
|
+
.filter(v => co % v.coc === 0 && (v.slab * 324 + v.coc * v.slab * 9) * 2 <= 16384);
|
|
1218
|
+
// subgroup variants (weights via subgroupBroadcastFirst, no shared staging):
|
|
1219
|
+
// measured +20% at the 360p grid and -19% at 720p/coc8 on a 4060 Ti - exactly
|
|
1220
|
+
// why they go through the tuner instead of being hardcoded
|
|
1221
|
+
const variants = device.features.has('subgroups')
|
|
1222
|
+
? [...base, ...base.map(v => ({ ...v, sg: true }))]
|
|
1223
|
+
: base;
|
|
1224
|
+
const buf = (bytes) => device.createBuffer({ size: Math.ceil(bytes / 4) * 4,
|
|
1225
|
+
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST });
|
|
1226
|
+
const src = buf(ci * w16 * h16 * 2), dst = buf(co * w16 * h16 * 2);
|
|
1227
|
+
const wgt = buf(co * ci * 9 * 2), bias = buf(co * 4), alpha = buf(co * 4);
|
|
1228
|
+
let best = null;
|
|
1229
|
+
for (const v of variants) {
|
|
1230
|
+
const gen = v.sg ? wgslConvRBSg : wgslConvRB;
|
|
1231
|
+
const p = device.createComputePipeline({ layout: 'auto', compute: {
|
|
1232
|
+
module: device.createShaderModule({ code: gen(ci, co, w16, h16, w16, h16, false, v) }),
|
|
1233
|
+
entryPoint: 'main' } });
|
|
1234
|
+
const bg = device.createBindGroup({ layout: p.getBindGroupLayout(0), entries: [
|
|
1235
|
+
{ binding: 0, resource: { buffer: src } }, { binding: 1, resource: { buffer: wgt } },
|
|
1236
|
+
{ binding: 2, resource: { buffer: bias } }, { binding: 3, resource: { buffer: alpha } },
|
|
1237
|
+
{ binding: 4, resource: { buffer: dst } }] });
|
|
1238
|
+
const run = (k) => {
|
|
1239
|
+
const enc = device.createCommandEncoder();
|
|
1240
|
+
const pass = enc.beginComputePass();
|
|
1241
|
+
pass.setPipeline(p); pass.setBindGroup(0, bg);
|
|
1242
|
+
for (let i = 0; i < k; i++) pass.dispatchWorkgroups(Math.ceil(w16 / 16), Math.ceil(h16 / 16), co / v.coc);
|
|
1243
|
+
pass.end();
|
|
1244
|
+
device.queue.submit([enc.finish()]);
|
|
1245
|
+
};
|
|
1246
|
+
run(3); await device.queue.onSubmittedWorkDone(); // warm (incl pipeline compile)
|
|
1247
|
+
const t0 = performance.now();
|
|
1248
|
+
run(30); await device.queue.onSubmittedWorkDone();
|
|
1249
|
+
const ms = (performance.now() - t0) / 30;
|
|
1250
|
+
if (!best || ms < best.ms) best = { ...v, ms };
|
|
1251
|
+
}
|
|
1252
|
+
[src, dst, wgt, bias, alpha].forEach(b => b.destroy());
|
|
1253
|
+
return best;
|
|
1254
|
+
}
|