sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +16 -3
- sonusai/audiofe.py +241 -77
- sonusai/calc_metric_spenh.py +71 -73
- sonusai/config/__init__.py +3 -0
- sonusai/config/config.py +61 -0
- sonusai/config/config.yml +20 -0
- sonusai/config/constants.py +8 -0
- sonusai/constants.py +11 -0
- sonusai/data/genmixdb.yml +21 -36
- sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
- sonusai/deprecated/plot.py +4 -5
- sonusai/doc/doc.py +4 -4
- sonusai/doc.py +11 -4
- sonusai/genft.py +43 -45
- sonusai/genmetrics.py +25 -19
- sonusai/genmix.py +54 -82
- sonusai/genmixdb.py +88 -264
- sonusai/ir_metric.py +30 -34
- sonusai/lsdb.py +41 -48
- sonusai/main.py +15 -22
- sonusai/metrics/calc_audio_stats.py +4 -293
- sonusai/metrics/calc_class_weights.py +4 -4
- sonusai/metrics/calc_optimal_thresholds.py +8 -5
- sonusai/metrics/calc_pesq.py +2 -2
- sonusai/metrics/calc_segsnr_f.py +4 -4
- sonusai/metrics/calc_speech.py +25 -13
- sonusai/metrics/class_summary.py +7 -7
- sonusai/metrics/confusion_matrix_summary.py +5 -5
- sonusai/metrics/one_hot.py +4 -4
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/metrics_summary.py +38 -45
- sonusai/mixture/__init__.py +4 -104
- sonusai/mixture/audio.py +10 -39
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/config.py +251 -271
- sonusai/mixture/constants.py +35 -39
- sonusai/mixture/data_io.py +25 -36
- sonusai/mixture/db_datatypes.py +58 -22
- sonusai/mixture/effects.py +386 -0
- sonusai/mixture/feature.py +7 -11
- sonusai/mixture/generation.py +478 -628
- sonusai/mixture/helpers.py +82 -184
- sonusai/mixture/ir_delay.py +3 -4
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +6 -12
- sonusai/mixture/mixdb.py +910 -729
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +2 -2
- sonusai/mixture/truth.py +17 -15
- sonusai/mixture/truth_functions/crm.py +12 -12
- sonusai/mixture/truth_functions/energy.py +22 -22
- sonusai/mixture/truth_functions/file.py +5 -5
- sonusai/mixture/truth_functions/metadata.py +4 -4
- sonusai/mixture/truth_functions/metrics.py +4 -4
- sonusai/mixture/truth_functions/phoneme.py +3 -3
- sonusai/mixture/truth_functions/sed.py +11 -13
- sonusai/mixture/truth_functions/target.py +10 -10
- sonusai/mkwav.py +26 -29
- sonusai/onnx_predict.py +240 -88
- sonusai/queries/__init__.py +2 -2
- sonusai/queries/queries.py +38 -34
- sonusai/speech/librispeech.py +1 -1
- sonusai/speech/mcgill.py +1 -1
- sonusai/speech/timit.py +2 -2
- sonusai/summarize_metric_spenh.py +10 -17
- sonusai/utils/__init__.py +7 -1
- sonusai/utils/asl_p56.py +2 -2
- sonusai/utils/asr.py +2 -2
- sonusai/utils/asr_functions/aaware_whisper.py +4 -5
- sonusai/utils/choice.py +31 -0
- sonusai/utils/compress.py +1 -1
- sonusai/utils/dataclass_from_dict.py +19 -1
- sonusai/utils/energy_f.py +3 -3
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/onnx_utils.py +3 -17
- sonusai/utils/print_mixture_details.py +21 -19
- sonusai/utils/{temp_seed.py → rand.py} +3 -3
- sonusai/utils/read_predict_data.py +2 -2
- sonusai/utils/reshape.py +3 -3
- sonusai/utils/stratified_shuffle_split.py +3 -3
- sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
- sonusai/utils/write_audio.py +2 -2
- sonusai/vars.py +11 -4
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
- sonusai-1.0.2.dist-info/RECORD +138 -0
- sonusai/mixture/augmentation.py +0 -444
- sonusai/mixture/class_count.py +0 -15
- sonusai/mixture/eq_rule_is_valid.py +0 -45
- sonusai/mixture/target_class_balancing.py +0 -107
- sonusai/mixture/targets.py +0 -175
- sonusai-0.20.3.dist-info/RECORD +0 -128
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
sonusai/mixture/config.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1
|
-
from
|
2
|
-
from
|
3
|
-
from
|
4
|
-
from
|
5
|
-
from sonusai.mixture.datatypes import TruthParameter
|
1
|
+
from ..datatypes import ImpulseResponseFile
|
2
|
+
from ..datatypes import SourceFile
|
3
|
+
from ..datatypes import SpectralMask
|
4
|
+
from ..datatypes import TruthParameter
|
6
5
|
|
7
6
|
|
8
7
|
def raw_load_config(name: str) -> dict:
|
@@ -54,7 +53,6 @@ def update_config_from_file(filename: str, given_config: dict) -> dict:
|
|
54
53
|
|
55
54
|
from .constants import REQUIRED_CONFIGS
|
56
55
|
from .constants import VALID_CONFIGS
|
57
|
-
from .constants import VALID_NOISE_MIX_MODES
|
58
56
|
|
59
57
|
updated_config = deepcopy(given_config)
|
60
58
|
|
@@ -81,6 +79,9 @@ def update_config_from_file(filename: str, given_config: dict) -> dict:
|
|
81
79
|
if key not in updated_config:
|
82
80
|
raise AttributeError(f"{filename} is missing required '{key}'")
|
83
81
|
|
82
|
+
# Validate and update sources
|
83
|
+
updated_config = update_sources(updated_config)
|
84
|
+
|
84
85
|
# Validate special cases
|
85
86
|
validate_truth_configs(updated_config)
|
86
87
|
validate_asr_configs(updated_config)
|
@@ -89,14 +90,76 @@ def update_config_from_file(filename: str, given_config: dict) -> dict:
|
|
89
90
|
if len(updated_config["spectral_masks"]) == 0:
|
90
91
|
updated_config["spectral_masks"] = given_config["spectral_masks"]
|
91
92
|
|
92
|
-
# Check for valid noise_mix_mode
|
93
|
-
if updated_config["noise_mix_mode"] not in VALID_NOISE_MIX_MODES:
|
94
|
-
nice_list = "\n".join([f" {item}" for item in VALID_NOISE_MIX_MODES])
|
95
|
-
raise ValueError(f"{filename} contains invalid noise_mix_mode.\nValid noise mix modes are:\n{nice_list}")
|
96
|
-
|
97
93
|
return updated_config
|
98
94
|
|
99
95
|
|
96
|
+
def update_sources(given: dict) -> dict:
|
97
|
+
"""Validate and update fields in given 'sources'
|
98
|
+
|
99
|
+
:param given: The dictionary of given config
|
100
|
+
"""
|
101
|
+
from .constants import REQUIRED_NON_PRIMARY_SOURCE_CONFIGS
|
102
|
+
from .constants import REQUIRED_SOURCE_CONFIGS
|
103
|
+
from .constants import REQUIRED_SOURCES_CATEGORIES
|
104
|
+
from .constants import VALID_NON_PRIMARY_SOURCE_CONFIGS
|
105
|
+
from .constants import VALID_PRIMARY_SOURCE_CONFIGS
|
106
|
+
|
107
|
+
sources = given["sources"]
|
108
|
+
|
109
|
+
for category in REQUIRED_SOURCES_CATEGORIES:
|
110
|
+
if category not in sources:
|
111
|
+
raise AttributeError(f"config sources is missing required '{category}'")
|
112
|
+
|
113
|
+
for category, source in sources.items():
|
114
|
+
for key in REQUIRED_SOURCE_CONFIGS:
|
115
|
+
if key not in source:
|
116
|
+
raise AttributeError(f"config source '{category}' is missing required '{key}'")
|
117
|
+
|
118
|
+
if category == "primary":
|
119
|
+
for key in source:
|
120
|
+
if key not in VALID_PRIMARY_SOURCE_CONFIGS:
|
121
|
+
nice_list = "\n".join([f" {item}" for item in VALID_PRIMARY_SOURCE_CONFIGS])
|
122
|
+
raise AttributeError(
|
123
|
+
f"Invalid source '{category}' config parameter: '{key}'.\nValid sources config parameters are:\n{nice_list}"
|
124
|
+
)
|
125
|
+
else:
|
126
|
+
for key in REQUIRED_NON_PRIMARY_SOURCE_CONFIGS:
|
127
|
+
if key not in source:
|
128
|
+
raise AttributeError(f"config source '{category}' is missing required '{key}'")
|
129
|
+
|
130
|
+
for key in source:
|
131
|
+
if key not in VALID_NON_PRIMARY_SOURCE_CONFIGS:
|
132
|
+
nice_list = "\n".join([f" {item}" for item in VALID_NON_PRIMARY_SOURCE_CONFIGS])
|
133
|
+
raise AttributeError(
|
134
|
+
f"Invalid source '{category}' config parameter: '{key}'.\nValid source config parameters are:\n{nice_list}"
|
135
|
+
)
|
136
|
+
|
137
|
+
files = source["files"]
|
138
|
+
|
139
|
+
if isinstance(files, str) and files in sources and files != category:
|
140
|
+
continue
|
141
|
+
|
142
|
+
if isinstance(files, list):
|
143
|
+
continue
|
144
|
+
|
145
|
+
raise TypeError(
|
146
|
+
f"'file' parameter of config source '{category}' is not a list or a reference to another source"
|
147
|
+
)
|
148
|
+
|
149
|
+
count = 0
|
150
|
+
while any(isinstance(source["files"], str) for source in sources.values()) and count < 100:
|
151
|
+
count += 1
|
152
|
+
for category, source in sources.items():
|
153
|
+
files = source["files"]
|
154
|
+
if isinstance(files, str):
|
155
|
+
given["sources"][category]["files"] = sources[files]["files"]
|
156
|
+
|
157
|
+
if count == 100:
|
158
|
+
raise RuntimeError("Check config sources for circular references")
|
159
|
+
|
160
|
+
return given
|
161
|
+
|
162
|
+
|
100
163
|
def validate_truth_configs(given: dict) -> None:
|
101
164
|
"""Validate fields in given 'truth_configs'
|
102
165
|
|
@@ -104,27 +167,31 @@ def validate_truth_configs(given: dict) -> None:
|
|
104
167
|
"""
|
105
168
|
from copy import deepcopy
|
106
169
|
|
107
|
-
from
|
108
|
-
|
170
|
+
from . import truth_functions
|
109
171
|
from .constants import REQUIRED_TRUTH_CONFIGS
|
110
172
|
|
111
|
-
|
112
|
-
raise AttributeError("config is missing required 'truth_configs'")
|
173
|
+
sources = given["sources"]
|
113
174
|
|
114
|
-
|
115
|
-
|
116
|
-
|
175
|
+
for category, source in sources.items():
|
176
|
+
if "truth_configs" not in source:
|
177
|
+
continue
|
117
178
|
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
raise AttributeError(f"'{name}' in truth_configs is missing required '{key}'")
|
179
|
+
truth_configs = source["truth_configs"]
|
180
|
+
if len(truth_configs) == 0:
|
181
|
+
raise ValueError(f"'truth_configs' in config source '{category}' is empty")
|
122
182
|
|
123
|
-
|
124
|
-
|
125
|
-
|
183
|
+
for truth_name, truth_config in truth_configs.items():
|
184
|
+
for k in REQUIRED_TRUTH_CONFIGS:
|
185
|
+
if k not in truth_config:
|
186
|
+
raise AttributeError(
|
187
|
+
f"'{truth_name}' in source '{category}' truth_configs is missing required '{k}'"
|
188
|
+
)
|
189
|
+
|
190
|
+
optional_config = deepcopy(truth_config)
|
191
|
+
for k in REQUIRED_TRUTH_CONFIGS:
|
192
|
+
del optional_config[k]
|
126
193
|
|
127
|
-
|
194
|
+
getattr(truth_functions, truth_config["function"] + "_validate")(optional_config)
|
128
195
|
|
129
196
|
|
130
197
|
def validate_asr_configs(given: dict) -> None:
|
@@ -132,8 +199,7 @@ def validate_asr_configs(given: dict) -> None:
|
|
132
199
|
|
133
200
|
:param given: The dictionary of given config
|
134
201
|
"""
|
135
|
-
from
|
136
|
-
|
202
|
+
from ..utils.asr import validate_asr
|
137
203
|
from .constants import REQUIRED_ASR_CONFIGS
|
138
204
|
|
139
205
|
if "asr_configs" not in given:
|
@@ -209,69 +275,80 @@ def update_config_from_hierarchy(root: str, leaf: str, config: dict) -> dict:
|
|
209
275
|
return new_config
|
210
276
|
|
211
277
|
|
212
|
-
def
|
213
|
-
"""Get the list of
|
278
|
+
def get_source_files(config: dict, show_progress: bool = False) -> list[SourceFile]:
|
279
|
+
"""Get the list of source files from a config
|
214
280
|
|
215
281
|
:param config: Config dictionary
|
216
282
|
:param show_progress: Show progress bar
|
217
|
-
:return: List of
|
283
|
+
:return: List of source files
|
218
284
|
"""
|
219
285
|
from itertools import chain
|
220
286
|
|
221
|
-
from
|
222
|
-
from
|
223
|
-
from sonusai.utils import track
|
287
|
+
from ..utils.parallel import par_track
|
288
|
+
from ..utils.parallel import track
|
224
289
|
|
225
|
-
|
290
|
+
sources = config["sources"]
|
291
|
+
if not isinstance(sources, dict) and not all(isinstance(source, dict) for source in sources):
|
292
|
+
raise TypeError("'sources' must be a dictionary of dictionaries")
|
293
|
+
|
294
|
+
if "primary" not in sources:
|
295
|
+
raise AttributeError("'primary' is missing in 'sources'")
|
226
296
|
|
227
297
|
class_indices = config["class_indices"]
|
228
298
|
if not isinstance(class_indices, list):
|
229
299
|
class_indices = [class_indices]
|
230
300
|
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
301
|
+
level_type = config["level_type"]
|
302
|
+
|
303
|
+
source_files: list[SourceFile] = []
|
304
|
+
for category in sources:
|
305
|
+
source_files.extend(
|
306
|
+
chain.from_iterable(
|
307
|
+
[
|
308
|
+
append_source_files(
|
309
|
+
category=category,
|
310
|
+
entry=entry,
|
311
|
+
class_indices=class_indices,
|
312
|
+
truth_configs=sources[category].get("truth_configs", []),
|
313
|
+
level_type=level_type,
|
314
|
+
)
|
315
|
+
for entry in sources[category]["files"]
|
316
|
+
]
|
317
|
+
)
|
242
318
|
)
|
243
|
-
)
|
244
319
|
|
245
|
-
progress = track(total=len(
|
246
|
-
|
320
|
+
progress = track(total=len(source_files), disable=not show_progress)
|
321
|
+
source_files = par_track(_get_num_samples, source_files, progress=progress)
|
247
322
|
progress.close()
|
248
323
|
|
249
324
|
num_classes = config["num_classes"]
|
250
|
-
for
|
251
|
-
if any(class_index < 0 for class_index in
|
325
|
+
for source_file in source_files:
|
326
|
+
if any(class_index < 0 for class_index in source_file.class_indices):
|
252
327
|
raise ValueError("class indices must contain only positive elements")
|
253
328
|
|
254
|
-
if any(class_index > num_classes for class_index in
|
329
|
+
if any(class_index > num_classes for class_index in source_file.class_indices):
|
255
330
|
raise ValueError(f"class index elements must not be greater than {num_classes}")
|
256
331
|
|
257
|
-
return
|
332
|
+
return source_files
|
258
333
|
|
259
334
|
|
260
|
-
def
|
261
|
-
|
335
|
+
def append_source_files(
|
336
|
+
category: str,
|
337
|
+
entry: dict,
|
262
338
|
class_indices: list[int],
|
263
339
|
truth_configs: dict,
|
264
340
|
level_type: str,
|
265
341
|
tokens: dict | None = None,
|
266
|
-
) -> list[
|
267
|
-
"""Process
|
342
|
+
) -> list[SourceFile]:
|
343
|
+
"""Process source files list and append as needed
|
268
344
|
|
269
|
-
:param
|
345
|
+
:param category: Source file category name
|
346
|
+
:param entry: Source file entry to append to the list
|
270
347
|
:param class_indices: Class indices
|
271
348
|
:param truth_configs: Truth configs
|
272
|
-
:param level_type:
|
349
|
+
:param level_type: Level type
|
273
350
|
:param tokens: Tokens used for variable expansion
|
274
|
-
:return: List of
|
351
|
+
:return: List of source files
|
275
352
|
"""
|
276
353
|
from copy import deepcopy
|
277
354
|
from glob import glob
|
@@ -282,41 +359,39 @@ def append_target_files(
|
|
282
359
|
from os.path import join
|
283
360
|
from os.path import splitext
|
284
361
|
|
285
|
-
from
|
286
|
-
|
362
|
+
from ..datatypes import TruthConfig
|
363
|
+
from ..utils.dataclass_from_dict import dataclass_from_dict
|
364
|
+
from ..utils.tokenized_shell_vars import tokenized_expand
|
365
|
+
from ..utils.tokenized_shell_vars import tokenized_replace
|
287
366
|
from .audio import validate_input_file
|
288
367
|
from .constants import REQUIRED_TRUTH_CONFIGS
|
289
|
-
from .datatypes import TruthConfig
|
290
|
-
from .tokenized_shell_vars import tokenized_expand
|
291
|
-
from .tokenized_shell_vars import tokenized_replace
|
292
368
|
|
293
369
|
if tokens is None:
|
294
370
|
tokens = {}
|
295
371
|
|
296
372
|
truth_configs_merged = deepcopy(truth_configs)
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
in_name = entry
|
373
|
+
|
374
|
+
if not isinstance(entry, dict):
|
375
|
+
raise TypeError("'entry' must be a dictionary")
|
376
|
+
|
377
|
+
in_name = entry.get("name")
|
378
|
+
if in_name is None:
|
379
|
+
raise KeyError("Source file list contained record without name")
|
380
|
+
|
381
|
+
class_indices = entry.get("class_indices", class_indices)
|
382
|
+
if not isinstance(class_indices, list):
|
383
|
+
class_indices = [class_indices]
|
384
|
+
|
385
|
+
truth_configs_override = entry.get("truth_configs", {})
|
386
|
+
for key in truth_configs_override:
|
387
|
+
if key not in truth_configs:
|
388
|
+
raise AttributeError(
|
389
|
+
f"Truth config '{key}' override specified for {entry['name']} is not defined at top level"
|
390
|
+
)
|
391
|
+
if key in truth_configs_override:
|
392
|
+
truth_configs_merged[key] |= truth_configs_override[key]
|
393
|
+
|
394
|
+
level_type = entry.get("level_type", level_type)
|
320
395
|
|
321
396
|
in_name, new_tokens = tokenized_expand(in_name)
|
322
397
|
tokens.update(new_tokens)
|
@@ -324,7 +399,7 @@ def append_target_files(
|
|
324
399
|
if not names:
|
325
400
|
raise OSError(f"Could not find {in_name}. Make sure path exists")
|
326
401
|
|
327
|
-
|
402
|
+
source_files: list[SourceFile] = []
|
328
403
|
for name in names:
|
329
404
|
ext = splitext(name)[1].lower()
|
330
405
|
dir_name = dirname(name)
|
@@ -333,9 +408,10 @@ def append_target_files(
|
|
333
408
|
child = file
|
334
409
|
if not isabs(child):
|
335
410
|
child = join(dir_name, child)
|
336
|
-
|
337
|
-
|
338
|
-
|
411
|
+
source_files.extend(
|
412
|
+
append_source_files(
|
413
|
+
category=category,
|
414
|
+
entry={"name": child},
|
339
415
|
class_indices=class_indices,
|
340
416
|
truth_configs=truth_configs_merged,
|
341
417
|
level_type=level_type,
|
@@ -355,41 +431,26 @@ def append_target_files(
|
|
355
431
|
tokens.update(new_tokens)
|
356
432
|
if not isabs(child):
|
357
433
|
child = join(dir_name, child)
|
358
|
-
|
359
|
-
|
360
|
-
|
434
|
+
source_files.extend(
|
435
|
+
append_source_files(
|
436
|
+
category=category,
|
437
|
+
entry={"name": child},
|
361
438
|
class_indices=class_indices,
|
362
439
|
truth_configs=truth_configs_merged,
|
363
440
|
level_type=level_type,
|
364
441
|
tokens=tokens,
|
365
442
|
)
|
366
443
|
)
|
367
|
-
elif ext == ".yml":
|
368
|
-
try:
|
369
|
-
yml_config = raw_load_config(name)
|
370
|
-
|
371
|
-
if "targets" in yml_config:
|
372
|
-
for record in yml_config["targets"]:
|
373
|
-
target_files.extend(
|
374
|
-
append_target_files(
|
375
|
-
entry=record,
|
376
|
-
class_indices=class_indices,
|
377
|
-
truth_configs=truth_configs_merged,
|
378
|
-
level_type=level_type,
|
379
|
-
tokens=tokens,
|
380
|
-
)
|
381
|
-
)
|
382
|
-
except Exception as e:
|
383
|
-
raise OSError(f"Error processing {name}: {e}") from e
|
384
444
|
else:
|
385
445
|
validate_input_file(name)
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
446
|
+
source_file = SourceFile(
|
447
|
+
category=category,
|
448
|
+
name=tokenized_replace(name, tokens),
|
449
|
+
samples=0,
|
450
|
+
class_indices=class_indices,
|
451
|
+
level_type=level_type,
|
452
|
+
truth_configs={},
|
453
|
+
)
|
393
454
|
if len(truth_configs_merged) > 0:
|
394
455
|
for tc_key, tc_value in truth_configs_merged.items():
|
395
456
|
config = deepcopy(tc_value)
|
@@ -398,145 +459,58 @@ def append_target_files(
|
|
398
459
|
truth_config[key] = config[key]
|
399
460
|
del config[key]
|
400
461
|
truth_config["config"] = config
|
401
|
-
|
402
|
-
for tc_key in
|
462
|
+
source_file.truth_configs[tc_key] = dataclass_from_dict(TruthConfig, truth_config)
|
463
|
+
for tc_key in source_file.truth_configs:
|
403
464
|
if (
|
404
465
|
"function" in truth_configs_merged[tc_key]
|
405
466
|
and truth_configs_merged[tc_key]["function"] == "file"
|
406
467
|
):
|
407
|
-
truth_configs_merged[tc_key]["file"] = splitext(
|
408
|
-
|
409
|
-
except Exception as e:
|
410
|
-
raise OSError(f"Error processing {name}: {e}") from e
|
411
|
-
|
412
|
-
return target_files
|
413
|
-
|
414
|
-
|
415
|
-
def get_noise_files(config: dict, show_progress: bool = False) -> list[NoiseFile]:
|
416
|
-
"""Get the list of noise files from a config
|
417
|
-
|
418
|
-
:param config: Config dictionary
|
419
|
-
:param show_progress: Show progress bar
|
420
|
-
:return: List of noise file
|
421
|
-
"""
|
422
|
-
from itertools import chain
|
423
|
-
|
424
|
-
from sonusai.utils import dataclass_from_dict
|
425
|
-
from sonusai.utils import par_track
|
426
|
-
from sonusai.utils import track
|
427
|
-
|
428
|
-
from .datatypes import NoiseFile
|
429
|
-
|
430
|
-
noise_files = list(chain.from_iterable([append_noise_files(entry=entry) for entry in config["noises"]]))
|
431
|
-
|
432
|
-
progress = track(total=len(noise_files), disable=not show_progress)
|
433
|
-
noise_files = par_track(_get_num_samples, noise_files, progress=progress)
|
434
|
-
progress.close()
|
435
|
-
|
436
|
-
return dataclass_from_dict(list[NoiseFile], noise_files)
|
437
|
-
|
438
|
-
|
439
|
-
def append_noise_files(entry: dict | str, tokens: dict | None = None) -> list[dict]:
|
440
|
-
"""Process noise files list and append as needed
|
441
|
-
|
442
|
-
:param entry: Noise file entry to append to the list
|
443
|
-
:param tokens: Tokens used for variable expansion
|
444
|
-
:return: List of noise files
|
445
|
-
"""
|
446
|
-
from glob import glob
|
447
|
-
from os import listdir
|
448
|
-
from os.path import dirname
|
449
|
-
from os.path import isabs
|
450
|
-
from os.path import isdir
|
451
|
-
from os.path import join
|
452
|
-
from os.path import splitext
|
453
|
-
|
454
|
-
from .audio import validate_input_file
|
455
|
-
from .tokenized_shell_vars import tokenized_expand
|
456
|
-
from .tokenized_shell_vars import tokenized_replace
|
457
|
-
|
458
|
-
if tokens is None:
|
459
|
-
tokens = {}
|
460
|
-
|
461
|
-
if isinstance(entry, dict):
|
462
|
-
if "name" in entry:
|
463
|
-
in_name = entry["name"]
|
464
|
-
else:
|
465
|
-
raise AttributeError("Noise list contained record without name")
|
466
|
-
else:
|
467
|
-
in_name = entry
|
468
|
-
|
469
|
-
in_name, new_tokens = tokenized_expand(in_name)
|
470
|
-
tokens.update(new_tokens)
|
471
|
-
names = sorted(glob(in_name))
|
472
|
-
if not names:
|
473
|
-
raise OSError(f"Could not find {in_name}. Make sure path exists")
|
474
|
-
|
475
|
-
noise_files: list[dict] = []
|
476
|
-
for name in names:
|
477
|
-
ext = splitext(name)[1].lower()
|
478
|
-
dir_name = dirname(name)
|
479
|
-
if isdir(name):
|
480
|
-
for file in listdir(name):
|
481
|
-
child = file
|
482
|
-
if not isabs(child):
|
483
|
-
child = join(dir_name, child)
|
484
|
-
noise_files.extend(append_noise_files(entry=child, tokens=tokens))
|
485
|
-
else:
|
486
|
-
try:
|
487
|
-
if ext == ".txt":
|
488
|
-
with open(file=name) as txt_file:
|
489
|
-
for line in txt_file:
|
490
|
-
# strip comments
|
491
|
-
child = line.partition("#")[0]
|
492
|
-
child = child.rstrip()
|
493
|
-
if child:
|
494
|
-
child, new_tokens = tokenized_expand(child)
|
495
|
-
tokens.update(new_tokens)
|
496
|
-
if not isabs(child):
|
497
|
-
child = join(dir_name, child)
|
498
|
-
noise_files.extend(append_noise_files(entry=child, tokens=tokens))
|
499
|
-
elif ext == ".yml":
|
500
|
-
try:
|
501
|
-
yml_config = raw_load_config(name)
|
502
|
-
|
503
|
-
if "noises" in yml_config:
|
504
|
-
for record in yml_config["noises"]:
|
505
|
-
noise_files.extend(append_noise_files(entry=record, tokens=tokens))
|
506
|
-
except Exception as e:
|
507
|
-
raise OSError(f"Error processing {name}: {e}") from e
|
508
|
-
else:
|
509
|
-
validate_input_file(name)
|
510
|
-
noise_file: dict = {
|
511
|
-
"expanded_name": name,
|
512
|
-
"name": tokenized_replace(name, tokens),
|
513
|
-
}
|
514
|
-
noise_files.append(noise_file)
|
468
|
+
truth_configs_merged[tc_key]["file"] = splitext(source_file.name)[0] + ".h5"
|
469
|
+
source_files.append(source_file)
|
515
470
|
except Exception as e:
|
516
471
|
raise OSError(f"Error processing {name}: {e}") from e
|
517
472
|
|
518
|
-
return
|
473
|
+
return source_files
|
519
474
|
|
520
475
|
|
521
|
-
def
|
476
|
+
def get_ir_files(config: dict, show_progress: bool = False) -> list[ImpulseResponseFile]:
|
522
477
|
"""Get the list of impulse response files from a config
|
523
478
|
|
524
479
|
:param config: Config dictionary
|
480
|
+
:param show_progress: Show progress bar
|
525
481
|
:return: List of impulse response files
|
526
482
|
"""
|
527
483
|
from itertools import chain
|
528
484
|
|
529
|
-
|
485
|
+
from ..utils.parallel import par_track
|
486
|
+
from ..utils.parallel import track
|
487
|
+
|
488
|
+
ir_files = list(
|
530
489
|
chain.from_iterable(
|
531
490
|
[
|
532
|
-
|
491
|
+
append_ir_files(
|
492
|
+
entry=ImpulseResponseFile(
|
493
|
+
name=entry["name"],
|
494
|
+
tags=entry.get("tags", []),
|
495
|
+
delay=entry.get("delay", "auto"),
|
496
|
+
)
|
497
|
+
)
|
533
498
|
for entry in config["impulse_responses"]
|
534
499
|
]
|
535
500
|
)
|
536
501
|
)
|
537
502
|
|
503
|
+
if len(ir_files) == 0:
|
504
|
+
return []
|
505
|
+
|
506
|
+
progress = track(total=len(ir_files), disable=not show_progress)
|
507
|
+
ir_files = par_track(_get_ir_delay, ir_files, progress=progress)
|
508
|
+
progress.close()
|
509
|
+
|
510
|
+
return ir_files
|
511
|
+
|
538
512
|
|
539
|
-
def
|
513
|
+
def append_ir_files(entry: ImpulseResponseFile, tokens: dict | None = None) -> list[ImpulseResponseFile]:
|
540
514
|
"""Process impulse response files list and append as needed
|
541
515
|
|
542
516
|
:param entry: Impulse response file entry to append to the list
|
@@ -551,21 +525,20 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
|
|
551
525
|
from os.path import join
|
552
526
|
from os.path import splitext
|
553
527
|
|
528
|
+
from ..utils.tokenized_shell_vars import tokenized_expand
|
529
|
+
from ..utils.tokenized_shell_vars import tokenized_replace
|
554
530
|
from .audio import validate_input_file
|
555
|
-
from .ir_delay import get_impulse_response_delay
|
556
|
-
from .tokenized_shell_vars import tokenized_expand
|
557
|
-
from .tokenized_shell_vars import tokenized_replace
|
558
531
|
|
559
532
|
if tokens is None:
|
560
533
|
tokens = {}
|
561
534
|
|
562
|
-
in_name, new_tokens = tokenized_expand(entry.
|
535
|
+
in_name, new_tokens = tokenized_expand(entry.name)
|
563
536
|
tokens.update(new_tokens)
|
564
537
|
names = sorted(glob(in_name))
|
565
538
|
if not names:
|
566
539
|
raise OSError(f"Could not find {in_name}. Make sure path exists")
|
567
540
|
|
568
|
-
|
541
|
+
ir_files: list[ImpulseResponseFile] = []
|
569
542
|
for name in names:
|
570
543
|
ext = splitext(name)[1].lower()
|
571
544
|
dir_name = dirname(name)
|
@@ -573,8 +546,8 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
|
|
573
546
|
for file in listdir(name):
|
574
547
|
if not isabs(file):
|
575
548
|
file = join(dir_name, file)
|
576
|
-
child = ImpulseResponseFile(file, entry.tags,
|
577
|
-
|
549
|
+
child = ImpulseResponseFile(file, entry.tags, entry.delay)
|
550
|
+
ir_files.extend(append_ir_files(entry=child, tokens=tokens))
|
578
551
|
else:
|
579
552
|
try:
|
580
553
|
if ext == ".txt":
|
@@ -588,30 +561,24 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
|
|
588
561
|
tokens.update(new_tokens)
|
589
562
|
if not isabs(file):
|
590
563
|
file = join(dir_name, file)
|
591
|
-
child = ImpulseResponseFile(file, entry.tags,
|
592
|
-
|
564
|
+
child = ImpulseResponseFile(file, entry.tags, entry.delay)
|
565
|
+
ir_files.extend(append_ir_files(entry=child, tokens=tokens))
|
593
566
|
elif ext == ".yml":
|
594
567
|
try:
|
595
568
|
yml_config = raw_load_config(name)
|
596
569
|
|
597
570
|
if "impulse_responses" in yml_config:
|
598
571
|
for record in yml_config["impulse_responses"]:
|
599
|
-
|
600
|
-
append_impulse_response_files(entry=record, tokens=tokens)
|
601
|
-
)
|
572
|
+
ir_files.extend(append_ir_files(entry=record, tokens=tokens))
|
602
573
|
except Exception as e:
|
603
574
|
raise OSError(f"Error processing {name}: {e}") from e
|
604
575
|
else:
|
605
576
|
validate_input_file(name)
|
606
|
-
|
607
|
-
ImpulseResponseFile(
|
608
|
-
tokenized_replace(name, tokens), entry.tags, get_impulse_response_delay(name)
|
609
|
-
)
|
610
|
-
)
|
577
|
+
ir_files.append(ImpulseResponseFile(tokenized_replace(name, tokens), entry.tags, entry.delay))
|
611
578
|
except Exception as e:
|
612
579
|
raise OSError(f"Error processing {name}: {e}") from e
|
613
580
|
|
614
|
-
return
|
581
|
+
return ir_files
|
615
582
|
|
616
583
|
|
617
584
|
def get_spectral_masks(config: dict) -> list[SpectralMask]:
|
@@ -620,10 +587,10 @@ def get_spectral_masks(config: dict) -> list[SpectralMask]:
|
|
620
587
|
:param config: Config dictionary
|
621
588
|
:return: List of spectral masks
|
622
589
|
"""
|
623
|
-
from
|
590
|
+
from ..utils.dataclass_from_dict import list_dataclass_from_dict
|
624
591
|
|
625
592
|
try:
|
626
|
-
return
|
593
|
+
return list_dataclass_from_dict(list[SpectralMask], config["spectral_masks"])
|
627
594
|
except Exception as e:
|
628
595
|
raise ValueError(f"Error in spectral_masks: {e}") from e
|
629
596
|
|
@@ -636,30 +603,43 @@ def get_truth_parameters(config: dict) -> list[TruthParameter]:
|
|
636
603
|
"""
|
637
604
|
from copy import deepcopy
|
638
605
|
|
639
|
-
from
|
640
|
-
|
606
|
+
from . import truth_functions
|
641
607
|
from .constants import REQUIRED_TRUTH_CONFIGS
|
642
|
-
from .datatypes import TruthParameter
|
643
608
|
|
644
609
|
truth_parameters: list[TruthParameter] = []
|
645
|
-
for
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
610
|
+
for category, source_config in config["sources"].items():
|
611
|
+
if "truth_configs" in source_config:
|
612
|
+
for truth_name, truth_config in source_config["truth_configs"].items():
|
613
|
+
optional_config = deepcopy(truth_config)
|
614
|
+
for key in REQUIRED_TRUTH_CONFIGS:
|
615
|
+
del optional_config[key]
|
616
|
+
|
617
|
+
parameters = getattr(truth_functions, truth_config["function"] + "_parameters")(
|
618
|
+
config["feature"],
|
619
|
+
config["num_classes"],
|
620
|
+
optional_config,
|
621
|
+
)
|
622
|
+
truth_parameters.append(TruthParameter(category, truth_name, parameters))
|
656
623
|
|
657
624
|
return truth_parameters
|
658
625
|
|
659
626
|
|
660
|
-
def _get_num_samples(entry:
|
627
|
+
def _get_num_samples(entry: SourceFile) -> SourceFile:
|
661
628
|
from .audio import get_num_samples
|
662
629
|
|
663
|
-
entry
|
664
|
-
|
630
|
+
entry.samples = get_num_samples(entry.name)
|
631
|
+
return entry
|
632
|
+
|
633
|
+
|
634
|
+
def _get_ir_delay(entry: ImpulseResponseFile) -> ImpulseResponseFile:
|
635
|
+
from .ir_delay import get_ir_delay
|
636
|
+
|
637
|
+
if entry.delay == "auto":
|
638
|
+
entry.delay = get_ir_delay(entry.name)
|
639
|
+
else:
|
640
|
+
try:
|
641
|
+
entry.delay = int(entry.delay)
|
642
|
+
except ValueError as e:
|
643
|
+
raise ValueError(f"Invalid impulse response delay: {entry.delay}") from e
|
644
|
+
|
665
645
|
return entry
|