@framecast/rt 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +18 -0
- package/README.md +78 -0
- package/index.d.ts +57 -0
- package/package.json +43 -0
- package/rt.js +1254 -0
- package/sr.js +205 -0
package/sr.js
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
// Framecast tiny 2x super-resolution pass (anime upscale) for the present path.
|
|
2
|
+
// Residual-vs-bilinear: out = bilinear2x(src) + detail, detail = 3 tiny convs (c=16).
|
|
3
|
+
// Texture in -> texture out (2x size); everything stays on the GPU.
|
|
4
|
+
// Weights: assets/rt_sr.{bin,json} (tools/export_sr_weights.py).
|
|
5
|
+
|
|
6
|
+
import { wgslConvRB, WGSL_TO_F16 } from './rt.js';
|
|
7
|
+
|
|
8
|
+
const WG = 8;
|
|
9
|
+
|
|
10
|
+
function wgslIn(C) {
|
|
11
|
+
return /* wgsl */`
|
|
12
|
+
enable f16;
|
|
13
|
+
@group(0) @binding(0) var src: texture_2d<f32>;
|
|
14
|
+
@group(0) @binding(1) var samp: sampler;
|
|
15
|
+
@group(0) @binding(2) var<storage, read> wgt: array<f32>; // [C,3,3,3]
|
|
16
|
+
@group(0) @binding(3) var<storage, read> bias: array<f32>; // [C]
|
|
17
|
+
@group(0) @binding(4) var<storage, read> alpha: array<f32>; // [C]
|
|
18
|
+
@group(0) @binding(5) var<storage, read_write> dst: array<f16>; // [C,H,W]
|
|
19
|
+
@group(0) @binding(6) var<storage, read> dims: array<u32>; // [W,H]
|
|
20
|
+
|
|
21
|
+
@compute @workgroup_size(${WG}, ${WG})
|
|
22
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
23
|
+
let W = i32(dims[0]); let H = i32(dims[1]);
|
|
24
|
+
let x = i32(gid.x); let y = i32(gid.y);
|
|
25
|
+
if (x >= W || y >= H) { return; }
|
|
26
|
+
var px: array<vec3<f32>, 9>;
|
|
27
|
+
for (var ky = 0; ky < 3; ky++) {
|
|
28
|
+
for (var kx = 0; kx < 3; kx++) {
|
|
29
|
+
let sx = clamp(x + kx - 1, 0, W - 1);
|
|
30
|
+
let sy = clamp(y + ky - 1, 0, H - 1);
|
|
31
|
+
let uv = (vec2<f32>(f32(sx), f32(sy)) + 0.5) / vec2<f32>(f32(W), f32(H));
|
|
32
|
+
let c = textureSampleLevel(src, samp, uv, 0.0).rgb;
|
|
33
|
+
px[ky * 3 + kx] = vec3<f32>(c.b, c.g, c.r); // BGR domain like the rest
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
for (var co = 0; co < ${C}; co++) {
|
|
37
|
+
var acc = bias[co];
|
|
38
|
+
for (var k = 0; k < 9; k++) {
|
|
39
|
+
let wb = co * 27 + k;
|
|
40
|
+
acc += px[k].x * wgt[wb] + px[k].y * wgt[wb + 9] + px[k].z * wgt[wb + 18];
|
|
41
|
+
}
|
|
42
|
+
let v = select(alpha[co] * acc, acc, acc >= 0.0);
|
|
43
|
+
dst[co * H * W + y * W + x] = f16(v);
|
|
44
|
+
}
|
|
45
|
+
}`;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
// NOTE: torch Conv2d weight is [CO,CI,3,3]; the wgslIn indexing above expects
|
|
49
|
+
// [CO][CI][k] flattened as co*27 + ci*9 + k - matches torch layout directly.
|
|
50
|
+
// mid convs use the register-blocked kernel from rt.js (2x2 patch x 4 output
|
|
51
|
+
// channels per thread, shared-memory tiles) - the naive per-pixel loops cost
|
|
52
|
+
// 3.5x more at c=32. The out conv computes ALL 12 pixel-shuffle outputs per
|
|
53
|
+
// low-res pixel in one thread - four 2x-quadrant threads would re-read the
|
|
54
|
+
// same CxHxW window four times.
|
|
55
|
+
function wgslShuffle(W, H) {
|
|
56
|
+
return /* wgsl */`
|
|
57
|
+
enable f16;
|
|
58
|
+
@group(0) @binding(0) var<storage, read> det: array<f16>; // [12,H,W] conv output
|
|
59
|
+
@group(0) @binding(1) var srcTex: texture_2d<f32>;
|
|
60
|
+
@group(0) @binding(2) var samp: sampler;
|
|
61
|
+
@group(0) @binding(3) var outTex: texture_storage_2d<rgba8unorm, write>;
|
|
62
|
+
|
|
63
|
+
@compute @workgroup_size(8, 8)
|
|
64
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
65
|
+
let ox = i32(gid.x); let oy = i32(gid.y);
|
|
66
|
+
if (ox >= ${W * 2} || oy >= ${H * 2}) { return; }
|
|
67
|
+
let x = ox / 2; let y = oy / 2;
|
|
68
|
+
let sub = (oy & 1) * 2 + (ox & 1);
|
|
69
|
+
let p = y * ${W} + x;
|
|
70
|
+
// detail channels [b,g,r] live at co = ch*4 + sub (torch PixelShuffle layout)
|
|
71
|
+
let db = f32(det[u32(sub) * ${H * W}u + u32(p)]);
|
|
72
|
+
let dg = f32(det[(u32(sub) + 4u) * ${H * W}u + u32(p)]);
|
|
73
|
+
let dr = f32(det[(u32(sub) + 8u) * ${H * W}u + u32(p)]);
|
|
74
|
+
let uv = (vec2<f32>(f32(ox), f32(oy)) + 0.5) / vec2<f32>(${W * 2}.0, ${H * 2}.0);
|
|
75
|
+
let base = textureSampleLevel(srcTex, samp, uv, 0.0).rgb;
|
|
76
|
+
let b = clamp(base.b + db, 0.0, 1.0);
|
|
77
|
+
let g = clamp(base.g + dg, 0.0, 1.0);
|
|
78
|
+
let r = clamp(base.r + dr, 0.0, 1.0);
|
|
79
|
+
textureStore(outTex, vec2<i32>(ox, oy), vec4<f32>(r, g, b, 1.0));
|
|
80
|
+
}`;
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
export async function createSR(device, { weightsBin, weightsManifest, channels }) {
|
|
84
|
+
// channel width lives in the weights: c1 is the 3->C input conv
|
|
85
|
+
const C = channels || (weightsManifest['c1.weight'] ? weightsManifest['c1.weight'].shape[0] : 16);
|
|
86
|
+
const bufN = (bytes) => device.createBuffer({
|
|
87
|
+
size: Math.ceil(bytes / 4) * 4, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST });
|
|
88
|
+
|
|
89
|
+
const wbuf = {};
|
|
90
|
+
for (const [name, m] of Object.entries(weightsManifest)) {
|
|
91
|
+
const n = m.shape.reduce((a, b) => a * b, 1);
|
|
92
|
+
wbuf[name] = bufN(n * 4);
|
|
93
|
+
device.queue.writeBuffer(wbuf[name], 0, weightsBin, m.offset * 4, n * 4);
|
|
94
|
+
}
|
|
95
|
+
const sampler = device.createSampler({ magFilter: 'linear', minFilter: 'linear',
|
|
96
|
+
addressModeU: 'clamp-to-edge', addressModeV: 'clamp-to-edge' });
|
|
97
|
+
const pipe = (code) => device.createComputePipeline({ layout: 'auto',
|
|
98
|
+
compute: { module: device.createShaderModule({ code }), entryPoint: 'main' } });
|
|
99
|
+
const pIn = pipe(wgslIn(C));
|
|
100
|
+
// f16 copies of the heavy conv weights (the RB kernels read f16)
|
|
101
|
+
const pToH = pipe(WGSL_TO_F16);
|
|
102
|
+
const wbufH = {};
|
|
103
|
+
{
|
|
104
|
+
const enc = device.createCommandEncoder();
|
|
105
|
+
for (const name of ['c2.weight', 'c3.weight', 'c4.weight']) {
|
|
106
|
+
const n = weightsManifest[name].shape.reduce((a, b) => a * b, 1);
|
|
107
|
+
wbufH[name] = device.createBuffer({ size: Math.ceil(n / 2) * 4,
|
|
108
|
+
usage: GPUBufferUsage.STORAGE });
|
|
109
|
+
const pass = enc.beginComputePass();
|
|
110
|
+
pass.setPipeline(pToH);
|
|
111
|
+
pass.setBindGroup(0, device.createBindGroup({ layout: pToH.getBindGroupLayout(0),
|
|
112
|
+
entries: [{ binding: 0, resource: { buffer: wbuf[name] } },
|
|
113
|
+
{ binding: 1, resource: { buffer: wbufH[name] } }] }));
|
|
114
|
+
pass.dispatchWorkgroups(Math.ceil(n / 256));
|
|
115
|
+
pass.end();
|
|
116
|
+
}
|
|
117
|
+
device.queue.submit([enc.finish()]);
|
|
118
|
+
}
|
|
119
|
+
const onesAlpha = bufN(12 * 4);
|
|
120
|
+
device.queue.writeBuffer(onesAlpha, 0, new Float32Array(12).fill(1));
|
|
121
|
+
|
|
122
|
+
// per-input-size state (feature buffers + dims); keyed by "WxH"
|
|
123
|
+
const states = new Map();
|
|
124
|
+
function stateFor(w, h) {
|
|
125
|
+
const k = w + 'x' + h;
|
|
126
|
+
if (!states.has(k)) {
|
|
127
|
+
const dims = bufN(8);
|
|
128
|
+
device.queue.writeBuffer(dims, 0, new Uint32Array([w, h]));
|
|
129
|
+
const fa = bufN(C * w * h * 2);
|
|
130
|
+
const fb = bufN(C * w * h * 2);
|
|
131
|
+
const det = bufN(12 * w * h * 2);
|
|
132
|
+
const pMid = pipe(wgslConvRB(C, C, w, h, w, h, false));
|
|
133
|
+
const pOutConv = pipe(wgslConvRB(C, 12, w, h, w, h, false));
|
|
134
|
+
const pShuf = pipe(wgslShuffle(w, h));
|
|
135
|
+
const midBg = (wname, aname, sBuf, dBuf) => device.createBindGroup({
|
|
136
|
+
layout: pMid.getBindGroupLayout(0), entries: [
|
|
137
|
+
{ binding: 0, resource: { buffer: sBuf } },
|
|
138
|
+
{ binding: 1, resource: { buffer: wbufH[wname] } },
|
|
139
|
+
{ binding: 2, resource: { buffer: wbuf[wname.replace('.weight', '.bias')] } },
|
|
140
|
+
{ binding: 3, resource: { buffer: wbuf[aname] } },
|
|
141
|
+
{ binding: 4, resource: { buffer: dBuf } }] });
|
|
142
|
+
states.set(k, {
|
|
143
|
+
dims, fa, fb, det, pMid, pOutConv, pShuf,
|
|
144
|
+
bgM2: midBg('c2.weight', 'a2.weight', fa, fb),
|
|
145
|
+
bgM3: midBg('c3.weight', 'a3.weight', fb, fa),
|
|
146
|
+
bgOut: device.createBindGroup({ layout: pOutConv.getBindGroupLayout(0), entries: [
|
|
147
|
+
{ binding: 0, resource: { buffer: fa } },
|
|
148
|
+
{ binding: 1, resource: { buffer: wbufH['c4.weight'] } },
|
|
149
|
+
{ binding: 2, resource: { buffer: wbuf['c4.bias'] } },
|
|
150
|
+
{ binding: 3, resource: { buffer: onesAlpha } },
|
|
151
|
+
{ binding: 4, resource: { buffer: det } }] }),
|
|
152
|
+
});
|
|
153
|
+
if (states.size > 6) { // sizes changed wholesale
|
|
154
|
+
for (const [kk, s] of states) if (kk !== k) { s.fa.destroy(); s.fb.destroy(); s.det.destroy(); s.dims.destroy(); states.delete(kk); }
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
return states.get(k);
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
const gx = (n) => Math.ceil(n / WG);
|
|
161
|
+
// keyed by texture OBJECT, not label: the host rebuilds its texture pools on
|
|
162
|
+
// settings changes reusing the same labels - a label-keyed cache then binds
|
|
163
|
+
// views of DESTROYED textures and every SR'd frame comes out corrupted until
|
|
164
|
+
// page reload. WeakMaps also let dead textures drop their bind groups with GC.
|
|
165
|
+
const bgCache = new WeakMap();
|
|
166
|
+
|
|
167
|
+
// srcTex (w x h) -> dstTex (2w x 2h, rgba8unorm STORAGE_BINDING)
|
|
168
|
+
function process(srcTex, dstTex, w, h) {
|
|
169
|
+
const S = stateFor(w, h);
|
|
170
|
+
let perSrc = bgCache.get(srcTex);
|
|
171
|
+
if (!perSrc) { perSrc = new WeakMap(); bgCache.set(srcTex, perSrc); }
|
|
172
|
+
let bgs = perSrc.get(dstTex);
|
|
173
|
+
if (!bgs) {
|
|
174
|
+
const srcView = srcTex.createView();
|
|
175
|
+
bgs = {
|
|
176
|
+
in: device.createBindGroup({ layout: pIn.getBindGroupLayout(0), entries: [
|
|
177
|
+
{ binding: 0, resource: srcView }, { binding: 1, resource: sampler },
|
|
178
|
+
{ binding: 2, resource: { buffer: wbuf['c1.weight'] } },
|
|
179
|
+
{ binding: 3, resource: { buffer: wbuf['c1.bias'] } },
|
|
180
|
+
{ binding: 4, resource: { buffer: wbuf['a1.weight'] } },
|
|
181
|
+
{ binding: 5, resource: { buffer: S.fa } },
|
|
182
|
+
{ binding: 6, resource: { buffer: S.dims } }] }),
|
|
183
|
+
shuf: device.createBindGroup({ layout: S.pShuf.getBindGroupLayout(0), entries: [
|
|
184
|
+
{ binding: 0, resource: { buffer: S.det } },
|
|
185
|
+
{ binding: 1, resource: srcView }, { binding: 2, resource: sampler },
|
|
186
|
+
{ binding: 3, resource: dstTex.createView() }] }),
|
|
187
|
+
};
|
|
188
|
+
perSrc.set(dstTex, bgs);
|
|
189
|
+
}
|
|
190
|
+
const enc = device.createCommandEncoder();
|
|
191
|
+
const pass = enc.beginComputePass();
|
|
192
|
+
pass.setPipeline(pIn); pass.setBindGroup(0, bgs.in); pass.dispatchWorkgroups(gx(w), gx(h));
|
|
193
|
+
pass.setPipeline(S.pMid);
|
|
194
|
+
pass.setBindGroup(0, S.bgM2); pass.dispatchWorkgroups(Math.ceil(w / 16), Math.ceil(h / 16), C / 4);
|
|
195
|
+
pass.setBindGroup(0, S.bgM3); pass.dispatchWorkgroups(Math.ceil(w / 16), Math.ceil(h / 16), C / 4);
|
|
196
|
+
pass.setPipeline(S.pOutConv); pass.setBindGroup(0, S.bgOut);
|
|
197
|
+
pass.dispatchWorkgroups(Math.ceil(w / 16), Math.ceil(h / 16), 3);
|
|
198
|
+
pass.setPipeline(S.pShuf); pass.setBindGroup(0, bgs.shuf);
|
|
199
|
+
pass.dispatchWorkgroups(gx(w * 2), gx(h * 2));
|
|
200
|
+
pass.end();
|
|
201
|
+
device.queue.submit([enc.finish()]);
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
return { process };
|
|
205
|
+
}
|