sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sonusai/__init__.py +16 -3
  2. sonusai/audiofe.py +241 -77
  3. sonusai/calc_metric_spenh.py +71 -73
  4. sonusai/config/__init__.py +3 -0
  5. sonusai/config/config.py +61 -0
  6. sonusai/config/config.yml +20 -0
  7. sonusai/config/constants.py +8 -0
  8. sonusai/constants.py +11 -0
  9. sonusai/data/genmixdb.yml +21 -36
  10. sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
  11. sonusai/deprecated/plot.py +4 -5
  12. sonusai/doc/doc.py +4 -4
  13. sonusai/doc.py +11 -4
  14. sonusai/genft.py +43 -45
  15. sonusai/genmetrics.py +25 -19
  16. sonusai/genmix.py +54 -82
  17. sonusai/genmixdb.py +88 -264
  18. sonusai/ir_metric.py +30 -34
  19. sonusai/lsdb.py +41 -48
  20. sonusai/main.py +15 -22
  21. sonusai/metrics/calc_audio_stats.py +4 -293
  22. sonusai/metrics/calc_class_weights.py +4 -4
  23. sonusai/metrics/calc_optimal_thresholds.py +8 -5
  24. sonusai/metrics/calc_pesq.py +2 -2
  25. sonusai/metrics/calc_segsnr_f.py +4 -4
  26. sonusai/metrics/calc_speech.py +25 -13
  27. sonusai/metrics/class_summary.py +7 -7
  28. sonusai/metrics/confusion_matrix_summary.py +5 -5
  29. sonusai/metrics/one_hot.py +4 -4
  30. sonusai/metrics/snr_summary.py +7 -7
  31. sonusai/metrics_summary.py +38 -45
  32. sonusai/mixture/__init__.py +4 -104
  33. sonusai/mixture/audio.py +10 -39
  34. sonusai/mixture/class_balancing.py +103 -0
  35. sonusai/mixture/config.py +251 -271
  36. sonusai/mixture/constants.py +35 -39
  37. sonusai/mixture/data_io.py +25 -36
  38. sonusai/mixture/db_datatypes.py +58 -22
  39. sonusai/mixture/effects.py +386 -0
  40. sonusai/mixture/feature.py +7 -11
  41. sonusai/mixture/generation.py +478 -628
  42. sonusai/mixture/helpers.py +82 -184
  43. sonusai/mixture/ir_delay.py +3 -4
  44. sonusai/mixture/ir_effects.py +77 -0
  45. sonusai/mixture/log_duration_and_sizes.py +6 -12
  46. sonusai/mixture/mixdb.py +910 -729
  47. sonusai/mixture/pad_audio.py +35 -0
  48. sonusai/mixture/resample.py +7 -0
  49. sonusai/mixture/sox_effects.py +195 -0
  50. sonusai/mixture/sox_help.py +650 -0
  51. sonusai/mixture/spectral_mask.py +2 -2
  52. sonusai/mixture/truth.py +17 -15
  53. sonusai/mixture/truth_functions/crm.py +12 -12
  54. sonusai/mixture/truth_functions/energy.py +22 -22
  55. sonusai/mixture/truth_functions/file.py +5 -5
  56. sonusai/mixture/truth_functions/metadata.py +4 -4
  57. sonusai/mixture/truth_functions/metrics.py +4 -4
  58. sonusai/mixture/truth_functions/phoneme.py +3 -3
  59. sonusai/mixture/truth_functions/sed.py +11 -13
  60. sonusai/mixture/truth_functions/target.py +10 -10
  61. sonusai/mkwav.py +26 -29
  62. sonusai/onnx_predict.py +240 -88
  63. sonusai/queries/__init__.py +2 -2
  64. sonusai/queries/queries.py +38 -34
  65. sonusai/speech/librispeech.py +1 -1
  66. sonusai/speech/mcgill.py +1 -1
  67. sonusai/speech/timit.py +2 -2
  68. sonusai/summarize_metric_spenh.py +10 -17
  69. sonusai/utils/__init__.py +7 -1
  70. sonusai/utils/asl_p56.py +2 -2
  71. sonusai/utils/asr.py +2 -2
  72. sonusai/utils/asr_functions/aaware_whisper.py +4 -5
  73. sonusai/utils/choice.py +31 -0
  74. sonusai/utils/compress.py +1 -1
  75. sonusai/utils/dataclass_from_dict.py +19 -1
  76. sonusai/utils/energy_f.py +3 -3
  77. sonusai/utils/evaluate_random_rule.py +15 -0
  78. sonusai/utils/keyboard_interrupt.py +12 -0
  79. sonusai/utils/onnx_utils.py +3 -17
  80. sonusai/utils/print_mixture_details.py +21 -19
  81. sonusai/utils/{temp_seed.py → rand.py} +3 -3
  82. sonusai/utils/read_predict_data.py +2 -2
  83. sonusai/utils/reshape.py +3 -3
  84. sonusai/utils/stratified_shuffle_split.py +3 -3
  85. sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
  86. sonusai/utils/write_audio.py +2 -2
  87. sonusai/vars.py +11 -4
  88. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
  89. sonusai-1.0.2.dist-info/RECORD +138 -0
  90. sonusai/mixture/augmentation.py +0 -444
  91. sonusai/mixture/class_count.py +0 -15
  92. sonusai/mixture/eq_rule_is_valid.py +0 -45
  93. sonusai/mixture/target_class_balancing.py +0 -107
  94. sonusai/mixture/targets.py +0 -175
  95. sonusai-0.20.3.dist-info/RECORD +0 -128
  96. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
  97. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
