sonusai 0.18.8__py3-none-any.whl → 0.19.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +93 -77
  13. sonusai/genmetrics.py +59 -46
  14. sonusai/genmix.py +116 -104
  15. sonusai/genmixdb.py +194 -153
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +15 -17
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +19 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +50 -46
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +677 -473
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +52 -85
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +40 -27
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
  113. sonusai-0.19.5.dist-info/RECORD +125 -0
  114. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.8.dist-info/RECORD +0 -125
  118. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
@@ -2,12 +2,13 @@ import numpy as np
2
2
 
3
3
  from sonusai.mixture.constants import SAMPLE_RATE
4
4
  from sonusai.mixture.datatypes import SpeechMetrics
5
+
5
6
  from .calc_pesq import calc_pesq
6
7
 
7
8
 
8
9
  def calc_speech(hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int = SAMPLE_RATE) -> SpeechMetrics:
9
10
  """Calculate speech metrics pesq, c_sig, c_bak, and c_ovl.
10
-
11
+
11
12
  These are all related and thus included in one function. Reference: matlab script "compute_metrics.m".
12
13
 
13
14
  :param hypothesis: estimated audio
@@ -22,7 +23,7 @@ def calc_speech(hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int
22
23
 
23
24
  # Value from CMGAN reference implementation
24
25
  alpha = 0.95
25
- wss_dist = np.mean(wss_dist_vec[0: round(np.size(wss_dist_vec) * alpha)])
26
+ wss_dist = np.mean(wss_dist_vec[0 : round(np.size(wss_dist_vec) * alpha)])
26
27
 
27
28
  # Log likelihood ratio measure
28
29
  llr_dist = _calc_log_likelihood_ratio_measure(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
@@ -45,16 +46,16 @@ def calc_speech(hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int
45
46
  return SpeechMetrics(_pesq, csig, cbak, covl)
46
47
 
47
48
 
48
- def _calc_weighted_spectral_slope_measure(hypothesis: np.ndarray,
49
- reference: np.ndarray,
50
- sample_rate: int = SAMPLE_RATE) -> np.ndarray:
49
+ def _calc_weighted_spectral_slope_measure(
50
+ hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int = SAMPLE_RATE
51
+ ) -> np.ndarray:
51
52
  from scipy.fftpack import fft
52
53
 
53
54
  # The lengths of the reference and hypothesis must be the same.
54
55
  reference_length = np.size(reference)
55
56
  hypothesis_length = np.size(hypothesis)
56
57
  if reference_length != hypothesis_length:
57
- raise ValueError('Hypothesis and reference must be the same length.')
58
+ raise ValueError("Hypothesis and reference must be the same length.")
58
59
 
59
60
  # Window length in samples
60
61
  win_length = int(np.round(30 * sample_rate / 1000))
@@ -72,14 +73,64 @@ def _calc_weighted_spectral_slope_measure(hypothesis: np.ndarray,
72
73
  k_loc_max = 1.0
73
74
 
74
75
  # Critical band filter definitions (center frequency and bandwidths in Hz)
75
- cent_freq = np.array([50.0000, 120.000, 190.000, 260.000, 330.000, 400.000, 470.000,
76
- 540.000, 617.372, 703.378, 798.717, 904.128, 1020.38, 1148.30,
77
- 1288.72, 1442.54, 1610.70, 1794.16, 1993.93, 2211.08, 2446.71,
78
- 2701.97, 2978.04, 3276.17, 3597.63])
79
- bandwidth = np.array([70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 70.0000,
80
- 77.3724, 86.0056, 95.3398, 105.411, 116.256, 127.914, 140.423,
81
- 153.823, 168.154, 183.457, 199.776, 217.153, 235.631, 255.255,
82
- 276.072, 298.126, 321.465, 346.136])
76
+ cent_freq = np.array(
77
+ [
78
+ 50.0000,
79
+ 120.000,
80
+ 190.000,
81
+ 260.000,
82
+ 330.000,
83
+ 400.000,
84
+ 470.000,
85
+ 540.000,
86
+ 617.372,
87
+ 703.378,
88
+ 798.717,
89
+ 904.128,
90
+ 1020.38,
91
+ 1148.30,
92
+ 1288.72,
93
+ 1442.54,
94
+ 1610.70,
95
+ 1794.16,
96
+ 1993.93,
97
+ 2211.08,
98
+ 2446.71,
99
+ 2701.97,
100
+ 2978.04,
101
+ 3276.17,
102
+ 3597.63,
103
+ ]
104
+ )
105
+ bandwidth = np.array(
106
+ [
107
+ 70.0000,
108
+ 70.0000,
109
+ 70.0000,
110
+ 70.0000,
111
+ 70.0000,
112
+ 70.0000,
113
+ 70.0000,
114
+ 77.3724,
115
+ 86.0056,
116
+ 95.3398,
117
+ 105.411,
118
+ 116.256,
119
+ 127.914,
120
+ 140.423,
121
+ 153.823,
122
+ 168.154,
123
+ 183.457,
124
+ 199.776,
125
+ 217.153,
126
+ 235.631,
127
+ 255.255,
128
+ 276.072,
129
+ 298.126,
130
+ 321.465,
131
+ 346.136,
132
+ ]
133
+ )
83
134
 
84
135
  # Minimum critical bandwidth
85
136
  bw_min = bandwidth[0]
@@ -109,8 +160,8 @@ def _calc_weighted_spectral_slope_measure(hypothesis: np.ndarray,
109
160
  distortion = np.empty(num_frames)
110
161
  for frame_count in range(num_frames):
111
162
  # (1) Get the frames for the test and reference speech. Multiply by Hanning window.
112
- reference_frame = reference[start: start + win_length] / 32768
113
- hypothesis_frame = hypothesis[start: start + win_length] / 32768
163
+ reference_frame = reference[start : start + win_length] / 32768
164
+ hypothesis_frame = hypothesis[start : start + win_length] / 32768
114
165
  reference_frame = np.multiply(reference_frame, window)
115
166
  hypothesis_frame = np.multiply(hypothesis_frame, window)
116
167
 
@@ -122,12 +173,12 @@ def _calc_weighted_spectral_slope_measure(hypothesis: np.ndarray,
122
173
  reference_energy = np.matmul(crit_filter, reference_spec[0:n_fft_by_2])
123
174
  hypothesis_energy = np.matmul(crit_filter, hypothesis_spec[0:n_fft_by_2])
124
175
 
125
- reference_energy = 10 * np.log10(np.maximum(reference_energy, 1E-10))
126
- hypothesis_energy = 10 * np.log10(np.maximum(hypothesis_energy, 1E-10))
176
+ reference_energy = 10 * np.log10(np.maximum(reference_energy, 1e-10))
177
+ hypothesis_energy = 10 * np.log10(np.maximum(hypothesis_energy, 1e-10))
127
178
 
128
179
  # (4) Compute spectral slope (dB[i+1]-dB[i])
129
- reference_slope = reference_energy[1:num_crit] - reference_energy[0: num_crit - 1]
130
- hypothesis_slope = hypothesis_energy[1:num_crit] - hypothesis_energy[0: num_crit - 1]
180
+ reference_slope = reference_energy[1:num_crit] - reference_energy[0 : num_crit - 1]
181
+ hypothesis_slope = hypothesis_energy[1:num_crit] - hypothesis_energy[0 : num_crit - 1]
131
182
 
132
183
  # (5) Find the nearest peak locations in the spectra to each critical band.
133
184
  # If the slope is negative, we search to the left. If positive, we search to the right.
@@ -173,17 +224,22 @@ def _calc_weighted_spectral_slope_measure(hypothesis: np.ndarray,
173
224
  # and less emphasis on slope differences in spectral valleys.
174
225
  # This procedure is described on page 1280 of Klatt's 1982 ICASSP paper.
175
226
 
176
- w_max_reference = np.divide(k_max, k_max + db_max_reference - reference_energy[0: num_crit - 1])
177
- w_loc_max_reference = np.divide(k_loc_max, k_loc_max + reference_loc_peak - reference_energy[0: num_crit - 1])
227
+ w_max_reference = np.divide(k_max, k_max + db_max_reference - reference_energy[0 : num_crit - 1])
228
+ w_loc_max_reference = np.divide(
229
+ k_loc_max,
230
+ k_loc_max + reference_loc_peak - reference_energy[0 : num_crit - 1],
231
+ )
178
232
  w_reference = np.multiply(w_max_reference, w_loc_max_reference)
179
233
 
180
- w_max_hypothesis = np.divide(k_max, k_max + db_max_hypothesis - hypothesis_energy[0: num_crit - 1])
181
- w_loc_max_hypothesis = np.divide(k_loc_max,
182
- k_loc_max + hypothesis_loc_peak - hypothesis_energy[0: num_crit - 1])
234
+ w_max_hypothesis = np.divide(k_max, k_max + db_max_hypothesis - hypothesis_energy[0 : num_crit - 1])
235
+ w_loc_max_hypothesis = np.divide(
236
+ k_loc_max,
237
+ k_loc_max + hypothesis_loc_peak - hypothesis_energy[0 : num_crit - 1],
238
+ )
183
239
  w_hypothesis = np.multiply(w_max_hypothesis, w_loc_max_hypothesis)
184
240
 
185
241
  w = np.divide(np.add(w_reference, w_hypothesis), 2.0)
186
- slope_diff = np.subtract(reference_slope, hypothesis_slope)[0: num_crit - 1]
242
+ slope_diff = np.subtract(reference_slope, hypothesis_slope)[0 : num_crit - 1]
187
243
  distortion[frame_count] = np.dot(w, np.square(slope_diff)) / np.sum(w)
188
244
 
189
245
  # This normalization is not part of Klatt's paper, but helps to normalize the measure.
@@ -193,16 +249,16 @@ def _calc_weighted_spectral_slope_measure(hypothesis: np.ndarray,
193
249
  return distortion
194
250
 
195
251
 
196
- def _calc_log_likelihood_ratio_measure(hypothesis: np.ndarray,
197
- reference: np.ndarray,
198
- sample_rate: int = SAMPLE_RATE) -> np.ndarray:
252
+ def _calc_log_likelihood_ratio_measure(
253
+ hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int = SAMPLE_RATE
254
+ ) -> np.ndarray:
199
255
  from scipy.linalg import toeplitz
200
256
 
201
257
  # The lengths of the reference and hypothesis must be the same.
202
258
  reference_length = np.size(reference)
203
259
  hypothesis_length = np.size(hypothesis)
204
260
  if reference_length != hypothesis_length:
205
- raise ValueError('Hypothesis and reference must be the same length.')
261
+ raise ValueError("Hypothesis and reference must be the same length.")
206
262
 
207
263
  # window length in samples
208
264
  win_length = int(np.round(30 * sample_rate / 1000))
@@ -222,8 +278,8 @@ def _calc_log_likelihood_ratio_measure(hypothesis: np.ndarray,
222
278
  distortion = np.empty(num_frames)
223
279
  for frame_count in range(num_frames):
224
280
  # (1) Get the frames for the test and reference speech. Multiply by Hanning window.
225
- reference_frame = reference[start: start + win_length]
226
- hypothesis_frame = hypothesis[start: start + win_length]
281
+ reference_frame = reference[start : start + win_length]
282
+ hypothesis_frame = hypothesis[start : start + win_length]
227
283
  reference_frame = np.multiply(reference_frame, window)
228
284
  hypothesis_frame = np.multiply(hypothesis_frame, window)
229
285
 
@@ -239,16 +295,18 @@ def _calc_log_likelihood_ratio_measure(hypothesis: np.ndarray,
239
295
  return distortion
240
296
 
241
297
 
242
- def _calc_snr(hypothesis: np.ndarray,
243
- reference: np.ndarray,
244
- sample_rate: int = SAMPLE_RATE) -> tuple[float, np.ndarray]:
298
+ def _calc_snr(
299
+ hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int = SAMPLE_RATE
300
+ ) -> tuple[float, np.ndarray]:
245
301
  # The lengths of the reference and hypothesis must be the same.
246
302
  reference_length = len(reference)
247
303
  hypothesis_length = len(hypothesis)
248
304
  if reference_length != hypothesis_length:
249
- raise ValueError('Hypothesis and reference must be the same length.')
305
+ raise ValueError("Hypothesis and reference must be the same length.")
250
306
 
251
- overall_snr = 10 * np.log10(np.sum(np.square(reference)) / np.sum(np.square(reference - hypothesis)))
307
+ overall_snr = 10 * np.log10(
308
+ np.sum(np.square(reference)) / (np.sum(np.square(reference - hypothesis))) + np.finfo(np.float32).eps
309
+ )
252
310
 
253
311
  # window length in samples
254
312
  win_length = round(30 * sample_rate / 1000)
@@ -268,17 +326,17 @@ def _calc_snr(hypothesis: np.ndarray,
268
326
  eps = np.spacing(1)
269
327
  for frame_count in range(num_frames):
270
328
  # (1) Get the frames for the test and reference speech. Multiply by Hanning window.
271
- reference_frame = reference[start:start + win_length]
272
- hypothesis_frame = hypothesis[start:start + win_length]
329
+ reference_frame = reference[start : start + win_length]
330
+ hypothesis_frame = hypothesis[start : start + win_length]
273
331
  reference_frame = np.multiply(reference_frame, window)
274
332
  hypothesis_frame = np.multiply(hypothesis_frame, window)
275
333
 
276
334
  # (2) Compute the segmental SNR
277
335
  signal_energy = np.sum(np.square(reference_frame))
278
336
  noise_energy = np.sum(np.square(reference_frame - hypothesis_frame))
279
- segmental_snr[frame_count] = np.clip(10 * np.log10(signal_energy / (noise_energy + eps) + eps),
280
- min_snr,
281
- max_snr)
337
+ segmental_snr[frame_count] = np.clip(
338
+ 10 * np.log10(signal_energy / (noise_energy + eps) + eps), min_snr, max_snr
339
+ )
282
340
 
283
341
  start = start + skip_rate
284
342
 
@@ -291,7 +349,7 @@ def _lp_coefficients(speech_frame, model_order):
291
349
  autocorrelation = np.empty(model_order + 1)
292
350
  e = np.empty(model_order + 1)
293
351
  for k in range(model_order + 1):
294
- autocorrelation[k] = np.dot(speech_frame[0:win_length - k], speech_frame[k: win_length])
352
+ autocorrelation[k] = np.dot(speech_frame[0 : win_length - k], speech_frame[k:win_length])
295
353
 
296
354
  # (2) Levinson-Durbin
297
355
  a = np.ones(model_order)
@@ -299,14 +357,14 @@ def _lp_coefficients(speech_frame, model_order):
299
357
  ref_coefficients = np.empty(model_order)
300
358
  e[0] = autocorrelation[0]
301
359
  for i in range(model_order):
302
- a_past[0: i] = a[0: i]
303
- sum_term = np.dot(a_past[0: i], autocorrelation[i:0:-1])
360
+ a_past[0:i] = a[0:i]
361
+ sum_term = np.dot(a_past[0:i], autocorrelation[i:0:-1])
304
362
  ref_coefficients[i] = (autocorrelation[i + 1] - sum_term) / e[i]
305
363
  a[i] = ref_coefficients[i]
306
364
  if i == 0:
307
- a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1:-1:-1], ref_coefficients[i])
365
+ a[0:i] = a_past[0:i] - np.multiply(a_past[i - 1 : -1 : -1], ref_coefficients[i])
308
366
  else:
309
- a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1::-1], ref_coefficients[i])
367
+ a[0:i] = a_past[0:i] - np.multiply(a_past[i - 1 :: -1], ref_coefficients[i])
310
368
  e[i + 1] = (1 - ref_coefficients[i] * ref_coefficients[i]) * e[i]
311
369
  lp_params = np.concatenate((np.array([1]), -a))
312
370
  return autocorrelation, ref_coefficients, lp_params
@@ -18,15 +18,17 @@ def calc_wer(hypothesis: list[str] | str, reference: list[str] | str) -> WerResu
18
18
  """
