sonusai 0.18.9__py3-none-any.whl → 0.19.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +93 -77
  13. sonusai/genmetrics.py +59 -46
  14. sonusai/genmix.py +116 -104
  15. sonusai/genmixdb.py +194 -153
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +15 -17
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +19 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +40 -38
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +669 -477
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +52 -85
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +40 -27
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
  113. sonusai-0.19.5.dist-info/RECORD +125 -0
  114. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.9.dist-info/RECORD +0 -125
  118. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
sonusai/mixture/config.py CHANGED
@@ -1,7 +1,9 @@
1
+ from sonusai.mixture.datatypes import ImpulseResponseFile
1
2
  from sonusai.mixture.datatypes import ImpulseResponseFiles
2
3
  from sonusai.mixture.datatypes import NoiseFiles
3
4
  from sonusai.mixture.datatypes import SpectralMasks
4
5
  from sonusai.mixture.datatypes import TargetFiles
6
+ from sonusai.mixture.datatypes import TruthParameters
5
7
 
6
8
 
7
9
  def raw_load_config(name: str) -> dict:
@@ -12,7 +14,7 @@ def raw_load_config(name: str) -> dict:
12
14
  """
13
15
  import yaml
14
16
 
15
- with open(file=name, mode='r') as f:
17
+ with open(file=name) as f:
16
18
  config = yaml.safe_load(f)
17
19
 
18
20
  return config
@@ -23,13 +25,12 @@ def get_default_config() -> dict:
23
25
 
24
26
  :return: Dictionary of default config data
25
27
  """
26
- from sonusai import SonusAIError
27
28
  from .constants import DEFAULT_CONFIG
28
29
 
29
30
  try:
30
31
  return raw_load_config(DEFAULT_CONFIG)
31
32
  except Exception as e:
32
- raise SonusAIError(f'Error loading default config: {e}')
33
+ raise OSError(f"Error loading default config: {e}") from e
33
34
 
34
35
 
35
36
  def load_config(name: str) -> dict:
@@ -40,125 +41,115 @@ def load_config(name: str) -> dict:
40
41
  """
41
42
  from os.path import join
42
43
 
43
- return update_config_from_file(name=join(name, 'config.yml'), config=get_default_config())
44
+ return update_config_from_file(filename=join(name, "config.yml"), given_config=get_default_config())
44
45
 
45
46
 
46
- def update_config_from_file(name: str, config: dict) -> dict:
47
- """Update the given config with the config in the YAML file
47
+ def update_config_from_file(filename: str, given_config: dict) -> dict:
48
+ """Update the given config with the config in the specified YAML file
48
49
 
49
- :param name: File name
50
- :param config: Config dictionary to update
50
+ :param filename: File name
51
+ :param given_config: Config dictionary to update
51
52
  :return: Updated config dictionary
52
53
  """
53
54
  from copy import deepcopy
54
55
 
55
- from sonusai import SonusAIError
56
56
  from .constants import REQUIRED_CONFIGS
57
57
  from .constants import VALID_CONFIGS
58
58
  from .constants import VALID_NOISE_MIX_MODES
59
59
 
60
- updated_config = deepcopy(config)
60
+ updated_config = deepcopy(given_config)
61
61
 
62
62
  try:
63
- new_config = raw_load_config(name)
63
+ file_config = raw_load_config(filename)
64
64
  except Exception as e:
65
- raise SonusAIError(f'Error loading config from {name}: {e}')
65
+ raise OSError(f"Error loading config from {filename}: {e}") from e
66
66
 
67
67
  # Check for unrecognized keys
68
- for key in new_config:
68
+ for key in file_config:
69
69
  if key not in VALID_CONFIGS:
70
- nice_list = '\n'.join([f' {item}' for item in VALID_CONFIGS])
71
- raise SonusAIError(f'Invalid config parameter in {name}: {key}.\n'
72
- f'Valid config parameters are:\n{nice_list}')
70
+ nice_list = "\n".join([f" {item}" for item in VALID_CONFIGS])
71
+ raise AttributeError(
72
+ f"Invalid config parameter in {filename}: {key}.\nValid config parameters are:\n{nice_list}"
73
+ )
73
74
 
74
75
  # Use default config as base and overwrite with given config keys as found
75
76
  for key in updated_config:
