sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sonusai/__init__.py +16 -3
  2. sonusai/audiofe.py +241 -77
  3. sonusai/calc_metric_spenh.py +71 -73
  4. sonusai/config/__init__.py +3 -0
  5. sonusai/config/config.py +61 -0
  6. sonusai/config/config.yml +20 -0
  7. sonusai/config/constants.py +8 -0
  8. sonusai/constants.py +11 -0
  9. sonusai/data/genmixdb.yml +21 -36
  10. sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
  11. sonusai/deprecated/plot.py +4 -5
  12. sonusai/doc/doc.py +4 -4
  13. sonusai/doc.py +11 -4
  14. sonusai/genft.py +43 -45
  15. sonusai/genmetrics.py +25 -19
  16. sonusai/genmix.py +54 -82
  17. sonusai/genmixdb.py +88 -264
  18. sonusai/ir_metric.py +30 -34
  19. sonusai/lsdb.py +41 -48
  20. sonusai/main.py +15 -22
  21. sonusai/metrics/calc_audio_stats.py +4 -293
  22. sonusai/metrics/calc_class_weights.py +4 -4
  23. sonusai/metrics/calc_optimal_thresholds.py +8 -5
  24. sonusai/metrics/calc_pesq.py +2 -2
  25. sonusai/metrics/calc_segsnr_f.py +4 -4
  26. sonusai/metrics/calc_speech.py +25 -13
  27. sonusai/metrics/class_summary.py +7 -7
  28. sonusai/metrics/confusion_matrix_summary.py +5 -5
  29. sonusai/metrics/one_hot.py +4 -4
  30. sonusai/metrics/snr_summary.py +7 -7
  31. sonusai/metrics_summary.py +38 -45
  32. sonusai/mixture/__init__.py +4 -104
  33. sonusai/mixture/audio.py +10 -39
  34. sonusai/mixture/class_balancing.py +103 -0
  35. sonusai/mixture/config.py +251 -271
  36. sonusai/mixture/constants.py +35 -39
  37. sonusai/mixture/data_io.py +25 -36
  38. sonusai/mixture/db_datatypes.py +58 -22
  39. sonusai/mixture/effects.py +386 -0
  40. sonusai/mixture/feature.py +7 -11
  41. sonusai/mixture/generation.py +478 -628
  42. sonusai/mixture/helpers.py +82 -184
  43. sonusai/mixture/ir_delay.py +3 -4
  44. sonusai/mixture/ir_effects.py +77 -0
  45. sonusai/mixture/log_duration_and_sizes.py +6 -12
  46. sonusai/mixture/mixdb.py +910 -729
  47. sonusai/mixture/pad_audio.py +35 -0
  48. sonusai/mixture/resample.py +7 -0
  49. sonusai/mixture/sox_effects.py +195 -0
  50. sonusai/mixture/sox_help.py +650 -0
  51. sonusai/mixture/spectral_mask.py +2 -2
  52. sonusai/mixture/truth.py +17 -15
  53. sonusai/mixture/truth_functions/crm.py +12 -12
  54. sonusai/mixture/truth_functions/energy.py +22 -22
  55. sonusai/mixture/truth_functions/file.py +5 -5
  56. sonusai/mixture/truth_functions/metadata.py +4 -4
  57. sonusai/mixture/truth_functions/metrics.py +4 -4
  58. sonusai/mixture/truth_functions/phoneme.py +3 -3
  59. sonusai/mixture/truth_functions/sed.py +11 -13
  60. sonusai/mixture/truth_functions/target.py +10 -10
  61. sonusai/mkwav.py +26 -29
  62. sonusai/onnx_predict.py +240 -88
  63. sonusai/queries/__init__.py +2 -2
  64. sonusai/queries/queries.py +38 -34
  65. sonusai/speech/librispeech.py +1 -1
  66. sonusai/speech/mcgill.py +1 -1
  67. sonusai/speech/timit.py +2 -2
  68. sonusai/summarize_metric_spenh.py +10 -17
  69. sonusai/utils/__init__.py +7 -1
  70. sonusai/utils/asl_p56.py +2 -2
  71. sonusai/utils/asr.py +2 -2
  72. sonusai/utils/asr_functions/aaware_whisper.py +4 -5
  73. sonusai/utils/choice.py +31 -0
  74. sonusai/utils/compress.py +1 -1
  75. sonusai/utils/dataclass_from_dict.py +19 -1
  76. sonusai/utils/energy_f.py +3 -3
  77. sonusai/utils/evaluate_random_rule.py +15 -0
  78. sonusai/utils/keyboard_interrupt.py +12 -0
  79. sonusai/utils/onnx_utils.py +3 -17
  80. sonusai/utils/print_mixture_details.py +21 -19
  81. sonusai/utils/{temp_seed.py → rand.py} +3 -3
  82. sonusai/utils/read_predict_data.py +2 -2
  83. sonusai/utils/reshape.py +3 -3
  84. sonusai/utils/stratified_shuffle_split.py +3 -3
  85. sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
  86. sonusai/utils/write_audio.py +2 -2
  87. sonusai/vars.py +11 -4
  88. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
  89. sonusai-1.0.2.dist-info/RECORD +138 -0
  90. sonusai/mixture/augmentation.py +0 -444
  91. sonusai/mixture/class_count.py +0 -15
  92. sonusai/mixture/eq_rule_is_valid.py +0 -45
  93. sonusai/mixture/target_class_balancing.py +0 -107
  94. sonusai/mixture/targets.py +0 -175
  95. sonusai-0.20.3.dist-info/RECORD +0 -128
  96. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
  97. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
