sonusai 1.0.16__cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,363 @@
1
+ """sonusai onnx_predict
2
+
3
+ usage: onnx_predict [-hvlwr] [--include GLOB] [-i MIXID] MODEL DATA ...
4
+
5
+ options:
6
+ -h, --help
7
+ -v, --verbose Be verbose.
8
+ -i MIXID, --mixid MIXID Mixture ID(s) to generate if input is a mixture database. [default: *].
9
+ --include GLOB Search only files whose base name matches GLOB. [default: *.{wav,flac}].
10
+ -w, --write-wav Calculate inverse transform of prediction and write .wav files
11
+
12
+ Run prediction (inference) using an ONNX model on a SonusAI mixture dataset or audio files from a glob path.
13
+ The ONNX Runtime (ort) inference engine is used to execute the inference.
14
+
15
+ Inputs:
16
+ MODEL ONNX model .onnx file of a trained model (weights are expected to be in the file).
17
+ The model must also include required Sonusai hyperparameters. See theSonusai torchl_onnx command.
18
+
19
+ DATA A string which must be one of the following:
20
+ 1. Path to a single file. The prediction data is written to <filename_predict.*> in same location.
21
+ 2. Path to a Sonusai Mixture Database directory.
22
+ - Sonusai mixture database directory, prediction files will be named mixid_predict.*
23
+ - MIXID will select a subset of mixture ids
24
+ 3. Directory with audio files found recursively within. See GLOB audio file extensions below.
25
+ 4. Regex resolving to a list of files.
26
+ - Subdirectory containing audio files with extension
27
+ - Regex resolving to a list of audio files
28
+
29
+ generate feature and truth data if not found.
30
+
31
+ Note there are multiple ways to process model prediction over multiple audio data files:
32
+ 1. TSE (timestep single extension): mixture transform frames are fit into the timestep dimension and the model run as
33
+ a single inference call. If batch_size is > 1 then run multiple mixtures in one call with shorter mixtures
34
+ zero-padded to the size of the largest mixture.
35
+ 2. TME (timestep multi-extension): mixture is split into multiple timesteps, i.e. batch[0] is starting timesteps, ...
36
+ Note that batches are run independently, thus sequential state from one set of timesteps to the next will not be
37
+ maintained, thus results for such models (i.e. conv, LSTMs, in the timestep dimension) would not match using
38
+ TSE mode.
39
+
40
+ TBD not sure below make sense, need to continue ??
41
+ 2. BSE (batch single extension): mixture transform frames are fit into the batch dimension. This make sense only if
42
+ independent predictions are made on each frame w/o considering previous frames (timesteps=1) or there is no
43
+ timestep dimension in the model (timesteps=0).
44
+ 3. Classification
45
+
46
+ Outputs the following to opredict-<TIMESTAMP> directory:
47
+ <id>
48
+ predict.h5
49
+ onnx_predict.log
50
+
51
+ """
52
+
53
+
54
+ def process_path(path, ext_list: list[str] | None = None):
55
+ """
56
+ Check path which can be a single file, a subdirectory, or a regex
57
+ return:
58
+ - a list of files with matching extensions to any in ext_list provided (i.e. ['.wav', '.mp3', '.acc'])
59
+ - the basedir of the path, if
60
+ """
61
+ import glob
62
+ from os.path import abspath
63
+ from os.path import commonprefix
64
+ from os.path import dirname
65
+ from os.path import isdir
66
+ from os.path import isfile
67
+ from os.path import join
68
+
69
+ from sonusai.utils.braced_glob import braced_iglob
70
+
71
+ if ext_list is None:
72
+ ext_list = [".wav", ".WAV", ".flac", ".FLAC", ".mp3", ".aac"]
73
+
74
+ # Check if the path is a single file, and return it as a list with the dirname
75
+ if isfile(path):
76
+ if any(path.endswith(ext) for ext in ext_list):
77
+ basedir = dirname(path) # base directory
78
+ if not basedir:
79
+ basedir = "./"
80
+ return [path], basedir
81
+ else:
82
+ return [], []
83
+
84
+ # Check if the path is a dir, recursively find all files any of the specified extensions, return file list and dir
85
+ if isdir(path):
86
+ matching_files = []
87
+ for ext in ext_list:
88
+ matching_files.extend(glob.glob(join(path, "**/*" + ext), recursive=True))
89
+ return matching_files, path
90
+
91
+ # Process as a regex, return list of filenames and basedir
92
+ apath = abspath(path) # join(abspath(path), "**", "*.{wav,flac,WAV}")
93
+ matching_files = []
94
+ for file in braced_iglob(pathname=apath, recursive=True):
95
+ matching_files.append(file)
96
+ if matching_files:
97
+ basedir = commonprefix(matching_files) # Find basedir
98
+ return matching_files, basedir
99
+ else:
100
+ return [], []
101
+
102
+
103
+ def main() -> None:
104
+ from docopt import docopt
105
+
106
+ from sonusai import __version__ as sai_version
107
+ from sonusai.utils.docstring import trim_docstring
108
+
109
+ args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
110
+
111
+ verbose = args["--verbose"]
112
+ wav = args["--write-wav"]
113
+ mixids = args["--mixid"]
114
+ include = args["--include"]
115
+ model_path = args["MODEL"]
116
+ data_paths = args["DATA"]
117
+
118
+ # Quick check of CPU and GPU devices
119
+ import re
120
+ import subprocess
121
+ import time
122
+ from os import makedirs
123
+ from os.path import basename
124
+ from os.path import exists
125
+ from os.path import isdir
126
+ from os.path import isfile
127
+ from os.path import join
128
+ from os.path import normpath
129
+ from os.path import splitext
130
+
131
+ import h5py
132
+ import numpy as np
133
+ import onnxruntime as ort
134
+ import psutil
135
+
136
+ from sonusai import create_file_handler
137
+ from sonusai import initial_log_messages
138
+ from sonusai import logger
139
+ from sonusai import update_console_handler
140
+ from sonusai.mixture import MixtureDatabase
141
+ from sonusai.mixture import get_audio_from_feature
142
+ from sonusai.utils.create_ts_name import create_ts_name
143
+ from sonusai.utils.onnx_utils import load_ort_session
144
+ from sonusai.utils.seconds_to_hms import seconds_to_hms
145
+ from sonusai.utils.write_audio import write_audio
146
+
147
+ num_cpu = psutil.cpu_count()
148
+ cpu_percent = psutil.cpu_percent(interval=1)
149
+ print(f"#CPUs: {num_cpu}, current CPU utilization: {cpu_percent}%")
150
+ print(f"Memory utilization: {psutil.virtual_memory().percent}%")
151
+
152
+ vga_devices = [
153
+ line.split(" ", 3)[-1]
154
+ for line in subprocess.check_output("lspci | grep -i vga", shell=True).decode().splitlines()
155
+ ]
156
+ nv_devs = list(filter(lambda x: "nvidia" in x.lower(), vga_devices))
157
+ nv_mods = [re.search(r"\[.*?\]", device).group(0) if re.search(r"\[.*?\]", device) else None for device in nv_devs]
158
+ if len(nv_mods) > 0:
159
+ print(f"{len(nv_mods)} Nvidia devices present: {nv_mods}") # prints model names
160
+ else:
161
+ print("No cuda devices present, using cpu")
162
+
163
+ avail_providers = ort.get_available_providers()
164
+ print(f"Loaded ONNX Runtime, available providers: {avail_providers}.")
165
+ if len(nv_mods) > 0:
166
+ print(
167
+ "If GPU is desired, need to replace onnxruntime with onnxruntime-gpu i.e. using pip:"
168
+ "> pip uninstall onnxruntime"
169
+ "> pip install onnxruntime-gpu\n\n"
170
+ )
171
+
172
+ # Quick check that model is valid
173
+ if exists(model_path) and isfile(model_path):
174
+ try:
175
+ session = ort.InferenceSession(model_path)
176
+ options = ort.SessionOptions()
177
+ except Exception as e:
178
+ print(f"Error: could not load ONNX model from {model_path}: {e}")
179
+ raise SystemExit(1) from e
180
+ else:
181
+ print(f"Error: model file path is not valid: {model_path}")
182
+ raise SystemExit(1)
183
+
184
+ # Check datapath is valid
185
+ if len(data_paths) == 1 and isdir(data_paths[0]): # Try opening as mixdb subdir
186
+ mixdb_path = data_paths[0]
187
+ try:
188
+ mixdb = MixtureDatabase(mixdb_path)
189
+ except Exception:
190
+ mixdb_path = None
191
+ in_basename = basename(normpath(data_paths[0]))
192
+ output_dir = create_ts_name("opredict-" + in_basename)
193
+ num_featparams = mixdb.feature_parameters
194
+ print(f"Loaded SonusAI mixdb with {mixdb.num_mixtures} mixtures and {num_featparams} classes")
195
+ p_mixids = mixdb.mixids_to_list(mixids)
196
+ feature_mode = mixdb.feature
197
+
198
+ if mixdb_path is None:
199
+ if verbose:
200
+ print(f"Checking {len(data_paths)} locations ... ")
201
+ # Check location, default ext are ['.wav', '.WAV', '.flac', '.FLAC', '.mp3', '.aac']
202
+ pfiles, basedir = process_path(data_paths)
203
+ if pfiles is None or len(pfiles) < 1:
204
+ print(f"No audio files or Sonusai mixture database found in {data_paths}, exiting ...")
205
+ raise SystemExit(1)
206
+ else:
207
+ pfiles = sorted(pfiles, key=basename)
208
+ output_dir = basedir
209
+
210
+ if mixdb_path is not None or len(pfiles) > 1: # log file only if mixdb or more than one file
211
+ makedirs(output_dir, exist_ok=True)
212
+ # Setup logging file
213
+ create_file_handler(join(output_dir, "onnx-predict.log"), verbose)
214
+ update_console_handler(verbose)
215
+ initial_log_messages("onnx_predict")
216
+ # print some previous messages
217
+ logger.info(f"Loaded ONNX Runtime, available providers: {avail_providers}.")
218
+ if mixdb_path:
219
+ logger.debug(f"Loaded SonusAI mixdb with {mixdb.num_mixtures} mixtures and {num_featparams} classes")
220
+ if len(p_mixids) != mixdb.num_mixtures:
221
+ logger.info(f"Processing a subset of {len(p_mixids)} from available mixtures.")
222
+
223
+ # Reload model/session and do more thorough checking
224
+ session, options, model_root, hparams, sess_inputs, sess_outputs = load_ort_session(model_path)
225
+ if "CUDAExecutionProvider" in avail_providers:
226
+ session.set_providers(["CUDAExecutionProvider"])
227
+ if hparams is None:
228
+ logger.error("Error: ONNX model does not have required SonusAI hyperparameters, cannot proceed.")
229
+ raise SystemExit(1)
230
+
231
+ if len(sess_inputs) != 1: # TBD update to support state_in and state_out
232
+ logger.error(f"Error: ONNX model does not have 1 input, but {len(sess_inputs)}. Exit due to unknown input.")
233
+
234
+ in0name = sess_inputs[0].name
235
+ in0type = sess_inputs[0].type
236
+ in0shape = sess_inputs[0].shape # a list
237
+ # Check for 2 cases of model feature input shape: batch x timesteps x fparams or batch x fparams
238
+ if not isinstance(in0shape[0], str):
239
+ model_batchsz = int(in0shape[0])
240
+ logger.debug(f"Onnx model has fixed batch_size: {model_batchsz}.")
241
+ else:
242
+ model_batchsz = -1
243
+ logger.debug("Onnx model has a dynamic batch_size.")
244
+
245
+ if len(in0shape) < 3:
246
+ model_tsteps = 0
247
+ model_featparams = int(in0shape[1])
248
+ else:
249
+ model_featparams = int(in0shape[2])
250
+ if not isinstance(in0shape[1], str):
251
+ model_tsteps = int(in0shape[1])
252
+ logger.debug(f"Onnx model has fixed timesteps: {model_tsteps}.")
253
+ else:
254
+ model_tsteps = -1
255
+ logger.debug("Onnx model has dynamic timesteps dimension size.")
256
+
257
+ out_names = [n.name for n in session.get_outputs()]
258
+
259
+ if in0type.find("float16") != -1:
260
+ model_is_fp16 = True
261
+ logger.info("Detected input of float16, converting all feature inputs to that type.")
262
+ else:
263
+ model_is_fp16 = False
264
+
265
+ logger.info(f"Read and compiled ONNX model from {model_path}.")
266
+
267
+ start_time = time.monotonic()
268
+
269
+ if mixdb is not None and hparams["batch_size"] == 1:
270
+ if hparams["feature"] != feature_mode: # warn on mis-match, but TBD could be sov-mode
271
+ logger.warning("Mixture feature does not match model feature, this inference run may fail.")
272
+ logger.info(f"Processing {len(p_mixids)} mixtures from SonusAI mixdb ...")
273
+ logger.info(f"Using OnnxRT provider {session.get_providers()} ...")
274
+
275
+ for mixid in p_mixids:
276
+ # feature data is now always fp32 and frames x stride x feature_params
277
+ feat_dat, _ = mixdb.mixture_ft(mixid)
278
+ if feat_dat.shape[1] > 1: # stride mode num_frames overrides batch dim, no reshape
279
+ stride_mode = 1
280
+ batch_size = feat_dat.shape[0] # num_frames in stride mode becomes batch size
281
+ if hparams["timesteps"] == 0:
282
+ # no timestep dimension, remove the dimension
283
+ timesteps = 0
284
+ feat_dat = np.reshape(feat_dat, [batch_size, num_featparams])
285
+ else:
286
+ # fit frames into timestep dimension (TSE mode) and knowing batch_size = 1
287
+ timesteps = feat_dat.shape[0]
288
+ feat_dat = np.transpose(feat_dat, (1, 0, 2)) # transpose to 1 x frames=tsteps x feat_params
289
+
290
+ if model_is_fp16:
291
+ feat_dat = np.float16(feat_dat) # type: ignore[assignment]
292
+
293
+ # run inference, ort session wants i.e. batch x timesteps x feat_params, outputs numpy BxTxFP or BxFP
294
+ predict = session.run(out_names, {in0name: feat_dat})[0]
295
+ # predict, _ = reshape_outputs(predict=predict[0], timesteps=frames) # frames x feat_params
296
+
297
+ output_fname = join(output_dir, mixdb.mixture(mixid).name)
298
+ with h5py.File(output_fname + ".h5", "a") as f:
299
+ if "predict" in f:
300
+ del f["predict"]
301
+ f.create_dataset("predict", data=predict)
302
+ if wav:
303
+ # note only makes sense if model is predicting audio, i.e., timestep dimension exists
304
+ # predict_audio wants [frames, channels, feature_parameters] equivalent to timesteps, batch, bins
305
+ predict = np.transpose(predict, [1, 0, 2])
306
+ predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
307
+ owav_name = splitext(output_fname)[0] + "_predict.wav"
308
+ write_audio(owav_name, predict_audio)
309
+
310
+ else: # TBD add support
311
+ logger.info("Mixture database does not exist or batch_size is not equal to one, exiting ...")
312
+
313
+ end_time = time.monotonic()
314
+ logger.info(f"Completed in {seconds_to_hms(seconds=end_time - start_time)}")
315
+ logger.info("")
316
+
317
+
318
+ if __name__ == "__main__":
319
+ from sonusai import exception_handler
320
+ from sonusai.utils.keyboard_interrupt import register_keyboard_interrupt
321
+
322
+ register_keyboard_interrupt()
323
+ try:
324
+ main()
325
+ except Exception as e:
326
+ exception_handler(e)
327
+
328
+ # mixdb_path = None
329
+ # mixdb: MixtureDatabase | None = None
330
+ # p_mixids: list[int] = []
331
+ # entries: list[PathInfo] = []
332
+ #
333
+ # if len(data_paths) == 1 and isdir(data_paths[0]):
334
+ # # Assume it's a single path to SonusAI mixdb subdir
335
+ # in_basename = basename(normpath(data_paths[0]))
336
+ # mixdb_path = data_paths[0]
337
+ # else:
338
+ # # search all data paths for .wav, .flac (or whatever is specified in include)
339
+ # in_basename = ""
340
+
341
+ # if mixdb_path is not None: # a mixdb is found and loaded
342
+ # # Assume it's a single path to SonusAI mixdb subdir
343
+ # num_featparams = mixdb.feature_parameters
344
+ # logger.debug(f"SonusAI mixdb: found {mixdb.num_mixtures} mixtures with {num_featparams} classes")
345
+ # p_mixids = mixdb.mixids_to_list(mixids)
346
+ # if len(p_mixids) != mixdb.num_mixtures:
347
+ # logger.info(f"Processing a subset of {p_mixids} from available mixtures.")
348
+ # else:
349
+ # for p in data_paths:
350
+ # location = join(realpath(abspath(p)), "**", include)
351
+ # logger.debug(f"Processing files in {location}")
352
+ # for file in braced_iglob(pathname=location, recursive=True):
353
+ # name = file
354
+ # entries.append(PathInfo(abs_path=file, audio_filepath=name))
355
+ # logger.info(f"{len(data_paths)} data paths specified, found {len(pfile)} audio files.")
356
+
357
+ # feature, _ = reshape_inputs(
358
+ # feature=feature,
359
+ # batch_size=1,
360
+ # timesteps=timesteps,
361
+ # flatten=hparams["flatten"],
362
+ # add1ch=hparams["add1ch"],
363
+ # )
File without changes
@@ -0,0 +1,156 @@
1
+ """
2
+ Parse 'expand' expressions.
3
+
4
+ This module provides functionality to find, parse and evaluate 'expand'
5
+ expressions in text, supporting nested expressions and random value generation.
6
+ """
7
+
8
+ import re
9
+ from dataclasses import dataclass
10
+
11
+ import pyparsing as pp
12
+
13
+ # Constants
14
+ SAI_EXPAND_PATTERN = r"expand\("
15
+ SAI_RAND_LITERAL = "rand"
16
+ SAI_EXPAND_LITERAL = "expand"
17
+
18
+
19
+ @dataclass
20
+ class Match:
21
+ """Represents a matched 'expand' expression in text."""
22
+
23
+ group: str
24
+ span: tuple[int, int]
25
+
26
+ def start(self) -> int:
27
+ """Return the start position of the match."""
28
+ return self.span[0]
29
+
30
+ def end(self) -> int:
31
+ """Return the end position of the match."""
32
+ return self.span[1]
33
+
34
+
35
+ def find_matching_parenthesis(text: str, start_pos: int) -> int:
36
+ """Find the position of the matching closing parenthesis.
37
+
38
+ :param text: The text to search in
39
+ :param start_pos: Position after the opening parenthesis
40
+ :return: Position after the matching closing parenthesis
41
+ :raises ValueError: If no matching parenthesis is found
42
+ """
43
+ num_lparen = 1
44
+ pos = start_pos
45
+
46
+ while num_lparen != 0 and pos < len(text):
47
+ if text[pos] == "(":
48
+ num_lparen += 1
49
+ elif text[pos] == ")":
50
+ num_lparen -= 1
51
+ pos += 1
52
+
53
+ if num_lparen != 0:
54
+ raise ValueError(f"Unbalanced parenthesis in '{text}'")
55
+
56
+ return pos
57
+
58
+
59
+ def find_expand(text: str) -> list[Match]:
60
+ """Find all 'expand' expressions in the text.
61
+
62
+ :param text: The text to search in
63
+ :return: List of Match objects for each 'expand' expression
64
+ :raises ValueError: If parentheses are unbalanced
65
+ """
66
+ results = []
67
+ matches = re.finditer(SAI_EXPAND_PATTERN, text)
68
+
69
+ for match in matches:
70
+ start = match.start()
71
+ end_pos = find_matching_parenthesis(text, match.end())
72
+ results.append(Match(group=text[start:end_pos], span=(start, end_pos)))
73
+
74
+ return results
75
+
76
+
77
+ def create_parser() -> pp.ParserElement:
78
+ """Create a pyparsing parser for 'expand' expressions.
79
+
80
+ :return: Parser for 'expand' expressions
81
+ """
82
+ lparen = pp.Literal("(")
83
+ rparen = pp.Literal(")")
84
+ comma = pp.Literal(",")
85
+
86
+ # Define numeric types
87
+ real_number = pp.pyparsing_common.real
88
+ signed_integer = pp.pyparsing_common.signed_integer
89
+ number = real_number | signed_integer
90
+
91
+ # Define identifiers and expressions
92
+ identifier = pp.Word(pp.alphanums + "_.-")
93
+
94
+ # Define 'rand' expression
95
+ rand_literal = pp.Literal(SAI_RAND_LITERAL)
96
+ rand_expression = (rand_literal + lparen + number + comma + number + rparen).set_parse_action(
97
+ lambda tokens: "".join(map(str, tokens))
98
+ )
99
+
100
+ # Define 'expand' expression
101
+ expand_literal = pp.Literal(SAI_EXPAND_LITERAL)
102
+ expand_args = pp.DelimitedList(rand_expression | identifier, min=1)
103
+ expand_expression = expand_literal + lparen + expand_args("args") + rparen
104
+
105
+ return expand_expression
106
+
107
+
108
+ def parse_expand(text: str) -> list[str]:
109
+ """Parse an 'expand' expression and extract its arguments.
110
+
111
+ :param text: Text containing an 'expand' expression
112
+ :return: List of argument values
113
+ :raises ValueError: If the expression cannot be parsed
114
+ """
115
+ parser = create_parser()
116
+
117
+ try:
118
+ result = parser.parse_string(text)
119
+ return list(result.args)
120
+ except pp.ParseException as e:
121
+ raise ValueError(f"Could not parse '{text}'") from e
122
+
123
+
124
+ def expand(directive: str) -> list[str]:
125
+ """Evaluate the 'expand' directive.
126
+
127
+ Recursively processes and expands 'expand' expressions in the text,
128
+ starting with the innermost expressions.
129
+
130
+ :param directive: Directive to evaluate
131
+ :return: A list of the expanded results
132
+ """
133
+ # Initialize with input
134
+ expanded = [directive]
135
+
136
+ # Look for 'expand' patterns
137
+ matches = find_expand(directive)
138
+
139
+ # If no pattern found, return the original text
140
+ if not matches:
141
+ return expanded
142
+
143
+ # Remove the original text as we'll replace it with expanded versions
144
+ expanded.pop()
145
+
146
+ # Start with the innermost match (last in the list)
147
+ match = matches[-1]
148
+ prelude = directive[: match.start()]
149
+ postlude = directive[match.end() :]
150
+
151
+ # Process each value in the 'expand' expression
152
+ for value in parse_expand(match.group):
153
+ # Recursively expand the text with each replacement value
154
+ expanded.extend(expand(prelude + value + postlude))
155
+
156
+ return expanded
@@ -0,0 +1,129 @@
1
+ from dataclasses import dataclass
2
+ from dataclasses import fields
3
+ from typing import Any
4
+
5
+ from pyparsing import Literal
6
+ from pyparsing import Optional
7
+ from pyparsing import ParseException
8
+ from pyparsing import ParseResults
9
+ from pyparsing import QuotedString
10
+ from pyparsing import Suppress
11
+ from pyparsing import Word
12
+ from pyparsing import ZeroOrMore
13
+ from pyparsing import alphanums
14
+ from pyparsing import alphas
15
+ from pyparsing import pyparsing_common
16
+
17
+
18
+ @dataclass
19
+ class SourceDirective:
20
+ """Represents a parsed source directive with its parameters."""
21
+
22
+ unique: str | None = None
23
+ repeat: bool = False
24
+ loop: bool = False
25
+ start: int = 0
26
+
27
+
28
+ def parse_source_directive(directive: str) -> SourceDirective:
29
+ """Parse a source directive into its components.
30
+
31
+ Parses directives of the form:
32
+ - choose(unique=None, repeat=False, loop=False, start=0)
33
+ - sequence(unique=None, loop=False, start=0)
34
+
35
+ :param directive: The directive string to parse
36
+ :return: SourceDirective with parsed parameters
37
+ :raises ValueError: If the directive format is invalid
38
+ """
39
+ # Check for a simple directive without parentheses
40
+ if _is_simple_directive(directive):
41
+ return SourceDirective()
42
+
43
+ # Parse full directive with parameters
44
+ parsed_tokens = _parse_directive_grammar(directive)
45
+ params = _process_parsed_parameters(parsed_tokens, directive)
46
+
47
+ return SourceDirective(**params)
48
+
49
+
50
+ def _get_valid_parameters() -> set[str]:
51
+ """Get valid parameter names from SourceDirective dataclass fields."""
52
+ return {field.name for field in fields(SourceDirective)}
53
+
54
+
55
+ def _is_simple_directive(directive: str) -> bool:
56
+ """Check if the directive is just a function name without parentheses."""
57
+ directive_type = Literal("choose") | Literal("sequence")
58
+ try:
59
+ directive_type.parseString(directive, parseAll=True)
60
+ except ParseException:
61
+ return False
62
+ return True
63
+
64
+
65
+ def _parse_directive_grammar(directive: str) -> ParseResults:
66
+ """Parse directive using pyparsing grammar and return tokens."""
67
+ # Define grammar components
68
+ directive_type = Literal("choose") | Literal("sequence")
69
+ identifier = Word(alphas + "_", alphanums + "_")
70
+
71
+ # Value types
72
+ none_value = Literal("None")
73
+ true_value = Literal("True")
74
+ false_value = Literal("False")
75
+ integer = pyparsing_common.signed_integer()
76
+ quoted_string = QuotedString('"', escChar="\\") | QuotedString("'", escChar="\\")
77
+ non_quoted_string = Word(alphanums + "_-./")
78
+ rand_value = Literal("rand")
79
+
80
+ # Combined value and parameter grammar
81
+ value = none_value | true_value | false_value | integer | quoted_string | non_quoted_string | rand_value
82
+ parameter = identifier + Suppress("=") + value
83
+ param_list = Optional(parameter + ZeroOrMore(Suppress(",") + parameter) + Optional(Suppress(",")))
84
+ directive_expr = Suppress(directive_type) + Suppress("(") + param_list + Suppress(")")
85
+
86
+ try:
87
+ return directive_expr.parseString(directive, parseAll=True)
88
+ except ParseException as e:
89
+ raise ValueError(f"Invalid directive format: '{directive}'. Error: {e}") from e
90
+
91
+
92
+ def _process_parsed_parameters(parsed_tokens: ParseResults, directive: str) -> dict:
93
+ """Convert parsed tokens to a parameter dictionary with type conversion."""
94
+ params = {}
95
+ valid_params = _get_valid_parameters()
96
+
97
+ for i in range(0, len(parsed_tokens), 2):
98
+ param_name = parsed_tokens[i]
99
+ param_value = parsed_tokens[i + 1]
100
+
101
+ _validate_parameter_name(param_name, directive, valid_params)
102
+ params[param_name] = _convert_parameter_value(param_value)
103
+
104
+ return params
105
+
106
+
107
+ def _validate_parameter_name(param_name: Any, directive: str, valid_params: set[str]) -> None:
108
+ """Validate that parameter name is allowed."""
109
+ if param_name not in valid_params:
110
+ raise ValueError(
111
+ f"Invalid directive format: '{directive}'. Error: parameter must be one of {', '.join(sorted(valid_params))}."
112
+ )
113
+
114
+
115
+ def _convert_parameter_value(param_value):
116
+ """Convert string representations to appropriate Python types."""
117
+ if param_value == "None":
118
+ return None
119
+ elif param_value in ("True", "true", "Yes", "yes"):
120
+ return True
121
+ elif param_value in ("False", "false", "No", "no"):
122
+ return False
123
+ elif param_value == "rand":
124
+ return "rand"
125
+ elif isinstance(param_value, int):
126
+ return param_value
127
+ else:
128
+ # String value (already unquoted by pyparsing)
129
+ return param_value