python-peass 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- peass/__init__.py +18 -0
- peass/auditory_model.py +158 -0
- peass/decomposition.py +470 -0
- peass/gammatone.py +267 -0
- peass/metrics.py +147 -0
- peass/parameters/paramTask1.npz +0 -0
- peass/parameters/paramTask2.npz +0 -0
- peass/parameters/paramTask3.npz +0 -0
- peass/parameters/paramTask4.npz +0 -0
- peass/predictor.py +131 -0
- python_peass-2.0.1.dist-info/METADATA +165 -0
- python_peass-2.0.1.dist-info/RECORD +14 -0
- python_peass-2.0.1.dist-info/WHEEL +4 -0
- python_peass-2.0.1.dist-info/licenses/LICENSE +680 -0
peass/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""
|
|
2
|
+
python-peass: Perceptual Evaluation methods for Audio Source Separation
|
|
3
|
+
A modern, Pythonic port of the PEASS v2.0.1 toolkit [1].
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
__version__ = "2.0.1"
|
|
7
|
+
|
|
8
|
+
from .decomposition import extract_distortion_components
|
|
9
|
+
from .metrics import audio_quality_features
|
|
10
|
+
from .metrics import calculate_energy_ratios
|
|
11
|
+
from .predictor import predict_peass_scores
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"predict_peass_scores",
|
|
15
|
+
"extract_distortion_components",
|
|
16
|
+
"calculate_energy_ratios",
|
|
17
|
+
"audio_quality_features",
|
|
18
|
+
]
|
peass/auditory_model.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PEASS Auditory Package - Dau 1996/1997 Psychoacoustic Ear Model [1, 2]
|
|
3
|
+
|
|
4
|
+
This module ports the legacy C/MEX elements (haircell.c, adapt.c) into pure,
|
|
5
|
+
performant Python [1, 3]. It simulates the transduction process of the inner hair cells
|
|
6
|
+
and the temporal adaptation (forward masking) of the auditory nerve.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Tuple
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import scipy.signal as signal
|
|
13
|
+
|
|
14
|
+
from .gammatone import GammatoneAnalyzer
|
|
15
|
+
|
|
16
|
+
# Check for Numba availability
|
|
17
|
+
try:
|
|
18
|
+
import numba
|
|
19
|
+
|
|
20
|
+
_HAS_NUMBA = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
_HAS_NUMBA = False
|
|
23
|
+
|
|
24
|
+
if _HAS_NUMBA:
|
|
25
|
+
@numba.jit(nopython=True, cache=True)
|
|
26
|
+
def _numba_adaptation_loop(rx: np.ndarray, gain_val: float, sthresh: float, factor: np.ndarray) -> np.ndarray:
|
|
27
|
+
num_bands, num_samples = rx.shape
|
|
28
|
+
for sample_idx in range(num_samples):
|
|
29
|
+
for band_idx in range(num_bands):
|
|
30
|
+
val = rx[band_idx, sample_idx] / factor[band_idx]
|
|
31
|
+
rx[band_idx, sample_idx] = val
|
|
32
|
+
factor[band_idx] = max((1.0 - gain_val) * val + gain_val * factor[band_idx], sthresh)
|
|
33
|
+
return rx
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def haircell_transduction(subband_signals: np.ndarray, sampling_frequency: float) -> np.ndarray:
|
|
37
|
+
"""
|
|
38
|
+
Models the nonlinear mechanical-to-neural transduction of the inner hair cells.
|
|
39
|
+
Replaces haircell.c MEX script [2, 3].
|
|
40
|
+
|
|
41
|
+
Stages:
|
|
42
|
+
1. Half-wave rectification (simulates unidirectional shearing of hair bundle)
|
|
43
|
+
2. 1 kHz first-order lowpass filter (simulates inner hair cell membrane limits)
|
|
44
|
+
"""
|
|
45
|
+
# % gain=exp(-pi*2000/fs);
|
|
46
|
+
# % rx=filter(1-gain,[1 -gain],max(rx,0),[],2);
|
|
47
|
+
gain_haircell = np.exp(-np.pi * 2000.0 / sampling_frequency)
|
|
48
|
+
b_hc = np.array([1.0 - gain_haircell])
|
|
49
|
+
a_hc = np.array([1.0, -gain_haircell])
|
|
50
|
+
|
|
51
|
+
# Process rectified signals over the sample dimension (axis 1)
|
|
52
|
+
rectified_signals = np.maximum(subband_signals, 0.0)
|
|
53
|
+
return signal.lfilter(b_hc, a_hc, rectified_signals, axis=1)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def adaptation_loops(subband_signals: np.ndarray, sampling_frequency: float) -> np.ndarray:
|
|
57
|
+
"""
|
|
58
|
+
Simulates the physiological adaptive properties of the auditory nerve.
|
|
59
|
+
Replaces adapt.c MEX script [2].
|
|
60
|
+
|
|
61
|
+
Runs 5 consecutive non-linear feedback loops modeling forward masking,
|
|
62
|
+
vectorized across all bands for optimal execution in Python.
|
|
63
|
+
"""
|
|
64
|
+
dbrange = 100.0
|
|
65
|
+
thresh = 10.0 ** (-dbrange / 20.0)
|
|
66
|
+
bw_loop = 1.0 / (np.pi * np.array([0.005, 0.05, 0.129, 0.253, 0.5]))
|
|
67
|
+
|
|
68
|
+
# % rx=max(single(rx),thresh);
|
|
69
|
+
rx = np.maximum(subband_signals.astype(np.float32), thresh)
|
|
70
|
+
num_bands, num_samples = rx.shape
|
|
71
|
+
|
|
72
|
+
# Process each of the 5 adaptive stages
|
|
73
|
+
sthresh = thresh
|
|
74
|
+
for stage_idx in range(5):
|
|
75
|
+
gain_val = np.exp(-np.pi * bw_loop[stage_idx] / sampling_frequency)
|
|
76
|
+
sthresh = np.sqrt(sthresh)
|
|
77
|
+
factor = np.full(num_bands, sthresh, dtype=np.float32) # divisor factor for each band
|
|
78
|
+
|
|
79
|
+
if _HAS_NUMBA:
|
|
80
|
+
# Compiled loop executing at native C speeds
|
|
81
|
+
rx = _numba_adaptation_loop(rx, float(gain_val), float(sthresh), factor)
|
|
82
|
+
else:
|
|
83
|
+
# Fallback pure-Python loop
|
|
84
|
+
for sample_idx in range(num_samples):
|
|
85
|
+
# Divide current sample by current divisor factor
|
|
86
|
+
val = rx[:, sample_idx] / factor
|
|
87
|
+
rx[:, sample_idx] = val
|
|
88
|
+
# Update divisor filter state
|
|
89
|
+
factor = np.maximum((1.0 - gain_val) * val + gain_val * factor, sthresh)
|
|
90
|
+
|
|
91
|
+
# % rx=double(dbrange/(1-sthresh))*(double(rx)-double(sthresh));
|
|
92
|
+
return (dbrange / (1.0 - sthresh)) * (rx - sthresh)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def generate_internal_representation(
|
|
96
|
+
signal_data: np.ndarray,
|
|
97
|
+
sampling_frequency: float,
|
|
98
|
+
modulation_processing_type: str = 'lp'
|
|
99
|
+
) -> Tuple[np.ndarray, float]:
|
|
100
|
+
"""
|
|
101
|
+
Generates the 3D internal auditory representation of a signal.
|
|
102
|
+
Equivalent of pemo_internal.m [1].
|
|
103
|
+
"""
|
|
104
|
+
if len(signal_data.shape) > 1:
|
|
105
|
+
if signal_data.shape[0] < signal_data.shape[1]:
|
|
106
|
+
signal_data = signal_data.T
|
|
107
|
+
signal_data = signal_data.ravel()
|
|
108
|
+
|
|
109
|
+
# Model input scaling (1.0 becomes 100 dB SPL)
|
|
110
|
+
signal_data = 10.0 * signal_data
|
|
111
|
+
|
|
112
|
+
# Frequency analysis boundaries
|
|
113
|
+
fmin = 235.0
|
|
114
|
+
fmax = min(0.5 * sampling_frequency, 14500.0)
|
|
115
|
+
if sampling_frequency < 3.0 * fmax:
|
|
116
|
+
new_fs = int(round(1.5 * sampling_frequency))
|
|
117
|
+
signal_data = signal_data.astype(float)
|
|
118
|
+
signal_data = signal.resample(signal_data, int(round(len(signal_data) * new_fs / sampling_frequency)))
|
|
119
|
+
sampling_frequency = float(new_fs)
|
|
120
|
+
|
|
121
|
+
analyzer = GammatoneAnalyzer(sampling_frequency, fmin, 1000.0, fmax, 1.0)
|
|
122
|
+
num_bands = len(analyzer.filters)
|
|
123
|
+
|
|
124
|
+
# Subband analysis
|
|
125
|
+
subbands = np.real(analyzer.process(signal_data))
|
|
126
|
+
|
|
127
|
+
# Transduction and Adaptation stages
|
|
128
|
+
transduced = haircell_transduction(subbands, sampling_frequency)
|
|
129
|
+
adapted = adaptation_loops(transduced, sampling_frequency)
|
|
130
|
+
|
|
131
|
+
# Modulation Filtering & Downsampling
|
|
132
|
+
if modulation_processing_type == 'fb':
|
|
133
|
+
adapted = signal.resample(adapted, int(round(adapted.shape[1] * 800.0 / sampling_frequency)), axis=1)
|
|
134
|
+
sampling_frequency = 800.0
|
|
135
|
+
center_frequencies_mod = np.concatenate(([0.0, 5.0], 10.0 * (5.0 / 3.0) ** np.arange(6)))
|
|
136
|
+
bandwidth_mod = np.concatenate(([5.0, 5.0], 5.0 * (5.0 / 3.0) ** np.arange(6)))
|
|
137
|
+
else:
|
|
138
|
+
adapted = signal.resample(adapted, int(round(adapted.shape[1] * 100.0 / sampling_frequency)), axis=1)
|
|
139
|
+
sampling_frequency = 100.0
|
|
140
|
+
center_frequencies_mod = np.array([0.0])
|
|
141
|
+
bandwidth_mod = np.array([15.92])
|
|
142
|
+
|
|
143
|
+
num_modulations = len(center_frequencies_mod)
|
|
144
|
+
num_samples = adapted.shape[1]
|
|
145
|
+
internal_representation = np.zeros((num_bands, num_samples, num_modulations), dtype=complex)
|
|
146
|
+
|
|
147
|
+
for m in range(num_modulations):
|
|
148
|
+
gain_val = np.exp(-np.pi * bandwidth_mod[m] / sampling_frequency)
|
|
149
|
+
b_mod = np.array([1.0 - gain_val])
|
|
150
|
+
a_mod = np.array([1.0, -gain_val * np.exp(2j * np.pi * center_frequencies_mod[m] / sampling_frequency)])
|
|
151
|
+
internal_representation[:, :, m] = signal.lfilter(b_mod, a_mod, adapted, axis=1)
|
|
152
|
+
|
|
153
|
+
# Hilbert envelope extraction above 10 Hz
|
|
154
|
+
above_10_hz = (center_frequencies_mod > 10.0)
|
|
155
|
+
internal_representation[:, :, ~above_10_hz] = np.real(internal_representation[:, :, ~above_10_hz])
|
|
156
|
+
internal_representation[:, :, above_10_hz] = np.abs(internal_representation[:, :, above_10_hz])
|
|
157
|
+
|
|
158
|
+
return internal_representation, sampling_frequency
|
peass/decomposition.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PEASS Decomposition Package - Least-Squares Distortion Decomposer [1]
|
|
3
|
+
|
|
4
|
+
This module decomposes the separation error of a source estimate into:
|
|
5
|
+
1. Target distortion (filter-induced alterations)
|
|
6
|
+
2. Interference (leakage from overlapping sources)
|
|
7
|
+
3. Artifacts (artificial noise / musical noise components)
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import pathlib
|
|
11
|
+
from typing import List
|
|
12
|
+
from typing import Tuple
|
|
13
|
+
from typing import Union
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import scipy.linalg as linalg
|
|
17
|
+
import scipy.signal as signal
|
|
18
|
+
import soundfile as sf
|
|
19
|
+
|
|
20
|
+
from .gammatone import GammatoneAnalyzer
|
|
21
|
+
from .gammatone import GammatoneSynthesizer
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def least_squares_decompose(
|
|
25
|
+
source_estimates: np.ndarray,
|
|
26
|
+
true_sources: np.ndarray,
|
|
27
|
+
filter_half_length: int,
|
|
28
|
+
analysis_window: np.ndarray
|
|
29
|
+
) -> np.ndarray:
|
|
30
|
+
"""
|
|
31
|
+
Weighted least-squares projection of source estimate on the source subspaces.
|
|
32
|
+
Equivalent of LSDecompose.m [1].
|
|
33
|
+
"""
|
|
34
|
+
filter_length = 2 * filter_half_length + 1
|
|
35
|
+
num_sources = true_sources.shape[1]
|
|
36
|
+
num_samples = source_estimates.shape[0]
|
|
37
|
+
|
|
38
|
+
toeplitz_matrix = np.zeros((num_samples, num_sources * filter_length), dtype=true_sources.dtype)
|
|
39
|
+
for j in range(num_sources):
|
|
40
|
+
col = true_sources[filter_length - 1:, j]
|
|
41
|
+
row = true_sources[filter_length - 1::-1, j]
|
|
42
|
+
toeplitz_matrix[:, j * filter_length: (j + 1) * filter_length] = linalg.toeplitz(col, row)
|
|
43
|
+
|
|
44
|
+
weighted_sources = analysis_window[:, np.newaxis] * toeplitz_matrix
|
|
45
|
+
weighted_estimates = analysis_window[:, np.newaxis] * source_estimates
|
|
46
|
+
|
|
47
|
+
gram_matrix = weighted_sources.conj().T @ weighted_sources
|
|
48
|
+
reg_lambda = 10.0 ** -15
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
cholesky_factor = linalg.cholesky(gram_matrix + reg_lambda * np.eye(gram_matrix.shape[0]), lower=False)
|
|
52
|
+
test_condition = False
|
|
53
|
+
except (linalg.LinAlgError, ValueError):
|
|
54
|
+
test_condition = True
|
|
55
|
+
|
|
56
|
+
if test_condition:
|
|
57
|
+
projection_weights = np.linalg.pinv(weighted_sources) @ weighted_estimates
|
|
58
|
+
else:
|
|
59
|
+
b = weighted_sources.conj().T @ weighted_estimates
|
|
60
|
+
tmp = linalg.solve_triangular(cholesky_factor.conj().T, b, lower=True)
|
|
61
|
+
projection_weights = linalg.solve_triangular(cholesky_factor, tmp, lower=False)
|
|
62
|
+
|
|
63
|
+
projections = np.zeros((num_samples, source_estimates.shape[1], num_sources), dtype=source_estimates.dtype)
|
|
64
|
+
weighted_diag = analysis_window[:, np.newaxis]
|
|
65
|
+
for j in range(num_sources):
|
|
66
|
+
projections[:, :, j] = weighted_diag * (toeplitz_matrix[:, j * filter_length: (j + 1) * filter_length] @
|
|
67
|
+
projection_weights[j * filter_length: (j + 1) * filter_length, :])
|
|
68
|
+
|
|
69
|
+
return projections
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def least_squares_decompose_time_varying(
|
|
73
|
+
source_estimates: np.ndarray,
|
|
74
|
+
true_sources: np.ndarray,
|
|
75
|
+
filter_length: int,
|
|
76
|
+
window_length: int,
|
|
77
|
+
hop_size: int
|
|
78
|
+
) -> np.ndarray:
|
|
79
|
+
"""
|
|
80
|
+
Time-varying least-squares subband decomposer.
|
|
81
|
+
Equivalent of LSDecompose_tv.m [1].
|
|
82
|
+
"""
|
|
83
|
+
filter_half_length = (filter_length - 1) // 2
|
|
84
|
+
if (filter_length - 1) % 2 != 0:
|
|
85
|
+
raise ValueError("Filter length must be an odd integer.")
|
|
86
|
+
|
|
87
|
+
pad_length = filter_length - 1 + window_length - 1
|
|
88
|
+
true_sources = np.pad(true_sources, ((0, pad_length), (0, 0)), mode='constant')
|
|
89
|
+
source_estimates = np.pad(source_estimates, ((0, pad_length), (0, 0)), mode='constant')
|
|
90
|
+
|
|
91
|
+
total_samples, num_sources = true_sources.shape
|
|
92
|
+
num_channels = source_estimates.shape[1]
|
|
93
|
+
|
|
94
|
+
# Periodic Hann windows
|
|
95
|
+
hann_win = signal.windows.hann(window_length, sym=False)
|
|
96
|
+
analysis_window = np.sqrt(np.flipud(hann_win))
|
|
97
|
+
synthesis_window = np.sqrt(np.flipud(hann_win))
|
|
98
|
+
|
|
99
|
+
synthesis_weights = np.zeros((window_length, num_channels, num_sources))
|
|
100
|
+
for chan in range(num_channels):
|
|
101
|
+
for j in range(num_sources):
|
|
102
|
+
synthesis_weights[:, chan, j] = synthesis_window
|
|
103
|
+
|
|
104
|
+
w_begin = 0
|
|
105
|
+
w_end = w_begin + window_length
|
|
106
|
+
|
|
107
|
+
projections_accum = np.zeros((total_samples, num_channels, num_sources), dtype=true_sources.dtype)
|
|
108
|
+
window_accum = np.zeros((total_samples, 1))
|
|
109
|
+
|
|
110
|
+
while w_end - window_length / 2.0 <= projections_accum.shape[0] - window_length + 1:
|
|
111
|
+
frame_estimates = source_estimates[w_begin:w_end, :]
|
|
112
|
+
|
|
113
|
+
sw_start = w_begin - filter_half_length
|
|
114
|
+
sw_end = w_end + filter_half_length
|
|
115
|
+
pad_left = max(0, -sw_start)
|
|
116
|
+
pad_right = max(0, sw_end - true_sources.shape[0])
|
|
117
|
+
slice_start = max(0, sw_start)
|
|
118
|
+
slice_end = min(true_sources.shape[0], sw_end)
|
|
119
|
+
|
|
120
|
+
frame_sources_slice = true_sources[slice_start:slice_end, :]
|
|
121
|
+
frame_sources = np.vstack([
|
|
122
|
+
np.zeros((pad_left, num_sources), dtype=true_sources.dtype),
|
|
123
|
+
frame_sources_slice,
|
|
124
|
+
np.zeros((pad_right, num_sources), dtype=true_sources.dtype)
|
|
125
|
+
])
|
|
126
|
+
|
|
127
|
+
frame_projections = least_squares_decompose(frame_estimates, frame_sources, filter_half_length, analysis_window)
|
|
128
|
+
|
|
129
|
+
projections_accum[w_begin:w_end, :, :] += frame_projections[:window_length, :, :] * synthesis_weights
|
|
130
|
+
window_accum[w_begin:w_end, 0] += synthesis_window * analysis_window
|
|
131
|
+
|
|
132
|
+
w_begin += hop_size
|
|
133
|
+
w_end += hop_size
|
|
134
|
+
|
|
135
|
+
valid_indices = (window_accum[:, 0] != 0)
|
|
136
|
+
for j in range(num_sources):
|
|
137
|
+
projections_accum[valid_indices, :, j] /= window_accum[valid_indices, :]
|
|
138
|
+
|
|
139
|
+
return projections_accum[:-(window_length - 1), :, :]
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def extract_target_spatial_interference_artifacts(
|
|
143
|
+
true_sources: np.ndarray,
|
|
144
|
+
source_estimates: np.ndarray,
|
|
145
|
+
filter_length: int,
|
|
146
|
+
window_length: int,
|
|
147
|
+
hop_size: int,
|
|
148
|
+
flag_two_projections: bool = False
|
|
149
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
150
|
+
"""
|
|
151
|
+
Splits multi-source signal mixtures into Target, Spatial Distortion,
|
|
152
|
+
Interference, and Artifact components. Replaces extractTSIA.m [1].
|
|
153
|
+
"""
|
|
154
|
+
total_samples, num_channels, num_sources = true_sources.shape
|
|
155
|
+
num_estimates = source_estimates.shape[2] if len(source_estimates.shape) > 2 else 1
|
|
156
|
+
if len(source_estimates.shape) == 2:
|
|
157
|
+
source_estimates = source_estimates[:, :, np.newaxis]
|
|
158
|
+
|
|
159
|
+
sources_reshaped = true_sources.reshape((total_samples, num_sources * num_channels), order='F')
|
|
160
|
+
estimates_reshaped = source_estimates.reshape((total_samples, num_estimates * num_channels), order='F')
|
|
161
|
+
|
|
162
|
+
projections_all = least_squares_decompose_time_varying(estimates_reshaped, sources_reshaped, filter_length,
|
|
163
|
+
window_length, hop_size)
|
|
164
|
+
|
|
165
|
+
y_projected = np.zeros((total_samples, num_channels * num_estimates, num_sources), dtype=true_sources.dtype)
|
|
166
|
+
for nSource in range(num_sources):
|
|
167
|
+
start_idx = nSource * num_channels
|
|
168
|
+
end_idx = (nSource + 1) * num_channels
|
|
169
|
+
y_projected[:, :, nSource] = np.sum(projections_all[:total_samples, :, start_idx:end_idx], axis=2)
|
|
170
|
+
|
|
171
|
+
spatial_distortion = np.zeros((total_samples, num_estimates * num_channels), dtype=source_estimates.dtype)
|
|
172
|
+
if flag_two_projections:
|
|
173
|
+
for nEst in range(num_estimates):
|
|
174
|
+
start_est = nEst * num_channels
|
|
175
|
+
end_est = (nEst + 1) * num_channels
|
|
176
|
+
spatial_proj = least_squares_decompose_time_varying(
|
|
177
|
+
estimates_reshaped[:, start_est:end_est],
|
|
178
|
+
sources_reshaped[:, :num_channels],
|
|
179
|
+
filter_length, window_length, hop_size
|
|
180
|
+
)
|
|
181
|
+
spatial_distortion[:, start_est:end_est] = np.sum(spatial_proj[:total_samples, :, :], axis=2)
|
|
182
|
+
|
|
183
|
+
true_reference = np.zeros((total_samples, num_channels * num_estimates), dtype=true_sources.dtype)
|
|
184
|
+
for nEst in range(num_estimates):
|
|
185
|
+
start_est = nEst * num_channels
|
|
186
|
+
end_est = (nEst + 1) * num_channels
|
|
187
|
+
true_reference[:, start_est:end_est] = sources_reshaped[:, :num_channels]
|
|
188
|
+
|
|
189
|
+
if flag_two_projections:
|
|
190
|
+
spatial_distortion = spatial_distortion - true_reference
|
|
191
|
+
else:
|
|
192
|
+
spatial_distortion = y_projected[:, :, :num_estimates].reshape((total_samples, num_estimates * num_channels),
|
|
193
|
+
order='F') - true_reference
|
|
194
|
+
|
|
195
|
+
interference = np.sum(y_projected, axis=2) - spatial_distortion - true_reference
|
|
196
|
+
artifacts = estimates_reshaped - true_reference - spatial_distortion - interference
|
|
197
|
+
|
|
198
|
+
true_reference_3d = true_reference.reshape((total_samples, num_channels, num_estimates), order='F')
|
|
199
|
+
spatial_distortion_3d = spatial_distortion.reshape((total_samples, num_channels, num_estimates), order='F')
|
|
200
|
+
interference_3d = interference.reshape((total_samples, num_channels, num_estimates), order='F')
|
|
201
|
+
artifacts_3d = artifacts.reshape((total_samples, num_channels, num_estimates), order='F')
|
|
202
|
+
|
|
203
|
+
return true_reference_3d, spatial_distortion_3d, interference_3d, artifacts_3d
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def extract_distortion_components(
|
|
207
|
+
src_files: List[Union[str, np.ndarray]],
|
|
208
|
+
est_file: Union[str, np.ndarray],
|
|
209
|
+
options: dict = None,
|
|
210
|
+
sampling_frequency: float = None
|
|
211
|
+
) -> Tuple[List[str], List[np.ndarray]]:
|
|
212
|
+
"""
|
|
213
|
+
Subband least-squares decomposes estimates into distinct physical components.
|
|
214
|
+
Replaces extractDistortionComponents.m [1].
|
|
215
|
+
"""
|
|
216
|
+
default_options = {
|
|
217
|
+
'destDir': './',
|
|
218
|
+
'FLAG_2PROJ': False,
|
|
219
|
+
'frameLength': 0.5,
|
|
220
|
+
'filterLength': 0.04,
|
|
221
|
+
'shadeInMs': 10,
|
|
222
|
+
'shadeOutMs': 10,
|
|
223
|
+
'segmentationFactor': 1
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
if options is None:
|
|
227
|
+
options = default_options
|
|
228
|
+
else:
|
|
229
|
+
for k, v in default_options.items():
|
|
230
|
+
if k not in options or options[k] is None:
|
|
231
|
+
options[k] = v
|
|
232
|
+
|
|
233
|
+
is_file_mode = isinstance(est_file, (str, pathlib.Path))
|
|
234
|
+
|
|
235
|
+
if is_file_mode:
|
|
236
|
+
est_data, sampling_frequency = sf.read(est_file)
|
|
237
|
+
if len(est_data.shape) == 1:
|
|
238
|
+
est_data = est_data[:, np.newaxis]
|
|
239
|
+
|
|
240
|
+
src_data_list = []
|
|
241
|
+
for src_path in src_files:
|
|
242
|
+
data, fs_s = sf.read(src_path)
|
|
243
|
+
if fs_s != sampling_frequency:
|
|
244
|
+
raise ValueError("Sampling rates of all files must match.")
|
|
245
|
+
if len(data.shape) == 1:
|
|
246
|
+
data = data[:, np.newaxis]
|
|
247
|
+
src_data_list.append(data)
|
|
248
|
+
else:
|
|
249
|
+
est_data = np.atleast_2d(est_file)
|
|
250
|
+
if est_data.shape[0] < est_data.shape[1]:
|
|
251
|
+
est_data = est_data.T
|
|
252
|
+
|
|
253
|
+
src_data_list = []
|
|
254
|
+
for s_arr in src_files:
|
|
255
|
+
s_arr = np.atleast_2d(s_arr)
|
|
256
|
+
if s_arr.shape[0] < s_arr.shape[1]:
|
|
257
|
+
s_arr = s_arr.T
|
|
258
|
+
src_data_list.append(s_arr)
|
|
259
|
+
|
|
260
|
+
if sampling_frequency is None:
|
|
261
|
+
raise ValueError("In-memory mode requires explicit sampling rate 'fs'.")
|
|
262
|
+
|
|
263
|
+
J = len(src_data_list)
|
|
264
|
+
L_original = est_data.shape[0]
|
|
265
|
+
NChan = est_data.shape[1]
|
|
266
|
+
|
|
267
|
+
for j, s_data in enumerate(src_data_list):
|
|
268
|
+
if s_data.shape != est_data.shape:
|
|
269
|
+
raise ValueError("All source signals must be of matching dimensions.")
|
|
270
|
+
|
|
271
|
+
def apply_shading(sig, fs, shade_in, shade_out):
|
|
272
|
+
sig_shaded = sig.copy()
|
|
273
|
+
if shade_in > 0:
|
|
274
|
+
win_len = 2 * int(round(shade_in / 1000.0 * fs + 1))
|
|
275
|
+
wShadeIn = signal.windows.hann(win_len, sym=False)[:win_len // 2]
|
|
276
|
+
for c in range(sig_shaded.shape[1]):
|
|
277
|
+
sig_shaded[:len(wShadeIn), c] *= wShadeIn
|
|
278
|
+
if shade_out > 0:
|
|
279
|
+
win_len = 2 * int(round(shade_out / 1000.0 * fs + 1))
|
|
280
|
+
wShadeOut = signal.windows.hann(win_len, sym=False)[:win_len // 2]
|
|
281
|
+
wShadeOut = np.flip(wShadeOut)
|
|
282
|
+
for c in range(sig_shaded.shape[1]):
|
|
283
|
+
sig_shaded[-len(wShadeOut):, c] *= wShadeOut
|
|
284
|
+
return sig_shaded
|
|
285
|
+
|
|
286
|
+
src_shaded = [apply_shading(s, sampling_frequency, options['shadeInMs'], options['shadeOutMs']) for s in
|
|
287
|
+
src_data_list]
|
|
288
|
+
est_shaded = apply_shading(est_data, sampling_frequency, options['shadeInMs'], options['shadeOutMs'])
|
|
289
|
+
|
|
290
|
+
# Analysis Gammatone Filterbank
|
|
291
|
+
sj_gamma = [[None for _ in range(NChan)] for _ in range(J)]
|
|
292
|
+
Mmod = None
|
|
293
|
+
analyzer = None
|
|
294
|
+
|
|
295
|
+
for j in range(J):
|
|
296
|
+
for nChan in range(NChan):
|
|
297
|
+
sj_gamma[j][nChan], analyzer, Mmod = my_analysis_filter_bank(src_shaded[j][:, nChan], sampling_frequency,
|
|
298
|
+
Mmod)
|
|
299
|
+
|
|
300
|
+
sj_est_gamma = [None for _ in range(NChan)]
|
|
301
|
+
for nChan in range(NChan):
|
|
302
|
+
sj_est_gamma[nChan], analyzer, _ = my_analysis_filter_bank(est_shaded[:, nChan], sampling_frequency, Mmod)
|
|
303
|
+
|
|
304
|
+
# Convert to subband blocks
|
|
305
|
+
Nb = len(sj_gamma[0][0])
|
|
306
|
+
s = []
|
|
307
|
+
sEst = []
|
|
308
|
+
for b in range(Nb):
|
|
309
|
+
L_band = len(sj_gamma[0][0][b])
|
|
310
|
+
s_band = np.zeros((L_band, NChan, J), dtype=complex)
|
|
311
|
+
sEst_band = np.zeros((L_band, NChan, 1), dtype=complex)
|
|
312
|
+
for nChan in range(NChan):
|
|
313
|
+
sEst_band[:, nChan, 0] = sj_est_gamma[nChan][b]
|
|
314
|
+
for j in range(J):
|
|
315
|
+
s_band[:, nChan, j] = sj_gamma[j][nChan][b]
|
|
316
|
+
s.append(s_band)
|
|
317
|
+
sEst.append(sEst_band)
|
|
318
|
+
|
|
319
|
+
fRef = 1000.0
|
|
320
|
+
TframeFRef = options['frameLength']
|
|
321
|
+
ThopFRef = TframeFRef / 4.0
|
|
322
|
+
idx_fref = np.argmin(np.abs(analyzer.center_frequencies - fRef))
|
|
323
|
+
bwRef = analyzer.bandwidths[idx_fref]
|
|
324
|
+
|
|
325
|
+
# Corrected object-subscripting glitch:
|
|
326
|
+
fsb = analyzer.sampling_frequency / analyzer.Ndec
|
|
327
|
+
|
|
328
|
+
TfilterFRef = min(options['filterLength'], TframeFRef / NChan / J / 3.0)
|
|
329
|
+
flens = np.maximum(3, 2 * np.round((TfilterFRef * bwRef / analyzer.bw * fsb - 1) / 2.0) + 1).astype(int)
|
|
330
|
+
Lws = np.maximum(3, np.round(TframeFRef * bwRef / analyzer.bw * fsb)).astype(int)
|
|
331
|
+
hops = np.maximum(1, np.round(ThopFRef * bwRef / analyzer.bw * fsb)).astype(int)
|
|
332
|
+
|
|
333
|
+
sgTrue, egTarget, egInterf, egArtif = [], [], [], []
|
|
334
|
+
for b in range(Nb):
|
|
335
|
+
sTrue_b, eSpat_b, eInterf_b, eArtif_b = extract_target_spatial_interference_artifacts(
|
|
336
|
+
s[b], sEst[b], flens[b], Lws[b], hops[b], flag_two_projections=options['FLAG_2PROJ']
|
|
337
|
+
)
|
|
338
|
+
sgTrue.append(sTrue_b)
|
|
339
|
+
egTarget.append(eSpat_b)
|
|
340
|
+
egInterf.append(eInterf_b)
|
|
341
|
+
egArtif.append(eArtif_b)
|
|
342
|
+
|
|
343
|
+
s_gamma_true = [[None for _ in range(Nb)] for _ in range(NChan)]
|
|
344
|
+
s_gamma_target = [[None for _ in range(Nb)] for _ in range(NChan)]
|
|
345
|
+
s_gamma_interf = [[None for _ in range(Nb)] for _ in range(NChan)]
|
|
346
|
+
s_gamma_artif = [[None for _ in range(Nb)] for _ in range(NChan)]
|
|
347
|
+
|
|
348
|
+
for nChan in range(NChan):
|
|
349
|
+
for b in range(Nb):
|
|
350
|
+
s_gamma_true[nChan][b] = sgTrue[b][:, nChan, 0]
|
|
351
|
+
s_gamma_target[nChan][b] = egTarget[b][:, nChan, 0]
|
|
352
|
+
s_gamma_interf[nChan][b] = egInterf[b][:, nChan, 0]
|
|
353
|
+
s_gamma_artif[nChan][b] = egArtif[b][:, nChan, 0]
|
|
354
|
+
|
|
355
|
+
trueSynth = np.zeros((L_original, NChan))
|
|
356
|
+
targetSynth = np.zeros((L_original, NChan))
|
|
357
|
+
interfSynth = np.zeros((L_original, NChan))
|
|
358
|
+
artifSynth = np.zeros((L_original, NChan))
|
|
359
|
+
|
|
360
|
+
def fit_to_length(sig, target_len):
|
|
361
|
+
if len(sig) >= target_len:
|
|
362
|
+
return sig[:target_len]
|
|
363
|
+
return np.pad(sig, (0, target_len - len(sig)), mode='constant')
|
|
364
|
+
|
|
365
|
+
for nChan in range(NChan):
|
|
366
|
+
synth_t, _ = my_synthesis_filter_bank(s_gamma_true[nChan], analyzer)
|
|
367
|
+
synth_s, _ = my_synthesis_filter_bank(s_gamma_target[nChan], analyzer)
|
|
368
|
+
synth_i, _ = my_synthesis_filter_bank(s_gamma_interf[nChan], analyzer)
|
|
369
|
+
synth_a, _ = my_synthesis_filter_bank(s_gamma_artif[nChan], analyzer)
|
|
370
|
+
|
|
371
|
+
trueSynth[:, nChan] = fit_to_length(synth_t, L_original)
|
|
372
|
+
targetSynth[:, nChan] = fit_to_length(synth_s, L_original)
|
|
373
|
+
interfSynth[:, nChan] = fit_to_length(synth_i, L_original)
|
|
374
|
+
artifSynth[:, nChan] = fit_to_length(synth_a, L_original)
|
|
375
|
+
|
|
376
|
+
if is_file_mode:
|
|
377
|
+
dest_path = pathlib.Path(options['destDir'])
|
|
378
|
+
filename = pathlib.Path(est_file).stem
|
|
379
|
+
out_filenames = [
|
|
380
|
+
str(dest_path / f"{filename}_true.wav"),
|
|
381
|
+
str(dest_path / f"{filename}_eTarget.wav"),
|
|
382
|
+
str(dest_path / f"{filename}_eInterf.wav"),
|
|
383
|
+
str(dest_path / f"{filename}_eArtif.wav")
|
|
384
|
+
]
|
|
385
|
+
sf.write(out_filenames[0], trueSynth, int(sampling_frequency))
|
|
386
|
+
sf.write(out_filenames[1], targetSynth, int(sampling_frequency))
|
|
387
|
+
sf.write(out_filenames[2], interfSynth, int(sampling_frequency))
|
|
388
|
+
sf.write(out_filenames[3], artifSynth, int(sampling_frequency))
|
|
389
|
+
return out_filenames, [trueSynth, targetSynth, interfSynth, artifSynth]
|
|
390
|
+
else:
|
|
391
|
+
return [], [trueSynth, targetSynth, interfSynth, artifSynth]
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def my_analysis_filter_bank(x: np.ndarray, fs: float, Mmod: np.ndarray = None):
|
|
395
|
+
"""Temporary local alias for packaging isolation."""
|
|
396
|
+
from .gammatone import calculate_erb_bandwidth
|
|
397
|
+
MinCF = 20.0
|
|
398
|
+
MaxCF = fs / 2.0
|
|
399
|
+
base_freq = 1000.0
|
|
400
|
+
filters_per_ERB = 1.0
|
|
401
|
+
|
|
402
|
+
fsOrig = fs
|
|
403
|
+
if fs / 2.0 < 1.5 * MaxCF:
|
|
404
|
+
new_fs = int(round(1.5 * fs))
|
|
405
|
+
x = signal.resample(x, int(round(len(x) * new_fs / fs)))
|
|
406
|
+
fs = new_fs
|
|
407
|
+
|
|
408
|
+
analyzer = GammatoneAnalyzer(fs, MinCF, base_freq, MaxCF, filters_per_ERB)
|
|
409
|
+
analyzer.fsOrig = fsOrig
|
|
410
|
+
|
|
411
|
+
gfb_out = analyzer.process(x)
|
|
412
|
+
Nb = gfb_out.shape[0]
|
|
413
|
+
|
|
414
|
+
if Mmod is None:
|
|
415
|
+
time_steps = np.arange(gfb_out.shape[1])
|
|
416
|
+
cfs = analyzer.center_frequencies[:, np.newaxis]
|
|
417
|
+
Mmod = np.exp(-2j * np.pi / fs * cfs * time_steps)
|
|
418
|
+
|
|
419
|
+
gfb_out = gfb_out * Mmod
|
|
420
|
+
|
|
421
|
+
bw = calculate_erb_bandwidth(analyzer.center_frequencies)
|
|
422
|
+
alpha_dec = 2.0
|
|
423
|
+
Ndec = np.maximum(1, np.floor(fs / (bw * alpha_dec))).astype(int)
|
|
424
|
+
|
|
425
|
+
analyzer.Ndec = Ndec
|
|
426
|
+
analyzer.fs = fs
|
|
427
|
+
analyzer.bw = bw
|
|
428
|
+
|
|
429
|
+
gfb_out_dec = []
|
|
430
|
+
for k in range(Nb):
|
|
431
|
+
decimated = signal.resample_poly(gfb_out[k, :], 1, Ndec[k])
|
|
432
|
+
gfb_out_dec.append(decimated)
|
|
433
|
+
|
|
434
|
+
return gfb_out_dec, analyzer, Mmod
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def my_synthesis_filter_bank(xFB: list, analyzer: GammatoneAnalyzer):
|
|
438
|
+
"""Temporary local alias for packaging isolation."""
|
|
439
|
+
Nb = len(xFB)
|
|
440
|
+
fs = analyzer.fs
|
|
441
|
+
|
|
442
|
+
max_len = max(len(xFB[k]) * analyzer.Ndec[k] for k in range(Nb))
|
|
443
|
+
gfb_out_proc = np.zeros((Nb, max_len), dtype=complex)
|
|
444
|
+
for k in range(Nb):
|
|
445
|
+
target_len = len(xFB[k]) * analyzer.Ndec[k]
|
|
446
|
+
upsampled = signal.resample_poly(xFB[k], analyzer.Ndec[k], 1)
|
|
447
|
+
if len(upsampled) > target_len:
|
|
448
|
+
upsampled = upsampled[:target_len]
|
|
449
|
+
elif len(upsampled) < target_len:
|
|
450
|
+
upsampled = np.pad(upsampled, (0, target_len - len(upsampled)), mode='constant')
|
|
451
|
+
gfb_out_proc[k, :target_len] = upsampled
|
|
452
|
+
|
|
453
|
+
time_steps = np.arange(max_len)
|
|
454
|
+
cfs = analyzer.center_frequencies[:, np.newaxis]
|
|
455
|
+
Mmod_synth = np.exp(2j * np.pi / fs * cfs * time_steps)
|
|
456
|
+
|
|
457
|
+
gfb_out_proc = gfb_out_proc * Mmod_synth
|
|
458
|
+
|
|
459
|
+
desired_delay_in_seconds = 1000.0 / fs
|
|
460
|
+
synthesizer = GammatoneSynthesizer(analyzer, desired_delay_in_seconds)
|
|
461
|
+
|
|
462
|
+
# Corrected object-oriented process execution directly matching interface definitions:
|
|
463
|
+
output = synthesizer.process(gfb_out_proc)
|
|
464
|
+
|
|
465
|
+
fsOrig = analyzer.fsOrig
|
|
466
|
+
output = signal.resample(output, int(round(len(output) * fsOrig / fs)))
|
|
467
|
+
delay_samples = int(round(desired_delay_in_seconds * fsOrig))
|
|
468
|
+
output = output[delay_samples:]
|
|
469
|
+
|
|
470
|
+
return output, synthesizer
|