76
- if key in new_config:
77
- if key not in ['truth_settings']:
78
- updated_config[key] = new_config[key]
79
-
80
- # Handle 'truth_settings' special case
81
- if 'truth_settings' in new_config:
82
- updated_config['truth_settings'] = deepcopy(new_config['truth_settings'])
83
-
84
- if not isinstance(updated_config['truth_settings'], list):
85
- updated_config['truth_settings'] = [updated_config['truth_settings']]
86
-
87
- default = deepcopy(config['truth_settings'])
88
- if not isinstance(default, list):
89
- default = [default]
90
-
91
- updated_config['truth_settings'] = update_truth_settings(updated_config['truth_settings'], default)
92
-
93
- # Handle 'asr_configs' special case
94
- if 'asr_configs' in updated_config:
95
- asr_configs = {}
96
- for asr_config in updated_config['asr_configs']:
97
- asr_name = asr_config.get('name', None)
98
- asr_engine = asr_config.get('engine', None)
99
- if asr_name is None or asr_engine is None:
100
- raise SonusAIError(f'Invalid config parameter in {name}: asr_configs.\n'
101
- f'asr_configs must contain both name and engine.')
102
- del asr_config['name']
103
- asr_configs[asr_name] = asr_config
104
- updated_config['asr_configs'] = asr_configs
77
+ if key in file_config:
78
+ updated_config[key] = file_config[key]
105
79
 
106
80
  # Check for required keys
107
81
  for key in REQUIRED_CONFIGS:
108
82
  if key not in updated_config:
109
- raise SonusAIError(f'Missing required config in {name}: {key}')
83
+ raise AttributeError(f"{filename} is missing required '{key}'")
84
+
85
+ # Validate special cases
86
+ validate_truth_configs(updated_config)
87
+ validate_asr_configs(updated_config)
110
88
 
111
89
  # Check for non-empty spectral masks
112
- if len(updated_config['spectral_masks']) == 0:
113
- updated_config['spectral_masks'] = config['spectral_masks']
90
+ if len(updated_config["spectral_masks"]) == 0:
91
+ updated_config["spectral_masks"] = given_config["spectral_masks"]
114
92
 
115
93
  # Check for valid noise_mix_mode
116
- if updated_config['noise_mix_mode'] not in VALID_NOISE_MIX_MODES:
117
- nice_list = '\n'.join([f' {item}' for item in VALID_NOISE_MIX_MODES])
118
- raise SonusAIError(f'Invalid noise_mix_mode in {name}.\n'
119
- f'Valid noise mix modes are:\n{nice_list}')
94
+ if updated_config["noise_mix_mode"] not in VALID_NOISE_MIX_MODES:
95
+ nice_list = "\n".join([f" {item}" for item in VALID_NOISE_MIX_MODES])
96
+ raise ValueError(f"{filename} contains invalid noise_mix_mode.\nValid noise mix modes are:\n{nice_list}")
120
97
 
121
98
  return updated_config
122
99
 
123
100
 
124
- def update_truth_settings(given: list[dict] | dict, default: list[dict] = None) -> list[dict]:
125
- """Update missing fields in given 'truth_settings' with default values
101
+ def validate_truth_configs(given: dict) -> None:
102
+ """Validate fields in given 'truth_configs'
126
103
 
