sonusai 0.18.2__py3-none-any.whl → 0.18.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -0
- sonusai/audiofe.py +1 -1
- sonusai/calc_metric_spenh.py +32 -362
- sonusai/data/genmixdb.yml +2 -0
- sonusai/doc/doc.py +45 -4
- sonusai/genmetrics.py +137 -109
- sonusai/lsdb.py +2 -2
- sonusai/metrics/__init__.py +4 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_pesq.py +12 -8
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_snr_f.py +34 -0
- sonusai/metrics/calc_speech.py +312 -0
- sonusai/metrics/calc_wer.py +2 -3
- sonusai/metrics/calc_wsdr.py +0 -59
- sonusai/mixture/__init__.py +3 -2
- sonusai/mixture/audio.py +6 -5
- sonusai/mixture/config.py +13 -0
- sonusai/mixture/constants.py +1 -0
- sonusai/mixture/datatypes.py +33 -0
- sonusai/mixture/generation.py +6 -2
- sonusai/mixture/mixdb.py +261 -122
- sonusai/mixture/soundfile_audio.py +8 -6
- sonusai/mixture/sox_audio.py +16 -13
- sonusai/mixture/torchaudio_audio.py +6 -4
- sonusai/mixture/truth_functions/energy.py +40 -28
- sonusai/mixture/truth_functions/target.py +0 -1
- sonusai/utils/__init__.py +1 -1
- sonusai/utils/asr.py +26 -39
- sonusai/utils/asr_functions/aaware_whisper.py +3 -3
- {sonusai-0.18.2.dist-info → sonusai-0.18.4.dist-info}/METADATA +1 -1
- {sonusai-0.18.2.dist-info → sonusai-0.18.4.dist-info}/RECORD +34 -31
- sonusai/mixture/mapped_snr_f.py +0 -100
- {sonusai-0.18.2.dist-info → sonusai-0.18.4.dist-info}/WHEEL +0 -0
- {sonusai-0.18.2.dist-info → sonusai-0.18.4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,312 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
from sonusai.mixture.constants import SAMPLE_RATE
|
4
|
+
from sonusai.mixture.datatypes import SpeechMetrics
|
5
|
+
from .calc_pesq import calc_pesq
|
6
|
+
|
7
|
+
|
8
|
+
def calc_speech(hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int = SAMPLE_RATE) -> SpeechMetrics:
|
9
|
+
"""Calculate speech metrics pesq, c_sig, c_bak, c_ovl, seg_snr.
|
10
|
+
|
11
|
+
These are all related and thus included in one function. Reference: matlab script "compute_metrics.m".
|
12
|
+
|
13
|
+
:param hypothesis: estimated audio
|
14
|
+
:param reference: reference audio
|
15
|
+
:param sample_rate: sample rate of audio
|
16
|
+
:return: SpeechMetrics named tuple
|
17
|
+
"""
|
18
|
+
|
19
|
+
# Weighted spectral slope measure
|
20
|
+
wss_dist_vec = _calc_weighted_spectral_slope_measure(hypothesis=hypothesis, reference=reference)
|
21
|
+
wss_dist_vec = np.sort(wss_dist_vec)
|
22
|
+
|
23
|
+
# Value from CMGAN reference implementation
|
24
|
+
alpha = 0.95
|
25
|
+
wss_dist = np.mean(wss_dist_vec[0: round(np.size(wss_dist_vec) * alpha)])
|
26
|
+
|
27
|
+
# Log likelihood ratio measure
|
28
|
+
llr_dist = _calc_log_likelihood_ratio_measure(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
|
29
|
+
ll_rs = np.sort(llr_dist)
|
30
|
+
llr_len = round(np.size(llr_dist) * alpha)
|
31
|
+
llr_mean = np.mean(ll_rs[:llr_len])
|
32
|
+
|
33
|
+
# Segmental SNR
|
34
|
+
snr_dist, segsnr_dist = _calc_snr(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
|
35
|
+
seg_snr = np.mean(segsnr_dist)
|
36
|
+
|
37
|
+
# PESQ
|
38
|
+
_pesq = calc_pesq(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
|
39
|
+
|
40
|
+
# Now compute the composite measures
|
41
|
+
c_sig = np.clip(3.093 - 1.029 * llr_mean + 0.603 * _pesq - 0.009 * wss_dist, 1, 5)
|
42
|
+
c_bak = np.clip(1.634 + 0.478 * _pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5)
|
43
|
+
c_ovl = np.clip(1.594 + 0.805 * _pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5)
|
44
|
+
|
45
|
+
return SpeechMetrics(_pesq, c_sig, c_bak, c_ovl)
|
46
|
+
|
47
|
+
|
48
|
+
def _calc_weighted_spectral_slope_measure(hypothesis: np.ndarray,
|
49
|
+
reference: np.ndarray,
|
50
|
+
sample_rate: int = SAMPLE_RATE) -> np.ndarray:
|
51
|
+
from scipy.fftpack import fft
|
52
|
+
|
53
|
+
# The lengths of the reference and hypothesis must be the same.
|
54
|
+
reference_length = np.size(reference)
|
55
|
+
hypothesis_length = np.size(hypothesis)
|
56
|
+
if reference_length != hypothesis_length:
|
57
|
+
raise ValueError('Hypothesis and reference must be the same length.')
|
58
|
+
|
59
|
+
# Window length in samples
|
60
|
+
win_length = int(np.round(30 * sample_rate / 1000))
|
61
|
+
# Window skip in samples
|
62
|
+
skip_rate = int(np.floor(np.divide(win_length, 4)))
|
63
|
+
# Maximum bandwidth
|
64
|
+
max_freq = int(np.divide(sample_rate, 2))
|
65
|
+
num_crit = 25
|
66
|
+
|
67
|
+
n_fft = int(np.power(2, np.ceil(np.log2(2 * win_length))))
|
68
|
+
n_fft_by_2 = int(np.multiply(0.5, n_fft))
|
69
|
+
# Value suggested by Klatt, pg 1280
|
70
|
+
k_max = 20.0
|
71
|
+
# Value suggested by Klatt, pg 1280
|
72
|
+
k_loc_max = 1.0
|
73
|
+
|
74
|
+
# Critical band filter definitions (center frequency and bandwidths in Hz)
|
75
|
+
cent_freq = np.array([50.0000, 120.000, 190.000, 260.000, 330.000, 400.000, 470.000,
|
76
|
+
540.000, 617.372, 703.378, 798.717, 904.128, 1020.38, 1148.30,
|
77
|
+
1288.72, 1442.54, 1610.70, 1794.16, 1993.93, 2211.08, 2446.71,
|
78
|
+
2701.97, 2978.04, 3276.17, 3597.63])
|
79
|
+
bandwidth = np.array([70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 70.0000,
|
80
|
+
77.3724, 86.0056, 95.3398, 105.411, 116.256, 127.914, 140.423,
|
81
|
+
153.823, 168.154, 183.457, 199.776, 217.153, 235.631, 255.255,
|
82
|
+
276.072, 298.126, 321.465, 346.136])
|
83
|
+
|
84
|
+
# Minimum critical bandwidth
|
85
|
+
bw_min = bandwidth[0]
|
86
|
+
|
87
|
+
# Set up the critical band filters.
|
88
|
+
# Note here that Gaussian-ly shaped filters are used.
|
89
|
+
# Also, the sum of the filter weights are equivalent for each critical band filter.
|
90
|
+
# Filter less than -30 dB and set to zero.
|
91
|
+
|
92
|
+
# -30 dB point of filter
|
93
|
+
min_factor = np.exp(-30.0 / (2.0 * 2.303))
|
94
|
+
crit_filter = np.empty((num_crit, n_fft_by_2))
|
95
|
+
for i in range(num_crit):
|
96
|
+
f0 = (cent_freq[i] / max_freq) * n_fft_by_2
|
97
|
+
bw = (bandwidth[i] / max_freq) * n_fft_by_2
|
98
|
+
norm_factor = np.log(bw_min) - np.log(bandwidth[i])
|
99
|
+
j = np.arange(n_fft_by_2)
|
100
|
+
crit_filter[i, :] = np.exp(-11 * np.square(np.divide(j - np.floor(f0), bw)) + norm_factor)
|
101
|
+
cond = np.greater(crit_filter[i, :], min_factor)
|
102
|
+
crit_filter[i, :] = np.where(cond, crit_filter[i, :], 0)
|
103
|
+
|
104
|
+
# For each frame of input speech, calculate the weighted spectral slope measure
|
105
|
+
num_frames = int(reference_length / skip_rate - (win_length / skip_rate))
|
106
|
+
start = 0
|
107
|
+
window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
|
108
|
+
|
109
|
+
distortion = np.empty(num_frames)
|
110
|
+
for frame_count in range(num_frames):
|
111
|
+
# (1) Get the frames for the test and reference speech. Multiply by Hanning window.
|
112
|
+
reference_frame = reference[start: start + win_length] / 32768
|
113
|
+
hypothesis_frame = hypothesis[start: start + win_length] / 32768
|
114
|
+
reference_frame = np.multiply(reference_frame, window)
|
115
|
+
hypothesis_frame = np.multiply(hypothesis_frame, window)
|
116
|
+
|
117
|
+
# (2) Compute the power spectrum of reference and hypothesis
|
118
|
+
reference_spec = np.square(np.abs(fft(reference_frame, n_fft)))
|
119
|
+
hypothesis_spec = np.square(np.abs(fft(hypothesis_frame, n_fft)))
|
120
|
+
|
121
|
+
# (3) Compute filter bank output energies (in dB scale)
|
122
|
+
reference_energy = np.matmul(crit_filter, reference_spec[0:n_fft_by_2])
|
123
|
+
hypothesis_energy = np.matmul(crit_filter, hypothesis_spec[0:n_fft_by_2])
|
124
|
+
|
125
|
+
reference_energy = 10 * np.log10(np.maximum(reference_energy, 1E-10))
|
126
|
+
hypothesis_energy = 10 * np.log10(np.maximum(hypothesis_energy, 1E-10))
|
127
|
+
|
128
|
+
# (4) Compute spectral slope (dB[i+1]-dB[i])
|
129
|
+
reference_slope = reference_energy[1:num_crit] - reference_energy[0: num_crit - 1]
|
130
|
+
hypothesis_slope = hypothesis_energy[1:num_crit] - hypothesis_energy[0: num_crit - 1]
|
131
|
+
|
132
|
+
# (5) Find the nearest peak locations in the spectra to each critical band.
|
133
|
+
# If the slope is negative, we search to the left. If positive, we search to the right.
|
134
|
+
reference_loc_peak = np.empty(num_crit - 1)
|
135
|
+
hypothesis_loc_peak = np.empty(num_crit - 1)
|
136
|
+
|
137
|
+
for i in range(num_crit - 1):
|
138
|
+
# find the peaks in the reference speech signal
|
139
|
+
if reference_slope[i] > 0:
|
140
|
+
# search to the right
|
141
|
+
n = i
|
142
|
+
while (n < num_crit - 1) and (reference_slope[n] > 0):
|
143
|
+
n = n + 1
|
144
|
+
reference_loc_peak[i] = reference_energy[n - 1]
|
145
|
+
else:
|
146
|
+
# search to the left
|
147
|
+
n = i
|
148
|
+
while (n >= 0) and (reference_slope[n] <= 0):
|
149
|
+
n = n - 1
|
150
|
+
reference_loc_peak[i] = reference_energy[n + 1]
|
151
|
+
|
152
|
+
# find the peaks in the hypothesis speech signal
|
153
|
+
if hypothesis_slope[i] > 0:
|
154
|
+
# search to the right
|
155
|
+
n = i
|
156
|
+
while (n < num_crit - 1) and (hypothesis_slope[n] > 0):
|
157
|
+
n = n + 1
|
158
|
+
hypothesis_loc_peak[i] = hypothesis_energy[n - 1]
|
159
|
+
else:
|
160
|
+
# search to the left
|
161
|
+
n = i
|
162
|
+
while (n >= 0) and (hypothesis_slope[n] <= 0):
|
163
|
+
n = n - 1
|
164
|
+
hypothesis_loc_peak[i] = hypothesis_energy[n + 1]
|
165
|
+
|
166
|
+
# (6) Compute the weighted spectral slope measure for this frame.
|
167
|
+
# This includes determination of the weighting function.
|
168
|
+
db_max_reference = np.max(reference_energy)
|
169
|
+
db_max_hypothesis = np.max(hypothesis_energy)
|
170
|
+
|
171
|
+
# The weights are calculated by averaging individual weighting factors from the reference and hypothesis frame.
|
172
|
+
# These weights w_reference and w_hypothesis should range from 0 to 1 and place more emphasis on spectral peaks
|
173
|
+
# and less emphasis on slope differences in spectral valleys.
|
174
|
+
# This procedure is described on page 1280 of Klatt's 1982 ICASSP paper.
|
175
|
+
|
176
|
+
w_max_reference = np.divide(k_max, k_max + db_max_reference - reference_energy[0: num_crit - 1])
|
177
|
+
w_loc_max_reference = np.divide(k_loc_max, k_loc_max + reference_loc_peak - reference_energy[0: num_crit - 1])
|
178
|
+
w_reference = np.multiply(w_max_reference, w_loc_max_reference)
|
179
|
+
|
180
|
+
w_max_hypothesis = np.divide(k_max, k_max + db_max_hypothesis - hypothesis_energy[0: num_crit - 1])
|
181
|
+
w_loc_max_hypothesis = np.divide(k_loc_max,
|
182
|
+
k_loc_max + hypothesis_loc_peak - hypothesis_energy[0: num_crit - 1])
|
183
|
+
w_hypothesis = np.multiply(w_max_hypothesis, w_loc_max_hypothesis)
|
184
|
+
|
185
|
+
w = np.divide(np.add(w_reference, w_hypothesis), 2.0)
|
186
|
+
slope_diff = np.subtract(reference_slope, hypothesis_slope)[0: num_crit - 1]
|
187
|
+
distortion[frame_count] = np.dot(w, np.square(slope_diff)) / np.sum(w)
|
188
|
+
|
189
|
+
# This normalization is not part of Klatt's paper, but helps to normalize the measure.
|
190
|
+
# Here we scale the measure by the sum of the weights.
|
191
|
+
start = start + skip_rate
|
192
|
+
|
193
|
+
return distortion
|
194
|
+
|
195
|
+
|
196
|
+
def _calc_log_likelihood_ratio_measure(hypothesis: np.ndarray,
|
197
|
+
reference: np.ndarray,
|
198
|
+
sample_rate: int = SAMPLE_RATE) -> np.ndarray:
|
199
|
+
from scipy.linalg import toeplitz
|
200
|
+
|
201
|
+
# The lengths of the reference and hypothesis must be the same.
|
202
|
+
reference_length = np.size(reference)
|
203
|
+
hypothesis_length = np.size(hypothesis)
|
204
|
+
if reference_length != hypothesis_length:
|
205
|
+
raise ValueError('Hypothesis and reference must be the same length.')
|
206
|
+
|
207
|
+
# window length in samples
|
208
|
+
win_length = int(np.round(30 * sample_rate / 1000))
|
209
|
+
# window skip in samples
|
210
|
+
skip_rate = int(np.floor(win_length / 4))
|
211
|
+
# LPC analysis order; this could vary depending on sampling frequency.
|
212
|
+
if sample_rate < 10000:
|
213
|
+
p = 10
|
214
|
+
else:
|
215
|
+
p = 16
|
216
|
+
|
217
|
+
# For each frame of input speech, calculate the log likelihood ratio
|
218
|
+
num_frames = int((reference_length - win_length) / skip_rate)
|
219
|
+
start = 0
|
220
|
+
window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
|
221
|
+
|
222
|
+
distortion = np.empty(num_frames)
|
223
|
+
for frame_count in range(num_frames):
|
224
|
+
# (1) Get the frames for the test and reference speech. Multiply by Hanning window.
|
225
|
+
reference_frame = reference[start: start + win_length]
|
226
|
+
hypothesis_frame = hypothesis[start: start + win_length]
|
227
|
+
reference_frame = np.multiply(reference_frame, window)
|
228
|
+
hypothesis_frame = np.multiply(hypothesis_frame, window)
|
229
|
+
|
230
|
+
# (2) Get the autocorrelation lags and LPC parameters used to compute the log likelihood ratio measure.
|
231
|
+
r_reference, ref_reference, a_reference = _lp_coefficients(reference_frame, p)
|
232
|
+
r_hypothesis, ref_hypothesis, a_hypothesis = _lp_coefficients(hypothesis_frame, p)
|
233
|
+
|
234
|
+
# (3) Compute the log likelihood ratio measure
|
235
|
+
numerator = np.dot(np.matmul(a_hypothesis, toeplitz(r_reference)), a_hypothesis)
|
236
|
+
denominator = np.dot(np.matmul(a_reference, toeplitz(r_reference)), a_reference)
|
237
|
+
distortion[frame_count] = np.log(numerator / denominator)
|
238
|
+
start = start + skip_rate
|
239
|
+
return distortion
|
240
|
+
|
241
|
+
|
242
|
+
def _calc_snr(hypothesis: np.ndarray,
|
243
|
+
reference: np.ndarray,
|
244
|
+
sample_rate: int = SAMPLE_RATE) -> tuple[float, np.ndarray]:
|
245
|
+
# The lengths of the reference and hypothesis must be the same.
|
246
|
+
reference_length = len(reference)
|
247
|
+
hypothesis_length = len(hypothesis)
|
248
|
+
if reference_length != hypothesis_length:
|
249
|
+
raise ValueError('Hypothesis and reference must be the same length.')
|
250
|
+
|
251
|
+
overall_snr = 10 * np.log10(np.sum(np.square(reference)) / np.sum(np.square(reference - hypothesis)))
|
252
|
+
|
253
|
+
# window length in samples
|
254
|
+
win_length = round(30 * sample_rate / 1000)
|
255
|
+
# window skip in samples
|
256
|
+
skip_rate = int(np.floor(win_length / 4))
|
257
|
+
# minimum SNR in dB
|
258
|
+
min_snr = -10
|
259
|
+
# maximum SNR in dB
|
260
|
+
max_snr = 35
|
261
|
+
|
262
|
+
# For each frame of input speech, calculate the segmental SNR
|
263
|
+
num_frames = int(reference_length / skip_rate - (win_length / skip_rate))
|
264
|
+
start = 0
|
265
|
+
window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
|
266
|
+
|
267
|
+
segmental_snr = np.empty(num_frames)
|
268
|
+
eps = np.spacing(1)
|
269
|
+
for frame_count in range(num_frames):
|
270
|
+
# (1) Get the frames for the test and reference speech. Multiply by Hanning window.
|
271
|
+
reference_frame = reference[start:start + win_length]
|
272
|
+
hypothesis_frame = hypothesis[start:start + win_length]
|
273
|
+
reference_frame = np.multiply(reference_frame, window)
|
274
|
+
hypothesis_frame = np.multiply(hypothesis_frame, window)
|
275
|
+
|
276
|
+
# (2) Compute the segmental SNR
|
277
|
+
signal_energy = np.sum(np.square(reference_frame))
|
278
|
+
noise_energy = np.sum(np.square(reference_frame - hypothesis_frame))
|
279
|
+
segmental_snr[frame_count] = np.clip(10 * np.log10(signal_energy / (noise_energy + eps) + eps),
|
280
|
+
min_snr,
|
281
|
+
max_snr)
|
282
|
+
|
283
|
+
start = start + skip_rate
|
284
|
+
|
285
|
+
return overall_snr, segmental_snr
|
286
|
+
|
287
|
+
|
288
|
+
def _lp_coefficients(speech_frame, model_order):
|
289
|
+
# (1) Compute autocorrelation lags
|
290
|
+
win_length = np.size(speech_frame)
|
291
|
+
autocorrelation = np.empty(model_order + 1)
|
292
|
+
e = np.empty(model_order + 1)
|
293
|
+
for k in range(model_order + 1):
|
294
|
+
autocorrelation[k] = np.dot(speech_frame[0:win_length - k], speech_frame[k: win_length])
|
295
|
+
|
296
|
+
# (2) Levinson-Durbin
|
297
|
+
a = np.ones(model_order)
|
298
|
+
a_past = np.empty(model_order)
|
299
|
+
ref_coefficients = np.empty(model_order)
|
300
|
+
e[0] = autocorrelation[0]
|
301
|
+
for i in range(model_order):
|
302
|
+
a_past[0: i] = a[0: i]
|
303
|
+
sum_term = np.dot(a_past[0: i], autocorrelation[i:0:-1])
|
304
|
+
ref_coefficients[i] = (autocorrelation[i + 1] - sum_term) / e[i]
|
305
|
+
a[i] = ref_coefficients[i]
|
306
|
+
if i == 0:
|
307
|
+
a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1:-1:-1], ref_coefficients[i])
|
308
|
+
else:
|
309
|
+
a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1::-1], ref_coefficients[i])
|
310
|
+
e[i + 1] = (1 - ref_coefficients[i] * ref_coefficients[i]) * e[i]
|
311
|
+
lp_params = np.concatenate((np.array([1]), -a))
|
312
|
+
return autocorrelation, ref_coefficients, lp_params
|
sonusai/metrics/calc_wer.py
CHANGED
sonusai/metrics/calc_wsdr.py
CHANGED
@@ -52,62 +52,3 @@ def calc_wsdr(hypothesis: np.ndarray,
|
|
52
52
|
wsdr = 10 * np.log10(-1 / (wsdr - 1 - 1e-7)) # range -3 --> inf (or 1e-7 limit of 70db)
|
53
53
|
|
54
54
|
return float(wsdr), cc, cw
|
55
|
-
|
56
|
-
# From calc_sa_sdr:
|
57
|
-
# These should include a noise to be a complete mixture estimate, i.e.,
|
58
|
-
# noise_est = sum-over-all-srcs(s_est(0:nsamples, :) - sum-over-non-noisesrc(s_est(0:nsamples, n))
|
59
|
-
# should be one of the sources in reference (s_true) and hypothesis (s_est).
|
60
|
-
#
|
61
|
-
# Calculates -10*log10(sumn(||sn||^2) / sumn(||sn - shn||^2)
|
62
|
-
# Note: for SA method, sums are done independently on ref and error before division, vs. SDR and SI-SDR
|
63
|
-
# where sum over n is taken after divide (before log). This is more stable in noise-only cases and also
|
64
|
-
# when some sources are poorly estimated.
|
65
|
-
# TBD: add soft-max option with eps and tau params
|
66
|
-
#
|
67
|
-
# if with_scale:
|
68
|
-
# # calc 1 x nsrc scaling factors
|
69
|
-
# ref_energy = np.sum(reference ** 2, axis=0, keepdims=True)
|
70
|
-
# # if ref_energy is zero, just set scaling to 1.0
|
71
|
-
# with np.errstate(divide='ignore', invalid='ignore'):
|
72
|
-
# opt_scale = np.sum(reference * hypothesis, axis=0, keepdims=True) / ref_energy
|
73
|
-
# opt_scale[opt_scale == np.inf] = 1.0
|
74
|
-
# opt_scale = np.nan_to_num(opt_scale, nan=1.0)
|
75
|
-
# scaled_ref = opt_scale * reference
|
76
|
-
# else:
|
77
|
-
# scaled_ref = reference
|
78
|
-
# opt_scale = np.ones((1, reference.shape[1]), dtype=float)
|
79
|
-
#
|
80
|
-
# # Calculate Lsdr = −<y,yˆ>/∥y∥∥yˆ∥ always in range [1 --> -1], size [batch,]
|
81
|
-
# t_tru_sq = torch.sum(torch.square(t_tru), -1)
|
82
|
-
# t_denom = torch.sqrt(t_tru_sq) * torch.sqrt(torch.sum(torch.square(t_est), -1)) + 1e-7
|
83
|
-
# t_wsdr = -torch.divide(torch.sum(torch.multiply(t_tru, t_est), -1), t_denom)
|
84
|
-
# n_tru_sq = torch.sum(torch.square(n_tru), -1)
|
85
|
-
# n_denom = torch.sqrt(torch.sum(torch.square(n_tru), -1)) \
|
86
|
-
# * torch.sqrt(torch.sum(torch.square(n_est), -1)) + 1e-7
|
87
|
-
# n_wsdr = -torch.divide(torch.sum(torch.multiply(n_tru, n_est), -1), n_denom)
|
88
|
-
# if self.cl_noise_wght > 0:
|
89
|
-
# wsdr = self.cl_target_wght * t_wsdr + self.cl_noise_wght * n_wsdr
|
90
|
-
# else: # adaptive per relative strength of target vs noise: α = ||y||2/(||y||2 +||z||2)
|
91
|
-
# tweight = torch.divide(t_tru_sq, t_tru_sq + n_tru_sq + 1e-7) # energy ratio target vs. noise
|
92
|
-
# wsdr = tweight * t_wsdr + (1 - tweight) * n_wsdr
|
93
|
-
# wsdr = torch.mean(wsdr) # reduction to scalar
|
94
|
-
#
|
95
|
-
# # multisrc sa-sdr, inputs must be [samples, nsrc]
|
96
|
-
# err = scaled_ref - hypothesis
|
97
|
-
#
|
98
|
-
# # -10*log10(sumk(||sk||^2) / sumk(||sk - shk||^2)
|
99
|
-
# # sum over samples and sources
|
100
|
-
# num = np.sum(reference ** 2)
|
101
|
-
# den = np.sum(err ** 2)
|
102
|
-
# if num == 0 and den == 0:
|
103
|
-
# ratio = np.inf
|
104
|
-
# else:
|
105
|
-
# ratio = num / (den + np.finfo(np.float32).eps)
|
106
|
-
#
|
107
|
-
# sa_sdr = 10 * np.log10(ratio)
|
108
|
-
#
|
109
|
-
# if with_negate:
|
110
|
-
# # for use as a loss function
|
111
|
-
# sa_sdr = -sa_sdr
|
112
|
-
#
|
113
|
-
# return sa_sdr, opt_scale
|
sonusai/mixture/__init__.py
CHANGED
@@ -44,6 +44,7 @@ from .constants import VALID_CONFIGS
|
|
44
44
|
from .constants import VALID_NOISE_MIX_MODES
|
45
45
|
from .constants import VALID_TRUTH_SETTINGS
|
46
46
|
from .datatypes import AudioF
|
47
|
+
from .datatypes import AudioStatsMetrics
|
47
48
|
from .datatypes import AudioT
|
48
49
|
from .datatypes import AudiosF
|
49
50
|
from .datatypes import AudiosT
|
@@ -72,9 +73,11 @@ from .datatypes import NoiseFile
|
|
72
73
|
from .datatypes import NoiseFiles
|
73
74
|
from .datatypes import Predict
|
74
75
|
from .datatypes import Segsnr
|
76
|
+
from .datatypes import SnrFMetrics
|
75
77
|
from .datatypes import SpectralMask
|
76
78
|
from .datatypes import SpectralMasks
|
77
79
|
from .datatypes import SpeechMetadata
|
80
|
+
from .datatypes import SpeechMetrics
|
78
81
|
from .datatypes import TargetFile
|
79
82
|
from .datatypes import TargetFiles
|
80
83
|
from .datatypes import TransformConfig
|
@@ -113,8 +116,6 @@ from .helpers import read_mixture_data
|
|
113
116
|
from .helpers import write_mixture_data
|
114
117
|
from .helpers import write_mixture_metadata
|
115
118
|
from .log_duration_and_sizes import log_duration_and_sizes
|
116
|
-
from .mapped_snr_f import calculate_mapped_snr_f
|
117
|
-
from .mapped_snr_f import calculate_snr_f_statistics
|
118
119
|
from .mixdb import MixtureDatabase
|
119
120
|
from .mixdb import db_file
|
120
121
|
from .sox_audio import Transformer
|
sonusai/mixture/audio.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from functools import lru_cache
|
2
|
+
from pathlib import Path
|
2
3
|
|
3
4
|
from sonusai.mixture.datatypes import AudioT
|
4
5
|
from sonusai.mixture.datatypes import ImpulseResponseData
|
@@ -28,7 +29,7 @@ def get_duration(audio: AudioT) -> float:
|
|
28
29
|
return len(audio) / SAMPLE_RATE
|
29
30
|
|
30
31
|
|
31
|
-
def validate_input_file(input_filepath: str) -> None:
|
32
|
+
def validate_input_file(input_filepath: str | Path) -> None:
|
32
33
|
from os.path import exists
|
33
34
|
from os.path import splitext
|
34
35
|
|
@@ -46,7 +47,7 @@ def validate_input_file(input_filepath: str) -> None:
|
|
46
47
|
|
47
48
|
|
48
49
|
@lru_cache
|
49
|
-
def get_sample_rate(name: str) -> int:
|
50
|
+
def get_sample_rate(name: str | Path) -> int:
|
50
51
|
"""Get sample rate from audio file
|
51
52
|
|
52
53
|
:param name: File name
|
@@ -58,7 +59,7 @@ def get_sample_rate(name: str) -> int:
|
|
58
59
|
|
59
60
|
|
60
61
|
@lru_cache
|
61
|
-
def read_audio(name: str) -> AudioT:
|
62
|
+
def read_audio(name: str | Path) -> AudioT:
|
62
63
|
"""Read audio data from a file
|
63
64
|
|
64
65
|
:param name: File name
|
@@ -70,7 +71,7 @@ def read_audio(name: str) -> AudioT:
|
|
70
71
|
|
71
72
|
|
72
73
|
@lru_cache
|
73
|
-
def read_ir(name: str) -> ImpulseResponseData:
|
74
|
+
def read_ir(name: str | Path) -> ImpulseResponseData:
|
74
75
|
"""Read impulse response data
|
75
76
|
|
76
77
|
:param name: File name
|
@@ -82,7 +83,7 @@ def read_ir(name: str) -> ImpulseResponseData:
|
|
82
83
|
|
83
84
|
|
84
85
|
@lru_cache
|
85
|
-
def get_num_samples(name: str) -> int:
|
86
|
+
def get_num_samples(name: str | Path) -> int:
|
86
87
|
"""Get the number of samples resampled to the SonusAI sample rate in the given file
|
87
88
|
|
88
89
|
:param name: File name
|
sonusai/mixture/config.py
CHANGED
@@ -90,6 +90,19 @@ def update_config_from_file(name: str, config: dict) -> dict:
|
|
90
90
|
|
91
91
|
updated_config['truth_settings'] = update_truth_settings(updated_config['truth_settings'], default)
|
92
92
|
|
93
|
+
# Handle 'asr_configs' special case
|
94
|
+
if 'asr_configs' in updated_config:
|
95
|
+
asr_configs = {}
|
96
|
+
for asr_config in updated_config['asr_configs']:
|
97
|
+
asr_name = asr_config.get('name', None)
|
98
|
+
asr_engine = asr_config.get('engine', None)
|
99
|
+
if asr_name is None or asr_engine is None:
|
100
|
+
raise SonusAIError(f'Invalid config parameter in {name}: asr_configs.\n'
|
101
|
+
f'asr_configs must contain both name and engine.')
|
102
|
+
del asr_config['name']
|
103
|
+
asr_configs[asr_name] = asr_config
|
104
|
+
updated_config['asr_configs'] = asr_configs
|
105
|
+
|
93
106
|
# Check for required keys
|
94
107
|
for key in REQUIRED_CONFIGS:
|
95
108
|
if key not in updated_config:
|
sonusai/mixture/constants.py
CHANGED
sonusai/mixture/datatypes.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
|
+
from typing import Any
|
3
|
+
from typing import NamedTuple
|
2
4
|
from typing import Optional
|
3
5
|
from typing import TypeAlias
|
4
6
|
|
@@ -309,8 +311,12 @@ class FeatureGeneratorInfo:
|
|
309
311
|
it_config: TransformConfig
|
310
312
|
|
311
313
|
|
314
|
+
ASRConfigs: TypeAlias = dict[str, dict[str, Any]]
|
315
|
+
|
316
|
+
|
312
317
|
@dataclass
|
313
318
|
class MixtureDatabaseConfig(DataClassSonusAIMixin):
|
319
|
+
asr_configs: Optional[ASRConfigs] = None
|
314
320
|
class_balancing: Optional[bool] = False
|
315
321
|
class_labels: Optional[list[str]] = None
|
316
322
|
class_weights_threshold: Optional[list[float]] = None
|
@@ -327,3 +333,30 @@ class MixtureDatabaseConfig(DataClassSonusAIMixin):
|
|
327
333
|
|
328
334
|
|
329
335
|
SpeechMetadata: TypeAlias = str | list[Interval] | None
|
336
|
+
|
337
|
+
|
338
|
+
class SnrFMetrics(NamedTuple):
|
339
|
+
mean: Optional[float] = None
|
340
|
+
var: Optional[float] = None
|
341
|
+
db_mean: Optional[float] = None
|
342
|
+
db_std: Optional[float] = None
|
343
|
+
|
344
|
+
|
345
|
+
class SpeechMetrics(NamedTuple):
|
346
|
+
pesq: Optional[float] = None
|
347
|
+
c_sig: Optional[float] = None
|
348
|
+
c_bak: Optional[float] = None
|
349
|
+
c_ovl: Optional[float] = None
|
350
|
+
|
351
|
+
|
352
|
+
class AudioStatsMetrics(NamedTuple):
|
353
|
+
dco: Optional[float] = None
|
354
|
+
min: Optional[float] = None
|
355
|
+
max: Optional[float] = None
|
356
|
+
pkdb: Optional[float] = None
|
357
|
+
lrms: Optional[float] = None
|
358
|
+
pkr: Optional[float] = None
|
359
|
+
tr: Optional[float] = None
|
360
|
+
cr: Optional[float] = None
|
361
|
+
fl: Optional[float] = None
|
362
|
+
pkc: Optional[float] = None
|
sonusai/mixture/generation.py
CHANGED
@@ -59,6 +59,7 @@ def initialize_db(location: str, test: bool = False) -> None:
|
|
59
59
|
CREATE TABLE top (
|
60
60
|
id INTEGER PRIMARY KEY NOT NULL,
|
61
61
|
version INTEGER NOT NULL,
|
62
|
+
asr_configs TEXT NOT NULL,
|
62
63
|
class_balancing BOOLEAN NOT NULL,
|
63
64
|
feature TEXT NOT NULL,
|
64
65
|
noise_mix_mode TEXT NOT NULL,
|
@@ -149,6 +150,8 @@ def initialize_db(location: str, test: bool = False) -> None:
|
|
149
150
|
def populate_top_table(location: str, config: dict, test: bool = False) -> None:
|
150
151
|
"""Populate top table
|
151
152
|
"""
|
153
|
+
import json
|
154
|
+
|
152
155
|
from sonusai import SonusAIError
|
153
156
|
from .mixdb import db_connection
|
154
157
|
|
@@ -158,11 +161,12 @@ def populate_top_table(location: str, config: dict, test: bool = False) -> None:
|
|
158
161
|
|
159
162
|
con = db_connection(location=location, readonly=False, test=test)
|
160
163
|
con.execute("""
|
161
|
-
INSERT INTO top (version, class_balancing, feature, noise_mix_mode, num_classes,
|
164
|
+
INSERT INTO top (version, asr_configs, class_balancing, feature, noise_mix_mode, num_classes,
|
162
165
|
seed, truth_mutex, truth_reduction_function, mixid_width, speaker_metadata_tiers, textgrid_metadata_tiers)
|
163
|
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
166
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
164
167
|
""", (
|
165
168
|
1,
|
169
|
+
json.dumps(config['asr_configs']),
|
166
170
|
config['class_balancing'],
|
167
171
|
config['feature'],
|
168
172
|
config['noise_mix_mode'],
|