sonusai 0.19.9__py3-none-any.whl → 0.20.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/calc_metric_spenh.py +265 -233
- sonusai/data/genmixdb.yml +4 -2
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/doc/doc.py +14 -0
- sonusai/genft.py +1 -1
- sonusai/genmetrics.py +15 -18
- sonusai/genmix.py +1 -1
- sonusai/genmixdb.py +30 -52
- sonusai/ir_metric.py +555 -0
- sonusai/metrics_summary.py +322 -0
- sonusai/mixture/__init__.py +6 -2
- sonusai/mixture/audio.py +139 -15
- sonusai/mixture/augmentation.py +199 -84
- sonusai/mixture/config.py +9 -4
- sonusai/mixture/constants.py +0 -1
- sonusai/mixture/datatypes.py +19 -10
- sonusai/mixture/generation.py +52 -64
- sonusai/mixture/helpers.py +38 -26
- sonusai/mixture/ir_delay.py +63 -0
- sonusai/mixture/mixdb.py +190 -46
- sonusai/mixture/targets.py +3 -6
- sonusai/mixture/truth_functions/energy.py +9 -5
- sonusai/mixture/truth_functions/metrics.py +1 -1
- sonusai/mkwav.py +1 -1
- sonusai/onnx_predict.py +1 -1
- sonusai/queries/queries.py +1 -1
- sonusai/utils/__init__.py +2 -0
- sonusai/utils/asr.py +1 -1
- sonusai/utils/load_object.py +8 -2
- sonusai/utils/stratified_shuffle_split.py +1 -1
- sonusai/utils/temp_seed.py +13 -0
- {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/METADATA +2 -2
- {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/RECORD +36 -35
- {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/WHEEL +1 -1
- sonusai/mixture/soundfile_audio.py +0 -130
- sonusai/mixture/sox_audio.py +0 -476
- sonusai/mixture/sox_augmentation.py +0 -136
- sonusai/mixture/torchaudio_audio.py +0 -106
- sonusai/mixture/torchaudio_augmentation.py +0 -109
- {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/entry_points.txt +0 -0
sonusai/calc_metric_spenh.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""sonusai calc_metric_spenh
|
2
2
|
|
3
|
-
usage: calc_metric_spenh [-hvtpws] [-i MIXID] [-e ASR] [-
|
3
|
+
usage: calc_metric_spenh [-hvtpws] [-i MIXID] [-e ASR] [-n NCPU] PLOC TLOC
|
4
4
|
|
5
5
|
options:
|
6
6
|
-h, --help
|
@@ -11,8 +11,10 @@ options:
|
|
11
11
|
-w, --wav Generate WAV files per mixture.
|
12
12
|
-s, --summary Enable summary files generation.
|
13
13
|
-n, --num_process NCPU Number of parallel processes to use [default: auto]
|
14
|
-
-e ASR, --asr-method ASR ASR method
|
15
|
-
|
14
|
+
-e ASR, --asr-method ASR ASR method used for WER metrics. Must exist in the TLOC dataset as pre-calculated
|
15
|
+
metrics using SonusAI genmetrics. Can be either an integer index, i.e 0,1,... or the
|
16
|
+
name of the asr_engine configuration in the dataset. If an incorrect name is specified,
|
17
|
+
a list of asr_engines of the dataset will be printed.
|
16
18
|
|
17
19
|
Calculate speech enhancement metrics of prediction data in PLOC using SonusAI mixture data in TLOC as truth/label
|
18
20
|
reference. Metric and extraction data files are written into PLOC.
|
@@ -20,9 +22,14 @@ reference. Metric and extraction data files are written into PLOC.
|
|
20
22
|
PLOC directory containing prediction data in .h5 files created from truth/label mixture data in TLOC
|
21
23
|
TLOC directory with SonusAI mixture database of truth/label mixture data
|
22
24
|
|
23
|
-
For
|
24
|
-
|
25
|
-
|
25
|
+
For ASR methods, the method must bel2 defined in the TLOC dataset, for example possible fast_whisper available models are:
|
26
|
+
{tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large} and an example configuration looks like:
|
27
|
+
{'fwhsptiny_cpu': {'engine': 'faster_whisper',
|
28
|
+
'model': 'tiny',
|
29
|
+
'device': 'cpu',
|
30
|
+
'beam_size': 5}}
|
31
|
+
Note: the ASR config can optionally include the model, device, and other fields the engine supports.
|
32
|
+
Most ASR are very computationally demanding and can overwhelm/hang a local system.
|
26
33
|
|
27
34
|
Outputs the following to PLOC (where id is mixid number 0:num_mixtures):
|
28
35
|
<id>_metric_spenh.txt
|
@@ -61,8 +68,6 @@ Inputs:
|
|
61
68
|
"""
|
62
69
|
|
63
70
|
import signal
|
64
|
-
from contextlib import redirect_stdout
|
65
|
-
from dataclasses import dataclass
|
66
71
|
|
67
72
|
import matplotlib
|
68
73
|
import matplotlib.pyplot as plt
|
@@ -93,24 +98,17 @@ signal.signal(signal.SIGINT, signal_handler)
|
|
93
98
|
matplotlib.use("SVG")
|
94
99
|
|
95
100
|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
predict_wav_mode: bool
|
101
|
-
truth_est_mode: bool
|
102
|
-
enable_plot: bool
|
103
|
-
enable_wav: bool
|
104
|
-
asr_method: str
|
105
|
-
asr_model_name: str
|
106
|
-
|
107
|
-
|
108
|
-
MP_GLOBAL: MPGlobal
|
101
|
+
def first_key(x: dict) -> str:
|
102
|
+
for key in x:
|
103
|
+
return key
|
104
|
+
raise KeyError("No key found")
|
109
105
|
|
110
106
|
|
111
107
|
def mean_square_error(
|
112
|
-
hypothesis: np.ndarray,
|
113
|
-
|
108
|
+
hypothesis: np.ndarray,
|
109
|
+
reference: np.ndarray,
|
110
|
+
squared: bool = False,
|
111
|
+
) -> tuple[float, np.ndarray, np.ndarray]:
|
114
112
|
"""Calculate root-mean-square error or mean square error
|
115
113
|
|
116
114
|
:param hypothesis: [frames, bins]
|
@@ -125,7 +123,7 @@ def mean_square_error(
|
|
125
123
|
# mean over bins for value per frame
|
126
124
|
err_f = np.mean(sq_err, axis=1)
|
127
125
|
# mean over all
|
128
|
-
err = np.mean(sq_err)
|
126
|
+
err = float(np.mean(sq_err))
|
129
127
|
|
130
128
|
if not squared:
|
131
129
|
err_b = np.sqrt(err_b)
|
@@ -135,9 +133,7 @@ def mean_square_error(
|
|
135
133
|
return err, err_b, err_f
|
136
134
|
|
137
135
|
|
138
|
-
def mean_abs_percentage_error(
|
139
|
-
hypothesis: np.ndarray, reference: np.ndarray
|
140
|
-
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
136
|
+
def mean_abs_percentage_error(hypothesis: np.ndarray, reference: np.ndarray) -> tuple[float, np.ndarray, np.ndarray]:
|
141
137
|
"""Calculate mean abs percentage error
|
142
138
|
|
143
139
|
If inputs are complex, calculates average: mape(real)/2 + mape(imag)/2
|
@@ -162,12 +158,12 @@ def mean_abs_percentage_error(
|
|
162
158
|
# mean over bins for value per frame
|
163
159
|
err_f = np.around(np.mean(abs_err, axis=1), 3)
|
164
160
|
# mean over all
|
165
|
-
err = np.around(np.mean(abs_err), 3)
|
161
|
+
err = float(np.around(np.mean(abs_err), 3))
|
166
162
|
|
167
163
|
return err, err_b, err_f
|
168
164
|
|
169
165
|
|
170
|
-
def log_error(reference: np.ndarray, hypothesis: np.ndarray) -> tuple[
|
166
|
+
def log_error(reference: np.ndarray, hypothesis: np.ndarray) -> tuple[float, np.ndarray, np.ndarray]:
|
171
167
|
"""Calculate log error
|
172
168
|
|
173
169
|
:param reference: complex or real [frames, bins]
|
@@ -184,7 +180,7 @@ def log_error(reference: np.ndarray, hypothesis: np.ndarray) -> tuple[np.ndarray
|
|
184
180
|
# mean over bins for value per frame
|
185
181
|
err_f = np.around(np.mean(log_err, axis=1), 3)
|
186
182
|
# mean over all
|
187
|
-
err = np.around(np.mean(log_err), 3)
|
183
|
+
err = float(np.around(np.mean(log_err), 3))
|
188
184
|
|
189
185
|
return err, err_b, err_f
|
190
186
|
|
@@ -196,7 +192,7 @@ def plot_mixpred(
|
|
196
192
|
feature: Feature | None = None,
|
197
193
|
predict: Predict | None = None,
|
198
194
|
tp_title: str = "",
|
199
|
-
) -> plt.Figure:
|
195
|
+
) -> plt.Figure: # pyright: ignore [reportPrivateImportUsage]
|
200
196
|
from sonusai.mixture import SAMPLE_RATE
|
201
197
|
|
202
198
|
num_plots = 2
|
@@ -224,22 +220,12 @@ def plot_mixpred(
|
|
224
220
|
|
225
221
|
if feature is not None:
|
226
222
|
p += 1
|
227
|
-
ax[p].imshow(
|
228
|
-
np.transpose(feature),
|
229
|
-
aspect="auto",
|
230
|
-
interpolation="nearest",
|
231
|
-
origin="lower",
|
232
|
-
)
|
223
|
+
ax[p].imshow(np.transpose(feature), aspect="auto", interpolation="nearest", origin="lower")
|
233
224
|
ax[p].set_title("Feature")
|
234
225
|
|
235
226
|
if predict is not None:
|
236
227
|
p += 1
|
237
|
-
im = ax[p].imshow(
|
238
|
-
np.transpose(predict),
|
239
|
-
aspect="auto",
|
240
|
-
interpolation="nearest",
|
241
|
-
origin="lower",
|
242
|
-
)
|
228
|
+
im = ax[p].imshow(np.transpose(predict), aspect="auto", interpolation="nearest", origin="lower")
|
243
229
|
ax[p].set_title("Predict " + tp_title)
|
244
230
|
plt.colorbar(im, location="bottom")
|
245
231
|
|
@@ -251,7 +237,7 @@ def plot_pdb_predict_truth(
|
|
251
237
|
truth_f: np.ndarray | None = None,
|
252
238
|
metric: np.ndarray | None = None,
|
253
239
|
tp_title: str = "",
|
254
|
-
) -> plt.Figure:
|
240
|
+
) -> plt.Figure: # pyright: ignore [reportPrivateImportUsage]
|
255
241
|
"""Plot predict and optionally truth and a metric in power db, e.g. applies 10*log10(predict)"""
|
256
242
|
num_plots = 2
|
257
243
|
if truth_f is not None:
|
@@ -277,24 +263,12 @@ def plot_pdb_predict_truth(
|
|
277
263
|
pred_avg = 10 * np.log10(np.mean(predict, axis=-1) + np.finfo(np.float32).eps)
|
278
264
|
p += 1
|
279
265
|
x_axis = np.arange(len(pred_avg), dtype=np.float32) # / SAMPLE_RATE
|
280
|
-
ax[p].plot(
|
281
|
-
x_axis,
|
282
|
-
pred_avg,
|
283
|
-
color="black",
|
284
|
-
linestyle="dashed",
|
285
|
-
label="Predict mean over freq.",
|
286
|
-
)
|
266
|
+
ax[p].plot(x_axis, pred_avg, color="black", linestyle="dashed", label="Predict mean over freq.")
|
287
267
|
ax[p].set_ylabel("mean db", color="black")
|
288
268
|
ax[p].set_xlim(x_axis[0], x_axis[-1])
|
289
269
|
if truth_f is not None:
|
290
270
|
truth_avg = 10 * np.log10(np.mean(truth_f, axis=-1) + np.finfo(np.float32).eps)
|
291
|
-
ax[p].plot(
|
292
|
-
x_axis,
|
293
|
-
truth_avg,
|
294
|
-
color="green",
|
295
|
-
linestyle="dashed",
|
296
|
-
label="Truth mean over freq.",
|
297
|
-
)
|
271
|
+
ax[p].plot(x_axis, truth_avg, color="green", linestyle="dashed", label="Truth mean over freq.")
|
298
272
|
|
299
273
|
if metric is not None: # instantiate 2nd y-axis that shares the same x-axis
|
300
274
|
ax2 = ax[p].twinx()
|
@@ -317,7 +291,7 @@ def plot_e_predict_truth(
|
|
317
291
|
truth_wav: np.ndarray | None = None,
|
318
292
|
metric: np.ndarray | None = None,
|
319
293
|
tp_title: str = "",
|
320
|
-
) -> plt.Figure:
|
294
|
+
) -> plt.Figure: # pyright: ignore [reportPrivateImportUsage]
|
321
295
|
"""Plot predict spectrogram and waveform and optionally truth and a metric)"""
|
322
296
|
num_plots = 2
|
323
297
|
if truth_f is not None:
|
@@ -335,13 +309,7 @@ def plot_e_predict_truth(
|
|
335
309
|
|
336
310
|
if truth_f is not None: # plot truth if provided and use same colormap as predict
|
337
311
|
p += 1
|
338
|
-
ax[p].imshow(
|
339
|
-
truth_f.transpose(),
|
340
|
-
im.cmap,
|
341
|
-
aspect="auto",
|
342
|
-
interpolation="nearest",
|
343
|
-
origin="lower",
|
344
|
-
)
|
312
|
+
ax[p].imshow(truth_f.transpose(), im.cmap, aspect="auto", interpolation="nearest", origin="lower")
|
345
313
|
ax[p].set_title("Truth")
|
346
314
|
|
347
315
|
# Plot predict wav, and optionally truth avg and metric lines
|
@@ -383,7 +351,17 @@ def plot_e_predict_truth(
|
|
383
351
|
return fig
|
384
352
|
|
385
353
|
|
386
|
-
def _process_mixture(
|
354
|
+
def _process_mixture(
|
355
|
+
m_id: int,
|
356
|
+
truth_location: str,
|
357
|
+
predict_location: str,
|
358
|
+
predict_wav_mode: bool,
|
359
|
+
truth_est_mode: bool,
|
360
|
+
enable_plot: bool,
|
361
|
+
enable_wav: bool,
|
362
|
+
asr_method: str,
|
363
|
+
target_f_key: str,
|
364
|
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
387
365
|
import pickle
|
388
366
|
from os.path import basename
|
389
367
|
from os.path import join
|
@@ -412,19 +390,10 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
412
390
|
from sonusai.utils import unstack_complex
|
413
391
|
from sonusai.utils import write_audio
|
414
392
|
|
415
|
-
|
416
|
-
|
417
|
-
mixdb = MP_GLOBAL.mixdb
|
418
|
-
predict_location = MP_GLOBAL.predict_location
|
419
|
-
predict_wav_mode = MP_GLOBAL.predict_wav_mode
|
420
|
-
truth_est_mode = MP_GLOBAL.truth_est_mode
|
421
|
-
enable_plot = MP_GLOBAL.enable_plot
|
422
|
-
enable_wav = MP_GLOBAL.enable_wav
|
423
|
-
asr_method = MP_GLOBAL.asr_method
|
424
|
-
asr_model_name = MP_GLOBAL.asr_model_name
|
393
|
+
mixdb = MixtureDatabase(truth_location)
|
425
394
|
|
426
|
-
# 1) Read predict data, var predict with shape [BatchSize,Classes] or [
|
427
|
-
output_name = join(predict_location, mixdb.mixture(
|
395
|
+
# 1) Read predict data, var predict with shape [BatchSize,Classes] or [batch, timesteps, classes]
|
396
|
+
output_name = join(predict_location, mixdb.mixture(m_id).name + ".h5")
|
428
397
|
predict = None
|
429
398
|
if truth_est_mode:
|
430
399
|
# in truth estimation mode we use the truth in place of prediction to see metrics with perfect input
|
@@ -439,31 +408,31 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
439
408
|
predict = np.array(f["predict"])
|
440
409
|
except Exception as e:
|
441
410
|
raise OSError(f"Error reading {output_name}: {e}") from e
|
442
|
-
# reshape to always be [frames,classes] where ndim==3 case frames = batch *
|
411
|
+
# reshape to always be [frames, classes] where ndim==3 case frames = batch * timesteps
|
443
412
|
if predict.ndim > 2: # TBD generalize to somehow detect if timestep dim exists, some cases > 2 don't have
|
444
413
|
# logger.debug(f'Prediction reshape from {predict.shape} to remove timestep dimension.')
|
445
414
|
predict, _ = reshape_outputs(predict=predict, truth=None, timesteps=predict.shape[1])
|
446
415
|
else:
|
447
416
|
base_name, ext = splitext(output_name)
|
448
417
|
predict_name = join(base_name + ".wav")
|
449
|
-
audio = read_audio(predict_name)
|
418
|
+
audio = read_audio(predict_name, use_cache=True)
|
450
419
|
predict = forward_transform(audio, mixdb.ft_config)
|
451
420
|
if mixdb.feature[0:1] == "h":
|
452
421
|
predict = power_compress(predict)
|
453
422
|
predict = stack_complex(predict)
|
454
423
|
|
455
424
|
# 2) Collect true target, noise, mixture data, trim to predict size if needed
|
456
|
-
tmp = mixdb.mixture_targets(
|
457
|
-
target_f = mixdb.mixture_targets_f(
|
425
|
+
tmp = mixdb.mixture_targets(m_id) # time-dom augmented targets is list of pre-IR and pre-specaugment targets
|
426
|
+
target_f = mixdb.mixture_targets_f(m_id, targets=tmp)[0]
|
458
427
|
target = tmp[0]
|
459
|
-
mixture = mixdb.mixture_mixture(
|
428
|
+
mixture = mixdb.mixture_mixture(m_id) # note: gives full reverberated/distorted target, but no specaugment
|
460
429
|
# noise_wo_dist = mixdb.mixture_noise(mixid) # noise without specaugment and distortion
|
461
430
|
# noise_wo_dist_f = mixdb.mixture_noise_f(mixid, noise=noise_wo_dist)
|
462
431
|
noise = mixture - target # has time-domain distortion (ir,etc.) but does not have specaugment
|
463
432
|
# noise_f = mixdb.mixture_noise_f(mixid, noise=noise)
|
464
433
|
# note: uses pre-IR, pre-specaug audio
|
465
|
-
segsnr_f
|
466
|
-
mixture_f = mixdb.mixture_mixture_f(
|
434
|
+
segsnr_f = mixdb.mixture_metrics(m_id, ["ssnr"])["ssnr"][0]
|
435
|
+
mixture_f = mixdb.mixture_mixture_f(m_id, mixture=mixture)
|
467
436
|
noise_f = mixture_f - target_f # true noise in freq domain includes specaugment and time-domain ir,distortions
|
468
437
|
# segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise)
|
469
438
|
segsnr_f[segsnr_f == np.inf] = DB_99
|
@@ -476,13 +445,21 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
476
445
|
|
477
446
|
# gen feature, truth - note feature only used for plots
|
478
447
|
# TODO: parse truth_f for different formats
|
479
|
-
feature,
|
448
|
+
feature, truth_all = mixdb.mixture_ft(m_id, mixture_f=mixture_f)
|
449
|
+
truth_f = truth_all[target_f_key]
|
450
|
+
if truth_f.ndim > 2: # note this may not be needed anymore as all target_f truth is 3 dims
|
451
|
+
if truth_f.shape[1] != 1:
|
452
|
+
logger.info("Error: target_f truth has stride > 1, exiting.")
|
453
|
+
raise SystemExit(1)
|
454
|
+
else:
|
455
|
+
truth_f = truth_f[:, 0, :] # remove stride dimension
|
456
|
+
|
480
457
|
# ignore mixup
|
481
|
-
for truth_setting in mixdb.target_file(mixdb.mixture(mixid).targets[0].file_id).
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
458
|
+
# for truth_setting in mixdb.target_file(mixdb.mixture(mixid).targets[0].file_id).truth_settings:
|
459
|
+
# if truth_setting.function == 'target_mixture_f':
|
460
|
+
# half = truth_f.shape[-1] // 2
|
461
|
+
# # extract target_f only
|
462
|
+
# truth_f = truth_f[..., :half]
|
486
463
|
|
487
464
|
if not truth_est_mode:
|
488
465
|
if predict.shape[0] < target_f.shape[0]: # target_f, truth_f, mixture_f, etc. same size
|
@@ -503,15 +480,17 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
503
480
|
)
|
504
481
|
trim_f = predict.shape[0] - target_f.shape[0]
|
505
482
|
predict = predict[0:-trim_f, :]
|
506
|
-
# raise
|
483
|
+
# raise SonusAIError(
|
507
484
|
# f'Error: prediction has more frames than true mixture {predict.shape[0]} vs {truth_f.shape[0]}')
|
508
485
|
|
509
486
|
# 3) Extraction - format proper complex and wav estimates and truth (unstack, uncompress, inv tf, etc.)
|
510
487
|
if truth_est_mode:
|
511
488
|
predict = truth_f # substitute truth for the prediction (for test/debug)
|
512
489
|
predict_complex = unstack_complex(predict) # unstack
|
513
|
-
# if
|
514
|
-
if mixdb.feature[0:1] == "h" and mixdb.
|
490
|
+
# if feature has compressed mag and truth does not, compress it
|
491
|
+
if mixdb.feature[0:1] == "h" and not mixdb.truth_configs[first_key(mixdb.truth_configs)].function.startswith(
|
492
|
+
"targetcmpr"
|
493
|
+
):
|
515
494
|
predict_complex = power_compress(predict_complex) # from uncompressed truth
|
516
495
|
else:
|
517
496
|
predict_complex = unstack_complex(predict)
|
@@ -556,10 +535,14 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
556
535
|
# logger.debug(f'wsdr ccoefs for mixid {mixid} = {wsdr_cc}.')
|
557
536
|
|
558
537
|
# Speech intelligibility measure - PESQ
|
559
|
-
if int(mixdb.mixture(
|
538
|
+
if int(mixdb.mixture(m_id).snr) > -99:
|
560
539
|
# len = target_est_wav.shape[0]
|
561
540
|
pesq_speech, csig_tg, cbak_tg, covl_tg = calc_speech(target_est_wav, target_fi)
|
562
|
-
|
541
|
+
metrics = mixdb.mixture_metrics(m_id, ["mxpesq", "mxcsig", "mxcbak", "mxcovl"])
|
542
|
+
pesq_mixture = metrics["mxpesq"]
|
543
|
+
csig_mx = metrics["mxcsig"]
|
544
|
+
cbak_mx = metrics["mxcbak"]
|
545
|
+
covl_mx = metrics["mxcovl"]
|
563
546
|
# pesq_speech_tst = calc_pesq(hypothesis=target_est_wav, reference=target)
|
564
547
|
# pesq_mixture_tst = calc_pesq(hypothesis=mixture, reference=target)
|
565
548
|
# pesq improvement
|
@@ -581,25 +564,37 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
581
564
|
asr_tt = None
|
582
565
|
asr_mx = None
|
583
566
|
asr_tge = None
|
584
|
-
asr_engines = list(mixdb.asr_configs.keys())
|
585
|
-
if
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
567
|
+
# asr_engines = list(mixdb.asr_configs.keys())
|
568
|
+
if asr_method is not None and mixdb.mixture(m_id).snr >= -96: # noise only, ignore/reset target ASR
|
569
|
+
asr_mx_name = f"mxasr.{asr_method}"
|
570
|
+
wer_mx_name = f"mxwer.{asr_method}"
|
571
|
+
asr_tt_name = f"tasr.{asr_method}"
|
572
|
+
metrics = mixdb.mixture_metrics(m_id, [asr_mx_name, wer_mx_name, asr_tt_name])
|
573
|
+
asr_mx = metrics[asr_mx_name][0]
|
574
|
+
wer_mx = metrics[wer_mx_name][0]
|
575
|
+
asr_tt = metrics[asr_tt_name][0]
|
590
576
|
|
591
577
|
if asr_tt:
|
592
|
-
|
593
|
-
|
578
|
+
noiseadd = None # TBD add as switch, default -30
|
579
|
+
if noiseadd is not None:
|
580
|
+
ngain = np.power(10, min(float(noiseadd), 0.0) / 20.0) # limit to gain <1, convert to float
|
581
|
+
tgasr_est_wav = target_est_wav + ngain * noise_est_wav # add back noise at low level
|
582
|
+
else:
|
583
|
+
tgasr_est_wav = target_est_wav
|
584
|
+
|
585
|
+
# logger.info(f'Calculating prediction ASR for mixid {mixid}')
|
586
|
+
asr_cfg = mixdb.asr_configs[asr_method]
|
587
|
+
asr_tge = calc_asr(tgasr_est_wav, **asr_cfg).text
|
588
|
+
wer_tge = calc_wer(asr_tge, asr_tt).wer * 100 # target estimate WER
|
594
589
|
if wer_mx == 0.0:
|
595
590
|
if wer_tge == 0.0:
|
596
591
|
wer_pi = 0.0
|
597
592
|
else:
|
598
|
-
wer_pi = -999.0
|
593
|
+
wer_pi = -999.0 # instead of -Inf
|
599
594
|
else:
|
600
595
|
wer_pi = 100 * (wer_mx - wer_tge) / wer_mx
|
601
596
|
else:
|
602
|
-
|
597
|
+
logger.warning(f"Warning: mixid {m_id} ASR truth is empty, setting to 0% WER")
|
603
598
|
wer_mx = float(0)
|
604
599
|
wer_tge = float(0)
|
605
600
|
wer_pi = float(0)
|
@@ -633,10 +628,10 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
633
628
|
"SPFILE",
|
634
629
|
"NFILE",
|
635
630
|
]
|
636
|
-
ti = mixdb.mixture(
|
637
|
-
ni = mixdb.mixture(
|
631
|
+
ti = mixdb.mixture(m_id).targets[0].file_id
|
632
|
+
ni = mixdb.mixture(m_id).noise.file_id
|
638
633
|
metr1 = [
|
639
|
-
mixdb.mixture(
|
634
|
+
mixdb.mixture(m_id).snr,
|
640
635
|
pesq_mixture,
|
641
636
|
pesq_speech,
|
642
637
|
pesq_impr_pc,
|
@@ -658,17 +653,11 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
658
653
|
basename(mixdb.target_file(ti).name),
|
659
654
|
basename(mixdb.noise_file(ni).name),
|
660
655
|
]
|
661
|
-
mtab1 = pd.DataFrame([metr1], columns=mtable1_col, index=[
|
656
|
+
mtab1 = pd.DataFrame([metr1], columns=mtable1_col, index=[m_id])
|
662
657
|
|
663
658
|
# Stats of per frame estimation metrics
|
664
659
|
metr2 = pd.DataFrame(
|
665
|
-
{
|
666
|
-
"SSNR": segsnr_f,
|
667
|
-
"PCM": pcm_frame,
|
668
|
-
"SLERR": lerr_tg_frame,
|
669
|
-
"NLERR": lerr_n_frame,
|
670
|
-
"SPD": phd_frame,
|
671
|
-
}
|
660
|
+
{"SSNR": segsnr_f, "PCM": pcm_frame, "SLERR": lerr_tg_frame, "NLERR": lerr_n_frame, "SPD": phd_frame}
|
672
661
|
)
|
673
662
|
metr2 = metr2.describe() # Use pandas stat function
|
674
663
|
# Change SSNR stats to dB, except count. SSNR is index 0, pandas requires using iloc
|
@@ -679,29 +668,33 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
679
668
|
[metr2.columns, ["Avg", "Min", "Med", "Max", "Std"]], names=["Metric", "Stat"]
|
680
669
|
)
|
681
670
|
dat1row = metr2.loc[["mean", "min", "50%", "max", "std"], :].T.stack().to_numpy().reshape((1, -1))
|
682
|
-
mtab2 = pd.DataFrame(dat1row, index=[
|
683
|
-
mtab2.insert(0, "MXSNR", mixdb.mixture(
|
671
|
+
mtab2 = pd.DataFrame(dat1row, index=[m_id], columns=new_labels)
|
672
|
+
mtab2.insert(0, "MXSNR", mixdb.mixture(m_id).snr, False) # add MXSNR as the first metric column
|
684
673
|
|
685
674
|
all_metrics_table_1 = mtab1 # return to be collected by process
|
686
675
|
all_metrics_table_2 = mtab2 # return to be collected by process
|
687
676
|
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
print(
|
695
|
-
print("")
|
696
|
-
print(
|
697
|
-
print(f"
|
677
|
+
if asr_method is None:
|
678
|
+
metric_name = base_name + "_metric_spenh.txt"
|
679
|
+
else:
|
680
|
+
metric_name = base_name + "_metric_spenh_" + asr_method + ".txt"
|
681
|
+
|
682
|
+
with open(metric_name, "w") as f:
|
683
|
+
print("Speech enhancement metrics:", file=f)
|
684
|
+
print(mtab1.round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
|
685
|
+
print("", file=f)
|
686
|
+
print(f"Extraction statistics over {mixture_f.shape[0]} frames:", file=f)
|
687
|
+
print(metr2.round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
|
688
|
+
print("", file=f)
|
689
|
+
print(f"Target path: {mixdb.target_file(ti).name}", file=f)
|
690
|
+
print(f"Noise path: {mixdb.noise_file(ni).name}", file=f)
|
698
691
|
if asr_method != "none":
|
699
|
-
print(f"ASR method: {asr_method}
|
700
|
-
print(f"ASR truth: {asr_tt}")
|
701
|
-
print(f"ASR result for mixture: {asr_mx}")
|
702
|
-
print(f"ASR result for prediction: {asr_tge}")
|
692
|
+
print(f"ASR method: {asr_method}", file=f)
|
693
|
+
print(f"ASR truth: {asr_tt}", file=f)
|
694
|
+
print(f"ASR result for mixture: {asr_mx}", file=f)
|
695
|
+
print(f"ASR result for prediction: {asr_tge}", file=f)
|
703
696
|
|
704
|
-
print(f"Augmentations: {mixdb.mixture(
|
697
|
+
print(f"Augmentations: {mixdb.mixture(m_id)}", file=f)
|
705
698
|
|
706
699
|
# 7) write wav files
|
707
700
|
if enable_wav:
|
@@ -728,7 +721,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
728
721
|
# Reshape to get frames*decimated_stride, num_bands
|
729
722
|
step = int(mixdb.feature_samples / mixdb.feature_step_samples)
|
730
723
|
if feature.ndim != 3:
|
731
|
-
raise
|
724
|
+
raise OSError("feature does not have 3 dimensions: frames, stride, num_bands")
|
732
725
|
|
733
726
|
# for feature cn*00n**
|
734
727
|
feat_sgram = unstack_complex(feature)
|
@@ -738,17 +731,19 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
738
731
|
|
739
732
|
with PdfPages(plot_name) as pdf:
|
740
733
|
# page1 we always have a mixture and prediction, target optional if truth provided
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
734
|
+
# For speech enhancement, target_f is definitely included:
|
735
|
+
predplot = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
|
736
|
+
tfunc_name = "target_f"
|
737
|
+
# if tfunc_name == 'mapped_snr_f':
|
738
|
+
# # leave as unmapped snr
|
739
|
+
# predplot = predict
|
740
|
+
# tfunc_name = mixdb.target_file(1).truth_settings[0].function
|
741
|
+
# elif tfunc_name == 'target_f' or 'target_mixture_f':
|
742
|
+
# predplot = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
|
743
|
+
# else:
|
744
|
+
# # use dB scale
|
745
|
+
# predplot = 10 * np.log10(predict + np.finfo(np.float32).eps)
|
746
|
+
# tfunc_name = tfunc_name + ' (db)'
|
752
747
|
|
753
748
|
mixspec = 20 * np.log10(abs(mixture_f) + np.finfo(np.float32).eps)
|
754
749
|
fig_obj = plot_mixpred(
|
@@ -816,8 +811,7 @@ def main():
|
|
816
811
|
|
817
812
|
verbose = args["--verbose"]
|
818
813
|
mixids = args["--mixid"]
|
819
|
-
asr_method = args["--asr-method"]
|
820
|
-
asr_model_name = args["--model"].lower()
|
814
|
+
asr_method = args["--asr-method"]
|
821
815
|
truth_est_mode = args["--truth-est-mode"]
|
822
816
|
enable_plot = args["--plot"]
|
823
817
|
enable_wav = args["--wav"]
|
@@ -827,6 +821,7 @@ def main():
|
|
827
821
|
truth_location = args["TLOC"]
|
828
822
|
|
829
823
|
import glob
|
824
|
+
from functools import partial
|
830
825
|
from os.path import basename
|
831
826
|
from os.path import isdir
|
832
827
|
from os.path import join
|
@@ -837,16 +832,13 @@ def main():
|
|
837
832
|
from sonusai import initial_log_messages
|
838
833
|
from sonusai import logger
|
839
834
|
from sonusai import update_console_handler
|
840
|
-
from sonusai.mixture import DEFAULT_SPEECH
|
841
835
|
from sonusai.mixture import MixtureDatabase
|
842
|
-
from sonusai.mixture import read_audio
|
843
|
-
from sonusai.utils import calc_asr
|
844
836
|
from sonusai.utils import par_track
|
845
837
|
from sonusai.utils import track
|
846
838
|
|
847
839
|
# Check prediction subdirectory
|
848
840
|
if not isdir(predict_location):
|
849
|
-
print(f"The specified predict location {predict_location} is not a valid subdirectory path, exiting
|
841
|
+
print(f"The specified predict location {predict_location} is not a valid subdirectory path, exiting.")
|
850
842
|
|
851
843
|
# all_predict_files = listdir(predict_location)
|
852
844
|
all_predict_files = glob.glob(predict_location + "/*.h5")
|
@@ -855,7 +847,7 @@ def main():
|
|
855
847
|
if len(all_predict_files) <= 0 and not truth_est_mode:
|
856
848
|
all_predict_files = glob.glob(predict_location + "/*.wav") # check for wav files
|
857
849
|
if len(all_predict_files) <= 0:
|
858
|
-
print(f"Subdirectory {predict_location} has no .h5 or .wav files, exiting
|
850
|
+
print(f"Subdirectory {predict_location} has no .h5 or .wav files, exiting.")
|
859
851
|
else:
|
860
852
|
logger.info(f"Found {len(all_predict_files)} prediction .wav files.")
|
861
853
|
predict_wav_mode = True
|
@@ -877,59 +869,40 @@ def main():
|
|
877
869
|
logger.info(
|
878
870
|
f"Found mixdb of {mixdb.num_mixtures} total mixtures, with {mixdb.num_classes} classes in {truth_location}"
|
879
871
|
)
|
880
|
-
|
872
|
+
# speech enhancement metrics and audio truth requires target_f truth type, check it is present
|
873
|
+
target_f_key = None
|
874
|
+
logger.info(f"mixdb has {len(mixdb.truth_configs)} truth types defined, checking that target_f type is present.")
|
875
|
+
for key in mixdb.truth_configs:
|
876
|
+
if mixdb.truth_configs[key].function == "target_f":
|
877
|
+
target_f_key = key
|
878
|
+
if target_f_key is None:
|
879
|
+
logger.error("mixdb does not have target_f truth define, required for speech enhancement metrics, exiting.")
|
880
|
+
raise SystemExit(1)
|
881
881
|
|
882
|
-
|
883
|
-
if asr_method == "none":
|
884
|
-
fnb = "metric_spenh_"
|
885
|
-
elif asr_method == "google":
|
886
|
-
fnb = "metric_spenh_ggl_"
|
887
|
-
logger.info(f"ASR enabled with method {asr_method}")
|
888
|
-
enable_asr_warmup = True
|
889
|
-
elif asr_method == "deepgram":
|
890
|
-
fnb = "metric_spenh_dgram_"
|
891
|
-
logger.info(f"ASR enabled with method {asr_method}")
|
892
|
-
enable_asr_warmup = True
|
893
|
-
elif asr_method == "aixplain_whisper":
|
894
|
-
fnb = "metric_spenh_whspx_" + asr_model_name + "_"
|
895
|
-
logger.info(f"ASR enabled with method {asr_method} and whisper model {asr_model_name}")
|
896
|
-
enable_asr_warmup = True
|
897
|
-
elif asr_method == "whisper":
|
898
|
-
fnb = "metric_spenh_whspl_" + asr_model_name + "_"
|
899
|
-
logger.info(f"ASR enabled with method {asr_method} and whisper model {asr_model_name}")
|
900
|
-
enable_asr_warmup = True
|
901
|
-
elif asr_method == "aaware_whisper":
|
902
|
-
fnb = "metric_spenh_whspaaw_" + asr_model_name + "_"
|
903
|
-
logger.info(f"ASR enabled with method {asr_method} and whisper model {asr_model_name}")
|
904
|
-
enable_asr_warmup = True
|
905
|
-
elif asr_method == "faster_whisper":
|
906
|
-
fnb = "metric_spenh_fwhsp_" + asr_model_name + "_"
|
907
|
-
logger.info(f"ASR enabled with method {asr_method} and whisper model {asr_model_name}")
|
908
|
-
enable_asr_warmup = True
|
909
|
-
elif asr_method == "sensory":
|
910
|
-
fnb = "metric_spenh_snsr_" + asr_model_name + "_"
|
911
|
-
logger.info(f"ASR enabled with method {asr_method} and model {asr_model_name}")
|
912
|
-
enable_asr_warmup = True
|
913
|
-
else:
|
914
|
-
logger.error(f"Unrecognized ASR method: {asr_method}")
|
915
|
-
return
|
916
|
-
|
917
|
-
if enable_asr_warmup:
|
918
|
-
audio = read_audio(DEFAULT_SPEECH)
|
919
|
-
logger.info("Warming up asr method, note for cloud service this could take up to a few min ...")
|
920
|
-
asr_chk = calc_asr(audio, engine=asr_method, whisper_model_name=asr_model_name)
|
921
|
-
logger.info(f"Warmup completed, results {asr_chk}")
|
922
|
-
|
923
|
-
global MP_GLOBAL
|
882
|
+
logger.info(f"Only running specified subset of {len(mixids)} mixtures")
|
924
883
|
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
884
|
+
asr_config_en = None
|
885
|
+
fnb = "metric_spenh_"
|
886
|
+
if asr_method is not None:
|
887
|
+
if asr_method in mixdb.asr_configs:
|
888
|
+
logger.info(f"Specified ASR method {asr_method} exists in mixdb.asr_configs, it will be used for ")
|
889
|
+
logger.info("prediction ASR and WER, and pre-calculated target and mixture ASR if available.")
|
890
|
+
asr_config_en = True
|
891
|
+
asr_cfg = mixdb.asr_configs[asr_method]
|
892
|
+
fnb = "metric_spenh_" + asr_method + "_"
|
893
|
+
logger.info(f"Using ASR cfg: {asr_cfg} ")
|
894
|
+
# audio = read_audio(DEFAULT_SPEECH, use_cache=True)
|
895
|
+
# logger.info(f'Warming up {asr_method}, note for cloud service this could take up to a few minutes.')
|
896
|
+
# asr_chk = calc_asr(audio, **asr_cfg)
|
897
|
+
# logger.info(f'Warmup completed, results {asr_chk}')
|
898
|
+
else:
|
899
|
+
logger.info(
|
900
|
+
f"Specified ASR method {asr_method} does not exists in mixdb.asr_configs."
|
901
|
+
f"Must choose one of the following (or none):"
|
902
|
+
)
|
903
|
+
logger.info(f"{', '.join(mixdb.asr_configs)}")
|
904
|
+
logger.error("Unrecognized ASR method, exiting.")
|
905
|
+
raise SystemExit(1)
|
933
906
|
|
934
907
|
num_cpu = psutil.cpu_count()
|
935
908
|
cpu_percent = psutil.cpu_percent(interval=1)
|
@@ -944,12 +917,33 @@ def main():
|
|
944
917
|
|
945
918
|
# Individual mixtures use pandas print, set precision to 2 decimal places
|
946
919
|
# pd.set_option('float_format', '{:.2f}'.format)
|
947
|
-
logger.info(f"Calculating metrics for {len(mixids)} mixtures using {use_cpu} parallel processes
|
948
|
-
progress =
|
920
|
+
logger.info(f"Calculating metrics for {len(mixids)} mixtures using {use_cpu} parallel processes")
|
921
|
+
# progress = tqdm(total=len(mixids), desc='calc_metric_spenh', mininterval=1)
|
922
|
+
progress = track(total=len(mixids))
|
949
923
|
if use_cpu is None:
|
950
|
-
|
924
|
+
no_par = True
|
925
|
+
num_cpus = None
|
951
926
|
else:
|
952
|
-
|
927
|
+
no_par = True
|
928
|
+
num_cpus = None
|
929
|
+
|
930
|
+
all_metrics_tables = par_track(
|
931
|
+
partial(
|
932
|
+
_process_mixture,
|
933
|
+
truth_location=truth_location,
|
934
|
+
predict_location=predict_location,
|
935
|
+
predict_wav_mode=predict_wav_mode,
|
936
|
+
truth_est_mode=truth_est_mode,
|
937
|
+
enable_plot=enable_plot,
|
938
|
+
enable_wav=enable_wav,
|
939
|
+
asr_method=asr_method,
|
940
|
+
target_f_key=target_f_key,
|
941
|
+
),
|
942
|
+
mixids,
|
943
|
+
progress=progress,
|
944
|
+
num_cpus=num_cpus,
|
945
|
+
no_par=no_par,
|
946
|
+
)
|
953
947
|
progress.close()
|
954
948
|
|
955
949
|
all_metrics_table_1 = pd.concat([item[0] for item in all_metrics_tables])
|
@@ -1010,7 +1004,7 @@ def main():
|
|
1010
1004
|
all_nom99_mean["WERi%"] = 0.0
|
1011
1005
|
else:
|
1012
1006
|
all_nom99_mean["WERi%"] = -999.0
|
1013
|
-
else: #
|
1007
|
+
else: # WER%
|
1014
1008
|
all_nom99_mean["WERi%"] = 100 * (all_nom99_mean["MXWER"] - all_nom99_mean["WER"]) / all_nom99_mean["MXWER"]
|
1015
1009
|
|
1016
1010
|
num_mix = len(mixids)
|
@@ -1023,33 +1017,37 @@ def main():
|
|
1023
1017
|
else:
|
1024
1018
|
ofname = join(predict_location, fnb + "summary_truest.txt")
|
1025
1019
|
|
1026
|
-
with open(ofname, "w") as f
|
1027
|
-
print(f"ASR enabled with method {asr_method},
|
1028
|
-
print(
|
1029
|
-
|
1030
|
-
|
1031
|
-
print(
|
1032
|
-
|
1033
|
-
|
1020
|
+
with open(ofname, "w") as f:
|
1021
|
+
print(f"ASR enabled with method {asr_method}", file=f)
|
1022
|
+
print(
|
1023
|
+
f"Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:", file=f
|
1024
|
+
)
|
1025
|
+
print(
|
1026
|
+
all_nom99_mean.to_frame().T.round(2).to_string(float_format=lambda x: f"{x:.2f}", index=False), file=f
|
1027
|
+
)
|
1028
|
+
print("\nSpeech enhancement metrics avg over each SNR:", file=f)
|
1029
|
+
print(mtab_snr_summary.round(2).to_string(float_format=lambda x: f"{x:.2f}", index=False), file=f)
|
1030
|
+
print("", file=f)
|
1031
|
+
print("Extraction statistics stats avg over each SNR:", file=f)
|
1034
1032
|
# with pd.option_context('display.max_colwidth', 9):
|
1035
1033
|
# with pd.set_option('float_format', '{:.1f}'.format):
|
1036
|
-
print(mtab_snr_summary_em.round(1).to_string(float_format=lambda x: f"{x:.1f}", index=False))
|
1037
|
-
print("")
|
1034
|
+
print(mtab_snr_summary_em.round(1).to_string(float_format=lambda x: f"{x:.1f}", index=False), file=f)
|
1035
|
+
print("", file=f)
|
1038
1036
|
# pd.set_option('float_format', '{:.2f}'.format)
|
1039
1037
|
|
1040
|
-
print(f"Speech enhancement metrics stats over all {num_mix} mixtures:")
|
1041
|
-
print(all_metrics_table_1.describe().round(2).to_string(float_format=lambda x: f"{x:.2f}"))
|
1042
|
-
print("")
|
1043
|
-
print(f"Extraction statistics stats over all {num_mix} mixtures:")
|
1044
|
-
print(all_metrics_table_2.describe().round(2).to_string(float_format=lambda x: f"{x:.1f}"))
|
1045
|
-
print("")
|
1038
|
+
print(f"Speech enhancement metrics stats over all {num_mix} mixtures:", file=f)
|
1039
|
+
print(all_metrics_table_1.describe().round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
|
1040
|
+
print("", file=f)
|
1041
|
+
print(f"Extraction statistics stats over all {num_mix} mixtures:", file=f)
|
1042
|
+
print(all_metrics_table_2.describe().round(2).to_string(float_format=lambda x: f"{x:.1f}"), file=f)
|
1043
|
+
print("", file=f)
|
1046
1044
|
|
1047
|
-
print("Speech enhancement metrics all-mixtures list:")
|
1048
|
-
# print(all_metrics_table_1.head().style.format(precision=2))
|
1049
|
-
print(all_metrics_table_1.round(2).to_string(float_format=lambda x: f"{x:.2f}"))
|
1050
|
-
print("")
|
1051
|
-
print("Extraction statistics all-mixtures list:")
|
1052
|
-
print(all_metrics_table_2.round(2).to_string(float_format=lambda x: f"{x:.1f}"))
|
1045
|
+
print("Speech enhancement metrics all-mixtures list:", file=f)
|
1046
|
+
# print(all_metrics_table_1.head().style.format(precision=2), file=f)
|
1047
|
+
print(all_metrics_table_1.round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
|
1048
|
+
print("", file=f)
|
1049
|
+
print("Extraction statistics all-mixtures list:", file=f)
|
1050
|
+
print(all_metrics_table_2.round(2).to_string(float_format=lambda x: f"{x:.1f}"), file=f)
|
1053
1051
|
|
1054
1052
|
# Write summary to .csv file
|
1055
1053
|
if not truth_est_mode:
|
@@ -1084,7 +1082,7 @@ def main():
|
|
1084
1082
|
label = f"Extraction statistics stats over {num_mix} mixtures:"
|
1085
1083
|
pd.DataFrame([label]).to_csv(csv_name, **header_args)
|
1086
1084
|
all_metrics_table_2.describe().round(2).to_csv(csv_name, **table_args)
|
1087
|
-
label = f"ASR enabled with method {asr_method}
|
1085
|
+
label = f"ASR enabled with method {asr_method}"
|
1088
1086
|
pd.DataFrame([label]).to_csv(csv_name, **header_args)
|
1089
1087
|
|
1090
1088
|
if not truth_est_mode:
|
@@ -1104,3 +1102,37 @@ def main():
|
|
1104
1102
|
|
1105
1103
|
if __name__ == "__main__":
|
1106
1104
|
main()
|
1105
|
+
|
1106
|
+
# if asr_method == 'none':
|
1107
|
+
# fnb = 'metric_spenh_'
|
1108
|
+
# elif asr_method == 'google':
|
1109
|
+
# fnb = 'metric_spenh_ggl_'
|
1110
|
+
# logger.info(f'ASR enabled with method {asr_method}')
|
1111
|
+
# enable_asr_warmup = True
|
1112
|
+
# elif asr_method == 'deepgram':
|
1113
|
+
# fnb = 'metric_spenh_dgram_'
|
1114
|
+
# logger.info(f'ASR enabled with method {asr_method}')
|
1115
|
+
# enable_asr_warmup = True
|
1116
|
+
# elif asr_method == 'aixplain_whisper':
|
1117
|
+
# fnb = 'metric_spenh_whspx_' + mixdb.asr_configs[asr_method]['model'] + '_'
|
1118
|
+
# asr_model_name = mixdb.asr_configs[asr_method]['model']
|
1119
|
+
# enable_asr_warmup = True
|
1120
|
+
# elif asr_method == 'whisper':
|
1121
|
+
# fnb = 'metric_spenh_whspl_' + mixdb.asr_configs[asr_method]['model'] + '_'
|
1122
|
+
# asr_model_name = mixdb.asr_configs[asr_method]['model']
|
1123
|
+
# enable_asr_warmup = True
|
1124
|
+
# elif asr_method == 'aaware_whisper':
|
1125
|
+
# fnb = 'metric_spenh_whspaaw_' + mixdb.asr_configs[asr_method]['model'] + '_'
|
1126
|
+
# asr_model_name = mixdb.asr_configs[asr_method]['model']
|
1127
|
+
# enable_asr_warmup = True
|
1128
|
+
# elif asr_method == 'faster_whisper':
|
1129
|
+
# fnb = 'metric_spenh_fwhsp_' + mixdb.asr_configs[asr_method]['model'] + '_'
|
1130
|
+
# asr_model_name = mixdb.asr_configs[asr_method]['model']
|
1131
|
+
# enable_asr_warmup = True
|
1132
|
+
# elif asr_method == 'sensory':
|
1133
|
+
# fnb = 'metric_spenh_snsr_' + mixdb.asr_configs[asr_method]['model'] + '_'
|
1134
|
+
# asr_model_name = mixdb.asr_configs[asr_method]['model']
|
1135
|
+
# enable_asr_warmup = True
|
1136
|
+
# else:
|
1137
|
+
# logger.error(f'Unrecognized ASR method: {asr_method}')
|
1138
|
+
# return
|