sonusai 0.18.8__py3-none-any.whl → 0.19.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +93 -77
- sonusai/genmetrics.py +59 -46
- sonusai/genmix.py +116 -104
- sonusai/genmixdb.py +194 -153
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +15 -17
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +19 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +50 -46
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +677 -473
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +52 -85
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +40 -27
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
- sonusai-0.19.5.dist-info/RECORD +125 -0
- {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.8.dist-info/RECORD +0 -125
- {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
sonusai/genmixdb.py
CHANGED
@@ -76,8 +76,8 @@ generation functions. By default, these are included with the feature data in a
|
|
76
76
|
truth generation is turned on with default settings (see truth section) and a single class, i.e., detecting a single
|
77
77
|
type of sound. The truth format is a single float per class representing the probability of activity/presence, and
|
78
78
|
multi-class truth is possible by specifying the number of classes and either a scalar index or a vector of indices in
|
79
|
-
which to put the truth result. For example, 'num_class: 3' and '
|
80
|
-
with truth put in index 2 (others would be 0) for data/target.wav being an audio clip from sound type of class 2.
|
79
|
+
which to put the truth result. For example, 'num_class: 3' and 'class_indices: [ 2 ]' adds a 1x3 vector to the feature
|
80
|
+
data with truth put in index 2 (others would be 0) for data/target.wav being an audio clip from sound type of class 2.
|
81
81
|
|
82
82
|
The mixture is created with potential data augmentation functions in the following way:
|
83
83
|
1. apply noise augmentation rule
|
@@ -112,6 +112,7 @@ targets:
|
|
112
112
|
will find all .wav files in the specified directories and process them as targets.
|
113
113
|
|
114
114
|
"""
|
115
|
+
|
115
116
|
import signal
|
116
117
|
from dataclasses import dataclass
|
117
118
|
|
@@ -124,7 +125,7 @@ def signal_handler(_sig, _frame):
|
|
124
125
|
|
125
126
|
from sonusai import logger
|
126
127
|
|
127
|
-
logger.info(
|
128
|
+
logger.info("Canceled due to keyboard interrupt")
|
128
129
|
sys.exit(1)
|
129
130
|
|
130
131
|
|
@@ -133,34 +134,34 @@ signal.signal(signal.SIGINT, signal_handler)
|
|
133
134
|
|
134
135
|
@dataclass
|
135
136
|
class MPGlobal:
|
136
|
-
mixdb: MixtureDatabase
|
137
|
-
save_mix: bool
|
138
|
-
save_ft: bool
|
139
|
-
save_segsnr: bool
|
137
|
+
mixdb: MixtureDatabase
|
138
|
+
save_mix: bool
|
139
|
+
save_ft: bool
|
140
|
+
save_segsnr: bool
|
140
141
|
|
141
142
|
|
142
|
-
MP_GLOBAL
|
143
|
+
MP_GLOBAL: MPGlobal
|
143
144
|
|
144
145
|
|
145
|
-
def genmixdb(
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
146
|
+
def genmixdb(
|
147
|
+
location: str,
|
148
|
+
save_mix: bool = False,
|
149
|
+
save_ft: bool = False,
|
150
|
+
save_segsnr: bool = False,
|
151
|
+
logging: bool = True,
|
152
|
+
show_progress: bool = False,
|
153
|
+
test: bool = False,
|
154
|
+
save_json: bool = False,
|
155
|
+
) -> MixtureDatabase:
|
153
156
|
from random import seed
|
154
157
|
|
155
158
|
import yaml
|
156
|
-
from tqdm import tqdm
|
157
159
|
|
158
|
-
from sonusai import SonusAIError
|
159
160
|
from sonusai import logger
|
160
|
-
from sonusai.mixture import AugmentationRule
|
161
|
-
from sonusai.mixture import MixtureDatabase
|
162
161
|
from sonusai.mixture import SAMPLE_BYTES
|
163
162
|
from sonusai.mixture import SAMPLE_RATE
|
163
|
+
from sonusai.mixture import AugmentationRule
|
164
|
+
from sonusai.mixture import MixtureDatabase
|
164
165
|
from sonusai.mixture import balance_targets
|
165
166
|
from sonusai.mixture import generate_mixtures
|
166
167
|
from sonusai.mixture import get_all_snrs_from_config
|
@@ -182,11 +183,13 @@ def genmixdb(location: str,
|
|
182
183
|
from sonusai.mixture import populate_spectral_mask_table
|
183
184
|
from sonusai.mixture import populate_target_file_table
|
184
185
|
from sonusai.mixture import populate_top_table
|
186
|
+
from sonusai.mixture import populate_truth_parameters_table
|
185
187
|
from sonusai.mixture import update_mixid_width
|
186
188
|
from sonusai.utils import dataclass_from_dict
|
187
189
|
from sonusai.utils import human_readable_size
|
188
|
-
from sonusai.utils import
|
190
|
+
from sonusai.utils import par_track
|
189
191
|
from sonusai.utils import seconds_to_hms
|
192
|
+
from sonusai.utils import track
|
190
193
|
|
191
194
|
config = load_config(location)
|
192
195
|
initialize_db(location=location, test=test)
|
@@ -197,113 +200,116 @@ def genmixdb(location: str,
|
|
197
200
|
populate_class_label_table(location, config, test)
|
198
201
|
populate_class_weights_threshold_table(location, config, test)
|
199
202
|
populate_spectral_mask_table(location, config, test)
|
203
|
+
populate_truth_parameters_table(location, config, test)
|
200
204
|
|
201
|
-
seed(config[
|
205
|
+
seed(config["seed"])
|
202
206
|
|
203
207
|
if logging:
|
204
|
-
logger.debug(f
|
205
|
-
logger.debug(
|
208
|
+
logger.debug(f"Seed: {config['seed']}")
|
209
|
+
logger.debug("Configuration:")
|
206
210
|
logger.debug(yaml.dump(config))
|
207
211
|
|
208
212
|
if logging:
|
209
|
-
logger.info(
|
213
|
+
logger.info("Collecting targets")
|
210
214
|
|
211
215
|
target_files = get_target_files(config, show_progress=show_progress)
|
212
216
|
|
213
217
|
if len(target_files) == 0:
|
214
|
-
raise
|
218
|
+
raise RuntimeError("Canceled due to no targets")
|
215
219
|
|
216
220
|
populate_target_file_table(location, target_files, test)
|
217
221
|
|
218
222
|
if logging:
|
219
|
-
logger.debug(
|
223
|
+
logger.debug("List of targets:")
|
220
224
|
logger.debug(yaml.dump([target.name for target in mixdb.target_files], default_flow_style=False))
|
221
|
-
logger.debug(
|
225
|
+
logger.debug("")
|
222
226
|
|
223
227
|
if logging:
|
224
|
-
logger.info(
|
228
|
+
logger.info("Collecting noises")
|
225
229
|
|
226
230
|
noise_files = get_noise_files(config, show_progress=show_progress)
|
227
231
|
|
228
232
|
populate_noise_file_table(location, noise_files, test)
|
229
233
|
|
230
234
|
if logging:
|
231
|
-
logger.debug(
|
235
|
+
logger.debug("List of noises:")
|
232
236
|
logger.debug(yaml.dump([noise.name for noise in mixdb.noise_files], default_flow_style=False))
|
233
|
-
logger.debug(
|
237
|
+
logger.debug("")
|
234
238
|
|
235
239
|
if logging:
|
236
|
-
logger.info(
|
240
|
+
logger.info("Collecting impulse responses")
|
237
241
|
|
238
242
|
impulse_response_files = get_impulse_response_files(config)
|
239
243
|
|
240
244
|
populate_impulse_response_file_table(location, impulse_response_files, test)
|
241
245
|
|
242
246
|
if logging:
|
243
|
-
logger.debug(
|
247
|
+
logger.debug("List of impulse responses:")
|
244
248
|
logger.debug(
|
245
|
-
yaml.dump(
|
246
|
-
|
247
|
-
|
249
|
+
yaml.dump(
|
250
|
+
[entry.file for entry in mixdb.impulse_response_files],
|
251
|
+
default_flow_style=False,
|
252
|
+
)
|
253
|
+
)
|
254
|
+
logger.debug("")
|
248
255
|
|
249
256
|
if logging:
|
250
|
-
logger.info(
|
257
|
+
logger.info("Collecting target augmentations")
|
251
258
|
|
252
|
-
target_augmentations = get_augmentation_rules(
|
253
|
-
|
259
|
+
target_augmentations = get_augmentation_rules(
|
260
|
+
rules=config["target_augmentations"], num_ir=mixdb.num_impulse_response_files
|
261
|
+
)
|
254
262
|
mixups = get_mixups(target_augmentations)
|
255
263
|
|
256
264
|
if logging:
|
257
265
|
for mixup in mixups:
|
258
|
-
logger.debug(f
|
266
|
+
logger.debug(f"Expanded list of target augmentation rules for mixup of {mixup}:")
|
259
267
|
for target_augmentation in get_target_augmentations_for_mixup(target_augmentations, mixup):
|
260
268
|
ta_dict = target_augmentation.to_dict()
|
261
|
-
del ta_dict[
|
262
|
-
logger.debug(f
|
263
|
-
logger.debug(
|
269
|
+
del ta_dict["mixup"]
|
270
|
+
logger.debug(f"- {ta_dict}")
|
271
|
+
logger.debug("")
|
264
272
|
|
265
273
|
if logging:
|
266
|
-
logger.info(
|
274
|
+
logger.info("Collecting noise augmentations")
|
267
275
|
|
268
|
-
noise_augmentations = get_augmentation_rules(
|
269
|
-
|
276
|
+
noise_augmentations = get_augmentation_rules(
|
277
|
+
rules=config["noise_augmentations"], num_ir=mixdb.num_impulse_response_files
|
278
|
+
)
|
270
279
|
|
271
280
|
if logging:
|
272
|
-
logger.debug(
|
281
|
+
logger.debug("Expanded list of noise augmentations:")
|
273
282
|
for noise_augmentation in noise_augmentations:
|
274
283
|
na_dict = noise_augmentation.to_dict()
|
275
|
-
del na_dict[
|
276
|
-
logger.debug(f
|
277
|
-
logger.debug(
|
284
|
+
del na_dict["mixup"]
|
285
|
+
logger.debug(f"- {na_dict}")
|
286
|
+
logger.debug("")
|
278
287
|
|
279
288
|
if logging:
|
280
|
-
logger.debug(f
|
281
|
-
logger.debug(f
|
282
|
-
logger.debug(f
|
283
|
-
logger.debug(
|
289
|
+
logger.debug(f"SNRs: {config['snrs']}\n")
|
290
|
+
logger.debug(f"Random SNRs: {config['random_snrs']}\n")
|
291
|
+
logger.debug(f"Noise mix mode: {mixdb.noise_mix_mode}\n")
|
292
|
+
logger.debug("Spectral masks:")
|
284
293
|
for spectral_mask in mixdb.spectral_masks:
|
285
|
-
logger.debug(f
|
286
|
-
logger.debug(
|
287
|
-
|
288
|
-
if mixdb.truth_mutex and any(mixup > 1 for mixup in mixups):
|
289
|
-
raise SonusAIError(f'Mutex truth mode is not compatible with mixup')
|
294
|
+
logger.debug(f"- {spectral_mask}")
|
295
|
+
logger.debug("")
|
290
296
|
|
291
297
|
if logging:
|
292
|
-
logger.info(
|
298
|
+
logger.info("Collecting augmented targets")
|
293
299
|
|
294
300
|
augmented_targets = get_augmented_targets(target_files, target_augmentations, mixups)
|
295
301
|
|
296
|
-
if config[
|
297
|
-
class_balancing_augmentation = dataclass_from_dict(AugmentationRule, config[
|
302
|
+
if config["class_balancing"]:
|
303
|
+
class_balancing_augmentation = dataclass_from_dict(AugmentationRule, config["class_balancing_augmentation"])
|
298
304
|
augmented_targets, target_augmentations = balance_targets(
|
299
305
|
augmented_targets=augmented_targets,
|
300
306
|
targets=target_files,
|
301
307
|
target_augmentations=target_augmentations,
|
302
308
|
class_balancing_augmentation=class_balancing_augmentation,
|
303
309
|
num_classes=mixdb.num_classes,
|
304
|
-
truth_mutex=mixdb.truth_mutex,
|
305
310
|
num_ir=mixdb.num_impulse_response_files,
|
306
|
-
mixups=mixups
|
311
|
+
mixups=mixups,
|
312
|
+
)
|
307
313
|
|
308
314
|
target_audio_samples = sum([targets.samples for targets in mixdb.target_files])
|
309
315
|
target_audio_duration = target_audio_samples / SAMPLE_RATE
|
@@ -311,13 +317,17 @@ def genmixdb(location: str,
|
|
311
317
|
noise_audio_samples = noise_audio_duration * SAMPLE_RATE
|
312
318
|
|
313
319
|
if logging:
|
314
|
-
logger.info(
|
315
|
-
logger.info(
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
320
|
+
logger.info("")
|
321
|
+
logger.info(
|
322
|
+
f"Target audio: {mixdb.num_target_files} files, "
|
323
|
+
f"{human_readable_size(target_audio_samples * SAMPLE_BYTES, 1)}, "
|
324
|
+
f"{seconds_to_hms(seconds=target_audio_duration)}"
|
325
|
+
)
|
326
|
+
logger.info(
|
327
|
+
f"Noise audio: {mixdb.num_noise_files} files, "
|
328
|
+
f"{human_readable_size(noise_audio_samples * SAMPLE_BYTES, 1)}, "
|
329
|
+
f"{seconds_to_hms(seconds=noise_audio_duration)}"
|
330
|
+
)
|
321
331
|
|
322
332
|
used_noise_files, used_noise_samples, mixtures = generate_mixtures(
|
323
333
|
noise_mix_mode=mixdb.noise_mix_mode,
|
@@ -330,41 +340,48 @@ def genmixdb(location: str,
|
|
330
340
|
all_snrs=get_all_snrs_from_config(config),
|
331
341
|
mixups=mixups,
|
332
342
|
num_classes=mixdb.num_classes,
|
333
|
-
truth_mutex=mixdb.truth_mutex,
|
334
343
|
feature_step_samples=mixdb.feature_step_samples,
|
335
|
-
num_ir=mixdb.num_impulse_response_files
|
344
|
+
num_ir=mixdb.num_impulse_response_files,
|
345
|
+
)
|
336
346
|
|
337
347
|
num_mixtures = len(mixtures)
|
338
348
|
update_mixid_width(location, num_mixtures, test)
|
339
349
|
|
340
350
|
if logging:
|
341
|
-
logger.info(
|
342
|
-
logger.info(f
|
351
|
+
logger.info("")
|
352
|
+
logger.info(f"Found {num_mixtures:,} mixtures to process")
|
343
353
|
|
344
354
|
total_duration = float(sum([mixture.samples for mixture in mixtures])) / SAMPLE_RATE
|
345
355
|
|
346
356
|
if logging:
|
347
|
-
log_duration_and_sizes(
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
357
|
+
log_duration_and_sizes(
|
358
|
+
total_duration=total_duration,
|
359
|
+
num_classes=mixdb.num_classes,
|
360
|
+
feature_step_samples=mixdb.feature_step_samples,
|
361
|
+
feature_parameters=mixdb.feature_parameters,
|
362
|
+
stride=mixdb.fg_stride,
|
363
|
+
desc="Estimated",
|
364
|
+
)
|
365
|
+
logger.info(
|
366
|
+
f"Feature shape: "
|
367
|
+
f"{mixdb.fg_stride} x {mixdb.feature_parameters} "
|
368
|
+
f"({mixdb.fg_stride * mixdb.feature_parameters} total params)"
|
369
|
+
)
|
370
|
+
logger.info(f"Feature samples: {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)")
|
371
|
+
logger.info(f"Feature step samples: {mixdb.feature_step_samples} samples ({mixdb.feature_step_ms} ms)")
|
372
|
+
logger.info("")
|
359
373
|
|
360
374
|
# Fill in the details
|
361
375
|
if logging:
|
362
|
-
logger.info(
|
363
|
-
progress =
|
364
|
-
mixtures =
|
365
|
-
|
366
|
-
|
367
|
-
|
376
|
+
logger.info("Generating mixtures")
|
377
|
+
progress = track(total=num_mixtures, disable=not show_progress)
|
378
|
+
mixtures = par_track(
|
379
|
+
_process_mixture,
|
380
|
+
mixtures,
|
381
|
+
progress=progress,
|
382
|
+
initializer=_initializer,
|
383
|
+
initargs=(location, save_mix, save_ft, save_segsnr, test),
|
384
|
+
)
|
368
385
|
progress.close()
|
369
386
|
|
370
387
|
populate_mixture_table(location, mixtures, test)
|
@@ -378,20 +395,22 @@ def genmixdb(location: str,
|
|
378
395
|
noise_samples_percent = (float(used_noise_samples) / float(noise_audio_samples)) * 100
|
379
396
|
|
380
397
|
if logging:
|
381
|
-
log_duration_and_sizes(
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
logger.info(
|
390
|
-
logger.info(
|
398
|
+
log_duration_and_sizes(
|
399
|
+
total_duration=total_duration,
|
400
|
+
num_classes=mixdb.num_classes,
|
401
|
+
feature_step_samples=mixdb.feature_step_samples,
|
402
|
+
feature_parameters=mixdb.feature_parameters,
|
403
|
+
stride=mixdb.fg_stride,
|
404
|
+
desc="Actual",
|
405
|
+
)
|
406
|
+
logger.info("")
|
407
|
+
logger.info(f"Used {noise_files_percent:,.0f}% of noise files")
|
408
|
+
logger.info(f"Used {noise_samples_percent:,.0f}% of noise audio")
|
409
|
+
logger.info("")
|
391
410
|
|
392
411
|
if not test and save_json:
|
393
412
|
if logging:
|
394
|
-
logger.info(f
|
413
|
+
logger.info(f"Writing JSON version of database to {location}")
|
395
414
|
mixdb = MixtureDatabase(location)
|
396
415
|
mixdb.save()
|
397
416
|
|
@@ -399,10 +418,14 @@ def genmixdb(location: str,
|
|
399
418
|
|
400
419
|
|
401
420
|
def _initializer(location: str, save_mix: bool, save_ft: bool, save_segsnr: bool, test: bool) -> None:
|
402
|
-
MP_GLOBAL
|
403
|
-
|
404
|
-
MP_GLOBAL
|
405
|
-
|
421
|
+
global MP_GLOBAL
|
422
|
+
|
423
|
+
MP_GLOBAL = MPGlobal(
|
424
|
+
mixdb=MixtureDatabase(location, test),
|
425
|
+
save_mix=save_mix,
|
426
|
+
save_ft=save_ft,
|
427
|
+
save_segsnr=save_segsnr,
|
428
|
+
)
|
406
429
|
|
407
430
|
|
408
431
|
def _process_mixture(mixture: Mixture) -> Mixture:
|
@@ -410,11 +433,13 @@ def _process_mixture(mixture: Mixture) -> Mixture:
|
|
410
433
|
|
411
434
|
from sonusai.mixture import get_ft
|
412
435
|
from sonusai.mixture import get_segsnr
|
413
|
-
from sonusai.mixture import
|
436
|
+
from sonusai.mixture import get_truth
|
414
437
|
from sonusai.mixture import update_mixture
|
415
|
-
from sonusai.mixture import
|
438
|
+
from sonusai.mixture import write_cached_data
|
416
439
|
from sonusai.mixture import write_mixture_metadata
|
417
440
|
|
441
|
+
global MP_GLOBAL
|
442
|
+
|
418
443
|
with_data = MP_GLOBAL.save_mix or MP_GLOBAL.save_ft
|
419
444
|
mixdb = MP_GLOBAL.mixdb
|
420
445
|
|
@@ -424,31 +449,41 @@ def _process_mixture(mixture: Mixture) -> Mixture:
|
|
424
449
|
write_data: list[tuple[str, Any]] = []
|
425
450
|
|
426
451
|
if MP_GLOBAL.save_mix:
|
427
|
-
write_data.append((
|
428
|
-
write_data.append((
|
429
|
-
write_data.append((
|
452
|
+
write_data.append(("targets", genmix_data.targets))
|
453
|
+
write_data.append(("noise", genmix_data.noise))
|
454
|
+
write_data.append(("mixture", genmix_data.mixture))
|
430
455
|
|
431
456
|
if MP_GLOBAL.save_ft:
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
457
|
+
if genmix_data.targets is None or genmix_data.noise is None or genmix_data.mixture is None:
|
458
|
+
raise RuntimeError("Mixture data was not generated properly")
|
459
|
+
truth_t = get_truth(
|
460
|
+
mixdb=mixdb,
|
461
|
+
mixture=mixture,
|
462
|
+
targets_audio=genmix_data.targets,
|
463
|
+
noise_audio=genmix_data.noise,
|
464
|
+
mixture_audio=genmix_data.mixture,
|
465
|
+
)
|
466
|
+
feature, truth_f = get_ft(
|
467
|
+
mixdb=mixdb,
|
468
|
+
mixture=mixture,
|
469
|
+
mixture_audio=genmix_data.mixture,
|
470
|
+
truth_t=truth_t,
|
471
|
+
)
|
472
|
+
write_data.append(("feature", feature))
|
473
|
+
write_data.append(("truth_f", truth_f))
|
443
474
|
|
444
475
|
if MP_GLOBAL.save_segsnr:
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
476
|
+
if genmix_data.target is None:
|
477
|
+
raise RuntimeError("Target data was not generated properly")
|
478
|
+
segsnr = get_segsnr(
|
479
|
+
mixdb=mixdb,
|
480
|
+
mixture=mixture,
|
481
|
+
target_audio=genmix_data.target,
|
482
|
+
noise=genmix_data.noise,
|
483
|
+
)
|
484
|
+
write_data.append(("segsnr", segsnr))
|
485
|
+
|
486
|
+
write_cached_data(mixdb.location, "mixture", mixture.name, write_data)
|
452
487
|
write_mixture_metadata(mixdb, mixture)
|
453
488
|
|
454
489
|
return mixture
|
@@ -478,13 +513,13 @@ def main() -> None:
|
|
478
513
|
from sonusai.mixture import load_config
|
479
514
|
from sonusai.utils import seconds_to_hms
|
480
515
|
|
481
|
-
verbose = args[
|
482
|
-
save_mix = args[
|
483
|
-
save_ft = args[
|
484
|
-
save_segsnr = args[
|
485
|
-
dryrun = args[
|
486
|
-
save_json = args[
|
487
|
-
location = args[
|
516
|
+
verbose = args["--verbose"]
|
517
|
+
save_mix = args["--mix"]
|
518
|
+
save_ft = args["--ft"]
|
519
|
+
save_segsnr = args["--segsnr"]
|
520
|
+
dryrun = args["--dryrun"]
|
521
|
+
save_json = args["--json"]
|
522
|
+
location = args["LOC"]
|
488
523
|
|
489
524
|
start_time = time.monotonic()
|
490
525
|
|
@@ -493,30 +528,36 @@ def main() -> None:
|
|
493
528
|
|
494
529
|
makedirs(location, exist_ok=True)
|
495
530
|
|
496
|
-
create_file_handler(join(location,
|
531
|
+
create_file_handler(join(location, "genmixdb.log"))
|
497
532
|
update_console_handler(verbose)
|
498
|
-
initial_log_messages(
|
533
|
+
initial_log_messages("genmixdb")
|
499
534
|
|
500
535
|
if dryrun:
|
501
536
|
config = load_config(location)
|
502
|
-
logger.info(
|
537
|
+
logger.info("Dryrun configuration:")
|
503
538
|
logger.info(yaml.dump(config))
|
504
539
|
return
|
505
540
|
|
506
|
-
logger.info(f
|
507
|
-
logger.info(
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
541
|
+
logger.info(f"Creating mixture database for {location}")
|
542
|
+
logger.info("")
|
543
|
+
|
544
|
+
try:
|
545
|
+
genmixdb(
|
546
|
+
location=location,
|
547
|
+
save_mix=save_mix,
|
548
|
+
save_ft=save_ft,
|
549
|
+
save_segsnr=save_segsnr,
|
550
|
+
show_progress=True,
|
551
|
+
save_json=save_json,
|
552
|
+
)
|
553
|
+
except Exception as e:
|
554
|
+
logger.debug(e)
|
555
|
+
raise
|
515
556
|
|
516
557
|
end_time = time.monotonic()
|
517
|
-
logger.info(f
|
518
|
-
logger.info(
|
558
|
+
logger.info(f"Completed in {seconds_to_hms(seconds=end_time - start_time)}")
|
559
|
+
logger.info("")
|
519
560
|
|
520
561
|
|
521
|
-
if __name__ ==
|
562
|
+
if __name__ == "__main__":
|
522
563
|
main()
|