19
19
  import jiwer
20
20
 
21
- transformation = jiwer.Compose([
22
- jiwer.ToLowerCase(),
23
- jiwer.RemovePunctuation(),
24
- jiwer.RemoveWhiteSpace(replace_by_space=True),
25
- jiwer.RemoveMultipleSpaces(),
26
- jiwer.Strip(),
27
- jiwer.RemoveEmptyStrings(),
28
- jiwer.ReduceToListOfListOfWords(word_delimiter=' ')
29
- ])
21
+ transformation = jiwer.Compose(
22
+ [
23
+ jiwer.ToLowerCase(),
24
+ jiwer.RemovePunctuation(),
25
+ jiwer.RemoveWhiteSpace(replace_by_space=True),
26
+ jiwer.RemoveMultipleSpaces(),
27
+ jiwer.Strip(),
28
+ jiwer.RemoveEmptyStrings(),
29
+ jiwer.ReduceToListOfListOfWords(word_delimiter=" "),
30
+ ]
31
+ )
30
32
 
31
33
  if isinstance(reference, str):
32
34
  reference = [reference]
@@ -34,35 +36,36 @@ def calc_wer(hypothesis: list[str] | str, reference: list[str] | str) -> WerResu
34
36
  hypothesis = [hypothesis]
35
37
 
