sonusai 1.0.16__cp311-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,1136 @@
1
+ """sonusai calc_metric_spenh
2
+
3
+ usage: calc_metric_spenh [-hvtpws] [-i MIXID] [-e ASR] [-n NCPU] PLOC TLOC
4
+
5
+ options:
6
+ -h, --help
7
+ -v, --verbose Be verbose.
8
+ -i MIXID, --mixid MIXID Mixture ID(s) to process, can be range like 0:maxmix+1. [default: *]
9
+ -t, --truth-est-mode Calculate extraction and metrics using truth (instead of prediction).
10
+ -p, --plot Enable PDF plots file generation per mixture.
11
+ -w, --wav Generate WAV files per mixture.
12
+ -s, --summary Enable summary files generation.
13
+ -n, --num_process NCPU Number of parallel processes to use [default: auto]
14
+ -e ASR, --asr-method ASR ASR method used for WER metrics. Must exist in the TLOC dataset as pre-calculated
15
+ metrics using SonusAI genmetrics. Can be either an integer index, i.e 0,1,... or the
16
+ name of the asr_engine configuration in the dataset. If an incorrect name is specified,
17
+ a list of asr_engines of the dataset will be printed.
18
+
19
+ Calculate speech enhancement metrics of prediction data in PLOC using SonusAI mixture data in TLOC as truth/label
20
+ reference. Metric and extraction data files are written into PLOC.
21
+
22
+ PLOC directory containing prediction data in .h5 files created from truth/label mixture data in TLOC
23
+ TLOC directory with SonusAI mixture database of truth/label mixture data
24
+
25
+ For ASR methods, the method must bel2 defined in the TLOC dataset, for example possible fast_whisper available models are:
26
+ {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large} and an example configuration looks like:
27
+ {'fwhsptiny_cpu': {'engine': 'faster_whisper',
28
+ 'model': 'tiny',
29
+ 'device': 'cpu',
30
+ 'beam_size': 5}}
31
+ Note: the ASR config can optionally include the model, device, and other fields the engine supports.
32
+ Most ASR are very computationally demanding and can overwhelm/hang a local system.
33
+
34
+ Outputs the following to PLOC (where id is mixid number 0:num_mixtures):
35
+ <id>_metric_spenh.txt
36
+
37
+ If --plot:
38
+ <id>_metric_spenh.pdf
39
+
40
+ If --wav:
41
+ <id>_target.wav
42
+ <id>_target_est.wav
43
+ <id>_noise.wav
44
+ <id>_noise_est.wav
45
+ <id>_mixture.wav
46
+
47
+ If --truth-est-mode:
48
+ <id>_target_truth_est.wav
49
+ <id>_noise_truth_est.wav
50
+
51
+ If --summary:
52
+ metric_spenh_targetf_summary.txt
53
+ metric_spenh_targetf_summary.csv
54
+ metric_spenh_targetf_list.csv
55
+ metric_spenh_targetf_estats_list.csv
56
+
57
+ If --truth-est-mode:
58
+ metric_spenh_targetf_truth_list.csv
59
+ metric_spenh_targetf_estats_truth_list.csv
60
+
61
+ TBD
62
+ Metric and extraction data are written into prediction location PLOC as separate files per mixture.
63
+
64
+ -d PLOC, --ploc PLOC Location of SonusAI predict data.
65
+
66
+ Inputs:
67
+
68
+ """
69
+
70
+ from typing import Any
71
+
72
+ import matplotlib
73
+ import matplotlib.pyplot as plt
74
+ import numpy as np
75
+ import pandas as pd
76
+
77
+ from sonusai.datatypes import AudioF
78
+ from sonusai.datatypes import AudioT
79
+ from sonusai.datatypes import Feature
80
+ from sonusai.datatypes import Predict
81
+ from sonusai.mixture import MixtureDatabase
82
+
83
+ DB_99 = np.power(10, 99 / 10)
84
+ DB_N99 = np.power(10, -99 / 10)
85
+
86
+
87
+ matplotlib.use("SVG")
88
+
89
+
90
+ def first_key(x: dict) -> str:
91
+ for key in x:
92
+ return key
93
+ raise KeyError("No key found")
94
+
95
+
96
+ def mean_square_error(
97
+ hypothesis: np.ndarray,
98
+ reference: np.ndarray,
99
+ squared: bool = False,
100
+ ) -> tuple[float, np.ndarray, np.ndarray]:
101
+ """Calculate root-mean-square error or mean square error
102
+
103
+ :param hypothesis: [frames, bins]
104
+ :param reference: [frames, bins]
105
+ :param squared: calculate mean square rather than root-mean-square
106
+ :return: mean, mean per bin, mean per frame
107
+ """
108
+ sq_err = np.square(reference - hypothesis)
109
+
110
+ # mean over frames for value per bin
111
+ err_b = np.mean(sq_err, axis=0)
112
+ # mean over bins for value per frame
113
+ err_f = np.mean(sq_err, axis=1)
114
+ # mean over all
115
+ err = float(np.mean(sq_err))
116
+
117
+ if not squared:
118
+ err_b = np.sqrt(err_b)
119
+ err_f = np.sqrt(err_f)
120
+ err = np.sqrt(err)
121
+
122
+ return err, err_b, err_f
123
+
124
+
125
+ def mean_abs_percentage_error(hypothesis: np.ndarray, reference: np.ndarray) -> tuple[float, np.ndarray, np.ndarray]:
126
+ """Calculate mean abs percentage error
127
+
128
+ If inputs are complex, calculates average: mape(real)/2 + mape(imag)/2
129
+
130
+ :param hypothesis: [frames, bins]
131
+ :param reference: [frames, bins]
132
+ :return: mean, mean per bin, mean per frame
133
+ """
134
+ if not np.iscomplexobj(reference) and not np.iscomplexobj(hypothesis):
135
+ abs_err = 100 * np.abs((reference - hypothesis) / (reference + np.finfo(np.float32).eps))
136
+ else:
137
+ reference_r = np.real(reference)
138
+ reference_i = np.imag(reference)
139
+ hypothesis_r = np.real(hypothesis)
140
+ hypothesis_i = np.imag(hypothesis)
141
+ abs_err_r = 100 * np.abs((reference_r - hypothesis_r) / (reference_r + np.finfo(np.float32).eps))
142
+ abs_err_i = 100 * np.abs((reference_i - hypothesis_i) / (reference_i + np.finfo(np.float32).eps))
143
+ abs_err = (abs_err_r / 2) + (abs_err_i / 2)
144
+
145
+ # mean over frames for value per bin
146
+ err_b = np.around(np.mean(abs_err, axis=0), 3)
147
+ # mean over bins for value per frame
148
+ err_f = np.around(np.mean(abs_err, axis=1), 3)
149
+ # mean over all
150
+ err = float(np.around(np.mean(abs_err), 3))
151
+
152
+ return err, err_b, err_f
153
+
154
+
155
+ def log_error(reference: np.ndarray, hypothesis: np.ndarray) -> tuple[float, np.ndarray, np.ndarray]:
156
+ """Calculate log error
157
+
158
+ :param reference: complex or real [frames, bins]
159
+ :param hypothesis: complex or real [frames, bins]
160
+ :return: mean, mean per bin, mean per frame
161
+ """
162
+ reference_sq = np.real(reference * np.conjugate(reference))
163
+ hypothesis_sq = np.real(hypothesis * np.conjugate(hypothesis))
164
+ log_err = abs(10 * np.log10((reference_sq + np.finfo(np.float32).eps) / (hypothesis_sq + np.finfo(np.float32).eps)))
165
+ # log_err = abs(10 * np.log10(reference_sq / (hypothesis_sq + np.finfo(np.float32).eps) + np.finfo(np.float32).eps))
166
+
167
+ # mean over frames for value per bin
168
+ err_b = np.around(np.mean(log_err, axis=0), 3)
169
+ # mean over bins for value per frame
170
+ err_f = np.around(np.mean(log_err, axis=1), 3)
171
+ # mean over all
172
+ err = float(np.around(np.mean(log_err), 3))
173
+
174
+ return err, err_b, err_f
175
+
176
+
177
+ def plot_mixpred(
178
+ mixture: AudioT,
179
+ mixture_f: AudioF,
180
+ target: AudioT | None = None,
181
+ feature: Feature | None = None,
182
+ predict: Predict | None = None,
183
+ tp_title: str = "",
184
+ ) -> tuple[plt.Figure, Any]: # pyright: ignore [reportPrivateImportUsage]
185
+ from sonusai.constants import SAMPLE_RATE
186
+
187
+ num_plots = 2
188
+ if feature is not None:
189
+ num_plots += 1
190
+ if predict is not None:
191
+ num_plots += 1
192
+
193
+ fig, ax = plt.subplots(num_plots, 1, constrained_layout=True, figsize=(11, 8.5))
194
+
195
+ # Plot the waveform
196
+ p = 0
197
+ x_axis = np.arange(len(mixture), dtype=np.float32) / SAMPLE_RATE
198
+ ax[p].plot(x_axis, mixture, label="Mixture", color="mistyrose")
199
+ ax[0].set_ylabel("magnitude", color="tab:blue")
200
+ ax[p].set_xlim(x_axis[0], x_axis[-1])
201
+ if target is not None: # Plot target time-domain waveform on top of mixture
202
+ ax[0].plot(x_axis, target, label="Target", color="tab:blue")
203
+ ax[p].set_title("Waveform")
204
+
205
+ # Plot the mixture spectrogram
206
+ p += 1
207
+ ax[p].imshow(np.transpose(mixture_f), aspect="auto", interpolation="nearest", origin="lower")
208
+ ax[p].set_title("Mixture")
209
+
210
+ if feature is not None:
211
+ p += 1
212
+ ax[p].imshow(np.transpose(feature), aspect="auto", interpolation="nearest", origin="lower")
213
+ ax[p].set_title("Feature")
214
+
215
+ if predict is not None:
216
+ p += 1
217
+ im = ax[p].imshow(np.transpose(predict), aspect="auto", interpolation="nearest", origin="lower")
218
+ ax[p].set_title("Predict " + tp_title)
219
+ plt.colorbar(im, location="bottom")
220
+
221
+ return fig, ax
222
+
223
+
224
+ def plot_pdb_predict_truth(
225
+ predict: np.ndarray,
226
+ truth_f: np.ndarray | None = None,
227
+ metric: np.ndarray | None = None,
228
+ tp_title: str = "",
229
+ ) -> plt.Figure: # pyright: ignore [reportPrivateImportUsage]
230
+ """Plot predict and optionally truth and a metric in power db, e.g. applies 10*log10(predict)"""
231
+ num_plots = 2
232
+ if truth_f is not None:
233
+ num_plots += 1
234
+
235
+ fig, ax = plt.subplots(num_plots, 1, constrained_layout=True, figsize=(11, 8.5))
236
+
237
+ # Plot the predict spectrogram
238
+ p = 0
239
+ tmp = 10 * np.log10(predict.transpose() + np.finfo(np.float32).eps)
240
+ im = ax[p].imshow(tmp, aspect="auto", interpolation="nearest", origin="lower")
241
+ ax[p].set_title("Predict")
242
+ plt.colorbar(im, location="bottom")
243
+
244
+ if truth_f is not None:
245
+ p += 1
246
+ tmp = 10 * np.log10(truth_f.transpose() + np.finfo(np.float32).eps)
247
+ im = ax[p].imshow(tmp, aspect="auto", interpolation="nearest", origin="lower")
248
+ ax[p].set_title("Truth")
249
+ plt.colorbar(im, location="bottom")
250
+
251
+ # Plot the predict avg, and optionally truth avg and metric lines
252
+ pred_avg = 10 * np.log10(np.mean(predict, axis=-1) + np.finfo(np.float32).eps)
253
+ p += 1
254
+ x_axis = np.arange(len(pred_avg), dtype=np.float32) # / SAMPLE_RATE
255
+ ax[p].plot(x_axis, pred_avg, color="black", linestyle="dashed", label="Predict mean over freq.")
256
+ ax[p].set_ylabel("mean db", color="black")
257
+ ax[p].set_xlim(x_axis[0], x_axis[-1])
258
+ if truth_f is not None:
259
+ truth_avg = 10 * np.log10(np.mean(truth_f, axis=-1) + np.finfo(np.float32).eps)
260
+ ax[p].plot(x_axis, truth_avg, color="green", linestyle="dashed", label="Truth mean over freq.")
261
+
262
+ if metric is not None: # instantiate 2nd y-axis that shares the same x-axis
263
+ ax2 = ax[p].twinx()
264
+ color2 = "red"
265
+ ax2.plot(x_axis, metric, color=color2, label="sig distortion (mse db)")
266
+ ax2.set_xlim(x_axis[0], x_axis[-1])
267
+ ax2.set_ylim([0, np.max(metric)])
268
+ ax2.set_ylabel("spectral distortion (mse db)", color=color2)
269
+ ax2.tick_params(axis="y", labelcolor=color2)
270
+ ax[p].set_title("SNR and SNR mse (mean over freq. db)")
271
+ else:
272
+ ax[p].set_title("SNR (mean over freq. db)")
273
+ return fig
274
+
275
+
276
+ def plot_e_predict_truth(
277
+ predict: np.ndarray,
278
+ predict_wav: np.ndarray,
279
+ truth_f: np.ndarray | None = None,
280
+ truth_wav: np.ndarray | None = None,
281
+ metric: np.ndarray | None = None,
282
+ tp_title: str = "",
283
+ ) -> tuple[plt.Figure, Any]: # pyright: ignore [reportPrivateImportUsage]
284
+ """Plot predict spectrogram and waveform and optionally truth and a metric)"""
285
+ num_plots = 2
286
+ if truth_f is not None:
287
+ num_plots += 1
288
+ if metric is not None:
289
+ num_plots += 1
290
+
291
+ fig, ax = plt.subplots(num_plots, 1, constrained_layout=True, figsize=(11, 8.5))
292
+
293
+ # Plot the predict spectrogram
294
+ p = 0
295
+ im = ax[p].imshow(predict.transpose(), aspect="auto", interpolation="nearest", origin="lower")
296
+ ax[p].set_title("Predict")
297
+ plt.colorbar(im, location="bottom")
298
+
299
+ if truth_f is not None: # plot truth if provided and use same colormap as predict
300
+ p += 1
301
+ ax[p].imshow(truth_f.transpose(), im.cmap, aspect="auto", interpolation="nearest", origin="lower")
302
+ ax[p].set_title("Truth")
303
+
304
+ # Plot predict wav, and optionally truth avg and metric lines
305
+ p += 1
306
+ x_axis = np.arange(len(predict_wav), dtype=np.float32) # / SAMPLE_RATE
307
+ ax[p].plot(x_axis, predict_wav, color="black", linestyle="dashed", label="Speech Estimate")
308
+ ax[p].set_ylabel("Amplitude", color="black")
309
+ ax[p].set_xlim(x_axis[0], x_axis[-1])
310
+ if truth_wav is not None:
311
+ ntrim = len(truth_wav) - len(predict_wav)
312
+ if ntrim > 0:
313
+ truth_wav = truth_wav[0:-ntrim]
314
+ ax[p].plot(x_axis, truth_wav, color="green", linestyle="dashed", label="True Target")
315
+
316
+ # Plot the metric lines
317
+ if metric is not None:
318
+ p += 1
319
+ if metric.ndim > 1: # if it has multiple dims, plot 1st
320
+ metric1 = metric[:, 0]
321
+ else:
322
+ metric1 = metric # if single dim, plot it as 1st
323
+ x_axis = np.arange(len(metric1), dtype=np.float32) # / SAMPLE_RATE
324
+ ax[p].plot(x_axis, metric1, color="red", label="Target LogErr")
325
+ ax[p].set_ylabel("log error db", color="red")
326
+ ax[p].set_xlim(x_axis[0], x_axis[-1])
327
+ ax[p].set_ylim([-0.01, np.max(metric1) + 0.01])
328
+ if metric.ndim > 1 and metric.shape[1] > 1:
329
+ p += 1
330
+ metr2 = metric[:, 1]
331
+ ax = np.append(ax, np.array(ax[p - 1].twinx()))
332
+ color2 = "blue"
333
+ ax[p].plot(x_axis, metr2, color=color2, label="phase dist (deg)")
334
+ # ax2.set_ylim([-180.0, +180.0])
335
+ if np.max(metr2) - np.min(metr2) > 0.1:
336
+ ax[p].set_ylim([np.min(metr2), np.max(metr2)])
337
+ ax[p].set_ylabel("phase dist (deg)", color=color2)
338
+ ax[p].tick_params(axis="y", labelcolor=color2)
339
+ # ax[p].set_title('SNR and SNR mse (mean over freq. db)')
340
+
341
+ return fig, ax
342
+
343
+
344
+ def _process_mixture(
345
+ m_id: int,
346
+ truth_location: str,
347
+ predict_location: str,
348
+ predict_wav_mode: bool,
349
+ truth_est_mode: bool,
350
+ enable_plot: bool,
351
+ enable_wav: bool,
352
+ asr_method: str,
353
+ target_f_key: str,
354
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
355
+ import pickle
356
+ from os.path import basename
357
+ from os.path import join
358
+ from os.path import splitext
359
+
360
+ import h5py
361
+ import pgzip
362
+ from matplotlib.backends.backend_pdf import PdfPages
363
+ from pystoi import stoi
364
+
365
+ from sonusai import logger
366
+ from sonusai.metrics import calc_pcm
367
+ from sonusai.metrics import calc_pesq
368
+ from sonusai.metrics import calc_phase_distance
369
+ from sonusai.metrics import calc_speech
370
+ from sonusai.metrics import calc_wer
371
+ from sonusai.metrics import calc_wsdr
372
+ from sonusai.mixture import forward_transform
373
+ from sonusai.mixture import inverse_transform
374
+ from sonusai.mixture.audio import read_audio
375
+ from sonusai.utils.asr import calc_asr
376
+ from sonusai.utils.compress import power_compress
377
+ from sonusai.utils.compress import power_uncompress
378
+ from sonusai.utils.numeric_conversion import float_to_int16
379
+ from sonusai.utils.reshape import reshape_outputs
380
+ from sonusai.utils.stacked_complex import stack_complex
381
+ from sonusai.utils.stacked_complex import unstack_complex
382
+ from sonusai.utils.write_audio import write_audio
383
+
384
+ mixdb = MixtureDatabase(truth_location)
385
+
386
+ # 1) Read predict data, var predict with shape [BatchSize,Classes] or [batch, timesteps, classes]
387
+ output_name = join(predict_location, mixdb.mixture(m_id).name + ".h5")
388
+ predict = None
389
+ if truth_est_mode:
390
+ # in truth estimation mode we use the truth in place of prediction to see metrics with perfect input
391
+ # don't bother to read prediction, and predict var will get assigned to truth later
392
+ # mark outputs with tru suffix, i.e. 0000_truest_*
393
+ base_name = splitext(output_name)[0] + "_truest"
394
+ else:
395
+ base_name, ext = splitext(output_name) # base_name used later
396
+ if not predict_wav_mode:
397
+ try:
398
+ with h5py.File(output_name, "r") as f:
399
+ predict = np.array(f["predict"])
400
+ except Exception as e:
401
+ raise OSError(f"Error reading {output_name}: {e}") from e
402
+ # reshape to always be [frames, classes] where ndim==3 case frames = batch * timesteps
403
+ if predict.ndim > 2: # TBD generalize to somehow detect if timestep dim exists, some cases > 2 don't have
404
+ # logger.debug(f'Prediction reshape from {predict.shape} to remove timestep dimension.')
405
+ predict, _ = reshape_outputs(predict=predict, truth=None, timesteps=predict.shape[1])
406
+ else:
407
+ base_name, ext = splitext(output_name)
408
+ predict_name = join(base_name + ".wav")
409
+ audio = read_audio(predict_name, use_cache=True)
410
+ predict = forward_transform(audio, mixdb.ft_config)
411
+ if mixdb.feature[0:1] == "h":
412
+ predict = power_compress(predict)
413
+ predict = stack_complex(predict)
414
+
415
+ # 2) Collect true target, noise, mixture data, trim to predict size if needed
416
+ tmp = mixdb.mixture_sources(m_id) # time-dom augmented targets is list of pre-IR and pre-specaugment targets
417
+ target_f = mixdb.mixture_sources_f(m_id, sources=tmp)["primary"]
418
+ target = tmp["primary"]
419
+ mixture = mixdb.mixture_mixture(m_id) # note: gives full reverberated/distorted target, but no specaugment
420
+ # noise_wo_dist = mixdb.mixture_noise(mixid) # noise without specaugment and distortion
421
+ # noise_wo_dist_f = mixdb.mixture_noise_f(mixid, noise=noise_wo_dist)
422
+ noise = mixture - target # has time-domain distortion (ir,etc.) but does not have specaugment
423
+ # noise_f = mixdb.mixture_noise_f(mixid, noise=noise)
424
+ # note: uses pre-IR, pre-specaug audio
425
+ segsnr_f = mixdb.mixture_metrics(m_id, ["ssnr"])["ssnr"] # Why [0] removed?
426
+ mixture_f = mixdb.mixture_mixture_f(m_id, mixture=mixture)
427
+ noise_f = mixture_f - target_f # true noise in freq domain includes specaugment and time-domain ir,distortions
428
+ # segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise)
429
+ segsnr_f[segsnr_f == np.inf] = DB_99
430
+ # segsnr_f should never be -np.inf
431
+ segsnr_f[segsnr_f == -np.inf] = DB_N99
432
+ # need to use inv-tf to match #samples & latency shift properties of predict inv tf
433
+ target_fi = inverse_transform(target_f, mixdb.it_config)
434
+ noise_fi = inverse_transform(noise_f, mixdb.it_config)
435
+ # mixture_fi = mixdb.inverse_transform(mixture_f)
436
+
437
+ # gen feature, truth - note feature only used for plots
438
+ # TODO: parse truth_f for different formats
439
+ feature, truth_all = mixdb.mixture_ft(m_id, mixture_f=mixture_f)
440
+ truth_f = truth_all["primary"][target_f_key]
441
+ if truth_f.ndim > 2: # note this may not be needed anymore as all target_f truth is 3 dims
442
+ if truth_f.shape[1] != 1:
443
+ logger.info("Error: target_f truth has stride > 1, exiting.")
444
+ raise SystemExit(1)
445
+ else:
446
+ truth_f = truth_f[:, 0, :] # remove stride dimension
447
+
448
+ # ignore mixup
449
+ # for truth_setting in mixdb.target_file(mixdb.mixture(mixid).targets[0].file_id).truth_settings:
450
+ # if truth_setting.function == 'target_mixture_f':
451
+ # half = truth_f.shape[-1] // 2
452
+ # # extract target_f only
453
+ # truth_f = truth_f[..., :half]
454
+
455
+ if not truth_est_mode:
456
+ if predict.shape[0] < target_f.shape[0]: # target_f, truth_f, mixture_f, etc. same size
457
+ trim_f = target_f.shape[0] - predict.shape[0]
458
+ logger.debug(f"Warning: prediction frames less than mixture, trimming {trim_f} frames from all truth.")
459
+ target_f = target_f[0:-trim_f, :]
460
+ target_fi, _ = inverse_transform(target_f, mixdb.it_config)
461
+ trim_t = target.shape[0] - target_fi.shape[0]
462
+ target = target[0:-trim_t]
463
+ noise_f = noise_f[0:-trim_f, :]
464
+ noise = noise[0:-trim_t]
465
+ mixture_f = mixture_f[0:-trim_f, :]
466
+ mixture = mixture[0:-trim_t]
467
+ truth_f = truth_f[0:-trim_f, :]
468
+ elif predict.shape[0] > target_f.shape[0]:
469
+ logger.debug(
470
+ f"Warning: prediction has more frames than true mixture {predict.shape[0]} vs {truth_f.shape[0]}"
471
+ )
472
+ trim_f = predict.shape[0] - target_f.shape[0]
473
+ predict = predict[0:-trim_f, :]
474
+ # raise SonusAIError(
475
+ # f'Error: prediction has more frames than true mixture {predict.shape[0]} vs {truth_f.shape[0]}')
476
+
477
+ # 3) Extraction - format proper complex and wav estimates and truth (unstack, uncompress, inv tf, etc.)
478
+ if truth_est_mode:
479
+ predict = truth_f # substitute truth for the prediction (for test/debug)
480
+ predict_complex = unstack_complex(predict) # unstack
481
+ # if feature has compressed mag and truth does not, compress it
482
+ if mixdb.feature[0:1] == "h" and not first_key(mixdb.category_truth_configs("primary")).startswith(
483
+ "targetcmpr"
484
+ ):
485
+ predict_complex = power_compress(predict_complex) # from uncompressed truth
486
+ else:
487
+ predict_complex = unstack_complex(predict)
488
+
489
+ truth_f_complex = unstack_complex(truth_f)
490
+ if mixdb.feature[0:1] == "h": # 'hn' or 'ha' or 'hd', etc.: # if feat has compressed mag
491
+ # estimate noise in uncompressed-mag domain
492
+ noise_est_complex = mixture_f - power_uncompress(predict_complex)
493
+ predict_complex = power_uncompress(predict_complex) # uncompress if truth is compressed
494
+ else: # cn, c8, ..
495
+ noise_est_complex = mixture_f - predict_complex
496
+
497
+ target_est_wav = inverse_transform(predict_complex, mixdb.it_config)
498
+ noise_est_wav = inverse_transform(noise_est_complex, mixdb.it_config)
499
+
500
+ # 4) Metrics
501
+ # Target/Speech logerr - PSD estimation accuracy symmetric mean log-spectral distortion
502
+ lerr_tg, lerr_tg_bin, lerr_tg_frame = log_error(reference=truth_f_complex, hypothesis=predict_complex)
503
+ # Noise logerr - PSD estimation accuracy
504
+ lerr_n, lerr_n_bin, lerr_n_frame = log_error(reference=noise_f, hypothesis=noise_est_complex)
505
+ # PCM loss metric
506
+ ytrue_f = np.concatenate((truth_f_complex[:, np.newaxis, :], noise_f[:, np.newaxis, :]), axis=1)
507
+ ypred_f = np.concatenate((predict_complex[:, np.newaxis, :], noise_est_complex[:, np.newaxis, :]), axis=1)
508
+ pcm, pcm_bin, pcm_frame = calc_pcm(hypothesis=ypred_f, reference=ytrue_f, with_log=True)
509
+
510
+ # Phase distance
511
+ phd, phd_bin, phd_frame = calc_phase_distance(hypothesis=predict_complex, reference=truth_f_complex)
512
+
513
+ # Noise td logerr
514
+ # lerr_nt, lerr_nt_bin, lerr_nt_frame = log_error(noise_fi, noise_truth_est_audio)
515
+
516
+ # # SA-SDR (time-domain source-aggregated SDR)
517
+ ytrue = np.concatenate((target_fi[:, np.newaxis], noise_fi[:, np.newaxis]), axis=1)
518
+ ypred = np.concatenate((target_est_wav[:, np.newaxis], noise_est_wav[:, np.newaxis]), axis=1)
519
+ # # note: w/o scale is more pessimistic number
520
+ # sa_sdr, _ = calc_sa_sdr(hypothesis=ypred, reference=ytrue)
521
+ target_stoi = stoi(target_fi, target_est_wav, 16000, extended=False)
522
+
523
+ wsdr, wsdr_cc, wsdr_cw = calc_wsdr(hypothesis=ypred, reference=ytrue, with_log=True)
524
+ # logger.debug(f'wsdr weight sum for mixid {mixid} = {np.sum(wsdr_cw)}.')
525
+ # logger.debug(f'wsdr cweights = {wsdr_cw}.')
526
+ # logger.debug(f'wsdr ccoefs for mixid {mixid} = {wsdr_cc}.')
527
+
528
+ # Speech intelligibility measure - PESQ
529
+ if int(mixdb.mixture(m_id).noise.snr) > -99:
530
+ # len = target_est_wav.shape[0]
531
+ pesq_speech = calc_pesq(target_est_wav, target_fi)
532
+ csig_tg, cbak_tg, covl_tg = calc_speech(target_est_wav, target_fi, pesq=pesq_speech)
533
+ metrics = mixdb.mixture_metrics(m_id, ["mxpesq", "mxcsig", "mxcbak", "mxcovl"])
534
+ pesq_mx = metrics["mxpesq"]["primary"] if isinstance(metrics["mxpesq"], dict) else metrics["mxpesq"]
535
+ csig_mx = metrics["mxcsig"]["primary"] if isinstance(metrics["mxcsig"], dict) else metrics["mxcsig"]
536
+ cbak_mx = metrics["mxcbak"]["primary"] if isinstance(metrics["mxcbak"], dict) else metrics["mxcbak"]
537
+ covl_mx = metrics["mxcovl"]["primary"] if isinstance(metrics["mxcovl"], dict) else metrics["mxcovl"]
538
+ # pesq_speech_tst = calc_pesq(hypothesis=target_est_wav, reference=target)
539
+ # pesq_mixture_tst = calc_pesq(hypothesis=mixture, reference=target)
540
+ # pesq improvement
541
+ pesq_impr = pesq_speech - pesq_mx
542
+ # pesq improvement %
543
+ pesq_impr_pc = pesq_impr / (pesq_mx + np.finfo(np.float32).eps) * 100
544
+ else:
545
+ pesq_speech = 0
546
+ pesq_mx = 0
547
+ pesq_impr_pc = np.float32(0)
548
+ csig_mx = 0
549
+ csig_tg = 0
550
+ cbak_mx = 0
551
+ cbak_tg = 0
552
+ covl_mx = 0
553
+ covl_tg = 0
554
+
555
+ # Calc ASR
556
+ asr_tt = None
557
+ asr_mx = None
558
+ asr_tge = None
559
+ # asr_engines = list(mixdb.asr_configs.keys())
560
+ if asr_method is not None and mixdb.mixture(m_id).noise.snr >= -96: # noise only, ignore/reset target ASR
561
+ asr_mx_name = f"mxasr.{asr_method}"
562
+ wer_mx_name = f"mxwer.{asr_method}"
563
+ asr_tt_name = f"sasr.{asr_method}"
564
+ metrics = mixdb.mixture_metrics(m_id, [asr_mx_name, wer_mx_name, asr_tt_name])
565
+ asr_mx = metrics[asr_mx_name]["primary"] if isinstance(metrics[asr_mx_name], dict) else metrics[asr_mx_name]
566
+ wer_mx = metrics[wer_mx_name]["primary"] if isinstance(metrics[wer_mx_name], dict) else metrics[wer_mx_name]
567
+ asr_tt = metrics[asr_tt_name]["primary"] if isinstance(metrics[asr_tt_name], dict) else metrics[asr_tt_name]
568
+
569
+ if asr_tt:
570
+ noiseadd = None # TBD add as switch, default -30
571
+ if noiseadd is not None:
572
+ ngain = np.power(10, min(float(noiseadd), 0.0) / 20.0) # limit to gain <1, convert to float
573
+ tgasr_est_wav = target_est_wav + ngain * noise_est_wav # add back noise at low level
574
+ else:
575
+ tgasr_est_wav = target_est_wav
576
+
577
+ # logger.info(f'Calculating prediction ASR for mixid {mixid}')
578
+ asr_cfg = mixdb.asr_configs[asr_method]
579
+ asr_tge = calc_asr(tgasr_est_wav, **asr_cfg).text
580
+ wer_tge = calc_wer(asr_tge, asr_tt).wer * 100 # target estimate WER
581
+ if wer_mx == 0.0:
582
+ if wer_tge == 0.0:
583
+ wer_pi = 0.0
584
+ else:
585
+ wer_pi = -999.0 # instead of -Inf
586
+ else:
587
+ wer_pi = 100 * (wer_mx - wer_tge) / wer_mx
588
+ else:
589
+ logger.warning(f"Warning: mixid {m_id} ASR truth is empty, setting to 0% WER")
590
+ wer_mx = float(0)
591
+ wer_tge = float(0)
592
+ wer_pi = float(0)
593
+ else:
594
+ wer_mx = float("nan")
595
+ wer_tge = float("nan")
596
+ wer_pi = float("nan")
597
+
598
+ # 5) Save per mixture metric results
599
+ # Single row in table of scalar metrics per mixture
600
+ mtable1_col = [
601
+ "MXSNR",
602
+ "MXPESQ",
603
+ "PESQ",
604
+ "PESQi%",
605
+ "MXWER",
606
+ "WER",
607
+ "WERi%",
608
+ "WSDR",
609
+ "STOI",
610
+ "PCM",
611
+ "SPLERR",
612
+ "NLERR",
613
+ "PD",
614
+ "MXCSIG",
615
+ "CSIG",
616
+ "MXCBAK",
617
+ "CBAK",
618
+ "MXCOVL",
619
+ "COVL",
620
+ "SPFILE",
621
+ "NFILE",
622
+ ]
623
+ ti = mixdb.mixture(m_id).sources["primary"].file_id
624
+ ni = mixdb.mixture(m_id).noise.file_id
625
+ metr1 = [
626
+ mixdb.mixture(m_id).noise.snr,
627
+ pesq_mx,
628
+ pesq_speech,
629
+ pesq_impr_pc,
630
+ wer_mx,
631
+ wer_tge,
632
+ wer_pi,
633
+ wsdr,
634
+ target_stoi,
635
+ pcm,
636
+ lerr_tg,
637
+ lerr_n,
638
+ phd,
639
+ csig_mx,
640
+ csig_tg,
641
+ cbak_mx,
642
+ cbak_tg,
643
+ covl_mx,
644
+ covl_tg,
645
+ basename(mixdb.source_file(ti).name),
646
+ basename(mixdb.source_file(ni).name),
647
+ ]
648
+ mtab1 = pd.DataFrame([metr1], columns=mtable1_col, index=[m_id])
649
+
650
+ # Stats of per frame estimation metrics
651
+ metr2 = pd.DataFrame(
652
+ {"SSNR": segsnr_f, "PCM": pcm_frame, "SLERR": lerr_tg_frame, "NLERR": lerr_n_frame, "SPD": phd_frame}
653
+ )
654
+ metr2 = metr2.describe() # Use pandas stat function
655
+ # Change SSNR stats to dB, except count. SSNR is index 0, pandas requires using iloc
656
+ # metr2['SSNR'][1:] = metr2['SSNR'][1:].apply(lambda x: 10 * np.log10(x + 1.01e-10))
657
+ metr2.iloc[1:, 0] = metr2["SSNR"][1:].apply(lambda x: 10 * np.log10(x + 1.01e-10))
658
+ # create a single row in multi-column header
659
+ new_labels = pd.MultiIndex.from_product(
660
+ [metr2.columns, ["Avg", "Min", "Med", "Max", "Std"]], names=["Metric", "Stat"]
661
+ )
662
+ dat1row = metr2.loc[["mean", "min", "50%", "max", "std"], :].T.stack().to_numpy().reshape((1, -1))
663
+ mtab2 = pd.DataFrame(dat1row, index=[m_id], columns=new_labels)
664
+ mtab2.insert(0, "MXSNR", mixdb.mixture(m_id).noise.snr, False) # add MXSNR as the first metric column
665
+
666
+ all_metrics_table_1 = mtab1 # return to be collected by process
667
+ all_metrics_table_2 = mtab2 # return to be collected by process
668
+
669
+ if asr_method is None:
670
+ metric_name = base_name + "_metric_spenh.txt"
671
+ else:
672
+ metric_name = base_name + "_metric_spenh_" + asr_method + ".txt"
673
+
674
+ with open(metric_name, "w") as f:
675
+ print("Speech enhancement metrics:", file=f)
676
+ print(mtab1.round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
677
+ print("", file=f)
678
+ print(f"Extraction statistics over {mixture_f.shape[0]} frames:", file=f)
679
+ print(metr2.round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
680
+ print("", file=f)
681
+ print(f"Target path: {mixdb.source_file(ti).name}", file=f)
682
+ print(f"Noise path: {mixdb.source_file(ni).name}", file=f)
683
+ if asr_method != "none":
684
+ print(f"ASR method: {asr_method}", file=f)
685
+ print(f"ASR truth: {asr_tt}", file=f)
686
+ print(f"ASR result for mixture: {asr_mx}", file=f)
687
+ print(f"ASR result for prediction: {asr_tge}", file=f)
688
+
689
+ print(f"Augmentations: {mixdb.mixture(m_id)}", file=f)
690
+
691
+ # 7) write wav files
692
+ if enable_wav:
693
+ write_audio(name=base_name + "_mixture.wav", audio=float_to_int16(mixture))
694
+ write_audio(name=base_name + "_target.wav", audio=float_to_int16(target))
695
+ # write_audio(name=base_name + '_target_fi.wav', audio=float_to_int16(target_fi))
696
+ write_audio(name=base_name + "_noise.wav", audio=float_to_int16(noise))
697
+ write_audio(name=base_name + "_target_est.wav", audio=float_to_int16(target_est_wav))
698
+ write_audio(name=base_name + "_noise_est.wav", audio=float_to_int16(noise_est_wav))
699
+
700
+ # debug code to test for perfect reconstruction of the extraction method
701
+ # note both 75% olsa-hanns and 50% olsa-hann modes checked to have perfect reconstruction
702
+ # target_r = mixdb.inverse_transform(target_f)
703
+ # noise_r = mixdb.inverse_transform(noise_f)
704
+ # _write_wav(name=base_name + '_target_r.wav', audio=float_to_int16(target_r))
705
+ # _write_wav(name=base_name + '_noise_r.wav', audio=float_to_int16(noise_r)) # chk perfect rec
706
+
707
+ # 8) Write out plot file
708
+ if enable_plot:
709
+ plot_name = base_name + "_metric_spenh.pdf"
710
+
711
+ # Reshape feature to eliminate overlap redundancy for easier to understand spectrogram view
712
+ # Original size (frames, stride, num_bands), decimates in stride dimension only if step is > 1
713
+ # Reshape to get frames*decimated_stride, num_bands
714
+ step = int(mixdb.feature_samples / mixdb.feature_step_samples)
715
+ if feature.ndim != 3:
716
+ raise OSError("feature does not have 3 dimensions: frames, stride, num_bands")
717
+
718
+ # for feature cn*00n**
719
+ feat_sgram = unstack_complex(feature)
720
+ feat_sgram = 20 * np.log10(abs(feat_sgram) + np.finfo(np.float32).eps)
721
+ feat_sgram = feat_sgram[:, -step:, :] # decimate, Fx1xB
722
+ feat_sgram = np.reshape(feat_sgram, (feat_sgram.shape[0] * feat_sgram.shape[1], feat_sgram.shape[2]))
723
+
724
+ with PdfPages(plot_name) as pdf:
725
+ # page1 we always have a mixture and prediction, target optional if truth provided
726
+ # For speech enhancement, target_f is definitely included:
727
+ predplot = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
728
+ tfunc_name = "target_f"
729
+ # if tfunc_name == 'mapped_snr_f':
730
+ # # leave as unmapped snr
731
+ # predplot = predict
732
+ # tfunc_name = mixdb.target_file(1).truth_settings[0].function
733
+ # elif tfunc_name == 'target_f' or 'target_mixture_f':
734
+ # predplot = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
735
+ # else:
736
+ # # use dB scale
737
+ # predplot = 10 * np.log10(predict + np.finfo(np.float32).eps)
738
+ # tfunc_name = tfunc_name + ' (db)'
739
+
740
+ mixspec = 20 * np.log10(abs(mixture_f) + np.finfo(np.float32).eps)
741
+ fig, ax = plot_mixpred(
742
+ mixture=mixture,
743
+ mixture_f=mixspec,
744
+ target=target,
745
+ feature=feat_sgram,
746
+ predict=predplot,
747
+ tp_title=tfunc_name,
748
+ )
749
+ pdf.savefig(fig)
750
+ pickle.dump((fig, ax), pgzip.open(base_name + "_metric_spenh_fig1.pkl.gz", "wb"))
751
+
752
+ # ----- page 2, plot unmapped predict, opt truth reconstructed and line plots of mean-over-f
753
+ # pdf.savefig(plot_pdb_predtruth(predict=pred_snr_f, tp_title='predict snr_f (db)'))
754
+
755
+ # page 3 speech extraction
756
+ tg_spec = 20 * np.log10(abs(target_f) + np.finfo(np.float32).eps)
757
+ tg_est_spec = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
758
+ # n_spec = np.reshape(n_spec,(n_spec.shape[0] * n_spec.shape[1], n_spec.shape[2]))
759
+ fig, ax = plot_e_predict_truth(
760
+ predict=tg_est_spec,
761
+ predict_wav=target_est_wav,
762
+ truth_f=tg_spec,
763
+ truth_wav=target_fi,
764
+ metric=np.vstack((lerr_tg_frame, phd_frame)).T,
765
+ tp_title="speech estimate",
766
+ )
767
+ pdf.savefig(fig)
768
+ pickle.dump((fig, ax), pgzip.open(base_name + "_metric_spenh_fig2.pkl.gz", "wb"))
769
+
770
+ # page 4 noise extraction
771
+ n_spec = 20 * np.log10(abs(noise_f) + np.finfo(np.float32).eps)
772
+ n_est_spec = 20 * np.log10(abs(noise_est_complex) + np.finfo(np.float32).eps)
773
+ fig, ax = plot_e_predict_truth(
774
+ predict=n_est_spec,
775
+ predict_wav=noise_est_wav,
776
+ truth_f=n_spec,
777
+ truth_wav=noise_fi,
778
+ metric=lerr_n_frame,
779
+ tp_title="noise estimate",
780
+ )
781
+ pdf.savefig(fig)
782
+ pickle.dump((fig, ax), pgzip.open(base_name + "_metric_spenh_fig4.pkl.gz", "wb"))
783
+
784
+ # Plot error waveforms
785
+ # tg_err_wav = target_fi - target_est_wav
786
+ # tg_err_spec = 20*np.log10(np.abs(target_f - predict_complex))
787
+
788
+ plt.close("all")
789
+
790
+ return all_metrics_table_1, all_metrics_table_2
791
+
792
+
793
+ def main():
794
+ from docopt import docopt
795
+
796
+ import sonusai
797
+ from sonusai.utils.docstring import trim_docstring
798
+
799
+ args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
800
+
801
+ verbose = args["--verbose"]
802
+ mixids = args["--mixid"]
803
+ asr_method = args["--asr-method"]
804
+ truth_est_mode = args["--truth-est-mode"]
805
+ enable_plot = args["--plot"]
806
+ enable_wav = args["--wav"]
807
+ enable_summary = args["--summary"]
808
+ predict_location = args["PLOC"]
809
+ num_proc = args["--num_process"]
810
+ truth_location = args["TLOC"]
811
+
812
+ import glob
813
+ from functools import partial
814
+ from os.path import basename
815
+ from os.path import isdir
816
+ from os.path import join
817
+
818
+ import psutil
819
+
820
+ from sonusai import create_file_handler
821
+ from sonusai import initial_log_messages
822
+ from sonusai import logger
823
+ from sonusai import update_console_handler
824
+ from sonusai.mixture import MixtureDatabase
825
+ from sonusai.utils.parallel import par_track
826
+ from sonusai.utils.parallel import track
827
+
828
+ # Check prediction subdirectory
829
+ if not isdir(predict_location):
830
+ print(f"The specified predict location {predict_location} is not a valid subdirectory path, exiting.")
831
+
832
+ # all_predict_files = listdir(predict_location)
833
+ all_predict_files = glob.glob(predict_location + "/*.h5")
834
+ predict_logfile = glob.glob(predict_location + "/*predict.log")
835
+ predict_wav_mode = False
836
+ if len(all_predict_files) <= 0 and not truth_est_mode:
837
+ all_predict_files = glob.glob(predict_location + "/*.wav") # check for wav files
838
+ if len(all_predict_files) <= 0:
839
+ print(f"Subdirectory {predict_location} has no .h5 or .wav files, exiting.")
840
+ else:
841
+ logger.info(f"Found {len(all_predict_files)} prediction .wav files.")
842
+ predict_wav_mode = True
843
+ else:
844
+ logger.info(f"Found {len(all_predict_files)} prediction .h5 files.")
845
+
846
+ if len(predict_logfile) == 0:
847
+ logger.info(f"Warning, predict location {predict_location} has no prediction log files.")
848
+ else:
849
+ logger.info(f"Found predict log {basename(predict_logfile[0])} in predict location.")
850
+
851
+ # Setup logging file
852
+ create_file_handler(join(predict_location, "calc_metric_spenh.log"), verbose)
853
+ update_console_handler(verbose)
854
+ initial_log_messages("calc_metric_spenh")
855
+
856
+ mixdb = MixtureDatabase(truth_location)
857
+ mixids = mixdb.mixids_to_list(mixids)
858
+ logger.info(
859
+ f"Found mixdb of {mixdb.num_mixtures} total mixtures, with {mixdb.num_classes} classes in {truth_location}"
860
+ )
861
+ # speech enhancement metrics and audio truth requires target_f truth type, check it is present
862
+ target_f_key = None
863
+ logger.info(
864
+ f"mixdb has {len(mixdb.category_truth_configs('primary'))} truth types defined for primary, checking that target_f type is present."
865
+ )
866
+ for key in mixdb.category_truth_configs("primary"):
867
+ if mixdb.category_truth_configs("primary")[key] == "target_f":
868
+ target_f_key = key
869
+ if target_f_key is None:
870
+ logger.error("mixdb does not have target_f truth defined, required for speech enhancement metrics, exiting.")
871
+ raise SystemExit(1)
872
+
873
+ logger.info(f"Only running specified subset of {len(mixids)} mixtures")
874
+
875
+ asr_config_en = None
876
+ fnb = "metric_spenh_"
877
+ if asr_method is not None:
878
+ if asr_method in mixdb.asr_configs:
879
+ logger.info(f"Specified ASR method {asr_method} exists in mixdb.asr_configs, it will be used for ")
880
+ logger.info("prediction ASR and WER, and pre-calculated target and mixture ASR if available.")
881
+ asr_config_en = True
882
+ asr_cfg = mixdb.asr_configs[asr_method]
883
+ fnb = "metric_spenh_" + asr_method + "_"
884
+ logger.info(f"Using ASR cfg: {asr_cfg} ")
885
+ # audio = read_audio(DEFAULT_SPEECH, use_cache=True)
886
+ # logger.info(f'Warming up {asr_method}, note for cloud service this could take up to a few minutes.')
887
+ # asr_chk = calc_asr(audio, **asr_cfg)
888
+ # logger.info(f'Warmup completed, results {asr_chk}')
889
+ else:
890
+ logger.info(
891
+ f"Specified ASR method {asr_method} does not exists in mixdb.asr_configs."
892
+ f"Must choose one of the following (or none):"
893
+ )
894
+ logger.info(f"{', '.join(mixdb.asr_configs)}")
895
+ logger.error("Unrecognized ASR method, exiting.")
896
+ raise SystemExit(1)
897
+
898
+ num_cpu = psutil.cpu_count()
899
+ cpu_percent = psutil.cpu_percent(interval=1)
900
+ logger.info(f"#CPUs: {num_cpu}, current CPU utilization: {cpu_percent}%")
901
+ logger.info(f"Memory utilization: {psutil.virtual_memory().percent}%")
902
+ if num_proc == "auto":
903
+ use_cpu = int(num_cpu * (0.9 - cpu_percent / 100)) # default use 80% of available cpus
904
+ elif num_proc == "None":
905
+ use_cpu = None
906
+ else:
907
+ use_cpu = min(max(int(num_proc), 1), num_cpu)
908
+
909
+ # Individual mixtures use pandas print, set precision to 2 decimal places
910
+ # pd.set_option('float_format', '{:.2f}'.format)
911
+ logger.info(f"Calculating metrics for {len(mixids)} mixtures using {use_cpu} parallel processes")
912
+ # progress = tqdm(total=len(mixids), desc='calc_metric_spenh', mininterval=1)
913
+ progress = track(total=len(mixids))
914
+ if use_cpu is None:
915
+ no_par = True
916
+ num_cpus = None
917
+ else:
918
+ no_par = False
919
+ num_cpus = use_cpu
920
+
921
+ all_metrics_tables = par_track(
922
+ partial(
923
+ _process_mixture,
924
+ truth_location=truth_location,
925
+ predict_location=predict_location,
926
+ predict_wav_mode=predict_wav_mode,
927
+ truth_est_mode=truth_est_mode,
928
+ enable_plot=enable_plot,
929
+ enable_wav=enable_wav,
930
+ asr_method=asr_method,
931
+ target_f_key=target_f_key,
932
+ ),
933
+ mixids,
934
+ progress=progress,
935
+ num_cpus=num_cpus,
936
+ no_par=no_par,
937
+ )
938
+ progress.close()
939
+
940
+ all_metrics_table_1 = pd.concat([item[0] for item in all_metrics_tables])
941
+ all_metrics_table_2 = pd.concat([item[1] for item in all_metrics_tables])
942
+
943
+ if not enable_summary:
944
+ return
945
+
946
+ # 9) Done with mixtures, write out summary metrics
947
+ # Calculate SNR summary avg of each non-random snr
948
+ all_mtab1_sorted = all_metrics_table_1.sort_values(by=["MXSNR", "SPFILE"])
949
+ all_mtab2_sorted = all_metrics_table_2.sort_values(by=["MXSNR"])
950
+ mtab_snr_summary = None
951
+ mtab_snr_summary_em = None
952
+ for snri in range(0, len(mixdb.snrs)):
953
+ tmp = all_mtab1_sorted.query("MXSNR==" + str(mixdb.snrs[snri])).mean(numeric_only=True).to_frame().T
954
+ # avoid nan when subset of mixids specified
955
+ if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
956
+ mtab_snr_summary = pd.concat([mtab_snr_summary, tmp])
957
+
958
+ tmp = all_mtab2_sorted[all_mtab2_sorted["MXSNR"] == mixdb.snrs[snri]].mean(numeric_only=True).to_frame().T
959
+ # avoid nan when subset of mixids specified (mxsnr will be nan if no data):
960
+ if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
961
+ mtab_snr_summary_em = pd.concat([mtab_snr_summary_em, tmp])
962
+
963
+ mtab_snr_summary = mtab_snr_summary.sort_values(by=["MXSNR"], ascending=False)
964
+ # Correct percentages in snr summary table
965
+ mtab_snr_summary["PESQi%"] = (
966
+ 100 * (mtab_snr_summary["PESQ"] - mtab_snr_summary["MXPESQ"]) / np.maximum(mtab_snr_summary["MXPESQ"], 0.01)
967
+ )
968
+ for i in range(len(mtab_snr_summary)):
969
+ if mtab_snr_summary["MXWER"].iloc[i] == 0.0:
970
+ if mtab_snr_summary["WER"].iloc[i] == 0.0:
971
+ mtab_snr_summary.iloc[i, 6] = 0.0 # mtab_snr_summary['WERi%'].iloc[i] = 0.0
972
+ else:
973
+ mtab_snr_summary.iloc[i, 6] = -999.0 # mtab_snr_summary['WERi%'].iloc[i] = -999.0
974
+ else:
975
+ if ~np.isnan(mtab_snr_summary["WER"].iloc[i]) and ~np.isnan(mtab_snr_summary["MXWER"].iloc[i]):
976
+ # update WERi% in 6th col
977
+ mtab_snr_summary.iloc[i, 6] = (
978
+ 100
979
+ * (mtab_snr_summary["MXWER"].iloc[i] - mtab_snr_summary["WER"].iloc[i])
980
+ / mtab_snr_summary["MXWER"].iloc[i]
981
+ )
982
+
983
+ # Calculate avg metrics over all mixtures except -99
984
+ all_mtab1_sorted_nom99 = all_mtab1_sorted[all_mtab1_sorted.MXSNR != -99]
985
+ all_nom99_mean = all_mtab1_sorted_nom99.mean(numeric_only=True)
986
+
987
+ # correct the percentage averages with a direct calculation (PESQ% and WER%):
988
+ # ser.iloc[pos]
989
+ all_nom99_mean["PESQi%"] = (
990
+ 100 * (all_nom99_mean["PESQ"] - all_nom99_mean["MXPESQ"]) / np.maximum(all_nom99_mean["MXPESQ"], 0.01)
991
+ ) # pesq%
992
+ # all_nom99_mean[3] = 100 * (all_nom99_mean[2] - all_nom99_mean[1]) / np.maximum(all_nom99_mean[1], 0.01) # pesq%
993
+ if all_nom99_mean["MXWER"] == 0.0:
994
+ if all_nom99_mean["WER"] == 0.0:
995
+ all_nom99_mean["WERi%"] = 0.0
996
+ else:
997
+ all_nom99_mean["WERi%"] = -999.0
998
+ else: # WER%
999
+ all_nom99_mean["WERi%"] = 100 * (all_nom99_mean["MXWER"] - all_nom99_mean["WER"]) / all_nom99_mean["MXWER"]
1000
+
1001
+ num_mix = len(mixids)
1002
+ if num_mix > 1:
1003
+ # Print pandas data to files using precision to 2 decimals
1004
+ # pd.set_option('float_format', '{:.2f}'.format)
1005
+
1006
+ if not truth_est_mode:
1007
+ ofname = join(predict_location, fnb + "summary.txt")
1008
+ else:
1009
+ ofname = join(predict_location, fnb + "summary_truest.txt")
1010
+
1011
+ with open(ofname, "w") as f:
1012
+ print(f"ASR enabled with method {asr_method}", file=f)
1013
+ print(
1014
+ f"Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:", file=f
1015
+ )
1016
+ print(
1017
+ all_nom99_mean.to_frame().T.round(2).to_string(float_format=lambda x: f"{x:.2f}", index=False), file=f
1018
+ )
1019
+ print("\nSpeech enhancement metrics avg over each SNR:", file=f)
1020
+ print(mtab_snr_summary.round(2).to_string(float_format=lambda x: f"{x:.2f}", index=False), file=f)
1021
+ print("", file=f)
1022
+ print("Extraction statistics stats avg over each SNR:", file=f)
1023
+ # with pd.option_context('display.max_colwidth', 9):
1024
+ # with pd.set_option('float_format', '{:.1f}'.format):
1025
+ print(mtab_snr_summary_em.round(1).to_string(float_format=lambda x: f"{x:.1f}", index=False), file=f)
1026
+ print("", file=f)
1027
+ # pd.set_option('float_format', '{:.2f}'.format)
1028
+
1029
+ print(f"Speech enhancement metrics stats over all {num_mix} mixtures:", file=f)
1030
+ print(all_metrics_table_1.describe().round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
1031
+ print("", file=f)
1032
+ print(f"Extraction statistics stats over all {num_mix} mixtures:", file=f)
1033
+ print(all_metrics_table_2.describe().round(2).to_string(float_format=lambda x: f"{x:.1f}"), file=f)
1034
+ print("", file=f)
1035
+
1036
+ print("Speech enhancement metrics all-mixtures list:", file=f)
1037
+ # print(all_metrics_table_1.head().style.format(precision=2), file=f)
1038
+ print(all_metrics_table_1.round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
1039
+ print("", file=f)
1040
+ print("Extraction statistics all-mixtures list:", file=f)
1041
+ print(all_metrics_table_2.round(2).to_string(float_format=lambda x: f"{x:.1f}"), file=f)
1042
+
1043
+ # Write summary to .csv file
1044
+ if not truth_est_mode:
1045
+ csv_name = str(join(predict_location, fnb + "summary.csv"))
1046
+ else:
1047
+ csv_name = str(join(predict_location, fnb + "truest_summary.csv"))
1048
+ header_args = {
1049
+ "mode": "a",
1050
+ "encoding": "utf-8",
1051
+ "index": False,
1052
+ "header": False,
1053
+ }
1054
+ table_args = {
1055
+ "mode": "a",
1056
+ "encoding": "utf-8",
1057
+ }
1058
+ label = f"Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:"
1059
+ pd.DataFrame([label]).to_csv(csv_name, header=False, index=False) # open as write
1060
+ all_nom99_mean.to_frame().T.round(2).to_csv(csv_name, index=False, **table_args)
1061
+ pd.DataFrame([""]).to_csv(csv_name, **header_args)
1062
+ pd.DataFrame(["Speech enhancement metrics avg over each SNR:"]).to_csv(csv_name, **header_args)
1063
+ mtab_snr_summary.round(2).to_csv(csv_name, index=False, **table_args)
1064
+ pd.DataFrame([""]).to_csv(csv_name, **header_args)
1065
+ pd.DataFrame(["Extraction statistics stats avg over each SNR:"]).to_csv(csv_name, **header_args)
1066
+ mtab_snr_summary_em.round(2).to_csv(csv_name, index=False, **table_args)
1067
+ pd.DataFrame([""]).to_csv(csv_name, **header_args)
1068
+ pd.DataFrame([""]).to_csv(csv_name, **header_args)
1069
+ label = f"Speech enhancement metrics stats over {num_mix} mixtures:"
1070
+ pd.DataFrame([label]).to_csv(csv_name, **header_args)
1071
+ all_metrics_table_1.describe().round(2).to_csv(csv_name, **table_args)
1072
+ pd.DataFrame([""]).to_csv(csv_name, **header_args)
1073
+ label = f"Extraction statistics stats over {num_mix} mixtures:"
1074
+ pd.DataFrame([label]).to_csv(csv_name, **header_args)
1075
+ all_metrics_table_2.describe().round(2).to_csv(csv_name, **table_args)
1076
+ label = f"ASR enabled with method {asr_method}"
1077
+ pd.DataFrame([label]).to_csv(csv_name, **header_args)
1078
+
1079
+ if not truth_est_mode:
1080
+ csv_name = str(join(predict_location, fnb + "list.csv"))
1081
+ else:
1082
+ csv_name = str(join(predict_location, fnb + "truest_list.csv"))
1083
+ pd.DataFrame(["Speech enhancement metrics list:"]).to_csv(csv_name, header=False, index=False) # open as write
1084
+ all_metrics_table_1.round(2).to_csv(csv_name, **table_args)
1085
+
1086
+ if not truth_est_mode:
1087
+ csv_name = str(join(predict_location, fnb + "estats_list.csv"))
1088
+ else:
1089
+ csv_name = str(join(predict_location, fnb + "truest_estats_list.csv"))
1090
+ pd.DataFrame(["Extraction statistics list:"]).to_csv(csv_name, header=False, index=False) # open as write
1091
+ all_metrics_table_2.round(2).to_csv(csv_name, **table_args)
1092
+
1093
+
1094
+ if __name__ == "__main__":
1095
+ from sonusai import exception_handler
1096
+ from sonusai.utils.keyboard_interrupt import register_keyboard_interrupt
1097
+
1098
+ register_keyboard_interrupt()
1099
+ try:
1100
+ main()
1101
+ except Exception as e:
1102
+ exception_handler(e)
1103
+
1104
+ # if asr_method == 'none':
1105
+ # fnb = 'metric_spenh_'
1106
+ # elif asr_method == 'google':
1107
+ # fnb = 'metric_spenh_ggl_'
1108
+ # logger.info(f'ASR enabled with method {asr_method}')
1109
+ # enable_asr_warmup = True
1110
+ # elif asr_method == 'deepgram':
1111
+ # fnb = 'metric_spenh_dgram_'
1112
+ # logger.info(f'ASR enabled with method {asr_method}')
1113
+ # enable_asr_warmup = True
1114
+ # elif asr_method == 'aixplain_whisper':
1115
+ # fnb = 'metric_spenh_whspx_' + mixdb.asr_configs[asr_method]['model'] + '_'
1116
+ # asr_model_name = mixdb.asr_configs[asr_method]['model']
1117
+ # enable_asr_warmup = True
1118
+ # elif asr_method == 'whisper':
1119
+ # fnb = 'metric_spenh_whspl_' + mixdb.asr_configs[asr_method]['model'] + '_'
1120
+ # asr_model_name = mixdb.asr_configs[asr_method]['model']
1121
+ # enable_asr_warmup = True
1122
+ # elif asr_method == 'aaware_whisper':
1123
+ # fnb = 'metric_spenh_whspaaw_' + mixdb.asr_configs[asr_method]['model'] + '_'
1124
+ # asr_model_name = mixdb.asr_configs[asr_method]['model']
1125
+ # enable_asr_warmup = True
1126
+ # elif asr_method == 'faster_whisper':
1127
+ # fnb = 'metric_spenh_fwhsp_' + mixdb.asr_configs[asr_method]['model'] + '_'
1128
+ # asr_model_name = mixdb.asr_configs[asr_method]['model']
1129
+ # enable_asr_warmup = True
1130
+ # elif asr_method == 'sensory':
1131
+ # fnb = 'metric_spenh_snsr_' + mixdb.asr_configs[asr_method]['model'] + '_'
1132
+ # asr_model_name = mixdb.asr_configs[asr_method]['model']
1133
+ # enable_asr_warmup = True
1134
+ # else:
1135
+ # logger.error(f'Unrecognized ASR method: {asr_method}')
1136
+ # return