sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +81 -91
- sonusai/genmetrics.py +51 -61
- sonusai/genmix.py +105 -115
- sonusai/genmixdb.py +201 -174
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +16 -18
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +20 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +40 -38
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +669 -477
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +58 -101
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +41 -30
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
- sonusai-0.19.6.dist-info/RECORD +125 -0
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.9.dist-info/RECORD +0 -125
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
sonusai/mixture/generation.py
CHANGED
@@ -1,24 +1,25 @@
|
|
1
|
-
|
2
|
-
from
|
3
|
-
from
|
4
|
-
from
|
5
|
-
from
|
6
|
-
from
|
7
|
-
from
|
8
|
-
from
|
9
|
-
from
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from
|
13
|
-
from
|
14
|
-
from
|
15
|
-
from
|
1
|
+
# ruff: noqa: S608
|
2
|
+
from .datatypes import AudiosT
|
3
|
+
from .datatypes import AudioT
|
4
|
+
from .datatypes import Augmentation
|
5
|
+
from .datatypes import AugmentationRules
|
6
|
+
from .datatypes import AugmentedTargets
|
7
|
+
from .datatypes import GenMixData
|
8
|
+
from .datatypes import ImpulseResponseFiles
|
9
|
+
from .datatypes import Mixture
|
10
|
+
from .datatypes import Mixtures
|
11
|
+
from .datatypes import NoiseFiles
|
12
|
+
from .datatypes import SpectralMasks
|
13
|
+
from .datatypes import TargetFiles
|
14
|
+
from .datatypes import Targets
|
15
|
+
from .datatypes import UniversalSNRGenerator
|
16
|
+
from .mixdb import MixtureDatabase
|
16
17
|
|
17
18
|
|
18
19
|
def config_file(location: str) -> str:
|
19
20
|
from os.path import join
|
20
21
|
|
21
|
-
return join(location,
|
22
|
+
return join(location, "config.yml")
|
22
23
|
|
23
24
|
|
24
25
|
def initialize_db(location: str, test: bool = False) -> None:
|
@@ -27,9 +28,16 @@ def initialize_db(location: str, test: bool = False) -> None:
|
|
27
28
|
con = db_connection(location=location, create=True, test=test)
|
28
29
|
|
29
30
|
con.execute("""
|
30
|
-
CREATE TABLE
|
31
|
+
CREATE TABLE truth_config(
|
31
32
|
id INTEGER PRIMARY KEY NOT NULL,
|
32
|
-
|
33
|
+
config TEXT NOT NULL)
|
34
|
+
""")
|
35
|
+
|
36
|
+
con.execute("""
|
37
|
+
CREATE TABLE truth_parameters(
|
38
|
+
id INTEGER PRIMARY KEY NOT NULL,
|
39
|
+
name TEXT NOT NULL,
|
40
|
+
parameters INTEGER NOT NULL)
|
33
41
|
""")
|
34
42
|
|
35
43
|
con.execute("""
|
@@ -37,6 +45,7 @@ def initialize_db(location: str, test: bool = False) -> None:
|
|
37
45
|
id INTEGER PRIMARY KEY NOT NULL,
|
38
46
|
name TEXT NOT NULL,
|
39
47
|
samples INTEGER NOT NULL,
|
48
|
+
class_indices TEXT NOT NULL,
|
40
49
|
level_type TEXT NOT NULL,
|
41
50
|
speaker_id INTEGER,
|
42
51
|
FOREIGN KEY(speaker_id) REFERENCES speaker (id))
|
@@ -65,8 +74,6 @@ def initialize_db(location: str, test: bool = False) -> None:
|
|
65
74
|
noise_mix_mode TEXT NOT NULL,
|
66
75
|
num_classes INTEGER NOT NULL,
|
67
76
|
seed INTEGER NOT NULL,
|
68
|
-
truth_mutex BOOLEAN NOT NULL,
|
69
|
-
truth_reduction_function TEXT NOT NULL,
|
70
77
|
mixid_width INTEGER NOT NULL,
|
71
78
|
speaker_metadata_tiers TEXT NOT NULL,
|
72
79
|
textgrid_metadata_tiers TEXT NOT NULL)
|
@@ -87,7 +94,8 @@ def initialize_db(location: str, test: bool = False) -> None:
|
|
87
94
|
con.execute("""
|
88
95
|
CREATE TABLE impulse_response_file (
|
89
96
|
id INTEGER PRIMARY KEY NOT NULL,
|
90
|
-
file TEXT NOT NULL
|
97
|
+
file TEXT NOT NULL,
|
98
|
+
tags TEXT NOT NULL)
|
91
99
|
""")
|
92
100
|
|
93
101
|
con.execute("""
|
@@ -101,11 +109,11 @@ def initialize_db(location: str, test: bool = False) -> None:
|
|
101
109
|
""")
|
102
110
|
|
103
111
|
con.execute("""
|
104
|
-
CREATE TABLE
|
112
|
+
CREATE TABLE target_file_truth_config (
|
105
113
|
target_file_id INTEGER,
|
106
|
-
|
114
|
+
truth_config_id INTEGER,
|
107
115
|
FOREIGN KEY(target_file_id) REFERENCES target_file (id),
|
108
|
-
FOREIGN KEY(
|
116
|
+
FOREIGN KEY(truth_config_id) REFERENCES truth_config (id))
|
109
117
|
""")
|
110
118
|
|
111
119
|
con.execute("""
|
@@ -148,59 +156,55 @@ def initialize_db(location: str, test: bool = False) -> None:
|
|
148
156
|
|
149
157
|
|
150
158
|
def populate_top_table(location: str, config: dict, test: bool = False) -> None:
|
151
|
-
"""Populate top table
|
152
|
-
"""
|
159
|
+
"""Populate top table"""
|
153
160
|
import json
|
154
161
|
|
155
|
-
from
|
162
|
+
from .constants import MIXDB_VERSION
|
156
163
|
from .mixdb import db_connection
|
157
164
|
|
158
|
-
if config['truth_mode'] not in ['normal', 'mutex']:
|
159
|
-
raise SonusAIError(f'invalid truth_mode: {config["truth_mode"]}')
|
160
|
-
truth_mutex = config['truth_mode'] == 'mutex'
|
161
|
-
|
162
165
|
con = db_connection(location=location, readonly=False, test=test)
|
163
|
-
con.execute(
|
166
|
+
con.execute(
|
167
|
+
"""
|
164
168
|
INSERT INTO top (version, asr_configs, class_balancing, feature, noise_mix_mode, num_classes,
|
165
|
-
seed,
|
166
|
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?,
|
167
|
-
""",
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
169
|
+
seed, mixid_width, speaker_metadata_tiers, textgrid_metadata_tiers)
|
170
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
171
|
+
""",
|
172
|
+
(
|
173
|
+
MIXDB_VERSION,
|
174
|
+
json.dumps(config["asr_configs"]),
|
175
|
+
config["class_balancing"],
|
176
|
+
config["feature"],
|
177
|
+
config["noise_mix_mode"],
|
178
|
+
config["num_classes"],
|
179
|
+
config["seed"],
|
180
|
+
0,
|
181
|
+
"",
|
182
|
+
"",
|
183
|
+
),
|
184
|
+
)
|
180
185
|
con.commit()
|
181
186
|
con.close()
|
182
187
|
|
183
188
|
|
184
189
|
def populate_class_label_table(location: str, config: dict, test: bool = False) -> None:
|
185
|
-
"""Populate class_label table
|
186
|
-
"""
|
190
|
+
"""Populate class_label table"""
|
187
191
|
from .mixdb import db_connection
|
188
192
|
|
189
193
|
con = db_connection(location=location, readonly=False, test=test)
|
190
|
-
con.executemany(
|
191
|
-
|
194
|
+
con.executemany(
|
195
|
+
"INSERT INTO class_label (label) VALUES (?)",
|
196
|
+
[(item,) for item in config["class_labels"]],
|
197
|
+
)
|
192
198
|
con.commit()
|
193
199
|
con.close()
|
194
200
|
|
195
201
|
|
196
202
|
def populate_class_weights_threshold_table(location: str, config: dict, test: bool = False) -> None:
|
197
|
-
"""Populate class_weights_threshold table
|
198
|
-
"""
|
199
|
-
from sonusai import SonusAIError
|
203
|
+
"""Populate class_weights_threshold table"""
|
200
204
|
from .mixdb import db_connection
|
201
205
|
|
202
|
-
class_weights_threshold = config[
|
203
|
-
num_classes = config[
|
206
|
+
class_weights_threshold = config["class_weights_threshold"]
|
207
|
+
num_classes = config["num_classes"]
|
204
208
|
|
205
209
|
if not isinstance(class_weights_threshold, list):
|
206
210
|
class_weights_threshold = [class_weights_threshold]
|
@@ -209,43 +213,72 @@ def populate_class_weights_threshold_table(location: str, config: dict, test: bo
|
|
209
213
|
class_weights_threshold = [class_weights_threshold[0]] * num_classes
|
210
214
|
|
211
215
|
if len(class_weights_threshold) != num_classes:
|
212
|
-
raise
|
216
|
+
raise ValueError(f"invalid class_weights_threshold length: {len(class_weights_threshold)}")
|
213
217
|
|
214
218
|
con = db_connection(location=location, readonly=False, test=test)
|
215
|
-
con.executemany(
|
216
|
-
|
219
|
+
con.executemany(
|
220
|
+
"INSERT INTO class_weights_threshold (threshold) VALUES (?)",
|
221
|
+
[(item,) for item in class_weights_threshold],
|
222
|
+
)
|
217
223
|
con.commit()
|
218
224
|
con.close()
|
219
225
|
|
220
226
|
|
221
227
|
def populate_spectral_mask_table(location: str, config: dict, test: bool = False) -> None:
|
222
|
-
"""Populate spectral_mask table
|
223
|
-
"""
|
228
|
+
"""Populate spectral_mask table"""
|
224
229
|
from .config import get_spectral_masks
|
225
230
|
from .mixdb import db_connection
|
226
231
|
|
227
232
|
con = db_connection(location=location, readonly=False, test=test)
|
228
|
-
con.executemany(
|
233
|
+
con.executemany(
|
234
|
+
"""
|
229
235
|
INSERT INTO spectral_mask (f_max_width, f_num, t_max_width, t_num, t_max_percent) VALUES (?, ?, ?, ?, ?)
|
230
|
-
""",
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
+
""",
|
237
|
+
[
|
238
|
+
(
|
239
|
+
item.f_max_width,
|
240
|
+
item.f_num,
|
241
|
+
item.t_max_width,
|
242
|
+
item.t_num,
|
243
|
+
item.t_max_percent,
|
244
|
+
)
|
245
|
+
for item in get_spectral_masks(config)
|
246
|
+
],
|
247
|
+
)
|
248
|
+
con.commit()
|
249
|
+
con.close()
|
250
|
+
|
251
|
+
|
252
|
+
def populate_truth_parameters_table(location: str, config: dict, test: bool = False) -> None:
|
253
|
+
"""Populate truth_parameters table"""
|
254
|
+
from .config import get_truth_parameters
|
255
|
+
from .mixdb import db_connection
|
256
|
+
|
257
|
+
con = db_connection(location=location, readonly=False, test=test)
|
258
|
+
con.executemany(
|
259
|
+
"""
|
260
|
+
INSERT INTO truth_parameters (name, parameters) VALUES (?, ?)
|
261
|
+
""",
|
262
|
+
[
|
263
|
+
(
|
264
|
+
item.name,
|
265
|
+
item.parameters,
|
266
|
+
)
|
267
|
+
for item in get_truth_parameters(config)
|
268
|
+
],
|
269
|
+
)
|
236
270
|
con.commit()
|
237
271
|
con.close()
|
238
272
|
|
239
273
|
|
240
274
|
def populate_target_file_table(location: str, target_files: TargetFiles, test: bool = False) -> None:
|
241
|
-
"""Populate target file table
|
242
|
-
"""
|
275
|
+
"""Populate target file table"""
|
243
276
|
import json
|
244
277
|
from pathlib import Path
|
245
278
|
|
246
279
|
from .mixdb import db_connection
|
247
280
|
|
248
|
-
|
281
|
+
_populate_truth_config_table(location, target_files, test)
|
249
282
|
_populate_speaker_table(location, target_files, test)
|
250
283
|
|
251
284
|
con = db_connection(location=location, readonly=False, test=test)
|
@@ -259,76 +292,106 @@ def populate_target_file_table(location: str, target_files: TargetFiles, test: b
|
|
259
292
|
textgrid_metadata_tiers.add(tier)
|
260
293
|
|
261
294
|
# Get truth settings for target file
|
262
|
-
|
263
|
-
for
|
264
|
-
|
265
|
-
|
266
|
-
|
295
|
+
truth_config_ids: list[int] = []
|
296
|
+
for name, config in target_file.truth_configs.items():
|
297
|
+
ts = json.dumps({"name": name} | config.to_dict())
|
298
|
+
cur.execute(
|
299
|
+
"SELECT truth_config.id FROM truth_config WHERE ? = truth_config.config",
|
300
|
+
(ts,),
|
301
|
+
)
|
302
|
+
truth_config_ids.append(cur.fetchone()[0])
|
267
303
|
|
268
304
|
# Get speaker_id for target file
|
269
|
-
cur.execute(
|
270
|
-
|
305
|
+
cur.execute(
|
306
|
+
"SELECT speaker.id FROM speaker WHERE ? = speaker.parent",
|
307
|
+
(Path(target_file.name).parent.as_posix(),),
|
308
|
+
)
|
271
309
|
result = cur.fetchone()
|
272
310
|
speaker_id = None
|
273
311
|
if result is not None:
|
274
312
|
speaker_id = result[0]
|
275
313
|
|
276
314
|
# Add entry
|
277
|
-
cur.execute(
|
278
|
-
|
315
|
+
cur.execute(
|
316
|
+
"INSERT INTO target_file (name, samples, class_indices, level_type, speaker_id) VALUES (?, ?, ?, ?, ?)",
|
317
|
+
(
|
318
|
+
target_file.name,
|
319
|
+
target_file.samples,
|
320
|
+
json.dumps(target_file.class_indices),
|
321
|
+
target_file.level_type,
|
322
|
+
speaker_id,
|
323
|
+
),
|
324
|
+
)
|
279
325
|
target_file_id = cur.lastrowid
|
280
|
-
for
|
281
|
-
cur.execute(
|
282
|
-
|
326
|
+
for truth_config_id in truth_config_ids:
|
327
|
+
cur.execute(
|
328
|
+
"INSERT INTO target_file_truth_config (target_file_id, truth_config_id) VALUES (?, ?)",
|
329
|
+
(target_file_id, truth_config_id),
|
330
|
+
)
|
283
331
|
|
284
332
|
# Update textgrid_metadata_tiers in the top table
|
285
|
-
con.execute(
|
286
|
-
|
333
|
+
con.execute(
|
334
|
+
"UPDATE top SET textgrid_metadata_tiers=? WHERE top.id = ?",
|
335
|
+
(json.dumps(sorted(textgrid_metadata_tiers)), 1),
|
336
|
+
)
|
287
337
|
|
288
338
|
con.commit()
|
289
339
|
con.close()
|
290
340
|
|
291
341
|
|
292
342
|
def populate_noise_file_table(location: str, noise_files: NoiseFiles, test: bool = False) -> None:
|
293
|
-
"""Populate noise file table
|
294
|
-
"""
|
343
|
+
"""Populate noise file table"""
|
295
344
|
from .mixdb import db_connection
|
296
345
|
|
297
346
|
con = db_connection(location=location, readonly=False, test=test)
|
298
|
-
con.executemany(
|
299
|
-
|
347
|
+
con.executemany(
|
348
|
+
"INSERT INTO noise_file (name, samples) VALUES (?, ?)",
|
349
|
+
[(noise_file.name, noise_file.samples) for noise_file in noise_files],
|
350
|
+
)
|
300
351
|
con.commit()
|
301
352
|
con.close()
|
302
353
|
|
303
354
|
|
304
|
-
def populate_impulse_response_file_table(
|
305
|
-
|
306
|
-
|
307
|
-
"""
|
355
|
+
def populate_impulse_response_file_table(
|
356
|
+
location: str, impulse_response_files: ImpulseResponseFiles, test: bool = False
|
357
|
+
) -> None:
|
358
|
+
"""Populate impulse response file table"""
|
359
|
+
import json
|
360
|
+
|
308
361
|
from .mixdb import db_connection
|
309
362
|
|
310
363
|
con = db_connection(location=location, readonly=False, test=test)
|
311
|
-
con.executemany(
|
312
|
-
|
364
|
+
con.executemany(
|
365
|
+
"INSERT INTO impulse_response_file (file, tags) VALUES (?, ?)",
|
366
|
+
[
|
367
|
+
(
|
368
|
+
impulse_response_file.file,
|
369
|
+
json.dumps(impulse_response_file.tags),
|
370
|
+
)
|
371
|
+
for impulse_response_file in impulse_response_files
|
372
|
+
],
|
373
|
+
)
|
313
374
|
con.commit()
|
314
375
|
con.close()
|
315
376
|
|
316
377
|
|
317
378
|
def update_mixid_width(location: str, num_mixtures: int, test: bool = False) -> None:
|
318
|
-
"""Update the mixid width
|
319
|
-
"""
|
320
|
-
from .mixdb import db_connection
|
379
|
+
"""Update the mixid width"""
|
321
380
|
from sonusai.utils import max_text_width
|
322
381
|
|
382
|
+
from .mixdb import db_connection
|
383
|
+
|
323
384
|
con = db_connection(location=location, readonly=False, test=test)
|
324
|
-
con.execute(
|
385
|
+
con.execute(
|
386
|
+
"UPDATE top SET mixid_width=? WHERE top.id = ?",
|
387
|
+
(max_text_width(num_mixtures), 1),
|
388
|
+
)
|
325
389
|
con.commit()
|
326
390
|
con.close()
|
327
391
|
|
328
392
|
|
329
393
|
def populate_mixture_table(location: str, mixtures: Mixtures, test: bool = False) -> None:
|
330
|
-
"""Populate mixture table
|
331
|
-
"""
|
394
|
+
"""Populate mixture table"""
|
332
395
|
from .helpers import from_mixture
|
333
396
|
from .helpers import from_target
|
334
397
|
from .mixdb import db_connection
|
@@ -348,29 +411,35 @@ def populate_mixture_table(location: str, mixtures: Mixtures, test: bool = False
|
|
348
411
|
# Populate mixture table
|
349
412
|
cur = con.cursor()
|
350
413
|
for mixture in mixtures:
|
351
|
-
cur.execute(
|
414
|
+
cur.execute(
|
415
|
+
"""
|
352
416
|
INSERT INTO mixture (name, noise_file_id, noise_augmentation, noise_offset, noise_snr_gain, random_snr,
|
353
417
|
snr, samples, spectral_mask_id, spectral_mask_seed, target_snr_gain)
|
354
418
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
355
|
-
""",
|
419
|
+
""",
|
420
|
+
from_mixture(mixture),
|
421
|
+
)
|
356
422
|
|
357
423
|
mixture_id = cur.lastrowid
|
358
424
|
for target in mixture.targets:
|
359
|
-
target_id = con.execute(
|
425
|
+
target_id = con.execute(
|
426
|
+
"""
|
360
427
|
SELECT target.id
|
361
428
|
FROM target
|
362
429
|
WHERE ? = target.file_id AND ? = target.augmentation AND ? = target.gain
|
363
|
-
""",
|
364
|
-
|
365
|
-
|
430
|
+
""",
|
431
|
+
from_target(target),
|
432
|
+
).fetchone()[0]
|
433
|
+
con.execute(
|
434
|
+
"INSERT INTO mixture_target (mixture_id, target_id) VALUES (?, ?)",
|
435
|
+
(mixture_id, target_id),
|
436
|
+
)
|
366
437
|
|
367
438
|
con.commit()
|
368
439
|
con.close()
|
369
440
|
|
370
441
|
|
371
|
-
def update_mixture(mixdb: MixtureDatabase,
|
372
|
-
mixture: Mixture,
|
373
|
-
with_data: bool = False) -> tuple[Mixture, GenMixData]:
|
442
|
+
def update_mixture(mixdb: MixtureDatabase, mixture: Mixture, with_data: bool = False) -> tuple[Mixture, GenMixData]:
|
374
443
|
"""Update mixture record with name and gains
|
375
444
|
|
376
445
|
:param mixdb: Mixture database
|
@@ -391,12 +460,11 @@ def update_mixture(mixdb: MixtureDatabase,
|
|
391
460
|
# Apply IR and sum targets audio before initializing the mixture SNR gains
|
392
461
|
target_audio = get_target(mixdb, mixture, targets_audio)
|
393
462
|
|
394
|
-
mixture = _initialize_mixture_gains(
|
395
|
-
|
396
|
-
|
397
|
-
noise_audio=noise_audio)
|
463
|
+
mixture = _initialize_mixture_gains(
|
464
|
+
mixdb=mixdb, mixture=mixture, target_audio=target_audio, noise_audio=noise_audio
|
465
|
+
)
|
398
466
|
|
399
|
-
mixture.name = f
|
467
|
+
mixture.name = f"{int(mixture.name):0{mixdb.mixid_width}}"
|
400
468
|
|
401
469
|
if not with_data:
|
402
470
|
return mixture, GenMixData()
|
@@ -409,10 +477,12 @@ def update_mixture(mixdb: MixtureDatabase,
|
|
409
477
|
target_audio = get_target(mixdb, mixture, targets_audio)
|
410
478
|
mixture_audio = target_audio + noise_audio
|
411
479
|
|
412
|
-
return mixture, GenMixData(
|
413
|
-
|
414
|
-
|
415
|
-
|
480
|
+
return mixture, GenMixData(
|
481
|
+
mixture=mixture_audio,
|
482
|
+
targets=targets_audio,
|
483
|
+
target=target_audio,
|
484
|
+
noise=noise_audio,
|
485
|
+
)
|
416
486
|
|
417
487
|
|
418
488
|
def _augmented_noise_audio(mixdb: MixtureDatabase, mixture: Mixture) -> AudioT:
|
@@ -439,9 +509,13 @@ def _initialize_targets_audio(mixdb: MixtureDatabase, mixture: Mixture) -> tuple
|
|
439
509
|
targets_audio = []
|
440
510
|
for target in mixture.targets:
|
441
511
|
target_audio = mixdb.read_target_audio(target.file_id)
|
442
|
-
targets_audio.append(
|
443
|
-
|
444
|
-
|
512
|
+
targets_audio.append(
|
513
|
+
apply_augmentation(
|
514
|
+
audio=target_audio,
|
515
|
+
augmentation=target.augmentation,
|
516
|
+
frame_length=mixdb.feature_step_samples,
|
517
|
+
)
|
518
|
+
)
|
445
519
|
|
446
520
|
# target_gain is used to back out the gain augmentation in order to return the target audio
|
447
521
|
# to its normalized level when calculating truth (if needed).
|
@@ -458,13 +532,11 @@ def _initialize_targets_audio(mixdb: MixtureDatabase, mixture: Mixture) -> tuple
|
|
458
532
|
return mixture, targets_audio
|
459
533
|
|
460
534
|
|
461
|
-
def _initialize_mixture_gains(
|
462
|
-
|
463
|
-
|
464
|
-
noise_audio: AudioT) -> Mixture:
|
535
|
+
def _initialize_mixture_gains(
|
536
|
+
mixdb: MixtureDatabase, mixture: Mixture, target_audio: AudioT, noise_audio: AudioT
|
537
|
+
) -> Mixture:
|
465
538
|
import numpy as np
|
466
539
|
|
467
|
-
from sonusai import SonusAIError
|
468
540
|
from sonusai.utils import asl_p56
|
469
541
|
from sonusai.utils import db_to_linear
|
470
542
|
|
@@ -480,19 +552,20 @@ def _initialize_mixture_gains(mixdb: MixtureDatabase,
|
|
480
552
|
mixture.target_snr_gain = 1
|
481
553
|
mixture.noise_snr_gain = 0
|
482
554
|
else:
|
483
|
-
target_level_types = [
|
484
|
-
|
555
|
+
target_level_types = [
|
556
|
+
target_file.level_type for target_file in [mixdb.target_file(target.file_id) for target in mixture.targets]
|
557
|
+
]
|
485
558
|
if not all(level_type == target_level_types[0] for level_type in target_level_types):
|
486
|
-
raise
|
559
|
+
raise ValueError("Not all target_level_types in mixup are the same")
|
487
560
|
|
488
561
|
level_type = target_level_types[0]
|
489
562
|
match level_type:
|
490
|
-
case
|
563
|
+
case "default":
|
491
564
|
target_energy = np.mean(np.square(target_audio))
|
492
|
-
case
|
565
|
+
case "speech":
|
493
566
|
target_energy = asl_p56(target_audio)
|
494
567
|
case _:
|
495
|
-
raise
|
568
|
+
raise ValueError(f"Unknown level_type: {level_type}")
|
496
569
|
|
497
570
|
noise_energy = np.mean(np.square(noise_audio))
|
498
571
|
if noise_energy == 0:
|
@@ -525,19 +598,20 @@ def _initialize_mixture_gains(mixdb: MixtureDatabase,
|
|
525
598
|
return mixture
|
526
599
|
|
527
600
|
|
528
|
-
def generate_mixtures(
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
601
|
+
def generate_mixtures(
|
602
|
+
noise_mix_mode: str,
|
603
|
+
augmented_targets: AugmentedTargets,
|
604
|
+
target_files: TargetFiles,
|
605
|
+
target_augmentations: AugmentationRules,
|
606
|
+
noise_files: NoiseFiles,
|
607
|
+
noise_augmentations: AugmentationRules,
|
608
|
+
spectral_masks: SpectralMasks,
|
609
|
+
all_snrs: list[UniversalSNRGenerator],
|
610
|
+
mixups: list[int],
|
611
|
+
num_classes: int,
|
612
|
+
feature_step_samples: int,
|
613
|
+
num_ir: int,
|
614
|
+
) -> tuple[int, int, Mixtures]:
|
541
615
|
"""Generate mixtures
|
542
616
|
|
543
617
|
:param noise_mix_mode: Noise mix mode
|
@@ -550,72 +624,72 @@ def generate_mixtures(noise_mix_mode: str,
|
|
550
624
|
:param all_snrs: List of all SNRs
|
551
625
|
:param mixups: List of mixup values
|
552
626
|
:param num_classes: Number of classes
|
553
|
-
:param truth_mutex: Truth mutex mode
|
554
627
|
:param feature_step_samples: Number of samples in a feature step
|
555
628
|
:param num_ir: Number of impulse response files
|
556
629
|
:return: (Number of noise files used, number of noise samples used, list of mixture records)
|
557
630
|
"""
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
if noise_mix_mode ==
|
589
|
-
return _non_combinatorial_noise_mix(
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
631
|
+
if noise_mix_mode == "exhaustive":
|
632
|
+
return _exhaustive_noise_mix(
|
633
|
+
augmented_targets=augmented_targets,
|
634
|
+
target_files=target_files,
|
635
|
+
target_augmentations=target_augmentations,
|
636
|
+
noise_files=noise_files,
|
637
|
+
noise_augmentations=noise_augmentations,
|
638
|
+
spectral_masks=spectral_masks,
|
639
|
+
all_snrs=all_snrs,
|
640
|
+
mixups=mixups,
|
641
|
+
num_classes=num_classes,
|
642
|
+
feature_step_samples=feature_step_samples,
|
643
|
+
num_ir=num_ir,
|
644
|
+
)
|
645
|
+
|
646
|
+
if noise_mix_mode == "non-exhaustive":
|
647
|
+
return _non_exhaustive_noise_mix(
|
648
|
+
augmented_targets=augmented_targets,
|
649
|
+
target_files=target_files,
|
650
|
+
target_augmentations=target_augmentations,
|
651
|
+
noise_files=noise_files,
|
652
|
+
noise_augmentations=noise_augmentations,
|
653
|
+
spectral_masks=spectral_masks,
|
654
|
+
all_snrs=all_snrs,
|
655
|
+
mixups=mixups,
|
656
|
+
num_classes=num_classes,
|
657
|
+
feature_step_samples=feature_step_samples,
|
658
|
+
num_ir=num_ir,
|
659
|
+
)
|
660
|
+
|
661
|
+
if noise_mix_mode == "non-combinatorial":
|
662
|
+
return _non_combinatorial_noise_mix(
|
663
|
+
augmented_targets=augmented_targets,
|
664
|
+
target_files=target_files,
|
665
|
+
target_augmentations=target_augmentations,
|
666
|
+
noise_files=noise_files,
|
667
|
+
noise_augmentations=noise_augmentations,
|
668
|
+
spectral_masks=spectral_masks,
|
669
|
+
all_snrs=all_snrs,
|
670
|
+
mixups=mixups,
|
671
|
+
num_classes=num_classes,
|
672
|
+
feature_step_samples=feature_step_samples,
|
673
|
+
num_ir=num_ir,
|
674
|
+
)
|
675
|
+
|
676
|
+
raise ValueError(f"invalid noise_mix_mode: {noise_mix_mode}")
|
677
|
+
|
678
|
+
|
679
|
+
def _exhaustive_noise_mix(
|
680
|
+
augmented_targets: AugmentedTargets,
|
681
|
+
target_files: TargetFiles,
|
682
|
+
target_augmentations: AugmentationRules,
|
683
|
+
noise_files: NoiseFiles,
|
684
|
+
noise_augmentations: AugmentationRules,
|
685
|
+
spectral_masks: SpectralMasks,
|
686
|
+
all_snrs: list[UniversalSNRGenerator],
|
687
|
+
mixups: list[int],
|
688
|
+
num_classes: int,
|
689
|
+
feature_step_samples: int,
|
690
|
+
num_ir: int,
|
691
|
+
) -> tuple[int, int, Mixtures]:
|
692
|
+
"""Use every noise/augmentation with every target/augmentation"""
|
619
693
|
from random import randint
|
620
694
|
|
621
695
|
import numpy as np
|
@@ -633,42 +707,53 @@ def _exhaustive_noise_mix(augmented_targets: AugmentedTargets,
|
|
633
707
|
used_noise_files = len(noise_files) * len(noise_augmentations)
|
634
708
|
used_noise_samples = 0
|
635
709
|
|
636
|
-
augmented_target_ids_for_mixups = [
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
710
|
+
augmented_target_ids_for_mixups = [
|
711
|
+
get_augmented_target_ids_for_mixup(
|
712
|
+
augmented_targets=augmented_targets,
|
713
|
+
targets=target_files,
|
714
|
+
target_augmentations=target_augmentations,
|
715
|
+
mixup=mixup,
|
716
|
+
num_classes=num_classes,
|
717
|
+
)
|
718
|
+
for mixup in mixups
|
719
|
+
]
|
642
720
|
for noise_file_id in range(len(noise_files)):
|
643
721
|
for noise_augmentation_rule in noise_augmentations:
|
644
722
|
noise_augmentation = augmentation_from_rule(noise_augmentation_rule, num_ir)
|
645
723
|
noise_offset = 0
|
646
724
|
noise_length = estimate_augmented_length_from_length(
|
647
725
|
length=noise_files[noise_file_id].samples,
|
648
|
-
tempo=noise_augmentation.tempo
|
726
|
+
tempo=noise_augmentation.tempo,
|
727
|
+
)
|
649
728
|
|
650
729
|
for augmented_target_ids_for_mixup in augmented_target_ids_for_mixups:
|
651
730
|
for augmented_target_ids in augmented_target_ids_for_mixup:
|
652
|
-
targets, target_length = _get_target_info(
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
731
|
+
targets, target_length = _get_target_info(
|
732
|
+
augmented_target_ids=augmented_target_ids,
|
733
|
+
augmented_targets=augmented_targets,
|
734
|
+
target_files=target_files,
|
735
|
+
target_augmentations=target_augmentations,
|
736
|
+
feature_step_samples=feature_step_samples,
|
737
|
+
num_ir=num_ir,
|
738
|
+
)
|
658
739
|
|
659
740
|
for spectral_mask_id in range(len(spectral_masks)):
|
660
741
|
for snr in all_snrs:
|
661
|
-
mixtures.append(
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
742
|
+
mixtures.append(
|
743
|
+
Mixture(
|
744
|
+
targets=targets,
|
745
|
+
name=str(m_id),
|
746
|
+
noise=Noise(
|
747
|
+
file_id=noise_file_id + 1,
|
748
|
+
augmentation=noise_augmentation,
|
749
|
+
offset=noise_offset,
|
750
|
+
),
|
751
|
+
samples=target_length,
|
752
|
+
snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
|
753
|
+
spectral_mask_id=spectral_mask_id + 1,
|
754
|
+
spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
|
755
|
+
)
|
756
|
+
)
|
672
757
|
m_id += 1
|
673
758
|
|
674
759
|
noise_offset = int((noise_offset + target_length) % noise_length)
|
@@ -677,19 +762,20 @@ def _exhaustive_noise_mix(augmented_targets: AugmentedTargets,
|
|
677
762
|
return used_noise_files, used_noise_samples, mixtures
|
678
763
|
|
679
764
|
|
680
|
-
def _non_exhaustive_noise_mix(
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
765
|
+
def _non_exhaustive_noise_mix(
|
766
|
+
augmented_targets: AugmentedTargets,
|
767
|
+
target_files: TargetFiles,
|
768
|
+
target_augmentations: AugmentationRules,
|
769
|
+
noise_files: NoiseFiles,
|
770
|
+
noise_augmentations: AugmentationRules,
|
771
|
+
spectral_masks: SpectralMasks,
|
772
|
+
all_snrs: list[UniversalSNRGenerator],
|
773
|
+
mixups: list[int],
|
774
|
+
num_classes: int,
|
775
|
+
feature_step_samples: int,
|
776
|
+
num_ir: int,
|
777
|
+
) -> tuple[int, int, Mixtures]:
|
778
|
+
"""Cycle through every target/augmentation without necessarily using all noise/augmentation combinations
|
693
779
|
(reduced data set).
|
694
780
|
"""
|
695
781
|
from random import randint
|
@@ -710,67 +796,81 @@ def _non_exhaustive_noise_mix(augmented_targets: AugmentedTargets,
|
|
710
796
|
noise_augmentation_id = None
|
711
797
|
noise_offset = None
|
712
798
|
|
713
|
-
augmented_target_indices_for_mixups = [
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
799
|
+
augmented_target_indices_for_mixups = [
|
800
|
+
get_augmented_target_ids_for_mixup(
|
801
|
+
augmented_targets=augmented_targets,
|
802
|
+
targets=target_files,
|
803
|
+
target_augmentations=target_augmentations,
|
804
|
+
mixup=mixup,
|
805
|
+
num_classes=num_classes,
|
806
|
+
)
|
807
|
+
for mixup in mixups
|
808
|
+
]
|
720
809
|
for mixup in augmented_target_indices_for_mixups:
|
721
810
|
for augmented_target_indices in mixup:
|
722
|
-
targets, target_length = _get_target_info(
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
811
|
+
targets, target_length = _get_target_info(
|
812
|
+
augmented_target_ids=augmented_target_indices,
|
813
|
+
augmented_targets=augmented_targets,
|
814
|
+
target_files=target_files,
|
815
|
+
target_augmentations=target_augmentations,
|
816
|
+
feature_step_samples=feature_step_samples,
|
817
|
+
num_ir=num_ir,
|
818
|
+
)
|
728
819
|
|
729
820
|
for spectral_mask_id in range(len(spectral_masks)):
|
730
821
|
for snr in all_snrs:
|
731
|
-
(
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
822
|
+
(
|
823
|
+
noise_file_id,
|
824
|
+
noise_augmentation_id,
|
825
|
+
noise_augmentation,
|
826
|
+
noise_offset,
|
827
|
+
) = _get_next_noise_offset(
|
828
|
+
noise_file_id=noise_file_id,
|
829
|
+
noise_augmentation_id=noise_augmentation_id,
|
830
|
+
noise_offset=noise_offset,
|
831
|
+
target_length=target_length,
|
832
|
+
noise_files=noise_files,
|
833
|
+
noise_augmentations=noise_augmentations,
|
834
|
+
num_ir=num_ir,
|
835
|
+
)
|
741
836
|
used_noise_samples += target_length
|
742
837
|
|
743
|
-
used_noise_files.add(f
|
744
|
-
|
745
|
-
mixtures.append(
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
838
|
+
used_noise_files.add(f"{noise_file_id}_{noise_augmentation_id}")
|
839
|
+
|
840
|
+
mixtures.append(
|
841
|
+
Mixture(
|
842
|
+
targets=targets,
|
843
|
+
name=str(m_id),
|
844
|
+
noise=Noise(
|
845
|
+
file_id=noise_file_id + 1,
|
846
|
+
augmentation=noise_augmentation,
|
847
|
+
offset=noise_offset,
|
848
|
+
),
|
849
|
+
samples=target_length,
|
850
|
+
snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
|
851
|
+
spectral_mask_id=spectral_mask_id + 1,
|
852
|
+
spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
|
853
|
+
)
|
854
|
+
)
|
756
855
|
m_id += 1
|
757
856
|
|
758
857
|
return len(used_noise_files), used_noise_samples, mixtures
|
759
858
|
|
760
859
|
|
761
|
-
def _non_combinatorial_noise_mix(
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
860
|
+
def _non_combinatorial_noise_mix(
|
861
|
+
augmented_targets: AugmentedTargets,
|
862
|
+
target_files: TargetFiles,
|
863
|
+
target_augmentations: AugmentationRules,
|
864
|
+
noise_files: NoiseFiles,
|
865
|
+
noise_augmentations: AugmentationRules,
|
866
|
+
spectral_masks: SpectralMasks,
|
867
|
+
all_snrs: list[UniversalSNRGenerator],
|
868
|
+
mixups: list[int],
|
869
|
+
num_classes: int,
|
870
|
+
feature_step_samples: int,
|
871
|
+
num_ir: int,
|
872
|
+
) -> tuple[int, int, Mixtures]:
|
873
|
+
"""Combine a target/augmentation with a single cut of a noise/augmentation non-exhaustively
|
774
874
|
(each target/augmentation does not use each noise/augmentation). Cut has random start and loop back to
|
775
875
|
beginning if end of noise/augmentation is reached.
|
776
876
|
"""
|
@@ -792,57 +892,72 @@ def _non_combinatorial_noise_mix(augmented_targets: AugmentedTargets,
|
|
792
892
|
noise_file_id = None
|
793
893
|
noise_augmentation_id = None
|
794
894
|
|
795
|
-
augmented_target_indices_for_mixups = [
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
895
|
+
augmented_target_indices_for_mixups = [
|
896
|
+
get_augmented_target_ids_for_mixup(
|
897
|
+
augmented_targets=augmented_targets,
|
898
|
+
targets=target_files,
|
899
|
+
target_augmentations=target_augmentations,
|
900
|
+
mixup=mixup,
|
901
|
+
num_classes=num_classes,
|
902
|
+
)
|
903
|
+
for mixup in mixups
|
904
|
+
]
|
802
905
|
for mixup in augmented_target_indices_for_mixups:
|
803
906
|
for augmented_target_indices in mixup:
|
804
|
-
targets, target_length = _get_target_info(
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
907
|
+
targets, target_length = _get_target_info(
|
908
|
+
augmented_target_ids=augmented_target_indices,
|
909
|
+
augmented_targets=augmented_targets,
|
910
|
+
target_files=target_files,
|
911
|
+
target_augmentations=target_augmentations,
|
912
|
+
feature_step_samples=feature_step_samples,
|
913
|
+
num_ir=num_ir,
|
914
|
+
)
|
810
915
|
|
811
916
|
for spectral_mask_id in range(len(spectral_masks)):
|
812
917
|
for snr in all_snrs:
|
813
|
-
(
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
918
|
+
(
|
919
|
+
noise_file_id,
|
920
|
+
noise_augmentation_id,
|
921
|
+
noise_augmentation,
|
922
|
+
noise_length,
|
923
|
+
) = _get_next_noise_indices(
|
924
|
+
noise_file_id=noise_file_id,
|
925
|
+
noise_augmentation_id=noise_augmentation_id,
|
926
|
+
noise_files=noise_files,
|
927
|
+
noise_augmentations=noise_augmentations,
|
928
|
+
num_ir=num_ir,
|
929
|
+
)
|
821
930
|
used_noise_samples += target_length
|
822
931
|
|
823
|
-
used_noise_files.add(f
|
824
|
-
|
825
|
-
mixtures.append(
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
932
|
+
used_noise_files.add(f"{noise_file_id}_{noise_augmentation_id}")
|
933
|
+
|
934
|
+
mixtures.append(
|
935
|
+
Mixture(
|
936
|
+
targets=targets,
|
937
|
+
name=str(m_id),
|
938
|
+
noise=Noise(
|
939
|
+
file_id=noise_file_id + 1,
|
940
|
+
augmentation=noise_augmentation,
|
941
|
+
offset=choice(range(noise_length)), # noqa: S311
|
942
|
+
),
|
943
|
+
samples=target_length,
|
944
|
+
snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
|
945
|
+
spectral_mask_id=spectral_mask_id + 1,
|
946
|
+
spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
|
947
|
+
)
|
948
|
+
)
|
836
949
|
m_id += 1
|
837
950
|
|
838
951
|
return len(used_noise_files), used_noise_samples, mixtures
|
839
952
|
|
840
953
|
|
841
|
-
def _get_next_noise_indices(
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
954
|
+
def _get_next_noise_indices(
|
955
|
+
noise_file_id: int | None,
|
956
|
+
noise_augmentation_id: int | None,
|
957
|
+
noise_files: NoiseFiles,
|
958
|
+
noise_augmentations: AugmentationRules,
|
959
|
+
num_ir: int,
|
960
|
+
) -> tuple[int, int, Augmentation, int]:
|
846
961
|
from .augmentation import augmentation_from_rule
|
847
962
|
from .augmentation import estimate_augmented_length_from_length
|
848
963
|
|
@@ -858,19 +973,21 @@ def _get_next_noise_indices(noise_file_id: int,
|
|
858
973
|
noise_file_id = 0
|
859
974
|
|
860
975
|
noise_augmentation = augmentation_from_rule(noise_augmentations[noise_augmentation_id], num_ir)
|
861
|
-
noise_length = estimate_augmented_length_from_length(
|
862
|
-
|
976
|
+
noise_length = estimate_augmented_length_from_length(
|
977
|
+
length=noise_files[noise_file_id].samples, tempo=noise_augmentation.tempo
|
978
|
+
)
|
863
979
|
return noise_file_id, noise_augmentation_id, noise_augmentation, noise_length
|
864
980
|
|
865
981
|
|
866
|
-
def _get_next_noise_offset(
|
867
|
-
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
982
|
+
def _get_next_noise_offset(
|
983
|
+
noise_file_id: int | None,
|
984
|
+
noise_augmentation_id: int | None,
|
985
|
+
noise_offset: int | None,
|
986
|
+
target_length: int,
|
987
|
+
noise_files: NoiseFiles,
|
988
|
+
noise_augmentations: AugmentationRules,
|
989
|
+
num_ir: int,
|
990
|
+
) -> tuple[int, int, Augmentation, int]:
|
874
991
|
from .augmentation import augmentation_from_rule
|
875
992
|
from .augmentation import estimate_augmented_length_from_length
|
876
993
|
|
@@ -880,11 +997,12 @@ def _get_next_noise_offset(noise_file_id: int | None,
|
|
880
997
|
noise_offset = 0
|
881
998
|
|
882
999
|
noise_augmentation = augmentation_from_rule(noise_augmentations[noise_file_id], num_ir)
|
883
|
-
noise_length = estimate_augmented_length_from_length(
|
884
|
-
|
1000
|
+
noise_length = estimate_augmented_length_from_length(
|
1001
|
+
length=noise_files[noise_file_id].samples, tempo=noise_augmentation.tempo
|
1002
|
+
)
|
885
1003
|
if noise_offset + target_length >= noise_length:
|
886
1004
|
if noise_offset == 0:
|
887
|
-
raise
|
1005
|
+
raise ValueError("Length of target audio exceeds length of noise audio")
|
888
1006
|
|
889
1007
|
noise_offset = 0
|
890
1008
|
noise_augmentation_id += 1
|
@@ -898,12 +1016,14 @@ def _get_next_noise_offset(noise_file_id: int | None,
|
|
898
1016
|
return noise_file_id, noise_augmentation_id, noise_augmentation, noise_offset
|
899
1017
|
|
900
1018
|
|
901
|
-
def _get_target_info(
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
1019
|
+
def _get_target_info(
|
1020
|
+
augmented_target_ids: list[int],
|
1021
|
+
augmented_targets: AugmentedTargets,
|
1022
|
+
target_files: TargetFiles,
|
1023
|
+
target_augmentations: AugmentationRules,
|
1024
|
+
feature_step_samples: int,
|
1025
|
+
num_ir: int,
|
1026
|
+
) -> tuple[Targets, int]:
|
907
1027
|
from .augmentation import augmentation_from_rule
|
908
1028
|
from .augmentation import estimate_augmented_length_from_length
|
909
1029
|
from .datatypes import Target
|
@@ -918,18 +1038,23 @@ def _get_target_info(augmented_target_ids: list[int],
|
|
918
1038
|
|
919
1039
|
mixups.append(Target(file_id=tfi + 1, augmentation=target_augmentation))
|
920
1040
|
|
921
|
-
target_length = max(
|
922
|
-
|
923
|
-
|
924
|
-
|
1041
|
+
target_length = max(
|
1042
|
+
estimate_augmented_length_from_length(
|
1043
|
+
length=target_files[tfi].samples,
|
1044
|
+
tempo=target_augmentation.tempo,
|
1045
|
+
frame_length=feature_step_samples,
|
1046
|
+
),
|
1047
|
+
target_length,
|
1048
|
+
)
|
925
1049
|
return mixups, target_length
|
926
1050
|
|
927
1051
|
|
928
1052
|
def get_all_snrs_from_config(config: dict) -> list[UniversalSNRGenerator]:
|
929
1053
|
from .datatypes import UniversalSNRGenerator
|
930
1054
|
|
931
|
-
return
|
932
|
-
|
1055
|
+
return [UniversalSNRGenerator(is_random=False, _raw_value=snr) for snr in config["snrs"]] + [
|
1056
|
+
UniversalSNRGenerator(is_random=True, _raw_value=snr) for snr in config["random_snrs"]
|
1057
|
+
]
|
933
1058
|
|
934
1059
|
|
935
1060
|
def _get_textgrid_tiers_from_target_file(target_file: str) -> list[str]:
|
@@ -939,7 +1064,7 @@ def _get_textgrid_tiers_from_target_file(target_file: str) -> list[str]:
|
|
939
1064
|
|
940
1065
|
from sonusai.mixture import tokenized_expand
|
941
1066
|
|
942
|
-
textgrid_file = Path(tokenized_expand(target_file)[0]).with_suffix(
|
1067
|
+
textgrid_file = Path(tokenized_expand(target_file)[0]).with_suffix(".TextGrid")
|
943
1068
|
if not textgrid_file.exists():
|
944
1069
|
return []
|
945
1070
|
|
@@ -949,8 +1074,7 @@ def _get_textgrid_tiers_from_target_file(target_file: str) -> list[str]:
|
|
949
1074
|
|
950
1075
|
|
951
1076
|
def _populate_speaker_table(location: str, target_files: TargetFiles, test: bool = False) -> None:
|
952
|
-
"""Populate speaker table
|
953
|
-
"""
|
1077
|
+
"""Populate speaker table"""
|
954
1078
|
import json
|
955
1079
|
from pathlib import Path
|
956
1080
|
|
@@ -960,65 +1084,74 @@ def _populate_speaker_table(location: str, target_files: TargetFiles, test: bool
|
|
960
1084
|
from .tokenized_shell_vars import tokenized_expand
|
961
1085
|
|
962
1086
|
# Determine columns for speaker table
|
963
|
-
all_parents =
|
964
|
-
speaker_parents = (parent for parent in all_parents if Path(tokenized_expand(parent /
|
1087
|
+
all_parents = {Path(target_file.name).parent for target_file in target_files}
|
1088
|
+
speaker_parents = (parent for parent in all_parents if Path(tokenized_expand(parent / "speaker.yml")[0]).exists())
|
965
1089
|
|
966
1090
|
speakers: dict[Path, dict[str, str]] = {}
|
967
1091
|
for parent in sorted(speaker_parents):
|
968
|
-
with open(tokenized_expand(parent /
|
1092
|
+
with open(tokenized_expand(parent / "speaker.yml")[0]) as f:
|
969
1093
|
speakers[parent] = yaml.safe_load(f)
|
970
1094
|
|
971
1095
|
new_columns: list[str] = []
|
972
|
-
for keys in speakers
|
973
|
-
for column in speakers[keys]
|
1096
|
+
for keys in speakers:
|
1097
|
+
for column in speakers[keys]:
|
974
1098
|
new_columns.append(column)
|
975
1099
|
new_columns = sorted(set(new_columns))
|
976
1100
|
|
977
1101
|
con = db_connection(location=location, readonly=False, test=test)
|
978
1102
|
|
979
1103
|
for new_column in new_columns:
|
980
|
-
con.execute(f
|
1104
|
+
con.execute(f"ALTER TABLE speaker ADD COLUMN {new_column} TEXT")
|
981
1105
|
|
982
1106
|
# Populate speaker table
|
983
1107
|
speaker_rows: list[tuple[str, ...]] = []
|
984
|
-
for key in speakers
|
1108
|
+
for key in speakers:
|
985
1109
|
entry = (speakers[key].get(column, None) for column in new_columns)
|
986
|
-
speaker_rows.append((key.as_posix(), *entry))
|
1110
|
+
speaker_rows.append((key.as_posix(), *entry)) # type: ignore[arg-type]
|
987
1111
|
|
988
|
-
column_ids =
|
989
|
-
column_values =
|
990
|
-
con.executemany(f
|
1112
|
+
column_ids = ", ".join(["parent", *new_columns])
|
1113
|
+
column_values = ", ".join(["?"] * (len(new_columns) + 1))
|
1114
|
+
con.executemany(f"INSERT INTO speaker ({column_ids}) VALUES ({column_values})", speaker_rows)
|
991
1115
|
|
992
1116
|
con.execute("CREATE INDEX speaker_parent_idx ON speaker (parent)")
|
993
1117
|
|
994
1118
|
# Update speaker_metadata_tiers in the top table
|
995
|
-
tiers = [
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
|
1119
|
+
tiers = [
|
1120
|
+
description[0]
|
1121
|
+
for description in con.execute("SELECT * FROM speaker").description
|
1122
|
+
if description[0] not in ("id", "parent")
|
1123
|
+
]
|
1124
|
+
con.execute(
|
1125
|
+
"UPDATE top SET speaker_metadata_tiers=? WHERE top.id = ?",
|
1126
|
+
(json.dumps(tiers), 1),
|
1127
|
+
)
|
1128
|
+
|
1129
|
+
if "speaker_id" in tiers:
|
1000
1130
|
con.execute("CREATE INDEX speaker_speaker_id_idx ON speaker (speaker_id)")
|
1001
1131
|
|
1002
1132
|
con.commit()
|
1003
1133
|
con.close()
|
1004
1134
|
|
1005
1135
|
|
1006
|
-
def
|
1007
|
-
"""Populate
|
1008
|
-
|
1136
|
+
def _populate_truth_config_table(location: str, target_files: TargetFiles, test: bool = False) -> None:
|
1137
|
+
"""Populate truth_config table"""
|
1138
|
+
import json
|
1139
|
+
|
1009
1140
|
from .mixdb import db_connection
|
1010
1141
|
|
1011
1142
|
con = db_connection(location=location, readonly=False, test=test)
|
1012
1143
|
|
1013
|
-
# Populate
|
1014
|
-
|
1015
|
-
for
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1019
|
-
|
1020
|
-
con.executemany(
|
1021
|
-
|
1144
|
+
# Populate truth_config table
|
1145
|
+
truth_configs: list[str] = []
|
1146
|
+
for target_file in target_files:
|
1147
|
+
for name, config in target_file.truth_configs.items():
|
1148
|
+
ts = json.dumps({"name": name} | config.to_dict())
|
1149
|
+
if ts not in truth_configs:
|
1150
|
+
truth_configs.append(ts)
|
1151
|
+
con.executemany(
|
1152
|
+
"INSERT INTO truth_config (config) VALUES (?)",
|
1153
|
+
[(item,) for item in truth_configs],
|
1154
|
+
)
|
1022
1155
|
|
1023
1156
|
con.commit()
|
1024
1157
|
con.close()
|