sonusai 0.18.8__py3-none-any.whl → 0.19.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +93 -77
  13. sonusai/genmetrics.py +59 -46
  14. sonusai/genmix.py +116 -104
  15. sonusai/genmixdb.py +194 -153
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +15 -17
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +19 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +50 -46
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +677 -473
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +52 -85
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +40 -27
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
  113. sonusai-0.19.5.dist-info/RECORD +125 -0
  114. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.8.dist-info/RECORD +0 -125
  118. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
@@ -1,24 +1,25 @@
1
- from sonusai.mixture.datatypes import AudioT
2
- from sonusai.mixture.datatypes import AudiosT
3
- from sonusai.mixture.datatypes import Augmentation
4
- from sonusai.mixture.datatypes import AugmentationRules
5
- from sonusai.mixture.datatypes import AugmentedTargets
6
- from sonusai.mixture.datatypes import GenMixData
7
- from sonusai.mixture.datatypes import ImpulseResponseFiles
8
- from sonusai.mixture.datatypes import Mixture
9
- from sonusai.mixture.datatypes import Mixtures
10
- from sonusai.mixture.datatypes import NoiseFiles
11
- from sonusai.mixture.datatypes import SpectralMasks
12
- from sonusai.mixture.datatypes import TargetFiles
13
- from sonusai.mixture.datatypes import Targets
14
- from sonusai.mixture.datatypes import UniversalSNRGenerator
15
- from sonusai.mixture.mixdb import MixtureDatabase
1
+ # ruff: noqa: S608
2
+ from .datatypes import AudiosT
3
+ from .datatypes import AudioT
4
+ from .datatypes import Augmentation
5
+ from .datatypes import AugmentationRules
6
+ from .datatypes import AugmentedTargets
7
+ from .datatypes import GenMixData
8
+ from .datatypes import ImpulseResponseFiles
9
+ from .datatypes import Mixture
10
+ from .datatypes import Mixtures
11
+ from .datatypes import NoiseFiles
12
+ from .datatypes import SpectralMasks
13
+ from .datatypes import TargetFiles
14
+ from .datatypes import Targets
15
+ from .datatypes import UniversalSNRGenerator
16
+ from .mixdb import MixtureDatabase
16
17
 
17
18
 
18
19
  def config_file(location: str) -> str:
19
20
  from os.path import join
20
21
 
21
- return join(location, 'config.yml')
22
+ return join(location, "config.yml")
22
23
 
23
24
 
24
25
  def initialize_db(location: str, test: bool = False) -> None:
@@ -27,9 +28,16 @@ def initialize_db(location: str, test: bool = False) -> None:
27
28
  con = db_connection(location=location, create=True, test=test)
28
29
 
29
30
  con.execute("""
30
- CREATE TABLE truth_setting(
31
+ CREATE TABLE truth_config(
31
32
  id INTEGER PRIMARY KEY NOT NULL,
32
- setting TEXT NOT NULL)
33
+ config TEXT NOT NULL)
34
+ """)
35
+
36
+ con.execute("""
37
+ CREATE TABLE truth_parameters(
38
+ id INTEGER PRIMARY KEY NOT NULL,
39
+ name TEXT NOT NULL,
40
+ parameters INTEGER NOT NULL)
33
41
  """)
34
42
 
35
43
  con.execute("""
@@ -37,6 +45,7 @@ def initialize_db(location: str, test: bool = False) -> None:
37
45
  id INTEGER PRIMARY KEY NOT NULL,
38
46
  name TEXT NOT NULL,
39
47
  samples INTEGER NOT NULL,
48
+ class_indices TEXT NOT NULL,
40
49
  level_type TEXT NOT NULL,
41
50
  speaker_id INTEGER,
42
51
  FOREIGN KEY(speaker_id) REFERENCES speaker (id))
@@ -65,8 +74,6 @@ def initialize_db(location: str, test: bool = False) -> None:
65
74
  noise_mix_mode TEXT NOT NULL,
66
75
  num_classes INTEGER NOT NULL,
67
76
  seed INTEGER NOT NULL,
68
- truth_mutex BOOLEAN NOT NULL,
69
- truth_reduction_function TEXT NOT NULL,
70
77
  mixid_width INTEGER NOT NULL,
71
78
  speaker_metadata_tiers TEXT NOT NULL,
72
79
  textgrid_metadata_tiers TEXT NOT NULL)
@@ -87,7 +94,8 @@ def initialize_db(location: str, test: bool = False) -> None:
87
94
  con.execute("""
88
95
  CREATE TABLE impulse_response_file (
89
96
  id INTEGER PRIMARY KEY NOT NULL,
90
- file TEXT NOT NULL)
97
+ file TEXT NOT NULL,
98
+ tags TEXT NOT NULL)
91
99
  """)
92
100
 
93
101
  con.execute("""
@@ -101,11 +109,11 @@ def initialize_db(location: str, test: bool = False) -> None:
101
109
  """)
102
110
 
103
111
  con.execute("""
104
- CREATE TABLE target_file_truth_setting (
112
+ CREATE TABLE target_file_truth_config (
105
113
  target_file_id INTEGER,
106
- truth_setting_id INTEGER,
114
+ truth_config_id INTEGER,
107
115
  FOREIGN KEY(target_file_id) REFERENCES target_file (id),
108
- FOREIGN KEY(truth_setting_id) REFERENCES truth_setting (id))
116
+ FOREIGN KEY(truth_config_id) REFERENCES truth_config (id))
109
117
  """)
110
118
 
111
119
  con.execute("""
@@ -148,59 +156,55 @@ def initialize_db(location: str, test: bool = False) -> None:
148
156
 
149
157
 
150
158
  def populate_top_table(location: str, config: dict, test: bool = False) -> None:
151
- """Populate top table
152
- """
159
+ """Populate top table"""
153
160
  import json
154
161
 
155
- from sonusai import SonusAIError
162
+ from .constants import MIXDB_VERSION
156
163
  from .mixdb import db_connection
157
164
 
158
- if config['truth_mode'] not in ['normal', 'mutex']:
159
- raise SonusAIError(f'invalid truth_mode: {config["truth_mode"]}')
160
- truth_mutex = config['truth_mode'] == 'mutex'
161
-
162
165
  con = db_connection(location=location, readonly=False, test=test)
163
- con.execute("""
166
+ con.execute(
167
+ """
164
168
  INSERT INTO top (version, asr_configs, class_balancing, feature, noise_mix_mode, num_classes,
165
- seed, truth_mutex, truth_reduction_function, mixid_width, speaker_metadata_tiers, textgrid_metadata_tiers)
166
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
167
- """, (
168
- 1,
169
- json.dumps(config['asr_configs']),
170
- config['class_balancing'],
171
- config['feature'],
172
- config['noise_mix_mode'],
173
- config['num_classes'],
174
- config['seed'],
175
- truth_mutex,
176
- config['truth_reduction_function'],
177
- 0,
178
- '',
179
- ''))
169
+ seed, mixid_width, speaker_metadata_tiers, textgrid_metadata_tiers)
170
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
171
+ """,
172
+ (
173
+ MIXDB_VERSION,
174
+ json.dumps(config["asr_configs"]),
175
+ config["class_balancing"],
176
+ config["feature"],
177
+ config["noise_mix_mode"],
178
+ config["num_classes"],
179
+ config["seed"],
180
+ 0,
181
+ "",
182
+ "",
183
+ ),
184
+ )
180
185
  con.commit()
