sonusai 0.18.1__py3-none-any.whl → 0.18.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/__init__.py CHANGED
@@ -10,6 +10,7 @@ commands_doc = """
10
10
  calc_metric_spenh Run speech enhancement and analysis
11
11
  doc Documentation
12
12
  genft Generate feature and truth data
13
+ genmetrics Generate mixture metrics data
13
14
  genmix Generate mixture and truth data
14
15
  genmixdb Generate a mixture database
15
16
  gentcst Generate target configuration from a subdirectory tree
sonusai/audiofe.py CHANGED
@@ -142,7 +142,7 @@ def main() -> None:
142
142
  if hparams is None:
143
143
  logger.error(f'Error: ONNX model does not have required SonusAI hyperparameters, cannot proceed.')
144
144
  raise SystemExit(1)
145
- feature_mode = hparams.feature
145
+ feature_mode = hparams['feature']
146
146
  in0name = sess_inputs[0].name
147
147
  in0type = sess_inputs[0].type
148
148
  out_names = [n.name for n in session.get_outputs()]
@@ -24,7 +24,7 @@ For whisper ASR methods, the possible models used in local processing (ASR = whi
24
24
  {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large}
25
25
  but note most are very computationally demanding and can overwhelm/hang a local system.
26
26
 
27
- Outputs the following to PLOC (where id is mixd number 0:num_mixtures):
27
+ Outputs the following to PLOC (where id is mixid number 0:num_mixtures):
28
28
  <id>_metric_spenh.txt
29
29
 
30
30
  If --plot:
@@ -74,6 +74,9 @@ from sonusai.mixture import Feature
74
74
  from sonusai.mixture import MixtureDatabase
75
75
  from sonusai.mixture import Predict
76
76
 
77
+ DB_99 = np.power(10, 99 / 10)
78
+ DB_N99 = np.power(10, -99 / 10)
79
+
77
80
 
78
81
  def signal_handler(_sig, _frame):
79
82
  import sys
@@ -122,298 +125,6 @@ def power_uncompress(spec):
122
125
  return real_uncompress + 1j * imag_uncompress
123
126
 
124
127
 
125
- def snr(clean_speech, processed_speech, sample_rate):
126
- # Check the length of the clean and processed speech. Must be the same.
127
- clean_length = len(clean_speech)
128
- processed_length = len(processed_speech)
129
- if clean_length != processed_length:
130
- raise ValueError('Both Speech Files must be same length.')
131
-
132
- overall_snr = 10 * np.log10(np.sum(np.square(clean_speech)) / np.sum(np.square(clean_speech - processed_speech)))
133
-
134
- # Global Variables
135
- win_length = round(30 * sample_rate / 1000) # window length in samples
136
- skip_rate = int(np.floor(win_length / 4)) # window skip in samples
137
- min_snr = -10 # minimum SNR in dB
138
- max_snr = 35 # maximum SNR in dB
139
-
140
- # For each frame of input speech, calculate the Segmental SNR
141
- num_frames = int(clean_length / skip_rate - (win_length / skip_rate)) # number of frames
142
- start = 0 # starting sample
143
- window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
144
-
145
- segmental_snr = np.empty(num_frames)
146
- eps = np.spacing(1)
147
- for frame_count in range(num_frames):
148
- # (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
149
- clean_frame = clean_speech[start:start + win_length]
150
- processed_frame = processed_speech[start:start + win_length]
151
- clean_frame = np.multiply(clean_frame, window)
152
- processed_frame = np.multiply(processed_frame, window)
153
-
154
- # (2) Compute the Segmental SNR
155
- signal_energy = np.sum(np.square(clean_frame))
156
- noise_energy = np.sum(np.square(clean_frame - processed_frame))
157
- segmental_snr[frame_count] = 10 * np.log10(signal_energy / (noise_energy + eps) + eps)
158
- segmental_snr[frame_count] = max(segmental_snr[frame_count], min_snr)
159
- segmental_snr[frame_count] = min(segmental_snr[frame_count], max_snr)
160
-
161
- start = start + skip_rate
162
-
163
- return overall_snr, segmental_snr
164
-
165
-
166
- def lp_coefficients(speech_frame, model_order):
167
- # (1) Compute Autocorrelation Lags
168
- win_length = np.size(speech_frame)
169
- autocorrelation = np.empty(model_order + 1)
170
- e = np.empty(model_order + 1)
171
- for k in range(model_order + 1):
172
- autocorrelation[k] = np.dot(speech_frame[0:win_length - k], speech_frame[k: win_length])
173
-
174
- # (2) Levinson-Durbin
175
- a = np.ones(model_order)
176
- a_past = np.empty(model_order)
177
- ref_coefficients = np.empty(model_order)
178
- e[0] = autocorrelation[0]
179
- for i in range(model_order):
180
- a_past[0: i] = a[0: i]
181
- sum_term = np.dot(a_past[0: i], autocorrelation[i:0:-1])
182
- ref_coefficients[i] = (autocorrelation[i + 1] - sum_term) / e[i]
183
- a[i] = ref_coefficients[i]
184
- if i == 0:
185
- a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1:-1:-1], ref_coefficients[i])
186
- else:
187
- a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1::-1], ref_coefficients[i])
188
- e[i + 1] = (1 - ref_coefficients[i] * ref_coefficients[i]) * e[i]
189
- lp_params = np.concatenate((np.array([1]), -a))
190
- return autocorrelation, ref_coefficients, lp_params
191
-
192
-
193
- def llr(clean_speech, processed_speech, sample_rate):
194
- from scipy.linalg import toeplitz
195
-
196
- # Check the length of the clean and processed speech. Must be the same.
197
- clean_length = np.size(clean_speech)
198
- processed_length = np.size(processed_speech)
199
- if clean_length != processed_length:
200
- raise ValueError('Both speech files must be same length.')
201
-
202
- # Global Variables
203
- win_length = (np.round(30 * sample_rate / 1000)).astype(int) # window length in samples
204
- skip_rate = (np.floor(win_length / 4)).astype(int) # window skip in samples
205
- if sample_rate < 10000:
206
- p = 10 # LPC Analysis Order
207
- else:
208
- p = 16 # this could vary depending on sampling frequency.
209
-
210
- # For each frame of input speech, calculate the Log Likelihood Ratio
211
- num_frames = int((clean_length - win_length) / skip_rate) # number of frames
212
- start = 0 # starting sample
213
- window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
214
-
215
- distortion = np.empty(num_frames)
216
- for frame_count in range(num_frames):
217
- # (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
218
- clean_frame = clean_speech[start: start + win_length]
219
- processed_frame = processed_speech[start: start + win_length]
220
- clean_frame = np.multiply(clean_frame, window)
221
- processed_frame = np.multiply(processed_frame, window)
222
-
223
- # (2) Get the autocorrelation lags and LPC parameters used to compute the LLR measure.
224
- r_clean, ref_clean, a_clean = lp_coefficients(clean_frame, p)
225
- r_processed, ref_processed, a_processed = lp_coefficients(processed_frame, p)
226
-
227
- # (3) Compute the LLR measure
228
- numerator = np.dot(np.matmul(a_processed, toeplitz(r_clean)), a_processed)
229
- denominator = np.dot(np.matmul(a_clean, toeplitz(r_clean)), a_clean)
230
- distortion[frame_count] = np.log(numerator / denominator)
231
- start = start + skip_rate
232
- return distortion
233
-
234
-
235
- def wss(clean_speech, processed_speech, sample_rate):
236
- from scipy.fftpack import fft
237
-
238
- # Check the length of the clean and processed speech, which must be the same.
239
- clean_length = np.size(clean_speech)
240
- processed_length = np.size(processed_speech)
241
- if clean_length != processed_length:
242
- raise ValueError('Files must have same length.')
243
-
244
- # Global variables
245
- win_length = (np.round(30 * sample_rate / 1000)).astype(int) # window length in samples
246
- skip_rate = (np.floor(np.divide(win_length, 4))).astype(int) # window skip in samples
247
- max_freq = (np.divide(sample_rate, 2)).astype(int) # maximum bandwidth
248
- num_crit = 25 # number of critical bands
249
-
250
- n_fft = (np.power(2, np.ceil(np.log2(2 * win_length)))).astype(int)
251
- n_fft_by_2 = (np.multiply(0.5, n_fft)).astype(int) # FFT size/2
252
- k_max = 20.0 # value suggested by Klatt, pg 1280
253
- k_loc_max = 1.0 # value suggested by Klatt, pg 1280
254
-
255
- # Critical Band Filter Definitions (Center Frequency and Bandwidths in Hz)
256
- cent_freq = np.array([50.0000, 120.000, 190.000, 260.000, 330.000, 400.000, 470.000,
257
- 540.000, 617.372, 703.378, 798.717, 904.128, 1020.38, 1148.30,
258
- 1288.72, 1442.54, 1610.70, 1794.16, 1993.93, 2211.08, 2446.71,
259
- 2701.97, 2978.04, 3276.17, 3597.63])
260
- bandwidth = np.array([70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 70.0000,
261
- 77.3724, 86.0056, 95.3398, 105.411, 116.256, 127.914, 140.423,
262
- 153.823, 168.154, 183.457, 199.776, 217.153, 235.631, 255.255,
263
- 276.072, 298.126, 321.465, 346.136])
264
-
265
- bw_min = bandwidth[0] # minimum critical bandwidth
266
-
267
- # Set up the critical band filters.
268
- # Note here that Gaussian-ly shaped filters are used.
269
- # Also, the sum of the filter weights are equivalent for each critical band filter.
270
- # Filter less than -30 dB and set to zero.
271
- min_factor = np.exp(-30.0 / (2.0 * 2.303)) # -30 dB point of filter
272
- crit_filter = np.empty((num_crit, n_fft_by_2))
273
- for i in range(num_crit):
274
- f0 = (cent_freq[i] / max_freq) * n_fft_by_2
275
- bw = (bandwidth[i] / max_freq) * n_fft_by_2
276
- norm_factor = np.log(bw_min) - np.log(bandwidth[i])
277
- j = np.arange(n_fft_by_2)
278
- crit_filter[i, :] = np.exp(-11 * np.square(np.divide(j - np.floor(f0), bw)) + norm_factor)
279
- cond = np.greater(crit_filter[i, :], min_factor)
280
- crit_filter[i, :] = np.where(cond, crit_filter[i, :], 0)
281
- # For each frame of input speech, calculate the Weighted Spectral Slope Measure
282
- num_frames = int(clean_length / skip_rate - (win_length / skip_rate)) # number of frames
283
- start = 0 # starting sample
284
- window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
285
-
286
- distortion = np.empty(num_frames)
287
- for frame_count in range(num_frames):
288
- # (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
289
- clean_frame = clean_speech[start: start + win_length] / 32768
290
- processed_frame = processed_speech[start: start + win_length] / 32768
291
- clean_frame = np.multiply(clean_frame, window)
292
- processed_frame = np.multiply(processed_frame, window)
293
- # (2) Compute the Power Spectrum of Clean and Processed
294
- clean_spec = np.square(np.abs(fft(clean_frame, n_fft)))
295
- processed_spec = np.square(np.abs(fft(processed_frame, n_fft)))
296
-
297
- # (3) Compute Filterbank Output Energies (in dB scale)
298
- clean_energy = np.matmul(crit_filter, clean_spec[0:n_fft_by_2])
299
- processed_energy = np.matmul(crit_filter, processed_spec[0:n_fft_by_2])
300
-
301
- clean_energy = 10 * np.log10(np.maximum(clean_energy, 1E-10))
302
- processed_energy = 10 * np.log10(np.maximum(processed_energy, 1E-10))
303
-
304
- # (4) Compute Spectral Slope (dB[i+1]-dB[i])
305
- clean_slope = clean_energy[1:num_crit] - clean_energy[0: num_crit - 1]
306
- processed_slope = processed_energy[1:num_crit] - processed_energy[0: num_crit - 1]
307
-
308
- # (5) Find the nearest peak locations in the spectra to each critical band.
309
- # If the slope is negative, we search to the left. If positive, we search to the right.
310
- clean_loc_peak = np.empty(num_crit - 1)
311
- processed_loc_peak = np.empty(num_crit - 1)
312
-
313
- for i in range(num_crit - 1):
314
- # find the peaks in the clean speech signal
315
- if clean_slope[i] > 0: # search to the right
316
- n = i
317
- while (n < num_crit - 1) and (clean_slope[n] > 0):
318
- n = n + 1
319
- clean_loc_peak[i] = clean_energy[n - 1]
320
- else: # search to the left
321
- n = i
322
- while (n >= 0) and (clean_slope[n] <= 0):
323
- n = n - 1
324
- clean_loc_peak[i] = clean_energy[n + 1]
325
-
326
- # find the peaks in the processed speech signal
327
- if processed_slope[i] > 0: # search to the right
328
- n = i
329
- while (n < num_crit - 1) and (processed_slope[n] > 0):
330
- n = n + 1
331
- processed_loc_peak[i] = processed_energy[n - 1]
332
- else: # search to the left
333
- n = i
334
- while (n >= 0) and (processed_slope[n] <= 0):
335
- n = n - 1
336
- processed_loc_peak[i] = processed_energy[n + 1]
337
-
338
- # (6) Compute the WSS Measure for this frame. This includes determination of the weighting function.
339
- db_max_clean = np.max(clean_energy)
340
- db_max_processed = np.max(processed_energy)
341
- '''
342
- The weights are calculated by averaging individual weighting factors from the clean and processed frame.
343
- These weights w_clean and w_processed should range from 0 to 1 and place more emphasis on spectral peaks
344
- and less emphasis on slope differences in spectral valleys.
345
- This procedure is described on page 1280 of Klatt's 1982 ICASSP paper.
346
- '''
347
- w_max_clean = np.divide(k_max, k_max + db_max_clean - clean_energy[0: num_crit - 1])
348
- w_loc_max_clean = np.divide(k_loc_max, k_loc_max + clean_loc_peak - clean_energy[0: num_crit - 1])
349
- w_clean = np.multiply(w_max_clean, w_loc_max_clean)
350
-
351
- w_max_processed = np.divide(k_max, k_max + db_max_processed - processed_energy[0: num_crit - 1])
352
- w_loc_max_processed = np.divide(k_loc_max, k_loc_max + processed_loc_peak - processed_energy[0: num_crit - 1])
353
- w_processed = np.multiply(w_max_processed, w_loc_max_processed)
354
-
355
- w = np.divide(np.add(w_clean, w_processed), 2.0)
356
- slope_diff = np.subtract(clean_slope, processed_slope)[0: num_crit - 1]
357
- distortion[frame_count] = np.dot(w, np.square(slope_diff)) / np.sum(w)
358
- # This normalization is not part of Klatt's paper, but helps to normalize the measure.
359
- # Here we scale the measure by the sum of the weights.
360
- start = start + skip_rate
361
- return distortion
362
-
363
-
364
- def calc_speech_metrics(hypothesis: np.ndarray,
365
- reference: np.ndarray) -> tuple[float, int, int, int, float]:
366
- """
367
- Calculate speech metrics pesq_mos, c_sig, c_bak, c_ovl, seg_snr. These are all related and thus included
368
- in one function. Reference: matlab script "compute_metrics.m".
369
-
370
- Usage:
371
- pesq, c_sig, c_bak, c_ovl, ssnr = compute_metrics(hypothesis, reference, fs, path)
372
- reference: clean audio as array
373
- hypothesis: enhanced audio as array
374
- Audio must have sampling rate = 16000 Hz.
375
-
376
- Example call:
377
- pesq_output, csig_output, cbak_output, covl_output, ssnr_output = \
378
- calc_speech_metrics(predicted_audio, target_audio)
379
- """
380
- from sonusai.metrics import calc_pesq
381
-
382
- fs = 16000
383
-
384
- # compute the WSS measure
385
- wss_dist_vec = wss(reference, hypothesis, fs)
386
- wss_dist_vec = np.sort(wss_dist_vec)
387
- alpha = 0.95 # value from CMGAN ref implementation
388
- wss_dist = np.mean(wss_dist_vec[0: round(np.size(wss_dist_vec) * alpha)])
389
-
390
- # compute the LLR measure
391
- llr_dist = llr(reference, hypothesis, fs)
392
- ll_rs = np.sort(llr_dist)
393
- llr_len = round(np.size(llr_dist) * alpha)
394
- llr_mean = np.mean(ll_rs[0: llr_len])
395
-
396
- # compute the SNRseg
397
- snr_dist, segsnr_dist = snr(reference, hypothesis, fs)
398
- seg_snr = np.mean(segsnr_dist)
399
-
400
- # compute the pesq (use Sonusai wrapper, only fs=16k, mode=wb support)
401
- pesq_mos = calc_pesq(hypothesis=hypothesis, reference=reference)
402
-
403
- # now compute the composite measures
404
- c_sig = 3.093 - 1.029 * llr_mean + 0.603 * pesq_mos - 0.009 * wss_dist
405
- c_sig = max(1, c_sig)
406
- c_sig = min(5, c_sig) # limit values to [1, 5]
407
- c_bak = 1.634 + 0.478 * pesq_mos - 0.007 * wss_dist + 0.063 * seg_snr
408
- c_bak = max(1, c_bak)
409
- c_bak = min(5, c_bak) # limit values to [1, 5]
410
- c_ovl = 1.594 + 0.805 * pesq_mos - 0.512 * llr_mean - 0.007 * wss_dist
411
- c_ovl = max(1, c_ovl)
412
- c_ovl = min(5, c_ovl) # limit values to [1, 5]
413
-
414
- return pesq_mos, c_sig, c_bak, c_ovl, seg_snr
415
-
416
-
417
128
  def mean_square_error(hypothesis: np.ndarray,
418
129
  reference: np.ndarray,
419
130
  squared: bool = False) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
@@ -494,48 +205,6 @@ def log_error(reference: np.ndarray, hypothesis: np.ndarray) -> tuple[np.ndarray
494
205
  return err, err_b, err_f
495
206
 
496
207
 
497
- def phase_distance(reference: np.ndarray,
498
- hypothesis: np.ndarray,
499
- eps: float = 1e-9) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
500
- """Calculate weighted phase distance error (weight normalization over bins per frame)
501
-
502
- :param reference: complex [frames, bins]
503
- :param hypothesis: complex [frames, bins]
504
- :param eps: epsilon value
505
- :return: mean, mean per bin, mean per frame
506
- """
507
- ang_diff = np.angle(reference) - np.angle(hypothesis)
508
- phd_mod = (ang_diff + np.pi) % (2 * np.pi) - np.pi
509
- rh_angle_diff = phd_mod * 180 / np.pi # angle diff in deg
510
-
511
- # Use complex divide to intrinsically keep angle diff +/-180 deg, but avoid div by zero (real hyp)
512
- # hyp_real = np.real(hypothesis)
513
- # near_zeros = np.real(hyp_real) < eps
514
- # hyp_real = hyp_real * (np.logical_not(near_zeros))
515
- # hyp_real = hyp_real + (near_zeros * eps)
516
- # hypothesis = hyp_real + 1j*np.imag(hypothesis)
517
- # rh_angle_diff = np.angle(reference / hypothesis) * 180 / np.pi # angle diff +/-180
518
-
519
- # weighted mean over all (scalar)
520
- reference_mag = np.abs(reference)
521
- ref_weight = reference_mag / (np.sum(reference_mag) + eps) # frames x bins
522
- err = np.around(np.sum(ref_weight * rh_angle_diff), 3)
523
-
524
- # weighted mean over frames (value per bin)
525
- err_b = np.zeros(reference.shape[1])
526
- for bi in range(reference.shape[1]):
527
- ref_weight = reference_mag[:, bi] / (np.sum(reference_mag[:, bi], axis=0) + eps)
528
- err_b[bi] = np.around(np.sum(ref_weight * rh_angle_diff[:, bi]), 3)
529
-
530
- # weighted mean over bins (value per frame)
531
- err_f = np.zeros(reference.shape[0])
532
- for fi in range(reference.shape[0]):
533
- ref_weight = reference_mag[fi, :] / (np.sum(reference_mag[fi, :]) + eps)
534
- err_f[fi] = np.around(np.sum(ref_weight * rh_angle_diff[fi, :]), 3)
535
-
536
- return err, err_b, err_f
537
-
538
-
539
208
  def plot_mixpred(mixture: AudioT,
540
209
  mixture_f: AudioF,
541
210
  target: Optional[AudioT] = None,
@@ -543,7 +212,6 @@ def plot_mixpred(mixture: AudioT,
543
212
  predict: Optional[Predict] = None,
544
213
  tp_title: str = '') -> plt.Figure:
545
214
  from sonusai.mixture import SAMPLE_RATE
546
-
547
215
  num_plots = 2
548
216
  if feature is not None:
549
217
  num_plots += 1
@@ -706,12 +374,13 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
706
374
  import h5py
707
375
  import mgzip
708
376
  from matplotlib.backends.backend_pdf import PdfPages
709
- from numpy import inf
710
377
  from pystoi import stoi
711
378
 
712
379
  from sonusai import SonusAIError
713
380
  from sonusai import logger
714
381
  from sonusai.metrics import calc_pcm
382
+ from sonusai.metrics import calc_phase_distance
383
+ from sonusai.metrics import calc_speech
715
384
  from sonusai.metrics import calc_wer
716
385
  from sonusai.metrics import calc_wsdr
717
386
  from sonusai.mixture import forward_transform
@@ -771,24 +440,28 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
771
440
  # noise_wo_dist_f = mixdb.mixture_noise_f(mixid, noise=noise_wo_dist)
772
441
  noise = mixture - target # has time-domain distortion (ir,etc.) but does not have specaugment
773
442
  # noise_f = mixdb.mixture_noise_f(mixid, noise=noise)
774
- segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise) # note: uses pre-IR, pre-specaug audio
443
+ # note: uses pre-IR, pre-specaug audio
444
+ segsnr_f: np.ndarray = mixdb.mixture_metrics(mixid, ['ssnr'])[0] # type: ignore
775
445
  mixture_f = mixdb.mixture_mixture_f(mixid, mixture=mixture)
776
446
  noise_f = mixture_f - target_f # true noise in freq domain includes specaugment and time-domain ir,distortions
777
447
  # segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise)
778
- segsnr_f[segsnr_f == inf] = 7.944e8 # 99db
779
- segsnr_f[segsnr_f == -inf] = 1.258e-10 # -99db
448
+ segsnr_f[segsnr_f == np.inf] = DB_99
449
+ # segsnr_f should never be -np.inf
450
+ segsnr_f[segsnr_f == -np.inf] = DB_N99
780
451
  # need to use inv-tf to match #samples & latency shift properties of predict inv tf
781
452
  target_fi = inverse_transform(target_f, mixdb.it_config)
782
453
  noise_fi = inverse_transform(noise_f, mixdb.it_config)
783
454
  # mixture_fi = mixdb.inverse_transform(mixture_f)
784
455
 
785
456
  # gen feature, truth - note feature only used for plots
786
- # TBD parse truth_f for different formats and also multi-truth
787
- feature, truth_f = mixdb.mixture_ft(mixid, mixture=mixture)
788
- truth_type = mixdb.target_file(mixdb.mixture(mixid).targets[0].file_id).truth_settings[0].function
789
- if truth_type == 'target_mixture_f':
790
- half = truth_f.shape[-1] // 2
791
- truth_f = truth_f[..., :half] # extract target_f only
457
+ # TODO: parse truth_f for different formats
458
+ feature, truth_f = mixdb.mixture_ft(mixid, mixture_f=mixture_f)
459
+ # ignore mixup
460
+ for truth_setting in mixdb.target_file(mixdb.mixture(mixid).targets[0].file_id).truth_settings:
461
+ if truth_setting.function == 'target_mixture_f':
462
+ half = truth_f.shape[-1] // 2
463
+ # extract target_f only
464
+ truth_f = truth_f[..., :half]
792
465
 
793
466
  if not truth_est_mode:
794
467
  if predict.shape[0] < target_f.shape[0]: # target_f, truth_f, mixture_f, etc. same size
@@ -843,12 +516,12 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
843
516
  pcm, pcm_bin, pcm_frame = calc_pcm(hypothesis=ypred_f, reference=ytrue_f, with_log=True)
844
517
 
845
518
  # Phase distance
846
- phd, phd_bin, phd_frame = phase_distance(hypothesis=predict_complex, reference=truth_f_complex)
519
+ phd, phd_bin, phd_frame = calc_phase_distance(hypothesis=predict_complex, reference=truth_f_complex)
847
520
 
848
521
  # Noise td logerr
849
522
  # lerr_nt, lerr_nt_bin, lerr_nt_frame = log_error(noise_fi, noise_truth_est_audio)
850
523
 
851
- # # SA-SDR (time-domain source-aggragated SDR)
524
+ # # SA-SDR (time-domain source-aggregated SDR)
852
525
  ytrue = np.concatenate((target_fi[:, np.newaxis], noise_fi[:, np.newaxis]), axis=1)
853
526
  ypred = np.concatenate((target_est_wav[:, np.newaxis], noise_est_wav[:, np.newaxis]), axis=1)
854
527
  # # note: w/o scale is more pessimistic number
@@ -863,8 +536,8 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
863
536
  # Speech intelligibility measure - PESQ
864
537
  if int(mixdb.mixture(mixid).snr) > -99:
865
538
  # len = target_est_wav.shape[0]
866
- pesq_speech, csig_tg, cbak_tg, covl_tg, sgsnr_tg = calc_speech_metrics(target_est_wav, target_fi)
867
- pesq_mixture, csig_mx, cbak_mx, covl_mx, sgsnr_mx = calc_speech_metrics(mixture, target)
539
+ pesq_speech, csig_tg, cbak_tg, covl_tg = calc_speech(target_est_wav, target_fi)
540
+ pesq_mixture, csig_mx, cbak_mx, covl_mx = mixdb.mixture_metrics(mixid, ['mxpesq', 'mxcsig', 'mxcbak', 'mxcovl'])
868
541
  # pesq_speech_tst = calc_pesq(hypothesis=target_est_wav, reference=target)
869
542
  # pesq_mixture_tst = calc_pesq(hypothesis=mixture, reference=target)
870
543
  # pesq improvement
@@ -886,20 +559,15 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
886
559
  asr_tt = None
887
560
  asr_mx = None
888
561
  asr_tge = None
889
- if asr_method == 'none' or mixdb.mixture(mixid).snr == -99: # noise only, ignore/reset target asr
890
- wer_mx = float('nan')
891
- wer_tge = float('nan')
892
- wer_pi = float('nan')
893
- else:
562
+ asr_engines = list(mixdb.asr_configs.keys())
563
+ if len(asr_engines) > 0 and mixdb.mixture(mixid).snr >= -96: # noise only, ignore/reset target asr
564
+ wer_mx = float(mixdb.mixture_metrics(mixid, [f'mxwer.{asr_engines[0]}'])[0]) * 100
894
565
  asr_tt = MP_GLOBAL.mixdb.mixture_speech_metadata(mixid, 'text')[0] # ignore mixup
895
566
  if asr_tt is None:
896
567
  asr_tt = calc_asr(target, engine=asr_method, whisper_model_name=asr_model_name).text # target truth
897
568
 
898
569
  if asr_tt:
899
- asr_mx = calc_asr(mixture, engine=asr_method, whisper_model_name=asr_model_name).text
900
570
  asr_tge = calc_asr(target_est_wav, engine=asr_method, whisper_model_name=asr_model_name).text
901
-
902
- wer_mx = calc_wer(asr_mx, asr_tt).wer * 100 # mixture wer
903
571
  wer_tge = calc_wer(asr_tge, asr_tt).wer * 100 # target estimate wer
904
572
  if wer_mx == 0.0:
905
573
  if wer_tge == 0.0:
@@ -913,6 +581,10 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
913
581
  wer_mx = float(0)
914
582
  wer_tge = float(0)
915
583
  wer_pi = float(0)
584
+ else:
585
+ wer_mx = float('nan')
586
+ wer_tge = float('nan')
587
+ wer_pi = float('nan')
916
588
 
917
589
  # 5) Save per mixture metric results
918
590
  # Single row in table of scalar metrics per mixture
@@ -1088,7 +760,6 @@ def main():
1088
760
  from os.path import basename
1089
761
  from os.path import isdir
1090
762
  from os.path import join
1091
- from os.path import split
1092
763
 
1093
764
  import psutil
1094
765
  from tqdm import tqdm
@@ -1097,7 +768,7 @@ def main():
1097
768
  from sonusai import initial_log_messages
1098
769
  from sonusai import logger
1099
770
  from sonusai import update_console_handler
1100
- from sonusai.mixture import DEFAULT_NOISE
771
+ from sonusai.mixture import DEFAULT_SPEECH
1101
772
  from sonusai.mixture import MixtureDatabase
1102
773
  from sonusai.mixture import read_audio
1103
774
  from sonusai.utils import calc_asr
@@ -1173,8 +844,7 @@ def main():
1173
844
  return
1174
845
 
1175
846
  if enable_asr_warmup:
1176
- default_speech = split(DEFAULT_NOISE)[0] + '/speech_ma01_01.wav'
1177
- audio = read_audio(default_speech)
847
+ audio = read_audio(DEFAULT_SPEECH)
1178
848
  logger.info(f'Warming up asr method, note for cloud service this could take up to a few min ...')
1179
849
  asr_chk = calc_asr(audio, engine=asr_method, whisper_model_name=asr_model_name)
1180
850
  logger.info(f'Warmup completed, results {asr_chk}')
sonusai/data/genmixdb.yml CHANGED
@@ -62,3 +62,5 @@ spectral_masks:
62
62
  t_max_width: 100
63
63
  t_num: 0
64
64
  t_max_percent: 100
65
+
66
+ asr_configs: [ ]
sonusai/doc/doc.py CHANGED
@@ -199,7 +199,6 @@ def get_truth_functions() -> str:
199
199
  def doc_truth_settings() -> str:
200
200
  import yaml
201
201
 
202
- from sonusai.mixture import get_default_config
203
202
  default = f"\nDefault value:\n\n{yaml.dump(get_default_config()['truth_settings'])}"
204
203
  return """
