sonusai 0.20.2__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sonusai/__init__.py +16 -3
  2. sonusai/audiofe.py +240 -76
  3. sonusai/calc_metric_spenh.py +71 -73
  4. sonusai/config/__init__.py +3 -0
  5. sonusai/config/config.py +61 -0
  6. sonusai/config/config.yml +20 -0
  7. sonusai/config/constants.py +8 -0
  8. sonusai/constants.py +11 -0
  9. sonusai/data/genmixdb.yml +21 -36
  10. sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
  11. sonusai/deprecated/plot.py +4 -5
  12. sonusai/doc/doc.py +4 -4
  13. sonusai/doc.py +11 -4
  14. sonusai/genft.py +43 -45
  15. sonusai/genmetrics.py +23 -19
  16. sonusai/genmix.py +54 -82
  17. sonusai/genmixdb.py +88 -264
  18. sonusai/ir_metric.py +30 -34
  19. sonusai/lsdb.py +41 -48
  20. sonusai/main.py +15 -22
  21. sonusai/metrics/calc_audio_stats.py +4 -17
  22. sonusai/metrics/calc_class_weights.py +4 -4
  23. sonusai/metrics/calc_optimal_thresholds.py +8 -5
  24. sonusai/metrics/calc_pesq.py +2 -2
  25. sonusai/metrics/calc_segsnr_f.py +4 -4
  26. sonusai/metrics/calc_speech.py +25 -13
  27. sonusai/metrics/class_summary.py +7 -7
  28. sonusai/metrics/confusion_matrix_summary.py +5 -5
  29. sonusai/metrics/one_hot.py +4 -4
  30. sonusai/metrics/snr_summary.py +7 -7
  31. sonusai/metrics_summary.py +38 -45
  32. sonusai/mixture/__init__.py +5 -104
  33. sonusai/mixture/audio.py +10 -39
  34. sonusai/mixture/class_balancing.py +103 -0
  35. sonusai/mixture/config.py +251 -271
  36. sonusai/mixture/constants.py +35 -39
  37. sonusai/mixture/data_io.py +25 -36
  38. sonusai/mixture/db_datatypes.py +58 -22
  39. sonusai/mixture/effects.py +386 -0
  40. sonusai/mixture/feature.py +7 -11
  41. sonusai/mixture/generation.py +484 -611
  42. sonusai/mixture/helpers.py +82 -184
  43. sonusai/mixture/ir_delay.py +3 -4
  44. sonusai/mixture/ir_effects.py +77 -0
  45. sonusai/mixture/log_duration_and_sizes.py +6 -12
  46. sonusai/mixture/mixdb.py +931 -669
  47. sonusai/mixture/pad_audio.py +35 -0
  48. sonusai/mixture/resample.py +7 -0
  49. sonusai/mixture/sox_effects.py +195 -0
  50. sonusai/mixture/sox_help.py +650 -0
  51. sonusai/mixture/spectral_mask.py +2 -2
  52. sonusai/mixture/truth.py +17 -15
  53. sonusai/mixture/truth_functions/crm.py +12 -12
  54. sonusai/mixture/truth_functions/energy.py +22 -22
  55. sonusai/mixture/truth_functions/file.py +5 -5
  56. sonusai/mixture/truth_functions/metadata.py +4 -4
  57. sonusai/mixture/truth_functions/metrics.py +4 -4
  58. sonusai/mixture/truth_functions/phoneme.py +3 -3
  59. sonusai/mixture/truth_functions/sed.py +11 -13
  60. sonusai/mixture/truth_functions/target.py +10 -10
  61. sonusai/mkwav.py +26 -29
  62. sonusai/onnx_predict.py +240 -88
  63. sonusai/queries/__init__.py +2 -2
  64. sonusai/queries/queries.py +38 -34
  65. sonusai/speech/librispeech.py +1 -1
  66. sonusai/speech/mcgill.py +1 -1
  67. sonusai/speech/timit.py +2 -2
  68. sonusai/summarize_metric_spenh.py +10 -17
  69. sonusai/utils/__init__.py +7 -1
  70. sonusai/utils/asl_p56.py +2 -2
  71. sonusai/utils/asr.py +2 -2
  72. sonusai/utils/asr_functions/aaware_whisper.py +4 -5
  73. sonusai/utils/choice.py +31 -0
  74. sonusai/utils/compress.py +1 -1
  75. sonusai/utils/dataclass_from_dict.py +19 -1
  76. sonusai/utils/energy_f.py +3 -3
  77. sonusai/utils/evaluate_random_rule.py +15 -0
  78. sonusai/utils/keyboard_interrupt.py +12 -0
  79. sonusai/utils/onnx_utils.py +3 -17
  80. sonusai/utils/print_mixture_details.py +21 -19
  81. sonusai/utils/{temp_seed.py → rand.py} +3 -3
  82. sonusai/utils/read_predict_data.py +2 -2
  83. sonusai/utils/reshape.py +3 -3
  84. sonusai/utils/stratified_shuffle_split.py +3 -3
  85. sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
  86. sonusai/utils/write_audio.py +2 -2
  87. sonusai/vars.py +11 -4
  88. {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/METADATA +4 -2
  89. sonusai-1.0.1.dist-info/RECORD +138 -0
  90. sonusai/mixture/augmentation.py +0 -444
  91. sonusai/mixture/class_count.py +0 -15
  92. sonusai/mixture/eq_rule_is_valid.py +0 -45
  93. sonusai/mixture/target_class_balancing.py +0 -107
  94. sonusai/mixture/targets.py +0 -175
  95. sonusai-0.20.2.dist-info/RECORD +0 -128
  96. {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/WHEEL +0 -0
  97. {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/entry_points.txt +0 -0
@@ -1,17 +1,14 @@
1
1
  # ruff: noqa: S608
2
- from .datatypes import AudioT
3
- from .datatypes import Augmentation
4
- from .datatypes import AugmentationRule
5
- from .datatypes import AugmentedTarget
6
- from .datatypes import GenMixData
7
- from .datatypes import ImpulseResponseFile
8
- from .datatypes import Mixture
9
- from .datatypes import NoiseFile
10
- from .datatypes import SpectralMask
11
- from .datatypes import Target
12
- from .datatypes import TargetFile
13
- from .datatypes import UniversalSNRGenerator
14
2
  from .mixdb import MixtureDatabase
3
+ from ..datatypes import AudioT
4
+ from ..datatypes import Effects
5
+ from ..datatypes import GenMixData
6
+ from ..datatypes import ImpulseResponseFile
7
+ from ..datatypes import Mixture
8
+ from ..datatypes import Source
9
+ from ..datatypes import SourceFile
10
+ from ..datatypes import SourcesAudioT
11
+ from ..datatypes import UniversalSNRGenerator
15
12
 
16
13
 
17
14
  def config_file(location: str) -> str:
@@ -34,47 +31,62 @@ def initialize_db(location: str, test: bool = False) -> None:
34
31
  con.execute("""
35
32
  CREATE TABLE truth_parameters(
36
33
  id INTEGER PRIMARY KEY NOT NULL,
34
+ category TEXT NOT NULL,
37
35
  name TEXT NOT NULL,
38
36
  parameters INTEGER)
39
37
  """)
40
38
 
41
39
  con.execute("""
42
- CREATE TABLE target_file (
40
+ CREATE TABLE source_file (
43
41
  id INTEGER PRIMARY KEY NOT NULL,
42
+ category TEXT NOT NULL,
43
+ class_indices TEXT,
44
+ level_type TEXT NOT NULL,
44
45
  name TEXT NOT NULL,
45
46
  samples INTEGER NOT NULL,
46
- class_indices TEXT NOT NULL,
47
- level_type TEXT NOT NULL,
48
47
  speaker_id INTEGER,
49
48
  FOREIGN KEY(speaker_id) REFERENCES speaker (id))
50
49
  """)
51
50
 
52
51
  con.execute("""
53
- CREATE TABLE speaker (
52
+ CREATE TABLE ir_file (
54
53
  id INTEGER PRIMARY KEY NOT NULL,
55
- parent TEXT NOT NULL)
54
+ delay INTEGER NOT NULL,
55
+ name TEXT NOT NULL)
56
56
  """)
57
57
 
58
58
  con.execute("""
59
- CREATE TABLE noise_file (
59
+ CREATE TABLE ir_tag (
60
60
  id INTEGER PRIMARY KEY NOT NULL,
61
- name TEXT NOT NULL,
62
- samples INTEGER NOT NULL)
61
+ tag TEXT NOT NULL UNIQUE)
62
+ """)
63
+
64
+ con.execute("""
65
+ CREATE TABLE ir_file_ir_tag (
66
+ file_id INTEGER NOT NULL,
67
+ tag_id INTEGER NOT NULL,
68
+ FOREIGN KEY(file_id) REFERENCES ir_file (id),
69
+ FOREIGN KEY(tag_id) REFERENCES ir_tag (id))
70
+ """)
71
+
72
+ con.execute("""
73
+ CREATE TABLE speaker (
74
+ id INTEGER PRIMARY KEY NOT NULL,
75
+ parent TEXT NOT NULL)
63
76
  """)
64
77
 
65
78
  con.execute("""
66
79
  CREATE TABLE top (
67
80
  id INTEGER PRIMARY KEY NOT NULL,
68
- version INTEGER NOT NULL,
69
81
  asr_configs TEXT NOT NULL,
70
82
  class_balancing BOOLEAN NOT NULL,
71
83
  feature TEXT NOT NULL,
72
- noise_mix_mode TEXT NOT NULL,
84
+ mixid_width INTEGER NOT NULL,
73
85
  num_classes INTEGER NOT NULL,
74
86
  seed INTEGER NOT NULL,
75
- mixid_width INTEGER NOT NULL,
76
87
  speaker_metadata_tiers TEXT NOT NULL,
77
- textgrid_metadata_tiers TEXT NOT NULL)
88
+ textgrid_metadata_tiers TEXT NOT NULL,
89
+ version INTEGER NOT NULL)
78
90
  """)
79
91
 
80
92
  con.execute("""
@@ -89,64 +101,54 @@ def initialize_db(location: str, test: bool = False) -> None:
89
101
  threshold FLOAT NOT NULL)
90
102
  """)
91
103
 
92
- con.execute("""
93
- CREATE TABLE impulse_response_file (
94
- id INTEGER PRIMARY KEY NOT NULL,
95
- file TEXT NOT NULL,
96
- tags TEXT NOT NULL,
97
- delay INTEGER NOT NULL)
98
- """)
99
-
100
104
  con.execute("""
101
105
  CREATE TABLE spectral_mask (
102
106
  id INTEGER PRIMARY KEY NOT NULL,
103
107
  f_max_width INTEGER NOT NULL,
104
108
  f_num INTEGER NOT NULL,
109
+ t_max_percent INTEGER NOT NULL,
105
110
  t_max_width INTEGER NOT NULL,
106
- t_num INTEGER NOT NULL,
107
- t_max_percent INTEGER NOT NULL)
111
+ t_num INTEGER NOT NULL)
108
112
  """)
109
113
 
110
114
  con.execute("""
111
- CREATE TABLE target_file_truth_config (
112
- target_file_id INTEGER,
113
- truth_config_id INTEGER,
114
- FOREIGN KEY(target_file_id) REFERENCES target_file (id),
115
+ CREATE TABLE source_file_truth_config (
116
+ source_file_id INTEGER NOT NULL,
117
+ truth_config_id INTEGER NOT NULL,
118
+ FOREIGN KEY(source_file_id) REFERENCES source_file (id),
115
119
  FOREIGN KEY(truth_config_id) REFERENCES truth_config (id))
116
120
  """)
117
121
 
118
122
  con.execute("""
119
- CREATE TABLE target (
123
+ CREATE TABLE source (
120
124
  id INTEGER PRIMARY KEY NOT NULL,
125
+ effects TEXT NOT NULL,
121
126
  file_id INTEGER NOT NULL,
122
- augmentation TEXT NOT NULL,
123
- FOREIGN KEY(file_id) REFERENCES target_file (id))
127
+ pre_tempo FLOAT NOT NULL,
128
+ repeat BOOLEAN NOT NULL,
129
+ snr FLOAT NOT NULL,
130
+ snr_gain FLOAT NOT NULL,
131
+ snr_random BOOLEAN NOT NULL,
132
+ start INTEGER NOT NULL,
133
+ FOREIGN KEY(file_id) REFERENCES source_file (id))
124
134
  """)
125
135
 
126
136
  con.execute("""
127
137
  CREATE TABLE mixture (
128
138
  id INTEGER PRIMARY KEY NOT NULL,
129
- name VARCHAR NOT NULL,
130
- noise_file_id INTEGER NOT NULL,
131
- noise_augmentation TEXT NOT NULL,
132
- noise_offset INTEGER NOT NULL,
133
- noise_snr_gain FLOAT,
134
- random_snr BOOLEAN NOT NULL,
135
- snr FLOAT NOT NULL,
139
+ name TEXT NOT NULL,
136
140
  samples INTEGER NOT NULL,
137
141
  spectral_mask_id INTEGER NOT NULL,
138
142
  spectral_mask_seed INTEGER NOT NULL,
139
- target_snr_gain FLOAT,
140
- FOREIGN KEY(noise_file_id) REFERENCES noise_file (id),
141
143
  FOREIGN KEY(spectral_mask_id) REFERENCES spectral_mask (id))
142
144
  """)
143
145
 
144
146
  con.execute("""
145
- CREATE TABLE mixture_target (
146
- mixture_id INTEGER,
147
- target_id INTEGER,
147
+ CREATE TABLE mixture_source (
148
+ mixture_id INTEGER NOT NULL,
149
+ source_id INTEGER NOT NULL,
148
150
  FOREIGN KEY(mixture_id) REFERENCES mixture (id),
149
- FOREIGN KEY(target_id) REFERENCES target (id))
151
+ FOREIGN KEY(source_id) REFERENCES source (id))
150
152
  """)
151
153
 
152
154
  con.commit()
@@ -163,22 +165,21 @@ def populate_top_table(location: str, config: dict, test: bool = False) -> None:
163
165
  con = db_connection(location=location, readonly=False, test=test)
164
166
  con.execute(
165
167
  """
166
- INSERT INTO top (id, version, asr_configs, class_balancing, feature, noise_mix_mode, num_classes,
167
- seed, mixid_width, speaker_metadata_tiers, textgrid_metadata_tiers)
168
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
168
+ INSERT INTO top (id, asr_configs, class_balancing, feature, mixid_width, num_classes,
169
+ seed, speaker_metadata_tiers, textgrid_metadata_tiers, version)
170
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
169
171
  """,
170
172
  (
171
173
  1,
172
- MIXDB_VERSION,
173
174
  json.dumps(config["asr_configs"]),
174
175
  config["class_balancing"],
175
176
  config["feature"],
176
- config["noise_mix_mode"],
177
+ 0,
177
178
  config["num_classes"],
178
179
  config["seed"],
179
- 0,
180
180
  "",
181
181
  "",
182
+ MIXDB_VERSION,
182
183
  ),
183
184
  )
184
185
  con.commit()
@@ -231,15 +232,15 @@ def populate_spectral_mask_table(location: str, config: dict, test: bool = False
231
232
  con = db_connection(location=location, readonly=False, test=test)
232
233
  con.executemany(
233
234
  """
234
- INSERT INTO spectral_mask (f_max_width, f_num, t_max_width, t_num, t_max_percent) VALUES (?, ?, ?, ?, ?)
235
+ INSERT INTO spectral_mask (f_max_width, f_num, t_max_percent, t_max_width, t_num) VALUES (?, ?, ?, ?, ?)
235
236
  """,
236
237
  [
237
238
  (
238
239
  item.f_max_width,
239
240
  item.f_num,
241
+ item.t_max_percent,
240
242
  item.t_max_width,
241
243
  item.t_num,
242
- item.t_max_percent,
243
244
  )
244
245
  for item in get_spectral_masks(config)
245
246
  ],
@@ -256,10 +257,11 @@ def populate_truth_parameters_table(location: str, config: dict, test: bool = Fa
256
257
  con = db_connection(location=location, readonly=False, test=test)
257
258
  con.executemany(
258
259
  """
259
- INSERT INTO truth_parameters (name, parameters) VALUES (?, ?)
260
+ INSERT INTO truth_parameters (category, name, parameters) VALUES (?, ?, ?)
260
261
  """,
261
262
  [
262
263
  (
264
+ item.category,
263
265
  item.name,
264
266
  item.parameters,
265
267
  )
@@ -270,40 +272,41 @@ def populate_truth_parameters_table(location: str, config: dict, test: bool = Fa
270
272
  con.close()
271
273
 
272
274
 
273
- def populate_target_file_table(location: str, target_files: list[TargetFile], test: bool = False) -> None:
274
- """Populate target file table"""
275
+ def populate_source_file_table(location: str, files: list[SourceFile], test: bool = False) -> None:
276
+ """Populate source file table"""
275
277
  import json
276
278
  from pathlib import Path
277
279
 
278
280
  from .mixdb import db_connection
279
281
 
280
- _populate_truth_config_table(location, target_files, test)
281
- _populate_speaker_table(location, target_files, test)
282
+ _populate_truth_config_table(location, files, test)
283
+ _populate_speaker_table(location, files, test)
282
284
 
283
285
  con = db_connection(location=location, readonly=False, test=test)
284
286
 
285
287
  cur = con.cursor()
286
288
  textgrid_metadata_tiers: set[str] = set()
287
- for target_file in target_files:
288
- # Get TextGrid tiers for target file and add to collection
289
- tiers = _get_textgrid_tiers_from_target_file(target_file.name)
289
+ for file in files:
290
+ # Get TextGrid tiers for source file and add to collection
291
+ tiers = _get_textgrid_tiers_from_source_file(file.name)
290
292
  for tier in tiers:
291
293
  textgrid_metadata_tiers.add(tier)
292
294
 
293
- # Get truth settings for target file
295
+ # Get truth settings for file
294
296
  truth_config_ids: list[int] = []
295
- for name, config in target_file.truth_configs.items():
296
- ts = json.dumps({"name": name} | config.to_dict())
297
- cur.execute(
298
- "SELECT truth_config.id FROM truth_config WHERE ? = truth_config.config",
299
- (ts,),
300
- )
301
- truth_config_ids.append(cur.fetchone()[0])
302
-
303
- # Get speaker_id for target file
297
+ if file.truth_configs:
298
+ for name, config in file.truth_configs.items():
299
+ ts = json.dumps({"name": name} | config.to_dict())
300
+ cur.execute(
301
+ "SELECT truth_config.id FROM truth_config WHERE ? = truth_config.config",
302
+ (ts,),
303
+ )
304
+ truth_config_ids.append(cur.fetchone()[0])
305
+
306
+ # Get speaker_id for source file
304
307
  cur.execute(
305
308
  "SELECT speaker.id FROM speaker WHERE ? = speaker.parent",
306
- (Path(target_file.name).parent.as_posix(),),
309
+ (Path(file.name).parent.as_posix(),),
307
310
  )
308
311
  result = cur.fetchone()
309
312
  speaker_id = None
@@ -312,20 +315,24 @@ def populate_target_file_table(location: str, target_files: list[TargetFile], te
312
315
 
313
316
  # Add entry
314
317
  cur.execute(
315
- "INSERT INTO target_file (name, samples, class_indices, level_type, speaker_id) VALUES (?, ?, ?, ?, ?)",
318
+ """
319
+ INSERT INTO source_file (category, class_indices, level_type, name, samples, speaker_id)
320
+ VALUES (?, ?, ?, ?, ?, ?)
321
+ """,
316
322
  (
317
- target_file.name,
318
- target_file.samples,
319
- json.dumps(target_file.class_indices),
320
- target_file.level_type,
323
+ file.category,
324
+ json.dumps(file.class_indices),
325
+ file.level_type,
326
+ file.name,
327
+ file.samples,
321
328
  speaker_id,
322
329
  ),
323
330
  )
324
- target_file_id = cur.lastrowid
331
+ source_file_id = cur.lastrowid
325
332
  for truth_config_id in truth_config_ids:
326
333
  cur.execute(
327
- "INSERT INTO target_file_truth_config (target_file_id, truth_config_id) VALUES (?, ?)",
328
- (target_file_id, truth_config_id),
334
+ "INSERT INTO source_file_truth_config (source_file_id, truth_config_id) VALUES (?, ?)",
335
+ (source_file_id, truth_config_id),
329
336
  )
330
337
 
331
338
  # Update textgrid_metadata_tiers in the top table
@@ -338,47 +345,47 @@ def populate_target_file_table(location: str, target_files: list[TargetFile], te
338
345
  con.close()
339
346
 
340
347
 
341
- def populate_noise_file_table(location: str, noise_files: list[NoiseFile], test: bool = False) -> None:
342
- """Populate noise file table"""
348
+ def populate_impulse_response_file_table(location: str, files: list[ImpulseResponseFile], test: bool = False) -> None:
349
+ """Populate impulse response file table"""
343
350
  from .mixdb import db_connection
344
351
 
345
- con = db_connection(location=location, readonly=False, test=test)
346
- con.executemany(
347
- "INSERT INTO noise_file (name, samples) VALUES (?, ?)",
348
- [(noise_file.name, noise_file.samples) for noise_file in noise_files],
349
- )
350
- con.commit()
351
- con.close()
352
-
352
+ _populate_impulse_response_tag_table(location, files, test)
353
353
 
354
- def populate_impulse_response_file_table(
355
- location: str, impulse_response_files: list[ImpulseResponseFile], test: bool = False
356
- ) -> None:
357
- """Populate impulse response file table"""
358
- import json
354
+ con = db_connection(location=location, readonly=False, test=test)
359
355
 
360
- from .mixdb import db_connection
356
+ cur = con.cursor()
357
+ for file in files:
358
+ # Get tags for file
359
+ tag_ids: list[int] = []
360
+ for tag in file.tags:
361
+ cur.execute(
362
+ "SELECT ir_tag.id FROM ir_tag WHERE ? = ir_tag.tag",
363
+ (tag,),
364
+ )
365
+ tag_ids.append(cur.fetchone()[0])
361
366
 
362
- con = db_connection(location=location, readonly=False, test=test)
363
- con.executemany(
364
- "INSERT INTO impulse_response_file (file, tags, delay) VALUES (?, ?, ?)",
365
- [
367
+ cur.execute(
368
+ "INSERT INTO ir_file (delay, name) VALUES (?, ?)",
366
369
  (
367
- impulse_response_file.file,
368
- json.dumps(impulse_response_file.tags),
369
- impulse_response_file.delay,
370
+ file.delay,
371
+ file.name,
372
+ ),
373
+ )
374
+
375
+ file_id = cur.lastrowid
376
+ for tag_id in tag_ids:
377
+ cur.execute(
378
+ "INSERT INTO ir_file_ir_tag (file_id, tag_id) VALUES (?, ?)",
379
+ (file_id, tag_id),
370
380
  )
371
- for impulse_response_file in impulse_response_files
372
- ],
373
- )
381
+
374
382
  con.commit()
375
383
  con.close()
376
384
 
377
385
 
378
386
  def update_mixid_width(location: str, num_mixtures: int, test: bool = False) -> None:
379
387
  """Update the mixid width"""
380
- from sonusai.utils import max_text_width
381
-
388
+ from ..utils.max_text_width import max_text_width
382
389
  from .mixdb import db_connection
383
390
 
384
391
  con = db_connection(location=location, readonly=False, test=test)
@@ -391,42 +398,43 @@ def update_mixid_width(location: str, num_mixtures: int, test: bool = False) ->
391
398
 
392
399
 
393
400
  def generate_mixtures(
394
- noise_mix_mode: str,
395
- augmented_targets: list[AugmentedTarget],
396
- target_files: list[TargetFile],
397
- target_augmentations: list[AugmentationRule],
398
- noise_files: list[NoiseFile],
399
- noise_augmentations: list[AugmentationRule],
400
- spectral_masks: list[SpectralMask],
401
- all_snrs: list[UniversalSNRGenerator],
402
- mixups: list[int],
403
- num_classes: int,
404
- feature_step_samples: int,
405
- num_ir: int,
406
- ) -> tuple[int, int, list[Mixture]]:
401
+ location: str,
402
+ config: dict,
403
+ effects: dict[str, list[Effects]],
404
+ test: bool = False,
405
+ ) -> list[Mixture]:
407
406
  """Generate mixtures"""
408
- if noise_mix_mode == "exhaustive":
409
- func = _exhaustive_noise_mix
410
- elif noise_mix_mode == "non-exhaustive":
411
- func = _non_exhaustive_noise_mix
412
- elif noise_mix_mode == "non-combinatorial":
413
- func = _non_combinatorial_noise_mix
414
- else:
415
- raise ValueError(f"invalid noise_mix_mode: {noise_mix_mode}")
416
-
417
- return func(
418
- augmented_targets=augmented_targets,
419
- target_files=target_files,
420
- target_augmentations=target_augmentations,
421
- noise_files=noise_files,
422
- noise_augmentations=noise_augmentations,
423
- spectral_masks=spectral_masks,
424
- all_snrs=all_snrs,
425
- mixups=mixups,
426
- num_classes=num_classes,
427
- feature_step_samples=feature_step_samples,
428
- num_ir=num_ir,
429
- )
407
+ mixdb = MixtureDatabase(location, test)
408
+
409
+ effected_sources: dict[str, list[tuple[SourceFile, Effects]]] = {}
410
+ for category in mixdb.source_files:
411
+ effected_sources[category] = []
412
+ for file in mixdb.source_files[category]:
413
+ for effect in effects[category]:
414
+ effected_sources[category].append((file, effect))
415
+
416
+ mixtures: list[Mixture] = []
417
+ for noise_mix_rule in config["sources"]["noise"]["mix_rules"]:
418
+ match noise_mix_rule["mode"]:
419
+ case "exhaustive":
420
+ func = _exhaustive_noise_mix
421
+ case "non-exhaustive":
422
+ func = _non_exhaustive_noise_mix
423
+ case "non-combinatorial":
424
+ func = _non_combinatorial_noise_mix
425
+ case _:
426
+ raise ValueError(f"invalid noise mix_rule mode: {noise_mix_rule['mode']}")
427
+
428
+ mixtures.extend(
429
+ func(
430
+ location=location,
431
+ config=config,
432
+ effected_sources=effected_sources,
433
+ test=test,
434
+ )
435
+ )
436
+
437
+ return mixtures
430
438
 
431
439
 
432
440
  def populate_mixture_table(
@@ -437,26 +445,33 @@ def populate_mixture_table(
437
445
  show_progress: bool = False,
438
446
  ) -> None:
439
447
  """Populate mixture table"""
440
- from sonusai import logger
441
- from sonusai.utils import track
442
-
448
+ from .. import logger
449
+ from ..utils.parallel import track
443
450
  from .helpers import from_mixture
444
- from .helpers import from_target
451
+ from .helpers import from_source
445
452
  from .mixdb import db_connection
446
453
 
447
454
  con = db_connection(location=location, readonly=False, test=test)
448
455
 
449
- # Populate target table
456
+ # Populate source table
450
457
  if logging:
451
- logger.info("Populating target table")
452
- targets: list[tuple[int, str]] = []
458
+ logger.info("Populating source table")
459
+ sources: list[tuple[str, int, float, bool, float, float, bool, int]] = []
453
460
  for mixture in mixtures:
454
- for target in mixture.targets:
455
- entry = from_target(target)
456
- if entry not in targets:
457
- targets.append(entry)
458
- for target in track(targets, disable=not show_progress):
459
- con.execute("INSERT INTO target (file_id, augmentation) VALUES (?, ?)", target)
461
+ for source in mixture.all_sources.values():
462
+ entry = from_source(source)
463
+ if entry not in sources:
464
+ sources.append(entry)
465
+ for source in track(sources, disable=not show_progress):
466
+ con.execute(
467
+ """
468
+ INSERT INTO source (effects, file_id, pre_tempo, repeat, snr, snr_gain, snr_random, start)
469
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
470
+ """,
471
+ source,
472
+ )
473
+
474
+ con.commit()
460
475
 
461
476
  # Populate mixture table
462
477
  if logging:
@@ -465,25 +480,31 @@ def populate_mixture_table(
465
480
  m_id = int(mixture.name)
466
481
  con.execute(
467
482
  """
468
- INSERT INTO mixture (id, name, noise_file_id, noise_augmentation, noise_offset, noise_snr_gain, random_snr,
469
- snr, samples, spectral_mask_id, spectral_mask_seed, target_snr_gain)
470
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
483
+ INSERT INTO mixture (id, name, samples, spectral_mask_id, spectral_mask_seed)
484
+ VALUES (?, ?, ?, ?, ?)
471
485
  """,
472
486
  (m_id + 1, *from_mixture(mixture)),
473
487
  )
474
488
 
475
- for target in mixture.targets:
476
- target_id = con.execute(
489
+ for source in mixture.all_sources.values():
490
+ source_id = con.execute(
477
491
  """
478
- SELECT target.id
479
- FROM target
480
- WHERE ? = target.file_id AND ? = target.augmentation
492
+ SELECT source.id
493
+ FROM source
494
+ WHERE ? = source.effects
495
+ AND ? = source.file_id
496
+ AND ? = source.pre_tempo
497
+ AND ? = source.repeat
498
+ AND ? = source.snr
499
+ AND ? = source.snr_gain
500
+ AND ? = source.snr_random
501
+ AND ? = source.start
481
502
  """,
482
- from_target(target),
503
+ from_source(source),
483
504
  ).fetchone()[0]
484
505
  con.execute(
485
- "INSERT INTO mixture_target (mixture_id, target_id) VALUES (?, ?)",
486
- (m_id + 1, target_id),
506
+ "INSERT INTO mixture_source (mixture_id, source_id) VALUES (?, ?)",
507
+ (m_id + 1, source_id),
487
508
  )
488
509
 
489
510
  con.commit()
@@ -491,525 +512,362 @@ def populate_mixture_table(
491
512
 
492
513
 
493
514
  def update_mixture(mixdb: MixtureDatabase, mixture: Mixture, with_data: bool = False) -> tuple[Mixture, GenMixData]:
494
- """Update mixture record with name and gains"""
495
- from .audio import get_next_noise
496
- from .augmentation import apply_gain
497
- from .helpers import get_target
498
-
499
- mixture, targets_audio = _initialize_targets_audio(mixdb, mixture)
500
-
501
- noise_audio = _augmented_noise_audio(mixdb, mixture)
502
- noise_audio = get_next_noise(audio=noise_audio, offset=mixture.noise_offset, length=mixture.samples)
515
+ """Update mixture record with name, samples, and gains"""
516
+ import numpy as np
503
517
 
504
- # Apply IR and sum targets audio before initializing the mixture SNR gains
505
- target_audio = get_target(mixdb, mixture, targets_audio)
518
+ sources_audio: SourcesAudioT = {}
519
+ post_audio: SourcesAudioT = {}
520
+ for category in mixture.all_sources:
521
+ mixture, sources_audio[category], post_audio[category] = _update_source(mixdb, mixture, category)
506
522
 
507
- mixture = _initialize_mixture_gains(
508
- mixdb=mixdb, mixture=mixture, target_audio=target_audio, noise_audio=noise_audio
509
- )
523
+ mixture = _initialize_mixture_gains(mixdb, mixture, post_audio)
510
524
 
511
525
  mixture.name = f"{int(mixture.name):0{mixdb.mixid_width}}"
512
526
 
513
527
  if not with_data:
514
528
  return mixture, GenMixData()
515
529
 
516
- # Apply SNR gains
517
- targets_audio = [apply_gain(audio=target_audio, gain=mixture.target_snr_gain) for target_audio in targets_audio]
518
- noise_audio = apply_gain(audio=noise_audio, gain=mixture.noise_snr_gain)
530
+ # Apply gains
531
+ post_audio = {
532
+ category: post_audio[category] * mixture.all_sources[category].snr_gain for category in mixture.all_sources
533
+ }
519
534
 
520
- # Apply IR and sum targets audio after applying the mixture SNR gains
521
- target_audio = get_target(mixdb, mixture, targets_audio)
522
- mixture_audio = target_audio + noise_audio
535
+ # Sum sources, noise, and mixture
536
+ source_audio = np.sum([post_audio[category] for category in mixture.sources], axis=0)
537
+ noise_audio = post_audio["noise"]
538
+ mixture_audio = source_audio + noise_audio
523
539
 
524
540
  return mixture, GenMixData(
525
- mixture=mixture_audio,
526
- targets=targets_audio,
527
- target=target_audio,
541
+ sources=sources_audio,
542
+ source=source_audio,
528
543
  noise=noise_audio,
544
+ mixture=mixture_audio,
529
545
  )
530
546
 
531
547
 
532
- def _augmented_noise_audio(mixdb: MixtureDatabase, mixture: Mixture) -> AudioT:
533
- from .audio import read_audio
534
- from .augmentation import apply_augmentation
535
-
536
- noise = mixdb.noise_file(mixture.noise.file_id)
537
- noise_augmentation = mixture.noise.augmentation
548
+ def _update_source(mixdb: MixtureDatabase, mixture: Mixture, category: str) -> tuple[Mixture, AudioT, AudioT]:
549
+ from .effects import apply_effects
550
+ from .effects import conform_audio_to_length
538
551
 
539
- audio = read_audio(noise.name)
540
- audio = apply_augmentation(mixdb, audio, noise_augmentation.pre)
552
+ source = mixture.all_sources[category]
553
+ org_audio = mixdb.read_source_audio(source.file_id)
541
554
 
542
- return audio
555
+ org_samples = len(org_audio)
556
+ pre_audio = apply_effects(mixdb, org_audio, source.effects, pre=True, post=False)
543
557
 
558
+ pre_samples = len(pre_audio)
559
+ mixture.all_sources[category].pre_tempo = org_samples / pre_samples
544
560
 
545
- def _initialize_targets_audio(mixdb: MixtureDatabase, mixture: Mixture) -> tuple[Mixture, list[AudioT]]:
546
- from .augmentation import apply_augmentation
547
- from .augmentation import pad_audio_to_length
561
+ pre_audio = conform_audio_to_length(pre_audio, mixture.samples, source.repeat, source.start)
548
562
 
549
- targets_audio = []
550
- for target in mixture.targets:
551
- target_audio = mixdb.read_target_audio(target.file_id)
552
- targets_audio.append(
553
- apply_augmentation(
554
- mixdb=mixdb,
555
- audio=target_audio,
556
- augmentation=target.augmentation.pre,
557
- frame_length=mixdb.feature_step_samples,
558
- )
559
- )
560
-
561
- mixture.samples = max([len(item) for item in targets_audio])
563
+ post_audio = apply_effects(mixdb, pre_audio, source.effects, pre=False, post=True)
564
+ if len(pre_audio) != len(post_audio):
565
+ raise RuntimeError(f"post-truth effects changed length: {source.effects.post}")
562
566
 
563
- for idx in range(len(targets_audio)):
564
- targets_audio[idx] = pad_audio_to_length(audio=targets_audio[idx], length=mixture.samples)
567
+ return mixture, pre_audio, post_audio
565
568
 
566
- return mixture, targets_audio
567
569
 
568
-
569
- def _initialize_mixture_gains(
570
- mixdb: MixtureDatabase,
571
- mixture: Mixture,
572
- target_audio: AudioT,
573
- noise_audio: AudioT,
574
- ) -> Mixture:
570
+ def _initialize_mixture_gains(mixdb: MixtureDatabase, mixture: Mixture, sources_audio: SourcesAudioT) -> Mixture:
575
571
  import numpy as np
576
572
 
577
- from sonusai.utils import asl_p56
578
- from sonusai.utils import db_to_linear
579
-
580
- if mixture.is_noise_only:
581
- # Special case for zeroing out target data
582
- mixture.target_snr_gain = 0
583
- mixture.noise_snr_gain = 1
584
- elif mixture.is_target_only:
585
- # Special case for zeroing out noise data
586
- mixture.target_snr_gain = 1
587
- mixture.noise_snr_gain = 0
588
- else:
589
- target_level_types = [
590
- target_file.level_type for target_file in [mixdb.target_file(target.file_id) for target in mixture.targets]
591
- ]
592
- if not all(level_type == target_level_types[0] for level_type in target_level_types):
593
- raise ValueError("Not all target_level_types in mixup are the same")
594
-
595
- level_type = target_level_types[0]
573
+ from ..utils.asl_p56 import asl_p56
574
+ from ..utils.db import db_to_linear
575
+
576
+ sources_energy: dict[str, float] = {}
577
+ for category in mixture.all_sources:
578
+ level_type = mixdb.source_file(mixture.all_sources[category].file_id).level_type
596
579
  match level_type:
597
580
  case "default":
598
- target_energy = np.mean(np.square(target_audio))
581
+ sources_energy[category] = float(np.mean(np.square(sources_audio[category])))
599
582
  case "speech":
600
- target_energy = asl_p56(target_audio)
583
+ sources_energy[category] = asl_p56(sources_audio[category])
601
584
  case _:
602
585
  raise ValueError(f"Unknown level_type: {level_type}")
603
586
 
604
- noise_energy = np.mean(np.square(noise_audio))
605
- if noise_energy == 0:
606
- noise_gain = 1
607
- else:
608
- noise_gain = np.sqrt(target_energy / noise_energy) / db_to_linear(mixture.snr)
609
-
610
- # Check for noise_gain > 1 to avoid clipping
611
- if noise_gain > 1:
612
- mixture.target_snr_gain = 1 / noise_gain
613
- mixture.noise_snr_gain = 1
614
- else:
615
- mixture.target_snr_gain = 1
616
- mixture.noise_snr_gain = noise_gain
587
+ # Initialize all gains to 1
588
+ for category in mixture.all_sources:
589
+ mixture.all_sources[category].snr_gain = 1
590
+
591
+ # Resolve gains
592
+ for category in mixture.all_sources:
593
+ if mixture.is_noise_only and category != "noise":
594
+ # Special case for zeroing out source data
595
+ mixture.all_sources[category].snr_gain = 0
596
+ elif mixture.is_source_only and category == "noise":
597
+ # Special case for zeroing out noise data
598
+ mixture.all_sources[category].snr_gain = 0
599
+ elif category != "primary":
600
+ if sources_energy["primary"] == 0:
601
+ # Avoid divide-by-zero
602
+ mixture.all_sources[category].snr_gain = 1
603
+ else:
604
+ mixture.all_sources[category].snr_gain = float(
605
+ np.sqrt(sources_energy["primary"] / sources_energy[category])
606
+ ) / db_to_linear(mixture.all_sources[category].snr)
607
+
608
+ # Normalize gains
609
+ max_snr_gain = max([source.snr_gain for source in mixture.all_sources.values()])
610
+ for category in mixture.all_sources:
611
+ mixture.all_sources[category].snr_gain = mixture.all_sources[category].snr_gain / max_snr_gain
617
612
 
618
613
  # Check for clipping in mixture
619
- gain_adjusted_target_audio = target_audio * mixture.target_snr_gain
620
- gain_adjusted_noise_audio = noise_audio * mixture.noise_snr_gain
621
- mixture_audio = gain_adjusted_target_audio + gain_adjusted_noise_audio
622
- max_abs_audio = max(abs(mixture_audio))
614
+ mixture_audio = np.sum(
615
+ [sources_audio[category] * mixture.all_sources[category].snr_gain for category in mixture.all_sources], axis=0
616
+ )
617
+ max_abs_audio = float(np.max(np.abs(mixture_audio)))
623
618
  clip_level = db_to_linear(-0.25)
624
619
  if max_abs_audio > clip_level:
625
- # Clipping occurred; lower gains to bring audio within +/-1
626
620
  gain_adjustment = clip_level / max_abs_audio
627
- mixture.target_snr_gain *= gain_adjustment
628
- mixture.noise_snr_gain *= gain_adjustment
621
+ for category in mixture.all_sources:
622
+ mixture.all_sources[category].snr_gain *= gain_adjustment
623
+
624
+ # To improve repeatability, round results
625
+ for category in mixture.all_sources:
626
+ mixture.all_sources[category].snr_gain = round(mixture.all_sources[category].snr_gain, ndigits=5)
629
627
 
630
- mixture.target_snr_gain = round(mixture.target_snr_gain, ndigits=5)
631
- mixture.noise_snr_gain = round(mixture.noise_snr_gain, ndigits=5)
632
628
  return mixture
633
629
 
634
630
 
635
631
  def _exhaustive_noise_mix(
636
- augmented_targets: list[AugmentedTarget],
637
- target_files: list[TargetFile],
638
- target_augmentations: list[AugmentationRule],
639
- noise_files: list[NoiseFile],
640
- noise_augmentations: list[AugmentationRule],
641
- spectral_masks: list[SpectralMask],
642
- all_snrs: list[UniversalSNRGenerator],
643
- mixups: list[int],
644
- num_classes: int,
645
- feature_step_samples: int,
646
- num_ir: int,
647
- ) -> tuple[int, int, list[Mixture]]:
648
- """Use every noise/augmentation with every target/augmentation+interferences/augmentation"""
632
+ location: str,
633
+ config: dict,
634
+ effected_sources: dict[str, list[tuple[SourceFile, Effects]]],
635
+ test: bool = False,
636
+ ) -> list[Mixture]:
637
+ """Use every noise/effect with every source/effect+interferences/effect"""
649
638
  from random import randint
650
639
 
651
640
  import numpy as np
652
641
 
653
- from .augmentation import augmentation_from_rule
654
- from .augmentation import estimate_augmented_length_from_length
655
- from .datatypes import Mixture
656
- from .datatypes import Noise
657
- from .datatypes import UniversalSNR
658
- from .targets import get_augmented_target_ids_for_mixup
642
+ from ..datatypes import Mixture
643
+ from ..datatypes import UniversalSNR
644
+ from .effects import effects_from_rules
645
+ from .effects import estimate_effected_length
659
646
 
660
- m_id = 0
661
- used_noise_files = len(noise_files) * len(noise_augmentations)
662
- used_noise_samples = 0
663
-
664
- augmented_target_ids_for_mixups = [
665
- get_augmented_target_ids_for_mixup(
666
- augmented_targets=augmented_targets,
667
- targets=target_files,
668
- target_augmentations=target_augmentations,
669
- mixup=mixup,
670
- num_classes=num_classes,
671
- )
672
- for mixup in mixups
673
- ]
647
+ mixdb = MixtureDatabase(location, test)
648
+ snrs = get_all_snrs_from_config(config)
674
649
 
650
+ m_id = 0
675
651
  mixtures: list[Mixture] = []
676
- for noise_file_id in range(len(noise_files)):
677
- for noise_augmentation_rule in noise_augmentations:
678
- noise_augmentation = augmentation_from_rule(noise_augmentation_rule, num_ir)
679
- noise_offset = 0
680
- noise_length = estimate_augmented_length_from_length(
681
- length=noise_files[noise_file_id].samples,
682
- tempo=noise_augmentation.pre.tempo,
683
- )
652
+ for noise_file, noise_rule in effected_sources["noise"]:
653
+ noise_start = 0
654
+ noise_effect = effects_from_rules(mixdb, noise_rule)
655
+ noise_length = estimate_effected_length(noise_file.samples, noise_effect)
684
656
 
685
- for augmented_target_ids_for_mixup in augmented_target_ids_for_mixups:
686
- for augmented_target_ids in augmented_target_ids_for_mixup:
687
- targets, target_length = _get_target_info(
688
- augmented_target_ids=augmented_target_ids,
689
- augmented_targets=augmented_targets,
690
- target_files=target_files,
691
- target_augmentations=target_augmentations,
692
- feature_step_samples=feature_step_samples,
693
- num_ir=num_ir,
694
- )
657
+ for primary_file, primary_rule in effected_sources["primary"]:
658
+ primary_effect = effects_from_rules(mixdb, primary_rule)
659
+ primary_length = estimate_effected_length(primary_file.samples, primary_effect, mixdb.feature_step_samples)
695
660
 
696
- for spectral_mask_id in range(len(spectral_masks)):
697
- for snr in all_snrs:
698
- mixtures.append(
699
- Mixture(
700
- targets=targets,
701
- name=str(m_id),
702
- noise=Noise(file_id=noise_file_id + 1, augmentation=noise_augmentation),
703
- noise_offset=noise_offset,
704
- samples=target_length,
661
+ for spectral_mask_id in range(len(config["spectral_masks"])):
662
+ for snr in snrs["noise"]:
663
+ mixtures.append(
664
+ Mixture(
665
+ name=str(m_id),
666
+ all_sources={
667
+ "primary": Source(
668
+ file_id=primary_file.id,
669
+ effects=primary_effect,
670
+ ),
671
+ "noise": Source(
672
+ file_id=noise_file.id,
673
+ effects=noise_effect,
674
+ start=noise_start,
675
+ repeat=True,
705
676
  snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
706
- spectral_mask_id=spectral_mask_id + 1,
707
- spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
708
- )
709
- )
710
- m_id += 1
711
-
712
- noise_offset = int((noise_offset + target_length) % noise_length)
713
- used_noise_samples += target_length
677
+ ),
678
+ },
679
+ samples=primary_length,
680
+ spectral_mask_id=spectral_mask_id + 1,
681
+ spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
682
+ )
683
+ )
684
+ noise_start = int((noise_start + primary_length) % noise_length)
685
+ m_id += 1
714
686
 
715
- return used_noise_files, used_noise_samples, mixtures
687
+ return mixtures
716
688
 
717
689
 
718
690
  def _non_exhaustive_noise_mix(
719
- augmented_targets: list[AugmentedTarget],
720
- target_files: list[TargetFile],
721
- target_augmentations: list[AugmentationRule],
722
- noise_files: list[NoiseFile],
723
- noise_augmentations: list[AugmentationRule],
724
- spectral_masks: list[SpectralMask],
725
- all_snrs: list[UniversalSNRGenerator],
726
- mixups: list[int],
727
- num_classes: int,
728
- feature_step_samples: int,
729
- num_ir: int,
730
- ) -> tuple[int, int, list[Mixture]]:
731
- """Cycle through every target/augmentation+interferences/augmentation without necessarily using all
732
- noise/augmentation combinations (reduced data set).
691
+ location: str,
692
+ config: dict,
693
+ effected_sources: dict[str, list[tuple[SourceFile, Effects]]],
694
+ test: bool = False,
695
+ ) -> list[Mixture]:
696
+ """Cycle through every source/effect+interferences/effect without necessarily using all
697
+ noise/effect combinations (reduced data set).
733
698
  """
734
699
  from random import randint
735
700
 
736
701
  import numpy as np
737
702
 
738
- from .datatypes import Mixture
739
- from .datatypes import Noise
740
- from .datatypes import UniversalSNR
741
- from .targets import get_augmented_target_ids_for_mixup
703
+ from ..datatypes import Mixture
704
+ from ..datatypes import UniversalSNR
705
+ from .effects import effects_from_rules
706
+ from .effects import estimate_effected_length
742
707
 
743
- m_id = 0
744
- used_noise_files = set()
745
- used_noise_samples = 0
746
- noise_file_id = None
747
- noise_augmentation_id = None
748
- noise_offset = None
749
-
750
- augmented_target_indices_for_mixups = [
751
- get_augmented_target_ids_for_mixup(
752
- augmented_targets=augmented_targets,
753
- targets=target_files,
754
- target_augmentations=target_augmentations,
755
- mixup=mixup,
756
- num_classes=num_classes,
757
- )
758
- for mixup in mixups
759
- ]
708
+ mixdb = MixtureDatabase(location, test)
709
+ snrs = get_all_snrs_from_config(config)
760
710
 
761
- mixtures: list[Mixture] = []
762
- for mixup in augmented_target_indices_for_mixups:
763
- for augmented_target_indices in mixup:
764
- targets, target_length = _get_target_info(
765
- augmented_target_ids=augmented_target_indices,
766
- augmented_targets=augmented_targets,
767
- target_files=target_files,
768
- target_augmentations=target_augmentations,
769
- feature_step_samples=feature_step_samples,
770
- num_ir=num_ir,
771
- )
772
-
773
- for spectral_mask_id in range(len(spectral_masks)):
774
- for snr in all_snrs:
775
- (
776
- noise_file_id,
777
- noise_augmentation_id,
778
- noise_augmentation,
779
- noise_offset,
780
- ) = _get_next_noise_offset(
781
- noise_file_id=noise_file_id,
782
- noise_augmentation_id=noise_augmentation_id,
783
- noise_offset=noise_offset,
784
- target_length=target_length,
785
- noise_files=noise_files,
786
- noise_augmentations=noise_augmentations,
787
- num_ir=num_ir,
788
- )
789
- used_noise_samples += target_length
711
+ next_noise = NextNoise(mixdb, effected_sources["noise"])
790
712
 
791
- used_noise_files.add(f"{noise_file_id}_{noise_augmentation_id}")
792
-
793
- mixtures.append(
794
- Mixture(
795
- targets=targets,
796
- name=str(m_id),
797
- noise=Noise(file_id=noise_file_id + 1, augmentation=noise_augmentation),
798
- noise_offset=noise_offset,
799
- samples=target_length,
800
- snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
801
- spectral_mask_id=spectral_mask_id + 1,
802
- spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
803
- )
713
+ m_id = 0
714
+ mixtures: list[Mixture] = []
715
+ for primary_file, primary_rule in effected_sources["primary"]:
716
+ primary_effect = effects_from_rules(mixdb, primary_rule)
717
+ primary_length = estimate_effected_length(primary_file.samples, primary_effect, mixdb.feature_step_samples)
718
+
719
+ for spectral_mask_id in range(len(config["spectral_masks"])):
720
+ for snr in snrs["noise"]:
721
+ noise_file_id, noise_effect, noise_start = next_noise.generate(primary_file.samples)
722
+
723
+ mixtures.append(
724
+ Mixture(
725
+ name=str(m_id),
726
+ all_sources={
727
+ "primary": Source(
728
+ file_id=primary_file.id,
729
+ effects=primary_effect,
730
+ ),
731
+ "noise": Source(
732
+ file_id=noise_file_id,
733
+ effects=noise_effect,
734
+ start=noise_start,
735
+ repeat=True,
736
+ snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
737
+ ),
738
+ },
739
+ samples=primary_length,
740
+ spectral_mask_id=spectral_mask_id + 1,
741
+ spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
804
742
  )
805
- m_id += 1
743
+ )
744
+ m_id += 1
806
745
 
807
- return len(used_noise_files), used_noise_samples, mixtures
746
+ return mixtures
808
747
 
809
748
 
810
749
  def _non_combinatorial_noise_mix(
811
- augmented_targets: list[AugmentedTarget],
812
- target_files: list[TargetFile],
813
- target_augmentations: list[AugmentationRule],
814
- noise_files: list[NoiseFile],
815
- noise_augmentations: list[AugmentationRule],
816
- spectral_masks: list[SpectralMask],
817
- all_snrs: list[UniversalSNRGenerator],
818
- mixups: list[int],
819
- num_classes: int,
820
- feature_step_samples: int,
821
- num_ir: int,
822
- ) -> tuple[int, int, list[Mixture]]:
823
- """Combine a target/augmentation+interferences/augmentation with a single cut of a noise/augmentation
824
- non-exhaustively (each target/augmentation+interferences/augmentation does not use each noise/augmentation).
825
- Cut has random start and loop back to beginning if end of noise/augmentation is reached.
750
+ location: str,
751
+ config: dict,
752
+ effected_sources: dict[str, list[tuple[SourceFile, Effects]]],
753
+ test: bool = False,
754
+ ) -> list[Mixture]:
755
+ """Combine a source/effect+interferences/effect with a single cut of a noise/effect
756
+ non-exhaustively (each source/effect+interferences/effect does not use each noise/effect).
757
+ Cut has random start and loop back to beginning if end of noise/effect is reached.
826
758
  """
827
759
  from random import choice
828
760
  from random import randint
829
761
 
830
762
  import numpy as np
831
763
 
832
- from .datatypes import Mixture
833
- from .datatypes import Noise
834
- from .datatypes import UniversalSNR
835
- from .targets import get_augmented_target_ids_for_mixup
764
+ from ..datatypes import Mixture
765
+ from ..datatypes import UniversalSNR
766
+ from .effects import effects_from_rules
767
+ from .effects import estimate_effected_length
836
768
 
837
- m_id = 0
838
- used_noise_files = set()
839
- used_noise_samples = 0
840
- noise_file_id = None
841
- noise_augmentation_id = None
842
-
843
- augmented_target_indices_for_mixups = [
844
- get_augmented_target_ids_for_mixup(
845
- augmented_targets=augmented_targets,
846
- targets=target_files,
847
- target_augmentations=target_augmentations,
848
- mixup=mixup,
849
- num_classes=num_classes,
850
- )
851
- for mixup in mixups
852
- ]
769
+ mixdb = MixtureDatabase(location, test)
770
+ snrs = get_all_snrs_from_config(config)
853
771
 
772
+ m_id = 0
773
+ noise_id = 0
854
774
  mixtures: list[Mixture] = []
855
- for mixup in augmented_target_indices_for_mixups:
856
- for augmented_target_indices in mixup:
857
- targets, target_length = _get_target_info(
858
- augmented_target_ids=augmented_target_indices,
859
- augmented_targets=augmented_targets,
860
- target_files=target_files,
861
- target_augmentations=target_augmentations,
862
- feature_step_samples=feature_step_samples,
863
- num_ir=num_ir,
864
- )
865
-
866
- for spectral_mask_id in range(len(spectral_masks)):
867
- for snr in all_snrs:
868
- (
869
- noise_file_id,
870
- noise_augmentation_id,
871
- noise_augmentation,
872
- noise_length,
873
- ) = _get_next_noise_indices(
874
- noise_file_id=noise_file_id,
875
- noise_augmentation_id=noise_augmentation_id,
876
- noise_files=noise_files,
877
- noise_augmentations=noise_augmentations,
878
- num_ir=num_ir,
775
+ for primary_file, primary_rule in effected_sources["primary"]:
776
+ primary_effect = effects_from_rules(mixdb, primary_rule)
777
+ primary_length = estimate_effected_length(primary_file.samples, primary_effect, mixdb.feature_step_samples)
778
+
779
+ for spectral_mask_id in range(len(config["spectral_masks"])):
780
+ for snr in snrs["noise"]:
781
+ noise_file, noise_rule = effected_sources["noise"][noise_id]
782
+ noise_effect = effects_from_rules(mixdb, noise_rule)
783
+ noise_length = estimate_effected_length(noise_file.samples, noise_effect)
784
+
785
+ mixtures.append(
786
+ Mixture(
787
+ name=str(m_id),
788
+ all_sources={
789
+ "primary": Source(
790
+ file_id=primary_file.id,
791
+ effects=primary_effect,
792
+ ),
793
+ "noise": Source(
794
+ file_id=noise_file.id,
795
+ effects=noise_effect,
796
+ start=choice(range(noise_length)), # noqa: S311
797
+ repeat=True,
798
+ snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
799
+ ),
800
+ },
801
+ samples=primary_length,
802
+ spectral_mask_id=spectral_mask_id + 1,
803
+ spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
879
804
  )
880
- used_noise_samples += target_length
805
+ )
806
+ noise_id = (noise_id + 1) % len(effected_sources["noise"])
807
+ m_id += 1
881
808
 
882
- used_noise_files.add(f"{noise_file_id}_{noise_augmentation_id}")
809
+ return mixtures
883
810
 
884
- mixtures.append(
885
- Mixture(
886
- targets=targets,
887
- name=str(m_id),
888
- noise=Noise(file_id=noise_file_id + 1, augmentation=noise_augmentation),
889
- noise_offset=choice(range(noise_length)), # noqa: S311
890
- samples=target_length,
891
- snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
892
- spectral_mask_id=spectral_mask_id + 1,
893
- spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
894
- )
895
- )
896
- m_id += 1
897
811
 
898
- return len(used_noise_files), used_noise_samples, mixtures
899
-
900
-
901
- def _get_next_noise_indices(
902
- noise_file_id: int | None,
903
- noise_augmentation_id: int | None,
904
- noise_files: list[NoiseFile],
905
- noise_augmentations: list[AugmentationRule],
906
- num_ir: int,
907
- ) -> tuple[int, int, Augmentation, int]:
908
- from .augmentation import augmentation_from_rule
909
- from .augmentation import estimate_augmented_length_from_length
910
-
911
- if noise_file_id is None or noise_augmentation_id is None:
912
- noise_file_id = 0
913
- noise_augmentation_id = 0
914
- else:
915
- noise_augmentation_id += 1
916
- if noise_augmentation_id == len(noise_augmentations):
917
- noise_augmentation_id = 0
918
- noise_file_id += 1
919
- if noise_file_id == len(noise_files):
920
- noise_file_id = 0
921
-
922
- noise_augmentation = augmentation_from_rule(noise_augmentations[noise_augmentation_id], num_ir)
923
- noise_length = estimate_augmented_length_from_length(
924
- length=noise_files[noise_file_id].samples, tempo=noise_augmentation.pre.tempo
925
- )
926
- return noise_file_id, noise_augmentation_id, noise_augmentation, noise_length
927
-
928
-
929
- def _get_next_noise_offset(
930
- noise_file_id: int | None,
931
- noise_augmentation_id: int | None,
932
- noise_offset: int | None,
933
- target_length: int,
934
- noise_files: list[NoiseFile],
935
- noise_augmentations: list[AugmentationRule],
936
- num_ir: int,
937
- ) -> tuple[int, int, Augmentation, int]:
938
- from .augmentation import augmentation_from_rule
939
- from .augmentation import estimate_augmented_length_from_length
940
-
941
- if noise_file_id is None or noise_augmentation_id is None or noise_offset is None:
942
- noise_file_id = 0
943
- noise_augmentation_id = 0
944
- noise_offset = 0
945
-
946
- noise_augmentation = augmentation_from_rule(noise_augmentations[noise_file_id], num_ir)
947
- noise_length = estimate_augmented_length_from_length(
948
- length=noise_files[noise_file_id].samples, tempo=noise_augmentation.pre.tempo
949
- )
950
- if noise_offset + target_length >= noise_length:
951
- if noise_offset == 0:
952
- raise ValueError("Length of target audio exceeds length of noise audio")
953
-
954
- noise_offset = 0
955
- noise_augmentation_id += 1
956
- if noise_augmentation_id == len(noise_augmentations):
957
- noise_augmentation_id = 0
958
- noise_file_id += 1
959
- if noise_file_id == len(noise_files):
960
- noise_file_id = 0
961
- noise_augmentation = augmentation_from_rule(noise_augmentations[noise_augmentation_id], num_ir)
962
-
963
- return noise_file_id, noise_augmentation_id, noise_augmentation, noise_offset
964
-
965
-
966
- def _get_target_info(
967
- augmented_target_ids: list[int],
968
- augmented_targets: list[AugmentedTarget],
969
- target_files: list[TargetFile],
970
- target_augmentations: list[AugmentationRule],
971
- feature_step_samples: int,
972
- num_ir: int,
973
- ) -> tuple[list[Target], int]:
974
- from .augmentation import augmentation_from_rule
975
- from .augmentation import estimate_augmented_length_from_length
976
-
977
- mixups: list[Target] = []
978
- target_length = 0
979
- for idx in augmented_target_ids:
980
- tfi = augmented_targets[idx].target_id
981
- target_augmentation_rule = target_augmentations[augmented_targets[idx].target_augmentation_id]
982
- target_augmentation = augmentation_from_rule(target_augmentation_rule, num_ir)
983
-
984
- mixups.append(Target(file_id=tfi + 1, augmentation=target_augmentation))
985
-
986
- target_length = max(
987
- estimate_augmented_length_from_length(
988
- length=target_files[tfi].samples,
989
- tempo=target_augmentation.pre.tempo,
990
- frame_length=feature_step_samples,
991
- ),
992
- target_length,
993
- )
994
- return mixups, target_length
812
+ class NextNoise:
813
+ def __init__(self, mixdb: MixtureDatabase, effected_noises: list[tuple[SourceFile, Effects]]) -> None:
814
+ from .effects import effects_from_rules
815
+ from .effects import estimate_effected_length
995
816
 
817
+ self.mixdb = mixdb
818
+ self.effected_noises = effected_noises
996
819
 
997
- def get_all_snrs_from_config(config: dict) -> list[UniversalSNRGenerator]:
998
- from .datatypes import UniversalSNRGenerator
820
+ self.noise_start = 0
821
+ self.noise_id = 0
822
+ self.noise_effect = effects_from_rules(self.mixdb, self.noise_rule)
823
+ self.noise_length = estimate_effected_length(self.noise_file.samples, self.noise_effect)
999
824
 
1000
- return [UniversalSNRGenerator(is_random=False, _raw_value=snr) for snr in config["snrs"]] + [
1001
- UniversalSNRGenerator(is_random=True, _raw_value=snr) for snr in config["random_snrs"]
1002
- ]
825
+ @property
826
+ def noise_file(self):
827
+ return self.effected_noises[self.noise_id][0]
828
+
829
+ @property
830
+ def noise_rule(self):
831
+ return self.effected_noises[self.noise_id][1]
1003
832
 
833
+ def generate(self, length: int) -> tuple[int, Effects, int]:
834
+ from .effects import effects_from_rules
835
+ from .effects import estimate_effected_length
836
+
837
+ if self.noise_start + length > self.noise_length:
838
+ # Not enough samples in current noise
839
+ if self.noise_start == 0:
840
+ raise ValueError("Length of primary audio exceeds length of noise audio")
841
+
842
+ self.noise_start = 0
843
+ self.noise_id = (self.noise_id + 1) % len(self.effected_noises)
844
+ self.noise_effect = effects_from_rules(self.mixdb, self.noise_rule)
845
+ self.noise_length = estimate_effected_length(self.noise_file.samples, self.noise_effect)
846
+ noise_start = self.noise_start
847
+ else:
848
+ # Current noise has enough samples
849
+ noise_start = self.noise_start
850
+ self.noise_start += length
1004
851
 
1005
- def _get_textgrid_tiers_from_target_file(target_file: str) -> list[str]:
852
+ return self.noise_file.id, self.noise_effect, noise_start
853
+
854
+
855
+ def get_all_snrs_from_config(config: dict) -> dict[str, list[UniversalSNRGenerator]]:
856
+ snrs: dict[str, list[UniversalSNRGenerator]] = {}
857
+ for category in config["sources"]:
858
+ if category != "primary":
859
+ snrs[category] = [UniversalSNRGenerator(snr) for snr in config["sources"][category]["snrs"]]
860
+ return snrs
861
+
862
+
863
+ def _get_textgrid_tiers_from_source_file(file: str) -> list[str]:
1006
864
  from pathlib import Path
1007
865
 
1008
866
  from praatio import textgrid
1009
867
 
1010
- from sonusai.mixture import tokenized_expand
868
+ from ..utils.tokenized_shell_vars import tokenized_expand
1011
869
 
1012
- textgrid_file = Path(tokenized_expand(target_file)[0]).with_suffix(".TextGrid")
870
+ textgrid_file = Path(tokenized_expand(file)[0]).with_suffix(".TextGrid")
1013
871
  if not textgrid_file.exists():
1014
872
  return []
1015
873
 
@@ -1018,18 +876,18 @@ def _get_textgrid_tiers_from_target_file(target_file: str) -> list[str]:
1018
876
  return sorted(tg.tierNames)
1019
877
 
1020
878
 
1021
- def _populate_speaker_table(location: str, target_files: list[TargetFile], test: bool = False) -> None:
879
+ def _populate_speaker_table(location: str, source_files: list[SourceFile], test: bool = False) -> None:
1022
880
  """Populate speaker table"""
1023
881
  import json
1024
882
  from pathlib import Path
1025
883
 
1026
884
  import yaml
1027
885
 
886
+ from ..utils.tokenized_shell_vars import tokenized_expand
1028
887
  from .mixdb import db_connection
1029
- from .tokenized_shell_vars import tokenized_expand
1030
888
 
1031
889
  # Determine columns for speaker table
1032
- all_parents = {Path(target_file.name).parent for target_file in target_files}
890
+ all_parents = {Path(file.name).parent for file in source_files}
1033
891
  speaker_parents = (parent for parent in all_parents if Path(tokenized_expand(parent / "speaker.yml")[0]).exists())
1034
892
 
1035
893
  speakers: dict[Path, dict[str, str]] = {}
@@ -1072,13 +930,13 @@ def _populate_speaker_table(location: str, target_files: list[TargetFile], test:
1072
930
  )
1073
931
 
1074
932
  if "speaker_id" in tiers:
1075
- con.execute("CREATE INDEX speaker_speaker_id_idx ON speaker (speaker_id)")
933
+ con.execute("CREATE INDEX speaker_speaker_id_idx ON source_file (speaker_id)")
1076
934
 
1077
935
  con.commit()
1078
936
  con.close()
1079
937
 
1080
938
 
1081
- def _populate_truth_config_table(location: str, target_files: list[TargetFile], test: bool = False) -> None:
939
+ def _populate_truth_config_table(location: str, source_files: list[SourceFile], test: bool = False) -> None:
1082
940
  """Populate truth_config table"""
1083
941
  import json
1084
942
 
@@ -1088,8 +946,8 @@ def _populate_truth_config_table(location: str, target_files: list[TargetFile],
1088
946
 
1089
947
  # Populate truth_config table
1090
948
  truth_configs: list[str] = []
1091
- for target_file in target_files:
1092
- for name, config in target_file.truth_configs.items():
949
+ for file in source_files:
950
+ for name, config in file.truth_configs.items():
1093
951
  ts = json.dumps({"name": name} | config.to_dict())
1094
952
  if ts not in truth_configs:
1095
953
  truth_configs.append(ts)
@@ -1100,3 +958,18 @@ def _populate_truth_config_table(location: str, target_files: list[TargetFile],
1100
958
 
1101
959
  con.commit()
1102
960
  con.close()
961
+
962
+
963
+ def _populate_impulse_response_tag_table(location: str, files: list[ImpulseResponseFile], test: bool = False) -> None:
964
+ """Populate ir_tag table"""
965
+ from .mixdb import db_connection
966
+
967
+ con = db_connection(location=location, readonly=False, test=test)
968
+
969
+ con.executemany(
970
+ "INSERT INTO ir_tag (tag) VALUES (?)",
971
+ [(tag,) for tag in {tag for file in files for tag in file.tags}],
972
+ )
973
+
974
+ con.commit()
975
+ con.close()