127
- :param given: The dictionary of given truth settings
128
- :param default: The dictionary of default truth settings
129
- :return: Updated dictionary of truth settings
104
+ :param given: The dictionary of given config
130
105
  """
131
106
  from copy import deepcopy
132
107
 
133
- from sonusai import SonusAIError
134
- from .constants import VALID_TRUTH_SETTINGS
108
+ from sonusai.mixture import truth_functions
135
109
 
136
- if isinstance(given, list):
137
- truth_settings = deepcopy(given)
138
- else:
139
- truth_settings = [deepcopy(given)]
110
+ from .constants import REQUIRED_TRUTH_CONFIGS
140
111
 
141
- if default is not None and len(truth_settings) != len(default):
142
- raise SonusAIError(f'Length of given does not match default')
112
+ if "truth_configs" not in given:
113
+ raise AttributeError("config is missing required 'truth_configs'")
143
114
 
144
- for n in range(len(truth_settings)):
145
- for key in truth_settings[n]:
146
- if key not in VALID_TRUTH_SETTINGS:
147
- nice_list = '\n'.join([f' {item}' for item in VALID_TRUTH_SETTINGS])
148
- raise SonusAIError(f'Invalid truth_settings: {key}.\nValid truth_settings are:\n{nice_list}')
115
+ truth_configs = given["truth_configs"]
116
+ if len(truth_configs) == 0:
117
+ raise ValueError("'truth_configs' in config is empty")
149
118
 
150
- for key in VALID_TRUTH_SETTINGS:
151
- if key not in truth_settings[n]:
152
- if default is not None and key in default[n]:
153
- truth_settings[n][key] = default[n][key]
154
- else:
155
- raise SonusAIError(f'Missing required truth_settings: {key}')
119
+ for name, truth_config in truth_configs.items():
120
+ for key in REQUIRED_TRUTH_CONFIGS:
121
+ if key not in truth_config:
122
+ raise AttributeError(f"'{name}' in truth_configs is missing required '{key}'")
123
+
124
+ optional_config = deepcopy(truth_config)
125
+ for key in REQUIRED_TRUTH_CONFIGS:
126
+ del optional_config[key]
127
+
128
+ getattr(truth_functions, truth_config["function"] + "_validate")(optional_config)
129
+
130
+
131
+ def validate_asr_configs(given: dict) -> None:
132
+ """Validate fields in given 'asr_config'
133
+
134
+ :param given: The dictionary of given config
135
+ """
136
+ from sonusai.utils import validate_asr
137
+
138
+ from .constants import REQUIRED_ASR_CONFIGS
139
+
140
+ if "asr_configs" not in given:
141
+ raise AttributeError("config is missing required 'asr_configs'")
142
+
143
+ asr_configs = given["asr_configs"]
156
144
 
157
- for truth_setting in truth_settings:
158
- if not isinstance(truth_setting['index'], list):
159
- truth_setting['index'] = [truth_setting['index']]
145
+ for name, asr_config in asr_configs.items():
146
+ for key in REQUIRED_ASR_CONFIGS:
147
+ if key not in asr_config:
148
+ raise AttributeError(f"'{name}' in asr_configs is missing required '{key}'")
160
149
 
161
- return truth_settings
150
+ engine = asr_config["engine"]
151
+ config = {x: asr_config[x] for x in asr_config if x != "engine"}
152
+ validate_asr(engine, **config)
162
153
 
163
154
 
164
155
  def get_hierarchical_config_files(root: str, leaf: str) -> list[str]:
@@ -171,25 +162,23 @@ def get_hierarchical_config_files(root: str, leaf: str) -> list[str]:
171
162
  import os
172
163
  from pathlib import Path
173
164
 
174
- from sonusai import SonusAIError
175
-
176
- config_file = 'config.yml'
165
+ config_file = "config.yml"
177
166
 
178
167
  root_path = Path(os.path.abspath(root))
179
168
  if not root_path.is_dir():
180
- raise SonusAIError(f'Given root, {root_path}, is not a directory.')
169
+ raise OSError(f"Given root, {root_path}, is not a directory.")
181
170
 
182
171
  leaf_path = Path(os.path.abspath(leaf))
183
172
  if not leaf_path.is_dir():
184
- raise SonusAIError(f'Given leaf, {leaf_path}, is not a directory.')
173
+ raise OSError(f"Given leaf, {leaf_path}, is not a directory.")
185
174
 
186
175
  common = os.path.commonpath((root_path, leaf_path))
187
176
  if os.path.normpath(common) != os.path.normpath(root_path):
188
- raise SonusAIError(f'Given leaf, {leaf_path}, is not in the hierarchy of the given root, {root_path}')
177
+ raise OSError(f"Given leaf, {leaf_path}, is not in the hierarchy of the given root, {root_path}")
189
178
 
190
179
  top_config_file = os.path.join(root_path, config_file)
191
180
  if not Path(top_config_file).is_file():
192
- raise SonusAIError(f'Could not find {top_config_file}')
181
+ raise OSError(f"Could not find {top_config_file}")
193
182
 
194
183
  current = leaf_path
195
184
  config_files = []
@@ -216,24 +205,11 @@ def update_config_from_hierarchy(root: str, leaf: str, config: dict) -> dict:
216
205
  new_config = deepcopy(config)
217
206
  config_files = get_hierarchical_config_files(root=root, leaf=leaf)
218
207
  for config_file in config_files:
219
- new_config = update_config_from_file(name=config_file, config=new_config)
208
+ new_config = update_config_from_file(filename=config_file, given_config=new_config)
220
209
 
221
210
  return new_config
222
211
 
223
212
 
224
- def get_max_class(num_classes: int, truth_mutex: bool) -> int:
225
- """Get the maximum class index
226
-
227
- :param num_classes: Number of classes
228
- :param truth_mutex: Truth is mutex mode
229
- :return: Highest class index
230
- """
231
- max_class = num_classes
232
- if truth_mutex:
233
- max_class -= 1
234
- return max_class
235
-
236
-
237
213
  def get_target_files(config: dict, show_progress: bool = False) -> TargetFiles:
238
214
  """Get the list of target files from a config
239
215
 