sonusai/onnx_predict.py CHANGED
@@ -14,16 +14,19 @@ The ONNX Runtime (ort) inference engine is used to execute the inference.
14
14
 
15
15
  Inputs:
16
16
  MODEL ONNX model .onnx file of a trained model (weights are expected to be in the file).
17
+ The model must also include required Sonusai hyperparameters. See theSonusai torchl_onnx command.
17
18
 
18
- DATA The input data must be one of the following:
19
- * WAV
20
- Using the given model, generate feature data and run prediction. A model file must be
21
- provided. The MIXID is ignored.
22
-
23
- * directory
24
- Using the given SonusAI mixture database directory, generate feature and truth data if not found.
25
- Run prediction. The MIXID is required.
19
+ DATA A string which must be one of the following:
20
+ 1. Path to a single file. The prediction data is written to <filename_predict.*> in same location.
21
+ 2. Path to a Sonusai Mixture Database directory.
22
+ - Sonusai mixture database directory, prediction files will be named mixid_predict.*
23
+ - MIXID will select a subset of mixture ids
24
+ 3. Directory with audio files found recursively within. See GLOB audio file extensions below.
25
+ 4. Regex resolving to a list of files.
26
+ - Subdirectory containing audio files with extension
27
+ - Regex resolving to a list of audio files
26
28
 
29
+ generate feature and truth data if not found.
27
30
 
28
31
  Note there are multiple ways to process model prediction over multiple audio data files:
29
32
  1. TSE (timestep single extension): mixture transform frames are fit into the timestep dimension and the model run as
@@ -42,33 +45,68 @@ TBD not sure below make sense, need to continue ??
42
45
 
43
46
  Outputs the following to opredict-<TIMESTAMP> directory:
44
47
  <id>
45
- predict.pkl
48
+ predict.h5
46
49
  onnx_predict.log
47
50
 
