auditory-models 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- auditory_models/__init__.py +19 -0
- auditory_models/gpsmq/__init__.py +18 -0
- auditory_models/gpsmq/gpsmq.py +349 -0
- auditory_models/helpers/__init__.py +15 -0
- auditory_models/helpers/filterbank.py +238 -0
- auditory_models/helpers/utils.py +360 -0
- auditory_models/stoi/__init__.py +18 -0
- auditory_models/stoi/stoi.py +176 -0
- auditory_models-0.1.1.dist-info/METADATA +68 -0
- auditory_models-0.1.1.dist-info/RECORD +13 -0
- auditory_models-0.1.1.dist-info/WHEEL +5 -0
- auditory_models-0.1.1.dist-info/licenses/COPYING +173 -0
- auditory_models-0.1.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# This file is part of auditory_models
|
|
2
|
+
# Copyright (C) 2025 Max Zimmermann
|
|
3
|
+
#
|
|
4
|
+
# auditory_models is free software: you can redistribute it and/or modify
|
|
5
|
+
# it under the terms of the GNU General Public License as published by
|
|
6
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
7
|
+
# (at your option) any later version.
|
|
8
|
+
#
|
|
9
|
+
# auditory_models is distributed in the hope that it will be useful,
|
|
10
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
11
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
12
|
+
# GNU General Public License for more details.
|
|
13
|
+
#
|
|
14
|
+
# You should have received a copy of the GNU General Public License
|
|
15
|
+
# along with auditory_models. If not, see <https://www.gnu.org/licenses/>.
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from .stoi import STOI
|
|
19
|
+
from .gpsmq import GPSMq
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# This file is part of auditory_models
|
|
2
|
+
# Copyright (C) 2025 Max Zimmermann
|
|
3
|
+
#
|
|
4
|
+
# auditory_models is free software: you can redistribute it and/or modify
|
|
5
|
+
# it under the terms of the GNU General Public License as published by
|
|
6
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
7
|
+
# (at your option) any later version.
|
|
8
|
+
#
|
|
9
|
+
# auditory_models is distributed in the hope that it will be useful,
|
|
10
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
11
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
12
|
+
# GNU General Public License for more details.
|
|
13
|
+
#
|
|
14
|
+
# You should have received a copy of the GNU General Public License
|
|
15
|
+
# along with auditory_models. If not, see <https://www.gnu.org/licenses/>.
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from .gpsmq import GPSMq
|
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
# This file is part of auditory_models
|
|
2
|
+
# Copyright (C) 2025 Max Zimmermann
|
|
3
|
+
#
|
|
4
|
+
# auditory_models is free software: you can redistribute it and/or modify
|
|
5
|
+
# it under the terms of the GNU General Public License as published by
|
|
6
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
7
|
+
# (at your option) any later version.
|
|
8
|
+
#
|
|
9
|
+
# auditory_models is distributed in the hope that it will be useful,
|
|
10
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
11
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
12
|
+
# GNU General Public License for more details.
|
|
13
|
+
#
|
|
14
|
+
# You should have received a copy of the GNU General Public License
|
|
15
|
+
# along with auditory_models. If not, see <https://www.gnu.org/licenses/>.
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
from numpy.typing import NDArray
|
|
20
|
+
from scipy.signal import butter, fftconvolve, firwin, sosfilt
|
|
21
|
+
from warnings import warn
|
|
22
|
+
|
|
23
|
+
from auditory_models.helpers.utils import thirdoct, iso389_7_thresholds
|
|
24
|
+
from auditory_models.helpers.filterbank import BandpassFilterbank, GammatoneFilterbank
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class GPSMq:
|
|
28
|
+
def __init__(self, binaural: bool = True, aud_filt_range: tuple[int, int] = (7, 24),
|
|
29
|
+
mod_filt_range: tuple[int, int] = (1, 7), corr_thres: float = 0.8, decimation_factor: int = 8,
|
|
30
|
+
limits: tuple[float, float] = (-65, -100), threshold_scaling: float = 1e-10):
|
|
31
|
+
"""
|
|
32
|
+
Init method
|
|
33
|
+
:param binaural: Determines if binaural processing should be used, if True the input of process() must have
|
|
34
|
+
shapes of (2, signal_length).
|
|
35
|
+
:param aud_filt_range: The range of Gammatone-Filterbank center-frequencies as indices of a vector of
|
|
36
|
+
frequencies in third-octave distances, starting with 63Hz and ending with 16000Hz as defined by
|
|
37
|
+
IEC 61260-1:2014.
|
|
38
|
+
:param mod_filt_range: The range of Modulation-Filterbank center-frequencies as exponents of 2. The default of
|
|
39
|
+
(1, 7) would result in [2, 4, 8, 16, 32, 64] Hz as center-frequencies.
|
|
40
|
+
:param corr_thres: Scaling value as parameter for the sigmoid function that will process the correlation matrix.
|
|
41
|
+
:param decimation_factor: Factor of decimation of the lowpass filtered Hilbert envelope signals.
|
|
42
|
+
:param limits: Limits in dB for the multi resolution based power processing
|
|
43
|
+
:param threshold_scaling: Scaling factor for the squared hearing threshold. Default value is chosen with the
|
|
44
|
+
assumption that 0 dB RMS equals 100 dB SPL.
|
|
45
|
+
"""
|
|
46
|
+
self._sample_rate = 0
|
|
47
|
+
self._binaural = binaural
|
|
48
|
+
self._n_chan = self._binaural + 1
|
|
49
|
+
self._aud_filt_range = aud_filt_range
|
|
50
|
+
self._mod_filt_range = mod_filt_range
|
|
51
|
+
self._corr_thres = corr_thres
|
|
52
|
+
self._decimation_factor = decimation_factor
|
|
53
|
+
self._limits = limits
|
|
54
|
+
self._threshold_scaling = threshold_scaling
|
|
55
|
+
self._slope = 1 / (self._limits[0] - self._limits[1])
|
|
56
|
+
self._env_lp_sos = None
|
|
57
|
+
self._sig_len = 0
|
|
58
|
+
self._sig_len_dec = 0
|
|
59
|
+
self._env_pow_lim = 10 ** -2.7
|
|
60
|
+
|
|
61
|
+
self._decimation_order = 20 * self._decimation_factor
|
|
62
|
+
self.decimation_filter = None
|
|
63
|
+
self._sample_rate_dec = None
|
|
64
|
+
self.gtfb = None
|
|
65
|
+
self.mf_cf = np.pow(2, np.arange(*self._mod_filt_range))
|
|
66
|
+
self.mfb = None
|
|
67
|
+
self._iso_thres = None
|
|
68
|
+
self._upper_lim = None
|
|
69
|
+
|
|
70
|
+
def _recompute_properties(self) -> None:
|
|
71
|
+
"""
|
|
72
|
+
When process receives a sample_rate value that is different from the previous one, recompute all dependent
|
|
73
|
+
properties.
|
|
74
|
+
:return: None
|
|
75
|
+
"""
|
|
76
|
+
self._env_lp_sos = butter(1, 150, fs=self._sample_rate, output="sos")
|
|
77
|
+
# cutoff frequency for decimation filter is computed with the decimation factor and an extra factor for the
|
|
78
|
+
# filter to be effective at the new nyquist frequency
|
|
79
|
+
cutoff = self._sample_rate / (2 * self._decimation_factor) * (5 / 6)
|
|
80
|
+
self.decimation_filter = firwin(self._decimation_order + 1, cutoff, fs=self._sample_rate, window=("kaiser", 5.),
|
|
81
|
+
pass_zero=True)
|
|
82
|
+
self._sample_rate_dec = self._sample_rate / self._decimation_factor
|
|
83
|
+
gt_cf = thirdoct(self._sample_rate, nfft=1024, min_freq=63, max_freq=16000)[1]
|
|
84
|
+
self.gtfb = GammatoneFilterbank(self._sample_rate, cf=gt_cf[self._aud_filt_range[0]:self._aud_filt_range[1]])
|
|
85
|
+
self.mfb = BandpassFilterbank(self.mf_cf, self._sample_rate_dec)
|
|
86
|
+
self._iso_thres = 10 ** (iso389_7_thresholds(self.gtfb.cf) / 10)
|
|
87
|
+
self._upper_lim = self._iso_thres * 10 ** (self._limits[0] / 10)
|
|
88
|
+
self._iso_thres *= self._threshold_scaling
|
|
89
|
+
|
|
90
|
+
def process(self, reference: np.ndarray, degraded: np.ndarray, sample_rate: float) -> dict:
|
|
91
|
+
"""
|
|
92
|
+
Process method to calculate the perceptual proximity of a degraded signal to a reference signal.
|
|
93
|
+
:param reference: Reference signal, if binaural it must have shape (2, signal_length)
|
|
94
|
+
:param degraded: Degraded signal, if binaural it must have shape (2, signal_length)
|
|
95
|
+
:param sample_rate: sample rate of both input signals in Hz
|
|
96
|
+
:return: dictionary containing measurement data
|
|
97
|
+
"snr_dc": SNR for the DC part
|
|
98
|
+
"snr_ac": SNR for the modulation part
|
|
99
|
+
"snr_ac_fix": SNR for the modulation part including weighting to reduce effects of IPDs/ITDs
|
|
100
|
+
"opm": combined snr_dc and snr_ac into perceptual measure (MUSHRA scale)
|
|
101
|
+
"opm_fix": combined snr_dc and snr_ac_fix into perceptual measure (MUSHRA scale)
|
|
102
|
+
"""
|
|
103
|
+
reference = np.squeeze(reference)
|
|
104
|
+
degraded = np.squeeze(degraded)
|
|
105
|
+
if self._binaural:
|
|
106
|
+
if reference.ndim != 2:
|
|
107
|
+
raise ValueError(f"reference must have two dimensions, currently: {reference.ndim}")
|
|
108
|
+
if reference.shape[0] != 2:
|
|
109
|
+
raise ValueError(f"reference must have a size of 2 in first dimension, currently shape is: "
|
|
110
|
+
f"{reference.shape}")
|
|
111
|
+
if degraded.ndim != 2:
|
|
112
|
+
raise ValueError(f"degraded must have two dimensions, currently: {degraded.ndim}")
|
|
113
|
+
if degraded.shape[0] != 2:
|
|
114
|
+
raise ValueError(f"degraded must have a size of 2 in first dimension, currently shape is: "
|
|
115
|
+
f"{degraded.shape}")
|
|
116
|
+
if reference.shape[1] != degraded.shape[1]:
|
|
117
|
+
raise ValueError(f"reference and degraded must have same signal lengths, currently "
|
|
118
|
+
f"{reference.shape[1]} and {degraded.shape[1]}.")
|
|
119
|
+
else:
|
|
120
|
+
if reference.ndim != 1:
|
|
121
|
+
raise ValueError(f"reference must have one dimension, currently: {reference.ndim}")
|
|
122
|
+
if degraded.ndim != 1:
|
|
123
|
+
raise ValueError(f"degraded must have one dimension, currently: {degraded.ndim}")
|
|
124
|
+
if reference.size != degraded.size:
|
|
125
|
+
raise ValueError(f"reference and degraded must have same signal lengths, currently "
|
|
126
|
+
f"{reference.size} and {degraded.size}.")
|
|
127
|
+
reference = reference[np.newaxis, :]
|
|
128
|
+
degraded = degraded[np.newaxis, :]
|
|
129
|
+
|
|
130
|
+
if self._sig_len != reference.shape[reference.ndim - 1]:
|
|
131
|
+
self._sig_len = reference.shape[reference.ndim - 1]
|
|
132
|
+
self._sig_len_dec = round(np.ceil(self._sig_len / self._decimation_factor))
|
|
133
|
+
|
|
134
|
+
if sample_rate != self._sample_rate:
|
|
135
|
+
self._sample_rate = sample_rate
|
|
136
|
+
self._recompute_properties()
|
|
137
|
+
|
|
138
|
+
if self._sample_rate / 2 < self.gtfb.cf[-1]:
|
|
139
|
+
warn(f"Given sample rate ({self._sample_rate}) does not fit Nyquist Theorem with the center-frequency of "
|
|
140
|
+
f"the highest band ({self.gtfb.cf[-1]}).")
|
|
141
|
+
|
|
142
|
+
# Auditory filtering via Gammatone Filterbank and Hilbert Envelope
|
|
143
|
+
ref_hilbert = np.zeros((self._n_chan, self.gtfb.cf.size, self._sig_len))
|
|
144
|
+
dgr_hilbert = np.zeros((self._n_chan, self.gtfb.cf.size, self._sig_len))
|
|
145
|
+
for ch in range(self._n_chan):
|
|
146
|
+
ref_hilbert[ch, :, :] = np.abs(
|
|
147
|
+
self.gtfb.process(reference[ch, :].astype(np.complex128), save_state=False)) / np.sqrt(2)
|
|
148
|
+
dgr_hilbert[ch, :, :] = np.abs(
|
|
149
|
+
self.gtfb.process(degraded[ch, :].astype(np.complex128), save_state=False)) / np.sqrt(2)
|
|
150
|
+
|
|
151
|
+
# ILD cancellation
|
|
152
|
+
if self._binaural:
|
|
153
|
+
ref_rms = np.sqrt(np.mean(np.square(ref_hilbert), axis=2))
|
|
154
|
+
dgr_rms = np.sqrt(np.mean(np.square(dgr_hilbert), axis=2))
|
|
155
|
+
ref_lr_diff = np.squeeze(np.abs(np.diff(ref_rms, axis=0)))
|
|
156
|
+
dgr_lr_diff = np.squeeze(np.abs(np.diff(dgr_rms, axis=0)))
|
|
157
|
+
ref_dgr_l_diff = ref_rms[0, :] - dgr_rms[0, :]
|
|
158
|
+
ref_dgr_r_diff = ref_rms[1, :] - dgr_rms[1, :]
|
|
159
|
+
for band in range(self.gtfb.cf.size):
|
|
160
|
+
if np.sign(ref_dgr_l_diff[band]) * np.sign(ref_dgr_r_diff[band]) == -1:
|
|
161
|
+
if ref_dgr_l_diff[band] > 0:
|
|
162
|
+
dgr_hilbert[0, band, :] *= ref_rms[0, band] / dgr_rms[0, band]
|
|
163
|
+
elif ref_dgr_l_diff[band] < 0:
|
|
164
|
+
ref_hilbert[0, band, :] *= dgr_rms[0, band] / ref_rms[0, band]
|
|
165
|
+
|
|
166
|
+
if ref_dgr_r_diff[band] > 0:
|
|
167
|
+
dgr_hilbert[1, band, :] *= ref_rms[1, band] / dgr_rms[1, band]
|
|
168
|
+
elif ref_dgr_r_diff[band] < 0:
|
|
169
|
+
ref_hilbert[1, band, :] *= dgr_rms[1, band] / ref_rms[1, band]
|
|
170
|
+
else:
|
|
171
|
+
if np.abs(ref_dgr_l_diff[band]) > np.abs(ref_dgr_r_diff[band]):
|
|
172
|
+
if ref_lr_diff[band] > dgr_lr_diff[band]:
|
|
173
|
+
ref_hilbert[0, band, :] *= ref_rms[1, band] / ref_rms[0, band]
|
|
174
|
+
elif ref_lr_diff[band] < dgr_lr_diff[band]:
|
|
175
|
+
dgr_hilbert[0, band, :] *= dgr_rms[1, band] / dgr_rms[0, band]
|
|
176
|
+
elif np.abs(ref_dgr_l_diff[band]) < np.abs(ref_dgr_r_diff[band]):
|
|
177
|
+
if ref_lr_diff[band] > dgr_lr_diff[band]:
|
|
178
|
+
ref_hilbert[1, band, :] *= ref_rms[0, band] / ref_rms[1, band]
|
|
179
|
+
elif ref_lr_diff[band] < dgr_lr_diff[band]:
|
|
180
|
+
dgr_hilbert[1, band, :] *= dgr_rms[0, band] / dgr_rms[1, band]
|
|
181
|
+
|
|
182
|
+
# apply Lowpass-Filter at 150Hz
|
|
183
|
+
ref_lp = sosfilt(self._env_lp_sos, ref_hilbert, axis=2)
|
|
184
|
+
dgr_lp = sosfilt(self._env_lp_sos, dgr_hilbert, axis=2)
|
|
185
|
+
|
|
186
|
+
# decimate lowpass filtered signal
|
|
187
|
+
ref_lp_dec = fftconvolve(ref_lp, self.decimation_filter[np.newaxis, np.newaxis, :],
|
|
188
|
+
mode="same", axes=-1)[:, :, ::self._decimation_factor]
|
|
189
|
+
dgr_lp_dec = fftconvolve(dgr_lp, self.decimation_filter[np.newaxis, np.newaxis, :],
|
|
190
|
+
mode="same", axes=-1)[:, :, ::self._decimation_factor]
|
|
191
|
+
|
|
192
|
+
# apply modulation filterbank
|
|
193
|
+
ref_mod = np.zeros((self._n_chan, self.gtfb.cf.size, self.mfb.n_filters, self._sig_len_dec))
|
|
194
|
+
dgr_mod = np.zeros((self._n_chan, self.gtfb.cf.size, self.mfb.n_filters, self._sig_len_dec))
|
|
195
|
+
for ch in range(self._n_chan):
|
|
196
|
+
for gt in range(self.gtfb.cf.size):
|
|
197
|
+
ref_mod[ch, gt, :, :] = self.mfb.process(ref_lp_dec[ch, gt, :], save_state=False)
|
|
198
|
+
dgr_mod[ch, gt, :, :] = self.mfb.process(dgr_lp_dec[ch, gt, :], save_state=False)
|
|
199
|
+
|
|
200
|
+
# calculate multi-resolution-based envelope power and short-time power
|
|
201
|
+
ref_epsm, ref_psm, ref_dc2mod = self.multi_resolution_based_power(ref_lp_dec, ref_mod)
|
|
202
|
+
dgr_epsm, dgr_psm, dgr_dc2mod = self.multi_resolution_based_power(dgr_lp_dec, dgr_mod)
|
|
203
|
+
|
|
204
|
+
# calculate correlation matrix
|
|
205
|
+
corr_mat = np.ones((self._n_chan, self.gtfb.cf.size, self.mfb.n_filters - 1))
|
|
206
|
+
for ch in range(self._n_chan):
|
|
207
|
+
for aud_idx in range(self.gtfb.cf.size):
|
|
208
|
+
for mf_idx in range(1, self.mfb.n_filters):
|
|
209
|
+
corr_mat[ch, aud_idx, mf_idx - 1] = np.corrcoef(
|
|
210
|
+
ref_mod[ch, aud_idx, mf_idx, round(self._sample_rate_dec / 4) - 1::],
|
|
211
|
+
dgr_mod[ch, aud_idx, mf_idx, round(self._sample_rate_dec / 4) - 1::])[0, 1]
|
|
212
|
+
|
|
213
|
+
# psm above iso thresholds
|
|
214
|
+
clipper = np.logical_and(ref_psm > self._iso_thres[np.newaxis, :, np.newaxis, np.newaxis],
|
|
215
|
+
dgr_psm > self._iso_thres[np.newaxis, :, np.newaxis, np.newaxis])
|
|
216
|
+
|
|
217
|
+
# calculate SNR increment and decrement
|
|
218
|
+
snr_inc_epsm = np.fmax(np.fmin(dgr_epsm / (ref_epsm + 1e-30) - 1, 20), 0.) * dgr_dc2mod
|
|
219
|
+
snr_inc_epsm[np.logical_not(clipper)] = 0.
|
|
220
|
+
snr_inc_psm = np.fmax(np.fmin(dgr_psm / (ref_psm + 1e-30) - 1, 20), 0.)
|
|
221
|
+
|
|
222
|
+
snr_dec_epsm = np.fmax(np.fmin(ref_epsm / (dgr_epsm + 1e-30) - 1, 20), 0.) * ref_dc2mod
|
|
223
|
+
snr_dec_epsm[np.logical_not(clipper)] = 0.
|
|
224
|
+
snr_dec_psm = np.fmax(np.fmin(ref_psm / (dgr_psm + 1e-30) - 1, 20), 0.)
|
|
225
|
+
|
|
226
|
+
snr_inc_mod = np.zeros((self._n_chan, self.gtfb.cf.size, self.mfb.n_filters))
|
|
227
|
+
snr_inc_stint = np.zeros(snr_inc_mod.shape)
|
|
228
|
+
snr_dec_mod = np.zeros(snr_inc_mod.shape)
|
|
229
|
+
snr_dec_stint = np.zeros(snr_inc_mod.shape)
|
|
230
|
+
|
|
231
|
+
for mf_idx in range(self.mfb.n_filters):
|
|
232
|
+
n_windows = round(np.ceil(self._sig_len_dec / np.floor(self._sample_rate_dec / self.mfb.cf[mf_idx])))
|
|
233
|
+
snr_inc_mod[:, :, mf_idx] = np.mean(snr_inc_epsm[:, :, mf_idx, 0:n_windows], axis=-1)
|
|
234
|
+
snr_inc_stint[:, :, mf_idx] = np.mean(snr_inc_psm[:, :, mf_idx, 0:n_windows], axis=-1)
|
|
235
|
+
snr_dec_mod[:, :, mf_idx] = np.mean(snr_dec_epsm[:, :, mf_idx, 0:n_windows], axis=-1)
|
|
236
|
+
snr_dec_stint[:, :, mf_idx] = np.mean(snr_dec_psm[:, :, mf_idx, 0:n_windows], axis=-1)
|
|
237
|
+
snr_inc_mod = np.fmax(snr_inc_mod, 0.)
|
|
238
|
+
snr_inc_stint = np.fmax(snr_inc_stint, 0.)
|
|
239
|
+
snr_dec_mod = np.fmax(snr_dec_mod, 0.)
|
|
240
|
+
snr_dec_stint = np.fmax(snr_dec_stint, 0.)
|
|
241
|
+
valid_mod_freq = self.gtfb.cf[:, np.newaxis] > 4 * self.mfb.cf[np.newaxis, :]
|
|
242
|
+
snr_inc_mod *= valid_mod_freq
|
|
243
|
+
snr_dec_mod *= valid_mod_freq
|
|
244
|
+
|
|
245
|
+
# combine SNR increment and decrement
|
|
246
|
+
snr_dc = np.mean(np.sqrt(np.sum(np.square((snr_inc_stint[:, :, 2] + snr_dec_stint[:, :, 2]) / 2), axis=-1)))
|
|
247
|
+
inc_dec_mean = (snr_inc_mod[:, :, 1::] + 10 ** (-0.7) * snr_dec_mod[:, :, 1::]) / 2
|
|
248
|
+
snr_ac = np.mean(np.sqrt(np.sum(np.sum(np.square(inc_dec_mean), axis=-1), axis=-1)))
|
|
249
|
+
|
|
250
|
+
# include weighting via correlation matrix to make model less sensitive to IPDs/ITDs
|
|
251
|
+
sigmoid_corr = 1 / (1 + np.exp(-50 * (corr_mat - self._corr_thres)))
|
|
252
|
+
snr_ac_fix = np.mean(np.sqrt(np.sum(np.sum(np.square(sigmoid_corr * inc_dec_mean), axis=-1), axis=-1)))
|
|
253
|
+
|
|
254
|
+
# transform to perceptual measure (MUSHRA scale)
|
|
255
|
+
opm = -4.08695250298565 * 10 * np.log10(snr_dc + snr_ac + 1e-30) + 75.6467755339438
|
|
256
|
+
opm_fix = -4.15742483856597 * 10 * np.log10(snr_dc + snr_ac_fix + 1e-30) + 74.7791007678067
|
|
257
|
+
|
|
258
|
+
return {"snr_dc": snr_dc, "snr_ac": snr_ac, "snr_ac_fix": snr_ac_fix, "opm": opm, "opm_fix": opm_fix}
|
|
259
|
+
|
|
260
|
+
def lowpass_filterbank(self, signal: NDArray) -> NDArray:
|
|
261
|
+
"""
|
|
262
|
+
Compute output of a lowpass filterbank using moving averages with variable window sizes depending on the
|
|
263
|
+
modulation frequency.
|
|
264
|
+
:param signal: Input signal, must be one-dimensional
|
|
265
|
+
:return: two-dimensional matrix with filterbank output with shape (number_of_mod_filters, signal_length)
|
|
266
|
+
"""
|
|
267
|
+
out = np.zeros((self.mfb.n_filters, self._sig_len_dec))
|
|
268
|
+
window = np.ones(round(np.ceil(2 * self._sample_rate_dec / np.min(self.mfb.cf))))
|
|
269
|
+
crit_mod_freq = 8.
|
|
270
|
+
idx_crit = 0
|
|
271
|
+
repeat = False
|
|
272
|
+
for idx, mod_freq in enumerate(self.mfb.cf):
|
|
273
|
+
if mod_freq <= crit_mod_freq:
|
|
274
|
+
win_width = round(1.5 * self._sample_rate_dec / mod_freq)
|
|
275
|
+
out[idx, :] = np.convolve(window[0:win_width] / win_width, signal, mode="full")[0:signal.size]
|
|
276
|
+
elif not repeat:
|
|
277
|
+
win_width = round(self._sample_rate_dec / 8)
|
|
278
|
+
out[idx, :] = np.convolve(window[0:win_width] / win_width, signal, mode="full")[0:signal.size]
|
|
279
|
+
idx_crit = idx
|
|
280
|
+
repeat = True
|
|
281
|
+
elif repeat:
|
|
282
|
+
out[idx, :] = out[idx_crit, :]
|
|
283
|
+
return out
|
|
284
|
+
|
|
285
|
+
def multi_resolution_based_power(self, envelope: NDArray, modulation: NDArray) -> tuple[NDArray, NDArray, NDArray]:
|
|
286
|
+
"""
|
|
287
|
+
Calculate the multi resolution based power from an envelope signal and a modulation signal.
|
|
288
|
+
:param envelope: Lowpass filtered envelope signal for different frequency bands, must have shape
|
|
289
|
+
(channels, auditory bands, signal_length)
|
|
290
|
+
:param modulation: Modulation filtered signal for different frequency bands, must have shape
|
|
291
|
+
(channels, auditory bands, modulation bands, signal_length)
|
|
292
|
+
:return: Three arrays containing the envelope power spectrum model (EPSM), power spectrum model (PSM), and a
|
|
293
|
+
correction matrix (dc2mod). Each of these arrays has shape
|
|
294
|
+
(channels, auditory bands, modulation bands, new_length)
|
|
295
|
+
"""
|
|
296
|
+
epsm = np.zeros((self._n_chan, self.gtfb.cf.size, self.mfb.n_filters,
|
|
297
|
+
round(np.ceil(self._sig_len_dec / np.floor(self._sample_rate_dec / self.mfb.cf[-1])))))
|
|
298
|
+
psm = np.zeros(epsm.shape)
|
|
299
|
+
dc2mod = np.ones(epsm.shape)
|
|
300
|
+
modulation *= np.sqrt(2)
|
|
301
|
+
for ch in range(self._n_chan):
|
|
302
|
+
for aud_idx in range(self.gtfb.cf.size):
|
|
303
|
+
tmp_pow = np.max((np.square(np.mean(envelope[ch, aud_idx, :])), self._iso_thres[aud_idx]))
|
|
304
|
+
lpfb = self.lowpass_filterbank(envelope[ch, aud_idx, :])
|
|
305
|
+
for mf_idx in range(self.mfb.n_filters):
|
|
306
|
+
samples_per_mf = round(np.floor(self._sample_rate_dec / self.mfb.cf[mf_idx]))
|
|
307
|
+
n_windows = round(np.ceil(self._sig_len_dec / samples_per_mf))
|
|
308
|
+
if tmp_pow <= self._iso_thres[aud_idx]:
|
|
309
|
+
epsm[ch, aud_idx, mf_idx, 0:n_windows] = self._env_pow_lim
|
|
310
|
+
psm[ch, aud_idx, mf_idx, 0:n_windows] = self._iso_thres[aud_idx]
|
|
311
|
+
continue
|
|
312
|
+
pow_dc_seg = lpfb[mf_idx, samples_per_mf-1::samples_per_mf] ** 2
|
|
313
|
+
if n_windows * samples_per_mf != self._sig_len_dec:
|
|
314
|
+
pow_dc_seg = np.concatenate((pow_dc_seg, [lpfb[mf_idx, -1] ** 2]))
|
|
315
|
+
for win_idx in range(n_windows):
|
|
316
|
+
ac = modulation[ch, aud_idx, mf_idx, win_idx * samples_per_mf:(win_idx + 1) * samples_per_mf]
|
|
317
|
+
if mf_idx == 0:
|
|
318
|
+
pow_ac = np.var(ac)
|
|
319
|
+
else:
|
|
320
|
+
pow_ac = np.mean(ac ** 2)
|
|
321
|
+
if pow_dc_seg[win_idx] <= self._iso_thres[aud_idx]:
|
|
322
|
+
epsm[ch, aud_idx, mf_idx, win_idx] = self._env_pow_lim
|
|
323
|
+
else:
|
|
324
|
+
epsm[ch, aud_idx, mf_idx, win_idx] = np.max((pow_ac / pow_dc_seg[win_idx],
|
|
325
|
+
self._env_pow_lim))
|
|
326
|
+
psm[ch, aud_idx, mf_idx, win_idx] = np.max((pow_dc_seg[win_idx], self._iso_thres[aud_idx]))
|
|
327
|
+
if pow_dc_seg[win_idx] > self._upper_lim[aud_idx]:
|
|
328
|
+
dc2mod[ch, aud_idx, mf_idx, win_idx] = 1.
|
|
329
|
+
else:
|
|
330
|
+
dc2mod[ch, aud_idx, mf_idx, win_idx] = np.max((
|
|
331
|
+
self._slope * (10 * np.log10(pow_dc_seg[win_idx] / self._iso_thres[aud_idx] + 1e-30)),
|
|
332
|
+
0.))
|
|
333
|
+
return epsm, psm, dc2mod
|
|
334
|
+
|
|
335
|
+
def export_config(self) -> dict:
|
|
336
|
+
"""
|
|
337
|
+
Export the parameters as dictionary.
|
|
338
|
+
:return: Dictionary containing parameters
|
|
339
|
+
"""
|
|
340
|
+
out_config = dict(
|
|
341
|
+
binaural=self._binaural,
|
|
342
|
+
aud_filt_range=self._aud_filt_range,
|
|
343
|
+
mod_filt_range=self._mod_filt_range,
|
|
344
|
+
corr_thres=self._corr_thres,
|
|
345
|
+
decimation_factor=self._decimation_factor,
|
|
346
|
+
limits=self._limits,
|
|
347
|
+
threshold_scaling=self._threshold_scaling
|
|
348
|
+
)
|
|
349
|
+
return out_config
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# This file is part of auditory_models
|
|
2
|
+
# Copyright (C) 2025 Max Zimmermann
|
|
3
|
+
#
|
|
4
|
+
# auditory_models is free software: you can redistribute it and/or modify
|
|
5
|
+
# it under the terms of the GNU General Public License as published by
|
|
6
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
7
|
+
# (at your option) any later version.
|
|
8
|
+
#
|
|
9
|
+
# auditory_models is distributed in the hope that it will be useful,
|
|
10
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
11
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
12
|
+
# GNU General Public License for more details.
|
|
13
|
+
#
|
|
14
|
+
# You should have received a copy of the GNU General Public License
|
|
15
|
+
# along with auditory_models. If not, see <https://www.gnu.org/licenses/>.
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
# This file is part of auditory_models
|
|
2
|
+
# Copyright (C) 2025 Max Zimmermann
|
|
3
|
+
#
|
|
4
|
+
# auditory_models is free software: you can redistribute it and/or modify
|
|
5
|
+
# it under the terms of the GNU General Public License as published by
|
|
6
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
7
|
+
# (at your option) any later version.
|
|
8
|
+
#
|
|
9
|
+
# auditory_models is distributed in the hope that it will be useful,
|
|
10
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
11
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
12
|
+
# GNU General Public License for more details.
|
|
13
|
+
#
|
|
14
|
+
# You should have received a copy of the GNU General Public License
|
|
15
|
+
# along with auditory_models. If not, see <https://www.gnu.org/licenses/>.
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
from numpy import pi, cos, exp, sqrt
|
|
20
|
+
from numpy.typing import ArrayLike, NDArray
|
|
21
|
+
from scipy.signal import butter, sosfilt
|
|
22
|
+
from scipy.special import factorial
|
|
23
|
+
|
|
24
|
+
from auditory_models.helpers.utils import generate_delta_impulse, hertz_to_erbscale, erb_aud, freq_erb_spaced
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class FilterbankBase:
|
|
28
|
+
"""
|
|
29
|
+
This class implements a base class for filterbank implementations.
|
|
30
|
+
"""
|
|
31
|
+
def __init__(self, sample_rate: float, cf: ArrayLike | None):
|
|
32
|
+
if np.ndim(cf) > 1:
|
|
33
|
+
raise ValueError(f"cf must be one-dimensional, currently: {np.ndim(cf)}")
|
|
34
|
+
self._sample_rate = sample_rate
|
|
35
|
+
if cf is not None:
|
|
36
|
+
self._cf = np.array(cf)
|
|
37
|
+
else:
|
|
38
|
+
self._cf = None
|
|
39
|
+
self._coeffs = None
|
|
40
|
+
self._states = None
|
|
41
|
+
self.n_filters = None
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def cf(self) -> NDArray:
|
|
45
|
+
""" Indicate that `cf` is a read-only variable. """
|
|
46
|
+
return self._cf
|
|
47
|
+
|
|
48
|
+
def _design_filterbank(self) -> None:
|
|
49
|
+
""" Compute the filter coefficients. """
|
|
50
|
+
raise NotImplementedError
|
|
51
|
+
|
|
52
|
+
def reset_states(self) -> None:
|
|
53
|
+
""" Reset all filter states to zero. """
|
|
54
|
+
self._states = np.zeros(self._states.shape, dtype=self._states.dtype)
|
|
55
|
+
|
|
56
|
+
def process(self, signal: NDArray, save_state: bool = True) -> NDArray:
|
|
57
|
+
"""
|
|
58
|
+
Compute the filterbank output of a given signal.
|
|
59
|
+
:param signal: Input signal, shape may be either 1-dim or 2-dim (n_filters, signal_length).
|
|
60
|
+
:param save_state: Indicating if state should be memorized.
|
|
61
|
+
:return: Output signal with shape (n_filters, signal_length)
|
|
62
|
+
"""
|
|
63
|
+
if signal.ndim == 1:
|
|
64
|
+
signal = np.repeat(signal.reshape(1, signal.size), self.n_filters, axis=0)
|
|
65
|
+
elif signal.ndim == 2:
|
|
66
|
+
if signal.shape[0] != self.n_filters:
|
|
67
|
+
raise ValueError(f"Input signal must be of shape (number_of_filters, signal_length). The first "
|
|
68
|
+
f"dimension does not match. Is currently {signal.shape[0]}, must be "
|
|
69
|
+
f"{self.n_filters}")
|
|
70
|
+
else:
|
|
71
|
+
raise ValueError(f"Signal must have either one or two dimensions. In case of two dimensions, the first "
|
|
72
|
+
f"must have a length equal to the number of filters ({self.n_filters}).")
|
|
73
|
+
if save_state:
|
|
74
|
+
for i, sos_coeff in enumerate(self._coeffs):
|
|
75
|
+
signal[i, :], self._states[i, :, :] = sosfilt(sos_coeff, signal[i, :], zi=self._states[i, :, :])
|
|
76
|
+
else:
|
|
77
|
+
for i, sos_coeff in enumerate(self._coeffs):
|
|
78
|
+
signal[i, :], _ = sosfilt(sos_coeff, signal[i, :], zi=self._states[i, :, :])
|
|
79
|
+
return signal
|
|
80
|
+
|
|
81
|
+
def synthesize(self, bands: NDArray) -> NDArray:
|
|
82
|
+
raise NotImplementedError
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class GammatoneFilterbank(FilterbankBase):
|
|
86
|
+
"""
|
|
87
|
+
This class implements gammatone filters and a filtering routine.
|
|
88
|
+
|
|
89
|
+
Reference:
|
|
90
|
+
[Hohmann2002]
|
|
91
|
+
Hohmann, V., Frequency analysis and synthesis using a Gammatone filterbank,
|
|
92
|
+
Acta Acustica, Vol 88 (2002), 433--442
|
|
93
|
+
"""
|
|
94
|
+
def __init__(self, sample_rate: float = 44100,
|
|
95
|
+
order: int = 4,
|
|
96
|
+
normfreq: float = 1000.0,
|
|
97
|
+
freq_range: tuple[float, float] | None = None,
|
|
98
|
+
band_range: tuple[int, int] = (-12, 12),
|
|
99
|
+
density: float = 1.0,
|
|
100
|
+
cf: ArrayLike | None = None,
|
|
101
|
+
bandwidths: ArrayLike | None = None,
|
|
102
|
+
bandwidth_factor: float = 1.0, attenuation_half_bandwidth_db: float = -3,
|
|
103
|
+
desired_delay_sec: float = 0.02):
|
|
104
|
+
"""
|
|
105
|
+
Init method
|
|
106
|
+
:param sample_rate: sample rate of filterbank in Hz
|
|
107
|
+
:param order: order of filters
|
|
108
|
+
:param normfreq: The reference frequency for `startband` and `endband`
|
|
109
|
+
:param freq_range: two values of the lowest and highest possible center-frequency in Hz, overrides
|
|
110
|
+
`band_range`, first value must be lower than `norm_freq`, second value must be higher than `norm_freq`
|
|
111
|
+
:param band_range: two values defining the number of filters above and below the `normfreq`, if freq_range is
|
|
112
|
+
given, this value will be overridden!
|
|
113
|
+
:param density: ERB density of 1 would be `erb_aud`
|
|
114
|
+
:param cf: Sequence of center-frequencies in Hz, overrides automatic computation via `normfreq`, `freq_range`,
|
|
115
|
+
`band_range`, and `density`
|
|
116
|
+
:param bandwidths: array of bandwidths of filters, size must be equal to number of filters,
|
|
117
|
+
if None given it will default to the ERB of the respective center-frequencies
|
|
118
|
+
:param bandwidth_factor: if bandwidths is not specified, they will be computed via the erb_aud() of each
|
|
119
|
+
center-frequency multiplied by this parameter
|
|
120
|
+
:param attenuation_half_bandwidth_db: attenuation of the filters at half bandwidth in dB
|
|
121
|
+
:param desired_delay_sec:
|
|
122
|
+
"""
|
|
123
|
+
super().__init__(sample_rate, cf)
|
|
124
|
+
self._order = order
|
|
125
|
+
if freq_range is None:
|
|
126
|
+
if len(band_range) != 2:
|
|
127
|
+
raise ValueError(f"`band_range` must have a size of 2, currently: {len(band_range)}")
|
|
128
|
+
startband = band_range[0]
|
|
129
|
+
endband = band_range[1]
|
|
130
|
+
else:
|
|
131
|
+
if len(freq_range) != 2:
|
|
132
|
+
raise ValueError(f"`freq_range` must have a size of 2, currently: {len(freq_range)}")
|
|
133
|
+
if freq_range[0] >= normfreq or freq_range[1] <= normfreq:
|
|
134
|
+
raise ValueError("`freq_range` values must be lower/higher than `norm_freq`! Currently `freq_range` = "
|
|
135
|
+
f"{freq_range}, `norm_freq` = {normfreq}")
|
|
136
|
+
startband = round(np.fix(hertz_to_erbscale(freq_range[0]) - hertz_to_erbscale(normfreq)))
|
|
137
|
+
endband = round(np.ceil(hertz_to_erbscale(freq_range[1]) - hertz_to_erbscale(normfreq)))
|
|
138
|
+
|
|
139
|
+
if self._cf is None:
|
|
140
|
+
self._cf = freq_erb_spaced(startband, endband, normfreq, density)
|
|
141
|
+
else:
|
|
142
|
+
self._cf = np.array(cf, dtype=np.float64)
|
|
143
|
+
self.n_filters = self._cf.size
|
|
144
|
+
self._bandwidths = bandwidths
|
|
145
|
+
self._use_erb = False
|
|
146
|
+
if self._bandwidths is None:
|
|
147
|
+
self._use_erb = True
|
|
148
|
+
self._bandwidths = bandwidth_factor * erb_aud(self._cf)
|
|
149
|
+
self._attenuation_half_bandwidth_db = attenuation_half_bandwidth_db
|
|
150
|
+
self._states = np.zeros((self.n_filters, self._order, 2), dtype=np.complex128)
|
|
151
|
+
self._design_filterbank()
|
|
152
|
+
|
|
153
|
+
self._desired_delay_samples = int(self._sample_rate * desired_delay_sec)
|
|
154
|
+
self._max_indices, self._slopes = self.estimate_max_indices_and_slopes()
|
|
155
|
+
self._delay_samples = self._desired_delay_samples - self._max_indices
|
|
156
|
+
self._delay_memory = np.zeros((len(self._cf), np.max(self._delay_samples)))
|
|
157
|
+
self._phase_factors = np.abs(self._slopes) * 1j / self._slopes
|
|
158
|
+
self._gains = np.ones(len(self._cf))
|
|
159
|
+
|
|
160
|
+
def _design_filterbank(self) -> None:
|
|
161
|
+
""" Returns filter coefficients of a gammatone filter [Hohmann2002]. """
|
|
162
|
+
if self._use_erb:
|
|
163
|
+
# [Hohmann2002] eq. (14)
|
|
164
|
+
a_gamma = (pi * factorial(2 * self._order - 2) *
|
|
165
|
+
2 ** -(2 * self._order - 2) /
|
|
166
|
+
factorial(self._order - 1) ** 2)
|
|
167
|
+
b = self._bandwidths / a_gamma
|
|
168
|
+
lambda_ = np.exp(-2 * pi * b / self._sample_rate)
|
|
169
|
+
else:
|
|
170
|
+
# [Hohmann2002] eq. (12)
|
|
171
|
+
phi = pi * self._bandwidths / self._sample_rate
|
|
172
|
+
alpha = 10 ** (0.1 * self._attenuation_half_bandwidth_db / self._order)
|
|
173
|
+
p = (-2 + 2 * alpha * cos(phi)) / (1 - alpha)
|
|
174
|
+
lambda_ = -p / 2 - sqrt(p * p / 4 - 1)
|
|
175
|
+
beta = 2 * pi * self._cf / self._sample_rate
|
|
176
|
+
coef = lambda_ * exp(1j * beta)
|
|
177
|
+
factor = 2 * (1 - np.abs(coef)) ** self._order
|
|
178
|
+
self._coeffs = np.zeros((self.n_filters, self._order, 6), dtype=np.complex128)
|
|
179
|
+
for idx, c in enumerate(-coef):
|
|
180
|
+
self._coeffs[idx, :, :] = np.repeat([[1., 0., 0., 1., c, 0.]], self._order, axis=0)
|
|
181
|
+
self._coeffs[:, 0, 0] = factor
|
|
182
|
+
|
|
183
|
+
# def synthesize(self, bands: NDArray) -> NDArray:
|
|
184
|
+
# return np.array(list(self.delay([b*g for b, g in zip(bands, self._gains)]))).sum(axis=0)
|
|
185
|
+
|
|
186
|
+
def delay(self, bands):
|
|
187
|
+
for i, band in enumerate(bands):
|
|
188
|
+
if self._delay_samples[i] == 0:
|
|
189
|
+
yield np.real(band) * self._phase_factors[i]
|
|
190
|
+
else:
|
|
191
|
+
yield np.concatenate((self._delay_memory[i, :self._delay_samples[i]],
|
|
192
|
+
np.real(band[:-self._delay_samples[i]])), axis=0)
|
|
193
|
+
self._delay_memory[i, :self._delay_samples[i]] = np.real(band[-self._delay_samples[i]:])
|
|
194
|
+
|
|
195
|
+
def estimate_max_indices_and_slopes(self):
|
|
196
|
+
sig = generate_delta_impulse(self._desired_delay_samples, dtype=np.complex128)
|
|
197
|
+
bands = self.process(sig, save_state=False)
|
|
198
|
+
ibandmax = np.argmax(np.abs(bands[:self._desired_delay_samples]), axis=-1)
|
|
199
|
+
slopes = [b[i+1] - b[i-1] for (b, i) in zip(bands, ibandmax)]
|
|
200
|
+
return np.array(ibandmax), np.array(slopes)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class BandpassFilterbank(FilterbankBase):
|
|
204
|
+
"""
|
|
205
|
+
Implements a bandpass filterbank from given center-frequencies with octave width for each band. There is also the
|
|
206
|
+
option to add a lowpass filter at 1Hz to include the DC-component.
|
|
207
|
+
"""
|
|
208
|
+
def __init__(self, cf: ArrayLike, sample_rate: float, order: int = 2, dc_lowpass: bool = True):
|
|
209
|
+
"""
|
|
210
|
+
Init method
|
|
211
|
+
:param cf: Array of center-frequencies for filters in Hz
|
|
212
|
+
:param sample_rate: Sampling rate of filterbank in Hz
|
|
213
|
+
:param order: Order of bandpass filters
|
|
214
|
+
:param dc_lowpass: Indicates if an additional band for the DC via a lowpass at 1 Hz should be computed
|
|
215
|
+
"""
|
|
216
|
+
super().__init__(sample_rate, cf)
|
|
217
|
+
self._cf = np.array(cf)
|
|
218
|
+
self._sample_rate = sample_rate
|
|
219
|
+
if order % 2 != 0:
|
|
220
|
+
raise ValueError(f"order must be an even integer, currently: {order}")
|
|
221
|
+
self._order = order
|
|
222
|
+
self._dc_lp = dc_lowpass
|
|
223
|
+
self.n_filters = self._cf.size + self._dc_lp
|
|
224
|
+
self._states = np.zeros((self.n_filters, round(self._order / 2), 2))
|
|
225
|
+
self._coeffs = np.zeros((self.n_filters, round(self._order / 2), 6))
|
|
226
|
+
self._design_filterbank()
|
|
227
|
+
|
|
228
|
+
def _design_filterbank(self) -> None:
|
|
229
|
+
""" Compute the filter coefficients """
|
|
230
|
+
if self._dc_lp:
|
|
231
|
+
self._coeffs[0, :, :] = butter(self._order, 1, output="sos", fs=self._sample_rate)
|
|
232
|
+
fhigh = self._cf / 2 + np.sqrt(np.square(self._cf / 2) + np.square(self._cf))
|
|
233
|
+
flow = np.square(self._cf) / fhigh
|
|
234
|
+
for idx, (low, high) in enumerate(zip(flow, fhigh)):
|
|
235
|
+
self._coeffs[idx + self._dc_lp, :, :] = butter(round(self._order / 2), [low, high], btype="bandpass",
|
|
236
|
+
output="sos", fs=self._sample_rate)
|
|
237
|
+
if self._dc_lp:
|
|
238
|
+
self._cf = np.concatenate(([1], self._cf))
|