sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sonusai/__init__.py +16 -3
  2. sonusai/audiofe.py +241 -77
  3. sonusai/calc_metric_spenh.py +71 -73
  4. sonusai/config/__init__.py +3 -0
  5. sonusai/config/config.py +61 -0
  6. sonusai/config/config.yml +20 -0
  7. sonusai/config/constants.py +8 -0
  8. sonusai/constants.py +11 -0
  9. sonusai/data/genmixdb.yml +21 -36
  10. sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
  11. sonusai/deprecated/plot.py +4 -5
  12. sonusai/doc/doc.py +4 -4
  13. sonusai/doc.py +11 -4
  14. sonusai/genft.py +43 -45
  15. sonusai/genmetrics.py +25 -19
  16. sonusai/genmix.py +54 -82
  17. sonusai/genmixdb.py +88 -264
  18. sonusai/ir_metric.py +30 -34
  19. sonusai/lsdb.py +41 -48
  20. sonusai/main.py +15 -22
  21. sonusai/metrics/calc_audio_stats.py +4 -293
  22. sonusai/metrics/calc_class_weights.py +4 -4
  23. sonusai/metrics/calc_optimal_thresholds.py +8 -5
  24. sonusai/metrics/calc_pesq.py +2 -2
  25. sonusai/metrics/calc_segsnr_f.py +4 -4
  26. sonusai/metrics/calc_speech.py +25 -13
  27. sonusai/metrics/class_summary.py +7 -7
  28. sonusai/metrics/confusion_matrix_summary.py +5 -5
  29. sonusai/metrics/one_hot.py +4 -4
  30. sonusai/metrics/snr_summary.py +7 -7
  31. sonusai/metrics_summary.py +38 -45
  32. sonusai/mixture/__init__.py +4 -104
  33. sonusai/mixture/audio.py +10 -39
  34. sonusai/mixture/class_balancing.py +103 -0
  35. sonusai/mixture/config.py +251 -271
  36. sonusai/mixture/constants.py +35 -39
  37. sonusai/mixture/data_io.py +25 -36
  38. sonusai/mixture/db_datatypes.py +58 -22
  39. sonusai/mixture/effects.py +386 -0
  40. sonusai/mixture/feature.py +7 -11
  41. sonusai/mixture/generation.py +478 -628
  42. sonusai/mixture/helpers.py +82 -184
  43. sonusai/mixture/ir_delay.py +3 -4
  44. sonusai/mixture/ir_effects.py +77 -0
  45. sonusai/mixture/log_duration_and_sizes.py +6 -12
  46. sonusai/mixture/mixdb.py +910 -729
  47. sonusai/mixture/pad_audio.py +35 -0
  48. sonusai/mixture/resample.py +7 -0
  49. sonusai/mixture/sox_effects.py +195 -0
  50. sonusai/mixture/sox_help.py +650 -0
  51. sonusai/mixture/spectral_mask.py +2 -2
  52. sonusai/mixture/truth.py +17 -15
  53. sonusai/mixture/truth_functions/crm.py +12 -12
  54. sonusai/mixture/truth_functions/energy.py +22 -22
  55. sonusai/mixture/truth_functions/file.py +5 -5
  56. sonusai/mixture/truth_functions/metadata.py +4 -4
  57. sonusai/mixture/truth_functions/metrics.py +4 -4
  58. sonusai/mixture/truth_functions/phoneme.py +3 -3
  59. sonusai/mixture/truth_functions/sed.py +11 -13
  60. sonusai/mixture/truth_functions/target.py +10 -10
  61. sonusai/mkwav.py +26 -29
  62. sonusai/onnx_predict.py +240 -88
  63. sonusai/queries/__init__.py +2 -2
  64. sonusai/queries/queries.py +38 -34
  65. sonusai/speech/librispeech.py +1 -1
  66. sonusai/speech/mcgill.py +1 -1
  67. sonusai/speech/timit.py +2 -2
  68. sonusai/summarize_metric_spenh.py +10 -17
  69. sonusai/utils/__init__.py +7 -1
  70. sonusai/utils/asl_p56.py +2 -2
  71. sonusai/utils/asr.py +2 -2
  72. sonusai/utils/asr_functions/aaware_whisper.py +4 -5
  73. sonusai/utils/choice.py +31 -0
  74. sonusai/utils/compress.py +1 -1
  75. sonusai/utils/dataclass_from_dict.py +19 -1
  76. sonusai/utils/energy_f.py +3 -3
  77. sonusai/utils/evaluate_random_rule.py +15 -0
  78. sonusai/utils/keyboard_interrupt.py +12 -0
  79. sonusai/utils/onnx_utils.py +3 -17
  80. sonusai/utils/print_mixture_details.py +21 -19
  81. sonusai/utils/{temp_seed.py → rand.py} +3 -3
  82. sonusai/utils/read_predict_data.py +2 -2
  83. sonusai/utils/reshape.py +3 -3
  84. sonusai/utils/stratified_shuffle_split.py +3 -3
  85. sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
  86. sonusai/utils/write_audio.py +2 -2
  87. sonusai/vars.py +11 -4
  88. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
  89. sonusai-1.0.2.dist-info/RECORD +138 -0
  90. sonusai/mixture/augmentation.py +0 -444
  91. sonusai/mixture/class_count.py +0 -15
  92. sonusai/mixture/eq_rule_is_valid.py +0 -45
  93. sonusai/mixture/target_class_balancing.py +0 -107
  94. sonusai/mixture/targets.py +0 -175
  95. sonusai-0.20.3.dist-info/RECORD +0 -128
  96. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
  97. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
@@ -1,16 +1,13 @@
1
1
  # ruff: noqa: S608
2
- from .datatypes import AudioT
3
- from .datatypes import Augmentation
4
- from .datatypes import AugmentationRule
5
- from .datatypes import AugmentedTarget
6
- from .datatypes import GenMixData
7
- from .datatypes import ImpulseResponseFile
8
- from .datatypes import Mixture
9
- from .datatypes import NoiseFile
10
- from .datatypes import SpectralMask
11
- from .datatypes import Target
12
- from .datatypes import TargetFile
13
- from .datatypes import UniversalSNRGenerator
2
+ from ..datatypes import AudioT
3
+ from ..datatypes import Effects
4
+ from ..datatypes import GenMixData
5
+ from ..datatypes import ImpulseResponseFile
6
+ from ..datatypes import Mixture
7
+ from ..datatypes import Source
8
+ from ..datatypes import SourceFile
9
+ from ..datatypes import SourcesAudioT
10
+ from ..datatypes import UniversalSNRGenerator
14
11
  from .mixdb import MixtureDatabase
15
12
 
16
13
 
@@ -34,47 +31,62 @@ def initialize_db(location: str, test: bool = False) -> None:
34
31
  con.execute("""
35
32
  CREATE TABLE truth_parameters(
36
33
  id INTEGER PRIMARY KEY NOT NULL,
34
+ category TEXT NOT NULL,
37
35
  name TEXT NOT NULL,
38
36
  parameters INTEGER)
39
37
  """)
40
38
 
41
39
  con.execute("""
42
- CREATE TABLE target_file (
40
+ CREATE TABLE source_file (
43
41
  id INTEGER PRIMARY KEY NOT NULL,
42
+ category TEXT NOT NULL,
43
+ class_indices TEXT,
44
+ level_type TEXT NOT NULL,
44
45
  name TEXT NOT NULL,
45
46
  samples INTEGER NOT NULL,
46
- class_indices TEXT NOT NULL,
47
- level_type TEXT NOT NULL,
48
47
  speaker_id INTEGER,
49
48
  FOREIGN KEY(speaker_id) REFERENCES speaker (id))
50
49
  """)
51
50
 
52
51
  con.execute("""
53
- CREATE TABLE speaker (
52
+ CREATE TABLE ir_file (
54
53
  id INTEGER PRIMARY KEY NOT NULL,
55
- parent TEXT NOT NULL)
54
+ delay INTEGER NOT NULL,
55
+ name TEXT NOT NULL)
56
56
  """)
57
57
 
58
58
  con.execute("""
59
- CREATE TABLE noise_file (
59
+ CREATE TABLE ir_tag (
60
60
  id INTEGER PRIMARY KEY NOT NULL,
61
- name TEXT NOT NULL,
62
- samples INTEGER NOT NULL)
61
+ tag TEXT NOT NULL UNIQUE)
62
+ """)
63
+
64
+ con.execute("""
65
+ CREATE TABLE ir_file_ir_tag (
66
+ file_id INTEGER NOT NULL,
67
+ tag_id INTEGER NOT NULL,
68
+ FOREIGN KEY(file_id) REFERENCES ir_file (id),
69
+ FOREIGN KEY(tag_id) REFERENCES ir_tag (id))
70
+ """)
71
+
72
+ con.execute("""
73
+ CREATE TABLE speaker (
74
+ id INTEGER PRIMARY KEY NOT NULL,
75
+ parent TEXT NOT NULL)
63
76
  """)
64
77
 
65
78
  con.execute("""
66
79
  CREATE TABLE top (
67
80
  id INTEGER PRIMARY KEY NOT NULL,
68
- version INTEGER NOT NULL,
69
81
  asr_configs TEXT NOT NULL,
70
82
  class_balancing BOOLEAN NOT NULL,
71
83
  feature TEXT NOT NULL,
72
- noise_mix_mode TEXT NOT NULL,
84
+ mixid_width INTEGER NOT NULL,
73
85
  num_classes INTEGER NOT NULL,
74
86
  seed INTEGER NOT NULL,
75
- mixid_width INTEGER NOT NULL,
76
87
  speaker_metadata_tiers TEXT NOT NULL,
77
- textgrid_metadata_tiers TEXT NOT NULL)
88
+ textgrid_metadata_tiers TEXT NOT NULL,
89
+ version INTEGER NOT NULL)
78
90
  """)
79
91
 
80
92
  con.execute("""
@@ -89,64 +101,54 @@ def initialize_db(location: str, test: bool = False) -> None:
89
101
  threshold FLOAT NOT NULL)
90
102
  """)
91
103
 
92
- con.execute("""
93
- CREATE TABLE impulse_response_file (
94
- id INTEGER PRIMARY KEY NOT NULL,
95
- file TEXT NOT NULL,
96
- tags TEXT NOT NULL,
97
- delay INTEGER NOT NULL)
98
- """)
99
-
100
104
  con.execute("""
101
105
  CREATE TABLE spectral_mask (
102
106
  id INTEGER PRIMARY KEY NOT NULL,
103
107
  f_max_width INTEGER NOT NULL,
104
108
  f_num INTEGER NOT NULL,
109
+ t_max_percent INTEGER NOT NULL,
105
110
  t_max_width INTEGER NOT NULL,
106
- t_num INTEGER NOT NULL,
107
- t_max_percent INTEGER NOT NULL)
111
+ t_num INTEGER NOT NULL)
108
112
  """)
109
113
 
110
114
  con.execute("""
111
- CREATE TABLE target_file_truth_config (
112
- target_file_id INTEGER,
113
- truth_config_id INTEGER,
114
- FOREIGN KEY(target_file_id) REFERENCES target_file (id),
115
+ CREATE TABLE source_file_truth_config (
116
+ source_file_id INTEGER NOT NULL,
117
+ truth_config_id INTEGER NOT NULL,
118
+ FOREIGN KEY(source_file_id) REFERENCES source_file (id),
115
119
  FOREIGN KEY(truth_config_id) REFERENCES truth_config (id))
116
120
  """)
117
121
 
118
122
  con.execute("""
119
- CREATE TABLE target (
123
+ CREATE TABLE source (
120
124
  id INTEGER PRIMARY KEY NOT NULL,
125
+ effects TEXT NOT NULL,
121
126
  file_id INTEGER NOT NULL,
122
- augmentation TEXT NOT NULL,
123
- FOREIGN KEY(file_id) REFERENCES target_file (id))
127
+ pre_tempo FLOAT NOT NULL,
128
+ repeat BOOLEAN NOT NULL,
129
+ snr FLOAT NOT NULL,
130
+ snr_gain FLOAT NOT NULL,
131
+ snr_random BOOLEAN NOT NULL,
132
+ start INTEGER NOT NULL,
133
+ FOREIGN KEY(file_id) REFERENCES source_file (id))
124
134
  """)
125
135
 
126
136
  con.execute("""
127
137
  CREATE TABLE mixture (
128
138
  id INTEGER PRIMARY KEY NOT NULL,
129
- name VARCHAR NOT NULL,
130
- noise_file_id INTEGER NOT NULL,
131
- noise_augmentation TEXT NOT NULL,
132
- noise_offset INTEGER NOT NULL,
133
- noise_snr_gain FLOAT,
134
- random_snr BOOLEAN NOT NULL,
135
- snr FLOAT NOT NULL,
139
+ name TEXT NOT NULL,
136
140
  samples INTEGER NOT NULL,
137
141
  spectral_mask_id INTEGER NOT NULL,
138
142
  spectral_mask_seed INTEGER NOT NULL,
139
- target_snr_gain FLOAT,
140
- FOREIGN KEY(noise_file_id) REFERENCES noise_file (id),
141
143
  FOREIGN KEY(spectral_mask_id) REFERENCES spectral_mask (id))
142
144
  """)
143
145
 
144
146
  con.execute("""
145
- CREATE TABLE mixture_target (
146
- mixture_id INTEGER,
147
- target_id INTEGER,
147
+ CREATE TABLE mixture_source (
148
+ mixture_id INTEGER NOT NULL,
149
+ source_id INTEGER NOT NULL,
148
150
  FOREIGN KEY(mixture_id) REFERENCES mixture (id),
149
- FOREIGN KEY(target_id) REFERENCES target (id))
151
+ FOREIGN KEY(source_id) REFERENCES source (id))
150
152
  """)
151
153
 
152
154
  con.commit()
@@ -163,22 +165,21 @@ def populate_top_table(location: str, config: dict, test: bool = False) -> None:
163
165
  con = db_connection(location=location, readonly=False, test=test)
164
166
  con.execute(
165
167
  """
166
- INSERT INTO top (id, version, asr_configs, class_balancing, feature, noise_mix_mode, num_classes,
167
- seed, mixid_width, speaker_metadata_tiers, textgrid_metadata_tiers)
168
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
168
+ INSERT INTO top (id, asr_configs, class_balancing, feature, mixid_width, num_classes,
169
+ seed, speaker_metadata_tiers, textgrid_metadata_tiers, version)
170
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
169
171
  """,
170
172
  (
171
173
  1,
172
- MIXDB_VERSION,
173
174
  json.dumps(config["asr_configs"]),
174
175
  config["class_balancing"],
175
176
  config["feature"],
176
- config["noise_mix_mode"],
177
+ 0,
177
178
  config["num_classes"],
178
179
  config["seed"],
179
- 0,
180
180
  "",
181
181
  "",
182
+ MIXDB_VERSION,
182
183
  ),
183
184
  )
184
185
  con.commit()