36
38
  # jiwer does not allow empty string
37
- measures = {'insertions': 0,
38
- 'substitutions': 0,
39
- 'deletions': 0,
40
- 'hits': 0}
39
+ measures = {"insertions": 0, "substitutions": 0, "deletions": 0, "hits": 0}
41
40
  if any(len(t) == 0 for t in reference):
42
41
  if any(len(t) != 0 for t in hypothesis):
43
- measures['insertions'] = len(hypothesis)
42
+ measures["insertions"] = len(hypothesis)
44
43
  else:
45
- measures = jiwer.compute_measures(truth=reference,
46
- hypothesis=hypothesis,
47
- truth_transform=transformation,
48
- hypothesis_transform=transformation)
44
+ measures = jiwer.compute_measures(
45
+ truth=reference,
46
+ hypothesis=hypothesis,
47
+ truth_transform=transformation,
48
+ hypothesis_transform=transformation,
49
+ )
49
50
 
50
- errors = measures['substitutions'] + measures['deletions'] + measures['insertions']
51
- words = measures['hits'] + measures['substitutions'] + measures['deletions']
51
+ errors = measures["substitutions"] + measures["deletions"] + measures["insertions"]
52
+ words = measures["hits"] + measures["substitutions"] + measures["deletions"]
52
53
 
