sonusai 0.18.2__py3-none-any.whl → 0.18.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -0
- sonusai/audiofe.py +1 -1
- sonusai/calc_metric_spenh.py +32 -362
- sonusai/data/genmixdb.yml +2 -0
- sonusai/doc/doc.py +45 -4
- sonusai/genmetrics.py +137 -109
- sonusai/lsdb.py +2 -2
- sonusai/metrics/__init__.py +4 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_pesq.py +12 -8
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_snr_f.py +34 -0
- sonusai/metrics/calc_speech.py +312 -0
- sonusai/metrics/calc_wer.py +2 -3
- sonusai/metrics/calc_wsdr.py +0 -59
- sonusai/mixture/__init__.py +3 -2
- sonusai/mixture/audio.py +6 -5
- sonusai/mixture/config.py +13 -0
- sonusai/mixture/constants.py +1 -0
- sonusai/mixture/datatypes.py +33 -0
- sonusai/mixture/generation.py +6 -2
- sonusai/mixture/mixdb.py +261 -122
- sonusai/mixture/soundfile_audio.py +8 -6
- sonusai/mixture/sox_audio.py +16 -13
- sonusai/mixture/torchaudio_audio.py +6 -4
- sonusai/mixture/truth_functions/energy.py +40 -28
- sonusai/mixture/truth_functions/target.py +0 -1
- sonusai/utils/__init__.py +1 -1
- sonusai/utils/asr.py +26 -39
- sonusai/utils/asr_functions/aaware_whisper.py +3 -3
- {sonusai-0.18.2.dist-info → sonusai-0.18.4.dist-info}/METADATA +1 -1
- {sonusai-0.18.2.dist-info → sonusai-0.18.4.dist-info}/RECORD +34 -31
- sonusai/mixture/mapped_snr_f.py +0 -100
- {sonusai-0.18.2.dist-info → sonusai-0.18.4.dist-info}/WHEEL +0 -0
- {sonusai-0.18.2.dist-info → sonusai-0.18.4.dist-info}/entry_points.txt +0 -0
sonusai/__init__.py
CHANGED
@@ -10,6 +10,7 @@ commands_doc = """
|
|
10
10
|
calc_metric_spenh Run speech enhancement and analysis
|
11
11
|
doc Documentation
|
12
12
|
genft Generate feature and truth data
|
13
|
+
genmetrics Generate mixture metrics data
|
13
14
|
genmix Generate mixture and truth data
|
14
15
|
genmixdb Generate a mixture database
|
15
16
|
gentcst Generate target configuration from a subdirectory tree
|
sonusai/audiofe.py
CHANGED
@@ -142,7 +142,7 @@ def main() -> None:
|
|
142
142
|
if hparams is None:
|
143
143
|
logger.error(f'Error: ONNX model does not have required SonusAI hyperparameters, cannot proceed.')
|
144
144
|
raise SystemExit(1)
|
145
|
-
feature_mode = hparams
|
145
|
+
feature_mode = hparams['feature']
|
146
146
|
in0name = sess_inputs[0].name
|
147
147
|
in0type = sess_inputs[0].type
|
148
148
|
out_names = [n.name for n in session.get_outputs()]
|
sonusai/calc_metric_spenh.py
CHANGED
@@ -24,7 +24,7 @@ For whisper ASR methods, the possible models used in local processing (ASR = whi
|
|
24
24
|
{tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large}
|
25
25
|
but note most are very computationally demanding and can overwhelm/hang a local system.
|
26
26
|
|
27
|
-
Outputs the following to PLOC (where id is
|
27
|
+
Outputs the following to PLOC (where id is mixid number 0:num_mixtures):
|
28
28
|
<id>_metric_spenh.txt
|
29
29
|
|
30
30
|
If --plot:
|
@@ -74,6 +74,9 @@ from sonusai.mixture import Feature
|
|
74
74
|
from sonusai.mixture import MixtureDatabase
|
75
75
|
from sonusai.mixture import Predict
|
76
76
|
|
77
|
+
DB_99 = np.power(10, 99 / 10)
|
78
|
+
DB_N99 = np.power(10, -99 / 10)
|
79
|
+
|
77
80
|
|
78
81
|
def signal_handler(_sig, _frame):
|
79
82
|
import sys
|
@@ -122,298 +125,6 @@ def power_uncompress(spec):
|
|
122
125
|
return real_uncompress + 1j * imag_uncompress
|
123
126
|
|
124
127
|
|
125
|
-
def snr(clean_speech, processed_speech, sample_rate):
|
126
|
-
# Check the length of the clean and processed speech. Must be the same.
|
127
|
-
clean_length = len(clean_speech)
|
128
|
-
processed_length = len(processed_speech)
|
129
|
-
if clean_length != processed_length:
|
130
|
-
raise ValueError('Both Speech Files must be same length.')
|
131
|
-
|
132
|
-
overall_snr = 10 * np.log10(np.sum(np.square(clean_speech)) / np.sum(np.square(clean_speech - processed_speech)))
|
133
|
-
|
134
|
-
# Global Variables
|
135
|
-
win_length = round(30 * sample_rate / 1000) # window length in samples
|
136
|
-
skip_rate = int(np.floor(win_length / 4)) # window skip in samples
|
137
|
-
min_snr = -10 # minimum SNR in dB
|
138
|
-
max_snr = 35 # maximum SNR in dB
|
139
|
-
|
140
|
-
# For each frame of input speech, calculate the Segmental SNR
|
141
|
-
num_frames = int(clean_length / skip_rate - (win_length / skip_rate)) # number of frames
|
142
|
-
start = 0 # starting sample
|
143
|
-
window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
|
144
|
-
|
145
|
-
segmental_snr = np.empty(num_frames)
|
146
|
-
eps = np.spacing(1)
|
147
|
-
for frame_count in range(num_frames):
|
148
|
-
# (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
|
149
|
-
clean_frame = clean_speech[start:start + win_length]
|
150
|
-
processed_frame = processed_speech[start:start + win_length]
|
151
|
-
clean_frame = np.multiply(clean_frame, window)
|
152
|
-
processed_frame = np.multiply(processed_frame, window)
|
153
|
-
|
154
|
-
# (2) Compute the Segmental SNR
|
155
|
-
signal_energy = np.sum(np.square(clean_frame))
|
156
|
-
noise_energy = np.sum(np.square(clean_frame - processed_frame))
|
157
|
-
segmental_snr[frame_count] = 10 * np.log10(signal_energy / (noise_energy + eps) + eps)
|
158
|
-
segmental_snr[frame_count] = max(segmental_snr[frame_count], min_snr)
|
159
|
-
segmental_snr[frame_count] = min(segmental_snr[frame_count], max_snr)
|
160
|
-
|
161
|
-
start = start + skip_rate
|
162
|
-
|
163
|
-
return overall_snr, segmental_snr
|
164
|
-
|
165
|
-
|
166
|
-
def lp_coefficients(speech_frame, model_order):
|
167
|
-
# (1) Compute Autocorrelation Lags
|
168
|
-
win_length = np.size(speech_frame)
|
169
|
-
autocorrelation = np.empty(model_order + 1)
|
170
|
-
e = np.empty(model_order + 1)
|
171
|
-
for k in range(model_order + 1):
|
172
|
-
autocorrelation[k] = np.dot(speech_frame[0:win_length - k], speech_frame[k: win_length])
|
173
|
-
|
174
|
-
# (2) Levinson-Durbin
|
175
|
-
a = np.ones(model_order)
|
176
|
-
a_past = np.empty(model_order)
|
177
|
-
ref_coefficients = np.empty(model_order)
|
178
|
-
e[0] = autocorrelation[0]
|
179
|
-
for i in range(model_order):
|
180
|
-
a_past[0: i] = a[0: i]
|
181
|
-
sum_term = np.dot(a_past[0: i], autocorrelation[i:0:-1])
|
182
|
-
ref_coefficients[i] = (autocorrelation[i + 1] - sum_term) / e[i]
|
183
|
-
a[i] = ref_coefficients[i]
|
184
|
-
if i == 0:
|
185
|
-
a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1:-1:-1], ref_coefficients[i])
|
186
|
-
else:
|
187
|
-
a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1::-1], ref_coefficients[i])
|
188
|
-
e[i + 1] = (1 - ref_coefficients[i] * ref_coefficients[i]) * e[i]
|
189
|
-
lp_params = np.concatenate((np.array([1]), -a))
|
190
|
-
return autocorrelation, ref_coefficients, lp_params
|
191
|
-
|
192
|
-
|
193
|
-
def llr(clean_speech, processed_speech, sample_rate):
|
194
|
-
from scipy.linalg import toeplitz
|
195
|
-
|
196
|
-
# Check the length of the clean and processed speech. Must be the same.
|
197
|
-
clean_length = np.size(clean_speech)
|
198
|
-
processed_length = np.size(processed_speech)
|
199
|
-
if clean_length != processed_length:
|
200
|
-
raise ValueError('Both speech files must be same length.')
|
201
|
-
|
202
|
-
# Global Variables
|
203
|
-
win_length = (np.round(30 * sample_rate / 1000)).astype(int) # window length in samples
|
204
|
-
skip_rate = (np.floor(win_length / 4)).astype(int) # window skip in samples
|
205
|
-
if sample_rate < 10000:
|
206
|
-
p = 10 # LPC Analysis Order
|
207
|
-
else:
|
208
|
-
p = 16 # this could vary depending on sampling frequency.
|
209
|
-
|
210
|
-
# For each frame of input speech, calculate the Log Likelihood Ratio
|
211
|
-
num_frames = int((clean_length - win_length) / skip_rate) # number of frames
|
212
|
-
start = 0 # starting sample
|
213
|
-
window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
|
214
|
-
|
215
|
-
distortion = np.empty(num_frames)
|
216
|
-
for frame_count in range(num_frames):
|
217
|
-
# (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
|
218
|
-
clean_frame = clean_speech[start: start + win_length]
|
219
|
-
processed_frame = processed_speech[start: start + win_length]
|
220
|
-
clean_frame = np.multiply(clean_frame, window)
|
221
|
-
processed_frame = np.multiply(processed_frame, window)
|
222
|
-
|
223
|
-
# (2) Get the autocorrelation lags and LPC parameters used to compute the LLR measure.
|
224
|
-
r_clean, ref_clean, a_clean = lp_coefficients(clean_frame, p)
|
225
|
-
r_processed, ref_processed, a_processed = lp_coefficients(processed_frame, p)
|
226
|
-
|
227
|
-
# (3) Compute the LLR measure
|
228
|
-
numerator = np.dot(np.matmul(a_processed, toeplitz(r_clean)), a_processed)
|
229
|
-
denominator = np.dot(np.matmul(a_clean, toeplitz(r_clean)), a_clean)
|
230
|
-
distortion[frame_count] = np.log(numerator / denominator)
|
231
|
-
start = start + skip_rate
|
232
|
-
return distortion
|
233
|
-
|
234
|
-
|
235
|
-
def wss(clean_speech, processed_speech, sample_rate):
|
236
|
-
from scipy.fftpack import fft
|
237
|
-
|
238
|
-
# Check the length of the clean and processed speech, which must be the same.
|
239
|
-
clean_length = np.size(clean_speech)
|
240
|
-
processed_length = np.size(processed_speech)
|
241
|
-
if clean_length != processed_length:
|
242
|
-
raise ValueError('Files must have same length.')
|
243
|
-
|
244
|
-
# Global variables
|
245
|
-
win_length = (np.round(30 * sample_rate / 1000)).astype(int) # window length in samples
|
246
|
-
skip_rate = (np.floor(np.divide(win_length, 4))).astype(int) # window skip in samples
|
247
|
-
max_freq = (np.divide(sample_rate, 2)).astype(int) # maximum bandwidth
|
248
|
-
num_crit = 25 # number of critical bands
|
249
|
-
|
250
|
-
n_fft = (np.power(2, np.ceil(np.log2(2 * win_length)))).astype(int)
|
251
|
-
n_fft_by_2 = (np.multiply(0.5, n_fft)).astype(int) # FFT size/2
|
252
|
-
k_max = 20.0 # value suggested by Klatt, pg 1280
|
253
|
-
k_loc_max = 1.0 # value suggested by Klatt, pg 1280
|
254
|
-
|
255
|
-
# Critical Band Filter Definitions (Center Frequency and Bandwidths in Hz)
|
256
|
-
cent_freq = np.array([50.0000, 120.000, 190.000, 260.000, 330.000, 400.000, 470.000,
|
257
|
-
540.000, 617.372, 703.378, 798.717, 904.128, 1020.38, 1148.30,
|
258
|
-
1288.72, 1442.54, 1610.70, 1794.16, 1993.93, 2211.08, 2446.71,
|
259
|
-
2701.97, 2978.04, 3276.17, 3597.63])
|
260
|
-
bandwidth = np.array([70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 70.0000,
|
261
|
-
77.3724, 86.0056, 95.3398, 105.411, 116.256, 127.914, 140.423,
|
262
|
-
153.823, 168.154, 183.457, 199.776, 217.153, 235.631, 255.255,
|
263
|
-
276.072, 298.126, 321.465, 346.136])
|
264
|
-
|
265
|
-
bw_min = bandwidth[0] # minimum critical bandwidth
|
266
|
-
|
267
|
-
# Set up the critical band filters.
|
268
|
-
# Note here that Gaussian-ly shaped filters are used.
|
269
|
-
# Also, the sum of the filter weights are equivalent for each critical band filter.
|
270
|
-
# Filter less than -30 dB and set to zero.
|
271
|
-
min_factor = np.exp(-30.0 / (2.0 * 2.303)) # -30 dB point of filter
|
272
|
-
crit_filter = np.empty((num_crit, n_fft_by_2))
|
273
|
-
for i in range(num_crit):
|
274
|
-
f0 = (cent_freq[i] / max_freq) * n_fft_by_2
|
275
|
-
bw = (bandwidth[i] / max_freq) * n_fft_by_2
|
276
|
-
norm_factor = np.log(bw_min) - np.log(bandwidth[i])
|
277
|
-
j = np.arange(n_fft_by_2)
|
278
|
-
crit_filter[i, :] = np.exp(-11 * np.square(np.divide(j - np.floor(f0), bw)) + norm_factor)
|
279
|
-
cond = np.greater(crit_filter[i, :], min_factor)
|
280
|
-
crit_filter[i, :] = np.where(cond, crit_filter[i, :], 0)
|
281
|
-
# For each frame of input speech, calculate the Weighted Spectral Slope Measure
|
282
|
-
num_frames = int(clean_length / skip_rate - (win_length / skip_rate)) # number of frames
|
283
|
-
start = 0 # starting sample
|
284
|
-
window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
|
285
|
-
|
286
|
-
distortion = np.empty(num_frames)
|
287
|
-
for frame_count in range(num_frames):
|
288
|
-
# (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
|
289
|
-
clean_frame = clean_speech[start: start + win_length] / 32768
|
290
|
-
processed_frame = processed_speech[start: start + win_length] / 32768
|
291
|
-
clean_frame = np.multiply(clean_frame, window)
|
292
|
-
processed_frame = np.multiply(processed_frame, window)
|
293
|
-
# (2) Compute the Power Spectrum of Clean and Processed
|
294
|
-
clean_spec = np.square(np.abs(fft(clean_frame, n_fft)))
|
295
|
-
processed_spec = np.square(np.abs(fft(processed_frame, n_fft)))
|
296
|
-
|
297
|
-
# (3) Compute Filterbank Output Energies (in dB scale)
|
298
|
-
clean_energy = np.matmul(crit_filter, clean_spec[0:n_fft_by_2])
|
299
|
-
processed_energy = np.matmul(crit_filter, processed_spec[0:n_fft_by_2])
|
300
|
-
|
301
|
-
clean_energy = 10 * np.log10(np.maximum(clean_energy, 1E-10))
|
302
|
-
processed_energy = 10 * np.log10(np.maximum(processed_energy, 1E-10))
|
303
|
-
|
304
|
-
# (4) Compute Spectral Slope (dB[i+1]-dB[i])
|
305
|
-
clean_slope = clean_energy[1:num_crit] - clean_energy[0: num_crit - 1]
|
306
|
-
processed_slope = processed_energy[1:num_crit] - processed_energy[0: num_crit - 1]
|
307
|
-
|
308
|
-
# (5) Find the nearest peak locations in the spectra to each critical band.
|
309
|
-
# If the slope is negative, we search to the left. If positive, we search to the right.
|
310
|
-
clean_loc_peak = np.empty(num_crit - 1)
|
311
|
-
processed_loc_peak = np.empty(num_crit - 1)
|
312
|
-
|
313
|
-
for i in range(num_crit - 1):
|
314
|
-
# find the peaks in the clean speech signal
|
315
|
-
if clean_slope[i] > 0: # search to the right
|
316
|
-
n = i
|
317
|
-
while (n < num_crit - 1) and (clean_slope[n] > 0):
|
318
|
-
n = n + 1
|
319
|
-
clean_loc_peak[i] = clean_energy[n - 1]
|
320
|
-
else: # search to the left
|
321
|
-
n = i
|
322
|
-
while (n >= 0) and (clean_slope[n] <= 0):
|
323
|
-
n = n - 1
|
324
|
-
clean_loc_peak[i] = clean_energy[n + 1]
|
325
|
-
|
326
|
-
# find the peaks in the processed speech signal
|
327
|
-
if processed_slope[i] > 0: # search to the right
|
328
|
-
n = i
|
329
|
-
while (n < num_crit - 1) and (processed_slope[n] > 0):
|
330
|
-
n = n + 1
|
331
|
-
processed_loc_peak[i] = processed_energy[n - 1]
|
332
|
-
else: # search to the left
|
333
|
-
n = i
|
334
|
-
while (n >= 0) and (processed_slope[n] <= 0):
|
335
|
-
n = n - 1
|
336
|
-
processed_loc_peak[i] = processed_energy[n + 1]
|
337
|
-
|
338
|
-
# (6) Compute the WSS Measure for this frame. This includes determination of the weighting function.
|
339
|
-
db_max_clean = np.max(clean_energy)
|
340
|
-
db_max_processed = np.max(processed_energy)
|
341
|
-
'''
|
342
|
-
The weights are calculated by averaging individual weighting factors from the clean and processed frame.
|
343
|
-
These weights w_clean and w_processed should range from 0 to 1 and place more emphasis on spectral peaks
|
344
|
-
and less emphasis on slope differences in spectral valleys.
|
345
|
-
This procedure is described on page 1280 of Klatt's 1982 ICASSP paper.
|
346
|
-
'''
|
347
|
-
w_max_clean = np.divide(k_max, k_max + db_max_clean - clean_energy[0: num_crit - 1])
|
348
|
-
w_loc_max_clean = np.divide(k_loc_max, k_loc_max + clean_loc_peak - clean_energy[0: num_crit - 1])
|
349
|
-
w_clean = np.multiply(w_max_clean, w_loc_max_clean)
|
350
|
-
|
351
|
-
w_max_processed = np.divide(k_max, k_max + db_max_processed - processed_energy[0: num_crit - 1])
|
352
|
-
w_loc_max_processed = np.divide(k_loc_max, k_loc_max + processed_loc_peak - processed_energy[0: num_crit - 1])
|
353
|
-
w_processed = np.multiply(w_max_processed, w_loc_max_processed)
|
354
|
-
|
355
|
-
w = np.divide(np.add(w_clean, w_processed), 2.0)
|
356
|
-
slope_diff = np.subtract(clean_slope, processed_slope)[0: num_crit - 1]
|
357
|
-
distortion[frame_count] = np.dot(w, np.square(slope_diff)) / np.sum(w)
|
358
|
-
# This normalization is not part of Klatt's paper, but helps to normalize the measure.
|
359
|
-
# Here we scale the measure by the sum of the weights.
|
360
|
-
start = start + skip_rate
|
361
|
-
return distortion
|
362
|
-
|
363
|
-
|
364
|
-
def calc_speech_metrics(hypothesis: np.ndarray,
|
365
|
-
reference: np.ndarray) -> tuple[float, int, int, int, float]:
|
366
|
-
"""
|
367
|
-
Calculate speech metrics pesq_mos, c_sig, c_bak, c_ovl, seg_snr. These are all related and thus included
|
368
|
-
in one function. Reference: matlab script "compute_metrics.m".
|
369
|
-
|
370
|
-
Usage:
|
371
|
-
pesq, c_sig, c_bak, c_ovl, ssnr = compute_metrics(hypothesis, reference, fs, path)
|
372
|
-
reference: clean audio as array
|
373
|
-
hypothesis: enhanced audio as array
|
374
|
-
Audio must have sampling rate = 16000 Hz.
|
375
|
-
|
376
|
-
Example call:
|
377
|
-
pesq_output, csig_output, cbak_output, covl_output, ssnr_output = \
|
378
|
-
calc_speech_metrics(predicted_audio, target_audio)
|
379
|
-
"""
|
380
|
-
from sonusai.metrics import calc_pesq
|
381
|
-
|
382
|
-
fs = 16000
|
383
|
-
|
384
|
-
# compute the WSS measure
|
385
|
-
wss_dist_vec = wss(reference, hypothesis, fs)
|
386
|
-
wss_dist_vec = np.sort(wss_dist_vec)
|
387
|
-
alpha = 0.95 # value from CMGAN ref implementation
|
388
|
-
wss_dist = np.mean(wss_dist_vec[0: round(np.size(wss_dist_vec) * alpha)])
|
389
|
-
|
390
|
-
# compute the LLR measure
|
391
|
-
llr_dist = llr(reference, hypothesis, fs)
|
392
|
-
ll_rs = np.sort(llr_dist)
|
393
|
-
llr_len = round(np.size(llr_dist) * alpha)
|
394
|
-
llr_mean = np.mean(ll_rs[0: llr_len])
|
395
|
-
|
396
|
-
# compute the SNRseg
|
397
|
-
snr_dist, segsnr_dist = snr(reference, hypothesis, fs)
|
398
|
-
seg_snr = np.mean(segsnr_dist)
|
399
|
-
|
400
|
-
# compute the pesq (use Sonusai wrapper, only fs=16k, mode=wb support)
|
401
|
-
pesq_mos = calc_pesq(hypothesis=hypothesis, reference=reference)
|
402
|
-
|
403
|
-
# now compute the composite measures
|
404
|
-
c_sig = 3.093 - 1.029 * llr_mean + 0.603 * pesq_mos - 0.009 * wss_dist
|
405
|
-
c_sig = max(1, c_sig)
|
406
|
-
c_sig = min(5, c_sig) # limit values to [1, 5]
|
407
|
-
c_bak = 1.634 + 0.478 * pesq_mos - 0.007 * wss_dist + 0.063 * seg_snr
|
408
|
-
c_bak = max(1, c_bak)
|
409
|
-
c_bak = min(5, c_bak) # limit values to [1, 5]
|
410
|
-
c_ovl = 1.594 + 0.805 * pesq_mos - 0.512 * llr_mean - 0.007 * wss_dist
|
411
|
-
c_ovl = max(1, c_ovl)
|
412
|
-
c_ovl = min(5, c_ovl) # limit values to [1, 5]
|
413
|
-
|
414
|
-
return pesq_mos, c_sig, c_bak, c_ovl, seg_snr
|
415
|
-
|
416
|
-
|
417
128
|
def mean_square_error(hypothesis: np.ndarray,
|
418
129
|
reference: np.ndarray,
|
419
130
|
squared: bool = False) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
@@ -494,48 +205,6 @@ def log_error(reference: np.ndarray, hypothesis: np.ndarray) -> tuple[np.ndarray
|
|
494
205
|
return err, err_b, err_f
|
495
206
|
|
496
207
|
|
497
|
-
def phase_distance(reference: np.ndarray,
|
498
|
-
hypothesis: np.ndarray,
|
499
|
-
eps: float = 1e-9) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
500
|
-
"""Calculate weighted phase distance error (weight normalization over bins per frame)
|
501
|
-
|
502
|
-
:param reference: complex [frames, bins]
|
503
|
-
:param hypothesis: complex [frames, bins]
|
504
|
-
:param eps: epsilon value
|
505
|
-
:return: mean, mean per bin, mean per frame
|
506
|
-
"""
|
507
|
-
ang_diff = np.angle(reference) - np.angle(hypothesis)
|
508
|
-
phd_mod = (ang_diff + np.pi) % (2 * np.pi) - np.pi
|
509
|
-
rh_angle_diff = phd_mod * 180 / np.pi # angle diff in deg
|
510
|
-
|
511
|
-
# Use complex divide to intrinsically keep angle diff +/-180 deg, but avoid div by zero (real hyp)
|
512
|
-
# hyp_real = np.real(hypothesis)
|
513
|
-
# near_zeros = np.real(hyp_real) < eps
|
514
|
-
# hyp_real = hyp_real * (np.logical_not(near_zeros))
|
515
|
-
# hyp_real = hyp_real + (near_zeros * eps)
|
516
|
-
# hypothesis = hyp_real + 1j*np.imag(hypothesis)
|
517
|
-
# rh_angle_diff = np.angle(reference / hypothesis) * 180 / np.pi # angle diff +/-180
|
518
|
-
|
519
|
-
# weighted mean over all (scalar)
|
520
|
-
reference_mag = np.abs(reference)
|
521
|
-
ref_weight = reference_mag / (np.sum(reference_mag) + eps) # frames x bins
|
522
|
-
err = np.around(np.sum(ref_weight * rh_angle_diff), 3)
|
523
|
-
|
524
|
-
# weighted mean over frames (value per bin)
|
525
|
-
err_b = np.zeros(reference.shape[1])
|
526
|
-
for bi in range(reference.shape[1]):
|
527
|
-
ref_weight = reference_mag[:, bi] / (np.sum(reference_mag[:, bi], axis=0) + eps)
|
528
|
-
err_b[bi] = np.around(np.sum(ref_weight * rh_angle_diff[:, bi]), 3)
|
529
|
-
|
530
|
-
# weighted mean over bins (value per frame)
|
531
|
-
err_f = np.zeros(reference.shape[0])
|
532
|
-
for fi in range(reference.shape[0]):
|
533
|
-
ref_weight = reference_mag[fi, :] / (np.sum(reference_mag[fi, :]) + eps)
|
534
|
-
err_f[fi] = np.around(np.sum(ref_weight * rh_angle_diff[fi, :]), 3)
|
535
|
-
|
536
|
-
return err, err_b, err_f
|
537
|
-
|
538
|
-
|
539
208
|
def plot_mixpred(mixture: AudioT,
|
540
209
|
mixture_f: AudioF,
|
541
210
|
target: Optional[AudioT] = None,
|
@@ -543,7 +212,6 @@ def plot_mixpred(mixture: AudioT,
|
|
543
212
|
predict: Optional[Predict] = None,
|
544
213
|
tp_title: str = '') -> plt.Figure:
|
545
214
|
from sonusai.mixture import SAMPLE_RATE
|
546
|
-
|
547
215
|
num_plots = 2
|
548
216
|
if feature is not None:
|
549
217
|
num_plots += 1
|
@@ -706,12 +374,13 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
706
374
|
import h5py
|
707
375
|
import mgzip
|
708
376
|
from matplotlib.backends.backend_pdf import PdfPages
|
709
|
-
from numpy import inf
|
710
377
|
from pystoi import stoi
|
711
378
|
|
712
379
|
from sonusai import SonusAIError
|
713
380
|
from sonusai import logger
|
714
381
|
from sonusai.metrics import calc_pcm
|
382
|
+
from sonusai.metrics import calc_phase_distance
|
383
|
+
from sonusai.metrics import calc_speech
|
715
384
|
from sonusai.metrics import calc_wer
|
716
385
|
from sonusai.metrics import calc_wsdr
|
717
386
|
from sonusai.mixture import forward_transform
|
@@ -771,24 +440,28 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
771
440
|
# noise_wo_dist_f = mixdb.mixture_noise_f(mixid, noise=noise_wo_dist)
|
772
441
|
noise = mixture - target # has time-domain distortion (ir,etc.) but does not have specaugment
|
773
442
|
# noise_f = mixdb.mixture_noise_f(mixid, noise=noise)
|
774
|
-
|
443
|
+
# note: uses pre-IR, pre-specaug audio
|
444
|
+
segsnr_f: np.ndarray = mixdb.mixture_metrics(mixid, ['ssnr'])[0] # type: ignore
|
775
445
|
mixture_f = mixdb.mixture_mixture_f(mixid, mixture=mixture)
|
776
446
|
noise_f = mixture_f - target_f # true noise in freq domain includes specaugment and time-domain ir,distortions
|
777
447
|
# segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise)
|
778
|
-
segsnr_f[segsnr_f == inf] =
|
779
|
-
segsnr_f
|
448
|
+
segsnr_f[segsnr_f == np.inf] = DB_99
|
449
|
+
# segsnr_f should never be -np.inf
|
450
|
+
segsnr_f[segsnr_f == -np.inf] = DB_N99
|
780
451
|
# need to use inv-tf to match #samples & latency shift properties of predict inv tf
|
781
452
|
target_fi = inverse_transform(target_f, mixdb.it_config)
|
782
453
|
noise_fi = inverse_transform(noise_f, mixdb.it_config)
|
783
454
|
# mixture_fi = mixdb.inverse_transform(mixture_f)
|
784
455
|
|
785
456
|
# gen feature, truth - note feature only used for plots
|
786
|
-
#
|
787
|
-
feature, truth_f = mixdb.mixture_ft(mixid,
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
457
|
+
# TODO: parse truth_f for different formats
|
458
|
+
feature, truth_f = mixdb.mixture_ft(mixid, mixture_f=mixture_f)
|
459
|
+
# ignore mixup
|
460
|
+
for truth_setting in mixdb.target_file(mixdb.mixture(mixid).targets[0].file_id).truth_settings:
|
461
|
+
if truth_setting.function == 'target_mixture_f':
|
462
|
+
half = truth_f.shape[-1] // 2
|
463
|
+
# extract target_f only
|
464
|
+
truth_f = truth_f[..., :half]
|
792
465
|
|
793
466
|
if not truth_est_mode:
|
794
467
|
if predict.shape[0] < target_f.shape[0]: # target_f, truth_f, mixture_f, etc. same size
|
@@ -843,12 +516,12 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
843
516
|
pcm, pcm_bin, pcm_frame = calc_pcm(hypothesis=ypred_f, reference=ytrue_f, with_log=True)
|
844
517
|
|
845
518
|
# Phase distance
|
846
|
-
phd, phd_bin, phd_frame =
|
519
|
+
phd, phd_bin, phd_frame = calc_phase_distance(hypothesis=predict_complex, reference=truth_f_complex)
|
847
520
|
|
848
521
|
# Noise td logerr
|
849
522
|
# lerr_nt, lerr_nt_bin, lerr_nt_frame = log_error(noise_fi, noise_truth_est_audio)
|
850
523
|
|
851
|
-
# # SA-SDR (time-domain source-
|
524
|
+
# # SA-SDR (time-domain source-aggregated SDR)
|
852
525
|
ytrue = np.concatenate((target_fi[:, np.newaxis], noise_fi[:, np.newaxis]), axis=1)
|
853
526
|
ypred = np.concatenate((target_est_wav[:, np.newaxis], noise_est_wav[:, np.newaxis]), axis=1)
|
854
527
|
# # note: w/o scale is more pessimistic number
|
@@ -863,8 +536,8 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
863
536
|
# Speech intelligibility measure - PESQ
|
864
537
|
if int(mixdb.mixture(mixid).snr) > -99:
|
865
538
|
# len = target_est_wav.shape[0]
|
866
|
-
pesq_speech, csig_tg, cbak_tg, covl_tg
|
867
|
-
pesq_mixture, csig_mx, cbak_mx, covl_mx
|
539
|
+
pesq_speech, csig_tg, cbak_tg, covl_tg = calc_speech(target_est_wav, target_fi)
|
540
|
+
pesq_mixture, csig_mx, cbak_mx, covl_mx = mixdb.mixture_metrics(mixid, ['mxpesq', 'mxcsig', 'mxcbak', 'mxcovl'])
|
868
541
|
# pesq_speech_tst = calc_pesq(hypothesis=target_est_wav, reference=target)
|
869
542
|
# pesq_mixture_tst = calc_pesq(hypothesis=mixture, reference=target)
|
870
543
|
# pesq improvement
|
@@ -886,20 +559,15 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
886
559
|
asr_tt = None
|
887
560
|
asr_mx = None
|
888
561
|
asr_tge = None
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
wer_pi = float('nan')
|
893
|
-
else:
|
562
|
+
asr_engines = list(mixdb.asr_configs.keys())
|
563
|
+
if len(asr_engines) > 0 and mixdb.mixture(mixid).snr >= -96: # noise only, ignore/reset target asr
|
564
|
+
wer_mx = float(mixdb.mixture_metrics(mixid, [f'mxwer.{asr_engines[0]}'])[0]) * 100
|
894
565
|
asr_tt = MP_GLOBAL.mixdb.mixture_speech_metadata(mixid, 'text')[0] # ignore mixup
|
895
566
|
if asr_tt is None:
|
896
567
|
asr_tt = calc_asr(target, engine=asr_method, whisper_model_name=asr_model_name).text # target truth
|
897
568
|
|
898
569
|
if asr_tt:
|
899
|
-
asr_mx = calc_asr(mixture, engine=asr_method, whisper_model_name=asr_model_name).text
|
900
570
|
asr_tge = calc_asr(target_est_wav, engine=asr_method, whisper_model_name=asr_model_name).text
|
901
|
-
|
902
|
-
wer_mx = calc_wer(asr_mx, asr_tt).wer * 100 # mixture wer
|
903
571
|
wer_tge = calc_wer(asr_tge, asr_tt).wer * 100 # target estimate wer
|
904
572
|
if wer_mx == 0.0:
|
905
573
|
if wer_tge == 0.0:
|
@@ -913,6 +581,10 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
913
581
|
wer_mx = float(0)
|
914
582
|
wer_tge = float(0)
|
915
583
|
wer_pi = float(0)
|
584
|
+
else:
|
585
|
+
wer_mx = float('nan')
|
586
|
+
wer_tge = float('nan')
|
587
|
+
wer_pi = float('nan')
|
916
588
|
|
917
589
|
# 5) Save per mixture metric results
|
918
590
|
# Single row in table of scalar metrics per mixture
|
@@ -1088,7 +760,6 @@ def main():
|
|
1088
760
|
from os.path import basename
|
1089
761
|
from os.path import isdir
|
1090
762
|
from os.path import join
|
1091
|
-
from os.path import split
|
1092
763
|
|
1093
764
|
import psutil
|
1094
765
|
from tqdm import tqdm
|
@@ -1097,7 +768,7 @@ def main():
|
|
1097
768
|
from sonusai import initial_log_messages
|
1098
769
|
from sonusai import logger
|
1099
770
|
from sonusai import update_console_handler
|
1100
|
-
from sonusai.mixture import
|
771
|
+
from sonusai.mixture import DEFAULT_SPEECH
|
1101
772
|
from sonusai.mixture import MixtureDatabase
|
1102
773
|
from sonusai.mixture import read_audio
|
1103
774
|
from sonusai.utils import calc_asr
|
@@ -1173,8 +844,7 @@ def main():
|
|
1173
844
|
return
|
1174
845
|
|
1175
846
|
if enable_asr_warmup:
|
1176
|
-
|
1177
|
-
audio = read_audio(default_speech)
|
847
|
+
audio = read_audio(DEFAULT_SPEECH)
|
1178
848
|
logger.info(f'Warming up asr method, note for cloud service this could take up to a few min ...')
|
1179
849
|
asr_chk = calc_asr(audio, engine=asr_method, whisper_model_name=asr_model_name)
|
1180
850
|
logger.info(f'Warmup completed, results {asr_chk}')
|
sonusai/data/genmixdb.yml
CHANGED
sonusai/doc/doc.py
CHANGED
@@ -199,7 +199,6 @@ def get_truth_functions() -> str:
|
|
199
199
|
def doc_truth_settings() -> str:
|
200
200
|
import yaml
|
201
201
|
|
202
|
-
from sonusai.mixture import get_default_config
|
203
202
|
default = f"\nDefault value:\n\n{yaml.dump(get_default_config()['truth_settings'])}"
|
204
203
|
return """
|
205
204
|
'truth_settings' is a mixture database configuration parameter that sets the truth
|
@@ -375,7 +374,6 @@ This rule expands to 6 unique augmentations being applied to each target
|
|
375
374
|
def doc_target_augmentations() -> str:
|
376
375
|
import yaml
|
377
376
|
|
378
|
-
from sonusai.mixture import get_default_config
|
379
377
|
default = f"\nDefault value:\n\n{yaml.dump(get_default_config()['target_augmentations'])}"
|
380
378
|
return """
|
381
379
|
'target_augmentations' is a mixture database configuration parameter that
|
@@ -388,7 +386,6 @@ See 'augmentations' for details on augmentation rules.
|
|
388
386
|
def doc_class_balancing_augmentation() -> str:
|
389
387
|
import yaml
|
390
388
|
|
391
|
-
from sonusai.mixture import get_default_config
|
392
389
|
default = f"\nDefault value:\n\n{yaml.dump(get_default_config()['class_balancing_augmentation'])}"
|
393
390
|
return """
|
394
391
|
'class_balancing_augmentation' is a mixture database configuration parameter
|
@@ -436,7 +433,6 @@ Required field:
|
|
436
433
|
def doc_noise_augmentations() -> str:
|
437
434
|
import yaml
|
438
435
|
|
439
|
-
from sonusai.mixture import get_default_config
|
440
436
|
default = f"\nDefault value:\n\n{yaml.dump(get_default_config()['noise_augmentations'])}"
|
441
437
|
return """
|
442
438
|
'noise_augmentations' is a mixture database configuration parameter that
|
@@ -536,3 +532,48 @@ def doc_config() -> str:
|
|
536
532
|
for c in VALID_CONFIGS:
|
537
533
|
text += f' {c}\n'
|
538
534
|
return text
|
535
|
+
|
536
|
+
|
537
|
+
def doc_asr_configs() -> str:
|
538
|
+
from sonusai.utils import get_available_engines
|
539
|
+
|
540
|
+
default = f"\nDefault value: {get_default_config()['asr_configs']}"
|
541
|
+
engines = get_available_engines()
|
542
|
+
text = """
|
543
|
+
'asr_configs' is a mixture database configuration parameter that sets the list of
|
544
|
+
ASR engine(s) to use.
|
545
|
+
|
546
|
+
Required fields:
|
547
|
+
|
548
|
+
'name' Unique identifier for the ASR engine.
|
549
|
+
'engine' ASR engine to use. Available engines:
|
550
|
+
"""
|
551
|
+
text += f' {", ".join(engines)}\n'
|
552
|
+
text += """
|
553
|
+
Optional fields:
|
554
|
+
|
555
|
+
'model' Some ASR engines allow the specification of a model, but note most are
|
556
|
+
very computationally demanding and can overwhelm/hang a local system.
|
557
|
+
Available whisper ASR engines:
|
558
|
+
tiny.en, tiny, base.en, base, small.en, small, medium.en, medium, large-v1, large-v2, large
|
559
|
+
'device' Some ASR engines allow the specification of a device, either 'cpu' or 'cuda'.
|
560
|
+
'cpu_threads' Some ASR engines allow the specification of the number of CPU threads to use.
|
561
|
+
'compute_type' Some ASR engines allow the specification of a compute type, e.g. 'int8'.
|
562
|
+
'beam_size' Some ASR engines allow the specification of a beam size.
|
563
|
+
<other> Other parameters can be injected into the ASR engine as needed; all
|
564
|
+
fields in each config are forwarded to the given engine.
|
565
|
+
|
566
|
+
Example:
|
567
|
+
|
568
|
+
asr_configs:
|
569
|
+
- name: faster_tiny_cuda
|
570
|
+
engine: faster_whisper
|
571
|
+
model: tiny
|
572
|
+
device: cuda
|
573
|
+
beam_size: 5
|
574
|
+
- name: google
|
575
|
+
engine: google
|
576
|
+
|
577
|
+
Creates two ASR engines for use named faster_tiny_cuda and google.
|
578
|
+
"""
|
579
|
+
return text + default
|