sonusai 0.18.1__py3-none-any.whl → 0.18.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,312 @@
1
+ import numpy as np
2
+
3
+ from sonusai.mixture.constants import SAMPLE_RATE
4
+ from sonusai.mixture.datatypes import SpeechMetrics
5
+ from .calc_pesq import calc_pesq
6
+
7
+
8
+ def calc_speech(hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int = SAMPLE_RATE) -> SpeechMetrics:
9
+ """Calculate speech metrics pesq, c_sig, c_bak, c_ovl, seg_snr.
10
+
11
+ These are all related and thus included in one function. Reference: matlab script "compute_metrics.m".
12
+
13
+ :param hypothesis: estimated audio
14
+ :param reference: reference audio
15
+ :param sample_rate: sample rate of audio
16
+ :return: SpeechMetrics named tuple
17
+ """
18
+
19
+ # Weighted spectral slope measure
20
+ wss_dist_vec = _calc_weighted_spectral_slope_measure(hypothesis=hypothesis, reference=reference)
21
+ wss_dist_vec = np.sort(wss_dist_vec)
22
+
23
+ # Value from CMGAN reference implementation
24
+ alpha = 0.95
25
+ wss_dist = np.mean(wss_dist_vec[0: round(np.size(wss_dist_vec) * alpha)])
26
+
27
+ # Log likelihood ratio measure
28
+ llr_dist = _calc_log_likelihood_ratio_measure(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
29
+ ll_rs = np.sort(llr_dist)
30
+ llr_len = round(np.size(llr_dist) * alpha)
31
+ llr_mean = np.mean(ll_rs[:llr_len])
32
+
33
+ # Segmental SNR
34
+ snr_dist, segsnr_dist = _calc_snr(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
35
+ seg_snr = np.mean(segsnr_dist)
36
+
37
+ # PESQ
38
+ _pesq = calc_pesq(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
39
+
40
+ # Now compute the composite measures
41
+ c_sig = np.clip(3.093 - 1.029 * llr_mean + 0.603 * _pesq - 0.009 * wss_dist, 1, 5)
42
+ c_bak = np.clip(1.634 + 0.478 * _pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5)
43
+ c_ovl = np.clip(1.594 + 0.805 * _pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5)
44
+
45
+ return SpeechMetrics(_pesq, c_sig, c_bak, c_ovl)
46
+
47
+
48
+ def _calc_weighted_spectral_slope_measure(hypothesis: np.ndarray,
49
+ reference: np.ndarray,
50
+ sample_rate: int = SAMPLE_RATE) -> np.ndarray:
51
+ from scipy.fftpack import fft
52
+
53
+ # The lengths of the reference and hypothesis must be the same.
54
+ reference_length = np.size(reference)
55
+ hypothesis_length = np.size(hypothesis)
56
+ if reference_length != hypothesis_length:
57
+ raise ValueError('Hypothesis and reference must be the same length.')
58
+
59
+ # Window length in samples
60
+ win_length = int(np.round(30 * sample_rate / 1000))
61
+ # Window skip in samples
62
+ skip_rate = int(np.floor(np.divide(win_length, 4)))
63
+ # Maximum bandwidth
64
+ max_freq = int(np.divide(sample_rate, 2))
65
+ num_crit = 25
66
+
67
+ n_fft = int(np.power(2, np.ceil(np.log2(2 * win_length))))
68
+ n_fft_by_2 = int(np.multiply(0.5, n_fft))
69
+ # Value suggested by Klatt, pg 1280
70
+ k_max = 20.0
71
+ # Value suggested by Klatt, pg 1280
72
+ k_loc_max = 1.0
73
+
74
+ # Critical band filter definitions (center frequency and bandwidths in Hz)
75
+ cent_freq = np.array([50.0000, 120.000, 190.000, 260.000, 330.000, 400.000, 470.000,
76
+ 540.000, 617.372, 703.378, 798.717, 904.128, 1020.38, 1148.30,
77
+ 1288.72, 1442.54, 1610.70, 1794.16, 1993.93, 2211.08, 2446.71,
78
+ 2701.97, 2978.04, 3276.17, 3597.63])
79
+ bandwidth = np.array([70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 70.0000,
80
+ 77.3724, 86.0056, 95.3398, 105.411, 116.256, 127.914, 140.423,
81
+ 153.823, 168.154, 183.457, 199.776, 217.153, 235.631, 255.255,
82
+ 276.072, 298.126, 321.465, 346.136])
83
+
84
+ # Minimum critical bandwidth
85
+ bw_min = bandwidth[0]
86
+
87
+ # Set up the critical band filters.
88
+ # Note here that Gaussian-ly shaped filters are used.
89
+ # Also, the sum of the filter weights are equivalent for each critical band filter.
90
+ # Filter less than -30 dB and set to zero.
91
+
92
+ # -30 dB point of filter
93
+ min_factor = np.exp(-30.0 / (2.0 * 2.303))
94
+ crit_filter = np.empty((num_crit, n_fft_by_2))
95
+ for i in range(num_crit):
96
+ f0 = (cent_freq[i] / max_freq) * n_fft_by_2
97
+ bw = (bandwidth[i] / max_freq) * n_fft_by_2
98
+ norm_factor = np.log(bw_min) - np.log(bandwidth[i])
99
+ j = np.arange(n_fft_by_2)
100
+ crit_filter[i, :] = np.exp(-11 * np.square(np.divide(j - np.floor(f0), bw)) + norm_factor)
101
+ cond = np.greater(crit_filter[i, :], min_factor)
102
+ crit_filter[i, :] = np.where(cond, crit_filter[i, :], 0)
103
+
104
+ # For each frame of input speech, calculate the weighted spectral slope measure
105
+ num_frames = int(reference_length / skip_rate - (win_length / skip_rate))
106
+ start = 0
107
+ window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
108
+
109
+ distortion = np.empty(num_frames)
110
+ for frame_count in range(num_frames):
111
+ # (1) Get the frames for the test and reference speech. Multiply by Hanning window.
112
+ reference_frame = reference[start: start + win_length] / 32768
113
+ hypothesis_frame = hypothesis[start: start + win_length] / 32768
114
+ reference_frame = np.multiply(reference_frame, window)
115
+ hypothesis_frame = np.multiply(hypothesis_frame, window)
116
+
117
+ # (2) Compute the power spectrum of reference and hypothesis
118
+ reference_spec = np.square(np.abs(fft(reference_frame, n_fft)))
119
+ hypothesis_spec = np.square(np.abs(fft(hypothesis_frame, n_fft)))
120
+
121
+ # (3) Compute filter bank output energies (in dB scale)
122
+ reference_energy = np.matmul(crit_filter, reference_spec[0:n_fft_by_2])
123
+ hypothesis_energy = np.matmul(crit_filter, hypothesis_spec[0:n_fft_by_2])
124
+
125
+ reference_energy = 10 * np.log10(np.maximum(reference_energy, 1E-10))
126
+ hypothesis_energy = 10 * np.log10(np.maximum(hypothesis_energy, 1E-10))
127
+
128
+ # (4) Compute spectral slope (dB[i+1]-dB[i])
129
+ reference_slope = reference_energy[1:num_crit] - reference_energy[0: num_crit - 1]
130
+ hypothesis_slope = hypothesis_energy[1:num_crit] - hypothesis_energy[0: num_crit - 1]
131
+
132
+ # (5) Find the nearest peak locations in the spectra to each critical band.
133
+ # If the slope is negative, we search to the left. If positive, we search to the right.
134
+ reference_loc_peak = np.empty(num_crit - 1)
135
+ hypothesis_loc_peak = np.empty(num_crit - 1)
136
+
137
+ for i in range(num_crit - 1):
138
+ # find the peaks in the reference speech signal
139
+ if reference_slope[i] > 0:
140
+ # search to the right
141
+ n = i
142
+ while (n < num_crit - 1) and (reference_slope[n] > 0):
143
+ n = n + 1
144
+ reference_loc_peak[i] = reference_energy[n - 1]
145
+ else:
146
+ # search to the left
147
+ n = i
148
+ while (n >= 0) and (reference_slope[n] <= 0):
149
+ n = n - 1
150
+ reference_loc_peak[i] = reference_energy[n + 1]
151
+
152
+ # find the peaks in the hypothesis speech signal
153
+ if hypothesis_slope[i] > 0:
154
+ # search to the right
155
+ n = i
156
+ while (n < num_crit - 1) and (hypothesis_slope[n] > 0):
157
+ n = n + 1
158
+ hypothesis_loc_peak[i] = hypothesis_energy[n - 1]
159
+ else:
160
+ # search to the left
161
+ n = i
162
+ while (n >= 0) and (hypothesis_slope[n] <= 0):
163
+ n = n - 1
164
+ hypothesis_loc_peak[i] = hypothesis_energy[n + 1]
165
+
166
+ # (6) Compute the weighted spectral slope measure for this frame.
167
+ # This includes determination of the weighting function.
168
+ db_max_reference = np.max(reference_energy)
169
+ db_max_hypothesis = np.max(hypothesis_energy)
170
+
171
+ # The weights are calculated by averaging individual weighting factors from the reference and hypothesis frame.
172
+ # These weights w_reference and w_hypothesis should range from 0 to 1 and place more emphasis on spectral peaks
173
+ # and less emphasis on slope differences in spectral valleys.
174
+ # This procedure is described on page 1280 of Klatt's 1982 ICASSP paper.
175
+
176
+ w_max_reference = np.divide(k_max, k_max + db_max_reference - reference_energy[0: num_crit - 1])
177
+ w_loc_max_reference = np.divide(k_loc_max, k_loc_max + reference_loc_peak - reference_energy[0: num_crit - 1])
178
+ w_reference = np.multiply(w_max_reference, w_loc_max_reference)
179
+
180
+ w_max_hypothesis = np.divide(k_max, k_max + db_max_hypothesis - hypothesis_energy[0: num_crit - 1])
181
+ w_loc_max_hypothesis = np.divide(k_loc_max,
182
+ k_loc_max + hypothesis_loc_peak - hypothesis_energy[0: num_crit - 1])
183
+ w_hypothesis = np.multiply(w_max_hypothesis, w_loc_max_hypothesis)
184
+
185
+ w = np.divide(np.add(w_reference, w_hypothesis), 2.0)
186
+ slope_diff = np.subtract(reference_slope, hypothesis_slope)[0: num_crit - 1]
187
+ distortion[frame_count] = np.dot(w, np.square(slope_diff)) / np.sum(w)
188
+
189
+ # This normalization is not part of Klatt's paper, but helps to normalize the measure.
190
+ # Here we scale the measure by the sum of the weights.
191
+ start = start + skip_rate
192
+
193
+ return distortion
194
+
195
+
196
+ def _calc_log_likelihood_ratio_measure(hypothesis: np.ndarray,
197
+ reference: np.ndarray,
198
+ sample_rate: int = SAMPLE_RATE) -> np.ndarray:
199
+ from scipy.linalg import toeplitz
200
+
201
+ # The lengths of the reference and hypothesis must be the same.
202
+ reference_length = np.size(reference)
203
+ hypothesis_length = np.size(hypothesis)
204
+ if reference_length != hypothesis_length:
205
+ raise ValueError('Hypothesis and reference must be the same length.')
206
+
207
+ # window length in samples
208
+ win_length = int(np.round(30 * sample_rate / 1000))
209
+ # window skip in samples
210
+ skip_rate = int(np.floor(win_length / 4))
211
+ # LPC analysis order; this could vary depending on sampling frequency.
212
+ if sample_rate < 10000:
213
+ p = 10
214
+ else:
215
+ p = 16
216
+
217
+ # For each frame of input speech, calculate the log likelihood ratio
218
+ num_frames = int((reference_length - win_length) / skip_rate)
219
+ start = 0
220
+ window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
221
+
222
+ distortion = np.empty(num_frames)
223
+ for frame_count in range(num_frames):
224
+ # (1) Get the frames for the test and reference speech. Multiply by Hanning window.
225
+ reference_frame = reference[start: start + win_length]
226
+ hypothesis_frame = hypothesis[start: start + win_length]
227
+ reference_frame = np.multiply(reference_frame, window)
228
+ hypothesis_frame = np.multiply(hypothesis_frame, window)
229
+
230
+ # (2) Get the autocorrelation lags and LPC parameters used to compute the log likelihood ratio measure.
231
+ r_reference, ref_reference, a_reference = _lp_coefficients(reference_frame, p)
232
+ r_hypothesis, ref_hypothesis, a_hypothesis = _lp_coefficients(hypothesis_frame, p)
233
+
234
+ # (3) Compute the log likelihood ratio measure
235
+ numerator = np.dot(np.matmul(a_hypothesis, toeplitz(r_reference)), a_hypothesis)
236
+ denominator = np.dot(np.matmul(a_reference, toeplitz(r_reference)), a_reference)
237
+ distortion[frame_count] = np.log(numerator / denominator)
238
+ start = start + skip_rate
239
+ return distortion
240
+
241
+
242
+ def _calc_snr(hypothesis: np.ndarray,
243
+ reference: np.ndarray,
244
+ sample_rate: int = SAMPLE_RATE) -> tuple[float, np.ndarray]:
245
+ # The lengths of the reference and hypothesis must be the same.
246
+ reference_length = len(reference)
247
+ hypothesis_length = len(hypothesis)
248
+ if reference_length != hypothesis_length:
249
+ raise ValueError('Hypothesis and reference must be the same length.')
250
+
251
+ overall_snr = 10 * np.log10(np.sum(np.square(reference)) / np.sum(np.square(reference - hypothesis)))
252
+
253
+ # window length in samples
254
+ win_length = round(30 * sample_rate / 1000)
255
+ # window skip in samples
256
+ skip_rate = int(np.floor(win_length / 4))
257
+ # minimum SNR in dB
258
+ min_snr = -10
259
+ # maximum SNR in dB
260
+ max_snr = 35
261
+
262
+ # For each frame of input speech, calculate the segmental SNR
263
+ num_frames = int(reference_length / skip_rate - (win_length / skip_rate))
264
+ start = 0
265
+ window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, win_length + 1) / (win_length + 1)))
266
+
267
+ segmental_snr = np.empty(num_frames)
268
+ eps = np.spacing(1)
269
+ for frame_count in range(num_frames):
270
+ # (1) Get the frames for the test and reference speech. Multiply by Hanning window.
271
+ reference_frame = reference[start:start + win_length]
272
+ hypothesis_frame = hypothesis[start:start + win_length]
273
+ reference_frame = np.multiply(reference_frame, window)
274
+ hypothesis_frame = np.multiply(hypothesis_frame, window)
275
+
276
+ # (2) Compute the segmental SNR
277
+ signal_energy = np.sum(np.square(reference_frame))
278
+ noise_energy = np.sum(np.square(reference_frame - hypothesis_frame))
279
+ segmental_snr[frame_count] = np.clip(10 * np.log10(signal_energy / (noise_energy + eps) + eps),
280
+ min_snr,
281
+ max_snr)
282
+
283
+ start = start + skip_rate
284
+
285
+ return overall_snr, segmental_snr
286
+
287
+
288
+ def _lp_coefficients(speech_frame, model_order):
289
+ # (1) Compute autocorrelation lags
290
+ win_length = np.size(speech_frame)
291
+ autocorrelation = np.empty(model_order + 1)
292
+ e = np.empty(model_order + 1)
293
+ for k in range(model_order + 1):
294
+ autocorrelation[k] = np.dot(speech_frame[0:win_length - k], speech_frame[k: win_length])
295
+
296
+ # (2) Levinson-Durbin
297
+ a = np.ones(model_order)
298
+ a_past = np.empty(model_order)
299
+ ref_coefficients = np.empty(model_order)
300
+ e[0] = autocorrelation[0]
301
+ for i in range(model_order):
302
+ a_past[0: i] = a[0: i]
303
+ sum_term = np.dot(a_past[0: i], autocorrelation[i:0:-1])
304
+ ref_coefficients[i] = (autocorrelation[i + 1] - sum_term) / e[i]
305
+ a[i] = ref_coefficients[i]
306
+ if i == 0:
307
+ a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1:-1:-1], ref_coefficients[i])
308
+ else:
309
+ a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1::-1], ref_coefficients[i])
310
+ e[i + 1] = (1 - ref_coefficients[i] * ref_coefficients[i]) * e[i]
311
+ lp_params = np.concatenate((np.array([1]), -a))
312
+ return autocorrelation, ref_coefficients, lp_params
@@ -1,8 +1,7 @@
1
- from dataclasses import dataclass
1
+ from typing import NamedTuple
2
2
 