53
54
  if words != 0:
54
55
  wer = errors / words
55
- substitutions_rate = measures['substitutions'] / words
56
- deletions_rate = measures['deletions'] / words
57
- insertions_rate = measures['insertions'] / words
56
+ substitutions_rate = measures["substitutions"] / words
57
+ deletions_rate = measures["deletions"] / words
58
+ insertions_rate = measures["insertions"] / words
58
59
  else:
59
- wer = float('inf')
60
- substitutions_rate = float('inf')
61
- deletions_rate = float('inf')
62
- insertions_rate = float('inf')
60
+ wer = float("inf")
61
+ substitutions_rate = float("inf")
62
+ deletions_rate = float("inf")
63
+ insertions_rate = float("inf")
63
64
 
64
- return WerResult(wer=wer,
65
- words=int(words),
66
- substitutions=substitutions_rate,
67
- deletions=deletions_rate,
68
- insertions=insertions_rate)
65
+ return WerResult(
66
+ wer=wer,
67
+ words=int(words),
68
+ substitutions=substitutions_rate,
69
+ deletions=deletions_rate,
70
+ insertions=insertions_rate,
71
+ )
@@ -1,10 +1,12 @@
1
1
  import numpy as np
2
2
 
3
3
 
4
- def calc_wsdr(hypothesis: np.ndarray,
5
- reference: np.ndarray,
6
- with_log: bool = False,
7
- with_negate: bool = False) -> tuple[float, np.ndarray, np.ndarray]:
4
+ def calc_wsdr(
5
+ hypothesis: np.ndarray,
6
+ reference: np.ndarray,
7
+ with_log: bool = False,
8
+ with_negate: bool = False,
9
+ ) -> tuple[float, np.ndarray, np.ndarray]:
8
10
  """Calculate weighted SDR (signal distortion ratio) using all source inputs of size [samples, nsrc].
9
11
  Uses true reference energy ratios to weight each cross-correlation coefficient cc = <y,yˆ>/∥y∥∥yˆ∥
10
12
  in a sum over all sources.
@@ -26,11 +28,12 @@ def calc_wsdr(hypothesis: np.ndarray,
26
28
  :return: (wsdr, ccoef, cweights)
27
29
  """