48
51
  """
49
52
 
50
- import signal
51
-
52
-
53
- def signal_handler(_sig, _frame):
54
- import sys
55
-
56
- from sonusai import logger
57
53
 
58
- logger.info("Canceled due to keyboard interrupt")
59
- sys.exit(1)
54
+ def process_path(path, ext_list: list[str] | None = None):
55
+ """
56
+ Check path which can be a single file, a subdirectory, or a regex
57
+ return:
58
+ - a list of files with matching extensions to any in ext_list provided (i.e. ['.wav', '.mp3', '.acc'])
59
+ - the basedir of the path, if
60
+ """
61
+ import glob
62
+ from os.path import abspath
63
+ from os.path import commonprefix
64
+ from os.path import dirname
65
+ from os.path import isdir
66
+ from os.path import isfile
67
+ from os.path import join
60
68
 
69
+ from sonusai.utils import braced_iglob
61
70
 
62
- signal.signal(signal.SIGINT, signal_handler)
71
+ if ext_list is None:
72
+ ext_list = [".wav", ".WAV", ".flac", ".FLAC", ".mp3", ".aac"]
73
+
74
+ # Check if the path is a single file, and return it as a list with the dirname
75
+ if isfile(path):
76
+ if any(path.endswith(ext) for ext in ext_list):
77
+ basedir = dirname(path) # base directory
78
+ if not basedir:
79
+ basedir = "./"
80
+ return [path], basedir
81
+ else:
82
+ return [], []
83
+
84
+ # Check if the path is a dir, recursively find all files any of the specified extensions, return file list and dir
85
+ if isdir(path):
86
+ matching_files = []
87
+ for ext in ext_list:
88
+ matching_files.extend(glob.glob(join(path, "**/*" + ext), recursive=True))
89
+ return matching_files, path
90
+
91
+ # Process as a regex, return list of filenames and basedir
92
+ apath = abspath(path) # join(abspath(path), "**", "*.{wav,flac,WAV}")
93
+ matching_files = []
94
+ for file in braced_iglob(pathname=apath, recursive=True):
95
+ matching_files.append(file)
96
+ if matching_files:
97
+ basedir = commonprefix(matching_files) # Find basedir
98
+ return matching_files, basedir
99
+ else:
100
+ return [], []
63
101
 
64
102
 
65
103
  def main() -> None:
66
104
  from docopt import docopt
67
105
 
68
- import sonusai
106
+ from sonusai import __version__ as sai_version
69
107
  from sonusai.utils import trim_docstring
70
108
 
71
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
109
+ args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
72
110
 
73
111
  verbose = args["--verbose"]
74
112
  wav = args["--write-wav"]
@@ -77,18 +115,23 @@ def main() -> None:
77
115
  model_path = args["MODEL"]
78
116
  data_paths = args["DATA"]
79
117
 
118
+ # Quick check of CPU and GPU devices
119
+ import re
120
+ import subprocess
121
+ import time
80
122
  from os import makedirs
81
- from os.path import abspath
82
123
  from os.path import basename
124
+ from os.path import exists
83
125
  from os.path import isdir
126
+ from os.path import isfile
84
127
  from os.path import join
85
128
  from os.path import normpath
86
- from os.path import realpath
87
129
  from os.path import splitext
88
130
 
89
131
  import h5py
90
132
  import numpy as np
91
133
  import onnxruntime as ort
134
+ import psutil
92
135
 
93
136
  from sonusai import create_file_handler
94
137
  from sonusai import initial_log_messages
@@ -96,66 +139,122 @@ def main() -> None:
96
139
  from sonusai import update_console_handler
97
140
  from sonusai.mixture import MixtureDatabase
98
141
  from sonusai.mixture import get_audio_from_feature
99
- from sonusai.utils import PathInfo
100
- from sonusai.utils import braced_iglob
101
142
  from sonusai.utils import create_ts_name
102
143
  from sonusai.utils import load_ort_session
103
- from sonusai.utils import reshape_inputs
144
+ from sonusai.utils import seconds_to_hms
104
145
  from sonusai.utils import write_audio
105
146
 
106
- mixdb_path = None
107
- mixdb: MixtureDatabase | None = None
108
- p_mixids: list[int] = []
109
- entries: list[PathInfo] = []
110
-
111
- if len(data_paths) == 1 and isdir(data_paths[0]):
112
- # Assume it's a single path to SonusAI mixdb subdir
113
- in_basename = basename(normpath(data_paths[0]))
114
- mixdb_path = data_paths[0]
147
+ num_cpu = psutil.cpu_count()
148
+ cpu_percent = psutil.cpu_percent(interval=1)
149
+ print(f"#CPUs: {num_cpu}, current CPU utilization: {cpu_percent}%")
150
+ print(f"Memory utilization: {psutil.virtual_memory().percent}%")
151
+
152
+ vga_devices = [
153
+ line.split(" ", 3)[-1]
154
+ for line in subprocess.check_output("lspci | grep -i vga", shell=True).decode().splitlines()
155
+ ]
156
+ nv_devs = list(filter(lambda x: "nvidia" in x.lower(), vga_devices))
157
+ nv_mods = [re.search(r"\[.*?\]", device).group(0) if re.search(r"\[.*?\]", device) else None for device in nv_devs]
158
+ if len(nv_mods) > 0:
159
+ print(f"{len(nv_mods)} Nvidia devices present: {nv_mods}") # prints model names
115
160
  else:
116
- # search all data paths for .wav, .flac (or whatever is specified in include)
117
- in_basename = ""
118
-
119
- output_dir = create_ts_name("opredict-" + in_basename)
120
- makedirs(output_dir, exist_ok=True)
121
-
122
- # Setup logging file
123
- create_file_handler(join(output_dir, "onnx-predict.log"))
124
- update_console_handler(verbose)
125
- initial_log_messages("onnx_predict")
161
+ print("No cuda devices present, using cpu")
162
+
163
+ avail_providers = ort.get_available_providers()
164
+ print(f"Loaded ONNX Runtime, available providers: {avail_providers}.")
165
+ if len(nv_mods) > 0:
166
+ print(
167
+ "If GPU is desired, need to replace onnxruntime with onnxruntime-gpu i.e. using pip:"
168
+ "> pip uninstall onnxruntime"
169
+ "> pip install onnxruntime-gpu\n\n"
170
+ )
171
+
172
+ # Quick check that model is valid
173
+ if exists(model_path) and isfile(model_path):
174
+ try:
175
+ session = ort.InferenceSession(model_path)
176
+ options = ort.SessionOptions()
177
+ except Exception as e:
178
+ print(f"Error: could not load ONNX model from {model_path}: {e}")
179
+ raise SystemExit(1) from e
180
+ else:
181
+ print(f"Error: model file path is not valid: {model_path}")
182
+ raise SystemExit(1)
126
183
 
127
- providers = ort.get_available_providers()
128
- logger.info(f"Loaded ONNX Runtime, available providers: {providers}.")
184
+ # Check datapath is valid
185
+ if len(data_paths) == 1 and isdir(data_paths[0]): # Try opening as mixdb subdir
186
+ mixdb_path = data_paths[0]
187
+ try:
188
+ mixdb = MixtureDatabase(mixdb_path)
189
+ except Exception:
190
+ mixdb_path = None
191
+ in_basename = basename(normpath(data_paths[0]))
192
+ output_dir = create_ts_name("opredict-" + in_basename)
193
+ num_featparams = mixdb.feature_parameters
194
+ print(f"Loaded SonusAI mixdb with {mixdb.num_mixtures} mixtures and {num_featparams} classes")
195
+ p_mixids = mixdb.mixids_to_list(mixids)
196
+ feature_mode = mixdb.feature
129
197
 
198
+ if mixdb_path is None:
199
+ if verbose:
200
+ print(f"Checking {len(data_paths)} locations ... ")
201
+ # Check location, default ext are ['.wav', '.WAV', '.flac', '.FLAC', '.mp3', '.aac']
202
+ pfiles, basedir = process_path(data_paths)
203
+ if pfiles is None or len(pfiles) < 1:
204
+ print(f"No audio files or Sonusai mixture database found in {data_paths}, exiting ...")
205
+ raise SystemExit(1)
206
+ else:
207
+ pfiles = sorted(pfiles, key=basename)
208
+ output_dir = basedir
209
+
210
+ if mixdb_path is not None or len(pfiles) > 1: # log file only if mixdb or more than one file
211
+ makedirs(output_dir, exist_ok=True)
212
+ # Setup logging file
213
+ create_file_handler(join(output_dir, "onnx-predict.log"))
214
+ update_console_handler(verbose)
215
+ initial_log_messages("onnx_predict")
216
+ # print some previous messages
217
+ logger.info(f"Loaded ONNX Runtime, available providers: {avail_providers}.")
218
+ if mixdb_path:
219
+ logger.debug(f"Loaded SonusAI mixdb with {mixdb.num_mixtures} mixtures and {num_featparams} classes")
220
+ if len(p_mixids) != mixdb.num_mixtures:
221
+ logger.info(f"Processing a subset of {len(p_mixids)} from available mixtures.")
222
+
223
+ # Reload model/session and do more thorough checking
130
224
  session, options, model_root, hparams, sess_inputs, sess_outputs = load_ort_session(model_path)
225
+ if "CUDAExecutionProvider" in avail_providers:
226
+ session.set_providers(["CUDAExecutionProvider"])
131
227
  if hparams is None:
132
228
  logger.error("Error: ONNX model does not have required SonusAI hyperparameters, cannot proceed.")
133
229
  raise SystemExit(1)
134
- if len(sess_inputs) != 1:
230
+
231
+ if len(sess_inputs) != 1: # TBD update to support state_in and state_out
135
232
  logger.error(f"Error: ONNX model does not have 1 input, but {len(sess_inputs)}. Exit due to unknown input.")
136
233
 
137
234
  in0name = sess_inputs[0].name
138
235
  in0type = sess_inputs[0].type
139
- out_names = [n.name for n in session.get_outputs()]
140
-
141
- logger.info(f"Read and compiled ONNX model from {model_path}.")
236
+ in0shape = sess_inputs[0].shape # a list
237
+ # Check for 2 cases of model feature input shape: batch x timesteps x fparams or batch x fparams
238
+ if not isinstance(in0shape[0], str):
239
+ model_batchsz = int(in0shape[0])
240
+ logger.debug(f"Onnx model has fixed batch_size: {model_batchsz}.")
241
+ else:
242
+ model_batchsz = -1
243
+ logger.debug("Onnx model has a dynamic batch_size.")
142
244
 
143
- if mixdb_path is not None:
144
- # Assume it's a single path to SonusAI mixdb subdir
145
- logger.debug(f"Attempting to load mixture database from {mixdb_path}")
146
- mixdb = MixtureDatabase(mixdb_path)
147
- logger.info(f"SonusAI mixdb: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes")
148
- p_mixids = mixdb.mixids_to_list(mixids)
149
- if len(p_mixids) != mixdb.num_mixtures:
150
- logger.info(f"Processing a subset of {p_mixids} from available mixtures.")
245
+ if len(in0shape) < 3:
246
+ model_tsteps = 0
247
+ model_featparams = int(in0shape[1])
151
248
  else:
152
- for p in data_paths:
153
- location = join(realpath(abspath(p)), "**", include)
154
- logger.debug(f"Processing {location}")
155
- for file in braced_iglob(pathname=location, recursive=True):
156
- name = file
157
- entries.append(PathInfo(abs_path=file, audio_filepath=name))
158
- logger.info(f"{len(data_paths)} data paths specified, found {len(entries)} audio files.")
249
+ model_featparams = int(in0shape[2])
250
+ if not isinstance(in0shape[1], str):
251
+ model_tsteps = int(in0shape[1])
252
+ logger.debug(f"Onnx model has fixed timesteps: {model_tsteps}.")
253
+ else:
254
+ model_tsteps = -1
255
+ logger.debug("Onnx model has dynamic timesteps dimension size.")
256
+
257
+ out_names = [n.name for n in session.get_outputs()]
159
258
 
160
259
  if in0type.find("float16") != -1:
161
260
  model_is_fp16 = True
@@ -163,38 +262,40 @@ def main() -> None:
163
262
  else:
164
263
  model_is_fp16 = False
165
264
 
265
+ logger.info(f"Read and compiled ONNX model from {model_path}.")
266
+
267
+ start_time = time.monotonic()
268
+
166
269
  if mixdb is not None and hparams["batch_size"] == 1:
167
- # mixdb input
168
- # Assume (of course) that mixdb feature, etc. is what model expects
169
- if hparams["feature"] != mixdb.feature:
270
+ if hparams["feature"] != feature_mode: # warn on mis-match, but TBD could be sov-mode
170
271
  logger.warning("Mixture feature does not match model feature, this inference run may fail.")
171
- # no choice, can't use hparams.feature since it's different from the mixdb
172
- feature_mode = mixdb.feature
272
+ logger.info(f"Processing {len(p_mixids)} mixtures from SonusAI mixdb ...")
273
+ logger.info(f"Using OnnxRT provider {session.get_providers()} ...")
173
274
 
174
275
  for mixid in p_mixids:
175
- # frames x stride x feature_params
176
- feature, _ = mixdb.mixture_ft(mixid)
276
+ # feature data is now always fp32 and frames x stride x feature_params
277
+ feat_dat, _ = mixdb.mixture_ft(mixid)
278
+ if feat_dat.shape[1] > 1: # stride mode num_frames overrides batch dim, no reshape
279
+ stride_mode = 1
280
+ batch_size = feat_dat.shape[0] # num_frames in stride mode becomes batch size
177
281
  if hparams["timesteps"] == 0:
178
- # no timestep dimension, reshape will handle
282
+ # no timestep dimension, remove the dimension
179
283
  timesteps = 0
284
+ feat_dat = np.reshape(feat_dat, [batch_size, num_featparams])
180
285
  else:
181
- # fit frames into timestep dimension (TSE mode)
182
- timesteps = feature.shape[0]
183
-
184
- feature, _ = reshape_inputs(
185
- feature=feature,
186
- batch_size=1,
187
- timesteps=timesteps,
188
- flatten=hparams["flatten"],
189
- add1ch=hparams["add1ch"],
190
- )
286
+ # fit frames into timestep dimension (TSE mode) and knowing batch_size = 1
287
+ timesteps = feat_dat.shape[0]
288
+ feat_dat = np.transpose(feat_dat, (1, 0, 2)) # transpose to 1 x frames=tsteps x feat_params
289
+
191
290
  if model_is_fp16:
192
- feature = np.float16(feature) # type: ignore[assignment]
291
+ feat_dat = np.float16(feat_dat) # type: ignore[assignment]
292
+
193
293
  # run inference, ort session wants i.e. batch x timesteps x feat_params, outputs numpy BxTxFP or BxFP
194
- predict = session.run(out_names, {in0name: feature})[0]
294
+ predict = session.run(out_names, {in0name: feat_dat})[0]
195
295
  # predict, _ = reshape_outputs(predict=predict[0], timesteps=frames) # frames x feat_params
296
+
196
297
  output_fname = join(output_dir, mixdb.mixture(mixid).name)
197
- with h5py.File(output_fname, "a") as f:
298
+ with h5py.File(output_fname + ".h5", "a") as f:
198
299
  if "predict" in f:
199
300
  del f["predict"]
200
301
  f.create_dataset("predict", data=predict)
@@ -206,6 +307,57 @@ def main() -> None:
206
307
  owav_name = splitext(output_fname)[0] + "_predict.wav"
207
308
  write_audio(owav_name, predict_audio)
208
309
 
310
+ else: # TBD add support
311
+ logger.info("Mixture database does not exist or batch_size is not equal to one, exiting ...")
312
+
313
+ end_time = time.monotonic()
314
+ logger.info(f"Completed in {seconds_to_hms(seconds=end_time - start_time)}")
315
+ logger.info("")
316
+
209
317
 
210
318
  if __name__ == "__main__":
211
- main()
319
+ from sonusai import exception_handler
320
+ from sonusai.utils import register_keyboard_interrupt
321
+
322
+ register_keyboard_interrupt()
323
+ try:
324
+ main()
325
+ except Exception as e:
326
+ exception_handler(e)
327
+
328
+ # mixdb_path = None
329
+ # mixdb: MixtureDatabase | None = None
330
+ # p_mixids: list[int] = []
331
+ # entries: list[PathInfo] = []
332
+ #
333
+ # if len(data_paths) == 1 and isdir(data_paths[0]):
334
+ # # Assume it's a single path to SonusAI mixdb subdir
335
+ # in_basename = basename(normpath(data_paths[0]))
336
+ # mixdb_path = data_paths[0]
337
+ # else:
338
+ # # search all data paths for .wav, .flac (or whatever is specified in include)
339
+ # in_basename = ""
340
+
341
+ # if mixdb_path is not None: # a mixdb is found and loaded
342
+ # # Assume it's a single path to SonusAI mixdb subdir
343
+ # num_featparams = mixdb.feature_parameters
344
+ # logger.debug(f"SonusAI mixdb: found {mixdb.num_mixtures} mixtures with {num_featparams} classes")
345
+ # p_mixids = mixdb.mixids_to_list(mixids)
346
+ # if len(p_mixids) != mixdb.num_mixtures:
347
+ # logger.info(f"Processing a subset of {p_mixids} from available mixtures.")
348
+ # else:
349
+ # for p in data_paths:
350
+ # location = join(realpath(abspath(p)), "**", include)
351
+ # logger.debug(f"Processing files in {location}")
352
+ # for file in braced_iglob(pathname=location, recursive=True):
353
+ # name = file
354
+ # entries.append(PathInfo(abs_path=file, audio_filepath=name))
355
+ # logger.info(f"{len(data_paths)} data paths specified, found {len(pfile)} audio files.")
356
+
357
+ # feature, _ = reshape_inputs(
358
+ # feature=feature,
359
+ # batch_size=1,
360
+ # timesteps=timesteps,
361
+ # flatten=hparams["flatten"],
362
+ # add1ch=hparams["add1ch"],
363
+ # )
@@ -1,8 +1,8 @@
1
1
  # SonusAI query utilities
2
2
  # ruff: noqa: F401
3
3
 
4
+ from .queries import get_mixids_from_class_indices
4
5
  from .queries import get_mixids_from_noise
5
6
  from .queries import get_mixids_from_snr
6
- from .queries import get_mixids_from_target
7
+ from .queries import get_mixids_from_source
7
8
  from .queries import get_mixids_from_truth_function
8
- from .queries import get_mixids_from_class_indices
@@ -1,8 +1,8 @@
1
1
  from collections.abc import Callable
2
2
  from typing import Any
3
3
 
4
- from sonusai.mixture.datatypes import GeneralizedIDs
5
- from sonusai.mixture.mixdb import MixtureDatabase
4
+ from ..datatypes import GeneralizedIDs
5
+ from ..mixture.mixdb import MixtureDatabase
6
6
 
7
7
 
8
8
  def _true_predicate(_: Any) -> bool:
@@ -29,8 +29,8 @@ def get_mixids_from_mixture_field_predicate(
29
29
  criteria_set = set()
30
30
  for m_id in mixid_out:
31
31
  value = getattr(mixdb.mixture(m_id), field)
32
- if isinstance(value, list):
33
- for v in value:
32
+ if isinstance(value, dict):
33
+ for v in value.values():
34
34
  if predicate(v):
35
35
  criteria_set.add(v)
36
36
  elif predicate(value):
@@ -42,8 +42,8 @@ def get_mixids_from_mixture_field_predicate(
42
42
  result[criterion] = []
43
43
  for m_id in mixid_out:
44
44
  value = getattr(mixdb.mixture(m_id), field)
45
- if isinstance(value, list):
46
- for v in value:
45
+ if isinstance(value, dict):
46
+ for v in value.values():
47
47
  if v == criterion:
48
48
  result[criterion].append(m_id)
49
49
  elif value == criterion:
@@ -64,7 +64,7 @@ def get_mixids_from_truth_configs_field_predicate(
64
64
  - keys are the matching field values
65
65
  - values are lists of the mixids that match the criteria
66
66
  """
