sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,1116 @@
1
+ # ruff: noqa: S608
2
+ import json
3
+ from functools import partial
4
+ from os.path import join
5
+ from pathlib import Path
6
+ from random import choice
7
+ from random import randint
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import yaml
12
+ from praatio import textgrid
13
+
14
+ from .. import logger
15
+ from ..config.config import load_config
16
+ from ..config.ir import get_ir_files
17
+ from ..config.source import get_source_files
18
+ from ..constants import SAMPLE_BYTES
19
+ from ..constants import SAMPLE_RATE
20
+ from ..datatypes import AudioT
21
+ from ..datatypes import Effects
22
+ from ..datatypes import GenMixData
23
+ from ..datatypes import ImpulseResponseFile
24
+ from ..datatypes import Mixture
25
+ from ..datatypes import Source
26
+ from ..datatypes import SourceFile
27
+ from ..datatypes import SourcesAudioT
28
+ from ..datatypes import UniversalSNRGenerator
29
+ from ..utils.human_readable_size import human_readable_size
30
+ from ..utils.seconds_to_hms import seconds_to_hms
31
+ from .db import SQLiteDatabase
32
+ from .effects import get_effect_rules
33
+ from .mixdb import MixtureDatabase
34
+
35
+
36
+ def config_file(location: str) -> str:
37
+ return join(location, "config.yml")
38
+
39
+
40
+ # Database schema definition
41
+ DATABASE_SCHEMA = [
42
+ """
43
+ CREATE TABLE truth_config(
44
+ id INTEGER PRIMARY KEY NOT NULL,
45
+ config TEXT NOT NULL)
46
+ """,
47
+ """
48
+ CREATE TABLE truth_parameters(
49
+ id INTEGER PRIMARY KEY NOT NULL,
50
+ category TEXT NOT NULL,
51
+ name TEXT NOT NULL,
52
+ parameters INTEGER)
53
+ """,
54
+ """
55
+ CREATE TABLE source_file (
56
+ id INTEGER PRIMARY KEY NOT NULL,
57
+ category TEXT NOT NULL,
58
+ class_indices TEXT,
59
+ level_type TEXT NOT NULL,
60
+ name TEXT NOT NULL,
61
+ samples INTEGER NOT NULL,
62
+ speaker_id INTEGER,
63
+ FOREIGN KEY(speaker_id) REFERENCES speaker (id))
64
+ """,
65
+ """
66
+ CREATE TABLE ir_file (
67
+ id INTEGER PRIMARY KEY NOT NULL,
68
+ delay INTEGER NOT NULL,
69
+ name TEXT NOT NULL)
70
+ """,
71
+ """
72
+ CREATE TABLE ir_tag (
73
+ id INTEGER PRIMARY KEY NOT NULL,
74
+ tag TEXT NOT NULL UNIQUE)
75
+ """,
76
+ """
77
+ CREATE TABLE ir_file_ir_tag (
78
+ file_id INTEGER NOT NULL,
79
+ tag_id INTEGER NOT NULL,
80
+ FOREIGN KEY(file_id) REFERENCES ir_file (id),
81
+ FOREIGN KEY(tag_id) REFERENCES ir_tag (id))
82
+ """,
83
+ """
84
+ CREATE TABLE speaker (
85
+ id INTEGER PRIMARY KEY NOT NULL,
86
+ parent TEXT NOT NULL)
87
+ """,
88
+ """
89
+ CREATE TABLE top (
90
+ id INTEGER PRIMARY KEY NOT NULL,
91
+ asr_configs TEXT NOT NULL,
92
+ class_balancing BOOLEAN NOT NULL,
93
+ feature TEXT NOT NULL,
94
+ mixid_width INTEGER NOT NULL,
95
+ num_classes INTEGER NOT NULL,
96
+ seed INTEGER NOT NULL,
97
+ speaker_metadata_tiers TEXT NOT NULL,
98
+ textgrid_metadata_tiers TEXT NOT NULL,
99
+ version INTEGER NOT NULL)
100
+ """,
101
+ """
102
+ CREATE TABLE class_label (
103
+ id INTEGER PRIMARY KEY NOT NULL,
104
+ label TEXT NOT NULL)
105
+ """,
106
+ """
107
+ CREATE TABLE class_weights_threshold (
108
+ id INTEGER PRIMARY KEY NOT NULL,
109
+ threshold FLOAT NOT NULL)
110
+ """,
111
+ """
112
+ CREATE TABLE spectral_mask (
113
+ id INTEGER PRIMARY KEY NOT NULL,
114
+ f_max_width INTEGER NOT NULL,
115
+ f_num INTEGER NOT NULL,
116
+ t_max_percent INTEGER NOT NULL,
117
+ t_max_width INTEGER NOT NULL,
118
+ t_num INTEGER NOT NULL)
119
+ """,
120
+ """
121
+ CREATE TABLE source_file_truth_config (
122
+ source_file_id INTEGER NOT NULL,
123
+ truth_config_id INTEGER NOT NULL,
124
+ FOREIGN KEY(source_file_id) REFERENCES source_file (id),
125
+ FOREIGN KEY(truth_config_id) REFERENCES truth_config (id))
126
+ """,
127
+ """
128
+ CREATE TABLE source (
129
+ id INTEGER PRIMARY KEY NOT NULL,
130
+ effects TEXT NOT NULL,
131
+ file_id INTEGER NOT NULL,
132
+ pre_tempo FLOAT NOT NULL,
133
+ repeat BOOLEAN NOT NULL,
134
+ snr FLOAT NOT NULL,
135
+ snr_gain FLOAT NOT NULL,
136
+ snr_random BOOLEAN NOT NULL,
137
+ start INTEGER NOT NULL,
138
+ UNIQUE(effects, file_id, pre_tempo, repeat, snr, snr_gain, snr_random, start),
139
+ FOREIGN KEY(file_id) REFERENCES source_file (id))
140
+ """,
141
+ """
142
+ CREATE TABLE mixture (
143
+ id INTEGER PRIMARY KEY NOT NULL,
144
+ name TEXT NOT NULL,
145
+ samples INTEGER NOT NULL,
146
+ spectral_mask_id INTEGER NOT NULL,
147
+ spectral_mask_seed INTEGER NOT NULL,
148
+ FOREIGN KEY(spectral_mask_id) REFERENCES spectral_mask (id))
149
+ """,
150
+ """
151
+ CREATE TABLE mixture_source (
152
+ mixture_id INTEGER NOT NULL,
153
+ source_id INTEGER NOT NULL,
154
+ FOREIGN KEY(mixture_id) REFERENCES mixture (id),
155
+ FOREIGN KEY(source_id) REFERENCES source (id))
156
+ """,
157
+ ]
158
+
159
+
160
+ class DatabaseManager:
161
+ """Manages database operations for mixture database generation."""
162
+
163
+ def __init__(self, location: str, test: bool = False, verbose: bool = False, logging: bool = False) -> None:
164
+ self.location = location
165
+ self.test = test
166
+ self.verbose = verbose
167
+ self.logging = logging
168
+
169
+ self.config = load_config(self.location)
170
+ self.db = partial(SQLiteDatabase, location=self.location, test=self.test, verbose=self.verbose)
171
+
172
+ with self.db(create=True) as c:
173
+ for table_sql in DATABASE_SCHEMA:
174
+ c.execute(table_sql)
175
+
176
+ self.mixdb = MixtureDatabase(location=self.location, test=self.test)
177
+
178
+ def populate_top_table(self) -> None:
179
+ """Populate the top table"""
180
+ from .constants import MIXDB_VERSION
181
+
182
+ parameters = (
183
+ 1,
184
+ json.dumps(self.config["asr_configs"]),
185
+ self.config["class_balancing"],
186
+ self.config["feature"],
187
+ 0,
188
+ self.config["num_classes"],
189
+ self.config["seed"],
190
+ "",
191
+ "",
192
+ MIXDB_VERSION,
193
+ )
194
+
195
+ with self.db(readonly=False) as c:
196
+ c.execute(
197
+ """
198
+ INSERT INTO top (id, asr_configs, class_balancing, feature, mixid_width, num_classes,
199
+ seed, speaker_metadata_tiers, textgrid_metadata_tiers, version)
200
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
201
+ """,
202
+ parameters,
203
+ )
204
+
205
+ def populate_class_label_table(self) -> None:
206
+ """Populate the class_label table"""
207
+ with self.db(readonly=False) as c:
208
+ c.executemany(
209
+ "INSERT INTO class_label (label) VALUES (?)",
210
+ [(item,) for item in self.config["class_labels"]],
211
+ )
212
+
213
+ def populate_class_weights_threshold_table(self) -> None:
214
+ """Populate the class_weights_threshold table"""
215
+ class_weights_threshold = self.config["class_weights_threshold"]
216
+ num_classes = self.config["num_classes"]
217
+
218
+ if not isinstance(class_weights_threshold, list):
219
+ class_weights_threshold = [class_weights_threshold]
220
+
221
+ if len(class_weights_threshold) == 1:
222
+ class_weights_threshold = [class_weights_threshold[0]] * num_classes
223
+
224
+ if len(class_weights_threshold) != num_classes:
225
+ raise ValueError(f"invalid class_weights_threshold length: {len(class_weights_threshold)}")
226
+
227
+ with self.db(readonly=False) as c:
228
+ c.executemany(
229
+ "INSERT INTO class_weights_threshold (threshold) VALUES (?)",
230
+ [(item,) for item in class_weights_threshold],
231
+ )
232
+
233
+ def populate_spectral_mask_table(self) -> None:
234
+ """Populate the spectral_mask table"""
235
+ from ..config.spectral_masks import get_spectral_masks
236
+
237
+ with self.db(readonly=False) as c:
238
+ c.executemany(
239
+ """
240
+ INSERT INTO spectral_mask (f_max_width, f_num, t_max_percent, t_max_width, t_num) VALUES (?, ?, ?, ?, ?)
241
+ """,
242
+ [
243
+ (
244
+ item.f_max_width,
245
+ item.f_num,
246
+ item.t_max_percent,
247
+ item.t_max_width,
248
+ item.t_num,
249
+ )
250
+ for item in get_spectral_masks(self.config)
251
+ ],
252
+ )
253
+
254
+ def populate_truth_parameters_table(self) -> None:
255
+ """Populate the truth_parameters table"""
256
+ from ..config.truth import get_truth_parameters
257
+
258
+ with self.db(readonly=False) as c:
259
+ c.executemany(
260
+ """
261
+ INSERT INTO truth_parameters (category, name, parameters) VALUES (?, ?, ?)
262
+ """,
263
+ [
264
+ (
265
+ item.category,
266
+ item.name,
267
+ item.parameters,
268
+ )
269
+ for item in get_truth_parameters(self.config)
270
+ ],
271
+ )
272
+
273
+ def populate_source_file_table(self, show_progress: bool = False) -> None:
274
+ """Populate the source file table"""
275
+ if self.logging:
276
+ logger.info("Collecting sources")
277
+
278
+ files = get_source_files(self.config, show_progress)
279
+ logger.info("")
280
+
281
+ if len([file for file in files if file.category == "primary"]) == 0:
282
+ raise RuntimeError("Canceled due to no primary sources")
283
+
284
+ if self.logging:
285
+ logger.info("Populating source file table")
286
+
287
+ self._populate_truth_config_table(files)
288
+ self._populate_speaker_table(files)
289
+
290
+ with self.db(readonly=False) as c:
291
+ textgrid_metadata_tiers: set[str] = set()
292
+ for file in files:
293
+ # Get TextGrid tiers for the source file and add to the collection
294
+ tiers = _get_textgrid_tiers_from_source_file(file.name)
295
+ for tier in tiers:
296
+ textgrid_metadata_tiers.add(tier)
297
+
298
+ # Get truth settings for the file
299
+ truth_config_ids: list[int] = []
300
+ if file.truth_configs:
301
+ for name, config in file.truth_configs.items():
302
+ ts = json.dumps({"name": name} | config.to_dict())
303
+ c.execute(
304
+ "SELECT truth_config.id FROM truth_config WHERE ? = truth_config.config",
305
+ (ts,),
306
+ )
307
+ truth_config_ids.append(c.fetchone()[0])
308
+
309
+ # Get speaker_id for the source file
310
+ c.execute(
311
+ "SELECT speaker.id FROM speaker WHERE ? = speaker.parent", (Path(file.name).parent.as_posix(),)
312
+ )
313
+ result = c.fetchone()
314
+ speaker_id = None
315
+ if result is not None:
316
+ speaker_id = result[0]
317
+
318
+ # Add entry
319
+ c.execute(
320
+ """
321
+ INSERT INTO source_file (category, class_indices, level_type, name, samples, speaker_id)
322
+ VALUES (?, ?, ?, ?, ?, ?)
323
+ """,
324
+ (
325
+ file.category,
326
+ json.dumps(file.class_indices),
327
+ file.level_type,
328
+ file.name,
329
+ file.samples,
330
+ speaker_id,
331
+ ),
332
+ )
333
+ source_file_id = c.lastrowid
334
+ for truth_config_id in truth_config_ids:
335
+ c.execute(
336
+ "INSERT INTO source_file_truth_config (source_file_id, truth_config_id) VALUES (?, ?)",
337
+ (source_file_id, truth_config_id),
338
+ )
339
+
340
+ # Update textgrid_metadata_tiers in the top table
341
+ c.execute(
342
+ "UPDATE top SET textgrid_metadata_tiers=? WHERE ? = id",
343
+ (json.dumps(sorted(textgrid_metadata_tiers)), 1),
344
+ )
345
+
346
+ if self.logging:
347
+ logger.info("Sources summary")
348
+ data = {
349
+ "category": [],
350
+ "files": [],
351
+ "size": [],
352
+ "duration": [],
353
+ }
354
+ for category, files in self.mixdb.source_files.items():
355
+ audio_samples = sum([source.samples for source in files])
356
+ audio_duration = audio_samples / SAMPLE_RATE
357
+ data["category"].append(category)
358
+ data["files"].append(self.mixdb.num_source_files(category))
359
+ data["size"].append(human_readable_size(audio_samples * SAMPLE_BYTES, 1))
360
+ data["duration"].append(seconds_to_hms(seconds=audio_duration))
361
+
362
+ df = pd.DataFrame(data)
363
+ logger.info(df.to_string(index=False, header=False))
364
+ logger.info("")
365
+
366
+ for category, files in self.mixdb.source_files.items():
367
+ logger.debug(f"List of {category} sources:")
368
+ logger.debug(yaml.dump([file.name for file in files], default_flow_style=False))
369
+
370
+ def populate_impulse_response_file_table(self, show_progress: bool = False) -> None:
371
+ """Populate the impulse response file table"""
372
+ if self.logging:
373
+ logger.info("Collecting impulse responses")
374
+
375
+ files = get_ir_files(self.config, show_progress=show_progress)
376
+ logger.info("")
377
+
378
+ if self.logging:
379
+ logger.info("Populating impulse response file table")
380
+
381
+ self._populate_impulse_response_tag_table(files)
382
+
383
+ with self.db(readonly=False) as c:
384
+ for file in files:
385
+ # Get the tags for the file
386
+ tag_ids: list[int] = []
387
+ for tag in file.tags:
388
+ c.execute("SELECT id FROM ir_tag WHERE ? = tag", (tag,))
389
+ tag_ids.append(c.fetchone()[0])
390
+
391
+ c.execute("INSERT INTO ir_file (delay, name) VALUES (?, ?)", (file.delay, file.name))
392
+
393
+ file_id = c.lastrowid
394
+ for tag_id in tag_ids:
395
+ c.execute("INSERT INTO ir_file_ir_tag (file_id, tag_id) VALUES (?, ?)", (file_id, tag_id))
396
+
397
+ if self.logging:
398
+ logger.debug("List of impulse responses:")
399
+ for idx, file in enumerate(files):
400
+ logger.debug(f"id: {idx}, name:{file.name}, delay: {file.delay}, tags: [{', '.join(file.tags)}]")
401
+ logger.debug("")
402
+
403
+ def populate_mixture_table(self, mixtures: list[Mixture], show_progress: bool = False) -> None:
404
+ """Populate the mixture table"""
405
+ from ..utils.parallel import track
406
+ from .helpers import from_mixture
407
+ from .helpers import from_source
408
+
409
+ if self.logging:
410
+ logger.info("Populating mixture and source tables")
411
+
412
+ with self.db(readonly=False) as c:
413
+ # Populate the source table
414
+ for mixture in track(mixtures, disable=not show_progress):
415
+ m_id = int(mixture.name) + 1
416
+ c.execute(
417
+ """
418
+ INSERT INTO mixture (id, name, samples, spectral_mask_id, spectral_mask_seed)
419
+ VALUES (?, ?, ?, ?, ?)
420
+ """,
421
+ (m_id, *from_mixture(mixture)),
422
+ )
423
+
424
+ for source in mixture.all_sources.values():
425
+ c.execute(
426
+ """
427
+ INSERT OR IGNORE INTO source (effects, file_id, pre_tempo, repeat, snr, snr_gain, snr_random, start)
428
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
429
+ """,
430
+ from_source(source),
431
+ )
432
+
433
+ source_id = c.execute(
434
+ """
435
+ SELECT id
436
+ FROM source
437
+ WHERE ? = effects
438
+ AND ? = file_id
439
+ AND ? = pre_tempo
440
+ AND ? = repeat
441
+ AND ? = snr
442
+ AND ? = snr_gain
443
+ AND ? = snr_random
444
+ AND ? = start
445
+ """,
446
+ from_source(source),
447
+ ).fetchone()[0]
448
+ c.execute("INSERT INTO mixture_source (mixture_id, source_id) VALUES (?, ?)", (m_id, source_id))
449
+
450
+ if self.logging:
451
+ logger.info("Closing mixture and source tables")
452
+
453
+ def _populate_speaker_table(self, source_files: list[SourceFile]) -> None:
454
+ """Populate the speaker table"""
455
+ from ..utils.tokenized_shell_vars import tokenized_expand
456
+
457
+ # Determine the columns for speaker the table
458
+ all_parents = {Path(file.name).parent for file in source_files}
459
+ speaker_parents = (
460
+ parent for parent in all_parents if Path(tokenized_expand(parent / "speaker.yml")[0]).exists()
461
+ )
462
+
463
+ speakers: dict[Path, dict[str, str]] = {}
464
+ for parent in sorted(speaker_parents):
465
+ with open(tokenized_expand(parent / "speaker.yml")[0]) as f:
466
+ speakers[parent] = yaml.safe_load(f)
467
+
468
+ new_columns: list[str] = []
469
+ for keys in speakers:
470
+ for column in speakers[keys]:
471
+ new_columns.append(column)
472
+ new_columns = sorted(set(new_columns))
473
+
474
+ with self.db(readonly=False) as c:
475
+ for new_column in new_columns:
476
+ c.execute(f"ALTER TABLE speaker ADD COLUMN {new_column} TEXT")
477
+
478
+ # Populate the speaker table
479
+ speaker_rows: list[tuple[str, ...]] = []
480
+ for key in speakers:
481
+ entry = (speakers[key].get(column, None) for column in new_columns)
482
+ speaker_rows.append((key.as_posix(), *entry)) # type: ignore[arg-type]
483
+
484
+ column_ids = ", ".join(["parent", *new_columns])
485
+ column_values = ", ".join(["?"] * (len(new_columns) + 1))
486
+ c.executemany(f"INSERT INTO speaker ({column_ids}) VALUES ({column_values})", speaker_rows)
487
+
488
+ c.execute("CREATE INDEX speaker_parent_idx ON speaker (parent)")
489
+
490
+ # Update speaker_metadata_tiers in the top table
491
+ tiers = [
492
+ description[0]
493
+ for description in c.execute("SELECT * FROM speaker").description
494
+ if description[0] not in ("id", "parent")
495
+ ]
496
+ c.execute("UPDATE top SET speaker_metadata_tiers=? WHERE ? = id", (json.dumps(tiers), 1))
497
+
498
+ if "speaker_id" in tiers:
499
+ c.execute("CREATE INDEX speaker_speaker_id_idx ON source_file (speaker_id)")
500
+
501
+ def _populate_truth_config_table(self, source_files: list[SourceFile]) -> None:
502
+ """Populate the truth_config table"""
503
+ with self.db(readonly=False) as c:
504
+ # Populate truth_config table
505
+ truth_configs: list[str] = []
506
+ for file in source_files:
507
+ for name, config in file.truth_configs.items():
508
+ ts = json.dumps({"name": name} | config.to_dict())
509
+ if ts not in truth_configs:
510
+ truth_configs.append(ts)
511
+ c.executemany(
512
+ "INSERT INTO truth_config (config) VALUES (?)",
513
+ [(item,) for item in truth_configs],
514
+ )
515
+
516
+ def _populate_impulse_response_tag_table(self, files: list[ImpulseResponseFile]) -> None:
517
+ """Populate the ir_tag table"""
518
+ with self.db(readonly=False) as c:
519
+ c.executemany(
520
+ "INSERT INTO ir_tag (tag) VALUES (?)",
521
+ [(tag,) for tag in {tag for file in files for tag in file.tags}],
522
+ )
523
+
524
+ def generate_mixtures(self) -> list[Mixture]:
525
+ """Generate mixtures"""
526
+ from ..utils.max_text_width import max_text_width
527
+
528
+ if self.logging:
529
+ logger.info("Collecting effects")
530
+
531
+ rules = get_effect_rules(self.location, self.config, self.test)
532
+
533
+ if self.logging:
534
+ logger.info("")
535
+ for category, effect in rules.items():
536
+ logger.debug(f"List of {category} rules:")
537
+ logger.debug(yaml.dump([entry.to_dict() for entry in effect], default_flow_style=False))
538
+
539
+ if self.logging:
540
+ logger.debug("SNRS:")
541
+ for category, source in self.config["sources"].items():
542
+ if category != "primary":
543
+ logger.debug(f" {category}")
544
+ for snr in source["snrs"]:
545
+ logger.debug(f" - {snr}")
546
+ logger.debug("")
547
+ logger.debug("Mix Rules:")
548
+ for category, source in self.config["sources"].items():
549
+ if category != "primary":
550
+ logger.debug(f" {category}")
551
+ for mix_rule in source["mix_rules"]:
552
+ logger.debug(f" - {mix_rule}")
553
+ logger.debug("")
554
+ logger.debug("Spectral masks:")
555
+ for spectral_mask in self.mixdb.spectral_masks:
556
+ logger.debug(f"- {spectral_mask}")
557
+ logger.debug("")
558
+
559
+ if self.logging:
560
+ logger.info("Generating mixtures")
561
+
562
+ effected_sources: dict[str, list[tuple[SourceFile, Effects]]] = {}
563
+ for category in self.mixdb.source_files:
564
+ effected_sources[category] = []
565
+ for file in self.mixdb.source_files[category]:
566
+ for rule in rules[category]:
567
+ effected_sources[category].append((file, rule))
568
+
569
+ # First, create mixtures of primary and noise
570
+ mixtures: list[Mixture] = []
571
+ for mix_rule in self.config["sources"]["noise"]["mix_rules"]:
572
+ mixtures.extend(
573
+ self._process_noise_sources(
574
+ primary_effected_sources=effected_sources["primary"],
575
+ noise_effected_sources=effected_sources["noise"],
576
+ mix_rule=mix_rule,
577
+ )
578
+ )
579
+
580
+ # Next, cycle through any additional sources and apply mix rules for each
581
+ additional_sources = [cat for cat in self.mixdb.source_files if cat not in ("primary", "noise")]
582
+ for category in additional_sources:
583
+ new_mixtures: list[Mixture] = []
584
+ for mix_rule in self.config["sources"][category]["mix_rules"]:
585
+ new_mixtures.extend(
586
+ self._process_additional_sources(
587
+ effected_sources=effected_sources[category],
588
+ mixtures=mixtures,
589
+ category=category,
590
+ mix_rule=mix_rule,
591
+ )
592
+ )
593
+ mixtures.extend(new_mixtures)
594
+
595
+ # Update the mixid width
596
+ with self.db(readonly=False) as c:
597
+ c.execute("UPDATE top SET mixid_width=? WHERE ? = id", (max_text_width(len(mixtures)), 1))
598
+
599
+ return mixtures
600
+
601
+ def _process_noise_sources(
602
+ self,
603
+ primary_effected_sources: list[tuple[SourceFile, Effects]],
604
+ noise_effected_sources: list[tuple[SourceFile, Effects]],
605
+ mix_rule: str,
606
+ ) -> list[Mixture]:
607
+ match mix_rule:
608
+ case "exhaustive":
609
+ return self._noise_exhaustive(primary_effected_sources, noise_effected_sources)
610
+ case "non-exhaustive":
611
+ return self._noise_non_exhaustive(primary_effected_sources, noise_effected_sources)
612
+ case "non-combinatorial":
613
+ return self._noise_non_combinatorial(primary_effected_sources, noise_effected_sources)
614
+ case _:
615
+ raise ValueError(f"invalid noise mix_rule: {mix_rule}")
616
+
617
+ def _noise_exhaustive(
618
+ self,
619
+ primary_effected_sources: list[tuple[SourceFile, Effects]],
620
+ noise_effected_sources: list[tuple[SourceFile, Effects]],
621
+ ) -> list[Mixture]:
622
+ """Use every noise/effect with every source/effect+interferences/effect"""
623
+ from ..datatypes import Mixture
624
+ from ..datatypes import UniversalSNR
625
+ from .effects import effects_from_rules
626
+ from .effects import estimate_effected_length
627
+
628
+ snrs = self.all_snrs()
629
+
630
+ mixtures: list[Mixture] = []
631
+ for noise_file, noise_rule in noise_effected_sources:
632
+ noise_start = 0
633
+ noise_effect = effects_from_rules(self.mixdb, noise_rule)
634
+ noise_length = estimate_effected_length(noise_file.samples, noise_effect)
635
+
636
+ for primary_file, primary_rule in primary_effected_sources:
637
+ primary_effect = effects_from_rules(self.mixdb, primary_rule)
638
+ primary_length = estimate_effected_length(
639
+ primary_file.samples, primary_effect, self.mixdb.feature_step_samples
640
+ )
641
+
642
+ for spectral_mask_id in range(len(self.config["spectral_masks"])):
643
+ for snr in snrs["noise"]:
644
+ mixtures.append(
645
+ Mixture(
646
+ name="",
647
+ all_sources={
648
+ "primary": Source(
649
+ file_id=primary_file.id,
650
+ effects=primary_effect,
651
+ ),
652
+ "noise": Source(
653
+ file_id=noise_file.id,
654
+ effects=noise_effect,
655
+ start=noise_start,
656
+ loop=True,
657
+ snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
658
+ ),
659
+ },
660
+ samples=primary_length,
661
+ spectral_mask_id=spectral_mask_id + 1,
662
+ spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
663
+ )
664
+ )
665
+ noise_start = int((noise_start + primary_length) % noise_length)
666
+
667
+ return mixtures
668
+
669
+ def _noise_non_exhaustive(
670
+ self,
671
+ primary_effected_sources: list[tuple[SourceFile, Effects]],
672
+ noise_effected_sources: list[tuple[SourceFile, Effects]],
673
+ ) -> list[Mixture]:
674
+ """Cycle through every source/effect+interferences/effect without necessarily using all
675
+ noise/effect combinations (reduced data set).
676
+ """
677
+ from ..datatypes import Mixture
678
+ from ..datatypes import UniversalSNR
679
+ from .effects import effects_from_rules
680
+ from .effects import estimate_effected_length
681
+
682
+ snrs = self.all_snrs()
683
+
684
+ next_noise = NextNoise(self.mixdb, noise_effected_sources)
685
+
686
+ mixtures: list[Mixture] = []
687
+ for primary_file, primary_rule in primary_effected_sources:
688
+ primary_effect = effects_from_rules(self.mixdb, primary_rule)
689
+ primary_length = estimate_effected_length(
690
+ primary_file.samples, primary_effect, self.mixdb.feature_step_samples
691
+ )
692
+
693
+ for spectral_mask_id in range(len(self.config["spectral_masks"])):
694
+ for snr in snrs["noise"]:
695
+ noise_file_id, noise_effect, noise_start = next_noise.generate(primary_file.samples)
696
+
697
+ mixtures.append(
698
+ Mixture(
699
+ name="",
700
+ all_sources={
701
+ "primary": Source(
702
+ file_id=primary_file.id,
703
+ effects=primary_effect,
704
+ ),
705
+ "noise": Source(
706
+ file_id=noise_file_id,
707
+ effects=noise_effect,
708
+ start=noise_start,
709
+ loop=True,
710
+ snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
711
+ ),
712
+ },
713
+ samples=primary_length,
714
+ spectral_mask_id=spectral_mask_id + 1,
715
+ spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
716
+ )
717
+ )
718
+
719
+ return mixtures
720
+
721
+ def _noise_non_combinatorial(
722
+ self,
723
+ primary_effected_sources: list[tuple[SourceFile, Effects]],
724
+ noise_effected_sources: list[tuple[SourceFile, Effects]],
725
+ ) -> list[Mixture]:
726
+ """Combine a source/effect+interferences/effect with a single cut of a noise/effect
727
+ non-exhaustively (each source/effect+interferences/effect does not use each noise/effect).
728
+ Cut has a random start and loop back to the beginning if the end of noise/effect is reached.
729
+ """
730
+ from ..datatypes import Mixture
731
+ from ..datatypes import UniversalSNR
732
+ from .effects import effects_from_rules
733
+ from .effects import estimate_effected_length
734
+
735
+ snrs = self.all_snrs()
736
+
737
+ noise_id = 0
738
+ mixtures: list[Mixture] = []
739
+ for primary_file, primary_rule in primary_effected_sources:
740
+ primary_effect = effects_from_rules(self.mixdb, primary_rule)
741
+ primary_length = estimate_effected_length(
742
+ primary_file.samples, primary_effect, self.mixdb.feature_step_samples
743
+ )
744
+
745
+ for spectral_mask_id in range(len(self.config["spectral_masks"])):
746
+ for snr in snrs["noise"]:
747
+ noise_file, noise_rule = noise_effected_sources[noise_id]
748
+ noise_effect = effects_from_rules(self.mixdb, noise_rule)
749
+ noise_length = estimate_effected_length(noise_file.samples, noise_effect)
750
+
751
+ mixtures.append(
752
+ Mixture(
753
+ name="",
754
+ all_sources={
755
+ "primary": Source(
756
+ file_id=primary_file.id,
757
+ effects=primary_effect,
758
+ ),
759
+ "noise": Source(
760
+ file_id=noise_file.id,
761
+ effects=noise_effect,
762
+ start=choice(range(noise_length)), # noqa: S311
763
+ loop=True,
764
+ snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
765
+ ),
766
+ },
767
+ samples=primary_length,
768
+ spectral_mask_id=spectral_mask_id + 1,
769
+ spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
770
+ )
771
+ )
772
+ noise_id = (noise_id + 1) % len(noise_effected_sources)
773
+
774
+ return mixtures
775
+
776
+ def _process_additional_sources(
777
+ self,
778
+ effected_sources: list[tuple[SourceFile, Effects]],
779
+ mixtures: list[Mixture],
780
+ category: str,
781
+ mix_rule: str,
782
+ ) -> list[Mixture]:
783
+ if mix_rule == "none":
784
+ return []
785
+ if mix_rule.startswith("choose"):
786
+ return self._additional_choose(
787
+ effected_sources=effected_sources,
788
+ mixtures=mixtures,
789
+ category=category,
790
+ mix_rule=mix_rule,
791
+ )
792
+ if mix_rule.startswith("sequence"):
793
+ return self._additional_sequence(
794
+ effected_sources=effected_sources,
795
+ mixtures=mixtures,
796
+ category=category,
797
+ mix_rule=mix_rule,
798
+ )
799
+ raise ValueError(f"invalid {category} mix_rule: {mix_rule}")
800
+
801
+ def _additional_choose(
802
+ self,
803
+ effected_sources: list[tuple[SourceFile, Effects]],
804
+ mixtures: list[Mixture],
805
+ category: str,
806
+ mix_rule: str,
807
+ ) -> list[Mixture]:
808
+ from copy import deepcopy
809
+
810
+ from ..datatypes import UniversalSNR
811
+ from ..parse.parse_source_directive import parse_source_directive
812
+ from ..utils.choice import RandomChoice
813
+
814
+ # Parse the mix rule
815
+ try:
816
+ params = parse_source_directive(mix_rule)
817
+ except ValueError as e:
818
+ raise ValueError(f"Error parsing choose directive: {e}") from e
819
+
820
+ snrs = self.all_snrs()[category]
821
+
822
+ choice_objs: dict[tuple[int, ...], RandomChoice] = {}
823
+ if params.unique == "speaker_id" :
824
+ # Get a set of speaker_id values in use in the existing mixtures.
825
+ all_speaker_ids = {self.get_mixture_speaker_ids(mixture) for mixture in mixtures}
826
+
827
+ # Create a set of RandomChoice objects that are filtered on those values.
828
+ for speaker_ids in all_speaker_ids:
829
+ filtered_sources = _filter_sources(effected_sources, speaker_ids)
830
+ if not filtered_sources:
831
+ raise ValueError(
832
+ f"Additional source, {category}, has no valid entries for speaker_ids unique from {speaker_ids}"
833
+ )
834
+ choice_objs[speaker_ids] = RandomChoice(data=filtered_sources, repetition=params.repeat)
835
+ elif params.unique is None:
836
+ choice_objs = {(0,): RandomChoice(data=effected_sources, repetition=params.repeat)}
837
+ else:
838
+ raise ValueError(f"Invalid unique value: {params.unique}")
839
+
840
+ # Loop over mixtures and add additional sources
841
+ new_mixtures: list[Mixture] = []
842
+ for mixture in mixtures:
843
+ for snr in snrs:
844
+ new_mixture = deepcopy(mixture)
845
+ if params.unique == "speaker_id" :
846
+ speaker_ids = self.get_mixture_speaker_ids(mixture)
847
+ elif params.unique is None:
848
+ speaker_ids = (0,)
849
+ else:
850
+ raise ValueError(f"Invalid unique value: {params.unique}")
851
+ source, effect = choice_objs[speaker_ids].next()
852
+ new_source = Source(
853
+ file_id=source.id,
854
+ effects=effect,
855
+ start=params.start,
856
+ loop=params.loop,
857
+ snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
858
+ )
859
+ new_mixture.all_sources[category] = new_source
860
+ new_mixtures.append(new_mixture)
861
+
862
+ return new_mixtures
863
+
864
+ def _additional_sequence(
865
+ self,
866
+ effected_sources: list[tuple[SourceFile, Effects]],
867
+ mixtures: list[Mixture],
868
+ category: str,
869
+ mix_rule: str,
870
+ ) -> list[Mixture]:
871
+ from copy import deepcopy
872
+
873
+ from ..datatypes import UniversalSNR
874
+ from ..parse.parse_source_directive import parse_source_directive
875
+ from ..utils.choice import SequentialChoice
876
+
877
+ # Parse the mix rule
878
+ try:
879
+ params = parse_source_directive(mix_rule)
880
+ except ValueError as e:
881
+ raise ValueError(f"Error parsing choose directive: {e}") from e
882
+
883
+ snrs = self.all_snrs()[category]
884
+
885
+ sequence_objs: dict[tuple[int, ...], SequentialChoice] = {}
886
+ if params.unique == "speaker_id" :
887
+ # Get a set of speaker_id values in use in the existing mixtures.
888
+ all_speaker_ids = {self.get_mixture_speaker_ids(mixture) for mixture in mixtures}
889
+
890
+ # Create a set of SequentialChoice objects that are filtered on those values.
891
+ for speaker_ids in all_speaker_ids:
892
+ filtered_sources = _filter_sources(effected_sources, speaker_ids)
893
+ if not filtered_sources:
894
+ raise ValueError(
895
+ f"Additional source, {category}, has no valid entries for speaker_ids unique from {speaker_ids}"
896
+ )
897
+ sequence_objs[speaker_ids] = SequentialChoice(data=filtered_sources)
898
+ elif params.unique is None:
899
+ sequence_objs = {(0,): SequentialChoice(data=effected_sources)}
900
+ else:
901
+ raise ValueError(f"Invalid unique value: {params.unique}")
902
+
903
+ # Loop over mixtures and add additional sources
904
+ new_mixtures: list[Mixture] = []
905
+ for mixture in mixtures:
906
+ for snr in snrs:
907
+ new_mixture = deepcopy(mixture)
908
+ if params.unique == "speaker_id" :
909
+ speaker_ids = self.get_mixture_speaker_ids(mixture)
910
+ elif params.unique is None:
911
+ speaker_ids = (0,)
912
+ else:
913
+ raise ValueError(f"Invalid unique value: {params.unique}")
914
+ source, effect = sequence_objs[speaker_ids].next()
915
+ new_source = Source(
916
+ file_id=source.id,
917
+ effects=effect,
918
+ start=params.start,
919
+ loop=params.loop,
920
+ snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
921
+ )
922
+ new_mixture.all_sources[category] = new_source
923
+ new_mixtures.append(new_mixture)
924
+
925
+ return new_mixtures
926
+
927
+ def get_mixture_speaker_ids(self, mixture: Mixture) -> tuple[int, ...]:
928
+ """Get the speaker IDs used in a mixture, excluding None values"""
929
+ valid_speaker_ids = [
930
+ speaker_id
931
+ for source in mixture.all_sources.values()
932
+ if (speaker_id := self.mixdb.source_file(source.file_id).speaker_id) is not None
933
+ ]
934
+ return tuple(valid_speaker_ids)
935
+
936
+ def all_snrs(self) -> dict[str, list[UniversalSNRGenerator]]:
937
+ snrs: dict[str, list[UniversalSNRGenerator]] = {}
938
+ for category in self.config["sources"]:
939
+ if category != "primary":
940
+ snrs[category] = [UniversalSNRGenerator(snr) for snr in self.config["sources"][category]["snrs"]]
941
+ return snrs
942
+
943
+
944
+ def update_mixture(mixdb: MixtureDatabase, mixture: Mixture, with_data: bool = False) -> tuple[Mixture, GenMixData]:
945
+ """Update mixture record with name, samples, and gains"""
946
+ sources_audio: SourcesAudioT = {}
947
+ post_audio: SourcesAudioT = {}
948
+ for category in mixture.all_sources:
949
+ mixture, sources_audio[category], post_audio[category] = _update_source(mixdb, mixture, category)
950
+
951
+ mixture = _initialize_mixture_gains(mixdb, mixture, post_audio)
952
+
953
+ if not with_data:
954
+ return mixture, GenMixData()
955
+
956
+ # Apply gains
957
+ post_audio = {
958
+ category: post_audio[category] * mixture.all_sources[category].snr_gain for category in mixture.all_sources
959
+ }
960
+
961
+ # Sum sources, noise, and mixture
962
+ source_audio = np.sum([post_audio[category] for category in mixture.sources], axis=0)
963
+ noise_audio = post_audio["noise"]
964
+ mixture_audio = source_audio + noise_audio
965
+
966
+ return mixture, GenMixData(
967
+ sources=sources_audio,
968
+ source=source_audio,
969
+ noise=noise_audio,
970
+ mixture=mixture_audio,
971
+ )
972
+
973
+
974
+ def _update_source(mixdb: MixtureDatabase, mixture: Mixture, category: str) -> tuple[Mixture, AudioT, AudioT]:
975
+ from .effects import apply_effects
976
+ from .effects import conform_audio_to_length
977
+
978
+ source = mixture.all_sources[category]
979
+ org_audio = mixdb.read_source_audio(source.file_id)
980
+
981
+ org_samples = len(org_audio)
982
+ pre_audio = apply_effects(mixdb, org_audio, source.effects, pre=True, post=False)
983
+
984
+ pre_samples = len(pre_audio)
985
+ mixture.all_sources[category].pre_tempo = org_samples / pre_samples
986
+
987
+ pre_audio = conform_audio_to_length(pre_audio, mixture.samples, source.loop, source.start)
988
+
989
+ post_audio = apply_effects(mixdb, pre_audio, source.effects, pre=False, post=True)
990
+ if len(pre_audio) != len(post_audio):
991
+ raise RuntimeError(f"post-truth effects changed length: {source.effects.post}")
992
+
993
+ return mixture, pre_audio, post_audio
994
+
995
+
996
+ def _initialize_mixture_gains(mixdb: MixtureDatabase, mixture: Mixture, sources_audio: SourcesAudioT) -> Mixture:
997
+ from ..utils.asl_p56 import asl_p56
998
+ from ..utils.db import db_to_linear
999
+
1000
+ sources_energy: dict[str, float] = {}
1001
+ for category in mixture.all_sources:
1002
+ level_type = mixdb.source_file(mixture.all_sources[category].file_id).level_type
1003
+ match level_type:
1004
+ case "default":
1005
+ sources_energy[category] = float(np.mean(np.square(sources_audio[category])))
1006
+ case "speech":
1007
+ sources_energy[category] = asl_p56(sources_audio[category])
1008
+ case _:
1009
+ raise ValueError(f"Unknown level_type: {level_type}")
1010
+
1011
+ # Initialize all gains to 1
1012
+ for category in mixture.all_sources:
1013
+ mixture.all_sources[category].snr_gain = 1
1014
+
1015
+ # Resolve gains
1016
+ for category in mixture.all_sources:
1017
+ if mixture.is_noise_only and category != "noise":
1018
+ # Special case for zeroing out source data
1019
+ mixture.all_sources[category].snr_gain = 0
1020
+ elif mixture.is_source_only and category == "noise":
1021
+ # Special case for zeroing out noise data
1022
+ mixture.all_sources[category].snr_gain = 0
1023
+ elif category != "primary":
1024
+ if sources_energy["primary"] == 0 or sources_energy[category] == 0:
1025
+ # Avoid divide-by-zero
1026
+ mixture.all_sources[category].snr_gain = 1
1027
+ else:
1028
+ mixture.all_sources[category].snr_gain = float(
1029
+ np.sqrt(sources_energy["primary"] / sources_energy[category])
1030
+ ) / db_to_linear(mixture.all_sources[category].snr)
1031
+
1032
+ # Normalize gains
1033
+ max_snr_gain = max([source.snr_gain for source in mixture.all_sources.values()])
1034
+ for category in mixture.all_sources:
1035
+ mixture.all_sources[category].snr_gain = mixture.all_sources[category].snr_gain / max_snr_gain
1036
+
1037
+ # Check for clipping in mixture
1038
+ mixture_audio = np.sum(
1039
+ [sources_audio[category] * mixture.all_sources[category].snr_gain for category in mixture.all_sources], axis=0
1040
+ )
1041
+ max_abs_audio = float(np.max(np.abs(mixture_audio)))
1042
+ clip_level = db_to_linear(-0.25)
1043
+ if max_abs_audio > clip_level:
1044
+ gain_adjustment = clip_level / max_abs_audio
1045
+ for category in mixture.all_sources:
1046
+ mixture.all_sources[category].snr_gain *= gain_adjustment
1047
+
1048
+ # To improve repeatability, round results
1049
+ for category in mixture.all_sources:
1050
+ mixture.all_sources[category].snr_gain = round(mixture.all_sources[category].snr_gain, ndigits=5)
1051
+
1052
+ return mixture
1053
+
1054
+
1055
+ class NextNoise:
1056
+ def __init__(self, mixdb: MixtureDatabase, effected_noises: list[tuple[SourceFile, Effects]]) -> None:
1057
+ from .effects import effects_from_rules
1058
+ from .effects import estimate_effected_length
1059
+
1060
+ self.mixdb = mixdb
1061
+ self.effected_noises = effected_noises
1062
+
1063
+ self.noise_start = 0
1064
+ self.noise_id = 0
1065
+ self.noise_effect = effects_from_rules(self.mixdb, self.noise_rule)
1066
+ self.noise_length = estimate_effected_length(self.noise_file.samples, self.noise_effect)
1067
+
1068
+ @property
1069
+ def noise_file(self):
1070
+ return self.effected_noises[self.noise_id][0]
1071
+
1072
+ @property
1073
+ def noise_rule(self):
1074
+ return self.effected_noises[self.noise_id][1]
1075
+
1076
+ def generate(self, length: int) -> tuple[int, Effects, int]:
1077
+ from .effects import effects_from_rules
1078
+ from .effects import estimate_effected_length
1079
+
1080
+ if self.noise_start + length > self.noise_length:
1081
+ # Not enough samples in current noise
1082
+ if self.noise_start == 0:
1083
+ raise ValueError("Length of primary audio exceeds length of noise audio")
1084
+
1085
+ self.noise_start = 0
1086
+ self.noise_id = (self.noise_id + 1) % len(self.effected_noises)
1087
+ self.noise_effect = effects_from_rules(self.mixdb, self.noise_rule)
1088
+ self.noise_length = estimate_effected_length(self.noise_file.samples, self.noise_effect)
1089
+ noise_start = self.noise_start
1090
+ else:
1091
+ # Current noise has enough samples
1092
+ noise_start = self.noise_start
1093
+ self.noise_start += length
1094
+
1095
+ return self.noise_file.id, self.noise_effect, noise_start
1096
+
1097
+
1098
+ def _get_textgrid_tiers_from_source_file(file: str) -> list[str]:
1099
+ from ..utils.tokenized_shell_vars import tokenized_expand
1100
+
1101
+ textgrid_file = Path(tokenized_expand(file)[0]).with_suffix(".TextGrid")
1102
+ if not textgrid_file.exists():
1103
+ return []
1104
+
1105
+ tg = textgrid.openTextgrid(str(textgrid_file), includeEmptyIntervals=False)
1106
+
1107
+ return sorted(tg.tierNames)
1108
+
1109
+
1110
+ def _filter_sources(
1111
+ effected_sources: list[tuple[SourceFile, Effects]],
1112
+ speaker_id: tuple[int, ...],
1113
+ ) -> list[tuple[SourceFile, Effects]]:
1114
+ return [
1115
+ (source_file, effects) for source_file, effects in effected_sources if source_file.speaker_id not in speaker_id
1116
+ ]