@@ -243,48 +219,62 @@ def get_target_files(config: dict, show_progress: bool = False) -> TargetFiles:
243
219
  """
244
220
  from itertools import chain
245
221
 
246
- from tqdm import tqdm
247
-
248
- from sonusai import SonusAIError
249
222
  from sonusai.utils import dataclass_from_dict
250
- from sonusai.utils import pp_tqdm_imap
251
- from .datatypes import TargetFiles
223
+ from sonusai.utils import par_track
224
+ from sonusai.utils import track
252
225
 
253
- truth_settings = config.get('truth_settings', list())
254
- level_type = config.get('target_level_type', None)
255
- target_files = list(chain.from_iterable([append_target_files(entry=entry,
256
- truth_settings=truth_settings,
257
- level_type=level_type)
258
- for entry in config['targets']]))
226
+ from .datatypes import TargetFiles
259
227
 
260
- progress = tqdm(total=len(target_files), disable=not show_progress)
261
- target_files = pp_tqdm_imap(_get_num_samples, target_files, progress=progress)
228
+ class_indices = config["class_indices"]
229
+ if not isinstance(class_indices, list):
230
+ class_indices = [class_indices]
231
+
232
+ target_files = list(
233
+ chain.from_iterable(
234
+ [
235
+ append_target_files(
236
+ entry=entry,
237
+ class_indices=class_indices,
238
+ truth_configs=config["truth_configs"],
239
+ level_type=config["target_level_type"],
240
+ )
241
+ for entry in config["targets"]
242
+ ]
243
+ )
244
+ )
245
+
246
+ progress = track(total=len(target_files), disable=not show_progress)
247
+ target_files = par_track(_get_num_samples, target_files, progress=progress)
262
248
  progress.close()
263
249
 
264
- max_class = get_max_class(config['num_classes'], config['truth_mode'] == 'mutex')
265
-
250
+ num_classes = config["num_classes"]
266
251
  for target_file in target_files:
267
- target_file['truth_settings'] = update_truth_settings(target_file['truth_settings'], config['truth_settings'])
252
+ if any(class_index < 0 for class_index in target_file["class_indices"]):
253
+ raise ValueError("class indices must contain only positive elements")
268
254
 
269
- for truth_setting in target_file['truth_settings']:
270
- if any(idx > max_class for idx in truth_setting['index']):
271
- raise SonusAIError('invalid truth index')
255
+ if any(class_index > num_classes for class_index in target_file["class_indices"]):
256
+ raise ValueError(f"class index elements must not be greater than {num_classes}")
272
257
 
273
258
  return dataclass_from_dict(TargetFiles, target_files)
274
259
 
275
260
 
276
- def append_target_files(entry: dict | str,
277
- truth_settings: list[dict],
278
- level_type: str | None,
279
- tokens: dict = None) -> list[dict]:
261
+ def append_target_files(
262
+ entry: dict | str,
263
+ class_indices: list[int],
264
+ truth_configs: dict,
265
+ level_type: str,
266
+ tokens: dict | None = None,
267
+ ) -> list[dict]:
280
268
  """Process target files list and append as needed
281
269
 
282
270
  :param entry: Target file entry to append to the list
283
- :param truth_settings: Truth settings
271
+ :param class_indices: Class indices
272
+ :param truth_configs: Truth configs
284
273
  :param level_type: Target level type
285
274
  :param tokens: Tokens used for variable expansion
286
275
  :return: List of target files