3
3
 
4
- @dataclass(frozen=True)
5
- class WerResult:
4
+ class WerResult(NamedTuple):
6
5
  wer: float
7
6
  words: int
8
7
  substitutions: float
@@ -52,62 +52,3 @@ def calc_wsdr(hypothesis: np.ndarray,
52
52
  wsdr = 10 * np.log10(-1 / (wsdr - 1 - 1e-7)) # range -3 --> inf (or 1e-7 limit of 70db)
53
53
 
54
54
  return float(wsdr), cc, cw
55
-
56
- # From calc_sa_sdr:
57
- # These should include a noise to be a complete mixture estimate, i.e.,
58
- # noise_est = sum-over-all-srcs(s_est(0:nsamples, :) - sum-over-non-noisesrc(s_est(0:nsamples, n))
59
- # should be one of the sources in reference (s_true) and hypothesis (s_est).
60
- #
61
- # Calculates -10*log10(sumn(||sn||^2) / sumn(||sn - shn||^2)
62
- # Note: for SA method, sums are done independently on ref and error before division, vs. SDR and SI-SDR
63
- # where sum over n is taken after divide (before log). This is more stable in noise-only cases and also
64
- # when some sources are poorly estimated.
65
- # TBD: add soft-max option with eps and tau params
66
- #
67
- # if with_scale:
68
- # # calc 1 x nsrc scaling factors
69
- # ref_energy = np.sum(reference ** 2, axis=0, keepdims=True)
70
- # # if ref_energy is zero, just set scaling to 1.0
71
- # with np.errstate(divide='ignore', invalid='ignore'):
72
- # opt_scale = np.sum(reference * hypothesis, axis=0, keepdims=True) / ref_energy
73
- # opt_scale[opt_scale == np.inf] = 1.0
74
- # opt_scale = np.nan_to_num(opt_scale, nan=1.0)
75
- # scaled_ref = opt_scale * reference
76
- # else:
77
- # scaled_ref = reference
78
- # opt_scale = np.ones((1, reference.shape[1]), dtype=float)
79
- #
80
- # # Calculate Lsdr = −<y,yˆ>/∥y∥∥yˆ∥ always in range [1 --> -1], size [batch,]
81
- # t_tru_sq = torch.sum(torch.square(t_tru), -1)
82
- # t_denom = torch.sqrt(t_tru_sq) * torch.sqrt(torch.sum(torch.square(t_est), -1)) + 1e-7
83
- # t_wsdr = -torch.divide(torch.sum(torch.multiply(t_tru, t_est), -1), t_denom)
84
- # n_tru_sq = torch.sum(torch.square(n_tru), -1)
85
- # n_denom = torch.sqrt(torch.sum(torch.square(n_tru), -1)) \
86
- # * torch.sqrt(torch.sum(torch.square(n_est), -1)) + 1e-7
87
- # n_wsdr = -torch.divide(torch.sum(torch.multiply(n_tru, n_est), -1), n_denom)
88
- # if self.cl_noise_wght > 0:
89
- # wsdr = self.cl_target_wght * t_wsdr + self.cl_noise_wght * n_wsdr
90
- # else: # adaptive per relative strength of target vs noise: α = ||y||2/(||y||2 +||z||2)
91
- # tweight = torch.divide(t_tru_sq, t_tru_sq + n_tru_sq + 1e-7) # energy ratio target vs. noise
92
- # wsdr = tweight * t_wsdr + (1 - tweight) * n_wsdr
93
- # wsdr = torch.mean(wsdr) # reduction to scalar
94
- #
95
- # # multisrc sa-sdr, inputs must be [samples, nsrc]
96
- # err = scaled_ref - hypothesis
97
- #
98
- # # -10*log10(sumk(||sk||^2) / sumk(||sk - shk||^2)
99
- # # sum over samples and sources
100
- # num = np.sum(reference ** 2)
101
- # den = np.sum(err ** 2)
102
- # if num == 0 and den == 0:
103
- # ratio = np.inf
104
- # else:
105
- # ratio = num / (den + np.finfo(np.float32).eps)
106
- #
107
- # sa_sdr = 10 * np.log10(ratio)
108
- #
109
- # if with_negate:
110
- # # for use as a loss function
111
- # sa_sdr = -sa_sdr
112
- #
113
- # return sa_sdr, opt_scale
@@ -44,6 +44,7 @@ from .constants import VALID_CONFIGS
44
44
  from .constants import VALID_NOISE_MIX_MODES
45
45
  from .constants import VALID_TRUTH_SETTINGS
46
46
  from .datatypes import AudioF
47
+ from .datatypes import AudioStatsMetrics
47
48
  from .datatypes import AudioT
48
49
  from .datatypes import AudiosF
49
50
  from .datatypes import AudiosT
@@ -72,9 +73,11 @@ from .datatypes import NoiseFile
72
73
  from .datatypes import NoiseFiles
73
74
  from .datatypes import Predict
74
75
  from .datatypes import Segsnr
76
+ from .datatypes import SnrFMetrics
75
77
  from .datatypes import SpectralMask
76
78
  from .datatypes import SpectralMasks
77
79
  from .datatypes import SpeechMetadata
80
+ from .datatypes import SpeechMetrics
78
81
  from .datatypes import TargetFile
79
82
  from .datatypes import TargetFiles
80
83
  from .datatypes import TransformConfig
@@ -113,8 +116,6 @@ from .helpers import read_mixture_data
113
116
  from .helpers import write_mixture_data
114
117
  from .helpers import write_mixture_metadata
115
118
  from .log_duration_and_sizes import log_duration_and_sizes
116
- from .mapped_snr_f import calculate_mapped_snr_f
117
- from .mapped_snr_f import calculate_snr_f_statistics
118
119
  from .mixdb import MixtureDatabase
119
120
  from .mixdb import db_file
120
121
  from .sox_audio import Transformer
sonusai/mixture/audio.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from functools import lru_cache
2
+ from pathlib import Path
2
3
 
3
4
  from sonusai.mixture.datatypes import AudioT
4
5
  from sonusai.mixture.datatypes import ImpulseResponseData
@@ -28,7 +29,7 @@ def get_duration(audio: AudioT) -> float:
28
29
  return len(audio) / SAMPLE_RATE
29
30
 
30
31
 
31
- def validate_input_file(input_filepath: str) -> None:
32
+ def validate_input_file(input_filepath: str | Path) -> None:
32
33
  from os.path import exists
33
34
  from os.path import splitext
34
35
 
@@ -46,7 +47,7 @@ def validate_input_file(input_filepath: str) -> None:
46
47
 
47
48
 
48
49
  @lru_cache
49
- def get_sample_rate(name: str) -> int:
50
+ def get_sample_rate(name: str | Path) -> int:
50
51
  """Get sample rate from audio file
51
52
 
52
53
  :param name: File name
@@ -58,7 +59,7 @@ def get_sample_rate(name: str) -> int:
58
59
 
59
60
 
60
61
  @lru_cache
61
- def read_audio(name: str) -> AudioT:
62
+ def read_audio(name: str | Path) -> AudioT:
62
63
  """Read audio data from a file
63
64
 
64
65
  :param name: File name
@@ -70,7 +71,7 @@ def read_audio(name: str) -> AudioT:
70
71
 
71
72
 
72
73
  @lru_cache
73
- def read_ir(name: str) -> ImpulseResponseData:
74
+ def read_ir(name: str | Path) -> ImpulseResponseData:
74
75
  """Read impulse response data
75
76
 
76
77
  :param name: File name
@@ -82,7 +83,7 @@ def read_ir(name: str) -> ImpulseResponseData:
82
83
 
83
84
 
84
85
  @lru_cache
85
- def get_num_samples(name: str) -> int:
86
+ def get_num_samples(name: str | Path) -> int:
86
87
  """Get the number of samples resampled to the SonusAI sample rate in the given file
87
88
 
88
89
  :param name: File name
sonusai/mixture/config.py CHANGED
@@ -90,6 +90,19 @@ def update_config_from_file(name: str, config: dict) -> dict:
90
90
 
91
91
  updated_config['truth_settings'] = update_truth_settings(updated_config['truth_settings'], default)
92
92
 
93
+ # Handle 'asr_configs' special case
94
+ if 'asr_configs' in updated_config:
95
+ asr_configs = {}
96
+ for asr_config in updated_config['asr_configs']:
97
+ asr_name = asr_config.get('name', None)
98
+ asr_engine = asr_config.get('engine', None)
99
+ if asr_name is None or asr_engine is None:
100
+ raise SonusAIError(f'Invalid config parameter in {name}: asr_configs.\n'
101
+ f'asr_configs must contain both name and engine.')
102
+ del asr_config['name']
103
+ asr_configs[asr_name] = asr_config
104
+ updated_config['asr_configs'] = asr_configs
105
+
93
106
  # Check for required keys
94
107
  for key in REQUIRED_CONFIGS:
95
108
  if key not in updated_config:
@@ -4,6 +4,7 @@ from importlib.resources import as_file
4
4
  from importlib.resources import files
5
5
 
6
6
  REQUIRED_CONFIGS = [
7
+ 'asr_configs',
7
8
  'class_balancing',
8
9
  'class_balancing_augmentation',
9
10
  'class_labels',
@@ -1,4 +1,6 @@
1
1
  from dataclasses import dataclass
2
+ from typing import Any
3
+ from typing import NamedTuple
2
4
  from typing import Optional
3
5
  from typing import TypeAlias
4
6
 
@@ -309,8 +311,12 @@ class FeatureGeneratorInfo:
309
311
  it_config: TransformConfig
310
312
 
311
313
 
314
+ ASRConfigs: TypeAlias = dict[str, dict[str, Any]]
315
+
316
+
312
317
  @dataclass
313
318
  class MixtureDatabaseConfig(DataClassSonusAIMixin):
319
+ asr_configs: Optional[ASRConfigs] = None
314
320
  class_balancing: Optional[bool] = False
315
321
  class_labels: Optional[list[str]] = None
316
322
  class_weights_threshold: Optional[list[float]] = None
@@ -327,3 +333,30 @@ class MixtureDatabaseConfig(DataClassSonusAIMixin):
327
333
 
328
334
 
329
335
  SpeechMetadata: TypeAlias = str | list[Interval] | None
336
+
337
+
338
+ class SnrFMetrics(NamedTuple):
339
+ mean: Optional[float] = None
340
+ var: Optional[float] = None
341
+ db_mean: Optional[float] = None
342
+ db_std: Optional[float] = None
343
+
344
+
345
+ class SpeechMetrics(NamedTuple):
346
+ pesq: Optional[float] = None
347
+ c_sig: Optional[float] = None
348
+ c_bak: Optional[float] = None
349
+ c_ovl: Optional[float] = None
350
+
351
+
352
+ class AudioStatsMetrics(NamedTuple):
353
+ dco: Optional[float] = None
354
+ min: Optional[float] = None
355
+ max: Optional[float] = None
356
+ pkdb: Optional[float] = None
357
+ lrms: Optional[float] = None
358
+ pkr: Optional[float] = None
359
+ tr: Optional[float] = None
360
+ cr: Optional[float] = None
361
+ fl: Optional[float] = None
362
+ pkc: Optional[float] = None
@@ -59,6 +59,7 @@ def initialize_db(location: str, test: bool = False) -> None:
59
59
  CREATE TABLE top (
60
60
  id INTEGER PRIMARY KEY NOT NULL,
61
61
  version INTEGER NOT NULL,
62
+ asr_configs TEXT NOT NULL,
62
63
  class_balancing BOOLEAN NOT NULL,
63
64
  feature TEXT NOT NULL,
64
65
  noise_mix_mode TEXT NOT NULL,
@@ -149,6 +150,8 @@ def initialize_db(location: str, test: bool = False) -> None:
149
150
  def populate_top_table(location: str, config: dict, test: bool = False) -> None:
150
151
  """Populate top table
151
152
  """
153
+ import json
154
+
152
155
  from sonusai import SonusAIError
153
156
  from .mixdb import db_connection
154
157
 
@@ -158,11 +161,12 @@ def populate_top_table(location: str, config: dict, test: bool = False) -> None:
158
161
 
159
162
  con = db_connection(location=location, readonly=False, test=test)
160
163
  con.execute("""
161
- INSERT INTO top (version, class_balancing, feature, noise_mix_mode, num_classes,
164
+ INSERT INTO top (version, asr_configs, class_balancing, feature, noise_mix_mode, num_classes,
162
165
  seed, truth_mutex, truth_reduction_function, mixid_width, speaker_metadata_tiers, textgrid_metadata_tiers)
163
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
166
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
164
167
  """, (
165
168
  1,
169
+ json.dumps(config['asr_configs']),
166
170
  config['class_balancing'],
167
171
  config['feature'],
168
172
  config['noise_mix_mode'],