205
204
  'truth_settings' is a mixture database configuration parameter that sets the truth
@@ -375,7 +374,6 @@ This rule expands to 6 unique augmentations being applied to each target
375
374
  def doc_target_augmentations() -> str:
376
375
  import yaml
377
376
 
378
- from sonusai.mixture import get_default_config
379
377
  default = f"\nDefault value:\n\n{yaml.dump(get_default_config()['target_augmentations'])}"
380
378
  return """
381
379
  'target_augmentations' is a mixture database configuration parameter that
@@ -388,7 +386,6 @@ See 'augmentations' for details on augmentation rules.
388
386
  def doc_class_balancing_augmentation() -> str:
389
387
  import yaml
390
388
 
391
- from sonusai.mixture import get_default_config
392
389
  default = f"\nDefault value:\n\n{yaml.dump(get_default_config()['class_balancing_augmentation'])}"
393
390
  return """
394
391
  'class_balancing_augmentation' is a mixture database configuration parameter
@@ -436,7 +433,6 @@ Required field:
436
433
  def doc_noise_augmentations() -> str:
437
434
  import yaml
438
435
 
439
- from sonusai.mixture import get_default_config
440
436
  default = f"\nDefault value:\n\n{yaml.dump(get_default_config()['noise_augmentations'])}"
