sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
sonusai/ir_metric.py
ADDED
@@ -0,0 +1,551 @@
|
|
1
|
+
"""sonusai ir_metric
|
2
|
+
|
3
|
+
usage: ir_metric [-h] [-n NCPU] IRLOC
|
4
|
+
|
5
|
+
options:
|
6
|
+
-h, --help
|
7
|
+
-n, --num_process NCPU Number of parallel processes to use [default: auto]
|
8
|
+
|
9
|
+
Calculate delay and gain metrics of impulse response (IR) files <filename>.wav in IRLOC.
|
10
|
+
Metrics include gain and multiple ways to calculate the IR delay:
|
11
|
+
- gmax: max abs(fft(ir,4096))
|
12
|
+
- dcc: cross-correlation of ir with pulse train
|
13
|
+
- dmax: index of max(ir)
|
14
|
+
- dgd: group delay method
|
15
|
+
- dcen: centroid of energy
|
16
|
+
|
17
|
+
Results are written into IRLOC/ir_metrics.txt
|
18
|
+
|
19
|
+
IRLOC directory containing impulse response data in audio files (.wav, .flac, etc.). Only first channel is analyzed.
|
20
|
+
|
21
|
+
"""
|
22
|
+
|
23
|
+
import glob
|
24
|
+
from os.path import abspath
|
25
|
+
from os.path import basename
|
26
|
+
from os.path import commonprefix
|
27
|
+
from os.path import dirname
|
28
|
+
from os.path import isdir
|
29
|
+
from os.path import isfile
|
30
|
+
from os.path import join
|
31
|
+
from os.path import relpath
|
32
|
+
from os.path import splitext
|
33
|
+
|
34
|
+
import matplotlib.pyplot as plt
|
35
|
+
import numpy as np
|
36
|
+
import pandas as pd
|
37
|
+
import soundfile
|
38
|
+
from numpy import fft
|
39
|
+
|
40
|
+
from sonusai.utils.braced_glob import braced_iglob
|
41
|
+
|
42
|
+
|
43
|
+
def tdoa(signal, reference, interp=1, phat=False, fs=1, t_max=None):
|
44
|
+
"""
|
45
|
+
Estimates the shift of array signal with respect to reference
|
46
|
+
using generalized cross-correlation
|
47
|
+
|
48
|
+
Parameters
|
49
|
+
----------
|
50
|
+
signal: array_like
|
51
|
+
The array whose tdoa is measured
|
52
|
+
reference: array_like
|
53
|
+
The reference array
|
54
|
+
interp: int, optional
|
55
|
+
The interpolation factor for the output array, default 1.
|
56
|
+
phat: bool, optional
|
57
|
+
Apply the PHAT weighting (default False)
|
58
|
+
fs: int or float, optional
|
59
|
+
The sampling frequency of the input arrays, default=1
|
60
|
+
|
61
|
+
Returns
|
62
|
+
-------
|
63
|
+
The estimated delay between the two arrays
|
64
|
+
"""
|
65
|
+
|
66
|
+
signal = np.array(signal)
|
67
|
+
reference = np.array(reference)
|
68
|
+
|
69
|
+
N1 = signal.shape[0]
|
70
|
+
N2 = reference.shape[0]
|
71
|
+
|
72
|
+
r_12 = correlate(signal, reference, interp=interp, phat=phat)
|
73
|
+
|
74
|
+
delay = (np.argmax(np.abs(r_12)) / interp - (N2 - 1)) / fs
|
75
|
+
|
76
|
+
return delay
|
77
|
+
|
78
|
+
|
79
|
+
def correlate(x1, x2, interp=1, phat=False):
|
80
|
+
"""
|
81
|
+
Compute the cross-correlation between x1 and x2
|
82
|
+
|
83
|
+
Parameters
|
84
|
+
----------
|
85
|
+
x1,x2: array_like
|
86
|
+
The data arrays
|
87
|
+
interp: int, optional
|
88
|
+
The interpolation factor for the output array, default 1.
|
89
|
+
phat: bool, optional
|
90
|
+
Apply the PHAT weighting (default False)
|
91
|
+
|
92
|
+
Returns
|
93
|
+
-------
|
94
|
+
The cross-correlation between the two arrays
|
95
|
+
"""
|
96
|
+
|
97
|
+
N1 = x1.shape[0]
|
98
|
+
N2 = x2.shape[0]
|
99
|
+
|
100
|
+
N = N1 + N2 - 1
|
101
|
+
|
102
|
+
X1 = fft.rfft(x1, n=N)
|
103
|
+
X2 = fft.rfft(x2, n=N)
|
104
|
+
|
105
|
+
if phat:
|
106
|
+
eps1 = np.mean(np.abs(X1)) * 1e-10
|
107
|
+
X1 /= np.abs(X1) + eps1
|
108
|
+
eps2 = np.mean(np.abs(X2)) * 1e-10
|
109
|
+
X2 /= np.abs(X2) + eps2
|
110
|
+
|
111
|
+
m = np.minimum(N1, N2)
|
112
|
+
|
113
|
+
out = fft.irfft(X1 * np.conj(X2), n=int(N * interp))
|
114
|
+
|
115
|
+
return np.concatenate([out[-interp * (N2 - 1) :], out[: (interp * N1)]])
|
116
|
+
|
117
|
+
|
118
|
+
def hilbert(u):
|
119
|
+
# N : fft length
|
120
|
+
# M : number of elements to zero out
|
121
|
+
# U : DFT of u
|
122
|
+
# v : IDFT of H(U)
|
123
|
+
|
124
|
+
N = len(u)
|
125
|
+
# take forward Fourier transform
|
126
|
+
U = fft.fft(u)
|
127
|
+
M = N - N // 2 - 1
|
128
|
+
# zero out negative frequency components
|
129
|
+
U[N // 2 + 1 :] = [0] * M
|
130
|
+
# double fft energy except @ DC0
|
131
|
+
U[1 : N // 2] = 2 * U[1 : N // 2]
|
132
|
+
# take inverse Fourier transform
|
133
|
+
v = fft.ifft(U)
|
134
|
+
return v
|
135
|
+
|
136
|
+
|
137
|
+
def measure_rt60(h, fs=1, decay_db=60, energy_thres=1.0, plot=False, rt60_tgt=None):
|
138
|
+
"""
|
139
|
+
RT60 Measurement Routine (taken/modified from Pyroom acoustics.)
|
140
|
+
|
141
|
+
Calculates reverberation time of an impulse response using the Schroeder method [1].
|
142
|
+
Returns:
|
143
|
+
rt60: Reverberation time to -60db (-5db to -65db), will be estimated from rt20 or rt10 if noise floor > -65db
|
144
|
+
edt: Early decay time from 0db to -10db
|
145
|
+
rt10: Reverberation time to -10db (-5db to -15db)
|
146
|
+
rt20: Reverberation time to -20db (-5db to -25db), will be estimated from rt10 if noise floor > -25db
|
147
|
+
floor: 0 if noise floor > -10db or energy curve is not a decay
|
148
|
+
1 if noise floor > -15db and edt is measured, but rt10 estimated from entire energy curve length
|
149
|
+
2 if noise -15db > floor > -25db, rt20 is estimated from measured rt10
|
150
|
+
3 if noise -25db > floor > -65db, rt60 is estimated from measured rt20
|
151
|
+
4 if noise floor < -65db, rt60, edt, rt10, rt20 are all measured
|
152
|
+
Optionally plots some useful information.
|
153
|
+
|
154
|
+
Parameters
|
155
|
+
----------
|
156
|
+
h: array_like
|
157
|
+
The impulse response.
|
158
|
+
fs: float or int, optional
|
159
|
+
The sampling frequency of h (default to 1, i.e., samples).
|
160
|
+
decay_db: float or int, optional
|
161
|
+
The decay in decibels for which we actually estimate the slope and time.
|
162
|
+
Although we want to estimate the RT60, it might not be practical. Instead,
|
163
|
+
we measure the RT10, RT20 or RT30 and extrapolate to RT60.
|
164
|
+
energy_thres: float
|
165
|
+
This should be a value between 0.0 and 1.0.
|
166
|
+
If provided, the fit will be done using a fraction energy_thres of the
|
167
|
+
whole energy. This is useful when there is a long noisy tail for example.
|
168
|
+
plot: bool, optional
|
169
|
+
If set to ``True``, the power decay and different estimated values will
|
170
|
+
be plotted (default False).
|
171
|
+
rt60_tgt: float
|
172
|
+
This parameter can be used to indicate a target RT60 to which we want
|
173
|
+
to compare the estimated value.
|
174
|
+
|
175
|
+
References
|
176
|
+
----------
|
177
|
+
|
178
|
+
[1] M. R. Schroeder, "New Method of Measuring Reverberation Time,"
|
179
|
+
J. Acoust. Soc. Am., vol. 37, no. 3, pp. 409-412, Mar. 1968.
|
180
|
+
"""
|
181
|
+
|
182
|
+
h = np.array(h)
|
183
|
+
fs = float(fs)
|
184
|
+
h = np.abs(hilbert(h)) # hilbert from scratch, see above
|
185
|
+
|
186
|
+
# The power of the impulse response in dB
|
187
|
+
power = h**2
|
188
|
+
# Backward energy integration according to Schroeder
|
189
|
+
energy = np.cumsum(power[::-1])[::-1] # Integration according to Schroeder
|
190
|
+
|
191
|
+
if energy_thres < 1.0:
|
192
|
+
assert 0.0 < energy_thres < 1.0
|
193
|
+
energy -= energy[0] * (1.0 - energy_thres)
|
194
|
+
energy = np.maximum(energy, 0.0)
|
195
|
+
|
196
|
+
# remove the possibly all zero tail
|
197
|
+
i_nz = np.max(np.where(energy > 0)[0])
|
198
|
+
energy = energy[:i_nz]
|
199
|
+
energy_db = 10 * np.log10(energy)
|
200
|
+
energy_db -= energy_db[0] # normalize to first sample assuming it's the peak
|
201
|
+
|
202
|
+
min_energy_db = -np.min(energy_db)
|
203
|
+
if min_energy_db - 5 < decay_db:
|
204
|
+
decay_db = min_energy_db
|
205
|
+
|
206
|
+
# -5 dB headroom
|
207
|
+
try:
|
208
|
+
i_5db = np.min(np.where(energy_db < -5)[0])
|
209
|
+
except ValueError:
|
210
|
+
floor = 0
|
211
|
+
return 0.0, 0.0, 0.0, 0.0, floor # failed, energy curve is not a decay, or has noise floor tail above -5db
|
212
|
+
e_5db = energy_db[i_5db]
|
213
|
+
t_5db = i_5db / fs # This is the initial decay to -5db, used as start of decay slope measurements
|
214
|
+
|
215
|
+
# Estimate slope from 0db to -10db - this is also known as EDT (early decay time)
|
216
|
+
try:
|
217
|
+
i_10db = np.min(np.where(energy_db < -10)[0])
|
218
|
+
except ValueError:
|
219
|
+
floor = 0
|
220
|
+
return 0.0, 0.0, 0.0, 0.0, floor # failed, energy curve is not a decay, or noise floor tail above -10db
|
221
|
+
e_10db = energy_db[i_10db]
|
222
|
+
edt = i_10db / fs # this is also known as EDT (early decay time)
|
223
|
+
|
224
|
+
# after initial decay, estimate RT10, RT20, RT60
|
225
|
+
try:
|
226
|
+
i_decay10db = np.min(np.where(energy_db < -5 - 10)[0])
|
227
|
+
except ValueError:
|
228
|
+
floor = 1
|
229
|
+
i_decay10db = len(energy_db) # noise floor tail is above -15db, use entire curve
|
230
|
+
t10_decay = i_decay10db / fs
|
231
|
+
rt10 = t10_decay - t_5db
|
232
|
+
|
233
|
+
try:
|
234
|
+
i_decay20db = np.min(np.where(energy_db < -5 - 20)[0])
|
235
|
+
except ValueError:
|
236
|
+
floor = 2
|
237
|
+
i_decay20db = len(energy_db) # noise floor tail is above -20db, use entire curve
|
238
|
+
t20_decay = i_decay20db / fs
|
239
|
+
rt20 = t20_decay - t_5db
|
240
|
+
|
241
|
+
try:
|
242
|
+
i_decay60db = np.min(np.where(energy_db < -5 - 60)[0])
|
243
|
+
t60_decay = i_decay60db / fs
|
244
|
+
rt60 = t60_decay - t_5db
|
245
|
+
floor = 4
|
246
|
+
except ValueError:
|
247
|
+
floor = 3
|
248
|
+
i_decay60db = len(energy_db) # noise floor tail is above -60db, use t20_decay to estimate
|
249
|
+
t60_decay = 3 * i_decay20db / fs
|
250
|
+
rt60 = t60_decay - t_5db
|
251
|
+
|
252
|
+
# # extrapolate to compute the rt60 decay time from decay_db decay time
|
253
|
+
# decay_time = t_decay - t_5db
|
254
|
+
# est_rt60 = (60 / decay_db) * decay_time
|
255
|
+
|
256
|
+
if plot:
|
257
|
+
# Remove clip power below to minimum energy (for plotting purpose mostly)
|
258
|
+
energy_min = energy[-1]
|
259
|
+
energy_db_min = energy_db[-1]
|
260
|
+
power[power < energy[-1]] = energy_min
|
261
|
+
power_db = 10 * np.log10(power)
|
262
|
+
power_db -= np.max(power_db)
|
263
|
+
|
264
|
+
# time vector
|
265
|
+
def get_time(x, fs):
|
266
|
+
return np.arange(x.shape[0]) / fs - i_5db / fs
|
267
|
+
|
268
|
+
T = get_time(power_db, fs)
|
269
|
+
|
270
|
+
# plot power and energy
|
271
|
+
plt.plot(get_time(energy_db, fs), energy_db, label="Energy")
|
272
|
+
|
273
|
+
# now the linear fit
|
274
|
+
plt.plot([0, rt60], [e_5db, -65], "--", label="Linear Fit")
|
275
|
+
plt.plot(T, np.ones_like(T) * -60, "--", label="-60 dB")
|
276
|
+
plt.vlines(rt60, energy_db_min, 0, linestyles="dashed", label="Estimated RT60")
|
277
|
+
|
278
|
+
if rt60_tgt is not None:
|
279
|
+
plt.vlines(rt60_tgt, energy_db_min, 0, label="Target RT60")
|
280
|
+
|
281
|
+
plt.legend()
|
282
|
+
|
283
|
+
return rt60, edt, rt10, rt20, floor
|
284
|
+
|
285
|
+
|
286
|
+
def process_path(path: str, extensions: list[str] | None = None) -> tuple[list, str | None]:
|
287
|
+
"""
|
288
|
+
Check path which can be a single file, a subdirectory, or a regex
|
289
|
+
return:
|
290
|
+
- a list of files with matching extensions to any in extlist provided (i.e. ['.wav', '.mp3', '.acc'])
|
291
|
+
- the basedir of the path, if
|
292
|
+
"""
|
293
|
+
if extensions is None:
|
294
|
+
extensions = [".wav", ".WAV", ".flac", ".FLAC", ".mp3", ".aac"]
|
295
|
+
|
296
|
+
# Check if the path is a single file, and return it as a list with the dirname
|
297
|
+
if isfile(path):
|
298
|
+
if any(path.endswith(ext) for ext in extensions):
|
299
|
+
basedir = dirname(path) # base directory
|
300
|
+
if not basedir:
|
301
|
+
basedir = "./"
|
302
|
+
return [path], basedir
|
303
|
+
|
304
|
+
return [], None
|
305
|
+
|
306
|
+
# Check if the path is a dir, recursively find all files any of the specified extensions, return file list and dir
|
307
|
+
if isdir(path):
|
308
|
+
matching_files = []
|
309
|
+
for ext in extensions:
|
310
|
+
matching_files.extend(glob.glob(join(path, "**/*" + ext), recursive=True))
|
311
|
+
return matching_files, path
|
312
|
+
|
313
|
+
# Process as a regex, return list of filenames and basedir
|
314
|
+
apath = abspath(path) # join(abspath(path), "**", "*.{wav,flac,WAV}")
|
315
|
+
matching_files = []
|
316
|
+
for file in braced_iglob(pathname=apath, recursive=True):
|
317
|
+
matching_files.append(file)
|
318
|
+
|
319
|
+
if matching_files:
|
320
|
+
basedir = commonprefix(matching_files) # Find basedir
|
321
|
+
return matching_files, basedir
|
322
|
+
|
323
|
+
return [], None
|
324
|
+
|
325
|
+
|
326
|
+
def _process_ir(pfile: str, irtab_col: list, basedir: str) -> pd.DataFrame:
|
327
|
+
# 1) Read ir audio file, and calc basic stats
|
328
|
+
ir_fname = pfile[1] # abs_path
|
329
|
+
irwav, sample_rate = soundfile.read(ir_fname)
|
330
|
+
if irwav.ndim == 2:
|
331
|
+
irwav = irwav[:, 0] # Only first channel of multi-channel
|
332
|
+
duration = len(irwav) / sample_rate
|
333
|
+
srk = sample_rate / 1000
|
334
|
+
ir_basename = relpath(ir_fname, basedir)
|
335
|
+
|
336
|
+
# 2) Compute delay via autocorrelation (not working - always zero, use interplated tdoa instead)
|
337
|
+
# ar = np.correlate(irwav, irwav, mode='same')
|
338
|
+
# acdelay_index = np.argmax(ar)
|
339
|
+
# dacc= acdelay_index - len(ar) // 2 # Center the delay around 0 of 'same' mode
|
340
|
+
|
341
|
+
# 3) Compute delay via max argument - find the peak
|
342
|
+
peak_index = np.argmax(irwav)
|
343
|
+
peak_value = irwav[peak_index]
|
344
|
+
dmax = peak_index
|
345
|
+
|
346
|
+
# 4) Calculate cross-correlation with white gaussian noise ref (ssame as pyrooma.tdoa() with interp=1)
|
347
|
+
np.random.seed(42)
|
348
|
+
wgn_ref = np.random.normal(0, 0.2, int(np.ceil(0.05 * sample_rate))) # (mean,std_dev,length)
|
349
|
+
wgn_conv = np.convolve(irwav, wgn_ref)
|
350
|
+
wgn_corr = np.correlate(wgn_conv, wgn_ref, mode="full") # Compute cross-correlation
|
351
|
+
delay_index = np.argmax(np.abs(wgn_corr)) # Find the delay (need abs??, yes)
|
352
|
+
dcc = delay_index - len(wgn_ref) + 1 # Adjust for the mode='full' shift
|
353
|
+
# GCC with PHAT weighting known to be best, but does seem to mismatch dcc, dmax more frequently
|
354
|
+
dtdoa = tdoa(wgn_conv, wgn_ref, interp=16, phat=True)
|
355
|
+
gdccmax = np.max(np.abs(wgn_conv)) / np.max(np.abs(wgn_ref)) # gain of max value
|
356
|
+
|
357
|
+
# # 4b) Calculate cross-correlation with chirp 20Hz-20KHz
|
358
|
+
# t_end = 2 # 1s
|
359
|
+
# t = np.linspace(0, t_end, int(t_end * sample_rate))
|
360
|
+
# k = (20 - 20000) / t_end
|
361
|
+
# chrp_phase = 2 * np.pi * (20 * t + 0.5 * k * t ** 2)
|
362
|
+
# chrp = np.cos(chrp_phase)
|
363
|
+
# chrp_convout = np.convolve(irwav,chrp)
|
364
|
+
# chrp_corr = np.correlate(chrp_convout, chrp, mode='full') # Compute cross-correlation
|
365
|
+
# chrp_delay_idx = np.argmax(np.abs(chrp_corr))
|
366
|
+
# dcchr = chrp_delay_idx - len(chrp) + 1
|
367
|
+
# dtdoachr = tdoa(chrp_convout, chrp, interp=16, phat=False)
|
368
|
+
# gdcchrmax = np.max(np.abs(chrp_convout)) / np.max(np.abs(chrp))
|
369
|
+
# #sin_ref = np.sin(2 * np.pi * 500/sample_rate * np.arange(0,sample_rate))
|
370
|
+
|
371
|
+
# # Create a pulse train alternating +1, -1, ... of width PW, spacing PS_ms
|
372
|
+
# PS = int(0.010 * sample_rate) # Spacing between pulses in sec (to samples)
|
373
|
+
# PW = 5 # Pulse width in samples, make sure < PS
|
374
|
+
# PTLEN = int(1 * sample_rate) # Length in sec (to samples)
|
375
|
+
# #sample_vec = np.arange(PTLEN)
|
376
|
+
#
|
377
|
+
# # Construct the pulse train
|
378
|
+
# ptrain_ref = np.zeros(PTLEN)
|
379
|
+
# polarity = 1
|
380
|
+
# for i in range(0, PTLEN, PS):
|
381
|
+
# if polarity == 1:
|
382
|
+
# ptrain_ref[i:(i + PW)] = 1
|
383
|
+
# polarity = -1
|
384
|
+
# else:
|
385
|
+
# ptrain_ref[i:(i + PW)] = -1
|
386
|
+
# polarity = 1
|
387
|
+
#
|
388
|
+
# pt_convout = np.convolve(irwav,ptrain_ref)
|
389
|
+
# pt_corr = np.correlate(pt_convout, ptrain_ref, mode='full') # Compute cross-correlation
|
390
|
+
# pt_delay_idx = np.argmax(np.abs(pt_corr))
|
391
|
+
# dcc = pt_delay_idx - len(ptrain_ref) + 1
|
392
|
+
# dtdoa = tdoa(pt_convout, ptrain_ref, interp=16, phat=True)
|
393
|
+
# gdccptmax = np.max(np.abs(pt_convout)) / np.max(np.abs(ptrain_ref))
|
394
|
+
|
395
|
+
# 5) Calculate delay using group_delay method
|
396
|
+
fft_size = len(irwav)
|
397
|
+
H = np.fft.fft(irwav, n=fft_size)
|
398
|
+
phase = np.unwrap(np.angle(H))
|
399
|
+
freq = np.fft.fftfreq(fft_size) # in samples, using d=1/sampling_rate=1
|
400
|
+
group_delay = -np.gradient(phase) / (2 * np.pi * np.gradient(freq))
|
401
|
+
dagd = np.mean(group_delay[np.isfinite(group_delay)]) # Average group delay
|
402
|
+
gmax = max(np.abs(H))
|
403
|
+
|
404
|
+
rt60, edt, rt10, rt20, nfloor = measure_rt60(irwav, sample_rate, plot=False)
|
405
|
+
|
406
|
+
# 4) Tabulate metrics as single row in table of scalar metrics per mixture
|
407
|
+
# irtab_col = ["dmax", "dcc", "dccphat", "dagd", "gdccmax", "rt20", "rt60", "max", "min", "gmax", "dur", "sr", "irfile"]
|
408
|
+
metr1 = [dmax, dcc, dtdoa, dagd, gdccmax, rt20, rt60, peak_value, min(irwav), gmax, duration, srk, ir_basename]
|
409
|
+
mtab1 = pd.DataFrame([metr1], columns=irtab_col, index=[pfile[0]]) # return tuple of dataframe
|
410
|
+
|
411
|
+
return mtab1
|
412
|
+
|
413
|
+
|
414
|
+
def main():
|
415
|
+
from docopt import docopt
|
416
|
+
|
417
|
+
from . import __version__ as sai_version
|
418
|
+
from .utils.docstring import trim_docstring
|
419
|
+
|
420
|
+
args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
|
421
|
+
|
422
|
+
ir_location = args["IRLOC"]
|
423
|
+
num_proc = args["--num_process"]
|
424
|
+
|
425
|
+
import psutil
|
426
|
+
|
427
|
+
from .utils.create_timestamp import create_timestamp
|
428
|
+
from .utils.parallel import par_track
|
429
|
+
from .utils.parallel import track
|
430
|
+
|
431
|
+
# Check location, default ext are ['.wav', '.WAV', '.flac', '.FLAC', '.mp3', '.aac']
|
432
|
+
pfiles, basedir = process_path(ir_location)
|
433
|
+
pfiles = sorted(pfiles, key=basename)
|
434
|
+
|
435
|
+
if pfiles is None or len(pfiles) < 1:
|
436
|
+
print(f"No IR audio files found in {ir_location}, exiting ...")
|
437
|
+
raise SystemExit(1)
|
438
|
+
if len(pfiles) == 1:
|
439
|
+
print(f"Found single IR audio file {ir_location} , writing to *-irmetric.txt ...")
|
440
|
+
fbase, ext = splitext(basename(pfiles[0]))
|
441
|
+
wlcsv_name = None
|
442
|
+
txt_fname = str(join(basedir, fbase + "-irmetric.txt"))
|
443
|
+
else:
|
444
|
+
print(f"Found {len(pfiles)} files under {basedir} for impulse response metric calculations")
|
445
|
+
wlcsv_name = str(join(basedir, "ir_metric_list.csv"))
|
446
|
+
txt_fname = str(join(basedir, "ir_metric_summary.txt"))
|
447
|
+
|
448
|
+
num_cpu = psutil.cpu_count()
|
449
|
+
cpu_percent = psutil.cpu_percent(interval=1)
|
450
|
+
print(f"#CPUs: {num_cpu}, current CPU utilization: {cpu_percent}%")
|
451
|
+
print(f"Memory utilization: {psutil.virtual_memory().percent}%")
|
452
|
+
if num_proc == "auto":
|
453
|
+
use_cpu = int(num_cpu * (0.9 - cpu_percent / 100)) # default use 80% of available cpus
|
454
|
+
elif num_proc == "None":
|
455
|
+
use_cpu = None
|
456
|
+
else:
|
457
|
+
use_cpu = min(max(int(num_proc), 1), num_cpu)
|
458
|
+
|
459
|
+
timestamp = create_timestamp()
|
460
|
+
# Individual mixtures use pandas print, set precision to 2 decimal places
|
461
|
+
# pd.set_option('float_format', '{:.2f}'.format)
|
462
|
+
print(f"Calculating metrics for {len(pfiles)} impulse response files using {use_cpu} parallel processes ...")
|
463
|
+
progress = track(total=len(pfiles))
|
464
|
+
if use_cpu is None or len(pfiles) == 1:
|
465
|
+
no_par = True
|
466
|
+
num_cpus = None
|
467
|
+
else:
|
468
|
+
no_par = False
|
469
|
+
num_cpus = use_cpu
|
470
|
+
|
471
|
+
from functools import partial
|
472
|
+
|
473
|
+
# Setup pandas table for summarizing ir metrics
|
474
|
+
irtab_col = [
|
475
|
+
"dmax",
|
476
|
+
"dcc",
|
477
|
+
"dccphat",
|
478
|
+
"dagd",
|
479
|
+
"gdccmax",
|
480
|
+
"rt20",
|
481
|
+
"rt60",
|
482
|
+
"max",
|
483
|
+
"min",
|
484
|
+
"gmax",
|
485
|
+
"dur",
|
486
|
+
"sr",
|
487
|
+
"irfile",
|
488
|
+
]
|
489
|
+
idx = range(len(pfiles))
|
490
|
+
llfiles = list(zip(idx, pfiles, strict=False))
|
491
|
+
|
492
|
+
all_metrics_tables = par_track(
|
493
|
+
partial(
|
494
|
+
_process_ir,
|
495
|
+
irtab_col=irtab_col,
|
496
|
+
basedir=basedir,
|
497
|
+
),
|
498
|
+
llfiles,
|
499
|
+
progress=progress,
|
500
|
+
num_cpus=num_cpus,
|
501
|
+
no_par=no_par,
|
502
|
+
)
|
503
|
+
progress.close()
|
504
|
+
|
505
|
+
# progress = tqdm(total=len(pfiles), desc='ir_metric')
|
506
|
+
# if use_cpu is None:
|
507
|
+
# all_metrics_tab = pp_tqdm_imap(_process_mixture, pfiles, progress=progress, no_par=True)
|
508
|
+
# else:
|
509
|
+
# all_metrics_tab = pp_tqdm_imap(_process_mixture, pfiles, progress=progress, num_cpus=use_cpu)
|
510
|
+
# progress.close()
|
511
|
+
|
512
|
+
header_args = {
|
513
|
+
"mode": "a",
|
514
|
+
"encoding": "utf-8",
|
515
|
+
"index": False,
|
516
|
+
"header": False,
|
517
|
+
}
|
518
|
+
table_args = {
|
519
|
+
"mode": "a",
|
520
|
+
"encoding": "utf-8",
|
521
|
+
}
|
522
|
+
|
523
|
+
all_metrics_tab = pd.concat([item for item in all_metrics_tables]) # already sorted by truth filename via idx
|
524
|
+
mtabsort = all_metrics_tab.sort_values(by=["irfile"])
|
525
|
+
|
526
|
+
# Write list to .csv
|
527
|
+
if wlcsv_name:
|
528
|
+
pd.DataFrame([["Timestamp", timestamp]]).to_csv(wlcsv_name, header=False, index=False)
|
529
|
+
pd.DataFrame([f"IR metric list for {ir_location}:"]).to_csv(wlcsv_name, mode="a", header=False, index=False)
|
530
|
+
mtabsort.round(2).to_csv(wlcsv_name, **table_args)
|
531
|
+
|
532
|
+
# Write summary and list to .txt
|
533
|
+
with open(txt_fname, "w") as f:
|
534
|
+
print(f"Timestamp: {timestamp}", file=f)
|
535
|
+
print(f"IR metrics stats over {len(llfiles)} files:", file=f)
|
536
|
+
print(mtabsort.describe().round(3).T.to_string(float_format=lambda x: f"{x:.3f}", index=True), file=f)
|
537
|
+
print("", file=f)
|
538
|
+
print("", file=f)
|
539
|
+
print([f"IR metric list for {ir_location}:"], file=f)
|
540
|
+
print(mtabsort.round(3).to_string(), file=f)
|
541
|
+
|
542
|
+
|
543
|
+
if __name__ == "__main__":
|
544
|
+
from sonusai import exception_handler
|
545
|
+
from sonusai.utils.keyboard_interrupt import register_keyboard_interrupt
|
546
|
+
|
547
|
+
register_keyboard_interrupt()
|
548
|
+
try:
|
549
|
+
main()
|
550
|
+
except Exception as e:
|
551
|
+
exception_handler(e)
|
sonusai/lsdb.py
ADDED
@@ -0,0 +1,141 @@
|
|
1
|
+
"""sonusai lsdb
|
2
|
+
|
3
|
+
usage: lsdb [-hsa] [-i MIXID] [-c CID] LOC
|
4
|
+
|
5
|
+
Options:
|
6
|
+
-h, --help
|
7
|
+
-i MIXID, --mixid MIXID Mixture ID(s) to analyze. [default: *].
|
8
|
+
-c CID, --class_index CID Analyze mixtures that contain this class index.
|
9
|
+
-s, --sources List all source files.
|
10
|
+
-a, --all_class_counts List all class counts.
|
11
|
+
|
12
|
+
List mixture data information from a SonusAI mixture database.
|
13
|
+
|
14
|
+
Inputs:
|
15
|
+
LOC A SonusAI mixture database directory.
|
16
|
+
|
17
|
+
"""
|
18
|
+
|
19
|
+
from sonusai.datatypes import GeneralizedIDs
|
20
|
+
from sonusai.mixture import MixtureDatabase
|
21
|
+
|
22
|
+
|
23
|
+
def lsdb(
|
24
|
+
mixdb: MixtureDatabase,
|
25
|
+
mixids: GeneralizedIDs = "*",
|
26
|
+
class_index: int | None = None,
|
27
|
+
list_targets: bool = False,
|
28
|
+
all_class_counts: bool = False,
|
29
|
+
) -> None:
|
30
|
+
from sonusai import logger
|
31
|
+
from sonusai.constants import SAMPLE_RATE
|
32
|
+
from sonusai.queries.queries import get_mixids_from_class_indices
|
33
|
+
from sonusai.utils.print_mixture_details import print_mixture_details
|
34
|
+
from sonusai.utils.ranges import consolidate_range
|
35
|
+
from sonusai.utils.seconds_to_hms import seconds_to_hms
|
36
|
+
|
37
|
+
desc_len = 24
|
38
|
+
|
39
|
+
total_samples = mixdb.total_samples()
|
40
|
+
total_duration = total_samples / SAMPLE_RATE
|
41
|
+
|
42
|
+
logger.info(f"{'Mixtures':{desc_len}} {mixdb.num_mixtures}")
|
43
|
+
logger.info(f"{'Duration':{desc_len}} {seconds_to_hms(seconds=total_duration)}")
|
44
|
+
logger.info(f"{'Sources':{desc_len}} {mixdb.num_source_files}")
|
45
|
+
logger.info(f"{'Feature':{desc_len}} {mixdb.feature}")
|
46
|
+
logger.info(
|
47
|
+
f"{'Feature shape':{desc_len}} {mixdb.fg_stride} x {mixdb.feature_parameters} "
|
48
|
+
f"({mixdb.fg_stride * mixdb.feature_parameters} total params)"
|
49
|
+
)
|
50
|
+
logger.info(f"{'Feature samples':{desc_len}} {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)")
|
51
|
+
logger.info(
|
52
|
+
f"{'Feature step samples':{desc_len}} {mixdb.feature_step_samples} samples ({mixdb.feature_step_ms} ms)"
|
53
|
+
)
|
54
|
+
logger.info(f"{'Feature overlap':{desc_len}} {mixdb.fg_step / mixdb.fg_stride} ({mixdb.feature_step_ms} ms)")
|
55
|
+
logger.info(f"{'SNRs':{desc_len}} {mixdb.snrs}")
|
56
|
+
logger.info(f"{'Random SNRs':{desc_len}} {mixdb.random_snrs}")
|
57
|
+
logger.info(f"{'Classes':{desc_len}} {mixdb.num_classes}")
|
58
|
+
# TODO: fix class count
|
59
|
+
logger.info(f"{'Class count':{desc_len}} not supported")
|
60
|
+
# print_class_count(class_count=class_count, length=desc_len, print_fn=logger.info)
|
61
|
+
# TODO: add class weight calculations here
|
62
|
+
logger.info("")
|
63
|
+
|
64
|
+
if list_targets:
|
65
|
+
logger.info("Source details:")
|
66
|
+
for category, sources in mixdb.source_files.items():
|
67
|
+
print(f" {category}:")
|
68
|
+
for source in sources:
|
69
|
+
logger.info(f"{' Name':{desc_len}} {source.name}")
|
70
|
+
logger.info(f"{' Truth index':{desc_len}} {source.class_indices}")
|
71
|
+
logger.info("")
|
72
|
+
|
73
|
+
if class_index is not None:
|
74
|
+
if 0 <= class_index > mixdb.num_classes:
|
75
|
+
raise ValueError(f"Given class_index is outside valid range of 1-{mixdb.num_classes}")
|
76
|
+
ids = get_mixids_from_class_indices(mixdb=mixdb, predicate=lambda x: x in [class_index])[class_index]
|
77
|
+
logger.info(f"Mixtures with class index {class_index}: {ids}")
|
78
|
+
logger.info("")
|
79
|
+
|
80
|
+
mixids = mixdb.mixids_to_list(mixids)
|
81
|
+
|
82
|
+
if len(mixids) == 1:
|
83
|
+
print_mixture_details(mixdb=mixdb, mixid=mixids[0], print_fn=logger.info)
|
84
|
+
if all_class_counts:
|
85
|
+
# TODO: fix class count
|
86
|
+
logger.info("All class count not supported")
|
87
|
+
# print_class_count(class_count=class_count, length=desc_len, print_fn=logger.info, all_class_counts=True)
|
88
|
+
else:
|
89
|
+
logger.info(
|
90
|
+
f"Calculating statistics from truth_f files for {len(mixids):,} mixtures ({consolidate_range(mixids)})"
|
91
|
+
)
|
92
|
+
logger.info("Not supported")
|
93
|
+
|
94
|
+
|
95
|
+
def main() -> None:
|
96
|
+
from docopt import docopt
|
97
|
+
|
98
|
+
from sonusai import __version__ as sai_version
|
99
|
+
from sonusai.utils.docstring import trim_docstring
|
100
|
+
|
101
|
+
args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
|
102
|
+
|
103
|
+
mixid = args["--mixid"]
|
104
|
+
class_index = args["--class_index"]
|
105
|
+
list_targets = args["--targets"]
|
106
|
+
all_class_counts = args["--all_class_counts"]
|
107
|
+
location = args["LOC"]
|
108
|
+
|
109
|
+
from sonusai import create_file_handler
|
110
|
+
from sonusai import initial_log_messages
|
111
|
+
from sonusai import logger
|
112
|
+
from sonusai import update_console_handler
|
113
|
+
|
114
|
+
if class_index is not None:
|
115
|
+
class_index = int(class_index)
|
116
|
+
|
117
|
+
create_file_handler("lsdb.log")
|
118
|
+
update_console_handler(False)
|
119
|
+
initial_log_messages("lsdb")
|
120
|
+
|
121
|
+
logger.info(f"Analyzing {location}")
|
122
|
+
|
123
|
+
mixdb = MixtureDatabase(location)
|
124
|
+
lsdb(
|
125
|
+
mixdb=mixdb,
|
126
|
+
mixids=mixid,
|
127
|
+
class_index=class_index,
|
128
|
+
list_targets=list_targets,
|
129
|
+
all_class_counts=all_class_counts,
|
130
|
+
)
|
131
|
+
|
132
|
+
|
133
|
+
if __name__ == "__main__":
|
134
|
+
from sonusai import exception_handler
|
135
|
+
from sonusai.utils.keyboard_interrupt import register_keyboard_interrupt
|
136
|
+
|
137
|
+
register_keyboard_interrupt()
|
138
|
+
try:
|
139
|
+
main()
|
140
|
+
except Exception as e:
|
141
|
+
exception_handler(e)
|