python-peass 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- peass/__init__.py +18 -0
- peass/auditory_model.py +158 -0
- peass/decomposition.py +470 -0
- peass/gammatone.py +267 -0
- peass/metrics.py +147 -0
- peass/parameters/paramTask1.npz +0 -0
- peass/parameters/paramTask2.npz +0 -0
- peass/parameters/paramTask3.npz +0 -0
- peass/parameters/paramTask4.npz +0 -0
- peass/predictor.py +131 -0
- python_peass-2.0.1.dist-info/METADATA +165 -0
- python_peass-2.0.1.dist-info/RECORD +14 -0
- python_peass-2.0.1.dist-info/WHEEL +4 -0
- python_peass-2.0.1.dist-info/licenses/LICENSE +680 -0
peass/gammatone.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PEASS Auditory Package - Hohmann 2002 Gammatone Filterbank [1, 3]
|
|
3
|
+
|
|
4
|
+
This module implements the complex-valued Gammatone Filterbank as described in [3].
|
|
5
|
+
It provides complete physical modeling of frequency analysis, delay/phase alignment,
|
|
6
|
+
and synthesize capabilities to reconstruct fullband audio from subbands.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import List
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import scipy.signal as signal
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def calculate_erb_bandwidth(center_frequency: float) -> float:
|
|
16
|
+
"""
|
|
17
|
+
Computes the Equivalent Rectangular Bandwidth of auditory filters.
|
|
18
|
+
|
|
19
|
+
Formula defined in Eq. (13) of [3]:
|
|
20
|
+
% bw = 24.7*(.00437*fc+1);
|
|
21
|
+
"""
|
|
22
|
+
return 24.7 * (0.00437 * center_frequency + 1.0)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def frequency_to_erb_scale(frequency_hz: float) -> float:
|
|
26
|
+
"""
|
|
27
|
+
Converts frequency in Hz to Equivalent Rectangular Bandwidth (ERB) scale.
|
|
28
|
+
|
|
29
|
+
Formula defined in Eq. (16) of [3]:
|
|
30
|
+
% ERBscale = GFB_Q * log(1 + Hz / (GFB_L * GFB_Q));
|
|
31
|
+
"""
|
|
32
|
+
return 9.265 * np.log(1.0 + frequency_hz / (24.7 * 9.265))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def erb_scale_to_frequency(erb_scale: float) -> float:
|
|
36
|
+
"""
|
|
37
|
+
Converts Equivalent Rectangular Bandwidth (ERB) scale value to frequency in Hz.
|
|
38
|
+
|
|
39
|
+
Formula defined in Eq. (17) of [3]:
|
|
40
|
+
% Hz = (exp(ERBscale / GFB_Q) - 1) * (GFB_L * GFB_Q);
|
|
41
|
+
"""
|
|
42
|
+
return (np.exp(erb_scale / 9.265) - 1.0) * (24.7 * 9.265)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_center_frequencies(
|
|
46
|
+
filters_per_erb: float,
|
|
47
|
+
lower_cutoff_hz: float,
|
|
48
|
+
specified_center_hz: float,
|
|
49
|
+
upper_cutoff_hz: float
|
|
50
|
+
) -> np.ndarray:
|
|
51
|
+
"""
|
|
52
|
+
Constructs a vector of center frequencies equidistant on the ERB scale.
|
|
53
|
+
Equivalent to Gfb_center_frequencies.m [3].
|
|
54
|
+
"""
|
|
55
|
+
lower_erb = frequency_to_erb_scale(lower_cutoff_hz)
|
|
56
|
+
specified_erb = frequency_to_erb_scale(specified_center_hz)
|
|
57
|
+
upper_erb = frequency_to_erb_scale(upper_cutoff_hz)
|
|
58
|
+
|
|
59
|
+
erbs_below_base = specified_erb - lower_erb
|
|
60
|
+
num_filters_below = int(np.floor(erbs_below_base * filters_per_erb))
|
|
61
|
+
|
|
62
|
+
start_erb = specified_erb - (num_filters_below / filters_per_erb)
|
|
63
|
+
center_erbs = np.arange(start_erb, upper_erb + 1e-9, 1.0 / filters_per_erb)
|
|
64
|
+
return erb_scale_to_frequency(center_erbs)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class GammatoneFilter:
|
|
68
|
+
"""
|
|
69
|
+
Represents a single 4th-order complex-valued all-pole Gammatone filter.
|
|
70
|
+
Equivalent to Gfb_Filter class in MATLAB [3].
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
sampling_frequency: float,
|
|
76
|
+
center_frequency: float,
|
|
77
|
+
gamma_order: int = 4,
|
|
78
|
+
bandwidth_factor: float = 1.0
|
|
79
|
+
):
|
|
80
|
+
self.gamma_order: int = gamma_order
|
|
81
|
+
self.sampling_frequency: float = sampling_frequency
|
|
82
|
+
self.center_frequency: float = center_frequency
|
|
83
|
+
|
|
84
|
+
# Auditory bandwidth scaling (Eq. 14 of [3])
|
|
85
|
+
audiological_erb = calculate_erb_bandwidth(center_frequency) * bandwidth_factor
|
|
86
|
+
a_gamma = (np.pi * math_factorial(2 * gamma_order - 2) * (2.0 ** -(2 * gamma_order - 2)) /
|
|
87
|
+
(math_factorial(gamma_order - 1) ** 2))
|
|
88
|
+
b = audiological_erb / a_gamma
|
|
89
|
+
|
|
90
|
+
self.lambda_val: float = np.exp(-2.0 * np.pi * b / sampling_frequency)
|
|
91
|
+
self.beta: float = 2.0 * np.pi * center_frequency / sampling_frequency
|
|
92
|
+
|
|
93
|
+
# Complex pole coefficient
|
|
94
|
+
self.coefficient: complex = self.lambda_val * np.exp(1j * self.beta)
|
|
95
|
+
self.normalization_factor: float = 2.0 * (1.0 - np.abs(self.coefficient)) ** gamma_order
|
|
96
|
+
self.state: np.ndarray = np.zeros(gamma_order, dtype=complex)
|
|
97
|
+
|
|
98
|
+
def process(self, input_signal: np.ndarray) -> np.ndarray:
|
|
99
|
+
factor = self.normalization_factor
|
|
100
|
+
coeff = self.coefficient
|
|
101
|
+
filter_state = self.state * coeff
|
|
102
|
+
|
|
103
|
+
y = input_signal.copy()
|
|
104
|
+
b_stage = np.array([factor], dtype=complex)
|
|
105
|
+
a_stage = np.array([1.0, -coeff], dtype=complex)
|
|
106
|
+
|
|
107
|
+
new_state = np.zeros(self.gamma_order, dtype=complex)
|
|
108
|
+
for i in range(self.gamma_order):
|
|
109
|
+
b_coef = b_stage if i == 0 else np.array([1.0], dtype=complex)
|
|
110
|
+
zi = np.array([filter_state[i]], dtype=complex)
|
|
111
|
+
y, zf = signal.lfilter(b_coef, a_stage, y, zi=zi)
|
|
112
|
+
new_state[i] = zf[0]
|
|
113
|
+
|
|
114
|
+
self.state = new_state / coeff
|
|
115
|
+
return y
|
|
116
|
+
|
|
117
|
+
def clear_state(self) -> None:
|
|
118
|
+
"""Resets the internal filter state to zeros."""
|
|
119
|
+
self.state = np.zeros(self.gamma_order, dtype=complex)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class GammatoneAnalyzer:
|
|
123
|
+
"""
|
|
124
|
+
A collection of GammatoneFilters acting as an analysis filterbank.
|
|
125
|
+
Equivalent to Gfb_Analyzer class in MATLAB [3].
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
sampling_frequency: float,
|
|
131
|
+
lower_cutoff_hz: float,
|
|
132
|
+
specified_center_hz: float,
|
|
133
|
+
upper_cutoff_hz: float,
|
|
134
|
+
filters_per_erb: float,
|
|
135
|
+
gamma_order: int = 4,
|
|
136
|
+
bandwidth_factor: float = 1.0
|
|
137
|
+
):
|
|
138
|
+
self.sampling_frequency: float = sampling_frequency
|
|
139
|
+
self.center_frequencies: np.ndarray = get_center_frequencies(
|
|
140
|
+
filters_per_erb, lower_cutoff_hz, specified_center_hz, upper_cutoff_hz
|
|
141
|
+
)
|
|
142
|
+
self.filters: List[GammatoneFilter] = [
|
|
143
|
+
GammatoneFilter(sampling_frequency, cf, gamma_order, bandwidth_factor)
|
|
144
|
+
for cf in self.center_frequencies
|
|
145
|
+
]
|
|
146
|
+
self.bandwidths: np.ndarray = calculate_erb_bandwidth(self.center_frequencies)
|
|
147
|
+
|
|
148
|
+
def process(self, input_signal: np.ndarray) -> np.ndarray:
|
|
149
|
+
num_bands = len(self.filters)
|
|
150
|
+
output = np.zeros((num_bands, input_signal.shape[0]), dtype=complex)
|
|
151
|
+
for band in range(num_bands):
|
|
152
|
+
output[band, :] = self.filters[band].process(input_signal)
|
|
153
|
+
return output
|
|
154
|
+
|
|
155
|
+
def get_z_response(self, z: np.ndarray) -> np.ndarray:
|
|
156
|
+
z_col = z[:, np.newaxis]
|
|
157
|
+
num_bands = len(self.filters)
|
|
158
|
+
response = np.ones((z_col.shape[0], num_bands), dtype=complex)
|
|
159
|
+
for band in range(num_bands):
|
|
160
|
+
coeff = self.filters[band].coefficient
|
|
161
|
+
norm = self.filters[band].normalization_factor
|
|
162
|
+
gamma = self.filters[band].gamma_order
|
|
163
|
+
response[:, band] = ((1.0 - coeff / z_col[:, 0]) ** -gamma) * norm
|
|
164
|
+
return response
|
|
165
|
+
|
|
166
|
+
def clear_state(self) -> None:
|
|
167
|
+
"""Resets all filters' states to zeros."""
|
|
168
|
+
for filter_obj in self.filters:
|
|
169
|
+
filter_obj.clear_state()
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class GammatoneDelay:
|
|
173
|
+
"""
|
|
174
|
+
Handles phase alignment and group delay estimation across subbands.
|
|
175
|
+
Equivalent to Gfb_Delay class in MATLAB [3].
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
def __init__(self, analyzer: GammatoneAnalyzer, delay_samples: int):
|
|
179
|
+
# Reset the analyzer states before analyzing the impulse response
|
|
180
|
+
analyzer.clear_state()
|
|
181
|
+
|
|
182
|
+
impulse = np.zeros(delay_samples + 2)
|
|
183
|
+
impulse[0] = 1.0
|
|
184
|
+
|
|
185
|
+
# Analyze impulse
|
|
186
|
+
impulse_response = analyzer.process(impulse)
|
|
187
|
+
num_bands = impulse_response.shape[0]
|
|
188
|
+
|
|
189
|
+
ir_slice = np.abs(impulse_response[:, :delay_samples + 1])
|
|
190
|
+
max_indices = np.argmax(ir_slice, axis=1)
|
|
191
|
+
|
|
192
|
+
self.delays_samples: np.ndarray = delay_samples - max_indices
|
|
193
|
+
slopes = np.zeros(num_bands, dtype=complex)
|
|
194
|
+
for band in range(num_bands):
|
|
195
|
+
idx = max_indices[band]
|
|
196
|
+
slopes[band] = impulse_response[band, idx + 1] - impulse_response[band, idx - 1]
|
|
197
|
+
|
|
198
|
+
slopes = slopes / np.abs(slopes)
|
|
199
|
+
self.phase_factors: np.ndarray = 1j / slopes
|
|
200
|
+
self.memory: np.ndarray = np.zeros((num_bands, int(np.max(self.delays_samples))), dtype=float)
|
|
201
|
+
|
|
202
|
+
def process(self, input_data: np.ndarray) -> np.ndarray:
|
|
203
|
+
num_bands, num_samples = input_data.shape
|
|
204
|
+
output = np.zeros((num_bands, num_samples))
|
|
205
|
+
for band in range(num_bands):
|
|
206
|
+
delay_val = int(self.delays_samples[band])
|
|
207
|
+
phase_corrected = np.real(input_data[band, :] * self.phase_factors[band])
|
|
208
|
+
if delay_val == 0:
|
|
209
|
+
output[band, :] = phase_corrected
|
|
210
|
+
else:
|
|
211
|
+
tmp_out = np.concatenate((self.memory[band, :delay_val], phase_corrected))
|
|
212
|
+
self.memory[band, :delay_val] = tmp_out[num_samples:]
|
|
213
|
+
output[band, :] = tmp_out[:num_samples]
|
|
214
|
+
return output
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class GammatoneMixer:
|
|
218
|
+
"""
|
|
219
|
+
Optimizes subband synthesis gains to flat-response outputs.
|
|
220
|
+
Equivalent to Gfb_Mixer class in MATLAB [3].
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
def __init__(self, analyzer: GammatoneAnalyzer, delay: GammatoneDelay, iterations: int = 100):
|
|
224
|
+
center_frequencies = analyzer.center_frequencies
|
|
225
|
+
num_bands = len(center_frequencies)
|
|
226
|
+
fs = analyzer.sampling_frequency
|
|
227
|
+
|
|
228
|
+
z_c = np.exp(2j * np.pi * center_frequencies / fs)
|
|
229
|
+
self.gains: np.ndarray = np.ones(num_bands)
|
|
230
|
+
|
|
231
|
+
pos_response = analyzer.get_z_response(z_c)
|
|
232
|
+
neg_response = analyzer.get_z_response(np.conj(z_c))
|
|
233
|
+
|
|
234
|
+
for band in range(num_bands):
|
|
235
|
+
pos_response[:, band] = pos_response[:, band] * delay.phase_factors[band] * (
|
|
236
|
+
z_c ** -delay.delays_samples[band])
|
|
237
|
+
neg_response[:, band] = neg_response[:, band] * delay.phase_factors[band] * (
|
|
238
|
+
np.conj(z_c) ** -delay.delays_samples[band])
|
|
239
|
+
|
|
240
|
+
f_response = (pos_response + np.conj(neg_response)) / 2.0
|
|
241
|
+
|
|
242
|
+
for _ in range(iterations):
|
|
243
|
+
selected_spectrum = f_response @ self.gains
|
|
244
|
+
self.gains = self.gains / np.abs(selected_spectrum)
|
|
245
|
+
|
|
246
|
+
def process(self, input_data: np.ndarray) -> np.ndarray:
|
|
247
|
+
return self.gains @ input_data
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class GammatoneSynthesizer:
|
|
251
|
+
"""
|
|
252
|
+
Synthesizes multiple subband signals back into a single fullband waveform.
|
|
253
|
+
Equivalent to Gfb_Synthesizer class in MATLAB [3].
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
def __init__(self, analyzer: GammatoneAnalyzer, desired_delay_seconds: float):
|
|
257
|
+
self.delay = GammatoneDelay(analyzer, int(round(desired_delay_seconds * analyzer.sampling_frequency)))
|
|
258
|
+
self.mixer = GammatoneMixer(analyzer, self.delay)
|
|
259
|
+
|
|
260
|
+
def process(self, input_data: np.ndarray) -> np.ndarray:
|
|
261
|
+
delayed = self.delay.process(input_data)
|
|
262
|
+
return self.mixer.process(delayed)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def math_factorial(n: int) -> int:
|
|
266
|
+
"""Standard integer factorial calculation."""
|
|
267
|
+
return 1 if n <= 1 else n * math_factorial(n - 1)
|
peass/metrics.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PEASS Metrics Package - Auditory Features & Similarity Metrics [1, 2]
|
|
3
|
+
|
|
4
|
+
This module computes the perceptual features and linear/energy ratio calculations
|
|
5
|
+
such as SDR, ISR, SIR, and SAR [1]. It houses the core PEMO-Q time-frequency
|
|
6
|
+
cross-correlation engine.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Tuple
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from .auditory_model import generate_internal_representation
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def calculate_energy_ratios(
|
|
17
|
+
s_true: np.ndarray,
|
|
18
|
+
e_target: np.ndarray,
|
|
19
|
+
e_interf: np.ndarray,
|
|
20
|
+
e_artif: np.ndarray
|
|
21
|
+
) -> Tuple[float, float, float, float]:
|
|
22
|
+
"""
|
|
23
|
+
Computes standard BSS Eval energy ratio metrics from physically decomposed components.
|
|
24
|
+
Replaces ISR_SIR_SAR_fromNewDecomposition.m [1].
|
|
25
|
+
"""
|
|
26
|
+
sTrue_flat = s_true.ravel()
|
|
27
|
+
eTarget_flat = e_target.ravel()
|
|
28
|
+
eInterf_flat = e_interf.ravel()
|
|
29
|
+
eArtif_flat = e_artif.ravel()
|
|
30
|
+
|
|
31
|
+
# Eq. (11), (12), (13) of Emiya 2011 [4]:
|
|
32
|
+
ISR = 10.0 * np.log10(np.sum(sTrue_flat ** 2) / np.sum(eTarget_flat ** 2))
|
|
33
|
+
SIR = 10.0 * np.log10(np.sum((sTrue_flat + eTarget_flat) ** 2) / np.sum(eInterf_flat ** 2))
|
|
34
|
+
SAR = 10.0 * np.log10(np.sum((sTrue_flat + eTarget_flat + eInterf_flat) ** 2) / np.sum(eArtif_flat ** 2))
|
|
35
|
+
SDR = 10.0 * np.log10(np.sum(sTrue_flat ** 2) / np.sum((eTarget_flat + eInterf_flat + eArtif_flat) ** 2))
|
|
36
|
+
|
|
37
|
+
return ISR, SIR, SAR, SDR
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def pemo_similarity_metric(internal_reference: np.ndarray, internal_test: np.ndarray,
|
|
41
|
+
sampling_frequency: float) -> float:
|
|
42
|
+
"""
|
|
43
|
+
Compares two internal representations to produce an auditory similarity metric.
|
|
44
|
+
Replaces pemo_metric.m [1].
|
|
45
|
+
|
|
46
|
+
Performs assimilation of masked content, local framing, cross-correlation,
|
|
47
|
+
moving RMS weighting, and percentile assessment [2].
|
|
48
|
+
"""
|
|
49
|
+
nband, nsampl, nmod = internal_reference.shape
|
|
50
|
+
|
|
51
|
+
# Assimilation (Eq. of PEMO-Q [2]):
|
|
52
|
+
assim = (np.abs(internal_test) < np.abs(internal_reference))
|
|
53
|
+
internal_test[assim] = 0.25 * internal_reference[assim] + 0.75 * internal_test[assim]
|
|
54
|
+
|
|
55
|
+
# Convert frame sizes
|
|
56
|
+
flen = int(min(nsampl, 0.1 * sampling_frequency))
|
|
57
|
+
nfram = int(np.floor(nsampl / flen))
|
|
58
|
+
nsampl = nfram * flen
|
|
59
|
+
|
|
60
|
+
internal_reference = internal_reference[:, :nsampl, :]
|
|
61
|
+
internal_test = internal_test[:, :nsampl, :]
|
|
62
|
+
|
|
63
|
+
PSMt = np.zeros(nfram)
|
|
64
|
+
lPSM = np.zeros(nmod)
|
|
65
|
+
lNMS = np.zeros(nmod)
|
|
66
|
+
|
|
67
|
+
for t in range(nfram):
|
|
68
|
+
for m in range(nmod):
|
|
69
|
+
lref = internal_reference[:, t * flen: (t + 1) * flen, m]
|
|
70
|
+
lref_flat = lref.ravel()
|
|
71
|
+
lref_flat = lref_flat - np.mean(lref_flat)
|
|
72
|
+
|
|
73
|
+
ltest = internal_test[:, t * flen: (t + 1) * flen, m]
|
|
74
|
+
ltest_flat_orig = ltest.ravel()
|
|
75
|
+
lNMS[m] = np.sum(ltest_flat_orig ** 2)
|
|
76
|
+
|
|
77
|
+
ltest_flat = ltest_flat_orig - np.mean(ltest_flat_orig)
|
|
78
|
+
denom = np.sqrt(np.sum(lref_flat ** 2) * np.sum(ltest_flat ** 2))
|
|
79
|
+
lPSM[m] = np.sum(lref_flat * ltest_flat) / denom if denom != 0 else 0.0
|
|
80
|
+
|
|
81
|
+
sum_lnms = np.sum(lNMS)
|
|
82
|
+
PSMt[t] = np.sum(lPSM * lNMS) / sum_lnms if sum_lnms != 0 else 0.0
|
|
83
|
+
|
|
84
|
+
# From local to global similarity
|
|
85
|
+
ilen = int(1 * sampling_frequency)
|
|
86
|
+
mtest_sq = np.sum(internal_test ** 2, axis=(0, 2))
|
|
87
|
+
|
|
88
|
+
RMS = np.zeros(nfram)
|
|
89
|
+
for t in range(nfram):
|
|
90
|
+
start_idx = int(max(0, (t + 0.5) * flen - 0.5 * ilen))
|
|
91
|
+
end_idx = int(min(nsampl, (t + 0.5) * flen + 0.5 * ilen))
|
|
92
|
+
ltest = mtest_sq[start_idx:end_idx]
|
|
93
|
+
RMS[t] = np.mean(ltest) if len(ltest) > 0 else 0.0
|
|
94
|
+
|
|
95
|
+
# Sorted weighted percentile extraction
|
|
96
|
+
ind = np.argsort(PSMt)
|
|
97
|
+
PSMt_sorted = PSMt[ind]
|
|
98
|
+
RMS_sorted = RMS[ind]
|
|
99
|
+
RMS_cum = np.cumsum(RMS_sorted)
|
|
100
|
+
|
|
101
|
+
cutoff = 0.5 * RMS_cum[-1]
|
|
102
|
+
match_indices = np.where(RMS_cum >= cutoff)[0]
|
|
103
|
+
|
|
104
|
+
return PSMt_sorted[match_indices[0]] if len(match_indices) > 0 else 0.0
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def audio_quality_features(decomposition_signals: list[np.ndarray], sampling_frequency: float = 16000.0) -> Tuple[
|
|
108
|
+
float, float, float, float]:
|
|
109
|
+
"""
|
|
110
|
+
Computes quality features by sending decomposed signals through the internal auditory model.
|
|
111
|
+
Replaces audioQualityFeatures.m [1].
|
|
112
|
+
"""
|
|
113
|
+
sTrue, eTarget, eInterf, eArtif = decomposition_signals
|
|
114
|
+
|
|
115
|
+
if len(sTrue.shape) == 1:
|
|
116
|
+
sTrue = sTrue[:, np.newaxis]
|
|
117
|
+
eTarget = eTarget[:, np.newaxis]
|
|
118
|
+
eInterf = eInterf[:, np.newaxis]
|
|
119
|
+
eArtif = eArtif[:, np.newaxis]
|
|
120
|
+
|
|
121
|
+
testAll = sTrue + eTarget + eInterf + eArtif
|
|
122
|
+
NChan = sTrue.shape[1]
|
|
123
|
+
|
|
124
|
+
qTarget = np.zeros(NChan)
|
|
125
|
+
qInterf = np.zeros(NChan)
|
|
126
|
+
qArtif = np.zeros(NChan)
|
|
127
|
+
qGlobal = np.zeros(NChan)
|
|
128
|
+
|
|
129
|
+
for kChan in range(NChan):
|
|
130
|
+
mtest, fr = generate_internal_representation(testAll[:, kChan], sampling_frequency)
|
|
131
|
+
|
|
132
|
+
mref_t, _ = generate_internal_representation(sTrue[:, kChan] + eInterf[:, kChan] + eArtif[:, kChan],
|
|
133
|
+
sampling_frequency)
|
|
134
|
+
qTarget[kChan] = pemo_similarity_metric(mref_t, mtest, fr)
|
|
135
|
+
|
|
136
|
+
mref_i, _ = generate_internal_representation(sTrue[:, kChan] + eTarget[:, kChan] + eArtif[:, kChan],
|
|
137
|
+
sampling_frequency)
|
|
138
|
+
qInterf[kChan] = pemo_similarity_metric(mref_i, mtest, fr)
|
|
139
|
+
|
|
140
|
+
mref_a, _ = generate_internal_representation(sTrue[:, kChan] + eTarget[:, kChan] + eInterf[:, kChan],
|
|
141
|
+
sampling_frequency)
|
|
142
|
+
qArtif[kChan] = pemo_similarity_metric(mref_a, mtest, fr)
|
|
143
|
+
|
|
144
|
+
mref_g, _ = generate_internal_representation(sTrue[:, kChan], sampling_frequency)
|
|
145
|
+
qGlobal[kChan] = pemo_similarity_metric(mref_g, mtest, fr)
|
|
146
|
+
|
|
147
|
+
return np.min(qTarget), np.min(qInterf), np.min(qArtif), np.min(qGlobal)
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
peass/predictor.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PEASS Predictor Package - Multi-Criteria Neural Network Regressor [1]
|
|
3
|
+
|
|
4
|
+
This module maps raw auditory similarity scores (qTarget, qInterf, qArtif, qGlobal)
|
|
5
|
+
to Predicted Perceptual Scores (OPS, TPS, IPS, APS) on a scale from 0 to 100
|
|
6
|
+
using modern .npz parameter loading [1].
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import pathlib
|
|
11
|
+
from typing import Any
|
|
12
|
+
from typing import Dict
|
|
13
|
+
from typing import Union
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import soundfile as sf
|
|
17
|
+
|
|
18
|
+
from .decomposition import extract_distortion_components
|
|
19
|
+
from .metrics import audio_quality_features
|
|
20
|
+
from .metrics import calculate_energy_ratios
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def my_mapping(features: np.ndarray, weights: np.ndarray, bias: np.ndarray, output_weights: np.ndarray,
|
|
24
|
+
output_bias: np.ndarray) -> float:
|
|
25
|
+
"""
|
|
26
|
+
Evaluates forward propagation through the two-layer perceptron.
|
|
27
|
+
Replaces myMapping.m [1].
|
|
28
|
+
"""
|
|
29
|
+
if len(features.shape) == 1:
|
|
30
|
+
features = features[:, np.newaxis]
|
|
31
|
+
|
|
32
|
+
# Hidden layer
|
|
33
|
+
s1 = weights @ features + bias
|
|
34
|
+
o1 = 1.0 / (1.0 + np.exp(-s1))
|
|
35
|
+
|
|
36
|
+
# Output layer
|
|
37
|
+
s2 = output_weights.T @ o1 + output_bias
|
|
38
|
+
y = 100.0 / (1.0 + np.exp(-s2))
|
|
39
|
+
|
|
40
|
+
return float(y[0, 0])
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def predict_peass_scores(
|
|
44
|
+
original_files: list[Union[str, np.ndarray]],
|
|
45
|
+
estimate_file: Union[str, np.ndarray],
|
|
46
|
+
options: dict = None,
|
|
47
|
+
sampling_frequency: float = None,
|
|
48
|
+
return_decomposition: bool = False
|
|
49
|
+
) -> Dict[str, Any]:
|
|
50
|
+
"""
|
|
51
|
+
Wrapper entry point. Performs least-squares decomposition, generates auditory features,
|
|
52
|
+
and predicts Perceptual Evaluation scores [1].
|
|
53
|
+
|
|
54
|
+
Replaces PEASS_ObjectiveMeasure.m and map2SubjScale.m [1].
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
original_files: List of file paths or NumPy arrays of reference sources.
|
|
58
|
+
estimate_file: File path or NumPy array of the separated estimate.
|
|
59
|
+
options: Algorithmic tuning parameters dictionary.
|
|
60
|
+
sampling_frequency: Rate in Hz (required for in-memory array arrays).
|
|
61
|
+
return_decomposition: If True, returns the calculated waveform arrays/saved filepaths.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
dict: Containing OPS, TPS, IPS, APS and decibel criteria.
|
|
65
|
+
If return_decomposition=True, includes "decomposition_arrays" (and
|
|
66
|
+
"decomposition_files" if inputting file paths).
|
|
67
|
+
"""
|
|
68
|
+
# 1. Physical Decomposition
|
|
69
|
+
file_paths, decomposed_arrays = extract_distortion_components(original_files, estimate_file, options,
|
|
70
|
+
sampling_frequency)
|
|
71
|
+
s_true, e_target, e_interf, e_artif = decomposed_arrays
|
|
72
|
+
|
|
73
|
+
if sampling_frequency is None:
|
|
74
|
+
if isinstance(estimate_file, (str, pathlib.Path)):
|
|
75
|
+
_, sampling_frequency = sf.read(estimate_file)
|
|
76
|
+
else:
|
|
77
|
+
sampling_frequency = 16000.0
|
|
78
|
+
|
|
79
|
+
# 2. Traditional Energy Ratios
|
|
80
|
+
ISR, SIR, SAR, SDR = calculate_energy_ratios(s_true, e_target, e_interf, e_artif)
|
|
81
|
+
|
|
82
|
+
# 3. Auditory Feature Extraction
|
|
83
|
+
q_target, q_interf, q_artif, q_global = audio_quality_features(decomposed_arrays, sampling_frequency)
|
|
84
|
+
|
|
85
|
+
# 4. Neural Network Scoring Regressions
|
|
86
|
+
q_features = np.array([q_global, q_target, q_interf, q_artif])
|
|
87
|
+
q_mapped = np.clip(np.log((1.0 + q_features) / (1.0 - q_features)), -5.5, 5.5)
|
|
88
|
+
|
|
89
|
+
scores = np.zeros(4)
|
|
90
|
+
# Dynamically locate local parameters folder absolute path
|
|
91
|
+
pkg_dir = os.path.dirname(os.path.realpath(__file__))
|
|
92
|
+
|
|
93
|
+
for nTask in range(4):
|
|
94
|
+
npz_path = os.path.join(pkg_dir, "parameters", f"paramTask{nTask + 1}.npz")
|
|
95
|
+
mat_data = np.load(npz_path)
|
|
96
|
+
|
|
97
|
+
W = mat_data['W']
|
|
98
|
+
b = mat_data['b']
|
|
99
|
+
v = mat_data['v']
|
|
100
|
+
a = mat_data['a']
|
|
101
|
+
selec = mat_data['selec']
|
|
102
|
+
|
|
103
|
+
scores[nTask] = my_mapping(q_mapped[selec], W, b, v, a)
|
|
104
|
+
|
|
105
|
+
results = {
|
|
106
|
+
"OPS": float(scores[0]), # Overall Perceptual Score
|
|
107
|
+
"TPS": float(scores[1]), # Target-related Perceptual Score
|
|
108
|
+
"IPS": float(scores[2]), # Interference-related Perceptual Score
|
|
109
|
+
"APS": float(scores[3]), # Artifact-related Perceptual Score
|
|
110
|
+
"SDR": SDR,
|
|
111
|
+
"ISR": ISR,
|
|
112
|
+
"SIR": SIR,
|
|
113
|
+
"SAR": SAR
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
if return_decomposition:
|
|
117
|
+
results["decomposition_arrays"] = {
|
|
118
|
+
"true_target": s_true,
|
|
119
|
+
"target_distortion": e_target,
|
|
120
|
+
"interference": e_interf,
|
|
121
|
+
"artifacts": e_artif
|
|
122
|
+
}
|
|
123
|
+
if file_paths:
|
|
124
|
+
results["decomposition_files"] = {
|
|
125
|
+
"true_target": file_paths[0],
|
|
126
|
+
"target_distortion": file_paths[1],
|
|
127
|
+
"interference": file_paths[2],
|
|
128
|
+
"artifacts": file_paths[3]
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
return results
|