sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +81 -91
  13. sonusai/genmetrics.py +51 -61
  14. sonusai/genmix.py +105 -115
  15. sonusai/genmixdb.py +201 -174
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +16 -18
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +20 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +40 -38
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +669 -477
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +58 -101
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +41 -30
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
  113. sonusai-0.19.6.dist-info/RECORD +125 -0
  114. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.9.dist-info/RECORD +0 -125
  118. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
@@ -59,9 +59,10 @@ Metric and extraction data are written into prediction location PLOC as separate
59
59
  Inputs:
60
60
 
61
61
  """
62
+
62
63
  import signal
64
+ from contextlib import redirect_stdout
63
65
  from dataclasses import dataclass
64
- from typing import Optional
65
66
 
66
67
  import matplotlib
67
68
  import matplotlib.pyplot as plt
@@ -83,51 +84,33 @@ def signal_handler(_sig, _frame):
83
84
 
84
85
  from sonusai import logger
85
86
 
86
- logger.info('Canceled due to keyboard interrupt')
87
+ logger.info("Canceled due to keyboard interrupt")
87
88
  sys.exit(1)
88
89
 
89
90
 
90
91
  signal.signal(signal.SIGINT, signal_handler)
91
92
 
92
- matplotlib.use('SVG')
93
+ matplotlib.use("SVG")
93
94
 
94
95
 
95
96
  @dataclass
96
97
  class MPGlobal:
97
- mixdb: MixtureDatabase = None
98
- predict_location: str = None
99
- predict_wav_mode: bool = None
100
- truth_est_mode: bool = None
101
- enable_plot: bool = None
102
- enable_wav: bool = None
103
- asr_method: str = None
104
- asr_model_name: str = None
105
-
106
-
107
- MP_GLOBAL = MPGlobal()
108
-
109
-
110
- def power_compress(spec):
111
- mag = np.abs(spec)
112
- phase = np.angle(spec)
113
- mag = mag ** 0.3
114
- real_compress = mag * np.cos(phase)
115
- imag_compress = mag * np.sin(phase)
116
- return real_compress + 1j * imag_compress
98
+ mixdb: MixtureDatabase
99
+ predict_location: str
100
+ predict_wav_mode: bool
101
+ truth_est_mode: bool
102
+ enable_plot: bool
103
+ enable_wav: bool
104
+ asr_method: str
105
+ asr_model_name: str
117
106
 
118
107
 
119
- def power_uncompress(spec):
120
- mag = np.abs(spec)
121
- phase = np.angle(spec)
122
- mag = mag ** (1. / 0.3)
123
- real_uncompress = mag * np.cos(phase)
124
- imag_uncompress = mag * np.sin(phase)
125
- return real_uncompress + 1j * imag_uncompress
108
+ MP_GLOBAL: MPGlobal
126
109
 
127
110
 
128
- def mean_square_error(hypothesis: np.ndarray,
129
- reference: np.ndarray,
130
- squared: bool = False) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
111
+ def mean_square_error(
112
+ hypothesis: np.ndarray, reference: np.ndarray, squared: bool = False
113
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
131
114
  """Calculate root-mean-square error or mean square error
132
115
 
133
116
  :param hypothesis: [frames, bins]