28
30
  nsrc = reference.shape[-1]
29
- assert hypothesis.shape[-1] == nsrc
31
+ if hypothesis.shape[-1] != nsrc:
32
+ raise ValueError("hypothesis has wrong shape")
30
33
 
31
34
  # Calculate cc = <y,yˆ>/∥y∥∥yˆ∥ always in range -1 --> 1, size [1,nsrc]
32
- ref_e = np.sum(reference ** 2, axis=0, keepdims=True) # [1,nsrc]
33
- hy_e = np.sum(hypothesis ** 2, axis=0, keepdims=True)
35
+ ref_e = np.sum(reference**2, axis=0, keepdims=True) # [1,nsrc]
36
+ hy_e = np.sum(hypothesis**2, axis=0, keepdims=True)
34
37
  allref_e = np.sum(ref_e)
35
38
  cc = np.zeros(nsrc) # calc correlation coefficient
36
39
  cw = np.zeros(nsrc) # cc weights (energy ratio)
@@ -1,3 +1,4 @@
1
+ # ruff: noqa: F821
1
2
  import numpy as np
2
3
  import pandas as pd
3
4
 
@@ -7,33 +8,35 @@ from sonusai.mixture import Predict
7
8
  from sonusai.mixture import Truth
8
9
 
9
10
 
