sonusai 1.0.16__cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
sonusai/onnx_predict.py
ADDED
@@ -0,0 +1,363 @@
|
|
1
|
+
"""sonusai onnx_predict
|
2
|
+
|
3
|
+
usage: onnx_predict [-hvlwr] [--include GLOB] [-i MIXID] MODEL DATA ...
|
4
|
+
|
5
|
+
options:
|
6
|
+
-h, --help
|
7
|
+
-v, --verbose Be verbose.
|
8
|
+
-i MIXID, --mixid MIXID Mixture ID(s) to generate if input is a mixture database. [default: *].
|
9
|
+
--include GLOB Search only files whose base name matches GLOB. [default: *.{wav,flac}].
|
10
|
+
-w, --write-wav Calculate inverse transform of prediction and write .wav files
|
11
|
+
|
12
|
+
Run prediction (inference) using an ONNX model on a SonusAI mixture dataset or audio files from a glob path.
|
13
|
+
The ONNX Runtime (ort) inference engine is used to execute the inference.
|
14
|
+
|
15
|
+
Inputs:
|
16
|
+
MODEL ONNX model .onnx file of a trained model (weights are expected to be in the file).
|
17
|
+
The model must also include required Sonusai hyperparameters. See theSonusai torchl_onnx command.
|
18
|
+
|
19
|
+
DATA A string which must be one of the following:
|
20
|
+
1. Path to a single file. The prediction data is written to <filename_predict.*> in same location.
|
21
|
+
2. Path to a Sonusai Mixture Database directory.
|
22
|
+
- Sonusai mixture database directory, prediction files will be named mixid_predict.*
|
23
|
+
- MIXID will select a subset of mixture ids
|
24
|
+
3. Directory with audio files found recursively within. See GLOB audio file extensions below.
|
25
|
+
4. Regex resolving to a list of files.
|
26
|
+
- Subdirectory containing audio files with extension
|
27
|
+
- Regex resolving to a list of audio files
|
28
|
+
|
29
|
+
generate feature and truth data if not found.
|
30
|
+
|
31
|
+
Note there are multiple ways to process model prediction over multiple audio data files:
|
32
|
+
1. TSE (timestep single extension): mixture transform frames are fit into the timestep dimension and the model run as
|
33
|
+
a single inference call. If batch_size is > 1 then run multiple mixtures in one call with shorter mixtures
|
34
|
+
zero-padded to the size of the largest mixture.
|
35
|
+
2. TME (timestep multi-extension): mixture is split into multiple timesteps, i.e. batch[0] is starting timesteps, ...
|
36
|
+
Note that batches are run independently, thus sequential state from one set of timesteps to the next will not be
|
37
|
+
maintained, thus results for such models (i.e. conv, LSTMs, in the timestep dimension) would not match using
|
38
|
+
TSE mode.
|
39
|
+
|
40
|
+
TBD not sure below make sense, need to continue ??
|
41
|
+
2. BSE (batch single extension): mixture transform frames are fit into the batch dimension. This make sense only if
|
42
|
+
independent predictions are made on each frame w/o considering previous frames (timesteps=1) or there is no
|
43
|
+
timestep dimension in the model (timesteps=0).
|
44
|
+
3. Classification
|
45
|
+
|
46
|
+
Outputs the following to opredict-<TIMESTAMP> directory:
|
47
|
+
<id>
|
48
|
+
predict.h5
|
49
|
+
onnx_predict.log
|
50
|
+
|
51
|
+
"""
|
52
|
+
|
53
|
+
|
54
|
+
def process_path(path, ext_list: list[str] | None = None):
|
55
|
+
"""
|
56
|
+
Check path which can be a single file, a subdirectory, or a regex
|
57
|
+
return:
|
58
|
+
- a list of files with matching extensions to any in ext_list provided (i.e. ['.wav', '.mp3', '.acc'])
|
59
|
+
- the basedir of the path, if
|
60
|
+
"""
|
61
|
+
import glob
|
62
|
+
from os.path import abspath
|
63
|
+
from os.path import commonprefix
|
64
|
+
from os.path import dirname
|
65
|
+
from os.path import isdir
|
66
|
+
from os.path import isfile
|
67
|
+
from os.path import join
|
68
|
+
|
69
|
+
from sonusai.utils.braced_glob import braced_iglob
|
70
|
+
|
71
|
+
if ext_list is None:
|
72
|
+
ext_list = [".wav", ".WAV", ".flac", ".FLAC", ".mp3", ".aac"]
|
73
|
+
|
74
|
+
# Check if the path is a single file, and return it as a list with the dirname
|
75
|
+
if isfile(path):
|
76
|
+
if any(path.endswith(ext) for ext in ext_list):
|
77
|
+
basedir = dirname(path) # base directory
|
78
|
+
if not basedir:
|
79
|
+
basedir = "./"
|
80
|
+
return [path], basedir
|
81
|
+
else:
|
82
|
+
return [], []
|
83
|
+
|
84
|
+
# Check if the path is a dir, recursively find all files any of the specified extensions, return file list and dir
|
85
|
+
if isdir(path):
|
86
|
+
matching_files = []
|
87
|
+
for ext in ext_list:
|
88
|
+
matching_files.extend(glob.glob(join(path, "**/*" + ext), recursive=True))
|
89
|
+
return matching_files, path
|
90
|
+
|
91
|
+
# Process as a regex, return list of filenames and basedir
|
92
|
+
apath = abspath(path) # join(abspath(path), "**", "*.{wav,flac,WAV}")
|
93
|
+
matching_files = []
|
94
|
+
for file in braced_iglob(pathname=apath, recursive=True):
|
95
|
+
matching_files.append(file)
|
96
|
+
if matching_files:
|
97
|
+
basedir = commonprefix(matching_files) # Find basedir
|
98
|
+
return matching_files, basedir
|
99
|
+
else:
|
100
|
+
return [], []
|
101
|
+
|
102
|
+
|
103
|
+
def main() -> None:
|
104
|
+
from docopt import docopt
|
105
|
+
|
106
|
+
from sonusai import __version__ as sai_version
|
107
|
+
from sonusai.utils.docstring import trim_docstring
|
108
|
+
|
109
|
+
args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
|
110
|
+
|
111
|
+
verbose = args["--verbose"]
|
112
|
+
wav = args["--write-wav"]
|
113
|
+
mixids = args["--mixid"]
|
114
|
+
include = args["--include"]
|
115
|
+
model_path = args["MODEL"]
|
116
|
+
data_paths = args["DATA"]
|
117
|
+
|
118
|
+
# Quick check of CPU and GPU devices
|
119
|
+
import re
|
120
|
+
import subprocess
|
121
|
+
import time
|
122
|
+
from os import makedirs
|
123
|
+
from os.path import basename
|
124
|
+
from os.path import exists
|
125
|
+
from os.path import isdir
|
126
|
+
from os.path import isfile
|
127
|
+
from os.path import join
|
128
|
+
from os.path import normpath
|
129
|
+
from os.path import splitext
|
130
|
+
|
131
|
+
import h5py
|
132
|
+
import numpy as np
|
133
|
+
import onnxruntime as ort
|
134
|
+
import psutil
|
135
|
+
|
136
|
+
from sonusai import create_file_handler
|
137
|
+
from sonusai import initial_log_messages
|
138
|
+
from sonusai import logger
|
139
|
+
from sonusai import update_console_handler
|
140
|
+
from sonusai.mixture import MixtureDatabase
|
141
|
+
from sonusai.mixture import get_audio_from_feature
|
142
|
+
from sonusai.utils.create_ts_name import create_ts_name
|
143
|
+
from sonusai.utils.onnx_utils import load_ort_session
|
144
|
+
from sonusai.utils.seconds_to_hms import seconds_to_hms
|
145
|
+
from sonusai.utils.write_audio import write_audio
|
146
|
+
|
147
|
+
num_cpu = psutil.cpu_count()
|
148
|
+
cpu_percent = psutil.cpu_percent(interval=1)
|
149
|
+
print(f"#CPUs: {num_cpu}, current CPU utilization: {cpu_percent}%")
|
150
|
+
print(f"Memory utilization: {psutil.virtual_memory().percent}%")
|
151
|
+
|
152
|
+
vga_devices = [
|
153
|
+
line.split(" ", 3)[-1]
|
154
|
+
for line in subprocess.check_output("lspci | grep -i vga", shell=True).decode().splitlines()
|
155
|
+
]
|
156
|
+
nv_devs = list(filter(lambda x: "nvidia" in x.lower(), vga_devices))
|
157
|
+
nv_mods = [re.search(r"\[.*?\]", device).group(0) if re.search(r"\[.*?\]", device) else None for device in nv_devs]
|
158
|
+
if len(nv_mods) > 0:
|
159
|
+
print(f"{len(nv_mods)} Nvidia devices present: {nv_mods}") # prints model names
|
160
|
+
else:
|
161
|
+
print("No cuda devices present, using cpu")
|
162
|
+
|
163
|
+
avail_providers = ort.get_available_providers()
|
164
|
+
print(f"Loaded ONNX Runtime, available providers: {avail_providers}.")
|
165
|
+
if len(nv_mods) > 0:
|
166
|
+
print(
|
167
|
+
"If GPU is desired, need to replace onnxruntime with onnxruntime-gpu i.e. using pip:"
|
168
|
+
"> pip uninstall onnxruntime"
|
169
|
+
"> pip install onnxruntime-gpu\n\n"
|
170
|
+
)
|
171
|
+
|
172
|
+
# Quick check that model is valid
|
173
|
+
if exists(model_path) and isfile(model_path):
|
174
|
+
try:
|
175
|
+
session = ort.InferenceSession(model_path)
|
176
|
+
options = ort.SessionOptions()
|
177
|
+
except Exception as e:
|
178
|
+
print(f"Error: could not load ONNX model from {model_path}: {e}")
|
179
|
+
raise SystemExit(1) from e
|
180
|
+
else:
|
181
|
+
print(f"Error: model file path is not valid: {model_path}")
|
182
|
+
raise SystemExit(1)
|
183
|
+
|
184
|
+
# Check datapath is valid
|
185
|
+
if len(data_paths) == 1 and isdir(data_paths[0]): # Try opening as mixdb subdir
|
186
|
+
mixdb_path = data_paths[0]
|
187
|
+
try:
|
188
|
+
mixdb = MixtureDatabase(mixdb_path)
|
189
|
+
except Exception:
|
190
|
+
mixdb_path = None
|
191
|
+
in_basename = basename(normpath(data_paths[0]))
|
192
|
+
output_dir = create_ts_name("opredict-" + in_basename)
|
193
|
+
num_featparams = mixdb.feature_parameters
|
194
|
+
print(f"Loaded SonusAI mixdb with {mixdb.num_mixtures} mixtures and {num_featparams} classes")
|
195
|
+
p_mixids = mixdb.mixids_to_list(mixids)
|
196
|
+
feature_mode = mixdb.feature
|
197
|
+
|
198
|
+
if mixdb_path is None:
|
199
|
+
if verbose:
|
200
|
+
print(f"Checking {len(data_paths)} locations ... ")
|
201
|
+
# Check location, default ext are ['.wav', '.WAV', '.flac', '.FLAC', '.mp3', '.aac']
|
202
|
+
pfiles, basedir = process_path(data_paths)
|
203
|
+
if pfiles is None or len(pfiles) < 1:
|
204
|
+
print(f"No audio files or Sonusai mixture database found in {data_paths}, exiting ...")
|
205
|
+
raise SystemExit(1)
|
206
|
+
else:
|
207
|
+
pfiles = sorted(pfiles, key=basename)
|
208
|
+
output_dir = basedir
|
209
|
+
|
210
|
+
if mixdb_path is not None or len(pfiles) > 1: # log file only if mixdb or more than one file
|
211
|
+
makedirs(output_dir, exist_ok=True)
|
212
|
+
# Setup logging file
|
213
|
+
create_file_handler(join(output_dir, "onnx-predict.log"), verbose)
|
214
|
+
update_console_handler(verbose)
|
215
|
+
initial_log_messages("onnx_predict")
|
216
|
+
# print some previous messages
|
217
|
+
logger.info(f"Loaded ONNX Runtime, available providers: {avail_providers}.")
|
218
|
+
if mixdb_path:
|
219
|
+
logger.debug(f"Loaded SonusAI mixdb with {mixdb.num_mixtures} mixtures and {num_featparams} classes")
|
220
|
+
if len(p_mixids) != mixdb.num_mixtures:
|
221
|
+
logger.info(f"Processing a subset of {len(p_mixids)} from available mixtures.")
|
222
|
+
|
223
|
+
# Reload model/session and do more thorough checking
|
224
|
+
session, options, model_root, hparams, sess_inputs, sess_outputs = load_ort_session(model_path)
|
225
|
+
if "CUDAExecutionProvider" in avail_providers:
|
226
|
+
session.set_providers(["CUDAExecutionProvider"])
|
227
|
+
if hparams is None:
|
228
|
+
logger.error("Error: ONNX model does not have required SonusAI hyperparameters, cannot proceed.")
|
229
|
+
raise SystemExit(1)
|
230
|
+
|
231
|
+
if len(sess_inputs) != 1: # TBD update to support state_in and state_out
|
232
|
+
logger.error(f"Error: ONNX model does not have 1 input, but {len(sess_inputs)}. Exit due to unknown input.")
|
233
|
+
|
234
|
+
in0name = sess_inputs[0].name
|
235
|
+
in0type = sess_inputs[0].type
|
236
|
+
in0shape = sess_inputs[0].shape # a list
|
237
|
+
# Check for 2 cases of model feature input shape: batch x timesteps x fparams or batch x fparams
|
238
|
+
if not isinstance(in0shape[0], str):
|
239
|
+
model_batchsz = int(in0shape[0])
|
240
|
+
logger.debug(f"Onnx model has fixed batch_size: {model_batchsz}.")
|
241
|
+
else:
|
242
|
+
model_batchsz = -1
|
243
|
+
logger.debug("Onnx model has a dynamic batch_size.")
|
244
|
+
|
245
|
+
if len(in0shape) < 3:
|
246
|
+
model_tsteps = 0
|
247
|
+
model_featparams = int(in0shape[1])
|
248
|
+
else:
|
249
|
+
model_featparams = int(in0shape[2])
|
250
|
+
if not isinstance(in0shape[1], str):
|
251
|
+
model_tsteps = int(in0shape[1])
|
252
|
+
logger.debug(f"Onnx model has fixed timesteps: {model_tsteps}.")
|
253
|
+
else:
|
254
|
+
model_tsteps = -1
|
255
|
+
logger.debug("Onnx model has dynamic timesteps dimension size.")
|
256
|
+
|
257
|
+
out_names = [n.name for n in session.get_outputs()]
|
258
|
+
|
259
|
+
if in0type.find("float16") != -1:
|
260
|
+
model_is_fp16 = True
|
261
|
+
logger.info("Detected input of float16, converting all feature inputs to that type.")
|
262
|
+
else:
|
263
|
+
model_is_fp16 = False
|
264
|
+
|
265
|
+
logger.info(f"Read and compiled ONNX model from {model_path}.")
|
266
|
+
|
267
|
+
start_time = time.monotonic()
|
268
|
+
|
269
|
+
if mixdb is not None and hparams["batch_size"] == 1:
|
270
|
+
if hparams["feature"] != feature_mode: # warn on mis-match, but TBD could be sov-mode
|
271
|
+
logger.warning("Mixture feature does not match model feature, this inference run may fail.")
|
272
|
+
logger.info(f"Processing {len(p_mixids)} mixtures from SonusAI mixdb ...")
|
273
|
+
logger.info(f"Using OnnxRT provider {session.get_providers()} ...")
|
274
|
+
|
275
|
+
for mixid in p_mixids:
|
276
|
+
# feature data is now always fp32 and frames x stride x feature_params
|
277
|
+
feat_dat, _ = mixdb.mixture_ft(mixid)
|
278
|
+
if feat_dat.shape[1] > 1: # stride mode num_frames overrides batch dim, no reshape
|
279
|
+
stride_mode = 1
|
280
|
+
batch_size = feat_dat.shape[0] # num_frames in stride mode becomes batch size
|
281
|
+
if hparams["timesteps"] == 0:
|
282
|
+
# no timestep dimension, remove the dimension
|
283
|
+
timesteps = 0
|
284
|
+
feat_dat = np.reshape(feat_dat, [batch_size, num_featparams])
|
285
|
+
else:
|
286
|
+
# fit frames into timestep dimension (TSE mode) and knowing batch_size = 1
|
287
|
+
timesteps = feat_dat.shape[0]
|
288
|
+
feat_dat = np.transpose(feat_dat, (1, 0, 2)) # transpose to 1 x frames=tsteps x feat_params
|
289
|
+
|
290
|
+
if model_is_fp16:
|
291
|
+
feat_dat = np.float16(feat_dat) # type: ignore[assignment]
|
292
|
+
|
293
|
+
# run inference, ort session wants i.e. batch x timesteps x feat_params, outputs numpy BxTxFP or BxFP
|
294
|
+
predict = session.run(out_names, {in0name: feat_dat})[0]
|
295
|
+
# predict, _ = reshape_outputs(predict=predict[0], timesteps=frames) # frames x feat_params
|
296
|
+
|
297
|
+
output_fname = join(output_dir, mixdb.mixture(mixid).name)
|
298
|
+
with h5py.File(output_fname + ".h5", "a") as f:
|
299
|
+
if "predict" in f:
|
300
|
+
del f["predict"]
|
301
|
+
f.create_dataset("predict", data=predict)
|
302
|
+
if wav:
|
303
|
+
# note only makes sense if model is predicting audio, i.e., timestep dimension exists
|
304
|
+
# predict_audio wants [frames, channels, feature_parameters] equivalent to timesteps, batch, bins
|
305
|
+
predict = np.transpose(predict, [1, 0, 2])
|
306
|
+
predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
|
307
|
+
owav_name = splitext(output_fname)[0] + "_predict.wav"
|
308
|
+
write_audio(owav_name, predict_audio)
|
309
|
+
|
310
|
+
else: # TBD add support
|
311
|
+
logger.info("Mixture database does not exist or batch_size is not equal to one, exiting ...")
|
312
|
+
|
313
|
+
end_time = time.monotonic()
|
314
|
+
logger.info(f"Completed in {seconds_to_hms(seconds=end_time - start_time)}")
|
315
|
+
logger.info("")
|
316
|
+
|
317
|
+
|
318
|
+
if __name__ == "__main__":
|
319
|
+
from sonusai import exception_handler
|
320
|
+
from sonusai.utils.keyboard_interrupt import register_keyboard_interrupt
|
321
|
+
|
322
|
+
register_keyboard_interrupt()
|
323
|
+
try:
|
324
|
+
main()
|
325
|
+
except Exception as e:
|
326
|
+
exception_handler(e)
|
327
|
+
|
328
|
+
# mixdb_path = None
|
329
|
+
# mixdb: MixtureDatabase | None = None
|
330
|
+
# p_mixids: list[int] = []
|
331
|
+
# entries: list[PathInfo] = []
|
332
|
+
#
|
333
|
+
# if len(data_paths) == 1 and isdir(data_paths[0]):
|
334
|
+
# # Assume it's a single path to SonusAI mixdb subdir
|
335
|
+
# in_basename = basename(normpath(data_paths[0]))
|
336
|
+
# mixdb_path = data_paths[0]
|
337
|
+
# else:
|
338
|
+
# # search all data paths for .wav, .flac (or whatever is specified in include)
|
339
|
+
# in_basename = ""
|
340
|
+
|
341
|
+
# if mixdb_path is not None: # a mixdb is found and loaded
|
342
|
+
# # Assume it's a single path to SonusAI mixdb subdir
|
343
|
+
# num_featparams = mixdb.feature_parameters
|
344
|
+
# logger.debug(f"SonusAI mixdb: found {mixdb.num_mixtures} mixtures with {num_featparams} classes")
|
345
|
+
# p_mixids = mixdb.mixids_to_list(mixids)
|
346
|
+
# if len(p_mixids) != mixdb.num_mixtures:
|
347
|
+
# logger.info(f"Processing a subset of {p_mixids} from available mixtures.")
|
348
|
+
# else:
|
349
|
+
# for p in data_paths:
|
350
|
+
# location = join(realpath(abspath(p)), "**", include)
|
351
|
+
# logger.debug(f"Processing files in {location}")
|
352
|
+
# for file in braced_iglob(pathname=location, recursive=True):
|
353
|
+
# name = file
|
354
|
+
# entries.append(PathInfo(abs_path=file, audio_filepath=name))
|
355
|
+
# logger.info(f"{len(data_paths)} data paths specified, found {len(pfile)} audio files.")
|
356
|
+
|
357
|
+
# feature, _ = reshape_inputs(
|
358
|
+
# feature=feature,
|
359
|
+
# batch_size=1,
|
360
|
+
# timesteps=timesteps,
|
361
|
+
# flatten=hparams["flatten"],
|
362
|
+
# add1ch=hparams["add1ch"],
|
363
|
+
# )
|
File without changes
|
sonusai/parse/expand.py
ADDED
@@ -0,0 +1,156 @@
|
|
1
|
+
"""
|
2
|
+
Parse 'expand' expressions.
|
3
|
+
|
4
|
+
This module provides functionality to find, parse and evaluate 'expand'
|
5
|
+
expressions in text, supporting nested expressions and random value generation.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import re
|
9
|
+
from dataclasses import dataclass
|
10
|
+
|
11
|
+
import pyparsing as pp
|
12
|
+
|
13
|
+
# Constants
|
14
|
+
SAI_EXPAND_PATTERN = r"expand\("
|
15
|
+
SAI_RAND_LITERAL = "rand"
|
16
|
+
SAI_EXPAND_LITERAL = "expand"
|
17
|
+
|
18
|
+
|
19
|
+
@dataclass
|
20
|
+
class Match:
|
21
|
+
"""Represents a matched 'expand' expression in text."""
|
22
|
+
|
23
|
+
group: str
|
24
|
+
span: tuple[int, int]
|
25
|
+
|
26
|
+
def start(self) -> int:
|
27
|
+
"""Return the start position of the match."""
|
28
|
+
return self.span[0]
|
29
|
+
|
30
|
+
def end(self) -> int:
|
31
|
+
"""Return the end position of the match."""
|
32
|
+
return self.span[1]
|
33
|
+
|
34
|
+
|
35
|
+
def find_matching_parenthesis(text: str, start_pos: int) -> int:
|
36
|
+
"""Find the position of the matching closing parenthesis.
|
37
|
+
|
38
|
+
:param text: The text to search in
|
39
|
+
:param start_pos: Position after the opening parenthesis
|
40
|
+
:return: Position after the matching closing parenthesis
|
41
|
+
:raises ValueError: If no matching parenthesis is found
|
42
|
+
"""
|
43
|
+
num_lparen = 1
|
44
|
+
pos = start_pos
|
45
|
+
|
46
|
+
while num_lparen != 0 and pos < len(text):
|
47
|
+
if text[pos] == "(":
|
48
|
+
num_lparen += 1
|
49
|
+
elif text[pos] == ")":
|
50
|
+
num_lparen -= 1
|
51
|
+
pos += 1
|
52
|
+
|
53
|
+
if num_lparen != 0:
|
54
|
+
raise ValueError(f"Unbalanced parenthesis in '{text}'")
|
55
|
+
|
56
|
+
return pos
|
57
|
+
|
58
|
+
|
59
|
+
def find_expand(text: str) -> list[Match]:
|
60
|
+
"""Find all 'expand' expressions in the text.
|
61
|
+
|
62
|
+
:param text: The text to search in
|
63
|
+
:return: List of Match objects for each 'expand' expression
|
64
|
+
:raises ValueError: If parentheses are unbalanced
|
65
|
+
"""
|
66
|
+
results = []
|
67
|
+
matches = re.finditer(SAI_EXPAND_PATTERN, text)
|
68
|
+
|
69
|
+
for match in matches:
|
70
|
+
start = match.start()
|
71
|
+
end_pos = find_matching_parenthesis(text, match.end())
|
72
|
+
results.append(Match(group=text[start:end_pos], span=(start, end_pos)))
|
73
|
+
|
74
|
+
return results
|
75
|
+
|
76
|
+
|
77
|
+
def create_parser() -> pp.ParserElement:
|
78
|
+
"""Create a pyparsing parser for 'expand' expressions.
|
79
|
+
|
80
|
+
:return: Parser for 'expand' expressions
|
81
|
+
"""
|
82
|
+
lparen = pp.Literal("(")
|
83
|
+
rparen = pp.Literal(")")
|
84
|
+
comma = pp.Literal(",")
|
85
|
+
|
86
|
+
# Define numeric types
|
87
|
+
real_number = pp.pyparsing_common.real
|
88
|
+
signed_integer = pp.pyparsing_common.signed_integer
|
89
|
+
number = real_number | signed_integer
|
90
|
+
|
91
|
+
# Define identifiers and expressions
|
92
|
+
identifier = pp.Word(pp.alphanums + "_.-")
|
93
|
+
|
94
|
+
# Define 'rand' expression
|
95
|
+
rand_literal = pp.Literal(SAI_RAND_LITERAL)
|
96
|
+
rand_expression = (rand_literal + lparen + number + comma + number + rparen).set_parse_action(
|
97
|
+
lambda tokens: "".join(map(str, tokens))
|
98
|
+
)
|
99
|
+
|
100
|
+
# Define 'expand' expression
|
101
|
+
expand_literal = pp.Literal(SAI_EXPAND_LITERAL)
|
102
|
+
expand_args = pp.DelimitedList(rand_expression | identifier, min=1)
|
103
|
+
expand_expression = expand_literal + lparen + expand_args("args") + rparen
|
104
|
+
|
105
|
+
return expand_expression
|
106
|
+
|
107
|
+
|
108
|
+
def parse_expand(text: str) -> list[str]:
|
109
|
+
"""Parse an 'expand' expression and extract its arguments.
|
110
|
+
|
111
|
+
:param text: Text containing an 'expand' expression
|
112
|
+
:return: List of argument values
|
113
|
+
:raises ValueError: If the expression cannot be parsed
|
114
|
+
"""
|
115
|
+
parser = create_parser()
|
116
|
+
|
117
|
+
try:
|
118
|
+
result = parser.parse_string(text)
|
119
|
+
return list(result.args)
|
120
|
+
except pp.ParseException as e:
|
121
|
+
raise ValueError(f"Could not parse '{text}'") from e
|
122
|
+
|
123
|
+
|
124
|
+
def expand(directive: str) -> list[str]:
|
125
|
+
"""Evaluate the 'expand' directive.
|
126
|
+
|
127
|
+
Recursively processes and expands 'expand' expressions in the text,
|
128
|
+
starting with the innermost expressions.
|
129
|
+
|
130
|
+
:param directive: Directive to evaluate
|
131
|
+
:return: A list of the expanded results
|
132
|
+
"""
|
133
|
+
# Initialize with input
|
134
|
+
expanded = [directive]
|
135
|
+
|
136
|
+
# Look for 'expand' patterns
|
137
|
+
matches = find_expand(directive)
|
138
|
+
|
139
|
+
# If no pattern found, return the original text
|
140
|
+
if not matches:
|
141
|
+
return expanded
|
142
|
+
|
143
|
+
# Remove the original text as we'll replace it with expanded versions
|
144
|
+
expanded.pop()
|
145
|
+
|
146
|
+
# Start with the innermost match (last in the list)
|
147
|
+
match = matches[-1]
|
148
|
+
prelude = directive[: match.start()]
|
149
|
+
postlude = directive[match.end() :]
|
150
|
+
|
151
|
+
# Process each value in the 'expand' expression
|
152
|
+
for value in parse_expand(match.group):
|
153
|
+
# Recursively expand the text with each replacement value
|
154
|
+
expanded.extend(expand(prelude + value + postlude))
|
155
|
+
|
156
|
+
return expanded
|
@@ -0,0 +1,129 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from dataclasses import fields
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from pyparsing import Literal
|
6
|
+
from pyparsing import Optional
|
7
|
+
from pyparsing import ParseException
|
8
|
+
from pyparsing import ParseResults
|
9
|
+
from pyparsing import QuotedString
|
10
|
+
from pyparsing import Suppress
|
11
|
+
from pyparsing import Word
|
12
|
+
from pyparsing import ZeroOrMore
|
13
|
+
from pyparsing import alphanums
|
14
|
+
from pyparsing import alphas
|
15
|
+
from pyparsing import pyparsing_common
|
16
|
+
|
17
|
+
|
18
|
+
@dataclass
|
19
|
+
class SourceDirective:
|
20
|
+
"""Represents a parsed source directive with its parameters."""
|
21
|
+
|
22
|
+
unique: str | None = None
|
23
|
+
repeat: bool = False
|
24
|
+
loop: bool = False
|
25
|
+
start: int = 0
|
26
|
+
|
27
|
+
|
28
|
+
def parse_source_directive(directive: str) -> SourceDirective:
|
29
|
+
"""Parse a source directive into its components.
|
30
|
+
|
31
|
+
Parses directives of the form:
|
32
|
+
- choose(unique=None, repeat=False, loop=False, start=0)
|
33
|
+
- sequence(unique=None, loop=False, start=0)
|
34
|
+
|
35
|
+
:param directive: The directive string to parse
|
36
|
+
:return: SourceDirective with parsed parameters
|
37
|
+
:raises ValueError: If the directive format is invalid
|
38
|
+
"""
|
39
|
+
# Check for a simple directive without parentheses
|
40
|
+
if _is_simple_directive(directive):
|
41
|
+
return SourceDirective()
|
42
|
+
|
43
|
+
# Parse full directive with parameters
|
44
|
+
parsed_tokens = _parse_directive_grammar(directive)
|
45
|
+
params = _process_parsed_parameters(parsed_tokens, directive)
|
46
|
+
|
47
|
+
return SourceDirective(**params)
|
48
|
+
|
49
|
+
|
50
|
+
def _get_valid_parameters() -> set[str]:
|
51
|
+
"""Get valid parameter names from SourceDirective dataclass fields."""
|
52
|
+
return {field.name for field in fields(SourceDirective)}
|
53
|
+
|
54
|
+
|
55
|
+
def _is_simple_directive(directive: str) -> bool:
|
56
|
+
"""Check if the directive is just a function name without parentheses."""
|
57
|
+
directive_type = Literal("choose") | Literal("sequence")
|
58
|
+
try:
|
59
|
+
directive_type.parseString(directive, parseAll=True)
|
60
|
+
except ParseException:
|
61
|
+
return False
|
62
|
+
return True
|
63
|
+
|
64
|
+
|
65
|
+
def _parse_directive_grammar(directive: str) -> ParseResults:
|
66
|
+
"""Parse directive using pyparsing grammar and return tokens."""
|
67
|
+
# Define grammar components
|
68
|
+
directive_type = Literal("choose") | Literal("sequence")
|
69
|
+
identifier = Word(alphas + "_", alphanums + "_")
|
70
|
+
|
71
|
+
# Value types
|
72
|
+
none_value = Literal("None")
|
73
|
+
true_value = Literal("True")
|
74
|
+
false_value = Literal("False")
|
75
|
+
integer = pyparsing_common.signed_integer()
|
76
|
+
quoted_string = QuotedString('"', escChar="\\") | QuotedString("'", escChar="\\")
|
77
|
+
non_quoted_string = Word(alphanums + "_-./")
|
78
|
+
rand_value = Literal("rand")
|
79
|
+
|
80
|
+
# Combined value and parameter grammar
|
81
|
+
value = none_value | true_value | false_value | integer | quoted_string | non_quoted_string | rand_value
|
82
|
+
parameter = identifier + Suppress("=") + value
|
83
|
+
param_list = Optional(parameter + ZeroOrMore(Suppress(",") + parameter) + Optional(Suppress(",")))
|
84
|
+
directive_expr = Suppress(directive_type) + Suppress("(") + param_list + Suppress(")")
|
85
|
+
|
86
|
+
try:
|
87
|
+
return directive_expr.parseString(directive, parseAll=True)
|
88
|
+
except ParseException as e:
|
89
|
+
raise ValueError(f"Invalid directive format: '{directive}'. Error: {e}") from e
|
90
|
+
|
91
|
+
|
92
|
+
def _process_parsed_parameters(parsed_tokens: ParseResults, directive: str) -> dict:
|
93
|
+
"""Convert parsed tokens to a parameter dictionary with type conversion."""
|
94
|
+
params = {}
|
95
|
+
valid_params = _get_valid_parameters()
|
96
|
+
|
97
|
+
for i in range(0, len(parsed_tokens), 2):
|
98
|
+
param_name = parsed_tokens[i]
|
99
|
+
param_value = parsed_tokens[i + 1]
|
100
|
+
|
101
|
+
_validate_parameter_name(param_name, directive, valid_params)
|
102
|
+
params[param_name] = _convert_parameter_value(param_value)
|
103
|
+
|
104
|
+
return params
|
105
|
+
|
106
|
+
|
107
|
+
def _validate_parameter_name(param_name: Any, directive: str, valid_params: set[str]) -> None:
|
108
|
+
"""Validate that parameter name is allowed."""
|
109
|
+
if param_name not in valid_params:
|
110
|
+
raise ValueError(
|
111
|
+
f"Invalid directive format: '{directive}'. Error: parameter must be one of {', '.join(sorted(valid_params))}."
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
def _convert_parameter_value(param_value):
|
116
|
+
"""Convert string representations to appropriate Python types."""
|
117
|
+
if param_value == "None":
|
118
|
+
return None
|
119
|
+
elif param_value in ("True", "true", "Yes", "yes"):
|
120
|
+
return True
|
121
|
+
elif param_value in ("False", "false", "No", "no"):
|
122
|
+
return False
|
123
|
+
elif param_value == "rand":
|
124
|
+
return "rand"
|
125
|
+
elif isinstance(param_value, int):
|
126
|
+
return param_value
|
127
|
+
else:
|
128
|
+
# String value (already unquoted by pyparsing)
|
129
|
+
return param_value
|