@octoseq/mir 0.1.0-main.0d2814e

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/dist/chunk-DUWYCAVG.js +1525 -0
  2. package/dist/chunk-DUWYCAVG.js.map +1 -0
  3. package/dist/index.d.ts +450 -0
  4. package/dist/index.js +1234 -0
  5. package/dist/index.js.map +1 -0
  6. package/dist/runMir-CSIBwNZ3.d.ts +84 -0
  7. package/dist/runner/runMir.d.ts +2 -0
  8. package/dist/runner/runMir.js +3 -0
  9. package/dist/runner/runMir.js.map +1 -0
  10. package/dist/runner/workerProtocol.d.ts +169 -0
  11. package/dist/runner/workerProtocol.js +11 -0
  12. package/dist/runner/workerProtocol.js.map +1 -0
  13. package/dist/types-BE3py4fZ.d.ts +83 -0
  14. package/package.json +55 -0
  15. package/src/dsp/fft.ts +22 -0
  16. package/src/dsp/fftBackend.ts +53 -0
  17. package/src/dsp/fftBackendFftjs.ts +60 -0
  18. package/src/dsp/hpss.ts +152 -0
  19. package/src/dsp/hpssGpu.ts +101 -0
  20. package/src/dsp/mel.ts +219 -0
  21. package/src/dsp/mfcc.ts +119 -0
  22. package/src/dsp/onset.ts +205 -0
  23. package/src/dsp/peakPick.ts +112 -0
  24. package/src/dsp/spectral.ts +95 -0
  25. package/src/dsp/spectrogram.ts +176 -0
  26. package/src/gpu/README.md +34 -0
  27. package/src/gpu/context.ts +44 -0
  28. package/src/gpu/helpers.ts +87 -0
  29. package/src/gpu/hpssMasks.ts +116 -0
  30. package/src/gpu/kernels/hpssMasks.wgsl.ts +137 -0
  31. package/src/gpu/kernels/melProject.wgsl.ts +48 -0
  32. package/src/gpu/kernels/onsetEnvelope.wgsl.ts +56 -0
  33. package/src/gpu/melProject.ts +98 -0
  34. package/src/gpu/onsetEnvelope.ts +81 -0
  35. package/src/gpu/webgpu.d.ts +176 -0
  36. package/src/index.ts +121 -0
  37. package/src/runner/runMir.ts +431 -0
  38. package/src/runner/workerProtocol.ts +189 -0
  39. package/src/search/featureVectorV1.ts +123 -0
  40. package/src/search/fingerprintV1.ts +230 -0
  41. package/src/search/refinedModelV1.ts +321 -0
  42. package/src/search/searchTrackV1.ts +206 -0
  43. package/src/search/searchTrackV1Guided.ts +863 -0
  44. package/src/search/similarity.ts +98 -0
  45. package/src/types.ts +105 -0
  46. package/src/util/display.ts +80 -0
  47. package/src/util/normalise.ts +58 -0
  48. package/src/util/stats.ts +25 -0