67
- from sonusai.mixture import REQUIRED_TRUTH_CONFIGS
67
+ from ..mixture.constants import REQUIRED_TRUTH_CONFIGS
68
68
 
69
69
  mixid_out = mixdb.mixids_to_list(mixids)
70
70
 
@@ -79,23 +79,24 @@ def get_mixids_from_truth_configs_field_predicate(
79
79
 
80
80
  result = {}
81
81
  for value in values:
82
- # Get a list of targets for each field value
82
+ # Get a list of sources for each field value
83
83
  indices = []
84
- for t_id in mixdb.target_file_ids:
85
- target = mixdb.target_file(t_id)
86
- for truth_config in target.truth_configs.values():
87
- if field in REQUIRED_TRUTH_CONFIGS:
88
- if value in getattr(truth_config, field):
89
- indices.append(t_id)
90
- else:
91
- if value in getattr(truth_config.config, field):
92
- indices.append(t_id)
84
+ for s_ids in mixdb.source_file_ids.values():
85
+ for s_id in s_ids:
86
+ source = mixdb.source_file(s_id)
87
+ for truth_config in source.truth_configs.values():
88
+ if field in REQUIRED_TRUTH_CONFIGS:
89
+ if value in getattr(truth_config, field):
90
+ indices.append(s_id)
91
+ else:
92
+ if value in getattr(truth_config.config, field):
93
+ indices.append(s_id)
93
94
  indices = sorted(set(indices))
94
95
 
95
96
  mixids = []
96
97
  for index in indices:
97
98
  for m_id in mixid_out:
98
- if index in [target.file_id for target in mixdb.mixture(m_id).targets]:
99
+ if index in [source.file_id for source in mixdb.mixture(m_id).all_sources.values()]:
99
100
  mixids.append(m_id)
100
101
 
101
102
  mixids = sorted(set(mixids))
@@ -109,18 +110,19 @@ def get_all_truth_configs_values_from_field(mixdb: MixtureDatabase, field: str)
109
110
  """
110
111
  Generate a list of all values corresponding to the given field in truth_configs
111
112
  """
112
- from sonusai.mixture import REQUIRED_TRUTH_CONFIGS
113
+ from ..mixture.constants import REQUIRED_TRUTH_CONFIGS
113
114
 
114
115
  result = []
115
- for target in mixdb.target_files:
116
- for truth_config in target.truth_configs.values():
117
- if field in REQUIRED_TRUTH_CONFIGS:
118
- value = getattr(truth_config, field)
119
- else:
120
- value = getattr(truth_config.config, field, None)
121
- if not isinstance(value, list):
122
- value = [value]
123
- result.extend(value)
116
+ for sources in mixdb.source_files.values():
117
+ for source in sources:
118
+ for truth_config in source.truth_configs.values():
119
+ if field in REQUIRED_TRUTH_CONFIGS:
120
+ value = getattr(truth_config, field)
121
+ else:
122
+ value = getattr(truth_config.config, field, None)
123
+ if not isinstance(value, list):
124
+ value = [value]
125
+ result.extend(value)
124
126
 
125
127
  return sorted(set(result))
126
128
 
@@ -139,18 +141,18 @@ def get_mixids_from_noise(
139
141
  return get_mixids_from_mixture_field_predicate(mixdb=mixdb, mixids=mixids, field="noise_id", predicate=predicate)
140
142
 
141
143
 
142
- def get_mixids_from_target(
144
+ def get_mixids_from_source(
143
145
  mixdb: MixtureDatabase,
144
146
  mixids: GeneralizedIDs = "*",
145
147
  predicate: Callable[[Any], bool] | None = None,
146
148
  ) -> dict[int, list[int]]:
147
149
  """
148
- Generate mixids based on a target index predicate
150
+ Generate mixids based on a source index predicate
149
151
  Return a dictionary where:
150
- - keys are the target indices
151
- - values are lists of the mixids that match the target index
152
+ - keys are the source indices
153
+ - values are lists of the mixids that match the source index
152
154
  """
153
- return get_mixids_from_mixture_field_predicate(mixdb=mixdb, mixids=mixids, field="target_ids", predicate=predicate)
155
+ return get_mixids_from_mixture_field_predicate(mixdb=mixdb, mixids=mixids, field="source_ids", predicate=predicate)
154
156
 
155
157
 
156
158
  def get_mixids_from_snr(
@@ -178,7 +180,9 @@ def get_mixids_from_snr(
178
180
  result: dict[float, list[int]] = {}
179
181
  for snr in snrs:
180
182
  # Get a list of mixids for each SNR
181
- result[snr] = sorted([i for i, mixture in enumerate(mixdb.mixtures()) if mixture.snr == snr and i in mixid_out])
183
+ result[snr] = sorted(
184
+ [i for i, mixture in enumerate(mixdb.mixtures) if mixture.noise.snr == snr and i in mixid_out]
185
+ )
182
186
 
183
187
  return result
184
188
 
@@ -30,7 +30,7 @@ def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
30
30
  """
31
31
  import string
32
32
 
33
- from sonusai.mixture import get_sample_rate
33
+ from ..mixture.audio import get_sample_rate
34
34
 
35
35
  path = Path(audio)
36
36
  name = path.stem
sonusai/speech/mcgill.py CHANGED
@@ -12,7 +12,7 @@ def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
12
12
  import string
13
13
  import struct
14
14
 
15
- from sonusai.mixture import get_sample_rate
15
+ from ..mixture.audio import get_sample_rate
16
16
 
17
17
  if not os.path.exists(audio):
18
18
  return None
sonusai/speech/timit.py CHANGED
@@ -12,7 +12,7 @@ def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
12
12
  """
13
13
  import string
14
14
 
15
- from sonusai.mixture import get_sample_rate
15
+ from ..mixture.audio import get_sample_rate
16
16
 
17
17
  file = Path(audio).with_suffix(".TXT")
18
18
  if not os.path.exists(file):
@@ -52,7 +52,7 @@ def load_phonemes(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None
52
52
 
53
53
 
54
54
  def _load_ta(audio: str | os.PathLike[str], tier: str) -> list[TimeAlignedType] | None:
55
- from sonusai.mixture import get_sample_rate
55
+ from ..mixture.audio import get_sample_rate
56
56
 
57
57
  if tier == "words":
58
58
  file = Path(audio).with_suffix(".WRD")
@@ -14,20 +14,6 @@ Inputs:
14
14
 
15
15
  """
16
16
 
17
- import signal
18
-
19
-
20
- def signal_handler(_sig, _frame):
21
- import sys
22
-
23
- from sonusai import logger
24
-
25
- logger.info("Canceled due to keyboard interrupt")
26
- sys.exit(1)
27
-
28
-
29
- signal.signal(signal.SIGINT, signal_handler)
30
-
31
17
 
32
18
  def summarize_metric_spenh(location: str, by: str = "MIXID", reverse: bool = False) -> str:
33
19
  import glob
@@ -56,10 +42,10 @@ def summarize_metric_spenh(location: str, by: str = "MIXID", reverse: bool = Fal
56
42
  def main():
57
43
  from docopt import docopt
58
44
 
59
- import sonusai
45
+ from sonusai import __version__ as sai_version
60
46
  from sonusai.utils import trim_docstring
61
47
 
62
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
48
+ args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
63
49
 
64
50
  by = args["--sort"]
65
51
  reverse = args["--reverse"]
@@ -69,4 +55,11 @@ def main():
69
55
 
70
56
 
71
57
  if __name__ == "__main__":
72
- main()
58
+ from sonusai import exception_handler
59
+ from sonusai.utils import register_keyboard_interrupt
60
+
61
+ register_keyboard_interrupt()
62
+ try:
63
+ main()
64
+ except Exception as e:
65
+ exception_handler(e)