181
186
  con.close()
182
187
 
183
188
 
184
189
  def populate_class_label_table(location: str, config: dict, test: bool = False) -> None:
185
- """Populate class_label table
186
- """
190
+ """Populate class_label table"""
187
191
  from .mixdb import db_connection
188
192
 
189
193
  con = db_connection(location=location, readonly=False, test=test)
190
- con.executemany("INSERT INTO class_label (label) VALUES (?)",
191
- [(item,) for item in config['class_labels']])
194
+ con.executemany(
195
+ "INSERT INTO class_label (label) VALUES (?)",
196
+ [(item,) for item in config["class_labels"]],
197
+ )
192
198
  con.commit()
193
199
  con.close()
194
200
 
195
201
 
196
202
  def populate_class_weights_threshold_table(location: str, config: dict, test: bool = False) -> None:
197
- """Populate class_weights_threshold table
198
- """
199
- from sonusai import SonusAIError
203
+ """Populate class_weights_threshold table"""
200
204
  from .mixdb import db_connection
201
205
 
202
- class_weights_threshold = config['class_weights_threshold']
203
- num_classes = config['num_classes']
206
+ class_weights_threshold = config["class_weights_threshold"]
207
+ num_classes = config["num_classes"]
204
208
 
205
209
  if not isinstance(class_weights_threshold, list):
206
210
  class_weights_threshold = [class_weights_threshold]
