sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sonusai/__init__.py +16 -3
  2. sonusai/audiofe.py +241 -77
  3. sonusai/calc_metric_spenh.py +71 -73
  4. sonusai/config/__init__.py +3 -0
  5. sonusai/config/config.py +61 -0
  6. sonusai/config/config.yml +20 -0
  7. sonusai/config/constants.py +8 -0
  8. sonusai/constants.py +11 -0
  9. sonusai/data/genmixdb.yml +21 -36
  10. sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
  11. sonusai/deprecated/plot.py +4 -5
  12. sonusai/doc/doc.py +4 -4
  13. sonusai/doc.py +11 -4
  14. sonusai/genft.py +43 -45
  15. sonusai/genmetrics.py +25 -19
  16. sonusai/genmix.py +54 -82
  17. sonusai/genmixdb.py +88 -264
  18. sonusai/ir_metric.py +30 -34
  19. sonusai/lsdb.py +41 -48
  20. sonusai/main.py +15 -22
  21. sonusai/metrics/calc_audio_stats.py +4 -293
  22. sonusai/metrics/calc_class_weights.py +4 -4
  23. sonusai/metrics/calc_optimal_thresholds.py +8 -5
  24. sonusai/metrics/calc_pesq.py +2 -2
  25. sonusai/metrics/calc_segsnr_f.py +4 -4
  26. sonusai/metrics/calc_speech.py +25 -13
  27. sonusai/metrics/class_summary.py +7 -7
  28. sonusai/metrics/confusion_matrix_summary.py +5 -5
  29. sonusai/metrics/one_hot.py +4 -4
  30. sonusai/metrics/snr_summary.py +7 -7
  31. sonusai/metrics_summary.py +38 -45
  32. sonusai/mixture/__init__.py +4 -104
  33. sonusai/mixture/audio.py +10 -39
  34. sonusai/mixture/class_balancing.py +103 -0
  35. sonusai/mixture/config.py +251 -271
  36. sonusai/mixture/constants.py +35 -39
  37. sonusai/mixture/data_io.py +25 -36
  38. sonusai/mixture/db_datatypes.py +58 -22
  39. sonusai/mixture/effects.py +386 -0
  40. sonusai/mixture/feature.py +7 -11
  41. sonusai/mixture/generation.py +478 -628
  42. sonusai/mixture/helpers.py +82 -184
  43. sonusai/mixture/ir_delay.py +3 -4
  44. sonusai/mixture/ir_effects.py +77 -0
  45. sonusai/mixture/log_duration_and_sizes.py +6 -12
  46. sonusai/mixture/mixdb.py +910 -729
  47. sonusai/mixture/pad_audio.py +35 -0
  48. sonusai/mixture/resample.py +7 -0
  49. sonusai/mixture/sox_effects.py +195 -0
  50. sonusai/mixture/sox_help.py +650 -0
  51. sonusai/mixture/spectral_mask.py +2 -2
  52. sonusai/mixture/truth.py +17 -15
  53. sonusai/mixture/truth_functions/crm.py +12 -12
  54. sonusai/mixture/truth_functions/energy.py +22 -22
  55. sonusai/mixture/truth_functions/file.py +5 -5
  56. sonusai/mixture/truth_functions/metadata.py +4 -4
  57. sonusai/mixture/truth_functions/metrics.py +4 -4
  58. sonusai/mixture/truth_functions/phoneme.py +3 -3
  59. sonusai/mixture/truth_functions/sed.py +11 -13
  60. sonusai/mixture/truth_functions/target.py +10 -10
  61. sonusai/mkwav.py +26 -29
  62. sonusai/onnx_predict.py +240 -88
  63. sonusai/queries/__init__.py +2 -2
  64. sonusai/queries/queries.py +38 -34
  65. sonusai/speech/librispeech.py +1 -1
  66. sonusai/speech/mcgill.py +1 -1
  67. sonusai/speech/timit.py +2 -2
  68. sonusai/summarize_metric_spenh.py +10 -17
  69. sonusai/utils/__init__.py +7 -1
  70. sonusai/utils/asl_p56.py +2 -2
  71. sonusai/utils/asr.py +2 -2
  72. sonusai/utils/asr_functions/aaware_whisper.py +4 -5
  73. sonusai/utils/choice.py +31 -0
  74. sonusai/utils/compress.py +1 -1
  75. sonusai/utils/dataclass_from_dict.py +19 -1
  76. sonusai/utils/energy_f.py +3 -3
  77. sonusai/utils/evaluate_random_rule.py +15 -0
  78. sonusai/utils/keyboard_interrupt.py +12 -0
  79. sonusai/utils/onnx_utils.py +3 -17
  80. sonusai/utils/print_mixture_details.py +21 -19
  81. sonusai/utils/{temp_seed.py → rand.py} +3 -3
  82. sonusai/utils/read_predict_data.py +2 -2
  83. sonusai/utils/reshape.py +3 -3
  84. sonusai/utils/stratified_shuffle_split.py +3 -3
  85. sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
  86. sonusai/utils/write_audio.py +2 -2
  87. sonusai/vars.py +11 -4
  88. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
  89. sonusai-1.0.2.dist-info/RECORD +138 -0
  90. sonusai/mixture/augmentation.py +0 -444
  91. sonusai/mixture/class_count.py +0 -15
  92. sonusai/mixture/eq_rule_is_valid.py +0 -45
  93. sonusai/mixture/target_class_balancing.py +0 -107
  94. sonusai/mixture/targets.py +0 -175
  95. sonusai-0.20.3.dist-info/RECORD +0 -128
  96. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
  97. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
