sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +16 -3
- sonusai/audiofe.py +241 -77
- sonusai/calc_metric_spenh.py +71 -73
- sonusai/config/__init__.py +3 -0
- sonusai/config/config.py +61 -0
- sonusai/config/config.yml +20 -0
- sonusai/config/constants.py +8 -0
- sonusai/constants.py +11 -0
- sonusai/data/genmixdb.yml +21 -36
- sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
- sonusai/deprecated/plot.py +4 -5
- sonusai/doc/doc.py +4 -4
- sonusai/doc.py +11 -4
- sonusai/genft.py +43 -45
- sonusai/genmetrics.py +25 -19
- sonusai/genmix.py +54 -82
- sonusai/genmixdb.py +88 -264
- sonusai/ir_metric.py +30 -34
- sonusai/lsdb.py +41 -48
- sonusai/main.py +15 -22
- sonusai/metrics/calc_audio_stats.py +4 -293
- sonusai/metrics/calc_class_weights.py +4 -4
- sonusai/metrics/calc_optimal_thresholds.py +8 -5
- sonusai/metrics/calc_pesq.py +2 -2
- sonusai/metrics/calc_segsnr_f.py +4 -4
- sonusai/metrics/calc_speech.py +25 -13
- sonusai/metrics/class_summary.py +7 -7
- sonusai/metrics/confusion_matrix_summary.py +5 -5
- sonusai/metrics/one_hot.py +4 -4
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/metrics_summary.py +38 -45
- sonusai/mixture/__init__.py +4 -104
- sonusai/mixture/audio.py +10 -39
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/config.py +251 -271
- sonusai/mixture/constants.py +35 -39
- sonusai/mixture/data_io.py +25 -36
- sonusai/mixture/db_datatypes.py +58 -22
- sonusai/mixture/effects.py +386 -0
- sonusai/mixture/feature.py +7 -11
- sonusai/mixture/generation.py +478 -628
- sonusai/mixture/helpers.py +82 -184
- sonusai/mixture/ir_delay.py +3 -4
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +6 -12
- sonusai/mixture/mixdb.py +910 -729
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +2 -2
- sonusai/mixture/truth.py +17 -15
- sonusai/mixture/truth_functions/crm.py +12 -12
- sonusai/mixture/truth_functions/energy.py +22 -22
- sonusai/mixture/truth_functions/file.py +5 -5
- sonusai/mixture/truth_functions/metadata.py +4 -4
- sonusai/mixture/truth_functions/metrics.py +4 -4
- sonusai/mixture/truth_functions/phoneme.py +3 -3
- sonusai/mixture/truth_functions/sed.py +11 -13
- sonusai/mixture/truth_functions/target.py +10 -10
- sonusai/mkwav.py +26 -29
- sonusai/onnx_predict.py +240 -88
- sonusai/queries/__init__.py +2 -2
- sonusai/queries/queries.py +38 -34
- sonusai/speech/librispeech.py +1 -1
- sonusai/speech/mcgill.py +1 -1
- sonusai/speech/timit.py +2 -2
- sonusai/summarize_metric_spenh.py +10 -17
- sonusai/utils/__init__.py +7 -1
- sonusai/utils/asl_p56.py +2 -2
- sonusai/utils/asr.py +2 -2
- sonusai/utils/asr_functions/aaware_whisper.py +4 -5
- sonusai/utils/choice.py +31 -0
- sonusai/utils/compress.py +1 -1
- sonusai/utils/dataclass_from_dict.py +19 -1
- sonusai/utils/energy_f.py +3 -3
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/onnx_utils.py +3 -17
- sonusai/utils/print_mixture_details.py +21 -19
- sonusai/utils/{temp_seed.py → rand.py} +3 -3
- sonusai/utils/read_predict_data.py +2 -2
- sonusai/utils/reshape.py +3 -3
- sonusai/utils/stratified_shuffle_split.py +3 -3
- sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
- sonusai/utils/write_audio.py +2 -2
- sonusai/vars.py +11 -4
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
- sonusai-1.0.2.dist-info/RECORD +138 -0
- sonusai/mixture/augmentation.py +0 -444
- sonusai/mixture/class_count.py +0 -15
- sonusai/mixture/eq_rule_is_valid.py +0 -45
- sonusai/mixture/target_class_balancing.py +0 -107
- sonusai/mixture/targets.py +0 -175
- sonusai-0.20.3.dist-info/RECORD +0 -128
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
sonusai/onnx_predict.py
CHANGED
@@ -14,16 +14,19 @@ The ONNX Runtime (ort) inference engine is used to execute the inference.
|
|
14
14
|
|
15
15
|
Inputs:
|
16
16
|
MODEL ONNX model .onnx file of a trained model (weights are expected to be in the file).
|
17
|
+
The model must also include required Sonusai hyperparameters. See theSonusai torchl_onnx command.
|
17
18
|
|
18
|
-
DATA
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
19
|
+
DATA A string which must be one of the following:
|
20
|
+
1. Path to a single file. The prediction data is written to <filename_predict.*> in same location.
|
21
|
+
2. Path to a Sonusai Mixture Database directory.
|
22
|
+
- Sonusai mixture database directory, prediction files will be named mixid_predict.*
|
23
|
+
- MIXID will select a subset of mixture ids
|
24
|
+
3. Directory with audio files found recursively within. See GLOB audio file extensions below.
|
25
|
+
4. Regex resolving to a list of files.
|
26
|
+
- Subdirectory containing audio files with extension
|
27
|
+
- Regex resolving to a list of audio files
|
26
28
|
|
29
|
+
generate feature and truth data if not found.
|
27
30
|
|
28
31
|
Note there are multiple ways to process model prediction over multiple audio data files:
|
29
32
|
1. TSE (timestep single extension): mixture transform frames are fit into the timestep dimension and the model run as
|
@@ -42,33 +45,68 @@ TBD not sure below make sense, need to continue ??
|
|
42
45
|
|
43
46
|
Outputs the following to opredict-<TIMESTAMP> directory:
|
44
47
|
<id>
|
45
|
-
predict.
|
48
|
+
predict.h5
|
46
49
|
onnx_predict.log
|
47
50
|
|
48
51
|
"""
|
49
52
|
|
50
|
-
import signal
|
51
|
-
|
52
|
-
|
53
|
-
def signal_handler(_sig, _frame):
|
54
|
-
import sys
|
55
|
-
|
56
|
-
from sonusai import logger
|
57
53
|
|
58
|
-
|
59
|
-
|
54
|
+
def process_path(path, ext_list: list[str] | None = None):
|
55
|
+
"""
|
56
|
+
Check path which can be a single file, a subdirectory, or a regex
|
57
|
+
return:
|
58
|
+
- a list of files with matching extensions to any in ext_list provided (i.e. ['.wav', '.mp3', '.acc'])
|
59
|
+
- the basedir of the path, if
|
60
|
+
"""
|
61
|
+
import glob
|
62
|
+
from os.path import abspath
|
63
|
+
from os.path import commonprefix
|
64
|
+
from os.path import dirname
|
65
|
+
from os.path import isdir
|
66
|
+
from os.path import isfile
|
67
|
+
from os.path import join
|
60
68
|
|
69
|
+
from sonusai.utils import braced_iglob
|
61
70
|
|
62
|
-
|
71
|
+
if ext_list is None:
|
72
|
+
ext_list = [".wav", ".WAV", ".flac", ".FLAC", ".mp3", ".aac"]
|
73
|
+
|
74
|
+
# Check if the path is a single file, and return it as a list with the dirname
|
75
|
+
if isfile(path):
|
76
|
+
if any(path.endswith(ext) for ext in ext_list):
|
77
|
+
basedir = dirname(path) # base directory
|
78
|
+
if not basedir:
|
79
|
+
basedir = "./"
|
80
|
+
return [path], basedir
|
81
|
+
else:
|
82
|
+
return [], []
|
83
|
+
|
84
|
+
# Check if the path is a dir, recursively find all files any of the specified extensions, return file list and dir
|
85
|
+
if isdir(path):
|
86
|
+
matching_files = []
|
87
|
+
for ext in ext_list:
|
88
|
+
matching_files.extend(glob.glob(join(path, "**/*" + ext), recursive=True))
|
89
|
+
return matching_files, path
|
90
|
+
|
91
|
+
# Process as a regex, return list of filenames and basedir
|
92
|
+
apath = abspath(path) # join(abspath(path), "**", "*.{wav,flac,WAV}")
|
93
|
+
matching_files = []
|
94
|
+
for file in braced_iglob(pathname=apath, recursive=True):
|
95
|
+
matching_files.append(file)
|
96
|
+
if matching_files:
|
97
|
+
basedir = commonprefix(matching_files) # Find basedir
|
98
|
+
return matching_files, basedir
|
99
|
+
else:
|
100
|
+
return [], []
|
63
101
|
|
64
102
|
|
65
103
|
def main() -> None:
|
66
104
|
from docopt import docopt
|
67
105
|
|
68
|
-
import
|
106
|
+
from sonusai import __version__ as sai_version
|
69
107
|
from sonusai.utils import trim_docstring
|
70
108
|
|
71
|
-
args = docopt(trim_docstring(__doc__), version=
|
109
|
+
args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
|
72
110
|
|
73
111
|
verbose = args["--verbose"]
|
74
112
|
wav = args["--write-wav"]
|
@@ -77,18 +115,23 @@ def main() -> None:
|
|
77
115
|
model_path = args["MODEL"]
|
78
116
|
data_paths = args["DATA"]
|
79
117
|
|
118
|
+
# Quick check of CPU and GPU devices
|
119
|
+
import re
|
120
|
+
import subprocess
|
121
|
+
import time
|
80
122
|
from os import makedirs
|
81
|
-
from os.path import abspath
|
82
123
|
from os.path import basename
|
124
|
+
from os.path import exists
|
83
125
|
from os.path import isdir
|
126
|
+
from os.path import isfile
|
84
127
|
from os.path import join
|
85
128
|
from os.path import normpath
|
86
|
-
from os.path import realpath
|
87
129
|
from os.path import splitext
|
88
130
|
|
89
131
|
import h5py
|
90
132
|
import numpy as np
|
91
133
|
import onnxruntime as ort
|
134
|
+
import psutil
|
92
135
|
|
93
136
|
from sonusai import create_file_handler
|
94
137
|
from sonusai import initial_log_messages
|
@@ -96,66 +139,122 @@ def main() -> None:
|
|
96
139
|
from sonusai import update_console_handler
|
97
140
|
from sonusai.mixture import MixtureDatabase
|
98
141
|
from sonusai.mixture import get_audio_from_feature
|
99
|
-
from sonusai.utils import PathInfo
|
100
|
-
from sonusai.utils import braced_iglob
|
101
142
|
from sonusai.utils import create_ts_name
|
102
143
|
from sonusai.utils import load_ort_session
|
103
|
-
from sonusai.utils import
|
144
|
+
from sonusai.utils import seconds_to_hms
|
104
145
|
from sonusai.utils import write_audio
|
105
146
|
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
147
|
+
num_cpu = psutil.cpu_count()
|
148
|
+
cpu_percent = psutil.cpu_percent(interval=1)
|
149
|
+
print(f"#CPUs: {num_cpu}, current CPU utilization: {cpu_percent}%")
|
150
|
+
print(f"Memory utilization: {psutil.virtual_memory().percent}%")
|
151
|
+
|
152
|
+
vga_devices = [
|
153
|
+
line.split(" ", 3)[-1]
|
154
|
+
for line in subprocess.check_output("lspci | grep -i vga", shell=True).decode().splitlines()
|
155
|
+
]
|
156
|
+
nv_devs = list(filter(lambda x: "nvidia" in x.lower(), vga_devices))
|
157
|
+
nv_mods = [re.search(r"\[.*?\]", device).group(0) if re.search(r"\[.*?\]", device) else None for device in nv_devs]
|
158
|
+
if len(nv_mods) > 0:
|
159
|
+
print(f"{len(nv_mods)} Nvidia devices present: {nv_mods}") # prints model names
|
115
160
|
else:
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
161
|
+
print("No cuda devices present, using cpu")
|
162
|
+
|
163
|
+
avail_providers = ort.get_available_providers()
|
164
|
+
print(f"Loaded ONNX Runtime, available providers: {avail_providers}.")
|
165
|
+
if len(nv_mods) > 0:
|
166
|
+
print(
|
167
|
+
"If GPU is desired, need to replace onnxruntime with onnxruntime-gpu i.e. using pip:"
|
168
|
+
"> pip uninstall onnxruntime"
|
169
|
+
"> pip install onnxruntime-gpu\n\n"
|
170
|
+
)
|
171
|
+
|
172
|
+
# Quick check that model is valid
|
173
|
+
if exists(model_path) and isfile(model_path):
|
174
|
+
try:
|
175
|
+
session = ort.InferenceSession(model_path)
|
176
|
+
options = ort.SessionOptions()
|
177
|
+
except Exception as e:
|
178
|
+
print(f"Error: could not load ONNX model from {model_path}: {e}")
|
179
|
+
raise SystemExit(1) from e
|
180
|
+
else:
|
181
|
+
print(f"Error: model file path is not valid: {model_path}")
|
182
|
+
raise SystemExit(1)
|
126
183
|
|
127
|
-
|
128
|
-
|
184
|
+
# Check datapath is valid
|
185
|
+
if len(data_paths) == 1 and isdir(data_paths[0]): # Try opening as mixdb subdir
|
186
|
+
mixdb_path = data_paths[0]
|
187
|
+
try:
|
188
|
+
mixdb = MixtureDatabase(mixdb_path)
|
189
|
+
except Exception:
|
190
|
+
mixdb_path = None
|
191
|
+
in_basename = basename(normpath(data_paths[0]))
|
192
|
+
output_dir = create_ts_name("opredict-" + in_basename)
|
193
|
+
num_featparams = mixdb.feature_parameters
|
194
|
+
print(f"Loaded SonusAI mixdb with {mixdb.num_mixtures} mixtures and {num_featparams} classes")
|
195
|
+
p_mixids = mixdb.mixids_to_list(mixids)
|
196
|
+
feature_mode = mixdb.feature
|
129
197
|
|
198
|
+
if mixdb_path is None:
|
199
|
+
if verbose:
|
200
|
+
print(f"Checking {len(data_paths)} locations ... ")
|
201
|
+
# Check location, default ext are ['.wav', '.WAV', '.flac', '.FLAC', '.mp3', '.aac']
|
202
|
+
pfiles, basedir = process_path(data_paths)
|
203
|
+
if pfiles is None or len(pfiles) < 1:
|
204
|
+
print(f"No audio files or Sonusai mixture database found in {data_paths}, exiting ...")
|
205
|
+
raise SystemExit(1)
|
206
|
+
else:
|
207
|
+
pfiles = sorted(pfiles, key=basename)
|
208
|
+
output_dir = basedir
|
209
|
+
|
210
|
+
if mixdb_path is not None or len(pfiles) > 1: # log file only if mixdb or more than one file
|
211
|
+
makedirs(output_dir, exist_ok=True)
|
212
|
+
# Setup logging file
|
213
|
+
create_file_handler(join(output_dir, "onnx-predict.log"))
|
214
|
+
update_console_handler(verbose)
|
215
|
+
initial_log_messages("onnx_predict")
|
216
|
+
# print some previous messages
|
217
|
+
logger.info(f"Loaded ONNX Runtime, available providers: {avail_providers}.")
|
218
|
+
if mixdb_path:
|
219
|
+
logger.debug(f"Loaded SonusAI mixdb with {mixdb.num_mixtures} mixtures and {num_featparams} classes")
|
220
|
+
if len(p_mixids) != mixdb.num_mixtures:
|
221
|
+
logger.info(f"Processing a subset of {len(p_mixids)} from available mixtures.")
|
222
|
+
|
223
|
+
# Reload model/session and do more thorough checking
|
130
224
|
session, options, model_root, hparams, sess_inputs, sess_outputs = load_ort_session(model_path)
|
225
|
+
if "CUDAExecutionProvider" in avail_providers:
|
226
|
+
session.set_providers(["CUDAExecutionProvider"])
|
131
227
|
if hparams is None:
|
132
228
|
logger.error("Error: ONNX model does not have required SonusAI hyperparameters, cannot proceed.")
|
133
229
|
raise SystemExit(1)
|
134
|
-
|
230
|
+
|
231
|
+
if len(sess_inputs) != 1: # TBD update to support state_in and state_out
|
135
232
|
logger.error(f"Error: ONNX model does not have 1 input, but {len(sess_inputs)}. Exit due to unknown input.")
|
136
233
|
|
137
234
|
in0name = sess_inputs[0].name
|
138
235
|
in0type = sess_inputs[0].type
|
139
|
-
|
140
|
-
|
141
|
-
|
236
|
+
in0shape = sess_inputs[0].shape # a list
|
237
|
+
# Check for 2 cases of model feature input shape: batch x timesteps x fparams or batch x fparams
|
238
|
+
if not isinstance(in0shape[0], str):
|
239
|
+
model_batchsz = int(in0shape[0])
|
240
|
+
logger.debug(f"Onnx model has fixed batch_size: {model_batchsz}.")
|
241
|
+
else:
|
242
|
+
model_batchsz = -1
|
243
|
+
logger.debug("Onnx model has a dynamic batch_size.")
|
142
244
|
|
143
|
-
if
|
144
|
-
|
145
|
-
|
146
|
-
mixdb = MixtureDatabase(mixdb_path)
|
147
|
-
logger.info(f"SonusAI mixdb: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes")
|
148
|
-
p_mixids = mixdb.mixids_to_list(mixids)
|
149
|
-
if len(p_mixids) != mixdb.num_mixtures:
|
150
|
-
logger.info(f"Processing a subset of {p_mixids} from available mixtures.")
|
245
|
+
if len(in0shape) < 3:
|
246
|
+
model_tsteps = 0
|
247
|
+
model_featparams = int(in0shape[1])
|
151
248
|
else:
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
249
|
+
model_featparams = int(in0shape[2])
|
250
|
+
if not isinstance(in0shape[1], str):
|
251
|
+
model_tsteps = int(in0shape[1])
|
252
|
+
logger.debug(f"Onnx model has fixed timesteps: {model_tsteps}.")
|
253
|
+
else:
|
254
|
+
model_tsteps = -1
|
255
|
+
logger.debug("Onnx model has dynamic timesteps dimension size.")
|
256
|
+
|
257
|
+
out_names = [n.name for n in session.get_outputs()]
|
159
258
|
|
160
259
|
if in0type.find("float16") != -1:
|
161
260
|
model_is_fp16 = True
|
@@ -163,38 +262,40 @@ def main() -> None:
|
|
163
262
|
else:
|
164
263
|
model_is_fp16 = False
|
165
264
|
|
265
|
+
logger.info(f"Read and compiled ONNX model from {model_path}.")
|
266
|
+
|
267
|
+
start_time = time.monotonic()
|
268
|
+
|
166
269
|
if mixdb is not None and hparams["batch_size"] == 1:
|
167
|
-
#
|
168
|
-
# Assume (of course) that mixdb feature, etc. is what model expects
|
169
|
-
if hparams["feature"] != mixdb.feature:
|
270
|
+
if hparams["feature"] != feature_mode: # warn on mis-match, but TBD could be sov-mode
|
170
271
|
logger.warning("Mixture feature does not match model feature, this inference run may fail.")
|
171
|
-
|
172
|
-
|
272
|
+
logger.info(f"Processing {len(p_mixids)} mixtures from SonusAI mixdb ...")
|
273
|
+
logger.info(f"Using OnnxRT provider {session.get_providers()} ...")
|
173
274
|
|
174
275
|
for mixid in p_mixids:
|
175
|
-
# frames x stride x feature_params
|
176
|
-
|
276
|
+
# feature data is now always fp32 and frames x stride x feature_params
|
277
|
+
feat_dat, _ = mixdb.mixture_ft(mixid)
|
278
|
+
if feat_dat.shape[1] > 1: # stride mode num_frames overrides batch dim, no reshape
|
279
|
+
stride_mode = 1
|
280
|
+
batch_size = feat_dat.shape[0] # num_frames in stride mode becomes batch size
|
177
281
|
if hparams["timesteps"] == 0:
|
178
|
-
# no timestep dimension,
|
282
|
+
# no timestep dimension, remove the dimension
|
179
283
|
timesteps = 0
|
284
|
+
feat_dat = np.reshape(feat_dat, [batch_size, num_featparams])
|
180
285
|
else:
|
181
|
-
# fit frames into timestep dimension (TSE mode)
|
182
|
-
timesteps =
|
183
|
-
|
184
|
-
|
185
|
-
feature=feature,
|
186
|
-
batch_size=1,
|
187
|
-
timesteps=timesteps,
|
188
|
-
flatten=hparams["flatten"],
|
189
|
-
add1ch=hparams["add1ch"],
|
190
|
-
)
|
286
|
+
# fit frames into timestep dimension (TSE mode) and knowing batch_size = 1
|
287
|
+
timesteps = feat_dat.shape[0]
|
288
|
+
feat_dat = np.transpose(feat_dat, (1, 0, 2)) # transpose to 1 x frames=tsteps x feat_params
|
289
|
+
|
191
290
|
if model_is_fp16:
|
192
|
-
|
291
|
+
feat_dat = np.float16(feat_dat) # type: ignore[assignment]
|
292
|
+
|
193
293
|
# run inference, ort session wants i.e. batch x timesteps x feat_params, outputs numpy BxTxFP or BxFP
|
194
|
-
predict = session.run(out_names, {in0name:
|
294
|
+
predict = session.run(out_names, {in0name: feat_dat})[0]
|
195
295
|
# predict, _ = reshape_outputs(predict=predict[0], timesteps=frames) # frames x feat_params
|
296
|
+
|
196
297
|
output_fname = join(output_dir, mixdb.mixture(mixid).name)
|
197
|
-
with h5py.File(output_fname, "a") as f:
|
298
|
+
with h5py.File(output_fname + ".h5", "a") as f:
|
198
299
|
if "predict" in f:
|
199
300
|
del f["predict"]
|
200
301
|
f.create_dataset("predict", data=predict)
|
@@ -206,6 +307,57 @@ def main() -> None:
|
|
206
307
|
owav_name = splitext(output_fname)[0] + "_predict.wav"
|
207
308
|
write_audio(owav_name, predict_audio)
|
208
309
|
|
310
|
+
else: # TBD add support
|
311
|
+
logger.info("Mixture database does not exist or batch_size is not equal to one, exiting ...")
|
312
|
+
|
313
|
+
end_time = time.monotonic()
|
314
|
+
logger.info(f"Completed in {seconds_to_hms(seconds=end_time - start_time)}")
|
315
|
+
logger.info("")
|
316
|
+
|
209
317
|
|
210
318
|
if __name__ == "__main__":
|
211
|
-
|
319
|
+
from sonusai import exception_handler
|
320
|
+
from sonusai.utils import register_keyboard_interrupt
|
321
|
+
|
322
|
+
register_keyboard_interrupt()
|
323
|
+
try:
|
324
|
+
main()
|
325
|
+
except Exception as e:
|
326
|
+
exception_handler(e)
|
327
|
+
|
328
|
+
# mixdb_path = None
|
329
|
+
# mixdb: MixtureDatabase | None = None
|
330
|
+
# p_mixids: list[int] = []
|
331
|
+
# entries: list[PathInfo] = []
|
332
|
+
#
|
333
|
+
# if len(data_paths) == 1 and isdir(data_paths[0]):
|
334
|
+
# # Assume it's a single path to SonusAI mixdb subdir
|
335
|
+
# in_basename = basename(normpath(data_paths[0]))
|
336
|
+
# mixdb_path = data_paths[0]
|
337
|
+
# else:
|
338
|
+
# # search all data paths for .wav, .flac (or whatever is specified in include)
|
339
|
+
# in_basename = ""
|
340
|
+
|
341
|
+
# if mixdb_path is not None: # a mixdb is found and loaded
|
342
|
+
# # Assume it's a single path to SonusAI mixdb subdir
|
343
|
+
# num_featparams = mixdb.feature_parameters
|
344
|
+
# logger.debug(f"SonusAI mixdb: found {mixdb.num_mixtures} mixtures with {num_featparams} classes")
|
345
|
+
# p_mixids = mixdb.mixids_to_list(mixids)
|
346
|
+
# if len(p_mixids) != mixdb.num_mixtures:
|
347
|
+
# logger.info(f"Processing a subset of {p_mixids} from available mixtures.")
|
348
|
+
# else:
|
349
|
+
# for p in data_paths:
|
350
|
+
# location = join(realpath(abspath(p)), "**", include)
|
351
|
+
# logger.debug(f"Processing files in {location}")
|
352
|
+
# for file in braced_iglob(pathname=location, recursive=True):
|
353
|
+
# name = file
|
354
|
+
# entries.append(PathInfo(abs_path=file, audio_filepath=name))
|
355
|
+
# logger.info(f"{len(data_paths)} data paths specified, found {len(pfile)} audio files.")
|
356
|
+
|
357
|
+
# feature, _ = reshape_inputs(
|
358
|
+
# feature=feature,
|
359
|
+
# batch_size=1,
|
360
|
+
# timesteps=timesteps,
|
361
|
+
# flatten=hparams["flatten"],
|
362
|
+
# add1ch=hparams["add1ch"],
|
363
|
+
# )
|
sonusai/queries/__init__.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
# SonusAI query utilities
|
2
2
|
# ruff: noqa: F401
|
3
3
|
|
4
|
+
from .queries import get_mixids_from_class_indices
|
4
5
|
from .queries import get_mixids_from_noise
|
5
6
|
from .queries import get_mixids_from_snr
|
6
|
-
from .queries import
|
7
|
+
from .queries import get_mixids_from_source
|
7
8
|
from .queries import get_mixids_from_truth_function
|
8
|
-
from .queries import get_mixids_from_class_indices
|
sonusai/queries/queries.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
from collections.abc import Callable
|
2
2
|
from typing import Any
|
3
3
|
|
4
|
-
from
|
5
|
-
from
|
4
|
+
from ..datatypes import GeneralizedIDs
|
5
|
+
from ..mixture.mixdb import MixtureDatabase
|
6
6
|
|
7
7
|
|
8
8
|
def _true_predicate(_: Any) -> bool:
|
@@ -29,8 +29,8 @@ def get_mixids_from_mixture_field_predicate(
|
|
29
29
|
criteria_set = set()
|
30
30
|
for m_id in mixid_out:
|
31
31
|
value = getattr(mixdb.mixture(m_id), field)
|
32
|
-
if isinstance(value,
|
33
|
-
for v in value:
|
32
|
+
if isinstance(value, dict):
|
33
|
+
for v in value.values():
|
34
34
|
if predicate(v):
|
35
35
|
criteria_set.add(v)
|
36
36
|
elif predicate(value):
|
@@ -42,8 +42,8 @@ def get_mixids_from_mixture_field_predicate(
|
|
42
42
|
result[criterion] = []
|
43
43
|
for m_id in mixid_out:
|
44
44
|
value = getattr(mixdb.mixture(m_id), field)
|
45
|
-
if isinstance(value,
|
46
|
-
for v in value:
|
45
|
+
if isinstance(value, dict):
|
46
|
+
for v in value.values():
|
47
47
|
if v == criterion:
|
48
48
|
result[criterion].append(m_id)
|
49
49
|
elif value == criterion:
|
@@ -64,7 +64,7 @@ def get_mixids_from_truth_configs_field_predicate(
|
|
64
64
|
- keys are the matching field values
|
65
65
|
- values are lists of the mixids that match the criteria
|
66
66
|
"""
|
67
|
-
from
|
67
|
+
from ..mixture.constants import REQUIRED_TRUTH_CONFIGS
|
68
68
|
|
69
69
|
mixid_out = mixdb.mixids_to_list(mixids)
|
70
70
|
|
@@ -79,23 +79,24 @@ def get_mixids_from_truth_configs_field_predicate(
|
|
79
79
|
|
80
80
|
result = {}
|
81
81
|
for value in values:
|
82
|
-
# Get a list of
|
82
|
+
# Get a list of sources for each field value
|
83
83
|
indices = []
|
84
|
-
for
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
if
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
84
|
+
for s_ids in mixdb.source_file_ids.values():
|
85
|
+
for s_id in s_ids:
|
86
|
+
source = mixdb.source_file(s_id)
|
87
|
+
for truth_config in source.truth_configs.values():
|
88
|
+
if field in REQUIRED_TRUTH_CONFIGS:
|
89
|
+
if value in getattr(truth_config, field):
|
90
|
+
indices.append(s_id)
|
91
|
+
else:
|
92
|
+
if value in getattr(truth_config.config, field):
|
93
|
+
indices.append(s_id)
|
93
94
|
indices = sorted(set(indices))
|
94
95
|
|
95
96
|
mixids = []
|
96
97
|
for index in indices:
|
97
98
|
for m_id in mixid_out:
|
98
|
-
if index in [
|
99
|
+
if index in [source.file_id for source in mixdb.mixture(m_id).all_sources.values()]:
|
99
100
|
mixids.append(m_id)
|
100
101
|
|
101
102
|
mixids = sorted(set(mixids))
|
@@ -109,18 +110,19 @@ def get_all_truth_configs_values_from_field(mixdb: MixtureDatabase, field: str)
|
|
109
110
|
"""
|
110
111
|
Generate a list of all values corresponding to the given field in truth_configs
|
111
112
|
"""
|
112
|
-
from
|
113
|
+
from ..mixture.constants import REQUIRED_TRUTH_CONFIGS
|
113
114
|
|
114
115
|
result = []
|
115
|
-
for
|
116
|
-
for
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
116
|
+
for sources in mixdb.source_files.values():
|
117
|
+
for source in sources:
|
118
|
+
for truth_config in source.truth_configs.values():
|
119
|
+
if field in REQUIRED_TRUTH_CONFIGS:
|
120
|
+
value = getattr(truth_config, field)
|
121
|
+
else:
|
122
|
+
value = getattr(truth_config.config, field, None)
|
123
|
+
if not isinstance(value, list):
|
124
|
+
value = [value]
|
125
|
+
result.extend(value)
|
124
126
|
|
125
127
|
return sorted(set(result))
|
126
128
|
|
@@ -139,18 +141,18 @@ def get_mixids_from_noise(
|
|
139
141
|
return get_mixids_from_mixture_field_predicate(mixdb=mixdb, mixids=mixids, field="noise_id", predicate=predicate)
|
140
142
|
|
141
143
|
|
142
|
-
def
|
144
|
+
def get_mixids_from_source(
|
143
145
|
mixdb: MixtureDatabase,
|
144
146
|
mixids: GeneralizedIDs = "*",
|
145
147
|
predicate: Callable[[Any], bool] | None = None,
|
146
148
|
) -> dict[int, list[int]]:
|
147
149
|
"""
|
148
|
-
Generate mixids based on a
|
150
|
+
Generate mixids based on a source index predicate
|
149
151
|
Return a dictionary where:
|
150
|
-
- keys are the
|
151
|
-
- values are lists of the mixids that match the
|
152
|
+
- keys are the source indices
|
153
|
+
- values are lists of the mixids that match the source index
|
152
154
|
"""
|
153
|
-
return get_mixids_from_mixture_field_predicate(mixdb=mixdb, mixids=mixids, field="
|
155
|
+
return get_mixids_from_mixture_field_predicate(mixdb=mixdb, mixids=mixids, field="source_ids", predicate=predicate)
|
154
156
|
|
155
157
|
|
156
158
|
def get_mixids_from_snr(
|
@@ -178,7 +180,9 @@ def get_mixids_from_snr(
|
|
178
180
|
result: dict[float, list[int]] = {}
|
179
181
|
for snr in snrs:
|
180
182
|
# Get a list of mixids for each SNR
|
181
|
-
result[snr] = sorted(
|
183
|
+
result[snr] = sorted(
|
184
|
+
[i for i, mixture in enumerate(mixdb.mixtures) if mixture.noise.snr == snr and i in mixid_out]
|
185
|
+
)
|
182
186
|
|
183
187
|
return result
|
184
188
|
|
sonusai/speech/librispeech.py
CHANGED
sonusai/speech/mcgill.py
CHANGED
sonusai/speech/timit.py
CHANGED
@@ -12,7 +12,7 @@ def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
|
|
12
12
|
"""
|
13
13
|
import string
|
14
14
|
|
15
|
-
from
|
15
|
+
from ..mixture.audio import get_sample_rate
|
16
16
|
|
17
17
|
file = Path(audio).with_suffix(".TXT")
|
18
18
|
if not os.path.exists(file):
|
@@ -52,7 +52,7 @@ def load_phonemes(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None
|
|
52
52
|
|
53
53
|
|
54
54
|
def _load_ta(audio: str | os.PathLike[str], tier: str) -> list[TimeAlignedType] | None:
|
55
|
-
from
|
55
|
+
from ..mixture.audio import get_sample_rate
|
56
56
|
|
57
57
|
if tier == "words":
|
58
58
|
file = Path(audio).with_suffix(".WRD")
|
@@ -14,20 +14,6 @@ Inputs:
|
|
14
14
|
|
15
15
|
"""
|
16
16
|
|
17
|
-
import signal
|
18
|
-
|
19
|
-
|
20
|
-
def signal_handler(_sig, _frame):
|
21
|
-
import sys
|
22
|
-
|
23
|
-
from sonusai import logger
|
24
|
-
|
25
|
-
logger.info("Canceled due to keyboard interrupt")
|
26
|
-
sys.exit(1)
|
27
|
-
|
28
|
-
|
29
|
-
signal.signal(signal.SIGINT, signal_handler)
|
30
|
-
|
31
17
|
|
32
18
|
def summarize_metric_spenh(location: str, by: str = "MIXID", reverse: bool = False) -> str:
|
33
19
|
import glob
|
@@ -56,10 +42,10 @@ def summarize_metric_spenh(location: str, by: str = "MIXID", reverse: bool = Fal
|
|
56
42
|
def main():
|
57
43
|
from docopt import docopt
|
58
44
|
|
59
|
-
import
|
45
|
+
from sonusai import __version__ as sai_version
|
60
46
|
from sonusai.utils import trim_docstring
|
61
47
|
|
62
|
-
args = docopt(trim_docstring(__doc__), version=
|
48
|
+
args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
|
63
49
|
|
64
50
|
by = args["--sort"]
|
65
51
|
reverse = args["--reverse"]
|
@@ -69,4 +55,11 @@ def main():
|
|
69
55
|
|
70
56
|
|
71
57
|
if __name__ == "__main__":
|
72
|
-
|
58
|
+
from sonusai import exception_handler
|
59
|
+
from sonusai.utils import register_keyboard_interrupt
|
60
|
+
|
61
|
+
register_keyboard_interrupt()
|
62
|
+
try:
|
63
|
+
main()
|
64
|
+
except Exception as e:
|
65
|
+
exception_handler(e)
|