10
- def class_summary(mixdb: MixtureDatabase,
11
- mixids: GeneralizedIDs,
12
- truth_f: Truth,
13
- predict: Predict,
14
- predict_thr: float | np.ndarray = 0,
15
- truth_thr: float = 0.5,
16
- timesteps: int = 0) -> pd.DataFrame:
17
- """ Calculate table of metrics per class, and averages for a list
18
- of mixtures using truth and prediction data [features, num_classes]
19
- Example:
20
- Generate multi-class metric summary into table, for example:
21
- PPV TPR F1 FPR ACC AP AUC Support
22
- Class 1 0.71 0.80 0.75 0.00 0.99 44
23
- Class 2 0.90 0.76 0.82 0.00 0.99 128
24
- Class 3 0.86 0.82 0.84 0.04 0.93 789
25
- Other 0.94 0.96 0.95 0.18 0.92 2807
11
+ def class_summary(
12
+ mixdb: MixtureDatabase,
13
+ mixids: GeneralizedIDs,
14
+ truth_f: Truth,
15
+ predict: Predict,
16
+ predict_thr: float | np.ndarray = 0,
17
+ truth_thr: float = 0.5,
18
+ timesteps: int = 0,
19
+ ) -> pd.DataFrame:
20
+ """Calculate table of metrics per class, and averages for a list
21
+ of mixtures using truth and prediction data [features, num_classes]
22
+ Example:
23
+ Generate multi-class metric summary into table, for example:
24
+ PPV TPR F1 FPR ACC AP AUC Support
25
+ Class 1 0.71 0.80 0.75 0.00 0.99 44
26
+ Class 2 0.90 0.76 0.82 0.00 0.99 128
27
+ Class 3 0.86 0.82 0.84 0.04 0.93 789
28
+ Other 0.94 0.96 0.95 0.18 0.92 2807
26
29
 
