sonusai 1.0.16__cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,1136 @@
|
|
1
|
+
"""sonusai calc_metric_spenh
|
2
|
+
|
3
|
+
usage: calc_metric_spenh [-hvtpws] [-i MIXID] [-e ASR] [-n NCPU] PLOC TLOC
|
4
|
+
|
5
|
+
options:
|
6
|
+
-h, --help
|
7
|
+
-v, --verbose Be verbose.
|
8
|
+
-i MIXID, --mixid MIXID Mixture ID(s) to process, can be range like 0:maxmix+1. [default: *]
|
9
|
+
-t, --truth-est-mode Calculate extraction and metrics using truth (instead of prediction).
|
10
|
+
-p, --plot Enable PDF plots file generation per mixture.
|
11
|
+
-w, --wav Generate WAV files per mixture.
|
12
|
+
-s, --summary Enable summary files generation.
|
13
|
+
-n, --num_process NCPU Number of parallel processes to use [default: auto]
|
14
|
+
-e ASR, --asr-method ASR ASR method used for WER metrics. Must exist in the TLOC dataset as pre-calculated
|
15
|
+
metrics using SonusAI genmetrics. Can be either an integer index, i.e 0,1,... or the
|
16
|
+
name of the asr_engine configuration in the dataset. If an incorrect name is specified,
|
17
|
+
a list of asr_engines of the dataset will be printed.
|
18
|
+
|
19
|
+
Calculate speech enhancement metrics of prediction data in PLOC using SonusAI mixture data in TLOC as truth/label
|
20
|
+
reference. Metric and extraction data files are written into PLOC.
|
21
|
+
|
22
|
+
PLOC directory containing prediction data in .h5 files created from truth/label mixture data in TLOC
|
23
|
+
TLOC directory with SonusAI mixture database of truth/label mixture data
|
24
|
+
|
25
|
+
For ASR methods, the method must bel2 defined in the TLOC dataset, for example possible fast_whisper available models are:
|
26
|
+
{tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large} and an example configuration looks like:
|
27
|
+
{'fwhsptiny_cpu': {'engine': 'faster_whisper',
|
28
|
+
'model': 'tiny',
|
29
|
+
'device': 'cpu',
|
30
|
+
'beam_size': 5}}
|
31
|
+
Note: the ASR config can optionally include the model, device, and other fields the engine supports.
|
32
|
+
Most ASR are very computationally demanding and can overwhelm/hang a local system.
|
33
|
+
|
34
|
+
Outputs the following to PLOC (where id is mixid number 0:num_mixtures):
|
35
|
+
<id>_metric_spenh.txt
|
36
|
+
|
37
|
+
If --plot:
|
38
|
+
<id>_metric_spenh.pdf
|
39
|
+
|
40
|
+
If --wav:
|
41
|
+
<id>_target.wav
|
42
|
+
<id>_target_est.wav
|
43
|
+
<id>_noise.wav
|
44
|
+
<id>_noise_est.wav
|
45
|
+
<id>_mixture.wav
|
46
|
+
|
47
|
+
If --truth-est-mode:
|
48
|
+
<id>_target_truth_est.wav
|
49
|
+
<id>_noise_truth_est.wav
|
50
|
+
|
51
|
+
If --summary:
|
52
|
+
metric_spenh_targetf_summary.txt
|
53
|
+
metric_spenh_targetf_summary.csv
|
54
|
+
metric_spenh_targetf_list.csv
|
55
|
+
metric_spenh_targetf_estats_list.csv
|
56
|
+
|
57
|
+
If --truth-est-mode:
|
58
|
+
metric_spenh_targetf_truth_list.csv
|
59
|
+
metric_spenh_targetf_estats_truth_list.csv
|
60
|
+
|
61
|
+
TBD
|
62
|
+
Metric and extraction data are written into prediction location PLOC as separate files per mixture.
|
63
|
+
|
64
|
+
-d PLOC, --ploc PLOC Location of SonusAI predict data.
|
65
|
+
|
66
|
+
Inputs:
|
67
|
+
|
68
|
+
"""
|
69
|
+
|
70
|
+
from typing import Any
|
71
|
+
|
72
|
+
import matplotlib
|
73
|
+
import matplotlib.pyplot as plt
|
74
|
+
import numpy as np
|
75
|
+
import pandas as pd
|
76
|
+
|
77
|
+
from sonusai.datatypes import AudioF
|
78
|
+
from sonusai.datatypes import AudioT
|
79
|
+
from sonusai.datatypes import Feature
|
80
|
+
from sonusai.datatypes import Predict
|
81
|
+
from sonusai.mixture import MixtureDatabase
|
82
|
+
|
83
|
+
DB_99 = np.power(10, 99 / 10)
|
84
|
+
DB_N99 = np.power(10, -99 / 10)
|
85
|
+
|
86
|
+
|
87
|
+
matplotlib.use("SVG")
|
88
|
+
|
89
|
+
|
90
|
+
def first_key(x: dict) -> str:
|
91
|
+
for key in x:
|
92
|
+
return key
|
93
|
+
raise KeyError("No key found")
|
94
|
+
|
95
|
+
|
96
|
+
def mean_square_error(
|
97
|
+
hypothesis: np.ndarray,
|
98
|
+
reference: np.ndarray,
|
99
|
+
squared: bool = False,
|
100
|
+
) -> tuple[float, np.ndarray, np.ndarray]:
|
101
|
+
"""Calculate root-mean-square error or mean square error
|
102
|
+
|
103
|
+
:param hypothesis: [frames, bins]
|
104
|
+
:param reference: [frames, bins]
|
105
|
+
:param squared: calculate mean square rather than root-mean-square
|
106
|
+
:return: mean, mean per bin, mean per frame
|
107
|
+
"""
|
108
|
+
sq_err = np.square(reference - hypothesis)
|
109
|
+
|
110
|
+
# mean over frames for value per bin
|
111
|
+
err_b = np.mean(sq_err, axis=0)
|
112
|
+
# mean over bins for value per frame
|
113
|
+
err_f = np.mean(sq_err, axis=1)
|
114
|
+
# mean over all
|
115
|
+
err = float(np.mean(sq_err))
|
116
|
+
|
117
|
+
if not squared:
|
118
|
+
err_b = np.sqrt(err_b)
|
119
|
+
err_f = np.sqrt(err_f)
|
120
|
+
err = np.sqrt(err)
|
121
|
+
|
122
|
+
return err, err_b, err_f
|
123
|
+
|
124
|
+
|
125
|
+
def mean_abs_percentage_error(hypothesis: np.ndarray, reference: np.ndarray) -> tuple[float, np.ndarray, np.ndarray]:
|
126
|
+
"""Calculate mean abs percentage error
|
127
|
+
|
128
|
+
If inputs are complex, calculates average: mape(real)/2 + mape(imag)/2
|
129
|
+
|
130
|
+
:param hypothesis: [frames, bins]
|
131
|
+
:param reference: [frames, bins]
|
132
|
+
:return: mean, mean per bin, mean per frame
|
133
|
+
"""
|
134
|
+
if not np.iscomplexobj(reference) and not np.iscomplexobj(hypothesis):
|
135
|
+
abs_err = 100 * np.abs((reference - hypothesis) / (reference + np.finfo(np.float32).eps))
|
136
|
+
else:
|
137
|
+
reference_r = np.real(reference)
|
138
|
+
reference_i = np.imag(reference)
|
139
|
+
hypothesis_r = np.real(hypothesis)
|
140
|
+
hypothesis_i = np.imag(hypothesis)
|
141
|
+
abs_err_r = 100 * np.abs((reference_r - hypothesis_r) / (reference_r + np.finfo(np.float32).eps))
|
142
|
+
abs_err_i = 100 * np.abs((reference_i - hypothesis_i) / (reference_i + np.finfo(np.float32).eps))
|
143
|
+
abs_err = (abs_err_r / 2) + (abs_err_i / 2)
|
144
|
+
|
145
|
+
# mean over frames for value per bin
|
146
|
+
err_b = np.around(np.mean(abs_err, axis=0), 3)
|
147
|
+
# mean over bins for value per frame
|
148
|
+
err_f = np.around(np.mean(abs_err, axis=1), 3)
|
149
|
+
# mean over all
|
150
|
+
err = float(np.around(np.mean(abs_err), 3))
|
151
|
+
|
152
|
+
return err, err_b, err_f
|
153
|
+
|
154
|
+
|
155
|
+
def log_error(reference: np.ndarray, hypothesis: np.ndarray) -> tuple[float, np.ndarray, np.ndarray]:
|
156
|
+
"""Calculate log error
|
157
|
+
|
158
|
+
:param reference: complex or real [frames, bins]
|
159
|
+
:param hypothesis: complex or real [frames, bins]
|
160
|
+
:return: mean, mean per bin, mean per frame
|
161
|
+
"""
|
162
|
+
reference_sq = np.real(reference * np.conjugate(reference))
|
163
|
+
hypothesis_sq = np.real(hypothesis * np.conjugate(hypothesis))
|
164
|
+
log_err = abs(10 * np.log10((reference_sq + np.finfo(np.float32).eps) / (hypothesis_sq + np.finfo(np.float32).eps)))
|
165
|
+
# log_err = abs(10 * np.log10(reference_sq / (hypothesis_sq + np.finfo(np.float32).eps) + np.finfo(np.float32).eps))
|
166
|
+
|
167
|
+
# mean over frames for value per bin
|
168
|
+
err_b = np.around(np.mean(log_err, axis=0), 3)
|
169
|
+
# mean over bins for value per frame
|
170
|
+
err_f = np.around(np.mean(log_err, axis=1), 3)
|
171
|
+
# mean over all
|
172
|
+
err = float(np.around(np.mean(log_err), 3))
|
173
|
+
|
174
|
+
return err, err_b, err_f
|
175
|
+
|
176
|
+
|
177
|
+
def plot_mixpred(
|
178
|
+
mixture: AudioT,
|
179
|
+
mixture_f: AudioF,
|
180
|
+
target: AudioT | None = None,
|
181
|
+
feature: Feature | None = None,
|
182
|
+
predict: Predict | None = None,
|
183
|
+
tp_title: str = "",
|
184
|
+
) -> tuple[plt.Figure, Any]: # pyright: ignore [reportPrivateImportUsage]
|
185
|
+
from sonusai.constants import SAMPLE_RATE
|
186
|
+
|
187
|
+
num_plots = 2
|
188
|
+
if feature is not None:
|
189
|
+
num_plots += 1
|
190
|
+
if predict is not None:
|
191
|
+
num_plots += 1
|
192
|
+
|
193
|
+
fig, ax = plt.subplots(num_plots, 1, constrained_layout=True, figsize=(11, 8.5))
|
194
|
+
|
195
|
+
# Plot the waveform
|
196
|
+
p = 0
|
197
|
+
x_axis = np.arange(len(mixture), dtype=np.float32) / SAMPLE_RATE
|
198
|
+
ax[p].plot(x_axis, mixture, label="Mixture", color="mistyrose")
|
199
|
+
ax[0].set_ylabel("magnitude", color="tab:blue")
|
200
|
+
ax[p].set_xlim(x_axis[0], x_axis[-1])
|
201
|
+
if target is not None: # Plot target time-domain waveform on top of mixture
|
202
|
+
ax[0].plot(x_axis, target, label="Target", color="tab:blue")
|
203
|
+
ax[p].set_title("Waveform")
|
204
|
+
|
205
|
+
# Plot the mixture spectrogram
|
206
|
+
p += 1
|
207
|
+
ax[p].imshow(np.transpose(mixture_f), aspect="auto", interpolation="nearest", origin="lower")
|
208
|
+
ax[p].set_title("Mixture")
|
209
|
+
|
210
|
+
if feature is not None:
|
211
|
+
p += 1
|
212
|
+
ax[p].imshow(np.transpose(feature), aspect="auto", interpolation="nearest", origin="lower")
|
213
|
+
ax[p].set_title("Feature")
|
214
|
+
|
215
|
+
if predict is not None:
|
216
|
+
p += 1
|
217
|
+
im = ax[p].imshow(np.transpose(predict), aspect="auto", interpolation="nearest", origin="lower")
|
218
|
+
ax[p].set_title("Predict " + tp_title)
|
219
|
+
plt.colorbar(im, location="bottom")
|
220
|
+
|
221
|
+
return fig, ax
|
222
|
+
|
223
|
+
|
224
|
+
def plot_pdb_predict_truth(
|
225
|
+
predict: np.ndarray,
|
226
|
+
truth_f: np.ndarray | None = None,
|
227
|
+
metric: np.ndarray | None = None,
|
228
|
+
tp_title: str = "",
|
229
|
+
) -> plt.Figure: # pyright: ignore [reportPrivateImportUsage]
|
230
|
+
"""Plot predict and optionally truth and a metric in power db, e.g. applies 10*log10(predict)"""
|
231
|
+
num_plots = 2
|
232
|
+
if truth_f is not None:
|
233
|
+
num_plots += 1
|
234
|
+
|
235
|
+
fig, ax = plt.subplots(num_plots, 1, constrained_layout=True, figsize=(11, 8.5))
|
236
|
+
|
237
|
+
# Plot the predict spectrogram
|
238
|
+
p = 0
|
239
|
+
tmp = 10 * np.log10(predict.transpose() + np.finfo(np.float32).eps)
|
240
|
+
im = ax[p].imshow(tmp, aspect="auto", interpolation="nearest", origin="lower")
|
241
|
+
ax[p].set_title("Predict")
|
242
|
+
plt.colorbar(im, location="bottom")
|
243
|
+
|
244
|
+
if truth_f is not None:
|
245
|
+
p += 1
|
246
|
+
tmp = 10 * np.log10(truth_f.transpose() + np.finfo(np.float32).eps)
|
247
|
+
im = ax[p].imshow(tmp, aspect="auto", interpolation="nearest", origin="lower")
|
248
|
+
ax[p].set_title("Truth")
|
249
|
+
plt.colorbar(im, location="bottom")
|
250
|
+
|
251
|
+
# Plot the predict avg, and optionally truth avg and metric lines
|
252
|
+
pred_avg = 10 * np.log10(np.mean(predict, axis=-1) + np.finfo(np.float32).eps)
|
253
|
+
p += 1
|
254
|
+
x_axis = np.arange(len(pred_avg), dtype=np.float32) # / SAMPLE_RATE
|
255
|
+
ax[p].plot(x_axis, pred_avg, color="black", linestyle="dashed", label="Predict mean over freq.")
|
256
|
+
ax[p].set_ylabel("mean db", color="black")
|
257
|
+
ax[p].set_xlim(x_axis[0], x_axis[-1])
|
258
|
+
if truth_f is not None:
|
259
|
+
truth_avg = 10 * np.log10(np.mean(truth_f, axis=-1) + np.finfo(np.float32).eps)
|
260
|
+
ax[p].plot(x_axis, truth_avg, color="green", linestyle="dashed", label="Truth mean over freq.")
|
261
|
+
|
262
|
+
if metric is not None: # instantiate 2nd y-axis that shares the same x-axis
|
263
|
+
ax2 = ax[p].twinx()
|
264
|
+
color2 = "red"
|
265
|
+
ax2.plot(x_axis, metric, color=color2, label="sig distortion (mse db)")
|
266
|
+
ax2.set_xlim(x_axis[0], x_axis[-1])
|
267
|
+
ax2.set_ylim([0, np.max(metric)])
|
268
|
+
ax2.set_ylabel("spectral distortion (mse db)", color=color2)
|
269
|
+
ax2.tick_params(axis="y", labelcolor=color2)
|
270
|
+
ax[p].set_title("SNR and SNR mse (mean over freq. db)")
|
271
|
+
else:
|
272
|
+
ax[p].set_title("SNR (mean over freq. db)")
|
273
|
+
return fig
|
274
|
+
|
275
|
+
|
276
|
+
def plot_e_predict_truth(
|
277
|
+
predict: np.ndarray,
|
278
|
+
predict_wav: np.ndarray,
|
279
|
+
truth_f: np.ndarray | None = None,
|
280
|
+
truth_wav: np.ndarray | None = None,
|
281
|
+
metric: np.ndarray | None = None,
|
282
|
+
tp_title: str = "",
|
283
|
+
) -> tuple[plt.Figure, Any]: # pyright: ignore [reportPrivateImportUsage]
|
284
|
+
"""Plot predict spectrogram and waveform and optionally truth and a metric)"""
|
285
|
+
num_plots = 2
|
286
|
+
if truth_f is not None:
|
287
|
+
num_plots += 1
|
288
|
+
if metric is not None:
|
289
|
+
num_plots += 1
|
290
|
+
|
291
|
+
fig, ax = plt.subplots(num_plots, 1, constrained_layout=True, figsize=(11, 8.5))
|
292
|
+
|
293
|
+
# Plot the predict spectrogram
|
294
|
+
p = 0
|
295
|
+
im = ax[p].imshow(predict.transpose(), aspect="auto", interpolation="nearest", origin="lower")
|
296
|
+
ax[p].set_title("Predict")
|
297
|
+
plt.colorbar(im, location="bottom")
|
298
|
+
|
299
|
+
if truth_f is not None: # plot truth if provided and use same colormap as predict
|
300
|
+
p += 1
|
301
|
+
ax[p].imshow(truth_f.transpose(), im.cmap, aspect="auto", interpolation="nearest", origin="lower")
|
302
|
+
ax[p].set_title("Truth")
|
303
|
+
|
304
|
+
# Plot predict wav, and optionally truth avg and metric lines
|
305
|
+
p += 1
|
306
|
+
x_axis = np.arange(len(predict_wav), dtype=np.float32) # / SAMPLE_RATE
|
307
|
+
ax[p].plot(x_axis, predict_wav, color="black", linestyle="dashed", label="Speech Estimate")
|
308
|
+
ax[p].set_ylabel("Amplitude", color="black")
|
309
|
+
ax[p].set_xlim(x_axis[0], x_axis[-1])
|
310
|
+
if truth_wav is not None:
|
311
|
+
ntrim = len(truth_wav) - len(predict_wav)
|
312
|
+
if ntrim > 0:
|
313
|
+
truth_wav = truth_wav[0:-ntrim]
|
314
|
+
ax[p].plot(x_axis, truth_wav, color="green", linestyle="dashed", label="True Target")
|
315
|
+
|
316
|
+
# Plot the metric lines
|
317
|
+
if metric is not None:
|
318
|
+
p += 1
|
319
|
+
if metric.ndim > 1: # if it has multiple dims, plot 1st
|
320
|
+
metric1 = metric[:, 0]
|
321
|
+
else:
|
322
|
+
metric1 = metric # if single dim, plot it as 1st
|
323
|
+
x_axis = np.arange(len(metric1), dtype=np.float32) # / SAMPLE_RATE
|
324
|
+
ax[p].plot(x_axis, metric1, color="red", label="Target LogErr")
|
325
|
+
ax[p].set_ylabel("log error db", color="red")
|
326
|
+
ax[p].set_xlim(x_axis[0], x_axis[-1])
|
327
|
+
ax[p].set_ylim([-0.01, np.max(metric1) + 0.01])
|
328
|
+
if metric.ndim > 1 and metric.shape[1] > 1:
|
329
|
+
p += 1
|
330
|
+
metr2 = metric[:, 1]
|
331
|
+
ax = np.append(ax, np.array(ax[p - 1].twinx()))
|
332
|
+
color2 = "blue"
|
333
|
+
ax[p].plot(x_axis, metr2, color=color2, label="phase dist (deg)")
|
334
|
+
# ax2.set_ylim([-180.0, +180.0])
|
335
|
+
if np.max(metr2) - np.min(metr2) > 0.1:
|
336
|
+
ax[p].set_ylim([np.min(metr2), np.max(metr2)])
|
337
|
+
ax[p].set_ylabel("phase dist (deg)", color=color2)
|
338
|
+
ax[p].tick_params(axis="y", labelcolor=color2)
|
339
|
+
# ax[p].set_title('SNR and SNR mse (mean over freq. db)')
|
340
|
+
|
341
|
+
return fig, ax
|
342
|
+
|
343
|
+
|
344
|
+
def _process_mixture(
|
345
|
+
m_id: int,
|
346
|
+
truth_location: str,
|
347
|
+
predict_location: str,
|
348
|
+
predict_wav_mode: bool,
|
349
|
+
truth_est_mode: bool,
|
350
|
+
enable_plot: bool,
|
351
|
+
enable_wav: bool,
|
352
|
+
asr_method: str,
|
353
|
+
target_f_key: str,
|
354
|
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
355
|
+
import pickle
|
356
|
+
from os.path import basename
|
357
|
+
from os.path import join
|
358
|
+
from os.path import splitext
|
359
|
+
|
360
|
+
import h5py
|
361
|
+
import pgzip
|
362
|
+
from matplotlib.backends.backend_pdf import PdfPages
|
363
|
+
from pystoi import stoi
|
364
|
+
|
365
|
+
from sonusai import logger
|
366
|
+
from sonusai.metrics import calc_pcm
|
367
|
+
from sonusai.metrics import calc_pesq
|
368
|
+
from sonusai.metrics import calc_phase_distance
|
369
|
+
from sonusai.metrics import calc_speech
|
370
|
+
from sonusai.metrics import calc_wer
|
371
|
+
from sonusai.metrics import calc_wsdr
|
372
|
+
from sonusai.mixture import forward_transform
|
373
|
+
from sonusai.mixture import inverse_transform
|
374
|
+
from sonusai.mixture.audio import read_audio
|
375
|
+
from sonusai.utils.asr import calc_asr
|
376
|
+
from sonusai.utils.compress import power_compress
|
377
|
+
from sonusai.utils.compress import power_uncompress
|
378
|
+
from sonusai.utils.numeric_conversion import float_to_int16
|
379
|
+
from sonusai.utils.reshape import reshape_outputs
|
380
|
+
from sonusai.utils.stacked_complex import stack_complex
|
381
|
+
from sonusai.utils.stacked_complex import unstack_complex
|
382
|
+
from sonusai.utils.write_audio import write_audio
|
383
|
+
|
384
|
+
mixdb = MixtureDatabase(truth_location)
|
385
|
+
|
386
|
+
# 1) Read predict data, var predict with shape [BatchSize,Classes] or [batch, timesteps, classes]
|
387
|
+
output_name = join(predict_location, mixdb.mixture(m_id).name + ".h5")
|
388
|
+
predict = None
|
389
|
+
if truth_est_mode:
|
390
|
+
# in truth estimation mode we use the truth in place of prediction to see metrics with perfect input
|
391
|
+
# don't bother to read prediction, and predict var will get assigned to truth later
|
392
|
+
# mark outputs with tru suffix, i.e. 0000_truest_*
|
393
|
+
base_name = splitext(output_name)[0] + "_truest"
|
394
|
+
else:
|
395
|
+
base_name, ext = splitext(output_name) # base_name used later
|
396
|
+
if not predict_wav_mode:
|
397
|
+
try:
|
398
|
+
with h5py.File(output_name, "r") as f:
|
399
|
+
predict = np.array(f["predict"])
|
400
|
+
except Exception as e:
|
401
|
+
raise OSError(f"Error reading {output_name}: {e}") from e
|
402
|
+
# reshape to always be [frames, classes] where ndim==3 case frames = batch * timesteps
|
403
|
+
if predict.ndim > 2: # TBD generalize to somehow detect if timestep dim exists, some cases > 2 don't have
|
404
|
+
# logger.debug(f'Prediction reshape from {predict.shape} to remove timestep dimension.')
|
405
|
+
predict, _ = reshape_outputs(predict=predict, truth=None, timesteps=predict.shape[1])
|
406
|
+
else:
|
407
|
+
base_name, ext = splitext(output_name)
|
408
|
+
predict_name = join(base_name + ".wav")
|
409
|
+
audio = read_audio(predict_name, use_cache=True)
|
410
|
+
predict = forward_transform(audio, mixdb.ft_config)
|
411
|
+
if mixdb.feature[0:1] == "h":
|
412
|
+
predict = power_compress(predict)
|
413
|
+
predict = stack_complex(predict)
|
414
|
+
|
415
|
+
# 2) Collect true target, noise, mixture data, trim to predict size if needed
|
416
|
+
tmp = mixdb.mixture_sources(m_id) # time-dom augmented targets is list of pre-IR and pre-specaugment targets
|
417
|
+
target_f = mixdb.mixture_sources_f(m_id, sources=tmp)["primary"]
|
418
|
+
target = tmp["primary"]
|
419
|
+
mixture = mixdb.mixture_mixture(m_id) # note: gives full reverberated/distorted target, but no specaugment
|
420
|
+
# noise_wo_dist = mixdb.mixture_noise(mixid) # noise without specaugment and distortion
|
421
|
+
# noise_wo_dist_f = mixdb.mixture_noise_f(mixid, noise=noise_wo_dist)
|
422
|
+
noise = mixture - target # has time-domain distortion (ir,etc.) but does not have specaugment
|
423
|
+
# noise_f = mixdb.mixture_noise_f(mixid, noise=noise)
|
424
|
+
# note: uses pre-IR, pre-specaug audio
|
425
|
+
segsnr_f = mixdb.mixture_metrics(m_id, ["ssnr"])["ssnr"] # Why [0] removed?
|
426
|
+
mixture_f = mixdb.mixture_mixture_f(m_id, mixture=mixture)
|
427
|
+
noise_f = mixture_f - target_f # true noise in freq domain includes specaugment and time-domain ir,distortions
|
428
|
+
# segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise)
|
429
|
+
segsnr_f[segsnr_f == np.inf] = DB_99
|
430
|
+
# segsnr_f should never be -np.inf
|
431
|
+
segsnr_f[segsnr_f == -np.inf] = DB_N99
|
432
|
+
# need to use inv-tf to match #samples & latency shift properties of predict inv tf
|
433
|
+
target_fi = inverse_transform(target_f, mixdb.it_config)
|
434
|
+
noise_fi = inverse_transform(noise_f, mixdb.it_config)
|
435
|
+
# mixture_fi = mixdb.inverse_transform(mixture_f)
|
436
|
+
|
437
|
+
# gen feature, truth - note feature only used for plots
|
438
|
+
# TODO: parse truth_f for different formats
|
439
|
+
feature, truth_all = mixdb.mixture_ft(m_id, mixture_f=mixture_f)
|
440
|
+
truth_f = truth_all["primary"][target_f_key]
|
441
|
+
if truth_f.ndim > 2: # note this may not be needed anymore as all target_f truth is 3 dims
|
442
|
+
if truth_f.shape[1] != 1:
|
443
|
+
logger.info("Error: target_f truth has stride > 1, exiting.")
|
444
|
+
raise SystemExit(1)
|
445
|
+
else:
|
446
|
+
truth_f = truth_f[:, 0, :] # remove stride dimension
|
447
|
+
|
448
|
+
# ignore mixup
|
449
|
+
# for truth_setting in mixdb.target_file(mixdb.mixture(mixid).targets[0].file_id).truth_settings:
|
450
|
+
# if truth_setting.function == 'target_mixture_f':
|
451
|
+
# half = truth_f.shape[-1] // 2
|
452
|
+
# # extract target_f only
|
453
|
+
# truth_f = truth_f[..., :half]
|
454
|
+
|
455
|
+
if not truth_est_mode:
|
456
|
+
if predict.shape[0] < target_f.shape[0]: # target_f, truth_f, mixture_f, etc. same size
|
457
|
+
trim_f = target_f.shape[0] - predict.shape[0]
|
458
|
+
logger.debug(f"Warning: prediction frames less than mixture, trimming {trim_f} frames from all truth.")
|
459
|
+
target_f = target_f[0:-trim_f, :]
|
460
|
+
target_fi, _ = inverse_transform(target_f, mixdb.it_config)
|
461
|
+
trim_t = target.shape[0] - target_fi.shape[0]
|
462
|
+
target = target[0:-trim_t]
|
463
|
+
noise_f = noise_f[0:-trim_f, :]
|
464
|
+
noise = noise[0:-trim_t]
|
465
|
+
mixture_f = mixture_f[0:-trim_f, :]
|
466
|
+
mixture = mixture[0:-trim_t]
|
467
|
+
truth_f = truth_f[0:-trim_f, :]
|
468
|
+
elif predict.shape[0] > target_f.shape[0]:
|
469
|
+
logger.debug(
|
470
|
+
f"Warning: prediction has more frames than true mixture {predict.shape[0]} vs {truth_f.shape[0]}"
|
471
|
+
)
|
472
|
+
trim_f = predict.shape[0] - target_f.shape[0]
|
473
|
+
predict = predict[0:-trim_f, :]
|
474
|
+
# raise SonusAIError(
|
475
|
+
# f'Error: prediction has more frames than true mixture {predict.shape[0]} vs {truth_f.shape[0]}')
|
476
|
+
|
477
|
+
# 3) Extraction - format proper complex and wav estimates and truth (unstack, uncompress, inv tf, etc.)
|
478
|
+
if truth_est_mode:
|
479
|
+
predict = truth_f # substitute truth for the prediction (for test/debug)
|
480
|
+
predict_complex = unstack_complex(predict) # unstack
|
481
|
+
# if feature has compressed mag and truth does not, compress it
|
482
|
+
if mixdb.feature[0:1] == "h" and not first_key(mixdb.category_truth_configs("primary")).startswith(
|
483
|
+
"targetcmpr"
|
484
|
+
):
|
485
|
+
predict_complex = power_compress(predict_complex) # from uncompressed truth
|
486
|
+
else:
|
487
|
+
predict_complex = unstack_complex(predict)
|
488
|
+
|
489
|
+
truth_f_complex = unstack_complex(truth_f)
|
490
|
+
if mixdb.feature[0:1] == "h": # 'hn' or 'ha' or 'hd', etc.: # if feat has compressed mag
|
491
|
+
# estimate noise in uncompressed-mag domain
|
492
|
+
noise_est_complex = mixture_f - power_uncompress(predict_complex)
|
493
|
+
predict_complex = power_uncompress(predict_complex) # uncompress if truth is compressed
|
494
|
+
else: # cn, c8, ..
|
495
|
+
noise_est_complex = mixture_f - predict_complex
|
496
|
+
|
497
|
+
target_est_wav = inverse_transform(predict_complex, mixdb.it_config)
|
498
|
+
noise_est_wav = inverse_transform(noise_est_complex, mixdb.it_config)
|
499
|
+
|
500
|
+
# 4) Metrics
|
501
|
+
# Target/Speech logerr - PSD estimation accuracy symmetric mean log-spectral distortion
|
502
|
+
lerr_tg, lerr_tg_bin, lerr_tg_frame = log_error(reference=truth_f_complex, hypothesis=predict_complex)
|
503
|
+
# Noise logerr - PSD estimation accuracy
|
504
|
+
lerr_n, lerr_n_bin, lerr_n_frame = log_error(reference=noise_f, hypothesis=noise_est_complex)
|
505
|
+
# PCM loss metric
|
506
|
+
ytrue_f = np.concatenate((truth_f_complex[:, np.newaxis, :], noise_f[:, np.newaxis, :]), axis=1)
|
507
|
+
ypred_f = np.concatenate((predict_complex[:, np.newaxis, :], noise_est_complex[:, np.newaxis, :]), axis=1)
|
508
|
+
pcm, pcm_bin, pcm_frame = calc_pcm(hypothesis=ypred_f, reference=ytrue_f, with_log=True)
|
509
|
+
|
510
|
+
# Phase distance
|
511
|
+
phd, phd_bin, phd_frame = calc_phase_distance(hypothesis=predict_complex, reference=truth_f_complex)
|
512
|
+
|
513
|
+
# Noise td logerr
|
514
|
+
# lerr_nt, lerr_nt_bin, lerr_nt_frame = log_error(noise_fi, noise_truth_est_audio)
|
515
|
+
|
516
|
+
# # SA-SDR (time-domain source-aggregated SDR)
|
517
|
+
ytrue = np.concatenate((target_fi[:, np.newaxis], noise_fi[:, np.newaxis]), axis=1)
|
518
|
+
ypred = np.concatenate((target_est_wav[:, np.newaxis], noise_est_wav[:, np.newaxis]), axis=1)
|
519
|
+
# # note: w/o scale is more pessimistic number
|
520
|
+
# sa_sdr, _ = calc_sa_sdr(hypothesis=ypred, reference=ytrue)
|
521
|
+
target_stoi = stoi(target_fi, target_est_wav, 16000, extended=False)
|
522
|
+
|
523
|
+
wsdr, wsdr_cc, wsdr_cw = calc_wsdr(hypothesis=ypred, reference=ytrue, with_log=True)
|
524
|
+
# logger.debug(f'wsdr weight sum for mixid {mixid} = {np.sum(wsdr_cw)}.')
|
525
|
+
# logger.debug(f'wsdr cweights = {wsdr_cw}.')
|
526
|
+
# logger.debug(f'wsdr ccoefs for mixid {mixid} = {wsdr_cc}.')
|
527
|
+
|
528
|
+
# Speech intelligibility measure - PESQ
|
529
|
+
if int(mixdb.mixture(m_id).noise.snr) > -99:
|
530
|
+
# len = target_est_wav.shape[0]
|
531
|
+
pesq_speech = calc_pesq(target_est_wav, target_fi)
|
532
|
+
csig_tg, cbak_tg, covl_tg = calc_speech(target_est_wav, target_fi, pesq=pesq_speech)
|
533
|
+
metrics = mixdb.mixture_metrics(m_id, ["mxpesq", "mxcsig", "mxcbak", "mxcovl"])
|
534
|
+
pesq_mx = metrics["mxpesq"]["primary"] if isinstance(metrics["mxpesq"], dict) else metrics["mxpesq"]
|
535
|
+
csig_mx = metrics["mxcsig"]["primary"] if isinstance(metrics["mxcsig"], dict) else metrics["mxcsig"]
|
536
|
+
cbak_mx = metrics["mxcbak"]["primary"] if isinstance(metrics["mxcbak"], dict) else metrics["mxcbak"]
|
537
|
+
covl_mx = metrics["mxcovl"]["primary"] if isinstance(metrics["mxcovl"], dict) else metrics["mxcovl"]
|
538
|
+
# pesq_speech_tst = calc_pesq(hypothesis=target_est_wav, reference=target)
|
539
|
+
# pesq_mixture_tst = calc_pesq(hypothesis=mixture, reference=target)
|
540
|
+
# pesq improvement
|
541
|
+
pesq_impr = pesq_speech - pesq_mx
|
542
|
+
# pesq improvement %
|
543
|
+
pesq_impr_pc = pesq_impr / (pesq_mx + np.finfo(np.float32).eps) * 100
|
544
|
+
else:
|
545
|
+
pesq_speech = 0
|
546
|
+
pesq_mx = 0
|
547
|
+
pesq_impr_pc = np.float32(0)
|
548
|
+
csig_mx = 0
|
549
|
+
csig_tg = 0
|
550
|
+
cbak_mx = 0
|
551
|
+
cbak_tg = 0
|
552
|
+
covl_mx = 0
|
553
|
+
covl_tg = 0
|
554
|
+
|
555
|
+
# Calc ASR
|
556
|
+
asr_tt = None
|
557
|
+
asr_mx = None
|
558
|
+
asr_tge = None
|
559
|
+
# asr_engines = list(mixdb.asr_configs.keys())
|
560
|
+
if asr_method is not None and mixdb.mixture(m_id).noise.snr >= -96: # noise only, ignore/reset target ASR
|
561
|
+
asr_mx_name = f"mxasr.{asr_method}"
|
562
|
+
wer_mx_name = f"mxwer.{asr_method}"
|
563
|
+
asr_tt_name = f"sasr.{asr_method}"
|
564
|
+
metrics = mixdb.mixture_metrics(m_id, [asr_mx_name, wer_mx_name, asr_tt_name])
|
565
|
+
asr_mx = metrics[asr_mx_name]["primary"] if isinstance(metrics[asr_mx_name], dict) else metrics[asr_mx_name]
|
566
|
+
wer_mx = metrics[wer_mx_name]["primary"] if isinstance(metrics[wer_mx_name], dict) else metrics[wer_mx_name]
|
567
|
+
asr_tt = metrics[asr_tt_name]["primary"] if isinstance(metrics[asr_tt_name], dict) else metrics[asr_tt_name]
|
568
|
+
|
569
|
+
if asr_tt:
|
570
|
+
noiseadd = None # TBD add as switch, default -30
|
571
|
+
if noiseadd is not None:
|
572
|
+
ngain = np.power(10, min(float(noiseadd), 0.0) / 20.0) # limit to gain <1, convert to float
|
573
|
+
tgasr_est_wav = target_est_wav + ngain * noise_est_wav # add back noise at low level
|
574
|
+
else:
|
575
|
+
tgasr_est_wav = target_est_wav
|
576
|
+
|
577
|
+
# logger.info(f'Calculating prediction ASR for mixid {mixid}')
|
578
|
+
asr_cfg = mixdb.asr_configs[asr_method]
|
579
|
+
asr_tge = calc_asr(tgasr_est_wav, **asr_cfg).text
|
580
|
+
wer_tge = calc_wer(asr_tge, asr_tt).wer * 100 # target estimate WER
|
581
|
+
if wer_mx == 0.0:
|
582
|
+
if wer_tge == 0.0:
|
583
|
+
wer_pi = 0.0
|
584
|
+
else:
|
585
|
+
wer_pi = -999.0 # instead of -Inf
|
586
|
+
else:
|
587
|
+
wer_pi = 100 * (wer_mx - wer_tge) / wer_mx
|
588
|
+
else:
|
589
|
+
logger.warning(f"Warning: mixid {m_id} ASR truth is empty, setting to 0% WER")
|
590
|
+
wer_mx = float(0)
|
591
|
+
wer_tge = float(0)
|
592
|
+
wer_pi = float(0)
|
593
|
+
else:
|
594
|
+
wer_mx = float("nan")
|
595
|
+
wer_tge = float("nan")
|
596
|
+
wer_pi = float("nan")
|
597
|
+
|
598
|
+
# 5) Save per mixture metric results
|
599
|
+
# Single row in table of scalar metrics per mixture
|
600
|
+
mtable1_col = [
|
601
|
+
"MXSNR",
|
602
|
+
"MXPESQ",
|
603
|
+
"PESQ",
|
604
|
+
"PESQi%",
|
605
|
+
"MXWER",
|
606
|
+
"WER",
|
607
|
+
"WERi%",
|
608
|
+
"WSDR",
|
609
|
+
"STOI",
|
610
|
+
"PCM",
|
611
|
+
"SPLERR",
|
612
|
+
"NLERR",
|
613
|
+
"PD",
|
614
|
+
"MXCSIG",
|
615
|
+
"CSIG",
|
616
|
+
"MXCBAK",
|
617
|
+
"CBAK",
|
618
|
+
"MXCOVL",
|
619
|
+
"COVL",
|
620
|
+
"SPFILE",
|
621
|
+
"NFILE",
|
622
|
+
]
|
623
|
+
ti = mixdb.mixture(m_id).sources["primary"].file_id
|
624
|
+
ni = mixdb.mixture(m_id).noise.file_id
|
625
|
+
metr1 = [
|
626
|
+
mixdb.mixture(m_id).noise.snr,
|
627
|
+
pesq_mx,
|
628
|
+
pesq_speech,
|
629
|
+
pesq_impr_pc,
|
630
|
+
wer_mx,
|
631
|
+
wer_tge,
|
632
|
+
wer_pi,
|
633
|
+
wsdr,
|
634
|
+
target_stoi,
|
635
|
+
pcm,
|
636
|
+
lerr_tg,
|
637
|
+
lerr_n,
|
638
|
+
phd,
|
639
|
+
csig_mx,
|
640
|
+
csig_tg,
|
641
|
+
cbak_mx,
|
642
|
+
cbak_tg,
|
643
|
+
covl_mx,
|
644
|
+
covl_tg,
|
645
|
+
basename(mixdb.source_file(ti).name),
|
646
|
+
basename(mixdb.source_file(ni).name),
|
647
|
+
]
|
648
|
+
mtab1 = pd.DataFrame([metr1], columns=mtable1_col, index=[m_id])
|
649
|
+
|
650
|
+
# Stats of per frame estimation metrics
|
651
|
+
metr2 = pd.DataFrame(
|
652
|
+
{"SSNR": segsnr_f, "PCM": pcm_frame, "SLERR": lerr_tg_frame, "NLERR": lerr_n_frame, "SPD": phd_frame}
|
653
|
+
)
|
654
|
+
metr2 = metr2.describe() # Use pandas stat function
|
655
|
+
# Change SSNR stats to dB, except count. SSNR is index 0, pandas requires using iloc
|
656
|
+
# metr2['SSNR'][1:] = metr2['SSNR'][1:].apply(lambda x: 10 * np.log10(x + 1.01e-10))
|
657
|
+
metr2.iloc[1:, 0] = metr2["SSNR"][1:].apply(lambda x: 10 * np.log10(x + 1.01e-10))
|
658
|
+
# create a single row in multi-column header
|
659
|
+
new_labels = pd.MultiIndex.from_product(
|
660
|
+
[metr2.columns, ["Avg", "Min", "Med", "Max", "Std"]], names=["Metric", "Stat"]
|
661
|
+
)
|
662
|
+
dat1row = metr2.loc[["mean", "min", "50%", "max", "std"], :].T.stack().to_numpy().reshape((1, -1))
|
663
|
+
mtab2 = pd.DataFrame(dat1row, index=[m_id], columns=new_labels)
|
664
|
+
mtab2.insert(0, "MXSNR", mixdb.mixture(m_id).noise.snr, False) # add MXSNR as the first metric column
|
665
|
+
|
666
|
+
all_metrics_table_1 = mtab1 # return to be collected by process
|
667
|
+
all_metrics_table_2 = mtab2 # return to be collected by process
|
668
|
+
|
669
|
+
if asr_method is None:
|
670
|
+
metric_name = base_name + "_metric_spenh.txt"
|
671
|
+
else:
|
672
|
+
metric_name = base_name + "_metric_spenh_" + asr_method + ".txt"
|
673
|
+
|
674
|
+
with open(metric_name, "w") as f:
|
675
|
+
print("Speech enhancement metrics:", file=f)
|
676
|
+
print(mtab1.round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
|
677
|
+
print("", file=f)
|
678
|
+
print(f"Extraction statistics over {mixture_f.shape[0]} frames:", file=f)
|
679
|
+
print(metr2.round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
|
680
|
+
print("", file=f)
|
681
|
+
print(f"Target path: {mixdb.source_file(ti).name}", file=f)
|
682
|
+
print(f"Noise path: {mixdb.source_file(ni).name}", file=f)
|
683
|
+
if asr_method != "none":
|
684
|
+
print(f"ASR method: {asr_method}", file=f)
|
685
|
+
print(f"ASR truth: {asr_tt}", file=f)
|
686
|
+
print(f"ASR result for mixture: {asr_mx}", file=f)
|
687
|
+
print(f"ASR result for prediction: {asr_tge}", file=f)
|
688
|
+
|
689
|
+
print(f"Augmentations: {mixdb.mixture(m_id)}", file=f)
|
690
|
+
|
691
|
+
# 7) write wav files
|
692
|
+
if enable_wav:
|
693
|
+
write_audio(name=base_name + "_mixture.wav", audio=float_to_int16(mixture))
|
694
|
+
write_audio(name=base_name + "_target.wav", audio=float_to_int16(target))
|
695
|
+
# write_audio(name=base_name + '_target_fi.wav', audio=float_to_int16(target_fi))
|
696
|
+
write_audio(name=base_name + "_noise.wav", audio=float_to_int16(noise))
|
697
|
+
write_audio(name=base_name + "_target_est.wav", audio=float_to_int16(target_est_wav))
|
698
|
+
write_audio(name=base_name + "_noise_est.wav", audio=float_to_int16(noise_est_wav))
|
699
|
+
|
700
|
+
# debug code to test for perfect reconstruction of the extraction method
|
701
|
+
# note both 75% olsa-hanns and 50% olsa-hann modes checked to have perfect reconstruction
|
702
|
+
# target_r = mixdb.inverse_transform(target_f)
|
703
|
+
# noise_r = mixdb.inverse_transform(noise_f)
|
704
|
+
# _write_wav(name=base_name + '_target_r.wav', audio=float_to_int16(target_r))
|
705
|
+
# _write_wav(name=base_name + '_noise_r.wav', audio=float_to_int16(noise_r)) # chk perfect rec
|
706
|
+
|
707
|
+
# 8) Write out plot file
|
708
|
+
if enable_plot:
|
709
|
+
plot_name = base_name + "_metric_spenh.pdf"
|
710
|
+
|
711
|
+
# Reshape feature to eliminate overlap redundancy for easier to understand spectrogram view
|
712
|
+
# Original size (frames, stride, num_bands), decimates in stride dimension only if step is > 1
|
713
|
+
# Reshape to get frames*decimated_stride, num_bands
|
714
|
+
step = int(mixdb.feature_samples / mixdb.feature_step_samples)
|
715
|
+
if feature.ndim != 3:
|
716
|
+
raise OSError("feature does not have 3 dimensions: frames, stride, num_bands")
|
717
|
+
|
718
|
+
# for feature cn*00n**
|
719
|
+
feat_sgram = unstack_complex(feature)
|
720
|
+
feat_sgram = 20 * np.log10(abs(feat_sgram) + np.finfo(np.float32).eps)
|
721
|
+
feat_sgram = feat_sgram[:, -step:, :] # decimate, Fx1xB
|
722
|
+
feat_sgram = np.reshape(feat_sgram, (feat_sgram.shape[0] * feat_sgram.shape[1], feat_sgram.shape[2]))
|
723
|
+
|
724
|
+
with PdfPages(plot_name) as pdf:
|
725
|
+
# page1 we always have a mixture and prediction, target optional if truth provided
|
726
|
+
# For speech enhancement, target_f is definitely included:
|
727
|
+
predplot = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
|
728
|
+
tfunc_name = "target_f"
|
729
|
+
# if tfunc_name == 'mapped_snr_f':
|
730
|
+
# # leave as unmapped snr
|
731
|
+
# predplot = predict
|
732
|
+
# tfunc_name = mixdb.target_file(1).truth_settings[0].function
|
733
|
+
# elif tfunc_name == 'target_f' or 'target_mixture_f':
|
734
|
+
# predplot = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
|
735
|
+
# else:
|
736
|
+
# # use dB scale
|
737
|
+
# predplot = 10 * np.log10(predict + np.finfo(np.float32).eps)
|
738
|
+
# tfunc_name = tfunc_name + ' (db)'
|
739
|
+
|
740
|
+
mixspec = 20 * np.log10(abs(mixture_f) + np.finfo(np.float32).eps)
|
741
|
+
fig, ax = plot_mixpred(
|
742
|
+
mixture=mixture,
|
743
|
+
mixture_f=mixspec,
|
744
|
+
target=target,
|
745
|
+
feature=feat_sgram,
|
746
|
+
predict=predplot,
|
747
|
+
tp_title=tfunc_name,
|
748
|
+
)
|
749
|
+
pdf.savefig(fig)
|
750
|
+
pickle.dump((fig, ax), pgzip.open(base_name + "_metric_spenh_fig1.pkl.gz", "wb"))
|
751
|
+
|
752
|
+
# ----- page 2, plot unmapped predict, opt truth reconstructed and line plots of mean-over-f
|
753
|
+
# pdf.savefig(plot_pdb_predtruth(predict=pred_snr_f, tp_title='predict snr_f (db)'))
|
754
|
+
|
755
|
+
# page 3 speech extraction
|
756
|
+
tg_spec = 20 * np.log10(abs(target_f) + np.finfo(np.float32).eps)
|
757
|
+
tg_est_spec = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
|
758
|
+
# n_spec = np.reshape(n_spec,(n_spec.shape[0] * n_spec.shape[1], n_spec.shape[2]))
|
759
|
+
fig, ax = plot_e_predict_truth(
|
760
|
+
predict=tg_est_spec,
|
761
|
+
predict_wav=target_est_wav,
|
762
|
+
truth_f=tg_spec,
|
763
|
+
truth_wav=target_fi,
|
764
|
+
metric=np.vstack((lerr_tg_frame, phd_frame)).T,
|
765
|
+
tp_title="speech estimate",
|
766
|
+
)
|
767
|
+
pdf.savefig(fig)
|
768
|
+
pickle.dump((fig, ax), pgzip.open(base_name + "_metric_spenh_fig2.pkl.gz", "wb"))
|
769
|
+
|
770
|
+
# page 4 noise extraction
|
771
|
+
n_spec = 20 * np.log10(abs(noise_f) + np.finfo(np.float32).eps)
|
772
|
+
n_est_spec = 20 * np.log10(abs(noise_est_complex) + np.finfo(np.float32).eps)
|
773
|
+
fig, ax = plot_e_predict_truth(
|
774
|
+
predict=n_est_spec,
|
775
|
+
predict_wav=noise_est_wav,
|
776
|
+
truth_f=n_spec,
|
777
|
+
truth_wav=noise_fi,
|
778
|
+
metric=lerr_n_frame,
|
779
|
+
tp_title="noise estimate",
|
780
|
+
)
|
781
|
+
pdf.savefig(fig)
|
782
|
+
pickle.dump((fig, ax), pgzip.open(base_name + "_metric_spenh_fig4.pkl.gz", "wb"))
|
783
|
+
|
784
|
+
# Plot error waveforms
|
785
|
+
# tg_err_wav = target_fi - target_est_wav
|
786
|
+
# tg_err_spec = 20*np.log10(np.abs(target_f - predict_complex))
|
787
|
+
|
788
|
+
plt.close("all")
|
789
|
+
|
790
|
+
return all_metrics_table_1, all_metrics_table_2
|
791
|
+
|
792
|
+
|
793
|
+
def main():
|
794
|
+
from docopt import docopt
|
795
|
+
|
796
|
+
import sonusai
|
797
|
+
from sonusai.utils.docstring import trim_docstring
|
798
|
+
|
799
|
+
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
800
|
+
|
801
|
+
verbose = args["--verbose"]
|
802
|
+
mixids = args["--mixid"]
|
803
|
+
asr_method = args["--asr-method"]
|
804
|
+
truth_est_mode = args["--truth-est-mode"]
|
805
|
+
enable_plot = args["--plot"]
|
806
|
+
enable_wav = args["--wav"]
|
807
|
+
enable_summary = args["--summary"]
|
808
|
+
predict_location = args["PLOC"]
|
809
|
+
num_proc = args["--num_process"]
|
810
|
+
truth_location = args["TLOC"]
|
811
|
+
|
812
|
+
import glob
|
813
|
+
from functools import partial
|
814
|
+
from os.path import basename
|
815
|
+
from os.path import isdir
|
816
|
+
from os.path import join
|
817
|
+
|
818
|
+
import psutil
|
819
|
+
|
820
|
+
from sonusai import create_file_handler
|
821
|
+
from sonusai import initial_log_messages
|
822
|
+
from sonusai import logger
|
823
|
+
from sonusai import update_console_handler
|
824
|
+
from sonusai.mixture import MixtureDatabase
|
825
|
+
from sonusai.utils.parallel import par_track
|
826
|
+
from sonusai.utils.parallel import track
|
827
|
+
|
828
|
+
# Check prediction subdirectory
|
829
|
+
if not isdir(predict_location):
|
830
|
+
print(f"The specified predict location {predict_location} is not a valid subdirectory path, exiting.")
|
831
|
+
|
832
|
+
# all_predict_files = listdir(predict_location)
|
833
|
+
all_predict_files = glob.glob(predict_location + "/*.h5")
|
834
|
+
predict_logfile = glob.glob(predict_location + "/*predict.log")
|
835
|
+
predict_wav_mode = False
|
836
|
+
if len(all_predict_files) <= 0 and not truth_est_mode:
|
837
|
+
all_predict_files = glob.glob(predict_location + "/*.wav") # check for wav files
|
838
|
+
if len(all_predict_files) <= 0:
|
839
|
+
print(f"Subdirectory {predict_location} has no .h5 or .wav files, exiting.")
|
840
|
+
else:
|
841
|
+
logger.info(f"Found {len(all_predict_files)} prediction .wav files.")
|
842
|
+
predict_wav_mode = True
|
843
|
+
else:
|
844
|
+
logger.info(f"Found {len(all_predict_files)} prediction .h5 files.")
|
845
|
+
|
846
|
+
if len(predict_logfile) == 0:
|
847
|
+
logger.info(f"Warning, predict location {predict_location} has no prediction log files.")
|
848
|
+
else:
|
849
|
+
logger.info(f"Found predict log {basename(predict_logfile[0])} in predict location.")
|
850
|
+
|
851
|
+
# Setup logging file
|
852
|
+
create_file_handler(join(predict_location, "calc_metric_spenh.log"), verbose)
|
853
|
+
update_console_handler(verbose)
|
854
|
+
initial_log_messages("calc_metric_spenh")
|
855
|
+
|
856
|
+
mixdb = MixtureDatabase(truth_location)
|
857
|
+
mixids = mixdb.mixids_to_list(mixids)
|
858
|
+
logger.info(
|
859
|
+
f"Found mixdb of {mixdb.num_mixtures} total mixtures, with {mixdb.num_classes} classes in {truth_location}"
|
860
|
+
)
|
861
|
+
# speech enhancement metrics and audio truth requires target_f truth type, check it is present
|
862
|
+
target_f_key = None
|
863
|
+
logger.info(
|
864
|
+
f"mixdb has {len(mixdb.category_truth_configs('primary'))} truth types defined for primary, checking that target_f type is present."
|
865
|
+
)
|
866
|
+
for key in mixdb.category_truth_configs("primary"):
|
867
|
+
if mixdb.category_truth_configs("primary")[key] == "target_f":
|
868
|
+
target_f_key = key
|
869
|
+
if target_f_key is None:
|
870
|
+
logger.error("mixdb does not have target_f truth defined, required for speech enhancement metrics, exiting.")
|
871
|
+
raise SystemExit(1)
|
872
|
+
|
873
|
+
logger.info(f"Only running specified subset of {len(mixids)} mixtures")
|
874
|
+
|
875
|
+
asr_config_en = None
|
876
|
+
fnb = "metric_spenh_"
|
877
|
+
if asr_method is not None:
|
878
|
+
if asr_method in mixdb.asr_configs:
|
879
|
+
logger.info(f"Specified ASR method {asr_method} exists in mixdb.asr_configs, it will be used for ")
|
880
|
+
logger.info("prediction ASR and WER, and pre-calculated target and mixture ASR if available.")
|
881
|
+
asr_config_en = True
|
882
|
+
asr_cfg = mixdb.asr_configs[asr_method]
|
883
|
+
fnb = "metric_spenh_" + asr_method + "_"
|
884
|
+
logger.info(f"Using ASR cfg: {asr_cfg} ")
|
885
|
+
# audio = read_audio(DEFAULT_SPEECH, use_cache=True)
|
886
|
+
# logger.info(f'Warming up {asr_method}, note for cloud service this could take up to a few minutes.')
|
887
|
+
# asr_chk = calc_asr(audio, **asr_cfg)
|
888
|
+
# logger.info(f'Warmup completed, results {asr_chk}')
|
889
|
+
else:
|
890
|
+
logger.info(
|
891
|
+
f"Specified ASR method {asr_method} does not exists in mixdb.asr_configs."
|
892
|
+
f"Must choose one of the following (or none):"
|
893
|
+
)
|
894
|
+
logger.info(f"{', '.join(mixdb.asr_configs)}")
|
895
|
+
logger.error("Unrecognized ASR method, exiting.")
|
896
|
+
raise SystemExit(1)
|
897
|
+
|
898
|
+
num_cpu = psutil.cpu_count()
|
899
|
+
cpu_percent = psutil.cpu_percent(interval=1)
|
900
|
+
logger.info(f"#CPUs: {num_cpu}, current CPU utilization: {cpu_percent}%")
|
901
|
+
logger.info(f"Memory utilization: {psutil.virtual_memory().percent}%")
|
902
|
+
if num_proc == "auto":
|
903
|
+
use_cpu = int(num_cpu * (0.9 - cpu_percent / 100)) # default use 80% of available cpus
|
904
|
+
elif num_proc == "None":
|
905
|
+
use_cpu = None
|
906
|
+
else:
|
907
|
+
use_cpu = min(max(int(num_proc), 1), num_cpu)
|
908
|
+
|
909
|
+
# Individual mixtures use pandas print, set precision to 2 decimal places
|
910
|
+
# pd.set_option('float_format', '{:.2f}'.format)
|
911
|
+
logger.info(f"Calculating metrics for {len(mixids)} mixtures using {use_cpu} parallel processes")
|
912
|
+
# progress = tqdm(total=len(mixids), desc='calc_metric_spenh', mininterval=1)
|
913
|
+
progress = track(total=len(mixids))
|
914
|
+
if use_cpu is None:
|
915
|
+
no_par = True
|
916
|
+
num_cpus = None
|
917
|
+
else:
|
918
|
+
no_par = False
|
919
|
+
num_cpus = use_cpu
|
920
|
+
|
921
|
+
all_metrics_tables = par_track(
|
922
|
+
partial(
|
923
|
+
_process_mixture,
|
924
|
+
truth_location=truth_location,
|
925
|
+
predict_location=predict_location,
|
926
|
+
predict_wav_mode=predict_wav_mode,
|
927
|
+
truth_est_mode=truth_est_mode,
|
928
|
+
enable_plot=enable_plot,
|
929
|
+
enable_wav=enable_wav,
|
930
|
+
asr_method=asr_method,
|
931
|
+
target_f_key=target_f_key,
|
932
|
+
),
|
933
|
+
mixids,
|
934
|
+
progress=progress,
|
935
|
+
num_cpus=num_cpus,
|
936
|
+
no_par=no_par,
|
937
|
+
)
|
938
|
+
progress.close()
|
939
|
+
|
940
|
+
all_metrics_table_1 = pd.concat([item[0] for item in all_metrics_tables])
|
941
|
+
all_metrics_table_2 = pd.concat([item[1] for item in all_metrics_tables])
|
942
|
+
|
943
|
+
if not enable_summary:
|
944
|
+
return
|
945
|
+
|
946
|
+
# 9) Done with mixtures, write out summary metrics
|
947
|
+
# Calculate SNR summary avg of each non-random snr
|
948
|
+
all_mtab1_sorted = all_metrics_table_1.sort_values(by=["MXSNR", "SPFILE"])
|
949
|
+
all_mtab2_sorted = all_metrics_table_2.sort_values(by=["MXSNR"])
|
950
|
+
mtab_snr_summary = None
|
951
|
+
mtab_snr_summary_em = None
|
952
|
+
for snri in range(0, len(mixdb.snrs)):
|
953
|
+
tmp = all_mtab1_sorted.query("MXSNR==" + str(mixdb.snrs[snri])).mean(numeric_only=True).to_frame().T
|
954
|
+
# avoid nan when subset of mixids specified
|
955
|
+
if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
|
956
|
+
mtab_snr_summary = pd.concat([mtab_snr_summary, tmp])
|
957
|
+
|
958
|
+
tmp = all_mtab2_sorted[all_mtab2_sorted["MXSNR"] == mixdb.snrs[snri]].mean(numeric_only=True).to_frame().T
|
959
|
+
# avoid nan when subset of mixids specified (mxsnr will be nan if no data):
|
960
|
+
if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
|
961
|
+
mtab_snr_summary_em = pd.concat([mtab_snr_summary_em, tmp])
|
962
|
+
|
963
|
+
mtab_snr_summary = mtab_snr_summary.sort_values(by=["MXSNR"], ascending=False)
|
964
|
+
# Correct percentages in snr summary table
|
965
|
+
mtab_snr_summary["PESQi%"] = (
|
966
|
+
100 * (mtab_snr_summary["PESQ"] - mtab_snr_summary["MXPESQ"]) / np.maximum(mtab_snr_summary["MXPESQ"], 0.01)
|
967
|
+
)
|
968
|
+
for i in range(len(mtab_snr_summary)):
|
969
|
+
if mtab_snr_summary["MXWER"].iloc[i] == 0.0:
|
970
|
+
if mtab_snr_summary["WER"].iloc[i] == 0.0:
|
971
|
+
mtab_snr_summary.iloc[i, 6] = 0.0 # mtab_snr_summary['WERi%'].iloc[i] = 0.0
|
972
|
+
else:
|
973
|
+
mtab_snr_summary.iloc[i, 6] = -999.0 # mtab_snr_summary['WERi%'].iloc[i] = -999.0
|
974
|
+
else:
|
975
|
+
if ~np.isnan(mtab_snr_summary["WER"].iloc[i]) and ~np.isnan(mtab_snr_summary["MXWER"].iloc[i]):
|
976
|
+
# update WERi% in 6th col
|
977
|
+
mtab_snr_summary.iloc[i, 6] = (
|
978
|
+
100
|
979
|
+
* (mtab_snr_summary["MXWER"].iloc[i] - mtab_snr_summary["WER"].iloc[i])
|
980
|
+
/ mtab_snr_summary["MXWER"].iloc[i]
|
981
|
+
)
|
982
|
+
|
983
|
+
# Calculate avg metrics over all mixtures except -99
|
984
|
+
all_mtab1_sorted_nom99 = all_mtab1_sorted[all_mtab1_sorted.MXSNR != -99]
|
985
|
+
all_nom99_mean = all_mtab1_sorted_nom99.mean(numeric_only=True)
|
986
|
+
|
987
|
+
# correct the percentage averages with a direct calculation (PESQ% and WER%):
|
988
|
+
# ser.iloc[pos]
|
989
|
+
all_nom99_mean["PESQi%"] = (
|
990
|
+
100 * (all_nom99_mean["PESQ"] - all_nom99_mean["MXPESQ"]) / np.maximum(all_nom99_mean["MXPESQ"], 0.01)
|
991
|
+
) # pesq%
|
992
|
+
# all_nom99_mean[3] = 100 * (all_nom99_mean[2] - all_nom99_mean[1]) / np.maximum(all_nom99_mean[1], 0.01) # pesq%
|
993
|
+
if all_nom99_mean["MXWER"] == 0.0:
|
994
|
+
if all_nom99_mean["WER"] == 0.0:
|
995
|
+
all_nom99_mean["WERi%"] = 0.0
|
996
|
+
else:
|
997
|
+
all_nom99_mean["WERi%"] = -999.0
|
998
|
+
else: # WER%
|
999
|
+
all_nom99_mean["WERi%"] = 100 * (all_nom99_mean["MXWER"] - all_nom99_mean["WER"]) / all_nom99_mean["MXWER"]
|
1000
|
+
|
1001
|
+
num_mix = len(mixids)
|
1002
|
+
if num_mix > 1:
|
1003
|
+
# Print pandas data to files using precision to 2 decimals
|
1004
|
+
# pd.set_option('float_format', '{:.2f}'.format)
|
1005
|
+
|
1006
|
+
if not truth_est_mode:
|
1007
|
+
ofname = join(predict_location, fnb + "summary.txt")
|
1008
|
+
else:
|
1009
|
+
ofname = join(predict_location, fnb + "summary_truest.txt")
|
1010
|
+
|
1011
|
+
with open(ofname, "w") as f:
|
1012
|
+
print(f"ASR enabled with method {asr_method}", file=f)
|
1013
|
+
print(
|
1014
|
+
f"Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:", file=f
|
1015
|
+
)
|
1016
|
+
print(
|
1017
|
+
all_nom99_mean.to_frame().T.round(2).to_string(float_format=lambda x: f"{x:.2f}", index=False), file=f
|
1018
|
+
)
|
1019
|
+
print("\nSpeech enhancement metrics avg over each SNR:", file=f)
|
1020
|
+
print(mtab_snr_summary.round(2).to_string(float_format=lambda x: f"{x:.2f}", index=False), file=f)
|
1021
|
+
print("", file=f)
|
1022
|
+
print("Extraction statistics stats avg over each SNR:", file=f)
|
1023
|
+
# with pd.option_context('display.max_colwidth', 9):
|
1024
|
+
# with pd.set_option('float_format', '{:.1f}'.format):
|
1025
|
+
print(mtab_snr_summary_em.round(1).to_string(float_format=lambda x: f"{x:.1f}", index=False), file=f)
|
1026
|
+
print("", file=f)
|
1027
|
+
# pd.set_option('float_format', '{:.2f}'.format)
|
1028
|
+
|
1029
|
+
print(f"Speech enhancement metrics stats over all {num_mix} mixtures:", file=f)
|
1030
|
+
print(all_metrics_table_1.describe().round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
|
1031
|
+
print("", file=f)
|
1032
|
+
print(f"Extraction statistics stats over all {num_mix} mixtures:", file=f)
|
1033
|
+
print(all_metrics_table_2.describe().round(2).to_string(float_format=lambda x: f"{x:.1f}"), file=f)
|
1034
|
+
print("", file=f)
|
1035
|
+
|
1036
|
+
print("Speech enhancement metrics all-mixtures list:", file=f)
|
1037
|
+
# print(all_metrics_table_1.head().style.format(precision=2), file=f)
|
1038
|
+
print(all_metrics_table_1.round(2).to_string(float_format=lambda x: f"{x:.2f}"), file=f)
|
1039
|
+
print("", file=f)
|
1040
|
+
print("Extraction statistics all-mixtures list:", file=f)
|
1041
|
+
print(all_metrics_table_2.round(2).to_string(float_format=lambda x: f"{x:.1f}"), file=f)
|
1042
|
+
|
1043
|
+
# Write summary to .csv file
|
1044
|
+
if not truth_est_mode:
|
1045
|
+
csv_name = str(join(predict_location, fnb + "summary.csv"))
|
1046
|
+
else:
|
1047
|
+
csv_name = str(join(predict_location, fnb + "truest_summary.csv"))
|
1048
|
+
header_args = {
|
1049
|
+
"mode": "a",
|
1050
|
+
"encoding": "utf-8",
|
1051
|
+
"index": False,
|
1052
|
+
"header": False,
|
1053
|
+
}
|
1054
|
+
table_args = {
|
1055
|
+
"mode": "a",
|
1056
|
+
"encoding": "utf-8",
|
1057
|
+
}
|
1058
|
+
label = f"Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:"
|
1059
|
+
pd.DataFrame([label]).to_csv(csv_name, header=False, index=False) # open as write
|
1060
|
+
all_nom99_mean.to_frame().T.round(2).to_csv(csv_name, index=False, **table_args)
|
1061
|
+
pd.DataFrame([""]).to_csv(csv_name, **header_args)
|
1062
|
+
pd.DataFrame(["Speech enhancement metrics avg over each SNR:"]).to_csv(csv_name, **header_args)
|
1063
|
+
mtab_snr_summary.round(2).to_csv(csv_name, index=False, **table_args)
|
1064
|
+
pd.DataFrame([""]).to_csv(csv_name, **header_args)
|
1065
|
+
pd.DataFrame(["Extraction statistics stats avg over each SNR:"]).to_csv(csv_name, **header_args)
|
1066
|
+
mtab_snr_summary_em.round(2).to_csv(csv_name, index=False, **table_args)
|
1067
|
+
pd.DataFrame([""]).to_csv(csv_name, **header_args)
|
1068
|
+
pd.DataFrame([""]).to_csv(csv_name, **header_args)
|
1069
|
+
label = f"Speech enhancement metrics stats over {num_mix} mixtures:"
|
1070
|
+
pd.DataFrame([label]).to_csv(csv_name, **header_args)
|
1071
|
+
all_metrics_table_1.describe().round(2).to_csv(csv_name, **table_args)
|
1072
|
+
pd.DataFrame([""]).to_csv(csv_name, **header_args)
|
1073
|
+
label = f"Extraction statistics stats over {num_mix} mixtures:"
|
1074
|
+
pd.DataFrame([label]).to_csv(csv_name, **header_args)
|
1075
|
+
all_metrics_table_2.describe().round(2).to_csv(csv_name, **table_args)
|
1076
|
+
label = f"ASR enabled with method {asr_method}"
|
1077
|
+
pd.DataFrame([label]).to_csv(csv_name, **header_args)
|
1078
|
+
|
1079
|
+
if not truth_est_mode:
|
1080
|
+
csv_name = str(join(predict_location, fnb + "list.csv"))
|
1081
|
+
else:
|
1082
|
+
csv_name = str(join(predict_location, fnb + "truest_list.csv"))
|
1083
|
+
pd.DataFrame(["Speech enhancement metrics list:"]).to_csv(csv_name, header=False, index=False) # open as write
|
1084
|
+
all_metrics_table_1.round(2).to_csv(csv_name, **table_args)
|
1085
|
+
|
1086
|
+
if not truth_est_mode:
|
1087
|
+
csv_name = str(join(predict_location, fnb + "estats_list.csv"))
|
1088
|
+
else:
|
1089
|
+
csv_name = str(join(predict_location, fnb + "truest_estats_list.csv"))
|
1090
|
+
pd.DataFrame(["Extraction statistics list:"]).to_csv(csv_name, header=False, index=False) # open as write
|
1091
|
+
all_metrics_table_2.round(2).to_csv(csv_name, **table_args)
|
1092
|
+
|
1093
|
+
|
1094
|
+
if __name__ == "__main__":
|
1095
|
+
from sonusai import exception_handler
|
1096
|
+
from sonusai.utils.keyboard_interrupt import register_keyboard_interrupt
|
1097
|
+
|
1098
|
+
register_keyboard_interrupt()
|
1099
|
+
try:
|
1100
|
+
main()
|
1101
|
+
except Exception as e:
|
1102
|
+
exception_handler(e)
|
1103
|
+
|
1104
|
+
# if asr_method == 'none':
|
1105
|
+
# fnb = 'metric_spenh_'
|
1106
|
+
# elif asr_method == 'google':
|
1107
|
+
# fnb = 'metric_spenh_ggl_'
|
1108
|
+
# logger.info(f'ASR enabled with method {asr_method}')
|
1109
|
+
# enable_asr_warmup = True
|
1110
|
+
# elif asr_method == 'deepgram':
|
1111
|
+
# fnb = 'metric_spenh_dgram_'
|
1112
|
+
# logger.info(f'ASR enabled with method {asr_method}')
|
1113
|
+
# enable_asr_warmup = True
|
1114
|
+
# elif asr_method == 'aixplain_whisper':
|
1115
|
+
# fnb = 'metric_spenh_whspx_' + mixdb.asr_configs[asr_method]['model'] + '_'
|
1116
|
+
# asr_model_name = mixdb.asr_configs[asr_method]['model']
|
1117
|
+
# enable_asr_warmup = True
|
1118
|
+
# elif asr_method == 'whisper':
|
1119
|
+
# fnb = 'metric_spenh_whspl_' + mixdb.asr_configs[asr_method]['model'] + '_'
|
1120
|
+
# asr_model_name = mixdb.asr_configs[asr_method]['model']
|
1121
|
+
# enable_asr_warmup = True
|
1122
|
+
# elif asr_method == 'aaware_whisper':
|
1123
|
+
# fnb = 'metric_spenh_whspaaw_' + mixdb.asr_configs[asr_method]['model'] + '_'
|
1124
|
+
# asr_model_name = mixdb.asr_configs[asr_method]['model']
|
1125
|
+
# enable_asr_warmup = True
|
1126
|
+
# elif asr_method == 'faster_whisper':
|
1127
|
+
# fnb = 'metric_spenh_fwhsp_' + mixdb.asr_configs[asr_method]['model'] + '_'
|
1128
|
+
# asr_model_name = mixdb.asr_configs[asr_method]['model']
|
1129
|
+
# enable_asr_warmup = True
|
1130
|
+
# elif asr_method == 'sensory':
|
1131
|
+
# fnb = 'metric_spenh_snsr_' + mixdb.asr_configs[asr_method]['model'] + '_'
|
1132
|
+
# asr_model_name = mixdb.asr_configs[asr_method]['model']
|
1133
|
+
# enable_asr_warmup = True
|
1134
|
+
# else:
|
1135
|
+
# logger.error(f'Unrecognized ASR method: {asr_method}')
|
1136
|
+
# return
|