sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,314 @@
|
|
1
|
+
"""sonusai metrics_summary
|
2
|
+
|
3
|
+
usage: lsdb [-vlh] [-i MIXID] [-n NCPU] LOCATION
|
4
|
+
|
5
|
+
Options:
|
6
|
+
-h, --help
|
7
|
+
-v, --verbose
|
8
|
+
-l, --write-list Write .csv file list of all mixture metrics
|
9
|
+
-i MIXID, --mixid MIXID Mixture ID(s) to analyze. [default: *].
|
10
|
+
-n, --num_process NCPU Number of parallel processes to use [default: auto]
|
11
|
+
|
12
|
+
Summarize mixture metrics across a SonusAI mixture database where metrics have been generated by SonusAI genmetrics.
|
13
|
+
|
14
|
+
Inputs:
|
15
|
+
LOCATION A SonusAI mixture database directory with mixdb.db and pre-generated metrics from SonusAI genmetrics.
|
16
|
+
|
17
|
+
"""
|
18
|
+
|
19
|
+
import numpy as np
|
20
|
+
import pandas as pd
|
21
|
+
|
22
|
+
DB_99 = np.power(10, 99 / 10)
|
23
|
+
DB_N99 = np.power(10, -99 / 10)
|
24
|
+
|
25
|
+
|
26
|
+
def _process_mixture(
|
27
|
+
m_id: int,
|
28
|
+
location: str,
|
29
|
+
all_metric_names: list[str],
|
30
|
+
scalar_metric_names: list[str],
|
31
|
+
string_metric_names: list[str],
|
32
|
+
frame_metric_names: list[str],
|
33
|
+
bin_metric_names: list[str],
|
34
|
+
ptab_labels: list[str],
|
35
|
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
36
|
+
from os.path import basename
|
37
|
+
|
38
|
+
from sonusai.constants import SAMPLE_RATE
|
39
|
+
from sonusai.metrics import calc_wer
|
40
|
+
from sonusai.mixture import MixtureDatabase
|
41
|
+
|
42
|
+
mixdb = MixtureDatabase(location)
|
43
|
+
|
44
|
+
# Process mixture
|
45
|
+
# for mixid in mixids:
|
46
|
+
samples = mixdb.mixture(m_id).samples
|
47
|
+
duration = samples / SAMPLE_RATE
|
48
|
+
tf_frames = mixdb.mixture_transform_frames(m_id)
|
49
|
+
feat_frames = mixdb.mixture_feature_frames(m_id)
|
50
|
+
mxsnr = mixdb.mixture(m_id).noise.snr
|
51
|
+
ti = mixdb.mixture(m_id).sources["primary"].file_id
|
52
|
+
ni = mixdb.mixture(m_id).noise.file_id
|
53
|
+
t0file = basename(mixdb.source_file(ti).name)
|
54
|
+
nfile = basename(mixdb.source_file(ni).name)
|
55
|
+
|
56
|
+
all_metrics = mixdb.mixture_metrics(m_id, all_metric_names)
|
57
|
+
|
58
|
+
# replace dict with 'primary' value (ignore mixup)
|
59
|
+
scalar_metrics = {
|
60
|
+
key: all_metrics[key]["primary"] if isinstance(all_metrics[key], dict) else all_metrics[key]
|
61
|
+
for key in scalar_metric_names
|
62
|
+
}
|
63
|
+
string_metrics = {
|
64
|
+
key: all_metrics[key]["primary"] if isinstance(all_metrics[key], dict) else all_metrics[key]
|
65
|
+
for key in string_metric_names
|
66
|
+
}
|
67
|
+
|
68
|
+
# Convert strings into word count
|
69
|
+
for key in string_metrics:
|
70
|
+
string_metrics[key] = calc_wer(string_metrics[key], string_metrics[key]).words
|
71
|
+
|
72
|
+
# Collect pandas table values note: must match given ptab_labels
|
73
|
+
ptab_data: list = [
|
74
|
+
mxsnr,
|
75
|
+
*scalar_metrics.values(),
|
76
|
+
*string_metrics.values(),
|
77
|
+
tf_frames,
|
78
|
+
duration,
|
79
|
+
t0file,
|
80
|
+
nfile,
|
81
|
+
]
|
82
|
+
|
83
|
+
ptab1 = pd.DataFrame([ptab_data], columns=ptab_labels, index=[m_id])
|
84
|
+
|
85
|
+
# TODO: collect frame metrics and bin metrics
|
86
|
+
|
87
|
+
return ptab1, ptab1
|
88
|
+
|
89
|
+
|
90
|
+
def main() -> None:
|
91
|
+
from docopt import docopt
|
92
|
+
|
93
|
+
from sonusai import __version__ as sai_version
|
94
|
+
from sonusai.utils.docstring import trim_docstring
|
95
|
+
|
96
|
+
args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
|
97
|
+
|
98
|
+
verbose = args["--verbose"]
|
99
|
+
wrlist = args["--write-list"]
|
100
|
+
mixids = args["--mixid"]
|
101
|
+
location = args["LOCATION"]
|
102
|
+
num_proc = args["--num_process"]
|
103
|
+
|
104
|
+
from functools import partial
|
105
|
+
from os.path import basename
|
106
|
+
from os.path import join
|
107
|
+
|
108
|
+
import psutil
|
109
|
+
|
110
|
+
from sonusai import create_file_handler
|
111
|
+
from sonusai import initial_log_messages
|
112
|
+
from sonusai import logger
|
113
|
+
from sonusai import update_console_handler
|
114
|
+
from sonusai.mixture import MixtureDatabase
|
115
|
+
from sonusai.utils.create_timestamp import create_timestamp
|
116
|
+
from sonusai.utils.parallel import par_track
|
117
|
+
from sonusai.utils.parallel import track
|
118
|
+
|
119
|
+
mixdb = MixtureDatabase(location)
|
120
|
+
print(f"Found SonusAI mixture database with {mixdb.num_mixtures} mixtures.")
|
121
|
+
|
122
|
+
# Only check first and last mixture in order to save time
|
123
|
+
metrics_present = mixdb.cached_metrics([0, mixdb.num_mixtures - 1]) # return pre-generated metrics in mixdb tree
|
124
|
+
if "mxsnr" in metrics_present:
|
125
|
+
metrics_present.remove("mxsnr")
|
126
|
+
|
127
|
+
num_metrics_present = len(metrics_present)
|
128
|
+
if num_metrics_present < 1:
|
129
|
+
print(f"mixdb reports no pre-generated metrics are present. Nothing to summarize in {location}, exiting ...")
|
130
|
+
return
|
131
|
+
|
132
|
+
# Setup logging file
|
133
|
+
timestamp = create_timestamp() # string good for embedding into filenames
|
134
|
+
mixdb_fname = basename(location)
|
135
|
+
if verbose:
|
136
|
+
create_file_handler(join(location, "metrics_summary.log"), verbose)
|
137
|
+
update_console_handler(verbose)
|
138
|
+
initial_log_messages("metrics_summary")
|
139
|
+
logger.info(f"Logging summary of SonusAI mixture database at {location}")
|
140
|
+
else:
|
141
|
+
update_console_handler(verbose)
|
142
|
+
|
143
|
+
logger.info("")
|
144
|
+
mixids = mixdb.mixids_to_list(mixids)
|
145
|
+
if len(mixids) < mixdb.num_mixtures:
|
146
|
+
logger.info(
|
147
|
+
f"Processing a subset of {len(mixids)} out of total mixdb mixtures of {mixdb.num_mixtures}, "
|
148
|
+
f"summary results will not include entire dataset."
|
149
|
+
)
|
150
|
+
fsuffix = f"_s{len(mixids)}t{mixdb.num_mixtures}"
|
151
|
+
else:
|
152
|
+
logger.info(
|
153
|
+
f"Summarizing SonusAI mixture database with {mixdb.num_mixtures} mixtures "
|
154
|
+
f"and {num_metrics_present} pre-generated metrics ..."
|
155
|
+
)
|
156
|
+
fsuffix = ""
|
157
|
+
|
158
|
+
metric_sup = mixdb.supported_metrics
|
159
|
+
ft_bins = mixdb.ft_config.bin_end - mixdb.ft_config.bin_start + 1 # bins of forward transform
|
160
|
+
# Pre-process first mixid to gather metrics into 4 types: scalar, str (scalar word cnt), frame-array, bin-array
|
161
|
+
# Collect list of indices for each
|
162
|
+
scalar_metric_names: list[str] = []
|
163
|
+
string_metric_names: list[str] = []
|
164
|
+
frame_metric_names: list[str] = []
|
165
|
+
bin_metric_names: list[str] = []
|
166
|
+
all_metrics = mixdb.mixture_metrics(mixids[0], metrics_present)
|
167
|
+
tf_frames = mixdb.mixture_transform_frames(mixids[0])
|
168
|
+
for metric in metrics_present:
|
169
|
+
metval = all_metrics[metric] # get metric value
|
170
|
+
logger.debug(f"First mixid {mixids[0]} metric {metric} = {metval}")
|
171
|
+
if isinstance(metval, dict):
|
172
|
+
logger.warning(f"Mixid {mixids[0]} metric {metric} is a dict, using 'primary'.")
|
173
|
+
metval = metval["primary"] # remove any dict
|
174
|
+
if isinstance(metval, float | int):
|
175
|
+
logger.debug(f"Metric is scalar {type(metval)}, entering in summary table.")
|
176
|
+
scalar_metric_names.append(metric)
|
177
|
+
elif isinstance(metval, str):
|
178
|
+
logger.debug("Metric is string, will summarize with word count.")
|
179
|
+
string_metric_names.append(metric)
|
180
|
+
elif isinstance(metval, np.ndarray):
|
181
|
+
if metval.ndim == 1:
|
182
|
+
if metval.size == tf_frames:
|
183
|
+
logger.debug("Metric is frames vector.")
|
184
|
+
frame_metric_names.append(metric)
|
185
|
+
elif metval.size == ft_bins:
|
186
|
+
logger.debug("Metric is bins vector.")
|
187
|
+
bin_metric_names.append(metric)
|
188
|
+
else:
|
189
|
+
logger.warning(f"Mixid {mixids[0]} metric {metric} is a vector of improper size, ignoring.")
|
190
|
+
|
191
|
+
# Setup pandas table for summarizing scalar metrics, always include mxsnr first
|
192
|
+
ptab_labels = [
|
193
|
+
"mxsnr",
|
194
|
+
*scalar_metric_names,
|
195
|
+
*string_metric_names,
|
196
|
+
"fcnt",
|
197
|
+
"duration",
|
198
|
+
"t0file",
|
199
|
+
"nfile",
|
200
|
+
]
|
201
|
+
|
202
|
+
num_cpu = psutil.cpu_count()
|
203
|
+
cpu_percent = psutil.cpu_percent(interval=1)
|
204
|
+
logger.info("")
|
205
|
+
logger.info(f"#CPUs: {num_cpu}, current CPU utilization: {cpu_percent}%")
|
206
|
+
logger.info(f"Memory utilization: {psutil.virtual_memory().percent}%")
|
207
|
+
if num_proc == "auto":
|
208
|
+
use_cpu = int(num_cpu * (0.9 - cpu_percent / 100)) # default use 80% of available cpus
|
209
|
+
elif num_proc == "None":
|
210
|
+
use_cpu = None
|
211
|
+
else:
|
212
|
+
use_cpu = min(max(int(num_proc), 1), num_cpu)
|
213
|
+
|
214
|
+
logger.info(f"Summarizing metrics for {len(mixids)} mixtures using {use_cpu} parallel processes")
|
215
|
+
|
216
|
+
# progress = tqdm(total=len(mixids), desc='calc_metric_spenh', mininterval=1)
|
217
|
+
progress = track(total=len(mixids))
|
218
|
+
if use_cpu is None:
|
219
|
+
no_par = True
|
220
|
+
num_cpus = None
|
221
|
+
else:
|
222
|
+
no_par = False
|
223
|
+
num_cpus = use_cpu
|
224
|
+
|
225
|
+
all_metrics_tables = par_track(
|
226
|
+
partial(
|
227
|
+
_process_mixture,
|
228
|
+
location=location,
|
229
|
+
all_metric_names=metrics_present,
|
230
|
+
scalar_metric_names=scalar_metric_names,
|
231
|
+
string_metric_names=string_metric_names,
|
232
|
+
frame_metric_names=frame_metric_names,
|
233
|
+
bin_metric_names=bin_metric_names,
|
234
|
+
ptab_labels=ptab_labels,
|
235
|
+
),
|
236
|
+
mixids,
|
237
|
+
progress=progress,
|
238
|
+
num_cpus=num_cpus,
|
239
|
+
no_par=no_par,
|
240
|
+
)
|
241
|
+
progress.close()
|
242
|
+
|
243
|
+
# Done with mixtures, write out summary metrics
|
244
|
+
header_args = {
|
245
|
+
"mode": "a",
|
246
|
+
"encoding": "utf-8",
|
247
|
+
"index": False,
|
248
|
+
"header": False,
|
249
|
+
}
|
250
|
+
table_args = {
|
251
|
+
"mode": "a",
|
252
|
+
"encoding": "utf-8",
|
253
|
+
}
|
254
|
+
ptab1 = pd.concat([item[0] for item in all_metrics_tables])
|
255
|
+
if wrlist:
|
256
|
+
wlcsv_name = str(join(location, "metric_summary_list" + fsuffix + ".csv"))
|
257
|
+
pd.DataFrame([["Timestamp", timestamp]]).to_csv(wlcsv_name, header=False, index=False)
|
258
|
+
pd.DataFrame([f"Metric list for {mixdb_fname}:"]).to_csv(wlcsv_name, mode="a", header=False, index=False)
|
259
|
+
ptab1.round(2).to_csv(wlcsv_name, **table_args)
|
260
|
+
ptab1_sorted = ptab1.sort_values(by=["mxsnr", "t0file"])
|
261
|
+
|
262
|
+
# Create metrics table except -99 SNR
|
263
|
+
ptab1_nom99 = ptab1_sorted[ptab1_sorted.mxsnr != -99]
|
264
|
+
|
265
|
+
# Create summary by SNR for all scalar metrics, taking mean
|
266
|
+
mtab_snr_summary = None
|
267
|
+
for snri in range(0, len(mixdb.snrs)):
|
268
|
+
tmp = ptab1_sorted.query("mxsnr==" + str(mixdb.snrs[snri])).mean(numeric_only=True).to_frame().T
|
269
|
+
# avoid nan when subset of mixids specified (i.e. no mixtures exist for an SNR)
|
270
|
+
if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
|
271
|
+
mtab_snr_summary = pd.concat([mtab_snr_summary, tmp])
|
272
|
+
mtab_snr_summary = mtab_snr_summary.sort_values(by=["mxsnr"], ascending=False)
|
273
|
+
|
274
|
+
# Write summary to .csv
|
275
|
+
snrcsv_name = str(join(location, "metric_summary_snr" + fsuffix + ".csv"))
|
276
|
+
nmix = len(mixids)
|
277
|
+
nmixtot = mixdb.num_mixtures
|
278
|
+
pd.DataFrame([["Timestamp", timestamp]]).to_csv(snrcsv_name, header=False, index=False)
|
279
|
+
pd.DataFrame(['"Metrics avg over each SNR:"']).to_csv(snrcsv_name, **header_args)
|
280
|
+
mtab_snr_summary.round(2).T.to_csv(snrcsv_name, index=True, header=False, mode="a", encoding="utf-8")
|
281
|
+
pd.DataFrame(["--"]).to_csv(snrcsv_name, header=False, index=False, mode="a")
|
282
|
+
pd.DataFrame([f'"Metrics stats over {nmix} mixtures out of {nmixtot} total:"']).to_csv(snrcsv_name, **header_args)
|
283
|
+
ptab1.describe().round(2).T.to_csv(snrcsv_name, index=True, **table_args)
|
284
|
+
pd.DataFrame(["--"]).to_csv(snrcsv_name, header=False, index=False, mode="a")
|
285
|
+
pd.DataFrame([f'"Metrics stats over {len(ptab1_nom99)} non -99db mixtures out of {nmixtot} total:"']).to_csv(
|
286
|
+
snrcsv_name, **header_args
|
287
|
+
)
|
288
|
+
ptab1_nom99.describe().round(2).T.to_csv(snrcsv_name, index=True, **table_args)
|
289
|
+
|
290
|
+
# Write summary to text file
|
291
|
+
snrtxt_name = str(join(location, "metric_summary_snr" + fsuffix + ".txt"))
|
292
|
+
with open(snrtxt_name, "w") as f:
|
293
|
+
print(f"Timestamp: {timestamp}", file=f)
|
294
|
+
print("Metrics avg over each SNR:", file=f)
|
295
|
+
print(
|
296
|
+
mtab_snr_summary.round(2).T.to_string(float_format=lambda x: f"{x:.2f}", index=True, header=False), file=f
|
297
|
+
)
|
298
|
+
print("", file=f)
|
299
|
+
print(f"Metrics stats over {len(mixids)} mixtures out of {mixdb.num_mixtures} total:", file=f)
|
300
|
+
print(ptab1.describe().round(2).T.to_string(float_format=lambda x: f"{x:.2f}", index=True), file=f)
|
301
|
+
print("", file=f)
|
302
|
+
print(f"Metrics stats over {len(ptab1_nom99)} non -99db mixtures out of {mixdb.num_mixtures} total:", file=f)
|
303
|
+
print(ptab1_nom99.describe().round(2).T.to_string(float_format=lambda x: f"{x:.2f}", index=True), file=f)
|
304
|
+
|
305
|
+
|
306
|
+
if __name__ == "__main__":
|
307
|
+
from sonusai import exception_handler
|
308
|
+
from sonusai.utils.keyboard_interrupt import register_keyboard_interrupt
|
309
|
+
|
310
|
+
register_keyboard_interrupt()
|
311
|
+
try:
|
312
|
+
main()
|
313
|
+
except Exception as e:
|
314
|
+
exception_handler(e)
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# SonusAI mixture utilities
|
2
|
+
|
3
|
+
from .feature import get_audio_from_feature
|
4
|
+
from .feature import get_feature_from_audio
|
5
|
+
from .helpers import forward_transform
|
6
|
+
from .helpers import inverse_transform
|
7
|
+
from .mixdb import MixtureDatabase
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"MixtureDatabase",
|
11
|
+
"forward_transform",
|
12
|
+
"get_audio_from_feature",
|
13
|
+
"get_feature_from_audio",
|
14
|
+
"inverse_transform",
|
15
|
+
]
|
sonusai/mixture/audio.py
ADDED
@@ -0,0 +1,187 @@
|
|
1
|
+
from functools import lru_cache
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
from ..datatypes import AudioT
|
5
|
+
|
6
|
+
|
7
|
+
def get_next_noise(audio: AudioT, offset: int, length: int) -> AudioT:
|
8
|
+
"""Get the next sequence of noise data from noise audio
|
9
|
+
|
10
|
+
:param audio: Overall noise audio (entire file's worth of data)
|
11
|
+
:param offset: Starting sample
|
12
|
+
:param length: Number of samples to get
|
13
|
+
:return: Sequence of noise audio data
|
14
|
+
"""
|
15
|
+
import numpy as np
|
16
|
+
|
17
|
+
return np.take(audio, range(offset, offset + length), mode="wrap")
|
18
|
+
|
19
|
+
|
20
|
+
def get_duration(audio: AudioT) -> float:
|
21
|
+
"""Get duration of audio in seconds
|
22
|
+
|
23
|
+
:param audio: Time domain data [samples]
|
24
|
+
:return: Duration of audio in seconds
|
25
|
+
"""
|
26
|
+
from ..constants import SAMPLE_RATE
|
27
|
+
|
28
|
+
return len(audio) / SAMPLE_RATE
|
29
|
+
|
30
|
+
|
31
|
+
def validate_input_file(input_filepath: str | Path) -> None:
|
32
|
+
from os.path import exists
|
33
|
+
from os.path import splitext
|
34
|
+
|
35
|
+
from soundfile import available_formats
|
36
|
+
|
37
|
+
if not exists(input_filepath):
|
38
|
+
raise OSError(f"input_filepath {input_filepath} does not exist.")
|
39
|
+
|
40
|
+
ext = splitext(input_filepath)[1][1:].lower()
|
41
|
+
read_formats = [item.lower() for item in available_formats()]
|
42
|
+
if ext not in read_formats:
|
43
|
+
raise OSError(f"This installation cannot process .{ext} files")
|
44
|
+
|
45
|
+
|
46
|
+
def get_sample_rate(name: str | Path, use_cache: bool = True) -> int:
|
47
|
+
"""Get sample rate from audio file
|
48
|
+
|
49
|
+
:param name: File name
|
50
|
+
:param use_cache: If true, use LRU caching
|
51
|
+
:return: Sample rate
|
52
|
+
"""
|
53
|
+
if use_cache:
|
54
|
+
return _get_sample_rate(name)
|
55
|
+
return _get_sample_rate.__wrapped__(name)
|
56
|
+
|
57
|
+
|
58
|
+
@lru_cache
|
59
|
+
def _get_sample_rate(name: str | Path) -> int:
|
60
|
+
"""Get sample rate from audio file using soundfile
|
61
|
+
|
62
|
+
:param name: File name
|
63
|
+
:return: Sample rate
|
64
|
+
"""
|
65
|
+
import soundfile
|
66
|
+
from pydub import AudioSegment
|
67
|
+
|
68
|
+
from ..utils.tokenized_shell_vars import tokenized_expand
|
69
|
+
|
70
|
+
expanded_name, _ = tokenized_expand(name)
|
71
|
+
|
72
|
+
try:
|
73
|
+
if expanded_name.endswith(".mp3"):
|
74
|
+
return AudioSegment.from_mp3(expanded_name).frame_rate
|
75
|
+
|
76
|
+
if expanded_name.endswith(".m4a"):
|
77
|
+
return AudioSegment.from_file(expanded_name).frame_rate
|
78
|
+
|
79
|
+
return soundfile.info(expanded_name).samplerate
|
80
|
+
except Exception as e:
|
81
|
+
if name != expanded_name:
|
82
|
+
raise OSError(f"Error reading {name} (expanded: {expanded_name}): {e}") from e
|
83
|
+
else:
|
84
|
+
raise OSError(f"Error reading {name}: {e}") from e
|
85
|
+
|
86
|
+
|
87
|
+
def raw_read_audio(name: str | Path) -> tuple[AudioT, int]:
|
88
|
+
import numpy as np
|
89
|
+
import soundfile
|
90
|
+
from pydub import AudioSegment
|
91
|
+
|
92
|
+
from ..utils.tokenized_shell_vars import tokenized_expand
|
93
|
+
|
94
|
+
expanded_name, _ = tokenized_expand(name)
|
95
|
+
|
96
|
+
try:
|
97
|
+
if expanded_name.endswith(".mp3"):
|
98
|
+
sound = AudioSegment.from_mp3(expanded_name)
|
99
|
+
raw = np.array(sound.get_array_of_samples()).astype(np.float32).reshape((-1, sound.channels))
|
100
|
+
raw = raw / 2 ** (sound.sample_width * 8 - 1)
|
101
|
+
sample_rate = sound.frame_rate
|
102
|
+
elif expanded_name.endswith(".m4a"):
|
103
|
+
sound = AudioSegment.from_file(expanded_name)
|
104
|
+
raw = np.array(sound.get_array_of_samples()).astype(np.float32).reshape((-1, sound.channels))
|
105
|
+
raw = raw / 2 ** (sound.sample_width * 8 - 1)
|
106
|
+
sample_rate = sound.frame_rate
|
107
|
+
else:
|
108
|
+
raw, sample_rate = soundfile.read(expanded_name, always_2d=True, dtype="float32")
|
109
|
+
except Exception as e:
|
110
|
+
if name != expanded_name:
|
111
|
+
raise OSError(f"Error reading {name} (expanded: {expanded_name}): {e}") from e
|
112
|
+
else:
|
113
|
+
raise OSError(f"Error reading {name}: {e}") from e
|
114
|
+
|
115
|
+
return np.squeeze(raw[:, 0].astype(np.float32)), sample_rate
|
116
|
+
|
117
|
+
|
118
|
+
def read_audio(name: str | Path, use_cache: bool = True) -> AudioT:
|
119
|
+
"""Read audio data from a file
|
120
|
+
|
121
|
+
:param name: File name
|
122
|
+
:param use_cache: If true, use LRU caching
|
123
|
+
:return: Array of time domain audio data
|
124
|
+
"""
|
125
|
+
if use_cache:
|
126
|
+
return _read_audio(name)
|
127
|
+
return _read_audio.__wrapped__(name)
|
128
|
+
|
129
|
+
|
130
|
+
@lru_cache
|
131
|
+
def _read_audio(name: str | Path) -> AudioT:
|
132
|
+
"""Read audio data from a file using soundfile
|
133
|
+
|
134
|
+
:param name: File name
|
135
|
+
:return: Array of time domain audio data
|
136
|
+
"""
|
137
|
+
from ..constants import SAMPLE_RATE
|
138
|
+
from .resample import resample
|
139
|
+
|
140
|
+
out, sample_rate = raw_read_audio(name)
|
141
|
+
|
142
|
+
return resample(out, orig_sr=sample_rate, target_sr=SAMPLE_RATE)
|
143
|
+
|
144
|
+
|
145
|
+
def get_num_samples(name: str | Path, use_cache: bool = True) -> int:
|
146
|
+
"""Get the number of samples resampled to the SonusAI sample rate in the given file
|
147
|
+
|
148
|
+
:param name: File name
|
149
|
+
:param use_cache: If true, use LRU caching
|
150
|
+
:return: number of samples in resampled audio
|
151
|
+
"""
|
152
|
+
if use_cache:
|
153
|
+
return _get_num_samples(name)
|
154
|
+
return _get_num_samples.__wrapped__(name)
|
155
|
+
|
156
|
+
|
157
|
+
@lru_cache
|
158
|
+
def _get_num_samples(name: str | Path) -> int:
|
159
|
+
"""Get the number of samples resampled to the SonusAI sample rate in the given file
|
160
|
+
|
161
|
+
:param name: File name
|
162
|
+
:return: number of samples in resampled audio
|
163
|
+
"""
|
164
|
+
import math
|
165
|
+
|
166
|
+
import soundfile
|
167
|
+
from pydub import AudioSegment
|
168
|
+
|
169
|
+
from ..constants import SAMPLE_RATE
|
170
|
+
from ..utils.tokenized_shell_vars import tokenized_expand
|
171
|
+
|
172
|
+
expanded_name, _ = tokenized_expand(name)
|
173
|
+
|
174
|
+
if expanded_name.endswith(".mp3"):
|
175
|
+
sound = AudioSegment.from_mp3(expanded_name)
|
176
|
+
samples = sound.frame_count()
|
177
|
+
sample_rate = sound.frame_rate
|
178
|
+
elif expanded_name.endswith(".m4a"):
|
179
|
+
sound = AudioSegment.from_file(expanded_name)
|
180
|
+
samples = sound.frame_count()
|
181
|
+
sample_rate = sound.frame_rate
|
182
|
+
else:
|
183
|
+
info = soundfile.info(expanded_name)
|
184
|
+
samples = info.frames
|
185
|
+
sample_rate = info.samplerate
|
186
|
+
|
187
|
+
return math.ceil(SAMPLE_RATE * samples / sample_rate)
|
@@ -0,0 +1,103 @@
|
|
1
|
+
from ..datatypes import EffectList
|
2
|
+
from ..datatypes import EffectedFile
|
3
|
+
from ..datatypes import File
|
4
|
+
|
5
|
+
|
6
|
+
def balance_sources(
|
7
|
+
effected_sources: list[EffectedFile],
|
8
|
+
files: list[File],
|
9
|
+
effects: list[EffectList],
|
10
|
+
class_balancing_effect: EffectList,
|
11
|
+
num_classes: int,
|
12
|
+
num_ir: int,
|
13
|
+
mixups: list[int] | None = None,
|
14
|
+
) -> tuple[list[EffectedFile], list[EffectList]]:
|
15
|
+
import math
|
16
|
+
|
17
|
+
from .augmentation import get_mixups
|
18
|
+
from .sources import get_augmented_target_ids_by_class
|
19
|
+
|
20
|
+
first_cba_id = len(effects)
|
21
|
+
|
22
|
+
if mixups is None:
|
23
|
+
mixups = get_mixups(effects)
|
24
|
+
|
25
|
+
for mixup in mixups:
|
26
|
+
if mixup == 1:
|
27
|
+
continue
|
28
|
+
|
29
|
+
effected_sources_indices_by_class = get_augmented_target_ids_by_class(
|
30
|
+
augmented_targets=effected_sources,
|
31
|
+
targets=files,
|
32
|
+
target_augmentations=effects,
|
33
|
+
mixup=mixup,
|
34
|
+
num_classes=num_classes,
|
35
|
+
)
|
36
|
+
|
37
|
+
largest = max([len(item) for item in effected_sources_indices_by_class])
|
38
|
+
largest = math.ceil(largest / mixup) * mixup
|
39
|
+
for es_indices in effected_sources_indices_by_class:
|
40
|
+
additional_effects_needed = largest - len(es_indices)
|
41
|
+
file_ids = sorted({effected_sources[at_index].file_id for at_index in es_indices})
|
42
|
+
|
43
|
+
tfi_idx = 0
|
44
|
+
for _ in range(additional_effects_needed):
|
45
|
+
file_id = file_ids[tfi_idx]
|
46
|
+
tfi_idx = (tfi_idx + 1) % len(file_ids)
|
47
|
+
effect_id, effects = _get_unused_balancing_effect(
|
48
|
+
effected_sources=effected_sources,
|
49
|
+
files=files,
|
50
|
+
effects=effects,
|
51
|
+
class_balancing_effect=class_balancing_effect,
|
52
|
+
file_id=file_id,
|
53
|
+
mixup=mixup,
|
54
|
+
num_ir=num_ir,
|
55
|
+
first_cbe_id=first_cba_id,
|
56
|
+
)
|
57
|
+
effected_sources.append(EffectedFile(file_id=file_id, effect_id=effect_id))
|
58
|
+
|
59
|
+
return effected_sources, effects
|
60
|
+
|
61
|
+
|
62
|
+
def _get_unused_balancing_effect(
|
63
|
+
effected_sources: list[EffectedFile],
|
64
|
+
files: list[File],
|
65
|
+
effects: list[EffectList],
|
66
|
+
class_balancing_effect: EffectList,
|
67
|
+
file_id: int,
|
68
|
+
mixup: int,
|
69
|
+
num_ir: int,
|
70
|
+
first_cbe_id: int,
|
71
|
+
) -> tuple[int, list[EffectList]]:
|
72
|
+
"""Get an unused balancing augmentation for a given target file index"""
|
73
|
+
from dataclasses import asdict
|
74
|
+
|
75
|
+
from .augmentation import get_augmentation_rules
|
76
|
+
|
77
|
+
balancing_augmentations = [item for item in range(len(effects)) if item >= first_cbe_id]
|
78
|
+
used_balancing_augmentations = [
|
79
|
+
effected_source.effect_id
|
80
|
+
for effected_source in effected_sources
|
81
|
+
if effected_source.file_id == file_id and effected_source.effect_id in balancing_augmentations
|
82
|
+
]
|
83
|
+
|
84
|
+
augmentation_indices = [
|
85
|
+
item
|
86
|
+
for item in balancing_augmentations
|
87
|
+
if item not in used_balancing_augmentations and effects[item].mixup == mixup
|
88
|
+
]
|
89
|
+
if len(augmentation_indices) > 0:
|
90
|
+
return augmentation_indices[0], effects
|
91
|
+
|
92
|
+
class_balancing_effect = get_class_balancing_effect(file=files[file_id], default_cbe=class_balancing_effect)
|
93
|
+
new_effect = get_augmentation_rules(rules=asdict(class_balancing_effect), num_ir=num_ir)[0]
|
94
|
+
new_effect.mixup = mixup
|
95
|
+
effects.append(new_effect)
|
96
|
+
return len(effects) - 1, effects
|
97
|
+
|
98
|
+
|
99
|
+
def get_class_balancing_effect(file: File, default_cbe: EffectList) -> EffectList:
|
100
|
+
"""Get the class balancing effect rule for the given target"""
|
101
|
+
if file.class_balancing_effect is not None:
|
102
|
+
return file.class_balancing_effect
|
103
|
+
return default_cbe
|