@octoseq/mir 0.1.0-main.2e286ce
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/chunk-DUWYCAVG.js +1525 -0
- package/dist/chunk-DUWYCAVG.js.map +1 -0
- package/dist/index.d.ts +450 -0
- package/dist/index.js +1234 -0
- package/dist/index.js.map +1 -0
- package/dist/runMir-CSIBwNZ3.d.ts +84 -0
- package/dist/runner/runMir.d.ts +2 -0
- package/dist/runner/runMir.js +3 -0
- package/dist/runner/runMir.js.map +1 -0
- package/dist/runner/workerProtocol.d.ts +169 -0
- package/dist/runner/workerProtocol.js +11 -0
- package/dist/runner/workerProtocol.js.map +1 -0
- package/dist/types-BE3py4fZ.d.ts +83 -0
- package/package.json +55 -0
- package/src/dsp/fft.ts +22 -0
- package/src/dsp/fftBackend.ts +53 -0
- package/src/dsp/fftBackendFftjs.ts +60 -0
- package/src/dsp/hpss.ts +152 -0
- package/src/dsp/hpssGpu.ts +101 -0
- package/src/dsp/mel.ts +219 -0
- package/src/dsp/mfcc.ts +119 -0
- package/src/dsp/onset.ts +205 -0
- package/src/dsp/peakPick.ts +112 -0
- package/src/dsp/spectral.ts +95 -0
- package/src/dsp/spectrogram.ts +176 -0
- package/src/gpu/README.md +34 -0
- package/src/gpu/context.ts +44 -0
- package/src/gpu/helpers.ts +87 -0
- package/src/gpu/hpssMasks.ts +116 -0
- package/src/gpu/kernels/hpssMasks.wgsl.ts +137 -0
- package/src/gpu/kernels/melProject.wgsl.ts +48 -0
- package/src/gpu/kernels/onsetEnvelope.wgsl.ts +56 -0
- package/src/gpu/melProject.ts +98 -0
- package/src/gpu/onsetEnvelope.ts +81 -0
- package/src/gpu/webgpu.d.ts +176 -0
- package/src/index.ts +121 -0
- package/src/runner/runMir.ts +431 -0
- package/src/runner/workerProtocol.ts +189 -0
- package/src/search/featureVectorV1.ts +123 -0
- package/src/search/fingerprintV1.ts +230 -0
- package/src/search/refinedModelV1.ts +321 -0
- package/src/search/searchTrackV1.ts +206 -0
- package/src/search/searchTrackV1Guided.ts +863 -0
- package/src/search/similarity.ts +98 -0
- package/src/types.ts +105 -0
- package/src/util/display.ts +80 -0
- package/src/util/normalise.ts +58 -0
- package/src/util/stats.ts +25 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import type { MirFingerprintV1 } from "./fingerprintV1";
|
|
2
|
+
|
|
3
|
+
export type MirFeatureVectorSlice = {
|
|
4
|
+
offset: number;
|
|
5
|
+
length: number;
|
|
6
|
+
};
|
|
7
|
+
|
|
8
|
+
export type MirFeatureVectorLayoutV1 = {
|
|
9
|
+
dim: number;
|
|
10
|
+
|
|
11
|
+
// Foreground (query-length) window features
|
|
12
|
+
melMeanFg: MirFeatureVectorSlice;
|
|
13
|
+
melVarianceFg: MirFeatureVectorSlice;
|
|
14
|
+
onsetFg: MirFeatureVectorSlice;
|
|
15
|
+
mfccMeanFg?: MirFeatureVectorSlice;
|
|
16
|
+
mfccVarianceFg?: MirFeatureVectorSlice;
|
|
17
|
+
|
|
18
|
+
// Local contrast features (foreground - background-without-foreground)
|
|
19
|
+
melContrast?: MirFeatureVectorSlice;
|
|
20
|
+
onsetContrast?: MirFeatureVectorSlice;
|
|
21
|
+
mfccMeanContrast?: MirFeatureVectorSlice;
|
|
22
|
+
mfccVarianceContrast?: MirFeatureVectorSlice;
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
export function makeFeatureVectorLayoutV1(params: {
|
|
26
|
+
melDim: number;
|
|
27
|
+
mfccDim?: number;
|
|
28
|
+
includeContrast?: boolean;
|
|
29
|
+
}): MirFeatureVectorLayoutV1 {
|
|
30
|
+
const melDim = Math.max(0, params.melDim);
|
|
31
|
+
const mfccDim = Math.max(0, params.mfccDim ?? 0);
|
|
32
|
+
const includeContrast = params.includeContrast ?? true;
|
|
33
|
+
|
|
34
|
+
let offset = 0;
|
|
35
|
+
const melMeanFg: MirFeatureVectorSlice = { offset, length: melDim };
|
|
36
|
+
offset += melDim;
|
|
37
|
+
|
|
38
|
+
const melVarianceFg: MirFeatureVectorSlice = { offset, length: melDim };
|
|
39
|
+
offset += melDim;
|
|
40
|
+
|
|
41
|
+
const onsetFg: MirFeatureVectorSlice = { offset, length: 3 };
|
|
42
|
+
offset += 3;
|
|
43
|
+
|
|
44
|
+
const layout: MirFeatureVectorLayoutV1 = { dim: 0, melMeanFg, melVarianceFg, onsetFg };
|
|
45
|
+
|
|
46
|
+
if (mfccDim > 0) {
|
|
47
|
+
layout.mfccMeanFg = { offset, length: mfccDim };
|
|
48
|
+
layout.mfccVarianceFg = { offset: offset + mfccDim, length: mfccDim };
|
|
49
|
+
offset += mfccDim * 2;
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
if (includeContrast) {
|
|
53
|
+
layout.melContrast = { offset, length: melDim };
|
|
54
|
+
offset += melDim;
|
|
55
|
+
|
|
56
|
+
layout.onsetContrast = { offset, length: 3 };
|
|
57
|
+
offset += 3;
|
|
58
|
+
|
|
59
|
+
if (mfccDim > 0) {
|
|
60
|
+
layout.mfccMeanContrast = { offset, length: mfccDim };
|
|
61
|
+
layout.mfccVarianceContrast = { offset: offset + mfccDim, length: mfccDim };
|
|
62
|
+
offset += mfccDim * 2;
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
layout.dim = offset;
|
|
67
|
+
|
|
68
|
+
return layout;
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
export function writeFingerprintToFeatureVectorRawV1(
|
|
72
|
+
fp: MirFingerprintV1,
|
|
73
|
+
out: Float32Array,
|
|
74
|
+
offset: number,
|
|
75
|
+
layout: MirFeatureVectorLayoutV1
|
|
76
|
+
): void {
|
|
77
|
+
// Mel mean
|
|
78
|
+
for (let i = 0; i < layout.melMeanFg.length; i++) {
|
|
79
|
+
out[offset + layout.melMeanFg.offset + i] = fp.mel.mean[i] ?? 0;
|
|
80
|
+
}
|
|
81
|
+
// Mel variance
|
|
82
|
+
for (let i = 0; i < layout.melVarianceFg.length; i++) {
|
|
83
|
+
out[offset + layout.melVarianceFg.offset + i] = fp.mel.variance[i] ?? 0;
|
|
84
|
+
}
|
|
85
|
+
// Onset stats
|
|
86
|
+
out[offset + layout.onsetFg.offset + 0] = fp.onset.mean;
|
|
87
|
+
out[offset + layout.onsetFg.offset + 1] = fp.onset.max;
|
|
88
|
+
out[offset + layout.onsetFg.offset + 2] = fp.onset.peakDensityHz;
|
|
89
|
+
|
|
90
|
+
// Optional MFCC stats
|
|
91
|
+
if (layout.mfccMeanFg && layout.mfccVarianceFg) {
|
|
92
|
+
const mean = fp.mfcc?.mean;
|
|
93
|
+
const variance = fp.mfcc?.variance;
|
|
94
|
+
for (let i = 0; i < layout.mfccMeanFg.length; i++) {
|
|
95
|
+
out[offset + layout.mfccMeanFg.offset + i] = mean?.[i] ?? 0;
|
|
96
|
+
}
|
|
97
|
+
for (let i = 0; i < layout.mfccVarianceFg.length; i++) {
|
|
98
|
+
out[offset + layout.mfccVarianceFg.offset + i] = variance?.[i] ?? 0;
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
// Fingerprints do not include local contrast; ensure contrast blocks are deterministic zeros.
|
|
103
|
+
if (layout.melContrast) {
|
|
104
|
+
out.fill(0, offset + layout.melContrast.offset, offset + layout.melContrast.offset + layout.melContrast.length);
|
|
105
|
+
}
|
|
106
|
+
if (layout.onsetContrast) {
|
|
107
|
+
out.fill(0, offset + layout.onsetContrast.offset, offset + layout.onsetContrast.offset + layout.onsetContrast.length);
|
|
108
|
+
}
|
|
109
|
+
if (layout.mfccMeanContrast) {
|
|
110
|
+
out.fill(
|
|
111
|
+
0,
|
|
112
|
+
offset + layout.mfccMeanContrast.offset,
|
|
113
|
+
offset + layout.mfccMeanContrast.offset + layout.mfccMeanContrast.length
|
|
114
|
+
);
|
|
115
|
+
}
|
|
116
|
+
if (layout.mfccVarianceContrast) {
|
|
117
|
+
out.fill(
|
|
118
|
+
0,
|
|
119
|
+
offset + layout.mfccVarianceContrast.offset,
|
|
120
|
+
offset + layout.mfccVarianceContrast.offset + layout.mfccVarianceContrast.length
|
|
121
|
+
);
|
|
122
|
+
}
|
|
123
|
+
}
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
import type { MelSpectrogram } from "../dsp/mel";
|
|
2
|
+
import type { Features2D } from "../dsp/mfcc";
|
|
3
|
+
import { peakPick } from "../dsp/peakPick";
|
|
4
|
+
|
|
5
|
+
export type MirFingerprintV1 = {
|
|
6
|
+
version: "v1";
|
|
7
|
+
|
|
8
|
+
/** Query window time bounds (seconds) – informational/debug only. */
|
|
9
|
+
t0: number;
|
|
10
|
+
t1: number;
|
|
11
|
+
|
|
12
|
+
// A) Mel-spectrogram statistics
|
|
13
|
+
mel: {
|
|
14
|
+
/** Mean mel vector across frames (weighted by frame energy, then unit-normalised). */
|
|
15
|
+
mean: Float32Array;
|
|
16
|
+
/** Variance mel vector across frames (weighted by frame energy). */
|
|
17
|
+
variance: Float32Array;
|
|
18
|
+
};
|
|
19
|
+
|
|
20
|
+
// B) Transient/activity statistics
|
|
21
|
+
onset: {
|
|
22
|
+
mean: number;
|
|
23
|
+
max: number;
|
|
24
|
+
/** Peaks per second, computed using peakPick() on the onset envelope. */
|
|
25
|
+
peakDensityHz: number;
|
|
26
|
+
};
|
|
27
|
+
|
|
28
|
+
// Optional: MFCC statistics (coeffs 1–12, exclude C0)
|
|
29
|
+
mfcc?: {
|
|
30
|
+
mean: Float32Array;
|
|
31
|
+
variance: Float32Array;
|
|
32
|
+
};
|
|
33
|
+
};
|
|
34
|
+
|
|
35
|
+
export type FingerprintFrameWindow = {
|
|
36
|
+
startFrame: number;
|
|
37
|
+
endFrameExclusive: number;
|
|
38
|
+
};
|
|
39
|
+
|
|
40
|
+
function l2Norm(v: Float32Array): number {
|
|
41
|
+
let sum = 0;
|
|
42
|
+
for (let i = 0; i < v.length; i++) {
|
|
43
|
+
const x = v[i] ?? 0;
|
|
44
|
+
sum += x * x;
|
|
45
|
+
}
|
|
46
|
+
return Math.sqrt(sum);
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
function weightedStats(
|
|
50
|
+
frames: Float32Array[], // raw frames
|
|
51
|
+
start: number,
|
|
52
|
+
endExclusive: number,
|
|
53
|
+
dimHint = 0
|
|
54
|
+
): { mean: Float32Array; variance: Float32Array } {
|
|
55
|
+
const nFrames = Math.max(0, endExclusive - start);
|
|
56
|
+
|
|
57
|
+
// Handle empty window deterministically.
|
|
58
|
+
const first = frames[start];
|
|
59
|
+
const dim = first ? first.length : dimHint;
|
|
60
|
+
|
|
61
|
+
const mean = new Float32Array(dim);
|
|
62
|
+
const variance = new Float32Array(dim);
|
|
63
|
+
|
|
64
|
+
if (nFrames <= 0 || dim <= 0) return { mean, variance };
|
|
65
|
+
|
|
66
|
+
// 1. Calculate weights (L2 norms) and total weight
|
|
67
|
+
const weights = new Float32Array(nFrames);
|
|
68
|
+
const normFrames: Float32Array[] = new Array(nFrames);
|
|
69
|
+
let totalWeight = 0;
|
|
70
|
+
|
|
71
|
+
for (let i = 0; i < nFrames; i++) {
|
|
72
|
+
const f = frames[start + i];
|
|
73
|
+
if (!f) {
|
|
74
|
+
normFrames[i] = new Float32Array(dim);
|
|
75
|
+
continue;
|
|
76
|
+
}
|
|
77
|
+
const w = l2Norm(f);
|
|
78
|
+
weights[i] = w;
|
|
79
|
+
totalWeight += w;
|
|
80
|
+
|
|
81
|
+
// Normalize frame for shape statistics
|
|
82
|
+
const nf = new Float32Array(dim);
|
|
83
|
+
const d = w > 1e-12 ? w : 1;
|
|
84
|
+
for (let j = 0; j < dim; j++) nf[j] = f[j]! / d;
|
|
85
|
+
normFrames[i] = nf;
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
// fallback if all silence
|
|
89
|
+
if (totalWeight <= 1e-12) totalWeight = 1;
|
|
90
|
+
|
|
91
|
+
// 2. Weighted Mean
|
|
92
|
+
// Mean = sum(w_i * x_i) / sum(w_i)
|
|
93
|
+
for (let i = 0; i < nFrames; i++) {
|
|
94
|
+
const w = weights[i];
|
|
95
|
+
const nf = normFrames[i];
|
|
96
|
+
if (!w || w <= 0) continue;
|
|
97
|
+
const scale = w / totalWeight;
|
|
98
|
+
for (let j = 0; j < dim; j++) {
|
|
99
|
+
mean[j]! += nf![j]! * scale;
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
// 3. Weighted Variance
|
|
104
|
+
// Var = sum(w_i * (x_i - mean)^2) / sum(w_i)
|
|
105
|
+
for (let i = 0; i < nFrames; i++) {
|
|
106
|
+
const w = weights[i];
|
|
107
|
+
const nf = normFrames[i];
|
|
108
|
+
if (!w || w <= 0) continue;
|
|
109
|
+
const scale = w / totalWeight;
|
|
110
|
+
for (let j = 0; j < dim; j++) {
|
|
111
|
+
const diff = nf![j]! - mean[j]!;
|
|
112
|
+
variance[j]! += diff * diff * scale;
|
|
113
|
+
}
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
return { mean, variance };
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
function findFrameWindow(times: Float32Array, t0: number, t1: number): FingerprintFrameWindow {
|
|
120
|
+
// times are frame-center times; we include frames where t is within [t0,t1].
|
|
121
|
+
let start = 0;
|
|
122
|
+
while (start < times.length && (times[start] ?? 0) < t0) start++;
|
|
123
|
+
|
|
124
|
+
let end = start;
|
|
125
|
+
while (end < times.length && (times[end] ?? 0) <= t1) end++;
|
|
126
|
+
|
|
127
|
+
return { startFrame: start, endFrameExclusive: Math.max(start, end) };
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
/**
|
|
131
|
+
* Compute a deterministic v1 fingerprint for a time region [t0, t1].
|
|
132
|
+
*
|
|
133
|
+
* Loudness independence:
|
|
134
|
+
* - Uses energy-weighted statistics. Loud frames contribute more to the shape.
|
|
135
|
+
* - Resulting mean vector is effectively the average energy distribution direction.
|
|
136
|
+
*/
|
|
137
|
+
export function fingerprintV1(params: {
|
|
138
|
+
t0: number;
|
|
139
|
+
t1: number;
|
|
140
|
+
mel: MelSpectrogram;
|
|
141
|
+
onsetEnvelope: { times: Float32Array; values: Float32Array };
|
|
142
|
+
mfcc?: Features2D; // { times, values: Float32Array[] }
|
|
143
|
+
peakPick?: {
|
|
144
|
+
minIntervalSec?: number;
|
|
145
|
+
threshold?: number;
|
|
146
|
+
adaptiveFactor?: number;
|
|
147
|
+
};
|
|
148
|
+
}): MirFingerprintV1 {
|
|
149
|
+
const { t0, t1, mel, onsetEnvelope, mfcc } = params;
|
|
150
|
+
|
|
151
|
+
const tt0 = Math.min(t0, t1);
|
|
152
|
+
const tt1 = Math.max(t0, t1);
|
|
153
|
+
const dur = Math.max(1e-6, tt1 - tt0);
|
|
154
|
+
|
|
155
|
+
const melDimHint = mel.melBands.find((f) => f?.length)?.length ?? 0;
|
|
156
|
+
|
|
157
|
+
// --- Mel stats
|
|
158
|
+
const melWindow = findFrameWindow(mel.times, tt0, tt1);
|
|
159
|
+
// Be careful not to slice/copy excessively, but here we need array of arrays for helper
|
|
160
|
+
// melBands is Array<Float32Array>
|
|
161
|
+
const melStats = weightedStats(mel.melBands, melWindow.startFrame, melWindow.endFrameExclusive, melDimHint);
|
|
162
|
+
|
|
163
|
+
// --- Onset stats (1D)
|
|
164
|
+
// NOTE: onsetEnvelope times should align with mel.times (as computed today), but
|
|
165
|
+
// we don't assume perfect equality; we window by time.
|
|
166
|
+
let onsetSum = 0;
|
|
167
|
+
let onsetMax = -Infinity;
|
|
168
|
+
let onsetN = 0;
|
|
169
|
+
for (let i = 0; i < onsetEnvelope.times.length; i++) {
|
|
170
|
+
const t = onsetEnvelope.times[i] ?? 0;
|
|
171
|
+
if (t < tt0 || t > tt1) continue;
|
|
172
|
+
const v = onsetEnvelope.values[i] ?? 0;
|
|
173
|
+
onsetSum += v;
|
|
174
|
+
onsetN++;
|
|
175
|
+
if (v > onsetMax) onsetMax = v;
|
|
176
|
+
}
|
|
177
|
+
const onsetMean = onsetN > 0 ? onsetSum / onsetN : 0;
|
|
178
|
+
const onsetMaxSafe = Number.isFinite(onsetMax) ? onsetMax : 0;
|
|
179
|
+
|
|
180
|
+
// Peaks per second
|
|
181
|
+
const peaks = peakPick(onsetEnvelope.times, onsetEnvelope.values, {
|
|
182
|
+
minIntervalSec: params.peakPick?.minIntervalSec,
|
|
183
|
+
threshold: params.peakPick?.threshold,
|
|
184
|
+
adaptive: params.peakPick?.adaptiveFactor
|
|
185
|
+
? { method: "meanStd", factor: params.peakPick.adaptiveFactor }
|
|
186
|
+
: undefined,
|
|
187
|
+
strict: true,
|
|
188
|
+
});
|
|
189
|
+
const peaksInWindow = peaks.filter((p) => p.time >= tt0 && p.time <= tt1);
|
|
190
|
+
const peakDensityHz = peaksInWindow.length / dur;
|
|
191
|
+
|
|
192
|
+
// --- Optional MFCC (coeffs 1..12)
|
|
193
|
+
let mfccStats: MirFingerprintV1["mfcc"] | undefined;
|
|
194
|
+
const mfccDimHint = mfcc?.values.find((f) => f?.length)?.length ?? 0;
|
|
195
|
+
|
|
196
|
+
if (mfcc) {
|
|
197
|
+
const mfccWindow = findFrameWindow(mfcc.times, tt0, tt1);
|
|
198
|
+
|
|
199
|
+
// Exclude C0 and clamp to 1..12 inclusive.
|
|
200
|
+
// We must pre-process standard frames to slices for weightedStats to consume.
|
|
201
|
+
// Or we just consume them and slice inside?
|
|
202
|
+
// weightedStats takes Float32Array[].
|
|
203
|
+
const mfccFramesSliced: Float32Array[] = [];
|
|
204
|
+
for (let i = mfccWindow.startFrame; i < mfccWindow.endFrameExclusive; i++) {
|
|
205
|
+
const full = mfcc.values[i] ?? new Float32Array(0);
|
|
206
|
+
const start = Math.min(1, full.length);
|
|
207
|
+
const end = Math.min(13, full.length);
|
|
208
|
+
mfccFramesSliced.push(full.subarray(start, end));
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
const s = weightedStats(mfccFramesSliced, 0, mfccFramesSliced.length, mfccDimHint ? Math.max(0, mfccDimHint - 1) : 0);
|
|
212
|
+
mfccStats = { mean: s.mean, variance: s.variance };
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
return {
|
|
216
|
+
version: "v1",
|
|
217
|
+
t0: tt0,
|
|
218
|
+
t1: tt1,
|
|
219
|
+
mel: {
|
|
220
|
+
mean: melStats.mean,
|
|
221
|
+
variance: melStats.variance,
|
|
222
|
+
},
|
|
223
|
+
onset: {
|
|
224
|
+
mean: onsetMean,
|
|
225
|
+
max: onsetMaxSafe,
|
|
226
|
+
peakDensityHz,
|
|
227
|
+
},
|
|
228
|
+
...(mfccStats ? { mfcc: mfccStats } : {}),
|
|
229
|
+
};
|
|
230
|
+
}
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
import type { MirFeatureVectorLayoutV1 } from "./featureVectorV1";
|
|
2
|
+
|
|
3
|
+
export type MirRefinedModelKindV1 = "baseline" | "prototype" | "logistic";
|
|
4
|
+
|
|
5
|
+
export type MirRefinedModelExplainV1 = {
|
|
6
|
+
kind: MirRefinedModelKindV1;
|
|
7
|
+
positives: number;
|
|
8
|
+
negatives: number;
|
|
9
|
+
|
|
10
|
+
/** L2 norms per feature group (useful as a cheap, stable explainability hook). */
|
|
11
|
+
weightL2?: {
|
|
12
|
+
mel: number;
|
|
13
|
+
melForeground: number;
|
|
14
|
+
melContrast?: number;
|
|
15
|
+
onset: number;
|
|
16
|
+
onsetForeground: number;
|
|
17
|
+
onsetContrast?: number;
|
|
18
|
+
mfcc?: number;
|
|
19
|
+
mfccForeground?: number;
|
|
20
|
+
mfccContrast?: number;
|
|
21
|
+
};
|
|
22
|
+
|
|
23
|
+
/** Training diagnostics (only for logistic). */
|
|
24
|
+
training?: {
|
|
25
|
+
iterations: number;
|
|
26
|
+
finalLoss: number;
|
|
27
|
+
};
|
|
28
|
+
};
|
|
29
|
+
|
|
30
|
+
export type MirLogisticModelV1 = {
|
|
31
|
+
kind: "logistic";
|
|
32
|
+
w: Float32Array;
|
|
33
|
+
b: number;
|
|
34
|
+
explain: MirRefinedModelExplainV1;
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
export type MirPrototypeModelV1 = {
|
|
38
|
+
kind: "prototype";
|
|
39
|
+
prototype: Float32Array;
|
|
40
|
+
explain: MirRefinedModelExplainV1;
|
|
41
|
+
};
|
|
42
|
+
|
|
43
|
+
export type MirBaselineModelV1 = {
|
|
44
|
+
kind: "baseline";
|
|
45
|
+
explain: MirRefinedModelExplainV1;
|
|
46
|
+
};
|
|
47
|
+
|
|
48
|
+
export type MirRefinedModelV1 = MirBaselineModelV1 | MirPrototypeModelV1 | MirLogisticModelV1;
|
|
49
|
+
|
|
50
|
+
export type MirLogitContributionsByGroupV1 = {
|
|
51
|
+
logit: number;
|
|
52
|
+
bias: number;
|
|
53
|
+
mel: number;
|
|
54
|
+
melForeground: number;
|
|
55
|
+
melContrast?: number;
|
|
56
|
+
onset: number;
|
|
57
|
+
onsetForeground: number;
|
|
58
|
+
onsetContrast?: number;
|
|
59
|
+
mfcc?: number;
|
|
60
|
+
mfccForeground?: number;
|
|
61
|
+
mfccContrast?: number;
|
|
62
|
+
};
|
|
63
|
+
|
|
64
|
+
function clamp01(x: number): number {
|
|
65
|
+
return x <= 0 ? 0 : x >= 1 ? 1 : x;
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
function sigmoid(x: number): number {
|
|
69
|
+
// Prevent overflow in exp(); ±20 already saturates for our purposes.
|
|
70
|
+
const z = x > 20 ? 20 : x < -20 ? -20 : x;
|
|
71
|
+
return 1 / (1 + Math.exp(-z));
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
function dot(a: Float32Array, b: Float32Array): number {
|
|
75
|
+
const n = Math.min(a.length, b.length);
|
|
76
|
+
let s = 0;
|
|
77
|
+
for (let i = 0; i < n; i++) s += (a[i] ?? 0) * (b[i] ?? 0);
|
|
78
|
+
return s;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
function sliceDot(w: Float32Array, x: Float32Array, offset: number, length: number): number {
|
|
82
|
+
const end = Math.min(w.length, x.length, offset + length);
|
|
83
|
+
let sum = 0;
|
|
84
|
+
for (let i = offset; i < end; i++) sum += (w[i] ?? 0) * (x[i] ?? 0);
|
|
85
|
+
return sum;
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
export function logitContributionsByGroupV1(
|
|
89
|
+
w: Float32Array,
|
|
90
|
+
b: number,
|
|
91
|
+
x: Float32Array,
|
|
92
|
+
layout: MirFeatureVectorLayoutV1
|
|
93
|
+
): MirLogitContributionsByGroupV1 {
|
|
94
|
+
const melForeground =
|
|
95
|
+
sliceDot(w, x, layout.melMeanFg.offset, layout.melMeanFg.length) +
|
|
96
|
+
sliceDot(w, x, layout.melVarianceFg.offset, layout.melVarianceFg.length);
|
|
97
|
+
const melContrast = layout.melContrast ? sliceDot(w, x, layout.melContrast.offset, layout.melContrast.length) : 0;
|
|
98
|
+
const onsetForeground = sliceDot(w, x, layout.onsetFg.offset, layout.onsetFg.length);
|
|
99
|
+
const onsetContrast = layout.onsetContrast ? sliceDot(w, x, layout.onsetContrast.offset, layout.onsetContrast.length) : 0;
|
|
100
|
+
|
|
101
|
+
const mfccForeground =
|
|
102
|
+
layout.mfccMeanFg && layout.mfccVarianceFg
|
|
103
|
+
? sliceDot(w, x, layout.mfccMeanFg.offset, layout.mfccMeanFg.length) +
|
|
104
|
+
sliceDot(w, x, layout.mfccVarianceFg.offset, layout.mfccVarianceFg.length)
|
|
105
|
+
: 0;
|
|
106
|
+
const mfccContrast =
|
|
107
|
+
layout.mfccMeanContrast && layout.mfccVarianceContrast
|
|
108
|
+
? sliceDot(w, x, layout.mfccMeanContrast.offset, layout.mfccMeanContrast.length) +
|
|
109
|
+
sliceDot(w, x, layout.mfccVarianceContrast.offset, layout.mfccVarianceContrast.length)
|
|
110
|
+
: 0;
|
|
111
|
+
|
|
112
|
+
const mel = melForeground + melContrast;
|
|
113
|
+
const onset = onsetForeground + onsetContrast;
|
|
114
|
+
const mfcc = mfccForeground + mfccContrast;
|
|
115
|
+
|
|
116
|
+
const logit = mel + onset + mfcc + b;
|
|
117
|
+
|
|
118
|
+
return {
|
|
119
|
+
logit,
|
|
120
|
+
bias: b,
|
|
121
|
+
mel,
|
|
122
|
+
melForeground,
|
|
123
|
+
...(layout.melContrast ? { melContrast } : {}),
|
|
124
|
+
onset,
|
|
125
|
+
onsetForeground,
|
|
126
|
+
...(layout.onsetContrast ? { onsetContrast } : {}),
|
|
127
|
+
...(layout.mfccMeanFg || layout.mfccMeanContrast
|
|
128
|
+
? {
|
|
129
|
+
mfcc,
|
|
130
|
+
mfccForeground,
|
|
131
|
+
...(layout.mfccMeanContrast ? { mfccContrast } : {}),
|
|
132
|
+
}
|
|
133
|
+
: {}),
|
|
134
|
+
};
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
function l2Norm(v: Float32Array): number {
|
|
138
|
+
let sum = 0;
|
|
139
|
+
for (let i = 0; i < v.length; i++) {
|
|
140
|
+
const x = v[i] ?? 0;
|
|
141
|
+
sum += x * x;
|
|
142
|
+
}
|
|
143
|
+
return Math.sqrt(sum);
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
function cosineSimilarity01(a: Float32Array, b: Float32Array): number {
|
|
147
|
+
const n = Math.min(a.length, b.length);
|
|
148
|
+
let ab = 0;
|
|
149
|
+
let aa = 0;
|
|
150
|
+
let bb = 0;
|
|
151
|
+
for (let i = 0; i < n; i++) {
|
|
152
|
+
const x = a[i] ?? 0;
|
|
153
|
+
const y = b[i] ?? 0;
|
|
154
|
+
ab += x * y;
|
|
155
|
+
aa += x * x;
|
|
156
|
+
bb += y * y;
|
|
157
|
+
}
|
|
158
|
+
const denom = Math.sqrt(aa) * Math.sqrt(bb);
|
|
159
|
+
if (denom <= 0) return 0;
|
|
160
|
+
const cos = ab / denom;
|
|
161
|
+
const clamped = Math.max(-1, Math.min(1, cos));
|
|
162
|
+
return (clamped + 1) / 2;
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
function sliceSumSquares(w: Float32Array, offset: number, length: number): number {
|
|
166
|
+
let sum = 0;
|
|
167
|
+
const end = Math.min(w.length, offset + length);
|
|
168
|
+
for (let i = offset; i < end; i++) {
|
|
169
|
+
const x = w[i] ?? 0;
|
|
170
|
+
sum += x * x;
|
|
171
|
+
}
|
|
172
|
+
return sum;
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
export function summariseWeightL2ByGroup(w: Float32Array, layout: MirFeatureVectorLayoutV1): MirRefinedModelExplainV1["weightL2"] {
|
|
176
|
+
const melForegroundSq =
|
|
177
|
+
sliceSumSquares(w, layout.melMeanFg.offset, layout.melMeanFg.length) +
|
|
178
|
+
sliceSumSquares(w, layout.melVarianceFg.offset, layout.melVarianceFg.length);
|
|
179
|
+
const melContrastSq = layout.melContrast ? sliceSumSquares(w, layout.melContrast.offset, layout.melContrast.length) : 0;
|
|
180
|
+
const onsetForegroundSq = sliceSumSquares(w, layout.onsetFg.offset, layout.onsetFg.length);
|
|
181
|
+
const onsetContrastSq = layout.onsetContrast ? sliceSumSquares(w, layout.onsetContrast.offset, layout.onsetContrast.length) : 0;
|
|
182
|
+
|
|
183
|
+
const mfccForegroundSq =
|
|
184
|
+
layout.mfccMeanFg && layout.mfccVarianceFg
|
|
185
|
+
? sliceSumSquares(w, layout.mfccMeanFg.offset, layout.mfccMeanFg.length) +
|
|
186
|
+
sliceSumSquares(w, layout.mfccVarianceFg.offset, layout.mfccVarianceFg.length)
|
|
187
|
+
: 0;
|
|
188
|
+
const mfccContrastSq =
|
|
189
|
+
layout.mfccMeanContrast && layout.mfccVarianceContrast
|
|
190
|
+
? sliceSumSquares(w, layout.mfccMeanContrast.offset, layout.mfccMeanContrast.length) +
|
|
191
|
+
sliceSumSquares(w, layout.mfccVarianceContrast.offset, layout.mfccVarianceContrast.length)
|
|
192
|
+
: 0;
|
|
193
|
+
|
|
194
|
+
const mel = Math.sqrt(melForegroundSq + melContrastSq);
|
|
195
|
+
const onset = Math.sqrt(onsetForegroundSq + onsetContrastSq);
|
|
196
|
+
const mfcc = mfccForegroundSq + mfccContrastSq > 0 ? Math.sqrt(mfccForegroundSq + mfccContrastSq) : undefined;
|
|
197
|
+
|
|
198
|
+
return {
|
|
199
|
+
mel,
|
|
200
|
+
melForeground: Math.sqrt(melForegroundSq),
|
|
201
|
+
...(melContrastSq > 0 ? { melContrast: Math.sqrt(melContrastSq) } : {}),
|
|
202
|
+
onset,
|
|
203
|
+
onsetForeground: Math.sqrt(onsetForegroundSq),
|
|
204
|
+
...(onsetContrastSq > 0 ? { onsetContrast: Math.sqrt(onsetContrastSq) } : {}),
|
|
205
|
+
...(mfcc != null
|
|
206
|
+
? {
|
|
207
|
+
mfcc,
|
|
208
|
+
mfccForeground: Math.sqrt(mfccForegroundSq),
|
|
209
|
+
...(mfccContrastSq > 0 ? { mfccContrast: Math.sqrt(mfccContrastSq) } : {}),
|
|
210
|
+
}
|
|
211
|
+
: {}),
|
|
212
|
+
};
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
export function trainLogisticModelV1(params: {
|
|
216
|
+
positives: Float32Array[];
|
|
217
|
+
negatives: Float32Array[];
|
|
218
|
+
layout: MirFeatureVectorLayoutV1;
|
|
219
|
+
options?: { iterations?: number; learningRate?: number; l2?: number };
|
|
220
|
+
}): MirLogisticModelV1 {
|
|
221
|
+
const pos = params.positives;
|
|
222
|
+
const neg = params.negatives;
|
|
223
|
+
const dim = params.layout.dim;
|
|
224
|
+
|
|
225
|
+
// Small, deterministic batch GD: fast enough for < 50 samples and a few hundred dims.
|
|
226
|
+
const iterations = Math.max(1, params.options?.iterations ?? 80);
|
|
227
|
+
const learningRate = Math.max(1e-4, params.options?.learningRate ?? 0.15);
|
|
228
|
+
const l2 = Math.max(0, params.options?.l2 ?? 0.01);
|
|
229
|
+
|
|
230
|
+
const w = new Float32Array(dim);
|
|
231
|
+
let b = 0;
|
|
232
|
+
|
|
233
|
+
const posW = pos.length > 0 ? 0.5 / pos.length : 0;
|
|
234
|
+
const negW = neg.length > 0 ? 0.5 / neg.length : 0;
|
|
235
|
+
|
|
236
|
+
let lastLoss = Infinity;
|
|
237
|
+
let itersUsed = 0;
|
|
238
|
+
|
|
239
|
+
for (let iter = 0; iter < iterations; iter++) {
|
|
240
|
+
itersUsed = iter + 1;
|
|
241
|
+
|
|
242
|
+
const gradW = new Float32Array(dim);
|
|
243
|
+
let gradB = 0;
|
|
244
|
+
let loss = 0;
|
|
245
|
+
|
|
246
|
+
const accumulate = (x: Float32Array, y: 0 | 1, weight: number) => {
|
|
247
|
+
const s = dot(w, x) + b;
|
|
248
|
+
const p = sigmoid(s);
|
|
249
|
+
const err = p - y; // dL/ds for logistic loss
|
|
250
|
+
|
|
251
|
+
gradB += weight * err;
|
|
252
|
+
for (let j = 0; j < dim; j++) gradW[j] = (gradW[j] ?? 0) + weight * err * (x[j] ?? 0);
|
|
253
|
+
|
|
254
|
+
// Weighted cross-entropy loss
|
|
255
|
+
const pSafe = Math.min(1 - 1e-9, Math.max(1e-9, p));
|
|
256
|
+
loss += weight * (y ? -Math.log(pSafe) : -Math.log(1 - pSafe));
|
|
257
|
+
};
|
|
258
|
+
|
|
259
|
+
for (const x of pos) accumulate(x, 1, posW);
|
|
260
|
+
for (const x of neg) accumulate(x, 0, negW);
|
|
261
|
+
|
|
262
|
+
// L2 regularisation (do not regularise bias).
|
|
263
|
+
if (l2 > 0) {
|
|
264
|
+
for (let j = 0; j < dim; j++) {
|
|
265
|
+
gradW[j] = (gradW[j] ?? 0) + l2 * (w[j] ?? 0);
|
|
266
|
+
}
|
|
267
|
+
loss += (l2 * (l2Norm(w) ** 2)) / 2;
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
// Basic learning-rate decay helps stability on small datasets.
|
|
271
|
+
const lr = learningRate / (1 + iter * 0.01);
|
|
272
|
+
for (let j = 0; j < dim; j++) w[j] = (w[j] ?? 0) - lr * (gradW[j] ?? 0);
|
|
273
|
+
b -= lr * gradB;
|
|
274
|
+
|
|
275
|
+
if (Math.abs(lastLoss - loss) < 1e-6) break;
|
|
276
|
+
lastLoss = loss;
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
return {
|
|
280
|
+
kind: "logistic",
|
|
281
|
+
w,
|
|
282
|
+
b,
|
|
283
|
+
explain: {
|
|
284
|
+
kind: "logistic",
|
|
285
|
+
positives: pos.length,
|
|
286
|
+
negatives: neg.length,
|
|
287
|
+
weightL2: summariseWeightL2ByGroup(w, params.layout),
|
|
288
|
+
training: { iterations: itersUsed, finalLoss: Number.isFinite(lastLoss) ? lastLoss : 0 },
|
|
289
|
+
},
|
|
290
|
+
};
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
export function buildPrototypeModelV1(params: {
|
|
294
|
+
positives: Float32Array[];
|
|
295
|
+
layout: MirFeatureVectorLayoutV1;
|
|
296
|
+
}): MirPrototypeModelV1 {
|
|
297
|
+
const dim = params.layout.dim;
|
|
298
|
+
const proto = new Float32Array(dim);
|
|
299
|
+
|
|
300
|
+
const n = Math.max(1, params.positives.length);
|
|
301
|
+
for (const x of params.positives) {
|
|
302
|
+
for (let j = 0; j < dim; j++) proto[j] = (proto[j] ?? 0) + (x[j] ?? 0) / n;
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
return {
|
|
306
|
+
kind: "prototype",
|
|
307
|
+
prototype: proto,
|
|
308
|
+
explain: {
|
|
309
|
+
kind: "prototype",
|
|
310
|
+
positives: params.positives.length,
|
|
311
|
+
negatives: 0,
|
|
312
|
+
},
|
|
313
|
+
};
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
export function scoreWithModelV1(model: MirRefinedModelV1, x: Float32Array): number {
|
|
317
|
+
if (model.kind === "baseline") return 0;
|
|
318
|
+
if (model.kind === "prototype") return clamp01(cosineSimilarity01(model.prototype, x));
|
|
319
|
+
// logistic
|
|
320
|
+
return clamp01(sigmoid(dot(model.w, x) + model.b));
|
|
321
|
+
}
|