sonusai/mixture/config.py CHANGED
@@ -1,8 +1,7 @@
1
- from sonusai.mixture.datatypes import ImpulseResponseFile
2
- from sonusai.mixture.datatypes import NoiseFile
3
- from sonusai.mixture.datatypes import SpectralMask
4
- from sonusai.mixture.datatypes import TargetFile
5
- from sonusai.mixture.datatypes import TruthParameter
1
+ from ..datatypes import ImpulseResponseFile
2
+ from ..datatypes import SourceFile
3
+ from ..datatypes import SpectralMask
4
+ from ..datatypes import TruthParameter
6
5
 
7
6
 
8
7
  def raw_load_config(name: str) -> dict:
@@ -54,7 +53,6 @@ def update_config_from_file(filename: str, given_config: dict) -> dict:
54
53
 
55
54
  from .constants import REQUIRED_CONFIGS
56
55
  from .constants import VALID_CONFIGS
57
- from .constants import VALID_NOISE_MIX_MODES
58
56
 
59
57
  updated_config = deepcopy(given_config)
60
58
 
@@ -81,6 +79,9 @@ def update_config_from_file(filename: str, given_config: dict) -> dict:
81
79
  if key not in updated_config:
82
80
  raise AttributeError(f"{filename} is missing required '{key}'")
83
81
 
82
+ # Validate and update sources
83
+ updated_config = update_sources(updated_config)
84
+
84
85
  # Validate special cases
85
86
  validate_truth_configs(updated_config)
86
87
  validate_asr_configs(updated_config)
@@ -89,14 +90,76 @@ def update_config_from_file(filename: str, given_config: dict) -> dict:
89
90
  if len(updated_config["spectral_masks"]) == 0:
90
91
  updated_config["spectral_masks"] = given_config["spectral_masks"]
91
92
 
92
- # Check for valid noise_mix_mode
93
- if updated_config["noise_mix_mode"] not in VALID_NOISE_MIX_MODES:
94
- nice_list = "\n".join([f" {item}" for item in VALID_NOISE_MIX_MODES])
95
- raise ValueError(f"{filename} contains invalid noise_mix_mode.\nValid noise mix modes are:\n{nice_list}")
96
-
97
93
  return updated_config
98
94
 
99
95
 
