sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,1116 @@
|
|
1
|
+
# ruff: noqa: S608
|
2
|
+
import json
|
3
|
+
from functools import partial
|
4
|
+
from os.path import join
|
5
|
+
from pathlib import Path
|
6
|
+
from random import choice
|
7
|
+
from random import randint
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
import pandas as pd
|
11
|
+
import yaml
|
12
|
+
from praatio import textgrid
|
13
|
+
|
14
|
+
from .. import logger
|
15
|
+
from ..config.config import load_config
|
16
|
+
from ..config.ir import get_ir_files
|
17
|
+
from ..config.source import get_source_files
|
18
|
+
from ..constants import SAMPLE_BYTES
|
19
|
+
from ..constants import SAMPLE_RATE
|
20
|
+
from ..datatypes import AudioT
|
21
|
+
from ..datatypes import Effects
|
22
|
+
from ..datatypes import GenMixData
|
23
|
+
from ..datatypes import ImpulseResponseFile
|
24
|
+
from ..datatypes import Mixture
|
25
|
+
from ..datatypes import Source
|
26
|
+
from ..datatypes import SourceFile
|
27
|
+
from ..datatypes import SourcesAudioT
|
28
|
+
from ..datatypes import UniversalSNRGenerator
|
29
|
+
from ..utils.human_readable_size import human_readable_size
|
30
|
+
from ..utils.seconds_to_hms import seconds_to_hms
|
31
|
+
from .db import SQLiteDatabase
|
32
|
+
from .effects import get_effect_rules
|
33
|
+
from .mixdb import MixtureDatabase
|
34
|
+
|
35
|
+
|
36
|
+
def config_file(location: str) -> str:
|
37
|
+
return join(location, "config.yml")
|
38
|
+
|
39
|
+
|
40
|
+
# Database schema definition
|
41
|
+
DATABASE_SCHEMA = [
|
42
|
+
"""
|
43
|
+
CREATE TABLE truth_config(
|
44
|
+
id INTEGER PRIMARY KEY NOT NULL,
|
45
|
+
config TEXT NOT NULL)
|
46
|
+
""",
|
47
|
+
"""
|
48
|
+
CREATE TABLE truth_parameters(
|
49
|
+
id INTEGER PRIMARY KEY NOT NULL,
|
50
|
+
category TEXT NOT NULL,
|
51
|
+
name TEXT NOT NULL,
|
52
|
+
parameters INTEGER)
|
53
|
+
""",
|
54
|
+
"""
|
55
|
+
CREATE TABLE source_file (
|
56
|
+
id INTEGER PRIMARY KEY NOT NULL,
|
57
|
+
category TEXT NOT NULL,
|
58
|
+
class_indices TEXT,
|
59
|
+
level_type TEXT NOT NULL,
|
60
|
+
name TEXT NOT NULL,
|
61
|
+
samples INTEGER NOT NULL,
|
62
|
+
speaker_id INTEGER,
|
63
|
+
FOREIGN KEY(speaker_id) REFERENCES speaker (id))
|
64
|
+
""",
|
65
|
+
"""
|
66
|
+
CREATE TABLE ir_file (
|
67
|
+
id INTEGER PRIMARY KEY NOT NULL,
|
68
|
+
delay INTEGER NOT NULL,
|
69
|
+
name TEXT NOT NULL)
|
70
|
+
""",
|
71
|
+
"""
|
72
|
+
CREATE TABLE ir_tag (
|
73
|
+
id INTEGER PRIMARY KEY NOT NULL,
|
74
|
+
tag TEXT NOT NULL UNIQUE)
|
75
|
+
""",
|
76
|
+
"""
|
77
|
+
CREATE TABLE ir_file_ir_tag (
|
78
|
+
file_id INTEGER NOT NULL,
|
79
|
+
tag_id INTEGER NOT NULL,
|
80
|
+
FOREIGN KEY(file_id) REFERENCES ir_file (id),
|
81
|
+
FOREIGN KEY(tag_id) REFERENCES ir_tag (id))
|
82
|
+
""",
|
83
|
+
"""
|
84
|
+
CREATE TABLE speaker (
|
85
|
+
id INTEGER PRIMARY KEY NOT NULL,
|
86
|
+
parent TEXT NOT NULL)
|
87
|
+
""",
|
88
|
+
"""
|
89
|
+
CREATE TABLE top (
|
90
|
+
id INTEGER PRIMARY KEY NOT NULL,
|
91
|
+
asr_configs TEXT NOT NULL,
|
92
|
+
class_balancing BOOLEAN NOT NULL,
|
93
|
+
feature TEXT NOT NULL,
|
94
|
+
mixid_width INTEGER NOT NULL,
|
95
|
+
num_classes INTEGER NOT NULL,
|
96
|
+
seed INTEGER NOT NULL,
|
97
|
+
speaker_metadata_tiers TEXT NOT NULL,
|
98
|
+
textgrid_metadata_tiers TEXT NOT NULL,
|
99
|
+
version INTEGER NOT NULL)
|
100
|
+
""",
|
101
|
+
"""
|
102
|
+
CREATE TABLE class_label (
|
103
|
+
id INTEGER PRIMARY KEY NOT NULL,
|
104
|
+
label TEXT NOT NULL)
|
105
|
+
""",
|
106
|
+
"""
|
107
|
+
CREATE TABLE class_weights_threshold (
|
108
|
+
id INTEGER PRIMARY KEY NOT NULL,
|
109
|
+
threshold FLOAT NOT NULL)
|
110
|
+
""",
|
111
|
+
"""
|
112
|
+
CREATE TABLE spectral_mask (
|
113
|
+
id INTEGER PRIMARY KEY NOT NULL,
|
114
|
+
f_max_width INTEGER NOT NULL,
|
115
|
+
f_num INTEGER NOT NULL,
|
116
|
+
t_max_percent INTEGER NOT NULL,
|
117
|
+
t_max_width INTEGER NOT NULL,
|
118
|
+
t_num INTEGER NOT NULL)
|
119
|
+
""",
|
120
|
+
"""
|
121
|
+
CREATE TABLE source_file_truth_config (
|
122
|
+
source_file_id INTEGER NOT NULL,
|
123
|
+
truth_config_id INTEGER NOT NULL,
|
124
|
+
FOREIGN KEY(source_file_id) REFERENCES source_file (id),
|
125
|
+
FOREIGN KEY(truth_config_id) REFERENCES truth_config (id))
|
126
|
+
""",
|
127
|
+
"""
|
128
|
+
CREATE TABLE source (
|
129
|
+
id INTEGER PRIMARY KEY NOT NULL,
|
130
|
+
effects TEXT NOT NULL,
|
131
|
+
file_id INTEGER NOT NULL,
|
132
|
+
pre_tempo FLOAT NOT NULL,
|
133
|
+
repeat BOOLEAN NOT NULL,
|
134
|
+
snr FLOAT NOT NULL,
|
135
|
+
snr_gain FLOAT NOT NULL,
|
136
|
+
snr_random BOOLEAN NOT NULL,
|
137
|
+
start INTEGER NOT NULL,
|
138
|
+
UNIQUE(effects, file_id, pre_tempo, repeat, snr, snr_gain, snr_random, start),
|
139
|
+
FOREIGN KEY(file_id) REFERENCES source_file (id))
|
140
|
+
""",
|
141
|
+
"""
|
142
|
+
CREATE TABLE mixture (
|
143
|
+
id INTEGER PRIMARY KEY NOT NULL,
|
144
|
+
name TEXT NOT NULL,
|
145
|
+
samples INTEGER NOT NULL,
|
146
|
+
spectral_mask_id INTEGER NOT NULL,
|
147
|
+
spectral_mask_seed INTEGER NOT NULL,
|
148
|
+
FOREIGN KEY(spectral_mask_id) REFERENCES spectral_mask (id))
|
149
|
+
""",
|
150
|
+
"""
|
151
|
+
CREATE TABLE mixture_source (
|
152
|
+
mixture_id INTEGER NOT NULL,
|
153
|
+
source_id INTEGER NOT NULL,
|
154
|
+
FOREIGN KEY(mixture_id) REFERENCES mixture (id),
|
155
|
+
FOREIGN KEY(source_id) REFERENCES source (id))
|
156
|
+
""",
|
157
|
+
]
|
158
|
+
|
159
|
+
|
160
|
+
class DatabaseManager:
|
161
|
+
"""Manages database operations for mixture database generation."""
|
162
|
+
|
163
|
+
def __init__(self, location: str, test: bool = False, verbose: bool = False, logging: bool = False) -> None:
|
164
|
+
self.location = location
|
165
|
+
self.test = test
|
166
|
+
self.verbose = verbose
|
167
|
+
self.logging = logging
|
168
|
+
|
169
|
+
self.config = load_config(self.location)
|
170
|
+
self.db = partial(SQLiteDatabase, location=self.location, test=self.test, verbose=self.verbose)
|
171
|
+
|
172
|
+
with self.db(create=True) as c:
|
173
|
+
for table_sql in DATABASE_SCHEMA:
|
174
|
+
c.execute(table_sql)
|
175
|
+
|
176
|
+
self.mixdb = MixtureDatabase(location=self.location, test=self.test)
|
177
|
+
|
178
|
+
def populate_top_table(self) -> None:
|
179
|
+
"""Populate the top table"""
|
180
|
+
from .constants import MIXDB_VERSION
|
181
|
+
|
182
|
+
parameters = (
|
183
|
+
1,
|
184
|
+
json.dumps(self.config["asr_configs"]),
|
185
|
+
self.config["class_balancing"],
|
186
|
+
self.config["feature"],
|
187
|
+
0,
|
188
|
+
self.config["num_classes"],
|
189
|
+
self.config["seed"],
|
190
|
+
"",
|
191
|
+
"",
|
192
|
+
MIXDB_VERSION,
|
193
|
+
)
|
194
|
+
|
195
|
+
with self.db(readonly=False) as c:
|
196
|
+
c.execute(
|
197
|
+
"""
|
198
|
+
INSERT INTO top (id, asr_configs, class_balancing, feature, mixid_width, num_classes,
|
199
|
+
seed, speaker_metadata_tiers, textgrid_metadata_tiers, version)
|
200
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
201
|
+
""",
|
202
|
+
parameters,
|
203
|
+
)
|
204
|
+
|
205
|
+
def populate_class_label_table(self) -> None:
|
206
|
+
"""Populate the class_label table"""
|
207
|
+
with self.db(readonly=False) as c:
|
208
|
+
c.executemany(
|
209
|
+
"INSERT INTO class_label (label) VALUES (?)",
|
210
|
+
[(item,) for item in self.config["class_labels"]],
|
211
|
+
)
|
212
|
+
|
213
|
+
def populate_class_weights_threshold_table(self) -> None:
|
214
|
+
"""Populate the class_weights_threshold table"""
|
215
|
+
class_weights_threshold = self.config["class_weights_threshold"]
|
216
|
+
num_classes = self.config["num_classes"]
|
217
|
+
|
218
|
+
if not isinstance(class_weights_threshold, list):
|
219
|
+
class_weights_threshold = [class_weights_threshold]
|
220
|
+
|
221
|
+
if len(class_weights_threshold) == 1:
|
222
|
+
class_weights_threshold = [class_weights_threshold[0]] * num_classes
|
223
|
+
|
224
|
+
if len(class_weights_threshold) != num_classes:
|
225
|
+
raise ValueError(f"invalid class_weights_threshold length: {len(class_weights_threshold)}")
|
226
|
+
|
227
|
+
with self.db(readonly=False) as c:
|
228
|
+
c.executemany(
|
229
|
+
"INSERT INTO class_weights_threshold (threshold) VALUES (?)",
|
230
|
+
[(item,) for item in class_weights_threshold],
|
231
|
+
)
|
232
|
+
|
233
|
+
def populate_spectral_mask_table(self) -> None:
|
234
|
+
"""Populate the spectral_mask table"""
|
235
|
+
from ..config.spectral_masks import get_spectral_masks
|
236
|
+
|
237
|
+
with self.db(readonly=False) as c:
|
238
|
+
c.executemany(
|
239
|
+
"""
|
240
|
+
INSERT INTO spectral_mask (f_max_width, f_num, t_max_percent, t_max_width, t_num) VALUES (?, ?, ?, ?, ?)
|
241
|
+
""",
|
242
|
+
[
|
243
|
+
(
|
244
|
+
item.f_max_width,
|
245
|
+
item.f_num,
|
246
|
+
item.t_max_percent,
|
247
|
+
item.t_max_width,
|
248
|
+
item.t_num,
|
249
|
+
)
|
250
|
+
for item in get_spectral_masks(self.config)
|
251
|
+
],
|
252
|
+
)
|
253
|
+
|
254
|
+
def populate_truth_parameters_table(self) -> None:
|
255
|
+
"""Populate the truth_parameters table"""
|
256
|
+
from ..config.truth import get_truth_parameters
|
257
|
+
|
258
|
+
with self.db(readonly=False) as c:
|
259
|
+
c.executemany(
|
260
|
+
"""
|
261
|
+
INSERT INTO truth_parameters (category, name, parameters) VALUES (?, ?, ?)
|
262
|
+
""",
|
263
|
+
[
|
264
|
+
(
|
265
|
+
item.category,
|
266
|
+
item.name,
|
267
|
+
item.parameters,
|
268
|
+
)
|
269
|
+
for item in get_truth_parameters(self.config)
|
270
|
+
],
|
271
|
+
)
|
272
|
+
|
273
|
+
def populate_source_file_table(self, show_progress: bool = False) -> None:
|
274
|
+
"""Populate the source file table"""
|
275
|
+
if self.logging:
|
276
|
+
logger.info("Collecting sources")
|
277
|
+
|
278
|
+
files = get_source_files(self.config, show_progress)
|
279
|
+
logger.info("")
|
280
|
+
|
281
|
+
if len([file for file in files if file.category == "primary"]) == 0:
|
282
|
+
raise RuntimeError("Canceled due to no primary sources")
|
283
|
+
|
284
|
+
if self.logging:
|
285
|
+
logger.info("Populating source file table")
|
286
|
+
|
287
|
+
self._populate_truth_config_table(files)
|
288
|
+
self._populate_speaker_table(files)
|
289
|
+
|
290
|
+
with self.db(readonly=False) as c:
|
291
|
+
textgrid_metadata_tiers: set[str] = set()
|
292
|
+
for file in files:
|
293
|
+
# Get TextGrid tiers for the source file and add to the collection
|
294
|
+
tiers = _get_textgrid_tiers_from_source_file(file.name)
|
295
|
+
for tier in tiers:
|
296
|
+
textgrid_metadata_tiers.add(tier)
|
297
|
+
|
298
|
+
# Get truth settings for the file
|
299
|
+
truth_config_ids: list[int] = []
|
300
|
+
if file.truth_configs:
|
301
|
+
for name, config in file.truth_configs.items():
|
302
|
+
ts = json.dumps({"name": name} | config.to_dict())
|
303
|
+
c.execute(
|
304
|
+
"SELECT truth_config.id FROM truth_config WHERE ? = truth_config.config",
|
305
|
+
(ts,),
|
306
|
+
)
|
307
|
+
truth_config_ids.append(c.fetchone()[0])
|
308
|
+
|
309
|
+
# Get speaker_id for the source file
|
310
|
+
c.execute(
|
311
|
+
"SELECT speaker.id FROM speaker WHERE ? = speaker.parent", (Path(file.name).parent.as_posix(),)
|
312
|
+
)
|
313
|
+
result = c.fetchone()
|
314
|
+
speaker_id = None
|
315
|
+
if result is not None:
|
316
|
+
speaker_id = result[0]
|
317
|
+
|
318
|
+
# Add entry
|
319
|
+
c.execute(
|
320
|
+
"""
|
321
|
+
INSERT INTO source_file (category, class_indices, level_type, name, samples, speaker_id)
|
322
|
+
VALUES (?, ?, ?, ?, ?, ?)
|
323
|
+
""",
|
324
|
+
(
|
325
|
+
file.category,
|
326
|
+
json.dumps(file.class_indices),
|
327
|
+
file.level_type,
|
328
|
+
file.name,
|
329
|
+
file.samples,
|
330
|
+
speaker_id,
|
331
|
+
),
|
332
|
+
)
|
333
|
+
source_file_id = c.lastrowid
|
334
|
+
for truth_config_id in truth_config_ids:
|
335
|
+
c.execute(
|
336
|
+
"INSERT INTO source_file_truth_config (source_file_id, truth_config_id) VALUES (?, ?)",
|
337
|
+
(source_file_id, truth_config_id),
|
338
|
+
)
|
339
|
+
|
340
|
+
# Update textgrid_metadata_tiers in the top table
|
341
|
+
c.execute(
|
342
|
+
"UPDATE top SET textgrid_metadata_tiers=? WHERE ? = id",
|
343
|
+
(json.dumps(sorted(textgrid_metadata_tiers)), 1),
|
344
|
+
)
|
345
|
+
|
346
|
+
if self.logging:
|
347
|
+
logger.info("Sources summary")
|
348
|
+
data = {
|
349
|
+
"category": [],
|
350
|
+
"files": [],
|
351
|
+
"size": [],
|
352
|
+
"duration": [],
|
353
|
+
}
|
354
|
+
for category, files in self.mixdb.source_files.items():
|
355
|
+
audio_samples = sum([source.samples for source in files])
|
356
|
+
audio_duration = audio_samples / SAMPLE_RATE
|
357
|
+
data["category"].append(category)
|
358
|
+
data["files"].append(self.mixdb.num_source_files(category))
|
359
|
+
data["size"].append(human_readable_size(audio_samples * SAMPLE_BYTES, 1))
|
360
|
+
data["duration"].append(seconds_to_hms(seconds=audio_duration))
|
361
|
+
|
362
|
+
df = pd.DataFrame(data)
|
363
|
+
logger.info(df.to_string(index=False, header=False))
|
364
|
+
logger.info("")
|
365
|
+
|
366
|
+
for category, files in self.mixdb.source_files.items():
|
367
|
+
logger.debug(f"List of {category} sources:")
|
368
|
+
logger.debug(yaml.dump([file.name for file in files], default_flow_style=False))
|
369
|
+
|
370
|
+
def populate_impulse_response_file_table(self, show_progress: bool = False) -> None:
|
371
|
+
"""Populate the impulse response file table"""
|
372
|
+
if self.logging:
|
373
|
+
logger.info("Collecting impulse responses")
|
374
|
+
|
375
|
+
files = get_ir_files(self.config, show_progress=show_progress)
|
376
|
+
logger.info("")
|
377
|
+
|
378
|
+
if self.logging:
|
379
|
+
logger.info("Populating impulse response file table")
|
380
|
+
|
381
|
+
self._populate_impulse_response_tag_table(files)
|
382
|
+
|
383
|
+
with self.db(readonly=False) as c:
|
384
|
+
for file in files:
|
385
|
+
# Get the tags for the file
|
386
|
+
tag_ids: list[int] = []
|
387
|
+
for tag in file.tags:
|
388
|
+
c.execute("SELECT id FROM ir_tag WHERE ? = tag", (tag,))
|
389
|
+
tag_ids.append(c.fetchone()[0])
|
390
|
+
|
391
|
+
c.execute("INSERT INTO ir_file (delay, name) VALUES (?, ?)", (file.delay, file.name))
|
392
|
+
|
393
|
+
file_id = c.lastrowid
|
394
|
+
for tag_id in tag_ids:
|
395
|
+
c.execute("INSERT INTO ir_file_ir_tag (file_id, tag_id) VALUES (?, ?)", (file_id, tag_id))
|
396
|
+
|
397
|
+
if self.logging:
|
398
|
+
logger.debug("List of impulse responses:")
|
399
|
+
for idx, file in enumerate(files):
|
400
|
+
logger.debug(f"id: {idx}, name:{file.name}, delay: {file.delay}, tags: [{', '.join(file.tags)}]")
|
401
|
+
logger.debug("")
|
402
|
+
|
403
|
+
def populate_mixture_table(self, mixtures: list[Mixture], show_progress: bool = False) -> None:
|
404
|
+
"""Populate the mixture table"""
|
405
|
+
from ..utils.parallel import track
|
406
|
+
from .helpers import from_mixture
|
407
|
+
from .helpers import from_source
|
408
|
+
|
409
|
+
if self.logging:
|
410
|
+
logger.info("Populating mixture and source tables")
|
411
|
+
|
412
|
+
with self.db(readonly=False) as c:
|
413
|
+
# Populate the source table
|
414
|
+
for mixture in track(mixtures, disable=not show_progress):
|
415
|
+
m_id = int(mixture.name) + 1
|
416
|
+
c.execute(
|
417
|
+
"""
|
418
|
+
INSERT INTO mixture (id, name, samples, spectral_mask_id, spectral_mask_seed)
|
419
|
+
VALUES (?, ?, ?, ?, ?)
|
420
|
+
""",
|
421
|
+
(m_id, *from_mixture(mixture)),
|
422
|
+
)
|
423
|
+
|
424
|
+
for source in mixture.all_sources.values():
|
425
|
+
c.execute(
|
426
|
+
"""
|
427
|
+
INSERT OR IGNORE INTO source (effects, file_id, pre_tempo, repeat, snr, snr_gain, snr_random, start)
|
428
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
429
|
+
""",
|
430
|
+
from_source(source),
|
431
|
+
)
|
432
|
+
|
433
|
+
source_id = c.execute(
|
434
|
+
"""
|
435
|
+
SELECT id
|
436
|
+
FROM source
|
437
|
+
WHERE ? = effects
|
438
|
+
AND ? = file_id
|
439
|
+
AND ? = pre_tempo
|
440
|
+
AND ? = repeat
|
441
|
+
AND ? = snr
|
442
|
+
AND ? = snr_gain
|
443
|
+
AND ? = snr_random
|
444
|
+
AND ? = start
|
445
|
+
""",
|
446
|
+
from_source(source),
|
447
|
+
).fetchone()[0]
|
448
|
+
c.execute("INSERT INTO mixture_source (mixture_id, source_id) VALUES (?, ?)", (m_id, source_id))
|
449
|
+
|
450
|
+
if self.logging:
|
451
|
+
logger.info("Closing mixture and source tables")
|
452
|
+
|
453
|
+
def _populate_speaker_table(self, source_files: list[SourceFile]) -> None:
|
454
|
+
"""Populate the speaker table"""
|
455
|
+
from ..utils.tokenized_shell_vars import tokenized_expand
|
456
|
+
|
457
|
+
# Determine the columns for speaker the table
|
458
|
+
all_parents = {Path(file.name).parent for file in source_files}
|
459
|
+
speaker_parents = (
|
460
|
+
parent for parent in all_parents if Path(tokenized_expand(parent / "speaker.yml")[0]).exists()
|
461
|
+
)
|
462
|
+
|
463
|
+
speakers: dict[Path, dict[str, str]] = {}
|
464
|
+
for parent in sorted(speaker_parents):
|
465
|
+
with open(tokenized_expand(parent / "speaker.yml")[0]) as f:
|
466
|
+
speakers[parent] = yaml.safe_load(f)
|
467
|
+
|
468
|
+
new_columns: list[str] = []
|
469
|
+
for keys in speakers:
|
470
|
+
for column in speakers[keys]:
|
471
|
+
new_columns.append(column)
|
472
|
+
new_columns = sorted(set(new_columns))
|
473
|
+
|
474
|
+
with self.db(readonly=False) as c:
|
475
|
+
for new_column in new_columns:
|
476
|
+
c.execute(f"ALTER TABLE speaker ADD COLUMN {new_column} TEXT")
|
477
|
+
|
478
|
+
# Populate the speaker table
|
479
|
+
speaker_rows: list[tuple[str, ...]] = []
|
480
|
+
for key in speakers:
|
481
|
+
entry = (speakers[key].get(column, None) for column in new_columns)
|
482
|
+
speaker_rows.append((key.as_posix(), *entry)) # type: ignore[arg-type]
|
483
|
+
|
484
|
+
column_ids = ", ".join(["parent", *new_columns])
|
485
|
+
column_values = ", ".join(["?"] * (len(new_columns) + 1))
|
486
|
+
c.executemany(f"INSERT INTO speaker ({column_ids}) VALUES ({column_values})", speaker_rows)
|
487
|
+
|
488
|
+
c.execute("CREATE INDEX speaker_parent_idx ON speaker (parent)")
|
489
|
+
|
490
|
+
# Update speaker_metadata_tiers in the top table
|
491
|
+
tiers = [
|
492
|
+
description[0]
|
493
|
+
for description in c.execute("SELECT * FROM speaker").description
|
494
|
+
if description[0] not in ("id", "parent")
|
495
|
+
]
|
496
|
+
c.execute("UPDATE top SET speaker_metadata_tiers=? WHERE ? = id", (json.dumps(tiers), 1))
|
497
|
+
|
498
|
+
if "speaker_id" in tiers:
|
499
|
+
c.execute("CREATE INDEX speaker_speaker_id_idx ON source_file (speaker_id)")
|
500
|
+
|
501
|
+
def _populate_truth_config_table(self, source_files: list[SourceFile]) -> None:
|
502
|
+
"""Populate the truth_config table"""
|
503
|
+
with self.db(readonly=False) as c:
|
504
|
+
# Populate truth_config table
|
505
|
+
truth_configs: list[str] = []
|
506
|
+
for file in source_files:
|
507
|
+
for name, config in file.truth_configs.items():
|
508
|
+
ts = json.dumps({"name": name} | config.to_dict())
|
509
|
+
if ts not in truth_configs:
|
510
|
+
truth_configs.append(ts)
|
511
|
+
c.executemany(
|
512
|
+
"INSERT INTO truth_config (config) VALUES (?)",
|
513
|
+
[(item,) for item in truth_configs],
|
514
|
+
)
|
515
|
+
|
516
|
+
def _populate_impulse_response_tag_table(self, files: list[ImpulseResponseFile]) -> None:
|
517
|
+
"""Populate the ir_tag table"""
|
518
|
+
with self.db(readonly=False) as c:
|
519
|
+
c.executemany(
|
520
|
+
"INSERT INTO ir_tag (tag) VALUES (?)",
|
521
|
+
[(tag,) for tag in {tag for file in files for tag in file.tags}],
|
522
|
+
)
|
523
|
+
|
524
|
+
def generate_mixtures(self) -> list[Mixture]:
|
525
|
+
"""Generate mixtures"""
|
526
|
+
from ..utils.max_text_width import max_text_width
|
527
|
+
|
528
|
+
if self.logging:
|
529
|
+
logger.info("Collecting effects")
|
530
|
+
|
531
|
+
rules = get_effect_rules(self.location, self.config, self.test)
|
532
|
+
|
533
|
+
if self.logging:
|
534
|
+
logger.info("")
|
535
|
+
for category, effect in rules.items():
|
536
|
+
logger.debug(f"List of {category} rules:")
|
537
|
+
logger.debug(yaml.dump([entry.to_dict() for entry in effect], default_flow_style=False))
|
538
|
+
|
539
|
+
if self.logging:
|
540
|
+
logger.debug("SNRS:")
|
541
|
+
for category, source in self.config["sources"].items():
|
542
|
+
if category != "primary":
|
543
|
+
logger.debug(f" {category}")
|
544
|
+
for snr in source["snrs"]:
|
545
|
+
logger.debug(f" - {snr}")
|
546
|
+
logger.debug("")
|
547
|
+
logger.debug("Mix Rules:")
|
548
|
+
for category, source in self.config["sources"].items():
|
549
|
+
if category != "primary":
|
550
|
+
logger.debug(f" {category}")
|
551
|
+
for mix_rule in source["mix_rules"]:
|
552
|
+
logger.debug(f" - {mix_rule}")
|
553
|
+
logger.debug("")
|
554
|
+
logger.debug("Spectral masks:")
|
555
|
+
for spectral_mask in self.mixdb.spectral_masks:
|
556
|
+
logger.debug(f"- {spectral_mask}")
|
557
|
+
logger.debug("")
|
558
|
+
|
559
|
+
if self.logging:
|
560
|
+
logger.info("Generating mixtures")
|
561
|
+
|
562
|
+
effected_sources: dict[str, list[tuple[SourceFile, Effects]]] = {}
|
563
|
+
for category in self.mixdb.source_files:
|
564
|
+
effected_sources[category] = []
|
565
|
+
for file in self.mixdb.source_files[category]:
|
566
|
+
for rule in rules[category]:
|
567
|
+
effected_sources[category].append((file, rule))
|
568
|
+
|
569
|
+
# First, create mixtures of primary and noise
|
570
|
+
mixtures: list[Mixture] = []
|
571
|
+
for mix_rule in self.config["sources"]["noise"]["mix_rules"]:
|
572
|
+
mixtures.extend(
|
573
|
+
self._process_noise_sources(
|
574
|
+
primary_effected_sources=effected_sources["primary"],
|
575
|
+
noise_effected_sources=effected_sources["noise"],
|
576
|
+
mix_rule=mix_rule,
|
577
|
+
)
|
578
|
+
)
|
579
|
+
|
580
|
+
# Next, cycle through any additional sources and apply mix rules for each
|
581
|
+
additional_sources = [cat for cat in self.mixdb.source_files if cat not in ("primary", "noise")]
|
582
|
+
for category in additional_sources:
|
583
|
+
new_mixtures: list[Mixture] = []
|
584
|
+
for mix_rule in self.config["sources"][category]["mix_rules"]:
|
585
|
+
new_mixtures.extend(
|
586
|
+
self._process_additional_sources(
|
587
|
+
effected_sources=effected_sources[category],
|
588
|
+
mixtures=mixtures,
|
589
|
+
category=category,
|
590
|
+
mix_rule=mix_rule,
|
591
|
+
)
|
592
|
+
)
|
593
|
+
mixtures.extend(new_mixtures)
|
594
|
+
|
595
|
+
# Update the mixid width
|
596
|
+
with self.db(readonly=False) as c:
|
597
|
+
c.execute("UPDATE top SET mixid_width=? WHERE ? = id", (max_text_width(len(mixtures)), 1))
|
598
|
+
|
599
|
+
return mixtures
|
600
|
+
|
601
|
+
def _process_noise_sources(
|
602
|
+
self,
|
603
|
+
primary_effected_sources: list[tuple[SourceFile, Effects]],
|
604
|
+
noise_effected_sources: list[tuple[SourceFile, Effects]],
|
605
|
+
mix_rule: str,
|
606
|
+
) -> list[Mixture]:
|
607
|
+
match mix_rule:
|
608
|
+
case "exhaustive":
|
609
|
+
return self._noise_exhaustive(primary_effected_sources, noise_effected_sources)
|
610
|
+
case "non-exhaustive":
|
611
|
+
return self._noise_non_exhaustive(primary_effected_sources, noise_effected_sources)
|
612
|
+
case "non-combinatorial":
|
613
|
+
return self._noise_non_combinatorial(primary_effected_sources, noise_effected_sources)
|
614
|
+
case _:
|
615
|
+
raise ValueError(f"invalid noise mix_rule: {mix_rule}")
|
616
|
+
|
617
|
+
def _noise_exhaustive(
|
618
|
+
self,
|
619
|
+
primary_effected_sources: list[tuple[SourceFile, Effects]],
|
620
|
+
noise_effected_sources: list[tuple[SourceFile, Effects]],
|
621
|
+
) -> list[Mixture]:
|
622
|
+
"""Use every noise/effect with every source/effect+interferences/effect"""
|
623
|
+
from ..datatypes import Mixture
|
624
|
+
from ..datatypes import UniversalSNR
|
625
|
+
from .effects import effects_from_rules
|
626
|
+
from .effects import estimate_effected_length
|
627
|
+
|
628
|
+
snrs = self.all_snrs()
|
629
|
+
|
630
|
+
mixtures: list[Mixture] = []
|
631
|
+
for noise_file, noise_rule in noise_effected_sources:
|
632
|
+
noise_start = 0
|
633
|
+
noise_effect = effects_from_rules(self.mixdb, noise_rule)
|
634
|
+
noise_length = estimate_effected_length(noise_file.samples, noise_effect)
|
635
|
+
|
636
|
+
for primary_file, primary_rule in primary_effected_sources:
|
637
|
+
primary_effect = effects_from_rules(self.mixdb, primary_rule)
|
638
|
+
primary_length = estimate_effected_length(
|
639
|
+
primary_file.samples, primary_effect, self.mixdb.feature_step_samples
|
640
|
+
)
|
641
|
+
|
642
|
+
for spectral_mask_id in range(len(self.config["spectral_masks"])):
|
643
|
+
for snr in snrs["noise"]:
|
644
|
+
mixtures.append(
|
645
|
+
Mixture(
|
646
|
+
name="",
|
647
|
+
all_sources={
|
648
|
+
"primary": Source(
|
649
|
+
file_id=primary_file.id,
|
650
|
+
effects=primary_effect,
|
651
|
+
),
|
652
|
+
"noise": Source(
|
653
|
+
file_id=noise_file.id,
|
654
|
+
effects=noise_effect,
|
655
|
+
start=noise_start,
|
656
|
+
loop=True,
|
657
|
+
snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
|
658
|
+
),
|
659
|
+
},
|
660
|
+
samples=primary_length,
|
661
|
+
spectral_mask_id=spectral_mask_id + 1,
|
662
|
+
spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
|
663
|
+
)
|
664
|
+
)
|
665
|
+
noise_start = int((noise_start + primary_length) % noise_length)
|
666
|
+
|
667
|
+
return mixtures
|
668
|
+
|
669
|
+
def _noise_non_exhaustive(
|
670
|
+
self,
|
671
|
+
primary_effected_sources: list[tuple[SourceFile, Effects]],
|
672
|
+
noise_effected_sources: list[tuple[SourceFile, Effects]],
|
673
|
+
) -> list[Mixture]:
|
674
|
+
"""Cycle through every source/effect+interferences/effect without necessarily using all
|
675
|
+
noise/effect combinations (reduced data set).
|
676
|
+
"""
|
677
|
+
from ..datatypes import Mixture
|
678
|
+
from ..datatypes import UniversalSNR
|
679
|
+
from .effects import effects_from_rules
|
680
|
+
from .effects import estimate_effected_length
|
681
|
+
|
682
|
+
snrs = self.all_snrs()
|
683
|
+
|
684
|
+
next_noise = NextNoise(self.mixdb, noise_effected_sources)
|
685
|
+
|
686
|
+
mixtures: list[Mixture] = []
|
687
|
+
for primary_file, primary_rule in primary_effected_sources:
|
688
|
+
primary_effect = effects_from_rules(self.mixdb, primary_rule)
|
689
|
+
primary_length = estimate_effected_length(
|
690
|
+
primary_file.samples, primary_effect, self.mixdb.feature_step_samples
|
691
|
+
)
|
692
|
+
|
693
|
+
for spectral_mask_id in range(len(self.config["spectral_masks"])):
|
694
|
+
for snr in snrs["noise"]:
|
695
|
+
noise_file_id, noise_effect, noise_start = next_noise.generate(primary_file.samples)
|
696
|
+
|
697
|
+
mixtures.append(
|
698
|
+
Mixture(
|
699
|
+
name="",
|
700
|
+
all_sources={
|
701
|
+
"primary": Source(
|
702
|
+
file_id=primary_file.id,
|
703
|
+
effects=primary_effect,
|
704
|
+
),
|
705
|
+
"noise": Source(
|
706
|
+
file_id=noise_file_id,
|
707
|
+
effects=noise_effect,
|
708
|
+
start=noise_start,
|
709
|
+
loop=True,
|
710
|
+
snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
|
711
|
+
),
|
712
|
+
},
|
713
|
+
samples=primary_length,
|
714
|
+
spectral_mask_id=spectral_mask_id + 1,
|
715
|
+
spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
|
716
|
+
)
|
717
|
+
)
|
718
|
+
|
719
|
+
return mixtures
|
720
|
+
|
721
|
+
def _noise_non_combinatorial(
|
722
|
+
self,
|
723
|
+
primary_effected_sources: list[tuple[SourceFile, Effects]],
|
724
|
+
noise_effected_sources: list[tuple[SourceFile, Effects]],
|
725
|
+
) -> list[Mixture]:
|
726
|
+
"""Combine a source/effect+interferences/effect with a single cut of a noise/effect
|
727
|
+
non-exhaustively (each source/effect+interferences/effect does not use each noise/effect).
|
728
|
+
Cut has a random start and loop back to the beginning if the end of noise/effect is reached.
|
729
|
+
"""
|
730
|
+
from ..datatypes import Mixture
|
731
|
+
from ..datatypes import UniversalSNR
|
732
|
+
from .effects import effects_from_rules
|
733
|
+
from .effects import estimate_effected_length
|
734
|
+
|
735
|
+
snrs = self.all_snrs()
|
736
|
+
|
737
|
+
noise_id = 0
|
738
|
+
mixtures: list[Mixture] = []
|
739
|
+
for primary_file, primary_rule in primary_effected_sources:
|
740
|
+
primary_effect = effects_from_rules(self.mixdb, primary_rule)
|
741
|
+
primary_length = estimate_effected_length(
|
742
|
+
primary_file.samples, primary_effect, self.mixdb.feature_step_samples
|
743
|
+
)
|
744
|
+
|
745
|
+
for spectral_mask_id in range(len(self.config["spectral_masks"])):
|
746
|
+
for snr in snrs["noise"]:
|
747
|
+
noise_file, noise_rule = noise_effected_sources[noise_id]
|
748
|
+
noise_effect = effects_from_rules(self.mixdb, noise_rule)
|
749
|
+
noise_length = estimate_effected_length(noise_file.samples, noise_effect)
|
750
|
+
|
751
|
+
mixtures.append(
|
752
|
+
Mixture(
|
753
|
+
name="",
|
754
|
+
all_sources={
|
755
|
+
"primary": Source(
|
756
|
+
file_id=primary_file.id,
|
757
|
+
effects=primary_effect,
|
758
|
+
),
|
759
|
+
"noise": Source(
|
760
|
+
file_id=noise_file.id,
|
761
|
+
effects=noise_effect,
|
762
|
+
start=choice(range(noise_length)), # noqa: S311
|
763
|
+
loop=True,
|
764
|
+
snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
|
765
|
+
),
|
766
|
+
},
|
767
|
+
samples=primary_length,
|
768
|
+
spectral_mask_id=spectral_mask_id + 1,
|
769
|
+
spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
|
770
|
+
)
|
771
|
+
)
|
772
|
+
noise_id = (noise_id + 1) % len(noise_effected_sources)
|
773
|
+
|
774
|
+
return mixtures
|
775
|
+
|
776
|
+
def _process_additional_sources(
|
777
|
+
self,
|
778
|
+
effected_sources: list[tuple[SourceFile, Effects]],
|
779
|
+
mixtures: list[Mixture],
|
780
|
+
category: str,
|
781
|
+
mix_rule: str,
|
782
|
+
) -> list[Mixture]:
|
783
|
+
if mix_rule == "none":
|
784
|
+
return []
|
785
|
+
if mix_rule.startswith("choose"):
|
786
|
+
return self._additional_choose(
|
787
|
+
effected_sources=effected_sources,
|
788
|
+
mixtures=mixtures,
|
789
|
+
category=category,
|
790
|
+
mix_rule=mix_rule,
|
791
|
+
)
|
792
|
+
if mix_rule.startswith("sequence"):
|
793
|
+
return self._additional_sequence(
|
794
|
+
effected_sources=effected_sources,
|
795
|
+
mixtures=mixtures,
|
796
|
+
category=category,
|
797
|
+
mix_rule=mix_rule,
|
798
|
+
)
|
799
|
+
raise ValueError(f"invalid {category} mix_rule: {mix_rule}")
|
800
|
+
|
801
|
+
def _additional_choose(
|
802
|
+
self,
|
803
|
+
effected_sources: list[tuple[SourceFile, Effects]],
|
804
|
+
mixtures: list[Mixture],
|
805
|
+
category: str,
|
806
|
+
mix_rule: str,
|
807
|
+
) -> list[Mixture]:
|
808
|
+
from copy import deepcopy
|
809
|
+
|
810
|
+
from ..datatypes import UniversalSNR
|
811
|
+
from ..parse.parse_source_directive import parse_source_directive
|
812
|
+
from ..utils.choice import RandomChoice
|
813
|
+
|
814
|
+
# Parse the mix rule
|
815
|
+
try:
|
816
|
+
params = parse_source_directive(mix_rule)
|
817
|
+
except ValueError as e:
|
818
|
+
raise ValueError(f"Error parsing choose directive: {e}") from e
|
819
|
+
|
820
|
+
snrs = self.all_snrs()[category]
|
821
|
+
|
822
|
+
choice_objs: dict[tuple[int, ...], RandomChoice] = {}
|
823
|
+
if params.unique == "speaker_id" :
|
824
|
+
# Get a set of speaker_id values in use in the existing mixtures.
|
825
|
+
all_speaker_ids = {self.get_mixture_speaker_ids(mixture) for mixture in mixtures}
|
826
|
+
|
827
|
+
# Create a set of RandomChoice objects that are filtered on those values.
|
828
|
+
for speaker_ids in all_speaker_ids:
|
829
|
+
filtered_sources = _filter_sources(effected_sources, speaker_ids)
|
830
|
+
if not filtered_sources:
|
831
|
+
raise ValueError(
|
832
|
+
f"Additional source, {category}, has no valid entries for speaker_ids unique from {speaker_ids}"
|
833
|
+
)
|
834
|
+
choice_objs[speaker_ids] = RandomChoice(data=filtered_sources, repetition=params.repeat)
|
835
|
+
elif params.unique is None:
|
836
|
+
choice_objs = {(0,): RandomChoice(data=effected_sources, repetition=params.repeat)}
|
837
|
+
else:
|
838
|
+
raise ValueError(f"Invalid unique value: {params.unique}")
|
839
|
+
|
840
|
+
# Loop over mixtures and add additional sources
|
841
|
+
new_mixtures: list[Mixture] = []
|
842
|
+
for mixture in mixtures:
|
843
|
+
for snr in snrs:
|
844
|
+
new_mixture = deepcopy(mixture)
|
845
|
+
if params.unique == "speaker_id" :
|
846
|
+
speaker_ids = self.get_mixture_speaker_ids(mixture)
|
847
|
+
elif params.unique is None:
|
848
|
+
speaker_ids = (0,)
|
849
|
+
else:
|
850
|
+
raise ValueError(f"Invalid unique value: {params.unique}")
|
851
|
+
source, effect = choice_objs[speaker_ids].next()
|
852
|
+
new_source = Source(
|
853
|
+
file_id=source.id,
|
854
|
+
effects=effect,
|
855
|
+
start=params.start,
|
856
|
+
loop=params.loop,
|
857
|
+
snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
|
858
|
+
)
|
859
|
+
new_mixture.all_sources[category] = new_source
|
860
|
+
new_mixtures.append(new_mixture)
|
861
|
+
|
862
|
+
return new_mixtures
|
863
|
+
|
864
|
+
def _additional_sequence(
|
865
|
+
self,
|
866
|
+
effected_sources: list[tuple[SourceFile, Effects]],
|
867
|
+
mixtures: list[Mixture],
|
868
|
+
category: str,
|
869
|
+
mix_rule: str,
|
870
|
+
) -> list[Mixture]:
|
871
|
+
from copy import deepcopy
|
872
|
+
|
873
|
+
from ..datatypes import UniversalSNR
|
874
|
+
from ..parse.parse_source_directive import parse_source_directive
|
875
|
+
from ..utils.choice import SequentialChoice
|
876
|
+
|
877
|
+
# Parse the mix rule
|
878
|
+
try:
|
879
|
+
params = parse_source_directive(mix_rule)
|
880
|
+
except ValueError as e:
|
881
|
+
raise ValueError(f"Error parsing choose directive: {e}") from e
|
882
|
+
|
883
|
+
snrs = self.all_snrs()[category]
|
884
|
+
|
885
|
+
sequence_objs: dict[tuple[int, ...], SequentialChoice] = {}
|
886
|
+
if params.unique == "speaker_id" :
|
887
|
+
# Get a set of speaker_id values in use in the existing mixtures.
|
888
|
+
all_speaker_ids = {self.get_mixture_speaker_ids(mixture) for mixture in mixtures}
|
889
|
+
|
890
|
+
# Create a set of SequentialChoice objects that are filtered on those values.
|
891
|
+
for speaker_ids in all_speaker_ids:
|
892
|
+
filtered_sources = _filter_sources(effected_sources, speaker_ids)
|
893
|
+
if not filtered_sources:
|
894
|
+
raise ValueError(
|
895
|
+
f"Additional source, {category}, has no valid entries for speaker_ids unique from {speaker_ids}"
|
896
|
+
)
|
897
|
+
sequence_objs[speaker_ids] = SequentialChoice(data=filtered_sources)
|
898
|
+
elif params.unique is None:
|
899
|
+
sequence_objs = {(0,): SequentialChoice(data=effected_sources)}
|
900
|
+
else:
|
901
|
+
raise ValueError(f"Invalid unique value: {params.unique}")
|
902
|
+
|
903
|
+
# Loop over mixtures and add additional sources
|
904
|
+
new_mixtures: list[Mixture] = []
|
905
|
+
for mixture in mixtures:
|
906
|
+
for snr in snrs:
|
907
|
+
new_mixture = deepcopy(mixture)
|
908
|
+
if params.unique == "speaker_id" :
|
909
|
+
speaker_ids = self.get_mixture_speaker_ids(mixture)
|
910
|
+
elif params.unique is None:
|
911
|
+
speaker_ids = (0,)
|
912
|
+
else:
|
913
|
+
raise ValueError(f"Invalid unique value: {params.unique}")
|
914
|
+
source, effect = sequence_objs[speaker_ids].next()
|
915
|
+
new_source = Source(
|
916
|
+
file_id=source.id,
|
917
|
+
effects=effect,
|
918
|
+
start=params.start,
|
919
|
+
loop=params.loop,
|
920
|
+
snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
|
921
|
+
)
|
922
|
+
new_mixture.all_sources[category] = new_source
|
923
|
+
new_mixtures.append(new_mixture)
|
924
|
+
|
925
|
+
return new_mixtures
|
926
|
+
|
927
|
+
def get_mixture_speaker_ids(self, mixture: Mixture) -> tuple[int, ...]:
|
928
|
+
"""Get the speaker IDs used in a mixture, excluding None values"""
|
929
|
+
valid_speaker_ids = [
|
930
|
+
speaker_id
|
931
|
+
for source in mixture.all_sources.values()
|
932
|
+
if (speaker_id := self.mixdb.source_file(source.file_id).speaker_id) is not None
|
933
|
+
]
|
934
|
+
return tuple(valid_speaker_ids)
|
935
|
+
|
936
|
+
def all_snrs(self) -> dict[str, list[UniversalSNRGenerator]]:
|
937
|
+
snrs: dict[str, list[UniversalSNRGenerator]] = {}
|
938
|
+
for category in self.config["sources"]:
|
939
|
+
if category != "primary":
|
940
|
+
snrs[category] = [UniversalSNRGenerator(snr) for snr in self.config["sources"][category]["snrs"]]
|
941
|
+
return snrs
|
942
|
+
|
943
|
+
|
944
|
+
def update_mixture(mixdb: MixtureDatabase, mixture: Mixture, with_data: bool = False) -> tuple[Mixture, GenMixData]:
|
945
|
+
"""Update mixture record with name, samples, and gains"""
|
946
|
+
sources_audio: SourcesAudioT = {}
|
947
|
+
post_audio: SourcesAudioT = {}
|
948
|
+
for category in mixture.all_sources:
|
949
|
+
mixture, sources_audio[category], post_audio[category] = _update_source(mixdb, mixture, category)
|
950
|
+
|
951
|
+
mixture = _initialize_mixture_gains(mixdb, mixture, post_audio)
|
952
|
+
|
953
|
+
if not with_data:
|
954
|
+
return mixture, GenMixData()
|
955
|
+
|
956
|
+
# Apply gains
|
957
|
+
post_audio = {
|
958
|
+
category: post_audio[category] * mixture.all_sources[category].snr_gain for category in mixture.all_sources
|
959
|
+
}
|
960
|
+
|
961
|
+
# Sum sources, noise, and mixture
|
962
|
+
source_audio = np.sum([post_audio[category] for category in mixture.sources], axis=0)
|
963
|
+
noise_audio = post_audio["noise"]
|
964
|
+
mixture_audio = source_audio + noise_audio
|
965
|
+
|
966
|
+
return mixture, GenMixData(
|
967
|
+
sources=sources_audio,
|
968
|
+
source=source_audio,
|
969
|
+
noise=noise_audio,
|
970
|
+
mixture=mixture_audio,
|
971
|
+
)
|
972
|
+
|
973
|
+
|
974
|
+
def _update_source(mixdb: MixtureDatabase, mixture: Mixture, category: str) -> tuple[Mixture, AudioT, AudioT]:
|
975
|
+
from .effects import apply_effects
|
976
|
+
from .effects import conform_audio_to_length
|
977
|
+
|
978
|
+
source = mixture.all_sources[category]
|
979
|
+
org_audio = mixdb.read_source_audio(source.file_id)
|
980
|
+
|
981
|
+
org_samples = len(org_audio)
|
982
|
+
pre_audio = apply_effects(mixdb, org_audio, source.effects, pre=True, post=False)
|
983
|
+
|
984
|
+
pre_samples = len(pre_audio)
|
985
|
+
mixture.all_sources[category].pre_tempo = org_samples / pre_samples
|
986
|
+
|
987
|
+
pre_audio = conform_audio_to_length(pre_audio, mixture.samples, source.loop, source.start)
|
988
|
+
|
989
|
+
post_audio = apply_effects(mixdb, pre_audio, source.effects, pre=False, post=True)
|
990
|
+
if len(pre_audio) != len(post_audio):
|
991
|
+
raise RuntimeError(f"post-truth effects changed length: {source.effects.post}")
|
992
|
+
|
993
|
+
return mixture, pre_audio, post_audio
|
994
|
+
|
995
|
+
|
996
|
+
def _initialize_mixture_gains(mixdb: MixtureDatabase, mixture: Mixture, sources_audio: SourcesAudioT) -> Mixture:
|
997
|
+
from ..utils.asl_p56 import asl_p56
|
998
|
+
from ..utils.db import db_to_linear
|
999
|
+
|
1000
|
+
sources_energy: dict[str, float] = {}
|
1001
|
+
for category in mixture.all_sources:
|
1002
|
+
level_type = mixdb.source_file(mixture.all_sources[category].file_id).level_type
|
1003
|
+
match level_type:
|
1004
|
+
case "default":
|
1005
|
+
sources_energy[category] = float(np.mean(np.square(sources_audio[category])))
|
1006
|
+
case "speech":
|
1007
|
+
sources_energy[category] = asl_p56(sources_audio[category])
|
1008
|
+
case _:
|
1009
|
+
raise ValueError(f"Unknown level_type: {level_type}")
|
1010
|
+
|
1011
|
+
# Initialize all gains to 1
|
1012
|
+
for category in mixture.all_sources:
|
1013
|
+
mixture.all_sources[category].snr_gain = 1
|
1014
|
+
|
1015
|
+
# Resolve gains
|
1016
|
+
for category in mixture.all_sources:
|
1017
|
+
if mixture.is_noise_only and category != "noise":
|
1018
|
+
# Special case for zeroing out source data
|
1019
|
+
mixture.all_sources[category].snr_gain = 0
|
1020
|
+
elif mixture.is_source_only and category == "noise":
|
1021
|
+
# Special case for zeroing out noise data
|
1022
|
+
mixture.all_sources[category].snr_gain = 0
|
1023
|
+
elif category != "primary":
|
1024
|
+
if sources_energy["primary"] == 0 or sources_energy[category] == 0:
|
1025
|
+
# Avoid divide-by-zero
|
1026
|
+
mixture.all_sources[category].snr_gain = 1
|
1027
|
+
else:
|
1028
|
+
mixture.all_sources[category].snr_gain = float(
|
1029
|
+
np.sqrt(sources_energy["primary"] / sources_energy[category])
|
1030
|
+
) / db_to_linear(mixture.all_sources[category].snr)
|
1031
|
+
|
1032
|
+
# Normalize gains
|
1033
|
+
max_snr_gain = max([source.snr_gain for source in mixture.all_sources.values()])
|
1034
|
+
for category in mixture.all_sources:
|
1035
|
+
mixture.all_sources[category].snr_gain = mixture.all_sources[category].snr_gain / max_snr_gain
|
1036
|
+
|
1037
|
+
# Check for clipping in mixture
|
1038
|
+
mixture_audio = np.sum(
|
1039
|
+
[sources_audio[category] * mixture.all_sources[category].snr_gain for category in mixture.all_sources], axis=0
|
1040
|
+
)
|
1041
|
+
max_abs_audio = float(np.max(np.abs(mixture_audio)))
|
1042
|
+
clip_level = db_to_linear(-0.25)
|
1043
|
+
if max_abs_audio > clip_level:
|
1044
|
+
gain_adjustment = clip_level / max_abs_audio
|
1045
|
+
for category in mixture.all_sources:
|
1046
|
+
mixture.all_sources[category].snr_gain *= gain_adjustment
|
1047
|
+
|
1048
|
+
# To improve repeatability, round results
|
1049
|
+
for category in mixture.all_sources:
|
1050
|
+
mixture.all_sources[category].snr_gain = round(mixture.all_sources[category].snr_gain, ndigits=5)
|
1051
|
+
|
1052
|
+
return mixture
|
1053
|
+
|
1054
|
+
|
1055
|
+
class NextNoise:
|
1056
|
+
def __init__(self, mixdb: MixtureDatabase, effected_noises: list[tuple[SourceFile, Effects]]) -> None:
|
1057
|
+
from .effects import effects_from_rules
|
1058
|
+
from .effects import estimate_effected_length
|
1059
|
+
|
1060
|
+
self.mixdb = mixdb
|
1061
|
+
self.effected_noises = effected_noises
|
1062
|
+
|
1063
|
+
self.noise_start = 0
|
1064
|
+
self.noise_id = 0
|
1065
|
+
self.noise_effect = effects_from_rules(self.mixdb, self.noise_rule)
|
1066
|
+
self.noise_length = estimate_effected_length(self.noise_file.samples, self.noise_effect)
|
1067
|
+
|
1068
|
+
@property
|
1069
|
+
def noise_file(self):
|
1070
|
+
return self.effected_noises[self.noise_id][0]
|
1071
|
+
|
1072
|
+
@property
|
1073
|
+
def noise_rule(self):
|
1074
|
+
return self.effected_noises[self.noise_id][1]
|
1075
|
+
|
1076
|
+
def generate(self, length: int) -> tuple[int, Effects, int]:
|
1077
|
+
from .effects import effects_from_rules
|
1078
|
+
from .effects import estimate_effected_length
|
1079
|
+
|
1080
|
+
if self.noise_start + length > self.noise_length:
|
1081
|
+
# Not enough samples in current noise
|
1082
|
+
if self.noise_start == 0:
|
1083
|
+
raise ValueError("Length of primary audio exceeds length of noise audio")
|
1084
|
+
|
1085
|
+
self.noise_start = 0
|
1086
|
+
self.noise_id = (self.noise_id + 1) % len(self.effected_noises)
|
1087
|
+
self.noise_effect = effects_from_rules(self.mixdb, self.noise_rule)
|
1088
|
+
self.noise_length = estimate_effected_length(self.noise_file.samples, self.noise_effect)
|
1089
|
+
noise_start = self.noise_start
|
1090
|
+
else:
|
1091
|
+
# Current noise has enough samples
|
1092
|
+
noise_start = self.noise_start
|
1093
|
+
self.noise_start += length
|
1094
|
+
|
1095
|
+
return self.noise_file.id, self.noise_effect, noise_start
|
1096
|
+
|
1097
|
+
|
1098
|
+
def _get_textgrid_tiers_from_source_file(file: str) -> list[str]:
|
1099
|
+
from ..utils.tokenized_shell_vars import tokenized_expand
|
1100
|
+
|
1101
|
+
textgrid_file = Path(tokenized_expand(file)[0]).with_suffix(".TextGrid")
|
1102
|
+
if not textgrid_file.exists():
|
1103
|
+
return []
|
1104
|
+
|
1105
|
+
tg = textgrid.openTextgrid(str(textgrid_file), includeEmptyIntervals=False)
|
1106
|
+
|
1107
|
+
return sorted(tg.tierNames)
|
1108
|
+
|
1109
|
+
|
1110
|
+
def _filter_sources(
|
1111
|
+
effected_sources: list[tuple[SourceFile, Effects]],
|
1112
|
+
speaker_id: tuple[int, ...],
|
1113
|
+
) -> list[tuple[SourceFile, Effects]]:
|
1114
|
+
return [
|
1115
|
+
(source_file, effects) for source_file, effects in effected_sources if source_file.speaker_id not in speaker_id
|
1116
|
+
]
|