27
- micro-avg 0.92 0.027 3768
28
- macro avg 0.85 0.83 0.84 0.05 0.96 3768
29
- micro-avgwo
30
+ micro-avg 0.92 0.027 3768
31
+ macro avg 0.85 0.83 0.84 0.05 0.96 3768
32
+ micro-avgwo
30
33
  """
31
34
  from sonusai.metrics import one_hot
32
35
 
33
36
  num_classes = truth_f.shape[1]
34
37
 
35
38
  # TODO: re-work for modern mixdb API
36
- y_truth_f, y_predict = get_mixids_data(mixdb, mixids, truth_f, predict) # type: ignore
39
+ y_truth_f, y_predict = get_mixids_data(mixdb, mixids, truth_f, predict) # type: ignore[name-defined]
37
40
 
38
41
  if not mixdb.truth_mutex and num_classes > 1:
39
42
  if not isinstance(predict_thr, np.ndarray):
@@ -49,25 +52,25 @@ def class_summary(mixdb: MixtureDatabase,
49
52
 
50
53
  # [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
51
54
  table_idx = np.array([2, 1, 6, 4, 0, 12, 13, 9])
52
- col_n = ['PPV', 'TPR', 'F1', 'FPR', 'ACC', 'AP', 'AUC', 'Support']
55
+ col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC", "Support"]
53
56
  if mixdb.truth_mutex:
54
57
  if len(mixdb.class_labels) >= num_classes - 1: # labels exist with or without Other
55
58
  row_n = mixdb.class_labels
56
59
  if len(mixdb.class_labels) == num_classes - 1: # Other label does not exist, so add it
57
- row_n.append('Other')
60
+ row_n.append("Other")
58
61
  else:
59
- row_n = ([f'Class {i}' for i in range(1, num_classes)])
60
- row_n.append('Other')
62
+ row_n = [f"Class {i}" for i in range(1, num_classes)]
63
+ row_n.append("Other")
61
64
  else:
62
65
  if len(mixdb.class_labels) == num_classes:
63
66
  row_n = mixdb.class_labels
64
67
  else:
65
- row_n = ([f'Class {i}' for i in range(1, num_classes + 1)])
68
+ row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
66
69
 
67
70
  df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n)
68
71
 
69
72
  # [miPPV, miTPR, miF1, miFPR, miACC, miAP, miAUC, TPSUM]
70
- avg_row_n = ['Macro-avg', 'Micro-avg', 'Weighted-avg']
73
+ avg_row_n = ["Macro-avg", "Micro-avg", "Weighted-avg"]
71
74
  dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n)
72
75
 
73
76
  # dfblank = pd.DataFrame([''])
@@ -75,6 +78,6 @@ def class_summary(mixdb: MixtureDatabase,
75
78
 
76
79
  classdf = pd.concat([df, dfavg])
77
80
  # classdf = classdf.round(2)
78
- classdf['Support'] = classdf['Support'].astype(int)
81
+ classdf["Support"] = classdf["Support"].astype(int)
79
82
 
80
83
  return classdf
@@ -1,3 +1,4 @@
1
+ # ruff: noqa: F821
1
2
  import numpy as np
2
3
  import pandas as pd
3
4
 
@@ -7,31 +8,33 @@ from sonusai.mixture import Predict
7
8
  from sonusai.mixture import Truth
8
9
 
9
10
 
10
- def confusion_matrix_summary(mixdb: MixtureDatabase,
11
- mixids: GeneralizedIDs,
12
- truth_f: Truth,
13
- predict: Predict,
14
- class_idx: int,
15
- predict_thr: float | np.ndarray = 0,
16
- truth_thr: float = 0.5,
17
- timesteps: int = 0) -> tuple[pd.DataFrame, pd.DataFrame]:
11
+ def confusion_matrix_summary(
12
+ mixdb: MixtureDatabase,
13
+ mixids: GeneralizedIDs,
14
+ truth_f: Truth,
15
+ predict: Predict,
16
+ class_idx: int,
17
+ predict_thr: float | np.ndarray = 0,
18
+ truth_thr: float = 0.5,
19
+ timesteps: int = 0,
20
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
18
21
  """Calculate confusion matrix for specified class, using truth and prediction
19
- data [features, num_classes].
22
+ data [features, num_classes].
20
23
 
