sonusai 0.19.9__py3-none-any.whl → 0.20.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/calc_metric_spenh.py +265 -233
- sonusai/data/genmixdb.yml +4 -2
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/doc/doc.py +14 -0
- sonusai/genft.py +1 -1
- sonusai/genmetrics.py +15 -18
- sonusai/genmix.py +1 -1
- sonusai/genmixdb.py +30 -52
- sonusai/ir_metric.py +555 -0
- sonusai/metrics_summary.py +322 -0
- sonusai/mixture/__init__.py +6 -2
- sonusai/mixture/audio.py +139 -15
- sonusai/mixture/augmentation.py +199 -84
- sonusai/mixture/config.py +9 -4
- sonusai/mixture/constants.py +0 -1
- sonusai/mixture/datatypes.py +19 -10
- sonusai/mixture/generation.py +52 -64
- sonusai/mixture/helpers.py +38 -26
- sonusai/mixture/ir_delay.py +63 -0
- sonusai/mixture/mixdb.py +190 -46
- sonusai/mixture/targets.py +3 -6
- sonusai/mixture/truth_functions/energy.py +9 -5
- sonusai/mixture/truth_functions/metrics.py +1 -1
- sonusai/mkwav.py +1 -1
- sonusai/onnx_predict.py +1 -1
- sonusai/queries/queries.py +1 -1
- sonusai/utils/__init__.py +2 -0
- sonusai/utils/asr.py +1 -1
- sonusai/utils/load_object.py +8 -2
- sonusai/utils/stratified_shuffle_split.py +1 -1
- sonusai/utils/temp_seed.py +13 -0
- {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/METADATA +2 -2
- {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/RECORD +36 -35
- {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/WHEEL +1 -1
- sonusai/mixture/soundfile_audio.py +0 -130
- sonusai/mixture/sox_audio.py +0 -476
- sonusai/mixture/sox_augmentation.py +0 -136
- sonusai/mixture/torchaudio_audio.py +0 -106
- sonusai/mixture/torchaudio_augmentation.py +0 -109
- {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,322 @@
|
|
1
|
+
"""sonusai metrics_summary
|
2
|
+
|
3
|
+
usage: lsdb [-vlh] [-i MIXID] [-n NCPU] LOCATION
|
4
|
+
|
5
|
+
Options:
|
6
|
+
-h, --help
|
7
|
+
-v, --verbose
|
8
|
+
-l, --write-list Write .csv file list of all mixture metrics
|
9
|
+
-i MIXID, --mixid MIXID Mixture ID(s) to analyze. [default: *].
|
10
|
+
-n, --num_process NCPU Number of parallel processes to use [default: auto]
|
11
|
+
|
12
|
+
Summarize mixture metrics across a SonusAI mixture database where metrics have been generated by SonusAI genmetrics.
|
13
|
+
|
14
|
+
Inputs:
|
15
|
+
LOCATION A SonusAI mixture database directory with mixdb.db and pre-generated metrics from SonusAI genmetrics.
|
16
|
+
|
17
|
+
"""
|
18
|
+
|
19
|
+
import signal
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
import pandas as pd
|
23
|
+
|
24
|
+
|
25
|
+
def signal_handler(_sig, _frame):
|
26
|
+
import sys
|
27
|
+
|
28
|
+
from sonusai import logger
|
29
|
+
|
30
|
+
logger.info("Canceled due to keyboard interrupt")
|
31
|
+
sys.exit(1)
|
32
|
+
|
33
|
+
|
34
|
+
signal.signal(signal.SIGINT, signal_handler)
|
35
|
+
|
36
|
+
DB_99 = np.power(10, 99 / 10)
|
37
|
+
DB_N99 = np.power(10, -99 / 10)
|
38
|
+
|
39
|
+
|
40
|
+
def _process_mixture(
|
41
|
+
m_id: int,
|
42
|
+
location: str,
|
43
|
+
all_metric_names: list[str],
|
44
|
+
scalar_metric_names: list[str],
|
45
|
+
string_metric_names: list[str],
|
46
|
+
frame_metric_names: list[str],
|
47
|
+
bin_metric_names: list[str],
|
48
|
+
ptab_labels: list[str],
|
49
|
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
50
|
+
from os.path import basename
|
51
|
+
|
52
|
+
from sonusai.metrics import calc_wer
|
53
|
+
from sonusai.mixture import SAMPLE_RATE
|
54
|
+
from sonusai.mixture import MixtureDatabase
|
55
|
+
|
56
|
+
mixdb = MixtureDatabase(location)
|
57
|
+
|
58
|
+
# Process mixture
|
59
|
+
# for mixid in mixids:
|
60
|
+
samples = mixdb.mixture(m_id).samples
|
61
|
+
duration = samples / SAMPLE_RATE
|
62
|
+
tf_frames = mixdb.mixture_transform_frames(m_id)
|
63
|
+
feat_frames = mixdb.mixture_feature_frames(m_id)
|
64
|
+
mxsnr = mixdb.mixture(m_id).snr
|
65
|
+
ti = mixdb.mixture(m_id).targets[0].file_id
|
66
|
+
ni = mixdb.mixture(m_id).noise.file_id
|
67
|
+
t0file = basename(mixdb.target_file(ti).name)
|
68
|
+
nfile = basename(mixdb.noise_file(ni).name)
|
69
|
+
|
70
|
+
all_metrics = mixdb.mixture_metrics(m_id, all_metric_names)
|
71
|
+
|
72
|
+
# replace lists with first value (ignore mixup)
|
73
|
+
scalar_metrics = {
|
74
|
+
key: all_metrics[key][0] if isinstance(all_metrics[key], list) else all_metrics[key]
|
75
|
+
for key in scalar_metric_names
|
76
|
+
}
|
77
|
+
string_metrics = {
|
78
|
+
key: all_metrics[key][0] if isinstance(all_metrics[key], list) else all_metrics[key]
|
79
|
+
for key in string_metric_names
|
80
|
+
}
|
81
|
+
|
82
|
+
# Convert strings into word count
|
83
|
+
for key in string_metrics:
|
84
|
+
string_metrics[key] = calc_wer(string_metrics[key], string_metrics[key]).words
|
85
|
+
|
86
|
+
# Collect pandas table values note: must match given ptab_labels
|
87
|
+
ptab_data: list = [
|
88
|
+
mxsnr,
|
89
|
+
*scalar_metrics.values(),
|
90
|
+
*string_metrics.values(),
|
91
|
+
tf_frames,
|
92
|
+
duration,
|
93
|
+
t0file,
|
94
|
+
nfile,
|
95
|
+
]
|
96
|
+
|
97
|
+
ptab1 = pd.DataFrame([ptab_data], columns=ptab_labels, index=[m_id])
|
98
|
+
|
99
|
+
# TODO: collect frame metrics and bin metrics
|
100
|
+
|
101
|
+
return ptab1, ptab1
|
102
|
+
|
103
|
+
|
104
|
+
def main() -> None:
|
105
|
+
from docopt import docopt
|
106
|
+
|
107
|
+
from sonusai import __version__ as sonusai_ver
|
108
|
+
from sonusai.utils import trim_docstring
|
109
|
+
|
110
|
+
args = docopt(trim_docstring(__doc__), version=sonusai_ver, options_first=True)
|
111
|
+
|
112
|
+
verbose = args["--verbose"]
|
113
|
+
wrlist = args["--write-list"]
|
114
|
+
mixids = args["--mixid"]
|
115
|
+
location = args["LOCATION"]
|
116
|
+
num_proc = args["--num_process"]
|
117
|
+
|
118
|
+
from functools import partial
|
119
|
+
from os.path import basename
|
120
|
+
from os.path import join
|
121
|
+
|
122
|
+
import psutil
|
123
|
+
|
124
|
+
from sonusai import create_file_handler
|
125
|
+
from sonusai import initial_log_messages
|
126
|
+
from sonusai import logger
|
127
|
+
from sonusai import update_console_handler
|
128
|
+
from sonusai.mixture import MixtureDatabase
|
129
|
+
from sonusai.utils import create_timestamp
|
130
|
+
from sonusai.utils import par_track
|
131
|
+
from sonusai.utils import track
|
132
|
+
|
133
|
+
try:
|
134
|
+
mixdb = MixtureDatabase(location)
|
135
|
+
print(f"Found SonusAI mixture database with {mixdb.num_mixtures} mixtures.")
|
136
|
+
except:
|
137
|
+
print(f"Could not open SonusAI mixture database in {location}, exiting ...")
|
138
|
+
return
|
139
|
+
|
140
|
+
# Only check first and last mixture in order to save time
|
141
|
+
metrics_present = mixdb.cached_metrics([0, mixdb.num_mixtures - 1])
|
142
|
+
|
143
|
+
num_metrics_present = len(metrics_present)
|
144
|
+
if num_metrics_present < 1:
|
145
|
+
print(f"mixdb reports no pre-generated metrics are present. Nothing to summarize in {location}, exiting ...")
|
146
|
+
return
|
147
|
+
|
148
|
+
# Setup logging file
|
149
|
+
timestamp = create_timestamp() # string good for embedding into filenames
|
150
|
+
mixdb_fname = basename(location)
|
151
|
+
if verbose:
|
152
|
+
create_file_handler(join(location, "metrics_summary.log"))
|
153
|
+
update_console_handler(verbose)
|
154
|
+
initial_log_messages("metrics_summary")
|
155
|
+
logger.info(f"Logging summary of SonusAI mixture database at {location}")
|
156
|
+
else:
|
157
|
+
update_console_handler(verbose)
|
158
|
+
|
159
|
+
logger.info("")
|
160
|
+
mixids = mixdb.mixids_to_list(mixids)
|
161
|
+
if len(mixids) < mixdb.num_mixtures:
|
162
|
+
logger.info(
|
163
|
+
f"Processing a subset of {len(mixids)} out of total mixdb mixtures of {mixdb.num_mixtures}, "
|
164
|
+
f"summary results will not include entire dataset."
|
165
|
+
)
|
166
|
+
fsuffix = f"_s{len(mixids)}t{mixdb.num_mixtures}"
|
167
|
+
else:
|
168
|
+
logger.info(
|
169
|
+
f"Summarizing SonusAI mixture database with {mixdb.num_mixtures} mixtures "
|
170
|
+
f"and {num_metrics_present} pre-generated metrics ..."
|
171
|
+
)
|
172
|
+
fsuffix = ""
|
173
|
+
|
174
|
+
metric_sup = mixdb.supported_metrics
|
175
|
+
ft_bins = mixdb.ft_config.bin_end - mixdb.ft_config.bin_start + 1 # bins of forward transform
|
176
|
+
# Pre-process first mixid to gather metrics into 4 types: scalar, str (scalar word cnt), frame-array, bin-array
|
177
|
+
# Collect list of indices for each
|
178
|
+
scalar_metric_names: list[str] = []
|
179
|
+
string_metric_names: list[str] = []
|
180
|
+
frame_metric_names: list[str] = []
|
181
|
+
bin_metric_names: list[str] = []
|
182
|
+
all_metrics = mixdb.mixture_metrics(mixids[0], metrics_present)
|
183
|
+
tf_frames = mixdb.mixture_transform_frames(mixids[0])
|
184
|
+
for metric in metrics_present:
|
185
|
+
metval = all_metrics[metric] # get metric value
|
186
|
+
logger.debug(f"First mixid {mixids[0]} metric {metric} = {metval}")
|
187
|
+
if isinstance(metval, list):
|
188
|
+
if len(metval) > 1:
|
189
|
+
logger.warning(f"Mixid {mixids[0]} metric {metric} has a list with more than 1 element, using first.")
|
190
|
+
metval = metval[0] # remove any list
|
191
|
+
if isinstance(metval, float):
|
192
|
+
logger.debug("Metric is scalar float, entering in summary table.")
|
193
|
+
scalar_metric_names.append(metric)
|
194
|
+
elif isinstance(metval, str):
|
195
|
+
logger.debug("Metric is string, will summarize with word count.")
|
196
|
+
string_metric_names.append(metric)
|
197
|
+
elif isinstance(metval, np.ndarray):
|
198
|
+
if metval.ndim == 1:
|
199
|
+
if metval.size == tf_frames:
|
200
|
+
logger.debug("Metric is frames vector.")
|
201
|
+
frame_metric_names.append(metric)
|
202
|
+
elif metval.size == ft_bins:
|
203
|
+
logger.debug("Metric is bins vector.")
|
204
|
+
bin_metric_names.append(metric)
|
205
|
+
else:
|
206
|
+
logger.warning(f"Mixid {mixids[0]} metric {metric} is a vector of improper size, ignoring.")
|
207
|
+
|
208
|
+
# Setup pandas table for summarizing scalar metrics
|
209
|
+
ptab_labels = [
|
210
|
+
"mxsnr",
|
211
|
+
*scalar_metric_names,
|
212
|
+
*string_metric_names,
|
213
|
+
"fcnt",
|
214
|
+
"duration",
|
215
|
+
"t0file",
|
216
|
+
"nfile",
|
217
|
+
]
|
218
|
+
|
219
|
+
num_cpu = psutil.cpu_count()
|
220
|
+
cpu_percent = psutil.cpu_percent(interval=1)
|
221
|
+
logger.info("")
|
222
|
+
logger.info(f"#CPUs: {num_cpu}, current CPU utilization: {cpu_percent}%")
|
223
|
+
logger.info(f"Memory utilization: {psutil.virtual_memory().percent}%")
|
224
|
+
if num_proc == "auto":
|
225
|
+
use_cpu = int(num_cpu * (0.9 - cpu_percent / 100)) # default use 80% of available cpus
|
226
|
+
elif num_proc == "None":
|
227
|
+
use_cpu = None
|
228
|
+
else:
|
229
|
+
use_cpu = min(max(int(num_proc), 1), num_cpu)
|
230
|
+
|
231
|
+
logger.info(f"Summarizing metrics for {len(mixids)} mixtures using {use_cpu} parallel processes")
|
232
|
+
|
233
|
+
# progress = tqdm(total=len(mixids), desc='calc_metric_spenh', mininterval=1)
|
234
|
+
progress = track(total=len(mixids))
|
235
|
+
if use_cpu is None:
|
236
|
+
no_par = True
|
237
|
+
num_cpus = None
|
238
|
+
else:
|
239
|
+
no_par = False
|
240
|
+
num_cpus = use_cpu
|
241
|
+
|
242
|
+
all_metrics_tables = par_track(
|
243
|
+
partial(
|
244
|
+
_process_mixture,
|
245
|
+
location=location,
|
246
|
+
all_metric_names=metrics_present,
|
247
|
+
scalar_metric_names=scalar_metric_names,
|
248
|
+
string_metric_names=string_metric_names,
|
249
|
+
frame_metric_names=frame_metric_names,
|
250
|
+
bin_metric_names=bin_metric_names,
|
251
|
+
ptab_labels=ptab_labels,
|
252
|
+
),
|
253
|
+
mixids,
|
254
|
+
progress=progress,
|
255
|
+
num_cpus=num_cpus,
|
256
|
+
no_par=no_par,
|
257
|
+
)
|
258
|
+
progress.close()
|
259
|
+
|
260
|
+
# Done with mixtures, write out summary metrics
|
261
|
+
header_args = {
|
262
|
+
"mode": "a",
|
263
|
+
"encoding": "utf-8",
|
264
|
+
"index": False,
|
265
|
+
"header": False,
|
266
|
+
}
|
267
|
+
table_args = {
|
268
|
+
"mode": "a",
|
269
|
+
"encoding": "utf-8",
|
270
|
+
}
|
271
|
+
ptab1 = pd.concat([item[0] for item in all_metrics_tables])
|
272
|
+
if wrlist:
|
273
|
+
wlcsv_name = str(join(location, "metric_summary_list" + fsuffix + ".csv"))
|
274
|
+
pd.DataFrame([["Timestamp", timestamp]]).to_csv(wlcsv_name, header=False, index=False)
|
275
|
+
pd.DataFrame([f"Metric list for {mixdb_fname}:"]).to_csv(wlcsv_name, mode="a", header=False, index=False)
|
276
|
+
ptab1.round(2).to_csv(wlcsv_name, **table_args)
|
277
|
+
ptab1_sorted = ptab1.sort_values(by=["mxsnr", "t0file"])
|
278
|
+
|
279
|
+
# Create metrics table except except -99 SNR
|
280
|
+
ptab1_nom99 = ptab1_sorted[ptab1_sorted.mxsnr != -99]
|
281
|
+
|
282
|
+
# Create summary by SNR for all scalar metrics, taking mean
|
283
|
+
mtab_snr_summary = None
|
284
|
+
for snri in range(0, len(mixdb.snrs)):
|
285
|
+
tmp = ptab1_sorted.query("mxsnr==" + str(mixdb.snrs[snri])).mean(numeric_only=True).to_frame().T
|
286
|
+
# avoid nan when subset of mixids specified (i.e. no mixtures exist for an SNR)
|
287
|
+
if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
|
288
|
+
mtab_snr_summary = pd.concat([mtab_snr_summary, tmp])
|
289
|
+
mtab_snr_summary = mtab_snr_summary.sort_values(by=["mxsnr"], ascending=False)
|
290
|
+
|
291
|
+
# Write summary to .csv
|
292
|
+
snrcsv_name = str(join(location, "metric_summary_snr" + fsuffix + ".csv"))
|
293
|
+
nmix = len(mixids)
|
294
|
+
nmixtot = mixdb.num_mixtures
|
295
|
+
pd.DataFrame([["Timestamp", timestamp]]).to_csv(snrcsv_name, header=False, index=False)
|
296
|
+
pd.DataFrame(['"Metrics avg over each SNR:"']).to_csv(snrcsv_name, **header_args)
|
297
|
+
mtab_snr_summary.round(2).to_csv(snrcsv_name, index=False, **table_args)
|
298
|
+
pd.DataFrame(["--"]).to_csv(snrcsv_name, header=False, index=False, mode="a")
|
299
|
+
pd.DataFrame([f'"Metrics stats over {nmix} mixtures out of {nmixtot} total:"']).to_csv(snrcsv_name, **header_args)
|
300
|
+
ptab1.describe().round(2).T.to_csv(snrcsv_name, index=True, **table_args)
|
301
|
+
pd.DataFrame(["--"]).to_csv(snrcsv_name, header=False, index=False, mode="a")
|
302
|
+
pd.DataFrame([f'"Metrics stats over {len(ptab1_nom99)} non -99db mixtures out of {nmixtot} total:"']).to_csv(
|
303
|
+
snrcsv_name, **header_args
|
304
|
+
)
|
305
|
+
ptab1_nom99.describe().round(2).T.to_csv(snrcsv_name, index=True, **table_args)
|
306
|
+
|
307
|
+
# Write summary to .csv
|
308
|
+
snrtxt_name = str(join(location, "metric_summary_snr" + fsuffix + ".txt"))
|
309
|
+
with open(snrtxt_name, "w") as f:
|
310
|
+
print(f"Timestamp: {timestamp}", file=f)
|
311
|
+
print("Metrics avg over each SNR:", file=f)
|
312
|
+
print(mtab_snr_summary.round(2).to_string(float_format=lambda x: f"{x:.2f}", index=False), file=f)
|
313
|
+
print("", file=f)
|
314
|
+
print(f"Metrics stats over {len(mixids)} mixtures out of {mixdb.num_mixtures} total:", file=f)
|
315
|
+
print(ptab1.describe().round(2).T.to_string(float_format=lambda x: f"{x:.2f}", index=True), file=f)
|
316
|
+
print("", file=f)
|
317
|
+
print(f"Metrics stats over {len(ptab1_nom99)} non -99db mixtures out of {mixdb.num_mixtures} total:", file=f)
|
318
|
+
print(ptab1_nom99.describe().round(2).T.to_string(float_format=lambda x: f"{x:.2f}", index=True), file=f)
|
319
|
+
|
320
|
+
|
321
|
+
if __name__ == "__main__":
|
322
|
+
main()
|
sonusai/mixture/__init__.py
CHANGED
@@ -5,6 +5,7 @@ from .audio import get_duration
|
|
5
5
|
from .audio import get_next_noise
|
6
6
|
from .audio import get_num_samples
|
7
7
|
from .audio import get_sample_rate
|
8
|
+
from .audio import raw_read_audio
|
8
9
|
from .audio import read_audio
|
9
10
|
from .audio import read_ir
|
10
11
|
from .audio import validate_input_file
|
@@ -53,7 +54,9 @@ from .datatypes import AudioF
|
|
53
54
|
from .datatypes import AudioStatsMetrics
|
54
55
|
from .datatypes import AudioT
|
55
56
|
from .datatypes import Augmentation
|
57
|
+
from .datatypes import AugmentationEffects
|
56
58
|
from .datatypes import AugmentationRule
|
59
|
+
from .datatypes import AugmentationRuleEffects
|
57
60
|
from .datatypes import AugmentedTarget
|
58
61
|
from .datatypes import ClassCount
|
59
62
|
from .datatypes import EnergyF
|
@@ -87,6 +90,7 @@ from .datatypes import TruthParameter
|
|
87
90
|
from .datatypes import UniversalSNR
|
88
91
|
from .feature import get_audio_from_feature
|
89
92
|
from .feature import get_feature_from_audio
|
93
|
+
from .generation import generate_mixtures
|
90
94
|
from .generation import get_all_snrs_from_config
|
91
95
|
from .generation import initialize_db
|
92
96
|
from .generation import populate_class_label_table
|
@@ -99,7 +103,7 @@ from .generation import populate_target_file_table
|
|
99
103
|
from .generation import populate_top_table
|
100
104
|
from .generation import populate_truth_parameters_table
|
101
105
|
from .generation import update_mixid_width
|
102
|
-
from .generation import
|
106
|
+
from .generation import update_mixture
|
103
107
|
from .helpers import augmented_noise_samples
|
104
108
|
from .helpers import augmented_target_samples
|
105
109
|
from .helpers import check_audio_files_exist
|
@@ -110,10 +114,10 @@ from .helpers import get_transform_from_audio
|
|
110
114
|
from .helpers import inverse_transform
|
111
115
|
from .helpers import mixture_metadata
|
112
116
|
from .helpers import write_mixture_metadata
|
117
|
+
from .ir_delay import get_impulse_response_delay
|
113
118
|
from .log_duration_and_sizes import log_duration_and_sizes
|
114
119
|
from .mixdb import MixtureDatabase
|
115
120
|
from .mixdb import db_file
|
116
|
-
from .sox_audio import Transformer
|
117
121
|
from .spectral_mask import apply_spectral_mask
|
118
122
|
from .target_class_balancing import balance_targets
|
119
123
|
from .targets import get_augmented_target_ids_by_class
|
sonusai/mixture/audio.py
CHANGED
@@ -44,49 +44,173 @@ def validate_input_file(input_filepath: str | Path) -> None:
|
|
44
44
|
raise OSError(f"This installation cannot process .{ext} files")
|
45
45
|
|
46
46
|
|
47
|
-
|
48
|
-
def get_sample_rate(name: str | Path) -> int:
|
47
|
+
def get_sample_rate(name: str | Path, use_cache: bool = True) -> int:
|
49
48
|
"""Get sample rate from audio file
|
50
49
|
|
51
50
|
:param name: File name
|
51
|
+
:param use_cache: If true, use LRU caching
|
52
52
|
:return: Sample rate
|
53
53
|
"""
|
54
|
-
|
55
|
-
|
56
|
-
return
|
54
|
+
if use_cache:
|
55
|
+
return _get_sample_rate(name)
|
56
|
+
return _get_sample_rate.__wrapped__(name)
|
57
57
|
|
58
58
|
|
59
59
|
@lru_cache
|
60
|
-
def
|
60
|
+
def _get_sample_rate(name: str | Path) -> int:
|
61
|
+
"""Get sample rate from audio file using soundfile
|
62
|
+
|
63
|
+
:param name: File name
|
64
|
+
:return: Sample rate
|
65
|
+
"""
|
66
|
+
import soundfile
|
67
|
+
from pydub import AudioSegment
|
68
|
+
|
69
|
+
from .tokenized_shell_vars import tokenized_expand
|
70
|
+
|
71
|
+
expanded_name, _ = tokenized_expand(name)
|
72
|
+
|
73
|
+
try:
|
74
|
+
if expanded_name.endswith(".mp3"):
|
75
|
+
return AudioSegment.from_mp3(expanded_name).frame_rate
|
76
|
+
|
77
|
+
if expanded_name.endswith(".m4a"):
|
78
|
+
return AudioSegment.from_file(expanded_name).frame_rate
|
79
|
+
|
80
|
+
return soundfile.info(expanded_name).samplerate
|
81
|
+
except Exception as e:
|
82
|
+
if name != expanded_name:
|
83
|
+
raise OSError(f"Error reading {name} (expanded: {expanded_name}): {e}") from e
|
84
|
+
else:
|
85
|
+
raise OSError(f"Error reading {name}: {e}") from e
|
86
|
+
|
87
|
+
|
88
|
+
def raw_read_audio(name: str | Path) -> tuple[AudioT, int]:
|
89
|
+
import numpy as np
|
90
|
+
import soundfile
|
91
|
+
from pydub import AudioSegment
|
92
|
+
|
93
|
+
from .tokenized_shell_vars import tokenized_expand
|
94
|
+
|
95
|
+
expanded_name, _ = tokenized_expand(name)
|
96
|
+
|
97
|
+
try:
|
98
|
+
if expanded_name.endswith(".mp3"):
|
99
|
+
sound = AudioSegment.from_mp3(expanded_name)
|
100
|
+
raw = np.array(sound.get_array_of_samples()).astype(np.float32).reshape((-1, sound.channels))
|
101
|
+
raw = raw / 2 ** (sound.sample_width * 8 - 1)
|
102
|
+
sample_rate = sound.frame_rate
|
103
|
+
elif expanded_name.endswith(".m4a"):
|
104
|
+
sound = AudioSegment.from_file(expanded_name)
|
105
|
+
raw = np.array(sound.get_array_of_samples()).astype(np.float32).reshape((-1, sound.channels))
|
106
|
+
raw = raw / 2 ** (sound.sample_width * 8 - 1)
|
107
|
+
sample_rate = sound.frame_rate
|
108
|
+
else:
|
109
|
+
raw, sample_rate = soundfile.read(expanded_name, always_2d=True, dtype="float32")
|
110
|
+
except Exception as e:
|
111
|
+
if name != expanded_name:
|
112
|
+
raise OSError(f"Error reading {name} (expanded: {expanded_name}): {e}") from e
|
113
|
+
else:
|
114
|
+
raise OSError(f"Error reading {name}: {e}") from e
|
115
|
+
|
116
|
+
return np.squeeze(raw[:, 0].astype(np.float32)), sample_rate
|
117
|
+
|
118
|
+
|
119
|
+
def read_audio(name: str | Path, use_cache: bool = True) -> AudioT:
|
61
120
|
"""Read audio data from a file
|
62
121
|
|
63
122
|
:param name: File name
|
123
|
+
:param use_cache: If true, use LRU caching
|
64
124
|
:return: Array of time domain audio data
|
65
125
|
"""
|
66
|
-
|
67
|
-
|
68
|
-
return
|
126
|
+
if use_cache:
|
127
|
+
return _read_audio(name)
|
128
|
+
return _read_audio.__wrapped__(name)
|
69
129
|
|
70
130
|
|
71
131
|
@lru_cache
|
72
|
-
def
|
132
|
+
def _read_audio(name: str | Path) -> AudioT:
|
133
|
+
"""Read audio data from a file using soundfile
|
134
|
+
|
135
|
+
:param name: File name
|
136
|
+
:return: Array of time domain audio data
|
137
|
+
"""
|
138
|
+
import librosa
|
139
|
+
|
140
|
+
from .constants import SAMPLE_RATE
|
141
|
+
|
142
|
+
out, sample_rate = raw_read_audio(name)
|
143
|
+
out = librosa.resample(out, orig_sr=sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_hq")
|
144
|
+
|
145
|
+
return out
|
146
|
+
|
147
|
+
|
148
|
+
def read_ir(name: str | Path, delay: int, use_cache: bool = True) -> ImpulseResponseData:
|
73
149
|
"""Read impulse response data
|
74
150
|
|
75
151
|
:param name: File name
|
152
|
+
:param delay: Delay in samples
|
153
|
+
:param use_cache: If true, use LRU caching
|
154
|
+
:return: ImpulseResponseData object
|
155
|
+
"""
|
156
|
+
if use_cache:
|
157
|
+
return _read_ir(name, delay)
|
158
|
+
return _read_ir.__wrapped__(name, delay)
|
159
|
+
|
160
|
+
|
161
|
+
@lru_cache
|
162
|
+
def _read_ir(name: str | Path, delay: int) -> ImpulseResponseData:
|
163
|
+
"""Read impulse response data using soundfile
|
164
|
+
|
165
|
+
:param name: File name
|
166
|
+
:param delay: Delay in samples
|
76
167
|
:return: ImpulseResponseData object
|
77
168
|
"""
|
78
|
-
|
169
|
+
out, sample_rate = raw_read_audio(name)
|
170
|
+
|
171
|
+
return ImpulseResponseData(data=out, sample_rate=sample_rate, delay=delay)
|
172
|
+
|
173
|
+
|
174
|
+
def get_num_samples(name: str | Path, use_cache: bool = True) -> int:
|
175
|
+
"""Get the number of samples resampled to the SonusAI sample rate in the given file
|
79
176
|
|
80
|
-
|
177
|
+
:param name: File name
|
178
|
+
:param use_cache: If true, use LRU caching
|
179
|
+
:return: number of samples in resampled audio
|
180
|
+
"""
|
181
|
+
if use_cache:
|
182
|
+
return _get_num_samples(name)
|
183
|
+
return _get_num_samples.__wrapped__(name)
|
81
184
|
|
82
185
|
|
83
186
|
@lru_cache
|
84
|
-
def
|
187
|
+
def _get_num_samples(name: str | Path) -> int:
|
85
188
|
"""Get the number of samples resampled to the SonusAI sample rate in the given file
|
86
189
|
|
87
190
|
:param name: File name
|
88
191
|
:return: number of samples in resampled audio
|
89
192
|
"""
|
90
|
-
|
193
|
+
import math
|
91
194
|
|
92
|
-
|
195
|
+
import soundfile
|
196
|
+
from pydub import AudioSegment
|
197
|
+
|
198
|
+
from .constants import SAMPLE_RATE
|
199
|
+
from .tokenized_shell_vars import tokenized_expand
|
200
|
+
|
201
|
+
expanded_name, _ = tokenized_expand(name)
|
202
|
+
|
203
|
+
if expanded_name.endswith(".mp3"):
|
204
|
+
sound = AudioSegment.from_mp3(expanded_name)
|
205
|
+
samples = sound.frame_count()
|
206
|
+
sample_rate = sound.frame_rate
|
207
|
+
elif expanded_name.endswith(".m4a"):
|
208
|
+
sound = AudioSegment.from_file(expanded_name)
|
209
|
+
samples = sound.frame_count()
|
210
|
+
sample_rate = sound.frame_rate
|
211
|
+
else:
|
212
|
+
info = soundfile.info(name)
|
213
|
+
samples = info.frames
|
214
|
+
sample_rate = info.samplerate
|
215
|
+
|
216
|
+
return math.ceil(SAMPLE_RATE * samples / sample_rate)
|