sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +81 -91
- sonusai/genmetrics.py +51 -61
- sonusai/genmix.py +105 -115
- sonusai/genmixdb.py +201 -174
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +16 -18
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +20 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +40 -38
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +669 -477
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +58 -101
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +41 -30
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
- sonusai-0.19.6.dist-info/RECORD +125 -0
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.9.dist-info/RECORD +0 -125
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
sonusai/mixture/config.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
|
+
from sonusai.mixture.datatypes import ImpulseResponseFile
|
1
2
|
from sonusai.mixture.datatypes import ImpulseResponseFiles
|
2
3
|
from sonusai.mixture.datatypes import NoiseFiles
|
3
4
|
from sonusai.mixture.datatypes import SpectralMasks
|
4
5
|
from sonusai.mixture.datatypes import TargetFiles
|
6
|
+
from sonusai.mixture.datatypes import TruthParameters
|
5
7
|
|
6
8
|
|
7
9
|
def raw_load_config(name: str) -> dict:
|
@@ -12,7 +14,7 @@ def raw_load_config(name: str) -> dict:
|
|
12
14
|
"""
|
13
15
|
import yaml
|
14
16
|
|
15
|
-
with open(file=name
|
17
|
+
with open(file=name) as f:
|
16
18
|
config = yaml.safe_load(f)
|
17
19
|
|
18
20
|
return config
|
@@ -23,13 +25,12 @@ def get_default_config() -> dict:
|
|
23
25
|
|
24
26
|
:return: Dictionary of default config data
|
25
27
|
"""
|
26
|
-
from sonusai import SonusAIError
|
27
28
|
from .constants import DEFAULT_CONFIG
|
28
29
|
|
29
30
|
try:
|
30
31
|
return raw_load_config(DEFAULT_CONFIG)
|
31
32
|
except Exception as e:
|
32
|
-
raise
|
33
|
+
raise OSError(f"Error loading default config: {e}") from e
|
33
34
|
|
34
35
|
|
35
36
|
def load_config(name: str) -> dict:
|
@@ -40,125 +41,115 @@ def load_config(name: str) -> dict:
|
|
40
41
|
"""
|
41
42
|
from os.path import join
|
42
43
|
|
43
|
-
return update_config_from_file(
|
44
|
+
return update_config_from_file(filename=join(name, "config.yml"), given_config=get_default_config())
|
44
45
|
|
45
46
|
|
46
|
-
def update_config_from_file(
|
47
|
-
"""Update the given config with the config in the YAML file
|
47
|
+
def update_config_from_file(filename: str, given_config: dict) -> dict:
|
48
|
+
"""Update the given config with the config in the specified YAML file
|
48
49
|
|
49
|
-
:param
|
50
|
-
:param
|
50
|
+
:param filename: File name
|
51
|
+
:param given_config: Config dictionary to update
|
51
52
|
:return: Updated config dictionary
|
52
53
|
"""
|
53
54
|
from copy import deepcopy
|
54
55
|
|
55
|
-
from sonusai import SonusAIError
|
56
56
|
from .constants import REQUIRED_CONFIGS
|
57
57
|
from .constants import VALID_CONFIGS
|
58
58
|
from .constants import VALID_NOISE_MIX_MODES
|
59
59
|
|
60
|
-
updated_config = deepcopy(
|
60
|
+
updated_config = deepcopy(given_config)
|
61
61
|
|
62
62
|
try:
|
63
|
-
|
63
|
+
file_config = raw_load_config(filename)
|
64
64
|
except Exception as e:
|
65
|
-
raise
|
65
|
+
raise OSError(f"Error loading config from {filename}: {e}") from e
|
66
66
|
|
67
67
|
# Check for unrecognized keys
|
68
|
-
for key in
|
68
|
+
for key in file_config:
|
69
69
|
if key not in VALID_CONFIGS:
|
70
|
-
nice_list =
|
71
|
-
raise
|
72
|
-
|
70
|
+
nice_list = "\n".join([f" {item}" for item in VALID_CONFIGS])
|
71
|
+
raise AttributeError(
|
72
|
+
f"Invalid config parameter in {filename}: {key}.\nValid config parameters are:\n{nice_list}"
|
73
|
+
)
|
73
74
|
|
74
75
|
# Use default config as base and overwrite with given config keys as found
|
75
76
|
for key in updated_config:
|
76
|
-
if key in
|
77
|
-
|
78
|
-
updated_config[key] = new_config[key]
|
79
|
-
|
80
|
-
# Handle 'truth_settings' special case
|
81
|
-
if 'truth_settings' in new_config:
|
82
|
-
updated_config['truth_settings'] = deepcopy(new_config['truth_settings'])
|
83
|
-
|
84
|
-
if not isinstance(updated_config['truth_settings'], list):
|
85
|
-
updated_config['truth_settings'] = [updated_config['truth_settings']]
|
86
|
-
|
87
|
-
default = deepcopy(config['truth_settings'])
|
88
|
-
if not isinstance(default, list):
|
89
|
-
default = [default]
|
90
|
-
|
91
|
-
updated_config['truth_settings'] = update_truth_settings(updated_config['truth_settings'], default)
|
92
|
-
|
93
|
-
# Handle 'asr_configs' special case
|
94
|
-
if 'asr_configs' in updated_config:
|
95
|
-
asr_configs = {}
|
96
|
-
for asr_config in updated_config['asr_configs']:
|
97
|
-
asr_name = asr_config.get('name', None)
|
98
|
-
asr_engine = asr_config.get('engine', None)
|
99
|
-
if asr_name is None or asr_engine is None:
|
100
|
-
raise SonusAIError(f'Invalid config parameter in {name}: asr_configs.\n'
|
101
|
-
f'asr_configs must contain both name and engine.')
|
102
|
-
del asr_config['name']
|
103
|
-
asr_configs[asr_name] = asr_config
|
104
|
-
updated_config['asr_configs'] = asr_configs
|
77
|
+
if key in file_config:
|
78
|
+
updated_config[key] = file_config[key]
|
105
79
|
|
106
80
|
# Check for required keys
|
107
81
|
for key in REQUIRED_CONFIGS:
|
108
82
|
if key not in updated_config:
|
109
|
-
raise
|
83
|
+
raise AttributeError(f"{filename} is missing required '{key}'")
|
84
|
+
|
85
|
+
# Validate special cases
|
86
|
+
validate_truth_configs(updated_config)
|
87
|
+
validate_asr_configs(updated_config)
|
110
88
|
|
111
89
|
# Check for non-empty spectral masks
|
112
|
-
if len(updated_config[
|
113
|
-
updated_config[
|
90
|
+
if len(updated_config["spectral_masks"]) == 0:
|
91
|
+
updated_config["spectral_masks"] = given_config["spectral_masks"]
|
114
92
|
|
115
93
|
# Check for valid noise_mix_mode
|
116
|
-
if updated_config[
|
117
|
-
nice_list =
|
118
|
-
raise
|
119
|
-
f'Valid noise mix modes are:\n{nice_list}')
|
94
|
+
if updated_config["noise_mix_mode"] not in VALID_NOISE_MIX_MODES:
|
95
|
+
nice_list = "\n".join([f" {item}" for item in VALID_NOISE_MIX_MODES])
|
96
|
+
raise ValueError(f"{filename} contains invalid noise_mix_mode.\nValid noise mix modes are:\n{nice_list}")
|
120
97
|
|
121
98
|
return updated_config
|
122
99
|
|
123
100
|
|
124
|
-
def
|
125
|
-
"""
|
101
|
+
def validate_truth_configs(given: dict) -> None:
|
102
|
+
"""Validate fields in given 'truth_configs'
|
126
103
|
|
127
|
-
:param given: The dictionary of given
|
128
|
-
:param default: The dictionary of default truth settings
|
129
|
-
:return: Updated dictionary of truth settings
|
104
|
+
:param given: The dictionary of given config
|
130
105
|
"""
|
131
106
|
from copy import deepcopy
|
132
107
|
|
133
|
-
from sonusai import
|
134
|
-
from .constants import VALID_TRUTH_SETTINGS
|
108
|
+
from sonusai.mixture import truth_functions
|
135
109
|
|
136
|
-
|
137
|
-
truth_settings = deepcopy(given)
|
138
|
-
else:
|
139
|
-
truth_settings = [deepcopy(given)]
|
110
|
+
from .constants import REQUIRED_TRUTH_CONFIGS
|
140
111
|
|
141
|
-
if
|
142
|
-
raise
|
112
|
+
if "truth_configs" not in given:
|
113
|
+
raise AttributeError("config is missing required 'truth_configs'")
|
143
114
|
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
nice_list = '\n'.join([f' {item}' for item in VALID_TRUTH_SETTINGS])
|
148
|
-
raise SonusAIError(f'Invalid truth_settings: {key}.\nValid truth_settings are:\n{nice_list}')
|
115
|
+
truth_configs = given["truth_configs"]
|
116
|
+
if len(truth_configs) == 0:
|
117
|
+
raise ValueError("'truth_configs' in config is empty")
|
149
118
|
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
119
|
+
for name, truth_config in truth_configs.items():
|
120
|
+
for key in REQUIRED_TRUTH_CONFIGS:
|
121
|
+
if key not in truth_config:
|
122
|
+
raise AttributeError(f"'{name}' in truth_configs is missing required '{key}'")
|
123
|
+
|
124
|
+
optional_config = deepcopy(truth_config)
|
125
|
+
for key in REQUIRED_TRUTH_CONFIGS:
|
126
|
+
del optional_config[key]
|
127
|
+
|
128
|
+
getattr(truth_functions, truth_config["function"] + "_validate")(optional_config)
|
129
|
+
|
130
|
+
|
131
|
+
def validate_asr_configs(given: dict) -> None:
|
132
|
+
"""Validate fields in given 'asr_config'
|
133
|
+
|
134
|
+
:param given: The dictionary of given config
|
135
|
+
"""
|
136
|
+
from sonusai.utils import validate_asr
|
137
|
+
|
138
|
+
from .constants import REQUIRED_ASR_CONFIGS
|
139
|
+
|
140
|
+
if "asr_configs" not in given:
|
141
|
+
raise AttributeError("config is missing required 'asr_configs'")
|
142
|
+
|
143
|
+
asr_configs = given["asr_configs"]
|
156
144
|
|
157
|
-
for
|
158
|
-
|
159
|
-
|
145
|
+
for name, asr_config in asr_configs.items():
|
146
|
+
for key in REQUIRED_ASR_CONFIGS:
|
147
|
+
if key not in asr_config:
|
148
|
+
raise AttributeError(f"'{name}' in asr_configs is missing required '{key}'")
|
160
149
|
|
161
|
-
|
150
|
+
engine = asr_config["engine"]
|
151
|
+
config = {x: asr_config[x] for x in asr_config if x != "engine"}
|
152
|
+
validate_asr(engine, **config)
|
162
153
|
|
163
154
|
|
164
155
|
def get_hierarchical_config_files(root: str, leaf: str) -> list[str]:
|
@@ -171,25 +162,23 @@ def get_hierarchical_config_files(root: str, leaf: str) -> list[str]:
|
|
171
162
|
import os
|
172
163
|
from pathlib import Path
|
173
164
|
|
174
|
-
|
175
|
-
|
176
|
-
config_file = 'config.yml'
|
165
|
+
config_file = "config.yml"
|
177
166
|
|
178
167
|
root_path = Path(os.path.abspath(root))
|
179
168
|
if not root_path.is_dir():
|
180
|
-
raise
|
169
|
+
raise OSError(f"Given root, {root_path}, is not a directory.")
|
181
170
|
|
182
171
|
leaf_path = Path(os.path.abspath(leaf))
|
183
172
|
if not leaf_path.is_dir():
|
184
|
-
raise
|
173
|
+
raise OSError(f"Given leaf, {leaf_path}, is not a directory.")
|
185
174
|
|
186
175
|
common = os.path.commonpath((root_path, leaf_path))
|
187
176
|
if os.path.normpath(common) != os.path.normpath(root_path):
|
188
|
-
raise
|
177
|
+
raise OSError(f"Given leaf, {leaf_path}, is not in the hierarchy of the given root, {root_path}")
|
189
178
|
|
190
179
|
top_config_file = os.path.join(root_path, config_file)
|
191
180
|
if not Path(top_config_file).is_file():
|
192
|
-
raise
|
181
|
+
raise OSError(f"Could not find {top_config_file}")
|
193
182
|
|
194
183
|
current = leaf_path
|
195
184
|
config_files = []
|
@@ -216,24 +205,11 @@ def update_config_from_hierarchy(root: str, leaf: str, config: dict) -> dict:
|
|
216
205
|
new_config = deepcopy(config)
|
217
206
|
config_files = get_hierarchical_config_files(root=root, leaf=leaf)
|
218
207
|
for config_file in config_files:
|
219
|
-
new_config = update_config_from_file(
|
208
|
+
new_config = update_config_from_file(filename=config_file, given_config=new_config)
|
220
209
|
|
221
210
|
return new_config
|
222
211
|
|
223
212
|
|
224
|
-
def get_max_class(num_classes: int, truth_mutex: bool) -> int:
|
225
|
-
"""Get the maximum class index
|
226
|
-
|
227
|
-
:param num_classes: Number of classes
|
228
|
-
:param truth_mutex: Truth is mutex mode
|
229
|
-
:return: Highest class index
|
230
|
-
"""
|
231
|
-
max_class = num_classes
|
232
|
-
if truth_mutex:
|
233
|
-
max_class -= 1
|
234
|
-
return max_class
|
235
|
-
|
236
|
-
|
237
213
|
def get_target_files(config: dict, show_progress: bool = False) -> TargetFiles:
|
238
214
|
"""Get the list of target files from a config
|
239
215
|
|
@@ -243,48 +219,62 @@ def get_target_files(config: dict, show_progress: bool = False) -> TargetFiles:
|
|
243
219
|
"""
|
244
220
|
from itertools import chain
|
245
221
|
|
246
|
-
from tqdm import tqdm
|
247
|
-
|
248
|
-
from sonusai import SonusAIError
|
249
222
|
from sonusai.utils import dataclass_from_dict
|
250
|
-
from sonusai.utils import
|
251
|
-
from .
|
223
|
+
from sonusai.utils import par_track
|
224
|
+
from sonusai.utils import track
|
252
225
|
|
253
|
-
|
254
|
-
level_type = config.get('target_level_type', None)
|
255
|
-
target_files = list(chain.from_iterable([append_target_files(entry=entry,
|
256
|
-
truth_settings=truth_settings,
|
257
|
-
level_type=level_type)
|
258
|
-
for entry in config['targets']]))
|
226
|
+
from .datatypes import TargetFiles
|
259
227
|
|
260
|
-
|
261
|
-
|
228
|
+
class_indices = config["class_indices"]
|
229
|
+
if not isinstance(class_indices, list):
|
230
|
+
class_indices = [class_indices]
|
231
|
+
|
232
|
+
target_files = list(
|
233
|
+
chain.from_iterable(
|
234
|
+
[
|
235
|
+
append_target_files(
|
236
|
+
entry=entry,
|
237
|
+
class_indices=class_indices,
|
238
|
+
truth_configs=config["truth_configs"],
|
239
|
+
level_type=config["target_level_type"],
|
240
|
+
)
|
241
|
+
for entry in config["targets"]
|
242
|
+
]
|
243
|
+
)
|
244
|
+
)
|
245
|
+
|
246
|
+
progress = track(total=len(target_files), disable=not show_progress)
|
247
|
+
target_files = par_track(_get_num_samples, target_files, progress=progress)
|
262
248
|
progress.close()
|
263
249
|
|
264
|
-
|
265
|
-
|
250
|
+
num_classes = config["num_classes"]
|
266
251
|
for target_file in target_files:
|
267
|
-
|
252
|
+
if any(class_index < 0 for class_index in target_file["class_indices"]):
|
253
|
+
raise ValueError("class indices must contain only positive elements")
|
268
254
|
|
269
|
-
for
|
270
|
-
|
271
|
-
raise SonusAIError('invalid truth index')
|
255
|
+
if any(class_index > num_classes for class_index in target_file["class_indices"]):
|
256
|
+
raise ValueError(f"class index elements must not be greater than {num_classes}")
|
272
257
|
|
273
258
|
return dataclass_from_dict(TargetFiles, target_files)
|
274
259
|
|
275
260
|
|
276
|
-
def append_target_files(
|
277
|
-
|
278
|
-
|
279
|
-
|
261
|
+
def append_target_files(
|
262
|
+
entry: dict | str,
|
263
|
+
class_indices: list[int],
|
264
|
+
truth_configs: dict,
|
265
|
+
level_type: str,
|
266
|
+
tokens: dict | None = None,
|
267
|
+
) -> list[dict]:
|
280
268
|
"""Process target files list and append as needed
|
281
269
|
|
282
270
|
:param entry: Target file entry to append to the list
|
283
|
-
:param
|
271
|
+
:param class_indices: Class indices
|
272
|
+
:param truth_configs: Truth configs
|
284
273
|
:param level_type: Target level type
|
285
274
|
:param tokens: Tokens used for variable expansion
|
286
275
|
:return: List of target files
|
287
276
|
"""
|
277
|
+
from copy import deepcopy
|
288
278
|
from glob import glob
|
289
279
|
from os import listdir
|
290
280
|
from os.path import dirname
|
@@ -293,8 +283,11 @@ def append_target_files(entry: dict | str,
|
|
293
283
|
from os.path import join
|
294
284
|
from os.path import splitext
|
295
285
|
|
296
|
-
from sonusai import
|
286
|
+
from sonusai.utils import dataclass_from_dict
|
287
|
+
|
297
288
|
from .audio import validate_input_file
|
289
|
+
from .constants import REQUIRED_TRUTH_CONFIGS
|
290
|
+
from .datatypes import TruthConfig
|
298
291
|
from .tokenized_shell_vars import tokenized_expand
|
299
292
|
from .tokenized_shell_vars import tokenized_replace
|
300
293
|
|
@@ -302,23 +295,38 @@ def append_target_files(entry: dict | str,
|
|
302
295
|
tokens = {}
|
303
296
|
|
304
297
|
if isinstance(entry, dict):
|
305
|
-
if
|
306
|
-
in_name = entry[
|
298
|
+
if "name" in entry:
|
299
|
+
in_name = entry["name"]
|
307
300
|
else:
|
308
|
-
raise
|
309
|
-
|
310
|
-
if
|
311
|
-
|
312
|
-
|
313
|
-
|
301
|
+
raise AttributeError("Target list contained record without name")
|
302
|
+
|
303
|
+
if "class_indices" in entry:
|
304
|
+
if isinstance(entry["class_indices"], list):
|
305
|
+
class_indices = entry["class_indices"]
|
306
|
+
else:
|
307
|
+
class_indices = [entry["class_indices"]]
|
308
|
+
|
309
|
+
truth_configs_override = entry.get("truth_configs", {})
|
310
|
+
for key in truth_configs_override:
|
311
|
+
if key not in truth_configs:
|
312
|
+
raise AttributeError(
|
313
|
+
f"Truth config '{key}' override specified for {entry['name']} is not defined at top level"
|
314
|
+
)
|
315
|
+
truth_configs_merged = {}
|
316
|
+
for key in truth_configs_override:
|
317
|
+
truth_configs_merged[key] = deepcopy(truth_configs[key])
|
318
|
+
if truth_configs_override[key] is not None:
|
319
|
+
truth_configs_merged[key] |= truth_configs_override[key]
|
320
|
+
level_type = entry.get("level_type", level_type)
|
314
321
|
else:
|
315
322
|
in_name = entry
|
323
|
+
truth_configs_merged = deepcopy(truth_configs)
|
316
324
|
|
317
325
|
in_name, new_tokens = tokenized_expand(in_name)
|
318
326
|
tokens.update(new_tokens)
|
319
327
|
names = sorted(glob(in_name))
|
320
328
|
if not names:
|
321
|
-
raise
|
329
|
+
raise OSError(f"Could not find {in_name}. Make sure path exists")
|
322
330
|
|
323
331
|
target_files: list[dict] = []
|
324
332
|
for name in names:
|
@@ -329,57 +337,81 @@ def append_target_files(entry: dict | str,
|
|
329
337
|
child = file
|
330
338
|
if not isabs(child):
|
331
339
|
child = join(dir_name, child)
|
332
|
-
target_files.extend(
|
333
|
-
|
334
|
-
|
335
|
-
|
340
|
+
target_files.extend(
|
341
|
+
append_target_files(
|
342
|
+
entry=child,
|
343
|
+
class_indices=class_indices,
|
344
|
+
truth_configs=truth_configs_merged,
|
345
|
+
level_type=level_type,
|
346
|
+
tokens=tokens,
|
347
|
+
)
|
348
|
+
)
|
336
349
|
else:
|
337
350
|
try:
|
338
|
-
if ext ==
|
339
|
-
with open(file=name
|
351
|
+
if ext == ".txt":
|
352
|
+
with open(file=name) as txt_file:
|
340
353
|
for line in txt_file:
|
341
354
|
# strip comments
|
342
|
-
child = line.partition(
|
355
|
+
child = line.partition("#")[0]
|
343
356
|
child = child.rstrip()
|
344
357
|
if child:
|
345
358
|
child, new_tokens = tokenized_expand(child)
|
346
359
|
tokens.update(new_tokens)
|
347
360
|
if not isabs(child):
|
348
361
|
child = join(dir_name, child)
|
349
|
-
target_files.extend(
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
362
|
+
target_files.extend(
|
363
|
+
append_target_files(
|
364
|
+
entry=child,
|
365
|
+
class_indices=class_indices,
|
366
|
+
truth_configs=truth_configs_merged,
|
367
|
+
level_type=level_type,
|
368
|
+
tokens=tokens,
|
369
|
+
)
|
370
|
+
)
|
371
|
+
elif ext == ".yml":
|
354
372
|
try:
|
355
373
|
yml_config = raw_load_config(name)
|
356
374
|
|
357
|
-
if
|
358
|
-
for record in yml_config[
|
359
|
-
target_files.extend(
|
360
|
-
|
361
|
-
|
362
|
-
|
375
|
+
if "targets" in yml_config:
|
376
|
+
for record in yml_config["targets"]:
|
377
|
+
target_files.extend(
|
378
|
+
append_target_files(
|
379
|
+
entry=record,
|
380
|
+
class_indices=class_indices,
|
381
|
+
truth_configs=truth_configs_merged,
|
382
|
+
level_type=level_type,
|
383
|
+
tokens=tokens,
|
384
|
+
)
|
385
|
+
)
|
363
386
|
except Exception as e:
|
364
|
-
raise
|
387
|
+
raise OSError(f"Error processing {name}: {e}") from e
|
365
388
|
else:
|
366
389
|
validate_input_file(name)
|
367
390
|
target_file: dict = {
|
368
|
-
|
369
|
-
|
391
|
+
"expanded_name": name,
|
392
|
+
"name": tokenized_replace(name, tokens),
|
393
|
+
"class_indices": class_indices,
|
394
|
+
"level_type": level_type,
|
395
|
+
"truth_configs": {},
|
370
396
|
}
|
371
|
-
if len(
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
397
|
+
if len(truth_configs_merged) > 0:
|
398
|
+
for tc_key, tc_value in truth_configs_merged.items():
|
399
|
+
config = deepcopy(tc_value)
|
400
|
+
truth_config: dict = {}
|
401
|
+
for key in REQUIRED_TRUTH_CONFIGS:
|
402
|
+
truth_config[key] = config[key]
|
403
|
+
del config[key]
|
404
|
+
truth_config["config"] = config
|
405
|
+
target_file["truth_configs"][tc_key] = dataclass_from_dict(TruthConfig, truth_config)
|
406
|
+
for tc_key in target_file["truth_configs"]:
|
407
|
+
if (
|
408
|
+
"function" in truth_configs_merged[tc_key]
|
409
|
+
and truth_configs_merged[tc_key]["function"] == "file"
|
410
|
+
):
|
411
|
+
truth_configs_merged[tc_key]["file"] = splitext(target_file["name"])[0] + ".h5"
|
378
412
|
target_files.append(target_file)
|
379
|
-
except SonusAIError:
|
380
|
-
raise
|
381
413
|
except Exception as e:
|
382
|
-
raise
|
414
|
+
raise OSError(f"Error processing {name}: {e}") from e
|
383
415
|
|
384
416
|
return target_files
|
385
417
|
|
@@ -393,22 +425,22 @@ def get_noise_files(config: dict, show_progress: bool = False) -> NoiseFiles:
|
|
393
425
|
"""
|
394
426
|
from itertools import chain
|
395
427
|
|
396
|
-
from tqdm import tqdm
|
397
|
-
|
398
428
|
from sonusai.utils import dataclass_from_dict
|
399
|
-
from sonusai.utils import
|
429
|
+
from sonusai.utils import par_track
|
430
|
+
from sonusai.utils import track
|
431
|
+
|
400
432
|
from .datatypes import NoiseFiles
|
401
433
|
|
402
|
-
noise_files = list(chain.from_iterable([append_noise_files(entry=entry) for entry in config[
|
434
|
+
noise_files = list(chain.from_iterable([append_noise_files(entry=entry) for entry in config["noises"]]))
|
403
435
|
|
404
|
-
progress =
|
405
|
-
noise_files =
|
436
|
+
progress = track(total=len(noise_files), disable=not show_progress)
|
437
|
+
noise_files = par_track(_get_num_samples, noise_files, progress=progress)
|
406
438
|
progress.close()
|
407
439
|
|
408
440
|
return dataclass_from_dict(NoiseFiles, noise_files)
|
409
441
|
|
410
442
|
|
411
|
-
def append_noise_files(entry: dict | str, tokens: dict = None) -> list[dict]:
|
443
|
+
def append_noise_files(entry: dict | str, tokens: dict | None = None) -> list[dict]:
|
412
444
|
"""Process noise files list and append as needed
|
413
445
|
|
414
446
|
:param entry: Noise file entry to append to the list
|
@@ -423,7 +455,6 @@ def append_noise_files(entry: dict | str, tokens: dict = None) -> list[dict]:
|
|
423
455
|
from os.path import join
|
424
456
|
from os.path import splitext
|
425
457
|
|
426
|
-
from sonusai import SonusAIError
|
427
458
|
from .audio import validate_input_file
|
428
459
|
from .tokenized_shell_vars import tokenized_expand
|
429
460
|
from .tokenized_shell_vars import tokenized_replace
|
@@ -432,10 +463,10 @@ def append_noise_files(entry: dict | str, tokens: dict = None) -> list[dict]:
|
|
432
463
|
tokens = {}
|
433
464
|
|
434
465
|
if isinstance(entry, dict):
|
435
|
-
if
|
436
|
-
in_name = entry[
|
466
|
+
if "name" in entry:
|
467
|
+
in_name = entry["name"]
|
437
468
|
else:
|
438
|
-
raise
|
469
|
+
raise AttributeError("Noise list contained record without name")
|
439
470
|
else:
|
440
471
|
in_name = entry
|
441
472
|
|
@@ -443,7 +474,7 @@ def append_noise_files(entry: dict | str, tokens: dict = None) -> list[dict]:
|
|
443
474
|
tokens.update(new_tokens)
|
444
475
|
names = sorted(glob(in_name))
|
445
476
|
if not names:
|
446
|
-
raise
|
477
|
+
raise OSError(f"Could not find {in_name}. Make sure path exists")
|
447
478
|
|
448
479
|
noise_files: list[dict] = []
|
449
480
|
for name in names:
|
@@ -457,11 +488,11 @@ def append_noise_files(entry: dict | str, tokens: dict = None) -> list[dict]:
|
|
457
488
|
noise_files.extend(append_noise_files(entry=child, tokens=tokens))
|
458
489
|
else:
|
459
490
|
try:
|
460
|
-
if ext ==
|
461
|
-
with open(file=name
|
491
|
+
if ext == ".txt":
|
492
|
+
with open(file=name) as txt_file:
|
462
493
|
for line in txt_file:
|
463
494
|
# strip comments
|
464
|
-
child = line.partition(
|
495
|
+
child = line.partition("#")[0]
|
465
496
|
child = child.rstrip()
|
466
497
|
if child:
|
467
498
|
child, new_tokens = tokenized_expand(child)
|
@@ -469,26 +500,24 @@ def append_noise_files(entry: dict | str, tokens: dict = None) -> list[dict]:
|
|
469
500
|
if not isabs(child):
|
470
501
|
child = join(dir_name, child)
|
471
502
|
noise_files.extend(append_noise_files(entry=child, tokens=tokens))
|
472
|
-
elif ext ==
|
503
|
+
elif ext == ".yml":
|
473
504
|
try:
|
474
505
|
yml_config = raw_load_config(name)
|
475
506
|
|
476
|
-
if
|
477
|
-
for record in yml_config[
|
507
|
+
if "noises" in yml_config:
|
508
|
+
for record in yml_config["noises"]:
|
478
509
|
noise_files.extend(append_noise_files(entry=record, tokens=tokens))
|
479
510
|
except Exception as e:
|
480
|
-
raise
|
511
|
+
raise OSError(f"Error processing {name}: {e}") from e
|
481
512
|
else:
|
482
513
|
validate_input_file(name)
|
483
514
|
noise_file: dict = {
|
484
|
-
|
485
|
-
|
515
|
+
"expanded_name": name,
|
516
|
+
"name": tokenized_replace(name, tokens),
|
486
517
|
}
|
487
518
|
noise_files.append(noise_file)
|
488
|
-
except SonusAIError:
|
489
|
-
raise
|
490
519
|
except Exception as e:
|
491
|
-
raise
|
520
|
+
raise OSError(f"Error processing {name}: {e}") from e
|
492
521
|
|
493
522
|
return noise_files
|
494
523
|
|
@@ -499,13 +528,20 @@ def get_impulse_response_files(config: dict) -> ImpulseResponseFiles:
|
|
499
528
|
:param config: Config dictionary
|
500
529
|
:return: List of impulse response files
|
501
530
|
"""
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
531
|
+
return [ImpulseResponseFile(entry["name"], entry["tags"]) for entry in config["impulse_responses"]]
|
532
|
+
# from itertools import chain
|
533
|
+
#
|
534
|
+
# return list(
|
535
|
+
# chain.from_iterable(
|
536
|
+
# [
|
537
|
+
# append_impulse_response_files(entry=ImpulseResponseFile(entry["name"], entry["tags"]))
|
538
|
+
# for entry in config["impulse_responses"]
|
539
|
+
# ]
|
540
|
+
# )
|
541
|
+
# )
|
542
|
+
|
543
|
+
|
544
|
+
def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | None = None) -> list[str]:
|
509
545
|
"""Process impulse response files list and append as needed
|
510
546
|
|
511
547
|
:param entry: Impulse response file entry to append to the list
|
@@ -520,7 +556,6 @@ def append_impulse_response_files(entry: str, tokens: dict = None) -> list[str]:
|
|
520
556
|
from os.path import join
|
521
557
|
from os.path import splitext
|
522
558
|
|
523
|
-
from sonusai import SonusAIError
|
524
559
|
from .audio import validate_input_file
|
525
560
|
from .tokenized_shell_vars import tokenized_expand
|
526
561
|
from .tokenized_shell_vars import tokenized_replace
|
@@ -528,11 +563,11 @@ def append_impulse_response_files(entry: str, tokens: dict = None) -> list[str]:
|
|
528
563
|
if tokens is None:
|
529
564
|
tokens = {}
|
530
565
|
|
531
|
-
in_name, new_tokens = tokenized_expand(entry)
|
566
|
+
in_name, new_tokens = tokenized_expand(entry.file)
|
532
567
|
tokens.update(new_tokens)
|
533
568
|
names = sorted(glob(in_name))
|
534
569
|
if not names:
|
535
|
-
raise
|
570
|
+
raise OSError(f"Could not find {in_name}. Make sure path exists")
|
536
571
|
|
537
572
|
impulse_response_files: list[str] = []
|
538
573
|
for name in names:
|
@@ -540,41 +575,41 @@ def append_impulse_response_files(entry: str, tokens: dict = None) -> list[str]:
|
|
540
575
|
dir_name = dirname(name)
|
541
576
|
if isdir(name):
|
542
577
|
for file in listdir(name):
|
543
|
-
|
544
|
-
|
545
|
-
|
578
|
+
if not isabs(file):
|
579
|
+
file = join(dir_name, file)
|
580
|
+
child = ImpulseResponseFile(file, entry.tags)
|
546
581
|
impulse_response_files.extend(append_impulse_response_files(entry=child, tokens=tokens))
|
547
582
|
else:
|
548
583
|
try:
|
549
|
-
if ext ==
|
550
|
-
with open(file=name
|
584
|
+
if ext == ".txt":
|
585
|
+
with open(file=name) as txt_file:
|
551
586
|
for line in txt_file:
|
552
587
|
# strip comments
|
553
|
-
|
554
|
-
|
555
|
-
if
|
556
|
-
|
588
|
+
file = line.partition("#")[0]
|
589
|
+
file = file.rstrip()
|
590
|
+
if file:
|
591
|
+
file, new_tokens = tokenized_expand(file)
|
557
592
|
tokens.update(new_tokens)
|
558
|
-
if not isabs(
|
559
|
-
|
593
|
+
if not isabs(file):
|
594
|
+
file = join(dir_name, file)
|
595
|
+
child = ImpulseResponseFile(file, entry.tags)
|
560
596
|
impulse_response_files.extend(append_impulse_response_files(entry=child, tokens=tokens))
|
561
|
-
elif ext ==
|
597
|
+
elif ext == ".yml":
|
562
598
|
try:
|
563
599
|
yml_config = raw_load_config(name)
|
564
600
|
|
565
|
-
if
|
566
|
-
for record in yml_config[
|
601
|
+
if "impulse_responses" in yml_config:
|
602
|
+
for record in yml_config["impulse_responses"]:
|
567
603
|
impulse_response_files.extend(
|
568
|
-
append_impulse_response_files(entry=record, tokens=tokens)
|
604
|
+
append_impulse_response_files(entry=record, tokens=tokens)
|
605
|
+
)
|
569
606
|
except Exception as e:
|
570
|
-
raise
|
607
|
+
raise OSError(f"Error processing {name}: {e}") from e
|
571
608
|
else:
|
572
609
|
validate_input_file(name)
|
573
610
|
impulse_response_files.append(tokenized_replace(name, tokens))
|
574
|
-
except SonusAIError:
|
575
|
-
raise
|
576
611
|
except Exception as e:
|
577
|
-
raise
|
612
|
+
raise OSError(f"Error processing {name}: {e}") from e
|
578
613
|
|
579
614
|
return impulse_response_files
|
580
615
|
|
@@ -585,19 +620,51 @@ def get_spectral_masks(config: dict) -> SpectralMasks:
|
|
585
620
|
:param config: Config dictionary
|
586
621
|
:return: List of spectral masks
|
587
622
|
"""
|
588
|
-
from sonusai import SonusAIError
|
589
623
|
from sonusai.utils import dataclass_from_dict
|
590
|
-
from .datatypes import SpectralMasks
|
591
624
|
|
592
625
|
try:
|
593
|
-
return dataclass_from_dict(SpectralMasks, config[
|
626
|
+
return dataclass_from_dict(SpectralMasks, config["spectral_masks"])
|
594
627
|
except Exception as e:
|
595
|
-
raise
|
628
|
+
raise ValueError(f"Error in spectral_masks: {e}") from e
|
629
|
+
|
630
|
+
|
631
|
+
def get_truth_parameters(config: dict) -> TruthParameters:
|
632
|
+
"""Get the list of truth parameters from a config
|
633
|
+
|
634
|
+
:param config: Config dictionary
|
635
|
+
:return: List of truth parameters
|
636
|
+
"""
|
637
|
+
from copy import deepcopy
|
638
|
+
|
639
|
+
from sonusai.mixture import truth_functions
|
640
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
|
641
|
+
|
642
|
+
from .constants import REQUIRED_TRUTH_CONFIGS
|
643
|
+
from .datatypes import TruthParameter
|
644
|
+
|
645
|
+
truth_parameters: TruthParameters = []
|
646
|
+
for name, truth_config in config["truth_configs"].items():
|
647
|
+
optional_config = deepcopy(truth_config)
|
648
|
+
for key in REQUIRED_TRUTH_CONFIGS:
|
649
|
+
del optional_config[key]
|
650
|
+
|
651
|
+
t_config = TruthFunctionConfig(
|
652
|
+
feature=config["feature"],
|
653
|
+
num_classes=config["num_classes"],
|
654
|
+
class_indices=[1],
|
655
|
+
target_gain=1,
|
656
|
+
config=optional_config,
|
657
|
+
)
|
658
|
+
|
659
|
+
parameters = getattr(truth_functions, truth_config["function"] + "_parameters")(t_config)
|
660
|
+
truth_parameters.append(TruthParameter(name, parameters))
|
661
|
+
|
662
|
+
return truth_parameters
|
596
663
|
|
597
664
|
|
598
665
|
def _get_num_samples(entry: dict) -> dict:
|
599
666
|
from .audio import get_num_samples
|
600
667
|
|
601
|
-
entry[
|
602
|
-
del entry[
|
668
|
+
entry["samples"] = get_num_samples(entry["expanded_name"])
|
669
|
+
del entry["expanded_name"]
|
603
670
|
return entry
|