287
276
  """
277
+ from copy import deepcopy
288
278
  from glob import glob
289
279
  from os import listdir
290
280
  from os.path import dirname
@@ -293,8 +283,11 @@ def append_target_files(entry: dict | str,
293
283
  from os.path import join
294
284
  from os.path import splitext
295
285
 
296
- from sonusai import SonusAIError
286
+ from sonusai.utils import dataclass_from_dict
287
+
297
288
  from .audio import validate_input_file
289
+ from .constants import REQUIRED_TRUTH_CONFIGS
290
+ from .datatypes import TruthConfig
298
291
  from .tokenized_shell_vars import tokenized_expand
299
292
  from .tokenized_shell_vars import tokenized_replace
300
293
 
@@ -302,23 +295,38 @@ def append_target_files(entry: dict | str,
302
295
  tokens = {}
303
296
 
304
297
  if isinstance(entry, dict):
305
- if 'name' in entry:
306
- in_name = entry['name']
298
+ if "name" in entry:
299
+ in_name = entry["name"]
307
300
  else:
308
- raise SonusAIError('Target list contained record without name')
309
-
310
- if 'truth_settings' in entry:
311
- truth_settings = entry['truth_settings']
312
- if 'target_level_type' in entry:
313
- level_type = entry['target_level_type']
301
+ raise AttributeError("Target list contained record without name")
302
+
303
+ if "class_indices" in entry:
304
+ if isinstance(entry["class_indices"], list):
305
+ class_indices = entry["class_indices"]
306
+ else:
307
+ class_indices = [entry["class_indices"]]
308
+
309
+ truth_configs_override = entry.get("truth_configs", {})
310
+ for key in truth_configs_override:
311
+ if key not in truth_configs:
312
+ raise AttributeError(
313
+ f"Truth config '{key}' override specified for {entry['name']} is not defined at top level"
314
+ )
315
+ truth_configs_merged = {}
316
+ for key in truth_configs_override:
317
+ truth_configs_merged[key] = deepcopy(truth_configs[key])
318
+ if truth_configs_override[key] is not None:
319
+ truth_configs_merged[key] |= truth_configs_override[key]
320
+ level_type = entry.get("level_type", level_type)
314
321
  else:
315
322
  in_name = entry
323
+ truth_configs_merged = deepcopy(truth_configs)
316
324
 
317
325
  in_name, new_tokens = tokenized_expand(in_name)
318
326
  tokens.update(new_tokens)
319
327
  names = sorted(glob(in_name))
320
328
  if not names:
321
- raise SonusAIError(f'Could not find {in_name}. Make sure path exists')
329
+ raise OSError(f"Could not find {in_name}. Make sure path exists")
322
330
 
323
331
  target_files: list[dict] = []
324
332
  for name in names:
@@ -329,57 +337,81 @@ def append_target_files(entry: dict | str,
329
337
  child = file
330
338
  if not isabs(child):
331
339
  child = join(dir_name, child)
332
- target_files.extend(append_target_files(entry=child,
333
- truth_settings=truth_settings,
334
- level_type=level_type,
335
- tokens=tokens))
340
+ target_files.extend(
341
+ append_target_files(
342
+ entry=child,
343
+ class_indices=class_indices,
344
+ truth_configs=truth_configs_merged,
345
+ level_type=level_type,
346
+ tokens=tokens,
347
+ )
348
+ )
336
349
  else:
337
350
  try:
338
- if ext == '.txt':
339
- with open(file=name, mode='r') as txt_file:
351
+ if ext == ".txt":
352
+ with open(file=name) as txt_file:
340
353
  for line in txt_file:
341
354
  # strip comments
342
- child = line.partition('#')[0]
355
+ child = line.partition("#")[0]
343
356
  child = child.rstrip()
344
357
  if child:
345
358
  child, new_tokens = tokenized_expand(child)
346
359
  tokens.update(new_tokens)
347
360
  if not isabs(child):
348
361
  child = join(dir_name, child)
349
- target_files.extend(append_target_files(entry=child,
350
- truth_settings=truth_settings,
351
- level_type=level_type,
352
- tokens=tokens))
353
- elif ext == '.yml':
362
+ target_files.extend(
363
+ append_target_files(
364
+ entry=child,
365
+ class_indices=class_indices,
366
+ truth_configs=truth_configs_merged,
367
+ level_type=level_type,
368
+ tokens=tokens,
369
+ )
370
+ )
371
+ elif ext == ".yml":
354
372
  try:
355
373
  yml_config = raw_load_config(name)
356
374
 
357
- if 'targets' in yml_config:
358
- for record in yml_config['targets']:
359
- target_files.extend(append_target_files(entry=record,
360
- truth_settings=truth_settings,
361
- level_type=level_type,
362
- tokens=tokens))
375
+ if "targets" in yml_config:
376
+ for record in yml_config["targets"]:
377
+ target_files.extend(
378
+ append_target_files(
379
+ entry=record,
380
+ class_indices=class_indices,
381
+ truth_configs=truth_configs_merged,
382
+ level_type=level_type,
383
+ tokens=tokens,
384
+ )
385
+ )
363
386
  except Exception as e:
364
- raise SonusAIError(f'Error processing {name}: {e}')
387
+ raise OSError(f"Error processing {name}: {e}") from e
365
388
  else:
366
389
  validate_input_file(name)
367
390
  target_file: dict = {
368
- 'expanded_name': name,
369
- 'name': tokenized_replace(name, tokens),
391
+ "expanded_name": name,
392
+ "name": tokenized_replace(name, tokens),
393
+ "class_indices": class_indices,
394
+ "level_type": level_type,
395
+ "truth_configs": {},
370
396
  }
371
- if len(truth_settings) > 0:
372
- target_file['truth_settings'] = truth_settings
373
- for truth_setting in target_file['truth_settings']:
374
- if 'function' in truth_setting and truth_setting['function'] == 'file':
375
- truth_setting['config']['file'] = splitext(target_file['name'])[0] + '.h5'
376
- if level_type is not None:
377
- target_file['level_type'] = level_type
397
+ if len(truth_configs_merged) > 0:
398
+ for tc_key, tc_value in truth_configs_merged.items():
399
+ config = deepcopy(tc_value)
400
+ truth_config: dict = {}
401
+ for key in REQUIRED_TRUTH_CONFIGS:
402
+ truth_config[key] = config[key]
403
+ del config[key]
404
+ truth_config["config"] = config
405
+ target_file["truth_configs"][tc_key] = dataclass_from_dict(TruthConfig, truth_config)
406
+ for tc_key in target_file["truth_configs"]:
407
+ if (
408
+ "function" in truth_configs_merged[tc_key]
409
+ and truth_configs_merged[tc_key]["function"] == "file"
410
+ ):
411
+ truth_configs_merged[tc_key]["file"] = splitext(target_file["name"])[0] + ".h5"
378
412
  target_files.append(target_file)
379
- except SonusAIError:
380
- raise
381
413
  except Exception as e:
382
- raise SonusAIError(f'Error processing {name}: {e}')
414
+ raise OSError(f"Error processing {name}: {e}") from e
383
415
 
384
416
  return target_files
385
417
 
@@ -393,22 +425,22 @@ def get_noise_files(config: dict, show_progress: bool = False) -> NoiseFiles:
393
425
  """
394
426
  from itertools import chain
395
427
 
396
- from tqdm import tqdm
397
-
398
428
  from sonusai.utils import dataclass_from_dict
399
- from sonusai.utils import pp_tqdm_imap
429
+ from sonusai.utils import par_track
430
+ from sonusai.utils import track
431
+
400
432
  from .datatypes import NoiseFiles
401
433
 
402
- noise_files = list(chain.from_iterable([append_noise_files(entry=entry) for entry in config['noises']]))
434
+ noise_files = list(chain.from_iterable([append_noise_files(entry=entry) for entry in config["noises"]]))
403
435
 
404
- progress = tqdm(total=len(noise_files), disable=not show_progress)
405
- noise_files = pp_tqdm_imap(_get_num_samples, noise_files, progress=progress)
436
+ progress = track(total=len(noise_files), disable=not show_progress)
437
+ noise_files = par_track(_get_num_samples, noise_files, progress=progress)
406
438
  progress.close()
407
439
 
408
440
  return dataclass_from_dict(NoiseFiles, noise_files)
409
441
 
410
442
 
411
- def append_noise_files(entry: dict | str, tokens: dict = None) -> list[dict]:
443
+ def append_noise_files(entry: dict | str, tokens: dict | None = None) -> list[dict]:
412
444
  """Process noise files list and append as needed
413
445
 
414
446
  :param entry: Noise file entry to append to the list
@@ -423,7 +455,6 @@ def append_noise_files(entry: dict | str, tokens: dict = None) -> list[dict]:
423
455
  from os.path import join
424
456
  from os.path import splitext
425
457
 
426
- from sonusai import SonusAIError
427
458
  from .audio import validate_input_file
428
459
  from .tokenized_shell_vars import tokenized_expand
429
460
  from .tokenized_shell_vars import tokenized_replace
@@ -432,10 +463,10 @@ def append_noise_files(entry: dict | str, tokens: dict = None) -> list[dict]:
432
463
  tokens = {}
433
464
 
434
465
  if isinstance(entry, dict):
435
- if 'name' in entry:
436
- in_name = entry['name']
466
+ if "name" in entry:
467
+ in_name = entry["name"]
437
468
  else:
438
- raise SonusAIError('Noise list contained record without name')
469
+ raise AttributeError("Noise list contained record without name")
439
470
  else:
440
471
  in_name = entry
441
472
 
@@ -443,7 +474,7 @@ def append_noise_files(entry: dict | str, tokens: dict = None) -> list[dict]:
443
474
  tokens.update(new_tokens)
444
475
  names = sorted(glob(in_name))
445
476
  if not names:
446
- raise SonusAIError(f'Could not find {in_name}. Make sure path exists')
477
+ raise OSError(f"Could not find {in_name}. Make sure path exists")
447
478
 
448
479
  noise_files: list[dict] = []
449
480
  for name in names:
@@ -457,11 +488,11 @@ def append_noise_files(entry: dict | str, tokens: dict = None) -> list[dict]:
457
488
  noise_files.extend(append_noise_files(entry=child, tokens=tokens))
458
489
  else:
459
490
  try:
460
- if ext == '.txt':
461
- with open(file=name, mode='r') as txt_file:
491
+ if ext == ".txt":
492
+ with open(file=name) as txt_file:
462
493
  for line in txt_file:
463
494
  # strip comments
464
- child = line.partition('#')[0]
495
+ child = line.partition("#")[0]
465
496
  child = child.rstrip()
466
497
  if child:
467
498
  child, new_tokens = tokenized_expand(child)
@@ -469,26 +500,24 @@ def append_noise_files(entry: dict | str, tokens: dict = None) -> list[dict]:
469
500
  if not isabs(child):
470
501
  child = join(dir_name, child)
471
502
  noise_files.extend(append_noise_files(entry=child, tokens=tokens))
472
- elif ext == '.yml':
503
+ elif ext == ".yml":
473
504
  try:
474
505
  yml_config = raw_load_config(name)
475
506
 
476
- if 'noises' in yml_config:
477
- for record in yml_config['noises']:
507
+ if "noises" in yml_config:
508
+ for record in yml_config["noises"]:
478
509
  noise_files.extend(append_noise_files(entry=record, tokens=tokens))
479
510
  except Exception as e:
480
- raise SonusAIError(f'Error processing {name}: {e}')
511
+ raise OSError(f"Error processing {name}: {e}") from e
481
512
  else:
482
513
  validate_input_file(name)
483
514
  noise_file: dict = {
484
- 'expanded_name': name,
485
- 'name': tokenized_replace(name, tokens),
515
+ "expanded_name": name,
516
+ "name": tokenized_replace(name, tokens),
486
517
  }
487
518
  noise_files.append(noise_file)
488
- except SonusAIError:
489
- raise
490
519
  except Exception as e:
491
- raise SonusAIError(f'Error processing {name}: {e}')
520
+ raise OSError(f"Error processing {name}: {e}") from e
492
521
 
493
522
  return noise_files
494
523
 
@@ -499,13 +528,20 @@ def get_impulse_response_files(config: dict) -> ImpulseResponseFiles:
499
528
  :param config: Config dictionary
500
529
  :return: List of impulse response files
501
530
  """
