sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +16 -3
- sonusai/audiofe.py +241 -77
- sonusai/calc_metric_spenh.py +71 -73
- sonusai/config/__init__.py +3 -0
- sonusai/config/config.py +61 -0
- sonusai/config/config.yml +20 -0
- sonusai/config/constants.py +8 -0
- sonusai/constants.py +11 -0
- sonusai/data/genmixdb.yml +21 -36
- sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
- sonusai/deprecated/plot.py +4 -5
- sonusai/doc/doc.py +4 -4
- sonusai/doc.py +11 -4
- sonusai/genft.py +43 -45
- sonusai/genmetrics.py +25 -19
- sonusai/genmix.py +54 -82
- sonusai/genmixdb.py +88 -264
- sonusai/ir_metric.py +30 -34
- sonusai/lsdb.py +41 -48
- sonusai/main.py +15 -22
- sonusai/metrics/calc_audio_stats.py +4 -293
- sonusai/metrics/calc_class_weights.py +4 -4
- sonusai/metrics/calc_optimal_thresholds.py +8 -5
- sonusai/metrics/calc_pesq.py +2 -2
- sonusai/metrics/calc_segsnr_f.py +4 -4
- sonusai/metrics/calc_speech.py +25 -13
- sonusai/metrics/class_summary.py +7 -7
- sonusai/metrics/confusion_matrix_summary.py +5 -5
- sonusai/metrics/one_hot.py +4 -4
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/metrics_summary.py +38 -45
- sonusai/mixture/__init__.py +4 -104
- sonusai/mixture/audio.py +10 -39
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/config.py +251 -271
- sonusai/mixture/constants.py +35 -39
- sonusai/mixture/data_io.py +25 -36
- sonusai/mixture/db_datatypes.py +58 -22
- sonusai/mixture/effects.py +386 -0
- sonusai/mixture/feature.py +7 -11
- sonusai/mixture/generation.py +478 -628
- sonusai/mixture/helpers.py +82 -184
- sonusai/mixture/ir_delay.py +3 -4
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +6 -12
- sonusai/mixture/mixdb.py +910 -729
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +2 -2
- sonusai/mixture/truth.py +17 -15
- sonusai/mixture/truth_functions/crm.py +12 -12
- sonusai/mixture/truth_functions/energy.py +22 -22
- sonusai/mixture/truth_functions/file.py +5 -5
- sonusai/mixture/truth_functions/metadata.py +4 -4
- sonusai/mixture/truth_functions/metrics.py +4 -4
- sonusai/mixture/truth_functions/phoneme.py +3 -3
- sonusai/mixture/truth_functions/sed.py +11 -13
- sonusai/mixture/truth_functions/target.py +10 -10
- sonusai/mkwav.py +26 -29
- sonusai/onnx_predict.py +240 -88
- sonusai/queries/__init__.py +2 -2
- sonusai/queries/queries.py +38 -34
- sonusai/speech/librispeech.py +1 -1
- sonusai/speech/mcgill.py +1 -1
- sonusai/speech/timit.py +2 -2
- sonusai/summarize_metric_spenh.py +10 -17
- sonusai/utils/__init__.py +7 -1
- sonusai/utils/asl_p56.py +2 -2
- sonusai/utils/asr.py +2 -2
- sonusai/utils/asr_functions/aaware_whisper.py +4 -5
- sonusai/utils/choice.py +31 -0
- sonusai/utils/compress.py +1 -1
- sonusai/utils/dataclass_from_dict.py +19 -1
- sonusai/utils/energy_f.py +3 -3
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/onnx_utils.py +3 -17
- sonusai/utils/print_mixture_details.py +21 -19
- sonusai/utils/{temp_seed.py → rand.py} +3 -3
- sonusai/utils/read_predict_data.py +2 -2
- sonusai/utils/reshape.py +3 -3
- sonusai/utils/stratified_shuffle_split.py +3 -3
- sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
- sonusai/utils/write_audio.py +2 -2
- sonusai/vars.py +11 -4
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
- sonusai-1.0.2.dist-info/RECORD +138 -0
- sonusai/mixture/augmentation.py +0 -444
- sonusai/mixture/class_count.py +0 -15
- sonusai/mixture/eq_rule_is_valid.py +0 -45
- sonusai/mixture/target_class_balancing.py +0 -107
- sonusai/mixture/targets.py +0 -175
- sonusai-0.20.3.dist-info/RECORD +0 -128
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
sonusai/calc_metric_spenh.py
CHANGED
@@ -67,34 +67,23 @@ Inputs:
|
|
67
67
|
|
68
68
|
"""
|
69
69
|
|
70
|
-
import
|
70
|
+
from typing import Any
|
71
71
|
|
72
72
|
import matplotlib
|
73
73
|
import matplotlib.pyplot as plt
|
74
74
|
import numpy as np
|
75
75
|
import pandas as pd
|
76
76
|
|
77
|
-
from sonusai.
|
78
|
-
from sonusai.
|
79
|
-
from sonusai.
|
77
|
+
from sonusai.datatypes import AudioF
|
78
|
+
from sonusai.datatypes import AudioT
|
79
|
+
from sonusai.datatypes import Feature
|
80
|
+
from sonusai.datatypes import Predict
|
80
81
|
from sonusai.mixture import MixtureDatabase
|
81
|
-
from sonusai.mixture import Predict
|
82
82
|
|
83
83
|
DB_99 = np.power(10, 99 / 10)
|
84
84
|
DB_N99 = np.power(10, -99 / 10)
|
85
85
|
|
86
86
|
|
87
|
-
def signal_handler(_sig, _frame):
|
88
|
-
import sys
|
89
|
-
|
90
|
-
from sonusai import logger
|
91
|
-
|
92
|
-
logger.info("Canceled due to keyboard interrupt")
|
93
|
-
sys.exit(1)
|
94
|
-
|
95
|
-
|
96
|
-
signal.signal(signal.SIGINT, signal_handler)
|
97
|
-
|
98
87
|
matplotlib.use("SVG")
|
99
88
|
|
100
89
|
|
@@ -192,8 +181,8 @@ def plot_mixpred(
|
|
192
181
|
feature: Feature | None = None,
|
193
182
|
predict: Predict | None = None,
|
194
183
|
tp_title: str = "",
|
195
|
-
) -> plt.Figure: # pyright: ignore [reportPrivateImportUsage]
|
196
|
-
from sonusai.
|
184
|
+
) -> tuple[plt.Figure, Any]: # pyright: ignore [reportPrivateImportUsage]
|
185
|
+
from sonusai.constants import SAMPLE_RATE
|
197
186
|
|
198
187
|
num_plots = 2
|
199
188
|
if feature is not None:
|
@@ -229,7 +218,7 @@ def plot_mixpred(
|
|
229
218
|
ax[p].set_title("Predict " + tp_title)
|
230
219
|
plt.colorbar(im, location="bottom")
|
231
220
|
|
232
|
-
return fig
|
221
|
+
return fig, ax
|
233
222
|
|
234
223
|
|
235
224
|
def plot_pdb_predict_truth(
|
@@ -291,7 +280,7 @@ def plot_e_predict_truth(
|
|
291
280
|
truth_wav: np.ndarray | None = None,
|
292
281
|
metric: np.ndarray | None = None,
|
293
282
|
tp_title: str = "",
|
294
|
-
) -> plt.Figure: # pyright: ignore [reportPrivateImportUsage]
|
283
|
+
) -> tuple[plt.Figure, Any]: # pyright: ignore [reportPrivateImportUsage]
|
295
284
|
"""Plot predict spectrogram and waveform and optionally truth and a metric)"""
|
296
285
|
num_plots = 2
|
297
286
|
if truth_f is not None:
|
@@ -337,18 +326,19 @@ def plot_e_predict_truth(
|
|
337
326
|
ax[p].set_xlim(x_axis[0], x_axis[-1])
|
338
327
|
ax[p].set_ylim([-0.01, np.max(metric1) + 0.01])
|
339
328
|
if metric.ndim > 1 and metric.shape[1] > 1:
|
329
|
+
p += 1
|
340
330
|
metr2 = metric[:, 1]
|
341
|
-
|
331
|
+
ax = np.append(ax, np.array(ax[p - 1].twinx()))
|
342
332
|
color2 = "blue"
|
343
|
-
|
333
|
+
ax[p].plot(x_axis, metr2, color=color2, label="phase dist (deg)")
|
344
334
|
# ax2.set_ylim([-180.0, +180.0])
|
345
335
|
if np.max(metr2) - np.min(metr2) > 0.1:
|
346
|
-
|
347
|
-
|
348
|
-
|
336
|
+
ax[p].set_ylim([np.min(metr2), np.max(metr2)])
|
337
|
+
ax[p].set_ylabel("phase dist (deg)", color=color2)
|
338
|
+
ax[p].tick_params(axis="y", labelcolor=color2)
|
349
339
|
# ax[p].set_title('SNR and SNR mse (mean over freq. db)')
|
350
340
|
|
351
|
-
return fig
|
341
|
+
return fig, ax
|
352
342
|
|
353
343
|
|
354
344
|
def _process_mixture(
|
@@ -368,12 +358,13 @@ def _process_mixture(
|
|
368
358
|
from os.path import splitext
|
369
359
|
|
370
360
|
import h5py
|
371
|
-
import
|
361
|
+
import pgzip
|
372
362
|
from matplotlib.backends.backend_pdf import PdfPages
|
373
363
|
from pystoi import stoi
|
374
364
|
|
375
365
|
from sonusai import logger
|
376
366
|
from sonusai.metrics import calc_pcm
|
367
|
+
from sonusai.metrics import calc_pesq
|
377
368
|
from sonusai.metrics import calc_phase_distance
|
378
369
|
from sonusai.metrics import calc_speech
|
379
370
|
from sonusai.metrics import calc_wer
|
@@ -422,16 +413,16 @@ def _process_mixture(
|
|
422
413
|
predict = stack_complex(predict)
|
423
414
|
|
424
415
|
# 2) Collect true target, noise, mixture data, trim to predict size if needed
|
425
|
-
tmp = mixdb.
|
426
|
-
target_f = mixdb.
|
427
|
-
target = tmp[
|
416
|
+
tmp = mixdb.mixture_sources(m_id) # time-dom augmented targets is list of pre-IR and pre-specaugment targets
|
417
|
+
target_f = mixdb.mixture_sources_f(m_id, sources=tmp)["primary"]
|
418
|
+
target = tmp["primary"]
|
428
419
|
mixture = mixdb.mixture_mixture(m_id) # note: gives full reverberated/distorted target, but no specaugment
|
429
420
|
# noise_wo_dist = mixdb.mixture_noise(mixid) # noise without specaugment and distortion
|
430
421
|
# noise_wo_dist_f = mixdb.mixture_noise_f(mixid, noise=noise_wo_dist)
|
431
422
|
noise = mixture - target # has time-domain distortion (ir,etc.) but does not have specaugment
|
432
423
|
# noise_f = mixdb.mixture_noise_f(mixid, noise=noise)
|
433
424
|
# note: uses pre-IR, pre-specaug audio
|
434
|
-
segsnr_f = mixdb.mixture_metrics(m_id, ["ssnr"])["ssnr"][0]
|
425
|
+
segsnr_f = mixdb.mixture_metrics(m_id, ["ssnr"])["ssnr"] # Why [0] removed?
|
435
426
|
mixture_f = mixdb.mixture_mixture_f(m_id, mixture=mixture)
|
436
427
|
noise_f = mixture_f - target_f # true noise in freq domain includes specaugment and time-domain ir,distortions
|
437
428
|
# segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise)
|
@@ -446,7 +437,7 @@ def _process_mixture(
|
|
446
437
|
# gen feature, truth - note feature only used for plots
|
447
438
|
# TODO: parse truth_f for different formats
|
448
439
|
feature, truth_all = mixdb.mixture_ft(m_id, mixture_f=mixture_f)
|
449
|
-
truth_f = truth_all[target_f_key]
|
440
|
+
truth_f = truth_all["primary"][target_f_key]
|
450
441
|
if truth_f.ndim > 2: # note this may not be needed anymore as all target_f truth is 3 dims
|
451
442
|
if truth_f.shape[1] != 1:
|
452
443
|
logger.info("Error: target_f truth has stride > 1, exiting.")
|
@@ -488,7 +479,7 @@ def _process_mixture(
|
|
488
479
|
predict = truth_f # substitute truth for the prediction (for test/debug)
|
489
480
|
predict_complex = unstack_complex(predict) # unstack
|
490
481
|
# if feature has compressed mag and truth does not, compress it
|
491
|
-
if mixdb.feature[0:1] == "h" and not
|
482
|
+
if mixdb.feature[0:1] == "h" and not first_key(mixdb.category_truth_configs("primary")).startswith(
|
492
483
|
"targetcmpr"
|
493
484
|
):
|
494
485
|
predict_complex = power_compress(predict_complex) # from uncompressed truth
|
@@ -535,23 +526,24 @@ def _process_mixture(
|
|
535
526
|
# logger.debug(f'wsdr ccoefs for mixid {mixid} = {wsdr_cc}.')
|
536
527
|
|
537
528
|
# Speech intelligibility measure - PESQ
|
538
|
-
if int(mixdb.mixture(m_id).snr) > -99:
|
529
|
+
if int(mixdb.mixture(m_id).noise.snr) > -99:
|
539
530
|
# len = target_est_wav.shape[0]
|
540
|
-
pesq_speech
|
531
|
+
pesq_speech = calc_pesq(target_est_wav, target_fi)
|
532
|
+
csig_tg, cbak_tg, covl_tg = calc_speech(target_est_wav, target_fi, pesq=pesq_speech)
|
541
533
|
metrics = mixdb.mixture_metrics(m_id, ["mxpesq", "mxcsig", "mxcbak", "mxcovl"])
|
542
|
-
|
543
|
-
csig_mx = metrics["mxcsig"]
|
544
|
-
cbak_mx = metrics["mxcbak"]
|
545
|
-
covl_mx = metrics["mxcovl"]
|
534
|
+
pesq_mx = metrics["mxpesq"][0] if isinstance(metrics["mxpesq"], list) else metrics["mxpesq"]
|
535
|
+
csig_mx = metrics["mxcsig"][0] if isinstance(metrics["mxcsig"], list) else metrics["mxcsig"]
|
536
|
+
cbak_mx = metrics["mxcbak"][0] if isinstance(metrics["mxcbak"], list) else metrics["mxcbak"]
|
537
|
+
covl_mx = metrics["mxcovl"][0] if isinstance(metrics["mxcovl"], list) else metrics["mxcovl"]
|
546
538
|
# pesq_speech_tst = calc_pesq(hypothesis=target_est_wav, reference=target)
|
547
539
|
# pesq_mixture_tst = calc_pesq(hypothesis=mixture, reference=target)
|
548
540
|
# pesq improvement
|
549
|
-
pesq_impr = pesq_speech -
|
541
|
+
pesq_impr = pesq_speech - pesq_mx
|
550
542
|
# pesq improvement %
|
551
|
-
pesq_impr_pc = pesq_impr / (
|
543
|
+
pesq_impr_pc = pesq_impr / (pesq_mx + np.finfo(np.float32).eps) * 100
|
552
544
|
else:
|
553
545
|
pesq_speech = 0
|
554
|
-
|
546
|
+
pesq_mx = 0
|
555
547
|
pesq_impr_pc = np.float32(0)
|
556
548
|
csig_mx = 0
|
557
549
|
csig_tg = 0
|
@@ -565,14 +557,14 @@ def _process_mixture(
|
|
565
557
|
asr_mx = None
|
566
558
|
asr_tge = None
|
567
559
|
# asr_engines = list(mixdb.asr_configs.keys())
|
568
|
-
if asr_method is not None and mixdb.mixture(m_id).snr >= -96: # noise only, ignore/reset target ASR
|
560
|
+
if asr_method is not None and mixdb.mixture(m_id).noise.snr >= -96: # noise only, ignore/reset target ASR
|
569
561
|
asr_mx_name = f"mxasr.{asr_method}"
|
570
562
|
wer_mx_name = f"mxwer.{asr_method}"
|
571
563
|
asr_tt_name = f"tasr.{asr_method}"
|
572
564
|
metrics = mixdb.mixture_metrics(m_id, [asr_mx_name, wer_mx_name, asr_tt_name])
|
573
|
-
asr_mx = metrics[asr_mx_name][0]
|
574
|
-
wer_mx = metrics[wer_mx_name][0]
|
575
|
-
asr_tt = metrics[asr_tt_name][0]
|
565
|
+
asr_mx = metrics[asr_mx_name][0] if isinstance(metrics[asr_mx_name], list) else metrics[asr_mx_name]
|
566
|
+
wer_mx = metrics[wer_mx_name][0] if isinstance(metrics[wer_mx_name], list) else metrics[wer_mx_name]
|
567
|
+
asr_tt = metrics[asr_tt_name][0] if isinstance(metrics[asr_tt_name], list) else metrics[asr_tt_name]
|
576
568
|
|
577
569
|
if asr_tt:
|
578
570
|
noiseadd = None # TBD add as switch, default -30
|
@@ -628,11 +620,11 @@ def _process_mixture(
|
|
628
620
|
"SPFILE",
|
629
621
|
"NFILE",
|
630
622
|
]
|
631
|
-
ti = mixdb.mixture(m_id).
|
623
|
+
ti = mixdb.mixture(m_id).sources["primary"].file_id
|
632
624
|
ni = mixdb.mixture(m_id).noise.file_id
|
633
625
|
metr1 = [
|
634
|
-
mixdb.mixture(m_id).snr,
|
635
|
-
|
626
|
+
mixdb.mixture(m_id).noise.snr,
|
627
|
+
pesq_mx,
|
636
628
|
pesq_speech,
|
637
629
|
pesq_impr_pc,
|
638
630
|
wer_mx,
|
@@ -650,8 +642,8 @@ def _process_mixture(
|
|
650
642
|
cbak_tg,
|
651
643
|
covl_mx,
|
652
644
|
covl_tg,
|
653
|
-
basename(mixdb.
|
654
|
-
basename(mixdb.
|
645
|
+
basename(mixdb.source_file(ti).name),
|
646
|
+
basename(mixdb.source_file(ni).name),
|
655
647
|
]
|
656
648
|
mtab1 = pd.DataFrame([metr1], columns=mtable1_col, index=[m_id])
|
657
649
|
|
@@ -669,7 +661,7 @@ def _process_mixture(
|
|
669
661
|
)
|
670
662
|
dat1row = metr2.loc[["mean", "min", "50%", "max", "std"], :].T.stack().to_numpy().reshape((1, -1))
|
671
663
|
mtab2 = pd.DataFrame(dat1row, index=[m_id], columns=new_labels)
|
672
|
-
mtab2.insert(0, "MXSNR", mixdb.mixture(m_id).snr, False) # add MXSNR as the first metric column
|
664
|
+
mtab2.insert(0, "MXSNR", mixdb.mixture(m_id).noise.snr, False) # add MXSNR as the first metric column
|
673
665
|
|
674
666
|
all_metrics_table_1 = mtab1 # return to be collected by process
|
675
667
|
all_metrics_table_2 = mtab2 # return to be collected by process
|
@@ -686,8 +678,8 @@ def _process_mixture(
|
|
686
678
|
print(f"Extraction statistics over {mixture_f.shape[0]} frames:", file=f)
|
687
679
|
print(metr2.round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
|
688
680
|
print("", file=f)
|
689
|
-
print(f"Target path: {mixdb.
|
690
|
-
print(f"Noise path: {mixdb.
|
681
|
+
print(f"Target path: {mixdb.source_file(ti).name}", file=f)
|
682
|
+
print(f"Noise path: {mixdb.source_file(ni).name}", file=f)
|
691
683
|
if asr_method != "none":
|
692
684
|
print(f"ASR method: {asr_method}", file=f)
|
693
685
|
print(f"ASR truth: {asr_tt}", file=f)
|
@@ -746,7 +738,7 @@ def _process_mixture(
|
|
746
738
|
# tfunc_name = tfunc_name + ' (db)'
|
747
739
|
|
748
740
|
mixspec = 20 * np.log10(abs(mixture_f) + np.finfo(np.float32).eps)
|
749
|
-
|
741
|
+
fig, ax = plot_mixpred(
|
750
742
|
mixture=mixture,
|
751
743
|
mixture_f=mixspec,
|
752
744
|
target=target,
|
@@ -754,9 +746,8 @@ def _process_mixture(
|
|
754
746
|
predict=predplot,
|
755
747
|
tp_title=tfunc_name,
|
756
748
|
)
|
757
|
-
pdf.savefig(
|
758
|
-
|
759
|
-
pickle.dump(fig_obj, f)
|
749
|
+
pdf.savefig(fig)
|
750
|
+
pickle.dump((fig, ax), pgzip.open(base_name + "_metric_spenh_fig1.pkl.gz", "wb"))
|
760
751
|
|
761
752
|
# ----- page 2, plot unmapped predict, opt truth reconstructed and line plots of mean-over-f
|
762
753
|
# pdf.savefig(plot_pdb_predtruth(predict=pred_snr_f, tp_title='predict snr_f (db)'))
|
@@ -765,7 +756,7 @@ def _process_mixture(
|
|
765
756
|
tg_spec = 20 * np.log10(abs(target_f) + np.finfo(np.float32).eps)
|
766
757
|
tg_est_spec = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
|
767
758
|
# n_spec = np.reshape(n_spec,(n_spec.shape[0] * n_spec.shape[1], n_spec.shape[2]))
|
768
|
-
|
759
|
+
fig, ax = plot_e_predict_truth(
|
769
760
|
predict=tg_est_spec,
|
770
761
|
predict_wav=target_est_wav,
|
771
762
|
truth_f=tg_spec,
|
@@ -773,14 +764,13 @@ def _process_mixture(
|
|
773
764
|
metric=np.vstack((lerr_tg_frame, phd_frame)).T,
|
774
765
|
tp_title="speech estimate",
|
775
766
|
)
|
776
|
-
pdf.savefig(
|
777
|
-
|
778
|
-
pickle.dump(fig_obj, f)
|
767
|
+
pdf.savefig(fig)
|
768
|
+
pickle.dump((fig, ax), pgzip.open(base_name + "_metric_spenh_fig2.pkl.gz", "wb"))
|
779
769
|
|
780
770
|
# page 4 noise extraction
|
781
771
|
n_spec = 20 * np.log10(abs(noise_f) + np.finfo(np.float32).eps)
|
782
772
|
n_est_spec = 20 * np.log10(abs(noise_est_complex) + np.finfo(np.float32).eps)
|
783
|
-
|
773
|
+
fig, ax = plot_e_predict_truth(
|
784
774
|
predict=n_est_spec,
|
785
775
|
predict_wav=noise_est_wav,
|
786
776
|
truth_f=n_spec,
|
@@ -788,9 +778,8 @@ def _process_mixture(
|
|
788
778
|
metric=lerr_n_frame,
|
789
779
|
tp_title="noise estimate",
|
790
780
|
)
|
791
|
-
pdf.savefig(
|
792
|
-
|
793
|
-
pickle.dump(fig_obj, f)
|
781
|
+
pdf.savefig(fig)
|
782
|
+
pickle.dump((fig, ax), pgzip.open(base_name + "_metric_spenh_fig4.pkl.gz", "wb"))
|
794
783
|
|
795
784
|
# Plot error waveforms
|
796
785
|
# tg_err_wav = target_fi - target_est_wav
|
@@ -871,12 +860,14 @@ def main():
|
|
871
860
|
)
|
872
861
|
# speech enhancement metrics and audio truth requires target_f truth type, check it is present
|
873
862
|
target_f_key = None
|
874
|
-
logger.info(
|
875
|
-
|
876
|
-
|
863
|
+
logger.info(
|
864
|
+
f"mixdb has {len(mixdb.category_truth_configs('primary'))} truth types defined for primary, checking that target_f type is present."
|
865
|
+
)
|
866
|
+
for key in mixdb.category_truth_configs("primary"):
|
867
|
+
if mixdb.category_truth_configs("primary")[key] == "target_f":
|
877
868
|
target_f_key = key
|
878
869
|
if target_f_key is None:
|
879
|
-
logger.error("mixdb does not have target_f truth
|
870
|
+
logger.error("mixdb does not have target_f truth defined, required for speech enhancement metrics, exiting.")
|
880
871
|
raise SystemExit(1)
|
881
872
|
|
882
873
|
logger.info(f"Only running specified subset of {len(mixids)} mixtures")
|
@@ -924,8 +915,8 @@ def main():
|
|
924
915
|
no_par = True
|
925
916
|
num_cpus = None
|
926
917
|
else:
|
927
|
-
no_par =
|
928
|
-
num_cpus =
|
918
|
+
no_par = False
|
919
|
+
num_cpus = use_cpu
|
929
920
|
|
930
921
|
all_metrics_tables = par_track(
|
931
922
|
partial(
|
@@ -1101,7 +1092,14 @@ def main():
|
|
1101
1092
|
|
1102
1093
|
|
1103
1094
|
if __name__ == "__main__":
|
1104
|
-
|
1095
|
+
from sonusai import exception_handler
|
1096
|
+
from sonusai.utils import register_keyboard_interrupt
|
1097
|
+
|
1098
|
+
register_keyboard_interrupt()
|
1099
|
+
try:
|
1100
|
+
main()
|
1101
|
+
except Exception as e:
|
1102
|
+
exception_handler(e)
|
1105
1103
|
|
1106
1104
|
# if asr_method == 'none':
|
1107
1105
|
# fnb = 'metric_spenh_'
|
sonusai/config/config.py
ADDED
@@ -0,0 +1,61 @@
|
|
1
|
+
def _load_yaml(name: str) -> dict:
|
2
|
+
"""Load YAML file
|
3
|
+
|
4
|
+
:param name: File name
|
5
|
+
:return: Dictionary of config data
|
6
|
+
"""
|
7
|
+
import yaml
|
8
|
+
|
9
|
+
with open(file=name) as f:
|
10
|
+
config = yaml.safe_load(f)
|
11
|
+
|
12
|
+
return config
|
13
|
+
|
14
|
+
|
15
|
+
def _default_config() -> dict:
|
16
|
+
"""Load default SonusAI config
|
17
|
+
|
18
|
+
:return: Dictionary of default config data
|
19
|
+
"""
|
20
|
+
from .constants import DEFAULT_CONFIG
|
21
|
+
|
22
|
+
try:
|
23
|
+
return _load_yaml(DEFAULT_CONFIG)
|
24
|
+
except Exception as e:
|
25
|
+
raise OSError(f"Error loading default config: {e}") from e
|
26
|
+
|
27
|
+
|
28
|
+
def _update_config_from_file(filename: str, given_config: dict) -> dict:
|
29
|
+
"""Update the given config with the config in the specified YAML file
|
30
|
+
|
31
|
+
:param filename: File name
|
32
|
+
:param given_config: Config dictionary to update
|
33
|
+
:return: Updated config dictionary
|
34
|
+
"""
|
35
|
+
from copy import deepcopy
|
36
|
+
|
37
|
+
updated_config = deepcopy(given_config)
|
38
|
+
|
39
|
+
try:
|
40
|
+
file_config = _load_yaml(filename)
|
41
|
+
except Exception as e:
|
42
|
+
raise OSError(f"Error loading config from {filename}: {e}") from e
|
43
|
+
|
44
|
+
# Use default config as base and overwrite with given config keys as found
|
45
|
+
if file_config:
|
46
|
+
for key in updated_config:
|
47
|
+
if key in file_config:
|
48
|
+
updated_config[key] = file_config[key]
|
49
|
+
|
50
|
+
return updated_config
|
51
|
+
|
52
|
+
|
53
|
+
def load_config(name: str) -> dict:
|
54
|
+
"""Load SonusAI default config and update with given location (performing SonusAI variable substitution)
|
55
|
+
|
56
|
+
:param name: Directory containing mixture database
|
57
|
+
:return: Dictionary of config data
|
58
|
+
"""
|
59
|
+
from os.path import join
|
60
|
+
|
61
|
+
return _update_config_from_file(filename=join(name, "config.yml"), given_config=_default_config())
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# Default configuration for sonusai
|
2
|
+
|
3
|
+
# The values in this file are the defaults used if they are not specified in a
|
4
|
+
# local config.
|
5
|
+
|
6
|
+
feature: ""
|
7
|
+
|
8
|
+
target_level_type: default
|
9
|
+
|
10
|
+
class_indices: 1
|
11
|
+
|
12
|
+
num_classes: 1
|
13
|
+
|
14
|
+
class_labels: [ ]
|
15
|
+
|
16
|
+
seed: 0
|
17
|
+
|
18
|
+
class_weights_threshold: 0.5
|
19
|
+
|
20
|
+
asr_configs: { }
|
@@ -0,0 +1,8 @@
|
|
1
|
+
from importlib.resources import as_file
|
2
|
+
from importlib.resources import files
|
3
|
+
|
4
|
+
REQUIRED_TRUTH_CONFIG_FIELDS = ["function", "stride_reduction"]
|
5
|
+
REQUIRED_ASR_CONFIG_FIELDS = ["engine"]
|
6
|
+
|
7
|
+
with as_file(files("sonusai.config").joinpath("config.yml")) as path:
|
8
|
+
DEFAULT_CONFIG = str(path)
|
sonusai/constants.py
ADDED
@@ -0,0 +1,11 @@
|
|
1
|
+
from importlib.resources import as_file
|
2
|
+
from importlib.resources import files
|
3
|
+
|
4
|
+
SAMPLE_RATE = 16000
|
5
|
+
CHANNEL_COUNT = 1
|
6
|
+
BIT_DEPTH = 32
|
7
|
+
SAMPLE_BYTES = BIT_DEPTH // 8
|
8
|
+
FLOAT_BYTES = 4
|
9
|
+
|
10
|
+
with as_file(files("sonusai.data").joinpath("whitenoise.wav")) as path:
|
11
|
+
DEFAULT_NOISE = str(path)
|
sonusai/data/genmixdb.yml
CHANGED
@@ -3,54 +3,41 @@
|
|
3
3
|
# The values in this file are the defaults used if they are not specified in a
|
4
4
|
# local config.
|
5
5
|
|
6
|
-
|
7
|
-
|
8
|
-
target_level_type: default
|
9
|
-
|
10
|
-
class_indices: 1
|
6
|
+
seed: 0
|
11
7
|
|
12
|
-
|
8
|
+
feature: ""
|
13
9
|
|
14
10
|
num_classes: 1
|
15
11
|
|
16
|
-
|
12
|
+
asr_configs: { }
|
17
13
|
|
18
|
-
|
14
|
+
level_type: default
|
19
15
|
|
16
|
+
class_indices: 1
|
17
|
+
class_labels: [ ]
|
20
18
|
class_weights_threshold: 0.5
|
21
19
|
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
-
|
28
|
-
|
29
|
-
class_balancing_augmentation:
|
30
|
-
normalize: -3.5
|
31
|
-
pitch: "rand(-300, 300)"
|
32
|
-
tempo: "rand(0.8, 1.2)"
|
33
|
-
eq1: [ "rand(50, 250)", "rand(0.6, 1.0)", "rand(-6, 6)" ]
|
34
|
-
eq2: [ "rand(250, 1200)", "rand(0.6, 1.0)", "rand(-6, 6)" ]
|
35
|
-
eq3: [ "rand(1200, 6000)", "rand(0.6, 1.0)", "rand(-6, 6)" ]
|
20
|
+
class_balancing_effect:
|
21
|
+
- norm -3.5
|
22
|
+
- pitch sai_rand(-300, 300)
|
23
|
+
- tempo -s sai_rand(0.8, 1.2)
|
24
|
+
- equalizer sai_rand(50, 250) sai_rand(0.2, 2.0) sai_rand(-6, 6)
|
25
|
+
- equalizer sai_rand(250, 1200) sai_rand(0.2, 2.0) sai_rand(-6, 6)
|
26
|
+
- equalizer sai_rand(1200, 6000) sai_rand(0.2, 2.0) sai_rand(-6, 6)
|
36
27
|
|
37
28
|
class_balancing: false
|
38
29
|
|
39
|
-
|
40
|
-
- "${default_noise}"
|
41
|
-
|
42
|
-
noise_augmentations:
|
43
|
-
- pre:
|
44
|
-
normalize: -3.5
|
45
|
-
|
46
|
-
snrs:
|
47
|
-
- 99
|
30
|
+
impulse_responses: [ ]
|
48
31
|
|
49
|
-
|
32
|
+
sources:
|
33
|
+
primary:
|
34
|
+
files: [ ]
|
35
|
+
noise:
|
36
|
+
files: [ ]
|
50
37
|
|
51
|
-
|
38
|
+
summed_source_effects: [ ]
|
52
39
|
|
53
|
-
|
40
|
+
mixture_effects: [ ]
|
54
41
|
|
55
42
|
spectral_masks:
|
56
43
|
- f_max_width: 27
|
@@ -58,5 +45,3 @@ spectral_masks:
|
|
58
45
|
t_max_width: 100
|
59
46
|
t_num: 0
|
60
47
|
t_max_percent: 100
|
61
|
-
|
62
|
-
asr_configs: { }
|