21
- predict_thr sets the decision threshold(s) applied to predict data, thus allowing
22
- predict to be continuous probabilities.
24
+ predict_thr sets the decision threshold(s) applied to predict data, thus allowing
25
+ predict to be continuous probabilities.
23
26
 
24
- Default predict_thr=0 will infer 0.5 for multi-label mode (truth_mutex = False), or
25
- if single-label mode (truth_mutex == True) then ignore and use argmax mode, and
26
- the confusion matrix is calculated for all classes.
27
+ Default predict_thr=0 will infer 0.5 for multi-label mode (truth_mutex = False), or
28
+ if single-label mode (truth_mutex == True) then ignore and use argmax mode, and
29
+ the confusion matrix is calculated for all classes.
27
30
 
28
- Returns pandas dataframes of confusion matrix cmdf and normalized confusion matrix cmndf.
31
+ Returns pandas dataframes of confusion matrix cmdf and normalized confusion matrix cmndf.
29
32
  """
30
33
  from sonusai.metrics import one_hot
31
34
 
32
35
  num_classes = truth_f.shape[1]
33
36
  # TODO: re-work for modern mixdb API
34
- ytrue, ypred = get_mixids_data(mixdb=mixdb, mixids=mixids, truth_f=truth_f, predict=predict) # type: ignore
37
+ ytrue, ypred = get_mixids_data(mixdb=mixdb, mixids=mixids, truth_f=truth_f, predict=predict) # type: ignore[name-defined]
35
38
 
36
39
  # Check predict_thr array or scalar and return final scalar predict_thr value
37
40
  if not mixdb.truth_mutex and num_classes > 1:
@@ -56,16 +59,16 @@ def confusion_matrix_summary(mixdb: MixtureDatabase,
56
59
  if len(mixdb.class_labels) == num_classes:
57
60
  class_names = mixdb.class_labels
58
61
  else:
59
- class_names = ([f'Class {i}' for i in range(1, num_classes + 1)])
62
+ class_names = [f"Class {i}" for i in range(1, num_classes + 1)]
60
63
 
61
- class_nums = ([f'{i}' for i in range(1, num_classes + 1)])
64
+ class_nums = [f"{i}" for i in range(1, num_classes + 1)]
62
65
 
63
66
  if mixdb.truth_mutex:
64
67
  # single-label mode force to argmax mode
65
68
  predict_thr = np.array(0, dtype=np.float32)
66
69
  _, _, cm, cmn, _, _ = one_hot(ytrue, ypred, predict_thr, truth_thr, timesteps)
67
70
  row_n = class_names
68
- row_n[-1] = 'Other'
71
+ row_n[-1] = "Other"
69
72
  # mux = pd.MultiIndex.from_product([['Single-label/mutex mode, truth thr = {}'.format(truth_thr)],
70
73
  # class_nums])
71
74
  # mux = pd.MultiIndex.from_product([['truth thr = {}'.format(truth_thr)], class_nums])
@@ -76,12 +79,12 @@ def confusion_matrix_summary(mixdb: MixtureDatabase,
76
79
  else:
77
80
  _, _, cm, cmn, _, _ = one_hot(ytrue[:, class_idx], ypred[:, class_idx], predict_thr, truth_thr, timesteps)
78
81
  cname = class_names[class_idx]
79
- row_n = ['TrueN', 'TrueP']
80
- col_n = ['N-' + cname, 'P-' + cname]
82
+ row_n = ["TrueN", "TrueP"]
83
+ col_n = ["N-" + cname, "P-" + cname]
81
84
  cmdf = pd.DataFrame(cm, index=row_n, columns=col_n, dtype=np.int32)
82
85
  cmndf = pd.DataFrame(cmn, index=row_n, columns=col_n, dtype=np.float32)
83
86
  # add thresholds in 3rd row
84
- pdnote = pd.DataFrame(np.atleast_2d([predict_thr, truth_thr]), index=['p/t thr:'], columns=col_n)
87
+ pdnote = pd.DataFrame(np.atleast_2d([predict_thr, truth_thr]), index=["p/t thr:"], columns=col_n)
85
88
  cmdf = pd.concat([cmdf, pdnote])
86
89
  cmndf = pd.concat([cmndf, pdnote])
87
90