502
- from itertools import chain
503
-
504
- return list(
505
- chain.from_iterable([append_impulse_response_files(entry=entry) for entry in config['impulse_responses']]))
506
-
507
-
508
- def append_impulse_response_files(entry: str, tokens: dict = None) -> list[str]:
531
+ return [ImpulseResponseFile(entry["name"], entry["tags"]) for entry in config["impulse_responses"]]
532
+ # from itertools import chain
533
+ #
534
+ # return list(
535
+ # chain.from_iterable(
536
+ # [
537
+ # append_impulse_response_files(entry=ImpulseResponseFile(entry["name"], entry["tags"]))
538
+ # for entry in config["impulse_responses"]
539
+ # ]
540
+ # )
541
+ # )
542
+
543
+
544
+ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | None = None) -> list[str]:
509
545
  """Process impulse response files list and append as needed
510
546
 
511
547
  :param entry: Impulse response file entry to append to the list
@@ -520,7 +556,6 @@ def append_impulse_response_files(entry: str, tokens: dict = None) -> list[str]:
520
556
  from os.path import join
521
557
  from os.path import splitext
522
558
 
523
- from sonusai import SonusAIError
524
559
  from .audio import validate_input_file
525
560
  from .tokenized_shell_vars import tokenized_expand
526
561
  from .tokenized_shell_vars import tokenized_replace
@@ -528,11 +563,11 @@ def append_impulse_response_files(entry: str, tokens: dict = None) -> list[str]:
528
563
  if tokens is None:
529
564
  tokens = {}
530
565
 
531
- in_name, new_tokens = tokenized_expand(entry)
566
+ in_name, new_tokens = tokenized_expand(entry.file)
532
567
  tokens.update(new_tokens)
533
568
  names = sorted(glob(in_name))
534
569
  if not names:
535
- raise SonusAIError(f'Could not find {in_name}. Make sure path exists')
570
+ raise OSError(f"Could not find {in_name}. Make sure path exists")
536
571
 
537
572
  impulse_response_files: list[str] = []
538
573
  for name in names:
@@ -540,41 +575,41 @@ def append_impulse_response_files(entry: str, tokens: dict = None) -> list[str]:
540
575
  dir_name = dirname(name)
541
576
  if isdir(name):
542
577
  for file in listdir(name):
543
- child = file
544
- if not isabs(child):
545
- child = join(dir_name, child)
578
+ if not isabs(file):
579
+ file = join(dir_name, file)
580
+ child = ImpulseResponseFile(file, entry.tags)
546
581
  impulse_response_files.extend(append_impulse_response_files(entry=child, tokens=tokens))
547
582
  else:
548
583
  try:
549
- if ext == '.txt':
550
- with open(file=name, mode='r') as txt_file:
584
+ if ext == ".txt":
585
+ with open(file=name) as txt_file:
551
586
  for line in txt_file:
552
587
  # strip comments
553
- child = line.partition('#')[0]
554
- child = child.rstrip()
555
- if child:
556
- child, new_tokens = tokenized_expand(child)
588
+ file = line.partition("#")[0]
589
+ file = file.rstrip()
590
+ if file:
591
+ file, new_tokens = tokenized_expand(file)
557
592
  tokens.update(new_tokens)
558
- if not isabs(child):
559
- child = join(dir_name, child)
593
+ if not isabs(file):
594
+ file = join(dir_name, file)
595
+ child = ImpulseResponseFile(file, entry.tags)
560
596
  impulse_response_files.extend(append_impulse_response_files(entry=child, tokens=tokens))
561
- elif ext == '.yml':
597
+ elif ext == ".yml":
562
598
  try:
563
599
  yml_config = raw_load_config(name)
564
600
 
565
- if 'impulse_responses' in yml_config:
566
- for record in yml_config['impulse_responses']:
601
+ if "impulse_responses" in yml_config:
602
+ for record in yml_config["impulse_responses"]:
567
603
  impulse_response_files.extend(
568
- append_impulse_response_files(entry=record, tokens=tokens))
604
+ append_impulse_response_files(entry=record, tokens=tokens)
605
+ )
569
606
  except Exception as e:
570
- raise SonusAIError(f'Error processing {name}: {e}')
607
+ raise OSError(f"Error processing {name}: {e}") from e
571
608
  else:
572
609
  validate_input_file(name)
573
610
  impulse_response_files.append(tokenized_replace(name, tokens))
574
- except SonusAIError:
575
- raise
576
611
  except Exception as e:
577
- raise SonusAIError(f'Error processing {name}: {e}')
612
+ raise OSError(f"Error processing {name}: {e}") from e
578
613
 
579
614
  return impulse_response_files
580
615
 
@@ -585,19 +620,51 @@ def get_spectral_masks(config: dict) -> SpectralMasks:
585
620
  :param config: Config dictionary
586
621
  :return: List of spectral masks
587
622
  """