96
+ def update_sources(given: dict) -> dict:
97
+ """Validate and update fields in given 'sources'
98
+
99
+ :param given: The dictionary of given config
100
+ """
101
+ from .constants import REQUIRED_NON_PRIMARY_SOURCE_CONFIGS
102
+ from .constants import REQUIRED_SOURCE_CONFIGS
103
+ from .constants import REQUIRED_SOURCES_CATEGORIES
104
+ from .constants import VALID_NON_PRIMARY_SOURCE_CONFIGS
105
+ from .constants import VALID_PRIMARY_SOURCE_CONFIGS
106
+
107
+ sources = given["sources"]
108
+
109
+ for category in REQUIRED_SOURCES_CATEGORIES:
110
+ if category not in sources:
111
+ raise AttributeError(f"config sources is missing required '{category}'")
112
+
113
+ for category, source in sources.items():
114
+ for key in REQUIRED_SOURCE_CONFIGS:
115
+ if key not in source:
116
+ raise AttributeError(f"config source '{category}' is missing required '{key}'")
117
+
118
+ if category == "primary":
119
+ for key in source:
120
+ if key not in VALID_PRIMARY_SOURCE_CONFIGS:
121
+ nice_list = "\n".join([f" {item}" for item in VALID_PRIMARY_SOURCE_CONFIGS])
122
+ raise AttributeError(
123
+ f"Invalid source '{category}' config parameter: '{key}'.\nValid sources config parameters are:\n{nice_list}"
124
+ )
125
+ else:
126
+ for key in REQUIRED_NON_PRIMARY_SOURCE_CONFIGS:
127
+ if key not in source:
128
+ raise AttributeError(f"config source '{category}' is missing required '{key}'")
129
+
130
+ for key in source:
131
+ if key not in VALID_NON_PRIMARY_SOURCE_CONFIGS:
132
+ nice_list = "\n".join([f" {item}" for item in VALID_NON_PRIMARY_SOURCE_CONFIGS])
133
+ raise AttributeError(
134
+ f"Invalid source '{category}' config parameter: '{key}'.\nValid source config parameters are:\n{nice_list}"
135
+ )
136
+
137
+ files = source["files"]
138
+
139
+ if isinstance(files, str) and files in sources and files != category:
140
+ continue
141
+
142
+ if isinstance(files, list):
143
+ continue
144
+
145
+ raise TypeError(
146
+ f"'file' parameter of config source '{category}' is not a list or a reference to another source"
147
+ )
148
+
149
+ count = 0
150
+ while any(isinstance(source["files"], str) for source in sources.values()) and count < 100:
151
+ count += 1
152
+ for category, source in sources.items():
153
+ files = source["files"]
154
+ if isinstance(files, str):
155
+ given["sources"][category]["files"] = sources[files]["files"]
156
+
157
+ if count == 100:
158
+ raise RuntimeError("Check config sources for circular references")
159
+
160
+ return given
161
+
162
+
100
163
  def validate_truth_configs(given: dict) -> None:
101
164
  """Validate fields in given 'truth_configs'
102
165
 
@@ -104,27 +167,31 @@ def validate_truth_configs(given: dict) -> None:
104
167
  """
105
168
  from copy import deepcopy
106
169
 
107
- from sonusai.mixture import truth_functions
108
-
170
+ from . import truth_functions
109
171
  from .constants import REQUIRED_TRUTH_CONFIGS
110
172
 
111
- if "truth_configs" not in given:
112
- raise AttributeError("config is missing required 'truth_configs'")
173
+ sources = given["sources"]
113
174
 
114
- truth_configs = given["truth_configs"]
115
- if len(truth_configs) == 0:
116
- raise ValueError("'truth_configs' in config is empty")
175
+ for category, source in sources.items():
176
+ if "truth_configs" not in source:
177
+ continue
117
178
 
118
- for name, truth_config in truth_configs.items():
119
- for key in REQUIRED_TRUTH_CONFIGS:
120
- if key not in truth_config:
121
- raise AttributeError(f"'{name}' in truth_configs is missing required '{key}'")
179
+ truth_configs = source["truth_configs"]
180
+ if len(truth_configs) == 0:
181
+ raise ValueError(f"'truth_configs' in config source '{category}' is empty")
122
182
 
123
- optional_config = deepcopy(truth_config)
124
- for key in REQUIRED_TRUTH_CONFIGS:
125
- del optional_config[key]
183
+ for truth_name, truth_config in truth_configs.items():
184
+ for k in REQUIRED_TRUTH_CONFIGS:
185
+ if k not in truth_config:
186
+ raise AttributeError(
187
+ f"'{truth_name}' in source '{category}' truth_configs is missing required '{k}'"
188
+ )
189
+
190
+ optional_config = deepcopy(truth_config)
191
+ for k in REQUIRED_TRUTH_CONFIGS:
192
+ del optional_config[k]
126
193
 
127
- getattr(truth_functions, truth_config["function"] + "_validate")(optional_config)
194
+ getattr(truth_functions, truth_config["function"] + "_validate")(optional_config)
128
195
 
129
196
 
130
197
  def validate_asr_configs(given: dict) -> None:
@@ -132,8 +199,7 @@ def validate_asr_configs(given: dict) -> None:
132
199
 
133
200
  :param given: The dictionary of given config
134
201
  """
135
- from sonusai.utils import validate_asr
136
-
202
+ from ..utils.asr import validate_asr
137
203
  from .constants import REQUIRED_ASR_CONFIGS
138
204
 
139
205
  if "asr_configs" not in given:
@@ -209,69 +275,80 @@ def update_config_from_hierarchy(root: str, leaf: str, config: dict) -> dict:
209
275
  return new_config
210
276
 
211
277
 
212
- def get_target_files(config: dict, show_progress: bool = False) -> list[TargetFile]:
213
- """Get the list of target files from a config
278
+ def get_source_files(config: dict, show_progress: bool = False) -> list[SourceFile]:
279
+ """Get the list of source files from a config
214
280
 
215
281
  :param config: Config dictionary
216
282
  :param show_progress: Show progress bar
217
- :return: List of target files
283
+ :return: List of source files
218
284
  """
