sonusai 0.16.1__py3-none-any.whl → 0.17.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,27 +1,25 @@
1
1
  """sonusai calc_metric_spenh
2
2
 
3
- usage: calc_metric_spenh [-hvtpws] [-i MIXID] [-e WER] [-m WMNAME] PLOC TLOC
3
+ usage: calc_metric_spenh [-hvtpws] [-i MIXID] [-e ASR] [-m MODEL] PLOC TLOC
4
4
 
5
5
  options:
6
6
  -h, --help
7
7
  -v, --verbose Be verbose.
8
- -i MIXID, --mixid MIXID Mixture ID(s) to process, can be range like 0:maxmix+1 [default: *].
8
+ -i MIXID, --mixid MIXID Mixture ID(s) to process, can be range like 0:maxmix+1. [default: *]
9
9
  -t, --truth-est-mode Calculate extraction and metrics using truth (instead of prediction).
10
10
  -p, --plot Enable PDF plots file generation per mixture.
11
11
  -w, --wav Generate WAV files per mixture.
12
12
  -s, --summary Enable summary files generation.
13
- -e WER, --wer-method WER Word-Error-Rate method: deepgram, google, aixplain_whisper
14
- or whisper (locally run) [default: none]
15
- -m WMNAME, --whisper-model Whisper model name used in aixplain_whisper and whisper WER methods.
16
- [default: tiny]
13
+ -e ASR, --asr-method ASR ASR method: deepgram, google, aixplain_whisper, whisper, or sensory. [default: none]
14
+ -m MODEL, --model ASR model name used in some ASR methods. [default: tiny]
17
15
 
18
- Calculate speech enhancement metrics of prediction data in PLOC using SonusAI mixture data
19
- in TLOC as truth/label reference. Metric and extraction data files are written into PLOC.
16
+ Calculate speech enhancement metrics of prediction data in PLOC using SonusAI mixture data in TLOC as truth/label
17
+ reference. Metric and extraction data files are written into PLOC.
20
18
 
21
19
  PLOC directory containing prediction data in .h5 files created from truth/label mixture data in TLOC
22
20
  TLOC directory with SonusAI mixture database of truth/label mixture data
23
21
 
24
- For whisper WER methods, the possible models used in local processing (WER = whisper) are:
22
+ For whisper ASR methods, the possible models used in local processing (ASR = whisper) are:
25
23
  {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large}
26
24
  but note most are very computationally demanding and can overwhelm/hang a local system.
27
25
 
@@ -68,6 +66,7 @@ import matplotlib
68
66
  import matplotlib.pyplot as plt
69
67
  import numpy as np
70
68
  import pandas as pd
69
+
71
70
  from sonusai.mixture import AudioF
72
71
  from sonusai.mixture import AudioT
73
72
  from sonusai.mixture import Feature
@@ -93,12 +92,12 @@ matplotlib.use('SVG')
93
92
  class MPGlobal:
94
93
  mixdb: MixtureDatabase = None
95
94
  predict_location: str = None
96
- predwav_mode: bool = None
95
+ predict_wav_mode: bool = None
97
96
  truth_est_mode: bool = None
98
97
  enable_plot: bool = None
99
98
  enable_wav: bool = None
100
- wer_method: str = None
101
- whisper_model: str = None
99
+ asr_method: str = None
100
+ asr_model_name: str = None
102
101
 
103
102
 
104
103
  MP_GLOBAL = MPGlobal()
@@ -132,64 +131,62 @@ def snr(clean_speech, processed_speech, sample_rate):
132
131
  overall_snr = 10 * np.log10(np.sum(np.square(clean_speech)) / np.sum(np.square(clean_speech - processed_speech)))
133
132
 
134
133
  # Global Variables
135
- winlength = round(30 * sample_rate / 1000) # window length in samples
136
- skiprate = int(np.floor(winlength / 4)) # window skip in samples
137
- MIN_SNR = -10 # minimum SNR in dB
138
- MAX_SNR = 35 # maximum SNR in dB
134
+ win_length = round(30 * sample_rate / 1000) # window length in samples
135
+ skip_rate = int(np.floor(win_length / 4)) # window skip in samples
136
+ min_snr = -10 # minimum SNR in dB
137
+ max_snr = 35 # maximum SNR in dB
139
138
 
140
139
  # For each frame of input speech, calculate the Segmental SNR
141
- num_frames = int(clean_length / skiprate - (winlength / skiprate)) # number of frames
140
+ num_frames = int(clean_length / skip_rate - (win_length / skip_rate)) # number of frames
142
141
  start = 0 # starting sample
143
- window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, winlength + 1) / (winlength + 1)))
142
+ window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
144
143
 
145
144
  segmental_snr = np.empty(num_frames)
146
- EPS = np.spacing(1)
145
+ eps = np.spacing(1)
147
146
  for frame_count in range(num_frames):
148
147
  # (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
149
- clean_frame = clean_speech[start:start + winlength]
150
- processed_frame = processed_speech[start:start + winlength]
148
+ clean_frame = clean_speech[start:start + win_length]
149
+ processed_frame = processed_speech[start:start + win_length]
151
150
  clean_frame = np.multiply(clean_frame, window)
152
151
  processed_frame = np.multiply(processed_frame, window)
153
152
 
154
153
  # (2) Compute the Segmental SNR
155
154
  signal_energy = np.sum(np.square(clean_frame))
156
155
  noise_energy = np.sum(np.square(clean_frame - processed_frame))
157
- segmental_snr[frame_count] = 10 * np.log10(signal_energy / (noise_energy + EPS) + EPS)
158
- segmental_snr[frame_count] = max(segmental_snr[frame_count], MIN_SNR)
159
- segmental_snr[frame_count] = min(segmental_snr[frame_count], MAX_SNR)
156
+ segmental_snr[frame_count] = 10 * np.log10(signal_energy / (noise_energy + eps) + eps)
157
+ segmental_snr[frame_count] = np.max(segmental_snr[frame_count], min_snr)
158
+ segmental_snr[frame_count] = np.min(segmental_snr[frame_count], max_snr)
160
159
 
161
- start = start + skiprate
160
+ start = start + skip_rate
162
161
 
163
162
  return overall_snr, segmental_snr
164
163
 
165
164
 
166
- def lpcoeff(speech_frame, model_order):
165
+ def lp_coefficients(speech_frame, model_order):
167
166
  # (1) Compute Autocorrelation Lags
168
- winlength = np.size(speech_frame)
169
- R = np.empty(model_order + 1)
170
- E = np.empty(model_order + 1)
167
+ win_length = np.size(speech_frame)
168
+ autocorrelation = np.empty(model_order + 1)
169
+ e = np.empty(model_order + 1)
171
170
  for k in range(model_order + 1):
172
- R[k] = np.dot(speech_frame[0:winlength - k], speech_frame[k: winlength])
171
+ autocorrelation[k] = np.dot(speech_frame[0:win_length - k], speech_frame[k: win_length])
173
172
 
174
173
  # (2) Levinson-Durbin
175
174
  a = np.ones(model_order)
176
175
  a_past = np.empty(model_order)
177
- rcoeff = np.empty(model_order)
178
- E[0] = R[0]
176
+ ref_coefficients = np.empty(model_order)
177
+ e[0] = autocorrelation[0]
179
178
  for i in range(model_order):
180
179
  a_past[0: i] = a[0: i]
181
- sum_term = np.dot(a_past[0: i], R[i:0:-1])
182
- rcoeff[i] = (R[i + 1] - sum_term) / E[i]
183
- a[i] = rcoeff[i]
180
+ sum_term = np.dot(a_past[0: i], autocorrelation[i:0:-1])
181
+ ref_coefficients[i] = (autocorrelation[i + 1] - sum_term) / e[i]
182
+ a[i] = ref_coefficients[i]
184
183
  if i == 0:
185
- a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1:-1:-1], rcoeff[i])
184
+ a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1:-1:-1], ref_coefficients[i])
186
185
  else:
187
- a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1::-1], rcoeff[i])
188
- E[i + 1] = (1 - rcoeff[i] * rcoeff[i]) * E[i]
189
- acorr = R
190
- refcoeff = rcoeff
191
- lpparams = np.concatenate((np.array([1]), -a))
192
- return acorr, refcoeff, lpparams
186
+ a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1::-1], ref_coefficients[i])
187
+ e[i + 1] = (1 - ref_coefficients[i] * ref_coefficients[i]) * e[i]
188
+ lp_params = np.concatenate((np.array([1]), -a))
189
+ return autocorrelation, ref_coefficients, lp_params
193
190
 
194
191
 
195
192
  def llr(clean_speech, processed_speech, sample_rate):
@@ -199,38 +196,38 @@ def llr(clean_speech, processed_speech, sample_rate):
199
196
  clean_length = np.size(clean_speech)
200
197
  processed_length = np.size(processed_speech)
201
198
  if clean_length != processed_length:
202
- raise ValueError('Both Speech Files must be same length.')
199
+ raise ValueError('Both speech files must be same length.')
203
200
 
204
201
  # Global Variables
205
- winlength = (np.round(30 * sample_rate / 1000)).astype(int) # window length in samples
206
- skiprate = (np.floor(winlength / 4)).astype(int) # window skip in samples
202
+ win_length = (np.round(30 * sample_rate / 1000)).astype(int) # window length in samples
203
+ skip_rate = (np.floor(win_length / 4)).astype(int) # window skip in samples
207
204
  if sample_rate < 10000:
208
- P = 10 # LPC Analysis Order
205
+ p = 10 # LPC Analysis Order
209
206
  else:
210
- P = 16 # this could vary depending on sampling frequency.
207
+ p = 16 # this could vary depending on sampling frequency.
211
208
 
212
209
  # For each frame of input speech, calculate the Log Likelihood Ratio
213
- num_frames = int((clean_length - winlength) / skiprate) # number of frames
210
+ num_frames = int((clean_length - win_length) / skip_rate) # number of frames
214
211
  start = 0 # starting sample
215
- window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, winlength + 1) / (winlength + 1)))
212
+ window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
216
213
 
217
214
  distortion = np.empty(num_frames)
218
215
  for frame_count in range(num_frames):
219
216
  # (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
220
- clean_frame = clean_speech[start: start + winlength]
221
- processed_frame = processed_speech[start: start + winlength]
217
+ clean_frame = clean_speech[start: start + win_length]
218
+ processed_frame = processed_speech[start: start + win_length]
222
219
  clean_frame = np.multiply(clean_frame, window)
223
220
  processed_frame = np.multiply(processed_frame, window)
224
221
 
225
222
  # (2) Get the autocorrelation lags and LPC parameters used to compute the LLR measure.
226
- R_clean, Ref_clean, A_clean = lpcoeff(clean_frame, P)
227
- R_processed, Ref_processed, A_processed = lpcoeff(processed_frame, P)
223
+ r_clean, ref_clean, a_clean = lp_coefficients(clean_frame, p)
224
+ r_processed, ref_processed, a_processed = lp_coefficients(processed_frame, p)
228
225
 
229
226
  # (3) Compute the LLR measure
230
- numerator = np.dot(np.matmul(A_processed, toeplitz(R_clean)), A_processed)
231
- denominator = np.dot(np.matmul(A_clean, toeplitz(R_clean)), A_clean)
227
+ numerator = np.dot(np.matmul(a_processed, toeplitz(r_clean)), a_processed)
228
+ denominator = np.dot(np.matmul(a_clean, toeplitz(r_clean)), a_clean)
232
229
  distortion[frame_count] = np.log(numerator / denominator)
233
- start = start + skiprate
230
+ start = start + skip_rate
234
231
  return distortion
235
232
 
236
233
 
@@ -244,16 +241,15 @@ def wss(clean_speech, processed_speech, sample_rate):
244
241
  raise ValueError('Files must have same length.')
245
242
 
246
243
  # Global variables
247
- winlength = (np.round(30 * sample_rate / 1000)).astype(int) # window length in samples
248
- skiprate = (np.floor(np.divide(winlength, 4))).astype(int) # window skip in samples
244
+ win_length = (np.round(30 * sample_rate / 1000)).astype(int) # window length in samples
245
+ skip_rate = (np.floor(np.divide(win_length, 4))).astype(int) # window skip in samples
249
246
  max_freq = (np.divide(sample_rate, 2)).astype(int) # maximum bandwidth
250
247
  num_crit = 25 # number of critical bands
251
248
 
252
- USE_FFT_SPECTRUM = 1 # defaults to 10th order LP spectrum
253
- n_fft = (np.power(2, np.ceil(np.log2(2 * winlength)))).astype(int)
254
- n_fftby2 = (np.multiply(0.5, n_fft)).astype(int) # FFT size/2
255
- Kmax = 20.0 # value suggested by Klatt, pg 1280
256
- Klocmax = 1.0 # value suggested by Klatt, pg 1280
249
+ n_fft = (np.power(2, np.ceil(np.log2(2 * win_length)))).astype(int)
250
+ n_fft_by_2 = (np.multiply(0.5, n_fft)).astype(int) # FFT size/2
251
+ k_max = 20.0 # value suggested by Klatt, pg 1280
252
+ k_loc_max = 1.0 # value suggested by Klatt, pg 1280
257
253
 
258
254
  # Critical Band Filter Definitions (Center Frequency and Bandwidths in Hz)
259
255
  cent_freq = np.array([50.0000, 120.000, 190.000, 260.000, 330.000, 400.000, 470.000,
@@ -268,39 +264,38 @@ def wss(clean_speech, processed_speech, sample_rate):
268
264
  bw_min = bandwidth[0] # minimum critical bandwidth
269
265
 
270
266
  # Set up the critical band filters.
271
- # Note here that Gaussianly shaped filters are used.
267
+ # Note here that Gaussian-ly shaped filters are used.
272
268
  # Also, the sum of the filter weights are equivalent for each critical band filter.
273
269
  # Filter less than -30 dB and set to zero.
274
270
  min_factor = np.exp(-30.0 / (2.0 * 2.303)) # -30 dB point of filter
275
- crit_filter = np.empty((num_crit, n_fftby2))
271
+ crit_filter = np.empty((num_crit, n_fft_by_2))
276
272
  for i in range(num_crit):
277
- f0 = (cent_freq[i] / max_freq) * n_fftby2
278
- bw = (bandwidth[i] / max_freq) * n_fftby2
273
+ f0 = (cent_freq[i] / max_freq) * n_fft_by_2
274
+ bw = (bandwidth[i] / max_freq) * n_fft_by_2
279
275
  norm_factor = np.log(bw_min) - np.log(bandwidth[i])
280
- j = np.arange(n_fftby2)
276
+ j = np.arange(n_fft_by_2)
281
277
  crit_filter[i, :] = np.exp(-11 * np.square(np.divide(j - np.floor(f0), bw)) + norm_factor)
282
278
  cond = np.greater(crit_filter[i, :], min_factor)
283
279
  crit_filter[i, :] = np.where(cond, crit_filter[i, :], 0)
284
280
  # For each frame of input speech, calculate the Weighted Spectral Slope Measure
285
- num_frames = int(clean_length / skiprate - (winlength / skiprate)) # number of frames
281
+ num_frames = int(clean_length / skip_rate - (win_length / skip_rate)) # number of frames
286
282
  start = 0 # starting sample
287
- window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, winlength + 1) / (winlength + 1)))
283
+ window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
288
284
 
289
285
  distortion = np.empty(num_frames)
290
286
  for frame_count in range(num_frames):
291
287
  # (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
292
- clean_frame = clean_speech[start: start + winlength] / 32768
293
- processed_frame = processed_speech[start: start + winlength] / 32768
288
+ clean_frame = clean_speech[start: start + win_length] / 32768
289
+ processed_frame = processed_speech[start: start + win_length] / 32768
294
290
  clean_frame = np.multiply(clean_frame, window)
295
291
  processed_frame = np.multiply(processed_frame, window)
296
292
  # (2) Compute the Power Spectrum of Clean and Processed
297
- # if USE_FFT_SPECTRUM:
298
293
  clean_spec = np.square(np.abs(fft(clean_frame, n_fft)))
299
294
  processed_spec = np.square(np.abs(fft(processed_frame, n_fft)))
300
295
 
301
296
  # (3) Compute Filterbank Output Energies (in dB scale)
302
- clean_energy = np.matmul(crit_filter, clean_spec[0:n_fftby2])
303
- processed_energy = np.matmul(crit_filter, processed_spec[0:n_fftby2])
297
+ clean_energy = np.matmul(crit_filter, clean_spec[0:n_fft_by_2])
298
+ processed_energy = np.matmul(crit_filter, processed_spec[0:n_fft_by_2])
304
299
 
305
300
  clean_energy = 10 * np.log10(np.maximum(clean_energy, 1E-10))
306
301
  processed_energy = 10 * np.log10(np.maximum(processed_energy, 1E-10))
@@ -340,39 +335,39 @@ def wss(clean_speech, processed_speech, sample_rate):
340
335
  processed_loc_peak[i] = processed_energy[n + 1]
341
336
 
342
337
  # (6) Compute the WSS Measure for this frame. This includes determination of the weighting function.
343
- dBMax_clean = np.max(clean_energy)
344
- dBMax_processed = np.max(processed_energy)
338
+ db_max_clean = np.max(clean_energy)
339
+ db_max_processed = np.max(processed_energy)
345
340
  '''
346
341
  The weights are calculated by averaging individual weighting factors from the clean and processed frame.
347
- These weights W_clean and W_processed should range from 0 to 1 and place more emphasis on spectral peaks
342
+ These weights w_clean and w_processed should range from 0 to 1 and place more emphasis on spectral peaks
348
343
  and less emphasis on slope differences in spectral valleys.
349
344
  This procedure is described on page 1280 of Klatt's 1982 ICASSP paper.
350
345
  '''
351
- Wmax_clean = np.divide(Kmax, Kmax + dBMax_clean - clean_energy[0: num_crit - 1])
352
- Wlocmax_clean = np.divide(Klocmax, Klocmax + clean_loc_peak - clean_energy[0: num_crit - 1])
353
- W_clean = np.multiply(Wmax_clean, Wlocmax_clean)
346
+ w_max_clean = np.divide(k_max, k_max + db_max_clean - clean_energy[0: num_crit - 1])
347
+ w_loc_max_clean = np.divide(k_loc_max, k_loc_max + clean_loc_peak - clean_energy[0: num_crit - 1])
348
+ w_clean = np.multiply(w_max_clean, w_loc_max_clean)
354
349
 
355
- Wmax_processed = np.divide(Kmax, Kmax + dBMax_processed - processed_energy[0: num_crit - 1])
356
- Wlocmax_processed = np.divide(Klocmax, Klocmax + processed_loc_peak - processed_energy[0: num_crit - 1])
357
- W_processed = np.multiply(Wmax_processed, Wlocmax_processed)
350
+ w_max_processed = np.divide(k_max, k_max + db_max_processed - processed_energy[0: num_crit - 1])
351
+ w_loc_max_processed = np.divide(k_loc_max, k_loc_max + processed_loc_peak - processed_energy[0: num_crit - 1])
352
+ w_processed = np.multiply(w_max_processed, w_loc_max_processed)
358
353
 
359
- W = np.divide(np.add(W_clean, W_processed), 2.0)
354
+ w = np.divide(np.add(w_clean, w_processed), 2.0)
360
355
  slope_diff = np.subtract(clean_slope, processed_slope)[0: num_crit - 1]
361
- distortion[frame_count] = np.dot(W, np.square(slope_diff)) / np.sum(W)
362
- # this normalization is not part of Klatt's paper, but helps to normalize the measure.
356
+ distortion[frame_count] = np.dot(w, np.square(slope_diff)) / np.sum(w)
357
+ # This normalization is not part of Klatt's paper, but helps to normalize the measure.
363
358
  # Here we scale the measure by the sum of the weights.
364
- start = start + skiprate
359
+ start = start + skip_rate
365
360
  return distortion
366
361
 
367
362
 
368
363
  def calc_speech_metrics(hypothesis: np.ndarray,
369
364
  reference: np.ndarray) -> tuple[float, int, int, int, float]:
370
365
  """
