sonusai 0.19.9__py3-none-any.whl → 0.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. sonusai/calc_metric_spenh.py +265 -233
  2. sonusai/data/genmixdb.yml +4 -2
  3. sonusai/data/silero_vad_v5.1.jit +0 -0
  4. sonusai/data/silero_vad_v5.1.onnx +0 -0
  5. sonusai/doc/doc.py +14 -0
  6. sonusai/genft.py +1 -1
  7. sonusai/genmetrics.py +15 -18
  8. sonusai/genmix.py +1 -1
  9. sonusai/genmixdb.py +30 -52
  10. sonusai/ir_metric.py +555 -0
  11. sonusai/metrics_summary.py +322 -0
  12. sonusai/mixture/__init__.py +6 -2
  13. sonusai/mixture/audio.py +139 -15
  14. sonusai/mixture/augmentation.py +199 -84
  15. sonusai/mixture/config.py +9 -4
  16. sonusai/mixture/constants.py +0 -1
  17. sonusai/mixture/datatypes.py +19 -10
  18. sonusai/mixture/generation.py +52 -64
  19. sonusai/mixture/helpers.py +38 -26
  20. sonusai/mixture/ir_delay.py +63 -0
  21. sonusai/mixture/mixdb.py +190 -46
  22. sonusai/mixture/targets.py +3 -6
  23. sonusai/mixture/truth_functions/energy.py +9 -5
  24. sonusai/mixture/truth_functions/metrics.py +1 -1
  25. sonusai/mkwav.py +1 -1
  26. sonusai/onnx_predict.py +1 -1
  27. sonusai/queries/queries.py +1 -1
  28. sonusai/utils/__init__.py +2 -0
  29. sonusai/utils/asr.py +1 -1
  30. sonusai/utils/load_object.py +8 -2
  31. sonusai/utils/stratified_shuffle_split.py +1 -1
  32. sonusai/utils/temp_seed.py +13 -0
  33. {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/METADATA +2 -2
  34. {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/RECORD +36 -35
  35. {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/WHEEL +1 -1
  36. sonusai/mixture/soundfile_audio.py +0 -130
  37. sonusai/mixture/sox_audio.py +0 -476
  38. sonusai/mixture/sox_augmentation.py +0 -136
  39. sonusai/mixture/torchaudio_audio.py +0 -106
  40. sonusai/mixture/torchaudio_augmentation.py +0 -109
  41. {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/entry_points.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  """sonusai calc_metric_spenh
2
2
 
3
- usage: calc_metric_spenh [-hvtpws] [-i MIXID] [-e ASR] [-m MODEL] [-n NCPU] PLOC TLOC
3
+ usage: calc_metric_spenh [-hvtpws] [-i MIXID] [-e ASR] [-n NCPU] PLOC TLOC
4
4
 
5
5
  options:
6
6
  -h, --help
@@ -11,8 +11,10 @@ options:
11
11
  -w, --wav Generate WAV files per mixture.
12
12
  -s, --summary Enable summary files generation.
13
13
  -n, --num_process NCPU Number of parallel processes to use [default: auto]
14
- -e ASR, --asr-method ASR ASR method: deepgram, google, aixplain_whisper, whisper, or sensory. [default: none]
15
- -m MODEL, --model ASR model name used in some ASR methods. [default: tiny]
14
+ -e ASR, --asr-method ASR ASR method used for WER metrics. Must exist in the TLOC dataset as pre-calculated
15
+ metrics using SonusAI genmetrics. Can be either an integer index, i.e 0,1,... or the
16
+ name of the asr_engine configuration in the dataset. If an incorrect name is specified,
17
+ a list of asr_engines of the dataset will be printed.
16
18
 
17
19
  Calculate speech enhancement metrics of prediction data in PLOC using SonusAI mixture data in TLOC as truth/label
18
20
  reference. Metric and extraction data files are written into PLOC.
@@ -20,9 +22,14 @@ reference. Metric and extraction data files are written into PLOC.
20
22
  PLOC directory containing prediction data in .h5 files created from truth/label mixture data in TLOC
21
23
  TLOC directory with SonusAI mixture database of truth/label mixture data
22
24
 
23
- For whisper ASR methods, the possible models used in local processing (ASR = whisper) are:
24
- {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large}
25
- but note most are very computationally demanding and can overwhelm/hang a local system.
25
+ For ASR methods, the method must bel2 defined in the TLOC dataset, for example possible fast_whisper available models are:
26
+ {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large} and an example configuration looks like:
27
+ {'fwhsptiny_cpu': {'engine': 'faster_whisper',
28
+ 'model': 'tiny',
29
+ 'device': 'cpu',
30
+ 'beam_size': 5}}
31
+ Note: the ASR config can optionally include the model, device, and other fields the engine supports.
32
+ Most ASR are very computationally demanding and can overwhelm/hang a local system.
26
33
 
27
34
  Outputs the following to PLOC (where id is mixid number 0:num_mixtures):
28
35
  <id>_metric_spenh.txt
@@ -61,8 +68,6 @@ Inputs:
61
68
  """
62
69
 
63
70
  import signal
64
- from contextlib import redirect_stdout
65
- from dataclasses import dataclass
66
71
 
67
72
  import matplotlib
68
73
  import matplotlib.pyplot as plt
@@ -93,24 +98,17 @@ signal.signal(signal.SIGINT, signal_handler)
93
98
  matplotlib.use("SVG")
94
99
 
95
100
 
96
- @dataclass
97
- class MPGlobal:
98
- mixdb: MixtureDatabase
99
- predict_location: str
100
- predict_wav_mode: bool
101
- truth_est_mode: bool
102
- enable_plot: bool
103
- enable_wav: bool
104
- asr_method: str
105
- asr_model_name: str
106
-
107
-
108
- MP_GLOBAL: MPGlobal
101
+ def first_key(x: dict) -> str:
102
+ for key in x:
103
+ return key
104
+ raise KeyError("No key found")
109
105
 
110
106
 
111
107
  def mean_square_error(
112
- hypothesis: np.ndarray, reference: np.ndarray, squared: bool = False
113
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
108
+ hypothesis: np.ndarray,
109
+ reference: np.ndarray,
110
+ squared: bool = False,
111
+ ) -> tuple[float, np.ndarray, np.ndarray]:
114
112
  """Calculate root-mean-square error or mean square error
115
113
 
116
114
  :param hypothesis: [frames, bins]
@@ -125,7 +123,7 @@ def mean_square_error(
125
123
  # mean over bins for value per frame
126
124
  err_f = np.mean(sq_err, axis=1)
127
125
  # mean over all
128
- err = np.mean(sq_err)
126
+ err = float(np.mean(sq_err))
129
127
 
130
128
  if not squared:
131
129
  err_b = np.sqrt(err_b)
@@ -135,9 +133,7 @@ def mean_square_error(
135
133
  return err, err_b, err_f
136
134
 
137
135
 
138
- def mean_abs_percentage_error(
139
- hypothesis: np.ndarray, reference: np.ndarray
140
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
136
+ def mean_abs_percentage_error(hypothesis: np.ndarray, reference: np.ndarray) -> tuple[float, np.ndarray, np.ndarray]:
141
137
  """Calculate mean abs percentage error
142
138
 
143
139
  If inputs are complex, calculates average: mape(real)/2 + mape(imag)/2
@@ -162,12 +158,12 @@ def mean_abs_percentage_error(
162
158
  # mean over bins for value per frame
163
159
  err_f = np.around(np.mean(abs_err, axis=1), 3)
164
160
  # mean over all
165
- err = np.around(np.mean(abs_err), 3)
161
+ err = float(np.around(np.mean(abs_err), 3))
166
162
 
167
163
  return err, err_b, err_f
168
164
 
169
165
 
170
- def log_error(reference: np.ndarray, hypothesis: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
166
+ def log_error(reference: np.ndarray, hypothesis: np.ndarray) -> tuple[float, np.ndarray, np.ndarray]:
171
167
  """Calculate log error
172
168
 
173
169
  :param reference: complex or real [frames, bins]
@@ -184,7 +180,7 @@ def log_error(reference: np.ndarray, hypothesis: np.ndarray) -> tuple[np.ndarray
184
180
  # mean over bins for value per frame
185
181
  err_f = np.around(np.mean(log_err, axis=1), 3)
186
182
  # mean over all
187
- err = np.around(np.mean(log_err), 3)
183
+ err = float(np.around(np.mean(log_err), 3))
188
184
 
189
185
  return err, err_b, err_f
190
186
 
@@ -196,7 +192,7 @@ def plot_mixpred(
196
192
  feature: Feature | None = None,
197
193
  predict: Predict | None = None,
198
194
  tp_title: str = "",
199
- ) -> plt.Figure:
195
+ ) -> plt.Figure: # pyright: ignore [reportPrivateImportUsage]
200
196
  from sonusai.mixture import SAMPLE_RATE
201
197
 
202
198
  num_plots = 2
@@ -224,22 +220,12 @@ def plot_mixpred(
224
220
 
225
221
  if feature is not None:
226
222
  p += 1
227
- ax[p].imshow(
228
- np.transpose(feature),
229
- aspect="auto",
230
- interpolation="nearest",
231
- origin="lower",
232
- )
223
+ ax[p].imshow(np.transpose(feature), aspect="auto", interpolation="nearest", origin="lower")
233
224
  ax[p].set_title("Feature")
234
225
 
235
226
  if predict is not None:
236
227
  p += 1
237
- im = ax[p].imshow(
238
- np.transpose(predict),
239
- aspect="auto",
240
- interpolation="nearest",
241
- origin="lower",
242
- )
228
+ im = ax[p].imshow(np.transpose(predict), aspect="auto", interpolation="nearest", origin="lower")
243
229
  ax[p].set_title("Predict " + tp_title)
244
230
  plt.colorbar(im, location="bottom")
245
231
 
@@ -251,7 +237,7 @@ def plot_pdb_predict_truth(
251
237
  truth_f: np.ndarray | None = None,
252
238
  metric: np.ndarray | None = None,
253
239
  tp_title: str = "",
254
- ) -> plt.Figure:
240
+ ) -> plt.Figure: # pyright: ignore [reportPrivateImportUsage]
255
241
  """Plot predict and optionally truth and a metric in power db, e.g. applies 10*log10(predict)"""
256
242
  num_plots = 2
257
243
  if truth_f is not None:
@@ -277,24 +263,12 @@ def plot_pdb_predict_truth(
277
263
  pred_avg = 10 * np.log10(np.mean(predict, axis=-1) + np.finfo(np.float32).eps)
278
264
  p += 1
279
265
  x_axis = np.arange(len(pred_avg), dtype=np.float32) # / SAMPLE_RATE
280
- ax[p].plot(
281
- x_axis,
282
- pred_avg,
283
- color="black",
284
- linestyle="dashed",
285
- label="Predict mean over freq.",
286
- )
266
+ ax[p].plot(x_axis, pred_avg, color="black", linestyle="dashed", label="Predict mean over freq.")
287
267
  ax[p].set_ylabel("mean db", color="black")
288
268
  ax[p].set_xlim(x_axis[0], x_axis[-1])
289
269
  if truth_f is not None:
290
270
  truth_avg = 10 * np.log10(np.mean(truth_f, axis=-1) + np.finfo(np.float32).eps)
291
- ax[p].plot(
292
- x_axis,
293
- truth_avg,
294
- color="green",
295
- linestyle="dashed",
296
- label="Truth mean over freq.",
297
- )
271
+ ax[p].plot(x_axis, truth_avg, color="green", linestyle="dashed", label="Truth mean over freq.")
298
272
 
299
273
  if metric is not None: # instantiate 2nd y-axis that shares the same x-axis
300
274
  ax2 = ax[p].twinx()
@@ -317,7 +291,7 @@ def plot_e_predict_truth(
317
291
  truth_wav: np.ndarray | None = None,
318
292
  metric: np.ndarray | None = None,
319
293
  tp_title: str = "",
320
- ) -> plt.Figure:
294
+ ) -> plt.Figure: # pyright: ignore [reportPrivateImportUsage]
321
295
  """Plot predict spectrogram and waveform and optionally truth and a metric)"""
322
296
  num_plots = 2
323
297
  if truth_f is not None:
@@ -335,13 +309,7 @@ def plot_e_predict_truth(
335
309
 
336
310
  if truth_f is not None: # plot truth if provided and use same colormap as predict
337
311
  p += 1
338
- ax[p].imshow(
339
- truth_f.transpose(),
340
- im.cmap,
341
- aspect="auto",
342
- interpolation="nearest",
343
- origin="lower",
344
- )
312
+ ax[p].imshow(truth_f.transpose(), im.cmap, aspect="auto", interpolation="nearest", origin="lower")
345
313
  ax[p].set_title("Truth")
346
314
 
347
315
  # Plot predict wav, and optionally truth avg and metric lines
@@ -383,7 +351,17 @@ def plot_e_predict_truth(
383
351
  return fig
384
352
 
385
353
 
386
- def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
354
+ def _process_mixture(
355
+ m_id: int,
356
+ truth_location: str,
357
+ predict_location: str,
358
+ predict_wav_mode: bool,
359
+ truth_est_mode: bool,
360
+ enable_plot: bool,
361
+ enable_wav: bool,
362
+ asr_method: str,
363
+ target_f_key: str,
364
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
387
365
  import pickle
388
366
  from os.path import basename
389
367
  from os.path import join
@@ -412,19 +390,10 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
412
390
  from sonusai.utils import unstack_complex
413
391
  from sonusai.utils import write_audio
414
392
 
415
- global MP_GLOBAL
416
-
417
- mixdb = MP_GLOBAL.mixdb
418
- predict_location = MP_GLOBAL.predict_location
419
- predict_wav_mode = MP_GLOBAL.predict_wav_mode
420
- truth_est_mode = MP_GLOBAL.truth_est_mode
421
- enable_plot = MP_GLOBAL.enable_plot
422
- enable_wav = MP_GLOBAL.enable_wav
423
- asr_method = MP_GLOBAL.asr_method
424
- asr_model_name = MP_GLOBAL.asr_model_name
393
+ mixdb = MixtureDatabase(truth_location)
425
394
 
426
- # 1) Read predict data, var predict with shape [BatchSize,Classes] or [BatchSize,Tsteps,Classes]
427
- output_name = join(predict_location, mixdb.mixture(mixid).name)
395
+ # 1) Read predict data, var predict with shape [BatchSize,Classes] or [batch, timesteps, classes]
396
+ output_name = join(predict_location, mixdb.mixture(m_id).name + ".h5")
428
397
  predict = None
429
398
  if truth_est_mode:
430
399
  # in truth estimation mode we use the truth in place of prediction to see metrics with perfect input
@@ -439,31 +408,31 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
439
408
  predict = np.array(f["predict"])
440
409
  except Exception as e:
441
410
  raise OSError(f"Error reading {output_name}: {e}") from e
442
- # reshape to always be [frames,classes] where ndim==3 case frames = batch * tsteps
411
+ # reshape to always be [frames, classes] where ndim==3 case frames = batch * timesteps
443
412
  if predict.ndim > 2: # TBD generalize to somehow detect if timestep dim exists, some cases > 2 don't have
444
413
  # logger.debug(f'Prediction reshape from {predict.shape} to remove timestep dimension.')
445
414
  predict, _ = reshape_outputs(predict=predict, truth=None, timesteps=predict.shape[1])
446
415
  else:
447
416
  base_name, ext = splitext(output_name)
448
417
  predict_name = join(base_name + ".wav")
449
- audio = read_audio(predict_name)
418
+ audio = read_audio(predict_name, use_cache=True)
450
419
  predict = forward_transform(audio, mixdb.ft_config)
451
420
  if mixdb.feature[0:1] == "h":
452
421
  predict = power_compress(predict)
453
422
  predict = stack_complex(predict)
454
423
 
455
424
  # 2) Collect true target, noise, mixture data, trim to predict size if needed
456
- tmp = mixdb.mixture_targets(mixid) # targets is list of pre-IR and pre-specaugment targets
457
- target_f = mixdb.mixture_targets_f(mixid, targets=tmp)[0]
425
+ tmp = mixdb.mixture_targets(m_id) # time-dom augmented targets is list of pre-IR and pre-specaugment targets
426
+ target_f = mixdb.mixture_targets_f(m_id, targets=tmp)[0]
458
427
  target = tmp[0]
459
- mixture = mixdb.mixture_mixture(mixid) # note: gives full reverberated/distorted target, but no specaugment
428
+ mixture = mixdb.mixture_mixture(m_id) # note: gives full reverberated/distorted target, but no specaugment
460
429
  # noise_wo_dist = mixdb.mixture_noise(mixid) # noise without specaugment and distortion
461
430
  # noise_wo_dist_f = mixdb.mixture_noise_f(mixid, noise=noise_wo_dist)
462
431
  noise = mixture - target # has time-domain distortion (ir,etc.) but does not have specaugment
463
432
  # noise_f = mixdb.mixture_noise_f(mixid, noise=noise)
464
433
  # note: uses pre-IR, pre-specaug audio
465
- segsnr_f: np.ndarray = mixdb.mixture_metrics(mixid, ["ssnr"])[0] # type: ignore[assignment]
466
- mixture_f = mixdb.mixture_mixture_f(mixid, mixture=mixture)
434
+ segsnr_f = mixdb.mixture_metrics(m_id, ["ssnr"])["ssnr"][0]
435
+ mixture_f = mixdb.mixture_mixture_f(m_id, mixture=mixture)
467
436
  noise_f = mixture_f - target_f # true noise in freq domain includes specaugment and time-domain ir,distortions
468
437
  # segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise)
469
438
  segsnr_f[segsnr_f == np.inf] = DB_99
@@ -476,13 +445,21 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
476
445
 
477
446
  # gen feature, truth - note feature only used for plots
478
447
  # TODO: parse truth_f for different formats
479
- feature, truth_f = mixdb.mixture_ft(mixid, mixture_f=mixture_f)
448
+ feature, truth_all = mixdb.mixture_ft(m_id, mixture_f=mixture_f)
449
+ truth_f = truth_all[target_f_key]
450
+ if truth_f.ndim > 2: # note this may not be needed anymore as all target_f truth is 3 dims
451
+ if truth_f.shape[1] != 1:
452
+ logger.info("Error: target_f truth has stride > 1, exiting.")
453
+ raise SystemExit(1)
454
+ else:
455
+ truth_f = truth_f[:, 0, :] # remove stride dimension
456
+
480
457
  # ignore mixup
481
- for truth_setting in mixdb.target_file(mixdb.mixture(mixid).targets[0].file_id).truth_configs:
482
- if truth_setting.function == "target_mixture_f":
483
- half = truth_f.shape[-1] // 2
484
- # extract target_f only
485
- truth_f = truth_f[..., :half]
458
+ # for truth_setting in mixdb.target_file(mixdb.mixture(mixid).targets[0].file_id).truth_settings:
459
+ # if truth_setting.function == 'target_mixture_f':
460
+ # half = truth_f.shape[-1] // 2
461
+ # # extract target_f only
462
+ # truth_f = truth_f[..., :half]
486
463
 
487
464
  if not truth_est_mode:
488
465
  if predict.shape[0] < target_f.shape[0]: # target_f, truth_f, mixture_f, etc. same size
@@ -503,15 +480,17 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
503
480
  )
504
481
  trim_f = predict.shape[0] - target_f.shape[0]
505
482
  predict = predict[0:-trim_f, :]
506
- # raise ValueError(
483
+ # raise SonusAIError(
507
484
  # f'Error: prediction has more frames than true mixture {predict.shape[0]} vs {truth_f.shape[0]}')
508
485
 
509
486
  # 3) Extraction - format proper complex and wav estimates and truth (unstack, uncompress, inv tf, etc.)
510
487
  if truth_est_mode:
511
488
  predict = truth_f # substitute truth for the prediction (for test/debug)
512
489
  predict_complex = unstack_complex(predict) # unstack
513
- # if feat has compressed mag and truth does not, compress it
514
- if mixdb.feature[0:1] == "h" and mixdb.target_file(1).truth_configs[0].function[0:10] != "targetcmpr":
490
+ # if feature has compressed mag and truth does not, compress it
491
+ if mixdb.feature[0:1] == "h" and not mixdb.truth_configs[first_key(mixdb.truth_configs)].function.startswith(
492
+ "targetcmpr"
493
+ ):
515
494
  predict_complex = power_compress(predict_complex) # from uncompressed truth
516
495
  else:
517
496
  predict_complex = unstack_complex(predict)
@@ -556,10 +535,14 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
556
535
  # logger.debug(f'wsdr ccoefs for mixid {mixid} = {wsdr_cc}.')
557
536
 
558
537
  # Speech intelligibility measure - PESQ
559
- if int(mixdb.mixture(mixid).snr) > -99:
538
+ if int(mixdb.mixture(m_id).snr) > -99:
560
539
  # len = target_est_wav.shape[0]
561
540
  pesq_speech, csig_tg, cbak_tg, covl_tg = calc_speech(target_est_wav, target_fi)
562
- pesq_mixture, csig_mx, cbak_mx, covl_mx = mixdb.mixture_metrics(mixid, ["mxpesq", "mxcsig", "mxcbak", "mxcovl"])
541
+ metrics = mixdb.mixture_metrics(m_id, ["mxpesq", "mxcsig", "mxcbak", "mxcovl"])
542
+ pesq_mixture = metrics["mxpesq"]
543
+ csig_mx = metrics["mxcsig"]
544
+ cbak_mx = metrics["mxcbak"]
545
+ covl_mx = metrics["mxcovl"]
563
546
  # pesq_speech_tst = calc_pesq(hypothesis=target_est_wav, reference=target)
564
547
  # pesq_mixture_tst = calc_pesq(hypothesis=mixture, reference=target)
565
548
  # pesq improvement
@@ -581,25 +564,37 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
581
564
  asr_tt = None
582
565
  asr_mx = None
583
566
  asr_tge = None
584
- asr_engines = list(mixdb.asr_configs.keys())
585
- if len(asr_engines) > 0 and not mixdb.mixture(mixid).is_noise_only: # noise only, ignore/reset target asr
586
- wer_mx = float(mixdb.mixture_metrics(mixid, [f"mxwer.{asr_engines[0]}"])[0]) * 100
587
- asr_tt = MP_GLOBAL.mixdb.mixture_speech_metadata(mixid, "text")[0] # ignore mixup
588
- if asr_tt is None:
589
- asr_tt = calc_asr(target, engine=asr_method, whisper_model_name=asr_model_name).text # target truth
567
+ # asr_engines = list(mixdb.asr_configs.keys())
568
+ if asr_method is not None and mixdb.mixture(m_id).snr >= -96: # noise only, ignore/reset target ASR
569
+ asr_mx_name = f"mxasr.{asr_method}"
570
+ wer_mx_name = f"mxwer.{asr_method}"
571
+ asr_tt_name = f"tasr.{asr_method}"
572
+ metrics = mixdb.mixture_metrics(m_id, [asr_mx_name, wer_mx_name, asr_tt_name])
573
+ asr_mx = metrics[asr_mx_name][0]
574
+ wer_mx = metrics[wer_mx_name][0]
575
+ asr_tt = metrics[asr_tt_name][0]
590
576
 
591
577
  if asr_tt:
592
- asr_tge = calc_asr(target_est_wav, engine=asr_method, whisper_model_name=asr_model_name).text
593
- wer_tge = calc_wer(asr_tge, asr_tt).wer * 100 # target estimate wer
578
+ noiseadd = None # TBD add as switch, default -30
579
+ if noiseadd is not None:
580
+ ngain = np.power(10, min(float(noiseadd), 0.0) / 20.0) # limit to gain <1, convert to float
581
+ tgasr_est_wav = target_est_wav + ngain * noise_est_wav # add back noise at low level
582
+ else:
583
+ tgasr_est_wav = target_est_wav
584
+
585
+ # logger.info(f'Calculating prediction ASR for mixid {mixid}')
586
+ asr_cfg = mixdb.asr_configs[asr_method]
587
+ asr_tge = calc_asr(tgasr_est_wav, **asr_cfg).text
588
+ wer_tge = calc_wer(asr_tge, asr_tt).wer * 100 # target estimate WER
594
589
  if wer_mx == 0.0:
595
590
  if wer_tge == 0.0:
596
591
  wer_pi = 0.0
597
592
  else:
598
- wer_pi = -999.0
593
+ wer_pi = -999.0 # instead of -Inf
599
594
  else:
600
595
  wer_pi = 100 * (wer_mx - wer_tge) / wer_mx
601
596
  else:
602
- print(f"Warning: mixid {mixid} asr truth is empty, setting to 0% wer")
597
+ logger.warning(f"Warning: mixid {m_id} ASR truth is empty, setting to 0% WER")
603
598
  wer_mx = float(0)
604
599
  wer_tge = float(0)
605
600
  wer_pi = float(0)
@@ -633,10 +628,10 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
633
628
  "SPFILE",
634
629
  "NFILE",
635
630
  ]
636
- ti = mixdb.mixture(mixid).targets[0].file_id
637
- ni = mixdb.mixture(mixid).noise.file_id
631
+ ti = mixdb.mixture(m_id).targets[0].file_id
632
+ ni = mixdb.mixture(m_id).noise.file_id
638
633
  metr1 = [
639
- mixdb.mixture(mixid).snr,
634
+ mixdb.mixture(m_id).snr,
640
635
  pesq_mixture,
641
636
  pesq_speech,
642
637
  pesq_impr_pc,
@@ -658,17 +653,11 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
658
653
  basename(mixdb.target_file(ti).name),
659
654
  basename(mixdb.noise_file(ni).name),
660
655
  ]
661
- mtab1 = pd.DataFrame([metr1], columns=mtable1_col, index=[mixid])
656
+ mtab1 = pd.DataFrame([metr1], columns=mtable1_col, index=[m_id])
662
657
 
663
658
  # Stats of per frame estimation metrics
664
659
  metr2 = pd.DataFrame(
665
- {
666
- "SSNR": segsnr_f,
667
- "PCM": pcm_frame,
668
- "SLERR": lerr_tg_frame,
669
- "NLERR": lerr_n_frame,
670
- "SPD": phd_frame,
671
- }
660
+ {"SSNR": segsnr_f, "PCM": pcm_frame, "SLERR": lerr_tg_frame, "NLERR": lerr_n_frame, "SPD": phd_frame}
672
661
  )
673
662
  metr2 = metr2.describe() # Use pandas stat function
674
663
  # Change SSNR stats to dB, except count. SSNR is index 0, pandas requires using iloc
@@ -679,29 +668,33 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
679
668
  [metr2.columns, ["Avg", "Min", "Med", "Max", "Std"]], names=["Metric", "Stat"]
680
669
  )
681
670
  dat1row = metr2.loc[["mean", "min", "50%", "max", "std"], :].T.stack().to_numpy().reshape((1, -1))
682
- mtab2 = pd.DataFrame(dat1row, index=[mixid], columns=new_labels)
683
- mtab2.insert(0, "MXSNR", mixdb.mixture(mixid).snr, False) # add MXSNR as the first metric column
671
+ mtab2 = pd.DataFrame(dat1row, index=[m_id], columns=new_labels)
672
+ mtab2.insert(0, "MXSNR", mixdb.mixture(m_id).snr, False) # add MXSNR as the first metric column
684
673
 
685
674
  all_metrics_table_1 = mtab1 # return to be collected by process
686
675
  all_metrics_table_2 = mtab2 # return to be collected by process
687
676
 
688
- metric_name = base_name + "_metric_spenh.txt"
689
- with open(metric_name, "w") as f, redirect_stdout(f):
690
- print("Speech enhancement metrics:")
691
- print(mtab1.round(2).to_string(float_format=lambda x: f"{x:.2f}"))
692
- print("")
693
- print(f"Extraction statistics over {mixture_f.shape[0]} frames:")
694
- print(metr2.round(2).to_string(float_format=lambda x: f"{x:.2f}"))
695
- print("")
696
- print(f"Target path: {mixdb.target_file(ti).name}")
697
- print(f"Noise path: {mixdb.noise_file(ni).name}")
677
+ if asr_method is None:
678
+ metric_name = base_name + "_metric_spenh.txt"
679
+ else:
680
+ metric_name = base_name + "_metric_spenh_" + asr_method + ".txt"
681
+
682
+ with open(metric_name, "w") as f:
683
+ print("Speech enhancement metrics:", file=f)
684
+ print(mtab1.round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
685
+ print("", file=f)
686
+ print(f"Extraction statistics over {mixture_f.shape[0]} frames:", file=f)
687
+ print(metr2.round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
688
+ print("", file=f)
689
+ print(f"Target path: {mixdb.target_file(ti).name}", file=f)
690
+ print(f"Noise path: {mixdb.noise_file(ni).name}", file=f)
698
691
  if asr_method != "none":
699
- print(f"ASR method: {asr_method} and whisper model (if used): {asr_model_name}")
700
- print(f"ASR truth: {asr_tt}")
701
- print(f"ASR result for mixture: {asr_mx}")
702
- print(f"ASR result for prediction: {asr_tge}")
692
+ print(f"ASR method: {asr_method}", file=f)
693
+ print(f"ASR truth: {asr_tt}", file=f)
694
+ print(f"ASR result for mixture: {asr_mx}", file=f)
695
+ print(f"ASR result for prediction: {asr_tge}", file=f)
703
696
 
704
- print(f"Augmentations: {mixdb.mixture(mixid)}")
697
+ print(f"Augmentations: {mixdb.mixture(m_id)}", file=f)
705
698
 
706
699
  # 7) write wav files
707
700
  if enable_wav:
@@ -728,7 +721,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
728
721
  # Reshape to get frames*decimated_stride, num_bands
729
722
  step = int(mixdb.feature_samples / mixdb.feature_step_samples)
730
723
  if feature.ndim != 3:
731
- raise ValueError("feature does not have 3 dimensions: frames, stride, num_bands")
724
+ raise OSError("feature does not have 3 dimensions: frames, stride, num_bands")
732
725
 
733
726
  # for feature cn*00n**
734
727
  feat_sgram = unstack_complex(feature)
@@ -738,17 +731,19 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
738
731
 
739
732
  with PdfPages(plot_name) as pdf:
740
733
  # page1 we always have a mixture and prediction, target optional if truth provided
741
- tfunc_name = mixdb.target_file(1).truth_configs[0].function # first target, assumes all have same
742
- if tfunc_name == "mapped_snr_f":
743
- # leave as unmapped snr
744
- predplot = predict
745
- tfunc_name = mixdb.target_file(1).truth_configs[0].function
746
- elif tfunc_name in ("target_f", "target_mixture_f"):
747
- predplot = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
748
- else:
749
- # use dB scale
750
- predplot = 10 * np.log10(predict + np.finfo(np.float32).eps)
751
- tfunc_name = tfunc_name + " (db)"
734
+ # For speech enhancement, target_f is definitely included:
735
+ predplot = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
736
+ tfunc_name = "target_f"
737
+ # if tfunc_name == 'mapped_snr_f':
738
+ # # leave as unmapped snr
739
+ # predplot = predict
740
+ # tfunc_name = mixdb.target_file(1).truth_settings[0].function
741
+ # elif tfunc_name == 'target_f' or 'target_mixture_f':
742
+ # predplot = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
743
+ # else:
744
+ # # use dB scale
745
+ # predplot = 10 * np.log10(predict + np.finfo(np.float32).eps)
746
+ # tfunc_name = tfunc_name + ' (db)'
752
747
 
753
748
  mixspec = 20 * np.log10(abs(mixture_f) + np.finfo(np.float32).eps)
754
749
  fig_obj = plot_mixpred(
@@ -816,8 +811,7 @@ def main():
816
811
 
817
812
  verbose = args["--verbose"]
818
813
  mixids = args["--mixid"]
819
- asr_method = args["--asr-method"].lower()
820
- asr_model_name = args["--model"].lower()
814
+ asr_method = args["--asr-method"]
821
815
  truth_est_mode = args["--truth-est-mode"]
822
816
  enable_plot = args["--plot"]
823
817
  enable_wav = args["--wav"]
@@ -827,6 +821,7 @@ def main():
827
821
  truth_location = args["TLOC"]
828
822
 
829
823
  import glob
824
+ from functools import partial
830
825
  from os.path import basename
831
826
  from os.path import isdir
832
827
  from os.path import join
@@ -837,16 +832,13 @@ def main():
837
832
  from sonusai import initial_log_messages
838
833
  from sonusai import logger
839
834
  from sonusai import update_console_handler
840
- from sonusai.mixture import DEFAULT_SPEECH
841
835
  from sonusai.mixture import MixtureDatabase
842
- from sonusai.mixture import read_audio
843
- from sonusai.utils import calc_asr
844
836
  from sonusai.utils import par_track
845
837
  from sonusai.utils import track
846
838
 
847
839
  # Check prediction subdirectory
848
840
  if not isdir(predict_location):
849
- print(f"The specified predict location {predict_location} is not a valid subdirectory path, exiting ...")
841
+ print(f"The specified predict location {predict_location} is not a valid subdirectory path, exiting.")
850
842
 
851
843
  # all_predict_files = listdir(predict_location)
852
844
  all_predict_files = glob.glob(predict_location + "/*.h5")
@@ -855,7 +847,7 @@ def main():
855
847
  if len(all_predict_files) <= 0 and not truth_est_mode:
856
848
  all_predict_files = glob.glob(predict_location + "/*.wav") # check for wav files
857
849
  if len(all_predict_files) <= 0:
858
- print(f"Subdirectory {predict_location} has no .h5 or .wav files, exiting ...")
850
+ print(f"Subdirectory {predict_location} has no .h5 or .wav files, exiting.")
859
851
  else:
860
852
  logger.info(f"Found {len(all_predict_files)} prediction .wav files.")
861
853
  predict_wav_mode = True
@@ -877,59 +869,40 @@ def main():
877
869
  logger.info(
878
870
  f"Found mixdb of {mixdb.num_mixtures} total mixtures, with {mixdb.num_classes} classes in {truth_location}"
879
871
  )
880
- logger.info(f"Only running specified subset of {len(mixids)} mixtures")
872
+ # speech enhancement metrics and audio truth requires target_f truth type, check it is present
873
+ target_f_key = None
874
+ logger.info(f"mixdb has {len(mixdb.truth_configs)} truth types defined, checking that target_f type is present.")
875
+ for key in mixdb.truth_configs:
876
+ if mixdb.truth_configs[key].function == "target_f":
877
+ target_f_key = key
878
+ if target_f_key is None:
879
+ logger.error("mixdb does not have target_f truth define, required for speech enhancement metrics, exiting.")
880
+ raise SystemExit(1)
881
881
 
882
- enable_asr_warmup = False
883
- if asr_method == "none":
884
- fnb = "metric_spenh_"
885
- elif asr_method == "google":
886
- fnb = "metric_spenh_ggl_"
887
- logger.info(f"ASR enabled with method {asr_method}")
888
- enable_asr_warmup = True
889
- elif asr_method == "deepgram":
890
- fnb = "metric_spenh_dgram_"
891
- logger.info(f"ASR enabled with method {asr_method}")
892
- enable_asr_warmup = True
893
- elif asr_method == "aixplain_whisper":
894
- fnb = "metric_spenh_whspx_" + asr_model_name + "_"
895
- logger.info(f"ASR enabled with method {asr_method} and whisper model {asr_model_name}")
896
- enable_asr_warmup = True
897
- elif asr_method == "whisper":
898
- fnb = "metric_spenh_whspl_" + asr_model_name + "_"
899
- logger.info(f"ASR enabled with method {asr_method} and whisper model {asr_model_name}")
900
- enable_asr_warmup = True
901
- elif asr_method == "aaware_whisper":
902
- fnb = "metric_spenh_whspaaw_" + asr_model_name + "_"
903
- logger.info(f"ASR enabled with method {asr_method} and whisper model {asr_model_name}")
904
- enable_asr_warmup = True
905
- elif asr_method == "faster_whisper":
906
- fnb = "metric_spenh_fwhsp_" + asr_model_name + "_"
907
- logger.info(f"ASR enabled with method {asr_method} and whisper model {asr_model_name}")
908
- enable_asr_warmup = True
909
- elif asr_method == "sensory":
910
- fnb = "metric_spenh_snsr_" + asr_model_name + "_"
911
- logger.info(f"ASR enabled with method {asr_method} and model {asr_model_name}")
912
- enable_asr_warmup = True
913
- else:
914
- logger.error(f"Unrecognized ASR method: {asr_method}")
915
- return
916
-
917
- if enable_asr_warmup:
918
- audio = read_audio(DEFAULT_SPEECH)
919
- logger.info("Warming up asr method, note for cloud service this could take up to a few min ...")
920
- asr_chk = calc_asr(audio, engine=asr_method, whisper_model_name=asr_model_name)
921
- logger.info(f"Warmup completed, results {asr_chk}")
922
-
923
- global MP_GLOBAL
882
+ logger.info(f"Only running specified subset of {len(mixids)} mixtures")
924
883
 
925
- MP_GLOBAL.mixdb = mixdb
926
- MP_GLOBAL.predict_location = predict_location
927
- MP_GLOBAL.predict_wav_mode = predict_wav_mode
928
- MP_GLOBAL.truth_est_mode = truth_est_mode
929
- MP_GLOBAL.enable_plot = enable_plot
930
- MP_GLOBAL.enable_wav = enable_wav
931
- MP_GLOBAL.asr_method = asr_method
932
- MP_GLOBAL.asr_model_name = asr_model_name
884
+ asr_config_en = None
885
+ fnb = "metric_spenh_"
886
+ if asr_method is not None:
887
+ if asr_method in mixdb.asr_configs:
888
+ logger.info(f"Specified ASR method {asr_method} exists in mixdb.asr_configs, it will be used for ")
889
+ logger.info("prediction ASR and WER, and pre-calculated target and mixture ASR if available.")
890
+ asr_config_en = True
891
+ asr_cfg = mixdb.asr_configs[asr_method]
892
+ fnb = "metric_spenh_" + asr_method + "_"
893
+ logger.info(f"Using ASR cfg: {asr_cfg} ")
894
+ # audio = read_audio(DEFAULT_SPEECH, use_cache=True)
895
+ # logger.info(f'Warming up {asr_method}, note for cloud service this could take up to a few minutes.')
896
+ # asr_chk = calc_asr(audio, **asr_cfg)
897
+ # logger.info(f'Warmup completed, results {asr_chk}')
898
+ else:
899
+ logger.info(
900
+ f"Specified ASR method {asr_method} does not exists in mixdb.asr_configs."
901
+ f"Must choose one of the following (or none):"
902
+ )
903
+ logger.info(f"{', '.join(mixdb.asr_configs)}")
904
+ logger.error("Unrecognized ASR method, exiting.")
905
+ raise SystemExit(1)
933
906
 
934
907
  num_cpu = psutil.cpu_count()
935
908
  cpu_percent = psutil.cpu_percent(interval=1)
@@ -944,12 +917,33 @@ def main():
944
917
 
945
918
  # Individual mixtures use pandas print, set precision to 2 decimal places
946
919
  # pd.set_option('float_format', '{:.2f}'.format)
947
- logger.info(f"Calculating metrics for {len(mixids)} mixtures using {use_cpu} parallel processes ...")
948
- progress = track(total=len(mixids), desc="calc_metric_spenh")
920
+ logger.info(f"Calculating metrics for {len(mixids)} mixtures using {use_cpu} parallel processes")
921
+ # progress = tqdm(total=len(mixids), desc='calc_metric_spenh', mininterval=1)
922
+ progress = track(total=len(mixids))
949
923
  if use_cpu is None:
950
- all_metrics_tables = par_track(_process_mixture, mixids, progress=progress, no_par=True)
924
+ no_par = True
925
+ num_cpus = None
951
926
  else:
952
- all_metrics_tables = par_track(_process_mixture, mixids, progress=progress, num_cpus=use_cpu)
927
+ no_par = True
928
+ num_cpus = None
929
+
930
+ all_metrics_tables = par_track(
931
+ partial(
932
+ _process_mixture,
933
+ truth_location=truth_location,
934
+ predict_location=predict_location,
935
+ predict_wav_mode=predict_wav_mode,
936
+ truth_est_mode=truth_est_mode,
937
+ enable_plot=enable_plot,
938
+ enable_wav=enable_wav,
939
+ asr_method=asr_method,
940
+ target_f_key=target_f_key,
941
+ ),
942
+ mixids,
943
+ progress=progress,
944
+ num_cpus=num_cpus,
945
+ no_par=no_par,
946
+ )
953
947
  progress.close()
954
948
 
955
949
  all_metrics_table_1 = pd.concat([item[0] for item in all_metrics_tables])
@@ -1010,7 +1004,7 @@ def main():
1010
1004
  all_nom99_mean["WERi%"] = 0.0
1011
1005
  else:
1012
1006
  all_nom99_mean["WERi%"] = -999.0
1013
- else: # wer%
1007
+ else: # WER%
1014
1008
  all_nom99_mean["WERi%"] = 100 * (all_nom99_mean["MXWER"] - all_nom99_mean["WER"]) / all_nom99_mean["MXWER"]
1015
1009
 
1016
1010
  num_mix = len(mixids)
@@ -1023,33 +1017,37 @@ def main():
1023
1017
  else:
1024
1018
  ofname = join(predict_location, fnb + "summary_truest.txt")
1025
1019
 
1026
- with open(ofname, "w") as f, redirect_stdout(f):
1027
- print(f"ASR enabled with method {asr_method}, whisper model, if used: {asr_model_name}")
1028
- print(f"Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:")
1029
- print(all_nom99_mean.to_frame().T.round(2).to_string(float_format=lambda x: f"{x:.2f}", index=False))
1030
- print("\nSpeech enhancement metrics avg over each SNR:")
1031
- print(mtab_snr_summary.round(2).to_string(float_format=lambda x: f"{x:.2f}", index=False))
1032
- print("")
1033
- print("Extraction statistics stats avg over each SNR:")
1020
+ with open(ofname, "w") as f:
1021
+ print(f"ASR enabled with method {asr_method}", file=f)
1022
+ print(
1023
+ f"Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:", file=f
1024
+ )
1025
+ print(
1026
+ all_nom99_mean.to_frame().T.round(2).to_string(float_format=lambda x: f"{x:.2f}", index=False), file=f
1027
+ )
1028
+ print("\nSpeech enhancement metrics avg over each SNR:", file=f)
1029
+ print(mtab_snr_summary.round(2).to_string(float_format=lambda x: f"{x:.2f}", index=False), file=f)
1030
+ print("", file=f)
1031
+ print("Extraction statistics stats avg over each SNR:", file=f)
1034
1032
  # with pd.option_context('display.max_colwidth', 9):
1035
1033
  # with pd.set_option('float_format', '{:.1f}'.format):
1036
- print(mtab_snr_summary_em.round(1).to_string(float_format=lambda x: f"{x:.1f}", index=False))
1037
- print("")
1034
+ print(mtab_snr_summary_em.round(1).to_string(float_format=lambda x: f"{x:.1f}", index=False), file=f)
1035
+ print("", file=f)
1038
1036
  # pd.set_option('float_format', '{:.2f}'.format)
1039
1037
 
1040
- print(f"Speech enhancement metrics stats over all {num_mix} mixtures:")
1041
- print(all_metrics_table_1.describe().round(2).to_string(float_format=lambda x: f"{x:.2f}"))
1042
- print("")
1043
- print(f"Extraction statistics stats over all {num_mix} mixtures:")
1044
- print(all_metrics_table_2.describe().round(2).to_string(float_format=lambda x: f"{x:.1f}"))
1045
- print("")
1038
+ print(f"Speech enhancement metrics stats over all {num_mix} mixtures:", file=f)
1039
+ print(all_metrics_table_1.describe().round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
1040
+ print("", file=f)
1041
+ print(f"Extraction statistics stats over all {num_mix} mixtures:", file=f)
1042
+ print(all_metrics_table_2.describe().round(2).to_string(float_format=lambda x: f"{x:.1f}"), file=f)
1043
+ print("", file=f)
1046
1044
 
1047
- print("Speech enhancement metrics all-mixtures list:")
1048
- # print(all_metrics_table_1.head().style.format(precision=2))
1049
- print(all_metrics_table_1.round(2).to_string(float_format=lambda x: f"{x:.2f}"))
1050
- print("")
1051
- print("Extraction statistics all-mixtures list:")
1052
- print(all_metrics_table_2.round(2).to_string(float_format=lambda x: f"{x:.1f}"))
1045
+ print("Speech enhancement metrics all-mixtures list:", file=f)
1046
+ # print(all_metrics_table_1.head().style.format(precision=2), file=f)
1047
+ print(all_metrics_table_1.round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
1048
+ print("", file=f)
1049
+ print("Extraction statistics all-mixtures list:", file=f)
1050
+ print(all_metrics_table_2.round(2).to_string(float_format=lambda x: f"{x:.1f}"), file=f)
1053
1051
 
1054
1052
  # Write summary to .csv file
1055
1053
  if not truth_est_mode:
@@ -1084,7 +1082,7 @@ def main():
1084
1082
  label = f"Extraction statistics stats over {num_mix} mixtures:"
1085
1083
  pd.DataFrame([label]).to_csv(csv_name, **header_args)
1086
1084
  all_metrics_table_2.describe().round(2).to_csv(csv_name, **table_args)
1087
- label = f"ASR enabled with method {asr_method}, whisper model, if used: {asr_model_name}"
1085
+ label = f"ASR enabled with method {asr_method}"
1088
1086
  pd.DataFrame([label]).to_csv(csv_name, **header_args)
1089
1087
 
1090
1088
  if not truth_est_mode:
@@ -1104,3 +1102,37 @@ def main():
1104
1102
 
1105
1103
  if __name__ == "__main__":
1106
1104
  main()
1105
+
1106
+ # if asr_method == 'none':
1107
+ # fnb = 'metric_spenh_'
1108
+ # elif asr_method == 'google':
1109
+ # fnb = 'metric_spenh_ggl_'
1110
+ # logger.info(f'ASR enabled with method {asr_method}')
1111
+ # enable_asr_warmup = True
1112
+ # elif asr_method == 'deepgram':
1113
+ # fnb = 'metric_spenh_dgram_'
1114
+ # logger.info(f'ASR enabled with method {asr_method}')
1115
+ # enable_asr_warmup = True
1116
+ # elif asr_method == 'aixplain_whisper':
1117
+ # fnb = 'metric_spenh_whspx_' + mixdb.asr_configs[asr_method]['model'] + '_'
1118
+ # asr_model_name = mixdb.asr_configs[asr_method]['model']
1119
+ # enable_asr_warmup = True
1120
+ # elif asr_method == 'whisper':
1121
+ # fnb = 'metric_spenh_whspl_' + mixdb.asr_configs[asr_method]['model'] + '_'
1122
+ # asr_model_name = mixdb.asr_configs[asr_method]['model']
1123
+ # enable_asr_warmup = True
1124
+ # elif asr_method == 'aaware_whisper':
1125
+ # fnb = 'metric_spenh_whspaaw_' + mixdb.asr_configs[asr_method]['model'] + '_'
1126
+ # asr_model_name = mixdb.asr_configs[asr_method]['model']
1127
+ # enable_asr_warmup = True
1128
+ # elif asr_method == 'faster_whisper':
1129
+ # fnb = 'metric_spenh_fwhsp_' + mixdb.asr_configs[asr_method]['model'] + '_'
1130
+ # asr_model_name = mixdb.asr_configs[asr_method]['model']
1131
+ # enable_asr_warmup = True
1132
+ # elif asr_method == 'sensory':
1133
+ # fnb = 'metric_spenh_snsr_' + mixdb.asr_configs[asr_method]['model'] + '_'
1134
+ # asr_model_name = mixdb.asr_configs[asr_method]['model']
1135
+ # enable_asr_warmup = True
1136
+ # else:
1137
+ # logger.error(f'Unrecognized ASR method: {asr_method}')
1138
+ # return