sonusai 0.17.0__py3-none-any.whl → 0.17.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. sonusai/audiofe.py +25 -54
  2. sonusai/calc_metric_spenh.py +212 -219
  3. sonusai/doc/doc.py +1 -1
  4. sonusai/mixture/__init__.py +2 -0
  5. sonusai/mixture/audio.py +12 -0
  6. sonusai/mixture/datatypes.py +11 -3
  7. sonusai/mixture/mixdb.py +100 -0
  8. sonusai/mixture/soundfile_audio.py +39 -0
  9. sonusai/mixture/sox_augmentation.py +3 -0
  10. sonusai/mixture/speaker_metadata.py +35 -0
  11. sonusai/mixture/torchaudio_audio.py +22 -0
  12. sonusai/mkmanifest.py +1 -1
  13. sonusai/mkwav.py +4 -4
  14. sonusai/onnx_predict.py +114 -410
  15. sonusai/post_spenh_targetf.py +2 -2
  16. sonusai/queries/queries.py +1 -1
  17. sonusai/speech/__init__.py +3 -0
  18. sonusai/speech/l2arctic.py +116 -0
  19. sonusai/speech/librispeech.py +99 -0
  20. sonusai/speech/mcgill.py +70 -0
  21. sonusai/speech/textgrid.py +100 -0
  22. sonusai/speech/timit.py +135 -0
  23. sonusai/speech/types.py +12 -0
  24. sonusai/speech/vctk.py +52 -0
  25. sonusai/speech/voxceleb.py +102 -0
  26. sonusai/utils/__init__.py +3 -2
  27. sonusai/utils/asr_functions/aaware_whisper.py +2 -2
  28. sonusai/utils/asr_manifest_functions/__init__.py +0 -1
  29. sonusai/utils/asr_manifest_functions/data.py +0 -8
  30. sonusai/utils/asr_manifest_functions/librispeech.py +1 -1
  31. sonusai/utils/asr_manifest_functions/mcgill_speech.py +1 -1
  32. sonusai/utils/asr_manifest_functions/vctk_noisy_speech.py +1 -1
  33. sonusai/utils/braced_glob.py +7 -3
  34. sonusai/utils/onnx_utils.py +110 -106
  35. sonusai/utils/path_info.py +7 -0
  36. sonusai/utils/{wave.py → write_audio.py} +2 -2
  37. {sonusai-0.17.0.dist-info → sonusai-0.17.3.dist-info}/METADATA +3 -1
  38. {sonusai-0.17.0.dist-info → sonusai-0.17.3.dist-info}/RECORD +40 -35
  39. {sonusai-0.17.0.dist-info → sonusai-0.17.3.dist-info}/WHEEL +1 -1
  40. sonusai/calc_metric_spenh-save.py +0 -1334
  41. sonusai/onnx_predict-old.py +0 -240
  42. sonusai/onnx_predict-save.py +0 -487
  43. sonusai/ovino_predict.py +0 -508
  44. sonusai/ovino_query_devices.py +0 -47
  45. sonusai/torchl_onnx-old.py +0 -216
  46. {sonusai-0.17.0.dist-info → sonusai-0.17.3.dist-info}/entry_points.txt +0 -0
@@ -1,1334 +0,0 @@
1
- """sonusai calc_metric_spenh
2
-
3
- usage: calc_metric_spenh [-hvtpws] [-i MIXID] [-e WER] [-m WMNAME] PLOC TLOC
4
-
5
- options:
6
- -h, --help
7
- -v, --verbose Be verbose.
8
- -i MIXID, --mixid MIXID Mixture ID(s) to process, can be range like 0:maxmix+1 [default: *].
9
- -t, --truth-est-mode Calculate extraction and metrics using truth (instead of prediction).
10
- -p, --plot Enable PDF plots file generation per mixture.
11
- -w, --wav Generate WAV files per mixture.
12
- -s, --summary Enable summary files generation.
13
- -e WER, --wer-method WER Word-Error-Rate method: deepgram, google, aixplain_whisper
14
- or whisper (locally run) [default: none]
15
- -m WMNAME, --whisper-model Whisper model name used in aixplain_whisper and whisper WER methods.
16
- [default: tiny]
17
-
18
- Calculate speech enhancement metrics of prediction data in PLOC using SonusAI mixture data
19
- in TLOC as truth/label reference. Metric and extraction data files are written into PLOC.
20
-
21
- PLOC directory containing prediction data in .h5 files created from truth/label mixture data in TLOC
22
- TLOC directory with SonusAI mixture database of truth/label mixture data
23
-
24
- For whisper WER methods, the possible models used in local processing (WER = whisper) are:
25
- {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large}
26
- but note most are very computationally demanding and can overwhelm/hang a local system.
27
-
28
- Outputs the following to PLOC (where id is mixd number 0:num_mixtures):
29
- <id>_metric_spenh.txt
30
-
31
- If --plot:
32
- <id>_metric_spenh.pdf
33
-
34
- If --wav:
35
- <id>_target.wav
36
- <id>_target_est.wav
37
- <id>_noise.wav
38
- <id>_noise_est.wav
39
- <id>_mixture.wav
40
-
41
- If --truth-est-mode:
42
- <id>_target_truth_est.wav
43
- <id>_noise_truth_est.wav
44
-
45
- If --summary:
46
- metric_spenh_targetf_summary.txt
47
- metric_spenh_targetf_summary.csv
48
- metric_spenh_targetf_list.csv
49
- metric_spenh_targetf_estats_list.csv
50
-
51
- If --truth-est-mode:
52
- metric_spenh_targetf_truth_list.csv
53
- metric_spenh_targetf_estats_truth_list.csv
54
-
55
- TBD
56
- Metric and extraction data are written into prediction location PLOC as separate files per mixture.
57
-
58
- -d PLOC, --ploc PLOC Location of SonusAI predict data.
59
-
60
- Inputs:
61
-
62
- """
63
- from dataclasses import dataclass
64
- from typing import Optional
65
-
66
- import matplotlib
67
- import matplotlib.pyplot as plt
68
- import numpy as np
69
- import pandas as pd
70
-
71
- from sonusai import logger
72
- from sonusai.mixture import AudioF
73
- from sonusai.mixture import AudioT
74
- from sonusai.mixture import Feature
75
- from sonusai.mixture import MixtureDatabase
76
- from sonusai.mixture import Predict
77
-
78
- matplotlib.use('SVG')
79
-
80
-
81
- @dataclass
82
- class MPGlobal:
83
- mixdb: MixtureDatabase = None
84
- predict_location: str = None
85
- predwav_mode: bool = None
86
- truth_est_mode: bool = None
87
- enable_plot: bool = None
88
- enable_wav: bool = None
89
- wer_method: str = None
90
- whisper_model: str = None
91
-
92
-
93
- MP_GLOBAL = MPGlobal()
94
-
95
-
96
- def power_compress(spec):
97
- mag = np.abs(spec)
98
- phase = np.angle(spec)
99
- mag = mag ** 0.3
100
- real_compress = mag * np.cos(phase)
101
- imag_compress = mag * np.sin(phase)
102
- return real_compress + 1j * imag_compress
103
-
104
-
105
- def power_uncompress(spec):
106
- mag = np.abs(spec)
107
- phase = np.angle(spec)
108
- mag = mag ** (1. / 0.3)
109
- real_uncompress = mag * np.cos(phase)
110
- imag_uncompress = mag * np.sin(phase)
111
- return real_uncompress + 1j * imag_uncompress
112
-
113
-
114
- def snr(clean_speech, processed_speech, sample_rate):
115
- # Check the length of the clean and processed speech. Must be the same.
116
- clean_length = len(clean_speech)
117
- processed_length = len(processed_speech)
118
- if clean_length != processed_length:
119
- raise ValueError('Both Speech Files must be same length.')
120
-
121
- overall_snr = 10 * np.log10(np.sum(np.square(clean_speech)) / np.sum(np.square(clean_speech - processed_speech)))
122
-
123
- # Global Variables
124
- winlength = round(30 * sample_rate / 1000) # window length in samples
125
- skiprate = int(np.floor(winlength / 4)) # window skip in samples
126
- MIN_SNR = -10 # minimum SNR in dB
127
- MAX_SNR = 35 # maximum SNR in dB
128
-
129
- # For each frame of input speech, calculate the Segmental SNR
130
- num_frames = int(clean_length / skiprate - (winlength / skiprate)) # number of frames
131
- start = 0 # starting sample
132
- window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, winlength + 1) / (winlength + 1)))
133
-
134
- segmental_snr = np.empty(num_frames)
135
- EPS = np.spacing(1)
136
- for frame_count in range(num_frames):
137
- # (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
138
- clean_frame = clean_speech[start:start + winlength]
139
- processed_frame = processed_speech[start:start + winlength]
140
- clean_frame = np.multiply(clean_frame, window)
141
- processed_frame = np.multiply(processed_frame, window)
142
-
143
- # (2) Compute the Segmental SNR
144
- signal_energy = np.sum(np.square(clean_frame))
145
- noise_energy = np.sum(np.square(clean_frame - processed_frame))
146
- segmental_snr[frame_count] = 10 * np.log10(signal_energy / (noise_energy + EPS) + EPS)
147
- segmental_snr[frame_count] = max(segmental_snr[frame_count], MIN_SNR)
148
- segmental_snr[frame_count] = min(segmental_snr[frame_count], MAX_SNR)
149
-
150
- start = start + skiprate
151
-
152
- return overall_snr, segmental_snr
153
-
154
-
155
- def lpcoeff(speech_frame, model_order):
156
- # (1) Compute Autocorrelation Lags
157
- winlength = np.size(speech_frame)
158
- R = np.empty(model_order + 1)
159
- E = np.empty(model_order + 1)
160
- for k in range(model_order + 1):
161
- R[k] = np.dot(speech_frame[0:winlength - k], speech_frame[k: winlength])
162
-
163
- # (2) Levinson-Durbin
164
- a = np.ones(model_order)
165
- a_past = np.empty(model_order)
166
- rcoeff = np.empty(model_order)
167
- E[0] = R[0]
168
- for i in range(model_order):
169
- a_past[0: i] = a[0: i]
170
- sum_term = np.dot(a_past[0: i], R[i:0:-1])
171
- rcoeff[i] = (R[i + 1] - sum_term) / E[i]
172
- a[i] = rcoeff[i]
173
- if i == 0:
174
- a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1:-1:-1], rcoeff[i])
175
- else:
176
- a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1::-1], rcoeff[i])
177
- E[i + 1] = (1 - rcoeff[i] * rcoeff[i]) * E[i]
178
- acorr = R
179
- refcoeff = rcoeff
180
- lpparams = np.concatenate((np.array([1]), -a))
181
- return acorr, refcoeff, lpparams
182
-
183
-
184
- def llr(clean_speech, processed_speech, sample_rate):
185
- from scipy.linalg import toeplitz
186
-
187
- # Check the length of the clean and processed speech. Must be the same.
188
- clean_length = np.size(clean_speech)
189
- processed_length = np.size(processed_speech)
190
- if clean_length != processed_length:
191
- raise ValueError('Both Speech Files must be same length.')
192
-
193
- # Global Variables
194
- winlength = (np.round(30 * sample_rate / 1000)).astype(int) # window length in samples
195
- skiprate = (np.floor(winlength / 4)).astype(int) # window skip in samples
196
- if sample_rate < 10000:
197
- P = 10 # LPC Analysis Order
198
- else:
199
- P = 16 # this could vary depending on sampling frequency.
200
-
201
- # For each frame of input speech, calculate the Log Likelihood Ratio
202
- num_frames = int((clean_length - winlength) / skiprate) # number of frames
203
- start = 0 # starting sample
204
- window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, winlength + 1) / (winlength + 1)))
205
-
206
- distortion = np.empty(num_frames)
207
- for frame_count in range(num_frames):
208
- # (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
209
- clean_frame = clean_speech[start: start + winlength]
210
- processed_frame = processed_speech[start: start + winlength]
211
- clean_frame = np.multiply(clean_frame, window)
212
- processed_frame = np.multiply(processed_frame, window)
213
-
214
- # (2) Get the autocorrelation lags and LPC parameters used to compute the LLR measure.
215
- R_clean, Ref_clean, A_clean = lpcoeff(clean_frame, P)
216
- R_processed, Ref_processed, A_processed = lpcoeff(processed_frame, P)
217
-
218
- # (3) Compute the LLR measure
219
- numerator = np.dot(np.matmul(A_processed, toeplitz(R_clean)), A_processed)
220
- denominator = np.dot(np.matmul(A_clean, toeplitz(R_clean)), A_clean)
221
- distortion[frame_count] = np.log(numerator / denominator)
222
- start = start + skiprate
223
- return distortion
224
-
225
-
226
- def wss(clean_speech, processed_speech, sample_rate):
227
- from scipy.fftpack import fft
228
-
229
- # Check the length of the clean and processed speech, which must be the same.
230
- clean_length = np.size(clean_speech)
231
- processed_length = np.size(processed_speech)
232
- if clean_length != processed_length:
233
- raise ValueError('Files must have same length.')
234
-
235
- # Global variables
236
- winlength = (np.round(30 * sample_rate / 1000)).astype(int) # window length in samples
237
- skiprate = (np.floor(np.divide(winlength, 4))).astype(int) # window skip in samples
238
- max_freq = (np.divide(sample_rate, 2)).astype(int) # maximum bandwidth
239
- num_crit = 25 # number of critical bands
240
-
241
- USE_FFT_SPECTRUM = 1 # defaults to 10th order LP spectrum
242
- n_fft = (np.power(2, np.ceil(np.log2(2 * winlength)))).astype(int)
243
- n_fftby2 = (np.multiply(0.5, n_fft)).astype(int) # FFT size/2
244
- Kmax = 20.0 # value suggested by Klatt, pg 1280
245
- Klocmax = 1.0 # value suggested by Klatt, pg 1280
246
-
247
- # Critical Band Filter Definitions (Center Frequency and Bandwidths in Hz)
248
- cent_freq = np.array([50.0000, 120.000, 190.000, 260.000, 330.000, 400.000, 470.000,
249
- 540.000, 617.372, 703.378, 798.717, 904.128, 1020.38, 1148.30,
250
- 1288.72, 1442.54, 1610.70, 1794.16, 1993.93, 2211.08, 2446.71,
251
- 2701.97, 2978.04, 3276.17, 3597.63])
252
- bandwidth = np.array([70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 70.0000,
253
- 77.3724, 86.0056, 95.3398, 105.411, 116.256, 127.914, 140.423,
254
- 153.823, 168.154, 183.457, 199.776, 217.153, 235.631, 255.255,
255
- 276.072, 298.126, 321.465, 346.136])
256
-
257
- bw_min = bandwidth[0] # minimum critical bandwidth
258
-
259
- # Set up the critical band filters.
260
- # Note here that Gaussianly shaped filters are used.
261
- # Also, the sum of the filter weights are equivalent for each critical band filter.
262
- # Filter less than -30 dB and set to zero.
263
- min_factor = np.exp(-30.0 / (2.0 * 2.303)) # -30 dB point of filter
264
- crit_filter = np.empty((num_crit, n_fftby2))
265
- for i in range(num_crit):
266
- f0 = (cent_freq[i] / max_freq) * n_fftby2
267
- bw = (bandwidth[i] / max_freq) * n_fftby2
268
- norm_factor = np.log(bw_min) - np.log(bandwidth[i])
269
- j = np.arange(n_fftby2)
270
- crit_filter[i, :] = np.exp(-11 * np.square(np.divide(j - np.floor(f0), bw)) + norm_factor)
271
- cond = np.greater(crit_filter[i, :], min_factor)
272
- crit_filter[i, :] = np.where(cond, crit_filter[i, :], 0)
273
- # For each frame of input speech, calculate the Weighted Spectral Slope Measure
274
- num_frames = int(clean_length / skiprate - (winlength / skiprate)) # number of frames
275
- start = 0 # starting sample
276
- window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, winlength + 1) / (winlength + 1)))
277
-
278
- distortion = np.empty(num_frames)
279
- for frame_count in range(num_frames):
280
- # (1) Get the Frames for the test and reference speech. Multiply by Hanning Window.
281
- clean_frame = clean_speech[start: start + winlength] / 32768
282
- processed_frame = processed_speech[start: start + winlength] / 32768
283
- clean_frame = np.multiply(clean_frame, window)
284
- processed_frame = np.multiply(processed_frame, window)
285
- # (2) Compute the Power Spectrum of Clean and Processed
286
- # if USE_FFT_SPECTRUM:
287
- clean_spec = np.square(np.abs(fft(clean_frame, n_fft)))
288
- processed_spec = np.square(np.abs(fft(processed_frame, n_fft)))
289
-
290
- # (3) Compute Filterbank Output Energies (in dB scale)
291
- clean_energy = np.matmul(crit_filter, clean_spec[0:n_fftby2])
292
- processed_energy = np.matmul(crit_filter, processed_spec[0:n_fftby2])
293
-
294
- clean_energy = 10 * np.log10(np.maximum(clean_energy, 1E-10))
295
- processed_energy = 10 * np.log10(np.maximum(processed_energy, 1E-10))
296
-
297
- # (4) Compute Spectral Slope (dB[i+1]-dB[i])
298
- clean_slope = clean_energy[1:num_crit] - clean_energy[0: num_crit - 1]
299
- processed_slope = processed_energy[1:num_crit] - processed_energy[0: num_crit - 1]
300
-
301
- # (5) Find the nearest peak locations in the spectra to each critical band.
302
- # If the slope is negative, we search to the left. If positive, we search to the right.
303
- clean_loc_peak = np.empty(num_crit - 1)
304
- processed_loc_peak = np.empty(num_crit - 1)
305
-
306
- for i in range(num_crit - 1):
307
- # find the peaks in the clean speech signal
308
- if clean_slope[i] > 0: # search to the right
309
- n = i
310
- while (n < num_crit - 1) and (clean_slope[n] > 0):
311
- n = n + 1
312
- clean_loc_peak[i] = clean_energy[n - 1]
313
- else: # search to the left
314
- n = i
315
- while (n >= 0) and (clean_slope[n] <= 0):
316
- n = n - 1
317
- clean_loc_peak[i] = clean_energy[n + 1]
318
-
319
- # find the peaks in the processed speech signal
320
- if processed_slope[i] > 0: # search to the right
321
- n = i
322
- while (n < num_crit - 1) and (processed_slope[n] > 0):
323
- n = n + 1
324
- processed_loc_peak[i] = processed_energy[n - 1]
325
- else: # search to the left
326
- n = i
327
- while (n >= 0) and (processed_slope[n] <= 0):
328
- n = n - 1
329
- processed_loc_peak[i] = processed_energy[n + 1]
330
-
331
- # (6) Compute the WSS Measure for this frame. This includes determination of the weighting function.
332
- dBMax_clean = np.max(clean_energy)
333
- dBMax_processed = np.max(processed_energy)
334
- '''
335
- The weights are calculated by averaging individual weighting factors from the clean and processed frame.
336
- These weights W_clean and W_processed should range from 0 to 1 and place more emphasis on spectral peaks
337
- and less emphasis on slope differences in spectral valleys.
338
- This procedure is described on page 1280 of Klatt's 1982 ICASSP paper.
339
- '''
340
- Wmax_clean = np.divide(Kmax, Kmax + dBMax_clean - clean_energy[0: num_crit - 1])
341
- Wlocmax_clean = np.divide(Klocmax, Klocmax + clean_loc_peak - clean_energy[0: num_crit - 1])
342
- W_clean = np.multiply(Wmax_clean, Wlocmax_clean)
343
-
344
- Wmax_processed = np.divide(Kmax, Kmax + dBMax_processed - processed_energy[0: num_crit - 1])
345
- Wlocmax_processed = np.divide(Klocmax, Klocmax + processed_loc_peak - processed_energy[0: num_crit - 1])
346
- W_processed = np.multiply(Wmax_processed, Wlocmax_processed)
347
-
348
- W = np.divide(np.add(W_clean, W_processed), 2.0)
349
- slope_diff = np.subtract(clean_slope, processed_slope)[0: num_crit - 1]
350
- distortion[frame_count] = np.dot(W, np.square(slope_diff)) / np.sum(W)
351
- # this normalization is not part of Klatt's paper, but helps to normalize the measure.
352
- # Here we scale the measure by the sum of the weights.
353
- start = start + skiprate
354
- return distortion
355
-
356
-
357
- def calc_speech_metrics(hypothesis: np.ndarray,
358
- reference: np.ndarray) -> tuple[float, int, int, int, float]:
359
- """
360
- Calculate speech metrics pesq_mos, CSIG, CBAK, COVL, segSNR. These are all related and thus included
361
- in one function. Reference: matlab script "compute_metrics.m".
362
-
363
- Usage:
364
- pesq, csig, cbak, covl, ssnr = compute_metrics(hypothesis, reference, Fs, path)
365
- reference: clean audio as array
366
- hypothesis: enhanced audio as array
367
- Audio must have sampling rate = 16000 Hz.
368
-
369
- Example call:
370
- pesq_output, csig_output, cbak_output, covl_output, ssnr_output = \
371
- calc_speech_metrics(predicted_audio, target_audio)
372
- """
373
- from sonusai.metrics import calc_pesq
374
-
375
- Fs = 16000
376
-
377
- # compute the WSS measure
378
- wss_dist_vec = wss(reference, hypothesis, Fs)
379
- wss_dist_vec = np.sort(wss_dist_vec)
380
- alpha = 0.95 # value from CMGAN ref implementation
381
- wss_dist = np.mean(wss_dist_vec[0: round(np.size(wss_dist_vec) * alpha)])
382
-
383
- # compute the LLR measure
384
- llr_dist = llr(reference, hypothesis, Fs)
385
- ll_rs = np.sort(llr_dist)
386
- llr_len = round(np.size(llr_dist) * alpha)
387
- llr_mean = np.mean(ll_rs[0: llr_len])
388
-
389
- # compute the SNRseg
390
- snr_dist, segsnr_dist = snr(reference, hypothesis, Fs)
391
- snr_mean = snr_dist
392
- segSNR = np.mean(segsnr_dist)
393
-
394
- # compute the pesq (use Sonusai wrapper, only fs=16k, mode=wb support)
395
- pesq_mos = calc_pesq(hypothesis=hypothesis, reference=reference)
396
- # pesq_mos = pesq(sampling_rate1, data1, data2, 'wb')
397
-
398
- # now compute the composite measures
399
- CSIG = 3.093 - 1.029 * llr_mean + 0.603 * pesq_mos - 0.009 * wss_dist
400
- CSIG = max(1, CSIG)
401
- CSIG = min(5, CSIG) # limit values to [1, 5]
402
- CBAK = 1.634 + 0.478 * pesq_mos - 0.007 * wss_dist + 0.063 * segSNR
403
- CBAK = max(1, CBAK)
404
- CBAK = min(5, CBAK) # limit values to [1, 5]
405
- COVL = 1.594 + 0.805 * pesq_mos - 0.512 * llr_mean - 0.007 * wss_dist
406
- COVL = max(1, COVL)
407
- COVL = min(5, COVL) # limit values to [1, 5]
408
-
409
- return pesq_mos, CSIG, CBAK, COVL, segSNR
410
-
411
-
412
- def mean_square_error(hypothesis: np.ndarray,
413
- reference: np.ndarray,
414
- squared: bool = False) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
415
- """Calculate root-mean-square error or mean square error
416
-
417
- :param hypothesis: [frames, bins]
418
- :param reference: [frames, bins]
419
- :param squared: calculate mean square rather than root-mean-square
420
- :return: mean, mean per bin, mean per frame
421
- """
422
- sq_err = np.square(reference - hypothesis)
423
-
424
- # mean over frames for value per bin
425
- err_b = np.mean(sq_err, axis=0)
426
- # mean over bins for value per frame
427
- err_f = np.mean(sq_err, axis=1)
428
- # mean over all
429
- err = np.mean(sq_err)
430
-
431
- if not squared:
432
- err_b = np.sqrt(err_b)
433
- err_f = np.sqrt(err_f)
434
- err = np.sqrt(err)
435
-
436
- return err, err_b, err_f
437
-
438
-
439
- def mean_abs_percentage_error(hypothesis: np.ndarray,
440
- reference: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
441
- """Calculate mean abs percentage error
442
-
443
- If inputs are complex, calculates average: mape(real)/2 + mape(imag)/2
444
-
445
- :param hypothesis: [frames, bins]
446
- :param reference: [frames, bins]
447
- :return: mean, mean per bin, mean per frame
448
- """
449
- if not np.iscomplexobj(reference) and not np.iscomplexobj(hypothesis):
450
- abs_err = 100 * np.abs((reference - hypothesis) / (reference + np.finfo(np.float32).eps))
451
- else:
452
- reference_r = np.real(reference)
453
- reference_i = np.imag(reference)
454
- hypothesis_r = np.real(hypothesis)
455
- hypothesis_i = np.imag(hypothesis)
456
- abs_err_r = 100 * np.abs((reference_r - hypothesis_r) / (reference_r + np.finfo(np.float32).eps))
457
- abs_err_i = 100 * np.abs((reference_i - hypothesis_i) / (reference_i + np.finfo(np.float32).eps))
458
- abs_err = (abs_err_r / 2) + (abs_err_i / 2)
459
-
460
- # mean over frames for value per bin
461
- err_b = np.around(np.mean(abs_err, axis=0), 3)
462
- # mean over bins for value per frame
463
- err_f = np.around(np.mean(abs_err, axis=1), 3)
464
- # mean over all
465
- err = np.around(np.mean(abs_err), 3)
466
-
467
- return err, err_b, err_f
468
-
469
-
470
- def log_error(reference: np.ndarray, hypothesis: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
471
- """Calculate log error
472
-
473
- :param reference: complex or real [frames, bins]
474
- :param hypothesis: complex or real [frames, bins]
475
- :return: mean, mean per bin, mean per frame
476
- """
477
- reference_sq = np.real(reference * np.conjugate(reference))
478
- hypothesis_sq = np.real(hypothesis * np.conjugate(hypothesis))
479
- log_err = abs(10 * np.log10((reference_sq + np.finfo(np.float32).eps) / (hypothesis_sq + np.finfo(np.float32).eps)))
480
- # log_err = abs(10 * np.log10(reference_sq / (hypothesis_sq + np.finfo(np.float32).eps) + np.finfo(np.float32).eps))
481
-
482
- # mean over frames for value per bin
483
- err_b = np.around(np.mean(log_err, axis=0), 3)
484
- # mean over bins for value per frame
485
- err_f = np.around(np.mean(log_err, axis=1), 3)
486
- # mean over all
487
- err = np.around(np.mean(log_err), 3)
488
-
489
- return err, err_b, err_f
490
-
491
-
492
- def phase_distance(reference: np.ndarray,
493
- hypothesis: np.ndarray,
494
- eps: float = 1e-9) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
495
- """Calculate weighted phase distance error (weight normalization over bins per frame)
496
-
497
- :param reference: complex [frames, bins]
498
- :param hypothesis: complex [frames, bins]
499
- :param eps: epsilon value
500
- :return: mean, mean per bin, mean per frame
501
- """
502
- ang_diff = np.angle(reference) - np.angle(hypothesis)
503
- phd_mod = (ang_diff + np.pi) % (2 * np.pi) - np.pi
504
- rh_angle_diff = phd_mod * 180 / np.pi # angle diff in deg
505
-
506
- # Use complex divide to intrinsically keep angle diff +/-180 deg, but avoid div by zero (real hyp)
507
- # hyp_real = np.real(hypothesis)
508
- # near_zeros = np.real(hyp_real) < eps
509
- # hyp_real = hyp_real * (np.logical_not(near_zeros))
510
- # hyp_real = hyp_real + (near_zeros * eps)
511
- # hypothesis = hyp_real + 1j*np.imag(hypothesis)
512
- # rh_angle_diff = np.angle(reference / hypothesis) * 180 / np.pi # angle diff +/-180
513
-
514
- # weighted mean over all (scalar)
515
- reference_mag = np.abs(reference)
516
- ref_weight = reference_mag / (np.sum(reference_mag) + eps) # frames x bins
517
- err = np.around(np.sum(ref_weight * rh_angle_diff), 3)
518
-
519
- # weighted mean over frames (value per bin)
520
- err_b = np.zeros(reference.shape[1])
521
- for bi in range(reference.shape[1]):
522
- ref_weight = reference_mag[:, bi] / (np.sum(reference_mag[:, bi], axis=0) + eps)
523
- err_b[bi] = np.around(np.sum(ref_weight * rh_angle_diff[:, bi]), 3)
524
-
525
- # weighted mean over bins (value per frame)
526
- err_f = np.zeros(reference.shape[0])
527
- for fi in range(reference.shape[0]):
528
- ref_weight = reference_mag[fi, :] / (np.sum(reference_mag[fi, :]) + eps)
529
- err_f[fi] = np.around(np.sum(ref_weight * rh_angle_diff[fi, :]), 3)
530
-
531
- return err, err_b, err_f
532
-
533
-
534
- def plot_mixpred(mixture: AudioT,
535
- mixture_f: AudioF,
536
- target: Optional[AudioT] = None,
537
- feature: Optional[Feature] = None,
538
- predict: Optional[Predict] = None,
539
- tp_title: str = '') -> plt.Figure:
540
- from sonusai.mixture import SAMPLE_RATE
541
-
542
- num_plots = 2
543
- if feature is not None:
544
- num_plots += 1
545
- if predict is not None:
546
- num_plots += 1
547
-
548
- fig, ax = plt.subplots(num_plots, 1, constrained_layout=True, figsize=(11, 8.5))
549
-
550
- # Plot the waveform
551
- p = 0
552
- x_axis = np.arange(len(mixture), dtype=np.float32) / SAMPLE_RATE
553
- ax[p].plot(x_axis, mixture, label='Mixture', color='mistyrose')
554
- ax[0].set_ylabel('magnitude', color='tab:blue')
555
- ax[p].set_xlim(x_axis[0], x_axis[-1])
556
- # ax[p].set_ylim([-1.025, 1.025])
557
- if target is not None: # Plot target time-domain waveform on top of mixture
558
- ax[0].plot(x_axis, target, label='Target', color='tab:blue')
559
- # ax[0].tick_params(axis='y', labelcolor=color)
560
- ax[p].set_title('Waveform')
561
-
562
- # Plot the mixture spectrogram
563
- p += 1
564
- ax[p].imshow(np.transpose(mixture_f), aspect='auto', interpolation='nearest', origin='lower')
565
- ax[p].set_title('Mixture')
566
-
567
- if feature is not None:
568
- p += 1
569
- ax[p].imshow(np.transpose(feature), aspect='auto', interpolation='nearest', origin='lower')
570
- ax[p].set_title('Feature')
571
-
572
- if predict is not None:
573
- p += 1
574
- im = ax[p].imshow(np.transpose(predict), aspect='auto', interpolation='nearest', origin='lower')
575
- ax[p].set_title('Predict ' + tp_title)
576
- plt.colorbar(im, location='bottom')
577
-
578
- return fig
579
-
580
-
581
- def plot_pdb_predtruth(predict: np.ndarray,
582
- truth_f: Optional[np.ndarray] = None,
583
- metric: Optional[np.ndarray] = None,
584
- tp_title: str = '') -> plt.Figure:
585
- """Plot predict and optionally truth and a metric in power db, e.g. applies 10*log10(predict)"""
586
- num_plots = 2
587
- if truth_f is not None:
588
- num_plots += 1
589
-
590
- fig, ax = plt.subplots(num_plots, 1, constrained_layout=True, figsize=(11, 8.5))
591
-
592
- # Plot the predict spectrogram
593
- p = 0
594
- tmp = 10 * np.log10(predict.transpose() + np.finfo(np.float32).eps)
595
- im = ax[p].imshow(tmp, aspect='auto', interpolation='nearest', origin='lower')
596
- ax[p].set_title('Predict')
597
- plt.colorbar(im, location='bottom')
598
-
599
- if truth_f is not None:
600
- p += 1
601
- tmp = 10 * np.log10(truth_f.transpose() + np.finfo(np.float32).eps)
602
- im = ax[p].imshow(tmp, aspect='auto', interpolation='nearest', origin='lower')
603
- ax[p].set_title('Truth')
604
- plt.colorbar(im, location='bottom')
605
-
606
- # Plot the predict avg, and optionally truth avg and metric lines
607
- pred_avg = 10 * np.log10(np.mean(predict, axis=-1) + np.finfo(np.float32).eps)
608
- p += 1
609
- x_axis = np.arange(len(pred_avg), dtype=np.float32) # / SAMPLE_RATE
610
- ax[p].plot(x_axis, pred_avg, color='black', linestyle='dashed', label='Predict mean over freq.')
611
- ax[p].set_ylabel('mean db', color='black')
612
- ax[p].set_xlim(x_axis[0], x_axis[-1])
613
- if truth_f is not None:
614
- truth_avg = 10 * np.log10(np.mean(truth_f, axis=-1) + np.finfo(np.float32).eps)
615
- ax[p].plot(x_axis, truth_avg, color='green', linestyle='dashed', label='Truth mean over freq.')
616
-
617
- if metric is not None: # instantiate 2nd y-axis that shares the same x-axis
618
- ax2 = ax[p].twinx()
619
- color2 = 'red'
620
- ax2.plot(x_axis, metric, color=color2, label='sig distortion (mse db)')
621
- ax2.set_xlim(x_axis[0], x_axis[-1])
622
- ax2.set_ylim([0, np.max(metric)])
623
- ax2.set_ylabel('spectral distortion (mse db)', color=color2)
624
- ax2.tick_params(axis='y', labelcolor=color2)
625
- ax[p].set_title('SNR and SNR mse (mean over freq. db)')
626
- else:
627
- ax[p].set_title('SNR (mean over freq. db)')
628
- # ax[0].tick_params(axis='y', labelcolor=color)
629
- return fig
630
-
631
-
632
- def plot_epredtruth(predict: np.ndarray,
633
- predict_wav: np.ndarray,
634
- truth_f: Optional[np.ndarray] = None,
635
- truth_wav: Optional[np.ndarray] = None,
636
- metric: Optional[np.ndarray] = None,
637
- tp_title: str = '') -> plt.Figure:
638
- """Plot predict spectrogram and waveform and optionally truth and a metric)"""
639
- num_plots = 2
640
- if truth_f is not None:
641
- num_plots += 1
642
- if metric is not None:
643
- num_plots += 1
644
-
645
- fig, ax = plt.subplots(num_plots, 1, constrained_layout=True, figsize=(11, 8.5))
646
-
647
- # Plot the predict spectrogram
648
- p = 0
649
- im = ax[p].imshow(predict.transpose(), aspect='auto', interpolation='nearest', origin='lower')
650
- ax[p].set_title('Predict')
651
- plt.colorbar(im, location='bottom')
652
-
653
- if truth_f is not None: # plot truth if provided and use same colormap as predict
654
- p += 1
655
- ax[p].imshow(truth_f.transpose(), im.cmap, aspect='auto', interpolation='nearest', origin='lower')
656
- ax[p].set_title('Truth')
657
-
658
- # Plot the predict wav, and optionally truth avg and metric lines
659
- p += 1
660
- x_axis = np.arange(len(predict_wav), dtype=np.float32) # / SAMPLE_RATE
661
- ax[p].plot(x_axis, predict_wav, color='black', linestyle='dashed', label='Speech Estimate')
662
- ax[p].set_ylabel('Amplitude', color='black')
663
- ax[p].set_xlim(x_axis[0], x_axis[-1])
664
- if truth_wav is not None:
665
- ntrim = len(truth_wav) - len(predict_wav)
666
- if ntrim > 0:
667
- truth_wav = truth_wav[0:-ntrim]
668
- ax[p].plot(x_axis, truth_wav, color='green', linestyle='dashed', label='True Target')
669
-
670
- # Plot the metric lines
671
- if metric is not None:
672
- p += 1
673
- if metric.ndim > 1: # if it has multiple dims, plot 1st
674
- metric1 = metric[:, 0]
675
- else:
676
- metric1 = metric # if single dim, plot it as 1st
677
- x_axis = np.arange(len(metric1), dtype=np.float32) # / SAMPLE_RATE
678
- ax[p].plot(x_axis, metric1, color='red', label='Target LogErr')
679
- ax[p].set_ylabel('log error db', color='red')
680
- ax[p].set_xlim(x_axis[0], x_axis[-1])
681
- ax[p].set_ylim([-0.01, np.max(metric1) + .01])
682
- if metric.ndim > 1:
683
- if metric.shape[1] > 1:
684
- metr2 = metric[:, 1]
685
- ax2 = ax[p].twinx()
686
- color2 = 'blue'
687
- ax2.plot(x_axis, metr2, color=color2, label='phase dist (deg)')
688
- # ax2.set_ylim([-180.0, +180.0])
689
- if np.max(metr2) - np.min(metr2) > .1:
690
- ax2.set_ylim([np.min(metr2), np.max(metr2)])
691
- ax2.set_ylabel('phase dist (deg)', color=color2)
692
- ax2.tick_params(axis='y', labelcolor=color2)
693
- # ax[p].set_title('SNR and SNR mse (mean over freq. db)')
694
-
695
- return fig
696
-
697
-
698
- def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
699
- from os.path import basename
700
- from os.path import join
701
- from os.path import splitext
702
-
703
- import h5py
704
- from numpy import inf
705
- from pystoi import stoi
706
-
707
- from sonusai import SonusAIError
708
- from sonusai import logger
709
- from sonusai.metrics import calc_pcm
710
- from sonusai.metrics import calc_wer
711
- from sonusai.metrics import calc_wsdr
712
- from sonusai.mixture import forward_transform
713
- from sonusai.mixture import inverse_transform
714
- from sonusai.mixture import read_audio
715
- from sonusai.utils import calc_asr
716
- from sonusai.utils import float_to_int16
717
- from sonusai.utils import reshape_outputs
718
- from sonusai.utils import stack_complex
719
- from sonusai.utils import unstack_complex
720
- from sonusai.utils import write_wav
721
-
722
- mixdb = MP_GLOBAL.mixdb
723
- predict_location = MP_GLOBAL.predict_location
724
- predwav_mode = MP_GLOBAL.predwav_mode
725
- truth_est_mode = MP_GLOBAL.truth_est_mode
726
- enable_plot = MP_GLOBAL.enable_plot
727
- enable_wav = MP_GLOBAL.enable_wav
728
- wer_method = MP_GLOBAL.wer_method
729
- whisper_model = MP_GLOBAL.whisper_model
730
-
731
- # 1) Read predict data, var predict with shape [BatchSize,Classes] or [BatchSize,Tsteps,Classes]
732
- output_name = join(predict_location, mixdb.mixture(mixid).name)
733
- predict = None
734
- if truth_est_mode:
735
- # in truth estimation mode we use the truth in place of prediction to see metrics with perfect input
736
- # don't bother to read prediction, and predict var will get assigned to truth later
737
- # mark outputs with tru suffix, i.e. 0000_truest_*
738
- base_name = splitext(output_name)[0] + '_truest'
739
- else:
740
- base_name, ext = splitext(output_name) # base_name used later
741
- if not predwav_mode:
742
- try:
743
- with h5py.File(output_name, 'r') as f:
744
- predict = np.array(f['predict'])
745
- except Exception as e:
746
- raise SonusAIError(f'Error reading {output_name}: {e}')
747
- # reshape to always be [frames,classes] where ndim==3 case frames = batch * tsteps
748
- if predict.ndim > 2: # TBD generalize to somehow detect if timestep dim exists, some cases > 2 don't have
749
- # logger.debug(f'Prediction reshape from {predict.shape} to remove timestep dimension.')
750
- predict, _ = reshape_outputs(predict=predict, truth=None, timesteps=predict.shape[1])
751
- else:
752
- base_name, ext = splitext(output_name)
753
- prfname = join(base_name + '.wav')
754
- audio = read_audio(prfname)
755
- predict = forward_transform(audio, mixdb.ft_config)
756
- if mixdb.feature[0:1] == 'h':
757
- predict = power_compress(predict)
758
- predict = stack_complex(predict)
759
-
760
- # 2) Collect true target, noise, mixture data, trim to predict size if needed
761
- tmp = mixdb.mixture_targets(mixid) # targets is list of pre-IR and pre-specaugment targets
762
- target_f = mixdb.mixture_targets_f(mixid, targets=tmp)[0]
763
- target = tmp[0]
764
- mixture = mixdb.mixture_mixture(mixid) # note: gives full reverberated/distorted target, but no specaugment
765
- # noise_wodist = mixdb.mixture_noise(mixid) # noise without specaugment and distortion
766
- # noise_wodist_f = mixdb.mixture_noise_f(mixid, noise=noise_wodist)
767
- noise = mixture - target # has time-domain distortion (ir,etc.) but does not have specaugment
768
- #noise_f = mixdb.mixture_noise_f(mixid, noise=noise)
769
- segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise) # note: uses pre-IR, pre-specaug audio
770
- mixture_f = mixdb.mixture_mixture_f(mixid, mixture=mixture)
771
- noise_f = mixture_f - target_f # true noise in freq domain includes specaugment and time-domain ir,distortions
772
- #segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise)
773
- segsnr_f[segsnr_f == inf] = 7.944e8 # 99db
774
- segsnr_f[segsnr_f == -inf] = 1.258e-10 # -99db
775
- # need to use inv-tf to match #samples & latency shift properties of predict inv tf
776
- targetfi = inverse_transform(target_f, mixdb.it_config)
777
- noisefi = inverse_transform(noise_f, mixdb.it_config)
778
- # mixturefi = mixdb.inverse_transform(mixture_f)
779
-
780
- # gen feature, truth - note feature only used for plots
781
- # TBD parse truth_f for different formats and also multi-truth
782
- feature, truth_f = mixdb.mixture_ft(mixid, mixture=mixture)
783
- truth_type = mixdb.target_file(mixdb.mixture(mixid).targets[0].file_id).truth_settings[0].function
784
- if truth_type == 'target_mixture_f':
785
- half = truth_f.shape[-1] // 2
786
- truth_f = truth_f[..., :half] # extract target_f only
787
-
788
- if not truth_est_mode:
789
- if predict.shape[0] < target_f.shape[0]: # target_f, truth_f, mixture_f, etc. same size
790
- trimf = target_f.shape[0] - predict.shape[0]
791
- logger.debug(f'Warning: prediction frames less than mixture, trimming {trimf} frames from all truth.')
792
- target_f = target_f[0:-trimf, :]
793
- targetfi, _ = inverse_transform(target_f, mixdb.it_config)
794
- trimt = target.shape[0] - targetfi.shape[0]
795
- target = target[0:-trimt]
796
- noise_f = noise_f[0:-trimf, :]
797
- noise = noise[0:-trimt]
798
- mixture_f = mixture_f[0:-trimf, :]
799
- mixture = mixture[0:-trimt]
800
- truth_f = truth_f[0:-trimf, :]
801
- elif predict.shape[0] > target_f.shape[0]:
802
- raise SonusAIError(
803
- f'Error: prediction has more frames than true mixture {predict.shape[0]} vs {truth_f.shape[0]}')
804
-
805
- # 3) Extraction - format proper complex and wav estimates and truth (unstack, uncompress, inv tf, etc.)
806
- if truth_est_mode:
807
- predict = truth_f # substitute truth for the prediction (for test/debug)
808
- predict_complex = unstack_complex(predict) # unstack
809
- # if feat has compressed mag and truth does not, compress it
810
- if mixdb.feature[0:1] == 'h' and mixdb.target_file(1).truth_settings[0].function[0:10] != 'targetcmpr':
811
- predict_complex = power_compress(predict_complex) # from uncompressed truth
812
- else:
813
- predict_complex = unstack_complex(predict)
814
-
815
- truth_f_complex = unstack_complex(truth_f)
816
- if mixdb.feature[0:1] == 'h': # 'hn' or 'ha' or 'hd', etc.: # if feat has compressed mag
817
- # estimate noise in uncompressed-mag domain
818
- noise_est_complex = mixture_f - power_uncompress(predict_complex)
819
- predict_complex = power_uncompress(predict_complex) # uncompress if truth is compressed
820
- else: # cn, c8, ..
821
- noise_est_complex = mixture_f - predict_complex
822
-
823
- target_est_wav = inverse_transform(predict_complex, mixdb.it_config)
824
- noise_est_wav = inverse_transform(noise_est_complex, mixdb.it_config)
825
-
826
- # 4) Metrics
827
- # Target/Speech logerr - PSD estimation accuracy symmetric mean log-spectral distortion
828
- lerr_tg, lerr_tg_bin, lerr_tg_frame = log_error(reference=truth_f_complex, hypothesis=predict_complex)
829
- # Noise logerr - PSD estimation accuracy
830
- lerr_n, lerr_n_bin, lerr_n_frame = log_error(reference=noise_f, hypothesis=noise_est_complex)
831
- # PCM loss metric
832
- ytrue_f = np.concatenate((truth_f_complex[:, np.newaxis, :], noise_f[:, np.newaxis, :]), axis=1)
833
- ypred_f = np.concatenate((predict_complex[:, np.newaxis, :], noise_est_complex[:, np.newaxis, :]), axis=1)
834
- pcm, pcm_bin, pcm_frame = calc_pcm(hypothesis=ypred_f, reference=ytrue_f, with_log=True)
835
-
836
- # Phase distance
837
- phd, phd_bin, phd_frame = phase_distance(hypothesis=predict_complex, reference=truth_f_complex)
838
-
839
- # Noise td logerr
840
- # lerr_nt, lerr_nt_bin, lerr_nt_frame = log_error(noisefi, noise_truth_est_audio)
841
-
842
- # # SA-SDR (time-domain source-aggragated SDR)
843
- ytrue = np.concatenate((targetfi[:, np.newaxis], noisefi[:, np.newaxis]), axis=1)
844
- ypred = np.concatenate((target_est_wav[:, np.newaxis], noise_est_wav[:, np.newaxis]), axis=1)
845
- # # note: w/o scale is more pessimistic number
846
- # sa_sdr, _ = calc_sa_sdr(hypothesis=ypred, reference=ytrue)
847
- target_stoi = stoi(targetfi, target_est_wav, 16000, extended=False)
848
-
849
- wsdr, wsdr_cc, wsdr_cw = calc_wsdr(hypothesis=ypred, reference=ytrue, with_log=True)
850
- # logger.debug(f'wsdr weight sum for mixid {mixid} = {np.sum(wsdr_cw)}.')
851
- # logger.debug(f'wsdr cweights = {wsdr_cw}.')
852
- # logger.debug(f'wsdr ccoefs for mixid {mixid} = {wsdr_cc}.')
853
-
854
- # Speech intelligibility measure - PESQ
855
- if int(mixdb.mixture(mixid).snr) > -99:
856
- # len = target_est_wav.shape[0]
857
- pesq_speech, csig_tg, cbak_tg, covl_tg, sgsnr_tg = calc_speech_metrics(target_est_wav, targetfi)
858
- pesq_mixture, csig_mx, cbak_mx, covl_mx, sgsnr_mx = calc_speech_metrics(mixture, target)
859
- # pesq_speech_tst = calc_pesq(hypothesis=target_est_wav, reference=target)
860
- # pesq_mixture_tst = calc_pesq(hypothesis=mixture, reference=target)
861
- # pesq improvement
862
- pesq_impr = pesq_speech - pesq_mixture
863
- # pesq improvement %
864
- pesq_impr_pc = pesq_impr / (pesq_mixture + np.finfo(np.float32).eps) * 100
865
- else:
866
- pesq_speech = 0
867
- pesq_mixture = 0
868
- pesq_impr_pc = np.float32(0)
869
- csig_mx = 0
870
- csig_tg = 0
871
- cbak_mx = 0
872
- cbak_tg = 0
873
- covl_mx = 0
874
- covl_tg = 0
875
-
876
- # Calc WER
877
- asr_tt = ''
878
- asr_mx = ''
879
- asr_tge = ''
880
- if wer_method == 'none' or mixdb.mixture(mixid).snr == -99: # noise only, ignore/reset target asr
881
- wer_mx = float('nan')
882
- wer_tge = float('nan')
883
- wer_pi = float('nan')
884
- else:
885
- if MP_GLOBAL.mixdb.asr_manifests:
886
- asr_tt = MP_GLOBAL.mixdb.mixture_asr_data(mixid)[0] # ignore mixup
887
- else:
888
- asr_tt = calc_asr(target, engine=wer_method, whisper_model_name=whisper_model).text # target truth
889
-
890
- if asr_tt:
891
- asr_mx = calc_asr(mixture, engine=wer_method, whisper_model=whisper_model).text
892
- asr_tge = calc_asr(target_est_wav, engine=wer_method, whisper_model=whisper_model).text
893
-
894
- wer_mx = calc_wer(asr_mx, asr_tt).wer * 100 # mixture wer
895
- wer_tge = calc_wer(asr_tge, asr_tt).wer * 100 # target estimate wer
896
- if wer_mx == 0.0:
897
- if wer_tge == 0.0:
898
- wer_pi = 0.0
899
- else:
900
- wer_pi = -999.0
901
- else:
902
- wer_pi = 100 * (wer_mx - wer_tge) / wer_mx
903
- else:
904
- print(f'Warning: mixid {mixid} asr truth is empty, setting to 0% wer')
905
- wer_mx = float(0)
906
- wer_tge = float(0)
907
- wer_pi = float(0)
908
-
909
- # 5) Save per mixture metric results
910
- # Single row in table of scalar metrics per mixture
911
- mtable1_col = ['MXSNR', 'MXPESQ', 'PESQ', 'PESQi%', 'MXWER', 'WER', 'WERi%', 'WSDR', 'STOI',
912
- 'PCM', 'SPLERR', 'NLERR', 'PD', 'MXCSIG', 'CSIG', 'MXCBAK', 'CBAK', 'MXCOVL', 'COVL',
913
- 'SPFILE', 'NFILE']
914
- ti = mixdb.mixture(mixid).targets[0].file_id
915
- ni = mixdb.mixture(mixid).noise.file_id
916
- metr1 = [mixdb.mixture(mixid).snr, pesq_mixture, pesq_speech, pesq_impr_pc, wer_mx, wer_tge, wer_pi, wsdr,
917
- target_stoi, pcm, lerr_tg, lerr_n, phd, csig_mx, csig_tg, cbak_mx, cbak_tg, covl_mx, covl_tg,
918
- basename(mixdb.target_file(ti).name), basename(mixdb.noise_file(ni).name)]
919
- mtab1 = pd.DataFrame([metr1], columns=mtable1_col, index=[mixid])
920
-
921
- # Stats of per frame estimation metrics
922
- metr2 = pd.DataFrame({'SSNR': segsnr_f,
923
- 'PCM': pcm_frame,
924
- 'SLERR': lerr_tg_frame,
925
- 'NLERR': lerr_n_frame,
926
- 'SPD': phd_frame})
927
- metr2 = metr2.describe() # Use pandas stat function
928
- # Change SSNR stats to dB, except count. SSNR is index 0, pandas requires using iloc
929
- # metr2['SSNR'][1:] = metr2['SSNR'][1:].apply(lambda x: 10 * np.log10(x + 1.01e-10))
930
- metr2.iloc[1:, 0] = metr2['SSNR'][1:].apply(lambda x: 10 * np.log10(x + 1.01e-10))
931
- # create a single row in multi-column header
932
- new_labels = pd.MultiIndex.from_product([metr2.columns,
933
- ['Avg', 'Min', 'Med', 'Max', 'Std']],
934
- names=['Metric', 'Stat'])
935
- dat1row = metr2.loc[['mean', 'min', '50%', 'max', 'std'], :].T.stack().to_numpy().reshape((1, -1))
936
- mtab2 = pd.DataFrame(dat1row,
937
- index=[mixid],
938
- columns=new_labels)
939
- mtab2.insert(0, 'MXSNR', mixdb.mixture(mixid).snr, False) # add MXSNR as the first metric column
940
-
941
- all_metrics_table_1 = mtab1 # return to be collected by process
942
- all_metrics_table_2 = mtab2 # return to be collected by process
943
-
944
- metric_name = base_name + '_metric_spenh.txt'
945
- with open(metric_name, 'w') as f:
946
- print('Speech enhancement metrics:', file=f)
947
- print(mtab1.round(2).to_string(float_format=lambda x: "{:.2f}".format(x)), file=f)
948
- print('', file=f)
949
- print(f'Extraction statistics over {mixture_f.shape[0]} frames:', file=f)
950
- print(metr2.round(2).to_string(float_format=lambda x: "{:.2f}".format(x)), file=f)
951
- print('', file=f)
952
- print(f'Target path: {mixdb.target_file(ti).name}', file=f)
953
- print(f'Noise path: {mixdb.noise_file(ni).name}', file=f)
954
- if wer_method != 'none':
955
- print(f'WER method: {wer_method} and whisper model (if used): {whisper_model}', file=f)
956
- if mixdb.asr_manifests:
957
- print(f'ASR truth from manifest: {asr_tt}', file=f)
958
- else:
959
- print(f'ASR truth from wer method: {asr_tt}', file=f)
960
- print(f'ASR result for mixture: {asr_mx}', file=f)
961
- print(f'ASR result for prediction: {asr_tge}', file=f)
962
-
963
- print(f'Augmentations: {mixdb.mixture(mixid)}', file=f)
964
-
965
- # 7) write wav files
966
- if enable_wav:
967
- write_wav(name=base_name + '_mixture.wav', audio=float_to_int16(mixture))
968
- write_wav(name=base_name + '_target.wav', audio=float_to_int16(target))
969
- # write_wav(name=base_name + '_targetfi.wav', audio=float_to_int16(targetfi))
970
- write_wav(name=base_name + '_noise.wav', audio=float_to_int16(noise))
971
- write_wav(name=base_name + '_target_est.wav', audio=float_to_int16(target_est_wav))
972
- write_wav(name=base_name + '_noise_est.wav', audio=float_to_int16(noise_est_wav))
973
-
974
- # debug code to test for perfect reconstruction of the extraction method
975
- # note both 75% olsa-hanns and 50% olsa-hann modes checked to have perfect reconstruction
976
- # target_r = mixdb.inverse_transform(target_f)
977
- # noise_r = mixdb.inverse_transform(noise_f)
978
- # _write_wav(name=base_name + '_target_r.wav', audio=float_to_int16(target_r))
979
- # _write_wav(name=base_name + '_noise_r.wav', audio=float_to_int16(noise_r)) # chk perfect rec
980
-
981
- # 8) Write out plot file
982
- if enable_plot:
983
- from matplotlib.backends.backend_pdf import PdfPages
984
- plot_fname = base_name + '_metric_spenh.pdf'
985
-
986
- # Reshape feature to eliminate overlap redundancy for easier to understand spectrogram view
987
- # Original size (frames, stride, num_bands), decimates in stride dimension only if step is > 1
988
- # Reshape to get frames*decimated_stride, num_bands
989
- step = int(mixdb.feature_samples / mixdb.feature_step_samples)
990
- if feature.ndim != 3:
991
- raise SonusAIError(f'feature does not have 3 dimensions: frames, stride, num_bands')
992
-
993
- # for feature cn*00n**
994
- feat_sgram = unstack_complex(feature)
995
- feat_sgram = 20 * np.log10(abs(feat_sgram) + np.finfo(np.float32).eps)
996
- feat_sgram = feat_sgram[:, -step:, :] # decimate, Fx1xB
997
- feat_sgram = np.reshape(feat_sgram, (feat_sgram.shape[0] * feat_sgram.shape[1], feat_sgram.shape[2]))
998
-
999
- with PdfPages(plot_fname) as pdf:
1000
- # page1 we always have a mixture and prediction, target optional if truth provided
1001
- tfunc_name = mixdb.target_file(1).truth_settings[0].function # first target, assumes all have same
1002
- if tfunc_name == 'mapped_snr_f':
1003
- # leave as unmapped snr
1004
- predplot = predict
1005
- tfunc_name = mixdb.target_file(1).truth_settings[0].function
1006
- elif tfunc_name == 'target_f' or 'target_mixture_f':
1007
- predplot = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
1008
- else:
1009
- # use dB scale
1010
- predplot = 10 * np.log10(predict + np.finfo(np.float32).eps)
1011
- tfunc_name = tfunc_name + ' (db)'
1012
-
1013
- mixspec = 20 * np.log10(abs(mixture_f) + np.finfo(np.float32).eps)
1014
- pdf.savefig(plot_mixpred(mixture=mixture,
1015
- mixture_f=mixspec,
1016
- target=target,
1017
- feature=feat_sgram,
1018
- predict=predplot,
1019
- tp_title=tfunc_name))
1020
-
1021
- # ----- page 2, plot unmapped predict, opt truth reconstructed and line plots of mean-over-f
1022
- # pdf.savefig(plot_pdb_predtruth(predict=pred_snr_f, tp_title='predict snr_f (db)'))
1023
-
1024
- # page 3 speech extraction
1025
- tg_spec = 20 * np.log10(abs(target_f) + np.finfo(np.float32).eps)
1026
- tg_est_spec = 20 * np.log10(abs(predict_complex) + np.finfo(np.float32).eps)
1027
- # n_spec = np.reshape(n_spec,(n_spec.shape[0] * n_spec.shape[1], n_spec.shape[2]))
1028
- pdf.savefig(plot_epredtruth(predict=tg_est_spec,
1029
- predict_wav=target_est_wav,
1030
- truth_f=tg_spec,
1031
- truth_wav=targetfi,
1032
- metric=np.vstack((lerr_tg_frame, phd_frame)).T,
1033
- tp_title='speech estimate'))
1034
-
1035
- # page 4 noise extraction
1036
- n_spec = 20 * np.log10(abs(noise_f) + np.finfo(np.float32).eps)
1037
- n_est_spec = 20 * np.log10(abs(noise_est_complex) + np.finfo(np.float32).eps)
1038
- pdf.savefig(plot_epredtruth(predict=n_est_spec,
1039
- predict_wav=noise_est_wav,
1040
- truth_f=n_spec,
1041
- truth_wav=noisefi,
1042
- metric=lerr_n_frame,
1043
- tp_title='noise estimate'))
1044
-
1045
- # Plot error waveforms
1046
- # tg_err_wav = targetfi - target_est_wav
1047
- # tg_err_spec = 20*np.log10(np.abs(target_f - predict_complex))
1048
-
1049
- plt.close('all')
1050
-
1051
- return all_metrics_table_1, all_metrics_table_2
1052
-
1053
-
1054
- def main():
1055
- from docopt import docopt
1056
-
1057
- import sonusai
1058
- from sonusai.utils import trim_docstring
1059
-
1060
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
1061
-
1062
- verbose = args['--verbose']
1063
- mixids = args['--mixid']
1064
- predict_location = args['PLOC']
1065
- wer_method = args['--wer-method'].lower()
1066
- truth_est_mode = args['--truth-est-mode']
1067
- enable_plot = args['--plot']
1068
- enable_wav = args['--wav']
1069
- enable_summary = args['--summary']
1070
- truth_location = args['TLOC']
1071
- whisper_model = args['--whisper-model'].lower()
1072
-
1073
- import glob
1074
- from os.path import basename
1075
- from os.path import isdir
1076
- from os.path import join
1077
- from os.path import split
1078
-
1079
- from tqdm import tqdm
1080
-
1081
- from sonusai import create_file_handler
1082
- from sonusai import initial_log_messages
1083
- from sonusai import logger
1084
- from sonusai import update_console_handler
1085
- from sonusai.mixture import DEFAULT_NOISE
1086
- from sonusai.mixture import MixtureDatabase
1087
- from sonusai.mixture import read_audio
1088
- from sonusai.utils import calc_asr
1089
- from sonusai.utils import pp_tqdm_imap
1090
-
1091
- # Check prediction subdirectory
1092
- if not isdir(predict_location):
1093
- print(f'The specified predict location {predict_location} is not a valid subdirectory path, exiting ...')
1094
-
1095
- # allpfiles = listdir(predict_location)
1096
- allpfiles = glob.glob(predict_location + "/*.h5")
1097
- predict_logfile = glob.glob(predict_location + "/*predict.log")
1098
- predwav_mode = False
1099
- if len(allpfiles) <= 0 and not truth_est_mode:
1100
- allpfiles = glob.glob(predict_location + "/*.wav") # check for wav files
1101
- if len(allpfiles) <= 0:
1102
- print(f'Subdirectory {predict_location} has no .h5 or .wav files, exiting ...')
1103
- else:
1104
- logger.info(f'Found {len(allpfiles)} prediction .wav files.')
1105
- predwav_mode = True
1106
- else:
1107
- logger.info(f'Found {len(allpfiles)} prediction .h5 files.')
1108
-
1109
- if len(predict_logfile) == 0:
1110
- logger.info(f'Warning, predict location {predict_location} has no prediction log files.')
1111
- else:
1112
- logger.info(f'Found predict log {basename(predict_logfile[0])} in predict location.')
1113
-
1114
- # Setup logging file
1115
- create_file_handler(join(predict_location, 'calc_metric_spenh.log'))
1116
- update_console_handler(verbose)
1117
- initial_log_messages('calc_metric_spenh')
1118
-
1119
- mixdb = MixtureDatabase(truth_location)
1120
- mixids = mixdb.mixids_to_list(mixids)
1121
- logger.info(
1122
- f'Found mixdb of {mixdb.num_mixtures} total mixtures, with {mixdb.num_classes} classes in {truth_location}')
1123
- logger.info(f'Only running specified subset of {len(mixids)} mixtures')
1124
-
1125
- enable_asr_warmup = False
1126
- if wer_method == 'none':
1127
- fnb = 'metric_spenh_'
1128
- elif wer_method == 'google':
1129
- fnb = 'metric_spenh_ggl_'
1130
- logger.info(f'WER enabled with method {wer_method}')
1131
- enable_asr_warmup = True
1132
- elif wer_method == 'deepgram':
1133
- fnb = 'metric_spenh_dgram_'
1134
- logger.info(f'WER enabled with method {wer_method}')
1135
- enable_asr_warmup = True
1136
- elif wer_method == 'aixplain_whisper':
1137
- fnb = 'metric_spenh_whspx_' + whisper_model + '_'
1138
- logger.info(f'WER enabled with method {wer_method} and whisper model {whisper_model}')
1139
- enable_asr_warmup = True
1140
- elif wer_method == 'whisper':
1141
- fnb = 'metric_spenh_whspl_' + whisper_model + '_'
1142
- logger.info(f'WER enabled with method {wer_method} and whisper model {whisper_model}')
1143
- enable_asr_warmup = True
1144
- elif wer_method == 'aaware_whisper':
1145
- fnb = 'metric_spenh_whspaaw_' + whisper_model + '_'
1146
- logger.info(f'WER enabled with method {wer_method} and whisper model {whisper_model}')
1147
- enable_asr_warmup = True
1148
- elif wer_method == 'fastwhisper':
1149
- fnb = 'metric_spenh_fwhsp_' + whisper_model + '_'
1150
- logger.info(f'WER enabled with method {wer_method} and whisper model {whisper_model}')
1151
- enable_asr_warmup = True
1152
- else:
1153
- logger.error(f'Unrecognized WER method: {wer_method}')
1154
- return
1155
-
1156
- if enable_asr_warmup:
1157
- DEFAULT_SPEECH = split(DEFAULT_NOISE)[0] + '/speech_ma01_01.wav'
1158
- audio = read_audio(DEFAULT_SPEECH)
1159
- logger.info(f'Warming up asr method, note for cloud service this could take up to a few min ...')
1160
- asr_chk = calc_asr(audio, engine=wer_method, whisper_model_name=whisper_model)
1161
- logger.info(f'Warmup completed, results {asr_chk}')
1162
-
1163
- MP_GLOBAL.mixdb = mixdb
1164
- MP_GLOBAL.predict_location = predict_location
1165
- MP_GLOBAL.predwav_mode = predwav_mode
1166
- MP_GLOBAL.truth_est_mode = truth_est_mode
1167
- MP_GLOBAL.enable_plot = enable_plot
1168
- MP_GLOBAL.enable_wav = enable_wav
1169
- MP_GLOBAL.wer_method = wer_method
1170
- MP_GLOBAL.whisper_model = whisper_model
1171
-
1172
- # Individual mixtures use pandas print, set precision to 2 decimal places
1173
- # pd.set_option('float_format', '{:.2f}'.format)
1174
- progress = tqdm(total=len(mixids), desc='calc_metric_spenh')
1175
- all_metrics_tables = pp_tqdm_imap(_process_mixture, mixids, progress=progress, num_cpus=8)
1176
- progress.close()
1177
-
1178
- all_metrics_table_1 = pd.concat([item[0] for item in all_metrics_tables])
1179
- all_metrics_table_2 = pd.concat([item[1] for item in all_metrics_tables])
1180
-
1181
- if not enable_summary:
1182
- return
1183
-
1184
- # 9) Done with mixtures, write out summary metrics
1185
- # Calculate SNR summary avg of each non-random snr
1186
- all_mtab1_sorted = all_metrics_table_1.sort_values(by=['MXSNR', 'SPFILE'])
1187
- all_mtab2_sorted = all_metrics_table_2.sort_values(by=['MXSNR'])
1188
- mtab_snr_summary = None
1189
- mtab_snr_summary_em = None
1190
- for snri in range(0, len(mixdb.snrs)):
1191
- tmp = all_mtab1_sorted.query('MXSNR==' + str(mixdb.snrs[snri])).mean(numeric_only=True).to_frame().T
1192
- # avoid nan when subset of mixids specified
1193
- if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
1194
- mtab_snr_summary = pd.concat([mtab_snr_summary, tmp])
1195
-
1196
- tmp = all_mtab2_sorted[all_mtab2_sorted['MXSNR'] == mixdb.snrs[snri]].mean(numeric_only=True).to_frame().T
1197
- # avoid nan when subset of mixids specified (mxsnr will be nan if no data):
1198
- if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
1199
- mtab_snr_summary_em = pd.concat([mtab_snr_summary_em, tmp])
1200
-
1201
- mtab_snr_summary = mtab_snr_summary.sort_values(by=['MXSNR'], ascending=False)
1202
- # Correct percentages in snr summary table
1203
- mtab_snr_summary['PESQi%'] = 100 * (mtab_snr_summary['PESQ'] - mtab_snr_summary['MXPESQ']) / np.maximum(
1204
- mtab_snr_summary['MXPESQ'], 0.01)
1205
- for i in range(len(mtab_snr_summary)):
1206
- if mtab_snr_summary['MXWER'].iloc[i] == 0.0:
1207
- if mtab_snr_summary['WER'].iloc[i] == 0.0:
1208
- mtab_snr_summary['WERi%'].iloc[i] = 0.0
1209
- else:
1210
- mtab_snr_summary['WERi%'].iloc[i] = -999.0
1211
- else:
1212
- if ~np.isnan(mtab_snr_summary['WER'].iloc[i]) and ~np.isnan(mtab_snr_summary['MXWER'].iloc[i]):
1213
- # update WERi% in 6th col
1214
- mtab_snr_summary.iloc[i,6] = 100 * (mtab_snr_summary['MXWER'].iloc[i] -
1215
- mtab_snr_summary['WER'].iloc[i]) / \
1216
- mtab_snr_summary['MXWER'].iloc[i]
1217
-
1218
-
1219
- # Calculate avg metrics over all mixtures except -99
1220
- all_mtab1_sorted_nom99 = all_mtab1_sorted[all_mtab1_sorted.MXSNR != -99]
1221
- all_nom99_mean = all_mtab1_sorted_nom99.mean(numeric_only=True)
1222
-
1223
- # correct the percentage averages with a direct calculation (PESQ% and WER%):
1224
- # ser.iloc[pos]
1225
- all_nom99_mean['PESQi%'] = (100 * (all_nom99_mean['PESQ'] - all_nom99_mean['MXPESQ'])
1226
- / np.maximum(all_nom99_mean['MXPESQ'], 0.01)) # pesq%
1227
- # all_nom99_mean[3] = 100 * (all_nom99_mean[2] - all_nom99_mean[1]) / np.maximum(all_nom99_mean[1], 0.01) # pesq%
1228
- if all_nom99_mean['MXWER'] == 0.0:
1229
- if all_nom99_mean['WER'] == 0.0:
1230
- all_nom99_mean['WERi%'] = 0.0
1231
- else:
1232
- all_nom99_mean['WERi%'] = -999.0
1233
- else: # wer%
1234
- all_nom99_mean['WERi%'] = 100 * (all_nom99_mean['MXWER'] - all_nom99_mean['WER']) / all_nom99_mean['MXWER']
1235
-
1236
- num_mix = len(mixids)
1237
- if num_mix > 1:
1238
- # Print pandas data to files using precision to 2 decimals
1239
- # pd.set_option('float_format', '{:.2f}'.format)
1240
- csp = 0
1241
-
1242
- if not truth_est_mode:
1243
- ofname = join(predict_location, fnb + 'summary.txt')
1244
- else:
1245
- ofname = join(predict_location, fnb + 'summary_truest.txt')
1246
-
1247
- with open(ofname, 'w') as f:
1248
- print(f'WER enabled with method {wer_method}, whisper model, if used: {whisper_model}', file=f)
1249
- print(f'Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:',
1250
- file=f)
1251
- print(all_nom99_mean.to_frame().T.round(2).to_string(float_format=lambda x: "{:.2f}".format(x),
1252
- index=False), file=f)
1253
- print(f'\nSpeech enhancement metrics avg over each SNR:', file=f)
1254
- print(mtab_snr_summary.round(2).to_string(float_format=lambda x: "{:.2f}".format(x), index=False), file=f)
1255
- print('', file=f)
1256
- print(f'Extraction statistics stats avg over each SNR:', file=f)
1257
- # with pd.option_context('display.max_colwidth', 9):
1258
- # with pd.set_option('float_format', '{:.1f}'.format):
1259
- print(mtab_snr_summary_em.round(1).to_string(float_format=lambda x: "{:.1f}".format(x), index=False),
1260
- file=f)
1261
- print('', file=f)
1262
- # pd.set_option('float_format', '{:.2f}'.format)
1263
-
1264
- print(f'Speech enhancement metrics stats over all {num_mix} mixtures:', file=f)
1265
- print(all_metrics_table_1.describe().round(2).to_string(float_format=lambda x: "{:.2f}".format(x)), file=f)
1266
- print('', file=f)
1267
- print(f'Extraction statistics stats over all {num_mix} mixtures:', file=f)
1268
- print(all_metrics_table_2.describe().round(2).to_string(float_format=lambda x: "{:.1f}".format(x)), file=f)
1269
- print('', file=f)
1270
-
1271
- print('Speech enhancement metrics all-mixtures list:', file=f)
1272
- # print(all_metrics_table_1.head().style.format(precision=2), file=f)
1273
- print(all_metrics_table_1.round(2).to_string(float_format=lambda x: "{:.2f}".format(x)), file=f)
1274
- print('', file=f)
1275
- print('Extraction statistics all-mixtures list:', file=f)
1276
- print(all_metrics_table_2.round(2).to_string(float_format=lambda x: "{:.1f}".format(x)), file=f)
1277
-
1278
- # Write summary to .csv file
1279
- if not truth_est_mode:
1280
- csv_name = join(predict_location, fnb + 'summary.csv')
1281
- else:
1282
- csv_name = join(predict_location, fnb + 'summary_truest.csv')
1283
- header_args = {
1284
- 'mode': 'a',
1285
- 'encoding': 'utf-8',
1286
- 'index': False,
1287
- 'header': False,
1288
- }
1289
- table_args = {
1290
- 'mode': 'a',
1291
- 'encoding': 'utf-8',
1292
- }
1293
- label = f'Speech enhancement metrics avg over all {len(all_mtab1_sorted_nom99)} non -99 SNR mixtures:'
1294
- pd.DataFrame([label]).to_csv(csv_name, header=False, index=False) # open as write
1295
- all_nom99_mean.to_frame().T.round(2).to_csv(csv_name, index=False, **table_args)
1296
- pd.DataFrame(['']).to_csv(csv_name, **header_args)
1297
- pd.DataFrame([f'Speech enhancement metrics avg over each SNR:']).to_csv(csv_name, **header_args)
1298
- mtab_snr_summary.round(2).to_csv(csv_name, index=False, **table_args)
1299
- pd.DataFrame(['']).to_csv(csv_name, **header_args)
1300
- pd.DataFrame([f'Extraction statistics stats avg over each SNR:']).to_csv(csv_name, **header_args)
1301
- mtab_snr_summary_em.round(2).to_csv(csv_name, index=False, **table_args)
1302
- pd.DataFrame(['']).to_csv(csv_name, **header_args)
1303
- pd.DataFrame(['']).to_csv(csv_name, **header_args)
1304
- label = f'Speech enhancement metrics stats over {num_mix} mixtures:'
1305
- pd.DataFrame([label]).to_csv(csv_name, **header_args)
1306
- all_metrics_table_1.describe().round(2).to_csv(csv_name, **table_args)
1307
- pd.DataFrame(['']).to_csv(csv_name, **header_args)
1308
- label = f'Extraction statistics stats over {num_mix} mixtures:'
1309
- pd.DataFrame([label]).to_csv(csv_name, **header_args)
1310
- all_metrics_table_2.describe().round(2).to_csv(csv_name, **table_args)
1311
- label = f'WER enabled with method {wer_method}, whisper model, if used: {whisper_model}'
1312
- pd.DataFrame([label]).to_csv(csv_name, **header_args)
1313
-
1314
- if not truth_est_mode:
1315
- csv_name = join(predict_location, fnb + 'list.csv')
1316
- else:
1317
- csv_name = join(predict_location, fnb + 'list_truest.csv')
1318
- pd.DataFrame(['Speech enhancement metrics list:']).to_csv(csv_name, header=False, index=False) # open as write
1319
- all_metrics_table_1.round(2).to_csv(csv_name, **table_args)
1320
-
1321
- if not truth_est_mode:
1322
- csv_name = join(predict_location, fnb + 'estats_list.csv')
1323
- else:
1324
- csv_name = join(predict_location, fnb + 'estats_list_truest.csv')
1325
- pd.DataFrame(['Extraction statistics list:']).to_csv(csv_name, header=False, index=False) # open as write
1326
- all_metrics_table_2.round(2).to_csv(csv_name, **table_args)
1327
-
1328
-
1329
- if __name__ == '__main__':
1330
- try:
1331
- main()
1332
- except KeyboardInterrupt:
1333
- logger.info('Canceled due to keyboard interrupt')
1334
- exit()