371
- Calculate speech metrics pesq_mos, CSIG, CBAK, COVL, segSNR. These are all related and thus included
366
+ Calculate speech metrics pesq_mos, c_sig, c_bak, c_ovl, seg_snr. These are all related and thus included
372
367
  in one function. Reference: matlab script "compute_metrics.m".
373
368
 
374
369
  Usage:
375
- pesq, csig, cbak, covl, ssnr = compute_metrics(hypothesis, reference, Fs, path)
370
+ pesq, c_sig, c_bak, c_ovl, ssnr = compute_metrics(hypothesis, reference, fs, path)
376
371
  reference: clean audio as array
377
372
  hypothesis: enhanced audio as array
378
373
  Audio must have sampling rate = 16000 Hz.
@@ -383,41 +378,39 @@ def calc_speech_metrics(hypothesis: np.ndarray,
383
378
  """
384
379
  from sonusai.metrics import calc_pesq
385
380
 
386
- Fs = 16000
381
+ fs = 16000
387
382
 
388
383
  # compute the WSS measure
389
- wss_dist_vec = wss(reference, hypothesis, Fs)
384
+ wss_dist_vec = wss(reference, hypothesis, fs)
390
385
  wss_dist_vec = np.sort(wss_dist_vec)
391
386
  alpha = 0.95 # value from CMGAN ref implementation
392
387
  wss_dist = np.mean(wss_dist_vec[0: round(np.size(wss_dist_vec) * alpha)])
393
388
 
394
389
  # compute the LLR measure
395
- llr_dist = llr(reference, hypothesis, Fs)
390
+ llr_dist = llr(reference, hypothesis, fs)
396
391
  ll_rs = np.sort(llr_dist)
397
392
  llr_len = round(np.size(llr_dist) * alpha)
398
393
  llr_mean = np.mean(ll_rs[0: llr_len])
399
394
 
400
395
  # compute the SNRseg
401
- snr_dist, segsnr_dist = snr(reference, hypothesis, Fs)
402
- snr_mean = snr_dist
403
- segSNR = np.mean(segsnr_dist)
396
+ snr_dist, segsnr_dist = snr(reference, hypothesis, fs)
397
+ seg_snr = np.mean(segsnr_dist)
404
398
 
405
399
  # compute the pesq (use Sonusai wrapper, only fs=16k, mode=wb support)
406
400
  pesq_mos = calc_pesq(hypothesis=hypothesis, reference=reference)
407
- # pesq_mos = pesq(sampling_rate1, data1, data2, 'wb')
408
401
 
409
402
  # now compute the composite measures
410
- CSIG = 3.093 - 1.029 * llr_mean + 0.603 * pesq_mos - 0.009 * wss_dist
411
- CSIG = max(1, CSIG)
412
- CSIG = min(5, CSIG) # limit values to [1, 5]
413
- CBAK = 1.634 + 0.478 * pesq_mos - 0.007 * wss_dist + 0.063 * segSNR
414
- CBAK = max(1, CBAK)
415
- CBAK = min(5, CBAK) # limit values to [1, 5]
416
- COVL = 1.594 + 0.805 * pesq_mos - 0.512 * llr_mean - 0.007 * wss_dist
417
- COVL = max(1, COVL)
418
- COVL = min(5, COVL) # limit values to [1, 5]
403
+ c_sig = 3.093 - 1.029 * llr_mean + 0.603 * pesq_mos - 0.009 * wss_dist
404
+ c_sig = max(1, c_sig)
405
+ c_sig = min(5, c_sig) # limit values to [1, 5]
406
+ c_bak = 1.634 + 0.478 * pesq_mos - 0.007 * wss_dist + 0.063 * seg_snr
407
+ c_bak = max(1, c_bak)
408
+ c_bak = min(5, c_bak) # limit values to [1, 5]
409
+ c_ovl = 1.594 + 0.805 * pesq_mos - 0.512 * llr_mean - 0.007 * wss_dist
410
+ c_ovl = max(1, c_ovl)
411
+ c_ovl = min(5, c_ovl) # limit values to [1, 5]
419
412
 
420
- return pesq_mos, CSIG, CBAK, COVL, segSNR
413
+ return pesq_mos, c_sig, c_bak, c_ovl, seg_snr
421
414
 
422
415
 
423
416
  def mean_square_error(hypothesis: np.ndarray,
@@ -564,10 +557,8 @@ def plot_mixpred(mixture: AudioT,
564
557
  ax[p].plot(x_axis, mixture, label='Mixture', color='mistyrose')
565
558
  ax[0].set_ylabel('magnitude', color='tab:blue')
566
559
  ax[p].set_xlim(x_axis[0], x_axis[-1])
567
- # ax[p].set_ylim([-1.025, 1.025])
568
560
  if target is not None: # Plot target time-domain waveform on top of mixture
569
561
  ax[0].plot(x_axis, target, label='Target', color='tab:blue')
570
- # ax[0].tick_params(axis='y', labelcolor=color)
571
562
  ax[p].set_title('Waveform')
572
563
 
573
564
  # Plot the mixture spectrogram
@@ -589,10 +580,10 @@ def plot_mixpred(mixture: AudioT,
589
580
  return fig
590
581
 
591
582
 
592
- def plot_pdb_predtruth(predict: np.ndarray,
593
- truth_f: Optional[np.ndarray] = None,
594
- metric: Optional[np.ndarray] = None,
595
- tp_title: str = '') -> plt.Figure:
583
+ def plot_pdb_predict_truth(predict: np.ndarray,
584
+ truth_f: Optional[np.ndarray] = None,
585
+ metric: Optional[np.ndarray] = None,
586
+ tp_title: str = '') -> plt.Figure:
596
587
  """Plot predict and optionally truth and a metric in power db, e.g. applies 10*log10(predict)"""
597
588
  num_plots = 2
598
589
  if truth_f is not None:
@@ -636,16 +627,15 @@ def plot_pdb_predtruth(predict: np.ndarray,
636
627
  ax[p].set_title('SNR and SNR mse (mean over freq. db)')
637
628
  else:
638
629
  ax[p].set_title('SNR (mean over freq. db)')
639
- # ax[0].tick_params(axis='y', labelcolor=color)
640
630
  return fig
641
631
 
642
632
 
643
- def plot_epredtruth(predict: np.ndarray,
644
- predict_wav: np.ndarray,
645
- truth_f: Optional[np.ndarray] = None,
646
- truth_wav: Optional[np.ndarray] = None,
647
- metric: Optional[np.ndarray] = None,
648
- tp_title: str = '') -> plt.Figure:
633
+ def plot_e_predict_truth(predict: np.ndarray,
634
+ predict_wav: np.ndarray,
635
+ truth_f: Optional[np.ndarray] = None,
636
+ truth_wav: Optional[np.ndarray] = None,
637
+ metric: Optional[np.ndarray] = None,
638
+ tp_title: str = '') -> plt.Figure:
649
639
  """Plot predict spectrogram and waveform and optionally truth and a metric)"""
650
640
  num_plots = 2
651
641
  if truth_f is not None:
@@ -666,7 +656,7 @@ def plot_epredtruth(predict: np.ndarray,
666
656
  ax[p].imshow(truth_f.transpose(), im.cmap, aspect='auto', interpolation='nearest', origin='lower')
667
657
  ax[p].set_title('Truth')
668
658
 
669
- # Plot the predict wav, and optionally truth avg and metric lines
659
+ # Plot predict wav, and optionally truth avg and metric lines
670
660
  p += 1
671
661
  x_axis = np.arange(len(predict_wav), dtype=np.float32) # / SAMPLE_RATE
672
662
  ax[p].plot(x_axis, predict_wav, color='black', linestyle='dashed', label='Speech Estimate')
@@ -732,12 +722,12 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
732
722
 
733
723
  mixdb = MP_GLOBAL.mixdb
734
724
  predict_location = MP_GLOBAL.predict_location
735
- predwav_mode = MP_GLOBAL.predwav_mode
725
+ predict_wav_mode = MP_GLOBAL.predict_wav_mode
736
726
  truth_est_mode = MP_GLOBAL.truth_est_mode
737
727
  enable_plot = MP_GLOBAL.enable_plot
738
728
  enable_wav = MP_GLOBAL.enable_wav
739
- wer_method = MP_GLOBAL.wer_method
740
- whisper_model = MP_GLOBAL.whisper_model
729
+ asr_method = MP_GLOBAL.asr_method
730
+ asr_model_name = MP_GLOBAL.asr_model_name
741
731
 
742
732
  # 1) Read predict data, var predict with shape [BatchSize,Classes] or [BatchSize,Tsteps,Classes]
743
733
  output_name = join(predict_location, mixdb.mixture(mixid).name)
@@ -749,7 +739,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
749
739
  base_name = splitext(output_name)[0] + '_truest'
750
740
  else:
751
741
  base_name, ext = splitext(output_name) # base_name used later
752
- if not predwav_mode:
742
+ if not predict_wav_mode:
753
743
  try:
754
744
  with h5py.File(output_name, 'r') as f:
755
745
  predict = np.array(f['predict'])
@@ -761,8 +751,8 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
761
751
  predict, _ = reshape_outputs(predict=predict, truth=None, timesteps=predict.shape[1])
762
752
  else:
763
753
  base_name, ext = splitext(output_name)
764
- prfname = join(base_name + '.wav')
765
- audio = read_audio(prfname)
754
+ predict_name = join(base_name + '.wav')
755
+ audio = read_audio(predict_name)
766
756
  predict = forward_transform(audio, mixdb.ft_config)
767
757
  if mixdb.feature[0:1] == 'h':
768
758
  predict = power_compress(predict)
@@ -773,8 +763,8 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
773
763
  target_f = mixdb.mixture_targets_f(mixid, targets=tmp)[0]
774
764
  target = tmp[0]
775
765
  mixture = mixdb.mixture_mixture(mixid) # note: gives full reverberated/distorted target, but no specaugment
776
- # noise_wodist = mixdb.mixture_noise(mixid) # noise without specaugment and distortion
777
- # noise_wodist_f = mixdb.mixture_noise_f(mixid, noise=noise_wodist)
766
+ # noise_wo_dist = mixdb.mixture_noise(mixid) # noise without specaugment and distortion
767
+ # noise_wo_dist_f = mixdb.mixture_noise_f(mixid, noise=noise_wo_dist)
778
768
  noise = mixture - target # has time-domain distortion (ir,etc.) but does not have specaugment
779
769
  # noise_f = mixdb.mixture_noise_f(mixid, noise=noise)
780
770
  segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise) # note: uses pre-IR, pre-specaug audio
@@ -784,9 +774,9 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
784
774
  segsnr_f[segsnr_f == inf] = 7.944e8 # 99db
785
775
  segsnr_f[segsnr_f == -inf] = 1.258e-10 # -99db
786
776
  # need to use inv-tf to match #samples & latency shift properties of predict inv tf
787
- targetfi = inverse_transform(target_f, mixdb.it_config)
788
- noisefi = inverse_transform(noise_f, mixdb.it_config)
789
- # mixturefi = mixdb.inverse_transform(mixture_f)
777
+ target_fi = inverse_transform(target_f, mixdb.it_config)
778
+ noise_fi = inverse_transform(noise_f, mixdb.it_config)
779
+ # mixture_fi = mixdb.inverse_transform(mixture_f)
790
780
 
791
781
  # gen feature, truth - note feature only used for plots
792
782
  # TBD parse truth_f for different formats and also multi-truth
@@ -798,17 +788,17 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
798
788
 
799
789
  if not truth_est_mode:
800
790
  if predict.shape[0] < target_f.shape[0]: # target_f, truth_f, mixture_f, etc. same size
801
- trimf = target_f.shape[0] - predict.shape[0]
802
- logger.debug(f'Warning: prediction frames less than mixture, trimming {trimf} frames from all truth.')
803
- target_f = target_f[0:-trimf, :]
804
- targetfi, _ = inverse_transform(target_f, mixdb.it_config)
805
- trimt = target.shape[0] - targetfi.shape[0]
806
- target = target[0:-trimt]
807
- noise_f = noise_f[0:-trimf, :]
808
- noise = noise[0:-trimt]
809
- mixture_f = mixture_f[0:-trimf, :]
810
- mixture = mixture[0:-trimt]
811
- truth_f = truth_f[0:-trimf, :]
791
+ trim_f = target_f.shape[0] - predict.shape[0]
792
+ logger.debug(f'Warning: prediction frames less than mixture, trimming {trim_f} frames from all truth.')
793
+ target_f = target_f[0:-trim_f, :]
794
+ target_fi, _ = inverse_transform(target_f, mixdb.it_config)
795
+ trim_t = target.shape[0] - target_fi.shape[0]
796
+ target = target[0:-trim_t]
797
+ noise_f = noise_f[0:-trim_f, :]
798
+ noise = noise[0:-trim_t]
799
+ mixture_f = mixture_f[0:-trim_f, :]
800
+ mixture = mixture[0:-trim_t]
801
+ truth_f = truth_f[0:-trim_f, :]
812
802
  elif predict.shape[0] > target_f.shape[0]:
813
803
  raise SonusAIError(
814
804
  f'Error: prediction has more frames than true mixture {predict.shape[0]} vs {truth_f.shape[0]}')
@@ -848,14 +838,14 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
848
838
  phd, phd_bin, phd_frame = phase_distance(hypothesis=predict_complex, reference=truth_f_complex)
849
839
 
850
840
  # Noise td logerr
851
- # lerr_nt, lerr_nt_bin, lerr_nt_frame = log_error(noisefi, noise_truth_est_audio)
841
+ # lerr_nt, lerr_nt_bin, lerr_nt_frame = log_error(noise_fi, noise_truth_est_audio)
852
842
 
853
843
  # # SA-SDR (time-domain source-aggragated SDR)
854
- ytrue = np.concatenate((targetfi[:, np.newaxis], noisefi[:, np.newaxis]), axis=1)
844
+ ytrue = np.concatenate((target_fi[:, np.newaxis], noise_fi[:, np.newaxis]), axis=1)
855
845
  ypred = np.concatenate((target_est_wav[:, np.newaxis], noise_est_wav[:, np.newaxis]), axis=1)
856
846
  # # note: w/o scale is more pessimistic number
857
847
  # sa_sdr, _ = calc_sa_sdr(hypothesis=ypred, reference=ytrue)
858
- target_stoi = stoi(targetfi, target_est_wav, 16000, extended=False)
848
+ target_stoi = stoi(target_fi, target_est_wav, 16000, extended=False)
859
849
 
860
850
  wsdr, wsdr_cc, wsdr_cw = calc_wsdr(hypothesis=ypred, reference=ytrue, with_log=True)
861
851
  # logger.debug(f'wsdr weight sum for mixid {mixid} = {np.sum(wsdr_cw)}.')
@@ -865,7 +855,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
865
855
  # Speech intelligibility measure - PESQ
866
856
  if int(mixdb.mixture(mixid).snr) > -99:
867
857
  # len = target_est_wav.shape[0]
868
- pesq_speech, csig_tg, cbak_tg, covl_tg, sgsnr_tg = calc_speech_metrics(target_est_wav, targetfi)
858
+ pesq_speech, csig_tg, cbak_tg, covl_tg, sgsnr_tg = calc_speech_metrics(target_est_wav, target_fi)
869
859
  pesq_mixture, csig_mx, cbak_mx, covl_mx, sgsnr_mx = calc_speech_metrics(mixture, target)
870
860
  # pesq_speech_tst = calc_pesq(hypothesis=target_est_wav, reference=target)
871
861
  # pesq_mixture_tst = calc_pesq(hypothesis=mixture, reference=target)
@@ -884,23 +874,26 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
884
874
  covl_mx = 0
885
875
  covl_tg = 0
886
876
 
887
- # Calc WER
888
- asr_tt = ''
889
- asr_mx = ''
890
- asr_tge = ''
891
- if wer_method == 'none' or mixdb.mixture(mixid).snr == -99: # noise only, ignore/reset target asr
877
+ # Calc ASR
878
+ asr_tt = None
879
+ asr_mx = None
880
+ asr_tge = None
881
+ if asr_method == 'none' or mixdb.mixture(mixid).snr == -99: # noise only, ignore/reset target asr
892
882
  wer_mx = float('nan')
893
883
  wer_tge = float('nan')
894
884
  wer_pi = float('nan')
895
885
  else:
896
- if MP_GLOBAL.mixdb.asr_manifests:
897
- asr_tt = MP_GLOBAL.mixdb.mixture_asr_data(mixid)[0] # ignore mixup
898
- else:
899
- asr_tt = calc_asr(target, engine=wer_method, whisper_model_name=whisper_model).text # target truth
886
+ asr_tt = MP_GLOBAL.mixdb.get_speech_metadata(mixid, 'text')[0] # ignore mixup
887
+ if asr_tt is None:
888
+ asr_tt = calc_asr(target, engine=asr_method, whisper_model_name=asr_model_name).text # target truth
889
+ # if MP_GLOBAL.mixdb.asr_manifests:
890
+ # asr_tt = MP_GLOBAL.mixdb.mixture_asr_data(mixid)[0] # ignore mixup
891
+ # else:
892
+ # asr_tt = calc_asr(target, engine=asr_method, whisper_model_name=asr_model_name).text # target truth
900
893
 
901
894
  if asr_tt:
902
- asr_mx = calc_asr(mixture, engine=wer_method, whisper_model=whisper_model).text
903
- asr_tge = calc_asr(target_est_wav, engine=wer_method, whisper_model=whisper_model).text
895
+ asr_mx = calc_asr(mixture, engine=asr_method, whisper_model_name=asr_model_name).text
896
+ asr_tge = calc_asr(target_est_wav, engine=asr_method, whisper_model_name=asr_model_name).text
904
897
 
905
898
  wer_mx = calc_wer(asr_mx, asr_tt).wer * 100 # mixture wer
906
899
  wer_tge = calc_wer(asr_tge, asr_tt).wer * 100 # target estimate wer
@@ -962,10 +955,10 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
962
955
  print('', file=f)
963
956
  print(f'Target path: {mixdb.target_file(ti).name}', file=f)
964
957
  print(f'Noise path: {mixdb.noise_file(ni).name}', file=f)
965
- if wer_method != 'none':
966
- print(f'WER method: {wer_method} and whisper model (if used): {whisper_model}', file=f)
958
+ if asr_method != 'none':
959
+ print(f'ASR method: {asr_method} and whisper model (if used): {asr_model_name}', file=f)
967
960
  if mixdb.asr_manifests:
968
- print(f'ASR truth from manifest: {asr_tt}', file=f)
961
+ print(f'ASR truth from metadata: {asr_tt}', file=f)
969
962
  else:
970
963
  print(f'ASR truth from wer method: {asr_tt}', file=f)
971
964
  print(f'ASR result for mixture: {asr_mx}', file=f)
@@ -977,7 +970,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
977
970
  if enable_wav:
978
971
  write_wav(name=base_name + '_mixture.wav', audio=float_to_int16(mixture))
979
972
  write_wav(name=base_name + '_target.wav', audio=float_to_int16(target))
980
- # write_wav(name=base_name + '_targetfi.wav', audio=float_to_int16(targetfi))
973
+ # write_wav(name=base_name + '_target_fi.wav', audio=float_to_int16(target_fi))
981
974
  write_wav(name=base_name + '_noise.wav', audio=float_to_int16(noise))
982
975
  write_wav(name=base_name + '_target_est.wav', audio=float_to_int16(target_est_wav))
983
976
  write_wav(name=base_name + '_noise_est.wav', audio=float_to_int16(noise_est_wav))
@@ -992,7 +985,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
992
985
  # 8) Write out plot file
993
986
  if enable_plot:
994
987
  from matplotlib.backends.backend_pdf import PdfPages
995
- plot_fname = base_name + '_metric_spenh.pdf'
988
+ plot_name = base_name + '_metric_spenh.pdf'
996
989
 
997
990
  # Reshape feature to eliminate overlap redundancy for easier to understand spectrogram view
998
991
  # Original size (frames, stride, num_bands), decimates in stride dimension only if step is > 1
@@ -1007,7 +1000,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
1007
1000
  feat_sgram = feat_sgram[:, -step:, :] # decimate, Fx1xB
1008
1001
  feat_sgram = np.reshape(feat_sgram, (feat_sgram.shape[0] * feat_sgram.shape[1], feat_sgram.shape[2]))
1009
1002
 
1010
- with PdfPages(plot_fname) as pdf:
1003
+ with PdfPages(plot_name) as pdf:
1011
1004
  # page1 we always have a mixture and prediction, target optional if truth provided
1012
1005
  tfunc_name = mixdb.target_file(1).truth_settings[0].function # first target, assumes all have same
1013
1006
  if tfunc_name == 'mapped_snr_f':
@@ -1036,25 +1029,25 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
1036
1029
  tg_spec = 20 * np.log10(abs(target_f) + np.finfo(np.float32).eps)
1037
1030
  tg_est_spec = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
1038
1031
  # n_spec = np.reshape(n_spec,(n_spec.shape[0] * n_spec.shape[1], n_spec.shape[2]))
1039
- pdf.savefig(plot_epredtruth(predict=tg_est_spec,
1040
- predict_wav=target_est_wav,
1041
- truth_f=tg_spec,
1042
- truth_wav=targetfi,
1043
- metric=np.vstack((lerr_tg_frame, phd_frame)).T,
1044
- tp_title='speech estimate'))
1032
+ pdf.savefig(plot_e_predict_truth(predict=tg_est_spec,
1033
+ predict_wav=target_est_wav,
1034
+ truth_f=tg_spec,
1035
+ truth_wav=target_fi,
1036
+ metric=np.vstack((lerr_tg_frame, phd_frame)).T,
1037
+ tp_title='speech estimate'))
1045
1038
 
1046
1039
  # page 4 noise extraction
1047
1040
  n_spec = 20 * np.log10(abs(noise_f) + np.finfo(np.float32).eps)
1048
1041
  n_est_spec = 20 * np.log10(abs(noise_est_complex) + np.finfo(np.float32).eps)
1049
- pdf.savefig(plot_epredtruth(predict=n_est_spec,
1050
- predict_wav=noise_est_wav,
1051
- truth_f=n_spec,
1052
- truth_wav=noisefi,
1053
- metric=lerr_n_frame,
1054
- tp_title='noise estimate'))
1042
+ pdf.savefig(plot_e_predict_truth(predict=n_est_spec,
1043
+ predict_wav=noise_est_wav,
1044
+ truth_f=n_spec,
1045
+ truth_wav=noise_fi,
1046
+ metric=lerr_n_frame,
1047
+ tp_title='noise estimate'))
1055
1048
 
1056
1049
  # Plot error waveforms
1057
- # tg_err_wav = targetfi - target_est_wav
1050
+ # tg_err_wav = target_fi - target_est_wav
1058
1051
  # tg_err_spec = 20*np.log10(np.abs(target_f - predict_complex))
1059
1052
 
1060
1053
  plt.close('all')
@@ -1072,14 +1065,14 @@ def main():
1072
1065
 
1073
1066
  verbose = args['--verbose']
1074
1067
  mixids = args['--mixid']
1075
- predict_location = args['PLOC']
1076
- wer_method = args['--wer-method'].lower()
1068
+ asr_method = args['--asr-method'].lower()
1069
+ asr_model_name = args['--model'].lower()
1077
1070
  truth_est_mode = args['--truth-est-mode']
1078
1071
  enable_plot = args['--plot']
1079
1072
  enable_wav = args['--wav']
1080
1073
  enable_summary = args['--summary']
1074
+ predict_location = args['PLOC']
1081
1075
  truth_location = args['TLOC']
1082
- whisper_model = args['--whisper-model'].lower()
1083
1076
 
1084
1077
  import glob
1085
1078
  from os.path import basename
@@ -1103,19 +1096,19 @@ def main():
1103
1096
  if not isdir(predict_location):
1104
1097
  print(f'The specified predict location {predict_location} is not a valid subdirectory path, exiting ...')
1105
1098
 
1106
- # allpfiles = listdir(predict_location)
1107
- allpfiles = glob.glob(predict_location + "/*.h5")
1099
+ # all_predict_files = listdir(predict_location)
1100
+ all_predict_files = glob.glob(predict_location + "/*.h5")
1108
1101
  predict_logfile = glob.glob(predict_location + "/*predict.log")
1109
- predwav_mode = False
1110
- if len(allpfiles) <= 0 and not truth_est_mode:
1111
- allpfiles = glob.glob(predict_location + "/*.wav") # check for wav files
1112
- if len(allpfiles) <= 0:
1102
+ predict_wav_mode = False
1103
+ if len(all_predict_files) <= 0 and not truth_est_mode:
1104
+ all_predict_files = glob.glob(predict_location + "/*.wav") # check for wav files
1105
+ if len(all_predict_files) <= 0:
1113
1106
  print(f'Subdirectory {predict_location} has no .h5 or .wav files, exiting ...')
1114
1107
  else:
1115
- logger.info(f'Found {len(allpfiles)} prediction .wav files.')
1116
- predwav_mode = True
1108
+ logger.info(f'Found {len(all_predict_files)} prediction .wav files.')
1109
+ predict_wav_mode = True
1117
1110
  else:
1118
- logger.info(f'Found {len(allpfiles)} prediction .h5 files.')
1111
+ logger.info(f'Found {len(all_predict_files)} prediction .h5 files.')
1119
1112
 
1120
1113
  if len(predict_logfile) == 0:
1121
1114
  logger.info(f'Warning, predict location {predict_location} has no prediction log files.')
@@ -1134,51 +1127,51 @@ def main():
1134
1127
  logger.info(f'Only running specified subset of {len(mixids)} mixtures')
1135
1128
 
1136
1129
  enable_asr_warmup = False
1137
- if wer_method == 'none':
1130
+ if asr_method == 'none':
1138
1131
  fnb = 'metric_spenh_'
1139
- elif wer_method == 'google':
1132
+ elif asr_method == 'google':
1140
1133
  fnb = 'metric_spenh_ggl_'
1141
- logger.info(f'WER enabled with method {wer_method}')
1134
+ logger.info(f'ASR enabled with method {asr_method}')
1142
1135
  enable_asr_warmup = True
1143
- elif wer_method == 'deepgram':
1136
+ elif asr_method == 'deepgram':
1144
1137
  fnb = 'metric_spenh_dgram_'
1145
- logger.info(f'WER enabled with method {wer_method}')
1138
+ logger.info(f'ASR enabled with method {asr_method}')
1146
1139
  enable_asr_warmup = True
1147
- elif wer_method == 'aixplain_whisper':
1148
- fnb = 'metric_spenh_whspx_' + whisper_model + '_'
1149
- logger.info(f'WER enabled with method {wer_method} and whisper model {whisper_model}')
1140
+ elif asr_method == 'aixplain_whisper':
1141
+ fnb = 'metric_spenh_whspx_' + asr_model_name + '_'
1142
+ logger.info(f'ASR enabled with method {asr_method} and whisper model {asr_model_name}')
1150
1143
  enable_asr_warmup = True
1151
- elif wer_method == 'whisper':
1152
- fnb = 'metric_spenh_whspl_' + whisper_model + '_'
1153
- logger.info(f'WER enabled with method {wer_method} and whisper model {whisper_model}')
1144
+ elif asr_method == 'whisper':
1145
+ fnb = 'metric_spenh_whspl_' + asr_model_name + '_'
1146
+ logger.info(f'ASR enabled with method {asr_method} and whisper model {asr_model_name}')
1154
1147
  enable_asr_warmup = True
1155
- elif wer_method == 'aaware_whisper':
1156
- fnb = 'metric_spenh_whspaaw_' + whisper_model + '_'
1157
- logger.info(f'WER enabled with method {wer_method} and whisper model {whisper_model}')
1148
+ elif asr_method == 'aaware_whisper':
1149
+ fnb = 'metric_spenh_whspaaw_' + asr_model_name + '_'
1150
+ logger.info(f'ASR enabled with method {asr_method} and whisper model {asr_model_name}')
1158
1151
  enable_asr_warmup = True
1159
- elif wer_method == 'fastwhisper':
1160
- fnb = 'metric_spenh_fwhsp_' + whisper_model + '_'
1161
- logger.info(f'WER enabled with method {wer_method} and whisper model {whisper_model}')
1152
+ elif asr_method == 'faster_whisper':
1153
+ fnb = 'metric_spenh_fwhsp_' + asr_model_name + '_'
1154
+ logger.info(f'ASR enabled with method {asr_method} and whisper model {asr_model_name}')
1162
1155
  enable_asr_warmup = True
1163
1156
  else:
1164
- logger.error(f'Unrecognized WER method: {wer_method}')
1157
+ logger.error(f'Unrecognized ASR method: {asr_method}')
1165
1158
  return
1166
1159
 
1167
1160
  if enable_asr_warmup:
1168
1161
  DEFAULT_SPEECH = split(DEFAULT_NOISE)[0] + '/speech_ma01_01.wav'
1169
1162
  audio = read_audio(DEFAULT_SPEECH)
1170
1163
  logger.info(f'Warming up asr method, note for cloud service this could take up to a few min ...')
1171
- asr_chk = calc_asr(audio, engine=wer_method, whisper_model_name=whisper_model)
1164
+ asr_chk = calc_asr(audio, engine=asr_method, whisper_model_name=asr_model_name)
1172
1165
  logger.info(f'Warmup completed, results {asr_chk}')
1173
1166
 
1174
1167
  MP_GLOBAL.mixdb = mixdb
1175
1168
  MP_GLOBAL.predict_location = predict_location
1176
- MP_GLOBAL.predwav_mode = predwav_mode
1169
+ MP_GLOBAL.predict_wav_mode = predict_wav_mode
1177
1170
  MP_GLOBAL.truth_est_mode = truth_est_mode
1178
1171
  MP_GLOBAL.enable_plot = enable_plot
1179
1172
  MP_GLOBAL.enable_wav = enable_wav
1180
- MP_GLOBAL.wer_method = wer_method
1181
- MP_GLOBAL.whisper_model = whisper_model
1173
+ MP_GLOBAL.asr_method = asr_method
1174
+ MP_GLOBAL.asr_model_name = asr_model_name
1182
1175
 
1183
1176
  # Individual mixtures use pandas print, set precision to 2 decimal places
1184
1177
  # pd.set_option('float_format', '{:.2f}'.format)
@@ -1255,7 +1248,7 @@ def main():
1255
1248
  ofname = join(predict_location, fnb + 'summary_truest.txt')
1256
1249
 
1257
1250
  with open(ofname, 'w') as f:
1258
- print(f'WER enabled with method {wer_method}, whisper model, if used: {whisper_model}', file=f)
1251
+ print(f'ASR enabled with method {asr_method}, whisper model, if used: {asr_model_name}', file=f)
1259
1252
  print(f'Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:',
1260
1253
  file=f)
1261
1254
  print(all_nom99_mean.to_frame().T.round(2).to_string(float_format=lambda x: "{:.2f}".format(x),
@@ -1318,7 +1311,7 @@ def main():
1318
1311
  label = f'Extraction statistics stats over {num_mix} mixtures:'
1319
1312
  pd.DataFrame([label]).to_csv(csv_name, **header_args)
1320
1313
  all_metrics_table_2.describe().round(2).to_csv(csv_name, **table_args)
1321
- label = f'WER enabled with method {wer_method}, whisper model, if used: {whisper_model}'
1314
+ label = f'ASR enabled with method {asr_method}, whisper model, if used: {asr_model_name}'
1322
1315
  pd.DataFrame([label]).to_csv(csv_name, **header_args)
1323
1316
 
1324
1317
  if not truth_est_mode: