sonusai 0.20.2__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +16 -3
- sonusai/audiofe.py +240 -76
- sonusai/calc_metric_spenh.py +71 -73
- sonusai/config/__init__.py +3 -0
- sonusai/config/config.py +61 -0
- sonusai/config/config.yml +20 -0
- sonusai/config/constants.py +8 -0
- sonusai/constants.py +11 -0
- sonusai/data/genmixdb.yml +21 -36
- sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
- sonusai/deprecated/plot.py +4 -5
- sonusai/doc/doc.py +4 -4
- sonusai/doc.py +11 -4
- sonusai/genft.py +43 -45
- sonusai/genmetrics.py +23 -19
- sonusai/genmix.py +54 -82
- sonusai/genmixdb.py +88 -264
- sonusai/ir_metric.py +30 -34
- sonusai/lsdb.py +41 -48
- sonusai/main.py +15 -22
- sonusai/metrics/calc_audio_stats.py +4 -17
- sonusai/metrics/calc_class_weights.py +4 -4
- sonusai/metrics/calc_optimal_thresholds.py +8 -5
- sonusai/metrics/calc_pesq.py +2 -2
- sonusai/metrics/calc_segsnr_f.py +4 -4
- sonusai/metrics/calc_speech.py +25 -13
- sonusai/metrics/class_summary.py +7 -7
- sonusai/metrics/confusion_matrix_summary.py +5 -5
- sonusai/metrics/one_hot.py +4 -4
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/metrics_summary.py +38 -45
- sonusai/mixture/__init__.py +5 -104
- sonusai/mixture/audio.py +10 -39
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/config.py +251 -271
- sonusai/mixture/constants.py +35 -39
- sonusai/mixture/data_io.py +25 -36
- sonusai/mixture/db_datatypes.py +58 -22
- sonusai/mixture/effects.py +386 -0
- sonusai/mixture/feature.py +7 -11
- sonusai/mixture/generation.py +484 -611
- sonusai/mixture/helpers.py +82 -184
- sonusai/mixture/ir_delay.py +3 -4
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +6 -12
- sonusai/mixture/mixdb.py +931 -669
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +2 -2
- sonusai/mixture/truth.py +17 -15
- sonusai/mixture/truth_functions/crm.py +12 -12
- sonusai/mixture/truth_functions/energy.py +22 -22
- sonusai/mixture/truth_functions/file.py +5 -5
- sonusai/mixture/truth_functions/metadata.py +4 -4
- sonusai/mixture/truth_functions/metrics.py +4 -4
- sonusai/mixture/truth_functions/phoneme.py +3 -3
- sonusai/mixture/truth_functions/sed.py +11 -13
- sonusai/mixture/truth_functions/target.py +10 -10
- sonusai/mkwav.py +26 -29
- sonusai/onnx_predict.py +240 -88
- sonusai/queries/__init__.py +2 -2
- sonusai/queries/queries.py +38 -34
- sonusai/speech/librispeech.py +1 -1
- sonusai/speech/mcgill.py +1 -1
- sonusai/speech/timit.py +2 -2
- sonusai/summarize_metric_spenh.py +10 -17
- sonusai/utils/__init__.py +7 -1
- sonusai/utils/asl_p56.py +2 -2
- sonusai/utils/asr.py +2 -2
- sonusai/utils/asr_functions/aaware_whisper.py +4 -5
- sonusai/utils/choice.py +31 -0
- sonusai/utils/compress.py +1 -1
- sonusai/utils/dataclass_from_dict.py +19 -1
- sonusai/utils/energy_f.py +3 -3
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/onnx_utils.py +3 -17
- sonusai/utils/print_mixture_details.py +21 -19
- sonusai/utils/{temp_seed.py → rand.py} +3 -3
- sonusai/utils/read_predict_data.py +2 -2
- sonusai/utils/reshape.py +3 -3
- sonusai/utils/stratified_shuffle_split.py +3 -3
- sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
- sonusai/utils/write_audio.py +2 -2
- sonusai/vars.py +11 -4
- {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/METADATA +4 -2
- sonusai-1.0.1.dist-info/RECORD +138 -0
- sonusai/mixture/augmentation.py +0 -444
- sonusai/mixture/class_count.py +0 -15
- sonusai/mixture/eq_rule_is_valid.py +0 -45
- sonusai/mixture/target_class_balancing.py +0 -107
- sonusai/mixture/targets.py +0 -175
- sonusai-0.20.2.dist-info/RECORD +0 -128
- {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/WHEEL +0 -0
- {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/entry_points.txt +0 -0
sonusai/mixture/generation.py
CHANGED
@@ -1,17 +1,14 @@
|
|
1
1
|
# ruff: noqa: S608
|
2
|
-
from .datatypes import AudioT
|
3
|
-
from .datatypes import Augmentation
|
4
|
-
from .datatypes import AugmentationRule
|
5
|
-
from .datatypes import AugmentedTarget
|
6
|
-
from .datatypes import GenMixData
|
7
|
-
from .datatypes import ImpulseResponseFile
|
8
|
-
from .datatypes import Mixture
|
9
|
-
from .datatypes import NoiseFile
|
10
|
-
from .datatypes import SpectralMask
|
11
|
-
from .datatypes import Target
|
12
|
-
from .datatypes import TargetFile
|
13
|
-
from .datatypes import UniversalSNRGenerator
|
14
2
|
from .mixdb import MixtureDatabase
|
3
|
+
from ..datatypes import AudioT
|
4
|
+
from ..datatypes import Effects
|
5
|
+
from ..datatypes import GenMixData
|
6
|
+
from ..datatypes import ImpulseResponseFile
|
7
|
+
from ..datatypes import Mixture
|
8
|
+
from ..datatypes import Source
|
9
|
+
from ..datatypes import SourceFile
|
10
|
+
from ..datatypes import SourcesAudioT
|
11
|
+
from ..datatypes import UniversalSNRGenerator
|
15
12
|
|
16
13
|
|
17
14
|
def config_file(location: str) -> str:
|
@@ -34,47 +31,62 @@ def initialize_db(location: str, test: bool = False) -> None:
|
|
34
31
|
con.execute("""
|
35
32
|
CREATE TABLE truth_parameters(
|
36
33
|
id INTEGER PRIMARY KEY NOT NULL,
|
34
|
+
category TEXT NOT NULL,
|
37
35
|
name TEXT NOT NULL,
|
38
36
|
parameters INTEGER)
|
39
37
|
""")
|
40
38
|
|
41
39
|
con.execute("""
|
42
|
-
CREATE TABLE
|
40
|
+
CREATE TABLE source_file (
|
43
41
|
id INTEGER PRIMARY KEY NOT NULL,
|
42
|
+
category TEXT NOT NULL,
|
43
|
+
class_indices TEXT,
|
44
|
+
level_type TEXT NOT NULL,
|
44
45
|
name TEXT NOT NULL,
|
45
46
|
samples INTEGER NOT NULL,
|
46
|
-
class_indices TEXT NOT NULL,
|
47
|
-
level_type TEXT NOT NULL,
|
48
47
|
speaker_id INTEGER,
|
49
48
|
FOREIGN KEY(speaker_id) REFERENCES speaker (id))
|
50
49
|
""")
|
51
50
|
|
52
51
|
con.execute("""
|
53
|
-
CREATE TABLE
|
52
|
+
CREATE TABLE ir_file (
|
54
53
|
id INTEGER PRIMARY KEY NOT NULL,
|
55
|
-
|
54
|
+
delay INTEGER NOT NULL,
|
55
|
+
name TEXT NOT NULL)
|
56
56
|
""")
|
57
57
|
|
58
58
|
con.execute("""
|
59
|
-
CREATE TABLE
|
59
|
+
CREATE TABLE ir_tag (
|
60
60
|
id INTEGER PRIMARY KEY NOT NULL,
|
61
|
-
|
62
|
-
|
61
|
+
tag TEXT NOT NULL UNIQUE)
|
62
|
+
""")
|
63
|
+
|
64
|
+
con.execute("""
|
65
|
+
CREATE TABLE ir_file_ir_tag (
|
66
|
+
file_id INTEGER NOT NULL,
|
67
|
+
tag_id INTEGER NOT NULL,
|
68
|
+
FOREIGN KEY(file_id) REFERENCES ir_file (id),
|
69
|
+
FOREIGN KEY(tag_id) REFERENCES ir_tag (id))
|
70
|
+
""")
|
71
|
+
|
72
|
+
con.execute("""
|
73
|
+
CREATE TABLE speaker (
|
74
|
+
id INTEGER PRIMARY KEY NOT NULL,
|
75
|
+
parent TEXT NOT NULL)
|
63
76
|
""")
|
64
77
|
|
65
78
|
con.execute("""
|
66
79
|
CREATE TABLE top (
|
67
80
|
id INTEGER PRIMARY KEY NOT NULL,
|
68
|
-
version INTEGER NOT NULL,
|
69
81
|
asr_configs TEXT NOT NULL,
|
70
82
|
class_balancing BOOLEAN NOT NULL,
|
71
83
|
feature TEXT NOT NULL,
|
72
|
-
|
84
|
+
mixid_width INTEGER NOT NULL,
|
73
85
|
num_classes INTEGER NOT NULL,
|
74
86
|
seed INTEGER NOT NULL,
|
75
|
-
mixid_width INTEGER NOT NULL,
|
76
87
|
speaker_metadata_tiers TEXT NOT NULL,
|
77
|
-
textgrid_metadata_tiers TEXT NOT NULL
|
88
|
+
textgrid_metadata_tiers TEXT NOT NULL,
|
89
|
+
version INTEGER NOT NULL)
|
78
90
|
""")
|
79
91
|
|
80
92
|
con.execute("""
|
@@ -89,64 +101,54 @@ def initialize_db(location: str, test: bool = False) -> None:
|
|
89
101
|
threshold FLOAT NOT NULL)
|
90
102
|
""")
|
91
103
|
|
92
|
-
con.execute("""
|
93
|
-
CREATE TABLE impulse_response_file (
|
94
|
-
id INTEGER PRIMARY KEY NOT NULL,
|
95
|
-
file TEXT NOT NULL,
|
96
|
-
tags TEXT NOT NULL,
|
97
|
-
delay INTEGER NOT NULL)
|
98
|
-
""")
|
99
|
-
|
100
104
|
con.execute("""
|
101
105
|
CREATE TABLE spectral_mask (
|
102
106
|
id INTEGER PRIMARY KEY NOT NULL,
|
103
107
|
f_max_width INTEGER NOT NULL,
|
104
108
|
f_num INTEGER NOT NULL,
|
109
|
+
t_max_percent INTEGER NOT NULL,
|
105
110
|
t_max_width INTEGER NOT NULL,
|
106
|
-
t_num INTEGER NOT NULL
|
107
|
-
t_max_percent INTEGER NOT NULL)
|
111
|
+
t_num INTEGER NOT NULL)
|
108
112
|
""")
|
109
113
|
|
110
114
|
con.execute("""
|
111
|
-
CREATE TABLE
|
112
|
-
|
113
|
-
truth_config_id INTEGER,
|
114
|
-
FOREIGN KEY(
|
115
|
+
CREATE TABLE source_file_truth_config (
|
116
|
+
source_file_id INTEGER NOT NULL,
|
117
|
+
truth_config_id INTEGER NOT NULL,
|
118
|
+
FOREIGN KEY(source_file_id) REFERENCES source_file (id),
|
115
119
|
FOREIGN KEY(truth_config_id) REFERENCES truth_config (id))
|
116
120
|
""")
|
117
121
|
|
118
122
|
con.execute("""
|
119
|
-
CREATE TABLE
|
123
|
+
CREATE TABLE source (
|
120
124
|
id INTEGER PRIMARY KEY NOT NULL,
|
125
|
+
effects TEXT NOT NULL,
|
121
126
|
file_id INTEGER NOT NULL,
|
122
|
-
|
123
|
-
|
127
|
+
pre_tempo FLOAT NOT NULL,
|
128
|
+
repeat BOOLEAN NOT NULL,
|
129
|
+
snr FLOAT NOT NULL,
|
130
|
+
snr_gain FLOAT NOT NULL,
|
131
|
+
snr_random BOOLEAN NOT NULL,
|
132
|
+
start INTEGER NOT NULL,
|
133
|
+
FOREIGN KEY(file_id) REFERENCES source_file (id))
|
124
134
|
""")
|
125
135
|
|
126
136
|
con.execute("""
|
127
137
|
CREATE TABLE mixture (
|
128
138
|
id INTEGER PRIMARY KEY NOT NULL,
|
129
|
-
name
|
130
|
-
noise_file_id INTEGER NOT NULL,
|
131
|
-
noise_augmentation TEXT NOT NULL,
|
132
|
-
noise_offset INTEGER NOT NULL,
|
133
|
-
noise_snr_gain FLOAT,
|
134
|
-
random_snr BOOLEAN NOT NULL,
|
135
|
-
snr FLOAT NOT NULL,
|
139
|
+
name TEXT NOT NULL,
|
136
140
|
samples INTEGER NOT NULL,
|
137
141
|
spectral_mask_id INTEGER NOT NULL,
|
138
142
|
spectral_mask_seed INTEGER NOT NULL,
|
139
|
-
target_snr_gain FLOAT,
|
140
|
-
FOREIGN KEY(noise_file_id) REFERENCES noise_file (id),
|
141
143
|
FOREIGN KEY(spectral_mask_id) REFERENCES spectral_mask (id))
|
142
144
|
""")
|
143
145
|
|
144
146
|
con.execute("""
|
145
|
-
CREATE TABLE
|
146
|
-
mixture_id INTEGER,
|
147
|
-
|
147
|
+
CREATE TABLE mixture_source (
|
148
|
+
mixture_id INTEGER NOT NULL,
|
149
|
+
source_id INTEGER NOT NULL,
|
148
150
|
FOREIGN KEY(mixture_id) REFERENCES mixture (id),
|
149
|
-
FOREIGN KEY(
|
151
|
+
FOREIGN KEY(source_id) REFERENCES source (id))
|
150
152
|
""")
|
151
153
|
|
152
154
|
con.commit()
|
@@ -163,22 +165,21 @@ def populate_top_table(location: str, config: dict, test: bool = False) -> None:
|
|
163
165
|
con = db_connection(location=location, readonly=False, test=test)
|
164
166
|
con.execute(
|
165
167
|
"""
|
166
|
-
INSERT INTO top (id,
|
167
|
-
seed,
|
168
|
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?,
|
168
|
+
INSERT INTO top (id, asr_configs, class_balancing, feature, mixid_width, num_classes,
|
169
|
+
seed, speaker_metadata_tiers, textgrid_metadata_tiers, version)
|
170
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
169
171
|
""",
|
170
172
|
(
|
171
173
|
1,
|
172
|
-
MIXDB_VERSION,
|
173
174
|
json.dumps(config["asr_configs"]),
|
174
175
|
config["class_balancing"],
|
175
176
|
config["feature"],
|
176
|
-
|
177
|
+
0,
|
177
178
|
config["num_classes"],
|
178
179
|
config["seed"],
|
179
|
-
0,
|
180
180
|
"",
|
181
181
|
"",
|
182
|
+
MIXDB_VERSION,
|
182
183
|
),
|
183
184
|
)
|
184
185
|
con.commit()
|
@@ -231,15 +232,15 @@ def populate_spectral_mask_table(location: str, config: dict, test: bool = False
|
|
231
232
|
con = db_connection(location=location, readonly=False, test=test)
|
232
233
|
con.executemany(
|
233
234
|
"""
|
234
|
-
INSERT INTO spectral_mask (f_max_width, f_num, t_max_width, t_num
|
235
|
+
INSERT INTO spectral_mask (f_max_width, f_num, t_max_percent, t_max_width, t_num) VALUES (?, ?, ?, ?, ?)
|
235
236
|
""",
|
236
237
|
[
|
237
238
|
(
|
238
239
|
item.f_max_width,
|
239
240
|
item.f_num,
|
241
|
+
item.t_max_percent,
|
240
242
|
item.t_max_width,
|
241
243
|
item.t_num,
|
242
|
-
item.t_max_percent,
|
243
244
|
)
|
244
245
|
for item in get_spectral_masks(config)
|
245
246
|
],
|
@@ -256,10 +257,11 @@ def populate_truth_parameters_table(location: str, config: dict, test: bool = Fa
|
|
256
257
|
con = db_connection(location=location, readonly=False, test=test)
|
257
258
|
con.executemany(
|
258
259
|
"""
|
259
|
-
INSERT INTO truth_parameters (name, parameters) VALUES (?, ?)
|
260
|
+
INSERT INTO truth_parameters (category, name, parameters) VALUES (?, ?, ?)
|
260
261
|
""",
|
261
262
|
[
|
262
263
|
(
|
264
|
+
item.category,
|
263
265
|
item.name,
|
264
266
|
item.parameters,
|
265
267
|
)
|
@@ -270,40 +272,41 @@ def populate_truth_parameters_table(location: str, config: dict, test: bool = Fa
|
|
270
272
|
con.close()
|
271
273
|
|
272
274
|
|
273
|
-
def
|
274
|
-
"""Populate
|
275
|
+
def populate_source_file_table(location: str, files: list[SourceFile], test: bool = False) -> None:
|
276
|
+
"""Populate source file table"""
|
275
277
|
import json
|
276
278
|
from pathlib import Path
|
277
279
|
|
278
280
|
from .mixdb import db_connection
|
279
281
|
|
280
|
-
_populate_truth_config_table(location,
|
281
|
-
_populate_speaker_table(location,
|
282
|
+
_populate_truth_config_table(location, files, test)
|
283
|
+
_populate_speaker_table(location, files, test)
|
282
284
|
|
283
285
|
con = db_connection(location=location, readonly=False, test=test)
|
284
286
|
|
285
287
|
cur = con.cursor()
|
286
288
|
textgrid_metadata_tiers: set[str] = set()
|
287
|
-
for
|
288
|
-
# Get TextGrid tiers for
|
289
|
-
tiers =
|
289
|
+
for file in files:
|
290
|
+
# Get TextGrid tiers for source file and add to collection
|
291
|
+
tiers = _get_textgrid_tiers_from_source_file(file.name)
|
290
292
|
for tier in tiers:
|
291
293
|
textgrid_metadata_tiers.add(tier)
|
292
294
|
|
293
|
-
# Get truth settings for
|
295
|
+
# Get truth settings for file
|
294
296
|
truth_config_ids: list[int] = []
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
297
|
+
if file.truth_configs:
|
298
|
+
for name, config in file.truth_configs.items():
|
299
|
+
ts = json.dumps({"name": name} | config.to_dict())
|
300
|
+
cur.execute(
|
301
|
+
"SELECT truth_config.id FROM truth_config WHERE ? = truth_config.config",
|
302
|
+
(ts,),
|
303
|
+
)
|
304
|
+
truth_config_ids.append(cur.fetchone()[0])
|
305
|
+
|
306
|
+
# Get speaker_id for source file
|
304
307
|
cur.execute(
|
305
308
|
"SELECT speaker.id FROM speaker WHERE ? = speaker.parent",
|
306
|
-
(Path(
|
309
|
+
(Path(file.name).parent.as_posix(),),
|
307
310
|
)
|
308
311
|
result = cur.fetchone()
|
309
312
|
speaker_id = None
|
@@ -312,20 +315,24 @@ def populate_target_file_table(location: str, target_files: list[TargetFile], te
|
|
312
315
|
|
313
316
|
# Add entry
|
314
317
|
cur.execute(
|
315
|
-
"
|
318
|
+
"""
|
319
|
+
INSERT INTO source_file (category, class_indices, level_type, name, samples, speaker_id)
|
320
|
+
VALUES (?, ?, ?, ?, ?, ?)
|
321
|
+
""",
|
316
322
|
(
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
323
|
+
file.category,
|
324
|
+
json.dumps(file.class_indices),
|
325
|
+
file.level_type,
|
326
|
+
file.name,
|
327
|
+
file.samples,
|
321
328
|
speaker_id,
|
322
329
|
),
|
323
330
|
)
|
324
|
-
|
331
|
+
source_file_id = cur.lastrowid
|
325
332
|
for truth_config_id in truth_config_ids:
|
326
333
|
cur.execute(
|
327
|
-
"INSERT INTO
|
328
|
-
(
|
334
|
+
"INSERT INTO source_file_truth_config (source_file_id, truth_config_id) VALUES (?, ?)",
|
335
|
+
(source_file_id, truth_config_id),
|
329
336
|
)
|
330
337
|
|
331
338
|
# Update textgrid_metadata_tiers in the top table
|
@@ -338,47 +345,47 @@ def populate_target_file_table(location: str, target_files: list[TargetFile], te
|
|
338
345
|
con.close()
|
339
346
|
|
340
347
|
|
341
|
-
def
|
342
|
-
"""Populate
|
348
|
+
def populate_impulse_response_file_table(location: str, files: list[ImpulseResponseFile], test: bool = False) -> None:
|
349
|
+
"""Populate impulse response file table"""
|
343
350
|
from .mixdb import db_connection
|
344
351
|
|
345
|
-
|
346
|
-
con.executemany(
|
347
|
-
"INSERT INTO noise_file (name, samples) VALUES (?, ?)",
|
348
|
-
[(noise_file.name, noise_file.samples) for noise_file in noise_files],
|
349
|
-
)
|
350
|
-
con.commit()
|
351
|
-
con.close()
|
352
|
-
|
352
|
+
_populate_impulse_response_tag_table(location, files, test)
|
353
353
|
|
354
|
-
|
355
|
-
location: str, impulse_response_files: list[ImpulseResponseFile], test: bool = False
|
356
|
-
) -> None:
|
357
|
-
"""Populate impulse response file table"""
|
358
|
-
import json
|
354
|
+
con = db_connection(location=location, readonly=False, test=test)
|
359
355
|
|
360
|
-
|
356
|
+
cur = con.cursor()
|
357
|
+
for file in files:
|
358
|
+
# Get tags for file
|
359
|
+
tag_ids: list[int] = []
|
360
|
+
for tag in file.tags:
|
361
|
+
cur.execute(
|
362
|
+
"SELECT ir_tag.id FROM ir_tag WHERE ? = ir_tag.tag",
|
363
|
+
(tag,),
|
364
|
+
)
|
365
|
+
tag_ids.append(cur.fetchone()[0])
|
361
366
|
|
362
|
-
|
363
|
-
|
364
|
-
"INSERT INTO impulse_response_file (file, tags, delay) VALUES (?, ?, ?)",
|
365
|
-
[
|
367
|
+
cur.execute(
|
368
|
+
"INSERT INTO ir_file (delay, name) VALUES (?, ?)",
|
366
369
|
(
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
+
file.delay,
|
371
|
+
file.name,
|
372
|
+
),
|
373
|
+
)
|
374
|
+
|
375
|
+
file_id = cur.lastrowid
|
376
|
+
for tag_id in tag_ids:
|
377
|
+
cur.execute(
|
378
|
+
"INSERT INTO ir_file_ir_tag (file_id, tag_id) VALUES (?, ?)",
|
379
|
+
(file_id, tag_id),
|
370
380
|
)
|
371
|
-
|
372
|
-
],
|
373
|
-
)
|
381
|
+
|
374
382
|
con.commit()
|
375
383
|
con.close()
|
376
384
|
|
377
385
|
|
378
386
|
def update_mixid_width(location: str, num_mixtures: int, test: bool = False) -> None:
|
379
387
|
"""Update the mixid width"""
|
380
|
-
from
|
381
|
-
|
388
|
+
from ..utils.max_text_width import max_text_width
|
382
389
|
from .mixdb import db_connection
|
383
390
|
|
384
391
|
con = db_connection(location=location, readonly=False, test=test)
|
@@ -391,42 +398,43 @@ def update_mixid_width(location: str, num_mixtures: int, test: bool = False) ->
|
|
391
398
|
|
392
399
|
|
393
400
|
def generate_mixtures(
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
noise_augmentations: list[AugmentationRule],
|
400
|
-
spectral_masks: list[SpectralMask],
|
401
|
-
all_snrs: list[UniversalSNRGenerator],
|
402
|
-
mixups: list[int],
|
403
|
-
num_classes: int,
|
404
|
-
feature_step_samples: int,
|
405
|
-
num_ir: int,
|
406
|
-
) -> tuple[int, int, list[Mixture]]:
|
401
|
+
location: str,
|
402
|
+
config: dict,
|
403
|
+
effects: dict[str, list[Effects]],
|
404
|
+
test: bool = False,
|
405
|
+
) -> list[Mixture]:
|
407
406
|
"""Generate mixtures"""
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
407
|
+
mixdb = MixtureDatabase(location, test)
|
408
|
+
|
409
|
+
effected_sources: dict[str, list[tuple[SourceFile, Effects]]] = {}
|
410
|
+
for category in mixdb.source_files:
|
411
|
+
effected_sources[category] = []
|
412
|
+
for file in mixdb.source_files[category]:
|
413
|
+
for effect in effects[category]:
|
414
|
+
effected_sources[category].append((file, effect))
|
415
|
+
|
416
|
+
mixtures: list[Mixture] = []
|
417
|
+
for noise_mix_rule in config["sources"]["noise"]["mix_rules"]:
|
418
|
+
match noise_mix_rule["mode"]:
|
419
|
+
case "exhaustive":
|
420
|
+
func = _exhaustive_noise_mix
|
421
|
+
case "non-exhaustive":
|
422
|
+
func = _non_exhaustive_noise_mix
|
423
|
+
case "non-combinatorial":
|
424
|
+
func = _non_combinatorial_noise_mix
|
425
|
+
case _:
|
426
|
+
raise ValueError(f"invalid noise mix_rule mode: {noise_mix_rule['mode']}")
|
427
|
+
|
428
|
+
mixtures.extend(
|
429
|
+
func(
|
430
|
+
location=location,
|
431
|
+
config=config,
|
432
|
+
effected_sources=effected_sources,
|
433
|
+
test=test,
|
434
|
+
)
|
435
|
+
)
|
436
|
+
|
437
|
+
return mixtures
|
430
438
|
|
431
439
|
|
432
440
|
def populate_mixture_table(
|
@@ -437,26 +445,33 @@ def populate_mixture_table(
|
|
437
445
|
show_progress: bool = False,
|
438
446
|
) -> None:
|
439
447
|
"""Populate mixture table"""
|
440
|
-
from
|
441
|
-
from
|
442
|
-
|
448
|
+
from .. import logger
|
449
|
+
from ..utils.parallel import track
|
443
450
|
from .helpers import from_mixture
|
444
|
-
from .helpers import
|
451
|
+
from .helpers import from_source
|
445
452
|
from .mixdb import db_connection
|
446
453
|
|
447
454
|
con = db_connection(location=location, readonly=False, test=test)
|
448
455
|
|
449
|
-
# Populate
|
456
|
+
# Populate source table
|
450
457
|
if logging:
|
451
|
-
logger.info("Populating
|
452
|
-
|
458
|
+
logger.info("Populating source table")
|
459
|
+
sources: list[tuple[str, int, float, bool, float, float, bool, int]] = []
|
453
460
|
for mixture in mixtures:
|
454
|
-
for
|
455
|
-
entry =
|
456
|
-
if entry not in
|
457
|
-
|
458
|
-
for
|
459
|
-
con.execute(
|
461
|
+
for source in mixture.all_sources.values():
|
462
|
+
entry = from_source(source)
|
463
|
+
if entry not in sources:
|
464
|
+
sources.append(entry)
|
465
|
+
for source in track(sources, disable=not show_progress):
|
466
|
+
con.execute(
|
467
|
+
"""
|
468
|
+
INSERT INTO source (effects, file_id, pre_tempo, repeat, snr, snr_gain, snr_random, start)
|
469
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
470
|
+
""",
|
471
|
+
source,
|
472
|
+
)
|
473
|
+
|
474
|
+
con.commit()
|
460
475
|
|
461
476
|
# Populate mixture table
|
462
477
|
if logging:
|
@@ -465,25 +480,31 @@ def populate_mixture_table(
|
|
465
480
|
m_id = int(mixture.name)
|
466
481
|
con.execute(
|
467
482
|
"""
|
468
|
-
INSERT INTO mixture (id, name,
|
469
|
-
|
470
|
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
483
|
+
INSERT INTO mixture (id, name, samples, spectral_mask_id, spectral_mask_seed)
|
484
|
+
VALUES (?, ?, ?, ?, ?)
|
471
485
|
""",
|
472
486
|
(m_id + 1, *from_mixture(mixture)),
|
473
487
|
)
|
474
488
|
|
475
|
-
for
|
476
|
-
|
489
|
+
for source in mixture.all_sources.values():
|
490
|
+
source_id = con.execute(
|
477
491
|
"""
|
478
|
-
SELECT
|
479
|
-
FROM
|
480
|
-
WHERE ? =
|
492
|
+
SELECT source.id
|
493
|
+
FROM source
|
494
|
+
WHERE ? = source.effects
|
495
|
+
AND ? = source.file_id
|
496
|
+
AND ? = source.pre_tempo
|
497
|
+
AND ? = source.repeat
|
498
|
+
AND ? = source.snr
|
499
|
+
AND ? = source.snr_gain
|
500
|
+
AND ? = source.snr_random
|
501
|
+
AND ? = source.start
|
481
502
|
""",
|
482
|
-
|
503
|
+
from_source(source),
|
483
504
|
).fetchone()[0]
|
484
505
|
con.execute(
|
485
|
-
"INSERT INTO
|
486
|
-
(m_id + 1,
|
506
|
+
"INSERT INTO mixture_source (mixture_id, source_id) VALUES (?, ?)",
|
507
|
+
(m_id + 1, source_id),
|
487
508
|
)
|
488
509
|
|
489
510
|
con.commit()
|
@@ -491,525 +512,362 @@ def populate_mixture_table(
|
|
491
512
|
|
492
513
|
|
493
514
|
def update_mixture(mixdb: MixtureDatabase, mixture: Mixture, with_data: bool = False) -> tuple[Mixture, GenMixData]:
|
494
|
-
"""Update mixture record with name and gains"""
|
495
|
-
|
496
|
-
from .augmentation import apply_gain
|
497
|
-
from .helpers import get_target
|
498
|
-
|
499
|
-
mixture, targets_audio = _initialize_targets_audio(mixdb, mixture)
|
500
|
-
|
501
|
-
noise_audio = _augmented_noise_audio(mixdb, mixture)
|
502
|
-
noise_audio = get_next_noise(audio=noise_audio, offset=mixture.noise_offset, length=mixture.samples)
|
515
|
+
"""Update mixture record with name, samples, and gains"""
|
516
|
+
import numpy as np
|
503
517
|
|
504
|
-
|
505
|
-
|
518
|
+
sources_audio: SourcesAudioT = {}
|
519
|
+
post_audio: SourcesAudioT = {}
|
520
|
+
for category in mixture.all_sources:
|
521
|
+
mixture, sources_audio[category], post_audio[category] = _update_source(mixdb, mixture, category)
|
506
522
|
|
507
|
-
mixture = _initialize_mixture_gains(
|
508
|
-
mixdb=mixdb, mixture=mixture, target_audio=target_audio, noise_audio=noise_audio
|
509
|
-
)
|
523
|
+
mixture = _initialize_mixture_gains(mixdb, mixture, post_audio)
|
510
524
|
|
511
525
|
mixture.name = f"{int(mixture.name):0{mixdb.mixid_width}}"
|
512
526
|
|
513
527
|
if not with_data:
|
514
528
|
return mixture, GenMixData()
|
515
529
|
|
516
|
-
# Apply
|
517
|
-
|
518
|
-
|
530
|
+
# Apply gains
|
531
|
+
post_audio = {
|
532
|
+
category: post_audio[category] * mixture.all_sources[category].snr_gain for category in mixture.all_sources
|
533
|
+
}
|
519
534
|
|
520
|
-
#
|
521
|
-
|
522
|
-
|
535
|
+
# Sum sources, noise, and mixture
|
536
|
+
source_audio = np.sum([post_audio[category] for category in mixture.sources], axis=0)
|
537
|
+
noise_audio = post_audio["noise"]
|
538
|
+
mixture_audio = source_audio + noise_audio
|
523
539
|
|
524
540
|
return mixture, GenMixData(
|
525
|
-
|
526
|
-
|
527
|
-
target=target_audio,
|
541
|
+
sources=sources_audio,
|
542
|
+
source=source_audio,
|
528
543
|
noise=noise_audio,
|
544
|
+
mixture=mixture_audio,
|
529
545
|
)
|
530
546
|
|
531
547
|
|
532
|
-
def
|
533
|
-
from .
|
534
|
-
from .
|
535
|
-
|
536
|
-
noise = mixdb.noise_file(mixture.noise.file_id)
|
537
|
-
noise_augmentation = mixture.noise.augmentation
|
548
|
+
def _update_source(mixdb: MixtureDatabase, mixture: Mixture, category: str) -> tuple[Mixture, AudioT, AudioT]:
|
549
|
+
from .effects import apply_effects
|
550
|
+
from .effects import conform_audio_to_length
|
538
551
|
|
539
|
-
|
540
|
-
|
552
|
+
source = mixture.all_sources[category]
|
553
|
+
org_audio = mixdb.read_source_audio(source.file_id)
|
541
554
|
|
542
|
-
|
555
|
+
org_samples = len(org_audio)
|
556
|
+
pre_audio = apply_effects(mixdb, org_audio, source.effects, pre=True, post=False)
|
543
557
|
|
558
|
+
pre_samples = len(pre_audio)
|
559
|
+
mixture.all_sources[category].pre_tempo = org_samples / pre_samples
|
544
560
|
|
545
|
-
|
546
|
-
from .augmentation import apply_augmentation
|
547
|
-
from .augmentation import pad_audio_to_length
|
561
|
+
pre_audio = conform_audio_to_length(pre_audio, mixture.samples, source.repeat, source.start)
|
548
562
|
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
targets_audio.append(
|
553
|
-
apply_augmentation(
|
554
|
-
mixdb=mixdb,
|
555
|
-
audio=target_audio,
|
556
|
-
augmentation=target.augmentation.pre,
|
557
|
-
frame_length=mixdb.feature_step_samples,
|
558
|
-
)
|
559
|
-
)
|
560
|
-
|
561
|
-
mixture.samples = max([len(item) for item in targets_audio])
|
563
|
+
post_audio = apply_effects(mixdb, pre_audio, source.effects, pre=False, post=True)
|
564
|
+
if len(pre_audio) != len(post_audio):
|
565
|
+
raise RuntimeError(f"post-truth effects changed length: {source.effects.post}")
|
562
566
|
|
563
|
-
|
564
|
-
targets_audio[idx] = pad_audio_to_length(audio=targets_audio[idx], length=mixture.samples)
|
567
|
+
return mixture, pre_audio, post_audio
|
565
568
|
|
566
|
-
return mixture, targets_audio
|
567
569
|
|
568
|
-
|
569
|
-
def _initialize_mixture_gains(
|
570
|
-
mixdb: MixtureDatabase,
|
571
|
-
mixture: Mixture,
|
572
|
-
target_audio: AudioT,
|
573
|
-
noise_audio: AudioT,
|
574
|
-
) -> Mixture:
|
570
|
+
def _initialize_mixture_gains(mixdb: MixtureDatabase, mixture: Mixture, sources_audio: SourcesAudioT) -> Mixture:
|
575
571
|
import numpy as np
|
576
572
|
|
577
|
-
from
|
578
|
-
from
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
mixture.noise_snr_gain = 1
|
584
|
-
elif mixture.is_target_only:
|
585
|
-
# Special case for zeroing out noise data
|
586
|
-
mixture.target_snr_gain = 1
|
587
|
-
mixture.noise_snr_gain = 0
|
588
|
-
else:
|
589
|
-
target_level_types = [
|
590
|
-
target_file.level_type for target_file in [mixdb.target_file(target.file_id) for target in mixture.targets]
|
591
|
-
]
|
592
|
-
if not all(level_type == target_level_types[0] for level_type in target_level_types):
|
593
|
-
raise ValueError("Not all target_level_types in mixup are the same")
|
594
|
-
|
595
|
-
level_type = target_level_types[0]
|
573
|
+
from ..utils.asl_p56 import asl_p56
|
574
|
+
from ..utils.db import db_to_linear
|
575
|
+
|
576
|
+
sources_energy: dict[str, float] = {}
|
577
|
+
for category in mixture.all_sources:
|
578
|
+
level_type = mixdb.source_file(mixture.all_sources[category].file_id).level_type
|
596
579
|
match level_type:
|
597
580
|
case "default":
|
598
|
-
|
581
|
+
sources_energy[category] = float(np.mean(np.square(sources_audio[category])))
|
599
582
|
case "speech":
|
600
|
-
|
583
|
+
sources_energy[category] = asl_p56(sources_audio[category])
|
601
584
|
case _:
|
602
585
|
raise ValueError(f"Unknown level_type: {level_type}")
|
603
586
|
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
mixture.
|
613
|
-
|
614
|
-
|
615
|
-
mixture.
|
616
|
-
|
587
|
+
# Initialize all gains to 1
|
588
|
+
for category in mixture.all_sources:
|
589
|
+
mixture.all_sources[category].snr_gain = 1
|
590
|
+
|
591
|
+
# Resolve gains
|
592
|
+
for category in mixture.all_sources:
|
593
|
+
if mixture.is_noise_only and category != "noise":
|
594
|
+
# Special case for zeroing out source data
|
595
|
+
mixture.all_sources[category].snr_gain = 0
|
596
|
+
elif mixture.is_source_only and category == "noise":
|
597
|
+
# Special case for zeroing out noise data
|
598
|
+
mixture.all_sources[category].snr_gain = 0
|
599
|
+
elif category != "primary":
|
600
|
+
if sources_energy["primary"] == 0:
|
601
|
+
# Avoid divide-by-zero
|
602
|
+
mixture.all_sources[category].snr_gain = 1
|
603
|
+
else:
|
604
|
+
mixture.all_sources[category].snr_gain = float(
|
605
|
+
np.sqrt(sources_energy["primary"] / sources_energy[category])
|
606
|
+
) / db_to_linear(mixture.all_sources[category].snr)
|
607
|
+
|
608
|
+
# Normalize gains
|
609
|
+
max_snr_gain = max([source.snr_gain for source in mixture.all_sources.values()])
|
610
|
+
for category in mixture.all_sources:
|
611
|
+
mixture.all_sources[category].snr_gain = mixture.all_sources[category].snr_gain / max_snr_gain
|
617
612
|
|
618
613
|
# Check for clipping in mixture
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
max_abs_audio = max(abs(mixture_audio))
|
614
|
+
mixture_audio = np.sum(
|
615
|
+
[sources_audio[category] * mixture.all_sources[category].snr_gain for category in mixture.all_sources], axis=0
|
616
|
+
)
|
617
|
+
max_abs_audio = float(np.max(np.abs(mixture_audio)))
|
623
618
|
clip_level = db_to_linear(-0.25)
|
624
619
|
if max_abs_audio > clip_level:
|
625
|
-
# Clipping occurred; lower gains to bring audio within +/-1
|
626
620
|
gain_adjustment = clip_level / max_abs_audio
|
627
|
-
mixture.
|
628
|
-
|
621
|
+
for category in mixture.all_sources:
|
622
|
+
mixture.all_sources[category].snr_gain *= gain_adjustment
|
623
|
+
|
624
|
+
# To improve repeatability, round results
|
625
|
+
for category in mixture.all_sources:
|
626
|
+
mixture.all_sources[category].snr_gain = round(mixture.all_sources[category].snr_gain, ndigits=5)
|
629
627
|
|
630
|
-
mixture.target_snr_gain = round(mixture.target_snr_gain, ndigits=5)
|
631
|
-
mixture.noise_snr_gain = round(mixture.noise_snr_gain, ndigits=5)
|
632
628
|
return mixture
|
633
629
|
|
634
630
|
|
635
631
|
def _exhaustive_noise_mix(
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
all_snrs: list[UniversalSNRGenerator],
|
643
|
-
mixups: list[int],
|
644
|
-
num_classes: int,
|
645
|
-
feature_step_samples: int,
|
646
|
-
num_ir: int,
|
647
|
-
) -> tuple[int, int, list[Mixture]]:
|
648
|
-
"""Use every noise/augmentation with every target/augmentation+interferences/augmentation"""
|
632
|
+
location: str,
|
633
|
+
config: dict,
|
634
|
+
effected_sources: dict[str, list[tuple[SourceFile, Effects]]],
|
635
|
+
test: bool = False,
|
636
|
+
) -> list[Mixture]:
|
637
|
+
"""Use every noise/effect with every source/effect+interferences/effect"""
|
649
638
|
from random import randint
|
650
639
|
|
651
640
|
import numpy as np
|
652
641
|
|
653
|
-
from
|
654
|
-
from
|
655
|
-
from .
|
656
|
-
from .
|
657
|
-
from .datatypes import UniversalSNR
|
658
|
-
from .targets import get_augmented_target_ids_for_mixup
|
642
|
+
from ..datatypes import Mixture
|
643
|
+
from ..datatypes import UniversalSNR
|
644
|
+
from .effects import effects_from_rules
|
645
|
+
from .effects import estimate_effected_length
|
659
646
|
|
660
|
-
|
661
|
-
|
662
|
-
used_noise_samples = 0
|
663
|
-
|
664
|
-
augmented_target_ids_for_mixups = [
|
665
|
-
get_augmented_target_ids_for_mixup(
|
666
|
-
augmented_targets=augmented_targets,
|
667
|
-
targets=target_files,
|
668
|
-
target_augmentations=target_augmentations,
|
669
|
-
mixup=mixup,
|
670
|
-
num_classes=num_classes,
|
671
|
-
)
|
672
|
-
for mixup in mixups
|
673
|
-
]
|
647
|
+
mixdb = MixtureDatabase(location, test)
|
648
|
+
snrs = get_all_snrs_from_config(config)
|
674
649
|
|
650
|
+
m_id = 0
|
675
651
|
mixtures: list[Mixture] = []
|
676
|
-
for
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
noise_length = estimate_augmented_length_from_length(
|
681
|
-
length=noise_files[noise_file_id].samples,
|
682
|
-
tempo=noise_augmentation.pre.tempo,
|
683
|
-
)
|
652
|
+
for noise_file, noise_rule in effected_sources["noise"]:
|
653
|
+
noise_start = 0
|
654
|
+
noise_effect = effects_from_rules(mixdb, noise_rule)
|
655
|
+
noise_length = estimate_effected_length(noise_file.samples, noise_effect)
|
684
656
|
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
augmented_target_ids=augmented_target_ids,
|
689
|
-
augmented_targets=augmented_targets,
|
690
|
-
target_files=target_files,
|
691
|
-
target_augmentations=target_augmentations,
|
692
|
-
feature_step_samples=feature_step_samples,
|
693
|
-
num_ir=num_ir,
|
694
|
-
)
|
657
|
+
for primary_file, primary_rule in effected_sources["primary"]:
|
658
|
+
primary_effect = effects_from_rules(mixdb, primary_rule)
|
659
|
+
primary_length = estimate_effected_length(primary_file.samples, primary_effect, mixdb.feature_step_samples)
|
695
660
|
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
661
|
+
for spectral_mask_id in range(len(config["spectral_masks"])):
|
662
|
+
for snr in snrs["noise"]:
|
663
|
+
mixtures.append(
|
664
|
+
Mixture(
|
665
|
+
name=str(m_id),
|
666
|
+
all_sources={
|
667
|
+
"primary": Source(
|
668
|
+
file_id=primary_file.id,
|
669
|
+
effects=primary_effect,
|
670
|
+
),
|
671
|
+
"noise": Source(
|
672
|
+
file_id=noise_file.id,
|
673
|
+
effects=noise_effect,
|
674
|
+
start=noise_start,
|
675
|
+
repeat=True,
|
705
676
|
snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
677
|
+
),
|
678
|
+
},
|
679
|
+
samples=primary_length,
|
680
|
+
spectral_mask_id=spectral_mask_id + 1,
|
681
|
+
spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
|
682
|
+
)
|
683
|
+
)
|
684
|
+
noise_start = int((noise_start + primary_length) % noise_length)
|
685
|
+
m_id += 1
|
714
686
|
|
715
|
-
return
|
687
|
+
return mixtures
|
716
688
|
|
717
689
|
|
718
690
|
def _non_exhaustive_noise_mix(
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
mixups: list[int],
|
727
|
-
num_classes: int,
|
728
|
-
feature_step_samples: int,
|
729
|
-
num_ir: int,
|
730
|
-
) -> tuple[int, int, list[Mixture]]:
|
731
|
-
"""Cycle through every target/augmentation+interferences/augmentation without necessarily using all
|
732
|
-
noise/augmentation combinations (reduced data set).
|
691
|
+
location: str,
|
692
|
+
config: dict,
|
693
|
+
effected_sources: dict[str, list[tuple[SourceFile, Effects]]],
|
694
|
+
test: bool = False,
|
695
|
+
) -> list[Mixture]:
|
696
|
+
"""Cycle through every source/effect+interferences/effect without necessarily using all
|
697
|
+
noise/effect combinations (reduced data set).
|
733
698
|
"""
|
734
699
|
from random import randint
|
735
700
|
|
736
701
|
import numpy as np
|
737
702
|
|
738
|
-
from
|
739
|
-
from
|
740
|
-
from .
|
741
|
-
from .
|
703
|
+
from ..datatypes import Mixture
|
704
|
+
from ..datatypes import UniversalSNR
|
705
|
+
from .effects import effects_from_rules
|
706
|
+
from .effects import estimate_effected_length
|
742
707
|
|
743
|
-
|
744
|
-
|
745
|
-
used_noise_samples = 0
|
746
|
-
noise_file_id = None
|
747
|
-
noise_augmentation_id = None
|
748
|
-
noise_offset = None
|
749
|
-
|
750
|
-
augmented_target_indices_for_mixups = [
|
751
|
-
get_augmented_target_ids_for_mixup(
|
752
|
-
augmented_targets=augmented_targets,
|
753
|
-
targets=target_files,
|
754
|
-
target_augmentations=target_augmentations,
|
755
|
-
mixup=mixup,
|
756
|
-
num_classes=num_classes,
|
757
|
-
)
|
758
|
-
for mixup in mixups
|
759
|
-
]
|
708
|
+
mixdb = MixtureDatabase(location, test)
|
709
|
+
snrs = get_all_snrs_from_config(config)
|
760
710
|
|
761
|
-
|
762
|
-
for mixup in augmented_target_indices_for_mixups:
|
763
|
-
for augmented_target_indices in mixup:
|
764
|
-
targets, target_length = _get_target_info(
|
765
|
-
augmented_target_ids=augmented_target_indices,
|
766
|
-
augmented_targets=augmented_targets,
|
767
|
-
target_files=target_files,
|
768
|
-
target_augmentations=target_augmentations,
|
769
|
-
feature_step_samples=feature_step_samples,
|
770
|
-
num_ir=num_ir,
|
771
|
-
)
|
772
|
-
|
773
|
-
for spectral_mask_id in range(len(spectral_masks)):
|
774
|
-
for snr in all_snrs:
|
775
|
-
(
|
776
|
-
noise_file_id,
|
777
|
-
noise_augmentation_id,
|
778
|
-
noise_augmentation,
|
779
|
-
noise_offset,
|
780
|
-
) = _get_next_noise_offset(
|
781
|
-
noise_file_id=noise_file_id,
|
782
|
-
noise_augmentation_id=noise_augmentation_id,
|
783
|
-
noise_offset=noise_offset,
|
784
|
-
target_length=target_length,
|
785
|
-
noise_files=noise_files,
|
786
|
-
noise_augmentations=noise_augmentations,
|
787
|
-
num_ir=num_ir,
|
788
|
-
)
|
789
|
-
used_noise_samples += target_length
|
711
|
+
next_noise = NextNoise(mixdb, effected_sources["noise"])
|
790
712
|
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
)
|
713
|
+
m_id = 0
|
714
|
+
mixtures: list[Mixture] = []
|
715
|
+
for primary_file, primary_rule in effected_sources["primary"]:
|
716
|
+
primary_effect = effects_from_rules(mixdb, primary_rule)
|
717
|
+
primary_length = estimate_effected_length(primary_file.samples, primary_effect, mixdb.feature_step_samples)
|
718
|
+
|
719
|
+
for spectral_mask_id in range(len(config["spectral_masks"])):
|
720
|
+
for snr in snrs["noise"]:
|
721
|
+
noise_file_id, noise_effect, noise_start = next_noise.generate(primary_file.samples)
|
722
|
+
|
723
|
+
mixtures.append(
|
724
|
+
Mixture(
|
725
|
+
name=str(m_id),
|
726
|
+
all_sources={
|
727
|
+
"primary": Source(
|
728
|
+
file_id=primary_file.id,
|
729
|
+
effects=primary_effect,
|
730
|
+
),
|
731
|
+
"noise": Source(
|
732
|
+
file_id=noise_file_id,
|
733
|
+
effects=noise_effect,
|
734
|
+
start=noise_start,
|
735
|
+
repeat=True,
|
736
|
+
snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
|
737
|
+
),
|
738
|
+
},
|
739
|
+
samples=primary_length,
|
740
|
+
spectral_mask_id=spectral_mask_id + 1,
|
741
|
+
spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
|
804
742
|
)
|
805
|
-
|
743
|
+
)
|
744
|
+
m_id += 1
|
806
745
|
|
807
|
-
return
|
746
|
+
return mixtures
|
808
747
|
|
809
748
|
|
810
749
|
def _non_combinatorial_noise_mix(
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
num_classes: int,
|
820
|
-
feature_step_samples: int,
|
821
|
-
num_ir: int,
|
822
|
-
) -> tuple[int, int, list[Mixture]]:
|
823
|
-
"""Combine a target/augmentation+interferences/augmentation with a single cut of a noise/augmentation
|
824
|
-
non-exhaustively (each target/augmentation+interferences/augmentation does not use each noise/augmentation).
|
825
|
-
Cut has random start and loop back to beginning if end of noise/augmentation is reached.
|
750
|
+
location: str,
|
751
|
+
config: dict,
|
752
|
+
effected_sources: dict[str, list[tuple[SourceFile, Effects]]],
|
753
|
+
test: bool = False,
|
754
|
+
) -> list[Mixture]:
|
755
|
+
"""Combine a source/effect+interferences/effect with a single cut of a noise/effect
|
756
|
+
non-exhaustively (each source/effect+interferences/effect does not use each noise/effect).
|
757
|
+
Cut has random start and loop back to beginning if end of noise/effect is reached.
|
826
758
|
"""
|
827
759
|
from random import choice
|
828
760
|
from random import randint
|
829
761
|
|
830
762
|
import numpy as np
|
831
763
|
|
832
|
-
from
|
833
|
-
from
|
834
|
-
from .
|
835
|
-
from .
|
764
|
+
from ..datatypes import Mixture
|
765
|
+
from ..datatypes import UniversalSNR
|
766
|
+
from .effects import effects_from_rules
|
767
|
+
from .effects import estimate_effected_length
|
836
768
|
|
837
|
-
|
838
|
-
|
839
|
-
used_noise_samples = 0
|
840
|
-
noise_file_id = None
|
841
|
-
noise_augmentation_id = None
|
842
|
-
|
843
|
-
augmented_target_indices_for_mixups = [
|
844
|
-
get_augmented_target_ids_for_mixup(
|
845
|
-
augmented_targets=augmented_targets,
|
846
|
-
targets=target_files,
|
847
|
-
target_augmentations=target_augmentations,
|
848
|
-
mixup=mixup,
|
849
|
-
num_classes=num_classes,
|
850
|
-
)
|
851
|
-
for mixup in mixups
|
852
|
-
]
|
769
|
+
mixdb = MixtureDatabase(location, test)
|
770
|
+
snrs = get_all_snrs_from_config(config)
|
853
771
|
|
772
|
+
m_id = 0
|
773
|
+
noise_id = 0
|
854
774
|
mixtures: list[Mixture] = []
|
855
|
-
for
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
775
|
+
for primary_file, primary_rule in effected_sources["primary"]:
|
776
|
+
primary_effect = effects_from_rules(mixdb, primary_rule)
|
777
|
+
primary_length = estimate_effected_length(primary_file.samples, primary_effect, mixdb.feature_step_samples)
|
778
|
+
|
779
|
+
for spectral_mask_id in range(len(config["spectral_masks"])):
|
780
|
+
for snr in snrs["noise"]:
|
781
|
+
noise_file, noise_rule = effected_sources["noise"][noise_id]
|
782
|
+
noise_effect = effects_from_rules(mixdb, noise_rule)
|
783
|
+
noise_length = estimate_effected_length(noise_file.samples, noise_effect)
|
784
|
+
|
785
|
+
mixtures.append(
|
786
|
+
Mixture(
|
787
|
+
name=str(m_id),
|
788
|
+
all_sources={
|
789
|
+
"primary": Source(
|
790
|
+
file_id=primary_file.id,
|
791
|
+
effects=primary_effect,
|
792
|
+
),
|
793
|
+
"noise": Source(
|
794
|
+
file_id=noise_file.id,
|
795
|
+
effects=noise_effect,
|
796
|
+
start=choice(range(noise_length)), # noqa: S311
|
797
|
+
repeat=True,
|
798
|
+
snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
|
799
|
+
),
|
800
|
+
},
|
801
|
+
samples=primary_length,
|
802
|
+
spectral_mask_id=spectral_mask_id + 1,
|
803
|
+
spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
|
879
804
|
)
|
880
|
-
|
805
|
+
)
|
806
|
+
noise_id = (noise_id + 1) % len(effected_sources["noise"])
|
807
|
+
m_id += 1
|
881
808
|
|
882
|
-
|
809
|
+
return mixtures
|
883
810
|
|
884
|
-
mixtures.append(
|
885
|
-
Mixture(
|
886
|
-
targets=targets,
|
887
|
-
name=str(m_id),
|
888
|
-
noise=Noise(file_id=noise_file_id + 1, augmentation=noise_augmentation),
|
889
|
-
noise_offset=choice(range(noise_length)), # noqa: S311
|
890
|
-
samples=target_length,
|
891
|
-
snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
|
892
|
-
spectral_mask_id=spectral_mask_id + 1,
|
893
|
-
spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
|
894
|
-
)
|
895
|
-
)
|
896
|
-
m_id += 1
|
897
811
|
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
noise_file_id: int | None,
|
903
|
-
noise_augmentation_id: int | None,
|
904
|
-
noise_files: list[NoiseFile],
|
905
|
-
noise_augmentations: list[AugmentationRule],
|
906
|
-
num_ir: int,
|
907
|
-
) -> tuple[int, int, Augmentation, int]:
|
908
|
-
from .augmentation import augmentation_from_rule
|
909
|
-
from .augmentation import estimate_augmented_length_from_length
|
910
|
-
|
911
|
-
if noise_file_id is None or noise_augmentation_id is None:
|
912
|
-
noise_file_id = 0
|
913
|
-
noise_augmentation_id = 0
|
914
|
-
else:
|
915
|
-
noise_augmentation_id += 1
|
916
|
-
if noise_augmentation_id == len(noise_augmentations):
|
917
|
-
noise_augmentation_id = 0
|
918
|
-
noise_file_id += 1
|
919
|
-
if noise_file_id == len(noise_files):
|
920
|
-
noise_file_id = 0
|
921
|
-
|
922
|
-
noise_augmentation = augmentation_from_rule(noise_augmentations[noise_augmentation_id], num_ir)
|
923
|
-
noise_length = estimate_augmented_length_from_length(
|
924
|
-
length=noise_files[noise_file_id].samples, tempo=noise_augmentation.pre.tempo
|
925
|
-
)
|
926
|
-
return noise_file_id, noise_augmentation_id, noise_augmentation, noise_length
|
927
|
-
|
928
|
-
|
929
|
-
def _get_next_noise_offset(
|
930
|
-
noise_file_id: int | None,
|
931
|
-
noise_augmentation_id: int | None,
|
932
|
-
noise_offset: int | None,
|
933
|
-
target_length: int,
|
934
|
-
noise_files: list[NoiseFile],
|
935
|
-
noise_augmentations: list[AugmentationRule],
|
936
|
-
num_ir: int,
|
937
|
-
) -> tuple[int, int, Augmentation, int]:
|
938
|
-
from .augmentation import augmentation_from_rule
|
939
|
-
from .augmentation import estimate_augmented_length_from_length
|
940
|
-
|
941
|
-
if noise_file_id is None or noise_augmentation_id is None or noise_offset is None:
|
942
|
-
noise_file_id = 0
|
943
|
-
noise_augmentation_id = 0
|
944
|
-
noise_offset = 0
|
945
|
-
|
946
|
-
noise_augmentation = augmentation_from_rule(noise_augmentations[noise_file_id], num_ir)
|
947
|
-
noise_length = estimate_augmented_length_from_length(
|
948
|
-
length=noise_files[noise_file_id].samples, tempo=noise_augmentation.pre.tempo
|
949
|
-
)
|
950
|
-
if noise_offset + target_length >= noise_length:
|
951
|
-
if noise_offset == 0:
|
952
|
-
raise ValueError("Length of target audio exceeds length of noise audio")
|
953
|
-
|
954
|
-
noise_offset = 0
|
955
|
-
noise_augmentation_id += 1
|
956
|
-
if noise_augmentation_id == len(noise_augmentations):
|
957
|
-
noise_augmentation_id = 0
|
958
|
-
noise_file_id += 1
|
959
|
-
if noise_file_id == len(noise_files):
|
960
|
-
noise_file_id = 0
|
961
|
-
noise_augmentation = augmentation_from_rule(noise_augmentations[noise_augmentation_id], num_ir)
|
962
|
-
|
963
|
-
return noise_file_id, noise_augmentation_id, noise_augmentation, noise_offset
|
964
|
-
|
965
|
-
|
966
|
-
def _get_target_info(
|
967
|
-
augmented_target_ids: list[int],
|
968
|
-
augmented_targets: list[AugmentedTarget],
|
969
|
-
target_files: list[TargetFile],
|
970
|
-
target_augmentations: list[AugmentationRule],
|
971
|
-
feature_step_samples: int,
|
972
|
-
num_ir: int,
|
973
|
-
) -> tuple[list[Target], int]:
|
974
|
-
from .augmentation import augmentation_from_rule
|
975
|
-
from .augmentation import estimate_augmented_length_from_length
|
976
|
-
|
977
|
-
mixups: list[Target] = []
|
978
|
-
target_length = 0
|
979
|
-
for idx in augmented_target_ids:
|
980
|
-
tfi = augmented_targets[idx].target_id
|
981
|
-
target_augmentation_rule = target_augmentations[augmented_targets[idx].target_augmentation_id]
|
982
|
-
target_augmentation = augmentation_from_rule(target_augmentation_rule, num_ir)
|
983
|
-
|
984
|
-
mixups.append(Target(file_id=tfi + 1, augmentation=target_augmentation))
|
985
|
-
|
986
|
-
target_length = max(
|
987
|
-
estimate_augmented_length_from_length(
|
988
|
-
length=target_files[tfi].samples,
|
989
|
-
tempo=target_augmentation.pre.tempo,
|
990
|
-
frame_length=feature_step_samples,
|
991
|
-
),
|
992
|
-
target_length,
|
993
|
-
)
|
994
|
-
return mixups, target_length
|
812
|
+
class NextNoise:
|
813
|
+
def __init__(self, mixdb: MixtureDatabase, effected_noises: list[tuple[SourceFile, Effects]]) -> None:
|
814
|
+
from .effects import effects_from_rules
|
815
|
+
from .effects import estimate_effected_length
|
995
816
|
|
817
|
+
self.mixdb = mixdb
|
818
|
+
self.effected_noises = effected_noises
|
996
819
|
|
997
|
-
|
998
|
-
|
820
|
+
self.noise_start = 0
|
821
|
+
self.noise_id = 0
|
822
|
+
self.noise_effect = effects_from_rules(self.mixdb, self.noise_rule)
|
823
|
+
self.noise_length = estimate_effected_length(self.noise_file.samples, self.noise_effect)
|
999
824
|
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
825
|
+
@property
|
826
|
+
def noise_file(self):
|
827
|
+
return self.effected_noises[self.noise_id][0]
|
828
|
+
|
829
|
+
@property
|
830
|
+
def noise_rule(self):
|
831
|
+
return self.effected_noises[self.noise_id][1]
|
1003
832
|
|
833
|
+
def generate(self, length: int) -> tuple[int, Effects, int]:
|
834
|
+
from .effects import effects_from_rules
|
835
|
+
from .effects import estimate_effected_length
|
836
|
+
|
837
|
+
if self.noise_start + length > self.noise_length:
|
838
|
+
# Not enough samples in current noise
|
839
|
+
if self.noise_start == 0:
|
840
|
+
raise ValueError("Length of primary audio exceeds length of noise audio")
|
841
|
+
|
842
|
+
self.noise_start = 0
|
843
|
+
self.noise_id = (self.noise_id + 1) % len(self.effected_noises)
|
844
|
+
self.noise_effect = effects_from_rules(self.mixdb, self.noise_rule)
|
845
|
+
self.noise_length = estimate_effected_length(self.noise_file.samples, self.noise_effect)
|
846
|
+
noise_start = self.noise_start
|
847
|
+
else:
|
848
|
+
# Current noise has enough samples
|
849
|
+
noise_start = self.noise_start
|
850
|
+
self.noise_start += length
|
1004
851
|
|
1005
|
-
|
852
|
+
return self.noise_file.id, self.noise_effect, noise_start
|
853
|
+
|
854
|
+
|
855
|
+
def get_all_snrs_from_config(config: dict) -> dict[str, list[UniversalSNRGenerator]]:
|
856
|
+
snrs: dict[str, list[UniversalSNRGenerator]] = {}
|
857
|
+
for category in config["sources"]:
|
858
|
+
if category != "primary":
|
859
|
+
snrs[category] = [UniversalSNRGenerator(snr) for snr in config["sources"][category]["snrs"]]
|
860
|
+
return snrs
|
861
|
+
|
862
|
+
|
863
|
+
def _get_textgrid_tiers_from_source_file(file: str) -> list[str]:
|
1006
864
|
from pathlib import Path
|
1007
865
|
|
1008
866
|
from praatio import textgrid
|
1009
867
|
|
1010
|
-
from
|
868
|
+
from ..utils.tokenized_shell_vars import tokenized_expand
|
1011
869
|
|
1012
|
-
textgrid_file = Path(tokenized_expand(
|
870
|
+
textgrid_file = Path(tokenized_expand(file)[0]).with_suffix(".TextGrid")
|
1013
871
|
if not textgrid_file.exists():
|
1014
872
|
return []
|
1015
873
|
|
@@ -1018,18 +876,18 @@ def _get_textgrid_tiers_from_target_file(target_file: str) -> list[str]:
|
|
1018
876
|
return sorted(tg.tierNames)
|
1019
877
|
|
1020
878
|
|
1021
|
-
def _populate_speaker_table(location: str,
|
879
|
+
def _populate_speaker_table(location: str, source_files: list[SourceFile], test: bool = False) -> None:
|
1022
880
|
"""Populate speaker table"""
|
1023
881
|
import json
|
1024
882
|
from pathlib import Path
|
1025
883
|
|
1026
884
|
import yaml
|
1027
885
|
|
886
|
+
from ..utils.tokenized_shell_vars import tokenized_expand
|
1028
887
|
from .mixdb import db_connection
|
1029
|
-
from .tokenized_shell_vars import tokenized_expand
|
1030
888
|
|
1031
889
|
# Determine columns for speaker table
|
1032
|
-
all_parents = {Path(
|
890
|
+
all_parents = {Path(file.name).parent for file in source_files}
|
1033
891
|
speaker_parents = (parent for parent in all_parents if Path(tokenized_expand(parent / "speaker.yml")[0]).exists())
|
1034
892
|
|
1035
893
|
speakers: dict[Path, dict[str, str]] = {}
|
@@ -1072,13 +930,13 @@ def _populate_speaker_table(location: str, target_files: list[TargetFile], test:
|
|
1072
930
|
)
|
1073
931
|
|
1074
932
|
if "speaker_id" in tiers:
|
1075
|
-
con.execute("CREATE INDEX speaker_speaker_id_idx ON
|
933
|
+
con.execute("CREATE INDEX speaker_speaker_id_idx ON source_file (speaker_id)")
|
1076
934
|
|
1077
935
|
con.commit()
|
1078
936
|
con.close()
|
1079
937
|
|
1080
938
|
|
1081
|
-
def _populate_truth_config_table(location: str,
|
939
|
+
def _populate_truth_config_table(location: str, source_files: list[SourceFile], test: bool = False) -> None:
|
1082
940
|
"""Populate truth_config table"""
|
1083
941
|
import json
|
1084
942
|
|
@@ -1088,8 +946,8 @@ def _populate_truth_config_table(location: str, target_files: list[TargetFile],
|
|
1088
946
|
|
1089
947
|
# Populate truth_config table
|
1090
948
|
truth_configs: list[str] = []
|
1091
|
-
for
|
1092
|
-
for name, config in
|
949
|
+
for file in source_files:
|
950
|
+
for name, config in file.truth_configs.items():
|
1093
951
|
ts = json.dumps({"name": name} | config.to_dict())
|
1094
952
|
if ts not in truth_configs:
|
1095
953
|
truth_configs.append(ts)
|
@@ -1100,3 +958,18 @@ def _populate_truth_config_table(location: str, target_files: list[TargetFile],
|
|
1100
958
|
|
1101
959
|
con.commit()
|
1102
960
|
con.close()
|
961
|
+
|
962
|
+
|
963
|
+
def _populate_impulse_response_tag_table(location: str, files: list[ImpulseResponseFile], test: bool = False) -> None:
|
964
|
+
"""Populate ir_tag table"""
|
965
|
+
from .mixdb import db_connection
|
966
|
+
|
967
|
+
con = db_connection(location=location, readonly=False, test=test)
|
968
|
+
|
969
|
+
con.executemany(
|
970
|
+
"INSERT INTO ir_tag (tag) VALUES (?)",
|
971
|
+
[(tag,) for tag in {tag for file in files for tag in file.tags}],
|
972
|
+
)
|
973
|
+
|
974
|
+
con.commit()
|
975
|
+
con.close()
|