sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,382 @@
1
+ import numpy as np
2
+
3
+ from ..constants import SAMPLE_RATE
4
+ from ..datatypes import SpeechMetrics
5
+ from .calc_pesq import calc_pesq
6
+
7
+
8
+ def calc_speech(
9
+ hypothesis: np.ndarray,
10
+ reference: np.ndarray,
11
+ pesq: float | None = None,
12
+ sample_rate: int = SAMPLE_RATE,
13
+ ) -> SpeechMetrics:
14
+ """Calculate speech metrics c_sig, c_bak, and c_ovl.
15
+
16
+ These are all related and thus included in one function. Reference: matlab script "compute_metrics.m".
17
+
18
+ :param hypothesis: estimated audio
19
+ :param reference: reference audio
20
+ :param pesq: pesq
21
+ :param sample_rate: sample rate of audio
22
+ :return: SpeechMetrics named tuple
23
+ """
24
+
25
+ # Weighted spectral slope measure
26
+ wss_dist_vec = _calc_weighted_spectral_slope_measure(hypothesis=hypothesis, reference=reference)
27
+ wss_dist_vec = np.sort(wss_dist_vec)
28
+
29
+ # Value from CMGAN reference implementation
30
+ alpha = 0.95
31
+ wss_dist = np.mean(wss_dist_vec[0 : round(np.size(wss_dist_vec) * alpha)])
32
+
33
+ # Log likelihood ratio measure
34
+ llr_dist = _calc_log_likelihood_ratio_measure(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
35
+ ll_rs = np.sort(llr_dist)
36
+ llr_len = round(np.size(llr_dist) * alpha)
37
+ llr_mean = np.mean(ll_rs[:llr_len])
38
+
39
+ # Segmental SNR
40
+ _, segsnr_dist = _calc_snr(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
41
+ seg_snr = np.mean(segsnr_dist)
42
+
43
+ # PESQ
44
+ if pesq is None:
45
+ pesq = calc_pesq(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
46
+
47
+ # Now compute the composite measures
48
+ csig = float(np.clip(3.093 - 1.029 * llr_mean + 0.603 * pesq - 0.009 * wss_dist, 1, 5))
49
+ cbak = float(np.clip(1.634 + 0.478 * pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5))
50
+ covl = float(np.clip(1.594 + 0.805 * pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5))
51
+
52
+ return SpeechMetrics(csig, cbak, covl)
53
+
54
+
55
+ def _calc_weighted_spectral_slope_measure(
56
+ hypothesis: np.ndarray,
57
+ reference: np.ndarray,
58
+ sample_rate: int = SAMPLE_RATE,
59
+ ) -> np.ndarray:
60
+ from scipy.fftpack import fft
61
+
62
+ # The lengths of the reference and hypothesis must be the same.
63
+ reference_length = np.size(reference)
64
+ hypothesis_length = np.size(hypothesis)
65
+ if reference_length != hypothesis_length:
66
+ raise ValueError("Hypothesis and reference must be the same length.")
67
+
68
+ # Window length in samples
69
+ win_length = int(np.round(30 * sample_rate / 1000))
70
+ # Window skip in samples
71
+ skip_rate = int(np.floor(np.divide(win_length, 4)))
72
+ # Maximum bandwidth
73
+ max_freq = int(np.divide(sample_rate, 2))
74
+ num_crit = 25
75
+
76
+ n_fft = int(np.power(2, np.ceil(np.log2(2 * win_length))))
77
+ n_fft_by_2 = int(np.multiply(0.5, n_fft))
78
+ # Value suggested by Klatt, pg 1280
79
+ k_max = 20.0
80
+ # Value suggested by Klatt, pg 1280
81
+ k_loc_max = 1.0
82
+
83
+ # Critical band filter definitions (center frequency and bandwidths in Hz)
84
+ cent_freq = np.array(
85
+ [
86
+ 50.0000,
87
+ 120.000,
88
+ 190.000,
89
+ 260.000,
90
+ 330.000,
91
+ 400.000,
92
+ 470.000,
93
+ 540.000,
94
+ 617.372,
95
+ 703.378,
96
+ 798.717,
97
+ 904.128,
98
+ 1020.38,
99
+ 1148.30,
100
+ 1288.72,
101
+ 1442.54,
102
+ 1610.70,
103
+ 1794.16,
104
+ 1993.93,
105
+ 2211.08,
106
+ 2446.71,
107
+ 2701.97,
108
+ 2978.04,
109
+ 3276.17,
110
+ 3597.63,
111
+ ]
112
+ )
113
+ bandwidth = np.array(
114
+ [
115
+ 70.0000,
116
+ 70.0000,
117
+ 70.0000,
118
+ 70.0000,
119
+ 70.0000,
120
+ 70.0000,
121
+ 70.0000,
122
+ 77.3724,
123
+ 86.0056,
124
+ 95.3398,
125
+ 105.411,
126
+ 116.256,
127
+ 127.914,
128
+ 140.423,
129
+ 153.823,
130
+ 168.154,
131
+ 183.457,
132
+ 199.776,
133
+ 217.153,
134
+ 235.631,
135
+ 255.255,
136
+ 276.072,
137
+ 298.126,
138
+ 321.465,
139
+ 346.136,
140
+ ]
141
+ )
142
+
143
+ # Minimum critical bandwidth
144
+ bw_min = bandwidth[0]
145
+
146
+ # Set up the critical band filters.
147
+ # Note here that Gaussian-ly shaped filters are used.
148
+ # Also, the sum of the filter weights are equivalent for each critical band filter.
149
+ # Filter less than -30 dB and set to zero.
150
+
151
+ # -30 dB point of filter
152
+ min_factor = np.exp(-30.0 / (2.0 * 2.303))
153
+ crit_filter = np.empty((num_crit, n_fft_by_2))
154
+ for i in range(num_crit):
155
+ f0 = (cent_freq[i] / max_freq) * n_fft_by_2
156
+ bw = (bandwidth[i] / max_freq) * n_fft_by_2
157
+ norm_factor = np.log(bw_min) - np.log(bandwidth[i])
158
+ j = np.arange(n_fft_by_2)
159
+ crit_filter[i, :] = np.exp(-11 * np.square(np.divide(j - np.floor(f0), bw)) + norm_factor)
160
+ cond = np.greater(crit_filter[i, :], min_factor)
161
+ crit_filter[i, :] = np.where(cond, crit_filter[i, :], 0)
162
+
163
+ # For each frame of input speech, calculate the weighted spectral slope measure
164
+ num_frames = int(reference_length / skip_rate - (win_length / skip_rate))
165
+ start = 0
166
+ window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
167
+
168
+ distortion = np.empty(num_frames)
169
+ for frame_count in range(num_frames):
170
+ # (1) Get the frames for the test and reference speech. Multiply by Hanning window.
171
+ reference_frame = reference[start : start + win_length] / 32768
172
+ hypothesis_frame = hypothesis[start : start + win_length] / 32768
173
+ reference_frame = np.multiply(reference_frame, window)
174
+ hypothesis_frame = np.multiply(hypothesis_frame, window)
175
+
176
+ # (2) Compute the power spectrum of reference and hypothesis
177
+ reference_spec = np.square(np.abs(fft(reference_frame, n_fft)))
178
+ hypothesis_spec = np.square(np.abs(fft(hypothesis_frame, n_fft)))
179
+
180
+ # (3) Compute filter bank output energies (in dB scale)
181
+ reference_energy = np.matmul(crit_filter, reference_spec[0:n_fft_by_2])
182
+ hypothesis_energy = np.matmul(crit_filter, hypothesis_spec[0:n_fft_by_2])
183
+
184
+ reference_energy = 10 * np.log10(np.maximum(reference_energy, 1e-10))
185
+ hypothesis_energy = 10 * np.log10(np.maximum(hypothesis_energy, 1e-10))
186
+
187
+ # (4) Compute spectral slope (dB[i+1]-dB[i])
188
+ reference_slope = reference_energy[1:num_crit] - reference_energy[0 : num_crit - 1]
189
+ hypothesis_slope = hypothesis_energy[1:num_crit] - hypothesis_energy[0 : num_crit - 1]
190
+
191
+ # (5) Find the nearest peak locations in the spectra to each critical band.
192
+ # If the slope is negative, we search to the left. If positive, we search to the right.
193
+ reference_loc_peak = np.empty(num_crit - 1)
194
+ hypothesis_loc_peak = np.empty(num_crit - 1)
195
+
196
+ for i in range(num_crit - 1):
197
+ # find the peaks in the reference speech signal
198
+ if reference_slope[i] > 0:
199
+ # search to the right
200
+ n = i
201
+ while (n < num_crit - 1) and (reference_slope[n] > 0):
202
+ n = n + 1
203
+ reference_loc_peak[i] = reference_energy[n - 1]
204
+ else:
205
+ # search to the left
206
+ n = i
207
+ while (n >= 0) and (reference_slope[n] <= 0):
208
+ n = n - 1
209
+ reference_loc_peak[i] = reference_energy[n + 1]
210
+
211
+ # find the peaks in the hypothesis speech signal
212
+ if hypothesis_slope[i] > 0:
213
+ # search to the right
214
+ n = i
215
+ while (n < num_crit - 1) and (hypothesis_slope[n] > 0):
216
+ n = n + 1
217
+ hypothesis_loc_peak[i] = hypothesis_energy[n - 1]
218
+ else:
219
+ # search to the left
220
+ n = i
221
+ while (n >= 0) and (hypothesis_slope[n] <= 0):
222
+ n = n - 1
223
+ hypothesis_loc_peak[i] = hypothesis_energy[n + 1]
224
+
225
+ # (6) Compute the weighted spectral slope measure for this frame.
226
+ # This includes determination of the weighting function.
227
+ db_max_reference = np.max(reference_energy)
228
+ db_max_hypothesis = np.max(hypothesis_energy)
229
+
230
+ # The weights are calculated by averaging individual weighting factors from the reference and hypothesis frame.
231
+ # These weights w_reference and w_hypothesis should range from 0 to 1 and place more emphasis on spectral peaks
232
+ # and less emphasis on slope differences in spectral valleys.
233
+ # This procedure is described on page 1280 of Klatt's 1982 ICASSP paper.
234
+
235
+ w_max_reference = np.divide(k_max, k_max + db_max_reference - reference_energy[0 : num_crit - 1])
236
+ w_loc_max_reference = np.divide(
237
+ k_loc_max,
238
+ k_loc_max + reference_loc_peak - reference_energy[0 : num_crit - 1],
239
+ )
240
+ w_reference = np.multiply(w_max_reference, w_loc_max_reference)
241
+
242
+ w_max_hypothesis = np.divide(k_max, k_max + db_max_hypothesis - hypothesis_energy[0 : num_crit - 1])
243
+ w_loc_max_hypothesis = np.divide(
244
+ k_loc_max,
245
+ k_loc_max + hypothesis_loc_peak - hypothesis_energy[0 : num_crit - 1],
246
+ )
247
+ w_hypothesis = np.multiply(w_max_hypothesis, w_loc_max_hypothesis)
248
+
249
+ w = np.divide(np.add(w_reference, w_hypothesis), 2.0)
250
+ slope_diff = np.subtract(reference_slope, hypothesis_slope)[0 : num_crit - 1]
251
+ distortion[frame_count] = np.dot(w, np.square(slope_diff)) / np.sum(w)
252
+
253
+ # This normalization is not part of Klatt's paper, but helps to normalize the measure.
254
+ # Here we scale the measure by the sum of the weights.
255
+ start = start + skip_rate
256
+
257
+ return distortion
258
+
259
+
260
+ def _calc_log_likelihood_ratio_measure(
261
+ hypothesis: np.ndarray,
262
+ reference: np.ndarray,
263
+ sample_rate: int = SAMPLE_RATE,
264
+ ) -> np.ndarray:
265
+ from scipy.linalg import toeplitz
266
+
267
+ # The lengths of the reference and hypothesis must be the same.
268
+ reference_length = np.size(reference)
269
+ hypothesis_length = np.size(hypothesis)
270
+ if reference_length != hypothesis_length:
271
+ raise ValueError("Hypothesis and reference must be the same length.")
272
+
273
+ # window length in samples
274
+ win_length = int(np.round(30 * sample_rate / 1000))
275
+ # window skip in samples
276
+ skip_rate = int(np.floor(win_length / 4))
277
+ # LPC analysis order; this could vary depending on sampling frequency.
278
+ if sample_rate < 10000:
279
+ p = 10
280
+ else:
281
+ p = 16
282
+
283
+ # For each frame of input speech, calculate the log likelihood ratio
284
+ num_frames = int((reference_length - win_length) / skip_rate)
285
+ start = 0
286
+ window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
287
+
288
+ distortion = np.empty(num_frames)
289
+ for frame_count in range(num_frames):
290
+ # (1) Get the frames for the test and reference speech. Multiply by Hanning window.
291
+ reference_frame = reference[start : start + win_length]
292
+ hypothesis_frame = hypothesis[start : start + win_length]
293
+ reference_frame = np.multiply(reference_frame, window)
294
+ hypothesis_frame = np.multiply(hypothesis_frame, window)
295
+
296
+ # (2) Get the autocorrelation lags and LPC parameters used to compute the log likelihood ratio measure.
297
+ r_reference, _, a_reference = _lp_coefficients(reference_frame, p)
298
+ _, _, a_hypothesis = _lp_coefficients(hypothesis_frame, p)
299
+
300
+ # (3) Compute the log likelihood ratio measure
301
+ numerator = np.dot(np.matmul(a_hypothesis, toeplitz(r_reference)), a_hypothesis)
302
+ denominator = np.dot(np.matmul(a_reference, toeplitz(r_reference)), a_reference)
303
+ distortion[frame_count] = np.log(numerator / denominator)
304
+ start = start + skip_rate
305
+ return distortion
306
+
307
+
308
+ def _calc_snr(
309
+ hypothesis: np.ndarray,
310
+ reference: np.ndarray,
311
+ sample_rate: int = SAMPLE_RATE,
312
+ ) -> tuple[float, np.ndarray]:
313
+ # The lengths of the reference and hypothesis must be the same.
314
+ reference_length = len(reference)
315
+ hypothesis_length = len(hypothesis)
316
+ if reference_length != hypothesis_length:
317
+ raise ValueError("Hypothesis and reference must be the same length.")
318
+
319
+ overall_snr = 10 * np.log10(
320
+ np.sum(np.square(reference)) / (np.sum(np.square(reference - hypothesis))) + np.finfo(np.float32).eps
321
+ )
322
+
323
+ # window length in samples
324
+ win_length = round(30 * sample_rate / 1000)
325
+ # window skip in samples
326
+ skip_rate = int(np.floor(win_length / 4))
327
+ # minimum SNR in dB
328
+ min_snr = -10
329
+ # maximum SNR in dB
330
+ max_snr = 35
331
+
332
+ # For each frame of input speech, calculate the segmental SNR
333
+ num_frames = int(reference_length / skip_rate - (win_length / skip_rate))
334
+ start = 0
335
+ window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
336
+
337
+ segmental_snr = np.empty(num_frames)
338
+ eps = np.spacing(1)
339
+ for frame_count in range(num_frames):
340
+ # (1) Get the frames for the test and reference speech. Multiply by Hanning window.
341
+ reference_frame = reference[start : start + win_length]
342
+ hypothesis_frame = hypothesis[start : start + win_length]
343
+ reference_frame = np.multiply(reference_frame, window)
344
+ hypothesis_frame = np.multiply(hypothesis_frame, window)
345
+
346
+ # (2) Compute the segmental SNR
347
+ signal_energy = np.sum(np.square(reference_frame))
348
+ noise_energy = np.sum(np.square(reference_frame - hypothesis_frame))
349
+ segmental_snr[frame_count] = np.clip(
350
+ 10 * np.log10(signal_energy / (noise_energy + eps) + eps), min_snr, max_snr
351
+ )
352
+
353
+ start = start + skip_rate
354
+
355
+ return overall_snr, segmental_snr
356
+
357
+
358
+ def _lp_coefficients(speech_frame, model_order):
359
+ # (1) Compute autocorrelation lags
360
+ win_length = np.size(speech_frame)
361
+ autocorrelation = np.empty(model_order + 1)
362
+ e = np.empty(model_order + 1)
363
+ for k in range(model_order + 1):
364
+ autocorrelation[k] = np.dot(speech_frame[0 : win_length - k], speech_frame[k:win_length])
365
+
366
+ # (2) Levinson-Durbin
367
+ a = np.ones(model_order)
368
+ a_past = np.empty(model_order)
369
+ ref_coefficients = np.empty(model_order)
370
+ e[0] = autocorrelation[0]
371
+ for i in range(model_order):
372
+ a_past[0:i] = a[0:i]
373
+ sum_term = np.dot(a_past[0:i], autocorrelation[i:0:-1])
374
+ ref_coefficients[i] = (autocorrelation[i + 1] - sum_term) / e[i]
375
+ a[i] = ref_coefficients[i]
376
+ if i == 0:
377
+ a[0:i] = a_past[0:i] - np.multiply(a_past[i - 1 : -1 : -1], ref_coefficients[i])
378
+ else:
379
+ a[0:i] = a_past[0:i] - np.multiply(a_past[i - 1 :: -1], ref_coefficients[i])
380
+ e[i + 1] = (1 - ref_coefficients[i] * ref_coefficients[i]) * e[i]
381
+ lp_params = np.concatenate((np.array([1]), -a))
382
+ return autocorrelation, ref_coefficients, lp_params
@@ -0,0 +1,71 @@
1
+ from typing import NamedTuple
2
+
3
+
4
+ class WerResult(NamedTuple):
5
+ wer: float
6
+ words: int
7
+ substitutions: float
8
+ deletions: float
9
+ insertions: float
10
+
11
+
12
+ def calc_wer(hypothesis: list[str] | str, reference: list[str] | str) -> WerResult:
13
+ """Computes average word error rate between two texts represented as corresponding strings or lists of strings.
14
+
15
+ :param hypothesis: the hypothesis sentence(s) as a string or list of strings
16
+ :param reference: the reference sentence(s) as a string or list of strings
17
+ :return: a WerResult object with error, words, insertions, deletions, substitutions
18
+ """
19
+ import jiwer
20
+
21
+ transformation = jiwer.Compose(
22
+ [
23
+ jiwer.ToLowerCase(),
24
+ jiwer.RemovePunctuation(),
25
+ jiwer.RemoveWhiteSpace(replace_by_space=True),
26
+ jiwer.RemoveMultipleSpaces(),
27
+ jiwer.Strip(),
28
+ jiwer.RemoveEmptyStrings(),
29
+ jiwer.ReduceToListOfListOfWords(word_delimiter=" "),
30
+ ]
31
+ )
32
+
33
+ if isinstance(reference, str):
34
+ reference = [reference]
35
+ if isinstance(hypothesis, str):
36
+ hypothesis = [hypothesis]
37
+
38
+ # jiwer does not allow empty string
39
+ measures = {"insertions": 0, "substitutions": 0, "deletions": 0, "hits": 0}
40
+ if any(len(t) == 0 for t in reference):
41
+ if any(len(t) != 0 for t in hypothesis):
42
+ measures["insertions"] = len(hypothesis)
43
+ else:
44
+ measures = jiwer.compute_measures(
45
+ truth=reference,
46
+ hypothesis=hypothesis,
47
+ truth_transform=transformation,
48
+ hypothesis_transform=transformation,
49
+ )
50
+
51
+ errors = measures["substitutions"] + measures["deletions"] + measures["insertions"]
52
+ words = measures["hits"] + measures["substitutions"] + measures["deletions"]
53
+
54
+ if words != 0:
55
+ wer = errors / words
56
+ substitutions_rate = measures["substitutions"] / words
57
+ deletions_rate = measures["deletions"] / words
58
+ insertions_rate = measures["insertions"] / words
59
+ else:
60
+ wer = float("inf")
61
+ substitutions_rate = float("inf")
62
+ deletions_rate = float("inf")
63
+ insertions_rate = float("inf")
64
+
65
+ return WerResult(
66
+ wer=wer,
67
+ words=int(words),
68
+ substitutions=substitutions_rate,
69
+ deletions=deletions_rate,
70
+ insertions=insertions_rate,
71
+ )
@@ -0,0 +1,57 @@
1
+ import numpy as np
2
+
3
+
4
+ def calc_wsdr(
5
+ hypothesis: np.ndarray,
6
+ reference: np.ndarray,
7
+ with_log: bool = False,
8
+ with_negate: bool = False,
9
+ ) -> tuple[float, np.ndarray, np.ndarray]:
10
+ """Calculate weighted SDR (signal distortion ratio) using all source inputs of size [samples, nsrc].
11
+ Uses true reference energy ratios to weight each cross-correlation coefficient cc = <y,yˆ>/∥y∥∥yˆ∥
12
+ in a sum over all sources.
13
+
14
+ range is -1 --> 1 as correlation/estimation improves or with_log -3db --> 70db (1e7 max)
15
+ if with_negate, range is 1 --> -1 as correlation improves and with_log range 3db --> -70db (1e-7 min)
16
+
17
+ Returns: wsdr scalar weighted signal-distortion ratio
18
+ ccoef nsrc vector of cross correlation coefficients
19
+ cweights nsrc vector of reference energy ratio weights
20
+
21
+ Reference:
22
+ WSDR: 2019-ICLR-dcunet-phase-aware-speech-enh
23
+
24
+ :param hypothesis: [samples, nsrc]
25
+ :param reference: [samples, nsrc]
26
+ :param with_log: enable scaling (return 10*log10)
27
+ :param with_negate: enable negation (for use as a loss function)
28
+ :return: (wsdr, ccoef, cweights)
29
+ """
30
+ nsrc = reference.shape[-1]
31
+ if hypothesis.shape[-1] != nsrc:
32
+ raise ValueError("hypothesis has wrong shape")
33
+
34
+ # Calculate cc = <y,yˆ>/∥y∥∥yˆ∥ always in range -1 --> 1, size [1,nsrc]
35
+ ref_e = np.sum(reference**2, axis=0, keepdims=True) # [1,nsrc]
36
+ hy_e = np.sum(hypothesis**2, axis=0, keepdims=True)
37
+ allref_e = np.sum(ref_e)
38
+ cc = np.zeros(nsrc) # calc correlation coefficient
39
+ cw = np.zeros(nsrc) # cc weights (energy ratio)
40
+ for i in range(nsrc):
41
+ denom = np.sqrt(ref_e[0, i]) * np.sqrt(hy_e[0, i]) + 1e-7
42
+ cc[i] = np.sum(reference[:, i] * hypothesis[:, i], axis=0, keepdims=True) / denom
43
+ cw[i] = ref_e[0, i] / (allref_e + 1e-7)
44
+
45
+ # Note: tests show cw sums to 1.0 (+/- 7 digits), so just use cw for weighted sum
46
+ if with_negate: # for use as a loss function
47
+ wsdr = float(np.sum(cw * -cc)) # cc always in range 1 --> -1
48
+ if with_log:
49
+ wsdr = max(wsdr, -1.0)
50
+ wsdr = 10 * np.log10(wsdr + 1 + 1e-7) # range 3 --> -inf (or 1e-7 limit of -70db)
51
+ else:
52
+ wsdr = float(np.sum(cw * cc)) # cc always in range -1 --> 1
53
+ if with_log:
54
+ wsdr = min(wsdr, 1.0) # (np.sum(cw * cc) needs sat ==1.0 for log)
55
+ wsdr = 10 * np.log10(-1 / (wsdr - 1 - 1e-7)) # range -3 --> inf (or 1e-7 limit of 70db)
56
+
57
+ return float(wsdr), cc, cw