sonusai 0.17.0__py3-none-any.whl → 0.17.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/audiofe.py +22 -51
- sonusai/calc_metric_spenh.py +206 -213
- sonusai/doc/doc.py +1 -1
- sonusai/mixture/__init__.py +2 -0
- sonusai/mixture/audio.py +12 -0
- sonusai/mixture/datatypes.py +11 -3
- sonusai/mixture/mixdb.py +101 -0
- sonusai/mixture/soundfile_audio.py +39 -0
- sonusai/mixture/speaker_metadata.py +35 -0
- sonusai/mixture/torchaudio_audio.py +22 -0
- sonusai/mkmanifest.py +1 -1
- sonusai/onnx_predict.py +114 -410
- sonusai/queries/queries.py +1 -1
- sonusai/speech/__init__.py +3 -0
- sonusai/speech/l2arctic.py +116 -0
- sonusai/speech/librispeech.py +99 -0
- sonusai/speech/mcgill.py +70 -0
- sonusai/speech/textgrid.py +100 -0
- sonusai/speech/timit.py +135 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +52 -0
- sonusai/speech/voxceleb2.py +86 -0
- sonusai/utils/__init__.py +2 -1
- sonusai/utils/asr_manifest_functions/__init__.py +0 -1
- sonusai/utils/asr_manifest_functions/data.py +0 -8
- sonusai/utils/asr_manifest_functions/librispeech.py +1 -1
- sonusai/utils/asr_manifest_functions/mcgill_speech.py +1 -1
- sonusai/utils/asr_manifest_functions/vctk_noisy_speech.py +1 -1
- sonusai/utils/braced_glob.py +7 -3
- sonusai/utils/onnx_utils.py +110 -106
- sonusai/utils/path_info.py +7 -0
- {sonusai-0.17.0.dist-info → sonusai-0.17.2.dist-info}/METADATA +2 -1
- {sonusai-0.17.0.dist-info → sonusai-0.17.2.dist-info}/RECORD +35 -30
- {sonusai-0.17.0.dist-info → sonusai-0.17.2.dist-info}/WHEEL +1 -1
- sonusai/calc_metric_spenh-save.py +0 -1334
- sonusai/onnx_predict-old.py +0 -240
- sonusai/onnx_predict-save.py +0 -487
- sonusai/ovino_predict.py +0 -508
- sonusai/ovino_query_devices.py +0 -47
- sonusai/torchl_onnx-old.py +0 -216
- {sonusai-0.17.0.dist-info → sonusai-0.17.2.dist-info}/entry_points.txt +0 -0
sonusai/calc_metric_spenh.py
CHANGED
@@ -1,27 +1,25 @@
|
|
1
1
|
"""sonusai calc_metric_spenh
|
2
2
|
|
3
|
-
usage: calc_metric_spenh [-hvtpws] [-i MIXID] [-e
|
3
|
+
usage: calc_metric_spenh [-hvtpws] [-i MIXID] [-e ASR] [-m MODEL] PLOC TLOC
|
4
4
|
|
5
5
|
options:
|
6
6
|
-h, --help
|
7
7
|
-v, --verbose Be verbose.
|
8
|
-
-i MIXID, --mixid MIXID Mixture ID(s) to process, can be range like 0:maxmix+1 [default: *]
|
8
|
+
-i MIXID, --mixid MIXID Mixture ID(s) to process, can be range like 0:maxmix+1. [default: *]
|
9
9
|
-t, --truth-est-mode Calculate extraction and metrics using truth (instead of prediction).
|
10
10
|
-p, --plot Enable PDF plots file generation per mixture.
|
11
11
|
-w, --wav Generate WAV files per mixture.
|
12
12
|
-s, --summary Enable summary files generation.
|
13
|
-
-e
|
14
|
-
|
15
|
-
-m WMNAME, --whisper-model Whisper model name used in aixplain_whisper and whisper WER methods.
|
16
|
-
[default: tiny]
|
13
|
+
-e ASR, --asr-method ASR ASR method: deepgram, google, aixplain_whisper, whisper, or sensory. [default: none]
|
14
|
+
-m MODEL, --model ASR model name used in some ASR methods. [default: tiny]
|
17
15
|
|
18
|
-
Calculate speech enhancement metrics of prediction data in PLOC using SonusAI mixture data
|
19
|
-
|
16
|
+
Calculate speech enhancement metrics of prediction data in PLOC using SonusAI mixture data in TLOC as truth/label
|
17
|
+
reference. Metric and extraction data files are written into PLOC.
|
20
18
|
|
21
19
|
PLOC directory containing prediction data in .h5 files created from truth/label mixture data in TLOC
|
22
20
|
TLOC directory with SonusAI mixture database of truth/label mixture data
|
23
21
|
|
24
|
-
For whisper
|
22
|
+
For whisper ASR methods, the possible models used in local processing (ASR = whisper) are:
|
25
23
|
{tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large}
|
26
24
|
but note most are very computationally demanding and can overwhelm/hang a local system.
|
27
25
|
|
@@ -68,6 +66,7 @@ import matplotlib
|
|
68
66
|
import matplotlib.pyplot as plt
|
69
67
|
import numpy as np
|
70
68
|
import pandas as pd
|
69
|
+
|
71
70
|
from sonusai.mixture import AudioF
|
72
71
|
from sonusai.mixture import AudioT
|
73
72
|
from sonusai.mixture import Feature
|
@@ -93,12 +92,12 @@ matplotlib.use('SVG')
|
|
93
92
|
class MPGlobal:
|
94
93
|
mixdb: MixtureDatabase = None
|
95
94
|
predict_location: str = None
|
96
|
-
|
95
|
+
predict_wav_mode: bool = None
|
97
96
|
truth_est_mode: bool = None
|
98
97
|
enable_plot: bool = None
|
99
98
|
enable_wav: bool = None
|
100
|
-
|
101
|
-
|
99
|
+
asr_method: str = None
|
100
|
+
asr_model_name: str = None
|
102
101
|
|
103
102
|
|
104
103
|
MP_GLOBAL = MPGlobal()
|
@@ -132,64 +131,62 @@ def snr(clean_speech, processed_speech, sample_rate):
|
|
132
131
|
overall_snr = 10 * np.log10(np.sum(np.square(clean_speech)) / np.sum(np.square(clean_speech - processed_speech)))
|
133
132
|
|
134
133
|
# Global Variables
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
134
|
+
win_length = round(30 * sample_rate / 1000) # window length in samples
|
135
|
+
skip_rate = int(np.floor(win_length / 4)) # window skip in samples
|
136
|
+
min_snr = -10 # minimum SNR in dB
|
137
|
+
max_snr = 35 # maximum SNR in dB
|
139
138
|
|
140
139
|
# For each frame of input speech, calculate the Segmental SNR
|
141
|
-
num_frames = int(clean_length /
|
140
|
+
num_frames = int(clean_length / skip_rate - (win_length / skip_rate)) # number of frames
|
142
141
|
start = 0 # starting sample
|
143
|
-
window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1,
|
142
|
+
window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
|
144
143
|
|
145
144
|
segmental_snr = np.empty(num_frames)
|
146
|
-
|
145
|
+
eps = np.spacing(1)
|
147
146
|
for frame_count in range(num_frames):
|
148
147
|
# (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
|
149
|
-
clean_frame = clean_speech[start:start +
|
150
|
-
processed_frame = processed_speech[start:start +
|
148
|
+
clean_frame = clean_speech[start:start + win_length]
|
149
|
+
processed_frame = processed_speech[start:start + win_length]
|
151
150
|
clean_frame = np.multiply(clean_frame, window)
|
152
151
|
processed_frame = np.multiply(processed_frame, window)
|
153
152
|
|
154
153
|
# (2) Compute the Segmental SNR
|
155
154
|
signal_energy = np.sum(np.square(clean_frame))
|
156
155
|
noise_energy = np.sum(np.square(clean_frame - processed_frame))
|
157
|
-
segmental_snr[frame_count] = 10 * np.log10(signal_energy / (noise_energy +
|
158
|
-
segmental_snr[frame_count] = max(segmental_snr[frame_count],
|
159
|
-
segmental_snr[frame_count] = min(segmental_snr[frame_count],
|
156
|
+
segmental_snr[frame_count] = 10 * np.log10(signal_energy / (noise_energy + eps) + eps)
|
157
|
+
segmental_snr[frame_count] = np.max(segmental_snr[frame_count], min_snr)
|
158
|
+
segmental_snr[frame_count] = np.min(segmental_snr[frame_count], max_snr)
|
160
159
|
|
161
|
-
start = start +
|
160
|
+
start = start + skip_rate
|
162
161
|
|
163
162
|
return overall_snr, segmental_snr
|
164
163
|
|
165
164
|
|
166
|
-
def
|
165
|
+
def lp_coefficients(speech_frame, model_order):
|
167
166
|
# (1) Compute Autocorrelation Lags
|
168
|
-
|
169
|
-
|
170
|
-
|
167
|
+
win_length = np.size(speech_frame)
|
168
|
+
autocorrelation = np.empty(model_order + 1)
|
169
|
+
e = np.empty(model_order + 1)
|
171
170
|
for k in range(model_order + 1):
|
172
|
-
|
171
|
+
autocorrelation[k] = np.dot(speech_frame[0:win_length - k], speech_frame[k: win_length])
|
173
172
|
|
174
173
|
# (2) Levinson-Durbin
|
175
174
|
a = np.ones(model_order)
|
176
175
|
a_past = np.empty(model_order)
|
177
|
-
|
178
|
-
|
176
|
+
ref_coefficients = np.empty(model_order)
|
177
|
+
e[0] = autocorrelation[0]
|
179
178
|
for i in range(model_order):
|
180
179
|
a_past[0: i] = a[0: i]
|
181
|
-
sum_term = np.dot(a_past[0: i],
|
182
|
-
|
183
|
-
a[i] =
|
180
|
+
sum_term = np.dot(a_past[0: i], autocorrelation[i:0:-1])
|
181
|
+
ref_coefficients[i] = (autocorrelation[i + 1] - sum_term) / e[i]
|
182
|
+
a[i] = ref_coefficients[i]
|
184
183
|
if i == 0:
|
185
|
-
a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1:-1:-1],
|
184
|
+
a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1:-1:-1], ref_coefficients[i])
|
186
185
|
else:
|
187
|
-
a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1::-1],
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
lpparams = np.concatenate((np.array([1]), -a))
|
192
|
-
return acorr, refcoeff, lpparams
|
186
|
+
a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1::-1], ref_coefficients[i])
|
187
|
+
e[i + 1] = (1 - ref_coefficients[i] * ref_coefficients[i]) * e[i]
|
188
|
+
lp_params = np.concatenate((np.array([1]), -a))
|
189
|
+
return autocorrelation, ref_coefficients, lp_params
|
193
190
|
|
194
191
|
|
195
192
|
def llr(clean_speech, processed_speech, sample_rate):
|
@@ -199,38 +196,38 @@ def llr(clean_speech, processed_speech, sample_rate):
|
|
199
196
|
clean_length = np.size(clean_speech)
|
200
197
|
processed_length = np.size(processed_speech)
|
201
198
|
if clean_length != processed_length:
|
202
|
-
raise ValueError('Both
|
199
|
+
raise ValueError('Both speech files must be same length.')
|
203
200
|
|
204
201
|
# Global Variables
|
205
|
-
|
206
|
-
|
202
|
+
win_length = (np.round(30 * sample_rate / 1000)).astype(int) # window length in samples
|
203
|
+
skip_rate = (np.floor(win_length / 4)).astype(int) # window skip in samples
|
207
204
|
if sample_rate < 10000:
|
208
|
-
|
205
|
+
p = 10 # LPC Analysis Order
|
209
206
|
else:
|
210
|
-
|
207
|
+
p = 16 # this could vary depending on sampling frequency.
|
211
208
|
|
212
209
|
# For each frame of input speech, calculate the Log Likelihood Ratio
|
213
|
-
num_frames = int((clean_length -
|
210
|
+
num_frames = int((clean_length - win_length) / skip_rate) # number of frames
|
214
211
|
start = 0 # starting sample
|
215
|
-
window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1,
|
212
|
+
window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
|
216
213
|
|
217
214
|
distortion = np.empty(num_frames)
|
218
215
|
for frame_count in range(num_frames):
|
219
216
|
# (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
|
220
|
-
clean_frame = clean_speech[start: start +
|
221
|
-
processed_frame = processed_speech[start: start +
|
217
|
+
clean_frame = clean_speech[start: start + win_length]
|
218
|
+
processed_frame = processed_speech[start: start + win_length]
|
222
219
|
clean_frame = np.multiply(clean_frame, window)
|
223
220
|
processed_frame = np.multiply(processed_frame, window)
|
224
221
|
|
225
222
|
# (2) Get the autocorrelation lags and LPC parameters used to compute the LLR measure.
|
226
|
-
|
227
|
-
|
223
|
+
r_clean, ref_clean, a_clean = lp_coefficients(clean_frame, p)
|
224
|
+
r_processed, ref_processed, a_processed = lp_coefficients(processed_frame, p)
|
228
225
|
|
229
226
|
# (3) Compute the LLR measure
|
230
|
-
numerator = np.dot(np.matmul(
|
231
|
-
denominator = np.dot(np.matmul(
|
227
|
+
numerator = np.dot(np.matmul(a_processed, toeplitz(r_clean)), a_processed)
|
228
|
+
denominator = np.dot(np.matmul(a_clean, toeplitz(r_clean)), a_clean)
|
232
229
|
distortion[frame_count] = np.log(numerator / denominator)
|
233
|
-
start = start +
|
230
|
+
start = start + skip_rate
|
234
231
|
return distortion
|
235
232
|
|
236
233
|
|
@@ -244,16 +241,15 @@ def wss(clean_speech, processed_speech, sample_rate):
|
|
244
241
|
raise ValueError('Files must have same length.')
|
245
242
|
|
246
243
|
# Global variables
|
247
|
-
|
248
|
-
|
244
|
+
win_length = (np.round(30 * sample_rate / 1000)).astype(int) # window length in samples
|
245
|
+
skip_rate = (np.floor(np.divide(win_length, 4))).astype(int) # window skip in samples
|
249
246
|
max_freq = (np.divide(sample_rate, 2)).astype(int) # maximum bandwidth
|
250
247
|
num_crit = 25 # number of critical bands
|
251
248
|
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
Klocmax = 1.0 # value suggested by Klatt, pg 1280
|
249
|
+
n_fft = (np.power(2, np.ceil(np.log2(2 * win_length)))).astype(int)
|
250
|
+
n_fft_by_2 = (np.multiply(0.5, n_fft)).astype(int) # FFT size/2
|
251
|
+
k_max = 20.0 # value suggested by Klatt, pg 1280
|
252
|
+
k_loc_max = 1.0 # value suggested by Klatt, pg 1280
|
257
253
|
|
258
254
|
# Critical Band Filter Definitions (Center Frequency and Bandwidths in Hz)
|
259
255
|
cent_freq = np.array([50.0000, 120.000, 190.000, 260.000, 330.000, 400.000, 470.000,
|
@@ -268,39 +264,38 @@ def wss(clean_speech, processed_speech, sample_rate):
|
|
268
264
|
bw_min = bandwidth[0] # minimum critical bandwidth
|
269
265
|
|
270
266
|
# Set up the critical band filters.
|
271
|
-
# Note here that
|
267
|
+
# Note here that Gaussian-ly shaped filters are used.
|
272
268
|
# Also, the sum of the filter weights are equivalent for each critical band filter.
|
273
269
|
# Filter less than -30 dB and set to zero.
|
274
270
|
min_factor = np.exp(-30.0 / (2.0 * 2.303)) # -30 dB point of filter
|
275
|
-
crit_filter = np.empty((num_crit,
|
271
|
+
crit_filter = np.empty((num_crit, n_fft_by_2))
|
276
272
|
for i in range(num_crit):
|
277
|
-
f0 = (cent_freq[i] / max_freq) *
|
278
|
-
bw = (bandwidth[i] / max_freq) *
|
273
|
+
f0 = (cent_freq[i] / max_freq) * n_fft_by_2
|
274
|
+
bw = (bandwidth[i] / max_freq) * n_fft_by_2
|
279
275
|
norm_factor = np.log(bw_min) - np.log(bandwidth[i])
|
280
|
-
j = np.arange(
|
276
|
+
j = np.arange(n_fft_by_2)
|
281
277
|
crit_filter[i, :] = np.exp(-11 * np.square(np.divide(j - np.floor(f0), bw)) + norm_factor)
|
282
278
|
cond = np.greater(crit_filter[i, :], min_factor)
|
283
279
|
crit_filter[i, :] = np.where(cond, crit_filter[i, :], 0)
|
284
280
|
# For each frame of input speech, calculate the Weighted Spectral Slope Measure
|
285
|
-
num_frames = int(clean_length /
|
281
|
+
num_frames = int(clean_length / skip_rate - (win_length / skip_rate)) # number of frames
|
286
282
|
start = 0 # starting sample
|
287
|
-
window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1,
|
283
|
+
window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
|
288
284
|
|
289
285
|
distortion = np.empty(num_frames)
|
290
286
|
for frame_count in range(num_frames):
|
291
287
|
# (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
|
292
|
-
clean_frame = clean_speech[start: start +
|
293
|
-
processed_frame = processed_speech[start: start +
|
288
|
+
clean_frame = clean_speech[start: start + win_length] / 32768
|
289
|
+
processed_frame = processed_speech[start: start + win_length] / 32768
|
294
290
|
clean_frame = np.multiply(clean_frame, window)
|
295
291
|
processed_frame = np.multiply(processed_frame, window)
|
296
292
|
# (2) Compute the Power Spectrum of Clean and Processed
|
297
|
-
# if USE_FFT_SPECTRUM:
|
298
293
|
clean_spec = np.square(np.abs(fft(clean_frame, n_fft)))
|
299
294
|
processed_spec = np.square(np.abs(fft(processed_frame, n_fft)))
|
300
295
|
|
301
296
|
# (3) Compute Filterbank Output Energies (in dB scale)
|
302
|
-
clean_energy = np.matmul(crit_filter, clean_spec[0:
|
303
|
-
processed_energy = np.matmul(crit_filter, processed_spec[0:
|
297
|
+
clean_energy = np.matmul(crit_filter, clean_spec[0:n_fft_by_2])
|
298
|
+
processed_energy = np.matmul(crit_filter, processed_spec[0:n_fft_by_2])
|
304
299
|
|
305
300
|
clean_energy = 10 * np.log10(np.maximum(clean_energy, 1E-10))
|
306
301
|
processed_energy = 10 * np.log10(np.maximum(processed_energy, 1E-10))
|
@@ -340,39 +335,39 @@ def wss(clean_speech, processed_speech, sample_rate):
|
|
340
335
|
processed_loc_peak[i] = processed_energy[n + 1]
|
341
336
|
|
342
337
|
# (6) Compute the WSS Measure for this frame. This includes determination of the weighting function.
|
343
|
-
|
344
|
-
|
338
|
+
db_max_clean = np.max(clean_energy)
|
339
|
+
db_max_processed = np.max(processed_energy)
|
345
340
|
'''
|
346
341
|
The weights are calculated by averaging individual weighting factors from the clean and processed frame.
|
347
|
-
These weights
|
342
|
+
These weights w_clean and w_processed should range from 0 to 1 and place more emphasis on spectral peaks
|
348
343
|
and less emphasis on slope differences in spectral valleys.
|
349
344
|
This procedure is described on page 1280 of Klatt's 1982 ICASSP paper.
|
350
345
|
'''
|
351
|
-
|
352
|
-
|
353
|
-
|
346
|
+
w_max_clean = np.divide(k_max, k_max + db_max_clean - clean_energy[0: num_crit - 1])
|
347
|
+
w_loc_max_clean = np.divide(k_loc_max, k_loc_max + clean_loc_peak - clean_energy[0: num_crit - 1])
|
348
|
+
w_clean = np.multiply(w_max_clean, w_loc_max_clean)
|
354
349
|
|
355
|
-
|
356
|
-
|
357
|
-
|
350
|
+
w_max_processed = np.divide(k_max, k_max + db_max_processed - processed_energy[0: num_crit - 1])
|
351
|
+
w_loc_max_processed = np.divide(k_loc_max, k_loc_max + processed_loc_peak - processed_energy[0: num_crit - 1])
|
352
|
+
w_processed = np.multiply(w_max_processed, w_loc_max_processed)
|
358
353
|
|
359
|
-
|
354
|
+
w = np.divide(np.add(w_clean, w_processed), 2.0)
|
360
355
|
slope_diff = np.subtract(clean_slope, processed_slope)[0: num_crit - 1]
|
361
|
-
distortion[frame_count] = np.dot(
|
362
|
-
#
|
356
|
+
distortion[frame_count] = np.dot(w, np.square(slope_diff)) / np.sum(w)
|
357
|
+
# This normalization is not part of Klatt's paper, but helps to normalize the measure.
|
363
358
|
# Here we scale the measure by the sum of the weights.
|
364
|
-
start = start +
|
359
|
+
start = start + skip_rate
|
365
360
|
return distortion
|
366
361
|
|
367
362
|
|
368
363
|
def calc_speech_metrics(hypothesis: np.ndarray,
|
369
364
|
reference: np.ndarray) -> tuple[float, int, int, int, float]:
|
370
365
|
"""
|
371
|
-
Calculate speech metrics pesq_mos,
|
366
|
+
Calculate speech metrics pesq_mos, c_sig, c_bak, c_ovl, seg_snr. These are all related and thus included
|
372
367
|
in one function. Reference: matlab script "compute_metrics.m".
|
373
368
|
|
374
369
|
Usage:
|
375
|
-
pesq,
|
370
|
+
pesq, c_sig, c_bak, c_ovl, ssnr = compute_metrics(hypothesis, reference, fs, path)
|
376
371
|
reference: clean audio as array
|
377
372
|
hypothesis: enhanced audio as array
|
378
373
|
Audio must have sampling rate = 16000 Hz.
|
@@ -383,41 +378,39 @@ def calc_speech_metrics(hypothesis: np.ndarray,
|
|
383
378
|
"""
|
384
379
|
from sonusai.metrics import calc_pesq
|
385
380
|
|
386
|
-
|
381
|
+
fs = 16000
|
387
382
|
|
388
383
|
# compute the WSS measure
|
389
|
-
wss_dist_vec = wss(reference, hypothesis,
|
384
|
+
wss_dist_vec = wss(reference, hypothesis, fs)
|
390
385
|
wss_dist_vec = np.sort(wss_dist_vec)
|
391
386
|
alpha = 0.95 # value from CMGAN ref implementation
|
392
387
|
wss_dist = np.mean(wss_dist_vec[0: round(np.size(wss_dist_vec) * alpha)])
|
393
388
|
|
394
389
|
# compute the LLR measure
|
395
|
-
llr_dist = llr(reference, hypothesis,
|
390
|
+
llr_dist = llr(reference, hypothesis, fs)
|
396
391
|
ll_rs = np.sort(llr_dist)
|
397
392
|
llr_len = round(np.size(llr_dist) * alpha)
|
398
393
|
llr_mean = np.mean(ll_rs[0: llr_len])
|
399
394
|
|
400
395
|
# compute the SNRseg
|
401
|
-
snr_dist, segsnr_dist = snr(reference, hypothesis,
|
402
|
-
|
403
|
-
segSNR = np.mean(segsnr_dist)
|
396
|
+
snr_dist, segsnr_dist = snr(reference, hypothesis, fs)
|
397
|
+
seg_snr = np.mean(segsnr_dist)
|
404
398
|
|
405
399
|
# compute the pesq (use Sonusai wrapper, only fs=16k, mode=wb support)
|
406
400
|
pesq_mos = calc_pesq(hypothesis=hypothesis, reference=reference)
|
407
|
-
# pesq_mos = pesq(sampling_rate1, data1, data2, 'wb')
|
408
401
|
|
409
402
|
# now compute the composite measures
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
403
|
+
c_sig = 3.093 - 1.029 * llr_mean + 0.603 * pesq_mos - 0.009 * wss_dist
|
404
|
+
c_sig = max(1, c_sig)
|
405
|
+
c_sig = min(5, c_sig) # limit values to [1, 5]
|
406
|
+
c_bak = 1.634 + 0.478 * pesq_mos - 0.007 * wss_dist + 0.063 * seg_snr
|
407
|
+
c_bak = max(1, c_bak)
|
408
|
+
c_bak = min(5, c_bak) # limit values to [1, 5]
|
409
|
+
c_ovl = 1.594 + 0.805 * pesq_mos - 0.512 * llr_mean - 0.007 * wss_dist
|
410
|
+
c_ovl = max(1, c_ovl)
|
411
|
+
c_ovl = min(5, c_ovl) # limit values to [1, 5]
|
419
412
|
|
420
|
-
return pesq_mos,
|
413
|
+
return pesq_mos, c_sig, c_bak, c_ovl, seg_snr
|
421
414
|
|
422
415
|
|
423
416
|
def mean_square_error(hypothesis: np.ndarray,
|
@@ -564,10 +557,8 @@ def plot_mixpred(mixture: AudioT,
|
|
564
557
|
ax[p].plot(x_axis, mixture, label='Mixture', color='mistyrose')
|
565
558
|
ax[0].set_ylabel('magnitude', color='tab:blue')
|
566
559
|
ax[p].set_xlim(x_axis[0], x_axis[-1])
|
567
|
-
# ax[p].set_ylim([-1.025, 1.025])
|
568
560
|
if target is not None: # Plot target time-domain waveform on top of mixture
|
569
561
|
ax[0].plot(x_axis, target, label='Target', color='tab:blue')
|
570
|
-
# ax[0].tick_params(axis='y', labelcolor=color)
|
571
562
|
ax[p].set_title('Waveform')
|
572
563
|
|
573
564
|
# Plot the mixture spectrogram
|
@@ -589,10 +580,10 @@ def plot_mixpred(mixture: AudioT,
|
|
589
580
|
return fig
|
590
581
|
|
591
582
|
|
592
|
-
def
|
593
|
-
|
594
|
-
|
595
|
-
|
583
|
+
def plot_pdb_predict_truth(predict: np.ndarray,
|
584
|
+
truth_f: Optional[np.ndarray] = None,
|
585
|
+
metric: Optional[np.ndarray] = None,
|
586
|
+
tp_title: str = '') -> plt.Figure:
|
596
587
|
"""Plot predict and optionally truth and a metric in power db, e.g. applies 10*log10(predict)"""
|
597
588
|
num_plots = 2
|
598
589
|
if truth_f is not None:
|
@@ -636,16 +627,15 @@ def plot_pdb_predtruth(predict: np.ndarray,
|
|
636
627
|
ax[p].set_title('SNR and SNR mse (mean over freq. db)')
|
637
628
|
else:
|
638
629
|
ax[p].set_title('SNR (mean over freq. db)')
|
639
|
-
# ax[0].tick_params(axis='y', labelcolor=color)
|
640
630
|
return fig
|
641
631
|
|
642
632
|
|
643
|
-
def
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
633
|
+
def plot_e_predict_truth(predict: np.ndarray,
|
634
|
+
predict_wav: np.ndarray,
|
635
|
+
truth_f: Optional[np.ndarray] = None,
|
636
|
+
truth_wav: Optional[np.ndarray] = None,
|
637
|
+
metric: Optional[np.ndarray] = None,
|
638
|
+
tp_title: str = '') -> plt.Figure:
|
649
639
|
"""Plot predict spectrogram and waveform and optionally truth and a metric)"""
|
650
640
|
num_plots = 2
|
651
641
|
if truth_f is not None:
|
@@ -666,7 +656,7 @@ def plot_epredtruth(predict: np.ndarray,
|
|
666
656
|
ax[p].imshow(truth_f.transpose(), im.cmap, aspect='auto', interpolation='nearest', origin='lower')
|
667
657
|
ax[p].set_title('Truth')
|
668
658
|
|
669
|
-
# Plot
|
659
|
+
# Plot predict wav, and optionally truth avg and metric lines
|
670
660
|
p += 1
|
671
661
|
x_axis = np.arange(len(predict_wav), dtype=np.float32) # / SAMPLE_RATE
|
672
662
|
ax[p].plot(x_axis, predict_wav, color='black', linestyle='dashed', label='Speech Estimate')
|
@@ -732,12 +722,12 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
732
722
|
|
733
723
|
mixdb = MP_GLOBAL.mixdb
|
734
724
|
predict_location = MP_GLOBAL.predict_location
|
735
|
-
|
725
|
+
predict_wav_mode = MP_GLOBAL.predict_wav_mode
|
736
726
|
truth_est_mode = MP_GLOBAL.truth_est_mode
|
737
727
|
enable_plot = MP_GLOBAL.enable_plot
|
738
728
|
enable_wav = MP_GLOBAL.enable_wav
|
739
|
-
|
740
|
-
|
729
|
+
asr_method = MP_GLOBAL.asr_method
|
730
|
+
asr_model_name = MP_GLOBAL.asr_model_name
|
741
731
|
|
742
732
|
# 1) Read predict data, var predict with shape [BatchSize,Classes] or [BatchSize,Tsteps,Classes]
|
743
733
|
output_name = join(predict_location, mixdb.mixture(mixid).name)
|
@@ -749,7 +739,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
749
739
|
base_name = splitext(output_name)[0] + '_truest'
|
750
740
|
else:
|
751
741
|
base_name, ext = splitext(output_name) # base_name used later
|
752
|
-
if not
|
742
|
+
if not predict_wav_mode:
|
753
743
|
try:
|
754
744
|
with h5py.File(output_name, 'r') as f:
|
755
745
|
predict = np.array(f['predict'])
|
@@ -761,8 +751,8 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
761
751
|
predict, _ = reshape_outputs(predict=predict, truth=None, timesteps=predict.shape[1])
|
762
752
|
else:
|
763
753
|
base_name, ext = splitext(output_name)
|
764
|
-
|
765
|
-
audio = read_audio(
|
754
|
+
predict_name = join(base_name + '.wav')
|
755
|
+
audio = read_audio(predict_name)
|
766
756
|
predict = forward_transform(audio, mixdb.ft_config)
|
767
757
|
if mixdb.feature[0:1] == 'h':
|
768
758
|
predict = power_compress(predict)
|
@@ -773,8 +763,8 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
773
763
|
target_f = mixdb.mixture_targets_f(mixid, targets=tmp)[0]
|
774
764
|
target = tmp[0]
|
775
765
|
mixture = mixdb.mixture_mixture(mixid) # note: gives full reverberated/distorted target, but no specaugment
|
776
|
-
#
|
777
|
-
#
|
766
|
+
# noise_wo_dist = mixdb.mixture_noise(mixid) # noise without specaugment and distortion
|
767
|
+
# noise_wo_dist_f = mixdb.mixture_noise_f(mixid, noise=noise_wo_dist)
|
778
768
|
noise = mixture - target # has time-domain distortion (ir,etc.) but does not have specaugment
|
779
769
|
# noise_f = mixdb.mixture_noise_f(mixid, noise=noise)
|
780
770
|
segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise) # note: uses pre-IR, pre-specaug audio
|
@@ -784,9 +774,9 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
784
774
|
segsnr_f[segsnr_f == inf] = 7.944e8 # 99db
|
785
775
|
segsnr_f[segsnr_f == -inf] = 1.258e-10 # -99db
|
786
776
|
# need to use inv-tf to match #samples & latency shift properties of predict inv tf
|
787
|
-
|
788
|
-
|
789
|
-
#
|
777
|
+
target_fi = inverse_transform(target_f, mixdb.it_config)
|
778
|
+
noise_fi = inverse_transform(noise_f, mixdb.it_config)
|
779
|
+
# mixture_fi = mixdb.inverse_transform(mixture_f)
|
790
780
|
|
791
781
|
# gen feature, truth - note feature only used for plots
|
792
782
|
# TBD parse truth_f for different formats and also multi-truth
|
@@ -798,17 +788,17 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
798
788
|
|
799
789
|
if not truth_est_mode:
|
800
790
|
if predict.shape[0] < target_f.shape[0]: # target_f, truth_f, mixture_f, etc. same size
|
801
|
-
|
802
|
-
logger.debug(f'Warning: prediction frames less than mixture, trimming {
|
803
|
-
target_f = target_f[0:-
|
804
|
-
|
805
|
-
|
806
|
-
target = target[0:-
|
807
|
-
noise_f = noise_f[0:-
|
808
|
-
noise = noise[0:-
|
809
|
-
mixture_f = mixture_f[0:-
|
810
|
-
mixture = mixture[0:-
|
811
|
-
truth_f = truth_f[0:-
|
791
|
+
trim_f = target_f.shape[0] - predict.shape[0]
|
792
|
+
logger.debug(f'Warning: prediction frames less than mixture, trimming {trim_f} frames from all truth.')
|
793
|
+
target_f = target_f[0:-trim_f, :]
|
794
|
+
target_fi, _ = inverse_transform(target_f, mixdb.it_config)
|
795
|
+
trim_t = target.shape[0] - target_fi.shape[0]
|
796
|
+
target = target[0:-trim_t]
|
797
|
+
noise_f = noise_f[0:-trim_f, :]
|
798
|
+
noise = noise[0:-trim_t]
|
799
|
+
mixture_f = mixture_f[0:-trim_f, :]
|
800
|
+
mixture = mixture[0:-trim_t]
|
801
|
+
truth_f = truth_f[0:-trim_f, :]
|
812
802
|
elif predict.shape[0] > target_f.shape[0]:
|
813
803
|
raise SonusAIError(
|
814
804
|
f'Error: prediction has more frames than true mixture {predict.shape[0]} vs {truth_f.shape[0]}')
|
@@ -848,14 +838,14 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
848
838
|
phd, phd_bin, phd_frame = phase_distance(hypothesis=predict_complex, reference=truth_f_complex)
|
849
839
|
|
850
840
|
# Noise td logerr
|
851
|
-
# lerr_nt, lerr_nt_bin, lerr_nt_frame = log_error(
|
841
|
+
# lerr_nt, lerr_nt_bin, lerr_nt_frame = log_error(noise_fi, noise_truth_est_audio)
|
852
842
|
|
853
843
|
# # SA-SDR (time-domain source-aggragated SDR)
|
854
|
-
ytrue = np.concatenate((
|
844
|
+
ytrue = np.concatenate((target_fi[:, np.newaxis], noise_fi[:, np.newaxis]), axis=1)
|
855
845
|
ypred = np.concatenate((target_est_wav[:, np.newaxis], noise_est_wav[:, np.newaxis]), axis=1)
|
856
846
|
# # note: w/o scale is more pessimistic number
|
857
847
|
# sa_sdr, _ = calc_sa_sdr(hypothesis=ypred, reference=ytrue)
|
858
|
-
target_stoi = stoi(
|
848
|
+
target_stoi = stoi(target_fi, target_est_wav, 16000, extended=False)
|
859
849
|
|
860
850
|
wsdr, wsdr_cc, wsdr_cw = calc_wsdr(hypothesis=ypred, reference=ytrue, with_log=True)
|
861
851
|
# logger.debug(f'wsdr weight sum for mixid {mixid} = {np.sum(wsdr_cw)}.')
|
@@ -865,7 +855,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
865
855
|
# Speech intelligibility measure - PESQ
|
866
856
|
if int(mixdb.mixture(mixid).snr) > -99:
|
867
857
|
# len = target_est_wav.shape[0]
|
868
|
-
pesq_speech, csig_tg, cbak_tg, covl_tg, sgsnr_tg = calc_speech_metrics(target_est_wav,
|
858
|
+
pesq_speech, csig_tg, cbak_tg, covl_tg, sgsnr_tg = calc_speech_metrics(target_est_wav, target_fi)
|
869
859
|
pesq_mixture, csig_mx, cbak_mx, covl_mx, sgsnr_mx = calc_speech_metrics(mixture, target)
|
870
860
|
# pesq_speech_tst = calc_pesq(hypothesis=target_est_wav, reference=target)
|
871
861
|
# pesq_mixture_tst = calc_pesq(hypothesis=mixture, reference=target)
|
@@ -884,23 +874,26 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
884
874
|
covl_mx = 0
|
885
875
|
covl_tg = 0
|
886
876
|
|
887
|
-
# Calc
|
888
|
-
asr_tt =
|
889
|
-
asr_mx =
|
890
|
-
asr_tge =
|
891
|
-
if
|
877
|
+
# Calc ASR
|
878
|
+
asr_tt = None
|
879
|
+
asr_mx = None
|
880
|
+
asr_tge = None
|
881
|
+
if asr_method == 'none' or mixdb.mixture(mixid).snr == -99: # noise only, ignore/reset target asr
|
892
882
|
wer_mx = float('nan')
|
893
883
|
wer_tge = float('nan')
|
894
884
|
wer_pi = float('nan')
|
895
885
|
else:
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
886
|
+
asr_tt = MP_GLOBAL.mixdb.get_speech_metadata(mixid, 'text')[0] # ignore mixup
|
887
|
+
if asr_tt is None:
|
888
|
+
asr_tt = calc_asr(target, engine=asr_method, whisper_model_name=asr_model_name).text # target truth
|
889
|
+
# if MP_GLOBAL.mixdb.asr_manifests:
|
890
|
+
# asr_tt = MP_GLOBAL.mixdb.mixture_asr_data(mixid)[0] # ignore mixup
|
891
|
+
# else:
|
892
|
+
# asr_tt = calc_asr(target, engine=asr_method, whisper_model_name=asr_model_name).text # target truth
|
900
893
|
|
901
894
|
if asr_tt:
|
902
|
-
asr_mx = calc_asr(mixture, engine=
|
903
|
-
asr_tge = calc_asr(target_est_wav, engine=
|
895
|
+
asr_mx = calc_asr(mixture, engine=asr_method, whisper_model_name=asr_model_name).text
|
896
|
+
asr_tge = calc_asr(target_est_wav, engine=asr_method, whisper_model_name=asr_model_name).text
|
904
897
|
|
905
898
|
wer_mx = calc_wer(asr_mx, asr_tt).wer * 100 # mixture wer
|
906
899
|
wer_tge = calc_wer(asr_tge, asr_tt).wer * 100 # target estimate wer
|
@@ -962,10 +955,10 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
962
955
|
print('', file=f)
|
963
956
|
print(f'Target path: {mixdb.target_file(ti).name}', file=f)
|
964
957
|
print(f'Noise path: {mixdb.noise_file(ni).name}', file=f)
|
965
|
-
if
|
966
|
-
print(f'
|
958
|
+
if asr_method != 'none':
|
959
|
+
print(f'ASR method: {asr_method} and whisper model (if used): {asr_model_name}', file=f)
|
967
960
|
if mixdb.asr_manifests:
|
968
|
-
print(f'ASR truth from
|
961
|
+
print(f'ASR truth from metadata: {asr_tt}', file=f)
|
969
962
|
else:
|
970
963
|
print(f'ASR truth from wer method: {asr_tt}', file=f)
|
971
964
|
print(f'ASR result for mixture: {asr_mx}', file=f)
|
@@ -977,7 +970,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
977
970
|
if enable_wav:
|
978
971
|
write_wav(name=base_name + '_mixture.wav', audio=float_to_int16(mixture))
|
979
972
|
write_wav(name=base_name + '_target.wav', audio=float_to_int16(target))
|
980
|
-
# write_wav(name=base_name + '
|
973
|
+
# write_wav(name=base_name + '_target_fi.wav', audio=float_to_int16(target_fi))
|
981
974
|
write_wav(name=base_name + '_noise.wav', audio=float_to_int16(noise))
|
982
975
|
write_wav(name=base_name + '_target_est.wav', audio=float_to_int16(target_est_wav))
|
983
976
|
write_wav(name=base_name + '_noise_est.wav', audio=float_to_int16(noise_est_wav))
|
@@ -992,7 +985,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
992
985
|
# 8) Write out plot file
|
993
986
|
if enable_plot:
|
994
987
|
from matplotlib.backends.backend_pdf import PdfPages
|
995
|
-
|
988
|
+
plot_name = base_name + '_metric_spenh.pdf'
|
996
989
|
|
997
990
|
# Reshape feature to eliminate overlap redundancy for easier to understand spectrogram view
|
998
991
|
# Original size (frames, stride, num_bands), decimates in stride dimension only if step is > 1
|
@@ -1007,7 +1000,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
1007
1000
|
feat_sgram = feat_sgram[:, -step:, :] # decimate, Fx1xB
|
1008
1001
|
feat_sgram = np.reshape(feat_sgram, (feat_sgram.shape[0] * feat_sgram.shape[1], feat_sgram.shape[2]))
|
1009
1002
|
|
1010
|
-
with PdfPages(
|
1003
|
+
with PdfPages(plot_name) as pdf:
|
1011
1004
|
# page1 we always have a mixture and prediction, target optional if truth provided
|
1012
1005
|
tfunc_name = mixdb.target_file(1).truth_settings[0].function # first target, assumes all have same
|
1013
1006
|
if tfunc_name == 'mapped_snr_f':
|
@@ -1036,25 +1029,25 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
1036
1029
|
tg_spec = 20 * np.log10(abs(target_f) + np.finfo(np.float32).eps)
|
1037
1030
|
tg_est_spec = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
|
1038
1031
|
# n_spec = np.reshape(n_spec,(n_spec.shape[0] * n_spec.shape[1], n_spec.shape[2]))
|
1039
|
-
pdf.savefig(
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1032
|
+
pdf.savefig(plot_e_predict_truth(predict=tg_est_spec,
|
1033
|
+
predict_wav=target_est_wav,
|
1034
|
+
truth_f=tg_spec,
|
1035
|
+
truth_wav=target_fi,
|
1036
|
+
metric=np.vstack((lerr_tg_frame, phd_frame)).T,
|
1037
|
+
tp_title='speech estimate'))
|
1045
1038
|
|
1046
1039
|
# page 4 noise extraction
|
1047
1040
|
n_spec = 20 * np.log10(abs(noise_f) + np.finfo(np.float32).eps)
|
1048
1041
|
n_est_spec = 20 * np.log10(abs(noise_est_complex) + np.finfo(np.float32).eps)
|
1049
|
-
pdf.savefig(
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1042
|
+
pdf.savefig(plot_e_predict_truth(predict=n_est_spec,
|
1043
|
+
predict_wav=noise_est_wav,
|
1044
|
+
truth_f=n_spec,
|
1045
|
+
truth_wav=noise_fi,
|
1046
|
+
metric=lerr_n_frame,
|
1047
|
+
tp_title='noise estimate'))
|
1055
1048
|
|
1056
1049
|
# Plot error waveforms
|
1057
|
-
# tg_err_wav =
|
1050
|
+
# tg_err_wav = target_fi - target_est_wav
|
1058
1051
|
# tg_err_spec = 20*np.log10(np.abs(target_f - predict_complex))
|
1059
1052
|
|
1060
1053
|
plt.close('all')
|
@@ -1072,14 +1065,14 @@ def main():
|
|
1072
1065
|
|
1073
1066
|
verbose = args['--verbose']
|
1074
1067
|
mixids = args['--mixid']
|
1075
|
-
|
1076
|
-
|
1068
|
+
asr_method = args['--asr-method'].lower()
|
1069
|
+
asr_model_name = args['--model'].lower()
|
1077
1070
|
truth_est_mode = args['--truth-est-mode']
|
1078
1071
|
enable_plot = args['--plot']
|
1079
1072
|
enable_wav = args['--wav']
|
1080
1073
|
enable_summary = args['--summary']
|
1074
|
+
predict_location = args['PLOC']
|
1081
1075
|
truth_location = args['TLOC']
|
1082
|
-
whisper_model = args['--whisper-model'].lower()
|
1083
1076
|
|
1084
1077
|
import glob
|
1085
1078
|
from os.path import basename
|
@@ -1103,19 +1096,19 @@ def main():
|
|
1103
1096
|
if not isdir(predict_location):
|
1104
1097
|
print(f'The specified predict location {predict_location} is not a valid subdirectory path, exiting ...')
|
1105
1098
|
|
1106
|
-
#
|
1107
|
-
|
1099
|
+
# all_predict_files = listdir(predict_location)
|
1100
|
+
all_predict_files = glob.glob(predict_location + "/*.h5")
|
1108
1101
|
predict_logfile = glob.glob(predict_location + "/*predict.log")
|
1109
|
-
|
1110
|
-
if len(
|
1111
|
-
|
1112
|
-
if len(
|
1102
|
+
predict_wav_mode = False
|
1103
|
+
if len(all_predict_files) <= 0 and not truth_est_mode:
|
1104
|
+
all_predict_files = glob.glob(predict_location + "/*.wav") # check for wav files
|
1105
|
+
if len(all_predict_files) <= 0:
|
1113
1106
|
print(f'Subdirectory {predict_location} has no .h5 or .wav files, exiting ...')
|
1114
1107
|
else:
|
1115
|
-
logger.info(f'Found {len(
|
1116
|
-
|
1108
|
+
logger.info(f'Found {len(all_predict_files)} prediction .wav files.')
|
1109
|
+
predict_wav_mode = True
|
1117
1110
|
else:
|
1118
|
-
logger.info(f'Found {len(
|
1111
|
+
logger.info(f'Found {len(all_predict_files)} prediction .h5 files.')
|
1119
1112
|
|
1120
1113
|
if len(predict_logfile) == 0:
|
1121
1114
|
logger.info(f'Warning, predict location {predict_location} has no prediction log files.')
|
@@ -1134,51 +1127,51 @@ def main():
|
|
1134
1127
|
logger.info(f'Only running specified subset of {len(mixids)} mixtures')
|
1135
1128
|
|
1136
1129
|
enable_asr_warmup = False
|
1137
|
-
if
|
1130
|
+
if asr_method == 'none':
|
1138
1131
|
fnb = 'metric_spenh_'
|
1139
|
-
elif
|
1132
|
+
elif asr_method == 'google':
|
1140
1133
|
fnb = 'metric_spenh_ggl_'
|
1141
|
-
logger.info(f'
|
1134
|
+
logger.info(f'ASR enabled with method {asr_method}')
|
1142
1135
|
enable_asr_warmup = True
|
1143
|
-
elif
|
1136
|
+
elif asr_method == 'deepgram':
|
1144
1137
|
fnb = 'metric_spenh_dgram_'
|
1145
|
-
logger.info(f'
|
1138
|
+
logger.info(f'ASR enabled with method {asr_method}')
|
1146
1139
|
enable_asr_warmup = True
|
1147
|
-
elif
|
1148
|
-
fnb = 'metric_spenh_whspx_' +
|
1149
|
-
logger.info(f'
|
1140
|
+
elif asr_method == 'aixplain_whisper':
|
1141
|
+
fnb = 'metric_spenh_whspx_' + asr_model_name + '_'
|
1142
|
+
logger.info(f'ASR enabled with method {asr_method} and whisper model {asr_model_name}')
|
1150
1143
|
enable_asr_warmup = True
|
1151
|
-
elif
|
1152
|
-
fnb = 'metric_spenh_whspl_' +
|
1153
|
-
logger.info(f'
|
1144
|
+
elif asr_method == 'whisper':
|
1145
|
+
fnb = 'metric_spenh_whspl_' + asr_model_name + '_'
|
1146
|
+
logger.info(f'ASR enabled with method {asr_method} and whisper model {asr_model_name}')
|
1154
1147
|
enable_asr_warmup = True
|
1155
|
-
elif
|
1156
|
-
fnb = 'metric_spenh_whspaaw_' +
|
1157
|
-
logger.info(f'
|
1148
|
+
elif asr_method == 'aaware_whisper':
|
1149
|
+
fnb = 'metric_spenh_whspaaw_' + asr_model_name + '_'
|
1150
|
+
logger.info(f'ASR enabled with method {asr_method} and whisper model {asr_model_name}')
|
1158
1151
|
enable_asr_warmup = True
|
1159
|
-
elif
|
1160
|
-
fnb = 'metric_spenh_fwhsp_' +
|
1161
|
-
logger.info(f'
|
1152
|
+
elif asr_method == 'faster_whisper':
|
1153
|
+
fnb = 'metric_spenh_fwhsp_' + asr_model_name + '_'
|
1154
|
+
logger.info(f'ASR enabled with method {asr_method} and whisper model {asr_model_name}')
|
1162
1155
|
enable_asr_warmup = True
|
1163
1156
|
else:
|
1164
|
-
logger.error(f'Unrecognized
|
1157
|
+
logger.error(f'Unrecognized ASR method: {asr_method}')
|
1165
1158
|
return
|
1166
1159
|
|
1167
1160
|
if enable_asr_warmup:
|
1168
1161
|
DEFAULT_SPEECH = split(DEFAULT_NOISE)[0] + '/speech_ma01_01.wav'
|
1169
1162
|
audio = read_audio(DEFAULT_SPEECH)
|
1170
1163
|
logger.info(f'Warming up asr method, note for cloud service this could take up to a few min ...')
|
1171
|
-
asr_chk = calc_asr(audio, engine=
|
1164
|
+
asr_chk = calc_asr(audio, engine=asr_method, whisper_model_name=asr_model_name)
|
1172
1165
|
logger.info(f'Warmup completed, results {asr_chk}')
|
1173
1166
|
|
1174
1167
|
MP_GLOBAL.mixdb = mixdb
|
1175
1168
|
MP_GLOBAL.predict_location = predict_location
|
1176
|
-
MP_GLOBAL.
|
1169
|
+
MP_GLOBAL.predict_wav_mode = predict_wav_mode
|
1177
1170
|
MP_GLOBAL.truth_est_mode = truth_est_mode
|
1178
1171
|
MP_GLOBAL.enable_plot = enable_plot
|
1179
1172
|
MP_GLOBAL.enable_wav = enable_wav
|
1180
|
-
MP_GLOBAL.
|
1181
|
-
MP_GLOBAL.
|
1173
|
+
MP_GLOBAL.asr_method = asr_method
|
1174
|
+
MP_GLOBAL.asr_model_name = asr_model_name
|
1182
1175
|
|
1183
1176
|
# Individual mixtures use pandas print, set precision to 2 decimal places
|
1184
1177
|
# pd.set_option('float_format', '{:.2f}'.format)
|
@@ -1255,7 +1248,7 @@ def main():
|
|
1255
1248
|
ofname = join(predict_location, fnb + 'summary_truest.txt')
|
1256
1249
|
|
1257
1250
|
with open(ofname, 'w') as f:
|
1258
|
-
print(f'
|
1251
|
+
print(f'ASR enabled with method {asr_method}, whisper model, if used: {asr_model_name}', file=f)
|
1259
1252
|
print(f'Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:',
|
1260
1253
|
file=f)
|
1261
1254
|
print(all_nom99_mean.to_frame().T.round(2).to_string(float_format=lambda x: "{:.2f}".format(x),
|
@@ -1318,7 +1311,7 @@ def main():
|
|
1318
1311
|
label = f'Extraction statistics stats over {num_mix} mixtures:'
|
1319
1312
|
pd.DataFrame([label]).to_csv(csv_name, **header_args)
|
1320
1313
|
all_metrics_table_2.describe().round(2).to_csv(csv_name, **table_args)
|
1321
|
-
label = f'
|
1314
|
+
label = f'ASR enabled with method {asr_method}, whisper model, if used: {asr_model_name}'
|
1322
1315
|
pd.DataFrame([label]).to_csv(csv_name, **header_args)
|
1323
1316
|
|
1324
1317
|
if not truth_est_mode:
|