@framecast/rt 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (6) hide show
  1. package/LICENSE +18 -0
  2. package/README.md +78 -0
  3. package/index.d.ts +57 -0
  4. package/package.json +43 -0
  5. package/rt.js +1254 -0
  6. package/sr.js +205 -0
package/sr.js ADDED
@@ -0,0 +1,205 @@
1
+ // Framecast tiny 2x super-resolution pass (anime upscale) for the present path.
2
+ // Residual-vs-bilinear: out = bilinear2x(src) + detail, detail = 3 tiny convs (c=16).
3
+ // Texture in -> texture out (2x size); everything stays on the GPU.
4
+ // Weights: assets/rt_sr.{bin,json} (tools/export_sr_weights.py).
5
+
6
+ import { wgslConvRB, WGSL_TO_F16 } from './rt.js';
7
+
8
+ const WG = 8;
9
+
10
+ function wgslIn(C) {
11
+ return /* wgsl */`
12
+ enable f16;
13
+ @group(0) @binding(0) var src: texture_2d<f32>;
14
+ @group(0) @binding(1) var samp: sampler;
15
+ @group(0) @binding(2) var<storage, read> wgt: array<f32>; // [C,3,3,3]
16
+ @group(0) @binding(3) var<storage, read> bias: array<f32>; // [C]
17
+ @group(0) @binding(4) var<storage, read> alpha: array<f32>; // [C]
18
+ @group(0) @binding(5) var<storage, read_write> dst: array<f16>; // [C,H,W]
19
+ @group(0) @binding(6) var<storage, read> dims: array<u32>; // [W,H]
20
+
21
+ @compute @workgroup_size(${WG}, ${WG})
22
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
23
+ let W = i32(dims[0]); let H = i32(dims[1]);
24
+ let x = i32(gid.x); let y = i32(gid.y);
25
+ if (x >= W || y >= H) { return; }
26
+ var px: array<vec3<f32>, 9>;
27
+ for (var ky = 0; ky < 3; ky++) {
28
+ for (var kx = 0; kx < 3; kx++) {
29
+ let sx = clamp(x + kx - 1, 0, W - 1);
30
+ let sy = clamp(y + ky - 1, 0, H - 1);
31
+ let uv = (vec2<f32>(f32(sx), f32(sy)) + 0.5) / vec2<f32>(f32(W), f32(H));
32
+ let c = textureSampleLevel(src, samp, uv, 0.0).rgb;
33
+ px[ky * 3 + kx] = vec3<f32>(c.b, c.g, c.r); // BGR domain like the rest
34
+ }
35
+ }
36
+ for (var co = 0; co < ${C}; co++) {
37
+ var acc = bias[co];
38
+ for (var k = 0; k < 9; k++) {
39
+ let wb = co * 27 + k;
40
+ acc += px[k].x * wgt[wb] + px[k].y * wgt[wb + 9] + px[k].z * wgt[wb + 18];
41
+ }
42
+ let v = select(alpha[co] * acc, acc, acc >= 0.0);
43
+ dst[co * H * W + y * W + x] = f16(v);
44
+ }
45
+ }`;
46
+ }
47
+
48
+ // NOTE: torch Conv2d weight is [CO,CI,3,3]; the wgslIn indexing above expects
49
+ // [CO][CI][k] flattened as co*27 + ci*9 + k - matches torch layout directly.
50
+ // mid convs use the register-blocked kernel from rt.js (2x2 patch x 4 output
51
+ // channels per thread, shared-memory tiles) - the naive per-pixel loops cost
52
+ // 3.5x more at c=32. The out conv computes ALL 12 pixel-shuffle outputs per
53
+ // low-res pixel in one thread - four 2x-quadrant threads would re-read the
54
+ // same CxHxW window four times.
55
+ function wgslShuffle(W, H) {
56
+ return /* wgsl */`
57
+ enable f16;
58
+ @group(0) @binding(0) var<storage, read> det: array<f16>; // [12,H,W] conv output
59
+ @group(0) @binding(1) var srcTex: texture_2d<f32>;
60
+ @group(0) @binding(2) var samp: sampler;
61
+ @group(0) @binding(3) var outTex: texture_storage_2d<rgba8unorm, write>;
62
+
63
+ @compute @workgroup_size(8, 8)
64
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
65
+ let ox = i32(gid.x); let oy = i32(gid.y);
66
+ if (ox >= ${W * 2} || oy >= ${H * 2}) { return; }
67
+ let x = ox / 2; let y = oy / 2;
68
+ let sub = (oy & 1) * 2 + (ox & 1);
69
+ let p = y * ${W} + x;
70
+ // detail channels [b,g,r] live at co = ch*4 + sub (torch PixelShuffle layout)
71
+ let db = f32(det[u32(sub) * ${H * W}u + u32(p)]);
72
+ let dg = f32(det[(u32(sub) + 4u) * ${H * W}u + u32(p)]);
73
+ let dr = f32(det[(u32(sub) + 8u) * ${H * W}u + u32(p)]);
74
+ let uv = (vec2<f32>(f32(ox), f32(oy)) + 0.5) / vec2<f32>(${W * 2}.0, ${H * 2}.0);
75
+ let base = textureSampleLevel(srcTex, samp, uv, 0.0).rgb;
76
+ let b = clamp(base.b + db, 0.0, 1.0);
77
+ let g = clamp(base.g + dg, 0.0, 1.0);
78
+ let r = clamp(base.r + dr, 0.0, 1.0);
79
+ textureStore(outTex, vec2<i32>(ox, oy), vec4<f32>(r, g, b, 1.0));
80
+ }`;
81
+ }
82
+
83
+ export async function createSR(device, { weightsBin, weightsManifest, channels }) {
84
+ // channel width lives in the weights: c1 is the 3->C input conv
85
+ const C = channels || (weightsManifest['c1.weight'] ? weightsManifest['c1.weight'].shape[0] : 16);
86
+ const bufN = (bytes) => device.createBuffer({
87
+ size: Math.ceil(bytes / 4) * 4, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST });
88
+
89
+ const wbuf = {};
90
+ for (const [name, m] of Object.entries(weightsManifest)) {
91
+ const n = m.shape.reduce((a, b) => a * b, 1);
92
+ wbuf[name] = bufN(n * 4);
93
+ device.queue.writeBuffer(wbuf[name], 0, weightsBin, m.offset * 4, n * 4);
94
+ }
95
+ const sampler = device.createSampler({ magFilter: 'linear', minFilter: 'linear',
96
+ addressModeU: 'clamp-to-edge', addressModeV: 'clamp-to-edge' });
97
+ const pipe = (code) => device.createComputePipeline({ layout: 'auto',
98
+ compute: { module: device.createShaderModule({ code }), entryPoint: 'main' } });
99
+ const pIn = pipe(wgslIn(C));
100
+ // f16 copies of the heavy conv weights (the RB kernels read f16)
101
+ const pToH = pipe(WGSL_TO_F16);
102
+ const wbufH = {};
103
+ {
104
+ const enc = device.createCommandEncoder();
105
+ for (const name of ['c2.weight', 'c3.weight', 'c4.weight']) {
106
+ const n = weightsManifest[name].shape.reduce((a, b) => a * b, 1);
107
+ wbufH[name] = device.createBuffer({ size: Math.ceil(n / 2) * 4,
108
+ usage: GPUBufferUsage.STORAGE });
109
+ const pass = enc.beginComputePass();
110
+ pass.setPipeline(pToH);
111
+ pass.setBindGroup(0, device.createBindGroup({ layout: pToH.getBindGroupLayout(0),
112
+ entries: [{ binding: 0, resource: { buffer: wbuf[name] } },
113
+ { binding: 1, resource: { buffer: wbufH[name] } }] }));
114
+ pass.dispatchWorkgroups(Math.ceil(n / 256));
115
+ pass.end();
116
+ }
117
+ device.queue.submit([enc.finish()]);
118
+ }
119
+ const onesAlpha = bufN(12 * 4);
120
+ device.queue.writeBuffer(onesAlpha, 0, new Float32Array(12).fill(1));
121
+
122
+ // per-input-size state (feature buffers + dims); keyed by "WxH"
123
+ const states = new Map();
124
+ function stateFor(w, h) {
125
+ const k = w + 'x' + h;
126
+ if (!states.has(k)) {
127
+ const dims = bufN(8);
128
+ device.queue.writeBuffer(dims, 0, new Uint32Array([w, h]));
129
+ const fa = bufN(C * w * h * 2);
130
+ const fb = bufN(C * w * h * 2);
131
+ const det = bufN(12 * w * h * 2);
132
+ const pMid = pipe(wgslConvRB(C, C, w, h, w, h, false));
133
+ const pOutConv = pipe(wgslConvRB(C, 12, w, h, w, h, false));
134
+ const pShuf = pipe(wgslShuffle(w, h));
135
+ const midBg = (wname, aname, sBuf, dBuf) => device.createBindGroup({
136
+ layout: pMid.getBindGroupLayout(0), entries: [
137
+ { binding: 0, resource: { buffer: sBuf } },
138
+ { binding: 1, resource: { buffer: wbufH[wname] } },
139
+ { binding: 2, resource: { buffer: wbuf[wname.replace('.weight', '.bias')] } },
140
+ { binding: 3, resource: { buffer: wbuf[aname] } },
141
+ { binding: 4, resource: { buffer: dBuf } }] });
142
+ states.set(k, {
143
+ dims, fa, fb, det, pMid, pOutConv, pShuf,
144
+ bgM2: midBg('c2.weight', 'a2.weight', fa, fb),
145
+ bgM3: midBg('c3.weight', 'a3.weight', fb, fa),
146
+ bgOut: device.createBindGroup({ layout: pOutConv.getBindGroupLayout(0), entries: [
147
+ { binding: 0, resource: { buffer: fa } },
148
+ { binding: 1, resource: { buffer: wbufH['c4.weight'] } },
149
+ { binding: 2, resource: { buffer: wbuf['c4.bias'] } },
150
+ { binding: 3, resource: { buffer: onesAlpha } },
151
+ { binding: 4, resource: { buffer: det } }] }),
152
+ });
153
+ if (states.size > 6) { // sizes changed wholesale
154
+ for (const [kk, s] of states) if (kk !== k) { s.fa.destroy(); s.fb.destroy(); s.det.destroy(); s.dims.destroy(); states.delete(kk); }
155
+ }
156
+ }
157
+ return states.get(k);
158
+ }
159
+
160
+ const gx = (n) => Math.ceil(n / WG);
161
+ // keyed by texture OBJECT, not label: the host rebuilds its texture pools on
162
+ // settings changes reusing the same labels - a label-keyed cache then binds
163
+ // views of DESTROYED textures and every SR'd frame comes out corrupted until
164
+ // page reload. WeakMaps also let dead textures drop their bind groups with GC.
165
+ const bgCache = new WeakMap();
166
+
167
+ // srcTex (w x h) -> dstTex (2w x 2h, rgba8unorm STORAGE_BINDING)
168
+ function process(srcTex, dstTex, w, h) {
169
+ const S = stateFor(w, h);
170
+ let perSrc = bgCache.get(srcTex);
171
+ if (!perSrc) { perSrc = new WeakMap(); bgCache.set(srcTex, perSrc); }
172
+ let bgs = perSrc.get(dstTex);
173
+ if (!bgs) {
174
+ const srcView = srcTex.createView();
175
+ bgs = {
176
+ in: device.createBindGroup({ layout: pIn.getBindGroupLayout(0), entries: [
177
+ { binding: 0, resource: srcView }, { binding: 1, resource: sampler },
178
+ { binding: 2, resource: { buffer: wbuf['c1.weight'] } },
179
+ { binding: 3, resource: { buffer: wbuf['c1.bias'] } },
180
+ { binding: 4, resource: { buffer: wbuf['a1.weight'] } },
181
+ { binding: 5, resource: { buffer: S.fa } },
182
+ { binding: 6, resource: { buffer: S.dims } }] }),
183
+ shuf: device.createBindGroup({ layout: S.pShuf.getBindGroupLayout(0), entries: [
184
+ { binding: 0, resource: { buffer: S.det } },
185
+ { binding: 1, resource: srcView }, { binding: 2, resource: sampler },
186
+ { binding: 3, resource: dstTex.createView() }] }),
187
+ };
188
+ perSrc.set(dstTex, bgs);
189
+ }
190
+ const enc = device.createCommandEncoder();
191
+ const pass = enc.beginComputePass();
192
+ pass.setPipeline(pIn); pass.setBindGroup(0, bgs.in); pass.dispatchWorkgroups(gx(w), gx(h));
193
+ pass.setPipeline(S.pMid);
194
+ pass.setBindGroup(0, S.bgM2); pass.dispatchWorkgroups(Math.ceil(w / 16), Math.ceil(h / 16), C / 4);
195
+ pass.setBindGroup(0, S.bgM3); pass.dispatchWorkgroups(Math.ceil(w / 16), Math.ceil(h / 16), C / 4);
196
+ pass.setPipeline(S.pOutConv); pass.setBindGroup(0, S.bgOut);
197
+ pass.dispatchWorkgroups(Math.ceil(w / 16), Math.ceil(h / 16), 3);
198
+ pass.setPipeline(S.pShuf); pass.setBindGroup(0, bgs.shuf);
199
+ pass.dispatchWorkgroups(gx(w * 2), gx(h * 2));
200
+ pass.end();
201
+ device.queue.submit([enc.finish()]);
202
+ }
203
+
204
+ return { process };
205
+ }