@@ -152,8 +135,9 @@ def mean_square_error(hypothesis: np.ndarray,
152
135
  return err, err_b, err_f
153
136
 
154
137
 
155
- def mean_abs_percentage_error(hypothesis: np.ndarray,
156
- reference: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
138
+ def mean_abs_percentage_error(
139
+ hypothesis: np.ndarray, reference: np.ndarray
140
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
157
141
  """Calculate mean abs percentage error
158
142
 
159
143
  If inputs are complex, calculates average: mape(real)/2 + mape(imag)/2
@@ -205,13 +189,16 @@ def log_error(reference: np.ndarray, hypothesis: np.ndarray) -> tuple[np.ndarray
205
189
  return err, err_b, err_f
206
190
 
207
191
 
208
- def plot_mixpred(mixture: AudioT,
209
- mixture_f: AudioF,
210
- target: Optional[AudioT] = None,
211
- feature: Optional[Feature] = None,
212
- predict: Optional[Predict] = None,
213
- tp_title: str = '') -> plt.Figure:
192
+ def plot_mixpred(
193
+ mixture: AudioT,
194
+ mixture_f: AudioF,
195
+ target: AudioT | None = None,
196
+ feature: Feature | None = None,
197
+ predict: Predict | None = None,
198
+ tp_title: str = "",
199
+ ) -> plt.Figure:
214
200
  from sonusai.mixture import SAMPLE_RATE
201
+
215
202
  num_plots = 2
216
203
  if feature is not None:
217
204
  num_plots += 1
@@ -223,36 +210,48 @@ def plot_mixpred(mixture: AudioT,
223
210
  # Plot the waveform
224
211
  p = 0
225
212
  x_axis = np.arange(len(mixture), dtype=np.float32) / SAMPLE_RATE
226
- ax[p].plot(x_axis, mixture, label='Mixture', color='mistyrose')
227
- ax[0].set_ylabel('magnitude', color='tab:blue')
213
+ ax[p].plot(x_axis, mixture, label="Mixture", color="mistyrose")
214
+ ax[0].set_ylabel("magnitude", color="tab:blue")
228
215
  ax[p].set_xlim(x_axis[0], x_axis[-1])
229
216
  if target is not None: # Plot target time-domain waveform on top of mixture
230
- ax[0].plot(x_axis, target, label='Target', color='tab:blue')
231
- ax[p].set_title('Waveform')
217
+ ax[0].plot(x_axis, target, label="Target", color="tab:blue")
218
+ ax[p].set_title("Waveform")
232
219
 
233
220
  # Plot the mixture spectrogram
234
221
  p += 1
235
- ax[p].imshow(np.transpose(mixture_f), aspect='auto', interpolation='nearest', origin='lower')
236
- ax[p].set_title('Mixture')
222
+ ax[p].imshow(np.transpose(mixture_f), aspect="auto", interpolation="nearest", origin="lower")
223
+ ax[p].set_title("Mixture")
237
224
 
238
225
  if feature is not None:
239
226
  p += 1
240
- ax[p].imshow(np.transpose(feature), aspect='auto', interpolation='nearest', origin='lower')
241
- ax[p].set_title('Feature')
227
+ ax[p].imshow(
228
+ np.transpose(feature),
229
+ aspect="auto",
230
+ interpolation="nearest",
231
+ origin="lower",
232
+ )
233
+ ax[p].set_title("Feature")
242
234
 
243
235
  if predict is not None:
244
236
  p += 1
245
- im = ax[p].imshow(np.transpose(predict), aspect='auto', interpolation='nearest', origin='lower')
246
- ax[p].set_title('Predict ' + tp_title)
247
- plt.colorbar(im, location='bottom')
237
+ im = ax[p].imshow(
238
+ np.transpose(predict),
239
+ aspect="auto",
240
+ interpolation="nearest",
241
+ origin="lower",
242
+ )
243
+ ax[p].set_title("Predict " + tp_title)
244
+ plt.colorbar(im, location="bottom")
248
245
 
249
246
  return fig
250
247
 
251
248
 
252
- def plot_pdb_predict_truth(predict: np.ndarray,
253
- truth_f: Optional[np.ndarray] = None,
254
- metric: Optional[np.ndarray] = None,
255
- tp_title: str = '') -> plt.Figure:
249
+ def plot_pdb_predict_truth(
250
+ predict: np.ndarray,
251
+ truth_f: np.ndarray | None = None,
252
+ metric: np.ndarray | None = None,
253
+ tp_title: str = "",
254
+ ) -> plt.Figure:
256
255
  """Plot predict and optionally truth and a metric in power db, e.g. applies 10*log10(predict)"""
257
256
  num_plots = 2
258
257
  if truth_f is not None:
@@ -263,48 +262,62 @@ def plot_pdb_predict_truth(predict: np.ndarray,
263
262
  # Plot the predict spectrogram
264
263
  p = 0
265
264
  tmp = 10 * np.log10(predict.transpose() + np.finfo(np.float32).eps)
266
- im = ax[p].imshow(tmp, aspect='auto', interpolation='nearest', origin='lower')
267
- ax[p].set_title('Predict')
268
- plt.colorbar(im, location='bottom')
265
+ im = ax[p].imshow(tmp, aspect="auto", interpolation="nearest", origin="lower")
266
+ ax[p].set_title("Predict")
267
+ plt.colorbar(im, location="bottom")
269
268
 
270
269
  if truth_f is not None:
271
270
  p += 1
272
271
  tmp = 10 * np.log10(truth_f.transpose() + np.finfo(np.float32).eps)
273
- im = ax[p].imshow(tmp, aspect='auto', interpolation='nearest', origin='lower')
274
- ax[p].set_title('Truth')
275
- plt.colorbar(im, location='bottom')
272
+ im = ax[p].imshow(tmp, aspect="auto", interpolation="nearest", origin="lower")
273
+ ax[p].set_title("Truth")
274
+ plt.colorbar(im, location="bottom")
276
275
 
277
276
  # Plot the predict avg, and optionally truth avg and metric lines
278
277
  pred_avg = 10 * np.log10(np.mean(predict, axis=-1) + np.finfo(np.float32).eps)
279
278
  p += 1
280
279
  x_axis = np.arange(len(pred_avg), dtype=np.float32) # / SAMPLE_RATE
281
- ax[p].plot(x_axis, pred_avg, color='black', linestyle='dashed', label='Predict mean over freq.')
282
- ax[p].set_ylabel('mean db', color='black')
280
+ ax[p].plot(
281
+ x_axis,
282
+ pred_avg,
283
+ color="black",
284
+ linestyle="dashed",
285
+ label="Predict mean over freq.",
286
+ )
287
+ ax[p].set_ylabel("mean db", color="black")
283
288
  ax[p].set_xlim(x_axis[0], x_axis[-1])
284
289
  if truth_f is not None:
285
290
  truth_avg = 10 * np.log10(np.mean(truth_f, axis=-1) + np.finfo(np.float32).eps)
286
- ax[p].plot(x_axis, truth_avg, color='green', linestyle='dashed', label='Truth mean over freq.')
291
+ ax[p].plot(
292
+ x_axis,
293
+ truth_avg,
294
+ color="green",
295
+ linestyle="dashed",
296
+ label="Truth mean over freq.",
297
+ )
287
298
 
288
299
  if metric is not None: # instantiate 2nd y-axis that shares the same x-axis
289
300
  ax2 = ax[p].twinx()
290
- color2 = 'red'
291
- ax2.plot(x_axis, metric, color=color2, label='sig distortion (mse db)')
301
+ color2 = "red"
302
+ ax2.plot(x_axis, metric, color=color2, label="sig distortion (mse db)")
292
303
  ax2.set_xlim(x_axis[0], x_axis[-1])
293
304
  ax2.set_ylim([0, np.max(metric)])
294
- ax2.set_ylabel('spectral distortion (mse db)', color=color2)
295
- ax2.tick_params(axis='y', labelcolor=color2)
296
- ax[p].set_title('SNR and SNR mse (mean over freq. db)')
305
+ ax2.set_ylabel("spectral distortion (mse db)", color=color2)
306
+ ax2.tick_params(axis="y", labelcolor=color2)
307
+ ax[p].set_title("SNR and SNR mse (mean over freq. db)")
297
308
  else:
298
- ax[p].set_title('SNR (mean over freq. db)')
309
+ ax[p].set_title("SNR (mean over freq. db)")
299
310
  return fig
300
311
 
301
312
 
302
- def plot_e_predict_truth(predict: np.ndarray,
303
- predict_wav: np.ndarray,
304
- truth_f: Optional[np.ndarray] = None,
305
- truth_wav: Optional[np.ndarray] = None,
306
- metric: Optional[np.ndarray] = None,
307
- tp_title: str = '') -> plt.Figure:
313
+ def plot_e_predict_truth(
314
+ predict: np.ndarray,
315
+ predict_wav: np.ndarray,
316
+ truth_f: np.ndarray | None = None,
317
+ truth_wav: np.ndarray | None = None,
318
+ metric: np.ndarray | None = None,
319
+ tp_title: str = "",
320
+ ) -> plt.Figure:
308
321
  """Plot predict spectrogram and waveform and optionally truth and a metric)"""
309
322
  num_plots = 2
310
323
  if truth_f is not None:
@@ -316,26 +329,32 @@ def plot_e_predict_truth(predict: np.ndarray,
316
329
 
317
330
  # Plot the predict spectrogram
318
331
  p = 0
319
- im = ax[p].imshow(predict.transpose(), aspect='auto', interpolation='nearest', origin='lower')
320
- ax[p].set_title('Predict')
321
- plt.colorbar(im, location='bottom')
332
+ im = ax[p].imshow(predict.transpose(), aspect="auto", interpolation="nearest", origin="lower")
333
+ ax[p].set_title("Predict")
334
+ plt.colorbar(im, location="bottom")
322
335
 
323
336
  if truth_f is not None: # plot truth if provided and use same colormap as predict
324
337
  p += 1
325
- ax[p].imshow(truth_f.transpose(), im.cmap, aspect='auto', interpolation='nearest', origin='lower')
326
- ax[p].set_title('Truth')
338
+ ax[p].imshow(
339
+ truth_f.transpose(),
340
+ im.cmap,
341
+ aspect="auto",
342
+ interpolation="nearest",
343
+ origin="lower",
344
+ )
345
+ ax[p].set_title("Truth")
327
346
 
328
347
  # Plot predict wav, and optionally truth avg and metric lines
329
348
  p += 1
330
349
  x_axis = np.arange(len(predict_wav), dtype=np.float32) # / SAMPLE_RATE
331
- ax[p].plot(x_axis, predict_wav, color='black', linestyle='dashed', label='Speech Estimate')
332
- ax[p].set_ylabel('Amplitude', color='black')
350
+ ax[p].plot(x_axis, predict_wav, color="black", linestyle="dashed", label="Speech Estimate")
351
+ ax[p].set_ylabel("Amplitude", color="black")
333
352
  ax[p].set_xlim(x_axis[0], x_axis[-1])
334
353
  if truth_wav is not None:
335
354
  ntrim = len(truth_wav) - len(predict_wav)
336
355
  if ntrim > 0:
337
356
  truth_wav = truth_wav[0:-ntrim]
338
- ax[p].plot(x_axis, truth_wav, color='green', linestyle='dashed', label='True Target')
357
+ ax[p].plot(x_axis, truth_wav, color="green", linestyle="dashed", label="True Target")
339
358
 
340
359
  # Plot the metric lines
341
360
  if metric is not None:
@@ -345,22 +364,21 @@ def plot_e_predict_truth(predict: np.ndarray,
345
364
  else:
346
365
  metric1 = metric # if single dim, plot it as 1st
347
366
  x_axis = np.arange(len(metric1), dtype=np.float32) # / SAMPLE_RATE
348
- ax[p].plot(x_axis, metric1, color='red', label='Target LogErr')
349
- ax[p].set_ylabel('log error db', color='red')
367
+ ax[p].plot(x_axis, metric1, color="red", label="Target LogErr")
368
+ ax[p].set_ylabel("log error db", color="red")
350
369
  ax[p].set_xlim(x_axis[0], x_axis[-1])
351
- ax[p].set_ylim([-0.01, np.max(metric1) + .01])
352
- if metric.ndim > 1:
353
- if metric.shape[1] > 1:
354
- metr2 = metric[:, 1]
355
- ax2 = ax[p].twinx()
356
- color2 = 'blue'
357
- ax2.plot(x_axis, metr2, color=color2, label='phase dist (deg)')
358
- # ax2.set_ylim([-180.0, +180.0])
359
- if np.max(metr2) - np.min(metr2) > .1:
360
- ax2.set_ylim([np.min(metr2), np.max(metr2)])
361
- ax2.set_ylabel('phase dist (deg)', color=color2)
362
- ax2.tick_params(axis='y', labelcolor=color2)
363
- # ax[p].set_title('SNR and SNR mse (mean over freq. db)')
370
+ ax[p].set_ylim([-0.01, np.max(metric1) + 0.01])
371
+ if metric.ndim > 1 and metric.shape[1] > 1:
372
+ metr2 = metric[:, 1]
373
+ ax2 = ax[p].twinx()
374
+ color2 = "blue"
375
+ ax2.plot(x_axis, metr2, color=color2, label="phase dist (deg)")
376
+ # ax2.set_ylim([-180.0, +180.0])
377
+ if np.max(metr2) - np.min(metr2) > 0.1:
378
+ ax2.set_ylim([np.min(metr2), np.max(metr2)])
379
+ ax2.set_ylabel("phase dist (deg)", color=color2)
380
+ ax2.tick_params(axis="y", labelcolor=color2)
381
+ # ax[p].set_title('SNR and SNR mse (mean over freq. db)')
364
382
 
365
383
  return fig
366
384
 
@@ -376,7 +394,6 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
376
394
  from matplotlib.backends.backend_pdf import PdfPages
377
395
  from pystoi import stoi
378
396
 
379
- from sonusai import SonusAIError
380
397
  from sonusai import logger
381
398
  from sonusai.metrics import calc_pcm
382
399
  from sonusai.metrics import calc_phase_distance
@@ -388,11 +405,15 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
388
405
  from sonusai.mixture import read_audio
389
406
  from sonusai.utils import calc_asr
390
407
  from sonusai.utils import float_to_int16
408
+ from sonusai.utils import power_compress
409
+ from sonusai.utils import power_uncompress
391
410
  from sonusai.utils import reshape_outputs
392
411
  from sonusai.utils import stack_complex
393
412
  from sonusai.utils import unstack_complex
394
413
  from sonusai.utils import write_audio
395
414
 
415
+ global MP_GLOBAL
416
+
396
417
  mixdb = MP_GLOBAL.mixdb
397
418
  predict_location = MP_GLOBAL.predict_location
398
419
  predict_wav_mode = MP_GLOBAL.predict_wav_mode
@@ -409,25 +430,25 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
409
430
  # in truth estimation mode we use the truth in place of prediction to see metrics with perfect input
410
431
  # don't bother to read prediction, and predict var will get assigned to truth later
411
432
  # mark outputs with tru suffix, i.e. 0000_truest_*
412
- base_name = splitext(output_name)[0] + '_truest'
433
+ base_name = splitext(output_name)[0] + "_truest"
413
434
  else:
414
435
  base_name, ext = splitext(output_name) # base_name used later
415
436
  if not predict_wav_mode:
416
437
  try:
417
- with h5py.File(output_name, 'r') as f:
418
- predict = np.array(f['predict'])
438
+ with h5py.File(output_name, "r") as f:
439
+ predict = np.array(f["predict"])
419
440
  except Exception as e:
420
- raise SonusAIError(f'Error reading {output_name}: {e}')
441
+ raise OSError(f"Error reading {output_name}: {e}") from e
421
442
  # reshape to always be [frames,classes] where ndim==3 case frames = batch * tsteps
422
443
  if predict.ndim > 2: # TBD generalize to somehow detect if timestep dim exists, some cases > 2 don't have
423
444
  # logger.debug(f'Prediction reshape from {predict.shape} to remove timestep dimension.')
424
445
  predict, _ = reshape_outputs(predict=predict, truth=None, timesteps=predict.shape[1])
425
446
  else:
426
447
  base_name, ext = splitext(output_name)
427
- predict_name = join(base_name + '.wav')
448
+ predict_name = join(base_name + ".wav")
428
449
  audio = read_audio(predict_name)
429
450
  predict = forward_transform(audio, mixdb.ft_config)
430
- if mixdb.feature[0:1] == 'h':
451
+ if mixdb.feature[0:1] == "h":
431
452
  predict = power_compress(predict)
432
453
  predict = stack_complex(predict)
433
454
 
@@ -441,7 +462,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
441
462
  noise = mixture - target # has time-domain distortion (ir,etc.) but does not have specaugment
442
463
  # noise_f = mixdb.mixture_noise_f(mixid, noise=noise)
443
464
  # note: uses pre-IR, pre-specaug audio
444
- segsnr_f: np.ndarray = mixdb.mixture_metrics(mixid, ['ssnr'])[0] # type: ignore
465
+ segsnr_f: np.ndarray = mixdb.mixture_metrics(mixid, ["ssnr"])[0] # type: ignore[assignment]
445
466
  mixture_f = mixdb.mixture_mixture_f(mixid, mixture=mixture)
446
467
  noise_f = mixture_f - target_f # true noise in freq domain includes specaugment and time-domain ir,distortions
447
468
  # segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise)
@@ -457,8 +478,8 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
457
478
  # TODO: parse truth_f for different formats
458
479
  feature, truth_f = mixdb.mixture_ft(mixid, mixture_f=mixture_f)
459
480
  # ignore mixup
460
- for truth_setting in mixdb.target_file(mixdb.mixture(mixid).targets[0].file_id).truth_settings:
461
- if truth_setting.function == 'target_mixture_f':
481
+ for truth_setting in mixdb.target_file(mixdb.mixture(mixid).targets[0].file_id).truth_configs:
482
+ if truth_setting.function == "target_mixture_f":
462
483
  half = truth_f.shape[-1] // 2
463
484
  # extract target_f only
464
485
  truth_f = truth_f[..., :half]
@@ -466,7 +487,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
466
487
  if not truth_est_mode:
467
488
  if predict.shape[0] < target_f.shape[0]: # target_f, truth_f, mixture_f, etc. same size
468
489
  trim_f = target_f.shape[0] - predict.shape[0]
469
- logger.debug(f'Warning: prediction frames less than mixture, trimming {trim_f} frames from all truth.')
490
+ logger.debug(f"Warning: prediction frames less than mixture, trimming {trim_f} frames from all truth.")
470
491
  target_f = target_f[0:-trim_f, :]
471
492
  target_fi, _ = inverse_transform(target_f, mixdb.it_config)
472
493
  trim_t = target.shape[0] - target_fi.shape[0]
@@ -478,10 +499,11 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
478
499
  truth_f = truth_f[0:-trim_f, :]
479
500
  elif predict.shape[0] > target_f.shape[0]:
480
501
  logger.debug(
481
- f'Warning: prediction has more frames than true mixture {predict.shape[0]} vs {truth_f.shape[0]}')
502
+ f"Warning: prediction has more frames than true mixture {predict.shape[0]} vs {truth_f.shape[0]}"
503
+ )
482
504
  trim_f = predict.shape[0] - target_f.shape[0]
483
505
  predict = predict[0:-trim_f, :]
484
- # raise SonusAIError(
506
+ # raise ValueError(
485
507
  # f'Error: prediction has more frames than true mixture {predict.shape[0]} vs {truth_f.shape[0]}')
486
508
 
487
509
  # 3) Extraction - format proper complex and wav estimates and truth (unstack, uncompress, inv tf, etc.)
@@ -489,13 +511,13 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
489
511
  predict = truth_f # substitute truth for the prediction (for test/debug)
490
512
  predict_complex = unstack_complex(predict) # unstack
491
513
  # if feat has compressed mag and truth does not, compress it
492
- if mixdb.feature[0:1] == 'h' and mixdb.target_file(1).truth_settings[0].function[0:10] != 'targetcmpr':
514
+ if mixdb.feature[0:1] == "h" and mixdb.target_file(1).truth_configs[0].function[0:10] != "targetcmpr":
493
515
  predict_complex = power_compress(predict_complex) # from uncompressed truth
494
516
  else:
495
517
  predict_complex = unstack_complex(predict)
496
518
 
497
519
  truth_f_complex = unstack_complex(truth_f)
498
- if mixdb.feature[0:1] == 'h': # 'hn' or 'ha' or 'hd', etc.: # if feat has compressed mag
520
+ if mixdb.feature[0:1] == "h": # 'hn' or 'ha' or 'hd', etc.: # if feat has compressed mag
499
521
  # estimate noise in uncompressed-mag domain
500
522
  noise_est_complex = mixture_f - power_uncompress(predict_complex)
501
523
  predict_complex = power_uncompress(predict_complex) # uncompress if truth is compressed
@@ -537,7 +559,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
537
559
  if int(mixdb.mixture(mixid).snr) > -99:
538
560
  # len = target_est_wav.shape[0]
539
561
  pesq_speech, csig_tg, cbak_tg, covl_tg = calc_speech(target_est_wav, target_fi)
540
- pesq_mixture, csig_mx, cbak_mx, covl_mx = mixdb.mixture_metrics(mixid, ['mxpesq', 'mxcsig', 'mxcbak', 'mxcovl'])
562
+ pesq_mixture, csig_mx, cbak_mx, covl_mx = mixdb.mixture_metrics(mixid, ["mxpesq", "mxcsig", "mxcbak", "mxcovl"])
541
563
  # pesq_speech_tst = calc_pesq(hypothesis=target_est_wav, reference=target)
542
564
  # pesq_mixture_tst = calc_pesq(hypothesis=mixture, reference=target)
543
565
  # pesq improvement
@@ -561,8 +583,8 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
561
583
  asr_tge = None
562
584
  asr_engines = list(mixdb.asr_configs.keys())
563
585
  if len(asr_engines) > 0 and mixdb.mixture(mixid).snr >= -96: # noise only, ignore/reset target asr
564
- wer_mx = float(mixdb.mixture_metrics(mixid, [f'mxwer.{asr_engines[0]}'])[0]) * 100
565
- asr_tt = MP_GLOBAL.mixdb.mixture_speech_metadata(mixid, 'text')[0] # ignore mixup
586
+ wer_mx = float(mixdb.mixture_metrics(mixid, [f"mxwer.{asr_engines[0]}"])[0]) * 100
587
+ asr_tt = MP_GLOBAL.mixdb.mixture_speech_metadata(mixid, "text")[0] # ignore mixup
566
588
  if asr_tt is None:
567
589
  asr_tt = calc_asr(target, engine=asr_method, whisper_model_name=asr_model_name).text # target truth
568
590
 
@@ -577,76 +599,118 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
577
599
  else:
578
600
  wer_pi = 100 * (wer_mx - wer_tge) / wer_mx
579
601
  else:
580
- print(f'Warning: mixid {mixid} asr truth is empty, setting to 0% wer')
602
+ print(f"Warning: mixid {mixid} asr truth is empty, setting to 0% wer")
581
603
  wer_mx = float(0)
582
604
  wer_tge = float(0)
583
605
  wer_pi = float(0)
584
606
  else:
585
- wer_mx = float('nan')
586
- wer_tge = float('nan')
587
- wer_pi = float('nan')
607
+ wer_mx = float("nan")
608
+ wer_tge = float("nan")
609
+ wer_pi = float("nan")
588
610
 
589
611
  # 5) Save per mixture metric results
590
612
  # Single row in table of scalar metrics per mixture
591
- mtable1_col = ['MXSNR', 'MXPESQ', 'PESQ', 'PESQi%', 'MXWER', 'WER', 'WERi%', 'WSDR', 'STOI',
592
- 'PCM', 'SPLERR', 'NLERR', 'PD', 'MXCSIG', 'CSIG', 'MXCBAK', 'CBAK', 'MXCOVL', 'COVL',
593
- 'SPFILE', 'NFILE']
613
+ mtable1_col = [
614
+ "MXSNR",
615
+ "MXPESQ",
616
+ "PESQ",
617
+ "PESQi%",
618
+ "MXWER",
619
+ "WER",
620
+ "WERi%",
621
+ "WSDR",
622
+ "STOI",
623
+ "PCM",
624
+ "SPLERR",
625
+ "NLERR",
626
+ "PD",
627
+ "MXCSIG",
628
+ "CSIG",
629
+ "MXCBAK",
630
+ "CBAK",
631
+ "MXCOVL",
632
+ "COVL",
633
+ "SPFILE",
634
+ "NFILE",
635
+ ]
594
636
  ti = mixdb.mixture(mixid).targets[0].file_id
595
637
  ni = mixdb.mixture(mixid).noise.file_id
596
- metr1 = [mixdb.mixture(mixid).snr, pesq_mixture, pesq_speech, pesq_impr_pc, wer_mx, wer_tge, wer_pi, wsdr,
597
- target_stoi, pcm, lerr_tg, lerr_n, phd, csig_mx, csig_tg, cbak_mx, cbak_tg, covl_mx, covl_tg,
598
- basename(mixdb.target_file(ti).name), basename(mixdb.noise_file(ni).name)]
638
+ metr1 = [
639
+ mixdb.mixture(mixid).snr,
640
+ pesq_mixture,
641
+ pesq_speech,
642
+ pesq_impr_pc,
643
+ wer_mx,
644
+ wer_tge,
645
+ wer_pi,
646
+ wsdr,
647
+ target_stoi,
648
+ pcm,
649
+ lerr_tg,
650
+ lerr_n,
651
+ phd,
652
+ csig_mx,
653
+ csig_tg,
654
+ cbak_mx,
655
+ cbak_tg,
656
+ covl_mx,
657
+ covl_tg,
658
+ basename(mixdb.target_file(ti).name),
659
+ basename(mixdb.noise_file(ni).name),
660
+ ]
599
661
  mtab1 = pd.DataFrame([metr1], columns=mtable1_col, index=[mixid])
600
662
 
601
663
  # Stats of per frame estimation metrics
602
- metr2 = pd.DataFrame({'SSNR': segsnr_f,
603
- 'PCM': pcm_frame,
604
- 'SLERR': lerr_tg_frame,
605
- 'NLERR': lerr_n_frame,
606
- 'SPD': phd_frame})
664
+ metr2 = pd.DataFrame(
665
+ {
666
+ "SSNR": segsnr_f,
667
+ "PCM": pcm_frame,
668
+ "SLERR": lerr_tg_frame,
669
+ "NLERR": lerr_n_frame,
670
+ "SPD": phd_frame,
671
+ }
672
+ )
607
673
  metr2 = metr2.describe() # Use pandas stat function
608
674
  # Change SSNR stats to dB, except count. SSNR is index 0, pandas requires using iloc
609
675
  # metr2['SSNR'][1:] = metr2['SSNR'][1:].apply(lambda x: 10 * np.log10(x + 1.01e-10))
610
- metr2.iloc[1:, 0] = metr2['SSNR'][1:].apply(lambda x: 10 * np.log10(x + 1.01e-10))
676
+ metr2.iloc[1:, 0] = metr2["SSNR"][1:].apply(lambda x: 10 * np.log10(x + 1.01e-10))
611
677
  # create a single row in multi-column header
612
- new_labels = pd.MultiIndex.from_product([metr2.columns,
613
- ['Avg', 'Min', 'Med', 'Max', 'Std']],
614
- names=['Metric', 'Stat'])
615
- dat1row = metr2.loc[['mean', 'min', '50%', 'max', 'std'], :].T.stack().to_numpy().reshape((1, -1))
616
- mtab2 = pd.DataFrame(dat1row,
617
- index=[mixid],
618
- columns=new_labels)
619
- mtab2.insert(0, 'MXSNR', mixdb.mixture(mixid).snr, False) # add MXSNR as the first metric column
678
+ new_labels = pd.MultiIndex.from_product(
679
+ [metr2.columns, ["Avg", "Min", "Med", "Max", "Std"]], names=["Metric", "Stat"]
680
+ )
681
+ dat1row = metr2.loc[["mean", "min", "50%", "max", "std"], :].T.stack().to_numpy().reshape((1, -1))
682
+ mtab2 = pd.DataFrame(dat1row, index=[mixid], columns=new_labels)
683
+ mtab2.insert(0, "MXSNR", mixdb.mixture(mixid).snr, False) # add MXSNR as the first metric column
620
684
 
621
685
  all_metrics_table_1 = mtab1 # return to be collected by process
622
686
  all_metrics_table_2 = mtab2 # return to be collected by process
623
687
 
624
- metric_name = base_name + '_metric_spenh.txt'
625
- with open(metric_name, 'w') as f:
626
- print('Speech enhancement metrics:', file=f)
627
- print(mtab1.round(2).to_string(float_format=lambda x: "{:.2f}".format(x)), file=f)
628
- print('', file=f)
629
- print(f'Extraction statistics over {mixture_f.shape[0]} frames:', file=f)
630
- print(metr2.round(2).to_string(float_format=lambda x: "{:.2f}".format(x)), file=f)
631
- print('', file=f)
632
- print(f'Target path: {mixdb.target_file(ti).name}', file=f)
633
- print(f'Noise path: {mixdb.noise_file(ni).name}', file=f)
634
- if asr_method != 'none':
635
- print(f'ASR method: {asr_method} and whisper model (if used): {asr_model_name}', file=f)
636
- print(f'ASR truth: {asr_tt}', file=f)
637
- print(f'ASR result for mixture: {asr_mx}', file=f)
638
- print(f'ASR result for prediction: {asr_tge}', file=f)
639
-
640
- print(f'Augmentations: {mixdb.mixture(mixid)}', file=f)
688
+ metric_name = base_name + "_metric_spenh.txt"
689
+ with open(metric_name, "w") as f, redirect_stdout(f):
690
+ print("Speech enhancement metrics:")
691
+ print(mtab1.round(2).to_string(float_format=lambda x: f"{x:.2f}"))
692
+ print("")
693
+ print(f"Extraction statistics over {mixture_f.shape[0]} frames:")
694
+ print(metr2.round(2).to_string(float_format=lambda x: f"{x:.2f}"))
695
+ print("")
696
+ print(f"Target path: {mixdb.target_file(ti).name}")
697
+ print(f"Noise path: {mixdb.noise_file(ni).name}")
698
+ if asr_method != "none":
699
+ print(f"ASR method: {asr_method} and whisper model (if used): {asr_model_name}")
700
+ print(f"ASR truth: {asr_tt}")
701
+ print(f"ASR result for mixture: {asr_mx}")
702
+ print(f"ASR result for prediction: {asr_tge}")
703
+
704
+ print(f"Augmentations: {mixdb.mixture(mixid)}")
641
705
 
642
706
  # 7) write wav files
643
707
  if enable_wav:
644
- write_audio(name=base_name + '_mixture.wav', audio=float_to_int16(mixture))
645
- write_audio(name=base_name + '_target.wav', audio=float_to_int16(target))
708
+ write_audio(name=base_name + "_mixture.wav", audio=float_to_int16(mixture))
709
+ write_audio(name=base_name + "_target.wav", audio=float_to_int16(target))
646
710
  # write_audio(name=base_name + '_target_fi.wav', audio=float_to_int16(target_fi))
647
- write_audio(name=base_name + '_noise.wav', audio=float_to_int16(noise))
648
- write_audio(name=base_name + '_target_est.wav', audio=float_to_int16(target_est_wav))
649
- write_audio(name=base_name + '_noise_est.wav', audio=float_to_int16(noise_est_wav))
711
+ write_audio(name=base_name + "_noise.wav", audio=float_to_int16(noise))
712
+ write_audio(name=base_name + "_target_est.wav", audio=float_to_int16(target_est_wav))
713
+ write_audio(name=base_name + "_noise_est.wav", audio=float_to_int16(noise_est_wav))
650
714
 
651
715
  # debug code to test for perfect reconstruction of the extraction method
652
716
  # note both 75% olsa-hanns and 50% olsa-hann modes checked to have perfect reconstruction
@@ -657,14 +721,14 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
657
721
 
658
722
  # 8) Write out plot file
659
723
  if enable_plot:
660
- plot_name = base_name + '_metric_spenh.pdf'
724
+ plot_name = base_name + "_metric_spenh.pdf"
661
725
 
662
726
  # Reshape feature to eliminate overlap redundancy for easier to understand spectrogram view
663
727
  # Original size (frames, stride, num_bands), decimates in stride dimension only if step is > 1
664
728
  # Reshape to get frames*decimated_stride, num_bands
665
729
  step = int(mixdb.feature_samples / mixdb.feature_step_samples)
666
730
  if feature.ndim != 3:
667
- raise SonusAIError(f'feature does not have 3 dimensions: frames, stride, num_bands')
731
+ raise ValueError("feature does not have 3 dimensions: frames, stride, num_bands")
668
732
 
669
733
  # for feature cn*00n**
670
734
  feat_sgram = unstack_complex(feature)
@@ -674,27 +738,29 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
674
738
 
675
739
  with PdfPages(plot_name) as pdf:
676
740
  # page1 we always have a mixture and prediction, target optional if truth provided
677
- tfunc_name = mixdb.target_file(1).truth_settings[0].function # first target, assumes all have same
678
- if tfunc_name == 'mapped_snr_f':
741
+ tfunc_name = mixdb.target_file(1).truth_configs[0].function # first target, assumes all have same
742
+ if tfunc_name == "mapped_snr_f":
679
743
  # leave as unmapped snr
680
744
  predplot = predict
681
- tfunc_name = mixdb.target_file(1).truth_settings[0].function
682
- elif tfunc_name == 'target_f' or 'target_mixture_f':
745
+ tfunc_name = mixdb.target_file(1).truth_configs[0].function
746
+ elif tfunc_name in ("target_f", "target_mixture_f"):
683
747
  predplot = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
684
748
  else:
685
749
  # use dB scale
686
750
  predplot = 10 * np.log10(predict + np.finfo(np.float32).eps)
687
- tfunc_name = tfunc_name + ' (db)'
751
+ tfunc_name = tfunc_name + " (db)"
688
752
 
689
753
  mixspec = 20 * np.log10(abs(mixture_f) + np.finfo(np.float32).eps)
690
- fig_obj = plot_mixpred(mixture=mixture,
691
- mixture_f=mixspec,
692
- target=target,
693
- feature=feat_sgram,
694
- predict=predplot,
695
- tp_title=tfunc_name)
754
+ fig_obj = plot_mixpred(
755
+ mixture=mixture,
756
+ mixture_f=mixspec,
757
+ target=target,
758
+ feature=feat_sgram,
759
+ predict=predplot,
760
+ tp_title=tfunc_name,
761
+ )
696
762
  pdf.savefig(fig_obj)
697
- with mgzip.open(base_name + '_metric_spenh_fig1.mfigz', 'wb') as f:
763
+ with mgzip.open(base_name + "_metric_spenh_fig1.mfigz", "wb") as f:
698
764
  pickle.dump(fig_obj, f)
699
765
 
700
766
  # ----- page 2, plot unmapped predict, opt truth reconstructed and line plots of mean-over-f
@@ -704,34 +770,38 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
704
770
  tg_spec = 20 * np.log10(abs(target_f) + np.finfo(np.float32).eps)
705
771
  tg_est_spec = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
706
772
  # n_spec = np.reshape(n_spec,(n_spec.shape[0] * n_spec.shape[1], n_spec.shape[2]))
707
- fig_obj = plot_e_predict_truth(predict=tg_est_spec,
708
- predict_wav=target_est_wav,
709
- truth_f=tg_spec,
710
- truth_wav=target_fi,
711
- metric=np.vstack((lerr_tg_frame, phd_frame)).T,
712
- tp_title='speech estimate')
773
+ fig_obj = plot_e_predict_truth(
774
+ predict=tg_est_spec,
775
+ predict_wav=target_est_wav,
776
+ truth_f=tg_spec,
777
+ truth_wav=target_fi,
778
+ metric=np.vstack((lerr_tg_frame, phd_frame)).T,
779
+ tp_title="speech estimate",
780
+ )
713
781
  pdf.savefig(fig_obj)
714
- with mgzip.open(base_name + '_metric_spenh_fig2.mfigz', 'wb') as f:
782
+ with mgzip.open(base_name + "_metric_spenh_fig2.mfigz", "wb") as f:
715
783
  pickle.dump(fig_obj, f)
716
784
 
717
785
  # page 4 noise extraction
718
786
  n_spec = 20 * np.log10(abs(noise_f) + np.finfo(np.float32).eps)
719
787
  n_est_spec = 20 * np.log10(abs(noise_est_complex) + np.finfo(np.float32).eps)
720
- fig_obj = plot_e_predict_truth(predict=n_est_spec,
721
- predict_wav=noise_est_wav,
722
- truth_f=n_spec,
723
- truth_wav=noise_fi,
724
- metric=lerr_n_frame,
725
- tp_title='noise estimate')
788
+ fig_obj = plot_e_predict_truth(
789
+ predict=n_est_spec,
790
+ predict_wav=noise_est_wav,
791
+ truth_f=n_spec,
792
+ truth_wav=noise_fi,
793
+ metric=lerr_n_frame,
794
+ tp_title="noise estimate",
795
+ )
726
796
  pdf.savefig(fig_obj)
727
- with mgzip.open(base_name + '_metric_spenh_fig4.mfigz', 'wb') as f:
797
+ with mgzip.open(base_name + "_metric_spenh_fig4.mfigz", "wb") as f:
728
798
  pickle.dump(fig_obj, f)
729
799
 
730
800
  # Plot error waveforms
731
801
  # tg_err_wav = target_fi - target_est_wav
732
802
  # tg_err_spec = 20*np.log10(np.abs(target_f - predict_complex))
733
803
 
734
- plt.close('all')
804
+ plt.close("all")
735
805
 
736
806
  return all_metrics_table_1, all_metrics_table_2
737
807
 
@@ -744,17 +814,17 @@ def main():
744
814
 
745
815
  args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
746
816
 
747
- verbose = args['--verbose']
748
- mixids = args['--mixid']
749
- asr_method = args['--asr-method'].lower()
750
- asr_model_name = args['--model'].lower()
751
- truth_est_mode = args['--truth-est-mode']
752
- enable_plot = args['--plot']
753
- enable_wav = args['--wav']
754
- enable_summary = args['--summary']
755
- predict_location = args['PLOC']
756
- num_proc = args['--num_process']
757
- truth_location = args['TLOC']
817
+ verbose = args["--verbose"]
818
+ mixids = args["--mixid"]
819
+ asr_method = args["--asr-method"].lower()
820
+ asr_model_name = args["--model"].lower()
821
+ truth_est_mode = args["--truth-est-mode"]
822
+ enable_plot = args["--plot"]
823
+ enable_wav = args["--wav"]
824
+ enable_summary = args["--summary"]
825
+ predict_location = args["PLOC"]
826
+ num_proc = args["--num_process"]
827
+ truth_location = args["TLOC"]
758
828
 
759
829
  import glob
760
830
  from os.path import basename
@@ -762,7 +832,6 @@ def main():
762
832
  from os.path import join
763
833
 
764
834
  import psutil
765
- from tqdm import tqdm
766
835
 
767
836
  from sonusai import create_file_handler
768
837
  from sonusai import initial_log_messages
@@ -772,11 +841,12 @@ def main():
772
841
  from sonusai.mixture import MixtureDatabase
773
842
  from sonusai.mixture import read_audio
774
843
  from sonusai.utils import calc_asr
775
- from sonusai.utils import pp_tqdm_imap
844
+ from sonusai.utils import par_track
845
+ from sonusai.utils import track
776
846
 
777
847
  # Check prediction subdirectory
778
848
  if not isdir(predict_location):
779
- print(f'The specified predict location {predict_location} is not a valid subdirectory path, exiting ...')
849
+ print(f"The specified predict location {predict_location} is not a valid subdirectory path, exiting ...")
780
850
 
781
851
  # all_predict_files = listdir(predict_location)
782
852
  all_predict_files = glob.glob(predict_location + "/*.h5")
@@ -785,69 +855,72 @@ def main():
785
855
  if len(all_predict_files) <= 0 and not truth_est_mode:
786
856
  all_predict_files = glob.glob(predict_location + "/*.wav") # check for wav files
787
857
  if len(all_predict_files) <= 0:
788
- print(f'Subdirectory {predict_location} has no .h5 or .wav files, exiting ...')
858
+ print(f"Subdirectory {predict_location} has no .h5 or .wav files, exiting ...")
789
859
  else:
790
- logger.info(f'Found {len(all_predict_files)} prediction .wav files.')
860
+ logger.info(f"Found {len(all_predict_files)} prediction .wav files.")
791
861
  predict_wav_mode = True
792
862
  else:
793
- logger.info(f'Found {len(all_predict_files)} prediction .h5 files.')
863
+ logger.info(f"Found {len(all_predict_files)} prediction .h5 files.")
794
864
 
795
865
  if len(predict_logfile) == 0:
796
- logger.info(f'Warning, predict location {predict_location} has no prediction log files.')
866
+ logger.info(f"Warning, predict location {predict_location} has no prediction log files.")
797
867
  else:
798
- logger.info(f'Found predict log {basename(predict_logfile[0])} in predict location.')
868
+ logger.info(f"Found predict log {basename(predict_logfile[0])} in predict location.")
799
869
 
800
870
  # Setup logging file
801
- create_file_handler(join(predict_location, 'calc_metric_spenh.log'))
871
+ create_file_handler(join(predict_location, "calc_metric_spenh.log"))
802
872
  update_console_handler(verbose)
803
- initial_log_messages('calc_metric_spenh')
873
+ initial_log_messages("calc_metric_spenh")
804
874
 
805
875
  mixdb = MixtureDatabase(truth_location)
806
876
  mixids = mixdb.mixids_to_list(mixids)
807
877
  logger.info(
808
- f'Found mixdb of {mixdb.num_mixtures} total mixtures, with {mixdb.num_classes} classes in {truth_location}')
809
- logger.info(f'Only running specified subset of {len(mixids)} mixtures')
878
+ f"Found mixdb of {mixdb.num_mixtures} total mixtures, with {mixdb.num_classes} classes in {truth_location}"
879
+ )
880
+ logger.info(f"Only running specified subset of {len(mixids)} mixtures")
810
881
 
811
882
  enable_asr_warmup = False
812
- if asr_method == 'none':
813
- fnb = 'metric_spenh_'
814
- elif asr_method == 'google':
815
- fnb = 'metric_spenh_ggl_'
816
- logger.info(f'ASR enabled with method {asr_method}')
883
+ if asr_method == "none":
884
+ fnb = "metric_spenh_"
885
+ elif asr_method == "google":
886
+ fnb = "metric_spenh_ggl_"
887
+ logger.info(f"ASR enabled with method {asr_method}")
817
888
  enable_asr_warmup = True
818
- elif asr_method == 'deepgram':
819
- fnb = 'metric_spenh_dgram_'
820
- logger.info(f'ASR enabled with method {asr_method}')
889
+ elif asr_method == "deepgram":
890
+ fnb = "metric_spenh_dgram_"
891
+ logger.info(f"ASR enabled with method {asr_method}")
821
892
  enable_asr_warmup = True
822
- elif asr_method == 'aixplain_whisper':
823
- fnb = 'metric_spenh_whspx_' + asr_model_name + '_'
824
- logger.info(f'ASR enabled with method {asr_method} and whisper model {asr_model_name}')
893
+ elif asr_method == "aixplain_whisper":
894
+ fnb = "metric_spenh_whspx_" + asr_model_name + "_"
895
+ logger.info(f"ASR enabled with method {asr_method} and whisper model {asr_model_name}")
825
896
  enable_asr_warmup = True
826
- elif asr_method == 'whisper':
827
- fnb = 'metric_spenh_whspl_' + asr_model_name + '_'
828
- logger.info(f'ASR enabled with method {asr_method} and whisper model {asr_model_name}')
897
+ elif asr_method == "whisper":
898
+ fnb = "metric_spenh_whspl_" + asr_model_name + "_"
899
+ logger.info(f"ASR enabled with method {asr_method} and whisper model {asr_model_name}")
829
900
  enable_asr_warmup = True
830
- elif asr_method == 'aaware_whisper':
831
- fnb = 'metric_spenh_whspaaw_' + asr_model_name + '_'
832
- logger.info(f'ASR enabled with method {asr_method} and whisper model {asr_model_name}')
901
+ elif asr_method == "aaware_whisper":
902
+ fnb = "metric_spenh_whspaaw_" + asr_model_name + "_"
903
+ logger.info(f"ASR enabled with method {asr_method} and whisper model {asr_model_name}")
833
904
  enable_asr_warmup = True
834
- elif asr_method == 'faster_whisper':
835
- fnb = 'metric_spenh_fwhsp_' + asr_model_name + '_'
836
- logger.info(f'ASR enabled with method {asr_method} and whisper model {asr_model_name}')
905
+ elif asr_method == "faster_whisper":
906
+ fnb = "metric_spenh_fwhsp_" + asr_model_name + "_"
907
+ logger.info(f"ASR enabled with method {asr_method} and whisper model {asr_model_name}")
837
908
  enable_asr_warmup = True
838
- elif asr_method == 'sensory':
839
- fnb = 'metric_spenh_snsr_' + asr_model_name + '_'
840
- logger.info(f'ASR enabled with method {asr_method} and model {asr_model_name}')
909
+ elif asr_method == "sensory":
910
+ fnb = "metric_spenh_snsr_" + asr_model_name + "_"
911
+ logger.info(f"ASR enabled with method {asr_method} and model {asr_model_name}")
841
912
  enable_asr_warmup = True
842
913
  else:
843
- logger.error(f'Unrecognized ASR method: {asr_method}')
914
+ logger.error(f"Unrecognized ASR method: {asr_method}")
844
915
  return
845
916
 
846
917
  if enable_asr_warmup:
847
918
  audio = read_audio(DEFAULT_SPEECH)
848
- logger.info(f'Warming up asr method, note for cloud service this could take up to a few min ...')
919
+ logger.info("Warming up asr method, note for cloud service this could take up to a few min ...")
849
920
  asr_chk = calc_asr(audio, engine=asr_method, whisper_model_name=asr_model_name)
850
- logger.info(f'Warmup completed, results {asr_chk}')
921
+ logger.info(f"Warmup completed, results {asr_chk}")
922
+
923
+ global MP_GLOBAL
851
924
 
852
925
  MP_GLOBAL.mixdb = mixdb
853
926
  MP_GLOBAL.predict_location = predict_location
@@ -862,9 +935,9 @@ def main():
862
935
  cpu_percent = psutil.cpu_percent(interval=1)
863
936
  logger.info(f"#CPUs: {num_cpu}, current CPU utilization: {cpu_percent}%")
864
937
  logger.info(f"Memory utilization: {psutil.virtual_memory().percent}%")
865
- if num_proc == 'auto':
938
+ if num_proc == "auto":
866
939
  use_cpu = int(num_cpu * (0.9 - cpu_percent / 100)) # default use 80% of available cpus
867
- elif num_proc == 'None':
940
+ elif num_proc == "None":
868
941
  use_cpu = None
869
942
  else:
870
943
  use_cpu = min(max(int(num_proc), 1), num_cpu)
@@ -872,11 +945,11 @@ def main():
872
945
  # Individual mixtures use pandas print, set precision to 2 decimal places
873
946
  # pd.set_option('float_format', '{:.2f}'.format)
874
947
  logger.info(f"Calculating metrics for {len(mixids)} mixtures using {use_cpu} parallel processes ...")
875
- progress = tqdm(total=len(mixids), desc='calc_metric_spenh')
948
+ progress = track(total=len(mixids), desc="calc_metric_spenh")
876
949
  if use_cpu is None:
877
- all_metrics_tables = pp_tqdm_imap(_process_mixture, mixids, progress=progress, no_par=True)
950
+ all_metrics_tables = par_track(_process_mixture, mixids, progress=progress, no_par=True)
878
951
  else:
879
- all_metrics_tables = pp_tqdm_imap(_process_mixture, mixids, progress=progress, num_cpus=use_cpu)
952
+ all_metrics_tables = par_track(_process_mixture, mixids, progress=progress, num_cpus=use_cpu)
880
953
  progress.close()
881
954
 
882
955
  all_metrics_table_1 = pd.concat([item[0] for item in all_metrics_tables])
@@ -887,37 +960,40 @@ def main():
887
960
 
888
961
  # 9) Done with mixtures, write out summary metrics
889
962
  # Calculate SNR summary avg of each non-random snr
890
- all_mtab1_sorted = all_metrics_table_1.sort_values(by=['MXSNR', 'SPFILE'])
891
- all_mtab2_sorted = all_metrics_table_2.sort_values(by=['MXSNR'])
963
+ all_mtab1_sorted = all_metrics_table_1.sort_values(by=["MXSNR", "SPFILE"])
964
+ all_mtab2_sorted = all_metrics_table_2.sort_values(by=["MXSNR"])
892
965
  mtab_snr_summary = None
893
966
  mtab_snr_summary_em = None
894
967
  for snri in range(0, len(mixdb.snrs)):
895
- tmp = all_mtab1_sorted.query('MXSNR==' + str(mixdb.snrs[snri])).mean(numeric_only=True).to_frame().T
968
+ tmp = all_mtab1_sorted.query("MXSNR==" + str(mixdb.snrs[snri])).mean(numeric_only=True).to_frame().T
896
969
  # avoid nan when subset of mixids specified
897
970
  if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
898
971
  mtab_snr_summary = pd.concat([mtab_snr_summary, tmp])
899
972
 
900
- tmp = all_mtab2_sorted[all_mtab2_sorted['MXSNR'] == mixdb.snrs[snri]].mean(numeric_only=True).to_frame().T
973
+ tmp = all_mtab2_sorted[all_mtab2_sorted["MXSNR"] == mixdb.snrs[snri]].mean(numeric_only=True).to_frame().T
901
974
  # avoid nan when subset of mixids specified (mxsnr will be nan if no data):
902
975
  if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
903
976
  mtab_snr_summary_em = pd.concat([mtab_snr_summary_em, tmp])
904
977
 
905
- mtab_snr_summary = mtab_snr_summary.sort_values(by=['MXSNR'], ascending=False)
978
+ mtab_snr_summary = mtab_snr_summary.sort_values(by=["MXSNR"], ascending=False)
906
979
  # Correct percentages in snr summary table
907
- mtab_snr_summary['PESQi%'] = 100 * (mtab_snr_summary['PESQ'] - mtab_snr_summary['MXPESQ']) / np.maximum(
908
- mtab_snr_summary['MXPESQ'], 0.01)
980
+ mtab_snr_summary["PESQi%"] = (
981
+ 100 * (mtab_snr_summary["PESQ"] - mtab_snr_summary["MXPESQ"]) / np.maximum(mtab_snr_summary["MXPESQ"], 0.01)
982
+ )
909
983
  for i in range(len(mtab_snr_summary)):
910
- if mtab_snr_summary['MXWER'].iloc[i] == 0.0:
911
- if mtab_snr_summary['WER'].iloc[i] == 0.0:
984
+ if mtab_snr_summary["MXWER"].iloc[i] == 0.0:
985
+ if mtab_snr_summary["WER"].iloc[i] == 0.0:
912
986
  mtab_snr_summary.iloc[i, 6] = 0.0 # mtab_snr_summary['WERi%'].iloc[i] = 0.0
913
987
  else:
914
988
  mtab_snr_summary.iloc[i, 6] = -999.0 # mtab_snr_summary['WERi%'].iloc[i] = -999.0
915
989
  else:
916
- if ~np.isnan(mtab_snr_summary['WER'].iloc[i]) and ~np.isnan(mtab_snr_summary['MXWER'].iloc[i]):
990
+ if ~np.isnan(mtab_snr_summary["WER"].iloc[i]) and ~np.isnan(mtab_snr_summary["MXWER"].iloc[i]):
917
991
  # update WERi% in 6th col
918
- mtab_snr_summary.iloc[i, 6] = 100 * (mtab_snr_summary['MXWER'].iloc[i] -
919
- mtab_snr_summary['WER'].iloc[i]) / \
920
- mtab_snr_summary['MXWER'].iloc[i]
992
+ mtab_snr_summary.iloc[i, 6] = (
993
+ 100
994
+ * (mtab_snr_summary["MXWER"].iloc[i] - mtab_snr_summary["WER"].iloc[i])
995
+ / mtab_snr_summary["MXWER"].iloc[i]
996
+ )
921
997
 
922
998
  # Calculate avg metrics over all mixtures except -99
923
999
  all_mtab1_sorted_nom99 = all_mtab1_sorted[all_mtab1_sorted.MXSNR != -99]
@@ -925,16 +1001,17 @@ def main():
925
1001
 
926
1002
  # correct the percentage averages with a direct calculation (PESQ% and WER%):
927
1003
  # ser.iloc[pos]
928
- all_nom99_mean['PESQi%'] = (100 * (all_nom99_mean['PESQ'] - all_nom99_mean['MXPESQ'])
929
- / np.maximum(all_nom99_mean['MXPESQ'], 0.01)) # pesq%
1004
+ all_nom99_mean["PESQi%"] = (
1005
+ 100 * (all_nom99_mean["PESQ"] - all_nom99_mean["MXPESQ"]) / np.maximum(all_nom99_mean["MXPESQ"], 0.01)
1006
+ ) # pesq%
930
1007
  # all_nom99_mean[3] = 100 * (all_nom99_mean[2] - all_nom99_mean[1]) / np.maximum(all_nom99_mean[1], 0.01) # pesq%
931
- if all_nom99_mean['MXWER'] == 0.0:
932
- if all_nom99_mean['WER'] == 0.0:
933
- all_nom99_mean['WERi%'] = 0.0
1008
+ if all_nom99_mean["MXWER"] == 0.0:
1009
+ if all_nom99_mean["WER"] == 0.0:
1010
+ all_nom99_mean["WERi%"] = 0.0
934
1011
  else:
935
- all_nom99_mean['WERi%'] = -999.0
1012
+ all_nom99_mean["WERi%"] = -999.0
936
1013
  else: # wer%
937
- all_nom99_mean['WERi%'] = 100 * (all_nom99_mean['MXWER'] - all_nom99_mean['WER']) / all_nom99_mean['MXWER']
1014
+ all_nom99_mean["WERi%"] = 100 * (all_nom99_mean["MXWER"] - all_nom99_mean["WER"]) / all_nom99_mean["MXWER"]
938
1015
 
939
1016
  num_mix = len(mixids)
940
1017
  if num_mix > 1:
@@ -942,91 +1019,88 @@ def main():
942
1019
  # pd.set_option('float_format', '{:.2f}'.format)
943
1020
 
944
1021
  if not truth_est_mode:
945
- ofname = join(predict_location, fnb + 'summary.txt')
1022
+ ofname = join(predict_location, fnb + "summary.txt")
946
1023
  else:
947
- ofname = join(predict_location, fnb + 'summary_truest.txt')
948
-
949
- with open(ofname, 'w') as f:
950
- print(f'ASR enabled with method {asr_method}, whisper model, if used: {asr_model_name}', file=f)
951
- print(f'Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:',
952
- file=f)
953
- print(all_nom99_mean.to_frame().T.round(2).to_string(float_format=lambda x: "{:.2f}".format(x),
954
- index=False), file=f)
955
- print(f'\nSpeech enhancement metrics avg over each SNR:', file=f)
956
- print(mtab_snr_summary.round(2).to_string(float_format=lambda x: "{:.2f}".format(x), index=False), file=f)
957
- print('', file=f)
958
- print(f'Extraction statistics stats avg over each SNR:', file=f)
1024
+ ofname = join(predict_location, fnb + "summary_truest.txt")
1025
+
1026
+ with open(ofname, "w") as f, redirect_stdout(f):
1027
+ print(f"ASR enabled with method {asr_method}, whisper model, if used: {asr_model_name}")
1028
+ print(f"Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:")
1029
+ print(all_nom99_mean.to_frame().T.round(2).to_string(float_format=lambda x: f"{x:.2f}", index=False))
1030
+ print("\nSpeech enhancement metrics avg over each SNR:")
1031
+ print(mtab_snr_summary.round(2).to_string(float_format=lambda x: f"{x:.2f}", index=False))
1032
+ print("")
1033
+ print("Extraction statistics stats avg over each SNR:")
959
1034
  # with pd.option_context('display.max_colwidth', 9):
960
1035
  # with pd.set_option('float_format', '{:.1f}'.format):
961
- print(mtab_snr_summary_em.round(1).to_string(float_format=lambda x: "{:.1f}".format(x), index=False),
962
- file=f)
963
- print('', file=f)
1036
+ print(mtab_snr_summary_em.round(1).to_string(float_format=lambda x: f"{x:.1f}", index=False))
1037
+ print("")
964
1038
  # pd.set_option('float_format', '{:.2f}'.format)
965
1039
 
966
- print(f'Speech enhancement metrics stats over all {num_mix} mixtures:', file=f)
967
- print(all_metrics_table_1.describe().round(2).to_string(float_format=lambda x: "{:.2f}".format(x)), file=f)
968
- print('', file=f)
969
- print(f'Extraction statistics stats over all {num_mix} mixtures:', file=f)
970
- print(all_metrics_table_2.describe().round(2).to_string(float_format=lambda x: "{:.1f}".format(x)), file=f)
971
- print('', file=f)
1040
+ print(f"Speech enhancement metrics stats over all {num_mix} mixtures:")
1041
+ print(all_metrics_table_1.describe().round(2).to_string(float_format=lambda x: f"{x:.2f}"))
1042
+ print("")
1043
+ print(f"Extraction statistics stats over all {num_mix} mixtures:")
1044
+ print(all_metrics_table_2.describe().round(2).to_string(float_format=lambda x: f"{x:.1f}"))
1045
+ print("")
972
1046
 
973
- print('Speech enhancement metrics all-mixtures list:', file=f)
974
- # print(all_metrics_table_1.head().style.format(precision=2), file=f)
975
- print(all_metrics_table_1.round(2).to_string(float_format=lambda x: "{:.2f}".format(x)), file=f)
976
- print('', file=f)
977
- print('Extraction statistics all-mixtures list:', file=f)
978
- print(all_metrics_table_2.round(2).to_string(float_format=lambda x: "{:.1f}".format(x)), file=f)
1047
+ print("Speech enhancement metrics all-mixtures list:")
1048
+ # print(all_metrics_table_1.head().style.format(precision=2))
1049
+ print(all_metrics_table_1.round(2).to_string(float_format=lambda x: f"{x:.2f}"))
1050
+ print("")
1051
+ print("Extraction statistics all-mixtures list:")
1052
+ print(all_metrics_table_2.round(2).to_string(float_format=lambda x: f"{x:.1f}"))
979
1053
 
980
1054
  # Write summary to .csv file
981
1055
  if not truth_est_mode:
982
- csv_name = str(join(predict_location, fnb + 'summary.csv'))
1056
+ csv_name = str(join(predict_location, fnb + "summary.csv"))
983
1057
  else:
984
- csv_name = str(join(predict_location, fnb + 'truest_summary.csv'))
1058
+ csv_name = str(join(predict_location, fnb + "truest_summary.csv"))
985
1059
  header_args = {
986
- 'mode': 'a',
987
- 'encoding': 'utf-8',
988
- 'index': False,
989
- 'header': False,
1060
+ "mode": "a",
1061
+ "encoding": "utf-8",
1062
+ "index": False,
1063
+ "header": False,
990
1064
  }
991
1065
  table_args = {
992
- 'mode': 'a',
993
- 'encoding': 'utf-8',
1066
+ "mode": "a",
1067
+ "encoding": "utf-8",
994
1068
  }
995
- label = f'Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:'
1069
+ label = f"Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:"
996
1070
  pd.DataFrame([label]).to_csv(csv_name, header=False, index=False) # open as write
997
1071
  all_nom99_mean.to_frame().T.round(2).to_csv(csv_name, index=False, **table_args)
998
- pd.DataFrame(['']).to_csv(csv_name, **header_args)
999
- pd.DataFrame([f'Speech enhancement metrics avg over each SNR:']).to_csv(csv_name, **header_args)
1072
+ pd.DataFrame([""]).to_csv(csv_name, **header_args)
1073
+ pd.DataFrame(["Speech enhancement metrics avg over each SNR:"]).to_csv(csv_name, **header_args)
1000
1074
  mtab_snr_summary.round(2).to_csv(csv_name, index=False, **table_args)
1001
- pd.DataFrame(['']).to_csv(csv_name, **header_args)
1002
- pd.DataFrame([f'Extraction statistics stats avg over each SNR:']).to_csv(csv_name, **header_args)
1075
+ pd.DataFrame([""]).to_csv(csv_name, **header_args)
1076
+ pd.DataFrame(["Extraction statistics stats avg over each SNR:"]).to_csv(csv_name, **header_args)
1003
1077
  mtab_snr_summary_em.round(2).to_csv(csv_name, index=False, **table_args)
1004
- pd.DataFrame(['']).to_csv(csv_name, **header_args)
1005
- pd.DataFrame(['']).to_csv(csv_name, **header_args)
1006
- label = f'Speech enhancement metrics stats over {num_mix} mixtures:'
1078
+ pd.DataFrame([""]).to_csv(csv_name, **header_args)
1079
+ pd.DataFrame([""]).to_csv(csv_name, **header_args)
1080
+ label = f"Speech enhancement metrics stats over {num_mix} mixtures:"
1007
1081
  pd.DataFrame([label]).to_csv(csv_name, **header_args)
1008
1082
  all_metrics_table_1.describe().round(2).to_csv(csv_name, **table_args)
1009
- pd.DataFrame(['']).to_csv(csv_name, **header_args)
1010
- label = f'Extraction statistics stats over {num_mix} mixtures:'
1083
+ pd.DataFrame([""]).to_csv(csv_name, **header_args)
1084
+ label = f"Extraction statistics stats over {num_mix} mixtures:"
1011
1085
  pd.DataFrame([label]).to_csv(csv_name, **header_args)
1012
1086
  all_metrics_table_2.describe().round(2).to_csv(csv_name, **table_args)
1013
- label = f'ASR enabled with method {asr_method}, whisper model, if used: {asr_model_name}'
1087
+ label = f"ASR enabled with method {asr_method}, whisper model, if used: {asr_model_name}"
1014
1088
  pd.DataFrame([label]).to_csv(csv_name, **header_args)
1015
1089
 
1016
1090
  if not truth_est_mode:
1017
- csv_name = str(join(predict_location, fnb + 'list.csv'))
1091
+ csv_name = str(join(predict_location, fnb + "list.csv"))
1018
1092
  else:
1019
- csv_name = str(join(predict_location, fnb + 'truest_list.csv'))
1020
- pd.DataFrame(['Speech enhancement metrics list:']).to_csv(csv_name, header=False, index=False) # open as write
1093
+ csv_name = str(join(predict_location, fnb + "truest_list.csv"))
1094
+ pd.DataFrame(["Speech enhancement metrics list:"]).to_csv(csv_name, header=False, index=False) # open as write
1021
1095
  all_metrics_table_1.round(2).to_csv(csv_name, **table_args)
1022
1096
 
1023
1097
  if not truth_est_mode:
1024
- csv_name = str(join(predict_location, fnb + 'estats_list.csv'))
1098
+ csv_name = str(join(predict_location, fnb + "estats_list.csv"))
1025
1099
  else:
1026
- csv_name = str(join(predict_location, fnb + 'truest_estats_list.csv'))
1027
- pd.DataFrame(['Extraction statistics list:']).to_csv(csv_name, header=False, index=False) # open as write
1100
+ csv_name = str(join(predict_location, fnb + "truest_estats_list.csv"))
1101
+ pd.DataFrame(["Extraction statistics list:"]).to_csv(csv_name, header=False, index=False) # open as write
1028
1102
  all_metrics_table_2.round(2).to_csv(csv_name, **table_args)
1029
1103
 
1030
1104
 
1031
- if __name__ == '__main__':
1105
+ if __name__ == "__main__":
1032
1106
  main()