@@ -67,34 +67,23 @@ Inputs:
67
67
 
68
68
  """
69
69
 
70
- import signal
70
+ from typing import Any
71
71
 
72
72
  import matplotlib
73
73
  import matplotlib.pyplot as plt
74
74
  import numpy as np
75
75
  import pandas as pd
76
76
 
77
- from sonusai.mixture import AudioF
78
- from sonusai.mixture import AudioT
79
- from sonusai.mixture import Feature
77
+ from sonusai.datatypes import AudioF
78
+ from sonusai.datatypes import AudioT
79
+ from sonusai.datatypes import Feature
80
+ from sonusai.datatypes import Predict
80
81
  from sonusai.mixture import MixtureDatabase
81
- from sonusai.mixture import Predict
82
82
 
83
83
  DB_99 = np.power(10, 99 / 10)
84
84
  DB_N99 = np.power(10, -99 / 10)
85
85
 
86
86
 
87
- def signal_handler(_sig, _frame):
88
- import sys
89
-
90
- from sonusai import logger
91
-
92
- logger.info("Canceled due to keyboard interrupt")
93
- sys.exit(1)
94
-
95
-
96
- signal.signal(signal.SIGINT, signal_handler)
97
-
98
87
  matplotlib.use("SVG")
99
88
 
100
89
 
@@ -192,8 +181,8 @@ def plot_mixpred(
192
181
  feature: Feature | None = None,
193
182
  predict: Predict | None = None,
194
183
  tp_title: str = "",
195
- ) -> plt.Figure: # pyright: ignore [reportPrivateImportUsage]
196
- from sonusai.mixture import SAMPLE_RATE
184
+ ) -> tuple[plt.Figure, Any]: # pyright: ignore [reportPrivateImportUsage]
185
+ from sonusai.constants import SAMPLE_RATE
197
186
 
198
187
  num_plots = 2
199
188
  if feature is not None:
@@ -229,7 +218,7 @@ def plot_mixpred(
229
218
  ax[p].set_title("Predict " + tp_title)
230
219
  plt.colorbar(im, location="bottom")
231
220
 
232
- return fig
221
+ return fig, ax
233
222
 
234
223
 
235
224
  def plot_pdb_predict_truth(
@@ -291,7 +280,7 @@ def plot_e_predict_truth(
291
280
  truth_wav: np.ndarray | None = None,
292
281
  metric: np.ndarray | None = None,
293
282
  tp_title: str = "",
294
- ) -> plt.Figure: # pyright: ignore [reportPrivateImportUsage]
283
+ ) -> tuple[plt.Figure, Any]: # pyright: ignore [reportPrivateImportUsage]
295
284
  """Plot predict spectrogram and waveform and optionally truth and a metric)"""
296
285
  num_plots = 2
297
286
  if truth_f is not None:
@@ -337,18 +326,19 @@ def plot_e_predict_truth(
337
326
  ax[p].set_xlim(x_axis[0], x_axis[-1])
338
327
  ax[p].set_ylim([-0.01, np.max(metric1) + 0.01])
339
328
  if metric.ndim > 1 and metric.shape[1] > 1:
329
+ p += 1
340
330
  metr2 = metric[:, 1]
341
- ax2 = ax[p].twinx()
331
+ ax = np.append(ax, np.array(ax[p - 1].twinx()))
342
332
  color2 = "blue"
343
- ax2.plot(x_axis, metr2, color=color2, label="phase dist (deg)")
333
+ ax[p].plot(x_axis, metr2, color=color2, label="phase dist (deg)")
344
334
  # ax2.set_ylim([-180.0, +180.0])
345
335
  if np.max(metr2) - np.min(metr2) > 0.1:
346
- ax2.set_ylim([np.min(metr2), np.max(metr2)])
347
- ax2.set_ylabel("phase dist (deg)", color=color2)
348
- ax2.tick_params(axis="y", labelcolor=color2)
336
+ ax[p].set_ylim([np.min(metr2), np.max(metr2)])
337
+ ax[p].set_ylabel("phase dist (deg)", color=color2)
338
+ ax[p].tick_params(axis="y", labelcolor=color2)
349
339
  # ax[p].set_title('SNR and SNR mse (mean over freq. db)')
350
340
 
351
- return fig
341
+ return fig, ax
352
342
 
353
343
 
354
344
  def _process_mixture(
@@ -368,12 +358,13 @@ def _process_mixture(
368
358
  from os.path import splitext
369
359
 
370
360
  import h5py
371
- import mgzip
361
+ import pgzip
372
362
  from matplotlib.backends.backend_pdf import PdfPages
373
363
  from pystoi import stoi
374
364
 
375
365
  from sonusai import logger
376
366
  from sonusai.metrics import calc_pcm
367
+ from sonusai.metrics import calc_pesq
377
368
  from sonusai.metrics import calc_phase_distance
378
369
  from sonusai.metrics import calc_speech
379
370
  from sonusai.metrics import calc_wer
@@ -422,16 +413,16 @@ def _process_mixture(
422
413
  predict = stack_complex(predict)
423
414
 
424
415
  # 2) Collect true target, noise, mixture data, trim to predict size if needed
425
- tmp = mixdb.mixture_targets(m_id) # time-dom augmented targets is list of pre-IR and pre-specaugment targets
426
- target_f = mixdb.mixture_targets_f(m_id, targets=tmp)[0]
427
- target = tmp[0]
416
+ tmp = mixdb.mixture_sources(m_id) # time-dom augmented targets is list of pre-IR and pre-specaugment targets
417
+ target_f = mixdb.mixture_sources_f(m_id, sources=tmp)["primary"]
418
+ target = tmp["primary"]
428
419
  mixture = mixdb.mixture_mixture(m_id) # note: gives full reverberated/distorted target, but no specaugment
429
420
  # noise_wo_dist = mixdb.mixture_noise(mixid) # noise without specaugment and distortion
430
421
  # noise_wo_dist_f = mixdb.mixture_noise_f(mixid, noise=noise_wo_dist)
431
422
  noise = mixture - target # has time-domain distortion (ir,etc.) but does not have specaugment
432
423
  # noise_f = mixdb.mixture_noise_f(mixid, noise=noise)
433
424
  # note: uses pre-IR, pre-specaug audio
434
- segsnr_f = mixdb.mixture_metrics(m_id, ["ssnr"])["ssnr"][0]
425
+ segsnr_f = mixdb.mixture_metrics(m_id, ["ssnr"])["ssnr"] # Why [0] removed?
435
426
  mixture_f = mixdb.mixture_mixture_f(m_id, mixture=mixture)
436
427
  noise_f = mixture_f - target_f # true noise in freq domain includes specaugment and time-domain ir,distortions
437
428
  # segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise)
@@ -446,7 +437,7 @@ def _process_mixture(
446
437
  # gen feature, truth - note feature only used for plots
447
438
  # TODO: parse truth_f for different formats
448
439
  feature, truth_all = mixdb.mixture_ft(m_id, mixture_f=mixture_f)
449
- truth_f = truth_all[target_f_key]
440
+ truth_f = truth_all["primary"][target_f_key]
450
441
  if truth_f.ndim > 2: # note this may not be needed anymore as all target_f truth is 3 dims
451
442
  if truth_f.shape[1] != 1:
452
443
  logger.info("Error: target_f truth has stride > 1, exiting.")
@@ -488,7 +479,7 @@ def _process_mixture(
488
479
  predict = truth_f # substitute truth for the prediction (for test/debug)
489
480
  predict_complex = unstack_complex(predict) # unstack
490
481
  # if feature has compressed mag and truth does not, compress it
491
- if mixdb.feature[0:1] == "h" and not mixdb.truth_configs[first_key(mixdb.truth_configs)].function.startswith(
482
+ if mixdb.feature[0:1] == "h" and not first_key(mixdb.category_truth_configs("primary")).startswith(
492
483
  "targetcmpr"
493
484
  ):
494
485
  predict_complex = power_compress(predict_complex) # from uncompressed truth
@@ -535,23 +526,24 @@ def _process_mixture(
535
526
  # logger.debug(f'wsdr ccoefs for mixid {mixid} = {wsdr_cc}.')
536
527
 
537
528
  # Speech intelligibility measure - PESQ
538
- if int(mixdb.mixture(m_id).snr) > -99:
529
+ if int(mixdb.mixture(m_id).noise.snr) > -99:
539
530
  # len = target_est_wav.shape[0]
540
- pesq_speech, csig_tg, cbak_tg, covl_tg = calc_speech(target_est_wav, target_fi)
531
+ pesq_speech = calc_pesq(target_est_wav, target_fi)
532
+ csig_tg, cbak_tg, covl_tg = calc_speech(target_est_wav, target_fi, pesq=pesq_speech)
541
533
  metrics = mixdb.mixture_metrics(m_id, ["mxpesq", "mxcsig", "mxcbak", "mxcovl"])
542
- pesq_mixture = metrics["mxpesq"]
543
- csig_mx = metrics["mxcsig"]
544
- cbak_mx = metrics["mxcbak"]
545
- covl_mx = metrics["mxcovl"]
534
+ pesq_mx = metrics["mxpesq"][0] if isinstance(metrics["mxpesq"], list) else metrics["mxpesq"]
535
+ csig_mx = metrics["mxcsig"][0] if isinstance(metrics["mxcsig"], list) else metrics["mxcsig"]
536
+ cbak_mx = metrics["mxcbak"][0] if isinstance(metrics["mxcbak"], list) else metrics["mxcbak"]
537
+ covl_mx = metrics["mxcovl"][0] if isinstance(metrics["mxcovl"], list) else metrics["mxcovl"]
546
538
  # pesq_speech_tst = calc_pesq(hypothesis=target_est_wav, reference=target)
547
539
  # pesq_mixture_tst = calc_pesq(hypothesis=mixture, reference=target)
548
540
  # pesq improvement
549
- pesq_impr = pesq_speech - pesq_mixture
541
+ pesq_impr = pesq_speech - pesq_mx
550
542
  # pesq improvement %
551
- pesq_impr_pc = pesq_impr / (pesq_mixture + np.finfo(np.float32).eps) * 100
543
+ pesq_impr_pc = pesq_impr / (pesq_mx + np.finfo(np.float32).eps) * 100
552
544
  else:
553
545
  pesq_speech = 0
554
- pesq_mixture = 0
546
+ pesq_mx = 0
555
547
  pesq_impr_pc = np.float32(0)
556
548
  csig_mx = 0
557
549
  csig_tg = 0
@@ -565,14 +557,14 @@ def _process_mixture(
565
557
  asr_mx = None
566
558
  asr_tge = None
567
559
  # asr_engines = list(mixdb.asr_configs.keys())
568
- if asr_method is not None and mixdb.mixture(m_id).snr >= -96: # noise only, ignore/reset target ASR
560
+ if asr_method is not None and mixdb.mixture(m_id).noise.snr >= -96: # noise only, ignore/reset target ASR
569
561
  asr_mx_name = f"mxasr.{asr_method}"
570
562
  wer_mx_name = f"mxwer.{asr_method}"
571
563
  asr_tt_name = f"tasr.{asr_method}"
572
564
  metrics = mixdb.mixture_metrics(m_id, [asr_mx_name, wer_mx_name, asr_tt_name])
573
- asr_mx = metrics[asr_mx_name][0]
574
- wer_mx = metrics[wer_mx_name][0]
575
- asr_tt = metrics[asr_tt_name][0]
565
+ asr_mx = metrics[asr_mx_name][0] if isinstance(metrics[asr_mx_name], list) else metrics[asr_mx_name]
566
+ wer_mx = metrics[wer_mx_name][0] if isinstance(metrics[wer_mx_name], list) else metrics[wer_mx_name]
567
+ asr_tt = metrics[asr_tt_name][0] if isinstance(metrics[asr_tt_name], list) else metrics[asr_tt_name]
576
568
 
577
569
  if asr_tt:
578
570
  noiseadd = None # TBD add as switch, default -30
@@ -628,11 +620,11 @@ def _process_mixture(
628
620
  "SPFILE",
629
621
  "NFILE",
630
622
  ]
631
- ti = mixdb.mixture(m_id).targets[0].file_id
623
+ ti = mixdb.mixture(m_id).sources["primary"].file_id
632
624
  ni = mixdb.mixture(m_id).noise.file_id
633
625
  metr1 = [
634
- mixdb.mixture(m_id).snr,
635
- pesq_mixture,
626
+ mixdb.mixture(m_id).noise.snr,
627
+ pesq_mx,
636
628
  pesq_speech,
637
629
  pesq_impr_pc,
638
630
  wer_mx,
@@ -650,8 +642,8 @@ def _process_mixture(
650
642
  cbak_tg,
651
643
  covl_mx,
652
644
  covl_tg,
653
- basename(mixdb.target_file(ti).name),
654
- basename(mixdb.noise_file(ni).name),
645
+ basename(mixdb.source_file(ti).name),
646
+ basename(mixdb.source_file(ni).name),
655
647
  ]
656
648
  mtab1 = pd.DataFrame([metr1], columns=mtable1_col, index=[m_id])
657
649
 
@@ -669,7 +661,7 @@ def _process_mixture(
669
661
  )
670
662
  dat1row = metr2.loc[["mean", "min", "50%", "max", "std"], :].T.stack().to_numpy().reshape((1, -1))
671
663
  mtab2 = pd.DataFrame(dat1row, index=[m_id], columns=new_labels)
672
- mtab2.insert(0, "MXSNR", mixdb.mixture(m_id).snr, False) # add MXSNR as the first metric column
664
+ mtab2.insert(0, "MXSNR", mixdb.mixture(m_id).noise.snr, False) # add MXSNR as the first metric column
673
665
 
674
666
  all_metrics_table_1 = mtab1 # return to be collected by process
675
667
  all_metrics_table_2 = mtab2 # return to be collected by process
@@ -686,8 +678,8 @@ def _process_mixture(
686
678
  print(f"Extraction statistics over {mixture_f.shape[0]} frames:", file=f)
687
679
  print(metr2.round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
688
680
  print("", file=f)
689
- print(f"Target path: {mixdb.target_file(ti).name}", file=f)
690
- print(f"Noise path: {mixdb.noise_file(ni).name}", file=f)
681
+ print(f"Target path: {mixdb.source_file(ti).name}", file=f)
682
+ print(f"Noise path: {mixdb.source_file(ni).name}", file=f)
691
683
  if asr_method != "none":
692
684
  print(f"ASR method: {asr_method}", file=f)
693
685
  print(f"ASR truth: {asr_tt}", file=f)
@@ -746,7 +738,7 @@ def _process_mixture(
746
738
  # tfunc_name = tfunc_name + ' (db)'
747
739
 
748
740
  mixspec = 20 * np.log10(abs(mixture_f) + np.finfo(np.float32).eps)
749
- fig_obj = plot_mixpred(
741
+ fig, ax = plot_mixpred(
750
742
  mixture=mixture,
751
743
  mixture_f=mixspec,
752
744
  target=target,
@@ -754,9 +746,8 @@ def _process_mixture(
754
746
  predict=predplot,
755
747
  tp_title=tfunc_name,
756
748
  )
757
- pdf.savefig(fig_obj)
758
- with mgzip.open(base_name + "_metric_spenh_fig1.mfigz", "wb") as f:
759
- pickle.dump(fig_obj, f)
749
+ pdf.savefig(fig)
750
+ pickle.dump((fig, ax), pgzip.open(base_name + "_metric_spenh_fig1.pkl.gz", "wb"))
760
751
 
761
752
  # ----- page 2, plot unmapped predict, opt truth reconstructed and line plots of mean-over-f
762
753
  # pdf.savefig(plot_pdb_predtruth(predict=pred_snr_f, tp_title='predict snr_f (db)'))
@@ -765,7 +756,7 @@ def _process_mixture(
765
756
  tg_spec = 20 * np.log10(abs(target_f) + np.finfo(np.float32).eps)
766
757
  tg_est_spec = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
767
758
  # n_spec = np.reshape(n_spec,(n_spec.shape[0] * n_spec.shape[1], n_spec.shape[2]))
768
- fig_obj = plot_e_predict_truth(
759
+ fig, ax = plot_e_predict_truth(
769
760
  predict=tg_est_spec,
770
761
  predict_wav=target_est_wav,
771
762
  truth_f=tg_spec,
@@ -773,14 +764,13 @@ def _process_mixture(
773
764
  metric=np.vstack((lerr_tg_frame, phd_frame)).T,
774
765
  tp_title="speech estimate",
775
766
  )
776
- pdf.savefig(fig_obj)
777
- with mgzip.open(base_name + "_metric_spenh_fig2.mfigz", "wb") as f:
778
- pickle.dump(fig_obj, f)
767
+ pdf.savefig(fig)
768
+ pickle.dump((fig, ax), pgzip.open(base_name + "_metric_spenh_fig2.pkl.gz", "wb"))
779
769
 
780
770
  # page 4 noise extraction
781
771
  n_spec = 20 * np.log10(abs(noise_f) + np.finfo(np.float32).eps)
782
772
  n_est_spec = 20 * np.log10(abs(noise_est_complex) + np.finfo(np.float32).eps)
783
- fig_obj = plot_e_predict_truth(
773
+ fig, ax = plot_e_predict_truth(
784
774
  predict=n_est_spec,
785
775
  predict_wav=noise_est_wav,
786
776
  truth_f=n_spec,
@@ -788,9 +778,8 @@ def _process_mixture(
788
778
  metric=lerr_n_frame,
789
779
  tp_title="noise estimate",
790
780
  )
791
- pdf.savefig(fig_obj)
792
- with mgzip.open(base_name + "_metric_spenh_fig4.mfigz", "wb") as f:
793
- pickle.dump(fig_obj, f)
781
+ pdf.savefig(fig)
782
+ pickle.dump((fig, ax), pgzip.open(base_name + "_metric_spenh_fig4.pkl.gz", "wb"))
794
783
 
795
784
  # Plot error waveforms
796
785
  # tg_err_wav = target_fi - target_est_wav
@@ -871,12 +860,14 @@ def main():
871
860
  )
872
861
  # speech enhancement metrics and audio truth requires target_f truth type, check it is present
873
862
  target_f_key = None
874
- logger.info(f"mixdb has {len(mixdb.truth_configs)} truth types defined, checking that target_f type is present.")
875
- for key in mixdb.truth_configs:
876
- if mixdb.truth_configs[key].function == "target_f":
863
+ logger.info(
864
+ f"mixdb has {len(mixdb.category_truth_configs('primary'))} truth types defined for primary, checking that target_f type is present."
865
+ )
866
+ for key in mixdb.category_truth_configs("primary"):
867
+ if mixdb.category_truth_configs("primary")[key] == "target_f":
877
868
  target_f_key = key
878
869
  if target_f_key is None:
879
- logger.error("mixdb does not have target_f truth define, required for speech enhancement metrics, exiting.")
870
+ logger.error("mixdb does not have target_f truth defined, required for speech enhancement metrics, exiting.")
880
871
  raise SystemExit(1)
881
872
 
882
873
  logger.info(f"Only running specified subset of {len(mixids)} mixtures")
@@ -924,8 +915,8 @@ def main():
924
915
  no_par = True
925
916
  num_cpus = None
926
917
  else:
927
- no_par = True
928
- num_cpus = None
918
+ no_par = False
919
+ num_cpus = use_cpu
929
920
 
930
921
  all_metrics_tables = par_track(
931
922
  partial(
@@ -1101,7 +1092,14 @@ def main():
1101
1092
 
1102
1093
 
1103
1094
  if __name__ == "__main__":
1104
- main()
1095
+ from sonusai import exception_handler
1096
+ from sonusai.utils import register_keyboard_interrupt
1097
+
1098
+ register_keyboard_interrupt()
1099
+ try:
1100
+ main()
1101
+ except Exception as e:
1102
+ exception_handler(e)
1105
1103
 
1106
1104
  # if asr_method == 'none':
1107
1105
  # fnb = 'metric_spenh_'
@@ -0,0 +1,3 @@
1
+ # ruff: noqa: F401
2
+
3
+ from .config import load_config
@@ -0,0 +1,61 @@
1
+ def _load_yaml(name: str) -> dict:
2
+ """Load YAML file
3
+
4
+ :param name: File name
5
+ :return: Dictionary of config data
6
+ """
7
+ import yaml
8
+
9
+ with open(file=name) as f:
10
+ config = yaml.safe_load(f)
11
+
12
+ return config
13
+
14
+
15
+ def _default_config() -> dict:
16
+ """Load default SonusAI config
17
+
18
+ :return: Dictionary of default config data
19
+ """
20
+ from .constants import DEFAULT_CONFIG
21
+
22
+ try:
23
+ return _load_yaml(DEFAULT_CONFIG)
24
+ except Exception as e:
25
+ raise OSError(f"Error loading default config: {e}") from e
26
+
27
+
28
+ def _update_config_from_file(filename: str, given_config: dict) -> dict:
29
+ """Update the given config with the config in the specified YAML file
30
+
31
+ :param filename: File name
32
+ :param given_config: Config dictionary to update
33
+ :return: Updated config dictionary
34
+ """
35
+ from copy import deepcopy
36
+
37
+ updated_config = deepcopy(given_config)
38
+
39
+ try:
40
+ file_config = _load_yaml(filename)
41
+ except Exception as e:
42
+ raise OSError(f"Error loading config from {filename}: {e}") from e
43
+
44
+ # Use default config as base and overwrite with given config keys as found
45
+ if file_config:
46
+ for key in updated_config:
47
+ if key in file_config:
48
+ updated_config[key] = file_config[key]
49
+
50
+ return updated_config
51
+
52
+
53
+ def load_config(name: str) -> dict:
54
+ """Load SonusAI default config and update with given location (performing SonusAI variable substitution)
55
+
56
+ :param name: Directory containing mixture database
57
+ :return: Dictionary of config data
58
+ """
59
+ from os.path import join
60
+
61
+ return _update_config_from_file(filename=join(name, "config.yml"), given_config=_default_config())
@@ -0,0 +1,20 @@
1
+ # Default configuration for sonusai
2
+
3
+ # The values in this file are the defaults used if they are not specified in a
4
+ # local config.
5
+
6
+ feature: ""
7
+
8
+ target_level_type: default
9
+
10
+ class_indices: 1
11
+
12
+ num_classes: 1
13
+
14
+ class_labels: [ ]
15
+
16
+ seed: 0
17
+
18
+ class_weights_threshold: 0.5
19
+
20
+ asr_configs: { }
@@ -0,0 +1,8 @@
1
+ from importlib.resources import as_file
2
+ from importlib.resources import files
3
+
4
+ REQUIRED_TRUTH_CONFIG_FIELDS = ["function", "stride_reduction"]
5
+ REQUIRED_ASR_CONFIG_FIELDS = ["engine"]
6
+
7
+ with as_file(files("sonusai.config").joinpath("config.yml")) as path:
8
+ DEFAULT_CONFIG = str(path)
sonusai/constants.py ADDED
@@ -0,0 +1,11 @@
1
+ from importlib.resources import as_file
2
+ from importlib.resources import files
3
+
4
+ SAMPLE_RATE = 16000
5
+ CHANNEL_COUNT = 1
6
+ BIT_DEPTH = 32
7
+ SAMPLE_BYTES = BIT_DEPTH // 8
8
+ FLOAT_BYTES = 4
9
+
10
+ with as_file(files("sonusai.data").joinpath("whitenoise.wav")) as path:
11
+ DEFAULT_NOISE = str(path)
sonusai/data/genmixdb.yml CHANGED
@@ -3,54 +3,41 @@
3
3
  # The values in this file are the defaults used if they are not specified in a
4
4
  # local config.
5
5
 
6
- feature: ""
7
-
8
- target_level_type: default
9
-
10
- class_indices: 1
6
+ seed: 0
11
7
 
12
- targets: [ ]
8
+ feature: ""
13
9
 
14
10
  num_classes: 1
15
11
 
16
- class_labels: [ ]
12
+ asr_configs: { }
17
13
 
18
- seed: 0
14
+ level_type: default
19
15
 
16
+ class_indices: 1
17
+ class_labels: [ ]
20
18
  class_weights_threshold: 0.5
21
19
 
22
- truth_configs: { }
23
-
24
- asr_manifest: [ ]
25
-
26
- target_augmentations:
27
- - pre:
28
-
29
- class_balancing_augmentation:
30
- normalize: -3.5
31
- pitch: "rand(-300, 300)"
32
- tempo: "rand(0.8, 1.2)"
33
- eq1: [ "rand(50, 250)", "rand(0.6, 1.0)", "rand(-6, 6)" ]
34
- eq2: [ "rand(250, 1200)", "rand(0.6, 1.0)", "rand(-6, 6)" ]
35
- eq3: [ "rand(1200, 6000)", "rand(0.6, 1.0)", "rand(-6, 6)" ]
20
+ class_balancing_effect:
21
+ - norm -3.5
22
+ - pitch sai_rand(-300, 300)
23
+ - tempo -s sai_rand(0.8, 1.2)
24
+ - equalizer sai_rand(50, 250) sai_rand(0.2, 2.0) sai_rand(-6, 6)
25
+ - equalizer sai_rand(250, 1200) sai_rand(0.2, 2.0) sai_rand(-6, 6)
26
+ - equalizer sai_rand(1200, 6000) sai_rand(0.2, 2.0) sai_rand(-6, 6)
36
27
 
37
28
  class_balancing: false
38
29
 
39
- noises:
40
- - "${default_noise}"
41
-
42
- noise_augmentations:
43
- - pre:
44
- normalize: -3.5
45
-
46
- snrs:
47
- - 99
30
+ impulse_responses: [ ]
48
31
 
49
- random_snrs: [ ]
32
+ sources:
33
+ primary:
34
+ files: [ ]
35
+ noise:
36
+ files: [ ]
50
37
 
51
- noise_mix_mode: exhaustive
38
+ summed_source_effects: [ ]
52
39
 
53
- impulse_responses: [ ]
40
+ mixture_effects: [ ]
54
41
 
55
42
  spectral_masks:
56
43
  - f_max_width: 27
@@ -58,5 +45,3 @@ spectral_masks:
58
45
  t_max_width: 100
59
46
  t_num: 0
60
47
  t_max_percent: 100
61
-
62
- asr_configs: { }