sonusai 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -0
- sonusai/audiofe.py +157 -61
- sonusai/calc_metric_spenh-save.py +1334 -0
- sonusai/calc_metric_spenh.py +15 -8
- sonusai/genft.py +15 -6
- sonusai/genmix.py +14 -6
- sonusai/genmixdb.py +14 -6
- sonusai/gentcst.py +13 -6
- sonusai/lsdb.py +15 -5
- sonusai/mkmanifest.py +14 -6
- sonusai/mkwav.py +15 -6
- sonusai/onnx_predict-old.py +240 -0
- sonusai/onnx_predict-save.py +487 -0
- sonusai/onnx_predict.py +446 -182
- sonusai/ovino_predict.py +508 -0
- sonusai/ovino_query_devices.py +47 -0
- sonusai/plot.py +16 -6
- sonusai/post_spenh_targetf.py +13 -6
- sonusai/summarize_metric_spenh.py +71 -0
- sonusai/torchl_onnx-old.py +216 -0
- sonusai/tplot.py +14 -6
- sonusai/utils/onnx_utils.py +128 -39
- {sonusai-0.16.0.dist-info → sonusai-0.17.0.dist-info}/METADATA +1 -1
- {sonusai-0.16.0.dist-info → sonusai-0.17.0.dist-info}/RECORD +26 -19
- {sonusai-0.16.0.dist-info → sonusai-0.17.0.dist-info}/WHEEL +1 -1
- {sonusai-0.16.0.dist-info → sonusai-0.17.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,1334 @@
|
|
1
|
+
"""sonusai calc_metric_spenh
|
2
|
+
|
3
|
+
usage: calc_metric_spenh [-hvtpws] [-i MIXID] [-e WER] [-m WMNAME] PLOC TLOC
|
4
|
+
|
5
|
+
options:
|
6
|
+
-h, --help
|
7
|
+
-v, --verbose Be verbose.
|
8
|
+
-i MIXID, --mixid MIXID Mixture ID(s) to process, can be range like 0:maxmix+1 [default: *].
|
9
|
+
-t, --truth-est-mode Calculate extraction and metrics using truth (instead of prediction).
|
10
|
+
-p, --plot Enable PDF plots file generation per mixture.
|
11
|
+
-w, --wav Generate WAV files per mixture.
|
12
|
+
-s, --summary Enable summary files generation.
|
13
|
+
-e WER, --wer-method WER Word-Error-Rate method: deepgram, google, aixplain_whisper
|
14
|
+
or whisper (locally run) [default: none]
|
15
|
+
-m WMNAME, --whisper-model Whisper model name used in aixplain_whisper and whisper WER methods.
|
16
|
+
[default: tiny]
|
17
|
+
|
18
|
+
Calculate speech enhancement metrics of prediction data in PLOC using SonusAI mixture data
|
19
|
+
in TLOC as truth/label reference. Metric and extraction data files are written into PLOC.
|
20
|
+
|
21
|
+
PLOC directory containing prediction data in .h5 files created from truth/label mixture data in TLOC
|
22
|
+
TLOC directory with SonusAI mixture database of truth/label mixture data
|
23
|
+
|
24
|
+
For whisper WER methods, the possible models used in local processing (WER = whisper) are:
|
25
|
+
{tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large}
|
26
|
+
but note most are very computationally demanding and can overwhelm/hang a local system.
|
27
|
+
|
28
|
+
Outputs the following to PLOC (where id is mixd number 0:num_mixtures):
|
29
|
+
<id>_metric_spenh.txt
|
30
|
+
|
31
|
+
If --plot:
|
32
|
+
<id>_metric_spenh.pdf
|
33
|
+
|
34
|
+
If --wav:
|
35
|
+
<id>_target.wav
|
36
|
+
<id>_target_est.wav
|
37
|
+
<id>_noise.wav
|
38
|
+
<id>_noise_est.wav
|
39
|
+
<id>_mixture.wav
|
40
|
+
|
41
|
+
If --truth-est-mode:
|
42
|
+
<id>_target_truth_est.wav
|
43
|
+
<id>_noise_truth_est.wav
|
44
|
+
|
45
|
+
If --summary:
|
46
|
+
metric_spenh_targetf_summary.txt
|
47
|
+
metric_spenh_targetf_summary.csv
|
48
|
+
metric_spenh_targetf_list.csv
|
49
|
+
metric_spenh_targetf_estats_list.csv
|
50
|
+
|
51
|
+
If --truth-est-mode:
|
52
|
+
metric_spenh_targetf_truth_list.csv
|
53
|
+
metric_spenh_targetf_estats_truth_list.csv
|
54
|
+
|
55
|
+
TBD
|
56
|
+
Metric and extraction data are written into prediction location PLOC as separate files per mixture.
|
57
|
+
|
58
|
+
-d PLOC, --ploc PLOC Location of SonusAI predict data.
|
59
|
+
|
60
|
+
Inputs:
|
61
|
+
|
62
|
+
"""
|
63
|
+
from dataclasses import dataclass
|
64
|
+
from typing import Optional
|
65
|
+
|
66
|
+
import matplotlib
|
67
|
+
import matplotlib.pyplot as plt
|
68
|
+
import numpy as np
|
69
|
+
import pandas as pd
|
70
|
+
|
71
|
+
from sonusai import logger
|
72
|
+
from sonusai.mixture import AudioF
|
73
|
+
from sonusai.mixture import AudioT
|
74
|
+
from sonusai.mixture import Feature
|
75
|
+
from sonusai.mixture import MixtureDatabase
|
76
|
+
from sonusai.mixture import Predict
|
77
|
+
|
78
|
+
matplotlib.use('SVG')
|
79
|
+
|
80
|
+
|
81
|
+
@dataclass
|
82
|
+
class MPGlobal:
|
83
|
+
mixdb: MixtureDatabase = None
|
84
|
+
predict_location: str = None
|
85
|
+
predwav_mode: bool = None
|
86
|
+
truth_est_mode: bool = None
|
87
|
+
enable_plot: bool = None
|
88
|
+
enable_wav: bool = None
|
89
|
+
wer_method: str = None
|
90
|
+
whisper_model: str = None
|
91
|
+
|
92
|
+
|
93
|
+
MP_GLOBAL = MPGlobal()
|
94
|
+
|
95
|
+
|
96
|
+
def power_compress(spec):
|
97
|
+
mag = np.abs(spec)
|
98
|
+
phase = np.angle(spec)
|
99
|
+
mag = mag ** 0.3
|
100
|
+
real_compress = mag * np.cos(phase)
|
101
|
+
imag_compress = mag * np.sin(phase)
|
102
|
+
return real_compress + 1j * imag_compress
|
103
|
+
|
104
|
+
|
105
|
+
def power_uncompress(spec):
|
106
|
+
mag = np.abs(spec)
|
107
|
+
phase = np.angle(spec)
|
108
|
+
mag = mag ** (1. / 0.3)
|
109
|
+
real_uncompress = mag * np.cos(phase)
|
110
|
+
imag_uncompress = mag * np.sin(phase)
|
111
|
+
return real_uncompress + 1j * imag_uncompress
|
112
|
+
|
113
|
+
|
114
|
+
def snr(clean_speech, processed_speech, sample_rate):
|
115
|
+
# Check the length of the clean and processed speech. Must be the same.
|
116
|
+
clean_length = len(clean_speech)
|
117
|
+
processed_length = len(processed_speech)
|
118
|
+
if clean_length != processed_length:
|
119
|
+
raise ValueError('Both Speech Files must be same length.')
|
120
|
+
|
121
|
+
overall_snr = 10 * np.log10(np.sum(np.square(clean_speech)) / np.sum(np.square(clean_speech - processed_speech)))
|
122
|
+
|
123
|
+
# Global Variables
|
124
|
+
winlength = round(30 * sample_rate / 1000) # window length in samples
|
125
|
+
skiprate = int(np.floor(winlength / 4)) # window skip in samples
|
126
|
+
MIN_SNR = -10 # minimum SNR in dB
|
127
|
+
MAX_SNR = 35 # maximum SNR in dB
|
128
|
+
|
129
|
+
# For each frame of input speech, calculate the Segmental SNR
|
130
|
+
num_frames = int(clean_length / skiprate - (winlength / skiprate)) # number of frames
|
131
|
+
start = 0 # starting sample
|
132
|
+
window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, winlength + 1) / (winlength + 1)))
|
133
|
+
|
134
|
+
segmental_snr = np.empty(num_frames)
|
135
|
+
EPS = np.spacing(1)
|
136
|
+
for frame_count in range(num_frames):
|
137
|
+
# (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
|
138
|
+
clean_frame = clean_speech[start:start + winlength]
|
139
|
+
processed_frame = processed_speech[start:start + winlength]
|
140
|
+
clean_frame = np.multiply(clean_frame, window)
|
141
|
+
processed_frame = np.multiply(processed_frame, window)
|
142
|
+
|
143
|
+
# (2) Compute the Segmental SNR
|
144
|
+
signal_energy = np.sum(np.square(clean_frame))
|
145
|
+
noise_energy = np.sum(np.square(clean_frame - processed_frame))
|
146
|
+
segmental_snr[frame_count] = 10 * np.log10(signal_energy / (noise_energy + EPS) + EPS)
|
147
|
+
segmental_snr[frame_count] = max(segmental_snr[frame_count], MIN_SNR)
|
148
|
+
segmental_snr[frame_count] = min(segmental_snr[frame_count], MAX_SNR)
|
149
|
+
|
150
|
+
start = start + skiprate
|
151
|
+
|
152
|
+
return overall_snr, segmental_snr
|
153
|
+
|
154
|
+
|
155
|
+
def lpcoeff(speech_frame, model_order):
|
156
|
+
# (1) Compute Autocorrelation Lags
|
157
|
+
winlength = np.size(speech_frame)
|
158
|
+
R = np.empty(model_order + 1)
|
159
|
+
E = np.empty(model_order + 1)
|
160
|
+
for k in range(model_order + 1):
|
161
|
+
R[k] = np.dot(speech_frame[0:winlength - k], speech_frame[k: winlength])
|
162
|
+
|
163
|
+
# (2) Levinson-Durbin
|
164
|
+
a = np.ones(model_order)
|
165
|
+
a_past = np.empty(model_order)
|
166
|
+
rcoeff = np.empty(model_order)
|
167
|
+
E[0] = R[0]
|
168
|
+
for i in range(model_order):
|
169
|
+
a_past[0: i] = a[0: i]
|
170
|
+
sum_term = np.dot(a_past[0: i], R[i:0:-1])
|
171
|
+
rcoeff[i] = (R[i + 1] - sum_term) / E[i]
|
172
|
+
a[i] = rcoeff[i]
|
173
|
+
if i == 0:
|
174
|
+
a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1:-1:-1], rcoeff[i])
|
175
|
+
else:
|
176
|
+
a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1::-1], rcoeff[i])
|
177
|
+
E[i + 1] = (1 - rcoeff[i] * rcoeff[i]) * E[i]
|
178
|
+
acorr = R
|
179
|
+
refcoeff = rcoeff
|
180
|
+
lpparams = np.concatenate((np.array([1]), -a))
|
181
|
+
return acorr, refcoeff, lpparams
|
182
|
+
|
183
|
+
|
184
|
+
def llr(clean_speech, processed_speech, sample_rate):
|
185
|
+
from scipy.linalg import toeplitz
|
186
|
+
|
187
|
+
# Check the length of the clean and processed speech. Must be the same.
|
188
|
+
clean_length = np.size(clean_speech)
|
189
|
+
processed_length = np.size(processed_speech)
|
190
|
+
if clean_length != processed_length:
|
191
|
+
raise ValueError('Both Speech Files must be same length.')
|
192
|
+
|
193
|
+
# Global Variables
|
194
|
+
winlength = (np.round(30 * sample_rate / 1000)).astype(int) # window length in samples
|
195
|
+
skiprate = (np.floor(winlength / 4)).astype(int) # window skip in samples
|
196
|
+
if sample_rate < 10000:
|
197
|
+
P = 10 # LPC Analysis Order
|
198
|
+
else:
|
199
|
+
P = 16 # this could vary depending on sampling frequency.
|
200
|
+
|
201
|
+
# For each frame of input speech, calculate the Log Likelihood Ratio
|
202
|
+
num_frames = int((clean_length - winlength) / skiprate) # number of frames
|
203
|
+
start = 0 # starting sample
|
204
|
+
window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, winlength + 1) / (winlength + 1)))
|
205
|
+
|
206
|
+
distortion = np.empty(num_frames)
|
207
|
+
for frame_count in range(num_frames):
|
208
|
+
# (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
|
209
|
+
clean_frame = clean_speech[start: start + winlength]
|
210
|
+
processed_frame = processed_speech[start: start + winlength]
|
211
|
+
clean_frame = np.multiply(clean_frame, window)
|
212
|
+
processed_frame = np.multiply(processed_frame, window)
|
213
|
+
|
214
|
+
# (2) Get the autocorrelation lags and LPC parameters used to compute the LLR measure.
|
215
|
+
R_clean, Ref_clean, A_clean = lpcoeff(clean_frame, P)
|
216
|
+
R_processed, Ref_processed, A_processed = lpcoeff(processed_frame, P)
|
217
|
+
|
218
|
+
# (3) Compute the LLR measure
|
219
|
+
numerator = np.dot(np.matmul(A_processed, toeplitz(R_clean)), A_processed)
|
220
|
+
denominator = np.dot(np.matmul(A_clean, toeplitz(R_clean)), A_clean)
|
221
|
+
distortion[frame_count] = np.log(numerator / denominator)
|
222
|
+
start = start + skiprate
|
223
|
+
return distortion
|
224
|
+
|
225
|
+
|
226
|
+
def wss(clean_speech, processed_speech, sample_rate):
|
227
|
+
from scipy.fftpack import fft
|
228
|
+
|
229
|
+
# Check the length of the clean and processed speech, which must be the same.
|
230
|
+
clean_length = np.size(clean_speech)
|
231
|
+
processed_length = np.size(processed_speech)
|
232
|
+
if clean_length != processed_length:
|
233
|
+
raise ValueError('Files must have same length.')
|
234
|
+
|
235
|
+
# Global variables
|
236
|
+
winlength = (np.round(30 * sample_rate / 1000)).astype(int) # window length in samples
|
237
|
+
skiprate = (np.floor(np.divide(winlength, 4))).astype(int) # window skip in samples
|
238
|
+
max_freq = (np.divide(sample_rate, 2)).astype(int) # maximum bandwidth
|
239
|
+
num_crit = 25 # number of critical bands
|
240
|
+
|
241
|
+
USE_FFT_SPECTRUM = 1 # defaults to 10th order LP spectrum
|
242
|
+
n_fft = (np.power(2, np.ceil(np.log2(2 * winlength)))).astype(int)
|
243
|
+
n_fftby2 = (np.multiply(0.5, n_fft)).astype(int) # FFT size/2
|
244
|
+
Kmax = 20.0 # value suggested by Klatt, pg 1280
|
245
|
+
Klocmax = 1.0 # value suggested by Klatt, pg 1280
|
246
|
+
|
247
|
+
# Critical Band Filter Definitions (Center Frequency and Bandwidths in Hz)
|
248
|
+
cent_freq = np.array([50.0000, 120.000, 190.000, 260.000, 330.000, 400.000, 470.000,
|
249
|
+
540.000, 617.372, 703.378, 798.717, 904.128, 1020.38, 1148.30,
|
250
|
+
1288.72, 1442.54, 1610.70, 1794.16, 1993.93, 2211.08, 2446.71,
|
251
|
+
2701.97, 2978.04, 3276.17, 3597.63])
|
252
|
+
bandwidth = np.array([70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 70.0000,
|
253
|
+
77.3724, 86.0056, 95.3398, 105.411, 116.256, 127.914, 140.423,
|
254
|
+
153.823, 168.154, 183.457, 199.776, 217.153, 235.631, 255.255,
|
255
|
+
276.072, 298.126, 321.465, 346.136])
|
256
|
+
|
257
|
+
bw_min = bandwidth[0] # minimum critical bandwidth
|
258
|
+
|
259
|
+
# Set up the critical band filters.
|
260
|
+
# Note here that Gaussianly shaped filters are used.
|
261
|
+
# Also, the sum of the filter weights are equivalent for each critical band filter.
|
262
|
+
# Filter less than -30 dB and set to zero.
|
263
|
+
min_factor = np.exp(-30.0 / (2.0 * 2.303)) # -30 dB point of filter
|
264
|
+
crit_filter = np.empty((num_crit, n_fftby2))
|
265
|
+
for i in range(num_crit):
|
266
|
+
f0 = (cent_freq[i] / max_freq) * n_fftby2
|
267
|
+
bw = (bandwidth[i] / max_freq) * n_fftby2
|
268
|
+
norm_factor = np.log(bw_min) - np.log(bandwidth[i])
|
269
|
+
j = np.arange(n_fftby2)
|
270
|
+
crit_filter[i, :] = np.exp(-11 * np.square(np.divide(j - np.floor(f0), bw)) + norm_factor)
|
271
|
+
cond = np.greater(crit_filter[i, :], min_factor)
|
272
|
+
crit_filter[i, :] = np.where(cond, crit_filter[i, :], 0)
|
273
|
+
# For each frame of input speech, calculate the Weighted Spectral Slope Measure
|
274
|
+
num_frames = int(clean_length / skiprate - (winlength / skiprate)) # number of frames
|
275
|
+
start = 0 # starting sample
|
276
|
+
window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, winlength + 1) / (winlength + 1)))
|
277
|
+
|
278
|
+
distortion = np.empty(num_frames)
|
279
|
+
for frame_count in range(num_frames):
|
280
|
+
# (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
|
281
|
+
clean_frame = clean_speech[start: start + winlength] / 32768
|
282
|
+
processed_frame = processed_speech[start: start + winlength] / 32768
|
283
|
+
clean_frame = np.multiply(clean_frame, window)
|
284
|
+
processed_frame = np.multiply(processed_frame, window)
|
285
|
+
# (2) Compute the Power Spectrum of Clean and Processed
|
286
|
+
# if USE_FFT_SPECTRUM:
|
287
|
+
clean_spec = np.square(np.abs(fft(clean_frame, n_fft)))
|
288
|
+
processed_spec = np.square(np.abs(fft(processed_frame, n_fft)))
|
289
|
+
|
290
|
+
# (3) Compute Filterbank Output Energies (in dB scale)
|
291
|
+
clean_energy = np.matmul(crit_filter, clean_spec[0:n_fftby2])
|
292
|
+
processed_energy = np.matmul(crit_filter, processed_spec[0:n_fftby2])
|
293
|
+
|
294
|
+
clean_energy = 10 * np.log10(np.maximum(clean_energy, 1E-10))
|
295
|
+
processed_energy = 10 * np.log10(np.maximum(processed_energy, 1E-10))
|
296
|
+
|
297
|
+
# (4) Compute Spectral Slope (dB[i+1]-dB[i])
|
298
|
+
clean_slope = clean_energy[1:num_crit] - clean_energy[0: num_crit - 1]
|
299
|
+
processed_slope = processed_energy[1:num_crit] - processed_energy[0: num_crit - 1]
|
300
|
+
|
301
|
+
# (5) Find the nearest peak locations in the spectra to each critical band.
|
302
|
+
# If the slope is negative, we search to the left. If positive, we search to the right.
|
303
|
+
clean_loc_peak = np.empty(num_crit - 1)
|
304
|
+
processed_loc_peak = np.empty(num_crit - 1)
|
305
|
+
|
306
|
+
for i in range(num_crit - 1):
|
307
|
+
# find the peaks in the clean speech signal
|
308
|
+
if clean_slope[i] > 0: # search to the right
|
309
|
+
n = i
|
310
|
+
while (n < num_crit - 1) and (clean_slope[n] > 0):
|
311
|
+
n = n + 1
|
312
|
+
clean_loc_peak[i] = clean_energy[n - 1]
|
313
|
+
else: # search to the left
|
314
|
+
n = i
|
315
|
+
while (n >= 0) and (clean_slope[n] <= 0):
|
316
|
+
n = n - 1
|
317
|
+
clean_loc_peak[i] = clean_energy[n + 1]
|
318
|
+
|
319
|
+
# find the peaks in the processed speech signal
|
320
|
+
if processed_slope[i] > 0: # search to the right
|
321
|
+
n = i
|
322
|
+
while (n < num_crit - 1) and (processed_slope[n] > 0):
|
323
|
+
n = n + 1
|
324
|
+
processed_loc_peak[i] = processed_energy[n - 1]
|
325
|
+
else: # search to the left
|
326
|
+
n = i
|
327
|
+
while (n >= 0) and (processed_slope[n] <= 0):
|
328
|
+
n = n - 1
|
329
|
+
processed_loc_peak[i] = processed_energy[n + 1]
|
330
|
+
|
331
|
+
# (6) Compute the WSS Measure for this frame. This includes determination of the weighting function.
|
332
|
+
dBMax_clean = np.max(clean_energy)
|
333
|
+
dBMax_processed = np.max(processed_energy)
|
334
|
+
'''
|
335
|
+
The weights are calculated by averaging individual weighting factors from the clean and processed frame.
|
336
|
+
These weights W_clean and W_processed should range from 0 to 1 and place more emphasis on spectral peaks
|
337
|
+
and less emphasis on slope differences in spectral valleys.
|
338
|
+
This procedure is described on page 1280 of Klatt's 1982 ICASSP paper.
|
339
|
+
'''
|
340
|
+
Wmax_clean = np.divide(Kmax, Kmax + dBMax_clean - clean_energy[0: num_crit - 1])
|
341
|
+
Wlocmax_clean = np.divide(Klocmax, Klocmax + clean_loc_peak - clean_energy[0: num_crit - 1])
|
342
|
+
W_clean = np.multiply(Wmax_clean, Wlocmax_clean)
|
343
|
+
|
344
|
+
Wmax_processed = np.divide(Kmax, Kmax + dBMax_processed - processed_energy[0: num_crit - 1])
|
345
|
+
Wlocmax_processed = np.divide(Klocmax, Klocmax + processed_loc_peak - processed_energy[0: num_crit - 1])
|
346
|
+
W_processed = np.multiply(Wmax_processed, Wlocmax_processed)
|
347
|
+
|
348
|
+
W = np.divide(np.add(W_clean, W_processed), 2.0)
|
349
|
+
slope_diff = np.subtract(clean_slope, processed_slope)[0: num_crit - 1]
|
350
|
+
distortion[frame_count] = np.dot(W, np.square(slope_diff)) / np.sum(W)
|
351
|
+
# this normalization is not part of Klatt's paper, but helps to normalize the measure.
|
352
|
+
# Here we scale the measure by the sum of the weights.
|
353
|
+
start = start + skiprate
|
354
|
+
return distortion
|
355
|
+
|
356
|
+
|
357
|
+
def calc_speech_metrics(hypothesis: np.ndarray,
|
358
|
+
reference: np.ndarray) -> tuple[float, int, int, int, float]:
|
359
|
+
"""
|
360
|
+
Calculate speech metrics pesq_mos, CSIG, CBAK, COVL, segSNR. These are all related and thus included
|
361
|
+
in one function. Reference: matlab script "compute_metrics.m".
|
362
|
+
|
363
|
+
Usage:
|
364
|
+
pesq, csig, cbak, covl, ssnr = compute_metrics(hypothesis, reference, Fs, path)
|
365
|
+
reference: clean audio as array
|
366
|
+
hypothesis: enhanced audio as array
|
367
|
+
Audio must have sampling rate = 16000 Hz.
|
368
|
+
|
369
|
+
Example call:
|
370
|
+
pesq_output, csig_output, cbak_output, covl_output, ssnr_output = \
|
371
|
+
calc_speech_metrics(predicted_audio, target_audio)
|
372
|
+
"""
|
373
|
+
from sonusai.metrics import calc_pesq
|
374
|
+
|
375
|
+
Fs = 16000
|
376
|
+
|
377
|
+
# compute the WSS measure
|
378
|
+
wss_dist_vec = wss(reference, hypothesis, Fs)
|
379
|
+
wss_dist_vec = np.sort(wss_dist_vec)
|
380
|
+
alpha = 0.95 # value from CMGAN ref implementation
|
381
|
+
wss_dist = np.mean(wss_dist_vec[0: round(np.size(wss_dist_vec) * alpha)])
|
382
|
+
|
383
|
+
# compute the LLR measure
|
384
|
+
llr_dist = llr(reference, hypothesis, Fs)
|
385
|
+
ll_rs = np.sort(llr_dist)
|
386
|
+
llr_len = round(np.size(llr_dist) * alpha)
|
387
|
+
llr_mean = np.mean(ll_rs[0: llr_len])
|
388
|
+
|
389
|
+
# compute the SNRseg
|
390
|
+
snr_dist, segsnr_dist = snr(reference, hypothesis, Fs)
|
391
|
+
snr_mean = snr_dist
|
392
|
+
segSNR = np.mean(segsnr_dist)
|
393
|
+
|
394
|
+
# compute the pesq (use Sonusai wrapper, only fs=16k, mode=wb support)
|
395
|
+
pesq_mos = calc_pesq(hypothesis=hypothesis, reference=reference)
|
396
|
+
# pesq_mos = pesq(sampling_rate1, data1, data2, 'wb')
|
397
|
+
|
398
|
+
# now compute the composite measures
|
399
|
+
CSIG = 3.093 - 1.029 * llr_mean + 0.603 * pesq_mos - 0.009 * wss_dist
|
400
|
+
CSIG = max(1, CSIG)
|
401
|
+
CSIG = min(5, CSIG) # limit values to [1, 5]
|
402
|
+
CBAK = 1.634 + 0.478 * pesq_mos - 0.007 * wss_dist + 0.063 * segSNR
|
403
|
+
CBAK = max(1, CBAK)
|
404
|
+
CBAK = min(5, CBAK) # limit values to [1, 5]
|
405
|
+
COVL = 1.594 + 0.805 * pesq_mos - 0.512 * llr_mean - 0.007 * wss_dist
|
406
|
+
COVL = max(1, COVL)
|
407
|
+
COVL = min(5, COVL) # limit values to [1, 5]
|
408
|
+
|
409
|
+
return pesq_mos, CSIG, CBAK, COVL, segSNR
|
410
|
+
|
411
|
+
|
412
|
+
def mean_square_error(hypothesis: np.ndarray,
|
413
|
+
reference: np.ndarray,
|
414
|
+
squared: bool = False) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
415
|
+
"""Calculate root-mean-square error or mean square error
|
416
|
+
|
417
|
+
:param hypothesis: [frames, bins]
|
418
|
+
:param reference: [frames, bins]
|
419
|
+
:param squared: calculate mean square rather than root-mean-square
|
420
|
+
:return: mean, mean per bin, mean per frame
|
421
|
+
"""
|
422
|
+
sq_err = np.square(reference - hypothesis)
|
423
|
+
|
424
|
+
# mean over frames for value per bin
|
425
|
+
err_b = np.mean(sq_err, axis=0)
|
426
|
+
# mean over bins for value per frame
|
427
|
+
err_f = np.mean(sq_err, axis=1)
|
428
|
+
# mean over all
|
429
|
+
err = np.mean(sq_err)
|
430
|
+
|
431
|
+
if not squared:
|
432
|
+
err_b = np.sqrt(err_b)
|
433
|
+
err_f = np.sqrt(err_f)
|
434
|
+
err = np.sqrt(err)
|
435
|
+
|
436
|
+
return err, err_b, err_f
|
437
|
+
|
438
|
+
|
439
|
+
def mean_abs_percentage_error(hypothesis: np.ndarray,
|
440
|
+
reference: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
441
|
+
"""Calculate mean abs percentage error
|
442
|
+
|
443
|
+
If inputs are complex, calculates average: mape(real)/2 + mape(imag)/2
|
444
|
+
|
445
|
+
:param hypothesis: [frames, bins]
|
446
|
+
:param reference: [frames, bins]
|
447
|
+
:return: mean, mean per bin, mean per frame
|
448
|
+
"""
|
449
|
+
if not np.iscomplexobj(reference) and not np.iscomplexobj(hypothesis):
|
450
|
+
abs_err = 100 * np.abs((reference - hypothesis) / (reference + np.finfo(np.float32).eps))
|
451
|
+
else:
|
452
|
+
reference_r = np.real(reference)
|
453
|
+
reference_i = np.imag(reference)
|
454
|
+
hypothesis_r = np.real(hypothesis)
|
455
|
+
hypothesis_i = np.imag(hypothesis)
|
456
|
+
abs_err_r = 100 * np.abs((reference_r - hypothesis_r) / (reference_r + np.finfo(np.float32).eps))
|
457
|
+
abs_err_i = 100 * np.abs((reference_i - hypothesis_i) / (reference_i + np.finfo(np.float32).eps))
|
458
|
+
abs_err = (abs_err_r / 2) + (abs_err_i / 2)
|
459
|
+
|
460
|
+
# mean over frames for value per bin
|
461
|
+
err_b = np.around(np.mean(abs_err, axis=0), 3)
|
462
|
+
# mean over bins for value per frame
|
463
|
+
err_f = np.around(np.mean(abs_err, axis=1), 3)
|
464
|
+
# mean over all
|
465
|
+
err = np.around(np.mean(abs_err), 3)
|
466
|
+
|
467
|
+
return err, err_b, err_f
|
468
|
+
|
469
|
+
|
470
|
+
def log_error(reference: np.ndarray, hypothesis: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
471
|
+
"""Calculate log error
|
472
|
+
|
473
|
+
:param reference: complex or real [frames, bins]
|
474
|
+
:param hypothesis: complex or real [frames, bins]
|
475
|
+
:return: mean, mean per bin, mean per frame
|
476
|
+
"""
|
477
|
+
reference_sq = np.real(reference * np.conjugate(reference))
|
478
|
+
hypothesis_sq = np.real(hypothesis * np.conjugate(hypothesis))
|
479
|
+
log_err = abs(10 * np.log10((reference_sq + np.finfo(np.float32).eps) / (hypothesis_sq + np.finfo(np.float32).eps)))
|
480
|
+
# log_err = abs(10 * np.log10(reference_sq / (hypothesis_sq + np.finfo(np.float32).eps) + np.finfo(np.float32).eps))
|
481
|
+
|
482
|
+
# mean over frames for value per bin
|
483
|
+
err_b = np.around(np.mean(log_err, axis=0), 3)
|
484
|
+
# mean over bins for value per frame
|
485
|
+
err_f = np.around(np.mean(log_err, axis=1), 3)
|
486
|
+
# mean over all
|
487
|
+
err = np.around(np.mean(log_err), 3)
|
488
|
+
|
489
|
+
return err, err_b, err_f
|
490
|
+
|
491
|
+
|
492
|
+
def phase_distance(reference: np.ndarray,
|
493
|
+
hypothesis: np.ndarray,
|
494
|
+
eps: float = 1e-9) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
495
|
+
"""Calculate weighted phase distance error (weight normalization over bins per frame)
|
496
|
+
|
497
|
+
:param reference: complex [frames, bins]
|
498
|
+
:param hypothesis: complex [frames, bins]
|
499
|
+
:param eps: epsilon value
|
500
|
+
:return: mean, mean per bin, mean per frame
|
501
|
+
"""
|
502
|
+
ang_diff = np.angle(reference) - np.angle(hypothesis)
|
503
|
+
phd_mod = (ang_diff + np.pi) % (2 * np.pi) - np.pi
|
504
|
+
rh_angle_diff = phd_mod * 180 / np.pi # angle diff in deg
|
505
|
+
|
506
|
+
# Use complex divide to intrinsically keep angle diff +/-180 deg, but avoid div by zero (real hyp)
|
507
|
+
# hyp_real = np.real(hypothesis)
|
508
|
+
# near_zeros = np.real(hyp_real) < eps
|
509
|
+
# hyp_real = hyp_real * (np.logical_not(near_zeros))
|
510
|
+
# hyp_real = hyp_real + (near_zeros * eps)
|
511
|
+
# hypothesis = hyp_real + 1j*np.imag(hypothesis)
|
512
|
+
# rh_angle_diff = np.angle(reference / hypothesis) * 180 / np.pi # angle diff +/-180
|
513
|
+
|
514
|
+
# weighted mean over all (scalar)
|
515
|
+
reference_mag = np.abs(reference)
|
516
|
+
ref_weight = reference_mag / (np.sum(reference_mag) + eps) # frames x bins
|
517
|
+
err = np.around(np.sum(ref_weight * rh_angle_diff), 3)
|
518
|
+
|
519
|
+
# weighted mean over frames (value per bin)
|
520
|
+
err_b = np.zeros(reference.shape[1])
|
521
|
+
for bi in range(reference.shape[1]):
|
522
|
+
ref_weight = reference_mag[:, bi] / (np.sum(reference_mag[:, bi], axis=0) + eps)
|
523
|
+
err_b[bi] = np.around(np.sum(ref_weight * rh_angle_diff[:, bi]), 3)
|
524
|
+
|
525
|
+
# weighted mean over bins (value per frame)
|
526
|
+
err_f = np.zeros(reference.shape[0])
|
527
|
+
for fi in range(reference.shape[0]):
|
528
|
+
ref_weight = reference_mag[fi, :] / (np.sum(reference_mag[fi, :]) + eps)
|
529
|
+
err_f[fi] = np.around(np.sum(ref_weight * rh_angle_diff[fi, :]), 3)
|
530
|
+
|
531
|
+
return err, err_b, err_f
|
532
|
+
|
533
|
+
|
534
|
+
def plot_mixpred(mixture: AudioT,
|
535
|
+
mixture_f: AudioF,
|
536
|
+
target: Optional[AudioT] = None,
|
537
|
+
feature: Optional[Feature] = None,
|
538
|
+
predict: Optional[Predict] = None,
|
539
|
+
tp_title: str = '') -> plt.Figure:
|
540
|
+
from sonusai.mixture import SAMPLE_RATE
|
541
|
+
|
542
|
+
num_plots = 2
|
543
|
+
if feature is not None:
|
544
|
+
num_plots += 1
|
545
|
+
if predict is not None:
|
546
|
+
num_plots += 1
|
547
|
+
|
548
|
+
fig, ax = plt.subplots(num_plots, 1, constrained_layout=True, figsize=(11, 8.5))
|
549
|
+
|
550
|
+
# Plot the waveform
|
551
|
+
p = 0
|
552
|
+
x_axis = np.arange(len(mixture), dtype=np.float32) / SAMPLE_RATE
|
553
|
+
ax[p].plot(x_axis, mixture, label='Mixture', color='mistyrose')
|
554
|
+
ax[0].set_ylabel('magnitude', color='tab:blue')
|
555
|
+
ax[p].set_xlim(x_axis[0], x_axis[-1])
|
556
|
+
# ax[p].set_ylim([-1.025, 1.025])
|
557
|
+
if target is not None: # Plot target time-domain waveform on top of mixture
|
558
|
+
ax[0].plot(x_axis, target, label='Target', color='tab:blue')
|
559
|
+
# ax[0].tick_params(axis='y', labelcolor=color)
|
560
|
+
ax[p].set_title('Waveform')
|
561
|
+
|
562
|
+
# Plot the mixture spectrogram
|
563
|
+
p += 1
|
564
|
+
ax[p].imshow(np.transpose(mixture_f), aspect='auto', interpolation='nearest', origin='lower')
|
565
|
+
ax[p].set_title('Mixture')
|
566
|
+
|
567
|
+
if feature is not None:
|
568
|
+
p += 1
|
569
|
+
ax[p].imshow(np.transpose(feature), aspect='auto', interpolation='nearest', origin='lower')
|
570
|
+
ax[p].set_title('Feature')
|
571
|
+
|
572
|
+
if predict is not None:
|
573
|
+
p += 1
|
574
|
+
im = ax[p].imshow(np.transpose(predict), aspect='auto', interpolation='nearest', origin='lower')
|
575
|
+
ax[p].set_title('Predict ' + tp_title)
|
576
|
+
plt.colorbar(im, location='bottom')
|
577
|
+
|
578
|
+
return fig
|
579
|
+
|
580
|
+
|
581
|
+
def plot_pdb_predtruth(predict: np.ndarray,
|
582
|
+
truth_f: Optional[np.ndarray] = None,
|
583
|
+
metric: Optional[np.ndarray] = None,
|
584
|
+
tp_title: str = '') -> plt.Figure:
|
585
|
+
"""Plot predict and optionally truth and a metric in power db, e.g. applies 10*log10(predict)"""
|
586
|
+
num_plots = 2
|
587
|
+
if truth_f is not None:
|
588
|
+
num_plots += 1
|
589
|
+
|
590
|
+
fig, ax = plt.subplots(num_plots, 1, constrained_layout=True, figsize=(11, 8.5))
|
591
|
+
|
592
|
+
# Plot the predict spectrogram
|
593
|
+
p = 0
|
594
|
+
tmp = 10 * np.log10(predict.transpose() + np.finfo(np.float32).eps)
|
595
|
+
im = ax[p].imshow(tmp, aspect='auto', interpolation='nearest', origin='lower')
|
596
|
+
ax[p].set_title('Predict')
|
597
|
+
plt.colorbar(im, location='bottom')
|
598
|
+
|
599
|
+
if truth_f is not None:
|
600
|
+
p += 1
|
601
|
+
tmp = 10 * np.log10(truth_f.transpose() + np.finfo(np.float32).eps)
|
602
|
+
im = ax[p].imshow(tmp, aspect='auto', interpolation='nearest', origin='lower')
|
603
|
+
ax[p].set_title('Truth')
|
604
|
+
plt.colorbar(im, location='bottom')
|
605
|
+
|
606
|
+
# Plot the predict avg, and optionally truth avg and metric lines
|
607
|
+
pred_avg = 10 * np.log10(np.mean(predict, axis=-1) + np.finfo(np.float32).eps)
|
608
|
+
p += 1
|
609
|
+
x_axis = np.arange(len(pred_avg), dtype=np.float32) # / SAMPLE_RATE
|
610
|
+
ax[p].plot(x_axis, pred_avg, color='black', linestyle='dashed', label='Predict mean over freq.')
|
611
|
+
ax[p].set_ylabel('mean db', color='black')
|
612
|
+
ax[p].set_xlim(x_axis[0], x_axis[-1])
|
613
|
+
if truth_f is not None:
|
614
|
+
truth_avg = 10 * np.log10(np.mean(truth_f, axis=-1) + np.finfo(np.float32).eps)
|
615
|
+
ax[p].plot(x_axis, truth_avg, color='green', linestyle='dashed', label='Truth mean over freq.')
|
616
|
+
|
617
|
+
if metric is not None: # instantiate 2nd y-axis that shares the same x-axis
|
618
|
+
ax2 = ax[p].twinx()
|
619
|
+
color2 = 'red'
|
620
|
+
ax2.plot(x_axis, metric, color=color2, label='sig distortion (mse db)')
|
621
|
+
ax2.set_xlim(x_axis[0], x_axis[-1])
|
622
|
+
ax2.set_ylim([0, np.max(metric)])
|
623
|
+
ax2.set_ylabel('spectral distortion (mse db)', color=color2)
|
624
|
+
ax2.tick_params(axis='y', labelcolor=color2)
|
625
|
+
ax[p].set_title('SNR and SNR mse (mean over freq. db)')
|
626
|
+
else:
|
627
|
+
ax[p].set_title('SNR (mean over freq. db)')
|
628
|
+
# ax[0].tick_params(axis='y', labelcolor=color)
|
629
|
+
return fig
|
630
|
+
|
631
|
+
|
632
|
+
def plot_epredtruth(predict: np.ndarray,
|
633
|
+
predict_wav: np.ndarray,
|
634
|
+
truth_f: Optional[np.ndarray] = None,
|
635
|
+
truth_wav: Optional[np.ndarray] = None,
|
636
|
+
metric: Optional[np.ndarray] = None,
|
637
|
+
tp_title: str = '') -> plt.Figure:
|
638
|
+
"""Plot predict spectrogram and waveform and optionally truth and a metric)"""
|
639
|
+
num_plots = 2
|
640
|
+
if truth_f is not None:
|
641
|
+
num_plots += 1
|
642
|
+
if metric is not None:
|
643
|
+
num_plots += 1
|
644
|
+
|
645
|
+
fig, ax = plt.subplots(num_plots, 1, constrained_layout=True, figsize=(11, 8.5))
|
646
|
+
|
647
|
+
# Plot the predict spectrogram
|
648
|
+
p = 0
|
649
|
+
im = ax[p].imshow(predict.transpose(), aspect='auto', interpolation='nearest', origin='lower')
|
650
|
+
ax[p].set_title('Predict')
|
651
|
+
plt.colorbar(im, location='bottom')
|
652
|
+
|
653
|
+
if truth_f is not None: # plot truth if provided and use same colormap as predict
|
654
|
+
p += 1
|
655
|
+
ax[p].imshow(truth_f.transpose(), im.cmap, aspect='auto', interpolation='nearest', origin='lower')
|
656
|
+
ax[p].set_title('Truth')
|
657
|
+
|
658
|
+
# Plot the predict wav, and optionally truth avg and metric lines
|
659
|
+
p += 1
|
660
|
+
x_axis = np.arange(len(predict_wav), dtype=np.float32) # / SAMPLE_RATE
|
661
|
+
ax[p].plot(x_axis, predict_wav, color='black', linestyle='dashed', label='Speech Estimate')
|
662
|
+
ax[p].set_ylabel('Amplitude', color='black')
|
663
|
+
ax[p].set_xlim(x_axis[0], x_axis[-1])
|
664
|
+
if truth_wav is not None:
|
665
|
+
ntrim = len(truth_wav) - len(predict_wav)
|
666
|
+
if ntrim > 0:
|
667
|
+
truth_wav = truth_wav[0:-ntrim]
|
668
|
+
ax[p].plot(x_axis, truth_wav, color='green', linestyle='dashed', label='True Target')
|
669
|
+
|
670
|
+
# Plot the metric lines
|
671
|
+
if metric is not None:
|
672
|
+
p += 1
|
673
|
+
if metric.ndim > 1: # if it has multiple dims, plot 1st
|
674
|
+
metric1 = metric[:, 0]
|
675
|
+
else:
|
676
|
+
metric1 = metric # if single dim, plot it as 1st
|
677
|
+
x_axis = np.arange(len(metric1), dtype=np.float32) # / SAMPLE_RATE
|
678
|
+
ax[p].plot(x_axis, metric1, color='red', label='Target LogErr')
|
679
|
+
ax[p].set_ylabel('log error db', color='red')
|
680
|
+
ax[p].set_xlim(x_axis[0], x_axis[-1])
|
681
|
+
ax[p].set_ylim([-0.01, np.max(metric1) + .01])
|
682
|
+
if metric.ndim > 1:
|
683
|
+
if metric.shape[1] > 1:
|
684
|
+
metr2 = metric[:, 1]
|
685
|
+
ax2 = ax[p].twinx()
|
686
|
+
color2 = 'blue'
|
687
|
+
ax2.plot(x_axis, metr2, color=color2, label='phase dist (deg)')
|
688
|
+
# ax2.set_ylim([-180.0, +180.0])
|
689
|
+
if np.max(metr2) - np.min(metr2) > .1:
|
690
|
+
ax2.set_ylim([np.min(metr2), np.max(metr2)])
|
691
|
+
ax2.set_ylabel('phase dist (deg)', color=color2)
|
692
|
+
ax2.tick_params(axis='y', labelcolor=color2)
|
693
|
+
# ax[p].set_title('SNR and SNR mse (mean over freq. db)')
|
694
|
+
|
695
|
+
return fig
|
696
|
+
|
697
|
+
|
698
|
+
def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
699
|
+
from os.path import basename
|
700
|
+
from os.path import join
|
701
|
+
from os.path import splitext
|
702
|
+
|
703
|
+
import h5py
|
704
|
+
from numpy import inf
|
705
|
+
from pystoi import stoi
|
706
|
+
|
707
|
+
from sonusai import SonusAIError
|
708
|
+
from sonusai import logger
|
709
|
+
from sonusai.metrics import calc_pcm
|
710
|
+
from sonusai.metrics import calc_wer
|
711
|
+
from sonusai.metrics import calc_wsdr
|
712
|
+
from sonusai.mixture import forward_transform
|
713
|
+
from sonusai.mixture import inverse_transform
|
714
|
+
from sonusai.mixture import read_audio
|
715
|
+
from sonusai.utils import calc_asr
|
716
|
+
from sonusai.utils import float_to_int16
|
717
|
+
from sonusai.utils import reshape_outputs
|
718
|
+
from sonusai.utils import stack_complex
|
719
|
+
from sonusai.utils import unstack_complex
|
720
|
+
from sonusai.utils import write_wav
|
721
|
+
|
722
|
+
mixdb = MP_GLOBAL.mixdb
|
723
|
+
predict_location = MP_GLOBAL.predict_location
|
724
|
+
predwav_mode = MP_GLOBAL.predwav_mode
|
725
|
+
truth_est_mode = MP_GLOBAL.truth_est_mode
|
726
|
+
enable_plot = MP_GLOBAL.enable_plot
|
727
|
+
enable_wav = MP_GLOBAL.enable_wav
|
728
|
+
wer_method = MP_GLOBAL.wer_method
|
729
|
+
whisper_model = MP_GLOBAL.whisper_model
|
730
|
+
|
731
|
+
# 1) Read predict data, var predict with shape [BatchSize,Classes] or [BatchSize,Tsteps,Classes]
|
732
|
+
output_name = join(predict_location, mixdb.mixture(mixid).name)
|
733
|
+
predict = None
|
734
|
+
if truth_est_mode:
|
735
|
+
# in truth estimation mode we use the truth in place of prediction to see metrics with perfect input
|
736
|
+
# don't bother to read prediction, and predict var will get assigned to truth later
|
737
|
+
# mark outputs with tru suffix, i.e. 0000_truest_*
|
738
|
+
base_name = splitext(output_name)[0] + '_truest'
|
739
|
+
else:
|
740
|
+
base_name, ext = splitext(output_name) # base_name used later
|
741
|
+
if not predwav_mode:
|
742
|
+
try:
|
743
|
+
with h5py.File(output_name, 'r') as f:
|
744
|
+
predict = np.array(f['predict'])
|
745
|
+
except Exception as e:
|
746
|
+
raise SonusAIError(f'Error reading {output_name}: {e}')
|
747
|
+
# reshape to always be [frames,classes] where ndim==3 case frames = batch * tsteps
|
748
|
+
if predict.ndim > 2: # TBD generalize to somehow detect if timestep dim exists, some cases > 2 don't have
|
749
|
+
# logger.debug(f'Prediction reshape from {predict.shape} to remove timestep dimension.')
|
750
|
+
predict, _ = reshape_outputs(predict=predict, truth=None, timesteps=predict.shape[1])
|
751
|
+
else:
|
752
|
+
base_name, ext = splitext(output_name)
|
753
|
+
prfname = join(base_name + '.wav')
|
754
|
+
audio = read_audio(prfname)
|
755
|
+
predict = forward_transform(audio, mixdb.ft_config)
|
756
|
+
if mixdb.feature[0:1] == 'h':
|
757
|
+
predict = power_compress(predict)
|
758
|
+
predict = stack_complex(predict)
|
759
|
+
|
760
|
+
# 2) Collect true target, noise, mixture data, trim to predict size if needed
|
761
|
+
tmp = mixdb.mixture_targets(mixid) # targets is list of pre-IR and pre-specaugment targets
|
762
|
+
target_f = mixdb.mixture_targets_f(mixid, targets=tmp)[0]
|
763
|
+
target = tmp[0]
|
764
|
+
mixture = mixdb.mixture_mixture(mixid) # note: gives full reverberated/distorted target, but no specaugment
|
765
|
+
# noise_wodist = mixdb.mixture_noise(mixid) # noise without specaugment and distortion
|
766
|
+
# noise_wodist_f = mixdb.mixture_noise_f(mixid, noise=noise_wodist)
|
767
|
+
noise = mixture - target # has time-domain distortion (ir,etc.) but does not have specaugment
|
768
|
+
#noise_f = mixdb.mixture_noise_f(mixid, noise=noise)
|
769
|
+
segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise) # note: uses pre-IR, pre-specaug audio
|
770
|
+
mixture_f = mixdb.mixture_mixture_f(mixid, mixture=mixture)
|
771
|
+
noise_f = mixture_f - target_f # true noise in freq domain includes specaugment and time-domain ir,distortions
|
772
|
+
#segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise)
|
773
|
+
segsnr_f[segsnr_f == inf] = 7.944e8 # 99db
|
774
|
+
segsnr_f[segsnr_f == -inf] = 1.258e-10 # -99db
|
775
|
+
# need to use inv-tf to match #samples & latency shift properties of predict inv tf
|
776
|
+
targetfi = inverse_transform(target_f, mixdb.it_config)
|
777
|
+
noisefi = inverse_transform(noise_f, mixdb.it_config)
|
778
|
+
# mixturefi = mixdb.inverse_transform(mixture_f)
|
779
|
+
|
780
|
+
# gen feature, truth - note feature only used for plots
|
781
|
+
# TBD parse truth_f for different formats and also multi-truth
|
782
|
+
feature, truth_f = mixdb.mixture_ft(mixid, mixture=mixture)
|
783
|
+
truth_type = mixdb.target_file(mixdb.mixture(mixid).targets[0].file_id).truth_settings[0].function
|
784
|
+
if truth_type == 'target_mixture_f':
|
785
|
+
half = truth_f.shape[-1] // 2
|
786
|
+
truth_f = truth_f[..., :half] # extract target_f only
|
787
|
+
|
788
|
+
if not truth_est_mode:
|
789
|
+
if predict.shape[0] < target_f.shape[0]: # target_f, truth_f, mixture_f, etc. same size
|
790
|
+
trimf = target_f.shape[0] - predict.shape[0]
|
791
|
+
logger.debug(f'Warning: prediction frames less than mixture, trimming {trimf} frames from all truth.')
|
792
|
+
target_f = target_f[0:-trimf, :]
|
793
|
+
targetfi, _ = inverse_transform(target_f, mixdb.it_config)
|
794
|
+
trimt = target.shape[0] - targetfi.shape[0]
|
795
|
+
target = target[0:-trimt]
|
796
|
+
noise_f = noise_f[0:-trimf, :]
|
797
|
+
noise = noise[0:-trimt]
|
798
|
+
mixture_f = mixture_f[0:-trimf, :]
|
799
|
+
mixture = mixture[0:-trimt]
|
800
|
+
truth_f = truth_f[0:-trimf, :]
|
801
|
+
elif predict.shape[0] > target_f.shape[0]:
|
802
|
+
raise SonusAIError(
|
803
|
+
f'Error: prediction has more frames than true mixture {predict.shape[0]} vs {truth_f.shape[0]}')
|
804
|
+
|
805
|
+
# 3) Extraction - format proper complex and wav estimates and truth (unstack, uncompress, inv tf, etc.)
|
806
|
+
if truth_est_mode:
|
807
|
+
predict = truth_f # substitute truth for the prediction (for test/debug)
|
808
|
+
predict_complex = unstack_complex(predict) # unstack
|
809
|
+
# if feat has compressed mag and truth does not, compress it
|
810
|
+
if mixdb.feature[0:1] == 'h' and mixdb.target_file(1).truth_settings[0].function[0:10] != 'targetcmpr':
|
811
|
+
predict_complex = power_compress(predict_complex) # from uncompressed truth
|
812
|
+
else:
|
813
|
+
predict_complex = unstack_complex(predict)
|
814
|
+
|
815
|
+
truth_f_complex = unstack_complex(truth_f)
|
816
|
+
if mixdb.feature[0:1] == 'h': # 'hn' or 'ha' or 'hd', etc.: # if feat has compressed mag
|
817
|
+
# estimate noise in uncompressed-mag domain
|
818
|
+
noise_est_complex = mixture_f - power_uncompress(predict_complex)
|
819
|
+
predict_complex = power_uncompress(predict_complex) # uncompress if truth is compressed
|
820
|
+
else: # cn, c8, ..
|
821
|
+
noise_est_complex = mixture_f - predict_complex
|
822
|
+
|
823
|
+
target_est_wav = inverse_transform(predict_complex, mixdb.it_config)
|
824
|
+
noise_est_wav = inverse_transform(noise_est_complex, mixdb.it_config)
|
825
|
+
|
826
|
+
# 4) Metrics
|
827
|
+
# Target/Speech logerr - PSD estimation accuracy symmetric mean log-spectral distortion
|
828
|
+
lerr_tg, lerr_tg_bin, lerr_tg_frame = log_error(reference=truth_f_complex, hypothesis=predict_complex)
|
829
|
+
# Noise logerr - PSD estimation accuracy
|
830
|
+
lerr_n, lerr_n_bin, lerr_n_frame = log_error(reference=noise_f, hypothesis=noise_est_complex)
|
831
|
+
# PCM loss metric
|
832
|
+
ytrue_f = np.concatenate((truth_f_complex[:, np.newaxis, :], noise_f[:, np.newaxis, :]), axis=1)
|
833
|
+
ypred_f = np.concatenate((predict_complex[:, np.newaxis, :], noise_est_complex[:, np.newaxis, :]), axis=1)
|
834
|
+
pcm, pcm_bin, pcm_frame = calc_pcm(hypothesis=ypred_f, reference=ytrue_f, with_log=True)
|
835
|
+
|
836
|
+
# Phase distance
|
837
|
+
phd, phd_bin, phd_frame = phase_distance(hypothesis=predict_complex, reference=truth_f_complex)
|
838
|
+
|
839
|
+
# Noise td logerr
|
840
|
+
# lerr_nt, lerr_nt_bin, lerr_nt_frame = log_error(noisefi, noise_truth_est_audio)
|
841
|
+
|
842
|
+
# # SA-SDR (time-domain source-aggragated SDR)
|
843
|
+
ytrue = np.concatenate((targetfi[:, np.newaxis], noisefi[:, np.newaxis]), axis=1)
|
844
|
+
ypred = np.concatenate((target_est_wav[:, np.newaxis], noise_est_wav[:, np.newaxis]), axis=1)
|
845
|
+
# # note: w/o scale is more pessimistic number
|
846
|
+
# sa_sdr, _ = calc_sa_sdr(hypothesis=ypred, reference=ytrue)
|
847
|
+
target_stoi = stoi(targetfi, target_est_wav, 16000, extended=False)
|
848
|
+
|
849
|
+
wsdr, wsdr_cc, wsdr_cw = calc_wsdr(hypothesis=ypred, reference=ytrue, with_log=True)
|
850
|
+
# logger.debug(f'wsdr weight sum for mixid {mixid} = {np.sum(wsdr_cw)}.')
|
851
|
+
# logger.debug(f'wsdr cweights = {wsdr_cw}.')
|
852
|
+
# logger.debug(f'wsdr ccoefs for mixid {mixid} = {wsdr_cc}.')
|
853
|
+
|
854
|
+
# Speech intelligibility measure - PESQ
|
855
|
+
if int(mixdb.mixture(mixid).snr) > -99:
|
856
|
+
# len = target_est_wav.shape[0]
|
857
|
+
pesq_speech, csig_tg, cbak_tg, covl_tg, sgsnr_tg = calc_speech_metrics(target_est_wav, targetfi)
|
858
|
+
pesq_mixture, csig_mx, cbak_mx, covl_mx, sgsnr_mx = calc_speech_metrics(mixture, target)
|
859
|
+
# pesq_speech_tst = calc_pesq(hypothesis=target_est_wav, reference=target)
|
860
|
+
# pesq_mixture_tst = calc_pesq(hypothesis=mixture, reference=target)
|
861
|
+
# pesq improvement
|
862
|
+
pesq_impr = pesq_speech - pesq_mixture
|
863
|
+
# pesq improvement %
|
864
|
+
pesq_impr_pc = pesq_impr / (pesq_mixture + np.finfo(np.float32).eps) * 100
|
865
|
+
else:
|
866
|
+
pesq_speech = 0
|
867
|
+
pesq_mixture = 0
|
868
|
+
pesq_impr_pc = np.float32(0)
|
869
|
+
csig_mx = 0
|
870
|
+
csig_tg = 0
|
871
|
+
cbak_mx = 0
|
872
|
+
cbak_tg = 0
|
873
|
+
covl_mx = 0
|
874
|
+
covl_tg = 0
|
875
|
+
|
876
|
+
# Calc WER
|
877
|
+
asr_tt = ''
|
878
|
+
asr_mx = ''
|
879
|
+
asr_tge = ''
|
880
|
+
if wer_method == 'none' or mixdb.mixture(mixid).snr == -99: # noise only, ignore/reset target asr
|
881
|
+
wer_mx = float('nan')
|
882
|
+
wer_tge = float('nan')
|
883
|
+
wer_pi = float('nan')
|
884
|
+
else:
|
885
|
+
if MP_GLOBAL.mixdb.asr_manifests:
|
886
|
+
asr_tt = MP_GLOBAL.mixdb.mixture_asr_data(mixid)[0] # ignore mixup
|
887
|
+
else:
|
888
|
+
asr_tt = calc_asr(target, engine=wer_method, whisper_model_name=whisper_model).text # target truth
|
889
|
+
|
890
|
+
if asr_tt:
|
891
|
+
asr_mx = calc_asr(mixture, engine=wer_method, whisper_model=whisper_model).text
|
892
|
+
asr_tge = calc_asr(target_est_wav, engine=wer_method, whisper_model=whisper_model).text
|
893
|
+
|
894
|
+
wer_mx = calc_wer(asr_mx, asr_tt).wer * 100 # mixture wer
|
895
|
+
wer_tge = calc_wer(asr_tge, asr_tt).wer * 100 # target estimate wer
|
896
|
+
if wer_mx == 0.0:
|
897
|
+
if wer_tge == 0.0:
|
898
|
+
wer_pi = 0.0
|
899
|
+
else:
|
900
|
+
wer_pi = -999.0
|
901
|
+
else:
|
902
|
+
wer_pi = 100 * (wer_mx - wer_tge) / wer_mx
|
903
|
+
else:
|
904
|
+
print(f'Warning: mixid {mixid} asr truth is empty, setting to 0% wer')
|
905
|
+
wer_mx = float(0)
|
906
|
+
wer_tge = float(0)
|
907
|
+
wer_pi = float(0)
|
908
|
+
|
909
|
+
# 5) Save per mixture metric results
|
910
|
+
# Single row in table of scalar metrics per mixture
|
911
|
+
mtable1_col = ['MXSNR', 'MXPESQ', 'PESQ', 'PESQi%', 'MXWER', 'WER', 'WERi%', 'WSDR', 'STOI',
|
912
|
+
'PCM', 'SPLERR', 'NLERR', 'PD', 'MXCSIG', 'CSIG', 'MXCBAK', 'CBAK', 'MXCOVL', 'COVL',
|
913
|
+
'SPFILE', 'NFILE']
|
914
|
+
ti = mixdb.mixture(mixid).targets[0].file_id
|
915
|
+
ni = mixdb.mixture(mixid).noise.file_id
|
916
|
+
metr1 = [mixdb.mixture(mixid).snr, pesq_mixture, pesq_speech, pesq_impr_pc, wer_mx, wer_tge, wer_pi, wsdr,
|
917
|
+
target_stoi, pcm, lerr_tg, lerr_n, phd, csig_mx, csig_tg, cbak_mx, cbak_tg, covl_mx, covl_tg,
|
918
|
+
basename(mixdb.target_file(ti).name), basename(mixdb.noise_file(ni).name)]
|
919
|
+
mtab1 = pd.DataFrame([metr1], columns=mtable1_col, index=[mixid])
|
920
|
+
|
921
|
+
# Stats of per frame estimation metrics
|
922
|
+
metr2 = pd.DataFrame({'SSNR': segsnr_f,
|
923
|
+
'PCM': pcm_frame,
|
924
|
+
'SLERR': lerr_tg_frame,
|
925
|
+
'NLERR': lerr_n_frame,
|
926
|
+
'SPD': phd_frame})
|
927
|
+
metr2 = metr2.describe() # Use pandas stat function
|
928
|
+
# Change SSNR stats to dB, except count. SSNR is index 0, pandas requires using iloc
|
929
|
+
# metr2['SSNR'][1:] = metr2['SSNR'][1:].apply(lambda x: 10 * np.log10(x + 1.01e-10))
|
930
|
+
metr2.iloc[1:, 0] = metr2['SSNR'][1:].apply(lambda x: 10 * np.log10(x + 1.01e-10))
|
931
|
+
# create a single row in multi-column header
|
932
|
+
new_labels = pd.MultiIndex.from_product([metr2.columns,
|
933
|
+
['Avg', 'Min', 'Med', 'Max', 'Std']],
|
934
|
+
names=['Metric', 'Stat'])
|
935
|
+
dat1row = metr2.loc[['mean', 'min', '50%', 'max', 'std'], :].T.stack().to_numpy().reshape((1, -1))
|
936
|
+
mtab2 = pd.DataFrame(dat1row,
|
937
|
+
index=[mixid],
|
938
|
+
columns=new_labels)
|
939
|
+
mtab2.insert(0, 'MXSNR', mixdb.mixture(mixid).snr, False) # add MXSNR as the first metric column
|
940
|
+
|
941
|
+
all_metrics_table_1 = mtab1 # return to be collected by process
|
942
|
+
all_metrics_table_2 = mtab2 # return to be collected by process
|
943
|
+
|
944
|
+
metric_name = base_name + '_metric_spenh.txt'
|
945
|
+
with open(metric_name, 'w') as f:
|
946
|
+
print('Speech enhancement metrics:', file=f)
|
947
|
+
print(mtab1.round(2).to_string(float_format=lambda x: "{:.2f}".format(x)), file=f)
|
948
|
+
print('', file=f)
|
949
|
+
print(f'Extraction statistics over {mixture_f.shape[0]} frames:', file=f)
|
950
|
+
print(metr2.round(2).to_string(float_format=lambda x: "{:.2f}".format(x)), file=f)
|
951
|
+
print('', file=f)
|
952
|
+
print(f'Target path: {mixdb.target_file(ti).name}', file=f)
|
953
|
+
print(f'Noise path: {mixdb.noise_file(ni).name}', file=f)
|
954
|
+
if wer_method != 'none':
|
955
|
+
print(f'WER method: {wer_method} and whisper model (if used): {whisper_model}', file=f)
|
956
|
+
if mixdb.asr_manifests:
|
957
|
+
print(f'ASR truth from manifest: {asr_tt}', file=f)
|
958
|
+
else:
|
959
|
+
print(f'ASR truth from wer method: {asr_tt}', file=f)
|
960
|
+
print(f'ASR result for mixture: {asr_mx}', file=f)
|
961
|
+
print(f'ASR result for prediction: {asr_tge}', file=f)
|
962
|
+
|
963
|
+
print(f'Augmentations: {mixdb.mixture(mixid)}', file=f)
|
964
|
+
|
965
|
+
# 7) write wav files
|
966
|
+
if enable_wav:
|
967
|
+
write_wav(name=base_name + '_mixture.wav', audio=float_to_int16(mixture))
|
968
|
+
write_wav(name=base_name + '_target.wav', audio=float_to_int16(target))
|
969
|
+
# write_wav(name=base_name + '_targetfi.wav', audio=float_to_int16(targetfi))
|
970
|
+
write_wav(name=base_name + '_noise.wav', audio=float_to_int16(noise))
|
971
|
+
write_wav(name=base_name + '_target_est.wav', audio=float_to_int16(target_est_wav))
|
972
|
+
write_wav(name=base_name + '_noise_est.wav', audio=float_to_int16(noise_est_wav))
|
973
|
+
|
974
|
+
# debug code to test for perfect reconstruction of the extraction method
|
975
|
+
# note both 75% olsa-hanns and 50% olsa-hann modes checked to have perfect reconstruction
|
976
|
+
# target_r = mixdb.inverse_transform(target_f)
|
977
|
+
# noise_r = mixdb.inverse_transform(noise_f)
|
978
|
+
# _write_wav(name=base_name + '_target_r.wav', audio=float_to_int16(target_r))
|
979
|
+
# _write_wav(name=base_name + '_noise_r.wav', audio=float_to_int16(noise_r)) # chk perfect rec
|
980
|
+
|
981
|
+
# 8) Write out plot file
|
982
|
+
if enable_plot:
|
983
|
+
from matplotlib.backends.backend_pdf import PdfPages
|
984
|
+
plot_fname = base_name + '_metric_spenh.pdf'
|
985
|
+
|
986
|
+
# Reshape feature to eliminate overlap redundancy for easier to understand spectrogram view
|
987
|
+
# Original size (frames, stride, num_bands), decimates in stride dimension only if step is > 1
|
988
|
+
# Reshape to get frames*decimated_stride, num_bands
|
989
|
+
step = int(mixdb.feature_samples / mixdb.feature_step_samples)
|
990
|
+
if feature.ndim != 3:
|
991
|
+
raise SonusAIError(f'feature does not have 3 dimensions: frames, stride, num_bands')
|
992
|
+
|
993
|
+
# for feature cn*00n**
|
994
|
+
feat_sgram = unstack_complex(feature)
|
995
|
+
feat_sgram = 20 * np.log10(abs(feat_sgram) + np.finfo(np.float32).eps)
|
996
|
+
feat_sgram = feat_sgram[:, -step:, :] # decimate, Fx1xB
|
997
|
+
feat_sgram = np.reshape(feat_sgram, (feat_sgram.shape[0] * feat_sgram.shape[1], feat_sgram.shape[2]))
|
998
|
+
|
999
|
+
with PdfPages(plot_fname) as pdf:
|
1000
|
+
# page1 we always have a mixture and prediction, target optional if truth provided
|
1001
|
+
tfunc_name = mixdb.target_file(1).truth_settings[0].function # first target, assumes all have same
|
1002
|
+
if tfunc_name == 'mapped_snr_f':
|
1003
|
+
# leave as unmapped snr
|
1004
|
+
predplot = predict
|
1005
|
+
tfunc_name = mixdb.target_file(1).truth_settings[0].function
|
1006
|
+
elif tfunc_name == 'target_f' or 'target_mixture_f':
|
1007
|
+
predplot = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
|
1008
|
+
else:
|
1009
|
+
# use dB scale
|
1010
|
+
predplot = 10 * np.log10(predict + np.finfo(np.float32).eps)
|
1011
|
+
tfunc_name = tfunc_name + ' (db)'
|
1012
|
+
|
1013
|
+
mixspec = 20 * np.log10(abs(mixture_f) + np.finfo(np.float32).eps)
|
1014
|
+
pdf.savefig(plot_mixpred(mixture=mixture,
|
1015
|
+
mixture_f=mixspec,
|
1016
|
+
target=target,
|
1017
|
+
feature=feat_sgram,
|
1018
|
+
predict=predplot,
|
1019
|
+
tp_title=tfunc_name))
|
1020
|
+
|
1021
|
+
# ----- page 2, plot unmapped predict, opt truth reconstructed and line plots of mean-over-f
|
1022
|
+
# pdf.savefig(plot_pdb_predtruth(predict=pred_snr_f, tp_title='predict snr_f (db)'))
|
1023
|
+
|
1024
|
+
# page 3 speech extraction
|
1025
|
+
tg_spec = 20 * np.log10(abs(target_f) + np.finfo(np.float32).eps)
|
1026
|
+
tg_est_spec = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
|
1027
|
+
# n_spec = np.reshape(n_spec,(n_spec.shape[0] * n_spec.shape[1], n_spec.shape[2]))
|
1028
|
+
pdf.savefig(plot_epredtruth(predict=tg_est_spec,
|
1029
|
+
predict_wav=target_est_wav,
|
1030
|
+
truth_f=tg_spec,
|
1031
|
+
truth_wav=targetfi,
|
1032
|
+
metric=np.vstack((lerr_tg_frame, phd_frame)).T,
|
1033
|
+
tp_title='speech estimate'))
|
1034
|
+
|
1035
|
+
# page 4 noise extraction
|
1036
|
+
n_spec = 20 * np.log10(abs(noise_f) + np.finfo(np.float32).eps)
|
1037
|
+
n_est_spec = 20 * np.log10(abs(noise_est_complex) + np.finfo(np.float32).eps)
|
1038
|
+
pdf.savefig(plot_epredtruth(predict=n_est_spec,
|
1039
|
+
predict_wav=noise_est_wav,
|
1040
|
+
truth_f=n_spec,
|
1041
|
+
truth_wav=noisefi,
|
1042
|
+
metric=lerr_n_frame,
|
1043
|
+
tp_title='noise estimate'))
|
1044
|
+
|
1045
|
+
# Plot error waveforms
|
1046
|
+
# tg_err_wav = targetfi - target_est_wav
|
1047
|
+
# tg_err_spec = 20*np.log10(np.abs(target_f - predict_complex))
|
1048
|
+
|
1049
|
+
plt.close('all')
|
1050
|
+
|
1051
|
+
return all_metrics_table_1, all_metrics_table_2
|
1052
|
+
|
1053
|
+
|
1054
|
+
def main():
|
1055
|
+
from docopt import docopt
|
1056
|
+
|
1057
|
+
import sonusai
|
1058
|
+
from sonusai.utils import trim_docstring
|
1059
|
+
|
1060
|
+
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
1061
|
+
|
1062
|
+
verbose = args['--verbose']
|
1063
|
+
mixids = args['--mixid']
|
1064
|
+
predict_location = args['PLOC']
|
1065
|
+
wer_method = args['--wer-method'].lower()
|
1066
|
+
truth_est_mode = args['--truth-est-mode']
|
1067
|
+
enable_plot = args['--plot']
|
1068
|
+
enable_wav = args['--wav']
|
1069
|
+
enable_summary = args['--summary']
|
1070
|
+
truth_location = args['TLOC']
|
1071
|
+
whisper_model = args['--whisper-model'].lower()
|
1072
|
+
|
1073
|
+
import glob
|
1074
|
+
from os.path import basename
|
1075
|
+
from os.path import isdir
|
1076
|
+
from os.path import join
|
1077
|
+
from os.path import split
|
1078
|
+
|
1079
|
+
from tqdm import tqdm
|
1080
|
+
|
1081
|
+
from sonusai import create_file_handler
|
1082
|
+
from sonusai import initial_log_messages
|
1083
|
+
from sonusai import logger
|
1084
|
+
from sonusai import update_console_handler
|
1085
|
+
from sonusai.mixture import DEFAULT_NOISE
|
1086
|
+
from sonusai.mixture import MixtureDatabase
|
1087
|
+
from sonusai.mixture import read_audio
|
1088
|
+
from sonusai.utils import calc_asr
|
1089
|
+
from sonusai.utils import pp_tqdm_imap
|
1090
|
+
|
1091
|
+
# Check prediction subdirectory
|
1092
|
+
if not isdir(predict_location):
|
1093
|
+
print(f'The specified predict location {predict_location} is not a valid subdirectory path, exiting ...')
|
1094
|
+
|
1095
|
+
# allpfiles = listdir(predict_location)
|
1096
|
+
allpfiles = glob.glob(predict_location + "/*.h5")
|
1097
|
+
predict_logfile = glob.glob(predict_location + "/*predict.log")
|
1098
|
+
predwav_mode = False
|
1099
|
+
if len(allpfiles) <= 0 and not truth_est_mode:
|
1100
|
+
allpfiles = glob.glob(predict_location + "/*.wav") # check for wav files
|
1101
|
+
if len(allpfiles) <= 0:
|
1102
|
+
print(f'Subdirectory {predict_location} has no .h5 or .wav files, exiting ...')
|
1103
|
+
else:
|
1104
|
+
logger.info(f'Found {len(allpfiles)} prediction .wav files.')
|
1105
|
+
predwav_mode = True
|
1106
|
+
else:
|
1107
|
+
logger.info(f'Found {len(allpfiles)} prediction .h5 files.')
|
1108
|
+
|
1109
|
+
if len(predict_logfile) == 0:
|
1110
|
+
logger.info(f'Warning, predict location {predict_location} has no prediction log files.')
|
1111
|
+
else:
|
1112
|
+
logger.info(f'Found predict log {basename(predict_logfile[0])} in predict location.')
|
1113
|
+
|
1114
|
+
# Setup logging file
|
1115
|
+
create_file_handler(join(predict_location, 'calc_metric_spenh.log'))
|
1116
|
+
update_console_handler(verbose)
|
1117
|
+
initial_log_messages('calc_metric_spenh')
|
1118
|
+
|
1119
|
+
mixdb = MixtureDatabase(truth_location)
|
1120
|
+
mixids = mixdb.mixids_to_list(mixids)
|
1121
|
+
logger.info(
|
1122
|
+
f'Found mixdb of {mixdb.num_mixtures} total mixtures, with {mixdb.num_classes} classes in {truth_location}')
|
1123
|
+
logger.info(f'Only running specified subset of {len(mixids)} mixtures')
|
1124
|
+
|
1125
|
+
enable_asr_warmup = False
|
1126
|
+
if wer_method == 'none':
|
1127
|
+
fnb = 'metric_spenh_'
|
1128
|
+
elif wer_method == 'google':
|
1129
|
+
fnb = 'metric_spenh_ggl_'
|
1130
|
+
logger.info(f'WER enabled with method {wer_method}')
|
1131
|
+
enable_asr_warmup = True
|
1132
|
+
elif wer_method == 'deepgram':
|
1133
|
+
fnb = 'metric_spenh_dgram_'
|
1134
|
+
logger.info(f'WER enabled with method {wer_method}')
|
1135
|
+
enable_asr_warmup = True
|
1136
|
+
elif wer_method == 'aixplain_whisper':
|
1137
|
+
fnb = 'metric_spenh_whspx_' + whisper_model + '_'
|
1138
|
+
logger.info(f'WER enabled with method {wer_method} and whisper model {whisper_model}')
|
1139
|
+
enable_asr_warmup = True
|
1140
|
+
elif wer_method == 'whisper':
|
1141
|
+
fnb = 'metric_spenh_whspl_' + whisper_model + '_'
|
1142
|
+
logger.info(f'WER enabled with method {wer_method} and whisper model {whisper_model}')
|
1143
|
+
enable_asr_warmup = True
|
1144
|
+
elif wer_method == 'aaware_whisper':
|
1145
|
+
fnb = 'metric_spenh_whspaaw_' + whisper_model + '_'
|
1146
|
+
logger.info(f'WER enabled with method {wer_method} and whisper model {whisper_model}')
|
1147
|
+
enable_asr_warmup = True
|
1148
|
+
elif wer_method == 'fastwhisper':
|
1149
|
+
fnb = 'metric_spenh_fwhsp_' + whisper_model + '_'
|
1150
|
+
logger.info(f'WER enabled with method {wer_method} and whisper model {whisper_model}')
|
1151
|
+
enable_asr_warmup = True
|
1152
|
+
else:
|
1153
|
+
logger.error(f'Unrecognized WER method: {wer_method}')
|
1154
|
+
return
|
1155
|
+
|
1156
|
+
if enable_asr_warmup:
|
1157
|
+
DEFAULT_SPEECH = split(DEFAULT_NOISE)[0] + '/speech_ma01_01.wav'
|
1158
|
+
audio = read_audio(DEFAULT_SPEECH)
|
1159
|
+
logger.info(f'Warming up asr method, note for cloud service this could take up to a few min ...')
|
1160
|
+
asr_chk = calc_asr(audio, engine=wer_method, whisper_model_name=whisper_model)
|
1161
|
+
logger.info(f'Warmup completed, results {asr_chk}')
|
1162
|
+
|
1163
|
+
MP_GLOBAL.mixdb = mixdb
|
1164
|
+
MP_GLOBAL.predict_location = predict_location
|
1165
|
+
MP_GLOBAL.predwav_mode = predwav_mode
|
1166
|
+
MP_GLOBAL.truth_est_mode = truth_est_mode
|
1167
|
+
MP_GLOBAL.enable_plot = enable_plot
|
1168
|
+
MP_GLOBAL.enable_wav = enable_wav
|
1169
|
+
MP_GLOBAL.wer_method = wer_method
|
1170
|
+
MP_GLOBAL.whisper_model = whisper_model
|
1171
|
+
|
1172
|
+
# Individual mixtures use pandas print, set precision to 2 decimal places
|
1173
|
+
# pd.set_option('float_format', '{:.2f}'.format)
|
1174
|
+
progress = tqdm(total=len(mixids), desc='calc_metric_spenh')
|
1175
|
+
all_metrics_tables = pp_tqdm_imap(_process_mixture, mixids, progress=progress, num_cpus=8)
|
1176
|
+
progress.close()
|
1177
|
+
|
1178
|
+
all_metrics_table_1 = pd.concat([item[0] for item in all_metrics_tables])
|
1179
|
+
all_metrics_table_2 = pd.concat([item[1] for item in all_metrics_tables])
|
1180
|
+
|
1181
|
+
if not enable_summary:
|
1182
|
+
return
|
1183
|
+
|
1184
|
+
# 9) Done with mixtures, write out summary metrics
|
1185
|
+
# Calculate SNR summary avg of each non-random snr
|
1186
|
+
all_mtab1_sorted = all_metrics_table_1.sort_values(by=['MXSNR', 'SPFILE'])
|
1187
|
+
all_mtab2_sorted = all_metrics_table_2.sort_values(by=['MXSNR'])
|
1188
|
+
mtab_snr_summary = None
|
1189
|
+
mtab_snr_summary_em = None
|
1190
|
+
for snri in range(0, len(mixdb.snrs)):
|
1191
|
+
tmp = all_mtab1_sorted.query('MXSNR==' + str(mixdb.snrs[snri])).mean(numeric_only=True).to_frame().T
|
1192
|
+
# avoid nan when subset of mixids specified
|
1193
|
+
if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
|
1194
|
+
mtab_snr_summary = pd.concat([mtab_snr_summary, tmp])
|
1195
|
+
|
1196
|
+
tmp = all_mtab2_sorted[all_mtab2_sorted['MXSNR'] == mixdb.snrs[snri]].mean(numeric_only=True).to_frame().T
|
1197
|
+
# avoid nan when subset of mixids specified (mxsnr will be nan if no data):
|
1198
|
+
if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
|
1199
|
+
mtab_snr_summary_em = pd.concat([mtab_snr_summary_em, tmp])
|
1200
|
+
|
1201
|
+
mtab_snr_summary = mtab_snr_summary.sort_values(by=['MXSNR'], ascending=False)
|
1202
|
+
# Correct percentages in snr summary table
|
1203
|
+
mtab_snr_summary['PESQi%'] = 100 * (mtab_snr_summary['PESQ'] - mtab_snr_summary['MXPESQ']) / np.maximum(
|
1204
|
+
mtab_snr_summary['MXPESQ'], 0.01)
|
1205
|
+
for i in range(len(mtab_snr_summary)):
|
1206
|
+
if mtab_snr_summary['MXWER'].iloc[i] == 0.0:
|
1207
|
+
if mtab_snr_summary['WER'].iloc[i] == 0.0:
|
1208
|
+
mtab_snr_summary['WERi%'].iloc[i] = 0.0
|
1209
|
+
else:
|
1210
|
+
mtab_snr_summary['WERi%'].iloc[i] = -999.0
|
1211
|
+
else:
|
1212
|
+
if ~np.isnan(mtab_snr_summary['WER'].iloc[i]) and ~np.isnan(mtab_snr_summary['MXWER'].iloc[i]):
|
1213
|
+
# update WERi% in 6th col
|
1214
|
+
mtab_snr_summary.iloc[i,6] = 100 * (mtab_snr_summary['MXWER'].iloc[i] -
|
1215
|
+
mtab_snr_summary['WER'].iloc[i]) / \
|
1216
|
+
mtab_snr_summary['MXWER'].iloc[i]
|
1217
|
+
|
1218
|
+
|
1219
|
+
# Calculate avg metrics over all mixtures except -99
|
1220
|
+
all_mtab1_sorted_nom99 = all_mtab1_sorted[all_mtab1_sorted.MXSNR != -99]
|
1221
|
+
all_nom99_mean = all_mtab1_sorted_nom99.mean(numeric_only=True)
|
1222
|
+
|
1223
|
+
# correct the percentage averages with a direct calculation (PESQ% and WER%):
|
1224
|
+
# ser.iloc[pos]
|
1225
|
+
all_nom99_mean['PESQi%'] = (100 * (all_nom99_mean['PESQ'] - all_nom99_mean['MXPESQ'])
|
1226
|
+
/ np.maximum(all_nom99_mean['MXPESQ'], 0.01)) # pesq%
|
1227
|
+
# all_nom99_mean[3] = 100 * (all_nom99_mean[2] - all_nom99_mean[1]) / np.maximum(all_nom99_mean[1], 0.01) # pesq%
|
1228
|
+
if all_nom99_mean['MXWER'] == 0.0:
|
1229
|
+
if all_nom99_mean['WER'] == 0.0:
|
1230
|
+
all_nom99_mean['WERi%'] = 0.0
|
1231
|
+
else:
|
1232
|
+
all_nom99_mean['WERi%'] = -999.0
|
1233
|
+
else: # wer%
|
1234
|
+
all_nom99_mean['WERi%'] = 100 * (all_nom99_mean['MXWER'] - all_nom99_mean['WER']) / all_nom99_mean['MXWER']
|
1235
|
+
|
1236
|
+
num_mix = len(mixids)
|
1237
|
+
if num_mix > 1:
|
1238
|
+
# Print pandas data to files using precision to 2 decimals
|
1239
|
+
# pd.set_option('float_format', '{:.2f}'.format)
|
1240
|
+
csp = 0
|
1241
|
+
|
1242
|
+
if not truth_est_mode:
|
1243
|
+
ofname = join(predict_location, fnb + 'summary.txt')
|
1244
|
+
else:
|
1245
|
+
ofname = join(predict_location, fnb + 'summary_truest.txt')
|
1246
|
+
|
1247
|
+
with open(ofname, 'w') as f:
|
1248
|
+
print(f'WER enabled with method {wer_method}, whisper model, if used: {whisper_model}', file=f)
|
1249
|
+
print(f'Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:',
|
1250
|
+
file=f)
|
1251
|
+
print(all_nom99_mean.to_frame().T.round(2).to_string(float_format=lambda x: "{:.2f}".format(x),
|
1252
|
+
index=False), file=f)
|
1253
|
+
print(f'\nSpeech enhancement metrics avg over each SNR:', file=f)
|
1254
|
+
print(mtab_snr_summary.round(2).to_string(float_format=lambda x: "{:.2f}".format(x), index=False), file=f)
|
1255
|
+
print('', file=f)
|
1256
|
+
print(f'Extraction statistics stats avg over each SNR:', file=f)
|
1257
|
+
# with pd.option_context('display.max_colwidth', 9):
|
1258
|
+
# with pd.set_option('float_format', '{:.1f}'.format):
|
1259
|
+
print(mtab_snr_summary_em.round(1).to_string(float_format=lambda x: "{:.1f}".format(x), index=False),
|
1260
|
+
file=f)
|
1261
|
+
print('', file=f)
|
1262
|
+
# pd.set_option('float_format', '{:.2f}'.format)
|
1263
|
+
|
1264
|
+
print(f'Speech enhancement metrics stats over all {num_mix} mixtures:', file=f)
|
1265
|
+
print(all_metrics_table_1.describe().round(2).to_string(float_format=lambda x: "{:.2f}".format(x)), file=f)
|
1266
|
+
print('', file=f)
|
1267
|
+
print(f'Extraction statistics stats over all {num_mix} mixtures:', file=f)
|
1268
|
+
print(all_metrics_table_2.describe().round(2).to_string(float_format=lambda x: "{:.1f}".format(x)), file=f)
|
1269
|
+
print('', file=f)
|
1270
|
+
|
1271
|
+
print('Speech enhancement metrics all-mixtures list:', file=f)
|
1272
|
+
# print(all_metrics_table_1.head().style.format(precision=2), file=f)
|
1273
|
+
print(all_metrics_table_1.round(2).to_string(float_format=lambda x: "{:.2f}".format(x)), file=f)
|
1274
|
+
print('', file=f)
|
1275
|
+
print('Extraction statistics all-mixtures list:', file=f)
|
1276
|
+
print(all_metrics_table_2.round(2).to_string(float_format=lambda x: "{:.1f}".format(x)), file=f)
|
1277
|
+
|
1278
|
+
# Write summary to .csv file
|
1279
|
+
if not truth_est_mode:
|
1280
|
+
csv_name = join(predict_location, fnb + 'summary.csv')
|
1281
|
+
else:
|
1282
|
+
csv_name = join(predict_location, fnb + 'summary_truest.csv')
|
1283
|
+
header_args = {
|
1284
|
+
'mode': 'a',
|
1285
|
+
'encoding': 'utf-8',
|
1286
|
+
'index': False,
|
1287
|
+
'header': False,
|
1288
|
+
}
|
1289
|
+
table_args = {
|
1290
|
+
'mode': 'a',
|
1291
|
+
'encoding': 'utf-8',
|
1292
|
+
}
|
1293
|
+
label = f'Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:'
|
1294
|
+
pd.DataFrame([label]).to_csv(csv_name, header=False, index=False) # open as write
|
1295
|
+
all_nom99_mean.to_frame().T.round(2).to_csv(csv_name, index=False, **table_args)
|
1296
|
+
pd.DataFrame(['']).to_csv(csv_name, **header_args)
|
1297
|
+
pd.DataFrame([f'Speech enhancement metrics avg over each SNR:']).to_csv(csv_name, **header_args)
|
1298
|
+
mtab_snr_summary.round(2).to_csv(csv_name, index=False, **table_args)
|
1299
|
+
pd.DataFrame(['']).to_csv(csv_name, **header_args)
|
1300
|
+
pd.DataFrame([f'Extraction statistics stats avg over each SNR:']).to_csv(csv_name, **header_args)
|
1301
|
+
mtab_snr_summary_em.round(2).to_csv(csv_name, index=False, **table_args)
|
1302
|
+
pd.DataFrame(['']).to_csv(csv_name, **header_args)
|
1303
|
+
pd.DataFrame(['']).to_csv(csv_name, **header_args)
|
1304
|
+
label = f'Speech enhancement metrics stats over {num_mix} mixtures:'
|
1305
|
+
pd.DataFrame([label]).to_csv(csv_name, **header_args)
|
1306
|
+
all_metrics_table_1.describe().round(2).to_csv(csv_name, **table_args)
|
1307
|
+
pd.DataFrame(['']).to_csv(csv_name, **header_args)
|
1308
|
+
label = f'Extraction statistics stats over {num_mix} mixtures:'
|
1309
|
+
pd.DataFrame([label]).to_csv(csv_name, **header_args)
|
1310
|
+
all_metrics_table_2.describe().round(2).to_csv(csv_name, **table_args)
|
1311
|
+
label = f'WER enabled with method {wer_method}, whisper model, if used: {whisper_model}'
|
1312
|
+
pd.DataFrame([label]).to_csv(csv_name, **header_args)
|
1313
|
+
|
1314
|
+
if not truth_est_mode:
|
1315
|
+
csv_name = join(predict_location, fnb + 'list.csv')
|
1316
|
+
else:
|
1317
|
+
csv_name = join(predict_location, fnb + 'list_truest.csv')
|
1318
|
+
pd.DataFrame(['Speech enhancement metrics list:']).to_csv(csv_name, header=False, index=False) # open as write
|
1319
|
+
all_metrics_table_1.round(2).to_csv(csv_name, **table_args)
|
1320
|
+
|
1321
|
+
if not truth_est_mode:
|
1322
|
+
csv_name = join(predict_location, fnb + 'estats_list.csv')
|
1323
|
+
else:
|
1324
|
+
csv_name = join(predict_location, fnb + 'estats_list_truest.csv')
|
1325
|
+
pd.DataFrame(['Extraction statistics list:']).to_csv(csv_name, header=False, index=False) # open as write
|
1326
|
+
all_metrics_table_2.round(2).to_csv(csv_name, **table_args)
|
1327
|
+
|
1328
|
+
|
1329
|
+
if __name__ == '__main__':
|
1330
|
+
try:
|
1331
|
+
main()
|
1332
|
+
except KeyboardInterrupt:
|
1333
|
+
logger.info('Canceled due to keyboard interrupt')
|
1334
|
+
exit()
|