219
285
  from itertools import chain
220
286
 
221
- from sonusai.utils import dataclass_from_dict
222
- from sonusai.utils import par_track
223
- from sonusai.utils import track
287
+ from ..utils.parallel import par_track
288
+ from ..utils.parallel import track
224
289
 
225
- from .datatypes import TargetFile
290
+ sources = config["sources"]
291
+ if not isinstance(sources, dict) and not all(isinstance(source, dict) for source in sources):
292
+ raise TypeError("'sources' must be a dictionary of dictionaries")
293
+
294
+ if "primary" not in sources:
295
+ raise AttributeError("'primary' is missing in 'sources'")
226
296
 
227
297
  class_indices = config["class_indices"]
228
298
  if not isinstance(class_indices, list):
229
299
  class_indices = [class_indices]
230
300
 
231
- target_files = list(
232
- chain.from_iterable(
233
- [
234
- append_target_files(
235
- entry=entry,
236
- class_indices=class_indices,
237
- truth_configs=config["truth_configs"],
238
- level_type=config["target_level_type"],
239
- )
240
- for entry in config["targets"]
241
- ]
301
+ level_type = config["level_type"]
302
+
303
+ source_files: list[SourceFile] = []
304
+ for category in sources:
305
+ source_files.extend(
306
+ chain.from_iterable(
307
+ [
308
+ append_source_files(
309
+ category=category,
310
+ entry=entry,
311
+ class_indices=class_indices,
312
+ truth_configs=sources[category].get("truth_configs", []),
313
+ level_type=level_type,
314
+ )
315
+ for entry in sources[category]["files"]
316
+ ]
317
+ )
242
318
  )
243
- )
244
319
 
245
- progress = track(total=len(target_files), disable=not show_progress)
246
- target_files = par_track(_get_num_samples, target_files, progress=progress)
320
+ progress = track(total=len(source_files), disable=not show_progress)
321
+ source_files = par_track(_get_num_samples, source_files, progress=progress)
247
322
  progress.close()
248
323
 
249
324
  num_classes = config["num_classes"]
250
- for target_file in target_files:
251
- if any(class_index < 0 for class_index in target_file["class_indices"]):
325
+ for source_file in source_files:
326
+ if any(class_index < 0 for class_index in source_file.class_indices):
252
327
  raise ValueError("class indices must contain only positive elements")
253
328
 
254
- if any(class_index > num_classes for class_index in target_file["class_indices"]):
329
+ if any(class_index > num_classes for class_index in source_file.class_indices):
255
330
  raise ValueError(f"class index elements must not be greater than {num_classes}")
256
331
 
257
- return dataclass_from_dict(list[TargetFile], target_files)
332
+ return source_files
258
333
 
259
334
 
260
- def append_target_files(
261
- entry: dict | str,
335
+ def append_source_files(
336
+ category: str,
337
+ entry: dict,
262
338
  class_indices: list[int],
263
339
  truth_configs: dict,
264
340
  level_type: str,
265
341
  tokens: dict | None = None,
266
- ) -> list[dict]:
267
- """Process target files list and append as needed
342
+ ) -> list[SourceFile]:
343
+ """Process source files list and append as needed
268
344
 
269
- :param entry: Target file entry to append to the list
345
+ :param category: Source file category name
346
+ :param entry: Source file entry to append to the list
270
347
  :param class_indices: Class indices
271
348
  :param truth_configs: Truth configs
272
- :param level_type: Target level type
349
+ :param level_type: Level type
273
350
  :param tokens: Tokens used for variable expansion
274
- :return: List of target files
351
+ :return: List of source files
275
352
  """
276
353
  from copy import deepcopy
277
354
  from glob import glob
@@ -282,41 +359,39 @@ def append_target_files(
282
359
  from os.path import join
283
360
  from os.path import splitext
284
361
 
285
- from sonusai.utils import dataclass_from_dict
286
-
362
+ from ..datatypes import TruthConfig
363
+ from ..utils.dataclass_from_dict import dataclass_from_dict
364
+ from ..utils.tokenized_shell_vars import tokenized_expand
365
+ from ..utils.tokenized_shell_vars import tokenized_replace
287
366
  from .audio import validate_input_file
288
367
  from .constants import REQUIRED_TRUTH_CONFIGS
289
- from .datatypes import TruthConfig
290
- from .tokenized_shell_vars import tokenized_expand
291
- from .tokenized_shell_vars import tokenized_replace
292
368
 
293
369
  if tokens is None:
294
370
  tokens = {}
295
371
 
296
372
  truth_configs_merged = deepcopy(truth_configs)
297
- if isinstance(entry, dict):
298
- if "name" in entry:
299
- in_name = entry["name"]
300
- else:
301
- raise AttributeError("Target list contained record without name")
302
-
303
- if "class_indices" in entry:
304
- if isinstance(entry["class_indices"], list):
305
- class_indices = entry["class_indices"]
306
- else:
307
- class_indices = [entry["class_indices"]]
308
-
309
- truth_configs_override = entry.get("truth_configs", {})
310
- for key in truth_configs_override:
311
- if key not in truth_configs:
312
- raise AttributeError(
313
- f"Truth config '{key}' override specified for {entry['name']} is not defined at top level"
314
- )
315
- if key in truth_configs_override:
316
- truth_configs_merged[key] |= truth_configs_override[key]
317
- level_type = entry.get("level_type", level_type)
318
- else:
319
- in_name = entry
373
+
374
+ if not isinstance(entry, dict):
375
+ raise TypeError("'entry' must be a dictionary")
376
+
377
+ in_name = entry.get("name")
378
+ if in_name is None:
379
+ raise KeyError("Source file list contained record without name")
380
+
381
+ class_indices = entry.get("class_indices", class_indices)
382
+ if not isinstance(class_indices, list):
383
+ class_indices = [class_indices]
384
+
385
+ truth_configs_override = entry.get("truth_configs", {})
386
+ for key in truth_configs_override:
387
+ if key not in truth_configs:
388
+ raise AttributeError(
389
+ f"Truth config '{key}' override specified for {entry['name']} is not defined at top level"
390
+ )
391
+ if key in truth_configs_override:
392
+ truth_configs_merged[key] |= truth_configs_override[key]
393
+
394
+ level_type = entry.get("level_type", level_type)
320
395
 
321
396
  in_name, new_tokens = tokenized_expand(in_name)
322
397
  tokens.update(new_tokens)
@@ -324,7 +399,7 @@ def append_target_files(
324
399
  if not names:
325
400
  raise OSError(f"Could not find {in_name}. Make sure path exists")
326
401
 
327
- target_files: list[dict] = []
402
+ source_files: list[SourceFile] = []
328
403
  for name in names:
329
404
  ext = splitext(name)[1].lower()
330
405
  dir_name = dirname(name)
@@ -333,9 +408,10 @@ def append_target_files(
333
408
  child = file
334
409
  if not isabs(child):
335
410
  child = join(dir_name, child)
336
- target_files.extend(
337
- append_target_files(
338
- entry=child,
411
+ source_files.extend(
412
+ append_source_files(
413
+ category=category,
414
+ entry={"name": child},
339
415
  class_indices=class_indices,
340
416
  truth_configs=truth_configs_merged,
341
417
  level_type=level_type,
@@ -355,41 +431,26 @@ def append_target_files(
355
431
  tokens.update(new_tokens)
356
432
  if not isabs(child):
357
433
  child = join(dir_name, child)
358
- target_files.extend(
359
- append_target_files(
360
- entry=child,
434
+ source_files.extend(
435
+ append_source_files(
436
+ category=category,
437
+ entry={"name": child},
361
438
  class_indices=class_indices,
362
439
  truth_configs=truth_configs_merged,
363
440
  level_type=level_type,
364
441
  tokens=tokens,
365
442
  )
366
443
  )
367
- elif ext == ".yml":
368
- try:
369
- yml_config = raw_load_config(name)
370
-
371
- if "targets" in yml_config:
372
- for record in yml_config["targets"]:
373
- target_files.extend(
374
- append_target_files(
375
- entry=record,
376
- class_indices=class_indices,
377
- truth_configs=truth_configs_merged,
378
- level_type=level_type,
379
- tokens=tokens,
380
- )
381
- )
382
- except Exception as e:
383
- raise OSError(f"Error processing {name}: {e}") from e
384
444
  else:
385
445
  validate_input_file(name)
386
- target_file: dict = {
387
- "expanded_name": name,
388
- "name": tokenized_replace(name, tokens),
389
- "class_indices": class_indices,
390
- "level_type": level_type,
391
- "truth_configs": {},
392
- }
446
+ source_file = SourceFile(
447
+ category=category,
448
+ name=tokenized_replace(name, tokens),
449
+ samples=0,
450
+ class_indices=class_indices,
451
+ level_type=level_type,
452
+ truth_configs={},
453
+ )
393
454
  if len(truth_configs_merged) > 0:
394
455
  for tc_key, tc_value in truth_configs_merged.items():
395
456
  config = deepcopy(tc_value)
@@ -398,145 +459,58 @@ def append_target_files(
398
459
  truth_config[key] = config[key]
399
460
  del config[key]
400
461
  truth_config["config"] = config
401
- target_file["truth_configs"][tc_key] = dataclass_from_dict(TruthConfig, truth_config)
402
- for tc_key in target_file["truth_configs"]:
462
+ source_file.truth_configs[tc_key] = dataclass_from_dict(TruthConfig, truth_config)
463
+ for tc_key in source_file.truth_configs:
403
464
  if (
404
465
  "function" in truth_configs_merged[tc_key]
405
466
  and truth_configs_merged[tc_key]["function"] == "file"
406
467
  ):
407
- truth_configs_merged[tc_key]["file"] = splitext(target_file["name"])[0] + ".h5"
408
- target_files.append(target_file)
409
- except Exception as e:
410
- raise OSError(f"Error processing {name}: {e}") from e
411
-
412
- return target_files
413
-
414
-
415
- def get_noise_files(config: dict, show_progress: bool = False) -> list[NoiseFile]:
416
- """Get the list of noise files from a config
417
-
418
- :param config: Config dictionary
419
- :param show_progress: Show progress bar
420
- :return: List of noise file
421
- """
422
- from itertools import chain
423
-
424
- from sonusai.utils import dataclass_from_dict
425
- from sonusai.utils import par_track
426
- from sonusai.utils import track
427
-
428
- from .datatypes import NoiseFile
429
-
430
- noise_files = list(chain.from_iterable([append_noise_files(entry=entry) for entry in config["noises"]]))
431
-
432
- progress = track(total=len(noise_files), disable=not show_progress)
433
- noise_files = par_track(_get_num_samples, noise_files, progress=progress)
434
- progress.close()
435
-
436
- return dataclass_from_dict(list[NoiseFile], noise_files)
437
-
438
-
439
- def append_noise_files(entry: dict | str, tokens: dict | None = None) -> list[dict]:
440
- """Process noise files list and append as needed
441
-
442
- :param entry: Noise file entry to append to the list
443
- :param tokens: Tokens used for variable expansion
444
- :return: List of noise files
445
- """
446
- from glob import glob
447
- from os import listdir
448
- from os.path import dirname
449
- from os.path import isabs
450
- from os.path import isdir
451
- from os.path import join
452
- from os.path import splitext
453
-
454
- from .audio import validate_input_file
455
- from .tokenized_shell_vars import tokenized_expand
456
- from .tokenized_shell_vars import tokenized_replace
457
-
458
- if tokens is None:
459
- tokens = {}
460
-
461
- if isinstance(entry, dict):
462
- if "name" in entry:
463
- in_name = entry["name"]
464
- else:
465
- raise AttributeError("Noise list contained record without name")
466
- else:
467
- in_name = entry
468
-
469
- in_name, new_tokens = tokenized_expand(in_name)
470
- tokens.update(new_tokens)
471
- names = sorted(glob(in_name))
472
- if not names:
473
- raise OSError(f"Could not find {in_name}. Make sure path exists")
474
-
475
- noise_files: list[dict] = []
476
- for name in names:
477
- ext = splitext(name)[1].lower()
478
- dir_name = dirname(name)
479
- if isdir(name):
480
- for file in listdir(name):
481
- child = file
482
- if not isabs(child):
483
- child = join(dir_name, child)
484
- noise_files.extend(append_noise_files(entry=child, tokens=tokens))
485
- else:
486
- try:
487
- if ext == ".txt":
488
- with open(file=name) as txt_file:
489
- for line in txt_file:
490
- # strip comments
491
- child = line.partition("#")[0]
492
- child = child.rstrip()
493
- if child:
494
- child, new_tokens = tokenized_expand(child)
495
- tokens.update(new_tokens)
496
- if not isabs(child):
497
- child = join(dir_name, child)
498
- noise_files.extend(append_noise_files(entry=child, tokens=tokens))
499
- elif ext == ".yml":
500
- try:
501
- yml_config = raw_load_config(name)
502
-
503
- if "noises" in yml_config:
504
- for record in yml_config["noises"]:
505
- noise_files.extend(append_noise_files(entry=record, tokens=tokens))
506
- except Exception as e:
507
- raise OSError(f"Error processing {name}: {e}") from e
508
- else:
509
- validate_input_file(name)
510
- noise_file: dict = {
511
- "expanded_name": name,
512
- "name": tokenized_replace(name, tokens),
513
- }
514
- noise_files.append(noise_file)
468
+ truth_configs_merged[tc_key]["file"] = splitext(source_file.name)[0] + ".h5"
469
+ source_files.append(source_file)
515
470
  except Exception as e:
516
471
  raise OSError(f"Error processing {name}: {e}") from e
517
472
 
518
- return noise_files
473
+ return source_files
519
474
 
520
475
 
521
- def get_impulse_response_files(config: dict) -> list[ImpulseResponseFile]:
476
+ def get_ir_files(config: dict, show_progress: bool = False) -> list[ImpulseResponseFile]:
522
477
  """Get the list of impulse response files from a config
523
478
 
524
479
  :param config: Config dictionary
480
+ :param show_progress: Show progress bar
525
481
  :return: List of impulse response files
526
482
  """
527
483
  from itertools import chain
528
484
 
529
- return list(
485
+ from ..utils.parallel import par_track
486
+ from ..utils.parallel import track
487
+
488
+ ir_files = list(
530
489
  chain.from_iterable(
531
490
  [
532
- append_impulse_response_files(entry=ImpulseResponseFile(entry["name"], entry.get("tags", []), 0))
491
+ append_ir_files(
492
+ entry=ImpulseResponseFile(
493
+ name=entry["name"],
494
+ tags=entry.get("tags", []),
495
+ delay=entry.get("delay", "auto"),
496
+ )
497
+ )
533
498
  for entry in config["impulse_responses"]
534
499
  ]
535
500
  )
536
501
  )
537
502
 
503
+ if len(ir_files) == 0:
504
+ return []
505
+
506
+ progress = track(total=len(ir_files), disable=not show_progress)
507
+ ir_files = par_track(_get_ir_delay, ir_files, progress=progress)
508
+ progress.close()
509
+
510
+ return ir_files
511
+
538
512
 
539
- def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | None = None) -> list[ImpulseResponseFile]:
513
+ def append_ir_files(entry: ImpulseResponseFile, tokens: dict | None = None) -> list[ImpulseResponseFile]:
540
514
  """Process impulse response files list and append as needed
541
515
 
542
516
  :param entry: Impulse response file entry to append to the list
@@ -551,21 +525,20 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
551
525
  from os.path import join
552
526
  from os.path import splitext
553
527
 
528
+ from ..utils.tokenized_shell_vars import tokenized_expand
529
+ from ..utils.tokenized_shell_vars import tokenized_replace
554
530
  from .audio import validate_input_file
555
- from .ir_delay import get_impulse_response_delay
556
- from .tokenized_shell_vars import tokenized_expand
557
- from .tokenized_shell_vars import tokenized_replace
558
531
 
559
532
  if tokens is None:
560
533
  tokens = {}
561
534
 
562
- in_name, new_tokens = tokenized_expand(entry.file)
535
+ in_name, new_tokens = tokenized_expand(entry.name)
563
536
  tokens.update(new_tokens)
564
537
  names = sorted(glob(in_name))
565
538
  if not names:
566
539
  raise OSError(f"Could not find {in_name}. Make sure path exists")
567
540
 
568
- impulse_response_files: list[ImpulseResponseFile] = []
541
+ ir_files: list[ImpulseResponseFile] = []
569
542
  for name in names:
570
543
  ext = splitext(name)[1].lower()
571
544
  dir_name = dirname(name)
@@ -573,8 +546,8 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
573
546
  for file in listdir(name):
574
547
  if not isabs(file):
575
548
  file = join(dir_name, file)
576
- child = ImpulseResponseFile(file, entry.tags, get_impulse_response_delay(file))
577
- impulse_response_files.extend(append_impulse_response_files(entry=child, tokens=tokens))
549
+ child = ImpulseResponseFile(file, entry.tags, entry.delay)
550
+ ir_files.extend(append_ir_files(entry=child, tokens=tokens))
578
551
  else:
579
552
  try:
580
553
  if ext == ".txt":
@@ -588,30 +561,24 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
588
561
  tokens.update(new_tokens)
589
562
  if not isabs(file):
590
563
  file = join(dir_name, file)
591
- child = ImpulseResponseFile(file, entry.tags, get_impulse_response_delay(file))
592
- impulse_response_files.extend(append_impulse_response_files(entry=child, tokens=tokens))
564
+ child = ImpulseResponseFile(file, entry.tags, entry.delay)
565
+ ir_files.extend(append_ir_files(entry=child, tokens=tokens))
593
566
  elif ext == ".yml":
594
567
  try:
595
568
  yml_config = raw_load_config(name)
596
569
 
597
570
  if "impulse_responses" in yml_config:
598
571
  for record in yml_config["impulse_responses"]:
599
- impulse_response_files.extend(
600
- append_impulse_response_files(entry=record, tokens=tokens)
601
- )
572
+ ir_files.extend(append_ir_files(entry=record, tokens=tokens))
602
573
  except Exception as e:
603
574
  raise OSError(f"Error processing {name}: {e}") from e
604
575
  else:
605
576
  validate_input_file(name)
606
- impulse_response_files.append(
607
- ImpulseResponseFile(
608
- tokenized_replace(name, tokens), entry.tags, get_impulse_response_delay(name)
609
- )
610
- )
577
+ ir_files.append(ImpulseResponseFile(tokenized_replace(name, tokens), entry.tags, entry.delay))
611
578
  except Exception as e:
612
579
  raise OSError(f"Error processing {name}: {e}") from e
613
580
 
614
- return impulse_response_files
581
+ return ir_files
615
582
 
616
583
 
617
584
  def get_spectral_masks(config: dict) -> list[SpectralMask]:
@@ -620,10 +587,10 @@ def get_spectral_masks(config: dict) -> list[SpectralMask]:
620
587
  :param config: Config dictionary
621
588
  :return: List of spectral masks
622
589
  """
623
- from sonusai.utils import dataclass_from_dict
590
+ from ..utils.dataclass_from_dict import list_dataclass_from_dict
624
591
 
625
592
  try:
626
- return dataclass_from_dict(list[SpectralMask], config["spectral_masks"])
593
+ return list_dataclass_from_dict(list[SpectralMask], config["spectral_masks"])
627
594
  except Exception as e:
628
595
  raise ValueError(f"Error in spectral_masks: {e}") from e
629
596
 
@@ -636,30 +603,43 @@ def get_truth_parameters(config: dict) -> list[TruthParameter]:
636
603
  """
637
604
  from copy import deepcopy
638
605
 
639
- from sonusai.mixture import truth_functions
640
-
606
+ from . import truth_functions
641
607
  from .constants import REQUIRED_TRUTH_CONFIGS
642
- from .datatypes import TruthParameter
643
608
 
644
609
  truth_parameters: list[TruthParameter] = []
645
- for name, truth_config in config["truth_configs"].items():
646
- optional_config = deepcopy(truth_config)
647
- for key in REQUIRED_TRUTH_CONFIGS:
648
- del optional_config[key]
649
-
650
- parameters = getattr(truth_functions, truth_config["function"] + "_parameters")(
651
- config["feature"],
652
- config["num_classes"],
653
- optional_config,
654
- )
655
- truth_parameters.append(TruthParameter(name, parameters))
610
+ for category, source_config in config["sources"].items():
611
+ if "truth_configs" in source_config:
612
+ for truth_name, truth_config in source_config["truth_configs"].items():
613
+ optional_config = deepcopy(truth_config)
614
+ for key in REQUIRED_TRUTH_CONFIGS:
615
+ del optional_config[key]
616
+
617
+ parameters = getattr(truth_functions, truth_config["function"] + "_parameters")(
618
+ config["feature"],
619
+ config["num_classes"],
620
+ optional_config,
621
+ )
622
+ truth_parameters.append(TruthParameter(category, truth_name, parameters))
656
623
 
657
624
  return truth_parameters
658
625
 
659
626
 
660
- def _get_num_samples(entry: dict) -> dict:
627
+ def _get_num_samples(entry: SourceFile) -> SourceFile:
661
628
  from .audio import get_num_samples
662
629
 
663
- entry["samples"] = get_num_samples(entry["expanded_name"])
664
- del entry["expanded_name"]
630
+ entry.samples = get_num_samples(entry.name)
631
+ return entry
632
+
633
+
634
+ def _get_ir_delay(entry: ImpulseResponseFile) -> ImpulseResponseFile:
635
+ from .ir_delay import get_ir_delay
636
+
637
+ if entry.delay == "auto":
638
+ entry.delay = get_ir_delay(entry.name)
639
+ else:
640
+ try:
641
+ entry.delay = int(entry.delay)
642
+ except ValueError as e:
643
+ raise ValueError(f"Invalid impulse response delay: {entry.delay}") from e
644
+
665
645
  return entry