@@ -209,43 +213,72 @@ def populate_class_weights_threshold_table(location: str, config: dict, test: bo
209
213
  class_weights_threshold = [class_weights_threshold[0]] * num_classes
210
214
 
211
215
  if len(class_weights_threshold) != num_classes:
212
- raise SonusAIError(f'invalid class_weights_threshold length: {len(class_weights_threshold)}')
216
+ raise ValueError(f"invalid class_weights_threshold length: {len(class_weights_threshold)}")
213
217
 
214
218
  con = db_connection(location=location, readonly=False, test=test)
215
- con.executemany("INSERT INTO class_weights_threshold (threshold) VALUES (?)",
216
- [(item,) for item in class_weights_threshold])
219
+ con.executemany(
220
+ "INSERT INTO class_weights_threshold (threshold) VALUES (?)",
221
+ [(item,) for item in class_weights_threshold],
222
+ )
217
223
  con.commit()
218
224
  con.close()
219
225
 
220
226
 
221
227
  def populate_spectral_mask_table(location: str, config: dict, test: bool = False) -> None:
222
- """Populate spectral_mask table
223
- """
228
+ """Populate spectral_mask table"""
224
229
  from .config import get_spectral_masks
225
230
  from .mixdb import db_connection
226
231
 
227
232
  con = db_connection(location=location, readonly=False, test=test)
228
- con.executemany("""
233
+ con.executemany(
234
+ """
229
235
  INSERT INTO spectral_mask (f_max_width, f_num, t_max_width, t_num, t_max_percent) VALUES (?, ?, ?, ?, ?)
230
- """, [(item.f_max_width,
231
- item.f_num,
232
- item.t_max_width,
233
- item.t_num,
234
- item.t_max_percent) for item in get_spectral_masks(config)]
235
- )
236
+ """,
237
+ [
238
+ (
239
+ item.f_max_width,
240
+ item.f_num,
241
+ item.t_max_width,
242
+ item.t_num,
243
+ item.t_max_percent,
244
+ )
245
+ for item in get_spectral_masks(config)
246
+ ],
247
+ )
248
+ con.commit()
249
+ con.close()
250
+
251
+
252
+ def populate_truth_parameters_table(location: str, config: dict, test: bool = False) -> None:
253
+ """Populate truth_parameters table"""
254
+ from .config import get_truth_parameters
255
+ from .mixdb import db_connection
256
+
257
+ con = db_connection(location=location, readonly=False, test=test)
258
+ con.executemany(
259
+ """
260
+ INSERT INTO truth_parameters (name, parameters) VALUES (?, ?)
261
+ """,
262
+ [
263
+ (
264
+ item.name,
265
+ item.parameters,
266
+ )
267
+ for item in get_truth_parameters(config)
268
+ ],
269
+ )
236
270
  con.commit()
237
271
  con.close()
238
272
 
239
273
 
240
274
  def populate_target_file_table(location: str, target_files: TargetFiles, test: bool = False) -> None:
241
- """Populate target file table
242
- """
275
+ """Populate target file table"""
243
276
  import json
244
277
  from pathlib import Path
245
278
 
246
279
  from .mixdb import db_connection
247
280
 
248
- _populate_truth_setting_table(location, target_files, test)
281
+ _populate_truth_config_table(location, target_files, test)
249
282
  _populate_speaker_table(location, target_files, test)
250
283
 
251
284
  con = db_connection(location=location, readonly=False, test=test)
@@ -259,76 +292,106 @@ def populate_target_file_table(location: str, target_files: TargetFiles, test: b
259
292
  textgrid_metadata_tiers.add(tier)
260
293
 
261
294
  # Get truth settings for target file
262
- truth_setting_ids: list[int] = []
263
- for truth_setting in target_file.truth_settings:
264
- cur.execute("SELECT truth_setting.id FROM truth_setting WHERE ? = truth_setting.setting",
265
- (truth_setting.to_json(),))
266
- truth_setting_ids.append(cur.fetchone()[0])
295
+ truth_config_ids: list[int] = []
296
+ for name, config in target_file.truth_configs.items():
297
+ ts = json.dumps({"name": name} | config.to_dict())
298
+ cur.execute(
299
+ "SELECT truth_config.id FROM truth_config WHERE ? = truth_config.config",
300
+ (ts,),
301
+ )
302
+ truth_config_ids.append(cur.fetchone()[0])
267
303
 
268
304
  # Get speaker_id for target file
269
- cur.execute("SELECT speaker.id FROM speaker WHERE ? = speaker.parent",
270
- (Path(target_file.name).parent.as_posix(),))
305
+ cur.execute(
306
+ "SELECT speaker.id FROM speaker WHERE ? = speaker.parent",
307
+ (Path(target_file.name).parent.as_posix(),),
308
+ )
271
309
  result = cur.fetchone()
272
310
  speaker_id = None
273
311
  if result is not None:
274
312
  speaker_id = result[0]
275
313
 
276
314
  # Add entry
277
- cur.execute("INSERT INTO target_file (name, samples, level_type, speaker_id) VALUES (?, ?, ?, ?)",
278
- (target_file.name, target_file.samples, target_file.level_type, speaker_id))
315
+ cur.execute(
316
+ "INSERT INTO target_file (name, samples, class_indices, level_type, speaker_id) VALUES (?, ?, ?, ?, ?)",
317
+ (
318
+ target_file.name,
319
+ target_file.samples,
320
+ json.dumps(target_file.class_indices),
321
+ target_file.level_type,
322
+ speaker_id,
323
+ ),
324
+ )
279
325
  target_file_id = cur.lastrowid
280
- for truth_setting_id in truth_setting_ids:
281
- cur.execute("INSERT INTO target_file_truth_setting (target_file_id, truth_setting_id) VALUES (?, ?)",
282
- (target_file_id, truth_setting_id))
326
+ for truth_config_id in truth_config_ids:
327
+ cur.execute(
328
+ "INSERT INTO target_file_truth_config (target_file_id, truth_config_id) VALUES (?, ?)",
329
+ (target_file_id, truth_config_id),
330
+ )
283
331
 
284
332
  # Update textgrid_metadata_tiers in the top table
285
- con.execute("UPDATE top SET textgrid_metadata_tiers=? WHERE top.id = ?",
286
- (json.dumps(sorted(textgrid_metadata_tiers)), 1))
333
+ con.execute(
334
+ "UPDATE top SET textgrid_metadata_tiers=? WHERE top.id = ?",
335
+ (json.dumps(sorted(textgrid_metadata_tiers)), 1),
336
+ )
287
337
 
288
338
  con.commit()
289
339
  con.close()
290
340
 
291
341
 
292
342
  def populate_noise_file_table(location: str, noise_files: NoiseFiles, test: bool = False) -> None:
293
- """Populate noise file table
294
- """
343
+ """Populate noise file table"""
295
344
  from .mixdb import db_connection
296
345
 
297
346
  con = db_connection(location=location, readonly=False, test=test)
298
- con.executemany("INSERT INTO noise_file (name, samples) VALUES (?, ?)",
299
- [(noise_file.name, noise_file.samples) for noise_file in noise_files])
347
+ con.executemany(
348
+ "INSERT INTO noise_file (name, samples) VALUES (?, ?)",
349
+ [(noise_file.name, noise_file.samples) for noise_file in noise_files],
350
+ )
300
351
  con.commit()
301
352
  con.close()
302
353
 
303
354
 
304
- def populate_impulse_response_file_table(location: str, impulse_response_files: ImpulseResponseFiles,
305
- test: bool = False) -> None:
306
- """Populate impulse response file table
307
- """
355
+ def populate_impulse_response_file_table(
356
+ location: str, impulse_response_files: ImpulseResponseFiles, test: bool = False
357
+ ) -> None:
358
+ """Populate impulse response file table"""
359
+ import json
360
+
308
361
  from .mixdb import db_connection
309
362
 
310
363
  con = db_connection(location=location, readonly=False, test=test)
311
- con.executemany("INSERT INTO impulse_response_file (file) VALUES (?)",
312
- [(impulse_response_file,) for impulse_response_file in impulse_response_files])
364
+ con.executemany(
365
+ "INSERT INTO impulse_response_file (file, tags) VALUES (?, ?)",
366
+ [
367
+ (
368
+ impulse_response_file.file,
369
+ json.dumps(impulse_response_file.tags),
370
+ )
371
+ for impulse_response_file in impulse_response_files
372
+ ],
373
+ )
313
374
  con.commit()
314
375
  con.close()
315
376
 
316
377
 
317
378
  def update_mixid_width(location: str, num_mixtures: int, test: bool = False) -> None:
318
- """Update the mixid width
319
- """
320
- from .mixdb import db_connection
379
+ """Update the mixid width"""
321
380
  from sonusai.utils import max_text_width
322
381
 
382
+ from .mixdb import db_connection
383
+
323
384
  con = db_connection(location=location, readonly=False, test=test)
324
- con.execute("UPDATE top SET mixid_width=? WHERE top.id = ?", (max_text_width(num_mixtures), 1))
385
+ con.execute(
386
+ "UPDATE top SET mixid_width=? WHERE top.id = ?",
387
+ (max_text_width(num_mixtures), 1),
388
+ )
325
389
  con.commit()
326
390
  con.close()
327
391
 
328
392
 
329
393
  def populate_mixture_table(location: str, mixtures: Mixtures, test: bool = False) -> None:
330
- """Populate mixture table
331
- """
394
+ """Populate mixture table"""
332
395
  from .helpers import from_mixture
333
396
  from .helpers import from_target
334
397
  from .mixdb import db_connection
@@ -348,29 +411,35 @@ def populate_mixture_table(location: str, mixtures: Mixtures, test: bool = False
348
411
  # Populate mixture table
349
412
  cur = con.cursor()
350
413
  for mixture in mixtures:
351
- cur.execute("""
414
+ cur.execute(
415
+ """
352
416
  INSERT INTO mixture (name, noise_file_id, noise_augmentation, noise_offset, noise_snr_gain, random_snr,
353
417
  snr, samples, spectral_mask_id, spectral_mask_seed, target_snr_gain)
354
418
  VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
355
- """, from_mixture(mixture))
419
+ """,
420
+ from_mixture(mixture),
421
+ )
356
422
 
357
423
  mixture_id = cur.lastrowid
358
424
  for target in mixture.targets:
359
- target_id = con.execute("""
425
+ target_id = con.execute(
426
+ """
360
427
  SELECT target.id
361
428
  FROM target
362
429
  WHERE ? = target.file_id AND ? = target.augmentation AND ? = target.gain
363
- """, from_target(target)).fetchone()[0]
364
- con.execute("INSERT INTO mixture_target (mixture_id, target_id) VALUES (?, ?)",
365
- (mixture_id, target_id))
430
+ """,
431
+ from_target(target),
432
+ ).fetchone()[0]
433
+ con.execute(
434
+ "INSERT INTO mixture_target (mixture_id, target_id) VALUES (?, ?)",
435
+ (mixture_id, target_id),
436
+ )
366
437
 
367
438
  con.commit()
368
439
  con.close()
369
440
 
370
441
 
371
- def update_mixture(mixdb: MixtureDatabase,
372
- mixture: Mixture,
373
- with_data: bool = False) -> tuple[Mixture, GenMixData]:
442
+ def update_mixture(mixdb: MixtureDatabase, mixture: Mixture, with_data: bool = False) -> tuple[Mixture, GenMixData]:
374
443
  """Update mixture record with name and gains
375
444
 
376
445
  :param mixdb: Mixture database
@@ -391,12 +460,11 @@ def update_mixture(mixdb: MixtureDatabase,
391
460
  # Apply IR and sum targets audio before initializing the mixture SNR gains
392
461
  target_audio = get_target(mixdb, mixture, targets_audio)
393
462
 
394
- mixture = _initialize_mixture_gains(mixdb=mixdb,
395
- mixture=mixture,
396
- target_audio=target_audio,
397
- noise_audio=noise_audio)
463
+ mixture = _initialize_mixture_gains(
464
+ mixdb=mixdb, mixture=mixture, target_audio=target_audio, noise_audio=noise_audio
465
+ )
398
466
 
399
- mixture.name = f'{int(mixture.name):0{mixdb.mixid_width}}.h5'
467
+ mixture.name = f"{int(mixture.name):0{mixdb.mixid_width}}"
400
468
 
401
469
  if not with_data:
402
470
  return mixture, GenMixData()
@@ -409,10 +477,12 @@ def update_mixture(mixdb: MixtureDatabase,
409
477
  target_audio = get_target(mixdb, mixture, targets_audio)
410
478
  mixture_audio = target_audio + noise_audio
411
479
 
412
- return mixture, GenMixData(mixture=mixture_audio,
413
- targets=targets_audio,
414
- target=target_audio,
415
- noise=noise_audio)
480
+ return mixture, GenMixData(
481
+ mixture=mixture_audio,
482
+ targets=targets_audio,
483
+ target=target_audio,
484
+ noise=noise_audio,
485
+ )
416
486
 
417
487
 
418
488
  def _augmented_noise_audio(mixdb: MixtureDatabase, mixture: Mixture) -> AudioT:
@@ -439,9 +509,13 @@ def _initialize_targets_audio(mixdb: MixtureDatabase, mixture: Mixture) -> tuple
439
509
  targets_audio = []
440
510
  for target in mixture.targets:
441
511
  target_audio = mixdb.read_target_audio(target.file_id)
442
- targets_audio.append(apply_augmentation(audio=target_audio,
443
- augmentation=target.augmentation,
444
- frame_length=mixdb.feature_step_samples))
512
+ targets_audio.append(
513
+ apply_augmentation(
514
+ audio=target_audio,
515
+ augmentation=target.augmentation,
516
+ frame_length=mixdb.feature_step_samples,
517
+ )
518
+ )
445
519
 
446
520
  # target_gain is used to back out the gain augmentation in order to return the target audio
447
521
  # to its normalized level when calculating truth (if needed).
@@ -458,13 +532,11 @@ def _initialize_targets_audio(mixdb: MixtureDatabase, mixture: Mixture) -> tuple
458
532
  return mixture, targets_audio
459
533
 
460
534
 
461
- def _initialize_mixture_gains(mixdb: MixtureDatabase,
462
- mixture: Mixture,
463
- target_audio: AudioT,
464
- noise_audio: AudioT) -> Mixture:
535
+ def _initialize_mixture_gains(
536
+ mixdb: MixtureDatabase, mixture: Mixture, target_audio: AudioT, noise_audio: AudioT
537
+ ) -> Mixture:
465
538
  import numpy as np
466
539
 
467
- from sonusai import SonusAIError
468
540
  from sonusai.utils import asl_p56
469
541
  from sonusai.utils import db_to_linear
470
542
 
@@ -480,19 +552,20 @@ def _initialize_mixture_gains(mixdb: MixtureDatabase,
480
552
  mixture.target_snr_gain = 1
481
553
  mixture.noise_snr_gain = 0
482
554
  else:
483
- target_level_types = [target_file.level_type for target_file in
484
- [mixdb.target_file(target.file_id) for target in mixture.targets]]
555
+ target_level_types = [
556
+ target_file.level_type for target_file in [mixdb.target_file(target.file_id) for target in mixture.targets]
557
+ ]
485
558
  if not all(level_type == target_level_types[0] for level_type in target_level_types):
486
- raise SonusAIError(f'Not all target_level_types in mixup are the same')
559
+ raise ValueError("Not all target_level_types in mixup are the same")
487
560
 
488
561
  level_type = target_level_types[0]
489
562
  match level_type:
490
- case 'default':
563
+ case "default":
491
564
  target_energy = np.mean(np.square(target_audio))
492
- case 'speech':
565
+ case "speech":
493
566
  target_energy = asl_p56(target_audio)
494
567
  case _:
495
- raise SonusAIError(f'Unknown level_type: {level_type}')
568
+ raise ValueError(f"Unknown level_type: {level_type}")
496
569
 
497
570
  noise_energy = np.mean(np.square(noise_audio))
498
571
  if noise_energy == 0:
@@ -525,19 +598,20 @@ def _initialize_mixture_gains(mixdb: MixtureDatabase,
525
598
  return mixture
526
599
 
527
600
 
528
- def generate_mixtures(noise_mix_mode: str,
529
- augmented_targets: AugmentedTargets,
530
- target_files: TargetFiles,
531
- target_augmentations: AugmentationRules,
532
- noise_files: NoiseFiles,
533
- noise_augmentations: AugmentationRules,
534
- spectral_masks: SpectralMasks,
535
- all_snrs: list[UniversalSNRGenerator],
536
- mixups: list[int],
537
- num_classes: int,
538
- truth_mutex: bool,
539
- feature_step_samples: int,
540
- num_ir: int) -> tuple[int, int, Mixtures]:
601
+ def generate_mixtures(
602
+ noise_mix_mode: str,
603
+ augmented_targets: AugmentedTargets,
604
+ target_files: TargetFiles,
605
+ target_augmentations: AugmentationRules,
606
+ noise_files: NoiseFiles,
607
+ noise_augmentations: AugmentationRules,
608
+ spectral_masks: SpectralMasks,
609
+ all_snrs: list[UniversalSNRGenerator],
610
+ mixups: list[int],
611
+ num_classes: int,
612
+ feature_step_samples: int,
613
+ num_ir: int,
614
+ ) -> tuple[int, int, Mixtures]:
541
615
  """Generate mixtures
542
616
 
543
617
  :param noise_mix_mode: Noise mix mode
@@ -550,72 +624,72 @@ def generate_mixtures(noise_mix_mode: str,
550
624
  :param all_snrs: List of all SNRs
551
625
  :param mixups: List of mixup values
552
626
  :param num_classes: Number of classes
553
- :param truth_mutex: Truth mutex mode
554
627
  :param feature_step_samples: Number of samples in a feature step
555
628
  :param num_ir: Number of impulse response files
556
629
  :return: (Number of noise files used, number of noise samples used, list of mixture records)
557
630
  """
558
- from sonusai import SonusAIError
559
-
560
- if noise_mix_mode == 'exhaustive':
561
- return _exhaustive_noise_mix(augmented_targets=augmented_targets,
562
- target_files=target_files,
563
- target_augmentations=target_augmentations,
564
- noise_files=noise_files,
565
- noise_augmentations=noise_augmentations,
566
- spectral_masks=spectral_masks,
567
- all_snrs=all_snrs,
568
- mixups=mixups,
569
- num_classes=num_classes,
570
- truth_mutex=truth_mutex,
571
- feature_step_samples=feature_step_samples,
572
- num_ir=num_ir)
573
-
574
- if noise_mix_mode == 'non-exhaustive':
575
- return _non_exhaustive_noise_mix(augmented_targets=augmented_targets,
576
- target_files=target_files,
577
- target_augmentations=target_augmentations,
578
- noise_files=noise_files,
579
- noise_augmentations=noise_augmentations,
580
- spectral_masks=spectral_masks,
581
- all_snrs=all_snrs,
582
- mixups=mixups,
583
- num_classes=num_classes,
584
- truth_mutex=truth_mutex,
585
- feature_step_samples=feature_step_samples,
586
- num_ir=num_ir)
587
-
588
- if noise_mix_mode == 'non-combinatorial':
589
- return _non_combinatorial_noise_mix(augmented_targets=augmented_targets,
590
- target_files=target_files,
591
- target_augmentations=target_augmentations,
592
- noise_files=noise_files,
593
- noise_augmentations=noise_augmentations,
594
- spectral_masks=spectral_masks,
595
- all_snrs=all_snrs,
596
- mixups=mixups,
597
- num_classes=num_classes,
598
- truth_mutex=truth_mutex,
599
- feature_step_samples=feature_step_samples,
600
- num_ir=num_ir)
601
-
602
- raise SonusAIError(f'invalid noise_mix_mode: {noise_mix_mode}')
603
-
604
-
605
- def _exhaustive_noise_mix(augmented_targets: AugmentedTargets,
606
- target_files: TargetFiles,
607
- target_augmentations: AugmentationRules,
608
- noise_files: NoiseFiles,
609
- noise_augmentations: AugmentationRules,
610
- spectral_masks: SpectralMasks,
611
- all_snrs: list[UniversalSNRGenerator],
612
- mixups: list[int],
613
- num_classes: int,
614
- truth_mutex: bool,
615
- feature_step_samples: int,
616
- num_ir: int) -> tuple[int, int, Mixtures]:
617
- """ Use every noise/augmentation with every target/augmentation
618
- """
631
+ if noise_mix_mode == "exhaustive":
632
+ return _exhaustive_noise_mix(
633
+ augmented_targets=augmented_targets,
634
+ target_files=target_files,
635
+ target_augmentations=target_augmentations,
636
+ noise_files=noise_files,
637
+ noise_augmentations=noise_augmentations,
638
+ spectral_masks=spectral_masks,
639
+ all_snrs=all_snrs,
640
+ mixups=mixups,
641
+ num_classes=num_classes,
642
+ feature_step_samples=feature_step_samples,
643
+ num_ir=num_ir,
644
+ )
645
+
646
+ if noise_mix_mode == "non-exhaustive":
647
+ return _non_exhaustive_noise_mix(
648
+ augmented_targets=augmented_targets,
649
+ target_files=target_files,
650
+ target_augmentations=target_augmentations,
651
+ noise_files=noise_files,
652
+ noise_augmentations=noise_augmentations,
653
+ spectral_masks=spectral_masks,
654
+ all_snrs=all_snrs,
655
+ mixups=mixups,
656
+ num_classes=num_classes,
657
+ feature_step_samples=feature_step_samples,
658
+ num_ir=num_ir,
659
+ )
660
+
661
+ if noise_mix_mode == "non-combinatorial":
662
+ return _non_combinatorial_noise_mix(
663
+ augmented_targets=augmented_targets,
664
+ target_files=target_files,
665
+ target_augmentations=target_augmentations,
666
+ noise_files=noise_files,
667
+ noise_augmentations=noise_augmentations,
668
+ spectral_masks=spectral_masks,
669
+ all_snrs=all_snrs,
670
+ mixups=mixups,
671
+ num_classes=num_classes,
672
+ feature_step_samples=feature_step_samples,
673
+ num_ir=num_ir,
674
+ )
675
+
676
+ raise ValueError(f"invalid noise_mix_mode: {noise_mix_mode}")
677
+
678
+
679
+ def _exhaustive_noise_mix(
680
+ augmented_targets: AugmentedTargets,
681
+ target_files: TargetFiles,
682
+ target_augmentations: AugmentationRules,
683
+ noise_files: NoiseFiles,
684
+ noise_augmentations: AugmentationRules,
685
+ spectral_masks: SpectralMasks,
686
+ all_snrs: list[UniversalSNRGenerator],
687
+ mixups: list[int],
688
+ num_classes: int,
689
+ feature_step_samples: int,
690
+ num_ir: int,
691
+ ) -> tuple[int, int, Mixtures]:
692
+ """Use every noise/augmentation with every target/augmentation"""
619
693
  from random import randint
620
694
 
621
695
  import numpy as np
@@ -633,42 +707,53 @@ def _exhaustive_noise_mix(augmented_targets: AugmentedTargets,
633
707
  used_noise_files = len(noise_files) * len(noise_augmentations)
634
708
  used_noise_samples = 0
635
709
 
636
- augmented_target_ids_for_mixups = [get_augmented_target_ids_for_mixup(augmented_targets=augmented_targets,
637
- targets=target_files,
638
- target_augmentations=target_augmentations,
639
- mixup=mixup,
640
- num_classes=num_classes,
641
- truth_mutex=truth_mutex) for mixup in mixups]
710
+ augmented_target_ids_for_mixups = [
711
+ get_augmented_target_ids_for_mixup(
712
+ augmented_targets=augmented_targets,
713
+ targets=target_files,
714
+ target_augmentations=target_augmentations,
715
+ mixup=mixup,
716
+ num_classes=num_classes,
717
+ )
718
+ for mixup in mixups
719
+ ]
642
720
  for noise_file_id in range(len(noise_files)):
643
721
  for noise_augmentation_rule in noise_augmentations:
644
722
  noise_augmentation = augmentation_from_rule(noise_augmentation_rule, num_ir)
645
723
  noise_offset = 0
646
724
  noise_length = estimate_augmented_length_from_length(
647
725
  length=noise_files[noise_file_id].samples,
648
- tempo=noise_augmentation.tempo)
726
+ tempo=noise_augmentation.tempo,
727
+ )
649
728
 
650
729
  for augmented_target_ids_for_mixup in augmented_target_ids_for_mixups:
651
730
  for augmented_target_ids in augmented_target_ids_for_mixup:
652
- targets, target_length = _get_target_info(augmented_target_ids=augmented_target_ids,
653
- augmented_targets=augmented_targets,
654
- target_files=target_files,
655
- target_augmentations=target_augmentations,
656
- feature_step_samples=feature_step_samples,
657
- num_ir=num_ir)
731
+ targets, target_length = _get_target_info(
732
+ augmented_target_ids=augmented_target_ids,
733
+ augmented_targets=augmented_targets,
734
+ target_files=target_files,
735
+ target_augmentations=target_augmentations,
736
+ feature_step_samples=feature_step_samples,
737
+ num_ir=num_ir,
738
+ )
658
739
 
659
740
  for spectral_mask_id in range(len(spectral_masks)):
660
741
  for snr in all_snrs:
661
- mixtures.append(Mixture(
662
- targets=targets,
663
- name=str(m_id),
664
- noise=Noise(file_id=noise_file_id + 1,
665
- augmentation=noise_augmentation,
666
- offset=noise_offset),
667
- samples=target_length,
668
- snr=UniversalSNR(value=snr.value,
669
- is_random=snr.is_random),
670
- spectral_mask_id=spectral_mask_id + 1,
671
- spectral_mask_seed=randint(0, np.iinfo('i').max)))
742
+ mixtures.append(
743
+ Mixture(
744
+ targets=targets,
745
+ name=str(m_id),
746
+ noise=Noise(
747
+ file_id=noise_file_id + 1,
748
+ augmentation=noise_augmentation,
749
+ offset=noise_offset,
750
+ ),
751
+ samples=target_length,
752
+ snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
753
+ spectral_mask_id=spectral_mask_id + 1,
754
+ spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
755
+ )
756
+ )
672
757
  m_id += 1
673
758
 
674
759
  noise_offset = int((noise_offset + target_length) % noise_length)
@@ -677,19 +762,20 @@ def _exhaustive_noise_mix(augmented_targets: AugmentedTargets,
677
762
  return used_noise_files, used_noise_samples, mixtures
678
763
 
679
764
 
680
- def _non_exhaustive_noise_mix(augmented_targets: AugmentedTargets,
681
- target_files: TargetFiles,
682
- target_augmentations: AugmentationRules,
683
- noise_files: NoiseFiles,
684
- noise_augmentations: AugmentationRules,
685
- spectral_masks: SpectralMasks,
686
- all_snrs: list[UniversalSNRGenerator],
687
- mixups: list[int],
688
- num_classes: int,
689
- truth_mutex: bool,
690
- feature_step_samples: int,
691
- num_ir: int) -> tuple[int, int, Mixtures]:
692
- """ Cycle through every target/augmentation without necessarily using all noise/augmentation combinations
765
+ def _non_exhaustive_noise_mix(
766
+ augmented_targets: AugmentedTargets,
767
+ target_files: TargetFiles,
768
+ target_augmentations: AugmentationRules,
769
+ noise_files: NoiseFiles,
770
+ noise_augmentations: AugmentationRules,
771
+ spectral_masks: SpectralMasks,
772
+ all_snrs: list[UniversalSNRGenerator],
773
+ mixups: list[int],
774
+ num_classes: int,
775
+ feature_step_samples: int,
776
+ num_ir: int,
777
+ ) -> tuple[int, int, Mixtures]:
778
+ """Cycle through every target/augmentation without necessarily using all noise/augmentation combinations
693
779
  (reduced data set).
694
780
  """
695
781
  from random import randint
@@ -710,67 +796,81 @@ def _non_exhaustive_noise_mix(augmented_targets: AugmentedTargets,
710
796
  noise_augmentation_id = None
711
797
  noise_offset = None
712
798
 
713
- augmented_target_indices_for_mixups = [get_augmented_target_ids_for_mixup(
714
- augmented_targets=augmented_targets,
715
- targets=target_files,
716
- target_augmentations=target_augmentations,
717
- mixup=mixup,
718
- num_classes=num_classes,
719
- truth_mutex=truth_mutex) for mixup in mixups]
799
+ augmented_target_indices_for_mixups = [
800
+ get_augmented_target_ids_for_mixup(
801
+ augmented_targets=augmented_targets,
802
+ targets=target_files,
803
+ target_augmentations=target_augmentations,
804
+ mixup=mixup,
805
+ num_classes=num_classes,
806
+ )
807
+ for mixup in mixups
808
+ ]
720
809
  for mixup in augmented_target_indices_for_mixups:
721
810
  for augmented_target_indices in mixup:
722
- targets, target_length = _get_target_info(augmented_target_ids=augmented_target_indices,
723
- augmented_targets=augmented_targets,
724
- target_files=target_files,
725
- target_augmentations=target_augmentations,
726
- feature_step_samples=feature_step_samples,
727
- num_ir=num_ir)
811
+ targets, target_length = _get_target_info(
812
+ augmented_target_ids=augmented_target_indices,
813
+ augmented_targets=augmented_targets,
814
+ target_files=target_files,
815
+ target_augmentations=target_augmentations,
816
+ feature_step_samples=feature_step_samples,
817
+ num_ir=num_ir,
818
+ )
728
819
 
729
820
  for spectral_mask_id in range(len(spectral_masks)):
730
821
  for snr in all_snrs:
731
- (noise_file_id,
732
- noise_augmentation_id,
733
- noise_augmentation,
734
- noise_offset) = _get_next_noise_offset(noise_file_id=noise_file_id,
735
- noise_augmentation_id=noise_augmentation_id,
736
- noise_offset=noise_offset,
737
- target_length=target_length,
738
- noise_files=noise_files,
739
- noise_augmentations=noise_augmentations,
740
- num_ir=num_ir)
822
+ (
823
+ noise_file_id,
824
+ noise_augmentation_id,
825
+ noise_augmentation,
826
+ noise_offset,
827
+ ) = _get_next_noise_offset(
828
+ noise_file_id=noise_file_id,
829
+ noise_augmentation_id=noise_augmentation_id,
830
+ noise_offset=noise_offset,
831
+ target_length=target_length,
832
+ noise_files=noise_files,
833
+ noise_augmentations=noise_augmentations,
834
+ num_ir=num_ir,
835
+ )
741
836
  used_noise_samples += target_length
742
837
 
743
- used_noise_files.add(f'{noise_file_id}_{noise_augmentation_id}')
744
-
745
- mixtures.append(Mixture(
746
- targets=targets,
747
- name=str(m_id),
748
- noise=Noise(file_id=noise_file_id + 1,
749
- augmentation=noise_augmentation,
750
- offset=noise_offset),
751
- samples=target_length,
752
- snr=UniversalSNR(value=snr.value,
753
- is_random=snr.is_random),
754
- spectral_mask_id=spectral_mask_id + 1,
755
- spectral_mask_seed=randint(0, np.iinfo('i').max)))
838
+ used_noise_files.add(f"{noise_file_id}_{noise_augmentation_id}")
839
+
840
+ mixtures.append(
841
+ Mixture(
842
+ targets=targets,
843
+ name=str(m_id),
844
+ noise=Noise(
845
+ file_id=noise_file_id + 1,
846
+ augmentation=noise_augmentation,
847
+ offset=noise_offset,
848
+ ),
849
+ samples=target_length,
850
+ snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
851
+ spectral_mask_id=spectral_mask_id + 1,
852
+ spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
853
+ )
854
+ )
756
855
  m_id += 1
757
856
 
758
857
  return len(used_noise_files), used_noise_samples, mixtures
759
858
 
760
859
 
761
- def _non_combinatorial_noise_mix(augmented_targets: AugmentedTargets,
762
- target_files: TargetFiles,
763
- target_augmentations: AugmentationRules,
764
- noise_files: NoiseFiles,
765
- noise_augmentations: AugmentationRules,
766
- spectral_masks: SpectralMasks,
767
- all_snrs: list[UniversalSNRGenerator],
768
- mixups: list[int],
769
- num_classes: int,
770
- truth_mutex: bool,
771
- feature_step_samples: int,
772
- num_ir: int) -> tuple[int, int, Mixtures]:
773
- """ Combine a target/augmentation with a single cut of a noise/augmentation non-exhaustively
860
+ def _non_combinatorial_noise_mix(
861
+ augmented_targets: AugmentedTargets,
862
+ target_files: TargetFiles,
863
+ target_augmentations: AugmentationRules,
864
+ noise_files: NoiseFiles,
865
+ noise_augmentations: AugmentationRules,
866
+ spectral_masks: SpectralMasks,
867
+ all_snrs: list[UniversalSNRGenerator],
868
+ mixups: list[int],
869
+ num_classes: int,
870
+ feature_step_samples: int,
871
+ num_ir: int,
872
+ ) -> tuple[int, int, Mixtures]:
873
+ """Combine a target/augmentation with a single cut of a noise/augmentation non-exhaustively
774
874
  (each target/augmentation does not use each noise/augmentation). Cut has random start and loop back to
775
875
  beginning if end of noise/augmentation is reached.
776
876
  """
@@ -792,57 +892,72 @@ def _non_combinatorial_noise_mix(augmented_targets: AugmentedTargets,
792
892
  noise_file_id = None
793
893
  noise_augmentation_id = None
794
894
 
795
- augmented_target_indices_for_mixups = [get_augmented_target_ids_for_mixup(
796
- augmented_targets=augmented_targets,
797
- targets=target_files,
798
- target_augmentations=target_augmentations,
799
- mixup=mixup,
800
- num_classes=num_classes,
801
- truth_mutex=truth_mutex) for mixup in mixups]
895
+ augmented_target_indices_for_mixups = [
896
+ get_augmented_target_ids_for_mixup(
897
+ augmented_targets=augmented_targets,
898
+ targets=target_files,
899
+ target_augmentations=target_augmentations,
900
+ mixup=mixup,
901
+ num_classes=num_classes,
902
+ )
903
+ for mixup in mixups
904
+ ]
802
905
  for mixup in augmented_target_indices_for_mixups:
803
906
  for augmented_target_indices in mixup:
804
- targets, target_length = _get_target_info(augmented_target_ids=augmented_target_indices,
805
- augmented_targets=augmented_targets,
806
- target_files=target_files,
807
- target_augmentations=target_augmentations,
808
- feature_step_samples=feature_step_samples,
809
- num_ir=num_ir)
907
+ targets, target_length = _get_target_info(
908
+ augmented_target_ids=augmented_target_indices,
909
+ augmented_targets=augmented_targets,
910
+ target_files=target_files,
911
+ target_augmentations=target_augmentations,
912
+ feature_step_samples=feature_step_samples,
913
+ num_ir=num_ir,
914
+ )
810
915
 
811
916
  for spectral_mask_id in range(len(spectral_masks)):
812
917
  for snr in all_snrs:
813
- (noise_file_id,
814
- noise_augmentation_id,
815
- noise_augmentation,
816
- noise_length) = _get_next_noise_indices(noise_file_id=noise_file_id,
817
- noise_augmentation_id=noise_augmentation_id,
818
- noise_files=noise_files,
819
- noise_augmentations=noise_augmentations,
820
- num_ir=num_ir)
918
+ (
919
+ noise_file_id,
920
+ noise_augmentation_id,
921
+ noise_augmentation,
922
+ noise_length,
923
+ ) = _get_next_noise_indices(
924
+ noise_file_id=noise_file_id,
925
+ noise_augmentation_id=noise_augmentation_id,
926
+ noise_files=noise_files,
927
+ noise_augmentations=noise_augmentations,
928
+ num_ir=num_ir,
929
+ )
821
930
  used_noise_samples += target_length
822
931
 
823
- used_noise_files.add(f'{noise_file_id}_{noise_augmentation_id}')
824
-
825
- mixtures.append(Mixture(
826
- targets=targets,
827
- name=str(m_id),
828
- noise=Noise(file_id=noise_file_id + 1,
829
- augmentation=noise_augmentation,
830
- offset=choice(range(noise_length))),
831
- samples=target_length,
832
- snr=UniversalSNR(value=snr.value,
833
- is_random=snr.is_random),
834
- spectral_mask_id=spectral_mask_id + 1,
835
- spectral_mask_seed=randint(0, np.iinfo('i').max)))
932
+ used_noise_files.add(f"{noise_file_id}_{noise_augmentation_id}")
933
+
934
+ mixtures.append(
935
+ Mixture(
936
+ targets=targets,
937
+ name=str(m_id),
938
+ noise=Noise(
939
+ file_id=noise_file_id + 1,
940
+ augmentation=noise_augmentation,
941
+ offset=choice(range(noise_length)), # noqa: S311
942
+ ),
943
+ samples=target_length,
944
+ snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
945
+ spectral_mask_id=spectral_mask_id + 1,
946
+ spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
947
+ )
948
+ )
836
949
  m_id += 1
837
950
 
838
951
  return len(used_noise_files), used_noise_samples, mixtures
839
952
 
840
953
 
841
- def _get_next_noise_indices(noise_file_id: int,
842
- noise_augmentation_id: int,
843
- noise_files: NoiseFiles,
844
- noise_augmentations: AugmentationRules,
845
- num_ir: int) -> tuple[int, int, Augmentation, int]:
954
+ def _get_next_noise_indices(
955
+ noise_file_id: int | None,
956
+ noise_augmentation_id: int | None,
957
+ noise_files: NoiseFiles,
958
+ noise_augmentations: AugmentationRules,
959
+ num_ir: int,
960
+ ) -> tuple[int, int, Augmentation, int]:
846
961
  from .augmentation import augmentation_from_rule
847
962
  from .augmentation import estimate_augmented_length_from_length
848
963
 
@@ -858,19 +973,21 @@ def _get_next_noise_indices(noise_file_id: int,
858
973
  noise_file_id = 0
859
974
 
860
975
  noise_augmentation = augmentation_from_rule(noise_augmentations[noise_augmentation_id], num_ir)
861
- noise_length = estimate_augmented_length_from_length(length=noise_files[noise_file_id].samples,
862
- tempo=noise_augmentation.tempo)
976
+ noise_length = estimate_augmented_length_from_length(
977
+ length=noise_files[noise_file_id].samples, tempo=noise_augmentation.tempo
978
+ )
863
979
  return noise_file_id, noise_augmentation_id, noise_augmentation, noise_length
864
980
 
865
981
 
866
- def _get_next_noise_offset(noise_file_id: int | None,
867
- noise_augmentation_id: int | None,
868
- noise_offset: int | None,
869
- target_length: int,
870
- noise_files: NoiseFiles,
871
- noise_augmentations: AugmentationRules,
872
- num_ir: int) -> tuple[int, int, Augmentation, int]:
873
- from sonusai import SonusAIError
982
+ def _get_next_noise_offset(
983
+ noise_file_id: int | None,
984
+ noise_augmentation_id: int | None,
985
+ noise_offset: int | None,
986
+ target_length: int,
987
+ noise_files: NoiseFiles,
988
+ noise_augmentations: AugmentationRules,
989
+ num_ir: int,
990
+ ) -> tuple[int, int, Augmentation, int]:
874
991
  from .augmentation import augmentation_from_rule
875
992
  from .augmentation import estimate_augmented_length_from_length
876
993
 
@@ -880,11 +997,12 @@ def _get_next_noise_offset(noise_file_id: int | None,
880
997
  noise_offset = 0
881
998
 
882
999
  noise_augmentation = augmentation_from_rule(noise_augmentations[noise_file_id], num_ir)
883
- noise_length = estimate_augmented_length_from_length(length=noise_files[noise_file_id].samples,
884
- tempo=noise_augmentation.tempo)
1000
+ noise_length = estimate_augmented_length_from_length(
1001
+ length=noise_files[noise_file_id].samples, tempo=noise_augmentation.tempo
1002
+ )
885
1003
  if noise_offset + target_length >= noise_length:
886
1004
  if noise_offset == 0:
887
- raise SonusAIError('Length of target audio exceeds length of noise audio')
1005
+ raise ValueError("Length of target audio exceeds length of noise audio")
888
1006
 
889
1007
  noise_offset = 0
890
1008
  noise_augmentation_id += 1
@@ -898,12 +1016,14 @@ def _get_next_noise_offset(noise_file_id: int | None,
898
1016
  return noise_file_id, noise_augmentation_id, noise_augmentation, noise_offset
899
1017
 
900
1018
 
901
- def _get_target_info(augmented_target_ids: list[int],
902
- augmented_targets: AugmentedTargets,
903
- target_files: TargetFiles,
904
- target_augmentations: AugmentationRules,
905
- feature_step_samples: int,
906
- num_ir: int) -> tuple[Targets, int]:
1019
+ def _get_target_info(
1020
+ augmented_target_ids: list[int],
1021
+ augmented_targets: AugmentedTargets,
1022
+ target_files: TargetFiles,
1023
+ target_augmentations: AugmentationRules,
1024
+ feature_step_samples: int,
1025
+ num_ir: int,
1026
+ ) -> tuple[Targets, int]:
907
1027
  from .augmentation import augmentation_from_rule
908
1028
  from .augmentation import estimate_augmented_length_from_length
909
1029
  from .datatypes import Target
@@ -918,18 +1038,23 @@ def _get_target_info(augmented_target_ids: list[int],
918
1038
 
919
1039
  mixups.append(Target(file_id=tfi + 1, augmentation=target_augmentation))
920
1040
 
921
- target_length = max(estimate_augmented_length_from_length(length=target_files[tfi].samples,
922
- tempo=target_augmentation.tempo,
923
- frame_length=feature_step_samples),
924
- target_length)
1041
+ target_length = max(
1042
+ estimate_augmented_length_from_length(
1043
+ length=target_files[tfi].samples,
1044
+ tempo=target_augmentation.tempo,
1045
+ frame_length=feature_step_samples,
1046
+ ),
1047
+ target_length,
1048
+ )
925
1049
  return mixups, target_length
926
1050
 
927
1051
 
928
1052
  def get_all_snrs_from_config(config: dict) -> list[UniversalSNRGenerator]:
929
1053
  from .datatypes import UniversalSNRGenerator
930
1054
 
931
- return ([UniversalSNRGenerator(is_random=False, _raw_value=snr) for snr in config['snrs']] +
932
- [UniversalSNRGenerator(is_random=True, _raw_value=snr) for snr in config['random_snrs']])
1055
+ return [UniversalSNRGenerator(is_random=False, _raw_value=snr) for snr in config["snrs"]] + [
1056
+ UniversalSNRGenerator(is_random=True, _raw_value=snr) for snr in config["random_snrs"]
1057
+ ]
933
1058
 
934
1059
 
935
1060
  def _get_textgrid_tiers_from_target_file(target_file: str) -> list[str]:
@@ -939,7 +1064,7 @@ def _get_textgrid_tiers_from_target_file(target_file: str) -> list[str]:
939
1064
 
940
1065
  from sonusai.mixture import tokenized_expand
941
1066
 
942
- textgrid_file = Path(tokenized_expand(target_file)[0]).with_suffix('.TextGrid')
1067
+ textgrid_file = Path(tokenized_expand(target_file)[0]).with_suffix(".TextGrid")
943
1068
  if not textgrid_file.exists():
944
1069
  return []
945
1070
 
@@ -949,8 +1074,7 @@ def _get_textgrid_tiers_from_target_file(target_file: str) -> list[str]:
949
1074
 
950
1075
 
951
1076
  def _populate_speaker_table(location: str, target_files: TargetFiles, test: bool = False) -> None:
952
- """Populate speaker table
953
- """
1077
+ """Populate speaker table"""
954
1078
  import json
955
1079
  from pathlib import Path
956
1080
 
@@ -960,65 +1084,74 @@ def _populate_speaker_table(location: str, target_files: TargetFiles, test: bool
960
1084
  from .tokenized_shell_vars import tokenized_expand
961
1085
 
962
1086
  # Determine columns for speaker table
963
- all_parents = set([Path(target_file.name).parent for target_file in target_files])
964
- speaker_parents = (parent for parent in all_parents if Path(tokenized_expand(parent / 'speaker.yml')[0]).exists())
1087
+ all_parents = {Path(target_file.name).parent for target_file in target_files}
1088
+ speaker_parents = (parent for parent in all_parents if Path(tokenized_expand(parent / "speaker.yml")[0]).exists())
965
1089
 
966
1090
  speakers: dict[Path, dict[str, str]] = {}
967
1091
  for parent in sorted(speaker_parents):
968
- with open(tokenized_expand(parent / 'speaker.yml')[0], 'r') as f:
1092
+ with open(tokenized_expand(parent / "speaker.yml")[0]) as f:
969
1093
  speakers[parent] = yaml.safe_load(f)
970
1094
 
971
1095
  new_columns: list[str] = []
972
- for keys in speakers.keys():
973
- for column in speakers[keys].keys():
1096
+ for keys in speakers:
1097
+ for column in speakers[keys]:
974
1098
  new_columns.append(column)
975
1099
  new_columns = sorted(set(new_columns))
976
1100
 
977
1101
  con = db_connection(location=location, readonly=False, test=test)
978
1102
 
979
1103
  for new_column in new_columns:
980
- con.execute(f'ALTER TABLE speaker ADD COLUMN {new_column} TEXT')
1104
+ con.execute(f"ALTER TABLE speaker ADD COLUMN {new_column} TEXT")
981
1105
 
982
1106
  # Populate speaker table
983
1107
  speaker_rows: list[tuple[str, ...]] = []
984
- for key in speakers.keys():
1108
+ for key in speakers:
985
1109
  entry = (speakers[key].get(column, None) for column in new_columns)
986
- speaker_rows.append((key.as_posix(), *entry))
1110
+ speaker_rows.append((key.as_posix(), *entry)) # type: ignore[arg-type]
987
1111
 
988
- column_ids = ', '.join(['parent', *new_columns])
989
- column_values = ', '.join(['?'] * (len(new_columns) + 1))
990
- con.executemany(f'INSERT INTO speaker ({column_ids}) VALUES ({column_values})', speaker_rows)
1112
+ column_ids = ", ".join(["parent", *new_columns])
1113
+ column_values = ", ".join(["?"] * (len(new_columns) + 1))
1114
+ con.executemany(f"INSERT INTO speaker ({column_ids}) VALUES ({column_values})", speaker_rows)
991
1115
 
992
1116
  con.execute("CREATE INDEX speaker_parent_idx ON speaker (parent)")
993
1117
 
994
1118
  # Update speaker_metadata_tiers in the top table
995
- tiers = [description[0] for description in con.execute("SELECT * FROM speaker").description if
996
- description[0] not in ('id', 'parent')]
997
- con.execute("UPDATE top SET speaker_metadata_tiers=? WHERE top.id = ?", (json.dumps(tiers), 1))
998
-
999
- if 'speaker_id' in tiers:
1119
+ tiers = [
1120
+ description[0]
1121
+ for description in con.execute("SELECT * FROM speaker").description
1122
+ if description[0] not in ("id", "parent")
1123
+ ]
1124
+ con.execute(
1125
+ "UPDATE top SET speaker_metadata_tiers=? WHERE top.id = ?",
1126
+ (json.dumps(tiers), 1),
1127
+ )
1128
+
1129
+ if "speaker_id" in tiers:
1000
1130
  con.execute("CREATE INDEX speaker_speaker_id_idx ON speaker (speaker_id)")
1001
1131
 
1002
1132
  con.commit()
1003
1133
  con.close()
1004
1134
 
1005
1135
 
1006
- def _populate_truth_setting_table(location: str, target_files: TargetFiles, test: bool = False) -> None:
1007
- """Populate truth_setting table
1008
- """
1136
+ def _populate_truth_config_table(location: str, target_files: TargetFiles, test: bool = False) -> None:
1137
+ """Populate truth_config table"""
1138
+ import json
1139
+
1009
1140
  from .mixdb import db_connection
1010
1141
 
1011
1142
  con = db_connection(location=location, readonly=False, test=test)
1012
1143
 
1013
- # Populate truth_setting table
1014
- truth_settings: list[str] = []
1015
- for truth_setting in [truth_setting for target_file in target_files
1016
- for truth_setting in target_file.truth_settings]:
1017
- ts = truth_setting.to_json()
1018
- if ts not in truth_settings:
1019
- truth_settings.append(ts)
1020
- con.executemany("INSERT INTO truth_setting (setting) VALUES (?)",
1021
- [(item,) for item in truth_settings])
1144
+ # Populate truth_config table
1145
+ truth_configs: list[str] = []
1146
+ for target_file in target_files:
1147
+ for name, config in target_file.truth_configs.items():
1148
+ ts = json.dumps({"name": name} | config.to_dict())
1149
+ if ts not in truth_configs:
1150
+ truth_configs.append(ts)
1151
+ con.executemany(
1152
+ "INSERT INTO truth_config (config) VALUES (?)",
1153
+ [(item,) for item in truth_configs],
1154
+ )
1022
1155
 
1023
1156
  con.commit()
1024
1157
  con.close()