441
437
  return """
442
438
  'noise_augmentations' is a mixture database configuration parameter that
@@ -536,3 +532,48 @@ def doc_config() -> str:
536
532
  for c in VALID_CONFIGS:
537
533
  text += f' {c}\n'
538
534
  return text
535
+
536
+
537
+ def doc_asr_configs() -> str:
538
+ from sonusai.utils import get_available_engines
539
+
540
+ default = f"\nDefault value: {get_default_config()['asr_configs']}"
541
+ engines = get_available_engines()
542
+ text = """
543
+ 'asr_configs' is a mixture database configuration parameter that sets the list of
544
+ ASR engine(s) to use.
545
+
546
+ Required fields:
547
+
548
+ 'name' Unique identifier for the ASR engine.
549
+ 'engine' ASR engine to use. Available engines:
550
+ """
551
+ text += f' {", ".join(engines)}\n'
552
+ text += """
553
+ Optional fields:
554
+
555
+ 'model' Some ASR engines allow the specification of a model, but note most are
556
+ very computationally demanding and can overwhelm/hang a local system.
557
+ Available whisper ASR engines:
558
+ tiny.en, tiny, base.en, base, small.en, small, medium.en, medium, large-v1, large-v2, large
559
+ 'device' Some ASR engines allow the specification of a device, either 'cpu' or 'cuda'.
560
+ 'cpu_threads' Some ASR engines allow the specification of the number of CPU threads to use.
561
+ 'compute_type' Some ASR engines allow the specification of a compute type, e.g. 'int8'.
562
+ 'beam_size' Some ASR engines allow the specification of a beam size.
563
+ <other> Other parameters can be injected into the ASR engine as needed; all
564
+ fields in each config are forwarded to the given engine.
565
+
566
+ Example:
567
+
568
+ asr_configs:
569
+ - name: faster_tiny_cuda
570
+ engine: faster_whisper
571
+ model: tiny
572
+ device: cuda
573
+ beam_size: 5
574
+ - name: google
575
+ engine: google
576
+
577
+ Creates two ASR engines for use named faster_tiny_cuda and google.
578
+ """
579
+ return text + default