588
- from sonusai import SonusAIError
589
623
  from sonusai.utils import dataclass_from_dict
590
- from .datatypes import SpectralMasks
591
624
 
592
625
  try:
593
- return dataclass_from_dict(SpectralMasks, config['spectral_masks'])
626
+ return dataclass_from_dict(SpectralMasks, config["spectral_masks"])
594
627
  except Exception as e:
595
- raise SonusAIError(f'Error in spectral_masks: {e}')
628
+ raise ValueError(f"Error in spectral_masks: {e}") from e
629
+
630
+
631
+ def get_truth_parameters(config: dict) -> TruthParameters:
632
+ """Get the list of truth parameters from a config
633
+
634
+ :param config: Config dictionary
635
+ :return: List of truth parameters
636
+ """
637
+ from copy import deepcopy
638
+
639
+ from sonusai.mixture import truth_functions
640
+ from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
641
+
642
+ from .constants import REQUIRED_TRUTH_CONFIGS
643
+ from .datatypes import TruthParameter
644
+
645
+ truth_parameters: TruthParameters = []
646
+ for name, truth_config in config["truth_configs"].items():
647
+ optional_config = deepcopy(truth_config)
648
+ for key in REQUIRED_TRUTH_CONFIGS:
649
+ del optional_config[key]
650
+
651
+ t_config = TruthFunctionConfig(
652
+ feature=config["feature"],
653
+ num_classes=config["num_classes"],
654
+ class_indices=[1],
655
+ target_gain=1,
656
+ config=optional_config,
657
+ )
658
+
659
+ parameters = getattr(truth_functions, truth_config["function"] + "_parameters")(t_config)
660
+ truth_parameters.append(TruthParameter(name, parameters))
661
+
662
+ return truth_parameters
596
663
 
597
664
 
598
665
  def _get_num_samples(entry: dict) -> dict:
599
666
  from .audio import get_num_samples
600
667
 
601
- entry['samples'] = get_num_samples(entry['expanded_name'])
602
- del entry['expanded_name']
668
+ entry["samples"] = get_num_samples(entry["expanded_name"])
669
+ del entry["expanded_name"]
603
670
  return entry