@@ -0,0 +1,95 @@
1
+ import type { Spectrogram } from "./spectrogram";
2
+
3
+ /**
4
+ * Spectral centroid per frame (Hz).
5
+ *
6
+ * Output is aligned 1:1 with `spec.times`.
7
+ */
8
+ export function spectralCentroid(spec: Spectrogram): Float32Array {
9
+ const nFrames = spec.times.length;
10
+ const out = new Float32Array(nFrames);
11
+
12
+ const nBins = (spec.fftSize >>> 1) + 1;
13
+ const binHz = spec.sampleRate / spec.fftSize;
14
+
15
+ for (let t = 0; t < nFrames; t++) {
16
+ const mags = spec.magnitudes[t];
17
+ if (!mags) {
18
+ out[t] = 0;
19
+ continue;
20
+ }
21
+
22
+ let num = 0;
23
+ let den = 0;
24
+
25
+ // DC..Nyquist inclusive.
26
+ for (let k = 0; k < nBins; k++) {
27
+ const m = mags[k] ?? 0;
28
+ const f = k * binHz;
29
+ num += f * m;
30
+ den += m;
31
+ }
32
+
33
+ out[t] = den > 0 ? num / den : 0;
34
+ }
35
+
36
+ return out;
37
+ }
38
+
39
+ /**
40
+ * Spectral flux per frame (unitless).
41
+ *
42
+ * Definition used here:
43
+ * - L1 distance between successive *normalised* magnitude spectra.
44
+ * - First frame flux is 0.
45
+ *
46
+ * Output is aligned 1:1 with `spec.times`.
47
+ */
48
+ export function spectralFlux(spec: Spectrogram): Float32Array {
49
+ const nFrames = spec.times.length;
50
+ const out = new Float32Array(nFrames);
51
+
52
+ const nBins = (spec.fftSize >>> 1) + 1;
53
+
54
+ let prev: Float32Array | null = null;
55
+
56
+ for (let t = 0; t < nFrames; t++) {
57
+ const mags = spec.magnitudes[t];
58
+ if (!mags) {
59
+ out[t] = 0;
60
+ prev = null;
61
+ continue;
62
+ }
63
+
64
+ // Normalise to reduce sensitivity to overall level.
65
+ let sum = 0;
66
+ for (let k = 0; k < nBins; k++) sum += mags[k] ?? 0;
67
+
68
+ if (sum <= 0) {
69
+ out[t] = 0;
70
+ prev = null;
71
+ continue;
72
+ }
73
+
74
+ const cur = new Float32Array(nBins);
75
+ const inv = 1 / sum;
76
+ for (let k = 0; k < nBins; k++) cur[k] = (mags[k] ?? 0) * inv;
77
+
78
+ if (!prev) {
79
+ out[t] = 0;
80
+ prev = cur;
81
+ continue;
82
+ }
83
+
84
+ let flux = 0;
85
+ for (let k = 0; k < nBins; k++) {
86
+ const d = (cur[k] ?? 0) - (prev[k] ?? 0);
87
+ flux += Math.abs(d);
88
+ }
89
+
90
+ out[t] = flux;
91
+ prev = cur;
92
+ }
93
+
94
+ return out;
95
+ }
@@ -0,0 +1,176 @@
1
+ import type { MirGPU } from "../gpu/context";
2
+
3
+ import { hannWindow } from "./fft";
4
+ import { getFftBackend } from "./fftBackend";
5
+
6
+ // AudioBufferLike is re-exported from the package root.
7
+ // Keeping this local type avoids importing from ../index (which can create circular deps).
8
+ export type AudioBufferLike = {
9
+ sampleRate: number;
10
+ getChannelData(channel: number): Float32Array;
11
+ numberOfChannels: number;
12
+ };
13
+
14
+ export type SpectrogramConfig = {
15
+ fftSize: number;
16
+ hopSize: number;
17
+ window: "hann";
18
+ };
19
+
20
+ export type SpectrogramOptions = {
21
+ /** Optional cancellation hook; checked once per frame. */
22
+ isCancelled?: () => boolean;
23
+ };
24
+
25
+ export type Spectrogram = {
26
+ sampleRate: number;
27
+ fftSize: number;
28
+ hopSize: number;
29
+ times: Float32Array; // seconds (center of each frame)
30
+ magnitudes: Float32Array[]; // [frame][bin]
31
+ };
32
+
33
+ function assertPositiveInt(name: string, value: number): void {
34
+ if (!Number.isFinite(value) || value <= 0 || (value | 0) !== value) {
35
+ throw new Error(`@octoseq/mir: ${name} must be a positive integer`);
36
+ }
37
+ }
38
+
39
+ function mixToMono(audio: AudioBufferLike): Float32Array {
40
+ const nCh = audio.numberOfChannels;
41
+ if (nCh <= 0) {
42
+ throw new Error("@octoseq/mir: audio.numberOfChannels must be >= 1");
43
+ }
44
+
45
+ if (nCh === 1) {
46
+ return audio.getChannelData(0);
47
+ }
48
+
49
+ const length = audio.getChannelData(0).length;
50
+ const out = new Float32Array(length);
51
+
52
+ for (let ch = 0; ch < nCh; ch++) {
53
+ const data = audio.getChannelData(ch);
54
+ if (data.length !== length) {
55
+ throw new Error(
56
+ "@octoseq/mir: all channels must have equal length (AudioBuffer-like invariant)"
57
+ );
58
+ }
59
+ for (let i = 0; i < length; i++) {
60
+ // `out[i]` is `number|undefined` under `noUncheckedIndexedAccess`.
61
+ out[i] = (out[i] ?? 0) + (data[i] ?? 0);
62
+ }
63
+ }
64
+
65
+ const inv = 1 / nCh;
66
+ for (let i = 0; i < length; i++) out[i] = (out[i] ?? 0) * inv;
67
+
68
+ return out;
69
+ }
70
+
71
+ /**
72
+ * Compute a magnitude spectrogram.
73
+ *
74
+ * v0.1 implementation:
75
+ * - CPU STFT + FFT for correctness.
76
+ * - The function accepts an optional MirGPU to match the future API.
77
+ * (STFT/FFT is the largest dense math block and can be ported to WebGPU later.)
78
+ */
79
+ export async function spectrogram(
80
+ audio: AudioBufferLike,
81
+ config: SpectrogramConfig,
82
+ gpu?: MirGPU,
83
+ options: SpectrogramOptions = {}
84
+ ): Promise<Spectrogram> {
85
+ // Keep the parameter to make the expensive step explicitly reusable.
86
+ // (v0.1 computes FFT on CPU; GPU is accepted for future acceleration.)
87
+ void gpu;
88
+
89
+ assertPositiveInt("config.fftSize", config.fftSize);
90
+ assertPositiveInt("config.hopSize", config.hopSize);
91
+
92
+ if (config.window !== "hann") {
93
+ throw new Error(
94
+ `@octoseq/mir: unsupported window '${config.window}'. v0.1 supports only 'hann'.`
95
+ );
96
+ }
97
+
98
+ const fftSize = config.fftSize;
99
+ if ((fftSize & (fftSize - 1)) !== 0) {
100
+ throw new Error("@octoseq/mir: config.fftSize must be a power of two");
101
+ }
102
+
103
+ const hopSize = config.hopSize;
104
+ if (hopSize > fftSize) {
105
+ throw new Error(
106
+ "@octoseq/mir: config.hopSize must be <= config.fftSize"
107
+ );
108
+ }
109
+
110
+ const sr = audio.sampleRate;
111
+ const mono = mixToMono(audio);
112
+
113
+ // Number of frames with 'valid' windows.
114
+ // We prefer explicitness over padding for v0.1.
115
+ const nFrames = Math.max(0, 1 + Math.floor((mono.length - fftSize) / hopSize));
116
+
117
+ const times = new Float32Array(nFrames);
118
+ const mags: Float32Array[] = new Array(nFrames);
119
+
120
+ const window = hannWindow(fftSize);
121
+
122
+ // Reuse FFT plan and buffers across frames.
123
+ const fft = getFftBackend(fftSize);
124
+
125
+ // Preallocate frame buffer so we don't allocate per frame.
126
+ const windowedFrame = new Float32Array(fftSize);
127
+
128
+ // Minimal timing instrumentation. This is deliberately lightweight and only used
129
+ // for optional debug logging by callers (e.g. worker). We don't expose this in the public API.
130
+ let totalFftMs = 0;
131
+ const nowMs = (): number =>
132
+ typeof performance !== "undefined" ? performance.now() : Date.now();
133
+
134
+ for (let frame = 0; frame < nFrames; frame++) {
135
+ if (options.isCancelled?.()) {
136
+ throw new Error("@octoseq/mir: cancelled");
137
+ }
138
+ const start = frame * hopSize;
139
+
140
+ // time is the center of the analysis window.
141
+ times[frame] = (start + fftSize / 2) / sr;
142
+
143
+ // Apply windowing before FFT (same as previous implementation).
144
+ for (let i = 0; i < fftSize; i++) {
145
+ const s = mono[start + i] ?? 0;
146
+ windowedFrame[i] = s * (window[i] ?? 0);
147
+ }
148
+
149
+ const t0 = nowMs();
150
+ const { real, imag } = fft.forwardReal(windowedFrame);
151
+ totalFftMs += nowMs() - t0;
152
+
153
+ // Magnitudes: only keep the real-input half-spectrum [0..N/2] inclusive.
154
+ const nBins = (fftSize >>> 1) + 1;
155
+ const out = new Float32Array(nBins);
156
+ for (let k = 0; k < nBins; k++) {
157
+ const re = real[k] ?? 0;
158
+ const im = imag[k] ?? 0;
159
+ out[k] = Math.hypot(re, im);
160
+ }
161
+ mags[frame] = out;
162
+ }
163
+
164
+ // Attach as a non-enumerable debug field to avoid API changes.
165
+ // Consumers can optionally read it for profiling in development.
166
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
167
+ (mags as any).cpuFftTotalMs = totalFftMs;
168
+
169
+ return {
170
+ sampleRate: sr,
171
+ fftSize,
172
+ hopSize,
173
+ times,
174
+ magnitudes: mags
175
+ };
176
+ }
@@ -0,0 +1,34 @@
1
+ # WebGPU acceleration
2
+
3
+ This folder contains the WebGPU compute implementation that powers optional GPU paths in `@octoseq/mir`.
4
+
5
+ ## What runs on GPU?
6
+
7
+ - **Mel filterbank projection** (spectrogram magnitudes → mel bands) — real WGSL compute kernel.
8
+ - **HPSS mask estimation** — WGSL kernels for soft harmonic/percussive masks (see `hpssMasks.wgsl.ts`).
9
+ - FFT/STFT remains on CPU (see `src/dsp/spectrogram.ts`); GPU is used as an acceleration stage rather than a full pipeline.
10
+
11
+ ## Timing / observability
12
+
13
+ The GPU stage measures **submit → readback completion** time (`gpuSubmitToReadbackMs`) by awaiting `GPUBuffer.mapAsync()` on a readback buffer.
14
+ This timing is surfaced through:
15
+
16
+ - `MelSpectrogram.gpuTimings.gpuSubmitToReadbackMs`
17
+ - `MirResult.meta.timings.gpuMs` (prefers the submit→readback timing when present)
18
+
19
+ ## Numeric tolerance
20
+
21
+ GPU and CPU results may differ slightly due to floating point order-of-operations.
22
+ These differences should be small and not visually significant.
23
+ A reasonable tolerance for comparison is:
24
+
25
+ - `absDiff <= 1e-4` for individual mel bin values (after log10)
26
+ - HPSS masks are soft probabilities; expect small differences in the 1e-3 range.
27
+
28
+ ## Files
29
+
30
+ - `kernels/melProject.wgsl.ts` — WGSL kernel source
31
+ - `kernels/hpssMasks.wgsl.ts` — WGSL kernels for harmonic/percussive mask estimation
32
+ - `helpers.ts` — small buffer/dispatch/readback helpers
33
+ - `melProject.ts` — kernel wrapper that runs the projection and reads back `Float32Array`
34
+ - `hpssMasks.ts` — GPU HPSS mask orchestration + readback
@@ -0,0 +1,44 @@
1
+ /**
2
+ * WebGPU context wrapper for MIR computations.
3
+ *
4
+ * v0.1 scope:
5
+ * - Provide a safe, explicit way for callers to opt into GPU usage.
6
+ * - Throw a clear error when called outside the browser or when WebGPU is unavailable.
7
+ */
8
+ export class MirGPU {
9
+ public readonly device: GPUDevice;
10
+ public readonly queue: GPUQueue;
11
+
12
+ private constructor(device: GPUDevice) {
13
+ this.device = device;
14
+ this.queue = device.queue;
15
+ }
16
+
17
+ static async create(): Promise<MirGPU> {
18
+ // Next.js note: callers must create MirGPU from a client component.
19
+ if (typeof navigator === "undefined") {
20
+ throw new Error(
21
+ "@octoseq/mir: WebGPU is only available in the browser (navigator is undefined)."
22
+ );
23
+ }
24
+
25
+ const nav = navigator as Navigator & { gpu?: GPU };
26
+ if (!nav.gpu) {
27
+ throw new Error(
28
+ "@octoseq/mir: WebGPU is unavailable (navigator.gpu is missing). Use CPU mode or a WebGPU-capable browser."
29
+ );
30
+ }
31
+
32
+ const adapter = await nav.gpu.requestAdapter();
33
+ if (!adapter) {
34
+ throw new Error(
35
+ "@octoseq/mir: Failed to acquire a WebGPU adapter. WebGPU may be disabled or unsupported."
36
+ );
37
+ }
38
+
39
+ // We keep this minimal: no required features for v0.1.
40
+ const device = await adapter.requestDevice();
41
+
42
+ return new MirGPU(device);
43
+ }
44
+ }
@@ -0,0 +1,87 @@
1
+ import type { MirGPU } from "./context";
2
+
3
+ export function nowMs(): number {
4
+ return typeof performance !== "undefined" ? performance.now() : Date.now();
5
+ }
6
+
7
+ export type GpuStageTiming = {
8
+ /** wall-clock time from queue.submit() to readback completion */
9
+ gpuSubmitToReadbackMs: number;
10
+ };
11
+
12
+ export type GpuDispatchResult<T> = {
13
+ value: T;
14
+ timing: GpuStageTiming;
15
+ };
16
+
17
+ export function byteSizeF32(n: number): number {
18
+ return n * 4;
19
+ }
20
+
21
+ export function createAndWriteStorageBuffer(gpu: MirGPU, data: Float32Array): GPUBuffer {
22
+ const buf = gpu.device.createBuffer({
23
+ size: byteSizeF32(data.length),
24
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
25
+ });
26
+ // Some TS lib definitions make BufferSource incompatible with ArrayBufferLike.
27
+ // WebGPU implementations accept typed arrays; cast to keep this package dependency-free.
28
+ gpu.queue.writeBuffer(buf, 0, data as unknown as BufferSource);
29
+ return buf;
30
+ }
31
+
32
+ export function createUniformBufferU32x4(gpu: MirGPU, u32x4: Uint32Array): GPUBuffer {
33
+ if (u32x4.length !== 4) throw new Error("@octoseq/mir: uniform buffer must be 4 u32 values");
34
+ const buf = gpu.device.createBuffer({
35
+ size: 16,
36
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
37
+ });
38
+ gpu.queue.writeBuffer(buf, 0, u32x4 as unknown as BufferSource);
39
+ return buf;
40
+ }
41
+
42
+ export function createStorageOutBuffer(gpu: MirGPU, byteLength: number): GPUBuffer {
43
+ return gpu.device.createBuffer({
44
+ size: byteLength,
45
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
46
+ });
47
+ }
48
+
49
+ export function createReadbackBuffer(gpu: MirGPU, byteLength: number): GPUBuffer {
50
+ return gpu.device.createBuffer({
51
+ size: byteLength,
52
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
53
+ });
54
+ }
55
+
56
+ /**
57
+ * Submit an encoder that copies `outBuffer` to a MAP_READ buffer and returns the mapped bytes.
58
+ *
59
+ * Note: mapping is the completion signal; we intentionally measure submit->map to validate
60
+ * real GPU work end-to-end.
61
+ */
62
+ export async function submitAndReadback(
63
+ gpu: MirGPU,
64
+ encoder: GPUCommandEncoder,
65
+ outBuffer: GPUBuffer,
66
+ readback: GPUBuffer,
67
+ byteLength: number
68
+ ): Promise<GpuDispatchResult<ArrayBuffer>> {
69
+ encoder.copyBufferToBuffer(outBuffer, 0, readback, 0, byteLength);
70
+
71
+ const tSubmit = nowMs();
72
+ gpu.queue.submit([encoder.finish()]);
73
+
74
+ await readback.mapAsync(GPUMapMode.READ);
75
+ const tDone = nowMs();
76
+
77
+ const mapped = readback.getMappedRange();
78
+ const copy = mapped.slice(0); // copies to standalone ArrayBuffer
79
+ readback.unmap();
80
+
81
+ return {
82
+ value: copy,
83
+ timing: {
84
+ gpuSubmitToReadbackMs: tDone - tSubmit,
85
+ },
86
+ };
87
+ }
@@ -0,0 +1,116 @@
1
+ import type { MirGPU } from "./context";
2
+
3
+ import {
4
+ byteSizeF32,
5
+ createAndWriteStorageBuffer,
6
+ createReadbackBuffer,
7
+ createStorageOutBuffer,
8
+ createUniformBufferU32x4,
9
+ nowMs,
10
+ type GpuDispatchResult,
11
+ } from "./helpers";
12
+
13
+ import { hpssMasksWGSL } from "./kernels/hpssMasks.wgsl";
14
+
15
+ export type GpuHpssMasksInput = {
16
+ nFrames: number;
17
+ nBins: number;
18
+ magsFlat: Float32Array; // [frame][bin] row-major, length=nFrames*nBins
19
+ softMask: boolean;
20
+ };
21
+
22
+ export type GpuHpssMasksOutput = {
23
+ harmonicMaskFlat: Float32Array; // length=nFrames*nBins
24
+ percussiveMaskFlat: Float32Array; // length=nFrames*nBins
25
+ };
26
+
27
+ /**
28
+ * Compute HPSS masks on the GPU.
29
+ *
30
+ * Notes:
31
+ * - This stage intentionally only estimates masks. Applying masks to the original magnitude
32
+ * spectrogram is done on CPU for clarity and to preserve existing output shapes/types.
33
+ * - Kernel uses a fixed median-of-9 approximation (see WGSL source for details).
34
+ */
35
+ export async function gpuHpssMasks(
36
+ gpu: MirGPU,
37
+ input: GpuHpssMasksInput
38
+ ): Promise<GpuDispatchResult<GpuHpssMasksOutput>> {
39
+ const { device } = gpu;
40
+
41
+ const { nFrames, nBins, magsFlat, softMask } = input;
42
+
43
+ if (magsFlat.length !== nFrames * nBins) {
44
+ throw new Error("@octoseq/mir: magsFlat length mismatch");
45
+ }
46
+
47
+ const magsBuffer = createAndWriteStorageBuffer(gpu, magsFlat);
48
+
49
+ const outByteLen = byteSizeF32(nFrames * nBins);
50
+ const harmonicOutBuffer = createStorageOutBuffer(gpu, outByteLen);
51
+ const percussiveOutBuffer = createStorageOutBuffer(gpu, outByteLen);
52
+
53
+ const harmonicReadback = createReadbackBuffer(gpu, outByteLen);
54
+ const percussiveReadback = createReadbackBuffer(gpu, outByteLen);
55
+
56
+ const shader = device.createShaderModule({ code: hpssMasksWGSL });
57
+ const pipeline = device.createComputePipeline({
58
+ layout: "auto",
59
+ compute: { module: shader, entryPoint: "main" },
60
+ });
61
+
62
+ // Params matches WGSL: (nBins, nFrames, softMaskU32, _pad)
63
+ const params = createUniformBufferU32x4(gpu, new Uint32Array([nBins, nFrames, softMask ? 1 : 0, 0]));
64
+
65
+ const bindGroup = device.createBindGroup({
66
+ layout: pipeline.getBindGroupLayout(0),
67
+ entries: [
68
+ { binding: 0, resource: { buffer: magsBuffer } },
69
+ { binding: 1, resource: { buffer: harmonicOutBuffer } },
70
+ { binding: 2, resource: { buffer: percussiveOutBuffer } },
71
+ { binding: 3, resource: { buffer: params } },
72
+ ],
73
+ });
74
+
75
+ const encoder = device.createCommandEncoder();
76
+ const pass = encoder.beginComputePass();
77
+ pass.setPipeline(pipeline);
78
+ pass.setBindGroup(0, bindGroup);
79
+
80
+ const wgX = Math.ceil(nFrames / 16);
81
+ const wgY = Math.ceil(nBins / 16);
82
+ pass.dispatchWorkgroups(wgX, wgY);
83
+ pass.end();
84
+
85
+ // Read back both masks from a single submission.
86
+ encoder.copyBufferToBuffer(harmonicOutBuffer, 0, harmonicReadback, 0, outByteLen);
87
+ encoder.copyBufferToBuffer(percussiveOutBuffer, 0, percussiveReadback, 0, outByteLen);
88
+
89
+ const tSubmit = nowMs();
90
+ gpu.queue.submit([encoder.finish()]);
91
+
92
+ await Promise.all([harmonicReadback.mapAsync(GPUMapMode.READ), percussiveReadback.mapAsync(GPUMapMode.READ)]);
93
+ const tDone = nowMs();
94
+
95
+ const hBytes = harmonicReadback.getMappedRange().slice(0);
96
+ const pBytes = percussiveReadback.getMappedRange().slice(0);
97
+ harmonicReadback.unmap();
98
+ percussiveReadback.unmap();
99
+
100
+ magsBuffer.destroy();
101
+ harmonicOutBuffer.destroy();
102
+ percussiveOutBuffer.destroy();
103
+ params.destroy();
104
+ harmonicReadback.destroy();
105
+ percussiveReadback.destroy();
106
+
107
+ return {
108
+ value: {
109
+ harmonicMaskFlat: new Float32Array(hBytes),
110
+ percussiveMaskFlat: new Float32Array(pBytes),
111
+ },
112
+ timing: {
113
+ gpuSubmitToReadbackMs: tDone - tSubmit,
114
+ },
115
+ };
116
+ }
@@ -0,0 +1,137 @@
1
+ /**
2
+ * WGSL kernel: HPSS mask estimation (harmonic/percussive) from a linear magnitude spectrogram.
3
+ *
4
+ * CPU reference path in `src/dsp/hpss.ts` uses true median filters with configurable kernels
5
+ * (defaults 17x17). That is too slow, and implementing a general median on GPU would require
6
+ * per-pixel sorting / dynamic allocation.
7
+ *
8
+ * GPU strategy (approximation, intentionally fixed):
9
+ * - Use a fixed 9-tap 1D robust smoother in time (harmonic estimate) and in frequency
10
+ * (percussive estimate).
11
+ * - Robust smoother is implemented as the exact median-of-9 via a fixed compare–swap
12
+ * sorting network (no dynamic memory, no data-dependent branches, fixed cost).
13
+ *
14
+ * This yields masks that are structurally faithful for visualisation / musical use, while
15
+ * allowing a large performance win from GPU parallelism.
16
+ *
17
+ * Shapes:
18
+ * - Input mags: flattened row-major [frame][bin], length = nFrames * nBins
19
+ * - Output harmonicMask, percussiveMask: same layout and length
20
+ */
21
+ export const hpssMasksWGSL = /* wgsl */ `
22
+ struct Params {
23
+ nBins: u32,
24
+ nFrames: u32,
25
+ softMask: u32, // 1 => soft, 0 => hard
26
+ _pad: u32,
27
+ };
28
+
29
+ @group(0) @binding(0) var<storage, read> mags : array<f32>;
30
+ @group(0) @binding(1) var<storage, read_write> harmonicMask : array<f32>;
31
+ @group(0) @binding(2) var<storage, read_write> percussiveMask : array<f32>;
32
+ @group(0) @binding(3) var<uniform> params : Params;
33
+
34
+ fn clamp_i32(x: i32, lo: i32, hi: i32) -> i32 {
35
+ return max(lo, min(hi, x));
36
+ }
37
+
38
+ fn swap_if_greater(a: ptr<function, f32>, b: ptr<function, f32>) {
39
+ // Branchless compare–swap.
40
+ let av = *a;
41
+ let bv = *b;
42
+ *a = min(av, bv);
43
+ *b = max(av, bv);
44
+ }
45
+
46
+ // Sorting network for 9 values; returns the 5th smallest (median).
47
+ //
48
+ // Notes:
49
+ // - This is fixed-cost and data-independent.
50
+ // - For our HPSS approximation we only need a robust center value, and exact median-of-9
51
+ // is a good tradeoff vs kernel size.
52
+ fn median9(v0: f32, v1: f32, v2: f32, v3: f32, v4: f32, v5: f32, v6: f32, v7: f32, v8: f32) -> f32 {
53
+ var a0 = v0; var a1 = v1; var a2 = v2;
54
+ var a3 = v3; var a4 = v4; var a5 = v5;
55
+ var a6 = v6; var a7 = v7; var a8 = v8;
56
+
57
+ // 9-input sorting network (compare–swap stages). This is a known minimal-ish network.
58
+ // We fully sort then take middle; cost is acceptable for 9.
59
+ // Stage 1
60
+ swap_if_greater(&a0,&a1); swap_if_greater(&a3,&a4); swap_if_greater(&a6,&a7);
61
+ // Stage 2
62
+ swap_if_greater(&a1,&a2); swap_if_greater(&a4,&a5); swap_if_greater(&a7,&a8);
63
+ // Stage 3
64
+ swap_if_greater(&a0,&a1); swap_if_greater(&a3,&a4); swap_if_greater(&a6,&a7);
65
+ // Stage 4
66
+ swap_if_greater(&a0,&a3); swap_if_greater(&a3,&a6); swap_if_greater(&a0,&a3);
67
+ // Stage 5
68
+ swap_if_greater(&a1,&a4); swap_if_greater(&a4,&a7); swap_if_greater(&a1,&a4);
69
+ // Stage 6
70
+ swap_if_greater(&a2,&a5); swap_if_greater(&a5,&a8); swap_if_greater(&a2,&a5);
71
+ // Stage 7
72
+ swap_if_greater(&a1,&a3); swap_if_greater(&a5,&a7);
73
+ // Stage 8
74
+ swap_if_greater(&a2,&a6);
75
+ // Stage 9
76
+ swap_if_greater(&a2,&a3); swap_if_greater(&a4,&a6);
77
+ // Stage 10
78
+ swap_if_greater(&a2,&a4); swap_if_greater(&a4,&a6);
79
+ // Stage 11
80
+ swap_if_greater(&a3,&a5); swap_if_greater(&a5,&a7);
81
+ // Stage 12
82
+ swap_if_greater(&a3,&a4); swap_if_greater(&a5,&a6);
83
+ // Stage 13
84
+ swap_if_greater(&a4,&a5);
85
+
86
+ return a4;
87
+ }
88
+
89
+ fn mag_at(frame: i32, bin: i32) -> f32 {
90
+ let f = clamp_i32(frame, 0, i32(params.nFrames) - 1);
91
+ let b = clamp_i32(bin, 0, i32(params.nBins) - 1);
92
+ let idx = u32(f) * params.nBins + u32(b);
93
+ return mags[idx];
94
+ }
95
+
96
+ @compute @workgroup_size(16, 16)
97
+ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
98
+ let frame = gid.x;
99
+ let bin = gid.y;
100
+
101
+ if (frame >= params.nFrames || bin >= params.nBins) {
102
+ return;
103
+ }
104
+
105
+ let f = i32(frame);
106
+ let b = i32(bin);
107
+
108
+ // Harmonic estimate: median in time over 9 taps.
109
+ let h = median9(
110
+ mag_at(f-4,b), mag_at(f-3,b), mag_at(f-2,b), mag_at(f-1,b), mag_at(f,b),
111
+ mag_at(f+1,b), mag_at(f+2,b), mag_at(f+3,b), mag_at(f+4,b)
112
+ );
113
+
114
+ // Percussive estimate: median in frequency over 9 taps.
115
+ let p = median9(
116
+ mag_at(f,b-4), mag_at(f,b-3), mag_at(f,b-2), mag_at(f,b-1), mag_at(f,b),
117
+ mag_at(f,b+1), mag_at(f,b+2), mag_at(f,b+3), mag_at(f,b+4)
118
+ );
119
+
120
+ let eps: f32 = 1e-12;
121
+ let denom = max(eps, h + p);
122
+
123
+ var mh = h / denom;
124
+ var mp = p / denom;
125
+
126
+ // Optional hard mask (kept for compatibility with CPU options).
127
+ if (params.softMask == 0u) {
128
+ let isH = h >= p;
129
+ mh = select(0.0, 1.0, isH);
130
+ mp = select(1.0, 0.0, isH);
131
+ }
132
+
133
+ let idx = frame * params.nBins + bin;
134
+ harmonicMask[idx] = mh;
135
+ percussiveMask[idx] = mp;
136
+ }
137
+ `;