@@ -231,15 +232,15 @@ def populate_spectral_mask_table(location: str, config: dict, test: bool = False
231
232
  con = db_connection(location=location, readonly=False, test=test)
232
233
  con.executemany(
233
234
  """
234
- INSERT INTO spectral_mask (f_max_width, f_num, t_max_width, t_num, t_max_percent) VALUES (?, ?, ?, ?, ?)
235
+ INSERT INTO spectral_mask (f_max_width, f_num, t_max_percent, t_max_width, t_num) VALUES (?, ?, ?, ?, ?)
235
236
  """,
236
237
  [
237
238
  (
238
239
  item.f_max_width,
239
240
  item.f_num,
241
+ item.t_max_percent,
240
242
  item.t_max_width,
241
243
  item.t_num,
242
- item.t_max_percent,
243
244
  )
244
245
  for item in get_spectral_masks(config)
245
246
  ],
@@ -256,10 +257,11 @@ def populate_truth_parameters_table(location: str, config: dict, test: bool = Fa
256
257
  con = db_connection(location=location, readonly=False, test=test)
257
258
  con.executemany(
258
259
  """
259
- INSERT INTO truth_parameters (name, parameters) VALUES (?, ?)
260
+ INSERT INTO truth_parameters (category, name, parameters) VALUES (?, ?, ?)
260
261
  """,
261
262
  [
262
263
  (
264
+ item.category,
263
265
  item.name,
264
266
  item.parameters,
265
267
  )
@@ -270,41 +272,39 @@ def populate_truth_parameters_table(location: str, config: dict, test: bool = Fa
270
272
  con.close()
271
273
 
272
274
 
273
- def populate_target_file_table(location: str, target_files: list[TargetFile], test: bool = False) -> None:
274
- """Populate target file table"""
275
+ def populate_source_file_table(location: str, files: list[SourceFile], test: bool = False) -> None:
276
+ """Populate source file table"""
275
277
  import json
276
278
  from pathlib import Path
277
279
 
278
280
  from .mixdb import db_connection
279
281
 
280
- _populate_truth_config_table(location, target_files, test)
281
- _populate_speaker_table(location, target_files, test)
282
+ _populate_truth_config_table(location, files, test)
283
+ _populate_speaker_table(location, files, test)
282
284
 
283
285
  con = db_connection(location=location, readonly=False, test=test)
284
286
 
285
287
  cur = con.cursor()
286
288
  textgrid_metadata_tiers: set[str] = set()
287
- for target_file in target_files:
288
- # Get TextGrid tiers for target file and add to collection
289
- tiers = _get_textgrid_tiers_from_target_file(target_file.name)
289
+ for file in files:
290
+ # Get TextGrid tiers for source file and add to collection
291
+ tiers = _get_textgrid_tiers_from_source_file(file.name)
290
292
  for tier in tiers:
291
293
  textgrid_metadata_tiers.add(tier)
292
294
 
293
- # Get truth settings for target file
295
+ # Get truth settings for file
294
296
  truth_config_ids: list[int] = []
295
- for name, config in target_file.truth_configs.items():
296
- ts = json.dumps({"name": name} | config.to_dict())
297
- cur.execute(
298
- "SELECT truth_config.id FROM truth_config WHERE ? = truth_config.config",
299
- (ts,),
300
- )
301
- truth_config_ids.append(cur.fetchone()[0])
302
-
303
- # Get speaker_id for target file
304
- cur.execute(
305
- "SELECT speaker.id FROM speaker WHERE ? = speaker.parent",
306
- (Path(target_file.name).parent.as_posix(),),
307
- )
297
+ if file.truth_configs:
298
+ for name, config in file.truth_configs.items():
299
+ ts = json.dumps({"name": name} | config.to_dict())
300
+ cur.execute(
301
+ "SELECT truth_config.id FROM truth_config WHERE ? = truth_config.config",
302
+ (ts,),
303
+ )
304
+ truth_config_ids.append(cur.fetchone()[0])
305
+
306
+ # Get speaker_id for source file
307
+ cur.execute("SELECT speaker.id FROM speaker WHERE ? = speaker.parent", (Path(file.name).parent.as_posix(),))
308
308
  result = cur.fetchone()
309
309
  speaker_id = None
310
310
  if result is not None:
@@ -312,121 +312,110 @@ def populate_target_file_table(location: str, target_files: list[TargetFile], te
312
312
 
313
313
  # Add entry
314
314
  cur.execute(
315
- "INSERT INTO target_file (name, samples, class_indices, level_type, speaker_id) VALUES (?, ?, ?, ?, ?)",
315
+ """
316
+ INSERT INTO source_file (category, class_indices, level_type, name, samples, speaker_id)
317
+ VALUES (?, ?, ?, ?, ?, ?)
318
+ """,
316
319
  (
317
- target_file.name,
318
- target_file.samples,
319
- json.dumps(target_file.class_indices),
320
- target_file.level_type,
320
+ file.category,
321
+ json.dumps(file.class_indices),
322
+ file.level_type,
323
+ file.name,
324
+ file.samples,
321
325
  speaker_id,
322
326
  ),
323
327
  )
324
- target_file_id = cur.lastrowid
328
+ source_file_id = cur.lastrowid
325
329
  for truth_config_id in truth_config_ids:
326
330
  cur.execute(
327
- "INSERT INTO target_file_truth_config (target_file_id, truth_config_id) VALUES (?, ?)",
328
- (target_file_id, truth_config_id),
331
+ "INSERT INTO source_file_truth_config (source_file_id, truth_config_id) VALUES (?, ?)",
332
+ (source_file_id, truth_config_id),
329
333
  )
330
334
 
331
335
  # Update textgrid_metadata_tiers in the top table
332
336
  con.execute(
333
- "UPDATE top SET textgrid_metadata_tiers=? WHERE ? = top.id",
334
- (json.dumps(sorted(textgrid_metadata_tiers)), 1),
337
+ "UPDATE top SET textgrid_metadata_tiers=? WHERE ? = id", (json.dumps(sorted(textgrid_metadata_tiers)), 1)
335
338
  )
336
339
 
337
340
  con.commit()
338
341
  con.close()
339
342
 
340
343
 
341
- def populate_noise_file_table(location: str, noise_files: list[NoiseFile], test: bool = False) -> None:
342
- """Populate noise file table"""
344
+ def populate_impulse_response_file_table(location: str, files: list[ImpulseResponseFile], test: bool = False) -> None:
345
+ """Populate impulse response file table"""
343
346
  from .mixdb import db_connection
344
347
 
348
+ _populate_impulse_response_tag_table(location, files, test)
349
+
345
350
  con = db_connection(location=location, readonly=False, test=test)
346
- con.executemany(
347
- "INSERT INTO noise_file (name, samples) VALUES (?, ?)",
348
- [(noise_file.name, noise_file.samples) for noise_file in noise_files],
349
- )
350
- con.commit()
351
- con.close()
352
351
 
352
+ cur = con.cursor()
353
+ for file in files:
354
+ # Get tags for file
355
+ tag_ids: list[int] = []
356
+ for tag in file.tags:
357
+ cur.execute("SELECT id FROM ir_tag WHERE ? = tag", (tag,))
358
+ tag_ids.append(cur.fetchone()[0])
353
359
 
354
- def populate_impulse_response_file_table(
355
- location: str, impulse_response_files: list[ImpulseResponseFile], test: bool = False
356
- ) -> None:
357
- """Populate impulse response file table"""
358
- import json
360
+ cur.execute("INSERT INTO ir_file (delay, name) VALUES (?, ?)", (file.delay, file.name))
359
361
 
360
- from .mixdb import db_connection
362
+ file_id = cur.lastrowid
363
+ for tag_id in tag_ids:
364
+ cur.execute("INSERT INTO ir_file_ir_tag (file_id, tag_id) VALUES (?, ?)", (file_id, tag_id))
361
365
 
362
- con = db_connection(location=location, readonly=False, test=test)
363
- con.executemany(
364
- "INSERT INTO impulse_response_file (file, tags, delay) VALUES (?, ?, ?)",
365
- [
366
- (
367
- impulse_response_file.file,
368
- json.dumps(impulse_response_file.tags),
369
- impulse_response_file.delay,
370
- )
371
- for impulse_response_file in impulse_response_files
372
- ],
373
- )
374
366
  con.commit()
375
367
  con.close()
376
368
 
377
369
 
378
370
  def update_mixid_width(location: str, num_mixtures: int, test: bool = False) -> None:
379
371
  """Update the mixid width"""
380
- from sonusai.utils import max_text_width
381
-
372
+ from ..utils.max_text_width import max_text_width
382
373
  from .mixdb import db_connection
383
374
 
384
375
  con = db_connection(location=location, readonly=False, test=test)
385
- con.execute(
386
- "UPDATE top SET mixid_width=? WHERE ? = top.id",
387
- (max_text_width(num_mixtures), 1),
388
- )
376
+ con.execute("UPDATE top SET mixid_width=? WHERE ? = id", (max_text_width(num_mixtures), 1))
389
377
  con.commit()
390
378
  con.close()
391
379
 
392
380
 
393
381
  def generate_mixtures(
394
- noise_mix_mode: str,
395
- augmented_targets: list[AugmentedTarget],
396
- target_files: list[TargetFile],
397
- target_augmentations: list[AugmentationRule],
398
- noise_files: list[NoiseFile],
399
- noise_augmentations: list[AugmentationRule],
400
- spectral_masks: list[SpectralMask],
401
- all_snrs: list[UniversalSNRGenerator],
402
- mixups: list[int],
403
- num_classes: int,
404
- feature_step_samples: int,
405
- num_ir: int,
406
- ) -> tuple[int, int, list[Mixture]]:
382
+ location: str,
383
+ config: dict,
384
+ effects: dict[str, list[Effects]],
385
+ test: bool = False,
386
+ ) -> list[Mixture]:
407
387
  """Generate mixtures"""
408
- if noise_mix_mode == "exhaustive":
409
- func = _exhaustive_noise_mix
410
- elif noise_mix_mode == "non-exhaustive":
411
- func = _non_exhaustive_noise_mix
412
- elif noise_mix_mode == "non-combinatorial":
413
- func = _non_combinatorial_noise_mix
414
- else:
415
- raise ValueError(f"invalid noise_mix_mode: {noise_mix_mode}")
416
-
417
- return func(
418
- augmented_targets=augmented_targets,
419
- target_files=target_files,
420
- target_augmentations=target_augmentations,
421
- noise_files=noise_files,
422
- noise_augmentations=noise_augmentations,
423
- spectral_masks=spectral_masks,
424
- all_snrs=all_snrs,
425
- mixups=mixups,
426
- num_classes=num_classes,
427
- feature_step_samples=feature_step_samples,
428
- num_ir=num_ir,
429
- )
388
+ mixdb = MixtureDatabase(location, test)
389
+
390
+ effected_sources: dict[str, list[tuple[SourceFile, Effects]]] = {}
391
+ for category in mixdb.source_files:
392
+ effected_sources[category] = []
393
+ for file in mixdb.source_files[category]:
394
+ for effect in effects[category]:
395
+ effected_sources[category].append((file, effect))
396
+
397
+ mixtures: list[Mixture] = []
398
+ for noise_mix_rule in config["sources"]["noise"]["mix_rules"]:
399
+ match noise_mix_rule["mode"]:
400
+ case "exhaustive":
401
+ func = _exhaustive_noise_mix
402
+ case "non-exhaustive":
403
+ func = _non_exhaustive_noise_mix
404
+ case "non-combinatorial":
405
+ func = _non_combinatorial_noise_mix
406
+ case _:
407
+ raise ValueError(f"invalid noise mix_rule mode: {noise_mix_rule['mode']}")
408
+
409
+ mixtures.extend(
410
+ func(
411
+ location=location,
412
+ config=config,
413
+ effected_sources=effected_sources,
414
+ test=test,
415
+ )
416
+ )
417
+
418
+ return mixtures
430
419
 
431
420
 
432
421
  def populate_mixture_table(
@@ -437,579 +426,428 @@ def populate_mixture_table(
437
426
  show_progress: bool = False,
438
427
  ) -> None:
439
428
  """Populate mixture table"""
440
- from sonusai import logger
441
- from sonusai.utils import track
442
-
429
+ from .. import logger
430
+ from ..utils.parallel import track
443
431
  from .helpers import from_mixture
444
- from .helpers import from_target
432
+ from .helpers import from_source
445
433
  from .mixdb import db_connection
446
434
 
447
435
  con = db_connection(location=location, readonly=False, test=test)
448
436
 
449
- # Populate target table
437
+ # Populate source table
450
438
  if logging:
451
- logger.info("Populating target table")
452
- targets: list[tuple[int, str]] = []
439
+ logger.info("Populating source table")
440
+ sources: list[tuple[str, int, float, bool, float, float, bool, int]] = []
453
441
  for mixture in mixtures:
454
- for target in mixture.targets:
455
- entry = from_target(target)
456
- if entry not in targets:
457
- targets.append(entry)
458
- for target in track(targets, disable=not show_progress):
459
- con.execute("INSERT INTO target (file_id, augmentation) VALUES (?, ?)", target)
442
+ for source in mixture.all_sources.values():
443
+ entry = from_source(source)
444
+ if entry not in sources:
445
+ sources.append(entry)
446
+ for source in track(sources, disable=not show_progress):
447
+ con.execute(
448
+ """
449
+ INSERT INTO source (effects, file_id, pre_tempo, repeat, snr, snr_gain, snr_random, start)
450
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
451
+ """,
452
+ source,
453
+ )
454
+
455
+ con.commit()
460
456
 
461
457
  # Populate mixture table
462
458
  if logging:
463
459
  logger.info("Populating mixture table")
464
460
  for mixture in track(mixtures, disable=not show_progress):
465
- m_id = int(mixture.name)
461
+ m_id = int(mixture.name) + 1
466
462
  con.execute(
467
463
  """
468
- INSERT INTO mixture (id, name, noise_file_id, noise_augmentation, noise_offset, noise_snr_gain, random_snr,
469
- snr, samples, spectral_mask_id, spectral_mask_seed, target_snr_gain)
470
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
464
+ INSERT INTO mixture (id, name, samples, spectral_mask_id, spectral_mask_seed)
465
+ VALUES (?, ?, ?, ?, ?)
471
466
  """,
472
- (m_id + 1, *from_mixture(mixture)),
467
+ (m_id, *from_mixture(mixture)),
473
468
  )
474
469
 
475
- for target in mixture.targets:
476
- target_id = con.execute(
470
+ for source in mixture.all_sources.values():
471
+ source_id = con.execute(
477
472
  """
478
- SELECT target.id
479
- FROM target
480
- WHERE ? = target.file_id AND ? = target.augmentation
473
+ SELECT id
474
+ FROM source
475
+ WHERE ? = effects
476
+ AND ? = file_id
477
+ AND ? = pre_tempo
478
+ AND ? = repeat
479
+ AND ? = snr
480
+ AND ? = snr_gain
481
+ AND ? = snr_random
482
+ AND ? = start
481
483
  """,
482
- from_target(target),
484
+ from_source(source),
483
485
  ).fetchone()[0]
484
- con.execute(
485
- "INSERT INTO mixture_target (mixture_id, target_id) VALUES (?, ?)",
486
- (m_id + 1, target_id),
487
- )
486
+ con.execute("INSERT INTO mixture_source (mixture_id, source_id) VALUES (?, ?)", (m_id, source_id))
488
487
 
488
+ if logging:
489
+ logger.info("Closing mixture table")
489
490
  con.commit()
490
491
  con.close()
491
492
 
492
493
 
493
494
  def update_mixture(mixdb: MixtureDatabase, mixture: Mixture, with_data: bool = False) -> tuple[Mixture, GenMixData]:
494
- """Update mixture record with name and gains"""
495
- from .audio import get_next_noise
496
- from .augmentation import apply_gain
497
- from .helpers import get_target
498
-
499
- mixture, targets_audio = _initialize_targets_audio(mixdb, mixture)
500
-
501
- noise_audio = _augmented_noise_audio(mixdb, mixture)
502
- noise_audio = get_next_noise(audio=noise_audio, offset=mixture.noise_offset, length=mixture.samples)
495
+ """Update mixture record with name, samples, and gains"""
496
+ import numpy as np
503
497
 
504
- # Apply IR and sum targets audio before initializing the mixture SNR gains
505
- target_audio = get_target(mixdb, mixture, targets_audio)
498
+ sources_audio: SourcesAudioT = {}
499
+ post_audio: SourcesAudioT = {}
500
+ for category in mixture.all_sources:
501
+ mixture, sources_audio[category], post_audio[category] = _update_source(mixdb, mixture, category)
506
502
 
507
- mixture = _initialize_mixture_gains(
508
- mixdb=mixdb, mixture=mixture, target_audio=target_audio, noise_audio=noise_audio
509
- )
503
+ mixture = _initialize_mixture_gains(mixdb, mixture, post_audio)
510
504
 
511
505
  mixture.name = f"{int(mixture.name):0{mixdb.mixid_width}}"
512
506
 
513
507
  if not with_data:
514
508
  return mixture, GenMixData()
515
509
 
516
- # Apply SNR gains
517
- targets_audio = [apply_gain(audio=target_audio, gain=mixture.target_snr_gain) for target_audio in targets_audio]
518
- noise_audio = apply_gain(audio=noise_audio, gain=mixture.noise_snr_gain)
510
+ # Apply gains
511
+ post_audio = {
512
+ category: post_audio[category] * mixture.all_sources[category].snr_gain for category in mixture.all_sources
513
+ }
519
514
 
520
- # Apply IR and sum targets audio after applying the mixture SNR gains
521
- target_audio = get_target(mixdb, mixture, targets_audio)
522
- mixture_audio = target_audio + noise_audio
515
+ # Sum sources, noise, and mixture
516
+ source_audio = np.sum([post_audio[category] for category in mixture.sources], axis=0)
517
+ noise_audio = post_audio["noise"]
518
+ mixture_audio = source_audio + noise_audio
523
519
 
524
520
  return mixture, GenMixData(
525
- mixture=mixture_audio,
526
- targets=targets_audio,
527
- target=target_audio,
521
+ sources=sources_audio,
522
+ source=source_audio,
528
523
  noise=noise_audio,
524
+ mixture=mixture_audio,
529
525
  )
530
526
 
531
527
 
532
- def _augmented_noise_audio(mixdb: MixtureDatabase, mixture: Mixture) -> AudioT:
533
- from .audio import read_audio
534
- from .augmentation import apply_augmentation
535
-
536
- noise = mixdb.noise_file(mixture.noise.file_id)
537
- noise_augmentation = mixture.noise.augmentation
538
-
539
- audio = read_audio(noise.name)
540
- audio = apply_augmentation(mixdb, audio, noise_augmentation.pre)
541
-
542
- return audio
528
+ def _update_source(mixdb: MixtureDatabase, mixture: Mixture, category: str) -> tuple[Mixture, AudioT, AudioT]:
529
+ from .effects import apply_effects
530
+ from .effects import conform_audio_to_length
543
531
 
532
+ source = mixture.all_sources[category]
533
+ org_audio = mixdb.read_source_audio(source.file_id)
544
534
 
545
- def _initialize_targets_audio(mixdb: MixtureDatabase, mixture: Mixture) -> tuple[Mixture, list[AudioT]]:
546
- from .augmentation import apply_augmentation
547
- from .augmentation import pad_audio_to_length
535
+ org_samples = len(org_audio)
536
+ pre_audio = apply_effects(mixdb, org_audio, source.effects, pre=True, post=False)
548
537
 
549
- targets_audio = []
550
- for target in mixture.targets:
551
- target_audio = mixdb.read_target_audio(target.file_id)
552
- targets_audio.append(
553
- apply_augmentation(
554
- mixdb=mixdb,
555
- audio=target_audio,
556
- augmentation=target.augmentation.pre,
557
- frame_length=mixdb.feature_step_samples,
558
- )
559
- )
538
+ pre_samples = len(pre_audio)
539
+ mixture.all_sources[category].pre_tempo = org_samples / pre_samples
560
540
 
561
- mixture.samples = max([len(item) for item in targets_audio])
541
+ pre_audio = conform_audio_to_length(pre_audio, mixture.samples, source.repeat, source.start)
562
542
 
563
- for idx in range(len(targets_audio)):
564
- targets_audio[idx] = pad_audio_to_length(audio=targets_audio[idx], length=mixture.samples)
543
+ post_audio = apply_effects(mixdb, pre_audio, source.effects, pre=False, post=True)
544
+ if len(pre_audio) != len(post_audio):
545
+ raise RuntimeError(f"post-truth effects changed length: {source.effects.post}")
565
546
 
566
- return mixture, targets_audio
547
+ return mixture, pre_audio, post_audio
567
548
 
568
549
 
569
- def _initialize_mixture_gains(
570
- mixdb: MixtureDatabase,
571
- mixture: Mixture,
572
- target_audio: AudioT,
573
- noise_audio: AudioT,
574
- ) -> Mixture:
550
+ def _initialize_mixture_gains(mixdb: MixtureDatabase, mixture: Mixture, sources_audio: SourcesAudioT) -> Mixture:
575
551
  import numpy as np
576
552
 
577
- from sonusai.utils import asl_p56
578
- from sonusai.utils import db_to_linear
579
-
580
- if mixture.is_noise_only:
581
- # Special case for zeroing out target data
582
- mixture.target_snr_gain = 0
583
- mixture.noise_snr_gain = 1
584
- elif mixture.is_target_only:
585
- # Special case for zeroing out noise data
586
- mixture.target_snr_gain = 1
587
- mixture.noise_snr_gain = 0
588
- else:
589
- target_level_types = [
590
- target_file.level_type for target_file in [mixdb.target_file(target.file_id) for target in mixture.targets]
591
- ]
592
- if not all(level_type == target_level_types[0] for level_type in target_level_types):
593
- raise ValueError("Not all target_level_types in mixup are the same")
594
-
595
- level_type = target_level_types[0]
553
+ from ..utils.asl_p56 import asl_p56
554
+ from ..utils.db import db_to_linear
555
+
556
+ sources_energy: dict[str, float] = {}
557
+ for category in mixture.all_sources:
558
+ level_type = mixdb.source_file(mixture.all_sources[category].file_id).level_type
596
559
  match level_type:
597
560
  case "default":
598
- target_energy = np.mean(np.square(target_audio))
561
+ sources_energy[category] = float(np.mean(np.square(sources_audio[category])))
599
562
  case "speech":
600
- target_energy = asl_p56(target_audio)
563
+ sources_energy[category] = asl_p56(sources_audio[category])
601
564
  case _:
602
565
  raise ValueError(f"Unknown level_type: {level_type}")
603
566
 
604
- noise_energy = np.mean(np.square(noise_audio))
605
- if noise_energy == 0:
606
- noise_gain = 1
607
- else:
608
- noise_gain = np.sqrt(target_energy / noise_energy) / db_to_linear(mixture.snr)
609
-
610
- # Check for noise_gain > 1 to avoid clipping
611
- if noise_gain > 1:
612
- mixture.target_snr_gain = 1 / noise_gain
613
- mixture.noise_snr_gain = 1
614
- else:
615
- mixture.target_snr_gain = 1
616
- mixture.noise_snr_gain = noise_gain
567
+ # Initialize all gains to 1
568
+ for category in mixture.all_sources:
569
+ mixture.all_sources[category].snr_gain = 1
570
+
571
+ # Resolve gains
572
+ for category in mixture.all_sources:
573
+ if mixture.is_noise_only and category != "noise":
574
+ # Special case for zeroing out source data
575
+ mixture.all_sources[category].snr_gain = 0
576
+ elif mixture.is_source_only and category == "noise":
577
+ # Special case for zeroing out noise data
578
+ mixture.all_sources[category].snr_gain = 0
579
+ elif category != "primary":
580
+ if sources_energy["primary"] == 0:
581
+ # Avoid divide-by-zero
582
+ mixture.all_sources[category].snr_gain = 1
583
+ else:
584
+ mixture.all_sources[category].snr_gain = float(
585
+ np.sqrt(sources_energy["primary"] / sources_energy[category])
586
+ ) / db_to_linear(mixture.all_sources[category].snr)
587
+
588
+ # Normalize gains
589
+ max_snr_gain = max([source.snr_gain for source in mixture.all_sources.values()])
590
+ for category in mixture.all_sources:
591
+ mixture.all_sources[category].snr_gain = mixture.all_sources[category].snr_gain / max_snr_gain
617
592
 
618
593
  # Check for clipping in mixture
619
- gain_adjusted_target_audio = target_audio * mixture.target_snr_gain
620
- gain_adjusted_noise_audio = noise_audio * mixture.noise_snr_gain
621
- mixture_audio = gain_adjusted_target_audio + gain_adjusted_noise_audio
622
- max_abs_audio = max(abs(mixture_audio))
594
+ mixture_audio = np.sum(
595
+ [sources_audio[category] * mixture.all_sources[category].snr_gain for category in mixture.all_sources], axis=0
596
+ )
597
+ max_abs_audio = float(np.max(np.abs(mixture_audio)))
623
598
  clip_level = db_to_linear(-0.25)
624
599
  if max_abs_audio > clip_level:
625
- # Clipping occurred; lower gains to bring audio within +/-1
626
600
  gain_adjustment = clip_level / max_abs_audio
627
- mixture.target_snr_gain *= gain_adjustment
628
- mixture.noise_snr_gain *= gain_adjustment
601
+ for category in mixture.all_sources:
602
+ mixture.all_sources[category].snr_gain *= gain_adjustment
603
+
604
+ # To improve repeatability, round results
605
+ for category in mixture.all_sources:
606
+ mixture.all_sources[category].snr_gain = round(mixture.all_sources[category].snr_gain, ndigits=5)
629
607
 
630
- mixture.target_snr_gain = round(mixture.target_snr_gain, ndigits=5)
631
- mixture.noise_snr_gain = round(mixture.noise_snr_gain, ndigits=5)
632
608
  return mixture
633
609
 
634
610
 
635
611
  def _exhaustive_noise_mix(
636
- augmented_targets: list[AugmentedTarget],
637
- target_files: list[TargetFile],
638
- target_augmentations: list[AugmentationRule],
639
- noise_files: list[NoiseFile],
640
- noise_augmentations: list[AugmentationRule],
641
- spectral_masks: list[SpectralMask],
642
- all_snrs: list[UniversalSNRGenerator],
643
- mixups: list[int],
644
- num_classes: int,
645
- feature_step_samples: int,
646
- num_ir: int,
647
- ) -> tuple[int, int, list[Mixture]]:
648
- """Use every noise/augmentation with every target/augmentation+interferences/augmentation"""
612
+ location: str,
613
+ config: dict,
614
+ effected_sources: dict[str, list[tuple[SourceFile, Effects]]],
615
+ test: bool = False,
616
+ ) -> list[Mixture]:
617
+ """Use every noise/effect with every source/effect+interferences/effect"""
649
618
  from random import randint
650
619
 
651
620
  import numpy as np
652
621
 
653
- from .augmentation import augmentation_from_rule
654
- from .augmentation import estimate_augmented_length_from_length
655
- from .datatypes import Mixture
656
- from .datatypes import Noise
657
- from .datatypes import UniversalSNR
658
- from .targets import get_augmented_target_ids_for_mixup
622
+ from ..datatypes import Mixture
623
+ from ..datatypes import UniversalSNR
624
+ from .effects import effects_from_rules
625
+ from .effects import estimate_effected_length
659
626
 
660
- m_id = 0
661
- used_noise_files = len(noise_files) * len(noise_augmentations)
662
- used_noise_samples = 0
663
-
664
- augmented_target_ids_for_mixups = [
665
- get_augmented_target_ids_for_mixup(
666
- augmented_targets=augmented_targets,
667
- targets=target_files,
668
- target_augmentations=target_augmentations,
669
- mixup=mixup,
670
- num_classes=num_classes,
671
- )
672
- for mixup in mixups
673
- ]
627
+ mixdb = MixtureDatabase(location, test)
628
+ snrs = get_all_snrs_from_config(config)
674
629
 
630
+ m_id = 0
675
631
  mixtures: list[Mixture] = []
676
- for noise_file_id in range(len(noise_files)):
677
- for noise_augmentation_rule in noise_augmentations:
678
- noise_augmentation = augmentation_from_rule(noise_augmentation_rule, num_ir)
679
- noise_offset = 0
680
- noise_length = estimate_augmented_length_from_length(
681
- length=noise_files[noise_file_id].samples,
682
- tempo=noise_augmentation.pre.tempo,
683
- )
632
+ for noise_file, noise_rule in effected_sources["noise"]:
633
+ noise_start = 0
634
+ noise_effect = effects_from_rules(mixdb, noise_rule)
635
+ noise_length = estimate_effected_length(noise_file.samples, noise_effect)
684
636
 
685
- for augmented_target_ids_for_mixup in augmented_target_ids_for_mixups:
686
- for augmented_target_ids in augmented_target_ids_for_mixup:
687
- targets, target_length = _get_target_info(
688
- augmented_target_ids=augmented_target_ids,
689
- augmented_targets=augmented_targets,
690
- target_files=target_files,
691
- target_augmentations=target_augmentations,
692
- feature_step_samples=feature_step_samples,
693
- num_ir=num_ir,
694
- )
637
+ for primary_file, primary_rule in effected_sources["primary"]:
638
+ primary_effect = effects_from_rules(mixdb, primary_rule)
639
+ primary_length = estimate_effected_length(primary_file.samples, primary_effect, mixdb.feature_step_samples)
695
640
 
696
- for spectral_mask_id in range(len(spectral_masks)):
697
- for snr in all_snrs:
698
- mixtures.append(
699
- Mixture(
700
- targets=targets,
701
- name=str(m_id),
702
- noise=Noise(file_id=noise_file_id + 1, augmentation=noise_augmentation),
703
- noise_offset=noise_offset,
704
- samples=target_length,
641
+ for spectral_mask_id in range(len(config["spectral_masks"])):
642
+ for snr in snrs["noise"]:
643
+ mixtures.append(
644
+ Mixture(
645
+ name=str(m_id),
646
+ all_sources={
647
+ "primary": Source(
648
+ file_id=primary_file.id,
649
+ effects=primary_effect,
650
+ ),
651
+ "noise": Source(
652
+ file_id=noise_file.id,
653
+ effects=noise_effect,
654
+ start=noise_start,
655
+ repeat=True,
705
656
  snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
706
- spectral_mask_id=spectral_mask_id + 1,
707
- spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
708
- )
709
- )
710
- m_id += 1
711
-
712
- noise_offset = int((noise_offset + target_length) % noise_length)
713
- used_noise_samples += target_length
657
+ ),
658
+ },
659
+ samples=primary_length,
660
+ spectral_mask_id=spectral_mask_id + 1,
661
+ spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
662
+ )
663
+ )
664
+ noise_start = int((noise_start + primary_length) % noise_length)
665
+ m_id += 1
714
666
 
715
- return used_noise_files, used_noise_samples, mixtures
667
+ return mixtures
716
668
 
717
669
 
718
670
  def _non_exhaustive_noise_mix(
719
- augmented_targets: list[AugmentedTarget],
720
- target_files: list[TargetFile],
721
- target_augmentations: list[AugmentationRule],
722
- noise_files: list[NoiseFile],
723
- noise_augmentations: list[AugmentationRule],
724
- spectral_masks: list[SpectralMask],
725
- all_snrs: list[UniversalSNRGenerator],
726
- mixups: list[int],
727
- num_classes: int,
728
- feature_step_samples: int,
729
- num_ir: int,
730
- ) -> tuple[int, int, list[Mixture]]:
731
- """Cycle through every target/augmentation+interferences/augmentation without necessarily using all
732
- noise/augmentation combinations (reduced data set).
671
+ location: str,
672
+ config: dict,
673
+ effected_sources: dict[str, list[tuple[SourceFile, Effects]]],
674
+ test: bool = False,
675
+ ) -> list[Mixture]:
676
+ """Cycle through every source/effect+interferences/effect without necessarily using all
677
+ noise/effect combinations (reduced data set).
733
678
  """
734
679
  from random import randint
735
680
 
736
681
  import numpy as np
737
682
 
738
- from .datatypes import Mixture
739
- from .datatypes import Noise
740
- from .datatypes import UniversalSNR
741
- from .targets import get_augmented_target_ids_for_mixup
683
+ from ..datatypes import Mixture
684
+ from ..datatypes import UniversalSNR
685
+ from .effects import effects_from_rules
686
+ from .effects import estimate_effected_length
742
687
 
743
- m_id = 0
744
- used_noise_files = set()
745
- used_noise_samples = 0
746
- noise_file_id = None
747
- noise_augmentation_id = None
748
- noise_offset = None
749
-
750
- augmented_target_indices_for_mixups = [
751
- get_augmented_target_ids_for_mixup(
752
- augmented_targets=augmented_targets,
753
- targets=target_files,
754
- target_augmentations=target_augmentations,
755
- mixup=mixup,
756
- num_classes=num_classes,
757
- )
758
- for mixup in mixups
759
- ]
688
+ mixdb = MixtureDatabase(location, test)
689
+ snrs = get_all_snrs_from_config(config)
760
690
 
761
- mixtures: list[Mixture] = []
762
- for mixup in augmented_target_indices_for_mixups:
763
- for augmented_target_indices in mixup:
764
- targets, target_length = _get_target_info(
765
- augmented_target_ids=augmented_target_indices,
766
- augmented_targets=augmented_targets,
767
- target_files=target_files,
768
- target_augmentations=target_augmentations,
769
- feature_step_samples=feature_step_samples,
770
- num_ir=num_ir,
771
- )
691
+ next_noise = NextNoise(mixdb, effected_sources["noise"])
772
692
 
773
- for spectral_mask_id in range(len(spectral_masks)):
774
- for snr in all_snrs:
775
- (
776
- noise_file_id,
777
- noise_augmentation_id,
778
- noise_augmentation,
779
- noise_offset,
780
- ) = _get_next_noise_offset(
781
- noise_file_id=noise_file_id,
782
- noise_augmentation_id=noise_augmentation_id,
783
- noise_offset=noise_offset,
784
- target_length=target_length,
785
- noise_files=noise_files,
786
- noise_augmentations=noise_augmentations,
787
- num_ir=num_ir,
788
- )
789
- used_noise_samples += target_length
790
-
791
- used_noise_files.add(f"{noise_file_id}_{noise_augmentation_id}")
792
-
793
- mixtures.append(
794
- Mixture(
795
- targets=targets,
796
- name=str(m_id),
797
- noise=Noise(file_id=noise_file_id + 1, augmentation=noise_augmentation),
798
- noise_offset=noise_offset,
799
- samples=target_length,
800
- snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
801
- spectral_mask_id=spectral_mask_id + 1,
802
- spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
803
- )
693
+ m_id = 0
694
+ mixtures: list[Mixture] = []
695
+ for primary_file, primary_rule in effected_sources["primary"]:
696
+ primary_effect = effects_from_rules(mixdb, primary_rule)
697
+ primary_length = estimate_effected_length(primary_file.samples, primary_effect, mixdb.feature_step_samples)
698
+
699
+ for spectral_mask_id in range(len(config["spectral_masks"])):
700
+ for snr in snrs["noise"]:
701
+ noise_file_id, noise_effect, noise_start = next_noise.generate(primary_file.samples)
702
+
703
+ mixtures.append(
704
+ Mixture(
705
+ name=str(m_id),
706
+ all_sources={
707
+ "primary": Source(
708
+ file_id=primary_file.id,
709
+ effects=primary_effect,
710
+ ),
711
+ "noise": Source(
712
+ file_id=noise_file_id,
713
+ effects=noise_effect,
714
+ start=noise_start,
715
+ repeat=True,
716
+ snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
717
+ ),
718
+ },
719
+ samples=primary_length,
720
+ spectral_mask_id=spectral_mask_id + 1,
721
+ spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
804
722
  )
805
- m_id += 1
723
+ )
724
+ m_id += 1
806
725
 
807
- return len(used_noise_files), used_noise_samples, mixtures
726
+ return mixtures
808
727
 
809
728
 
810
729
  def _non_combinatorial_noise_mix(
811
- augmented_targets: list[AugmentedTarget],
812
- target_files: list[TargetFile],
813
- target_augmentations: list[AugmentationRule],
814
- noise_files: list[NoiseFile],
815
- noise_augmentations: list[AugmentationRule],
816
- spectral_masks: list[SpectralMask],
817
- all_snrs: list[UniversalSNRGenerator],
818
- mixups: list[int],
819
- num_classes: int,
820
- feature_step_samples: int,
821
- num_ir: int,
822
- ) -> tuple[int, int, list[Mixture]]:
823
- """Combine a target/augmentation+interferences/augmentation with a single cut of a noise/augmentation
824
- non-exhaustively (each target/augmentation+interferences/augmentation does not use each noise/augmentation).
825
- Cut has random start and loop back to beginning if end of noise/augmentation is reached.
730
+ location: str,
731
+ config: dict,
732
+ effected_sources: dict[str, list[tuple[SourceFile, Effects]]],
733
+ test: bool = False,
734
+ ) -> list[Mixture]:
735
+ """Combine a source/effect+interferences/effect with a single cut of a noise/effect
736
+ non-exhaustively (each source/effect+interferences/effect does not use each noise/effect).
737
+ Cut has random start and loop back to beginning if end of noise/effect is reached.
826
738
  """
827
739
  from random import choice
828
740
  from random import randint
829
741
 
830
742
  import numpy as np
831
743
 
832
- from .datatypes import Mixture
833
- from .datatypes import Noise
834
- from .datatypes import UniversalSNR
835
- from .targets import get_augmented_target_ids_for_mixup
744
+ from ..datatypes import Mixture
745
+ from ..datatypes import UniversalSNR
746
+ from .effects import effects_from_rules
747
+ from .effects import estimate_effected_length
836
748
 
837
- m_id = 0
838
- used_noise_files = set()
839
- used_noise_samples = 0
840
- noise_file_id = None
841
- noise_augmentation_id = None
842
-
843
- augmented_target_indices_for_mixups = [
844
- get_augmented_target_ids_for_mixup(
845
- augmented_targets=augmented_targets,
846
- targets=target_files,
847
- target_augmentations=target_augmentations,
848
- mixup=mixup,
849
- num_classes=num_classes,
850
- )
851
- for mixup in mixups
852
- ]
749
+ mixdb = MixtureDatabase(location, test)
750
+ snrs = get_all_snrs_from_config(config)
853
751
 
752
+ m_id = 0
753
+ noise_id = 0
854
754
  mixtures: list[Mixture] = []
855
- for mixup in augmented_target_indices_for_mixups:
856
- for augmented_target_indices in mixup:
857
- targets, target_length = _get_target_info(
858
- augmented_target_ids=augmented_target_indices,
859
- augmented_targets=augmented_targets,
860
- target_files=target_files,
861
- target_augmentations=target_augmentations,
862
- feature_step_samples=feature_step_samples,
863
- num_ir=num_ir,
864
- )
865
-
866
- for spectral_mask_id in range(len(spectral_masks)):
867
- for snr in all_snrs:
868
- (
869
- noise_file_id,
870
- noise_augmentation_id,
871
- noise_augmentation,
872
- noise_length,
873
- ) = _get_next_noise_indices(
874
- noise_file_id=noise_file_id,
875
- noise_augmentation_id=noise_augmentation_id,
876
- noise_files=noise_files,
877
- noise_augmentations=noise_augmentations,
878
- num_ir=num_ir,
755
+ for primary_file, primary_rule in effected_sources["primary"]:
756
+ primary_effect = effects_from_rules(mixdb, primary_rule)
757
+ primary_length = estimate_effected_length(primary_file.samples, primary_effect, mixdb.feature_step_samples)
758
+
759
+ for spectral_mask_id in range(len(config["spectral_masks"])):
760
+ for snr in snrs["noise"]:
761
+ noise_file, noise_rule = effected_sources["noise"][noise_id]
762
+ noise_effect = effects_from_rules(mixdb, noise_rule)
763
+ noise_length = estimate_effected_length(noise_file.samples, noise_effect)
764
+
765
+ mixtures.append(
766
+ Mixture(
767
+ name=str(m_id),
768
+ all_sources={
769
+ "primary": Source(
770
+ file_id=primary_file.id,
771
+ effects=primary_effect,
772
+ ),
773
+ "noise": Source(
774
+ file_id=noise_file.id,
775
+ effects=noise_effect,
776
+ start=choice(range(noise_length)), # noqa: S311
777
+ repeat=True,
778
+ snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
779
+ ),
780
+ },
781
+ samples=primary_length,
782
+ spectral_mask_id=spectral_mask_id + 1,
783
+ spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
879
784
  )
880
- used_noise_samples += target_length
785
+ )
786
+ noise_id = (noise_id + 1) % len(effected_sources["noise"])
787
+ m_id += 1
881
788
 
882
- used_noise_files.add(f"{noise_file_id}_{noise_augmentation_id}")
789
+ return mixtures
883
790
 
884
- mixtures.append(
885
- Mixture(
886
- targets=targets,
887
- name=str(m_id),
888
- noise=Noise(file_id=noise_file_id + 1, augmentation=noise_augmentation),
889
- noise_offset=choice(range(noise_length)), # noqa: S311
890
- samples=target_length,
891
- snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
892
- spectral_mask_id=spectral_mask_id + 1,
893
- spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
894
- )
895
- )
896
- m_id += 1
897
791
 
898
- return len(used_noise_files), used_noise_samples, mixtures
899
-
900
-
901
- def _get_next_noise_indices(
902
- noise_file_id: int | None,
903
- noise_augmentation_id: int | None,
904
- noise_files: list[NoiseFile],
905
- noise_augmentations: list[AugmentationRule],
906
- num_ir: int,
907
- ) -> tuple[int, int, Augmentation, int]:
908
- from .augmentation import augmentation_from_rule
909
- from .augmentation import estimate_augmented_length_from_length
910
-
911
- if noise_file_id is None or noise_augmentation_id is None:
912
- noise_file_id = 0
913
- noise_augmentation_id = 0
914
- else:
915
- noise_augmentation_id += 1
916
- if noise_augmentation_id == len(noise_augmentations):
917
- noise_augmentation_id = 0
918
- noise_file_id += 1
919
- if noise_file_id == len(noise_files):
920
- noise_file_id = 0
921
-
922
- noise_augmentation = augmentation_from_rule(noise_augmentations[noise_augmentation_id], num_ir)
923
- noise_length = estimate_augmented_length_from_length(
924
- length=noise_files[noise_file_id].samples, tempo=noise_augmentation.pre.tempo
925
- )
926
- return noise_file_id, noise_augmentation_id, noise_augmentation, noise_length
927
-
928
-
929
- def _get_next_noise_offset(
930
- noise_file_id: int | None,
931
- noise_augmentation_id: int | None,
932
- noise_offset: int | None,
933
- target_length: int,
934
- noise_files: list[NoiseFile],
935
- noise_augmentations: list[AugmentationRule],
936
- num_ir: int,
937
- ) -> tuple[int, int, Augmentation, int]:
938
- from .augmentation import augmentation_from_rule
939
- from .augmentation import estimate_augmented_length_from_length
940
-
941
- if noise_file_id is None or noise_augmentation_id is None or noise_offset is None:
942
- noise_file_id = 0
943
- noise_augmentation_id = 0
944
- noise_offset = 0
945
-
946
- noise_augmentation = augmentation_from_rule(noise_augmentations[noise_file_id], num_ir)
947
- noise_length = estimate_augmented_length_from_length(
948
- length=noise_files[noise_file_id].samples, tempo=noise_augmentation.pre.tempo
949
- )
950
- if noise_offset + target_length >= noise_length:
951
- if noise_offset == 0:
952
- raise ValueError("Length of target audio exceeds length of noise audio")
953
-
954
- noise_offset = 0
955
- noise_augmentation_id += 1
956
- if noise_augmentation_id == len(noise_augmentations):
957
- noise_augmentation_id = 0
958
- noise_file_id += 1
959
- if noise_file_id == len(noise_files):
960
- noise_file_id = 0
961
- noise_augmentation = augmentation_from_rule(noise_augmentations[noise_augmentation_id], num_ir)
962
-
963
- return noise_file_id, noise_augmentation_id, noise_augmentation, noise_offset
964
-
965
-
966
- def _get_target_info(
967
- augmented_target_ids: list[int],
968
- augmented_targets: list[AugmentedTarget],
969
- target_files: list[TargetFile],
970
- target_augmentations: list[AugmentationRule],
971
- feature_step_samples: int,
972
- num_ir: int,
973
- ) -> tuple[list[Target], int]:
974
- from .augmentation import augmentation_from_rule
975
- from .augmentation import estimate_augmented_length_from_length
976
-
977
- mixups: list[Target] = []
978
- target_length = 0
979
- for idx in augmented_target_ids:
980
- tfi = augmented_targets[idx].target_id
981
- target_augmentation_rule = target_augmentations[augmented_targets[idx].target_augmentation_id]
982
- target_augmentation = augmentation_from_rule(target_augmentation_rule, num_ir)
983
-
984
- mixups.append(Target(file_id=tfi + 1, augmentation=target_augmentation))
985
-
986
- target_length = max(
987
- estimate_augmented_length_from_length(
988
- length=target_files[tfi].samples,
989
- tempo=target_augmentation.pre.tempo,
990
- frame_length=feature_step_samples,
991
- ),
992
- target_length,
993
- )
994
- return mixups, target_length
792
+ class NextNoise:
793
+ def __init__(self, mixdb: MixtureDatabase, effected_noises: list[tuple[SourceFile, Effects]]) -> None:
794
+ from .effects import effects_from_rules
795
+ from .effects import estimate_effected_length
995
796
 
797
+ self.mixdb = mixdb
798
+ self.effected_noises = effected_noises
996
799
 
997
- def get_all_snrs_from_config(config: dict) -> list[UniversalSNRGenerator]:
998
- from .datatypes import UniversalSNRGenerator
800
+ self.noise_start = 0
801
+ self.noise_id = 0
802
+ self.noise_effect = effects_from_rules(self.mixdb, self.noise_rule)
803
+ self.noise_length = estimate_effected_length(self.noise_file.samples, self.noise_effect)
999
804
 
1000
- return [UniversalSNRGenerator(is_random=False, _raw_value=snr) for snr in config["snrs"]] + [
1001
- UniversalSNRGenerator(is_random=True, _raw_value=snr) for snr in config["random_snrs"]
1002
- ]
805
+ @property
806
+ def noise_file(self):
807
+ return self.effected_noises[self.noise_id][0]
1003
808
 
809
+ @property
810
+ def noise_rule(self):
811
+ return self.effected_noises[self.noise_id][1]
1004
812
 
1005
- def _get_textgrid_tiers_from_target_file(target_file: str) -> list[str]:
813
+ def generate(self, length: int) -> tuple[int, Effects, int]:
814
+ from .effects import effects_from_rules
815
+ from .effects import estimate_effected_length
816
+
817
+ if self.noise_start + length > self.noise_length:
818
+ # Not enough samples in current noise
819
+ if self.noise_start == 0:
820
+ raise ValueError("Length of primary audio exceeds length of noise audio")
821
+
822
+ self.noise_start = 0
823
+ self.noise_id = (self.noise_id + 1) % len(self.effected_noises)
824
+ self.noise_effect = effects_from_rules(self.mixdb, self.noise_rule)
825
+ self.noise_length = estimate_effected_length(self.noise_file.samples, self.noise_effect)
826
+ noise_start = self.noise_start
827
+ else:
828
+ # Current noise has enough samples
829
+ noise_start = self.noise_start
830
+ self.noise_start += length
831
+
832
+ return self.noise_file.id, self.noise_effect, noise_start
833
+
834
+
835
+ def get_all_snrs_from_config(config: dict) -> dict[str, list[UniversalSNRGenerator]]:
836
+ snrs: dict[str, list[UniversalSNRGenerator]] = {}
837
+ for category in config["sources"]:
838
+ if category != "primary":
839
+ snrs[category] = [UniversalSNRGenerator(snr) for snr in config["sources"][category]["snrs"]]
840
+ return snrs
841
+
842
+
843
+ def _get_textgrid_tiers_from_source_file(file: str) -> list[str]:
1006
844
  from pathlib import Path
1007
845
 
1008
846
  from praatio import textgrid
1009
847
 
1010
- from sonusai.mixture import tokenized_expand
848
+ from ..utils.tokenized_shell_vars import tokenized_expand
1011
849
 
1012
- textgrid_file = Path(tokenized_expand(target_file)[0]).with_suffix(".TextGrid")
850
+ textgrid_file = Path(tokenized_expand(file)[0]).with_suffix(".TextGrid")
1013
851
  if not textgrid_file.exists():
1014
852
  return []
1015
853
 
@@ -1018,18 +856,18 @@ def _get_textgrid_tiers_from_target_file(target_file: str) -> list[str]:
1018
856
  return sorted(tg.tierNames)
1019
857
 
1020
858
 
1021
- def _populate_speaker_table(location: str, target_files: list[TargetFile], test: bool = False) -> None:
859
+ def _populate_speaker_table(location: str, source_files: list[SourceFile], test: bool = False) -> None:
1022
860
  """Populate speaker table"""
1023
861
  import json
1024
862
  from pathlib import Path
1025
863
 
1026
864
  import yaml
1027
865
 
866
+ from ..utils.tokenized_shell_vars import tokenized_expand
1028
867
  from .mixdb import db_connection
1029
- from .tokenized_shell_vars import tokenized_expand
1030
868
 
1031
869
  # Determine columns for speaker table
1032
- all_parents = {Path(target_file.name).parent for target_file in target_files}
870
+ all_parents = {Path(file.name).parent for file in source_files}
1033
871
  speaker_parents = (parent for parent in all_parents if Path(tokenized_expand(parent / "speaker.yml")[0]).exists())
1034
872
 
1035
873
  speakers: dict[Path, dict[str, str]] = {}
@@ -1066,19 +904,16 @@ def _populate_speaker_table(location: str, target_files: list[TargetFile], test:
1066
904
  for description in con.execute("SELECT * FROM speaker").description
1067
905
  if description[0] not in ("id", "parent")
1068
906
  ]
1069
- con.execute(
1070
- "UPDATE top SET speaker_metadata_tiers=? WHERE ? = top.id",
1071
- (json.dumps(tiers), 1),
1072
- )
907
+ con.execute("UPDATE top SET speaker_metadata_tiers=? WHERE ? = id", (json.dumps(tiers), 1))
1073
908
 
1074
909
  if "speaker_id" in tiers:
1075
- con.execute("CREATE INDEX speaker_speaker_id_idx ON speaker (speaker_id)")
910
+ con.execute("CREATE INDEX speaker_speaker_id_idx ON source_file (speaker_id)")
1076
911
 
1077
912
  con.commit()
1078
913
  con.close()
1079
914
 
1080
915
 
1081
- def _populate_truth_config_table(location: str, target_files: list[TargetFile], test: bool = False) -> None:
916
+ def _populate_truth_config_table(location: str, source_files: list[SourceFile], test: bool = False) -> None:
1082
917
  """Populate truth_config table"""
1083
918
  import json
1084
919
 
@@ -1088,8 +923,8 @@ def _populate_truth_config_table(location: str, target_files: list[TargetFile],
1088
923
 
1089
924
  # Populate truth_config table
1090
925
  truth_configs: list[str] = []
1091
- for target_file in target_files:
1092
- for name, config in target_file.truth_configs.items():
926
+ for file in source_files:
927
+ for name, config in file.truth_configs.items():
1093
928
  ts = json.dumps({"name": name} | config.to_dict())
1094
929
  if ts not in truth_configs:
1095
930
  truth_configs.append(ts)
@@ -1100,3 +935,18 @@ def _populate_truth_config_table(location: str, target_files: list[TargetFile],
1100
935
 
1101
936
  con.commit()
1102
937
  con.close()
938
+
939
+
940
+ def _populate_impulse_response_tag_table(location: str, files: list[ImpulseResponseFile], test: bool = False) -> None:
941
+ """Populate ir_tag table"""
942
+ from .mixdb import db_connection
943
+
944
+ con = db_connection(location=location, readonly=False, test=test)
945
+
946
+ con.executemany(
947
+ "INSERT INTO ir_tag (tag) VALUES (?)",
948
+ [(tag,) for tag in {tag for file in files for tag in file.tags}],
949
+ )
950
+
951
+ con.commit()
952
+ con.close()