auditory-models 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,19 @@
1
+ # This file is part of auditory_models
2
+ # Copyright (C) 2025 Max Zimmermann
3
+ #
4
+ # auditory_models is free software: you can redistribute it and/or modify
5
+ # it under the terms of the GNU General Public License as published by
6
+ # the Free Software Foundation, either version 3 of the License, or
7
+ # (at your option) any later version.
8
+ #
9
+ # auditory_models is distributed in the hope that it will be useful,
10
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
11
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12
+ # GNU General Public License for more details.
13
+ #
14
+ # You should have received a copy of the GNU General Public License
15
+ # along with auditory_models. If not, see <https://www.gnu.org/licenses/>.
16
+
17
+
18
+ from .stoi import STOI
19
+ from .gpsmq import GPSMq
@@ -0,0 +1,18 @@
1
+ # This file is part of auditory_models
2
+ # Copyright (C) 2025 Max Zimmermann
3
+ #
4
+ # auditory_models is free software: you can redistribute it and/or modify
5
+ # it under the terms of the GNU General Public License as published by
6
+ # the Free Software Foundation, either version 3 of the License, or
7
+ # (at your option) any later version.
8
+ #
9
+ # auditory_models is distributed in the hope that it will be useful,
10
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
11
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12
+ # GNU General Public License for more details.
13
+ #
14
+ # You should have received a copy of the GNU General Public License
15
+ # along with auditory_models. If not, see <https://www.gnu.org/licenses/>.
16
+
17
+
18
+ from .gpsmq import GPSMq
@@ -0,0 +1,349 @@
1
+ # This file is part of auditory_models
2
+ # Copyright (C) 2025 Max Zimmermann
3
+ #
4
+ # auditory_models is free software: you can redistribute it and/or modify
5
+ # it under the terms of the GNU General Public License as published by
6
+ # the Free Software Foundation, either version 3 of the License, or
7
+ # (at your option) any later version.
8
+ #
9
+ # auditory_models is distributed in the hope that it will be useful,
10
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
11
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12
+ # GNU General Public License for more details.
13
+ #
14
+ # You should have received a copy of the GNU General Public License
15
+ # along with auditory_models. If not, see <https://www.gnu.org/licenses/>.
16
+
17
+
18
+ import numpy as np
19
+ from numpy.typing import NDArray
20
+ from scipy.signal import butter, fftconvolve, firwin, sosfilt
21
+ from warnings import warn
22
+
23
+ from auditory_models.helpers.utils import thirdoct, iso389_7_thresholds
24
+ from auditory_models.helpers.filterbank import BandpassFilterbank, GammatoneFilterbank
25
+
26
+
27
+ class GPSMq:
28
+ def __init__(self, binaural: bool = True, aud_filt_range: tuple[int, int] = (7, 24),
29
+ mod_filt_range: tuple[int, int] = (1, 7), corr_thres: float = 0.8, decimation_factor: int = 8,
30
+ limits: tuple[float, float] = (-65, -100), threshold_scaling: float = 1e-10):
31
+ """
32
+ Init method
33
+ :param binaural: Determines if binaural processing should be used, if True the input of process() must have
34
+ shapes of (2, signal_length).
35
+ :param aud_filt_range: The range of Gammatone-Filterbank center-frequencies as indices of a vector of
36
+ frequencies in third-octave distances, starting with 63Hz and ending with 16000Hz as defined by
37
+ IEC 61260-1:2014.
38
+ :param mod_filt_range: The range of Modulation-Filterbank center-frequencies as exponents of 2. The default of
39
+ (1, 7) would result in [2, 4, 8, 16, 32, 64] Hz as center-frequencies.
40
+ :param corr_thres: Scaling value as parameter for the sigmoid function that will process the correlation matrix.
41
+ :param decimation_factor: Factor of decimation of the lowpass filtered Hilbert envelope signals.
42
+ :param limits: Limits in dB for the multi resolution based power processing
43
+ :param threshold_scaling: Scaling factor for the squared hearing threshold. Default value is chosen with the
44
+ assumption that 0 dB RMS equals 100 dB SPL.
45
+ """
46
+ self._sample_rate = 0
47
+ self._binaural = binaural
48
+ self._n_chan = self._binaural + 1
49
+ self._aud_filt_range = aud_filt_range
50
+ self._mod_filt_range = mod_filt_range
51
+ self._corr_thres = corr_thres
52
+ self._decimation_factor = decimation_factor
53
+ self._limits = limits
54
+ self._threshold_scaling = threshold_scaling
55
+ self._slope = 1 / (self._limits[0] - self._limits[1])
56
+ self._env_lp_sos = None
57
+ self._sig_len = 0
58
+ self._sig_len_dec = 0
59
+ self._env_pow_lim = 10 ** -2.7
60
+
61
+ self._decimation_order = 20 * self._decimation_factor
62
+ self.decimation_filter = None
63
+ self._sample_rate_dec = None
64
+ self.gtfb = None
65
+ self.mf_cf = np.pow(2, np.arange(*self._mod_filt_range))
66
+ self.mfb = None
67
+ self._iso_thres = None
68
+ self._upper_lim = None
69
+
70
+ def _recompute_properties(self) -> None:
71
+ """
72
+ When process receives a sample_rate value that is different from the previous one, recompute all dependent
73
+ properties.
74
+ :return: None
75
+ """
76
+ self._env_lp_sos = butter(1, 150, fs=self._sample_rate, output="sos")
77
+ # cutoff frequency for decimation filter is computed with the decimation factor and an extra factor for the
78
+ # filter to be effective at the new nyquist frequency
79
+ cutoff = self._sample_rate / (2 * self._decimation_factor) * (5 / 6)
80
+ self.decimation_filter = firwin(self._decimation_order + 1, cutoff, fs=self._sample_rate, window=("kaiser", 5.),
81
+ pass_zero=True)
82
+ self._sample_rate_dec = self._sample_rate / self._decimation_factor
83
+ gt_cf = thirdoct(self._sample_rate, nfft=1024, min_freq=63, max_freq=16000)[1]
84
+ self.gtfb = GammatoneFilterbank(self._sample_rate, cf=gt_cf[self._aud_filt_range[0]:self._aud_filt_range[1]])
85
+ self.mfb = BandpassFilterbank(self.mf_cf, self._sample_rate_dec)
86
+ self._iso_thres = 10 ** (iso389_7_thresholds(self.gtfb.cf) / 10)
87
+ self._upper_lim = self._iso_thres * 10 ** (self._limits[0] / 10)
88
+ self._iso_thres *= self._threshold_scaling
89
+
90
+ def process(self, reference: np.ndarray, degraded: np.ndarray, sample_rate: float) -> dict:
91
+ """
92
+ Process method to calculate the perceptual proximity of a degraded signal to a reference signal.
93
+ :param reference: Reference signal, if binaural it must have shape (2, signal_length)
94
+ :param degraded: Degraded signal, if binaural it must have shape (2, signal_length)
95
+ :param sample_rate: sample rate of both input signals in Hz
96
+ :return: dictionary containing measurement data
97
+ "snr_dc": SNR for the DC part
98
+ "snr_ac": SNR for the modulation part
99
+ "snr_ac_fix": SNR for the modulation part including weighting to reduce effects of IPDs/ITDs
100
+ "opm": combined snr_dc and snr_ac into perceptual measure (MUSHRA scale)
101
+ "opm_fix": combined snr_dc and snr_ac_fix into perceptual measure (MUSHRA scale)
102
+ """
103
+ reference = np.squeeze(reference)
104
+ degraded = np.squeeze(degraded)
105
+ if self._binaural:
106
+ if reference.ndim != 2:
107
+ raise ValueError(f"reference must have two dimensions, currently: {reference.ndim}")
108
+ if reference.shape[0] != 2:
109
+ raise ValueError(f"reference must have a size of 2 in first dimension, currently shape is: "
110
+ f"{reference.shape}")
111
+ if degraded.ndim != 2:
112
+ raise ValueError(f"degraded must have two dimensions, currently: {degraded.ndim}")
113
+ if degraded.shape[0] != 2:
114
+ raise ValueError(f"degraded must have a size of 2 in first dimension, currently shape is: "
115
+ f"{degraded.shape}")
116
+ if reference.shape[1] != degraded.shape[1]:
117
+ raise ValueError(f"reference and degraded must have same signal lengths, currently "
118
+ f"{reference.shape[1]} and {degraded.shape[1]}.")
119
+ else:
120
+ if reference.ndim != 1:
121
+ raise ValueError(f"reference must have one dimension, currently: {reference.ndim}")
122
+ if degraded.ndim != 1:
123
+ raise ValueError(f"degraded must have one dimension, currently: {degraded.ndim}")
124
+ if reference.size != degraded.size:
125
+ raise ValueError(f"reference and degraded must have same signal lengths, currently "
126
+ f"{reference.size} and {degraded.size}.")
127
+ reference = reference[np.newaxis, :]
128
+ degraded = degraded[np.newaxis, :]
129
+
130
+ if self._sig_len != reference.shape[reference.ndim - 1]:
131
+ self._sig_len = reference.shape[reference.ndim - 1]
132
+ self._sig_len_dec = round(np.ceil(self._sig_len / self._decimation_factor))
133
+
134
+ if sample_rate != self._sample_rate:
135
+ self._sample_rate = sample_rate
136
+ self._recompute_properties()
137
+
138
+ if self._sample_rate / 2 < self.gtfb.cf[-1]:
139
+ warn(f"Given sample rate ({self._sample_rate}) does not fit Nyquist Theorem with the center-frequency of "
140
+ f"the highest band ({self.gtfb.cf[-1]}).")
141
+
142
+ # Auditory filtering via Gammatone Filterbank and Hilbert Envelope
143
+ ref_hilbert = np.zeros((self._n_chan, self.gtfb.cf.size, self._sig_len))
144
+ dgr_hilbert = np.zeros((self._n_chan, self.gtfb.cf.size, self._sig_len))
145
+ for ch in range(self._n_chan):
146
+ ref_hilbert[ch, :, :] = np.abs(
147
+ self.gtfb.process(reference[ch, :].astype(np.complex128), save_state=False)) / np.sqrt(2)
148
+ dgr_hilbert[ch, :, :] = np.abs(
149
+ self.gtfb.process(degraded[ch, :].astype(np.complex128), save_state=False)) / np.sqrt(2)
150
+
151
+ # ILD cancellation
152
+ if self._binaural:
153
+ ref_rms = np.sqrt(np.mean(np.square(ref_hilbert), axis=2))
154
+ dgr_rms = np.sqrt(np.mean(np.square(dgr_hilbert), axis=2))
155
+ ref_lr_diff = np.squeeze(np.abs(np.diff(ref_rms, axis=0)))
156
+ dgr_lr_diff = np.squeeze(np.abs(np.diff(dgr_rms, axis=0)))
157
+ ref_dgr_l_diff = ref_rms[0, :] - dgr_rms[0, :]
158
+ ref_dgr_r_diff = ref_rms[1, :] - dgr_rms[1, :]
159
+ for band in range(self.gtfb.cf.size):
160
+ if np.sign(ref_dgr_l_diff[band]) * np.sign(ref_dgr_r_diff[band]) == -1:
161
+ if ref_dgr_l_diff[band] > 0:
162
+ dgr_hilbert[0, band, :] *= ref_rms[0, band] / dgr_rms[0, band]
163
+ elif ref_dgr_l_diff[band] < 0:
164
+ ref_hilbert[0, band, :] *= dgr_rms[0, band] / ref_rms[0, band]
165
+
166
+ if ref_dgr_r_diff[band] > 0:
167
+ dgr_hilbert[1, band, :] *= ref_rms[1, band] / dgr_rms[1, band]
168
+ elif ref_dgr_r_diff[band] < 0:
169
+ ref_hilbert[1, band, :] *= dgr_rms[1, band] / ref_rms[1, band]
170
+ else:
171
+ if np.abs(ref_dgr_l_diff[band]) > np.abs(ref_dgr_r_diff[band]):
172
+ if ref_lr_diff[band] > dgr_lr_diff[band]:
173
+ ref_hilbert[0, band, :] *= ref_rms[1, band] / ref_rms[0, band]
174
+ elif ref_lr_diff[band] < dgr_lr_diff[band]:
175
+ dgr_hilbert[0, band, :] *= dgr_rms[1, band] / dgr_rms[0, band]
176
+ elif np.abs(ref_dgr_l_diff[band]) < np.abs(ref_dgr_r_diff[band]):
177
+ if ref_lr_diff[band] > dgr_lr_diff[band]:
178
+ ref_hilbert[1, band, :] *= ref_rms[0, band] / ref_rms[1, band]
179
+ elif ref_lr_diff[band] < dgr_lr_diff[band]:
180
+ dgr_hilbert[1, band, :] *= dgr_rms[0, band] / dgr_rms[1, band]
181
+
182
+ # apply Lowpass-Filter at 150Hz
183
+ ref_lp = sosfilt(self._env_lp_sos, ref_hilbert, axis=2)
184
+ dgr_lp = sosfilt(self._env_lp_sos, dgr_hilbert, axis=2)
185
+
186
+ # decimate lowpass filtered signal
187
+ ref_lp_dec = fftconvolve(ref_lp, self.decimation_filter[np.newaxis, np.newaxis, :],
188
+ mode="same", axes=-1)[:, :, ::self._decimation_factor]
189
+ dgr_lp_dec = fftconvolve(dgr_lp, self.decimation_filter[np.newaxis, np.newaxis, :],
190
+ mode="same", axes=-1)[:, :, ::self._decimation_factor]
191
+
192
+ # apply modulation filterbank
193
+ ref_mod = np.zeros((self._n_chan, self.gtfb.cf.size, self.mfb.n_filters, self._sig_len_dec))
194
+ dgr_mod = np.zeros((self._n_chan, self.gtfb.cf.size, self.mfb.n_filters, self._sig_len_dec))
195
+ for ch in range(self._n_chan):
196
+ for gt in range(self.gtfb.cf.size):
197
+ ref_mod[ch, gt, :, :] = self.mfb.process(ref_lp_dec[ch, gt, :], save_state=False)
198
+ dgr_mod[ch, gt, :, :] = self.mfb.process(dgr_lp_dec[ch, gt, :], save_state=False)
199
+
200
+ # calculate multi-resolution-based envelope power and short-time power
201
+ ref_epsm, ref_psm, ref_dc2mod = self.multi_resolution_based_power(ref_lp_dec, ref_mod)
202
+ dgr_epsm, dgr_psm, dgr_dc2mod = self.multi_resolution_based_power(dgr_lp_dec, dgr_mod)
203
+
204
+ # calculate correlation matrix
205
+ corr_mat = np.ones((self._n_chan, self.gtfb.cf.size, self.mfb.n_filters - 1))
206
+ for ch in range(self._n_chan):
207
+ for aud_idx in range(self.gtfb.cf.size):
208
+ for mf_idx in range(1, self.mfb.n_filters):
209
+ corr_mat[ch, aud_idx, mf_idx - 1] = np.corrcoef(
210
+ ref_mod[ch, aud_idx, mf_idx, round(self._sample_rate_dec / 4) - 1::],
211
+ dgr_mod[ch, aud_idx, mf_idx, round(self._sample_rate_dec / 4) - 1::])[0, 1]
212
+
213
+ # psm above iso thresholds
214
+ clipper = np.logical_and(ref_psm > self._iso_thres[np.newaxis, :, np.newaxis, np.newaxis],
215
+ dgr_psm > self._iso_thres[np.newaxis, :, np.newaxis, np.newaxis])
216
+
217
+ # calculate SNR increment and decrement
218
+ snr_inc_epsm = np.fmax(np.fmin(dgr_epsm / (ref_epsm + 1e-30) - 1, 20), 0.) * dgr_dc2mod
219
+ snr_inc_epsm[np.logical_not(clipper)] = 0.
220
+ snr_inc_psm = np.fmax(np.fmin(dgr_psm / (ref_psm + 1e-30) - 1, 20), 0.)
221
+
222
+ snr_dec_epsm = np.fmax(np.fmin(ref_epsm / (dgr_epsm + 1e-30) - 1, 20), 0.) * ref_dc2mod
223
+ snr_dec_epsm[np.logical_not(clipper)] = 0.
224
+ snr_dec_psm = np.fmax(np.fmin(ref_psm / (dgr_psm + 1e-30) - 1, 20), 0.)
225
+
226
+ snr_inc_mod = np.zeros((self._n_chan, self.gtfb.cf.size, self.mfb.n_filters))
227
+ snr_inc_stint = np.zeros(snr_inc_mod.shape)
228
+ snr_dec_mod = np.zeros(snr_inc_mod.shape)
229
+ snr_dec_stint = np.zeros(snr_inc_mod.shape)
230
+
231
+ for mf_idx in range(self.mfb.n_filters):
232
+ n_windows = round(np.ceil(self._sig_len_dec / np.floor(self._sample_rate_dec / self.mfb.cf[mf_idx])))
233
+ snr_inc_mod[:, :, mf_idx] = np.mean(snr_inc_epsm[:, :, mf_idx, 0:n_windows], axis=-1)
234
+ snr_inc_stint[:, :, mf_idx] = np.mean(snr_inc_psm[:, :, mf_idx, 0:n_windows], axis=-1)
235
+ snr_dec_mod[:, :, mf_idx] = np.mean(snr_dec_epsm[:, :, mf_idx, 0:n_windows], axis=-1)
236
+ snr_dec_stint[:, :, mf_idx] = np.mean(snr_dec_psm[:, :, mf_idx, 0:n_windows], axis=-1)
237
+ snr_inc_mod = np.fmax(snr_inc_mod, 0.)
238
+ snr_inc_stint = np.fmax(snr_inc_stint, 0.)
239
+ snr_dec_mod = np.fmax(snr_dec_mod, 0.)
240
+ snr_dec_stint = np.fmax(snr_dec_stint, 0.)
241
+ valid_mod_freq = self.gtfb.cf[:, np.newaxis] > 4 * self.mfb.cf[np.newaxis, :]
242
+ snr_inc_mod *= valid_mod_freq
243
+ snr_dec_mod *= valid_mod_freq
244
+
245
+ # combine SNR increment and decrement
246
+ snr_dc = np.mean(np.sqrt(np.sum(np.square((snr_inc_stint[:, :, 2] + snr_dec_stint[:, :, 2]) / 2), axis=-1)))
247
+ inc_dec_mean = (snr_inc_mod[:, :, 1::] + 10 ** (-0.7) * snr_dec_mod[:, :, 1::]) / 2
248
+ snr_ac = np.mean(np.sqrt(np.sum(np.sum(np.square(inc_dec_mean), axis=-1), axis=-1)))
249
+
250
+ # include weighting via correlation matrix to make model less sensitive to IPDs/ITDs
251
+ sigmoid_corr = 1 / (1 + np.exp(-50 * (corr_mat - self._corr_thres)))
252
+ snr_ac_fix = np.mean(np.sqrt(np.sum(np.sum(np.square(sigmoid_corr * inc_dec_mean), axis=-1), axis=-1)))
253
+
254
+ # transform to perceptual measure (MUSHRA scale)
255
+ opm = -4.08695250298565 * 10 * np.log10(snr_dc + snr_ac + 1e-30) + 75.6467755339438
256
+ opm_fix = -4.15742483856597 * 10 * np.log10(snr_dc + snr_ac_fix + 1e-30) + 74.7791007678067
257
+
258
+ return {"snr_dc": snr_dc, "snr_ac": snr_ac, "snr_ac_fix": snr_ac_fix, "opm": opm, "opm_fix": opm_fix}
259
+
260
+ def lowpass_filterbank(self, signal: NDArray) -> NDArray:
261
+ """
262
+ Compute output of a lowpass filterbank using moving averages with variable window sizes depending on the
263
+ modulation frequency.
264
+ :param signal: Input signal, must be one-dimensional
265
+ :return: two-dimensional matrix with filterbank output with shape (number_of_mod_filters, signal_length)
266
+ """
267
+ out = np.zeros((self.mfb.n_filters, self._sig_len_dec))
268
+ window = np.ones(round(np.ceil(2 * self._sample_rate_dec / np.min(self.mfb.cf))))
269
+ crit_mod_freq = 8.
270
+ idx_crit = 0
271
+ repeat = False
272
+ for idx, mod_freq in enumerate(self.mfb.cf):
273
+ if mod_freq <= crit_mod_freq:
274
+ win_width = round(1.5 * self._sample_rate_dec / mod_freq)
275
+ out[idx, :] = np.convolve(window[0:win_width] / win_width, signal, mode="full")[0:signal.size]
276
+ elif not repeat:
277
+ win_width = round(self._sample_rate_dec / 8)
278
+ out[idx, :] = np.convolve(window[0:win_width] / win_width, signal, mode="full")[0:signal.size]
279
+ idx_crit = idx
280
+ repeat = True
281
+ elif repeat:
282
+ out[idx, :] = out[idx_crit, :]
283
+ return out
284
+
285
+ def multi_resolution_based_power(self, envelope: NDArray, modulation: NDArray) -> tuple[NDArray, NDArray, NDArray]:
286
+ """
287
+ Calculate the multi resolution based power from an envelope signal and a modulation signal.
288
+ :param envelope: Lowpass filtered envelope signal for different frequency bands, must have shape
289
+ (channels, auditory bands, signal_length)
290
+ :param modulation: Modulation filtered signal for different frequency bands, must have shape
291
+ (channels, auditory bands, modulation bands, signal_length)
292
+ :return: Three arrays containing the envelope power spectrum model (EPSM), power spectrum model (PSM), and a
293
+ correction matrix (dc2mod). Each of these arrays has shape
294
+ (channels, auditory bands, modulation bands, new_length)
295
+ """
296
+ epsm = np.zeros((self._n_chan, self.gtfb.cf.size, self.mfb.n_filters,
297
+ round(np.ceil(self._sig_len_dec / np.floor(self._sample_rate_dec / self.mfb.cf[-1])))))
298
+ psm = np.zeros(epsm.shape)
299
+ dc2mod = np.ones(epsm.shape)
300
+ modulation *= np.sqrt(2)
301
+ for ch in range(self._n_chan):
302
+ for aud_idx in range(self.gtfb.cf.size):
303
+ tmp_pow = np.max((np.square(np.mean(envelope[ch, aud_idx, :])), self._iso_thres[aud_idx]))
304
+ lpfb = self.lowpass_filterbank(envelope[ch, aud_idx, :])
305
+ for mf_idx in range(self.mfb.n_filters):
306
+ samples_per_mf = round(np.floor(self._sample_rate_dec / self.mfb.cf[mf_idx]))
307
+ n_windows = round(np.ceil(self._sig_len_dec / samples_per_mf))
308
+ if tmp_pow <= self._iso_thres[aud_idx]:
309
+ epsm[ch, aud_idx, mf_idx, 0:n_windows] = self._env_pow_lim
310
+ psm[ch, aud_idx, mf_idx, 0:n_windows] = self._iso_thres[aud_idx]
311
+ continue
312
+ pow_dc_seg = lpfb[mf_idx, samples_per_mf-1::samples_per_mf] ** 2
313
+ if n_windows * samples_per_mf != self._sig_len_dec:
314
+ pow_dc_seg = np.concatenate((pow_dc_seg, [lpfb[mf_idx, -1] ** 2]))
315
+ for win_idx in range(n_windows):
316
+ ac = modulation[ch, aud_idx, mf_idx, win_idx * samples_per_mf:(win_idx + 1) * samples_per_mf]
317
+ if mf_idx == 0:
318
+ pow_ac = np.var(ac)
319
+ else:
320
+ pow_ac = np.mean(ac ** 2)
321
+ if pow_dc_seg[win_idx] <= self._iso_thres[aud_idx]:
322
+ epsm[ch, aud_idx, mf_idx, win_idx] = self._env_pow_lim
323
+ else:
324
+ epsm[ch, aud_idx, mf_idx, win_idx] = np.max((pow_ac / pow_dc_seg[win_idx],
325
+ self._env_pow_lim))
326
+ psm[ch, aud_idx, mf_idx, win_idx] = np.max((pow_dc_seg[win_idx], self._iso_thres[aud_idx]))
327
+ if pow_dc_seg[win_idx] > self._upper_lim[aud_idx]:
328
+ dc2mod[ch, aud_idx, mf_idx, win_idx] = 1.
329
+ else:
330
+ dc2mod[ch, aud_idx, mf_idx, win_idx] = np.max((
331
+ self._slope * (10 * np.log10(pow_dc_seg[win_idx] / self._iso_thres[aud_idx] + 1e-30)),
332
+ 0.))
333
+ return epsm, psm, dc2mod
334
+
335
+ def export_config(self) -> dict:
336
+ """
337
+ Export the parameters as dictionary.
338
+ :return: Dictionary containing parameters
339
+ """
340
+ out_config = dict(
341
+ binaural=self._binaural,
342
+ aud_filt_range=self._aud_filt_range,
343
+ mod_filt_range=self._mod_filt_range,
344
+ corr_thres=self._corr_thres,
345
+ decimation_factor=self._decimation_factor,
346
+ limits=self._limits,
347
+ threshold_scaling=self._threshold_scaling
348
+ )
349
+ return out_config
@@ -0,0 +1,15 @@
1
+ # This file is part of auditory_models
2
+ # Copyright (C) 2025 Max Zimmermann
3
+ #
4
+ # auditory_models is free software: you can redistribute it and/or modify
5
+ # it under the terms of the GNU General Public License as published by
6
+ # the Free Software Foundation, either version 3 of the License, or
7
+ # (at your option) any later version.
8
+ #
9
+ # auditory_models is distributed in the hope that it will be useful,
10
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
11
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12
+ # GNU General Public License for more details.
13
+ #
14
+ # You should have received a copy of the GNU General Public License
15
+ # along with auditory_models. If not, see <https://www.gnu.org/licenses/>.
@@ -0,0 +1,238 @@
1
+ # This file is part of auditory_models
2
+ # Copyright (C) 2025 Max Zimmermann
3
+ #
4
+ # auditory_models is free software: you can redistribute it and/or modify
5
+ # it under the terms of the GNU General Public License as published by
6
+ # the Free Software Foundation, either version 3 of the License, or
7
+ # (at your option) any later version.
8
+ #
9
+ # auditory_models is distributed in the hope that it will be useful,
10
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
11
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12
+ # GNU General Public License for more details.
13
+ #
14
+ # You should have received a copy of the GNU General Public License
15
+ # along with auditory_models. If not, see <https://www.gnu.org/licenses/>.
16
+
17
+
18
+ import numpy as np
19
+ from numpy import pi, cos, exp, sqrt
20
+ from numpy.typing import ArrayLike, NDArray
21
+ from scipy.signal import butter, sosfilt
22
+ from scipy.special import factorial
23
+
24
+ from auditory_models.helpers.utils import generate_delta_impulse, hertz_to_erbscale, erb_aud, freq_erb_spaced
25
+
26
+
27
+ class FilterbankBase:
28
+ """
29
+ This class implements a base class for filterbank implementations.
30
+ """
31
+ def __init__(self, sample_rate: float, cf: ArrayLike | None):
32
+ if np.ndim(cf) > 1:
33
+ raise ValueError(f"cf must be one-dimensional, currently: {np.ndim(cf)}")
34
+ self._sample_rate = sample_rate
35
+ if cf is not None:
36
+ self._cf = np.array(cf)
37
+ else:
38
+ self._cf = None
39
+ self._coeffs = None
40
+ self._states = None
41
+ self.n_filters = None
42
+
43
+ @property
44
+ def cf(self) -> NDArray:
45
+ """ Indicate that `cf` is a read-only variable. """
46
+ return self._cf
47
+
48
+ def _design_filterbank(self) -> None:
49
+ """ Compute the filter coefficients. """
50
+ raise NotImplementedError
51
+
52
+ def reset_states(self) -> None:
53
+ """ Reset all filter states to zero. """
54
+ self._states = np.zeros(self._states.shape, dtype=self._states.dtype)
55
+
56
+ def process(self, signal: NDArray, save_state: bool = True) -> NDArray:
57
+ """
58
+ Compute the filterbank output of a given signal.
59
+ :param signal: Input signal, shape may be either 1-dim or 2-dim (n_filters, signal_length).
60
+ :param save_state: Indicating if state should be memorized.
61
+ :return: Output signal with shape (n_filters, signal_length)
62
+ """
63
+ if signal.ndim == 1:
64
+ signal = np.repeat(signal.reshape(1, signal.size), self.n_filters, axis=0)
65
+ elif signal.ndim == 2:
66
+ if signal.shape[0] != self.n_filters:
67
+ raise ValueError(f"Input signal must be of shape (number_of_filters, signal_length). The first "
68
+ f"dimension does not match. Is currently {signal.shape[0]}, must be "
69
+ f"{self.n_filters}")
70
+ else:
71
+ raise ValueError(f"Signal must have either one or two dimensions. In case of two dimensions, the first "
72
+ f"must have a length equal to the number of filters ({self.n_filters}).")
73
+ if save_state:
74
+ for i, sos_coeff in enumerate(self._coeffs):
75
+ signal[i, :], self._states[i, :, :] = sosfilt(sos_coeff, signal[i, :], zi=self._states[i, :, :])
76
+ else:
77
+ for i, sos_coeff in enumerate(self._coeffs):
78
+ signal[i, :], _ = sosfilt(sos_coeff, signal[i, :], zi=self._states[i, :, :])
79
+ return signal
80
+
81
+ def synthesize(self, bands: NDArray) -> NDArray:
82
+ raise NotImplementedError
83
+
84
+
85
+ class GammatoneFilterbank(FilterbankBase):
86
+ """
87
+ This class implements gammatone filters and a filtering routine.
88
+
89
+ Reference:
90
+ [Hohmann2002]
91
+ Hohmann, V., Frequency analysis and synthesis using a Gammatone filterbank,
92
+ Acta Acustica, Vol 88 (2002), 433--442
93
+ """
94
+ def __init__(self, sample_rate: float = 44100,
95
+ order: int = 4,
96
+ normfreq: float = 1000.0,
97
+ freq_range: tuple[float, float] | None = None,
98
+ band_range: tuple[int, int] = (-12, 12),
99
+ density: float = 1.0,
100
+ cf: ArrayLike | None = None,
101
+ bandwidths: ArrayLike | None = None,
102
+ bandwidth_factor: float = 1.0, attenuation_half_bandwidth_db: float = -3,
103
+ desired_delay_sec: float = 0.02):
104
+ """
105
+ Init method
106
+ :param sample_rate: sample rate of filterbank in Hz
107
+ :param order: order of filters
108
+ :param normfreq: The reference frequency for `startband` and `endband`
109
+ :param freq_range: two values of the lowest and highest possible center-frequency in Hz, overrides
110
+ `band_range`, first value must be lower than `norm_freq`, second value must be higher than `norm_freq`
111
+ :param band_range: two values defining the number of filters above and below the `normfreq`, if freq_range is
112
+ given, this value will be overridden!
113
+ :param density: ERB density of 1 would be `erb_aud`
114
+ :param cf: Sequence of center-frequencies in Hz, overrides automatic computation via `normfreq`, `freq_range`,
115
+ `band_range`, and `density`
116
+ :param bandwidths: array of bandwidths of filters, size must be equal to number of filters,
117
+ if None given it will default to the ERB of the respective center-frequencies
118
+ :param bandwidth_factor: if bandwidths is not specified, they will be computed via the erb_aud() of each
119
+ center-frequency multiplied by this parameter
120
+ :param attenuation_half_bandwidth_db: attenuation of the filters at half bandwidth in dB
121
+ :param desired_delay_sec:
122
+ """
123
+ super().__init__(sample_rate, cf)
124
+ self._order = order
125
+ if freq_range is None:
126
+ if len(band_range) != 2:
127
+ raise ValueError(f"`band_range` must have a size of 2, currently: {len(band_range)}")
128
+ startband = band_range[0]
129
+ endband = band_range[1]
130
+ else:
131
+ if len(freq_range) != 2:
132
+ raise ValueError(f"`freq_range` must have a size of 2, currently: {len(freq_range)}")
133
+ if freq_range[0] >= normfreq or freq_range[1] <= normfreq:
134
+ raise ValueError("`freq_range` values must be lower/higher than `norm_freq`! Currently `freq_range` = "
135
+ f"{freq_range}, `norm_freq` = {normfreq}")
136
+ startband = round(np.fix(hertz_to_erbscale(freq_range[0]) - hertz_to_erbscale(normfreq)))
137
+ endband = round(np.ceil(hertz_to_erbscale(freq_range[1]) - hertz_to_erbscale(normfreq)))
138
+
139
+ if self._cf is None:
140
+ self._cf = freq_erb_spaced(startband, endband, normfreq, density)
141
+ else:
142
+ self._cf = np.array(cf, dtype=np.float64)
143
+ self.n_filters = self._cf.size
144
+ self._bandwidths = bandwidths
145
+ self._use_erb = False
146
+ if self._bandwidths is None:
147
+ self._use_erb = True
148
+ self._bandwidths = bandwidth_factor * erb_aud(self._cf)
149
+ self._attenuation_half_bandwidth_db = attenuation_half_bandwidth_db
150
+ self._states = np.zeros((self.n_filters, self._order, 2), dtype=np.complex128)
151
+ self._design_filterbank()
152
+
153
+ self._desired_delay_samples = int(self._sample_rate * desired_delay_sec)
154
+ self._max_indices, self._slopes = self.estimate_max_indices_and_slopes()
155
+ self._delay_samples = self._desired_delay_samples - self._max_indices
156
+ self._delay_memory = np.zeros((len(self._cf), np.max(self._delay_samples)))
157
+ self._phase_factors = np.abs(self._slopes) * 1j / self._slopes
158
+ self._gains = np.ones(len(self._cf))
159
+
160
+ def _design_filterbank(self) -> None:
161
+ """ Returns filter coefficients of a gammatone filter [Hohmann2002]. """
162
+ if self._use_erb:
163
+ # [Hohmann2002] eq. (14)
164
+ a_gamma = (pi * factorial(2 * self._order - 2) *
165
+ 2 ** -(2 * self._order - 2) /
166
+ factorial(self._order - 1) ** 2)
167
+ b = self._bandwidths / a_gamma
168
+ lambda_ = np.exp(-2 * pi * b / self._sample_rate)
169
+ else:
170
+ # [Hohmann2002] eq. (12)
171
+ phi = pi * self._bandwidths / self._sample_rate
172
+ alpha = 10 ** (0.1 * self._attenuation_half_bandwidth_db / self._order)
173
+ p = (-2 + 2 * alpha * cos(phi)) / (1 - alpha)
174
+ lambda_ = -p / 2 - sqrt(p * p / 4 - 1)
175
+ beta = 2 * pi * self._cf / self._sample_rate
176
+ coef = lambda_ * exp(1j * beta)
177
+ factor = 2 * (1 - np.abs(coef)) ** self._order
178
+ self._coeffs = np.zeros((self.n_filters, self._order, 6), dtype=np.complex128)
179
+ for idx, c in enumerate(-coef):
180
+ self._coeffs[idx, :, :] = np.repeat([[1., 0., 0., 1., c, 0.]], self._order, axis=0)
181
+ self._coeffs[:, 0, 0] = factor
182
+
183
+ # def synthesize(self, bands: NDArray) -> NDArray:
184
+ # return np.array(list(self.delay([b*g for b, g in zip(bands, self._gains)]))).sum(axis=0)
185
+
186
+ def delay(self, bands):
187
+ for i, band in enumerate(bands):
188
+ if self._delay_samples[i] == 0:
189
+ yield np.real(band) * self._phase_factors[i]
190
+ else:
191
+ yield np.concatenate((self._delay_memory[i, :self._delay_samples[i]],
192
+ np.real(band[:-self._delay_samples[i]])), axis=0)
193
+ self._delay_memory[i, :self._delay_samples[i]] = np.real(band[-self._delay_samples[i]:])
194
+
195
+ def estimate_max_indices_and_slopes(self):
196
+ sig = generate_delta_impulse(self._desired_delay_samples, dtype=np.complex128)
197
+ bands = self.process(sig, save_state=False)
198
+ ibandmax = np.argmax(np.abs(bands[:self._desired_delay_samples]), axis=-1)
199
+ slopes = [b[i+1] - b[i-1] for (b, i) in zip(bands, ibandmax)]
200
+ return np.array(ibandmax), np.array(slopes)
201
+
202
+
203
+ class BandpassFilterbank(FilterbankBase):
204
+ """
205
+ Implements a bandpass filterbank from given center-frequencies with octave width for each band. There is also the
206
+ option to add a lowpass filter at 1Hz to include the DC-component.
207
+ """
208
+ def __init__(self, cf: ArrayLike, sample_rate: float, order: int = 2, dc_lowpass: bool = True):
209
+ """
210
+ Init method
211
+ :param cf: Array of center-frequencies for filters in Hz
212
+ :param sample_rate: Sampling rate of filterbank in Hz
213
+ :param order: Order of bandpass filters
214
+ :param dc_lowpass: Indicates if an additional band for the DC via a lowpass at 1 Hz should be computed
215
+ """
216
+ super().__init__(sample_rate, cf)
217
+ self._cf = np.array(cf)
218
+ self._sample_rate = sample_rate
219
+ if order % 2 != 0:
220
+ raise ValueError(f"order must be an even integer, currently: {order}")
221
+ self._order = order
222
+ self._dc_lp = dc_lowpass
223
+ self.n_filters = self._cf.size + self._dc_lp
224
+ self._states = np.zeros((self.n_filters, round(self._order / 2), 2))
225
+ self._coeffs = np.zeros((self.n_filters, round(self._order / 2), 6))
226
+ self._design_filterbank()
227
+
228
+ def _design_filterbank(self) -> None:
229
+ """ Compute the filter coefficients """
230
+ if self._dc_lp:
231
+ self._coeffs[0, :, :] = butter(self._order, 1, output="sos", fs=self._sample_rate)
232
+ fhigh = self._cf / 2 + np.sqrt(np.square(self._cf / 2) + np.square(self._cf))
233
+ flow = np.square(self._cf) / fhigh
234
+ for idx, (low, high) in enumerate(zip(flow, fhigh)):
235
+ self._coeffs[idx + self._dc_lp, :, :] = butter(round(self._order / 2), [low, high], btype="bandpass",
236
+ output="sos", fs=self._sample_rate)
237
+ if self._dc_lp:
238
+ self._cf = np.concatenate(([1], self._cf))